feat: run mmlu and summarize without engine_dir. (#4056)

Signed-off-by: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com>
This commit is contained in:
yuxianq 2025-05-05 19:35:07 +08:00 committed by GitHub
parent aa980dc92f
commit 2cfcdbefee
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 29 additions and 11 deletions

View File

@ -56,7 +56,7 @@ from transformers import (AutoConfig, AutoModel, AutoModelForCausalLM,
AutoModelForSeq2SeqLM, AutoTokenizer, AutoModelForSeq2SeqLM, AutoTokenizer,
GenerationConfig) GenerationConfig)
from utils import (add_common_args, load_tokenizer, prepare_enc_dec_inputs, from utils import (add_common_args, load_tokenizer, prepare_enc_dec_inputs,
read_model_name) read_is_enc_dec, read_model_name)
import tensorrt_llm import tensorrt_llm
from tensorrt_llm.runtime import PYTHON_BINDINGS, ModelRunner from tensorrt_llm.runtime import PYTHON_BINDINGS, ModelRunner
@ -399,15 +399,14 @@ def main():
cat_cors = {cat: [] for cat in get_categories()} cat_cors = {cat: [] for cat in get_categories()}
# different handling if encoder-decoder models # different handling if encoder-decoder models
is_enc_dec = {'encoder', 'decoder'}.issubset({ is_enc_dec = read_is_enc_dec(
name args.engine_dir if not args.test_hf else args.hf_model_dir,
for name in os.listdir(args.engine_dir) args.test_hf)
if os.path.isdir(os.path.join(args.engine_dir, name))
})
model_name, model_version = read_model_name( model_name, model_version = read_model_name(
args.engine_dir if not is_enc_dec else os.path. (args.engine_dir if not is_enc_dec else os.path.join(
join(args.engine_dir, 'encoder')) args.engine_dir, 'encoder'))
if not args.test_hf else args.hf_model_dir, args.test_hf)
tokenizer, pad_id, end_id = load_tokenizer( tokenizer, pad_id, end_id = load_tokenizer(
tokenizer_dir=args.tokenizer_dir, tokenizer_dir=args.tokenizer_dir,

View File

@ -47,7 +47,8 @@ def main(args):
test_hf = args.test_hf and runtime_rank == 0 # only run hf on rank 0 test_hf = args.test_hf and runtime_rank == 0 # only run hf on rank 0
test_trt_llm = args.test_trt_llm test_trt_llm = args.test_trt_llm
model_name, model_version = read_model_name(args.engine_dir) model_name, model_version = read_model_name(
args.engine_dir if not test_hf else args.hf_model_dir, test_hf)
if args.hf_model_dir is None: if args.hf_model_dir is None:
logger.warning( logger.warning(
"hf_model_dir is not specified. Try to infer from model_name, but this may be incorrect." "hf_model_dir is not specified. Try to infer from model_name, but this may be incorrect."

View File

@ -82,12 +82,30 @@ def read_decoder_start_token_id(engine_dir):
return config['pretrained_config']['decoder_start_token_id'] return config['pretrained_config']['decoder_start_token_id']
def read_model_name(engine_dir: str): def read_is_enc_dec(engine_dir: str, is_hf: bool = False):
engine_version = get_engine_version(engine_dir) if is_hf:
with open(Path(engine_dir) / "config.json", 'r') as f:
config = json.load(f)
is_enc_dec = config.get('is_encoder_decoder', False)
else:
is_enc_dec = {'encoder', 'decoder'}.issubset({
name
for name in os.listdir(engine_dir)
if os.path.isdir(os.path.join(engine_dir, name))
})
return is_enc_dec
def read_model_name(engine_dir: str, is_hf: bool = False):
with open(Path(engine_dir) / "config.json", 'r') as f: with open(Path(engine_dir) / "config.json", 'r') as f:
config = json.load(f) config = json.load(f)
if is_hf:
model_arch = config['architectures'][0]
model_version = config.get('model_type', None)
return model_arch, model_version
engine_version = get_engine_version(engine_dir)
if engine_version is None: if engine_version is None:
return config['builder_config']['name'], None return config['builder_config']['name'], None