Learning and Optimization: Lower Bounds and Tight Connections Nati Srebro TTI-Chicago On The Universality of Online Mirror Descent S, Karthik Sridharan (UPenn), Ambuj Tewari (Michigan), NIPS’11 Learning from an Optimization Viewpoint Karthik Sridharan TTIC PhD Thesis
Learning/Optimization over L 2 Ball SVM: ℓ (h(x);y) = [1-y · h(x)] + • Stat Learning / Stoch Optimization: min ||w||2 ≤ B L(w) = E x,y~ D [ ℓ ( � w,x � ;y)] based on m iid samples x,y~ D ||x|| 2 ≤ R ̂ (w) = 1/m ∑ t ℓ (h(x t );y t ) L • Using SAA/ERM: ŵ = arg min L ̂ (w) � B 2 R 2 /m L ( ˆ w ) ≤ inf � w �≤ B L ( w ) + 2 • Rate of 1 st order (or any local) optimization: � B 2 R 2 /T L ( w T ) ≤ inf � w �≤ B ˆ ˆ L ( w ) + • Using SA/SGD on L(w): w t+1 ← w t - η t ∇ w ℓ ( � w,x t � ;y t ) � B 2 R 2 /m L ( ¯ w m ) ≤ inf � w �≤ B L ( w ) + [Bottou Bousquet 08][S Shalev-Shwartz 08][Juditsky Lan Nemirovski Shapiro 09]
Learning/Optimization over L 2 Ball radius of opt domain Lipshitz � B 2 R 2 • (Deterministic) Optimization: runtime T (grad evals) radius of � radius of data hypothesis B 2 R 2 • Statistical Learning: m #samples � B 2 R 2 • Stoch. Aprx. / One-pass SGD: T #grad estimates = #samples = runtime � B 2 R 2 • Online Learning (avg regret): T #rounds
Questions • What about other (convex) learning problems (other geometries): – Is Stochastic Approximation always optimal? – Are the rates for learning (# of samples) and optimization (runtime / # of accesses) always the same?
Outline • Deterministic Optimization vs Stat. Learning – Main result: fat shattering as lower bound on optimization – Conclusion: sample complexity ≤ opt runtime • Stochastic Approximation for Learning – Online Learning Very briefly – Optimality of Online Mirror Descent
Optimization Complexity min w ∈W f(w) • Optimization problem defined by: – Optimization space W – Function class F ⊆ { f: W → R } • Runtime to get accuracy ǫ : – Input: instance f ∈ F , ǫ >0 – Output: w ∈ W s.t. f(w) ≤ inf w ∈W f(w)+ ǫ • Count number of local black-box accesses to f( · ): O f :w → f(w), ∇ f(w), any other “local” information ( ∀ neighborhood N(w) f 1 =f 2 on N(w) ⇒ O f1 (w)=O f2 (w))
Generalized Lipchitz Problems min w ∈W f(w) • We will consider problems where: – W is a convex subset of a vector space L (e.g. R d or inf. dim.) – X convex ⊂ L * – F = F lip( X ) = { f: W → R convex | ∀ w ∇ f(w) ∈ X } • Examples: – X = { |x| 2 ≤ 1 } corresponds to standard notion of Lipchitz functions X = { |x| ≤ 1} corresponds to Lipchitz w.r.t. norm |x| – • Theorem (Main Result): The ǫ -fat shattering dimension of lin( W , X ) is a lower bound on the number of accesses required to optimize F lip to within ǫ
Fat Shattering • Definition: • x 1 ,…,x m ∈ X are ǫ -fat shattered by W if there exists scalars t 1 ,…,t n s.t. for every sign pattern y 1 ,…,y m , there exists w ∈W s.t. y i ( � w,x i � -t i ) > ǫ . • The ǫ -fat shattering dimension of lin( W , X ) is the largest number of points m that can be ǫ -fat shattered
Optimization, ERM and Learning • Supervised learning with linear predictors: ̂ (w) = (1/m) ∑ t=1..m loss( � w,x t � , y t ) L x t ∈ X 1-Lipshitz ERM: ŵ = min w ∈W L ̂ (w) ̂ (w) ∈ conv( X ) Gradient of (empirical) risk: ∇ L • Learning guarantee: If for some q ≥ 2, fat-dim( ǫ ) ≤ (V/ ǫ ) q ⇒ L( ŵ ) ≤ inf w ∈W L(w) + O( V log 1.5 (m) / m 1/q ) • Conclusion: For q ≥ ≥ 2 , if there exists V s.t. the rate of optimization is at most ≥ ≥ ǫ (m) ≤ V/T 1/q , then the statistical rate of the associated learning problem is at most: ǫ (m) ≤ 36 V log 1.5 (m) / m 1/q
Convex Learning ⇒ Linear Prediction • Consider learning with a hypothesis class H = { h: X → R } ̂ (h) = (1/m) ∑ t=1..m loss( h(x t ), y t ) L ̂ (h w ) will be convex in a • With any meaningful loss, L parameterization w, only if h w (x) is linear in w , i.e. h w (x) = � w, φ (x) � • Rich variety of learning problems obtained with different (sometimes implicit) choices of linear hypothesis classes, feature mappings φ , and loss functions.
Linear Prediction • Gradient space X is the learning data domain (i.e. the space learning inputs come from), or image of feature map φ – φ specified via Kernel (as in SVMs, kernalized logistic or ridge regression) – In boosting: coordinates of φ are “weak learners” – φ can specify evaluations (as in collaborative filtering, total variation problems) • Optimization space F is the hypothesis class , the set of allowed linear predictors. Corresponds to choice of “regularization” – L 2 (SVMs, ridge regression) – L 1 (LASSO, Boosting) – Elastic net, other interpolations – Group norms – Matrix norms: trace-norm, max-norm, etc (eg for collaborative filtering and multi-task learning) • Loss function need only be (scalar) Lipchitz. – hinge, logistic, etc – structured loss, where y i non-binary (CRFs, translation, etc) – exp-loss (Boosting), squared loss ⇒ NOT globally Lipchitz
Main Result • Problems of the form: min w ∈W f(w) W convex ⊂ vector space B (e.g. R n , or inf.-dimensional) – X convex ⊂ B * – f ∈ F = F lip( X ) = { f: W→ R convex | ∀ w ∇ f(w) ∈ X } – • Theorem (Main Result): The ǫ -fat shattering dimension of lin( W , X ) is a lower bound on the number of accesses required to optimize f ∈ F lip to within ǫ • Conclusion: For q ≥ ≥ 2 , if for some V, the rate of ERM optimization is at most ≥ ≥ ǫ (m) ≤ V/T 1/q , then the learning rate of the associated problem is at most: ǫ (m) ≤ 36 V log 1.5 (m) / m 1/q
Proof of Main Result • Theorem: The ǫ -fat shattering dimension of lin( W , X ) is a lower bound on the number of accesses required to optimize F lip to within ǫ • That is, for any optimization algorithm, there exists a function f ∈F lip s.t. after m=fat-dim( ǫ ) local accesses, the algorithm is ≥ ǫ - suboptimal. Proof overview: • View optimization as a game, where at each round t: – Optimizer asks for local information at w t , – Adversary responds, ensuring consistency with some f ∈F . We will play the adversary, ensuring consistency with some f ∈F where inf w f(w) ≤ ǫ , but where f(w t ) ≥ 0.
Playing the Adversary • x 1 ,..,x m fat-shattered with thresholds s 1 ,..,s m . I.e., ∀ signs y 1 ,..,y m ∃ w s.t. y i ( � w,x i � -s i ) ≥ ǫ • We will consider functions of the form: f y (w) = max i y i (s i - � w,x i � ) • Convex, piecewise linear • (Sub)-gradients are y i x i ⇒ f y ∈F lip( X ) • Fat shattering ⇒ ∀ y inf w f y (w) ≤ - ǫ
Playing the Adversary f y (w) = max i y i (s i - � w,x i ) Goal: ensure consistency with some f y s.t. f y (w t ) ≥ 0 • • How: Maintain model f t (w) = max i ∈ At y i (s i - � w,x i ) based on A t ⊆ {1..m}. Initialize A 0 = {} • • At each round t=1..m, add to A t : i t = argmax i ∉ At-1 |s i - � w,x i � | and set corresponding y i s.t. y i (s i - � w,x i � ) ≥ 0 Return local information at w t based on f t • Claim: f t agrees with final f y on w t , and so adversarial responses to • algorithm are consistent with f y , but f y (w t ) = f t (w t ) ≥ 0 ≥ inf w f y (w)+ ǫ
Optimization vs Learning (deterministic) Statistical ≥ = d ǫ Optimization Learning runtime, # samples # func, grad accesses • Converse? – Optimize with d ǫ accesses? (intractable alg OK) – Learning ⇒ Optimization? With sample size m , exact grad calculation is O( m ) time, and so even if #iter=#samples, runtime is O( m 2 ). • Stochastic Approximation? (stochastic, local access, O(1) memory method)
Online Optimization / Learning • Online optimization setup: – As before, problem specified by W , F – f 1 ,f 2 ,… presented sequentially by “adversary” – “Learner” responds with w 1 ,w 2 ,… Adversary: f 1 f 2 f 3 …. Learner: w 1 w 2 w 3 – Formally, learning rule A: F * →W with w t =A(f 1 ,…,f t-1 ) • Goal: minimize regret versus best single response in hindsight. – Rule A has regret ǫ (m) if for all sequences f 1 ,…,f m : 1/m ∑ t=1..m f t (w t ) ≤ inf w ∈W 1/m ∑ t=1..m f t (w) + ǫ (m) w t =A(f 1 ,…,f t-1 ) • Examples: – Spam Filtering – Investment return: w[ i ] = investment in holding i f t (w) = - � w,z t � , where z t [i] = return on holding i
Online To Batch • An online optimization algorithm with regret guarantee 1/m ∑ t=1..m f t (w t ) ≤ inf w ∈W 1/m ∑ t=1..m f t (w) + ǫ (m) can be converted to a learning (stochastic optimization) algorithm, by running it on a sample and outputting the average of the iterates: [Cesa-Bianchi et al 04] : E [L (w ̅ m )] ≤ inf w ∈W L(w) + ǫ (m) w ̄ m =(w 1 +..+w m )/m (in fact, even with high probability rather then in expectation) An online optimization algorithm that uses only local info at w i can • also be used as for deterministic optimization, by setting z i =z: ̅ m ) ≤ inf w ∈W f(w) + ǫ (m) f(w
Recommend
More recommend