skip to content
Ball's Blog

DeltaNet: Improved Linear Attention with Delta Rule

/ 14 min read

Updated:
Table of Contents

Prerequisites

Please read Mathematical Formulation of Linear Attention first.

Summary

DeltaNet applies the delta rule to linear attention, treating the state matrix as a regression model that maps keys to values. While a naive parallel scan approach leads to prohibitive O(Ld3logL)O(Ld^3\log L) time complexity, chunking combined with the WY representation enables an efficient parallel form with O(LCd+Ld2)O(LCd + Ld^2) complexity — matching that of standard linear attention.

Terminology

  • qtR1×dq_t \in \mathbb{R}^{1 \times d}: the query vector at time tt
  • ktR1×dk_t \in \mathbb{R}^{1 \times d}: the key vector at time tt
  • vtR1×dv_t \in \mathbb{R}^{1 \times d}: the value vector at time tt
  • otR1×do_t \in \mathbb{R}^{1 \times d}: the output vector at time tt
  • QRL×dQ \in \mathbb{R}^{L \times d}: the matrix of all query vectors
  • KRL×dK \in \mathbb{R}^{L \times d}: the matrix of all key vectors
  • VRL×dV \in \mathbb{R}^{L \times d}: the matrix of all value vectors
  • ORL×dO \in \mathbb{R}^{L \times d}: the matrix of all output vectors
  • LL: the sequence length
  • dd: the dimension of the query, key, and value vectors
  • StRd×dS_{t} \in \mathbb{R}^{d \times d}: the state matrix at ttth token, which is the sum of outer products of key and value vectors up to time tt: St=i=1tkiTviS_{t} = \sum_{i=1}^{t} k_i^T v_i
  • GtRd×dG_t \in \mathbb{R}^{d \times d}: the forgetting gate matrix at ttth token, which controls the contribution of each token to the state matrix
  • \odot: element-wise multiplication operator
  • II: the identity matrix
  • βtR\beta_t \in \mathbb{R}: the learning rate for delta rule

Delta Rule

Let’s say we have the following regression model:

y=xWwhereyR1×d is the output vectorWRd×d is the weight matrixxR1×d is the input vector\begin{aligned} y &= xW \\ \text{where} &\quad y \in \mathbb{R}^{1 \times d} \text{ is the output vector} \\ &\quad W \in \mathbb{R}^{d \times d} \text{ is the weight matrix} \\ &\quad x \in \mathbb{R}^{1 \times d} \text{ is the input vector} \end{aligned}

To train this model, we use the delta rule, which updates the weight matrix WW based on the MSE Loss between the predicted output yy and the ground truth y^\hat{y}:

MSE Loss=12y^y2=12y^xW2\begin{aligned} \text{MSE Loss} &= \frac{1}{2} \| \hat{y} - y \|^2 \\ &= \frac{1}{2} \| \hat{y} - xW \|^2 \\ \end{aligned} MSE LossW=xT(y^xW)=xTy^+xTxW=xTy^+(xTx)W\begin{aligned} \frac{\partial \text{MSE Loss}}{\partial W} &= -x^T (\hat{y} - xW) \\ &= -x^T \hat{y} + x^T x W \\ &= -x^T \hat{y} + (x^T x) W \end{aligned} Wnew=WoldβMSE LossW=Wold+βxTy^β(xTx)Wold=(IβxTx)Wold+βxTy^(Eq.1)\begin{aligned} W_{\text{new}} &= W_{\text{old}} - \beta \frac{\partial \text{MSE Loss}}{\partial W} \\ &= W_{\text{old}} + \beta x^T \hat{y} - \beta (x^T x) W_{\text{old}} \\ &= (I - \beta x^T x) W_{\text{old}} + \beta x^T \hat{y} \tag{Eq.1} \end{aligned}

Recap of Linear Attention

Let’s first recall the formulation of linear attention.

Recurrent form:

ot=j=1t(qtkjT)vj=j=1tqt(kjTvj)=qtj=1tkjTvj=qtSt\begin{aligned} o_t &= \sum_{j=1}^t \left(q_tk_j^T\right)v_j \\ &= \sum_{j=1}^t q_t\left(k_j^Tv_j\right) \\ &= q_t \sum_{j=1}^t k_j^Tv_j \\ &= q_t S_t \end{aligned}

StS_t is the state matrix at ttth token, which is the sum of outer products of key and value vectors up to time tt: St=i=1tkiTviS_{t} = \sum_{i=1}^{t} k_i^T v_i

However, we can view the state matrix StS_t as a regression model that maps key vectors to value vectors.

kmSt=kmi=1tkiTvi=i=1tkmkiTvi=vm+i=1,imtkmkiTvi(assuming km=1)where1mtkm is the key vector at mth tokenvm is the value vector at mth token\begin{aligned} k_m S_t &= k_m \sum_{i=1}^{t} k_i^T v_i \\ &= \sum_{i=1}^{t} k_m k_i^T v_i \\ &= v_m + \sum_{i=1, i \neq m}^{t} k_m k_i^T v_i \quad (\text{assuming } \|k_m\| = 1) \\ \text{where} &\quad 1 \leq m \leq t \\ &\quad k_m \text{ is the key vector at } m \text{th token} \\ &\quad v_m \text{ is the value vector at } m \text{th token} \\ \end{aligned}

This is the same form as the regression model we discussed in Delta Rule.

DeltaNet: Applying Delta Rule to Linear Attention

Let’s apply delta rule (Eq. 1\text{Eq. 1}) to linear attention. As mentioned in Recap of Linear Attention, we are going to view the state matrix StS_t as a regression model. Then the recurrent form of Delta Net can be derived as follows:

St=(IβtktTkt)St1+βtktTvtot=qtSt\begin{aligned} S_t &= (I - \beta_t k_t^T k_t) S_{t-1} + \beta_t k_t^T v_t \\ o_t &= q_t S_t \end{aligned}

Deriving Parallel Form of DeltaNet using Parallel Scan (Failed)

Let’s derive the parallel form of DeltaNet using parallel scan. Let’s first simplify the recurrent form of DeltaNet by defining two matrices:

Mt=IβtktTktXt=βtktTvt\begin{aligned} M_t &= I - \beta_t k_t^T k_t \\ X_t &= \beta_t k_t^T v_t \\ \end{aligned}

Then we can rewrite the recurrent form of DeltaNet as follows:

St=MtSt1+Xtot=qtStwhereS0=0\begin{aligned} S_t &= M_t S_{t-1} + X_t \\ o_t &= q_t S_t \\ \text{where} &\quad S_0 = 0 \end{aligned}

If we unroll the recurrence for S0S_0, S1S_1, S2S_2, and S3S_3, we can get the following equations:

S0=0S1=M1S0+X1S2=M2S1+X2=M2M1S0+M2X1+X2S3=M3S2+X3=M3M2M1S0+M3M2X1+M3X2+X3(Eq. 2)\begin{aligned} S_0 &= 0 \\ S_1 &= M_1 S_0 + X_1 \\ S_2 &= M_2 S_1 + X_2 = M_2 M_1 S_0 + M_2 X_1 + X_2 \\ S_3 &= M_3 S_2 + X_3 = M_3 M_2 M_1 S_0 + M_3 M_2 X_1 + M_3 X_2 + X_3 \\ \tag{Eq. 2} \end{aligned}

From (Eq.2)(Eq. 2), we can see that StS_t can be derived from parallel scan of the sequence of pairs (Mt,Xt)(M_t, X_t) using the following binary operator:

(Ma,Xa)(Mb,Xb)=(MbMa,MbXa+Xb)whereMa,MbRd×dXa,XbRd×d\begin{aligned} (M_a, X_a) \otimes (M_b, X_b) &= (M_b M_a, M_b X_a + X_b) \\ \text{where} &\quad M_a, M_b \in \mathbb{R}^{d \times d} \\ &\quad X_a, X_b \in \mathbb{R}^{d \times d} \end{aligned} DeltaNet Parallel Scan

Reason 1 for Failure: Explosion of Time Complexity

The depth of the parallel scan (assume Hillis Steele scan is used) is O(logL)O(\log L) and the total work is O(Ld3logL)O(L d^3\log{L}), which is more expensive than the original linear attention with O(Ld2)O(L d^2).

The reason why the total work is O(Ld3logL)O(L d^3\log{L}) is that the \otimes operator involves matrix multiplication, which has a time complexity of O(d3)O(d^3). And the work of the Hillis Steele algorithm is O(LlogL×work of operator)=O(Ld3logL)O(L \log{L} \times \text{work of operator}) = O(L d^3 \log {L})

Reason 2 for Failure: Explosion of Memory Complexity

The downside of parallel scan is that it requires storing all intermediate results (i.e., M0,M1,...,ML1M_0, M_1, ..., M_{L-1} and X0,X1,...,XL1X_0, X_1, ..., X_{L-1}).

Memory Complexity=O(L×size of intermediate result)=O(L×(d2+d2))=O(Ld2)\begin{aligned} \text{Memory Complexity} &= O(L \times \text{size of intermediate result}) \\ &= O(L \times (d^2 + d^2)) \\ &= O(L d^2) \end{aligned}

How can we overcome these challenges?

Deriving Parallel Form of DeltaNet using Chunking

Notation

  • [i]j:=Ci+jwhere{q,k,v,o,S,β}\Box_{[i]}^{j} := \Box_{C i + j} \\ \text{where} \quad \Box \in \{ q, k, v, o, S, \beta\}
  • [i]:=Ci:C(i+1)where{Q,K,V,O}\triangle_{[i]} := \triangle_{Ci: C(i+1)} \\ \text{where} \quad \triangle \in \{ Q, K, V, O\}

Chunking in Linear Attention

Before we derive the parallel form of DeltaNet using chunking, let’s first see how chunking can be applied to linear attention.

S[t]r=S[t]0+j=1rk[t]jTv[t]j(Eq. 3)\begin{aligned} S_{[t]}^r = S_{[t]}^{0} + \sum_{j=1}^{r} {k_{[t]}^{j}}^T v_{[t]}^{j} \tag{Eq. 3} \end{aligned}

Eq. 3\text{Eq. 3} shows how the state matrix S[t]rS_{[t]}^r can be computed from the state matrix of the previous chunk S[t]0=S[t1]CS_{[t]}^{0}=S_{[t-1]}^{C} and the key-value pairs of the current chunk. Then the output vector can be computed as follows:

o[t]r=q[t]rS[t]r=q[t]r(S[t]0+j=1rk[t]jTv[t]j)=q[t]rS[t]0+q[t]rj=1rk[t]jTv[t]j=q[t]rS[t]0+j=1rq[t]rk[t]jTv[t]j(Eq. 4)\begin{aligned} o_{[t]}^r &= q_{[t]}^r S_{[t]}^r \\ &= q_{[t]}^r(S_{[t]}^{0} + \sum_{j=1}^{r} {k_{[t]}^{j}}^T v_{[t]}^{j}) \\ &= q_{[t]}^r S_{[t]}^{0} + q_{[t]}^r \sum_{j=1}^{r} {k_{[t]}^{j}}^T v_{[t]}^{j} \\ &= q_{[t]}^r S_{[t]}^{0} + \sum_{j=1}^{r} q_{[t]}^r {k_{[t]}^{j}}^T v_{[t]}^{j} \tag {Eq. 4} \end{aligned}

If we convert Eq. 4\text{Eq. 4} to a matrix form, we can get the following equation:

O[t]=Q[t]S[t]0inter-chunk state passing+(Q[t]K[t]TMask)V[t]intra-chunk parallel computation\begin{aligned} O_{[t]} &= \underset{\text{inter-chunk state passing}}{\underline{Q_{[t]} S_{[t]}^{0}}} + \underset{\text{intra-chunk parallel computation}}{\underline{(Q_{[t]}K_{[t]}^T \odot \text{Mask}) V_{[t]}}} \end{aligned}

Computation complexity of chunking in linear attention is O(LCd+Ld2)O(LCd + Ld^2).

The reason for this is that computation for inter-chunk state passing for a single chunk requires O(Cd2)O(Cd^2). Computation for intra-chunk parallel computation for a single chunk requires O(C2d)O(C^2d). Since there are L/CL/C chunks, the total computation complexity is O(LCd+Ld2)O(LCd + Ld^2).

Chunking in DeltaNet: Naive Approach

Recall the recurrent form of DeltaNet again:

St=(IβtktTkt)St1+βtktTvtot=qtSt\begin{aligned} S_t &= (I - \beta_t k_t^T k_t) S_{t-1} + \beta_t k_t^T v_t \\ o_t &= q_t S_t \end{aligned} S[t]r=(Iβ[t]rk[t]rTk[t]r)S[t]r1+β[t]rk[t]rTv[t]r=j=1r[(i=j+1r(Iβ[t]ik[t]iTk[t]i))forgetting gate productβ[t]jk[t]jTv[t]j]intra-chunk parallel computation+(l=1r(Iβ[t]lk[t]lTk[t]l))S[t1]Cinter-chunk state passing(Eq. 5)\begin{aligned} S_{[t]}^r &= (I - \beta_{[t]}^r {k_{[t]}^r}^T k_{[t]}^r) S_{[t]}^{r-1} + \beta_{[t]}^r {k_{[t]}^r}^T v_{[t]}^r \\ &= \underset{\text{intra-chunk parallel computation}}{\underline{\sum_{j=1}^r \left[ \underset {\text{forgetting gate product}}{\underline{\left( \prod_{i=j+1}^r (I - \beta_{[t]}^i {k_{[t]}^i}^T k_{[t]}^i) \right)}}\beta_{[t]}^j {k_{[t]}^j}^T v_{[t]}^j \right]}} \\ &+ \underset{\text{inter-chunk state passing}}{\underline{\left( \prod_{l=1}^r (I - \beta_{[t]}^l {k_{[t]}^l}^T k_{[t]}^l) \right) S_{[t-1]}^C}} \tag{Eq. 5} \end{aligned}

Eq. 5\text{Eq. 5} is not as beautiful as Eq. 3\text{Eq. 3}, but it shows that the state matrix StS_t can be computed from the key-value pairs of all previous tokens and the product of the forgetting gates of all subsequent tokens.

Let’s take a deeper look at the product of the forgetting gates of all subsequent tokens:

i=j+1r(Iβ[t]ik[t]iTk[t]i)\begin{aligned} \prod_{i=j+1}^r (I - \beta_{[t]}^i {k_{[t]}^i}^T k_{[t]}^i) \end{aligned}

For simple representation, let’s define PnP_n as:

Pn=i=1n(IβikiTki)=(IβnknTkn)Pn1\begin{aligned} P_n &= \prod_{i=1}^n (I - \beta_i k_i^T k_i) \\ &= (I - \beta_n k_n^T k_n) P_{n-1} \end{aligned}

Computation complexity of getting Pnforn=1,2,LP_n \quad \text{for} n=1,2,\ldots L is O(Ld3)O(L d^3).

The reason for this is that we need O(d3)O(d^3) computation to derive PnP_n from Pn1P_{n-1}, and we need to compute PnP_n for all n=1,2,,Ln=1,2,\ldots,L. As a result, the total computation complexity is O(Ld3)O(L d^3).

Memory complexity of storing Pnfor n=1,2,,LP_n \quad \text{for } n=1,2,\ldots,L is O(Ld2)O(L d^2).

WY Representation

The WY representation is a mathematical technique that allows us to represent the product of matrices in a more compact form. Specifically, it states that the product of matrices of the form (IβikiTki)(I - \beta_i k_i^T k_i) can be represented as follows:

i=1t(IβikiTki)=Ij=1tkjTwjwherewjR1×dfor j=1,2,,t\begin{aligned} \prod_{i=1}^t (I - \beta_i k_i^T k_i) &= I - \sum_{j=1}^t k_j^T w_j \\ \text{where} &\quad w_j \in \mathbb{R}^{1 \times d} \quad \text{for } j=1,2,\ldots,t \end{aligned}

This can be easily proved by mathematical induction. (Proof is omitted here, but you can find it in the Appendix B.1 of the original DeltaNet paper.)

Note that the definition of vector in my blog is row-wise (R1×d\mathbb{R}^{1 \times d}), and the paper uses column-wise definition (Rd\mathbb{R}^d)

If you prove it by mathematical induction, you will find that wjw_j can be derived from βj\beta_j and kjk_j using the following equation:

wj=βjkj(βjkjm=1j1kmTwm)\begin{aligned} w_j &= \beta_j k_j - \left( \beta_j k_j \sum_{m=1}^{j-1} k_m^T w_m \right) \\ \end{aligned}

Chunking in DeltaNet: Applying WY Representation to Forgetting Gate Product

According to the WY representation, the product of the forgetting gates of all subsequent tokens can be represented as follows:

i=1t(IβikiTki)=Ij=1tkjTwjwherewjR1×dfor j=1,2,,t\begin{aligned} \prod_{i=1}^t (I - \beta_i k_i^T k_i) &= I - \sum_{j=1}^t k_j^T w_j \\ \text{where} &\quad w_j \in \mathbb{R}^{1 \times d} \quad \text{for } j=1,2,\ldots,t \end{aligned}

The beauty starts from here. By applying the WY representation, chunking of DeltaNet becomes the same as chunking of linear attention. Proof of this is as follows:

We want to show St=j=1tkjTujProof.Let’s use Mathematical Induction.Base Case:t=1S1=(Iβ1k1Tk1)S0+β1k1Tv1=β1k1Tv1=k1T(β1v1)=k1Tu1Induction Hypothesis:Assume that St1=j=1t1kjTujWe want to show that St=j=1tkjTujSt=(IβtktTkt)St1+βtktTvt=(IβtktTkt)j=1t1kjTuj+βtktTvt=j=1t1kjTujβtktTktj=1t1kjTuj+βtktTvt=j=1t1kjTuj+βtktT(vtktj=1t1kjTuj)=j=1t1kjTuj+ktT(βtvtβtktj=1t1kjTuj)ut=j=1tkjTuj\begin{align*} &\text{We want to show } \quad S_t = \sum_{j=1}^t k_j^T u_j \\ &\text{Proof.} \quad \text{Let's use Mathematical Induction.} \\ &\quad \text{Base Case:} \quad t=1 \\ &\quad S_1 = (I - \beta_1 k_1^T k_1) S_0 + \beta_1 k_1^T v_1 = \beta_1 k_1^T v_1 = k_1^T \left(\beta_1 v_1\right) =k_1^T u_1 \\ &\quad \text{Induction Hypothesis:} \quad \text{Assume that } S_{t-1} = \sum_{j=1}^{t-1} k_j^T u_j \\ &\quad \text{We want to show that } S_t = \sum_{j=1}^t k_j^T u_j \\ &\quad S_t = (I - \beta_t k_t^T k_t) S_{t-1} + \beta_t k_t^T v_t \\ &\quad = (I - \beta_t k_t^T k_t)\sum_{j=1}^{t-1} k_j^T u_j + \beta_t k_t^T v_t \\ &\quad = \sum_{j=1}^{t-1} k_j^T u_j - \beta_t k_t^T k_t \sum_{j=1}^{t-1} k_j^T u_j + \beta_t k_t^T v_t \\ &\quad = \sum_{j=1}^{t-1} k_j^T u_j + \beta_t k_t^T \left( v_t - k_t \sum_{j=1}^{t-1} k_j^T u_j \right) \\ &\quad = \sum_{j=1}^{t-1} k_j^T u_j + k_t^T \underset{u_t}{\underline{\left(\beta_t v_t - \beta_t k_t \sum_{j=1}^{t-1} k_j^T u_j \right)}} \\ &\quad = \sum_{j=1}^{t} k_j^T u_j \end{align*}

As a result, the computation complexity of chunking in DeltaNet with WY representation is the same as chunking in linear attention, which is O(LCd+Ld2)O(LCd + Ld^2)!

Ok, we are going to simplify Eq. 5\text{Eq. 5}:

S[t]r=(Iβ[t]rk[t]rTk[t]r)S[t]r1+β[t]rk[t]rTv[t]r=j=1r[(i=j+1r(Iβ[t]ik[t]iTk[t]i))forgetting gate productβ[t]jk[t]jTv[t]j]intra-chunk computation+(l=1r(Iβ[t]lk[t]lTk[t]l))S[t1]Cinter-chunk state passing=j=1rk[t]jTu[t]jintra-chunk computation+(Il=1rk[t]lTw[t]l)S[t1]Cinter-chunk state passingwhereu[t]j=β[t]jv[t]jβ[t]jk[t]jm=1j1k[t]mTu[t]mw[t]l=β[t]lk[t]lβ[t]lk[t]l(m=1l1k[t]mTw[t]m)(Eq. 6)\begin{aligned} S_{[t]}^r &= (I - \beta_{[t]}^r {k_{[t]}^r}^T k_{[t]}^r) S_{[t]}^{r-1} + \beta_{[t]}^r {k_{[t]}^r}^T v_{[t]}^r \\ &= \underset{\text{intra-chunk computation}}{\underline{\sum_{j=1}^r \left[ \underset {\text{forgetting gate product}}{\underline{\left( \prod_{i=j+1}^r (I - \beta_{[t]}^i {k_{[t]}^i}^T k_{[t]}^i) \right)}} \beta_{[t]}^j {k_{[t]}^j}^T v_{[t]}^j \right]}} \\ &+ \underset{\text{inter-chunk state passing}}{\underline{\left( \prod_{l=1}^r (I - \beta_{[t]}^l {k_{[t]}^l}^T k_{[t]}^l) \right) S_{[t-1]}^C}} \\ &= \underset{\text{intra-chunk computation}}{\underline{\sum_{j=1}^r {k_{[t]}^j}^T u_{[t]}^j}} + \underset{\text{inter-chunk state passing}}{\underline{\left( I - \sum_{l=1}^r {k_{[t]}^l}^T w_{[t]}^l \right) S_{[t-1]}^C}} \\ \quad \text{where} &\quad u_{[t]}^j = \beta_{[t]}^j v_{[t]}^j - \beta_{[t]}^j k_{[t]}^j \sum_{m=1}^{j-1} {k_{[t]}^m}^T u_{[t]}^m \\ &\quad w_{[t]}^l = \beta_{[t]}^l {k_{[t]}^l} - \beta_{[t]}^l {k_{[t]}^l} \left( \sum_{m=1}^{l-1} {k_{[t]}^m}^T w_{[t]}^m \right) \tag{Eq. 6} \end{aligned}

The output is computed as follows:

o[t]r=q[t]rS[t]r=q[t]r(j=1rk[t]jTu[t]j+(Il=1rk[t]lTw[t]l)S[t1]C)=q[t]rj=1rk[t]jTu[t]j+q[t]r(Il=1rk[t]lTw[t]l)S[t1]C=q[t]rj=1rk[t]jTu[t]j+q[t]rS[t1]Cq[t]rl=1rk[t]lTw[t]lS[t1]C=q[t]rS[t1]C+j=1rq[t]rk[t]jTu[t]jl=1rq[t]rk[t]lTw[t]lS[t1]C=q[t]rS[t1]C+j=1rq[t]rk[t]jT(u[t]jw[t]jS[t1]C)(Eq. 7)\begin{aligned} o_{[t]}^r &= q_{[t]}^r S_{[t]}^r \\ &= q_{[t]}^r \left( \sum_{j=1}^r {k_{[t]}^j}^T u_{[t]}^j + \left( I - \sum_{l=1}^r {k_{[t]}^l}^T w_{[t]}^l \right) S_{[t-1]}^C \right) \\ &= q_{[t]}^r \sum_{j=1}^r {k_{[t]}^j}^T u_{[t]}^j + q_{[t]}^r \left( I - \sum_{l=1}^r {k_{[t]}^l}^T w_{[t]}^l \right) S_{[t-1]}^C \\ &= q_{[t]}^r \sum_{j=1}^r {k_{[t]}^j}^T u_{[t]}^j + q_{[t]}^r S_{[t-1]}^C - q_{[t]}^r \sum_{l=1}^r {k_{[t]}^l}^T w_{[t]}^l S_{[t-1]}^C \\ &= q_{[t]}^r S_{[t-1]}^C + \sum_{j=1}^r q_{[t]}^r {k_{[t]}^j}^T u_{[t]}^j - \sum_{l=1}^r q_{[t]}^r {k_{[t]}^l}^T w_{[t]}^l S_{[t-1]}^C \\ &= q_{[t]}^r S_{[t-1]}^C + \sum_{j=1}^r q_{[t]}^r {k_{[t]}^j}^T \left(u_{[t]}^j - w_{[t]}^j S_{[t-1]}^C \right) \tag{Eq. 7} \end{aligned}

If we expand Eq. 6\text{Eq. 6} and Eq. 7\text{Eq. 7} into matrix form, we can get the following equations:

S[t]0:C=(IK[t]TW[t]0:C)S[t1]C+K[t]0:CTU[t]0:C=S[t1]C+K[t]0:CT(U[t]0:CW[t]0:CS[t1]C)Delta corrected ValueO[t]0:C=Q[t]0:CS[t1]C+(Q[t]0:CK[t]0:CTMask)(U[t]0:CW[t]0:CS[t1]C)Delta corrected Value\begin{aligned} S_{[t]}^{0:C} &= (I - K_{[t]}^T W_{[t]}^{0:C}) S_{[t-1]}^C + {K_{[t]}^{0:C}}^T U_{[t]}^{0:C} \\ &= S_{[t-1]}^C + {K_{[t]}^{0:C}}^T \underset{\text{Delta corrected Value}}{\underline{\left(U_{[t]}^{0:C} - W_{[t]}^{0:C} S_{[t-1]}^C \right)}} \\ O_{[t]}^{0:C} &= Q_{[t]}^{0:C}S_{[t-1]}^C + (Q_{[t]}^{0:C}{K_{[t]}^{0:C}}^T \odot \text{Mask}) \underset{\text{Delta corrected Value}}{\underline{(U_{[t]}^{0:C} - W_{[t]}^{0:C} S_{[t-1]}^C)}} \end{aligned}

Last piece of the puzzle: How to compute UU and WW efficiently?

Recall the definition of uu and ww:

u[t]j=β[t]jv[t]jβ[t]jk[t]jm=1j1k[t]mTu[t]mw[t]l=β[t]lk[t]lβ[t]lk[t]l(m=1l1k[t]mTw[t]m)\begin{aligned} u_{[t]}^j &= \beta_{[t]}^j v_{[t]}^j - \beta_{[t]}^j k_{[t]}^j \sum_{m=1}^{j-1} {k_{[t]}^m}^T u_{[t]}^m \\ w_{[t]}^l &= \beta_{[t]}^l {k_{[t]}^l} - \beta_{[t]}^l {k_{[t]}^l} \left( \sum_{m=1}^{l-1} {k_{[t]}^m}^T w_{[t]}^m \right) \end{aligned}

We only discussed the recurrent form of uu and ww, but we can also derive the parallel form of uu and ww. We are going to use simple representation for the parallel form of uu and ww:

Step 1. Finding pattern from the recurrent form

w0=β0k0w1=β1(k1k1k0Tw0)=β1(k1a1,0β0k0)w2=β2(k2k2(k0Tw0+k1Tw1))=β2(k2a2,0w0a2,1w1)whereai,j=kikjT\begin{aligned} & w_{0} = \beta_{0} k_{0} \\ & w_{1} = \beta_{1} (k_{1} - k_{1} k_{0}^T w_0) \\ &= \beta_{1} (k_{1} - a_{1,0} \beta_0 k_0) \\ & w_{2} = \beta_{2} (k_{2} - k_{2} (k_{0}^T w_0 + k_{1}^T w_1)) \\ &= \beta_{2} (k_{2} - a_{2,0} w_0 - a_{2,1} w_1) \\ & \vdots \\ & \quad \text{where} \quad a_{i,j} = k_i k_j^T \\ \end{aligned}

Step 2. Rearranging the equation:

w0=β0k0β1a1,0w0+w1=β1k1β2a2,0w0+β2a2,1w1+w2=β2k2\begin{aligned} & w_{0} = \beta_{0} k_{0} \\ & \beta_{1} a_{1,0} w_0 + w_{1} = \beta_{1} k_{1}\\ & \beta_{2} a_{2,0} w_0 + \beta_{2} a_{2,1} w_1 + w_{2} = \beta_{2} k_{2} \\ & \vdots \\ \end{aligned}

Step 3. Converting equations into matrix form:

[1000β1a1,0100β2a2,0β2a2,110βrar,0βrar,1βrar,21][w0w1w2wr]=[β0k0β1k1β2k2βrkr]\begin{aligned} \begin{bmatrix} 1 & 0 & 0 & \cdots & 0 \\ \beta_{1} a_{1,0} & 1 & 0 & \cdots & 0 \\ \beta_{2} a_{2,0} & \beta_{2} a_{2,1} & 1 & \cdots & 0 \\ \vdots & \vdots & \vdots & \ddots & \vdots \\ \beta_{r} a_{r,0} & \beta_{r} a_{r,1} & \beta_{r} a_{r,2} & \cdots & 1 \\ \end{bmatrix} \begin{bmatrix} w_0 \\ w_1 \\ w_2 \\ \vdots \\ w_r \\ \end{bmatrix} = \begin{bmatrix} \beta_{0} k_{0} \\ \beta_{1} k_{1} \\ \beta_{2} k_{2} \\ \vdots \\ \beta_{r} k_{r} \\ \end{bmatrix} \end{aligned}

Let’s define matrix IAI-A as follows:

IA=[1000β1a1,0100β2a2,0β2a2,110βrar,0βrar,1βrar,21]\begin{aligned} I - A = \begin{bmatrix} 1 & 0 & 0 & \cdots & 0 \\ \beta_{1} a_{1,0} & 1 & 0 & \cdots & 0 \\ \beta_{2} a_{2,0} & \beta_{2} a_{2,1} & 1 & \cdots & 0 \\ \vdots & \vdots & \vdots & \ddots & \vdots \\ \beta_{r} a_{r,0} & \beta_{r} a_{r,1} & \beta_{r} a_{r,2} & \cdots & 1 \\ \end{bmatrix} \end{aligned} (IA)W=diag(β)KwhereW=[w0w1w2wr]K=[k0k1k2kr]diag(β)=[β00000β10000β20000βr]\begin{aligned} (I-A) W = \text{diag}(\beta) K \\ \text{where} &\quad W = \begin{bmatrix}w_0 \\ w_1 \\ w_2 \\ \vdots \\ w_r \\ \end{bmatrix} \\ &\quad K = \begin{bmatrix}k_{0} \\ k_{1} \\ k_{2} \\ \vdots \\ k_{r} \\ \end{bmatrix} \\ &\quad \text{diag}(\beta) = \begin{bmatrix}\beta_{0} & 0 & 0 & \cdots & 0 \\ 0 & \beta_{1} & 0 & \cdots & 0 \\ 0 & 0 & \beta_{2} & \cdots & 0 \\ \vdots & \vdots & \vdots & \ddots & \vdots \\ 0 & 0 & 0 & \cdots & \beta_{r} \\ \end{bmatrix} \end{aligned}

Step 4. Solving the matrix equation

W=(IA)1diag(β)K\begin{aligned} W &= (I - A)^{-1} \text{diag}(\beta)K \end{aligned}

Step 5. Simplifying the inverse matrix (IA)1(I - A)^{-1}

We have a powerful property of IAI-A: IAI-A is a lower triangular matrix with all diagonal elements equal to 1.

As a result, inverse of IAI-A always exists, and can be computed efficiently (using Tensor Core) by the following equation:

(IA)1=I+A+A2+A3++ACwhereC is the chunk size\begin{aligned} (I-A)^{-1} &= I + A + A^2 + A^3 + \cdots + A^C \\ \text{where} &\quad C \text{ is the chunk size} \end{aligned}

Step 6. Final form of WW and UU

W[t]0:C=(IA[t]0:C)1diag(β[t]0:C)K[t]0:CU[t]0:C=(IA[t]0:C)1diag(β[t]0:C)V[t]0:C\begin{aligned} W_{[t]}^{0:C} &= (I - A_{[t]}^{0:C})^{-1} \text{diag}(\beta_{[t]}^{0:C})K_{[t]}^{0:C} \\ U_{[t]}^{0:C} &= (I - A_{[t]}^{0:C})^{-1} \text{diag}(\beta_{[t]}^{0:C})V_{[t]}^{0:C} \\ \end{aligned}

Final form of DeltaNet

Form is somewhat same as pure linear attention!

Preparing auxiliary matricesA[t]0:C=tril(diag(β[t]0:C)K[t]0:CK[t]0:CT,1)W[t]0:C=(IA[t]0:C)1diag(β[t]0:C)K[t]0:CU[t]0:C=(IA[t]0:C)1diag(β[t]0:C)V[t]0:CMain computationS[t]0:C=(IK[t]TW[t]0:C)S[t1]C+K[t]0:CTU[t]0:C=S[t1]C+K[t]0:CT(U[t]0:CW[t]0:CS[t1]C)Delta corrected ValueO[t]0:C=Q[t]0:CS[t1]C+(Q[t]0:CK[t]0:CTMask)(U[t]0:CW[t]0:CS[t1]C)Delta corrected Value\begin{aligned} &\text{Preparing auxiliary matrices} \\ A_{[t]}^{0:C} &= \text{tril}(\text{diag}(\beta_{[t]}^{0:C})K_{[t]}^{0:C} {K_{[t]}^{0:C}}^T, -1)\\ W_{[t]}^{0:C} &= (I - A_{[t]}^{0:C})^{-1} \text{diag}(\beta_{[t]}^{0:C})K_{[t]}^{0:C} \\ U_{[t]}^{0:C} &= (I - A_{[t]}^{0:C})^{-1} \text{diag}(\beta_{[t]}^{0:C})V_{[t]}^{0:C} \\ \\ &\text{Main computation} \\ S_{[t]}^{0:C} &= (I - K_{[t]}^T W_{[t]}^{0:C}) S_{[t-1]}^C + {K_{[t]}^{0:C}}^T U_{[t]}^{0:C} \\ &= S_{[t-1]}^C + {K_{[t]}^{0:C}}^T \underset{\text{Delta corrected Value}}{\underline{\left(U_{[t]}^{0:C} - W_{[t]}^{0:C} S_{[t-1]}^C \right)}} \\ O_{[t]}^{0:C} &= Q_{[t]}^{0:C}S_{[t-1]}^C + (Q_{[t]}^{0:C}{K_{[t]}^{0:C}}^T \odot \text{Mask}) \underset{\text{Delta corrected Value}}{\underline{(U_{[t]}^{0:C} - W_{[t]}^{0:C} S_{[t-1]}^C)}} \\ \end{aligned}

Wrap Up

We’ve gone through the mathematical formulation of DeltaNet, and derived the parallel form of DeltaNet using chunking and WY representation.

References