on the complexity of training a neural network
play

On The Complexity of Training a Neural Network Santosh Vempala - PowerPoint PPT Presentation

On The Complexity of Training a Neural Network Santosh Vempala Algorithms and Randomness Center Georgia Tech The Complexity of Learning Neural Networks Deep learning s successes are incredible Do you want to beat 9 dan Goplayers? The


  1. On The Complexity of Training a Neural Network Santosh Vempala Algorithms and Randomness Center Georgia Tech The Complexity of Learning Neural Networks

  2. Deep learning ’ s successes are incredible Do you want to beat 9 dan Goplayers? The Complexity of Learning Neural Networks

  3. Deep learning’s successes are incredible Do you want to beat 9 dan Goplayers? classify images more accurately than humans? The Complexity of Learning Neural Networks

  4. Deep learning ’s successes are incredible Do you want to beat 9 dan Goplayers? classify images more accurately than humans? recognize speech? recommend movies? drive an autonomous vehicle? The Complexity of Learning Neural Networks

  5. Deep learning conquers the world Do you want to beat 9 dan Goplayers? classify images more accurately than humans? recognize speech? recommend movies? drive an autonomous vehicle? publish a paper in NIPS? Try deeplearning! The Complexity of Learning Neural Networks

  6. The Learning Problem Problem Given labeled samples (𝑦,𝑔 𝑦 ) where x ∼ D and f : R n → R, find 𝑕: 𝑆 𝑜 → 𝑆 s.t. 𝐹 𝐸 𝑔 𝑦 − 𝑕 𝑦 2 ≤ 𝜁 The Complexity of Learning Neural Networks

  7. Deep Learning for a theoretician Want to approximate concept f . How to choose model g ? A simple “ neural network ” (NN) y 1 = σ ( W 1 · x + b 1 ) x ∼ D (R n ) W 3 · y + b 3 = g ( x ) y 2 = σ ( W 2 · x + b 2 ) input output hidden layer layer layer The Complexity of Learning Neural Networks

  8. Deep Learning for a theoretician Want to approximate concept f . How to choose model g ? A simple “ neural network ” (NN) y 1 = σ ( W 1 · x + b 1 ) x ∼ D (R n ) W 3 · y + b 3 = g ( x ) y 2 = σ ( W 2 · x + b 2 ) activation weights function 1 e.g., sigmoid: 𝜏 𝑦 = 1+𝑓 −𝑦 The Complexity of Learning Neural Networks

  9. Deep Learning for a theoretician How to “train” network, i.e., choose W , b ? y 1 = σ ( W 1 · x + b 1 ) x ∼ D (R n ) W 3 · y + b 3 = g ( x ) y 2 = σ ( W 2 · x + b 2 ) Gradient descent: estimate gradient from samples, update weights, repeat ( x 1 , f ( x 1 )) labeled data ( x 2 , f ( x 2 )) E x ∇ W ( f − g ) 2 W ← W − ∇ W ( x 2 , f ( x 2 )) The Complexity of Learning Neural Networks

  10. Guarantees for deep learning? Goal Provable guarantees for NN training algorithms The Complexity of Learning Neural Networks

  11. Guarantees for deep learning? Goal Provable guarantees for NN training algorithms when data generated by a NN The Complexity of Learning Neural Networks

  12. Guarantees for deep learning? Goal Provable guarantees for NN training algorithms when data generated by a NN Theorem (Cybenko1989) Continuous functions can be approximated by 1 -hidden-layer NNs with sigmoids The Complexity of Learning Neural Networks

  13. Guarantees for deep learning? Goal Provable guarantees for NN training algorithms when data generated by a small one-hidden-layer NN Theorem (Cybenko1989) Continuous functions can be approximated by 1 -hidden-layer NNs with sigmoids The Complexity of Learning Neural Networks

  14. Provable guarantees for deep learning? What could the form of such guarantees be? Under what conditions (on input distribution, function) does Stochastic Gradient Descent work? Does it help if the data is generated by a NN? (Is the “ realizable ” case easier?)

  15. Guarantees for deep learning? Lower bounds for realizable case: NP-Hard to train neural network with 3 threshold neurons (A. Blum – Rivest 1993) Complexity/crypto assumptions ⇒ cannot efficiently learn small depth networks (even improperly) (Klivans – Sherstov, 2006), (Daniely – Linial – Shalev-Schwartz, 2014) Even with nice input distributions, some deep learning algorithms don’t work (Shamir, 2016)

  16. Outline 1. Lower bounds for learning neural networks “ You can ’ t efficiently learn functions computed by small, single-hidden layer neural networks, even over nice input distributions. ” 2. Polynomial-time analysis of gradient descent “ Gradient descent can efficiently train single-hidden layer neuralnetworks of unbiased activation units. ” 2. Open questions The Complexity of Learning Neural Networks

  17. A nice family of neural The Construction networks ෨ 𝑃(𝑜) n -dim Linear sigmoid input output units N (0 , 1) N (0 , 1) N (0 , 1) N (0 , 1) N (0 , 1)

  18. A nice family of neural Smooth activation The Construction functions networks e sx Sigmoid function of sharpness s is 1+ e sx s = 1 s = 4 s = 12

  19. A little more generality A nice family of neural The Construction networks ReLU, PReLU, n -dim Linear softplus, input output sigmoid ... logconcave logconcave logconcave logconcave logconcave

  20. Use deep learning! Choose your favorite network architecture, activation units, loss function, gradient descent variant, regularization scheme. . .

  21. A computational lower bound Choose your favorite network architecture, activation units, loss function, gradient descent variant, regularization scheme, etc. “ Theorem ” If using only “ black box ” functions of input (e.g., gradients via Tensorflow, Hessians via Autograd) 2 Ω 𝑜 1 need function evaluations of accuracy at most 𝑡 𝑜 . 𝑡 2

  22. A little more context O ˜ ( n ) n -dim Linear sigmoid input output units N (0 , 1) N (0 , 1) N (0 , 1) N (0 , 1) N (0 , 1) Janzamin – Sedghi – Anandkumar (2015): tensor decomposition algorithm with additional assumptions; sample size = poly(n, condition number of weights) Shamir (2016) gives exponential lower bounds against “ vanilla ” SGD with mean-squared loss, ReLU units, Gaussian input (nonrealizable, but similar construction) More recent improvements on upper bounds (coming up)

  23. A little more generality ReLU, n -dim PReLU, Linear input output softplus, sigmoid, ... logconcave logconcave logconcave logconcave logconcave Lower bound applies to algorithms of following form: Estimate v = E ( x , y ) ∼ D ( h ( W , x , y )) 1. where W: current weights, ( x ,y): labeled example from input dist. D H: arbitrary [0 , 1]-valued function 2. Use v to update W .

  24. The hard family of functions σ (1 /s + x ) σ (1 /s − x ) φ ( x ) = σ (1 /s + x ) + σ (1 /s − x ) − 1 F σ ( x ) = φ ( x ) + φ ( x − 2 /s ) + φ ( x + 2 /s ) + · · · F σ : R → R affine combination of σ -units, almost periodic on [ − ෨ 𝑃(𝑜) ], period = 1 / s 𝑃(𝑜), ෨

  25. The hard family of functions r ( x + 1 / (2 s )) − r ( x − 1 / (2 s )) φ 0 ( x ) = r ( x + 1 / (2 s )) − r ( x − 1 / (2 s )) φ ( x ) = φ 0 (1 / (2 s ) + x ) + φ 0 (1 / (2 s ) − x ) − 1 F r ( x ) = φ ( x ) + φ ( x − 2 /s ) + φ ( x + 2 /s ) + · · ·

  26. The hard family of functions F σ : R → R affine combination of σ -units, almost periodic on [ − ෨ 𝑃(𝑜), ෨ 𝑃(𝑜) ], period = 1 / s weights ± 1 logconcave S F σ f S logconcave S S ⊆ { 1 ,..., n } with | S | = n / 2 ∀ 𝑔 𝑇 𝑦 = 𝐺 ෍ 𝑦 𝑗 𝜏 𝑗∈𝑇

  27. Throw some “ deep learning ” at it! Choose your favorite: network architecture, activation units, loss function, gradient descent variant, regularization scheme. .. The Complexity of Learning Neural Networks

  28. Theory vs. practice, revisited sigmoid 0.040 0.035 0.030 0.025 train error 0.020 0.015 50 0.010 100 200 0.005 500 1000 0.000 0 5 10 15 20 s * sqrt(n) The Complexity of Learning Neural Networks

  29. Theory vs. practice, revisited

  30. Throw some “deep learning” at it! Choose your favorite: network architecture, activation units, loss function, gradient descent variant, regularization scheme. .. Were yousuccessful? The Complexity of Learning Neural Networks

  31. Try “deep lear ning” ! Choose your favorite: network architecture, activation units, loss function, gradient descent variant, regularization scheme. .. Were you successful? “ Theorem ” No! The Complexity of Learning Neural Networks

  32. Statistical Query Algorithms Recall gradient descent training algorithm: ( x 1 , f ( x 1 )) labeled data estimate gradient update weights ( x 2 , f ( x 2 )) E x ∇ W ( f − g ) 2 W ← W - ∇ W ( x 2 , f ( x 2 )) The Complexity of Learning Neural Networks

  33. Statistical Query Algorithms Recall gradient descent training algorithm: query gradient update weights E x ∇ W ( f − g ) 2 W ← W + ∇ W Need gradient estimate, not necessarily labeled examples The Complexity of Learning Neural Networks

  34. Statistical Query Algorithms Statistical query (SQ) algorithms introduced by Kearns in 1993. No direct access to samples. Queries expectations of functions on labeled example distribution Query ( h, τ ): h : R n × R → [0 , 1], τ > 0 Algorithm Oracle Response v : |E h − v | < τ E.g. for gradient descent, query h = ∇ W ( f − g ) 2 The Complexity of Learning Neural Networks

  35. Statistical Query Algorithms SQ algorithms extremely general. Almost all “ robust ” machine learning guarantees can be acheived with SQ algorithms. The Complexity of Learning Neural Networks

Recommend


More recommend