Gradient Descent Finds Global Minima of Deep Neural Networks Simon S. Du, Jason D. Lee, Haochuan Li, Liwei Wang, Xiyu Zhai 1
Empirical Observations on Empirical Risk • Zhang et al, 2017, Understanding Deep Learning Requires Rethinking Generalization. Randomization Test: replace true labels by random labels. Observations: Empirical Risk-> 0 for both true labels and random labels. Conjecture: because neural networks are over-parameterized. Open Problem: why gradient descent can find a neural network that fits all labels. 2
Setup { x i , y i } n • Training Data: i =1 , x i ∈ R d , y i ∈ R • A Model. • Fully connected neural network: f ( θ , x ) = W L σ ( W L − 1 · · · W 2 σ ( W 1 x ) · · · ) • A loss function. n • Quadratic loss: R ( θ ) = 1 X ( f ( θ , x i ) − y i ) 2 2 n i =1 • An optimization algorithm: • Gradient descent: θ ( t + 1) ← θ ( t ) − η ∂ R ( θ ( t )) ∂θ ( t ) 3
Trajectory-based Analysis θ ( t + 1) ← θ ( t ) − η ∂ R ( θ ( t )) ∂θ ( t ) • Trajectory of parameters: θ (0) , θ (1) , θ (2) , · · · • Predictions: u i ( t ) , f ( θ ( t ) , x i ) , u ( t ) , ( u 1 ( t ) , . . . , u n ( t )) > ∈ R n • Trajectory of predictions: u (0) , u (1) , u (2) , . . . 4
Proof Sketch • Simplified form (continuous time): L du ( t ) ij ( t ) = 1 n h ∂ u i ( t ) ∂ W ` ( t ) , ∂ u j ( t ) X H ` ( t ) ( y − u ( t )) H ` = − ∂ W ` ( t ) i dt ` =1 • Random initialization + concentration + perturbation analysis: L L L X X X H ` (0) → H ∞ H ` ( t ) → H ` (0) , ∀ t ≥ 0 lim lim m →∞ m →∞ ` =1 ` =1 ` =1 • Linear ODE theory: k u ( t ) � y k 2 2 exp ( � λ 0 t ) k u (0) � y k 2 2 , λ 0 = λ min ( H ∞ ) 5
Main Results Theorem 1: For fully-connected neural network with smooth activation, if ! = poly ', 2 * , 1/- . and step 1 2 size / = 0 3 4 5 6(8) , then with high probability over random initialization we have: for : = 1,2, … @ <(=(0)) . < = : ≤ 1 − /- . First global linear convergence guarantee for deep NN. • Exponential dependence due to error propagation. • 6
Main Results (Cont’d) Theorem 2: For ResNet or Convolutional ResNet with smooth activation, if ! = 0 1 poly ', ), 1/, - and step size . = / 2 3 , then with high probability over random initialization we have: for 4 = 1,2, … ; 7(8(0)) . 7 8 4 ≤ 1 − ., - ResNet architecture makes the error propagation more stable => • exponential improvement over fully-connected neural networks. 7
Learn more @ Pacific Ball Room #80 8
Recommend
More recommend