TensorRT-LLMs/examples/pytorch/quickstart_advanced.py
hlu1 31624b079a
feat: [Deepseek] Add trtllm-gen MOE FP4 MOE backend (#3387)
* Add TRT-LLM Gen MOE to Deepseek

fix fused moe rebase bug.

Fix atol in test_fp4_gemm_quantize.py

fix fused moe rebase bug.

Fix FusedMoe.

Disable 2nd routing kernel preexit

Bump routing reduction to fp32

Disable PDL for fc1

[DEBUG] Lift token limit to 16k

[Bugfix] Token limit to 16k + fp32 routing + tanh

Make fp8 tileN 8

Fix FP8 MoE + Remove redundent temp output for FP4

[FP8-only] Avoid wasting CTAs for activation kernel

fix: unblock FP8 weightloading with trtllm-gen

Remove max_token limit for trtllm-gen path

perf: avoid type-conversion and fill_ from aten

Minor fix

Signed-off-by: Hao Lu <haolu@nvidia.com>

* Fix rebase issues

Signed-off-by: Hao Lu <haolu@nvidia.com>

* Fix compile issue

Signed-off-by: Zongfei Jing <20381269+zongfeijing@users.noreply.github.com>

* CI clean

Signed-off-by: Zongfei Jing <20381269+zongfeijing@users.noreply.github.com>

---------

Signed-off-by: Hao Lu <haolu@nvidia.com>
Signed-off-by: Zongfei Jing <20381269+zongfeijing@users.noreply.github.com>
Co-authored-by: Zongfei Jing <20381269+zongfeijing@users.noreply.github.com>
2025-04-21 10:01:33 +08:00

172 lines
6.1 KiB
Python

import argparse
from tensorrt_llm import SamplingParams
from tensorrt_llm._torch import LLM
from tensorrt_llm._torch.pyexecutor.config import PyTorchConfig
from tensorrt_llm.llmapi import (EagleDecodingConfig, KvCacheConfig,
MTPDecodingConfig)
example_prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
def add_llm_args(parser):
parser.add_argument('--model_dir',
type=str,
required=True,
help="Model checkpoint directory.")
parser.add_argument("--prompt",
type=str,
nargs="+",
help="A single or a list of text prompts.")
# Build config
parser.add_argument("--max_seq_len",
type=int,
default=None,
help="The maximum sequence length.")
parser.add_argument("--max_batch_size",
type=int,
default=2048,
help="The maximum batch size.")
parser.add_argument(
"--max_num_tokens",
type=int,
default=8192,
help=
"The maximum total tokens (context + generation) across all sequences in a batch."
)
# Parallelism
parser.add_argument('--attention_backend',
type=str,
default='TRTLLM',
choices=[
'VANILLA', 'TRTLLM', 'FLASHINFER',
'FLASHINFER_STAR_ATTENTION'
])
parser.add_argument('--moe_backend',
type=str,
default='CUTLASS',
choices=['CUTLASS', 'TRTLLM'])
parser.add_argument('--enable_attention_dp',
default=False,
action='store_true')
parser.add_argument('--tp_size', type=int, default=1)
parser.add_argument('--pp_size', type=int, default=1)
parser.add_argument('--moe_ep_size', type=int, default=-1)
parser.add_argument('--moe_tp_size', type=int, default=-1)
# KV cache
parser.add_argument('--kv_cache_dtype', type=str, default='auto')
parser.add_argument('--kv_cache_enable_block_reuse',
default=True,
action='store_false')
parser.add_argument("--kv_cache_fraction", type=float, default=None)
# Runtime
parser.add_argument('--enable_overlap_scheduler',
default=False,
action='store_true')
parser.add_argument('--enable_chunked_prefill',
default=False,
action='store_true')
parser.add_argument('--use_cuda_graph', default=False, action='store_true')
parser.add_argument('--print_iter_log',
default=False,
action='store_true',
help='Print iteration logs during execution')
# Sampling
parser.add_argument("--max_tokens", type=int, default=64)
parser.add_argument("--temperature", type=float, default=None)
parser.add_argument("--top_k", type=int, default=None)
parser.add_argument("--top_p", type=float, default=None)
parser.add_argument('--load_format', type=str, default='auto')
# Speculative decoding
parser.add_argument('--spec_decode_algo', type=str, default=None)
parser.add_argument('--spec_decode_nextn', type=int, default=1)
parser.add_argument('--eagle_model_dir', type=str, default=None)
return parser
def parse_arguments():
parser = argparse.ArgumentParser(
description="LLM models with the PyTorch workflow.")
parser = add_llm_args(parser)
args = parser.parse_args()
return args
def setup_llm(args):
pytorch_config = PyTorchConfig(
enable_overlap_scheduler=args.enable_overlap_scheduler,
kv_cache_dtype=args.kv_cache_dtype,
attn_backend=args.attention_backend,
use_cuda_graph=args.use_cuda_graph,
load_format=args.load_format,
print_iter_log=args.print_iter_log,
moe_backend=args.moe_backend)
kv_cache_config = KvCacheConfig(
enable_block_reuse=args.kv_cache_enable_block_reuse,
free_gpu_memory_fraction=args.kv_cache_fraction,
)
spec_decode_algo = args.spec_decode_algo.upper(
) if args.spec_decode_algo is not None else None
if spec_decode_algo == 'MTP':
spec_config = MTPDecodingConfig(
num_nextn_predict_layers=args.spec_decode_nextn)
elif spec_decode_algo == "EAGLE3":
spec_config = EagleDecodingConfig(
max_draft_len=args.spec_decode_nextn,
pytorch_eagle_weights_path=args.eagle_model_dir)
else:
spec_config = None
llm = LLM(model=args.model_dir,
max_seq_len=args.max_seq_len,
max_batch_size=args.max_batch_size,
max_num_tokens=args.max_num_tokens,
pytorch_backend_config=pytorch_config,
kv_cache_config=kv_cache_config,
tensor_parallel_size=args.tp_size,
pipeline_parallel_size=args.pp_size,
enable_attention_dp=args.enable_attention_dp,
moe_expert_parallel_size=args.moe_ep_size,
moe_tensor_parallel_size=args.moe_tp_size,
enable_chunked_prefill=args.enable_chunked_prefill,
speculative_config=spec_config)
sampling_params = SamplingParams(
max_tokens=args.max_tokens,
temperature=args.temperature,
top_k=args.top_k,
top_p=args.top_p,
)
return llm, sampling_params
def main():
args = parse_arguments()
prompts = args.prompt if args.prompt else example_prompts
llm, sampling_params = setup_llm(args)
outputs = llm.generate(prompts, sampling_params)
for i, output in enumerate(outputs):
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"[{i}] Prompt: {prompt!r}, Generated text: {generated_text!r}")
if __name__ == '__main__':
main()