TensorRT-LLMs/tests/_torch/helpers.py
Sharan Chetlur 258c7540c0 open source 09df54c0cc99354a60bbc0303e3e8ea33a96bef0 (#2725)
Co-authored-by: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com>

open source f8c0381a2bc50ee2739c3d8c2be481b31e5f00bd (#2736)

Co-authored-by: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com>

Add note for blackwell (#2742)

Update the docs to workaround the extra-index-url issue (#2744)

update README.md (#2751)

Fix github io pages (#2761)

Update
2025-02-11 02:21:51 +00:00

36 lines
1.5 KiB
Python

from typing import Dict
import torch
import torch.nn as nn
import torch.nn.functional as F
def reference_moe_torch(x: torch.Tensor, router_logits: torch.Tensor,
top_k: int,
weights: Dict[str, torch.Tensor]) -> torch.Tensor:
num_experts = router_logits.shape[-1]
routing_weights = nn.functional.softmax(router_logits,
dim=1,
dtype=torch.float)
routing_weights, selected_experts = torch.topk(routing_weights,
top_k,
dim=-1)
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
# cast back to the input dtype
routing_weights = routing_weights.to(x.dtype)
results = torch.zeros_like(x)
# naive looping over experts
for expert_id in range(num_experts):
batch_idx, nth_expert = torch.where(selected_experts == expert_id)
w1_weight = weights[f"{expert_id}.w1.weight"]
w2_weight = weights[f"{expert_id}.w2.weight"]
w3_weight = weights[f"{expert_id}.w3.weight"]
expert_inputs = x[batch_idx]
output = (F.silu(expert_inputs @ w1_weight.t()) *
(expert_inputs @ w3_weight.t())) @ w2_weight.t()
results[batch_idx] += routing_weights[batch_idx, nth_expert,
None] * output
return results.view_as(x)