transformers are rnns fast autoregressive transformers
play

Transformers are RNNs: Fast Autoregressive Transformers with Linear - PowerPoint PPT Presentation

Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention Angelos Katharopoulos, Apoorv Vyas, Nikolaos Pappas and Fran cois Fleuret ICML, July 2020 https://linear-transformers.com/ Funded by Transformers are performant


  1. Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention Angelos Katharopoulos, Apoorv Vyas, Nikolaos Pappas and Fran¸ cois Fleuret ICML, July 2020 https://linear-transformers.com/ Funded by

  2. Transformers are performant Transformer models have demonstrated impressive performance on ◮ NLP (Vaswani et al., 2017; Devlin et al., 2019; Dai et al., 2019; Yang et al., 2019; Radford et al., 2019) ◮ Neural Machine Translation ◮ Question Answering ◮ Textual Entailment A. Katharopoulos Transformers are RNNs 2/17

  3. Transformers are performant Transformer models have demonstrated impressive performance on ◮ NLP (Vaswani et al., 2017; Devlin et al., 2019; Dai et al., 2019; Yang et al., 2019; Radford et al., 2019) ◮ Neural Machine Translation ◮ Question Answering ◮ Textual Entailment ◮ Speech & audio processing (Sperber et al., 2018) ◮ Autoregressive image generation and general computer vision (Child et al., 2019; Parmar et al., 2019; Carion et al., 2020; Cordonnier et al., 2020) A. Katharopoulos Transformers are RNNs 2/17

  4. Transformers are hard to scale � N 2 � Self-attention computation and memory scales as O with respect to the sequence length . GPU Memory (MB) Time (milliseconds) 60 2000 40 1000 20 0 0 1000 2000 3000 4000 1000 2000 3000 4000 Sequence Length Sequence Length A single self-attention layer in an NVIDIA GTX 1080 Ti A. Katharopoulos Transformers are RNNs 3/17

  5. Our contributions in a nutshell ◮ A transformer model with linear complexity both for memory and computation during training A. Katharopoulos Transformers are RNNs 4/17

  6. Our contributions in a nutshell ◮ A transformer model with linear complexity both for memory and computation during training ◮ A transformer model with linear computational complexity and constant memory for autoregressive inference A. Katharopoulos Transformers are RNNs 4/17

  7. Our contributions in a nutshell ◮ A transformer model with linear complexity both for memory and computation during training ◮ A transformer model with linear computational complexity and constant memory for autoregressive inference ◮ Unravel the relation between transformers and RNNs A. Katharopoulos Transformers are RNNs 4/17

  8. Definition of a transformer A. Katharopoulos Transformers are RNNs 5/17

  9. Definition of a transformer A. Katharopoulos Transformers are RNNs 5/17

  10. Definition of a transformer A. Katharopoulos Transformers are RNNs 5/17

  11. Self-Attention The commonly used attention mechanism is the scaled dot product attention Q = XW Q K = XW K V = XW V � QK T � A l ( X ) = V ′ = softmax √ V D A. Katharopoulos Transformers are RNNs 6/17

  12. Self-Attention The commonly used attention mechanism is the scaled dot product attention Q = XW Q K = XW K V = XW V � QK T � A l ( X ) = V ′ = softmax √ V D A. Katharopoulos Transformers are RNNs 6/17

  13. Self-Attention The commonly used attention mechanism is the scaled dot product attention Q = XW Q K = XW K V = XW V � QK T � A l ( X ) = V ′ = softmax √ V D � �� � Quadratic complexity A. Katharopoulos Transformers are RNNs 6/17

  14. Linear Attention What if we write the self-attention using an arbitrary similarity score? � N j =1 sim ( Q i , K j ) V j V ′ i = � N j =1 sim ( Q i , K j ) A. Katharopoulos Transformers are RNNs 7/17

  15. Linear Attention What if this similarity is a kernel, namely sim ( a , b ) = φ ( a ) T φ ( b )? � N j =1 sim ( Q i , K j ) V j V ′ i = � N j =1 sim ( Q i , K j ) Kernelization � N j =1 φ ( Q i ) T φ ( K j ) V j = j =1 φ ( Q i ) T φ ( K j ) � N A. Katharopoulos Transformers are RNNs 7/17

  16. Linear Attention Matrix products are associative which makes the attention computation O ( N ) with respect to the sequence length. � N j =1 sim ( Q i , K j ) V j V ′ i = � N j =1 sim ( Q i , K j ) Kernelization j =1 φ ( Q i ) T φ ( K j ) V j � N = � N j =1 φ ( Q i ) T φ ( K j ) Associativity property φ ( Q i ) T � N j =1 φ ( K j ) V T j = φ ( Q i ) T � N j =1 φ ( K j ) A. Katharopoulos Transformers are RNNs 7/17

  17. Causal Masking Causal masking is used to efficiently train autoregressive transformers. A. Katharopoulos Transformers are RNNs 8/17

  18. Causal Masking Causal masking is used to efficiently train autoregressive transformers. Non-autoregressive Autoregressive � N � i j =1 sim ( Q i , K j ) V j j =1 sim ( Q i , K j ) V j V ′ V ′ i = i = � i � N j =1 sim ( Q i , K j ) j =1 sim ( Q i , K j ) A. Katharopoulos Transformers are RNNs 8/17

  19. Causal Masking Causal masking is used to efficiently train autoregressive transformers. Non-autoregressive Autoregressive φ ( Q i ) T � N φ ( Q i ) T � i j =1 φ ( K j ) V T j =1 φ ( K j ) V T j j V ′ i = V ′ i = φ ( Q i ) T � N φ ( Q i ) T � i j =1 φ ( K j ) j =1 φ ( K j ) A. Katharopoulos Transformers are RNNs 8/17

  20. Causal Masking Causal masking is used to efficiently train autoregressive transformers. Non-autoregressive Autoregressive S S i � �� � � �� � � N � i φ ( Q i ) T φ ( Q i ) T j =1 φ ( K j ) V T j =1 φ ( K j ) V T j j V ′ V ′ i = i = φ ( Q i ) T � N φ ( Q i ) T � i j =1 φ ( K j ) j =1 φ ( K j ) � �� � � �� � Z Z i A. Katharopoulos Transformers are RNNs 8/17

  21. Causal Masking Causal masking is used to efficiently train autoregressive transformers. Non-autoregressive Autoregressive S S i � �� � � �� � � N � i φ ( Q i ) T φ ( Q i ) T j =1 φ ( K j ) V T j =1 φ ( K j ) V T j j V ′ V ′ i = i = φ ( Q i ) T � N φ ( Q i ) T � i j =1 φ ( K j ) j =1 φ ( K j ) � �� � � �� � Z Z i Naive computation of S i and Z i results in quadratic complexity. A. Katharopoulos Transformers are RNNs 8/17

  22. Transformers are RNNs Autoregressive transformers can be written as a function that receives an input x i , modifies the internal state { s i − 1 , z i − 1 } and predicts an output y i . A. Katharopoulos Transformers are RNNs 9/17

  23. Transformers are RNNs Autoregressive transformers can be written as a function that receives an input x i , modifies the internal state { s i − 1 , z i − 1 } and predicts an output y i . A. Katharopoulos Transformers are RNNs 9/17

  24. Transformers are RNNs Autoregressive transformers can be written as a function that receives an input x i , modifies the internal state { s i − 1 , z i − 1 } and predicts an output y i . A. Katharopoulos Transformers are RNNs 9/17

  25. Transformers are RNNs Autoregressive transformers can be written as a function that receives an input x i , modifies the internal state { s i − 1 , z i − 1 } and predicts an output y i . Autoregressive inference with linear complexity and constant memory . A. Katharopoulos Transformers are RNNs 9/17

  26. Practical implications ◮ Our theoretical analysis holds for all transformers even when using infinite dimensional feature maps A. Katharopoulos Transformers are RNNs 10/17

  27. Practical implications ◮ Our theoretical analysis holds for all transformers even when using infinite dimensional feature maps ◮ We need a simple finite dimensional feature map to speed up computation A. Katharopoulos Transformers are RNNs 10/17

  28. Practical implications ◮ Our theoretical analysis holds for all transformers even when using infinite dimensional feature maps ◮ We need a simple finite dimensional feature map to speed up computation ◮ We derive the gradients as cumulative sums which allows for a significant speed-up A. Katharopoulos Transformers are RNNs 10/17

  29. Experimental setup Baselines ◮ Softmax transformer (Vaswani et al., 2017) ◮ LSH attention from Reformer (Kitaev et al., 2020) Experiments ◮ Artificial benchmark for computational and memory requirements ◮ Autoregressive image generation on MNIST and CIFAR-10 ◮ Automatic speech recognition on Wall Street Journal A. Katharopoulos Transformers are RNNs 11/17

  30. Experimental setup Baselines ◮ Softmax transformer (Vaswani et al., 2017) ◮ LSH attention from Reformer (Kitaev et al., 2020) Experiments ◮ Artificial benchmark for computational and memory requirements ◮ Autoregressive image generation on MNIST and CIFAR-10 ◮ Automatic speech recognition on Wall Street Journal A. Katharopoulos Transformers are RNNs 11/17

  31. Experimental setup Baselines ◮ Softmax transformer (Vaswani et al., 2017) ◮ LSH attention from Reformer (Kitaev et al., 2020) Experiments ◮ Artificial benchmark for computational and memory requirements ◮ Autoregressive image generation on MNIST and CIFAR-10 ◮ Automatic speech recognition on Wall Street Journal A. Katharopoulos Transformers are RNNs 11/17

  32. Benchmark GPU Memory (MB) 10 2 Time (milliseconds) 10 3 10 1 10 2 10 0 10 1 2 10 2 12 2 14 2 16 2 10 2 12 2 14 2 16 Sequence Length Sequence Length softmax lsh-1 lsh-4 lsh-8 linear (ours) A. Katharopoulos Transformers are RNNs 12/17

  33. Benchmark GPU Memory (MB) 10 2 Time (milliseconds) 10 3 10 1 10 2 10 0 10 1 2 10 2 12 2 14 2 16 2 10 2 12 2 14 2 16 Sequence Length Sequence Length softmax lsh-1 lsh-4 lsh-8 linear (ours) A. Katharopoulos Transformers are RNNs 12/17

Recommend


More recommend