diff --git a/tensorrt_llm/_torch/modules/fused_moe/quantization.py b/tensorrt_llm/_torch/modules/fused_moe/quantization.py index 18e9c7cc98..249aadc04e 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/quantization.py +++ b/tensorrt_llm/_torch/modules/fused_moe/quantization.py @@ -528,31 +528,44 @@ class DeepSeekFP8BlockScalesFusedMoEMethod(FusedMoEMethodBase): load_expert_ids: List[int], dst_w3_w1_weight_scale: torch.Tensor, dst_w2_weight_scale: torch.Tensor, device): for local_slot_id, expert_id in enumerate(load_expert_ids): - w3_scale = load_weight_shard( - weights[f"{expert_id}.w3.weight_scale_inv"], - module.tp_size, - module.tp_rank, - TensorParallelMode.COLUMN, - device=device) - dst_w3_w1_weight_scale[local_slot_id][:dst_w3_w1_weight_scale. - shape[-2] // - 2].copy_(w3_scale) - w1_scale = load_weight_shard( - weights[f"{expert_id}.w1.weight_scale_inv"], - module.tp_size, - module.tp_rank, - TensorParallelMode.COLUMN, - device=device) - dst_w3_w1_weight_scale[local_slot_id][dst_w3_w1_weight_scale. - shape[-2] // - 2:].copy_(w1_scale) - w2_scale = load_weight_shard( - weights[f"{expert_id}.w2.weight_scale_inv"], - module.tp_size, - module.tp_rank, - TensorParallelMode.ROW, - device=device) - dst_w2_weight_scale[local_slot_id].copy_(w2_scale) + if module.weight_loading_mode == MoEWeightLoadingMode.FUSED_GATE_UP_PROJ: + w3_scale = weights['gate_up_proj_weight_scale'][ + expert_id].transpose(0, 1).contiguous() + w1_scale = None + w2_scale = weights['down_proj_weight_scale'][ + expert_id].transpose(0, 1).contiguous() + elif module.weight_loading_mode == MoEWeightLoadingMode.VANILLA: + w3_scale = weights[f"{expert_id}.w3.weight_scale_inv"] + w1_scale = weights[f"{expert_id}.w1.weight_scale_inv"] + w2_scale = weights[f"{expert_id}.w2.weight_scale_inv"] + else: + raise NotImplementedError( + f"Unknown weight loading mode in MoE: {module.weight_loading_mode}" + ) + + w3_w1_scale_shard = load_weight_shard(w3_scale, + module.tp_size, + module.tp_rank, + TensorParallelMode.COLUMN, + device=device) + + if w1_scale is not None: + w1_scale_shard = load_weight_shard(w1_scale, + module.tp_size, + module.tp_rank, + TensorParallelMode.COLUMN, + device=device) + w3_w1_scale_shard = torch.cat( + [w3_w1_scale_shard, w1_scale_shard], dim=-2) + + dst_w3_w1_weight_scale[local_slot_id].copy_(w3_w1_scale_shard) + + w2_scale_shard = load_weight_shard(w2_scale, + module.tp_size, + module.tp_rank, + TensorParallelMode.ROW, + device=device) + dst_w2_weight_scale[local_slot_id].copy_(w2_scale_shard) def load_quant_scales(self, module: torch.nn.Module, weights: Dict): self.load_expert_all_weight_scale_fp8_block_scale( diff --git a/tests/unittest/_torch/modules/test_fused_moe.py b/tests/unittest/_torch/modules/test_fused_moe.py index 51a7758d28..a72ad4c7b6 100644 --- a/tests/unittest/_torch/modules/test_fused_moe.py +++ b/tests/unittest/_torch/modules/test_fused_moe.py @@ -30,6 +30,7 @@ from tensorrt_llm._torch.modules.fused_moe.fused_moe_deepgemm import \ DeepGemmFusedMoE from tensorrt_llm._torch.modules.fused_moe.fused_moe_wide_ep import \ AlltoallMethodType +from tensorrt_llm._torch.modules.fused_moe.interface import MoEWeightLoadingMode from tensorrt_llm._torch.modules.gated_mlp import GatedMLP from tensorrt_llm._utils import mpi_rank from tensorrt_llm.mapping import Mapping @@ -561,13 +562,14 @@ def test_fused_moe_fp8_blockwise_deepgemm(dtype, @skip_non_hopper_unittest @pytest.mark.parametrize( - "dtype, num_experts, seq_len, hidden_size, RoutingMethodCls", + "dtype, num_experts, seq_len, hidden_size, RoutingMethodCls, WeightLoadingMode", product( [torch.bfloat16], [72], [128, 256, 384, 512, 1024, 2048, 4096, 8192], [2560], [DefaultMoeRoutingMethod], + [MoEWeightLoadingMode.VANILLA, MoEWeightLoadingMode.FUSED_GATE_UP_PROJ], ), ) def test_fused_moe_fp8_blockwise(dtype, @@ -575,6 +577,7 @@ def test_fused_moe_fp8_blockwise(dtype, seq_len, hidden_size, RoutingMethodCls, + WeightLoadingMode, mapping=None): SEQ_LEN = seq_len HIDDEN_SIZE = hidden_size @@ -600,6 +603,13 @@ def test_fused_moe_fp8_blockwise(dtype, device="cuda") weights = {} + + if WeightLoadingMode == MoEWeightLoadingMode.FUSED_GATE_UP_PROJ: + weights['gate_up_proj'] = {} + weights['down_proj'] = {} + weights['gate_up_proj_weight_scale'] = {} + weights['down_proj_weight_scale'] = {} + for expert_id in range(NUM_EXPERTS): w1_weight = torch.randn((INTERMEDIATE_SIZE, HIDDEN_SIZE), dtype=dtype, @@ -626,13 +636,26 @@ def test_fused_moe_fp8_blockwise(dtype, weights[f"{expert_id}.w1.weight"] = w1_weight_fp8 weights[f"{expert_id}.w2.weight"] = w2_weight_fp8 weights[f"{expert_id}.w3.weight"] = w3_weight_fp8 - weights[f"{expert_id}.w1.weight_scale_inv"] = w1_weight_scale - weights[f"{expert_id}.w2.weight_scale_inv"] = w2_weight_scale - weights[f"{expert_id}.w3.weight_scale_inv"] = w3_weight_scale weights[f"{expert_id}.w1.weight_scale"] = w1_weight_scale weights[f"{expert_id}.w2.weight_scale"] = w2_weight_scale weights[f"{expert_id}.w3.weight_scale"] = w3_weight_scale + if WeightLoadingMode == MoEWeightLoadingMode.FUSED_GATE_UP_PROJ: + weights['gate_up_proj'][expert_id] = torch.cat( + [w3_weight_fp8, w1_weight_fp8], + dim=-2).transpose(0, 1).contiguous() + weights['down_proj'][expert_id] = w2_weight_fp8.transpose( + 0, 1).contiguous() + weights['gate_up_proj_weight_scale'][expert_id] = torch.cat( + [w3_weight_scale, w1_weight_scale], + dim=-2).transpose(0, 1).contiguous() + weights['down_proj_weight_scale'][ + expert_id] = w2_weight_scale.transpose(0, 1).contiguous() + elif WeightLoadingMode == MoEWeightLoadingMode.VANILLA: + weights[f"{expert_id}.w1.weight_scale_inv"] = w1_weight_scale + weights[f"{expert_id}.w2.weight_scale_inv"] = w2_weight_scale + weights[f"{expert_id}.w3.weight_scale_inv"] = w3_weight_scale + quant_config = QuantConfig(quant_algo=QuantAlgo.FP8_BLOCK_SCALES) fused_moe = CuteDslFusedMoE( @@ -643,6 +666,7 @@ def test_fused_moe_fp8_blockwise(dtype, dtype=dtype, reduce_results=True, model_config=ModelConfig(quant_config=quant_config, mapping=mapping), + weight_loading_mode=WeightLoadingMode, ) fused_moe.cuda() fused_moe.load_weights([weights]) @@ -655,6 +679,7 @@ def test_fused_moe_fp8_blockwise(dtype, dtype=dtype, reduce_results=True, model_config=ModelConfig(quant_config=quant_config, mapping=mapping), + weight_loading_mode=WeightLoadingMode, ) fused_moe_origin.cuda() fused_moe_origin.load_weights([weights])