mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
[OMNIML-2336][feat] add W4A8 NVFP4 FP8 fused moe (#7968)
Signed-off-by: Shiyang Chen <shiychen@nvidia.com>
This commit is contained in:
parent
b77f19f4ff
commit
ba8abeab10
@ -41,6 +41,7 @@ def get_moe_cls(
|
||||
quant_config.quant_mode.has_fp8_block_scales()
|
||||
or quant_config.quant_mode.has_nvfp4()
|
||||
or quant_config.quant_mode.has_w4a16_mxfp4()
|
||||
or quant_config.quant_mode.has_w4a8_nvfp4_fp8()
|
||||
or quant_config.quant_mode.has_w4a8_mxfp4_fp8()
|
||||
or quant_config.quant_mode.has_w4a8_mxfp4_mxfp8()):
|
||||
return TRTLLMGenFusedMoE
|
||||
|
||||
@ -15,6 +15,7 @@ from .quantization import (DeepSeekFP8BlockScalesFusedMoEMethod,
|
||||
NVFP4TRTLLMGenFusedMoEMethod,
|
||||
W4A8MXFP4FP8TRTLLMGenFusedMoEMethod,
|
||||
W4A8MXFP4MXFP8TRTLLMGenFusedMoEMethod,
|
||||
W4A8NVFP4FP8TRTLLMGenFusedMoEMethod,
|
||||
W4A16MXFP4TRTLLMGenFusedMoEMethod)
|
||||
from .routing import BaseMoeRoutingMethod, DeepSeekV3MoeRoutingMethod
|
||||
|
||||
@ -111,7 +112,7 @@ class TRTLLMGenFusedMoE(MoE):
|
||||
|
||||
def _check_configs(self):
|
||||
assert self.has_deepseek_fp8_block_scales \
|
||||
or self.has_nvfp4 or self.has_w4a16_mxfp4 \
|
||||
or self.has_nvfp4 or self.has_w4a16_mxfp4 or self.has_w4a8_nvfp4_fp8 \
|
||||
or self.has_w4a8_mxfp4_fp8 or self.has_w4a8_mxfp4_mxfp8, "TRTLLMGenFusedMoE only supports fp8_block_scaling, nvfp4, w4a16_mxfp4, w4a8_mxfp4_fp8 and w4a8_mxfp4_mxfp8 dtypes."
|
||||
|
||||
if self.bias or self.swiglu_alpha is not None or self.swiglu_beta is not None or self.swiglu_limit is not None:
|
||||
@ -125,6 +126,8 @@ class TRTLLMGenFusedMoE(MoE):
|
||||
return NVFP4TRTLLMGenFusedMoEMethod()
|
||||
elif self.quant_config.layer_quant_mode.has_w4a16_mxfp4():
|
||||
return W4A16MXFP4TRTLLMGenFusedMoEMethod()
|
||||
elif self.quant_config.layer_quant_mode.has_w4a8_nvfp4_fp8():
|
||||
return W4A8NVFP4FP8TRTLLMGenFusedMoEMethod()
|
||||
elif self.quant_config.layer_quant_mode.has_w4a8_mxfp4_fp8():
|
||||
return W4A8MXFP4FP8TRTLLMGenFusedMoEMethod()
|
||||
elif self.quant_config.layer_quant_mode.has_w4a8_mxfp4_mxfp8():
|
||||
@ -147,8 +150,8 @@ class TRTLLMGenFusedMoE(MoE):
|
||||
self._weights_created = True
|
||||
self._check_configs()
|
||||
|
||||
# TODO: FIX this.
|
||||
if (self.has_w4a16_mxfp4 or self.has_w4a8_mxfp4_fp8
|
||||
if (self.has_w4a16_mxfp4 or self.has_w4a8_nvfp4_fp8
|
||||
or self.has_w4a8_mxfp4_fp8
|
||||
or self.has_w4a8_mxfp4_mxfp8) and not self.bias:
|
||||
self.w3_w1_bias = nn.Parameter(torch.zeros(
|
||||
(self.w3_w1_weight.shape[0], self.w3_w1_weight.shape[1]),
|
||||
@ -378,6 +381,46 @@ class TRTLLMGenFusedMoE(MoE):
|
||||
)
|
||||
final_hidden_states = final_hidden_states[:, :self.
|
||||
hidden_size].contiguous()
|
||||
elif self.has_w4a8_nvfp4_fp8:
|
||||
|
||||
if not run_post_quant_allgather:
|
||||
hidden_states_fp8, _ = torch.ops.tensorrt_llm.static_quantize_e4m3_per_tensor(
|
||||
x, 1.0 / self.fc31_input_scale)
|
||||
else:
|
||||
hidden_states_fp8 = x
|
||||
|
||||
outputs = torch.ops.trtllm.fp8_fp4_block_scale_moe_runner(
|
||||
router_logits,
|
||||
routing_bias,
|
||||
hidden_states_fp8,
|
||||
self.w3_w1_weight,
|
||||
self.w3_w1_weight_scale.view(torch.float8_e4m3fn),
|
||||
self.w2_weight,
|
||||
self.w2_weight_scale.view(torch.float8_e4m3fn),
|
||||
self.fc31_scale_c.data,
|
||||
self.fc31_alpha.data,
|
||||
self.fc2_alpha.data,
|
||||
self.num_slots,
|
||||
top_k,
|
||||
n_group,
|
||||
topk_group,
|
||||
self.intermediate_size_per_partition,
|
||||
self.
|
||||
slot_start, # local_expert_start; use ep_rank if stride!=1
|
||||
self.expert_size_per_partition, # local_expert_size
|
||||
routed_scaling_factor,
|
||||
self.routing_method.routing_method_type,
|
||||
do_finalize=do_finalize,
|
||||
act_type=0,
|
||||
topk_ids=token_selected_experts,
|
||||
topk_weights=token_final_scales,
|
||||
)
|
||||
|
||||
if not do_finalize:
|
||||
assert not self.reduce_results, "reduce_results must be False when do_finalize is False"
|
||||
return outputs
|
||||
else:
|
||||
final_hidden_states = outputs[0]
|
||||
elif self.has_w4a8_mxfp4_fp8:
|
||||
pad_size = self.w3_w1_weight.shape[-1] * 2 - x.shape[-1]
|
||||
if not run_post_quant_allgather:
|
||||
|
||||
@ -301,6 +301,12 @@ class MoE(nn.Module):
|
||||
return self.quant_config is not None and self.quant_config.layer_quant_mode.has_nvfp4(
|
||||
)
|
||||
|
||||
@property
|
||||
def has_w4a8_nvfp4_fp8(self):
|
||||
assert self._weights_created
|
||||
return self.quant_config is not None and self.quant_config.layer_quant_mode.has_w4a8_nvfp4_fp8(
|
||||
)
|
||||
|
||||
@property
|
||||
def has_w4a8_mxfp4_fp8(self):
|
||||
assert self._weights_created
|
||||
|
||||
@ -96,7 +96,7 @@ def trtllmgen_maybe_get_cached_w3_w1_permute_indices(
|
||||
torch.Tensor],
|
||||
epilogue_tile_m: int,
|
||||
num_elts_per_sf: Union[None, int] = None) -> torch.Tensor:
|
||||
key = (dst_w3_w1_weight.shape, "w31")
|
||||
key = (dst_w3_w1_weight.shape, "w31", int(num_elts_per_sf or -1))
|
||||
if key not in cache_permute_indices:
|
||||
# Get permute indices and chain them together
|
||||
permute0 = get_reorder_rows_for_gated_act_gemm_row_indices(
|
||||
@ -122,7 +122,7 @@ def trtllmgen_maybe_get_cached_w2_permute_indices(
|
||||
torch.Tensor],
|
||||
epilogue_tile_m: int,
|
||||
num_elts_per_sf: Union[None, int] = None) -> torch.Tensor:
|
||||
key = (dst_w2_weight.shape, "w2")
|
||||
key = (dst_w2_weight.shape, "w2", int(num_elts_per_sf or -1))
|
||||
if key not in cache_permute_indices:
|
||||
if num_elts_per_sf is None:
|
||||
permute_indices = (get_shuffle_matrix_a_row_indices(
|
||||
@ -1478,11 +1478,15 @@ class NVFP4FusedMoEMethod(FusedMoEMethodBase):
|
||||
Base class for NVFP4 fused MoE methods for all backends.
|
||||
"""
|
||||
|
||||
def create_weights(self, module: torch.nn.Module, weight_dtype,
|
||||
weight_vec_size, block_scales_dtype,
|
||||
block_scales_vec_size):
|
||||
def create_weights(self,
|
||||
module: torch.nn.Module,
|
||||
weight_dtype,
|
||||
weight_vec_size,
|
||||
block_scales_dtype,
|
||||
block_scales_vec_size,
|
||||
scaling_vector_size=16):
|
||||
|
||||
module.scaling_vector_size = 16
|
||||
module.scaling_vector_size = scaling_vector_size
|
||||
# Divide by 16 because we use int64 to pack 16 fp4 values
|
||||
w3_w1_weight_shape = (module.expert_size_per_partition,
|
||||
module.intermediate_size_per_partition * 2,
|
||||
@ -1893,9 +1897,12 @@ class NVFP4TRTLLMGenFusedMoEMethod(NVFP4FusedMoEMethod):
|
||||
non_blocking=True)
|
||||
|
||||
def load_expert_w3_w1_weight_scale_nvfp4(
|
||||
self, module: torch.nn.Module, w1_weight_scale: torch.Tensor,
|
||||
self,
|
||||
module: torch.nn.Module,
|
||||
w1_weight_scale: torch.Tensor,
|
||||
w3_weight_scale: torch.Tensor,
|
||||
dst_w3_w1_weight_scale: torch.Tensor):
|
||||
dst_w3_w1_weight_scale: torch.Tensor,
|
||||
num_elts_per_sf: int = 16):
|
||||
device = dst_w3_w1_weight_scale.device
|
||||
assert device.type == "cuda"
|
||||
w1_weight_scale = load_weight_shard(w1_weight_scale,
|
||||
@ -1933,7 +1940,7 @@ class NVFP4TRTLLMGenFusedMoEMethod(NVFP4FusedMoEMethod):
|
||||
dst_w3_w1_weight_scale.view(float4_sf_dtype),
|
||||
self._cache_permute_indices,
|
||||
epilogue_tile_m,
|
||||
num_elts_per_sf=16)
|
||||
num_elts_per_sf=num_elts_per_sf)
|
||||
|
||||
# Shuffle the weight according to permute indices
|
||||
w3_w1_weight_scale = torch.ops.trtllm.shuffle_matrix(
|
||||
@ -1949,9 +1956,11 @@ class NVFP4TRTLLMGenFusedMoEMethod(NVFP4FusedMoEMethod):
|
||||
processed_w3_w1_weight_scale.view(
|
||||
self.block_scales_dtype).reshape(orig_shape))
|
||||
|
||||
def load_expert_w2_weight_scale_nvfp4(self, module: torch.nn.Module,
|
||||
def load_expert_w2_weight_scale_nvfp4(self,
|
||||
module: torch.nn.Module,
|
||||
w2_weight_scale: torch.Tensor,
|
||||
dst_w2_weight_scale: torch.Tensor):
|
||||
dst_w2_weight_scale: torch.Tensor,
|
||||
num_elts_per_sf: int = 16):
|
||||
device = dst_w2_weight_scale.device
|
||||
assert device.type == "cuda"
|
||||
w2_weight_scale = load_weight_shard(w2_weight_scale,
|
||||
@ -1976,7 +1985,7 @@ class NVFP4TRTLLMGenFusedMoEMethod(NVFP4FusedMoEMethod):
|
||||
dst_w2_weight_scale.view(float4_sf_dtype),
|
||||
self._cache_permute_indices,
|
||||
epilogue_tile_m,
|
||||
num_elts_per_sf=16)
|
||||
num_elts_per_sf=num_elts_per_sf)
|
||||
|
||||
# Shuffle the weight according to permute indices
|
||||
w_shuffled = torch.ops.trtllm.shuffle_matrix(
|
||||
@ -1998,6 +2007,56 @@ class NVFP4TRTLLMGenFusedMoEMethod(NVFP4FusedMoEMethod):
|
||||
non_blocking=True)
|
||||
|
||||
|
||||
class W4A8NVFP4FP8TRTLLMGenFusedMoEMethod(NVFP4TRTLLMGenFusedMoEMethod):
|
||||
|
||||
def create_weights(self, module: torch.nn.Module):
|
||||
weight_vec_size = torch.iinfo(self.weight_dtype).bits // 4
|
||||
block_scales_vec_size = 1
|
||||
|
||||
NVFP4FusedMoEMethod.create_weights(self, module, self.weight_dtype,
|
||||
weight_vec_size,
|
||||
self.block_scales_dtype,
|
||||
block_scales_vec_size, 32)
|
||||
|
||||
fc31_scale_c = nn.Parameter(torch.ones(module.expert_size_per_partition,
|
||||
dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
module.register_parameter("fc31_scale_c", fc31_scale_c)
|
||||
|
||||
self.setup_quant_scales(module)
|
||||
|
||||
def load_expert_w3_w1_weight_scale_nvfp4(
|
||||
self, module: torch.nn.Module, w1_weight_scale: torch.Tensor,
|
||||
w3_weight_scale: torch.Tensor,
|
||||
dst_w3_w1_weight_scale: torch.Tensor):
|
||||
return super().load_expert_w3_w1_weight_scale_nvfp4(
|
||||
module, w1_weight_scale, w3_weight_scale, dst_w3_w1_weight_scale,
|
||||
32)
|
||||
|
||||
def load_expert_w2_weight_scale_nvfp4(self, module: torch.nn.Module,
|
||||
w2_weight_scale: torch.Tensor,
|
||||
dst_w2_weight_scale: torch.Tensor):
|
||||
return super().load_expert_w2_weight_scale_nvfp4(
|
||||
module, w2_weight_scale, dst_w2_weight_scale, 32)
|
||||
|
||||
def load_all_fp4_weight_scales_and_alphas(
|
||||
self, module: torch.nn.Module, weights: Dict,
|
||||
load_expert_ids: List[int], dst_w3_w1_weight_scale: torch.Tensor,
|
||||
dst_w2_weight_scale: torch.Tensor, dst_fc31_alpha: torch.Tensor,
|
||||
dst_fc2_alpha: torch.Tensor):
|
||||
super().load_all_fp4_weight_scales_and_alphas(
|
||||
module, weights, load_expert_ids, dst_w3_w1_weight_scale,
|
||||
dst_w2_weight_scale, dst_fc31_alpha, dst_fc2_alpha)
|
||||
# The kernel we use will convert nvfp4 to e4m3 before matmul,
|
||||
# so the range of the scale factor can only be [0,448/6].
|
||||
dst_w3_w1_weight_scale.copy_((dst_w3_w1_weight_scale.to(torch.float32) /
|
||||
6.0).to(torch.float8_e4m3fn))
|
||||
dst_w2_weight_scale.copy_((dst_w2_weight_scale.to(torch.float32) /
|
||||
6.0).to(torch.float8_e4m3fn))
|
||||
dst_fc31_alpha.copy_(dst_fc31_alpha * 6.0)
|
||||
dst_fc2_alpha.copy_(dst_fc2_alpha * 6.0)
|
||||
|
||||
|
||||
def _get_weight_alignment(weight_alignment, scaling_vector_size, tp_size,
|
||||
shard_dim_size):
|
||||
|
||||
|
||||
@ -31,8 +31,9 @@ from tensorrt_llm._torch.modules.fused_moe.interface import MoEWeightLoadingMode
|
||||
# isort and yapf will fight against each other here, so we disable isort
|
||||
# isort: off
|
||||
from tensorrt_llm._torch.modules.fused_moe import (
|
||||
BaseMoeRoutingMethod, CutlassFusedMoE, DefaultMoeRoutingMethod,
|
||||
RenormalizeMoeRoutingMethod, TritonFusedMoE, create_moe, WideEPMoE)
|
||||
BaseMoeRoutingMethod, CutlassFusedMoE, TRTLLMGenFusedMoE,
|
||||
DefaultMoeRoutingMethod, RenormalizeMoeRoutingMethod, TritonFusedMoE,
|
||||
create_moe, WideEPMoE)
|
||||
# isort: on
|
||||
from tensorrt_llm._torch.modules.fused_moe.fused_moe_triton import \
|
||||
IS_TRITON_KERNELS_AVAILABLE
|
||||
@ -1423,6 +1424,131 @@ def test_fused_moe_nvfp4(dtype):
|
||||
torch.testing.assert_close(output, ref_output, rtol=1e-2, atol=0.15)
|
||||
|
||||
|
||||
@skip_pre_blackwell
|
||||
def test_fused_moe_w4a8_nvfp4_fp8():
|
||||
dtype = torch.bfloat16
|
||||
mapping = Mapping()
|
||||
mapping.rank = mpi_rank()
|
||||
|
||||
with torch.device(f'cuda:{mapping.rank}'):
|
||||
SCALING_VECTOR_SIZE = 32
|
||||
|
||||
SEQ_LEN = 4
|
||||
HIDDEN_SIZE = 512
|
||||
INTERMEDIATE_SIZE = 512
|
||||
NUM_EXPERTS = 4
|
||||
TOP_K = 2
|
||||
routing_method = RenormalizeMoeRoutingMethod(top_k=TOP_K)
|
||||
torch.manual_seed(0)
|
||||
torch.cuda.manual_seed(0)
|
||||
x = torch.randn((SEQ_LEN, HIDDEN_SIZE), dtype=dtype, device="cuda")
|
||||
x_sf_global = 448 / x.abs().max().float()
|
||||
router_logits = torch.randn((SEQ_LEN, NUM_EXPERTS),
|
||||
dtype=dtype,
|
||||
device="cuda")
|
||||
|
||||
weights = {}
|
||||
for expert_id in range(NUM_EXPERTS):
|
||||
w1_weight = torch.randn((INTERMEDIATE_SIZE, HIDDEN_SIZE),
|
||||
dtype=torch.float32,
|
||||
device="cpu")
|
||||
w1_sf_global = (448) / w1_weight.abs().max().float()
|
||||
|
||||
w2_weight = torch.randn((HIDDEN_SIZE, INTERMEDIATE_SIZE),
|
||||
dtype=torch.float32,
|
||||
device="cpu")
|
||||
w2_sf_global = (448) / w2_weight.abs().max().float()
|
||||
|
||||
w3_weight = torch.randn((INTERMEDIATE_SIZE, HIDDEN_SIZE),
|
||||
dtype=torch.float32,
|
||||
device="cpu")
|
||||
w3_sf_global = (448) / w3_weight.abs().max().float()
|
||||
|
||||
w3_w1_global = min(
|
||||
w1_sf_global,
|
||||
w3_sf_global) # w3 global and w1 global must be the same
|
||||
|
||||
w1_weight_nvfp4, w1_sf_block, _ = torch.ops.tensorrt_llm.float_to_e2m1_and_ufp8sf_scale(
|
||||
w1_weight * w3_w1_global, SCALING_VECTOR_SIZE, 1, False)
|
||||
w1_sf_block_unswizzled = torch.ops.trtllm.block_scale_interleave_reverse(
|
||||
w1_sf_block.view(INTERMEDIATE_SIZE, -1))
|
||||
|
||||
w2_weight_nvfp4, w2_sf_block, _ = torch.ops.tensorrt_llm.float_to_e2m1_and_ufp8sf_scale(
|
||||
w2_weight * w2_sf_global, SCALING_VECTOR_SIZE, 1, False)
|
||||
w2_sf_block_unswizzled = torch.ops.trtllm.block_scale_interleave_reverse(
|
||||
w2_sf_block.view(HIDDEN_SIZE, -1))
|
||||
|
||||
w3_weight_nvfp4, w3_sf_block, _ = torch.ops.tensorrt_llm.float_to_e2m1_and_ufp8sf_scale(
|
||||
w3_weight * w3_w1_global, SCALING_VECTOR_SIZE, 1, False)
|
||||
w3_sf_block_unswizzled = torch.ops.trtllm.block_scale_interleave_reverse(
|
||||
w3_sf_block.view(INTERMEDIATE_SIZE, -1))
|
||||
|
||||
w1_weight_nvfp4 = w1_weight_nvfp4.cuda()
|
||||
w1_sf_block_unswizzled = w1_sf_block_unswizzled.cuda()
|
||||
w2_weight_nvfp4 = w2_weight_nvfp4.cuda()
|
||||
w2_sf_block_unswizzled = w2_sf_block_unswizzled.cuda()
|
||||
w3_weight_nvfp4 = w3_weight_nvfp4.cuda()
|
||||
w3_sf_block_unswizzled = w3_sf_block_unswizzled.cuda()
|
||||
|
||||
w1_input_scale = x_sf_global.cuda()
|
||||
w2_input_scale = x_sf_global.cuda()
|
||||
w3_input_scale = x_sf_global.cuda()
|
||||
|
||||
weights[f"{expert_id}.w1.weight"] = w1_weight_nvfp4
|
||||
weights[f"{expert_id}.w2.weight"] = w2_weight_nvfp4
|
||||
weights[f"{expert_id}.w3.weight"] = w3_weight_nvfp4
|
||||
weights[
|
||||
f"{expert_id}.w1.weight_scale"] = w1_sf_block_unswizzled.view(
|
||||
torch.float8_e4m3fn).cuda()
|
||||
weights[
|
||||
f"{expert_id}.w2.weight_scale"] = w2_sf_block_unswizzled.view(
|
||||
torch.float8_e4m3fn).cuda()
|
||||
weights[
|
||||
f"{expert_id}.w3.weight_scale"] = w3_sf_block_unswizzled.view(
|
||||
torch.float8_e4m3fn).cuda()
|
||||
weights[f"{expert_id}.w1.input_scale"] = 1.0 / w1_input_scale
|
||||
weights[f"{expert_id}.w2.input_scale"] = 1.0 / w2_input_scale
|
||||
weights[f"{expert_id}.w3.input_scale"] = 1.0 / w3_input_scale
|
||||
weights[f"{expert_id}.w1.weight_scale_2"] = 1.0 / w3_w1_global
|
||||
weights[f"{expert_id}.w2.weight_scale_2"] = 1.0 / w2_sf_global
|
||||
weights[f"{expert_id}.w3.weight_scale_2"] = 1.0 / w3_w1_global
|
||||
|
||||
quant_config = QuantConfig(quant_algo=QuantAlgo.W4A8_NVFP4_FP8)
|
||||
fused_moe = TRTLLMGenFusedMoE(
|
||||
num_experts=NUM_EXPERTS,
|
||||
routing_method=routing_method,
|
||||
hidden_size=HIDDEN_SIZE,
|
||||
intermediate_size=INTERMEDIATE_SIZE,
|
||||
dtype=dtype,
|
||||
reduce_results=False,
|
||||
model_config=ModelConfig(quant_config=quant_config))
|
||||
fused_moe.load_weights([weights])
|
||||
fused_moe.cuda()
|
||||
|
||||
# Evaluate the outputs on a variant sequence length to cover all possible keys in Autotuner cache
|
||||
ref_fused_moe = RefGatedMLPFusedMoE(
|
||||
num_experts=NUM_EXPERTS,
|
||||
routing_method=routing_method,
|
||||
hidden_size=HIDDEN_SIZE,
|
||||
intermediate_size=INTERMEDIATE_SIZE,
|
||||
dtype=dtype,
|
||||
model_config=ModelConfig(quant_config=quant_config))
|
||||
ref_fused_moe.load_weights([weights])
|
||||
ref_fused_moe.cuda()
|
||||
|
||||
AutoTuner.get().clear_cache()
|
||||
with torch.inference_mode(), autotune():
|
||||
fused_moe.forward(x, router_logits)
|
||||
|
||||
with torch.inference_mode():
|
||||
output = fused_moe.forward(x, router_logits)
|
||||
ref_output = ref_fused_moe.forward(x, router_logits)
|
||||
|
||||
# compare
|
||||
torch.cuda.synchronize()
|
||||
torch.testing.assert_close(output, ref_output, rtol=1e-1, atol=0.5)
|
||||
|
||||
|
||||
@skip_neither_ada_nor_hopper_unittest
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||
@pytest.mark.parametrize(
|
||||
@ -1689,7 +1815,7 @@ def test_fused_moe_w4afp8(dtype, weight_loading_mode):
|
||||
@skip_pre_blackwell
|
||||
@pytest.mark.parametrize("moe_backend", ["TRTLLM", "CUTLASS"])
|
||||
@pytest.mark.parametrize("bias", [True, False])
|
||||
def test_fused_moe_mxfp4_mxpf8(moe_backend, bias):
|
||||
def test_fused_moe_mxfp4_mxfp8(moe_backend, bias):
|
||||
SCALING_VECTOR_SIZE = 32
|
||||
dtype = torch.bfloat16
|
||||
SEQ_LEN = 128
|
||||
@ -2262,7 +2388,8 @@ class RefGatedMLPFusedMoE(nn.Module):
|
||||
f"{expert}.w3.input_scale"]
|
||||
down_proj_weights[0]['input_scale'] = weights[
|
||||
f"{expert}.w2.input_scale"]
|
||||
elif self.quant_config and self.quant_config.quant_algo == QuantAlgo.NVFP4:
|
||||
elif self.quant_config and self.quant_config.quant_algo in (
|
||||
QuantAlgo.NVFP4, QuantAlgo.W4A8_NVFP4_FP8):
|
||||
gate_up_proj_weights[0]['weight_scale'] = weights[
|
||||
f"{expert}.w1.weight_scale"]
|
||||
gate_up_proj_weights[1]['weight_scale'] = weights[
|
||||
|
||||
@ -1069,13 +1069,36 @@ class TestMoeFp4:
|
||||
routing_info,
|
||||
use_autotune=True,
|
||||
use_topk_as_input=False)
|
||||
if intermediate_size >= 256:
|
||||
self.run_moe_fp8_fp4_test(num_tokens,
|
||||
hidden_size,
|
||||
intermediate_size,
|
||||
routing_info,
|
||||
use_autotune=True,
|
||||
use_topk_as_input=False)
|
||||
|
||||
@pytest.mark.parametrize("num_tokens", [1])
|
||||
@pytest.mark.parametrize("hidden_size", [1024])
|
||||
@pytest.mark.parametrize("intermediate_size", [384])
|
||||
@pytest.mark.parametrize(
|
||||
"routing_info",
|
||||
[
|
||||
pytest.param(
|
||||
{
|
||||
"num_experts": 72,
|
||||
"top_k": 6,
|
||||
"padding": 8,
|
||||
"n_groups": 1,
|
||||
"top_k_groups": 1,
|
||||
"routed_scaling": 2.5,
|
||||
"has_routing_bias": True,
|
||||
"routing_method_type": RoutingMethodType.DeepSeekV3
|
||||
},
|
||||
id="RoutingDSlite"),
|
||||
],
|
||||
)
|
||||
def test_autotune_fp8_fp4(self, num_tokens, hidden_size, intermediate_size,
|
||||
routing_info):
|
||||
|
||||
self.run_moe_fp8_fp4_test(num_tokens,
|
||||
hidden_size,
|
||||
intermediate_size,
|
||||
routing_info,
|
||||
use_autotune=True,
|
||||
use_topk_as_input=False)
|
||||
|
||||
@pytest.mark.parametrize("num_tokens", [1, 150])
|
||||
@pytest.mark.parametrize("hidden_size", [1024])
|
||||
@ -1120,6 +1143,33 @@ class TestMoeFp4:
|
||||
routing_info,
|
||||
use_autotune=False,
|
||||
use_topk_as_input=use_topk_as_input)
|
||||
|
||||
@pytest.mark.parametrize("num_tokens", [1])
|
||||
@pytest.mark.parametrize("hidden_size", [1024])
|
||||
@pytest.mark.parametrize("intermediate_size", [1024])
|
||||
@pytest.mark.parametrize(
|
||||
"routing_info",
|
||||
[
|
||||
pytest.param(
|
||||
{
|
||||
"num_experts": 128,
|
||||
"top_k": 4,
|
||||
"padding": 8,
|
||||
"n_groups": None,
|
||||
"top_k_groups": None,
|
||||
"routed_scaling": None,
|
||||
"has_routing_bias": False,
|
||||
"routing_method_type": RoutingMethodType.Renormalize
|
||||
},
|
||||
id="RoutingRenormalize_topk_4"),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("use_topk_as_input", [False, True],
|
||||
ids=["use_score_as_input", "use_topk_as_input"])
|
||||
def test_no_autotune_fp8_fp4(self, num_tokens, hidden_size,
|
||||
intermediate_size, routing_info,
|
||||
use_topk_as_input):
|
||||
|
||||
self.run_moe_fp8_fp4_test(num_tokens,
|
||||
hidden_size,
|
||||
intermediate_size,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user