mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
feat: large-scale EP(part 7: DeepEP integration) (#4792)
Signed-off-by: Tailing Yuan <yuantailing@gmail.com> Co-authored-by: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com>
This commit is contained in:
parent
443b2eb51f
commit
0b60da2c45
@ -1,7 +1,7 @@
|
|||||||
version: "3.9"
|
version: "3.9"
|
||||||
services:
|
services:
|
||||||
tensorrt_llm-dev:
|
tensorrt_llm-dev:
|
||||||
image: urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:pytorch-25.04-py3-x86_64-ubuntu24.04-trt10.10.0.31-skip-tritondevel-202506021004-9420
|
image: urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:pytorch-25.04-py3-x86_64-ubuntu24.04-trt10.10.0.31-skip-tritondevel-202506111045-4792
|
||||||
network_mode: host
|
network_mode: host
|
||||||
ipc: host
|
ipc: host
|
||||||
|
|
||||||
|
|||||||
@ -72,6 +72,10 @@ RUN bash ./install_pytorch.sh $TORCH_INSTALL_TYPE && rm install_pytorch.sh
|
|||||||
RUN pip3 uninstall -y opencv && rm -rf /usr/local/lib/python3*/dist-packages/cv2/
|
RUN pip3 uninstall -y opencv && rm -rf /usr/local/lib/python3*/dist-packages/cv2/
|
||||||
RUN pip3 install opencv-python-headless --force-reinstall --no-deps --no-cache-dir
|
RUN pip3 install opencv-python-headless --force-reinstall --no-deps --no-cache-dir
|
||||||
|
|
||||||
|
# Install DeepEP
|
||||||
|
COPY docker/common/install_deep_ep.sh install_deep_ep.sh
|
||||||
|
RUN bash ./install_deep_ep.sh && rm install_deep_ep.sh
|
||||||
|
|
||||||
# WARs against security issues inherited from pytorch:25.04
|
# WARs against security issues inherited from pytorch:25.04
|
||||||
# * https://github.com/advisories/GHSA-vqfr-h8mv-ghfj
|
# * https://github.com/advisories/GHSA-vqfr-h8mv-ghfj
|
||||||
# * https://github.com/advisories/GHSA-7cx3-6m66-7c5m
|
# * https://github.com/advisories/GHSA-7cx3-6m66-7c5m
|
||||||
|
|||||||
47
docker/common/install_deep_ep.sh
Normal file
47
docker/common/install_deep_ep.sh
Normal file
@ -0,0 +1,47 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
set -euxo pipefail
|
||||||
|
|
||||||
|
GITHUB_URL=${GITHUB_MIRROR:-https://github.com}
|
||||||
|
DEEP_EP_COMMIT=2b266cf6452134f993ab0fcb3ef2d5de7683c561
|
||||||
|
|
||||||
|
if [ "$(. /etc/os-release && echo $ID)" == "rocky" ]; then
|
||||||
|
echo "Skipping DeepEP installation in the Rocky distribution."
|
||||||
|
exit 0
|
||||||
|
fi
|
||||||
|
libmlx5_dir=$(dirname $(ldconfig -p | grep libmlx5.so.1 | head -n1 | awk '{print $NF}'))
|
||||||
|
|
||||||
|
export NVCC_APPEND_FLAGS="--threads 4"
|
||||||
|
|
||||||
|
# Custom NVSHMEM
|
||||||
|
curl -fsSL https://developer.download.nvidia.com/compute/redist/nvshmem/3.2.5/source/nvshmem_src_3.2.5-1.txz | tar xz
|
||||||
|
pushd nvshmem_src
|
||||||
|
curl -fsSL $GITHUB_URL/deepseek-ai/DeepEP/raw/$DEEP_EP_COMMIT/third-party/nvshmem.patch | patch -p1
|
||||||
|
sed "s/TRANSPORT_VERSION_MAJOR 3/TRANSPORT_VERSION_MAJOR 103/" -i src/CMakeLists.txt
|
||||||
|
ln -s libmlx5.so.1 "$libmlx5_dir/libmlx5.so"
|
||||||
|
cmake -S . -B build \
|
||||||
|
-DCMAKE_INSTALL_PREFIX=/opt/custom_nvshmem \
|
||||||
|
-DGDRCOPY_HOME=/usr/include \
|
||||||
|
-DNVSHMEM_SHMEM_SUPPORT=0 \
|
||||||
|
-DNVSHMEM_UCX_SUPPORT=0 \
|
||||||
|
-DNVSHMEM_USE_NCCL=0 \
|
||||||
|
-DNVSHMEM_MPI_SUPPORT=0 \
|
||||||
|
-DNVSHMEM_IBGDA_SUPPORT=1 \
|
||||||
|
-DNVSHMEM_PMIX_SUPPORT=0 \
|
||||||
|
-DNVSHMEM_TIMEOUT_DEVICE_POLLING=0 \
|
||||||
|
-DNVSHMEM_USE_GDRCOPY=1 \
|
||||||
|
-DCMAKE_CUDA_ARCHITECTURES="90-real;100-real;120-real" \
|
||||||
|
-DNVSHMEM_BUILD_TESTS=0 \
|
||||||
|
-DNVSHMEM_BUILD_EXAMPLES=0
|
||||||
|
cmake --build build -j`nproc`
|
||||||
|
make -C build install
|
||||||
|
popd
|
||||||
|
|
||||||
|
# DeepEP
|
||||||
|
curl -fsSL $GITHUB_URL/deepseek-ai/DeepEP/archive/$DEEP_EP_COMMIT.tar.gz | tar xz
|
||||||
|
TORCH_CUDA_ARCH_LIST="9.0;10.0;12.0" NVSHMEM_DIR=/opt/custom_nvshmem pip install -v --no-cache-dir ./DeepEP-$DEEP_EP_COMMIT
|
||||||
|
|
||||||
|
# Clean up
|
||||||
|
rm -r nvshmem_src
|
||||||
|
rm "$libmlx5_dir/libmlx5.so"
|
||||||
|
rm -r DeepEP-$DEEP_EP_COMMIT
|
||||||
@ -28,10 +28,10 @@ UPLOAD_PATH = env.uploadPath ? env.uploadPath : "sw-tensorrt-generic/llm-artifac
|
|||||||
// Container configuration
|
// Container configuration
|
||||||
// available tags can be found in: https://urm.nvidia.com/artifactory/sw-tensorrt-docker/tensorrt-llm/
|
// available tags can be found in: https://urm.nvidia.com/artifactory/sw-tensorrt-docker/tensorrt-llm/
|
||||||
// [base_image_name]-[arch]-[os](-[python_version])-[trt_version]-[torch_install_type]-[stage]-[date]-[mr_id]
|
// [base_image_name]-[arch]-[os](-[python_version])-[trt_version]-[torch_install_type]-[stage]-[date]-[mr_id]
|
||||||
LLM_DOCKER_IMAGE = "urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:pytorch-25.04-py3-x86_64-ubuntu24.04-trt10.10.0.31-skip-tritondevel-202506021004-9420"
|
LLM_DOCKER_IMAGE = "urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:pytorch-25.04-py3-x86_64-ubuntu24.04-trt10.10.0.31-skip-tritondevel-202506111045-4792"
|
||||||
LLM_SBSA_DOCKER_IMAGE = "urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:pytorch-25.04-py3-aarch64-ubuntu24.04-trt10.10.0.31-skip-tritondevel-202506021004-9420"
|
LLM_SBSA_DOCKER_IMAGE = "urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:pytorch-25.04-py3-aarch64-ubuntu24.04-trt10.10.0.31-skip-tritondevel-202506111045-4792"
|
||||||
LLM_ROCKYLINUX8_PY310_DOCKER_IMAGE = "urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:cuda-12.9.0-devel-rocky8-x86_64-rocky8-py310-trt10.10.0.31-skip-tritondevel-202506021004-9420"
|
LLM_ROCKYLINUX8_PY310_DOCKER_IMAGE = "urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:cuda-12.9.0-devel-rocky8-x86_64-rocky8-py310-trt10.10.0.31-skip-tritondevel-202506111045-4792"
|
||||||
LLM_ROCKYLINUX8_PY312_DOCKER_IMAGE = "urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:cuda-12.9.0-devel-rocky8-x86_64-rocky8-py312-trt10.10.0.31-skip-tritondevel-202506021004-9420"
|
LLM_ROCKYLINUX8_PY312_DOCKER_IMAGE = "urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:cuda-12.9.0-devel-rocky8-x86_64-rocky8-py312-trt10.10.0.31-skip-tritondevel-202506111045-4792"
|
||||||
|
|
||||||
// TODO: Move common variables to an unified location
|
// TODO: Move common variables to an unified location
|
||||||
BUILD_CORES_REQUEST = "8"
|
BUILD_CORES_REQUEST = "8"
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
|
|
||||||
import java.lang.InterruptedException
|
import java.lang.InterruptedException
|
||||||
|
|
||||||
DOCKER_IMAGE = "urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:pytorch-25.04-py3-x86_64-ubuntu24.04-trt10.10.0.31-skip-tritondevel-202506021004-9420"
|
DOCKER_IMAGE = "urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:pytorch-25.04-py3-x86_64-ubuntu24.04-trt10.10.0.31-skip-tritondevel-202506111045-4792"
|
||||||
|
|
||||||
def createKubernetesPodConfig(image, arch = "amd64")
|
def createKubernetesPodConfig(image, arch = "amd64")
|
||||||
{
|
{
|
||||||
|
|||||||
@ -84,6 +84,9 @@ class ModelConfig(Generic[TConfig]):
|
|||||||
# If true, enable min-latency mode. Currently only used for Llama4.
|
# If true, enable min-latency mode. Currently only used for Llama4.
|
||||||
enable_min_latency: bool = False
|
enable_min_latency: bool = False
|
||||||
|
|
||||||
|
# Allow models to select op according to whether CUDA Graphs are used.
|
||||||
|
use_cuda_graph: bool = False
|
||||||
|
|
||||||
extra_attrs: Dict = field(default_factory=dict, repr=False, init=False)
|
extra_attrs: Dict = field(default_factory=dict, repr=False, init=False)
|
||||||
|
|
||||||
_frozen: bool = field(default=False, init=False, repr=False)
|
_frozen: bool = field(default=False, init=False, repr=False)
|
||||||
|
|||||||
@ -38,7 +38,6 @@ from torch import nn
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
from tensorrt_llm._mnnvl_utils import MnnvlMemory
|
|
||||||
from tensorrt_llm.functional import PositionEmbeddingType
|
from tensorrt_llm.functional import PositionEmbeddingType
|
||||||
from tensorrt_llm.llmapi.utils import enable_llm_debug
|
from tensorrt_llm.llmapi.utils import enable_llm_debug
|
||||||
from tensorrt_llm.mapping import Mapping
|
from tensorrt_llm.mapping import Mapping
|
||||||
@ -413,10 +412,6 @@ class Deepseekv3MoE(nn.Module):
|
|||||||
config = model_config.pretrained_config
|
config = model_config.pretrained_config
|
||||||
self.top_k = top_k
|
self.top_k = top_k
|
||||||
self.use_dp = model_config.mapping.enable_attention_dp
|
self.use_dp = model_config.mapping.enable_attention_dp
|
||||||
self.enable_alltoall = Deepseekv3MoE.should_enable_alltoall(
|
|
||||||
model_config, top_k)
|
|
||||||
if self.enable_alltoall:
|
|
||||||
MnnvlMemory.initialize()
|
|
||||||
self.gate = DeepseekV3Gate(
|
self.gate = DeepseekV3Gate(
|
||||||
hidden_size,
|
hidden_size,
|
||||||
num_experts,
|
num_experts,
|
||||||
@ -439,7 +434,6 @@ class Deepseekv3MoE(nn.Module):
|
|||||||
model_config=model_config,
|
model_config=model_config,
|
||||||
override_quant_config=override_quant_config,
|
override_quant_config=override_quant_config,
|
||||||
aux_stream=aux_stream_dict[AuxStreamType.MoeChunkingOverlap],
|
aux_stream=aux_stream_dict[AuxStreamType.MoeChunkingOverlap],
|
||||||
enable_alltoall=self.enable_alltoall,
|
|
||||||
layer_idx=layer_idx)
|
layer_idx=layer_idx)
|
||||||
|
|
||||||
self.mapping = model_config.mapping
|
self.mapping = model_config.mapping
|
||||||
@ -505,25 +499,6 @@ class Deepseekv3MoE(nn.Module):
|
|||||||
|
|
||||||
return shared_tp_size, shared_output_scale
|
return shared_tp_size, shared_output_scale
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def should_enable_alltoall(model_config: ModelConfig, top_k: int) -> bool:
|
|
||||||
if not model_config.mapping.enable_attention_dp:
|
|
||||||
return False
|
|
||||||
|
|
||||||
if model_config.mapping.tp_size == 1:
|
|
||||||
return False
|
|
||||||
|
|
||||||
if not MnnvlMemory.supports_mnnvl():
|
|
||||||
return False
|
|
||||||
|
|
||||||
if os.environ.get("TRTLLM_MOE_DISABLE_ALLTOALLV", "0") == "1":
|
|
||||||
return False
|
|
||||||
|
|
||||||
if model_config.mapping.moe_ep_size <= top_k:
|
|
||||||
return False
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
def compute_routed_output(self, hidden_states, hidden_states_fp4,
|
def compute_routed_output(self, hidden_states, hidden_states_fp4,
|
||||||
all_rank_num_tokens, do_finalize):
|
all_rank_num_tokens, do_finalize):
|
||||||
# max-throughput
|
# max-throughput
|
||||||
@ -531,7 +506,7 @@ class Deepseekv3MoE(nn.Module):
|
|||||||
if self.use_dp and self.mapping.tp_size > 1:
|
if self.use_dp and self.mapping.tp_size > 1:
|
||||||
# FP4 all_gather moves this bf16 allgather in to after topk and fp4 quantization
|
# FP4 all_gather moves this bf16 allgather in to after topk and fp4 quantization
|
||||||
# to reduce allreduce BW
|
# to reduce allreduce BW
|
||||||
if disable_fp4_allgather() and not self.enable_alltoall:
|
if disable_fp4_allgather() and not self.experts.enable_alltoall:
|
||||||
hidden_states = allgather(hidden_states,
|
hidden_states = allgather(hidden_states,
|
||||||
self.mapping,
|
self.mapping,
|
||||||
dim=0,
|
dim=0,
|
||||||
|
|||||||
@ -6,8 +6,6 @@ from torch import nn
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from transformers import Qwen3MoeConfig
|
from transformers import Qwen3MoeConfig
|
||||||
|
|
||||||
from tensorrt_llm._mnnvl_utils import MnnvlMemory
|
|
||||||
|
|
||||||
from ..attention_backend import AttentionMetadata
|
from ..attention_backend import AttentionMetadata
|
||||||
from ..distributed import (AllReduce, AllReduceFusionOp, AllReduceParams,
|
from ..distributed import (AllReduce, AllReduceFusionOp, AllReduceParams,
|
||||||
allgather)
|
allgather)
|
||||||
@ -91,10 +89,6 @@ class Qwen3MoE(nn.Module):
|
|||||||
self.mapping = model_config.mapping
|
self.mapping = model_config.mapping
|
||||||
self.allreduce = AllReduce(mapping=model_config.mapping,
|
self.allreduce = AllReduce(mapping=model_config.mapping,
|
||||||
strategy=model_config.allreduce_strategy)
|
strategy=model_config.allreduce_strategy)
|
||||||
self.enable_alltoall = Qwen3MoE.should_enable_alltoall(
|
|
||||||
model_config, self.top_k)
|
|
||||||
if self.enable_alltoall:
|
|
||||||
MnnvlMemory.initialize()
|
|
||||||
|
|
||||||
self.gate = Qwen3Gate(
|
self.gate = Qwen3Gate(
|
||||||
hidden_size=self.hidden_dim,
|
hidden_size=self.hidden_dim,
|
||||||
@ -117,25 +111,6 @@ class Qwen3MoE(nn.Module):
|
|||||||
model_config=model_config,
|
model_config=model_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def should_enable_alltoall(model_config: ModelConfig, top_k: int) -> bool:
|
|
||||||
if not model_config.mapping.enable_attention_dp:
|
|
||||||
return False
|
|
||||||
|
|
||||||
if model_config.mapping.tp_size == 1:
|
|
||||||
return False
|
|
||||||
|
|
||||||
if not MnnvlMemory.supports_mnnvl():
|
|
||||||
return False
|
|
||||||
|
|
||||||
if os.environ.get("TRTLLM_MOE_DISABLE_ALLTOALLV", "0") == "1":
|
|
||||||
return False
|
|
||||||
|
|
||||||
if model_config.mapping.moe_ep_size <= top_k:
|
|
||||||
return False
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
@ -151,7 +126,7 @@ class Qwen3MoE(nn.Module):
|
|||||||
if self.enable_attention_dp and self.mapping.tp_size > 1:
|
if self.enable_attention_dp and self.mapping.tp_size > 1:
|
||||||
# FP4 all_gather moves this bf16 allgather in to after topk and fp4 quantization
|
# FP4 all_gather moves this bf16 allgather in to after topk and fp4 quantization
|
||||||
# to reduce allreduce BW
|
# to reduce allreduce BW
|
||||||
if disable_fp4_allgather() and not self.enable_alltoall:
|
if disable_fp4_allgather() and not self.experts.enable_alltoall:
|
||||||
hidden_states = allgather(hidden_states,
|
hidden_states = allgather(hidden_states,
|
||||||
self.mapping,
|
self.mapping,
|
||||||
dim=0,
|
dim=0,
|
||||||
|
|||||||
@ -52,7 +52,6 @@ def create_moe(
|
|||||||
aux_stream: Optional[torch.cuda.Stream] = None,
|
aux_stream: Optional[torch.cuda.Stream] = None,
|
||||||
weight_loading_mode: MoEWeightLoadingMode = MoEWeightLoadingMode.VANILLA,
|
weight_loading_mode: MoEWeightLoadingMode = MoEWeightLoadingMode.VANILLA,
|
||||||
apply_router_weight_on_input: bool = False,
|
apply_router_weight_on_input: bool = False,
|
||||||
enable_alltoall: bool = False,
|
|
||||||
layer_idx: Optional[int] = None,
|
layer_idx: Optional[int] = None,
|
||||||
) -> MoE:
|
) -> MoE:
|
||||||
moe_cls = get_moe_cls(model_config, override_quant_config)
|
moe_cls = get_moe_cls(model_config, override_quant_config)
|
||||||
@ -63,7 +62,6 @@ def create_moe(
|
|||||||
|
|
||||||
if moe_cls == TRTLLMGenFusedMoE:
|
if moe_cls == TRTLLMGenFusedMoE:
|
||||||
assert not apply_router_weight_on_input, "apply_router_weight_on_input is not supported in TRTLLMGenFusedMoE."
|
assert not apply_router_weight_on_input, "apply_router_weight_on_input is not supported in TRTLLMGenFusedMoE."
|
||||||
assert not enable_alltoall, "enable_alltoall is not supported in TRTLLMGenFusedMoE."
|
|
||||||
|
|
||||||
return moe_cls(
|
return moe_cls(
|
||||||
routing_method=routing_method,
|
routing_method=routing_method,
|
||||||
@ -88,12 +86,10 @@ def create_moe(
|
|||||||
aux_stream=aux_stream,
|
aux_stream=aux_stream,
|
||||||
weight_loading_mode=weight_loading_mode,
|
weight_loading_mode=weight_loading_mode,
|
||||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
enable_alltoall=enable_alltoall,
|
|
||||||
layer_idx=layer_idx,
|
layer_idx=layer_idx,
|
||||||
)
|
)
|
||||||
elif moe_cls == VanillaMoE:
|
elif moe_cls == VanillaMoE:
|
||||||
assert not apply_router_weight_on_input, "apply_router_weight_on_input is not supported in VanillaMoE."
|
assert not apply_router_weight_on_input, "apply_router_weight_on_input is not supported in VanillaMoE."
|
||||||
assert not enable_alltoall, "enable_alltoall is not supported in VanillaMoE."
|
|
||||||
|
|
||||||
return moe_cls(
|
return moe_cls(
|
||||||
routing_method=routing_method,
|
routing_method=routing_method,
|
||||||
|
|||||||
209
tensorrt_llm/_torch/modules/fused_moe/deep_ep_utils.py
Normal file
209
tensorrt_llm/_torch/modules/fused_moe/deep_ep_utils.py
Normal file
@ -0,0 +1,209 @@
|
|||||||
|
# Adapted from
|
||||||
|
# https://github.com/deepseek-ai/DeepEP/blob/aae9fa9a6dd0fec2a723fbb85ec4b22460fab670/README.md
|
||||||
|
import weakref
|
||||||
|
from typing import List, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from tensorrt_llm._utils import local_mpi_size, mpi_comm
|
||||||
|
from tensorrt_llm.mapping import Mapping
|
||||||
|
|
||||||
|
try:
|
||||||
|
from deep_ep import Buffer
|
||||||
|
deep_ep_installed = True
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
deep_ep_installed = False
|
||||||
|
|
||||||
|
|
||||||
|
class VariableLengthBuffer:
|
||||||
|
""" A wrapper of deep_ep.Buffer that accepts future size change
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, mapping: Mapping):
|
||||||
|
self.comm = mpi_comm().Split(mapping.pp_rank, mapping.moe_ep_rank)
|
||||||
|
self.buffer = None
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
self.comm.Free()
|
||||||
|
|
||||||
|
def reserve(self, hidden_size: int, hidden_dtype: torch.dtype):
|
||||||
|
""" Ensure the buffer capacity is large enough.
|
||||||
|
|
||||||
|
Reserve is a collective operation that requires all EP ranks to be sync
|
||||||
|
"""
|
||||||
|
# NOTES: you may also replace `get_*_config` with your auto-tuned results via all the tests
|
||||||
|
num_nvl_bytes, num_rdma_bytes = 0, 0
|
||||||
|
hidden_bytes = hidden_size * max(hidden_dtype.itemsize,
|
||||||
|
torch.bfloat16.itemsize)
|
||||||
|
world_size = self.comm.Get_size()
|
||||||
|
for config in (Buffer.get_dispatch_config(world_size),
|
||||||
|
Buffer.get_combine_config(world_size)):
|
||||||
|
num_nvl_bytes = max(
|
||||||
|
config.get_nvl_buffer_size_hint(hidden_bytes, world_size),
|
||||||
|
num_nvl_bytes)
|
||||||
|
num_rdma_bytes = max(
|
||||||
|
config.get_rdma_buffer_size_hint(hidden_bytes, world_size),
|
||||||
|
num_rdma_bytes)
|
||||||
|
|
||||||
|
# Allocate a buffer if not existed or not enough buffer size
|
||||||
|
if self.buffer is None or self.buffer.num_nvl_bytes < num_nvl_bytes or self.buffer.num_rdma_bytes < num_rdma_bytes:
|
||||||
|
if self.buffer is not None:
|
||||||
|
num_nvl_bytes = max(num_nvl_bytes, self.buffer.num_nvl_bytes)
|
||||||
|
num_rdma_bytes = max(num_rdma_bytes, self.buffer.num_rdma_bytes)
|
||||||
|
del self.buffer # Destruct before Construct
|
||||||
|
self.buffer = Buffer(None,
|
||||||
|
num_nvl_bytes,
|
||||||
|
num_rdma_bytes,
|
||||||
|
num_nvl_peers=local_mpi_size(),
|
||||||
|
comm=self.comm)
|
||||||
|
|
||||||
|
def dispatch(self, x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
|
||||||
|
topk_idx: torch.Tensor, topk_weights: torch.Tensor,
|
||||||
|
num_experts: int) -> \
|
||||||
|
Tuple[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], torch.Tensor, torch.Tensor, List, Tuple]:
|
||||||
|
# NOTES: an optional `previous_event` means a CUDA event captured that you want to make it as a dependency
|
||||||
|
# of the dispatch kernel, it may be useful with communication-computation overlap. For more information, please
|
||||||
|
# refer to the docs of `Buffer.dispatch`
|
||||||
|
|
||||||
|
# Calculate layout before actual dispatch
|
||||||
|
num_tokens_per_rank, num_tokens_per_rdma_rank, num_tokens_per_expert, is_token_in_rank, event = \
|
||||||
|
self.buffer.get_dispatch_layout(topk_idx, num_experts)
|
||||||
|
assert event.event is None
|
||||||
|
|
||||||
|
# Do MoE dispatch
|
||||||
|
# NOTES: the CPU will wait for GPU's signal to arrive, so this is not compatible with CUDA graph
|
||||||
|
# For more advanced usages, please refer to the docs of the `dispatch` function
|
||||||
|
recv_x, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, handle, event = \
|
||||||
|
self.buffer.dispatch(x, topk_idx=topk_idx, topk_weights=topk_weights,
|
||||||
|
num_tokens_per_rank=num_tokens_per_rank, num_tokens_per_rdma_rank=num_tokens_per_rdma_rank,
|
||||||
|
is_token_in_rank=is_token_in_rank, num_tokens_per_expert=num_tokens_per_expert)
|
||||||
|
assert event.event is None
|
||||||
|
|
||||||
|
# For event management, please refer to the docs of the `EventOverlap` class
|
||||||
|
return recv_x, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, handle
|
||||||
|
|
||||||
|
def combine(self, x: torch.Tensor, handle: Tuple) -> torch.Tensor:
|
||||||
|
# Do MoE combine
|
||||||
|
# For more advanced usages, please refer to the docs of the `combine` function
|
||||||
|
combined_x, _, event = self.buffer.combine(x, handle)
|
||||||
|
assert event.event is None
|
||||||
|
|
||||||
|
# For event management, please refer to the docs of the `EventOverlap` class
|
||||||
|
return combined_x
|
||||||
|
|
||||||
|
|
||||||
|
class VariableLengthLowLatencyBuffer:
|
||||||
|
""" A wrapper of deep_ep.Buffer that accepts future size change
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, mapping: Mapping):
|
||||||
|
self.comm = mpi_comm().Split(mapping.pp_rank, mapping.moe_ep_rank)
|
||||||
|
self.buffer = None
|
||||||
|
self.num_max_dispatch_tokens_per_rank = None
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
self.comm.Free()
|
||||||
|
|
||||||
|
def reserve(self, num_max_dispatch_tokens_per_rank: int, hidden_size: int,
|
||||||
|
num_experts: int):
|
||||||
|
""" Ensure the buffer capacity is large enough.
|
||||||
|
|
||||||
|
Reserve is a collective operation that requires all EP ranks to be sync
|
||||||
|
"""
|
||||||
|
# NOTES: the low-latency mode will consume much more space than the normal mode
|
||||||
|
# So we recommend that `num_max_dispatch_tokens_per_rank` (the actual batch size in the decoding engine) should be less than 256
|
||||||
|
world_size = self.comm.Get_size()
|
||||||
|
num_rdma_bytes = Buffer.get_low_latency_rdma_size_hint(
|
||||||
|
num_max_dispatch_tokens_per_rank, hidden_size, world_size,
|
||||||
|
num_experts)
|
||||||
|
|
||||||
|
# Allocate a buffer if not existed or not enough buffer size
|
||||||
|
if self.buffer is None or self.buffer.num_rdma_bytes < num_rdma_bytes:
|
||||||
|
# NOTES: for best performance, the QP number **must** be equal to the number of the local experts
|
||||||
|
assert num_experts % world_size == 0
|
||||||
|
del self.buffer # Destruct before Construct
|
||||||
|
self.buffer = Buffer(None,
|
||||||
|
0,
|
||||||
|
num_rdma_bytes,
|
||||||
|
low_latency_mode=True,
|
||||||
|
num_qps_per_rank=num_experts // world_size,
|
||||||
|
comm=self.comm)
|
||||||
|
|
||||||
|
def low_latency_dispatch(self, hidden_states: torch.Tensor,
|
||||||
|
topk_idx: torch.Tensor,
|
||||||
|
num_max_dispatch_tokens_per_rank: int,
|
||||||
|
num_experts: int):
|
||||||
|
if self.num_max_dispatch_tokens_per_rank is None:
|
||||||
|
self.num_max_dispatch_tokens_per_rank = num_max_dispatch_tokens_per_rank
|
||||||
|
if num_max_dispatch_tokens_per_rank != self.num_max_dispatch_tokens_per_rank:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"There are issues if `low_latency_dispatch` calls use different `num_max_dispatch_tokens_per_rank` values"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Do MoE dispatch, compatible with CUDA graph (but you may restore some buffer status once you replay)
|
||||||
|
recv_hidden_states, recv_expert_count, handle, event, hook = \
|
||||||
|
self.buffer.low_latency_dispatch(hidden_states, topk_idx, num_max_dispatch_tokens_per_rank, num_experts, use_fp8=False)
|
||||||
|
assert event.event is None
|
||||||
|
assert hook is None
|
||||||
|
|
||||||
|
# NOTES: the actual tensor will not be received only if you call `hook()`,
|
||||||
|
# it is useful for double-batch overlapping, but **without any SM occupation**
|
||||||
|
# If you don't want to overlap, please set `return_recv_hook=False`
|
||||||
|
# Later, you can use our GEMM library to do the computation with this specific format
|
||||||
|
return recv_hidden_states, recv_expert_count, handle
|
||||||
|
|
||||||
|
def low_latency_combine(self, hidden_states: torch.Tensor,
|
||||||
|
topk_idx: torch.Tensor, topk_weights: torch.Tensor,
|
||||||
|
handle: Tuple):
|
||||||
|
# Do MoE combine, compatible with CUDA graph (but you may restore some buffer status once you replay)
|
||||||
|
combined_hidden_states, event, hook = \
|
||||||
|
self.buffer.low_latency_combine(hidden_states, topk_idx, topk_weights, handle)
|
||||||
|
assert event.event is None
|
||||||
|
assert hook is None
|
||||||
|
|
||||||
|
# NOTES: the same behavior as described in the dispatch kernel
|
||||||
|
return combined_hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class BufferPool:
|
||||||
|
""" A pool that allocates buffers on demand.
|
||||||
|
|
||||||
|
Although the pool interface allows creating multiple buffers, the
|
||||||
|
current version of DeepEP supports at most one `deep_ep.Buffer` at a
|
||||||
|
time. Please ensure that all references to `VariableLengthBuffer` are
|
||||||
|
released before getting another buffer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.buffers: Map[Mapping,
|
||||||
|
weakref.ReferenceType[VariableLengthBuffer]] = {}
|
||||||
|
self.low_latency_buffers: Map[
|
||||||
|
Mapping,
|
||||||
|
weakref.ReferenceType[VariableLengthLowLatencyBuffer]] = {}
|
||||||
|
|
||||||
|
def get_buffer(self, mapping: Mapping) -> VariableLengthBuffer:
|
||||||
|
""" Get_buffer is a collective operation that requires all ranks to be sync
|
||||||
|
"""
|
||||||
|
if mapping in self.buffers and self.buffers[mapping]() is not None:
|
||||||
|
buffer = self.buffers[mapping]()
|
||||||
|
else:
|
||||||
|
buffer = VariableLengthBuffer(mapping)
|
||||||
|
self.buffers[mapping] = weakref.ref(buffer)
|
||||||
|
return buffer
|
||||||
|
|
||||||
|
def get_low_latency_buffer(
|
||||||
|
self, mapping: Mapping) -> VariableLengthLowLatencyBuffer:
|
||||||
|
""" Get_low_latency_buffer is a collective operation that requires all ranks to be sync
|
||||||
|
"""
|
||||||
|
if mapping in self.low_latency_buffers and self.low_latency_buffers[
|
||||||
|
mapping]() is not None:
|
||||||
|
buffer = self.low_latency_buffers[mapping]()
|
||||||
|
else:
|
||||||
|
buffer = VariableLengthLowLatencyBuffer(mapping)
|
||||||
|
self.low_latency_buffers[mapping] = weakref.ref(buffer)
|
||||||
|
return buffer
|
||||||
|
|
||||||
|
|
||||||
|
# The default pool
|
||||||
|
# You may create own pools for better resource management.
|
||||||
|
buffer_pool = BufferPool()
|
||||||
@ -1,16 +1,19 @@
|
|||||||
import os
|
import os
|
||||||
|
from enum import IntEnum
|
||||||
from typing import Dict, List, Optional, Tuple, Union
|
from typing import Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from tensorrt_llm._mnnvl_utils import MnnvlMoe, MoEAlltoallInfo
|
from tensorrt_llm._mnnvl_utils import MnnvlMemory, MnnvlMoe, MoEAlltoallInfo
|
||||||
from tensorrt_llm._utils import logger
|
from tensorrt_llm._utils import logger
|
||||||
|
from tensorrt_llm.mapping import Mapping
|
||||||
|
|
||||||
from ...distributed import allgather, reducescatter
|
from ...distributed import allgather, reducescatter
|
||||||
from ...expert_statistic import ExpertStatistic
|
from ...expert_statistic import ExpertStatistic
|
||||||
from ...model_config import ModelConfig
|
from ...model_config import ModelConfig
|
||||||
from ...utils import (EventType, Fp4QuantizedTensor, disable_fp4_allgather,
|
from ...utils import (EventType, Fp4QuantizedTensor, disable_fp4_allgather,
|
||||||
reswizzle_sf, swizzle_sf, unswizzle_sf)
|
reswizzle_sf, swizzle_sf, unswizzle_sf)
|
||||||
|
from .deep_ep_utils import buffer_pool, deep_ep_installed
|
||||||
from .interface import MoE
|
from .interface import MoE
|
||||||
from .moe_load_balancer import get_moe_load_balancer
|
from .moe_load_balancer import get_moe_load_balancer
|
||||||
from .quantization import (DeepSeekFP8BlockScalesFusedMoEMethod,
|
from .quantization import (DeepSeekFP8BlockScalesFusedMoEMethod,
|
||||||
@ -20,6 +23,18 @@ from .quantization import (DeepSeekFP8BlockScalesFusedMoEMethod,
|
|||||||
from .routing import BaseMoeRoutingMethod
|
from .routing import BaseMoeRoutingMethod
|
||||||
|
|
||||||
|
|
||||||
|
# The type of alltoall method
|
||||||
|
class AlltoallMethodType(IntEnum):
|
||||||
|
# Not available
|
||||||
|
NotEnabled = 0
|
||||||
|
# MNNVL
|
||||||
|
MNNVL = 1
|
||||||
|
# DeepEP intranode or internode: no CUDA Graphs support, IBGDA is required by internode
|
||||||
|
DeepEP = 2
|
||||||
|
# DeepEP low latency: CUDA Graphs are supported, IBGDA is required
|
||||||
|
DeepEPLowLatency = 3
|
||||||
|
|
||||||
|
|
||||||
class CutlassFusedMoE(MoE):
|
class CutlassFusedMoE(MoE):
|
||||||
"""
|
"""
|
||||||
Fused Mixture of Experts (MoE) Layer with performance tuning.
|
Fused Mixture of Experts (MoE) Layer with performance tuning.
|
||||||
@ -33,7 +48,6 @@ class CutlassFusedMoE(MoE):
|
|||||||
dtype (Optional[torch.dtype]): Data type for the weights.
|
dtype (Optional[torch.dtype]): Data type for the weights.
|
||||||
reduce_results (bool): Whether to reduce the results across devices.
|
reduce_results (bool): Whether to reduce the results across devices.
|
||||||
model_config (ModelConfig): Configuration object for the model.
|
model_config (ModelConfig): Configuration object for the model.
|
||||||
enable_alltoall (bool): whether to enable alltoall instead of allgather/reducescatter
|
|
||||||
|
|
||||||
MoE torch custom op:
|
MoE torch custom op:
|
||||||
In min-latency mode:
|
In min-latency mode:
|
||||||
@ -82,7 +96,6 @@ class CutlassFusedMoE(MoE):
|
|||||||
weight_loading_mode: MoEWeightLoadingMode = MoEWeightLoadingMode.
|
weight_loading_mode: MoEWeightLoadingMode = MoEWeightLoadingMode.
|
||||||
VANILLA,
|
VANILLA,
|
||||||
apply_router_weight_on_input: bool = False,
|
apply_router_weight_on_input: bool = False,
|
||||||
enable_alltoall: bool = False,
|
|
||||||
layer_idx: Optional[int] = None,
|
layer_idx: Optional[int] = None,
|
||||||
):
|
):
|
||||||
|
|
||||||
@ -176,7 +189,12 @@ class CutlassFusedMoE(MoE):
|
|||||||
self.has_been_profiled = False
|
self.has_been_profiled = False
|
||||||
self.has_been_profiled_min_latency = False
|
self.has_been_profiled_min_latency = False
|
||||||
|
|
||||||
self.enable_alltoall = enable_alltoall
|
self.alltoall_method_type = self.select_alltoall_method_type(
|
||||||
|
model_config.mapping, routing_method.experts_per_token, dtype,
|
||||||
|
model_config.use_cuda_graph)
|
||||||
|
logger.info_once(
|
||||||
|
f"CutlassFusedMoE selects alltoall_method_type {self.alltoall_method_type!r}",
|
||||||
|
key="alltoall_method_type")
|
||||||
self.use_postquant_alltoall = False
|
self.use_postquant_alltoall = False
|
||||||
if self.enable_alltoall:
|
if self.enable_alltoall:
|
||||||
assert self.use_dp and self.parallel_size > 1,\
|
assert self.use_dp and self.parallel_size > 1,\
|
||||||
@ -185,8 +203,25 @@ class CutlassFusedMoE(MoE):
|
|||||||
self.use_postquant_alltoall = (os.environ.get(
|
self.use_postquant_alltoall = (os.environ.get(
|
||||||
"TRTLLM_MOE_POST_QUANT_ALLTOALLV", "1")
|
"TRTLLM_MOE_POST_QUANT_ALLTOALLV", "1")
|
||||||
== "1") and qm.has_nvfp4()
|
== "1") and qm.has_nvfp4()
|
||||||
self.alltoall_workspace = MnnvlMoe.get_moe_workspaces(
|
if self.alltoall_method_type == AlltoallMethodType.MNNVL:
|
||||||
model_config.mapping) if enable_alltoall else None
|
MnnvlMemory.initialize()
|
||||||
|
self.alltoall_workspace = MnnvlMoe.get_moe_workspaces(
|
||||||
|
model_config.mapping)
|
||||||
|
elif self.alltoall_method_type == AlltoallMethodType.DeepEP:
|
||||||
|
self.deep_ep_buffer = buffer_pool.get_buffer(
|
||||||
|
model_config.mapping)
|
||||||
|
self.deep_ep_buffer.reserve(hidden_size, dtype)
|
||||||
|
elif self.alltoall_method_type == AlltoallMethodType.DeepEPLowLatency:
|
||||||
|
self.deep_ep_max_num_tokens = min(model_config.max_num_tokens,
|
||||||
|
self.moe_max_num_tokens)
|
||||||
|
self.deep_ep_buffer = buffer_pool.get_low_latency_buffer(
|
||||||
|
model_config.mapping)
|
||||||
|
self.deep_ep_buffer.reserve(self.deep_ep_max_num_tokens,
|
||||||
|
hidden_size, self.num_slots)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"Not available alltoall method type: {alltoall_method_type!r}"
|
||||||
|
)
|
||||||
|
|
||||||
# If True, the router weight will be multiplied on the input rather than at the end of FC2
|
# If True, the router weight will be multiplied on the input rather than at the end of FC2
|
||||||
self.apply_router_weight_on_input = apply_router_weight_on_input
|
self.apply_router_weight_on_input = apply_router_weight_on_input
|
||||||
@ -215,12 +250,48 @@ class CutlassFusedMoE(MoE):
|
|||||||
f"unsupported quantization mode: {self.quant_config.quant_mode}"
|
f"unsupported quantization mode: {self.quant_config.quant_mode}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def select_alltoall_method_type(mapping: Mapping, top_k: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
use_cuda_graph: bool) -> AlltoallMethodType:
|
||||||
|
if not mapping.enable_attention_dp:
|
||||||
|
return AlltoallMethodType.NotEnabled
|
||||||
|
|
||||||
|
if mapping.tp_size == 1:
|
||||||
|
return AlltoallMethodType.NotEnabled
|
||||||
|
|
||||||
|
if os.environ.get("TRTLLM_MOE_DISABLE_ALLTOALLV", "0") == "1":
|
||||||
|
return AlltoallMethodType.NotEnabled
|
||||||
|
|
||||||
|
if mapping.moe_ep_size <= top_k:
|
||||||
|
return AlltoallMethodType.NotEnabled
|
||||||
|
|
||||||
|
if MnnvlMemory.supports_mnnvl():
|
||||||
|
return AlltoallMethodType.MNNVL
|
||||||
|
|
||||||
|
if os.environ.get("TRTLLM_CAN_USE_DEEP_EP", "0") == "1":
|
||||||
|
if deep_ep_installed and dtype == torch.bfloat16:
|
||||||
|
if use_cuda_graph:
|
||||||
|
# Here we can only choose DeepEPLowLatency since only this method supports CUDA Graphs.
|
||||||
|
return AlltoallMethodType.DeepEPLowLatency
|
||||||
|
else:
|
||||||
|
# Here we can choose DeepEP or DeepEPLowLatency if both are available. Now DeepEP is faster.
|
||||||
|
return AlltoallMethodType.DeepEP
|
||||||
|
|
||||||
|
return AlltoallMethodType.NotEnabled
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def has_w4afp8(self):
|
def has_w4afp8(self):
|
||||||
assert self._weights_created
|
assert self._weights_created
|
||||||
return self.quant_config and self.quant_config.quant_mode.is_int4_weight_only_per_group(
|
return self.quant_config and self.quant_config.quant_mode.is_int4_weight_only_per_group(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def enable_alltoall(self):
|
||||||
|
""" enable_alltoall (bool): whether to enable alltoall instead of allgather/reducescatter
|
||||||
|
"""
|
||||||
|
return self.alltoall_method_type != AlltoallMethodType.NotEnabled
|
||||||
|
|
||||||
def _get_quant_method(self):
|
def _get_quant_method(self):
|
||||||
if self.quant_config is not None and self.quant_config.layer_quant_mode.has_any_quant(
|
if self.quant_config is not None and self.quant_config.layer_quant_mode.has_any_quant(
|
||||||
exclude_kv_cache=True):
|
exclude_kv_cache=True):
|
||||||
@ -311,10 +382,6 @@ class CutlassFusedMoE(MoE):
|
|||||||
# TODO: remove this once we have correct fusedmoe kernel ready
|
# TODO: remove this once we have correct fusedmoe kernel ready
|
||||||
token_final_scales = None
|
token_final_scales = None
|
||||||
|
|
||||||
token_count = x.shape[0]
|
|
||||||
|
|
||||||
alltoall_info = None
|
|
||||||
|
|
||||||
if self.layer_load_balancer and not self.layer_load_balancer.is_static_routing(
|
if self.layer_load_balancer and not self.layer_load_balancer.is_static_routing(
|
||||||
) and is_first_call:
|
) and is_first_call:
|
||||||
self.layer_load_balancer.maybe_cudagraph_done_wait()
|
self.layer_load_balancer.maybe_cudagraph_done_wait()
|
||||||
@ -334,13 +401,58 @@ class CutlassFusedMoE(MoE):
|
|||||||
ExpertStatistic.maybe_add_info(self.num_slots, token_selected_slots)
|
ExpertStatistic.maybe_add_info(self.num_slots, token_selected_slots)
|
||||||
|
|
||||||
token_selected_experts_for_statistic = token_selected_experts if need_statistic else None
|
token_selected_experts_for_statistic = token_selected_experts if need_statistic else None
|
||||||
|
|
||||||
if self.enable_alltoall:
|
if self.enable_alltoall:
|
||||||
x, token_selected_slots, token_final_scales, token_selected_experts_for_statistic, alltoall_info = \
|
if self.alltoall_method_type == AlltoallMethodType.MNNVL:
|
||||||
self.alltoall_prepare_maybe_dispatch(all_rank_num_tokens,
|
token_count = x.shape[0]
|
||||||
x,
|
alltoall_info = None
|
||||||
token_selected_slots,
|
x, token_selected_slots, token_final_scales, token_selected_experts_for_statistic, alltoall_info = \
|
||||||
token_final_scales,
|
self.alltoall_prepare_maybe_dispatch(all_rank_num_tokens,
|
||||||
token_selected_experts_for_statistic)
|
x,
|
||||||
|
token_selected_slots,
|
||||||
|
token_final_scales,
|
||||||
|
token_selected_experts_for_statistic)
|
||||||
|
elif self.alltoall_method_type == AlltoallMethodType.DeepEP:
|
||||||
|
if not self.use_postquant_alltoall:
|
||||||
|
x, recv_topk_idx, token_final_scales, num_recv_tokens_per_expert_list, deep_ep_handle = \
|
||||||
|
self.deep_ep_buffer.dispatch(x, token_selected_slots.to(torch.int64), token_final_scales, self.num_slots)
|
||||||
|
elif self.alltoall_method_type == AlltoallMethodType.DeepEPLowLatency:
|
||||||
|
if not self.use_postquant_alltoall:
|
||||||
|
deep_ep_topk_idx = token_selected_slots.to(torch.int64)
|
||||||
|
deep_ep_topk_weights = token_final_scales
|
||||||
|
x, recv_expert_count, deep_ep_handle = \
|
||||||
|
self.deep_ep_buffer.low_latency_dispatch(x, deep_ep_topk_idx, self.deep_ep_max_num_tokens, self.num_slots)
|
||||||
|
# x shape: [#local experts, #max recv tokens, hidden_size]
|
||||||
|
# recv_expert_count shape: [#local experts]
|
||||||
|
|
||||||
|
# Adapter between `torch.ops.trtllm.fused_moe` and DeepEP
|
||||||
|
# TODO: remove the adapter by changing `torch.ops.trtllm.fused_moe` API
|
||||||
|
mask = torch.arange(
|
||||||
|
x.shape[1], dtype=torch.int32, device=x.device).expand(
|
||||||
|
x.shape[0],
|
||||||
|
x.shape[1]) < recv_expert_count.unsqueeze(1)
|
||||||
|
token_selected_slots = torch.full(
|
||||||
|
(x.shape[0], x.shape[1], self.routing_method.top_k),
|
||||||
|
self.num_slots,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=x.device)
|
||||||
|
token_selected_slots[:, :, 0] = torch.where(
|
||||||
|
mask,
|
||||||
|
torch.arange(
|
||||||
|
x.shape[0] * self.mapping.moe_ep_rank,
|
||||||
|
x.shape[0] * (self.mapping.moe_ep_rank + 1),
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=x.device).unsqueeze(1), self.num_slots)
|
||||||
|
x = x.view(x.shape[0] * x.shape[1], x.shape[2])
|
||||||
|
token_selected_slots = token_selected_slots.view(
|
||||||
|
x.shape[0], self.routing_method.top_k)
|
||||||
|
token_final_scales = torch.ones_like(
|
||||||
|
token_selected_slots, dtype=token_final_scales.dtype)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"Not available alltoall method type: {alltoall_method_type!r}"
|
||||||
|
)
|
||||||
|
|
||||||
x_sf = None
|
x_sf = None
|
||||||
if self.has_any_quant:
|
if self.has_any_quant:
|
||||||
if self.has_fp8_qdq:
|
if self.has_fp8_qdq:
|
||||||
@ -414,8 +526,56 @@ class CutlassFusedMoE(MoE):
|
|||||||
quant_scales = self.quant_scales
|
quant_scales = self.quant_scales
|
||||||
|
|
||||||
if self.use_postquant_alltoall:
|
if self.use_postquant_alltoall:
|
||||||
x, x_sf = self.alltoall_postquant_dispatch(x, x_sf, x_row, x_col,
|
if self.alltoall_method_type == AlltoallMethodType.MNNVL:
|
||||||
alltoall_info)
|
x, x_sf = self.alltoall_postquant_dispatch(
|
||||||
|
x, x_sf, x_row, x_col, alltoall_info)
|
||||||
|
elif self.alltoall_method_type == AlltoallMethodType.DeepEP:
|
||||||
|
if x_sf is not None:
|
||||||
|
if self.has_nvfp4:
|
||||||
|
x_sf = unswizzle_sf(x_sf, x_row, x_col,
|
||||||
|
self.scaling_vector_size)
|
||||||
|
# Adapter between `x_sf` and DeepEP
|
||||||
|
# TODO: remove the adapter by adding dtype support to DeepEP
|
||||||
|
x_sf_dtype = x_sf.dtype
|
||||||
|
x_sf = x_sf.view(torch.float32)
|
||||||
|
(x, x_sf), recv_topk_idx, token_final_scales, num_recv_tokens_per_expert_list, deep_ep_handle = \
|
||||||
|
self.deep_ep_buffer.dispatch((x, x_sf), token_selected_slots.to(torch.int64), token_final_scales, self.num_slots)
|
||||||
|
if x_sf is not None:
|
||||||
|
x_sf = x_sf.view(x_sf_dtype)
|
||||||
|
if self.has_nvfp4:
|
||||||
|
x_sf = swizzle_sf(x_sf, x.shape[0], x.shape[1] * 2,
|
||||||
|
self.scaling_vector_size)
|
||||||
|
elif self.alltoall_method_type == AlltoallMethodType.DeepEPLowLatency:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Not implemented postquant for DeepEPLowLatency, please set TRTLLM_MOE_POST_QUANT_ALLTOALLV=0"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"Not available alltoall method type: {alltoall_method_type!r}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.enable_alltoall:
|
||||||
|
# Adapter between `torch.ops.trtllm.fused_moe` and DeepEP
|
||||||
|
# TODO: remove the adapter by changing APIs
|
||||||
|
if self.alltoall_method_type == AlltoallMethodType.DeepEP:
|
||||||
|
token_selected_slots = recv_topk_idx.to(torch.int32)
|
||||||
|
mask = token_selected_slots == -1
|
||||||
|
token_selected_slots += self.expert_size_per_partition * self.mapping.moe_ep_rank
|
||||||
|
token_selected_slots[mask] = self.num_slots
|
||||||
|
num_recv_token_is_zero = x.shape[0] == 0
|
||||||
|
if x.shape[0] == 0:
|
||||||
|
x = torch.zeros((1, x.shape[1]),
|
||||||
|
dtype=x.dtype,
|
||||||
|
device=x.device)
|
||||||
|
token_selected_slots = torch.full(
|
||||||
|
(1, token_selected_slots.shape[1]),
|
||||||
|
self.num_slots,
|
||||||
|
dtype=token_selected_slots.dtype,
|
||||||
|
device=token_selected_slots.device)
|
||||||
|
token_final_scales = torch.ones(
|
||||||
|
(1, token_final_scales.shape[1]),
|
||||||
|
dtype=token_final_scales.dtype,
|
||||||
|
device=token_final_scales.device)
|
||||||
|
|
||||||
final_hidden_states = torch.ops.trtllm.fused_moe(
|
final_hidden_states = torch.ops.trtllm.fused_moe(
|
||||||
x,
|
x,
|
||||||
@ -452,9 +612,25 @@ class CutlassFusedMoE(MoE):
|
|||||||
final_hidden_states = final_hidden_states[0]
|
final_hidden_states = final_hidden_states[0]
|
||||||
|
|
||||||
if self.enable_alltoall:
|
if self.enable_alltoall:
|
||||||
final_hidden_states = self.alltoall_combine(final_hidden_states,
|
if self.alltoall_method_type == AlltoallMethodType.MNNVL:
|
||||||
alltoall_info,
|
final_hidden_states = self.alltoall_combine(
|
||||||
token_count)
|
final_hidden_states, alltoall_info, token_count)
|
||||||
|
elif self.alltoall_method_type == AlltoallMethodType.DeepEP:
|
||||||
|
if num_recv_token_is_zero:
|
||||||
|
final_hidden_states = final_hidden_states[:0]
|
||||||
|
final_hidden_states = self.deep_ep_buffer.combine(
|
||||||
|
final_hidden_states, deep_ep_handle)
|
||||||
|
elif self.alltoall_method_type == AlltoallMethodType.DeepEPLowLatency:
|
||||||
|
final_hidden_states = self.deep_ep_buffer.low_latency_combine(
|
||||||
|
final_hidden_states.view(
|
||||||
|
self.expert_size_per_partition,
|
||||||
|
self.deep_ep_max_num_tokens * self.mapping.moe_ep_size,
|
||||||
|
final_hidden_states.shape[1]), deep_ep_topk_idx,
|
||||||
|
deep_ep_topk_weights, deep_ep_handle)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"Not available alltoall method type: {alltoall_method_type!r}"
|
||||||
|
)
|
||||||
|
|
||||||
if self.layer_load_balancer and not self.layer_load_balancer.is_static_routing(
|
if self.layer_load_balancer and not self.layer_load_balancer.is_static_routing(
|
||||||
) and is_last_call:
|
) and is_last_call:
|
||||||
|
|||||||
@ -23,7 +23,6 @@ class TRTLLMGenFusedMoE(MoE):
|
|||||||
dtype (Optional[torch.dtype]): Data type for the weights.
|
dtype (Optional[torch.dtype]): Data type for the weights.
|
||||||
reduce_results (bool): Whether to reduce the results across devices.
|
reduce_results (bool): Whether to reduce the results across devices.
|
||||||
model_config (ModelConfig): Configuration object for the model.
|
model_config (ModelConfig): Configuration object for the model.
|
||||||
enable_alltoall (bool): whether to enable alltoall instead of allgather/reducescatter
|
|
||||||
|
|
||||||
MoE torch custom op:
|
MoE torch custom op:
|
||||||
Only support min-latency mode now (SM100 Blackwell only).
|
Only support min-latency mode now (SM100 Blackwell only).
|
||||||
|
|||||||
@ -89,8 +89,6 @@ class VanillaMoE(nn.ModuleList):
|
|||||||
if model_config.moe_max_num_tokens
|
if model_config.moe_max_num_tokens
|
||||||
is not None else max_num_tokens)
|
is not None else max_num_tokens)
|
||||||
|
|
||||||
self.enable_alltoall = False
|
|
||||||
|
|
||||||
self._weights_created = False
|
self._weights_created = False
|
||||||
if not model_config.skip_create_weights_in_init:
|
if not model_config.skip_create_weights_in_init:
|
||||||
self.create_weights()
|
self.create_weights()
|
||||||
@ -458,7 +456,7 @@ class VanillaMoE(nn.ModuleList):
|
|||||||
use_dp_padding: Optional[bool] = None,
|
use_dp_padding: Optional[bool] = None,
|
||||||
):
|
):
|
||||||
outputs = inputs
|
outputs = inputs
|
||||||
if self.parallel_size > 1 and not self.enable_alltoall:
|
if self.parallel_size > 1:
|
||||||
if self.use_dp:
|
if self.use_dp:
|
||||||
outputs = reducescatter(
|
outputs = reducescatter(
|
||||||
inputs,
|
inputs,
|
||||||
|
|||||||
@ -27,7 +27,6 @@ class MoE(nn.Module):
|
|||||||
dtype (Optional[torch.dtype]): Data type for the weights.
|
dtype (Optional[torch.dtype]): Data type for the weights.
|
||||||
reduce_results (bool): Whether to reduce the results across devices.
|
reduce_results (bool): Whether to reduce the results across devices.
|
||||||
model_config (ModelConfig): Configuration object for the model.
|
model_config (ModelConfig): Configuration object for the model.
|
||||||
enable_alltoall (bool): whether to enable alltoall instead of allgather/reducescatter
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -123,3 +122,9 @@ class MoE(nn.Module):
|
|||||||
assert self._weights_created
|
assert self._weights_created
|
||||||
return self.quant_config is not None and self.quant_config.layer_quant_mode.has_nvfp4(
|
return self.quant_config is not None and self.quant_config.layer_quant_mode.has_nvfp4(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def enable_alltoall(self):
|
||||||
|
""" enable_alltoall (bool): whether to enable alltoall instead of allgather/reducescatter
|
||||||
|
"""
|
||||||
|
return False
|
||||||
|
|||||||
@ -1003,6 +1003,7 @@ class PyTorchModelEngine(ModelEngine):
|
|||||||
checkpoint_dir,
|
checkpoint_dir,
|
||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
enable_min_latency=self.pytorch_backend_config.enable_min_latency,
|
enable_min_latency=self.pytorch_backend_config.enable_min_latency,
|
||||||
|
use_cuda_graph=self.pytorch_backend_config.use_cuda_graph,
|
||||||
spec_config=self.spec_config,
|
spec_config=self.spec_config,
|
||||||
max_num_tokens=max_num_tokens,
|
max_num_tokens=max_num_tokens,
|
||||||
moe_max_num_tokens=moe_max_num_tokens,
|
moe_max_num_tokens=moe_max_num_tokens,
|
||||||
|
|||||||
@ -75,6 +75,9 @@ class Logger(metaclass=Singleton):
|
|||||||
self._polygraphy_logger.module_severity = severity_map[
|
self._polygraphy_logger.module_severity = severity_map[
|
||||||
min_severity][2]
|
min_severity][2]
|
||||||
|
|
||||||
|
# For log_once
|
||||||
|
self._appeared_keys = set()
|
||||||
|
|
||||||
if invalid_severity:
|
if invalid_severity:
|
||||||
self.warning(
|
self.warning(
|
||||||
f"Requested log level {environ_severity} is invalid. Using '{self.DEFAULT_LEVEL}' instead"
|
f"Requested log level {environ_severity} is invalid. Using '{self.DEFAULT_LEVEL}' instead"
|
||||||
@ -109,23 +112,44 @@ class Logger(metaclass=Singleton):
|
|||||||
parts.extend(map(str, msg))
|
parts.extend(map(str, msg))
|
||||||
self._func_wrapper(severity)(" ".join(parts))
|
self._func_wrapper(severity)(" ".join(parts))
|
||||||
|
|
||||||
|
def log_once(self, severity, *msg, key):
|
||||||
|
if key not in self._appeared_keys:
|
||||||
|
self._appeared_keys.add(key)
|
||||||
|
self.log(severity, *msg)
|
||||||
|
|
||||||
def critical(self, *msg):
|
def critical(self, *msg):
|
||||||
self.log(self.INTERNAL_ERROR, *msg)
|
self.log(self.INTERNAL_ERROR, *msg)
|
||||||
|
|
||||||
|
def critical_once(self, *msg, key):
|
||||||
|
self.log_once(self.INTERNAL_ERROR, *msg, key=key)
|
||||||
|
|
||||||
fatal = critical
|
fatal = critical
|
||||||
|
fatal_once = critical_once
|
||||||
|
|
||||||
def error(self, *msg):
|
def error(self, *msg):
|
||||||
self.log(self.ERROR, *msg)
|
self.log(self.ERROR, *msg)
|
||||||
|
|
||||||
|
def error_once(self, *msg, key):
|
||||||
|
self.log_once(self.ERROR, *msg, key=key)
|
||||||
|
|
||||||
def warning(self, *msg):
|
def warning(self, *msg):
|
||||||
self.log(self.WARNING, *msg)
|
self.log(self.WARNING, *msg)
|
||||||
|
|
||||||
|
def warning_once(self, *msg, key):
|
||||||
|
self.log_once(self.WARNING, *msg, key=key)
|
||||||
|
|
||||||
def info(self, *msg):
|
def info(self, *msg):
|
||||||
self.log(self.INFO, *msg)
|
self.log(self.INFO, *msg)
|
||||||
|
|
||||||
|
def info_once(self, *msg, key):
|
||||||
|
self.log_once(self.INFO, *msg, key=key)
|
||||||
|
|
||||||
def debug(self, *msg):
|
def debug(self, *msg):
|
||||||
self.log(self.VERBOSE, *msg)
|
self.log(self.VERBOSE, *msg)
|
||||||
|
|
||||||
|
def debug_once(self, *msg, key):
|
||||||
|
self.log_once(self.VERBOSE, *msg, key=key)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def level(self) -> str:
|
def level(self) -> str:
|
||||||
return self._min_severity
|
return self._min_severity
|
||||||
|
|||||||
@ -54,6 +54,8 @@ l0_dgx_h100:
|
|||||||
auto_trigger: deepseek
|
auto_trigger: deepseek
|
||||||
tests:
|
tests:
|
||||||
- unittest/_torch/multi_gpu_modeling -k "deepseek"
|
- unittest/_torch/multi_gpu_modeling -k "deepseek"
|
||||||
|
- unittest/_torch/modules/test_fused_moe.py::test_fused_moe_alltoall[DeepEP]
|
||||||
|
- unittest/_torch/modules/test_fused_moe.py::test_fused_moe_alltoall[DeepEPLowLatency]
|
||||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp4-mtp_nextn=0-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False]
|
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp4-mtp_nextn=0-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False]
|
||||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp4-mtp_nextn=0-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=True]
|
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp4-mtp_nextn=0-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=True]
|
||||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp4-mtp_nextn=0-fp8kv=True-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False]
|
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp4-mtp_nextn=0-fp8kv=True-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False]
|
||||||
|
|||||||
@ -2,6 +2,7 @@ import pickle
|
|||||||
import sys
|
import sys
|
||||||
from itertools import product
|
from itertools import product
|
||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Optional
|
||||||
|
from unittest import mock
|
||||||
|
|
||||||
import cloudpickle
|
import cloudpickle
|
||||||
import pytest
|
import pytest
|
||||||
@ -19,6 +20,8 @@ from tensorrt_llm._torch.modules.fused_moe import (BaseMoeRoutingMethod,
|
|||||||
DefaultMoeRoutingMethod,
|
DefaultMoeRoutingMethod,
|
||||||
RenormalizeMoeRoutingMethod,
|
RenormalizeMoeRoutingMethod,
|
||||||
VanillaMoE)
|
VanillaMoE)
|
||||||
|
from tensorrt_llm._torch.modules.fused_moe.fused_moe_cutlass import \
|
||||||
|
AlltoallMethodType
|
||||||
from tensorrt_llm._torch.modules.gated_mlp import GatedMLP
|
from tensorrt_llm._torch.modules.gated_mlp import GatedMLP
|
||||||
from tensorrt_llm._utils import mpi_rank
|
from tensorrt_llm._utils import mpi_rank
|
||||||
from tensorrt_llm.mapping import Mapping
|
from tensorrt_llm.mapping import Mapping
|
||||||
@ -123,6 +126,111 @@ def test_fused_moe_multi_gpu(moe_cls, ep_size):
|
|||||||
assert r is None
|
assert r is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(torch.cuda.device_count() < 4,
|
||||||
|
reason="needs 4 GPUs to run this test")
|
||||||
|
@pytest.mark.parametrize("alltoall_method_type", [
|
||||||
|
AlltoallMethodType.MNNVL, AlltoallMethodType.DeepEP,
|
||||||
|
AlltoallMethodType.DeepEPLowLatency
|
||||||
|
],
|
||||||
|
ids=lambda s: s.name)
|
||||||
|
def test_fused_moe_alltoall(alltoall_method_type):
|
||||||
|
world_size = 4
|
||||||
|
dtype = torch.bfloat16
|
||||||
|
HIDDEN_SIZE = 2560
|
||||||
|
INTERMEDIATE_SIZE = 1536
|
||||||
|
NUM_EXPERTS = 72
|
||||||
|
TOP_K = 6
|
||||||
|
MAX_NUM_TOKENS = 2048
|
||||||
|
|
||||||
|
def per_rank_test_fused_moe_alltoall(job_id):
|
||||||
|
routing_method = DefaultMoeRoutingMethod(top_k=TOP_K)
|
||||||
|
mapping = Mapping(world_size=world_size,
|
||||||
|
rank=mpi_rank(),
|
||||||
|
tp_size=world_size,
|
||||||
|
moe_ep_size=world_size,
|
||||||
|
moe_tp_size=1,
|
||||||
|
enable_attention_dp=True)
|
||||||
|
torch.cuda.set_device(mapping.rank)
|
||||||
|
torch.manual_seed(mapping.rank)
|
||||||
|
|
||||||
|
weights = {}
|
||||||
|
for expert_id in range(NUM_EXPERTS):
|
||||||
|
w1_weight = torch.empty((INTERMEDIATE_SIZE, HIDDEN_SIZE),
|
||||||
|
dtype=dtype)
|
||||||
|
w2_weight = torch.empty((HIDDEN_SIZE, INTERMEDIATE_SIZE),
|
||||||
|
dtype=dtype)
|
||||||
|
w3_weight = torch.empty((INTERMEDIATE_SIZE, HIDDEN_SIZE),
|
||||||
|
dtype=dtype)
|
||||||
|
torch.nn.init.xavier_uniform_(w1_weight)
|
||||||
|
torch.nn.init.xavier_uniform_(w2_weight)
|
||||||
|
torch.nn.init.xavier_uniform_(w3_weight)
|
||||||
|
weights[f"{expert_id}.w1.weight"] = w1_weight
|
||||||
|
weights[f"{expert_id}.w2.weight"] = w2_weight
|
||||||
|
weights[f"{expert_id}.w3.weight"] = w3_weight
|
||||||
|
with mock.patch.object(CutlassFusedMoE,
|
||||||
|
"select_alltoall_method_type",
|
||||||
|
return_value=alltoall_method_type):
|
||||||
|
alltoall_model = CutlassFusedMoE(
|
||||||
|
num_experts=NUM_EXPERTS,
|
||||||
|
routing_method=routing_method,
|
||||||
|
hidden_size=HIDDEN_SIZE,
|
||||||
|
intermediate_size=INTERMEDIATE_SIZE,
|
||||||
|
dtype=dtype,
|
||||||
|
reduce_results=True,
|
||||||
|
model_config=ModelConfig(mapping=mapping,
|
||||||
|
max_num_tokens=MAX_NUM_TOKENS),
|
||||||
|
)
|
||||||
|
alltoall_model.to("cuda")
|
||||||
|
alltoall_model.load_weights([weights])
|
||||||
|
with mock.patch.object(CutlassFusedMoE,
|
||||||
|
"select_alltoall_method_type",
|
||||||
|
return_value=AlltoallMethodType.NotEnabled):
|
||||||
|
ref_model = CutlassFusedMoE(
|
||||||
|
num_experts=NUM_EXPERTS,
|
||||||
|
routing_method=routing_method,
|
||||||
|
hidden_size=HIDDEN_SIZE,
|
||||||
|
intermediate_size=INTERMEDIATE_SIZE,
|
||||||
|
dtype=dtype,
|
||||||
|
reduce_results=True,
|
||||||
|
model_config=ModelConfig(mapping=mapping,
|
||||||
|
max_num_tokens=MAX_NUM_TOKENS),
|
||||||
|
)
|
||||||
|
ref_model.to("cuda")
|
||||||
|
ref_model.load_weights([weights])
|
||||||
|
|
||||||
|
# Evaluate the outputs on a variant sequence length to verify the robustness of alltoall methods
|
||||||
|
m = MAX_NUM_TOKENS
|
||||||
|
while m >= 1:
|
||||||
|
x = torch.randn((m, HIDDEN_SIZE), dtype=dtype).cuda()
|
||||||
|
router_logits = torch.randn((m, NUM_EXPERTS), dtype=dtype).cuda()
|
||||||
|
all_rank_num_tokens = [m] * mapping.world_size
|
||||||
|
|
||||||
|
with torch.inference_mode():
|
||||||
|
output = alltoall_model.forward(
|
||||||
|
x,
|
||||||
|
router_logits,
|
||||||
|
all_rank_num_tokens=all_rank_num_tokens,
|
||||||
|
use_dp_padding=False)
|
||||||
|
ref_output = ref_model.forward(
|
||||||
|
x,
|
||||||
|
router_logits,
|
||||||
|
all_rank_num_tokens=all_rank_num_tokens,
|
||||||
|
use_dp_padding=False)
|
||||||
|
|
||||||
|
# Evaluate outputs
|
||||||
|
torch.testing.assert_close(output,
|
||||||
|
ref_output,
|
||||||
|
rtol=0.05,
|
||||||
|
atol=0.003)
|
||||||
|
m //= 2
|
||||||
|
|
||||||
|
with MPIPoolExecutor(max_workers=world_size) as executor:
|
||||||
|
results = executor.map(per_rank_test_fused_moe_alltoall,
|
||||||
|
range(world_size))
|
||||||
|
for r in results:
|
||||||
|
assert r is None
|
||||||
|
|
||||||
|
|
||||||
@skip_pre_hopper
|
@skip_pre_hopper
|
||||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||||
def test_fused_moe_fp8(dtype):
|
def test_fused_moe_fp8(dtype):
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user