mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
feat: run mmlu and summarize without engine_dir. (#4056)
Signed-off-by: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com>
This commit is contained in:
parent
aa980dc92f
commit
2cfcdbefee
@ -56,7 +56,7 @@ from transformers import (AutoConfig, AutoModel, AutoModelForCausalLM,
|
||||
AutoModelForSeq2SeqLM, AutoTokenizer,
|
||||
GenerationConfig)
|
||||
from utils import (add_common_args, load_tokenizer, prepare_enc_dec_inputs,
|
||||
read_model_name)
|
||||
read_is_enc_dec, read_model_name)
|
||||
|
||||
import tensorrt_llm
|
||||
from tensorrt_llm.runtime import PYTHON_BINDINGS, ModelRunner
|
||||
@ -399,15 +399,14 @@ def main():
|
||||
cat_cors = {cat: [] for cat in get_categories()}
|
||||
|
||||
# different handling if encoder-decoder models
|
||||
is_enc_dec = {'encoder', 'decoder'}.issubset({
|
||||
name
|
||||
for name in os.listdir(args.engine_dir)
|
||||
if os.path.isdir(os.path.join(args.engine_dir, name))
|
||||
})
|
||||
is_enc_dec = read_is_enc_dec(
|
||||
args.engine_dir if not args.test_hf else args.hf_model_dir,
|
||||
args.test_hf)
|
||||
|
||||
model_name, model_version = read_model_name(
|
||||
args.engine_dir if not is_enc_dec else os.path.
|
||||
join(args.engine_dir, 'encoder'))
|
||||
(args.engine_dir if not is_enc_dec else os.path.join(
|
||||
args.engine_dir, 'encoder'))
|
||||
if not args.test_hf else args.hf_model_dir, args.test_hf)
|
||||
|
||||
tokenizer, pad_id, end_id = load_tokenizer(
|
||||
tokenizer_dir=args.tokenizer_dir,
|
||||
|
||||
@ -47,7 +47,8 @@ def main(args):
|
||||
|
||||
test_hf = args.test_hf and runtime_rank == 0 # only run hf on rank 0
|
||||
test_trt_llm = args.test_trt_llm
|
||||
model_name, model_version = read_model_name(args.engine_dir)
|
||||
model_name, model_version = read_model_name(
|
||||
args.engine_dir if not test_hf else args.hf_model_dir, test_hf)
|
||||
if args.hf_model_dir is None:
|
||||
logger.warning(
|
||||
"hf_model_dir is not specified. Try to infer from model_name, but this may be incorrect."
|
||||
|
||||
@ -82,12 +82,30 @@ def read_decoder_start_token_id(engine_dir):
|
||||
return config['pretrained_config']['decoder_start_token_id']
|
||||
|
||||
|
||||
def read_model_name(engine_dir: str):
|
||||
engine_version = get_engine_version(engine_dir)
|
||||
def read_is_enc_dec(engine_dir: str, is_hf: bool = False):
|
||||
if is_hf:
|
||||
with open(Path(engine_dir) / "config.json", 'r') as f:
|
||||
config = json.load(f)
|
||||
is_enc_dec = config.get('is_encoder_decoder', False)
|
||||
else:
|
||||
is_enc_dec = {'encoder', 'decoder'}.issubset({
|
||||
name
|
||||
for name in os.listdir(engine_dir)
|
||||
if os.path.isdir(os.path.join(engine_dir, name))
|
||||
})
|
||||
return is_enc_dec
|
||||
|
||||
|
||||
def read_model_name(engine_dir: str, is_hf: bool = False):
|
||||
with open(Path(engine_dir) / "config.json", 'r') as f:
|
||||
config = json.load(f)
|
||||
|
||||
if is_hf:
|
||||
model_arch = config['architectures'][0]
|
||||
model_version = config.get('model_type', None)
|
||||
return model_arch, model_version
|
||||
|
||||
engine_version = get_engine_version(engine_dir)
|
||||
if engine_version is None:
|
||||
return config['builder_config']['name'], None
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user