# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import copy import json import math import struct import weakref from functools import partial from pathlib import Path, PosixPath from typing import Any, Dict, List import numpy as np # isort: off import torch import tensorrt as trt # isort: on # numpy doesn't know bfloat16, define abstract binary type instead np_bfloat16 = np.dtype('V2', metadata={"dtype": "bfloat16"}) def torch_to_numpy(x: torch.Tensor): assert isinstance(x, torch.Tensor), \ f'x must be a torch.Tensor object, but got {type(x)}.' if x.dtype != torch.bfloat16: return x.detach().cpu().numpy() return x.view(torch.int16).detach().cpu().numpy().view(np_bfloat16) def numpy_to_torch(x): if x.dtype != np_bfloat16: return torch.tensor(x) return torch.tensor(x.view(np.int16)).view(torch.bfloat16) def numpy_to_dtype(x, dtype: str): if x.dtype == np_bfloat16: # BF16 --> non-BF16 or BF16 if dtype != 'bfloat16': torch_to_numpy(numpy_to_torch(x).to(str_dtype_to_torch(dtype))) else: return x else: # non-BF16 types --> non-BF16 or BF16 if dtype != 'bfloat16': return x.astype(str_dtype_to_np(dtype)) else: return torch_to_numpy(torch.from_numpy(x).to(torch.bfloat16)) fp32_array = partial(np.array, dtype=np.float32) fp16_array = partial(np.array, dtype=np.float16) int32_array = partial(np.array, dtype=np.int32) def bf16_array(x): x = torch.tensor(x, dtype=torch.bfloat16) x = torch_to_numpy(x) return x def copy_torch_to_numpy(x: torch.Tensor, ndarray: np.array): if x.dtype != torch.bfloat16: torch.from_numpy(ndarray).copy_(x) return ndarray torch.from_numpy(ndarray.view(np.int16)).copy_(x.view(torch.int16)) return ndarray def trt_version(): return trt.__version__ def torch_version(): return torch.__version__ _str_to_np_dict = dict( float16=np.float16, float32=np.float32, int64=np.int64, int32=np.int32, int8=np.int8, bool=np.bool_, bfloat16=np_bfloat16, ) def str_dtype_to_np(dtype): ret = _str_to_np_dict.get(dtype) assert ret is not None, f'Unsupported dtype: {dtype}' return ret _str_to_torch_dtype_dict = dict( bfloat16=torch.bfloat16, float16=torch.float16, float32=torch.float32, int64=torch.int64, int32=torch.int32, int8=torch.int8, bool=torch.bool, ) def str_dtype_to_torch(dtype): ret = _str_to_torch_dtype_dict.get(dtype) assert ret is not None, f'Unsupported dtype: {dtype}' return ret _str_to_trt_dtype_dict = dict(float16=trt.float16, float32=trt.float32, int64=trt.int64, int32=trt.int32, int8=trt.int8, bool=trt.bool, bfloat16=trt.bfloat16, fp8=trt.fp8) def str_dtype_to_trt(dtype): ret = _str_to_trt_dtype_dict.get(dtype) assert ret is not None, f'Unsupported dtype: {dtype}' return ret _trt_to_str_dtype_dict = {v: k for k, v in _str_to_trt_dtype_dict.items()} def trt_dtype_to_str(dtype: trt.DataType) -> str: assert isinstance(dtype, trt.DataType) return _trt_to_str_dtype_dict[dtype] _np_to_trt_dtype_dict = { np.int8: trt.int8, np.int32: trt.int32, np.int64: trt.int64, np.float16: trt.float16, np.float32: trt.float32, np.bool_: trt.bool, # hash of np.dtype('int32') != np.int32 np.dtype('int8'): trt.int8, np.dtype('int32'): trt.int32, np.dtype('int64'): trt.int64, np.dtype('float16'): trt.float16, np.dtype('float32'): trt.float32, np.dtype('bool'): trt.bool, np_bfloat16: trt.bfloat16, } def np_dtype_to_trt(dtype): ret = _np_to_trt_dtype_dict.get(dtype) assert ret is not None, f'Unsupported dtype: {dtype}' return ret _trt_to_np_dtype_dict = { trt.int8: np.int8, trt.int32: np.int32, trt.int64: np.int64, trt.float16: np.float16, trt.float32: np.float32, trt.bool: np.bool_, trt.bfloat16: np_bfloat16, } def trt_dtype_to_np(dtype): ret = _trt_to_np_dtype_dict.get(dtype) assert ret is not None, f'Unsupported dtype: {dtype}' return ret _torch_to_np_dtype_dict = { torch.bool: np.bool_, torch.uint8: np.uint8, torch.int8: np.int8, torch.int16: np.int16, torch.int32: np.int32, torch.int64: np.int64, torch.float16: np.float16, torch.bfloat16: np_bfloat16, torch.float32: np.float32, torch.float64: np.float64, torch.complex64: np.complex64, torch.complex128: np.complex128, } def torch_dtype_to_np(dtype): ret = _torch_to_np_dtype_dict.get(dtype) assert ret is not None, f'Unsupported dtype: {dtype}' return ret _trt_to_torch_dtype_dict = { trt.float16: torch.float16, trt.float32: torch.float32, trt.int64: torch.int64, trt.int32: torch.int32, trt.int8: torch.int8, trt.bool: torch.bool, trt.bfloat16: torch.bfloat16 } def trt_dtype_to_torch(dtype): ret = _trt_to_torch_dtype_dict.get(dtype) assert ret is not None, f'Unsupported dtype: {dtype}' return ret def dim_to_trt_axes(dim): """Converts torch dim, or tuple of dims to a tensorrt axes bitmask""" if not isinstance(dim, tuple): dim = (dim, ) # create axes bitmask for reduce layer axes = 0 for d in dim: axes |= 1 << d return axes def trt_axes_to_dim(axes: int) -> List[int]: """Converts tensorrt axes bitmask to dims""" dim = [] for i in range(32): if axes & (1 << i): dim.append(i) return dim def dim_resolve_negative(dim, ndim): if not isinstance(dim, tuple): dim = (dim, ) pos = [] for d in dim: if d < 0: d = ndim + d pos.append(d) return tuple(pos) def mpi_comm(): from mpi4py import MPI return MPI.COMM_WORLD def mpi_rank(): return mpi_comm().Get_rank() def mpi_world_size(): return mpi_comm().Get_size() def mpi_barrier(): mpi_comm().Barrier() def pad_vocab_size(vocab_size, tp_size): return int(math.ceil(vocab_size / tp_size) * tp_size) def to_dict(obj): return copy.deepcopy(obj.__dict__) def to_json_string(obj): if not isinstance(obj, dict): obj = to_dict(obj) return json.dumps(obj, indent=2, sort_keys=True) + "\n" def to_json_file(obj, json_file_path): with open(json_file_path, "w", encoding="utf-8") as writer: writer.write(to_json_string(obj)) def numpy_fp32_to_bf16(src): # Numpy doesn't support bfloat16 type # Convert float32 to bfloat16 manually and assign with bf16 abstract type original_shape = src.shape src = src.flatten() src = np.ascontiguousarray(src) assert src.dtype == np.float32 dst = np.empty_like(src, dtype=np.uint16) for i in range(len(dst)): bytes = struct.pack('