TensorRT-LLMs/tests/bindings/test_executor_bindings.py
Sharan Chetlur 258c7540c0 open source 09df54c0cc99354a60bbc0303e3e8ea33a96bef0 (#2725)
Co-authored-by: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com>

open source f8c0381a2bc50ee2739c3d8c2be481b31e5f00bd (#2736)

Co-authored-by: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com>

Add note for blackwell (#2742)

Update the docs to workaround the extra-index-url issue (#2744)

update README.md (#2751)

Fix github io pages (#2761)

Update
2025-02-11 02:21:51 +00:00

2158 lines
86 KiB
Python

import datetime
import json
import os as _os
import pickle
import random
import sys as _sys
import time
import typing as tp
from pathlib import Path
import numpy as np
import pytest
import torch
from binding_test_utils import *
from pydantic import BaseModel
import tensorrt_llm.bindings.executor as trtllm
import tensorrt_llm.version as trtllm_version
from tensorrt_llm.models.modeling_utils import PretrainedConfig
_sys.path.append(_os.path.join(_os.path.dirname(__file__), '..'))
import inspect
from utils.cpp_paths import *
from utils.llm_data import llm_models_root
from utils.util import skip_pre_ampere
@pytest.fixture
def model_files(llm_root: Path, resource_path: Path, results_data_path: Path):
# Model engines and expected outputs need to be generated.
if not results_data_path.exists():
model_cache = llm_models_root()
model_cache_arg = ["--model_cache", str(model_cache)
] if model_cache is not None else []
prepare_model_tests(llm_root, resource_path, "gpt", model_cache_arg)
@pytest.fixture
def lora_config_paths(llm_root: Path, resource_path: Path,
lora_config_path: Path):
if not lora_config_path.exists():
prepare_lora_configs(llm_root, resource_path, lora_config_path)
return (lora_config_path / "source.npy", lora_config_path / "config.npy")
def get_expected_num_tokens(prompt_len, max_tokens, streaming,
exclude_input_from_output):
if not streaming and not exclude_input_from_output:
return prompt_len + max_tokens
return max_tokens
@skip_pre_ampere # ContextFMHAType with fp32 acc is not supported in pre-ampere architecture
def test_executor_valid_ctor(model_files, model_path):
executor_config = trtllm.ExecutorConfig(1)
executor = trtllm.Executor(model_path, trtllm.ModelType.DECODER_ONLY,
executor_config)
@skip_pre_ampere # ContextFMHAType with fp32 acc is not supported in pre-ampere architecture
def test_executor_from_memory(model_files, model_path):
executor_config = trtllm.ExecutorConfig(1)
engine_buffer = open(model_path / "rank0.engine", mode="rb").read()
json_config_str = open(model_path / "config.json", 'r').read()
executor = trtllm.Executor(engine_buffer, json_config_str,
trtllm.ModelType.DECODER_ONLY, executor_config)
def test_executor_invalid_ctor():
executor_config = trtllm.ExecutorConfig(1)
invalid_path = "Bla"
try:
executor = trtllm.Executor(invalid_path, trtllm.ModelType.DECODER_ONLY,
executor_config)
assert False, "Expected an error"
except Exception as e:
assert "File does not exist" in str(e)
@skip_pre_ampere # ContextFMHAType with fp32 acc is not supported in pre-ampere architecture
def test_shutdown(model_files, model_path):
beam_width = 1
executor_config = trtllm.ExecutorConfig(beam_width)
executor = trtllm.Executor(model_path, trtllm.ModelType.DECODER_ONLY,
executor_config)
# Create the request
max_tokens = 5
input_tokens = [1, 2, 3, 4]
request = trtllm.Request(input_tokens,
max_tokens=max_tokens,
streaming=False,
sampling_config=trtllm.SamplingConfig())
# Enqueue the request
assert executor.can_enqueue_requests() == True
req_id = executor.enqueue_request(request)
executor.shutdown()
assert executor.can_enqueue_requests() == False
with pytest.raises(Exception):
executor.enqueue_request(request)
with pytest.raises(Exception):
executor.await_responses()
with pytest.raises(Exception):
executor.get_latest_iteration_stats()
with pytest.raises(Exception):
executor.get_latest_request_stats()
with pytest.raises(Exception):
executor.get_latest_debug_tensors()
with pytest.raises(Exception):
executor.cancel_request(req_id)
with pytest.raises(Exception):
executor.get_num_responses_ready(req_id)
@skip_pre_ampere # ContextFMHAType with fp32 acc is not supported in pre-ampere architecture
def test_embedding_bias(model_files, model_path):
streaming = False
exclude_input_from_output = False
output_config = trtllm.OutputConfig()
output_config.exclude_input_from_output = exclude_input_from_output
# Create executor
beam_width = 1
executor_config = trtllm.ExecutorConfig(beam_width)
executor = trtllm.Executor(model_path, trtllm.ModelType.DECODER_ONLY,
executor_config)
# Create the request
max_tokens = 5
input_tokens = [1, 2, 3, 4]
# Set embedding bias so "biased_output" is always picked
biased_output = 10
vocab_size_padded = 50257
embedding_bias = torch.zeros(vocab_size_padded)
embedding_bias[biased_output] = torch.finfo(torch.float32).max
request = trtllm.Request(input_tokens,
max_tokens=max_tokens,
streaming=streaming,
sampling_config=trtllm.SamplingConfig(),
output_config=output_config,
embedding_bias=embedding_bias)
# Enqueue the request
request_id = executor.enqueue_request(request)
# Get the new tokens
tokens = []
done = False
i = 0
max_wait_ms = 10000
while not done and i < max_wait_ms:
wait_time = datetime.timedelta(milliseconds=1)
responses = executor.await_responses(request_id, wait_time)
for response in responses:
assert not response.has_error(
), f"Request id {request_id} failed with err {response.error_msg}"
result = response.result
done = result.is_final
new_tokens = result.output_token_ids[beam_width - 1]
tokens.extend(new_tokens)
i += 1
assert i < max_wait_ms
assert len(tokens) == get_expected_num_tokens(
len(input_tokens), max_tokens, streaming,
exclude_input_from_output), f"{request_id}"
# All generated tokens should equal biased_output
assert tokens[-max_tokens:] == [biased_output] * max_tokens
@pytest.mark.parametrize("streaming", [False, True])
@pytest.mark.parametrize("exclude_input_from_output", [False])
@skip_pre_ampere # ContextFMHAType with fp32 acc is not supported in pre-ampere architecture
def test_single_request(streaming: bool, exclude_input_from_output: bool,
model_files, model_path):
output_config = trtllm.OutputConfig()
output_config.exclude_input_from_output = exclude_input_from_output
# Create executor
beam_width = 1
executor_config = trtllm.ExecutorConfig(beam_width)
executor = trtllm.Executor(model_path, trtllm.ModelType.DECODER_ONLY,
executor_config)
# Create the request
max_tokens = 5
input_tokens = [1, 2, 3, 4]
request = trtllm.Request(input_tokens,
max_tokens=max_tokens,
streaming=streaming,
sampling_config=trtllm.SamplingConfig(),
output_config=output_config)
# Enqueue the request
request_id = executor.enqueue_request(request)
# Get the new tokens
tokens = []
done = False
i = 0
max_wait_ms = 10000
while not done and i < max_wait_ms:
wait_time = datetime.timedelta(milliseconds=1)
responses = executor.await_responses(request_id, wait_time)
for response in responses:
assert not response.has_error(
), f"Request id {request_id} failed with err {response.error_msg}"
result = response.result
done = result.is_final
new_tokens = result.output_token_ids[beam_width - 1]
tokens.extend(new_tokens)
i += 1
assert i < max_wait_ms
assert len(tokens) == get_expected_num_tokens(
len(input_tokens), max_tokens, streaming,
exclude_input_from_output), f"{request_id}"
executor.get_latest_iteration_stats()
executor.get_latest_request_stats()
executor.get_latest_debug_tensors()
@skip_pre_ampere # ContextFMHAType with fp32 acc is not supported in pre-ampere architecture
def test_single_request_lora(model_files, model_path_lora, lora_config_paths):
streaming = False
exclude_input_from_output = False
output_config = trtllm.OutputConfig()
output_config.exclude_input_from_output = exclude_input_from_output
# Create executor
beam_width = 1
peft_cache_config = trtllm.PeftCacheConfig(num_put_workers=4,
num_ensure_workers=4)
executor_config = trtllm.ExecutorConfig(1,
peft_cache_config=peft_cache_config)
executor = trtllm.Executor(model_path_lora, trtllm.ModelType.DECODER_ONLY,
executor_config)
# Create the request
max_tokens = 5
input_tokens = [1, 2, 3, 4]
lora_weights = torch.tensor(np.load(lora_config_paths[0])).half()
lora_config = torch.tensor(np.load(lora_config_paths[1]))
request = trtllm.Request(input_tokens,
max_tokens=max_tokens,
streaming=streaming,
sampling_config=trtllm.SamplingConfig(),
output_config=output_config,
lora_config=trtllm.LoraConfig(
0, lora_weights, lora_config))
# Enqueue the request
request_id = executor.enqueue_request(request)
# Get the new tokens
tokens = []
done = False
i = 0
max_wait_ms = 10000
while not done and i < max_wait_ms:
wait_time = datetime.timedelta(milliseconds=1)
responses = executor.await_responses(request_id, wait_time)
for response in responses:
assert not response.has_error(
), f"Request id {request_id} failed with err {response.error_msg}"
result = response.result
done = result.is_final
new_tokens = result.output_token_ids[beam_width - 1]
tokens.extend(new_tokens)
i += 1
assert i < max_wait_ms
assert len(tokens) == get_expected_num_tokens(
len(input_tokens), max_tokens, streaming,
exclude_input_from_output), f"{request_id}"
@pytest.mark.parametrize("streaming", [False, True])
@pytest.mark.parametrize("exclude_input_from_output", [False])
@skip_pre_ampere # ContextFMHAType with fp32 acc is not supported in pre-ampere architecture
def test_multi_request(streaming: bool, exclude_input_from_output: bool,
model_files, model_path):
output_config = trtllm.OutputConfig()
output_config.exclude_input_from_output = exclude_input_from_output
# Create executor
beam_width = 1
executor_config = trtllm.ExecutorConfig(beam_width)
executor = trtllm.Executor(model_path, trtllm.ModelType.DECODER_ONLY,
executor_config)
num_requests = 20
max_prompt_len = 20
max_max_tokens = 20
end_id = -1
# Enqueue the requests
tokens = {}
expected_num_tokens = {}
for i in range(num_requests):
prompt_len = random.randint(1, max_prompt_len)
max_tokens = random.randint(1, max_max_tokens)
input_tokens = [1] * prompt_len
# Some requests has num_return_sequences > 1.
num_return_sequences = 2 if i % 5 == 1 else 1
request = trtllm.Request(input_tokens,
max_tokens=max_tokens,
streaming=streaming,
sampling_config=trtllm.SamplingConfig(
num_return_sequences=num_return_sequences),
output_config=output_config,
end_id=end_id)
request_id = executor.enqueue_request(request)
tokens[request_id] = [
[] for _ in range(request.sampling_config.num_return_sequences)
]
expected_num_tokens[request_id] = get_expected_num_tokens(
prompt_len, max_tokens, streaming, exclude_input_from_output)
# Get the new tokens for each request
num_finished = 0
i = 0
num_responses = 0
max_wait_ms = 10000
while num_finished < num_requests and i < max_wait_ms:
wait_time = datetime.timedelta(milliseconds=1)
responses = executor.await_responses(wait_time)
for response in responses:
num_responses += 1
assert not response.has_error(
), f"Request id {response.request_id} failed with err {response.error_msg}"
result = response.result
num_finished += result.is_final
new_tokens = result.output_token_ids[beam_width - 1]
tokens[response.request_id][result.sequence_index].extend(
new_tokens)
i += 1
assert i < max_wait_ms
for request_id in expected_num_tokens:
for actual_tokens in tokens[request_id]:
assert len(actual_tokens) == expected_num_tokens[request_id]
@pytest.mark.parametrize("streaming", [False, True])
@pytest.mark.parametrize("exclude_input_from_output", [False])
@skip_pre_ampere # ContextFMHAType with fp32 acc is not supported in pre-ampere architecture
def test_multi_request_with_ids(streaming: bool,
exclude_input_from_output: bool, model_files,
model_path):
output_config = trtllm.OutputConfig()
output_config.exclude_input_from_output = exclude_input_from_output
# Create executor
beam_width = 1
executor_config = trtllm.ExecutorConfig(beam_width)
executor = trtllm.Executor(model_path, trtllm.ModelType.DECODER_ONLY,
executor_config)
num_requests = 20
max_prompt_len = 20
max_max_tokens = 20
end_id = -1
# Enqueue the requests
tokens = {}
expected_num_tokens = {}
for i in range(num_requests):
prompt_len = random.randint(1, max_prompt_len)
max_tokens = random.randint(1, max_max_tokens)
input_tokens = [1] * prompt_len
num_return_sequences = 2 if i % 5 == 1 else 1
request = trtllm.Request(input_tokens,
max_tokens=max_tokens,
streaming=streaming,
sampling_config=trtllm.SamplingConfig(
num_return_sequences=num_return_sequences),
output_config=output_config,
end_id=end_id)
request_id = executor.enqueue_request(request)
tokens[request_id] = [
[] for _ in range(request.sampling_config.num_return_sequences)
]
expected_num_tokens[request_id] = get_expected_num_tokens(
prompt_len, max_tokens, streaming, exclude_input_from_output)
# Get the new tokens for each request
num_finished = 0
i = 0
num_responses = 0
max_wait_ms = 10000
while num_finished < num_requests and i < max_wait_ms:
wait_time = datetime.timedelta(milliseconds=1)
id_responses = executor.await_responses(list(tokens.keys()), wait_time)
for responses in id_responses:
for response in responses:
num_responses += 1
# Allow response with error only if await_response processed a terminated request id
if response.has_error():
terminated_request_error = "ReqId " + str(
response.request_id
) + " has already been processed and was terminated."
assert response.error_msg == terminated_request_error, (
f"Request id {response.request_id} failed with err "
f"{response.error_msg}")
else:
result = response.result
num_finished += result.is_final
new_tokens = result.output_token_ids[beam_width - 1]
tokens[response.request_id][result.sequence_index].extend(
new_tokens)
i += 1
assert i < max_wait_ms
for request_id in expected_num_tokens:
for seq_idx, actual_tokens in enumerate(tokens[request_id]):
assert len(actual_tokens) == expected_num_tokens[request_id]
@pytest.mark.parametrize("streaming", [False, True])
@pytest.mark.parametrize("exclude_input_from_output", [False])
@skip_pre_ampere # ContextFMHAType with fp32 acc is not supported in pre-ampere architecture
def test_get_num_responses_ready(streaming: bool,
exclude_input_from_output: bool, model_files,
model_path):
output_config = trtllm.OutputConfig()
output_config.exclude_input_from_output = exclude_input_from_output
# Create executor
executor_config = trtllm.ExecutorConfig(1)
executor = trtllm.Executor(model_path, trtllm.ModelType.DECODER_ONLY,
executor_config)
max_prompt_len = 20
max_max_tokens = 20
# Enqueue the requests
num_requests = random.randint(1, 50)
num_expected_responses = 0
req_num_expected_responses = {}
for i in range(num_requests):
prompt_len = random.randint(1, max_prompt_len)
max_tokens = random.randint(1, max_max_tokens)
num_return_sequences = 2 if i % 5 == 1 else 1
request = trtllm.Request([1] * prompt_len,
max_tokens=max_tokens,
streaming=streaming,
sampling_config=trtllm.SamplingConfig(
num_return_sequences=num_return_sequences),
output_config=output_config)
request_id = executor.enqueue_request(request)
req_num_expected_responses[request_id] = (
(max_tokens if streaming else 1) * num_return_sequences)
num_expected_responses += req_num_expected_responses[request_id]
i = 0
num_ready = 0
max_wait_ms = 10000
while num_ready < num_expected_responses and i < max_wait_ms:
num_ready = 0
for request_id in req_num_expected_responses:
num_ready += executor.get_num_responses_ready(request_id)
time.sleep(0.001)
i += 1
assert i < max_wait_ms
for request_id in req_num_expected_responses:
num_ready = executor.get_num_responses_ready(request_id)
assert num_ready == req_num_expected_responses[request_id]
assert executor.get_num_responses_ready() == num_expected_responses
@pytest.mark.parametrize("batching_type", [trtllm.BatchingType.INFLIGHT])
@pytest.mark.parametrize("streaming", [False, True])
@pytest.mark.parametrize("beam_width", [1])
@pytest.mark.parametrize("compute_log_probs", [False, True])
@pytest.mark.parametrize("exclude_input_from_output", [False])
@pytest.mark.parametrize("return_context_logits", [False, True])
@pytest.mark.parametrize("return_generation_logits", [False, True])
@skip_pre_ampere # ContextFMHAType with fp32 acc is not supported in pre-ampere architecture
def test_token_comparison(batching_type: trtllm.BatchingType, streaming: bool,
beam_width: int, compute_log_probs: bool,
exclude_input_from_output: bool,
return_context_logits: bool,
return_generation_logits: bool, model_files,
model_path, model_path_return_logits, input_data_path,
results_data_path, results_data_path_beam_width_2):
if streaming and beam_width > 1:
pytest.skip("Test does not support streaming with beam search")
vocab_size_padded = 50257
pad_id = 50256
remove_input = not exclude_input_from_output and not streaming
def load_test_data(input_path, results_path):
# Inputs
assert input_path.is_file()
given_input = np.load(input_path).astype("int32")
input_shape = given_input.shape
assert len(input_shape) == 2
max_input_length = input_shape[1]
given_input_lengths = sequence_lengths(given_input, pad_id)
assert np.all(given_input_lengths <= max_input_length)
# Expected results
assert results_path.is_file()
expected_outputs = np.load(results_path).astype("int32")
output_shape = expected_outputs.shape
assert len(output_shape) == 2
assert input_shape[0] * beam_width == output_shape[0]
max_seq_length = output_shape[1]
max_tokens = max_seq_length - max_input_length
end_ids = [pad_id for _ in range(len(given_input_lengths))]
expected_lengths = []
for i in range(len(given_input_lengths)):
expected_lengths.append([
given_input_lengths[i] + max_tokens for _ in range(beam_width)
])
test_data = {
"expected_output_ids": expected_outputs,
"expected_output_lengths": expected_lengths,
"max_seq_length": max_seq_length,
"end_ids": end_ids
}
return given_input, given_input_lengths, max_input_length, test_data
def validate_results_shapes(result, input_length, max_output_len,
beam_tokens):
if compute_log_probs:
assert result.cum_log_probs is not None
assert result.log_probs is not None
assert len(result.cum_log_probs) == beam_width
assert len(result.log_probs) == beam_width
for beam in range(beam_width):
expected_len = len(
beam_tokens[beam]) - (input_length if remove_input else 0)
assert len(result.log_probs[beam]) == expected_len
else:
assert result.cum_log_probs is None
assert result.log_probs is None
if return_context_logits:
assert result.context_logits is not None
assert len(result.context_logits.shape) == 2
assert list(result.context_logits.shape) == [
input_length, vocab_size_padded
]
else:
assert result.context_logits is None
if return_generation_logits:
assert len(result.generation_logits.shape) == 3
if streaming:
assert list(result.generation_logits.shape) == [
max_output_len, beam_width, vocab_size_padded
] or list(result.generation_logits.shape) == [
1, beam_width, vocab_size_padded
]
else:
assert list(result.generation_logits.shape) == [
beam_width, max_output_len, vocab_size_padded
]
def verify_output(beam_tokens, test_data, given_input_lengths):
for batch_id, seq_tokens in beam_tokens.items():
input_length = given_input_lengths[batch_id]
end_id = test_data["end_ids"][batch_id]
for tokens in seq_tokens:
for beam in range(beam_width):
predicted_tokens = tokens[beam]
if remove_input:
predicted_tokens = predicted_tokens[input_length:]
expected_length = test_data["expected_output_lengths"][
batch_id][beam] - input_length
assert len(predicted_tokens) == expected_length
expected_tokens = test_data["expected_output_ids"][
batch_id * beam_width + beam][input_length:]
for i in range(len(predicted_tokens)):
if expected_tokens[i] == end_id:
break
assert predicted_tokens[i] == expected_tokens[i], \
f"Predicted: {predicted_tokens} vs Expected: {expected_tokens}"
output_config = trtllm.OutputConfig()
output_config.exclude_input_from_output = exclude_input_from_output
output_config.return_log_probs = compute_log_probs
output_config.return_generation_logits = return_generation_logits
output_config.return_context_logits = return_context_logits
kv_cache_config = trtllm.KvCacheConfig(False, free_gpu_memory_fraction=0.5)
executor_config = trtllm.ExecutorConfig(beam_width)
executor_config.batching_type = batching_type
executor_config.kv_cache_config = kv_cache_config
if return_context_logits or return_generation_logits:
model_path = model_path_return_logits
executor = trtllm.Executor(model_path, trtllm.ModelType.DECODER_ONLY,
executor_config)
# Load test data
results_path = results_data_path if beam_width == 1 else results_data_path_beam_width_2
given_input, given_input_lengths, max_input_length, test_data = load_test_data(
input_data_path, results_path)
# Create requests from input data
num_requests = len(given_input_lengths)
requests = []
req_max_tokens = []
for i in range(num_requests):
input_len = given_input_lengths[i]
max_tokens = test_data["max_seq_length"] - max_input_length
req_max_tokens.append(max_tokens)
req_tokens = given_input[i][:input_len]
num_return_sequences = 2 if i % 5 == 1 else 1
request = trtllm.Request(req_tokens,
max_tokens=max_tokens,
streaming=streaming,
sampling_config=trtllm.SamplingConfig(
beam_width,
num_return_sequences=num_return_sequences),
output_config=output_config,
end_id=-1)
requests.append(request)
req_ids = executor.enqueue_requests(requests)
req_to_batch_id = {req_ids[i]: i for i in range(len(requests))}
tokens = {
i: [[[] for _ in range(beam_width)]
for _ in range(req.sampling_config.num_return_sequences)]
for i, req in enumerate(requests)
}
num_finished = 0
i = 0
num_responses = 0
max_wait_ms = 10000
while num_finished < num_requests and i < max_wait_ms:
wait_time = datetime.timedelta(milliseconds=1)
responses = executor.await_responses(wait_time)
for response in responses:
num_responses += 1
assert not response.has_error(
), f"Request id {response.request_id} failed with err {response.error_msg}"
result = response.result
num_finished += result.is_final
batch_id = req_to_batch_id[response.request_id]
for beam in range(beam_width):
new_tokens = result.output_token_ids[beam]
tokens[batch_id][result.sequence_index][beam] += new_tokens
validate_results_shapes(result, given_input_lengths[batch_id],
req_max_tokens[batch_id],
tokens[batch_id][result.sequence_index])
i += 1
assert i < max_wait_ms
verify_output(tokens, test_data, given_input_lengths)
@pytest.mark.parametrize("streaming", [False, True])
@pytest.mark.parametrize("beam_width", [1])
@skip_pre_ampere # ContextFMHAType with fp32 acc is not supported in pre-ampere architecture
def test_finish_reason(streaming: bool, beam_width: int, model_files,
model_path):
if streaming and beam_width > 1:
pytest.skip("Test does not support streaming with beam search")
executor = trtllm.Executor(model_path, trtllm.ModelType.DECODER_ONLY,
trtllm.ExecutorConfig(beam_width))
requests = [
# Finish due to length.
trtllm.Request([1, 2, 3, 4],
max_tokens=5,
streaming=streaming,
sampling_config=trtllm.SamplingConfig(beam_width)),
# Finish due to end id.
trtllm.Request([1, 2, 3, 4],
max_tokens=5,
streaming=streaming,
sampling_config=trtllm.SamplingConfig(beam_width),
end_id=4),
# Finish due to stop word.
trtllm.Request([1, 2, 3, 4],
max_tokens=5,
streaming=streaming,
sampling_config=trtllm.SamplingConfig(beam_width),
stop_words=[[4, 2]]),
]
req_ids = executor.enqueue_requests(requests)
req_to_batch_id = {req_ids[i]: i for i in range(len(requests))}
num_finished = 0
i = 0
num_responses = 0
max_wait_ms = 10000
while num_finished < len(requests) and i < max_wait_ms:
wait_time = datetime.timedelta(milliseconds=1)
responses = executor.await_responses(wait_time)
for response in responses:
num_responses += 1
assert not response.has_error(
), f"Request id {response.request_id} failed with err {response.error_msg}"
result = response.result
num_finished += result.is_final
batch_id = req_to_batch_id[response.request_id]
# Non final results should have "NOT_FINISHED". Revise this when streaming + beam_width > 1 is enabled.
if not result.is_final:
assert all([
r == trtllm.FinishReason.NOT_FINISHED
for r in result.finish_reasons
])
# Check if finish reason is correct.
elif batch_id == 0:
assert all([
r == trtllm.FinishReason.LENGTH
for r in result.finish_reasons
])
elif batch_id == 1:
assert all([
r == trtllm.FinishReason.END_ID
for r in result.finish_reasons
])
elif batch_id == 2:
assert all([
r == trtllm.FinishReason.STOP_WORDS
for r in result.finish_reasons
])
i += 1
assert i < max_wait_ms
@skip_pre_ampere # ContextFMHAType with fp32 acc is not supported in pre-ampere architecture
def test_gpt_executor_timed_out(model_files, model_path):
beam_width = 1
executor_config = trtllm.ExecutorConfig(beam_width)
executor = trtllm.Executor(model_path, trtllm.ModelType.DECODER_ONLY,
executor_config)
# No requests enqueued, expect no responses
num_responses_ready = executor.get_num_responses_ready()
assert num_responses_ready == 0
wait_time = datetime.timedelta(milliseconds=10)
responses = executor.await_responses(wait_time)
assert len(responses) == 0
@skip_pre_ampere # ContextFMHAType with fp32 acc is not supported in pre-ampere architecture
def test_single_request_invalid_inputs(model_files, model_path):
streaming = True
beam_width = 1
executor_config = trtllm.ExecutorConfig(beam_width)
executor = trtllm.Executor(model_path, trtllm.ModelType.DECODER_ONLY,
executor_config)
max_tokens = 5
input_tokens = [1, 2, 3, 4]
request = trtllm.Request(input_tokens,
max_tokens=max_tokens,
streaming=streaming)
# Invalid embedding bias shape
embedding_bias = torch.ones(1)
request.embedding_bias = embedding_bias
expected_error_msg = "embedding bias shape is not as expected"
request_id = executor.enqueue_request(request)
done = False
i = 0
max_wait_ms = 10000
while not done and i < max_wait_ms:
wait_time = datetime.timedelta(milliseconds=1)
responses = executor.await_responses(request_id, wait_time)
for response in responses:
assert response.has_error(), "Expected an error"
assert expected_error_msg in response.error_msg
done = True
i += 1
assert done
def test_sampling_config():
beam_width = 1
kwargs = {
"top_k": 2,
"top_p": 1.0,
"top_p_min": 1.0,
"top_p_reset_ids": 3,
"top_p_decay": 1.0,
"seed": 7,
"temperature": 1.0,
"min_tokens": 4,
"beam_search_diversity_rate": 1.0,
"repetition_penalty": 1.0,
"presence_penalty": 1.0,
"frequency_penalty": 1.0,
"length_penalty": 1.0,
"early_stopping": 5,
"num_return_sequences": 2,
}
config = trtllm.SamplingConfig(beam_width, **kwargs)
for k, v in kwargs.items():
assert getattr(config, k) == v
del config
config = trtllm.SamplingConfig(beam_width)
assert config.beam_width == beam_width
for k in kwargs:
assert getattr(config, k) is None
def test_sampling_config_deprecated_args():
# random_seed -> seed
config = trtllm.SamplingConfig(seed=1)
assert config.seed == 1
assert config.random_seed == 1
config = trtllm.SamplingConfig(random_seed=2)
assert config.seed == 2
assert config.random_seed == 2
config = trtllm.SamplingConfig(seed=3, random_seed=4)
assert config.seed == 3
assert config.random_seed == 3
# min_length -> min_tokens
config = trtllm.SamplingConfig(min_tokens=1)
assert config.min_tokens == 1
assert config.min_length == 1
config = trtllm.SamplingConfig(min_length=2)
assert config.min_tokens == 2
assert config.min_length == 2
config = trtllm.SamplingConfig(min_tokens=3, min_length=4)
assert config.min_tokens == 3
assert config.min_length == 3
def test_output_config():
config = trtllm.OutputConfig()
assert config.return_log_probs == False
assert config.return_context_logits == False
assert config.return_generation_logits == False
assert config.exclude_input_from_output == False
assert config.return_encoder_output == False
assert config.return_perf_metrics == False
assert config.additional_model_outputs == None
config = trtllm.OutputConfig(
True, False, True, False, True, False,
list([trtllm.AdditionalModelOutput("topKLogits", True)]))
assert config.return_log_probs == True
assert config.return_context_logits == False
assert config.return_generation_logits == True
assert config.exclude_input_from_output == False
assert config.return_encoder_output == True
assert config.return_perf_metrics == False
assert len(config.additional_model_outputs) == 1
additional_model_output = config.additional_model_outputs[0]
assert additional_model_output.name == "topKLogits"
assert additional_model_output.gather_context == True
def test_external_draft_tokens_config():
tokens = [1, 2, 3]
config = trtllm.ExternalDraftTokensConfig(tokens)
assert config.tokens == tokens
assert config.logits is None
assert config.acceptance_threshold is None
del config
logits = torch.ones(3, 1)
acceptance_threshold = 1.0
fast_logits = False
config = trtllm.ExternalDraftTokensConfig(tokens, logits,
acceptance_threshold, fast_logits)
assert config.tokens == tokens
assert (config.logits == logits).all()
assert config.acceptance_threshold == acceptance_threshold
assert config.fast_logits == fast_logits
def test_prompt_tuning_config():
embedding_table = torch.ones(100, 64)
config = trtllm.PromptTuningConfig(embedding_table)
assert (config.embedding_table == embedding_table).all()
def test_mrope_config():
mrope_rotary_cos_sin = torch.ones(1, 4194304)
mrope_position_deltas = torch.tensor([-50])
config = trtllm.MropeConfig(mrope_rotary_cos_sin, mrope_position_deltas)
assert (config.mrope_rotary_cos_sin == mrope_rotary_cos_sin).all()
assert (config.mrope_position_deltas == mrope_position_deltas).all()
def test_lora_config():
task_id = 1
lora_config = trtllm.LoraConfig(task_id)
assert lora_config.task_id == task_id
assert lora_config.weights is None
assert lora_config.config is None
task_id = 2
weights = torch.ones(1, 2)
config = torch.ones(1, 2, dtype=torch.int32)
lora_config = trtllm.LoraConfig(task_id, weights, config)
assert lora_config.task_id == task_id
assert (lora_config.weights == weights).all()
assert (lora_config.config == config).all()
@skip_pre_ampere # ContextFMHAType with fp32 acc is not supported in pre-ampere architecture
def test_wakeup(model_files, model_path):
import threading
def resp_thread(stop_signal: threading.Event, executor: trtllm.Executor):
while not stop_signal.is_set():
timeout = None
responses = executor.await_responses(timeout=timeout)
if stop_signal.is_set():
return
for response in responses:
response.result.output_token_ids
executor = trtllm.Executor(model_path, trtllm.ModelType.DECODER_ONLY,
trtllm.ExecutorConfig())
stop_signal = threading.Event()
thread = threading.Thread(target=resp_thread, args=(stop_signal, executor))
thread.start()
request = trtllm.Request(input_token_ids=[1, 2, 3, 4],
max_tokens=5,
streaming=True)
executor.enqueue_request(request)
time.sleep(2)
stop_signal.set()
executor.shutdown()
thread.join()
assert not thread.is_alive()
def test_guided_decoding_params():
guided_decoding_params = trtllm.GuidedDecodingParams(
trtllm.GuidedDecodingParams.GuideType.JSON)
assert guided_decoding_params.guide_type == trtllm.GuidedDecodingParams.GuideType.JSON
class Answer(BaseModel):
answer: int
json_schema = json.dumps(Answer.model_json_schema())
guided_decoding_params = trtllm.GuidedDecodingParams(
trtllm.GuidedDecodingParams.GuideType.JSON_SCHEMA, guide=json_schema)
assert guided_decoding_params.guide_type == trtllm.GuidedDecodingParams.GuideType.JSON_SCHEMA
assert guided_decoding_params.guide == json_schema
with pytest.raises(Exception):
trtllm.GuidedDecodingParams(
trtllm.GuidedDecodingParams.GuideType.JSON_SCHEMA)
regex = r"\d+"
guided_decoding_params = trtllm.GuidedDecodingParams(
trtllm.GuidedDecodingParams.GuideType.REGEX, guide=regex)
assert guided_decoding_params.guide_type == trtllm.GuidedDecodingParams.GuideType.REGEX
assert guided_decoding_params.guide == regex
ebnf_grammar = "root ::= [0-9]+"
guided_decoding_params = trtllm.GuidedDecodingParams(
trtllm.GuidedDecodingParams.GuideType.EBNF_GRAMMAR, guide=ebnf_grammar)
assert guided_decoding_params.guide_type == trtllm.GuidedDecodingParams.GuideType.EBNF_GRAMMAR
assert guided_decoding_params.guide == ebnf_grammar
def test_request():
kwargs = {
"input_token_ids": [1, 2, 3],
"max_tokens": 1,
"streaming": False,
"sampling_config": trtllm.SamplingConfig(),
"output_config": trtllm.OutputConfig(),
"end_id": -1,
"pad_id": -2,
"bad_words": [[4, 5, 6]],
"stop_words": [[7, 8, 9]],
"embedding_bias": torch.ones(1),
"external_draft_tokens_config":
trtllm.ExternalDraftTokensConfig([1, 2, 3]),
"prompt_tuning_config": trtllm.PromptTuningConfig(torch.ones(100, 64)),
"lora_config": trtllm.LoraConfig(1),
"logits_post_processor_name": "my_logits_pp",
"client_id": 1234,
}
request = trtllm.Request(**kwargs)
for k, v in kwargs.items():
if "config" not in k:
assert getattr(request, k) == v
assert isinstance(request.sampling_config, trtllm.SamplingConfig)
assert isinstance(request.output_config, trtllm.OutputConfig)
assert isinstance(request.external_draft_tokens_config,
trtllm.ExternalDraftTokensConfig)
assert request.external_draft_tokens_config.tokens == [1, 2, 3]
assert isinstance(request.prompt_tuning_config, trtllm.PromptTuningConfig)
assert (request.prompt_tuning_config.embedding_table == torch.ones(
100, 64)).all()
assert isinstance(request.lora_config, trtllm.LoraConfig)
def test_request_deprecated_args():
# max_new_tokens -> max_tokens
request = trtllm.Request([1, 2, 3], max_tokens=10)
assert request.max_tokens == 10
assert request.max_new_tokens == 10
request = trtllm.Request([1, 2, 3], max_new_tokens=20)
assert request.max_tokens == 20
assert request.max_new_tokens == 20
request = trtllm.Request([1, 2, 3], max_tokens=30, max_new_tokens=40)
assert request.max_tokens == 30
assert request.max_new_tokens == 30
def test_spec_dec_fast_logits_info():
fast_logits_info = trtllm.SpeculativeDecodingFastLogitsInfo()
fast_logits_info.draft_request_id = 3
fast_logits_info.draft_participant_id = 5
assert fast_logits_info.draft_request_id == 3
assert fast_logits_info.draft_participant_id == 5
def test_result():
result = trtllm.Result()
result.is_final = True
result.output_token_ids = [[1, 2, 3]]
result.cum_log_probs = [1.0, 2.0, 3.0]
result.log_probs = [[1.0, 2.0, 3.0]]
result.context_logits = torch.ones(3, 100)
result.generation_logits = torch.ones(1, 3, 100)
result.encoder_output = torch.ones(1, 1)
result.finish_reasons = [trtllm.FinishReason.LENGTH]
result.sequence_index = 1
result.is_sequence_final = True
result.additional_outputs = [
trtllm.AdditionalOutput("topKLogits", torch.ones(1, 4, 100))
]
assert result.is_final is True
assert result.output_token_ids == [[1, 2, 3]]
assert result.cum_log_probs == [1.0, 2.0, 3.0]
assert result.log_probs == [[1.0, 2.0, 3.0]]
assert (result.context_logits == torch.ones(3, 100)).all()
assert (result.generation_logits == torch.ones(1, 3, 100)).all()
assert (result.encoder_output == torch.ones(1, 1)).all()
assert result.finish_reasons == [trtllm.FinishReason.LENGTH]
assert result.sequence_index == 1
assert result.is_sequence_final is True
assert len(result.additional_outputs) == 1
additional_output = result.additional_outputs[0]
assert additional_output.name == "topKLogits"
assert (additional_output.output == torch.ones(1, 4, 100)).all()
def test_response():
request_id = 0
error_msg = "error"
response = trtllm.Response(request_id, error_msg)
assert response.request_id == request_id
assert response.has_error()
assert response.error_msg == error_msg
result = trtllm.Result()
result.is_final = True
result.output_token_ids = [[1, 2, 3]]
request_id = 1
response = trtllm.Response(request_id, result)
assert response.request_id == request_id
assert not response.has_error()
assert response.result.is_final
assert response.result.output_token_ids == [[1, 2, 3]]
def test_scheduler_config():
capacity_scheduler_policy = trtllm.CapacitySchedulerPolicy.GUARANTEED_NO_EVICT
config = trtllm.SchedulerConfig()
assert config.capacity_scheduler_policy == capacity_scheduler_policy
assert config.context_chunking_policy == None
capacity_scheduler_policy = trtllm.CapacitySchedulerPolicy.MAX_UTILIZATION
config = trtllm.SchedulerConfig(capacity_scheduler_policy)
assert config.capacity_scheduler_policy == capacity_scheduler_policy
assert config.context_chunking_policy == None
capacity_scheduler_policy = trtllm.CapacitySchedulerPolicy.GUARANTEED_NO_EVICT
config = trtllm.SchedulerConfig(capacity_scheduler_policy)
assert config.capacity_scheduler_policy == capacity_scheduler_policy
assert config.context_chunking_policy == None
capacity_scheduler_policy = trtllm.CapacitySchedulerPolicy.STATIC_BATCH
config = trtllm.SchedulerConfig(capacity_scheduler_policy)
assert config.capacity_scheduler_policy == capacity_scheduler_policy
assert config.context_chunking_policy == None
context_chunking_policy = trtllm.ContextChunkingPolicy.FIRST_COME_FIRST_SERVED
config = trtllm.SchedulerConfig(capacity_scheduler_policy,
context_chunking_policy)
assert config.capacity_scheduler_policy == capacity_scheduler_policy
assert config.context_chunking_policy == context_chunking_policy
dynamic_batch_config = trtllm.DynamicBatchConfig(True, True, 128)
config = trtllm.SchedulerConfig(capacity_scheduler_policy,
context_chunking_policy,
dynamic_batch_config)
assert config.capacity_scheduler_policy == capacity_scheduler_policy
assert config.context_chunking_policy == context_chunking_policy
assert config.dynamic_batch_config.enable_batch_size_tuning == True
assert config.dynamic_batch_config.enable_max_num_tokens_tuning == True
assert config.dynamic_batch_config.dynamic_batch_moving_average_window == 128
def test_kv_cache_config():
config = trtllm.KvCacheConfig()
assert config.enable_block_reuse == True
assert config.max_tokens is None
assert config.max_attention_window is None
assert config.sink_token_length is None
assert config.free_gpu_memory_fraction is None
assert config.cross_kv_cache_fraction is None
assert config.host_cache_size is None
assert config.onboard_blocks == True
assert config.secondary_offload_min_priority == None
assert config.event_buffer_max_size == 0
config.enable_block_reuse = False
config.max_tokens = 1
config.max_attention_window = [2]
config.sink_token_length = 3
config.free_gpu_memory_fraction = 0.5
config.cross_kv_cache_fraction = 0.5
config.host_cache_size = 4
config.onboard_blocks = False
config.secondary_offload_min_priority = 50
config.event_buffer_max_size = 1024
assert config.enable_block_reuse == False
assert config.max_tokens == 1
assert config.max_attention_window == [2]
assert config.sink_token_length == 3
assert config.free_gpu_memory_fraction == 0.5
assert config.cross_kv_cache_fraction == 0.5
assert config.host_cache_size == 4
assert config.onboard_blocks == False
assert config.secondary_offload_min_priority == 50
assert config.event_buffer_max_size == 1024
kwargs = {
"enable_block_reuse": True,
"max_tokens": 3,
"max_attention_window": [10],
"sink_token_length": 2,
"free_gpu_memory_fraction": 0.5,
"cross_kv_cache_fraction": 0.5,
"host_cache_size": 1024,
"onboard_blocks": False,
"event_buffer_max_size": 2048
}
config = trtllm.KvCacheConfig(**kwargs)
for k, v in kwargs.items():
assert getattr(config, k) == v
config = trtllm.KvCacheConfig(**kwargs)
max_attention_window, sink_token_length = config.max_attention_window, config.sink_token_length
runtime_defaults = trtllm.RuntimeDefaults(
max_attention_window=max_attention_window + [1],
sink_token_length=sink_token_length + 1)
config.fill_empty_fields_from_runtime_defaults(runtime_defaults)
assert config.max_attention_window == max_attention_window, "runtime defaults shouldn't override existing values"
assert config.sink_token_length == sink_token_length, "runtime defaults shouldn't override existing values"
config = trtllm.KvCacheConfig(**{
**kwargs, "max_attention_window": None,
"sink_token_length": None
})
config.fill_empty_fields_from_runtime_defaults(runtime_defaults)
assert config.max_attention_window == runtime_defaults.max_attention_window, "runtime defaults should apply to non existent values"
assert config.sink_token_length == runtime_defaults.sink_token_length, "runtime defaults should apply to non existent values"
config = trtllm.KvCacheConfig(**kwargs, runtime_defaults=runtime_defaults)
setter_config = trtllm.KvCacheConfig(**kwargs)
setter_config.fill_empty_fields_from_runtime_defaults(runtime_defaults)
for k in kwargs.keys():
assert getattr(config, k) == getattr(
setter_config, k
), "passing runtime_defaults to the constructor or settings it manually should be equivalent"
def test_kv_cache_retention_config():
TokenRangeRetentionConfig = trtllm.KvCacheRetentionConfig.TokenRangeRetentionConfig
config = trtllm.KvCacheRetentionConfig(
[TokenRangeRetentionConfig(0, 2, 30, datetime.timedelta(seconds=30))],
80)
assert len(config.token_range_retention_configs) == 1
assert config.token_range_retention_configs[0].token_start == 0
assert config.token_range_retention_configs[0].token_end == 2
assert config.token_range_retention_configs[0].priority == 30
assert config.token_range_retention_configs[
0].duration_ms == datetime.timedelta(seconds=30)
assert config.decode_retention_priority == 80
assert config.decode_duration_ms is None
config = trtllm.KvCacheRetentionConfig([
TokenRangeRetentionConfig(0, 64, 80),
TokenRangeRetentionConfig(64, 100, 10)
], 10, datetime.timedelta(milliseconds=30000))
assert len(config.token_range_retention_configs) == 2
assert config.token_range_retention_configs[0].token_start == 0
assert config.token_range_retention_configs[0].token_end == 64
assert config.token_range_retention_configs[0].priority == 80
assert config.token_range_retention_configs[0].duration_ms is None
assert config.token_range_retention_configs[1].token_start == 64
assert config.token_range_retention_configs[1].token_end == 100
assert config.token_range_retention_configs[1].priority == 10
assert config.token_range_retention_configs[1].duration_ms is None
assert config.decode_retention_priority == 10
assert config.decode_duration_ms == datetime.timedelta(seconds=30)
with pytest.raises(Exception):
# Invalid token ranges
trtllm.KvCacheRetentionConfig([
TokenRangeRetentionConfig(0, 64, 10),
TokenRangeRetentionConfig(32, 128, 50)
], 50)
def test_lookahead_decoding_config():
config = trtllm.LookaheadDecodingConfig(3, 5, 7)
assert config.max_window_size == 3
assert config.max_ngram_size == 5
assert config.max_verification_set_size == 7
config = trtllm.LookaheadDecodingConfig(5, 10, 3)
assert config.max_window_size == 5
assert config.max_ngram_size == 10
assert config.max_verification_set_size == 3
kwargs = {
"max_window_size": 5,
"max_ngram_size": 3,
"max_verification_set_size": 7,
}
config = trtllm.LookaheadDecodingConfig(**kwargs)
for k, v in kwargs.items():
assert getattr(config, k) == v
def test_eagle_config():
config = trtllm.EagleConfig([[0, 0], [0, 1]], False, 0.5)
assert config.eagle_choices == [[0, 0], [0, 1]]
assert config.greedy_sampling == False
assert config.posterior_threshold == 0.5
config = trtllm.EagleConfig([[0, 0], [0, 1, 0]], True)
assert config.eagle_choices == [[0, 0], [0, 1, 0]]
assert config.greedy_sampling == True
assert config.posterior_threshold == None
config = trtllm.EagleConfig(None, True, 0.5)
assert config.eagle_choices == None
assert config.greedy_sampling == True
assert config.posterior_threshold == 0.5
kwargs = {
"eagle_choices": [[0, 0], [0, 1], [0, 2]],
"greedy_sampling": True,
"posterior_threshold": 0.5
}
config = trtllm.EagleConfig(**kwargs)
for k, v in kwargs.items():
assert getattr(config, k) == v
def test_decoding_mode():
mode = trtllm.DecodingMode.Auto()
assert mode.isAuto()
mode = trtllm.DecodingMode.TopK()
assert mode.isTopK()
mode = trtllm.DecodingMode.TopP()
assert mode.isTopP()
mode = trtllm.DecodingMode.TopKTopP()
assert mode.isTopKandTopP()
mode = trtllm.DecodingMode.BeamSearch()
assert mode.isBeamSearch()
mode = trtllm.DecodingMode.Medusa()
assert mode.isMedusa()
mode = trtllm.DecodingMode.Lookahead()
assert mode.isLookahead()
mode = trtllm.DecodingMode.ExplicitDraftTokens()
assert mode.isExplicitDraftTokens()
mode = trtllm.DecodingMode.Eagle()
assert mode.isEagle()
def test_speculative_decoding_config():
config = trtllm.DecodingConfig()
assert config.decoding_mode is None
assert config.lookahead_decoding_config is None
assert config.medusa_choices is None
assert config.eagle_config is None
config = trtllm.DecodingConfig()
config.decoding_mode = trtllm.DecodingMode.TopKTopP()
assert config.decoding_mode.isTopKandTopP()
assert config.lookahead_decoding_config == None
assert config.medusa_choices == None
assert config.eagle_config is None
config = trtllm.DecodingConfig()
la_decoding_config = trtllm.LookaheadDecodingConfig(3, 5, 7)
config.lookahead_decoding_config = la_decoding_config
assert config.decoding_mode.isLookahead()
assert config.lookahead_decoding_config.max_ngram_size == la_decoding_config.max_ngram_size
assert config.lookahead_decoding_config.max_window_size == la_decoding_config.max_window_size
assert config.lookahead_decoding_config.max_verification_set_size == la_decoding_config.max_verification_set_size
assert config.medusa_choices == None
assert config.eagle_config is None
config = trtllm.DecodingConfig()
config.medusa_choices = [[0, 0], [0, 1]]
assert config.decoding_mode.isMedusa()
assert config.lookahead_decoding_config == None
assert config.medusa_choices == [[0, 0], [0, 1]]
assert config.eagle_config is None
config = trtllm.DecodingConfig()
config.eagle_config = trtllm.EagleConfig([[0, 0], [0, 1]])
assert config.decoding_mode.isEagle()
assert config.lookahead_decoding_config == None
assert config.medusa_choices == None
assert config.eagle_config is not None
assert config.eagle_config.eagle_choices == [[0, 0], [0, 1]]
def test_logits_post_processor_config():
config = trtllm.LogitsPostProcessorConfig()
assert config.processor_map == None
assert config.processor_batched == None
assert config.replicate == True
kwargs = {
"processor_map": {
"test_pp": None
},
"processor_batched": None,
"replicate": False
}
config = trtllm.LogitsPostProcessorConfig(**kwargs)
for k, v in kwargs.items():
assert getattr(config, k) == v
def test_guided_decoding_config():
encoded_vocab = ["eos", "a", "b", "c", "d"]
tokenizer_str = None
stop_token_ids = [0]
guided_decoding_config = trtllm.GuidedDecodingConfig(
backend=trtllm.GuidedDecodingConfig.GuidedDecodingBackend.XGRAMMAR,
encoded_vocab=encoded_vocab,
tokenizer_str=tokenizer_str,
stop_token_ids=stop_token_ids)
assert guided_decoding_config.backend == trtllm.GuidedDecodingConfig.GuidedDecodingBackend.XGRAMMAR
assert guided_decoding_config.encoded_vocab == encoded_vocab
assert guided_decoding_config.tokenizer_str == tokenizer_str
assert guided_decoding_config.stop_token_ids == stop_token_ids
def test_executor_config():
config = trtllm.ExecutorConfig()
assert config.max_beam_width == 1
assert config.max_batch_size is None
assert config.max_num_tokens is None
assert isinstance(config.scheduler_config, trtllm.SchedulerConfig)
assert isinstance(config.kv_cache_config, trtllm.KvCacheConfig)
assert config.enable_chunked_context == False
assert config.normalize_log_probs == True
assert config.iter_stats_max_iterations == 1000
assert config.batching_type == trtllm.BatchingType.INFLIGHT
assert config.parallel_config is None
assert isinstance(config.peft_cache_config, trtllm.PeftCacheConfig)
assert config.logits_post_processor_config is None
assert config.decoding_config is None
assert config.debug_config is None
assert config.recv_poll_period_ms == 0
assert config.max_seq_idle_microseconds == 180000000
assert config.spec_dec_config is None
assert config.guided_decoding_config is None
kwargs = {
"max_beam_width":
2,
"max_batch_size":
8,
"max_num_tokens":
128,
"scheduler_config":
trtllm.SchedulerConfig(trtllm.CapacitySchedulerPolicy.MAX_UTILIZATION),
"kv_cache_config":
trtllm.KvCacheConfig(),
"enable_chunked_context":
True,
"normalize_log_probs":
False,
"iter_stats_max_iterations":
100,
"batching_type":
trtllm.BatchingType.STATIC,
"parallel_config":
trtllm.ParallelConfig(),
"peft_cache_config":
trtllm.PeftCacheConfig(10),
"logits_post_processor_config":
trtllm.LogitsPostProcessorConfig(),
"decoding_config":
trtllm.DecodingConfig(trtllm.DecodingMode.TopKTopP()),
"extended_runtime_perf_knob_config":
trtllm.ExtendedRuntimePerfKnobConfig(multi_block_mode=True),
"debug_config":
trtllm.DebugConfig(debug_input_tensors=True,
debug_output_tensors=True,
debug_tensor_names=["test"]),
"recv_poll_period_ms":
50,
"max_seq_idle_microseconds":
240 * 1000 * 1000,
"spec_dec_config":
trtllm.SpeculativeDecodingConfig(fast_logits=True),
"guided_decoding_config":
trtllm.GuidedDecodingConfig(
trtllm.GuidedDecodingConfig.GuidedDecodingBackend.XGRAMMAR,
encoded_vocab=["eos", "a", "b", "c", "d"]),
"additional_output_names": ["topKLogits"]
}
config = trtllm.ExecutorConfig(**kwargs)
for k, v in kwargs.items():
if "config" not in k:
assert getattr(config, k) == v
assert isinstance(config.scheduler_config, trtllm.SchedulerConfig)
assert config.scheduler_config.capacity_scheduler_policy == trtllm.CapacitySchedulerPolicy.MAX_UTILIZATION
assert isinstance(config.kv_cache_config, trtllm.KvCacheConfig)
assert isinstance(config.parallel_config, trtllm.ParallelConfig)
assert isinstance(config.peft_cache_config, trtllm.PeftCacheConfig)
assert config.extended_runtime_perf_knob_config.multi_block_mode == True
assert isinstance(config.debug_config, trtllm.DebugConfig)
assert isinstance(config.logits_post_processor_config,
trtllm.LogitsPostProcessorConfig)
assert isinstance(config.spec_dec_config, trtllm.SpeculativeDecodingConfig)
assert isinstance(config.guided_decoding_config,
trtllm.GuidedDecodingConfig)
assert isinstance(config.additional_output_names, list)
assert len(config.additional_output_names) == 1
assert config.additional_output_names[0] == "topKLogits"
def test_parallel_config():
comm_type = trtllm.CommunicationType.MPI
comm_mode = trtllm.CommunicationMode.LEADER
device_ids = [0, 1, 2, 3]
participant_ids = [4, 5, 6, 7]
parallel_config = trtllm.ParallelConfig(comm_type, comm_mode, device_ids,
participant_ids)
assert parallel_config.communication_type == comm_type
assert parallel_config.communication_mode == comm_mode
assert parallel_config.device_ids == device_ids
assert parallel_config.participant_ids == participant_ids
comm_mode = trtllm.CommunicationMode.ORCHESTRATOR
#Dummy path to worker executable
worker_path = _os.path.abspath(__file__)
orchestrator_config = trtllm.OrchestratorConfig(True, str(worker_path),
None, True)
parallel_config = trtllm.ParallelConfig(comm_type, comm_mode, device_ids,
participant_ids,
orchestrator_config)
assert parallel_config.communication_mode == comm_mode
assert parallel_config.orchestrator_config.is_orchestrator == True
assert parallel_config.orchestrator_config.worker_executable_path == str(
worker_path)
assert parallel_config.orchestrator_config.spawn_processes == True
def test_peft_cache_config():
num_host_module_layer = 1
num_device_module_layer = 2
optimal_adapter_size = 3
max_adapter_size = 4
num_put_workers = 5
num_ensure_workers = 6
num_copy_streams = 7
max_pages_per_block_host = 8
max_pages_per_block_device = 9
device_cache_percent = 0.9
host_cache_size = 1024
peft_cache_config = trtllm.PeftCacheConfig(
num_host_module_layer, num_device_module_layer, optimal_adapter_size,
max_adapter_size, num_put_workers, num_ensure_workers, num_copy_streams,
max_pages_per_block_host, max_pages_per_block_device,
device_cache_percent, host_cache_size)
assert peft_cache_config.num_host_module_layer == num_host_module_layer
assert peft_cache_config.num_device_module_layer == num_device_module_layer
assert peft_cache_config.optimal_adapter_size == optimal_adapter_size
assert peft_cache_config.max_adapter_size == max_adapter_size
assert peft_cache_config.num_put_workers == num_put_workers
assert peft_cache_config.num_ensure_workers == num_ensure_workers
assert peft_cache_config.num_copy_streams == num_copy_streams
assert peft_cache_config.max_pages_per_block_host == max_pages_per_block_host
assert peft_cache_config.max_pages_per_block_device == max_pages_per_block_device
assert np.isclose(peft_cache_config.device_cache_percent,
device_cache_percent)
assert peft_cache_config.host_cache_size == host_cache_size
@skip_pre_ampere # ContextFMHAType with fp32 acc is not supported in pre-ampere architecture
def test_logits_post_processor(model_files, model_path):
# Define the logits post-processor callback
def logits_post_processor(req_id: int, logits: torch.Tensor,
ids: tp.List[tp.List[int]], stream_ptr: int,
client_id: tp.Optional[int]):
assert client_id == 123
with torch.cuda.stream(torch.cuda.ExternalStream(stream_ptr)):
logits[:] = float("-inf")
logits[..., 42] = 0
# Create executor
beam_width = 1
executor_config = trtllm.ExecutorConfig(beam_width)
executor_config.logits_post_processor_config = trtllm.LogitsPostProcessorConfig(
{"my_logits_pp": logits_post_processor})
executor = trtllm.Executor(model_path, trtllm.ModelType.DECODER_ONLY,
executor_config)
# Create the request
max_tokens = 5
input_tokens = [1, 2, 3, 4]
request = trtllm.Request(input_tokens,
max_tokens=max_tokens,
streaming=False,
client_id=123)
request.logits_post_processor_name = "my_logits_pp"
# Enqueue the request
request_id = executor.enqueue_request(request)
# Get the new tokens
tokens = []
done = False
i = 0
max_wait_ms = 10000
while not done and i < max_wait_ms:
wait_time = datetime.timedelta(milliseconds=1)
responses = executor.await_responses(request_id, wait_time)
for response in responses:
assert not response.has_error(
), f"Request id {request_id} failed with err {response.error_msg}"
result = response.result
done = result.is_final
new_tokens = result.output_token_ids[beam_width - 1]
tokens.extend(new_tokens)
i += 1
assert i < max_wait_ms
assert len(tokens) == get_expected_num_tokens(len(input_tokens), max_tokens,
False, False), f"{request_id}"
# check that all output tokens are 42
print(tokens)
assert tokens[-max_tokens:] == [42] * max_tokens
@skip_pre_ampere # ContextFMHAType with fp32 acc is not supported in pre-ampere architecture
def test_logits_post_processor_batched(model_files, model_path):
# Define the logits post-processor callback
def logits_post_processor_batched(
req_id_batch: tp.List[int], logits_batch: tp.List[torch.Tensor],
ids_batch: tp.List[tp.List[tp.List[int]]], stream_ptr: int,
client_id_batch: tp.List[tp.Optional[int]]):
for client_id in client_id_batch:
assert client_id == 123
with torch.cuda.stream(torch.cuda.ExternalStream(stream_ptr)):
for logits in logits_batch:
logits[:] = float("-inf")
logits[..., 42] = 0
# Create executor
beam_width = 1
executor_config = trtllm.ExecutorConfig(beam_width)
executor_config.logits_post_processor_config = trtllm.LogitsPostProcessorConfig(
None, logits_post_processor_batched)
executor = trtllm.Executor(model_path, trtllm.ModelType.DECODER_ONLY,
executor_config)
# Create the request
max_tokens = 5
input_tokens = [1, 2, 3, 4]
request = trtllm.Request(input_tokens,
max_tokens=max_tokens,
streaming=False,
client_id=123)
request.logits_post_processor_name = request.BATCHED_POST_PROCESSOR_NAME
batch_size = 4
# Enqueue the requests
request_ids = []
for _ in range(batch_size):
request_id = executor.enqueue_request(request)
request_ids.append(request_id)
# Get the new tokens
tokens = {req_id: [] for req_id in request_ids}
num_finished = 0
i = 0
max_wait_ms = 10000
while num_finished < len(request_ids) and i < max_wait_ms:
responses = executor.await_responses(datetime.timedelta(milliseconds=1))
for response in responses:
req_id = response.request_id
assert not response.has_error(
), f"Request id {req_id} failed with err {response.error_msg}"
result = response.result
num_finished += 1 if result.is_final else 0
new_tokens = result.output_token_ids[beam_width - 1]
tokens[req_id].extend(new_tokens)
assert i < max_wait_ms
expected_num_tokens = get_expected_num_tokens(len(input_tokens), max_tokens,
False, False)
for req_id in request_ids:
assert len(tokens[req_id]) == expected_num_tokens, f"{req_id}"
def test_kv_event_stream(model_path):
beam_width = 1
executor_config = trtllm.ExecutorConfig(
beam_width,
kv_cache_config=trtllm.KvCacheConfig(True,
4 * 64,
event_buffer_max_size=1024,
host_cache_size=3000000))
executor = trtllm.Executor(model_path, trtllm.ModelType.DECODER_ONLY,
executor_config)
cache_manager = executor.get_kv_cache_event_manager()
events = cache_manager.get_latest_events()
assert len(events) == 1
assert isinstance(events[0], trtllm.kv_cache.KVCacheEvent)
assert events[0].event_id == 0
assert isinstance(events[0].data, trtllm.kv_cache.KVCacheCreatedData)
for req in range(2):
input_tokens = list(range(req, req + 127))
request = trtllm.Request(input_tokens,
max_tokens=5,
streaming=False,
sampling_config=trtllm.SamplingConfig())
id = executor.enqueue_request(request)
responses = executor.await_responses(id)
for response in responses:
assert not response.has_error()
if response.result.is_final:
time.sleep(0.1)
events = cache_manager.get_latest_events(
datetime.timedelta(milliseconds=100))
if req == 0:
assert events[0].event_id == 1
assert isinstance(events[0].data,
trtllm.kv_cache.KVCacheStoredData)
assert events[0].data.parent_hash is None
assert len(events[0].data.blocks) == 1
assert events[1].data.parent_hash == events[0].data.blocks[
0].block_hash
assert len(events[1].data.blocks) == 2
else:
# Swap a block to secondary
assert isinstance(events[0].data,
trtllm.kv_cache.KVCacheUpdatedData)
assert events[0].data.cache_level.old_value == 0
assert events[0].data.cache_level.new_value == 1
# Store the filled context block
assert isinstance(events[1].data,
trtllm.kv_cache.KVCacheStoredData)
assert len(events[1].data.blocks) == 1
assert events[1].data.parent_hash is None
# Swap another block to secondary
assert isinstance(events[2].data,
trtllm.kv_cache.KVCacheUpdatedData)
assert events[2].data.cache_level.old_value == 0
assert events[2].data.cache_level.new_value == 1
assert isinstance(events[2].data.cache_level,
trtllm.kv_cache.KVCacheEventDiffInt)
# Remove the first block in secondary
assert isinstance(events[3].data,
trtllm.kv_cache.KVCacheRemovedData)
assert len(events[3].data.block_hashes) == 1
assert events[3].data.block_hashes[0] == events[
0].data.block_hash
# Store the second context block and the decode block
assert isinstance(events[4].data,
trtllm.kv_cache.KVCacheStoredData)
assert len(events[4].data.blocks) == 2
assert events[4].data.parent_hash == events[1].data.blocks[
0].block_hash
@pytest.mark.parametrize("streaming", [False, True])
def test_request_perf_metrics(streaming: bool, model_path):
# Create executor
beam_width = 1
executor_config = trtllm.ExecutorConfig(beam_width)
executor = trtllm.Executor(model_path, trtllm.ModelType.DECODER_ONLY,
executor_config)
# Create request
max_tokens = 5
input_tokens = [1, 2, 3, 4]
output_config = trtllm.OutputConfig(return_perf_metrics=True)
request = trtllm.Request(input_tokens,
max_tokens=max_tokens,
streaming=streaming,
output_config=output_config)
# Enqueue the request
request_id = executor.enqueue_request(request)
def check_perf_metrics(perf_metrics, done, response_id):
assert perf_metrics is not None
timing_metrics = perf_metrics.timing_metrics
assert timing_metrics.arrival_time < timing_metrics.first_scheduled_time
assert timing_metrics.first_scheduled_time < timing_metrics.first_token_time
if done:
assert timing_metrics.first_token_time < timing_metrics.last_token_time
else:
assert timing_metrics.last_token_time == datetime.timedelta(0)
kv_cache_metrics = perf_metrics.kv_cache_metrics
assert kv_cache_metrics.num_total_allocated_blocks == 1
assert kv_cache_metrics.num_new_allocated_blocks == 1
assert kv_cache_metrics.num_reused_blocks == 0
assert kv_cache_metrics.num_missed_blocks == 1
assert kv_cache_metrics.kv_cache_hit_rate == 0
assert perf_metrics.first_iter == 0
if done:
assert perf_metrics.iter == (max_tokens - 1)
assert perf_metrics.last_iter == max_tokens - 1
else:
assert perf_metrics.iter == response_id
assert perf_metrics.last_iter is None
# Get the new tokens
tokens = []
done = False
i = 0
max_wait_ms = 10000
response_id = 0
while not done and i < max_wait_ms:
wait_time = datetime.timedelta(milliseconds=1)
responses = executor.await_responses(request_id, wait_time)
for response in responses:
assert not response.has_error(
), f"Request id {request_id} failed with err {response.error_msg}"
result = response.result
done = result.is_final
check_perf_metrics(result.request_perf_metrics, done, response_id)
new_tokens = result.output_token_ids[beam_width - 1]
tokens.extend(new_tokens)
response_id += 1
i += 1
assert i < max_wait_ms
assert len(tokens) == get_expected_num_tokens(
len(input_tokens),
max_tokens,
streaming=streaming,
exclude_input_from_output=False), f"{request_id}"
def test_kv_event_stream_timeout(model_path):
beam_width = 1
executor_config = trtllm.ExecutorConfig(
beam_width,
kv_cache_config=trtllm.KvCacheConfig(True,
4 * 64,
event_buffer_max_size=1024))
executor = trtllm.Executor(model_path, trtllm.ModelType.DECODER_ONLY,
executor_config)
cache_manager = executor.get_kv_cache_event_manager()
events = cache_manager.get_latest_events()
assert len(events) == 1
start = datetime.datetime.now()
events = cache_manager.get_latest_events(datetime.timedelta(seconds=1))
end = datetime.datetime.now()
# Make sure that it actually waited
assert abs(end - start) > datetime.timedelta(milliseconds=900)
assert len(events) == 0
def test_iteration_stats():
stats = trtllm.IterationStats()
stats.timestamp = "01:23:56"
stats.iter = 1
stats.iter_latency_ms = 100
stats.num_active_requests = 2
stats.num_queued_requests = 10
stats.max_num_active_requests = 3
stats.gpu_mem_usage = 1024
stats.cpu_mem_usage = 2048
stats.pinned_mem_usage = 4096
stats_json = json.loads(stats.to_json_str())
assert stats_json["timestamp"] == stats.timestamp
assert stats_json["iter"] == stats.iter
assert stats_json["iterLatencyMS"] == stats.iter_latency_ms
assert stats_json[
"newActiveRequestsQueueLatencyMS"] == stats.new_active_requests_queue_latency_ms
assert stats_json["numActiveRequests"] == stats.num_active_requests
assert stats_json["numQueuedRequests"] == stats.num_queued_requests
assert stats_json["numCompletedRequests"] == stats.num_completed_requests
assert stats_json["maxNumActiveRequests"] == stats.max_num_active_requests
assert stats_json["gpuMemUsage"] == stats.gpu_mem_usage
assert stats_json["cpuMemUsage"] == stats.cpu_mem_usage
assert stats_json["pinnedMemUsage"] == stats.pinned_mem_usage
assert stats_json["kvCacheStats"] is None
assert stats_json["staticBatchingStats"] is None
assert stats_json["inflightBatchingStats"] is None
def test_request_stats():
stats = trtllm.RequestStats()
stats.id = 1
stats.stage = trtllm.RequestStage.CONTEXT_IN_PROGRESS
stats.context_prefill_position = 2
stats.num_generated_tokens = 3
stats.avg_num_decoded_tokens_per_iter = 2.5
stats.scheduled = True
stats.paused = False
stats_json = json.loads(stats.to_json_str())
assert stats_json["id"] == stats.id
assert stats_json["stage"] == "CONTEXT_IN_PROGRESS"
assert stats_json[
"contextPrefillPosition"] == stats.context_prefill_position
assert stats_json["numGeneratedTokens"] == stats.num_generated_tokens
assert stats_json[
"avgNumDecodedTokensPerIter"] == stats.avg_num_decoded_tokens_per_iter
assert stats_json["scheduled"] == stats.scheduled
assert stats_json["paused"] == stats.paused
assert stats_json["disServingStats"] is None
def test_request_stats_per_iteration():
stats = trtllm.RequestStatsPerIteration()
stats.iter = 1
req_stat = trtllm.RequestStats()
req_stat.id = 1
stats.request_stats = [req_stat]
stats_json = json.loads(stats.to_json_str())
assert stats_json["iter"] == 1
assert stats_json["requestStats"][0]["id"] == 1
def test_scheduler_config_pickle():
policy = trtllm.CapacitySchedulerPolicy.MAX_UTILIZATION
config = trtllm.SchedulerConfig(policy)
config_str = pickle.dumps(config)
config_copy = pickle.loads(config_str)
assert config.capacity_scheduler_policy == config_copy.capacity_scheduler_policy
def test_kv_cache_config_pickle():
config = trtllm.KvCacheConfig()
config.enable_block_reuse = True
config.free_gpu_memory_fraction = 0.3
config.cross_kv_cache_fraction = 0.5
config.event_buffer_max_size = 1024
config_copy = pickle.loads(pickle.dumps(config))
assert config.enable_block_reuse == config_copy.enable_block_reuse
assert config.max_tokens == config_copy.max_tokens
assert config.max_attention_window == config_copy.max_attention_window
assert config.sink_token_length == config_copy.sink_token_length
assert config.free_gpu_memory_fraction == config_copy.free_gpu_memory_fraction
assert config.cross_kv_cache_fraction == config_copy.cross_kv_cache_fraction
assert config.host_cache_size == config_copy.host_cache_size
assert config.onboard_blocks == config_copy.onboard_blocks
assert config.event_buffer_max_size == config_copy.event_buffer_max_size
def test_peft_cache_config_pickle():
config = trtllm.PeftCacheConfig(1, 2, 3, 4, 5, 6, 7, 8, 9, 0.9, 1024)
config_copy = pickle.loads(pickle.dumps(config))
assert config.num_host_module_layer == config_copy.num_host_module_layer
assert config.num_device_module_layer == config_copy.num_device_module_layer
assert config.optimal_adapter_size == config_copy.optimal_adapter_size
assert config.max_adapter_size == config_copy.max_adapter_size
assert config.num_put_workers == config_copy.num_put_workers
assert config.num_ensure_workers == config_copy.num_ensure_workers
assert config.num_copy_streams == config_copy.num_copy_streams
assert config.max_pages_per_block_host == config_copy.max_pages_per_block_host
assert config.max_pages_per_block_device == config_copy.max_pages_per_block_device
assert config.device_cache_percent == config_copy.device_cache_percent
assert config.host_cache_size == config_copy.host_cache_size
def test_decoding_config_pickle():
config = trtllm.DecodingConfig(
decoding_mode=trtllm.DecodingMode.BeamSearch())
config_copy = pickle.loads(pickle.dumps(config))
assert config_copy.decoding_mode.isBeamSearch
assert config.lookahead_decoding_config == config_copy.lookahead_decoding_config
assert config.medusa_choices == config_copy.medusa_choices
def test_debug_config_pickle():
config = trtllm.DebugConfig(debug_input_tensors=True,
debug_output_tensors=True,
debug_tensor_names=["test"],
debug_tensors_max_iterations=5)
config_copy = pickle.loads(pickle.dumps(config))
assert config.debug_input_tensors == config_copy.debug_input_tensors
assert config.debug_output_tensors == config_copy.debug_output_tensors
assert config.debug_tensor_names == config_copy.debug_tensor_names
assert config.debug_tensors_max_iterations == config_copy.debug_tensors_max_iterations
def test_logits_post_processor_config_pickle():
kwargs = {
"processor_map": {
"test_pp": None
},
"processor_batched": None,
"replicate": False
}
config = trtllm.LogitsPostProcessorConfig(**kwargs)
config_copy = pickle.loads(pickle.dumps(config))
for k in kwargs:
assert getattr(config, k) == getattr(config_copy, k)
def test_guided_decoding_params_pickle():
class Answer(BaseModel):
answer: int
json_schema = json.dumps(Answer.model_json_schema())
params = trtllm.GuidedDecodingParams(
trtllm.GuidedDecodingParams.GuideType.JSON_SCHEMA, guide=json_schema)
params_copy = pickle.loads(pickle.dumps(params))
assert params_copy.guide_type == params.guide_type
assert params_copy.guide == params.guide
def test_guided_decoding_config_pickle():
encoded_vocab = ["eos", "a", "b", "c", "d"]
tokenizer_str = None
stop_token_ids = [0]
config = trtllm.GuidedDecodingConfig(
backend=trtllm.GuidedDecodingConfig.GuidedDecodingBackend.XGRAMMAR,
encoded_vocab=encoded_vocab,
tokenizer_str=tokenizer_str,
stop_token_ids=stop_token_ids)
config_copy = pickle.loads(pickle.dumps(config))
assert config_copy.backend == config.backend
assert config_copy.encoded_vocab == config.encoded_vocab
assert config_copy.tokenizer_str == config.tokenizer_str
assert config_copy.stop_token_ids == config.stop_token_ids
def test_executor_config_pickle():
beam_width = 2
config = trtllm.ExecutorConfig(beam_width)
kwargs = {
"max_beam_width":
2,
"max_batch_size":
8,
"max_num_tokens":
128,
"scheduler_config":
trtllm.SchedulerConfig(trtllm.CapacitySchedulerPolicy.MAX_UTILIZATION),
"kv_cache_config":
trtllm.KvCacheConfig(enable_block_reuse=True),
"enable_chunked_context":
True,
"normalize_log_probs":
False,
"iter_stats_max_iterations":
100,
"batching_type":
trtllm.BatchingType.STATIC,
"parallel_config":
trtllm.ParallelConfig(),
"peft_cache_config":
trtllm.PeftCacheConfig(10),
"logits_post_processor_config":
trtllm.LogitsPostProcessorConfig(),
"decoding_config":
trtllm.DecodingConfig(trtllm.DecodingMode.TopKTopP()),
"extended_runtime_perf_knob_config":
trtllm.ExtendedRuntimePerfKnobConfig(multi_block_mode=True),
"debug_config":
trtllm.DebugConfig(debug_input_tensors=True,
debug_output_tensors=True,
debug_tensor_names=["test"]),
"recv_poll_period_ms":
50,
"max_seq_idle_microseconds":
240 * 1000 * 1000,
"spec_dec_config":
trtllm.SpeculativeDecodingConfig(fast_logits=True)
}
config = trtllm.ExecutorConfig(**kwargs)
for k, v in kwargs.items():
if "config" not in k:
assert getattr(config, k) == v
config.backend = 'pytorch'
pickle.dumps(config)
config_copy = pickle.loads(pickle.dumps(config))
assert config.max_beam_width == config_copy.max_beam_width
assert config.max_batch_size == config_copy.max_batch_size
assert config.max_num_tokens == config_copy.max_num_tokens
assert config.scheduler_config.capacity_scheduler_policy == config_copy.scheduler_config.capacity_scheduler_policy
assert config.kv_cache_config.enable_block_reuse == config_copy.kv_cache_config.enable_block_reuse
assert config.enable_chunked_context == config_copy.enable_chunked_context
assert config.normalize_log_probs == config_copy.normalize_log_probs
assert config.normalize_log_probs == config_copy.normalize_log_probs
assert config.iter_stats_max_iterations == config_copy.iter_stats_max_iterations
assert config.batching_type == config_copy.batching_type
assert config.parallel_config.communication_type == config_copy.parallel_config.communication_type
assert config.peft_cache_config.num_host_module_layer == config_copy.peft_cache_config.num_host_module_layer
assert config_copy.decoding_config.decoding_mode.isTopKandTopP
assert config.extended_runtime_perf_knob_config.multi_block_mode == config_copy.extended_runtime_perf_knob_config.multi_block_mode
assert config.debug_config.debug_input_tensors == config_copy.debug_config.debug_input_tensors
assert config.max_seq_idle_microseconds == config_copy.max_seq_idle_microseconds
assert config.backend == config_copy.backend
assert config.spec_dec_config.fast_logits == config_copy.spec_dec_config.fast_logits
def test_return_full_tokens():
max_tokens = 5
input_tokens = [1, 2, 3, 4]
request = trtllm.Request(input_tokens,
max_tokens=max_tokens,
streaming=False,
sampling_config=trtllm.SamplingConfig())
request.return_all_generated_tokens = True
assert request.return_all_generated_tokens == True
request.return_all_generated_tokens = False
assert request.return_all_generated_tokens == False
def test_getters_return_references():
config = trtllm.ExecutorConfig()
# Make sure kv_cache_config is a reference. Returning a value
# will lead to the very confusing behavior of this set statement
# not working.
config.kv_cache_config.max_tokens = 42
assert config.kv_cache_config.max_tokens == 42
def test_allotted_time_ms():
allotted_time = datetime.timedelta(milliseconds=2)
input_tokens = [1, 2, 3, 4]
max_new_tokens = 5
request = trtllm.Request(input_tokens, max_tokens=max_new_tokens)
request.allotted_time_ms = allotted_time
assert request.allotted_time_ms == datetime.timedelta(milliseconds=2)
def test_executor_version():
assert trtllm.__version__ == trtllm_version.__version__
def get_all_field_names_of_class(cls: type) -> list[str]:
return [
name for name, obj in inspect.getmembers(cls)
if isinstance(obj, property) or (
not callable(obj) and not name.startswith('__'))
]
def test_runtime_defaults():
full_runtime_defaults: dict[str, tp.Any] = json.loads("""{
"max_attention_window": [1, 2],
"sink_token_length": 4
}""")
all_field_names = set(full_runtime_defaults)
assert set(
get_all_field_names_of_class(trtllm.RuntimeDefaults)
) == all_field_names, "Expected fields of runtime_defaults to match actual data"
msg = """\
Rather than create a `from_dict` on top of the bound class, \
we rely on being able to directly provide the dict created from raw json as kwargs to `RuntimeDefaults.` \
See: `PretrainedConfig.__init__()`"""
assert PretrainedConfig.create_runtime_defaults(
full_runtime_defaults) is not None, msg
default_runtime_defaults = trtllm.RuntimeDefaults()
for key in all_field_names:
assert getattr(default_runtime_defaults, key) == None
def test_DynamicBatchConfig_pickle():
config = trtllm.DynamicBatchConfig(enable_batch_size_tuning=True,
enable_max_num_tokens_tuning=True,
dynamic_batch_moving_average_window=128)
config_copy = pickle.loads(pickle.dumps(config))
assert config.enable_batch_size_tuning == config_copy.enable_batch_size_tuning