mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
325 lines
13 KiB
Python
325 lines
13 KiB
Python
# autoflake: skip_file
|
|
import argparse
|
|
import json
|
|
import os
|
|
import re
|
|
import shutil
|
|
|
|
import torch
|
|
from safetensors.torch import safe_open, save_file
|
|
from tqdm import tqdm
|
|
|
|
import tensorrt_llm
|
|
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument('--model_dir',
|
|
type=str,
|
|
required=True,
|
|
help='HF checkpoint path')
|
|
parser.add_argument('--output_dir',
|
|
type=str,
|
|
required=True,
|
|
help='Save path')
|
|
parser.add_argument(
|
|
'--act_scales',
|
|
type=str,
|
|
required=True,
|
|
help=
|
|
'ModelOpt calibrated checkpoint dir or extracted safetensors for activation scales'
|
|
)
|
|
parser.add_argument('--parts',
|
|
type=int,
|
|
default=1,
|
|
help='devide all safetensors into parts')
|
|
parser.add_argument('--rank',
|
|
type=int,
|
|
default=0,
|
|
help='which part to be quantize')
|
|
args = parser.parse_args()
|
|
return args
|
|
|
|
|
|
def load_and_preprocess_state_dict(modelopt_state_root, world_size=8):
|
|
state_dict_list = []
|
|
# load amax from state dict
|
|
for rank in range(world_size):
|
|
amax_file = f"{modelopt_state_root}/amax_dict_rank{rank}-mp{world_size}.pt"
|
|
if os.path.exists(amax_file):
|
|
state_dict_list.append(torch.load(amax_file, map_location="cuda:0"))
|
|
else:
|
|
print(f"WARNING: amax file not found: {amax_file}")
|
|
|
|
if not state_dict_list:
|
|
print("ERROR: No amax files loaded!")
|
|
return {}
|
|
|
|
# calculate the max across all TP ranks
|
|
merged_state_dict = state_dict_list[0]
|
|
for rank in range(world_size):
|
|
for key, amax in state_dict_list[rank].items():
|
|
if key in merged_state_dict.items():
|
|
amax = torch.max(amax.to(0), merged_state_dict[key].to(0))
|
|
merged_state_dict[key] = amax.to(0)
|
|
|
|
mapping = {
|
|
"ffn.shared_experts.w1": "mlp.shared_experts.gate_proj",
|
|
"ffn.shared_experts.w2": "mlp.shared_experts.down_proj",
|
|
"ffn.shared_experts.w3": "mlp.shared_experts.up_proj",
|
|
"ffn.shared_experts": "mlp.shared_experts",
|
|
"ffn.shared_experts": "mlp.shared_experts",
|
|
"ffn.shared_experts": "mlp.shared_experts",
|
|
"ffn.w1": "mlp.gate_proj",
|
|
"ffn.w2": "mlp.down_proj",
|
|
"ffn.w3": "mlp.up_proj",
|
|
"head": "lm_head",
|
|
"attn": "self_attn",
|
|
}
|
|
new_dict = {}
|
|
for k, v in merged_state_dict.items():
|
|
new_key = k.replace("layers", "model.layers")
|
|
for original_pattern, replace_pattern in mapping.items():
|
|
new_key = new_key.replace(original_pattern, replace_pattern)
|
|
# ffn.experts.xx.w1/w2/w3- > mlp.experts.xx.gate_proj/down_proj/up_proj
|
|
new_key = re.sub(r"ffn\.experts\.(\d+)\.w1",
|
|
r"mlp.experts.\1.gate_proj", new_key)
|
|
new_key = re.sub(r"ffn\.experts\.(\d+)\.w2",
|
|
r"mlp.experts.\1.down_proj", new_key)
|
|
new_key = re.sub(r"ffn\.experts\.(\d+)\.w3", r"mlp.experts.\1.up_proj",
|
|
new_key)
|
|
new_dict[new_key] = v
|
|
|
|
merged_state_dict.clear()
|
|
merged_state_dict.update(new_dict)
|
|
|
|
# set amax for modules to be fused and make sure they share the same input
|
|
for key, amax in merged_state_dict.items():
|
|
if "up_proj" in key:
|
|
gate_proj_key = key.replace("up_proj", "gate_proj")
|
|
if "weight_quantizer" in key:
|
|
fused_amax = torch.max(amax, merged_state_dict[gate_proj_key])
|
|
merged_state_dict[key] = fused_amax
|
|
merged_state_dict[gate_proj_key] = fused_amax
|
|
elif "input_quantizer" in key:
|
|
assert amax == merged_state_dict[gate_proj_key]
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
return merged_state_dict
|
|
|
|
|
|
def get_scales_from_amax(start_layer, end_layer, renamed_state_dict):
|
|
weight_name_dict = {"gate_proj": 1, "down_proj": 2, "up_proj": 3}
|
|
scales = {}
|
|
for layer_idx in range(start_layer, end_layer):
|
|
amax_keys_per_layer = [
|
|
x for x in renamed_state_dict.keys()
|
|
if (x.startswith(f'model.layers.{layer_idx}.mlp.experts.')
|
|
and x.endswith(".input_quantizer._amax"))
|
|
]
|
|
for k in amax_keys_per_layer:
|
|
expert_idx = int(k.split('.')[5])
|
|
weight_idx = weight_name_dict[k.split('.')[6]]
|
|
val = renamed_state_dict[k]
|
|
scales[
|
|
f'model.layers.{layer_idx}.mlp.experts.{expert_idx}.w{weight_idx}.input_scale'] = val.unsqueeze(
|
|
0) / 448
|
|
|
|
return scales
|
|
|
|
|
|
def quantize_fp8_block_scale_to_int4(fp8_tensor, fp8_scale):
|
|
group_size = 128
|
|
blocked_tensor = fp8_tensor.view(fp8_tensor.shape[0] // 128, 128,
|
|
fp8_tensor.shape[1] // 128,
|
|
128).to(torch.float32)
|
|
dequant_tensor = (blocked_tensor *
|
|
fp8_scale.unsqueeze(1).unsqueeze(3)).view(
|
|
fp8_tensor.shape[0],
|
|
fp8_tensor.shape[1] // group_size,
|
|
group_size).to(torch.bfloat16).to(torch.float32)
|
|
scale_tensor = torch.abs(dequant_tensor).max(dim=2).values / 7
|
|
quant_tensor = torch.clamp(torch.round(
|
|
(dequant_tensor / scale_tensor.unsqueeze(-1))),
|
|
min=-8,
|
|
max=7)
|
|
quant_tensor = quant_tensor.to(torch.int8)
|
|
return quant_tensor.view(fp8_tensor.shape), scale_tensor
|
|
|
|
|
|
def main(args):
|
|
model_dir = args.model_dir
|
|
output_dir = args.output_dir
|
|
os.makedirs(output_dir, exist_ok=True)
|
|
|
|
if torch.cuda.is_available():
|
|
num_gpus = torch.cuda.device_count()
|
|
torch.cuda.set_device(args.rank % num_gpus)
|
|
|
|
model_index_file = os.path.join(model_dir, "model.safetensors.index.json")
|
|
with open(model_index_file, "r") as f:
|
|
model_index = json.load(f)
|
|
weight_map = model_index["weight_map"]
|
|
|
|
processed_files = {}
|
|
for tensor_name in list(weight_map.keys()):
|
|
if tensor_name not in weight_map:
|
|
continue
|
|
file_name = weight_map[tensor_name]
|
|
if file_name in processed_files:
|
|
continue
|
|
processed_files[file_name] = safe_open(os.path.join(
|
|
model_dir, file_name),
|
|
"pt",
|
|
device="cuda")
|
|
|
|
with open(os.path.join(model_dir, "config.json"), 'r') as file:
|
|
config = json.load(file)
|
|
|
|
num_layer = config['num_hidden_layers']
|
|
part_layer = (num_layer + args.parts - 1) // args.parts
|
|
start_layer = args.rank * part_layer
|
|
end_layer = min(num_layer, args.rank * part_layer + part_layer)
|
|
|
|
def get_tensor(name):
|
|
if name not in weight_map:
|
|
return None
|
|
ff = weight_map[name]
|
|
safetensors_loader = processed_files[ff]
|
|
return safetensors_loader.get_tensor(name).cuda()
|
|
|
|
def get_file_name(layer):
|
|
rank = layer // part_layer
|
|
return "model-%05d-of-%05d.safetensors" % (rank, args.parts)
|
|
|
|
new_safetensors = {}
|
|
new_json = {}
|
|
new_json['weight_map'] = {}
|
|
new_json['metadata'] = {}
|
|
for key in tqdm(list(weight_map.keys())):
|
|
if "mlp.experts" in key and (key.endswith("weight")
|
|
or key.endswith("weight_scale_inv")):
|
|
if key.endswith("weight_scale_inv"):
|
|
continue
|
|
if args.rank == 0:
|
|
layer = int(key.split(".")[2])
|
|
new_json['weight_map'][key] = get_file_name(layer)
|
|
new_json['weight_map'][key.replace(
|
|
"weight", "weight_scale_inv")] = get_file_name(layer)
|
|
if int(key.split(".")[2]) < start_layer or int(
|
|
key.split(".")[2]) >= end_layer:
|
|
continue
|
|
fp8_tensor = get_tensor(key)
|
|
fp8_scale = get_tensor(key.replace("weight", "weight_scale_inv"))
|
|
quant_tensor, scale_tensor = quantize_fp8_block_scale_to_int4(
|
|
fp8_tensor, fp8_scale)
|
|
|
|
packer = torch.ops.trtllm.pack_int8_tensor_to_packed_int4
|
|
packed_tensor = packer(quant_tensor.cpu().contiguous())
|
|
new_safetensors.update({key: packed_tensor})
|
|
new_safetensors.update({
|
|
key.replace("weight", "weight_scale_inv"):
|
|
scale_tensor.contiguous()
|
|
})
|
|
else:
|
|
name = key.split(".")
|
|
if args.rank == 0:
|
|
if len(name) < 3 or not name[2].isdigit():
|
|
new_safetensors.update({key: get_tensor(key)})
|
|
new_json['weight_map'][key] = get_file_name(0)
|
|
continue
|
|
|
|
file_name = get_file_name(int(name[2]))
|
|
new_json['weight_map'][key] = file_name
|
|
|
|
if len(name) < 3 or not name[2].isdigit() or (int(
|
|
name[2]) < start_layer or int(name[2]) >= end_layer):
|
|
continue
|
|
new_safetensors.update({key: get_tensor(key)})
|
|
|
|
# Process activation scales for all ranks
|
|
if os.path.isdir(args.act_scales):
|
|
# Extract activation scales
|
|
renamed_state_dict = load_and_preprocess_state_dict(
|
|
modelopt_state_root=args.act_scales, world_size=8)
|
|
scales = get_scales_from_amax(start_layer=start_layer,
|
|
end_layer=end_layer,
|
|
renamed_state_dict=renamed_state_dict)
|
|
new_safetensors.update(scales)
|
|
|
|
if args.rank == 0:
|
|
if not os.path.isdir(args.act_scales):
|
|
input_scales = safe_open(args.act_scales, "pt")
|
|
for k in input_scales.keys():
|
|
new_safetensors.update({k: input_scales.get_tensor(k)})
|
|
new_json['weight_map'][k] = args.act_scales.split("/")[-1]
|
|
|
|
file_name = get_file_name(start_layer)
|
|
print(f'saving to {file_name}...')
|
|
save_file(new_safetensors, os.path.join(output_dir, file_name))
|
|
with open(os.path.join(output_dir, "model.safetensors.index.json"),
|
|
"w") as f:
|
|
json.dump(new_json, f)
|
|
|
|
names = [
|
|
"configuration_deepseek.py", "generation_config.json",
|
|
"modeling_deepseek.py", "tokenizer.json", "tokenizer_config.json"
|
|
]
|
|
for name in names:
|
|
shutil.copy(os.path.join(model_dir, name), output_dir)
|
|
if os.path.isdir(args.act_scales):
|
|
shutil.copytree(args.act_scales, output_dir, dirs_exist_ok=True)
|
|
else:
|
|
shutil.copy(args.act_scales, output_dir)
|
|
|
|
# config.json
|
|
del config['quantization_config']
|
|
with open(os.path.join(output_dir, "config.json"), 'w') as file:
|
|
json.dump(config, file, indent=4)
|
|
|
|
# quant_cfg.json
|
|
attn_names = ["fused_a", "q_b_proj", "kv_b_proj", "o_proj"]
|
|
mlp_names = ["gate_up_proj", "down_proj"]
|
|
fp8_block_scale = {"quant_algo": "FP8_BLOCK_SCALES"}
|
|
w4a8_awq = {"quant_algo": "W4A8_AWQ"}
|
|
quant_cfg = {}
|
|
quant_cfg["quant_algo"] = "MIXED_PRECISION"
|
|
quant_cfg["kv_cache_quant_algo"] = None
|
|
quant_cfg["quantized_layers"] = {}
|
|
for l in range(61):
|
|
prefix = f"model.layers.{l}"
|
|
for n1 in attn_names:
|
|
quant_cfg["quantized_layers"][
|
|
f"{prefix}.self_attn.{n1}"] = fp8_block_scale
|
|
for n2 in mlp_names:
|
|
quant_cfg["quantized_layers"][
|
|
f"{prefix}.mlp.shared_experts.{n2}"] = fp8_block_scale
|
|
if l < 3:
|
|
for n3 in mlp_names:
|
|
quant_cfg["quantized_layers"][
|
|
f"{prefix}.mlp.{n3}"] = fp8_block_scale
|
|
else:
|
|
quant_cfg["quantized_layers"][
|
|
f"{prefix}.mlp.experts"] = w4a8_awq
|
|
with open(os.path.join(output_dir, "quant_cfg.json"), 'w') as file:
|
|
json.dump(quant_cfg, file, indent=4)
|
|
|
|
# hf_quant_config.json
|
|
hf_quant_config = {}
|
|
hf_quant_config['quantization'] = {}
|
|
hf_quant_config['quantization']["quant_algo"] = "MIXED_PRECISION"
|
|
hf_quant_config['quantization']["kv_cache_quant_algo"] = None
|
|
with open(os.path.join(output_dir, "hf_quant_config.json"),
|
|
'w') as file:
|
|
json.dump(hf_quant_config, file, indent=4)
|
|
else:
|
|
file_name = get_file_name(start_layer)
|
|
print(f'saving to {file_name}...')
|
|
save_file(new_safetensors, os.path.join(output_dir, file_name))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
args = parse_args()
|
|
main(args)
|