Exponential convergence of testing error for stochastic gradient methods Loucas Pillaud-Vivien INRIA - Ecole Normale Sup´ erieure, Paris, France Joint work with Alessandro Rudi and Francis Bach COLT - July 2018 1/10
Stochastic Gradient Descent Minimizes a function F given unbiased estimates of its gradients: g k = g k − 1 − γ k ∇ F k ( g k − 1 ) . 2/10
Stochastic Gradient Descent Minimizes a function F given unbiased estimates of its gradients: g k = g k − 1 − γ k ∇ F k ( g k − 1 ) . A workhorse in Machine Learning ◮ n input-output samples ( x i , y i ) i � n . ◮ One observation at each step − → complexity O ( d ) per iteration . 2/10
Stochastic Gradient Descent Minimizes a function F given unbiased estimates of its gradients: g k = g k − 1 − γ k ∇ F k ( g k − 1 ) . A workhorse in Machine Learning ◮ n input-output samples ( x i , y i ) i � n . ◮ One observation at each step − → complexity O ( d ) per iteration . Regression problems: Best convergence rates O (1 / √ n ) or O (1 / n ). Nemirovski and Yudin (1983); Polyak and Juditsky (1992) 2/10
Stochastic Gradient Descent ◮ Regression problems: best convergence rates O (1 / √ n ) or O (1 / n ). ◮ Can it be faster for classification problems ? 3/10
Stochastic Gradient Descent ◮ Regression problems: best convergence rates O (1 / √ n ) or O (1 / n ). ◮ Can it be faster for classification problems ? Take home message : Yes , SGD converges exponentially fast in classification error with some margin condition. 3/10
4.0 0-1 3.5 square hinge 3.0 logistic 2.5 2.0 1.5 1.0 0.5 0.0 3 2 1 0 1 2 3 Binary classification: problem setting ◮ Data: ( x , y ) ∈ X × {− 1 , 1 } distributed according to ρ . ◮ Prediction: ˆ y = sign g ( x ), with g ( x ) = � g , φ ( x ) � H . ◮ Aim: minimize over g ∈ H the error , F 01 ( g ) = E ℓ 01 ( y , g ( x )) = E 1 yg ( x ) < 0 . 4/10
Binary classification: problem setting ◮ Data: ( x , y ) ∈ X × {− 1 , 1 } distributed according to ρ . ◮ Prediction: ˆ y = sign g ( x ), with g ( x ) = � g , φ ( x ) � H . ◮ Aim: minimize over g ∈ H the error , F 01 ( g ) = E ℓ 01 ( y , g ( x )) = E 1 yg ( x ) < 0 . From error to losses As ℓ 01 is non convex, we use square loss : 4.0 0-1 3.5 square hinge 3.0 logistic 2.5 2.0 1.5 1.0 0.5 0.0 3 2 1 0 1 2 3 4/10
Binary classification: problem setting ◮ Data: ( x , y ) ∈ X × {− 1 , 1 } distributed according to ρ . ◮ Prediction: ˆ y = sign g ( x ), with g ( x ) = � g , φ ( x ) � H . ◮ Aim: minimize over g ∈ H the error , F 01 ( g ) = E ℓ 01 ( y , g ( x )) = E 1 yg ( x ) < 0 . From error to losses As ℓ 01 is non convex, we use square loss : ◮ Square loss: F ( g ) = E ℓ ( y , g ( x )) = E ( y − g ( x )) 2 , minimum for g ∗ ( x ) = E ( y | x ). ◮ Ridge regression: F λ ( g ) = E ( y − g ( x )) 2 + λ � g � 2 H , minimum for g λ . 4/10
Binary classification: problem setting ◮ Data: ( x , y ) ∈ X × {− 1 , 1 } distributed according to ρ . ◮ Prediction: ˆ y = sign g ( x ), with g ( x ) = � g , φ ( x ) � H . ◮ Aim: minimize over g ∈ H the error , F 01 ( g ) = E ℓ 01 ( y , g ( x )) = E 1 yg ( x ) < 0 . From error to losses As ℓ 01 is non convex, we use square loss : ◮ Excess error and loss (Bartlett et al., 2006): � E ( y − g ( x )) 2 − ℓ ∗ E ℓ 01 ( y , g ( x )) − ℓ 01 ∗ � � �� � � �� � Excess error Excess loss 1 If we use existing results for SGD: E ℓ 01 ( y , g ( x )) − ℓ 01 ∗ � √ . λ n − → Not exponential 4/10
Main assumptions Margin condition (Mammen and Tsybakov, 1999) ◮ Hard inputs to predict: P ( y = 1 | x ) = 1 / 2, i.e., E ( y | x ) = 0 ◮ Easy inputs to predict: P ( y = 1 | x ) ∈ { 0 , 1 } , i.e., | E ( y | x ) | = 1 − → Margin condition: ∃ δ > 0 , s . t . | E ( y | x ) | � δ for all x ∈ supp ( ρ X ). 5/10
Main assumptions ◮ (A1) Margin condition: ∃ δ > 0 , s . t . | E ( y | x ) | � δ , for all x ∈ supp ( ρ X ). ◮ (A2) Technical condition: ∃ λ > 0 s.t. sign ( E ( y | x )) g λ ( x ) � δ/ 2, for all x ∈ supp ( ρ X ). Consequence: for ˆ g s.t. � g λ − ˆ g � L ∞ < δ/ 2, sign ˆ g ( x ) = sign ( E ( y | x )). 6/10
Main assumptions ◮ (A1) Margin condition: ∃ δ > 0 , s . t . | E ( y | x ) | � δ , for all x ∈ supp ( ρ X ). ◮ (A2) Technical condition: ∃ λ > 0 s.t. sign ( E ( y | x )) g λ ( x ) � δ/ 2, for all x ∈ supp ( ρ X ). Single pass SGD through the data on the regularized problem g n = g n − 1 − γ n [( � φ ( x n ) , g n − 1 � − y n ) φ ( x n ) + λ ( g n − 1 − g 0 )] , � n 1 g tail Take tail averaged estimator, ¯ = i = n / 2 g i , Jain et al. (2016). n n / 2 Theorem: Exponential convergence of SGD for the test error Assume n � 1 λγ log R δ and γ � 1 / (4 R 2 ) then, � � − ℓ 01 ∗ � 4 exp � − λ 2 δ 2 n / R 2 � g tail E x 1 ... x n E ℓ 01 y , ¯ ( x ) . n 6/10
Main result Theorem: Exponential convergence of SGD for the test error Assume n � 1 λγ log R δ and γ � 1 / (4 R 2 ) then, � � − ℓ 01 ∗ � 4 exp � − λ 2 δ 2 n / R 2 � g tail E x 1 ... x n E ℓ 01 y , ¯ ( x ) . n ◮ Main tool for the proof: high probability bound in � · � L ∞ for the SGD recursion (result on its own) ◮ Excess testing loss not exponentially convergent 7/10
Main result Theorem: Exponential convergence of SGD for the test error Assume n � 1 λγ log R δ and γ � 1 / (4 R 2 ) then, � � − ℓ 01 ∗ � 4 exp � − λ 2 δ 2 n / R 2 � g tail E x 1 ... x n E ℓ 01 y , ¯ ( x ) . n ◮ Main tool for the proof: high probability bound in � · � L ∞ for the SGD recursion (result on its own) ◮ Excess testing loss not exponentially convergent ◮ Motivations to look at the paper and come at the poster session: ◮ Sharper bounds ◮ High probability bound in � · � L ∞ (usually � · � L 2 ) for the SGD recursion ◮ Bounds for the regular averaging ◮ Bounds for general low-noise condition 7/10
Synthetic experiments ◮ Comparing test/train losses/errors for tail-averaged SGD ( X = [0 , 1], H Sobolev.) (a) Excess losses (b) Excess errors 8/10
Conclusion Take home message: ◮ Exponential convergence of test error and not test loss ◮ Importance of the margin condition ◮ High probability bound for the averaged and regularized SGD 9/10
Conclusion Take home message: ◮ Exponential convergence of test error and not test loss ◮ Importance of the margin condition ◮ High probability bound for the averaged and regularized SGD Possible extensions: ◮ No regularization ◮ Study the effect of the regularity of the problem ◮ Beyond least-squares 9/10
Thank you for your attention ! Come see us at the poster session ! 10/10
Thank you for your attention ! Come see us at the poster session ! 3 0 10/10
Recommend
More recommend