mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
* Update TensorRT-LLM --------- Co-authored-by: Timur Abishev <abishev.timur@gmail.com> Co-authored-by: MahmoudAshraf97 <hassouna97.ma@gmail.com> Co-authored-by: Saeyoon Oh <saeyoon.oh@furiosa.ai> Co-authored-by: hattizai <hattizai@gmail.com>
123 lines
3.8 KiB
Python
123 lines
3.8 KiB
Python
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
'''
|
|
Script that prunes TRT-LLM checkpoints.
|
|
'''
|
|
import argparse
|
|
import json
|
|
import os
|
|
from pathlib import Path
|
|
from typing import Dict
|
|
|
|
import safetensors
|
|
import torch
|
|
from safetensors.torch import save_file
|
|
|
|
from tensorrt_llm.logger import logger
|
|
from tensorrt_llm.models import MODEL_MAP, PretrainedConfig
|
|
|
|
SUPPORTED_MODELS = list(MODEL_MAP.keys())
|
|
PRUNABLE_WEIGHTS = [
|
|
'attention.qkv.weight',
|
|
'attention.proj.weight',
|
|
'mlp.fc.weight',
|
|
'mlp.proj.weight',
|
|
'mlp.gate.weight',
|
|
]
|
|
|
|
|
|
def can_prune(key: str) -> bool:
|
|
for w in PRUNABLE_WEIGHTS:
|
|
if w in key:
|
|
return True
|
|
return False
|
|
|
|
|
|
def load_config(config_path: Path) -> Dict[str, any]:
|
|
if not config_path.exists():
|
|
return {}
|
|
|
|
with open(str(config_path), 'r') as f:
|
|
return json.load(f)
|
|
|
|
|
|
def prune_and_save(ckpt_dir: str, out_dir: str, prune_all: bool):
|
|
logger.info(f'Checkpoint Dir: {ckpt_dir}, Out Dir: {out_dir}')
|
|
model_config = PretrainedConfig.from_json_file(
|
|
os.path.join(ckpt_dir, 'config.json'))
|
|
|
|
architecture = model_config.architecture
|
|
if architecture not in MODEL_MAP:
|
|
raise RuntimeError(f'Unsupported model architecture: {architecture}')
|
|
|
|
if not os.path.exists(out_dir):
|
|
os.makedirs(out_dir)
|
|
|
|
for rank in range(model_config.mapping.world_size):
|
|
pruned_weights = {}
|
|
with safetensors.safe_open(os.path.join(ckpt_dir,
|
|
f'rank{rank}.safetensors'),
|
|
framework='pt',
|
|
device='cpu') as f:
|
|
for key in f.keys():
|
|
tensor = f.get_tensor(key)
|
|
if prune_all or can_prune(key):
|
|
pruned_weights[key] = torch.tensor([], dtype=tensor.dtype)
|
|
else:
|
|
pruned_weights[key] = tensor
|
|
|
|
save_file(pruned_weights,
|
|
os.path.join(out_dir, f'rank{rank}.safetensors'))
|
|
|
|
config_path = Path(ckpt_dir, 'config.json')
|
|
with open(str(Path(out_dir, 'config.json')), 'w') as f:
|
|
config = load_config(config_path)
|
|
config['is_pruned'] = True
|
|
json.dump(config, f)
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument('--checkpoint_dir', type=str, default=None)
|
|
parser.add_argument('--prune_all',
|
|
default=False,
|
|
action='store_true',
|
|
help='Remove all weights in the checkpoint')
|
|
parser.add_argument(
|
|
'--out_dir',
|
|
type=str,
|
|
default=None,
|
|
help=
|
|
'Path to write pruned checkpoint. Defaults to the same directory append with `.pruned`'
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
if args.checkpoint_dir is None:
|
|
raise RuntimeError(
|
|
"No `--checkpoint_dir` supplied to checkpoint pruner.")
|
|
|
|
if args.out_dir is None:
|
|
ckpt_path = Path(args.checkpoint_dir)
|
|
ckpt_name = ckpt_path.name
|
|
args.out_dir = str(
|
|
Path(args.checkpoint_dir).with_name(ckpt_name + '.pruned'))
|
|
|
|
prune_and_save(os.path.abspath(args.checkpoint_dir),
|
|
os.path.abspath(args.out_dir), args.prune_all)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|