Non-convex Optimization for Machine Learning Prateek Jain Microsoft Research India
Outline • Optimization for Machine Learning • Non-convex Optimization • Convergence to Stationary Points • First order stationary points • Second order stationary points • Non-convex Optimization in ML • Neural Networks • Learning with Structure • Alternating Minimization • Projected Gradient Descent
Relevant Monograph (Shameless Ad)
Optimization in ML Supervised Learning • Given points ( 𝑦 𝑗 , 𝑧 𝑗 ) • Prediction function: ෝ 𝑧 𝑗 = 𝜚(𝑦 𝑗 , 𝑥) • Minimize loss: min 𝑥 σ 𝑗 ℓ(𝜚 𝑦 𝑗 , 𝑥 , 𝑧 𝑗 ) Unsupervised Learning Given points (𝑦 1 , 𝑦 2 … 𝑦 𝑂 ) Find cluster center or train GANs 𝑦 𝑗 = 𝜚(𝑦 𝑗 , 𝑥) Represent ෝ Minimize loss: min 𝑥 σ 𝑗 ℓ(𝜚 𝑦 𝑗 , 𝑥 , 𝑦 𝑗 )
Optimization Problems • Unconstrained optimization • Constrained optimization 𝑥∈𝑆 𝑒 𝑔(𝑥) min min 𝑥 𝑔(𝑥) 𝑡. 𝑢. 𝑥 ∈ 𝐷 • Support Vector Machines • Deep networks • Sparse regression • Regression • Recommendation system • Gradient Boosted Decision Trees • …
Convex Optimization min 𝑥 𝑔(𝑥) 𝑡. 𝑢. 𝑥 ∈ 𝐷 Convex function Convex set ∀𝑥 1 , 𝑥 2 ∈ 𝐷, 𝜇𝑥 1 + 1 − 𝜇 𝑥 2 ∈ 𝐷 𝑔 𝜇𝑥 1 + 1 − 𝜇 𝑥 2 ≤ 𝜇𝑔 𝑥 1 + 1 − 𝜇 𝑔 𝑥 2 , 0 ≤ 𝜇 ≤ 1 0 ≤ 𝜇 ≤ 1 Slide credit: Purushottam Kar
Examples Linear Programming Quadratic Programming Semidefinite Programming Slide credit: Purushottam Kar
Convex Optimization • Unconstrained optimization • Constrained optimization 𝑥∈𝑆 𝑒 𝑔(𝑥) min min 𝑥 𝑔(𝑥) 𝑡. 𝑢. 𝑥 ∈ 𝐷 Optima: KKT conditions Optima: just ensure ∇ 𝑥 𝑔 𝑥 = 0 In this talk, lets assume 𝑔 is 𝑀 − smooth => 𝑔 is differentiable 𝑔 𝑦 ≤ 𝑔 𝑧 + ∇𝑔 𝑧 , 𝑦 − 𝑧 + 𝑀 2 ||𝑦 − 𝑧|| 2 OR, ||∇𝑔 𝑦 − ∇𝑔 𝑧 || ≤ 𝑀||𝑦 − 𝑧||
Gradient Descent Methods • Projected gradient descent method: • For t=1, 2, … (until convergence) • 𝑥 𝑢+1 = 𝑄 𝐷 (𝑥 𝑢 − 𝜃∇𝑔 𝑥 𝑢 ) • 𝜃: step-size
Convergence Proof 𝑥 𝑢 𝑔 𝑥 𝑢+1 ≤ 𝑔 𝑥 𝑢 + ∇𝑔 𝑥 𝑢 , 𝑥 𝑢+1 − 𝑥 𝑢 + 𝑀 2 ||𝑥 𝑢+1 − 𝑥 𝑢 || 2 𝑔 𝑥 𝑢+1 ≤ 𝑔 𝑥 𝑢 − 1 − 𝑀𝜃 𝜃||∇𝑔 𝑥 𝑢 || 2 ≤ 𝑔 𝑥 𝑢 − 𝜃 2 ||∇𝑔 𝑥 𝑢 || 2 2 𝑔 𝑥 𝑢+1 ≤ 𝑔 𝑥 ∗ + ∇𝑔 𝑥 𝑢 , 𝑥 𝑢 − 𝑥 ∗ − 1 2𝜃 ||𝑥 𝑢+1 − 𝑥 𝑢 || 2 𝑔 𝑥 𝑢+1 ≤ 𝑔 𝑥 ∗ + 1 2𝜃 ||𝑥 𝑢 − 𝑥 ∗ || 2 − ||𝑥 𝑢+1 − 𝑥 ∗ || 2 1 𝑈 ⋅ 2𝜃 ||𝑥 0 − 𝑥 ∗ || 2 ⇒ 𝑔 𝑥 𝑈 ≤ 𝑔 𝑥 ∗ + 𝜗 𝑔 𝑥 𝑈 ≤ 𝑔 𝑥 ∗ + 𝑈 = 𝑃 𝑀 ⋅ ||𝑥 0 − 𝑥 ∗ || 2 𝜗
Non-convexity? 𝑥∈𝑆 𝑒 𝑔(𝑥) min • Critical points: ∇𝑔 𝑥 = 0 • But: ∇f w = 0 ⇏ Optimality
Local Optima • 𝑔 𝑥 ≤ 𝑔 𝑥 ′ , ∀||𝑥 − 𝑥 ′ || ≤ 𝜗 Local Minima image credit: academo.org
First Order Stationary Points First Order Stationary Point (FOSP) • Defined by: ∇𝑔 𝑥 = 0 • But ∇ 2 𝑔(𝑥) need not be positive semi-definite image credit: academo.org
First Order Stationary Points First Order Stationary Point (FOSP) 2 − 𝑥 2 2 ) • E.g., 𝑔 𝑥 = 0.5(𝑥 1 𝑥 1 • ∇𝑔 𝑥 = −𝑥 2 • ∇𝑔 0 = 0 • But, ∇ 2 𝑔 𝑥 = 1 0 −1 ⇒ 𝑗𝑜𝑒𝑓𝑔𝑗𝑜𝑗𝑢𝑓 0 𝜗 3 8 𝜗 2 ⇒ 𝑔 0,0 • 𝑔 2 , 𝜗 = − is not a local minima image credit: academo.org
Second Order Stationary Points Second Order Stationary Point (SOSP) if: • ∇𝑔 𝑥 = 0 • ∇ 2 𝑔 𝑥 ≽ 0 Does it imply local optimality? Second Order Stationary Point (SOSP) image credit: academo.org
Second Order Stationary Points 1 3 − 3 𝑥 1 𝑥 2 2 ) • 𝑔 𝑥 = 3 (𝑥 1 2 − 𝑥 2 2 ) • ∇ 𝑔 𝑥 = (𝑥 1 −2 𝑥 1 𝑥 2 2𝑥 1 −2𝑥 2 • ∇ 2 𝑔 𝑥 = −2𝑥 2 −2𝑥 1 • ∇𝑔 0 = 0, ∇ 2 𝑔 0 = 0 ⇒ 0 𝑗𝑡 𝑇𝑃𝑇𝑄 2 3 𝜗 3 < 𝑔(0) • 𝑔 𝜗, 𝜗 = − Second Order Stationary Point (SOSP) image credit: academo.org
Stationarity and local optima • 𝑥 is local optima implies: 𝑔 𝑥 ≤ 𝑔 𝑥 ′ , ∀||𝑥 − 𝑥 ′ || ≤ 𝜗 • 𝑥 is FOSP implies: 𝑔 𝑥 ≤ 𝑔 𝑥 ′ + 𝑃(||𝑥 − 𝑥|| 2 ) • 𝑥 is SOSP implies: 𝑔 𝑥 ≤ 𝑔 𝑥 ′ + 𝑃(||𝑥 − 𝑥 ′ || 3 ) • 𝑥 is p-th order SP implies: 𝑔 𝑥 ≤ 𝑔 𝑥 ′ + 𝑃(||𝑥 − 𝑥 ′ || 𝑞+1 ) • That is, local optima: 𝑞 = ∞
𝑔 𝑥 ≤ 𝑔 𝑥 ′ + 𝑃(||𝑥 − 𝑥 ′ || 𝑞+1 ) Computability? First Order Stationary Point Second Order Stationary Point Third Order Stationary Point 𝑞 ≥ 4 Stationary Point NP-Hard Local Optima NP-Hard Anandkumar and Ge-2016
Does Gradient Descent Work for Local Optimality? • Yes! • In fact, with high probability converges to a “local minimizer” • If initialized randomly!!! • But no rates known • NP-hard in general!! • Big open problem ☺ image credit: academo.org
Finding First Order Stationary Points First Order Stationary Point (FOSP) • Defined by: ∇𝑔 𝑥 = 0 • But ∇ 2 𝑔(𝑥) need not be positive semi-definite image credit: academo.org
Gradient Descent Methods • Gradient descent: • For t=1, 2, … (until convergence) • 𝑥 𝑢+1 = 𝑥 𝑢 − 𝜃∇𝑔 𝑥 𝑢 • 𝜃: step-size • Assume: ||∇𝑔 𝑦 − ∇𝑔 𝑧 || ≤ 𝑀||𝑦 − 𝑧||
Convergence to FOSP 𝑔 𝑥 𝑢+1 ≤ 𝑔 𝑥 𝑢 + ∇𝑔 𝑥 𝑢 , 𝑥 𝑢+1 − 𝑥 𝑢 + 𝑀 2 ||𝑥 𝑢+1 − 𝑥 𝑢 || 2 1 − 𝑀𝜃 𝜃||∇𝑔 𝑥 𝑢 || 2 ≤ 𝑔 𝑥 𝑢 − 1 2𝑀 ||∇𝑔 𝑥 𝑢 || 2 𝑔 𝑥 𝑢+1 ≤ 𝑔 𝑥 𝑢 − 2 ||∇𝑔 𝑥 𝑢 || 2 ≤ 𝑔 𝑥 𝑢 − 𝑔 𝑥 𝑢+1 1 ||∇𝑔 𝑥 𝑢 || 2 ≤ 𝑔 𝑥 0 − 𝑔(𝑥 ∗ ) 2𝑀 𝑢 2𝑀 (𝑔 𝑥 0 − 𝑔(𝑥 ∗ )) min ||∇𝑔 𝑥 𝑢 || ≤ ≤ 𝜗 𝑈 𝑢 𝑈 = 𝑃 𝑀 ⋅ (𝑔 𝑥 0 − 𝑔 𝑥 ∗ ) 𝜗 2
Accelerated Gradient Descent for FOSP? 𝑏 𝑥 𝑢 𝑛𝑒 𝑥 𝑢 𝑥 𝑢 • For t=1, 2….T 𝑏 + 𝛽 𝑢 𝑥 𝑢 𝑛𝑒 = 1 − 𝛽 𝑢 𝑥 𝑢 • 𝑥 𝑢+1 𝑛𝑒 ) • 𝑥 𝑢+1 = 𝑥 𝑢 − 𝜃 𝑢 ∇𝑔(𝑥 𝑢+1 𝑏 𝑥 𝑢+1 𝑏 = 𝑥 𝑢 𝑛𝑒 − 𝛾 𝑢 ∇𝑔(𝑥 𝑢+1 𝑥 𝑢+1 𝑛𝑒 ) • 𝑥 𝑢+1 • Convergence? min ||∇𝑔 𝑥 𝑢 || ≤ 𝜗 𝑢 𝑀⋅(𝑔 𝑥 0 −𝑔 𝑥 ∗ ) • For 𝑈 = 𝑃( ) 𝜗 𝑀⋅(𝑔 𝑥 0 −𝑔 𝑥 ∗ ) 1/4 • If convex: 𝑈 = 𝑃( ) 𝜗 Ghadimi and Lan - 2013
Non-convex Optimization: Sum of Functions • What if the function has more structure? 𝑜 𝑔 𝑥 = 1 min 𝑜 𝑔 𝑗 (𝑥) 𝑥 𝑗=1 𝑜 • ∇𝑔 𝑥 = σ 𝑗=1 ∇𝑔 𝑗 (𝑥) • I.e., computing gradient would require 𝑃(𝑜) computation
Does Stochastic Gradient Descent Work? • For t=1, 2, … (until convergence) • Sample 𝑗 𝑢 ∼ 𝑉𝑜𝑗𝑔[1, 𝑜] • 𝑥 𝑢+1 = 𝑥 𝑢 − 𝜃∇𝑔 𝑗 𝑢 𝑥 𝑢 Proof? 𝐹 𝑗 𝑢 𝑥 𝑢+1 − 𝑥 𝑢 = 𝜃∇𝑔(𝑥 𝑢 ) 𝑔 𝑥 𝑢+1 ≤ 𝑔 𝑥 𝑢 + ∇𝑔 𝑥 𝑢 , 𝑥 𝑢+1 − 𝑥 𝑢 + 𝑀 2 ||𝑥 𝑢+1 − 𝑥 𝑢 || 2 − 𝜃 2 ||∇𝑔 𝑥 𝑢 || 2 + 𝑀 2 𝜃 2 ⋅ 𝑊𝑏𝑠 E 𝑔 𝑥 𝑢+1 ≤ 𝐹 𝑔 𝑥 𝑢 1 4 𝑢 ||∇𝑔 𝑥 𝑢 || ≤ 𝑀 𝑔 𝑥 0 − 𝑔 𝑥 ∗ ⋅ 𝑊𝑏𝑠 min ≤ 𝜗 1 𝑈 4 𝑈 = 𝑃 𝑀 ⋅ 𝑊𝑏𝑠 ⋅ (𝑔 𝑥 0 − 𝑔 𝑥 ∗ ) 𝜗 4
Summary: Convergence to FOSP Algorithm No. of Gradient Calls (Non-convex) No. of Gradient Calls (Convex) 1 𝑃 1 GD [Folkore; Nesterov] 𝑃 𝜗 2 𝜗 AGD [Ghadimi & Lan-2013] 𝑃 1 1 𝑃 𝜗 𝜗 Algorithm No. of Gradient Calls Convex Case 𝑃( 𝑜 𝑃(𝑜 GD [Folkore] 𝜗 2 ) 𝜗 ) 𝑃 𝑜 AGD [Ghadimi & Lan’2013] 𝑜 𝑃 𝜗 𝜗 𝑜 𝑔 𝑥 = 1 𝑃( 1 𝑃( 1 𝑜 𝑔 𝑗 (𝑥) SGD [Ghadimi & Lan’2013] 𝜗 4 ) 𝜗 2 ) 𝑗=1 2 𝑜/𝜗 2 ) SVRG [Reddi et al-2016, Allen- 𝑃(𝑜 + 3 /𝜗 2 ) 𝑃(𝑜 + 𝑜 Zhu&Hazan-2016] 2 MSVRG [Reddi et al-2016] 𝑜 𝑃(min( 1 𝜗 4 , 𝑜 3 𝑃 𝑜 + 𝜗 2 )) 𝜗 2
Finding Second Order Stationary Points (SOSP) Second Order Stationary Point (SOSP) if: • ∇𝑔 𝑥 = 0 • ∇ 2 𝑔 𝑥 ≽ 0 Approximate SOSP: • ||∇𝑔 𝑥 || ≤ 𝜗 • 𝜇 𝑛𝑗𝑜 ∇ 2 𝑔 𝑥 ≥ − 𝜍𝜗 Second Order Stationary Point (SOSP) image credit: academo.org
Recommend
More recommend