Relative Fisher Information and Natural Gradient for Learning Large Modular Models Ke Sun 1 Frank Nielsen 2 , 3 1 King Abdullah University of Science & Technology (KAUST) 2 École Polytechnique 3 Sony CSL ICML 2017
Fisher Information Metric (FIM) Consider a statistical model p ( x | Θ ) of order D . The FIM (Hotelling29,Rao45) I ( Θ ) = ( I ij ) is defined by a D × D positive semi-definite matrix � ∂ l � ∂ l I ij = E p , (1) ∂ Θ i ∂ Θ j where l ( Θ ) = log p ( x | Θ ) denotes the log-likelihood. 1/29
Equivalent Expressions � ∂ l � ∂ l I ij = E p ∂ Θ i ∂ Θ j � � ∂ 2 l = − E p ∂ Θ i ∂ Θ j � ∂ � � p ( x | Θ ) ∂ p ( x | Θ ) = 4 d x . ∂ Θ i ∂ Θ j Observed FIM (Efron & Hinkley, 1978) With respect to X n = { x k } n k = 1 , n ∂ 2 log p ( x i | Θ ) � I = −∇ 2 l ( Θ | X n ) = − ˆ . ∂ Θ ∂ Θ ⊺ i = 1 2/29
FIM and Statistical Learning ◮ Any parametric learning is inside a corresponding parameter manifold M Θ M Θ θ T θ M Θ : a tangent space with a learning curve a local inner product g ( θ ) ◮ FIM gives an invariant Riemannian metric g ( Θ ) = I ( Θ ) for any loss function based on standard f-divergence (KL, cross-entropy, . . . ) S. Amari. Information Geometry and Its Applications. 2016. 3/29
Invariance The FIM is not invariant and depends on the parameterization: g Θ ( Θ ) = J ⊺ g Λ ( Λ ) J where J is the Jacobian matrix J ij = ∂ Λ i ∂ Θ j . However its measurements such as � δ Θ , δ Θ � g ( Θ ) is invariant: � δ Θ , δ Θ � g ( Θ ) = δ Θ ⊺ g ( Θ ) δ Θ = δ Θ ⊺ J ⊺ g Λ ( Λ ) J δ Θ = δ Λ ⊺ g Λ ( Λ ) δ Λ = � δ Λ , δ Λ � g ( Λ ) . Regardless of the choice of the coordinate system, it is essentially the same metric! 4/29
Statistical Formulation of a Multilayer Perceptron (MLP) � p ( y | x , Θ ) = p ( y | h L − 1 , θ L ) · · · p ( h 2 | h 1 , θ 2 ) p ( h 1 | x , θ 1 ) , h 1 , ··· , h L − 1 y 1 y 2 y 3 y 4 y 5 y θ L h L − 1 h 1 θ 1 x 1 x 2 x 3 x 4 x 5 x 5/29
The FIM of a MLP The FIM of a MLP has the following expression � ∂ l � ∂ l g ( Θ ) = E x ∼ ˆ p ( X n ) , y ∼ p ( y | x , Θ ) ∂ Θ ∂ Θ ⊺ � ∂ l i n � = 1 ∂ l i � E p ( y | x i , Θ ) n ∂ Θ ∂ Θ ⊺ i = 1 where ◮ ˆ p ( X n ) is the empirical distribution of the samples X n = { x i } n i = 1 ◮ l i ( Θ ) = log p ( y | x i , Θ ) is the conditional log-likelihood 6/29
Meaning of the FIM of a MLP Consider a learning step on M Θ from Θ to Θ + δ Θ . The step size � δ Θ , δ Θ � g ( Θ ) = δ Θ ⊺ g ( Θ ) δ Θ � � ∂ l i �� n 1 ∂ l i � = δ Θ ⊺ E p ( y | x i , Θ ) δ Θ ∂ Θ ∂ Θ ⊺ n i = 1 n � � 2 = 1 δ Θ ⊺ ∂ l i � E p ( y | x i , Θ ) n ∂ Θ i = 1 ∂ l measures how much δ Θ is statistically along ∂ Θ . Will δ Θ make a significant change to the mapping x → y or not? 7/29
Natural Gradient: Seeking a Short Path Consider min Θ ∈M Θ L ( Θ ) . At Θ t ∈ M Θ , the target is to minimize wrt δ Θ + 1 L ( Θ t + δ Θ ) 2 γ � δ Θ , δ Θ � g ( Θ t ) ( γ : learning rate) � �� � � �� � Loss function Squared step size ≈ L ( Θ t ) + δ Θ ⊺ ▽ L ( Θ t ) + 1 2 γ δ Θ ⊺ g ( Θ t ) δ Θ , giving a learning step δ Θ t = − γ g − 1 ( Θ t ) ▽ L ( Θ t ) � �� � natural gradient ◮ Equivalence with mirror descent (Raskutti & Mukherjee 2013) 8/29
Natural Gradient: Intrinsics δ Θ t = − γ g − 1 ( Θ t ) ▽ L ( Θ t ) This Riemannnian metric is a property of the parameter space that is independent of the loss function L ( Θ ) . The good performance of natural gradient relies on that L ( Θ ) is similarly curved as log p ( x | Θ ) ( x ∼ p ( x | Θ ) ). Natural gradient is not universally good for any loss functions. 9/29
Natural Gradient: Pros and Cons Pros ◮ Invariant (intrinsic) gradient ◮ Not trapped in plateaus ◮ Achieve Fisher efficiency in online learning Cons ◮ Too expensive to compute (no closed-form FIM; need matrix inversion) 10/29
Relative FIM — Informal Ideas ◮ Decompose the learning system into subsystems ◮ The subsystems are interfaced with each other through hidden variables h i ◮ Some subsystems are interfaced with the I/O environment through x i and y i ◮ Compute the subsystem FIM by integrating out its interface variables h i , so that the intrinsics of this subsystem can be discussed regardless of the remaining parts 11/29
From FIM to Relative FIM (RFIM) FIM log p ( x | θ ) (likelihood scalar) θ (parameter vector) How sensitive is x wrt tiny movements of θ on M θ ? RFIM log p ( r | θ , θ f ) (likelihood scalar) θ (parameter vector) Given θ f , how sensitive is r wrt tiny movements of θ ? 12/29
Relative FIM — Definition Given θ f (the reference ), the Relative Fisher Information Metric (RFIM) of θ wrt h (the response ) is � ∂ � ∂ θ ln p ( h | θ , θ f ) ∂ g h ( θ | θ f ) = E p ( h | θ , θ f ) ∂ θ ⊺ ln p ( h | θ , θ f ) , or simply g h ( θ ) . Meaning: given θ f , how variations of θ will affect the response h . 13/29
Different Subsystems – Simple Examples h ′ i h i θ θ h i Figure: Generator Figure: Discriminator or Regressor 14/29
A Dynamic Geometry � � = p ( y | Θ , x ) p ( h 1 | θ 1 , x ) p ( h 2 | θ 2 , h 1 ) p ( y | θ 3 , h 2 ) Model: h 1 h 2 x + ∆ x h 1 + ∆ h 1 M θ 3 M Θ x Manifold: M θ 1 M θ 2 h 1 h 2 + ∆ h 2 Computational graph: h 2 y y x Θ x θ 1 h 1 θ 2 h 2 θ 3 θ 1 θ 2 h 1 θ 3 h 2 Θ θ 1 g h 1 ( θ 1 ) g h 2 ( θ 2 ) g y ( θ 3 ) θ 2 θ 3 I ( Θ ) Metric: Θ h 1 h 2 ◮ As the interface hidden variables h i are changing, the subsystem geometry is not absolute but is relative to its reference variables provided by adjacent subsystems 15/29
RFIM of One tanh Neuron Consider a neuron with input x , weights w , a hyperbolic tangent activation function, and a stochastic output y ∈ {− 1 , 1 } , given by p ( y = 1 ) = 1 + tanh ( w ⊺ ˜ x ) tanh ( t ) = exp ( t ) − exp ( − t ) , exp ( t ) + exp ( − t ) . 2 x = ( x ⊺ , 1 ) ⊺ denotes the augmented vector of x ˜ ν tanh ( w , x ) = sech 2 ( w ⊺ ˜ g y ( w | x ) = ν tanh ( w , x )˜ x ⊺ , x ˜ x ) . 16/29
RFIM of Parametric Rectified Linear Unit x ) , σ 2 ) , p ( y | w , x ) = G ( y | relu ( w ⊺ ˜ ( G is for Gaussian) � t if t ≥ 0 relu ( t ) = if t < 0 . ( 0 ≤ ι < 1 ) ι t By certain assumptions, g y ( w | x ) = ν relu ( w , x )˜ x ˜ x ⊺ , 2 � 1 − ι � ν relu ( w , x ) = 1 ι + ( 1 − ι ) sigm w ⊺ ˜ x . σ 2 ω ���� sigmoid Set σ = 1, ι = 0, it simplifies to � 1 � ν relu ( w , x ) = sigm 2 ω w ⊺ ˜ x . 17/29
Generic Expression of One-neuron RFIMs Denote f ∈ { tanh , sigm , relu , elu } to be an element-wise nonlinear activation function. The RFIM is g y ( w | x ) = ν f ( w , x )˜ x ⊺ , x ˜ where ν f ( w , x ) is a positive coefficient with large values in the linear region , or the effective learning zone of the neuron. 18/29
RFIM of a Linear Layer x : input; W : connection weights; y : stochastic output following x , σ 2 I ) . p ( y | W , x ) = G ( y | W ⊺ ˜ We vectorize W by stacking its columns { w i } . Then x ˜ ˜ x ⊺ g y ( W | x ) = 1 ... . σ 2 x ˜ ˜ x ⊺ 19/29
RFIM of a Non-linear Layer A nonlinear layer applies an element-wise activation on W ⊺ ˜ x . We have ν f ( w 1 , x )˜ x ˜ x ⊺ g y ( W | x ) = ... , ν f ( w m , x )˜ x ˜ x ⊺ where ν f ( w i , x ) depends on the activation function f . 20/29
The RFIMs of single neuron models, a linear layer, a non-linear layer, a soft-max layer, two consecutive layers all have simple closed form solutions 1 . 1 See the paper. 21/29
List of RFIMs the RFIM g y ( w ) Subsystem sech 2 ( w ⊺ ˜ A tanh neuron x )˜ x ˜ x ⊺ A sigm neuron sigm ( w ⊺ ˜ x ) [ 1 − sigm ( w ⊺ ˜ x )] ˜ x ˜ x ⊺ � 1 − ι �� 2 ˜ � A relu neuron ι + ( 1 − ι ) sigm ω w ⊺ ˜ x x ˜ x ⊺ � ˜ x ˜ if w ⊺ ˜ x ⊺ x ≥ 0 A elu neuron x )) 2 ˜ ( α exp ( w ⊺ ˜ x ˜ x ⊺ if w ⊺ ˜ x < 0 A linear layer diag [˜ x ˜ x ⊺ , · · · , ˜ x ˜ x ⊺ ] A non-linear layer diag [ ν f ( w 1 , ˜ x )˜ x ˜ x ⊺ , · · · , ν f ( w m , ˜ x )˜ x ˜ x ⊺ ] ( η 1 − η 2 1 )˜ x ˜ x ⊺ − η 1 η 2 ˜ x ˜ x ⊺ − η 1 η m ˜ x ˜ x ⊺ · · · ( η 2 − η 2 A soft-max layer − η 2 η 1 ˜ x ˜ x ⊺ 2 )˜ x ˜ x ⊺ − η 2 η m ˜ x ˜ x ⊺ · · · . . . ... . . . . . . . ( η m − η 2 − η m η 1 ˜ x ˜ x ⊺ − η m η 2 ˜ x ˜ x ⊺ m )˜ x ˜ x ⊺ · · · Two layers see the paper. 22/29
Relative Natural Gradient Descent (RNGD) For each subsystem, � � � − 1 · ∂ L � g h ( θ t | θ f ) θ t + 1 ← θ t − γ · ¯ � ∂ θ � θ = θ t � �� � inverse RFIM where n g h ( θ t | θ f ) = 1 � g h ( θ t | θ i ¯ f ) . n i = 1 By definition, RFIM is a function of the reference variables. g h ( θ t | θ f ) is its expectation wrt an empirical distribution of θ f . ¯ 23/29
Recommend
More recommend