[None][feat] Add support for fused gate_up_proj scales for FP8 blockwise (#6496)

Signed-off-by: Aurelien Chartier <2567591+achartier@users.noreply.github.com>
This commit is contained in:
Aurelien Chartier 2025-08-05 11:22:32 -07:00 committed by GitHub
parent 46df8712c8
commit 6da95f29a9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 67 additions and 29 deletions

View File

@ -528,31 +528,44 @@ class DeepSeekFP8BlockScalesFusedMoEMethod(FusedMoEMethodBase):
load_expert_ids: List[int], dst_w3_w1_weight_scale: torch.Tensor,
dst_w2_weight_scale: torch.Tensor, device):
for local_slot_id, expert_id in enumerate(load_expert_ids):
w3_scale = load_weight_shard(
weights[f"{expert_id}.w3.weight_scale_inv"],
module.tp_size,
module.tp_rank,
TensorParallelMode.COLUMN,
device=device)
dst_w3_w1_weight_scale[local_slot_id][:dst_w3_w1_weight_scale.
shape[-2] //
2].copy_(w3_scale)
w1_scale = load_weight_shard(
weights[f"{expert_id}.w1.weight_scale_inv"],
module.tp_size,
module.tp_rank,
TensorParallelMode.COLUMN,
device=device)
dst_w3_w1_weight_scale[local_slot_id][dst_w3_w1_weight_scale.
shape[-2] //
2:].copy_(w1_scale)
w2_scale = load_weight_shard(
weights[f"{expert_id}.w2.weight_scale_inv"],
module.tp_size,
module.tp_rank,
TensorParallelMode.ROW,
device=device)
dst_w2_weight_scale[local_slot_id].copy_(w2_scale)
if module.weight_loading_mode == MoEWeightLoadingMode.FUSED_GATE_UP_PROJ:
w3_scale = weights['gate_up_proj_weight_scale'][
expert_id].transpose(0, 1).contiguous()
w1_scale = None
w2_scale = weights['down_proj_weight_scale'][
expert_id].transpose(0, 1).contiguous()
elif module.weight_loading_mode == MoEWeightLoadingMode.VANILLA:
w3_scale = weights[f"{expert_id}.w3.weight_scale_inv"]
w1_scale = weights[f"{expert_id}.w1.weight_scale_inv"]
w2_scale = weights[f"{expert_id}.w2.weight_scale_inv"]
else:
raise NotImplementedError(
f"Unknown weight loading mode in MoE: {module.weight_loading_mode}"
)
w3_w1_scale_shard = load_weight_shard(w3_scale,
module.tp_size,
module.tp_rank,
TensorParallelMode.COLUMN,
device=device)
if w1_scale is not None:
w1_scale_shard = load_weight_shard(w1_scale,
module.tp_size,
module.tp_rank,
TensorParallelMode.COLUMN,
device=device)
w3_w1_scale_shard = torch.cat(
[w3_w1_scale_shard, w1_scale_shard], dim=-2)
dst_w3_w1_weight_scale[local_slot_id].copy_(w3_w1_scale_shard)
w2_scale_shard = load_weight_shard(w2_scale,
module.tp_size,
module.tp_rank,
TensorParallelMode.ROW,
device=device)
dst_w2_weight_scale[local_slot_id].copy_(w2_scale_shard)
def load_quant_scales(self, module: torch.nn.Module, weights: Dict):
self.load_expert_all_weight_scale_fp8_block_scale(

View File

@ -30,6 +30,7 @@ from tensorrt_llm._torch.modules.fused_moe.fused_moe_deepgemm import \
DeepGemmFusedMoE
from tensorrt_llm._torch.modules.fused_moe.fused_moe_wide_ep import \
AlltoallMethodType
from tensorrt_llm._torch.modules.fused_moe.interface import MoEWeightLoadingMode
from tensorrt_llm._torch.modules.gated_mlp import GatedMLP
from tensorrt_llm._utils import mpi_rank
from tensorrt_llm.mapping import Mapping
@ -561,13 +562,14 @@ def test_fused_moe_fp8_blockwise_deepgemm(dtype,
@skip_non_hopper_unittest
@pytest.mark.parametrize(
"dtype, num_experts, seq_len, hidden_size, RoutingMethodCls",
"dtype, num_experts, seq_len, hidden_size, RoutingMethodCls, WeightLoadingMode",
product(
[torch.bfloat16],
[72],
[128, 256, 384, 512, 1024, 2048, 4096, 8192],
[2560],
[DefaultMoeRoutingMethod],
[MoEWeightLoadingMode.VANILLA, MoEWeightLoadingMode.FUSED_GATE_UP_PROJ],
),
)
def test_fused_moe_fp8_blockwise(dtype,
@ -575,6 +577,7 @@ def test_fused_moe_fp8_blockwise(dtype,
seq_len,
hidden_size,
RoutingMethodCls,
WeightLoadingMode,
mapping=None):
SEQ_LEN = seq_len
HIDDEN_SIZE = hidden_size
@ -600,6 +603,13 @@ def test_fused_moe_fp8_blockwise(dtype,
device="cuda")
weights = {}
if WeightLoadingMode == MoEWeightLoadingMode.FUSED_GATE_UP_PROJ:
weights['gate_up_proj'] = {}
weights['down_proj'] = {}
weights['gate_up_proj_weight_scale'] = {}
weights['down_proj_weight_scale'] = {}
for expert_id in range(NUM_EXPERTS):
w1_weight = torch.randn((INTERMEDIATE_SIZE, HIDDEN_SIZE),
dtype=dtype,
@ -626,13 +636,26 @@ def test_fused_moe_fp8_blockwise(dtype,
weights[f"{expert_id}.w1.weight"] = w1_weight_fp8
weights[f"{expert_id}.w2.weight"] = w2_weight_fp8
weights[f"{expert_id}.w3.weight"] = w3_weight_fp8
weights[f"{expert_id}.w1.weight_scale_inv"] = w1_weight_scale
weights[f"{expert_id}.w2.weight_scale_inv"] = w2_weight_scale
weights[f"{expert_id}.w3.weight_scale_inv"] = w3_weight_scale
weights[f"{expert_id}.w1.weight_scale"] = w1_weight_scale
weights[f"{expert_id}.w2.weight_scale"] = w2_weight_scale
weights[f"{expert_id}.w3.weight_scale"] = w3_weight_scale
if WeightLoadingMode == MoEWeightLoadingMode.FUSED_GATE_UP_PROJ:
weights['gate_up_proj'][expert_id] = torch.cat(
[w3_weight_fp8, w1_weight_fp8],
dim=-2).transpose(0, 1).contiguous()
weights['down_proj'][expert_id] = w2_weight_fp8.transpose(
0, 1).contiguous()
weights['gate_up_proj_weight_scale'][expert_id] = torch.cat(
[w3_weight_scale, w1_weight_scale],
dim=-2).transpose(0, 1).contiguous()
weights['down_proj_weight_scale'][
expert_id] = w2_weight_scale.transpose(0, 1).contiguous()
elif WeightLoadingMode == MoEWeightLoadingMode.VANILLA:
weights[f"{expert_id}.w1.weight_scale_inv"] = w1_weight_scale
weights[f"{expert_id}.w2.weight_scale_inv"] = w2_weight_scale
weights[f"{expert_id}.w3.weight_scale_inv"] = w3_weight_scale
quant_config = QuantConfig(quant_algo=QuantAlgo.FP8_BLOCK_SCALES)
fused_moe = CuteDslFusedMoE(
@ -643,6 +666,7 @@ def test_fused_moe_fp8_blockwise(dtype,
dtype=dtype,
reduce_results=True,
model_config=ModelConfig(quant_config=quant_config, mapping=mapping),
weight_loading_mode=WeightLoadingMode,
)
fused_moe.cuda()
fused_moe.load_weights([weights])
@ -655,6 +679,7 @@ def test_fused_moe_fp8_blockwise(dtype,
dtype=dtype,
reduce_results=True,
model_config=ModelConfig(quant_config=quant_config, mapping=mapping),
weight_loading_mode=WeightLoadingMode,
)
fused_moe_origin.cuda()
fused_moe_origin.load_weights([weights])