[OMNIML-2336][feat] add W4A8 NVFP4 FP8 fused moe (#7968)

Signed-off-by: Shiyang Chen <shiychen@nvidia.com>
This commit is contained in:
sychen52 2025-09-30 23:39:33 -07:00 committed by GitHub
parent b77f19f4ff
commit ba8abeab10
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 312 additions and 26 deletions

View File

@ -41,6 +41,7 @@ def get_moe_cls(
quant_config.quant_mode.has_fp8_block_scales()
or quant_config.quant_mode.has_nvfp4()
or quant_config.quant_mode.has_w4a16_mxfp4()
or quant_config.quant_mode.has_w4a8_nvfp4_fp8()
or quant_config.quant_mode.has_w4a8_mxfp4_fp8()
or quant_config.quant_mode.has_w4a8_mxfp4_mxfp8()):
return TRTLLMGenFusedMoE

View File

@ -15,6 +15,7 @@ from .quantization import (DeepSeekFP8BlockScalesFusedMoEMethod,
NVFP4TRTLLMGenFusedMoEMethod,
W4A8MXFP4FP8TRTLLMGenFusedMoEMethod,
W4A8MXFP4MXFP8TRTLLMGenFusedMoEMethod,
W4A8NVFP4FP8TRTLLMGenFusedMoEMethod,
W4A16MXFP4TRTLLMGenFusedMoEMethod)
from .routing import BaseMoeRoutingMethod, DeepSeekV3MoeRoutingMethod
@ -111,7 +112,7 @@ class TRTLLMGenFusedMoE(MoE):
def _check_configs(self):
assert self.has_deepseek_fp8_block_scales \
or self.has_nvfp4 or self.has_w4a16_mxfp4 \
or self.has_nvfp4 or self.has_w4a16_mxfp4 or self.has_w4a8_nvfp4_fp8 \
or self.has_w4a8_mxfp4_fp8 or self.has_w4a8_mxfp4_mxfp8, "TRTLLMGenFusedMoE only supports fp8_block_scaling, nvfp4, w4a16_mxfp4, w4a8_mxfp4_fp8 and w4a8_mxfp4_mxfp8 dtypes."
if self.bias or self.swiglu_alpha is not None or self.swiglu_beta is not None or self.swiglu_limit is not None:
@ -125,6 +126,8 @@ class TRTLLMGenFusedMoE(MoE):
return NVFP4TRTLLMGenFusedMoEMethod()
elif self.quant_config.layer_quant_mode.has_w4a16_mxfp4():
return W4A16MXFP4TRTLLMGenFusedMoEMethod()
elif self.quant_config.layer_quant_mode.has_w4a8_nvfp4_fp8():
return W4A8NVFP4FP8TRTLLMGenFusedMoEMethod()
elif self.quant_config.layer_quant_mode.has_w4a8_mxfp4_fp8():
return W4A8MXFP4FP8TRTLLMGenFusedMoEMethod()
elif self.quant_config.layer_quant_mode.has_w4a8_mxfp4_mxfp8():
@ -147,8 +150,8 @@ class TRTLLMGenFusedMoE(MoE):
self._weights_created = True
self._check_configs()
# TODO: FIX this.
if (self.has_w4a16_mxfp4 or self.has_w4a8_mxfp4_fp8
if (self.has_w4a16_mxfp4 or self.has_w4a8_nvfp4_fp8
or self.has_w4a8_mxfp4_fp8
or self.has_w4a8_mxfp4_mxfp8) and not self.bias:
self.w3_w1_bias = nn.Parameter(torch.zeros(
(self.w3_w1_weight.shape[0], self.w3_w1_weight.shape[1]),
@ -378,6 +381,46 @@ class TRTLLMGenFusedMoE(MoE):
)
final_hidden_states = final_hidden_states[:, :self.
hidden_size].contiguous()
elif self.has_w4a8_nvfp4_fp8:
if not run_post_quant_allgather:
hidden_states_fp8, _ = torch.ops.tensorrt_llm.static_quantize_e4m3_per_tensor(
x, 1.0 / self.fc31_input_scale)
else:
hidden_states_fp8 = x
outputs = torch.ops.trtllm.fp8_fp4_block_scale_moe_runner(
router_logits,
routing_bias,
hidden_states_fp8,
self.w3_w1_weight,
self.w3_w1_weight_scale.view(torch.float8_e4m3fn),
self.w2_weight,
self.w2_weight_scale.view(torch.float8_e4m3fn),
self.fc31_scale_c.data,
self.fc31_alpha.data,
self.fc2_alpha.data,
self.num_slots,
top_k,
n_group,
topk_group,
self.intermediate_size_per_partition,
self.
slot_start, # local_expert_start; use ep_rank if stride!=1
self.expert_size_per_partition, # local_expert_size
routed_scaling_factor,
self.routing_method.routing_method_type,
do_finalize=do_finalize,
act_type=0,
topk_ids=token_selected_experts,
topk_weights=token_final_scales,
)
if not do_finalize:
assert not self.reduce_results, "reduce_results must be False when do_finalize is False"
return outputs
else:
final_hidden_states = outputs[0]
elif self.has_w4a8_mxfp4_fp8:
pad_size = self.w3_w1_weight.shape[-1] * 2 - x.shape[-1]
if not run_post_quant_allgather:

View File

@ -301,6 +301,12 @@ class MoE(nn.Module):
return self.quant_config is not None and self.quant_config.layer_quant_mode.has_nvfp4(
)
@property
def has_w4a8_nvfp4_fp8(self):
assert self._weights_created
return self.quant_config is not None and self.quant_config.layer_quant_mode.has_w4a8_nvfp4_fp8(
)
@property
def has_w4a8_mxfp4_fp8(self):
assert self._weights_created

View File

@ -96,7 +96,7 @@ def trtllmgen_maybe_get_cached_w3_w1_permute_indices(
torch.Tensor],
epilogue_tile_m: int,
num_elts_per_sf: Union[None, int] = None) -> torch.Tensor:
key = (dst_w3_w1_weight.shape, "w31")
key = (dst_w3_w1_weight.shape, "w31", int(num_elts_per_sf or -1))
if key not in cache_permute_indices:
# Get permute indices and chain them together
permute0 = get_reorder_rows_for_gated_act_gemm_row_indices(
@ -122,7 +122,7 @@ def trtllmgen_maybe_get_cached_w2_permute_indices(
torch.Tensor],
epilogue_tile_m: int,
num_elts_per_sf: Union[None, int] = None) -> torch.Tensor:
key = (dst_w2_weight.shape, "w2")
key = (dst_w2_weight.shape, "w2", int(num_elts_per_sf or -1))
if key not in cache_permute_indices:
if num_elts_per_sf is None:
permute_indices = (get_shuffle_matrix_a_row_indices(
@ -1478,11 +1478,15 @@ class NVFP4FusedMoEMethod(FusedMoEMethodBase):
Base class for NVFP4 fused MoE methods for all backends.
"""
def create_weights(self, module: torch.nn.Module, weight_dtype,
weight_vec_size, block_scales_dtype,
block_scales_vec_size):
def create_weights(self,
module: torch.nn.Module,
weight_dtype,
weight_vec_size,
block_scales_dtype,
block_scales_vec_size,
scaling_vector_size=16):
module.scaling_vector_size = 16
module.scaling_vector_size = scaling_vector_size
# Divide by 16 because we use int64 to pack 16 fp4 values
w3_w1_weight_shape = (module.expert_size_per_partition,
module.intermediate_size_per_partition * 2,
@ -1893,9 +1897,12 @@ class NVFP4TRTLLMGenFusedMoEMethod(NVFP4FusedMoEMethod):
non_blocking=True)
def load_expert_w3_w1_weight_scale_nvfp4(
self, module: torch.nn.Module, w1_weight_scale: torch.Tensor,
self,
module: torch.nn.Module,
w1_weight_scale: torch.Tensor,
w3_weight_scale: torch.Tensor,
dst_w3_w1_weight_scale: torch.Tensor):
dst_w3_w1_weight_scale: torch.Tensor,
num_elts_per_sf: int = 16):
device = dst_w3_w1_weight_scale.device
assert device.type == "cuda"
w1_weight_scale = load_weight_shard(w1_weight_scale,
@ -1933,7 +1940,7 @@ class NVFP4TRTLLMGenFusedMoEMethod(NVFP4FusedMoEMethod):
dst_w3_w1_weight_scale.view(float4_sf_dtype),
self._cache_permute_indices,
epilogue_tile_m,
num_elts_per_sf=16)
num_elts_per_sf=num_elts_per_sf)
# Shuffle the weight according to permute indices
w3_w1_weight_scale = torch.ops.trtllm.shuffle_matrix(
@ -1949,9 +1956,11 @@ class NVFP4TRTLLMGenFusedMoEMethod(NVFP4FusedMoEMethod):
processed_w3_w1_weight_scale.view(
self.block_scales_dtype).reshape(orig_shape))
def load_expert_w2_weight_scale_nvfp4(self, module: torch.nn.Module,
def load_expert_w2_weight_scale_nvfp4(self,
module: torch.nn.Module,
w2_weight_scale: torch.Tensor,
dst_w2_weight_scale: torch.Tensor):
dst_w2_weight_scale: torch.Tensor,
num_elts_per_sf: int = 16):
device = dst_w2_weight_scale.device
assert device.type == "cuda"
w2_weight_scale = load_weight_shard(w2_weight_scale,
@ -1976,7 +1985,7 @@ class NVFP4TRTLLMGenFusedMoEMethod(NVFP4FusedMoEMethod):
dst_w2_weight_scale.view(float4_sf_dtype),
self._cache_permute_indices,
epilogue_tile_m,
num_elts_per_sf=16)
num_elts_per_sf=num_elts_per_sf)
# Shuffle the weight according to permute indices
w_shuffled = torch.ops.trtllm.shuffle_matrix(
@ -1998,6 +2007,56 @@ class NVFP4TRTLLMGenFusedMoEMethod(NVFP4FusedMoEMethod):
non_blocking=True)
class W4A8NVFP4FP8TRTLLMGenFusedMoEMethod(NVFP4TRTLLMGenFusedMoEMethod):
def create_weights(self, module: torch.nn.Module):
weight_vec_size = torch.iinfo(self.weight_dtype).bits // 4
block_scales_vec_size = 1
NVFP4FusedMoEMethod.create_weights(self, module, self.weight_dtype,
weight_vec_size,
self.block_scales_dtype,
block_scales_vec_size, 32)
fc31_scale_c = nn.Parameter(torch.ones(module.expert_size_per_partition,
dtype=torch.float32),
requires_grad=False)
module.register_parameter("fc31_scale_c", fc31_scale_c)
self.setup_quant_scales(module)
def load_expert_w3_w1_weight_scale_nvfp4(
self, module: torch.nn.Module, w1_weight_scale: torch.Tensor,
w3_weight_scale: torch.Tensor,
dst_w3_w1_weight_scale: torch.Tensor):
return super().load_expert_w3_w1_weight_scale_nvfp4(
module, w1_weight_scale, w3_weight_scale, dst_w3_w1_weight_scale,
32)
def load_expert_w2_weight_scale_nvfp4(self, module: torch.nn.Module,
w2_weight_scale: torch.Tensor,
dst_w2_weight_scale: torch.Tensor):
return super().load_expert_w2_weight_scale_nvfp4(
module, w2_weight_scale, dst_w2_weight_scale, 32)
def load_all_fp4_weight_scales_and_alphas(
self, module: torch.nn.Module, weights: Dict,
load_expert_ids: List[int], dst_w3_w1_weight_scale: torch.Tensor,
dst_w2_weight_scale: torch.Tensor, dst_fc31_alpha: torch.Tensor,
dst_fc2_alpha: torch.Tensor):
super().load_all_fp4_weight_scales_and_alphas(
module, weights, load_expert_ids, dst_w3_w1_weight_scale,
dst_w2_weight_scale, dst_fc31_alpha, dst_fc2_alpha)
# The kernel we use will convert nvfp4 to e4m3 before matmul,
# so the range of the scale factor can only be [0,448/6].
dst_w3_w1_weight_scale.copy_((dst_w3_w1_weight_scale.to(torch.float32) /
6.0).to(torch.float8_e4m3fn))
dst_w2_weight_scale.copy_((dst_w2_weight_scale.to(torch.float32) /
6.0).to(torch.float8_e4m3fn))
dst_fc31_alpha.copy_(dst_fc31_alpha * 6.0)
dst_fc2_alpha.copy_(dst_fc2_alpha * 6.0)
def _get_weight_alignment(weight_alignment, scaling_vector_size, tp_size,
shard_dim_size):

View File

@ -31,8 +31,9 @@ from tensorrt_llm._torch.modules.fused_moe.interface import MoEWeightLoadingMode
# isort and yapf will fight against each other here, so we disable isort
# isort: off
from tensorrt_llm._torch.modules.fused_moe import (
BaseMoeRoutingMethod, CutlassFusedMoE, DefaultMoeRoutingMethod,
RenormalizeMoeRoutingMethod, TritonFusedMoE, create_moe, WideEPMoE)
BaseMoeRoutingMethod, CutlassFusedMoE, TRTLLMGenFusedMoE,
DefaultMoeRoutingMethod, RenormalizeMoeRoutingMethod, TritonFusedMoE,
create_moe, WideEPMoE)
# isort: on
from tensorrt_llm._torch.modules.fused_moe.fused_moe_triton import \
IS_TRITON_KERNELS_AVAILABLE
@ -1423,6 +1424,131 @@ def test_fused_moe_nvfp4(dtype):
torch.testing.assert_close(output, ref_output, rtol=1e-2, atol=0.15)
@skip_pre_blackwell
def test_fused_moe_w4a8_nvfp4_fp8():
dtype = torch.bfloat16
mapping = Mapping()
mapping.rank = mpi_rank()
with torch.device(f'cuda:{mapping.rank}'):
SCALING_VECTOR_SIZE = 32
SEQ_LEN = 4
HIDDEN_SIZE = 512
INTERMEDIATE_SIZE = 512
NUM_EXPERTS = 4
TOP_K = 2
routing_method = RenormalizeMoeRoutingMethod(top_k=TOP_K)
torch.manual_seed(0)
torch.cuda.manual_seed(0)
x = torch.randn((SEQ_LEN, HIDDEN_SIZE), dtype=dtype, device="cuda")
x_sf_global = 448 / x.abs().max().float()
router_logits = torch.randn((SEQ_LEN, NUM_EXPERTS),
dtype=dtype,
device="cuda")
weights = {}
for expert_id in range(NUM_EXPERTS):
w1_weight = torch.randn((INTERMEDIATE_SIZE, HIDDEN_SIZE),
dtype=torch.float32,
device="cpu")
w1_sf_global = (448) / w1_weight.abs().max().float()
w2_weight = torch.randn((HIDDEN_SIZE, INTERMEDIATE_SIZE),
dtype=torch.float32,
device="cpu")
w2_sf_global = (448) / w2_weight.abs().max().float()
w3_weight = torch.randn((INTERMEDIATE_SIZE, HIDDEN_SIZE),
dtype=torch.float32,
device="cpu")
w3_sf_global = (448) / w3_weight.abs().max().float()
w3_w1_global = min(
w1_sf_global,
w3_sf_global) # w3 global and w1 global must be the same
w1_weight_nvfp4, w1_sf_block, _ = torch.ops.tensorrt_llm.float_to_e2m1_and_ufp8sf_scale(
w1_weight * w3_w1_global, SCALING_VECTOR_SIZE, 1, False)
w1_sf_block_unswizzled = torch.ops.trtllm.block_scale_interleave_reverse(
w1_sf_block.view(INTERMEDIATE_SIZE, -1))
w2_weight_nvfp4, w2_sf_block, _ = torch.ops.tensorrt_llm.float_to_e2m1_and_ufp8sf_scale(
w2_weight * w2_sf_global, SCALING_VECTOR_SIZE, 1, False)
w2_sf_block_unswizzled = torch.ops.trtllm.block_scale_interleave_reverse(
w2_sf_block.view(HIDDEN_SIZE, -1))
w3_weight_nvfp4, w3_sf_block, _ = torch.ops.tensorrt_llm.float_to_e2m1_and_ufp8sf_scale(
w3_weight * w3_w1_global, SCALING_VECTOR_SIZE, 1, False)
w3_sf_block_unswizzled = torch.ops.trtllm.block_scale_interleave_reverse(
w3_sf_block.view(INTERMEDIATE_SIZE, -1))
w1_weight_nvfp4 = w1_weight_nvfp4.cuda()
w1_sf_block_unswizzled = w1_sf_block_unswizzled.cuda()
w2_weight_nvfp4 = w2_weight_nvfp4.cuda()
w2_sf_block_unswizzled = w2_sf_block_unswizzled.cuda()
w3_weight_nvfp4 = w3_weight_nvfp4.cuda()
w3_sf_block_unswizzled = w3_sf_block_unswizzled.cuda()
w1_input_scale = x_sf_global.cuda()
w2_input_scale = x_sf_global.cuda()
w3_input_scale = x_sf_global.cuda()
weights[f"{expert_id}.w1.weight"] = w1_weight_nvfp4
weights[f"{expert_id}.w2.weight"] = w2_weight_nvfp4
weights[f"{expert_id}.w3.weight"] = w3_weight_nvfp4
weights[
f"{expert_id}.w1.weight_scale"] = w1_sf_block_unswizzled.view(
torch.float8_e4m3fn).cuda()
weights[
f"{expert_id}.w2.weight_scale"] = w2_sf_block_unswizzled.view(
torch.float8_e4m3fn).cuda()
weights[
f"{expert_id}.w3.weight_scale"] = w3_sf_block_unswizzled.view(
torch.float8_e4m3fn).cuda()
weights[f"{expert_id}.w1.input_scale"] = 1.0 / w1_input_scale
weights[f"{expert_id}.w2.input_scale"] = 1.0 / w2_input_scale
weights[f"{expert_id}.w3.input_scale"] = 1.0 / w3_input_scale
weights[f"{expert_id}.w1.weight_scale_2"] = 1.0 / w3_w1_global
weights[f"{expert_id}.w2.weight_scale_2"] = 1.0 / w2_sf_global
weights[f"{expert_id}.w3.weight_scale_2"] = 1.0 / w3_w1_global
quant_config = QuantConfig(quant_algo=QuantAlgo.W4A8_NVFP4_FP8)
fused_moe = TRTLLMGenFusedMoE(
num_experts=NUM_EXPERTS,
routing_method=routing_method,
hidden_size=HIDDEN_SIZE,
intermediate_size=INTERMEDIATE_SIZE,
dtype=dtype,
reduce_results=False,
model_config=ModelConfig(quant_config=quant_config))
fused_moe.load_weights([weights])
fused_moe.cuda()
# Evaluate the outputs on a variant sequence length to cover all possible keys in Autotuner cache
ref_fused_moe = RefGatedMLPFusedMoE(
num_experts=NUM_EXPERTS,
routing_method=routing_method,
hidden_size=HIDDEN_SIZE,
intermediate_size=INTERMEDIATE_SIZE,
dtype=dtype,
model_config=ModelConfig(quant_config=quant_config))
ref_fused_moe.load_weights([weights])
ref_fused_moe.cuda()
AutoTuner.get().clear_cache()
with torch.inference_mode(), autotune():
fused_moe.forward(x, router_logits)
with torch.inference_mode():
output = fused_moe.forward(x, router_logits)
ref_output = ref_fused_moe.forward(x, router_logits)
# compare
torch.cuda.synchronize()
torch.testing.assert_close(output, ref_output, rtol=1e-1, atol=0.5)
@skip_neither_ada_nor_hopper_unittest
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize(
@ -1689,7 +1815,7 @@ def test_fused_moe_w4afp8(dtype, weight_loading_mode):
@skip_pre_blackwell
@pytest.mark.parametrize("moe_backend", ["TRTLLM", "CUTLASS"])
@pytest.mark.parametrize("bias", [True, False])
def test_fused_moe_mxfp4_mxpf8(moe_backend, bias):
def test_fused_moe_mxfp4_mxfp8(moe_backend, bias):
SCALING_VECTOR_SIZE = 32
dtype = torch.bfloat16
SEQ_LEN = 128
@ -2262,7 +2388,8 @@ class RefGatedMLPFusedMoE(nn.Module):
f"{expert}.w3.input_scale"]
down_proj_weights[0]['input_scale'] = weights[
f"{expert}.w2.input_scale"]
elif self.quant_config and self.quant_config.quant_algo == QuantAlgo.NVFP4:
elif self.quant_config and self.quant_config.quant_algo in (
QuantAlgo.NVFP4, QuantAlgo.W4A8_NVFP4_FP8):
gate_up_proj_weights[0]['weight_scale'] = weights[
f"{expert}.w1.weight_scale"]
gate_up_proj_weights[1]['weight_scale'] = weights[

View File

@ -1069,13 +1069,36 @@ class TestMoeFp4:
routing_info,
use_autotune=True,
use_topk_as_input=False)
if intermediate_size >= 256:
self.run_moe_fp8_fp4_test(num_tokens,
hidden_size,
intermediate_size,
routing_info,
use_autotune=True,
use_topk_as_input=False)
@pytest.mark.parametrize("num_tokens", [1])
@pytest.mark.parametrize("hidden_size", [1024])
@pytest.mark.parametrize("intermediate_size", [384])
@pytest.mark.parametrize(
"routing_info",
[
pytest.param(
{
"num_experts": 72,
"top_k": 6,
"padding": 8,
"n_groups": 1,
"top_k_groups": 1,
"routed_scaling": 2.5,
"has_routing_bias": True,
"routing_method_type": RoutingMethodType.DeepSeekV3
},
id="RoutingDSlite"),
],
)
def test_autotune_fp8_fp4(self, num_tokens, hidden_size, intermediate_size,
routing_info):
self.run_moe_fp8_fp4_test(num_tokens,
hidden_size,
intermediate_size,
routing_info,
use_autotune=True,
use_topk_as_input=False)
@pytest.mark.parametrize("num_tokens", [1, 150])
@pytest.mark.parametrize("hidden_size", [1024])
@ -1120,6 +1143,33 @@ class TestMoeFp4:
routing_info,
use_autotune=False,
use_topk_as_input=use_topk_as_input)
@pytest.mark.parametrize("num_tokens", [1])
@pytest.mark.parametrize("hidden_size", [1024])
@pytest.mark.parametrize("intermediate_size", [1024])
@pytest.mark.parametrize(
"routing_info",
[
pytest.param(
{
"num_experts": 128,
"top_k": 4,
"padding": 8,
"n_groups": None,
"top_k_groups": None,
"routed_scaling": None,
"has_routing_bias": False,
"routing_method_type": RoutingMethodType.Renormalize
},
id="RoutingRenormalize_topk_4"),
],
)
@pytest.mark.parametrize("use_topk_as_input", [False, True],
ids=["use_score_as_input", "use_topk_as_input"])
def test_no_autotune_fp8_fp4(self, num_tokens, hidden_size,
intermediate_size, routing_info,
use_topk_as_input):
self.run_moe_fp8_fp4_test(num_tokens,
hidden_size,
intermediate_size,