Table of Contents
Summary
Linear attention is a technique that modifies the traditional softmax attention mechanism to reduce computational complexity while maintaining fixed memory usage. This article explores the mathematical formulation of linear attention, its advantages over softmax attention.
Terminology
- : the query vector at time
- : the key vector at time
- : the value vector at time
- : the output vector at time
- : the matrix of all query vectors
- : the matrix of all key vectors
- : the matrix of all value vectors
- : the matrix of all output vectors
- : the sequence length
- : the dimension of the query, key, and value vectors
- : the causal mask matrix, where if and otherwise
- : the state matrix at th token, which is the sum of outer products of key and value vectors up to time :
- : the forgetting gate matrix at th token, which controls the contribution of each token to the state matrix
- : element-wise multiplication operator
Recap of Softmax Attention
Let’s first recall the formulation of softmax attention.
Parallel form:
Recurrent form:
The time complexity of softmax attention for prefill is .
You may wonder why. Let’s think step-by-step by looking at parallel form equation.
- : Matmul takes time.
- : Summing the mask takes time.
- : Softmax takes time.
- : Matmul takes time. Therefore, the overall time complexity is dominated by the terms, resulting in time complexity for softmax attention.
Formulation of Linear Attention
Linear attention modifies the softmax attention by removing the softmax function. The formulation of linear attention is as follows:
Parallel form:
where is the lower triangular causal mask with if and otherwise.
Recurrent form:
The time complexity of linear attention for prefill is .
You may wonder why. Let’s think step-by-step by looking at recurrent form equation.
- : Matmul takes time.
- : Summing over tokens takes time.
Expanding Linear Attention with State Matrix
We can further optimize linear attention by introducing a state matrix that accumulates the outer products of key and value vectors up to time . The formulation becomes:
where .
Each time we receive a new token, we can update the state matrix by adding the outer product of the new key and value vectors:
Limitation of Linear Attention
While linear attention reduces the computational complexity, it has a limitation in terms of expressiveness. The context is stored in a single state matrix . This means that the model can only capture interactions between keys and values in a limited way, as all information is compressed into a single matrix.
State matrix can be viewed as a reconstructing function that reconstructs the value vector from the key vector .
Let’s see why. Let’s try to reconstruct from : Assume ‘s are normalized to unit length.
As you can see from the equation, we can reconstruct from by multiplying with . However, there is a noise term that comes from other tokens. This noise term can make it difficult for the model to accurately capture the interactions between keys and values, especially when the sequence length is long.
How modern Linear Attention Are Evolving
To address the limitation of linear attention, modern linear attention mechanisms such as Mamba2, GLA use decaying(forgetting) mechanism to reduce the noise term. By applying a decay factor to the state matrix , the model can give more weight to recent tokens and less weight to older tokens, which helps to mitigate the noise from distant tokens.
differs per mechanism.
References
- Vaswani et al. Attention Is All You Need. NeurIPS 2017.
- Katharopoulos et al. Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention. ICML 2020.
- Schlag, Irie, Schmidhuber. Linear Transformers Are Secretly Fast Weight Programmers. ICML 2021.
- Yang et al. Gated Linear Attention Transformers with Hardware-Efficient Training. ICML 2024.
- Gu & Dao. Mamba: Linear-Time Sequence Modeling with Selective State Spaces. 2023.
- Dao & Gu. Transformers are SSMs: Generalized Models and Efficient Algorithms Through Structured State Space Duality. ICML 2024.
- Songlin Yang. DeltaNet Explained — Part I (covers linear attention foundations, state matrix interpretation)