TensorRT-LLMs/tests/unittest/_torch/speculative/test_ngram.py
Yuan Tong 6c7813e821
[TRTLLM-7457][ci] Update & cleanup unittest parallel config (#7254)
Signed-off-by: Yuan Tong <13075180+tongyuantongyu@users.noreply.github.com>
Co-authored-by: QI JUN <22017000+QiJune@users.noreply.github.com>
2025-08-27 00:45:58 -04:00

76 lines
2.5 KiB
Python

import os
import sys
import unittest
import pytest
import torch
from tensorrt_llm import LLM, SamplingParams
from tensorrt_llm.llmapi import (CudaGraphConfig, KvCacheConfig,
NGramDecodingConfig)
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
from utils.llm_data import llm_models_root
# TODO: add disable_overlap_scheduler=False
@pytest.mark.parametrize(
"disable_overlap_scheduler,use_cuda_graph,attn_backend",
[[True, False, "TRTLLM"], [True, True, "TRTLLM"],
[True, False, "FLASHINFER"]])
@pytest.mark.high_cuda_memory
def test_llama_ngram(disable_overlap_scheduler: bool, use_cuda_graph: bool,
attn_backend: str):
total_mem_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
if total_mem_gb < 20:
pytest.skip("Not enough memory to load target model")
max_batch_size = 2
max_draft_len = 4
kv_cache_config = KvCacheConfig(enable_block_reuse=False, max_tokens=8192)
cuda_graph_config = CudaGraphConfig(
batch_sizes=[1]) if use_cuda_graph else None
llm_common_config = dict( \
model=llm_models_root() / "llama-3.1-model" /"Meta-Llama-3.1-8B",
backend='pytorch',
attn_backend=attn_backend,
disable_overlap_scheduler=disable_overlap_scheduler,
cuda_graph_config=cuda_graph_config,
max_batch_size=max_batch_size,
kv_cache_config=kv_cache_config,
max_num_tokens=2048,
)
spec_config = NGramDecodingConfig(
max_draft_len=max_draft_len,
max_matching_ngram_size=2,
is_keep_all=True,
is_use_oldest=True,
is_public_pool=True,
)
prompts = [
"The capital of France is",
"The president of the United States is",
]
sampling_params = SamplingParams(max_tokens=32, ignore_eos=True)
llm_spec = LLM(**llm_common_config, speculative_config=spec_config)
results_spec = llm_spec.generate(prompts, sampling_params)
generated_text_spec = [result.outputs[0].text for result in results_spec]
llm_spec.shutdown()
llm_ref = LLM(**llm_common_config)
results_ref = llm_ref.generate(prompts, sampling_params)
generated_text_ref = [result.outputs[0].text for result in results_ref]
llm_ref.shutdown()
for text_spec, text_ref in zip(generated_text_spec, generated_text_ref):
# The spec decode algorithm currently guarantees identical results
assert text_spec == text_ref
if __name__ == "__main__":
unittest.main()