import argparse import json import time from tensorrt_llm import LLM, SamplingParams from tensorrt_llm.llmapi import (AttentionDpConfig, AutoDecodingConfig, CudaGraphConfig, DraftTargetDecodingConfig, Eagle3DecodingConfig, KvCacheConfig, MoeConfig, MTPDecodingConfig, NGramDecodingConfig, TorchCompileConfig) example_prompts = [ "Hello, my name is", "The capital of France is", "The future of AI is", ] def add_llm_args(parser): parser.add_argument('--model_dir', type=str, required=True, help="Model checkpoint directory.") parser.add_argument("--prompt", type=str, nargs="+", help="A single or a list of text prompts.") parser.add_argument('--checkpoint_format', type=str, default=None, choices=["HF", "mistral"], help="Model checkpoint format.") # Build config parser.add_argument("--max_seq_len", type=int, default=None, help="The maximum sequence length.") parser.add_argument("--max_batch_size", type=int, default=2048, help="The maximum batch size.") parser.add_argument( "--max_num_tokens", type=int, default=8192, help= "The maximum total tokens (context + generation) across all sequences in a batch." ) # Parallelism parser.add_argument('--attention_backend', type=str, default='TRTLLM', choices=[ 'VANILLA', 'TRTLLM', 'FLASHINFER', 'FLASHINFER_STAR_ATTENTION' ]) parser.add_argument( '--moe_backend', type=str, default='AUTO', choices=[ 'AUTO', 'CUTLASS', 'TRTLLM', 'VANILLA', 'WIDEEP', 'DEEPGEMM', 'CUTEDSL', 'TRITON' ], help= 'MoE backend to use. AUTO selects default backend based on model. It currently doesn\'t always give the best choice for all scenarios. The capabilities of auto selection will be improved in future releases.' ) parser.add_argument('--enable_attention_dp', default=False, action='store_true') parser.add_argument('--attention_dp_enable_balance', default=False, action='store_true') parser.add_argument('--attention_dp_time_out_iters', type=int, default=0) parser.add_argument('--attention_dp_batching_wait_iters', type=int, default=0) parser.add_argument('--sampler_type', default="auto", choices=["auto", "TorchSampler", "TRTLLMSampler"]) parser.add_argument('--tp_size', type=int, default=1) parser.add_argument('--pp_size', type=int, default=1) parser.add_argument('--orchestrator_type', type=str, default=None, choices=[None, 'rpc', 'ray'], help='Orchestrator type for multi-GPU execution') parser.add_argument('--moe_ep_size', type=int, default=-1) parser.add_argument('--moe_tp_size', type=int, default=-1) parser.add_argument('--moe_cluster_size', type=int, default=-1) parser.add_argument( '--use_low_precision_moe_combine', default=False, action='store_true', help='Use low precision combine in MoE (only for NVFP4 quantization)') # KV cache parser.add_argument('--kv_cache_dtype', type=str, default='auto') parser.add_argument('--disable_kv_cache_reuse', default=False, action='store_true') parser.add_argument("--tokens_per_block", type=int, default=32) parser.add_argument('--mamba_ssm_cache_dtype', type=str, default='bfloat16', choices=['auto', 'float16', 'bfloat16', 'float32'], help='Data type for Mamba SSM cache.') parser.add_argument('--log_kv_cache_events', default=False, action='store_true') parser.add_argument( '--use_kv_cache_manager_v2', default=False, action='store_true', help='Use KVCacheManagerV2 for KV cache management (PyTorch backend).', ) # Runtime parser.add_argument('--disable_overlap_scheduler', default=False, action='store_true') parser.add_argument('--enable_chunked_prefill', default=False, action='store_true') parser.add_argument('--use_cuda_graph', default=False, action='store_true') parser.add_argument('--cuda_graph_padding_enabled', default=False, action='store_true') parser.add_argument('--cuda_graph_batch_sizes', nargs='+', type=int, default=None) parser.add_argument('--print_iter_log', default=False, action='store_true', help='Print iteration logs during execution') parser.add_argument('--use_torch_compile', default=False, action='store_true', help='Use torch.compile to optimize the model') parser.add_argument('--use_piecewise_cuda_graph', default=False, action='store_true', help='Use piecewise CUDA graph to optimize the model') parser.add_argument('--apply_chat_template', default=False, action='store_true') # Sampling parser.add_argument("--max_tokens", type=int, default=64) parser.add_argument("--temperature", type=float, default=None) parser.add_argument("--top_k", type=int, default=None) parser.add_argument("--top_p", type=float, default=None) parser.add_argument('--load_format', type=str, default='auto') parser.add_argument('--n', type=int, default=1) parser.add_argument('--best_of', type=int, default=None) parser.add_argument('--max_beam_width', type=int, default=1) # Speculative decoding parser.add_argument('--spec_decode_algo', type=str, default=None) parser.add_argument('--spec_decode_max_draft_len', type=int, default=1) parser.add_argument('--draft_model_dir', type=str, default=None) parser.add_argument('--max_matching_ngram_size', type=int, default=5) parser.add_argument('--use_one_model', default=False, action='store_true') parser.add_argument('--eagle_choices', type=str, default=None) parser.add_argument('--use_dynamic_tree', default=False, action='store_true') parser.add_argument('--dynamic_tree_max_topK', type=int, default=None) parser.add_argument('--allow_advanced_sampling', default=False, action='store_true') parser.add_argument('--eagle3_model_arch', type=str, default="llama3", choices=["llama3", "mistral_large3"], help="The model architecture of the eagle3 model.") # Relaxed acceptance parser.add_argument('--use_relaxed_acceptance_for_thinking', default=False, action='store_true') parser.add_argument('--relaxed_topk', type=int, default=1) parser.add_argument('--relaxed_delta', type=float, default=0.) # HF parser.add_argument('--trust_remote_code', default=False, action='store_true') parser.add_argument('--return_context_logits', default=False, action='store_true') parser.add_argument('--return_generation_logits', default=False, action='store_true') parser.add_argument('--prompt_logprobs', default=False, action='store_true') parser.add_argument('--logprobs', default=False, action='store_true') parser.add_argument('--additional_model_outputs', type=str, default=None, nargs='+') return parser def parse_arguments(): parser = argparse.ArgumentParser( description="LLM models with the PyTorch workflow.") parser = add_llm_args(parser) parser.add_argument("--kv_cache_fraction", type=float, default=0.9) args = parser.parse_args() return args def setup_llm(args, **kwargs): kv_cache_config = KvCacheConfig( enable_block_reuse=not args.disable_kv_cache_reuse, free_gpu_memory_fraction=args.kv_cache_fraction, dtype=args.kv_cache_dtype, tokens_per_block=args.tokens_per_block, use_kv_cache_manager_v2=args.use_kv_cache_manager_v2, mamba_ssm_cache_dtype=args.mamba_ssm_cache_dtype, event_buffer_max_size=1024 if args.log_kv_cache_events else 0) spec_decode_algo = args.spec_decode_algo.upper( ) if args.spec_decode_algo is not None else None if spec_decode_algo == 'MTP': if not args.use_one_model: print("Running MTP eagle with two model style.") spec_config = MTPDecodingConfig( num_nextn_predict_layers=args.spec_decode_max_draft_len, use_relaxed_acceptance_for_thinking=args. use_relaxed_acceptance_for_thinking, relaxed_topk=args.relaxed_topk, relaxed_delta=args.relaxed_delta, mtp_eagle_one_model=args.use_one_model, speculative_model=args.model_dir) elif spec_decode_algo == "EAGLE3": spec_config = Eagle3DecodingConfig( max_draft_len=args.spec_decode_max_draft_len, speculative_model=args.draft_model_dir, eagle3_one_model=args.use_one_model, eagle_choices=args.eagle_choices, use_dynamic_tree=args.use_dynamic_tree, dynamic_tree_max_topK=args.dynamic_tree_max_topK, allow_advanced_sampling=args.allow_advanced_sampling, eagle3_model_arch=args.eagle3_model_arch) elif spec_decode_algo == "DRAFT_TARGET": spec_config = DraftTargetDecodingConfig( max_draft_len=args.spec_decode_max_draft_len, speculative_model=args.draft_model_dir) elif spec_decode_algo == "NGRAM": spec_config = NGramDecodingConfig( max_draft_len=args.spec_decode_max_draft_len, max_matching_ngram_size=args.max_matching_ngram_size, is_keep_all=True, is_use_oldest=True, is_public_pool=True, ) elif spec_decode_algo == "AUTO": spec_config = AutoDecodingConfig() else: spec_config = None cuda_graph_config = CudaGraphConfig( batch_sizes=args.cuda_graph_batch_sizes, enable_padding=args.cuda_graph_padding_enabled, ) if args.use_cuda_graph else None attention_dp_config = AttentionDpConfig( enable_balance=args.attention_dp_enable_balance, timeout_iters=args.attention_dp_time_out_iters, batching_wait_iters=args.attention_dp_batching_wait_iters, ) llm = LLM( model=args.model_dir, backend='pytorch', checkpoint_format=args.checkpoint_format, disable_overlap_scheduler=args.disable_overlap_scheduler, kv_cache_config=kv_cache_config, attn_backend=args.attention_backend, cuda_graph_config=cuda_graph_config, load_format=args.load_format, print_iter_log=args.print_iter_log, enable_iter_perf_stats=args.print_iter_log, torch_compile_config=TorchCompileConfig( enable_fullgraph=args.use_torch_compile, enable_inductor=args.use_torch_compile, enable_piecewise_cuda_graph= \ args.use_piecewise_cuda_graph) if args.use_torch_compile else None, moe_config=MoeConfig(backend=args.moe_backend, use_low_precision_moe_combine=args.use_low_precision_moe_combine), sampler_type=args.sampler_type, max_seq_len=args.max_seq_len, max_batch_size=args.max_batch_size, max_num_tokens=args.max_num_tokens, enable_attention_dp=args.enable_attention_dp, attention_dp_config=attention_dp_config, tensor_parallel_size=args.tp_size, pipeline_parallel_size=args.pp_size, moe_expert_parallel_size=args.moe_ep_size, moe_tensor_parallel_size=args.moe_tp_size, moe_cluster_parallel_size=args.moe_cluster_size, enable_chunked_prefill=args.enable_chunked_prefill, speculative_config=spec_config, trust_remote_code=args.trust_remote_code, gather_generation_logits=args.return_generation_logits, max_beam_width=args.max_beam_width, orchestrator_type=args.orchestrator_type, **kwargs) use_beam_search = args.max_beam_width > 1 best_of = args.best_of or args.n if use_beam_search: if args.n == 1 and args.best_of is None: args.n = args.max_beam_width assert best_of <= args.max_beam_width, f"beam width: {best_of}, should be less or equal to max_beam_width: {args.max_beam_width}" assert best_of >= args.n, f"In sampling mode best_of value: {best_of} should be greater than or equal to n: {args.n}" sampling_params = SamplingParams( max_tokens=args.max_tokens, temperature=args.temperature, top_k=args.top_k, top_p=args.top_p, return_context_logits=args.return_context_logits, return_generation_logits=args.return_generation_logits, logprobs=args.logprobs, prompt_logprobs=args.prompt_logprobs, n=args.n, best_of=best_of, use_beam_search=use_beam_search, additional_model_outputs=args.additional_model_outputs) return llm, sampling_params def main(): args = parse_arguments() prompts = args.prompt if args.prompt else example_prompts llm, sampling_params = setup_llm(args) new_prompts = [] if args.apply_chat_template: for prompt in prompts: messages = [{"role": "user", "content": f"{prompt}"}] new_prompts.append( llm.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)) prompts = new_prompts outputs = llm.generate(prompts, sampling_params) for i, output in enumerate(outputs): prompt = output.prompt for sequence_idx, sequence in enumerate(output.outputs): generated_text = sequence.text # Skip printing the beam_idx if no beam search was used sequence_id_text = f"[{sequence_idx}]" if args.max_beam_width > 1 or args.n > 1 else "" print( f"[{i}]{sequence_id_text} Prompt: {prompt!r}, Generated text: {generated_text!r}" ) if args.return_context_logits: print( f"[{i}]{sequence_id_text} Context logits: {output.context_logits}" ) if args.return_generation_logits: print( f"[{i}]{sequence_id_text} Generation logits: {sequence.generation_logits}" ) if args.prompt_logprobs: print( f"[{i}]{sequence_id_text} Prompt logprobs: {sequence.prompt_logprobs}" ) if args.logprobs: print(f"[{i}]{sequence_id_text} Logprobs: {sequence.logprobs}") if args.additional_model_outputs: for output_name in args.additional_model_outputs: if sequence.additional_context_outputs: print( f"[{i}]{sequence_id_text} Context {output_name}: {sequence.additional_context_outputs[output_name]}" ) print( f"[{i}]{sequence_id_text} Generation {output_name}: {sequence.additional_generation_outputs[output_name]}" ) if args.log_kv_cache_events: time.sleep(1) # Wait for events to be dispatched events = llm.get_kv_cache_events(5) print("=== KV_CACHE_EVENTS_START ===") print(json.dumps(events, indent=2)) print("=== KV_CACHE_EVENTS_END ===") if __name__ == '__main__': main()