mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
[None][feat] Add alltoall to trtllm-gen MoE backend. (#8481)
Signed-off-by: Bo Li <22713281+bobboli@users.noreply.github.com>
This commit is contained in:
parent
ab4b9966b2
commit
ebb62e17d8
@ -1,8 +1,11 @@
|
||||
import os
|
||||
from functools import cached_property
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from tensorrt_llm._mnnvl_utils import MnnvlMemory, MnnvlMoe
|
||||
from tensorrt_llm._utils import get_sm_version
|
||||
|
||||
from ...custom_ops.trtllm_gen_custom_ops import \
|
||||
@ -106,10 +109,28 @@ class TRTLLMGenFusedMoE(MoE):
|
||||
assert len(
|
||||
self.initial_local_expert_ids) == self.expert_size_per_partition
|
||||
|
||||
self.alltoall_workspace = None
|
||||
self.alltoall_prepare_workspace = None
|
||||
if self.enable_alltoall:
|
||||
MnnvlMemory.initialize()
|
||||
self.alltoall_workspace = MnnvlMoe.get_moe_workspaces(
|
||||
model_config.mapping)
|
||||
self.alltoall_prepare_workspace = MnnvlMoe.get_moe_prepare_workspace(
|
||||
model_config.mapping)
|
||||
|
||||
self._weights_created = False
|
||||
if not model_config.skip_create_weights_in_init:
|
||||
self.create_weights()
|
||||
|
||||
@cached_property
|
||||
def enable_alltoall(self):
|
||||
mapping = self.mapping
|
||||
routing_experts = self.routing_method.experts_per_token
|
||||
return (mapping.moe_ep_size > routing_experts
|
||||
and mapping.enable_attention_dp and mapping.tp_size > 1
|
||||
and os.environ.get("TRTLLM_MOE_DISABLE_ALLTOALLV", "0") != "1"
|
||||
and MnnvlMemory.supports_mnnvl())
|
||||
|
||||
def _check_configs(self):
|
||||
assert self.has_deepseek_fp8_block_scales \
|
||||
or self.has_nvfp4 or self.has_w4a16_mxfp4 or self.has_w4a8_nvfp4_fp8 \
|
||||
@ -175,6 +196,48 @@ class TRTLLMGenFusedMoE(MoE):
|
||||
def post_load_weights(self):
|
||||
self.quant_method.post_load_weights(self)
|
||||
|
||||
def _quantize_for_post_quant_comm(self, x):
|
||||
"""Quantize inputs prior to post-communication (alltoall/allgather).
|
||||
Returns: (x, x_sf, x_row, x_col)
|
||||
"""
|
||||
x_row = x.shape[0]
|
||||
x_col = x.shape[1]
|
||||
x_sf = None
|
||||
if self.has_w4a8_mxfp4_fp8:
|
||||
x, _ = torch.ops.tensorrt_llm.static_quantize_e4m3_per_tensor(
|
||||
x, self.fc31_input_dequant[0])
|
||||
x_row, x_col = x.shape[0], x.shape[1]
|
||||
elif self.has_nvfp4:
|
||||
if isinstance(x, Fp4QuantizedTensor):
|
||||
assert not x.is_sf_swizzled, "Fp4QuantizedTensor should not be swizzled before communication"
|
||||
x_row = x.shape[0]
|
||||
x_col = x.shape[1] * 2
|
||||
x, x_sf = x.fp4_tensor, x.scaling_factor
|
||||
else:
|
||||
x_row = x.shape[0]
|
||||
x_col = x.shape[1]
|
||||
x, x_sf = torch.ops.trtllm.fp4_quantize(
|
||||
x, self.fc31_input_scale, self.scaling_vector_size, False,
|
||||
False)
|
||||
elif self.has_w4a8_mxfp4_mxfp8:
|
||||
x, x_sf = torch.ops.trtllm.mxfp8_quantize(
|
||||
x, False, alignment=self.quant_method.weight_alignment)
|
||||
x_row, x_col = x.shape[0], x.shape[1]
|
||||
elif self.has_deepseek_fp8_block_scales:
|
||||
# No change required before communication
|
||||
pass
|
||||
elif self.has_w4a16_mxfp4:
|
||||
pad_size = self.w3_w1_weight.shape[-1] * 2 - x.shape[-1]
|
||||
x = torch.nn.functional.pad(x, (0, pad_size))
|
||||
elif self.has_w4a8_nvfp4_fp8:
|
||||
x, _ = torch.ops.tensorrt_llm.static_quantize_e4m3_per_tensor(
|
||||
x, 1.0 / self.fc31_input_scale)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"unsupported quantization mode for post communication: {self.quant_config.quant_mode}"
|
||||
)
|
||||
return x, x_sf, x_row, x_col
|
||||
|
||||
def forward_impl(
|
||||
self,
|
||||
x: Union[torch.Tensor, Fp4QuantizedTensor],
|
||||
@ -202,55 +265,80 @@ class TRTLLMGenFusedMoE(MoE):
|
||||
topk_group = None
|
||||
routed_scaling_factor = None
|
||||
|
||||
run_post_quant_allgather = self.use_dp and self.parallel_size > 1
|
||||
run_post_quant_allgather = (self.use_dp and self.parallel_size > 1
|
||||
and not self.enable_alltoall)
|
||||
post_quant_comm = run_post_quant_allgather or self.enable_alltoall
|
||||
|
||||
x_sf = None
|
||||
token_selected_experts = None
|
||||
token_final_scales = None
|
||||
x_row = x.shape[0]
|
||||
x_col = x.shape[1]
|
||||
if run_post_quant_allgather:
|
||||
# apply routing
|
||||
token_count = x.shape[0]
|
||||
alltoall_info = None
|
||||
|
||||
if post_quant_comm:
|
||||
token_selected_experts, token_final_scales = self.routing_method.apply(
|
||||
router_logits)
|
||||
token_final_scales = token_final_scales.to(torch.bfloat16)
|
||||
assert token_final_scales.dtype == torch.bfloat16
|
||||
assert token_selected_experts.dtype == torch.int32
|
||||
# quantize inputs
|
||||
if self.has_w4a8_mxfp4_fp8:
|
||||
x, _ = torch.ops.tensorrt_llm.static_quantize_e4m3_per_tensor(
|
||||
x, self.fc31_input_dequant[0])
|
||||
# Update x_row and x_col to the padded shape
|
||||
x_row, x_col = x.shape[0], x.shape[1]
|
||||
elif self.has_nvfp4:
|
||||
if isinstance(x, Fp4QuantizedTensor):
|
||||
assert not x.is_sf_swizzled, "Fp4QuantizedTensor should not be swizzled before communication"
|
||||
x_row = x.shape[0]
|
||||
# note: we use uint8 to store 2 fp4 values
|
||||
x_col = x.shape[1] * 2
|
||||
x, x_sf = x.fp4_tensor, x.scaling_factor
|
||||
else:
|
||||
x_row = x.shape[0]
|
||||
x_col = x.shape[1]
|
||||
x, x_sf = torch.ops.trtllm.fp4_quantize(
|
||||
x, self.fc31_input_scale, self.scaling_vector_size,
|
||||
False, False)
|
||||
elif self.has_w4a8_mxfp4_mxfp8:
|
||||
x, x_sf = torch.ops.trtllm.mxfp8_quantize(
|
||||
x, False, alignment=self.quant_method.weight_alignment)
|
||||
# Update x_row and x_col to the padded shape
|
||||
x_row, x_col = x.shape[0], x.shape[1]
|
||||
elif self.has_deepseek_fp8_block_scales:
|
||||
pass
|
||||
elif self.has_w4a16_mxfp4:
|
||||
pad_size = self.w3_w1_weight.shape[-1] * 2 - x.shape[-1]
|
||||
x = torch.nn.functional.pad(x, (0, pad_size))
|
||||
else:
|
||||
raise ValueError(
|
||||
f"unsupported quantization mode with run_post_quant_allgather: {self.quant_config.quant_mode}"
|
||||
)
|
||||
token_selected_experts = token_selected_experts.to(torch.int32)
|
||||
if token_final_scales is not None:
|
||||
token_final_scales = token_final_scales.to(torch.bfloat16)
|
||||
|
||||
#allgather for attention DP
|
||||
x, x_sf, x_row, x_col = self._quantize_for_post_quant_comm(x)
|
||||
|
||||
if self.enable_alltoall:
|
||||
assert all_rank_num_tokens is not None, "all_rank_num_tokens required for alltoall"
|
||||
|
||||
max_num_token = max(
|
||||
all_rank_num_tokens) if all_rank_num_tokens else token_count
|
||||
|
||||
if token_final_scales is None:
|
||||
token_final_scales = torch.ones_like(token_selected_experts,
|
||||
dtype=torch.float32)
|
||||
else:
|
||||
token_final_scales = token_final_scales.to(torch.float32)
|
||||
|
||||
assert self.alltoall_prepare_workspace is not None, "alltoall_prepare_workspace should be initialized"
|
||||
alltoall_info, _ = MnnvlMoe.mnnvl_moe_alltoallv_prepare_without_allgather(
|
||||
token_selected_experts,
|
||||
None,
|
||||
self.alltoall_prepare_workspace,
|
||||
max_num_token,
|
||||
self.ep_rank,
|
||||
self.ep_size,
|
||||
self.num_experts,
|
||||
self.num_slots,
|
||||
top_k,
|
||||
)
|
||||
|
||||
if x_sf is not None:
|
||||
x_sf = x_sf.view(x_row, ceil_div(x_col,
|
||||
self.scaling_vector_size))
|
||||
|
||||
x, x_sf, token_selected_experts, token_final_scales = MnnvlMoe.mnnvl_moe_alltoallv(
|
||||
[x, x_sf, token_selected_experts, token_final_scales],
|
||||
alltoall_info,
|
||||
self.alltoall_workspace,
|
||||
self.ep_rank,
|
||||
self.ep_size,
|
||||
)
|
||||
|
||||
torch.ops.trtllm.memset_expert_ids(
|
||||
token_selected_experts,
|
||||
alltoall_info.recv_rank_count_cumsum,
|
||||
max_num_token,
|
||||
top_k,
|
||||
self.num_slots,
|
||||
self.ep_size,
|
||||
)
|
||||
|
||||
if x_sf is not None:
|
||||
x_sf = x_sf.flatten()
|
||||
|
||||
if token_final_scales is not None:
|
||||
token_final_scales = token_final_scales.to(torch.bfloat16)
|
||||
|
||||
elif run_post_quant_allgather:
|
||||
if x_sf is not None:
|
||||
x_sf = x_sf.view(x_row, ceil_div(x_col,
|
||||
self.scaling_vector_size))
|
||||
@ -265,6 +353,9 @@ class TRTLLMGenFusedMoE(MoE):
|
||||
if x_sf is not None:
|
||||
x_sf = x_sf.flatten()
|
||||
|
||||
router_logits_arg = router_logits if not post_quant_comm else None
|
||||
routing_bias_arg = routing_bias if not post_quant_comm else None
|
||||
|
||||
# TODO: since routing kernel is integrated into moe_runner for fp8,
|
||||
# here we just route the I/Os for moe_runner
|
||||
if self.has_deepseek_fp8_block_scales:
|
||||
@ -272,8 +363,8 @@ class TRTLLMGenFusedMoE(MoE):
|
||||
x_val, x_scale = torch.ops.trtllm.fp8_quantize_1x128(x)
|
||||
|
||||
final_hidden_states = torch.ops.trtllm.fp8_block_scale_moe_runner(
|
||||
router_logits if not run_post_quant_allgather else None,
|
||||
routing_bias if not run_post_quant_allgather else None,
|
||||
router_logits_arg,
|
||||
routing_bias_arg,
|
||||
x_val,
|
||||
x_scale,
|
||||
self.w3_w1_weight,
|
||||
@ -297,7 +388,7 @@ class TRTLLMGenFusedMoE(MoE):
|
||||
scale_factor_use_ue8m0 = False
|
||||
is_scale_factor_swizzled = False # use linear layout here
|
||||
|
||||
if not run_post_quant_allgather:
|
||||
if not post_quant_comm:
|
||||
hidden_states_fp4, hidden_states_scale_linear_fp4 = (
|
||||
torch.ops.trtllm.fp4_quantize(
|
||||
x,
|
||||
@ -310,8 +401,8 @@ class TRTLLMGenFusedMoE(MoE):
|
||||
hidden_states_fp4, hidden_states_scale_linear_fp4 = x, x_sf
|
||||
|
||||
outputs = torch.ops.trtllm.fp4_block_scale_moe_runner(
|
||||
router_logits if not run_post_quant_allgather else None,
|
||||
routing_bias if not run_post_quant_allgather else None,
|
||||
router_logits_arg,
|
||||
routing_bias_arg,
|
||||
hidden_states_fp4,
|
||||
hidden_states_scale_linear_fp4.view(torch.float8_e4m3fn),
|
||||
self.w3_w1_weight,
|
||||
@ -343,7 +434,7 @@ class TRTLLMGenFusedMoE(MoE):
|
||||
final_hidden_states = outputs[0]
|
||||
elif self.has_w4a16_mxfp4:
|
||||
assert x.dtype == torch.bfloat16
|
||||
if not run_post_quant_allgather:
|
||||
if not post_quant_comm:
|
||||
pad_size = self.w3_w1_weight.shape[-1] * 2 - x.shape[-1]
|
||||
x = torch.nn.functional.pad(x, (0, pad_size))
|
||||
else:
|
||||
@ -352,8 +443,8 @@ class TRTLLMGenFusedMoE(MoE):
|
||||
intermediate_size_per_partition_padded = self.w3_w1_weight.shape[
|
||||
-2] // 2
|
||||
final_hidden_states = torch.ops.trtllm.bf16_mxe2m1_block_scale_moe_runner(
|
||||
router_logits if not run_post_quant_allgather else None,
|
||||
routing_bias if not run_post_quant_allgather else None,
|
||||
router_logits_arg,
|
||||
routing_bias_arg,
|
||||
x,
|
||||
self.w3_w1_weight,
|
||||
self.w3_w1_weight_scale,
|
||||
@ -383,15 +474,15 @@ class TRTLLMGenFusedMoE(MoE):
|
||||
hidden_size].contiguous()
|
||||
elif self.has_w4a8_nvfp4_fp8:
|
||||
|
||||
if not run_post_quant_allgather:
|
||||
if not post_quant_comm:
|
||||
hidden_states_fp8, _ = torch.ops.tensorrt_llm.static_quantize_e4m3_per_tensor(
|
||||
x, 1.0 / self.fc31_input_scale)
|
||||
else:
|
||||
hidden_states_fp8 = x
|
||||
|
||||
outputs = torch.ops.trtllm.fp8_fp4_block_scale_moe_runner(
|
||||
router_logits,
|
||||
routing_bias,
|
||||
router_logits_arg,
|
||||
routing_bias_arg,
|
||||
hidden_states_fp8,
|
||||
self.w3_w1_weight,
|
||||
self.w3_w1_weight_scale.view(torch.float8_e4m3fn),
|
||||
@ -423,7 +514,7 @@ class TRTLLMGenFusedMoE(MoE):
|
||||
final_hidden_states = outputs[0]
|
||||
elif self.has_w4a8_mxfp4_fp8:
|
||||
pad_size = self.w3_w1_weight.shape[-1] * 2 - x.shape[-1]
|
||||
if not run_post_quant_allgather:
|
||||
if not post_quant_comm:
|
||||
x = torch.nn.functional.pad(x, (0, pad_size))
|
||||
x, _ = torch.ops.tensorrt_llm.static_quantize_e4m3_per_tensor(
|
||||
x, self.fc31_input_gate_dequant[0])
|
||||
@ -433,8 +524,8 @@ class TRTLLMGenFusedMoE(MoE):
|
||||
-2] // 2
|
||||
|
||||
final_hidden_states = torch.ops.trtllm.e4m3_mxe2m1_block_scale_moe_runner(
|
||||
router_logits if not run_post_quant_allgather else None,
|
||||
routing_bias if not run_post_quant_allgather else None,
|
||||
router_logits_arg,
|
||||
routing_bias_arg,
|
||||
x,
|
||||
self.w3_w1_weight,
|
||||
self.w3_w1_weight_scale,
|
||||
@ -466,7 +557,7 @@ class TRTLLMGenFusedMoE(MoE):
|
||||
final_hidden_states = final_hidden_states[:, :self.
|
||||
hidden_size].contiguous()
|
||||
elif self.has_w4a8_mxfp4_mxfp8:
|
||||
if not run_post_quant_allgather:
|
||||
if not post_quant_comm:
|
||||
# TRTLLM-Gen uses linear SF layout for the mxfp8 input.
|
||||
mxfp8_x, sf = torch.ops.trtllm.mxfp8_quantize(
|
||||
x, False, alignment=self.quant_method.weight_alignment)
|
||||
@ -477,8 +568,8 @@ class TRTLLMGenFusedMoE(MoE):
|
||||
-2] // 2
|
||||
|
||||
final_hidden_states = torch.ops.trtllm.mxe4m3_mxe2m1_block_scale_moe_runner(
|
||||
router_logits if not run_post_quant_allgather else None,
|
||||
routing_bias if not run_post_quant_allgather else None,
|
||||
router_logits_arg,
|
||||
routing_bias_arg,
|
||||
mxfp8_x,
|
||||
sf,
|
||||
self.w3_w1_weight,
|
||||
@ -511,6 +602,18 @@ class TRTLLMGenFusedMoE(MoE):
|
||||
"TRTLLMGenFusedMoE only supports fp8_block_scaling, nvfp4, w4a16_mxfp4, w4a8_mxfp4_mxfp8 and w4a8_mxfp4_fp8 dtypes."
|
||||
)
|
||||
|
||||
# Combine results if using alltoall
|
||||
if self.enable_alltoall and alltoall_info is not None:
|
||||
final_hidden_states = MnnvlMoe.mnnvl_moe_alltoallv_combine(
|
||||
final_hidden_states,
|
||||
alltoall_info,
|
||||
self.alltoall_workspace,
|
||||
ep_rank=self.ep_rank,
|
||||
ep_size=self.ep_size,
|
||||
top_k=top_k,
|
||||
token_count=token_count,
|
||||
)
|
||||
|
||||
final_hidden_states = self.reducescatter_or_allreduce(
|
||||
final_hidden_states,
|
||||
all_rank_num_tokens=all_rank_num_tokens,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user