mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
fix:https://nvbugs/5234033 enable startcoder trt-flow with transformer 4.51.3. Signed-off-by: nv-guomingz <137257613+nv-guomingz@users.noreply.github.com>
275 lines
11 KiB
Python
275 lines
11 KiB
Python
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import math
|
|
import weakref
|
|
from typing import Optional, Sequence, Union
|
|
|
|
import numpy as np
|
|
|
|
# isort: off
|
|
import torch
|
|
import tensorrt as trt
|
|
# isort: on
|
|
|
|
from ._common import default_net
|
|
from ._utils import (copy_torch_to_numpy, np_dtype_to_trt, str_dtype_to_trt,
|
|
torch_to_numpy, trt_dtype_to_np, trt_dtype_to_torch)
|
|
from .functional import Tensor, constant
|
|
from .logger import logger
|
|
from .network import Network
|
|
|
|
|
|
class Parameter:
|
|
_DEFAULT_DTYPE = trt.DataType.FLOAT
|
|
|
|
def __init__(self,
|
|
value: Optional[Union[np.ndarray, torch.Tensor]] = None,
|
|
shape: Sequence[int] = None,
|
|
dtype: Union[str, trt.DataType] = None,
|
|
is_buffer: bool = False,
|
|
prefer_managed=False):
|
|
if dtype is None:
|
|
logger.warning(
|
|
f'Parameter dtype is None, using default dtype: {self._DEFAULT_DTYPE}, it is recommended to always specify dtype explicitly'
|
|
)
|
|
dtype = self._DEFAULT_DTYPE if dtype is None else dtype
|
|
if isinstance(dtype, str):
|
|
dtype = str_dtype_to_trt(dtype)
|
|
self._dtype: trt.DataType = dtype
|
|
if value is None:
|
|
assert isinstance(shape, (
|
|
list,
|
|
tuple)), f"shape must be list or tuple, receive {(type(shape))}"
|
|
self._shape = tuple(shape)
|
|
self._value = None
|
|
else:
|
|
self._shape = value.shape
|
|
self._value = self._regularize_value(value)
|
|
self.is_buffer = is_buffer
|
|
self._prefer_managed = prefer_managed
|
|
self._tensor: Tensor = None
|
|
self._network: weakref.ref = None
|
|
self._name = None
|
|
self.need_transpose = False
|
|
|
|
@property
|
|
def shape(self):
|
|
return self._shape
|
|
|
|
@property
|
|
def dtype(self):
|
|
return self._dtype
|
|
|
|
@property
|
|
def name(self):
|
|
return self._name
|
|
|
|
def _create_managed_tensor(self, network) -> Tensor:
|
|
num = len(network._inputs)
|
|
self._name = f"managed_constant_{num}"
|
|
|
|
if self._value is None or (isinstance(self._value, np.ndarray)
|
|
and not self._value.flags['C_CONTIGUOUS']):
|
|
value_old = self._value
|
|
self._value = np.empty(self._shape, trt_dtype_to_np(self._dtype))
|
|
network._register_unfilled_weights(
|
|
# use updated self._shape here
|
|
self._name,
|
|
self._value,
|
|
value_old)
|
|
return Tensor(name=self._name, dtype=self._dtype, shape=self._shape)
|
|
|
|
def get_managed_tensor(self, network: Network) -> Tensor:
|
|
if self._network is None or self._network() != network:
|
|
self._network = weakref.ref(network)
|
|
self._tensor = network.get_parameter_tensor(self)
|
|
if self._tensor is None:
|
|
self._tensor = self._create_managed_tensor(network)
|
|
network.set_parameter_tensor(self, self._tensor)
|
|
return self._tensor
|
|
|
|
def _create_constant_tensor(self) -> Tensor:
|
|
if (self._value is not None and isinstance(self._value, np.ndarray)
|
|
and self._value.flags['C_CONTIGUOUS']):
|
|
lower_type = None
|
|
lower_shape = None
|
|
# workaround for reinterpreted data type
|
|
dtype = self._value.dtype
|
|
if (self.dtype == trt.fp4 or self.dtype
|
|
== trt.fp8) and (dtype == np.uint8 or dtype == np.int8
|
|
or dtype == np.int32 or dtype == np.int64):
|
|
lower_type = self.dtype
|
|
lower_shape = self.shape
|
|
|
|
self._value = constant(self._value, lower_type, lower_shape)
|
|
return self._value
|
|
elif self._value is None or isinstance(self._value, np.ndarray):
|
|
if self._dtype == trt.fp4:
|
|
shape = list(self._shape)
|
|
assert shape[
|
|
-1] % 16 == 0, "For FP4, the last dimension of the shape should be multiple of 16"
|
|
shape[-1] = shape[-1] // 16
|
|
dtype = np.int64
|
|
else:
|
|
shape = self._shape
|
|
dtype = trt_dtype_to_np(self._dtype)
|
|
ndarray = np.empty(shape, dtype)
|
|
tensor = constant(ndarray, self._dtype, self._shape)
|
|
default_net()._register_unfilled_weights(tensor.producer.name,
|
|
ndarray, self._value)
|
|
return tensor
|
|
|
|
def get_constant_tensor(self, network: Network) -> Tensor:
|
|
if self._network is None or self._network() != network:
|
|
self._network = weakref.ref(network)
|
|
self._tensor = network.get_parameter_tensor(self)
|
|
if self._tensor is None:
|
|
self._tensor = self._create_constant_tensor()
|
|
self._name = self._tensor.producer.name
|
|
network.set_parameter_tensor(self, self._tensor)
|
|
return self._tensor
|
|
|
|
def get_tensor(self, network) -> Tensor:
|
|
if self.is_managed(network):
|
|
return self.get_managed_tensor(network)
|
|
else:
|
|
return self.get_constant_tensor(network)
|
|
|
|
def is_managed(self, network):
|
|
if network is None:
|
|
network = default_net()
|
|
return self._prefer_managed and network.plugin_config.manage_weights
|
|
|
|
@property
|
|
def value(self) -> Tensor:
|
|
return self.get_tensor(default_net())
|
|
|
|
@classmethod
|
|
def xavier_init(cls, weights: np.ndarray):
|
|
shape = weights.shape
|
|
dtype = np_dtype_to_trt(weights.dtype)
|
|
if len(shape) == 2:
|
|
# Xavier initialization see https://paperswithcode.com/method/xavier-initialization
|
|
v_range = math.sqrt(6) / math.sqrt(shape[0] + shape[1])
|
|
else:
|
|
v_range = 0.1
|
|
|
|
if dtype == trt.DataType.INT8 or dtype == trt.DataType.INT32 or dtype == trt.DataType.INT64:
|
|
range_map = {
|
|
trt.DataType.INT8: 128,
|
|
trt.DataType.INT32: 2**31,
|
|
trt.DataType.INT64: 2**63
|
|
}
|
|
upper = math.ceil(range_map[dtype] * v_range)
|
|
value = torch.randint(-upper,
|
|
upper, (shape),
|
|
dtype=trt_dtype_to_torch(dtype),
|
|
device='cuda')
|
|
# value ~ U[int(-128 * v_range), int(128 * v_range)]
|
|
elif dtype == trt.DataType.FP8:
|
|
value = torch.rand((shape), device='cuda') * 2 - 1
|
|
# value ~ U[-v_range, v_range]
|
|
value = value * v_range
|
|
value = value.to(trt_dtype_to_torch(dtype))
|
|
else:
|
|
value = torch.rand(
|
|
(shape), dtype=trt_dtype_to_torch(dtype), device='cuda') * 2 - 1
|
|
# value ~ U[-v_range, v_range]
|
|
value = value * v_range
|
|
|
|
copy_torch_to_numpy(value, weights)
|
|
|
|
def is_inited(self) -> bool:
|
|
return self._value is not None
|
|
|
|
@property
|
|
def raw_value(self) -> np.ndarray:
|
|
if self._value is None:
|
|
dtype = trt_dtype_to_np(self.dtype)
|
|
self._value = np.empty(self.shape, dtype)
|
|
Parameter.xavier_init(self._value)
|
|
assert isinstance(
|
|
self._value, np.ndarray
|
|
), "Must be np.ndarray. Proper usage: get parameter.raw_value before getting parameter.value"
|
|
return self._value
|
|
|
|
@value.setter
|
|
def value(self, v: Union[np.ndarray, torch.Tensor]):
|
|
v = self._regularize_value(v)
|
|
|
|
if v.shape != self.shape and v.ndim == 0 and max(self.shape) == 1:
|
|
# convert the scalar into a tensor which each dim is 1.
|
|
v = v.reshape(self.shape)
|
|
|
|
if self.dtype == trt.fp4:
|
|
assert v.shape[:-1] == self.shape[:-1] and v.shape[-1] == self.shape[-1] // 2 // v.dtype.itemsize, \
|
|
f'For FP4, the shape of the value should be the same as the original shape, ' \
|
|
f'except the last dimension should be half of the original shape. ' \
|
|
f'Updated: {v.shape}, original: {self.shape}'
|
|
else:
|
|
assert v.shape == self.shape, \
|
|
f'The value updated is not the same shape as the original. ' \
|
|
f'Updated: {v.shape}, original: {self.shape}'
|
|
if (self.dtype == trt.fp4 or self.dtype
|
|
== trt.fp8) and (v.dtype == np.int8 or v.dtype == np.uint8
|
|
or v.dtype == np.int32 or v.dtype == np.int64):
|
|
pass
|
|
else:
|
|
dtype = np_dtype_to_trt(v.dtype)
|
|
if self.dtype != dtype:
|
|
logger.warning(
|
|
f"Parameter was initialized as {self.dtype} but set to {dtype}"
|
|
)
|
|
self._dtype = dtype
|
|
self._value = v
|
|
|
|
def set_value_or_dummy(self, v: Union[np.ndarray, torch.Tensor]):
|
|
v = self._regularize_value(v)
|
|
if v.shape != self._shape:
|
|
self.value = np.empty(self._shape, trt_dtype_to_np(self._dtype))
|
|
return
|
|
|
|
self.value = v
|
|
|
|
def set_name(self, name: str, network):
|
|
self._name = name
|
|
if self.is_managed(network):
|
|
self._get_weights(network).name = name
|
|
return True
|
|
else:
|
|
return network.trt_network.set_weights_name(
|
|
self._get_weights(network), name)
|
|
|
|
def _get_weights(self, network) -> trt.Weights | Tensor | None:
|
|
tensor = network.get_parameter_tensor(self)
|
|
if self.is_managed(network):
|
|
return tensor
|
|
elif tensor is not None:
|
|
tensor.producer.__class__ = trt.IConstantLayer
|
|
return tensor.producer.weights
|
|
else:
|
|
return None
|
|
|
|
def _regularize_value(self, value):
|
|
if isinstance(value, np.ndarray):
|
|
return value
|
|
|
|
elif isinstance(value, torch.distributed.tensor.DTensor):
|
|
return value.to_local().cpu().numpy()
|
|
elif isinstance(value, torch.Tensor):
|
|
return torch_to_numpy(value)
|
|
raise TypeError(
|
|
f'Expected numpy.ndarray or torch.Tensor, got {type(value)}')
|