TensorRT-LLMs/tensorrt_llm/models/convert_utils.py
Kaiyu Xie 9dbc5b38ba
Update TensorRT-LLM (#1891)
* Update TensorRT-LLM

---------

Co-authored-by: Marks101 <markus.schnoes@gmx.de>
Co-authored-by: lkm2835 <lkm2835@gmail.com>
2024-07-04 14:37:19 +08:00

241 lines
8.0 KiB
Python

import re
from pathlib import Path
from typing import Dict, Optional, Union
import torch
from datasets import load_dataset
from ..quantization import QuantAlgo
def split(v, tp_size, idx, dim=0):
if tp_size == 1:
return v
if len(v.shape) == 1:
return torch.chunk(v, tp_size)[idx].contiguous()
else:
return torch.chunk(v, tp_size, dim=dim)[idx].clone()
def split_qkv_tp(v, n_head, n_hidden, tensor_parallel, rank):
"""
Splits the QKV matrix according to tensor parallelism
"""
v = v.reshape(3, n_hidden, n_hidden)
split_v = split(v, tensor_parallel, rank, dim=1)
split_v = split_v.reshape(3 * (n_hidden // tensor_parallel), n_hidden)
return split_v.clone()
def split_qkv_bias_tp(v, n_head, n_hidden, tensor_parallel, rank):
"""
Splits the QKV bias according to tensor parallelism
"""
v = v.reshape(3, n_hidden)
split_v = split(v, tensor_parallel, rank, dim=1)
split_v = split_v.reshape(3 * (n_hidden // tensor_parallel))
return split_v.clone()
def split_matrix_tp(v, tensor_parallel, rank, dim):
return split(v, tensor_parallel, rank, dim=dim)
def weight_only_quantize(weight: torch.Tensor,
quant_algo: str,
plugin: bool = True):
assert quant_algo in [QuantAlgo.W4A16, QuantAlgo.W8A16
], f'unsupported quant algo: {quant_algo}'
if quant_algo == QuantAlgo.W4A16:
assert plugin, 'W4A16 is only supported with plugin'
if weight.dim() > 2:
v = weight.transpose(-1, -2)
else:
v = weight.t()
t = torch.quint4x2 if quant_algo == QuantAlgo.W4A16 else torch.int8
processed_torch_weights, torch_weight_scales = \
torch.ops.trtllm.symmetric_quantize_last_axis_of_batched_matrix(
v.contiguous(), t)
if plugin:
return processed_torch_weights, torch_weight_scales
else:
return v, torch_weight_scales
def weight_only_quantize_dict(weights: Dict[str, torch.Tensor],
quant_algo: str,
quant_weights=[
'qkv.weight', 'dense.weight', 'fc.weight',
'proj.weight', 'gate.weight'
],
exclude_weights=['shared_expert_gate.weight'],
plugin: bool = True):
if quant_algo not in [QuantAlgo.W4A16, QuantAlgo.W8A16]:
return weights
for name in list(weights):
if any([_name in name for _name in exclude_weights]):
continue
if any([_name in name for _name in quant_weights
]) and weights[name].dtype != torch.int8:
quant_weight, quant_scale = weight_only_quantize(
weight=weights[name], quant_algo=quant_algo, plugin=plugin)
weights[name] = quant_weight
weights[name.replace('.weight', '.per_channel_scale')] = quant_scale
return weights
def load_state_dict(
file_path: Union[str, Path],
dtype: Optional[torch.dtype] = None,
device: Optional[Union[str, torch.device]] = None,
) -> Dict[str, torch.Tensor]:
""" Load weights from model file.
`safetensors` or `pytorch binary` is supported.
Args:
file_path: model file path, ends with .bin or .safetensors.
dtype: torch.dtype, data type.
device: torch device like, optional. If None, load to cpu.
Returns:
Weights as state dict.
"""
file_path = Path(file_path)
if dtype is not None:
assert isinstance(dtype, torch.dtype)
if device is None:
device = 'cpu'
model_params = {}
if file_path.suffix == '.safetensors':
# load from safetensors file
from safetensors import safe_open
with safe_open(file_path, framework='pt', device=device) as f:
for name in f.keys():
tensor = f.get_tensor(name)
if dtype is not None:
tensor = tensor.to(dtype)
model_params[name] = tensor
elif file_path.suffix == '.bin':
# load from pytorch bin file
state_dict = torch.load(file_path, map_location=device)
for name in state_dict:
tensor = state_dict[name]
if dtype is not None:
tensor = tensor.to(dtype)
model_params[name] = tensor
else:
raise NotImplementedError(
f'Support .safetensors or .bin files, but got {str(file_path)}')
return model_params
def get_model_path(
model_dir: Union[str, Path],
name: Optional[str] = None,
) -> Optional[str]:
""" Get model path from model directory.
`safetensors` or `pytorch binary` is supported.
Args:
model_dir: model directory.
name: model file name without suffix.
Returns:
Full model path.
"""
model_dir = Path(model_dir)
if name is not None:
if (model_dir / f"{name}.safetensors").exists():
return str(model_dir / f"{name}.safetensors")
elif (model_dir / f"{name}.bin").exists():
return str(model_dir / f"{name}.bin")
else:
return None
else:
model_files = list(model_dir.glob('*.safetensors'))
if len(model_files) > 0:
assert len(
model_files
) == 1, f"find multiple safetensors files in {model_dir}, please specify one"
return str(model_files[0])
model_files = list(model_dir.glob('*.bin'))
if len(model_files) > 0:
assert len(
model_files
) == 1, f"find multiple bin files in {model_dir}, please specify one"
return str(model_files[0])
return None
def retrieved_layer_index_from_name(name: str) -> Optional[int]:
# This method is a hacky function to retrieve the layer index from
# HF model. Most of HF models have similar naming convention but
# please check carefully before applying if this method works well
# on your target model.
res = re.search(r'\d+', name)
return int(res.group()) if res is not None else res
def iterate_shard_files(model_dir: Union[Path, str],
rank: int,
progress_bar: bool = True):
model_dir = Path(model_dir)
# '.bin' or '.safetensors'. In case that both exist, '.safetensor'
# files will be loaded first.
shard_files = list(model_dir.glob('*.safetensors'))
if not shard_files:
# The model checkpoint is stored in .bin file.
shard_files = list(model_dir.glob('*.bin'))
if not shard_files:
raise RuntimeError(
f"Could not find any .safetensors or .bin files in {model_dir}")
try:
import tqdm
if progress_bar:
# Show a progress bar per rank.
desc = f'Rank [{rank}] Loading weights'
shard_files = tqdm.tqdm(shard_files, desc=desc, position=rank)
except ImportError:
pass
for shard_file in shard_files:
yield shard_file
def has_safetensors(model_dir: str):
return len(list(Path(model_dir).glob('*.safetensors'))) > 0
DEFAULT_HF_DATASET_META = {
'ccdv/cnn_dailymail': ('3.0.0', 'train', 'article'),
'cnn_dailymail': ('3.0.0', 'train', 'article'),
'lambada': (None, 'validation', 'text'),
}
def load_calib_dataset(dataset_name_or_dir: str,
config_name: Optional[str] = None,
split: Optional[str] = None,
key: Optional[str] = None,
trust_remote_code=True,
**kwargs):
if config_name is None:
for name, meta in DEFAULT_HF_DATASET_META.items():
if name in dataset_name_or_dir:
if config_name is None:
config_name = meta[0]
if split is None:
split = meta[1]
if key is None:
key = meta[2]
break
dataset = load_dataset(dataset_name_or_dir,
name=config_name,
split=split,
**kwargs)
return dataset[key]