Variational inference Probabilistic Graphical Models Sharif University of Technology Soleymani Spring 2018 Some slides are adapted from Xing ’ s slides
Exact methods for inference Variable elimination Message Passing: shared terms Sum-product (belief propagation) Max-product Junction Tree 2
Junction tree General algorithm on graphs with cycles Message passing on junction trees 𝑛 𝑗𝑘 𝑇 𝑗𝑘 𝑇 𝑗𝑘 𝐷 𝐷 𝑗 𝑘 𝑛 𝑘𝑗 𝑇 𝑗𝑘 3
Why approximate inference The computational complexity of Junction tree algorithm with be 𝐿 𝐷 where 𝐷 shows the largest elimination clique (the largest clique in the triangulated graph) Tree-width of an 𝑂 × 𝑂 grid is 𝑂 For a distribution 𝑄 associated with a complex graph, computing the marginal (or conditional) probability of arbitrary random variable(s) is intractable 4
Learning and inference Learning usually needs inference For Bayesian inference that is one of the principal foundations for machine learning, learning is just an inference problem For Maximum Likelihood approach, also, we need inference when we have incomplete data or when we encounter an undirected model 5
Approximate inference Approximate inference techniques Variational algorithms Loopy belief propagation Mean field approximation Expectation propagation Stochastic simulation / sampling methods 6
Variational methods “ variational ” : general term for optimization-based formulations Many problems can be expressed in terms of an optimization problem in which the quantity being optimized is a functional Variational inference is deterministic framework that is widely used for approximate inference 7
Variatonal inference methods Constructing an approximation to the target distribution 𝑄 where this approximation takes a simpler form for inference: We define a target class of distributions Search for an instance 𝑅 ∗ in that is the best approximation to 𝑄 Queries will be answered using 𝑅 ∗ rather than on 𝑄 Constrained optimization : given family of distributions Simpler families for which solving the optimization problem will be computationally tractable However, the family may not be sufficiently expressive to encode 𝑄 8
Setup Assume that we are interested in the posterior distribution Observed variables 𝑄(𝑎, 𝑌|𝛽) 𝑌 = {𝑦 1 , … , 𝑦 𝑜 } 𝑄 𝑎 𝑌, 𝛽 = 𝑎 = {𝑨 1 , … , 𝑨 𝑛 } Hidden variables 𝑄 𝑎, 𝑌 𝛽 𝑒𝑎 The problem of computing the posterior is an instance of a more general problem that variational inference solves Main idea: We pick a family of distributions over the latent variables with its own variational parameters Then, find the setting of the parameters that makes 𝑅 close to the posterior of interest Use 𝑅 with the fitted parameters as an approximation for the posterior 9
Approximation Goal: Approximate difficult distribution 𝑄(𝑎|𝑌) with a new distribution 𝑅(𝑎) such that: 𝑄(𝑎|𝑌) and 𝑅(𝑎) are close Computation on 𝑅(𝑎) is easy Typically, the true posterior is not in the variational family. How should we measure distance between distributions? The Kullback-Leibler divergence (KL-divergence) between two distributions 𝑄 and 𝑅 10
KL divergence Kullback-Leibler divergence between 𝑄 and 𝑅 : 𝐿𝑀(𝑄| 𝑅 = 𝑄 𝑦 log 𝑄(𝑦) 𝑅(𝑦) 𝑒𝑦 A result from information theory: For any 𝑄 and 𝑅 𝐿𝑀(𝑄| 𝑅 ≥ 0 𝐿𝑀(𝑄| 𝑅 = 0 if and only if 𝑄 ≡ 𝑅 𝐸 is asymmetric 11
How measure the distance of 𝑄 and 𝑅 ? We wish to find a distribution 𝑅 such that 𝑅 is a “ good ” approximation to 𝑄 We can therefore use KL divergence as a scoring function to decide a good 𝑅 But, 𝐿𝑀(𝑄(𝑎|𝑌)||𝑅(𝑎)) ≠ 𝐿𝑀(𝑅(𝑎)||𝑄(𝑎|𝑌)) 12
M-projection vs. I-projection M-projection of 𝑅 onto 𝑄 𝑅 ∗ = argmin 𝐿𝑀(𝑄||𝑅) 𝑅∈ I-projection of 𝑅 onto 𝑄 𝑅 ∗ = argmin 𝐿𝑀(𝑅||𝑄) 𝑅∈ These two will differ only when 𝑅 is minimized over a restricted set of probability distributions (when 𝑄 ∉ 𝑅 set of possible 𝑅 distributions) 13
KL divergence: M-projection vs. I-projection Let 𝑄 be a 2D Gaussian and 𝑅 be a Gaussian distribution with diagonal covariance matrix: 𝑅 𝒜 log 𝑅 𝒜 𝑅 ∗ = argmin 𝑄 𝒜 log 𝑄 𝒜 𝑄 𝒜 𝑒𝒜 𝑅 ∗ = argmin 𝑅 𝒜 𝑒𝒜 𝑅 𝑅 𝑄 : Green 𝑅 ∗ : Red 𝐹 𝑄 𝒜 = 𝐹 𝑅 [𝒜] 𝐹 𝑄 𝒜 = 𝐹 𝑅 [𝒜] [Bishop] 14
KL divergence: M-projection vs. I-projection Let 𝑄 is mixture of two 2D Gaussians and 𝑅 be a 2D Gaussian distribution with arbitrary covariance matrix: 𝑅 𝒜 log 𝑅 𝒜 𝑄 : Blue 𝑄 𝒜 log 𝑄 𝒜 𝑅 ∗ = argmin 𝑅 ∗ = argmin 𝑄 𝒜 𝑒𝒜 𝑅 𝒜 𝑒𝒜 𝑅 ∗ : Red 𝑅 𝑅 two good solutions! 𝐹 𝑄 𝒜 = 𝐹 𝑅 𝒜 𝐷𝑝𝑤 𝑄 𝒜 = 𝐷𝑝𝑤 𝑅 𝒜 15 [Bishop]
M-projection Computing 𝐿𝑀(𝑄| 𝑅 requires inference on 𝑄 𝑄 𝑨 log 𝑄 𝑨 𝐿𝑀(𝑄| 𝑅 = 𝑅 𝑨 = −𝐼 𝑄 − 𝐹 𝑄 [log 𝑅(𝑨)] 𝑨 Inference on 𝑄 (that is difficult) is required! When 𝑅 is in the exponential family: 𝐿𝑀(𝑄| 𝑅 = 0 ⇔ 𝐹 𝑄 𝑈 𝑨 = 𝐹 𝑅 [𝑈 𝑨 ] Moment projection Expectation Propagation methods are based on minimizing 𝐿𝑀(𝑄| 𝑅 16
I-projection can be computed without performing inference on 𝐿𝑀(𝑅| 𝑄 𝑄 𝐿𝑀(𝑅| 𝑄 = 𝑅 𝑨 log 𝑅 𝑨 𝑄 𝑨 𝑒𝑨 = −𝐼 𝑅 − 𝐹 𝑅 [𝑄(𝑨)] Most variational inference algorithms make use of 𝐿𝑀(𝑅| 𝑄 Computing expectations w.r.t. 𝑅 is tractable (by choosing a suitable class of distributions for 𝑅 ) We choose a restricted family of distributions such that the expectations can be evaluated and optimized efficiently. and yet which is still sufficiently flexible as to give a good approximation 17
Example of variatinal approximation Variational Laplace Approx. [Bishop] 18
Evidence Lower Bound (ELBO) ln 𝑄 𝑌 = ℒ 𝑅 + 𝐿𝑀(𝑅||𝑄) 𝑌 = {𝑦 1 , … , 𝑦 𝑜 } 𝑎 = {𝑨 1 , … , 𝑨 𝑛 } ℒ 𝑅 = 𝑅 𝑎 ln 𝑄(𝑌, 𝑎) 𝑒𝑎 𝑅(𝑎) 𝐿𝑀(𝑅||𝑄) = − 𝑅 𝑎 ln 𝑄(𝑎|𝑌) 𝑒𝑎 𝑅(𝑎) We also called ℒ 𝑅 as We can maximize the lower bound ℒ 𝑅 𝐺[𝑄, 𝑅] latter. equivalent to minimizing KL divergence. if we allow any possible choice for 𝑅(𝑎) , then the maximum of the lower bound occurs when the KL divergence vanishes occurs when 𝑅(𝑎) equals the posterior distribution 𝑄(𝑎|𝑌) . The difference between the ELBO and the KL divergence is ln 𝑄(𝑌) which is what the ELBO bounds 19
Evidence Lower Bound (ELBO) Lower bound on the marginal likelihood This quantity should increase monotonically with each iteration we maximize the ELBO to find the parameters that gives as tight a bound as possible on the marginal likelihood ELBO converges to a local minimum. Variational inference is closely related to EM 20
Factorized distributions 𝑅 (mean-filed variational inference) The restriction on the distributions in the form of factorization assumptions: 𝑅 𝑎 = 𝑅 𝑗 (𝑎 𝑗 ) 𝑗 ℒ 𝑅 = 𝑅 𝑗 ln 𝑄(𝑌, 𝑎) − ln 𝑅 𝑗 𝑒𝑎 𝑗 𝑗 Coordinate ascent to optimize ℒ 𝑅 (we first find ℒ 𝑘 𝑅 that is a functional of 𝑅 𝑘 ): ℒ 𝑘 𝑅 = 𝑅 𝑘 ln 𝑄(𝑌, 𝑎) 𝑅 𝑗 𝑒𝑎 𝑗 𝑒𝑎 𝑘 − 𝑅 𝑘 ln 𝑅 𝑘 𝑒𝑎 𝑘 + 𝑑𝑝𝑜𝑡𝑢 𝑗≠𝑘 ⇒ ℒ 𝑘 𝑅 = 𝐹 −𝑘 ln 𝑄 𝑌, 𝑎 − 𝑅 𝑘 ln 𝑅 𝑘 𝑒𝑎 𝑘 + 𝑑𝑝𝑜𝑡𝑢 𝐹 −𝑘 ln 𝑄 𝑌, 𝑎 = ln 𝑄 𝑌, 𝑎 𝑅 𝑗 𝑒𝑎 𝑗 21 𝑗≠𝑘
Factorized distributions 𝑅 : optimization 𝑀(𝑅 𝑘 , 𝜇) = ℒ 𝑘 𝑅 + 𝜇( 𝑅 𝑎 𝑘 − 1) 𝑎 𝑘 𝑒𝑀 𝑘 ) = 𝐹 −𝑘 log 𝑄 𝑎, 𝑌 − log 𝑅 𝑎 𝑘 − 1 + 𝜇 = 0 𝑒𝑅(𝑎 ⇒ 𝑅 ∗ (𝑎 𝑘 ) ∝ exp 𝐹 −𝑘 ln 𝑄 𝑌, 𝑎 𝑅 ∗ (𝑎 𝑘 ) ∝ exp 𝐹 −𝑘 ln 𝑄 𝑎 𝑘 |𝑎 −𝑘 , 𝑌 The above formula determines the form of the optimal 𝑅 𝑎 𝑘 . We didn't specify the form in advance and only the factorization has been assumed. Depending on that form, the optimal 𝑅 𝑎 might not be easy to work with. 𝑘 Nonetheless, for many models it is easy. Since we are replacing the neighboring values by their mean value, the method is known as mean field 22
Example: Gaussian factorized distribution 23
Example: Gaussian factorized distribution Solution: 24
Recommend
More recommend