mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-05 02:31:33 +08:00
394 lines
16 KiB
Python
394 lines
16 KiB
Python
import argparse
|
|
import json
|
|
import time
|
|
|
|
from tensorrt_llm import LLM, SamplingParams
|
|
from tensorrt_llm.llmapi import (AttentionDpConfig, AutoDecodingConfig,
|
|
CudaGraphConfig, DraftTargetDecodingConfig,
|
|
Eagle3DecodingConfig, KvCacheConfig, MoeConfig,
|
|
MTPDecodingConfig, NGramDecodingConfig,
|
|
TorchCompileConfig)
|
|
|
|
example_prompts = [
|
|
"Hello, my name is",
|
|
"The capital of France is",
|
|
"The future of AI is",
|
|
]
|
|
|
|
|
|
def add_llm_args(parser):
|
|
parser.add_argument('--model_dir',
|
|
type=str,
|
|
required=True,
|
|
help="Model checkpoint directory.")
|
|
parser.add_argument("--prompt",
|
|
type=str,
|
|
nargs="+",
|
|
help="A single or a list of text prompts.")
|
|
parser.add_argument('--checkpoint_format',
|
|
type=str,
|
|
default=None,
|
|
choices=["HF", "mistral"],
|
|
help="Model checkpoint format.")
|
|
# Build config
|
|
parser.add_argument("--max_seq_len",
|
|
type=int,
|
|
default=None,
|
|
help="The maximum sequence length.")
|
|
parser.add_argument("--max_batch_size",
|
|
type=int,
|
|
default=2048,
|
|
help="The maximum batch size.")
|
|
parser.add_argument(
|
|
"--max_num_tokens",
|
|
type=int,
|
|
default=8192,
|
|
help=
|
|
"The maximum total tokens (context + generation) across all sequences in a batch."
|
|
)
|
|
|
|
# Parallelism
|
|
parser.add_argument('--attention_backend',
|
|
type=str,
|
|
default='TRTLLM',
|
|
choices=[
|
|
'VANILLA', 'TRTLLM', 'FLASHINFER',
|
|
'FLASHINFER_STAR_ATTENTION'
|
|
])
|
|
parser.add_argument(
|
|
'--moe_backend',
|
|
type=str,
|
|
default='AUTO',
|
|
choices=[
|
|
'AUTO', 'CUTLASS', 'TRTLLM', 'VANILLA', 'WIDEEP', 'DEEPGEMM',
|
|
'CUTEDSL', 'TRITON'
|
|
],
|
|
help=
|
|
'MoE backend to use. AUTO selects default backend based on model. It currently doesn\'t always give the best choice for all scenarios. The capabilities of auto selection will be improved in future releases.'
|
|
)
|
|
parser.add_argument('--enable_attention_dp',
|
|
default=False,
|
|
action='store_true')
|
|
parser.add_argument('--attention_dp_enable_balance',
|
|
default=False,
|
|
action='store_true')
|
|
parser.add_argument('--attention_dp_time_out_iters', type=int, default=0)
|
|
parser.add_argument('--attention_dp_batching_wait_iters',
|
|
type=int,
|
|
default=0)
|
|
parser.add_argument('--sampler_type',
|
|
default="auto",
|
|
choices=["auto", "TorchSampler", "TRTLLMSampler"])
|
|
parser.add_argument('--tp_size', type=int, default=1)
|
|
parser.add_argument('--pp_size', type=int, default=1)
|
|
parser.add_argument('--orchestrator_type',
|
|
type=str,
|
|
default=None,
|
|
choices=[None, 'rpc', 'ray'],
|
|
help='Orchestrator type for multi-GPU execution')
|
|
parser.add_argument('--moe_ep_size', type=int, default=-1)
|
|
parser.add_argument('--moe_tp_size', type=int, default=-1)
|
|
parser.add_argument('--moe_cluster_size', type=int, default=-1)
|
|
parser.add_argument(
|
|
'--use_low_precision_moe_combine',
|
|
default=False,
|
|
action='store_true',
|
|
help='Use low precision combine in MoE (only for NVFP4 quantization)')
|
|
|
|
# KV cache
|
|
parser.add_argument('--kv_cache_dtype', type=str, default='auto')
|
|
parser.add_argument('--disable_kv_cache_reuse',
|
|
default=False,
|
|
action='store_true')
|
|
parser.add_argument("--tokens_per_block", type=int, default=32)
|
|
parser.add_argument('--mamba_ssm_cache_dtype',
|
|
type=str,
|
|
default='bfloat16',
|
|
choices=['auto', 'float16', 'bfloat16', 'float32'],
|
|
help='Data type for Mamba SSM cache.')
|
|
parser.add_argument('--log_kv_cache_events',
|
|
default=False,
|
|
action='store_true')
|
|
|
|
# Runtime
|
|
parser.add_argument('--disable_overlap_scheduler',
|
|
default=False,
|
|
action='store_true')
|
|
parser.add_argument('--enable_chunked_prefill',
|
|
default=False,
|
|
action='store_true')
|
|
parser.add_argument('--use_cuda_graph', default=False, action='store_true')
|
|
parser.add_argument('--cuda_graph_padding_enabled',
|
|
default=False,
|
|
action='store_true')
|
|
parser.add_argument('--cuda_graph_batch_sizes',
|
|
nargs='+',
|
|
type=int,
|
|
default=None)
|
|
parser.add_argument('--print_iter_log',
|
|
default=False,
|
|
action='store_true',
|
|
help='Print iteration logs during execution')
|
|
parser.add_argument('--use_torch_compile',
|
|
default=False,
|
|
action='store_true',
|
|
help='Use torch.compile to optimize the model')
|
|
parser.add_argument('--use_piecewise_cuda_graph',
|
|
default=False,
|
|
action='store_true',
|
|
help='Use piecewise CUDA graph to optimize the model')
|
|
parser.add_argument('--apply_chat_template',
|
|
default=False,
|
|
action='store_true')
|
|
|
|
# Sampling
|
|
parser.add_argument("--max_tokens", type=int, default=64)
|
|
parser.add_argument("--temperature", type=float, default=None)
|
|
parser.add_argument("--top_k", type=int, default=None)
|
|
parser.add_argument("--top_p", type=float, default=None)
|
|
parser.add_argument('--load_format', type=str, default='auto')
|
|
parser.add_argument('--n', type=int, default=1)
|
|
parser.add_argument('--best_of', type=int, default=None)
|
|
parser.add_argument('--max_beam_width', type=int, default=1)
|
|
|
|
# Speculative decoding
|
|
parser.add_argument('--spec_decode_algo', type=str, default=None)
|
|
parser.add_argument('--spec_decode_max_draft_len', type=int, default=1)
|
|
parser.add_argument('--draft_model_dir', type=str, default=None)
|
|
parser.add_argument('--max_matching_ngram_size', type=int, default=5)
|
|
parser.add_argument('--use_one_model', default=False, action='store_true')
|
|
parser.add_argument('--eagle_choices', type=str, default=None)
|
|
parser.add_argument('--use_dynamic_tree',
|
|
default=False,
|
|
action='store_true')
|
|
parser.add_argument('--dynamic_tree_max_topK', type=int, default=None)
|
|
parser.add_argument('--allow_advanced_sampling',
|
|
default=False,
|
|
action='store_true')
|
|
parser.add_argument('--eagle3_model_arch',
|
|
type=str,
|
|
default="llama3",
|
|
choices=["llama3", "mistral_large3"],
|
|
help="The model architecture of the eagle3 model.")
|
|
|
|
# Relaxed acceptance
|
|
parser.add_argument('--use_relaxed_acceptance_for_thinking',
|
|
default=False,
|
|
action='store_true')
|
|
parser.add_argument('--relaxed_topk', type=int, default=1)
|
|
parser.add_argument('--relaxed_delta', type=float, default=0.)
|
|
|
|
# HF
|
|
parser.add_argument('--trust_remote_code',
|
|
default=False,
|
|
action='store_true')
|
|
parser.add_argument('--return_context_logits',
|
|
default=False,
|
|
action='store_true')
|
|
parser.add_argument('--return_generation_logits',
|
|
default=False,
|
|
action='store_true')
|
|
parser.add_argument('--prompt_logprobs', default=False, action='store_true')
|
|
parser.add_argument('--logprobs', default=False, action='store_true')
|
|
|
|
parser.add_argument('--additional_model_outputs',
|
|
type=str,
|
|
default=None,
|
|
nargs='+')
|
|
|
|
return parser
|
|
|
|
|
|
def parse_arguments():
|
|
parser = argparse.ArgumentParser(
|
|
description="LLM models with the PyTorch workflow.")
|
|
parser = add_llm_args(parser)
|
|
parser.add_argument("--kv_cache_fraction", type=float, default=0.9)
|
|
args = parser.parse_args()
|
|
return args
|
|
|
|
|
|
def setup_llm(args, **kwargs):
|
|
kv_cache_config = KvCacheConfig(
|
|
enable_block_reuse=not args.disable_kv_cache_reuse,
|
|
free_gpu_memory_fraction=args.kv_cache_fraction,
|
|
dtype=args.kv_cache_dtype,
|
|
tokens_per_block=args.tokens_per_block,
|
|
mamba_ssm_cache_dtype=args.mamba_ssm_cache_dtype,
|
|
event_buffer_max_size=1024 if args.log_kv_cache_events else 0)
|
|
|
|
spec_decode_algo = args.spec_decode_algo.upper(
|
|
) if args.spec_decode_algo is not None else None
|
|
|
|
if spec_decode_algo == 'MTP':
|
|
if not args.use_one_model:
|
|
print("Running MTP eagle with two model style.")
|
|
spec_config = MTPDecodingConfig(
|
|
num_nextn_predict_layers=args.spec_decode_max_draft_len,
|
|
use_relaxed_acceptance_for_thinking=args.
|
|
use_relaxed_acceptance_for_thinking,
|
|
relaxed_topk=args.relaxed_topk,
|
|
relaxed_delta=args.relaxed_delta,
|
|
mtp_eagle_one_model=args.use_one_model,
|
|
speculative_model=args.model_dir)
|
|
elif spec_decode_algo == "EAGLE3":
|
|
spec_config = Eagle3DecodingConfig(
|
|
max_draft_len=args.spec_decode_max_draft_len,
|
|
speculative_model=args.draft_model_dir,
|
|
eagle3_one_model=args.use_one_model,
|
|
eagle_choices=args.eagle_choices,
|
|
use_dynamic_tree=args.use_dynamic_tree,
|
|
dynamic_tree_max_topK=args.dynamic_tree_max_topK,
|
|
allow_advanced_sampling=args.allow_advanced_sampling,
|
|
eagle3_model_arch=args.eagle3_model_arch)
|
|
elif spec_decode_algo == "DRAFT_TARGET":
|
|
spec_config = DraftTargetDecodingConfig(
|
|
max_draft_len=args.spec_decode_max_draft_len,
|
|
speculative_model=args.draft_model_dir)
|
|
elif spec_decode_algo == "NGRAM":
|
|
spec_config = NGramDecodingConfig(
|
|
max_draft_len=args.spec_decode_max_draft_len,
|
|
max_matching_ngram_size=args.max_matching_ngram_size,
|
|
is_keep_all=True,
|
|
is_use_oldest=True,
|
|
is_public_pool=True,
|
|
)
|
|
elif spec_decode_algo == "AUTO":
|
|
spec_config = AutoDecodingConfig()
|
|
else:
|
|
spec_config = None
|
|
|
|
cuda_graph_config = CudaGraphConfig(
|
|
batch_sizes=args.cuda_graph_batch_sizes,
|
|
enable_padding=args.cuda_graph_padding_enabled,
|
|
) if args.use_cuda_graph else None
|
|
|
|
attention_dp_config = AttentionDpConfig(
|
|
enable_balance=args.attention_dp_enable_balance,
|
|
timeout_iters=args.attention_dp_time_out_iters,
|
|
batching_wait_iters=args.attention_dp_batching_wait_iters,
|
|
)
|
|
|
|
llm = LLM(
|
|
model=args.model_dir,
|
|
backend='pytorch',
|
|
checkpoint_format=args.checkpoint_format,
|
|
disable_overlap_scheduler=args.disable_overlap_scheduler,
|
|
kv_cache_config=kv_cache_config,
|
|
attn_backend=args.attention_backend,
|
|
cuda_graph_config=cuda_graph_config,
|
|
load_format=args.load_format,
|
|
print_iter_log=args.print_iter_log,
|
|
enable_iter_perf_stats=args.print_iter_log,
|
|
torch_compile_config=TorchCompileConfig(
|
|
enable_fullgraph=args.use_torch_compile,
|
|
enable_inductor=args.use_torch_compile,
|
|
enable_piecewise_cuda_graph= \
|
|
args.use_piecewise_cuda_graph)
|
|
if args.use_torch_compile else None,
|
|
moe_config=MoeConfig(backend=args.moe_backend, use_low_precision_moe_combine=args.use_low_precision_moe_combine),
|
|
sampler_type=args.sampler_type,
|
|
max_seq_len=args.max_seq_len,
|
|
max_batch_size=args.max_batch_size,
|
|
max_num_tokens=args.max_num_tokens,
|
|
enable_attention_dp=args.enable_attention_dp,
|
|
attention_dp_config=attention_dp_config,
|
|
tensor_parallel_size=args.tp_size,
|
|
pipeline_parallel_size=args.pp_size,
|
|
moe_expert_parallel_size=args.moe_ep_size,
|
|
moe_tensor_parallel_size=args.moe_tp_size,
|
|
moe_cluster_parallel_size=args.moe_cluster_size,
|
|
enable_chunked_prefill=args.enable_chunked_prefill,
|
|
speculative_config=spec_config,
|
|
trust_remote_code=args.trust_remote_code,
|
|
gather_generation_logits=args.return_generation_logits,
|
|
max_beam_width=args.max_beam_width,
|
|
orchestrator_type=args.orchestrator_type,
|
|
**kwargs)
|
|
|
|
use_beam_search = args.max_beam_width > 1
|
|
best_of = args.best_of or args.n
|
|
if use_beam_search:
|
|
if args.n == 1 and args.best_of is None:
|
|
args.n = args.max_beam_width
|
|
assert best_of <= args.max_beam_width, f"beam width: {best_of}, should be less or equal to max_beam_width: {args.max_beam_width}"
|
|
|
|
assert best_of >= args.n, f"In sampling mode best_of value: {best_of} should be greater than or equal to n: {args.n}"
|
|
|
|
sampling_params = SamplingParams(
|
|
max_tokens=args.max_tokens,
|
|
temperature=args.temperature,
|
|
top_k=args.top_k,
|
|
top_p=args.top_p,
|
|
return_context_logits=args.return_context_logits,
|
|
return_generation_logits=args.return_generation_logits,
|
|
logprobs=args.logprobs,
|
|
prompt_logprobs=args.prompt_logprobs,
|
|
n=args.n,
|
|
best_of=best_of,
|
|
use_beam_search=use_beam_search,
|
|
additional_model_outputs=args.additional_model_outputs)
|
|
return llm, sampling_params
|
|
|
|
|
|
def main():
|
|
args = parse_arguments()
|
|
prompts = args.prompt if args.prompt else example_prompts
|
|
|
|
llm, sampling_params = setup_llm(args)
|
|
new_prompts = []
|
|
if args.apply_chat_template:
|
|
for prompt in prompts:
|
|
messages = [{"role": "user", "content": f"{prompt}"}]
|
|
new_prompts.append(
|
|
llm.tokenizer.apply_chat_template(messages,
|
|
tokenize=False,
|
|
add_generation_prompt=True))
|
|
prompts = new_prompts
|
|
outputs = llm.generate(prompts, sampling_params)
|
|
|
|
for i, output in enumerate(outputs):
|
|
prompt = output.prompt
|
|
for sequence_idx, sequence in enumerate(output.outputs):
|
|
generated_text = sequence.text
|
|
# Skip printing the beam_idx if no beam search was used
|
|
sequence_id_text = f"[{sequence_idx}]" if args.max_beam_width > 1 or args.n > 1 else ""
|
|
print(
|
|
f"[{i}]{sequence_id_text} Prompt: {prompt!r}, Generated text: {generated_text!r}"
|
|
)
|
|
if args.return_context_logits:
|
|
print(
|
|
f"[{i}]{sequence_id_text} Context logits: {output.context_logits}"
|
|
)
|
|
if args.return_generation_logits:
|
|
print(
|
|
f"[{i}]{sequence_id_text} Generation logits: {sequence.generation_logits}"
|
|
)
|
|
if args.prompt_logprobs:
|
|
print(
|
|
f"[{i}]{sequence_id_text} Prompt logprobs: {sequence.prompt_logprobs}"
|
|
)
|
|
if args.logprobs:
|
|
print(f"[{i}]{sequence_id_text} Logprobs: {sequence.logprobs}")
|
|
|
|
if args.additional_model_outputs:
|
|
for output_name in args.additional_model_outputs:
|
|
if sequence.additional_context_outputs:
|
|
print(
|
|
f"[{i}]{sequence_id_text} Context {output_name}: {sequence.additional_context_outputs[output_name]}"
|
|
)
|
|
print(
|
|
f"[{i}]{sequence_id_text} Generation {output_name}: {sequence.additional_generation_outputs[output_name]}"
|
|
)
|
|
|
|
if args.log_kv_cache_events:
|
|
time.sleep(1) # Wait for events to be dispatched
|
|
events = llm.get_kv_cache_events(5)
|
|
print("=== KV_CACHE_EVENTS_START ===")
|
|
print(json.dumps(events, indent=2))
|
|
print("=== KV_CACHE_EVENTS_END ===")
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|