Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention Angelos Katharopoulos, Apoorv Vyas, Nikolaos Pappas and Fran¸ cois Fleuret ICML, July 2020 https://linear-transformers.com/ Funded by
Transformers are performant Transformer models have demonstrated impressive performance on ◮ NLP (Vaswani et al., 2017; Devlin et al., 2019; Dai et al., 2019; Yang et al., 2019; Radford et al., 2019) ◮ Neural Machine Translation ◮ Question Answering ◮ Textual Entailment A. Katharopoulos Transformers are RNNs 2/17
Transformers are performant Transformer models have demonstrated impressive performance on ◮ NLP (Vaswani et al., 2017; Devlin et al., 2019; Dai et al., 2019; Yang et al., 2019; Radford et al., 2019) ◮ Neural Machine Translation ◮ Question Answering ◮ Textual Entailment ◮ Speech & audio processing (Sperber et al., 2018) ◮ Autoregressive image generation and general computer vision (Child et al., 2019; Parmar et al., 2019; Carion et al., 2020; Cordonnier et al., 2020) A. Katharopoulos Transformers are RNNs 2/17
Transformers are hard to scale � N 2 � Self-attention computation and memory scales as O with respect to the sequence length . GPU Memory (MB) Time (milliseconds) 60 2000 40 1000 20 0 0 1000 2000 3000 4000 1000 2000 3000 4000 Sequence Length Sequence Length A single self-attention layer in an NVIDIA GTX 1080 Ti A. Katharopoulos Transformers are RNNs 3/17
Our contributions in a nutshell ◮ A transformer model with linear complexity both for memory and computation during training A. Katharopoulos Transformers are RNNs 4/17
Our contributions in a nutshell ◮ A transformer model with linear complexity both for memory and computation during training ◮ A transformer model with linear computational complexity and constant memory for autoregressive inference A. Katharopoulos Transformers are RNNs 4/17
Our contributions in a nutshell ◮ A transformer model with linear complexity both for memory and computation during training ◮ A transformer model with linear computational complexity and constant memory for autoregressive inference ◮ Unravel the relation between transformers and RNNs A. Katharopoulos Transformers are RNNs 4/17
Definition of a transformer A. Katharopoulos Transformers are RNNs 5/17
Definition of a transformer A. Katharopoulos Transformers are RNNs 5/17
Definition of a transformer A. Katharopoulos Transformers are RNNs 5/17
Self-Attention The commonly used attention mechanism is the scaled dot product attention Q = XW Q K = XW K V = XW V � QK T � A l ( X ) = V ′ = softmax √ V D A. Katharopoulos Transformers are RNNs 6/17
Self-Attention The commonly used attention mechanism is the scaled dot product attention Q = XW Q K = XW K V = XW V � QK T � A l ( X ) = V ′ = softmax √ V D A. Katharopoulos Transformers are RNNs 6/17
Self-Attention The commonly used attention mechanism is the scaled dot product attention Q = XW Q K = XW K V = XW V � QK T � A l ( X ) = V ′ = softmax √ V D � �� � Quadratic complexity A. Katharopoulos Transformers are RNNs 6/17
Linear Attention What if we write the self-attention using an arbitrary similarity score? � N j =1 sim ( Q i , K j ) V j V ′ i = � N j =1 sim ( Q i , K j ) A. Katharopoulos Transformers are RNNs 7/17
Linear Attention What if this similarity is a kernel, namely sim ( a , b ) = φ ( a ) T φ ( b )? � N j =1 sim ( Q i , K j ) V j V ′ i = � N j =1 sim ( Q i , K j ) Kernelization � N j =1 φ ( Q i ) T φ ( K j ) V j = j =1 φ ( Q i ) T φ ( K j ) � N A. Katharopoulos Transformers are RNNs 7/17
Linear Attention Matrix products are associative which makes the attention computation O ( N ) with respect to the sequence length. � N j =1 sim ( Q i , K j ) V j V ′ i = � N j =1 sim ( Q i , K j ) Kernelization j =1 φ ( Q i ) T φ ( K j ) V j � N = � N j =1 φ ( Q i ) T φ ( K j ) Associativity property φ ( Q i ) T � N j =1 φ ( K j ) V T j = φ ( Q i ) T � N j =1 φ ( K j ) A. Katharopoulos Transformers are RNNs 7/17
Causal Masking Causal masking is used to efficiently train autoregressive transformers. A. Katharopoulos Transformers are RNNs 8/17
Causal Masking Causal masking is used to efficiently train autoregressive transformers. Non-autoregressive Autoregressive � N � i j =1 sim ( Q i , K j ) V j j =1 sim ( Q i , K j ) V j V ′ V ′ i = i = � i � N j =1 sim ( Q i , K j ) j =1 sim ( Q i , K j ) A. Katharopoulos Transformers are RNNs 8/17
Causal Masking Causal masking is used to efficiently train autoregressive transformers. Non-autoregressive Autoregressive φ ( Q i ) T � N φ ( Q i ) T � i j =1 φ ( K j ) V T j =1 φ ( K j ) V T j j V ′ i = V ′ i = φ ( Q i ) T � N φ ( Q i ) T � i j =1 φ ( K j ) j =1 φ ( K j ) A. Katharopoulos Transformers are RNNs 8/17
Causal Masking Causal masking is used to efficiently train autoregressive transformers. Non-autoregressive Autoregressive S S i � �� � � �� � � N � i φ ( Q i ) T φ ( Q i ) T j =1 φ ( K j ) V T j =1 φ ( K j ) V T j j V ′ V ′ i = i = φ ( Q i ) T � N φ ( Q i ) T � i j =1 φ ( K j ) j =1 φ ( K j ) � �� � � �� � Z Z i A. Katharopoulos Transformers are RNNs 8/17
Causal Masking Causal masking is used to efficiently train autoregressive transformers. Non-autoregressive Autoregressive S S i � �� � � �� � � N � i φ ( Q i ) T φ ( Q i ) T j =1 φ ( K j ) V T j =1 φ ( K j ) V T j j V ′ V ′ i = i = φ ( Q i ) T � N φ ( Q i ) T � i j =1 φ ( K j ) j =1 φ ( K j ) � �� � � �� � Z Z i Naive computation of S i and Z i results in quadratic complexity. A. Katharopoulos Transformers are RNNs 8/17
Transformers are RNNs Autoregressive transformers can be written as a function that receives an input x i , modifies the internal state { s i − 1 , z i − 1 } and predicts an output y i . A. Katharopoulos Transformers are RNNs 9/17
Transformers are RNNs Autoregressive transformers can be written as a function that receives an input x i , modifies the internal state { s i − 1 , z i − 1 } and predicts an output y i . A. Katharopoulos Transformers are RNNs 9/17
Transformers are RNNs Autoregressive transformers can be written as a function that receives an input x i , modifies the internal state { s i − 1 , z i − 1 } and predicts an output y i . A. Katharopoulos Transformers are RNNs 9/17
Transformers are RNNs Autoregressive transformers can be written as a function that receives an input x i , modifies the internal state { s i − 1 , z i − 1 } and predicts an output y i . Autoregressive inference with linear complexity and constant memory . A. Katharopoulos Transformers are RNNs 9/17
Practical implications ◮ Our theoretical analysis holds for all transformers even when using infinite dimensional feature maps A. Katharopoulos Transformers are RNNs 10/17
Practical implications ◮ Our theoretical analysis holds for all transformers even when using infinite dimensional feature maps ◮ We need a simple finite dimensional feature map to speed up computation A. Katharopoulos Transformers are RNNs 10/17
Practical implications ◮ Our theoretical analysis holds for all transformers even when using infinite dimensional feature maps ◮ We need a simple finite dimensional feature map to speed up computation ◮ We derive the gradients as cumulative sums which allows for a significant speed-up A. Katharopoulos Transformers are RNNs 10/17
Experimental setup Baselines ◮ Softmax transformer (Vaswani et al., 2017) ◮ LSH attention from Reformer (Kitaev et al., 2020) Experiments ◮ Artificial benchmark for computational and memory requirements ◮ Autoregressive image generation on MNIST and CIFAR-10 ◮ Automatic speech recognition on Wall Street Journal A. Katharopoulos Transformers are RNNs 11/17
Experimental setup Baselines ◮ Softmax transformer (Vaswani et al., 2017) ◮ LSH attention from Reformer (Kitaev et al., 2020) Experiments ◮ Artificial benchmark for computational and memory requirements ◮ Autoregressive image generation on MNIST and CIFAR-10 ◮ Automatic speech recognition on Wall Street Journal A. Katharopoulos Transformers are RNNs 11/17
Experimental setup Baselines ◮ Softmax transformer (Vaswani et al., 2017) ◮ LSH attention from Reformer (Kitaev et al., 2020) Experiments ◮ Artificial benchmark for computational and memory requirements ◮ Autoregressive image generation on MNIST and CIFAR-10 ◮ Automatic speech recognition on Wall Street Journal A. Katharopoulos Transformers are RNNs 11/17
Benchmark GPU Memory (MB) 10 2 Time (milliseconds) 10 3 10 1 10 2 10 0 10 1 2 10 2 12 2 14 2 16 2 10 2 12 2 14 2 16 Sequence Length Sequence Length softmax lsh-1 lsh-4 lsh-8 linear (ours) A. Katharopoulos Transformers are RNNs 12/17
Benchmark GPU Memory (MB) 10 2 Time (milliseconds) 10 3 10 1 10 2 10 0 10 1 2 10 2 12 2 14 2 16 2 10 2 12 2 14 2 16 Sequence Length Sequence Length softmax lsh-1 lsh-4 lsh-8 linear (ours) A. Katharopoulos Transformers are RNNs 12/17
Recommend
More recommend