Advanced inference in probabilistic programs Brooks Paige
Inference thus far • Likelihood weighting / importance sampling • MCMC (single-dimension, coded by hand) • “Lightweight” Metropolis-Hastings (update one random choice at a time, by re-running the remainder of the program)
Inference: this talk How can we make inference more computationally efficient? • Sequential Monte Carlo uses importance sampling as a building block for an inference algorithm that can succeed in models with higher-dimensional latent spaces • Algorithms which extend SMC: Particle MCMC , and asynchronous SMC � • What sort of proposal distributions should we be simulating from in these methods? Can we learn importance sampling proposals automatically?
Inference in Anglican (doquery ¡:algorithm ¡model ¡[args] ¡options) ¡ • How do you implement an inference algorithm in Anglican? (JW will show you this afternoon) • Two important special forms are the interface between model code and inference code: (sample ¡...) ¡ (observe ¡...) ¡ • Q: what kinds of inference algorithms can we develop and implement using this interface?
Incremental evidence (defquery ¡monolithic-‑observe ¡[] ¡ • If we can write our ¡ ¡... ¡;; ¡many ¡sample ¡statments ¡ programs in such a way ¡ ¡ ¡ ¡(sample ¡...) ¡ ¡ ¡ ¡ ¡(sample ¡...) ¡ that we see early, ¡ ¡ ¡ ¡(sample ¡...) ¡ incremental evidence ¡ ¡... ¡;; ¡single ¡observe ¡/ ¡ ¡ ¡ ¡ ¡ ¡ ¡ ¡;; ¡conditioning ¡statement ¡ ¡ then we can use more ¡ ¡ ¡ ¡ ¡ ¡;; ¡at ¡the ¡end ¡ efficient inference ¡ ¡ ¡ ¡(observe ¡...)) ¡ algorithms. � � (defquery ¡incremental-‑observe ¡[] ¡ • Intuition: sample ¡ ¡(loop ¡... ¡ ¡ ¡ ¡ ¡;; ¡interleaved ¡sample ¡and ¡ ¡ statements which come ¡ ¡ ¡ ¡;; ¡observe ¡statements ¡ after observe statements ¡ ¡ ¡ ¡(sample ¡...) ¡ ¡ can be informed by the ¡ ¡ ¡ ¡(observe ¡...) ¡ ¡ ¡ ¡ ¡(recur ¡...))) ¡ data ¡ ¡ ¡ ¡ ¡
Hidden Markov model x 0 x 1 x 2 x 3 · · · y 1 y 2 y 3
Hidden Markov model x 0 x 1 x 2 x 3 · · · y 1 y 2 y 3 Place a massive observe statement at the end
Hidden Markov model x 0 x 1 x 2 x 3 · · · y 1 y 2 y 3 Place a massive observe statement at the end
Hidden Markov model x 0 x 1 x 2 x 3 · · · y 1 y 2 y 3 Place a massive observe statement at the end
Hidden Markov model x 0 x 1 x 2 x 3 · · · y 1 y 2 y 3 Place a massive observe statement at the end
Hidden Markov model No “feedback” until all random variables have been sampled x 0 x 1 x 2 x 3 · · · y 1 y 2 y 3 Place a massive observe statement at the end
Hidden Markov model x 0 x 1 x 2 x 3 · · · y 1 y 2 y 3 Place observe statements as early as possibly
Hidden Markov model x 0 x 1 x 2 x 3 · · · y 1 y 2 y 3 Place observe statements as early as possibly
Hidden Markov model Does y 1 have high probability given x 0 and x 1 ? x 0 x 1 x 2 x 3 · · · y 1 y 2 y 3 Place observe statements as early as possibly
Hidden Markov model Does y 2 have high probability given x 0 , x 1 , and x 2 ? x 0 x 1 x 2 x 3 · · · y 1 y 2 y 3 Place observe statements as early as possibly
Hidden Markov model x 0 x 1 x 2 x 3 · · · y 1 y 2 y 3 Place observe statements as early as possibly
Hidden Markov model Incremental evidence == computational efficiency? x 0 x 1 x 2 x 3 · · · y 1 y 2 y 3 Place observe statements as early as possibly
Incremental evidence • Many models and settings are naturally written incrementally! ‣ Canonical example: time series models (observe at discrete timesteps) ‣ Planning problems (observe at discrete timesteps) ‣ Models which factor into global and “local” (per- datapoint) observes, such as mixture models and many multilevel Bayesian models ‣ Models such as image synthesis, where the entire “canvas” is always visible and can be evaluated according to a fitness function at any time
State-space models • Running example: inference in state- space models “space” (x) • Observed data y n and latent state x n • Inference goals: estimate latent state; “time” (n) predict future data; N estimate marginal Y p ( x 0: N , y 0: N ) = g ( y n | x 0: n ) f ( x n | x 0: n − 1 ) likelihood n =0
Sequential Monte Carlo n = 1 • Basic idea: approximate the posterior distribution using a weighted set of K K total particles x ( k ) particles 0: n K X w 1: K • p ( x 0: n | y 0: n ) ≈ 0: n ( x 0: n ) δ x ( k ) n k =1
Sequential Monte Carlo n = 1 • Each particle is assigned an (unnormalized) weight W k based on its likelihood n K total particles K X w 1: K • p ( x 0: n | y 0: n ) ≈ 0: n ( x 0: n ) δ x ( k ) n k =1 w k n ∝ W k • n
Sequential Monte Carlo n = 1 • Each particle is assigned an (unnormalized) weight W k based on its likelihood n K total particles K X w 1: K • p ( x 0: n | y 0: n ) ≈ 0: n ( x 0: n ) δ x ( k ) n k =1 w k n ∝ W k • n
Sequential Monte Carlo n = 1 n = 2 • Particles are resampled according to their weights, then simulated forward K total particles • Each particle has zero or more children • Number of children M k n is proportional to the W k weight n
Sequential Monte Carlo n = 1 n = 2 • Particles with low weight are discarded, and particles with high weight are replicated K total particles • Better-than-average particles are replicated more often ] = W k n E [ M n k | W 1: K • n W n
Sequential Monte Carlo n = 1 n = 2 Iteratively, - simulate - weight K total particles - resample
Sequential Monte Carlo n = 1 n = 2 Iteratively, - simulate - weight K total particles - resample
Sequential Monte Carlo n = 1 n = 2 n = 3 K total particles
Sequential Monte Carlo SMC in action: slowed down for clarity
Probabilistic programs as state spaces?
Trace • Sequence of N observe ’s e encounter N s { ( g i , φ i , y i ) } N � i =1 to the sample • Sequence of M sample ’s . This yields seq d { ( f j , θ j ) } M � j =1 sampled values • Sequence of M sampled values ments, wi e) { x j } M � j =1 . own norm • Conditioned on these sampled values the entire computation is deterministic
Trace Probability • Defined as (up to a normalization constant) N M Y Y γ ( x ) , p ( x , y ) = g i ( y i | φ i ) f j ( x j | θ j ) . � i =1 j =1 • Hides true dependency structure ◆ M N � � ✓ ✓ ◆ � ˜ ˜ � ˜ Y Y � � γ ( x ) = p ( x , y ) = g i ( x n i ) ˜ φ i ( x n i ) f j ( x j − 1 ) θ j ( x j − 1 ) y i x j � � i =1 j =1 x 6 { alue x j = x 1 × · · · × x j denote x 4 { sampled values (with x 1 x 2 x 3 x 4 x 5 x 6 etc y 1 y 2
Likelihood Weighting • Run K independent copies of program simulating from the prior M k Y q ( x k ) = f j ( x k j | θ k j ) � j =1 • Accumulate unnormalized weights (likelihoods) N k w ( x k ) = γ ( x k ) Y g k i ( y k i | φ k � q ( x k ) = i ) i =1 • Use in approximate (Monte Carlo) integration K w ( x k ) W k = ˆ X W k R ( x k ) E π [ R ( x )] = P K ` =1 w ( x ` ) k =1
Probabilistic programs as state spaces subspace of x which is • Notation with ˜ x 1: n = ˜ x 1 × · · · × ˜ x n such disjoint. While there are alw ˜ ˜ x 1 x 2 � { { etc � x 1 x 2 x 3 x 4 x 6 x 5 � y 1 y 2 • Incrementalized joint N Y γ n (˜ g ( y n | ˜ x 1: n ) p (˜ x n | ˜ x 1: n ) = x 1: n − 1 ) , � n =1 • Incrementalized target ed incremental targets x 1: n ) = 1 π n (˜ γ n (˜ x 1: n ) Z n
Particle Markov chain Monte Carlo
Particle Markov Chain Monte Carlo n n n • Iterable SMC … - PIMH : “particle independent Metropolis- Hastings” n n n Sweep - PGIBBS : “iterated … conditional SMC” - PGAS : “particle Gibbs ancestral sampling" n n n …
PIMH Math • Each sweep of SMC can n n n compute Z 1 ˆ … N N K 1 ˆ Y ˆ Y X x k Z = Z n = w (˜ 1: n ) � K n =1 n =1 k =1 • PIMH is MH that accepts entire n n n Sweep new particle sets w.p. Z 2 ˆ … ! ˆ Z ? α s PIMH = min 1 , � ˆ Z s − 1 • And all particles can be used n n n ˆ Z ∗ S K … E P IMH [ R ( x )] = 1 ˆ X X W s,k R ( x s,k ) S s =1 k =1
Asynchronous anytime sequential Monte Carlo
Parallelization in SMC • Forward simulation trivially parallelizes • this is the sort of parallelization achieved through (e.g.) parfor in MATLAB, or pmap in functional programming languages • The resampling step (normalizing weights, sampling child counts) is a global synchronous operation • cannot resample until all particles finish simulation
Particle Cascade • Replace resampling step with branching step • Launch particles asynchronously • As each particle arrives at an observation, choose a number of offspring based only on the particles which have arrived so far • … don’t need to wait for all particles to arrive • … only need to track average weights at each observation, which we compute online
Particle Cascade n = 1 • Start by simulating particles, one at a time, from f ( x n | x 1: n − 1 ) • Weight by likelihood g ( y n | x 1: n )
Recommend
More recommend