On The Complexity of Training a Neural Network Santosh Vempala Algorithms and Randomness Center Georgia Tech The Complexity of Learning Neural Networks
Deep learning ’ s successes are incredible Do you want to beat 9 dan Goplayers? The Complexity of Learning Neural Networks
Deep learning’s successes are incredible Do you want to beat 9 dan Goplayers? classify images more accurately than humans? The Complexity of Learning Neural Networks
Deep learning ’s successes are incredible Do you want to beat 9 dan Goplayers? classify images more accurately than humans? recognize speech? recommend movies? drive an autonomous vehicle? The Complexity of Learning Neural Networks
Deep learning conquers the world Do you want to beat 9 dan Goplayers? classify images more accurately than humans? recognize speech? recommend movies? drive an autonomous vehicle? publish a paper in NIPS? Try deeplearning! The Complexity of Learning Neural Networks
The Learning Problem Problem Given labeled samples (𝑦,𝑔 𝑦 ) where x ∼ D and f : R n → R, find : 𝑆 𝑜 → 𝑆 s.t. 𝐹 𝐸 𝑔 𝑦 − 𝑦 2 ≤ 𝜁 The Complexity of Learning Neural Networks
Deep Learning for a theoretician Want to approximate concept f . How to choose model g ? A simple “ neural network ” (NN) y 1 = σ ( W 1 · x + b 1 ) x ∼ D (R n ) W 3 · y + b 3 = g ( x ) y 2 = σ ( W 2 · x + b 2 ) input output hidden layer layer layer The Complexity of Learning Neural Networks
Deep Learning for a theoretician Want to approximate concept f . How to choose model g ? A simple “ neural network ” (NN) y 1 = σ ( W 1 · x + b 1 ) x ∼ D (R n ) W 3 · y + b 3 = g ( x ) y 2 = σ ( W 2 · x + b 2 ) activation weights function 1 e.g., sigmoid: 𝜏 𝑦 = 1+𝑓 −𝑦 The Complexity of Learning Neural Networks
Deep Learning for a theoretician How to “train” network, i.e., choose W , b ? y 1 = σ ( W 1 · x + b 1 ) x ∼ D (R n ) W 3 · y + b 3 = g ( x ) y 2 = σ ( W 2 · x + b 2 ) Gradient descent: estimate gradient from samples, update weights, repeat ( x 1 , f ( x 1 )) labeled data ( x 2 , f ( x 2 )) E x ∇ W ( f − g ) 2 W ← W − ∇ W ( x 2 , f ( x 2 )) The Complexity of Learning Neural Networks
Guarantees for deep learning? Goal Provable guarantees for NN training algorithms The Complexity of Learning Neural Networks
Guarantees for deep learning? Goal Provable guarantees for NN training algorithms when data generated by a NN The Complexity of Learning Neural Networks
Guarantees for deep learning? Goal Provable guarantees for NN training algorithms when data generated by a NN Theorem (Cybenko1989) Continuous functions can be approximated by 1 -hidden-layer NNs with sigmoids The Complexity of Learning Neural Networks
Guarantees for deep learning? Goal Provable guarantees for NN training algorithms when data generated by a small one-hidden-layer NN Theorem (Cybenko1989) Continuous functions can be approximated by 1 -hidden-layer NNs with sigmoids The Complexity of Learning Neural Networks
Provable guarantees for deep learning? What could the form of such guarantees be? Under what conditions (on input distribution, function) does Stochastic Gradient Descent work? Does it help if the data is generated by a NN? (Is the “ realizable ” case easier?)
Guarantees for deep learning? Lower bounds for realizable case: NP-Hard to train neural network with 3 threshold neurons (A. Blum – Rivest 1993) Complexity/crypto assumptions ⇒ cannot efficiently learn small depth networks (even improperly) (Klivans – Sherstov, 2006), (Daniely – Linial – Shalev-Schwartz, 2014) Even with nice input distributions, some deep learning algorithms don’t work (Shamir, 2016)
Outline 1. Lower bounds for learning neural networks “ You can ’ t efficiently learn functions computed by small, single-hidden layer neural networks, even over nice input distributions. ” 2. Polynomial-time analysis of gradient descent “ Gradient descent can efficiently train single-hidden layer neuralnetworks of unbiased activation units. ” 2. Open questions The Complexity of Learning Neural Networks
A nice family of neural The Construction networks ෨ 𝑃(𝑜) n -dim Linear sigmoid input output units N (0 , 1) N (0 , 1) N (0 , 1) N (0 , 1) N (0 , 1)
A nice family of neural Smooth activation The Construction functions networks e sx Sigmoid function of sharpness s is 1+ e sx s = 1 s = 4 s = 12
A little more generality A nice family of neural The Construction networks ReLU, PReLU, n -dim Linear softplus, input output sigmoid ... logconcave logconcave logconcave logconcave logconcave
Use deep learning! Choose your favorite network architecture, activation units, loss function, gradient descent variant, regularization scheme. . .
A computational lower bound Choose your favorite network architecture, activation units, loss function, gradient descent variant, regularization scheme, etc. “ Theorem ” If using only “ black box ” functions of input (e.g., gradients via Tensorflow, Hessians via Autograd) 2 Ω 𝑜 1 need function evaluations of accuracy at most 𝑡 𝑜 . 𝑡 2
A little more context O ˜ ( n ) n -dim Linear sigmoid input output units N (0 , 1) N (0 , 1) N (0 , 1) N (0 , 1) N (0 , 1) Janzamin – Sedghi – Anandkumar (2015): tensor decomposition algorithm with additional assumptions; sample size = poly(n, condition number of weights) Shamir (2016) gives exponential lower bounds against “ vanilla ” SGD with mean-squared loss, ReLU units, Gaussian input (nonrealizable, but similar construction) More recent improvements on upper bounds (coming up)
A little more generality ReLU, n -dim PReLU, Linear input output softplus, sigmoid, ... logconcave logconcave logconcave logconcave logconcave Lower bound applies to algorithms of following form: Estimate v = E ( x , y ) ∼ D ( h ( W , x , y )) 1. where W: current weights, ( x ,y): labeled example from input dist. D H: arbitrary [0 , 1]-valued function 2. Use v to update W .
The hard family of functions σ (1 /s + x ) σ (1 /s − x ) φ ( x ) = σ (1 /s + x ) + σ (1 /s − x ) − 1 F σ ( x ) = φ ( x ) + φ ( x − 2 /s ) + φ ( x + 2 /s ) + · · · F σ : R → R affine combination of σ -units, almost periodic on [ − ෨ 𝑃(𝑜) ], period = 1 / s 𝑃(𝑜), ෨
The hard family of functions r ( x + 1 / (2 s )) − r ( x − 1 / (2 s )) φ 0 ( x ) = r ( x + 1 / (2 s )) − r ( x − 1 / (2 s )) φ ( x ) = φ 0 (1 / (2 s ) + x ) + φ 0 (1 / (2 s ) − x ) − 1 F r ( x ) = φ ( x ) + φ ( x − 2 /s ) + φ ( x + 2 /s ) + · · ·
The hard family of functions F σ : R → R affine combination of σ -units, almost periodic on [ − ෨ 𝑃(𝑜), ෨ 𝑃(𝑜) ], period = 1 / s weights ± 1 logconcave S F σ f S logconcave S S ⊆ { 1 ,..., n } with | S | = n / 2 ∀ 𝑔 𝑇 𝑦 = 𝐺 𝑦 𝑗 𝜏 𝑗∈𝑇
Throw some “ deep learning ” at it! Choose your favorite: network architecture, activation units, loss function, gradient descent variant, regularization scheme. .. The Complexity of Learning Neural Networks
Theory vs. practice, revisited sigmoid 0.040 0.035 0.030 0.025 train error 0.020 0.015 50 0.010 100 200 0.005 500 1000 0.000 0 5 10 15 20 s * sqrt(n) The Complexity of Learning Neural Networks
Theory vs. practice, revisited
Throw some “deep learning” at it! Choose your favorite: network architecture, activation units, loss function, gradient descent variant, regularization scheme. .. Were yousuccessful? The Complexity of Learning Neural Networks
Try “deep lear ning” ! Choose your favorite: network architecture, activation units, loss function, gradient descent variant, regularization scheme. .. Were you successful? “ Theorem ” No! The Complexity of Learning Neural Networks
Statistical Query Algorithms Recall gradient descent training algorithm: ( x 1 , f ( x 1 )) labeled data estimate gradient update weights ( x 2 , f ( x 2 )) E x ∇ W ( f − g ) 2 W ← W - ∇ W ( x 2 , f ( x 2 )) The Complexity of Learning Neural Networks
Statistical Query Algorithms Recall gradient descent training algorithm: query gradient update weights E x ∇ W ( f − g ) 2 W ← W + ∇ W Need gradient estimate, not necessarily labeled examples The Complexity of Learning Neural Networks
Statistical Query Algorithms Statistical query (SQ) algorithms introduced by Kearns in 1993. No direct access to samples. Queries expectations of functions on labeled example distribution Query ( h, τ ): h : R n × R → [0 , 1], τ > 0 Algorithm Oracle Response v : |E h − v | < τ E.g. for gradient descent, query h = ∇ W ( f − g ) 2 The Complexity of Learning Neural Networks
Statistical Query Algorithms SQ algorithms extremely general. Almost all “ robust ” machine learning guarantees can be acheived with SQ algorithms. The Complexity of Learning Neural Networks
Recommend
More recommend