parallel scan alg lgorithm
play

Parallel Scan Alg lgorithm Shang Wang 1,2 , Yifan Bai 1 , Gennady - PowerPoint PPT Presentation

Scaling Back-Propagation by Parallel Scan Alg lgorithm Shang Wang 1,2 , Yifan Bai 1 , Gennady Pekhimenko 1,2 1 2 The original PPTX file can be downloaded from here. Executive Summary ry The back-propagation (BP) algorithm is popularly used in


  1. Scaling Back-Propagation by Parallel Scan Alg lgorithm Shang Wang 1,2 , Yifan Bai 1 , Gennady Pekhimenko 1,2 1 2 The original PPTX file can be downloaded from here.

  2. Executive Summary ry The back-propagation (BP) algorithm is popularly used in training deep learning (DL) models and implemented in many DL frameworks (e.g., PyTorch and TensorFlow). Problem: BP imposes a strong sequential dependency along layers during the gradient computations. 2

  3. Executive Summary ry The back-propagation (BP) algorithm is popularly used in training deep learning (DL) models and implemented in many DL frameworks (e.g., PyTorch and TensorFlow). Problem: BP imposes a strong sequential dependency along layers during the gradient computations. Key idea: We propose scaling BP by P arallel S can A lgorithm ( BPPSA ) : • Reformulate BP into a scan operation. 1 2 3 4 5 6 7 8 0 1 3 6 10 15 21 28 2

  4. Executive Summary ry The back-propagation (BP) algorithm is popularly used in training deep learning (DL) models and implemented in many DL frameworks (e.g., PyTorch and TensorFlow). Problem: BP imposes a strong sequential dependency along layers during the gradient computations. Key idea: We propose scaling BP by P arallel S can A lgorithm ( BPPSA ) : • Reformulate BP into a scan operation. • Scaled by a customized parallel algorithm. Key Results: Θ (log n) vs. Θ (n) steps on parallel systems. Up to 108 × backward pass speedup (→ 2.17 × overall speedup). 2

  5. Back-propagation 1 (B (BP) Every rywhere 1 Rumelhart et al. “Learning representations by back -propagating 3 errors.”, Nature (1986)

  6. BP’s Strong Sequential Dependency 𝑦 Ԧ 𝑦 Ԧ 𝑚 ReLU Loss Linear Linear 𝜶 𝒎 𝜶𝒎 𝜶𝒎 𝑼 𝑼 𝝐𝒈(⦁) 𝝐𝒈(⦁) 𝜖𝑔( Ԧ 𝑦) 𝑔 Ԧ 𝑦 𝝐⦁ 𝝐⦁ 𝜖 Ԧ 𝑦 Jacobian 𝑈 𝜖𝑔( Ԧ 𝑦) 𝛼 Ԧ 𝑦 𝑚 = 𝛼 𝑦) 𝑚 𝑔( Ԧ 𝜖 Ԧ 𝑦 4

  7. BP’s Strong Sequential Dependency 𝑦 Ԧ 𝑦 Ԧ 𝑚 ReLU Loss Linear Linear 𝜶 𝒎 𝜶𝒎 𝜶𝒎 𝑼 𝑼 𝝐𝒈(⦁) 𝝐𝒈(⦁) 𝜖𝑔( Ԧ 𝑦) 𝑔 Ԧ 𝑦 𝝐⦁ 𝝐⦁ 𝜖 Ԧ 𝑦 Jacobian 𝑈 𝜖𝑔( Ԧ 𝑦) 𝛼 Ԧ 𝑦 𝑚 = 𝛼 𝑦) 𝑚 𝑔( Ԧ 𝜖 Ԧ 𝑦 Strong Sequential Dependency along layers. 4

  8. Data Parallel Training 𝑦 1 Ԧ 𝑚 1 Strong Sequential Dependency Respects BP’s strong sequential dependency. 𝑦 2 Ԧ 𝑚 2 Strong Sequential Dependency Conceptually simple , widely used . 𝑦 3 Ԧ 𝑚 3 Strong Sequential Dependency Effectively increases the batch size: • Generalization gap 1 𝑦 4 Ԧ 𝑚 4 • Batch size scaling limit 2 Constraint: The model must fit in 𝑦 i Ԧ 𝑚 i one device. 1 Keskar, Nitish Shirish et al. “On Large - Batch Training for Deep Learning: Generalization Gap and Sharp Minima.” ICLR (2017) 5 2 Shallue, Christopher J. et al. “Measuring the Effects of Data Parallelism on Neural Network Training.” Journal of Machine Lea rning Research 20 (2019)

  9. Model Parallel Training Used when the model cannot fit in one device. BP’s strong sequential dependency limits scalability . Prior works on pipeline parallel training 1,2 to mitigate such problem, but have their own limitations: • Linear per-device space complexity. • Trade- off between “ bubble of idleness ” vs. potential convergence affect . 𝜶 𝒋−𝟑 𝒎 𝜶 𝒋−𝟐 𝒎 𝜶 𝒋 𝒎 𝜶 𝒋+𝟐 𝒎 Conv Conv Linear 1 Harlap, Aaron et al. “ PipeDream : Fast and Efficient Pipeline Parallel DNN Training.” SOSP (2019) 6 2 Huang, Yanping et al. “ GPipe : Efficient Training of Giant Neural Networks using Pipeline Parallelism.” NeurIPS (2019)

  10. Rethinking BP fr from an Algorithm Perspective 7

  11. Rethinking BP fr from an Algorithm Perspective • Problems with strong sequential dependency were studied in the past (80’), but in a much simpler context. • We propose scaling B ack- P ropagation by P arallel S can A lgorithm ( BPPSA ): • Reformulate BP as a scan operation. • Scale BP by a customized Blelloch Scan algorithm. • Leverage sparsity in the Jacobians. 7

  12. What is a Scan 1 Operation? Binary , associative operator: + Identity: 0 Input sequence: 1 2 3 4 5 6 7 8 Exclusive scan: 0 1 3 6 10 15 21 28 Compute partial reductions at each step of the sequence. 8 1 Blelloch, Guy E. ”Prefix sums and their applications”. Technical Report (1990)

  13. What is a Scan 1 Operation? Binary , associative operator: + Identity: 0 Input sequence: 1 2 3 4 5 6 7 8 Exclusive scan: 0 1 3 6 10 15 21 28 Compute partial reductions at each step of the sequence. 8 1 Blelloch, Guy E. ”Prefix sums and their applications”. Technical Report (1990)

  14. Linear Scan 1 2 3 4 5 6 7 Step: executing the operator once. 3 Number of Elements ( n ) 6 Worker ( p ): an instance of execution; Time n e.g., a core in a multi-core CPU 10 On a single worker: perform scan 15 linearly; takes n steps. 21 With more workers: Can we achieve sublinear steps? 28 9

  15. Blelloch Scan: : ① Up-sweep Phase 1 2 3 4 5 6 7 8 3 7 11 15 Up-sweep 10 26 A B Time A+B Compute partial sums via a reduction tree . 10

  16. Blelloch Scan: : ② Down-sweep Phase 1 2 3 4 5 6 7 8 Parallel 3 7 11 15 Down-sweep 10 26 A B Time B A+B 10 0 3 0 11 10 Combine partial sums 1 0 3 3 5 10 7 21 across branches. 0 1 3 6 10 15 21 11 28

  17. Blelloch Scan: Efficiency 1 2 3 4 5 6 7 8 3 7 11 15 10 26 Logarithmic Time steps along the 2logn 10 0 critical path. 3 0 11 10 1 0 3 3 5 10 7 21 0 1 3 6 10 15 21 12 28

  18. Reformulate BP as a S Scan Operation Binary , associative operator: + Identity: 0 Input sequence: 1 2 3 4 5 6 7 8 Exclusive scan: 0 1 3 6 10 15 21 28 Key Insight : matrix multiplication in BP is also binary & associative ! 13

  19. Reformulate BP as a S Scan Operation G i = 𝜶 𝒚 𝒋 𝒎 𝑼 𝝐𝒚 𝒋+𝟐 J i+1 i+1 = Binary , associative operator: A ◊ B = BA Identity: I 𝝐𝒚 𝒋 Input sequence: G 7 J 7 J 6 J 5 J 4 J 3 J 2 J 1 Exclusive scan: G 7 G 6 G 5 G 4 G 3 G 2 G 1 I Key Insight : matrix multiplication in BP is also binary & associative ! 13

  20. Scale BP by Ble lelloch Scan 1 2 3 4 5 6 7 8 3 7 11 15 10 26 Logarithmic Time steps along the 2logn 10 0 critical path! 3 0 11 10 1 0 3 3 5 10 7 21 0 1 3 6 10 15 21 28

  21. Scale BP by Ble lelloch Scan G 7 J 7 J 6 J 5 J 4 J 3 J 2 J 1 G 6 J 5:6 J 3:4 J 1:2 G 4 J 1:4 Logarithmic Time steps along the 2logn G 4 I critical path! G 6 I J 3:4 G 4 G 7 I J 6 G 6 J 4 G 4 J 2 G 2 I G 7 G 6 G 5 G 4 G 3 G 2 G 1

  22. Scale BP by Ble lelloch Scan G 7 J 7 J 6 J 5 J 4 J 3 J 2 J 1 G 6 J 5:6 J 3:4 J 1:2 G 4 J 1:4 Logarithmic Time steps along the 2logn G 4 I critical path! G 6 I J 3:4 G 4 Down-sweep A B G 7 I J 6 G 6 J 4 G 4 J 2 G 2 Matrix multiplications are AB BA B noncommutative . I G 7 G 6 G 5 G 4 G 3 G 2 G 1

  23. Scale BP by Ble lelloch Scan G 7 J 7 J 6 J 5 J 4 J 3 J 2 J 1 G 6 J 5:6 J 3:4 J 1:2 G 4 J 1:4 Logarithmic Time steps along the 2logn G 4 I critical path! G 6 I J 3:4 G 4 Down-sweep A B G 7 I J 6 G 6 J 4 G 4 J 2 G 2 Matrix multiplications are BA B noncommutative . I G 7 G 6 G 5 G 4 G 3 G 2 G 1

  24. Reconstructs the Original BP Exactly Our method produces gradients mathematically equivalent to BP. The Jacobians are multiplied in a different order → numerical differences. Empirically show that such differences do not effect convergence. Training LeNet-5 on CIFAR-10 (baseline: PyTorch Autograd) 15

  25. Ja Jacobians are Memory ry & Compute Hungry ry A full Jacobian can be prohibitively expensive to handle. • e.g., 1 st convolution in VGG-11 on CIFAR-10 images occupy 768 MB of memory. 𝑦 Ԧ 3072 𝜖𝑔( Ԧ 𝑦) 65536 768 MB 𝑔 Ԧ 𝑦 𝜖 Ԧ 𝑦 16

  26. Ja Jacobians are Memory ry & Compute Hungry ry A full Jacobian can be prohibitively expensive to handle. • e.g., 1 st convolution in VGG-11 on CIFAR-10 images occupy 768 MB of memory. • Generated one row at a time by passing basis vectors into Op_Grad() (the VJP function). 𝟐 𝟏 𝟏 𝟐 𝟏 𝑫𝒑𝒐𝒘𝟑𝒆_𝑯𝒔𝒃𝒆( 𝟏 ) 𝟏 𝟏 𝟏 𝑫𝒑𝒐𝒘𝟑𝒆_𝑯𝒔𝒃𝒆( ) 𝟏 𝟏 𝟏 𝟏 𝟏 16

  27. Jacobians are Memory Ja ry & Compute Hungry ry A full Jacobian can be prohibitively expensive to handle. • e.g., 1 st convolution in VGG-11 on CIFAR-10 images occupy 768 MB of memory. • Generated one row at a time by passing basis vectors into Op_Grad() (the VJP function). 𝟐 𝟏 𝟏 𝟏 𝟐 𝟏 𝟏 𝟏 𝑫𝒑𝒐𝒘𝟑𝒆_𝑯𝒔𝒃𝒆( 𝟏 ) 𝟏 𝟏 𝟐 𝟏 𝟏 𝟏 𝑫𝒑𝒐𝒘𝟑𝒆_𝑯𝒔𝒃𝒆( ) 𝑫𝒑𝒐𝒘𝟑𝒆_𝑯𝒔𝒃𝒆( ) 𝟏 𝟏 𝟐 𝟏 𝟏 𝑫𝒑𝒐𝒘𝟑𝒆_𝑯𝒔𝒃𝒆( ) 𝟏 𝟏 𝟏 𝟏 𝟏 𝟏 𝟏 𝟏 16

Recommend


More recommend