mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
feat: run mmlu and summarize without engine_dir. (#4056)
Signed-off-by: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com>
This commit is contained in:
parent
aa980dc92f
commit
2cfcdbefee
@ -56,7 +56,7 @@ from transformers import (AutoConfig, AutoModel, AutoModelForCausalLM,
|
|||||||
AutoModelForSeq2SeqLM, AutoTokenizer,
|
AutoModelForSeq2SeqLM, AutoTokenizer,
|
||||||
GenerationConfig)
|
GenerationConfig)
|
||||||
from utils import (add_common_args, load_tokenizer, prepare_enc_dec_inputs,
|
from utils import (add_common_args, load_tokenizer, prepare_enc_dec_inputs,
|
||||||
read_model_name)
|
read_is_enc_dec, read_model_name)
|
||||||
|
|
||||||
import tensorrt_llm
|
import tensorrt_llm
|
||||||
from tensorrt_llm.runtime import PYTHON_BINDINGS, ModelRunner
|
from tensorrt_llm.runtime import PYTHON_BINDINGS, ModelRunner
|
||||||
@ -399,15 +399,14 @@ def main():
|
|||||||
cat_cors = {cat: [] for cat in get_categories()}
|
cat_cors = {cat: [] for cat in get_categories()}
|
||||||
|
|
||||||
# different handling if encoder-decoder models
|
# different handling if encoder-decoder models
|
||||||
is_enc_dec = {'encoder', 'decoder'}.issubset({
|
is_enc_dec = read_is_enc_dec(
|
||||||
name
|
args.engine_dir if not args.test_hf else args.hf_model_dir,
|
||||||
for name in os.listdir(args.engine_dir)
|
args.test_hf)
|
||||||
if os.path.isdir(os.path.join(args.engine_dir, name))
|
|
||||||
})
|
|
||||||
|
|
||||||
model_name, model_version = read_model_name(
|
model_name, model_version = read_model_name(
|
||||||
args.engine_dir if not is_enc_dec else os.path.
|
(args.engine_dir if not is_enc_dec else os.path.join(
|
||||||
join(args.engine_dir, 'encoder'))
|
args.engine_dir, 'encoder'))
|
||||||
|
if not args.test_hf else args.hf_model_dir, args.test_hf)
|
||||||
|
|
||||||
tokenizer, pad_id, end_id = load_tokenizer(
|
tokenizer, pad_id, end_id = load_tokenizer(
|
||||||
tokenizer_dir=args.tokenizer_dir,
|
tokenizer_dir=args.tokenizer_dir,
|
||||||
|
|||||||
@ -47,7 +47,8 @@ def main(args):
|
|||||||
|
|
||||||
test_hf = args.test_hf and runtime_rank == 0 # only run hf on rank 0
|
test_hf = args.test_hf and runtime_rank == 0 # only run hf on rank 0
|
||||||
test_trt_llm = args.test_trt_llm
|
test_trt_llm = args.test_trt_llm
|
||||||
model_name, model_version = read_model_name(args.engine_dir)
|
model_name, model_version = read_model_name(
|
||||||
|
args.engine_dir if not test_hf else args.hf_model_dir, test_hf)
|
||||||
if args.hf_model_dir is None:
|
if args.hf_model_dir is None:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"hf_model_dir is not specified. Try to infer from model_name, but this may be incorrect."
|
"hf_model_dir is not specified. Try to infer from model_name, but this may be incorrect."
|
||||||
|
|||||||
@ -82,12 +82,30 @@ def read_decoder_start_token_id(engine_dir):
|
|||||||
return config['pretrained_config']['decoder_start_token_id']
|
return config['pretrained_config']['decoder_start_token_id']
|
||||||
|
|
||||||
|
|
||||||
def read_model_name(engine_dir: str):
|
def read_is_enc_dec(engine_dir: str, is_hf: bool = False):
|
||||||
engine_version = get_engine_version(engine_dir)
|
if is_hf:
|
||||||
|
with open(Path(engine_dir) / "config.json", 'r') as f:
|
||||||
|
config = json.load(f)
|
||||||
|
is_enc_dec = config.get('is_encoder_decoder', False)
|
||||||
|
else:
|
||||||
|
is_enc_dec = {'encoder', 'decoder'}.issubset({
|
||||||
|
name
|
||||||
|
for name in os.listdir(engine_dir)
|
||||||
|
if os.path.isdir(os.path.join(engine_dir, name))
|
||||||
|
})
|
||||||
|
return is_enc_dec
|
||||||
|
|
||||||
|
|
||||||
|
def read_model_name(engine_dir: str, is_hf: bool = False):
|
||||||
with open(Path(engine_dir) / "config.json", 'r') as f:
|
with open(Path(engine_dir) / "config.json", 'r') as f:
|
||||||
config = json.load(f)
|
config = json.load(f)
|
||||||
|
|
||||||
|
if is_hf:
|
||||||
|
model_arch = config['architectures'][0]
|
||||||
|
model_version = config.get('model_type', None)
|
||||||
|
return model_arch, model_version
|
||||||
|
|
||||||
|
engine_version = get_engine_version(engine_dir)
|
||||||
if engine_version is None:
|
if engine_version is None:
|
||||||
return config['builder_config']['name'], None
|
return config['builder_config']['name'], None
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user