feat: large-scale EP(part 7: DeepEP integration) (#4792)

Signed-off-by: Tailing Yuan <yuantailing@gmail.com>
Co-authored-by: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com>
This commit is contained in:
Tailing Yuan 2025-06-14 19:12:38 +08:00 committed by GitHub
parent 443b2eb51f
commit 0b60da2c45
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 610 additions and 88 deletions

View File

@ -1,7 +1,7 @@
version: "3.9"
services:
tensorrt_llm-dev:
image: urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:pytorch-25.04-py3-x86_64-ubuntu24.04-trt10.10.0.31-skip-tritondevel-202506021004-9420
image: urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:pytorch-25.04-py3-x86_64-ubuntu24.04-trt10.10.0.31-skip-tritondevel-202506111045-4792
network_mode: host
ipc: host

View File

@ -72,6 +72,10 @@ RUN bash ./install_pytorch.sh $TORCH_INSTALL_TYPE && rm install_pytorch.sh
RUN pip3 uninstall -y opencv && rm -rf /usr/local/lib/python3*/dist-packages/cv2/
RUN pip3 install opencv-python-headless --force-reinstall --no-deps --no-cache-dir
# Install DeepEP
COPY docker/common/install_deep_ep.sh install_deep_ep.sh
RUN bash ./install_deep_ep.sh && rm install_deep_ep.sh
# WARs against security issues inherited from pytorch:25.04
# * https://github.com/advisories/GHSA-vqfr-h8mv-ghfj
# * https://github.com/advisories/GHSA-7cx3-6m66-7c5m

View File

@ -0,0 +1,47 @@
#!/bin/bash
set -euxo pipefail
GITHUB_URL=${GITHUB_MIRROR:-https://github.com}
DEEP_EP_COMMIT=2b266cf6452134f993ab0fcb3ef2d5de7683c561
if [ "$(. /etc/os-release && echo $ID)" == "rocky" ]; then
echo "Skipping DeepEP installation in the Rocky distribution."
exit 0
fi
libmlx5_dir=$(dirname $(ldconfig -p | grep libmlx5.so.1 | head -n1 | awk '{print $NF}'))
export NVCC_APPEND_FLAGS="--threads 4"
# Custom NVSHMEM
curl -fsSL https://developer.download.nvidia.com/compute/redist/nvshmem/3.2.5/source/nvshmem_src_3.2.5-1.txz | tar xz
pushd nvshmem_src
curl -fsSL $GITHUB_URL/deepseek-ai/DeepEP/raw/$DEEP_EP_COMMIT/third-party/nvshmem.patch | patch -p1
sed "s/TRANSPORT_VERSION_MAJOR 3/TRANSPORT_VERSION_MAJOR 103/" -i src/CMakeLists.txt
ln -s libmlx5.so.1 "$libmlx5_dir/libmlx5.so"
cmake -S . -B build \
-DCMAKE_INSTALL_PREFIX=/opt/custom_nvshmem \
-DGDRCOPY_HOME=/usr/include \
-DNVSHMEM_SHMEM_SUPPORT=0 \
-DNVSHMEM_UCX_SUPPORT=0 \
-DNVSHMEM_USE_NCCL=0 \
-DNVSHMEM_MPI_SUPPORT=0 \
-DNVSHMEM_IBGDA_SUPPORT=1 \
-DNVSHMEM_PMIX_SUPPORT=0 \
-DNVSHMEM_TIMEOUT_DEVICE_POLLING=0 \
-DNVSHMEM_USE_GDRCOPY=1 \
-DCMAKE_CUDA_ARCHITECTURES="90-real;100-real;120-real" \
-DNVSHMEM_BUILD_TESTS=0 \
-DNVSHMEM_BUILD_EXAMPLES=0
cmake --build build -j`nproc`
make -C build install
popd
# DeepEP
curl -fsSL $GITHUB_URL/deepseek-ai/DeepEP/archive/$DEEP_EP_COMMIT.tar.gz | tar xz
TORCH_CUDA_ARCH_LIST="9.0;10.0;12.0" NVSHMEM_DIR=/opt/custom_nvshmem pip install -v --no-cache-dir ./DeepEP-$DEEP_EP_COMMIT
# Clean up
rm -r nvshmem_src
rm "$libmlx5_dir/libmlx5.so"
rm -r DeepEP-$DEEP_EP_COMMIT

View File

@ -28,10 +28,10 @@ UPLOAD_PATH = env.uploadPath ? env.uploadPath : "sw-tensorrt-generic/llm-artifac
// Container configuration
// available tags can be found in: https://urm.nvidia.com/artifactory/sw-tensorrt-docker/tensorrt-llm/
// [base_image_name]-[arch]-[os](-[python_version])-[trt_version]-[torch_install_type]-[stage]-[date]-[mr_id]
LLM_DOCKER_IMAGE = "urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:pytorch-25.04-py3-x86_64-ubuntu24.04-trt10.10.0.31-skip-tritondevel-202506021004-9420"
LLM_SBSA_DOCKER_IMAGE = "urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:pytorch-25.04-py3-aarch64-ubuntu24.04-trt10.10.0.31-skip-tritondevel-202506021004-9420"
LLM_ROCKYLINUX8_PY310_DOCKER_IMAGE = "urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:cuda-12.9.0-devel-rocky8-x86_64-rocky8-py310-trt10.10.0.31-skip-tritondevel-202506021004-9420"
LLM_ROCKYLINUX8_PY312_DOCKER_IMAGE = "urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:cuda-12.9.0-devel-rocky8-x86_64-rocky8-py312-trt10.10.0.31-skip-tritondevel-202506021004-9420"
LLM_DOCKER_IMAGE = "urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:pytorch-25.04-py3-x86_64-ubuntu24.04-trt10.10.0.31-skip-tritondevel-202506111045-4792"
LLM_SBSA_DOCKER_IMAGE = "urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:pytorch-25.04-py3-aarch64-ubuntu24.04-trt10.10.0.31-skip-tritondevel-202506111045-4792"
LLM_ROCKYLINUX8_PY310_DOCKER_IMAGE = "urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:cuda-12.9.0-devel-rocky8-x86_64-rocky8-py310-trt10.10.0.31-skip-tritondevel-202506111045-4792"
LLM_ROCKYLINUX8_PY312_DOCKER_IMAGE = "urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:cuda-12.9.0-devel-rocky8-x86_64-rocky8-py312-trt10.10.0.31-skip-tritondevel-202506111045-4792"
// TODO: Move common variables to an unified location
BUILD_CORES_REQUEST = "8"

View File

@ -1,7 +1,7 @@
import java.lang.InterruptedException
DOCKER_IMAGE = "urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:pytorch-25.04-py3-x86_64-ubuntu24.04-trt10.10.0.31-skip-tritondevel-202506021004-9420"
DOCKER_IMAGE = "urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:pytorch-25.04-py3-x86_64-ubuntu24.04-trt10.10.0.31-skip-tritondevel-202506111045-4792"
def createKubernetesPodConfig(image, arch = "amd64")
{

View File

@ -84,6 +84,9 @@ class ModelConfig(Generic[TConfig]):
# If true, enable min-latency mode. Currently only used for Llama4.
enable_min_latency: bool = False
# Allow models to select op according to whether CUDA Graphs are used.
use_cuda_graph: bool = False
extra_attrs: Dict = field(default_factory=dict, repr=False, init=False)
_frozen: bool = field(default=False, init=False, repr=False)

View File

@ -38,7 +38,6 @@ from torch import nn
from tqdm import tqdm
from transformers import PretrainedConfig
from tensorrt_llm._mnnvl_utils import MnnvlMemory
from tensorrt_llm.functional import PositionEmbeddingType
from tensorrt_llm.llmapi.utils import enable_llm_debug
from tensorrt_llm.mapping import Mapping
@ -413,10 +412,6 @@ class Deepseekv3MoE(nn.Module):
config = model_config.pretrained_config
self.top_k = top_k
self.use_dp = model_config.mapping.enable_attention_dp
self.enable_alltoall = Deepseekv3MoE.should_enable_alltoall(
model_config, top_k)
if self.enable_alltoall:
MnnvlMemory.initialize()
self.gate = DeepseekV3Gate(
hidden_size,
num_experts,
@ -439,7 +434,6 @@ class Deepseekv3MoE(nn.Module):
model_config=model_config,
override_quant_config=override_quant_config,
aux_stream=aux_stream_dict[AuxStreamType.MoeChunkingOverlap],
enable_alltoall=self.enable_alltoall,
layer_idx=layer_idx)
self.mapping = model_config.mapping
@ -505,25 +499,6 @@ class Deepseekv3MoE(nn.Module):
return shared_tp_size, shared_output_scale
@staticmethod
def should_enable_alltoall(model_config: ModelConfig, top_k: int) -> bool:
if not model_config.mapping.enable_attention_dp:
return False
if model_config.mapping.tp_size == 1:
return False
if not MnnvlMemory.supports_mnnvl():
return False
if os.environ.get("TRTLLM_MOE_DISABLE_ALLTOALLV", "0") == "1":
return False
if model_config.mapping.moe_ep_size <= top_k:
return False
return True
def compute_routed_output(self, hidden_states, hidden_states_fp4,
all_rank_num_tokens, do_finalize):
# max-throughput
@ -531,7 +506,7 @@ class Deepseekv3MoE(nn.Module):
if self.use_dp and self.mapping.tp_size > 1:
# FP4 all_gather moves this bf16 allgather in to after topk and fp4 quantization
# to reduce allreduce BW
if disable_fp4_allgather() and not self.enable_alltoall:
if disable_fp4_allgather() and not self.experts.enable_alltoall:
hidden_states = allgather(hidden_states,
self.mapping,
dim=0,

View File

@ -6,8 +6,6 @@ from torch import nn
from tqdm import tqdm
from transformers import Qwen3MoeConfig
from tensorrt_llm._mnnvl_utils import MnnvlMemory
from ..attention_backend import AttentionMetadata
from ..distributed import (AllReduce, AllReduceFusionOp, AllReduceParams,
allgather)
@ -91,10 +89,6 @@ class Qwen3MoE(nn.Module):
self.mapping = model_config.mapping
self.allreduce = AllReduce(mapping=model_config.mapping,
strategy=model_config.allreduce_strategy)
self.enable_alltoall = Qwen3MoE.should_enable_alltoall(
model_config, self.top_k)
if self.enable_alltoall:
MnnvlMemory.initialize()
self.gate = Qwen3Gate(
hidden_size=self.hidden_dim,
@ -117,25 +111,6 @@ class Qwen3MoE(nn.Module):
model_config=model_config,
)
@staticmethod
def should_enable_alltoall(model_config: ModelConfig, top_k: int) -> bool:
if not model_config.mapping.enable_attention_dp:
return False
if model_config.mapping.tp_size == 1:
return False
if not MnnvlMemory.supports_mnnvl():
return False
if os.environ.get("TRTLLM_MOE_DISABLE_ALLTOALLV", "0") == "1":
return False
if model_config.mapping.moe_ep_size <= top_k:
return False
return True
def forward(
self,
hidden_states: torch.Tensor,
@ -151,7 +126,7 @@ class Qwen3MoE(nn.Module):
if self.enable_attention_dp and self.mapping.tp_size > 1:
# FP4 all_gather moves this bf16 allgather in to after topk and fp4 quantization
# to reduce allreduce BW
if disable_fp4_allgather() and not self.enable_alltoall:
if disable_fp4_allgather() and not self.experts.enable_alltoall:
hidden_states = allgather(hidden_states,
self.mapping,
dim=0,

View File

@ -52,7 +52,6 @@ def create_moe(
aux_stream: Optional[torch.cuda.Stream] = None,
weight_loading_mode: MoEWeightLoadingMode = MoEWeightLoadingMode.VANILLA,
apply_router_weight_on_input: bool = False,
enable_alltoall: bool = False,
layer_idx: Optional[int] = None,
) -> MoE:
moe_cls = get_moe_cls(model_config, override_quant_config)
@ -63,7 +62,6 @@ def create_moe(
if moe_cls == TRTLLMGenFusedMoE:
assert not apply_router_weight_on_input, "apply_router_weight_on_input is not supported in TRTLLMGenFusedMoE."
assert not enable_alltoall, "enable_alltoall is not supported in TRTLLMGenFusedMoE."
return moe_cls(
routing_method=routing_method,
@ -88,12 +86,10 @@ def create_moe(
aux_stream=aux_stream,
weight_loading_mode=weight_loading_mode,
apply_router_weight_on_input=apply_router_weight_on_input,
enable_alltoall=enable_alltoall,
layer_idx=layer_idx,
)
elif moe_cls == VanillaMoE:
assert not apply_router_weight_on_input, "apply_router_weight_on_input is not supported in VanillaMoE."
assert not enable_alltoall, "enable_alltoall is not supported in VanillaMoE."
return moe_cls(
routing_method=routing_method,

View File

@ -0,0 +1,209 @@
# Adapted from
# https://github.com/deepseek-ai/DeepEP/blob/aae9fa9a6dd0fec2a723fbb85ec4b22460fab670/README.md
import weakref
from typing import List, Tuple, Union
import torch
from tensorrt_llm._utils import local_mpi_size, mpi_comm
from tensorrt_llm.mapping import Mapping
try:
from deep_ep import Buffer
deep_ep_installed = True
except ModuleNotFoundError:
deep_ep_installed = False
class VariableLengthBuffer:
""" A wrapper of deep_ep.Buffer that accepts future size change
"""
def __init__(self, mapping: Mapping):
self.comm = mpi_comm().Split(mapping.pp_rank, mapping.moe_ep_rank)
self.buffer = None
def __del__(self):
self.comm.Free()
def reserve(self, hidden_size: int, hidden_dtype: torch.dtype):
""" Ensure the buffer capacity is large enough.
Reserve is a collective operation that requires all EP ranks to be sync
"""
# NOTES: you may also replace `get_*_config` with your auto-tuned results via all the tests
num_nvl_bytes, num_rdma_bytes = 0, 0
hidden_bytes = hidden_size * max(hidden_dtype.itemsize,
torch.bfloat16.itemsize)
world_size = self.comm.Get_size()
for config in (Buffer.get_dispatch_config(world_size),
Buffer.get_combine_config(world_size)):
num_nvl_bytes = max(
config.get_nvl_buffer_size_hint(hidden_bytes, world_size),
num_nvl_bytes)
num_rdma_bytes = max(
config.get_rdma_buffer_size_hint(hidden_bytes, world_size),
num_rdma_bytes)
# Allocate a buffer if not existed or not enough buffer size
if self.buffer is None or self.buffer.num_nvl_bytes < num_nvl_bytes or self.buffer.num_rdma_bytes < num_rdma_bytes:
if self.buffer is not None:
num_nvl_bytes = max(num_nvl_bytes, self.buffer.num_nvl_bytes)
num_rdma_bytes = max(num_rdma_bytes, self.buffer.num_rdma_bytes)
del self.buffer # Destruct before Construct
self.buffer = Buffer(None,
num_nvl_bytes,
num_rdma_bytes,
num_nvl_peers=local_mpi_size(),
comm=self.comm)
def dispatch(self, x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
topk_idx: torch.Tensor, topk_weights: torch.Tensor,
num_experts: int) -> \
Tuple[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], torch.Tensor, torch.Tensor, List, Tuple]:
# NOTES: an optional `previous_event` means a CUDA event captured that you want to make it as a dependency
# of the dispatch kernel, it may be useful with communication-computation overlap. For more information, please
# refer to the docs of `Buffer.dispatch`
# Calculate layout before actual dispatch
num_tokens_per_rank, num_tokens_per_rdma_rank, num_tokens_per_expert, is_token_in_rank, event = \
self.buffer.get_dispatch_layout(topk_idx, num_experts)
assert event.event is None
# Do MoE dispatch
# NOTES: the CPU will wait for GPU's signal to arrive, so this is not compatible with CUDA graph
# For more advanced usages, please refer to the docs of the `dispatch` function
recv_x, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, handle, event = \
self.buffer.dispatch(x, topk_idx=topk_idx, topk_weights=topk_weights,
num_tokens_per_rank=num_tokens_per_rank, num_tokens_per_rdma_rank=num_tokens_per_rdma_rank,
is_token_in_rank=is_token_in_rank, num_tokens_per_expert=num_tokens_per_expert)
assert event.event is None
# For event management, please refer to the docs of the `EventOverlap` class
return recv_x, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, handle
def combine(self, x: torch.Tensor, handle: Tuple) -> torch.Tensor:
# Do MoE combine
# For more advanced usages, please refer to the docs of the `combine` function
combined_x, _, event = self.buffer.combine(x, handle)
assert event.event is None
# For event management, please refer to the docs of the `EventOverlap` class
return combined_x
class VariableLengthLowLatencyBuffer:
""" A wrapper of deep_ep.Buffer that accepts future size change
"""
def __init__(self, mapping: Mapping):
self.comm = mpi_comm().Split(mapping.pp_rank, mapping.moe_ep_rank)
self.buffer = None
self.num_max_dispatch_tokens_per_rank = None
def __del__(self):
self.comm.Free()
def reserve(self, num_max_dispatch_tokens_per_rank: int, hidden_size: int,
num_experts: int):
""" Ensure the buffer capacity is large enough.
Reserve is a collective operation that requires all EP ranks to be sync
"""
# NOTES: the low-latency mode will consume much more space than the normal mode
# So we recommend that `num_max_dispatch_tokens_per_rank` (the actual batch size in the decoding engine) should be less than 256
world_size = self.comm.Get_size()
num_rdma_bytes = Buffer.get_low_latency_rdma_size_hint(
num_max_dispatch_tokens_per_rank, hidden_size, world_size,
num_experts)
# Allocate a buffer if not existed or not enough buffer size
if self.buffer is None or self.buffer.num_rdma_bytes < num_rdma_bytes:
# NOTES: for best performance, the QP number **must** be equal to the number of the local experts
assert num_experts % world_size == 0
del self.buffer # Destruct before Construct
self.buffer = Buffer(None,
0,
num_rdma_bytes,
low_latency_mode=True,
num_qps_per_rank=num_experts // world_size,
comm=self.comm)
def low_latency_dispatch(self, hidden_states: torch.Tensor,
topk_idx: torch.Tensor,
num_max_dispatch_tokens_per_rank: int,
num_experts: int):
if self.num_max_dispatch_tokens_per_rank is None:
self.num_max_dispatch_tokens_per_rank = num_max_dispatch_tokens_per_rank
if num_max_dispatch_tokens_per_rank != self.num_max_dispatch_tokens_per_rank:
raise NotImplementedError(
"There are issues if `low_latency_dispatch` calls use different `num_max_dispatch_tokens_per_rank` values"
)
# Do MoE dispatch, compatible with CUDA graph (but you may restore some buffer status once you replay)
recv_hidden_states, recv_expert_count, handle, event, hook = \
self.buffer.low_latency_dispatch(hidden_states, topk_idx, num_max_dispatch_tokens_per_rank, num_experts, use_fp8=False)
assert event.event is None
assert hook is None
# NOTES: the actual tensor will not be received only if you call `hook()`,
# it is useful for double-batch overlapping, but **without any SM occupation**
# If you don't want to overlap, please set `return_recv_hook=False`
# Later, you can use our GEMM library to do the computation with this specific format
return recv_hidden_states, recv_expert_count, handle
def low_latency_combine(self, hidden_states: torch.Tensor,
topk_idx: torch.Tensor, topk_weights: torch.Tensor,
handle: Tuple):
# Do MoE combine, compatible with CUDA graph (but you may restore some buffer status once you replay)
combined_hidden_states, event, hook = \
self.buffer.low_latency_combine(hidden_states, topk_idx, topk_weights, handle)
assert event.event is None
assert hook is None
# NOTES: the same behavior as described in the dispatch kernel
return combined_hidden_states
class BufferPool:
""" A pool that allocates buffers on demand.
Although the pool interface allows creating multiple buffers, the
current version of DeepEP supports at most one `deep_ep.Buffer` at a
time. Please ensure that all references to `VariableLengthBuffer` are
released before getting another buffer.
"""
def __init__(self):
self.buffers: Map[Mapping,
weakref.ReferenceType[VariableLengthBuffer]] = {}
self.low_latency_buffers: Map[
Mapping,
weakref.ReferenceType[VariableLengthLowLatencyBuffer]] = {}
def get_buffer(self, mapping: Mapping) -> VariableLengthBuffer:
""" Get_buffer is a collective operation that requires all ranks to be sync
"""
if mapping in self.buffers and self.buffers[mapping]() is not None:
buffer = self.buffers[mapping]()
else:
buffer = VariableLengthBuffer(mapping)
self.buffers[mapping] = weakref.ref(buffer)
return buffer
def get_low_latency_buffer(
self, mapping: Mapping) -> VariableLengthLowLatencyBuffer:
""" Get_low_latency_buffer is a collective operation that requires all ranks to be sync
"""
if mapping in self.low_latency_buffers and self.low_latency_buffers[
mapping]() is not None:
buffer = self.low_latency_buffers[mapping]()
else:
buffer = VariableLengthLowLatencyBuffer(mapping)
self.low_latency_buffers[mapping] = weakref.ref(buffer)
return buffer
# The default pool
# You may create own pools for better resource management.
buffer_pool = BufferPool()

View File

@ -1,16 +1,19 @@
import os
from enum import IntEnum
from typing import Dict, List, Optional, Tuple, Union
import torch
from tensorrt_llm._mnnvl_utils import MnnvlMoe, MoEAlltoallInfo
from tensorrt_llm._mnnvl_utils import MnnvlMemory, MnnvlMoe, MoEAlltoallInfo
from tensorrt_llm._utils import logger
from tensorrt_llm.mapping import Mapping
from ...distributed import allgather, reducescatter
from ...expert_statistic import ExpertStatistic
from ...model_config import ModelConfig
from ...utils import (EventType, Fp4QuantizedTensor, disable_fp4_allgather,
reswizzle_sf, swizzle_sf, unswizzle_sf)
from .deep_ep_utils import buffer_pool, deep_ep_installed
from .interface import MoE
from .moe_load_balancer import get_moe_load_balancer
from .quantization import (DeepSeekFP8BlockScalesFusedMoEMethod,
@ -20,6 +23,18 @@ from .quantization import (DeepSeekFP8BlockScalesFusedMoEMethod,
from .routing import BaseMoeRoutingMethod
# The type of alltoall method
class AlltoallMethodType(IntEnum):
# Not available
NotEnabled = 0
# MNNVL
MNNVL = 1
# DeepEP intranode or internode: no CUDA Graphs support, IBGDA is required by internode
DeepEP = 2
# DeepEP low latency: CUDA Graphs are supported, IBGDA is required
DeepEPLowLatency = 3
class CutlassFusedMoE(MoE):
"""
Fused Mixture of Experts (MoE) Layer with performance tuning.
@ -33,7 +48,6 @@ class CutlassFusedMoE(MoE):
dtype (Optional[torch.dtype]): Data type for the weights.
reduce_results (bool): Whether to reduce the results across devices.
model_config (ModelConfig): Configuration object for the model.
enable_alltoall (bool): whether to enable alltoall instead of allgather/reducescatter
MoE torch custom op:
In min-latency mode:
@ -82,7 +96,6 @@ class CutlassFusedMoE(MoE):
weight_loading_mode: MoEWeightLoadingMode = MoEWeightLoadingMode.
VANILLA,
apply_router_weight_on_input: bool = False,
enable_alltoall: bool = False,
layer_idx: Optional[int] = None,
):
@ -176,7 +189,12 @@ class CutlassFusedMoE(MoE):
self.has_been_profiled = False
self.has_been_profiled_min_latency = False
self.enable_alltoall = enable_alltoall
self.alltoall_method_type = self.select_alltoall_method_type(
model_config.mapping, routing_method.experts_per_token, dtype,
model_config.use_cuda_graph)
logger.info_once(
f"CutlassFusedMoE selects alltoall_method_type {self.alltoall_method_type!r}",
key="alltoall_method_type")
self.use_postquant_alltoall = False
if self.enable_alltoall:
assert self.use_dp and self.parallel_size > 1,\
@ -185,8 +203,25 @@ class CutlassFusedMoE(MoE):
self.use_postquant_alltoall = (os.environ.get(
"TRTLLM_MOE_POST_QUANT_ALLTOALLV", "1")
== "1") and qm.has_nvfp4()
self.alltoall_workspace = MnnvlMoe.get_moe_workspaces(
model_config.mapping) if enable_alltoall else None
if self.alltoall_method_type == AlltoallMethodType.MNNVL:
MnnvlMemory.initialize()
self.alltoall_workspace = MnnvlMoe.get_moe_workspaces(
model_config.mapping)
elif self.alltoall_method_type == AlltoallMethodType.DeepEP:
self.deep_ep_buffer = buffer_pool.get_buffer(
model_config.mapping)
self.deep_ep_buffer.reserve(hidden_size, dtype)
elif self.alltoall_method_type == AlltoallMethodType.DeepEPLowLatency:
self.deep_ep_max_num_tokens = min(model_config.max_num_tokens,
self.moe_max_num_tokens)
self.deep_ep_buffer = buffer_pool.get_low_latency_buffer(
model_config.mapping)
self.deep_ep_buffer.reserve(self.deep_ep_max_num_tokens,
hidden_size, self.num_slots)
else:
raise NotImplementedError(
f"Not available alltoall method type: {alltoall_method_type!r}"
)
# If True, the router weight will be multiplied on the input rather than at the end of FC2
self.apply_router_weight_on_input = apply_router_weight_on_input
@ -215,12 +250,48 @@ class CutlassFusedMoE(MoE):
f"unsupported quantization mode: {self.quant_config.quant_mode}"
)
@staticmethod
def select_alltoall_method_type(mapping: Mapping, top_k: int,
dtype: torch.dtype,
use_cuda_graph: bool) -> AlltoallMethodType:
if not mapping.enable_attention_dp:
return AlltoallMethodType.NotEnabled
if mapping.tp_size == 1:
return AlltoallMethodType.NotEnabled
if os.environ.get("TRTLLM_MOE_DISABLE_ALLTOALLV", "0") == "1":
return AlltoallMethodType.NotEnabled
if mapping.moe_ep_size <= top_k:
return AlltoallMethodType.NotEnabled
if MnnvlMemory.supports_mnnvl():
return AlltoallMethodType.MNNVL
if os.environ.get("TRTLLM_CAN_USE_DEEP_EP", "0") == "1":
if deep_ep_installed and dtype == torch.bfloat16:
if use_cuda_graph:
# Here we can only choose DeepEPLowLatency since only this method supports CUDA Graphs.
return AlltoallMethodType.DeepEPLowLatency
else:
# Here we can choose DeepEP or DeepEPLowLatency if both are available. Now DeepEP is faster.
return AlltoallMethodType.DeepEP
return AlltoallMethodType.NotEnabled
@property
def has_w4afp8(self):
assert self._weights_created
return self.quant_config and self.quant_config.quant_mode.is_int4_weight_only_per_group(
)
@property
def enable_alltoall(self):
""" enable_alltoall (bool): whether to enable alltoall instead of allgather/reducescatter
"""
return self.alltoall_method_type != AlltoallMethodType.NotEnabled
def _get_quant_method(self):
if self.quant_config is not None and self.quant_config.layer_quant_mode.has_any_quant(
exclude_kv_cache=True):
@ -311,10 +382,6 @@ class CutlassFusedMoE(MoE):
# TODO: remove this once we have correct fusedmoe kernel ready
token_final_scales = None
token_count = x.shape[0]
alltoall_info = None
if self.layer_load_balancer and not self.layer_load_balancer.is_static_routing(
) and is_first_call:
self.layer_load_balancer.maybe_cudagraph_done_wait()
@ -334,13 +401,58 @@ class CutlassFusedMoE(MoE):
ExpertStatistic.maybe_add_info(self.num_slots, token_selected_slots)
token_selected_experts_for_statistic = token_selected_experts if need_statistic else None
if self.enable_alltoall:
x, token_selected_slots, token_final_scales, token_selected_experts_for_statistic, alltoall_info = \
self.alltoall_prepare_maybe_dispatch(all_rank_num_tokens,
x,
token_selected_slots,
token_final_scales,
token_selected_experts_for_statistic)
if self.alltoall_method_type == AlltoallMethodType.MNNVL:
token_count = x.shape[0]
alltoall_info = None
x, token_selected_slots, token_final_scales, token_selected_experts_for_statistic, alltoall_info = \
self.alltoall_prepare_maybe_dispatch(all_rank_num_tokens,
x,
token_selected_slots,
token_final_scales,
token_selected_experts_for_statistic)
elif self.alltoall_method_type == AlltoallMethodType.DeepEP:
if not self.use_postquant_alltoall:
x, recv_topk_idx, token_final_scales, num_recv_tokens_per_expert_list, deep_ep_handle = \
self.deep_ep_buffer.dispatch(x, token_selected_slots.to(torch.int64), token_final_scales, self.num_slots)
elif self.alltoall_method_type == AlltoallMethodType.DeepEPLowLatency:
if not self.use_postquant_alltoall:
deep_ep_topk_idx = token_selected_slots.to(torch.int64)
deep_ep_topk_weights = token_final_scales
x, recv_expert_count, deep_ep_handle = \
self.deep_ep_buffer.low_latency_dispatch(x, deep_ep_topk_idx, self.deep_ep_max_num_tokens, self.num_slots)
# x shape: [#local experts, #max recv tokens, hidden_size]
# recv_expert_count shape: [#local experts]
# Adapter between `torch.ops.trtllm.fused_moe` and DeepEP
# TODO: remove the adapter by changing `torch.ops.trtllm.fused_moe` API
mask = torch.arange(
x.shape[1], dtype=torch.int32, device=x.device).expand(
x.shape[0],
x.shape[1]) < recv_expert_count.unsqueeze(1)
token_selected_slots = torch.full(
(x.shape[0], x.shape[1], self.routing_method.top_k),
self.num_slots,
dtype=torch.int32,
device=x.device)
token_selected_slots[:, :, 0] = torch.where(
mask,
torch.arange(
x.shape[0] * self.mapping.moe_ep_rank,
x.shape[0] * (self.mapping.moe_ep_rank + 1),
dtype=torch.int32,
device=x.device).unsqueeze(1), self.num_slots)
x = x.view(x.shape[0] * x.shape[1], x.shape[2])
token_selected_slots = token_selected_slots.view(
x.shape[0], self.routing_method.top_k)
token_final_scales = torch.ones_like(
token_selected_slots, dtype=token_final_scales.dtype)
else:
raise NotImplementedError(
f"Not available alltoall method type: {alltoall_method_type!r}"
)
x_sf = None
if self.has_any_quant:
if self.has_fp8_qdq:
@ -414,8 +526,56 @@ class CutlassFusedMoE(MoE):
quant_scales = self.quant_scales
if self.use_postquant_alltoall:
x, x_sf = self.alltoall_postquant_dispatch(x, x_sf, x_row, x_col,
alltoall_info)
if self.alltoall_method_type == AlltoallMethodType.MNNVL:
x, x_sf = self.alltoall_postquant_dispatch(
x, x_sf, x_row, x_col, alltoall_info)
elif self.alltoall_method_type == AlltoallMethodType.DeepEP:
if x_sf is not None:
if self.has_nvfp4:
x_sf = unswizzle_sf(x_sf, x_row, x_col,
self.scaling_vector_size)
# Adapter between `x_sf` and DeepEP
# TODO: remove the adapter by adding dtype support to DeepEP
x_sf_dtype = x_sf.dtype
x_sf = x_sf.view(torch.float32)
(x, x_sf), recv_topk_idx, token_final_scales, num_recv_tokens_per_expert_list, deep_ep_handle = \
self.deep_ep_buffer.dispatch((x, x_sf), token_selected_slots.to(torch.int64), token_final_scales, self.num_slots)
if x_sf is not None:
x_sf = x_sf.view(x_sf_dtype)
if self.has_nvfp4:
x_sf = swizzle_sf(x_sf, x.shape[0], x.shape[1] * 2,
self.scaling_vector_size)
elif self.alltoall_method_type == AlltoallMethodType.DeepEPLowLatency:
raise NotImplementedError(
"Not implemented postquant for DeepEPLowLatency, please set TRTLLM_MOE_POST_QUANT_ALLTOALLV=0"
)
else:
raise NotImplementedError(
f"Not available alltoall method type: {alltoall_method_type!r}"
)
if self.enable_alltoall:
# Adapter between `torch.ops.trtllm.fused_moe` and DeepEP
# TODO: remove the adapter by changing APIs
if self.alltoall_method_type == AlltoallMethodType.DeepEP:
token_selected_slots = recv_topk_idx.to(torch.int32)
mask = token_selected_slots == -1
token_selected_slots += self.expert_size_per_partition * self.mapping.moe_ep_rank
token_selected_slots[mask] = self.num_slots
num_recv_token_is_zero = x.shape[0] == 0
if x.shape[0] == 0:
x = torch.zeros((1, x.shape[1]),
dtype=x.dtype,
device=x.device)
token_selected_slots = torch.full(
(1, token_selected_slots.shape[1]),
self.num_slots,
dtype=token_selected_slots.dtype,
device=token_selected_slots.device)
token_final_scales = torch.ones(
(1, token_final_scales.shape[1]),
dtype=token_final_scales.dtype,
device=token_final_scales.device)
final_hidden_states = torch.ops.trtllm.fused_moe(
x,
@ -452,9 +612,25 @@ class CutlassFusedMoE(MoE):
final_hidden_states = final_hidden_states[0]
if self.enable_alltoall:
final_hidden_states = self.alltoall_combine(final_hidden_states,
alltoall_info,
token_count)
if self.alltoall_method_type == AlltoallMethodType.MNNVL:
final_hidden_states = self.alltoall_combine(
final_hidden_states, alltoall_info, token_count)
elif self.alltoall_method_type == AlltoallMethodType.DeepEP:
if num_recv_token_is_zero:
final_hidden_states = final_hidden_states[:0]
final_hidden_states = self.deep_ep_buffer.combine(
final_hidden_states, deep_ep_handle)
elif self.alltoall_method_type == AlltoallMethodType.DeepEPLowLatency:
final_hidden_states = self.deep_ep_buffer.low_latency_combine(
final_hidden_states.view(
self.expert_size_per_partition,
self.deep_ep_max_num_tokens * self.mapping.moe_ep_size,
final_hidden_states.shape[1]), deep_ep_topk_idx,
deep_ep_topk_weights, deep_ep_handle)
else:
raise NotImplementedError(
f"Not available alltoall method type: {alltoall_method_type!r}"
)
if self.layer_load_balancer and not self.layer_load_balancer.is_static_routing(
) and is_last_call:

View File

@ -23,7 +23,6 @@ class TRTLLMGenFusedMoE(MoE):
dtype (Optional[torch.dtype]): Data type for the weights.
reduce_results (bool): Whether to reduce the results across devices.
model_config (ModelConfig): Configuration object for the model.
enable_alltoall (bool): whether to enable alltoall instead of allgather/reducescatter
MoE torch custom op:
Only support min-latency mode now (SM100 Blackwell only).

View File

@ -89,8 +89,6 @@ class VanillaMoE(nn.ModuleList):
if model_config.moe_max_num_tokens
is not None else max_num_tokens)
self.enable_alltoall = False
self._weights_created = False
if not model_config.skip_create_weights_in_init:
self.create_weights()
@ -458,7 +456,7 @@ class VanillaMoE(nn.ModuleList):
use_dp_padding: Optional[bool] = None,
):
outputs = inputs
if self.parallel_size > 1 and not self.enable_alltoall:
if self.parallel_size > 1:
if self.use_dp:
outputs = reducescatter(
inputs,

View File

@ -27,7 +27,6 @@ class MoE(nn.Module):
dtype (Optional[torch.dtype]): Data type for the weights.
reduce_results (bool): Whether to reduce the results across devices.
model_config (ModelConfig): Configuration object for the model.
enable_alltoall (bool): whether to enable alltoall instead of allgather/reducescatter
"""
def __init__(
@ -123,3 +122,9 @@ class MoE(nn.Module):
assert self._weights_created
return self.quant_config is not None and self.quant_config.layer_quant_mode.has_nvfp4(
)
@property
def enable_alltoall(self):
""" enable_alltoall (bool): whether to enable alltoall instead of allgather/reducescatter
"""
return False

View File

@ -1003,6 +1003,7 @@ class PyTorchModelEngine(ModelEngine):
checkpoint_dir,
trust_remote_code=True,
enable_min_latency=self.pytorch_backend_config.enable_min_latency,
use_cuda_graph=self.pytorch_backend_config.use_cuda_graph,
spec_config=self.spec_config,
max_num_tokens=max_num_tokens,
moe_max_num_tokens=moe_max_num_tokens,

View File

@ -75,6 +75,9 @@ class Logger(metaclass=Singleton):
self._polygraphy_logger.module_severity = severity_map[
min_severity][2]
# For log_once
self._appeared_keys = set()
if invalid_severity:
self.warning(
f"Requested log level {environ_severity} is invalid. Using '{self.DEFAULT_LEVEL}' instead"
@ -109,23 +112,44 @@ class Logger(metaclass=Singleton):
parts.extend(map(str, msg))
self._func_wrapper(severity)(" ".join(parts))
def log_once(self, severity, *msg, key):
if key not in self._appeared_keys:
self._appeared_keys.add(key)
self.log(severity, *msg)
def critical(self, *msg):
self.log(self.INTERNAL_ERROR, *msg)
def critical_once(self, *msg, key):
self.log_once(self.INTERNAL_ERROR, *msg, key=key)
fatal = critical
fatal_once = critical_once
def error(self, *msg):
self.log(self.ERROR, *msg)
def error_once(self, *msg, key):
self.log_once(self.ERROR, *msg, key=key)
def warning(self, *msg):
self.log(self.WARNING, *msg)
def warning_once(self, *msg, key):
self.log_once(self.WARNING, *msg, key=key)
def info(self, *msg):
self.log(self.INFO, *msg)
def info_once(self, *msg, key):
self.log_once(self.INFO, *msg, key=key)
def debug(self, *msg):
self.log(self.VERBOSE, *msg)
def debug_once(self, *msg, key):
self.log_once(self.VERBOSE, *msg, key=key)
@property
def level(self) -> str:
return self._min_severity

View File

@ -54,6 +54,8 @@ l0_dgx_h100:
auto_trigger: deepseek
tests:
- unittest/_torch/multi_gpu_modeling -k "deepseek"
- unittest/_torch/modules/test_fused_moe.py::test_fused_moe_alltoall[DeepEP]
- unittest/_torch/modules/test_fused_moe.py::test_fused_moe_alltoall[DeepEPLowLatency]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp4-mtp_nextn=0-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp4-mtp_nextn=0-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=True]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp4-mtp_nextn=0-fp8kv=True-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False]

View File

@ -2,6 +2,7 @@ import pickle
import sys
from itertools import product
from typing import Dict, List, Optional
from unittest import mock
import cloudpickle
import pytest
@ -19,6 +20,8 @@ from tensorrt_llm._torch.modules.fused_moe import (BaseMoeRoutingMethod,
DefaultMoeRoutingMethod,
RenormalizeMoeRoutingMethod,
VanillaMoE)
from tensorrt_llm._torch.modules.fused_moe.fused_moe_cutlass import \
AlltoallMethodType
from tensorrt_llm._torch.modules.gated_mlp import GatedMLP
from tensorrt_llm._utils import mpi_rank
from tensorrt_llm.mapping import Mapping
@ -123,6 +126,111 @@ def test_fused_moe_multi_gpu(moe_cls, ep_size):
assert r is None
@pytest.mark.skipif(torch.cuda.device_count() < 4,
reason="needs 4 GPUs to run this test")
@pytest.mark.parametrize("alltoall_method_type", [
AlltoallMethodType.MNNVL, AlltoallMethodType.DeepEP,
AlltoallMethodType.DeepEPLowLatency
],
ids=lambda s: s.name)
def test_fused_moe_alltoall(alltoall_method_type):
world_size = 4
dtype = torch.bfloat16
HIDDEN_SIZE = 2560
INTERMEDIATE_SIZE = 1536
NUM_EXPERTS = 72
TOP_K = 6
MAX_NUM_TOKENS = 2048
def per_rank_test_fused_moe_alltoall(job_id):
routing_method = DefaultMoeRoutingMethod(top_k=TOP_K)
mapping = Mapping(world_size=world_size,
rank=mpi_rank(),
tp_size=world_size,
moe_ep_size=world_size,
moe_tp_size=1,
enable_attention_dp=True)
torch.cuda.set_device(mapping.rank)
torch.manual_seed(mapping.rank)
weights = {}
for expert_id in range(NUM_EXPERTS):
w1_weight = torch.empty((INTERMEDIATE_SIZE, HIDDEN_SIZE),
dtype=dtype)
w2_weight = torch.empty((HIDDEN_SIZE, INTERMEDIATE_SIZE),
dtype=dtype)
w3_weight = torch.empty((INTERMEDIATE_SIZE, HIDDEN_SIZE),
dtype=dtype)
torch.nn.init.xavier_uniform_(w1_weight)
torch.nn.init.xavier_uniform_(w2_weight)
torch.nn.init.xavier_uniform_(w3_weight)
weights[f"{expert_id}.w1.weight"] = w1_weight
weights[f"{expert_id}.w2.weight"] = w2_weight
weights[f"{expert_id}.w3.weight"] = w3_weight
with mock.patch.object(CutlassFusedMoE,
"select_alltoall_method_type",
return_value=alltoall_method_type):
alltoall_model = CutlassFusedMoE(
num_experts=NUM_EXPERTS,
routing_method=routing_method,
hidden_size=HIDDEN_SIZE,
intermediate_size=INTERMEDIATE_SIZE,
dtype=dtype,
reduce_results=True,
model_config=ModelConfig(mapping=mapping,
max_num_tokens=MAX_NUM_TOKENS),
)
alltoall_model.to("cuda")
alltoall_model.load_weights([weights])
with mock.patch.object(CutlassFusedMoE,
"select_alltoall_method_type",
return_value=AlltoallMethodType.NotEnabled):
ref_model = CutlassFusedMoE(
num_experts=NUM_EXPERTS,
routing_method=routing_method,
hidden_size=HIDDEN_SIZE,
intermediate_size=INTERMEDIATE_SIZE,
dtype=dtype,
reduce_results=True,
model_config=ModelConfig(mapping=mapping,
max_num_tokens=MAX_NUM_TOKENS),
)
ref_model.to("cuda")
ref_model.load_weights([weights])
# Evaluate the outputs on a variant sequence length to verify the robustness of alltoall methods
m = MAX_NUM_TOKENS
while m >= 1:
x = torch.randn((m, HIDDEN_SIZE), dtype=dtype).cuda()
router_logits = torch.randn((m, NUM_EXPERTS), dtype=dtype).cuda()
all_rank_num_tokens = [m] * mapping.world_size
with torch.inference_mode():
output = alltoall_model.forward(
x,
router_logits,
all_rank_num_tokens=all_rank_num_tokens,
use_dp_padding=False)
ref_output = ref_model.forward(
x,
router_logits,
all_rank_num_tokens=all_rank_num_tokens,
use_dp_padding=False)
# Evaluate outputs
torch.testing.assert_close(output,
ref_output,
rtol=0.05,
atol=0.003)
m //= 2
with MPIPoolExecutor(max_workers=world_size) as executor:
results = executor.map(per_rank_test_fused_moe_alltoall,
range(world_size))
for r in results:
assert r is None
@skip_pre_hopper
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
def test_fused_moe_fp8(dtype):