feat: run mmlu and summarize without engine_dir. (#4056)

Signed-off-by: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com>
This commit is contained in:
yuxianq 2025-05-05 19:35:07 +08:00 committed by GitHub
parent aa980dc92f
commit 2cfcdbefee
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 29 additions and 11 deletions

View File

@ -56,7 +56,7 @@ from transformers import (AutoConfig, AutoModel, AutoModelForCausalLM,
AutoModelForSeq2SeqLM, AutoTokenizer,
GenerationConfig)
from utils import (add_common_args, load_tokenizer, prepare_enc_dec_inputs,
read_model_name)
read_is_enc_dec, read_model_name)
import tensorrt_llm
from tensorrt_llm.runtime import PYTHON_BINDINGS, ModelRunner
@ -399,15 +399,14 @@ def main():
cat_cors = {cat: [] for cat in get_categories()}
# different handling if encoder-decoder models
is_enc_dec = {'encoder', 'decoder'}.issubset({
name
for name in os.listdir(args.engine_dir)
if os.path.isdir(os.path.join(args.engine_dir, name))
})
is_enc_dec = read_is_enc_dec(
args.engine_dir if not args.test_hf else args.hf_model_dir,
args.test_hf)
model_name, model_version = read_model_name(
args.engine_dir if not is_enc_dec else os.path.
join(args.engine_dir, 'encoder'))
(args.engine_dir if not is_enc_dec else os.path.join(
args.engine_dir, 'encoder'))
if not args.test_hf else args.hf_model_dir, args.test_hf)
tokenizer, pad_id, end_id = load_tokenizer(
tokenizer_dir=args.tokenizer_dir,

View File

@ -47,7 +47,8 @@ def main(args):
test_hf = args.test_hf and runtime_rank == 0 # only run hf on rank 0
test_trt_llm = args.test_trt_llm
model_name, model_version = read_model_name(args.engine_dir)
model_name, model_version = read_model_name(
args.engine_dir if not test_hf else args.hf_model_dir, test_hf)
if args.hf_model_dir is None:
logger.warning(
"hf_model_dir is not specified. Try to infer from model_name, but this may be incorrect."

View File

@ -82,12 +82,30 @@ def read_decoder_start_token_id(engine_dir):
return config['pretrained_config']['decoder_start_token_id']
def read_model_name(engine_dir: str):
engine_version = get_engine_version(engine_dir)
def read_is_enc_dec(engine_dir: str, is_hf: bool = False):
if is_hf:
with open(Path(engine_dir) / "config.json", 'r') as f:
config = json.load(f)
is_enc_dec = config.get('is_encoder_decoder', False)
else:
is_enc_dec = {'encoder', 'decoder'}.issubset({
name
for name in os.listdir(engine_dir)
if os.path.isdir(os.path.join(engine_dir, name))
})
return is_enc_dec
def read_model_name(engine_dir: str, is_hf: bool = False):
with open(Path(engine_dir) / "config.json", 'r') as f:
config = json.load(f)
if is_hf:
model_arch = config['architectures'][0]
model_version = config.get('model_type', None)
return model_arch, model_version
engine_version = get_engine_version(engine_dir)
if engine_version is None:
return config['builder_config']['name'], None