mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
feat: large-scale EP(part 7: DeepEP integration) (#4792)
Signed-off-by: Tailing Yuan <yuantailing@gmail.com> Co-authored-by: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com>
This commit is contained in:
parent
443b2eb51f
commit
0b60da2c45
@ -1,7 +1,7 @@
|
||||
version: "3.9"
|
||||
services:
|
||||
tensorrt_llm-dev:
|
||||
image: urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:pytorch-25.04-py3-x86_64-ubuntu24.04-trt10.10.0.31-skip-tritondevel-202506021004-9420
|
||||
image: urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:pytorch-25.04-py3-x86_64-ubuntu24.04-trt10.10.0.31-skip-tritondevel-202506111045-4792
|
||||
network_mode: host
|
||||
ipc: host
|
||||
|
||||
|
||||
@ -72,6 +72,10 @@ RUN bash ./install_pytorch.sh $TORCH_INSTALL_TYPE && rm install_pytorch.sh
|
||||
RUN pip3 uninstall -y opencv && rm -rf /usr/local/lib/python3*/dist-packages/cv2/
|
||||
RUN pip3 install opencv-python-headless --force-reinstall --no-deps --no-cache-dir
|
||||
|
||||
# Install DeepEP
|
||||
COPY docker/common/install_deep_ep.sh install_deep_ep.sh
|
||||
RUN bash ./install_deep_ep.sh && rm install_deep_ep.sh
|
||||
|
||||
# WARs against security issues inherited from pytorch:25.04
|
||||
# * https://github.com/advisories/GHSA-vqfr-h8mv-ghfj
|
||||
# * https://github.com/advisories/GHSA-7cx3-6m66-7c5m
|
||||
|
||||
47
docker/common/install_deep_ep.sh
Normal file
47
docker/common/install_deep_ep.sh
Normal file
@ -0,0 +1,47 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -euxo pipefail
|
||||
|
||||
GITHUB_URL=${GITHUB_MIRROR:-https://github.com}
|
||||
DEEP_EP_COMMIT=2b266cf6452134f993ab0fcb3ef2d5de7683c561
|
||||
|
||||
if [ "$(. /etc/os-release && echo $ID)" == "rocky" ]; then
|
||||
echo "Skipping DeepEP installation in the Rocky distribution."
|
||||
exit 0
|
||||
fi
|
||||
libmlx5_dir=$(dirname $(ldconfig -p | grep libmlx5.so.1 | head -n1 | awk '{print $NF}'))
|
||||
|
||||
export NVCC_APPEND_FLAGS="--threads 4"
|
||||
|
||||
# Custom NVSHMEM
|
||||
curl -fsSL https://developer.download.nvidia.com/compute/redist/nvshmem/3.2.5/source/nvshmem_src_3.2.5-1.txz | tar xz
|
||||
pushd nvshmem_src
|
||||
curl -fsSL $GITHUB_URL/deepseek-ai/DeepEP/raw/$DEEP_EP_COMMIT/third-party/nvshmem.patch | patch -p1
|
||||
sed "s/TRANSPORT_VERSION_MAJOR 3/TRANSPORT_VERSION_MAJOR 103/" -i src/CMakeLists.txt
|
||||
ln -s libmlx5.so.1 "$libmlx5_dir/libmlx5.so"
|
||||
cmake -S . -B build \
|
||||
-DCMAKE_INSTALL_PREFIX=/opt/custom_nvshmem \
|
||||
-DGDRCOPY_HOME=/usr/include \
|
||||
-DNVSHMEM_SHMEM_SUPPORT=0 \
|
||||
-DNVSHMEM_UCX_SUPPORT=0 \
|
||||
-DNVSHMEM_USE_NCCL=0 \
|
||||
-DNVSHMEM_MPI_SUPPORT=0 \
|
||||
-DNVSHMEM_IBGDA_SUPPORT=1 \
|
||||
-DNVSHMEM_PMIX_SUPPORT=0 \
|
||||
-DNVSHMEM_TIMEOUT_DEVICE_POLLING=0 \
|
||||
-DNVSHMEM_USE_GDRCOPY=1 \
|
||||
-DCMAKE_CUDA_ARCHITECTURES="90-real;100-real;120-real" \
|
||||
-DNVSHMEM_BUILD_TESTS=0 \
|
||||
-DNVSHMEM_BUILD_EXAMPLES=0
|
||||
cmake --build build -j`nproc`
|
||||
make -C build install
|
||||
popd
|
||||
|
||||
# DeepEP
|
||||
curl -fsSL $GITHUB_URL/deepseek-ai/DeepEP/archive/$DEEP_EP_COMMIT.tar.gz | tar xz
|
||||
TORCH_CUDA_ARCH_LIST="9.0;10.0;12.0" NVSHMEM_DIR=/opt/custom_nvshmem pip install -v --no-cache-dir ./DeepEP-$DEEP_EP_COMMIT
|
||||
|
||||
# Clean up
|
||||
rm -r nvshmem_src
|
||||
rm "$libmlx5_dir/libmlx5.so"
|
||||
rm -r DeepEP-$DEEP_EP_COMMIT
|
||||
@ -28,10 +28,10 @@ UPLOAD_PATH = env.uploadPath ? env.uploadPath : "sw-tensorrt-generic/llm-artifac
|
||||
// Container configuration
|
||||
// available tags can be found in: https://urm.nvidia.com/artifactory/sw-tensorrt-docker/tensorrt-llm/
|
||||
// [base_image_name]-[arch]-[os](-[python_version])-[trt_version]-[torch_install_type]-[stage]-[date]-[mr_id]
|
||||
LLM_DOCKER_IMAGE = "urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:pytorch-25.04-py3-x86_64-ubuntu24.04-trt10.10.0.31-skip-tritondevel-202506021004-9420"
|
||||
LLM_SBSA_DOCKER_IMAGE = "urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:pytorch-25.04-py3-aarch64-ubuntu24.04-trt10.10.0.31-skip-tritondevel-202506021004-9420"
|
||||
LLM_ROCKYLINUX8_PY310_DOCKER_IMAGE = "urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:cuda-12.9.0-devel-rocky8-x86_64-rocky8-py310-trt10.10.0.31-skip-tritondevel-202506021004-9420"
|
||||
LLM_ROCKYLINUX8_PY312_DOCKER_IMAGE = "urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:cuda-12.9.0-devel-rocky8-x86_64-rocky8-py312-trt10.10.0.31-skip-tritondevel-202506021004-9420"
|
||||
LLM_DOCKER_IMAGE = "urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:pytorch-25.04-py3-x86_64-ubuntu24.04-trt10.10.0.31-skip-tritondevel-202506111045-4792"
|
||||
LLM_SBSA_DOCKER_IMAGE = "urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:pytorch-25.04-py3-aarch64-ubuntu24.04-trt10.10.0.31-skip-tritondevel-202506111045-4792"
|
||||
LLM_ROCKYLINUX8_PY310_DOCKER_IMAGE = "urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:cuda-12.9.0-devel-rocky8-x86_64-rocky8-py310-trt10.10.0.31-skip-tritondevel-202506111045-4792"
|
||||
LLM_ROCKYLINUX8_PY312_DOCKER_IMAGE = "urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:cuda-12.9.0-devel-rocky8-x86_64-rocky8-py312-trt10.10.0.31-skip-tritondevel-202506111045-4792"
|
||||
|
||||
// TODO: Move common variables to an unified location
|
||||
BUILD_CORES_REQUEST = "8"
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
|
||||
import java.lang.InterruptedException
|
||||
|
||||
DOCKER_IMAGE = "urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:pytorch-25.04-py3-x86_64-ubuntu24.04-trt10.10.0.31-skip-tritondevel-202506021004-9420"
|
||||
DOCKER_IMAGE = "urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:pytorch-25.04-py3-x86_64-ubuntu24.04-trt10.10.0.31-skip-tritondevel-202506111045-4792"
|
||||
|
||||
def createKubernetesPodConfig(image, arch = "amd64")
|
||||
{
|
||||
|
||||
@ -84,6 +84,9 @@ class ModelConfig(Generic[TConfig]):
|
||||
# If true, enable min-latency mode. Currently only used for Llama4.
|
||||
enable_min_latency: bool = False
|
||||
|
||||
# Allow models to select op according to whether CUDA Graphs are used.
|
||||
use_cuda_graph: bool = False
|
||||
|
||||
extra_attrs: Dict = field(default_factory=dict, repr=False, init=False)
|
||||
|
||||
_frozen: bool = field(default=False, init=False, repr=False)
|
||||
|
||||
@ -38,7 +38,6 @@ from torch import nn
|
||||
from tqdm import tqdm
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from tensorrt_llm._mnnvl_utils import MnnvlMemory
|
||||
from tensorrt_llm.functional import PositionEmbeddingType
|
||||
from tensorrt_llm.llmapi.utils import enable_llm_debug
|
||||
from tensorrt_llm.mapping import Mapping
|
||||
@ -413,10 +412,6 @@ class Deepseekv3MoE(nn.Module):
|
||||
config = model_config.pretrained_config
|
||||
self.top_k = top_k
|
||||
self.use_dp = model_config.mapping.enable_attention_dp
|
||||
self.enable_alltoall = Deepseekv3MoE.should_enable_alltoall(
|
||||
model_config, top_k)
|
||||
if self.enable_alltoall:
|
||||
MnnvlMemory.initialize()
|
||||
self.gate = DeepseekV3Gate(
|
||||
hidden_size,
|
||||
num_experts,
|
||||
@ -439,7 +434,6 @@ class Deepseekv3MoE(nn.Module):
|
||||
model_config=model_config,
|
||||
override_quant_config=override_quant_config,
|
||||
aux_stream=aux_stream_dict[AuxStreamType.MoeChunkingOverlap],
|
||||
enable_alltoall=self.enable_alltoall,
|
||||
layer_idx=layer_idx)
|
||||
|
||||
self.mapping = model_config.mapping
|
||||
@ -505,25 +499,6 @@ class Deepseekv3MoE(nn.Module):
|
||||
|
||||
return shared_tp_size, shared_output_scale
|
||||
|
||||
@staticmethod
|
||||
def should_enable_alltoall(model_config: ModelConfig, top_k: int) -> bool:
|
||||
if not model_config.mapping.enable_attention_dp:
|
||||
return False
|
||||
|
||||
if model_config.mapping.tp_size == 1:
|
||||
return False
|
||||
|
||||
if not MnnvlMemory.supports_mnnvl():
|
||||
return False
|
||||
|
||||
if os.environ.get("TRTLLM_MOE_DISABLE_ALLTOALLV", "0") == "1":
|
||||
return False
|
||||
|
||||
if model_config.mapping.moe_ep_size <= top_k:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def compute_routed_output(self, hidden_states, hidden_states_fp4,
|
||||
all_rank_num_tokens, do_finalize):
|
||||
# max-throughput
|
||||
@ -531,7 +506,7 @@ class Deepseekv3MoE(nn.Module):
|
||||
if self.use_dp and self.mapping.tp_size > 1:
|
||||
# FP4 all_gather moves this bf16 allgather in to after topk and fp4 quantization
|
||||
# to reduce allreduce BW
|
||||
if disable_fp4_allgather() and not self.enable_alltoall:
|
||||
if disable_fp4_allgather() and not self.experts.enable_alltoall:
|
||||
hidden_states = allgather(hidden_states,
|
||||
self.mapping,
|
||||
dim=0,
|
||||
|
||||
@ -6,8 +6,6 @@ from torch import nn
|
||||
from tqdm import tqdm
|
||||
from transformers import Qwen3MoeConfig
|
||||
|
||||
from tensorrt_llm._mnnvl_utils import MnnvlMemory
|
||||
|
||||
from ..attention_backend import AttentionMetadata
|
||||
from ..distributed import (AllReduce, AllReduceFusionOp, AllReduceParams,
|
||||
allgather)
|
||||
@ -91,10 +89,6 @@ class Qwen3MoE(nn.Module):
|
||||
self.mapping = model_config.mapping
|
||||
self.allreduce = AllReduce(mapping=model_config.mapping,
|
||||
strategy=model_config.allreduce_strategy)
|
||||
self.enable_alltoall = Qwen3MoE.should_enable_alltoall(
|
||||
model_config, self.top_k)
|
||||
if self.enable_alltoall:
|
||||
MnnvlMemory.initialize()
|
||||
|
||||
self.gate = Qwen3Gate(
|
||||
hidden_size=self.hidden_dim,
|
||||
@ -117,25 +111,6 @@ class Qwen3MoE(nn.Module):
|
||||
model_config=model_config,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def should_enable_alltoall(model_config: ModelConfig, top_k: int) -> bool:
|
||||
if not model_config.mapping.enable_attention_dp:
|
||||
return False
|
||||
|
||||
if model_config.mapping.tp_size == 1:
|
||||
return False
|
||||
|
||||
if not MnnvlMemory.supports_mnnvl():
|
||||
return False
|
||||
|
||||
if os.environ.get("TRTLLM_MOE_DISABLE_ALLTOALLV", "0") == "1":
|
||||
return False
|
||||
|
||||
if model_config.mapping.moe_ep_size <= top_k:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -151,7 +126,7 @@ class Qwen3MoE(nn.Module):
|
||||
if self.enable_attention_dp and self.mapping.tp_size > 1:
|
||||
# FP4 all_gather moves this bf16 allgather in to after topk and fp4 quantization
|
||||
# to reduce allreduce BW
|
||||
if disable_fp4_allgather() and not self.enable_alltoall:
|
||||
if disable_fp4_allgather() and not self.experts.enable_alltoall:
|
||||
hidden_states = allgather(hidden_states,
|
||||
self.mapping,
|
||||
dim=0,
|
||||
|
||||
@ -52,7 +52,6 @@ def create_moe(
|
||||
aux_stream: Optional[torch.cuda.Stream] = None,
|
||||
weight_loading_mode: MoEWeightLoadingMode = MoEWeightLoadingMode.VANILLA,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
enable_alltoall: bool = False,
|
||||
layer_idx: Optional[int] = None,
|
||||
) -> MoE:
|
||||
moe_cls = get_moe_cls(model_config, override_quant_config)
|
||||
@ -63,7 +62,6 @@ def create_moe(
|
||||
|
||||
if moe_cls == TRTLLMGenFusedMoE:
|
||||
assert not apply_router_weight_on_input, "apply_router_weight_on_input is not supported in TRTLLMGenFusedMoE."
|
||||
assert not enable_alltoall, "enable_alltoall is not supported in TRTLLMGenFusedMoE."
|
||||
|
||||
return moe_cls(
|
||||
routing_method=routing_method,
|
||||
@ -88,12 +86,10 @@ def create_moe(
|
||||
aux_stream=aux_stream,
|
||||
weight_loading_mode=weight_loading_mode,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
enable_alltoall=enable_alltoall,
|
||||
layer_idx=layer_idx,
|
||||
)
|
||||
elif moe_cls == VanillaMoE:
|
||||
assert not apply_router_weight_on_input, "apply_router_weight_on_input is not supported in VanillaMoE."
|
||||
assert not enable_alltoall, "enable_alltoall is not supported in VanillaMoE."
|
||||
|
||||
return moe_cls(
|
||||
routing_method=routing_method,
|
||||
|
||||
209
tensorrt_llm/_torch/modules/fused_moe/deep_ep_utils.py
Normal file
209
tensorrt_llm/_torch/modules/fused_moe/deep_ep_utils.py
Normal file
@ -0,0 +1,209 @@
|
||||
# Adapted from
|
||||
# https://github.com/deepseek-ai/DeepEP/blob/aae9fa9a6dd0fec2a723fbb85ec4b22460fab670/README.md
|
||||
import weakref
|
||||
from typing import List, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from tensorrt_llm._utils import local_mpi_size, mpi_comm
|
||||
from tensorrt_llm.mapping import Mapping
|
||||
|
||||
try:
|
||||
from deep_ep import Buffer
|
||||
deep_ep_installed = True
|
||||
except ModuleNotFoundError:
|
||||
deep_ep_installed = False
|
||||
|
||||
|
||||
class VariableLengthBuffer:
|
||||
""" A wrapper of deep_ep.Buffer that accepts future size change
|
||||
"""
|
||||
|
||||
def __init__(self, mapping: Mapping):
|
||||
self.comm = mpi_comm().Split(mapping.pp_rank, mapping.moe_ep_rank)
|
||||
self.buffer = None
|
||||
|
||||
def __del__(self):
|
||||
self.comm.Free()
|
||||
|
||||
def reserve(self, hidden_size: int, hidden_dtype: torch.dtype):
|
||||
""" Ensure the buffer capacity is large enough.
|
||||
|
||||
Reserve is a collective operation that requires all EP ranks to be sync
|
||||
"""
|
||||
# NOTES: you may also replace `get_*_config` with your auto-tuned results via all the tests
|
||||
num_nvl_bytes, num_rdma_bytes = 0, 0
|
||||
hidden_bytes = hidden_size * max(hidden_dtype.itemsize,
|
||||
torch.bfloat16.itemsize)
|
||||
world_size = self.comm.Get_size()
|
||||
for config in (Buffer.get_dispatch_config(world_size),
|
||||
Buffer.get_combine_config(world_size)):
|
||||
num_nvl_bytes = max(
|
||||
config.get_nvl_buffer_size_hint(hidden_bytes, world_size),
|
||||
num_nvl_bytes)
|
||||
num_rdma_bytes = max(
|
||||
config.get_rdma_buffer_size_hint(hidden_bytes, world_size),
|
||||
num_rdma_bytes)
|
||||
|
||||
# Allocate a buffer if not existed or not enough buffer size
|
||||
if self.buffer is None or self.buffer.num_nvl_bytes < num_nvl_bytes or self.buffer.num_rdma_bytes < num_rdma_bytes:
|
||||
if self.buffer is not None:
|
||||
num_nvl_bytes = max(num_nvl_bytes, self.buffer.num_nvl_bytes)
|
||||
num_rdma_bytes = max(num_rdma_bytes, self.buffer.num_rdma_bytes)
|
||||
del self.buffer # Destruct before Construct
|
||||
self.buffer = Buffer(None,
|
||||
num_nvl_bytes,
|
||||
num_rdma_bytes,
|
||||
num_nvl_peers=local_mpi_size(),
|
||||
comm=self.comm)
|
||||
|
||||
def dispatch(self, x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
|
||||
topk_idx: torch.Tensor, topk_weights: torch.Tensor,
|
||||
num_experts: int) -> \
|
||||
Tuple[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], torch.Tensor, torch.Tensor, List, Tuple]:
|
||||
# NOTES: an optional `previous_event` means a CUDA event captured that you want to make it as a dependency
|
||||
# of the dispatch kernel, it may be useful with communication-computation overlap. For more information, please
|
||||
# refer to the docs of `Buffer.dispatch`
|
||||
|
||||
# Calculate layout before actual dispatch
|
||||
num_tokens_per_rank, num_tokens_per_rdma_rank, num_tokens_per_expert, is_token_in_rank, event = \
|
||||
self.buffer.get_dispatch_layout(topk_idx, num_experts)
|
||||
assert event.event is None
|
||||
|
||||
# Do MoE dispatch
|
||||
# NOTES: the CPU will wait for GPU's signal to arrive, so this is not compatible with CUDA graph
|
||||
# For more advanced usages, please refer to the docs of the `dispatch` function
|
||||
recv_x, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, handle, event = \
|
||||
self.buffer.dispatch(x, topk_idx=topk_idx, topk_weights=topk_weights,
|
||||
num_tokens_per_rank=num_tokens_per_rank, num_tokens_per_rdma_rank=num_tokens_per_rdma_rank,
|
||||
is_token_in_rank=is_token_in_rank, num_tokens_per_expert=num_tokens_per_expert)
|
||||
assert event.event is None
|
||||
|
||||
# For event management, please refer to the docs of the `EventOverlap` class
|
||||
return recv_x, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, handle
|
||||
|
||||
def combine(self, x: torch.Tensor, handle: Tuple) -> torch.Tensor:
|
||||
# Do MoE combine
|
||||
# For more advanced usages, please refer to the docs of the `combine` function
|
||||
combined_x, _, event = self.buffer.combine(x, handle)
|
||||
assert event.event is None
|
||||
|
||||
# For event management, please refer to the docs of the `EventOverlap` class
|
||||
return combined_x
|
||||
|
||||
|
||||
class VariableLengthLowLatencyBuffer:
|
||||
""" A wrapper of deep_ep.Buffer that accepts future size change
|
||||
"""
|
||||
|
||||
def __init__(self, mapping: Mapping):
|
||||
self.comm = mpi_comm().Split(mapping.pp_rank, mapping.moe_ep_rank)
|
||||
self.buffer = None
|
||||
self.num_max_dispatch_tokens_per_rank = None
|
||||
|
||||
def __del__(self):
|
||||
self.comm.Free()
|
||||
|
||||
def reserve(self, num_max_dispatch_tokens_per_rank: int, hidden_size: int,
|
||||
num_experts: int):
|
||||
""" Ensure the buffer capacity is large enough.
|
||||
|
||||
Reserve is a collective operation that requires all EP ranks to be sync
|
||||
"""
|
||||
# NOTES: the low-latency mode will consume much more space than the normal mode
|
||||
# So we recommend that `num_max_dispatch_tokens_per_rank` (the actual batch size in the decoding engine) should be less than 256
|
||||
world_size = self.comm.Get_size()
|
||||
num_rdma_bytes = Buffer.get_low_latency_rdma_size_hint(
|
||||
num_max_dispatch_tokens_per_rank, hidden_size, world_size,
|
||||
num_experts)
|
||||
|
||||
# Allocate a buffer if not existed or not enough buffer size
|
||||
if self.buffer is None or self.buffer.num_rdma_bytes < num_rdma_bytes:
|
||||
# NOTES: for best performance, the QP number **must** be equal to the number of the local experts
|
||||
assert num_experts % world_size == 0
|
||||
del self.buffer # Destruct before Construct
|
||||
self.buffer = Buffer(None,
|
||||
0,
|
||||
num_rdma_bytes,
|
||||
low_latency_mode=True,
|
||||
num_qps_per_rank=num_experts // world_size,
|
||||
comm=self.comm)
|
||||
|
||||
def low_latency_dispatch(self, hidden_states: torch.Tensor,
|
||||
topk_idx: torch.Tensor,
|
||||
num_max_dispatch_tokens_per_rank: int,
|
||||
num_experts: int):
|
||||
if self.num_max_dispatch_tokens_per_rank is None:
|
||||
self.num_max_dispatch_tokens_per_rank = num_max_dispatch_tokens_per_rank
|
||||
if num_max_dispatch_tokens_per_rank != self.num_max_dispatch_tokens_per_rank:
|
||||
raise NotImplementedError(
|
||||
"There are issues if `low_latency_dispatch` calls use different `num_max_dispatch_tokens_per_rank` values"
|
||||
)
|
||||
|
||||
# Do MoE dispatch, compatible with CUDA graph (but you may restore some buffer status once you replay)
|
||||
recv_hidden_states, recv_expert_count, handle, event, hook = \
|
||||
self.buffer.low_latency_dispatch(hidden_states, topk_idx, num_max_dispatch_tokens_per_rank, num_experts, use_fp8=False)
|
||||
assert event.event is None
|
||||
assert hook is None
|
||||
|
||||
# NOTES: the actual tensor will not be received only if you call `hook()`,
|
||||
# it is useful for double-batch overlapping, but **without any SM occupation**
|
||||
# If you don't want to overlap, please set `return_recv_hook=False`
|
||||
# Later, you can use our GEMM library to do the computation with this specific format
|
||||
return recv_hidden_states, recv_expert_count, handle
|
||||
|
||||
def low_latency_combine(self, hidden_states: torch.Tensor,
|
||||
topk_idx: torch.Tensor, topk_weights: torch.Tensor,
|
||||
handle: Tuple):
|
||||
# Do MoE combine, compatible with CUDA graph (but you may restore some buffer status once you replay)
|
||||
combined_hidden_states, event, hook = \
|
||||
self.buffer.low_latency_combine(hidden_states, topk_idx, topk_weights, handle)
|
||||
assert event.event is None
|
||||
assert hook is None
|
||||
|
||||
# NOTES: the same behavior as described in the dispatch kernel
|
||||
return combined_hidden_states
|
||||
|
||||
|
||||
class BufferPool:
|
||||
""" A pool that allocates buffers on demand.
|
||||
|
||||
Although the pool interface allows creating multiple buffers, the
|
||||
current version of DeepEP supports at most one `deep_ep.Buffer` at a
|
||||
time. Please ensure that all references to `VariableLengthBuffer` are
|
||||
released before getting another buffer.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.buffers: Map[Mapping,
|
||||
weakref.ReferenceType[VariableLengthBuffer]] = {}
|
||||
self.low_latency_buffers: Map[
|
||||
Mapping,
|
||||
weakref.ReferenceType[VariableLengthLowLatencyBuffer]] = {}
|
||||
|
||||
def get_buffer(self, mapping: Mapping) -> VariableLengthBuffer:
|
||||
""" Get_buffer is a collective operation that requires all ranks to be sync
|
||||
"""
|
||||
if mapping in self.buffers and self.buffers[mapping]() is not None:
|
||||
buffer = self.buffers[mapping]()
|
||||
else:
|
||||
buffer = VariableLengthBuffer(mapping)
|
||||
self.buffers[mapping] = weakref.ref(buffer)
|
||||
return buffer
|
||||
|
||||
def get_low_latency_buffer(
|
||||
self, mapping: Mapping) -> VariableLengthLowLatencyBuffer:
|
||||
""" Get_low_latency_buffer is a collective operation that requires all ranks to be sync
|
||||
"""
|
||||
if mapping in self.low_latency_buffers and self.low_latency_buffers[
|
||||
mapping]() is not None:
|
||||
buffer = self.low_latency_buffers[mapping]()
|
||||
else:
|
||||
buffer = VariableLengthLowLatencyBuffer(mapping)
|
||||
self.low_latency_buffers[mapping] = weakref.ref(buffer)
|
||||
return buffer
|
||||
|
||||
|
||||
# The default pool
|
||||
# You may create own pools for better resource management.
|
||||
buffer_pool = BufferPool()
|
||||
@ -1,16 +1,19 @@
|
||||
import os
|
||||
from enum import IntEnum
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from tensorrt_llm._mnnvl_utils import MnnvlMoe, MoEAlltoallInfo
|
||||
from tensorrt_llm._mnnvl_utils import MnnvlMemory, MnnvlMoe, MoEAlltoallInfo
|
||||
from tensorrt_llm._utils import logger
|
||||
from tensorrt_llm.mapping import Mapping
|
||||
|
||||
from ...distributed import allgather, reducescatter
|
||||
from ...expert_statistic import ExpertStatistic
|
||||
from ...model_config import ModelConfig
|
||||
from ...utils import (EventType, Fp4QuantizedTensor, disable_fp4_allgather,
|
||||
reswizzle_sf, swizzle_sf, unswizzle_sf)
|
||||
from .deep_ep_utils import buffer_pool, deep_ep_installed
|
||||
from .interface import MoE
|
||||
from .moe_load_balancer import get_moe_load_balancer
|
||||
from .quantization import (DeepSeekFP8BlockScalesFusedMoEMethod,
|
||||
@ -20,6 +23,18 @@ from .quantization import (DeepSeekFP8BlockScalesFusedMoEMethod,
|
||||
from .routing import BaseMoeRoutingMethod
|
||||
|
||||
|
||||
# The type of alltoall method
|
||||
class AlltoallMethodType(IntEnum):
|
||||
# Not available
|
||||
NotEnabled = 0
|
||||
# MNNVL
|
||||
MNNVL = 1
|
||||
# DeepEP intranode or internode: no CUDA Graphs support, IBGDA is required by internode
|
||||
DeepEP = 2
|
||||
# DeepEP low latency: CUDA Graphs are supported, IBGDA is required
|
||||
DeepEPLowLatency = 3
|
||||
|
||||
|
||||
class CutlassFusedMoE(MoE):
|
||||
"""
|
||||
Fused Mixture of Experts (MoE) Layer with performance tuning.
|
||||
@ -33,7 +48,6 @@ class CutlassFusedMoE(MoE):
|
||||
dtype (Optional[torch.dtype]): Data type for the weights.
|
||||
reduce_results (bool): Whether to reduce the results across devices.
|
||||
model_config (ModelConfig): Configuration object for the model.
|
||||
enable_alltoall (bool): whether to enable alltoall instead of allgather/reducescatter
|
||||
|
||||
MoE torch custom op:
|
||||
In min-latency mode:
|
||||
@ -82,7 +96,6 @@ class CutlassFusedMoE(MoE):
|
||||
weight_loading_mode: MoEWeightLoadingMode = MoEWeightLoadingMode.
|
||||
VANILLA,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
enable_alltoall: bool = False,
|
||||
layer_idx: Optional[int] = None,
|
||||
):
|
||||
|
||||
@ -176,7 +189,12 @@ class CutlassFusedMoE(MoE):
|
||||
self.has_been_profiled = False
|
||||
self.has_been_profiled_min_latency = False
|
||||
|
||||
self.enable_alltoall = enable_alltoall
|
||||
self.alltoall_method_type = self.select_alltoall_method_type(
|
||||
model_config.mapping, routing_method.experts_per_token, dtype,
|
||||
model_config.use_cuda_graph)
|
||||
logger.info_once(
|
||||
f"CutlassFusedMoE selects alltoall_method_type {self.alltoall_method_type!r}",
|
||||
key="alltoall_method_type")
|
||||
self.use_postquant_alltoall = False
|
||||
if self.enable_alltoall:
|
||||
assert self.use_dp and self.parallel_size > 1,\
|
||||
@ -185,8 +203,25 @@ class CutlassFusedMoE(MoE):
|
||||
self.use_postquant_alltoall = (os.environ.get(
|
||||
"TRTLLM_MOE_POST_QUANT_ALLTOALLV", "1")
|
||||
== "1") and qm.has_nvfp4()
|
||||
self.alltoall_workspace = MnnvlMoe.get_moe_workspaces(
|
||||
model_config.mapping) if enable_alltoall else None
|
||||
if self.alltoall_method_type == AlltoallMethodType.MNNVL:
|
||||
MnnvlMemory.initialize()
|
||||
self.alltoall_workspace = MnnvlMoe.get_moe_workspaces(
|
||||
model_config.mapping)
|
||||
elif self.alltoall_method_type == AlltoallMethodType.DeepEP:
|
||||
self.deep_ep_buffer = buffer_pool.get_buffer(
|
||||
model_config.mapping)
|
||||
self.deep_ep_buffer.reserve(hidden_size, dtype)
|
||||
elif self.alltoall_method_type == AlltoallMethodType.DeepEPLowLatency:
|
||||
self.deep_ep_max_num_tokens = min(model_config.max_num_tokens,
|
||||
self.moe_max_num_tokens)
|
||||
self.deep_ep_buffer = buffer_pool.get_low_latency_buffer(
|
||||
model_config.mapping)
|
||||
self.deep_ep_buffer.reserve(self.deep_ep_max_num_tokens,
|
||||
hidden_size, self.num_slots)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Not available alltoall method type: {alltoall_method_type!r}"
|
||||
)
|
||||
|
||||
# If True, the router weight will be multiplied on the input rather than at the end of FC2
|
||||
self.apply_router_weight_on_input = apply_router_weight_on_input
|
||||
@ -215,12 +250,48 @@ class CutlassFusedMoE(MoE):
|
||||
f"unsupported quantization mode: {self.quant_config.quant_mode}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def select_alltoall_method_type(mapping: Mapping, top_k: int,
|
||||
dtype: torch.dtype,
|
||||
use_cuda_graph: bool) -> AlltoallMethodType:
|
||||
if not mapping.enable_attention_dp:
|
||||
return AlltoallMethodType.NotEnabled
|
||||
|
||||
if mapping.tp_size == 1:
|
||||
return AlltoallMethodType.NotEnabled
|
||||
|
||||
if os.environ.get("TRTLLM_MOE_DISABLE_ALLTOALLV", "0") == "1":
|
||||
return AlltoallMethodType.NotEnabled
|
||||
|
||||
if mapping.moe_ep_size <= top_k:
|
||||
return AlltoallMethodType.NotEnabled
|
||||
|
||||
if MnnvlMemory.supports_mnnvl():
|
||||
return AlltoallMethodType.MNNVL
|
||||
|
||||
if os.environ.get("TRTLLM_CAN_USE_DEEP_EP", "0") == "1":
|
||||
if deep_ep_installed and dtype == torch.bfloat16:
|
||||
if use_cuda_graph:
|
||||
# Here we can only choose DeepEPLowLatency since only this method supports CUDA Graphs.
|
||||
return AlltoallMethodType.DeepEPLowLatency
|
||||
else:
|
||||
# Here we can choose DeepEP or DeepEPLowLatency if both are available. Now DeepEP is faster.
|
||||
return AlltoallMethodType.DeepEP
|
||||
|
||||
return AlltoallMethodType.NotEnabled
|
||||
|
||||
@property
|
||||
def has_w4afp8(self):
|
||||
assert self._weights_created
|
||||
return self.quant_config and self.quant_config.quant_mode.is_int4_weight_only_per_group(
|
||||
)
|
||||
|
||||
@property
|
||||
def enable_alltoall(self):
|
||||
""" enable_alltoall (bool): whether to enable alltoall instead of allgather/reducescatter
|
||||
"""
|
||||
return self.alltoall_method_type != AlltoallMethodType.NotEnabled
|
||||
|
||||
def _get_quant_method(self):
|
||||
if self.quant_config is not None and self.quant_config.layer_quant_mode.has_any_quant(
|
||||
exclude_kv_cache=True):
|
||||
@ -311,10 +382,6 @@ class CutlassFusedMoE(MoE):
|
||||
# TODO: remove this once we have correct fusedmoe kernel ready
|
||||
token_final_scales = None
|
||||
|
||||
token_count = x.shape[0]
|
||||
|
||||
alltoall_info = None
|
||||
|
||||
if self.layer_load_balancer and not self.layer_load_balancer.is_static_routing(
|
||||
) and is_first_call:
|
||||
self.layer_load_balancer.maybe_cudagraph_done_wait()
|
||||
@ -334,13 +401,58 @@ class CutlassFusedMoE(MoE):
|
||||
ExpertStatistic.maybe_add_info(self.num_slots, token_selected_slots)
|
||||
|
||||
token_selected_experts_for_statistic = token_selected_experts if need_statistic else None
|
||||
|
||||
if self.enable_alltoall:
|
||||
x, token_selected_slots, token_final_scales, token_selected_experts_for_statistic, alltoall_info = \
|
||||
self.alltoall_prepare_maybe_dispatch(all_rank_num_tokens,
|
||||
x,
|
||||
token_selected_slots,
|
||||
token_final_scales,
|
||||
token_selected_experts_for_statistic)
|
||||
if self.alltoall_method_type == AlltoallMethodType.MNNVL:
|
||||
token_count = x.shape[0]
|
||||
alltoall_info = None
|
||||
x, token_selected_slots, token_final_scales, token_selected_experts_for_statistic, alltoall_info = \
|
||||
self.alltoall_prepare_maybe_dispatch(all_rank_num_tokens,
|
||||
x,
|
||||
token_selected_slots,
|
||||
token_final_scales,
|
||||
token_selected_experts_for_statistic)
|
||||
elif self.alltoall_method_type == AlltoallMethodType.DeepEP:
|
||||
if not self.use_postquant_alltoall:
|
||||
x, recv_topk_idx, token_final_scales, num_recv_tokens_per_expert_list, deep_ep_handle = \
|
||||
self.deep_ep_buffer.dispatch(x, token_selected_slots.to(torch.int64), token_final_scales, self.num_slots)
|
||||
elif self.alltoall_method_type == AlltoallMethodType.DeepEPLowLatency:
|
||||
if not self.use_postquant_alltoall:
|
||||
deep_ep_topk_idx = token_selected_slots.to(torch.int64)
|
||||
deep_ep_topk_weights = token_final_scales
|
||||
x, recv_expert_count, deep_ep_handle = \
|
||||
self.deep_ep_buffer.low_latency_dispatch(x, deep_ep_topk_idx, self.deep_ep_max_num_tokens, self.num_slots)
|
||||
# x shape: [#local experts, #max recv tokens, hidden_size]
|
||||
# recv_expert_count shape: [#local experts]
|
||||
|
||||
# Adapter between `torch.ops.trtllm.fused_moe` and DeepEP
|
||||
# TODO: remove the adapter by changing `torch.ops.trtllm.fused_moe` API
|
||||
mask = torch.arange(
|
||||
x.shape[1], dtype=torch.int32, device=x.device).expand(
|
||||
x.shape[0],
|
||||
x.shape[1]) < recv_expert_count.unsqueeze(1)
|
||||
token_selected_slots = torch.full(
|
||||
(x.shape[0], x.shape[1], self.routing_method.top_k),
|
||||
self.num_slots,
|
||||
dtype=torch.int32,
|
||||
device=x.device)
|
||||
token_selected_slots[:, :, 0] = torch.where(
|
||||
mask,
|
||||
torch.arange(
|
||||
x.shape[0] * self.mapping.moe_ep_rank,
|
||||
x.shape[0] * (self.mapping.moe_ep_rank + 1),
|
||||
dtype=torch.int32,
|
||||
device=x.device).unsqueeze(1), self.num_slots)
|
||||
x = x.view(x.shape[0] * x.shape[1], x.shape[2])
|
||||
token_selected_slots = token_selected_slots.view(
|
||||
x.shape[0], self.routing_method.top_k)
|
||||
token_final_scales = torch.ones_like(
|
||||
token_selected_slots, dtype=token_final_scales.dtype)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Not available alltoall method type: {alltoall_method_type!r}"
|
||||
)
|
||||
|
||||
x_sf = None
|
||||
if self.has_any_quant:
|
||||
if self.has_fp8_qdq:
|
||||
@ -414,8 +526,56 @@ class CutlassFusedMoE(MoE):
|
||||
quant_scales = self.quant_scales
|
||||
|
||||
if self.use_postquant_alltoall:
|
||||
x, x_sf = self.alltoall_postquant_dispatch(x, x_sf, x_row, x_col,
|
||||
alltoall_info)
|
||||
if self.alltoall_method_type == AlltoallMethodType.MNNVL:
|
||||
x, x_sf = self.alltoall_postquant_dispatch(
|
||||
x, x_sf, x_row, x_col, alltoall_info)
|
||||
elif self.alltoall_method_type == AlltoallMethodType.DeepEP:
|
||||
if x_sf is not None:
|
||||
if self.has_nvfp4:
|
||||
x_sf = unswizzle_sf(x_sf, x_row, x_col,
|
||||
self.scaling_vector_size)
|
||||
# Adapter between `x_sf` and DeepEP
|
||||
# TODO: remove the adapter by adding dtype support to DeepEP
|
||||
x_sf_dtype = x_sf.dtype
|
||||
x_sf = x_sf.view(torch.float32)
|
||||
(x, x_sf), recv_topk_idx, token_final_scales, num_recv_tokens_per_expert_list, deep_ep_handle = \
|
||||
self.deep_ep_buffer.dispatch((x, x_sf), token_selected_slots.to(torch.int64), token_final_scales, self.num_slots)
|
||||
if x_sf is not None:
|
||||
x_sf = x_sf.view(x_sf_dtype)
|
||||
if self.has_nvfp4:
|
||||
x_sf = swizzle_sf(x_sf, x.shape[0], x.shape[1] * 2,
|
||||
self.scaling_vector_size)
|
||||
elif self.alltoall_method_type == AlltoallMethodType.DeepEPLowLatency:
|
||||
raise NotImplementedError(
|
||||
"Not implemented postquant for DeepEPLowLatency, please set TRTLLM_MOE_POST_QUANT_ALLTOALLV=0"
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Not available alltoall method type: {alltoall_method_type!r}"
|
||||
)
|
||||
|
||||
if self.enable_alltoall:
|
||||
# Adapter between `torch.ops.trtllm.fused_moe` and DeepEP
|
||||
# TODO: remove the adapter by changing APIs
|
||||
if self.alltoall_method_type == AlltoallMethodType.DeepEP:
|
||||
token_selected_slots = recv_topk_idx.to(torch.int32)
|
||||
mask = token_selected_slots == -1
|
||||
token_selected_slots += self.expert_size_per_partition * self.mapping.moe_ep_rank
|
||||
token_selected_slots[mask] = self.num_slots
|
||||
num_recv_token_is_zero = x.shape[0] == 0
|
||||
if x.shape[0] == 0:
|
||||
x = torch.zeros((1, x.shape[1]),
|
||||
dtype=x.dtype,
|
||||
device=x.device)
|
||||
token_selected_slots = torch.full(
|
||||
(1, token_selected_slots.shape[1]),
|
||||
self.num_slots,
|
||||
dtype=token_selected_slots.dtype,
|
||||
device=token_selected_slots.device)
|
||||
token_final_scales = torch.ones(
|
||||
(1, token_final_scales.shape[1]),
|
||||
dtype=token_final_scales.dtype,
|
||||
device=token_final_scales.device)
|
||||
|
||||
final_hidden_states = torch.ops.trtllm.fused_moe(
|
||||
x,
|
||||
@ -452,9 +612,25 @@ class CutlassFusedMoE(MoE):
|
||||
final_hidden_states = final_hidden_states[0]
|
||||
|
||||
if self.enable_alltoall:
|
||||
final_hidden_states = self.alltoall_combine(final_hidden_states,
|
||||
alltoall_info,
|
||||
token_count)
|
||||
if self.alltoall_method_type == AlltoallMethodType.MNNVL:
|
||||
final_hidden_states = self.alltoall_combine(
|
||||
final_hidden_states, alltoall_info, token_count)
|
||||
elif self.alltoall_method_type == AlltoallMethodType.DeepEP:
|
||||
if num_recv_token_is_zero:
|
||||
final_hidden_states = final_hidden_states[:0]
|
||||
final_hidden_states = self.deep_ep_buffer.combine(
|
||||
final_hidden_states, deep_ep_handle)
|
||||
elif self.alltoall_method_type == AlltoallMethodType.DeepEPLowLatency:
|
||||
final_hidden_states = self.deep_ep_buffer.low_latency_combine(
|
||||
final_hidden_states.view(
|
||||
self.expert_size_per_partition,
|
||||
self.deep_ep_max_num_tokens * self.mapping.moe_ep_size,
|
||||
final_hidden_states.shape[1]), deep_ep_topk_idx,
|
||||
deep_ep_topk_weights, deep_ep_handle)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Not available alltoall method type: {alltoall_method_type!r}"
|
||||
)
|
||||
|
||||
if self.layer_load_balancer and not self.layer_load_balancer.is_static_routing(
|
||||
) and is_last_call:
|
||||
|
||||
@ -23,7 +23,6 @@ class TRTLLMGenFusedMoE(MoE):
|
||||
dtype (Optional[torch.dtype]): Data type for the weights.
|
||||
reduce_results (bool): Whether to reduce the results across devices.
|
||||
model_config (ModelConfig): Configuration object for the model.
|
||||
enable_alltoall (bool): whether to enable alltoall instead of allgather/reducescatter
|
||||
|
||||
MoE torch custom op:
|
||||
Only support min-latency mode now (SM100 Blackwell only).
|
||||
|
||||
@ -89,8 +89,6 @@ class VanillaMoE(nn.ModuleList):
|
||||
if model_config.moe_max_num_tokens
|
||||
is not None else max_num_tokens)
|
||||
|
||||
self.enable_alltoall = False
|
||||
|
||||
self._weights_created = False
|
||||
if not model_config.skip_create_weights_in_init:
|
||||
self.create_weights()
|
||||
@ -458,7 +456,7 @@ class VanillaMoE(nn.ModuleList):
|
||||
use_dp_padding: Optional[bool] = None,
|
||||
):
|
||||
outputs = inputs
|
||||
if self.parallel_size > 1 and not self.enable_alltoall:
|
||||
if self.parallel_size > 1:
|
||||
if self.use_dp:
|
||||
outputs = reducescatter(
|
||||
inputs,
|
||||
|
||||
@ -27,7 +27,6 @@ class MoE(nn.Module):
|
||||
dtype (Optional[torch.dtype]): Data type for the weights.
|
||||
reduce_results (bool): Whether to reduce the results across devices.
|
||||
model_config (ModelConfig): Configuration object for the model.
|
||||
enable_alltoall (bool): whether to enable alltoall instead of allgather/reducescatter
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -123,3 +122,9 @@ class MoE(nn.Module):
|
||||
assert self._weights_created
|
||||
return self.quant_config is not None and self.quant_config.layer_quant_mode.has_nvfp4(
|
||||
)
|
||||
|
||||
@property
|
||||
def enable_alltoall(self):
|
||||
""" enable_alltoall (bool): whether to enable alltoall instead of allgather/reducescatter
|
||||
"""
|
||||
return False
|
||||
|
||||
@ -1003,6 +1003,7 @@ class PyTorchModelEngine(ModelEngine):
|
||||
checkpoint_dir,
|
||||
trust_remote_code=True,
|
||||
enable_min_latency=self.pytorch_backend_config.enable_min_latency,
|
||||
use_cuda_graph=self.pytorch_backend_config.use_cuda_graph,
|
||||
spec_config=self.spec_config,
|
||||
max_num_tokens=max_num_tokens,
|
||||
moe_max_num_tokens=moe_max_num_tokens,
|
||||
|
||||
@ -75,6 +75,9 @@ class Logger(metaclass=Singleton):
|
||||
self._polygraphy_logger.module_severity = severity_map[
|
||||
min_severity][2]
|
||||
|
||||
# For log_once
|
||||
self._appeared_keys = set()
|
||||
|
||||
if invalid_severity:
|
||||
self.warning(
|
||||
f"Requested log level {environ_severity} is invalid. Using '{self.DEFAULT_LEVEL}' instead"
|
||||
@ -109,23 +112,44 @@ class Logger(metaclass=Singleton):
|
||||
parts.extend(map(str, msg))
|
||||
self._func_wrapper(severity)(" ".join(parts))
|
||||
|
||||
def log_once(self, severity, *msg, key):
|
||||
if key not in self._appeared_keys:
|
||||
self._appeared_keys.add(key)
|
||||
self.log(severity, *msg)
|
||||
|
||||
def critical(self, *msg):
|
||||
self.log(self.INTERNAL_ERROR, *msg)
|
||||
|
||||
def critical_once(self, *msg, key):
|
||||
self.log_once(self.INTERNAL_ERROR, *msg, key=key)
|
||||
|
||||
fatal = critical
|
||||
fatal_once = critical_once
|
||||
|
||||
def error(self, *msg):
|
||||
self.log(self.ERROR, *msg)
|
||||
|
||||
def error_once(self, *msg, key):
|
||||
self.log_once(self.ERROR, *msg, key=key)
|
||||
|
||||
def warning(self, *msg):
|
||||
self.log(self.WARNING, *msg)
|
||||
|
||||
def warning_once(self, *msg, key):
|
||||
self.log_once(self.WARNING, *msg, key=key)
|
||||
|
||||
def info(self, *msg):
|
||||
self.log(self.INFO, *msg)
|
||||
|
||||
def info_once(self, *msg, key):
|
||||
self.log_once(self.INFO, *msg, key=key)
|
||||
|
||||
def debug(self, *msg):
|
||||
self.log(self.VERBOSE, *msg)
|
||||
|
||||
def debug_once(self, *msg, key):
|
||||
self.log_once(self.VERBOSE, *msg, key=key)
|
||||
|
||||
@property
|
||||
def level(self) -> str:
|
||||
return self._min_severity
|
||||
|
||||
@ -54,6 +54,8 @@ l0_dgx_h100:
|
||||
auto_trigger: deepseek
|
||||
tests:
|
||||
- unittest/_torch/multi_gpu_modeling -k "deepseek"
|
||||
- unittest/_torch/modules/test_fused_moe.py::test_fused_moe_alltoall[DeepEP]
|
||||
- unittest/_torch/modules/test_fused_moe.py::test_fused_moe_alltoall[DeepEPLowLatency]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp4-mtp_nextn=0-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp4-mtp_nextn=0-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=True]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp4-mtp_nextn=0-fp8kv=True-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False]
|
||||
|
||||
@ -2,6 +2,7 @@ import pickle
|
||||
import sys
|
||||
from itertools import product
|
||||
from typing import Dict, List, Optional
|
||||
from unittest import mock
|
||||
|
||||
import cloudpickle
|
||||
import pytest
|
||||
@ -19,6 +20,8 @@ from tensorrt_llm._torch.modules.fused_moe import (BaseMoeRoutingMethod,
|
||||
DefaultMoeRoutingMethod,
|
||||
RenormalizeMoeRoutingMethod,
|
||||
VanillaMoE)
|
||||
from tensorrt_llm._torch.modules.fused_moe.fused_moe_cutlass import \
|
||||
AlltoallMethodType
|
||||
from tensorrt_llm._torch.modules.gated_mlp import GatedMLP
|
||||
from tensorrt_llm._utils import mpi_rank
|
||||
from tensorrt_llm.mapping import Mapping
|
||||
@ -123,6 +126,111 @@ def test_fused_moe_multi_gpu(moe_cls, ep_size):
|
||||
assert r is None
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 4,
|
||||
reason="needs 4 GPUs to run this test")
|
||||
@pytest.mark.parametrize("alltoall_method_type", [
|
||||
AlltoallMethodType.MNNVL, AlltoallMethodType.DeepEP,
|
||||
AlltoallMethodType.DeepEPLowLatency
|
||||
],
|
||||
ids=lambda s: s.name)
|
||||
def test_fused_moe_alltoall(alltoall_method_type):
|
||||
world_size = 4
|
||||
dtype = torch.bfloat16
|
||||
HIDDEN_SIZE = 2560
|
||||
INTERMEDIATE_SIZE = 1536
|
||||
NUM_EXPERTS = 72
|
||||
TOP_K = 6
|
||||
MAX_NUM_TOKENS = 2048
|
||||
|
||||
def per_rank_test_fused_moe_alltoall(job_id):
|
||||
routing_method = DefaultMoeRoutingMethod(top_k=TOP_K)
|
||||
mapping = Mapping(world_size=world_size,
|
||||
rank=mpi_rank(),
|
||||
tp_size=world_size,
|
||||
moe_ep_size=world_size,
|
||||
moe_tp_size=1,
|
||||
enable_attention_dp=True)
|
||||
torch.cuda.set_device(mapping.rank)
|
||||
torch.manual_seed(mapping.rank)
|
||||
|
||||
weights = {}
|
||||
for expert_id in range(NUM_EXPERTS):
|
||||
w1_weight = torch.empty((INTERMEDIATE_SIZE, HIDDEN_SIZE),
|
||||
dtype=dtype)
|
||||
w2_weight = torch.empty((HIDDEN_SIZE, INTERMEDIATE_SIZE),
|
||||
dtype=dtype)
|
||||
w3_weight = torch.empty((INTERMEDIATE_SIZE, HIDDEN_SIZE),
|
||||
dtype=dtype)
|
||||
torch.nn.init.xavier_uniform_(w1_weight)
|
||||
torch.nn.init.xavier_uniform_(w2_weight)
|
||||
torch.nn.init.xavier_uniform_(w3_weight)
|
||||
weights[f"{expert_id}.w1.weight"] = w1_weight
|
||||
weights[f"{expert_id}.w2.weight"] = w2_weight
|
||||
weights[f"{expert_id}.w3.weight"] = w3_weight
|
||||
with mock.patch.object(CutlassFusedMoE,
|
||||
"select_alltoall_method_type",
|
||||
return_value=alltoall_method_type):
|
||||
alltoall_model = CutlassFusedMoE(
|
||||
num_experts=NUM_EXPERTS,
|
||||
routing_method=routing_method,
|
||||
hidden_size=HIDDEN_SIZE,
|
||||
intermediate_size=INTERMEDIATE_SIZE,
|
||||
dtype=dtype,
|
||||
reduce_results=True,
|
||||
model_config=ModelConfig(mapping=mapping,
|
||||
max_num_tokens=MAX_NUM_TOKENS),
|
||||
)
|
||||
alltoall_model.to("cuda")
|
||||
alltoall_model.load_weights([weights])
|
||||
with mock.patch.object(CutlassFusedMoE,
|
||||
"select_alltoall_method_type",
|
||||
return_value=AlltoallMethodType.NotEnabled):
|
||||
ref_model = CutlassFusedMoE(
|
||||
num_experts=NUM_EXPERTS,
|
||||
routing_method=routing_method,
|
||||
hidden_size=HIDDEN_SIZE,
|
||||
intermediate_size=INTERMEDIATE_SIZE,
|
||||
dtype=dtype,
|
||||
reduce_results=True,
|
||||
model_config=ModelConfig(mapping=mapping,
|
||||
max_num_tokens=MAX_NUM_TOKENS),
|
||||
)
|
||||
ref_model.to("cuda")
|
||||
ref_model.load_weights([weights])
|
||||
|
||||
# Evaluate the outputs on a variant sequence length to verify the robustness of alltoall methods
|
||||
m = MAX_NUM_TOKENS
|
||||
while m >= 1:
|
||||
x = torch.randn((m, HIDDEN_SIZE), dtype=dtype).cuda()
|
||||
router_logits = torch.randn((m, NUM_EXPERTS), dtype=dtype).cuda()
|
||||
all_rank_num_tokens = [m] * mapping.world_size
|
||||
|
||||
with torch.inference_mode():
|
||||
output = alltoall_model.forward(
|
||||
x,
|
||||
router_logits,
|
||||
all_rank_num_tokens=all_rank_num_tokens,
|
||||
use_dp_padding=False)
|
||||
ref_output = ref_model.forward(
|
||||
x,
|
||||
router_logits,
|
||||
all_rank_num_tokens=all_rank_num_tokens,
|
||||
use_dp_padding=False)
|
||||
|
||||
# Evaluate outputs
|
||||
torch.testing.assert_close(output,
|
||||
ref_output,
|
||||
rtol=0.05,
|
||||
atol=0.003)
|
||||
m //= 2
|
||||
|
||||
with MPIPoolExecutor(max_workers=world_size) as executor:
|
||||
results = executor.map(per_rank_test_fused_moe_alltoall,
|
||||
range(world_size))
|
||||
for r in results:
|
||||
assert r is None
|
||||
|
||||
|
||||
@skip_pre_hopper
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||
def test_fused_moe_fp8(dtype):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user