TensorRT-LLMs/examples/llm-api/quickstart_advanced.py
Yi Zhang 0306c0f12c
[TRTLLM-9766][feat] Integration of the KVCacheManager V2 to TRTLLM Runtime (#10659)
Signed-off-by: yizhang-nv <187001205+yizhang-nv@users.noreply.github.com>
2026-02-02 14:29:02 +08:00

401 lines
17 KiB
Python

import argparse
import json
import time
from tensorrt_llm import LLM, SamplingParams
from tensorrt_llm.llmapi import (AttentionDpConfig, AutoDecodingConfig,
CudaGraphConfig, DraftTargetDecodingConfig,
Eagle3DecodingConfig, KvCacheConfig, MoeConfig,
MTPDecodingConfig, NGramDecodingConfig,
TorchCompileConfig)
example_prompts = [
"Hello, my name is",
"The capital of France is",
"The future of AI is",
]
def add_llm_args(parser):
parser.add_argument('--model_dir',
type=str,
required=True,
help="Model checkpoint directory.")
parser.add_argument("--prompt",
type=str,
nargs="+",
help="A single or a list of text prompts.")
parser.add_argument('--checkpoint_format',
type=str,
default=None,
choices=["HF", "mistral"],
help="Model checkpoint format.")
# Build config
parser.add_argument("--max_seq_len",
type=int,
default=None,
help="The maximum sequence length.")
parser.add_argument("--max_batch_size",
type=int,
default=2048,
help="The maximum batch size.")
parser.add_argument(
"--max_num_tokens",
type=int,
default=8192,
help=
"The maximum total tokens (context + generation) across all sequences in a batch."
)
# Parallelism
parser.add_argument('--attention_backend',
type=str,
default='TRTLLM',
choices=[
'VANILLA', 'TRTLLM', 'FLASHINFER',
'FLASHINFER_STAR_ATTENTION'
])
parser.add_argument(
'--moe_backend',
type=str,
default='AUTO',
choices=[
'AUTO', 'CUTLASS', 'TRTLLM', 'VANILLA', 'WIDEEP', 'DEEPGEMM',
'CUTEDSL', 'TRITON'
],
help=
'MoE backend to use. AUTO selects default backend based on model. It currently doesn\'t always give the best choice for all scenarios. The capabilities of auto selection will be improved in future releases.'
)
parser.add_argument('--enable_attention_dp',
default=False,
action='store_true')
parser.add_argument('--attention_dp_enable_balance',
default=False,
action='store_true')
parser.add_argument('--attention_dp_time_out_iters', type=int, default=0)
parser.add_argument('--attention_dp_batching_wait_iters',
type=int,
default=0)
parser.add_argument('--sampler_type',
default="auto",
choices=["auto", "TorchSampler", "TRTLLMSampler"])
parser.add_argument('--tp_size', type=int, default=1)
parser.add_argument('--pp_size', type=int, default=1)
parser.add_argument('--orchestrator_type',
type=str,
default=None,
choices=[None, 'rpc', 'ray'],
help='Orchestrator type for multi-GPU execution')
parser.add_argument('--moe_ep_size', type=int, default=-1)
parser.add_argument('--moe_tp_size', type=int, default=-1)
parser.add_argument('--moe_cluster_size', type=int, default=-1)
parser.add_argument(
'--use_low_precision_moe_combine',
default=False,
action='store_true',
help='Use low precision combine in MoE (only for NVFP4 quantization)')
# KV cache
parser.add_argument('--kv_cache_dtype', type=str, default='auto')
parser.add_argument('--disable_kv_cache_reuse',
default=False,
action='store_true')
parser.add_argument("--tokens_per_block", type=int, default=32)
parser.add_argument('--mamba_ssm_cache_dtype',
type=str,
default='bfloat16',
choices=['auto', 'float16', 'bfloat16', 'float32'],
help='Data type for Mamba SSM cache.')
parser.add_argument('--log_kv_cache_events',
default=False,
action='store_true')
parser.add_argument(
'--use_kv_cache_manager_v2',
default=False,
action='store_true',
help='Use KVCacheManagerV2 for KV cache management (PyTorch backend).',
)
# Runtime
parser.add_argument('--disable_overlap_scheduler',
default=False,
action='store_true')
parser.add_argument('--enable_chunked_prefill',
default=False,
action='store_true')
parser.add_argument('--use_cuda_graph', default=False, action='store_true')
parser.add_argument('--cuda_graph_padding_enabled',
default=False,
action='store_true')
parser.add_argument('--cuda_graph_batch_sizes',
nargs='+',
type=int,
default=None)
parser.add_argument('--print_iter_log',
default=False,
action='store_true',
help='Print iteration logs during execution')
parser.add_argument('--use_torch_compile',
default=False,
action='store_true',
help='Use torch.compile to optimize the model')
parser.add_argument('--use_piecewise_cuda_graph',
default=False,
action='store_true',
help='Use piecewise CUDA graph to optimize the model')
parser.add_argument('--apply_chat_template',
default=False,
action='store_true')
# Sampling
parser.add_argument("--max_tokens", type=int, default=64)
parser.add_argument("--temperature", type=float, default=None)
parser.add_argument("--top_k", type=int, default=None)
parser.add_argument("--top_p", type=float, default=None)
parser.add_argument('--load_format', type=str, default='auto')
parser.add_argument('--n', type=int, default=1)
parser.add_argument('--best_of', type=int, default=None)
parser.add_argument('--max_beam_width', type=int, default=1)
# Speculative decoding
parser.add_argument('--spec_decode_algo', type=str, default=None)
parser.add_argument('--spec_decode_max_draft_len', type=int, default=1)
parser.add_argument('--draft_model_dir', type=str, default=None)
parser.add_argument('--max_matching_ngram_size', type=int, default=5)
parser.add_argument('--use_one_model', default=False, action='store_true')
parser.add_argument('--eagle_choices', type=str, default=None)
parser.add_argument('--use_dynamic_tree',
default=False,
action='store_true')
parser.add_argument('--dynamic_tree_max_topK', type=int, default=None)
parser.add_argument('--allow_advanced_sampling',
default=False,
action='store_true')
parser.add_argument('--eagle3_model_arch',
type=str,
default="llama3",
choices=["llama3", "mistral_large3"],
help="The model architecture of the eagle3 model.")
# Relaxed acceptance
parser.add_argument('--use_relaxed_acceptance_for_thinking',
default=False,
action='store_true')
parser.add_argument('--relaxed_topk', type=int, default=1)
parser.add_argument('--relaxed_delta', type=float, default=0.)
# HF
parser.add_argument('--trust_remote_code',
default=False,
action='store_true')
parser.add_argument('--return_context_logits',
default=False,
action='store_true')
parser.add_argument('--return_generation_logits',
default=False,
action='store_true')
parser.add_argument('--prompt_logprobs', default=False, action='store_true')
parser.add_argument('--logprobs', default=False, action='store_true')
parser.add_argument('--additional_model_outputs',
type=str,
default=None,
nargs='+')
return parser
def parse_arguments():
parser = argparse.ArgumentParser(
description="LLM models with the PyTorch workflow.")
parser = add_llm_args(parser)
parser.add_argument("--kv_cache_fraction", type=float, default=0.9)
args = parser.parse_args()
return args
def setup_llm(args, **kwargs):
kv_cache_config = KvCacheConfig(
enable_block_reuse=not args.disable_kv_cache_reuse,
free_gpu_memory_fraction=args.kv_cache_fraction,
dtype=args.kv_cache_dtype,
tokens_per_block=args.tokens_per_block,
use_kv_cache_manager_v2=args.use_kv_cache_manager_v2,
mamba_ssm_cache_dtype=args.mamba_ssm_cache_dtype,
event_buffer_max_size=1024 if args.log_kv_cache_events else 0)
spec_decode_algo = args.spec_decode_algo.upper(
) if args.spec_decode_algo is not None else None
if spec_decode_algo == 'MTP':
if not args.use_one_model:
print("Running MTP eagle with two model style.")
spec_config = MTPDecodingConfig(
num_nextn_predict_layers=args.spec_decode_max_draft_len,
use_relaxed_acceptance_for_thinking=args.
use_relaxed_acceptance_for_thinking,
relaxed_topk=args.relaxed_topk,
relaxed_delta=args.relaxed_delta,
mtp_eagle_one_model=args.use_one_model,
speculative_model=args.model_dir)
elif spec_decode_algo == "EAGLE3":
spec_config = Eagle3DecodingConfig(
max_draft_len=args.spec_decode_max_draft_len,
speculative_model=args.draft_model_dir,
eagle3_one_model=args.use_one_model,
eagle_choices=args.eagle_choices,
use_dynamic_tree=args.use_dynamic_tree,
dynamic_tree_max_topK=args.dynamic_tree_max_topK,
allow_advanced_sampling=args.allow_advanced_sampling,
eagle3_model_arch=args.eagle3_model_arch)
elif spec_decode_algo == "DRAFT_TARGET":
spec_config = DraftTargetDecodingConfig(
max_draft_len=args.spec_decode_max_draft_len,
speculative_model=args.draft_model_dir)
elif spec_decode_algo == "NGRAM":
spec_config = NGramDecodingConfig(
max_draft_len=args.spec_decode_max_draft_len,
max_matching_ngram_size=args.max_matching_ngram_size,
is_keep_all=True,
is_use_oldest=True,
is_public_pool=True,
)
elif spec_decode_algo == "AUTO":
spec_config = AutoDecodingConfig()
else:
spec_config = None
cuda_graph_config = CudaGraphConfig(
batch_sizes=args.cuda_graph_batch_sizes,
enable_padding=args.cuda_graph_padding_enabled,
) if args.use_cuda_graph else None
attention_dp_config = AttentionDpConfig(
enable_balance=args.attention_dp_enable_balance,
timeout_iters=args.attention_dp_time_out_iters,
batching_wait_iters=args.attention_dp_batching_wait_iters,
)
llm = LLM(
model=args.model_dir,
backend='pytorch',
checkpoint_format=args.checkpoint_format,
disable_overlap_scheduler=args.disable_overlap_scheduler,
kv_cache_config=kv_cache_config,
attn_backend=args.attention_backend,
cuda_graph_config=cuda_graph_config,
load_format=args.load_format,
print_iter_log=args.print_iter_log,
enable_iter_perf_stats=args.print_iter_log,
torch_compile_config=TorchCompileConfig(
enable_fullgraph=args.use_torch_compile,
enable_inductor=args.use_torch_compile,
enable_piecewise_cuda_graph= \
args.use_piecewise_cuda_graph)
if args.use_torch_compile else None,
moe_config=MoeConfig(backend=args.moe_backend, use_low_precision_moe_combine=args.use_low_precision_moe_combine),
sampler_type=args.sampler_type,
max_seq_len=args.max_seq_len,
max_batch_size=args.max_batch_size,
max_num_tokens=args.max_num_tokens,
enable_attention_dp=args.enable_attention_dp,
attention_dp_config=attention_dp_config,
tensor_parallel_size=args.tp_size,
pipeline_parallel_size=args.pp_size,
moe_expert_parallel_size=args.moe_ep_size,
moe_tensor_parallel_size=args.moe_tp_size,
moe_cluster_parallel_size=args.moe_cluster_size,
enable_chunked_prefill=args.enable_chunked_prefill,
speculative_config=spec_config,
trust_remote_code=args.trust_remote_code,
gather_generation_logits=args.return_generation_logits,
max_beam_width=args.max_beam_width,
orchestrator_type=args.orchestrator_type,
**kwargs)
use_beam_search = args.max_beam_width > 1
best_of = args.best_of or args.n
if use_beam_search:
if args.n == 1 and args.best_of is None:
args.n = args.max_beam_width
assert best_of <= args.max_beam_width, f"beam width: {best_of}, should be less or equal to max_beam_width: {args.max_beam_width}"
assert best_of >= args.n, f"In sampling mode best_of value: {best_of} should be greater than or equal to n: {args.n}"
sampling_params = SamplingParams(
max_tokens=args.max_tokens,
temperature=args.temperature,
top_k=args.top_k,
top_p=args.top_p,
return_context_logits=args.return_context_logits,
return_generation_logits=args.return_generation_logits,
logprobs=args.logprobs,
prompt_logprobs=args.prompt_logprobs,
n=args.n,
best_of=best_of,
use_beam_search=use_beam_search,
additional_model_outputs=args.additional_model_outputs)
return llm, sampling_params
def main():
args = parse_arguments()
prompts = args.prompt if args.prompt else example_prompts
llm, sampling_params = setup_llm(args)
new_prompts = []
if args.apply_chat_template:
for prompt in prompts:
messages = [{"role": "user", "content": f"{prompt}"}]
new_prompts.append(
llm.tokenizer.apply_chat_template(messages,
tokenize=False,
add_generation_prompt=True))
prompts = new_prompts
outputs = llm.generate(prompts, sampling_params)
for i, output in enumerate(outputs):
prompt = output.prompt
for sequence_idx, sequence in enumerate(output.outputs):
generated_text = sequence.text
# Skip printing the beam_idx if no beam search was used
sequence_id_text = f"[{sequence_idx}]" if args.max_beam_width > 1 or args.n > 1 else ""
print(
f"[{i}]{sequence_id_text} Prompt: {prompt!r}, Generated text: {generated_text!r}"
)
if args.return_context_logits:
print(
f"[{i}]{sequence_id_text} Context logits: {output.context_logits}"
)
if args.return_generation_logits:
print(
f"[{i}]{sequence_id_text} Generation logits: {sequence.generation_logits}"
)
if args.prompt_logprobs:
print(
f"[{i}]{sequence_id_text} Prompt logprobs: {sequence.prompt_logprobs}"
)
if args.logprobs:
print(f"[{i}]{sequence_id_text} Logprobs: {sequence.logprobs}")
if args.additional_model_outputs:
for output_name in args.additional_model_outputs:
if sequence.additional_context_outputs:
print(
f"[{i}]{sequence_id_text} Context {output_name}: {sequence.additional_context_outputs[output_name]}"
)
print(
f"[{i}]{sequence_id_text} Generation {output_name}: {sequence.additional_generation_outputs[output_name]}"
)
if args.log_kv_cache_events:
time.sleep(1) # Wait for events to be dispatched
events = llm.get_kv_cache_events(5)
print("=== KV_CACHE_EVENTS_START ===")
print(json.dumps(events, indent=2))
print("=== KV_CACHE_EVENTS_END ===")
if __name__ == '__main__':
main()