<aside> Problem 2: Inference-only Mixture of Experts (MoE) Layer on a single MI300

In the post GPT-4 era of LLMs, a common model design is the mixture-of-experts (MoE) module instead of the dense linear layer / FFN following the attention layer. While these modules are extremely useful for scaling model parameters while maintaining sparse forward passes, efficient kernel implementations of these models are currently lacking. For this problem, you will be implementing this kernel on a single MI300.

MoE... on a single GPU…? More seasoned engineers may find it strange that this problem asks for MoE on a single device despite it inherently being implemented with multiple kernels across multiple devices or even nodes (see 64-way parallel MoE in Deepseek-v2). This is definitely a valid concern, but you are being tasked here to optimize it on a single device!

[💡Tip]: We will describe this layer in terms of modular functions to make it easier to digest — we also encourage writing kernels for these modular functions first before attempting to write the entire layer as a kernel!


The Mixture of Experts Layer, Informally Explained

I was offered feedback that a lot of people might not like math-y notation, so this short explanation + the code later in this document might be enough for you to understand the problem.

You will be implementing a Mixture-of-Experts (MoE) layer on a single device, with no residual connections and a softmax weighting function for routing tokens. For some folks, this explanation might be enough to get started. If that’s you, skip to here. If not, keep reading 📖!

Inputs. The MoE takes a sequence of token embeddings $x_1,x_2,...,x_T$ that are basically just vectors. Each of these token embeddings independently get “routed” to different experts, which we will explain below in both words and a simple drawing. It is a lot easier to think one token embedding at a time, because the MoE layer processes them all independently.

The MoE. The MoE layer has a bunch of experts and shared experts. Each of these “experts / shared experts” are functionally the same (they’re written using an Expert class) but they have their own independent weights and are assigned tokens differently:

Visually, this process looks something like the following:

Example of MoE layer for two tokens. In this example, there is 1 shared expert and N experts. The router (with K=2) assigns this token 1 to experts 2 and N with weights 0.15 and 0.7 respectively, and it assigns token 2 to experts 1 and 2 with weight 0.43 and 0.34 respectively. Finally, the outputs are summed. This process will get repeated for other tokens, but the router will assign different weights and active experts.

Example of MoE layer for two tokens. In this example, there is 1 shared expert and N experts. The router (with K=2) assigns this token 1 to experts 2 and N with weights 0.15 and 0.7 respectively, and it assigns token 2 to experts 1 and 2 with weight 0.43 and 0.34 respectively. Finally, the outputs are summed. This process will get repeated for other tokens, but the router will assign different weights and active experts.

The Router. The router first projects the input token into a vector with dimension the number of non-shared experts $N$, computes a softmax to get a probability, then zeros out all but the top $K$ indices of this softmax’d vector (which are now the active experts). These non-zero values are also the scalar weights that get multiplied by the output of the active experts.


The Mixture of Experts Layer, Formally Explained

[💡Tip]: Recall that the FFN / MOE layer acts **over the hidden dimension of each token independently. In other words, you can treat the sequence dimension like a batch dimension in traditional deep learning operators — this is very different than attention, which acts over the sequence dimension! In this problem, you will be given a sequence of tokens $\bold{X} \in \mathbb{R}^{S \times d}$ and asked to output the resulting tensor of applying the MOE layer to each token independently, so for our example we only need to consider one token.

Formally, given a single token $x \in \mathbb{R}^{1 \times d}$, an MOE layer with $N$ total non-shared experts, $K$ active non-shared experts, and $K_s$ shared experts is defined as:

$$ \text{MOE}(x) \triangleq \sum_{i=1}^{K_s} f^s_i(x) + \sum_{i=1}^{N} g_i(x) \cdot f_i(x) $$

There are two learnable functions, the experts $f$ and the gates $g$ that we define formally below:

Each expert $f_i(\cdot),f^{s}_i(\cdot)$ is functionally identical, just parameterized independently (the $s$ is just notation to signify a unique expert without re-indexing the functions). They are defined as

$$ f_i(x) \triangleq \bold{W_i}^{\text{down}} \left( \sigma \left( \bold{W_i}^{\text{gate}}x \right) \odot \bold{W_i}^{\text{up}}x \right) $$

where

Each gate $g_i(\cdot)$ is functionally identical, just parameterized independently. As the name suggests, only $K$ gates can be non-zero at once! Let $\text{top}K(\mathcal{S})$ denote the $K$ “highest” elements in $\mathcal{S}$ where $\mathcal{S} \in \mathbb{R}^{N}{+}$. In layman’s terms, it’s just the $K$ biggest elements. Then (abusing a bit of notation here):

$$ \begin{align} y_i &\triangleq \bold{W}^{g}_i x_i \\

s_i &\triangleq \text{softmax} \left(y_i \right) = \frac{e^{y_i}}{\sum_{j=1}^{N} e^{y_j}} \\

g_i(x_i) &\triangleq s_i \quad \text{ if } \quad s_i \in \text{top}_K(\left\{s_i | 1 \leq i \leq N\right\}) \quad \text{ else } 0 \end{align} $$

If you don’t like math, it’s super simple. Gates are only active for the $K$ highest (projected) input values, and are assigned a score according to their softmax if they are active. Otherwise, they are 0. The only weights here are:

Some tips for coding 👩🏻‍💻. The tricky part in the final implementation is that the above analysis applies to a single token. When given a sequence, different tokens might activate different experts, so you’ll have to think of a clever way to do this efficiently. The implementation we provide also handles gates using torch.topk, which sparsely generates Equation 3 above by giving a set of non-zero scores and their corresponding indices. Note that you do not have to do this unless you want to.

Remark. The astute reader will notice that because the softmax function is monotonic, you don’t even need to compute the softmax to determine which gates are active — the only reason it’s necessary (new research topic…?) is to assign a weight to the gates according to a probability distribution, but because we only take the top K values, it isn’t even a probability distribution!


PyTorch Implementation of Mixture of Experts.

We modularize the experts, gates, and the full model for ease of readability. The expert computation is very inefficient, but is made clear so you can see how tokens are routed and weighted based on the prior expressions for an entire sequence.

class Expert(nn.Module):
    def __init__(self, config: Dict, d_shared_expert: Optional[int] = None):
        super().__init__()
        self.config = config
        self.act_fn = nn.SiLU()
        self.d_hidden: int = config["d_hidden"]
        self.d_expert: int = config["d_expert"] if d_shared_expert is None else d_shared_expert

        self.W_gate = nn.Linear(self.d_hidden, self.d_expert, bias=False)
        self.W_up = nn.Linear(self.d_hidden, self.d_expert, bias=False)
        self.W_down = nn.Linear(self.d_expert, self.d_hidden, bias=False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        gate = self.act_fn(self.W_gate(x))
        out = self.W_down(gate * self.W_up(x))
        return out
class MoEGate(nn.Module):
    def __init__(self, config: Dict):
        super().__init__()
        self.top_k: int = config["n_experts_per_token"]
        self.num_experts: int = config["n_routed_experts"]
        self.d_hidden: int = config["d_hidden"]

        self.W_g = nn.Linear(self.d_hidden, self.num_experts, bias=False)

    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        logits = self.W_g(x)
        scores = logits.softmax(dim=-1)
        topk_scores, topk_indices = torch.topk(scores, k=self.top_k, dim=-1, sorted=False)

        return topk_indices, topk_scores
class MoE(nn.Module):
    def __init__(self, config: Dict):
        super().__init__()
        self.config = config
        self.experts = nn.ModuleList([
            Expert(config)
            for _ in range(config["n_routed_experts"])
        ])
        self.gating_network = MoEGate(config)

        # Shortcut to computing the sum of shared experts
        shared_expert_dim = config["d_expert"] * config["n_shared_experts"]
        self.shared_expert = Expert(config=config, d_shared_expert=shared_expert_dim)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Compute shared expert outputs over residual
        shared_output = self.shared_expert(x)

        # Compute expert routing weights
        expert_indices, expert_scores = self.gating_network(x)

        # Compute routed expert outputs
        routed_output = torch.zeros_like(x)
        batch_size, seq_len, _ = x.shape
        for b in range(batch_size):
            for t in range(seq_len):
                token = x[b, t].unsqueeze(0)  # (1, hidden_dim)
                for k in range(self.config["n_experts_per_token"]):
                    expert_id = int(expert_indices[b, t, k])
                    weight = expert_scores[b, t, k]
                    expert = self.experts[expert_id]

                    expert_output = expert(token).squeeze(0)  # (hidden_dim,)
                    routed_output[b, t] += expert_output * weight

        return routed_output + shared_output

### Input Data ###
# batch size B, sequence length S, hidden dimension d
X = torch.randn(B, S, d)
O = torch.zeros((B, S, d))

### Your example implementation ###
def kernel(data: Tuple[torch.Tensor,...], 
					 weights: Dict[str, torch.Tensor]):
		moe = MOE()
		moe.replace_weights(weights)
		O = moe(data)
		return O

Problem Constraints:

We consider all possible combinations of the following configurations:

"batch_size": {1,2,4}
"seq_len": {512, 1024, 2048, 4096, 8192}
"d_hidden": 7168
"d_expert": 2048
"n_routed_experts": 256
"n_shared_experts": 1
"n_experts_per_token": 8

</aside>