Unraveling the mysteries of stochastic gradient descent on deep neural networks Pratik Chaudhari UCLA VISION LAB 1
The question measures disagreement of predictions with ground truth x ∗ = argmin f ( x ) Cat Dog ... x weights aka parameters Stochastic gradient descent x k +1 = x k − η + f b ( x k ) Many, many variants: Why is SGD AdaGrad, rmsprop, Adam, so special? SAG, SVRG, Catalyst, APPA, Natasha, Katyusha… 2
Empirical evidence: wide “minima” 10 5 10 3 Frequency 10 Short negative tail 10 4 0 − 5 0 10 20 30 40 10 3 Frequency Eigenvalues 10 2 0 − 0.5 − 0.4 − 0.3 − 0.2 − 0.1 0.0 Eigenvalues 3
A bit of statistical physics ‣ Energy landscape of a binary perceptron Many sharp minima Few wide minima, � but generalize better � [Baldassi et al., '15] ‣ Wide minima are a large deviations phenomenon 4
Tilting the Gibbs measure ‣ Local Entropy [Chaudhari et al., ICLR '17] x ∗ = argmin f ( x ) x e − f ( x ) = argmax x ⇣ G γ ∗ e − f ( x ) ⌘ ≈ argmin − log x Gaussian kernel � of variance γ 5
Parle: parallelization of SGD ‣ State-of-the-art performance [Chaudhari et al., SysML '18] Wide-ResNet: CIFAR-10 All-CNN: CIFAR-10 (25% data) 6
The question Why is SGD so special? 7
A continuous-time view of SGD ‣ Diffusion matrix: variance of mini-batch gradients = D ( x ) ⇣ ⌘ var + f b ( x ) b * N + f k ( x ) + f k ( x ) > � + f ( x ) + f ( x ) > + = 1 1 X , - b N k =1 ‣ Temperature: ratio of learning rate and step-size β − 1 = η 2 b 8
A continuous-time view of SGD ‣ Continuous-time limit of discrete-time updates q 2 β − 1 D ( x ) dW ( t ) dx = − + f ( x ) dt + |{z} , η will assume x ∈ ⌦ ⊂ “ d ‣ Fokker-Planck (FP) equation gives the distribution on the weight space induced by SGD ⇣ ⇣ ⌘ ⌘ + β − 1 div ρ t = div + f ρ D ρ where x ( t ) ∼ ρ ( t ) |{z} | {z } drift di ff usion 9
Wasserstein gradient flow ‣ Heat equation performs steepest descent on the ⇣ ⌘ ρ t = div I + ρ Dirichlet energy 1 Z � + ρ ( x ) � 2 dx 2 ⌦ ‣ It is also the steepest descent in the Wasserstein metric for ◊ 2 2 ( ρ , ρ τ k ) Z ρ τ k +1 ∈ argmin − H ( ρ ) + − H ( ρ ) = log ρ d ρ 2 τ ρ Ω converges to trajectories of the heat equation ‣ Negative entropy is a Lyapunov functional for Brownian motion ρ ss heat = argmin − H ( ρ ) ρ 10
Wasserstein gradient flow: with drift ‣ If , the Fokker-Planck equation D = I ⇣ ⌘ + f ρ + β − 1 I + ρ ρ t = div has the Jordan-Kinderleher-Otto (JKO) functional [Jordan et al., '97] − β − 1 H ( ρ ) f g ρ ss ( x ) = argmin f ( x ) ≈ x ∼ ρ ρ | {z } | {z } entropic term energetic term as the Lyapunov functional. ‣ FP is the steepest descent on JKO in the Wasserstein metric 11
What happens for non-isotropic noise? ⇣ ⇣ ⌘ ⌘ + β − 1 div ρ t = div + f ρ D ρ |{z} | {z } drift di ff usion ‣ FP monotonically minimizes the free energy f g − β − 1 H ( ρ ) ρ ss ( x ) = argmin � ( x ) ≈ x ∼ ρ ρ ‣ Rewrite as F ( ρ ) = β − 1 KL ( ρ `` ρ ss ) compare with |x - x*| for deterministic optimization. 12
SGD performs variational inference Theorem [Chaudhari & Soatto, ICLR '18] The functional F ( ρ ) = β − 1 KL ( ρ `` ρ ss ) is minimized monotonically by trajectories of the Fokker-Planck equation ⇣ ⌘ + f ρ + β − 1 div ( D ρ ) ρ t = div with as the steady-state distribution. Moreover, ρ ss Φ = − β − 1 log ρ ss up to a constant. 13
Some implications ‣ Learning rate should scale linearly with batch-size β − 1 = η should not be small 2 b ‣ Sampling with replacement regularizes better than without ! w / o replacement = η 1 − b β − 1 2 b N also generalizes better. 14
Information Bottleneck Principle ‣ Minimize mutual information of the representation with the training data [Tishby '99, Achille & Soatto '17] − β − 1 KL f g ⇣ ⌘ IB β ( θ ) = ≈ x ∼ ρ θ f ( x ) ρ θ `` prior ‣ Minimizing these functionals is hard, SGD does it naturally 15
Potential Phi vs. original loss f ‣ The solution of the variational problem is ρ ss ( x ) = 1 e − β Φ ( x ) Z β ‣ Key point Most likely locations of ρ ss ( x ) , 1 e � β f ( x ) SGD are not the critical Z 0 points of the original loss β ‣ The two losses are equal if and only if noise is isotropic D ( x ) = I Φ ( x ) = f ( x ) ⇔ 16
Deep networks have highly non-isotropic noise CIFAR-10 CIFAR-100 λ ( D ) = 0 . 27 ± 0 . 84 λ ( D ) = 0 . 98 ± 2 . 16 rank( D ) = 0 . 34 % rank( D ) = 0 . 47 % ‣ Evaluate neural architectures using the di ff usion matrix 17
How different are cats and dogs, really? 18
SGD converges to limit cycles Theorem [Chaudhari & Soatto, ICLR '18] The most likely trajectories of SGD are x = j ( x ) , ˙ where the "leftover" vector field j ( x ) = − + f ( x ) + D ( x ) +� ( x ) − β − 1 div D ( x ) is such that div j ( x ) = 0 . 19
Trajectories of SGD ‣ Run SGD for epochs 10 5 FFT of x i k +1 − x i k 20
An example force-field saddle-point +� ( x ) = 0 � j ( x ) � is small j ( x ) = 0 very large � j ( x ) � 21
Most likely locations are not the critical points of the original loss Theorem [Chaudhari & Soatto, ICLR '18] The Ito SDE q 2 β − 1 D dW ( t ) dx = − + f dt + is equivalent to an A-type SDE q ⇣ ⌘ 2 β − 1 D dW ( t ) dx = − D + Q +� dt + ρ ss ∝ e − β Φ ( x ) with the same steady-state if +� − β − 1 div ⇣ ⌘ ⇣ ⌘ + f = D + Q D + Q . 22
Knots in our understanding ARCHITECTURE OPTIMIZATION GENERALIZATION 23
Punchline Is SGD special? 24
arXiv:1710.11029, ICLR '18 Stochastic gradient descent performs variational inference, converges to limit cycles for deep networks, Pratik Chaudhari and Stefano Soatto. www.pratikac.info Thank you, questions? 25
Recommend
More recommend