Unraveling the mysteries of stochastic gradient descent on deep - PowerPoint PPT Presentation
Unraveling the mysteries of stochastic gradient descent on deep neural networks Pratik Chaudhari UCLA VISION LAB 1 The question measures disagreement of predictions with ground truth x = argmin f ( x ) Cat Dog ... x weights aka
Unraveling the mysteries of stochastic gradient descent on deep neural networks Pratik Chaudhari UCLA VISION LAB 1
The question measures disagreement of predictions with ground truth x ∗ = argmin f ( x ) Cat Dog ... x weights aka parameters Stochastic gradient descent x k +1 = x k − η + f b ( x k ) Many, many variants: Why is SGD AdaGrad, rmsprop, Adam, so special? SAG, SVRG, Catalyst, APPA, Natasha, Katyusha… 2
Empirical evidence: wide “minima” 10 5 10 3 Frequency 10 Short negative tail 10 4 0 − 5 0 10 20 30 40 10 3 Frequency Eigenvalues 10 2 0 − 0.5 − 0.4 − 0.3 − 0.2 − 0.1 0.0 Eigenvalues 3
A bit of statistical physics ‣ Energy landscape of a binary perceptron Many sharp minima Few wide minima, � but generalize better � [Baldassi et al., '15] ‣ Wide minima are a large deviations phenomenon 4
Tilting the Gibbs measure ‣ Local Entropy [Chaudhari et al., ICLR '17] x ∗ = argmin f ( x ) x e − f ( x ) = argmax x ⇣ G γ ∗ e − f ( x ) ⌘ ≈ argmin − log x Gaussian kernel � of variance γ 5
Parle: parallelization of SGD ‣ State-of-the-art performance [Chaudhari et al., SysML '18] Wide-ResNet: CIFAR-10 All-CNN: CIFAR-10 (25% data) 6
The question Why is SGD so special? 7
A continuous-time view of SGD ‣ Diffusion matrix: variance of mini-batch gradients = D ( x ) ⇣ ⌘ var + f b ( x ) b * N + f k ( x ) + f k ( x ) > � + f ( x ) + f ( x ) > + = 1 1 X , - b N k =1 ‣ Temperature: ratio of learning rate and step-size β − 1 = η 2 b 8
A continuous-time view of SGD ‣ Continuous-time limit of discrete-time updates q 2 β − 1 D ( x ) dW ( t ) dx = − + f ( x ) dt + |{z} , η will assume x ∈ ⌦ ⊂ “ d ‣ Fokker-Planck (FP) equation gives the distribution on the weight space induced by SGD ⇣ ⇣ ⌘ ⌘ + β − 1 div ρ t = div + f ρ D ρ where x ( t ) ∼ ρ ( t ) |{z} | {z } drift di ff usion 9
Wasserstein gradient flow ‣ Heat equation performs steepest descent on the ⇣ ⌘ ρ t = div I + ρ Dirichlet energy 1 Z � + ρ ( x ) � 2 dx 2 ⌦ ‣ It is also the steepest descent in the Wasserstein metric for ◊ 2 2 ( ρ , ρ τ k ) Z ρ τ k +1 ∈ argmin − H ( ρ ) + − H ( ρ ) = log ρ d ρ 2 τ ρ Ω converges to trajectories of the heat equation ‣ Negative entropy is a Lyapunov functional for Brownian motion ρ ss heat = argmin − H ( ρ ) ρ 10
Wasserstein gradient flow: with drift ‣ If , the Fokker-Planck equation D = I ⇣ ⌘ + f ρ + β − 1 I + ρ ρ t = div has the Jordan-Kinderleher-Otto (JKO) functional [Jordan et al., '97] − β − 1 H ( ρ ) f g ρ ss ( x ) = argmin f ( x ) ≈ x ∼ ρ ρ | {z } | {z } entropic term energetic term as the Lyapunov functional. ‣ FP is the steepest descent on JKO in the Wasserstein metric 11
What happens for non-isotropic noise? ⇣ ⇣ ⌘ ⌘ + β − 1 div ρ t = div + f ρ D ρ |{z} | {z } drift di ff usion ‣ FP monotonically minimizes the free energy f g − β − 1 H ( ρ ) ρ ss ( x ) = argmin � ( x ) ≈ x ∼ ρ ρ ‣ Rewrite as F ( ρ ) = β − 1 KL ( ρ `` ρ ss ) compare with |x - x*| for deterministic optimization. 12
SGD performs variational inference Theorem [Chaudhari & Soatto, ICLR '18] The functional F ( ρ ) = β − 1 KL ( ρ `` ρ ss ) is minimized monotonically by trajectories of the Fokker-Planck equation ⇣ ⌘ + f ρ + β − 1 div ( D ρ ) ρ t = div with as the steady-state distribution. Moreover, ρ ss Φ = − β − 1 log ρ ss up to a constant. 13
Some implications ‣ Learning rate should scale linearly with batch-size β − 1 = η should not be small 2 b ‣ Sampling with replacement regularizes better than without ! w / o replacement = η 1 − b β − 1 2 b N also generalizes better. 14
Information Bottleneck Principle ‣ Minimize mutual information of the representation with the training data [Tishby '99, Achille & Soatto '17] − β − 1 KL f g ⇣ ⌘ IB β ( θ ) = ≈ x ∼ ρ θ f ( x ) ρ θ `` prior ‣ Minimizing these functionals is hard, SGD does it naturally 15
Potential Phi vs. original loss f ‣ The solution of the variational problem is ρ ss ( x ) = 1 e − β Φ ( x ) Z β ‣ Key point Most likely locations of ρ ss ( x ) , 1 e � β f ( x ) SGD are not the critical Z 0 points of the original loss β ‣ The two losses are equal if and only if noise is isotropic D ( x ) = I Φ ( x ) = f ( x ) ⇔ 16
Deep networks have highly non-isotropic noise CIFAR-10 CIFAR-100 λ ( D ) = 0 . 27 ± 0 . 84 λ ( D ) = 0 . 98 ± 2 . 16 rank( D ) = 0 . 34 % rank( D ) = 0 . 47 % ‣ Evaluate neural architectures using the di ff usion matrix 17
How different are cats and dogs, really? 18
SGD converges to limit cycles Theorem [Chaudhari & Soatto, ICLR '18] The most likely trajectories of SGD are x = j ( x ) , ˙ where the "leftover" vector field j ( x ) = − + f ( x ) + D ( x ) +� ( x ) − β − 1 div D ( x ) is such that div j ( x ) = 0 . 19
Trajectories of SGD ‣ Run SGD for epochs 10 5 FFT of x i k +1 − x i k 20
An example force-field saddle-point +� ( x ) = 0 � j ( x ) � is small j ( x ) = 0 very large � j ( x ) � 21
Most likely locations are not the critical points of the original loss Theorem [Chaudhari & Soatto, ICLR '18] The Ito SDE q 2 β − 1 D dW ( t ) dx = − + f dt + is equivalent to an A-type SDE q ⇣ ⌘ 2 β − 1 D dW ( t ) dx = − D + Q +� dt + ρ ss ∝ e − β Φ ( x ) with the same steady-state if +� − β − 1 div ⇣ ⌘ ⇣ ⌘ + f = D + Q D + Q . 22
Knots in our understanding ARCHITECTURE OPTIMIZATION GENERALIZATION 23
Punchline Is SGD special? 24
arXiv:1710.11029, ICLR '18 Stochastic gradient descent performs variational inference, converges to limit cycles for deep networks, Pratik Chaudhari and Stefano Soatto. www.pratikac.info Thank you, questions? 25
Recommend
More recommend
Explore More Topics
Stay informed with curated content and fresh updates.