Shortened URL: **https://tinyurl.com/amd-comp-mla**

<aside> Problem 3: Decoding Multi-head Latent Attention (MLA) on a single MI300

The multi-head latent attention (MLA) operator introduced in DeepSeek-V2 is a modified version of the attention algorithm that down-projects the keys / values to store an efficient KV-cache, then up-projects them for use in the attention operator. We will be implementing the inference-only version for decoding that includes the down-projected RoPE embeddings and down-projected Q values without adding YaRN or other optimizations.

[💡Tip]: We will describe this layer in terms of modular functions to make it easier to digest — we also encourage building on similar techniques like Flash Attention to build your kernel. You could even start by calling FlashAttention in PyTorch!


The MLA for Decoding, Informally Explained

The differences between MLA and standard MHA attention are pretty nice from an implementation standpoint — the major differences come in how the $\bold{Q},\bold{K},\bold{V}$ vectors are computed, but the attention computation itself is identical. There are two main changes to multi-head attention that we will implement.

  1. The first is that the input $X$ is first down-projected separately for $\bold{Q},\bold{K},\bold{V}.$ The purpose is that our KV cache is stored as down-projected latents, so we need to first down-project the sequence to add the KV cache in.
  2. The second is that the queries and keys split the hidden dimension such that one gets a RoPE transformation applied, and the other does not. The actual way this is applied is pretty simple, but it is quite specific. Reading either the code or math below will be super helpful.

We are working in the decoding setting, where the input “sequence” is a single token. In other words, the query will have sequence length 1, but the keys / values will have sequence length KV cache length + 1 . We also require the user to update / return the KV cache after this operation is applied, which is basically just the old KV cache + the KV latent for the current token.


The MLA, Formally Explained

Formally, given an sequence of token embeddings $\bold{X} \in \mathbb{R}^{N \times d}$, we compute

$$ \text{MLA}(\bold{X}) \triangleq \text{MHA} \left(q(\bold{X}), k(\bold{X}), v(\bold{X}) \right) $$

The difference between MLA and MHA is how these mappings $q(\cdot),k(\cdot),v(\cdot)$ are computed. Let $n_h$ be the number of attention heads and $d_{n_h} := d_{model} / n_h$. This will be important later for the shape of the $\bold{Q},\bold{K},\bold{V}$ vectors as they are passed into the multi-head attention operator. By definition (excluding batch dimension),

$$ q(\bold{X}), k(\bold{X}), v(\bold{X}) \quad \text{ all have shape } (...,N,n_h,d_{n_h}) $$

We will write this out in the order of how $\bold{X}$ is transformed to make it easier for the reader. In our kernel, $N=1$ because we are writing a decoding kernel.

[Step 1. Down-project to Low Rank and add KV Cache] We first down-project (without biases) and add the stored KV cache latents. Remember that the purpose of down-projecting is to store a low rank KV latent:

$$ \begin{align*} \bold{D}{Q} &\triangleq \bold{W}^{q}{\text{down}} \bold{X} \in \mathbb{R}^{1 \times (d_{q} * n_h)} \\ \bold{D}^{(1)}{KV} &\triangleq \bold{W}^{kv}{\text{down}} \bold{X} \in \mathbb{R}^{1 \times ((d_{kv} + d_{rope}) * n_h)} \end{align*} $$

We then concatenate the stored KV cache embeddings $\bold{D}{KV\cached} \in \mathbb{R}^{s \times ((d{kv} + d{rope})*n_h)}$

$$ \bold{D}{KV} \triangleq \bold{D}{KV\cached} :: \bold{D}^{(1)}{KV} \in \mathbb{R}^{s+1 \times ((d_{kv} + d_{rope}) * n_h)} $$

In this problem, you will also have to return the stored KV cache, so make sure to update it. These vectors will be transformed by both an up-projection and a RoPE projection, which you will see below. In other words, steps 2 & 3 are executed in parallel streams w.r.t. $\bold{D}Q,\bold{D}{KV}.$

[Step 2. (NoPE + Prepare RoPE) Up-projecting the Low-rank vectors] These down-projected embeddings are normally used as the KV cache, but in this problem we will not require you to explicitly store them. We continue by up-projecting these vectors as follows:

$$ \begin{align*} \bold{Q}{\text{nope}} :: \bold{Q}{\text{rope}} &\triangleq \bold{W}^{q}{\text{up}} \bold{D}q \in \mathbb{R}^{1 \times (d{nope}+d{rope})} \\

\end{align*} $$

The keys / values are transformed slightly differently. Firstly, the values $\bold{V}$ do not have a RoPE stream. Secondly, due to how we notated everything, $\bold{K}_{\text{rope}}$ does not get up-projected, but this is because we tried to use as few weight matrices as possible. We first split up our KV as:

$$ \bold{D}{kv\\text{nope}} :: \bold{K}{\text{rope}} =\bold{D}{kv} $$

We then compute the nope embeddings for $\bold{K},\bold{V}$ as follows:

$$ \begin{align*} \bold{K}{\text{nope}} :: \bold{V}{\text{nope}} &\triangleq \bold{W}^{k}{\text{up}} \bold{D}{kv} \in \mathbb{R}^{1 \times (d_{k} + d_{v})} \end{align*} $$

If the notation above is confusing, :: is the concatenation operator, which we use to show that we will split the tensor along the last dimension to get two separate tensors. We do this so we only need one weight matrix, but it’s equivalent to having separate nope and rope weight matrices instead of a single $\bold{W}{up}$ (it is easy to verify this). Each of these tensors are then reshaped for MHA into the form $\mathbb{R}^{bs \times N \times n_h \times d{n_h}}$ .

[3. (RoPE) Preparing RoPE Embeddings] MLA computes RoPE (Su et al., 2021) ****embeddings in a unique way by only computing it for part of an embedding. We apply these to the RoPE embeddings we computed earlier for $\bold{Q}{\text{rope}}, \bold{K}{\text{rope}}$

We can then compute the RoPE embedding transformation (note that the query starts a position $s$, which is the size of the KV cache, i.e. the position of the token we are decoding):

$$ \begin{align*} \bold{R}{Q} &\triangleq \text{RoPE}(\bold{Q}{\text{rope}}, \text{pos}=s) \in \mathbb{R}^{1 \times (n_h * d_{rotate})} \\ \bold{R}{K} &\triangleq \text{RoPE}(\bold{K}{\text{rope}}) \in \mathbb{R}^{(s+1) \times (n_h * d_{rotate})} \end{align*} $$

[4. Computing the final Q,K,V and MHA] Again abusing notation (you know that $\bold{S}, \bold{R}$ are computed from $\bold{X}$), we concatenate the RoPE vectors with the up-projected variants to compute our final $q,k,v$:

$$ \begin{align*} q(\bold{X}) &\triangleq \text{concat}(\bold{S}_Q, \bold{R}_Q, \text{dim}=-1) \\ k(\bold{X}) &\triangleq \text{concat}(\bold{S}_K, \bold{R}_K, \text{dim}=-1) \\ v(\bold{X}) &\triangleq \bold{S}_V \\ \text{MLA}(\bold{X}) &\triangleq \text{MHA} \left(q(\bold{X}), k(\bold{X}), v(\bold{X}) \right) \end{align*} $$

Like in most standard MHA implementations, we have to transpose these matrices to swap the sequence dimension and the head dimension. So the final shapes should be $(..., n_h, N, d_{n_h} + d_{\text{rotate}})$ for $q(\bold{X}),k(\bold{X})$ and $(..., n_h, N, d_{n_h})$ for $v(\bold{X})$.


[A. If you forgot how MHA works] For completion sake, the multi-head attention algorithm is defined as:

$$ \text{MHA}(\bold{Q},\bold{K},\bold{V}) = \text{softmax}\left(\frac{\bold{Q}\bold{K}^T}{\sqrt{d_k}}\right)\bold{V} $$

where all operations act on the sequence and last (hidden) dimension (i.e. the # heads is kind of like a batch size, not literally though). Here, in the case of MLA, $d_k := d_{n_h}+ d_{\text{rotate}}$ .


All weights and biases can be inferred from the equations above based on the output shapes, but if you’re unsure just refer to the code — it’s a lot easier to read!

Remarks. I think the MHA / attention problem has sort of been exhausted, which is fine but it makes it hard to make a leaderboard problem off of this. It would be cool to see if some minor adjustments to the algorithm could be optimized, although still practically useful. DeepSeek in particular has proven that they can optimize their own algorithms pretty well, but not many standard implementations exist online.


PyTorch Implementation of MLA Decoding

We modularize the KV cache, RoPE, and MLA modules for your convenience. All datatypes should actually be bfloat16 in here.

class RoPE(nn.Module):
    def __init__(self, d_model: int):
        super().__init__()
        self.d_model = d_model
        theta = 10000 ** (-torch.arange(0, d_model//2, dtype=torch.bfloat16) / (d_model//2))
        self.register_buffer("theta", theta)

    def rotate_half(self, x: torch.Tensor) -> torch.Tensor:
        x1, x2 = x.chunk(2, dim=-1)
        return torch.cat((-x2, x1), dim=-1)

    def forward(self, x: torch.Tensor, start_pos: int = 0) -> torch.Tensor:
        seq_len = x.size(-2)
        d_model = x.size(-1)
        assert d_model == self.d_model
        seq_idx = torch.arange(start_pos, start_pos + seq_len, device=x.device)
        idx_theta = torch.einsum('s,d->sd', seq_idx, self.theta)
        idx_theta2 = torch.cat([idx_theta, idx_theta], dim=-1)
        cos = idx_theta2.cos().to(torch.bfloat16)
        sin = idx_theta2.sin().to(torch.bfloat16)
        return x * cos + self.rotate_half(x) * sin

class KVCache(nn.Module):
    def __init__(self, kv_cache_shape: tuple) -> None:
        super().__init__()
        self.register_buffer('data', torch.zeros(kv_cache_shape, dtype=torch.bfloat16))
        self.seq_len = 0
        self.zero()

    def zero(self) -> None:
        self.data.zero_()

    def forward(self, c_kv: torch.Tensor) -> torch.Tensor:
        assert self.seq_len + c_kv.size(1) <= self.data.size(1), "KV Cache Exceeded"

        self.data = self.data.to(c_kv.dtype)
        self.data[
            :, self.seq_len : self.seq_len + c_kv.size(1), :
        ] = c_kv
        self.seq_len += c_kv.size(1)

        return self.data[:, :self.seq_len], self.seq_len
class MLA(nn.Module):
    def __init__(self, config: Config):
        super().__init__()
        self.dim = config.dim
        self.n_heads = config.n_heads
        self.q_lora_rank = config.q_lora_rank
        self.kv_lora_rank = config.kv_lora_rank
        self.nope_head_dim = config.qk_nope_head_dim
        self.rope_head_dim = config.qk_rope_head_dim
        self.v_head_dim = config.v_head_dim
        # Down-projection matrices
        self.Q_proj_down = nn.Linear(self.dim, self.q_lora_rank, bias=False)
        self.KV_proj_down = nn.Linear(self.dim, self.kv_lora_rank + self.rope_head_dim, bias=False)

        # Up-projection and rope projection matrices
        self.Q_proj_up = nn.Linear(self.q_lora_rank, (self.nope_head_dim + self.rope_head_dim) * self.n_heads, bias=False)
        self.KV_proj_up = nn.Linear(self.kv_lora_rank, (self.nope_head_dim + self.v_head_dim) * self.n_heads, bias=False)

        # RoPE on half embeddings
        self.q_rope = RoPE(self.rope_head_dim)
        self.k_rope = RoPE(self.rope_head_dim)

        # Output projection
        self.wo = nn.Linear(self.v_head_dim * self.n_heads, self.dim, bias=False)
        self.eps = 1e-6
   
    def forward(self, x: torch.Tensor, kv_cache: KVCache) -> torch.Tensor:
        # seq_len = 1 always here
        batch_size, seq_len, model_dim = x.size()

        ################################################################################
        #                 Step 1: Handle down-projection + KV cache                    #
        ################################################################################
        q_lora = self.Q_proj_down(x)
        kv_lora = self.KV_proj_down(x)
        kv_lora, kv_len = kv_cache(kv_lora)
        query_pos = kv_len - 1

        ################################################################################
        #                  Step 2: Up-project and prepare NoPE + RoPE                  #
        ################################################################################

        # Handle queries Q first
        q_nope_and_rope = self.Q_proj_up(q_lora).view(
            batch_size, seq_len, self.n_heads, self.nope_head_dim + self.rope_head_dim)
        q_nope, q_rope = torch.split(q_nope_and_rope, [self.nope_head_dim, self.rope_head_dim], dim=-1)

        # Handle keys and values K/V. V does not need RoPE
        kv_nope, k_rope = torch.split(kv_lora, [self.kv_lora_rank, self.rope_head_dim], dim=-1)
        kv_nope = self.KV_proj_up(kv_nope).view(
            batch_size, kv_len, self.n_heads, self.nope_head_dim + self.v_head_dim)
        k_nope, v = torch.split(kv_nope, [self.nope_head_dim, self.v_head_dim], dim=-1)

        ################################################################################
        #                    Step 3: Handle RoPE Stream                                #
        ################################################################################

        # Compute RoPE for queries and combine with no-RoPE part
        q_rope = q_rope.reshape(batch_size, self.n_heads, seq_len, self.rope_head_dim)
        q_rope = self.q_rope(q_rope, start_pos=query_pos)
        q_rope = q_rope.reshape(batch_size, seq_len, self.n_heads, self.rope_head_dim)
        q = torch.concat([q_nope, q_rope], dim=-1)

        # Compute RoPE for keys and combine with no-RoPE part
        k_rope = k_rope[:, :, None, :]
        k_rope = self.k_rope(k_rope).expand(-1,-1,self.n_heads,-1)
        k = torch.concat([k_nope, k_rope], dim=-1)
                
        ################################################################################
        #                        Compute Multi-head Attention                          #
        ################################################################################
        q = q.reshape(batch_size, self.n_heads, 1, self.rope_head_dim + self.nope_head_dim)
        k = k.reshape(batch_size, self.n_heads, kv_len, self.rope_head_dim + self.nope_head_dim)
        v = v.reshape(batch_size, self.n_heads, -1, self.v_head_dim)
        scores = torch.matmul(q, k.transpose(-1, -2)) / math.sqrt(self.rope_head_dim + self.nope_head_dim)
        attn = F.softmax(scores, dim=-1).to(torch.bfloat16)
        y = torch.matmul(attn, v).view(batch_size, -1)
        y = self.wo(y)

        return y, kv_cache

Problem Constraints:

We follow sizes from Deepseek-V3 / R1 and primarily vary the sequence length of the KV cache through:

  0. bs::[128] # batch size
  1. prefill::[512, 2048, 4096, 6144] # as kv length
  2. sq::[1] # as only consider decoding
  3. dim::7168 # hidden size of deepseek v3
  4. kv_lora_rank::[512] # kv lora rank of deepseek v3
  5. qk_rope_head_dim::[64] # rope embedding dimension
  6. v_head_dim::[128] # dim of v embeddings per head
  7. n_heads::[128] # num of attn heads

</aside>