TensorRT-LLMs/tests/bindings/test_executor_bindings.py
Kaiyu Xie 250d9c293d
Update TensorRT-LLM Release branch (#1445)
* Update TensorRT-LLM

---------

Co-authored-by: Bhuvanesh Sridharan <bhuvan.sridharan@gmail.com>
Co-authored-by: Morgan Funtowicz <funtowiczmo@gmail.com>
Co-authored-by: Eddie-Wang1120 <wangjinheng1120@163.com>
Co-authored-by: meghagarwal <16129366+megha95@users.noreply.github.com>
2024-04-12 17:59:19 +08:00

847 lines
32 KiB
Python

import datetime
import os
import random
import sys
import time
import typing as tp
from pathlib import Path
import numpy as np
import pytest
import torch
from binding_test_utils import *
import tensorrt_llm.bindings.executor as trtllm
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
from utils.util import skip_pre_ampere
@pytest.fixture
def model_path(engine_path):
return engine_path / "gpt2/fp16-plugin-packed-paged/tp1-pp1-gpu"
@pytest.fixture
def model_path_return_logits(engine_path):
return engine_path / "gpt2/fp16-plugin-packed-paged-gather/tp1-pp1-gpu"
@pytest.fixture
def input_data_path(data_path):
return data_path / "input_tokens.npy"
@pytest.fixture(scope="module")
def results_data_path(data_path: Path) -> Path:
return data_path / "gpt2/sampling/output_tokens_fp16_plugin_packed_paged_tp1_pp1.npy"
@pytest.fixture(scope="module")
def results_data_path_beam_width_2(data_path: Path) -> Path:
return data_path / "gpt2/beam_search_2/output_tokens_fp16_plugin_packed_paged_tp1_pp1.npy"
@pytest.fixture
def model_files(llm_root: Path, resource_path: Path, llm_model_root,
results_data_path):
# Model engines and expected outputs need to be generated.
if not results_data_path.exists():
model_cache_arg = ["--model_cache",
str(llm_model_root)
] if llm_model_root is not None else []
prepare_model_tests(llm_root, resource_path, "gpt", model_cache_arg)
def get_expected_num_tokens(prompt_len, max_new_tokens, streaming,
exclude_input_from_output):
if not streaming and not exclude_input_from_output:
return prompt_len + max_new_tokens
return max_new_tokens
@skip_pre_ampere # ContextFMHAType with fp32 acc is not supported in pre-ampere architecture
def test_executor_valid_ctor(model_files, model_path):
executor_config = trtllm.ExecutorConfig(1)
executor = trtllm.Executor(model_path, trtllm.ModelType.DECODER_ONLY,
executor_config)
@skip_pre_ampere # ContextFMHAType with fp32 acc is not supported in pre-ampere architecture
def test_executor_from_memory(model_files, model_path):
executor_config = trtllm.ExecutorConfig(1)
engine_buffer = open(model_path / "rank0.engine", mode="rb").read()
json_config_str = open(model_path / "config.json", 'r').read()
executor = trtllm.Executor(engine_buffer, json_config_str,
trtllm.ModelType.DECODER_ONLY, executor_config)
def test_executor_invalid_ctor():
executor_config = trtllm.ExecutorConfig(1)
invalid_path = "Bla"
try:
executor = trtllm.Executor(invalid_path, trtllm.ModelType.DECODER_ONLY,
executor_config)
assert False, "Expected an error"
except Exception as e:
assert "File does not exist" in str(e)
@skip_pre_ampere # ContextFMHAType with fp32 acc is not supported in pre-ampere architecture
def test_embedding_bias(model_files, model_path):
streaming = False
exclude_input_from_output = False
output_config = trtllm.OutputConfig()
output_config.exclude_input_from_output = exclude_input_from_output
# Create executor
beam_width = 1
executor_config = trtllm.ExecutorConfig(beam_width)
executor = trtllm.Executor(model_path, trtllm.ModelType.DECODER_ONLY,
executor_config)
# Create the request
max_new_tokens = 5
input_tokens = [1, 2, 3, 4]
# Set embedding bias so "biased_output" is always picked
biased_output = 10
vocab_size_padded = 50257
embedding_bias = torch.zeros(vocab_size_padded)
embedding_bias[biased_output] = torch.finfo(torch.float32).max
request = trtllm.Request(input_tokens,
max_new_tokens,
streaming,
trtllm.SamplingConfig(),
output_config,
embedding_bias=embedding_bias)
# Enqueue the request
request_id = executor.enqueue_request(request)
# Get the new tokens
tokens = []
done = False
i = 0
max_wait_ms = 10000
while not done and i < max_wait_ms:
wait_time = datetime.timedelta(milliseconds=1)
responses = executor.await_responses(request_id, wait_time)
for response in responses:
assert not response.has_error(
), f"Request id {request_id} failed with err {response.error_msg}"
result = response.result
done = result.is_final
new_tokens = result.output_token_ids[beam_width - 1]
tokens.extend(new_tokens)
i += 1
assert i < max_wait_ms
assert len(tokens) == get_expected_num_tokens(
len(input_tokens), max_new_tokens, streaming,
exclude_input_from_output), f"{request_id}"
# All generated tokens should equal biased_output
assert tokens[-max_new_tokens:] == [biased_output] * max_new_tokens
@pytest.mark.parametrize("streaming", [False, True])
@pytest.mark.parametrize("exclude_input_from_output", [False])
@skip_pre_ampere # ContextFMHAType with fp32 acc is not supported in pre-ampere architecture
def test_single_request(streaming: bool, exclude_input_from_output: bool,
model_files, model_path):
output_config = trtllm.OutputConfig()
output_config.exclude_input_from_output = exclude_input_from_output
# Create executor
beam_width = 1
executor_config = trtllm.ExecutorConfig(beam_width)
executor = trtllm.Executor(model_path, trtllm.ModelType.DECODER_ONLY,
executor_config)
# Create the request
max_new_tokens = 5
input_tokens = [1, 2, 3, 4]
request = trtllm.Request(input_tokens, max_new_tokens, streaming,
trtllm.SamplingConfig(), output_config)
# Enqueue the request
request_id = executor.enqueue_request(request)
# Get the new tokens
tokens = []
done = False
i = 0
max_wait_ms = 10000
while not done and i < max_wait_ms:
wait_time = datetime.timedelta(milliseconds=1)
responses = executor.await_responses(request_id, wait_time)
for response in responses:
assert not response.has_error(
), f"Request id {request_id} failed with err {response.error_msg}"
result = response.result
done = result.is_final
new_tokens = result.output_token_ids[beam_width - 1]
tokens.extend(new_tokens)
i += 1
assert i < max_wait_ms
assert len(tokens) == get_expected_num_tokens(
len(input_tokens), max_new_tokens, streaming,
exclude_input_from_output), f"{request_id}"
executor.get_latest_iteration_stats()
executor.get_latest_request_stats()
@pytest.mark.parametrize("streaming", [False, True])
@pytest.mark.parametrize("exclude_input_from_output", [False])
@skip_pre_ampere # ContextFMHAType with fp32 acc is not supported in pre-ampere architecture
def test_multi_request(streaming: bool, exclude_input_from_output: bool,
model_files, model_path):
output_config = trtllm.OutputConfig()
output_config.exclude_input_from_output = exclude_input_from_output
# Create executor
beam_width = 1
executor_config = trtllm.ExecutorConfig(beam_width)
executor = trtllm.Executor(model_path, trtllm.ModelType.DECODER_ONLY,
executor_config)
num_requests = 20
max_prompt_len = 20
max_max_new_tokens = 20
end_id = -1
# Enqueue the requests
tokens = {}
expected_num_tokens = {}
for i in range(num_requests):
prompt_len = random.randint(1, max_prompt_len)
max_new_tokens = random.randint(1, max_max_new_tokens)
input_tokens = [1] * prompt_len
request = trtllm.Request(input_tokens, max_new_tokens, streaming,
trtllm.SamplingConfig(), output_config, end_id)
request_id = executor.enqueue_request(request)
tokens[request_id] = []
expected_num_tokens[request_id] = get_expected_num_tokens(
prompt_len, max_new_tokens, streaming, exclude_input_from_output)
# Get the new tokens for each request
num_finished = 0
i = 0
num_responses = 0
max_wait_ms = 10000
while num_finished < num_requests and i < max_wait_ms:
wait_time = datetime.timedelta(milliseconds=1)
responses = executor.await_responses(None, wait_time)
for response in responses:
num_responses += 1
assert not response.has_error(
), f"Request id {response.request_id} failed with err {response.error_msg}"
result = response.result
num_finished += result.is_final
new_tokens = result.output_token_ids[beam_width - 1]
tokens[response.request_id].extend(new_tokens)
i += 1
assert i < max_wait_ms
for request_id in expected_num_tokens:
assert len(tokens[request_id]) == expected_num_tokens[request_id]
@pytest.mark.parametrize("streaming", [False, True])
@pytest.mark.parametrize("exclude_input_from_output", [False])
@skip_pre_ampere # ContextFMHAType with fp32 acc is not supported in pre-ampere architecture
def test_get_num_responses_ready(streaming: bool,
exclude_input_from_output: bool, model_files,
model_path):
output_config = trtllm.OutputConfig()
output_config.exclude_input_from_output = exclude_input_from_output
# Create executor
executor_config = trtllm.ExecutorConfig(1)
executor = trtllm.Executor(model_path, trtllm.ModelType.DECODER_ONLY,
executor_config)
max_prompt_len = 20
max_max_new_tokens = 20
# Enqueue the requests
num_requests = random.randint(1, 50)
num_expected_responses = 0
req_num_expected_responses = {}
for i in range(num_requests):
prompt_len = random.randint(1, max_prompt_len)
max_new_tokens = random.randint(1, max_max_new_tokens)
request = trtllm.Request([1] * prompt_len, max_new_tokens, streaming,
trtllm.SamplingConfig(), output_config)
request_id = executor.enqueue_request(request)
req_num_expected_responses[
request_id] = max_new_tokens if streaming else 1
num_expected_responses += req_num_expected_responses[request_id]
i = 0
num_ready = 0
max_wait_ms = 10000
while num_ready < num_expected_responses and i < max_wait_ms:
num_ready = 0
for request_id in req_num_expected_responses:
num_ready += executor.get_num_responses_ready(request_id)
time.sleep(0.001)
i += 1
assert i < max_wait_ms
for request_id in req_num_expected_responses:
num_ready = executor.get_num_responses_ready(request_id)
assert num_ready == req_num_expected_responses[request_id]
assert executor.get_num_responses_ready() == num_expected_responses
@pytest.mark.parametrize("batching_type", [trtllm.BatchingType.INFLIGHT])
@pytest.mark.parametrize("streaming", [False, True])
@pytest.mark.parametrize("beam_width", [1])
@pytest.mark.parametrize("compute_log_probs", [False, True])
@pytest.mark.parametrize("exclude_input_from_output", [False])
@pytest.mark.parametrize("return_context_logits", [False, True])
@pytest.mark.parametrize("return_generation_logits", [False, True])
@skip_pre_ampere # ContextFMHAType with fp32 acc is not supported in pre-ampere architecture
def test_token_comparison(batching_type: trtllm.BatchingType, streaming: bool,
beam_width: int, compute_log_probs: bool,
exclude_input_from_output: bool,
return_context_logits: bool,
return_generation_logits: bool, model_files,
model_path, model_path_return_logits, input_data_path,
results_data_path, results_data_path_beam_width_2):
if streaming and beam_width > 1:
pytest.skip("Test does not support streaming with beam search")
vocab_size_padded = 50257
pad_id = 50256
remove_input = not exclude_input_from_output and not streaming
def load_test_data(input_path, results_path):
# Inputs
assert input_path.is_file()
given_input = np.load(input_path).astype("int32")
input_shape = given_input.shape
assert len(input_shape) == 2
max_input_length = input_shape[1]
given_input_lengths = sequence_lengths(given_input, pad_id)
assert np.all(given_input_lengths <= max_input_length)
# Expected results
assert results_path.is_file()
expected_outputs = np.load(results_path).astype("int32")
output_shape = expected_outputs.shape
assert len(output_shape) == 2
assert input_shape[0] * beam_width == output_shape[0]
max_seq_length = output_shape[1]
max_new_tokens = max_seq_length - max_input_length
end_ids = [pad_id for _ in range(len(given_input_lengths))]
expected_lengths = []
for i in range(len(given_input_lengths)):
expected_lengths.append([
given_input_lengths[i] + max_new_tokens
for _ in range(beam_width)
])
test_data = {
"expected_output_ids": expected_outputs,
"expected_output_lengths": expected_lengths,
"max_seq_length": max_seq_length,
"end_ids": end_ids
}
return given_input, given_input_lengths, max_input_length, test_data
def validate_results_shapes(result, input_length, max_output_len,
beam_tokens):
if compute_log_probs:
assert result.cum_log_probs is not None
assert result.log_probs is not None
assert len(result.cum_log_probs) == beam_width
assert len(result.log_probs) == beam_width
for beam in range(beam_width):
expected_len = len(
beam_tokens[beam]) - (input_length if remove_input else 0)
assert len(result.log_probs[beam]) == expected_len
else:
assert result.cum_log_probs is None
assert result.log_probs is None
if return_context_logits:
assert result.context_logits is not None
assert len(result.context_logits.shape) == 2
assert list(result.context_logits.shape) == [
input_length, vocab_size_padded
]
else:
assert result.context_logits is None
if return_generation_logits:
assert len(result.generation_logits.shape) == 3
assert list(result.generation_logits.shape) == [
beam_width, max_output_len, vocab_size_padded
]
def verify_output(beam_tokens, test_data, given_input_lengths):
for batch_id, tokens in beam_tokens.items():
input_length = given_input_lengths[batch_id]
end_id = test_data["end_ids"][batch_id]
for beam in range(beam_width):
predicted_tokens = tokens[beam]
if remove_input:
predicted_tokens = predicted_tokens[input_length:]
expected_length = test_data["expected_output_lengths"][
batch_id][beam] - input_length
assert len(predicted_tokens) == expected_length
expected_tokens = test_data["expected_output_ids"][
batch_id * beam_width + beam][input_length:]
for i in range(len(predicted_tokens)):
if expected_tokens[i] == end_id:
break
assert predicted_tokens[i] == expected_tokens[i], \
f"Predicted: {predicted_tokens} vs Expected: {expected_tokens}"
output_config = trtllm.OutputConfig()
output_config.exclude_input_from_output = exclude_input_from_output
output_config.return_log_probs = compute_log_probs
output_config.return_generation_logits = return_generation_logits
output_config.return_context_logits = return_context_logits
kv_cache_config = trtllm.KvCacheConfig(False, free_gpu_memory_fraction=0.5)
executor_config = trtllm.ExecutorConfig(beam_width)
executor_config.batching_type = batching_type
executor_config.kv_cache_config = kv_cache_config
if return_context_logits or return_generation_logits:
model_path = model_path_return_logits
executor = trtllm.Executor(model_path, trtllm.ModelType.DECODER_ONLY,
executor_config)
# Load test data
results_path = results_data_path if beam_width == 1 else results_data_path_beam_width_2
given_input, given_input_lengths, max_input_length, test_data = load_test_data(
input_data_path, results_path)
# Create requests from input data
num_requests = len(given_input_lengths)
requests = []
req_max_new_tokens = []
for i in range(num_requests):
input_len = given_input_lengths[i]
max_new_tokens = test_data["max_seq_length"] - max_input_length
req_max_new_tokens.append(max_new_tokens)
req_tokens = given_input[i][:input_len]
requests.append(
trtllm.Request(req_tokens,
max_new_tokens,
streaming,
trtllm.SamplingConfig(beam_width),
output_config,
end_id=-1))
req_ids = executor.enqueue_requests(requests)
req_to_batch_id = {req_ids[i]: i for i in range(len(requests))}
tokens = {i: [[] for _ in range(beam_width)] for i in range(len(requests))}
num_finished = 0
i = 0
num_responses = 0
max_wait_ms = 10000
while num_finished < num_requests and i < max_wait_ms:
wait_time = datetime.timedelta(milliseconds=1)
responses = executor.await_responses(None, wait_time)
for response in responses:
num_responses += 1
assert not response.has_error(
), f"Request id {response.request_id} failed with err {response.error_msg}"
result = response.result
num_finished += result.is_final
batch_id = req_to_batch_id[response.request_id]
for beam in range(beam_width):
new_tokens = result.output_token_ids[beam]
tokens[batch_id][beam] += new_tokens
validate_results_shapes(result, given_input_lengths[batch_id],
req_max_new_tokens[batch_id],
tokens[batch_id])
i += 1
assert i < max_wait_ms
verify_output(tokens, test_data, given_input_lengths)
@skip_pre_ampere # ContextFMHAType with fp32 acc is not supported in pre-ampere architecture
def test_gpt_executor_timed_out(model_files, model_path):
beam_width = 1
executor_config = trtllm.ExecutorConfig(beam_width)
executor = trtllm.Executor(model_path, trtllm.ModelType.DECODER_ONLY,
executor_config)
# No requests enqueued, expect no responses
num_responses_ready = executor.get_num_responses_ready()
assert num_responses_ready == 0
wait_time = datetime.timedelta(milliseconds=10)
responses = executor.await_responses(None, wait_time)
assert len(responses) == 0
@skip_pre_ampere # ContextFMHAType with fp32 acc is not supported in pre-ampere architecture
def test_single_request_invalid_inputs(model_files, model_path):
streaming = True
beam_width = 1
executor_config = trtllm.ExecutorConfig(beam_width)
executor = trtllm.Executor(model_path, trtllm.ModelType.DECODER_ONLY,
executor_config)
max_new_tokens = 5
input_tokens = [1, 2, 3, 4]
request = trtllm.Request(input_tokens, max_new_tokens, streaming)
# Invalid embedding bias shape
embedding_bias = torch.ones(1)
request.embedding_bias = embedding_bias
expected_error_msg = "embedding bias shape is not as expected"
request_id = executor.enqueue_request(request)
done = False
i = 0
max_wait_ms = 10000
while not done and i < max_wait_ms:
wait_time = datetime.timedelta(milliseconds=1)
responses = executor.await_responses(request_id, wait_time)
for response in responses:
assert response.has_error(), "Expected an error"
assert expected_error_msg in response.error_msg
done = True
i += 1
assert done
def test_sampling_config():
beam_width = 1
kwargs = {
"top_k": 2,
"top_p": 1.0,
"top_p_min": 1.0,
"top_p_reset_ids": 3,
"top_p_decay": 1.0,
"random_seed": 7,
"temperature": 1.0,
"min_length": 4,
"beam_search_diversity_rate": 1.0,
"repetition_penalty": 1.0,
"presence_penalty": 1.0,
"frequency_penalty": 1.0,
"length_penalty": 1.0,
"early_stopping": 5
}
config = trtllm.SamplingConfig(beam_width, **kwargs)
for k, v in kwargs.items():
assert getattr(config, k) == v
del config
config = trtllm.SamplingConfig(beam_width)
assert config.beam_width == beam_width
for k in kwargs:
assert getattr(config, k) is None
def test_output_config():
config = trtllm.OutputConfig()
assert config.return_log_probs == False
assert config.return_context_logits == False
assert config.return_generation_logits == False
assert config.exclude_input_from_output == False
config = trtllm.OutputConfig(True, False, True, False)
assert config.return_log_probs == True
assert config.return_context_logits == False
assert config.return_generation_logits == True
assert config.exclude_input_from_output == False
def test_speculative_decoding_config():
tokens = [1, 2, 3]
config = trtllm.SpeculativeDecodingConfig(tokens)
assert config.tokens == tokens
assert config.logits is None
assert config.acceptance_threshold is None
del config
logits = torch.ones(3, 1)
acceptance_threshold = 1.0
config = trtllm.SpeculativeDecodingConfig(tokens, logits,
acceptance_threshold)
assert config.tokens == tokens
assert (config.logits == logits).all()
assert config.acceptance_threshold == acceptance_threshold
def test_prompt_tuning_config():
embedding_table = torch.ones(100, 64)
config = trtllm.PromptTuningConfig(embedding_table)
assert (config.embedding_table == embedding_table).all()
def test_lora_config():
task_id = 1
lora_config = trtllm.LoraConfig(task_id)
assert lora_config.task_id == task_id
assert lora_config.weights is None
assert lora_config.config is None
task_id = 2
weights = torch.ones(1, 2)
config = torch.ones(1, 2, dtype=torch.int32)
lora_config = trtllm.LoraConfig(task_id, weights, config)
assert lora_config.task_id == task_id
assert (lora_config.weights == weights).all()
assert (lora_config.config == config).all()
def test_request():
kwargs = {
"input_token_ids": [1, 2, 3],
"max_new_tokens": 1,
"streaming": False,
"sampling_config": trtllm.SamplingConfig(),
"output_config": trtllm.OutputConfig(),
"end_id": -1,
"pad_id": -2,
"bad_words": [[4, 5, 6]],
"stop_words": [[7, 8, 9]],
"embedding_bias": torch.ones(1),
"speculative_decoding_config":
trtllm.SpeculativeDecodingConfig([1, 2, 3]),
"prompt_tuning_config": trtllm.PromptTuningConfig(torch.ones(100, 64)),
"lora_config": trtllm.LoraConfig(1)
}
request = trtllm.Request(**kwargs)
for k, v in kwargs.items():
if "config" not in k:
assert getattr(request, k) == v
assert isinstance(request.sampling_config, trtllm.SamplingConfig)
assert isinstance(request.output_config, trtllm.OutputConfig)
assert isinstance(request.speculative_decoding_config,
trtllm.SpeculativeDecodingConfig)
assert request.speculative_decoding_config.tokens == [1, 2, 3]
assert isinstance(request.prompt_tuning_config, trtllm.PromptTuningConfig)
assert (request.prompt_tuning_config.embedding_table == torch.ones(
100, 64)).all()
assert isinstance(request.lora_config, trtllm.LoraConfig)
def test_result():
result = trtllm.Result()
result.is_final = True
result.output_token_ids = [[1, 2, 3]]
result.cum_log_probs = [1.0, 2.0, 3.0]
result.log_probs = [[1.0, 2.0, 3.0]]
result.context_logits = torch.ones(3, 100)
result.generation_logits = torch.ones(1, 3, 100)
assert result.is_final == True
assert result.output_token_ids == [[1, 2, 3]]
assert result.cum_log_probs == [1.0, 2.0, 3.0]
assert result.log_probs == [[1.0, 2.0, 3.0]]
assert (result.context_logits == torch.ones(3, 100)).all()
assert (result.generation_logits == torch.ones(1, 3, 100)).all()
def test_response():
request_id = 0
error_msg = "error"
response = trtllm.Response(request_id, error_msg)
assert response.request_id == request_id
assert response.has_error()
assert response.error_msg == error_msg
result = trtllm.Result()
result.is_final = True
result.output_token_ids = [[1, 2, 3]]
request_id = 1
response = trtllm.Response(request_id, result)
assert response.request_id == request_id
assert not response.has_error()
assert response.result.is_final
assert response.result.output_token_ids == [[1, 2, 3]]
def test_scheduler_config():
policy = trtllm.SchedulerPolicy.MAX_UTILIZATION
config = trtllm.SchedulerConfig(policy)
assert config.policy == policy
policy = trtllm.SchedulerPolicy.GUARANTEED_NO_EVICT
config = trtllm.SchedulerConfig(policy)
assert config.policy == policy
def test_kv_cache_config():
config = trtllm.KvCacheConfig()
assert config.enable_block_reuse == False
assert config.max_tokens is None
assert config.max_attention_window is None
assert config.sink_token_length is None
assert config.free_gpu_memory_fraction is None
assert config.host_cache_size is None
assert config.onboard_blocks == True
kwargs = {
"enable_block_reuse": True,
"max_tokens": 3,
"max_attention_window": 10,
"sink_token_length": 2,
"free_gpu_memory_fraction": 0.5,
"host_cache_size": 1024,
"onboard_blocks": False,
}
config = trtllm.KvCacheConfig(**kwargs)
for k, v in kwargs.items():
assert getattr(config, k) == v
def test_executor_config():
config = trtllm.ExecutorConfig()
assert config.max_beam_width == 1
assert isinstance(config.scheduler_config, trtllm.SchedulerConfig)
assert isinstance(config.kv_cache_config, trtllm.KvCacheConfig)
assert config.enable_chunked_context == False
assert config.normalize_log_probs == True
assert config.iter_stats_max_iterations == 1000
assert config.batching_type == trtllm.BatchingType.INFLIGHT
assert config.parallel_config is None
assert isinstance(config.peft_cache_config, trtllm.PeftCacheConfig)
assert config.logits_post_processor_map is None
assert config.medusa_choices is None
kwargs = {
"max_beam_width":
2,
"scheduler_config":
trtllm.SchedulerConfig(trtllm.SchedulerPolicy.MAX_UTILIZATION),
"kv_cache_config":
trtllm.KvCacheConfig(),
"enable_chunked_context":
True,
"normalize_log_probs":
False,
"iter_stats_max_iterations":
100,
"batching_type":
trtllm.BatchingType.STATIC,
"parallel_config":
trtllm.ParallelConfig(),
"peft_cache_config":
trtllm.PeftCacheConfig(10),
"logits_post_processor_map": {},
"medusa_choices": [[1, 2, 3]],
}
config = trtllm.ExecutorConfig(**kwargs)
for k, v in kwargs.items():
if "config" not in k:
assert getattr(config, k) == v
assert isinstance(config.scheduler_config, trtllm.SchedulerConfig)
assert config.scheduler_config.policy == trtllm.SchedulerPolicy.MAX_UTILIZATION
assert isinstance(config.kv_cache_config, trtllm.KvCacheConfig)
assert isinstance(config.parallel_config, trtllm.ParallelConfig)
assert isinstance(config.peft_cache_config, trtllm.PeftCacheConfig)
def test_parallel_config():
comm_type = trtllm.CommunicationType.MPI
comm_mode = trtllm.CommunicationMode.LEADER
device_ids = [0, 1, 2, 3]
participant_ids = [4, 5, 6, 7]
parallel_config = trtllm.ParallelConfig(comm_type, comm_mode, device_ids,
participant_ids)
assert parallel_config.communication_type == comm_type
assert parallel_config.communication_mode == comm_mode
assert parallel_config.device_ids == device_ids
assert parallel_config.participant_ids == participant_ids
def test_peft_cache_config():
num_host_module_layer = 1
num_device_module_layer = 2
optimal_adapter_size = 3
max_adapter_size = 4
num_put_workers = 5
num_ensure_workers = 6
num_copy_streams = 7
max_pages_per_block_host = 8
max_pages_per_block_device = 9
device_cache_percent = 0.9
host_cache_size = 1024
peft_cache_config = trtllm.PeftCacheConfig(
num_host_module_layer, num_device_module_layer, optimal_adapter_size,
max_adapter_size, num_put_workers, num_ensure_workers, num_copy_streams,
max_pages_per_block_host, max_pages_per_block_device,
device_cache_percent, host_cache_size)
assert peft_cache_config.num_host_module_layer == num_host_module_layer
assert peft_cache_config.num_device_module_layer == num_device_module_layer
assert peft_cache_config.optimal_adapter_size == optimal_adapter_size
assert peft_cache_config.max_adapter_size == max_adapter_size
assert peft_cache_config.num_put_workers == num_put_workers
assert peft_cache_config.num_ensure_workers == num_ensure_workers
assert peft_cache_config.num_copy_streams == num_copy_streams
assert peft_cache_config.max_pages_per_block_host == max_pages_per_block_host
assert peft_cache_config.max_pages_per_block_device == max_pages_per_block_device
assert np.isclose(peft_cache_config.device_cache_percent,
device_cache_percent)
assert peft_cache_config.host_cache_size == host_cache_size
@skip_pre_ampere # ContextFMHAType with fp32 acc is not supported in pre-ampere architecture
def test_logits_post_processor(model_files, model_path):
# Define the logits post-processor callback
def logits_post_processor(req_id: int, logits: torch.Tensor,
ids: tp.List[tp.List[int]], stream_ptr: int):
with torch.cuda.stream(torch.cuda.ExternalStream(stream_ptr)):
logits[:] = float("-inf")
logits[..., 42] = 0
# Create executor
beam_width = 1
executor_config = trtllm.ExecutorConfig(beam_width)
executor_config.logits_post_processor_map = {
"my_logits_pp": logits_post_processor
}
executor = trtllm.Executor(model_path, trtllm.ModelType.DECODER_ONLY,
executor_config)
# Create the request
max_new_tokens = 5
input_tokens = [1, 2, 3, 4]
request = trtllm.Request(input_tokens, max_new_tokens, False)
request.logits_post_processor_name = "my_logits_pp"
# Enqueue the request
request_id = executor.enqueue_request(request)
# Get the new tokens
tokens = []
done = False
i = 0
max_wait_ms = 10000
while not done and i < max_wait_ms:
wait_time = datetime.timedelta(milliseconds=1)
responses = executor.await_responses(request_id, wait_time)
for response in responses:
assert not response.has_error(
), f"Request id {request_id} failed with err {response.error_msg}"
result = response.result
done = result.is_final
new_tokens = result.output_token_ids[beam_width - 1]
tokens.extend(new_tokens)
i += 1
assert i < max_wait_ms
assert len(tokens) == get_expected_num_tokens(len(input_tokens),
max_new_tokens, False,
False), f"{request_id}"
# check that all output tokens are 42
print(tokens)
assert tokens[-max_new_tokens:] == [42] * max_new_tokens