Reparameterization Gradient for Non-differentiable Models Wonyeol Lee Hangyeol Yu Hongseok Yang KAIST Published at NeurIPS 2018
Reparameterization Gradient for Non-differentiable Models Wonyeol Lee Hangyeol Yu Hongseok Yang KAIST Published at NeurIPS 2018
Reparameterization Gradient for Non-differentiable Models Wonyeol Lee Hangyeol Yu Hongseok Yang KAIST Published at NeurIPS 2018
Reparameterization Gradient for Non-differentiable Models Wonyeol Lee Hangyeol Yu Hongseok Yang KAIST Published at NeurIPS 2018
Backgrounds
Posterior inference • Latent variable z � � n . • Observed variable x � � m . • Joint density p(x,z). • Want to infer posterior p(z|x 0 ) given a particular value x 0 of x.
Variational inference 1. Fix a family of variational distr. {q θ (z)} θ . 2. Find q θ (z) that approximates p(z|x 0 ) well. • Typically, by solving argmax θ (ELBO θ ) where ELBO θ = � qθ(z) [ log( p(x 0 ,z)/q θ (z) ) ].
Variational inference differentiable & easy-to-sample 1. Fix a family of variational distr. {q θ (z)} θ . 2. Find q θ (z) that approximates p(z|x 0 ) well. • Typically, by solving argmax θ (ELBO θ ) where ELBO θ = � qθ(z) [ log( p(x 0 ,z)/q θ (z) ) ].
Variational inference differentiable & easy-to-sample 1. Fix a family of variational distr. {q θ (z)} θ . 2. Find q θ (z) that approximates p(z|x 0 ) well. • Typically, by solving argmax θ (ELBO θ ) Typically, by solving where ELBO θ = � qθ (z) [ log( p(x 0 ,z)/q θ (z) ) ]. argmax θ (ELBO θ ) where ELBO θ = � qθ(z) [ log( p(x 0 ,z)/q θ (z) ) ].
Variational inference differentiable & easy-to-sample 1. Fix a family of variational distr. {q θ (z)} θ . 2. Find q θ (z) that approximates p(z|x 0 ) well. • Typically, by solving argmax θ (ELBO θ ) Typically, by solving where ELBO θ = � qθ(z) [ log( p(x 0 ,z)/q θ (z) ) ]. argmax θ (ELBO θ ) where ELBO θ = � qθ(z) [ log( p(x 0 ,z)/q θ (z) ) ]. .. z .. z ..
Gradient ascent θ n+1 = θ n + η × � θ ELBO θ=θn • Difficult to compute � θ ELBO θ . • Use an estimated gradient instead.
Gradient ascent θ n+1 = θ n + η × � θ ELBO θ=θn • Difficult to compute � θ ELBO θ . • Use an estimated gradient instead.
Gradient ascent θ n+1 = θ n + η × � θ ELBO θ=θn • Difficult to compute � θ ELBO θ . • Use an estimated gradient instead.
Reparameterization estimator • Works if p(x 0 ,z) is differentiable wrt. z. • Need distr. q(ε) & smooth function f θ (ε) s.t. f θ (ε) for ε ~ q(ε) has the distr. q θ (z). • Derived from the equation: � θ ELBO θ = � q(ε) [ � θ (.. f θ (ε) .. f θ (ε) ..) ]
Reparameterization estimator • Works if p(x 0 ,z) is differentiable wrt. z. • Need distr. q(ε) & smooth function f θ (ε) s.t. f θ (ε) for ε ~ q(ε) has the distr. q θ (z). • Derived from the equation: � θ ELBO θ = � q(ε) [ � θ (.. f θ (ε) .. f θ (ε) ..) ]
� θ ELBO θ = � θ � qθ(z) [.. z .. z ..] θ θ θ qθ(z) Reparameterization estimator = � θ � q(ε) [.. f θ (ε) .. f θ (ε) ..] θ q(ε) θ θ = � q(ε) [ � θ (.. f θ (ε) .. f θ (ε) ..)] q(ε) θ θ θ • Works if p(x 0 ,z) is differentiable wrt. z. • Need distr. q(ε) & smooth function f θ (ε) s.t. f θ (ε) for ε ~ q(ε) has the distr. q θ (z). • Derived from the equation: � θ ELBO θ = � q(ε) [ � θ (.. f θ (ε) .. f θ (ε) ..) ]
� θ ELBO θ = � θ � qθ(z) [.. z .. z ..] θ θ θ qθ(z) Reparameterization estimator = � θ � q(ε) [.. f θ (ε) .. f θ (ε) ..] θ q(ε) θ θ = � q(ε) [ � θ (.. f θ (ε) .. f θ (ε) ..)] q(ε) θ θ θ • Works if p(x 0 ,z) is differentiable wrt. z. • Need distr. q(ε) & smooth function f θ (ε) s.t. f θ (ε) for ε ~ q(ε) has the distr. q θ (z). • Derived from the equation: � θ ELBO θ = � q(ε) [ � θ (.. f θ (ε) .. f θ (ε) ..) ]
� θ ELBO θ = � θ � qθ(z) [.. z .. z ..] Reparameterization estimator = � θ � q(ε) [.. f θ (ε) .. f θ (ε) ..] = � q(ε) [ � θ (.. f θ (ε) .. f θ (ε) ..)] • Works if p(x 0 ,z) is differentiable wrt. z. • Need distr. q(ε) & smooth function f θ (ε) s.t. f θ (ε) for ε ~ q(ε) has the distr. q θ (z). • Derived from the equation: � θ ELBO θ = � q(ε) [ � θ (.. f θ (ε) .. f θ (ε) ..) ]
� θ ELBO θ = � θ � qθ(z) [.. z .. z ..] θ θ θ qθ(z) Reparameterization estimator = � θ � q(ε) [.. f θ (ε) .. f θ (ε) ..] θ q(ε) θ θ = � q(ε) [ � θ (.. f θ (ε) .. f θ (ε) ..)] q(ε) θ θ θ • Works if p(x 0 ,z) is differentiable wrt. z. • Need distr. q(ε) & smooth function f θ (ε) s.t. f θ (ε) for ε ~ q(ε) has the distr. q θ (z). • Derived from the equation: � θ ELBO θ = � q(ε) [ � θ (.. f θ (ε) .. f θ (ε) ..) ]
Non-differentiable models from probabilistic programming
(let [z (sample (normal 0 1))] (if (> z 0) (observe (normal 3 1) 0) (observe (normal -2 1) 0)) z)
(let [z (sample (normal 0 1))] (if (> z 0) (observe (normal 3 1) 0) (observe (normal -2 1) 0)) z)
(let [z (sample (normal 0 1))] (if (> z 0) (observe (normal 3 1) 0) (observe (normal -2 1) 0)) z)
(let [z (sample (normal 0 1))] (if (> z 0) (observe (normal 3 1) 0) (observe (normal -2 1) 0)) z)
(let [z (sample (normal 0 1))] (if (> z 0) (observe (normal 3 1) 0) (observe (normal -2 1) 0)) z)
(let [z (sample (normal 0 1))] (if (> z 0) (observe (normal 3 1) 0) (observe (normal -2 1) 0)) z) r 1 (z) = � (z|0,1) � (x=0|3,1) r 2 (z) = � (z|0,1) � (x=0|-2,1) p(z,x=0) = [z>0]r 1 (z) + [z≤0]r 2 (z)
(let [z (sample (normal 0 1))] (if (> z 0) (observe (normal 3 1) 0) (observe (normal -2 1) 0)) z) r 1 (z) = � (z|0,1) � (x=0|3,1) r 2 (z) = � (z|0,1) � (x=0|-2,1) p(z,x=0) = [z>0]r 1 (z) + [z≤0]r 2 (z)
(let [z (sample (normal 0 1))] (if (> z 0) (observe (normal 3 1) 0) (observe (normal -2 1) 0)) z) r 1 (z) = � (z|0,1) � (x=0|3,1) r 2 (z) = � (z|0,1) � (x=0|-2,1) p(z,x=0) = [z>0]r 1 (z) + [z≤0]r 2 (z)
(let [z (sample (normal 0 1))] (if (> z 0) (observe (normal 3 1) 0) (observe (normal -2 1) 0)) z) p(z,x=0) r 1 (z) = � (z|0,1) � (x=0|3,1) r 2 (z) = � (z|0,1) � (x=0|-2,1) p(z,x=0) = [z>0]r 1 (z) + [z≤0]r 2 (z) z
(let (let ≈ [z (sample (normal 0 1))] [ ε (sample (normal 0 1)) (if (> z 0) z (+ ε θ )] (observe (normal 3 1) 0) z) (observe (normal -2 1) 0)) z) r 1 (z) = � (z|0,1) � (x=0|3,1) r 2 (z) = � (z|0,1) � (x=0|-2,1) p(z,x=0) = [z>0]r 1 (z) + [z≤0]r 2 (z)
(let (let ≈ [z (sample (normal 0 1))] [ ε (sample (normal 0 1)) (if (> z 0) z (+ ε θ )] (observe (normal 3 1) 0) z) (observe (normal -2 1) 0)) z) r 1 (z) = � (z|0,1) � (x=0|3,1) r 2 (z) = � (z|0,1) � (x=0|-2,1) p(z,x=0) = [z>0]r 1 (z) + [z≤0]r 2 (z)
(let (let ≈ [z (sample (normal 0 1))] [ ε (sample (normal 0 1)) (if (> z 0) z (+ ε θ )] (observe (normal 3 1) 0) z) (observe (normal -2 1) 0)) z) r 1 (z) = � (z|0,1) � (x=0|3,1) r 2 (z) = � (z|0,1) � (x=0|-2,1) p(z,x=0) = [z>0]r 1 (z) + [z≤0]r 2 (z)
(let (let ≈ [z (sample (normal 0 1))] [ ε (sample (normal 0 1)) (if (> z 0) z (+ ε θ )] (observe (normal 3 1) 0) z) (observe (normal -2 1) 0)) z) r 1 (z) = � (z|0,1) � (x=0|3,1) r 2 (z) = � (z|0,1) � (x=0|-2,1) p(z,x=0) = [z>0]r 1 (z) + [z≤0]r 2 (z)
(let (let ≈ [z (sample (normal 0 1))] [ ε (sample (normal 0 1)) (if (> z 0) z (+ ε θ )] (observe (normal 3 1) 0) z) (observe (normal -2 1) 0)) z) q(ε) = � (ε|0,1) r 1 (z) = � (z|0,1) � (x=0|3,1) z = ε+θ r 2 (z) = � (z|0,1) � (x=0|-2,1) p(z,x=0) = [z>0]r 1 (z) + [z≤0]r 2 (z)
(let (let ≈ [z (sample (normal 0 1))] [ ε (sample (normal 0 1)) (if (> z 0) z (+ ε θ )] (observe (normal 3 1) 0) z) (observe (normal -2 1) 0)) z) q(ε) = � (ε|0,1) r 1 (z) = � (z|0,1) � (x=0|3,1) z = ε+θ r 2 (z) = � (z|0,1) � (x=0|-2,1) p(z,x=0) = [z>0]r 1 (z) + [z≤0]r 2 (z) How to find a good θ? θ n+1 ← θ n + η × � θ ELBO θ=θn
(let (let ≈ [z (sample (normal 0 1))] [ ε (sample (normal 0 1)) (if (> z 0) z (+ ε θ )] (observe (normal 3 1) 0) z) (observe (normal -2 1) 0)) z) q(ε) = � (ε|0,1) r 1 (z) = � (z|0,1) � (x=0|3,1) z = ε+θ r 2 (z) = � (z|0,1) � (x=0|-2,1) p(z,x=0) = [z>0]r 1 (z) + [z≤0]r 2 (z) How to find a good θ? By gradient ascent on ELBO θ . θ n+1 ← θ n + η × � θ ELBO θ=θn
(let (let ≈ [z (sample (normal 0 1))] [ ε (sample (normal 0 1)) (if (> z 0) z (+ ε θ )] (observe (normal 3 1) 0) z) (observe (normal -2 1) 0)) z) q(ε) = � (ε|0,1) r 1 (z) = � (z|0,1) � (x=0|3,1) z = ε+θ r 2 (z) = � (z|0,1) � (x=0|-2,1) p(z,x=0) = [z>0]r 1 (z) + [z≤0]r 2 (z) How to find a good θ? By gradient ascent on ELBO θ . θ n+1 ← θ n + η × � θ ELBO θ=θn
Recommend
More recommend