import copy import functools import time from collections import defaultdict from pathlib import Path import numpy as np import torch import torch.nn as nn from tqdm import tqdm from transformers.models.llama.modeling_llama import LlamaDecoderLayer from transformers.pytorch_utils import Conv1D from tensorrt_llm._utils import str_dtype_to_torch from tensorrt_llm.logger import logger from tensorrt_llm.mapping import Mapping from tensorrt_llm.models.convert_utils import (dup_kv_weight, generate_int8, smooth_gemm, smooth_gemm_fc1_gate, split, split_matrix_tp, split_qkv_tp) def get_tllm_linear_weight(weight, prefix, bias=None, use_weight_only=False, plugin_weight_only_quant_type=torch.int8, postfix='weight'): results = {} if use_weight_only: v = weight.t().contiguous().cpu() processed_torch_weights, torch_weight_scales = \ torch.ops.trtllm.symmetric_quantize_last_axis_of_batched_matrix( v, plugin_weight_only_quant_type) results[prefix + postfix] = processed_torch_weights results[prefix + 'per_channel_scale'] = torch_weight_scales else: results[prefix + postfix] = weight.contiguous() if bias is not None: results[prefix + 'bias'] = bias return results def load_medusa_hf(medusa_path: str, num_medusa_heads: int, num_medusa_layers: int, mapping=Mapping(), dtype='float32', use_weight_only=False, plugin_weight_only_quant_type=None, is_modelopt_ckpt=False): logger.info("Loading Medusa heads' weights ...") if is_modelopt_ckpt: from safetensors.torch import load_file state_dict = {} for filename in sorted(Path(medusa_path).glob("*.safetensors")): print(f"Loading the weights of Medusa heads from {filename}") state_dict.update(load_file(filename)) else: is_ckpt_safetensors = False ckpt_file = Path(medusa_path) / "medusa_lm_head.pt" if not ckpt_file.exists(): ckpt_file = Path(medusa_path) / "medusa_lm_head.safetensors" is_ckpt_safetensors = True if is_ckpt_safetensors: logger.info("Safetensors Found ...") from safetensors.torch import load_file state_dict = load_file(ckpt_file) else: state_dict = torch.load(ckpt_file, map_location="cpu") torch_dtype = str_dtype_to_torch(dtype) weights = {} prefix = "medusa_heads." if is_modelopt_ckpt else "" for h in range(num_medusa_heads): for l in range(num_medusa_layers): w = state_dict[f"{prefix}{h}.{l}.linear.weight"].clone().to( torch_dtype) split_v = split(w, mapping.tp_size, mapping.tp_rank) weights.update( get_tllm_linear_weight( split_v, f'medusa_heads.{h}.medusa_layers.{l}.linear.', None, use_weight_only, plugin_weight_only_quant_type)) b = state_dict[f"{prefix}{h}.{l}.linear.bias"].clone().to( torch_dtype) weights['medusa_heads.{}.medusa_layers.{}.linear.bias'.format( h, l)] = split(b, mapping.tp_size, mapping.tp_rank) lm = state_dict[f"{prefix}{h}.{num_medusa_layers}.weight"].clone().to( torch_dtype) # LM Head weights['medusa_heads.{}.lm_head.weight'.format(h)] = split( lm, mapping.tp_size, mapping.tp_rank) # scaling factors if is_modelopt_ckpt: scaling_dtype = str_dtype_to_torch("float32") weights[f'medusa_heads.{h}.medusa_layers.{l}.linear.activation_scaling_factor'] = \ state_dict[f"{prefix}{h}.{l}.linear.input_scale"].clone().to(scaling_dtype) weights[f'medusa_heads.{h}.medusa_layers.{l}.linear.weights_scaling_factor'] = \ state_dict[f"{prefix}{h}.{l}.linear.weight_scale"].clone().to(scaling_dtype) weights['medusa_heads.{}.lm_head.activation_scaling_factor'.format(h)] = \ state_dict[f"{prefix}{h}.{num_medusa_layers}.input_scale"].clone().to(scaling_dtype) weights['medusa_heads.{}.lm_head.weights_scaling_factor'.format(h)] = \ state_dict[f"{prefix}{h}.{num_medusa_layers}.weight_scale"].clone().to(scaling_dtype) return weights @torch.no_grad() def smooth_llama_model(model, scales, alpha, llama_qkv_para, llama_smoother): # Smooth the activation and weights with smoother = $\diag{s}$ for name, module in model.named_modules(): if not isinstance(module, LlamaDecoderLayer): continue # qkv_proj layer_name_q = name + ".self_attn.q_proj" layer_name_k = name + ".self_attn.k_proj" layer_name_v = name + ".self_attn.v_proj" layer_name_qkv = name + ".self_attn.qkv_proj" weight = torch.cat([ module.self_attn.q_proj.weight, module.self_attn.k_proj.weight, module.self_attn.v_proj.weight ], dim=0) smoother = smooth_gemm(weight, scales[layer_name_q]["x"], module.input_layernorm.weight, None, alpha) scales[layer_name_qkv]["x"] = scales[layer_name_q]["x"] / smoother scales[layer_name_qkv]["w"] = weight.abs().max(dim=1)[0] scales[layer_name_qkv]["y"] = torch.cat([ scales[layer_name_q]["y"], scales[layer_name_k]["y"], scales[layer_name_v]["y"] ], dim=0) # see transpose_weights function llama_qkv_para[layer_name_qkv] = weight.transpose(0, 1) # ================================================================= layer_name = name + ".self_attn.o_proj" smoother = smooth_gemm(module.self_attn.o_proj.weight, scales[layer_name]["x"], None, None, alpha) llama_smoother[layer_name] = smoother.float() scales[layer_name]["x"] = scales[layer_name]["x"] / smoother scales[layer_name]["w"] = module.self_attn.o_proj.weight.abs().max( dim=1)[0] # ================================================================== fc1_layer_name = name + ".mlp.gate_proj" gate_layer_name = name + ".mlp.up_proj" smoother = smooth_gemm_fc1_gate(module.mlp.gate_proj.weight, module.mlp.up_proj.weight, scales[fc1_layer_name]["x"], module.post_attention_layernorm.weight, None, alpha) scales[fc1_layer_name]["x"] = scales[fc1_layer_name]["x"] / smoother scales[fc1_layer_name]["w"] = module.mlp.gate_proj.weight.abs().max( dim=1)[0] scales[gate_layer_name]["x"] = scales[gate_layer_name]["x"] / smoother scales[gate_layer_name]["w"] = module.mlp.up_proj.weight.abs().max( dim=1)[0] # ================================================================== layer_name = name + ".mlp.down_proj" smoother = smooth_gemm(module.mlp.down_proj.weight, scales[layer_name]["x"], None, None, alpha) llama_smoother[layer_name] = smoother.float() scales[layer_name]["x"] = scales[layer_name]["x"] / smoother scales[layer_name]["w"] = module.mlp.down_proj.weight.abs().max( dim=1)[0] @torch.no_grad() def capture_activation_range(model, tokenizer, dataset, num_samples=512, seq_len=512): model.eval() device = next(model.parameters()).device act_scales = defaultdict(lambda: {"x": None, "y": None, "w": None}) tokenizer.pad_token = tokenizer.eos_token def stat_tensor(name, tensor, act_scales, key): hidden_dim = tensor.shape[-1] tensor = tensor.view(-1, hidden_dim).abs().detach() comming_max = torch.max(tensor, dim=0)[0].float() if act_scales[name][key] is None: act_scales[name][key] = comming_max else: act_scales[name][key] = torch.max(act_scales[name][key], comming_max) def stat_input_hook(m, x, y, name): if isinstance(x, tuple): x = x[0] stat_tensor(name, x, act_scales, "x") stat_tensor(name, y, act_scales, "y") if act_scales[name]["w"] is None: act_scales[name]["w"] = m.weight.abs().clip(1e-8, None).max(dim=1)[0] hooks = [] for name, m in model.named_modules(): if isinstance(m, nn.Linear) or isinstance(m, Conv1D): hooks.append( m.register_forward_hook( functools.partial(stat_input_hook, name=name))) for i in tqdm(range(num_samples), desc="calibrating model"): datapoint = dataset[i:i + 1] line = copy.copy(datapoint) line[0] = line[0] + ' TL;DR: ' line[0] = line[0].strip() line[0] = line[0].replace(" n't", "n't") input_ids = tokenizer(line, return_tensors="pt", max_length=seq_len, padding=True, truncation=True).input_ids.to(device) model(input_ids) for h in hooks: h.remove() return act_scales def get_weight(config, prefix, dtype): if config[prefix + '.weight'].dtype != dtype: config[prefix + '.weight'].data = config[prefix + '.weight'].to(dtype) return config[prefix + '.weight'] def get_bias(config, prefix, dtype): if config[prefix + '.bias'].dtype != dtype: config[prefix + '.bias'].data = config[prefix + '.bias'].to(dtype) return config[prefix + '.bias'] def get_weight_and_bias(config, prefix, dtype): return get_weight(config, prefix, dtype), get_bias(config, prefix, dtype) def get_tllm_linear_sq_weight(vals, prefix, shape, tensor_parallel, is_qkv=False, per_token=False, per_channel=False, last_prefix=None, bias=None, smoother_value=None, smoother_shape=None, rank=0, cat_dim=0, multi_query_mode=False): results = {} def multi_query_split(data, local_dim, head_size, tp_size, cur_rank): q, k, v = np.split(data, [local_dim, local_dim + head_size], axis=-1) q_split = np.split(q, tp_size, axis=-1) k_split = np.split(k, tp_size, axis=-1) v_split = np.split(v, tp_size, axis=-1) return [ np.concatenate((q_split[ii], k_split[ii], v_split[ii]), axis=-1) for ii in range(tp_size) ][cur_rank] col_shape = shape if (is_qkv or per_channel) else [1, 1] if per_token: original_weights = vals["weight.int8.col"] local_dim = original_weights.shape[0] head_size = (original_weights.shape[1] - local_dim) // 2 if multi_query_mode: cur_weights = multi_query_split(original_weights, local_dim, head_size, tensor_parallel, rank) else: cur_weights = np.split(original_weights, tensor_parallel, axis=cat_dim)[rank] if is_qkv: hidden_dim = cur_weights.shape[0] cur_weights = cur_weights.reshape(hidden_dim, -1) results[prefix + 'weight'] = torch.from_numpy(cur_weights).t().contiguous() if smoother_value is None: results[last_prefix] = torch.from_numpy( np.array([1.0], dtype=np.float32)) if smoother_value is None: if multi_query_mode: cur_per_channel_value = multi_query_split( vals["scale_w_quant_orig.col"], local_dim, head_size, tensor_parallel, rank) else: cur_per_channel_value = np.split(vals["scale_w_quant_orig.col"], tensor_parallel, axis=cat_dim)[rank] else: cur_per_channel_value = vals["scale_w_quant_orig.col"] results[prefix + 'per_channel_scale'] = torch.from_numpy( np.array(cur_per_channel_value, dtype=np.float32).reshape(col_shape)).contiguous() else: original_weights = np.array(vals["weight.int8"]) cur_weights = np.split(original_weights, tensor_parallel, axis=cat_dim)[rank] if is_qkv: hidden_dim = cur_weights.shape[0] cur_weights = cur_weights.reshape(hidden_dim, -1) results[prefix + 'weight'] = torch.from_numpy(cur_weights).t().contiguous() cur_per_channel_value = vals["scale_y_accum_quant"] results[prefix + 'per_channel_scale'] = torch.from_numpy( np.array([cur_per_channel_value], dtype=np.float32).reshape(col_shape)).contiguous() results[last_prefix] = torch.from_numpy( np.array([vals['scale_x_orig_quant']], dtype=np.float32)).contiguous() results[prefix + 'act_scale'] = torch.from_numpy( np.array([[vals["scale_y_quant_orig"]]], dtype=np.float32)).contiguous() if smoother_value is not None: cur_smoother_value = np.split(smoother_value, tensor_parallel, axis=cat_dim)[rank] results[prefix + 'smoother'] = cur_smoother_value.reshape( smoother_shape).contiguous().to(torch.float32) if bias is not None: results[prefix + 'bias'] = bias return results def convert_hf_llama(hf_model, mapping, rank=0, dtype='float32', use_parallel_embedding=False, sharding_dim=0, use_weight_only=False, plugin_weight_only_quant_type=torch.int8, use_smooth_quant=False, per_channel=False, per_token=False, int8_kv_cache=False, act_range=[], qkv_para=[], smoother=[], lora_config=None): weights = {} tik = time.time() tensor_parallel = mapping.tp_size model_params = dict(hf_model.named_parameters()) dtype = getattr(torch, dtype) num_attention_heads = hf_model.config.num_attention_heads hidden_size = hf_model.config.hidden_size intermediate_size = hf_model.config.intermediate_size num_key_value_heads = hf_model.config.num_key_value_heads mha_mode = (num_key_value_heads == num_attention_heads) num_hidden_layers = hf_model.config.num_hidden_layers layers_range = mapping.pp_layers(num_hidden_layers) for l in layers_range: layer_idx = l - layers_range[0] prefix = f'model.layers.{l}.' tllm_prex = f'transformer.layers.{layer_idx}.' q_weight = get_weight(model_params, prefix + 'self_attn.q_proj', dtype) k_weight = get_weight(model_params, prefix + 'self_attn.k_proj', dtype) v_weight = get_weight(model_params, prefix + 'self_attn.v_proj', dtype) if not mha_mode: head_size = hidden_size // num_attention_heads if num_key_value_heads < tensor_parallel: # duplicate the KV heads up to tensor_parallel k_weight = dup_kv_weight(k_weight, num_key_value_heads, tensor_parallel) v_weight = dup_kv_weight(v_weight, num_key_value_heads, tensor_parallel) assert (k_weight.shape[0] % (mapping.tp_size * head_size)) == 0 assert (v_weight.shape[0] % (mapping.tp_size * head_size)) == 0 wq = split(q_weight, mapping.tp_size, mapping.tp_rank) wk = split(k_weight, mapping.tp_size, mapping.tp_rank) wv = split(v_weight, mapping.tp_size, mapping.tp_rank) split_v = torch.concat((wq, wk, wv)) else: qkv_weight = torch.cat([q_weight, k_weight, v_weight], dim=0) split_v = split_qkv_tp(qkv_weight, num_attention_heads, hidden_size, tensor_parallel, mapping.tp_rank) if use_smooth_quant: qkv_weight = qkv_para[prefix + 'self_attn.qkv_proj'] if not mha_mode: hidden_size = qkv_weight.shape[0] local_dim = hidden_size head_size = (qkv_weight.shape[-1] - local_dim) // 2 qkv_weight = qkv_weight.reshape(hidden_size, local_dim + 2 * head_size) else: qkv_weight = qkv_weight.reshape(hidden_size, 3, hidden_size) int8_weights = generate_int8(qkv_weight, act_range.get(prefix + 'self_attn.qkv_proj'), is_qkv=True, multi_query_mode=bool(not mha_mode)) weights.update( get_tllm_linear_sq_weight( int8_weights, tllm_prex + 'attention.qkv.', [ 1, 3 * hidden_size // tensor_parallel if mha_mode else hidden_size // tensor_parallel + (hidden_size // num_key_value_heads) // tensor_parallel * 2 ], tensor_parallel, is_qkv=True, per_token=per_token, per_channel=per_channel, last_prefix=tllm_prex + 'input_layernorm.scale_to_int', smoother_value=None, smoother_shape=None, rank=mapping.tp_rank, cat_dim=-1, multi_query_mode=bool(not mha_mode))) else: weights.update( get_tllm_linear_weight(split_v, tllm_prex + 'attention.qkv.', None, use_weight_only, plugin_weight_only_quant_type)) if int8_kv_cache: qkv_y = torch.cat([ act_range.get(prefix + 'self_attn.q_proj')["y"], act_range.get(prefix + 'self_attn.k_proj')["y"], act_range.get(prefix + 'self_attn.v_proj')["y"] ], dim=0) int8_kv_scales = qkv_y.max() / 127. kv_cache_weights = {} kv_cache_weights[ tllm_prex + 'attention.kv_cache_scaling_factor'] = int8_kv_scales.reshape( [1]) attn_dense_weight = get_weight(model_params, prefix + 'self_attn.o_proj', dtype) split_v = split_matrix_tp(attn_dense_weight, tensor_parallel, mapping.tp_rank, dim=1) if use_smooth_quant: attn_dense_weight = attn_dense_weight.t() int8_weights = generate_int8( attn_dense_weight, act_range.get(prefix + 'self_attn.o_proj')) weights.update( get_tllm_linear_sq_weight( int8_weights, tllm_prex + 'attention.dense.', [1, hidden_size], tensor_parallel, is_qkv=False, per_token=per_token, per_channel=per_channel, last_prefix=tllm_prex + 'attention.quantization_scaling_factor', smoother_value=smoother[(prefix + 'self_attn.o_proj')], smoother_shape=[1, hidden_size // tensor_parallel], rank=mapping.tp_rank, cat_dim=0)) else: weights.update( get_tllm_linear_weight(split_v, tllm_prex + 'attention.dense.', None, use_weight_only, plugin_weight_only_quant_type)) mlp_gate_weight = get_weight(model_params, prefix + 'mlp.up_proj', dtype) split_v = split_matrix_tp(mlp_gate_weight, tensor_parallel, mapping.tp_rank, dim=0) if use_smooth_quant: mlp_gate_weight = mlp_gate_weight.t() int8_weights = generate_int8(mlp_gate_weight, act_range.get(prefix + 'mlp.up_proj')) weights.update( get_tllm_linear_sq_weight( int8_weights, tllm_prex + 'mlp.gate.', [1, intermediate_size // tensor_parallel], tensor_parallel, is_qkv=False, per_token=per_token, per_channel=per_channel, last_prefix=tllm_prex + 'post_layernorm.scale_to_int', smoother_value=None, smoother_shape=None, rank=mapping.tp_rank, cat_dim=-1)) else: weights.update( get_tllm_linear_weight(split_v, tllm_prex + 'mlp.gate.', None, use_weight_only, plugin_weight_only_quant_type)) mlp_fc_weight = get_weight(model_params, prefix + 'mlp.gate_proj', dtype) split_v = split_matrix_tp(mlp_fc_weight, tensor_parallel, mapping.tp_rank, dim=0) if use_smooth_quant: mlp_fc_weight = mlp_fc_weight.t() #verified int8_weights = generate_int8( mlp_fc_weight, act_range.get(prefix + 'mlp.gate_proj')) weights.update( get_tllm_linear_sq_weight( int8_weights, tllm_prex + 'mlp.fc.', [1, intermediate_size // tensor_parallel], tensor_parallel, is_qkv=False, per_token=per_token, per_channel=per_channel, last_prefix=tllm_prex + 'post_layernorm.scale_to_int', smoother_value=None, smoother_shape=None, rank=mapping.tp_rank, cat_dim=-1)) else: weights.update( get_tllm_linear_weight(split_v, tllm_prex + 'mlp.fc.', None, use_weight_only, plugin_weight_only_quant_type)) mlp_proj_weight = get_weight(model_params, prefix + 'mlp.down_proj', dtype) split_v = split_matrix_tp(mlp_proj_weight, tensor_parallel, mapping.tp_rank, dim=1) if use_smooth_quant: mlp_proj_weight = mlp_proj_weight.t() int8_weights = generate_int8( mlp_proj_weight, act_range.get(prefix + 'mlp.down_proj')) weights.update( get_tllm_linear_sq_weight( int8_weights, tllm_prex + 'mlp.proj.', [1, hidden_size], tensor_parallel, is_qkv=False, per_token=per_token, per_channel=per_channel, last_prefix=tllm_prex + 'mlp.quantization_scaling_factor', smoother_value=smoother[prefix + 'mlp.down_proj'], smoother_shape=[1, intermediate_size // tensor_parallel], rank=mapping.tp_rank, cat_dim=0)) else: weights.update( get_tllm_linear_weight(split_v, tllm_prex + 'mlp.proj.', None, use_weight_only, plugin_weight_only_quant_type)) # Layer norms do not use tensor parallelism input_ln_weight = get_weight(model_params, prefix + 'input_layernorm', dtype) weights[tllm_prex + 'input_layernorm.weight'] = input_ln_weight post_ln_weight = get_weight(model_params, prefix + 'post_attention_layernorm', dtype) weights[tllm_prex + 'post_layernorm.weight'] = post_ln_weight v = get_weight(model_params, 'model.embed_tokens', dtype) if hf_model.config.tie_word_embeddings: # lm_head.weight has the same weights as embedding if mapping.is_last_pp_rank(): weights['lm_head.weight'] = split(v, mapping.tp_size, mapping.tp_rank) if use_parallel_embedding: v = split_matrix_tp(v, mapping.tp_size, mapping.tp_rank, dim=sharding_dim) if mapping.is_first_pp_rank(): weights['transformer.vocab_embedding.weight'] = v lm_head_weights = get_weight(model_params, 'lm_head', dtype) if mapping.is_last_pp_rank(): weights['lm_head.weight'] = split_matrix_tp(lm_head_weights, tensor_parallel, mapping.tp_rank, dim=0) ln_f_w = get_weight(model_params, 'model.norm', dtype) weights['transformer.ln_f.weight'] = ln_f_w tok = time.time() t = time.strftime('%H:%M:%S', time.gmtime(tok - tik)) print(f'Weights loaded. Total time: {t}') return weights