Advanced Section #1: Moving averages, optimization algorithms, understanding dropout and batch normalization AC 209B: Data Science 2 Javier Zazo Pavlos Protopapas
Lecture Outline Moving averages Optimization algorithms Tuning the learning rate Gradient checking How to address overfitting Dropout Batch normalization 2
Moving averages 3
Moving averages ◮ Given a stationary process x [ n ] and a sequence of observations x 1 , x 2 , . . . , x n , . . . , we want to estimate the average of all values dynamically . ◮ We can use a moving average for instant n : x n +1 = 1 n ( x 1 + x 2 + . . . + x n ) ◮ To save computations and memory : � � � � n − 1 n − 1 � n � � x n +1 = 1 x i = 1 = 1 1 x n + x i x n + ( n − 1) x i n n n n − 1 i =1 i =1 i =1 = 1 n ( x n + ( n − 1) x n ) = x n + 1 n ( x n − x n ) ◮ Essentially, for α n = 1 /n , x n +1 = x n + α n ( x n − x n ) 4
Weighted moving averages ◮ Previous step size α n = 1 /n is dynamic. ◮ From stochastic approximation theory , the estimate converges to the true value with probability 1, if ∞ ∞ � � α 2 α i = ∞ and i < ∞ i =1 i =1 ◮ α n = 1 n satisfies the previous conditions. ◮ Constant α does not satisfy the second!! ◮ This can be useful to track non-stationary processes. 5
Exponentially weighted moving average ◮ Update rule for constant step size is x n +1 = x n + α ( x n − x n ) = αx n + (1 − α ) x n = αx n + (1 − α )[ αx n − 1 + (1 − α ) x n − 1 ] = αx n + (1 − α ) αx n − 1 + (1 − α ) 2 x n − 1 ] = αx n + (1 − α ) αx n − 1 + (1 − α ) 2 αx n − 2 + . . . + (1 − α ) n − 1 αx 1 + (1 − α ) n x 1 � n = (1 − α ) n x 1 + α (1 − α ) n − i x i i =1 ◮ Note that (1 − α ) n + � n i =1 α (1 − α ) n − i = 1. ◮ With infinite terms we get x n + (1 − α ) x n − 1 + (1 − α ) 2 x n − 2 + (1 − α ) 3 x n − 3 + . . . n →∞ x n = lim lim 1 + (1 − α ) + (1 − α ) 2 + (1 − α ) 3 + . . . n →∞ 6
Exponentially weighted moving average ◮ Recap update rule, but change 1 − α = β x n − 1 = βx n − 1 + (1 − β ) x n , ◮ β controls the amount of points to consider (variance): ◮ Rule of thumb: s N = 1 + β e 1 − β amounts to 86% of influence. u l a v – β = 0 . 9 corresponds to 19 points. p o i n t s – β = . 98 corresponds to 99 points (wide window). – β = 0 . 5 corresponds to 3 points (susceptible to outliers). 7
Bias correction ◮ The rule of thumb works for sufficiently large N . ◮ Otherwise, the first values are biased. ◮ We can correct the variance with: x n x corrected = 1 − β t . n s e u l a v p o i n t s 8
Bias correction II ◮ The bias correction can in practice be ignored (Keras does not implement it). ◮ Origin of bias comes from zero initialization: n � x n +1 = β n x 1 β n − i x i +(1 − β ) ���� i =1 0 ◮ Derivation: � � n � β n − i x i E [ x n +1 ] = E (1 − β ) i =1 � n β n − i + ζ = E [ x n ](1 − β ) i =1 = E [ x n ](1 − β n ) + ζ 9
Optimization algorithms 10
Gradient descent ◮ Gradient descent will have high variance if the problem is ill-conditioned. ◮ Aim to estimate directions of high variance and reduce their influence. ◮ Descent with momentum, RMSprop or Adam, help reduce the variance and speed up convergence. 11
Gradient descent with momentum ◮ The algorithm: 1: On iteration t for W update: Compute dW on current mini-batch. 2: v dW = βv dW + (1 − β ) dW . 3: W = W − αv dW . 4: ◮ Gradient with momentum performs an exponential moving average over the gradients. ◮ This will reduce the variance and give more stable descent directions. ◮ Bias correction is usually not applied. 12
RMSprop ◮ The algorithm: 1: On iteration t for W update: Compute dW on current mini-batch. 2: s dW = β 2 s dW + (1 − β 2 ) dW 2 . 3: dW W = W − α √ s dW + ǫ . 4: ◮ ǫ = 10 − 8 controls numerical stability. ◮ High variance gradients will have larger values → the squared averages will be large → reduces the step size. ◮ Allows a higher learning rate → faster convergence. 13
Adaptive moment estimation (Adam) ◮ The algorithm: 1: On iteration t for W update: Compute dW on current mini-batch. 2: v dW = β 1 v dW + (1 − β 1 ) dW . 3: s dW = β 2 s dW + (1 − β 2 ) dW 2 . 4: v corrected = v dW 5: 1 − β t 1 s corrected = s dW 6: 1 − β t 2 W = W − α v corrected √ s dW + ǫ . 7: 14
AMSGrad ◮ Adam/RMSprop fail to converge on certain convex problems. ◮ Reason is that some important descent directions are weakened by high second order estimations. ◮ AMSGrad proposes a conservative fix where second order moment estimator can only increase. ◮ The algorithm: 1: On iteration t for W update: 2: Compute dW on current mini-batch. v n +1 dW = β 1 v n 3: dW + (1 − β 1 ) dW . s n +1 dW + (1 − β 2 ) dW 2 . dW = β 2 s n 4: s n +1 dW , s n +1 s n 5: ˆ dW = max(ˆ dW ) W = W − α v corrected √ 6: dW + ǫ . s n +1 ˆ 15
Marginal value of adaptive gradient methods 16
Tuning the learning rate 17
Cyclical Learning Rates for Neural Networks ◮ Use cyclical learning rates to escape local extreme points. ◮ Saddle points are abundant in high dimensions, and convergence becomes very slow. Furthermore, they can help escape sharp local minima (overfitting). ◮ Cyclic learning rates raise the learning rate periodically: short term negative effect and yet achieve a longer term beneficial effect. ◮ Decreasing learning rates may still help reduce error towards the end. 18
Estimating the learning rate ◮ How can we get a good LR estimate? ◮ Start with a small LR and increase it on every batch exponentially. ◮ Simultaneously, compute loss function on validation set. ◮ This also works for finding bounds for cyclic LRs. 19
SGD with Warm Restarts ◮ Key idea: restart every T i epochs. Record best estimates before restart . ◮ Restarts are not from scratch, but from last estimate, and learning rate is increased. min + 1 min )(1 + cos( T c ur α t = α i 2( α i max − α i π )) T i ◮ The cycle can be lengthened with time. ◮ α i min and α i max can be decayed after a cycle. 20
Snapshot ensembles: Train 1, get M for free ◮ Ensemble networks are much more robust and accurate than individual networks. ◮ They constitute another type of regularization technique. ◮ The novelty is to train a single neural network, but obtain M different models. ◮ The idea is to converge to M different local optima, and save network parameters. 21
Snapshot ensembles II ◮ Different initialization points, or hyperarameter choices may converge to different local minima. ◮ Although these local minima may perform similarly in terms of averaged errors, they may not make the same mistakes. ◮ Ensemble methods train many NN, and then optimize through majority vote, or averaging of the prediction outputs. ◮ The proposal uses a cycling step size procedure (cosine), in which the learning rate is abruptly raised and wait for new convergence. ◮ The final ensemble consists of snapshots of the optimization path. 22
Snapshot ensembles III 23
Gradient checking 24
Gradient checking ◮ Useful technique to debug code of manual implementations of neural networks. ◮ Not intended for training of networks, but it can help to identify errors in a backpropagation implementation. ◮ Derivative of a function: f ( x + ǫ ) − f ( x − ǫ ) ≈ f ( x + ǫ ) − f ( x − ǫ ) f ′ ( x ) = lim . 2 ǫ 2 ǫ ǫ → 0 ◮ The approximation error is in the order O ( ǫ 2 ). ◮ In the multivariate case, the ǫ term affects a single component: ≈ f ( θ + r ) − f ( θ − d f ( θ ) r ) dθ r 2 ǫ where θ + r = ( θ 1 , . . . , θ r + ǫ, . . . , θ n ), θ − r = ( θ 1 , . . . , θ r − ǫ, . . . , θ n ). 25
Algorithm for gradient checking 1: Reshape input vector in a column vector θ . 2: for each r component do θ old ← θ r 3: Calculate f ( θ + r ) and f ( θ − r ). 4: d f ( θ ) Compute approx. dθ r . 5: Restore θ r ← θ old 6: 7: end for 8: Verify relative error is below some threshold: � dθ approx − dθ � ξ = � dθ approx � + � dθ � 26
How to address overfitting 27
Estimators ◮ Point estimation is the attempt to provide the single “best” prediction of some quantity of interest: ˆ θ m = g ( x (1) , . . . , x ( m ) ) . – θ : true value. – ˆ θ m : estimator for m samples. ◮ Frequentist perspective: θ fixed but unkwown. ⇒ ˆ ◮ Data is random = θ m is a r.v. 28
Bias and Variance ◮ Bias: expected deviation from the true value. ◮ Variance: deviation from the expected estimator. Examples: � µ m = 1 i x ( i ) – Sample mean: ˆ m � i ( x ( i ) − ˆ m = 1 σ 2 µ m ) 2 : – Sample variance ˆ m m ] = m − 1 σ 2 σ 2 E [ˆ m � i ( x ( i ) − ˆ σ 2 1 µ m ) 2 – Unbiased sample variance: ˜ m = m − 1 ◮ How to choose estimators with different statistics? – Mean square error (MSE). – Cross-validation: empirical . 29
Bias-Variance Example h i g h b i a s & a p p r o p r i a t e u n d e r fj t t i n g h i g h v a r i a n c e & h i g h b i a s & v a r i a n c e o v e r fj t t i n g 30
Recommend
More recommend