MaskGAN: Better Text Generation via Filling in the ______ June 5, 2018 조성재 ( 협동과정 인지과학전공 ) Sungjae Cho (Interdisciplinary Program in Cognitive Science) sj.cho@snu.ac.kr SNU Spoken Language Processing Lab / 서울대학교 음성언어처리연구실
Abstract Maximum likelihood and teacher forcing can result in poor sample quality since generating text requires conditioning on sequences of words that may have never been observed at training time. An actor-critic conditional GAN, MaskGAN, is introduced in this paper. MaskGAN produces more realistic conditional and unconditional text samples compared to a maximum likelihood trained model. SNU Spoken Language Processing Lab / 서울대학교 음성언어처리연구실 2
Prerequisites GAN (Goodfellow et al., 2014) Reinforcement learning; reward, V-value, Q-value, advantage A mode collapse = mode dropping Policy gradient Seq2seq model (Sutskever et al., 2014) REINFORCE algorithm maximum likelihood estimation actor-critic training algorithm stochastic gradient descent Pretraining Autoregression (autoregressively) Bleu score, n-gram (Validation) Perplexity SNU Spoken Language Processing Lab / 서울대학교 음성언어처리연구실 3
Motivations (From 1. Introduction & 2. Related Works) Maximum Likelihood RNN ’s are the most common generative model for sequences. Teacher Forcing leads to unstable dynamics in the hidden states. Professor Forcing does solve the above but does not encourage high sample quality. GAN’s have shown incredible quality samples for images but discrete nature of text make s training a generator harder. Reinforcement Learning framework can be leveraged to train the generator by policy gra dients. SNU Spoken Language Processing Lab / 서울대학교 음성언어처리연구실 4
1. Introduction GANs have only seen limited use for text sequences. This is due to the discrete nature of text making it infeasible to propagate the gradient fr om the discriminator back to the generator as in standard GAN training. We overcome this by using Reinforcement Learning (RL) to train the generator while the discriminator is still trained via maximum likelihood and stochastic gradient descent. SNU Spoken Language Processing Lab / 서울대학교 음성언어처리연구실 5
2. Related Works Main Related Works SeqGAN (Yu et al., 2017) trains a language model by using policy gradients to train the generator to fool a CNN-based discriminator that discriminates between real and synthetic text Professor Forcing (Lamb et al., 2016) An alternative to training an RNN with teacher forcing by using a discriminator to discriminate t he hidden states of a generator RNN that is conditioned on real and synthetic samples GANs for dialogue generation (Li et al., 2017) Their method applies REINFORCE with Monte Carlo sampling on the generator . An actor-critic algorithm for sequence prediction (Bahdanau et al., 2017) The rewards are task-specific scores such as BLEU instead of having rewards supplied by a discriminator in an adversarial setting SNU Spoken Language Processing Lab / 서울대학교 음성언어처리연구실 6
2. Related Works Our work is distinct in that An actor-critic training procedure on a task designed to provide rewards at every time s tep (Li et al., 2017) The in-filling task that may mitigate the problem of severe mode-collapse The critic that helps the generator converge more rapidly by reducing the high-variance of the gradient updates SNU Spoken Language Processing Lab / 서울대학교 음성언어처리연구실 7
3. MaskGAN | 3.1. Notation 𝑦 𝑢 : an input token at time 𝑢 𝑧 𝑢 , 𝑦 𝑢 𝑠𝑓𝑏𝑚 : a target token at time 𝑢 < 𝑛 > : a masked token (where the original token is replaced with a hidden token) 𝑦 𝑢 : the filled-in token of the 𝑢 -th word ො 𝑦 𝑢 : a filled-in token passed to the discriminator ( ො 𝑦 𝑢 ) 𝑦 𝑢 = 𝑦 𝑢 may be either real or fake. SNU Spoken Language Processing Lab / 서울대학교 음성언어처리연구실 8
3. MaskGAN | 3.2. Architecture | Notations Notations 𝒚 = 𝑦 1 , … , 𝑦 𝑈 : a discrete sequence 𝒏 = 𝑛 1 , … , 𝑛 𝑈 : a binary mask that is generated by (deterministically or stochastically) of the same length 𝑛 𝑢 ∈ 0,1 𝑛 𝑢 selects whether the token at time 𝑢 will remain. 𝒏(𝒚) : the masked sequence If 𝒚 = 𝑦 1 , 𝑦 2 , 𝑦 3 and 𝒏 = 1,0,1 , then 𝒏 𝒚 = 𝑦 1 , < 𝑛 >, 𝑦 3 . The original real context SNU Spoken Language Processing Lab / 서울대학교 음성언어처리연구실 9
3. MaskGAN | 3.2. Architecture | Problem Setups Start with a ground truth discrete sequence 𝑦 = 𝑦 1 , … , 𝑦 𝑈 and a binary mask of the sam e length, 𝑛 = (𝑛 1 , … , 𝑛 𝑈 ) . Applying the mask on the input sequence creates, 𝑛(𝑦) , a seq uence with blanks: For example: a b c d e 𝑦 1 0 0 1 1 𝑛 a _ _ d e 𝑛(𝑦) The goal of the generator is to autoregressively fill in the missing tokens conditioned on the previous tokens and the mask. SNU Spoken Language Processing Lab / 서울대학교 음성언어처리연구실 10
3. MaskGAN | 3.2. Architecture | Generator Generator architecture Seq2seq encoder-decoder architecture Input: 650 dimension input (soft embedding). Output: Vocab_size output (one-hot embedding). The encoder reads in a masked sequence. The decoder imputes the missing tokens by using the encoder hidden states. It autoregressively fills in the missing tokens. SNU Spoken Language Processing Lab / 서울대학교 음성언어처리연구실 11
3. MaskGAN | 3.2. Architecture | Discriminator Output is probability. Discriminator real | 𝐸 𝜚 𝑦 𝑢 | 𝑦 0:𝑈 , 𝒏 𝒚 = 𝑄 𝑦 𝑢 = 𝑦 𝑢 𝑦 0:𝑈 , 𝒏 𝒚 𝑠𝑓𝑏𝑚 ) 𝑠𝑓𝑏𝑚 ) 𝑠𝑓𝑏𝑚 ) 𝑠𝑓𝑏𝑚 ) 𝑠𝑓𝑏𝑚 ) 𝑄( 𝑦 1 = 𝑦 1 𝑄( 𝑦 2 = 𝑦 2 𝑄( 𝑦 3 = 𝑦 3 𝑄( 𝑦 4 = 𝑦 4 𝑄( 𝑦 5 = 𝑦 5 Discriminator Encoder Encoder Encoder Encoder Encoder Decoder Decoder Decoder Decoder Decoder a _____ _____ d e 𝑦 1 =a 𝑦 2 =x 𝑦 3 =y 𝑦 4 =d 𝑦 5 =e 𝒏(𝒚) 𝑦 0:𝑈 Discriminator architecture The discriminator has an identical architecture to the generator / except that the output is a scalar probability at each time point, real | 𝐸 𝜚 𝑦 𝑢 | 𝑦 0:𝑈 , 𝒏 𝒚 = 𝑄 𝑦 𝑢 = 𝑦 𝑢 𝑦 0:𝑈 , 𝒏 𝒚 rather than a distribution over the vocabulary size. – the generator case Set the reward at time 𝑢 as 𝑠 𝑢 ≡ log 𝐸 𝜚 . 𝑦 𝑢 | 𝑦 0:𝑈 , 𝒏 𝒚 SNU Spoken Language Processing Lab / 서울대학교 음성언어처리연구실 12
3. MaskGAN | 3.2. Architecture | Discriminator Output is probability. Discriminator real | 𝐸 𝜚 𝑦 𝑢 | 𝑦 0:𝑈 , 𝒏 𝒚 = 𝑄 𝑦 𝑢 = 𝑦 𝑢 𝑦 0:𝑈 , 𝒏 𝒚 𝑠𝑓𝑏𝑚 ) 𝑠𝑓𝑏𝑚 ) 𝑠𝑓𝑏𝑚 ) 𝑠𝑓𝑏𝑚 ) 𝑠𝑓𝑏𝑚 ) 𝑄( 𝑦 1 = 𝑦 1 𝑄( 𝑦 2 = 𝑦 2 𝑄( 𝑦 3 = 𝑦 3 𝑄( 𝑦 4 = 𝑦 4 𝑄( 𝑦 5 = 𝑦 5 Discriminator Encoder Encoder Encoder Encoder Encoder Decoder Decoder Decoder Decoder Decoder a _____ _____ d e 𝑦 1 =a 𝑦 2 =x 𝑦 3 =y 𝑦 4 =d 𝑦 5 =e 𝒏(𝒚) 𝑦 0:𝑈 Discriminator architecture The discriminator is given the filled-in sequence 𝑦 0:𝑈 from the generator. We give the discriminator the true context 𝒏(𝒚) : 𝑦 0:𝑈 𝑠𝑓𝑏𝑚 . The discriminator 𝐸 𝜚 computes the probability of each token 𝑦 𝑢 being real ( real ) giv 𝑦 𝑢 = 𝑦 𝑢 en the true context of the masked sequence 𝒏(𝒚) . SNU Spoken Language Processing Lab / 서울대학교 음성언어처리연구실 13
3. MaskGAN | 3.2. Architecture | Critic Critic Network log probability = reward Discounted total return 𝑆 𝑢 , 𝑈 𝛿 𝑡 𝑠 • 𝑆 𝑢 = σ 𝑡=𝑢 𝑡 real | 𝑠 𝑢 = log 𝐸 𝜚 𝑦 𝑢 | 𝑦 0:𝑈 , 𝒏 𝒚 = log 𝑄 𝑦 𝑢 = 𝑦 𝑢 𝑦 0:𝑈 , 𝒏 𝒚 • 𝑊 𝑢 ො State value function 𝑊 𝑦 0:𝑢 = 𝑐 𝑢 Discriminator 𝑠𝑓𝑏𝑚 ) 𝑠𝑓𝑏𝑚 ) 𝑠𝑓𝑏𝑚 ) 𝑠𝑓𝑏𝑚 ) 𝑠𝑓𝑏𝑚 ) 𝑄( 𝑦 1 = 𝑦 1 𝑄( 𝑦 2 = 𝑦 2 𝑄( 𝑦 3 = 𝑦 3 𝑄( 𝑦 4 = 𝑦 4 𝑄( 𝑦 5 = 𝑦 5 Discriminator Encoder Encoder Encoder Encoder Encoder Decoder Decoder Decoder Decoder Decoder a _____ _____ d e 𝑦 1 =a 𝑦 2 =x 𝑦 3 =y 𝑦 4 =d 𝑦 5 =e 𝒏(𝒚) 𝑦 0:𝑈 Critic network The critic network is implemented as an additional head off the discriminator. The critic network estimates the value function of the fill-in sequence: 𝑊 𝑢 ො with 𝑆 𝑢 = σ 𝑡=𝑢 𝑡 . 𝑈 𝛿 𝑡 𝑠 𝑦 0:𝑢 𝑦 𝑢 , 𝑡 𝑢 ≡ 𝑏 𝑢 ≡ ො 𝑦 1 , … , ො ො 𝑦 𝑢−1 SNU Spoken Language Processing Lab / 서울대학교 음성언어처리연구실 14
Recommend
More recommend