Efficient Meta Learning via Minibatch Proximal Update Pan Zhou Joint work with Xiao-Tong Yuan, Huan Xu, Shuicheng Yan, Jiashi Feng National University of Singapore pzhou@u.nus.edu Dec 11, 2019 1
Meta Learning via Minibatch Proximal Update (Meta-MinibatchProx) Meta-MinibatchProx learns a good prior model initialization from observed tasks such that is close to the optimal models of new similar tasks, promoting new task learning 2
Meta Learning via Minibatch Proximal Update (Meta-MinibatchProx) Meta-MinibatchProx learns a good prior model initialization from observed tasks such that is close to the optimal models of new similar tasks, promoting new task learning Training model: given a task distribution , we minimize a bi-level meta learning model • where each task has training samples is empirical loss with predictor and loss . 3
Meta Learning via Minibatch Proximal Update (Meta-MinibatchProx) Meta-MinibatchProx learns a good prior model initialization from observed tasks such that is close to the optimal models of new similar tasks, promoting new task learning Training model: given a task distribution , we minimize a bi-level meta learning model • update task-specific solution where each task has training samples is empirical loss with predictor and loss . 4
Meta Learning via Minibatch Proximal Update (Meta-MinibatchProx) Meta-MinibatchProx learns a good prior model initialization from observed tasks such that is close to the optimal models of new similar tasks, promoting new task learning Training model: given a task distribution , we minimize a bi-level meta learning model • update the prior model where each task has training samples is empirical loss with predictor and loss . 5
Meta Learning via Minibatch Proximal Update (Meta-MinibatchProx) Meta-MinibatchProx learns a good prior model initialization from observed tasks such that is close to the optimal models of new similar tasks, promoting new task learning Training model: given a task distribution , we minimize a bi-level meta learning model • where each task has training samples is empirical loss with predictor and loss . small average distance to optimum models of all tasks in expectation 6
Meta Learning via Minibatch Proximal Update (Meta-MinibatchProx) Meta-MinibatchProx learns a good prior model initialization from observed tasks such that is close to the optimal models of new similar tasks, promoting new task learning Test model: given a randomly sampled task consisting of K samples • where denotes the learnt prior initialization. 7
Meta Learning via Minibatch Proximal Update (Meta-MinibatchProx) Meta-MinibatchProx learns a good prior model initialization from observed tasks such that is close to the optimal models of new similar tasks, promoting new task learning Test model: given a randomly sampled task consisting of K samples • where denotes the learnt prior initialization. Benefit: a few data is sufficient for adaptation • small distance in expectation the learnt prior initialization is close to optimum when training and test tasks are sampled from the same distribution. 8
Optimization Algorithm We use SGD based algorithm to solve bi-level training model : 9
Optimization Algorithm We use SGD based algorithm to solve bi-level training model : Step1. select a mini-batch of task of size . • 10
Optimization Algorithm We use SGD based algorithm to solve bi-level training model : Step1. select a mini-batch of task of size . • Step2. for , compute an approximate minimizer: • 11
Optimization Algorithm We use SGD based algorithm to solve bi-level training model : Step1. select a mini-batch of task of size . • Step2. for , compute an approximate minimizer: • Step3. update the prior initialization model: • 12
Optimization Algorithm We use SGD based algorithm to solve bi-level training model : Step1. select a mini-batch of task of size . • Step2. for , compute an approximate minimizer: • Step3. update the prior initialization model: • Theorem 1 (convergence guarantees, informal). (1) Convex setting, i.e. convex . We prove (2) Nonconvex setting, i.e. smooth . We prove 13
Generalization Performance Guarantee Ideally, for a given task , one should train the model on the population risk • In practice, we has only K samples and adapt the learnt prior model to the new task: • Since , why is good for generalization in few-shot learning problem? • 14
Generalization Performance Guarantee Ideally, for a given task , one should train the model on the population risk • In practice, we has only K samples and adapt the learnt prior model to the new task: • Since , why is good for generalization in few-shot learning problem? • Theorem 2 (generalization performance guarantee, informal). Suppose each loss is convex and is smooth. Let . Then we have Remark: strong generalization performance , as our training model guarantees the learnt prior is close to the optimum model . 15
Experimental results Few-shot regression : smaller mean square error (MSE) between prediction and ground truth Few-shot classification: higher classification accuracy miniImageNet miniImageNet tieredImageNet tieredImageNet 1.15% 1.18% 55 72 1.44% 67 45 62 1.12% 5.15% 35 57 3.31% 25 0.8% 2.41% 52 47 15 1-shot 5-way 5-shot 5-way 1-shot 5-way 5-shot 5-way 1-shot 20-way 5-shot 20-way 1-shot 10-way 5-shot 10-way MAML FOMAML Reptile Ours MAML FOMAML Reptile Ours 16
POSTER # 26 05:00 -- 07:00 PM @ East Exhibition Hall B + C Thanks! 17
Recommend
More recommend