# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import copy import gc import inspect import json import linecache import math import os import struct import trace import weakref from contextlib import contextmanager from dataclasses import asdict from enum import EnumMeta from functools import lru_cache, partial, wraps from pathlib import Path from typing import Any, Dict, List, Optional, Sequence, Union import numpy as np import nvtx from mpi4py import MPI from mpi4py.util import pkl5 from packaging import version # isort: off import torch import tensorrt as trt # isort: on from tensorrt_llm.bindings import DataType, GptJsonConfig from tensorrt_llm.bindings.BuildInfo import ENABLE_MULTI_DEVICE from tensorrt_llm.logger import logger # numpy doesn't know bfloat16, define abstract binary type instead np_bfloat16 = np.dtype('V2', metadata={"dtype": "bfloat16"}) np_float8 = np.dtype('V1', metadata={"dtype": "float8"}) def torch_to_numpy(x: torch.Tensor): assert isinstance(x, torch.Tensor), \ f'x must be a torch.Tensor object, but got {type(x)}.' if x.dtype == torch.bfloat16: return x.view(torch.int16).detach().cpu().numpy().view(np_bfloat16) elif x.dtype == torch.float8_e4m3fn: return x.view(torch.int8).detach().cpu().numpy().view(np_float8) else: return x.detach().cpu().numpy() def numpy_to_torch(x): if x.dtype == np_bfloat16: return torch.from_numpy(x.view(np.int16)).view(torch.bfloat16) elif x.dtype == np_float8: return torch.from_numpy(x.view(np.int8)).view(torch.float8_e4m3fn) else: return torch.from_numpy(x) def numpy_to_dtype(x, dtype: str): if str_dtype_to_np(dtype) == x.dtype: return x if x.dtype not in [np_bfloat16, np_float8 ] and dtype not in ['bfloat16', 'fp8']: return x.astype(str_dtype_to_np(dtype)) else: return torch_to_numpy(numpy_to_torch(x).to(str_dtype_to_torch(dtype))) fp32_array = partial(np.array, dtype=np.float32) fp16_array = partial(np.array, dtype=np.float16) int32_array = partial(np.array, dtype=np.int32) int64_array = partial(np.array, dtype=np.int64) bool_array = partial(np.array, dtype=np.bool_) def dims_array(x): is_int64_dims = True try: trt.Dims([np.iinfo(np.int64).max]) except TypeError: is_int64_dims = False return int64_array(x) if is_int64_dims else int32_array(x) def bf16_array(x): x = torch.tensor(x, dtype=torch.bfloat16) x = torch_to_numpy(x) return x def numpy_array(data, trt_dtype): # convenient wrapper due to numpy not support bf16 yet if trt_dtype == trt.bfloat16: return bf16_array(data) return np.array(data, trt_dtype_to_np(trt_dtype)) def copy_torch_to_numpy(x: torch.Tensor, ndarray: np.array): if x.dtype == torch.bfloat16: torch.from_numpy(ndarray.view(np.int16)).copy_(x.view(torch.int16)) elif x.dtype == torch.float8_e4m3fn: torch.from_numpy(ndarray.view(np.int8)).copy_(x.view(torch.int8)) else: torch.from_numpy(ndarray).copy_(x) return ndarray def trt_version(): return trt.__version__ def trt_gte(major: int, minor: int = 0): """ Check if TRT version is greater than or equal to major.minor """ trt_ver = version.parse(trt_version()) return trt_ver.major >= major and trt_ver.minor >= minor def torch_version(): return torch.__version__ _str_to_np_dict = dict( float16=np.float16, float32=np.float32, int64=np.int64, int32=np.int32, int8=np.int8, bool=np.bool_, bfloat16=np_bfloat16, fp8=np_float8, ) def str_dtype_to_np(dtype): ret = _str_to_np_dict.get(dtype) assert ret is not None, f'Unsupported dtype: {dtype}' return ret _str_to_torch_dtype_dict = dict( bfloat16=torch.bfloat16, float16=torch.float16, float32=torch.float32, int64=torch.int64, int32=torch.int32, int8=torch.int8, bool=torch.bool, fp8=torch.float8_e4m3fn, ) def str_dtype_to_torch(dtype): ret = _str_to_torch_dtype_dict.get(dtype) assert ret is not None, f'Unsupported dtype: {dtype}' return ret _str_to_binding_dtype_dict = dict( bfloat16=DataType.BF16, float16=DataType.HALF, float32=DataType.FLOAT, int64=DataType.INT64, int32=DataType.INT32, int8=DataType.INT8, bool=DataType.BOOL, fp8=DataType.FP8, ) _binding_to_str_dtype = {v: k for k, v in _str_to_binding_dtype_dict.items()} _binding_dtype_size = { DataType.INT64: 8, DataType.FLOAT: 4, DataType.INT32: 4, DataType.BF16: 2, DataType.HALF: 2, DataType.BOOL: 1, DataType.FP8: 1, DataType.INT8: 1, DataType.UINT8: 1, } def binding_to_str_dtype(binding_dtype) -> str: ret = _binding_to_str_dtype.get(binding_dtype) assert ret is not None, f'Unsupported binding dtype: {binding_dtype}' return ret def binding_dtype_size(dtype: DataType): return _binding_dtype_size[dtype] def str_dtype_to_binding(dtype): ret = _str_to_binding_dtype_dict.get(dtype) assert ret is not None, f'Unsupported dtype: {dtype}' return ret _torch_dtype_to_str_dict = {v: k for k, v in _str_to_torch_dtype_dict.items()} def torch_dtype_to_str(dtype): return _torch_dtype_to_str_dict[dtype] _str_to_trt_dtype_dict = dict(float16=trt.float16, float32=trt.float32, int64=trt.int64, int32=trt.int32, int8=trt.int8, bool=trt.bool, bfloat16=trt.bfloat16, fp8=trt.fp8, nvfp4=trt.fp4) def str_dtype_to_trt(dtype): if dtype == "fp4": # Special handling for FP4 since CI's trt version is not recent enough. if not hasattr(trt, 'fp4'): raise ValueError( "fp4 unsupported, trt version needs to be upgraded.") return trt.fp4 ret = _str_to_trt_dtype_dict.get(dtype) assert ret is not None, f'Unsupported dtype: {dtype}' return ret _trt_to_str_dtype_dict = {v: k for k, v in _str_to_trt_dtype_dict.items()} def trt_dtype_to_str(dtype: trt.DataType) -> str: assert isinstance(dtype, trt.DataType) return _trt_to_str_dtype_dict[dtype] _np_to_trt_dtype_dict = { np.int8: trt.int8, np.int32: trt.int32, np.int64: trt.int64, np.float16: trt.float16, np.float32: trt.float32, np.bool_: trt.bool, # hash of np.dtype('int32') != np.int32 np.dtype('int8'): trt.int8, np.dtype('int32'): trt.int32, np.dtype('int64'): trt.int64, np.dtype('float16'): trt.float16, np.dtype('float32'): trt.float32, np.dtype('bool'): trt.bool, np_bfloat16: trt.bfloat16, np_float8: trt.fp8, } def np_dtype_to_trt(dtype): ret = _np_to_trt_dtype_dict.get(dtype) assert ret is not None, f'Unsupported dtype: {dtype}' return ret _trt_to_np_dtype_dict = { trt.int8: np.int8, trt.int32: np.int32, trt.int64: np.int64, trt.float16: np.float16, trt.float32: np.float32, trt.bool: np.bool_, trt.bfloat16: np_bfloat16, trt.fp8: np_float8, } def trt_dtype_to_np(dtype): ret = _trt_to_np_dtype_dict.get(dtype) assert ret is not None, f'Unsupported dtype: {dtype}' return ret _torch_to_np_dtype_dict = { torch.bool: np.bool_, torch.uint8: np.uint8, torch.int8: np.int8, torch.int16: np.int16, torch.int32: np.int32, torch.int64: np.int64, torch.float16: np.float16, torch.bfloat16: np_bfloat16, torch.float8_e4m3fn: np_float8, torch.float32: np.float32, torch.float64: np.float64, torch.complex64: np.complex64, torch.complex128: np.complex128, } def torch_dtype_to_np(dtype): ret = _torch_to_np_dtype_dict.get(dtype) assert ret is not None, f'Unsupported dtype: {dtype}' return ret _np_to_torch_dtype_dict = { np.bool_: torch.bool, np.uint8: torch.uint8, np.int8: torch.int8, np.int16: torch.int16, np.int32: torch.int32, np.int64: torch.int64, np.float16: torch.float16, np_bfloat16: torch.bfloat16, np_float8: torch.float8_e4m3fn, np.float32: torch.float32, np.float64: torch.float64, np.complex64: torch.complex64, np.complex128: torch.complex128, } def np_dtype_to_torch(dtype): ret = _np_to_torch_dtype_dict.get(dtype) assert ret is not None, f'Unsupported dtype: {dtype}' return ret _trt_to_torch_dtype_dict = { trt.float16: torch.float16, trt.float32: torch.float32, trt.int64: torch.int64, trt.int32: torch.int32, trt.int8: torch.int8, trt.bool: torch.bool, trt.bfloat16: torch.bfloat16, trt.fp8: torch.float8_e4m3fn, } def trt_dtype_to_torch(dtype): ret = _trt_to_torch_dtype_dict.get(dtype) assert ret is not None, f'Unsupported dtype: {dtype}' return ret def is_same_dtype(type_a: Union[str, trt.DataType], type_b: Union[str, trt.DataType]) -> bool: if isinstance(type_a, str): type_a = str_dtype_to_trt(type_a) if isinstance(type_b, str): type_b = str_dtype_to_trt(type_b) return type_a == type_b _torch_to_trt_dtype_dict = { torch.float16: trt.float16, torch.float32: trt.float32, torch.int64: trt.int64, torch.int32: trt.int32, torch.int8: trt.int8, torch.float8_e4m3fn: trt.fp8, torch.qint8: trt.int8, torch.bool: trt.bool, torch.bfloat16: trt.bfloat16 } def torch_dtype_to_trt(dtype): ret = _torch_to_trt_dtype_dict.get(dtype) assert ret is not None, f'Unsupported dtype: {dtype}' return ret _torch_to_binding_dtype_dict = { torch.float16: DataType.HALF, torch.float32: DataType.FLOAT, torch.int64: DataType.INT64, torch.int32: DataType.INT32, torch.int8: DataType.INT8, torch.float8_e4m3fn: DataType.FP8, torch.qint8: DataType.INT8, torch.bool: DataType.BOOL, torch.bfloat16: DataType.BF16 } def torch_dtype_to_binding(dtype): ret = _torch_to_binding_dtype_dict.get(dtype) assert ret is not None, f'Unsupported dtype: {dtype}' return ret _torch_dtype_to_np_typestr_dict = { torch.float16: " List[int]: """Converts tensorrt axes bitmask to dims""" dim = [] for i in range(32): if axes & (1 << i): dim.append(i) return dim def dim_resolve_negative(dim, ndim): if not isinstance(dim, tuple): dim = (dim, ) pos = [] for d in dim: if d < 0: d = ndim + d pos.append(d) return tuple(pos) # mpi4py only exports MPI_COMM_TYPE_SHARED, so we define OMPI_COMM_TYPE_HOST here OMPI_COMM_TYPE_HOST = 9 comm = pkl5.Intracomm(MPI.COMM_WORLD) def set_mpi_comm(new_comm): global comm comm = new_comm def mpi_comm(): return comm local_comm = mpi_comm().Split_type(split_type=OMPI_COMM_TYPE_HOST) def local_mpi_comm(): return local_comm def mpi_rank(): return mpi_comm().Get_rank() if ENABLE_MULTI_DEVICE else 0 def global_mpi_rank(): return MPI.COMM_WORLD.Get_rank() if ENABLE_MULTI_DEVICE else 0 def global_mpi_size(): return MPI.COMM_WORLD.Get_size() if ENABLE_MULTI_DEVICE else 1 def mpi_world_size(): return mpi_comm().Get_size() if ENABLE_MULTI_DEVICE else 1 def local_mpi_rank(): return local_comm.Get_rank() if ENABLE_MULTI_DEVICE else 0 def local_mpi_size(): return local_comm.Get_size() if ENABLE_MULTI_DEVICE else 1 def default_gpus_per_node(): num_gpus = torch.cuda.device_count() num_ranks = local_mpi_size() assert num_gpus > 0, "No GPU found on the node" if num_ranks > num_gpus: logger.warning(f"{num_ranks} MPI ranks will share {num_gpus} GPUs.") return min(num_ranks, num_gpus) def mpi_barrier(): if ENABLE_MULTI_DEVICE: mpi_comm().Barrier() def local_mpi_barrier(): if ENABLE_MULTI_DEVICE: local_comm.Barrier() def mpi_broadcast(obj, root=0): return mpi_comm().bcast(obj, root) if is_multi_device_enable() else obj def mpi_allgather(obj): return mpi_comm().allgather(obj) if ENABLE_MULTI_DEVICE else obj def mpi_isend(buf, dest, tag=0): # isend in buf-like objects (e.g. numpy array) # return request handle if ENABLE_MULTI_DEVICE if ENABLE_MULTI_DEVICE: return mpi_comm().Isend(buf, dest, tag=tag) return None def mpi_send(buf, dest, tag=0): # send in buf-like objects (e.g. numpy array) # return request handle if ENABLE_MULTI_DEVICE if ENABLE_MULTI_DEVICE: mpi_comm().Send(buf, dest, tag=tag) return None def mpi_recv(buf, source, tag): # recv in buf-like object (e.g. numpy array) if ENABLE_MULTI_DEVICE: return mpi_comm().Recv(buf, source, tag=tag) return None def mpi_send_object(obj, dest, tag=0): if ENABLE_MULTI_DEVICE: mpi_comm().send(obj, dest=dest, tag=tag) def mpi_isend_object(obj, dest, tag=0): if ENABLE_MULTI_DEVICE: return mpi_comm().isend(obj, dest=dest, tag=tag) return None def mpi_recv_object(source, tag): if ENABLE_MULTI_DEVICE: return mpi_comm().recv(source=source, tag=tag) return None def pad_vocab_size(vocab_size, tp_size): return int(math.ceil(vocab_size / tp_size) * tp_size) def to_dict(obj): return copy.deepcopy(obj.__dict__) def to_json_string(obj): if not isinstance(obj, dict): obj = to_dict(obj) return json.dumps(obj, indent=2, sort_keys=True) + "\n" def to_json_file(obj, json_file_path): with open(json_file_path, "w", encoding="utf-8") as writer: writer.write(to_json_string(obj)) def numpy_fp32_to_bf16(src): # Numpy doesn't support bfloat16 type # Convert float32 to bfloat16 manually and assign with bf16 abstract type original_shape = src.shape src = src.flatten() src = np.ascontiguousarray(src) assert src.dtype == np.float32 dst = np.empty_like(src, dtype=np.uint16) for i in range(len(dst)): bytes = struct.pack(' 0 else 0, False), "version": 3, } @staticmethod def from_trt_desc(desc: trt.PluginTensorDesc, pointer: int): return TensorWrapper(pointer, trt_dtype_to_torch(desc.type), desc.dims) def convert_to_torch_tensor( tensor: Union[TensorWrapper, torch.Tensor]) -> torch.Tensor: """ This function is to convert the `TensorWrapper` to torch.Tensor. """ if isinstance(tensor, torch.Tensor): return tensor old_ptr = tensor.data_ptr() new_tensor = torch.as_tensor(tensor).view(tensor.dtype) new_ptr = new_tensor.data_ptr() if old_ptr != new_ptr: raise RuntimeError( "Data pointer mismatch after converting to torch.Tensor") return new_tensor class KVCacheEventSerializer: @classmethod def get_event_serialize_func(cls, event_type): return { "KVCacheCreatedData": cls._created_to_json, "KVCacheStoredData": cls._stored_to_json, "KVCacheStoredBlockData": cls._stored_block_to_json, "KVCacheRemovedData": cls._removed_to_json, "KVCacheUpdatedData": cls._updated_to_json, }.get(event_type, None) @classmethod def serialize(cls, events): if events is None: return None if not isinstance(events, list): return cls.to_json_str(events) return [cls.to_json_str(event) for event in events] @classmethod def to_json_str(cls, event): if event is None: return {} event_type = type(event.data).__name__ event_serialize_func = cls.get_event_serialize_func(event_type) if event_serialize_func is None: raise ValueError(f"Unknown KVCache event data type: {event_type}") return { "event_id": event.event_id, "data": event_serialize_func(event.data), "window_size": event.window_size } @staticmethod def _created_to_json(data): return { "type": "created", "num_blocks_per_cache_level": data.num_blocks_per_cache_level } @staticmethod def _stored_to_json(data): return { "type": "stored", "parent_hash": data.parent_hash, "blocks": [ KVCacheEventSerializer._stored_block_to_json(block) for block in data.blocks ] } @staticmethod def _stored_block_to_json(data): return { "type": "stored_block", "block_hash": data.block_hash, "tokens": [ KVCacheEventSerializer._unique_tokens_to_json(token) for token in data.tokens ], # "lora_id": data.lora_id, # TODO (shreyasm): enable serialization of lora_id "cache_level": data.cache_level, "priority": data.priority } @staticmethod def _removed_to_json(data): return {"type": "removed", "block_hashes": data.block_hashes} @staticmethod def _updated_to_json(data): return { "type": "updated", "block_hash": data.block_hash, "cache_level": KVCacheEventSerializer._event_diff_to_json(data.cache_level), "priority": KVCacheEventSerializer._event_diff_to_json(data.priority) } @staticmethod def _event_diff_to_json(data): return { "type": "event_diff", "new_value": data.new_value, "old_value": data.old_value } @staticmethod def _unique_tokens_to_json(data): return { "type": "unique_token", "token_id": data.token_id, "token_extra_id": data.token_extra_id } def is_multi_device_enable(): """ This method evaluates if we are running on multiple GPUs and the flag ENABLE_MULTI_DEVICE is set. So we can avoid broadcast calls on single GPU. Issue: https://github.com/NVIDIA/TensorRT-LLM/issues/5927 ENABLE_MULTI_DEVICE is true by default when building tensorrt-llm so we need to also check the number of devices """ return local_mpi_size() > 1