mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
568 lines
19 KiB
Python
568 lines
19 KiB
Python
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import faulthandler
|
|
import math
|
|
import os
|
|
import time
|
|
import unittest
|
|
from contextlib import contextmanager
|
|
from dataclasses import dataclass
|
|
from difflib import SequenceMatcher
|
|
from pathlib import Path
|
|
from typing import Any, Generator
|
|
|
|
import psutil
|
|
import pynvml
|
|
import pytest
|
|
import tensorrt as trt
|
|
import torch
|
|
|
|
try:
|
|
from cuda.bindings import driver as cuda
|
|
from cuda.bindings import nvrtc
|
|
except ImportError:
|
|
from cuda import cuda, nvrtc
|
|
|
|
from parameterized import parameterized
|
|
|
|
import tensorrt_llm
|
|
from tensorrt_llm._torch.hostfunc import hostfunc
|
|
from tensorrt_llm._utils import (mpi_disabled, torch_dtype_to_trt,
|
|
trt_dtype_to_torch)
|
|
from tensorrt_llm.llmapi.utils import get_total_gpu_memory
|
|
from tensorrt_llm.plugin.plugin import ContextFMHAType
|
|
from tensorrt_llm.quantization import QuantMode
|
|
from tensorrt_llm.runtime import Session, TensorInfo
|
|
|
|
|
|
def ASSERT_DRV(err):
|
|
if isinstance(err, cuda.CUresult):
|
|
if err != cuda.CUresult.CUDA_SUCCESS:
|
|
raise RuntimeError('Cuda Error: {}'.format(err))
|
|
elif isinstance(err, nvrtc.nvrtcResult):
|
|
if err != nvrtc.nvrtcResult.NVRTC_SUCCESS:
|
|
raise RuntimeError('Nvrtc Error: {}'.format(err))
|
|
else:
|
|
raise RuntimeError('Unknown error type: {}'.format(err))
|
|
|
|
|
|
# ref: https://github.com/NVIDIA/cuda-python/blob/main/examples/extra/jit_program_test.py
|
|
def getSMVersion():
|
|
# Init
|
|
err, = cuda.cuInit(0)
|
|
ASSERT_DRV(err)
|
|
|
|
# Device
|
|
err, cuDevice = cuda.cuDeviceGet(0)
|
|
ASSERT_DRV(err)
|
|
|
|
# Get target architecture
|
|
err, sm_major = cuda.cuDeviceGetAttribute(
|
|
cuda.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR,
|
|
cuDevice)
|
|
ASSERT_DRV(err)
|
|
err, sm_minor = cuda.cuDeviceGetAttribute(
|
|
cuda.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR,
|
|
cuDevice)
|
|
ASSERT_DRV(err)
|
|
|
|
return sm_major * 10 + sm_minor
|
|
|
|
|
|
def getCUDAVersion():
|
|
import subprocess
|
|
|
|
try:
|
|
cuda_version = subprocess.run(['nvcc', '--version'],
|
|
stdout=subprocess.PIPE,
|
|
universal_newlines=True)
|
|
output = cuda_version.stdout.split()
|
|
release_version = output[-4].replace(',', '.').split('.')
|
|
return int(release_version[0]) * 100 + int(release_version[1])
|
|
except Exception as e:
|
|
print(f"Error getting CUDA version: {e}")
|
|
|
|
|
|
def isSM100Family():
|
|
sm = getSMVersion()
|
|
return sm == 100 or sm == 103
|
|
|
|
|
|
skip_pre_ada = pytest.mark.skipif(
|
|
getSMVersion() < 89,
|
|
reason="This test is not supported in pre-Ada architecture")
|
|
skip_pre_hopper = pytest.mark.skipif(
|
|
getSMVersion() < 90,
|
|
reason="This test is not supported in pre-Hopper architecture")
|
|
skip_pre_blackwell = pytest.mark.skipif(
|
|
getSMVersion() < 100,
|
|
reason="This test is not supported in pre-Blackwell architecture")
|
|
skip_blackwell = pytest.mark.skipif(
|
|
getSMVersion() == 100 or getSMVersion() == 103,
|
|
reason="This test is not supported in Blackwell architecture")
|
|
skip_blackwell_geforce = pytest.mark.skipif(
|
|
getSMVersion() == 120, reason="This test is not supported on SM 120")
|
|
|
|
# If used together with @parameterized, we have to use unittest.skipIf instead of pytest.mark.skipif
|
|
skip_pre_ada_unittest = unittest.skipIf(
|
|
getSMVersion() < 89 or (getSMVersion() == 89 and getCUDAVersion() < 1204),
|
|
reason=
|
|
"This test is not supported in pre-Ada architecture, and for Ada we require cuda version >= 12.4"
|
|
)
|
|
skip_pre_hopper_unittest = unittest.skipIf(
|
|
getSMVersion() < 90,
|
|
reason="This test is not supported in pre-Hopper architecture")
|
|
skip_pre_blackwell_unittest = unittest.skipIf(
|
|
getSMVersion() < 100,
|
|
reason="This test is not supported in pre-Blackwell architecture")
|
|
skip_non_ada_unittest = unittest.skipIf(
|
|
getSMVersion() != 89,
|
|
reason="This test is only supported in Ada architecture")
|
|
skip_non_hopper_unittest = unittest.skipIf(
|
|
getSMVersion() != 90,
|
|
reason="This test is only supported in Hopper architecture")
|
|
skip_neither_ada_nor_hopper_unittest = unittest.skipIf(
|
|
getSMVersion() != 90 and getSMVersion() != 89,
|
|
reason="This test is only supported in Ada or Hopper architecture")
|
|
|
|
IGNORE_ARCH = os.environ.get('TLLM_TEST_IGNORE_ARCH', False)
|
|
force_ampere = pytest.mark.skipif(
|
|
(not IGNORE_ARCH) and (getSMVersion() < 80 or getSMVersion() > 89),
|
|
reason="This test is only enabled in Ampere architecture")
|
|
|
|
|
|
def is_bf16(dtype):
|
|
return dtype == 'bfloat16' or dtype == 'bf16' or dtype == torch.bfloat16
|
|
|
|
|
|
def skip_fp8_pre_ada(use_fp8):
|
|
if use_fp8 and getSMVersion() < 89:
|
|
pytest.skip("FP8 is not supported on pre-Ada architectures")
|
|
|
|
|
|
def skip_blackwell_for_fmha_tests(context_fmha_type, head_size):
|
|
if (isSM100Family()) and (head_size not in [32, 64, 128] and
|
|
context_fmha_type != ContextFMHAType.disabled):
|
|
pytest.skip(
|
|
"Context FMHA only supports head sizes [32, 64, 128] currently on blackwell."
|
|
)
|
|
|
|
|
|
def skip_fp4_pre_blackwell(use_fp4):
|
|
if use_fp4 and getSMVersion() < 100:
|
|
pytest.skip("FP4 is not supported on pre-Blackwell architectures")
|
|
|
|
|
|
def skip_bf16_fp32_accum(dtype, context_fmha_type):
|
|
if is_bf16(dtype
|
|
) and context_fmha_type == ContextFMHAType.enabled_with_fp32_acc:
|
|
pytest.skip(
|
|
"bfloat16 Context FMHA will always accumulate on FP32, so it has been tested with ContextFMHAType.enabled"
|
|
)
|
|
|
|
|
|
def skip_num_gpus_less_than(num_gpus: int):
|
|
return pytest.mark.skipif(
|
|
torch.cuda.device_count() < num_gpus,
|
|
reason=f"The test needs at least {num_gpus} GPUs, skipping")
|
|
|
|
|
|
skip_single_gpu = skip_num_gpus_less_than(2)
|
|
|
|
|
|
def compose_decorator(*decorators):
|
|
|
|
def composed_decorator(f):
|
|
for dec in reversed(decorators):
|
|
f = dec(f)
|
|
return f
|
|
|
|
return composed_decorator
|
|
|
|
|
|
pytest.mark.gpu2 = compose_decorator(skip_single_gpu, pytest.mark.gpu2)
|
|
pytest.mark.gpu4 = compose_decorator(skip_num_gpus_less_than(4),
|
|
pytest.mark.gpu4)
|
|
|
|
|
|
def skip_gpu_memory_less_than(required_memory: int):
|
|
memory = get_total_gpu_memory(0)
|
|
return pytest.mark.skipif(
|
|
required_memory > memory,
|
|
reason=
|
|
f'Not enough GPU memory for this test (wanted {required_memory}, have {memory})'
|
|
)
|
|
|
|
|
|
skip_gpu_memory_less_than_40gb = skip_gpu_memory_less_than(40 * 1000 * 1000 *
|
|
1000)
|
|
|
|
skip_gpu_memory_less_than_80gb = skip_gpu_memory_less_than(80 * 1000 * 1000 *
|
|
1000)
|
|
|
|
skip_gpu_memory_less_than_138gb = skip_gpu_memory_less_than(138 * 1000 * 1000 *
|
|
1000)
|
|
|
|
|
|
def modelopt_installed():
|
|
try:
|
|
# isort: off
|
|
import modelopt.torch.quantization as atq # NOQA
|
|
from modelopt.torch.export import export_tensorrt_llm_checkpoint # NOQA
|
|
# isort: on
|
|
return True
|
|
except Exception:
|
|
return False
|
|
|
|
|
|
skip_no_modelopt = unittest.skipIf(not modelopt_installed(),
|
|
reason="Modelopt is not installed")
|
|
|
|
|
|
def check_nvlink():
|
|
"check nvlink is active"
|
|
try:
|
|
pynvml.nvmlInit()
|
|
device_count = pynvml.nvmlDeviceGetCount()
|
|
if device_count > 0:
|
|
handle = pynvml.nvmlDeviceGetHandleByIndex(0)
|
|
try:
|
|
link_state = pynvml.nvmlDeviceGetNvLinkState(handle, 0)
|
|
print(f"Link state is {link_state}")
|
|
|
|
except pynvml.NVMLError as error:
|
|
print(
|
|
f"Device does not seem to support NVLink or there's an issue: {error}"
|
|
)
|
|
pynvml.nvmlShutdown()
|
|
return True
|
|
|
|
except pynvml.NVMLError as error:
|
|
print(f"Error initializing NVML or other NVML error: {error}")
|
|
return False
|
|
|
|
|
|
skip_nvlink_inactive = unittest.skipIf(not check_nvlink(),
|
|
reason="nvlink is inactive.")
|
|
|
|
|
|
# This function names will make all unit tests names to show the values of all parameters in @parameterized.expand
|
|
def unittest_name_func(testcase_func, param_num, param):
|
|
expand_params = lambda params: '_'.join([
|
|
expand_params(x) if isinstance(x, (list, tuple)) else str(x)
|
|
for x in params
|
|
])
|
|
name = expand_params(param.args)
|
|
return "%s_%s" % (
|
|
testcase_func.__name__,
|
|
parameterized.to_safe_name(name),
|
|
)
|
|
|
|
|
|
def set_input_shape(profile,
|
|
inp: tensorrt_llm.Tensor,
|
|
shape: tuple,
|
|
data: torch.Tensor = None):
|
|
set_input_shapes(profile, inp, shape, shape, shape, data)
|
|
return
|
|
|
|
|
|
def set_input_shapes(profile,
|
|
inp: tensorrt_llm.Tensor,
|
|
min_shape: tuple,
|
|
opt_shape: tuple,
|
|
max_shape: tuple,
|
|
data: torch.Tensor = None):
|
|
if inp.trt_tensor.is_shape_tensor:
|
|
# For shape tensors, TensorRT expects the full tensor (on CPU), not just shape
|
|
assert data is not None, f"For shape tensor {inp.name}, TensorRT needs the tensor value."
|
|
assert str(data.device) == "cpu", f"Shape tensor's data needs to be on CPU " \
|
|
f"(device found={data.device}) for both updating the profile and for execution."
|
|
np_data = data.flatten().numpy()
|
|
profile.set_shape_input(inp.name, np_data, np_data, np_data)
|
|
return
|
|
profile.set_shape(inp.name, min_shape, opt_shape, max_shape)
|
|
return
|
|
|
|
|
|
def create_session(builder,
|
|
network,
|
|
precision="float32",
|
|
int8=False,
|
|
fp8=False,
|
|
memory_pool_limit=None,
|
|
optimization_profiles=[],
|
|
quant_mode=QuantMode(0)):
|
|
"""
|
|
This function creates an engine and a tensorrt_llm.runtime.Session for the engine.
|
|
Args:
|
|
network: a tensorrt_llm.Network object
|
|
precision: the precision of the network, choose from ["float32", "float16", "bfloat16"]
|
|
**kwargs: builder flags such as int8, fp8, etc.
|
|
Returns:
|
|
session: a tensorrt_llm.runtime.Session
|
|
"""
|
|
builder_config = builder.create_builder_config(precision=precision,
|
|
int8=int8,
|
|
fp8=fp8,
|
|
quant_mode=quant_mode)
|
|
# Some tests require to set mem pool limit to avoid OOM
|
|
if memory_pool_limit is not None:
|
|
builder_config.trt_builder_config.set_memory_pool_limit(
|
|
trt.MemoryPoolType.WORKSPACE, memory_pool_limit)
|
|
# Some tests include shape tensors, so the optimization profile needs to be feed in explicitly
|
|
if len(optimization_profiles) > 0:
|
|
for profile in optimization_profiles:
|
|
builder_config.trt_builder_config.add_optimization_profile(profile)
|
|
# Disable TF32 for accuracy in testing.
|
|
builder_config.trt_builder_config.clear_flag(trt.BuilderFlag.TF32)
|
|
engine = builder.build_engine(network, builder_config)
|
|
assert engine is not None, "Failed to build engine"
|
|
session = Session.from_serialized_engine(engine)
|
|
return session
|
|
|
|
|
|
def run_session(session: Session,
|
|
inputs,
|
|
outputs={},
|
|
override_shapes={},
|
|
override_types={}):
|
|
"""
|
|
The current session object needs to pass in both inputs and outputs bindings.
|
|
For test convenience, create a function that infers output shapes automatically,
|
|
This function is similar to tensorrt_llm.runtime.Session._debug_run, and Polygraphy runner.infer,
|
|
where only input shape is required.
|
|
NOTES:
|
|
1. The outputs dictionary is required for outputs for which the shapes cannot be inferred.
|
|
This function will prioritize to use the tensor in this dictionary.
|
|
2. `override_shapes` can be used to force some input tensors' shape to be different than the passed tensor.
|
|
Required for zero-volume tensors since torch.Tensor.data_ptr() is nullptr for such tensors.
|
|
3. `override_types` can be used to force some input tensors' type to be different than the passed tensor.
|
|
Required for zero-volume tensors since torch.Tensor.data_ptr() is nullptr for such tensors.
|
|
"""
|
|
|
|
# Prepare output tensors.
|
|
output_info = session.infer_shapes([
|
|
TensorInfo(
|
|
name,
|
|
torch_dtype_to_trt(tensor.dtype if name not in
|
|
override_types else override_types[name]),
|
|
tensor.shape
|
|
if name not in override_shapes else override_shapes[name])
|
|
for name, tensor in inputs.items()
|
|
])
|
|
|
|
def create_torch(t):
|
|
if t.dtype == trt.fp4:
|
|
shape = list(t.shape)
|
|
shape[-1] = shape[-1] // 2
|
|
return torch.empty(tuple(shape), dtype=torch.uint8, device='cuda')
|
|
else:
|
|
return torch.empty(tuple(t.shape),
|
|
dtype=trt_dtype_to_torch(t.dtype),
|
|
device='cuda')
|
|
|
|
outputs = {
|
|
t.name: create_torch(t) if t.name not in outputs else outputs[t.name]
|
|
for t in output_info
|
|
}
|
|
|
|
# Execute model inference
|
|
stream = torch.cuda.current_stream()
|
|
ok = session.run(inputs=inputs, outputs=outputs, stream=stream.cuda_stream)
|
|
assert ok, 'Engine execution failed'
|
|
stream.synchronize()
|
|
|
|
return outputs
|
|
|
|
|
|
def similarity_score(a, b):
|
|
"similar compare a and b "
|
|
return SequenceMatcher(None, a, b).ratio()
|
|
|
|
|
|
def similar(a, b, threshold=0.8):
|
|
"similar compare a and b "
|
|
return similarity_score(a, b) >= threshold
|
|
|
|
|
|
def get_project_root(test_file: str) -> Path:
|
|
return next(p for p in Path(test_file).resolve().parents
|
|
if (p / 'tests').is_dir() and (p / "tensorrt_llm").is_dir())
|
|
|
|
|
|
@contextmanager
|
|
def default_dtype(dtype: torch.dtype):
|
|
cur_default = torch.get_default_dtype()
|
|
torch.set_default_dtype(dtype)
|
|
yield
|
|
torch.set_default_dtype(cur_default)
|
|
|
|
|
|
def woq_assert_near_eq(ref, act, wTypeId):
|
|
# match the scale in cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_preprocessors.cpp
|
|
if wTypeId == 1:
|
|
bits_in_type = 8
|
|
else:
|
|
bits_in_type = 4
|
|
quant_range_scale = 1.0 / float(1 << (bits_in_type - 1))
|
|
|
|
max_val = torch.max(abs(ref)).item()
|
|
atol = (max_val * quant_range_scale) * 1.5 # allow for rounding
|
|
torch.testing.assert_close(ref, act, atol=atol, rtol=1e-7)
|
|
|
|
|
|
def woq_groupwise_gt_matmul(mat1, ref_torch_weights, bias=None):
|
|
ref = torch.matmul(mat1, ref_torch_weights)
|
|
if bias is not None:
|
|
ref += bias
|
|
return ref
|
|
|
|
|
|
def flatten_list_generator(
|
|
nested_list: list[Any]) -> Generator[Any, None, None]:
|
|
if not isinstance(nested_list, list):
|
|
yield nested_list
|
|
else:
|
|
for item in nested_list:
|
|
yield from flatten_list_generator(item)
|
|
|
|
|
|
def flatten_list(nested_list: list[Any]) -> list[Any]:
|
|
return list(flatten_list_generator(nested_list))
|
|
|
|
|
|
def duplicate_list_to_length(list: list[Any], target_length: int) -> list[Any]:
|
|
if target_length < len(list):
|
|
return list[:target_length]
|
|
duplicated_list = list * (target_length // len(list))
|
|
remain = target_length % len(list)
|
|
if remain != 0:
|
|
duplicated_list += list[:remain]
|
|
return duplicated_list
|
|
|
|
|
|
# Check a certain percentage of elements in two tensors are within a tolerance
|
|
def check_accuracy(a, b, atol, rtol, percent):
|
|
assert a.shape == b.shape
|
|
assert a.dtype == b.dtype
|
|
a = a.to(torch.float32)
|
|
b = b.to(torch.float32)
|
|
left = torch.abs(a - b)
|
|
right = atol + rtol * torch.abs(b)
|
|
count = torch.sum(left > right)
|
|
mismatch_percent = count / a.numel()
|
|
if not (mismatch_percent < 1 - percent):
|
|
raise Exception("Mismatch percentage is %f for rtol %f" %
|
|
(mismatch_percent, rtol))
|
|
|
|
|
|
skip_ray = pytest.mark.skipif(
|
|
mpi_disabled(), reason="This test is skipped for Ray orchestrator.")
|
|
|
|
|
|
@dataclass(kw_only=True)
|
|
class DeviceSleepCtl:
|
|
_cancellation_requested: bool = False
|
|
|
|
@property
|
|
def cancellation_requested(self):
|
|
return self._cancellation_requested
|
|
|
|
def cancel(self):
|
|
self._cancellation_requested = True
|
|
|
|
|
|
@hostfunc
|
|
def device_sleep(
|
|
duration_s: float,
|
|
*,
|
|
ctl: DeviceSleepCtl,
|
|
spin_s: float = 0.1,
|
|
):
|
|
spin_iters = math.ceil(duration_s / spin_s)
|
|
for _ in range(spin_iters):
|
|
if ctl.cancellation_requested:
|
|
break
|
|
time.sleep(spin_s)
|
|
|
|
|
|
@contextmanager
|
|
def assert_no_cuda_sync(
|
|
sync_timeout_s: float = 5, ) -> Generator[None, None, None]:
|
|
"""Check that the function does not stream synchronize."""
|
|
# NB: This implementation only assumes that the CUDA operations performed
|
|
# in the guarded scope use the currently selected CUDA stream. This
|
|
# should also cover custom Torch ops as well as non-Torch kernels.
|
|
#
|
|
# Python's faulthandler is used to provide tracebacks which can help
|
|
# pin-pointing synchronizing code. This is combined with PyTorch's
|
|
# less general (instrumentation based) tooling, which is expected
|
|
# to provide improved error reporting for those issues which it can
|
|
# detect.
|
|
|
|
sleep_finished_event = torch.cuda.Event()
|
|
scope_finished_event = torch.cuda.Event()
|
|
|
|
torch.cuda.synchronize()
|
|
sleep_ctl = DeviceSleepCtl()
|
|
faulthandler.dump_traceback_later(sync_timeout_s)
|
|
device_sleep(2 * sync_timeout_s, ctl=sleep_ctl) # cancelled below
|
|
sleep_finished_event.record()
|
|
torch_debug_mode_orig = torch.cuda.get_sync_debug_mode()
|
|
torch.cuda.set_sync_debug_mode("error")
|
|
try:
|
|
yield None
|
|
finally:
|
|
torch.cuda.set_sync_debug_mode(torch_debug_mode_orig)
|
|
scope_finished_event.record()
|
|
|
|
assert not sleep_finished_event.query(
|
|
), """sync code should return quickly"""
|
|
faulthandler.cancel_dump_traceback_later()
|
|
|
|
sleep_ctl.cancel()
|
|
scope_finished_event.synchronize()
|
|
|
|
|
|
_pynvmlInited = False
|
|
|
|
|
|
def get_current_process_gpu_memory(include_subprocess: bool = False) -> int:
|
|
"""
|
|
Returns GPU memory usage for current process in bytes.
|
|
"""
|
|
global _pynvmlInited
|
|
if not _pynvmlInited:
|
|
pynvml.nvmlInit()
|
|
_pynvmlInited = True
|
|
|
|
# Get current process ID
|
|
targets = [os.getpid()]
|
|
if include_subprocess:
|
|
targets.extend(
|
|
p.pid for p in psutil.Process(targets[0]).children(recursive=True))
|
|
targets = frozenset(targets)
|
|
|
|
# Get device handle for GPU 0
|
|
device_handle = pynvml.nvmlDeviceGetHandleByIndex(0)
|
|
|
|
# Get running processes
|
|
processes = pynvml.nvmlDeviceGetComputeRunningProcesses(device_handle)
|
|
|
|
# Find current process
|
|
return sum(process.usedGpuMemory for process in processes
|
|
if process.pid in targets)
|