TensorRT-LLMs/tests/unittest/_torch/test_trtllm_decoder.py
Yan Chunwei 0c26059703
chore: Cleanup deprecated APIs from LLM-API (part 1/2) (#3732)
* beam_width and max_new_token

Signed-off-by: Superjomn <328693+Superjomn@users.noreply.github.com>

* remove beam_width

Signed-off-by: Superjomn <328693+Superjomn@users.noreply.github.com>

* remove min_length

Signed-off-by: Superjomn <328693+Superjomn@users.noreply.github.com>

* remove return_num_sequences

Signed-off-by: Superjomn <328693+Superjomn@users.noreply.github.com>

Signed-off-by: Superjomn <328693+Superjomn@users.noreply.github.com>

Signed-off-by: Superjomn <328693+Superjomn@users.noreply.github.com>

Signed-off-by: Superjomn <328693+Superjomn@users.noreply.github.com>

Signed-off-by: Superjomn <328693+Superjomn@users.noreply.github.com>

Signed-off-by: Superjomn <328693+Superjomn@users.noreply.github.com>

---------

Signed-off-by: Superjomn <328693+Superjomn@users.noreply.github.com>
2025-05-07 13:20:25 +08:00

83 lines
2.7 KiB
Python

import json
from pathlib import Path
import pytest
from utils.llm_data import llm_models_root
from utils.util import similar
from tensorrt_llm import SamplingParams
from tensorrt_llm._torch import LLM
from tensorrt_llm._torch.pyexecutor.config import PyTorchConfig
from tensorrt_llm.llmapi import KvCacheConfig as TRT_KvCacheConfig
# A test case of mmlu_llama from lm_eval
@pytest.fixture(scope="module")
def test_case():
with open(Path(__file__).parent / "test_overlap_scheduler_input.json") as f:
return json.load(f)
@pytest.fixture(scope="module")
def model_path():
return llm_models_root() / "llama-models-v2/TinyLlama-1.1B-Chat-v1.0"
def create_llm(model_dir):
"""Create LLM with specific overlap scheduler setting"""
pytorch_config = PyTorchConfig(use_cuda_graph=True,
enable_trtllm_decoder=True)
trt_kv_cache_config = TRT_KvCacheConfig(enable_block_reuse=False)
return LLM(
model=str(model_dir),
tensor_parallel_size=1,
trust_remote_code=True,
enable_chunked_prefill=True,
pytorch_backend_config=pytorch_config,
kv_cache_config=trt_kv_cache_config,
max_num_tokens=
128 # Only one request longer than max_num_tokens is required to test chunked prefill
)
def test_trtllm_decoder(model_path, test_case):
prompts = [
"Magellan and Elcano lead the first",
"The capital of France is",
"The capital of Bolivia is",
]
expected_outputs = [["circumnavigation of the world."], ["Paris."],
["La Paz."]]
# Test configuration
max_new_tokens = test_case["max_new_tokens"]
temperature = test_case["temperature"]
top_p = test_case["top_p"]
stop_words = test_case["stop_words"]
sampling_config = SamplingParams(max_tokens=max_new_tokens,
n=1,
stop=stop_words,
temperature=temperature,
top_p=top_p)
# Test with overlap scheduler disabled
llm = create_llm(model_path)
outputs = llm.generate(prompts,
sampling_params=sampling_config,
use_tqdm=True)
texts = [[completion.text for completion in request_output.outputs]
for request_output in outputs]
llm.shutdown()
# Remove any text after \n\n, consider texts is a list of list of strings
texts = [[text.split('\n\n')[0] for text in request_output]
for request_output in texts]
# Verify outputs are consistent
for text, expected in zip(texts, expected_outputs):
assert similar(text, expected)