Rethinking Self Attention with Kernel and Rank (Towards More Efficient and Effective Transformer)


< 목차 >


Motivation

이번 Post 는 Rethinking Attention with Performers 라는 논문의 내용을 기반으로 작성되었습니다. 이 논문의 motivation 은 "Transformer 의 Attention 연산을 잘 분석해서 모델의 Space Complexity 와 Time Compelxtiy를 줄여보자" 입니다.

LRA_benchmark Fig. Long Range Arena Benchmark Score. Performer 는 Transformer 와 비교했을 때 성능면에서 조금 손해가 있으나 Long Range Sequence 에 대해 더 적은 메모리가 필요하며 (원이 작을수록 적은 메모리가 필요), 속도는 5배 이상 빠르다.

Performer (2020)

Standard Attention Module

Vanilla Tranformer 의 Scaled Dot Product Self-Attention (SA) Mechanism 을 Recap 해 봅시다. 이는 아래와 같은 수식을 따르는데요,

\[\operatorname{Attention}(Q, K, V)=\operatorname{softmax}\left(\frac{Q K^T}{\sqrt{d_k}}\right) V\]

Layer 의 Input Token 을 Query (Q), Key (K), Value (V) 로 변환한 뒤 \(Q^TK\) 를 통해 Token 간의 Relationship 을 계산하고 V 를 곱해 Token 들간의 정보를 서로 Mixing 해줍니다.

paltarion_mha Fig. Attention From Inputs to Outputs (Cross-Attention). Source from peltarion’s post

이 때의 시간 복잡도 (Time Compleixty) 와 공간 복잡도 (Space Complexity) 는 다음과 같은데

  • Space Complexity : \(O(L^2 + Ld)\)
  • Time Complexity : \(O(L^2d)\)

이는 \(Q,K,V\) 가 각각 \(\mathbb{R}^{L \times d}\) 일 때의 복잡도이며, 이 때 \(L\) 은 Sequence 의 길이 (Token의 개수) 이고 \(d\) 는 각 Token Vector 들의 크기 입니다.

attention_is_all_you_need_complexity Fig. Self Attention Mechanism’s Time Complexity

근데 왜 복잡도가 이렇게 Sequence 길이의 제곱 (Quadratic) 이 되는 걸까요?

맨 처음 입력 벡터들을 Q, K, V 로 각각 변환해 주는 것은 Linear Transformation 입니다. Q, K, V Projection 이 원래 차원을 유지하는 변환이라고 생각할 경우 \(d \times d\) 차원의 행렬과 \(L \times d\) 차원의 Input Matrix 를 곱하는 것이므로 \(O(n d^2)\) 의 연산량이 듭니다. 하지만 일반적으로 이 연산은 Self Attention 의 복잡도를 구할때 고려되지 않는다고 합니다. 이미 되어있다고 보고 Self-Attention 에 대해서만 계산하는거죠.

SA_time_comlexity Fig. Scaled Dot Product Self Attention. Source from link

SA_time_comlexity_with_proj Fig. Scaled Dot Product Multi Head Self Attention with Q, K, V Projection. Source from link

그다음으로는 \(Q^T K\) 인 행렬 곱을 해야되는데요, 어떤 \(n \times m\) 크기의 \(A\) 와 \(m \times p\) 크기의 \(B\) 행렬들이 있을 때 두 행렬의 곱인 \(C=AB\) 는 \(n \times p\) 크기가 되며

\[C_{ij} = \sum_{k=1}^{n} A_{ik} B_{kj}\]

이 때 연산은 아래와 같이 단순하게 구현할 수 있습니다.

Input: matrices A and B
Let C be a new matrix of the appropriate size
For i from 1 to n:
    For j from 1 to p:
        Let sum = 0
        For k from 1 to m:
            Set sum ← sum + Aik × Bkj
        Set Cij ← sum
Return C

즉 \(1 \times m\) 크기의 벡터들 p개 끼리 내적 (inner product; dot product) 하는 것을 n 번 반복 하는게 되는 겁니다. 이때의 time complexity 는 \(O(nmp)\) 가 되는데 만일 \(A,B\) 가 모두 \(n \times n\) 이라면 \(O(n^3)\) 가 됩니다. 그러니까 Self-Attention 의 경우 \(L \times d\) 인 두 행렬 \(Q,K\) 에 대해 \(Q^T K\) 연산을 취하므로 \(O(LdL) = O(L^2d)\) 만큼의 시간 복잡도가 드는 것입니다.

\[\frac{Q K^T}{\sqrt{d_k}}\]

여기에 \(\sqrt{d_k}\) 로 행렬을 Normalize 해주고 Row-wise 로 Softmax 연산을 취해주는 것은 row vector 를 순회해서 sum of exponential 을 해줘야 하는 것을 Token 갯수만큼 반복해야 하기 때문에 \(O(Ld)\) 가 될것이고,

\[\operatorname{softmax}\left(\frac{Q K^T}{\sqrt{d_k}}\right)\]

마지막으로 \(V\) 를 곱하는것 또한 \(L \times L\) 행렬과 \(L \times d\) 간의 행렬곱이기 때문에 \(O(L^2 d)\) 만큼의 연산이 들기 때문에 최종적으로 \(O(2L^2d + Ld) = O(L^2d)\) 가 되는 것입니다.

Scaled Dot Product Self Attention in Performer Paper

Scaled Dot Product Self Attention 을 Performer 논문에서는 아래와 같이 표현하는데요,

\[\operatorname{Att}_{\leftrightarrow}(\mathbf{Q}, \mathbf{K}, \mathbf{V})=\mathbf{D}^{-1} \mathbf{A V}, \quad \mathbf{A}=\exp \left(\mathbf{Q} \mathbf{K}^{\top} / \sqrt{d}\right), \quad \mathbf{D}=\operatorname{diag}\left(\mathbf{A} \mathbf{1}_L\right)\] \[\operatorname{Att}_{\rightarrow}(\mathbf{Q}, \mathbf{K}, \mathbf{V})=\widetilde{\mathbf{D}}^{-1} \widetilde{\mathbf{A}} \mathbf{V}, \quad \widetilde{\mathbf{A}}=\operatorname{tril}(\mathbf{A}), \quad \widetilde{\mathbf{D}}=\operatorname{diag}\left(\tilde{\mathbf{A}} \mathbf{1}_L\right)\]

첫 번째는 일반적인 Self Attention 으로 Encoder 에서 수행되는 것이며, 두 번째는 Decoder 에서 수행되는 Unidirectional Self Attention 입니다.

Notation 이 좀 헷갈리실 수 있는데요, Q, K, V 는 원 논문의 \(L \times d\) 차원의 Query, Key, Value 이며 \(\mathbb{1}_L\) 은 모든 element 를 1로 갖는 길이가 L 인 vector 입니다.

먼저 \(Q,K\) 를 이용해 \(L \times L\) 차원의 Attention Score Matrix \(A\) 를 만들 수 있죠, 여기서 우리가 더 해줘야 하는것은

  • 입력 벡터의 크기로 Scaling 해주는 것
  • Softmax 함수로 Normalization 해줄 것
  • V 를 곱해 최종적인 Context Vector 들을 만들 것

입니다.

Scailing 은 그냥 해주면 되고요,

\[\mathbf{A}=\exp \left(\mathbf{Q} \mathbf{K}^{\top} / \sqrt{d}\right)\]

Softmax 를 그냥 해줄수도 있지만 논문에서는 \(\mathbb{1}_L\) 를 이용해서 이를 처리합니다. 이를 위해 먼저 아래의 연산을 해주는데

\[\mathbf{D}=\operatorname{diag}\left(\mathbf{A} \mathbf{1}_L\right)\]

행렬 \(A\) 에 \(L \times 1\) 짜리 \(\mathbb{1}_L\) 를 곱해주면 row-wise 로 모두 더하는 연산을 해주게 됩니다.

\[\mathbf{A} \mathbf{1}_L = \begin{bmatrix} A_{11} & A_{12} & \cdots & A_{1L} \\ \vdots & \ddots & \cdots & \cdots \\ A_{L1} & A_{L2} & \cdots & A_{LL} \\ \end{bmatrix} \begin{bmatrix} 1 \\ \vdots \\ 1 \\ \end{bmatrix} = \begin{bmatrix} \sum_{i=1}^{L} A_{1i} \\ \vdots \\ \sum_{i=1}^{L} A_{Li} \\ \end{bmatrix}\]

처음에 정의할 때 \(A\) matrix 자체를 \(QK^T\) 행렬에 exponential 연산까지 한 것으로 정의 했기 때문에 이 연산 자체가 Softmax 의 분모가 되고,

\[Softmax (z) = \frac{ \exp (z_i) }{ \sum_{i=1}^{k}{ \exp (z_i) } } \text{ , where } z \in \mathbb{R}^{k}\]

이 \(\mathbf{A} \mathbf{1}_L\) vector를 다시 \(L \times L\) 행렬로 만들기 위해 어떤 input vector 를 diagnoal 로 갖는 행렬로 바꿔주는 새로운 연산자 \(diag\) 를 정의한 뒤

\[\mathbf{D}=\operatorname{diag}\left(\mathbf{A} \mathbf{1}_L\right) = \begin{bmatrix} \sum_{i=1}^{L} A_{1i} & 0 & \cdots & 0 \\ 0 & \sum_{i=1}^{L} A_{2i} & 0 & 0 \\ \vdots & 0 & \ddots & \vdots \\ 0 & 0 & \cdots & \sum_{i=1}^{L} A_{Li} \\ \end{bmatrix}\]

이 \(D\) 행렬의 역행렬을 곱해주면

\[\mathbf{D^{-1}} = \begin{bmatrix} \frac{1}{\sum_{i=1}^{L} A_{1i}} & 0 & \cdots & 0 \\ 0 & \frac{1}{\sum_{i=1}^{L} A_{2i}} & 0 & 0 \\ \vdots & 0 & \ddots & \vdots \\ 0 & 0 & \cdots & \frac{1}{\sum_{i=1}^{L} A_{Li}} \\ \end{bmatrix}\]

사실상 \(Softmax (\frac{QK^T}{\sqrt{d}})\) 연산을 한 것이나 다름 없고,

\[\mathbf{D}^{-1} \mathbf{A}\]

마지막으로 \(V\) 를 곱해주면 Scaled Dot Proudct Self Attention 을 행렬 곱으로만 표현할 수 있게 됩니다.

\[\operatorname{Att}_{\leftrightarrow}(\mathbf{Q}, \mathbf{K}, \mathbf{V})=\mathbf{D}^{-1} \mathbf{A V}, \quad \mathbf{A}=\exp \left(\mathbf{Q} \mathbf{K}^{\top} / \sqrt{d}\right), \quad \mathbf{D}=\operatorname{diag}\left(\mathbf{A} \mathbf{1}_L\right)\]

이를 어떤 행렬의 Diagonal 을 포함한 Lower Triangle part 만을 취하는 연산자 \(tril(\cdot)\) 을 정의해서 아래처럼 쓰면

\[\operatorname{Att}_{\rightarrow}(\mathbf{Q}, \mathbf{K}, \mathbf{V})=\widetilde{\mathbf{D}}^{-1} \widetilde{\mathbf{A}} \mathbf{V}, \quad \widetilde{\mathbf{A}}=\operatorname{tril}(\mathbf{A}), \quad \widetilde{\mathbf{D}}=\operatorname{diag}\left(\tilde{\mathbf{A}} \mathbf{1}_L\right)\]

Transformer Decoder 모듈의 Unidirectional Self Attention 까지 표현할 수 있게 되었습니다.

왜 이렇게까지 해야하느냐? 이는 곧 논문을 살펴보면서 알게 될겁니다.

Standard Sparsification Techniques and Limitations

Self-Attention 의 한계점은 Sequence 의 길이 \(L\)이 늘어나면 늘어날수록 필요한 Memory 와 Latency 가 늘어난다는 데 있습니다. 이를 해결하기 위해서 다양한 Self Attention 의 변형체들이 매 해 수없이 등장해왔는데 그 중 대표적인 논문들로 아래와 같은 논문들이 있었습니다.

standard_sparsification_techniques Fig.

(1) They require efficient sparse-matrix multiplication operations, which are not available on all accelerators

(2) they usually do not provide rigorous theoretical guarantees for their representation power

(3) they are optimized primarily for Transformer models and generative pre-training

(4) they usually stack more attention layers to compensate for sparse representations, making them difficult to use with other pre-trained models, thus requiring retraining and significant energy consumption.

In addition to these shortcomings, sparse attention mechanisms are often still not sufficient to address the full range of problems to which regular attention methods are applied, such as Pointer Networks.

There are also some operations that cannot be sparsified, such as the commonly used softmax operation, which normalizes similarity scores in the attention mechanism and is used heavily in industry-scale recommender systems.

Fast Attention Via positive Orthogonal Random features (FAVOR+)

논문에서 제안하는 알고리즘의 이름은 Fast Attention Via positive Orthogonal Random features (FAVOR+) 입니다. 이제 아래의 순서대로 내용을 살펴볼 것입니다.

  • Fast Attention (FA)
  • Random Features (R)
  • Orthogonal (O)

FA part of FAVOR+ (Generalized Kernelizable Attention)

\[\mathrm{K}(\mathbf{x}, \mathbf{y})=\mathbb{E}\left[\phi(\mathbf{x})^{\top} \phi(\mathbf{y})\right]\] \[\widehat{\operatorname{Att}_{\leftrightarrow}}(\mathbf{Q}, \mathbf{K}, \mathbf{V})=\widehat{\mathbf{D}}^{-1}\left(\mathbf{Q}^{\prime}\left(\left(\mathbf{K}^{\prime}\right)^{\top} \mathbf{V}\right)\right), \quad \widehat{\mathbf{D}}=\operatorname{diag}\left(\mathbf{Q}^{\prime}\left(\left(\mathbf{K}^{\prime}\right)^{\top} \mathbf{1}_L\right)\right)\]

approximated_via_lower_rank_matrices Fig.

FA Fig.

  • Space Complexity : \(O(Lr + Ld + rd)\)
  • Time Complexity : \(O(Lrd)\)

Prefix-Sum Computation

autoregressive_fig Fig.

autoregressive Fig.

R+ part of FAVOR+ (How to and How not to Approximate Softmax-Kernels for Attention)

performer_paper_fig2 Fig.

O part of FAVOR+ (Orthogonal Random Features (ORFS))

Theoretical Results

Pseudocode for FAVOR+

pseudocode_performer Fig.

Experimental Results of FAVOR+

google_blog_result_fig1 Fig.

google_blog_result_fig2 Fig.

google_blog_result_fig3 Fig.

google_blog_result_fig4 Fig.

google_blog_result_fig5 Fig.

Long Range Arena

performer_LRA_score Fig.

Connection to Linformer

Analysis on Rank of Sefl-Attention

Diagonlaization and Singular Vector Decomposition (SVD)

\[M_{m \times m}=P_{m \times m} D_{m \times m} P_{m \times m}^{-1}\] \[\begin{aligned} M & =P D P^{-1} \\ M(P) & =P D P^{-1}(P) \\ M P & =P D \end{aligned}\] \[\begin{aligned} & P=\left[a_1, a_2, \ldots a_m\right] \\ & D=\left[\begin{array}{cccc} \lambda_1 & 0 & \ldots & \ldots \\ 0 & \lambda_2 & 0 & \ldots \\ \vdots & \ldots & \ddots & \ldots \\ 0 & \ldots & 0 & \lambda_m \end{array}\right] \\ & P D=\left[\lambda_1 a_1, \lambda_2 a_2, \ldots \lambda_m a_m\right] \\ & \end{aligned}\]

하지만 모든 행렬에 대해 Diagonalization 을 수행할 수 있는건 (diagonlizable) 아닌데요, 즉 행렬 \(M\) 의 크기가 정사각형 \(n \times n\) 이어야 하며 Invertible 한 경우에만 이 연산이 가능합니다. 이를 해결하기 위해 Singular Vecotr Decomposition (SVD) 를 사용할 수 있는데요,

\[\begin{aligned} M & =U S V^T \\ \end{aligned}\]

어떤 행렬 \(M\) 에 대해서 SVD 를 수행하면 아래와 같이 \(U,S,V\) 행렬을 얻을 수 있게 됩니다.

svd_4x4 svd_4x8 svd_8x4 Fig. Vissualization of Results of SVD. Source From here

SVD 의 결과물을 보면 몇 가지 특이한 점을 알 수 있는데요, 바로

  • U와 V의 모든 Column Vector 들은 크기 (norm) 가 1이다.
  • U와 V의 Column Vector 들 간의 내적 (inner product) 을 취한 n x n Matrix 는 Identity Matrix 이다.
  • U는 원본 행렬 M 의 \(MM^T\) 의 Eigenvector 를 가지고 있다.
  • V는 원본 행렬 M 의 \(M^TM\) 의 Eigenvector 를 가지고 있다.
  • S는 \(M^TM\) 와 \(MM^T\) 와 관련된 Eigenvalue 의 sqaure root 를 가지고 있다.

입니다.

svd_eigen_decomposition_equivalence Fig. Eigenvalue Decomposition 과 SVD 의 관계. Source From here

SVD 를 잘 사용하면 여러 장점이 있지만 그 중에서 Low-Rank Matrix Approximation (LRA) 이라는 것이 가능해 집니다. 이는 \(US\) 와 \(V\) Matrix 들을 더 작은 크기의 Matrix 로 바꾸는 것인데요, SVD는 만약 Matrix 가 \(4 \times 4\) 일 경우 \(U \in \mathbb{R}^{4 \times 4}\), \(S \in \mathbb{R}^{4 \times 4}\), \(V \in \mathbb{R}^{4 \times 4}\) 를 얻을 수 있었죠,

\[\begin{aligned} M & =U S V^T \\ \end{aligned}\]

하지만 우리는 이를 아래처럼 근사할 수 있습니다.

\[\begin{aligned} M & \approx U_k S_k V_k^T, \text{ where } k = 2 \\ & \approx \hat{M}_k \end{aligned}\]

즉 우리는 \(U \in \mathbb{R}^{4 \times 2}\), \(S \in \mathbb{R}^{2 \times 2}\), \(V \in \mathbb{R}^{2 \times 4}\) 를 얻을 수 얻게 되는 건데요, 단순히 Matrix 가 Full Rank 일 때, 즉 Column 이 4일 때 Rank 가 4 인 경우에 이런 근사를 해버리면 좋지 않겠지만 우리가 만약 \(8 \times 8\) 크기의 행렬이지만 Rank 가 4 밖에 안되는 행렬을 가지고 있다면

rank_matrices Fig. Matrix with Redundant Columns. Source From here

원래의 SVD를

\[\begin{aligned} M & =U S V^T \\ & =(U S) V^T \\ & =L R^T, \text { where } \\ L & =(U S), \text { and } \\ R & =V \end{aligned}\]

근사해버려도 된다는 겁니다.

\[\begin{aligned} M & =L R^T \\ & \approx L_k R_k^T \\ & \approx \hat{M} \end{aligned}\]

LRA_matrix Fig. \(m \times n\) 크기의 Matrix \(M\) 을 \(m \times k\), \(k \times n\) 로 분해할 수 있다. 만약 \(k\) 가 데이터의 rank \(r\) 과 같다면 (\(r=k\)) Matrix \(M\) 은 완전히 Decomposition 으로부터 복원될 수 있고 \(k\)가 \(r\)보다 작으면, 즉 \(k<r\) 이면 근사된 Matrix, 즉 Low-Rank Approximated Matrix \(\hat{M}\) 을 얻을 수 있다.. Source From here

이렇게 근사를 해서 얻은 \(L, R (U,S,V)\) 를 사용해도 Redundant Column 이 포함되어 있는 행렬 \(\tilde{X}\) 를 복구하는데는 문제가 없다고 합니다.

rank_matrices_svd Fig. . Source From here

Self-Attention is Low Rank

한편 우리가 Transformer 에서 근사하고 싶은 연산은 \(QK^T\) 연산 이었습니다.

\[P = \text{softmax} [ \frac{QW_i^Q (KW_i^K)^T}{\sqrt{d}} ] = \exp (A) \cdot D_A^{-1}\]

Performer 논문에서와 같이 \(D_A\) 는 \(n \times n\) 의 diagonal matrix 이며, Self Attention 을 위와 같이 나타낼 때 \(QK^T\) 행렬은 Sequence Length \(L\)에 대해 \(L^2\) 만큼의 연산량이 드는 것으로 \(A \in \mathbb{R}^{L \times L}\) 크기를 갖는데요, Linformer 에서는 이를 Eigenvalue 를 사용해 분석해봤습니다.

annotated_encoder_layer_2 Fig. Visualization of Self Attention of Transformer Encdoer 2nd Layer. Source from here

annotated_encoder_layer_6 Fig. Visualization of Self Attention of Transformer Encdoer 6th Layer. Source from here

(논문에서는 Masked Language Modeling (MLM) 방식으로 학습된 Transformer Encoder Only Model, RoBerta 를 분석했지만 제가 가지고 온 Figure 는 Encoder-Decoder 구조의 번역 데이터를 학습 한것이라는 점을 인지해 주시기 바랍니다.)

논문에서는 Transformer 의 층별로 존재하는 12개 Head 들에서 계산된 Context Mapping Matrix (\(QK^T\) 계산을 하고 Scailing 한 것을 Row-wise 로 Normalization 한 것), P를 SVD 연산을 합니다.

이를 10,000개 Setences 에 대해서 한 뒤 누적시켜서 평균을 내고 Normalization 를 한 Normalized Cumulative Singular Value Averaged over 10k Sentences 값을 구해서 결과를 봤는데요

Linformer_fig1 Fig. Spectrum Analysis of the Context Mapping Matrix. 오른쪽은 전체 512개 Token 중에서 128번째로 큰 Singular Value 에서의 Normalized Cumulative Singular Value 에 대한 Heatmap

그 결과 얻은 각 Layer, Head 별로 Long-tail Spectrum Distribution 을 얻었다고 합니다. 이는 대부분의 Attention Map, \(P\) 는 처음 몇개의 큰 Singular Value 들만 있으면 복원이 된다는 것이죠. 이러한 현상은 낮은 Layer 보다 높은 Layer 에서 두드려졌다고 하는데요, 즉 위로 갈수록 Information 이 한쪽으로 집중되는 현상이 컸다는 겁니다. \(P\) 의 Rank 가 갈수록 낮았던 거죠.

\[\tilde{P} = \exp (A) \cdot D_A^{-1} R^T R, \text{ where } R \in \mathbb{R}^{k \times n} \text{ with i.i.d entries from } N(0, \frac{1}{K})\] \[P \approx P_{\text {low }}=\sum_{i=1}^k \sigma_i u_i v_i^T=\underbrace{\left[u_1, \cdots, u_k\right]}_k \operatorname{diag}\left\{\sigma_1, \cdots, \sigma_k\right\}\left[\begin{array}{c} v_1 \\ \vdots \\ v_k \end{array}\right]\} k\]

Linformer Method

이러한 Encoder 의 Self Attention 은 Low Rank 라는 특성을 이용해 저자들은 Layer 와 Head 별로 \(E_i, F_i \in \mathbb{R}^{n \times k}\) 라는 두 개의 Linear Projection Matrix 를 도입해서 K, V 를 계산할 때 쓰기로 하는데요,

\[\begin{aligned} \overline{\text { head }_i} & =\operatorname{Attention}\left(Q W_i^Q, \color{red}{E_i} K W_i^K, \color{blue}{F_i} V W_i^V\right) \\ & =\underbrace{\operatorname{softmax}\left(\frac{Q W_i^Q\left( \color{red}{E_i} K W_i^K\right)^T}{\sqrt{d_k}}\right)}_{\bar{P}: n \times k} \cdot \underbrace{ \color{blue}{F_i} V W_i^V}_{k \times d}, \end{aligned}\]

즉 사실상 Performer 과 거의 유사하게 Attention Score Map 을 Approximation 을 해버립니다.

Linformer_fig2_og Fig.

Performer vs Linformer

Performer 와 Linformer 의 차이는 첫 번째로

performer_fig3 Fig.

performer_fig5 Fig.

performer_fig6 Fig.

Pytorch Implementation

Performer

Source from lucidrains’ implementation

# non-causal linear attention
def linear_attention(q, k, v):
    k_cumsum = k.sum(dim = -2)
    D_inv = 1. / torch.einsum('...nd,...d->...n', q, k_cumsum.type_as(q))
    context = torch.einsum('...nd,...ne->...de', k, v)
    out = torch.einsum('...de,...nd,...n->...ne', context, q, D_inv)
    return out
# efficient causal linear attention, created by EPFL
# TODO: rewrite EPFL's CUDA kernel to do mixed precision and remove half to float conversion and back
def causal_linear_attention(q, k, v, eps = 1e-6):
    from fast_transformers.causal_product import CausalDotProduct
    autocast_enabled = torch.is_autocast_enabled()
    is_half = isinstance(q, torch.cuda.HalfTensor)
    assert not is_half or APEX_AVAILABLE, 'half tensors can only be used if nvidia apex is available'
    cuda_context = null_context if not autocast_enabled else partial(autocast, enabled = False)

    causal_dot_product_fn = amp.float_function(CausalDotProduct.apply) if is_half else CausalDotProduct.apply

    k_cumsum = k.cumsum(dim=-2) + eps
    D_inv = 1. / torch.einsum('...nd,...nd->...n', q, k_cumsum.type_as(q))

    with cuda_context():
        if autocast_enabled:
            q, k, v = map(lambda t: t.float(), (q, k, v))

        out = causal_dot_product_fn(q, k, v)

    out = torch.einsum('...nd,...n->...nd', out, D_inv)
    return out
class FastAttention(nn.Module):
    def __init__(self, dim_heads, nb_features = None, ortho_scaling = 0, causal = False, generalized_attention = False, kernel_fn = nn.ReLU(), no_projection = False):
        super().__init__()
        nb_features = default(nb_features, int(dim_heads * math.log(dim_heads)))

        self.dim_heads = dim_heads
        self.nb_features = nb_features
        self.ortho_scaling = ortho_scaling

        self.create_projection = partial(gaussian_orthogonal_random_matrix, nb_rows = self.nb_features, nb_columns = dim_heads, scaling = ortho_scaling)
        projection_matrix = self.create_projection()
        self.register_buffer('projection_matrix', projection_matrix)

        self.generalized_attention = generalized_attention
        self.kernel_fn = kernel_fn

        # if this is turned on, no projection will be used
        # queries and keys will be softmax-ed as in the original efficient attention paper
        self.no_projection = no_projection

        self.causal = causal
        if causal:
            try:
                import fast_transformers.causal_product.causal_product_cuda
                self.causal_linear_fn = partial(causal_linear_attention)
            except ImportError:
                print('unable to import cuda code for auto-regressive Performer. will default to the memory inefficient non-cuda version')
                self.causal_linear_fn = causal_linear_attention_noncuda

    @torch.no_grad()
    def redraw_projection_matrix(self, device):
        projections = self.create_projection(device = device)
        self.projection_matrix.copy_(projections)
        del projections

    def forward(self, q, k, v):
        device = q.device

        if self.no_projection:
            q = q.softmax(dim = -1)
            k = torch.exp(k) if self.causal else k.softmax(dim = -2)

        elif self.generalized_attention:
            create_kernel = partial(generalized_kernel, kernel_fn = self.kernel_fn, projection_matrix = self.projection_matrix, device = device)
            q, k = map(create_kernel, (q, k))

        else:
            create_kernel = partial(softmax_kernel, projection_matrix = self.projection_matrix, device = device)
            q = create_kernel(q, is_query = True)
            k = create_kernel(k, is_query = False)

        attn_fn = linear_attention if not self.causal else self.causal_linear_fn
        out = attn_fn(q, k, v)
        return out

References