TensorRT-LLMs/tensorrt_llm/models/gemma/weight.py
2024-09-03 12:14:23 +02:00

182 lines
7.4 KiB
Python

# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import Dict, Optional
import numpy as np
import torch
from tensorrt_llm.models.modeling_utils import PretrainedConfig
from ..._utils import numpy_to_torch
from ...logger import logger
from ...mapping import Mapping
Weights = Dict[str, torch.Tensor]
def quantize_fp8_weights(weights: Weights, num_layers: int,
mapping: Mapping) -> Weights:
def get_scaling_factor(weight):
amax = weight.max()
scale = 448.0 / amax
return scale
layers_range = mapping.pp_layers(num_layers)
scaling_factors = {}
scaled_weights = {}
trt_llm_prefix = "transformer.layers"
for l in layers_range:
# attention.qkv.weight
for name in [
"attention.qkv", "attention.dense", "mlp.fc", "mlp.gate",
"mlp.proj"
]:
trt_llm_name = ".".join((trt_llm_prefix, str(l), name, "weight"))
scale_name = ".".join(
(trt_llm_prefix, str(l), name, "weights_scaling_factor"))
weight = weights[trt_llm_name].float()
dtype = weights[trt_llm_name].dtype
scale = get_scaling_factor(weight)
scaled_weights[trt_llm_name] = (weight *
scale).to(dtype).contiguous()
scaling_factors[scale_name] = numpy_to_torch(
np.asarray([1 / scale]).astype(np.float32))
return scaling_factors
def load_from_fp8_gemma(quant_ckpt_path: Optional[str], num_layers: int,
mapping: Mapping, fp8_kv_cache: bool,
weight_scales: Weights):
"""
Get the fp8 scaling factors.
"""
fake_fp8_sf_dt = torch.float32
if quant_ckpt_path is not None and os.path.isfile(quant_ckpt_path):
fp8_gemma = np.load(quant_ckpt_path)
else:
fp8_gemma = None
logger.info(
f"There is not quantized checkpoint, use dummy fp8 scaling factors instead."
)
weights = {}
def get_fp8_gemma(name: str) -> np.ndarray:
if fp8_gemma is not None:
return fp8_gemma[name]
else:
return torch.tensor([1.0], dtype=fake_fp8_sf_dt).numpy()
layers_range = mapping.pp_layers(num_layers)
for l in layers_range:
prefix = f'_np:layers:{l}'
tllm_prex = f'transformer.layers.{l-layers_range[0]}'
weights[f'{tllm_prex}.attention.qkv.activation_scaling_factor'] = max(
get_fp8_gemma(
f'{prefix}:attention:qkv:q:activation_scaling_factor'),
get_fp8_gemma(
f'{prefix}:attention:qkv:k:activation_scaling_factor'),
get_fp8_gemma(
f'{prefix}:attention:qkv:v:activation_scaling_factor'))
weights[f'{tllm_prex}.attention.qkv.weights_scaling_factor'] = max(
get_fp8_gemma(f'{prefix}:attention:qkv:q:weights_scaling_factor'),
get_fp8_gemma(f'{prefix}:attention:qkv:k:weights_scaling_factor'),
get_fp8_gemma(f'{prefix}:attention:qkv:v:weights_scaling_factor'))
weights[
f'{tllm_prex}.attention.dense.activation_scaling_factor'] = get_fp8_gemma(
f'{prefix}:attention:dense:activation_scaling_factor')
weights[
f'{tllm_prex}.attention.dense.weights_scaling_factor'] = get_fp8_gemma(
f'{prefix}:attention:dense:weights_scaling_factor')
weights[
f'{tllm_prex}.mlp.fc.activation_scaling_factor'] = get_fp8_gemma(
f'{prefix}:mlp:fc:activation_scaling_factor')
weights[f'{tllm_prex}.mlp.fc.weights_scaling_factor'] = get_fp8_gemma(
f'{prefix}:mlp:fc:weights_scaling_factor')
weights[
f'{tllm_prex}.mlp.gate.activation_scaling_factor'] = get_fp8_gemma(
f'{prefix}:mlp:gate:activation_scaling_factor')
weights[f'{tllm_prex}.mlp.gate.weights_scaling_factor'] = get_fp8_gemma(
f'{prefix}:mlp:gate:weights_scaling_factor')
weights[
f'{tllm_prex}.mlp.proj.activation_scaling_factor'] = get_fp8_gemma(
f'{prefix}:mlp:proj:activation_scaling_factor')
weights[f'{tllm_prex}.mlp.proj.weights_scaling_factor'] = get_fp8_gemma(
f'{prefix}:mlp:proj:weights_scaling_factor')
if fp8_kv_cache:
# Not calibrating KV cache.
scaling_factor = 1.0
weights[
f'{tllm_prex}.attention.kv_cache_scaling_factor'] = torch.tensor(
[scaling_factor], dtype=fake_fp8_sf_dt).numpy()
if fp8_gemma is None:
weights.update(weight_scales)
for key in weights:
if isinstance(weights[key], np.ndarray):
weights[key] = numpy_to_torch(weights[key])
return weights
def dummy_weights_awq(weights: Weights, precision: str,
trt_llm_config: PretrainedConfig, group_size: int):
packer = torch.ops.trtllm.pack_int8_tensor_to_packed_int4
use_fp8_kv_cache = trt_llm_config.quant_mode.has_fp8_kv_cache()
use_int8_kv_cache = trt_llm_config.quant_mode.has_int8_kv_cache()
num_layers = trt_llm_config.num_hidden_layers
for name in list(weights):
if any([
_name in name for _name in [
'mlp.proj.weight', 'mlp.gate.weight', 'mlp.fc.weight',
'attention.qkv.weight', 'attention.dense.weight'
]
]):
print("Processing:", name)
weight = np.ascontiguousarray(weights[name].T)
in_dim, out_dim = weight.shape
scale = np.amax(weight) / 7
weights_scaling_factor = np.ones([out_dim, in_dim // group_size
]) * scale.astype(np.float32)
weight_smoothed = (weight.astype(np.float32) / scale).astype(
np.int8)
weight_smoothed[weight_smoothed < -8] = -8
weight_smoothed[weight_smoothed > 7] = 7
prequant_scaling_factor = np.ones([in_dim], dtype=weight.dtype)
weights[name] = packer(
torch.from_numpy(weight_smoothed)).T.contiguous().numpy()
weights[name.replace(
'weight', 'prequant_scaling_factor')] = prequant_scaling_factor
weights[name.replace(
'weight',
'weights_scaling_factor')] = weights_scaling_factor.astype(
weight.dtype)
if precision == "w4a8_awq":
alpha = np.array([1], dtype=np.float32)
weights[name.replace('weight', 'alpha')] = alpha
if use_fp8_kv_cache or use_int8_kv_cache:
for l in range(num_layers):
t = np.array([1], dtype=np.float32)
weights[
f"transformer.layers.{l}.attention.kv_cache_scaling_factor"] = t
return weights