rapid stochastic gradient descent
play

Rapid Stochastic Gradient Descent Accelerating Machine Learning - PowerPoint PPT Presentation

The imagination driving Australia s ICT future. Nicol N. Schraudolph Rapid Stochastic Gradient Descent Accelerating Machine Learning Statistical Machine Learning Program www.nicta.com.au The imagination driving


  1. The imagination driving Australia ’ s ICT future. Nicol N. Schraudolph Rapid Stochastic Gradient Descent Accelerating Machine Learning Statistical Machine Learning Program www.nicta.com.au

  2. The imagination driving Australia ’ s ICT future. Overview 2 1. Why Stochastic Gradient? 2. Stochastic Meta-Descent (SMD) Derivation and Algorithm Properties and Benchmark Results Applications and Ongoing Work 3. Summary and Outlook Statistical Machine Learning Program www.nicta.com.au

  3. The imagination driving Australia ’ s ICT future. The Information Glut 3 The flood of information caused by plentiful, affordable sensors (such as webcams) ever-increasing networking of these sensors overwhelms our processing ability in, e.g., science - pulsar survey at Arecibo: 1 TB/day business - Dell website: over 100 page requests/sec security - London: over 500’000 security cameras We need intelligent, adaptive filters to cope! Statistical Machine Learning Program www.nicta.com.au

  4. The imagination driving Australia ’ s ICT future. A Challenge for ML 4 Coping with the info glut requires ML alg.s for large, complex, nonlinear models millions of degrees of freedom large volumes of low-quality data noisy, correlated, non-stationary, outliers efficient real-time, online adaptation no fixed training set, life-long learning Current ML techniques have difficulty with this. Statistical Machine Learning Program www.nicta.com.au

  5. The imagination driving Australia ’ s ICT future. Online Learning Paradigm 5 classical optimization: online learning: iterative optimizer online optimizer objective fn. training data stream training data set ... ( aka adaptive filtering, stochastic approximation, ...) nested loops! Statistical Machine Learning Program www.nicta.com.au

  6. The imagination driving Australia ’ s ICT future. Stochastic Approximation 6 Classical formulation of optimization problem: 1 θ ∗ = arg min � : E x [ J ( θ , x )] ≈ J ( θ , x i ) | X | θ x i ∈ X inefficient for large data sets X inappropriate for never-ending, potentially non-stationary data streams ⇒ must resort to stochastic approximation: θ t +1 ≈ arg min J ( θ t , x t ) ( t = 0 , 1 , 2 , . . . ) θ Statistical Machine Learning Program www.nicta.com.au

  7. The imagination driving Australia ’ s ICT future. The Key Problem 7 online, scalable Levenberg optimization algorithms: convergence speed Marquardt quasi-Newton Kalman Filter accelerate convergence conjugate gradient gradient descent evolutionary algorithms 2 3 O(1) O(n) O(n ) O(n ) cost per iteration Statistical Machine Learning Program www.nicta.com.au

  8. The imagination driving Australia ’ s ICT future. The Key Problem 8 Stochastic approximation breaks many optimizers: conjugate directions break down due to noise line minimizations (CG, quasi-Newton) inaccurate Newton, Levenberg-Marquardt, Kalman filter - too expensive for large-scale problems This only leaves evolutionary alg.s - very inefficient (don’t use gradient) simple gradient descent - can be slow to converge Statistical Machine Learning Program www.nicta.com.au

  9. The imagination driving Australia ’ s ICT future. Gain Vector Adaptation 9 Given stochastic gradient , g t := ∂ θ J ( θ t , x t ) adapt θ by gradient descent with gain vector η : θ t +1 = θ t − η t · g t Key idea: simultaneously adapt η by exponentiated gradient: Hadamard ln η t = ln η t − 1 − µ ∂ ln η J ( θ t , x t ) (element-wise) η t = η t − 1 · exp( − µ ∂ θ J ( θ t , x t ) · ∂ ln η θ t ) scalar meta-gain ≈ η t − 1 · max( 1 2 , 1 − µ g t · v t ) (free parameter) Statistical Machine Learning Program www.nicta.com.au

  10. The imagination driving Australia ’ s ICT future. Single-Step Model 10 Conventionally, v t +1 := ∂ ln η t θ t +1 = − η t · g t (recall that ) θ t +1 = θ t − η t · g t giving η t = η t − 1 · max( 1 2 , 1 + µ η t − 1 · g t − 1 · g t ) ⇒ adaptation of η driven by autocorrelation of g : Statistical Machine Learning Program www.nicta.com.au

  11. The imagination driving Australia ’ s ICT future. SMD’s Multi-Step Model 11 To capture long-term dependence of θ on η : θ w(t) η p(t) θ w(t) η p(t) t 0 t 0 t λ i ∂ θ t +1 decay 0 ≤ λ ≤ 1 � define v t +1 := (free parameter) ∂ ln η t − i i =0 Statistical Machine Learning Program www.nicta.com.au

  12. The imagination driving Australia ’ s ICT future. SMD’s v -update 12 � � t t t t t t t � � t λ i ∂ ( η t · g t ) λ i ∂ ( η t · g t ) λ i ∂ θ t +1 ∂ θ t λ i ∂ η t · g t ∂ g t λ i η t · ∂ g t ∂ θ t � � � � � � � � λ i λ i λ i H t v t +1 = v t +1 = λ v t − v t +1 := v t +1 = λ v t − v t +1 ≈ λ v t − η t · v t +1 = λ v t − η t · ( g t + λ H t v t ) g t + v t +1 = λ v t − η t · g t + − − ∂ ln η t − i ∂ ln η t − i ∂ ln η t − i ∂ ln η t − i ∂ ln η t − i ∂ ln η t − i ∂ ln η t − i ∂ ln η t − i i =0 i =0 i =0 i =0 i =0 i =0 i =0 i =0 we obtain a simple iterative update for v correct smoothing over correlated input signals involves implicit Hessian-vector (H v ) product can be computed as efficiently as 2-3 gradient eval.s can be done automatically via algorithmic differentiation Statistical Machine Learning Program www.nicta.com.au

  13. The imagination driving Australia ’ s ICT future. Fixpoint of v 13 Fixpoint of v t +1 = λ v t − η t · ( g t + λ H t v t ) is a Levenberg-Marquardt style gradient step: v → − [ λ H + (1 − λ )diag( η ) − 1 ] − 1 g v is too noisy to use directly; SMD achieves stability by means of the double integration v → η → θ v ⋅ g is well-behaved (self-normalizing property) SMD uses Gauss-Newton approximation of H Statistical Machine Learning Program www.nicta.com.au

  14. The imagination driving Australia ’ s ICT future. Four Regions Benchmark 14 x y Compare simple stochastic gradient (SGD), conventional gain vector adaptation (ALAP), stochastic meta-descent (SMD), and a global extended Kalman filter (GEKF). Statistical Machine Learning Program www.nicta.com.au

  15. The imagination driving Australia ’ s ICT future. Benchmark: Convergence 15 loss 1.4 1.2 1.0 0.8 0.6 SGD 0.4 ALAP 0.2 SMD GEKF 0.0 0k 5k 10k 15k 20k 25k patterns Statistical Machine Learning Program www.nicta.com.au

  16. The imagination driving Australia ’ s ICT future. Computational Cost 16 Algorithm storage flops CPU ms weight update pattern SGD 1 6 0.5 SMD 3 18 1.0 ALAP 4 18 1.0 GEKF >90 >1500 40 Statistical Machine Learning Program www.nicta.com.au

  17. The imagination driving Australia ’ s ICT future. Benchmark: CPU Usage 17 loss 1.4 1.2 1.0 0.8 0.6 GEKF 0.4 ALAP 0.2 SGD SMD 0.0 seconds 0 10 20 30 40 50 Statistical Machine Learning Program www.nicta.com.au

  18. The imagination driving Australia ’ s ICT future. Autocorrelated Data 18 i.i.d. uniform Sobol Brownian ALAP vario-eta s-ALAP ALAP mom. momentum ELK1 ALAP vario-eta SMD ELK1 mom. SMD SMD E E E ELK1 patterns patterns patterns Statistical Machine Learning Program www.nicta.com.au

  19. The imagination driving Australia ’ s ICT future. Comparison to CG 19 Conjugate Gradient SMD deterministic stochastic stochastic (1000 pts) (1000 pts/iteration) (5 pts/iteration) overfits diverges converges Statistical Machine Learning Program www.nicta.com.au

  20. The imagination driving Australia ’ s ICT future. Application: Turbulent Flow 20 (PhD thesis of M. Milano, Inst. of Computational Science, ETH Zürich) original flow linear PCA neural network (75’000 d.o.f.) (160 p.c.) (160 nonlinear p.c.) Statistical Machine Learning Program www.nicta.com.au

  21. The imagination driving Australia ’ s ICT future. Turbulent Flow Model 21 Very high-dimensional optimization problem: 15 neural networks, each about 180’000 parameters the generic model has over 20 million parameters! Learning Curves Here SMD: reconstruction error bold driver outperformed SMD 1e+01 Matlab toolbox 1e+00 was able to train 1e-01 generic model 1e-02 3 iteration x 10 0.00 0.50 1.00 1.50 2.00 Statistical Machine Learning Program www.nicta.com.au

  22. The imagination driving Australia ’ s ICT future. Application: Hand Tracking 22 (PhD thesis of M. Bray, Computer Vision Lab, ETH Zürich) detailed hand model (10k vertices, 26 d.o.f.) randomly sample a few points on model surface project them to image compare with camera image at these points SMD uses resulting stochastic gradient to adjust model Statistical Machine Learning Program www.nicta.com.au

Recommend


More recommend