skip to content
Ball's Blog

Mathematical Formulation of Linear Attention

/ 4 min read

Updated:
Table of Contents

Summary

Linear attention is a technique that modifies the traditional softmax attention mechanism to reduce computational complexity while maintaining fixed memory usage. This article explores the mathematical formulation of linear attention, its advantages over softmax attention.

Terminology

  • qtR1×dq_t \in \mathbb{R}^{1 \times d}: the query vector at time tt
  • ktR1×dk_t \in \mathbb{R}^{1 \times d}: the key vector at time tt
  • vtR1×dv_t \in \mathbb{R}^{1 \times d}: the value vector at time tt
  • otR1×do_t \in \mathbb{R}^{1 \times d}: the output vector at time tt
  • QRL×dQ \in \mathbb{R}^{L \times d}: the matrix of all query vectors
  • KRL×dK \in \mathbb{R}^{L \times d}: the matrix of all key vectors
  • VRL×dV \in \mathbb{R}^{L \times d}: the matrix of all value vectors
  • ORL×dO \in \mathbb{R}^{L \times d}: the matrix of all output vectors
  • LL: the sequence length
  • dd: the dimension of the query, key, and value vectors
  • MRL×LM \in \mathbb{R}^{L \times L}: the causal mask matrix, where Mij=0M_{ij} = 0 if iji \geq j and -inf\text{-inf} otherwise
  • StRd×dS_{t} \in \mathbb{R}^{d \times d}: the state matrix at ttth token, which is the sum of outer products of key and value vectors up to time tt: St=i=1tkiTviS_{t} = \sum_{i=1}^{t} k_i^T v_i
  • GtRd×dG_t \in \mathbb{R}^{d \times d}: the forgetting gate matrix at ttth token, which controls the contribution of each token to the state matrix
  • \odot: element-wise multiplication operator

Recap of Softmax Attention

Let’s first recall the formulation of softmax attention.

Parallel form:

O=softmax(QKT+M)VO = \text{softmax}\left(QK^T + M\right)V

Recurrent form:

Ot=j=1texp(qtkjT)l=1texp(qtklT)vjO_t = \sum_{j=1}^t \frac{\text{exp}\left(q_tk_j^T\right)}{\sum_{l=1}^t \text{exp}\left(q_tk_l^T\right)} v_j

The time complexity of softmax attention for prefill is O(L2d)\text{O}\left(L^2d\right).

You may wonder why. Let’s think step-by-step by looking at parallel form equation.

  1. QKTQK^T: Matmul takes O(L2d)\text{O}\left(L^2d\right) time.
  2. QKT+MQK^T + M: Summing the mask takes O(L2)\text{O}\left(L^2\right) time.
  3. softmax(QKT+M)\text{softmax}\left(QK^T + M\right): Softmax takes O(L2)\text{O}\left(L^2\right) time.
  4. softmax(QKT+M)V\text{softmax}\left(QK^T + M\right)V: Matmul takes O(L2d)\text{O}\left(L^2d\right) time. Therefore, the overall time complexity is dominated by the O(L2d)\text{O}\left(L^2d\right) terms, resulting in O(L2d)\text{O}\left(L^2d\right) time complexity for softmax attention.

Formulation of Linear Attention

Linear attention modifies the softmax attention by removing the softmax function. The formulation of linear attention is as follows:

Parallel form:

O=(QKTMask)VO = (QK^T \odot \text{Mask})V

where MaskRL×L\text{Mask} \in \mathbb{R}^{L \times L} is the lower triangular causal mask with Maskij=1\text{Mask}_{ij} = 1 if iji \geq j and 00 otherwise.

Recurrent form:

Ot=j=1t(qtkjT)vj=j=1tqt(kjTvj)=qtj=1tkjTvjO_t = \sum_{j=1}^t \left(q_tk_j^T\right)v_j = \sum_{j=1}^t q_t\left(k_j^Tv_j\right) = q_t \sum_{j=1}^t k_j^Tv_j

The time complexity of linear attention for prefill is O(Ld2)\text{O}\left(Ld^2\right).

You may wonder why. Let’s think step-by-step by looking at recurrent form equation.

  1. kjTvjk_j^Tv_j: Matmul takes O(d2)\text{O}\left(d^2\right) time.
  2. j=1LkjTvj\sum_{j=1}^L k_j^Tv_j: Summing over LL tokens takes O(Ld2)\text{O}\left(Ld^2\right) time.

Expanding Linear Attention with State Matrix

We can further optimize linear attention by introducing a state matrix StS_t that accumulates the outer products of key and value vectors up to time tt. The formulation becomes:

Ot=j=1t(qtkjT)vj=j=1tqt(kjTvj)=qtj=1tkjTvj=qtSt\begin{aligned} O_t &= \sum_{j=1}^t \left(q_tk_j^T\right)v_j \\ &= \sum_{j=1}^t q_t\left(k_j^Tv_j\right) \\ &= q_t \sum_{j=1}^t k_j^Tv_j \\ &=q_t S_t \end{aligned}

where St=j=1tkjTvjS_t = \sum_{j=1}^t k_j^T v_j.

Each time we receive a new token, we can update the state matrix StS_t by adding the outer product of the new key and value vectors:

St=St1+ktTvtS_t = S_{t-1} + k_t^T v_t

Limitation of Linear Attention

While linear attention reduces the computational complexity, it has a limitation in terms of expressiveness. The context is stored in a single state matrix StRd×dS_t \in \mathbb{R}^{d \times d}. This means that the model can only capture interactions between keys and values in a limited way, as all information is compressed into a single matrix.

State matrix StS_t can be viewed as a reconstructing function that reconstructs the value vector vjv_j from the key vector kjk_j.

Let’s see why. Let’s try to reconstruct vlv_l from StS_t: Assume kik_i‘s are normalized to unit length.

klTSt=j=1tklkjTvj=vl+jlklkjTvj\begin{aligned} k_l^T S_t &= \sum_{j=1}^t k_l k_j^T v_j \\ &= v_l + \sum_{j \ne l} k_l k_j^T v_j \end{aligned}

As you can see from the equation, we can reconstruct vlv_l from StS_t by multiplying StS_t with klTk_l^T. However, there is a noise term jlklkjTvj\sum_{j \ne l} k_l k_j^T v_j that comes from other tokens. This noise term can make it difficult for the model to accurately capture the interactions between keys and values, especially when the sequence length is long.

How modern Linear Attention Are Evolving

To address the limitation of linear attention, modern linear attention mechanisms such as Mamba2, GLA use decaying(forgetting) mechanism to reduce the noise term. By applying a decay factor to the state matrix StS_t, the model can give more weight to recent tokens and less weight to older tokens, which helps to mitigate the noise from distant tokens.

St=GtSt1+ktTvt\begin{aligned} S_t = G_t \odot S_{t-1} + k_t^T v_t \end{aligned}

GtG_t differs per mechanism.

References