skip to content
Ball's Blog

Multihead Latent Attention (MLA)

/ 4 min read

Updated:
Table of Contents

Summary

After “Attention is All You Need” was published in 2017, attention mechanism has been widely used in various AI models. However, the original attention mechanism has some limitations, growing size of KV caches. To overcome this issue, Deepseek-v2 introduces Multihead Latent Attention (MLA) technique, which compresses the KV caches into a smaller latent space, and also allows more efficient attention computation.

Terminology

xtRDx_t \in \mathbb{R}^{D}: hidden state for tt th token

DD: Dimension of hidden state xtx_t

KtRDK_t \in \mathbb{R}^{D}: Key for tt th token

VtRDV_t \in \mathbb{R}^{D}: Value for tt th token

QtRDQ_t \in \mathbb{R}^{D}: Query for tt th token

rkvr_{kv}: Dimension of latent representation for KV cache

ctKVRrkvc_t^{KV} \in \mathbb{R}^{r_{kv}}: Latent representation of KV cache for tt th token. In other words, this is compressed representation of KtK_t and VtV_t. This can be decompressed into KtK_t and VtV_t through linear transformation.

WDKVW_D^{KV}: Down projection matrix for KV cache. This matrix transforms xtx_t into ctKVc_t^{KV}

WUKW_U^K: Up projection matrix for Key. This matrix transforms ctKVc_t^{KV} into KtK_t.

WUVW_U^V: Up projection matrix for Value. This matrix transforms ctKVc_t^{KV} into VtV_t.

rqr_q: Dimension of latent representation for Query

ctQRrqc_t^Q \in \mathbb{R}^{r_q}: Latent representation of Query for tt th token. This is compressed representation of QtQ_t. This can be decompressed into QtQ_t through linear transformation.

WDQW_D^Q: Down projection matrix for Query. This matrix transforms xtx_t into ctQc_t^Q

WUQW_U^Q: Up projection matrix for Query. This matrix transforms ctQc_t^Q into QtQ_t.

DpositionD_{position}: Dimension of position part for Query and Key.

qt,positionRDpositionq_{t,position} \in \mathbb{R}^{D_{position}}: Query for position part of tt th token

kt,positionRDpositionk_{t,position} \in \mathbb{R}^{D_{position}}: Key for position part of tt th token

WURQW_{UR}^Q: Up projection matrix for Query for position part. This matrix transforms ctQc_t^Q into qt,positionq_{t,position}

WURKW_{UR}^K: Up projection matrix for Key for position part. This matrix transforms ctKVc_t^{KV} into kt,positionk_{t,position}

Naive Implementation

The following figure illustrates the naive structure of MLA: naive MLA structure

This implementation still benefits from the compression of KV cache. The memory pressure for decoding will be reduced. However, the computational cost is high because we have to up project all the latent vectors(ctKVc_t^{KV} and ctQc_t^Q) into their respective high-dimensional representations.

Optimized Implementation

Deepseek-v2 introduces a fused implementation of MLA that reduces the computational cost by fusing the up projection and attention computation. The following figure illustrates the fused structure of MLA:

optimized MLA structure

In this implementation, we calculate and store WUQWUKTW_U^Q{W_U^{K}}^T in advance.

In naive implementation, to calculate attention score(QKTQK^T), we had to do the following steps:

  1. Up project cT..T1Qc_{T'..T-1}^Q into QT..T1Q_{T'..T-1} using WUQW_U^Q
  2. Up project c0..T1KVc_{0..T-1}^{KV} into K0..T1K_{0..T-1} using WUKW_U^K
  3. Calculate attention score using QT..T1Q_{T'..T-1} and K0..T1K_{0..T-1}

However, in the optimized implementation, we fuse the up projection and attention computation. It can be done in following steps:

  1. In static time, calculate and store WUQWUKTW_U^Q{W_U^{K}}^T
  2. Calculate attention score using cT..T1Qc_{T'..T-1}^Q and c0..T1KVc_{0..T-1}^{KV} and WUQWUKTW_U^Q{W_U^{K}}^T: QKT=cT..T1Q(WUQWUKT)c0..T1KVQK^T = c_{T'..T-1}^Q (W_U^Q {W_U^{K}}^T) c_{0..T-1}^{KV}

Positional Embedding for MLA

Rotary positional embedding is widely used in attention. MLA also uses rotary positional embedding, but it is used quite differently to the original RoPE. In MLA, if we try to apply RoPE to QQ and KK directly, we cannot utilize the fused implementation because RoPE matrix is dynamic(it changes with the position). To solve this issue, MLA separates the Query and Key into two parts: the content part and the position part.

QKT=[qcontent,qposition][kcontentTkpositionT]=qcontentkcontentT+qpositionkpositionTQK^T = [q_{content}, q_{position}] \begin{bmatrix} k_{content}^T \\ k_{position}^T \end{bmatrix} = q_{content}k_{content}^T + q_{position}k_{position}^T

The position part of the query and key is calculated using the rotary positional embedding:

qposition=RoPE(WURQcT..T1Q)q_{position} = RoPE(W_{UR}^Qc_{T'..T-1}^Q)

kposition=RoPE(WURKc0..T1KV)k_{position} = RoPE(W_{UR}^Kc_{0..T-1}^{KV})

Content part and position part are calculated in parallel, and then they are added together to get the final attention score. To make the calculation time of position part same as the content part, we decrease the dimension of the position part.(DpositionD_{position})

Benefits of MLA

MLA has several benefits compared to the original attention mechanism:

  1. Reduced KV Cache size: By compressing the KV cache into a smaller latent space, MLA significantly reduces the memory pressure during decoding. This allows for longer sequences to be processed without running out of memory.
  2. Efficient Attention Computation: The fused implementation of MLA reduces the computational cost by fusing the up projection and attention computation. This allows for faster inference without sacrificing performance. Specifically, WUQWUKTRrq×rkvW_U^Q{W_U^{K}}^T \in \mathbb{R}^{r_q \times r_{kv}} matrix is small. As a result, the required attention computation gets smaller.
  3. Orthogonal to GQA: MLA can be combined with GQA to further reduce the computational cost. GQA reduces the number of attention heads, while MLA reduces the dimension of KV cache. These two techniques can be used together to achieve even more efficient attention computation.