TensorRT-LLMs/tensorrt_llm/models/modeling_utils.py
2024-03-19 17:36:42 +08:00

949 lines
39 KiB
Python

import copy
import dataclasses
import json
import os
from typing import Dict, List, Optional, Union
import numpy as np
import safetensors
import torch
from .._common import default_net
from .._utils import (numpy_to_torch, str_dtype_to_torch, str_dtype_to_trt,
trt_dtype_to_torch)
from ..functional import PositionEmbeddingType, Tensor, gather_last_token_logits
from ..layers import (AttentionParams, Embedding, FusedGatedMLP, GatedMLP,
KeyValueCacheParams, LoraParams, PromptTuningEmbedding)
from ..layers.attention import Attention, BertAttention
from ..layers.linear import ColumnLinear, Linear, RowLinear
from ..layers.lora import Lora
from ..logger import logger
from ..mapping import Mapping
from ..module import Module, ModuleList
from ..quantization import QuantMode
from ..quantization.layers import FP8Linear
from ..quantization.mode import (FP8, W4A8_AWQ, W4A16, W4A16_AWQ,
W8A8_SQ_PLUGIN_LIST, W8A16)
from ..quantization.quantize import quantize
from .generation_mixin import GenerationMixin
WEIGHT_LOADER_MODELS = {"PhiForCausalLM"}
@dataclasses.dataclass
class QuantizationConfig:
'''Serializable quantization configuration class, part of the PretrainedConfig
'''
quant_algo: Optional[str] = None
kv_cache_quant_algo: Optional[str] = None
group_size: Optional[int] = 128
has_zero_point: Optional[bool] = False
pre_quant_scale: Optional[bool] = False
exclude_modules: Optional[List[str]] = None
@property
def use_plugin_sq(self):
return self.quant_algo in W8A8_SQ_PLUGIN_LIST
def quant_algo_to_ammo_qformat(self):
from ..quantization import mode as quant_algo
#"fp8", "int8_sq", "int4_awq", "w4a8_awq", "int8_wo", "int4_wo",
algo_to_ammo_map = {
quant_algo.W8A16: "int8_wo",
quant_algo.W4A16: "int4_wo",
quant_algo.W4A16_AWQ: "int4_awq",
quant_algo.W4A8_AWQ: 'w4a8_awq',
quant_algo.FP8: 'fp8',
quant_algo.W4A16_GPTQ: None,
quant_algo.W8A8_SQ_PER_CHANNEL: 'int8_sq',
quant_algo.W8A8_SQ_PER_TENSOR_PLUGIN: None,
quant_algo.W8A8_SQ_PER_CHANNEL_PER_TOKEN_PLUGIN: None,
quant_algo.W8A8_SQ_PER_CHANNEL_PER_TENSOR_PLUGIN: None,
quant_algo.W8A8_SQ_PER_TENSOR_PER_TOKEN_PLUGIN: None,
None: 'full_prec'
}
assert self.quant_algo in algo_to_ammo_map
qformat = algo_to_ammo_map[self.quant_algo]
assert qformat is not None, "None means we don't use AMMO for this kind of quantization algorithm, you probably shall not call this"
return qformat
def asdict(self):
return dataclasses.asdict(self)
def default_weight_loader(mapping: Mapping, param: torch.Tensor,
loaded_weight: torch.Tensor) -> None:
"""Default weight loader."""
param.value = loaded_weight
class PretrainedConfig:
def __init__(self,
architecture: str,
dtype: str,
logits_dtype: str,
vocab_size: int,
max_position_embeddings: int,
hidden_size: int,
num_hidden_layers: int,
num_attention_heads: int,
num_key_value_heads: int,
hidden_act: str,
intermediate_size: int,
norm_epsilon: float,
position_embedding_type: str,
world_size: int,
tp_size: int,
pp_size: int,
quantization: Union[QuantizationConfig, dict],
use_parallel_embedding: bool = False,
embedding_sharding_dim: int = 0,
share_embedding_table: bool = False,
head_size: int = None,
**kwargs):
self.architecture = architecture
self.dtype = dtype
self.logits_dtype = logits_dtype
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.head_size = hidden_size // num_attention_heads if head_size is None else head_size
self.hidden_act = hidden_act
self.intermediate_size = intermediate_size
self.norm_epsilon = norm_epsilon
self.position_embedding_type = PositionEmbeddingType.from_string(
position_embedding_type)
self.use_parallel_embedding = use_parallel_embedding
self.embedding_sharding_dim = embedding_sharding_dim
self.share_embedding_table = share_embedding_table
self.mapping = Mapping(world_size=world_size,
tp_size=tp_size,
pp_size=pp_size)
# Ideally shall only keep one of quant_mode and quant_config to make sure single source of truth
# keep them now since many code path is still using the quant_mode
self.quant_mode = QuantMode.from_quant_algo(
quantization.quant_algo, quantization.kv_cache_quant_algo)
if isinstance(quantization, dict):
self.quantization = dataclasses.replace(QuantizationConfig(),
**quantization)
else:
assert isinstance(
quantization, QuantizationConfig
), f"Expecting type of QuantizationConfig, found {type(quantization)}"
self.quantization = quantization
self.kv_dtype = self.dtype
if self.quant_mode.has_int8_kv_cache():
self.kv_dtype = 'int8'
elif self.quant_mode.has_fp8_kv_cache():
self.kv_dtype = 'fp8'
for key, value in kwargs.items():
try:
setattr(self, key, value)
except AttributeError as err:
raise err
def set_if_not_exist(self, key, value):
if not hasattr(self, key):
setattr(self, key, value)
@classmethod
def from_dict(cls, config):
config = copy.deepcopy(
config
) # many config.pop calls inside, make one local copy of the config dict such that the function has no side effects
architecture = config.pop('architecture')
dtype = config.pop('dtype')
vocab_size = config.pop('vocab_size')
hidden_size = config.pop('hidden_size')
num_hidden_layers = config.pop('num_hidden_layers')
num_attention_heads = config.pop('num_attention_heads')
hidden_act = config.pop('hidden_act')
norm_epsilon = config.pop('norm_epsilon', 1e-5)
position_embedding_type = config.pop('position_embedding_type',
'learned_absolute')
logits_dtype = config.pop('logits_dtype', 'float32')
num_key_value_heads = config.pop('num_key_value_heads',
num_attention_heads)
intermediate_size = config.pop('intermediate_size', None)
max_position_embeddings = config.pop('max_position_embeddings', None)
use_parallel_embedding = config.pop('use_parallel_embedding', False)
embedding_sharding_dim = config.pop('embedding_sharding_dim', 0)
share_embedding_table = config.pop('share_embedding_table', False)
mapping = config.pop('mapping', {
'world_size': 1,
'tp_size': 1,
'pp_size': 1
})
world_size = mapping.get('world_size', 1)
tp_size = mapping.get('tp_size', 1)
pp_size = mapping.get('pp_size', 1)
if share_embedding_table and tp_size > 1:
if (not use_parallel_embedding) or (use_parallel_embedding and
embedding_sharding_dim == 1):
raise NotImplementedError(
"For multiple-processes cases, sharing the embedding table must set" \
"use_parallel_embedding=True and embedding_sharding_dim=0"
)
quant_config = QuantizationConfig()
if 'quantization' in config:
# override the default quantization object from the given dict, allows user to specify partial set of the fields
quant_config_from_user = config.pop('quantization')
if isinstance(quant_config_from_user, dict):
quant_config = dataclasses.replace(quant_config,
**quant_config_from_user)
# allow user to directly pass one QuantizationConfig object
else:
assert isinstance(quant_config_from_user, QuantizationConfig)
quant_config = quant_config_from_user
return cls(architecture, dtype, logits_dtype, vocab_size,
max_position_embeddings, hidden_size, num_hidden_layers,
num_attention_heads, num_key_value_heads, hidden_act,
intermediate_size, norm_epsilon, position_embedding_type,
world_size, tp_size, pp_size, quant_config,
use_parallel_embedding, embedding_sharding_dim,
share_embedding_table, **config)
@classmethod
def from_json_file(cls, config_file: str):
with open(config_file) as f:
config = json.load(f)
return PretrainedConfig.from_dict(config)
def to_dict(self):
output = copy.deepcopy(self.__dict__)
output['position_embedding_type'] = str(self.position_embedding_type)
output['mapping'] = {
'world_size': self.mapping.world_size,
'tp_size': self.mapping.tp_size,
'pp_size': self.mapping.pp_size,
}
output.pop('quant_mode')
output['quantization'] = dataclasses.asdict(self.quantization)
return output
def set_rank(self, rank):
self.mapping = Mapping(self.mapping.world_size,
rank=rank,
tp_size=self.mapping.tp_size,
pp_size=self.mapping.pp_size)
class DecoderLayerList(ModuleList):
def __init__(self, cls, config):
self.layer_list = config.mapping.pp_layers(config.num_hidden_layers)
super().__init__([cls(config, idx) for idx in self.layer_list])
def forward(self,
hidden_states,
use_cache=False,
attention_mask=None,
kv_cache_params=None,
attention_params=None,
lora_params=None,
medusa_position_offsets=None,
medusa_packed_mask=None):
kv_cache_params.fill_none_tensor_list(len(self.layer_list))
if use_cache:
presents = []
for layer_idx, (layer, past) in enumerate(
zip(self, kv_cache_params.past_key_value)):
lora_layer_params = None
if lora_params is not None and lora_params.lora_ranks is not None:
lora_layer_params = lora_params.get_layer_params(layer_idx)
kwargs = {}
if lora_layer_params is not None:
kwargs['lora_layer_params'] = lora_layer_params
if medusa_position_offsets is not None:
kwargs['medusa_position_offsets'] = medusa_position_offsets
if medusa_packed_mask is not None:
kwargs['medusa_packed_mask'] = medusa_packed_mask
hidden_states = layer(
hidden_states,
use_cache=use_cache,
attention_mask=attention_mask,
kv_cache_params=KeyValueCacheParams(
past_key_value=[past],
host_past_key_value_lengths=kv_cache_params.
host_past_key_value_lengths,
host_max_attention_window_sizes=kv_cache_params.
host_max_attention_window_sizes,
host_sink_token_length=kv_cache_params.
host_sink_token_length,
kv_cache_block_pointers=kv_cache_params.
kv_cache_block_pointers,
host_kv_cache_block_pointers=kv_cache_params.
host_kv_cache_block_pointers,
cache_indirection=kv_cache_params.cache_indirection),
attention_params=attention_params,
**kwargs)
if use_cache:
presents.append(hidden_states[1])
hidden_states = hidden_states[0]
if use_cache:
return hidden_states, presents
return hidden_states
class PostInitCaller(type):
def __call__(cls, *args, **kwargs):
obj = type.__call__(cls, *args, **kwargs)
obj.__post_init__()
return obj
class PretrainedModel(Module, GenerationMixin, metaclass=PostInitCaller):
def __init__(self, config: PretrainedConfig):
super().__init__()
self.config = config
def __post_init__(self):
quantize(self, self.config.quant_mode,
**dataclasses.asdict(self.config.quantization))
def check_config(self, config):
raise NotImplementedError(
f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
)
@classmethod
def from_config(cls, config: PretrainedConfig):
return cls(config)
@classmethod
def from_checkpoint(cls,
ckpt_dir: str,
rank: int = 0,
config: PretrainedConfig = None):
if config is None:
config = PretrainedConfig.from_json_file(
os.path.join(ckpt_dir, 'config.json'))
config.set_rank(rank)
model = cls.from_config(config)
weights = {}
with safetensors.safe_open(os.path.join(ckpt_dir,
f'rank{rank}.safetensors'),
framework='pt',
device='cpu') as f:
for key in f.keys():
weights[key] = f.get_tensor(key)
preprocess_weights(weights, config)
model.load(weights)
return model
def load(self, weights):
expected_names = set([name for name, param in self.named_parameters()])
provided_names = set(weights.keys())
assert expected_names.issubset(
provided_names
), f"Expected but not provided tensors:{expected_names.difference(provided_names)}"
if self.config.architecture in WEIGHT_LOADER_MODELS:
mapping = self.config.mapping
for name, param in self.named_parameters():
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(mapping, param, weights[name])
else:
for name, param in self.named_parameters():
try:
param.value = weights[name]
except Exception as e:
raise RuntimeError(
f"Encounter error '{e}' for parameter '{name}'")
def load_partial_weights(self, weights: dict):
params = {name: param for name, param in self.named_parameters()}
mapping = self.config.mapping
for k, v in weights.items():
if k in params.keys():
param = params[k]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(mapping, param, v)
elif mapping.pp_size == 1:
logger.warning(f"Provided but not expected tensors: {k}")
def save_checkpoint(self, output_dir, save_config=True):
# multiple ranks could share same config.json, so adding a save_config parameter to let user avoiding writing config.json in all ranks
rank = self.config.mapping.rank
weights = {
name: numpy_to_torch(param.raw_value)
for name, param in self.named_parameters()
}
from safetensors.torch import save_file
save_file(weights, os.path.join(output_dir, f'rank{rank}.safetensors'))
if save_config:
with open(os.path.join(output_dir, 'config.json'), 'w') as f:
json.dump(self.config.to_dict(), f, indent=4)
def prepare_inputs(self,
max_batch_size,
max_input_len,
max_seq_len,
use_cache,
max_beam_width: int = 1,
max_num_tokens: int = None,
prompt_embedding_table_size: int = 0,
position_encoding_2d: bool = False,
max_draft_len: int = 0,
gather_context_logits: bool = False,
gather_generation_logits: bool = False,
lora_target_modules: List[str] = None):
'''@brief: Prepare inputs Tensors for the model, the given sizes are used to determine the
ranges of the dimensions of when using TRT dynamic shapes.
@return: a list contains values which can be fed into the self.forward()
'''
# Prepare inputs
remove_input_padding = default_net().plugin_config.remove_input_padding
use_gpt_attention_plugin = default_net(
).plugin_config.gpt_attention_plugin
use_gemm_plugin = default_net().plugin_config.gemm_plugin
paged_kv_cache = default_net().plugin_config.paged_kv_cache
tokens_per_block = default_net().plugin_config.tokens_per_block
use_custom_all_reduce = default_net(
).plugin_config.use_custom_all_reduce
use_lora_plugin = default_net().plugin_config.lora_plugin
multiple_profiles = default_net().plugin_config.multiple_profiles
model_inputs = self.prepare_basic_inputs(
max_batch_size=max_batch_size,
max_beam_width=max_beam_width,
max_input_len=max_input_len,
max_seq_len=max_seq_len,
num_kv_heads=self.config.num_key_value_heads,
head_size=self.config.head_size,
num_layers=self.config.num_hidden_layers,
kv_dtype=str_dtype_to_trt(self.config.kv_dtype),
remove_input_padding=remove_input_padding,
use_gpt_attention_plugin=use_gpt_attention_plugin,
use_gemm_plugin=use_gemm_plugin,
paged_kv_cache=paged_kv_cache,
tokens_per_block=tokens_per_block,
num_heads=self.config.num_attention_heads,
max_num_tokens=max_num_tokens,
dtype=str_dtype_to_trt(self.config.dtype),
prompt_embedding_table_size=prompt_embedding_table_size,
position_encoding_2d=position_encoding_2d,
mapping=self.config.mapping,
gather_context_logits=gather_context_logits,
gather_generation_logits=gather_generation_logits,
use_custom_all_reduce=use_custom_all_reduce,
use_lora_plugin=use_lora_plugin,
max_draft_len=max_draft_len,
lora_target_modules=lora_target_modules,
multiple_profiles=multiple_profiles)
result = {
'input_ids':
model_inputs['input_ids'],
'position_ids':
model_inputs['position_ids'],
'use_cache':
True,
'last_token_ids':
model_inputs['last_token_ids'],
'attention_mask':
model_inputs['attention_mask'],
'kv_cache_params':
KeyValueCacheParams(
past_key_value=model_inputs['past_key_value'],
host_past_key_value_lengths=model_inputs[
'host_past_key_value_lengths'],
host_max_attention_window_sizes=model_inputs[
'host_max_attention_window_sizes'],
host_sink_token_length=model_inputs['host_sink_token_length'],
kv_cache_block_pointers=model_inputs['kv_cache_block_pointers'],
host_kv_cache_block_pointers=model_inputs[
'host_kv_cache_block_pointers'],
cache_indirection=model_inputs['cache_indirection'],
),
'attention_params':
AttentionParams(
sequence_length=model_inputs['sequence_length'],
context_lengths=model_inputs['context_lengths'],
host_context_lengths=model_inputs['host_context_lengths'],
max_context_length=max_input_len,
host_request_types=model_inputs['host_request_types'])
}
if prompt_embedding_table_size > 0:
result['prompt_embedding_table'] = model_inputs[
'prompt_embedding_table']
result['prompt_tasks'] = model_inputs['tasks']
result['prompt_vocab_size'] = model_inputs['prompt_vocab_size']
if model_inputs['hidden_states_input'] is not None:
result['hidden_states'] = model_inputs['hidden_states_input']
if use_lora_plugin:
result['lora_params'] = LoraParams(
model_inputs['lora_ranks'],
model_inputs['lora_weights_pointers'],
host_context_lengths=model_inputs['host_context_lengths'],
max_context_length=max_input_len,
host_request_types=model_inputs['host_request_types'])
return result
@classmethod
def quantize(
cls,
hf_model_dir,
output_dir,
quant_config: QuantizationConfig,
*,
dtype='float16',
mapping: Optional[Mapping] = None,
calib_batches=512,
calib_batch_size=1,
random_seed=1234,
tokenizer_max_seq_length=2048,
):
if mapping is None: # single gpu
mapping = Mapping()
ammo_qformat = quant_config.quant_algo_to_ammo_qformat()
kv_cache_dtype = quant_config.kv_cache_quant_algo
assert ammo_qformat is not None
from ..quantization import quantize_and_export
hf_model_dir = str(
hf_model_dir) # quantize_and_export has some code can not take Path
quantize_and_export(
model_dir=hf_model_dir,
dtype=dtype,
device='cuda',
qformat=ammo_qformat,
kv_cache_dtype=kv_cache_dtype,
calib_size=calib_batches,
batch_size=calib_batch_size,
output_dir=output_dir,
tp_size=mapping.tp_size,
pp_size=mapping.pp_size,
seed=random_seed,
max_seq_length=tokenizer_max_seq_length,
awq_block_size=quant_config.group_size,
)
class DecoderModelForCausalLM(PretrainedModel):
def __init__(self, config, transformer, lm_head):
super().__init__(config)
self.transformer = transformer
self.lm_head = lm_head
def forward(self,
input_ids: Tensor,
position_ids=None,
use_cache=False,
last_token_ids=None,
attention_mask=None,
kv_cache_params=None,
attention_params=None,
hidden_states=None,
prompt_embedding_table: Optional[Tensor] = None,
prompt_tasks: Optional[Tensor] = None,
prompt_vocab_size: Optional[Tensor] = None,
lora_params=None,
medusa_position_offsets=None,
medusa_packed_mask=None):
kwargs = {
'input_ids': input_ids,
'position_ids': position_ids,
'use_cache': use_cache,
'attention_mask': attention_mask,
'kv_cache_params': kv_cache_params,
'attention_params': attention_params,
}
if lora_params is not None:
kwargs['lora_params'] = lora_params
if hidden_states is not None:
kwargs['hidden_states'] = hidden_states
if prompt_embedding_table is not None:
kwargs['prompt_embedding_table'] = prompt_embedding_table
if prompt_tasks is not None:
kwargs['prompt_tasks'] = prompt_tasks
if prompt_vocab_size is not None:
kwargs['prompt_vocab_size'] = prompt_vocab_size
if medusa_position_offsets is not None:
kwargs['medusa_position_offsets'] = medusa_position_offsets
if medusa_packed_mask is not None:
kwargs['medusa_packed_mask'] = medusa_packed_mask
hidden_states = self.transformer.forward(**kwargs)
if use_cache:
hidden_states, presents = hidden_states
if self.config.mapping.is_last_pp_rank():
hidden_states = gather_last_token_logits(
hidden_states, last_token_ids,
default_net().plugin_config.remove_input_padding)
# [batch_size, hidden_size] -> [batch_size, vocab_size]
lm_logits = self.lm_head(hidden_states)
lm_logits.mark_output('logits', self.config.logits_dtype)
else:
hidden_states.mark_output('hidden_states_output', self.config.dtype)
if use_cache and not default_net().plugin_config.paged_kv_cache:
for i, present in zip(
self.config.mapping.pp_layers(
self.config.num_hidden_layers), presents):
present.mark_output(f'present_key_value_{i}',
self.config.kv_dtype)
if self.config.mapping.is_last_pp_rank():
return (lm_logits, presents, hidden_states)
return (hidden_states, presents)
else:
if self.config.mapping.is_last_pp_rank():
return lm_logits, hidden_states
return hidden_states
def fuse_gate_mlp(model):
for layer in model.transformer.layers:
if not hasattr(layer, 'mlp'):
continue
quant_algo = model.config.quantization.quant_algo
if isinstance(layer.mlp, GatedMLP):
fused_layer = FusedGatedMLP(
hidden_size=layer.mlp.hidden_size,
ffn_hidden_size=layer.mlp.ffn_hidden_size,
hidden_act=layer.mlp.hidden_act,
bias=layer.mlp.bias,
dtype=layer.mlp.dtype,
tp_group=layer.mlp.tp_group,
tp_size=layer.mlp.tp_size,
quant_mode=layer.mlp.quant_mode)
if quant_algo == 'FP8':
if isinstance(layer.mlp.dtype, str):
dtype = str_dtype_to_torch(layer.mlp.dtype)
else:
dtype = trt_dtype_to_torch(layer.mlp.dtype)
# dequantize
gate_weight = numpy_to_torch(
layer.mlp.gate.weight.raw_value).to(dtype) * numpy_to_torch(
layer.mlp.gate.weights_scaling_factor.raw_value)
fc_weight = numpy_to_torch(
layer.mlp.fc.weight.raw_value).to(dtype) * numpy_to_torch(
layer.mlp.fc.weights_scaling_factor.raw_value)
# concat
fused_weight = torch.cat([gate_weight, fc_weight], dim=0)
# quantize
fused_weight_scaling_factor = numpy_to_torch(
max(
layer.mlp.gate.weights_scaling_factor.raw_value,
layer.mlp.fc.weights_scaling_factor.raw_value,
))
fused_weight = (fused_weight / fused_weight_scaling_factor).to(
torch.float8_e4m3fn)
fused_layer.fused_fc.weight.value = fused_weight
fused_layer.fused_fc.weights_scaling_factor.value = fused_weight_scaling_factor
fused_layer.fused_fc.activation_scaling_factor.value = \
max(layer.mlp.gate.activation_scaling_factor.raw_value,
layer.mlp.fc.activation_scaling_factor.raw_value
)
elif quant_algo is None:
fused_layer.fused_fc.weight.value = np.concatenate([
layer.mlp.gate.weight.raw_value,
layer.mlp.fc.weight.raw_value
],
axis=0)
if layer.mlp.bias:
fused_layer.fused_fc.bias.value = np.concatenate([
layer.mlp.gate.bias.raw_value,
layer.mlp.fc.bias.raw_value
],
axis=0)
else:
raise ValueError(f'Unsupported quant algo: {quant_algo}')
fused_layer.proj = layer.mlp.proj
layer.mlp = fused_layer
return model
def unfuse_qkv_gemm(model):
for name, layer in model.named_modules(remove_duplicate=True):
if isinstance(layer, Attention) and not layer.cross_attention:
assert layer.tp_size == 1, "please disable manual tp when enable auto parallel"
if layer.unfuse_qkv_gemm:
continue
layer.unfuse_qkv_gemm = True
linear_class = FP8Linear if layer.use_fp8_qdq else ColumnLinear
q = linear_class(layer.hidden_size,
layer.attention_hidden_size,
bias=layer.bias,
dtype=layer.dtype,
gather_output=False)
k = linear_class(layer.hidden_size,
layer.num_attention_kv_heads *
layer.attention_head_size,
bias=layer.bias,
dtype=layer.dtype,
gather_output=False)
v = linear_class(layer.hidden_size,
layer.num_attention_kv_heads *
layer.attention_head_size,
bias=layer.bias,
dtype=layer.dtype,
gather_output=False)
if layer.qkv.weight.is_inited():
qkv_weight = layer.qkv.weight.raw_value
weights = np.split(qkv_weight, [
q.out_features,
q.out_features + k.out_features,
])
for gemm, weight in zip([q, k, v], weights):
gemm.weight.value = weight
if layer.qkv.bias is not None and layer.qkv.bias.is_inited():
qkv_bias = layer.qkv.bias.raw_value
biases = np.split(qkv_bias, [
q.out_features,
q.out_features + k.out_features,
])
for gemm, bias in zip([q, k, v], biases):
gemm.bias.value = bias
for name, parameter in layer.qkv._parameters.items():
if name not in ["weight", "bias"]:
for gemm in [q, k, v]:
setattr(gemm, name, parameter)
layer.q = q
layer.k = k
layer.v = v
layer.qkv = None
return model
def set_prompt_tuning(model):
if isinstance(model.transformer.vocab_embedding, Embedding):
embedding = model.transformer.vocab_embedding
model.transformer.vocab_embedding = PromptTuningEmbedding(
num_embeddings=embedding.num_embeddings,
embedding_dim=embedding.embedding_dim,
dtype=embedding.dtype,
tp_size=embedding.tp_size,
tp_group=embedding.tp_group,
sharding_dim=embedding.sharding_dim,
tp_rank=embedding.tp_rank)
model.transformer.vocab_embedding.weight.value = embedding.weight.raw_value
return model
def add_lora(model, max_lora_rank: Optional[int]):
for name, layer in model.named_modules(remove_duplicate=True):
max_rank = max_lora_rank
if isinstance(layer, (Attention, BertAttention)):
if max_rank is None:
max_rank = min(
layer.hidden_size,
layer.num_attention_heads * layer.attention_head_size,
layer.num_attention_kv_heads * layer.attention_head_size)
layer.qkv_lora = Lora(
in_hidden_size=layer.hidden_size,
out_hidden_sizes=[
layer.num_attention_heads * layer.attention_head_size,
layer.num_attention_kv_heads * layer.attention_head_size,
layer.num_attention_kv_heads * layer.attention_head_size
],
max_low_rank=max_rank,
)
if isinstance(layer, (Linear, RowLinear)):
if max_rank is None:
max_rank = min(layer.in_features, layer.out_features)
layer.lora = Lora(
in_hidden_size=layer.in_features,
out_hidden_sizes=[layer.out_features],
max_low_rank=max_rank,
)
if isinstance(layer, FusedGatedMLP):
if max_rank is None:
max_rank = min(layer.hidden_size,
layer.ffn_hidden_size // layer.tp_size)
layer.mlp_in_lora = Lora(
in_hidden_size=layer.hidden_size,
out_hidden_sizes=[
layer.ffn_hidden_size // layer.tp_size,
layer.ffn_hidden_size // layer.tp_size
],
max_low_rank=max_rank,
)
return model
def optimize_model(
model,
use_fused_mlp=False,
use_unfused_qkv_gemm=False,
use_prompt_tuning=False,
use_lora=False,
max_lora_rank=None,
):
if use_fused_mlp:
model = fuse_gate_mlp(model)
if use_unfused_qkv_gemm:
model = unfuse_qkv_gemm(model)
if use_prompt_tuning:
model = set_prompt_tuning(model)
if use_lora:
model = add_lora(model, max_lora_rank)
return model
def preprocess_weights(
weights: Dict[str, torch.Tensor],
model_config: PretrainedConfig) -> Dict[str, torch.Tensor]:
quant_algo = model_config.quantization.quant_algo
kv_cache_quant_algo = model_config.quantization.kv_cache_quant_algo
# INT4_AWQ
if quant_algo == W4A8_AWQ or quant_algo == W4A16_AWQ:
preprocessor = torch.ops.trtllm.preprocess_weights_for_mixed_gemm
for name, param in weights.items():
if name.endswith('weight') and param.dtype == torch.int8:
dtype = torch.float16
if model_config.dtype == "bfloat16":
dtype = torch.bfloat16
weights[name] = preprocessor(param.T.contiguous(),
torch.quint4x2).view(dtype)
if name.endswith('weights_scaling_factor'):
weights[name] = param.T.contiguous().to(
str_dtype_to_torch(model_config.dtype))
if name.endswith('prequant_scaling_factor'):
weights[name] = param.reshape(1, -1)
if model_config.mapping.tp_rank > 0:
if name.endswith('attention.dense.bias') or name.endswith(
'mlp.proj.bias'):
weights[name] = torch.zeros_like(param)
if quant_algo == W4A8_AWQ:
for name in list(weights):
if name.endswith('weights_scaling_factor'):
activation_scaling_factor = weights.pop(
name.replace('weights_scaling_factor',
'activation_scaling_factor'))
weights_scaling_factor_2 = weights.pop(
name.replace('weights_scaling_factor',
'weights_scaling_factor_2'))
weights[name] /= weights_scaling_factor_2
weights[name.replace(
'weights_scaling_factor',
'prequant_scaling_factor')] /= activation_scaling_factor
weights[name.replace(
'weights_scaling_factor', 'alpha'
)] = activation_scaling_factor * weights_scaling_factor_2
# FP8
elif quant_algo == FP8:
for name, param in weights.items():
if name.endswith('weight') and param.dtype == torch.int8:
weights[name] = param.view(torch.float8_e4m3fn)
# lm_head is not quantized to FP8
if "lm_head.weight" in weights:
assert weights['lm_head.weight'].dtype == str_dtype_to_torch(
model_config.dtype)
weights.pop('lm_head.weights_scaling_factor', None)
weights.pop('lm_head.activation_scaling_factor', None)
# Weight only 4bit
elif quant_algo == W4A16:
for name in list(weights):
if any([
_name in name for _name in [
'qkv.weight', 'dense.weight', 'fc.weight',
'proj.weight', 'gate.weight'
]
]) and weights[name].dtype != torch.int8:
processed_torch_weights, torch_weight_scales = \
torch.ops.trtllm.symmetric_quantize_last_axis_of_batched_matrix(
weights[name].t().contiguous(), torch.quint4x2)
weights[name] = processed_torch_weights
weights[name.replace(
'.weight', '.per_channel_scale')] = torch_weight_scales
# Weight only 8bit
elif quant_algo == W8A16:
for name in list(weights):
if any([
_name in name for _name in [
'qkv.weight', 'dense.weight', 'fc.weight',
'proj.weight', 'gate.weight'
]
]) and weights[name].dtype != torch.int8:
processed_torch_weights, torch_weight_scales = \
torch.ops.trtllm.symmetric_quantize_last_axis_of_batched_matrix(
weights[name].t().contiguous(), torch.int8)
weights[name] = processed_torch_weights
weights[name.replace(
'.weight', '.per_channel_scale')] = torch_weight_scales
# FP8 kv_cache_scaling_factor is always 1.0
if kv_cache_quant_algo == FP8:
for name, param in weights.items():
if name.endswith('kv_cache_scaling_factor'):
weights[name] = torch.tensor([1.0], dtype=torch.float32)
# If layer_norm bias is None. (For MPT)
if model_config.architecture == 'MPTForCausalLM':
update_dict = {}
for name, param in weights.items():
if 'input_layernorm.weight' in name and name.replace(
'weight', 'bias') not in weights:
update_dict[name.replace('weight',
'bias')] = torch.zeros_like(param)
if 'post_layernorm.weight' in name and name.replace(
'weight', 'bias') not in weights:
update_dict[name.replace('weight',
'bias')] = torch.zeros_like(param)
if 'ln_f.weight' in name and name.replace('weight',
'bias') not in weights:
update_dict[name.replace('weight',
'bias')] = torch.zeros_like(param)
weights.update(update_dict)
# Parallel block rowlinear should not have duplicate bias.
elif model_config.architecture == 'GPTJForCausalLM':
if model_config.mapping.tp_rank > 0:
for name, param in weights.items():
if 'attention.dense.bias' in name or 'mlp.proj.bias' in name:
weights[name] = torch.zeros_like(param)