mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
384 lines
16 KiB
Python
384 lines
16 KiB
Python
import argparse
|
|
import json
|
|
import time
|
|
|
|
from tensorrt_llm import LLM, SamplingParams
|
|
from tensorrt_llm.llmapi import (AttentionDpConfig, AutoDecodingConfig,
|
|
CudaGraphConfig, DraftTargetDecodingConfig,
|
|
EagleDecodingConfig, KvCacheConfig, MoeConfig,
|
|
MTPDecodingConfig, NGramDecodingConfig,
|
|
TorchCompileConfig)
|
|
|
|
example_prompts = [
|
|
"Hello, my name is",
|
|
"The capital of France is",
|
|
"The future of AI is",
|
|
]
|
|
|
|
|
|
def add_llm_args(parser):
|
|
parser.add_argument('--model_dir',
|
|
type=str,
|
|
required=True,
|
|
help="Model checkpoint directory.")
|
|
parser.add_argument("--prompt",
|
|
type=str,
|
|
nargs="+",
|
|
help="A single or a list of text prompts.")
|
|
parser.add_argument('--checkpoint_format',
|
|
type=str,
|
|
default=None,
|
|
choices=["HF", "mistral"],
|
|
help="Model checkpoint format.")
|
|
# Build config
|
|
parser.add_argument("--max_seq_len",
|
|
type=int,
|
|
default=None,
|
|
help="The maximum sequence length.")
|
|
parser.add_argument("--max_batch_size",
|
|
type=int,
|
|
default=2048,
|
|
help="The maximum batch size.")
|
|
parser.add_argument(
|
|
"--max_num_tokens",
|
|
type=int,
|
|
default=8192,
|
|
help=
|
|
"The maximum total tokens (context + generation) across all sequences in a batch."
|
|
)
|
|
|
|
# Parallelism
|
|
parser.add_argument('--attention_backend',
|
|
type=str,
|
|
default='TRTLLM',
|
|
choices=[
|
|
'VANILLA', 'TRTLLM', 'FLASHINFER',
|
|
'FLASHINFER_STAR_ATTENTION'
|
|
])
|
|
parser.add_argument('--moe_backend',
|
|
type=str,
|
|
default='CUTLASS',
|
|
choices=[
|
|
'CUTLASS', 'TRTLLM', 'VANILLA', 'WIDEEP',
|
|
'DEEPGEMM', 'CUTEDSL', 'TRITON'
|
|
])
|
|
parser.add_argument('--enable_attention_dp',
|
|
default=False,
|
|
action='store_true')
|
|
parser.add_argument('--attention_dp_enable_balance',
|
|
default=False,
|
|
action='store_true')
|
|
parser.add_argument('--attention_dp_time_out_iters', type=int, default=0)
|
|
parser.add_argument('--attention_dp_batching_wait_iters',
|
|
type=int,
|
|
default=0)
|
|
parser.add_argument('--sampler_type',
|
|
default="auto",
|
|
choices=["auto", "TorchSampler", "TRTLLMSampler"])
|
|
parser.add_argument('--tp_size', type=int, default=1)
|
|
parser.add_argument('--pp_size', type=int, default=1)
|
|
parser.add_argument('--orchestrator_type',
|
|
type=str,
|
|
default=None,
|
|
choices=[None, 'rpc', 'ray'],
|
|
help='Orchestrator type for multi-GPU execution')
|
|
parser.add_argument('--moe_ep_size', type=int, default=-1)
|
|
parser.add_argument('--moe_tp_size', type=int, default=-1)
|
|
parser.add_argument('--moe_cluster_size', type=int, default=-1)
|
|
parser.add_argument(
|
|
'--use_low_precision_moe_combine',
|
|
default=False,
|
|
action='store_true',
|
|
help='Use low precision combine in MoE (only for NVFP4 quantization)')
|
|
|
|
# KV cache
|
|
parser.add_argument('--kv_cache_dtype', type=str, default='auto')
|
|
parser.add_argument('--disable_kv_cache_reuse',
|
|
default=False,
|
|
action='store_true')
|
|
parser.add_argument("--tokens_per_block", type=int, default=32)
|
|
parser.add_argument('--log_kv_cache_events',
|
|
default=False,
|
|
action='store_true')
|
|
|
|
# Runtime
|
|
parser.add_argument('--disable_overlap_scheduler',
|
|
default=False,
|
|
action='store_true')
|
|
parser.add_argument('--enable_chunked_prefill',
|
|
default=False,
|
|
action='store_true')
|
|
parser.add_argument('--use_cuda_graph', default=False, action='store_true')
|
|
parser.add_argument('--cuda_graph_padding_enabled',
|
|
default=False,
|
|
action='store_true')
|
|
parser.add_argument('--cuda_graph_batch_sizes',
|
|
nargs='+',
|
|
type=int,
|
|
default=None)
|
|
parser.add_argument('--print_iter_log',
|
|
default=False,
|
|
action='store_true',
|
|
help='Print iteration logs during execution')
|
|
parser.add_argument('--use_torch_compile',
|
|
default=False,
|
|
action='store_true',
|
|
help='Use torch.compile to optimize the model')
|
|
parser.add_argument('--use_piecewise_cuda_graph',
|
|
default=False,
|
|
action='store_true',
|
|
help='Use piecewise CUDA graph to optimize the model')
|
|
parser.add_argument('--apply_chat_template',
|
|
default=False,
|
|
action='store_true')
|
|
|
|
# Sampling
|
|
parser.add_argument("--max_tokens", type=int, default=64)
|
|
parser.add_argument("--temperature", type=float, default=None)
|
|
parser.add_argument("--top_k", type=int, default=None)
|
|
parser.add_argument("--top_p", type=float, default=None)
|
|
parser.add_argument('--load_format', type=str, default='auto')
|
|
parser.add_argument('--n', type=int, default=1)
|
|
parser.add_argument('--best_of', type=int, default=None)
|
|
parser.add_argument('--max_beam_width', type=int, default=1)
|
|
|
|
# Speculative decoding
|
|
parser.add_argument('--spec_decode_algo', type=str, default=None)
|
|
parser.add_argument('--spec_decode_max_draft_len', type=int, default=1)
|
|
parser.add_argument('--draft_model_dir', type=str, default=None)
|
|
parser.add_argument('--max_matching_ngram_size', type=int, default=5)
|
|
parser.add_argument('--use_one_model', default=False, action='store_true')
|
|
parser.add_argument('--eagle_choices', type=str, default=None)
|
|
parser.add_argument('--use_dynamic_tree',
|
|
default=False,
|
|
action='store_true')
|
|
parser.add_argument('--dynamic_tree_max_topK', type=int, default=None)
|
|
parser.add_argument('--allow_advanced_sampling',
|
|
default=False,
|
|
action='store_true')
|
|
parser.add_argument('--eagle3_model_arch',
|
|
type=str,
|
|
default="llama3",
|
|
choices=["llama3", "mistral_large3"],
|
|
help="The model architecture of the eagle3 model.")
|
|
|
|
# Relaxed acceptance
|
|
parser.add_argument('--use_relaxed_acceptance_for_thinking',
|
|
default=False,
|
|
action='store_true')
|
|
parser.add_argument('--relaxed_topk', type=int, default=1)
|
|
parser.add_argument('--relaxed_delta', type=float, default=0.)
|
|
|
|
# HF
|
|
parser.add_argument('--trust_remote_code',
|
|
default=False,
|
|
action='store_true')
|
|
parser.add_argument('--return_context_logits',
|
|
default=False,
|
|
action='store_true')
|
|
parser.add_argument('--return_generation_logits',
|
|
default=False,
|
|
action='store_true')
|
|
parser.add_argument('--prompt_logprobs', default=False, action='store_true')
|
|
parser.add_argument('--logprobs', default=False, action='store_true')
|
|
|
|
parser.add_argument('--additional_model_outputs',
|
|
type=str,
|
|
default=None,
|
|
nargs='+')
|
|
|
|
return parser
|
|
|
|
|
|
def parse_arguments():
|
|
parser = argparse.ArgumentParser(
|
|
description="LLM models with the PyTorch workflow.")
|
|
parser = add_llm_args(parser)
|
|
parser.add_argument("--kv_cache_fraction", type=float, default=0.9)
|
|
args = parser.parse_args()
|
|
return args
|
|
|
|
|
|
def setup_llm(args, **kwargs):
|
|
kv_cache_config = KvCacheConfig(
|
|
enable_block_reuse=not args.disable_kv_cache_reuse,
|
|
free_gpu_memory_fraction=args.kv_cache_fraction,
|
|
dtype=args.kv_cache_dtype,
|
|
tokens_per_block=args.tokens_per_block,
|
|
event_buffer_max_size=1024 if args.log_kv_cache_events else 0)
|
|
|
|
spec_decode_algo = args.spec_decode_algo.upper(
|
|
) if args.spec_decode_algo is not None else None
|
|
|
|
if spec_decode_algo == 'MTP':
|
|
if not args.use_one_model:
|
|
print("Running MTP eagle with two model style.")
|
|
spec_config = MTPDecodingConfig(
|
|
num_nextn_predict_layers=args.spec_decode_max_draft_len,
|
|
use_relaxed_acceptance_for_thinking=args.
|
|
use_relaxed_acceptance_for_thinking,
|
|
relaxed_topk=args.relaxed_topk,
|
|
relaxed_delta=args.relaxed_delta,
|
|
mtp_eagle_one_model=args.use_one_model,
|
|
speculative_model_dir=args.model_dir)
|
|
elif spec_decode_algo == "EAGLE3":
|
|
spec_config = EagleDecodingConfig(
|
|
max_draft_len=args.spec_decode_max_draft_len,
|
|
speculative_model_dir=args.draft_model_dir,
|
|
eagle3_one_model=args.use_one_model,
|
|
eagle_choices=args.eagle_choices,
|
|
use_dynamic_tree=args.use_dynamic_tree,
|
|
dynamic_tree_max_topK=args.dynamic_tree_max_topK,
|
|
allow_advanced_sampling=args.allow_advanced_sampling,
|
|
eagle3_model_arch=args.eagle3_model_arch)
|
|
elif spec_decode_algo == "DRAFT_TARGET":
|
|
spec_config = DraftTargetDecodingConfig(
|
|
max_draft_len=args.spec_decode_max_draft_len,
|
|
speculative_model_dir=args.draft_model_dir)
|
|
elif spec_decode_algo == "NGRAM":
|
|
spec_config = NGramDecodingConfig(
|
|
max_draft_len=args.spec_decode_max_draft_len,
|
|
max_matching_ngram_size=args.max_matching_ngram_size,
|
|
is_keep_all=True,
|
|
is_use_oldest=True,
|
|
is_public_pool=True,
|
|
)
|
|
elif spec_decode_algo == "AUTO":
|
|
spec_config = AutoDecodingConfig()
|
|
else:
|
|
spec_config = None
|
|
|
|
cuda_graph_config = CudaGraphConfig(
|
|
batch_sizes=args.cuda_graph_batch_sizes,
|
|
enable_padding=args.cuda_graph_padding_enabled,
|
|
) if args.use_cuda_graph else None
|
|
|
|
attention_dp_config = AttentionDpConfig(
|
|
enable_balance=args.attention_dp_enable_balance,
|
|
timeout_iters=args.attention_dp_time_out_iters,
|
|
batching_wait_iters=args.attention_dp_batching_wait_iters,
|
|
)
|
|
|
|
llm = LLM(
|
|
model=args.model_dir,
|
|
backend='pytorch',
|
|
checkpoint_format=args.checkpoint_format,
|
|
disable_overlap_scheduler=args.disable_overlap_scheduler,
|
|
kv_cache_config=kv_cache_config,
|
|
attn_backend=args.attention_backend,
|
|
cuda_graph_config=cuda_graph_config,
|
|
load_format=args.load_format,
|
|
print_iter_log=args.print_iter_log,
|
|
enable_iter_perf_stats=args.print_iter_log,
|
|
torch_compile_config=TorchCompileConfig(
|
|
enable_fullgraph=args.use_torch_compile,
|
|
enable_inductor=args.use_torch_compile,
|
|
enable_piecewise_cuda_graph= \
|
|
args.use_piecewise_cuda_graph)
|
|
if args.use_torch_compile else None,
|
|
moe_config=MoeConfig(backend=args.moe_backend, use_low_precision_moe_combine=args.use_low_precision_moe_combine),
|
|
sampler_type=args.sampler_type,
|
|
max_seq_len=args.max_seq_len,
|
|
max_batch_size=args.max_batch_size,
|
|
max_num_tokens=args.max_num_tokens,
|
|
enable_attention_dp=args.enable_attention_dp,
|
|
attention_dp_config=attention_dp_config,
|
|
tensor_parallel_size=args.tp_size,
|
|
pipeline_parallel_size=args.pp_size,
|
|
moe_expert_parallel_size=args.moe_ep_size,
|
|
moe_tensor_parallel_size=args.moe_tp_size,
|
|
moe_cluster_parallel_size=args.moe_cluster_size,
|
|
enable_chunked_prefill=args.enable_chunked_prefill,
|
|
speculative_config=spec_config,
|
|
trust_remote_code=args.trust_remote_code,
|
|
gather_generation_logits=args.return_generation_logits,
|
|
max_beam_width=args.max_beam_width,
|
|
orchestrator_type=args.orchestrator_type,
|
|
**kwargs)
|
|
|
|
use_beam_search = args.max_beam_width > 1
|
|
best_of = args.best_of or args.n
|
|
if use_beam_search:
|
|
if args.n == 1 and args.best_of is None:
|
|
args.n = args.max_beam_width
|
|
assert best_of <= args.max_beam_width, f"beam width: {best_of}, should be less or equal to max_beam_width: {args.max_beam_width}"
|
|
|
|
assert best_of >= args.n, f"In sampling mode best_of value: {best_of} should be greater than or equal to n: {args.n}"
|
|
|
|
sampling_params = SamplingParams(
|
|
max_tokens=args.max_tokens,
|
|
temperature=args.temperature,
|
|
top_k=args.top_k,
|
|
top_p=args.top_p,
|
|
return_context_logits=args.return_context_logits,
|
|
return_generation_logits=args.return_generation_logits,
|
|
logprobs=args.logprobs,
|
|
prompt_logprobs=args.prompt_logprobs,
|
|
n=args.n,
|
|
best_of=best_of,
|
|
use_beam_search=use_beam_search,
|
|
additional_model_outputs=args.additional_model_outputs)
|
|
return llm, sampling_params
|
|
|
|
|
|
def main():
|
|
args = parse_arguments()
|
|
prompts = args.prompt if args.prompt else example_prompts
|
|
|
|
llm, sampling_params = setup_llm(args)
|
|
new_prompts = []
|
|
if args.apply_chat_template:
|
|
for prompt in prompts:
|
|
messages = [{"role": "user", "content": f"{prompt}"}]
|
|
new_prompts.append(
|
|
llm.tokenizer.apply_chat_template(messages,
|
|
tokenize=False,
|
|
add_generation_prompt=True))
|
|
prompts = new_prompts
|
|
outputs = llm.generate(prompts, sampling_params)
|
|
|
|
for i, output in enumerate(outputs):
|
|
prompt = output.prompt
|
|
for sequence_idx, sequence in enumerate(output.outputs):
|
|
generated_text = sequence.text
|
|
# Skip printing the beam_idx if no beam search was used
|
|
sequence_id_text = f"[{sequence_idx}]" if args.max_beam_width > 1 or args.n > 1 else ""
|
|
print(
|
|
f"[{i}]{sequence_id_text} Prompt: {prompt!r}, Generated text: {generated_text!r}"
|
|
)
|
|
if args.return_context_logits:
|
|
print(
|
|
f"[{i}]{sequence_id_text} Context logits: {output.context_logits}"
|
|
)
|
|
if args.return_generation_logits:
|
|
print(
|
|
f"[{i}]{sequence_id_text} Generation logits: {sequence.generation_logits}"
|
|
)
|
|
if args.prompt_logprobs:
|
|
print(
|
|
f"[{i}]{sequence_id_text} Prompt logprobs: {sequence.prompt_logprobs}"
|
|
)
|
|
if args.logprobs:
|
|
print(f"[{i}]{sequence_id_text} Logprobs: {sequence.logprobs}")
|
|
|
|
if args.additional_model_outputs:
|
|
for output_name in args.additional_model_outputs:
|
|
if sequence.additional_context_outputs:
|
|
print(
|
|
f"[{i}]{sequence_id_text} Context {output_name}: {sequence.additional_context_outputs[output_name]}"
|
|
)
|
|
print(
|
|
f"[{i}]{sequence_id_text} Generation {output_name}: {sequence.additional_generation_outputs[output_name]}"
|
|
)
|
|
|
|
if args.log_kv_cache_events:
|
|
time.sleep(1) # Wait for events to be dispatched
|
|
events = llm.get_kv_cache_events(5)
|
|
print("=== KV_CACHE_EVENTS_START ===")
|
|
print(json.dumps(events, indent=2))
|
|
print("=== KV_CACHE_EVENTS_END ===")
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|