gradient estimation for implicit models with stein s
play

Gradient Estimation for Implicit Models with Steins Method Yingzhen - PowerPoint PPT Presentation

Gradient Estimation for Implicit Models with Steins Method Yingzhen Li Microsoft Research Cambridge Joint work with Rich Turner, Wenbo Gong, and Jos e Miguel Hern andez-Lobato . A little about my research... scalability VI +


  1. Gradient Estimation for Implicit Models with Stein’s Method Yingzhen Li Microsoft Research Cambridge Joint work with Rich Turner, Wenbo Gong, and Jos´ e Miguel Hern´ andez-Lobato .

  2. A little about my research... scalability VI + Gaussian VI + implicit dist. Bayesian Deep Learning MCMC current methods accuracy 1

  3. Examples for implicit (generative) models Implicit distributions: + easy to sample: z ∼ q ( z | x ) ⇔ ǫ ∼ π ( ǫ ) , z = f ( ǫ , x ) + super flexible 2

  4. Examples for implicit (generative) models • Bayesian inference goal: compute E p ( z | x ) [ F ( z )] • Approximate inference: find q ( z | x ) in some family Q such that q ( z | x ) ≈ p ( z | x ) • At inference time: Monte Carlo integration: K E p ( z | x ) [ F ( z )] ≈ 1 � z k ∼ q ( z | x ) F ( z k ) , K k =1 3

  5. Examples for implicit (generative) models • Bayesian inference goal: compute E p ( z | x ) [ F ( z )] • Approximate inference: find q ( z | x ) in some family Q such that q ( z | x ) ≈ p ( z | x ) • At inference time: Monte Carlo integration: K E p ( z | x ) [ F ( z )] ≈ 1 � z k ∼ q ( z | x ) F ( z k ) , K k =1 Tractability requirement: fast sampling from q (no need for point-wise density evaluation) 3

  6. Examples for implicit (generative) models Implicit distributions: + easy to sample: z ∼ q ( z | x ) ⇔ ǫ ∼ π ( ǫ ) , z = f ( ǫ , x ) + super flexible + better approximate posterior Fig. source: Mescheder et al. (2017) 4

  7. Examples for implicit (generative) models Implicit distributions: + easy to sample: z ∼ q ( z | x ) ⇔ ǫ ∼ π ( ǫ ) , z = f ( ǫ , x ) + super flexible + better approximate posterior − hard to evaluate density, need some tricks for training Fig. source: Mescheder et al. (2017) 4

  8. Loss approximation vs gradient approximation To train the implicit generative model p φ ( x ): E.g. the generative adversarial network (GAN) method (Goodfellow et al. 2014): min φ JS [ p D || p φ ] = min θ max D E p D [log D ( x )] + E p φ [log(1 − D ( x ))] true loss true minimum approx. loss minima approx. loss 5

  9. Loss approximation vs gradient approximation Often we use gradient-based optimisation methods to train machine learning models. ... which only require evaluating the gradient, rather than the loss function! true loss true minimum true gradient true minimum approx. loss approx. loss minima approx. gradient 5

  10. Gradient approximation for VI Variational inference with q distribution parameterised by φ : φ ∗ = arg min KL [ q φ ( z | x ) || p ( z | x )] = arg max L VI ( q φ ) φ φ L VI ( q φ ) = log p ( x ) − KL [ q φ ( z | x ) || p ( z | x )] � � log q φ ( z | x ) = log p ( x ) − E q p ( z | x ) � � log p ( x , z ) = E q q φ ( z | x ) = E q [log p ( x , z )] + H [ q φ ( z | x )] L VI ( q φ ) is also called the variational lower-bound 6

  11. Gradient approximation for VI Variational lower-bound: assume z ∼ q φ ⇔ ǫ ∼ π ( ǫ ) , z = f φ ( ǫ , x ) L VI ( q φ ) = E q [log p ( x , z )] + H [ q φ ( z | x )] = E π [log p ( x , f φ ( ǫ , x ))] + H [ q φ ( z | x )] // reparam . trick 6

  12. Gradient approximation for VI Variational lower-bound: assume z ∼ q φ ⇔ ǫ ∼ π ( ǫ ) , z = f φ ( ǫ , x ) L VI ( q φ ) = E q [log p ( x , z )] + H [ q φ ( z | x )] = E π [log p ( x , f φ ( ǫ , x ))] + H [ q φ ( z | x )] // reparam . trick If you use gradient descent for optimisation, then you only need gradients! 6

  13. Gradient approximation for VI Variational lower-bound: assume z ∼ q φ ⇔ ǫ ∼ π ( ǫ ) , z = f φ ( ǫ , x ) L VI ( q φ ) = E q [log p ( x , z )] + H [ q φ ( z | x )] = E π [log p ( x , f φ ( ǫ , x ))] + H [ q φ ( z | x )] // reparam . trick If you use gradient descent for optimisation, then you only need gradients! The gradient of the variational lower-bound: � � ∇ f log p ( x , f φ ( ǫ , x )) T ∇ φ f φ ( ǫ , x ) ∇ φ L VI ( q φ ) = E π + ∇ φ H [ q φ ( z | x )] 6

  14. Gradient approximation for VI Variational lower-bound: assume z ∼ q φ ⇔ ǫ ∼ π ( ǫ ) , z = f φ ( ǫ , x ) L VI ( q φ ) = E q [log p ( x , z )] + H [ q φ ( z | x )] = E π [log p ( x , f φ ( ǫ , x ))] + H [ q φ ( z | x )] // reparam . trick If you use gradient descent for optimisation, then you only need gradients! The gradient of the variational lower-bound: � � ∇ f log p ( x , f φ ( ǫ , x )) T ∇ φ f φ ( ǫ , x ) ∇ φ L VI ( q φ ) = E π + ∇ φ H [ q φ ( z | x )] The gradient of the entropy term: � � ∇ f log q ( f φ ( ǫ , x ) | x ) T ∇ φ f φ ( ǫ , x )] ❤❤❤❤❤❤❤❤ ✭ − ✭✭✭✭✭✭✭✭ ∇ φ H [ q φ ( z | x )] = − E π E q [ ∇ φ log q φ ( z | x )] ❤ this term is 0 It remains to approximate ∇ z log q ( z | x )! 6

  15. Stein gradient estimator Goal: approximate ∇ x log q ( x ) for a given distribution q ( x ) 7

  16. Stein gradient estimator Goal: approximate ∇ x log q ( x ) for a given distribution q ( x ) Stein’s identity: Define h ( x ): a (column vector) test function satisfying the boundary condition x → ∞ q ( x ) h ( x ) = 0 . lim Then we can derive Stein’s identity using integration by parts: E q [ h ( x ) ∇ x log q ( x ) T + ∇ x h ( x )] = 0 7

  17. Stein gradient estimator Goal: approximate ∇ x log q ( x ) for a given distribution q ( x ) Stein’s identity: Define h ( x ): a (column vector) test function satisfying the boundary condition x → ∞ q ( x ) h ( x ) = 0 . lim Then we can derive Stein’s identity using integration by parts: E q [ h ( x ) ∇ x log q ( x ) T + ∇ x h ( x )] = 0 Invert Stein’s identity to obtain ∇ x log q ( x )! 7

  18. Stein gradient estimator (kernel based) Goal: approximate ∇ x log q ( x ) for a given distribution q ( x ) Main idea: invert Stein’s identity: E q [ h ( x ) ∇ x log q ( x ) T + ∇ x h ( x )] = 0 1. Monte Carlo (MC) approximation to Stein’s identity: K K 1 − h ( x k ) ∇ x k log q ( x k ) T + err = 1 � � x k ∼ q ( x k ) , ∇ x k h ( x k ) , K K k =1 k =1 8

  19. Stein gradient estimator (kernel based) Goal: approximate ∇ x log q ( x ) for a given distribution q ( x ) Main idea: invert Stein’s identity: E q [ h ( x ) ∇ x log q ( x ) T + ∇ x h ( x )] = 0 1. Monte Carlo (MC) approximation to Stein’s identity: K K 1 − h ( x k ) ∇ x k log q ( x k ) T + err = 1 � � x k ∼ q ( x k ) , ∇ x k h ( x k ) , K K k =1 k =1 2. Rewrite the MC equations in matrix forms: denoting K ∇ x h = 1 � � � h ( x 1 ) , · · · , h ( x K ) ∇ x k h ( x k ) , H = , K k =1 � T , � ∇ x 1 log q ( x 1 ) , · · · , ∇ x K log q ( x K ) G := Then − 1 K HG + err = ∇ x h . 8

  20. Stein gradient estimator (kernel based) Goal: approximate ∇ x log q ( x ) for a given distribution q ( x ) Main idea: invert Stein’s identity: E q [ h ( x ) ∇ x log q ( x ) T + ∇ x h ( x )] = 0 Matrix form (MC): − 1 K HG + err = ∇ x h . 3. Now solve a ridge regression problem: ||∇ x h + 1 F + η G Stein ˆ K H ˆ G || 2 K 2 || ˆ G || 2 := arg min F , V ˆ G ∈ R K × d 8

  21. Stein gradient estimator (kernel based) Goal: approximate ∇ x log q ( x ) for a given distribution q ( x ) Main idea: invert Stein’s identity: E q [ h ( x ) ∇ x log q ( x ) T + ∇ x h ( x )] = 0 Matrix form (MC): − 1 K HG + err = ∇ x h . 3. Now solve a ridge regression problem: ||∇ x h + 1 F + η ˆ K H ˆ K 2 || ˆ G Stein G || 2 G || 2 := arg min F , V ˆ G ∈ R K × d Analytic solution: ˆ G Stein = − ( K + η I ) − 1 �∇ , K � , V with K := H T H , K ij = K ( x i , x j ) := h ( x i ) T h ( x j ) , �∇ , K � ij = � K �∇ , K � := K H T ∇ x h , j K ( x i , x k ) . k =1 ∇ x k 8

  22. Stein gradient estimator (kernel based) Kernelized Stein Discrepancy: � ˆ � S 2 ( q , ˆ g ( x ) T K xx ′ ˆ g ( x ′ ) + ˆ g ( x ) T ∇ x ′ K xx ′ + ∇ x K T g ( x ′ ) + Tr( ∇ x , x ′ K xx ′ ) q ) = E x , x ′ ∼ q xx ′ ˆ , K xx ′ = K ( x , x ′ ) . g ( x ) = ∇ x log q ( x ) , g ( x ) = ∇ x log ˆ ˆ q ( x ) , One can show that the V-statistic of KSD is q ) = 1 S 2 K 2 Tr(ˆ G T K ˆ G + 2ˆ G T �∇ , K � ) + C V ( q , ˆ This means q ) + η ˆ K 2 || ˆ G Stein S 2 G || 2 = arg min V ( q , ˆ V F ˆ G ∈ R K × d 9

  23. Comparisons to existing approaches parametric non-parametric denoising kernelise auto-encoder Stein direct (our approach) score special matching improved case sample efficiency NN-based KDE plug-in indirect density estimator KDE plug-in estimator: Singh (1977) Score matching estimator: Hyv¨ arinen (2005), Sasaki et al. (2014), Strathmann et al. (2015) Denoising auto-encoder: Vincent et al. (2008), Alain and Bengio (2014) 10

  24. Comparisons to existing approaches Compare to denoising auto-encoder (DAE): • DAE: for x ∼ q ( x ), denoise x = x + σ ǫ to x ˆ (by min. ℓ 2 loss, ǫ ∼ N ( 0 , I )) • When σ → 0, DAE ∗ (ˆ x ) ≈ x + σ 2 ∇ x log q ( x ) − unstable estimate: depends on σ + functional gradient in RKHS: ||∇ DAE loss || 2 H ∝ KSD Vincent et al. (2008), Alain and Bengio (2014) with Wenbo Gong and Jos´ e Miguel Hern´ andez-Lobato 11

  25. Example: entropy regularised GANs • Addressing mode collapse: train your generator using entropy regularisation: min L gen ( p gen ) − H [ p gen ] • L gen ( p gen ) is the generator loss of your favourite GAN method • Again the gradient of H [ p gen ] is approximated by the gradient estimators 12

Recommend


More recommend