mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
294 lines
11 KiB
Python
294 lines
11 KiB
Python
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import functools
|
|
import logging
|
|
import os
|
|
import pathlib
|
|
import tarfile
|
|
import typing
|
|
|
|
import torch
|
|
import yaml
|
|
from transformers import GPT2Config
|
|
from utils.convert import cpu_map_location, gpu_map_location
|
|
|
|
LOGGER = logging.getLogger(__name__)
|
|
|
|
|
|
def nemo_to_gpt_config(nemo_model_config, vocab_size, eos_id, bos_id):
|
|
convertion_dict = {
|
|
"activation_function": "activation",
|
|
"layer_norm_epsilon": "layernorm_epsilon",
|
|
"n_embd": "hidden_size",
|
|
"n_head": "num_attention_heads",
|
|
"n_layer": "num_layers",
|
|
"n_positions": "max_position_embeddings",
|
|
"rotary_pct": "rotary_percentage",
|
|
"bias": "bias",
|
|
"intermediate_size": "ffn_hidden_size",
|
|
}
|
|
|
|
kwargs = {
|
|
key: nemo_model_config[value]
|
|
for key, value in convertion_dict.items() if value in nemo_model_config
|
|
}
|
|
kwargs["vocab_size"] = vocab_size
|
|
kwargs["eos_token_id"] = eos_id
|
|
kwargs["bos_token_id"] = bos_id
|
|
|
|
return GPT2Config(**kwargs)
|
|
|
|
|
|
def add_special_tokens_to_tokenizer(tokenizer):
|
|
|
|
# Need to add cls, sep, mask tokens to the tokenizer if they don't exist.
|
|
# If cls, sep and mask are not attributes of the tokenizer, add it.
|
|
if not hasattr(tokenizer, 'cls_token'):
|
|
tokenizer.add_special_tokens({'cls_token': '<cls>'})
|
|
if not hasattr(tokenizer.tokenizer, 'sep_id'):
|
|
tokenizer.add_special_tokens({'sep_token': '<sep>'})
|
|
if not hasattr(tokenizer.tokenizer, 'mask_id'):
|
|
tokenizer.add_special_tokens({'mask_token': '<mask>'})
|
|
|
|
# bos, eos, pad and unk may be present in the provided spm .model file, if they are, use it.
|
|
if not hasattr(tokenizer, 'pad_token'):
|
|
if hasattr(tokenizer.tokenizer,
|
|
'pad_id') and tokenizer.tokenizer.pad_id() > 0:
|
|
tokenizer.pad_token = tokenizer.tokenizer.id_to_piece(
|
|
tokenizer.tokenizer.pad_id())
|
|
else:
|
|
tokenizer.add_special_tokens({'pad_token': '<pad>'})
|
|
else:
|
|
tokenizer.add_special_tokens({'pad_token': '<pad>'})
|
|
|
|
if not hasattr(tokenizer, 'bos_token'):
|
|
if hasattr(tokenizer.tokenizer,
|
|
'bos_id') and tokenizer.tokenizer.bos_id() > 0:
|
|
tokenizer.bos_token = tokenizer.tokenizer.id_to_piece(
|
|
tokenizer.tokenizer.bos_id())
|
|
else:
|
|
tokenizer.add_special_tokens({'bos_token': '<bos>'})
|
|
else:
|
|
tokenizer.add_special_tokens({'bos_token': '<s>'})
|
|
|
|
if not hasattr(tokenizer, 'eos_token'):
|
|
if hasattr(tokenizer.tokenizer,
|
|
'eos_id') and tokenizer.tokenizer.eos_id() > 0:
|
|
tokenizer.eos_token = tokenizer.tokenizer.id_to_piece(
|
|
tokenizer.tokenizer.eos_id())
|
|
else:
|
|
tokenizer.add_special_tokens({'eos_token': '<eos>'})
|
|
else:
|
|
tokenizer.add_special_tokens({'eos_token': '</s>'})
|
|
|
|
|
|
def unpack_nemo_ckpt(
|
|
nemo_archive_path: typing.Union[str, pathlib.Path],
|
|
out_dir_path: typing.Union[str, pathlib.Path],
|
|
):
|
|
nemo_archive_path = pathlib.Path(nemo_archive_path)
|
|
if not nemo_archive_path.exists():
|
|
raise FileNotFoundError(f"{nemo_archive_path} does not exist")
|
|
|
|
for tar_mode in ["r:", "r:gz"]:
|
|
try:
|
|
with tarfile.open(nemo_archive_path, mode=tar_mode) as tar_file:
|
|
|
|
def is_within_directory(directory, target):
|
|
|
|
abs_directory = os.path.abspath(directory)
|
|
abs_target = os.path.abspath(target)
|
|
|
|
prefix = os.path.commonprefix([abs_directory, abs_target])
|
|
|
|
return prefix == abs_directory
|
|
|
|
def safe_members(tar_file):
|
|
members = []
|
|
for member in tar_file.getmembers():
|
|
member_path = os.path.join(out_dir_path, member.name)
|
|
if not is_within_directory(out_dir_path, member_path):
|
|
raise Exception(
|
|
"Attempted Path Traversal in Tar File")
|
|
members.append(member)
|
|
return members
|
|
|
|
tar_file.extractall(out_dir_path,
|
|
members=safe_members(tar_file),
|
|
numeric_owner=False)
|
|
|
|
return out_dir_path
|
|
except tarfile.ReadError:
|
|
pass
|
|
|
|
raise RuntimeError(f"Could not unpack {nemo_archive_path}")
|
|
|
|
|
|
def extract_layers_with_prefix(model_, prefix):
|
|
length_to_trim = len(prefix)
|
|
model_state = model_.get("state_dict", model_)
|
|
return {
|
|
key[length_to_trim:]: model_state[key]
|
|
for key in model_state.keys() if prefix in key
|
|
}
|
|
|
|
|
|
class UnpackedNemoCheckpointDir:
|
|
|
|
def __init__(self,
|
|
checkpoints_dir: typing.Union[str, pathlib.Path],
|
|
load_checkpoints_to_cpu: bool = False):
|
|
self._checkpoints_dir = pathlib.Path(checkpoints_dir)
|
|
self._load_checkpoints_to_cpu = load_checkpoints_to_cpu
|
|
|
|
@property
|
|
@functools.lru_cache
|
|
def model_config(self):
|
|
model_config = None
|
|
|
|
model_config_filename = "model_config.yaml"
|
|
model_configs_paths = list(
|
|
self._checkpoints_dir.rglob(model_config_filename))
|
|
if model_configs_paths:
|
|
if len(model_configs_paths) > 1:
|
|
raise RuntimeError(
|
|
f"There are more than single {model_config_filename} "
|
|
f"in {self._checkpoints_dir}: {', '.join(map(lambda p: p.as_posix(), model_configs_paths))}"
|
|
)
|
|
model_config_path = model_configs_paths[0]
|
|
LOGGER.debug("Loading model config from %s", model_config_path)
|
|
with model_config_path.open("r") as model_config_file:
|
|
model_config = yaml.load(model_config_file,
|
|
Loader=yaml.SafeLoader)
|
|
else:
|
|
LOGGER.debug("Searching model config in checkpoints")
|
|
# try to obtain from checkpoint
|
|
checkpoint_name = self.checkpoint_name
|
|
checkpoints_paths = sorted(
|
|
self._checkpoints_dir.rglob(checkpoint_name))
|
|
if checkpoints_paths:
|
|
# assume that parallel ranks 0 checkpoint should have model config embedded
|
|
checkpoint_path = checkpoints_paths[0]
|
|
|
|
map_location_fn = cpu_map_location if self._load_checkpoints_to_cpu else gpu_map_location
|
|
|
|
model_00 = torch.load(checkpoint_path,
|
|
map_location=map_location_fn)
|
|
if "hyper_parameters" in model_00 and "cfg" in model_00[
|
|
"hyper_parameters"]:
|
|
model_config = model_00["hyper_parameters"]["cfg"]
|
|
LOGGER.debug("Loaded model config from checkpoint %s",
|
|
checkpoint_path)
|
|
else:
|
|
LOGGER.debug("Could not find model config in checkpoint %s",
|
|
checkpoint_path)
|
|
|
|
del model_00
|
|
|
|
if model_config is None:
|
|
LOGGER.warning(
|
|
"Could not find checkpoint with NeMo model config in %s",
|
|
self._checkpoints_dir)
|
|
|
|
LOGGER.debug("Loaded model config %s", model_config)
|
|
|
|
return model_config
|
|
|
|
@property
|
|
def checkpoints_dir(self):
|
|
return self._checkpoints_dir
|
|
|
|
def get_checkpoints_paths(self,
|
|
tensor_model_parallel_size=1,
|
|
pipeline_model_parallel_size=1):
|
|
"""
|
|
Injects tensor/pipeline model parallel ranks into the filepath.
|
|
Does nothing if not using model parallelism.
|
|
"""
|
|
|
|
checkpoint_path_without_rank = self.checkpoints_dir / self.checkpoint_name
|
|
|
|
def _inject_parallel_ranks(tp_rank, pp_rank):
|
|
if tensor_model_parallel_size > 1 or pipeline_model_parallel_size > 1:
|
|
if pipeline_model_parallel_size is None or pipeline_model_parallel_size == 1:
|
|
checkpoint_path = (checkpoint_path_without_rank.parent /
|
|
f"mp_rank_{tp_rank:02d}" /
|
|
checkpoint_path_without_rank.name)
|
|
else:
|
|
checkpoint_path = (
|
|
checkpoint_path_without_rank.parent /
|
|
f"tp_rank_{tp_rank:02d}_pp_rank_{pp_rank:03d}" /
|
|
checkpoint_path_without_rank.name)
|
|
return checkpoint_path
|
|
else:
|
|
return checkpoint_path_without_rank
|
|
|
|
return [[
|
|
_inject_parallel_ranks(tp_rank=tp_rank, pp_rank=pp_rank)
|
|
for pp_rank in range(pipeline_model_parallel_size)
|
|
] for tp_rank in range(tensor_model_parallel_size)]
|
|
|
|
@property
|
|
@functools.lru_cache
|
|
def checkpoint_name(self):
|
|
patterns = [
|
|
"model_weights.ckpt", # older megatron checkpoints
|
|
"*last.ckpt", # newer format of checkpoints
|
|
]
|
|
for pattern in patterns:
|
|
model_files = sorted(list(self._checkpoints_dir.rglob(pattern)))
|
|
if model_files:
|
|
return model_files[0].name
|
|
|
|
raise ValueError(
|
|
f"Could not find checkpoint files in {self._checkpoints_dir}")
|
|
|
|
@functools.lru_cache
|
|
def get_tokenizer_file_path(self, tokenizer_key, file_key,
|
|
default_filename_pattern):
|
|
model_config = self.model_config
|
|
file_property = None
|
|
if tokenizer_key in model_config and file_key in model_config[
|
|
tokenizer_key]:
|
|
file_property = model_config[tokenizer_key][file_key]
|
|
elif file_key in model_config:
|
|
file_property = model_config[file_key]
|
|
|
|
LOGGER.debug("model_config[%s][%s]=%s", tokenizer_key, file_key,
|
|
file_property)
|
|
|
|
if file_property and file_property.startswith("nemo:"):
|
|
filename = file_property.split("nemo:")[1]
|
|
filename_pattern = f"*{filename}"
|
|
elif file_property and file_property.startswith("/artifacts/"):
|
|
filename = pathlib.Path(file_property).name
|
|
filename_pattern = f"*{filename}"
|
|
elif file_property is None or file_property == "None":
|
|
filename_pattern = None
|
|
else:
|
|
filename_pattern = default_filename_pattern
|
|
LOGGER.warning(
|
|
f"Tokenizer file from config: {tokenizer_key}.{file_key}={file_property} "
|
|
f"looks like unsupported path. Pattern {filename_pattern} will be used."
|
|
)
|
|
|
|
file_path = None
|
|
if filename_pattern is not None:
|
|
files_paths = list(self._checkpoints_dir.glob(filename_pattern))
|
|
if files_paths:
|
|
assert len(files_paths) == 1
|
|
file_path = files_paths[0]
|
|
|
|
return file_path
|