TensorRT-LLMs/tensorrt_llm/_utils.py
Kaiyu Xie 8dd9c91470
Update TensorRT-LLM (#539)
* Update TensorRT-LLM

---------

Co-authored-by: Shixiaowei02 <39303645+Shixiaowei02@users.noreply.github.com>
2023-12-04 18:06:59 +08:00

275 lines
6.7 KiB
Python

# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import json
import math
import struct
from functools import partial
from pathlib import Path, PosixPath
import numpy as np
# isort: off
import torch
import tensorrt as trt
# isort: on
# numpy doesn't know bfloat16, define abstract binary type instead
np_bfloat16 = np.dtype('V2', metadata={"dtype": "bfloat16"})
def torch_to_numpy(x: torch.Tensor):
assert isinstance(x, torch.Tensor), \
f'x must be a torch.Tensor object, but got {type(x)}.'
if x.dtype != torch.bfloat16:
return x.detach().cpu().numpy()
return x.view(torch.int16).detach().cpu().numpy().view(np_bfloat16)
def numpy_to_torch(x):
if x.dtype != np_bfloat16:
return torch.tensor(x)
return torch.tensor(x.view(np.int16)).view(torch.bfloat16)
def numpy_to_dtype(x, dtype: str):
if x.dtype == np_bfloat16:
# BF16 --> non-BF16 or BF16
if dtype != 'bfloat16':
torch_to_numpy(numpy_to_torch(x).to(str_dtype_to_torch(dtype)))
else:
return x
else:
# non-BF16 types --> non-BF16 or BF16
if dtype != 'bfloat16':
return x.astype(str_dtype_to_np(dtype))
else:
return torch_to_numpy(torch.from_numpy(x).to(torch.bfloat16))
fp32_array = partial(np.array, dtype=np.float32)
fp16_array = partial(np.array, dtype=np.float16)
int32_array = partial(np.array, dtype=np.int32)
def bf16_array(x):
x = torch.tensor(x, dtype=torch.bfloat16)
x = torch_to_numpy(x)
return x
def trt_version():
return trt.__version__
def torch_version():
return torch.__version__
_str_to_np_dict = dict(
float16=np.float16,
float32=np.float32,
int32=np.int32,
bfloat16=np_bfloat16,
)
def str_dtype_to_np(dtype):
ret = _str_to_np_dict.get(dtype)
assert ret is not None, f'Unsupported dtype: {dtype}'
return ret
_str_to_torch_dtype_dict = dict(
bfloat16=torch.bfloat16,
float16=torch.float16,
float32=torch.float32,
int32=torch.int32,
int8=torch.int8,
)
def str_dtype_to_torch(dtype):
ret = _str_to_torch_dtype_dict.get(dtype)
assert ret is not None, f'Unsupported dtype: {dtype}'
return ret
_str_to_trt_dtype_dict = dict(float16=trt.float16,
float32=trt.float32,
int64=trt.int64,
int32=trt.int32,
int8=trt.int8,
bool=trt.bool,
bfloat16=trt.bfloat16,
fp8=trt.fp8)
def str_dtype_to_trt(dtype):
ret = _str_to_trt_dtype_dict.get(dtype)
assert ret is not None, f'Unsupported dtype: {dtype}'
return ret
_np_to_trt_dtype_dict = {
np.int8: trt.int8,
np.int32: trt.int32,
np.float16: trt.float16,
np.float32: trt.float32,
# hash of np.dtype('int32') != np.int32
np.dtype('int8'): trt.int8,
np.dtype('int32'): trt.int32,
np.dtype('float16'): trt.float16,
np.dtype('float32'): trt.float32,
np_bfloat16: trt.bfloat16,
np.bool_: trt.bool,
}
def np_dtype_to_trt(dtype):
ret = _np_to_trt_dtype_dict.get(dtype)
assert ret is not None, f'Unsupported dtype: {dtype}'
return ret
_trt_to_np_dtype_dict = {
trt.int8: np.int8,
trt.int32: np.int32,
trt.float16: np.float16,
trt.float32: np.float32,
trt.bool: np.bool_,
trt.bfloat16: np_bfloat16,
}
def trt_dtype_to_np(dtype):
ret = _trt_to_np_dtype_dict.get(dtype)
assert ret is not None, f'Unsupported dtype: {dtype}'
return ret
_torch_to_np_dtype_dict = {
torch.float16: np.float16,
torch.float32: np.float32,
}
def torch_dtype_to_np(dtype):
ret = _torch_to_np_dtype_dict.get(dtype)
assert ret is not None, f'Unsupported dtype: {dtype}'
return ret
_trt_to_torch_dtype_dict = {
trt.float16: torch.float16,
trt.float32: torch.float32,
trt.int32: torch.int32,
trt.int8: torch.int8,
trt.bfloat16: torch.bfloat16
}
def trt_dtype_to_torch(dtype):
ret = _trt_to_torch_dtype_dict.get(dtype)
assert ret is not None, f'Unsupported dtype: {dtype}'
return ret
def dim_to_trt_axes(dim):
"""Converts torch dim, or tuple of dims to a tensorrt axes bitmask"""
if not isinstance(dim, tuple):
dim = (dim, )
# create axes bitmask for reduce layer
axes = 0
for d in dim:
axes |= 1 << d
return axes
def dim_resolve_negative(dim, ndim):
if not isinstance(dim, tuple):
dim = (dim, )
pos = []
for d in dim:
if d < 0:
d = ndim + d
pos.append(d)
return tuple(pos)
def mpi_comm():
from mpi4py import MPI
return MPI.COMM_WORLD
def mpi_rank():
return mpi_comm().Get_rank()
def mpi_world_size():
return mpi_comm().Get_size()
def pad_vocab_size(vocab_size, tp_size):
return int(math.ceil(vocab_size / tp_size) * tp_size)
def to_dict(obj):
return copy.deepcopy(obj.__dict__)
def to_json_string(obj):
if not isinstance(obj, dict):
obj = to_dict(obj)
return json.dumps(obj, indent=2, sort_keys=True) + "\n"
def to_json_file(obj, json_file_path):
with open(json_file_path, "w", encoding="utf-8") as writer:
writer.write(to_json_string(obj))
def numpy_fp32_to_bf16(src):
# Numpy doesn't support bfloat16 type
# Convert float32 to bfloat16 manually and assign with bf16 abstract type
original_shape = src.shape
src = src.flatten()
src = np.ascontiguousarray(src)
assert src.dtype == np.float32
dst = np.empty_like(src, dtype=np.uint16)
for i in range(len(dst)):
bytes = struct.pack('<f', src[i])
dst[i] = struct.unpack('<H', struct.pack('BB', bytes[2], bytes[3]))[0]
return dst.reshape(original_shape).view(np_bfloat16)
def fromfile(dir_path, name, shape=None, dtype=None):
dtype = np_dtype if dtype is None else dtype
p = dir_path
if not isinstance(p, PosixPath):
p = Path(p)
p = p / name
if Path(p).exists():
t = np.fromfile(p, dtype=dtype)
if shape is not None:
t = t.reshape(shape)
return t
return None