– IN5550 – Neural Methods in Natural Language Processing Attention! Vinit Ravishankar University of Oslo April 4, 2019
Coming up: Last Week ◮ Gated RNNs ◮ Structured predictions ◮ RNN applications Today ◮ Seq2seq ◮ Attention models 2
Recap: unrolled RNNs 3
Recap: unrolled RNNs ◮ Each state s i and output y i depend on the full previous context, e.g. = R ( R ( R ( R ( x 1 , s o ) , x 2 ) , x 3 ) x 4 ) s 4 3
Recap: unrolled RNNs ◮ Each state s i and output y i depend on the full previous context, e.g. = R ( R ( R ( R ( x 1 , s o ) , x 2 ) , x 3 ) x 4 ) s 4 3
Conditioned generation ◮ Generate words using an RNN and a ‘conditioning’ context vector c ◮ p ( t j +1 ) = f ( RNN ([ˆ t j , c ] , s j )) ◮ Keep generating till you reach some maximum length, or generate </s> 4
Seq2seq - basic mode ◮ Words go in, words come out.. ◮ Traditionally uses the last RNN state as the conditioning context 5
Seq2seq - why? ◮ Machine translation 6
Seq2seq - why? ◮ Summarisation 7
Seq2seq - why? ◮ Conversation modelling 8
Seq2seq 9
Seq2seq 9
Seq2seq 9
Seq2seq 9
Seq2seq 9
Seq2seq on steroids “You can’t cram the meaning of a whole —ing sentence into a single —ing vector!” – Ray Mooney 10
Seq2seq on steroids “You can’t cram the meaning of a whole —ing sentence into a single —ing vector!” – Ray Mooney ◮ He’s not wrong, we can barely cram the meaning of a word into a single vector ◮ We could use multiple vectors though 10
Attention Idea: use a weighted sum of input RNN states for every output RNN state 11
Attention Idea: use a weighted sum of input RNN states for every output RNN state 11
Attention - mandatory maths Recap, without attention: p ( t j +1 ) = f ( RNN ([ˆ t j , c ] , s j ])) We’re using a separate context for every output element, i.e. a bunch of c j s for j = 1 , 2 , ..., T y 12
Attention - mandatory maths Recap, without attention: p ( t j +1 ) = f ( RNN ([ˆ t j , c ] , s j ])) We’re using a separate context for every output element, i.e. a bunch of c j s for j = 1 , 2 , ..., T y ◮ c j is a weighted sum of input vectors, i.e. a weighted sum of h 1 , h 2 , ..., h T x 12
Attention - mandatory maths Recap, without attention: p ( t j +1 ) = f ( RNN ([ˆ t j , c ] , s j ])) We’re using a separate context for every output element, i.e. a bunch of c j s for j = 1 , 2 , ..., T y ◮ c j is a weighted sum of input vectors, i.e. a weighted sum of h 1 , h 2 , ..., h T x ◮ The weights α are conditioned by the input state that they are weighting ( i ) and the output state they’re generating ( j ) ◮ i.e., c j = � T x i =1 α ij h i ◮ In English – the context vector that we use to generate the j th output is the weighted sum of all the input hidden states, i . 12
Attention - mandatory maths c j = � T x i =1 α ij s i ◮ How do we calculate these weights? 13
Attention - mandatory maths c j = � T x i =1 α ij s i ◮ How do we calculate these weights? ◮ Learn them while learning to translate. 13
Attention - mandatory maths c j = � T x i =1 α ij s i ◮ How do we calculate these weights? ◮ Learn them while learning to translate. ◮ Use a ‘relevance’ function a 1 that tells you how relevant an input state i is to an output token j 1 Called an ‘alignment model’ 13
Attention - mandatory maths c j = � T x i =1 α ij s i ◮ How do we calculate these weights? ◮ Learn them while learning to translate. ◮ Use a ‘relevance’ function a 1 that tells you how relevant an input state i is to an output token j ◮ Relevances: e ij = a ( s j − 1 , h i ) exp( e ij ) ◮ Weights: α ij = softmax( e ij ) = � Tx k =1 exp( e kj ) 1 Called an ‘alignment model’ 13
Attention - tl;dr Pay attention to a weighted combination of input states to generate the right output state 14
Self-attention John Lennon, 1967: love is all u need 15
Self-attention John Lennon, 1967: love is all u need Vaswani et al., 2017: 15
Self-attention Simple principle: instead of a target paying attention to different parts of the source, make the source pay attention to itself. 16
Self-attention Simple principle: instead of a target paying attention to different parts of the source, make the source pay attention to itself. Okay, maybe that wasn’t so simple. 16
Self-attention the man crossed the street because he fancied it 17
Self-attention the man crossed the street because he fancied it the man crossed the street because he fancied it 17
Self-attention the man crossed the street because he fancied it the man crossed the street because he fancied it the man crossed the street because he fancied it 17
Self-attention the man crossed the street because he fancied it the man crossed the street because he fancied it the man crossed the street because he fancied it ◮ By making parts of a sentence pay attention to other parts of itself, we get fancier representations ◮ This can be an RNN replacement ◮ Where an RNN carries long-term information down a chain, self-attention acts more like a tree 17
Transformer 18
Transformer The important bit: The maths: Attention( Q, K, V ) = softmax( QK T ) V √ d k 19
Transformer Attention( Q, K, V ) = softmax( QK T ) V √ d k What’s happening at a token level: ◮ Obtain three representations of the input, Q , K and V - query, key and value 20
Transformer Attention( Q, K, V ) = softmax( QK T ) V √ d k What’s happening at a token level: ◮ Obtain three representations of the input, Q , K and V - query, key and value ◮ Obtain a set of relevance strengths: QK T . For words i and j , Q i · K j represents the strength of the association - exactly like in seq2seq attention. 20
Transformer Attention( Q, K, V ) = softmax( QK T ) V √ d k What’s happening at a token level: ◮ Obtain three representations of the input, Q , K and V - query, key and value ◮ Obtain a set of relevance strengths: QK T . For words i and j , Q i · K j represents the strength of the association - exactly like in seq2seq attention. ◮ Scale it (stabler gradients, boring maths) and softmax for α s. 20
Transformer Attention( Q, K, V ) = softmax( QK T ) V √ d k What’s happening at a token level: ◮ Obtain three representations of the input, Q , K and V - query, key and value ◮ Obtain a set of relevance strengths: QK T . For words i and j , Q i · K j represents the strength of the association - exactly like in seq2seq attention. ◮ Scale it (stabler gradients, boring maths) and softmax for α s. ◮ Unlike seq2seq, use different ‘value’ vectors to weight. 20
Transformer Attention( Q, K, V ) = softmax( QK T ) V √ d k What’s happening at a token level: ◮ Obtain three representations of the input, Q , K and V - query, key and value ◮ Obtain a set of relevance strengths: QK T . For words i and j , Q i · K j represents the strength of the association - exactly like in seq2seq attention. ◮ Scale it (stabler gradients, boring maths) and softmax for α s. ◮ Unlike seq2seq, use different ‘value’ vectors to weight. In a sense, this is exactly like seq2seq attention, except: a) non-recurrent representations, b) same source/target, c) different value vectors 20
Adding heads Revolutionary idea: if representations learn so much from attention, why not learn many attentions Multi-headed attention is many self-attentions 21
Adding heads Revolutionary idea: if representations learn so much from attention, why not learn many attentions Multi-headed attention is many self-attentions (Simplified) transformer: 21
Transformer - why? ◮ it’s cool 22
Transformer - why? ◮ State-of-the-art for en-de NMT when released, state-of-the-art for en-fr (excluding ensembled) ◮ No recurrence - it’s extremely fast (“1/4th the training resources for French”) ◮ Been used in a bunch of other tasks since 22
What’s next? - Multitask learning 23
What’s next? - Multitask learning 23
Recommend
More recommend