mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
Merge 1fee67097d into 6df2c8a074
This commit is contained in:
commit
1a2a8b7641
@ -493,7 +493,12 @@ __global__ void moeA2ADispatchKernel(int32_t const* token_selected_experts, // [
|
||||
#if !DISABLE_SYNC_FOR_PROFILING
|
||||
uint32_t expected_value = *ptrs.flag_val;
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900
|
||||
// .acquire and .release qualifiers for fence instruction require sm_90 or higher.
|
||||
asm volatile("fence.release.sys;");
|
||||
#else
|
||||
asm volatile("fence.acq_rel.sys;");
|
||||
#endif
|
||||
#pragma unroll 1 // No unroll as one iter is typically enough
|
||||
for (int target_rank = lane_id; target_rank < ep_size; target_rank += warpSize)
|
||||
{
|
||||
@ -525,7 +530,6 @@ __global__ void moeA2ADispatchKernel(int32_t const* token_selected_experts, // [
|
||||
flag_set = flag_value == expected_value;
|
||||
} while (!flag_set);
|
||||
}
|
||||
// asm volatile("fence.acquire.sys;");
|
||||
#endif
|
||||
}
|
||||
}
|
||||
@ -1018,7 +1022,6 @@ __global__ void moeA2ACombineKernel(
|
||||
|
||||
if (blockIdx.x == 0)
|
||||
{
|
||||
// asm volatile("fence.release.sys;");
|
||||
#pragma unroll 1 // No unroll
|
||||
for (int peer_rank = lane_id; peer_rank < ep_size; peer_rank += warpSize)
|
||||
{
|
||||
@ -1050,7 +1053,12 @@ __global__ void moeA2ACombineKernel(
|
||||
flag_set = flag_value == expected_value;
|
||||
} while (!flag_set);
|
||||
}
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900
|
||||
// .acquire and .release qualifiers for fence instruction require sm_90 or higher.
|
||||
asm volatile("fence.acquire.sys;");
|
||||
#else
|
||||
asm volatile("fence.acq_rel.sys;");
|
||||
#endif
|
||||
}
|
||||
__syncthreads();
|
||||
#endif
|
||||
|
||||
@ -1,8 +1,8 @@
|
||||
# Multi-stage Dockerfile
|
||||
ARG BASE_IMAGE=nvcr.io/nvidia/pytorch
|
||||
ARG TRITON_IMAGE=nvcr.io/nvidia/tritonserver
|
||||
ARG BASE_TAG=25.10-py3
|
||||
ARG TRITON_BASE_TAG=25.10-py3
|
||||
ARG BASE_TAG=25.12-py3
|
||||
ARG TRITON_BASE_TAG=25.12-py3
|
||||
ARG DEVEL_IMAGE=devel
|
||||
|
||||
FROM ${BASE_IMAGE}:${BASE_TAG} AS base
|
||||
@ -147,6 +147,7 @@ RUN --mount=type=cache,target=/root/.cache/pip --mount=type=bind,from=wheel,sour
|
||||
pip install /tmp/wheel/tensorrt_llm*.whl
|
||||
|
||||
COPY README.md ./
|
||||
COPY --from=wheel /src/tensorrt_llm/build/tensorrt_llm*.whl ./
|
||||
COPY docs docs
|
||||
COPY cpp/include include
|
||||
|
||||
|
||||
@ -202,17 +202,16 @@ jenkins-rockylinux8_%: PYTHON_VERSION_TAG_ID = $(if $(findstring 3.12,${PYTHON_V
|
||||
jenkins-rockylinux8_%: IMAGE_WITH_TAG = $(shell . ../jenkins/current_image_tags.properties && echo $$LLM_ROCKYLINUX8_${PYTHON_VERSION_TAG_ID}_DOCKER_IMAGE)
|
||||
jenkins-rockylinux8_%: STAGE = tritondevel
|
||||
jenkins-rockylinux8_%: BASE_IMAGE = nvcr.io/nvidia/cuda
|
||||
# [TODO] Update to NVIDIA CUDA 13.0.2 when it's available
|
||||
jenkins-rockylinux8_%: BASE_TAG = 13.0.1-devel-rockylinux8
|
||||
jenkins-rockylinux8_%: BASE_TAG = 13.1.0-devel-rockylinux8
|
||||
|
||||
rockylinux8_%: STAGE = tritondevel
|
||||
rockylinux8_%: BASE_IMAGE = nvcr.io/nvidia/cuda
|
||||
rockylinux8_%: BASE_TAG = 13.0.1-devel-rockylinux8
|
||||
rockylinux8_%: BASE_TAG = 13.1.0-devel-rockylinux8
|
||||
|
||||
# For x86_64 and aarch64
|
||||
ubuntu22_%: STAGE = tritondevel
|
||||
ubuntu22_%: BASE_IMAGE = nvcr.io/nvidia/cuda
|
||||
ubuntu22_%: BASE_TAG = 13.0.1-devel-ubuntu22.04
|
||||
ubuntu22_%: BASE_TAG = 13.1.0-devel-ubuntu22.04
|
||||
|
||||
trtllm_%: STAGE = release
|
||||
trtllm_%: PUSH_TO_STAGING := 0
|
||||
|
||||
@ -5,7 +5,7 @@ set -ex
|
||||
# This script is used for reinstalling CUDA on Rocky Linux 8 with the run file.
|
||||
# CUDA version is usually aligned with the latest NGC CUDA image tag.
|
||||
# Only use when public CUDA image is not ready.
|
||||
CUDA_VER="13.0.2_580.95.05"
|
||||
CUDA_VER="13.1.0_590.44.01"
|
||||
CUDA_VER_SHORT="${CUDA_VER%_*}"
|
||||
|
||||
NVCC_VERSION_OUTPUT=$(nvcc --version)
|
||||
|
||||
@ -5,7 +5,7 @@ set -ex
|
||||
if [ -n "${GITHUB_MIRROR}" ]; then
|
||||
export PIP_INDEX_URL="https://urm.nvidia.com/artifactory/api/pypi/pypi-remote/simple"
|
||||
fi
|
||||
pip3 install polygraphy==0.49.9
|
||||
pip3 install polygraphy==0.49.26
|
||||
|
||||
# Clean up pip cache and temporary files
|
||||
pip3 cache purge
|
||||
|
||||
@ -4,8 +4,8 @@ set -ex
|
||||
|
||||
# Use latest stable version from https://pypi.org/project/torch/#history
|
||||
# and closest to the version specified in
|
||||
# https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-25-10.html#rel-25-10
|
||||
TORCH_VERSION="2.9.0"
|
||||
# https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-25-12.html#rel-25-12
|
||||
TORCH_VERSION="2.9.1"
|
||||
SYSTEM_ID=$(grep -oP '(?<=^ID=).+' /etc/os-release | tr -d '"')
|
||||
|
||||
prepare_environment() {
|
||||
|
||||
@ -2,20 +2,20 @@
|
||||
|
||||
set -ex
|
||||
|
||||
TRT_VER="10.13.3.9"
|
||||
TRT_VER="10.14.1.48"
|
||||
# Align with the pre-installed cuDNN / cuBLAS / NCCL versions from
|
||||
# https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-25-10.html#rel-25-10
|
||||
CUDA_VER="13.0" # 13.0.2
|
||||
# https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-25-12.html#rel-25-12
|
||||
CUDA_VER="13.1" # 13.1.0
|
||||
# Keep the installation for cuDNN if users want to install PyTorch with source codes.
|
||||
# PyTorch 2.x can compile with cuDNN v9.
|
||||
CUDNN_VER="9.14.0.64-1"
|
||||
NCCL_VER="2.27.7-1+cuda13.0"
|
||||
CUBLAS_VER="13.1.0.3-1"
|
||||
CUDNN_VER="9.17.0.29-1"
|
||||
NCCL_VER="2.28.9-1+cuda13.0"
|
||||
CUBLAS_VER="13.2.0.9-1"
|
||||
# Align with the pre-installed CUDA / NVCC / NVRTC versions from
|
||||
# https://docs.nvidia.com/cuda/cuda-toolkit-release-notes/index.html
|
||||
NVRTC_VER="13.0.88-1"
|
||||
CUDA_RUNTIME="13.0.96-1"
|
||||
CUDA_DRIVER_VERSION="580.95.05-1.el8"
|
||||
NVRTC_VER="13.1.80-1"
|
||||
CUDA_RUNTIME="13.1.80-1"
|
||||
CUDA_DRIVER_VERSION="590.44.01-1.el8"
|
||||
|
||||
for i in "$@"; do
|
||||
case $i in
|
||||
@ -118,7 +118,12 @@ install_rockylinux_requirements() {
|
||||
install_tensorrt() {
|
||||
PY_VERSION=$(python3 -c 'import sys; print(".".join(map(str, sys.version_info[0:2])))')
|
||||
PARSED_PY_VERSION=$(echo "${PY_VERSION//./}")
|
||||
|
||||
TRT_CUDA_VERSION=${CUDA_VER}
|
||||
# No CUDA 13.1 version for TensorRT yet. Use CUDA 13.0 package instead.
|
||||
if [ "$CUDA_VER" = "13.1" ]; then
|
||||
TRT_CUDA_VERSION="13.0"
|
||||
fi
|
||||
TRT_VER_SHORT=$(echo $TRT_VER | cut -d. -f1-3)
|
||||
|
||||
if [ -z "$RELEASE_URL_TRT" ];then
|
||||
|
||||
@ -83,19 +83,19 @@ def BUILD_CONFIGS = [
|
||||
(WHEEL_EXTRA_ARGS) : "--extra-cmake-vars WARNING_IS_ERROR=ON --extra-cmake-vars NIXL_ROOT=/opt/nvidia/nvda_nixl --extra-cmake-vars MOONCAKE_ROOT=/usr/local/Mooncake",
|
||||
(TARNAME) : "TensorRT-LLM-GH200.tar.gz",
|
||||
(WHEEL_ARCHS): "90-real;100-real;103-real;120-real",
|
||||
(BUILD_JOBS_FOR_CONFIG): "4", // TODO: Remove after fix the build OOM issue on SBSA
|
||||
(BUILD_JOBS_FOR_CONFIG): "8", // TODO: Remove after fix the build OOM issue on SBSA
|
||||
],
|
||||
(CONFIG_LINUX_AARCH64_PYBIND): [
|
||||
(WHEEL_EXTRA_ARGS) : "--binding_type pybind --extra-cmake-vars WARNING_IS_ERROR=ON --extra-cmake-vars NIXL_ROOT=/opt/nvidia/nvda_nixl --extra-cmake-vars MOONCAKE_ROOT=/usr/local/Mooncake",
|
||||
(TARNAME) : "pybind-TensorRT-LLM-GH200.tar.gz",
|
||||
(WHEEL_ARCHS): "90-real;100-real;103-real;120-real",
|
||||
(BUILD_JOBS_FOR_CONFIG): "4", // TODO: Remove after fix the build OOM issue on SBSA
|
||||
(BUILD_JOBS_FOR_CONFIG): "8", // TODO: Remove after fix the build OOM issue on SBSA
|
||||
],
|
||||
(CONFIG_LINUX_AARCH64_LLVM) : [
|
||||
(WHEEL_EXTRA_ARGS) : "--extra-cmake-vars WARNING_IS_ERROR=ON -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_CUDA_HOST_COMPILER=clang -DCMAKE_LINKER_TYPE=LLD",
|
||||
(TARNAME) : "llvm-TensorRT-LLM-GH200.tar.gz",
|
||||
(WHEEL_ARCHS): "90-real;100-real;103-real;120-real",
|
||||
(BUILD_JOBS_FOR_CONFIG): "4", // TODO: Remove after fix the build OOM issue on SBSA
|
||||
(BUILD_JOBS_FOR_CONFIG): "8", // TODO: Remove after fix the build OOM issue on SBSA
|
||||
],
|
||||
]
|
||||
|
||||
|
||||
@ -39,7 +39,7 @@ LLM_ROCKYLINUX8_PY310_DOCKER_IMAGE = env.wheelDockerImagePy310
|
||||
LLM_ROCKYLINUX8_PY312_DOCKER_IMAGE = env.wheelDockerImagePy312
|
||||
|
||||
// DLFW torch image
|
||||
DLFW_IMAGE = "urm.nvidia.com/docker/nvidia/pytorch:25.10-py3"
|
||||
DLFW_IMAGE = "urm.nvidia.com/docker/nvidia/pytorch:25.12-py3"
|
||||
|
||||
//Ubuntu base image
|
||||
UBUNTU_22_04_IMAGE = "urm.nvidia.com/docker/ubuntu:22.04"
|
||||
@ -316,6 +316,11 @@ def processShardTestList(llmSrc, testDBList, splitId, splits, perfMode=false) {
|
||||
foundRunningLine = true
|
||||
return false // Don't include the "Running" line itself
|
||||
}
|
||||
// Stop collecting when we hit the warnings/errors summary separator
|
||||
if (foundRunningLine && line.contains('======================')) {
|
||||
foundRunningLine = false // Stop collecting
|
||||
return false
|
||||
}
|
||||
|
||||
def hasDoubleColon = line.contains('::')
|
||||
def shouldInclude = foundRunningLine && hasDoubleColon
|
||||
@ -3389,7 +3394,7 @@ def launchTestJobs(pipeline, testFilter)
|
||||
// Python version and OS for sanity check
|
||||
x86SanityCheckConfigs = [
|
||||
"PY312-DLFW": [
|
||||
LLM_ROCKYLINUX8_PY312_DOCKER_IMAGE,
|
||||
LLM_DOCKER_IMAGE, // Workaround ABI incompatibilities between PyTorch 2.9.1 and 2.10.0a0
|
||||
"B200_PCIe",
|
||||
X86_64_TRIPLE,
|
||||
false,
|
||||
@ -3418,15 +3423,16 @@ def launchTestJobs(pipeline, testFilter)
|
||||
]
|
||||
|
||||
aarch64SanityCheckConfigs = [
|
||||
/* //Disable PY312-UB2404 temporarily since lack of official PyTorch for CUDA 13.1.
|
||||
"PY312-UB2404": [
|
||||
LLM_DOCKER_IMAGE,
|
||||
"GH200",
|
||||
AARCH64_TRIPLE,
|
||||
false,
|
||||
"",
|
||||
UBUNTU_24_04_IMAGE,
|
||||
true, // Extra PyTorch CUDA 13.0 install
|
||||
],
|
||||
DLFW_IMAGE,
|
||||
false, // Extra PyTorch CUDA 13.0 install
|
||||
],*/
|
||||
"PY312-DLFW": [
|
||||
LLM_DOCKER_IMAGE,
|
||||
"GH200",
|
||||
@ -3524,7 +3530,7 @@ def launchTestJobs(pipeline, testFilter)
|
||||
def platform = cpu_arch == X86_64_TRIPLE ? "x86_64" : "sbsa"
|
||||
trtllm_utils.llmExecStepWithRetry(pipeline, script: "wget https://developer.download.nvidia.com/compute/cuda/repos/${ubuntu_version}/${platform}/cuda-keyring_1.1-1_all.deb")
|
||||
trtllm_utils.llmExecStepWithRetry(pipeline, script: "dpkg -i cuda-keyring_1.1-1_all.deb")
|
||||
trtllm_utils.llmExecStepWithRetry(pipeline, script: "apt-get update && apt-get install -y cuda-toolkit-13-0")
|
||||
trtllm_utils.llmExecStepWithRetry(pipeline, script: "apt-get update && apt-get install -y cuda-toolkit-13-1")
|
||||
}
|
||||
// Extra PyTorch CUDA 13.0 install for all bare-metal environments (Default PyTorch is for CUDA 12.8)
|
||||
if (values[6]) {
|
||||
@ -3532,9 +3538,9 @@ def launchTestJobs(pipeline, testFilter)
|
||||
// Use internal mirror instead of https://download.pytorch.org/whl/cu130 for better network stability.
|
||||
// PyTorch CUDA 13.0 package and torchvision package can be installed as expected.
|
||||
if (k8s_arch == "amd64") {
|
||||
trtllm_utils.llmExecStepWithRetry(pipeline, script: "pip3 install torch==2.9.0+cu130 torchvision==0.24.0+cu130 --extra-index-url https://urm.nvidia.com/artifactory/api/pypi/pytorch-cu128-remote/simple")
|
||||
trtllm_utils.llmExecStepWithRetry(pipeline, script: "pip3 install torch==2.9.1+cu130 torchvision==0.24.1+cu130 --extra-index-url https://urm.nvidia.com/artifactory/api/pypi/pytorch-cu128-remote/simple")
|
||||
} else {
|
||||
trtllm_utils.llmExecStepWithRetry(pipeline, script: "pip3 install torch==2.9.0+cu130 torchvision==0.24.0 --extra-index-url https://urm.nvidia.com/artifactory/api/pypi/pytorch-cu128-remote/simple")
|
||||
trtllm_utils.llmExecStepWithRetry(pipeline, script: "pip3 install torch==2.9.1+cu130 torchvision==0.24.1 --extra-index-url https://urm.nvidia.com/artifactory/api/pypi/pytorch-cu128-remote/simple")
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -13,7 +13,7 @@
|
||||
# images are adopted from PostMerge pipelines, the abbreviated commit hash is used instead.
|
||||
IMAGE_NAME=urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm
|
||||
|
||||
LLM_DOCKER_IMAGE=urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:pytorch-25.10-py3-x86_64-ubuntu24.04-trt10.13.3.9-skip-tritondevel-202512241744-10055
|
||||
LLM_SBSA_DOCKER_IMAGE=urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:pytorch-25.10-py3-aarch64-ubuntu24.04-trt10.13.3.9-skip-tritondevel-202512241744-10055
|
||||
LLM_ROCKYLINUX8_PY310_DOCKER_IMAGE=urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:cuda-13.0.2-devel-rocky8-x86_64-rocky8-py310-trt10.13.3.9-skip-tritondevel-202512241744-10055
|
||||
LLM_ROCKYLINUX8_PY312_DOCKER_IMAGE=urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:cuda-13.0.2-devel-rocky8-x86_64-rocky8-py312-trt10.13.3.9-skip-tritondevel-202512241744-10055
|
||||
LLM_DOCKER_IMAGE=urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:pytorch-25.12-py3-x86_64-ubuntu24.04-trt10.14.1.48-skip-tritondevel-202601011103-9818
|
||||
LLM_SBSA_DOCKER_IMAGE=urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:pytorch-25.12-py3-aarch64-ubuntu24.04-trt10.14.1.48-skip-tritondevel-202601011103-9818
|
||||
LLM_ROCKYLINUX8_PY310_DOCKER_IMAGE=urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:cuda-13.1.0-devel-rocky8-x86_64-rocky8-py310-trt10.14.1.48-skip-tritondevel-202601011103-9818
|
||||
LLM_ROCKYLINUX8_PY312_DOCKER_IMAGE=urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:cuda-13.1.0-devel-rocky8-x86_64-rocky8-py312-trt10.14.1.48-skip-tritondevel-202601011103-9818
|
||||
|
||||
@ -19,13 +19,14 @@ pandas
|
||||
h5py==3.12.1
|
||||
StrEnum
|
||||
sentencepiece>=0.1.99
|
||||
tensorrt~=10.13.3
|
||||
# https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-25-10.html#rel-25-10 uses 2.9.0a0.
|
||||
torch>=2.9.0a0,<=2.9.0
|
||||
tensorrt~=10.14.1
|
||||
# https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-25-12.html#rel-25-12 uses 2.10.0a0.
|
||||
torch>=2.9.1,<=2.10.0a0
|
||||
torchvision
|
||||
nvidia-modelopt[torch]~=0.37.0
|
||||
# https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-25-10.html#rel-25-10 uses 2.27.7
|
||||
nvidia-nccl-cu13==2.27.7
|
||||
# https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-25-12.html#rel-25-12 uses 2.28.9
|
||||
# torch 2.9.1+cu130 depends on nvidia-nccl-cu13==2.27.7
|
||||
nvidia-nccl-cu13>=2.27.7,<=2.28.9
|
||||
nvidia-cuda-nvrtc
|
||||
transformers==4.57.1
|
||||
prometheus_client
|
||||
@ -65,7 +66,7 @@ ninja
|
||||
etcd3 @ git+https://github.com/kragniz/python-etcd3.git@e58a899579ba416449c4e225b61f039457c8072a
|
||||
blake3
|
||||
soundfile
|
||||
triton==3.5.0
|
||||
triton==3.5.1
|
||||
tiktoken
|
||||
blobfile
|
||||
openai-harmony==0.0.4
|
||||
|
||||
@ -257,6 +257,13 @@ def fused_moe(
|
||||
[gemm_tactic_1, gemm_tactic_2], activation_type,
|
||||
unpadded_hidden_size, tuner_num_tokens, out_tensor)
|
||||
|
||||
# When out_tensor is provided, the result is written in-place to out_tensor.
|
||||
# Return empty list to avoid aliasing constraint violation in PyTorch 2.9.1+
|
||||
# (custom op output cannot be the same tensor as input).
|
||||
# Callers should use out_tensor directly when they provide it.
|
||||
if out_tensor is not None and not min_latency_mode:
|
||||
return []
|
||||
|
||||
return output if min_latency_mode else [output]
|
||||
|
||||
|
||||
|
||||
@ -1102,9 +1102,17 @@ def mxe4m3_mxe2m1_block_scale_moe_runner(
|
||||
0] = routing_logits # replace dummy routing logits with actual routing logits
|
||||
input_tensors[-2] = topk_weights # replace dummy topk_weights with actual
|
||||
input_tensors[-1] = topk_ids # replace dummy topk_ids with actual
|
||||
return kernel_runner(input_tensors,
|
||||
tactic=[-1, -1] if best_tactic == -1 else best_tactic,
|
||||
output=output)
|
||||
result = kernel_runner(
|
||||
input_tensors,
|
||||
tactic=[-1, -1] if best_tactic == -1 else best_tactic,
|
||||
output=output)
|
||||
# When output is provided, the result is written in-place to output.
|
||||
# Return empty tensor to avoid aliasing constraint violation in PyTorch 2.9.1+
|
||||
# (custom op output cannot be the same tensor as input).
|
||||
# Callers should use output directly when they provide it.
|
||||
if output is not None:
|
||||
return torch.empty(0, device=result.device, dtype=result.dtype)
|
||||
return result
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
|
||||
@ -433,7 +433,7 @@ class CutlassFusedMoE(MoE):
|
||||
elif self.has_w4a16_mxfp4:
|
||||
weight_dtype = torch.uint8
|
||||
|
||||
final_hidden_states = torch.ops.trtllm.fused_moe(
|
||||
result = torch.ops.trtllm.fused_moe(
|
||||
x,
|
||||
token_selected_experts,
|
||||
token_final_scales,
|
||||
@ -468,10 +468,13 @@ class CutlassFusedMoE(MoE):
|
||||
unpadded_hidden_size=self.unpadded_hidden_size,
|
||||
out_tensor=moe_output,
|
||||
)
|
||||
# Custom op requires all inputs are in the same type.
|
||||
# Only in cutlass_min_latency_mode, the output is a list of tensors.
|
||||
# Otherwise, the output should be unpacked as a single tensor.
|
||||
final_hidden_states = final_hidden_states[0]
|
||||
# When moe_output is provided, the result is written in-place and
|
||||
# fused_moe returns empty list to avoid aliasing constraint violation.
|
||||
# Otherwise, unpack the single tensor from the returned list.
|
||||
if moe_output is not None:
|
||||
final_hidden_states = moe_output
|
||||
else:
|
||||
final_hidden_states = result[0]
|
||||
|
||||
return final_hidden_states
|
||||
|
||||
|
||||
@ -610,7 +610,7 @@ class TRTLLMGenFusedMoE(MoE):
|
||||
intermediate_size_per_partition_padded = self.w3_w1_weight.shape[
|
||||
-2] // 2
|
||||
|
||||
final_hidden_states = torch.ops.trtllm.mxe4m3_mxe2m1_block_scale_moe_runner(
|
||||
result = torch.ops.trtllm.mxe4m3_mxe2m1_block_scale_moe_runner(
|
||||
router_logits,
|
||||
routing_bias,
|
||||
x,
|
||||
@ -640,6 +640,10 @@ class TRTLLMGenFusedMoE(MoE):
|
||||
token_selected_experts,
|
||||
output=moe_output,
|
||||
)
|
||||
|
||||
# When output is provided, use it directly as the result
|
||||
# (custom op returns empty tensor to avoid PyTorch aliasing constraints)
|
||||
final_hidden_states = moe_output if moe_output is not None else result
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"TRTLLMGenFusedMoE only supports fp8_block_scaling, nvfp4, w4a16_mxfp4, w4a8_mxfp4_mxfp8 and w4a8_mxfp4_fp8 dtypes."
|
||||
|
||||
@ -306,6 +306,9 @@ full:RTXPro6000D/accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp
|
||||
examples/test_ray.py::test_ray_disaggregated_serving[tp2] SKIP (https://nvbugs/5612502)
|
||||
unittest/executor/test_rpc_proxy.py SKIP (https://nvbugs/5605741)
|
||||
unittest/executor/test_rpc_worker.py SKIP (https://nvbugs/5605741)
|
||||
cpp/test_e2e.py::test_model[-redrafter-86] SKIP (https://nvbugs/5761642)
|
||||
unittest/_torch/thop/parallel/test_fp8_block_scale_gemm.py::test_deep_gemm_in_subprocess[env2] SKIP (https://nvbugs/5766853)
|
||||
test_e2e.py::test_openai_responses SKIP (https://nvbugs/5804146)
|
||||
triton_server/test_triton.py::test_gpt_gather_logits[gpt-gather-logits] SKIP (https://nvbugs/5766960)
|
||||
stress_test/stress_test.py::test_run_stress_test[llama-v3-8b-instruct-hf_tp1-stress_time_300s_timeout_450s-GUARANTEED_NO_EVICT-pytorch-stress-test] SKIP (https://nvbugs/5766952)
|
||||
full:sm89/accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen_tp_asymmetric[MMLU-gen_tp=1-ctx_pp=2] SKIP (https://nvbugs/5596337)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user