mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-04 18:21:52 +08:00
[None][fix] Fix enable_alltoall passed to CutlassFusedMoE (#11016)
Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>
This commit is contained in:
parent
24ac86c485
commit
34a730aaf7
@ -206,7 +206,11 @@ if args.run_type == "GEN":
|
||||
max(1, 20480 // ctx_seq_len_q),
|
||||
)
|
||||
ctx_attn_workspace = torch.empty((0,), device="cuda", dtype=torch.int8)
|
||||
with mock.patch.dict(os.environ, {"TRTLLM_FORCE_ALLTOALL_METHOD": "NotEnabled"}, clear=False):
|
||||
with mock.patch.dict(
|
||||
os.environ,
|
||||
{"TRTLLM_FORCE_ALLTOALL_METHOD": "NotEnabled", "TRTLLM_FORCE_COMM_METHOD": "ALLGATHER"},
|
||||
clear=False,
|
||||
):
|
||||
ctx_runner = Runner(
|
||||
args.model,
|
||||
mapping,
|
||||
|
||||
@ -313,7 +313,7 @@ class ConfigurableMoE(MoE):
|
||||
2. Validates if current AllToAll strategy can be used for given workload
|
||||
3. Falls back to AllGather if current strategy cannot be used (logs info message)
|
||||
|
||||
After calling this method, use _is_using_alltoall() to check which method is active.
|
||||
After calling this method, use enable_alltoall to check which method is active.
|
||||
|
||||
Args:
|
||||
all_rank_num_tokens: Token counts per rank
|
||||
@ -348,23 +348,6 @@ class ConfigurableMoE(MoE):
|
||||
# Switch to AllGather (always works)
|
||||
self.comm = AllGatherReduceScatter(mapping=self.mapping)
|
||||
|
||||
def _is_using_alltoall(self) -> bool:
|
||||
"""
|
||||
Check if current communication strategy uses alltoall
|
||||
|
||||
Returns:
|
||||
True: Strategy uses alltoall (NVLINK, DeepEP, etc.)
|
||||
False: Strategy uses allgather (AllGatherReduceScatter or None)
|
||||
|
||||
Note: Can be called anytime. If comm is None, returns False (no alltoall).
|
||||
Typically called after determine_communication_method() to get accurate result.
|
||||
"""
|
||||
if self.comm is None:
|
||||
return False # No strategy means no alltoall
|
||||
|
||||
# AllGather uses allgather, all others use alltoall
|
||||
return not isinstance(self.comm, AllGatherReduceScatter)
|
||||
|
||||
def _create_comm_strategy_auto(self) -> Communication:
|
||||
"""
|
||||
Auto-create the best communication strategy based on hardware and configuration
|
||||
@ -810,11 +793,7 @@ class ConfigurableMoE(MoE):
|
||||
|
||||
Same as original implementation - chunking logic is backend-agnostic
|
||||
|
||||
Note: use_all_to_all is determined internally via _is_using_alltoall()
|
||||
|
||||
"""
|
||||
# Determine if using alltoall
|
||||
use_all_to_all = self._is_using_alltoall()
|
||||
# ========== Chunk preparation ==========
|
||||
if self.use_dp:
|
||||
# When using DP: need all ranks' token counts for reducescatter
|
||||
@ -828,7 +807,7 @@ class ConfigurableMoE(MoE):
|
||||
chunk_size_list = all_rank_chunk_size_list[self.rank]
|
||||
|
||||
# For alltoall, replace 0 with 1 (avoid empty tensor)
|
||||
if use_all_to_all:
|
||||
if self.enable_alltoall:
|
||||
all_rank_num_tokens_list = [
|
||||
[1 if val == 0 else val for val in val_list]
|
||||
for val_list in all_rank_num_tokens_list
|
||||
@ -842,7 +821,7 @@ class ConfigurableMoE(MoE):
|
||||
router_logits_list = router_logits.split(chunk_size_list)
|
||||
|
||||
# Determine if we need multiple streams for overlapped execution
|
||||
use_multi_stream = not use_all_to_all and self.aux_stream is not None
|
||||
use_multi_stream = not self.enable_alltoall and self.aux_stream is not None
|
||||
|
||||
# ========== Setup auxiliary stream ==========
|
||||
if use_multi_stream:
|
||||
@ -1086,7 +1065,8 @@ class ConfigurableMoE(MoE):
|
||||
# Only the non-alltoall case is considered for profiling in the warmup phase.
|
||||
# Therefore, to get the correct tactics during the actual inference, the inputs to the tuner
|
||||
# should be the same as when not using alltoall.
|
||||
if self._is_using_alltoall():
|
||||
kwargs["enable_alltoall"] = self.enable_alltoall
|
||||
if self.enable_alltoall:
|
||||
if all_rank_num_tokens is not None:
|
||||
kwargs["tuner_num_tokens"] = sum(all_rank_num_tokens)
|
||||
else:
|
||||
@ -1094,9 +1074,6 @@ class ConfigurableMoE(MoE):
|
||||
x.shape[0] * self.mapping.tp_size if x is not None else None
|
||||
)
|
||||
kwargs["tuner_top_k"] = self.routing_method.top_k
|
||||
else:
|
||||
kwargs["tuner_num_tokens"] = None
|
||||
kwargs["tuner_top_k"] = None
|
||||
|
||||
# Get moe_output for NVLinkOneSided backend
|
||||
kwargs["moe_output"] = self._get_nvlink_onesided_moe_output(
|
||||
|
||||
@ -411,6 +411,7 @@ class CutlassFusedMoE(MoE):
|
||||
tuner_num_tokens: Optional[int] = None,
|
||||
tuner_top_k: Optional[int] = None,
|
||||
moe_output: Optional[torch.Tensor] = None,
|
||||
enable_alltoall: Optional[bool] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Run MoE computation with Cutlass backend.
|
||||
@ -429,6 +430,7 @@ class CutlassFusedMoE(MoE):
|
||||
tuner_num_tokens: Number of tokens for profiling tuner (optional)
|
||||
tuner_top_k: Top-k value for profiling tuner (optional)
|
||||
moe_output: Pre-allocated output buffer (optional)
|
||||
enable_alltoall: Whether alltoall communication is enabled (optional). If None, defaults to self.enable_alltoall.
|
||||
|
||||
Returns:
|
||||
final_hidden_states: Output tensor from MoE computation
|
||||
@ -441,6 +443,9 @@ class CutlassFusedMoE(MoE):
|
||||
elif self.has_w4a16_mxfp4:
|
||||
weight_dtype = torch.uint8
|
||||
|
||||
if enable_alltoall is None:
|
||||
enable_alltoall = self.enable_alltoall
|
||||
|
||||
result = torch.ops.trtllm.fused_moe(
|
||||
x,
|
||||
token_selected_experts,
|
||||
@ -462,7 +467,7 @@ class CutlassFusedMoE(MoE):
|
||||
ep_rank=self.ep_rank,
|
||||
cluster_size=self.cluster_size,
|
||||
cluster_rank=self.cluster_rank,
|
||||
enable_alltoall=self.enable_alltoall,
|
||||
enable_alltoall=enable_alltoall,
|
||||
use_deepseek_fp8_block_scale=self.has_deepseek_fp8_block_scales,
|
||||
use_w4_group_scaling=self.has_w4afp8 or self.has_w4a16_mxfp4,
|
||||
use_int8_woq_per_channel=self.has_int8_woq_per_channel,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user