TensorRT-LLMs/tests/_torch/test_fused_moe.py
Sharan Chetlur 258c7540c0 open source 09df54c0cc99354a60bbc0303e3e8ea33a96bef0 (#2725)
Co-authored-by: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com>

open source f8c0381a2bc50ee2739c3d8c2be481b31e5f00bd (#2736)

Co-authored-by: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com>

Add note for blackwell (#2742)

Update the docs to workaround the extra-index-url issue (#2744)

update README.md (#2751)

Fix github io pages (#2761)

Update
2025-02-11 02:21:51 +00:00

369 lines
15 KiB
Python

import os
import sys
from typing import Dict, List, Optional
import pytest
import torch
import torch.nn as nn
from tensorrt_llm._torch.model_config import ModelConfig
from tensorrt_llm._torch.modules.fused_moe import FusedMoE
from tensorrt_llm._torch.modules.gated_mlp import GatedMLP
from tensorrt_llm.models.modeling_utils import QuantAlgo, QuantConfig
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
from utils.util import skip_pre_ampere, skip_pre_blackwell, skip_pre_hopper
@skip_pre_ampere
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
def test_fused_moe(dtype):
SEQ_LEN = 8
HIDDEN_SIZE = 64
INTERMEDIATE_SIZE = 32
NUM_EXPERTS = 3
TOP_K = 2
torch.manual_seed(0)
torch.cuda.manual_seed(0)
x = torch.randn((SEQ_LEN, HIDDEN_SIZE), dtype=dtype).cuda()
router_logits = torch.randn((SEQ_LEN, NUM_EXPERTS), dtype=dtype).cuda()
weights = {}
for expert_id in range(NUM_EXPERTS):
w1_weight = torch.randn((INTERMEDIATE_SIZE, HIDDEN_SIZE),
dtype=dtype).cuda()
w2_weight = torch.randn((HIDDEN_SIZE, INTERMEDIATE_SIZE),
dtype=dtype).cuda()
w3_weight = torch.randn((INTERMEDIATE_SIZE, HIDDEN_SIZE),
dtype=dtype).cuda()
weights[f"{expert_id}.w1.weight"] = w1_weight
weights[f"{expert_id}.w2.weight"] = w2_weight
weights[f"{expert_id}.w3.weight"] = w3_weight
fused_moe = FusedMoE(num_experts=NUM_EXPERTS,
top_k=TOP_K,
hidden_size=HIDDEN_SIZE,
intermediate_size=INTERMEDIATE_SIZE,
dtype=dtype,
reduce_results=False,
model_config=ModelConfig())
fused_moe.load_weights([weights])
fused_moe.cuda()
assert fused_moe.has_been_profiled == False
with torch.inference_mode():
output = fused_moe.forward(x, router_logits)
assert fused_moe.has_been_profiled
# torch run
ref_fused_moe = RefGatedMLPFusedMoE(num_experts=NUM_EXPERTS,
top_k=TOP_K,
hidden_size=HIDDEN_SIZE,
intermediate_size=INTERMEDIATE_SIZE,
dtype=dtype,
model_config=ModelConfig())
ref_fused_moe.load_weights([weights])
ref_fused_moe.cuda()
with torch.inference_mode():
ref_output = ref_fused_moe.forward(x, router_logits)
# with torch.inference_mode():
# ref_output = reference_moe_torch(x, router_logits, TOP_K, weights)
# compare
torch.cuda.synchronize()
torch.testing.assert_close(output, ref_output, rtol=0.2, atol=0.2)
@skip_pre_hopper
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
def test_fused_moe_fp8(dtype):
SEQ_LEN = 4
HIDDEN_SIZE = 64
INTERMEDIATE_SIZE = 32
NUM_EXPERTS = 3
TOP_K = 2
torch.manual_seed(0)
torch.cuda.manual_seed(0)
x = torch.randn((SEQ_LEN, HIDDEN_SIZE), dtype=dtype).cuda()
_, x_scale = torch.ops.tensorrt_llm.quantize_e4m3_per_tensor(x)
x_scale = x_scale.float().squeeze()
router_logits = torch.randn((SEQ_LEN, NUM_EXPERTS), dtype=dtype).cuda()
weights = {}
for expert_id in range(NUM_EXPERTS):
w1_weight = torch.randn((INTERMEDIATE_SIZE, HIDDEN_SIZE),
dtype=dtype).cuda()
w2_weight = torch.randn((HIDDEN_SIZE, INTERMEDIATE_SIZE),
dtype=dtype).cuda()
w3_weight = torch.randn((INTERMEDIATE_SIZE, HIDDEN_SIZE),
dtype=dtype).cuda()
w1_weight_fp8, w1_weight_scale = torch.ops.tensorrt_llm.quantize_e4m3_per_tensor(
w1_weight)
w1_weight_fp8 = w1_weight_fp8.view(torch.float8_e4m3fn).cuda()
w2_weight_fp8, w2_weight_scale = torch.ops.tensorrt_llm.quantize_e4m3_per_tensor(
w2_weight)
w2_weight_fp8 = w2_weight_fp8.view(torch.float8_e4m3fn).cuda()
w3_weight_fp8, w3_weight_scale = torch.ops.tensorrt_llm.quantize_e4m3_per_tensor(
w3_weight)
w3_weight_fp8 = w3_weight_fp8.view(torch.float8_e4m3fn).cuda()
w1_input_scale = x_scale.cuda()
w2_input_scale = x_scale.cuda()
w3_input_scale = x_scale.cuda()
weights[f"{expert_id}.w1.weight"] = w1_weight_fp8
weights[f"{expert_id}.w2.weight"] = w2_weight_fp8
weights[f"{expert_id}.w3.weight"] = w3_weight_fp8
weights[f"{expert_id}.w1.weight_scale"] = w1_weight_scale.float()
weights[f"{expert_id}.w2.weight_scale"] = w2_weight_scale.float()
weights[f"{expert_id}.w3.weight_scale"] = w3_weight_scale.float()
weights[f"{expert_id}.w1.input_scale"] = w1_input_scale
weights[f"{expert_id}.w2.input_scale"] = w2_input_scale
weights[f"{expert_id}.w3.input_scale"] = w3_input_scale
quant_config = QuantConfig(quant_algo=QuantAlgo.FP8)
fused_moe = FusedMoE(num_experts=NUM_EXPERTS,
top_k=TOP_K,
hidden_size=HIDDEN_SIZE,
intermediate_size=INTERMEDIATE_SIZE,
dtype=dtype,
reduce_results=False,
model_config=ModelConfig(quant_config=quant_config))
fused_moe.cuda()
fused_moe.load_weights([weights])
assert fused_moe.has_been_profiled == False
with torch.inference_mode():
output = fused_moe.forward(x, router_logits)
assert fused_moe.has_been_profiled
ref_fused_moe = RefGatedMLPFusedMoE(
num_experts=NUM_EXPERTS,
top_k=TOP_K,
hidden_size=HIDDEN_SIZE,
intermediate_size=INTERMEDIATE_SIZE,
dtype=dtype,
model_config=ModelConfig(quant_config=quant_config))
ref_fused_moe.load_weights([weights])
ref_fused_moe.cuda()
with torch.inference_mode():
ref_output = ref_fused_moe.forward(x, router_logits)
# compare
torch.cuda.synchronize()
torch.testing.assert_close(output, ref_output, rtol=1e-2, atol=0.1)
@skip_pre_blackwell
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
def test_fused_moe_nvfp4(dtype):
SCALING_VECTOR_SIZE = 16
SEQ_LEN = 4
HIDDEN_SIZE = 128
INTERMEDIATE_SIZE = 128
NUM_EXPERTS = 3
TOP_K = 2
torch.manual_seed(0)
torch.cuda.manual_seed(0)
x = torch.randn((SEQ_LEN, HIDDEN_SIZE), dtype=dtype).cuda()
x_sf_global = (448 * 6) / x.abs().max().float()
router_logits = torch.randn((SEQ_LEN, NUM_EXPERTS), dtype=dtype).cuda()
weights = {}
for expert_id in range(NUM_EXPERTS):
w1_weight = torch.randn((INTERMEDIATE_SIZE, HIDDEN_SIZE),
dtype=dtype).cuda()
w1_sf_global = (448 * 6) / w1_weight.abs().max().float()
w2_weight = torch.randn((HIDDEN_SIZE, INTERMEDIATE_SIZE),
dtype=dtype).cuda()
w2_sf_global = (448 * 6) / w2_weight.abs().max().float()
w3_weight = torch.randn((INTERMEDIATE_SIZE, HIDDEN_SIZE),
dtype=dtype).cuda()
w3_sf_global = (448 * 6) / w3_weight.abs().max().float()
w3_w1_global = min(
w1_sf_global,
w3_sf_global) # w3 global and w1 global must be the same
w1_weight_nvfp4, w1_sf_block = torch.ops.trtllm.fp4_quantize(
w1_weight, w3_w1_global, SCALING_VECTOR_SIZE, False)
w1_sf_block_unswizzled = torch.ops.tensorrt_llm.nvfp4_block_scale_interleave_reverse(
w1_sf_block.cpu().view(INTERMEDIATE_SIZE, -1))
w2_weight_nvfp4, w2_sf_block = torch.ops.trtllm.fp4_quantize(
w2_weight, w2_sf_global, SCALING_VECTOR_SIZE, False)
w2_sf_block_unswizzled = torch.ops.tensorrt_llm.nvfp4_block_scale_interleave_reverse(
w2_sf_block.cpu().view(HIDDEN_SIZE, -1))
w3_weight_nvfp4, w3_sf_block = torch.ops.trtllm.fp4_quantize(
w3_weight, w3_w1_global, SCALING_VECTOR_SIZE, False)
w3_sf_block_unswizzled = torch.ops.tensorrt_llm.nvfp4_block_scale_interleave_reverse(
w3_sf_block.cpu().view(INTERMEDIATE_SIZE, -1))
w1_input_scale = x_sf_global.cuda()
w2_input_scale = x_sf_global.cuda()
w3_input_scale = x_sf_global.cuda()
weights[f"{expert_id}.w1.weight"] = w1_weight_nvfp4
weights[f"{expert_id}.w2.weight"] = w2_weight_nvfp4
weights[f"{expert_id}.w3.weight"] = w3_weight_nvfp4
weights[f"{expert_id}.w1.weight_scale"] = w1_sf_block_unswizzled.view(
torch.float8_e4m3fn).cuda()
weights[f"{expert_id}.w2.weight_scale"] = w2_sf_block_unswizzled.view(
torch.float8_e4m3fn).cuda()
weights[f"{expert_id}.w3.weight_scale"] = w3_sf_block_unswizzled.view(
torch.float8_e4m3fn).cuda()
weights[f"{expert_id}.w1.input_scale"] = 1.0 / w1_input_scale
weights[f"{expert_id}.w2.input_scale"] = 1.0 / w2_input_scale
weights[f"{expert_id}.w3.input_scale"] = 1.0 / w3_input_scale
weights[f"{expert_id}.w1.weight_scale_2"] = 1.0 / w3_w1_global
weights[f"{expert_id}.w2.weight_scale_2"] = 1.0 / w2_sf_global
weights[f"{expert_id}.w3.weight_scale_2"] = 1.0 / w3_w1_global
quant_config = QuantConfig(quant_algo=QuantAlgo.NVFP4)
fused_moe = FusedMoE(num_experts=NUM_EXPERTS,
top_k=TOP_K,
hidden_size=HIDDEN_SIZE,
intermediate_size=INTERMEDIATE_SIZE,
dtype=dtype,
reduce_results=False,
model_config=ModelConfig(quant_config=quant_config))
fused_moe.load_weights([weights])
fused_moe.cuda()
assert fused_moe.has_been_profiled == False
with torch.inference_mode():
output = fused_moe.forward(x, router_logits)
assert fused_moe.has_been_profiled
ref_fused_moe = RefGatedMLPFusedMoE(
num_experts=NUM_EXPERTS,
top_k=TOP_K,
hidden_size=HIDDEN_SIZE,
intermediate_size=INTERMEDIATE_SIZE,
dtype=dtype,
model_config=ModelConfig(quant_config=quant_config))
ref_fused_moe.load_weights([weights])
ref_fused_moe.cuda()
with torch.inference_mode():
ref_output = ref_fused_moe.forward(x, router_logits)
# compare
torch.cuda.synchronize()
torch.testing.assert_close(output, ref_output, rtol=1e-2, atol=0.1)
class RefGatedMLPFusedMoE(nn.Module):
def __init__(self,
num_experts: int,
top_k: int,
hidden_size: int,
intermediate_size: int,
dtype: Optional[torch.dtype] = None,
model_config: ModelConfig = ModelConfig()):
super().__init__()
self.num_experts = num_experts
self.top_k = top_k
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.dtype = dtype
self.quant_config = model_config.quant_config
self.experts = nn.ModuleList([
GatedMLP(
hidden_size=self.hidden_size,
intermediate_size=self.intermediate_size,
bias=False,
dtype=self.dtype,
config=model_config,
) for _ in range(self.num_experts)
])
def forward(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor) -> torch.Tensor:
assert hidden_states.shape[-1] == self.hidden_size
hidden_states = hidden_states.view(-1, self.hidden_size)
routing_weights = nn.functional.softmax(router_logits,
dim=1,
dtype=torch.float)
routing_weights, selected_experts = torch.topk(routing_weights,
self.top_k,
dim=-1)
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
# cast back to the input dtype
routing_weights = routing_weights.to(hidden_states.dtype)
final_hidden_states = torch.zeros(hidden_states.shape,
dtype=hidden_states.dtype,
device=hidden_states.device)
for expert_id in range(self.num_experts):
batch_idx, nth_expert = torch.where(selected_experts == expert_id)
expert_inputs = hidden_states[batch_idx]
output = self.experts[expert_id](expert_inputs)
final_hidden_states[batch_idx] += routing_weights[batch_idx,
nth_expert,
None] * output
final_hidden_states = final_hidden_states.reshape(hidden_states.shape)
return final_hidden_states
def load_weights(self, weights: List[Dict]):
assert len(weights) == 1
weights = weights[0]
for expert in range(self.num_experts):
gate_up_proj_weights = [{}, {}]
down_proj_weights = [{}]
gate_up_proj_weights[0]['weight'] = weights[f"{expert}.w1.weight"]
gate_up_proj_weights[1]['weight'] = weights[f"{expert}.w3.weight"]
down_proj_weights[0]['weight'] = weights[f"{expert}.w2.weight"]
if self.quant_config and self.quant_config.quant_algo == QuantAlgo.FP8:
gate_up_proj_weights[0]['weight_scale'] = weights[
f"{expert}.w1.weight_scale"]
gate_up_proj_weights[1]['weight_scale'] = weights[
f"{expert}.w3.weight_scale"]
down_proj_weights[0]['weight_scale'] = weights[
f"{expert}.w2.weight_scale"]
gate_up_proj_weights[0]['input_scale'] = weights[
f"{expert}.w1.input_scale"]
gate_up_proj_weights[1]['input_scale'] = weights[
f"{expert}.w3.input_scale"]
down_proj_weights[0]['input_scale'] = weights[
f"{expert}.w2.input_scale"]
elif self.quant_config and self.quant_config.quant_algo == QuantAlgo.NVFP4:
gate_up_proj_weights[0]['weight_scale'] = weights[
f"{expert}.w1.weight_scale"]
gate_up_proj_weights[1]['weight_scale'] = weights[
f"{expert}.w3.weight_scale"]
down_proj_weights[0]['weight_scale'] = weights[
f"{expert}.w2.weight_scale"]
gate_up_proj_weights[0]['input_scale'] = weights[
f"{expert}.w1.input_scale"]
gate_up_proj_weights[1]['input_scale'] = weights[
f"{expert}.w3.input_scale"]
down_proj_weights[0]['input_scale'] = weights[
f"{expert}.w2.input_scale"]
gate_up_proj_weights[0]['weight_scale_2'] = weights[
f"{expert}.w1.weight_scale_2"]
gate_up_proj_weights[1]['weight_scale_2'] = weights[
f"{expert}.w3.weight_scale_2"]
down_proj_weights[0]['weight_scale_2'] = weights[
f"{expert}.w2.weight_scale_2"]
self.experts[expert].gate_up_proj.load_weights(gate_up_proj_weights)
self.experts[expert].down_proj.load_weights(down_proj_weights)