Implicit Optimization Bias as a key to Understanding Deep Learning Nati Srebro (TTIC) Based on joint work with Behnam Neyshabur (TTIC → IAS), Ryota Tomioka (TTIC → MSR), Srinadh Bhojanapalli, Suriya Gunasekar, Blake Woodworth, Pedro Savarese (TTIC), Russ Salakhutdinov (CMU), Ashia Wilson, Becca Roelofs, Mitchel Stern, Ben Recht (Berkeley), Daniel Soudry, Elad Hoffer, Mor Shpigel (Technion), Jason Lee (USC)
Increasing the Network Size [Neyshabur Tomioka S ICLR’15]
Increasing the Network Size [Neyshabur Tomioka S ICLR’15]
Increasing the Network Size [Neyshabur Tomioka S ICLR’15]
Increasing the Network Size 1 Test Error 0.5 0 Complexity (Path Norm?) [Neyshabur Tomioka S ICLR’15]
• What is the relevant “complexity measure” ( eg norm)? • How is this minimized (or controlled) by the optimization algorithm? • How does it change if we change the opt algorithm?
Cross-Entropy 0/1 Training Error 0/1 Test Error Training Loss 2.5 0.02 0.035 Path-SGD 2 0.015 0.03 MNIST SGD 1.5 0.01 0.025 1 0.005 0.02 0.5 0 0 0.015 2.5 0 50 100 150 200 250 300 0.2 0 50 100 150 200 250 300 0.5 0 50 100 150 200 250 300 2 0.48 0.15 CIFAR-10 1.5 0.46 0.1 1 0.44 0.05 0.5 0.42 0 0.4 0 0 50 100 150 200 250 300 0 50 100 150 200 250 300 0 50 100 150 200 250 300 Epoch Epoch Epoch 2.5 0.4 0.18 Epoch Epoch Epoch 0.17 2 0.3 SVHN 0.16 1.5 0.2 0.15 With Dropout 1 0.14 0.1 0.5 0.13 0 0.12 0 5 0.8 0.75 0 100 200 300 400 0 100 200 300 400 0 100 200 300 400 4 CIFAR-100 0.6 3 0.4 0.7 2 0.2 1 0 0.65 0 0 100 200 300 400 0 100 200 300 400 0 100 200 300 400 [Neyshabur Salakhudtinov S NIPS’15] Epoch Epoch
SGD vs ADAM Traini Error (Preplexity) Test Error (Preplexity) Results on Penn Treebank using 3-layer LSTM [Wilson Roelofs Stern S Recht , “The Marginal Value of Adaptive Gradient Methods in Machine Learning”, NIPS’17]
The Deep Recurrent Residual Boosting Machine Joe Flow, DeepFace Labs Section 1: Introduction We suggest a new amazing architecture and loss function that is great for learning. All you have to do to learn is fit the model on your training data Section 2: Learning Contribution: our model The model class ℎ 𝑥 is amazing. Our learning method is: 𝟐 𝒏 𝒎𝒑𝒕𝒕(𝒊 𝒙 𝒚 ; 𝒛) 𝐛𝐬𝐡 𝐧𝐣𝐨 𝒏 σ 𝒋=𝟐 (*) 𝒙 Section 3: Optimization This is how we solve the optimization problem (*): […] Section 4: Experiments It works!
Different optimization algorithm ➔ Different bias in optimum reached ➔ Different Inductive bias ➔ Different learning properties Goal: understand optimization algorithms not just as reaching some (global) optimum, but as reaching a specific optimum
Today Precisely understand implicit bias in: • Matrix Factorization • Linear Classification (Logistic Regression) • Linear Convolutional Networks
Matrix Reconstruction 2 𝑋∈ℝ 𝑜×𝑜 𝐺 𝑋 = 𝑋 − 𝑧 2 min 𝐵 1 , … , 𝐵 𝑛 ∈ ℝ 𝑜×𝑜 𝑧 ∈ ℝ 𝑛 𝑋 𝑗 = 〈𝐵 𝑗 , 𝑋〉 2 4 5 1 4 2 • Matrix completion ( 𝐵 𝑗 is indicator matrix) 3 1 2 2 5 4 4 2 4 1 3 1 • Matrix reconstruction from linear measurements 3 3 4 2 4 𝒛 • Multi-task learning ( 𝐵 𝑗 = 𝑓 𝑢𝑏𝑡𝑙 𝑝𝑔 𝑓𝑦𝑏𝑛𝑞𝑚𝑓 𝑗 ⋅ 𝜚 𝑓𝑦𝑏𝑛𝑞𝑚𝑓 𝑗 ⊤ ) 2 3 1 4 3 2 2 2 1 4 5 2 4 1 4 2 3 1 3 1 1 4 3 4 2 2 5 3 1 4 5 1 4 2 4 5 1 4 2 4 5 1 4 2 3 1 2 2 5 4 3 1 2 5 4 3 1 2 2 5 4 4 2 4 1 3 1 4 2 4 1 3 1 4 2 4 1 3 1 3 3 4 2 4 3 𝑩 𝟑 3 4 2 4 3 𝑩 𝟒 3 4 2 4 𝑩 𝟐 We are interested in the regime 𝒏 ≪ 𝒐 𝟑 2 3 1 4 3 2 2 1 4 3 2 2 3 1 4 3 2 2 2 1 4 5 2 2 1 4 5 2 2 1 4 5 2 4 1 4 2 3 2 4 1 4 2 3 2 4 1 4 2 3 1 3 1 1 4 3 1 3 1 1 4 3 1 3 1 1 4 3 • Many global optima for which 𝑋 = 𝑧 4 2 2 5 3 1 4 2 2 5 3 1 4 2 2 5 3 1 • Easy to have 𝑋 = 𝑧 without reconstruction/generalization - E.g. for matrix completion, set all unobserved entries to 0 • Gradient Descent on 𝑋 will generally yield trivial non-generalizing solution
Factorized Matrix Reconstruction 2 4 5 1 4 2 3 1 2 2 5 4 4 2 4 1 3 1 𝑊 ⊤ × ≈ 𝑋 𝑉 = 3 3 4 2 4 𝒛 2 3 1 4 3 2 2 2 1 4 5 2 4 1 4 2 3 1 3 1 1 4 3 4 2 2 5 3 1 𝑉,𝑊∈ℝ 𝑜×𝑜 𝑔 𝑉, 𝑊 = 𝐺 𝑉𝑊 ⊤ = 𝑉𝑊 ⊤ − 𝑧 2 2 min • Since 𝑉, 𝑊 full dim, no constraint on 𝑋 , equivalent to min 𝑋 𝐺(𝑋) • Underdetermined, all the same global min, trivial to minimize without generalizing What happens when we optimize by gradient descent on 𝑽, 𝑾 ?
Gradient descent on 𝒈 𝑽, 𝑾 gets to “good” global minima
Gradient descent on 𝒈 𝑽, 𝑾 gets to “good” global minima Gradient descent on 𝒈 𝑽, 𝑾 generalizes better with smaller step size
Question: Which global minima does gradient descent reach? Why does it generalize well?
Gradient descent on 𝑔(𝑉, 𝑊) converges to a minimum nuclear norm solution
Conjecture : With stepsize → 0 (i.e. gradient flow) and initialization → 0, gradient descent on 𝑉 converges to minimum nuclear norm solution: 𝑉𝑉 ⊤ → min 𝑋≽0 𝑋 ∗ 𝑡. 𝑢. 𝑌 = 𝑧 [Gunasekar Woodworth Bhojanapalli Neyshabur S 2017] • Rigorous proof when 𝐵 𝑗 s commute • General 𝐵 𝑗 : empirical validation + hand waving • Yuanzhi Li, Hongyang Zhang and Tengyu Ma: proved when 𝑧 = (𝑋 ∗ ) , 𝑋 ∗ low rank, RIP
Implicit Bias in Least Squared min ‖ 𝐵𝑥 − 𝑐‖ 2 • Gradient Descent (+Momentum) on 𝑥 ➔ min 𝐵𝑥=𝑐 𝑥 2 • Gradient Descent on factorization 𝑋 = 𝑉𝑊 ➔ probably min 𝐵 𝑋 =𝑐 𝑋 𝑢𝑠 with stepsize ↘ 0 and init ↘ 0 , but only in limit, depends on stepsize, init, proved only in special cases • AdaGrad on 𝑥 ➔ in some special cases min 𝐵𝑥=𝑐 𝑥 ∞ , but not always, and it depends on stepsize, adaptation param, momentum • Steepest Descent w.r.t. ‖𝑥‖ ➔ ??? Not min 𝐵𝑥=𝑐 𝑥 , even as stepsize ↘ 0 ! and it depends on stepsize, init, momentum • Coordinate Descent (steepest descent w.r.t. 𝑥 1 ) ➔ Related to, but not quite the Lasso (with stepsize ↘ 0 and particular tie-breaking ≈ LARS)
Training Single Unit on Separable Data 𝑛 arg min 𝑥∈ℝ 𝑜 ℒ 𝑥 = ℓ 𝑧 𝑗 𝑥, 𝑦 𝑗 𝑗=1 ℓ 𝑨 = log 1 + 𝑓 −𝑨 𝑛 • Data 𝑦 𝑗 , 𝑧 𝑗 linearly separable ( ∃ 𝑥 ∀ 𝑗 𝑧 𝑗 𝑥, 𝑦 𝑗 > 0 ) 𝑗=1 • Where does gradient descent converge? 𝑥 𝑢 = 𝑥 𝑢 − 𝜃𝛼ℒ(𝑥(𝑢)) • inf ℒ 𝑥 = 0 , but minima unattainable • GD diverges to infinity: 𝑥 𝑢 → ∞ , ℒ 𝑥 𝑢 → 0 • In what direction? What does 𝑥 𝑢 converge to? 𝑥 𝑢 𝑥 𝑢 𝑥 ෝ • Theorem : 2 → 𝑥 = arg min 𝑥 2 𝑡. 𝑢. ∀ 𝑗 𝑧 𝑗 𝑥, 𝑦 𝑗 ≥ 1 ෝ ෝ 𝑥 𝑢 𝑥 2
Other Objectives and Opt Methods • Single linear unit, logistic loss ➔ hard margin SVM solution (regardless of init, stepsize) • Multi-class problems with softmax loss ➔ multiclass SVM solution (regardless of init, stepsize) • Steepest Descent w.r.t. ‖𝑥‖ ➔ arg min 𝑥 𝑡. 𝑢. ∀ 𝑗 𝑧 𝑗 𝑥, 𝑦 𝑗 ≥ 1 (regardless of init, stepsize) • Coordinate Descent ➔ arg min 𝑥 1 𝑡. 𝑢. ∀ 𝑗 𝑧 𝑗 𝑥, 𝑦 𝑗 ≥ 1 (regardless of init, stepsize) • Matrix factorization problems ℒ 𝑉, 𝑊 = σ 𝑗 ℓ 𝐵 𝑗 , 𝑉𝑊 ⊤ , including 1-bit matrix completion ➔ arg min 𝑋 𝑢𝑠 𝑡. 𝑢. 𝐵 𝑗 , 𝑋 ≥ 1 (regardless of init)
Linear Neural Networks • Graph 𝐻(𝑊, 𝐹) , with ℎ 𝑤 = σ 𝑣→𝑤 𝑥 𝑣→𝑤 ℎ 𝑣 • Input units ℎ 𝑗𝑜 = 𝑦 𝑗 ∈ ℝ 𝑜 , single output ℎ 𝑝𝑣𝑢 (𝑦 𝑗 ) , binary label 𝑧 𝑗 ∈ ±1 𝑛 ℓ 𝑧 𝑗 ℎ 𝑝𝑣𝑢 𝑦 𝑗 • Training: min 𝑥 σ 𝑗=1 • Implements linear predictor: ℎ 𝑝𝑣𝑢 𝑦 𝑗 = ⟨𝒬 𝑥 , 𝑦 𝑗 ⟩ • Training: 𝑛 min 𝑥 ℒ 𝒬 𝑥 = ℓ 𝑧 𝑗 𝒬 𝑥 , 𝑦 𝑗 𝑗=1 • Just a different parametrization of linear classification: 𝛾∈Im 𝒬 ℒ(𝛾) min Im 𝒬 = ℝ 𝑜 in all our examples • GD on 𝑥 : different optimization procedure for same argmin problem 𝒬 𝑥 𝑢 • Limit of GD: 𝛾 ∞ = lim 𝒬 𝑥 𝑢 𝑢→∞
Fully Connected Linear NNs 𝑀 fully connected layers with 𝐸 𝑚 ≥ 1 units in layer 𝑚 𝑈 ℎ 𝑚−1 ℎ 0 = ℎ 𝑗𝑜 ℎ 𝑚 = 𝑋 ℎ 𝑝𝑣𝑢 = ℎ 𝑀 𝑚 𝑚 ∈ ℝ 𝐸 𝑚 ×𝐸 𝑚−1 , 𝑚 = 1. . 𝑀 ℎ 𝑚 ∈ 𝐸 𝑚 , parameters: 𝑥 = 𝑋 Theorem : 𝛾 ∞ ∝ arg min 𝛾 2 𝑡. 𝑢. ∀ 𝑗 𝑧 𝑗 𝛾, 𝑦 𝑗 ≥ 1 for ℓ 𝑨 = exp(−𝑨) , almost all linearly separable data sets and initializations 𝑥(0) and any bounded stepsizes s.t. ℒ 𝑥 𝑢 → 0 and Δ𝑥 𝑢 = 𝑥 𝑢 − 𝑥 𝑢 − 1 converges in direction
Recommend
More recommend