mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-23 20:23:08 +08:00
* beam_width and max_new_token Signed-off-by: Superjomn <328693+Superjomn@users.noreply.github.com> * remove beam_width Signed-off-by: Superjomn <328693+Superjomn@users.noreply.github.com> * remove min_length Signed-off-by: Superjomn <328693+Superjomn@users.noreply.github.com> * remove return_num_sequences Signed-off-by: Superjomn <328693+Superjomn@users.noreply.github.com> Signed-off-by: Superjomn <328693+Superjomn@users.noreply.github.com> Signed-off-by: Superjomn <328693+Superjomn@users.noreply.github.com> Signed-off-by: Superjomn <328693+Superjomn@users.noreply.github.com> Signed-off-by: Superjomn <328693+Superjomn@users.noreply.github.com> Signed-off-by: Superjomn <328693+Superjomn@users.noreply.github.com> --------- Signed-off-by: Superjomn <328693+Superjomn@users.noreply.github.com>
83 lines
2.7 KiB
Python
83 lines
2.7 KiB
Python
import json
|
|
from pathlib import Path
|
|
|
|
import pytest
|
|
from utils.llm_data import llm_models_root
|
|
from utils.util import similar
|
|
|
|
from tensorrt_llm import SamplingParams
|
|
from tensorrt_llm._torch import LLM
|
|
from tensorrt_llm._torch.pyexecutor.config import PyTorchConfig
|
|
from tensorrt_llm.llmapi import KvCacheConfig as TRT_KvCacheConfig
|
|
|
|
|
|
# A test case of mmlu_llama from lm_eval
|
|
@pytest.fixture(scope="module")
|
|
def test_case():
|
|
with open(Path(__file__).parent / "test_overlap_scheduler_input.json") as f:
|
|
return json.load(f)
|
|
|
|
|
|
@pytest.fixture(scope="module")
|
|
def model_path():
|
|
return llm_models_root() / "llama-models-v2/TinyLlama-1.1B-Chat-v1.0"
|
|
|
|
|
|
def create_llm(model_dir):
|
|
"""Create LLM with specific overlap scheduler setting"""
|
|
pytorch_config = PyTorchConfig(use_cuda_graph=True,
|
|
enable_trtllm_decoder=True)
|
|
|
|
trt_kv_cache_config = TRT_KvCacheConfig(enable_block_reuse=False)
|
|
|
|
return LLM(
|
|
model=str(model_dir),
|
|
tensor_parallel_size=1,
|
|
trust_remote_code=True,
|
|
enable_chunked_prefill=True,
|
|
pytorch_backend_config=pytorch_config,
|
|
kv_cache_config=trt_kv_cache_config,
|
|
max_num_tokens=
|
|
128 # Only one request longer than max_num_tokens is required to test chunked prefill
|
|
)
|
|
|
|
|
|
def test_trtllm_decoder(model_path, test_case):
|
|
prompts = [
|
|
"Magellan and Elcano lead the first",
|
|
"The capital of France is",
|
|
"The capital of Bolivia is",
|
|
]
|
|
|
|
expected_outputs = [["circumnavigation of the world."], ["Paris."],
|
|
["La Paz."]]
|
|
|
|
# Test configuration
|
|
max_new_tokens = test_case["max_new_tokens"]
|
|
temperature = test_case["temperature"]
|
|
top_p = test_case["top_p"]
|
|
stop_words = test_case["stop_words"]
|
|
|
|
sampling_config = SamplingParams(max_tokens=max_new_tokens,
|
|
n=1,
|
|
stop=stop_words,
|
|
temperature=temperature,
|
|
top_p=top_p)
|
|
|
|
# Test with overlap scheduler disabled
|
|
llm = create_llm(model_path)
|
|
outputs = llm.generate(prompts,
|
|
sampling_params=sampling_config,
|
|
use_tqdm=True)
|
|
texts = [[completion.text for completion in request_output.outputs]
|
|
for request_output in outputs]
|
|
llm.shutdown()
|
|
|
|
# Remove any text after \n\n, consider texts is a list of list of strings
|
|
texts = [[text.split('\n\n')[0] for text in request_output]
|
|
for request_output in texts]
|
|
|
|
# Verify outputs are consistent
|
|
for text, expected in zip(texts, expected_outputs):
|
|
assert similar(text, expected)
|