mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
[None][feat] Add support for fused gate_up_proj scales for FP8 blockwise (#6496)
Signed-off-by: Aurelien Chartier <2567591+achartier@users.noreply.github.com>
This commit is contained in:
parent
46df8712c8
commit
6da95f29a9
@ -528,31 +528,44 @@ class DeepSeekFP8BlockScalesFusedMoEMethod(FusedMoEMethodBase):
|
||||
load_expert_ids: List[int], dst_w3_w1_weight_scale: torch.Tensor,
|
||||
dst_w2_weight_scale: torch.Tensor, device):
|
||||
for local_slot_id, expert_id in enumerate(load_expert_ids):
|
||||
w3_scale = load_weight_shard(
|
||||
weights[f"{expert_id}.w3.weight_scale_inv"],
|
||||
module.tp_size,
|
||||
module.tp_rank,
|
||||
TensorParallelMode.COLUMN,
|
||||
device=device)
|
||||
dst_w3_w1_weight_scale[local_slot_id][:dst_w3_w1_weight_scale.
|
||||
shape[-2] //
|
||||
2].copy_(w3_scale)
|
||||
w1_scale = load_weight_shard(
|
||||
weights[f"{expert_id}.w1.weight_scale_inv"],
|
||||
module.tp_size,
|
||||
module.tp_rank,
|
||||
TensorParallelMode.COLUMN,
|
||||
device=device)
|
||||
dst_w3_w1_weight_scale[local_slot_id][dst_w3_w1_weight_scale.
|
||||
shape[-2] //
|
||||
2:].copy_(w1_scale)
|
||||
w2_scale = load_weight_shard(
|
||||
weights[f"{expert_id}.w2.weight_scale_inv"],
|
||||
module.tp_size,
|
||||
module.tp_rank,
|
||||
TensorParallelMode.ROW,
|
||||
device=device)
|
||||
dst_w2_weight_scale[local_slot_id].copy_(w2_scale)
|
||||
if module.weight_loading_mode == MoEWeightLoadingMode.FUSED_GATE_UP_PROJ:
|
||||
w3_scale = weights['gate_up_proj_weight_scale'][
|
||||
expert_id].transpose(0, 1).contiguous()
|
||||
w1_scale = None
|
||||
w2_scale = weights['down_proj_weight_scale'][
|
||||
expert_id].transpose(0, 1).contiguous()
|
||||
elif module.weight_loading_mode == MoEWeightLoadingMode.VANILLA:
|
||||
w3_scale = weights[f"{expert_id}.w3.weight_scale_inv"]
|
||||
w1_scale = weights[f"{expert_id}.w1.weight_scale_inv"]
|
||||
w2_scale = weights[f"{expert_id}.w2.weight_scale_inv"]
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Unknown weight loading mode in MoE: {module.weight_loading_mode}"
|
||||
)
|
||||
|
||||
w3_w1_scale_shard = load_weight_shard(w3_scale,
|
||||
module.tp_size,
|
||||
module.tp_rank,
|
||||
TensorParallelMode.COLUMN,
|
||||
device=device)
|
||||
|
||||
if w1_scale is not None:
|
||||
w1_scale_shard = load_weight_shard(w1_scale,
|
||||
module.tp_size,
|
||||
module.tp_rank,
|
||||
TensorParallelMode.COLUMN,
|
||||
device=device)
|
||||
w3_w1_scale_shard = torch.cat(
|
||||
[w3_w1_scale_shard, w1_scale_shard], dim=-2)
|
||||
|
||||
dst_w3_w1_weight_scale[local_slot_id].copy_(w3_w1_scale_shard)
|
||||
|
||||
w2_scale_shard = load_weight_shard(w2_scale,
|
||||
module.tp_size,
|
||||
module.tp_rank,
|
||||
TensorParallelMode.ROW,
|
||||
device=device)
|
||||
dst_w2_weight_scale[local_slot_id].copy_(w2_scale_shard)
|
||||
|
||||
def load_quant_scales(self, module: torch.nn.Module, weights: Dict):
|
||||
self.load_expert_all_weight_scale_fp8_block_scale(
|
||||
|
||||
@ -30,6 +30,7 @@ from tensorrt_llm._torch.modules.fused_moe.fused_moe_deepgemm import \
|
||||
DeepGemmFusedMoE
|
||||
from tensorrt_llm._torch.modules.fused_moe.fused_moe_wide_ep import \
|
||||
AlltoallMethodType
|
||||
from tensorrt_llm._torch.modules.fused_moe.interface import MoEWeightLoadingMode
|
||||
from tensorrt_llm._torch.modules.gated_mlp import GatedMLP
|
||||
from tensorrt_llm._utils import mpi_rank
|
||||
from tensorrt_llm.mapping import Mapping
|
||||
@ -561,13 +562,14 @@ def test_fused_moe_fp8_blockwise_deepgemm(dtype,
|
||||
|
||||
@skip_non_hopper_unittest
|
||||
@pytest.mark.parametrize(
|
||||
"dtype, num_experts, seq_len, hidden_size, RoutingMethodCls",
|
||||
"dtype, num_experts, seq_len, hidden_size, RoutingMethodCls, WeightLoadingMode",
|
||||
product(
|
||||
[torch.bfloat16],
|
||||
[72],
|
||||
[128, 256, 384, 512, 1024, 2048, 4096, 8192],
|
||||
[2560],
|
||||
[DefaultMoeRoutingMethod],
|
||||
[MoEWeightLoadingMode.VANILLA, MoEWeightLoadingMode.FUSED_GATE_UP_PROJ],
|
||||
),
|
||||
)
|
||||
def test_fused_moe_fp8_blockwise(dtype,
|
||||
@ -575,6 +577,7 @@ def test_fused_moe_fp8_blockwise(dtype,
|
||||
seq_len,
|
||||
hidden_size,
|
||||
RoutingMethodCls,
|
||||
WeightLoadingMode,
|
||||
mapping=None):
|
||||
SEQ_LEN = seq_len
|
||||
HIDDEN_SIZE = hidden_size
|
||||
@ -600,6 +603,13 @@ def test_fused_moe_fp8_blockwise(dtype,
|
||||
device="cuda")
|
||||
|
||||
weights = {}
|
||||
|
||||
if WeightLoadingMode == MoEWeightLoadingMode.FUSED_GATE_UP_PROJ:
|
||||
weights['gate_up_proj'] = {}
|
||||
weights['down_proj'] = {}
|
||||
weights['gate_up_proj_weight_scale'] = {}
|
||||
weights['down_proj_weight_scale'] = {}
|
||||
|
||||
for expert_id in range(NUM_EXPERTS):
|
||||
w1_weight = torch.randn((INTERMEDIATE_SIZE, HIDDEN_SIZE),
|
||||
dtype=dtype,
|
||||
@ -626,13 +636,26 @@ def test_fused_moe_fp8_blockwise(dtype,
|
||||
weights[f"{expert_id}.w1.weight"] = w1_weight_fp8
|
||||
weights[f"{expert_id}.w2.weight"] = w2_weight_fp8
|
||||
weights[f"{expert_id}.w3.weight"] = w3_weight_fp8
|
||||
weights[f"{expert_id}.w1.weight_scale_inv"] = w1_weight_scale
|
||||
weights[f"{expert_id}.w2.weight_scale_inv"] = w2_weight_scale
|
||||
weights[f"{expert_id}.w3.weight_scale_inv"] = w3_weight_scale
|
||||
weights[f"{expert_id}.w1.weight_scale"] = w1_weight_scale
|
||||
weights[f"{expert_id}.w2.weight_scale"] = w2_weight_scale
|
||||
weights[f"{expert_id}.w3.weight_scale"] = w3_weight_scale
|
||||
|
||||
if WeightLoadingMode == MoEWeightLoadingMode.FUSED_GATE_UP_PROJ:
|
||||
weights['gate_up_proj'][expert_id] = torch.cat(
|
||||
[w3_weight_fp8, w1_weight_fp8],
|
||||
dim=-2).transpose(0, 1).contiguous()
|
||||
weights['down_proj'][expert_id] = w2_weight_fp8.transpose(
|
||||
0, 1).contiguous()
|
||||
weights['gate_up_proj_weight_scale'][expert_id] = torch.cat(
|
||||
[w3_weight_scale, w1_weight_scale],
|
||||
dim=-2).transpose(0, 1).contiguous()
|
||||
weights['down_proj_weight_scale'][
|
||||
expert_id] = w2_weight_scale.transpose(0, 1).contiguous()
|
||||
elif WeightLoadingMode == MoEWeightLoadingMode.VANILLA:
|
||||
weights[f"{expert_id}.w1.weight_scale_inv"] = w1_weight_scale
|
||||
weights[f"{expert_id}.w2.weight_scale_inv"] = w2_weight_scale
|
||||
weights[f"{expert_id}.w3.weight_scale_inv"] = w3_weight_scale
|
||||
|
||||
quant_config = QuantConfig(quant_algo=QuantAlgo.FP8_BLOCK_SCALES)
|
||||
|
||||
fused_moe = CuteDslFusedMoE(
|
||||
@ -643,6 +666,7 @@ def test_fused_moe_fp8_blockwise(dtype,
|
||||
dtype=dtype,
|
||||
reduce_results=True,
|
||||
model_config=ModelConfig(quant_config=quant_config, mapping=mapping),
|
||||
weight_loading_mode=WeightLoadingMode,
|
||||
)
|
||||
fused_moe.cuda()
|
||||
fused_moe.load_weights([weights])
|
||||
@ -655,6 +679,7 @@ def test_fused_moe_fp8_blockwise(dtype,
|
||||
dtype=dtype,
|
||||
reduce_results=True,
|
||||
model_config=ModelConfig(quant_config=quant_config, mapping=mapping),
|
||||
weight_loading_mode=WeightLoadingMode,
|
||||
)
|
||||
fused_moe_origin.cuda()
|
||||
fused_moe_origin.load_weights([weights])
|
||||
|
||||
Loading…
Reference in New Issue
Block a user