TensorRT-LLMs/tests/hlapi/run_llm.py
Kaiyu Xie f430a4b447
Update TensorRT-LLM (#1688)
* Update TensorRT-LLM

---------

Co-authored-by: IbrahimAmin <ibrahimamin532@gmail.com>
Co-authored-by: Fabian Joswig <fjosw@users.noreply.github.com>
Co-authored-by: Pzzzzz <hello-cd.plus@hotmail.com>
Co-authored-by: CoderHam <hemant@cohere.com>
Co-authored-by: Konstantin Lopuhin <kostia.lopuhin@gmail.com>
2024-05-28 20:07:49 +08:00

31 lines
819 B
Python

#!/usr/bin/env python3
import os
import click
from tensorrt_llm.hlapi import LLM, ModelConfig, SamplingConfig
@click.command()
@click.option("--model_dir", type=str, required=True)
@click.option("--tp_size", type=int, required=True)
@click.option("--engine_dir", type=str, default=None)
def main(model_dir: str, tp_size: int, engine_dir: str):
config = ModelConfig(model_dir)
config.parallel_config.tp_size = tp_size
llm = LLM(config)
if engine_dir is not None and os.path.abspath(
engine_dir) != os.path.abspath(model_dir):
llm.save(engine_dir)
prompt = [45, 12, 13]
sampling_config = SamplingConfig(max_new_tokens=10, end_id=-1)
for output in llm.generate([prompt], sampling_config=sampling_config):
print(output)
if __name__ == '__main__':
main()