TensorRT-LLMs/tensorrt_llm/models/modeling_utils.py
Kaiyu Xie 4bb65f216f
Update TensorRT-LLM (#1274)
* Update TensorRT-LLM

---------

Co-authored-by: meghagarwal <16129366+megha95@users.noreply.github.com>
Co-authored-by: Shixiaowei02 <39303645+Shixiaowei02@users.noreply.github.com>
2024-03-12 18:15:52 +08:00

703 lines
29 KiB
Python

import copy
import dataclasses
import json
import os
from typing import List, Optional, Union
import numpy as np
import safetensors
import torch
from .._common import default_net
from .._utils import (numpy_to_torch, str_dtype_to_torch, str_dtype_to_trt,
trt_dtype_to_torch)
from ..functional import PositionEmbeddingType, Tensor, gather_last_token_logits
from ..layers import (AttentionParams, FusedGatedMLP, GatedMLP,
KeyValueCacheParams, LoraParams)
from ..layers.attention import Attention
from ..layers.linear import ColumnLinear
from ..logger import logger
from ..mapping import Mapping
from ..module import Module, ModuleList
from ..quantization import QuantMode
from ..quantization.layers import FP8Linear
from ..quantization.mode import W8A8_SQ_PLUGIN_LIST
from ..quantization.quantize import quantize
from .generation_mixin import GenerationMixin
WEIGHT_LOADER_MODELS = {"PhiForCausalLM"}
@dataclasses.dataclass
class QuantizationConfig:
'''Serializable quantization configuration class, part of the PretrainedConfig
'''
quant_algo: Optional[str] = None
kv_cache_quant_algo: Optional[str] = None
group_size: Optional[int] = 128
has_zero_point: Optional[bool] = False
pre_quant_scale: Optional[bool] = False
exclude_modules: Optional[List[str]] = None
@property
def use_plugin_sq(self):
return self.quant_algo in W8A8_SQ_PLUGIN_LIST
def default_weight_loader(mapping: Mapping, param: torch.Tensor,
loaded_weight: torch.Tensor) -> None:
"""Default weight loader."""
param.value = loaded_weight
class PretrainedConfig:
def __init__(self,
architecture: str,
dtype: str,
logits_dtype: str,
vocab_size: int,
max_position_embeddings: int,
hidden_size: int,
num_hidden_layers: int,
num_attention_heads: int,
num_key_value_heads: int,
hidden_act: str,
intermediate_size: int,
norm_epsilon: float,
position_embedding_type: str,
world_size: int,
tp_size: int,
pp_size: int,
quantization: Union[QuantizationConfig, dict],
use_prompt_tuning: bool = False,
use_parallel_embedding: bool = False,
embedding_sharding_dim: int = 0,
share_embedding_table: bool = False,
max_lora_rank: int = 64,
head_size: int = None,
**kwargs):
self.architecture = architecture
self.dtype = dtype
self.logits_dtype = logits_dtype
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.head_size = hidden_size // num_attention_heads if head_size is None else head_size
self.hidden_act = hidden_act
self.intermediate_size = intermediate_size
self.norm_epsilon = norm_epsilon
self.position_embedding_type = PositionEmbeddingType.from_string(
position_embedding_type)
self.use_prompt_tuning = use_prompt_tuning
self.use_parallel_embedding = use_parallel_embedding
self.embedding_sharding_dim = embedding_sharding_dim
self.share_embedding_table = share_embedding_table
self.mapping = Mapping(world_size=world_size,
tp_size=tp_size,
pp_size=pp_size)
# Ideally shall only keep one of quant_mode and quant_config to make sure single source of truth
# keep them now since many code path is still using the quant_mode
self.quant_mode = QuantMode.from_quant_algo(
quantization.quant_algo, quantization.kv_cache_quant_algo)
if isinstance(quantization, dict):
self.quantization = dataclasses.replace(QuantizationConfig(),
**quantization)
else:
assert isinstance(
quantization, QuantizationConfig
), f"Expecting type of QuantizationConfig, found {type(quantization)}"
self.quantization = quantization
self.kv_dtype = self.dtype
self.max_lora_rank = max_lora_rank
if self.quant_mode.has_int8_kv_cache():
self.kv_dtype = 'int8'
elif self.quant_mode.has_fp8_kv_cache():
self.kv_dtype = 'fp8'
for key, value in kwargs.items():
try:
setattr(self, key, value)
except AttributeError as err:
raise err
def set_if_not_exist(self, key, value):
if not hasattr(self, key):
setattr(self, key, value)
@classmethod
def from_dict(cls, config):
config = copy.deepcopy(
config
) # many config.pop calls inside, make one local copy of the config dict such that the function has no side effects
architecture = config.pop('architecture')
dtype = config.pop('dtype')
vocab_size = config.pop('vocab_size')
hidden_size = config.pop('hidden_size')
num_hidden_layers = config.pop('num_hidden_layers')
num_attention_heads = config.pop('num_attention_heads')
hidden_act = config.pop('hidden_act')
norm_epsilon = config.pop('norm_epsilon', 1e-5)
position_embedding_type = config.pop('position_embedding_type',
'learned_absolute')
logits_dtype = config.pop('logits_dtype', 'float32')
num_key_value_heads = config.pop('num_key_value_heads',
num_attention_heads)
intermediate_size = config.pop('intermediate_size', None)
max_position_embeddings = config.pop('max_position_embeddings', None)
use_prompt_tuning = config.pop('use_prompt_tuning', False)
use_parallel_embedding = config.pop('use_parallel_embedding', False)
embedding_sharding_dim = config.pop('embedding_sharding_dim', 0)
share_embedding_table = config.pop('share_embedding_table', False)
mapping = config.pop('mapping', {
'world_size': 1,
'tp_size': 1,
'pp_size': 1
})
world_size = mapping.get('world_size', 1)
tp_size = mapping.get('tp_size', 1)
pp_size = mapping.get('pp_size', 1)
if share_embedding_table and tp_size > 1:
if (not use_parallel_embedding) or (use_parallel_embedding and
embedding_sharding_dim == 1):
raise NotImplementedError(
"For multiple-processes cases, sharing the embedding table must set" \
"use_parallel_embedding=True and embedding_sharding_dim=0"
)
quant_config = QuantizationConfig()
if 'quantization' in config:
# override the default quantization object from the given dict, allows user to specify partial set of the fields
quant_config_from_user = config.pop('quantization')
if isinstance(quant_config_from_user, dict):
quant_config = dataclasses.replace(quant_config,
**quant_config_from_user)
# allow user to directly pass one QuantizationConfig object
else:
assert isinstance(quant_config_from_user, QuantizationConfig)
quant_config = quant_config_from_user
max_lora_rank = config.pop('max_lora_rank', 64)
return cls(architecture, dtype, logits_dtype, vocab_size,
max_position_embeddings, hidden_size, num_hidden_layers,
num_attention_heads, num_key_value_heads, hidden_act,
intermediate_size, norm_epsilon, position_embedding_type,
world_size, tp_size, pp_size, quant_config,
use_prompt_tuning, use_parallel_embedding,
embedding_sharding_dim, share_embedding_table, max_lora_rank,
**config)
@classmethod
def from_json_file(cls, config_file: str):
with open(config_file) as f:
config = json.load(f)
return PretrainedConfig.from_dict(config)
def to_dict(self):
output = copy.deepcopy(self.__dict__)
output['position_embedding_type'] = str(self.position_embedding_type)
output['mapping'] = {
'world_size': self.mapping.world_size,
'tp_size': self.mapping.tp_size,
'pp_size': self.mapping.pp_size,
}
output.pop('quant_mode')
output['quantization'] = dataclasses.asdict(self.quantization)
return output
def set_rank(self, rank):
self.mapping = Mapping(self.mapping.world_size,
rank=rank,
tp_size=self.mapping.tp_size,
pp_size=self.mapping.pp_size)
class DecoderLayerList(ModuleList):
def __init__(self, cls, config):
self.layer_list = config.mapping.pp_layers(config.num_hidden_layers)
super().__init__([cls(config, idx) for idx in self.layer_list])
def forward(self,
hidden_states,
use_cache=False,
attention_mask=None,
kv_cache_params=None,
attention_params=None,
lora_params=None,
medusa_position_offsets=None,
medusa_packed_mask=None):
kv_cache_params.fill_none_tensor_list(len(self.layer_list))
if use_cache:
presents = []
for layer_idx, (layer, past) in enumerate(
zip(self, kv_cache_params.past_key_value)):
lora_layer_params = None
if lora_params is not None and lora_params.lora_ranks is not None:
lora_layer_params = lora_params.get_layer_params(layer_idx)
kwargs = {}
if lora_layer_params is not None:
kwargs['lora_layer_params'] = lora_layer_params
if medusa_position_offsets is not None:
kwargs['medusa_position_offsets'] = medusa_position_offsets
if medusa_packed_mask is not None:
kwargs['medusa_packed_mask'] = medusa_packed_mask
hidden_states = layer(
hidden_states,
use_cache=use_cache,
attention_mask=attention_mask,
kv_cache_params=KeyValueCacheParams(
past_key_value=[past],
host_past_key_value_lengths=kv_cache_params.
host_past_key_value_lengths,
host_max_attention_window_sizes=kv_cache_params.
host_max_attention_window_sizes,
host_sink_token_length=kv_cache_params.
host_sink_token_length,
kv_cache_block_pointers=kv_cache_params.
kv_cache_block_pointers,
host_kv_cache_block_pointers=kv_cache_params.
host_kv_cache_block_pointers,
cache_indirection=kv_cache_params.cache_indirection),
attention_params=attention_params,
**kwargs)
if use_cache:
presents.append(hidden_states[1])
hidden_states = hidden_states[0]
if use_cache:
return hidden_states, presents
return hidden_states
class PostInitCaller(type):
def __call__(cls, *args, **kwargs):
obj = type.__call__(cls, *args, **kwargs)
obj.__post_init__()
return obj
class PretrainedModel(Module, GenerationMixin, metaclass=PostInitCaller):
def __init__(self, config: PretrainedConfig):
super().__init__()
self.config = config
def __post_init__(self):
quantize(self, self.config.quant_mode,
**dataclasses.asdict(self.config.quantization))
def check_config(self, config):
raise NotImplementedError(
f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
)
@classmethod
def from_config(cls, config: PretrainedConfig):
return cls(config)
@classmethod
def from_checkpoint(cls,
ckpt_dir: str,
rank: int = 0,
config: PretrainedConfig = None):
if config is None:
config = PretrainedConfig.from_json_file(
os.path.join(ckpt_dir, 'config.json'))
config.set_rank(rank)
model = cls.from_config(config)
weights = {}
with safetensors.safe_open(os.path.join(ckpt_dir,
f'rank{rank}.safetensors'),
framework='pt',
device='cpu') as f:
for key in f.keys():
weights[key] = f.get_tensor(key)
model.load(weights)
return model
def load(self, weights):
expected_names = set([name for name, param in self.named_parameters()])
provided_names = set(weights.keys())
assert expected_names.issubset(
provided_names
), f"Expected but not provided tensors:{expected_names.difference(provided_names)}"
if self.config.architecture in WEIGHT_LOADER_MODELS:
mapping = self.config.mapping
for name, param in self.named_parameters():
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(mapping, param, weights[name])
else:
for name, param in self.named_parameters():
try:
param.value = weights[name]
except Exception as e:
raise RuntimeError(
f"Encounter error '{e}' for parameter '{name}'")
def load_partial_weights(self, weights: dict):
params = {name: param for name, param in self.named_parameters()}
mapping = self.config.mapping
for k, v in weights.items():
if k in params.keys():
param = params[k]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(mapping, param, v)
elif mapping.pp_size == 1:
logger.warning(f"Provided but not expected tensors: {k}")
def save_checkpoint(self, output_dir, save_config=True):
# multiple ranks could share same config.json, so adding a save_config parameter to let user avoiding writing config.json in all ranks
rank = self.config.mapping.rank
weights = {
name: numpy_to_torch(param.raw_value)
for name, param in self.named_parameters()
}
from safetensors.torch import save_file
save_file(weights, os.path.join(output_dir, f'rank{rank}.safetensors'))
if save_config:
with open(os.path.join(output_dir, 'config.json'), 'w') as f:
json.dump(self.config.to_dict(), f, indent=4)
def prepare_inputs(self,
max_batch_size,
max_input_len,
max_seq_len,
use_cache,
max_beam_width: int = 1,
max_num_tokens: int = None,
prompt_embedding_table_size: int = 0,
position_encoding_2d: bool = False,
max_draft_len: int = 0,
gather_context_logits: bool = False,
gather_generation_logits: bool = False,
lora_target_modules: List[str] = None):
'''@brief: Prepare inputs Tensors for the model, the given sizes are used to determine the
ranges of the dimensions of when using TRT dynamic shapes.
@return: a list contains values which can be fed into the self.forward()
'''
# Prepare inputs
remove_input_padding = default_net().plugin_config.remove_input_padding
use_gpt_attention_plugin = default_net(
).plugin_config.gpt_attention_plugin
use_gemm_plugin = default_net().plugin_config.gemm_plugin
paged_kv_cache = default_net().plugin_config.paged_kv_cache
tokens_per_block = default_net().plugin_config.tokens_per_block
use_custom_all_reduce = default_net(
).plugin_config.use_custom_all_reduce
use_lora_plugin = default_net().plugin_config.lora_plugin
model_inputs = self.prepare_basic_inputs(
max_batch_size=max_batch_size,
max_beam_width=max_beam_width,
max_input_len=max_input_len,
max_seq_len=max_seq_len,
num_kv_heads=self.config.num_key_value_heads,
head_size=self.config.head_size,
num_layers=self.config.num_hidden_layers,
kv_dtype=str_dtype_to_trt(self.config.kv_dtype),
remove_input_padding=remove_input_padding,
use_gpt_attention_plugin=use_gpt_attention_plugin,
use_gemm_plugin=use_gemm_plugin,
paged_kv_cache=paged_kv_cache,
tokens_per_block=tokens_per_block,
num_heads=self.config.num_attention_heads,
max_num_tokens=max_num_tokens,
dtype=str_dtype_to_trt(self.config.dtype),
prompt_embedding_table_size=prompt_embedding_table_size,
position_encoding_2d=position_encoding_2d,
mapping=self.config.mapping,
gather_context_logits=gather_context_logits,
gather_generation_logits=gather_generation_logits,
use_custom_all_reduce=use_custom_all_reduce,
use_lora_plugin=use_lora_plugin,
max_draft_len=max_draft_len,
lora_target_modules=lora_target_modules)
result = {
'input_ids':
model_inputs['input_ids'],
'position_ids':
model_inputs['position_ids'],
'use_cache':
True,
'last_token_ids':
model_inputs['last_token_ids'],
'attention_mask':
model_inputs['attention_mask'],
'kv_cache_params':
KeyValueCacheParams(
past_key_value=model_inputs['past_key_value'],
host_past_key_value_lengths=model_inputs[
'host_past_key_value_lengths'],
host_max_attention_window_sizes=model_inputs[
'host_max_attention_window_sizes'],
host_sink_token_length=model_inputs['host_sink_token_length'],
kv_cache_block_pointers=model_inputs['kv_cache_block_pointers'],
host_kv_cache_block_pointers=model_inputs[
'host_kv_cache_block_pointers'],
cache_indirection=model_inputs['cache_indirection'],
),
'attention_params':
AttentionParams(
sequence_length=model_inputs['sequence_length'],
context_lengths=model_inputs['context_lengths'],
host_context_lengths=model_inputs['host_context_lengths'],
max_context_length=max_input_len,
host_request_types=model_inputs['host_request_types'])
}
if prompt_embedding_table_size > 0:
result['prompt_embedding_table'] = model_inputs[
'prompt_embedding_table']
result['prompt_tasks'] = model_inputs['tasks']
result['prompt_vocab_size'] = model_inputs['prompt_vocab_size']
if model_inputs['hidden_states_input'] is not None:
result['hidden_states'] = model_inputs['hidden_states_input']
if use_lora_plugin:
result['lora_params'] = LoraParams(
model_inputs['lora_ranks'],
model_inputs['lora_weights_pointers'],
host_context_lengths=model_inputs['host_context_lengths'],
max_context_length=max_input_len,
host_request_types=model_inputs['host_request_types'])
return result
class DecoderModelForCausalLM(PretrainedModel):
def __init__(self, config, transformer, lm_head):
super().__init__(config)
self.transformer = transformer
self.lm_head = lm_head
def forward(self,
input_ids: Tensor,
position_ids=None,
use_cache=False,
last_token_ids=None,
attention_mask=None,
kv_cache_params=None,
attention_params=None,
hidden_states=None,
prompt_embedding_table: Optional[Tensor] = None,
prompt_tasks: Optional[Tensor] = None,
prompt_vocab_size: Optional[Tensor] = None,
lora_params=None,
medusa_position_offsets=None,
medusa_packed_mask=None):
kwargs = {
'input_ids': input_ids,
'position_ids': position_ids,
'use_cache': use_cache,
'attention_mask': attention_mask,
'kv_cache_params': kv_cache_params,
'attention_params': attention_params,
}
if lora_params is not None:
kwargs['lora_params'] = lora_params
if hidden_states is not None:
kwargs['hidden_states'] = hidden_states
if prompt_embedding_table is not None:
kwargs['prompt_embedding_table'] = prompt_embedding_table
if prompt_tasks is not None:
kwargs['prompt_tasks'] = prompt_tasks
if prompt_vocab_size is not None:
kwargs['prompt_vocab_size'] = prompt_vocab_size
if medusa_position_offsets is not None:
kwargs['medusa_position_offsets'] = medusa_position_offsets
if medusa_packed_mask is not None:
kwargs['medusa_packed_mask'] = medusa_packed_mask
hidden_states = self.transformer.forward(**kwargs)
if use_cache:
hidden_states, presents = hidden_states
if self.config.mapping.is_last_pp_rank():
hidden_states = gather_last_token_logits(
hidden_states, last_token_ids,
default_net().plugin_config.remove_input_padding)
# [batch_size, hidden_size] -> [batch_size, vocab_size]
lm_logits = self.lm_head(hidden_states)
lm_logits.mark_output('logits', self.config.logits_dtype)
else:
hidden_states.mark_output('hidden_states_output', self.config.dtype)
if use_cache and not default_net().plugin_config.paged_kv_cache:
for i, present in zip(
self.config.mapping.pp_layers(
self.config.num_hidden_layers), presents):
present.mark_output(f'present_key_value_{i}',
self.config.kv_dtype)
if self.config.mapping.is_last_pp_rank():
return (lm_logits, presents, hidden_states)
return (hidden_states, presents)
else:
if self.config.mapping.is_last_pp_rank():
return lm_logits, hidden_states
return hidden_states
def fuse_gate_mlp(model):
for layer in model.transformer.layers:
if not hasattr(layer, 'mlp'):
continue
quant_algo = model.config.quantization.quant_algo
if isinstance(layer.mlp, GatedMLP):
fused_layer = FusedGatedMLP(
hidden_size=layer.mlp.hidden_size,
ffn_hidden_size=layer.mlp.ffn_hidden_size,
hidden_act=layer.mlp.hidden_act,
bias=layer.mlp.bias,
dtype=layer.mlp.dtype,
tp_group=layer.mlp.tp_group,
tp_size=layer.mlp.tp_size,
quant_mode=layer.mlp.quant_mode,
max_lora_rank=layer.mlp.max_lora_rank)
if quant_algo == 'FP8':
if isinstance(layer.mlp.dtype, str):
dtype = str_dtype_to_torch(layer.mlp.dtype)
else:
dtype = trt_dtype_to_torch(layer.mlp.dtype)
# dequantize
gate_weight = numpy_to_torch(
layer.mlp.gate.weight.raw_value).to(dtype) * numpy_to_torch(
layer.mlp.gate.weights_scaling_factor.raw_value)
fc_weight = numpy_to_torch(
layer.mlp.fc.weight.raw_value).to(dtype) * numpy_to_torch(
layer.mlp.fc.weights_scaling_factor.raw_value)
# concat
fused_weight = torch.cat([gate_weight, fc_weight], dim=0)
# quantize
fused_weight_scaling_factor = numpy_to_torch(
max(
layer.mlp.gate.weights_scaling_factor.raw_value,
layer.mlp.fc.weights_scaling_factor.raw_value,
))
fused_weight = (fused_weight / fused_weight_scaling_factor).to(
torch.float8_e4m3fn)
fused_layer.fused_fc.weight.value = fused_weight
fused_layer.fused_fc.weights_scaling_factor.value = fused_weight_scaling_factor
fused_layer.fused_fc.activation_scaling_factor.value = \
max(layer.mlp.gate.activation_scaling_factor.raw_value,
layer.mlp.fc.activation_scaling_factor.raw_value
)
elif quant_algo is None:
fused_layer.fused_fc.weight.value = np.concatenate([
layer.mlp.gate.weight.raw_value,
layer.mlp.fc.weight.raw_value
],
axis=0)
if layer.mlp.bias:
fused_layer.fused_fc.bias.value = np.concatenate([
layer.mlp.gate.bias.raw_value,
layer.mlp.fc.bias.raw_value
],
axis=0)
else:
raise ValueError(f'Unsupported quant algo: {quant_algo}')
fused_layer.proj = layer.mlp.proj
layer.mlp = fused_layer
return model
def unfuse_qkv_gemm(model):
for name, layer in model.named_modules(remove_duplicate=True):
if isinstance(layer, Attention) and not layer.cross_attention:
assert layer.tp_size == 1, "please disable manual tp when enable auto parallel"
if layer.unfuse_qkv_gemm:
continue
layer.unfuse_qkv_gemm = True
linear_class = FP8Linear if layer.use_fp8_qdq else ColumnLinear
q = linear_class(layer.hidden_size,
layer.attention_hidden_size,
bias=layer.bias,
dtype=layer.dtype,
gather_output=False)
k = linear_class(layer.hidden_size,
layer.num_attention_kv_heads *
layer.attention_head_size,
bias=layer.bias,
dtype=layer.dtype,
gather_output=False)
v = linear_class(layer.hidden_size,
layer.num_attention_kv_heads *
layer.attention_head_size,
bias=layer.bias,
dtype=layer.dtype,
gather_output=False)
if layer.qkv.weight.is_inited():
qkv_weight = layer.qkv.weight.raw_value
weights = np.split(qkv_weight, [
q.out_features,
q.out_features + k.out_features,
])
for gemm, weight in zip([q, k, v], weights):
gemm.weight.value = weight
if layer.qkv.bias is not None and layer.qkv.bias.is_inited():
qkv_bias = layer.qkv.bias.raw_value
biases = np.split(qkv_bias, [
q.out_features,
q.out_features + k.out_features,
])
for gemm, bias in zip([q, k, v], biases):
gemm.bias.value = bias
for name, parameter in layer.qkv._parameters.items():
if name not in ["weight", "bias"]:
for gemm in [q, k, v]:
setattr(gemm, name, parameter)
layer.q = q
layer.k = k
layer.v = v
layer.qkv = None
return model
def optimize_model(model, use_fused_mlp=False, use_unfused_qkv_gemm=False):
if use_fused_mlp:
model = fuse_gate_mlp(model)
if use_unfused_qkv_gemm:
model = unfuse_qkv_gemm(model)
return model