Shortened URL: **https://tinyurl.com/amd-comp-mla**
<aside> Problem 3: Decoding Multi-head Latent Attention (MLA) on a single MI300
The multi-head latent attention (MLA) operator introduced in DeepSeek-V2 is a modified version of the attention algorithm that down-projects the keys / values to store an efficient KV-cache, then up-projects them for use in the attention operator. We will be implementing the inference-only version for decoding that includes the down-projected RoPE embeddings and down-projected Q values without adding YaRN or other optimizations.
[💡Tip]: We will describe this layer in terms of modular functions to make it easier to digest — we also encourage building on similar techniques like Flash Attention to build your kernel. You could even start by calling FlashAttention in PyTorch!
torch.bfloat16
format.The differences between MLA and standard MHA attention are pretty nice from an implementation standpoint — the major differences come in how the $\bold{Q},\bold{K},\bold{V}$ vectors are computed, but the attention computation itself is identical. There are two main changes to multi-head attention that we will implement.
We are working in the decoding setting, where the input “sequence” is a single token. In other words, the query will have sequence length 1, but the keys / values will have sequence length KV cache length + 1
. We also require the user to update / return the KV cache after this operation is applied, which is basically just the old KV cache + the KV latent for the current token.
Formally, given an sequence of token embeddings $\bold{X} \in \mathbb{R}^{N \times d}$, we compute
$$ \text{MLA}(\bold{X}) \triangleq \text{MHA} \left(q(\bold{X}), k(\bold{X}), v(\bold{X}) \right) $$
The difference between MLA and MHA is how these mappings $q(\cdot),k(\cdot),v(\cdot)$ are computed. Let $n_h$ be the number of attention heads and $d_{n_h} := d_{model} / n_h$. This will be important later for the shape of the $\bold{Q},\bold{K},\bold{V}$ vectors as they are passed into the multi-head attention operator. By definition (excluding batch dimension),
$$ q(\bold{X}), k(\bold{X}), v(\bold{X}) \quad \text{ all have shape } (...,N,n_h,d_{n_h}) $$
We will write this out in the order of how $\bold{X}$ is transformed to make it easier for the reader. In our kernel, $N=1$ because we are writing a decoding kernel.
[Step 1. Down-project to Low Rank and add KV Cache] We first down-project (without biases) and add the stored KV cache latents. Remember that the purpose of down-projecting is to store a low rank KV latent:
$$ \begin{align*} \bold{D}{Q} &\triangleq \bold{W}^{q}{\text{down}} \bold{X} \in \mathbb{R}^{1 \times (d_{q} * n_h)} \\ \bold{D}^{(1)}{KV} &\triangleq \bold{W}^{kv}{\text{down}} \bold{X} \in \mathbb{R}^{1 \times ((d_{kv} + d_{rope}) * n_h)} \end{align*} $$
We then concatenate the stored KV cache embeddings $\bold{D}{KV\cached} \in \mathbb{R}^{s \times ((d{kv} + d{rope})*n_h)}$
$$ \bold{D}{KV} \triangleq \bold{D}{KV\cached} :: \bold{D}^{(1)}{KV} \in \mathbb{R}^{s+1 \times ((d_{kv} + d_{rope}) * n_h)} $$
In this problem, you will also have to return the stored KV cache, so make sure to update it. These vectors will be transformed by both an up-projection and a RoPE projection, which you will see below. In other words, steps 2 & 3 are executed in parallel streams w.r.t. $\bold{D}Q,\bold{D}{KV}.$
[Step 2. (NoPE + Prepare RoPE) Up-projecting the Low-rank vectors] These down-projected embeddings are normally used as the KV cache, but in this problem we will not require you to explicitly store them. We continue by up-projecting these vectors as follows:
$$ \begin{align*} \bold{Q}{\text{nope}} :: \bold{Q}{\text{rope}} &\triangleq \bold{W}^{q}{\text{up}} \bold{D}q \in \mathbb{R}^{1 \times (d{nope}+d{rope})} \\
\end{align*} $$
The keys / values are transformed slightly differently. Firstly, the values $\bold{V}$ do not have a RoPE stream. Secondly, due to how we notated everything, $\bold{K}_{\text{rope}}$ does not get up-projected, but this is because we tried to use as few weight matrices as possible. We first split up our KV as:
$$ \bold{D}{kv\\text{nope}} :: \bold{K}{\text{rope}} =\bold{D}{kv} $$
We then compute the nope
embeddings for $\bold{K},\bold{V}$ as follows:
$$ \begin{align*} \bold{K}{\text{nope}} :: \bold{V}{\text{nope}} &\triangleq \bold{W}^{k}{\text{up}} \bold{D}{kv} \in \mathbb{R}^{1 \times (d_{k} + d_{v})} \end{align*} $$
If the notation above is confusing, ::
is the concatenation operator, which we use to show that we will split the tensor along the last dimension to get two separate tensors. We do this so we only need one weight matrix, but it’s equivalent to having separate nope
and rope
weight matrices instead of a single $\bold{W}{up}$ (it is easy to verify this). Each of these tensors are then reshaped for MHA into the form $\mathbb{R}^{bs \times N \times n_h \times d{n_h}}$ .
[3. (RoPE) Preparing RoPE Embeddings] MLA computes RoPE (Su et al., 2021) ****embeddings in a unique way by only computing it for part of an embedding. We apply these to the RoPE embeddings we computed earlier for $\bold{Q}{\text{rope}}, \bold{K}{\text{rope}}$
We can then compute the RoPE embedding transformation (note that the query starts a position $s$, which is the size of the KV cache, i.e. the position of the token we are decoding):
$$ \begin{align*} \bold{R}{Q} &\triangleq \text{RoPE}(\bold{Q}{\text{rope}}, \text{pos}=s) \in \mathbb{R}^{1 \times (n_h * d_{rotate})} \\ \bold{R}{K} &\triangleq \text{RoPE}(\bold{K}{\text{rope}}) \in \mathbb{R}^{(s+1) \times (n_h * d_{rotate})} \end{align*} $$
[4. Computing the final Q,K,V and MHA] Again abusing notation (you know that $\bold{S}, \bold{R}$ are computed from $\bold{X}$), we concatenate the RoPE vectors with the up-projected variants to compute our final $q,k,v$:
$$ \begin{align*} q(\bold{X}) &\triangleq \text{concat}(\bold{S}_Q, \bold{R}_Q, \text{dim}=-1) \\ k(\bold{X}) &\triangleq \text{concat}(\bold{S}_K, \bold{R}_K, \text{dim}=-1) \\ v(\bold{X}) &\triangleq \bold{S}_V \\ \text{MLA}(\bold{X}) &\triangleq \text{MHA} \left(q(\bold{X}), k(\bold{X}), v(\bold{X}) \right) \end{align*} $$
Like in most standard MHA implementations, we have to transpose these matrices to swap the sequence dimension and the head dimension. So the final shapes should be $(..., n_h, N, d_{n_h} + d_{\text{rotate}})$ for $q(\bold{X}),k(\bold{X})$ and $(..., n_h, N, d_{n_h})$ for $v(\bold{X})$.
[A. If you forgot how MHA works] For completion sake, the multi-head attention algorithm is defined as:
$$ \text{MHA}(\bold{Q},\bold{K},\bold{V}) = \text{softmax}\left(\frac{\bold{Q}\bold{K}^T}{\sqrt{d_k}}\right)\bold{V} $$
where all operations act on the sequence and last (hidden) dimension (i.e. the # heads is kind of like a batch size, not literally though). Here, in the case of MLA, $d_k := d_{n_h}+ d_{\text{rotate}}$ .
All weights and biases can be inferred from the equations above based on the output shapes, but if you’re unsure just refer to the code — it’s a lot easier to read!
Remarks. I think the MHA / attention problem has sort of been exhausted, which is fine but it makes it hard to make a leaderboard problem off of this. It would be cool to see if some minor adjustments to the algorithm could be optimized, although still practically useful. DeepSeek in particular has proven that they can optimize their own algorithms pretty well, but not many standard implementations exist online.
We modularize the KV cache, RoPE, and MLA modules for your convenience. All datatypes should actually be bfloat16 in here.
class RoPE(nn.Module):
def __init__(self, d_model: int):
super().__init__()
self.d_model = d_model
theta = 10000 ** (-torch.arange(0, d_model//2, dtype=torch.bfloat16) / (d_model//2))
self.register_buffer("theta", theta)
def rotate_half(self, x: torch.Tensor) -> torch.Tensor:
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
def forward(self, x: torch.Tensor, start_pos: int = 0) -> torch.Tensor:
seq_len = x.size(-2)
d_model = x.size(-1)
assert d_model == self.d_model
seq_idx = torch.arange(start_pos, start_pos + seq_len, device=x.device)
idx_theta = torch.einsum('s,d->sd', seq_idx, self.theta)
idx_theta2 = torch.cat([idx_theta, idx_theta], dim=-1)
cos = idx_theta2.cos().to(torch.bfloat16)
sin = idx_theta2.sin().to(torch.bfloat16)
return x * cos + self.rotate_half(x) * sin
class KVCache(nn.Module):
def __init__(self, kv_cache_shape: tuple) -> None:
super().__init__()
self.register_buffer('data', torch.zeros(kv_cache_shape, dtype=torch.bfloat16))
self.seq_len = 0
self.zero()
def zero(self) -> None:
self.data.zero_()
def forward(self, c_kv: torch.Tensor) -> torch.Tensor:
assert self.seq_len + c_kv.size(1) <= self.data.size(1), "KV Cache Exceeded"
self.data = self.data.to(c_kv.dtype)
self.data[
:, self.seq_len : self.seq_len + c_kv.size(1), :
] = c_kv
self.seq_len += c_kv.size(1)
return self.data[:, :self.seq_len], self.seq_len
class MLA(nn.Module):
def __init__(self, config: Config):
super().__init__()
self.dim = config.dim
self.n_heads = config.n_heads
self.q_lora_rank = config.q_lora_rank
self.kv_lora_rank = config.kv_lora_rank
self.nope_head_dim = config.qk_nope_head_dim
self.rope_head_dim = config.qk_rope_head_dim
self.v_head_dim = config.v_head_dim
# Down-projection matrices
self.Q_proj_down = nn.Linear(self.dim, self.q_lora_rank, bias=False)
self.KV_proj_down = nn.Linear(self.dim, self.kv_lora_rank + self.rope_head_dim, bias=False)
# Up-projection and rope projection matrices
self.Q_proj_up = nn.Linear(self.q_lora_rank, (self.nope_head_dim + self.rope_head_dim) * self.n_heads, bias=False)
self.KV_proj_up = nn.Linear(self.kv_lora_rank, (self.nope_head_dim + self.v_head_dim) * self.n_heads, bias=False)
# RoPE on half embeddings
self.q_rope = RoPE(self.rope_head_dim)
self.k_rope = RoPE(self.rope_head_dim)
# Output projection
self.wo = nn.Linear(self.v_head_dim * self.n_heads, self.dim, bias=False)
self.eps = 1e-6
def forward(self, x: torch.Tensor, kv_cache: KVCache) -> torch.Tensor:
# seq_len = 1 always here
batch_size, seq_len, model_dim = x.size()
################################################################################
# Step 1: Handle down-projection + KV cache #
################################################################################
q_lora = self.Q_proj_down(x)
kv_lora = self.KV_proj_down(x)
kv_lora, kv_len = kv_cache(kv_lora)
query_pos = kv_len - 1
################################################################################
# Step 2: Up-project and prepare NoPE + RoPE #
################################################################################
# Handle queries Q first
q_nope_and_rope = self.Q_proj_up(q_lora).view(
batch_size, seq_len, self.n_heads, self.nope_head_dim + self.rope_head_dim)
q_nope, q_rope = torch.split(q_nope_and_rope, [self.nope_head_dim, self.rope_head_dim], dim=-1)
# Handle keys and values K/V. V does not need RoPE
kv_nope, k_rope = torch.split(kv_lora, [self.kv_lora_rank, self.rope_head_dim], dim=-1)
kv_nope = self.KV_proj_up(kv_nope).view(
batch_size, kv_len, self.n_heads, self.nope_head_dim + self.v_head_dim)
k_nope, v = torch.split(kv_nope, [self.nope_head_dim, self.v_head_dim], dim=-1)
################################################################################
# Step 3: Handle RoPE Stream #
################################################################################
# Compute RoPE for queries and combine with no-RoPE part
q_rope = q_rope.reshape(batch_size, self.n_heads, seq_len, self.rope_head_dim)
q_rope = self.q_rope(q_rope, start_pos=query_pos)
q_rope = q_rope.reshape(batch_size, seq_len, self.n_heads, self.rope_head_dim)
q = torch.concat([q_nope, q_rope], dim=-1)
# Compute RoPE for keys and combine with no-RoPE part
k_rope = k_rope[:, :, None, :]
k_rope = self.k_rope(k_rope).expand(-1,-1,self.n_heads,-1)
k = torch.concat([k_nope, k_rope], dim=-1)
################################################################################
# Compute Multi-head Attention #
################################################################################
q = q.reshape(batch_size, self.n_heads, 1, self.rope_head_dim + self.nope_head_dim)
k = k.reshape(batch_size, self.n_heads, kv_len, self.rope_head_dim + self.nope_head_dim)
v = v.reshape(batch_size, self.n_heads, -1, self.v_head_dim)
scores = torch.matmul(q, k.transpose(-1, -2)) / math.sqrt(self.rope_head_dim + self.nope_head_dim)
attn = F.softmax(scores, dim=-1).to(torch.bfloat16)
y = torch.matmul(attn, v).view(batch_size, -1)
y = self.wo(y)
return y, kv_cache
Problem Constraints:
We follow sizes from Deepseek-V3 / R1 and primarily vary the sequence length of the KV cache through:
0. bs::[128] # batch size
1. prefill::[512, 2048, 4096, 6144] # as kv length
2. sq::[1] # as only consider decoding
3. dim::7168 # hidden size of deepseek v3
4. kv_lora_rank::[512] # kv lora rank of deepseek v3
5. qk_rope_head_dim::[64] # rope embedding dimension
6. v_head_dim::[128] # dim of v embeddings per head
7. n_heads::[128] # num of attn heads
</aside>