mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
Signed-off-by: nv-guomingz <137257613+nv-guomingz@users.noreply.github.com> Signed-off-by: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com>
458 lines
16 KiB
Python
458 lines
16 KiB
Python
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import argparse
|
|
import copy
|
|
import json
|
|
import os
|
|
import traceback
|
|
from argparse import Namespace
|
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
from pathlib import Path
|
|
from typing import Dict, Optional
|
|
|
|
import safetensors
|
|
import torch
|
|
from transformers.models.auto import AutoModel
|
|
|
|
import tensorrt_llm.models.modeling_utils
|
|
from tensorrt_llm._utils import str_dtype_to_torch
|
|
from tensorrt_llm.mapping import Mapping
|
|
from tensorrt_llm.models import PretrainedConfig
|
|
from tensorrt_llm.models.llama.convert import (get_weight, get_weight_and_bias,
|
|
split)
|
|
|
|
BASE_MODEL_TLLM_WEIGHT_PREFIX = "base_model."
|
|
DRAFTER_TLLM_WEIGHT_PREFIX = "drafter."
|
|
|
|
# To add support for a new base model in ReDrafter:
|
|
# 1. Add the base model's tensorrt_llm class name mapping in `REDRAFTER_MAP` below
|
|
# 2. Create a new ReDrafter class in `tensorrt_llm/redrafter/models.py` by inheriting from `ReDrafterMixin`
|
|
# 3. Add the new ReDrafter class to the model registry in `tensorrt_llm/models/__init__.py`
|
|
#
|
|
# Example:
|
|
# REDRAFTER_MAP = {
|
|
# "QWenForCausalLM": "ReDrafterForQWenLM",
|
|
# "Qwen2ForCausalLM": "ReDrafterForQWenLM",
|
|
# "LlamaForCausalLM": "ReDrafterForLLaMALM"
|
|
# }
|
|
|
|
REDRAFTER_MAP = {
|
|
"QWenForCausalLM": "ReDrafterForQWenLM",
|
|
"Qwen2ForCausalLM": "ReDrafterForQWenLM",
|
|
"LlamaForCausalLM": "ReDrafterForLLaMALM"
|
|
}
|
|
|
|
|
|
def parse_arguments():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--base_model_checkpoint_dir",
|
|
type=str,
|
|
default=None,
|
|
required=True)
|
|
parser.add_argument("--drafter_model_dir",
|
|
type=str,
|
|
default=None,
|
|
required=True)
|
|
|
|
parser.add_argument("--tp_size",
|
|
type=int,
|
|
default=1,
|
|
help="N-way tensor parallelism size")
|
|
parser.add_argument("--dtype",
|
|
type=str,
|
|
default="float16",
|
|
choices=["float32", "bfloat16", "float16"])
|
|
|
|
parser.add_argument("--storage-type",
|
|
"-t",
|
|
type=str,
|
|
default="fp32",
|
|
choices=["fp32", "fp16"])
|
|
parser.add_argument("--load_model_on_cpu", action="store_true")
|
|
parser.add_argument(
|
|
"--use_parallel_embedding",
|
|
action="store_true",
|
|
default=False,
|
|
help="By default embedding parallelism is disabled.",
|
|
)
|
|
parser.add_argument(
|
|
"--embedding_sharding_dim",
|
|
type=int,
|
|
default=0,
|
|
choices=[0, 1],
|
|
help=
|
|
"By default the embedding lookup table is sharded along vocab dimension (=0). "
|
|
"To shard it along hidden dimension, set embedding_sharding_dim=1"
|
|
"Note: embedding sharing is only enabled when embedding_sharding_dim = 0",
|
|
)
|
|
parser.add_argument(
|
|
"--output_dir",
|
|
type=str,
|
|
default="tllm_checkpoint",
|
|
help="The path to save the TensorRT LLM checkpoint",
|
|
)
|
|
parser.add_argument(
|
|
"--workers",
|
|
type=int,
|
|
default=1,
|
|
help="The number of workers for converting checkpoint in parallel",
|
|
)
|
|
parser.add_argument(
|
|
"--dense_context_fmha",
|
|
default=False,
|
|
action="store_true",
|
|
help=
|
|
"Enable dense fmha in context phase, otherwise sliding window attention."
|
|
"If dense_context_fmha=False, the sliding window size is the max attention window size.",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--redrafter_draft_len_per_beam",
|
|
type=int,
|
|
default=5,
|
|
help=
|
|
"Number of times that the Recurrent Drafter runs the beam search to generate draft"
|
|
"candidates. Note that this draft_len does not include the first true/guaranteed token.",
|
|
)
|
|
parser.add_argument(
|
|
"--redrafter_num_beams",
|
|
type=int,
|
|
default=5,
|
|
help="Number of beam search candidates to keep during the Recurrent"
|
|
"Drafter beam search iterations.",
|
|
)
|
|
parser.add_argument(
|
|
"--redrafter_no_greedy_search",
|
|
action="store_false",
|
|
default=True,
|
|
dest="redrafter_greedy_search",
|
|
help=
|
|
"Whether Redrafter will use the token with the highest probability from lm_head"
|
|
"output or randomly sampled from the probability distribution.",
|
|
)
|
|
|
|
return parser.parse_args()
|
|
|
|
|
|
def hf_drafter(
|
|
hf_model: Namespace, #DrafterModel, # TODO:
|
|
mapping: Mapping,
|
|
dtype: torch.dtype = torch.float32,
|
|
additional_tllm_prefix: str = "",
|
|
) -> Dict[str, torch.Tensor]:
|
|
"""
|
|
Possible tensor names for Drafter checkpoints:
|
|
input_proj.weight
|
|
input_proj.bias
|
|
lm_head.0.linear.weight
|
|
lm_head.0.linear.bias
|
|
lm_head.1.linear.weight
|
|
lm_head.1.linear.bias
|
|
lm_head.2.weight
|
|
rnn_u.weight
|
|
rnn_u.bias
|
|
rnn_w.weight
|
|
|
|
OR
|
|
|
|
input_projs.weight
|
|
input_projs.bias
|
|
lm_heads.0.linear.weight
|
|
lm_heads.0.linear.bias
|
|
lm_heads.1.linear.weight
|
|
lm_heads.1.linear.bias
|
|
lm_heads.2.weight
|
|
|
|
OR
|
|
|
|
0.0.linear.weight
|
|
0.0.linear.bias
|
|
0.1.linear.weight
|
|
0.1.linear.bias
|
|
0.2.weight
|
|
|
|
"""
|
|
|
|
def get_weight_and_bias_with_multiple_possible_names(
|
|
model_params, dtype, names_to_try, bias=True):
|
|
w, b = None, None
|
|
for name in names_to_try:
|
|
try:
|
|
if bias:
|
|
w, b = get_weight_and_bias(model_params, name, dtype)
|
|
else:
|
|
w = get_weight(model_params, name, dtype)
|
|
break
|
|
except:
|
|
pass
|
|
if not bias:
|
|
return w
|
|
return w, b
|
|
|
|
weights = {}
|
|
# TODO: When ReDrafter is added to Transformers
|
|
# model_params = dict(hf_model.named_parameters())
|
|
model_params = dict(hf_model.named_parameters)
|
|
|
|
if hf_model.config.hidden_size * 2 != hf_model.config.exit_dim:
|
|
input_proj_weight, input_proj_bias = get_weight_and_bias_with_multiple_possible_names(
|
|
model_params, dtype, ["input_proj", "input_projs"])
|
|
weights[f"{additional_tllm_prefix}input_proj.weight"] = split(
|
|
input_proj_weight, mapping.tp_size, mapping.tp_rank, dim=0)
|
|
weights[f"{additional_tllm_prefix}input_proj.bias"] = split(
|
|
input_proj_bias, mapping.tp_size, mapping.tp_rank, dim=0)
|
|
|
|
for layer_idx in range(hf_model.config.num_draft_layers):
|
|
layer_weight, layer_bias = get_weight_and_bias_with_multiple_possible_names(
|
|
model_params, dtype, [
|
|
f"lm_head.{layer_idx}.linear", f"lm_heads.{layer_idx}.linear",
|
|
f"0.{layer_idx}.linear"
|
|
])
|
|
weights[
|
|
f"{additional_tllm_prefix}layers.{layer_idx}.linear.weight"] = split(
|
|
layer_weight, mapping.tp_size, mapping.tp_rank, dim=0)
|
|
weights[
|
|
f"{additional_tllm_prefix}layers.{layer_idx}.linear.bias"] = split(
|
|
layer_bias, mapping.tp_size, mapping.tp_rank, dim=0)
|
|
|
|
last_layer_weight = get_weight_and_bias_with_multiple_possible_names(
|
|
model_params,
|
|
dtype, [
|
|
f"lm_head.{hf_model.config.num_draft_layers}",
|
|
f"lm_heads.{hf_model.config.num_draft_layers}",
|
|
f"0.{hf_model.config.num_draft_layers}"
|
|
],
|
|
bias=False)
|
|
weights[f"{additional_tllm_prefix}lm_head.weight"] = split(
|
|
last_layer_weight, mapping.tp_size, mapping.tp_rank, dim=0)
|
|
|
|
if hf_model.config.rnn:
|
|
# rnn_u has both weight and bias
|
|
rnn_u_weight, rnn_u_bias = get_weight_and_bias(model_params, "rnn_u",
|
|
dtype)
|
|
weights[f"{additional_tllm_prefix}rnn_u.weight"] = split(
|
|
rnn_u_weight, mapping.tp_size, mapping.tp_rank, dim=0)
|
|
weights[f"{additional_tllm_prefix}rnn_u.bias"] = split(rnn_u_bias,
|
|
mapping.tp_size,
|
|
mapping.tp_rank,
|
|
dim=0)
|
|
|
|
# rnn_w only has weight
|
|
rnn_w_weight = get_weight(model_params, "rnn_w", dtype)
|
|
weights[f"{additional_tllm_prefix}rnn_w.weight"] = split(
|
|
rnn_w_weight, mapping.tp_size, mapping.tp_rank, dim=0)
|
|
|
|
return weights
|
|
|
|
|
|
def hf_redrafter_config(
|
|
tllm_base_model_config: tensorrt_llm.models.modeling_utils.PretrainedConfig,
|
|
drafter_config: Namespace, # DrafterConfig
|
|
redrafter_num_beams: int,
|
|
redrafter_draft_len_per_beam: int,
|
|
redrafter_greedy_search: bool,
|
|
) -> tensorrt_llm.models.modeling_utils.PretrainedConfig:
|
|
tllm_config = copy.deepcopy(tllm_base_model_config)
|
|
|
|
tllm_config.base_model_architecture = tllm_config.architecture
|
|
tllm_config.architecture = REDRAFTER_MAP[tllm_config.architecture]
|
|
setattr(tllm_config, "redrafter_num_layers",
|
|
drafter_config.num_draft_layers)
|
|
setattr(tllm_config, "redrafter_hidden_size", drafter_config.hidden_size)
|
|
setattr(tllm_config, "redrafter_exit_dim", drafter_config.exit_dim)
|
|
setattr(tllm_config, "redrafter_is_rnn", drafter_config.rnn)
|
|
|
|
# These three configs look like runtime parameters. But for TensorRT-LLM
|
|
# implementation, they are required to be provided at engine build time and
|
|
# TensorRT needs to unroll loops with set number of loop iterations.
|
|
setattr(tllm_config, "redrafter_num_beams", redrafter_num_beams)
|
|
setattr(tllm_config, "redrafter_draft_len_per_beam",
|
|
redrafter_draft_len_per_beam)
|
|
setattr(tllm_config, "redrafter_greedy_search", redrafter_greedy_search)
|
|
|
|
# Exclude the redrafter weights from any quantisation
|
|
if hasattr(tllm_config,
|
|
"quantization") and tllm_config.quantization is not None:
|
|
# If quantization is an object/namespace, handle it accordingly
|
|
if getattr(tllm_config.quantization, "exclude_modules",
|
|
None) is not None:
|
|
redrafter_exclude_modules = [
|
|
'drafter', 'drafter.layers', 'drafter.lm_head'
|
|
]
|
|
num_redrafter_layers = tllm_config.redrafter_num_layers
|
|
if tllm_config.redrafter_is_rnn:
|
|
redrafter_exclude_modules += ['drafter.rnn_u', 'drafter.rnn_w']
|
|
for lyrnum in range(num_redrafter_layers):
|
|
redrafter_exclude_modules += [
|
|
f'drafter.layers.{lyrnum}',
|
|
f'drafter.layers.{lyrnum}.linear'
|
|
]
|
|
tllm_config.quantization.exclude_modules += redrafter_exclude_modules
|
|
|
|
return tllm_config
|
|
|
|
|
|
def convert_and_save(
|
|
rank: int,
|
|
tp_size: int,
|
|
base_model_checkpoint_dir: str,
|
|
hf_drafter_model: Optional[AutoModel],
|
|
dtype: str,
|
|
use_parallel_embedding: bool,
|
|
embedding_sharding_dim: int,
|
|
output_dir: str,
|
|
) -> None:
|
|
|
|
mapping = Mapping(
|
|
world_size=tp_size,
|
|
rank=rank,
|
|
tp_size=tp_size,
|
|
)
|
|
|
|
# Load and prepare weights
|
|
stade_dict_path = os.path.join(base_model_checkpoint_dir,
|
|
f'rank{rank}.safetensors')
|
|
weights_safe = safetensors.safe_open(stade_dict_path, framework="pt")
|
|
weights = {k: weights_safe.get_tensor(k) for k in weights_safe.keys()}
|
|
|
|
if hf_drafter_model is not None:
|
|
drafter_weights = hf_drafter(
|
|
hf_drafter_model,
|
|
mapping,
|
|
dtype=str_dtype_to_torch(dtype),
|
|
additional_tllm_prefix=(DRAFTER_TLLM_WEIGHT_PREFIX
|
|
if hf_drafter_model is not None else ""),
|
|
)
|
|
weights.update(drafter_weights)
|
|
|
|
safetensors.torch.save_file(
|
|
weights, os.path.join(output_dir, f"rank{rank}.safetensors"))
|
|
|
|
|
|
def multi_worker_convert_and_save(
|
|
workers: int,
|
|
tp_size: int,
|
|
base_model_checkpoint_dir: str,
|
|
hf_drafter_model: Optional[AutoModel],
|
|
dtype: str,
|
|
use_parallel_embedding: bool,
|
|
embedding_sharding_dim: int,
|
|
output_dir: str,
|
|
) -> None:
|
|
with ThreadPoolExecutor(max_workers=workers) as p:
|
|
futures = [
|
|
p.submit(
|
|
convert_and_save,
|
|
rank,
|
|
tp_size,
|
|
base_model_checkpoint_dir,
|
|
hf_drafter_model,
|
|
dtype,
|
|
use_parallel_embedding,
|
|
embedding_sharding_dim,
|
|
output_dir,
|
|
) for rank in range(tp_size)
|
|
]
|
|
exceptions = []
|
|
for future in as_completed(futures):
|
|
try:
|
|
future.result()
|
|
except Exception as e:
|
|
traceback.print_exc()
|
|
exceptions.append(e)
|
|
assert len(
|
|
exceptions
|
|
) == 0, "Checkpoint conversion failed, please check error log."
|
|
|
|
|
|
def create_and_save_config(args):
|
|
mapping = Mapping(
|
|
world_size=args.tp_size,
|
|
tp_size=args.tp_size,
|
|
pp_size=1,
|
|
)
|
|
|
|
base_checkpoint_dir = args.base_model_checkpoint_dir
|
|
config_path = os.path.join(base_checkpoint_dir, 'config.json')
|
|
model_config = PretrainedConfig.from_json_file(config_path)
|
|
tllm_model_config = copy.deepcopy(model_config)
|
|
|
|
if args.drafter_model_dir:
|
|
# TODO: When ReDrafter is added to Transformers
|
|
# drafter_hf_config = AutoConfig.from_pretrained(args.drafter_model_dir)
|
|
with open(Path(args.drafter_model_dir, "config.json")) as fp:
|
|
drafter_hf_config = Namespace(**json.load(fp))
|
|
tllm_model_config = hf_redrafter_config(
|
|
tllm_base_model_config=tllm_model_config,
|
|
drafter_config=drafter_hf_config,
|
|
redrafter_num_beams=args.redrafter_num_beams,
|
|
redrafter_draft_len_per_beam=args.redrafter_draft_len_per_beam,
|
|
redrafter_greedy_search=args.redrafter_greedy_search,
|
|
)
|
|
|
|
with open(os.path.join(args.output_dir, "config.json"), "w") as f:
|
|
json.dump(tllm_model_config.to_dict(), f, indent=4)
|
|
return drafter_hf_config
|
|
|
|
|
|
def main():
|
|
args = parse_arguments()
|
|
if not os.path.exists(args.output_dir):
|
|
os.makedirs(args.output_dir)
|
|
|
|
drafter_hf_config = create_and_save_config(args)
|
|
base_checkpoint_dir = args.base_model_checkpoint_dir
|
|
|
|
hf_drafter_model: Optional[AutoModel] = None
|
|
if args.drafter_model_dir:
|
|
# TODO: When ReDrafter is added to Transformers
|
|
# hf_drafter_model = AutoModel.from_pretrained(
|
|
# args.drafter_model_dir,
|
|
# torch_dtype="auto",
|
|
# )
|
|
ckpt_file = Path(args.drafter_model_dir, "model.safetensors")
|
|
if not Path.exists(ckpt_file):
|
|
ckpt_file = Path(args.drafter_model_dir, "model.pt")
|
|
print(f"Loading drafter from {ckpt_file}")
|
|
if str(ckpt_file).endswith(".safetensors"):
|
|
drafter_ckpt = {}
|
|
with safetensors.safe_open(ckpt_file, framework="pt",
|
|
device="cpu") as f:
|
|
key: str = None
|
|
for key in f.keys():
|
|
drafter_ckpt[key] = f.get_tensor(key)
|
|
else:
|
|
drafter_ckpt = torch.load(ckpt_file, map_location='cpu')
|
|
hf_drafter_model = Namespace(**{
|
|
"named_parameters": drafter_ckpt,
|
|
"config": drafter_hf_config
|
|
})
|
|
|
|
multi_worker_convert_and_save(
|
|
args.workers,
|
|
args.tp_size,
|
|
base_checkpoint_dir,
|
|
hf_drafter_model,
|
|
args.dtype,
|
|
args.use_parallel_embedding,
|
|
args.embedding_sharding_dim,
|
|
args.output_dir,
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|