TensorRT-LLMs/tensorrt_llm/parameter.py
石晓伟 59f41c067d
Update TensorRT-LLM (#708)
* Update TensorRT-LLM

* update

* Bump version to 0.7.0
2023-12-20 16:38:28 +08:00

104 lines
3.9 KiB
Python
Raw Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Optional, Sequence, Union
import numpy as np
# isort: off
import torch
import tensorrt as trt
# isort: on
from ._utils import str_dtype_to_trt, torch_to_numpy, trt_dtype_to_torch
from .functional import Tensor, constant
from .logger import logger
class Parameter:
_DEFAULT_DTYPE = trt.DataType.FLOAT
def __init__(self,
value: Optional[Union[np.ndarray, torch.Tensor]] = None,
shape: Sequence[int] = None,
dtype: Union[str, trt.DataType] = None):
if dtype is None:
logger.warning(
f'Parameter dtype is None, using default dtype: {self._DEFAULT_DTYPE}, it is recommended to always specify dtype explicitly'
)
dtype = self._DEFAULT_DTYPE if dtype is None else dtype
if isinstance(dtype, str):
dtype = str_dtype_to_trt(dtype)
if value is None:
assert isinstance(shape, (list, tuple))
if len(shape) == 2:
# Xavier initialization see https://paperswithcode.com/method/xavier-initialization
v_range = math.sqrt(6) / math.sqrt(shape[0] + shape[1])
else:
v_range = 0.1
if dtype == trt.DataType.INT8:
upper = math.ceil(128 * v_range)
value = torch.randint(-upper,
upper, (shape),
dtype=trt_dtype_to_torch(dtype),
device='cuda')
# value ~ U[int(-128 * v_range), int(128 * v_range)]
else:
value = torch.randn(
(shape), dtype=trt_dtype_to_torch(dtype),
device='cuda') * 2 - 1
# value ~ N[-v_range, v_range]
value = value * v_range
self._value = self._regularize_value(value)
@property
def value(self) -> Tensor:
if isinstance(self._value, np.ndarray):
self._value = constant(self._value)
return self._value
@property
def raw_value(self) -> np.ndarray:
assert isinstance(
self._value, np.ndarray
), "Must be np.ndarray. Proper usage: get parameter.raw_value before getting parameter.value"
return self._value
@value.setter
def value(self, v: Union[np.ndarray, torch.Tensor]):
v = self._regularize_value(v)
assert v.shape == self._value.shape, \
f'The value updated is not the same shape as the original. ' \
f'Updated: {v.shape}, original: {self._value.shape}'
if self._value.dtype != v.dtype:
logger.warning(
f"Parameter was initialized as {self._value.dtype} but set to {v.dtype}"
)
self._value = v
def _get_weights(self) -> trt.Weights:
return self._value.producer.weights if isinstance(self._value,
Tensor) else None
def _regularize_value(self, value):
if isinstance(value, np.ndarray):
return value
elif isinstance(value, torch.Tensor):
return torch_to_numpy(value)
raise TypeError(
f'Expected numpy.ndarray or torch.Tensor, got {type(value)}')