learning to stop while learning to predict
play

Learning To Stop While Learning To Predict Xinshi Chen 1 , Hanjun - PowerPoint PPT Presentation

Learning To Stop While Learning To Predict Xinshi Chen 1 , Hanjun Dai 2 , Yu Li 3 , Xin Gao 3 , Le Song 1,4 1 Georgia Tech, 2 Google Brain, 3 KAUST, 4 Ant Financial ICML 2020 5-minute Core Message Dynamic Depth stop at different depths for


  1. Learning To Stop While Learning To Predict Xinshi Chen 1 , Hanjun Dai 2 , Yu Li 3 , Xin Gao 3 , Le Song 1,4 1 Georgia Tech, 2 Google Brain, 3 KAUST, 4 Ant Financial ICML 2020

  2. 5-minute Core Message Dynamic Depth stop at different depths for different input samples. no no no yes stop? stop? stop? stop? stop, output π’š πŸ“ π’š πŸ‘ π’š 𝟐 π’š πŸ’ π’š πŸ“ stopped depth=4 no yes stop? stop? stop, output π’š πŸ‘ π’š πŸ‘ π’š 𝟐 stopped depth=2

  3. 5-minute Core Message Motivation 1. Task-imbalanced Meta Learning Task 1: fewer samples Need different numbers of gradient steps for adaptation πœ„ ' βˆ‡ / β„’ , πœ„ ()*+, Task 2: more samples βˆ‡ / β„’ - πœ„ ()*+-

  4. 5-minute Core Message Motivation 2. Data-driven Algorithm Design Traditional algorithms have certain stop criteria to determine the number of iterations for each problem. E.g., β€’ iterate until convergence β€’ early stopping to avoid over-fitting stop hand-designed π’š 𝒖 π’š π’š 𝒖 update step criteria (output) not satisfied Deep learning based algorithms usually have a fixed number of iterations in the architecture.

  5. 5-minute Core Message Motivation 3. Others Image Denoising β€’ Images with different noise levels may need different number of denoising steps. noisy less noisy Image Recognition β€’ β€˜early exits’ is proposed to improve the computation efficiency and avoid β€˜over-thinking’. [Teerapittayanon et al., 2016; Zamir et al., 2017; Huang et al., 2018, Kaya et al. (2019)]

  6. 5-minute Core Message Predictive Model with Stopping Policy Predictive model 𝓖 𝜾 β€’ Transforms the input π’š to generate a path of states π’š , , … , π’š 6 Stopping Policy 𝝆 𝝔 β€’ Sequentially observes the states π’š ( and determines the probability of stop at layer 𝑒 Variational stop time distribution 𝒓 𝝔 Stop time distribution induced by stopping policy 𝝆 𝝔 β€’ variational stop time (D, (1 βˆ’ 𝜌 = (𝑦 B ) 𝒓 𝝔 𝑒 = 𝜌 = (𝑦 ( ) ∏ BC, distribution stop 𝝆 𝝔 0 𝝆 𝝔 0 𝝆 𝝔 0 𝝆 𝝔 1 stop, output π’š πŸ“ policy π’š π’š πŸ‘ π’š 𝟐 π’š πŸ’ π’š πŸ“ predictive model 𝜾 𝟐 𝜾 πŸ‘ 𝜾 πŸ’ 𝜾 πŸ“

  7. 5-minute Core Message How to learn the optimal ( 𝓖 𝜾 , 𝝆 𝝔 ) efficiently? β€’ Design a joint training objective : β„’(𝓖 𝜾 , 𝒓 𝝔 ) β€’ Introduce an oracle stop time distribution : 𝒓 βˆ— |𝓖 𝜾 : = argmin π’“βˆˆQ RST β„’(𝓖 𝜾 , 𝒓) β€’ Then we decompose the learning procedure into two stages : (i) The oracle model learning stage (ii) The imitation learning stage oracle oracle β„’(𝓖 𝜾 , 𝒓 βˆ— |β„± 𝜾 ) KL divergence 𝒓 𝝔 π‘Ÿ βˆ— |β„± 𝜾 β„± 𝜾 π‘Ÿ βˆ— | β„± 𝜾 βˆ— optimal β„± 𝜾 βˆ— optimal 𝒓 𝝔 βˆ—

  8. 5-minute Core Message Advantages of our training procedure ΓΌ Principled β€’ Two components are optimized towards a joint objective. ΓΌ Tuning-free β€’ Weights of different layers in the loss are given by the oracle distribution automatically. β€’ For different input samples, the weights on the layers can be different. ΓΌ Efficient Instead of updating πœ„ and 𝜚 alternatively, 𝜾 is optimized in 1st stage, and then 𝜚 is optimized in 2nd stage. β€’ ΓΌ Generic β€’ can be applied to a diverse range of applications. ΓΌ Better understanding β€’ A variational Bayes perspective, for better understanding the proposed model and joint training. β€’ A reinforcement learning perspective, for better understanding the learning of the stop policy.

  9. 5-minute Core Message Experiments l Learning to optimize: sparse recovery l Task-imbalanced meta learning: few-shot learning l Image denoising l Some observations on image recognition tasks.

  10. Problem Formulation - Models Predictive model 𝓖 𝜾 β€’ π’š ( = 𝑔 / Y (π’š (D, ) , for 𝑒 = 1,2, … , π‘ˆ Stopping Policy 𝝆 𝝔 β€’ 𝜌 ( = 𝜌 = π’š, π’š ( , for 𝑒 = 1,2, … , π‘ˆ Variational stop time distribution 𝒓 𝝔 (induced by 𝝆 𝝔 ) (D, (1 βˆ’ 𝜌 B ) for 𝑒 < π‘ˆ π‘Ÿ = 𝑒 = 𝜌 ( ∏ BC, β€’ Pr[not stopped before t] β€’ Help design the training objective and the algorithm.

  11. Problem Formulation – Optimization Objective β„’ β„± / , π‘Ÿ = ; 𝑦, 𝑧 = 𝔽 (∼a b π‘š 𝑧, 𝑦 ( ; πœ„ βˆ’ 𝛾𝐼 π‘Ÿ = loss in entropy expectation over 𝑒 β€’ Variational Bayes Perspective min /,= β„’ β„± / , π‘Ÿ = ; 𝑦, 𝑧 max /,= 𝒦 hDijk β„± / , π‘Ÿ = ; 𝑦, 𝑧 equivalent (i.e., 𝛾 -VAE, ELBO)

  12. Training Algorithm – Stage I Oracle stop time distribution: Interpretation: It is the optimal stop time distribution given a predictive model β„± / β€’ βˆ— β‹… 𝑧, 𝑦 ≔ argmax π‘Ÿ / π’“βˆˆQ RST 𝒦 hDijk β„± / , 𝒓; 𝑦, 𝑧 βˆ— 𝑒 𝑧, 𝑦 = π‘ž / 𝑒 𝑧, 𝑦 When 𝛾 = 1 , the oracle is the true posterior, π‘Ÿ / β€’ β€’ This posterior is computationally tractable, but it requires the π‘ž / 𝑧 𝑒, 𝑦 ,/h knowledge of the true label 𝑧 . = 6 βˆ‘ (C, π‘ž / 𝑧 𝑒, 𝑦 ,/h Stage I. Oracle model learning 6 1 1 βˆ— ; 𝑦, 𝑧 = max βˆ— 𝑒 𝑧, 𝑦 log 𝒒 𝜾 (𝒛|𝒖, π’š) max r 𝒦 hDijk 𝓖 𝜾 , 𝒓 𝜾 r r 𝒓 𝜾 |𝒠| |𝒠| / / (s,t)βˆˆπ’  (s,t)βˆˆπ’  (C, likelihood of the output at 𝑒 -th layer

  13. Training Algorithm – Stage II Recall: Variational stop time distribution 𝒓 𝝔 𝑒|𝑦 induced by the sequential policy 𝝆 𝝔 βˆ— (𝑒|𝑧, 𝑦) , by optimizing the forward KL divergence : Hope: 𝒓 𝝔 𝑒|𝑦 can mimic the oracle distribution 𝒓 𝜾 βˆ— Stage II. Imitation With Sequential Policy 6 βˆ— | π‘Ÿ = = βˆ’ r βˆ— ) βˆ— KL(π‘Ÿ / βˆ— π‘Ÿ / βˆ— 𝑒 𝑧, 𝑦 log π‘Ÿ = 𝑒 𝑦 βˆ’ 𝐼(π‘Ÿ / βˆ— forward KL divergence (C, Note: If we use reverse KL divergence , then it is equivalent to solving maximum-entropy RL .

  14. Experiment I - Learning To Optimize: Sparse Recovery β€’ Task: Recover 𝑦 βˆ— from its noisy measurements 𝑐 = 𝐡𝑦 βˆ— + πœ— β€’ Traditional Approach: - + 𝜍||𝑦|| , – LASSO formulation min β€’ Β½||𝑐 βˆ’ 𝐡𝑦|| - – Solved by iterative algorithms such as ISTA β€’ Learning-based Algorithm: – Learned ISTA (LISTA) is a deep architecture designed based on ISTA update steps β€’ Ablation study: Whether LISTA with adaptive depth ( LISTA-stop ) is better than LISTA .

  15. Experiment II – Task-imbalanced Meta Learning β€’ Task: Task-imbalanced few-shot learning. Each task contains k -shots for each class where k can vary. β€’ Our variant, MAML-stop : – Built on top of MAML, but MAML-stop learns how many adaptation gradient descent steps are needed for each task. Task-imbalanced setting: Vanilla setting:

  16. Experiment III – Image Denoising β€’ Our variant, DnCNN-stop : – Built on top of one of the most popular models, DnCNN, for the denoising task. * Noise-level 65, 75 are not observed during training.

Recommend


More recommend