CS7015 (Deep Learning) : Lecture 21 Variational Autoencoders Mitesh M. Khapra Department of Computer Science and Engineering Indian Institute of Technology Madras 1/36 Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 21
Acknowledgments Tutorial on Variational Autoencoders by Carl Doersch 1 Blog on Variational Autoencoders by Jaan Altosaar 2 1 Tutorial 2 Blog 2/36 Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 21
3/36 Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 21
Module 21.1: Revisiting Autoencoders 4/36 Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 21
Before we start talking about VAEs, let us ˆ quickly revisit autoencoders X An autoencoder contains an encoder which W ∗ takes the input X and maps it to a hidden h representation The decoder then takes this hidden represent- W ation and tries to reconstruct the input from it as ˆ X X The training happens using the following ob- h = g ( W X + b ) jective function m n X = f ( W ∗ h + c ) ˆ 1 � � x ij − x ij ) 2 min (ˆ m W,W ∗ , c , b i =1 j =1 where m is the number of training instances, i =1 and each x i ∈ R n ( x ij is thus the j -th { x i } m dimension of the i -th training instance) 5/36 Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 21
But where’s the fun in this ? ˆ X We are taking an input and simply recon- structing it W ∗ Of course, the fun lies in the fact that we are h getting a good abstraction of the input But RBMs were able to do something more W besides abstraction (they were able to do gen- X eration ) Let us revisit generation in the context of au- h = g ( W X + b ) toencoders X = f ( W ∗ h + c ) ˆ 6/36 Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 21
Can we do generation with autoencoders ? ˆ X In other words, once the autoencoder is trained can I remove the encoder, feed a hid- W ∗ den representation h to the decoder and de- code a ˆ h X from it ? In principle, yes! But in practice there is a W problem with this approach X h is a very high dimensional vector and only a few vectors in this space would actually cor- h = g ( W X + b ) respond to meaningful latent representations X = f ( W ∗ h + c ) ˆ of our input So of all the possible value of h which values should I feed to the decoder (we had asked a similar question before: slide 67, bullet 5 of lecture 19) 7/36 Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 21
Ideally, we should only feed those values of h ˆ which are highly likely X In other words, we are interested in sampling W ∗ from P ( h | X ) so that we pick only those h ’s h which have a high probability But unlike RBMs, autoencoders do not have such a probabilistic interpretation They learn a hidden representation h but not a distribution P ( h | X ) X = f ( W ∗ h + c ) ˆ Similarly the decoder is also deterministic and does not learn a distribution over X (given a h we can get a X but not P ( X | h ) ) 8/36 Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 21
We will now look at variational autoencoders which have the same structure as autoencoders but they learn a distribution over the hidden variables 9/36 Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 21
Module 21.2: Variational Autoencoders: The Neural Network Perspective 10/36 Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 21
Let { X = x i } N i =1 be the training data We can think of X as a random variable in R n For example, X could be an image and the dimensions of X correspond to pixels of the image We are interested in learning an abstraction Figure: Abstraction (i.e., given an X find the hidden representa- tion z ) We are also interested in generation ( i.e. , given a hidden representation generate an X ) In probabilistic terms we are interested in P ( z | X ) and P ( X | z ) (to be consistent with the literation on VAEs we will use z instead of H Figure: Generation and X instead of V ) 11/36 Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 21
Earlier we saw RBMs where we learnt P ( z | X ) and P ( X | z ) H ∈ { 0 , 1 } n Below we list certain characteristics of RBMs c 1 c 2 c n Structural assumptions: We assume cer- · · · h 1 h 2 h n tain independencies in the Markov Network Computational: When training with Gibbs w 1 , 1 w m,n Sampling we have to run the Markov Chain W ∈ R m × n for many time steps which is expensive When using Contrastive Approximation: v 1 v 2 · · · v m Divergence, we approximate the expectation by a point estimate b 1 b 2 b m (Nothing wrong with the above but we just V ∈ { 0 , 1 } m mention them to make the reader aware of these characteristics) 12/36 Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 21
We now return to our goals Reconstruction: ˆ X Goal 1: Learn a distribution over the latent variables ( Q ( z | X )) Decoder P φ ( X | z ) Goal 2: Learn a distribution over the visible variables ( P ( X | z )) VAEs use a neural network based encoder for z Goal 1 and a neural network based decoder for Goal Encoder Q θ ( z | X ) 2 We will look at the encoder first Data: X θ : the parameters of the encoder neural network φ : the parameters of the decoder neural network 13/36 Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 21
Encoder: What do we mean when we say z we want to learn a distribution? We mean that we want to learn the parameters of the µ Σ distribution But what are the parameters of Q ( z | X )? Well it depends on our modeling assump- Q θ ( z | X ) tion! In VAEs we assume that the latent variables come from a standard normal distribution X N (0 , I ) and the job of the encoder is to then predict the parameters of this distribution X ∈ R n , µ ∈ R m and Σ ∈ R m × m 14/36 Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 21
ˆ X i Now what about the decoder? The job of the decoder is to predict a probab- P φ ( X | z ) ility distribution over X : P ( X | z ) Once again we will assume a certain form for this distribution z For example, if we want to predict 28 x 28 Sample pixels and each pixel belongs to R ( i.e. , X ∈ R 784 ) then what would be a suitable family µ for P ( X | z )? Σ We could assume that P ( X | z ) is a Gaussian Q θ ( z | X ) distribution with unit variance The job of the decoder f would then be to X i predict the mean of this distribution as f φ ( z ) 15/36 Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 21
ˆ X i What would be the objective function of the decoder ? P φ ( X | z ) For any given training sample x i it should maximize P ( x i ) given by z ˆ P ( x i ) = P ( z ) P ( x i | z ) dz Sample = − E z ∼ Q θ ( z | x i ) [log P φ ( x i | z )] µ Σ (As usual we take log for numerical stability) Q θ ( z | X ) X i 16/36 Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 21
ˆ X i This is the loss function for one data point ( l i ( θ )) and we will just sum over all the data P φ ( X | z ) points to get the total loss L ( θ ) m � L ( θ ) = l i ( θ ) z i =1 Sample In addition, we also want a constraint on the distribution over the latent variables Specifically, we had assumed P ( z ) to be µ Σ N (0 , I ) and we want Q ( z | X ) to be as close to P ( z ) as possible Q θ ( z | X ) Thus, we will modify the loss function such X i that KL divergence captures l i ( θ, φ ) = − E z ∼ Q θ ( z | x i ) [log P φ ( x i | z )] the difference (or distance) between 2 distributions + KL ( Q θ ( z | x i ) || P ( z )) 17/36 Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 21
ˆ X i The second term in the loss function can actually be thought of as a regularizer P φ ( X | z ) It ensures that the encoder does not cheat by mapping each x i to a different point (a normal distribution with very low variance) in the Euclidean space In other words, in the absence of the regularizer the z encoder can learn a unique mapping for each x i and Sample the decoder can then decode from this unique mapping Even with high variance in samples from the distribu- tion, we want the decoder to be able to reconstruct µ Σ the original data very well (motivation similar to the adding noise) To summarize, for each data point we predict a distri- Q θ ( z | X ) bution such that, with high probability a sample from this distribution should be able to reconstruct the ori- X i ginal data point l i ( θ, φ ) = − E z ∼ Q θ ( z | x i ) [log P φ ( x i | z )] But why do we choose a normal distribution? Isn’t it too simplistic to assume that z follows a normal + KL ( Q θ ( z | x i ) || P ( z )) distribution 18/36 Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 21
Isn’t it a very strong assumption that P ( z ) ∼ N (0 , I ) ? For example, in the 2-dimensional case how can we be sure that P ( z ) is a normal distri- bution and not any other distribution The key insight here is that any distribution in d dimensions can be generated by the fol- lowing steps Step 1: Start with a set of d variables that are normally distributed (that’s exactly what we are assuming for P ( z )) Step 2: Mapping these variables through a sufficiently complex function (that’s exactly l i ( θ, φ ) = − E z ∼ Q θ ( z | x i ) [log P φ ( x i | z )] what the first few layers of the decoder can + KL ( Q θ ( z | x i ) || P ( z )) do) 19/36 Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 21
Recommend
More recommend