feat: large-scale EP(part 7: DeepEP integration) (#4792)

Signed-off-by: Tailing Yuan <yuantailing@gmail.com>
Co-authored-by: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com>
This commit is contained in:
Tailing Yuan 2025-06-14 19:12:38 +08:00 committed by GitHub
parent 443b2eb51f
commit 0b60da2c45
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 610 additions and 88 deletions

View File

@ -1,7 +1,7 @@
version: "3.9" version: "3.9"
services: services:
tensorrt_llm-dev: tensorrt_llm-dev:
image: urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:pytorch-25.04-py3-x86_64-ubuntu24.04-trt10.10.0.31-skip-tritondevel-202506021004-9420 image: urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:pytorch-25.04-py3-x86_64-ubuntu24.04-trt10.10.0.31-skip-tritondevel-202506111045-4792
network_mode: host network_mode: host
ipc: host ipc: host

View File

@ -72,6 +72,10 @@ RUN bash ./install_pytorch.sh $TORCH_INSTALL_TYPE && rm install_pytorch.sh
RUN pip3 uninstall -y opencv && rm -rf /usr/local/lib/python3*/dist-packages/cv2/ RUN pip3 uninstall -y opencv && rm -rf /usr/local/lib/python3*/dist-packages/cv2/
RUN pip3 install opencv-python-headless --force-reinstall --no-deps --no-cache-dir RUN pip3 install opencv-python-headless --force-reinstall --no-deps --no-cache-dir
# Install DeepEP
COPY docker/common/install_deep_ep.sh install_deep_ep.sh
RUN bash ./install_deep_ep.sh && rm install_deep_ep.sh
# WARs against security issues inherited from pytorch:25.04 # WARs against security issues inherited from pytorch:25.04
# * https://github.com/advisories/GHSA-vqfr-h8mv-ghfj # * https://github.com/advisories/GHSA-vqfr-h8mv-ghfj
# * https://github.com/advisories/GHSA-7cx3-6m66-7c5m # * https://github.com/advisories/GHSA-7cx3-6m66-7c5m

View File

@ -0,0 +1,47 @@
#!/bin/bash
set -euxo pipefail
GITHUB_URL=${GITHUB_MIRROR:-https://github.com}
DEEP_EP_COMMIT=2b266cf6452134f993ab0fcb3ef2d5de7683c561
if [ "$(. /etc/os-release && echo $ID)" == "rocky" ]; then
echo "Skipping DeepEP installation in the Rocky distribution."
exit 0
fi
libmlx5_dir=$(dirname $(ldconfig -p | grep libmlx5.so.1 | head -n1 | awk '{print $NF}'))
export NVCC_APPEND_FLAGS="--threads 4"
# Custom NVSHMEM
curl -fsSL https://developer.download.nvidia.com/compute/redist/nvshmem/3.2.5/source/nvshmem_src_3.2.5-1.txz | tar xz
pushd nvshmem_src
curl -fsSL $GITHUB_URL/deepseek-ai/DeepEP/raw/$DEEP_EP_COMMIT/third-party/nvshmem.patch | patch -p1
sed "s/TRANSPORT_VERSION_MAJOR 3/TRANSPORT_VERSION_MAJOR 103/" -i src/CMakeLists.txt
ln -s libmlx5.so.1 "$libmlx5_dir/libmlx5.so"
cmake -S . -B build \
-DCMAKE_INSTALL_PREFIX=/opt/custom_nvshmem \
-DGDRCOPY_HOME=/usr/include \
-DNVSHMEM_SHMEM_SUPPORT=0 \
-DNVSHMEM_UCX_SUPPORT=0 \
-DNVSHMEM_USE_NCCL=0 \
-DNVSHMEM_MPI_SUPPORT=0 \
-DNVSHMEM_IBGDA_SUPPORT=1 \
-DNVSHMEM_PMIX_SUPPORT=0 \
-DNVSHMEM_TIMEOUT_DEVICE_POLLING=0 \
-DNVSHMEM_USE_GDRCOPY=1 \
-DCMAKE_CUDA_ARCHITECTURES="90-real;100-real;120-real" \
-DNVSHMEM_BUILD_TESTS=0 \
-DNVSHMEM_BUILD_EXAMPLES=0
cmake --build build -j`nproc`
make -C build install
popd
# DeepEP
curl -fsSL $GITHUB_URL/deepseek-ai/DeepEP/archive/$DEEP_EP_COMMIT.tar.gz | tar xz
TORCH_CUDA_ARCH_LIST="9.0;10.0;12.0" NVSHMEM_DIR=/opt/custom_nvshmem pip install -v --no-cache-dir ./DeepEP-$DEEP_EP_COMMIT
# Clean up
rm -r nvshmem_src
rm "$libmlx5_dir/libmlx5.so"
rm -r DeepEP-$DEEP_EP_COMMIT

View File

@ -28,10 +28,10 @@ UPLOAD_PATH = env.uploadPath ? env.uploadPath : "sw-tensorrt-generic/llm-artifac
// Container configuration // Container configuration
// available tags can be found in: https://urm.nvidia.com/artifactory/sw-tensorrt-docker/tensorrt-llm/ // available tags can be found in: https://urm.nvidia.com/artifactory/sw-tensorrt-docker/tensorrt-llm/
// [base_image_name]-[arch]-[os](-[python_version])-[trt_version]-[torch_install_type]-[stage]-[date]-[mr_id] // [base_image_name]-[arch]-[os](-[python_version])-[trt_version]-[torch_install_type]-[stage]-[date]-[mr_id]
LLM_DOCKER_IMAGE = "urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:pytorch-25.04-py3-x86_64-ubuntu24.04-trt10.10.0.31-skip-tritondevel-202506021004-9420" LLM_DOCKER_IMAGE = "urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:pytorch-25.04-py3-x86_64-ubuntu24.04-trt10.10.0.31-skip-tritondevel-202506111045-4792"
LLM_SBSA_DOCKER_IMAGE = "urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:pytorch-25.04-py3-aarch64-ubuntu24.04-trt10.10.0.31-skip-tritondevel-202506021004-9420" LLM_SBSA_DOCKER_IMAGE = "urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:pytorch-25.04-py3-aarch64-ubuntu24.04-trt10.10.0.31-skip-tritondevel-202506111045-4792"
LLM_ROCKYLINUX8_PY310_DOCKER_IMAGE = "urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:cuda-12.9.0-devel-rocky8-x86_64-rocky8-py310-trt10.10.0.31-skip-tritondevel-202506021004-9420" LLM_ROCKYLINUX8_PY310_DOCKER_IMAGE = "urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:cuda-12.9.0-devel-rocky8-x86_64-rocky8-py310-trt10.10.0.31-skip-tritondevel-202506111045-4792"
LLM_ROCKYLINUX8_PY312_DOCKER_IMAGE = "urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:cuda-12.9.0-devel-rocky8-x86_64-rocky8-py312-trt10.10.0.31-skip-tritondevel-202506021004-9420" LLM_ROCKYLINUX8_PY312_DOCKER_IMAGE = "urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:cuda-12.9.0-devel-rocky8-x86_64-rocky8-py312-trt10.10.0.31-skip-tritondevel-202506111045-4792"
// TODO: Move common variables to an unified location // TODO: Move common variables to an unified location
BUILD_CORES_REQUEST = "8" BUILD_CORES_REQUEST = "8"

View File

@ -1,7 +1,7 @@
import java.lang.InterruptedException import java.lang.InterruptedException
DOCKER_IMAGE = "urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:pytorch-25.04-py3-x86_64-ubuntu24.04-trt10.10.0.31-skip-tritondevel-202506021004-9420" DOCKER_IMAGE = "urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:pytorch-25.04-py3-x86_64-ubuntu24.04-trt10.10.0.31-skip-tritondevel-202506111045-4792"
def createKubernetesPodConfig(image, arch = "amd64") def createKubernetesPodConfig(image, arch = "amd64")
{ {

View File

@ -84,6 +84,9 @@ class ModelConfig(Generic[TConfig]):
# If true, enable min-latency mode. Currently only used for Llama4. # If true, enable min-latency mode. Currently only used for Llama4.
enable_min_latency: bool = False enable_min_latency: bool = False
# Allow models to select op according to whether CUDA Graphs are used.
use_cuda_graph: bool = False
extra_attrs: Dict = field(default_factory=dict, repr=False, init=False) extra_attrs: Dict = field(default_factory=dict, repr=False, init=False)
_frozen: bool = field(default=False, init=False, repr=False) _frozen: bool = field(default=False, init=False, repr=False)

View File

@ -38,7 +38,6 @@ from torch import nn
from tqdm import tqdm from tqdm import tqdm
from transformers import PretrainedConfig from transformers import PretrainedConfig
from tensorrt_llm._mnnvl_utils import MnnvlMemory
from tensorrt_llm.functional import PositionEmbeddingType from tensorrt_llm.functional import PositionEmbeddingType
from tensorrt_llm.llmapi.utils import enable_llm_debug from tensorrt_llm.llmapi.utils import enable_llm_debug
from tensorrt_llm.mapping import Mapping from tensorrt_llm.mapping import Mapping
@ -413,10 +412,6 @@ class Deepseekv3MoE(nn.Module):
config = model_config.pretrained_config config = model_config.pretrained_config
self.top_k = top_k self.top_k = top_k
self.use_dp = model_config.mapping.enable_attention_dp self.use_dp = model_config.mapping.enable_attention_dp
self.enable_alltoall = Deepseekv3MoE.should_enable_alltoall(
model_config, top_k)
if self.enable_alltoall:
MnnvlMemory.initialize()
self.gate = DeepseekV3Gate( self.gate = DeepseekV3Gate(
hidden_size, hidden_size,
num_experts, num_experts,
@ -439,7 +434,6 @@ class Deepseekv3MoE(nn.Module):
model_config=model_config, model_config=model_config,
override_quant_config=override_quant_config, override_quant_config=override_quant_config,
aux_stream=aux_stream_dict[AuxStreamType.MoeChunkingOverlap], aux_stream=aux_stream_dict[AuxStreamType.MoeChunkingOverlap],
enable_alltoall=self.enable_alltoall,
layer_idx=layer_idx) layer_idx=layer_idx)
self.mapping = model_config.mapping self.mapping = model_config.mapping
@ -505,25 +499,6 @@ class Deepseekv3MoE(nn.Module):
return shared_tp_size, shared_output_scale return shared_tp_size, shared_output_scale
@staticmethod
def should_enable_alltoall(model_config: ModelConfig, top_k: int) -> bool:
if not model_config.mapping.enable_attention_dp:
return False
if model_config.mapping.tp_size == 1:
return False
if not MnnvlMemory.supports_mnnvl():
return False
if os.environ.get("TRTLLM_MOE_DISABLE_ALLTOALLV", "0") == "1":
return False
if model_config.mapping.moe_ep_size <= top_k:
return False
return True
def compute_routed_output(self, hidden_states, hidden_states_fp4, def compute_routed_output(self, hidden_states, hidden_states_fp4,
all_rank_num_tokens, do_finalize): all_rank_num_tokens, do_finalize):
# max-throughput # max-throughput
@ -531,7 +506,7 @@ class Deepseekv3MoE(nn.Module):
if self.use_dp and self.mapping.tp_size > 1: if self.use_dp and self.mapping.tp_size > 1:
# FP4 all_gather moves this bf16 allgather in to after topk and fp4 quantization # FP4 all_gather moves this bf16 allgather in to after topk and fp4 quantization
# to reduce allreduce BW # to reduce allreduce BW
if disable_fp4_allgather() and not self.enable_alltoall: if disable_fp4_allgather() and not self.experts.enable_alltoall:
hidden_states = allgather(hidden_states, hidden_states = allgather(hidden_states,
self.mapping, self.mapping,
dim=0, dim=0,

View File

@ -6,8 +6,6 @@ from torch import nn
from tqdm import tqdm from tqdm import tqdm
from transformers import Qwen3MoeConfig from transformers import Qwen3MoeConfig
from tensorrt_llm._mnnvl_utils import MnnvlMemory
from ..attention_backend import AttentionMetadata from ..attention_backend import AttentionMetadata
from ..distributed import (AllReduce, AllReduceFusionOp, AllReduceParams, from ..distributed import (AllReduce, AllReduceFusionOp, AllReduceParams,
allgather) allgather)
@ -91,10 +89,6 @@ class Qwen3MoE(nn.Module):
self.mapping = model_config.mapping self.mapping = model_config.mapping
self.allreduce = AllReduce(mapping=model_config.mapping, self.allreduce = AllReduce(mapping=model_config.mapping,
strategy=model_config.allreduce_strategy) strategy=model_config.allreduce_strategy)
self.enable_alltoall = Qwen3MoE.should_enable_alltoall(
model_config, self.top_k)
if self.enable_alltoall:
MnnvlMemory.initialize()
self.gate = Qwen3Gate( self.gate = Qwen3Gate(
hidden_size=self.hidden_dim, hidden_size=self.hidden_dim,
@ -117,25 +111,6 @@ class Qwen3MoE(nn.Module):
model_config=model_config, model_config=model_config,
) )
@staticmethod
def should_enable_alltoall(model_config: ModelConfig, top_k: int) -> bool:
if not model_config.mapping.enable_attention_dp:
return False
if model_config.mapping.tp_size == 1:
return False
if not MnnvlMemory.supports_mnnvl():
return False
if os.environ.get("TRTLLM_MOE_DISABLE_ALLTOALLV", "0") == "1":
return False
if model_config.mapping.moe_ep_size <= top_k:
return False
return True
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
@ -151,7 +126,7 @@ class Qwen3MoE(nn.Module):
if self.enable_attention_dp and self.mapping.tp_size > 1: if self.enable_attention_dp and self.mapping.tp_size > 1:
# FP4 all_gather moves this bf16 allgather in to after topk and fp4 quantization # FP4 all_gather moves this bf16 allgather in to after topk and fp4 quantization
# to reduce allreduce BW # to reduce allreduce BW
if disable_fp4_allgather() and not self.enable_alltoall: if disable_fp4_allgather() and not self.experts.enable_alltoall:
hidden_states = allgather(hidden_states, hidden_states = allgather(hidden_states,
self.mapping, self.mapping,
dim=0, dim=0,

View File

@ -52,7 +52,6 @@ def create_moe(
aux_stream: Optional[torch.cuda.Stream] = None, aux_stream: Optional[torch.cuda.Stream] = None,
weight_loading_mode: MoEWeightLoadingMode = MoEWeightLoadingMode.VANILLA, weight_loading_mode: MoEWeightLoadingMode = MoEWeightLoadingMode.VANILLA,
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
enable_alltoall: bool = False,
layer_idx: Optional[int] = None, layer_idx: Optional[int] = None,
) -> MoE: ) -> MoE:
moe_cls = get_moe_cls(model_config, override_quant_config) moe_cls = get_moe_cls(model_config, override_quant_config)
@ -63,7 +62,6 @@ def create_moe(
if moe_cls == TRTLLMGenFusedMoE: if moe_cls == TRTLLMGenFusedMoE:
assert not apply_router_weight_on_input, "apply_router_weight_on_input is not supported in TRTLLMGenFusedMoE." assert not apply_router_weight_on_input, "apply_router_weight_on_input is not supported in TRTLLMGenFusedMoE."
assert not enable_alltoall, "enable_alltoall is not supported in TRTLLMGenFusedMoE."
return moe_cls( return moe_cls(
routing_method=routing_method, routing_method=routing_method,
@ -88,12 +86,10 @@ def create_moe(
aux_stream=aux_stream, aux_stream=aux_stream,
weight_loading_mode=weight_loading_mode, weight_loading_mode=weight_loading_mode,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=apply_router_weight_on_input,
enable_alltoall=enable_alltoall,
layer_idx=layer_idx, layer_idx=layer_idx,
) )
elif moe_cls == VanillaMoE: elif moe_cls == VanillaMoE:
assert not apply_router_weight_on_input, "apply_router_weight_on_input is not supported in VanillaMoE." assert not apply_router_weight_on_input, "apply_router_weight_on_input is not supported in VanillaMoE."
assert not enable_alltoall, "enable_alltoall is not supported in VanillaMoE."
return moe_cls( return moe_cls(
routing_method=routing_method, routing_method=routing_method,

View File

@ -0,0 +1,209 @@
# Adapted from
# https://github.com/deepseek-ai/DeepEP/blob/aae9fa9a6dd0fec2a723fbb85ec4b22460fab670/README.md
import weakref
from typing import List, Tuple, Union
import torch
from tensorrt_llm._utils import local_mpi_size, mpi_comm
from tensorrt_llm.mapping import Mapping
try:
from deep_ep import Buffer
deep_ep_installed = True
except ModuleNotFoundError:
deep_ep_installed = False
class VariableLengthBuffer:
""" A wrapper of deep_ep.Buffer that accepts future size change
"""
def __init__(self, mapping: Mapping):
self.comm = mpi_comm().Split(mapping.pp_rank, mapping.moe_ep_rank)
self.buffer = None
def __del__(self):
self.comm.Free()
def reserve(self, hidden_size: int, hidden_dtype: torch.dtype):
""" Ensure the buffer capacity is large enough.
Reserve is a collective operation that requires all EP ranks to be sync
"""
# NOTES: you may also replace `get_*_config` with your auto-tuned results via all the tests
num_nvl_bytes, num_rdma_bytes = 0, 0
hidden_bytes = hidden_size * max(hidden_dtype.itemsize,
torch.bfloat16.itemsize)
world_size = self.comm.Get_size()
for config in (Buffer.get_dispatch_config(world_size),
Buffer.get_combine_config(world_size)):
num_nvl_bytes = max(
config.get_nvl_buffer_size_hint(hidden_bytes, world_size),
num_nvl_bytes)
num_rdma_bytes = max(
config.get_rdma_buffer_size_hint(hidden_bytes, world_size),
num_rdma_bytes)
# Allocate a buffer if not existed or not enough buffer size
if self.buffer is None or self.buffer.num_nvl_bytes < num_nvl_bytes or self.buffer.num_rdma_bytes < num_rdma_bytes:
if self.buffer is not None:
num_nvl_bytes = max(num_nvl_bytes, self.buffer.num_nvl_bytes)
num_rdma_bytes = max(num_rdma_bytes, self.buffer.num_rdma_bytes)
del self.buffer # Destruct before Construct
self.buffer = Buffer(None,
num_nvl_bytes,
num_rdma_bytes,
num_nvl_peers=local_mpi_size(),
comm=self.comm)
def dispatch(self, x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
topk_idx: torch.Tensor, topk_weights: torch.Tensor,
num_experts: int) -> \
Tuple[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], torch.Tensor, torch.Tensor, List, Tuple]:
# NOTES: an optional `previous_event` means a CUDA event captured that you want to make it as a dependency
# of the dispatch kernel, it may be useful with communication-computation overlap. For more information, please
# refer to the docs of `Buffer.dispatch`
# Calculate layout before actual dispatch
num_tokens_per_rank, num_tokens_per_rdma_rank, num_tokens_per_expert, is_token_in_rank, event = \
self.buffer.get_dispatch_layout(topk_idx, num_experts)
assert event.event is None
# Do MoE dispatch
# NOTES: the CPU will wait for GPU's signal to arrive, so this is not compatible with CUDA graph
# For more advanced usages, please refer to the docs of the `dispatch` function
recv_x, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, handle, event = \
self.buffer.dispatch(x, topk_idx=topk_idx, topk_weights=topk_weights,
num_tokens_per_rank=num_tokens_per_rank, num_tokens_per_rdma_rank=num_tokens_per_rdma_rank,
is_token_in_rank=is_token_in_rank, num_tokens_per_expert=num_tokens_per_expert)
assert event.event is None
# For event management, please refer to the docs of the `EventOverlap` class
return recv_x, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, handle
def combine(self, x: torch.Tensor, handle: Tuple) -> torch.Tensor:
# Do MoE combine
# For more advanced usages, please refer to the docs of the `combine` function
combined_x, _, event = self.buffer.combine(x, handle)
assert event.event is None
# For event management, please refer to the docs of the `EventOverlap` class
return combined_x
class VariableLengthLowLatencyBuffer:
""" A wrapper of deep_ep.Buffer that accepts future size change
"""
def __init__(self, mapping: Mapping):
self.comm = mpi_comm().Split(mapping.pp_rank, mapping.moe_ep_rank)
self.buffer = None
self.num_max_dispatch_tokens_per_rank = None
def __del__(self):
self.comm.Free()
def reserve(self, num_max_dispatch_tokens_per_rank: int, hidden_size: int,
num_experts: int):
""" Ensure the buffer capacity is large enough.
Reserve is a collective operation that requires all EP ranks to be sync
"""
# NOTES: the low-latency mode will consume much more space than the normal mode
# So we recommend that `num_max_dispatch_tokens_per_rank` (the actual batch size in the decoding engine) should be less than 256
world_size = self.comm.Get_size()
num_rdma_bytes = Buffer.get_low_latency_rdma_size_hint(
num_max_dispatch_tokens_per_rank, hidden_size, world_size,
num_experts)
# Allocate a buffer if not existed or not enough buffer size
if self.buffer is None or self.buffer.num_rdma_bytes < num_rdma_bytes:
# NOTES: for best performance, the QP number **must** be equal to the number of the local experts
assert num_experts % world_size == 0
del self.buffer # Destruct before Construct
self.buffer = Buffer(None,
0,
num_rdma_bytes,
low_latency_mode=True,
num_qps_per_rank=num_experts // world_size,
comm=self.comm)
def low_latency_dispatch(self, hidden_states: torch.Tensor,
topk_idx: torch.Tensor,
num_max_dispatch_tokens_per_rank: int,
num_experts: int):
if self.num_max_dispatch_tokens_per_rank is None:
self.num_max_dispatch_tokens_per_rank = num_max_dispatch_tokens_per_rank
if num_max_dispatch_tokens_per_rank != self.num_max_dispatch_tokens_per_rank:
raise NotImplementedError(
"There are issues if `low_latency_dispatch` calls use different `num_max_dispatch_tokens_per_rank` values"
)
# Do MoE dispatch, compatible with CUDA graph (but you may restore some buffer status once you replay)
recv_hidden_states, recv_expert_count, handle, event, hook = \
self.buffer.low_latency_dispatch(hidden_states, topk_idx, num_max_dispatch_tokens_per_rank, num_experts, use_fp8=False)
assert event.event is None
assert hook is None
# NOTES: the actual tensor will not be received only if you call `hook()`,
# it is useful for double-batch overlapping, but **without any SM occupation**
# If you don't want to overlap, please set `return_recv_hook=False`
# Later, you can use our GEMM library to do the computation with this specific format
return recv_hidden_states, recv_expert_count, handle
def low_latency_combine(self, hidden_states: torch.Tensor,
topk_idx: torch.Tensor, topk_weights: torch.Tensor,
handle: Tuple):
# Do MoE combine, compatible with CUDA graph (but you may restore some buffer status once you replay)
combined_hidden_states, event, hook = \
self.buffer.low_latency_combine(hidden_states, topk_idx, topk_weights, handle)
assert event.event is None
assert hook is None
# NOTES: the same behavior as described in the dispatch kernel
return combined_hidden_states
class BufferPool:
""" A pool that allocates buffers on demand.
Although the pool interface allows creating multiple buffers, the
current version of DeepEP supports at most one `deep_ep.Buffer` at a
time. Please ensure that all references to `VariableLengthBuffer` are
released before getting another buffer.
"""
def __init__(self):
self.buffers: Map[Mapping,
weakref.ReferenceType[VariableLengthBuffer]] = {}
self.low_latency_buffers: Map[
Mapping,
weakref.ReferenceType[VariableLengthLowLatencyBuffer]] = {}
def get_buffer(self, mapping: Mapping) -> VariableLengthBuffer:
""" Get_buffer is a collective operation that requires all ranks to be sync
"""
if mapping in self.buffers and self.buffers[mapping]() is not None:
buffer = self.buffers[mapping]()
else:
buffer = VariableLengthBuffer(mapping)
self.buffers[mapping] = weakref.ref(buffer)
return buffer
def get_low_latency_buffer(
self, mapping: Mapping) -> VariableLengthLowLatencyBuffer:
""" Get_low_latency_buffer is a collective operation that requires all ranks to be sync
"""
if mapping in self.low_latency_buffers and self.low_latency_buffers[
mapping]() is not None:
buffer = self.low_latency_buffers[mapping]()
else:
buffer = VariableLengthLowLatencyBuffer(mapping)
self.low_latency_buffers[mapping] = weakref.ref(buffer)
return buffer
# The default pool
# You may create own pools for better resource management.
buffer_pool = BufferPool()

View File

@ -1,16 +1,19 @@
import os import os
from enum import IntEnum
from typing import Dict, List, Optional, Tuple, Union from typing import Dict, List, Optional, Tuple, Union
import torch import torch
from tensorrt_llm._mnnvl_utils import MnnvlMoe, MoEAlltoallInfo from tensorrt_llm._mnnvl_utils import MnnvlMemory, MnnvlMoe, MoEAlltoallInfo
from tensorrt_llm._utils import logger from tensorrt_llm._utils import logger
from tensorrt_llm.mapping import Mapping
from ...distributed import allgather, reducescatter from ...distributed import allgather, reducescatter
from ...expert_statistic import ExpertStatistic from ...expert_statistic import ExpertStatistic
from ...model_config import ModelConfig from ...model_config import ModelConfig
from ...utils import (EventType, Fp4QuantizedTensor, disable_fp4_allgather, from ...utils import (EventType, Fp4QuantizedTensor, disable_fp4_allgather,
reswizzle_sf, swizzle_sf, unswizzle_sf) reswizzle_sf, swizzle_sf, unswizzle_sf)
from .deep_ep_utils import buffer_pool, deep_ep_installed
from .interface import MoE from .interface import MoE
from .moe_load_balancer import get_moe_load_balancer from .moe_load_balancer import get_moe_load_balancer
from .quantization import (DeepSeekFP8BlockScalesFusedMoEMethod, from .quantization import (DeepSeekFP8BlockScalesFusedMoEMethod,
@ -20,6 +23,18 @@ from .quantization import (DeepSeekFP8BlockScalesFusedMoEMethod,
from .routing import BaseMoeRoutingMethod from .routing import BaseMoeRoutingMethod
# The type of alltoall method
class AlltoallMethodType(IntEnum):
# Not available
NotEnabled = 0
# MNNVL
MNNVL = 1
# DeepEP intranode or internode: no CUDA Graphs support, IBGDA is required by internode
DeepEP = 2
# DeepEP low latency: CUDA Graphs are supported, IBGDA is required
DeepEPLowLatency = 3
class CutlassFusedMoE(MoE): class CutlassFusedMoE(MoE):
""" """
Fused Mixture of Experts (MoE) Layer with performance tuning. Fused Mixture of Experts (MoE) Layer with performance tuning.
@ -33,7 +48,6 @@ class CutlassFusedMoE(MoE):
dtype (Optional[torch.dtype]): Data type for the weights. dtype (Optional[torch.dtype]): Data type for the weights.
reduce_results (bool): Whether to reduce the results across devices. reduce_results (bool): Whether to reduce the results across devices.
model_config (ModelConfig): Configuration object for the model. model_config (ModelConfig): Configuration object for the model.
enable_alltoall (bool): whether to enable alltoall instead of allgather/reducescatter
MoE torch custom op: MoE torch custom op:
In min-latency mode: In min-latency mode:
@ -82,7 +96,6 @@ class CutlassFusedMoE(MoE):
weight_loading_mode: MoEWeightLoadingMode = MoEWeightLoadingMode. weight_loading_mode: MoEWeightLoadingMode = MoEWeightLoadingMode.
VANILLA, VANILLA,
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
enable_alltoall: bool = False,
layer_idx: Optional[int] = None, layer_idx: Optional[int] = None,
): ):
@ -176,7 +189,12 @@ class CutlassFusedMoE(MoE):
self.has_been_profiled = False self.has_been_profiled = False
self.has_been_profiled_min_latency = False self.has_been_profiled_min_latency = False
self.enable_alltoall = enable_alltoall self.alltoall_method_type = self.select_alltoall_method_type(
model_config.mapping, routing_method.experts_per_token, dtype,
model_config.use_cuda_graph)
logger.info_once(
f"CutlassFusedMoE selects alltoall_method_type {self.alltoall_method_type!r}",
key="alltoall_method_type")
self.use_postquant_alltoall = False self.use_postquant_alltoall = False
if self.enable_alltoall: if self.enable_alltoall:
assert self.use_dp and self.parallel_size > 1,\ assert self.use_dp and self.parallel_size > 1,\
@ -185,8 +203,25 @@ class CutlassFusedMoE(MoE):
self.use_postquant_alltoall = (os.environ.get( self.use_postquant_alltoall = (os.environ.get(
"TRTLLM_MOE_POST_QUANT_ALLTOALLV", "1") "TRTLLM_MOE_POST_QUANT_ALLTOALLV", "1")
== "1") and qm.has_nvfp4() == "1") and qm.has_nvfp4()
self.alltoall_workspace = MnnvlMoe.get_moe_workspaces( if self.alltoall_method_type == AlltoallMethodType.MNNVL:
model_config.mapping) if enable_alltoall else None MnnvlMemory.initialize()
self.alltoall_workspace = MnnvlMoe.get_moe_workspaces(
model_config.mapping)
elif self.alltoall_method_type == AlltoallMethodType.DeepEP:
self.deep_ep_buffer = buffer_pool.get_buffer(
model_config.mapping)
self.deep_ep_buffer.reserve(hidden_size, dtype)
elif self.alltoall_method_type == AlltoallMethodType.DeepEPLowLatency:
self.deep_ep_max_num_tokens = min(model_config.max_num_tokens,
self.moe_max_num_tokens)
self.deep_ep_buffer = buffer_pool.get_low_latency_buffer(
model_config.mapping)
self.deep_ep_buffer.reserve(self.deep_ep_max_num_tokens,
hidden_size, self.num_slots)
else:
raise NotImplementedError(
f"Not available alltoall method type: {alltoall_method_type!r}"
)
# If True, the router weight will be multiplied on the input rather than at the end of FC2 # If True, the router weight will be multiplied on the input rather than at the end of FC2
self.apply_router_weight_on_input = apply_router_weight_on_input self.apply_router_weight_on_input = apply_router_weight_on_input
@ -215,12 +250,48 @@ class CutlassFusedMoE(MoE):
f"unsupported quantization mode: {self.quant_config.quant_mode}" f"unsupported quantization mode: {self.quant_config.quant_mode}"
) )
@staticmethod
def select_alltoall_method_type(mapping: Mapping, top_k: int,
dtype: torch.dtype,
use_cuda_graph: bool) -> AlltoallMethodType:
if not mapping.enable_attention_dp:
return AlltoallMethodType.NotEnabled
if mapping.tp_size == 1:
return AlltoallMethodType.NotEnabled
if os.environ.get("TRTLLM_MOE_DISABLE_ALLTOALLV", "0") == "1":
return AlltoallMethodType.NotEnabled
if mapping.moe_ep_size <= top_k:
return AlltoallMethodType.NotEnabled
if MnnvlMemory.supports_mnnvl():
return AlltoallMethodType.MNNVL
if os.environ.get("TRTLLM_CAN_USE_DEEP_EP", "0") == "1":
if deep_ep_installed and dtype == torch.bfloat16:
if use_cuda_graph:
# Here we can only choose DeepEPLowLatency since only this method supports CUDA Graphs.
return AlltoallMethodType.DeepEPLowLatency
else:
# Here we can choose DeepEP or DeepEPLowLatency if both are available. Now DeepEP is faster.
return AlltoallMethodType.DeepEP
return AlltoallMethodType.NotEnabled
@property @property
def has_w4afp8(self): def has_w4afp8(self):
assert self._weights_created assert self._weights_created
return self.quant_config and self.quant_config.quant_mode.is_int4_weight_only_per_group( return self.quant_config and self.quant_config.quant_mode.is_int4_weight_only_per_group(
) )
@property
def enable_alltoall(self):
""" enable_alltoall (bool): whether to enable alltoall instead of allgather/reducescatter
"""
return self.alltoall_method_type != AlltoallMethodType.NotEnabled
def _get_quant_method(self): def _get_quant_method(self):
if self.quant_config is not None and self.quant_config.layer_quant_mode.has_any_quant( if self.quant_config is not None and self.quant_config.layer_quant_mode.has_any_quant(
exclude_kv_cache=True): exclude_kv_cache=True):
@ -311,10 +382,6 @@ class CutlassFusedMoE(MoE):
# TODO: remove this once we have correct fusedmoe kernel ready # TODO: remove this once we have correct fusedmoe kernel ready
token_final_scales = None token_final_scales = None
token_count = x.shape[0]
alltoall_info = None
if self.layer_load_balancer and not self.layer_load_balancer.is_static_routing( if self.layer_load_balancer and not self.layer_load_balancer.is_static_routing(
) and is_first_call: ) and is_first_call:
self.layer_load_balancer.maybe_cudagraph_done_wait() self.layer_load_balancer.maybe_cudagraph_done_wait()
@ -334,13 +401,58 @@ class CutlassFusedMoE(MoE):
ExpertStatistic.maybe_add_info(self.num_slots, token_selected_slots) ExpertStatistic.maybe_add_info(self.num_slots, token_selected_slots)
token_selected_experts_for_statistic = token_selected_experts if need_statistic else None token_selected_experts_for_statistic = token_selected_experts if need_statistic else None
if self.enable_alltoall: if self.enable_alltoall:
x, token_selected_slots, token_final_scales, token_selected_experts_for_statistic, alltoall_info = \ if self.alltoall_method_type == AlltoallMethodType.MNNVL:
self.alltoall_prepare_maybe_dispatch(all_rank_num_tokens, token_count = x.shape[0]
x, alltoall_info = None
token_selected_slots, x, token_selected_slots, token_final_scales, token_selected_experts_for_statistic, alltoall_info = \
token_final_scales, self.alltoall_prepare_maybe_dispatch(all_rank_num_tokens,
token_selected_experts_for_statistic) x,
token_selected_slots,
token_final_scales,
token_selected_experts_for_statistic)
elif self.alltoall_method_type == AlltoallMethodType.DeepEP:
if not self.use_postquant_alltoall:
x, recv_topk_idx, token_final_scales, num_recv_tokens_per_expert_list, deep_ep_handle = \
self.deep_ep_buffer.dispatch(x, token_selected_slots.to(torch.int64), token_final_scales, self.num_slots)
elif self.alltoall_method_type == AlltoallMethodType.DeepEPLowLatency:
if not self.use_postquant_alltoall:
deep_ep_topk_idx = token_selected_slots.to(torch.int64)
deep_ep_topk_weights = token_final_scales
x, recv_expert_count, deep_ep_handle = \
self.deep_ep_buffer.low_latency_dispatch(x, deep_ep_topk_idx, self.deep_ep_max_num_tokens, self.num_slots)
# x shape: [#local experts, #max recv tokens, hidden_size]
# recv_expert_count shape: [#local experts]
# Adapter between `torch.ops.trtllm.fused_moe` and DeepEP
# TODO: remove the adapter by changing `torch.ops.trtllm.fused_moe` API
mask = torch.arange(
x.shape[1], dtype=torch.int32, device=x.device).expand(
x.shape[0],
x.shape[1]) < recv_expert_count.unsqueeze(1)
token_selected_slots = torch.full(
(x.shape[0], x.shape[1], self.routing_method.top_k),
self.num_slots,
dtype=torch.int32,
device=x.device)
token_selected_slots[:, :, 0] = torch.where(
mask,
torch.arange(
x.shape[0] * self.mapping.moe_ep_rank,
x.shape[0] * (self.mapping.moe_ep_rank + 1),
dtype=torch.int32,
device=x.device).unsqueeze(1), self.num_slots)
x = x.view(x.shape[0] * x.shape[1], x.shape[2])
token_selected_slots = token_selected_slots.view(
x.shape[0], self.routing_method.top_k)
token_final_scales = torch.ones_like(
token_selected_slots, dtype=token_final_scales.dtype)
else:
raise NotImplementedError(
f"Not available alltoall method type: {alltoall_method_type!r}"
)
x_sf = None x_sf = None
if self.has_any_quant: if self.has_any_quant:
if self.has_fp8_qdq: if self.has_fp8_qdq:
@ -414,8 +526,56 @@ class CutlassFusedMoE(MoE):
quant_scales = self.quant_scales quant_scales = self.quant_scales
if self.use_postquant_alltoall: if self.use_postquant_alltoall:
x, x_sf = self.alltoall_postquant_dispatch(x, x_sf, x_row, x_col, if self.alltoall_method_type == AlltoallMethodType.MNNVL:
alltoall_info) x, x_sf = self.alltoall_postquant_dispatch(
x, x_sf, x_row, x_col, alltoall_info)
elif self.alltoall_method_type == AlltoallMethodType.DeepEP:
if x_sf is not None:
if self.has_nvfp4:
x_sf = unswizzle_sf(x_sf, x_row, x_col,
self.scaling_vector_size)
# Adapter between `x_sf` and DeepEP
# TODO: remove the adapter by adding dtype support to DeepEP
x_sf_dtype = x_sf.dtype
x_sf = x_sf.view(torch.float32)
(x, x_sf), recv_topk_idx, token_final_scales, num_recv_tokens_per_expert_list, deep_ep_handle = \
self.deep_ep_buffer.dispatch((x, x_sf), token_selected_slots.to(torch.int64), token_final_scales, self.num_slots)
if x_sf is not None:
x_sf = x_sf.view(x_sf_dtype)
if self.has_nvfp4:
x_sf = swizzle_sf(x_sf, x.shape[0], x.shape[1] * 2,
self.scaling_vector_size)
elif self.alltoall_method_type == AlltoallMethodType.DeepEPLowLatency:
raise NotImplementedError(
"Not implemented postquant for DeepEPLowLatency, please set TRTLLM_MOE_POST_QUANT_ALLTOALLV=0"
)
else:
raise NotImplementedError(
f"Not available alltoall method type: {alltoall_method_type!r}"
)
if self.enable_alltoall:
# Adapter between `torch.ops.trtllm.fused_moe` and DeepEP
# TODO: remove the adapter by changing APIs
if self.alltoall_method_type == AlltoallMethodType.DeepEP:
token_selected_slots = recv_topk_idx.to(torch.int32)
mask = token_selected_slots == -1
token_selected_slots += self.expert_size_per_partition * self.mapping.moe_ep_rank
token_selected_slots[mask] = self.num_slots
num_recv_token_is_zero = x.shape[0] == 0
if x.shape[0] == 0:
x = torch.zeros((1, x.shape[1]),
dtype=x.dtype,
device=x.device)
token_selected_slots = torch.full(
(1, token_selected_slots.shape[1]),
self.num_slots,
dtype=token_selected_slots.dtype,
device=token_selected_slots.device)
token_final_scales = torch.ones(
(1, token_final_scales.shape[1]),
dtype=token_final_scales.dtype,
device=token_final_scales.device)
final_hidden_states = torch.ops.trtllm.fused_moe( final_hidden_states = torch.ops.trtllm.fused_moe(
x, x,
@ -452,9 +612,25 @@ class CutlassFusedMoE(MoE):
final_hidden_states = final_hidden_states[0] final_hidden_states = final_hidden_states[0]
if self.enable_alltoall: if self.enable_alltoall:
final_hidden_states = self.alltoall_combine(final_hidden_states, if self.alltoall_method_type == AlltoallMethodType.MNNVL:
alltoall_info, final_hidden_states = self.alltoall_combine(
token_count) final_hidden_states, alltoall_info, token_count)
elif self.alltoall_method_type == AlltoallMethodType.DeepEP:
if num_recv_token_is_zero:
final_hidden_states = final_hidden_states[:0]
final_hidden_states = self.deep_ep_buffer.combine(
final_hidden_states, deep_ep_handle)
elif self.alltoall_method_type == AlltoallMethodType.DeepEPLowLatency:
final_hidden_states = self.deep_ep_buffer.low_latency_combine(
final_hidden_states.view(
self.expert_size_per_partition,
self.deep_ep_max_num_tokens * self.mapping.moe_ep_size,
final_hidden_states.shape[1]), deep_ep_topk_idx,
deep_ep_topk_weights, deep_ep_handle)
else:
raise NotImplementedError(
f"Not available alltoall method type: {alltoall_method_type!r}"
)
if self.layer_load_balancer and not self.layer_load_balancer.is_static_routing( if self.layer_load_balancer and not self.layer_load_balancer.is_static_routing(
) and is_last_call: ) and is_last_call:

View File

@ -23,7 +23,6 @@ class TRTLLMGenFusedMoE(MoE):
dtype (Optional[torch.dtype]): Data type for the weights. dtype (Optional[torch.dtype]): Data type for the weights.
reduce_results (bool): Whether to reduce the results across devices. reduce_results (bool): Whether to reduce the results across devices.
model_config (ModelConfig): Configuration object for the model. model_config (ModelConfig): Configuration object for the model.
enable_alltoall (bool): whether to enable alltoall instead of allgather/reducescatter
MoE torch custom op: MoE torch custom op:
Only support min-latency mode now (SM100 Blackwell only). Only support min-latency mode now (SM100 Blackwell only).

View File

@ -89,8 +89,6 @@ class VanillaMoE(nn.ModuleList):
if model_config.moe_max_num_tokens if model_config.moe_max_num_tokens
is not None else max_num_tokens) is not None else max_num_tokens)
self.enable_alltoall = False
self._weights_created = False self._weights_created = False
if not model_config.skip_create_weights_in_init: if not model_config.skip_create_weights_in_init:
self.create_weights() self.create_weights()
@ -458,7 +456,7 @@ class VanillaMoE(nn.ModuleList):
use_dp_padding: Optional[bool] = None, use_dp_padding: Optional[bool] = None,
): ):
outputs = inputs outputs = inputs
if self.parallel_size > 1 and not self.enable_alltoall: if self.parallel_size > 1:
if self.use_dp: if self.use_dp:
outputs = reducescatter( outputs = reducescatter(
inputs, inputs,

View File

@ -27,7 +27,6 @@ class MoE(nn.Module):
dtype (Optional[torch.dtype]): Data type for the weights. dtype (Optional[torch.dtype]): Data type for the weights.
reduce_results (bool): Whether to reduce the results across devices. reduce_results (bool): Whether to reduce the results across devices.
model_config (ModelConfig): Configuration object for the model. model_config (ModelConfig): Configuration object for the model.
enable_alltoall (bool): whether to enable alltoall instead of allgather/reducescatter
""" """
def __init__( def __init__(
@ -123,3 +122,9 @@ class MoE(nn.Module):
assert self._weights_created assert self._weights_created
return self.quant_config is not None and self.quant_config.layer_quant_mode.has_nvfp4( return self.quant_config is not None and self.quant_config.layer_quant_mode.has_nvfp4(
) )
@property
def enable_alltoall(self):
""" enable_alltoall (bool): whether to enable alltoall instead of allgather/reducescatter
"""
return False

View File

@ -1003,6 +1003,7 @@ class PyTorchModelEngine(ModelEngine):
checkpoint_dir, checkpoint_dir,
trust_remote_code=True, trust_remote_code=True,
enable_min_latency=self.pytorch_backend_config.enable_min_latency, enable_min_latency=self.pytorch_backend_config.enable_min_latency,
use_cuda_graph=self.pytorch_backend_config.use_cuda_graph,
spec_config=self.spec_config, spec_config=self.spec_config,
max_num_tokens=max_num_tokens, max_num_tokens=max_num_tokens,
moe_max_num_tokens=moe_max_num_tokens, moe_max_num_tokens=moe_max_num_tokens,

View File

@ -75,6 +75,9 @@ class Logger(metaclass=Singleton):
self._polygraphy_logger.module_severity = severity_map[ self._polygraphy_logger.module_severity = severity_map[
min_severity][2] min_severity][2]
# For log_once
self._appeared_keys = set()
if invalid_severity: if invalid_severity:
self.warning( self.warning(
f"Requested log level {environ_severity} is invalid. Using '{self.DEFAULT_LEVEL}' instead" f"Requested log level {environ_severity} is invalid. Using '{self.DEFAULT_LEVEL}' instead"
@ -109,23 +112,44 @@ class Logger(metaclass=Singleton):
parts.extend(map(str, msg)) parts.extend(map(str, msg))
self._func_wrapper(severity)(" ".join(parts)) self._func_wrapper(severity)(" ".join(parts))
def log_once(self, severity, *msg, key):
if key not in self._appeared_keys:
self._appeared_keys.add(key)
self.log(severity, *msg)
def critical(self, *msg): def critical(self, *msg):
self.log(self.INTERNAL_ERROR, *msg) self.log(self.INTERNAL_ERROR, *msg)
def critical_once(self, *msg, key):
self.log_once(self.INTERNAL_ERROR, *msg, key=key)
fatal = critical fatal = critical
fatal_once = critical_once
def error(self, *msg): def error(self, *msg):
self.log(self.ERROR, *msg) self.log(self.ERROR, *msg)
def error_once(self, *msg, key):
self.log_once(self.ERROR, *msg, key=key)
def warning(self, *msg): def warning(self, *msg):
self.log(self.WARNING, *msg) self.log(self.WARNING, *msg)
def warning_once(self, *msg, key):
self.log_once(self.WARNING, *msg, key=key)
def info(self, *msg): def info(self, *msg):
self.log(self.INFO, *msg) self.log(self.INFO, *msg)
def info_once(self, *msg, key):
self.log_once(self.INFO, *msg, key=key)
def debug(self, *msg): def debug(self, *msg):
self.log(self.VERBOSE, *msg) self.log(self.VERBOSE, *msg)
def debug_once(self, *msg, key):
self.log_once(self.VERBOSE, *msg, key=key)
@property @property
def level(self) -> str: def level(self) -> str:
return self._min_severity return self._min_severity

View File

@ -54,6 +54,8 @@ l0_dgx_h100:
auto_trigger: deepseek auto_trigger: deepseek
tests: tests:
- unittest/_torch/multi_gpu_modeling -k "deepseek" - unittest/_torch/multi_gpu_modeling -k "deepseek"
- unittest/_torch/modules/test_fused_moe.py::test_fused_moe_alltoall[DeepEP]
- unittest/_torch/modules/test_fused_moe.py::test_fused_moe_alltoall[DeepEPLowLatency]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp4-mtp_nextn=0-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp4-mtp_nextn=0-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp4-mtp_nextn=0-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=True] - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp4-mtp_nextn=0-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=True]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp4-mtp_nextn=0-fp8kv=True-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp4-mtp_nextn=0-fp8kv=True-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False]

View File

@ -2,6 +2,7 @@ import pickle
import sys import sys
from itertools import product from itertools import product
from typing import Dict, List, Optional from typing import Dict, List, Optional
from unittest import mock
import cloudpickle import cloudpickle
import pytest import pytest
@ -19,6 +20,8 @@ from tensorrt_llm._torch.modules.fused_moe import (BaseMoeRoutingMethod,
DefaultMoeRoutingMethod, DefaultMoeRoutingMethod,
RenormalizeMoeRoutingMethod, RenormalizeMoeRoutingMethod,
VanillaMoE) VanillaMoE)
from tensorrt_llm._torch.modules.fused_moe.fused_moe_cutlass import \
AlltoallMethodType
from tensorrt_llm._torch.modules.gated_mlp import GatedMLP from tensorrt_llm._torch.modules.gated_mlp import GatedMLP
from tensorrt_llm._utils import mpi_rank from tensorrt_llm._utils import mpi_rank
from tensorrt_llm.mapping import Mapping from tensorrt_llm.mapping import Mapping
@ -123,6 +126,111 @@ def test_fused_moe_multi_gpu(moe_cls, ep_size):
assert r is None assert r is None
@pytest.mark.skipif(torch.cuda.device_count() < 4,
reason="needs 4 GPUs to run this test")
@pytest.mark.parametrize("alltoall_method_type", [
AlltoallMethodType.MNNVL, AlltoallMethodType.DeepEP,
AlltoallMethodType.DeepEPLowLatency
],
ids=lambda s: s.name)
def test_fused_moe_alltoall(alltoall_method_type):
world_size = 4
dtype = torch.bfloat16
HIDDEN_SIZE = 2560
INTERMEDIATE_SIZE = 1536
NUM_EXPERTS = 72
TOP_K = 6
MAX_NUM_TOKENS = 2048
def per_rank_test_fused_moe_alltoall(job_id):
routing_method = DefaultMoeRoutingMethod(top_k=TOP_K)
mapping = Mapping(world_size=world_size,
rank=mpi_rank(),
tp_size=world_size,
moe_ep_size=world_size,
moe_tp_size=1,
enable_attention_dp=True)
torch.cuda.set_device(mapping.rank)
torch.manual_seed(mapping.rank)
weights = {}
for expert_id in range(NUM_EXPERTS):
w1_weight = torch.empty((INTERMEDIATE_SIZE, HIDDEN_SIZE),
dtype=dtype)
w2_weight = torch.empty((HIDDEN_SIZE, INTERMEDIATE_SIZE),
dtype=dtype)
w3_weight = torch.empty((INTERMEDIATE_SIZE, HIDDEN_SIZE),
dtype=dtype)
torch.nn.init.xavier_uniform_(w1_weight)
torch.nn.init.xavier_uniform_(w2_weight)
torch.nn.init.xavier_uniform_(w3_weight)
weights[f"{expert_id}.w1.weight"] = w1_weight
weights[f"{expert_id}.w2.weight"] = w2_weight
weights[f"{expert_id}.w3.weight"] = w3_weight
with mock.patch.object(CutlassFusedMoE,
"select_alltoall_method_type",
return_value=alltoall_method_type):
alltoall_model = CutlassFusedMoE(
num_experts=NUM_EXPERTS,
routing_method=routing_method,
hidden_size=HIDDEN_SIZE,
intermediate_size=INTERMEDIATE_SIZE,
dtype=dtype,
reduce_results=True,
model_config=ModelConfig(mapping=mapping,
max_num_tokens=MAX_NUM_TOKENS),
)
alltoall_model.to("cuda")
alltoall_model.load_weights([weights])
with mock.patch.object(CutlassFusedMoE,
"select_alltoall_method_type",
return_value=AlltoallMethodType.NotEnabled):
ref_model = CutlassFusedMoE(
num_experts=NUM_EXPERTS,
routing_method=routing_method,
hidden_size=HIDDEN_SIZE,
intermediate_size=INTERMEDIATE_SIZE,
dtype=dtype,
reduce_results=True,
model_config=ModelConfig(mapping=mapping,
max_num_tokens=MAX_NUM_TOKENS),
)
ref_model.to("cuda")
ref_model.load_weights([weights])
# Evaluate the outputs on a variant sequence length to verify the robustness of alltoall methods
m = MAX_NUM_TOKENS
while m >= 1:
x = torch.randn((m, HIDDEN_SIZE), dtype=dtype).cuda()
router_logits = torch.randn((m, NUM_EXPERTS), dtype=dtype).cuda()
all_rank_num_tokens = [m] * mapping.world_size
with torch.inference_mode():
output = alltoall_model.forward(
x,
router_logits,
all_rank_num_tokens=all_rank_num_tokens,
use_dp_padding=False)
ref_output = ref_model.forward(
x,
router_logits,
all_rank_num_tokens=all_rank_num_tokens,
use_dp_padding=False)
# Evaluate outputs
torch.testing.assert_close(output,
ref_output,
rtol=0.05,
atol=0.003)
m //= 2
with MPIPoolExecutor(max_workers=world_size) as executor:
results = executor.map(per_rank_test_fused_moe_alltoall,
range(world_size))
for r in results:
assert r is None
@skip_pre_hopper @skip_pre_hopper
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
def test_fused_moe_fp8(dtype): def test_fused_moe_fp8(dtype):