Learning with Differentiable Perturbed Optimizers Quentin Berthet Optimization for ML - CIRM - 2020
Q. Berthet M.Blondel O.Teboul M. Cuturi J-P. Vert F.Bach • Learning with Differentiable Perturbed Optimizers Preprint: arXiv:2002.08676
[A lot of] Machine learning these days Supervised learning : couples of inputs/responses ( X i , y i ) , a model g w X i X i ‘Alex A.’ y i y i y pred = g w ( X i ) y pred = g w ( X i ) ‘Alex G.’ L L g w g w ‘Claire B.’ ‘Soledad V.’ ‘Joseph S.’ Goal : Optimize parameters w ∈ R d of a function g w such that g w ( X i ) ≈ y i � min L ( g w ( X i ) , y i ) . w i Workhorse : first-order methods, based on ∇ w L ( g w ( X i ) , y i ) , backpropagation Problem : What if these models contain nondifferentiable ∗ operations? Q.Berthet - CIRM - 2020 1/17
Discrete decisions in Machine learning X y ∗ ( θ ) y ∗ ( θ ) θ y g w g w y ∗ y ∗ L Examples : discrete operations (e.g. max, rankings), break autodifferentiation y ∗ = vector of ranks e.g. [5 , 2 , 4 , 3 , 1] • θ = scores for k products, y ∗ = shortest path between two points • θ = edge costs, y ∗ = one-hot vector • θ = classification scores for each class, Q.Berthet - CIRM - 2020 2/17
Discrete decisions in Machine learning X y ∗ ( θ ) y ∗ ( θ ) θ y g w g w y ∗ y ∗ L Examples : discrete operations (e.g. max, rankings), break autodifferentiation y ∗ = vector of ranks e.g. [5 , 2 , 4 , 3 , 1] • θ = scores for k products, y ∗ = shortest path between two points • θ = edge costs, y ∗ = one-hot vector • θ = classification scores for each class, Q.Berthet - CIRM - 2020 2/17
Perturbed maximizer Discrete decisions : optimizers of linear program over C , convex hull of Y ⊆ R d y ∗ ( θ ) = argmax F ( θ ) = max y ∈C � y, θ � , and � y, θ � = ∇ θ F ( θ ) . y ∈C C θ θ y ∗ ( θ ) y ∗ ( θ ) Perturbed maximizer : average of solutions for inputs with noise εZ y ∗ ε ( θ ) = E [ y ∗ ( θ + εZ )] = E [argmax F ε ( θ ) = E [max y ∈C � y, θ � ] , � y, θ + εZ � ] = ∇ θ F ε ( θ ) . y ∈C Q.Berthet - CIRM - 2020 3/17
Perturbed maximizer Discrete decisions : optimizers of linear program over C , convex hull of Y ⊆ R d y ∗ ( θ + εZ ) y ∗ ( θ + εZ ) C θ + εZ θ + εZ y ∗ y ∗ " ( θ ) " ( θ ) θ y ∗ ( θ ) y ∗ ( θ ) Perturbed maximizer : average of solutions for inputs with noise εZ y ∗ ε ( θ ) = E [ y ∗ ( θ + εZ )] = E [argmax F ε ( θ ) = E [max y ∈C � y, θ � ] , � y, θ + εZ � ] = ∇ θ F ε ( θ ) . y ∈C Q.Berthet - CIRM - 2020 4/17
Perturbed model Model of optimal decision under uncertainty Luce (1959), McFadden et al. (1973) Y = argmax � y, θ + εZ � y ∈C Follows a perturbed model with Y ∼ p θ ( y ) , expectation y ∗ ε ( θ ) = E p θ [ Y ] . Perturb and map Papandreou & Yuille (2011), FT Perturbed L Kalai & Vempala (2003) Features Costs Shortest Path Perturbed Path = 0.5 Perturbed Path = 2.0 Example . Over the unit simplex C = ∆ d with Gumbel noise Z , F ( θ ) = max i θ i . θi e θi ε � ε , [ y ∗ F ε ( θ ) = ε log e p θ ( e i ) ∝ exp( � θ, e i � /ε ) , ε ( θ )] i = � e θj ε i ∈ [ d ] Q.Berthet - CIRM - 2020 5/17
Properties � ∗ is a convex function with domain C � Link with regularization : ε Ω = F ε y ∗ � � ε ( θ ) = argmax � y, θ � − ε Ω( y ) . y ∈C Consequence of duality and y ∗ ε ( θ ) = ∇ ε F ε ( θ ) . Generalization of entropy ε = 0 tiny ε small ε large ε Extreme temperatures. When ε → 0 , y ∗ ε ( θ ) → y ∗ ( θ ) for unique max. When ε → ∞ , y ∗ ε ( θ ) → argmin y Ω( y ) . Nonasymptotic results. Q.Berthet - CIRM - 2020 6/17
Properties Mirror maps : For C with full interior, Z with smooth density µ , full support F ε strictly convex, gradient Lipschitz. Ω strongly convex, Legendre type. R d R d C C r θ F " r θ F " θ y ∗ y ∗ " ( θ ) " ( θ ) r y Ω r y Ω Differentiability. Functions are smooth in the inputs. For µ ( z ) ∝ exp( − ν ( z )) y ∗ ε ( θ ) = ∇ θ F ε ( θ ) = E [ y ∗ ( θ + εZ )] = E [ F ( θ + εZ ) ∇ z ν ( Z ) /ε ] , ε ( θ ) = ∇ 2 F ε ( θ ) = E [ y ∗ ( θ + εZ ) ν ( Z ) ⊤ /ε ] . J θ y ∗ Perturbed maximizer y ∗ ε never locally constant in θ . Abernethy et al. (2014) Q.Berthet - CIRM - 2020 7/17
Properties Mirror maps : For C with full interior, Z with smooth density µ , full support F ε strictly convex, gradient Lipschitz. Ω strongly convex, Legendre type. R d R d C C θ r θ F " r θ F " y ∗ y ∗ " ( θ ) " ( θ ) r y Ω r y Ω Differentiability. Functions are smooth in the inputs. For µ ( z ) ∝ exp( − ν ( z )) y ∗ ε ( θ ) = ∇ θ F ε ( θ ) = E [ y ∗ ( θ + εZ )] = E [ F ( θ + εZ ) ∇ z ν ( Z ) /ε ] , ε ( θ ) = ∇ 2 F ε ( θ ) = E [ y ∗ ( θ + εZ ) ν ( Z ) ⊤ /ε ] . J θ y ∗ Perturbed maximizer y ∗ ε never locally constant in θ . Abernethy et al. (2014) Q.Berthet - CIRM - 2020 7/17
Learning with perturbed optimizers Machine learning pipeline: variable X , discrete label y , model outputs θ = g w ( X ) X y ∗ ( θ ) y ∗ ( θ ) θ y g w g w y ∗ y ∗ L Labels are solutions of optimization problems (one-hots, ranks, shortest paths) Small modification of the model: end-to-end differentiable Q.Berthet - CIRM - 2020 8/17
Learning with perturbed optimizers Machine learning pipeline: variable X , discrete label y , model outputs θ = g w ( X ) X y ∗ y ∗ θ " ( θ ) " ( θ ) y g w g w y ∗ y ∗ " " L Labels are solutions of optimization problems (one-hots, ranks, shortest paths) Small modification of the model: end-to-end differentiable Q.Berthet - CIRM - 2020 8/17
Learning with perturbed optimizers Machine learning pipeline: variable X , discrete label y , model outputs θ = g w ( X ) X y ∗ y ∗ θ " ( θ ) " ( θ ) y g w g w y ∗ y ∗ " " L Labels are solutions of optimization problems (one-hots, ranks, shortest paths) Small modification of the model: end-to-end differentiable Q.Berthet - CIRM - 2020 8/17
Why? and How? Learning problems : Features X i , model output θ w = g w ( X i ) , prediction y pred = y ∗ ε ( θ w ) , loss L y ∗ gradients require J θ y ∗ � � � � F ( w ) = L θ w , y i , ε ( θ w ) . ε Monte Carlo estimates. Perturbed maximizer and derivatives as expectations. For θ ∈ R d , Z (1) , . . . , Z ( M ) i.i.d. copies y ( ℓ ) = y ∗ ( θ + εZ ( ℓ ) ) y ∗ ( θ + εZ ) y ∗ ( θ + εZ ) C Unbiased estimate of y ∗ ε ( θ ) given by θ + εZ θ + εZ y ∗ y ∗ " ( θ ) " ( θ ) M y ε,M ( θ ) = 1 y ( ℓ ) . � ¯ θ M y ∗ ( θ ) y ∗ ( θ ) ℓ =1 Q.Berthet - CIRM - 2020 9/17
Fenchel-Young losses Natural loss to introduce, directly on θ , motivated by duality. Blondel et al. (2019) L ε ( θ ; y ) = F ε ( θ ) + ε Ω( y ) − � θ, y � . Interesting properties in a learning framework: • Convex in θ , minimized at θ s.t. y ∗ ε ( θ ) = y , with value 0. • Equal to Bregman divergence D ε Ω ( y ∗ ε ( θ ) | y ) • For random Y , E [ L ε ( θ ; Y )] = L ε ( θ ; E [ Y ]) + C e.g. for Y = argmax y ∈C � θ 0 + εZ, y � E [ L ε ( θ ; Y )] = L ε ( θ ; y ∗ ε ( θ 0 )) + C , population loss minimized at θ 0 . • Convenient gradients: ∇ θ L ε ( θ ; y ) = y ∗ ε ( θ ) − y . Q.Berthet - CIRM - 2020 10/17
Learning with perturbations and F-Y losses Within the same framework, possible to virtually bypass the optimization block X y ∗ y ∗ θ " ( θ ) " ( θ ) y y ∗ y ∗ g w g w " " L Easier to implement, no Jacobian of y ∗ ε Population loss minimized at ground truth for perturbed generative model. Q.Berthet - CIRM - 2020 11/17
Learning with perturbations and F-Y losses Within the same framework, possible to virtually bypass the optimization block X θ y g w g w L " L " Easier to implement, no Jacobian of y ∗ ε Population loss minimized at ground truth for perturbed generative model. Q.Berthet - CIRM - 2020 11/17
Unsupervised learning - parameter estimation Observation : Y 1 , . . . , Y n i.i.d. copies of y ∗ ( θ + εZ ) y ∗ ( θ + εZ ) C Y i = argmax � θ 0 + εZ i , y � θ + εZ θ + εZ y ∈C y ∗ y ∗ " ( θ ) " ( θ ) θ y ∗ ( θ ) y ∗ ( θ ) Estimating unknown θ 0 Minimization of empirical loss - related to inference in Gibbs models n L ε,n ( θ ) = 1 ¯ � stochastic grad. ∇ θ L ε ( θ, Y i ) = y ∗ L ( θ ; Y i ) , ε ( θ ) − Y i n i =1 Equal up to an additive constant to L ε ( θ ; ¯ Y n ) , in expectation to L ε ( θ ; y ∗ ε ( θ 0 )) Asymptotic normality for minimizer ˆ θ n around θ 0 Q.Berthet - CIRM - 2020 12/17
Supervised learning Motivated by model where y i = argmax y ∈C � g w 0 ( X i ) + εZ i , y � X θ y g w g w L " L " Stochastic gradients for empirical loss only require ∇ θ L ( θ = g w ( X i ); y i ) = y ∗ ε ( g w ( X i )) − y i . Simulated by a doubly stochastic scheme. Q.Berthet - CIRM - 2020 13/17
Recommend
More recommend