[https://nvbugs/5467062][fix] pass logitsPostProcessorBatched by reference (#7173)

Signed-off-by: qqiao <qqiao@nvidia.com>
Signed-off-by: Yanchao Lu <yanchaol@nvidia.com>
Co-authored-by: Emma Qiao <qqiao@nvidia.com>
This commit is contained in:
brb-nv 2025-08-25 14:59:05 -07:00 committed by GitHub
parent 37da222ad2
commit 3e27029c1c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
22 changed files with 86 additions and 29 deletions

View File

@ -9,7 +9,7 @@ TensorRT-LLM
[![python](https://img.shields.io/badge/python-3.10-green)](https://www.python.org/downloads/release/python-31012/)
[![cuda](https://img.shields.io/badge/cuda-12.9.1-green)](https://developer.nvidia.com/cuda-downloads)
[![trt](https://img.shields.io/badge/TRT-10.11.0-green)](https://developer.nvidia.com/tensorrt)
[![version](https://img.shields.io/badge/release-0.21.1-green)](./tensorrt_llm/version.py)
[![version](https://img.shields.io/badge/release-0.21.2-green)](./tensorrt_llm/version.py)
[![license](https://img.shields.io/badge/license-Apache%202-blue)](./LICENSE)
[Architecture](./docs/source/torch/arch_overview.md)&nbsp;&nbsp;&nbsp;|&nbsp;&nbsp;&nbsp;[Performance](./docs/source/performance/perf-overview.md)&nbsp;&nbsp;&nbsp;|&nbsp;&nbsp;&nbsp;[Examples](https://nvidia.github.io/TensorRT-LLM/quick-start-guide.html)&nbsp;&nbsp;&nbsp;|&nbsp;&nbsp;&nbsp;[Documentation](./docs/source/)&nbsp;&nbsp;&nbsp;|&nbsp;&nbsp;&nbsp;[Roadmap](https://github.com/NVIDIA/TensorRT-LLM/issues?q=is%3Aissue%20state%3Aopen%20label%3Aroadmap)

View File

@ -46,7 +46,7 @@ public:
bool operator()(RequestVector const& contextRequests, RequestVector const& generationRequests,
bool replicateLogitsPostProcessor, std::vector<batch_manager::LlmRequest::TensorPtr>& seqSlotLogits,
runtime::WorldConfig const& worldConfig, runtime::TllmRuntime& runtime,
std::optional<LogitsPostProcessorBatched> logitsPostProcessorBatched = std::nullopt) const;
std::optional<LogitsPostProcessorBatched> const& logitsPostProcessorBatched = std::nullopt) const;
};
} // namespace tensorrt_llm::batch_manager

View File

@ -1,7 +1,12 @@
import subprocess
import pytest
from cuda import cuda, nvrtc
try:
from cuda.bindings import driver as cuda
from cuda.bindings import nvrtc
except ImportError:
from cuda import cuda, nvrtc
def ASSERT_DRV(err):

View File

@ -37,7 +37,7 @@ using SizeType32 = tensorrt_llm::runtime::SizeType32;
bool LogitsPostProcessor::operator()(RequestVector const& contextRequests, RequestVector const& generationRequests,
bool replicateLogitsPostProcessor, std::vector<TensorPtr>& seqSlotLogits, tr::WorldConfig const& worldConfig,
tr::TllmRuntime& runtime, std::optional<LogitsPostProcessorBatched> logitsPostProcessorBatched) const
tr::TllmRuntime& runtime, std::optional<LogitsPostProcessorBatched> const& logitsPostProcessorBatched) const
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
NVTX3_SCOPED_RANGE(LogitsPostProcessor);

View File

@ -1,3 +1,3 @@
tensorrt_llm==0.21.1
tensorrt_llm==0.21.2
evaluate~=0.4.1
rouge_score~=0.1.2

View File

@ -1791,8 +1791,8 @@ def launchTestJobs(pipeline, testFilter, dockerNode=null)
fullSet = parallelJobs.keySet()
x86SlurmTestConfigs = [
"RTXPro6000-PyTorch-[Post-Merge]-1": ["rtx-pro-6000", "l0_rtx_pro_6000", 1, 1],
"DGX_B200-4_GPUs-PyTorch-[Post-Merge]-1": ["b200-4-gpus", "l0_dgx_b200", 1, 1, 4],
"RTXPro6000-PyTorch-Post-Merge-1": ["rtx-pro-6000", "l0_rtx_pro_6000", 1, 1],
"DGX_B200-4_GPUs-PyTorch-Post-Merge-1": ["b200-x4", "l0_dgx_b200", 1, 1, 4],
]
fullSet += x86SlurmTestConfigs.keySet()
@ -1818,8 +1818,8 @@ def launchTestJobs(pipeline, testFilter, dockerNode=null)
fullSet += SBSATestConfigs.keySet()
SBSASlurmTestConfigs = [
"GB200-4_GPUs-PyTorch-1": ["gb200-4-gpus", "l0_gb200", 1, 1, 4],
"GB200-4_GPUs-PyTorch-[Post-Merge]-1": ["gb200-4-gpus", "l0_gb200", 1, 1, 4],
"GB200-4_GPUs-PyTorch-1": ["gb200-x4", "l0_gb200", 1, 1, 4],
"GB200-4_GPUs-PyTorch-Post-Merge-1": ["gb200-x4", "l0_gb200", 1, 1, 4],
]
fullSet += SBSASlurmTestConfigs.keySet()

View File

@ -3,7 +3,7 @@
accelerate>=0.25.0
build
colored
cuda-python # Do not override the custom version of cuda-python installed in the NGC PyTorch image.
cuda-python>=12,<13
diffusers>=0.27.0
lark
mpi4py
@ -58,3 +58,4 @@ ninja
etcd3
blake3
llguidance==0.7.29
triton==3.3.1; platform_machine == "x86_64"

View File

@ -17,17 +17,20 @@ import struct
import sys
from typing import List, Tuple
from cuda import cuda, cudart
from cuda.cudart import cudaError_t
try:
from cuda.bindings import driver as cuda
from cuda.bindings import runtime as cudart
except ImportError:
from cuda import cuda, cudart
from ._utils import mpi_comm
from .logger import logger
from .mapping import Mapping
def _raise_if_error(error: cudaError_t | cuda.CUresult):
if isinstance(error, cudaError_t):
if error != cudaError_t.cudaSuccess:
def _raise_if_error(error: cudart.cudaError_t | cuda.CUresult):
if isinstance(error, cudart.cudaError_t):
if error != cudart.cudaError_t.cudaSuccess:
raise RuntimeError(f"CUDA Runtime API error: {repr(error)}")
if isinstance(error, cuda.CUresult):
if error != cuda.CUresult.CUDA_SUCCESS:

View File

@ -18,7 +18,11 @@ from dataclasses import dataclass
import pynvml
import torch
from cuda import cuda
try:
from cuda.bindings import driver as cuda
except ImportError:
from cuda import cuda
from ._dlpack_utils import pack_strided_memory
from ._utils import mpi_comm

View File

@ -5,7 +5,11 @@ from typing import Dict, Tuple, Union
import pynvml
import torch
from cuda import cudart
try:
from cuda.bindings import runtime as cudart
except ImportError:
from cuda import cudart
from tensorrt_llm._utils import DictConversion
from tensorrt_llm.logger import logger

View File

@ -29,7 +29,10 @@ import numpy as np
import torch
import tensorrt as trt
# isort: on
from cuda import cudart
try:
from cuda.bindings import runtime as cudart
except ImportError:
from cuda import cudart
from tensorrt_llm.runtime.memory_pools.memory_pools_allocator import \
MemoryPoolsAllocator

View File

@ -13,7 +13,12 @@ import math
from typing import Optional, Tuple
import torch.nn.functional as F
from cuda import cudart
try:
from cuda.bindings import runtime as cudart
except ImportError:
from cuda import cudart
from huggingface_hub import hf_hub_download
from PIL import Image, UnidentifiedImageError
from safetensors import safe_open

View File

@ -12,4 +12,4 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__version__ = "0.21.1"
__version__ = "0.21.2"

View File

@ -24,7 +24,11 @@ import sys
import psutil
import pynvml
from cuda import cuda
try:
from cuda.bindings import driver as cuda
except ImportError:
from cuda import cuda
# Logger
logger = logging.getLogger(__name__)

View File

@ -18,7 +18,10 @@ from argparse import ArgumentParser
# isort: off
import torch
# isort: on
from cuda import cuda, cudart
try:
from cuda.bindings import runtime as cudart
except ImportError:
from cuda import cudart
import tensorrt_llm as tllm
from tensorrt_llm import Mapping, Tensor

View File

@ -7,7 +7,11 @@ import time
import traceback
import tensorrt as trt
from cuda import cudart
try:
from cuda.bindings import runtime as cudart
except ImportError:
from cuda import cudart
import tensorrt_llm
from tensorrt_llm import (AutoConfig, AutoModelForCausalLM, BuildConfig,

View File

@ -34,7 +34,10 @@ MPI.pickle.__init__(
def run_single_rank(dtype, strategy, message_size):
import numpy as np
import torch
from cuda import cuda
try:
from cuda.bindings import driver as cuda
except ImportError:
from cuda import cuda
import tensorrt_llm
from tensorrt_llm._torch.distributed import AllReduce, AllReduceStrategy

View File

@ -21,7 +21,10 @@ import pytest
import torch
# isort: on
from cuda import cudart
try:
from cuda.bindings import runtime as cudart
except ImportError:
from cuda import cudart
from parameterized import parameterized
from utils.util import create_session, run_session, unittest_name_func

View File

@ -21,7 +21,10 @@ import pytest
import torch
# isort: on
from cuda import cudart
try:
from cuda.bindings import runtime as cudart
except ImportError:
from cuda import cudart
from parameterized import parameterized
from utils.util import create_session, run_session, unittest_name_func

View File

@ -21,7 +21,10 @@ import pytest
import torch
# isort: on
from cuda import cudart
try:
from cuda.bindings import runtime as cudart
except ImportError:
from cuda import cudart
from parameterized import parameterized
from utils.util import create_session, run_session, unittest_name_func

View File

@ -21,7 +21,10 @@ import pytest
import torch
# isort: on
from cuda import cudart
try:
from cuda.bindings import runtime as cudart
except ImportError:
from cuda import cudart
from parameterized import parameterized
from utils.util import create_session, run_session, unittest_name_func

View File

@ -7,7 +7,13 @@ import pynvml
import pytest
import tensorrt as trt
import torch
from cuda import cuda, nvrtc
try:
from cuda.bindings import driver as cuda
from cuda.bindings import nvrtc
except ImportError:
from cuda import cuda, nvrtc
from parameterized import parameterized
import tensorrt_llm