stein point markov chain monte carlo
play

Stein Point Markov Chain Monte Carlo Wilson Chen Institute of - PowerPoint PPT Presentation

Stein Point Markov Chain Monte Carlo Wilson Chen Institute of Statistical Mathematics, Japan June 15, 2019 @ ICML Steins Method Workshop, Long Beach 1/14 Collaborators Alessandro Barp Fran cois-Xavier Briol Jackson Gorham Mark


  1. Stein Point Markov Chain Monte Carlo Wilson Chen Institute of Statistical Mathematics, Japan June 15, 2019 @ ICML Stein’s Method Workshop, Long Beach 1/14

  2. Collaborators Alessandro Barp Fran¸ cois-Xavier Briol Jackson Gorham Mark Girolami Lester Mackey Chris Oates 2/14

  3. Empirical Approximation Problem A major problem in machine learning and modern statistics is to approximate some difficult-to-compute density p defined on some domain X ⊆ R d where normalisation constant is unknown . I.e., p ( x ) = ˜ p ( x ) /Z and Z > 0 is unknown. 3/14

  4. Empirical Approximation Problem A major problem in machine learning and modern statistics is to approximate some difficult-to-compute density p defined on some domain X ⊆ R d where normalisation constant is unknown . I.e., p ( x ) = ˜ p ( x ) /Z and Z > 0 is unknown. We consider an empirical approximation of p with points { x i } n i =1 : n p n ( x ) = 1 � ˆ δ ( x − x i ) , n i =1 so that for test function f : X → R : n � f ( x ) p ( x )d x ≈ 1 � f ( x i ) . n X i =1 3/14

  5. Empirical Approximation Problem A major problem in machine learning and modern statistics is to approximate some difficult-to-compute density p defined on some domain X ⊆ R d where normalisation constant is unknown . I.e., p ( x ) = ˜ p ( x ) /Z and Z > 0 is unknown. We consider an empirical approximation of p with points { x i } n i =1 : n p n ( x ) = 1 � ˆ δ ( x − x i ) , n i =1 so that for test function f : X → R : n � f ( x ) p ( x )d x ≈ 1 � f ( x i ) . n X i =1 A popular approach is Markov chain Monte Carlo. 3/14

  6. Discrepancy Idea – construct a measure of discrepancy D (ˆ p n , p ) with desirable features: ∗ • Detect (non)convergence. I.e., D (ˆ p n , p ) → 0 only if ˆ p n − → p . • Efficiently computable with limited access to p . 4/14

  7. Discrepancy Idea – construct a measure of discrepancy D (ˆ p n , p ) with desirable features: ∗ • Detect (non)convergence. I.e., D (ˆ p n , p ) → 0 only if ˆ p n − → p . • Efficiently computable with limited access to p . Unfortunately not the case for many popular discrepancy measures: • Kullback-Leibler divergence, • Wasserstein distance, • Maximum mean discrepancy (MMD). 4/14

  8. Kernel Embedding and MMD Kernel embedding of a distribution p � µ p ( · ) = k ( x, · ) p ( x )d x (a function in the RKHS K ) 5/14

  9. Kernel Embedding and MMD Kernel embedding of a distribution p � µ p ( · ) = k ( x, · ) p ( x )d x (a function in the RKHS K ) Consider the maximum mean discrepancy (MMD) as an option for D : p n − µ p � K =: D k,p ( { x i } n D (ˆ p n , p ) := � µ ˆ i =1 ) 5/14

  10. Kernel Embedding and MMD Kernel embedding of a distribution p � µ p ( · ) = k ( x, · ) p ( x )d x (a function in the RKHS K ) Consider the maximum mean discrepancy (MMD) as an option for D : p n − µ p � K =: D k,p ( { x i } n D (ˆ p n , p ) := � µ ˆ i =1 ) i =1 ) 2 = � µ ˆ ∴ D k,p ( { x i } n p n − µ p � 2 K = � µ ˆ p n − µ p , µ ˆ p n − µ p � = � µ ˆ p n , µ ˆ p n � − 2 � µ ˆ p n , µ p � + � µ p , µ p � We are faced with intractable integrals w.r.t. p ! 5/14

  11. Kernel Embedding and MMD Kernel embedding of a distribution p � µ p ( · ) = k ( x, · ) p ( x )d x (a function in the RKHS K ) Consider the maximum mean discrepancy (MMD) as an option for D : p n − µ p � K =: D k,p ( { x i } n D (ˆ p n , p ) := � µ ˆ i =1 ) i =1 ) 2 = � µ ˆ ∴ D k,p ( { x i } n p n − µ p � 2 K = � µ ˆ p n − µ p , µ ˆ p n − µ p � = � µ ˆ p n , µ ˆ p n � − 2 � µ ˆ p n , µ p � + � µ p , µ p � We are faced with intractable integrals w.r.t. p ! For a Stein kernel k 0 : � µ p ( · ) = k 0 ( x, · ) p ( x )d x = 0 . i =1 ) 2 =: KSD 2 ! p n − µ p � 2 p n � 2 K 0 =: D k 0 ,p ( { x i } n ∴ � µ ˆ K 0 = � µ ˆ 5/14

  12. Kernel Stein Discrepancy (KSD) The kernel Stein discrepancy (KSD) is given by � n n � i =1 ) = 1 � � D k 0 ,p ( { x i } n � k 0 ( x i , x j ) , � n i =1 j =1 6/14

  13. Kernel Stein Discrepancy (KSD) The kernel Stein discrepancy (KSD) is given by � n n � i =1 ) = 1 � � D k 0 ,p ( { x i } n � k 0 ( x i , x j ) , � n i =1 j =1 where k 0 is the Stein kernel k 0 ( x, x ′ ) := T p T ′ p k ( x, x ′ ) = ∇ x · ∇ x ′ k ( x, x ′ ) + �∇ x log p ( x ) , ∇ x ′ k ( x, x ′ ) � + �∇ x ′ log p ( x ′ ) , ∇ x k ( x, x ′ ) � + �∇ x log p ( x ) , ∇ x ′ log p ( x ′ ) � k ( x, x ′ ) , with T p f = ∇ ( pf ) /p . ( T p is a Stein operator.) 6/14

  14. Kernel Stein Discrepancy (KSD) The kernel Stein discrepancy (KSD) is given by � n n � i =1 ) = 1 � � D k 0 ,p ( { x i } n � k 0 ( x i , x j ) , � n i =1 j =1 where k 0 is the Stein kernel k 0 ( x, x ′ ) := T p T ′ p k ( x, x ′ ) = ∇ x · ∇ x ′ k ( x, x ′ ) + �∇ x log p ( x ) , ∇ x ′ k ( x, x ′ ) � + �∇ x ′ log p ( x ′ ) , ∇ x k ( x, x ′ ) � + �∇ x log p ( x ) , ∇ x ′ log p ( x ′ ) � k ( x, x ′ ) , with T p f = ∇ ( pf ) /p . ( T p is a Stein operator.) • This is computable without the normalisation constant. • Requires gradient information ∇ log p ( x i ) . • Detects (non)convergence for an appropriately chosen k (e.g., the IMQ kernel). 6/14

  15. Stein Points (SP) The main idea of Stein Points is the greedy minimisation of KSD: D k 0 ,p ( { x i } j − 1 x j | x 1 , . . . , x j − 1 ← arg min i =1 ∪ { x } ) x ∈X j − 1 � = arg min k 0 ( x, x ) + 2 k 0 ( x, x i ) . x ∈X i =1 7/14

  16. Stein Points (SP) The main idea of Stein Points is the greedy minimisation of KSD: D k 0 ,p ( { x i } j − 1 x j | x 1 , . . . , x j − 1 ← arg min i =1 ∪ { x } ) x ∈X j − 1 � = arg min k 0 ( x, x ) + 2 k 0 ( x, x i ) . x ∈X i =1 A global optimisation step is needed for each iteration. 7/14

  17. Stein Point Markov Chain Monte Carlo (SP-MCMC) We propose to replace the global minimisation at each iteration j of the SP method with a local search based on a p -invariant Markov chain of length m j . The proposed SP-MCMC method proceeds as follows: 1. Fix an initial point x 1 ∈ X . 2. For j = 2 , . . . , n : a. Select i ∗ ∈ { 1 , . . . , j − 1 } according to criterion crit ( { x i } j − 1 i =1 ) . b. Generate ( y j,i ) m j i =1 from a p -invariant Markov chain with y j, 1 = x i ∗ . i =1 D k 0 ,p ( { x i } j − 1 c. Set x j ← arg min x ∈{ y j,i } i =1 ∪ { x } ) . mj 8/14

  18. Stein Point Markov Chain Monte Carlo (SP-MCMC) We propose to replace the global minimisation at each iteration j of the SP method with a local search based on a p -invariant Markov chain of length m j . The proposed SP-MCMC method proceeds as follows: 1. Fix an initial point x 1 ∈ X . 2. For j = 2 , . . . , n : a. Select i ∗ ∈ { 1 , . . . , j − 1 } according to criterion crit ( { x i } j − 1 i =1 ) . b. Generate ( y j,i ) m j i =1 from a p -invariant Markov chain with y j, 1 = x i ∗ . i =1 D k 0 ,p ( { x i } j − 1 c. Set x j ← arg min x ∈{ y j,i } i =1 ∪ { x } ) . mj For crit , three different approaches are considered: • LAST selects the point last added: i ∗ := j − 1 . • RAND selects i ∗ uniformly at random in { 1 , . . . , j − 1 } . • INFL selects i ∗ to be the index of the most influential point in { x i } j − 1 i =1 . We call x ∗ i the most influential point if removing it from the point set creates the greatest increase in KSD. 8/14

  19. Gaussian Mixture Model Experiment MCMC LAST RAND INFL 0 0 0 0 log KSD -2 -2 -2 -2 -4 -4 -4 -4 500 1000 500 1000 500 1000 500 1000 j 2 2 2 2 0 0 0 0 -2 -2 -2 -2 500 1000 500 1000 500 1000 500 1000 j 1.5 1.5 1.5 1.5 SP-MCMC Density 1 1 1 1 MCMC 0.5 0.5 0.5 0.5 0 0 0 0 0 2 4 6 0 2 4 6 0 2 4 6 0 2 4 6 Jump 2 9/14

  20. IGARCH Experiment ( d = 2 ) -4 -5 -6 -7 log E P MALA -8 RWM SVGD -9 MED SP -10 SP-MALA LAST SP-MALA INFL -11 SP-RWM LAST SP-RWM INFL 2 4 6 8 10 12 log n eval SP-MCMC methods are compared against the original SP (Chen et al., 2018), MED (Roshan Joseph et al., 2015) and SVGD (Liu & Wang, 2016), as well as the Metropolis-adjusted Langevin algorithm ( MALA ) and random-walk Metropolis ( RWM ). 10/14

  21. ODE Experiment ( d = 4 ) 5 4 3 log KSD 2 MALA RWM 1 SVGD MED SP 0 SP-MALA LAST SP-MALA INFL SP-RWM LAST -1 SP-RWM INFL 2 4 6 8 10 12 log n eval SP-MCMC methods are compared against the original SP (Chen et al., 2018), MED (Roshan Joseph et al., 2015) and SVGD (Liu & Wang, 2016), as well as the Metropolis-adjusted Langevin algorithm ( MALA ) and random-walk Metropolis ( RWM ). 11/14

  22. ODE Experiment ( d = 10 ) 8 7 6 5 log KSD MALA 4 RWM SVGD MED 3 SP SP-MALA LAST 2 SP-MALA INFL SP-RWM LAST 1 SP-RWM INFL 4 6 8 10 12 log n eval SP-MCMC methods are compared against the original SP (Chen et al., 2018), MED (Roshan Joseph et al., 2015) and SVGD (Liu & Wang, 2016), as well as the Metropolis-adjusted Langevin algorithm ( MALA ) and random-walk Metropolis ( RWM ). 12/14

Recommend


More recommend