Class Probabilities and the Log-sum-exp Trick Oren Freifeld Computer Science, Ben-Gurion University May 14, 2017 Oren Freifeld (BGU CS) May 14, 2017 1 / 10
Disclaimer Both the problem and the solution described in these slides are widely known. I don’t remember where I saw the solution the first time, and couldn’t find out who should be credited with discovering it. Some of my derivations below are based on Ryan Adams’ post at https://hips.seas.harvard.edu/blog/2013/01/09/ computing-log-sum-exp/ Oren Freifeld (BGU CS) May 14, 2017 2 / 10
Numerical Issues with Computing Class Probabilities We often need to compute, for data point x , expressions such as: w k exp( l k ) p ( z = k | θ, x ) ∝ w k exp( l k ) and � K k ′ =1 w k ′ exp( l k ′ ) where l k ∈ R ∀ k ∈ { 1 , . . . , K } Here, l k does not necessarily stand for log-likelihood; rather, it stands for the nominal value of the exponent of the k − th term of interest. Oren Freifeld (BGU CS) May 14, 2017 3 / 10
Example In EM for GMM, the E step involves π k N ( x i ; µ k , Σ k ) r i,k = � K k ′ =1 π k ′ N ( x i ; µ k ′ , Σ k ′ ) π k (2 π ) − n/ 2 | Σ k | − 1 / 2 exp � � 2 ( x i − µ k ) T Σ − 1 − 1 k ( x i − µ k ) = � � � K k ′ =1 π k ′ (2 π ) − n/ 2 | Σ k ′ | − 1 / 2 exp − 1 2 ( x i − µ k ′ ) T Σ − 1 k ′ ( x i − µ k ′ ) w k l k � �� � � �� � π k | Σ k | − 1 / 2 exp 2 ( x i − µ k ) T Σ − 1 − 1 k ( x i − µ k ) = � K � � k ′ =1 π k ′ | Σ k ′ | − 1 / 2 exp − 1 2 ( x i − µ k ′ ) T Σ − 1 k ′ ( x i − µ k ′ ) Remark Here, the π in the 2 π term (which cancels out anyway) is the number π , while π k is the weight of the k -th component; confusing, but it is a fairly standard notation, especially in Bayesian statistics. Oren Freifeld (BGU CS) May 14, 2017 4 / 10
Numerical Issues with Computing Class Probabilities If l k < 0 and | l k | is too large, we might have situations where (on a computer) exp( l k ) = 0 for all k . Thus, � K k ′ =1 w k ′ exp( l k ′ ) will be zero. Similarly, if l k > 0 (can happen, for example, for some non-Gaussian conditional class probabilities), might get + ∞ (and/or overflow) if l k is too large. These issues appear in many clustering problems, including in (either Bayesian or non-Bayesian) mixture models. Oren Freifeld (BGU CS) May 14, 2017 5 / 10
The Log-sum-exp Trick Fact ∀ a ∈ R and ∀{ l k } K k =1 ⊂ R : K K � � log exp( l k ) = a + log exp( l k − a ) k =1 k =1 Oren Freifeld (BGU CS) May 14, 2017 6 / 10
The Log-sum-exp Trick Proof. � K � K K � � � log exp( l k ) = log exp( l k − a + a ) = log exp( l k − a ) exp( a ) k =1 k =1 k =1 � K � � � K � � = log exp( a ) exp( l k − a ) = log exp( a ) + log exp( l k − a ) k =1 k =1 K � exp( l k − a ) = a + log k =1 Oren Freifeld (BGU CS) May 14, 2017 7 / 10
The Log-sum-exp Trick Fact ∀ a ∈ R and ∀{ l k } K k =1 ⊂ R : exp( l k − a ) exp( l k ) = � K � K k ′ =1 exp( l k ′ − a ) k ′ =1 exp( l k ′ ) Oren Freifeld (BGU CS) May 14, 2017 8 / 10
Proof. � K � K (1) log k =1 exp( l k ) = a + log k =1 exp( l k − a ) (by the previous fact) (1) with K =1 (2) exp( l k ) = exp (log exp( l k )) = exp ( a + log exp( l k − a )) � � � � K K K � � � (1) (3) exp( l k ) = exp log exp( l k ) = exp a + log exp( l k − a ) k =1 k =1 k =1 exp( l k − a ) exp (log exp( l k − a )) (4) = � � � K log � K k ′ =1 exp( l k ′ − a ) k =1 exp( l k − a ) exp = exp( a ) exp (log exp( l k − a )) � � exp( a ) log � K k =1 exp( l k − a ) exp exp ( a + log exp( l k − a )) exp( l k ) (2)&(3) = = � � � K a + log � K k ′ =1 exp( l k ′ ) k =1 exp( l k − a ) exp Oren Freifeld (BGU CS) May 14, 2017 9 / 10
The Log-sum-exp Trick exp( l k − a ) exp( l k ) = � K � K k ′ =1 exp( l k ′ − a ) k ′ =1 exp( l k ′ ) Choose a = max k l k and compute the LHS, not the problematic RHS. This will prevent + ∞ , and even if some values vanish, we will have at least one survivor ( e max k l k − a = e 0 = 1 > 0 ) so the denominator will be strictly positive (and finite). More generally, instead of computing w k exp( l k ) � K k ′ =1 w k ′ exp( l k ′ ) use exp( l k + log w k − a ) � K k ′ =1 exp( l k ′ + log w k − a ) where a = max k l k + log w k and where we also used the fact that w k exp( l k ) = exp(log w k ) exp( l k ) = exp( l k + log w k ) . Oren Freifeld (BGU CS) May 14, 2017 10 / 10
Recommend
More recommend