fine grained analysis of optimization and generalization
play

Fine-Grained Analysis of Optimization and Generalization for - PowerPoint PPT Presentation

Fine-Grained Analysis of Optimization and Generalization for Overparameterized Two-Layer NNs Sanjeev Arora Simon S. Du Wei Hu Princeton & IAS CMU Princeton Zhiyuan Li Ruosong Wang Princeton CMU Rethinking generalization


  1. Fine-Grained Analysis of Optimization and Generalization for Overparameterized Two-Layer NNs Sanjeev Arora Simon S. Du Wei Hu Princeton & IAS CMU Princeton Zhiyuan Li Ruosong Wang Princeton CMU

  2. “Rethinking generalization” Experiment [Zhang et al ‘17] True Labels: 2 1 3 1 4 Random Labels: 5 1 7 0 8

  3. “Rethinking generalization” Experiment [Zhang et al ‘17] Unexplained phenomena ① SGD achieves nearly 0 training loss for both correct and random labels (overparametrization!) ② Good generalization with correct labels Faster convergence with correct labels than random labels.

  4. “Rethinking generalization” Experiment [Zhang et al ‘17] Unexplained phenomena ① SGD achieves nearly 0 training loss for both correct and random labels (overparametrization!) ② Good generalization with correct labels Faster convergence with correct labels than random labels. No good explanation in existing generalization theory: model complexity generalization gap ≤ # training samples

  5. “Rethinking generalization” Experiment [Zhang et al ‘17] Unexplained phenomena ① SGD achieves nearly 0 training loss for both correct and random labels (overparametrization!) ② Good generalization with correct labels Faster convergence with correct labels than random labels. No good explanation in existing generalization theory: This paper: Theoretical explanation for model complexity generalization gap ≤ overparametrized 2-layer # training samples nets using label properties

  6. Setting: Overparam Two-Layer ReLU Neural Nets Unexplained phenomena ① SGD achieves nearly 0 training loss for both correct and random labels (overparametrization!) ② Good generalization with correct labels Faster convergence with correct labels. 𝑋 Overparam: # hidden nodes is large 𝑦 ! 𝑔(𝑋, 𝑦) Training obj: ℓ ! loss, binary classification Init: i.i.d. Gaussian 𝑦 " Opt algo: GD for the first layer, 𝑋 𝑦 #

  7. Setting: Overparam Two-Layer ReLU Neural Nets Unexplained phenomena [Du et al., ICLR’19]: ① SGD achieves nearly 0 training loss for both GD converges to 0 training loss correct and random labels (overparametrization!) Explains phenomenon ①, but not ② or ③ ② Good generalization with correct labels Faster convergence with correct labels. 𝑋 Overparam: # hidden nodes is large 𝑦 ! 𝑔(𝑋, 𝑦) Training obj: ℓ ! loss, binary classification Init: i.i.d. Gaussian 𝑦 " Opt algo: GD for the first layer, 𝑋 𝑦 #

  8. Setting: Overparam Two-Layer ReLU Neural Nets Unexplained phenomena [Du et al., ICLR’19]: ① SGD achieves nearly 0 training loss for both GD converges to 0 training loss correct and random labels (overparametrization!) Explains phenomenon ①, but not ② or ③ ② Good generalization with correct labels Faster convergence with correct labels. This paper: for ② and ③ 𝑋 Faster convergence • with true labels Overparam: # hidden nodes is large 𝑦 ! 𝑔(𝑋, 𝑦) A data-dependent • Training obj: ℓ ! loss, binary classification generalization bound Init: i.i.d. Gaussian 𝑦 " (distinguish random Opt algo: GD for the first layer, 𝑋 labels from true labels). 𝑦 #

  9. Training Speed Theorem: # 𝐽 − 𝜃𝐼 " ⋅ 𝑧 loss iteration 𝑙 ≈ • 𝑧 : vector of labels • 𝐼 : kernel matrix (“Neural Tangent Kernel”), % 𝑦 # = 𝜌 − arccos 𝑦 " 𝐼 "# = E $ ∇ $ 𝑔 𝑋, 𝑦 " , ∇ $ 𝑔 𝑋, 𝑦 # % 𝑦 # 𝑦 " 2𝜌

  10. Training Speed Theorem: # 𝐽 − 𝜃𝐼 " ⋅ 𝑧 loss iteration 𝑙 ≈ Label projection sorted by eigenval • 𝑧 : vector of labels • 𝐼 : kernel matrix (“Neural Tangent Kernel”), % 𝑦 # = 𝜌 − arccos 𝑦 " 𝐼 "# = E $ ∇ $ 𝑔 𝑋, 𝑦 " , ∇ $ 𝑔 𝑋, 𝑦 # % 𝑦 # 𝑦 " 2𝜌 Implication: Training loss over time • Training speed determined by projections of 𝑧 on eigenvectors of 𝐼 : 𝑧, 𝑤 ! , 𝑧, 𝑤 " , 𝑧, 𝑤 # , … Explains different training speeds on correct vs random • Components on top eigenvectors converge to 0 labels faster than components on bottom eigenvectors

  11. Explaining Generalization despite vast overparametrization “data dependent complexity” Theorem: For 1-Lipschitz loss, 2𝑧 $ 𝐼 %& 𝑧 test error ≤ # training samples + small terms Corollary: Simple functions are provably learnable (eg, linear function and even-degree polynomials).

  12. Explaining Generalization despite vast overparametrization “data dependent complexity” Theorem: For 1-Lipschitz loss, 2𝑧 $ 𝐼 %& 𝑧 test error ≤ # training samples + small terms Corollary: Simple functions are provably learnable (eg, linear function and even-degree polynomials). Poster #75 tonight

  13. Explaining Generalization despite vast overparametrization “data dependent complexity” Theorem: For 1-Lipschitz loss, 2𝑧 $ 𝐼 %& 𝑧 test error ≤ # training samples + small terms “Distance to Init” “Min RKHS norm for training labels” Corollary: Simple functions are provably learnable (eg, linear function and even-degree polynomials). Poster #75 tonight

Recommend


More recommend