Loss Valleys and Generalization in Deep Learning Andrew Gordon Wilson Assistant Professor https://people.orie.cornell.edu/andrew Cornell University The Robotic Vision Probabilistic Object Detection Challenge CVPR Long Beach, CA June 17, 2019 1 / 41
Model Selection 700 Airline Passengers (Thousands) 600 500 400 300 200 100 1949 1951 1953 1955 1957 1959 1961 Year Which model should we choose? 10 4 3 � � a j x j a j x j (1): f 1 ( x ) = a 0 + a 1 x (2): f 2 ( x ) = (3): f 3 ( x ) = j = 0 j = 0 2 / 41
Bayesian or Frequentist? 3 / 41
How do we learn? ◮ The ability for a system to learn is determined by its support (which solutions are a priori possible) and inductive biases (which solutions are a priori likely). ◮ An influx of new massive datasets provide great opportunities to automatically learn rich statistical structure, leading to new scientific discoveries. Flexible Simple p(data|model) Medium All Possible Datasets 4 / 41
Bayesian Deep Learning Why? ◮ A powerful framework for model construction and understanding generalization ◮ Uncertainty representation and calibration (crucial for decision making) ◮ Better point estimates ◮ Interpretably incorporate prior knowledge and domain expertise ◮ It was the most successful approach at the end of the second wave of neural networks (Neal, 1998). ◮ Neural nets are much less mysterious when viewed through the lens of probability theory. Why not? ◮ Can be computationally intractable (but doesn’t have to be). ◮ Can involve a lot of moving parts (but doesn’t have to). There has been exciting progress in the last two years addressing these limitations as part of an extremely fruitful research direction. 5 / 41
Wide Optima Generalize Better Keskar et. al (2017) ◮ Bayesian integration will give very different predictions in deep learning especially ! 6 / 41
Mode Connectivity > 5 80 5 60 2.3 1.1 40 0.54 20 0.28 0 0.17 0.11 − 20 0.065 − 20 0 20 40 60 80 100 100 > 5 > 5 60 5 5 80 2.3 2.3 60 40 1.1 1.1 40 0.54 0.54 20 20 0.28 0.28 0 0 0.17 0.17 − 20 0.11 0.11 − 20 0.065 0.065 − 20 0 20 40 60 80 100 − 20 0 20 40 60 80 100 Loss Surfaces, Mode Connectivity, and Fast Ensembling of DNNs Advances in Neural Information Processing Systems (NeurIPS), 2018 T. Garipov, P. Izmailov, D. Podoprikhin, D. Vetrov, A.G. Wilson 7 / 41
Cyclical Learning Rate Schedule 8 / 41
Trajectory of SGD Test error (%) > 50 30 50 W 2 35.97 20 28.49 WSWA 24.5 10 22.38 W 1 W 3 0 21.24 20.64 − 10 19.95 − 10 0 10 20 30 40 50 9 / 41
Trajectory of SGD Test error (%) > 50 30 50 W 2 35.97 20 28.49 WSWA 24.5 10 22.38 W 1 W 3 0 21.24 20.64 − 10 19.95 − 10 0 10 20 30 40 50 10 / 41
Trajectory of SGD Test error (%) > 50 30 50 W 2 35.97 20 28.49 WSWA 24.5 10 22.38 W 1 W 3 0 21.24 20.64 − 10 19.95 − 10 0 10 20 30 40 50 11 / 41
SWA Algorithm ◮ Use learning rate that doesn’t decay to zero (cyclical or constant) ◮ Average weights ◮ Cyclical LR: at the end of each cycle ◮ Constant LR: at the end of each epoch ◮ Recompute batch normalization statistics at the end of training; in practice, do one additional forward pass on the training data. 12 / 41
Trajectory of SGD Test error (%) > 50 30 50 W 2 35.97 20 28.49 WSWA 24.5 10 22.38 W 1 W 3 0 21.24 20.64 − 10 19.95 − 10 0 10 20 30 40 50 Test error (%) Train loss > 50 > 0.8832 50 0.8832 WSGD WSGD 10 10 35.11 0.4391 27.52 0.2206 5 5 23.65 0.1131 21.67 0.06024 WSWA WSWA 0 20.67 0 0.03422 epoch 125 epoch 125 20.15 0.02142 19.62 0.00903 − 5 0 5 10 15 20 25 − 5 0 5 10 15 20 25 13 / 41
Following Random Paths 3 0 2 8 2 6 2 4 2 2 2 0 0 5 1 0 1 5 2 0 14 / 41
Path from w SWA to w SGD 30.0 2.5 Test error Train loss SWA SWA 27.5 2.0 SGD SGD Test error (%) 25.0 1.5 Train loss 22.5 1.0 20.0 0.5 17.5 0.0 − 80 − 60 − 40 − 20 0 20 40 Distance 15 / 41
Approximating an FGE Ensemble Because the points sampled from an FGE ensemble take small steps in weight space by design , we can do a linearization analysis to show that f ( w SWA ) ≈ 1 � f ( w i ) n 16 / 41
SWA Results, CIFAR 17 / 41
SWA Results, ImageNet (Top-1 Error Rate) 18 / 41
Sampling from a High Dimensional Gaussian SGD (with constant LR) proposals are on the surface of a hypersphere. Averaging lets us go inside the sphere to a point of higher density. 19 / 41
High Constant LR 50 45 40 Test error (%) 35 30 25 SGD Const LR SGD 20 Const LR SWA 15 0 50 100 150 200 250 300 Epochs Side observation: Averaging bad models does not give good solutions. Averaging bad weights can give great solutions. 20 / 41
Stochastic Weight Averaging ◮ Simple drop-in replacement for SGD or other optimizers ◮ Works by finding flat regions of the loss surface ◮ No runtime overhead, but often significant improvements in generalization for many tasks ◮ Available in PyTorch contrib (call optim.swa ) ◮ https://people.orie.cornell.edu/andrew/code Averaging Weights Leads to Wider Optima and Better Generalization , UAI 2018 P. Izmailov, D. Podoprikhin, T. Garipov, D. Vetrov, A.G. Wilson. 21 / 41
Uncertainty Representation with SWAG 1. Leverage theory that shows SGD with a constant learning rate is approximately sampling from a Gaussian distribution. 2. Compute first two moments of SGD trajectory (SWA computes just the first). 3. Use these moments to construct a Gaussian approximation in weight space. 4. Sample from this Gaussian distribution, pass samples through predictive distribution, and form a Bayesian model average. A Simple Baseline for Bayesian Uncertainty in Deep Learning W. Maddox, P. Izmailov, T. Garipov, D. Vetrov, A.G. Wilson 22 / 41
Uncertainty Calibration WideResNet28x10 CIFAR-100 WideResNet28x10 CIFAR-10 → STL-10 0.20 0.40 0.35 0.15 Confidence - Accuracy Confidence - Accuracy 0.30 0.10 0.25 0.05 0.20 0.00 0.15 0.10 -0.05 0.05 -0.10 0.00 0.200 0.759 0.927 0.978 0.993 0.998 0.200 0.759 0.927 0.978 0.993 0.998 Confidence (max prob) Confidence (max prob) DenseNet-161 ImageNet ResNet-152 ImageNet 0.12 0.10 0.10 Confidence - Accuracy Confidence - Accuracy 0.08 0.08 0.05 0.05 0.02 0.02 0.00 0.00 -0.02 -0.03 -0.05 -0.05 -0.08 0.200 0.759 0.927 0.978 0.993 0.998 0.200 0.759 0.927 0.978 0.993 0.998 Confidence (max prob) Confidence (max prob) 23 / 41
Uncertainty Likelihood 24 / 41
Subspace Inference for Bayesian Deep Learning A modular approach: ◮ Construct a subspace of a network with a high dimensional parameter space ◮ Perform inference directly in the subspace ◮ Sample from approximate posterior for Bayesian model averaging We can approximate the posterior of a WideResNet with 36 million parameters in a 5D subspace and achieve state-of-the-art results! 25 / 41
Subspace Construction ◮ Choose shift ˆ w and basis vectors { d 1 , . . . , d k } . ◮ Define subspace S = { w | w = ˆ w + t 1 d 1 + t k d k } . ◮ Likelihood p ( D| t ) = p M ( D| w = ˆ w + Pt ) . 26 / 41
Inference ◮ Approximate inference over parameters t ◮ MCMC, Variational Inference, Normalizing Flows, . . . ◮ Bayesian model averaging at test time: J p ( D ∗ |D ) = 1 � w + P ˜ ˜ p M ( D ∗ | ˜ w = ˆ t i ) , t i ∼ q ( t |D ) (1) J j = 1 27 / 41
Subspace Choice We want a subspace that ◮ Contains diverse models which give rise to different predictions ◮ Cheap to construct 28 / 41
Random Subspace ◮ Directions d 1 , . . . , d k ∼ N ( 0 , I p ) ◮ Use pre-trained solution as shift ˆ w ◮ Subspace S = { w | w = ˆ w + Pt } 29 / 41
PCA of SGD Trajectory ◮ Run SGD with a high constant learning rate from a pre-trained solution ◮ Collect snapshots of weights w i � 1 ◮ Use SWA solution as shift ˆ w = i w i M ◮ { d 1 , . . . , d k } are the first k PCA components of vectors ˆ w − w i . 30 / 41
Curve Subspace 31 / 41
Subspace Comparison (Regression) 32 / 41
Subspace Comparison (Classification) Subspace Inference for Bayesian Deep Learning P. Izmailov, W. Maddox, P. Kirichenko, T. Garipov, D. Vetrov, A.G. Wilson 33 / 41
Semi-Supervised Learning ◮ Make label predictions using structure from both unlabelled and labelled training data. ◮ Can quantify recent advances in unsupervised learning. ◮ Crucial for reducing the dependency of deep learning on large labelled datasets. 34 / 41
Semi-Supervised Learning There Are Many Consistent Explanations of Unlabeled Data: Why You Should Average B. Athiwaratkun, M. Finzi, P. Izmailov, A.G. Wilson ICLR 2019 � � L ( w f ) = ℓ CE ( w f , x , y ) + λ ℓ cons ( w f , x ) ( x , y ) ∈D L x ∈D L ∪D U � �� � � �� � L cons L CE 35 / 41
Semi-Supervised Learning World record results on semi-supervised vision benchmarks 36 / 41
Recommend
More recommend