unraveling the mysteries of stochastic gradient descent
play

Unraveling the mysteries of stochastic gradient descent on deep - PowerPoint PPT Presentation

Unraveling the mysteries of stochastic gradient descent on deep neural networks Pratik Chaudhari UCLA VISION LAB 1 The question measures disagreement of predictions with ground truth x = argmin f ( x ) Cat Dog ... x weights aka


  1. Unraveling the mysteries of stochastic gradient descent on deep neural networks Pratik Chaudhari UCLA VISION LAB 1

  2. The question measures disagreement of predictions with ground truth x ∗ = argmin f ( x ) Cat Dog ... x weights aka parameters Stochastic gradient descent x k +1 = x k − η + f b ( x k ) Many, many variants: Why is SGD AdaGrad, rmsprop, Adam, so special? SAG, SVRG, Catalyst, APPA, Natasha, Katyusha… 2

  3. Empirical evidence: wide “minima” 10 5 10 3 Frequency 10 Short negative tail 10 4 0 − 5 0 10 20 30 40 10 3 Frequency Eigenvalues 10 2 0 − 0.5 − 0.4 − 0.3 − 0.2 − 0.1 0.0 Eigenvalues 3

  4. A bit of statistical physics ‣ Energy landscape of a binary perceptron Many sharp minima Few wide minima, � but generalize better � [Baldassi et al., '15] ‣ Wide minima are a large deviations phenomenon 4

  5. Tilting the Gibbs measure ‣ Local Entropy [Chaudhari et al., ICLR '17] x ∗ = argmin f ( x ) x e − f ( x ) = argmax x ⇣ G γ ∗ e − f ( x ) ⌘ ≈ argmin − log x Gaussian kernel � of variance γ 5

  6. Parle: parallelization of SGD ‣ State-of-the-art performance [Chaudhari et al., SysML '18] Wide-ResNet: CIFAR-10 All-CNN: CIFAR-10 (25% data) 6

  7. The question Why is SGD so special? 7

  8. A continuous-time view of SGD ‣ Diffusion matrix: variance of mini-batch gradients = D ( x ) ⇣ ⌘ var + f b ( x ) b * N + f k ( x ) + f k ( x ) > � + f ( x ) + f ( x ) > + = 1 1 X , - b N k =1 ‣ Temperature: ratio of learning rate and step-size β − 1 = η 2 b 8

  9. A continuous-time view of SGD ‣ Continuous-time limit of discrete-time updates q 2 β − 1 D ( x ) dW ( t ) dx = − + f ( x ) dt + |{z} , η will assume x ∈ ⌦ ⊂ “ d ‣ Fokker-Planck (FP) equation gives the distribution on the 
 weight space induced by SGD ⇣ ⇣ ⌘ ⌘ + β − 1 div ρ t = div + f ρ D ρ where x ( t ) ∼ ρ ( t ) |{z} | {z } drift di ff usion 9

  10. Wasserstein gradient flow ‣ Heat equation performs steepest descent on the 
 ⇣ ⌘ ρ t = div I + ρ Dirichlet energy 1 Z � + ρ ( x ) � 2 dx 2 ⌦ ‣ It is also the steepest descent in the Wasserstein metric for   ◊ 2 2 ( ρ , ρ τ k )   Z   ρ τ k +1 ∈ argmin − H ( ρ ) + − H ( ρ ) = log ρ d ρ   2 τ   ρ Ω converges to trajectories of the heat equation ‣ Negative entropy is a Lyapunov functional for Brownian motion ρ ss heat = argmin − H ( ρ ) ρ 10

  11. Wasserstein gradient flow: with drift ‣ If , the Fokker-Planck equation D = I ⇣ ⌘ + f ρ + β − 1 I + ρ ρ t = div has the Jordan-Kinderleher-Otto (JKO) functional [Jordan et al., '97] − β − 1 H ( ρ ) f g ρ ss ( x ) = argmin f ( x ) ≈ x ∼ ρ ρ | {z } | {z } entropic term energetic term as the Lyapunov functional. ‣ FP is the steepest descent on JKO in the Wasserstein metric 11

  12. What happens for non-isotropic noise? ⇣ ⇣ ⌘ ⌘ + β − 1 div ρ t = div + f ρ D ρ |{z} | {z } drift di ff usion ‣ FP monotonically minimizes the free energy f g − β − 1 H ( ρ ) ρ ss ( x ) = argmin � ( x ) ≈ x ∼ ρ ρ ‣ Rewrite as F ( ρ ) = β − 1 KL ( ρ `` ρ ss ) compare with |x - x*| for deterministic optimization. 12

  13. SGD performs variational inference Theorem [Chaudhari & Soatto, ICLR '18] The functional F ( ρ ) = β − 1 KL ( ρ `` ρ ss ) is minimized monotonically by trajectories of the Fokker-Planck equation ⇣ ⌘ + f ρ + β − 1 div ( D ρ ) ρ t = div with as the steady-state distribution. Moreover, ρ ss Φ = − β − 1 log ρ ss up to a constant. 13

  14. Some implications ‣ Learning rate should scale linearly with batch-size β − 1 = η should not be small 2 b ‣ Sampling with replacement regularizes better than without ! w / o replacement = η 1 − b β − 1 2 b N also generalizes better. 14

  15. Information Bottleneck Principle ‣ Minimize mutual information of the representation with the training data 
 [Tishby '99, Achille & Soatto '17] − β − 1 KL f g ⇣ ⌘ IB β ( θ ) = ≈ x ∼ ρ θ f ( x ) ρ θ `` prior ‣ Minimizing these functionals is hard, SGD does it naturally 15

  16. Potential Phi vs. original loss f ‣ The solution of the variational problem is ρ ss ( x ) = 1 e − β Φ ( x ) Z β ‣ Key point Most likely locations of ρ ss ( x ) , 1 e � β f ( x ) SGD are not the critical Z 0 points of the original loss β ‣ The two losses are equal if and only if noise is isotropic D ( x ) = I Φ ( x ) = f ( x ) ⇔ 16

  17. Deep networks have highly non-isotropic noise CIFAR-10 CIFAR-100 λ ( D ) = 0 . 27 ± 0 . 84 λ ( D ) = 0 . 98 ± 2 . 16 rank( D ) = 0 . 34 % rank( D ) = 0 . 47 % ‣ Evaluate neural architectures using the di ff usion matrix 17

  18. How different are cats and dogs, really? 18

  19. SGD converges to limit cycles Theorem [Chaudhari & Soatto, ICLR '18] The most likely trajectories of SGD are x = j ( x ) , ˙ where the "leftover" vector field j ( x ) = − + f ( x ) + D ( x ) +� ( x ) − β − 1 div D ( x ) is such that div j ( x ) = 0 . 19

  20. Trajectories of SGD ‣ Run SGD for epochs 10 5 FFT of x i k +1 − x i k 20

  21. An example force-field saddle-point +� ( x ) = 0 � j ( x ) � is small j ( x ) = 0 very large � j ( x ) � 21

  22. Most likely locations are not the critical points of the original loss Theorem [Chaudhari & Soatto, ICLR '18] The Ito SDE q 2 β − 1 D dW ( t ) dx = − + f dt + is equivalent to an A-type SDE q ⇣ ⌘ 2 β − 1 D dW ( t ) dx = − D + Q +� dt + ρ ss ∝ e − β Φ ( x ) with the same steady-state if +� − β − 1 div ⇣ ⌘ ⇣ ⌘ + f = D + Q D + Q . 22

  23. Knots in our understanding ARCHITECTURE OPTIMIZATION GENERALIZATION 23

  24. Punchline Is SGD special? 24

  25. arXiv:1710.11029, ICLR '18 Stochastic gradient descent performs variational inference, converges to limit cycles for deep networks, Pratik Chaudhari and Stefano Soatto. www.pratikac.info Thank you, questions? 25

Recommend


More recommend