TensorRT-LLMs/tensorrt_llm/models/modeling_utils.py
Kaiyu Xie 9dbc5b38ba
Update TensorRT-LLM (#1891)
* Update TensorRT-LLM

---------

Co-authored-by: Marks101 <markus.schnoes@gmx.de>
Co-authored-by: lkm2835 <lkm2835@gmail.com>
2024-07-04 14:37:19 +08:00

1195 lines
49 KiB
Python

import argparse
import copy
import dataclasses
import json
import os
from enum import IntFlag, auto
from functools import cached_property
from typing import Dict, List, Optional, Union
import numpy as np
import safetensors
import torch
from .._common import default_net
from .._utils import (get_init_params, numpy_to_torch, release_gc,
str_dtype_to_torch, str_dtype_to_trt, trt_dtype_to_torch)
from ..functional import PositionEmbeddingType, Tensor, gather_last_token_logits
from ..layers import (AttentionParams, Embedding, FusedGatedMLP, FusedRgLru,
GatedMLP, KeyValueCacheParams, LoraParams,
PromptTuningEmbedding, RgLru)
from ..layers.attention import Attention, BertAttention
from ..layers.linear import ColumnLinear, Linear, RowLinear
from ..layers.lora import Lora
from ..layers.moe import MOE, MoeOOTB
from ..logger import logger
from ..mapping import Mapping
from ..module import Module, ModuleList
from ..parameter import Parameter
from ..quantization import QuantMode
from ..quantization.layers import (WeightOnlyGroupwiseQuantLinear,
WeightOnlyGroupwiseQuantRowLinear,
WeightOnlyQuantLinear,
WeightOnlyQuantRowLinear)
from ..quantization.mode import W8A8_SQ_PLUGIN_LIST, QuantAlgo
from ..top_model_mixin import TopModelMixin
from .convert_utils import weight_only_quantize_dict
from .generation_mixin import GenerationMixin
WEIGHT_LOADER_MODELS = {"PhiForCausalLM"}
class SpeculativeDecodingMode(IntFlag):
# [WARNING] KEEP BELOW DEFINITION IN SYNC WITH cpp/tensorrt_llm/runtime/speculativeDecodingMode.h
NONE = auto()
DRAFT_TOKENS_EXTERNAL = auto()
MEDUSA = auto()
LOOKAHEAD_DECODING = auto()
EXPLICIT_DRAFT_TOKENS = auto()
@staticmethod
def from_arguments(args: argparse.Namespace):
if args.speculative_decoding_mode is None:
return SpeculativeDecodingMode.NONE
elif args.speculative_decoding_mode == "draft_tokens_external":
return SpeculativeDecodingMode.DRAFT_TOKENS_EXTERNAL
elif args.speculative_decoding_mode == "medusa":
return SpeculativeDecodingMode.MEDUSA
elif args.speculative_decoding_mode == "lookahead_decoding":
return SpeculativeDecodingMode.LOOKAHEAD_DECODING
elif args.speculative_decoding_mode == "explicit_draft_tokens":
return SpeculativeDecodingMode.EXPLICIT_DRAFT_TOKENS
else:
assert False, "Unknown speculative_decoding_mode " + args.speculative_decoding_mode
@dataclasses.dataclass
class QuantConfig:
'''Serializable quantization configuration class, part of the PretrainedConfig
'''
quant_algo: Optional[QuantAlgo] = None
kv_cache_quant_algo: Optional[QuantAlgo] = None
group_size: Optional[int] = 128
smoothquant_val: Optional[float] = None
has_zero_point: Optional[bool] = False
pre_quant_scale: Optional[bool] = False
exclude_modules: Optional[List[str]] = None
@property
def use_plugin_sq(self):
return self.quant_algo in W8A8_SQ_PLUGIN_LIST
@cached_property
def quant_mode(self) -> QuantMode:
return QuantMode.from_quant_algo(
self.quant_algo,
self.kv_cache_quant_algo,
)
def quant_algo_to_modelopt_qformat(self):
algo_to_modelopt_map = {
QuantAlgo.W8A16: "int8_wo",
QuantAlgo.W4A16: "int4_wo",
QuantAlgo.W4A16_AWQ: "int4_awq",
QuantAlgo.W4A8_AWQ: 'w4a8_awq',
QuantAlgo.FP8: 'fp8',
QuantAlgo.W8A8_SQ_PER_CHANNEL: 'int8_sq',
}
if self.quant_algo is not None:
assert self.quant_algo in algo_to_modelopt_map, f"We don't use Modelopt for quantization algorithm {self.quant_algo}, you probably shall not call this"
qformat = algo_to_modelopt_map[self.quant_algo]
else:
qformat = 'full_prec'
return qformat
@classmethod
def from_dict(cls, config: dict):
return cls(**config)
def to_dict(self):
return dataclasses.asdict(self)
def default_weight_loader(mapping: Mapping, param: torch.Tensor,
loaded_weight: torch.Tensor) -> None:
"""Default weight loader."""
param.value = loaded_weight
def save_checkpoint(output_dir: str, config: dict, weights: dict) -> None:
""" Checkpoint saver for weight loader."""
with open(os.path.join(output_dir, 'config.json'), 'w') as f:
json.dump(config, f, indent=4)
safetensors.torch.save_file(weights,
os.path.join(output_dir, 'rank0.safetensors'))
class PretrainedConfig:
def __init__(self,
*,
architecture: str,
dtype: str,
hidden_size: int,
num_hidden_layers: int,
num_attention_heads: int,
vocab_size: Optional[int] = None,
hidden_act: str = 'gelu',
logits_dtype: str = 'float32',
norm_epsilon: float = 1e-5,
position_embedding_type: Union[
PositionEmbeddingType,
str] = PositionEmbeddingType.learned_absolute,
max_position_embeddings: Optional[int] = None,
num_key_value_heads: Optional[int] = None,
intermediate_size: Optional[int] = None,
mapping: Optional[Union[Mapping, dict]] = None,
quantization: Optional[Union[QuantConfig, dict]] = None,
use_parallel_embedding: bool = False,
embedding_sharding_dim: int = 0,
share_embedding_table: bool = False,
head_size: Optional[int] = None,
qk_layernorm: bool = False,
**kwargs):
self.architecture = architecture
self.dtype = dtype
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.hidden_act = hidden_act
self.logits_dtype = logits_dtype
self.norm_epsilon = norm_epsilon
if isinstance(position_embedding_type, str):
position_embedding_type = PositionEmbeddingType.from_string(
position_embedding_type)
assert isinstance(position_embedding_type, PositionEmbeddingType)
self.position_embedding_type = position_embedding_type
self.max_position_embeddings = max_position_embeddings
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
if intermediate_size is None:
intermediate_size = hidden_size * 4
self.intermediate_size = intermediate_size
if mapping is None:
mapping = Mapping()
elif isinstance(mapping, dict):
mapping = Mapping.from_dict(mapping)
assert isinstance(mapping, Mapping)
self.mapping = mapping
if quantization is None:
quantization = QuantConfig()
elif isinstance(quantization, dict):
quantization = QuantConfig.from_dict(quantization)
assert isinstance(quantization, QuantConfig)
self.quantization = quantization
self.use_parallel_embedding = use_parallel_embedding
self.embedding_sharding_dim = embedding_sharding_dim
self.share_embedding_table = share_embedding_table
if share_embedding_table and mapping.tp_size > 1:
if (not use_parallel_embedding) or (use_parallel_embedding and
embedding_sharding_dim == 1):
raise NotImplementedError(
"For tensor parallelism, sharing the embedding table must set" \
"use_parallel_embedding=True and embedding_sharding_dim=0"
)
if share_embedding_table and mapping.pp_size > 1:
raise NotImplementedError(
"Embedding table cannot be shared for pipeline parallelism")
if head_size is None:
head_size = hidden_size // num_attention_heads
self.head_size = head_size
self.qk_layernorm = qk_layernorm
for key, value in kwargs.items():
try:
setattr(self, key, value)
logger.warning(
f"Implicitly setting {self.__class__.__name__}.{key} = {value}"
)
except AttributeError as err:
raise err
@property
def kv_dtype(self):
if self.quant_mode.has_int8_kv_cache():
return 'int8'
elif self.quant_mode.has_fp8_kv_cache():
return 'fp8'
else:
return self.dtype
def set_if_not_exist(self, key, value):
if not hasattr(self, key):
setattr(self, key, value)
@classmethod
def from_dict(cls, config: dict):
# Maybe we need AutoConfig for this
from . import MODEL_MAP
model_cls = MODEL_MAP[config['architecture']]
config_cls = getattr(model_cls, 'config_class', cls)
return config_cls(**config)
def to_dict(self):
output = copy.deepcopy(self.__dict__)
output['position_embedding_type'] = str(self.position_embedding_type)
output['mapping'] = self.mapping.to_dict()
output['mapping'].pop('rank')
output['quantization'] = self.quantization.to_dict()
return output
@classmethod
def from_json_file(cls, config_file: str):
with open(config_file) as f:
config = json.load(f)
return cls.from_dict(config)
@classmethod
def from_checkpoint(cls, ckpt_dir: str):
return cls.from_json_file(os.path.join(ckpt_dir, 'config.json'))
def to_json_file(self, config_file: str):
with open(config_file, 'w') as f:
json.dump(self.to_dict(), f, indent=4)
@property
def quant_mode(self):
return self.quantization.quant_mode
def set_rank(self, rank):
self.mapping = Mapping(self.mapping.world_size,
rank=rank,
tp_size=self.mapping.tp_size,
pp_size=self.mapping.pp_size,
moe_tp_size=self.mapping.moe_tp_size,
moe_ep_size=self.mapping.moe_ep_size,
gpus_per_node=self.mapping.gpus_per_node)
class DecoderLayerList(ModuleList):
def __init__(self, cls, config):
self.num_hidden_layers = config.num_hidden_layers
self.layer_list = config.mapping.pp_layers(config.num_hidden_layers)
super().__init__([cls(config, idx) for idx in self.layer_list])
def forward(self,
hidden_states,
use_cache=False,
attention_mask=None,
kv_cache_params=None,
attention_params=None,
position_ids=None,
lora_params=None,
spec_decoding_params=None):
kv_cache_params.fill_none_tensor_list(len(self.layer_list))
if use_cache:
presents = []
for layer_idx, (layer, past) in enumerate(
zip(self, kv_cache_params.past_key_value)):
lora_layer_params = None
if lora_params is not None and lora_params.lora_ranks is not None:
lora_layer_params = lora_params.get_layer_params(layer_idx)
kwargs = {}
if position_ids is not None:
kwargs['position_ids'] = position_ids
if lora_layer_params is not None:
kwargs['lora_layer_params'] = lora_layer_params
if spec_decoding_params is not None:
kwargs['spec_decoding_params'] = spec_decoding_params
if default_net().plugin_config.reduce_fusion:
if layer_idx < self.layer_list[-1]:
kwargs['next_layer_input_layernorm_args'] = (
self[layer_idx + 1].input_layernorm.weight.value,
self[layer_idx + 1].input_layernorm.eps)
else:
kwargs['next_layer_input_layernorm_args'] = None
hidden_states = layer(
hidden_states,
use_cache=use_cache,
attention_mask=attention_mask,
kv_cache_params=KeyValueCacheParams(
past_key_value=[past],
host_past_key_value_lengths=kv_cache_params.
host_past_key_value_lengths,
host_max_attention_window_sizes=kv_cache_params.
host_max_attention_window_sizes,
host_sink_token_length=kv_cache_params.
host_sink_token_length,
kv_cache_block_offsets=kv_cache_params.
kv_cache_block_offsets,
host_kv_cache_block_offsets=kv_cache_params.
host_kv_cache_block_offsets,
host_kv_cache_pool_pointers=kv_cache_params.
host_kv_cache_pool_pointers,
cache_indirection=kv_cache_params.cache_indirection),
attention_params=attention_params,
**kwargs)
if use_cache:
presents.append(hidden_states[1])
hidden_states = hidden_states[0]
if use_cache:
return hidden_states, presents
return hidden_states
class PostInitCaller(type):
def __call__(cls, *args, **kwargs):
obj = type.__call__(cls, *args, **kwargs)
obj.__post_init__()
return obj
class PretrainedModel(Module,
GenerationMixin,
TopModelMixin,
metaclass=PostInitCaller):
def __init__(self, config: PretrainedConfig):
super().__init__()
self.config = config
def __post_init__(self):
from ..quantization.quantize import quantize
quantize(self, self.config.quantization)
# Currently, use_parallel_embedding and share_embedding_table must be enabled before weight loading;
# otherwise, the model will be inconsistent with the weights loaded from checkpoint.
optimize_model(
self,
use_parallel_embedding=self.config.use_parallel_embedding,
share_embedding_table=self.config.share_embedding_table,
)
def release(self):
release_gc()
def __del__(self):
self.release()
def check_config(self, config):
raise NotImplementedError(
f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
)
@classmethod
def from_config(cls, config: PretrainedConfig):
return cls(config)
@classmethod
def from_checkpoint(cls,
ckpt_dir: str,
rank: Optional[int] = None,
config: Optional[PretrainedConfig] = None):
if config is None:
config = PretrainedConfig.from_json_file(
os.path.join(ckpt_dir, 'config.json'))
if rank is not None:
config.set_rank(rank)
if config.architecture in WEIGHT_LOADER_MODELS:
weights_path = os.path.join(ckpt_dir, 'rank0.safetensors')
else:
rank = config.mapping.rank
weights_path = os.path.join(ckpt_dir, f'rank{rank}.safetensors')
assert os.path.isfile(weights_path)
weights = safetensors.torch.load_file(weights_path)
is_checkpoint_pruned = getattr(config, 'is_pruned', False)
preprocess_weights(weights, config, from_pruned=is_checkpoint_pruned)
model = cls(config)
model.load(weights, from_pruned=is_checkpoint_pruned)
return model
def load(self, weights, from_pruned=False):
expected_names = set()
required_names = set()
for name, param in self.named_parameters():
expected_names.add(name)
if not param.is_inited():
required_names.add(name)
provided_names = set(weights.keys())
if not required_names.issubset(provided_names):
raise RuntimeError(
f"Required but not provided tensors:{required_names.difference(provided_names)}"
)
if not provided_names.issubset(expected_names):
logger.warning(
f"Provided but not expected tensors: {provided_names.difference(expected_names)}"
)
if self.config.architecture in WEIGHT_LOADER_MODELS:
mapping = self.config.mapping
for name, param in self.named_parameters():
if name in provided_names:
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
if from_pruned and param._shape != weights[name].shape:
dummy_weight = torch.empty(param._shape,
dtype=trt_dtype_to_torch(
param._dtype))
weight_loader(mapping, param, dummy_weight)
else:
weight_loader(mapping, param, weights[name])
else:
for name, param in self.named_parameters():
if name in provided_names:
if not from_pruned:
try:
param.value = weights[name]
except Exception as e:
raise RuntimeError(
f"Encounter error '{e}' for parameter '{name}'")
else:
param.set_value_or_dummy(weights[name])
def load_partial_weights(self, weights: dict):
params = {name: param for name, param in self.named_parameters()}
mapping = self.config.mapping
for k, v in weights.items():
if k in params.keys():
param = params[k]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(mapping, param, v)
elif mapping.pp_size == 1:
logger.warning(f"Provided but not expected tensors: {k}")
def save_checkpoint(self, output_dir, save_config=True):
# multiple ranks could share same config.json, so adding a save_config parameter to let user avoiding writing config.json in all ranks
rank = self.config.mapping.rank
weights = {
name: numpy_to_torch(param.raw_value)
for name, param in self.named_parameters()
}
safetensors.torch.save_file(
weights, os.path.join(output_dir, f'rank{rank}.safetensors'))
if save_config:
self.config.to_json_file(os.path.join(output_dir, 'config.json'))
def prepare_inputs(self,
max_batch_size,
max_input_len,
max_seq_len,
max_num_tokens,
use_cache,
max_beam_width: int = 1,
opt_num_tokens: int = None,
prompt_embedding_table_size: int = 0,
position_encoding_2d: bool = False,
max_draft_len: int = 0,
speculative_decoding_draft_tokens_external: bool = False,
gather_context_logits: bool = False,
gather_generation_logits: bool = False,
lora_target_modules: List[str] = None,
opt_batch_size: int = 0):
'''@brief: Prepare inputs Tensors for the model, the given sizes are used to determine the
ranges of the dimensions of when using TRT dynamic shapes.
@return: a list contains values which can be fed into the self.forward()
'''
# Prepare inputs
remove_input_padding = default_net().plugin_config.remove_input_padding
use_gpt_attention_plugin = default_net(
).plugin_config.gpt_attention_plugin
use_gemm_plugin = default_net().plugin_config.gemm_plugin
paged_kv_cache = default_net().plugin_config.paged_kv_cache
tokens_per_block = default_net().plugin_config.tokens_per_block
use_custom_all_reduce = default_net(
).plugin_config.use_custom_all_reduce
use_lora_plugin = default_net().plugin_config.lora_plugin
multiple_profiles = default_net().plugin_config.multiple_profiles
streamingllm = default_net().plugin_config.streamingllm
model_inputs = self.prepare_basic_inputs(
max_batch_size=max_batch_size,
max_beam_width=max_beam_width,
max_input_len=max_input_len,
max_seq_len=max_seq_len,
hidden_size=self.config.hidden_size,
num_kv_heads=self.config.num_key_value_heads,
head_size=self.config.head_size,
num_layers=self.config.num_hidden_layers,
kv_dtype=str_dtype_to_trt(self.config.kv_dtype),
remove_input_padding=remove_input_padding,
use_gpt_attention_plugin=use_gpt_attention_plugin,
use_gemm_plugin=use_gemm_plugin,
paged_kv_cache=paged_kv_cache,
tokens_per_block=tokens_per_block,
num_heads=self.config.num_attention_heads,
max_num_tokens=max_num_tokens,
opt_num_tokens=opt_num_tokens,
dtype=str_dtype_to_trt(self.config.dtype),
prompt_embedding_table_size=prompt_embedding_table_size,
position_encoding_2d=position_encoding_2d,
mapping=self.config.mapping,
gather_context_logits=gather_context_logits,
gather_generation_logits=gather_generation_logits,
use_custom_all_reduce=use_custom_all_reduce,
use_lora_plugin=use_lora_plugin,
max_draft_len=max_draft_len,
speculative_decoding_draft_tokens_external=
speculative_decoding_draft_tokens_external,
lora_target_modules=lora_target_modules,
multiple_profiles=multiple_profiles,
streamingllm=streamingllm,
opt_batch_size=opt_batch_size)
result = {
'input_ids':
model_inputs['input_ids'],
'position_ids':
model_inputs['position_ids'],
'use_cache':
True,
'last_token_ids':
model_inputs['last_token_ids'],
'attention_mask':
model_inputs['attention_mask'],
'kv_cache_params':
KeyValueCacheParams(
past_key_value=model_inputs['past_key_value'],
host_past_key_value_lengths=model_inputs[
'host_past_key_value_lengths'],
host_max_attention_window_sizes=model_inputs[
'host_max_attention_window_sizes'],
host_sink_token_length=model_inputs['host_sink_token_length'],
kv_cache_block_offsets=model_inputs['kv_cache_block_offsets'],
host_kv_cache_block_offsets=model_inputs[
'host_kv_cache_block_offsets'],
host_kv_cache_pool_pointers=model_inputs[
'host_kv_cache_pool_pointers'],
cache_indirection=model_inputs['cache_indirection'],
),
'attention_params':
AttentionParams(
sequence_length=model_inputs['sequence_length'],
context_lengths=model_inputs['context_lengths'],
host_context_lengths=model_inputs['host_context_lengths'],
max_context_length=max_input_len,
host_request_types=model_inputs['host_request_types'])
}
if prompt_embedding_table_size > 0:
result['prompt_embedding_table'] = model_inputs[
'prompt_embedding_table']
result['prompt_tasks'] = model_inputs['tasks']
result['prompt_vocab_size'] = model_inputs['prompt_vocab_size']
if model_inputs['hidden_states_input'] is not None:
result['hidden_states'] = model_inputs['hidden_states_input']
if use_lora_plugin:
result['lora_params'] = LoraParams(
model_inputs['lora_ranks'],
model_inputs['lora_weights_pointers'],
host_context_lengths=model_inputs['host_context_lengths'],
max_context_length=max_input_len,
host_request_types=model_inputs['host_request_types'])
if model_inputs['spec_decoding_params'] is not None:
result['spec_decoding_params'] = model_inputs[
'spec_decoding_params']
return result
@classmethod
def quantize(
cls,
hf_model_dir: str,
output_dir: str,
dtype: str = 'float16',
mapping: Optional[Mapping] = None,
quant_config: Optional[QuantConfig] = None,
*,
calib_dataset='cnn_dailymail',
calib_batches=512,
calib_batch_size=1,
calib_max_seq_length=512,
random_seed=1234,
tokenizer_max_seq_length=2048,
):
if mapping is None: # single gpu
mapping = Mapping()
if mapping.moe_ep_size > 1:
raise NotImplementedError(
"Quantization for expert parallelism is not supported")
modelopt_qformat = quant_config.quant_algo_to_modelopt_qformat()
kv_cache_dtype = quant_config.kv_cache_quant_algo
assert modelopt_qformat is not None
from ..quantization import quantize_and_export
hf_model_dir = str(
hf_model_dir) # quantize_and_export has some code can not take Path
quantize_and_export(
model_dir=hf_model_dir,
device='cuda',
calib_dataset=calib_dataset,
dtype=dtype,
qformat=modelopt_qformat,
kv_cache_dtype=kv_cache_dtype,
calib_size=calib_batches,
batch_size=calib_batch_size,
calib_max_seq_length=calib_max_seq_length,
awq_block_size=quant_config.group_size,
output_dir=output_dir,
tp_size=mapping.tp_size,
pp_size=mapping.pp_size,
seed=random_seed,
tokenizer_max_seq_length=tokenizer_max_seq_length,
)
class DecoderModelForCausalLM(PretrainedModel):
def __init__(self, config: PretrainedConfig, transformer, lm_head):
super().__init__(config)
self.transformer = transformer
self.lm_head = lm_head
self.mup_width_multiplier = getattr(config, 'mup_width_multiplier',
None)
def forward(self,
input_ids: Tensor,
position_ids=None,
use_cache=False,
last_token_ids=None,
attention_mask=None,
kv_cache_params=None,
attention_params=None,
hidden_states=None,
prompt_embedding_table: Optional[Tensor] = None,
prompt_tasks: Optional[Tensor] = None,
prompt_vocab_size: Optional[Tensor] = None,
lora_params=None,
spec_decoding_params=None):
kwargs = {
'input_ids': input_ids,
'position_ids': position_ids,
'use_cache': use_cache,
'attention_mask': attention_mask,
'kv_cache_params': kv_cache_params,
'attention_params': attention_params,
}
if lora_params is not None:
kwargs['lora_params'] = lora_params
if hidden_states is not None:
kwargs['hidden_states'] = hidden_states
if prompt_embedding_table is not None:
kwargs['prompt_embedding_table'] = prompt_embedding_table
if prompt_tasks is not None:
kwargs['prompt_tasks'] = prompt_tasks
if prompt_vocab_size is not None:
kwargs['prompt_vocab_size'] = prompt_vocab_size
if spec_decoding_params is not None:
kwargs['spec_decoding_params'] = spec_decoding_params
hidden_states = self.transformer.forward(**kwargs)
if use_cache:
hidden_states, presents = hidden_states
if self.config.mapping.is_last_pp_rank():
hidden_states = gather_last_token_logits(
hidden_states, last_token_ids,
default_net().plugin_config.remove_input_padding)
# [batch_size, hidden_size] -> [batch_size, vocab_size]
lm_logits = self.lm_head(hidden_states)
if hasattr(self.config, 'output_multiplier_scale'):
lm_logits *= getattr(self.config, 'output_multiplier_scale', 1)
if self.mup_width_multiplier is not None:
lm_logits = lm_logits / self.mup_width_multiplier
lm_logits.mark_output('logits', self.config.logits_dtype)
else:
hidden_states.mark_output('hidden_states_output', self.config.dtype)
if use_cache and not default_net().plugin_config.paged_kv_cache:
for i, present in zip(
self.config.mapping.pp_layers(
self.config.num_hidden_layers), presents):
present.mark_output(f'present_key_value_{i}',
self.config.kv_dtype)
if self.config.mapping.is_last_pp_rank():
return (lm_logits, presents, hidden_states)
return (hidden_states, presents)
else:
if self.config.mapping.is_last_pp_rank():
return lm_logits, hidden_states
return hidden_states
def fuse_gate_mlp(
model: PretrainedModel,
gemm_swiglu_plugin_dtype: Optional[str] = None,
) -> PretrainedModel:
from ..quantization.quantize import fp8_quantize
quant_algo = model.config.quantization.quant_algo
for name, mlp, layer in model.named_modules_with_parent():
if isinstance(mlp, GatedMLP):
init_params = get_init_params(mlp)
init_params["inner_layernorm"] = mlp.inner_layernorm is not None
fused_layer = FusedGatedMLP(**init_params)
if quant_algo == QuantAlgo.FP8:
fused_layer = fp8_quantize(fused_layer,
model.config.quantization)
if isinstance(mlp.dtype, str):
dtype = str_dtype_to_torch(mlp.dtype)
else:
dtype = trt_dtype_to_torch(mlp.dtype)
# dequantize
gate_weight = numpy_to_torch(
mlp.gate.weight.raw_value).to(dtype) * numpy_to_torch(
mlp.gate.weights_scaling_factor.raw_value)
fc_weight = numpy_to_torch(
mlp.fc.weight.raw_value).to(dtype) * numpy_to_torch(
mlp.fc.weights_scaling_factor.raw_value)
# concat
fused_weight = torch.cat([gate_weight, fc_weight], dim=0)
# quantize
fused_weight_scaling_factor = numpy_to_torch(
max(
mlp.gate.weights_scaling_factor.raw_value,
mlp.fc.weights_scaling_factor.raw_value,
))
fused_weight = (fused_weight / fused_weight_scaling_factor).to(
torch.float8_e4m3fn)
if gemm_swiglu_plugin_dtype == 'fp8':
# gemm_swiglu_plugin needs (k, n) weights
# but weights should still be k-major for fp8
fused_layer.fused_fc.weight = Parameter(
shape=(fused_layer.fused_fc.in_features,
fused_layer.fused_fc.out_features),
dtype='fp8')
fused_layer.fused_fc.weight.value = fused_weight.view(
fused_layer.fused_fc.in_features,
fused_layer.fused_fc.out_features)
else:
fused_layer.fused_fc.weight.value = fused_weight
fused_layer.fused_fc.weights_scaling_factor.value = fused_weight_scaling_factor
fused_layer.fused_fc.activation_scaling_factor.value = max(
mlp.gate.activation_scaling_factor.raw_value,
mlp.fc.activation_scaling_factor.raw_value,
)
elif quant_algo is None:
fused_layer.fused_fc.weight.value = np.concatenate(
[
mlp.gate.weight.raw_value,
mlp.fc.weight.raw_value,
],
axis=0,
)
if mlp.bias:
fused_layer.fused_fc.bias.value = np.concatenate(
[mlp.gate.bias.raw_value, mlp.fc.bias.raw_value],
axis=0)
else:
raise ValueError(f'Unsupported quant algo: {quant_algo}')
fused_layer.proj = mlp.proj
fused_layer.inner_layernorm = mlp.inner_layernorm
mlp_name = name.rsplit('.', 1)[-1]
setattr(layer, mlp_name, fused_layer)
return model
def unfuse_qkv_gemm(model: PretrainedModel) -> PretrainedModel:
'''Split all the models' Attention layer's QKV GEMM into 3 GEMMs layer.q layer.k, layer.v and return the changed model
'''
from ..quantization.quantize import quantize
for name, layer in model.named_modules():
if isinstance(layer, Attention) and not layer.cross_attention:
assert layer.tp_size == 1, "please disable manual tp when enable auto parallel"
if layer.qkv is None:
continue
qkv_params = get_init_params(layer.qkv, ColumnLinear)
qkv_params["bias"] = qkv_params["bias"] is not None
qkv_params["strict_dtype"] = qkv_params.get(
"strict_dtype") is not None
q = ColumnLinear(
**{
**qkv_params,
"out_features":
layer.tp_size * layer.num_attention_heads *
layer.attention_head_size,
})
k = ColumnLinear(
**{
**qkv_params,
"out_features":
layer.tp_size * layer.num_attention_kv_heads *
layer.attention_head_size,
})
v = ColumnLinear(
**{
**qkv_params,
"out_features":
layer.tp_size * layer.num_attention_kv_heads *
layer.attention_head_size,
})
q = quantize(q, model.config.quantization)
k = quantize(k, model.config.quantization)
v = quantize(v, model.config.quantization)
out_features = q.out_features + k.out_features + v.out_features
if isinstance(layer.qkv, (
WeightOnlyQuantLinear,
WeightOnlyQuantRowLinear,
WeightOnlyGroupwiseQuantLinear,
WeightOnlyGroupwiseQuantRowLinear,
)):
out_dim = 1
else:
out_dim = 0
if layer.qkv.weight.is_inited():
qkv_weight = layer.qkv.weight.raw_value
weights = np.split(qkv_weight, [
qkv_weight.shape[out_dim] * q.out_features // out_features,
qkv_weight.shape[out_dim] *
(q.out_features + k.out_features) // out_features,
],
axis=out_dim)
for gemm, weight in zip([q, k, v], weights):
gemm.weight.value = weight
if layer.qkv.bias is not None and layer.qkv.bias.is_inited():
qkv_bias = layer.qkv.bias.raw_value
biases = np.split(qkv_bias, [
qkv_bias.shape[out_dim] * q.out_features // out_features,
qkv_bias.shape[out_dim] *
(q.out_features + k.out_features) // out_features,
],
axis=out_dim)
for gemm, bias in zip([q, k, v], biases):
gemm.bias.value = bias
for name, parameter in layer.qkv._parameters.items():
if name not in ["weight", "bias"]:
for gemm in [q, k, v]:
setattr(gemm, name, parameter)
layer.q = q
layer.k = k
layer.v = v
layer.qkv = None
return model
def fuse_rg_lru(model: PretrainedModel) -> PretrainedModel:
for name, rg_lru, parent in model.named_modules_with_parent():
if isinstance(rg_lru, RgLru):
fused_layer = FusedRgLru(**get_init_params(rg_lru))
fused_layer.gate.weight.value = np.concatenate(
[
rg_lru.input_gate.weight.raw_value,
rg_lru.recurrent_gate.weight.raw_value,
],
axis=-1,
)
fused_layer.gate.bias.value = np.concatenate(
[
rg_lru.input_gate.bias.raw_value,
rg_lru.recurrent_gate.bias.raw_value,
],
axis=-1,
)
fused_layer.recurrent_param.value = rg_lru.recurrent_param.raw_value
rg_lru_name = name.rsplit('.', 1)[-1]
setattr(parent, rg_lru_name, fused_layer)
return model
def set_prompt_tuning(model: PretrainedModel) -> PretrainedModel:
'''Replace the given models embedding layer with a PromptTuningEmbedding layer in-place, return the changed model
Pre-conditions: vocab_embedding exists
Post-conditions: isinstance(vocab_embedding, PromptTuningEmbedding)
'''
for name, embedding, parent in model.named_modules_with_parent():
layer_name = name.rsplit('.', 1)[-1]
if layer_name == "vocab_embedding" and isinstance(embedding, Embedding):
ptuning_embedding = PromptTuningEmbedding(
**get_init_params(embedding))
ptuning_embedding.weight.value = embedding.weight.raw_value
parent.vocab_embedding = ptuning_embedding
return model
def add_lora(model: PretrainedModel,
max_lora_rank: Optional[int]) -> PretrainedModel:
''' Add lora layers to the Attention/BertAttention/Linear/RowLinear/FusedGatedMLP layers to the given model, return the changed model
'''
for name, layer in model.named_modules():
max_rank = max_lora_rank
if isinstance(layer, (Attention, BertAttention)):
if max_rank is None:
max_rank = min(
layer.hidden_size,
layer.num_attention_heads * layer.attention_head_size,
layer.num_attention_kv_heads * layer.attention_head_size)
layer.qkv_lora = Lora(
in_hidden_size=layer.hidden_size,
out_hidden_sizes=[
layer.num_attention_heads * layer.attention_head_size,
layer.num_attention_kv_heads * layer.attention_head_size,
layer.num_attention_kv_heads * layer.attention_head_size
],
max_low_rank=max_rank,
)
if isinstance(layer, (Linear, RowLinear)):
if max_rank is None:
max_rank = min(layer.in_features, layer.out_features)
layer.lora = Lora(
in_hidden_size=layer.in_features,
out_hidden_sizes=[layer.out_features],
max_low_rank=max_rank,
)
if isinstance(layer, FusedGatedMLP):
if max_rank is None:
max_rank = min(layer.hidden_size,
layer.ffn_hidden_size // layer.tp_size)
layer.lora = Lora(
in_hidden_size=layer.hidden_size,
out_hidden_sizes=[
layer.ffn_hidden_size // layer.tp_size,
layer.ffn_hidden_size // layer.tp_size
],
max_low_rank=max_rank,
)
return model
def to_ootb_moe(model: PretrainedModel) -> PretrainedModel:
''' Use OOTB MoE instead of MoE plugin, return the changed model
'''
for name, layer, parent in model.named_modules_with_parent():
if isinstance(layer, MOE):
layer_name = name.rsplit('.', 1)[-1]
ootb_layer = layer.to(MoeOOTB, model.config.quantization)
setattr(parent, layer_name, ootb_layer)
return model
def parallelize_embedding(model: PretrainedModel) -> PretrainedModel:
for name, embedding, parent in model.named_modules_with_parent():
layer_name = name.rsplit('.', 1)[-1]
if isinstance(embedding, Embedding) and embedding.tp_group is None:
init_params = get_init_params(embedding)
init_params["tp_group"] = model.config.mapping.tp_group
init_params["tp_size"] = model.config.mapping.tp_size
init_params["tp_rank"] = model.config.mapping.tp_rank
init_params["sharding_dim"] = model.config.embedding_sharding_dim
new_embedding = embedding.__class__(**init_params)
setattr(parent, layer_name, new_embedding)
return model
def share_embedding(model: PretrainedModel) -> PretrainedModel:
lm_head = None
vocab_embedding = None
for name, layer in model.named_modules():
layer_name = name.rsplit('.', 1)[-1]
if layer_name == "lm_head":
lm_head = layer
if layer_name == "vocab_embedding":
vocab_embedding = layer
if lm_head is not None and vocab_embedding is not None:
break
if lm_head is not None and vocab_embedding is not None:
lm_head.weight = vocab_embedding.weight
if (hasattr(vocab_embedding, "per_token_scale")
and vocab_embedding.per_token_scale is not None):
lm_head.per_channel_scale = vocab_embedding.per_token_scale
return model
def set_fp8_context_fhma(model: PretrainedModel) -> PretrainedModel:
for name, layer in model.named_modules():
if isinstance(layer, Attention):
scale = [1.0] / layer.dense.activation_scaling_factor.raw_value
layer.attention_output_orig_quant_scale = Parameter(
value=scale.astype(np.float32))
return model
def optimize_model(
model: PretrainedModel,
use_parallel_embedding: bool = False,
share_embedding_table: bool = False,
use_ootb_moe: bool = False,
use_fused_mlp: bool = False,
gemm_swiglu_plugin_dtype: Optional[str] = None,
use_fused_rg_lru: bool = False,
use_unfused_qkv_gemm: bool = False,
use_prompt_tuning: bool = False,
use_lora: bool = False,
max_lora_rank: Optional[int] = None,
use_fp8_context_fmha: bool = False,
) -> PretrainedModel:
"""
Run optimization passes on model.
There are dependencies between some passes,
so we always run passes in the order of arguments to guarantee the execution order.
"""
# before weight loading
if use_parallel_embedding:
model = parallelize_embedding(model)
if share_embedding_table:
model = share_embedding(model)
# After weight loading
if use_ootb_moe:
model = to_ootb_moe(model)
if use_fused_mlp:
model = fuse_gate_mlp(model, gemm_swiglu_plugin_dtype)
if use_fused_rg_lru:
model = fuse_rg_lru(model)
if use_unfused_qkv_gemm:
model = unfuse_qkv_gemm(model)
if use_prompt_tuning:
model = set_prompt_tuning(model)
if use_lora:
model = add_lora(model, max_lora_rank)
if use_fp8_context_fmha:
model = set_fp8_context_fhma(model)
return model
def preprocess_weights(weights: Dict[str, torch.Tensor],
model_config: PretrainedConfig,
from_pruned=False) -> None:
"""This function in-place modifies weights and model_config, making them compatible with each other.
Note: Typically, it should be called before model creation and weight loading. For example,
preprocess_weights(weights, model_config)
model = XXXForCausalLM(model_config)
model.load(weights)
"""
quant_algo = model_config.quantization.quant_algo
kv_cache_quant_algo = model_config.quantization.kv_cache_quant_algo
# INT4_AWQ
if quant_algo == QuantAlgo.W4A8_AWQ or quant_algo == QuantAlgo.W4A16_AWQ:
preprocessor = torch.ops.trtllm.preprocess_weights_for_mixed_gemm
if quant_algo == QuantAlgo.W4A8_AWQ:
activation_type = torch.float8_e4m3fn
elif quant_algo == QuantAlgo.W4A16_AWQ:
activation_type = torch.float16
for name, param in weights.items():
if from_pruned and param.numel() == 0:
continue
if name.endswith('weight') and param.dtype == torch.int8:
dtype = torch.float16
if model_config.dtype == "bfloat16":
dtype = torch.bfloat16
weights[name] = preprocessor(param.T.contiguous(),
torch.quint4x2,
activation_type).view(dtype)
if name.endswith('weights_scaling_factor'):
weights[name] = param.T.contiguous().to(
str_dtype_to_torch(model_config.dtype))
if name.endswith('prequant_scaling_factor'):
weights[name] = param.reshape(1, -1)
if model_config.mapping.tp_rank > 0:
if name.endswith('attention.dense.bias') or name.endswith(
'mlp.proj.bias'):
weights[name] = torch.zeros_like(param)
if quant_algo == QuantAlgo.W4A8_AWQ:
for name in list(weights):
if name.endswith('weights_scaling_factor'):
activation_scaling_factor = weights.pop(
name.replace('weights_scaling_factor',
'activation_scaling_factor'))
weights_scaling_factor_2 = weights.pop(
name.replace('weights_scaling_factor',
'weights_scaling_factor_2'))
weights[name] /= weights_scaling_factor_2
weights[name.replace(
'weights_scaling_factor',
'prequant_scaling_factor')] /= activation_scaling_factor
weights[name.replace(
'weights_scaling_factor', 'alpha'
)] = activation_scaling_factor * weights_scaling_factor_2
# FP8
elif quant_algo == QuantAlgo.FP8:
for name, param in weights.items():
if name.endswith('weight') and param.dtype == torch.int8:
weights[name] = param.view(torch.float8_e4m3fn)
# lm_head is not quantized to FP8
if "lm_head.weight" in weights:
assert weights['lm_head.weight'].dtype == str_dtype_to_torch(
model_config.dtype)
weights.pop('lm_head.weights_scaling_factor', None)
weights.pop('lm_head.activation_scaling_factor', None)
elif quant_algo in [QuantAlgo.W4A16, QuantAlgo.W8A16]:
weights = weight_only_quantize_dict(weights=weights,
quant_algo=quant_algo,
plugin=True)
# FP8 kv_cache_scaling_factor is always 1.0
if kv_cache_quant_algo == QuantAlgo.FP8:
for name, param in weights.items():
if name.endswith('kv_cache_scaling_factor'):
weights[name] = torch.tensor([1.0], dtype=torch.float32)
# Parallel block rowlinear should not have duplicate bias.
elif model_config.architecture == 'GPTJForCausalLM':
if model_config.mapping.tp_rank > 0:
for name, param in weights.items():
if 'attention.dense.bias' in name or 'mlp.proj.bias' in name:
weights[name] = torch.zeros_like(param)
# For share_embedding_table
check_share_embedding(weights, model_config)
def check_share_embedding(weights: Dict[str, torch.Tensor],
model_config: PretrainedConfig):
if model_config.share_embedding_table:
if "lm_head.weight" in weights and "transformer.vocab_embedding.weight" in weights:
if (weights["lm_head.weight"] -
weights["transformer.vocab_embedding.weight"]).any():
logger.warning(
"lm_head.weight and transformer.vocab_embedding.weight are not identical, "
"share_embedding_table cannot be enabled; setting share_embedding_table=False."
)
model_config.share_embedding_table = False
else:
weights.pop("lm_head.weight")