Reliable Variational Learning for Hierarchical Dirichlet Processes Erik Sudderth Brown University Computer Science Collaborators: Michael Hughes & Dae Il Kim, Brown University Prem Gopalan & David Blei, Princeton University
Learning Structured BNP Models Genetics, Climate Change, Politics, … β γ There are reasons to believe that the genetics of an organism are likely to shift due to the extreme changes in our climate. To protect them, our politicians must pass environmental legislation that can protect our future species from becoming extinct… π d α z dn λ 0 φ k x dn Nonparametric: Data-driven discovery of model structure: topics, behaviors, objects, communities … N d ∞ D Reliable: Structure driven by data and modeling assumptions, not heuristic algorithm initializations Hierarchical Dirichlet Process Parsimonious: Want a single model structure with good predictive power, not full posterior uncertainty (Teh et al., JASA 2006)
Memoized Variational Inference for Dirichlet Process Mixture Models Michael Hughes & E. Sudderth 2013 Conference on Neural Information Processing Systems
Dirichlet Process Stick-Breaking GOAL: Partition data into an a priori unknown number of discrete clusters. x n z n v 1 , v 2 , v 3 . . . φ 3 0.5 0.3 0.2 φ 1 φ 2 π 1 π 2 π 3 π ∼ Stick( α ) Each cluster k = 1, 2, … Cluster shape: φ k ∼ H ( λ 0 ) Stick proportion: v k ∼ Beta(1 , α ) Stick-Breaking Cluster frequency: Q k − 1 (Sethuraman 1994 ) π k = v k ` =1 (1 − v ` ) π 1 = v 1 5 α = 20 4.5 0 1 π 2 = v 2 (1 − v 1 ) 4 3.5 3 π 3 = v 3 (1 − v 2 )(1 − v 1 ) α = 5 2.5 2 α = 1 1.5 1 − P K k =1 π k = Q K k =1 (1 − v k ) 1 0.5 0 0 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1
Dirichlet Process Mixtures GOAL: Partition data into an a priori unknown number of discrete clusters. x n z n v 1 , v 2 , v 3 . . . φ 3 0.5 0.3 0.2 φ 1 φ 2 π 1 π 2 π 3 π ∼ Stick( α ) Each cluster k = 1, 2, … Cluster shape: φ k ∼ H ( λ 0 ) Stick proportion: v k ∼ Beta(1 , α ) Cluster frequency: π k Each observation n = 1, 2, …, N: Cluster assignment: z n ∼ Cat( π ) Assume exponential Observed value: x n ∼ F( φ z n ) family likelihoods with conjugate priors � φ T � f ( x n | φ k ) = exp k t ( x n ) − a ( φ k ) 0 ¯ ¯ � λ T � h ( φ k | λ 0 ) = exp t ( φ k ) − ¯ a ( λ 0 ) t ( φ k ) = [ φ k , − a ( φ k )] ,
Dirichlet Process Mixtures GOAL: Partition data into an a priori unknown number of discrete clusters. x n z n v 1 , v 2 , v 3 . . . φ 3 0.5 0.3 0.2 φ 1 φ 2 π 1 π 2 π 3 π ∼ Stick( α ) Each cluster k = 1, 2, … Cluster shape: φ k ∼ H ( λ 0 ) α π Stick proportion: v k ∼ Beta(1 , α ) Hyperparameters Cluster frequency: π k Each observation n = 1, 2, …, N: z n λ 0 Cluster assignment: z n ∼ Cat( π ) Observed value: x n ∼ F( φ z n ) φ k x n � φ T � f ( x n | z n = k, φ ) = exp k t ( x n ) − a ( φ k ) N ∞ Visually summarize model structure via directed graphical model
MCMC for DP Mixtures Can we sample from the posterior distribution over data clusterings? x n z n v 1 , v 2 , v 3 . . . φ 3 0.5 0.3 0.2 φ 1 φ 2 π 1 π 2 π 3 π ∼ Stick( α ) Given any fixed partition z: Marginalize stick-breaking weights α π via Chinese Restaurant Process , assigning positive probability to all partitions of data (large support) z n λ 0 Via conjugacy of base measure to exponential family likelihood, φ k x n marginalize cluster shape parameters N ∞ Gibbs Sampler: (Neal 1992, MacEachern 1994 ) Iteratively resample cluster assignment for one observation, fixing all others.
Mixing for DP Mixture Samplers MNIST: 60,000 digits projected to 50 dimensions via PCA. Number of clusters Log-probability Five random initializations from K=1, K=50, K=300 clusters Reversible jump MCMC? Proposals slow, acceptance low.
Variational Bounds What is the marginal likelihood of our observed data? ZZ X log p ( x | α , λ 0 ) = log p ( x, z, v, φ | α , λ 0 ) dvd φ ZZ q ( z, v, φ ) p ( x, z, v, φ | α , λ 0 ) z X = log dvd φ q ( z, v, φ ) z p ( x, z, v, φ | α , λ 0 ) � Expectation with respect to some = log E q variational distribution q ( z, v, φ ) q ( z, v, φ ) Jensen’s ≥ E q [log p ( x, z, v, φ | α , λ 0 )] − E q [log q ( z, v, φ )] = L ( q ) Inequality Expected log-likelihood Variational (negative of “average energy”) entropy α π Maximizing this bound recovers true posterior: L ( q ) = log p ( x | α , λ 0 ) − KL( q ( z, v, φ ) || p ( z, v, φ | x, α , λ 0 )) λ 0 z n The simplest mean field variational methods create tractable algorithms via assumed independence : φ k x n q ( z, v, φ ) = q ( z ) q ( v, φ ) N ∞
Approximating Infinite Models Beta Exponential Family q ( z n = k ) = r nk Distribution from Conjugate Prior " ∞ " # # N Y Y q ( z, v, φ ) = q ( z ) q ( v, φ ) = q ( z n ) q ( v k ) q ( φ k ) · n =1 k =1 Categorical distribution with unbounded support, and infinitely many potential clusters! 0.2 Top-Down Model Truncation Blei & Jordan, 2006; Ishwaran & James, 2001 0.15 q ( z n ) = Cat( z n | r n 1 , r n 2 , . . . , r nK ) 0.1 " K " K − 1 # # K − 1 Y Y Y q ( v, φ ) = q ( φ k ) q ( v k ) v K = (1 − v k ) . 0.05 , · k =1 k =1 k =1 0 1 2 3 4 5 6 7 8 9 101112131415 Bottom-Up Assignment Truncation α = 4 , K = 10 0.2 Bryant & Sudderth, 2012; Teh, Kurihara, & Welling, 2008 q ( z n ) = Cat( z n | r n 1 , r n 2 , . . . , r nK , 0 , 0 , 0 , . . . ) 0.15 For any k>K, optimal 0.1 ∞ variational distributions Y q ( v, φ ) = q ( v k ) q ( φ k ) equal prior & need not 0.05 k =1 be explicitly represented 0 1 2 3 4 5 6 7 8 9 101112131415
Batch Variational Updates A Bayesian nonparametric analog of Expectation-Maximization (EM) " ∞ " # # N Y Y q ( z, v, φ ) = q ( z n | r n ) · Beta( v k | α k 1 , α k 0 ) h ( φ k | λ k ) n =1 k =1 for some K>0 q ( z n ) = Cat( z n | r n 1 , r n 2 , . . . , r nK , 0 , 0 , 0 , . . . ) Update Assignments (The Expectation Step): For all N data, r nk ∝ exp( E q [log π k ( v )] + E q [log p ( x n | φ k )]) for k ≤ K E q [log π k ( v )] = E q [log( v k )] + P k − 1 ` =1 E q [log(1 − v ` )] ψ ( α k 1 ) − ψ ( α k 1 + α k 0 ) ψ ( α k 0 ) − ψ ( α k 1 + α k 0 ) Update Cluster Parameters (The Other Expectation Step): k = P N k ← P N N 0 s 0 λ k ← λ 0 + s 0 n =1 r nk t ( x n ) n =1 r nk k Expected counts and su ffi cient statistics are only non-zero for first K clusters α k 1 ← 1 + N 0 α k 1 k E q [ v k ] = α k 1 + α k 0 ` = α + P K ` = k +1 N 0 ` = k +1 N 0 α k 0 ← α + P ∞ `
Likelihood Bounds & Convergence L ( q ) = E q [log p ( x, z, v, φ | α , λ 0 )] − E q [log q ( z, v, φ )] Immediately after global parameter update, bound simplifies: log-normalizers for cluster shape and beta stick-breaking priors K X L ( q ) = H [ r ] + [¯ a ( λ k ) − ¯ a ( λ 0 ) + log B ( α k 1 , α k 0 ) − log B (1 , α )] k =1 H [ r ] = − P N k =1 r nk log r nk = − P N P K P ∞ k =1 r nk log r nk n =1 n =1 α π For data item n = 1, 2, … N, and K candidate clusters: q ( z n = k ) = r nk ∝ e E q [log π k ( v )+log p ( x n | φ k )] λ 0 z n For cluster k = 1, 2, … K: Match α k 1 ← 1 + N 0 k ← P N s 0 n =1 r nk t ( x n ) Expected k φ k Su ffi cient x n α k 0 ← α + N 0 λ k ← λ 0 + s 0 Statistics >k N k ∞
Likelihood Bounds & Convergence L ( q ) = E q [log p ( x, z, v, φ | α , λ 0 )] − E q [log q ( z, v, φ )] Immediately after global parameter update, bound simplifies: log-normalizers for cluster shape and beta stick-breaking priors K X L ( q ) = H [ r ] + [¯ a ( λ k ) − ¯ a ( λ 0 ) + log B ( α k 1 , α k 0 ) − log B (1 , α )] k =1 H [ r ] = − P N k =1 r nk log r nk = − P N P K P ∞ k =1 r nk log r nk n =1 n =1 Properties of variational optimization algorithm: + Likelihood bound monotonically increasing, guaranteed convergence to posterior mode α π + Unlike classical EM for MAP estimation, allows Bayesian comparison of hypotheses with varying complexity K , crucial for BNP models λ 0 z n - Truncation level K is assumed fixed - Sensitive to initialization (many modes) - Each iteration must examine all data (SLOW) φ k x n N ∞
Stochastic Variational Inference Ho ff man, Blei, Paisley, & Wang, JMLR 2013 Stochastically partition large dataset into B smaller batches : Learning Rate Update: For each batch b Data x ( B 1 ) r ( B b ) ← Estep( x ( B b ) , α , λ ) ρ t , ( ρ 0 + t ) − κ For cluster k = 1, 2, … K: x ( B 2 ) batch stats s b Robbins-Monro k ← P n ∈ B b r nk t ( x n ) . give noisy . convergence condition: . x ( B b ) N λ b |B b | s b estimate of k ← λ 0 + k P (natural) t ρ t → ∞ . κ ∈ ( . 5 , 1] λ k ← ρ t λ b gradient . k + (1 − ρ t ) λ k . t ρ 2 P t < ∞ x ( B B ) Apply similar updates to stick weights. Properties of stochastic inference: 0.4 a + Per-iteration cost is low b c 0.2 + Initial iterations often very e ff ective - Objective is highly non-convex, so 0 0 1 2 3 4 10 10 10 10 10 convergence guarantee is weak num. iterations t - Batch size and learning rate significantly impact e ffi ciency & accuracy
Recommend
More recommend