Asymptotic Analysis of the LMS Algorithm with Momentum er 1 aji 1 Sotirios Sabanis 2 L´ aszl´ o Gerencs´ Bal´ azs Csan´ ad Cs´ 1 Institute for Computer Science and Control (SZTAKI), Hungarian Academy of Sciences (MTA), Hungary 2 School of Mathematics, University of Edinburgh, UK, and Alan Turing Institute, London, UK 57th IEEE CDC, Miami Beach, Florida, December 18, 2018
Introduction – Stochastic gradient descent (SGD) methods are popular stochastic approximation (SA) algorithms applied in a wide variety of fields. – Here, we focus on the special case of least mean square (LMS). – Polyak’s momentum is an acceleration technique for gradient methods which has several advantages for deterministic problems. – K. Yuan, B. Ying and A. H. Sayed (2016) argued that in the stochastic case it is “equivalent” to standard SGD, assuming fixed gains, strongly convex functions and martingale difference noises. – For LMS, they assumed independent noises to ensure this. – Here, we provide a significantly simpler asymptotic analysis of LMS with momentum for stationary, ergodic and mixing signals. – We present weak convergence results and explore the trade-off between the rate of convergence and the asymptotic covariance. L. Gerencs´ er, B. Cs. Cs´ aji, and S. Sabanis LMS with Momentum | 2
Stochastic Approximation with Fixed Gain Stochastic Approximation (SA) with Fixed Gain � � θ n +1 = θ n + µ H θ n , X n +1 ���� ���� ���� � �� � next current fixed update estimate estimate gain operator ◦ θ n ∈ R d is the estimate at time n . ◦ X n ∈ R k is the new data available at time n . ◦ µ ∈ [ 0 , ∞ ) is the fixed gain or step-size. ◦ H : R d × R k → R d is the update operator. (SA algorithms are typically applied to find roots, fixed points or extrema of functions we only observe at given points with noise.) L. Gerencs´ er, B. Cs. Cs´ aji, and S. Sabanis LMS with Momentum | 3
Stochastic Gradient Descent – We want to minimize an unknown function, f : R d → R , based only on noisy queries about its gradient, ∇ f , at selected points. Stochastic Gradient Descent (SGD) . = θ n + µ ( −∇ θ f ( θ n ) + ε n ) θ n +1 – Polyak’s heavy-ball or momentum method is defined as SGD with Momentum Acceleration . θ n +1 = θ n + µ ( −∇ θ f ( θ n ) + ε n ) + γ ( θ n − θ n − 1 ) – The added term acts both as a smoother and an accelerator. (The extra momentum dampens oscillations and helps us getting through narrow valleys, small humps and local minima.) L. Gerencs´ er, B. Cs. Cs´ aji, and S. Sabanis LMS with Momentum | 4
Mean-Square Optimal Linear Filter – [C0] Assume we observe a (strictly) stationary and ergodic stochastic process consisting input-output pairs { ( x t , y t ) } , where regressor (input) x t is R d -valued, while output y t is R -valued. – We want to find the mean-square optimal linear filter coefficients � 1 � θ ∗ . � � 2 y n − x T = arg min n θ E 2 θ ∈ R d – Using R ∗ . n ] and b . = E [ x n x T = E [ x n y n ], the optimal solution is Wiener-Hopf Equation R ∗ θ ∗ = b θ ∗ = R − 1 ⇒ = ∗ b – [C1] Assume that R ∗ is non-singular, thus, θ ∗ is uniquely defined. L. Gerencs´ er, B. Cs. Cs´ aji, and S. Sabanis LMS with Momentum | 5
Least Mean Square – The least mean square (LMS) algorithm is an SGD method Least Mean Square (LMS) . = θ n + µ x n +1 ( y n +1 − x T θ n +1 n +1 θ n ) with µ > 0 and some constant (non-random) initial condition θ 0 . – Introducing the observation and (coefficient) estimation errors as . . = y n − x T n θ ∗ = θ n − θ ∗ and ∆ n v n the estimation error process, { ∆ n } , follows the dynamics ∆ n +1 = ∆ n − µ x n +1 x T n +1 ∆ n + µ x n +1 v n +1 with ∆ 0 . = θ 0 − θ ∗ . Note that E [ x n v n ] = 0 for all n ≥ 0. L. Gerencs´ er, B. Cs. Cs´ aji, and S. Sabanis LMS with Momentum | 6
The Associated ODE – A standard tool for the analysis of SA methods is the associated ordinary differential equation (ODE). In the LMS case (for t ≥ 0) d . θ t = h (¯ ¯ θ ( t )) = b − R ∗ ¯ ¯ θ t with θ 0 = θ 0 dt where h ( θ ) . = E [ x n +1 ( y n +1 − x T n +1 θ ) ] is the mean update for θ . . – A piecewise constant extension of { θ n } is defined as θ c = θ [ t ] , t (note that here [ t ] denotes the integer part of t ). – LMS is modified by taking a truncation domain D , where D is the interior of a compact set; then we apply the stopping time . = inf { t : θ c τ t / ∈ D } . – [C2] We assume that the truncation domain is such that the solution of the ODE defined above does not leave D . L. Gerencs´ er, B. Cs. Cs´ aji, and S. Sabanis LMS with Momentum | 7
The Error of the ODE – Let us define the following error processes for the mean ODE . . ˜ = θ n − ¯ ˜ t − ¯ θ c = θ c θ n θ n and θ t t – The normalized and time-scaled version of the ODE error is . = µ − 1 / 2 ˜ θ [ ( t ∧ τ ) /µ ] = µ − 1 / 2 ˜ θ c V t ( µ ) ( t ∧ τ ) /µ – We will also need the asymptotic covariance matrices of the empirical means of the centered correction terms, given by + ∞ . � � ( H k ( θ ) − h ( θ ))( H 0 ( θ ) − h ( θ )) T � S ( θ ) = E k = −∞ where H n ( θ ) . = x n ( y n − x T n θ ), which series converges, for example, under various mixing conditions (this will be ensured by [C3] ). L. Gerencs´ er, B. Cs. Cs´ aji, and S. Sabanis LMS with Momentum | 8
Weak Convergence for LMS – [C3] We assume that the process defined by � √ µ . � [ t /µ ] − 1 � H n (¯ θ µ n ) − h (¯ L t ( µ ) = θ µ n ) n =0 converges weakly, as µ → 0, to a time-inhomogeneous zero-mean Brownian motion { L t } with local covariances { S (¯ θ t ) } . Theorem 1: Weak Convergence for LMS Under conditions C0, C1, C2 and C3, process { V t ( µ ) } converges weakly, as µ → 0, to a process { Z t } satisfying the following linear stochastic differential equation (SDE), for t ≥ 0, with Z 0 = 0 , 1 / 2 (¯ dZ t = − R ∗ Z t dt + S θ t ) dW t where { W t } is a standard Brownian motion in R d . L. Gerencs´ er, B. Cs. Cs´ aji, and S. Sabanis LMS with Momentum | 9
Momentum LMS LMS with Momentum Acceleration . = θ n + µ x n +1 ( y n +1 − x T n +1 θ n ) + γ ( θ n − θ n − 1 ) θ n +1 with µ > 0, 1 > γ > 0, and some non-random θ 0 = θ − 1 . – The filter coefficient errors now follow a 2nd order dynamics ∆ n +1 = ∆ n − µ x n +1 x T n +1 ∆ n + µ x n +1 v n +1 + γ (∆ n − ∆ n − 1 ) with ∆ 0 = ∆ − 1 (recall that ∆ n . = θ n − θ ∗ and v n . = y n − x T n θ ∗ ). – To handle higher-order dynamics, we can use a state-vector, � � ∆ n . U n = ∆ n − 1 L. Gerencs´ er, B. Cs. Cs´ aji, and S. Sabanis LMS with Momentum | 10
State-Space Form for Momentum LMS – Using U n . = [ ∆ n , ∆ n − 1 ] T , the state-space dynamics becomes U n +1 = U n + A n +1 U n + µ W n +1 , � γ I − µ · x n +1 x T � x n +1 v n +1 � � − γ I A n +1 . W n +1 . n +1 = , = − I I 0 – This, however, does not have the canonical form of SA methods. – We apply a state-space transformation by Yuan, Ying and Sayed, � I � 1 − γ I . T = T ( γ ) = − I 1 − γ I � I � − γ I . T − 1 = T − 1 ( γ ) = − I I L. Gerencs´ er, B. Cs. Cs´ aji, and S. Sabanis LMS with Momentum | 11
Transformed State-Space Dynamics – To get a standard SA form, we also need to synchronize γ and µ , µ µ = c (1 − γ ) 2 . 1 − γ = c (1 − γ ) leading to with some fixed constant (hyper-parameter) c > 0. – After applying T , the transformed dynamics becomes an (almost) . = 1 − γ as follows: canonical SA recursion with the fixed gain λ � � ¯ � ¯ � ¯ ¯ B n +1 + λ ¯ U n + ¯ U n +1 = U n + λ D n +1 W n +1 � − 1 � 0 � � 0 1 . ¯ ⊗ x n x T B n = + c n , − I − 1 0 1 � 0 � x n v n � � − 1 . . ¯ ¯ ⊗ x n x T D n = c n , W n = c . − 1 0 x n v n L. Gerencs´ er, B. Cs. Cs´ aji, and S. Sabanis LMS with Momentum | 12
The Associated ODE for Momentum LMS – Let us introduce the notations . H n ( ¯ ¯ = ( ¯ B n + λ ¯ D n ) ¯ U + ¯ U ) W n . h ( ¯ = E [ ¯ H n ( ¯ U ) ] = ¯ B λ ¯ U ) U � 0 � � − 1 � 1 − λ . 0 ¯ = E [ ¯ B n + λ ¯ ⊗ R ∗ D n ] = + c B λ 0 − I − 1 1 − λ Then, the associated ODE takes the form, with ¯ U 0 = ¯ ¯ U 0 , d ¯ h ( ¯ B λ ¯ U t = ¯ ¯ U t ) = ¯ ¯ ¯ U t dt – The solution for the limit when λ ↓ 0 is denoted by ¯ ¯ U ∗ t . – Lemma: If λ is sufficiently small, then ¯ B λ is stable. L. Gerencs´ er, B. Cs. Cs´ aji, and S. Sabanis LMS with Momentum | 13
The ODE Error for Momentum LMS – [C2’] We again introduce a truncation domain, ¯ D , as an interior of a compact set, and assume that the ODE does not leave ¯ D . – We set a stopping time for leaving the domain . = inf { n : ¯ ∈ ¯ τ ¯ U n / D } – And define the error process, for n ≥ 0, as . ˜ U n − ¯ ¯ ¯ ¯ U n = U n – Finally, the normalized and time-scaled error process is . = λ − 1 / 2 ˜ ¯ ¯ V t ( λ ) U [ ( t ∧ ¯ τ ) /λ ] – However, the weak convergence theorems for SA methods cannot be directly applied, because there is an extra λ term in the update. L. Gerencs´ er, B. Cs. Cs´ aji, and S. Sabanis LMS with Momentum | 14
Recommend
More recommend