mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-31 00:01:22 +08:00
Co-authored-by: DreamGenX <x@dreamgen.com> Co-authored-by: Ace-RR <78812427+Ace-RR@users.noreply.github.com> Co-authored-by: bprus <39293131+bprus@users.noreply.github.com> Co-authored-by: janpetrov <janpetrov@icloud.com>
39 lines
1.1 KiB
Python
39 lines
1.1 KiB
Python
#!/usr/bin/env python3
|
|
import os
|
|
|
|
import click
|
|
|
|
from tensorrt_llm.hlapi import LLM, ModelConfig, SamplingParams
|
|
|
|
|
|
@click.command()
|
|
@click.option("--model_dir", type=str, required=True)
|
|
@click.option("--tp_size", type=int, required=True)
|
|
@click.option("--engine_dir", type=str, default=None)
|
|
@click.option("--prompt", type=str, default=None)
|
|
def main(model_dir: str, tp_size: int, engine_dir: str, prompt: str):
|
|
config = ModelConfig(model_dir)
|
|
config.parallel_config.tp_size = tp_size
|
|
|
|
llm = LLM(config)
|
|
|
|
if engine_dir is not None and os.path.abspath(
|
|
engine_dir) != os.path.abspath(model_dir):
|
|
llm.save(engine_dir)
|
|
|
|
sampling_params = SamplingParams(max_new_tokens=10, end_id=-1)
|
|
|
|
# For intentional failure test, need a simple prompt here to start LLM
|
|
prompt_token_ids = [45, 12, 13]
|
|
for output in llm.generate([prompt_token_ids],
|
|
sampling_params=sampling_params):
|
|
print(output)
|
|
|
|
if prompt is not None:
|
|
for output in llm.generate([prompt], sampling_params=sampling_params):
|
|
print(output)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|