# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import copy import json import math import struct from functools import partial from pathlib import Path, PosixPath import numpy as np # isort: off import torch import tensorrt as trt # isort: on # numpy doesn't know bfloat16, define abstract binary type instead np_bfloat16 = np.dtype('V2', metadata={"dtype": "bfloat16"}) def torch_to_numpy(x: torch.Tensor): assert isinstance(x, torch.Tensor), \ f'x must be a torch.Tensor object, but got {type(x)}.' if x.dtype != torch.bfloat16: return x.detach().cpu().numpy() return x.view(torch.int16).detach().cpu().numpy().view(np_bfloat16) def numpy_to_torch(x): if x.dtype != np_bfloat16: return torch.tensor(x) return torch.tensor(x.view(np.int16)).view(torch.bfloat16) def numpy_to_dtype(x, dtype: str): if x.dtype == np_bfloat16: # BF16 --> non-BF16 or BF16 if dtype != 'bfloat16': torch_to_numpy(numpy_to_torch(x).to(str_dtype_to_torch(dtype))) else: return x else: # non-BF16 types --> non-BF16 or BF16 if dtype != 'bfloat16': return x.astype(str_dtype_to_np(dtype)) else: return torch_to_numpy(torch.from_numpy(x).to(torch.bfloat16)) fp32_array = partial(np.array, dtype=np.float32) fp16_array = partial(np.array, dtype=np.float16) int32_array = partial(np.array, dtype=np.int32) def bf16_array(x): x = torch.tensor(x, dtype=torch.bfloat16) x = torch_to_numpy(x) return x def trt_version(): return trt.__version__ def torch_version(): return torch.__version__ _str_to_np_dict = dict( float16=np.float16, float32=np.float32, int32=np.int32, bfloat16=np_bfloat16, ) def str_dtype_to_np(dtype): ret = _str_to_np_dict.get(dtype) assert ret is not None, f'Unsupported dtype: {dtype}' return ret _str_to_torch_dtype_dict = dict( bfloat16=torch.bfloat16, float16=torch.float16, float32=torch.float32, int32=torch.int32, int8=torch.int8, ) def str_dtype_to_torch(dtype): ret = _str_to_torch_dtype_dict.get(dtype) assert ret is not None, f'Unsupported dtype: {dtype}' return ret _str_to_trt_dtype_dict = dict(float16=trt.float16, float32=trt.float32, int64=trt.int64, int32=trt.int32, int8=trt.int8, bool=trt.bool, bfloat16=trt.bfloat16, fp8=trt.fp8) def str_dtype_to_trt(dtype): ret = _str_to_trt_dtype_dict.get(dtype) assert ret is not None, f'Unsupported dtype: {dtype}' return ret _np_to_trt_dtype_dict = { np.int8: trt.int8, np.int32: trt.int32, np.float16: trt.float16, np.float32: trt.float32, # hash of np.dtype('int32') != np.int32 np.dtype('int8'): trt.int8, np.dtype('int32'): trt.int32, np.dtype('float16'): trt.float16, np.dtype('float32'): trt.float32, np_bfloat16: trt.bfloat16, np.bool_: trt.bool, } def np_dtype_to_trt(dtype): ret = _np_to_trt_dtype_dict.get(dtype) assert ret is not None, f'Unsupported dtype: {dtype}' return ret _trt_to_np_dtype_dict = { trt.int8: np.int8, trt.int32: np.int32, trt.float16: np.float16, trt.float32: np.float32, trt.bool: np.bool_, trt.bfloat16: np_bfloat16, } def trt_dtype_to_np(dtype): ret = _trt_to_np_dtype_dict.get(dtype) assert ret is not None, f'Unsupported dtype: {dtype}' return ret _torch_to_np_dtype_dict = { torch.float16: np.float16, torch.float32: np.float32, } def torch_dtype_to_np(dtype): ret = _torch_to_np_dtype_dict.get(dtype) assert ret is not None, f'Unsupported dtype: {dtype}' return ret _trt_to_torch_dtype_dict = { trt.float16: torch.float16, trt.float32: torch.float32, trt.int32: torch.int32, trt.int8: torch.int8, trt.bfloat16: torch.bfloat16 } def trt_dtype_to_torch(dtype): ret = _trt_to_torch_dtype_dict.get(dtype) assert ret is not None, f'Unsupported dtype: {dtype}' return ret def dim_to_trt_axes(dim): """Converts torch dim, or tuple of dims to a tensorrt axes bitmask""" if not isinstance(dim, tuple): dim = (dim, ) # create axes bitmask for reduce layer axes = 0 for d in dim: axes |= 1 << d return axes def dim_resolve_negative(dim, ndim): if not isinstance(dim, tuple): dim = (dim, ) pos = [] for d in dim: if d < 0: d = ndim + d pos.append(d) return tuple(pos) def mpi_comm(): from mpi4py import MPI return MPI.COMM_WORLD def mpi_rank(): return mpi_comm().Get_rank() def mpi_world_size(): return mpi_comm().Get_size() def pad_vocab_size(vocab_size, tp_size): return int(math.ceil(vocab_size / tp_size) * tp_size) def to_dict(obj): return copy.deepcopy(obj.__dict__) def to_json_string(obj): if not isinstance(obj, dict): obj = to_dict(obj) return json.dumps(obj, indent=2, sort_keys=True) + "\n" def to_json_file(obj, json_file_path): with open(json_file_path, "w", encoding="utf-8") as writer: writer.write(to_json_string(obj)) def numpy_fp32_to_bf16(src): # Numpy doesn't support bfloat16 type # Convert float32 to bfloat16 manually and assign with bf16 abstract type original_shape = src.shape src = src.flatten() src = np.ascontiguousarray(src) assert src.dtype == np.float32 dst = np.empty_like(src, dtype=np.uint16) for i in range(len(dst)): bytes = struct.pack('