Not All Samples Are Created Equal Deep Learning with Importance Sampling Angelos Katharopoulos & Fran¸ cois Fleuret ICML, July 11, 2018 Funded by
Evolution of gradient norms during training Small CNN on MNIST Gradient norm CDF 1 . 0 10 0 0 . 9 Training loss Probability 0 . 8 10 − 1 0 . 7 0 . 6 10 − 2 0 . 5 0 2000 4000 6000 8000 10000 10 − 1 10 0 10 1 10 2 Iterations Gradient norm A. Katharopoulos Not All Samples Are Created Equal 2/13
Evolution of gradient norms during training Small CNN on MNIST Gradient norm CDF 1 . 0 10 0 0 . 9 Training loss Probability 0 . 8 10 − 1 0 . 7 0 . 6 10 − 2 0 . 5 0 2000 4000 6000 8000 10000 10 − 1 10 0 10 1 10 2 Iterations Gradient norm A. Katharopoulos Not All Samples Are Created Equal 2/13
Evolution of gradient norms during training Small CNN on MNIST Gradient norm CDF 1 . 0 10 0 0 . 9 Training loss Probability 0 . 8 10 − 1 0 . 7 0 . 6 10 − 2 0 . 5 0 2000 4000 6000 8000 10000 10 − 1 10 0 10 1 10 2 Iterations Gradient norm A. Katharopoulos Not All Samples Are Created Equal 2/13
Evolution of gradient norms during training Small CNN on MNIST Gradient norm CDF 1 . 0 10 0 0 . 9 Training loss Probability 0 . 8 10 − 1 85% of the samples have 0 . 7 negligible gradient 0 . 6 10 − 2 0 . 5 0 2000 4000 6000 8000 10000 10 − 1 10 0 10 1 10 2 Iterations Gradient norm A. Katharopoulos Not All Samples Are Created Equal 2/13
Related work ◮ Sample points proportionally to the gradient norm (Needell et al., 2014; Zhao and Zhang, 2015; Alain et al., 2015) ◮ SVRG type methods (Johnson and Zhang, 2013; Defazio et al., 2014; Lei et al., 2017) ◮ Sample using the loss ◮ Hard/Semi-hard sample mining (Schroff et al., 2015; Simo-Serra et al., 2015) ◮ Online Batch Selection (Loshchilov and Hutter, 2015) ◮ Prioritized Experience Replay (Schaul et al., 2015) A. Katharopoulos Not All Samples Are Created Equal 3/13
Related work ◮ Sample points proportionally to the gradient norm (Needell et al., 2014; Zhao and Zhang, 2015; Alain et al., 2015) ◮ SVRG type methods (Johnson and Zhang, 2013; Defazio et al., 2014; Lei et al., 2017) ◮ Sample using the loss ◮ Hard/Semi-hard sample mining (Schroff et al., 2015; Simo-Serra et al., 2015) ◮ Online Batch Selection (Loshchilov and Hutter, 2015) ◮ Prioritized Experience Replay (Schaul et al., 2015) A. Katharopoulos Not All Samples Are Created Equal 3/13
Contributions ◮ Derive a fast to compute importance distribution ◮ Variance cannot always be reduced so start importance sampling when it is useful A. Katharopoulos Not All Samples Are Created Equal 4/13
Contributions ◮ Derive a fast to compute importance distribution ◮ Variance cannot always be reduced so start importance sampling when it is useful BONUS ◮ Package everything in an embarassingly simple to use library A. Katharopoulos Not All Samples Are Created Equal 4/13
Deriving the sampling distribution (1) Similar to Zhao and Zhang (2015) we want to minimize the variance of the gradients. � � P ∗ = arg min w 2 i � G i � 2 Tr ( V P [ w i G i ]) = arg min E P 2 P To simplify, we minimize an upper bound � � � 2 � � G i � 2 ≤ ˆ w 2 i � G i � 2 w 2 i ˆ G i ⇐ ⇒ min P E P ≤ min P E P G i 2 A. Katharopoulos Not All Samples Are Created Equal 5/13
Deriving the sampling distribution (1) Similar to Zhao and Zhang (2015) we want to minimize the variance of the gradients. � � P ∗ = arg min w 2 i � G i � 2 Tr ( V P [ w i G i ]) = arg min E P 2 P To simplify, we minimize an upper bound � � � 2 � � G i � 2 ≤ ˆ w 2 i � G i � 2 w 2 i ˆ G i ⇐ ⇒ min P E P ≤ min P E P G i 2 A. Katharopoulos Not All Samples Are Created Equal 5/13
Deriving the sampling distribution (1) Similar to Zhao and Zhang (2015) we want to minimize the variance of the gradients. � � P ∗ = arg min w i 2 � G i � 2 Tr ( V P [ w i G i ]) = arg min E P 2 P To simplify, we minimize an upper bound � � � 2 � w i 2 � G i � 2 w i 2 ˆ � G i � 2 ≤ ˆ G i ⇐ ⇒ min P E P ≤ min P E P G i 2 A. Katharopoulos Not All Samples Are Created Equal 5/13
Deriving the sampling distribution (1) Similar to Zhao and Zhang (2015) we want to minimize the variance of the gradients. � � P ∗ = arg min w 2 i � G i � 2 Tr ( V P [ w i G i ]) = arg min E P 2 P To simplify, we minimize an upper bound � � � 2 � � G i � 2 ≤ ˆ w 2 i � G i � 2 w 2 i ˆ G i ⇐ ⇒ min P E P ≤ min P E P G i 2 A. Katharopoulos Not All Samples Are Created Equal 5/13
Deriving the sampling distribution (2) We show that we can upper bound the gradient norm of the parameters using the norm of the gradient with respect to the pre-activation outputs of the last layer. We conjecture that batch normalization and weight initialization make it tight. A. Katharopoulos Not All Samples Are Created Equal 6/13
Variance reduction achieved with our upper-bound CIFAR-100 1 . 2 uniform loss 1 . 1 gradient-norm Empirical variance reduction 1 . 0 0 . 9 0 . 8 0 . 7 0 . 6 0 . 5 0 . 4 10000 20000 30000 40000 50000 Iterations A. Katharopoulos Not All Samples Are Created Equal 7/13
Variance reduction achieved with our upper-bound CIFAR-100 1 . 2 uniform loss 1 . 1 gradient-norm Empirical variance reduction 1 . 0 upper-bound (ours) 0 . 9 0 . 8 0 . 7 0 . 6 0 . 5 0 . 4 10000 20000 30000 40000 50000 Iterations A. Katharopoulos Not All Samples Are Created Equal 7/13
Variance reduction achieved with our upper-bound Downsampled Imagenet 1 . 25 uniform loss 1 . 20 Empirical variance reduction gradient-norm 1 . 15 1 . 10 1 . 05 1 . 00 0 . 95 0 . 90 0 . 85 0 50000 100000 150000 200000 250000 300000 Iterations A. Katharopoulos Not All Samples Are Created Equal 7/13
Variance reduction achieved with our upper-bound Downsampled Imagenet 1 . 25 uniform loss 1 . 20 Empirical variance reduction gradient-norm 1 . 15 upper-bound (ours) 1 . 10 1 . 05 1 . 00 0 . 95 0 . 90 0 . 85 0 50000 100000 150000 200000 250000 300000 Iterations A. Katharopoulos Not All Samples Are Created Equal 7/13
Is the upper-bound enough to speed up training? Not really, because ◮ a forward pass on the whole dataset is still prohibitive ◮ the importance distribution can be arbitrarily close to uniform Two key ideas ◮ Sample a large batch ( B ) randomly and resample a small batch ( b ) with importance ◮ Start importance sampling when the variance will be reduced A. Katharopoulos Not All Samples Are Created Equal 8/13
When do we start importance sampling? We start importance sampling when the variance reduction is large enough B B B Tr ( V u [ G i ]) − Tr ( V P [ w i G i ]) = 1 � � � ( p i − u ) 2 ∝ � G i � 2 ( p i − u ) 2 2 B i =1 i =1 i =1 � �� � distance of importance distribution to uniform A. Katharopoulos Not All Samples Are Created Equal 9/13
When do we start importance sampling? We start importance sampling when the variance reduction is large enough B B B Tr ( V u [ G i ]) − Tr ( V P [ w i G i ]) = 1 � � � ( p i − u ) 2 ∝ � G i � 2 ( p i − u ) 2 2 B i =1 i =1 i =1 � �� � distance of importance distribution to uniform � � − 1 i ( p i − u ) 2 � We show that the equivalent batch increment τ ≥ 1 − which allows i p 2 � i us to perform importance sampling when Bt forward + b ( t forward + t backward ) ≤ τ ( t forward + t backward ) b � �� � � �� � Time for importance Time for equivalent sampling iteration uniform sampling iteration A. Katharopoulos Not All Samples Are Created Equal 9/13
Experimental setup ◮ We fix a time budget for all methods and compare the achieved training loss and test error ◮ We evaluate on three tasks 1. WideResnets on CIFAR10/100 (image classification task) 2. Pretrained ResNet50 on MIT67 (finetuning task) 3. LSTM on permuted MNIST (sequence classification task) A. Katharopoulos Not All Samples Are Created Equal 10/13
Importance sampling for image classification 4 × 10 0 Training loss relative to uniform Test error relative to uniform 10 1 3 × 10 0 2 × 10 0 10 0 10 − 1 10 0 10 − 2 6 × 10 − 1 Uniform Uniform CIFAR-10 AR-10 CIFAR-100 A. Katharopoulos Not All Samples Are Created Equal 11/13
Importance sampling for image classification ◮ SVRG methods do not work for Deep Learning 4 × 10 0 Training loss relative to uniform Test error relative to uniform 10 1 3 × 10 0 2 × 10 0 10 0 10 − 1 10 0 10 − 2 6 × 10 − 1 SVRG Katyusha SCSG Uniform SVRG Katyusha SCSG Uniform CIFAR-10 AR-10 CIFAR-100 A. Katharopoulos Not All Samples Are Created Equal 11/13
Importance sampling for image classification ◮ SVRG methods do not work for Deep Learning ◮ Our loss-based sampling outperfoms existing loss based methods 4 × 10 0 Training loss relative to uniform Test error relative to uniform 10 1 3 × 10 0 2 × 10 0 10 0 10 − 1 10 0 10 − 2 6 × 10 − 1 SVRG Katyusha SCSG Uniform Loschilov 2015 Schaul 2015 Loss (ours) SVRG Katyusha SCSG Uniform Loschilov 2015 Schaul 2015 Loss (ours) CIFAR-10 AR-10 CIFAR-100 A. Katharopoulos Not All Samples Are Created Equal 11/13
Recommend
More recommend