variational inference
play

Variational Inference CMSC 691 UMBC Goal: Posterior Inference - PowerPoint PPT Presentation

Approximate Inference: Variational Inference CMSC 691 UMBC Goal: Posterior Inference Hyperparameters Unknown parameters Data: p ( | ) Likelihood model: p( | ) (Some) Learning Techniques MAP/MLE: Point


  1. Approximate Inference: Variational Inference CMSC 691 UMBC

  2. Goal: Posterior Inference Hyperparameters α Unknown “parameters” Θ Data: p α ( Θ | ) Likelihood model: p( | Θ )

  3. (Some) Learning Techniques MAP/MLE: Point estimation, what we’ve already covered basic EM Variational Inference: today Functional Optimization Sampling/Monte Carlo next class

  4. Outline Variational Inference Basic Technique Variational Approximation Example: Topic Models

  5. Variational Inference: Core Idea • Observed 𝑦 , latent r.v.s 𝜄 • We have some joint model 𝑞(𝜄, 𝑦) • We want to compute 𝑞(𝜄|𝑦) but this is computationally difficult

  6. Variational Inference: Core Idea • Observed 𝑦 , latent r.v.s 𝜄 • We have some joint model 𝑞(𝜄, 𝑦) • We want to compute 𝑞(𝜄|𝑦) but this is computationally difficult • Solution: approximate 𝑞(𝜄|𝑦) with a different distribution 𝑟 𝜇 (𝜄) and make 𝑟 𝜇 (𝜄) “close” to 𝑞(𝜄|𝑦)

  7. Variational Inference Difficult to compute

  8. Variational Inference Difficult to Minimize the compute “difference” by changing λ Easy(ier) to q( θ ): controlled by parameters λ compute

  9. Variational Inference Difficult to Minimize the compute “difference” by changing λ Easy(ier) to compute

  10. Variational Inference: A Gradient- Based Optimization Technique Set t = 0 Pick a starting value λ t Until converged : 1. Get value y t = F(q(•; λ t )) 2. Get gradient g t = F’(q(•; λ t )) 3. Get scaling factor ρ t 4. Set λ t+1 = λ t + ρ t *g t 5. Set t += 1

  11. Variational Inference: A Gradient- Based Optimization Technique Set t = 0 Pick a starting value λ t Until converged : 1. Get value y t = F (q(•; λ t )) 2. Get gradient g t = F’ (q(•; λ t )) 3. Get scaling factor ρ t 4. Set λ t+1 = λ t + ρ t *g t 5. Set t += 1

  12. Variational Inference: The Function to Optimize Any easy-to-compute distribution Posterior of desired model

  13. Variational Inference: The Function to Optimize Any easy-to-compute distribution Posterior of desired model Find the best distribution (calculus of variations )

  14. Variational Inference: The Function to Optimize Find the best distribution Parameters for desired model

  15. Variational Inference: The Function to Optimize Find the best distribution Parameters for desired model Variational parameters for θ

  16. Variational Inference: The Function to Optimize KL-Divergence (expectation) Find the best distribution Parameters for desired model D KL 𝑟 𝜄 || 𝑞(𝜄|𝑦) = log 𝑟 𝜄 𝔽 𝑟 𝜄 𝑞(𝜄|𝑦) Variational parameters for θ

  17. Variational Inference Find the best distribution Parameters for desired model Variational parameters for θ

  18. Exponential Family Recap: “Easy” Posterior Inference p is the conjugate prior for π Exponential Family Recap: “Easy” Expectations

  19. Variational Inference Find the best distribution When p and q are the same exponential family form, the variational update q( θ ) is (often) computable (in closed form)

  20. Variational Inference: A Gradient- Based Optimization Technique Set t = 0 Pick a starting value λ t Let F(q(•; λ t )) = KL[q(•; λ t ) || p(•)] Until converged : 1. Get value y t = F(q(•; λ t )) 2. Get gradient g t = F’(q(•; λ t )) 3. Get scaling factor ρ t 4. Set λ t+1 = λ t + ρ t *g t 5. Set t += 1

  21. Variational Inference: Maximization or Minimization?

  22. Evidence Lower Bound (ELBO) log 𝑞 𝑦 = log ∫ 𝑞 𝑦, 𝜄 𝑒𝜄

  23. Evidence Lower Bound (ELBO) log 𝑞 𝑦 = log ∫ 𝑞 𝑦, 𝜄 𝑒𝜄 = log ∫ 𝑞 𝑦, 𝜄 𝑟 𝜄 𝑟(𝜄) 𝑒𝜄

  24. Evidence Lower Bound (ELBO) log 𝑞 𝑦 = log ∫ 𝑞 𝑦, 𝜄 𝑒𝜄 = log ∫ 𝑞 𝑦, 𝜄 𝑟 𝜄 𝑟(𝜄) 𝑒𝜄 𝑞 𝑦, 𝜄 = log 𝔽 𝑟 𝜄 𝑟 𝜄

  25. Evidence Lower Bound (ELBO) log 𝑞 𝑦 = log ∫ 𝑞 𝑦, 𝜄 𝑒𝜄 = log ∫ 𝑞 𝑦, 𝜄 𝑟 𝜄 𝑟(𝜄) 𝑒𝜄 𝑞 𝑦, 𝜄 = log 𝔽 𝑟 𝜄 𝑟 𝜄 ≥ 𝔽 𝑟 𝜄 log 𝑞 𝑦, 𝜄 − 𝔽 𝑟 𝜄 log 𝑟(𝜄) = ℒ(𝑟)

  26. Jensen’s Inequality For a concave function 𝑔 • 𝛽 ∈ Δ 𝐿−1 • Sequence of points 𝑦 = 𝑦 1 , … , 𝑦 𝐿 • 𝑔 𝛽 𝑈 𝑦 ≥ σ 𝑙 𝛽 𝑙 𝑔 𝑦 𝑙 • For convex f, flip inequality

  27. Jensen’s Inequality For a concave function 𝑔 • 𝛽 ∈ Δ 𝐿−1 • Sequence of points 𝑦 = 𝑦 1 , … , 𝑦 𝐿 • 𝑔 𝛽 𝑈 𝑦 ≥ σ 𝑙 𝛽 𝑙 𝑔 𝑦 𝑙 For convex f, flip inequality • 𝑔 𝛽 𝑈 𝑦 ≤ σ 𝑙 𝛽 𝑙 𝑔 𝑦 𝑙 • log is convex so for variational inference: log 𝔽 𝑞 ≤ 𝔽 log 𝑞 𝑟 𝑟

  28. EM: A Maximization- Throwback Maximization Procedure any Complete joint observed data distribution according to log-likelihood over Z “true” model 𝐺 𝜄, 𝑟 = 𝔽 𝒟(𝜄) −𝔽 log 𝑟(𝑎) we’ll see this again with variational inference

  29. Steps 1. Write out the objective

  30. Steps 1. Write out the objective 2. Use basic properties of expectations, logs to simplify and expand it – In general, the objective is just a large sum of individual expectations focused on one or two R.V.s

  31. Steps 1. Write out the objective 2. Use basic properties of expectations, logs to simplify and expand it – In general, the objective is just a large sum of individual expectations focused on one or two R.V.s 3. Simplify each expectation

  32. Steps 1. Write out the objective 2. Use basic properties of expectations, logs to simplify and expand it – In general, the objective is just a large sum of individual expectations focused on one or two R.V.s 3. Simplify each expectation 4. Differentiate the objective wrt the variational parameters

  33. Steps 1. Write out the objective 2. Use basic properties of expectations, logs to simplify and expand it – In general, the objective is just a large sum of individual expectations focused on one or two R.V.s 3. Simplify each expectation 4. Differentiate the objective wrt the variational parameters 5. Optimize based on gradients, with two options 1. Closed form solutions • Can lead to better convergence • May not be possible or worth it to get 2.

  34. Steps 1. Write out the objective 2. Use basic properties of expectations, logs to simplify and expand it – In general, the objective is just a large sum of individual expectations focused on one or two R.V.s 3. Simplify each expectation 4. Differentiate the objective wrt the variational parameters 5. Optimize based on gradients, with two options 1. Closed form solutions • Can lead to better convergence • May not be possible or worth it to get 2. Non-closed form (e.g., Newton-like step required) • Differentiation can be handled automatically • Convergence can be slower

  35. Outline Variational Inference Basic Technique Variational Approximation Example: Topic Models

  36. What should q be? Terminology: p /our generative story is our “true” model q approximates the “true” model’s posterior

  37. What should q be? Terminology: p /our generative story is our “true” model q approximates the “true” model’s posterior Therefore… q needs to be a distribution over the latent random variables

  38. What should q be? Terminology: p /our generative story is our “true” model q approximates the “true” model’s posterior Therefore… q needs to be a distribution over the latent random variables q’s precise formulation is task & model dependent q ’s complexity (what (in)dependence assumptions it makes) directly influences the computations

  39. Very common: Mean Field Approximation • Let the observed data be 𝑌 = {𝑦 1 , … , 𝑦 𝑂 } • Let the latent random variables be Θ = {𝜄 1 , … , 𝜄 𝑁 } • Goal: learn 𝑟(Θ)

  40. Very common: Mean Field Approximation • Let the observed data Mean field : be • Regardless of 𝑌 = {𝑦 1 , … , 𝑦 𝑂 } dependencies in the • Let the latent random true model, assume all variables be 𝜄 𝑗 are independent in Θ = {𝜄 1 , … , 𝜄 𝑁 } the q distribution • Goal: learn 𝑟(Θ) • Under the q distribution, each 𝜄 𝑗 has its own parameters

  41. Very common: Mean Field Approximation • Let the observed data Mean field : be • Regardless of 𝑌 = {𝑦 1 , … , 𝑦 𝑂 } dependencies in the • Let the latent random true model, assume all 𝜄 𝑗 are independent in variables be Θ = {𝜄 1 , … , 𝜄 𝑁 } the q distribution • Goal: learn 𝑟(Θ) • Under the q distribution, each 𝜄 𝑗 has its own parameters 𝜄 𝑗 ∼ 𝑟(𝜄 𝑗 ; 𝛿 𝑗 )

  42. Very common: Mean Field Approximation Mean field : • Let the observed data • Regardless of be dependencies in the true 𝑌 = {𝑦 1 , … , 𝑦 𝑂 } model, assume all 𝜄 𝑗 are independent in the q • Let the latent random distribution variables be • Under the q distribution, Θ = {𝜄 1 , … , 𝜄 𝑁 } each 𝜄 𝑗 has its own parameters • Goal: learn 𝑟(Θ) 𝜄 𝑗 ∼ 𝑟(𝜄 𝑗 ; 𝛿 𝑗 ) 𝑁 𝑟 Θ = ෑ 𝑟(𝜄 𝑗 ; 𝛿 𝑗 ) 𝑗=1

  43. Some General Guidelines Easiest math occurs when: • Conjugacy exists in the true model • Family distributions in q mimic those in p

Recommend


More recommend