mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
* make LlmArgs Pydantic Signed-off-by: Superjomn <328693+Superjomn@users.noreply.github.com> * amending doc fix api_stability fix tests Signed-off-by: Superjomn <328693+Superjomn@users.noreply.github.com> * restore yaml groups refine StackTrace singleton clean tests Signed-off-by: Superjomn <328693+Superjomn@users.noreply.github.com> * fix trtllm-bench fix pytorch Signed-off-by: Superjomn <328693+Superjomn@users.noreply.github.com> * fix serve distagg Signed-off-by: Superjomn <328693+Superjomn@users.noreply.github.com> Signed-off-by: Superjomn <328693+Superjomn@users.noreply.github.com> * fix Signed-off-by: Superjomn <328693+Superjomn@users.noreply.github.com> --------- Signed-off-by: Superjomn <328693+Superjomn@users.noreply.github.com>
168 lines
5.9 KiB
Python
168 lines
5.9 KiB
Python
import argparse
|
|
|
|
from tensorrt_llm import SamplingParams
|
|
from tensorrt_llm._torch import LLM
|
|
from tensorrt_llm._torch.pyexecutor.config import PyTorchConfig
|
|
from tensorrt_llm.llmapi import (EagleDecodingConfig, KvCacheConfig,
|
|
MTPDecodingConfig)
|
|
|
|
example_prompts = [
|
|
"Hello, my name is",
|
|
"The president of the United States is",
|
|
"The capital of France is",
|
|
"The future of AI is",
|
|
]
|
|
|
|
|
|
def add_llm_args(parser):
|
|
parser.add_argument('--model_dir',
|
|
type=str,
|
|
required=True,
|
|
help="Model checkpoint directory.")
|
|
parser.add_argument("--prompt",
|
|
type=str,
|
|
nargs="+",
|
|
help="A single or a list of text prompts.")
|
|
# Build config
|
|
parser.add_argument("--max_seq_len",
|
|
type=int,
|
|
default=None,
|
|
help="The maximum sequence length.")
|
|
parser.add_argument("--max_batch_size",
|
|
type=int,
|
|
default=2048,
|
|
help="The maximum batch size.")
|
|
parser.add_argument(
|
|
"--max_num_tokens",
|
|
type=int,
|
|
default=8192,
|
|
help=
|
|
"The maximum total tokens (context + generation) across all sequences in a batch."
|
|
)
|
|
|
|
# Parallelism
|
|
parser.add_argument('--attention_backend',
|
|
type=str,
|
|
default='TRTLLM',
|
|
choices=[
|
|
'VANILLA', 'TRTLLM', 'FLASHINFER',
|
|
'FLASHINFER_STAR_ATTENTION'
|
|
])
|
|
parser.add_argument('--enable_attention_dp',
|
|
default=False,
|
|
action='store_true')
|
|
parser.add_argument('--tp_size', type=int, default=1)
|
|
parser.add_argument('--pp_size', type=int, default=1)
|
|
parser.add_argument('--moe_ep_size', type=int, default=-1)
|
|
parser.add_argument('--moe_tp_size', type=int, default=-1)
|
|
|
|
# KV cache
|
|
parser.add_argument('--kv_cache_dtype', type=str, default='auto')
|
|
parser.add_argument('--kv_cache_enable_block_reuse',
|
|
default=True,
|
|
action='store_false')
|
|
parser.add_argument("--kv_cache_fraction", type=float, default=None)
|
|
|
|
# Runtime
|
|
parser.add_argument('--enable_overlap_scheduler',
|
|
default=False,
|
|
action='store_true')
|
|
parser.add_argument('--enable_chunked_prefill',
|
|
default=False,
|
|
action='store_true')
|
|
parser.add_argument('--use_cuda_graph', default=False, action='store_true')
|
|
parser.add_argument('--print_iter_log',
|
|
default=False,
|
|
action='store_true',
|
|
help='Print iteration logs during execution')
|
|
|
|
# Sampling
|
|
parser.add_argument("--max_tokens", type=int, default=64)
|
|
parser.add_argument("--temperature", type=float, default=None)
|
|
parser.add_argument("--top_k", type=int, default=None)
|
|
parser.add_argument("--top_p", type=float, default=None)
|
|
parser.add_argument('--load_format', type=str, default='auto')
|
|
|
|
# Speculative decoding
|
|
parser.add_argument('--spec_decode_algo', type=str, default=None)
|
|
parser.add_argument('--spec_decode_nextn', type=int, default=1)
|
|
parser.add_argument('--eagle_model_dir', type=str, default=None)
|
|
|
|
return parser
|
|
|
|
|
|
def parse_arguments():
|
|
parser = argparse.ArgumentParser(
|
|
description="LLM models with the PyTorch workflow.")
|
|
parser = add_llm_args(parser)
|
|
args = parser.parse_args()
|
|
return args
|
|
|
|
|
|
def setup_llm(args):
|
|
pytorch_config = PyTorchConfig(
|
|
enable_overlap_scheduler=args.enable_overlap_scheduler,
|
|
kv_cache_dtype=args.kv_cache_dtype,
|
|
attn_backend=args.attention_backend,
|
|
use_cuda_graph=args.use_cuda_graph,
|
|
load_format=args.load_format,
|
|
print_iter_log=args.print_iter_log,
|
|
)
|
|
|
|
kv_cache_config = KvCacheConfig(
|
|
enable_block_reuse=args.kv_cache_enable_block_reuse,
|
|
free_gpu_memory_fraction=args.kv_cache_fraction,
|
|
)
|
|
|
|
spec_decode_algo = args.spec_decode_algo.upper(
|
|
) if args.spec_decode_algo is not None else None
|
|
|
|
if spec_decode_algo == 'MTP':
|
|
spec_config = MTPDecodingConfig(
|
|
num_nextn_predict_layers=args.spec_decode_nextn)
|
|
elif spec_decode_algo == "EAGLE3":
|
|
spec_config = EagleDecodingConfig(
|
|
max_draft_len=args.spec_decode_nextn,
|
|
pytorch_eagle_weights_path=args.eagle_model_dir)
|
|
else:
|
|
spec_config = None
|
|
|
|
llm = LLM(model=args.model_dir,
|
|
max_seq_len=args.max_seq_len,
|
|
max_batch_size=args.max_batch_size,
|
|
max_num_tokens=args.max_num_tokens,
|
|
pytorch_backend_config=pytorch_config,
|
|
kv_cache_config=kv_cache_config,
|
|
tensor_parallel_size=args.tp_size,
|
|
pipeline_parallel_size=args.pp_size,
|
|
enable_attention_dp=args.enable_attention_dp,
|
|
moe_expert_parallel_size=args.moe_ep_size,
|
|
moe_tensor_parallel_size=args.moe_tp_size,
|
|
enable_chunked_prefill=args.enable_chunked_prefill,
|
|
speculative_config=spec_config)
|
|
|
|
sampling_params = SamplingParams(
|
|
max_tokens=args.max_tokens,
|
|
temperature=args.temperature,
|
|
top_k=args.top_k,
|
|
top_p=args.top_p,
|
|
)
|
|
return llm, sampling_params
|
|
|
|
|
|
def main():
|
|
args = parse_arguments()
|
|
prompts = args.prompt if args.prompt else example_prompts
|
|
|
|
llm, sampling_params = setup_llm(args)
|
|
outputs = llm.generate(prompts, sampling_params)
|
|
|
|
for i, output in enumerate(outputs):
|
|
prompt = output.prompt
|
|
generated_text = output.outputs[0].text
|
|
print(f"[{i}] Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|