mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
182 lines
7.4 KiB
Python
182 lines
7.4 KiB
Python
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import os
|
|
from typing import Dict, Optional
|
|
|
|
import numpy as np
|
|
import torch
|
|
|
|
from tensorrt_llm.models.modeling_utils import PretrainedConfig
|
|
|
|
from ..._utils import numpy_to_torch
|
|
from ...logger import logger
|
|
from ...mapping import Mapping
|
|
|
|
Weights = Dict[str, torch.Tensor]
|
|
|
|
|
|
def quantize_fp8_weights(weights: Weights, num_layers: int,
|
|
mapping: Mapping) -> Weights:
|
|
|
|
def get_scaling_factor(weight):
|
|
amax = weight.max()
|
|
scale = 448.0 / amax
|
|
return scale
|
|
|
|
layers_range = mapping.pp_layers(num_layers)
|
|
scaling_factors = {}
|
|
scaled_weights = {}
|
|
trt_llm_prefix = "transformer.layers"
|
|
for l in layers_range:
|
|
# attention.qkv.weight
|
|
for name in [
|
|
"attention.qkv", "attention.dense", "mlp.fc", "mlp.gate",
|
|
"mlp.proj"
|
|
]:
|
|
trt_llm_name = ".".join((trt_llm_prefix, str(l), name, "weight"))
|
|
scale_name = ".".join(
|
|
(trt_llm_prefix, str(l), name, "weights_scaling_factor"))
|
|
weight = weights[trt_llm_name].float()
|
|
dtype = weights[trt_llm_name].dtype
|
|
scale = get_scaling_factor(weight)
|
|
scaled_weights[trt_llm_name] = (weight *
|
|
scale).to(dtype).contiguous()
|
|
scaling_factors[scale_name] = numpy_to_torch(
|
|
np.asarray([1 / scale]).astype(np.float32))
|
|
return scaling_factors
|
|
|
|
|
|
def load_from_fp8_gemma(quant_ckpt_path: Optional[str], num_layers: int,
|
|
mapping: Mapping, fp8_kv_cache: bool,
|
|
weight_scales: Weights):
|
|
"""
|
|
Get the fp8 scaling factors.
|
|
"""
|
|
fake_fp8_sf_dt = torch.float32
|
|
|
|
if quant_ckpt_path is not None and os.path.isfile(quant_ckpt_path):
|
|
fp8_gemma = np.load(quant_ckpt_path)
|
|
else:
|
|
fp8_gemma = None
|
|
logger.info(
|
|
f"There is not quantized checkpoint, use dummy fp8 scaling factors instead."
|
|
)
|
|
weights = {}
|
|
|
|
def get_fp8_gemma(name: str) -> np.ndarray:
|
|
if fp8_gemma is not None:
|
|
return fp8_gemma[name]
|
|
else:
|
|
return torch.tensor([1.0], dtype=fake_fp8_sf_dt).numpy()
|
|
|
|
layers_range = mapping.pp_layers(num_layers)
|
|
for l in layers_range:
|
|
prefix = f'_np:layers:{l}'
|
|
tllm_prex = f'transformer.layers.{l-layers_range[0]}'
|
|
|
|
weights[f'{tllm_prex}.attention.qkv.activation_scaling_factor'] = max(
|
|
get_fp8_gemma(
|
|
f'{prefix}:attention:qkv:q:activation_scaling_factor'),
|
|
get_fp8_gemma(
|
|
f'{prefix}:attention:qkv:k:activation_scaling_factor'),
|
|
get_fp8_gemma(
|
|
f'{prefix}:attention:qkv:v:activation_scaling_factor'))
|
|
weights[f'{tllm_prex}.attention.qkv.weights_scaling_factor'] = max(
|
|
get_fp8_gemma(f'{prefix}:attention:qkv:q:weights_scaling_factor'),
|
|
get_fp8_gemma(f'{prefix}:attention:qkv:k:weights_scaling_factor'),
|
|
get_fp8_gemma(f'{prefix}:attention:qkv:v:weights_scaling_factor'))
|
|
weights[
|
|
f'{tllm_prex}.attention.dense.activation_scaling_factor'] = get_fp8_gemma(
|
|
f'{prefix}:attention:dense:activation_scaling_factor')
|
|
weights[
|
|
f'{tllm_prex}.attention.dense.weights_scaling_factor'] = get_fp8_gemma(
|
|
f'{prefix}:attention:dense:weights_scaling_factor')
|
|
|
|
weights[
|
|
f'{tllm_prex}.mlp.fc.activation_scaling_factor'] = get_fp8_gemma(
|
|
f'{prefix}:mlp:fc:activation_scaling_factor')
|
|
weights[f'{tllm_prex}.mlp.fc.weights_scaling_factor'] = get_fp8_gemma(
|
|
f'{prefix}:mlp:fc:weights_scaling_factor')
|
|
|
|
weights[
|
|
f'{tllm_prex}.mlp.gate.activation_scaling_factor'] = get_fp8_gemma(
|
|
f'{prefix}:mlp:gate:activation_scaling_factor')
|
|
weights[f'{tllm_prex}.mlp.gate.weights_scaling_factor'] = get_fp8_gemma(
|
|
f'{prefix}:mlp:gate:weights_scaling_factor')
|
|
|
|
weights[
|
|
f'{tllm_prex}.mlp.proj.activation_scaling_factor'] = get_fp8_gemma(
|
|
f'{prefix}:mlp:proj:activation_scaling_factor')
|
|
weights[f'{tllm_prex}.mlp.proj.weights_scaling_factor'] = get_fp8_gemma(
|
|
f'{prefix}:mlp:proj:weights_scaling_factor')
|
|
|
|
if fp8_kv_cache:
|
|
# Not calibrating KV cache.
|
|
scaling_factor = 1.0
|
|
weights[
|
|
f'{tllm_prex}.attention.kv_cache_scaling_factor'] = torch.tensor(
|
|
[scaling_factor], dtype=fake_fp8_sf_dt).numpy()
|
|
if fp8_gemma is None:
|
|
weights.update(weight_scales)
|
|
|
|
for key in weights:
|
|
if isinstance(weights[key], np.ndarray):
|
|
weights[key] = numpy_to_torch(weights[key])
|
|
return weights
|
|
|
|
|
|
def dummy_weights_awq(weights: Weights, precision: str,
|
|
trt_llm_config: PretrainedConfig, group_size: int):
|
|
packer = torch.ops.trtllm.pack_int8_tensor_to_packed_int4
|
|
use_fp8_kv_cache = trt_llm_config.quant_mode.has_fp8_kv_cache()
|
|
use_int8_kv_cache = trt_llm_config.quant_mode.has_int8_kv_cache()
|
|
num_layers = trt_llm_config.num_hidden_layers
|
|
for name in list(weights):
|
|
if any([
|
|
_name in name for _name in [
|
|
'mlp.proj.weight', 'mlp.gate.weight', 'mlp.fc.weight',
|
|
'attention.qkv.weight', 'attention.dense.weight'
|
|
]
|
|
]):
|
|
print("Processing:", name)
|
|
weight = np.ascontiguousarray(weights[name].T)
|
|
in_dim, out_dim = weight.shape
|
|
scale = np.amax(weight) / 7
|
|
weights_scaling_factor = np.ones([out_dim, in_dim // group_size
|
|
]) * scale.astype(np.float32)
|
|
weight_smoothed = (weight.astype(np.float32) / scale).astype(
|
|
np.int8)
|
|
weight_smoothed[weight_smoothed < -8] = -8
|
|
weight_smoothed[weight_smoothed > 7] = 7
|
|
prequant_scaling_factor = np.ones([in_dim], dtype=weight.dtype)
|
|
weights[name] = packer(
|
|
torch.from_numpy(weight_smoothed)).T.contiguous().numpy()
|
|
weights[name.replace(
|
|
'weight', 'prequant_scaling_factor')] = prequant_scaling_factor
|
|
weights[name.replace(
|
|
'weight',
|
|
'weights_scaling_factor')] = weights_scaling_factor.astype(
|
|
weight.dtype)
|
|
if precision == "w4a8_awq":
|
|
alpha = np.array([1], dtype=np.float32)
|
|
weights[name.replace('weight', 'alpha')] = alpha
|
|
if use_fp8_kv_cache or use_int8_kv_cache:
|
|
for l in range(num_layers):
|
|
t = np.array([1], dtype=np.float32)
|
|
weights[
|
|
f"transformer.layers.{l}.attention.kv_cache_scaling_factor"] = t
|
|
|
|
return weights
|