mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
Co-authored-by: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com> open source f8c0381a2bc50ee2739c3d8c2be481b31e5f00bd (#2736) Co-authored-by: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com> Add note for blackwell (#2742) Update the docs to workaround the extra-index-url issue (#2744) update README.md (#2751) Fix github io pages (#2761) Update
2158 lines
86 KiB
Python
2158 lines
86 KiB
Python
import datetime
|
|
import json
|
|
import os as _os
|
|
import pickle
|
|
import random
|
|
import sys as _sys
|
|
import time
|
|
import typing as tp
|
|
from pathlib import Path
|
|
|
|
import numpy as np
|
|
import pytest
|
|
import torch
|
|
from binding_test_utils import *
|
|
from pydantic import BaseModel
|
|
|
|
import tensorrt_llm.bindings.executor as trtllm
|
|
import tensorrt_llm.version as trtllm_version
|
|
from tensorrt_llm.models.modeling_utils import PretrainedConfig
|
|
|
|
_sys.path.append(_os.path.join(_os.path.dirname(__file__), '..'))
|
|
import inspect
|
|
|
|
from utils.cpp_paths import *
|
|
from utils.llm_data import llm_models_root
|
|
from utils.util import skip_pre_ampere
|
|
|
|
|
|
@pytest.fixture
|
|
def model_files(llm_root: Path, resource_path: Path, results_data_path: Path):
|
|
# Model engines and expected outputs need to be generated.
|
|
if not results_data_path.exists():
|
|
model_cache = llm_models_root()
|
|
model_cache_arg = ["--model_cache", str(model_cache)
|
|
] if model_cache is not None else []
|
|
prepare_model_tests(llm_root, resource_path, "gpt", model_cache_arg)
|
|
|
|
|
|
@pytest.fixture
|
|
def lora_config_paths(llm_root: Path, resource_path: Path,
|
|
lora_config_path: Path):
|
|
if not lora_config_path.exists():
|
|
prepare_lora_configs(llm_root, resource_path, lora_config_path)
|
|
return (lora_config_path / "source.npy", lora_config_path / "config.npy")
|
|
|
|
|
|
def get_expected_num_tokens(prompt_len, max_tokens, streaming,
|
|
exclude_input_from_output):
|
|
if not streaming and not exclude_input_from_output:
|
|
return prompt_len + max_tokens
|
|
return max_tokens
|
|
|
|
|
|
@skip_pre_ampere # ContextFMHAType with fp32 acc is not supported in pre-ampere architecture
|
|
def test_executor_valid_ctor(model_files, model_path):
|
|
executor_config = trtllm.ExecutorConfig(1)
|
|
executor = trtllm.Executor(model_path, trtllm.ModelType.DECODER_ONLY,
|
|
executor_config)
|
|
|
|
|
|
@skip_pre_ampere # ContextFMHAType with fp32 acc is not supported in pre-ampere architecture
|
|
def test_executor_from_memory(model_files, model_path):
|
|
executor_config = trtllm.ExecutorConfig(1)
|
|
engine_buffer = open(model_path / "rank0.engine", mode="rb").read()
|
|
json_config_str = open(model_path / "config.json", 'r').read()
|
|
executor = trtllm.Executor(engine_buffer, json_config_str,
|
|
trtllm.ModelType.DECODER_ONLY, executor_config)
|
|
|
|
|
|
def test_executor_invalid_ctor():
|
|
executor_config = trtllm.ExecutorConfig(1)
|
|
invalid_path = "Bla"
|
|
try:
|
|
executor = trtllm.Executor(invalid_path, trtllm.ModelType.DECODER_ONLY,
|
|
executor_config)
|
|
assert False, "Expected an error"
|
|
except Exception as e:
|
|
assert "File does not exist" in str(e)
|
|
|
|
|
|
@skip_pre_ampere # ContextFMHAType with fp32 acc is not supported in pre-ampere architecture
|
|
def test_shutdown(model_files, model_path):
|
|
|
|
beam_width = 1
|
|
executor_config = trtllm.ExecutorConfig(beam_width)
|
|
executor = trtllm.Executor(model_path, trtllm.ModelType.DECODER_ONLY,
|
|
executor_config)
|
|
|
|
# Create the request
|
|
max_tokens = 5
|
|
input_tokens = [1, 2, 3, 4]
|
|
request = trtllm.Request(input_tokens,
|
|
max_tokens=max_tokens,
|
|
streaming=False,
|
|
sampling_config=trtllm.SamplingConfig())
|
|
|
|
# Enqueue the request
|
|
assert executor.can_enqueue_requests() == True
|
|
req_id = executor.enqueue_request(request)
|
|
|
|
executor.shutdown()
|
|
assert executor.can_enqueue_requests() == False
|
|
|
|
with pytest.raises(Exception):
|
|
executor.enqueue_request(request)
|
|
with pytest.raises(Exception):
|
|
executor.await_responses()
|
|
with pytest.raises(Exception):
|
|
executor.get_latest_iteration_stats()
|
|
with pytest.raises(Exception):
|
|
executor.get_latest_request_stats()
|
|
with pytest.raises(Exception):
|
|
executor.get_latest_debug_tensors()
|
|
with pytest.raises(Exception):
|
|
executor.cancel_request(req_id)
|
|
with pytest.raises(Exception):
|
|
executor.get_num_responses_ready(req_id)
|
|
|
|
|
|
@skip_pre_ampere # ContextFMHAType with fp32 acc is not supported in pre-ampere architecture
|
|
def test_embedding_bias(model_files, model_path):
|
|
streaming = False
|
|
exclude_input_from_output = False
|
|
output_config = trtllm.OutputConfig()
|
|
output_config.exclude_input_from_output = exclude_input_from_output
|
|
|
|
# Create executor
|
|
beam_width = 1
|
|
executor_config = trtllm.ExecutorConfig(beam_width)
|
|
executor = trtllm.Executor(model_path, trtllm.ModelType.DECODER_ONLY,
|
|
executor_config)
|
|
|
|
# Create the request
|
|
max_tokens = 5
|
|
input_tokens = [1, 2, 3, 4]
|
|
# Set embedding bias so "biased_output" is always picked
|
|
biased_output = 10
|
|
vocab_size_padded = 50257
|
|
embedding_bias = torch.zeros(vocab_size_padded)
|
|
embedding_bias[biased_output] = torch.finfo(torch.float32).max
|
|
request = trtllm.Request(input_tokens,
|
|
max_tokens=max_tokens,
|
|
streaming=streaming,
|
|
sampling_config=trtllm.SamplingConfig(),
|
|
output_config=output_config,
|
|
embedding_bias=embedding_bias)
|
|
|
|
# Enqueue the request
|
|
request_id = executor.enqueue_request(request)
|
|
|
|
# Get the new tokens
|
|
tokens = []
|
|
done = False
|
|
i = 0
|
|
max_wait_ms = 10000
|
|
while not done and i < max_wait_ms:
|
|
wait_time = datetime.timedelta(milliseconds=1)
|
|
responses = executor.await_responses(request_id, wait_time)
|
|
for response in responses:
|
|
assert not response.has_error(
|
|
), f"Request id {request_id} failed with err {response.error_msg}"
|
|
result = response.result
|
|
done = result.is_final
|
|
new_tokens = result.output_token_ids[beam_width - 1]
|
|
tokens.extend(new_tokens)
|
|
i += 1
|
|
assert i < max_wait_ms
|
|
assert len(tokens) == get_expected_num_tokens(
|
|
len(input_tokens), max_tokens, streaming,
|
|
exclude_input_from_output), f"{request_id}"
|
|
# All generated tokens should equal biased_output
|
|
assert tokens[-max_tokens:] == [biased_output] * max_tokens
|
|
|
|
|
|
@pytest.mark.parametrize("streaming", [False, True])
|
|
@pytest.mark.parametrize("exclude_input_from_output", [False])
|
|
@skip_pre_ampere # ContextFMHAType with fp32 acc is not supported in pre-ampere architecture
|
|
def test_single_request(streaming: bool, exclude_input_from_output: bool,
|
|
model_files, model_path):
|
|
output_config = trtllm.OutputConfig()
|
|
output_config.exclude_input_from_output = exclude_input_from_output
|
|
|
|
# Create executor
|
|
beam_width = 1
|
|
executor_config = trtllm.ExecutorConfig(beam_width)
|
|
executor = trtllm.Executor(model_path, trtllm.ModelType.DECODER_ONLY,
|
|
executor_config)
|
|
|
|
# Create the request
|
|
max_tokens = 5
|
|
input_tokens = [1, 2, 3, 4]
|
|
request = trtllm.Request(input_tokens,
|
|
max_tokens=max_tokens,
|
|
streaming=streaming,
|
|
sampling_config=trtllm.SamplingConfig(),
|
|
output_config=output_config)
|
|
|
|
# Enqueue the request
|
|
request_id = executor.enqueue_request(request)
|
|
|
|
# Get the new tokens
|
|
tokens = []
|
|
done = False
|
|
i = 0
|
|
max_wait_ms = 10000
|
|
while not done and i < max_wait_ms:
|
|
wait_time = datetime.timedelta(milliseconds=1)
|
|
responses = executor.await_responses(request_id, wait_time)
|
|
for response in responses:
|
|
assert not response.has_error(
|
|
), f"Request id {request_id} failed with err {response.error_msg}"
|
|
result = response.result
|
|
done = result.is_final
|
|
new_tokens = result.output_token_ids[beam_width - 1]
|
|
tokens.extend(new_tokens)
|
|
i += 1
|
|
assert i < max_wait_ms
|
|
assert len(tokens) == get_expected_num_tokens(
|
|
len(input_tokens), max_tokens, streaming,
|
|
exclude_input_from_output), f"{request_id}"
|
|
|
|
executor.get_latest_iteration_stats()
|
|
executor.get_latest_request_stats()
|
|
executor.get_latest_debug_tensors()
|
|
|
|
|
|
@skip_pre_ampere # ContextFMHAType with fp32 acc is not supported in pre-ampere architecture
|
|
def test_single_request_lora(model_files, model_path_lora, lora_config_paths):
|
|
streaming = False
|
|
exclude_input_from_output = False
|
|
output_config = trtllm.OutputConfig()
|
|
output_config.exclude_input_from_output = exclude_input_from_output
|
|
|
|
# Create executor
|
|
beam_width = 1
|
|
|
|
peft_cache_config = trtllm.PeftCacheConfig(num_put_workers=4,
|
|
num_ensure_workers=4)
|
|
executor_config = trtllm.ExecutorConfig(1,
|
|
peft_cache_config=peft_cache_config)
|
|
executor = trtllm.Executor(model_path_lora, trtllm.ModelType.DECODER_ONLY,
|
|
executor_config)
|
|
|
|
# Create the request
|
|
max_tokens = 5
|
|
input_tokens = [1, 2, 3, 4]
|
|
lora_weights = torch.tensor(np.load(lora_config_paths[0])).half()
|
|
lora_config = torch.tensor(np.load(lora_config_paths[1]))
|
|
request = trtllm.Request(input_tokens,
|
|
max_tokens=max_tokens,
|
|
streaming=streaming,
|
|
sampling_config=trtllm.SamplingConfig(),
|
|
output_config=output_config,
|
|
lora_config=trtllm.LoraConfig(
|
|
0, lora_weights, lora_config))
|
|
|
|
# Enqueue the request
|
|
request_id = executor.enqueue_request(request)
|
|
|
|
# Get the new tokens
|
|
tokens = []
|
|
done = False
|
|
i = 0
|
|
max_wait_ms = 10000
|
|
while not done and i < max_wait_ms:
|
|
wait_time = datetime.timedelta(milliseconds=1)
|
|
responses = executor.await_responses(request_id, wait_time)
|
|
for response in responses:
|
|
assert not response.has_error(
|
|
), f"Request id {request_id} failed with err {response.error_msg}"
|
|
result = response.result
|
|
done = result.is_final
|
|
new_tokens = result.output_token_ids[beam_width - 1]
|
|
tokens.extend(new_tokens)
|
|
i += 1
|
|
assert i < max_wait_ms
|
|
assert len(tokens) == get_expected_num_tokens(
|
|
len(input_tokens), max_tokens, streaming,
|
|
exclude_input_from_output), f"{request_id}"
|
|
|
|
|
|
@pytest.mark.parametrize("streaming", [False, True])
|
|
@pytest.mark.parametrize("exclude_input_from_output", [False])
|
|
@skip_pre_ampere # ContextFMHAType with fp32 acc is not supported in pre-ampere architecture
|
|
def test_multi_request(streaming: bool, exclude_input_from_output: bool,
|
|
model_files, model_path):
|
|
output_config = trtllm.OutputConfig()
|
|
output_config.exclude_input_from_output = exclude_input_from_output
|
|
|
|
# Create executor
|
|
beam_width = 1
|
|
executor_config = trtllm.ExecutorConfig(beam_width)
|
|
executor = trtllm.Executor(model_path, trtllm.ModelType.DECODER_ONLY,
|
|
executor_config)
|
|
|
|
num_requests = 20
|
|
max_prompt_len = 20
|
|
max_max_tokens = 20
|
|
end_id = -1
|
|
|
|
# Enqueue the requests
|
|
tokens = {}
|
|
expected_num_tokens = {}
|
|
for i in range(num_requests):
|
|
prompt_len = random.randint(1, max_prompt_len)
|
|
max_tokens = random.randint(1, max_max_tokens)
|
|
input_tokens = [1] * prompt_len
|
|
|
|
# Some requests has num_return_sequences > 1.
|
|
num_return_sequences = 2 if i % 5 == 1 else 1
|
|
|
|
request = trtllm.Request(input_tokens,
|
|
max_tokens=max_tokens,
|
|
streaming=streaming,
|
|
sampling_config=trtllm.SamplingConfig(
|
|
num_return_sequences=num_return_sequences),
|
|
output_config=output_config,
|
|
end_id=end_id)
|
|
request_id = executor.enqueue_request(request)
|
|
tokens[request_id] = [
|
|
[] for _ in range(request.sampling_config.num_return_sequences)
|
|
]
|
|
expected_num_tokens[request_id] = get_expected_num_tokens(
|
|
prompt_len, max_tokens, streaming, exclude_input_from_output)
|
|
|
|
# Get the new tokens for each request
|
|
num_finished = 0
|
|
i = 0
|
|
num_responses = 0
|
|
max_wait_ms = 10000
|
|
while num_finished < num_requests and i < max_wait_ms:
|
|
wait_time = datetime.timedelta(milliseconds=1)
|
|
responses = executor.await_responses(wait_time)
|
|
for response in responses:
|
|
num_responses += 1
|
|
assert not response.has_error(
|
|
), f"Request id {response.request_id} failed with err {response.error_msg}"
|
|
result = response.result
|
|
num_finished += result.is_final
|
|
new_tokens = result.output_token_ids[beam_width - 1]
|
|
tokens[response.request_id][result.sequence_index].extend(
|
|
new_tokens)
|
|
i += 1
|
|
assert i < max_wait_ms
|
|
|
|
for request_id in expected_num_tokens:
|
|
for actual_tokens in tokens[request_id]:
|
|
assert len(actual_tokens) == expected_num_tokens[request_id]
|
|
|
|
|
|
@pytest.mark.parametrize("streaming", [False, True])
|
|
@pytest.mark.parametrize("exclude_input_from_output", [False])
|
|
@skip_pre_ampere # ContextFMHAType with fp32 acc is not supported in pre-ampere architecture
|
|
def test_multi_request_with_ids(streaming: bool,
|
|
exclude_input_from_output: bool, model_files,
|
|
model_path):
|
|
output_config = trtllm.OutputConfig()
|
|
output_config.exclude_input_from_output = exclude_input_from_output
|
|
|
|
# Create executor
|
|
beam_width = 1
|
|
executor_config = trtllm.ExecutorConfig(beam_width)
|
|
executor = trtllm.Executor(model_path, trtllm.ModelType.DECODER_ONLY,
|
|
executor_config)
|
|
|
|
num_requests = 20
|
|
max_prompt_len = 20
|
|
max_max_tokens = 20
|
|
end_id = -1
|
|
|
|
# Enqueue the requests
|
|
tokens = {}
|
|
expected_num_tokens = {}
|
|
for i in range(num_requests):
|
|
prompt_len = random.randint(1, max_prompt_len)
|
|
max_tokens = random.randint(1, max_max_tokens)
|
|
input_tokens = [1] * prompt_len
|
|
num_return_sequences = 2 if i % 5 == 1 else 1
|
|
|
|
request = trtllm.Request(input_tokens,
|
|
max_tokens=max_tokens,
|
|
streaming=streaming,
|
|
sampling_config=trtllm.SamplingConfig(
|
|
num_return_sequences=num_return_sequences),
|
|
output_config=output_config,
|
|
end_id=end_id)
|
|
request_id = executor.enqueue_request(request)
|
|
tokens[request_id] = [
|
|
[] for _ in range(request.sampling_config.num_return_sequences)
|
|
]
|
|
expected_num_tokens[request_id] = get_expected_num_tokens(
|
|
prompt_len, max_tokens, streaming, exclude_input_from_output)
|
|
|
|
# Get the new tokens for each request
|
|
num_finished = 0
|
|
i = 0
|
|
num_responses = 0
|
|
max_wait_ms = 10000
|
|
while num_finished < num_requests and i < max_wait_ms:
|
|
wait_time = datetime.timedelta(milliseconds=1)
|
|
id_responses = executor.await_responses(list(tokens.keys()), wait_time)
|
|
for responses in id_responses:
|
|
for response in responses:
|
|
num_responses += 1
|
|
# Allow response with error only if await_response processed a terminated request id
|
|
if response.has_error():
|
|
terminated_request_error = "ReqId " + str(
|
|
response.request_id
|
|
) + " has already been processed and was terminated."
|
|
assert response.error_msg == terminated_request_error, (
|
|
f"Request id {response.request_id} failed with err "
|
|
f"{response.error_msg}")
|
|
else:
|
|
result = response.result
|
|
num_finished += result.is_final
|
|
new_tokens = result.output_token_ids[beam_width - 1]
|
|
tokens[response.request_id][result.sequence_index].extend(
|
|
new_tokens)
|
|
i += 1
|
|
assert i < max_wait_ms
|
|
|
|
for request_id in expected_num_tokens:
|
|
for seq_idx, actual_tokens in enumerate(tokens[request_id]):
|
|
assert len(actual_tokens) == expected_num_tokens[request_id]
|
|
|
|
|
|
@pytest.mark.parametrize("streaming", [False, True])
|
|
@pytest.mark.parametrize("exclude_input_from_output", [False])
|
|
@skip_pre_ampere # ContextFMHAType with fp32 acc is not supported in pre-ampere architecture
|
|
def test_get_num_responses_ready(streaming: bool,
|
|
exclude_input_from_output: bool, model_files,
|
|
model_path):
|
|
output_config = trtllm.OutputConfig()
|
|
output_config.exclude_input_from_output = exclude_input_from_output
|
|
|
|
# Create executor
|
|
executor_config = trtllm.ExecutorConfig(1)
|
|
executor = trtllm.Executor(model_path, trtllm.ModelType.DECODER_ONLY,
|
|
executor_config)
|
|
|
|
max_prompt_len = 20
|
|
max_max_tokens = 20
|
|
|
|
# Enqueue the requests
|
|
num_requests = random.randint(1, 50)
|
|
num_expected_responses = 0
|
|
req_num_expected_responses = {}
|
|
for i in range(num_requests):
|
|
prompt_len = random.randint(1, max_prompt_len)
|
|
max_tokens = random.randint(1, max_max_tokens)
|
|
num_return_sequences = 2 if i % 5 == 1 else 1
|
|
|
|
request = trtllm.Request([1] * prompt_len,
|
|
max_tokens=max_tokens,
|
|
streaming=streaming,
|
|
sampling_config=trtllm.SamplingConfig(
|
|
num_return_sequences=num_return_sequences),
|
|
output_config=output_config)
|
|
request_id = executor.enqueue_request(request)
|
|
req_num_expected_responses[request_id] = (
|
|
(max_tokens if streaming else 1) * num_return_sequences)
|
|
num_expected_responses += req_num_expected_responses[request_id]
|
|
|
|
i = 0
|
|
num_ready = 0
|
|
max_wait_ms = 10000
|
|
while num_ready < num_expected_responses and i < max_wait_ms:
|
|
num_ready = 0
|
|
for request_id in req_num_expected_responses:
|
|
num_ready += executor.get_num_responses_ready(request_id)
|
|
time.sleep(0.001)
|
|
i += 1
|
|
assert i < max_wait_ms
|
|
|
|
for request_id in req_num_expected_responses:
|
|
num_ready = executor.get_num_responses_ready(request_id)
|
|
assert num_ready == req_num_expected_responses[request_id]
|
|
assert executor.get_num_responses_ready() == num_expected_responses
|
|
|
|
|
|
@pytest.mark.parametrize("batching_type", [trtllm.BatchingType.INFLIGHT])
|
|
@pytest.mark.parametrize("streaming", [False, True])
|
|
@pytest.mark.parametrize("beam_width", [1])
|
|
@pytest.mark.parametrize("compute_log_probs", [False, True])
|
|
@pytest.mark.parametrize("exclude_input_from_output", [False])
|
|
@pytest.mark.parametrize("return_context_logits", [False, True])
|
|
@pytest.mark.parametrize("return_generation_logits", [False, True])
|
|
@skip_pre_ampere # ContextFMHAType with fp32 acc is not supported in pre-ampere architecture
|
|
def test_token_comparison(batching_type: trtllm.BatchingType, streaming: bool,
|
|
beam_width: int, compute_log_probs: bool,
|
|
exclude_input_from_output: bool,
|
|
return_context_logits: bool,
|
|
return_generation_logits: bool, model_files,
|
|
model_path, model_path_return_logits, input_data_path,
|
|
results_data_path, results_data_path_beam_width_2):
|
|
if streaming and beam_width > 1:
|
|
pytest.skip("Test does not support streaming with beam search")
|
|
|
|
vocab_size_padded = 50257
|
|
pad_id = 50256
|
|
remove_input = not exclude_input_from_output and not streaming
|
|
|
|
def load_test_data(input_path, results_path):
|
|
# Inputs
|
|
assert input_path.is_file()
|
|
given_input = np.load(input_path).astype("int32")
|
|
input_shape = given_input.shape
|
|
assert len(input_shape) == 2
|
|
max_input_length = input_shape[1]
|
|
given_input_lengths = sequence_lengths(given_input, pad_id)
|
|
assert np.all(given_input_lengths <= max_input_length)
|
|
# Expected results
|
|
assert results_path.is_file()
|
|
expected_outputs = np.load(results_path).astype("int32")
|
|
output_shape = expected_outputs.shape
|
|
assert len(output_shape) == 2
|
|
assert input_shape[0] * beam_width == output_shape[0]
|
|
max_seq_length = output_shape[1]
|
|
max_tokens = max_seq_length - max_input_length
|
|
|
|
end_ids = [pad_id for _ in range(len(given_input_lengths))]
|
|
expected_lengths = []
|
|
for i in range(len(given_input_lengths)):
|
|
expected_lengths.append([
|
|
given_input_lengths[i] + max_tokens for _ in range(beam_width)
|
|
])
|
|
|
|
test_data = {
|
|
"expected_output_ids": expected_outputs,
|
|
"expected_output_lengths": expected_lengths,
|
|
"max_seq_length": max_seq_length,
|
|
"end_ids": end_ids
|
|
}
|
|
return given_input, given_input_lengths, max_input_length, test_data
|
|
|
|
def validate_results_shapes(result, input_length, max_output_len,
|
|
beam_tokens):
|
|
if compute_log_probs:
|
|
assert result.cum_log_probs is not None
|
|
assert result.log_probs is not None
|
|
assert len(result.cum_log_probs) == beam_width
|
|
assert len(result.log_probs) == beam_width
|
|
for beam in range(beam_width):
|
|
expected_len = len(
|
|
beam_tokens[beam]) - (input_length if remove_input else 0)
|
|
assert len(result.log_probs[beam]) == expected_len
|
|
else:
|
|
assert result.cum_log_probs is None
|
|
assert result.log_probs is None
|
|
if return_context_logits:
|
|
assert result.context_logits is not None
|
|
assert len(result.context_logits.shape) == 2
|
|
assert list(result.context_logits.shape) == [
|
|
input_length, vocab_size_padded
|
|
]
|
|
else:
|
|
assert result.context_logits is None
|
|
if return_generation_logits:
|
|
assert len(result.generation_logits.shape) == 3
|
|
if streaming:
|
|
assert list(result.generation_logits.shape) == [
|
|
max_output_len, beam_width, vocab_size_padded
|
|
] or list(result.generation_logits.shape) == [
|
|
1, beam_width, vocab_size_padded
|
|
]
|
|
else:
|
|
assert list(result.generation_logits.shape) == [
|
|
beam_width, max_output_len, vocab_size_padded
|
|
]
|
|
|
|
def verify_output(beam_tokens, test_data, given_input_lengths):
|
|
for batch_id, seq_tokens in beam_tokens.items():
|
|
input_length = given_input_lengths[batch_id]
|
|
end_id = test_data["end_ids"][batch_id]
|
|
for tokens in seq_tokens:
|
|
for beam in range(beam_width):
|
|
predicted_tokens = tokens[beam]
|
|
if remove_input:
|
|
predicted_tokens = predicted_tokens[input_length:]
|
|
expected_length = test_data["expected_output_lengths"][
|
|
batch_id][beam] - input_length
|
|
assert len(predicted_tokens) == expected_length
|
|
expected_tokens = test_data["expected_output_ids"][
|
|
batch_id * beam_width + beam][input_length:]
|
|
for i in range(len(predicted_tokens)):
|
|
if expected_tokens[i] == end_id:
|
|
break
|
|
assert predicted_tokens[i] == expected_tokens[i], \
|
|
f"Predicted: {predicted_tokens} vs Expected: {expected_tokens}"
|
|
|
|
output_config = trtllm.OutputConfig()
|
|
output_config.exclude_input_from_output = exclude_input_from_output
|
|
output_config.return_log_probs = compute_log_probs
|
|
output_config.return_generation_logits = return_generation_logits
|
|
output_config.return_context_logits = return_context_logits
|
|
|
|
kv_cache_config = trtllm.KvCacheConfig(False, free_gpu_memory_fraction=0.5)
|
|
executor_config = trtllm.ExecutorConfig(beam_width)
|
|
executor_config.batching_type = batching_type
|
|
executor_config.kv_cache_config = kv_cache_config
|
|
|
|
if return_context_logits or return_generation_logits:
|
|
model_path = model_path_return_logits
|
|
executor = trtllm.Executor(model_path, trtllm.ModelType.DECODER_ONLY,
|
|
executor_config)
|
|
|
|
# Load test data
|
|
results_path = results_data_path if beam_width == 1 else results_data_path_beam_width_2
|
|
given_input, given_input_lengths, max_input_length, test_data = load_test_data(
|
|
input_data_path, results_path)
|
|
|
|
# Create requests from input data
|
|
num_requests = len(given_input_lengths)
|
|
requests = []
|
|
req_max_tokens = []
|
|
|
|
for i in range(num_requests):
|
|
input_len = given_input_lengths[i]
|
|
max_tokens = test_data["max_seq_length"] - max_input_length
|
|
req_max_tokens.append(max_tokens)
|
|
req_tokens = given_input[i][:input_len]
|
|
num_return_sequences = 2 if i % 5 == 1 else 1
|
|
request = trtllm.Request(req_tokens,
|
|
max_tokens=max_tokens,
|
|
streaming=streaming,
|
|
sampling_config=trtllm.SamplingConfig(
|
|
beam_width,
|
|
num_return_sequences=num_return_sequences),
|
|
output_config=output_config,
|
|
end_id=-1)
|
|
requests.append(request)
|
|
|
|
req_ids = executor.enqueue_requests(requests)
|
|
|
|
req_to_batch_id = {req_ids[i]: i for i in range(len(requests))}
|
|
tokens = {
|
|
i: [[[] for _ in range(beam_width)]
|
|
for _ in range(req.sampling_config.num_return_sequences)]
|
|
for i, req in enumerate(requests)
|
|
}
|
|
|
|
num_finished = 0
|
|
i = 0
|
|
num_responses = 0
|
|
max_wait_ms = 10000
|
|
while num_finished < num_requests and i < max_wait_ms:
|
|
wait_time = datetime.timedelta(milliseconds=1)
|
|
responses = executor.await_responses(wait_time)
|
|
for response in responses:
|
|
num_responses += 1
|
|
assert not response.has_error(
|
|
), f"Request id {response.request_id} failed with err {response.error_msg}"
|
|
result = response.result
|
|
num_finished += result.is_final
|
|
|
|
batch_id = req_to_batch_id[response.request_id]
|
|
for beam in range(beam_width):
|
|
new_tokens = result.output_token_ids[beam]
|
|
tokens[batch_id][result.sequence_index][beam] += new_tokens
|
|
|
|
validate_results_shapes(result, given_input_lengths[batch_id],
|
|
req_max_tokens[batch_id],
|
|
tokens[batch_id][result.sequence_index])
|
|
i += 1
|
|
assert i < max_wait_ms
|
|
verify_output(tokens, test_data, given_input_lengths)
|
|
|
|
|
|
@pytest.mark.parametrize("streaming", [False, True])
|
|
@pytest.mark.parametrize("beam_width", [1])
|
|
@skip_pre_ampere # ContextFMHAType with fp32 acc is not supported in pre-ampere architecture
|
|
def test_finish_reason(streaming: bool, beam_width: int, model_files,
|
|
model_path):
|
|
if streaming and beam_width > 1:
|
|
pytest.skip("Test does not support streaming with beam search")
|
|
executor = trtllm.Executor(model_path, trtllm.ModelType.DECODER_ONLY,
|
|
trtllm.ExecutorConfig(beam_width))
|
|
requests = [
|
|
# Finish due to length.
|
|
trtllm.Request([1, 2, 3, 4],
|
|
max_tokens=5,
|
|
streaming=streaming,
|
|
sampling_config=trtllm.SamplingConfig(beam_width)),
|
|
# Finish due to end id.
|
|
trtllm.Request([1, 2, 3, 4],
|
|
max_tokens=5,
|
|
streaming=streaming,
|
|
sampling_config=trtllm.SamplingConfig(beam_width),
|
|
end_id=4),
|
|
# Finish due to stop word.
|
|
trtllm.Request([1, 2, 3, 4],
|
|
max_tokens=5,
|
|
streaming=streaming,
|
|
sampling_config=trtllm.SamplingConfig(beam_width),
|
|
stop_words=[[4, 2]]),
|
|
]
|
|
req_ids = executor.enqueue_requests(requests)
|
|
req_to_batch_id = {req_ids[i]: i for i in range(len(requests))}
|
|
|
|
num_finished = 0
|
|
i = 0
|
|
num_responses = 0
|
|
max_wait_ms = 10000
|
|
while num_finished < len(requests) and i < max_wait_ms:
|
|
wait_time = datetime.timedelta(milliseconds=1)
|
|
responses = executor.await_responses(wait_time)
|
|
for response in responses:
|
|
num_responses += 1
|
|
assert not response.has_error(
|
|
), f"Request id {response.request_id} failed with err {response.error_msg}"
|
|
result = response.result
|
|
num_finished += result.is_final
|
|
batch_id = req_to_batch_id[response.request_id]
|
|
|
|
# Non final results should have "NOT_FINISHED". Revise this when streaming + beam_width > 1 is enabled.
|
|
if not result.is_final:
|
|
assert all([
|
|
r == trtllm.FinishReason.NOT_FINISHED
|
|
for r in result.finish_reasons
|
|
])
|
|
# Check if finish reason is correct.
|
|
elif batch_id == 0:
|
|
assert all([
|
|
r == trtllm.FinishReason.LENGTH
|
|
for r in result.finish_reasons
|
|
])
|
|
elif batch_id == 1:
|
|
assert all([
|
|
r == trtllm.FinishReason.END_ID
|
|
for r in result.finish_reasons
|
|
])
|
|
elif batch_id == 2:
|
|
assert all([
|
|
r == trtllm.FinishReason.STOP_WORDS
|
|
for r in result.finish_reasons
|
|
])
|
|
i += 1
|
|
assert i < max_wait_ms
|
|
|
|
|
|
@skip_pre_ampere # ContextFMHAType with fp32 acc is not supported in pre-ampere architecture
|
|
def test_gpt_executor_timed_out(model_files, model_path):
|
|
beam_width = 1
|
|
executor_config = trtllm.ExecutorConfig(beam_width)
|
|
executor = trtllm.Executor(model_path, trtllm.ModelType.DECODER_ONLY,
|
|
executor_config)
|
|
|
|
# No requests enqueued, expect no responses
|
|
num_responses_ready = executor.get_num_responses_ready()
|
|
assert num_responses_ready == 0
|
|
|
|
wait_time = datetime.timedelta(milliseconds=10)
|
|
responses = executor.await_responses(wait_time)
|
|
assert len(responses) == 0
|
|
|
|
|
|
@skip_pre_ampere # ContextFMHAType with fp32 acc is not supported in pre-ampere architecture
|
|
def test_single_request_invalid_inputs(model_files, model_path):
|
|
streaming = True
|
|
beam_width = 1
|
|
executor_config = trtllm.ExecutorConfig(beam_width)
|
|
executor = trtllm.Executor(model_path, trtllm.ModelType.DECODER_ONLY,
|
|
executor_config)
|
|
|
|
max_tokens = 5
|
|
input_tokens = [1, 2, 3, 4]
|
|
request = trtllm.Request(input_tokens,
|
|
max_tokens=max_tokens,
|
|
streaming=streaming)
|
|
# Invalid embedding bias shape
|
|
embedding_bias = torch.ones(1)
|
|
request.embedding_bias = embedding_bias
|
|
expected_error_msg = "embedding bias shape is not as expected"
|
|
|
|
request_id = executor.enqueue_request(request)
|
|
|
|
done = False
|
|
i = 0
|
|
max_wait_ms = 10000
|
|
while not done and i < max_wait_ms:
|
|
wait_time = datetime.timedelta(milliseconds=1)
|
|
responses = executor.await_responses(request_id, wait_time)
|
|
for response in responses:
|
|
assert response.has_error(), "Expected an error"
|
|
assert expected_error_msg in response.error_msg
|
|
done = True
|
|
i += 1
|
|
assert done
|
|
|
|
|
|
def test_sampling_config():
|
|
beam_width = 1
|
|
kwargs = {
|
|
"top_k": 2,
|
|
"top_p": 1.0,
|
|
"top_p_min": 1.0,
|
|
"top_p_reset_ids": 3,
|
|
"top_p_decay": 1.0,
|
|
"seed": 7,
|
|
"temperature": 1.0,
|
|
"min_tokens": 4,
|
|
"beam_search_diversity_rate": 1.0,
|
|
"repetition_penalty": 1.0,
|
|
"presence_penalty": 1.0,
|
|
"frequency_penalty": 1.0,
|
|
"length_penalty": 1.0,
|
|
"early_stopping": 5,
|
|
"num_return_sequences": 2,
|
|
}
|
|
config = trtllm.SamplingConfig(beam_width, **kwargs)
|
|
for k, v in kwargs.items():
|
|
assert getattr(config, k) == v
|
|
del config
|
|
|
|
config = trtllm.SamplingConfig(beam_width)
|
|
assert config.beam_width == beam_width
|
|
for k in kwargs:
|
|
assert getattr(config, k) is None
|
|
|
|
|
|
def test_sampling_config_deprecated_args():
|
|
# random_seed -> seed
|
|
config = trtllm.SamplingConfig(seed=1)
|
|
assert config.seed == 1
|
|
assert config.random_seed == 1
|
|
config = trtllm.SamplingConfig(random_seed=2)
|
|
assert config.seed == 2
|
|
assert config.random_seed == 2
|
|
config = trtllm.SamplingConfig(seed=3, random_seed=4)
|
|
assert config.seed == 3
|
|
assert config.random_seed == 3
|
|
|
|
# min_length -> min_tokens
|
|
config = trtllm.SamplingConfig(min_tokens=1)
|
|
assert config.min_tokens == 1
|
|
assert config.min_length == 1
|
|
config = trtllm.SamplingConfig(min_length=2)
|
|
assert config.min_tokens == 2
|
|
assert config.min_length == 2
|
|
config = trtllm.SamplingConfig(min_tokens=3, min_length=4)
|
|
assert config.min_tokens == 3
|
|
assert config.min_length == 3
|
|
|
|
|
|
def test_output_config():
|
|
config = trtllm.OutputConfig()
|
|
assert config.return_log_probs == False
|
|
assert config.return_context_logits == False
|
|
assert config.return_generation_logits == False
|
|
assert config.exclude_input_from_output == False
|
|
assert config.return_encoder_output == False
|
|
assert config.return_perf_metrics == False
|
|
assert config.additional_model_outputs == None
|
|
|
|
config = trtllm.OutputConfig(
|
|
True, False, True, False, True, False,
|
|
list([trtllm.AdditionalModelOutput("topKLogits", True)]))
|
|
assert config.return_log_probs == True
|
|
assert config.return_context_logits == False
|
|
assert config.return_generation_logits == True
|
|
assert config.exclude_input_from_output == False
|
|
assert config.return_encoder_output == True
|
|
assert config.return_perf_metrics == False
|
|
assert len(config.additional_model_outputs) == 1
|
|
additional_model_output = config.additional_model_outputs[0]
|
|
assert additional_model_output.name == "topKLogits"
|
|
assert additional_model_output.gather_context == True
|
|
|
|
|
|
def test_external_draft_tokens_config():
|
|
tokens = [1, 2, 3]
|
|
config = trtllm.ExternalDraftTokensConfig(tokens)
|
|
assert config.tokens == tokens
|
|
assert config.logits is None
|
|
assert config.acceptance_threshold is None
|
|
del config
|
|
|
|
logits = torch.ones(3, 1)
|
|
acceptance_threshold = 1.0
|
|
fast_logits = False
|
|
config = trtllm.ExternalDraftTokensConfig(tokens, logits,
|
|
acceptance_threshold, fast_logits)
|
|
assert config.tokens == tokens
|
|
assert (config.logits == logits).all()
|
|
assert config.acceptance_threshold == acceptance_threshold
|
|
assert config.fast_logits == fast_logits
|
|
|
|
|
|
def test_prompt_tuning_config():
|
|
embedding_table = torch.ones(100, 64)
|
|
config = trtllm.PromptTuningConfig(embedding_table)
|
|
assert (config.embedding_table == embedding_table).all()
|
|
|
|
|
|
def test_mrope_config():
|
|
mrope_rotary_cos_sin = torch.ones(1, 4194304)
|
|
mrope_position_deltas = torch.tensor([-50])
|
|
config = trtllm.MropeConfig(mrope_rotary_cos_sin, mrope_position_deltas)
|
|
assert (config.mrope_rotary_cos_sin == mrope_rotary_cos_sin).all()
|
|
assert (config.mrope_position_deltas == mrope_position_deltas).all()
|
|
|
|
|
|
def test_lora_config():
|
|
task_id = 1
|
|
lora_config = trtllm.LoraConfig(task_id)
|
|
assert lora_config.task_id == task_id
|
|
assert lora_config.weights is None
|
|
assert lora_config.config is None
|
|
|
|
task_id = 2
|
|
weights = torch.ones(1, 2)
|
|
config = torch.ones(1, 2, dtype=torch.int32)
|
|
lora_config = trtllm.LoraConfig(task_id, weights, config)
|
|
assert lora_config.task_id == task_id
|
|
assert (lora_config.weights == weights).all()
|
|
assert (lora_config.config == config).all()
|
|
|
|
|
|
@skip_pre_ampere # ContextFMHAType with fp32 acc is not supported in pre-ampere architecture
|
|
def test_wakeup(model_files, model_path):
|
|
import threading
|
|
|
|
def resp_thread(stop_signal: threading.Event, executor: trtllm.Executor):
|
|
while not stop_signal.is_set():
|
|
timeout = None
|
|
responses = executor.await_responses(timeout=timeout)
|
|
if stop_signal.is_set():
|
|
return
|
|
for response in responses:
|
|
response.result.output_token_ids
|
|
|
|
executor = trtllm.Executor(model_path, trtllm.ModelType.DECODER_ONLY,
|
|
trtllm.ExecutorConfig())
|
|
stop_signal = threading.Event()
|
|
thread = threading.Thread(target=resp_thread, args=(stop_signal, executor))
|
|
thread.start()
|
|
request = trtllm.Request(input_token_ids=[1, 2, 3, 4],
|
|
max_tokens=5,
|
|
streaming=True)
|
|
executor.enqueue_request(request)
|
|
time.sleep(2)
|
|
stop_signal.set()
|
|
executor.shutdown()
|
|
thread.join()
|
|
assert not thread.is_alive()
|
|
|
|
|
|
def test_guided_decoding_params():
|
|
guided_decoding_params = trtllm.GuidedDecodingParams(
|
|
trtllm.GuidedDecodingParams.GuideType.JSON)
|
|
assert guided_decoding_params.guide_type == trtllm.GuidedDecodingParams.GuideType.JSON
|
|
|
|
class Answer(BaseModel):
|
|
answer: int
|
|
|
|
json_schema = json.dumps(Answer.model_json_schema())
|
|
guided_decoding_params = trtllm.GuidedDecodingParams(
|
|
trtllm.GuidedDecodingParams.GuideType.JSON_SCHEMA, guide=json_schema)
|
|
assert guided_decoding_params.guide_type == trtllm.GuidedDecodingParams.GuideType.JSON_SCHEMA
|
|
assert guided_decoding_params.guide == json_schema
|
|
|
|
with pytest.raises(Exception):
|
|
trtllm.GuidedDecodingParams(
|
|
trtllm.GuidedDecodingParams.GuideType.JSON_SCHEMA)
|
|
|
|
regex = r"\d+"
|
|
guided_decoding_params = trtllm.GuidedDecodingParams(
|
|
trtllm.GuidedDecodingParams.GuideType.REGEX, guide=regex)
|
|
assert guided_decoding_params.guide_type == trtllm.GuidedDecodingParams.GuideType.REGEX
|
|
assert guided_decoding_params.guide == regex
|
|
|
|
ebnf_grammar = "root ::= [0-9]+"
|
|
guided_decoding_params = trtllm.GuidedDecodingParams(
|
|
trtllm.GuidedDecodingParams.GuideType.EBNF_GRAMMAR, guide=ebnf_grammar)
|
|
assert guided_decoding_params.guide_type == trtllm.GuidedDecodingParams.GuideType.EBNF_GRAMMAR
|
|
assert guided_decoding_params.guide == ebnf_grammar
|
|
|
|
|
|
def test_request():
|
|
kwargs = {
|
|
"input_token_ids": [1, 2, 3],
|
|
"max_tokens": 1,
|
|
"streaming": False,
|
|
"sampling_config": trtllm.SamplingConfig(),
|
|
"output_config": trtllm.OutputConfig(),
|
|
"end_id": -1,
|
|
"pad_id": -2,
|
|
"bad_words": [[4, 5, 6]],
|
|
"stop_words": [[7, 8, 9]],
|
|
"embedding_bias": torch.ones(1),
|
|
"external_draft_tokens_config":
|
|
trtllm.ExternalDraftTokensConfig([1, 2, 3]),
|
|
"prompt_tuning_config": trtllm.PromptTuningConfig(torch.ones(100, 64)),
|
|
"lora_config": trtllm.LoraConfig(1),
|
|
"logits_post_processor_name": "my_logits_pp",
|
|
"client_id": 1234,
|
|
}
|
|
request = trtllm.Request(**kwargs)
|
|
for k, v in kwargs.items():
|
|
if "config" not in k:
|
|
assert getattr(request, k) == v
|
|
assert isinstance(request.sampling_config, trtllm.SamplingConfig)
|
|
assert isinstance(request.output_config, trtllm.OutputConfig)
|
|
assert isinstance(request.external_draft_tokens_config,
|
|
trtllm.ExternalDraftTokensConfig)
|
|
assert request.external_draft_tokens_config.tokens == [1, 2, 3]
|
|
assert isinstance(request.prompt_tuning_config, trtllm.PromptTuningConfig)
|
|
assert (request.prompt_tuning_config.embedding_table == torch.ones(
|
|
100, 64)).all()
|
|
assert isinstance(request.lora_config, trtllm.LoraConfig)
|
|
|
|
|
|
def test_request_deprecated_args():
|
|
# max_new_tokens -> max_tokens
|
|
request = trtllm.Request([1, 2, 3], max_tokens=10)
|
|
assert request.max_tokens == 10
|
|
assert request.max_new_tokens == 10
|
|
request = trtllm.Request([1, 2, 3], max_new_tokens=20)
|
|
assert request.max_tokens == 20
|
|
assert request.max_new_tokens == 20
|
|
request = trtllm.Request([1, 2, 3], max_tokens=30, max_new_tokens=40)
|
|
assert request.max_tokens == 30
|
|
assert request.max_new_tokens == 30
|
|
|
|
|
|
def test_spec_dec_fast_logits_info():
|
|
fast_logits_info = trtllm.SpeculativeDecodingFastLogitsInfo()
|
|
fast_logits_info.draft_request_id = 3
|
|
fast_logits_info.draft_participant_id = 5
|
|
assert fast_logits_info.draft_request_id == 3
|
|
assert fast_logits_info.draft_participant_id == 5
|
|
|
|
|
|
def test_result():
|
|
result = trtllm.Result()
|
|
result.is_final = True
|
|
result.output_token_ids = [[1, 2, 3]]
|
|
result.cum_log_probs = [1.0, 2.0, 3.0]
|
|
result.log_probs = [[1.0, 2.0, 3.0]]
|
|
result.context_logits = torch.ones(3, 100)
|
|
result.generation_logits = torch.ones(1, 3, 100)
|
|
result.encoder_output = torch.ones(1, 1)
|
|
result.finish_reasons = [trtllm.FinishReason.LENGTH]
|
|
result.sequence_index = 1
|
|
result.is_sequence_final = True
|
|
result.additional_outputs = [
|
|
trtllm.AdditionalOutput("topKLogits", torch.ones(1, 4, 100))
|
|
]
|
|
assert result.is_final is True
|
|
assert result.output_token_ids == [[1, 2, 3]]
|
|
assert result.cum_log_probs == [1.0, 2.0, 3.0]
|
|
assert result.log_probs == [[1.0, 2.0, 3.0]]
|
|
assert (result.context_logits == torch.ones(3, 100)).all()
|
|
assert (result.generation_logits == torch.ones(1, 3, 100)).all()
|
|
assert (result.encoder_output == torch.ones(1, 1)).all()
|
|
assert result.finish_reasons == [trtllm.FinishReason.LENGTH]
|
|
assert result.sequence_index == 1
|
|
assert result.is_sequence_final is True
|
|
assert len(result.additional_outputs) == 1
|
|
additional_output = result.additional_outputs[0]
|
|
assert additional_output.name == "topKLogits"
|
|
assert (additional_output.output == torch.ones(1, 4, 100)).all()
|
|
|
|
|
|
def test_response():
|
|
request_id = 0
|
|
error_msg = "error"
|
|
response = trtllm.Response(request_id, error_msg)
|
|
assert response.request_id == request_id
|
|
assert response.has_error()
|
|
assert response.error_msg == error_msg
|
|
|
|
result = trtllm.Result()
|
|
result.is_final = True
|
|
result.output_token_ids = [[1, 2, 3]]
|
|
request_id = 1
|
|
response = trtllm.Response(request_id, result)
|
|
assert response.request_id == request_id
|
|
assert not response.has_error()
|
|
assert response.result.is_final
|
|
assert response.result.output_token_ids == [[1, 2, 3]]
|
|
|
|
|
|
def test_scheduler_config():
|
|
capacity_scheduler_policy = trtllm.CapacitySchedulerPolicy.GUARANTEED_NO_EVICT
|
|
config = trtllm.SchedulerConfig()
|
|
assert config.capacity_scheduler_policy == capacity_scheduler_policy
|
|
assert config.context_chunking_policy == None
|
|
|
|
capacity_scheduler_policy = trtllm.CapacitySchedulerPolicy.MAX_UTILIZATION
|
|
config = trtllm.SchedulerConfig(capacity_scheduler_policy)
|
|
assert config.capacity_scheduler_policy == capacity_scheduler_policy
|
|
assert config.context_chunking_policy == None
|
|
|
|
capacity_scheduler_policy = trtllm.CapacitySchedulerPolicy.GUARANTEED_NO_EVICT
|
|
config = trtllm.SchedulerConfig(capacity_scheduler_policy)
|
|
assert config.capacity_scheduler_policy == capacity_scheduler_policy
|
|
assert config.context_chunking_policy == None
|
|
|
|
capacity_scheduler_policy = trtllm.CapacitySchedulerPolicy.STATIC_BATCH
|
|
config = trtllm.SchedulerConfig(capacity_scheduler_policy)
|
|
assert config.capacity_scheduler_policy == capacity_scheduler_policy
|
|
assert config.context_chunking_policy == None
|
|
|
|
context_chunking_policy = trtllm.ContextChunkingPolicy.FIRST_COME_FIRST_SERVED
|
|
config = trtllm.SchedulerConfig(capacity_scheduler_policy,
|
|
context_chunking_policy)
|
|
assert config.capacity_scheduler_policy == capacity_scheduler_policy
|
|
assert config.context_chunking_policy == context_chunking_policy
|
|
|
|
dynamic_batch_config = trtllm.DynamicBatchConfig(True, True, 128)
|
|
config = trtllm.SchedulerConfig(capacity_scheduler_policy,
|
|
context_chunking_policy,
|
|
dynamic_batch_config)
|
|
assert config.capacity_scheduler_policy == capacity_scheduler_policy
|
|
assert config.context_chunking_policy == context_chunking_policy
|
|
assert config.dynamic_batch_config.enable_batch_size_tuning == True
|
|
assert config.dynamic_batch_config.enable_max_num_tokens_tuning == True
|
|
assert config.dynamic_batch_config.dynamic_batch_moving_average_window == 128
|
|
|
|
|
|
def test_kv_cache_config():
|
|
config = trtllm.KvCacheConfig()
|
|
assert config.enable_block_reuse == True
|
|
assert config.max_tokens is None
|
|
assert config.max_attention_window is None
|
|
assert config.sink_token_length is None
|
|
assert config.free_gpu_memory_fraction is None
|
|
assert config.cross_kv_cache_fraction is None
|
|
assert config.host_cache_size is None
|
|
assert config.onboard_blocks == True
|
|
assert config.secondary_offload_min_priority == None
|
|
assert config.event_buffer_max_size == 0
|
|
|
|
config.enable_block_reuse = False
|
|
config.max_tokens = 1
|
|
config.max_attention_window = [2]
|
|
config.sink_token_length = 3
|
|
config.free_gpu_memory_fraction = 0.5
|
|
config.cross_kv_cache_fraction = 0.5
|
|
config.host_cache_size = 4
|
|
config.onboard_blocks = False
|
|
config.secondary_offload_min_priority = 50
|
|
config.event_buffer_max_size = 1024
|
|
assert config.enable_block_reuse == False
|
|
assert config.max_tokens == 1
|
|
assert config.max_attention_window == [2]
|
|
assert config.sink_token_length == 3
|
|
assert config.free_gpu_memory_fraction == 0.5
|
|
assert config.cross_kv_cache_fraction == 0.5
|
|
assert config.host_cache_size == 4
|
|
assert config.onboard_blocks == False
|
|
assert config.secondary_offload_min_priority == 50
|
|
assert config.event_buffer_max_size == 1024
|
|
|
|
kwargs = {
|
|
"enable_block_reuse": True,
|
|
"max_tokens": 3,
|
|
"max_attention_window": [10],
|
|
"sink_token_length": 2,
|
|
"free_gpu_memory_fraction": 0.5,
|
|
"cross_kv_cache_fraction": 0.5,
|
|
"host_cache_size": 1024,
|
|
"onboard_blocks": False,
|
|
"event_buffer_max_size": 2048
|
|
}
|
|
config = trtllm.KvCacheConfig(**kwargs)
|
|
for k, v in kwargs.items():
|
|
assert getattr(config, k) == v
|
|
|
|
config = trtllm.KvCacheConfig(**kwargs)
|
|
max_attention_window, sink_token_length = config.max_attention_window, config.sink_token_length
|
|
runtime_defaults = trtllm.RuntimeDefaults(
|
|
max_attention_window=max_attention_window + [1],
|
|
sink_token_length=sink_token_length + 1)
|
|
|
|
config.fill_empty_fields_from_runtime_defaults(runtime_defaults)
|
|
assert config.max_attention_window == max_attention_window, "runtime defaults shouldn't override existing values"
|
|
assert config.sink_token_length == sink_token_length, "runtime defaults shouldn't override existing values"
|
|
|
|
config = trtllm.KvCacheConfig(**{
|
|
**kwargs, "max_attention_window": None,
|
|
"sink_token_length": None
|
|
})
|
|
config.fill_empty_fields_from_runtime_defaults(runtime_defaults)
|
|
assert config.max_attention_window == runtime_defaults.max_attention_window, "runtime defaults should apply to non existent values"
|
|
assert config.sink_token_length == runtime_defaults.sink_token_length, "runtime defaults should apply to non existent values"
|
|
|
|
config = trtllm.KvCacheConfig(**kwargs, runtime_defaults=runtime_defaults)
|
|
setter_config = trtllm.KvCacheConfig(**kwargs)
|
|
setter_config.fill_empty_fields_from_runtime_defaults(runtime_defaults)
|
|
for k in kwargs.keys():
|
|
assert getattr(config, k) == getattr(
|
|
setter_config, k
|
|
), "passing runtime_defaults to the constructor or settings it manually should be equivalent"
|
|
|
|
|
|
def test_kv_cache_retention_config():
|
|
|
|
TokenRangeRetentionConfig = trtllm.KvCacheRetentionConfig.TokenRangeRetentionConfig
|
|
|
|
config = trtllm.KvCacheRetentionConfig(
|
|
[TokenRangeRetentionConfig(0, 2, 30, datetime.timedelta(seconds=30))],
|
|
80)
|
|
assert len(config.token_range_retention_configs) == 1
|
|
assert config.token_range_retention_configs[0].token_start == 0
|
|
assert config.token_range_retention_configs[0].token_end == 2
|
|
assert config.token_range_retention_configs[0].priority == 30
|
|
assert config.token_range_retention_configs[
|
|
0].duration_ms == datetime.timedelta(seconds=30)
|
|
assert config.decode_retention_priority == 80
|
|
assert config.decode_duration_ms is None
|
|
|
|
config = trtllm.KvCacheRetentionConfig([
|
|
TokenRangeRetentionConfig(0, 64, 80),
|
|
TokenRangeRetentionConfig(64, 100, 10)
|
|
], 10, datetime.timedelta(milliseconds=30000))
|
|
|
|
assert len(config.token_range_retention_configs) == 2
|
|
assert config.token_range_retention_configs[0].token_start == 0
|
|
assert config.token_range_retention_configs[0].token_end == 64
|
|
assert config.token_range_retention_configs[0].priority == 80
|
|
assert config.token_range_retention_configs[0].duration_ms is None
|
|
|
|
assert config.token_range_retention_configs[1].token_start == 64
|
|
assert config.token_range_retention_configs[1].token_end == 100
|
|
assert config.token_range_retention_configs[1].priority == 10
|
|
assert config.token_range_retention_configs[1].duration_ms is None
|
|
|
|
assert config.decode_retention_priority == 10
|
|
assert config.decode_duration_ms == datetime.timedelta(seconds=30)
|
|
|
|
with pytest.raises(Exception):
|
|
# Invalid token ranges
|
|
trtllm.KvCacheRetentionConfig([
|
|
TokenRangeRetentionConfig(0, 64, 10),
|
|
TokenRangeRetentionConfig(32, 128, 50)
|
|
], 50)
|
|
|
|
|
|
def test_lookahead_decoding_config():
|
|
config = trtllm.LookaheadDecodingConfig(3, 5, 7)
|
|
assert config.max_window_size == 3
|
|
assert config.max_ngram_size == 5
|
|
assert config.max_verification_set_size == 7
|
|
|
|
config = trtllm.LookaheadDecodingConfig(5, 10, 3)
|
|
assert config.max_window_size == 5
|
|
assert config.max_ngram_size == 10
|
|
assert config.max_verification_set_size == 3
|
|
|
|
kwargs = {
|
|
"max_window_size": 5,
|
|
"max_ngram_size": 3,
|
|
"max_verification_set_size": 7,
|
|
}
|
|
|
|
config = trtllm.LookaheadDecodingConfig(**kwargs)
|
|
for k, v in kwargs.items():
|
|
assert getattr(config, k) == v
|
|
|
|
|
|
def test_eagle_config():
|
|
config = trtllm.EagleConfig([[0, 0], [0, 1]], False, 0.5)
|
|
assert config.eagle_choices == [[0, 0], [0, 1]]
|
|
assert config.greedy_sampling == False
|
|
assert config.posterior_threshold == 0.5
|
|
|
|
config = trtllm.EagleConfig([[0, 0], [0, 1, 0]], True)
|
|
assert config.eagle_choices == [[0, 0], [0, 1, 0]]
|
|
assert config.greedy_sampling == True
|
|
assert config.posterior_threshold == None
|
|
|
|
config = trtllm.EagleConfig(None, True, 0.5)
|
|
assert config.eagle_choices == None
|
|
assert config.greedy_sampling == True
|
|
assert config.posterior_threshold == 0.5
|
|
|
|
kwargs = {
|
|
"eagle_choices": [[0, 0], [0, 1], [0, 2]],
|
|
"greedy_sampling": True,
|
|
"posterior_threshold": 0.5
|
|
}
|
|
|
|
config = trtllm.EagleConfig(**kwargs)
|
|
for k, v in kwargs.items():
|
|
assert getattr(config, k) == v
|
|
|
|
|
|
def test_decoding_mode():
|
|
mode = trtllm.DecodingMode.Auto()
|
|
assert mode.isAuto()
|
|
|
|
mode = trtllm.DecodingMode.TopK()
|
|
assert mode.isTopK()
|
|
|
|
mode = trtllm.DecodingMode.TopP()
|
|
assert mode.isTopP()
|
|
|
|
mode = trtllm.DecodingMode.TopKTopP()
|
|
assert mode.isTopKandTopP()
|
|
|
|
mode = trtllm.DecodingMode.BeamSearch()
|
|
assert mode.isBeamSearch()
|
|
|
|
mode = trtllm.DecodingMode.Medusa()
|
|
assert mode.isMedusa()
|
|
|
|
mode = trtllm.DecodingMode.Lookahead()
|
|
assert mode.isLookahead()
|
|
|
|
mode = trtllm.DecodingMode.ExplicitDraftTokens()
|
|
assert mode.isExplicitDraftTokens()
|
|
|
|
mode = trtllm.DecodingMode.Eagle()
|
|
assert mode.isEagle()
|
|
|
|
|
|
def test_speculative_decoding_config():
|
|
config = trtllm.DecodingConfig()
|
|
assert config.decoding_mode is None
|
|
assert config.lookahead_decoding_config is None
|
|
assert config.medusa_choices is None
|
|
assert config.eagle_config is None
|
|
|
|
config = trtllm.DecodingConfig()
|
|
config.decoding_mode = trtllm.DecodingMode.TopKTopP()
|
|
assert config.decoding_mode.isTopKandTopP()
|
|
assert config.lookahead_decoding_config == None
|
|
assert config.medusa_choices == None
|
|
assert config.eagle_config is None
|
|
|
|
config = trtllm.DecodingConfig()
|
|
la_decoding_config = trtllm.LookaheadDecodingConfig(3, 5, 7)
|
|
config.lookahead_decoding_config = la_decoding_config
|
|
|
|
assert config.decoding_mode.isLookahead()
|
|
assert config.lookahead_decoding_config.max_ngram_size == la_decoding_config.max_ngram_size
|
|
assert config.lookahead_decoding_config.max_window_size == la_decoding_config.max_window_size
|
|
assert config.lookahead_decoding_config.max_verification_set_size == la_decoding_config.max_verification_set_size
|
|
assert config.medusa_choices == None
|
|
assert config.eagle_config is None
|
|
|
|
config = trtllm.DecodingConfig()
|
|
config.medusa_choices = [[0, 0], [0, 1]]
|
|
|
|
assert config.decoding_mode.isMedusa()
|
|
assert config.lookahead_decoding_config == None
|
|
assert config.medusa_choices == [[0, 0], [0, 1]]
|
|
assert config.eagle_config is None
|
|
|
|
config = trtllm.DecodingConfig()
|
|
config.eagle_config = trtllm.EagleConfig([[0, 0], [0, 1]])
|
|
|
|
assert config.decoding_mode.isEagle()
|
|
assert config.lookahead_decoding_config == None
|
|
assert config.medusa_choices == None
|
|
assert config.eagle_config is not None
|
|
assert config.eagle_config.eagle_choices == [[0, 0], [0, 1]]
|
|
|
|
|
|
def test_logits_post_processor_config():
|
|
config = trtllm.LogitsPostProcessorConfig()
|
|
assert config.processor_map == None
|
|
assert config.processor_batched == None
|
|
assert config.replicate == True
|
|
|
|
kwargs = {
|
|
"processor_map": {
|
|
"test_pp": None
|
|
},
|
|
"processor_batched": None,
|
|
"replicate": False
|
|
}
|
|
config = trtllm.LogitsPostProcessorConfig(**kwargs)
|
|
for k, v in kwargs.items():
|
|
assert getattr(config, k) == v
|
|
|
|
|
|
def test_guided_decoding_config():
|
|
encoded_vocab = ["eos", "a", "b", "c", "d"]
|
|
tokenizer_str = None
|
|
stop_token_ids = [0]
|
|
guided_decoding_config = trtllm.GuidedDecodingConfig(
|
|
backend=trtllm.GuidedDecodingConfig.GuidedDecodingBackend.XGRAMMAR,
|
|
encoded_vocab=encoded_vocab,
|
|
tokenizer_str=tokenizer_str,
|
|
stop_token_ids=stop_token_ids)
|
|
assert guided_decoding_config.backend == trtllm.GuidedDecodingConfig.GuidedDecodingBackend.XGRAMMAR
|
|
assert guided_decoding_config.encoded_vocab == encoded_vocab
|
|
assert guided_decoding_config.tokenizer_str == tokenizer_str
|
|
assert guided_decoding_config.stop_token_ids == stop_token_ids
|
|
|
|
|
|
def test_executor_config():
|
|
config = trtllm.ExecutorConfig()
|
|
assert config.max_beam_width == 1
|
|
assert config.max_batch_size is None
|
|
assert config.max_num_tokens is None
|
|
assert isinstance(config.scheduler_config, trtllm.SchedulerConfig)
|
|
assert isinstance(config.kv_cache_config, trtllm.KvCacheConfig)
|
|
assert config.enable_chunked_context == False
|
|
assert config.normalize_log_probs == True
|
|
assert config.iter_stats_max_iterations == 1000
|
|
assert config.batching_type == trtllm.BatchingType.INFLIGHT
|
|
assert config.parallel_config is None
|
|
assert isinstance(config.peft_cache_config, trtllm.PeftCacheConfig)
|
|
assert config.logits_post_processor_config is None
|
|
assert config.decoding_config is None
|
|
assert config.debug_config is None
|
|
assert config.recv_poll_period_ms == 0
|
|
assert config.max_seq_idle_microseconds == 180000000
|
|
assert config.spec_dec_config is None
|
|
assert config.guided_decoding_config is None
|
|
|
|
kwargs = {
|
|
"max_beam_width":
|
|
2,
|
|
"max_batch_size":
|
|
8,
|
|
"max_num_tokens":
|
|
128,
|
|
"scheduler_config":
|
|
trtllm.SchedulerConfig(trtllm.CapacitySchedulerPolicy.MAX_UTILIZATION),
|
|
"kv_cache_config":
|
|
trtllm.KvCacheConfig(),
|
|
"enable_chunked_context":
|
|
True,
|
|
"normalize_log_probs":
|
|
False,
|
|
"iter_stats_max_iterations":
|
|
100,
|
|
"batching_type":
|
|
trtllm.BatchingType.STATIC,
|
|
"parallel_config":
|
|
trtllm.ParallelConfig(),
|
|
"peft_cache_config":
|
|
trtllm.PeftCacheConfig(10),
|
|
"logits_post_processor_config":
|
|
trtllm.LogitsPostProcessorConfig(),
|
|
"decoding_config":
|
|
trtllm.DecodingConfig(trtllm.DecodingMode.TopKTopP()),
|
|
"extended_runtime_perf_knob_config":
|
|
trtllm.ExtendedRuntimePerfKnobConfig(multi_block_mode=True),
|
|
"debug_config":
|
|
trtllm.DebugConfig(debug_input_tensors=True,
|
|
debug_output_tensors=True,
|
|
debug_tensor_names=["test"]),
|
|
"recv_poll_period_ms":
|
|
50,
|
|
"max_seq_idle_microseconds":
|
|
240 * 1000 * 1000,
|
|
"spec_dec_config":
|
|
trtllm.SpeculativeDecodingConfig(fast_logits=True),
|
|
"guided_decoding_config":
|
|
trtllm.GuidedDecodingConfig(
|
|
trtllm.GuidedDecodingConfig.GuidedDecodingBackend.XGRAMMAR,
|
|
encoded_vocab=["eos", "a", "b", "c", "d"]),
|
|
"additional_output_names": ["topKLogits"]
|
|
}
|
|
config = trtllm.ExecutorConfig(**kwargs)
|
|
for k, v in kwargs.items():
|
|
if "config" not in k:
|
|
assert getattr(config, k) == v
|
|
assert isinstance(config.scheduler_config, trtllm.SchedulerConfig)
|
|
assert config.scheduler_config.capacity_scheduler_policy == trtllm.CapacitySchedulerPolicy.MAX_UTILIZATION
|
|
assert isinstance(config.kv_cache_config, trtllm.KvCacheConfig)
|
|
assert isinstance(config.parallel_config, trtllm.ParallelConfig)
|
|
assert isinstance(config.peft_cache_config, trtllm.PeftCacheConfig)
|
|
assert config.extended_runtime_perf_knob_config.multi_block_mode == True
|
|
assert isinstance(config.debug_config, trtllm.DebugConfig)
|
|
assert isinstance(config.logits_post_processor_config,
|
|
trtllm.LogitsPostProcessorConfig)
|
|
assert isinstance(config.spec_dec_config, trtllm.SpeculativeDecodingConfig)
|
|
assert isinstance(config.guided_decoding_config,
|
|
trtllm.GuidedDecodingConfig)
|
|
assert isinstance(config.additional_output_names, list)
|
|
assert len(config.additional_output_names) == 1
|
|
assert config.additional_output_names[0] == "topKLogits"
|
|
|
|
|
|
def test_parallel_config():
|
|
comm_type = trtllm.CommunicationType.MPI
|
|
comm_mode = trtllm.CommunicationMode.LEADER
|
|
device_ids = [0, 1, 2, 3]
|
|
participant_ids = [4, 5, 6, 7]
|
|
parallel_config = trtllm.ParallelConfig(comm_type, comm_mode, device_ids,
|
|
participant_ids)
|
|
assert parallel_config.communication_type == comm_type
|
|
assert parallel_config.communication_mode == comm_mode
|
|
assert parallel_config.device_ids == device_ids
|
|
assert parallel_config.participant_ids == participant_ids
|
|
|
|
comm_mode = trtllm.CommunicationMode.ORCHESTRATOR
|
|
#Dummy path to worker executable
|
|
worker_path = _os.path.abspath(__file__)
|
|
orchestrator_config = trtllm.OrchestratorConfig(True, str(worker_path),
|
|
None, True)
|
|
parallel_config = trtllm.ParallelConfig(comm_type, comm_mode, device_ids,
|
|
participant_ids,
|
|
orchestrator_config)
|
|
assert parallel_config.communication_mode == comm_mode
|
|
assert parallel_config.orchestrator_config.is_orchestrator == True
|
|
assert parallel_config.orchestrator_config.worker_executable_path == str(
|
|
worker_path)
|
|
assert parallel_config.orchestrator_config.spawn_processes == True
|
|
|
|
|
|
def test_peft_cache_config():
|
|
num_host_module_layer = 1
|
|
num_device_module_layer = 2
|
|
optimal_adapter_size = 3
|
|
max_adapter_size = 4
|
|
num_put_workers = 5
|
|
num_ensure_workers = 6
|
|
num_copy_streams = 7
|
|
max_pages_per_block_host = 8
|
|
max_pages_per_block_device = 9
|
|
device_cache_percent = 0.9
|
|
host_cache_size = 1024
|
|
peft_cache_config = trtllm.PeftCacheConfig(
|
|
num_host_module_layer, num_device_module_layer, optimal_adapter_size,
|
|
max_adapter_size, num_put_workers, num_ensure_workers, num_copy_streams,
|
|
max_pages_per_block_host, max_pages_per_block_device,
|
|
device_cache_percent, host_cache_size)
|
|
|
|
assert peft_cache_config.num_host_module_layer == num_host_module_layer
|
|
assert peft_cache_config.num_device_module_layer == num_device_module_layer
|
|
assert peft_cache_config.optimal_adapter_size == optimal_adapter_size
|
|
assert peft_cache_config.max_adapter_size == max_adapter_size
|
|
assert peft_cache_config.num_put_workers == num_put_workers
|
|
assert peft_cache_config.num_ensure_workers == num_ensure_workers
|
|
assert peft_cache_config.num_copy_streams == num_copy_streams
|
|
assert peft_cache_config.max_pages_per_block_host == max_pages_per_block_host
|
|
assert peft_cache_config.max_pages_per_block_device == max_pages_per_block_device
|
|
assert np.isclose(peft_cache_config.device_cache_percent,
|
|
device_cache_percent)
|
|
assert peft_cache_config.host_cache_size == host_cache_size
|
|
|
|
|
|
@skip_pre_ampere # ContextFMHAType with fp32 acc is not supported in pre-ampere architecture
|
|
def test_logits_post_processor(model_files, model_path):
|
|
|
|
# Define the logits post-processor callback
|
|
def logits_post_processor(req_id: int, logits: torch.Tensor,
|
|
ids: tp.List[tp.List[int]], stream_ptr: int,
|
|
client_id: tp.Optional[int]):
|
|
assert client_id == 123
|
|
with torch.cuda.stream(torch.cuda.ExternalStream(stream_ptr)):
|
|
logits[:] = float("-inf")
|
|
logits[..., 42] = 0
|
|
|
|
# Create executor
|
|
beam_width = 1
|
|
executor_config = trtllm.ExecutorConfig(beam_width)
|
|
executor_config.logits_post_processor_config = trtllm.LogitsPostProcessorConfig(
|
|
{"my_logits_pp": logits_post_processor})
|
|
executor = trtllm.Executor(model_path, trtllm.ModelType.DECODER_ONLY,
|
|
executor_config)
|
|
|
|
# Create the request
|
|
max_tokens = 5
|
|
input_tokens = [1, 2, 3, 4]
|
|
request = trtllm.Request(input_tokens,
|
|
max_tokens=max_tokens,
|
|
streaming=False,
|
|
client_id=123)
|
|
request.logits_post_processor_name = "my_logits_pp"
|
|
|
|
# Enqueue the request
|
|
request_id = executor.enqueue_request(request)
|
|
|
|
# Get the new tokens
|
|
tokens = []
|
|
done = False
|
|
i = 0
|
|
max_wait_ms = 10000
|
|
while not done and i < max_wait_ms:
|
|
wait_time = datetime.timedelta(milliseconds=1)
|
|
responses = executor.await_responses(request_id, wait_time)
|
|
for response in responses:
|
|
assert not response.has_error(
|
|
), f"Request id {request_id} failed with err {response.error_msg}"
|
|
result = response.result
|
|
done = result.is_final
|
|
new_tokens = result.output_token_ids[beam_width - 1]
|
|
tokens.extend(new_tokens)
|
|
i += 1
|
|
assert i < max_wait_ms
|
|
assert len(tokens) == get_expected_num_tokens(len(input_tokens), max_tokens,
|
|
False, False), f"{request_id}"
|
|
|
|
# check that all output tokens are 42
|
|
print(tokens)
|
|
assert tokens[-max_tokens:] == [42] * max_tokens
|
|
|
|
|
|
@skip_pre_ampere # ContextFMHAType with fp32 acc is not supported in pre-ampere architecture
|
|
def test_logits_post_processor_batched(model_files, model_path):
|
|
|
|
# Define the logits post-processor callback
|
|
def logits_post_processor_batched(
|
|
req_id_batch: tp.List[int], logits_batch: tp.List[torch.Tensor],
|
|
ids_batch: tp.List[tp.List[tp.List[int]]], stream_ptr: int,
|
|
client_id_batch: tp.List[tp.Optional[int]]):
|
|
for client_id in client_id_batch:
|
|
assert client_id == 123
|
|
with torch.cuda.stream(torch.cuda.ExternalStream(stream_ptr)):
|
|
for logits in logits_batch:
|
|
logits[:] = float("-inf")
|
|
logits[..., 42] = 0
|
|
|
|
# Create executor
|
|
beam_width = 1
|
|
executor_config = trtllm.ExecutorConfig(beam_width)
|
|
executor_config.logits_post_processor_config = trtllm.LogitsPostProcessorConfig(
|
|
None, logits_post_processor_batched)
|
|
executor = trtllm.Executor(model_path, trtllm.ModelType.DECODER_ONLY,
|
|
executor_config)
|
|
|
|
# Create the request
|
|
max_tokens = 5
|
|
input_tokens = [1, 2, 3, 4]
|
|
request = trtllm.Request(input_tokens,
|
|
max_tokens=max_tokens,
|
|
streaming=False,
|
|
client_id=123)
|
|
request.logits_post_processor_name = request.BATCHED_POST_PROCESSOR_NAME
|
|
|
|
batch_size = 4
|
|
# Enqueue the requests
|
|
request_ids = []
|
|
for _ in range(batch_size):
|
|
request_id = executor.enqueue_request(request)
|
|
request_ids.append(request_id)
|
|
|
|
# Get the new tokens
|
|
tokens = {req_id: [] for req_id in request_ids}
|
|
num_finished = 0
|
|
i = 0
|
|
max_wait_ms = 10000
|
|
while num_finished < len(request_ids) and i < max_wait_ms:
|
|
responses = executor.await_responses(datetime.timedelta(milliseconds=1))
|
|
for response in responses:
|
|
req_id = response.request_id
|
|
assert not response.has_error(
|
|
), f"Request id {req_id} failed with err {response.error_msg}"
|
|
result = response.result
|
|
num_finished += 1 if result.is_final else 0
|
|
new_tokens = result.output_token_ids[beam_width - 1]
|
|
tokens[req_id].extend(new_tokens)
|
|
assert i < max_wait_ms
|
|
|
|
expected_num_tokens = get_expected_num_tokens(len(input_tokens), max_tokens,
|
|
False, False)
|
|
for req_id in request_ids:
|
|
assert len(tokens[req_id]) == expected_num_tokens, f"{req_id}"
|
|
|
|
|
|
def test_kv_event_stream(model_path):
|
|
|
|
beam_width = 1
|
|
executor_config = trtllm.ExecutorConfig(
|
|
beam_width,
|
|
kv_cache_config=trtllm.KvCacheConfig(True,
|
|
4 * 64,
|
|
event_buffer_max_size=1024,
|
|
host_cache_size=3000000))
|
|
|
|
executor = trtllm.Executor(model_path, trtllm.ModelType.DECODER_ONLY,
|
|
executor_config)
|
|
|
|
cache_manager = executor.get_kv_cache_event_manager()
|
|
|
|
events = cache_manager.get_latest_events()
|
|
|
|
assert len(events) == 1
|
|
assert isinstance(events[0], trtllm.kv_cache.KVCacheEvent)
|
|
assert events[0].event_id == 0
|
|
assert isinstance(events[0].data, trtllm.kv_cache.KVCacheCreatedData)
|
|
|
|
for req in range(2):
|
|
input_tokens = list(range(req, req + 127))
|
|
request = trtllm.Request(input_tokens,
|
|
max_tokens=5,
|
|
streaming=False,
|
|
sampling_config=trtllm.SamplingConfig())
|
|
|
|
id = executor.enqueue_request(request)
|
|
|
|
responses = executor.await_responses(id)
|
|
|
|
for response in responses:
|
|
assert not response.has_error()
|
|
if response.result.is_final:
|
|
time.sleep(0.1)
|
|
events = cache_manager.get_latest_events(
|
|
datetime.timedelta(milliseconds=100))
|
|
|
|
if req == 0:
|
|
assert events[0].event_id == 1
|
|
assert isinstance(events[0].data,
|
|
trtllm.kv_cache.KVCacheStoredData)
|
|
assert events[0].data.parent_hash is None
|
|
assert len(events[0].data.blocks) == 1
|
|
|
|
assert events[1].data.parent_hash == events[0].data.blocks[
|
|
0].block_hash
|
|
assert len(events[1].data.blocks) == 2
|
|
else:
|
|
# Swap a block to secondary
|
|
assert isinstance(events[0].data,
|
|
trtllm.kv_cache.KVCacheUpdatedData)
|
|
assert events[0].data.cache_level.old_value == 0
|
|
assert events[0].data.cache_level.new_value == 1
|
|
# Store the filled context block
|
|
assert isinstance(events[1].data,
|
|
trtllm.kv_cache.KVCacheStoredData)
|
|
assert len(events[1].data.blocks) == 1
|
|
assert events[1].data.parent_hash is None
|
|
# Swap another block to secondary
|
|
assert isinstance(events[2].data,
|
|
trtllm.kv_cache.KVCacheUpdatedData)
|
|
assert events[2].data.cache_level.old_value == 0
|
|
assert events[2].data.cache_level.new_value == 1
|
|
assert isinstance(events[2].data.cache_level,
|
|
trtllm.kv_cache.KVCacheEventDiffInt)
|
|
# Remove the first block in secondary
|
|
assert isinstance(events[3].data,
|
|
trtllm.kv_cache.KVCacheRemovedData)
|
|
assert len(events[3].data.block_hashes) == 1
|
|
assert events[3].data.block_hashes[0] == events[
|
|
0].data.block_hash
|
|
# Store the second context block and the decode block
|
|
assert isinstance(events[4].data,
|
|
trtllm.kv_cache.KVCacheStoredData)
|
|
assert len(events[4].data.blocks) == 2
|
|
assert events[4].data.parent_hash == events[1].data.blocks[
|
|
0].block_hash
|
|
|
|
|
|
@pytest.mark.parametrize("streaming", [False, True])
|
|
def test_request_perf_metrics(streaming: bool, model_path):
|
|
|
|
# Create executor
|
|
beam_width = 1
|
|
executor_config = trtllm.ExecutorConfig(beam_width)
|
|
executor = trtllm.Executor(model_path, trtllm.ModelType.DECODER_ONLY,
|
|
executor_config)
|
|
|
|
# Create request
|
|
max_tokens = 5
|
|
input_tokens = [1, 2, 3, 4]
|
|
output_config = trtllm.OutputConfig(return_perf_metrics=True)
|
|
request = trtllm.Request(input_tokens,
|
|
max_tokens=max_tokens,
|
|
streaming=streaming,
|
|
output_config=output_config)
|
|
|
|
# Enqueue the request
|
|
request_id = executor.enqueue_request(request)
|
|
|
|
def check_perf_metrics(perf_metrics, done, response_id):
|
|
assert perf_metrics is not None
|
|
|
|
timing_metrics = perf_metrics.timing_metrics
|
|
assert timing_metrics.arrival_time < timing_metrics.first_scheduled_time
|
|
assert timing_metrics.first_scheduled_time < timing_metrics.first_token_time
|
|
if done:
|
|
assert timing_metrics.first_token_time < timing_metrics.last_token_time
|
|
else:
|
|
assert timing_metrics.last_token_time == datetime.timedelta(0)
|
|
|
|
kv_cache_metrics = perf_metrics.kv_cache_metrics
|
|
assert kv_cache_metrics.num_total_allocated_blocks == 1
|
|
assert kv_cache_metrics.num_new_allocated_blocks == 1
|
|
assert kv_cache_metrics.num_reused_blocks == 0
|
|
assert kv_cache_metrics.num_missed_blocks == 1
|
|
assert kv_cache_metrics.kv_cache_hit_rate == 0
|
|
|
|
assert perf_metrics.first_iter == 0
|
|
if done:
|
|
assert perf_metrics.iter == (max_tokens - 1)
|
|
assert perf_metrics.last_iter == max_tokens - 1
|
|
else:
|
|
assert perf_metrics.iter == response_id
|
|
assert perf_metrics.last_iter is None
|
|
|
|
# Get the new tokens
|
|
tokens = []
|
|
done = False
|
|
i = 0
|
|
max_wait_ms = 10000
|
|
response_id = 0
|
|
while not done and i < max_wait_ms:
|
|
wait_time = datetime.timedelta(milliseconds=1)
|
|
responses = executor.await_responses(request_id, wait_time)
|
|
for response in responses:
|
|
assert not response.has_error(
|
|
), f"Request id {request_id} failed with err {response.error_msg}"
|
|
result = response.result
|
|
done = result.is_final
|
|
check_perf_metrics(result.request_perf_metrics, done, response_id)
|
|
new_tokens = result.output_token_ids[beam_width - 1]
|
|
tokens.extend(new_tokens)
|
|
response_id += 1
|
|
i += 1
|
|
assert i < max_wait_ms
|
|
assert len(tokens) == get_expected_num_tokens(
|
|
len(input_tokens),
|
|
max_tokens,
|
|
streaming=streaming,
|
|
exclude_input_from_output=False), f"{request_id}"
|
|
|
|
|
|
def test_kv_event_stream_timeout(model_path):
|
|
|
|
beam_width = 1
|
|
executor_config = trtllm.ExecutorConfig(
|
|
beam_width,
|
|
kv_cache_config=trtllm.KvCacheConfig(True,
|
|
4 * 64,
|
|
event_buffer_max_size=1024))
|
|
|
|
executor = trtllm.Executor(model_path, trtllm.ModelType.DECODER_ONLY,
|
|
executor_config)
|
|
|
|
cache_manager = executor.get_kv_cache_event_manager()
|
|
|
|
events = cache_manager.get_latest_events()
|
|
assert len(events) == 1
|
|
|
|
start = datetime.datetime.now()
|
|
events = cache_manager.get_latest_events(datetime.timedelta(seconds=1))
|
|
end = datetime.datetime.now()
|
|
# Make sure that it actually waited
|
|
assert abs(end - start) > datetime.timedelta(milliseconds=900)
|
|
assert len(events) == 0
|
|
|
|
|
|
def test_iteration_stats():
|
|
stats = trtllm.IterationStats()
|
|
stats.timestamp = "01:23:56"
|
|
stats.iter = 1
|
|
stats.iter_latency_ms = 100
|
|
stats.num_active_requests = 2
|
|
stats.num_queued_requests = 10
|
|
stats.max_num_active_requests = 3
|
|
stats.gpu_mem_usage = 1024
|
|
stats.cpu_mem_usage = 2048
|
|
stats.pinned_mem_usage = 4096
|
|
stats_json = json.loads(stats.to_json_str())
|
|
assert stats_json["timestamp"] == stats.timestamp
|
|
assert stats_json["iter"] == stats.iter
|
|
assert stats_json["iterLatencyMS"] == stats.iter_latency_ms
|
|
assert stats_json[
|
|
"newActiveRequestsQueueLatencyMS"] == stats.new_active_requests_queue_latency_ms
|
|
assert stats_json["numActiveRequests"] == stats.num_active_requests
|
|
assert stats_json["numQueuedRequests"] == stats.num_queued_requests
|
|
assert stats_json["numCompletedRequests"] == stats.num_completed_requests
|
|
assert stats_json["maxNumActiveRequests"] == stats.max_num_active_requests
|
|
assert stats_json["gpuMemUsage"] == stats.gpu_mem_usage
|
|
assert stats_json["cpuMemUsage"] == stats.cpu_mem_usage
|
|
assert stats_json["pinnedMemUsage"] == stats.pinned_mem_usage
|
|
assert stats_json["kvCacheStats"] is None
|
|
assert stats_json["staticBatchingStats"] is None
|
|
assert stats_json["inflightBatchingStats"] is None
|
|
|
|
|
|
def test_request_stats():
|
|
stats = trtllm.RequestStats()
|
|
stats.id = 1
|
|
stats.stage = trtllm.RequestStage.CONTEXT_IN_PROGRESS
|
|
stats.context_prefill_position = 2
|
|
stats.num_generated_tokens = 3
|
|
stats.avg_num_decoded_tokens_per_iter = 2.5
|
|
stats.scheduled = True
|
|
stats.paused = False
|
|
stats_json = json.loads(stats.to_json_str())
|
|
assert stats_json["id"] == stats.id
|
|
assert stats_json["stage"] == "CONTEXT_IN_PROGRESS"
|
|
assert stats_json[
|
|
"contextPrefillPosition"] == stats.context_prefill_position
|
|
assert stats_json["numGeneratedTokens"] == stats.num_generated_tokens
|
|
assert stats_json[
|
|
"avgNumDecodedTokensPerIter"] == stats.avg_num_decoded_tokens_per_iter
|
|
assert stats_json["scheduled"] == stats.scheduled
|
|
assert stats_json["paused"] == stats.paused
|
|
assert stats_json["disServingStats"] is None
|
|
|
|
|
|
def test_request_stats_per_iteration():
|
|
stats = trtllm.RequestStatsPerIteration()
|
|
stats.iter = 1
|
|
req_stat = trtllm.RequestStats()
|
|
req_stat.id = 1
|
|
stats.request_stats = [req_stat]
|
|
stats_json = json.loads(stats.to_json_str())
|
|
assert stats_json["iter"] == 1
|
|
assert stats_json["requestStats"][0]["id"] == 1
|
|
|
|
|
|
def test_scheduler_config_pickle():
|
|
policy = trtllm.CapacitySchedulerPolicy.MAX_UTILIZATION
|
|
config = trtllm.SchedulerConfig(policy)
|
|
config_str = pickle.dumps(config)
|
|
config_copy = pickle.loads(config_str)
|
|
assert config.capacity_scheduler_policy == config_copy.capacity_scheduler_policy
|
|
|
|
|
|
def test_kv_cache_config_pickle():
|
|
config = trtllm.KvCacheConfig()
|
|
config.enable_block_reuse = True
|
|
config.free_gpu_memory_fraction = 0.3
|
|
config.cross_kv_cache_fraction = 0.5
|
|
config.event_buffer_max_size = 1024
|
|
config_copy = pickle.loads(pickle.dumps(config))
|
|
assert config.enable_block_reuse == config_copy.enable_block_reuse
|
|
assert config.max_tokens == config_copy.max_tokens
|
|
assert config.max_attention_window == config_copy.max_attention_window
|
|
assert config.sink_token_length == config_copy.sink_token_length
|
|
assert config.free_gpu_memory_fraction == config_copy.free_gpu_memory_fraction
|
|
assert config.cross_kv_cache_fraction == config_copy.cross_kv_cache_fraction
|
|
assert config.host_cache_size == config_copy.host_cache_size
|
|
assert config.onboard_blocks == config_copy.onboard_blocks
|
|
assert config.event_buffer_max_size == config_copy.event_buffer_max_size
|
|
|
|
|
|
def test_peft_cache_config_pickle():
|
|
config = trtllm.PeftCacheConfig(1, 2, 3, 4, 5, 6, 7, 8, 9, 0.9, 1024)
|
|
config_copy = pickle.loads(pickle.dumps(config))
|
|
assert config.num_host_module_layer == config_copy.num_host_module_layer
|
|
assert config.num_device_module_layer == config_copy.num_device_module_layer
|
|
assert config.optimal_adapter_size == config_copy.optimal_adapter_size
|
|
assert config.max_adapter_size == config_copy.max_adapter_size
|
|
assert config.num_put_workers == config_copy.num_put_workers
|
|
assert config.num_ensure_workers == config_copy.num_ensure_workers
|
|
assert config.num_copy_streams == config_copy.num_copy_streams
|
|
assert config.max_pages_per_block_host == config_copy.max_pages_per_block_host
|
|
assert config.max_pages_per_block_device == config_copy.max_pages_per_block_device
|
|
assert config.device_cache_percent == config_copy.device_cache_percent
|
|
assert config.host_cache_size == config_copy.host_cache_size
|
|
|
|
|
|
def test_decoding_config_pickle():
|
|
config = trtllm.DecodingConfig(
|
|
decoding_mode=trtllm.DecodingMode.BeamSearch())
|
|
config_copy = pickle.loads(pickle.dumps(config))
|
|
assert config_copy.decoding_mode.isBeamSearch
|
|
assert config.lookahead_decoding_config == config_copy.lookahead_decoding_config
|
|
assert config.medusa_choices == config_copy.medusa_choices
|
|
|
|
|
|
def test_debug_config_pickle():
|
|
config = trtllm.DebugConfig(debug_input_tensors=True,
|
|
debug_output_tensors=True,
|
|
debug_tensor_names=["test"],
|
|
debug_tensors_max_iterations=5)
|
|
config_copy = pickle.loads(pickle.dumps(config))
|
|
assert config.debug_input_tensors == config_copy.debug_input_tensors
|
|
assert config.debug_output_tensors == config_copy.debug_output_tensors
|
|
assert config.debug_tensor_names == config_copy.debug_tensor_names
|
|
assert config.debug_tensors_max_iterations == config_copy.debug_tensors_max_iterations
|
|
|
|
|
|
def test_logits_post_processor_config_pickle():
|
|
kwargs = {
|
|
"processor_map": {
|
|
"test_pp": None
|
|
},
|
|
"processor_batched": None,
|
|
"replicate": False
|
|
}
|
|
config = trtllm.LogitsPostProcessorConfig(**kwargs)
|
|
config_copy = pickle.loads(pickle.dumps(config))
|
|
for k in kwargs:
|
|
assert getattr(config, k) == getattr(config_copy, k)
|
|
|
|
|
|
def test_guided_decoding_params_pickle():
|
|
|
|
class Answer(BaseModel):
|
|
answer: int
|
|
|
|
json_schema = json.dumps(Answer.model_json_schema())
|
|
params = trtllm.GuidedDecodingParams(
|
|
trtllm.GuidedDecodingParams.GuideType.JSON_SCHEMA, guide=json_schema)
|
|
params_copy = pickle.loads(pickle.dumps(params))
|
|
assert params_copy.guide_type == params.guide_type
|
|
assert params_copy.guide == params.guide
|
|
|
|
|
|
def test_guided_decoding_config_pickle():
|
|
encoded_vocab = ["eos", "a", "b", "c", "d"]
|
|
tokenizer_str = None
|
|
stop_token_ids = [0]
|
|
config = trtllm.GuidedDecodingConfig(
|
|
backend=trtllm.GuidedDecodingConfig.GuidedDecodingBackend.XGRAMMAR,
|
|
encoded_vocab=encoded_vocab,
|
|
tokenizer_str=tokenizer_str,
|
|
stop_token_ids=stop_token_ids)
|
|
config_copy = pickle.loads(pickle.dumps(config))
|
|
assert config_copy.backend == config.backend
|
|
assert config_copy.encoded_vocab == config.encoded_vocab
|
|
assert config_copy.tokenizer_str == config.tokenizer_str
|
|
assert config_copy.stop_token_ids == config.stop_token_ids
|
|
|
|
|
|
def test_executor_config_pickle():
|
|
beam_width = 2
|
|
config = trtllm.ExecutorConfig(beam_width)
|
|
|
|
kwargs = {
|
|
"max_beam_width":
|
|
2,
|
|
"max_batch_size":
|
|
8,
|
|
"max_num_tokens":
|
|
128,
|
|
"scheduler_config":
|
|
trtllm.SchedulerConfig(trtllm.CapacitySchedulerPolicy.MAX_UTILIZATION),
|
|
"kv_cache_config":
|
|
trtllm.KvCacheConfig(enable_block_reuse=True),
|
|
"enable_chunked_context":
|
|
True,
|
|
"normalize_log_probs":
|
|
False,
|
|
"iter_stats_max_iterations":
|
|
100,
|
|
"batching_type":
|
|
trtllm.BatchingType.STATIC,
|
|
"parallel_config":
|
|
trtllm.ParallelConfig(),
|
|
"peft_cache_config":
|
|
trtllm.PeftCacheConfig(10),
|
|
"logits_post_processor_config":
|
|
trtllm.LogitsPostProcessorConfig(),
|
|
"decoding_config":
|
|
trtllm.DecodingConfig(trtllm.DecodingMode.TopKTopP()),
|
|
"extended_runtime_perf_knob_config":
|
|
trtllm.ExtendedRuntimePerfKnobConfig(multi_block_mode=True),
|
|
"debug_config":
|
|
trtllm.DebugConfig(debug_input_tensors=True,
|
|
debug_output_tensors=True,
|
|
debug_tensor_names=["test"]),
|
|
"recv_poll_period_ms":
|
|
50,
|
|
"max_seq_idle_microseconds":
|
|
240 * 1000 * 1000,
|
|
"spec_dec_config":
|
|
trtllm.SpeculativeDecodingConfig(fast_logits=True)
|
|
}
|
|
config = trtllm.ExecutorConfig(**kwargs)
|
|
for k, v in kwargs.items():
|
|
if "config" not in k:
|
|
assert getattr(config, k) == v
|
|
|
|
config.backend = 'pytorch'
|
|
|
|
pickle.dumps(config)
|
|
config_copy = pickle.loads(pickle.dumps(config))
|
|
assert config.max_beam_width == config_copy.max_beam_width
|
|
assert config.max_batch_size == config_copy.max_batch_size
|
|
assert config.max_num_tokens == config_copy.max_num_tokens
|
|
assert config.scheduler_config.capacity_scheduler_policy == config_copy.scheduler_config.capacity_scheduler_policy
|
|
assert config.kv_cache_config.enable_block_reuse == config_copy.kv_cache_config.enable_block_reuse
|
|
assert config.enable_chunked_context == config_copy.enable_chunked_context
|
|
assert config.normalize_log_probs == config_copy.normalize_log_probs
|
|
assert config.normalize_log_probs == config_copy.normalize_log_probs
|
|
assert config.iter_stats_max_iterations == config_copy.iter_stats_max_iterations
|
|
assert config.batching_type == config_copy.batching_type
|
|
assert config.parallel_config.communication_type == config_copy.parallel_config.communication_type
|
|
assert config.peft_cache_config.num_host_module_layer == config_copy.peft_cache_config.num_host_module_layer
|
|
assert config_copy.decoding_config.decoding_mode.isTopKandTopP
|
|
assert config.extended_runtime_perf_knob_config.multi_block_mode == config_copy.extended_runtime_perf_knob_config.multi_block_mode
|
|
assert config.debug_config.debug_input_tensors == config_copy.debug_config.debug_input_tensors
|
|
assert config.max_seq_idle_microseconds == config_copy.max_seq_idle_microseconds
|
|
assert config.backend == config_copy.backend
|
|
assert config.spec_dec_config.fast_logits == config_copy.spec_dec_config.fast_logits
|
|
|
|
|
|
def test_return_full_tokens():
|
|
max_tokens = 5
|
|
input_tokens = [1, 2, 3, 4]
|
|
request = trtllm.Request(input_tokens,
|
|
max_tokens=max_tokens,
|
|
streaming=False,
|
|
sampling_config=trtllm.SamplingConfig())
|
|
request.return_all_generated_tokens = True
|
|
assert request.return_all_generated_tokens == True
|
|
request.return_all_generated_tokens = False
|
|
assert request.return_all_generated_tokens == False
|
|
|
|
|
|
def test_getters_return_references():
|
|
config = trtllm.ExecutorConfig()
|
|
# Make sure kv_cache_config is a reference. Returning a value
|
|
# will lead to the very confusing behavior of this set statement
|
|
# not working.
|
|
config.kv_cache_config.max_tokens = 42
|
|
assert config.kv_cache_config.max_tokens == 42
|
|
|
|
|
|
def test_allotted_time_ms():
|
|
allotted_time = datetime.timedelta(milliseconds=2)
|
|
input_tokens = [1, 2, 3, 4]
|
|
|
|
max_new_tokens = 5
|
|
request = trtllm.Request(input_tokens, max_tokens=max_new_tokens)
|
|
|
|
request.allotted_time_ms = allotted_time
|
|
|
|
assert request.allotted_time_ms == datetime.timedelta(milliseconds=2)
|
|
|
|
|
|
def test_executor_version():
|
|
assert trtllm.__version__ == trtllm_version.__version__
|
|
|
|
|
|
def get_all_field_names_of_class(cls: type) -> list[str]:
|
|
return [
|
|
name for name, obj in inspect.getmembers(cls)
|
|
if isinstance(obj, property) or (
|
|
not callable(obj) and not name.startswith('__'))
|
|
]
|
|
|
|
|
|
def test_runtime_defaults():
|
|
full_runtime_defaults: dict[str, tp.Any] = json.loads("""{
|
|
"max_attention_window": [1, 2],
|
|
"sink_token_length": 4
|
|
}""")
|
|
all_field_names = set(full_runtime_defaults)
|
|
|
|
assert set(
|
|
get_all_field_names_of_class(trtllm.RuntimeDefaults)
|
|
) == all_field_names, "Expected fields of runtime_defaults to match actual data"
|
|
|
|
msg = """\
|
|
Rather than create a `from_dict` on top of the bound class, \
|
|
we rely on being able to directly provide the dict created from raw json as kwargs to `RuntimeDefaults.` \
|
|
See: `PretrainedConfig.__init__()`"""
|
|
|
|
assert PretrainedConfig.create_runtime_defaults(
|
|
full_runtime_defaults) is not None, msg
|
|
|
|
default_runtime_defaults = trtllm.RuntimeDefaults()
|
|
for key in all_field_names:
|
|
assert getattr(default_runtime_defaults, key) == None
|
|
|
|
|
|
def test_DynamicBatchConfig_pickle():
|
|
config = trtllm.DynamicBatchConfig(enable_batch_size_tuning=True,
|
|
enable_max_num_tokens_tuning=True,
|
|
dynamic_batch_moving_average_window=128)
|
|
config_copy = pickle.loads(pickle.dumps(config))
|
|
assert config.enable_batch_size_tuning == config_copy.enable_batch_size_tuning
|