distributed optimization of cnns and rnns
play

Distributed Optimization of CNNs and RNNs GTC 2015 William Chan - PowerPoint PPT Presentation

Distributed Optimization of CNNs and RNNs GTC 2015 William Chan williamchan.ca williamchan@cmu.edu March 19, 2015 Outline 1. Motivation 2. Distributed ASGD 3. CNNs 4. RNNs 5. Conclusion Carnegie Mellon University 2 Motivation Why


  1. Distributed Optimization of CNNs and RNNs GTC 2015 William Chan williamchan.ca williamchan@cmu.edu March 19, 2015

  2. Outline 1. Motivation 2. Distributed ASGD 3. CNNs 4. RNNs 5. Conclusion Carnegie Mellon University 2

  3. Motivation ◮ Why need distributed training? Carnegie Mellon University 3

  4. Motivation ◮ More data → better models ◮ More data → longer training times Example: Baidu Deep Speech ◮ Synthetic training data generated from overlapping noise ◮ Synthetic training data → unlimited training data Carnegie Mellon University 4

  5. Motivation ◮ Complex models (e.g., CNNs and RNNs) better than simple models (DNNs) ◮ Complex models → longer training times Example: GoogLeNet ◮ 22 layers deep CNN Carnegie Mellon University 5

  6. GoogLeNet softmax2 SoftmaxActivation FC AveragePool 7x7+1(V) DepthConcat Conv Conv Conv Conv 1x1+1(S) 3x3+1(S) 5x5+1(S) 1x1+1(S) 1x1+1(S) Conv Conv 1x1+1(S) MaxPool 3x3+1(S) DepthConcat 1x1+1(S) Conv 3x3+1(S) Conv 5x5+1(S) Conv 1x1+1(S) Conv softmax1 Conv Conv MaxPool 1x1+1(S) 1x1+1(S) 3x3+1(S) SoftmaxActivation MaxPool FC 3x3+2(S) DepthConcat FC Conv Conv Conv Conv Conv 1x1+1(S) 3x3+1(S) 5x5+1(S) 1x1+1(S) 1x1+1(S) 1x1+1(S) Conv 1x1+1(S) Conv MaxPool 3x3+1(S) AveragePool 5x5+3(V) DepthConcat 1x1+1(S) Conv 3x3+1(S) Conv 5x5+1(S) Conv 1x1+1(S) Conv Conv Conv MaxPool 1x1+1(S) 1x1+1(S) 3x3+1(S) DepthConcat softmax0 Conv Conv Conv Conv 1x1+1(S) 3x3+1(S) 5x5+1(S) 1x1+1(S) SoftmaxActivation Conv Conv MaxPool FC 1x1+1(S) 1x1+1(S) 3x3+1(S) DepthConcat FC Conv Conv Conv Conv Conv 1x1+1(S) 3x3+1(S) 5x5+1(S) 1x1+1(S) 1x1+1(S) Conv Conv MaxPool AveragePool 1x1+1(S) 1x1+1(S) 3x3+1(S) 5x5+3(V) DepthConcat Conv Conv Conv Conv 1x1+1(S) 3x3+1(S) 5x5+1(S) 1x1+1(S) Conv Conv MaxPool 1x1+1(S) 1x1+1(S) 3x3+1(S) MaxPool 3x3+2(S) DepthConcat Conv Conv Conv Conv 1x1+1(S) 3x3+1(S) 5x5+1(S) 1x1+1(S) Conv Conv MaxPool 1x1+1(S) 1x1+1(S) 3x3+1(S) DepthConcat Conv Conv Conv Conv 1x1+1(S) 3x3+1(S) 5x5+1(S) 1x1+1(S) Conv Conv MaxPool 1x1+1(S) 1x1+1(S) 3x3+1(S) MaxPool 3x3+2(S) LocalRespNorm Conv 3x3+1(S) Conv 1x1+1(V) LocalRespNorm 3x3+2(S) MaxPool Conv 7x7+2(S) input Figure 3: GoogLeNet network with all the bells and whistles 7 Carnegie Mellon University 6

  7. Distributed Asynchronous Stochastic Gradient Descent ◮ Google Cats, DistBelief, 32 000 CPU cores and more... Figure 1: Google showed we can apply ASGD with Deep Learning. Carnegie Mellon University 7

  8. Distributed Asynchronous Stochastic Gradient Descent ◮ CPUs are expensive ◮ PhD students are poor : ( ◮ Let us use GPUs! Carnegie Mellon University 8

  9. Distributed Asynchronous Stochastic Gradient Descent Stochastic Gradient Descent: (1) θ = θ − η ∇ θ Distributed Asynchronous Stochastic Gradient Descent: θ = θ − η ∇ θ i (2) Carnegie Mellon University 9

  10. Distributed Asynchronous Stochastic Gradient Descent CMU SPEECH3: ◮ x1 GPU Master Parameter Server ◮ xN GPU ASGD Shards Carnegie Mellon University 10

  11. Distributed Asynchronous Stochastic Gradient Descent Parameter Server Independent SGD GPU shards, SGD Shard PCIE DMA synchronization with the parameter server is done via PCIE DMA bypassing the CPU SGD Shard Figure 2: CMU SPEECH3 GPU ASGD. Carnegie Mellon University 11

  12. Distributed Asynchronous Stochastic Gradient Descent SPEECH3 ASGD Shard ↔ Parameter Server Sync: ◮ Compute a minibatch (e.g., 128). ◮ If Parameter Server is free, sync. ◮ Else compute another minibatch. ◮ Easy to implement, < 300 lines of code. ◮ Works surprisingly well. Carnegie Mellon University 12

  13. Distributed Asynchronous Stochastic Gradient Descent Minor tricks: ◮ Momentum / Gradient Projection on Parameter Server ◮ Gradient Decay on Parameter Server ◮ Tunable max distance limit between Parameter Server and Shard. Carnegie Mellon University 13

  14. CNNs Convolutional Neural Networks (CNNs) ◮ Computer Vision ◮ Automatic Speech Recognition ◮ CNNs are typically ≈ 5% relative Word Error Rate (WER) better than DNNs Carnegie Mellon University 14

  15. CNNs Spectrum Time 2D Convolution Max Pooling 2D Convolution Fully Connected Fully Connected Posteriors Figure 3: CNN for Acoustic Modelling. Carnegie Mellon University 15

  16. CNNs Test Frame Accuracy vs. Time 50 45 Test Frame Accuracy 40 35 SGD Baseline 2x 3x 4x 5x 30 0 5 10 15 20 Time (Hours) Figure 4: SGD vs ASGD. Carnegie Mellon University 16

  17. CNNs Workers 40% FA 43% FA 44% FA 1 5:50 (100%) 14:36 (100%) 19:29 (100%) 2 3:36 (81.0%) 8:59 (81.3%) 11:58 (81.4%) 3 2:48 (69.4%) 5:59 (81.3%) 7:58 (81.5%) 4 2:05 (70.0%) 4:28 (81.7%) 6:32 (74.6%) 5 1:40 (70.0%) 3:49 (76.5%) 5:43 (68.2%) Table 1: Time (hh:mm) and scaling efficiency (in brackets) comparison for convergence to 40%, 43% and 44% Frame Accuracy (FA). Carnegie Mellon University 17

  18. RNNs Recurrent Neural Networks (RNNs) ◮ Machine Translation ◮ Automatic Speech Recognition ◮ RNNs are typically ≈ 5-10% relative WER better than DNNs Minor Tricks: ◮ Long Short Term Memory (LSTM) ◮ Cell activation clipping Carnegie Mellon University 18

  19. RNNs DNN RNN Figure 5: DNN vs. RNN. Carnegie Mellon University 19

  20. RNNs Workers 46.5% FA 47.5% FA 48.5% FA 1 1:51 (100%) 3:42 (100%) 7:41 (100%) 2 1:00 (92.5%) 2:00 (92.5%) 3:01 (128%) 5 - - 1:15 (122%) Table 2: Time (hh:mm) and scaling efficiency (in brackets) comparison for convergence to 46.5%, 47.5% and 48.5% Frame Accuracy (FA). ◮ RNNs seem to really like distributed training! Carnegie Mellon University 20

  21. RNNs Workers WER Time 1 3.95 18:37 2 4.11 8:04 5 4.06 5:24 Table 3: WERs. ◮ No (major) difference in WER! Carnegie Mellon University 21

  22. Conclusion ◮ Distributed ASGD on GPU, easy to implement! ◮ Speed up your training! ◮ Minor difference in loss against SGD baseline! Carnegie Mellon University 22

Recommend


More recommend