learning with differentiable perturbed optimizers
play

Learning with Differentiable Perturbed Optimizers Quentin Berthet - PowerPoint PPT Presentation

Learning with Differentiable Perturbed Optimizers Quentin Berthet Youth in High-dimensions - ICTP - 2020 Q. Berthet M.Blondel O.Teboul M. Cuturi J-P. Vert F.Bach Learning with Differentiable Perturbed Optimizers Preprint:


  1. Learning with Differentiable Perturbed Optimizers Quentin Berthet Youth in High-dimensions - ICTP - 2020

  2. Q. Berthet M.Blondel O.Teboul M. Cuturi J-P. Vert F.Bach • Learning with Differentiable Perturbed Optimizers Preprint: arXiv:2002.08676

  3. [A lot of] Machine learning these days Supervised learning : couples of inputs/responses ( X i , y i ) , a model g w X i X i ‘deer’ y i y i θ = g w ( X i ) θ = g w ( X i ) ‘ship’ L " L " g w g w ‘bird’ ‘horse’ ‘truck’ Goal : Optimize parameters w ∈ R d of a function g w such that g w ( X i ) ≈ y i � min L ( g w ( X i ) , y i ) . w i Workhorse : first-order methods, based on ∇ w L ( g w ( X i ) , y i ) , backpropagation Problem : What if these models contain nondifferentiable ∗ operations? Q.Berthet - ICTP - 2020 1/12

  4. Discrete decisions in Machine learning X y ∗ ( θ ) y ∗ ( θ ) θ y g w g w y ∗ y ∗ L Examples : discrete operations (e.g. max, rankings), break autodifferentiation y ∗ = vector of ranks e.g. [5 , 2 , 4 , 3 , 1] • θ = scores for k products, y ∗ = shortest path between two points • θ = edge costs, y ∗ = one-hot vector • θ = classification scores for each class, Q.Berthet - ICTP - 2020 2/12

  5. Discrete decisions in Machine learning X y ∗ ( θ ) y ∗ ( θ ) θ y g w g w y ∗ y ∗ L Examples : discrete operations (e.g. max, rankings), break autodifferentiation y ∗ = vector of ranks e.g. [5 , 2 , 4 , 3 , 1] • θ = scores for k products, y ∗ = shortest path between two points • θ = edge costs, y ∗ = one-hot vector • θ = classification scores for each class, Q.Berthet - ICTP - 2020 2/12

  6. Perturbed maximizer Discrete decisions : optimizers of linear program over C , convex hull of Y ⊆ R d y ∗ ( θ ) = argmax F ( θ ) = max y ∈C � y, θ � , and � y, θ � = ∇ θ F ( θ ) . y ∈C C θ θ y ∗ ( θ ) y ∗ ( θ ) Perturbed maximizer : average of solutions for inputs with noise εZ y ∈C � y, θ + εZ � ] , y ∗ ε ( θ ) = E [ y ∗ ( θ + εZ )] = E [argmax F ε ( θ ) = E [max � y, θ + εZ � ] = ∇ θ F ε ( θ ) . y ∈C Q.Berthet - ICTP - 2020 3/12

  7. Perturbed maximizer Discrete decisions : optimizers of linear program over C , convex hull of Y ⊆ R d y ∗ ( θ + εZ ) y ∗ ( θ + εZ ) C θ + εZ θ + εZ y ∗ y ∗ " ( θ ) " ( θ ) θ y ∗ ( θ ) y ∗ ( θ ) Perturbed maximizer : average of solutions for inputs with noise εZ y ∈C � y, θ + εZ � ] , y ∗ ε ( θ ) = E [ y ∗ ( θ + εZ )] = E [argmax F ε ( θ ) = E [max � y, θ + εZ � ] = ∇ θ F ε ( θ ) . y ∈C Q.Berthet - ICTP - 2020 4/12

  8. Perturbed model Model of optimal decision under uncertainty Luce (1959), McFadden et al. (1973) Y = argmax � y, θ + εZ � y ∈C Follows a perturbed model with Y ∼ p θ ( y ) , expectation y ∗ ε ( θ ) = E p θ [ Y ] . Perturb and map Papandreou & Yuille (2011), FT Perturbed L Kalai & Vempala (2003) Features Costs Shortest Path Perturbed Path = 0.5 Perturbed Path = 2.0 Example . Over the unit simplex C = ∆ d with Gumbel noise Z , Gibbs distribution. θi e θi ε � ε , [ y ∗ F ε ( θ ) = ε log e p θ ( e i ) ∝ exp( � θ, e i � /ε ) , ε ( θ )] i = � e θj ε i ∈ [ d ] Q.Berthet - ICTP - 2020 5/12

  9. Properties � ∗ is a convex function with domain C � Link with regularization : ε Ω = F ε y ∗ � � ε ( θ ) = argmax � y, θ � − ε Ω( y ) . y ∈C Consequence of duality and y ∗ ε ( θ ) = ∇ ε F ε ( θ ) . Generalized entropy Ω ε = 0 tiny ε small ε large ε Extreme temperatures. When ε → 0 , y ∗ ε ( θ ) → y ∗ ( θ ) for unique max. When ε → ∞ , y ∗ ε ( θ ) → argmin y Ω( y ) . Nonasymptotic results. Differentiability. Smoothness in the inputs, Jacobian as simple expectations. Q.Berthet - ICTP - 2020 6/12

  10. Learning and Fenchel-Young losses Learning from Y 1 , . . . , Y n for a model p θ . Gibbs distribution ∝ exp( � θ, Y � ) : minimize negative log-likelihood n L Gibbs ( θ ; Y ) = − 1 � � θ, Y i � + log Z ( θ ) n i =1 Stochastic gradient and full (batch) gradient: moment matching ∇ θ L Gibbs ( θ ; Y ) = E Gibbs ,θ [ Y ] − ¯ ∇ θ L Gibbs ( θ ; Y i ) = E Gibbs ,θ [ Y ] − Y i , Y n . Algorithmic challenge: replace by perturbed model Papandreou, Yuille (2011) ∇ θ L i ( θ ) = E p θ [ Y ] − Y i = y ∗ ε ( θ ) − Y i . Stochastic gradient of modified functional in θ , not a log-likelihood n L ε ( θ ; y ) = − 1 � � θ, Y i � + F ε ( θ ) . n i =1 Fenchel-Young loss Blondel et al. (2019) , good properties (convexity, randomness). Q.Berthet - ICTP - 2020 7/12

  11. Learning with perturbations and F-Y losses Within the same framework, possible to virtually bypass the optimization block X y ∗ y ∗ θ " ( θ ) " ( θ ) y y ∗ y ∗ g w g w " " L Easier to implement, no Jacobian of y ∗ ε Population loss minimized at ground truth for perturbed generative model. Q.Berthet - ICTP - 2020 8/12

  12. Learning with perturbations and F-Y losses Within the same framework, possible to virtually bypass the optimization block X θ y g w g w L " L " Easier to implement, no Jacobian of y ∗ ε Population loss minimized at ground truth for perturbed generative model. Q.Berthet - ICTP - 2020 8/12

  13. Computations Monte Carlo estimates. Perturbed maximizer and derivatives as expectations. For θ ∈ R d , Z (1) , . . . , Z ( M ) i.i.d. copies y ( ℓ ) = y ∗ ( θ + εZ ( ℓ ) ) y ∗ ( θ + εZ ) y ∗ ( θ + εZ ) C Unbiased estimate of y ∗ ε ( θ ) given by θ + εZ θ + εZ y ∗ y ∗ " ( θ ) " ( θ ) M y ε,M ( θ ) = 1 y ( ℓ ) . � ¯ θ M y ∗ ( θ ) y ∗ ( θ ) ℓ =1 Supervised learning : Features X i , model output θ w = g w ( X i ) , prediction y pred = y ∗ ε ( θ w ) . Stochastic gradient in w : ∇ w F i ( w ) = J w g w ( X i ) · ( y ∗ ε ( θ ) − Y i ) Q.Berthet - ICTP - 2020 9/12

  14. Computations Monte Carlo estimates. Perturbed maximizer and derivatives as expectations. For θ ∈ R d , Z (1) , . . . , Z ( M ) i.i.d. copies y ( ℓ ) = y ∗ ( θ + εZ ( ℓ ) ) y ∗ ( θ + εZ ) y ∗ ( θ + εZ ) C Unbiased estimate of y ∗ ε ( θ ) given by θ + εZ θ + εZ y ∗ y ∗ " ( θ ) " ( θ ) M y ε,M ( θ ) = 1 y ( ℓ ) . � ¯ θ M y ∗ ( θ ) y ∗ ( θ ) ℓ =1 Supervised learning : Features X i , model output θ w = g w ( X i ) , prediction y pred = y ∗ ε ( θ w ) . Stochastic gradient in w (doubly stochastic scheme) � 1 M � � y ∗ ( θ + εZ ( ℓ ) ) − Y i ∇ w F i ( w ) = J w g w ( X i ) · . M ℓ =1 Q.Berthet - ICTP - 2020 10/12

  15. Experiments Classification : CIFAR-10 dataset of images with 10 classes - Toy comparison X i X i ‘deer’ y i y i θ = g w ( X i ) θ = g w ( X i ) ‘ship’ L " L " g w g w ‘bird’ ‘horse’ ‘truck’ Architecture : vanilla-CNN made of 4 convolutional and 2 fully connected layers. Training : 600 epochs with minibatches of size 32 - influence of M and ε Train Accuracy Loss Loss 1.00 0.815 20.0 0.810 17.5 0.99 0.805 15.0 0.800 12.5 0.98 0.795 10.0 0.790 train M = 1 0.97 7.5 train M = 1000 0.785 test M = 1 5.0 perturbed Fenchel-Young, M = 1 perturbed Fenchel-Young, M = 1 test M = 1000 0.780 0.96 perturbed Fenchel-Young, M = 1000 perturbed Fenchel-Young, M = 1000 2.5 0.775 Cross entropy baseline Cross entropy baseline 0.95 0.770 0.0 0 100 200 300 400 500 600 0 100 200 300 400 500 600 10 4 10 2 10 0 10 2 10 4 epochs epochs Q.Berthet - ICTP - 2020 11/12

  16. Experiments Learning from shortest paths : From 10k examples of Warcraft 96 × 96 RGB images, representing 12 × 12 costs, and matrix of shortest paths. (Vlastelica et al. 19) Features Costs Shortest Path Perturbed Path = 0.5 Perturbed Path = 2.0 Train a CNN for 50 epochs, to learn costs recovery of optimal paths. Shortest Path Perfect Accuracy Cost ratio to optimal 100% 1.10 Perturbed FY 90% Blackbox loss 1.08 80% 70% 1.06 60% 50% 1.04 40% 30% 1.02 Perturbed FY 20% Blackbox loss 10% 1.00 Squared loss 0% 0 10 20 30 40 50 0 10 20 30 40 50 epochs epochs Q.Berthet - ICTP - 2020 12/12

  17. GRAZZIE

Recommend


More recommend