Inference and Representation David Sontag New York University Lecture 11, Nov. 24, 2015 David Sontag (NYU) Inference and Representation Lecture 11, Nov. 24, 2015 1 / 32
Approximate marginal inference Given the joint p ( x 1 , . . . , x n ) represented as a graphical model, how do we perform marginal inference , e.g. to compute p ( x 1 | e )? We showed in Lecture 4 that doing this exactly is NP-hard Nearly all approximate inference algorithms are either: Monte-carlo methods (e.g., Gibbs sampling, likelihood reweighting, 1 MCMC) Variational algorithms (e.g., mean-field, loopy belief propagation) 2 David Sontag (NYU) Inference and Representation Lecture 11, Nov. 24, 2015 2 / 32
Variational methods Goal : Approximate difficult distribution p ( x | e ) with a new distribution q ( x ) such that: p ( x | e ) and q ( x ) are “close” 1 Computation on q ( x ) is easy 2 How should we measure distance between distributions? The Kullback-Leibler divergence (KL-divergence) between two distributions p and q is defined as p ( x ) log p ( x ) � D ( p � q ) = q ( x ) . x (measures the expected number of extra bits required to describe samples from p ( x ) using a code based on q instead of p ) D ( p � q ) ≥ 0 for all p , q , with equality if and only if p = q Notice that KL-divergence is asymmetric David Sontag (NYU) Inference and Representation Lecture 11, Nov. 24, 2015 3 / 32
KL-divergence (see Section 2.8.2 of Murphy) p ( x ) log p ( x ) � D ( p � q ) = q ( x ) . x Suppose p is the true distribution we wish to do inference with What is the difference between the solution to arg min q D ( p � q ) (called the M-projection of q onto p ) and arg min q D ( q � p ) (called the I-projection )? These two will differ only when q is minimized over a restricted set of probability distributions Q = { q 1 , . . . } , and in particular when p �∈ Q David Sontag (NYU) Inference and Representation Lecture 11, Nov. 24, 2015 4 / 32
KL-divergence – M-projection p ( x ) log p ( x ) q ∗ = arg min � q ∈ Q D ( p � q ) = q ( x ) . x For example, suppose that p ( z ) is a 2D Gaussian and Q is the set of all Gaussian distributions with diagonal covariance matrices: 1 z 2 0.5 0 0 0.5 1 z 1 (b) p =Green, q ∗ =Red David Sontag (NYU) Inference and Representation Lecture 11, Nov. 24, 2015 5 / 32
KL-divergence – I-projection q ( x ) log q ( x ) q ∗ = arg min � q ∈ Q D ( q � p ) = p ( x ) . x For example, suppose that p ( z ) is a 2D Gaussian and Q is the set of all Gaussian distributions with diagonal covariance matrices: 1 z 2 0.5 0 0 0.5 1 z 1 p =Green, q ∗ =Red (a) David Sontag (NYU) Inference and Representation Lecture 11, Nov. 24, 2015 6 / 32
KL-divergence (single Gaussian) In this simple example, both the M-projection and I-projection find an approximate q ( x ) that has the correct mean (i.e. E p [ z ] = E q [ z ]): 1 1 z 2 z 2 0.5 0.5 0 0 0 0.5 1 0 0.5 z 1 1 z 1 (b) (a) What if p ( x ) is multi-modal? David Sontag (NYU) Inference and Representation Lecture 11, Nov. 24, 2015 7 / 32
KL-divergence – M-projection (mixture of Gaussians) p ( x ) log p ( x ) q ∗ = arg min � q ∈ Q D ( p � q ) = q ( x ) . x Now suppose that p ( x ) is mixture of two 2D Gaussians and Q is the set of all 2D Gaussian distributions (with arbitrary covariance matrices): p =Blue, q ∗ =Red M-projection yields distribution q ( x ) with the correct mean and covariance. David Sontag (NYU) Inference and Representation Lecture 11, Nov. 24, 2015 8 / 32
KL-divergence – I-projection (mixture of Gaussians) q ( x ) log q ( x ) q ∗ = arg min � q ∈ Q D ( q � p ) = p ( x ) . x p =Blue, q ∗ =Red (two local minima!) Unlike M-projection, the I-projection does not always yield the correct moments. Q: D ( p � q ) is convex – so why are there local minima? A: using a parametric form for q (i.e., a Gaussian). Not convex in µ, Σ. David Sontag (NYU) Inference and Representation Lecture 11, Nov. 24, 2015 9 / 32
M-projection does moment matching Recall that the M-projection is: p ( x ) log p ( x ) q ∗ = arg min � q ∈ Q D ( p � q ) = q ( x ) . x Suppose that Q is an exponential family ( p ( x ) can be arbitrary) and that we perform the M-projection, finding q ∗ Theorem: The expected sufficient statistics, with respect to q ∗ ( x ), are exactly the marginals of p ( x ): E q ∗ [ f ( x )] = E p [ f ( x )] Thus, solving for the M-projection (exactly) is just as hard as the original inference problem David Sontag (NYU) Inference and Representation Lecture 11, Nov. 24, 2015 10 / 32
M-projection does moment matching Recall that the M-projection is: p ( x ) log p ( x ) q ∗ = arg � q ( x ; η ) ∈ Q D ( p � q ) = min q ( x ) . x Theorem: E q ∗ [ f ( x )] = E p [ f ( x )]. Proof: Look at the first-order optimality conditions. � ∂ η i D ( p � q ) = − ∂ η i p ( x ) log q ( x ) x � � � = − ∂ η i p ( x ) log h ( x ) exp { η · f ( x ) − ln Z ( η ) } x � � � = − ∂ η i p ( x ) η · f ( x ) − ln Z ( η ) x � = − p ( x ) f i ( x ) + E q ( x ; η ) [ f i ( x )] ( since ∂ η i ln Z ( η ) = E q [ f i ( x )]) x = − E p [ f i ( x )] + E q ( x ; η ) [ f i ( x )] = 0 . Corollary : Even computing the gradients is hard (can’t do gradient descent) David Sontag (NYU) Inference and Representation Lecture 11, Nov. 24, 2015 11 / 32
Most variational inference algorithms make use of the I-projection David Sontag (NYU) Inference and Representation Lecture 11, Nov. 24, 2015 12 / 32
Variational methods Suppose that we have an arbitrary graphical model: 1 � � � � p ( x ; θ ) = φ c ( x c ) = exp θ c ( x c ) − ln Z ( θ ) Z ( θ ) c ∈ C c ∈ C All of the approaches begin as follows: q ( x ) ln q ( x ) � D ( q � p ) = p ( x ) x 1 � � = − q ( x ) ln p ( x ) − q ( x ) ln q ( x ) x x � � � � = − q ( x ) θ c ( x c ) − ln Z ( θ ) − H ( q ( x )) x c ∈ C � � � = − q ( x ) θ c ( x c ) + q ( x ) ln Z ( θ ) − H ( q ( x )) x x c ∈ C � = − E q [ θ c ( x c )] + ln Z ( θ ) − H ( q ( x )) . c ∈ C David Sontag (NYU) Inference and Representation Lecture 11, Nov. 24, 2015 13 / 32
Mean field algorithms for variational inference � � max q ( x c ) θ c ( x c ) + H ( q ( x )) . q ∈ Q x c c ∈ C Although this function is concave and thus in theory should be easy to optimize, we need some compact way of representing q ( x ) � Mean field algorithms assume a factored representation of the joint distribution, e.g. � q ( x ) = q i ( x i ) (called naive mean field) i ∈ V David Sontag (NYU) Inference and Representation Lecture 11, Nov. 24, 2015 14 / 32
Naive mean-field Suppose that Q consists of all fully factored distributions, of the form q ( x ) = � i ∈ V q i ( x i ) We can use this to simplify � � max q ( x c ) θ c ( x c ) + H ( q ) q ∈ Q x c c ∈ C First, note that q ( x c ) = � i ∈ c q i ( x i ) Next, notice that the joint entropy decomposes as a sum of local entropies: � H ( q ) = − q ( x ) ln q ( x ) x � � � � = − q ( x ) ln q i ( x i ) = − q ( x ) ln q i ( x i ) x x i ∈ V i ∈ V � � = − q ( x ) ln q i ( x i ) x i ∈ V � � � � = − q i ( x i ) ln q i ( x i ) q ( x V \ i | x i ) = H ( q i ) . x V \ i i ∈ V x i i ∈ V David Sontag (NYU) Inference and Representation Lecture 11, Nov. 24, 2015 15 / 32
Naive mean-field Suppose that Q consists of all fully factored distributions, of the form q ( x ) = � i ∈ V q i ( x i ) We can use this to simplify � � max q ( x c ) θ c ( x c ) + H ( q ) q ∈ Q x c c ∈ C First, note that q ( x c ) = � i ∈ c q i ( x i ) Next, notice that the joint entropy decomposes as H ( q ) = � i ∈ V H ( q i ). Putting these together, we obtain the following variational objective: � � � � ( ∗ ) max θ c ( x c ) q i ( x i ) + H ( q i ) q c ∈ C x c i ∈ c i ∈ V subject to the constraints q i ( x i ) ≥ 0 ∀ i ∈ V , x i ∈ Val ( X i ) � q i ( x i ) = 1 ∀ i ∈ V x i ∈ Val ( X i ) David Sontag (NYU) Inference and Representation Lecture 11, Nov. 24, 2015 16 / 32
Naive mean-field for pairwise MRFs How do we maximize the variational objective? � � � � ( ∗ ) max θ ij ( x i , x j ) q i ( x i ) q j ( x j ) − q i ( x i ) ln q i ( x i ) q ij ∈ E x i , x j i ∈ V x i This is a non-concave optimization problem, with many local maxima! Nonetheless, we can greedily maximize it using block coordinate ascent : Iterate over each of the variables i ∈ V . For variable i , 1 Fully maximize (*) with respect to { q i ( x i ) , ∀ x i ∈ Val ( X i ) } . 2 Repeat until convergence. 3 Constructing the Lagrangian, taking the derivative, setting to zero, and solving yields the update: ( shown on blackboard ) q i ( x i ) ← 1 � � � � exp θ i ( x i ) + q j ( x j ) θ ij ( x i , x j ) Z i j ∈ N ( i ) x j David Sontag (NYU) Inference and Representation Lecture 11, Nov. 24, 2015 17 / 32
Recommend
More recommend