ARSM: Augment-REINFORCE-Swap-Merge Estimator for Gradient Backpropagation Through Categorical Variables Mingzhang Yin*, Yuguang Yue*, Mingyuan Zhou The University of Texas at Austin Department of Statistics and Data Sciences IROM Department, McCombs School of Business International Conference on Machine Learning Long Beach, CA, June 13, 2019 (UT-Austin Statistics) ARSM June 2019 1 / 7
Categorical latent variable optimization Goal : Maximize the expectation with respect to categorical variables � E ( φ ) = f ( z ) q φ ( z ) d z = E z ∼ q φ ( z ) [ f ( z )] Notations: f ( z ) is the reward function for categorical z 1 z = ( z 1 , . . . , z K ) ∈ { 1 , 2 , . . . , C } K is a K -dimensional C -way 2 multivariate categorical vector q φ ( z ) = � K k =1 Categorical( z k ; σ ( φ k )) is the categorical distribution 3 whose parameters φ ∈ R KC needs to be optimized Challenge: It is difficult to estimate ∇ φ E ( φ ) especially for large K and C . (UT-Austin Statistics) ARSM June 2019 2 / 7
Derivation of ARSM Augment: the categorical variable z ∼ Cat( σ ( φ )) can be equivalently generated as π i e − φ i , π ∼ Dir( 1 C ) . z = arg min i ∈{ 1 ,..., C } Thus E ( φ ) = E z ∼ q φ ( z ) [ f ( z )] = E π ∼ Dir( 1 C ) [ f (arg min i π i e − φ i )] . REINFORCE: ∇ φ E ( φ ) = E π ∼ Dir( 1 C ) [ f (arg min i π i e − φ i )(1 − C π )] Swap: Swapping the i th and j th elements of π would not change the expectation, which is a property used to provide self-controlled variance reduction (without any tuning parameters). Merge : Sharing random numbers between differently expressed but equivalent expectations leads to ∇ φ c E ( φ ) = E π ∼ Dir( 1 C ) [ g ARSM ( π ) c ] C � C � g ARSM ( π ) c := 1 f ( z c ⇌ j ) − 1 � � f ( z m ⇌ j ) (1 − C π j ) C C j =1 m =1 (UT-Austin Statistics) ARSM June 2019 3 / 7
An illustration example Optimize φ ∈ R C to maximize E z ∼ Cat( σ ( φ )) [ f ( z )] , f ( z ) := 0 . 5 + z / ( CR ) True REINFORCE Gumbel RELAX AR ARS ARSM 0.53 0.53 0.53 0.53 0.53 0.53 0.53 Reward 0.52 0.52 0.52 0.52 0.52 0.52 0.52 1e 3 1e 1 1e 2 1e 1 1e 1 1e 2 0.5 2 1.0 2.5 1.0 Gradient 4 5.0 0.0 1 0.5 0.5 0.0 4.5 2 0.5 0.0 0.0 0 2.5 4.0 0 0.5 0.5 10 0 10 0 10 0 10 0 10 0 10 0 Probability 10 1 10 1 10 1 10 1 10 1 10 1 2 10 10 1 2 10 10 3 10 2 10 2 10 2 10 3 10 3 10 3 10 3 10 3 1e 1 1e 2 1e 6 1e 1 1e 4 1e 5 1.0 1e 6 0.5 4.0 1.0 0.75 Grad_var 1.0 4 3.8 0.50 0.0 0.5 0.5 0.5 3.6 0.25 2 0.5 0.00 0.0 0.0 0.0 3.4 0 5000 0 5000 0 5000 0 5000 0 5000 0 5000 0 5000 Iteration Figure: The optimal solution is σ ( φ ) = (0 , . . . , 1). The reward is computed analytically by E z ∼ Cat( σ ( φ )) [ f ( z )] with maximum as 0 . 533. (UT-Austin Statistics) ARSM June 2019 4 / 7
VAEs with one or two categorical hidden layers (20-dimensional 10-way categorical) 170 160 150 REINFORCE AR 140 RELAX -ELBO Gumbel-S. 130 ARS ARSM Gumbel-S._2layer 120 ARSM_2layer 110 100 0 20 40 60 80 100 120 140 Iterations(x1000) Figure: Plots of negative ELBOs (nats) on binarized MNIST against training iterations. The solid and dash lines correspond to the training and testing respectively. (UT-Austin Statistics) ARSM June 2019 5 / 7
Reinforcement Learning (a sequence of categorical actions) Figure: Moving average reward and log-variance of gradient estimator. In each plot, the solid lines are the median value of ten independent runs. The opaque bars are 10th and 90th percentiles. (UT-Austin Statistics) ARSM June 2019 6 / 7
Thank you! Welcome to our poster at Pacific Ballroom #85 (UT-Austin Statistics) ARSM June 2019 7 / 7
Recommend
More recommend