TensorRT-LLMs/tensorrt_llm/_utils.py
Sharan Chetlur 258c7540c0 open source 09df54c0cc99354a60bbc0303e3e8ea33a96bef0 (#2725)
Co-authored-by: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com>

open source f8c0381a2bc50ee2739c3d8c2be481b31e5f00bd (#2736)

Co-authored-by: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com>

Add note for blackwell (#2742)

Update the docs to workaround the extra-index-url issue (#2744)

update README.md (#2751)

Fix github io pages (#2761)

Update
2025-02-11 02:21:51 +00:00

720 lines
19 KiB
Python

# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import gc
import inspect
import json
import math
import struct
import weakref
from contextlib import contextmanager
from dataclasses import asdict
from enum import EnumMeta
from functools import partial
from pathlib import Path
from typing import Any, Dict, List, Optional, Sequence, Union
import numpy as np
from cuda import cuda
from packaging import version
# isort: off
import torch
import tensorrt as trt
# isort: on
from tensorrt_llm.bindings import DataType, GptJsonConfig
from tensorrt_llm.bindings.BuildInfo import ENABLE_MULTI_DEVICE
# numpy doesn't know bfloat16, define abstract binary type instead
np_bfloat16 = np.dtype('V2', metadata={"dtype": "bfloat16"})
np_float8 = np.dtype('V1', metadata={"dtype": "float8"})
def torch_to_numpy(x: torch.Tensor):
assert isinstance(x, torch.Tensor), \
f'x must be a torch.Tensor object, but got {type(x)}.'
if x.dtype == torch.bfloat16:
return x.view(torch.int16).detach().cpu().numpy().view(np_bfloat16)
elif x.dtype == torch.float8_e4m3fn:
return x.view(torch.int8).detach().cpu().numpy().view(np_float8)
else:
return x.detach().cpu().numpy()
def numpy_to_torch(x):
if x.dtype == np_bfloat16:
return torch.from_numpy(x.view(np.int16)).view(torch.bfloat16)
elif x.dtype == np_float8:
return torch.from_numpy(x.view(np.int8)).view(torch.float8_e4m3fn)
else:
return torch.from_numpy(x)
def numpy_to_dtype(x, dtype: str):
if str_dtype_to_np(dtype) == x.dtype:
return x
if x.dtype not in [np_bfloat16, np_float8
] and dtype not in ['bfloat16', 'fp8']:
return x.astype(str_dtype_to_np(dtype))
else:
return torch_to_numpy(numpy_to_torch(x).to(str_dtype_to_torch(dtype)))
fp32_array = partial(np.array, dtype=np.float32)
fp16_array = partial(np.array, dtype=np.float16)
int32_array = partial(np.array, dtype=np.int32)
int64_array = partial(np.array, dtype=np.int64)
bool_array = partial(np.array, dtype=np.bool_)
def dims_array(x):
is_int64_dims = True
try:
trt.Dims([np.iinfo(np.int64).max])
except TypeError:
is_int64_dims = False
return int64_array(x) if is_int64_dims else int32_array(x)
def bf16_array(x):
x = torch.tensor(x, dtype=torch.bfloat16)
x = torch_to_numpy(x)
return x
def numpy_array(data, trt_dtype):
# convenient wrapper due to numpy not support bf16 yet
if trt_dtype == trt.bfloat16:
return bf16_array(data)
return np.array(data, trt_dtype_to_np(trt_dtype))
def copy_torch_to_numpy(x: torch.Tensor, ndarray: np.array):
if x.dtype == torch.bfloat16:
torch.from_numpy(ndarray.view(np.int16)).copy_(x.view(torch.int16))
elif x.dtype == torch.float8_e4m3fn:
torch.from_numpy(ndarray.view(np.int8)).copy_(x.view(torch.int8))
else:
torch.from_numpy(ndarray).copy_(x)
return ndarray
# ref: https://github.com/NVIDIA/cuda-python/blob/main/examples/extra/jit_program_test.py
def get_sm_version():
# Init
err, = cuda.cuInit(0)
assert err == cuda.CUresult.CUDA_SUCCESS, f"Cuda Error: {err}"
# Device
err, cuDevice = cuda.cuDeviceGet(0)
assert err == cuda.CUresult.CUDA_SUCCESS, f"Cuda Error: {err}"
# Get target architecture
err, sm_major = cuda.cuDeviceGetAttribute(
cuda.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR,
cuDevice)
assert err == cuda.CUresult.CUDA_SUCCESS, f"Cuda Error: {err}"
err, sm_minor = cuda.cuDeviceGetAttribute(
cuda.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR,
cuDevice)
assert err == cuda.CUresult.CUDA_SUCCESS, f"Cuda Error: {err}"
return sm_major * 10 + sm_minor
def trt_version():
return trt.__version__
def trt_gte(major: int, minor: int = 0):
"""
Check if TRT version is greater than or equal to major.minor
"""
trt_ver = version.parse(trt_version())
return trt_ver.major >= major and trt_ver.minor >= minor
def torch_version():
return torch.__version__
_str_to_np_dict = dict(
float16=np.float16,
float32=np.float32,
int64=np.int64,
int32=np.int32,
int8=np.int8,
bool=np.bool_,
bfloat16=np_bfloat16,
fp8=np_float8,
)
def str_dtype_to_np(dtype):
ret = _str_to_np_dict.get(dtype)
assert ret is not None, f'Unsupported dtype: {dtype}'
return ret
_str_to_torch_dtype_dict = dict(
bfloat16=torch.bfloat16,
float16=torch.float16,
float32=torch.float32,
int64=torch.int64,
int32=torch.int32,
int8=torch.int8,
bool=torch.bool,
fp8=torch.float8_e4m3fn,
)
def str_dtype_to_torch(dtype):
ret = _str_to_torch_dtype_dict.get(dtype)
assert ret is not None, f'Unsupported dtype: {dtype}'
return ret
_str_to_binding_dtype_dict = dict(
bfloat16=DataType.BF16,
float16=DataType.HALF,
float32=DataType.FLOAT,
int64=DataType.INT64,
int32=DataType.INT32,
int8=DataType.INT8,
bool=DataType.BOOL,
fp8=DataType.FP8,
)
def str_dtype_to_binding(dtype):
ret = _str_to_binding_dtype_dict.get(dtype)
assert ret is not None, f'Unsupported dtype: {dtype}'
return ret
_torch_dtype_to_str_dict = {v: k for k, v in _str_to_torch_dtype_dict.items()}
def torch_dtype_to_str(dtype):
return _torch_dtype_to_str_dict[dtype]
_str_to_trt_dtype_dict = dict(float16=trt.float16,
float32=trt.float32,
int64=trt.int64,
int32=trt.int32,
int8=trt.int8,
bool=trt.bool,
bfloat16=trt.bfloat16,
fp8=trt.fp8,
nvfp4=trt.fp4)
def str_dtype_to_trt(dtype):
ret = _str_to_trt_dtype_dict.get(dtype)
assert ret is not None, f'Unsupported dtype: {dtype}'
return ret
_trt_to_str_dtype_dict = {v: k for k, v in _str_to_trt_dtype_dict.items()}
def trt_dtype_to_str(dtype: trt.DataType) -> str:
assert isinstance(dtype, trt.DataType)
return _trt_to_str_dtype_dict[dtype]
_np_to_trt_dtype_dict = {
np.int8: trt.int8,
np.int32: trt.int32,
np.int64: trt.int64,
np.float16: trt.float16,
np.float32: trt.float32,
np.bool_: trt.bool,
# hash of np.dtype('int32') != np.int32
np.dtype('int8'): trt.int8,
np.dtype('int32'): trt.int32,
np.dtype('int64'): trt.int64,
np.dtype('float16'): trt.float16,
np.dtype('float32'): trt.float32,
np.dtype('bool'): trt.bool,
np_bfloat16: trt.bfloat16,
np_float8: trt.fp8,
}
def np_dtype_to_trt(dtype):
ret = _np_to_trt_dtype_dict.get(dtype)
assert ret is not None, f'Unsupported dtype: {dtype}'
return ret
_trt_to_np_dtype_dict = {
trt.int8: np.int8,
trt.int32: np.int32,
trt.int64: np.int64,
trt.float16: np.float16,
trt.float32: np.float32,
trt.bool: np.bool_,
trt.bfloat16: np_bfloat16,
trt.fp8: np_float8,
}
def trt_dtype_to_np(dtype):
ret = _trt_to_np_dtype_dict.get(dtype)
assert ret is not None, f'Unsupported dtype: {dtype}'
return ret
_torch_to_np_dtype_dict = {
torch.bool: np.bool_,
torch.uint8: np.uint8,
torch.int8: np.int8,
torch.int16: np.int16,
torch.int32: np.int32,
torch.int64: np.int64,
torch.float16: np.float16,
torch.bfloat16: np_bfloat16,
torch.float8_e4m3fn: np_float8,
torch.float32: np.float32,
torch.float64: np.float64,
torch.complex64: np.complex64,
torch.complex128: np.complex128,
}
def torch_dtype_to_np(dtype):
ret = _torch_to_np_dtype_dict.get(dtype)
assert ret is not None, f'Unsupported dtype: {dtype}'
return ret
_np_to_torch_dtype_dict = {
np.bool_: torch.bool,
np.uint8: torch.uint8,
np.int8: torch.int8,
np.int16: torch.int16,
np.int32: torch.int32,
np.int64: torch.int64,
np.float16: torch.float16,
np_bfloat16: torch.bfloat16,
np_float8: torch.float8_e4m3fn,
np.float32: torch.float32,
np.float64: torch.float64,
np.complex64: torch.complex64,
np.complex128: torch.complex128,
}
def np_dtype_to_torch(dtype):
ret = _np_to_torch_dtype_dict.get(dtype)
assert ret is not None, f'Unsupported dtype: {dtype}'
return ret
_trt_to_torch_dtype_dict = {
trt.float16: torch.float16,
trt.float32: torch.float32,
trt.int64: torch.int64,
trt.int32: torch.int32,
trt.int8: torch.int8,
trt.bool: torch.bool,
trt.bfloat16: torch.bfloat16,
trt.fp8: torch.float8_e4m3fn,
}
def trt_dtype_to_torch(dtype):
ret = _trt_to_torch_dtype_dict.get(dtype)
assert ret is not None, f'Unsupported dtype: {dtype}'
return ret
def is_same_dtype(type_a: Union[str, trt.DataType],
type_b: Union[str, trt.DataType]) -> bool:
if isinstance(type_a, str):
type_a = str_dtype_to_trt(type_a)
if isinstance(type_b, str):
type_b = str_dtype_to_trt(type_b)
return type_a == type_b
_torch_to_trt_dtype_dict = {
torch.float16: trt.float16,
torch.float32: trt.float32,
torch.int64: trt.int64,
torch.int32: trt.int32,
torch.int8: trt.int8,
torch.float8_e4m3fn: trt.fp8,
torch.qint8: trt.int8,
torch.bool: trt.bool,
torch.bfloat16: trt.bfloat16
}
def torch_dtype_to_trt(dtype):
ret = _torch_to_trt_dtype_dict.get(dtype)
assert ret is not None, f'Unsupported dtype: {dtype}'
return ret
_torch_dtype_to_np_typestr_dict = {
torch.float16: "<f2",
torch.float32: "<f4",
torch.int64: "<i8",
torch.int32: "<i4",
torch.int8: "|i1",
torch.float8_e4m3fn: "<f1",
torch.qint8: "|u1",
torch.bool: "|b1",
torch.bfloat16: "<f2",
}
def torch_dtype_to_np_typestr(dtype):
ret = _torch_dtype_to_np_typestr_dict.get(dtype)
assert ret is not None, f'Unsupported dtype: {dtype}'
return ret
def dim_to_trt_axes(dim):
"""Converts torch dim, or tuple of dims to a tensorrt axes bitmask"""
if not isinstance(dim, tuple):
dim = (dim, )
# create axes bitmask for reduce layer
axes = 0
for d in dim:
axes |= 1 << d
return axes
def trt_axes_to_dim(axes: int) -> List[int]:
"""Converts tensorrt axes bitmask to dims"""
dim = []
for i in range(32):
if axes & (1 << i):
dim.append(i)
return dim
def dim_resolve_negative(dim, ndim):
if not isinstance(dim, tuple):
dim = (dim, )
pos = []
for d in dim:
if d < 0:
d = ndim + d
pos.append(d)
return tuple(pos)
# mpi4py only exports MPI_COMM_TYPE_SHARED, so we define OMPI_COMM_TYPE_HOST here
OMPI_COMM_TYPE_HOST = 9
def mpi_comm():
from mpi4py import MPI
return MPI.COMM_WORLD
def mpi_rank():
return mpi_comm().Get_rank() if ENABLE_MULTI_DEVICE else 0
def mpi_world_size():
return mpi_comm().Get_size() if ENABLE_MULTI_DEVICE else 1
def mpi_barrier():
if ENABLE_MULTI_DEVICE:
mpi_comm().Barrier()
def mpi_broadcast(obj, root=0):
return mpi_comm().bcast(obj, root) if ENABLE_MULTI_DEVICE else obj
def pad_vocab_size(vocab_size, tp_size):
return int(math.ceil(vocab_size / tp_size) * tp_size)
def to_dict(obj):
return copy.deepcopy(obj.__dict__)
def to_json_string(obj):
if not isinstance(obj, dict):
obj = to_dict(obj)
return json.dumps(obj, indent=2, sort_keys=True) + "\n"
def to_json_file(obj, json_file_path):
with open(json_file_path, "w", encoding="utf-8") as writer:
writer.write(to_json_string(obj))
def numpy_fp32_to_bf16(src):
# Numpy doesn't support bfloat16 type
# Convert float32 to bfloat16 manually and assign with bf16 abstract type
original_shape = src.shape
src = src.flatten()
src = np.ascontiguousarray(src)
assert src.dtype == np.float32
dst = np.empty_like(src, dtype=np.uint16)
for i in range(len(dst)):
bytes = struct.pack('<f', src[i])
dst[i] = struct.unpack('<H', struct.pack('BB', bytes[2], bytes[3]))[0]
return dst.reshape(original_shape).view(np_bfloat16)
_extra_attrs_by_object: Dict[int, Dict[str, Any]] = {}
def get_extra_attr(obj, attr_name):
if id(obj) not in _extra_attrs_by_object:
return None
extra_attrs = _extra_attrs_by_object[id(obj)]
return extra_attrs.get(attr_name)
def _clean_extra_attrs(obj_id):
if obj_id in _extra_attrs_by_object:
del _extra_attrs_by_object[obj_id]
def set_extra_attr(obj, attr_name, value):
if id(obj) not in _extra_attrs_by_object:
_extra_attrs_by_object[id(obj)] = {}
weakref.finalize(obj, _clean_extra_attrs, id(obj))
_extra_attrs_by_object[id(obj)][attr_name] = value
def has_extra_attr(obj, attr_name):
if id(obj) not in _extra_attrs_by_object:
return False
return attr_name in _extra_attrs_by_object[id(obj)]
def set_obj_attrs(
obj: torch.Tensor,
ojb_attrs: Optional[Dict[str, Any]],
):
"""Set attributes on a object.
This method is used to set attributes on a object. This method
will not overwrite existing attributes.
"""
if ojb_attrs is None:
return
for key, value in ojb_attrs.items():
assert not hasattr(
obj, key), (f"Overwriting existing tensor attribute: {key}")
setattr(obj, key, value)
def get_init_params(obj, cls=None):
"""
Get all parameters in object's __init__.
Use cls's __init__ as filter if cls provided.
"""
names = None
if cls is not None:
names = set(list(inspect.signature(cls.__init__).parameters)[1:])
return {
name: getattr(obj, name)
for name in list(inspect.signature(obj.__class__.__init__).parameters)
[1:] if names is None or name in names
}
def release_gc():
''' Release memory allocated by PyTorch and Python garbage collector explicitly and immediately.
This could be used when some states might be kept in memory even after the variables are deleted.
'''
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
class DictConversion:
@classmethod
def from_dict(cls, config: Dict[str, Any]):
obj = cls()
fields = obj.__dataclass_fields__
for key, value in config.items():
assert hasattr(obj, key)
field_cls = fields[key].type
if (isinstance(field_cls, type)
and issubclass(field_cls, DictConversion)
and isinstance(value, dict)):
value = field_cls.from_dict(value)
setattr(obj, key, value)
return obj
def to_dict(self):
return asdict(self)
@classmethod
def from_json_file(cls, file):
with open(file) as f:
return cls.from_dict(json.load(f))
def set_defaults(self, **kwargs):
for key, default in kwargs.items():
value = getattr(self, key)
if (value is None
or (isinstance(value, (list, dict)) and len(value) == 0)):
setattr(self, key, default)
class BaseEnumMeta(EnumMeta):
def __contains__(cls, item):
try:
cls(item)
except ValueError:
return False
return True
def supports_inflight_batching(engine_dir):
config_path = Path(engine_dir) / "config.json"
json_config = GptJsonConfig.parse_file(config_path)
model_config = json_config.model_config
return model_config.supports_inflight_batching
class QuantModeWrapper:
def __init__(self, objs):
self.objs = objs
def __getattr__(self, name):
def method_wrapper(*args, **kwargs):
result = False
for obj in self.objs:
attr = getattr(obj, name)
if callable(attr):
result = result | attr(*args, **kwargs)
return result
return method_wrapper
def __repr__(self):
return f"QuantModeWrapper: ({self.objs})"
def __str__(self):
obj_strs = [str(obj) for obj in self.objs]
return f"[{', '.join(obj_strs)}]"
def __getitem__(self, index):
return self.objs[index]
@contextmanager
def nvtx_range(msg):
torch.cuda.nvtx.range_push(msg)
try:
yield
finally:
torch.cuda.nvtx.range_pop()
def volume(d: Sequence[int]):
return np.prod(d)
class TensorWrapper:
"""
A wrapper wraps raw data pointer to a tensor-like object. Could be compatibale with openai triton kernel and be converted to `torch.Tensor` with zero-copy overhead.
"""
def __init__(
self,
data_ptr: int,
dtype: Union[torch.dtype, str, np.dtype, trt.DataType],
shape: Sequence[int],
):
self._data_ptr = data_ptr
self.dtype = dtype
self.shape = shape
def data_ptr(self):
return self._data_ptr
@property
def dtype(self):
return self._dtype
@property
def shape(self):
return getattr(self, "_shape", None)
@dtype.setter
def dtype(self, dtype: Union[torch.dtype, str, np.dtype, trt.DataType]):
if isinstance(dtype, torch.dtype):
self._dtype = dtype
elif isinstance(dtype, str):
self._dtype = str_dtype_to_torch(dtype)
elif isinstance(dtype, np.dtype):
self._dtype = np_dtype_to_torch(dtype)
elif isinstance(dtype, trt.DataType):
self._dtype = trt_dtype_to_torch(dtype)
else:
raise TypeError(f"Unsupported dtype: {dtype}")
@shape.setter
def shape(self, shape: Sequence[int]):
self._shape = tuple(int(i) for i in shape)
def numel(self):
return volume(self.shape)
@property
def __cuda_array_interface__(self):
return {
"shape": self.shape,
"typestr": torch_dtype_to_np_typestr(self.dtype),
"data": (self.data_ptr() if self.numel() > 0 else 0, False),
"version": 3,
}
@staticmethod
def from_trt_desc(desc: trt.PluginTensorDesc, pointer: int):
return TensorWrapper(pointer, trt_dtype_to_torch(desc.type), desc.dims)
def convert_to_torch_tensor(
tensor: Union[TensorWrapper, torch.Tensor]) -> torch.Tensor:
"""
This function is to convert the `TensorWrapper` to torch.Tensor.
"""
if isinstance(tensor, torch.Tensor):
return tensor
return torch.as_tensor(tensor).view(tensor.dtype)