TensorRT-LLMs/examples/redrafter/convert_checkpoint.py
Guoming Zhang 202bed4574 [None][chroe] Rename TensorRT-LLM to TensorRT LLM for source code. (#7851)
Signed-off-by: nv-guomingz <137257613+nv-guomingz@users.noreply.github.com>
Signed-off-by: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com>
2025-09-25 21:02:35 +08:00

458 lines
16 KiB
Python

# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import copy
import json
import os
import traceback
from argparse import Namespace
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path
from typing import Dict, Optional
import safetensors
import torch
from transformers.models.auto import AutoModel
import tensorrt_llm.models.modeling_utils
from tensorrt_llm._utils import str_dtype_to_torch
from tensorrt_llm.mapping import Mapping
from tensorrt_llm.models import PretrainedConfig
from tensorrt_llm.models.llama.convert import (get_weight, get_weight_and_bias,
split)
BASE_MODEL_TLLM_WEIGHT_PREFIX = "base_model."
DRAFTER_TLLM_WEIGHT_PREFIX = "drafter."
# To add support for a new base model in ReDrafter:
# 1. Add the base model's tensorrt_llm class name mapping in `REDRAFTER_MAP` below
# 2. Create a new ReDrafter class in `tensorrt_llm/redrafter/models.py` by inheriting from `ReDrafterMixin`
# 3. Add the new ReDrafter class to the model registry in `tensorrt_llm/models/__init__.py`
#
# Example:
# REDRAFTER_MAP = {
# "QWenForCausalLM": "ReDrafterForQWenLM",
# "Qwen2ForCausalLM": "ReDrafterForQWenLM",
# "LlamaForCausalLM": "ReDrafterForLLaMALM"
# }
REDRAFTER_MAP = {
"QWenForCausalLM": "ReDrafterForQWenLM",
"Qwen2ForCausalLM": "ReDrafterForQWenLM",
"LlamaForCausalLM": "ReDrafterForLLaMALM"
}
def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument("--base_model_checkpoint_dir",
type=str,
default=None,
required=True)
parser.add_argument("--drafter_model_dir",
type=str,
default=None,
required=True)
parser.add_argument("--tp_size",
type=int,
default=1,
help="N-way tensor parallelism size")
parser.add_argument("--dtype",
type=str,
default="float16",
choices=["float32", "bfloat16", "float16"])
parser.add_argument("--storage-type",
"-t",
type=str,
default="fp32",
choices=["fp32", "fp16"])
parser.add_argument("--load_model_on_cpu", action="store_true")
parser.add_argument(
"--use_parallel_embedding",
action="store_true",
default=False,
help="By default embedding parallelism is disabled.",
)
parser.add_argument(
"--embedding_sharding_dim",
type=int,
default=0,
choices=[0, 1],
help=
"By default the embedding lookup table is sharded along vocab dimension (=0). "
"To shard it along hidden dimension, set embedding_sharding_dim=1"
"Note: embedding sharing is only enabled when embedding_sharding_dim = 0",
)
parser.add_argument(
"--output_dir",
type=str,
default="tllm_checkpoint",
help="The path to save the TensorRT LLM checkpoint",
)
parser.add_argument(
"--workers",
type=int,
default=1,
help="The number of workers for converting checkpoint in parallel",
)
parser.add_argument(
"--dense_context_fmha",
default=False,
action="store_true",
help=
"Enable dense fmha in context phase, otherwise sliding window attention."
"If dense_context_fmha=False, the sliding window size is the max attention window size.",
)
parser.add_argument(
"--redrafter_draft_len_per_beam",
type=int,
default=5,
help=
"Number of times that the Recurrent Drafter runs the beam search to generate draft"
"candidates. Note that this draft_len does not include the first true/guaranteed token.",
)
parser.add_argument(
"--redrafter_num_beams",
type=int,
default=5,
help="Number of beam search candidates to keep during the Recurrent"
"Drafter beam search iterations.",
)
parser.add_argument(
"--redrafter_no_greedy_search",
action="store_false",
default=True,
dest="redrafter_greedy_search",
help=
"Whether Redrafter will use the token with the highest probability from lm_head"
"output or randomly sampled from the probability distribution.",
)
return parser.parse_args()
def hf_drafter(
hf_model: Namespace, #DrafterModel, # TODO:
mapping: Mapping,
dtype: torch.dtype = torch.float32,
additional_tllm_prefix: str = "",
) -> Dict[str, torch.Tensor]:
"""
Possible tensor names for Drafter checkpoints:
input_proj.weight
input_proj.bias
lm_head.0.linear.weight
lm_head.0.linear.bias
lm_head.1.linear.weight
lm_head.1.linear.bias
lm_head.2.weight
rnn_u.weight
rnn_u.bias
rnn_w.weight
OR
input_projs.weight
input_projs.bias
lm_heads.0.linear.weight
lm_heads.0.linear.bias
lm_heads.1.linear.weight
lm_heads.1.linear.bias
lm_heads.2.weight
OR
0.0.linear.weight
0.0.linear.bias
0.1.linear.weight
0.1.linear.bias
0.2.weight
"""
def get_weight_and_bias_with_multiple_possible_names(
model_params, dtype, names_to_try, bias=True):
w, b = None, None
for name in names_to_try:
try:
if bias:
w, b = get_weight_and_bias(model_params, name, dtype)
else:
w = get_weight(model_params, name, dtype)
break
except:
pass
if not bias:
return w
return w, b
weights = {}
# TODO: When ReDrafter is added to Transformers
# model_params = dict(hf_model.named_parameters())
model_params = dict(hf_model.named_parameters)
if hf_model.config.hidden_size * 2 != hf_model.config.exit_dim:
input_proj_weight, input_proj_bias = get_weight_and_bias_with_multiple_possible_names(
model_params, dtype, ["input_proj", "input_projs"])
weights[f"{additional_tllm_prefix}input_proj.weight"] = split(
input_proj_weight, mapping.tp_size, mapping.tp_rank, dim=0)
weights[f"{additional_tllm_prefix}input_proj.bias"] = split(
input_proj_bias, mapping.tp_size, mapping.tp_rank, dim=0)
for layer_idx in range(hf_model.config.num_draft_layers):
layer_weight, layer_bias = get_weight_and_bias_with_multiple_possible_names(
model_params, dtype, [
f"lm_head.{layer_idx}.linear", f"lm_heads.{layer_idx}.linear",
f"0.{layer_idx}.linear"
])
weights[
f"{additional_tllm_prefix}layers.{layer_idx}.linear.weight"] = split(
layer_weight, mapping.tp_size, mapping.tp_rank, dim=0)
weights[
f"{additional_tllm_prefix}layers.{layer_idx}.linear.bias"] = split(
layer_bias, mapping.tp_size, mapping.tp_rank, dim=0)
last_layer_weight = get_weight_and_bias_with_multiple_possible_names(
model_params,
dtype, [
f"lm_head.{hf_model.config.num_draft_layers}",
f"lm_heads.{hf_model.config.num_draft_layers}",
f"0.{hf_model.config.num_draft_layers}"
],
bias=False)
weights[f"{additional_tllm_prefix}lm_head.weight"] = split(
last_layer_weight, mapping.tp_size, mapping.tp_rank, dim=0)
if hf_model.config.rnn:
# rnn_u has both weight and bias
rnn_u_weight, rnn_u_bias = get_weight_and_bias(model_params, "rnn_u",
dtype)
weights[f"{additional_tllm_prefix}rnn_u.weight"] = split(
rnn_u_weight, mapping.tp_size, mapping.tp_rank, dim=0)
weights[f"{additional_tllm_prefix}rnn_u.bias"] = split(rnn_u_bias,
mapping.tp_size,
mapping.tp_rank,
dim=0)
# rnn_w only has weight
rnn_w_weight = get_weight(model_params, "rnn_w", dtype)
weights[f"{additional_tllm_prefix}rnn_w.weight"] = split(
rnn_w_weight, mapping.tp_size, mapping.tp_rank, dim=0)
return weights
def hf_redrafter_config(
tllm_base_model_config: tensorrt_llm.models.modeling_utils.PretrainedConfig,
drafter_config: Namespace, # DrafterConfig
redrafter_num_beams: int,
redrafter_draft_len_per_beam: int,
redrafter_greedy_search: bool,
) -> tensorrt_llm.models.modeling_utils.PretrainedConfig:
tllm_config = copy.deepcopy(tllm_base_model_config)
tllm_config.base_model_architecture = tllm_config.architecture
tllm_config.architecture = REDRAFTER_MAP[tllm_config.architecture]
setattr(tllm_config, "redrafter_num_layers",
drafter_config.num_draft_layers)
setattr(tllm_config, "redrafter_hidden_size", drafter_config.hidden_size)
setattr(tllm_config, "redrafter_exit_dim", drafter_config.exit_dim)
setattr(tllm_config, "redrafter_is_rnn", drafter_config.rnn)
# These three configs look like runtime parameters. But for TensorRT-LLM
# implementation, they are required to be provided at engine build time and
# TensorRT needs to unroll loops with set number of loop iterations.
setattr(tllm_config, "redrafter_num_beams", redrafter_num_beams)
setattr(tllm_config, "redrafter_draft_len_per_beam",
redrafter_draft_len_per_beam)
setattr(tllm_config, "redrafter_greedy_search", redrafter_greedy_search)
# Exclude the redrafter weights from any quantisation
if hasattr(tllm_config,
"quantization") and tllm_config.quantization is not None:
# If quantization is an object/namespace, handle it accordingly
if getattr(tllm_config.quantization, "exclude_modules",
None) is not None:
redrafter_exclude_modules = [
'drafter', 'drafter.layers', 'drafter.lm_head'
]
num_redrafter_layers = tllm_config.redrafter_num_layers
if tllm_config.redrafter_is_rnn:
redrafter_exclude_modules += ['drafter.rnn_u', 'drafter.rnn_w']
for lyrnum in range(num_redrafter_layers):
redrafter_exclude_modules += [
f'drafter.layers.{lyrnum}',
f'drafter.layers.{lyrnum}.linear'
]
tllm_config.quantization.exclude_modules += redrafter_exclude_modules
return tllm_config
def convert_and_save(
rank: int,
tp_size: int,
base_model_checkpoint_dir: str,
hf_drafter_model: Optional[AutoModel],
dtype: str,
use_parallel_embedding: bool,
embedding_sharding_dim: int,
output_dir: str,
) -> None:
mapping = Mapping(
world_size=tp_size,
rank=rank,
tp_size=tp_size,
)
# Load and prepare weights
stade_dict_path = os.path.join(base_model_checkpoint_dir,
f'rank{rank}.safetensors')
weights_safe = safetensors.safe_open(stade_dict_path, framework="pt")
weights = {k: weights_safe.get_tensor(k) for k in weights_safe.keys()}
if hf_drafter_model is not None:
drafter_weights = hf_drafter(
hf_drafter_model,
mapping,
dtype=str_dtype_to_torch(dtype),
additional_tllm_prefix=(DRAFTER_TLLM_WEIGHT_PREFIX
if hf_drafter_model is not None else ""),
)
weights.update(drafter_weights)
safetensors.torch.save_file(
weights, os.path.join(output_dir, f"rank{rank}.safetensors"))
def multi_worker_convert_and_save(
workers: int,
tp_size: int,
base_model_checkpoint_dir: str,
hf_drafter_model: Optional[AutoModel],
dtype: str,
use_parallel_embedding: bool,
embedding_sharding_dim: int,
output_dir: str,
) -> None:
with ThreadPoolExecutor(max_workers=workers) as p:
futures = [
p.submit(
convert_and_save,
rank,
tp_size,
base_model_checkpoint_dir,
hf_drafter_model,
dtype,
use_parallel_embedding,
embedding_sharding_dim,
output_dir,
) for rank in range(tp_size)
]
exceptions = []
for future in as_completed(futures):
try:
future.result()
except Exception as e:
traceback.print_exc()
exceptions.append(e)
assert len(
exceptions
) == 0, "Checkpoint conversion failed, please check error log."
def create_and_save_config(args):
mapping = Mapping(
world_size=args.tp_size,
tp_size=args.tp_size,
pp_size=1,
)
base_checkpoint_dir = args.base_model_checkpoint_dir
config_path = os.path.join(base_checkpoint_dir, 'config.json')
model_config = PretrainedConfig.from_json_file(config_path)
tllm_model_config = copy.deepcopy(model_config)
if args.drafter_model_dir:
# TODO: When ReDrafter is added to Transformers
# drafter_hf_config = AutoConfig.from_pretrained(args.drafter_model_dir)
with open(Path(args.drafter_model_dir, "config.json")) as fp:
drafter_hf_config = Namespace(**json.load(fp))
tllm_model_config = hf_redrafter_config(
tllm_base_model_config=tllm_model_config,
drafter_config=drafter_hf_config,
redrafter_num_beams=args.redrafter_num_beams,
redrafter_draft_len_per_beam=args.redrafter_draft_len_per_beam,
redrafter_greedy_search=args.redrafter_greedy_search,
)
with open(os.path.join(args.output_dir, "config.json"), "w") as f:
json.dump(tllm_model_config.to_dict(), f, indent=4)
return drafter_hf_config
def main():
args = parse_arguments()
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
drafter_hf_config = create_and_save_config(args)
base_checkpoint_dir = args.base_model_checkpoint_dir
hf_drafter_model: Optional[AutoModel] = None
if args.drafter_model_dir:
# TODO: When ReDrafter is added to Transformers
# hf_drafter_model = AutoModel.from_pretrained(
# args.drafter_model_dir,
# torch_dtype="auto",
# )
ckpt_file = Path(args.drafter_model_dir, "model.safetensors")
if not Path.exists(ckpt_file):
ckpt_file = Path(args.drafter_model_dir, "model.pt")
print(f"Loading drafter from {ckpt_file}")
if str(ckpt_file).endswith(".safetensors"):
drafter_ckpt = {}
with safetensors.safe_open(ckpt_file, framework="pt",
device="cpu") as f:
key: str = None
for key in f.keys():
drafter_ckpt[key] = f.get_tensor(key)
else:
drafter_ckpt = torch.load(ckpt_file, map_location='cpu')
hf_drafter_model = Namespace(**{
"named_parameters": drafter_ckpt,
"config": drafter_hf_config
})
multi_worker_convert_and_save(
args.workers,
args.tp_size,
base_checkpoint_dir,
hf_drafter_model,
args.dtype,
args.use_parallel_embedding,
args.embedding_sharding_dim,
args.output_dir,
)
if __name__ == "__main__":
main()