TensorRT-LLMs/tensorrt_llm/commands/refit.py
2025-10-28 09:17:26 -07:00

181 lines
7.0 KiB
Python

'''
Script that refits TRT-LLM engine(s) with weights in a TRT-LLM checkpoint.
'''
import argparse
import json
import os
import re
import shutil
import time
from pathlib import Path
import tensorrt as trt
from tensorrt_llm._common import _is_building
from tensorrt_llm._utils import np_dtype_to_trt
from tensorrt_llm.builder import EngineConfig, optimize_model_with_config
from tensorrt_llm.models import MODEL_MAP, PretrainedConfig
from ..logger import logger
ENGINE_RE = re.compile('rank(\d+).engine')
@_is_building
def refit_engine(engine_path: str, refit_engine_dir: str, checkpoint_dir: str,
engine_config: EngineConfig, fixed_weights_names: list):
# This function loops through all weights in the model and does a textual match between
# checkpoint weight names and engine weight names.
rank = int(ENGINE_RE.fullmatch(os.path.basename(engine_path)).group(1))
tik = time.time()
with open(engine_path,
"rb") as f, trt.Runtime(logger.trt_logger) as runtime:
engine = runtime.deserialize_cuda_engine(f.read())
tok = time.time()
t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
logger.info(f'Load TRT engine time: {t}')
refitter = trt.Refitter(engine, logger.trt_logger)
refittable_weights = set(refitter.get_all_weights())
# Load model.
tik = time.time()
rank_config = PretrainedConfig.from_dict(
engine_config.pretrained_config.to_dict())
rank_config.set_rank(rank)
architecture = rank_config.architecture
assert architecture in MODEL_MAP, \
f"Unsupported model architecture: {architecture}"
model_cls = MODEL_MAP[architecture]
model = model_cls.from_checkpoint(checkpoint_dir, config=rank_config)
tok = time.time()
t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
logger.info(f'Load checkpoint(s) time: {t}')
# There are weights preprocess during optimize model.
tik = time.time()
build_config = engine_config.build_config.model_copy(deep=True)
optimize_model_with_config(model, build_config)
tok = time.time()
t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
logger.info(f'Preprocess weights time: {t}')
# Refit engine.
tik = time.time()
refitted_weights = []
for name, buf in model.named_parameters():
if name in refittable_weights:
assert buf.is_inited, f"Failed because weight `{name}` is not initialized in model."
weight = buf._value
weights_value = trt.Weights(np_dtype_to_trt(weight.dtype),
weight.ctypes.data, weight.size)
assert refitter.set_named_weights(
name, weights_value), f'Failed to refit weight: `{name}`'
refitted_weights.append(name)
else:
if name not in fixed_weights_names:
logger.warning(
f"model weights `{name}` (shape: {buf._value.shape}) is not refittable, this means that we might not be able to update the engine using fine-tuned checkpoint!"
)
# Validate all refittable weights are provided.
if len(refitted_weights) != len(refittable_weights):
raise RuntimeError(
f'Missing refittable weights {refittable_weights.difference(refitted_weights)} from {checkpoint_dir}'
)
assert refitter.refit_cuda_engine(), f'Failed to refit engine.'
tok = time.time()
t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
logger.info(f'Execute GPU refit graph time: {t}')
tik = time.time()
refit_engine_path = os.path.join(refit_engine_dir,
os.path.basename(engine_path))
with open(refit_engine_path, 'wb') as f:
logger.info(f'\nWriting refitted engine to `{refit_engine_path}`')
s_config = engine.create_serialization_config()
s_config.flags &= ~(1 << int(trt.SerializationFlag.EXCLUDE_WEIGHTS))
f.write(engine.serialize_with_config(s_config))
tok = time.time()
t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
logger.info(f'Write TRT engine to disk time: {t}')
del refitter
def refit(engine_dir: str, checkpoint_dir: str, engine_config: EngineConfig,
output_dir: str, fixed_weights_names: list):
refit_engine_dir = output_dir
os.makedirs(refit_engine_dir, exist_ok=True)
shutil.copyfile(os.path.join(engine_dir, 'config.json'),
os.path.join(refit_engine_dir, 'config.json'))
engine_paths = list(Path(engine_dir).glob('*.engine'))
for path in engine_paths:
refit_engine(path, refit_engine_dir, checkpoint_dir, engine_config,
fixed_weights_names)
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
'--engine_dir',
type=str,
default=None,
help=
'Path to trt-llm engines. These engines must have been built from a pruned checkpoint, or otherwise be refittable.'
)
parser.add_argument('--checkpoint_dir',
type=str,
default=None,
help='Path to checkpoint containing desired weights')
parser.add_argument('--output_dir',
type=str,
default=None,
help="Output path of the refit model")
parser.add_argument('--log_level', type=str, default='info')
args = parser.parse_args()
logger.set_level(args.log_level)
if args.engine_dir is None or not Path(args.engine_dir).exists():
raise RuntimeError(
f'Please supply a valid --engine_dir (found `{args.engine_dir}`)')
if args.checkpoint_dir is None or not Path(args.checkpoint_dir).exists():
raise RuntimeError(
f'Please supply a valid --checkpoint_dir (found `{args.checkpoint_dir}`)'
)
engine_config = EngineConfig.from_json_file(
os.path.join(args.engine_dir, 'config.json'))
with open(os.path.join(args.checkpoint_dir, 'config.json'), 'r') as f:
checkpoint_config = json.load(f)
engine_arch = engine_config.pretrained_config.architecture
checkpoint_arch = checkpoint_config['architecture']
if engine_arch != checkpoint_arch:
raise RuntimeError(
f'Engine Architecture and Checkpoint Architecture do not match. ' +
f'Engine Architecture: `{engine_arch}`, Checkpoint Architecture: `{checkpoint_arch}`'
)
# The fixed weights are not read from checkpoint, they are hardcoded buffer from the model itself. These values remain constant across different fine-tuned checkpoints.
fixed_wts_in_model = []
model_cls = MODEL_MAP[engine_arch]
model = model_cls.from_config(engine_config.pretrained_config)
for name, param in model.named_parameters():
if param.is_inited():
fixed_wts_in_model.append(name)
refit(engine_dir=os.path.normpath(args.engine_dir),
checkpoint_dir=os.path.normpath(args.checkpoint_dir),
engine_config=engine_config,
output_dir=os.path.normpath(args.output_dir),
fixed_weights_names=fixed_wts_in_model)
if __name__ == '__main__':
main()