TensorRT-LLMs/tensorrt_llm/models/quantized/quant.py
2023-09-20 00:29:41 -07:00

318 lines
14 KiB
Python

# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Union
import numpy as np
from ...layers import ColumnLinear, RowLinear
from ...models import GPTJForCausalLM, GPTLMHeadModel, LLaMAForCausalLM
from ...quantization import QuantMode
# isort: off
from ...quantization.layers import (
SmoothQuantAttention, SmoothQuantGatedMLP, SmoothQuantLayerNorm,
SmoothQuantMLP, SmoothQuantRmsNorm, WeightOnlyGroupwiseQuantColumnLinear,
WeightOnlyGroupwiseQuantRowLinear, WeightOnlyQuantColumnLinear,
WeightOnlyQuantRowLinear)
# isort: on
def _smooth_quantize_gpt(model, quant_mode):
assert quant_mode.has_act_and_weight_quant()
for layer in model.layers:
assert hasattr(layer,
"input_layernorm"), "The layer has no input_layernorm"
layer.input_layernorm = SmoothQuantLayerNorm(
normalized_shape=layer.hidden_size,
dtype=layer.dtype,
quant_mode=quant_mode)
assert hasattr(layer, "attention"), "The layer has no attention"
layer.attention = SmoothQuantAttention(
layer.hidden_size,
num_attention_heads=layer.num_attention_heads,
max_position_embeddings=layer.max_position_embeddings,
num_layers=layer.num_layers,
apply_query_key_layer_scaling=layer.apply_query_key_layer_scaling,
dtype=layer.dtype,
attention_mask_type=layer.attention_mask_type,
position_embedding_type=layer.position_embedding_type,
tp_group=layer.tp_group,
tp_size=layer.tp_size,
quant_mode=quant_mode)
assert hasattr(layer, "mlp"), "The layer has no mlp"
layer.mlp = SmoothQuantMLP(hidden_size=layer.hidden_size,
ffn_hidden_size=layer.hidden_size * 4,
hidden_act=layer.hidden_act,
dtype=layer.dtype,
tp_group=layer.tp_group,
tp_size=layer.tp_size,
quant_mode=quant_mode)
assert hasattr(layer,
"post_layernorm"), "The layer has no post_layernorm"
layer.post_layernorm = SmoothQuantLayerNorm(
normalized_shape=layer.hidden_size,
dtype=layer.dtype,
quant_mode=quant_mode)
setattr(model, 'quant_mode', quant_mode)
return model
def _smooth_quantize_llama(model, quant_mode):
assert quant_mode.has_act_and_weight_quant()
for layer in model.layers:
assert hasattr(layer,
"input_layernorm"), "The layer has no input_layernorm"
layer.input_layernorm = SmoothQuantRmsNorm(
normalized_shape=layer.hidden_size,
dtype=layer.dtype,
quant_mode=quant_mode)
assert hasattr(layer, "attention"), "The layer has no attention"
layer.attention = SmoothQuantAttention(
layer.hidden_size,
num_attention_heads=layer.num_attention_heads,
num_kv_heads=layer.num_kv_heads,
max_position_embeddings=layer.max_position_embeddings,
num_layers=model.num_layers,
dtype=layer.dtype,
attention_mask_type=layer.attention_mask_type,
position_embedding_type=layer.position_embedding_type,
tp_group=layer.tp_group,
tp_size=layer.tp_size,
quant_mode=quant_mode,
bias=False)
assert hasattr(layer, "mlp"), "The layer has no mlp"
layer.mlp = SmoothQuantGatedMLP(hidden_size=model.hidden_size,
ffn_hidden_size=layer.mlp_hidden_size,
hidden_act=layer.hidden_act,
dtype=layer.dtype,
tp_group=layer.tp_group,
tp_size=layer.tp_size,
quant_mode=quant_mode,
bias=False)
assert hasattr(
layer,
"post_layernorm"), "The layer has no post_rmspost_layernormnorm"
layer.post_layernorm = SmoothQuantRmsNorm(
normalized_shape=layer.hidden_size,
dtype=layer.dtype,
quant_mode=quant_mode)
setattr(model, 'quant_mode', quant_mode)
return model
def smooth_quantize(model, quant_mode):
assert isinstance(model, GPTLMHeadModel) or isinstance(model, LLaMAForCausalLM),\
"Only GPTLMHeadModel and LLaMAForCausalLM are well tested now"
if isinstance(model, GPTLMHeadModel):
return _smooth_quantize_gpt(model, quant_mode)
elif isinstance(model, LLaMAForCausalLM):
return _smooth_quantize_llama(model, quant_mode)
else:
assert False, f"Model {type(model).__name__} is not supported by SmoothQuant yet"
def weight_only_quantize(model,
quant_mode,
exclude_modules=None,
current_key_name=None):
assert quant_mode.is_weight_only()
exclude_modules = ['lm_head'
] if exclude_modules is None else exclude_modules
for name, module in model.named_children():
if current_key_name is None:
current_key_name = []
current_key_name.append(name)
if len(list(module.children())) > 0:
weight_only_quantize(module, quant_mode, exclude_modules,
current_key_name)
if isinstance(module, ColumnLinear) and name not in exclude_modules:
if not any(key in '.'.join(current_key_name)
for key in exclude_modules):
model._modules[name] = WeightOnlyQuantColumnLinear(
in_features=module.in_features,
out_features=module.out_features * module.tp_size,
bias=module.bias is not None,
dtype=module.dtype,
tp_group=module.tp_group,
tp_size=module.tp_size,
gather_output=module.gather_output,
quant_mode=quant_mode)
elif isinstance(module, RowLinear) and name not in exclude_modules:
if not any(key in '.'.join(current_key_name)
for key in exclude_modules):
model._modules[name] = WeightOnlyQuantRowLinear(
in_features=module.in_features * module.tp_size,
out_features=module.out_features,
bias=module.bias is not None,
dtype=module.dtype,
tp_group=module.tp_group,
tp_size=module.tp_size,
quant_mode=quant_mode)
current_key_name.pop(-1)
setattr(model, 'quant_mode', quant_mode)
return model
def weight_only_groupwise_quantize(model,
quant_mode,
group_size=128,
pre_quant_scale=False,
zero=False,
exclude_modules=None,
current_key_name=None):
exclude_modules = ['lm_head'
] if exclude_modules is None else exclude_modules
for name, module in model.named_children():
if current_key_name is None:
current_key_name = []
current_key_name.append(name)
if len(list(module.children())) > 0:
weight_only_groupwise_quantize(module, quant_mode, group_size,
pre_quant_scale, zero,
exclude_modules, current_key_name)
if isinstance(module, ColumnLinear) and name not in exclude_modules:
if not any(key in '.'.join(current_key_name)
for key in exclude_modules):
model._modules[name] = WeightOnlyGroupwiseQuantColumnLinear(
in_features=module.in_features,
out_features=module.out_features * module.tp_size,
group_size=group_size,
pre_quant_scale=pre_quant_scale,
zero=zero,
bias=module.bias is not None,
dtype=module.dtype,
tp_group=module.tp_group,
tp_size=module.tp_size,
gather_output=module.gather_output)
elif isinstance(module, RowLinear) and name not in exclude_modules:
if not any(key in '.'.join(current_key_name)
for key in exclude_modules):
model._modules[name] = WeightOnlyGroupwiseQuantRowLinear(
in_features=module.in_features * module.tp_size,
out_features=module.out_features,
group_size=group_size,
pre_quant_scale=pre_quant_scale,
zero=zero,
bias=module.bias is not None,
dtype=module.dtype,
tp_group=module.tp_group,
tp_size=module.tp_size)
current_key_name.pop(-1)
setattr(model, 'quant_mode', quant_mode)
return model
def get_dummy_quant_scales(num_layers):
return {
'lm_head_act': 0.99,
'lm_head_weights': 0.99,
'fc_act': [0.99 for _ in range(num_layers)],
'fc_weights': [0.99 for _ in range(num_layers)],
'gate_act': [0.99 for _ in range(num_layers)],
'gate_weights': [0.99 for _ in range(num_layers)],
'proj_act': [0.99 for _ in range(num_layers)],
'proj_weights': [0.99 for _ in range(num_layers)],
'qkv_act': [0.99 for _ in range(num_layers)],
'qkv_weights': [0.99 for _ in range(num_layers)],
'qkv_output': [5.0 for _ in range(num_layers)],
'dense_act': [0.99 for _ in range(num_layers)],
'dense_weights': [0.99 for _ in range(num_layers)],
}
def _quantize_layer(layer, layer_idx, quant_mode, quant_scales):
assert hasattr(layer, "mlp"), "The layer has no mlp"
layer.mlp.fc.activation_scaling_factor.value = np.array(
[quant_scales['fc_act'][layer_idx]], dtype=np.float32)
layer.mlp.fc.weights_scaling_factor.value = np.array(
[quant_scales['fc_weights'][layer_idx]], dtype=np.float32)
layer.mlp.proj.activation_scaling_factor.value = np.array(
[quant_scales['proj_act'][layer_idx]], dtype=np.float32)
layer.mlp.proj.weights_scaling_factor.value = np.array(
[quant_scales['proj_weights'][layer_idx]], dtype=np.float32)
if hasattr(layer.mlp, 'gate'):
layer.mlp.gate.activation_scaling_factor.value = np.array(
[quant_scales['gate_act'][layer_idx]], dtype=np.float32)
layer.mlp.gate.weights_scaling_factor.value = np.array(
[quant_scales['gate_weights'][layer_idx]], dtype=np.float32)
assert hasattr(layer, "attention"), "The layer has no attention"
layer.attention.qkv.activation_scaling_factor.value = np.array(
[quant_scales['qkv_act'][layer_idx]], dtype=np.float32)
layer.attention.qkv.weights_scaling_factor.value = np.array(
[quant_scales['qkv_weights'][layer_idx]], dtype=np.float32)
if quant_mode.has_fp8_kv_cache():
layer.attention.kv_orig_quant_scale.value = np.array(
[quant_scales['qkv_output'][layer_idx]], dtype=np.float32)
layer.attention.kv_quant_orig_scale.value = np.array(
[1.0 / quant_scales['qkv_output'][layer_idx]], dtype=np.float32)
layer.attention.dense.activation_scaling_factor.value = np.array(
[quant_scales['dense_act'][layer_idx]], dtype=np.float32)
layer.attention.dense.weights_scaling_factor.value = np.array(
[quant_scales['dense_weights'][layer_idx]], dtype=np.float32)
return layer
def _default_fp8_quantize(model: Union[GPTLMHeadModel, LLaMAForCausalLM,
GPTJForCausalLM],
quant_mode: QuantMode,
quant_scales: dict = None):
"""
Quantize all linear layers (i.e., MLP, Attention QKV/Dense) and KV cache IO with dummy scales
This is used by benchmark script and therefore is intentionally decoupled from AMMO toolkit
"""
if quant_scales is None:
quant_scales = get_dummy_quant_scales(model._num_layers)
assert model.quant_mode == quant_mode, "Quant setting not consistent with model init setting"
use_fp8_qdq = quant_mode.has_fp8_qdq()
assert use_fp8_qdq
for layer_idx, layer in enumerate(model.layers):
layer = _quantize_layer(layer, layer_idx, quant_mode, quant_scales)
return model
def fp8_quantize(model, quant_mode: QuantMode, quant_scales: dict = None):
if isinstance(model, LLaMAForCausalLM):
return _default_fp8_quantize(model, quant_mode, quant_scales)
elif isinstance(model, GPTJForCausalLM):
return _default_fp8_quantize(model, quant_mode, quant_scales)
elif isinstance(model, GPTLMHeadModel):
return _default_fp8_quantize(model, quant_mode, quant_scales)
else:
assert False, f"Model {model} is not implemented by fp8_quantize yet"