mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
* feat: adding multimodal (only image for now) support in trtllm-bench Signed-off-by: Rakib Hasan <rhasan@nvidia.com> * fix: add in load_dataset() calls to maintain the v2.19.2 behavior Signed-off-by: Rakib Hasan <rhasan@nvidia.com> * re-adding prompt_token_ids and using that for prompt_len Signed-off-by: Rakib Hasan <rhasan@nvidia.com> * updating the datasets version in examples as well Signed-off-by: Rakib Hasan <rhasan@nvidia.com> * api changes are not needed Signed-off-by: Rakib Hasan <rhasan@nvidia.com> * moving datasets requirement and removing a missed api change Signed-off-by: Rakib Hasan <rhasan@nvidia.com> * addressing review comments Signed-off-by: Rakib Hasan <rhasan@nvidia.com> * refactoring the quickstart example Signed-off-by: Rakib Hasan <rhasan@nvidia.com> --------- Signed-off-by: Rakib Hasan <rhasan@nvidia.com>
577 lines
22 KiB
Python
577 lines
22 KiB
Python
import fnmatch
|
|
import re
|
|
from pathlib import Path
|
|
from typing import Dict, List, Optional, Tuple, Union
|
|
|
|
import torch
|
|
from datasets import load_dataset
|
|
|
|
from .._utils import torch_dtype_to_str
|
|
from ..logger import logger
|
|
from ..quantization import QuantAlgo
|
|
|
|
|
|
def infer_dtype(dtype: str,
|
|
source_dtype: Optional[Union[str, torch.dtype]] = None) -> str:
|
|
if dtype == 'auto':
|
|
if source_dtype is None:
|
|
dtype = 'float16'
|
|
elif isinstance(source_dtype, str):
|
|
dtype = source_dtype
|
|
elif isinstance(source_dtype, torch.dtype):
|
|
dtype = torch_dtype_to_str(source_dtype)
|
|
if dtype == 'float32':
|
|
dtype = 'float16'
|
|
logger.info(f"Specified dtype 'auto'; inferred dtype {dtype!r}.")
|
|
return dtype
|
|
elif dtype in ('float16', 'fp16'):
|
|
return 'float16'
|
|
elif dtype in ('bfloat16', 'bf16'):
|
|
return 'bfloat16'
|
|
elif dtype in ('float32', 'fp32'):
|
|
return 'float32'
|
|
else:
|
|
raise ValueError(f"Unexpected dtype value {dtype}.")
|
|
|
|
|
|
def split(v, tp_size, idx, dim=0):
|
|
if tp_size == 1:
|
|
return v
|
|
if len(v.shape) == 1:
|
|
return torch.chunk(v, tp_size)[idx].contiguous()
|
|
else:
|
|
return torch.chunk(v, tp_size, dim=dim)[idx].contiguous()
|
|
|
|
|
|
def split_qkv_tp(v, n_head, n_hidden, tensor_parallel, rank):
|
|
"""
|
|
Splits the QKV matrix according to tensor parallelism
|
|
"""
|
|
v = v.reshape(3, n_hidden, n_hidden)
|
|
split_v = split(v, tensor_parallel, rank, dim=1)
|
|
split_v = split_v.reshape(3 * (n_hidden // tensor_parallel), n_hidden)
|
|
return split_v.contiguous()
|
|
|
|
|
|
def split_qkv_bias_tp(v, n_head, n_hidden, tensor_parallel, rank):
|
|
"""
|
|
Splits the QKV bias according to tensor parallelism
|
|
"""
|
|
v = v.reshape(3, n_hidden)
|
|
split_v = split(v, tensor_parallel, rank, dim=1)
|
|
split_v = split_v.reshape(3 * (n_hidden // tensor_parallel))
|
|
return split_v.contiguous()
|
|
|
|
|
|
def split_matrix_tp(v, tensor_parallel, rank, dim):
|
|
return split(v, tensor_parallel, rank, dim=dim)
|
|
|
|
|
|
def weight_only_quantize(weight: torch.Tensor,
|
|
quant_algo: str,
|
|
plugin: bool = True):
|
|
assert quant_algo in [QuantAlgo.W4A16, QuantAlgo.W8A16
|
|
], f'unsupported quant algo: {quant_algo}'
|
|
if quant_algo == QuantAlgo.W4A16:
|
|
assert plugin, 'W4A16 is only supported with plugin'
|
|
if weight.dim() > 2:
|
|
v = weight.transpose(-1, -2)
|
|
else:
|
|
v = weight.t()
|
|
t = torch.quint4x2 if quant_algo == QuantAlgo.W4A16 else torch.int8
|
|
processed_torch_weights, torch_weight_scales = \
|
|
torch.ops.trtllm.symmetric_quantize_last_axis_of_batched_matrix(
|
|
v.contiguous(), t)
|
|
if plugin:
|
|
return processed_torch_weights, torch_weight_scales
|
|
else:
|
|
return v, torch_weight_scales
|
|
|
|
|
|
def get_weight(params: Dict[str, torch.Tensor], prefix: str,
|
|
dtype: torch.dtype) -> torch.Tensor:
|
|
if f'{prefix}.weight' not in params:
|
|
return None
|
|
return params[f'{prefix}.weight'].to(dtype).detach().cpu().contiguous()
|
|
|
|
|
|
def get_bias(params: Dict[str, torch.Tensor], prefix: str,
|
|
dtype: torch.dtype) -> torch.Tensor:
|
|
if f'{prefix}.bias' not in params:
|
|
return None
|
|
return params[f'{prefix}.bias'].to(dtype).detach().cpu().contiguous()
|
|
|
|
|
|
def get_weight_and_bias(params: Dict[str, torch.Tensor], prefix: str,
|
|
dtype: torch.dtype) -> Tuple[torch.Tensor]:
|
|
return get_weight(params, prefix, dtype), get_bias(params, prefix, dtype)
|
|
|
|
|
|
def dup_kv_weight(v, num_head, tp_size):
|
|
assert tp_size % num_head == 0
|
|
reps = tp_size // num_head
|
|
head_size = v.shape[0] // num_head
|
|
v = v.reshape(num_head, head_size,
|
|
-1)[:, None, :, :].expand(num_head, reps, head_size,
|
|
v.shape[1])
|
|
return v.reshape(num_head * reps * head_size, -1).clone().detach()
|
|
|
|
|
|
def dup_kv_bias(v, num_head, tp_size):
|
|
assert tp_size % num_head == 0
|
|
reps = tp_size // num_head
|
|
head_size = v.shape[0] // num_head
|
|
v = v.reshape(num_head,
|
|
head_size)[:, None, :].expand(num_head, reps, head_size)
|
|
return v.reshape(num_head * reps * head_size).clone().detach()
|
|
|
|
|
|
def weight_only_quantize_dict(weights: Dict[str, torch.Tensor],
|
|
quant_algo: str,
|
|
quant_weights=[
|
|
'qkv.weight', 'dense.weight', 'fc.weight',
|
|
'proj.weight', 'gate.weight'
|
|
],
|
|
exclude_modules=None,
|
|
plugin: bool = True):
|
|
if quant_algo not in [QuantAlgo.W4A16, QuantAlgo.W8A16]:
|
|
return weights
|
|
if exclude_modules is None:
|
|
exclude_modules = ['*shared_expert_gate.weight']
|
|
for name in list(weights):
|
|
is_excluded = False
|
|
for exclude_module in exclude_modules:
|
|
if fnmatch.fnmatchcase(name, exclude_module):
|
|
is_excluded = True
|
|
break
|
|
if not is_excluded and any([_name in name for _name in quant_weights
|
|
]) and weights[name].dtype != torch.int8:
|
|
quant_weight, quant_scale = weight_only_quantize(
|
|
weight=weights[name], quant_algo=quant_algo, plugin=plugin)
|
|
weights[name] = quant_weight
|
|
weights[name.replace('.weight', '.per_channel_scale')] = quant_scale
|
|
return weights
|
|
|
|
|
|
def load_state_dict(
|
|
file_path: Union[str, Path],
|
|
dtype: Optional[torch.dtype] = None,
|
|
device: Optional[Union[str, torch.device]] = None,
|
|
) -> Dict[str, torch.Tensor]:
|
|
""" Load weights from model file.
|
|
|
|
`safetensors` or `pytorch binary` is supported.
|
|
Args:
|
|
file_path: model file path, ends with .bin or .safetensors.
|
|
dtype: torch.dtype, data type.
|
|
device: torch device like, optional. If None, load to cpu.
|
|
Returns:
|
|
Weights as state dict.
|
|
"""
|
|
file_path = Path(file_path)
|
|
if dtype is not None:
|
|
assert isinstance(dtype, torch.dtype)
|
|
|
|
if device is None:
|
|
device = 'cpu'
|
|
|
|
model_params = {}
|
|
if file_path.suffix == '.safetensors':
|
|
# load from safetensors file
|
|
from safetensors import safe_open
|
|
with safe_open(file_path, framework='pt', device=device) as f:
|
|
for name in f.keys():
|
|
tensor = f.get_tensor(name)
|
|
if dtype is not None:
|
|
tensor = tensor.to(dtype)
|
|
model_params[name] = tensor
|
|
elif file_path.suffix == '.bin':
|
|
# load from pytorch bin file
|
|
state_dict = torch.load(file_path, map_location=device)
|
|
for name in state_dict:
|
|
tensor = state_dict[name]
|
|
if dtype is not None:
|
|
tensor = tensor.to(dtype)
|
|
model_params[name] = tensor
|
|
else:
|
|
raise NotImplementedError(
|
|
f'Support .safetensors or .bin files, but got {str(file_path)}')
|
|
return model_params
|
|
|
|
|
|
def get_model_path(
|
|
model_dir: Union[str, Path],
|
|
name: Optional[str] = None,
|
|
) -> Optional[str]:
|
|
""" Get model path from model directory.
|
|
|
|
`safetensors` or `pytorch binary` is supported.
|
|
Args:
|
|
model_dir: model directory.
|
|
name: model file name without suffix.
|
|
Returns:
|
|
Full model path.
|
|
"""
|
|
model_dir = Path(model_dir)
|
|
if name is not None:
|
|
if (model_dir / f"{name}.safetensors").exists():
|
|
return str(model_dir / f"{name}.safetensors")
|
|
elif (model_dir / f"{name}.bin").exists():
|
|
return str(model_dir / f"{name}.bin")
|
|
else:
|
|
return None
|
|
else:
|
|
model_files = list(model_dir.glob('*.safetensors'))
|
|
if len(model_files) > 0:
|
|
assert len(
|
|
model_files
|
|
) == 1, f"find multiple safetensors files in {model_dir}, please specify one"
|
|
return str(model_files[0])
|
|
model_files = list(model_dir.glob('*.bin'))
|
|
if len(model_files) > 0:
|
|
assert len(
|
|
model_files
|
|
) == 1, f"find multiple bin files in {model_dir}, please specify one"
|
|
return str(model_files[0])
|
|
return None
|
|
|
|
|
|
def retrieved_layer_index_from_name(name: str) -> Optional[int]:
|
|
# This method is a hacky function to retrieve the layer index from
|
|
# HF model. Most of HF models have similar naming convention but
|
|
# please check carefully before applying if this method works well
|
|
# on your target model.
|
|
res = re.search(r'\d+', name)
|
|
return int(res.group()) if res is not None else res
|
|
|
|
|
|
def iterate_shard_files(model_dir: Union[Path, str],
|
|
rank: int,
|
|
progress_bar: bool = True):
|
|
model_dir = Path(model_dir)
|
|
|
|
# '.bin' or '.safetensors'. In case that both exist, '.safetensor'
|
|
# files will be loaded first.
|
|
shard_files = list(model_dir.glob('*.safetensors'))
|
|
if not shard_files:
|
|
# The model checkpoint is stored in .bin file.
|
|
shard_files = list(model_dir.glob('*.bin'))
|
|
if not shard_files:
|
|
raise RuntimeError(
|
|
f"Could not find any .safetensors or .bin files in {model_dir}")
|
|
|
|
try:
|
|
import tqdm
|
|
if progress_bar:
|
|
# Show a progress bar per rank.
|
|
desc = f'Rank [{rank}] Loading weights'
|
|
shard_files = tqdm.tqdm(shard_files, desc=desc, position=rank)
|
|
|
|
except ImportError:
|
|
pass
|
|
|
|
for shard_file in shard_files:
|
|
yield shard_file
|
|
|
|
|
|
def has_safetensors(model_dir: str):
|
|
return len(list(Path(model_dir).glob('*.safetensors'))) > 0
|
|
|
|
|
|
DEFAULT_HF_DATASET_META = {
|
|
'ccdv/cnn_dailymail': ('3.0.0', 'train', 'article'),
|
|
'cnn_dailymail': ('3.0.0', 'train', 'article'),
|
|
'lambada': (None, 'validation', 'text'),
|
|
'': (None, 'train', 'text'), # Default value in HF
|
|
}
|
|
|
|
|
|
def load_calib_dataset(dataset_name_or_dir: str,
|
|
config_name: Optional[str] = None,
|
|
split: Optional[str] = None,
|
|
key: Optional[str] = None,
|
|
trust_remote_code=True,
|
|
**kwargs):
|
|
if config_name is None:
|
|
for name, meta in DEFAULT_HF_DATASET_META.items():
|
|
if name in dataset_name_or_dir:
|
|
if config_name is None:
|
|
config_name = meta[0]
|
|
if split is None:
|
|
split = meta[1]
|
|
if key is None:
|
|
key = meta[2]
|
|
break
|
|
|
|
dataset = load_dataset(dataset_name_or_dir,
|
|
name=config_name,
|
|
split=split,
|
|
trust_remote_code=trust_remote_code,
|
|
**kwargs)
|
|
return dataset[key]
|
|
|
|
|
|
@torch.no_grad()
|
|
def apply_smoothing(
|
|
scales: torch.Tensor,
|
|
gemm_weights: Union[torch.Tensor, List[torch.Tensor]],
|
|
layernorm_weights: Optional[Union[torch.Tensor,
|
|
List[torch.Tensor]]] = None,
|
|
layernorm_bias: Optional[Union[torch.Tensor,
|
|
List[torch.Tensor]]] = None,
|
|
dtype: torch.dtype = torch.float32,
|
|
layernorm_1p: bool = False):
|
|
if not isinstance(gemm_weights, list):
|
|
gemm_weights = [gemm_weights]
|
|
|
|
if layernorm_weights is not None:
|
|
assert layernorm_weights.numel() == scales.numel()
|
|
layernorm_weights.div_(scales).to(dtype)
|
|
if layernorm_bias is not None:
|
|
assert layernorm_bias.numel() == scales.numel()
|
|
layernorm_bias.div_(scales).to(dtype)
|
|
if layernorm_1p:
|
|
layernorm_weights += (1 / scales) - 1
|
|
|
|
for gemm in gemm_weights:
|
|
gemm.mul_(scales.view(1, -1)).to(dtype)
|
|
|
|
|
|
@torch.no_grad()
|
|
def smooth_gemm(gemm_weights,
|
|
act_scales,
|
|
layernorm_weights=None,
|
|
layernorm_bias=None,
|
|
alpha: Optional[float] = 0.5,
|
|
weight_scales=None):
|
|
if not isinstance(gemm_weights, list):
|
|
gemm_weights = [gemm_weights]
|
|
orig_dtype = gemm_weights[0].dtype
|
|
|
|
for gemm in gemm_weights:
|
|
# gemm_weights are expected to be transposed
|
|
assert gemm.shape[1] == act_scales.numel()
|
|
|
|
if weight_scales is None:
|
|
weight_scales = torch.cat(
|
|
[gemm.abs().max(dim=0, keepdim=True)[0] for gemm in gemm_weights],
|
|
dim=0)
|
|
weight_scales = weight_scales.max(dim=0)[0]
|
|
weight_scales = weight_scales.to(float).clamp(min=1e-5)
|
|
scales = (act_scales.to(gemm_weights[0].device).to(float).pow(alpha) /
|
|
weight_scales.pow(1 - alpha)).clamp(min=1e-5)
|
|
|
|
apply_smoothing(scales, gemm_weights, layernorm_weights, layernorm_bias,
|
|
orig_dtype)
|
|
|
|
return scales
|
|
|
|
|
|
@torch.no_grad()
|
|
def smooth_gemm_fc1_gate(fc1_weights,
|
|
gate_weights,
|
|
act_scales,
|
|
layernorm_weights=None,
|
|
layernorm_bias=None,
|
|
alpha=0.5,
|
|
weight_scales=None):
|
|
gemm_weights = []
|
|
if not isinstance(fc1_weights, list):
|
|
fc1_weights = [fc1_weights]
|
|
if not isinstance(gate_weights, list):
|
|
gate_weights = [gate_weights]
|
|
|
|
for i in range(len(fc1_weights)):
|
|
gemm_weight = torch.cat([fc1_weights[i], gate_weights[i]], dim=0)
|
|
gemm_weights.append(gemm_weight)
|
|
|
|
orig_dtype = gemm_weights[0].dtype
|
|
|
|
for gemm in gemm_weights:
|
|
# gemm_weights are expected to be transposed
|
|
assert gemm.shape[1] == act_scales.numel()
|
|
|
|
if weight_scales is None:
|
|
weight_scales = torch.cat(
|
|
[gemm.abs().max(dim=0, keepdim=True)[0] for gemm in gemm_weights],
|
|
dim=0)
|
|
weight_scales = weight_scales.max(dim=0)[0]
|
|
weight_scales = weight_scales.to(float).clamp(min=1e-5)
|
|
scales = (act_scales.to(gemm_weights[0].device).to(float).pow(alpha) /
|
|
weight_scales.pow(1 - alpha)).clamp(min=1e-5)
|
|
|
|
apply_smoothing(scales, fc1_weights + gate_weights, layernorm_weights,
|
|
layernorm_bias, orig_dtype)
|
|
|
|
return scales
|
|
|
|
|
|
def generate_int8(
|
|
weights: torch.Tensor,
|
|
act_range: Dict[str, torch.Tensor],
|
|
is_qkv: bool = False,
|
|
multi_query_mode: bool = False,
|
|
):
|
|
"""
|
|
This function has two purposes:
|
|
- compute quantized weights, scaled either per-tensor or per-column
|
|
- compute scaling factors
|
|
|
|
Depending on the GEMM API (CUTLASS/CUBLAS) the required scaling factors differ.
|
|
CUTLASS uses two sets of scaling factors. One for the activation X, one for the weight W.
|
|
CUBLAS only has one (we can't do per-row scaling). So we must provide pre-multiplied scaling factor.
|
|
|
|
Here is the list of what we need (T means per-tensor, C per-column):
|
|
- scale_x_orig_quant puts fp activation into the quantized range (i.e. [-128, 127], for int8). Used before the GEMM. (T)
|
|
- scale_y_quant_orig puts quantized activation into the fp range. Used if the GEMM outputs int8. (T)
|
|
- scale_w_quant_orig puts weights from quant range to fp range (used with CUTLASS) (T, C)
|
|
- scale_y_accum_quant puts the GEMM result (XW) from accumulation range (int32)
|
|
to quant range (int8) (used for CUBLAS) (T, C)
|
|
|
|
Note that we don't do anything special about row-parallel GEMM. Theoretically, we could have per-GPU scaling factors too,
|
|
but then the model would change depending on the number of GPUs used.
|
|
|
|
For QKV projection, the behavior is special. Even if we have a single matrix to perform QKV projection, we consider it
|
|
as three different matrices: Q, K, and V. So per-tensor actually means one scaling factor for each Q, K and V.
|
|
For our GEMM implementation to respect this behavior, we use per-column mode and replicate values along columns.
|
|
"""
|
|
|
|
# compute weight scaling factors for fp->int8 and int8->fp
|
|
if is_qkv and not multi_query_mode:
|
|
scale_w_orig_quant_t = 127. / act_range["w"].reshape(3, -1).max(
|
|
dim=-1, keepdims=True)[0]
|
|
scale_w_orig_quant_c = 127. / act_range["w"].reshape(3, -1)
|
|
elif is_qkv and multi_query_mode:
|
|
hidden_dim = weights.shape[0]
|
|
local_dim = act_range["w"].shape[0]
|
|
kv_dim = (local_dim - hidden_dim) // 2
|
|
scale_w_q = act_range["w"][0:hidden_dim]
|
|
scale_w_k = act_range["w"][hidden_dim:hidden_dim + kv_dim]
|
|
scale_w_v = act_range["w"][-kv_dim:]
|
|
|
|
scale_w_qkv_t = torch.concat([
|
|
scale_w_q.max(dim=0, keepdim=True)[0],
|
|
scale_w_k.max(dim=0, keepdim=True)[0],
|
|
scale_w_v.max(dim=0, keepdim=True)[0]
|
|
])
|
|
|
|
scale_w_orig_quant_t = 127. / scale_w_qkv_t
|
|
scale_w_orig_quant_c = 127. / act_range["w"]
|
|
else:
|
|
scale_w_orig_quant_t = 127. / act_range["w"].max()
|
|
scale_w_orig_quant_c = 127. / act_range["w"]
|
|
scale_w_quant_orig_t = 1.0 / scale_w_orig_quant_t
|
|
scale_w_quant_orig_c = 1.0 / scale_w_orig_quant_c
|
|
|
|
scale_w_orig_quant_c = scale_w_orig_quant_c.to(torch.float32)
|
|
scale_w_orig_quant_t = scale_w_orig_quant_t.to(torch.float32)
|
|
|
|
# compute the rest of needed scaling factors
|
|
scale_x_orig_quant_t = 127. / act_range["x"].max()
|
|
scale_y_orig_quant_t = 127. / act_range["y"].max()
|
|
scale_y_quant_orig_t = act_range["y"].max() / 127.
|
|
scale_y_accum_quant_t = scale_y_orig_quant_t / (scale_x_orig_quant_t *
|
|
scale_w_orig_quant_t)
|
|
scale_y_accum_quant_c = scale_y_orig_quant_t / (scale_x_orig_quant_t *
|
|
scale_w_orig_quant_c)
|
|
if is_qkv and not multi_query_mode:
|
|
scale_y_accum_quant_t = torch.broadcast_to(scale_y_accum_quant_t,
|
|
scale_w_orig_quant_c.shape)
|
|
scale_w_quant_orig_t = torch.broadcast_to(scale_w_quant_orig_t,
|
|
scale_w_orig_quant_c.shape)
|
|
if is_qkv and multi_query_mode:
|
|
scale_q_y_accum_t = torch.broadcast_to(scale_y_accum_quant_t[0],
|
|
scale_w_q.shape)
|
|
scale_k_y_accum_t = torch.broadcast_to(scale_y_accum_quant_t[1],
|
|
scale_w_k.shape)
|
|
scale_v_y_accum_t = torch.broadcast_to(scale_y_accum_quant_t[2],
|
|
scale_w_v.shape)
|
|
scale_y_accum_quant_t = torch.concat(
|
|
[scale_q_y_accum_t, scale_k_y_accum_t, scale_v_y_accum_t])
|
|
scale_w_quant_orig_t = torch.concat([
|
|
torch.broadcast_to(scale_w_quant_orig_t[0], scale_w_q.shape),
|
|
torch.broadcast_to(scale_w_quant_orig_t[1], scale_w_k.shape),
|
|
torch.broadcast_to(scale_w_quant_orig_t[2], scale_w_v.shape)
|
|
])
|
|
|
|
to_i8 = lambda x: x.round().clip(-127, 127).to(torch.int8)
|
|
|
|
if is_qkv and multi_query_mode:
|
|
if weights.device != scale_w_quant_orig_t.device:
|
|
scale_w_quant_orig_t = scale_w_quant_orig_t.to(weights.device)
|
|
weight_int8 = to_i8(weights / scale_w_quant_orig_t)
|
|
else:
|
|
if weights.device != scale_w_orig_quant_t.device:
|
|
scale_w_orig_quant_t = scale_w_orig_quant_t.to(weights.device)
|
|
weight_int8 = to_i8(weights * scale_w_orig_quant_t)
|
|
if weights.device != scale_w_orig_quant_c.device:
|
|
scale_w_orig_quant_c = scale_w_orig_quant_c.to(weights.device)
|
|
|
|
return {
|
|
"weight.int8": weight_int8,
|
|
"weight.int8.col": to_i8(weights * scale_w_orig_quant_c),
|
|
"scale_x_orig_quant": scale_x_orig_quant_t.to(torch.float32),
|
|
"scale_w_quant_orig": scale_w_quant_orig_t.to(torch.float32),
|
|
"scale_w_quant_orig.col": scale_w_quant_orig_c.to(torch.float32),
|
|
"scale_y_accum_quant": scale_y_accum_quant_t.to(torch.float32),
|
|
"scale_y_accum_quant.col": scale_y_accum_quant_c.to(torch.float32),
|
|
"scale_y_quant_orig": scale_y_quant_orig_t.to(torch.float32),
|
|
}
|
|
|
|
|
|
def get_tllm_linear_weight(weight,
|
|
prefix,
|
|
bias=None,
|
|
use_weight_only=False,
|
|
plugin_weight_only_quant_type=torch.int8,
|
|
dtype='float32',
|
|
use_gemm_woq_plugin=False,
|
|
use_fp8_rowwise=False,
|
|
weight_scale=None,
|
|
clamp_value=[-1200.0, 1200],
|
|
tp_rank=0,
|
|
postfix='weight',
|
|
quant_scale_name=None):
|
|
results = {}
|
|
if use_weight_only:
|
|
if weight_scale:
|
|
logger.error(
|
|
"Weight only doesn't support loading scales from the weights.")
|
|
if weight.dim() > 2:
|
|
v = weight.transpose(1, 2).contiguous()
|
|
else:
|
|
v = weight.t().contiguous()
|
|
processed_torch_weights, torch_weight_scales = \
|
|
torch.ops.trtllm.symmetric_quantize_last_axis_of_batched_matrix(
|
|
v.cpu(), plugin_weight_only_quant_type)
|
|
if not use_gemm_woq_plugin:
|
|
results[prefix + postfix] = v.to(dtype)
|
|
else:
|
|
results[prefix + postfix] = processed_torch_weights
|
|
if quant_scale_name is not None:
|
|
results[quant_scale_name] = torch_weight_scales
|
|
else:
|
|
results[prefix + 'per_channel_scale'] = torch_weight_scales
|
|
elif use_fp8_rowwise:
|
|
if weight_scale is not None:
|
|
assert weight.dtype == torch.float8_e4m3fn, "weight data type must be torch.float8_e4m3fn"
|
|
results[prefix + postfix] = weight
|
|
torch_weight_scales = weight_scale.to(torch.float32)
|
|
else:
|
|
processed_torch_weights, torch_weight_scales = fp8_per_channel_quant_weight_gpu(
|
|
weight, clamp_value)
|
|
results[prefix + postfix] = processed_torch_weights
|
|
torch_weight_scales = torch_weight_scales.to(torch.float32)
|
|
|
|
if quant_scale_name is not None:
|
|
results[quant_scale_name] = torch_weight_scales
|
|
else:
|
|
results[prefix + 'per_channel_scale'] = torch_weight_scales
|
|
else:
|
|
results[prefix + postfix] = weight
|
|
|
|
if bias is not None:
|
|
results[prefix + 'bias'] = bias
|
|
|
|
return results
|