rao blackwellized stochastic gradients for discrete
play

Rao-Blackwellized Stochastic Gradients for Discrete Distributions - PowerPoint PPT Presentation

Rao-Blackwellized Stochastic Gradients for Discrete Distributions Runjing (Bryan) Liu June 11, 2019 University of California, Berkeley Objective We fit a discrete latent variable model . Objective We fit a discrete latent variable


  1. Rao-Blackwellized Stochastic Gradients for Discrete Distributions Runjing (Bryan) Liu June 11, 2019 University of California, Berkeley

  2. Objective • We fit a discrete latent variable model .

  3. Objective • We fit a discrete latent variable model . • Fitting such a model involves finding argmin E q η ( z ) [ f η ( z )] η where z is a discrete random variable with K categories.

  4. Objective • We fit a discrete latent variable model . • Fitting such a model involves finding argmin E q η ( z ) [ f η ( z )] η where z is a discrete random variable with K categories. • Two common approaches are :

  5. Objective • We fit a discrete latent variable model . • Fitting such a model involves finding argmin E q η ( z ) [ f η ( z )] η where z is a discrete random variable with K categories. • Two common approaches are : 1. Analytically integrate out z .

  6. Objective • We fit a discrete latent variable model . • Fitting such a model involves finding argmin E q η ( z ) [ f η ( z )] η where z is a discrete random variable with K categories. • Two common approaches are : 1. Analytically integrate out z . K might be large. Problem:

  7. Objective • We fit a discrete latent variable model . • Fitting such a model involves finding argmin E q η ( z ) [ f η ( z )] η where z is a discrete random variable with K categories. • Two common approaches are : 1. Analytically integrate out z . K might be large. Problem: 2. Sample z ∼ q η ( z ), and estimate the gradient with g ( z ).

  8. Objective • We fit a discrete latent variable model . • Fitting such a model involves finding argmin E q η ( z ) [ f η ( z )] η where z is a discrete random variable with K categories. • Two common approaches are : 1. Analytically integrate out z . K might be large. Problem: 2. Sample z ∼ q η ( z ), and estimate the gradient with g ( z ). Problem: g ( z ) might have high variance.

  9. Objective • We fit a discrete latent variable model . • Fitting such a model involves finding argmin E q η ( z ) [ f η ( z )] η where z is a discrete random variable with K categories. • Two common approaches are : 1. Analytically integrate out z . K might be large. Problem: 2. Sample z ∼ q η ( z ), and estimate the gradient with g ( z ). Problem: g ( z ) might have high variance. We propose a method that uses a combination of these two approaches to reduce the variance of any gradient estimator g ( z ).

  10. Our method Suppose g is an unbiased estimate of the gradient, so K � ∇ η L ( η ) = E q η ( z ) [ g ( z )] = q η ( k ) g ( k ) k =1

  11. Our method Suppose g is an unbiased estimate of the gradient, so K � ∇ η L ( η ) = E q η ( z ) [ g ( z )] = q η ( k ) g ( k ) k =1 In many applications (e.g. variational Bayes), q η ( z ) Key observation: is concentrated on only a few categories.

  12. Our method Suppose g is an unbiased estimate of the gradient, so K � ∇ η L ( η ) = E q η ( z ) [ g ( z )] = q η ( k ) g ( k ) k =1 In many applications (e.g. variational Bayes), q η ( z ) Key observation: is concentrated on only a few categories. Our idea: Let us analytically sum categories where q η ( z ) has high probability, and sample the remaining terms.

  13. Our method Suppose g is an unbiased estimate of the gradient, so K � ∇ η L ( η ) = E q η ( z ) [ g ( z )] = q η ( k ) g ( k ) k =1 In many applications (e.g. variational Bayes), q η ( z ) Key observation: is concentrated on only a few categories. Our idea: Let us analytically sum categories where q η ( z ) has high probability, and sample the remaining terms.

  14. Our method In math, K � � q η ( k ) g ( k ) = q η ( z ) g ( z ) + (1 − q η ( C α )) E q η ( z ) [ g ( z ) | z / ∈ C α ] � �� � � �� � k =1 z ∈C α small estimate by sampling � �� � analytically sum

  15. Our method In math, K � � q η ( k ) g ( k ) = q η ( z ) g ( z ) + (1 − q η ( C α )) E q η ( z ) [ g ( z ) | z / ∈ C α ] � �� � � �� � k =1 z ∈C α small estimate by sampling � �� � analytically sum The variance reduction is guaranteed by representing our estimator as an instance of Rao-Blackwellization .

  16. Results: Generative semi-supervised classification We train a classifier to classify the class label of MNIST digits and learn a generative model for MNIST digits conditional on the class label.

  17. Results: Generative semi-supervised classification We train a classifier to classify the class label of MNIST digits and learn a generative model for MNIST digits conditional on the class label. Our objective is to maximize the evidence lower bound (ELBO), p η ( x ) ≥ E q η ( z ) [log p η ( x , z ) − log q η ( z )] In this problem, the class label z has ten discrete categories.

  18. Results: Generative semi-supervised classification

  19. Results: Generative semi-supervised classification

  20. Results: moving MNIST We train a generative model for non-centered MNIST digits.

  21. Results: moving MNIST We train a generative model for non-centered MNIST digits. To do so, we must first learn the location of the MNIST digit. There are 68 × 68 discrete categories.

  22. Results: moving MNIST We train a generative model for non-centered MNIST digits. To do so, we must first learn the location of the MNIST digit. There are 68 × 68 discrete categories. Thus, computing the exact sum is intractable!

  23. Results: moving MNIST Trajectory of the negative ELBO Reconstruction of MNIST digits

  24. Our paper: Rao-Blackwellized Stochastic Gradients for Discrete Distributions https://arxiv.org/abs/1810.04777 Our code: https://github.com/Runjing-Liu120/RaoBlackwellizedSGD The collaboration: Bryan Liu Jeffrey Regier Nilesh Michael I. Jon McAuliffe Tripuraneni Jordan

Recommend


More recommend