reparameterization gradient
play

Reparameterization Gradient for Non-differentiable Models Wonyeol - PowerPoint PPT Presentation

Reparameterization Gradient for Non-differentiable Models Wonyeol Lee Hangyeol Yu Hongseok Yang KAIST Published at NeurIPS 2018 Reparameterization Gradient for Non-differentiable Models Wonyeol Lee Hangyeol Yu Hongseok Yang


  1. Reparameterization Gradient for Non-differentiable Models Wonyeol Lee Hangyeol Yu Hongseok Yang KAIST Published at NeurIPS 2018

  2. Reparameterization Gradient for Non-differentiable Models Wonyeol Lee Hangyeol Yu Hongseok Yang KAIST Published at NeurIPS 2018

  3. Reparameterization Gradient for Non-differentiable Models Wonyeol Lee Hangyeol Yu Hongseok Yang KAIST Published at NeurIPS 2018

  4. Reparameterization Gradient for Non-differentiable Models Wonyeol Lee Hangyeol Yu Hongseok Yang KAIST Published at NeurIPS 2018

  5. Backgrounds

  6. Posterior inference • Latent variable z � � n . • Observed variable x � � m . • Joint density p(x,z). • Want to infer posterior p(z|x 0 ) given a particular value x 0 of x.

  7. Variational inference 1. Fix a family of variational distr. {q θ (z)} θ . 2. Find q θ (z) that approximates p(z|x 0 ) well. • Typically, by solving argmax θ (ELBO θ ) where ELBO θ = � qθ(z) [ log( p(x 0 ,z)/q θ (z) ) ].

  8. Variational inference differentiable & easy-to-sample 1. Fix a family of variational distr. {q θ (z)} θ . 2. Find q θ (z) that approximates p(z|x 0 ) well. • Typically, by solving argmax θ (ELBO θ ) where ELBO θ = � qθ(z) [ log( p(x 0 ,z)/q θ (z) ) ].

  9. Variational inference differentiable & easy-to-sample 1. Fix a family of variational distr. {q θ (z)} θ . 2. Find q θ (z) that approximates p(z|x 0 ) well. • Typically, by solving argmax θ (ELBO θ ) Typically, by solving where ELBO θ = � qθ (z) [ log( p(x 0 ,z)/q θ (z) ) ]. argmax θ (ELBO θ ) where ELBO θ = � qθ(z) [ log( p(x 0 ,z)/q θ (z) ) ].

  10. Variational inference differentiable & easy-to-sample 1. Fix a family of variational distr. {q θ (z)} θ . 2. Find q θ (z) that approximates p(z|x 0 ) well. • Typically, by solving argmax θ (ELBO θ ) Typically, by solving where ELBO θ = � qθ(z) [ log( p(x 0 ,z)/q θ (z) ) ]. argmax θ (ELBO θ ) where ELBO θ = � qθ(z) [ log( p(x 0 ,z)/q θ (z) ) ]. .. z .. z ..

  11. Gradient ascent θ n+1 = θ n + η × � θ ELBO θ=θn • Difficult to compute � θ ELBO θ . • Use an estimated gradient instead.

  12. Gradient ascent θ n+1 = θ n + η × � θ ELBO θ=θn • Difficult to compute � θ ELBO θ . • Use an estimated gradient instead.

  13. Gradient ascent θ n+1 = θ n + η × � θ ELBO θ=θn • Difficult to compute � θ ELBO θ . • Use an estimated gradient instead.

  14. Reparameterization estimator • Works if p(x 0 ,z) is differentiable wrt. z. • Need distr. q(ε) & smooth function f θ (ε) s.t. f θ (ε) for ε ~ q(ε) has the distr. q θ (z). • Derived from the equation: � θ ELBO θ = � q(ε) [ � θ (.. f θ (ε) .. f θ (ε) ..) ]

  15. Reparameterization estimator • Works if p(x 0 ,z) is differentiable wrt. z. • Need distr. q(ε) & smooth function f θ (ε) s.t. f θ (ε) for ε ~ q(ε) has the distr. q θ (z). • Derived from the equation: � θ ELBO θ = � q(ε) [ � θ (.. f θ (ε) .. f θ (ε) ..) ]

  16. � θ ELBO θ = � θ � qθ(z) [.. z .. z ..] θ θ θ qθ(z) Reparameterization estimator = � θ � q(ε) [.. f θ (ε) .. f θ (ε) ..] θ q(ε) θ θ = � q(ε) [ � θ (.. f θ (ε) .. f θ (ε) ..)] q(ε) θ θ θ • Works if p(x 0 ,z) is differentiable wrt. z. • Need distr. q(ε) & smooth function f θ (ε) s.t. f θ (ε) for ε ~ q(ε) has the distr. q θ (z). • Derived from the equation: � θ ELBO θ = � q(ε) [ � θ (.. f θ (ε) .. f θ (ε) ..) ]

  17. � θ ELBO θ = � θ � qθ(z) [.. z .. z ..] θ θ θ qθ(z) Reparameterization estimator = � θ � q(ε) [.. f θ (ε) .. f θ (ε) ..] θ q(ε) θ θ = � q(ε) [ � θ (.. f θ (ε) .. f θ (ε) ..)] q(ε) θ θ θ • Works if p(x 0 ,z) is differentiable wrt. z. • Need distr. q(ε) & smooth function f θ (ε) s.t. f θ (ε) for ε ~ q(ε) has the distr. q θ (z). • Derived from the equation: � θ ELBO θ = � q(ε) [ � θ (.. f θ (ε) .. f θ (ε) ..) ]

  18. � θ ELBO θ = � θ � qθ(z) [.. z .. z ..] Reparameterization estimator = � θ � q(ε) [.. f θ (ε) .. f θ (ε) ..] = � q(ε) [ � θ (.. f θ (ε) .. f θ (ε) ..)] • Works if p(x 0 ,z) is differentiable wrt. z. • Need distr. q(ε) & smooth function f θ (ε) s.t. f θ (ε) for ε ~ q(ε) has the distr. q θ (z). • Derived from the equation: � θ ELBO θ = � q(ε) [ � θ (.. f θ (ε) .. f θ (ε) ..) ]

  19. � θ ELBO θ = � θ � qθ(z) [.. z .. z ..] θ θ θ qθ(z) Reparameterization estimator = � θ � q(ε) [.. f θ (ε) .. f θ (ε) ..] θ q(ε) θ θ = � q(ε) [ � θ (.. f θ (ε) .. f θ (ε) ..)] q(ε) θ θ θ • Works if p(x 0 ,z) is differentiable wrt. z. • Need distr. q(ε) & smooth function f θ (ε) s.t. f θ (ε) for ε ~ q(ε) has the distr. q θ (z). • Derived from the equation: � θ ELBO θ = � q(ε) [ � θ (.. f θ (ε) .. f θ (ε) ..) ]

  20. Non-differentiable models from probabilistic programming

  21. (let [z (sample (normal 0 1))] (if (> z 0) (observe (normal 3 1) 0) (observe (normal -2 1) 0)) z)

  22. (let [z (sample (normal 0 1))] (if (> z 0) (observe (normal 3 1) 0) (observe (normal -2 1) 0)) z)

  23. (let [z (sample (normal 0 1))] (if (> z 0) (observe (normal 3 1) 0) (observe (normal -2 1) 0)) z)

  24. (let [z (sample (normal 0 1))] (if (> z 0) (observe (normal 3 1) 0) (observe (normal -2 1) 0)) z)

  25. (let [z (sample (normal 0 1))] (if (> z 0) (observe (normal 3 1) 0) (observe (normal -2 1) 0)) z)

  26. (let [z (sample (normal 0 1))] (if (> z 0) (observe (normal 3 1) 0) (observe (normal -2 1) 0)) z) r 1 (z) = � (z|0,1) � (x=0|3,1) r 2 (z) = � (z|0,1) � (x=0|-2,1) p(z,x=0) = [z>0]r 1 (z) + [z≤0]r 2 (z)

  27. (let [z (sample (normal 0 1))] (if (> z 0) (observe (normal 3 1) 0) (observe (normal -2 1) 0)) z) r 1 (z) = � (z|0,1) � (x=0|3,1) r 2 (z) = � (z|0,1) � (x=0|-2,1) p(z,x=0) = [z>0]r 1 (z) + [z≤0]r 2 (z)

  28. (let [z (sample (normal 0 1))] (if (> z 0) (observe (normal 3 1) 0) (observe (normal -2 1) 0)) z) r 1 (z) = � (z|0,1) � (x=0|3,1) r 2 (z) = � (z|0,1) � (x=0|-2,1) p(z,x=0) = [z>0]r 1 (z) + [z≤0]r 2 (z)

  29. (let [z (sample (normal 0 1))] (if (> z 0) (observe (normal 3 1) 0) (observe (normal -2 1) 0)) z) p(z,x=0) r 1 (z) = � (z|0,1) � (x=0|3,1) r 2 (z) = � (z|0,1) � (x=0|-2,1) p(z,x=0) = [z>0]r 1 (z) + [z≤0]r 2 (z) z

  30. (let (let ≈ [z (sample (normal 0 1))] [ ε (sample (normal 0 1)) (if (> z 0) z (+ ε θ )] (observe (normal 3 1) 0) z) (observe (normal -2 1) 0)) z) r 1 (z) = � (z|0,1) � (x=0|3,1) r 2 (z) = � (z|0,1) � (x=0|-2,1) p(z,x=0) = [z>0]r 1 (z) + [z≤0]r 2 (z)

  31. (let (let ≈ [z (sample (normal 0 1))] [ ε (sample (normal 0 1)) (if (> z 0) z (+ ε θ )] (observe (normal 3 1) 0) z) (observe (normal -2 1) 0)) z) r 1 (z) = � (z|0,1) � (x=0|3,1) r 2 (z) = � (z|0,1) � (x=0|-2,1) p(z,x=0) = [z>0]r 1 (z) + [z≤0]r 2 (z)

  32. (let (let ≈ [z (sample (normal 0 1))] [ ε (sample (normal 0 1)) (if (> z 0) z (+ ε θ )] (observe (normal 3 1) 0) z) (observe (normal -2 1) 0)) z) r 1 (z) = � (z|0,1) � (x=0|3,1) r 2 (z) = � (z|0,1) � (x=0|-2,1) p(z,x=0) = [z>0]r 1 (z) + [z≤0]r 2 (z)

  33. (let (let ≈ [z (sample (normal 0 1))] [ ε (sample (normal 0 1)) (if (> z 0) z (+ ε θ )] (observe (normal 3 1) 0) z) (observe (normal -2 1) 0)) z) r 1 (z) = � (z|0,1) � (x=0|3,1) r 2 (z) = � (z|0,1) � (x=0|-2,1) p(z,x=0) = [z>0]r 1 (z) + [z≤0]r 2 (z)

  34. (let (let ≈ [z (sample (normal 0 1))] [ ε (sample (normal 0 1)) (if (> z 0) z (+ ε θ )] (observe (normal 3 1) 0) z) (observe (normal -2 1) 0)) z) q(ε) = � (ε|0,1) r 1 (z) = � (z|0,1) � (x=0|3,1) z = ε+θ r 2 (z) = � (z|0,1) � (x=0|-2,1) p(z,x=0) = [z>0]r 1 (z) + [z≤0]r 2 (z)

  35. (let (let ≈ [z (sample (normal 0 1))] [ ε (sample (normal 0 1)) (if (> z 0) z (+ ε θ )] (observe (normal 3 1) 0) z) (observe (normal -2 1) 0)) z) q(ε) = � (ε|0,1) r 1 (z) = � (z|0,1) � (x=0|3,1) z = ε+θ r 2 (z) = � (z|0,1) � (x=0|-2,1) p(z,x=0) = [z>0]r 1 (z) + [z≤0]r 2 (z) How to find a good θ? θ n+1 ← θ n + η × � θ ELBO θ=θn

  36. (let (let ≈ [z (sample (normal 0 1))] [ ε (sample (normal 0 1)) (if (> z 0) z (+ ε θ )] (observe (normal 3 1) 0) z) (observe (normal -2 1) 0)) z) q(ε) = � (ε|0,1) r 1 (z) = � (z|0,1) � (x=0|3,1) z = ε+θ r 2 (z) = � (z|0,1) � (x=0|-2,1) p(z,x=0) = [z>0]r 1 (z) + [z≤0]r 2 (z) How to find a good θ? By gradient ascent on ELBO θ . θ n+1 ← θ n + η × � θ ELBO θ=θn

  37. (let (let ≈ [z (sample (normal 0 1))] [ ε (sample (normal 0 1)) (if (> z 0) z (+ ε θ )] (observe (normal 3 1) 0) z) (observe (normal -2 1) 0)) z) q(ε) = � (ε|0,1) r 1 (z) = � (z|0,1) � (x=0|3,1) z = ε+θ r 2 (z) = � (z|0,1) � (x=0|-2,1) p(z,x=0) = [z>0]r 1 (z) + [z≤0]r 2 (z) How to find a good θ? By gradient ascent on ELBO θ . θ n+1 ← θ n + η × � θ ELBO θ=θn

Recommend


More recommend