maskgan better text generation via filling in the
play

MaskGAN: Better Text Generation via Filling in the ______ June 5, - PowerPoint PPT Presentation

MaskGAN: Better Text Generation via Filling in the ______ June 5, 2018 ( ) Sungjae Cho (Interdisciplinary Program in Cognitive Science) sj.cho@snu.ac.kr SNU Spoken Language Processing Lab /


  1. MaskGAN: Better Text Generation via Filling in the ______ June 5, 2018 조성재 ( 협동과정 인지과학전공 ) Sungjae Cho (Interdisciplinary Program in Cognitive Science) sj.cho@snu.ac.kr SNU Spoken Language Processing Lab / 서울대학교 음성언어처리연구실

  2. Abstract Maximum likelihood and teacher forcing can result in poor sample quality since  generating text requires conditioning on sequences of words that may have never been observed at training time. An actor-critic conditional GAN, MaskGAN, is introduced in this paper.  MaskGAN produces more realistic conditional and unconditional text samples compared  to a maximum likelihood trained model. SNU Spoken Language Processing Lab / 서울대학교 음성언어처리연구실 2

  3. Prerequisites GAN (Goodfellow et al., 2014) Reinforcement learning;   reward, V-value, Q-value, advantage A mode collapse = mode dropping  Policy gradient  Seq2seq model (Sutskever et al., 2014)  REINFORCE algorithm  maximum likelihood estimation  actor-critic training algorithm  stochastic gradient descent  Pretraining  Autoregression (autoregressively)  Bleu score, n-gram  (Validation) Perplexity  SNU Spoken Language Processing Lab / 서울대학교 음성언어처리연구실 3

  4. Motivations (From 1. Introduction & 2. Related Works) Maximum Likelihood RNN ’s are the most common generative model for sequences.  Teacher Forcing leads to unstable dynamics in the hidden states.  Professor Forcing does solve the above but does not encourage high sample quality.  GAN’s have shown incredible quality samples for images but discrete nature of text make  s training a generator harder. Reinforcement Learning framework can be leveraged to train the generator by policy gra  dients. SNU Spoken Language Processing Lab / 서울대학교 음성언어처리연구실 4

  5. 1. Introduction GANs have only seen limited use for text sequences.  This is due to the discrete nature of text making it infeasible to propagate the gradient fr  om the discriminator back to the generator as in standard GAN training. We overcome this by using Reinforcement Learning (RL) to train the generator while the  discriminator is still trained via maximum likelihood and stochastic gradient descent. SNU Spoken Language Processing Lab / 서울대학교 음성언어처리연구실 5

  6. 2. Related Works Main Related Works SeqGAN (Yu et al., 2017)   trains a language model by using policy gradients to train the generator  to fool a CNN-based discriminator that discriminates between real and synthetic text Professor Forcing (Lamb et al., 2016)   An alternative to training an RNN with teacher forcing by using a discriminator to discriminate t he hidden states of a generator RNN that is conditioned on real and synthetic samples GANs for dialogue generation (Li et al., 2017)   Their method applies REINFORCE with Monte Carlo sampling on the generator . An actor-critic algorithm for sequence prediction (Bahdanau et al., 2017)   The rewards are task-specific scores such as BLEU  instead of having rewards supplied by a discriminator in an adversarial setting SNU Spoken Language Processing Lab / 서울대학교 음성언어처리연구실 6

  7. 2. Related Works Our work is distinct in that An actor-critic training procedure on a task designed to provide rewards at every time s  tep (Li et al., 2017) The in-filling task that may mitigate the problem of severe mode-collapse  The critic that helps the generator converge more rapidly by reducing the high-variance  of the gradient updates SNU Spoken Language Processing Lab / 서울대학교 음성언어처리연구실 7

  8. 3. MaskGAN | 3.1. Notation 𝑦 𝑢 : an input token at time 𝑢  𝑧 𝑢 , 𝑦 𝑢 𝑠𝑓𝑏𝑚 : a target token at time 𝑢  < 𝑛 > : a masked token (where the original token is replaced with a hidden token)  𝑦 𝑢 : the filled-in token of the 𝑢 -th word  ො 𝑦 𝑢 : a filled-in token passed to the discriminator ( ො 𝑦 𝑢 )  ෤ 𝑦 𝑢 = ෤ 𝑦 𝑢 may be either real or fake.  ෤ SNU Spoken Language Processing Lab / 서울대학교 음성언어처리연구실 8

  9. 3. MaskGAN | 3.2. Architecture | Notations Notations 𝒚 = 𝑦 1 , … , 𝑦 𝑈 : a discrete sequence  𝒏 = 𝑛 1 , … , 𝑛 𝑈 : a binary mask that is generated by (deterministically or stochastically) of  the same length  𝑛 𝑢 ∈ 0,1  𝑛 𝑢 selects whether the token at time 𝑢 will remain. 𝒏(𝒚) : the masked sequence   If 𝒚 = 𝑦 1 , 𝑦 2 , 𝑦 3 and 𝒏 = 1,0,1 , then 𝒏 𝒚 = 𝑦 1 , < 𝑛 >, 𝑦 3 .  The original real context SNU Spoken Language Processing Lab / 서울대학교 음성언어처리연구실 9

  10. 3. MaskGAN | 3.2. Architecture | Problem Setups Start with a ground truth discrete sequence 𝑦 = 𝑦 1 , … , 𝑦 𝑈 and a binary mask of the sam  e length, 𝑛 = (𝑛 1 , … , 𝑛 𝑈 ) . Applying the mask on the input sequence creates, 𝑛(𝑦) , a seq uence with blanks: For example: a b c d e 𝑦 1 0 0 1 1 𝑛 a _ _ d e 𝑛(𝑦) The goal of the generator is to autoregressively fill in the missing tokens conditioned on  the previous tokens and the mask. SNU Spoken Language Processing Lab / 서울대학교 음성언어처리연구실 10

  11. 3. MaskGAN | 3.2. Architecture | Generator Generator architecture Seq2seq encoder-decoder architecture  Input: 650 dimension input (soft embedding).  Output: Vocab_size output (one-hot embedding).  The encoder reads in a masked sequence.  The decoder imputes the missing tokens by using the encoder hidden states.  It autoregressively fills in the missing tokens.  SNU Spoken Language Processing Lab / 서울대학교 음성언어처리연구실 11

  12. 3. MaskGAN | 3.2. Architecture | Discriminator Output is probability. Discriminator real |෤ 𝐸 𝜚 ෤ 𝑦 𝑢 |෤ 𝑦 0:𝑈 , 𝒏 𝒚 = 𝑄 𝑦 𝑢 = 𝑦 𝑢 ෤ 𝑦 0:𝑈 , 𝒏 𝒚 𝑠𝑓𝑏𝑚 ) 𝑠𝑓𝑏𝑚 ) 𝑠𝑓𝑏𝑚 ) 𝑠𝑓𝑏𝑚 ) 𝑠𝑓𝑏𝑚 ) 𝑄(෤ 𝑦 1 = 𝑦 1 𝑄(෤ 𝑦 2 = 𝑦 2 𝑄(෤ 𝑦 3 = 𝑦 3 𝑄(෤ 𝑦 4 = 𝑦 4 𝑄(෤ 𝑦 5 = 𝑦 5 Discriminator Encoder Encoder Encoder Encoder Encoder Decoder Decoder Decoder Decoder Decoder a _____ _____ d e 𝑦 1 =a 𝑦 2 =x 𝑦 3 =y 𝑦 4 =d 𝑦 5 =e ෤ ෤ ෤ ෤ ෤ 𝒏(𝒚) 𝑦 0:𝑈 ෤ Discriminator architecture The discriminator has an identical architecture to the generator / except that  the output is a scalar probability at each time point,  real |෤ 𝐸 𝜚 ෤ 𝑦 𝑢 |෤ 𝑦 0:𝑈 , 𝒏 𝒚 = 𝑄 𝑦 𝑢 = 𝑦 𝑢 ෤ 𝑦 0:𝑈 , 𝒏 𝒚 rather than a distribution over the vocabulary size. – the generator case  Set the reward at time 𝑢 as 𝑠 𝑢 ≡ log 𝐸 𝜚 ෤ .  𝑦 𝑢 |෤ 𝑦 0:𝑈 , 𝒏 𝒚 SNU Spoken Language Processing Lab / 서울대학교 음성언어처리연구실 12

  13. 3. MaskGAN | 3.2. Architecture | Discriminator Output is probability. Discriminator real |෤ 𝐸 𝜚 ෤ 𝑦 𝑢 |෤ 𝑦 0:𝑈 , 𝒏 𝒚 = 𝑄 𝑦 𝑢 = 𝑦 𝑢 ෤ 𝑦 0:𝑈 , 𝒏 𝒚 𝑠𝑓𝑏𝑚 ) 𝑠𝑓𝑏𝑚 ) 𝑠𝑓𝑏𝑚 ) 𝑠𝑓𝑏𝑚 ) 𝑠𝑓𝑏𝑚 ) 𝑄(෤ 𝑦 1 = 𝑦 1 𝑄(෤ 𝑦 2 = 𝑦 2 𝑄(෤ 𝑦 3 = 𝑦 3 𝑄(෤ 𝑦 4 = 𝑦 4 𝑄(෤ 𝑦 5 = 𝑦 5 Discriminator Encoder Encoder Encoder Encoder Encoder Decoder Decoder Decoder Decoder Decoder a _____ _____ d e 𝑦 1 =a 𝑦 2 =x 𝑦 3 =y 𝑦 4 =d 𝑦 5 =e ෤ ෤ ෤ ෤ ෤ 𝒏(𝒚) 𝑦 0:𝑈 ෤ Discriminator architecture The discriminator is given the filled-in sequence ෤ 𝑦 0:𝑈 from the generator.  We give the discriminator the true context 𝒏(𝒚) : 𝑦 0:𝑈 𝑠𝑓𝑏𝑚 .  The discriminator 𝐸 𝜚 computes the probability of each token ෤ 𝑦 𝑢 being real ( ෤ real ) giv  𝑦 𝑢 = 𝑦 𝑢 en the true context of the masked sequence 𝒏(𝒚) . SNU Spoken Language Processing Lab / 서울대학교 음성언어처리연구실 13

  14. 3. MaskGAN | 3.2. Architecture | Critic Critic Network log probability = reward Discounted total return 𝑆 𝑢 , 𝑈 𝛿 𝑡 𝑠 • 𝑆 𝑢 = σ 𝑡=𝑢 𝑡 real |෤ 𝑠 𝑢 = log 𝐸 𝜚 ෤ 𝑦 𝑢 |෤ 𝑦 0:𝑈 , 𝒏 𝒚 = log 𝑄 𝑦 𝑢 = 𝑦 𝑢 ෤ 𝑦 0:𝑈 , 𝒏 𝒚 • 𝑊 𝑢 ො State value function 𝑊 𝑦 0:𝑢 = 𝑐 𝑢 Discriminator 𝑠𝑓𝑏𝑚 ) 𝑠𝑓𝑏𝑚 ) 𝑠𝑓𝑏𝑚 ) 𝑠𝑓𝑏𝑚 ) 𝑠𝑓𝑏𝑚 ) 𝑄(෤ 𝑦 1 = 𝑦 1 𝑄(෤ 𝑦 2 = 𝑦 2 𝑄(෤ 𝑦 3 = 𝑦 3 𝑄(෤ 𝑦 4 = 𝑦 4 𝑄(෤ 𝑦 5 = 𝑦 5 Discriminator Encoder Encoder Encoder Encoder Encoder Decoder Decoder Decoder Decoder Decoder a _____ _____ d e 𝑦 1 =a 𝑦 2 =x 𝑦 3 =y 𝑦 4 =d 𝑦 5 =e ෤ ෤ ෤ ෤ ෤ 𝒏(𝒚) 𝑦 0:𝑈 ෤ Critic network The critic network is implemented as an additional head off the discriminator.  The critic network estimates the value function of the fill-in sequence:  𝑊 𝑢 ො with 𝑆 𝑢 = σ 𝑡=𝑢 𝑡 . 𝑈 𝛿 𝑡 𝑠 𝑦 0:𝑢 𝑦 𝑢 , 𝑡 𝑢 ≡  𝑏 𝑢 ≡ ො 𝑦 1 , … , ො ො 𝑦 𝑢−1 SNU Spoken Language Processing Lab / 서울대학교 음성언어처리연구실 14

Recommend


More recommend