Learning To Stop While Learning To Predict Xinshi Chen 1 , Hanjun Dai 2 , Yu Li 3 , Xin Gao 3 , Le Song 1,4 1 Georgia Tech, 2 Google Brain, 3 KAUST, 4 Ant Financial ICML 2020
5-minute Core Message Dynamic Depth stop at different depths for different input samples. no no no yes stop? stop? stop? stop? stop, output π π π π π π π π π π stopped depth=4 no yes stop? stop? stop, output π π π π π π stopped depth=2
5-minute Core Message Motivation 1. Task-imbalanced Meta Learning Task 1: fewer samples Need different numbers of gradient steps for adaptation π ' β / β , π ()*+, Task 2: more samples β / β - π ()*+-
5-minute Core Message Motivation 2. Data-driven Algorithm Design Traditional algorithms have certain stop criteria to determine the number of iterations for each problem. E.g., β’ iterate until convergence β’ early stopping to avoid over-fitting stop hand-designed π π π π π update step criteria (output) not satisfied Deep learning based algorithms usually have a fixed number of iterations in the architecture.
5-minute Core Message Motivation 3. Others Image Denoising β’ Images with different noise levels may need different number of denoising steps. noisy less noisy Image Recognition β’ βearly exitsβ is proposed to improve the computation efficiency and avoid βover-thinkingβ. [Teerapittayanon et al., 2016; Zamir et al., 2017; Huang et al., 2018, Kaya et al. (2019)]
5-minute Core Message Predictive Model with Stopping Policy Predictive model π πΎ β’ Transforms the input π to generate a path of states π , , β¦ , π 6 Stopping Policy π π β’ Sequentially observes the states π ( and determines the probability of stop at layer π’ Variational stop time distribution π π Stop time distribution induced by stopping policy π π β’ variational stop time (D, (1 β π = (π¦ B ) π π π’ = π = (π¦ ( ) β BC, distribution stop π π 0 π π 0 π π 0 π π 1 stop, output π π policy π π π π π π π π π predictive model πΎ π πΎ π πΎ π πΎ π
5-minute Core Message How to learn the optimal ( π πΎ , π π ) efficiently? β’ Design a joint training objective : β(π πΎ , π π ) β’ Introduce an oracle stop time distribution : π β |π πΎ : = argmin πβQ RST β(π πΎ , π) β’ Then we decompose the learning procedure into two stages : (i) The oracle model learning stage (ii) The imitation learning stage oracle oracle β(π πΎ , π β |β± πΎ ) KL divergence π π π β |β± πΎ β± πΎ π β | β± πΎ β optimal β± πΎ β optimal π π β
5-minute Core Message Advantages of our training procedure ΓΌ Principled β’ Two components are optimized towards a joint objective. ΓΌ Tuning-free β’ Weights of different layers in the loss are given by the oracle distribution automatically. β’ For different input samples, the weights on the layers can be different. ΓΌ Efficient Instead of updating π and π alternatively, πΎ is optimized in 1st stage, and then π is optimized in 2nd stage. β’ ΓΌ Generic β’ can be applied to a diverse range of applications. ΓΌ Better understanding β’ A variational Bayes perspective, for better understanding the proposed model and joint training. β’ A reinforcement learning perspective, for better understanding the learning of the stop policy.
5-minute Core Message Experiments l Learning to optimize: sparse recovery l Task-imbalanced meta learning: few-shot learning l Image denoising l Some observations on image recognition tasks.
Problem Formulation - Models Predictive model π πΎ β’ π ( = π / Y (π (D, ) , for π’ = 1,2, β¦ , π Stopping Policy π π β’ π ( = π = π, π ( , for π’ = 1,2, β¦ , π Variational stop time distribution π π (induced by π π ) (D, (1 β π B ) for π’ < π π = π’ = π ( β BC, β’ Pr[not stopped before t] β’ Help design the training objective and the algorithm.
Problem Formulation β Optimization Objective β β± / , π = ; π¦, π§ = π½ (βΌa b π π§, π¦ ( ; π β πΎπΌ π = loss in entropy expectation over π’ β’ Variational Bayes Perspective min /,= β β± / , π = ; π¦, π§ max /,= π¦ hDijk β± / , π = ; π¦, π§ equivalent (i.e., πΎ -VAE, ELBO)
Training Algorithm β Stage I Oracle stop time distribution: Interpretation: It is the optimal stop time distribution given a predictive model β± / β’ β β π§, π¦ β argmax π / πβQ RST π¦ hDijk β± / , π; π¦, π§ β π’ π§, π¦ = π / π’ π§, π¦ When πΎ = 1 , the oracle is the true posterior, π / β’ β’ This posterior is computationally tractable, but it requires the π / π§ π’, π¦ ,/h knowledge of the true label π§ . = 6 β (C, π / π§ π’, π¦ ,/h Stage I. Oracle model learning 6 1 1 β ; π¦, π§ = max β π’ π§, π¦ log π πΎ (π|π, π) max r π¦ hDijk π πΎ , π πΎ r r π πΎ |π | |π | / / (s,t)βπ (s,t)βπ (C, likelihood of the output at π’ -th layer
Training Algorithm β Stage II Recall: Variational stop time distribution π π π’|π¦ induced by the sequential policy π π β (π’|π§, π¦) , by optimizing the forward KL divergence : Hope: π π π’|π¦ can mimic the oracle distribution π πΎ β Stage II. Imitation With Sequential Policy 6 β | π = = β r β ) β KL(π / β π / β π’ π§, π¦ log π = π’ π¦ β πΌ(π / β forward KL divergence (C, Note: If we use reverse KL divergence , then it is equivalent to solving maximum-entropy RL .
Experiment I - Learning To Optimize: Sparse Recovery β’ Task: Recover π¦ β from its noisy measurements π = π΅π¦ β + π β’ Traditional Approach: - + π||π¦|| , β LASSO formulation min β’ Β½||π β π΅π¦|| - β Solved by iterative algorithms such as ISTA β’ Learning-based Algorithm: β Learned ISTA (LISTA) is a deep architecture designed based on ISTA update steps β’ Ablation study: Whether LISTA with adaptive depth ( LISTA-stop ) is better than LISTA .
Experiment II β Task-imbalanced Meta Learning β’ Task: Task-imbalanced few-shot learning. Each task contains k -shots for each class where k can vary. β’ Our variant, MAML-stop : β Built on top of MAML, but MAML-stop learns how many adaptation gradient descent steps are needed for each task. Task-imbalanced setting: Vanilla setting:
Experiment III β Image Denoising β’ Our variant, DnCNN-stop : β Built on top of one of the most popular models, DnCNN, for the denoising task. * Noise-level 65, 75 are not observed during training.
Recommend
More recommend