TensorRT-LLMs/triton_backend/all_models/tests/test_python_backend.py
Iman Tabrizian 4c7191af67
Move Triton backend to TRT-LLM main (#3549)
* Move TRT-LLM backend repo to TRT-LLM repo

Signed-off-by: Iman Tabrizian <10105175+tabrizian@users.noreply.github.com>

* Address review comments

Signed-off-by: Iman Tabrizian <10105175+tabrizian@users.noreply.github.com>

* debug ci

Signed-off-by: Iman Tabrizian <10105175+tabrizian@users.noreply.github.com>

* Update triton backend

Signed-off-by: Iman Tabrizian <10105175+tabrizian@users.noreply.github.com>

* Fixes after update

Signed-off-by: Iman Tabrizian <10105175+tabrizian@users.noreply.github.com>

---------

Signed-off-by: Iman Tabrizian <10105175+tabrizian@users.noreply.github.com>
2025-05-16 07:15:23 +08:00

873 lines
36 KiB
Python

# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import os
import sys
from dataclasses import dataclass
from typing import Dict, List, Union
from unittest.mock import MagicMock, patch
import numpy as np
import pytest
import torch
# Mock pb_utils
sys.modules["triton_python_backend_utils"] = MagicMock()
# Use PYTHONPATH=../inflight_batcher_llm/tensorrt_llm/1/
from model import *
import tensorrt_llm.bindings.executor as trtllm
@dataclass
class MockTritonTensor:
_name: str
_tensor: Union[np.ndarray, torch.Tensor]
def name(self) -> str:
return self._name
def as_numpy(self) -> np.ndarray:
if self.is_cpu():
return self._tensor
else:
return self._tensor.as_numpy()
def is_cpu(self) -> bool:
if isinstance(self._tensor, np.ndarray):
return True
else:
return False
@dataclass
class MockTritonError:
message: str
@dataclass
class MockTritonResponse:
tensors: Dict[str, MockTritonTensor]
error: MockTritonError
def __init__(self,
output_tensors: List[MockTritonTensor],
error: MockTritonError = None):
self.tensors = {}
for tensor in output_tensors:
self.tensors[tensor.name()] = tensor
self.error = error
def output_tensors(self):
return self.tensors.values()
def has_error(self):
return self.error is not None
@dataclass
class MockTritonRequest:
tensors: Dict[str, MockTritonTensor]
def get_input_tensor_by_name(self, name: str) -> MockTritonTensor:
return self.tensors[name] if name in self.tensors else None
def get_response_sender(self):
return None
def mock_pb_utils_get_input_tensor_by_name_side_effect(
request: MockTritonRequest, name: str) -> MockTritonTensor:
return request.get_input_tensor_by_name(name)
def make_mock_triton_request(
tensors: Dict[str, np.ndarray]) -> MockTritonRequest:
return MockTritonRequest({
k: MockTritonTensor(k, np.array(v))
for k, v in tensors.items()
})
@pytest.fixture(autouse=True)
def apply_patches():
patch("model.pb_utils.Tensor", new=MockTritonTensor).start()
patch("model.pb_utils.InferenceResponse", new=MockTritonResponse).start()
patch("model.pb_utils.TritonError", new=MockTritonError).start()
patch("model.pb_utils.InferenceRequest", new=MockTritonRequest).start()
patch("model.pb_utils.get_input_tensor_by_name",
new=mock_pb_utils_get_input_tensor_by_name_side_effect).start()
patch("model.pb_utils.TritonModelException", new=Exception).start()
@pytest.fixture
def triton_request() -> MockTritonRequest:
inputs = {
"input_ids": [[28524, 287, 5093, 12]],
"request_output_len": [16],
"streaming": [True],
"end_id": [50256],
"pad_id": [50256],
"stop_words_list": [[[14480, 326, 262, 1171], [1, 4, -1, -1]]],
"bad_words_list": [[[24044, 76, 1230], [2, 3, -1]]],
"embedding_bias":
np.array([[0., 0., 0.]], dtype=np.float32),
"beam_width": [2],
"runtime_top_k": [1],
"runtime_top_p": [0.],
"seed": [4],
"temperature": [1.],
"min_tokens": [3],
"repetition_penalty": [1.0],
"presence_penalty": [2.0],
"frequency_penalty": [4.0],
"len_penalty": [8.0],
"runtime_top_p_min": [1.0],
"runtime_top_p_reset_ids": [1],
"runtime_top_p_decay": [1.0],
"beam_search_diversity_rate": [1.0],
"early_stopping": [True],
"return_log_probs":
True,
"return_context_logits":
True,
"return_generation_logits":
True,
"draft_input_ids": [[0, 1]],
"draft_logits":
np.array([[[1.0, 2.0], [3.0, 4.0]]], dtype=np.float32),
"draft_acceptance_threshold":
1.0,
"prompt_embedding_table":
np.array([[[1.0, 2.0], [3.0, 4.0]]], dtype=np.float16),
"lora_task_id": [1],
"lora_weights":
np.array([[[1.0, 2.0], [3.0, 4.0]]], dtype=np.float16),
"lora_config":
np.array([[[1, 2, 3], [4, 5, 6]]], dtype=np.int32),
"retention_token_range_starts":
np.array([[0, 100]], dtype=np.int32),
"retention_token_range_ends":
np.array([[100, 200]], dtype=np.int32),
"retention_token_range_priorities":
np.array([[100, 50]], dtype=np.int32),
"prompt_vocab_size": [2],
}
return make_mock_triton_request(inputs)
@pytest.fixture
def batched_triton_request() -> MockTritonRequest:
inputs = {
"input_ids": [[28524, 287, 5093, 12], [1, 2, 3, 4]],
"input_lengths": [4, 2],
"request_output_len": [16, 3],
"streaming": [True, False],
"end_id": [50256, 50257],
"pad_id": [50256, 50257],
"stop_words_list": [[[14480, 326, 262, 1171], [1, 4, -1, -1]],
[[66, 77, -1, -1], [1, 2, -1, -1]]],
"bad_words_list": [[[24044, 76, 1230], [2, 3, -1]],
[[88, 99, 111], [1, 3, -1]]],
"embedding_bias":
np.array([[0., 0., 0.], [1., 1., 1.]], dtype=np.float32),
"beam_width": [2, 3],
"runtime_top_k": [1, 2],
"runtime_top_p": [0., 1.],
"seed": [4, 7],
"temperature": [1., 0.5],
"min_tokens": [3, 10],
"repetition_penalty": [1.0, 1.1],
"presence_penalty": [2.0, 2.1],
"frequency_penalty": [4.0, 4.1],
"len_penalty": [8.0, 8.1],
"runtime_top_p_min": [1.0, 0.5],
"runtime_top_p_reset_ids": [1, 3],
"runtime_top_p_decay": [1.0, 0.1],
"beam_search_diversity_rate": [1.0, 0.7],
"early_stopping": [True, False],
"return_log_probs": [True, False],
"return_context_logits": [True, False],
"return_generation_logits": [True, False],
"draft_input_ids": [[0, 1], [2, 3]],
"draft_logits":
np.array([[[1.0, 2.0], [3.0, 4.0]], [[1.1, 2.1], [3.1, 4.1]]],
dtype=np.float32),
"draft_acceptance_threshold": [1.0, 0.5],
"prompt_embedding_table":
np.array([[[1.0, 2.0], [3.0, 4.0]], [[2.0, 3.0], [4.0, 5.0]]],
dtype=np.float16),
"lora_task_id": [1, 2],
"lora_weights":
np.array([[[1.0, 2.0], [3.0, 4.0]], [[3.0, 4.0], [5.0, 6.0]]],
dtype=np.float16),
"lora_config":
np.array([[[1, 2, 3], [4, 5, 6]], [[11, 12, 13], [14, 15, 16]]],
dtype=np.int32),
"retention_token_range_starts":
np.array([[0, 100], [0, 200]], dtype=np.int32),
"retention_token_range_ends":
np.array([[100, 200], [200, 300]], dtype=np.int32),
"retention_token_range_priorities":
np.array([[100, 50], [0, 0]], dtype=np.int32),
"prompt_vocab_size": [2],
}
return make_mock_triton_request(inputs)
@pytest.fixture
def triton_request_minimal() -> MockTritonRequest:
inputs = {
"input_ids": [[28524, 287, 5093, 12]],
"request_output_len": [[16]],
}
return MockTritonRequest({
k: MockTritonTensor(k, np.array(v))
for k, v in inputs.items()
})
@pytest.fixture
def trtllm_response() -> trtllm.Response:
result = trtllm.Result()
result.is_final = True
result.output_token_ids = [[1, 2, 3]]
result.cum_log_probs = [1]
result.log_probs = [[1, 3]]
result.context_logits = torch.ones(3, 10)
result.generation_logits = torch.ones(1, 5, 10)
return trtllm.Response(0, result)
@pytest.fixture
def trtllm_response_minimal() -> trtllm.Response:
result = trtllm.Result()
result.is_final = False
result.output_token_ids = [[1, 2, 3]]
return trtllm.Response(0, result)
@pytest.fixture
def trtllm_response_error() -> trtllm.Response:
return trtllm.Response(0, "internal error")
def test_get_input_tensor_by_name(triton_request: MockTritonRequest):
assert (get_input_tensor_by_name(triton_request, "input_ids") == np.array(
[[28524, 287, 5093, 12]])).all()
assert get_input_tensor_by_name(triton_request, "no_value") is None
def test_get_input_scalar_by_name(triton_request: MockTritonRequest):
assert get_input_scalar_by_name(triton_request, "request_output_len") == 16
assert get_input_scalar_by_name(triton_request, "streaming") == True
assert get_input_scalar_by_name(triton_request, "end_id") == 50256
assert get_input_scalar_by_name(triton_request, "pad_id") == 50256
assert get_input_scalar_by_name(triton_request, "beam_width") == 2
assert get_input_scalar_by_name(triton_request, "runtime_top_k") == 1
assert get_input_scalar_by_name(triton_request, "runtime_top_p") == 0.
assert get_input_scalar_by_name(triton_request, "temperature") == 1.
def test_read_parameter_as_type():
assert read_parameter_as_type("", "name") is None
assert read_parameter_as_type("", "name", int) is None
assert read_parameter_as_type("", "name", float) is None
assert read_parameter_as_type("", "name", bool) is None
assert read_parameter_as_type("${unfilled_parameter}", "name") is None
assert read_parameter_as_type("foo", "name", int) is None
assert read_parameter_as_type("string_value", "name") == "string_value"
assert read_parameter_as_type("4", "name", int) == 4
assert read_parameter_as_type("0.5", "name", float) == 0.5
assert read_parameter_as_type("1", "name", bool) == True
assert read_parameter_as_type("true", "name", bool) == True
assert read_parameter_as_type("True", "name", bool) == True
assert read_parameter_as_type("0", "name", bool) == False
assert read_parameter_as_type("false", "name", bool) == False
assert read_parameter_as_type("False", "name", bool) == False
def test_get_parameter():
model_config = {"parameters": {"max_beam_width": {"string_value": "1"}}}
assert get_parameter(model_config, "max_beam_width", int) == 1
assert get_parameter(model_config, "gpt_model_type", str) is None
def test_convert_word_list():
assert convert_word_list(None) is None
assert convert_word_list(np.array([[[], []]])) == []
assert convert_word_list(
np.array([[[14480, 326, 262, 1171], [1, 4, -1,
-1]]])) == [[14480],
[326, 262, 1171]]
assert convert_word_list(np.array([[[24044, 76, 1230],
[2, 3, -1]]])) == [[24044, 76], [1230]]
assert convert_word_list(np.array([[[326, 262, 1230],
[3, -1, -1]]])) == [[326, 262, 1230]]
for bad_format in [
np.array([]),
np.array([[]]),
np.array([[[]]]),
np.array([[[1], [2], [3]]]),
np.array([[[262], [5]]]),
]:
with pytest.raises(Exception, match="Invalid format for word list"):
convert_word_list(bad_format)
def test_parse_medusa_choices():
assert parse_medusa_choices("{0, 0, 0}, {0, 1}") == [[0, 0, 0], [0, 1]]
for bad_format in [
"{{}",
"{",
"{{}",
"}",
"{0, 1, 2",
"0, 1, 2",
"{0, 1, 2}, {\"foo\"}",
]:
with pytest.raises(Exception,
match="Invalid format for medusa_choices"):
parse_medusa_choices(bad_format)
def check_converted_request(converted):
assert isinstance(converted, trtllm.Request)
assert converted.input_token_ids == [28524, 287, 5093, 12]
assert converted.max_tokens == 16
assert converted.streaming == True
assert converted.end_id == 50256
assert converted.pad_id == 50256
assert converted.stop_words == [[14480], [326, 262, 1171]]
assert converted.bad_words == [[24044, 76], [1230]]
assert (converted.embedding_bias == torch.tensor([0., 0., 0.])).all()
assert converted.logits_post_processor_name is None
assert isinstance(converted.external_draft_tokens_config,
trtllm.ExternalDraftTokensConfig)
assert converted.external_draft_tokens_config.tokens == [0, 1]
assert (converted.external_draft_tokens_config.logits == torch.tensor(
[[1.0, 2.0], [3.0, 4.0]])).all()
assert converted.external_draft_tokens_config.acceptance_threshold == 1.0
assert isinstance(converted.prompt_tuning_config, trtllm.PromptTuningConfig)
assert (converted.prompt_tuning_config.embedding_table == torch.tensor(
[[1.0, 2.0], [3.0, 4.0]])).all()
assert isinstance(converted.lora_config, trtllm.LoraConfig)
assert converted.lora_config.task_id == 1
assert (converted.lora_config.weights == torch.tensor([[1.0, 2.0],
[3.0, 4.0]])).all()
assert (converted.lora_config.config == torch.tensor([[1, 2, 3],
[4, 5, 6]])).all()
assert isinstance(converted.kv_cache_retention_config,
trtllm.KvCacheRetentionConfig)
assert len(
converted.kv_cache_retention_config.token_range_retention_configs)
assert converted.kv_cache_retention_config.token_range_retention_configs[
0].token_start == 0
assert converted.kv_cache_retention_config.token_range_retention_configs[
0].token_end == 100
assert converted.kv_cache_retention_config.token_range_retention_configs[
0].priority == 100
assert converted.sampling_config.beam_width == 2
assert converted.sampling_config.top_k == 1
assert converted.sampling_config.top_p is None
assert converted.sampling_config.top_p_min == 1.0
assert converted.sampling_config.top_p_reset_ids == 1
assert converted.sampling_config.top_p_decay == 1.0
assert converted.sampling_config.seed == 4
assert converted.sampling_config.temperature == 1.0
assert converted.sampling_config.min_tokens == 3
assert converted.sampling_config.beam_search_diversity_rate == 1.0
assert converted.sampling_config.repetition_penalty == 1.0
assert converted.sampling_config.presence_penalty == 2.0
assert converted.sampling_config.frequency_penalty == 4.0
assert converted.sampling_config.length_penalty == 8.0
assert converted.sampling_config.early_stopping == True
assert converted.output_config.return_log_probs == True
assert converted.output_config.return_context_logits == True
assert converted.output_config.return_generation_logits == True
assert converted.output_config.exclude_input_from_output == True
def test_convert_batched_request(batched_triton_request: MockTritonRequest):
converted_reqs = convert_request(batched_triton_request,
exclude_input_from_output=True,
decoupled=True)
assert len(converted_reqs) == 2
converted0 = converted_reqs[0]
check_converted_request(converted0)
converted = converted_reqs[1]
assert isinstance(converted, trtllm.Request)
assert converted.input_token_ids == [1, 2]
assert converted.max_tokens == 3
assert converted.streaming == False
assert converted.end_id == 50257
assert converted.pad_id == 50257
assert converted.stop_words == [[66], [77]]
assert converted.bad_words == [[88], [99, 111]]
assert (converted.embedding_bias == torch.tensor([1., 1., 1.])).all()
assert converted.logits_post_processor_name is None
assert isinstance(converted.external_draft_tokens_config,
trtllm.ExternalDraftTokensConfig)
assert converted.external_draft_tokens_config.tokens == [2, 3]
assert (converted.external_draft_tokens_config.logits == torch.tensor(
[[1.1, 2.1], [3.1, 4.1]])).all()
assert converted.external_draft_tokens_config.acceptance_threshold == 0.5
assert isinstance(converted.prompt_tuning_config, trtllm.PromptTuningConfig)
print(converted.prompt_tuning_config.embedding_table)
assert (converted.prompt_tuning_config.embedding_table == torch.tensor(
[[2.0, 3.0], [4.0, 5.0]])).all()
assert isinstance(converted.lora_config, trtllm.LoraConfig)
assert converted.lora_config.task_id == 2
assert (converted.lora_config.weights == torch.tensor([[3.0, 4.0],
[5.0, 6.0]])).all()
assert (converted.lora_config.config == torch.tensor([[11, 12, 13],
[14, 15, 16]])).all()
assert converted.sampling_config.beam_width == 3
assert converted.sampling_config.top_k == 2
assert converted.sampling_config.top_p == 1.
assert converted.sampling_config.top_p_min == 0.5
assert converted.sampling_config.top_p_reset_ids == 3
assert converted.sampling_config.top_p_decay == pytest.approx(0.1)
assert converted.sampling_config.seed == 7
assert converted.sampling_config.temperature == 0.5
assert converted.sampling_config.min_tokens == 10
assert converted.sampling_config.beam_search_diversity_rate == pytest.approx(
0.7)
assert converted.sampling_config.repetition_penalty == pytest.approx(1.1)
assert converted.sampling_config.presence_penalty == pytest.approx(2.1)
assert converted.sampling_config.frequency_penalty == pytest.approx(4.1)
assert converted.sampling_config.length_penalty == pytest.approx(8.1)
assert converted.sampling_config.early_stopping == False
assert converted.output_config.return_log_probs == False
assert converted.output_config.return_context_logits == False
assert converted.output_config.return_generation_logits == False
assert converted.output_config.exclude_input_from_output == True
def test_convert_request(triton_request: MockTritonRequest):
converted_reqs = convert_request(triton_request,
exclude_input_from_output=True,
decoupled=True)
assert len(converted_reqs) == 1
converted = converted_reqs[0]
check_converted_request(converted)
def test_convert_request_minimal(triton_request_minimal: MockTritonRequest):
converted_reqs = convert_request(triton_request_minimal,
exclude_input_from_output=False,
decoupled=False)
assert len(converted_reqs) == 1
converted = converted_reqs[0]
assert converted.input_token_ids == [28524, 287, 5093, 12]
assert converted.max_tokens == 16
assert converted.streaming == False
assert converted.end_id is None
assert converted.pad_id is None
assert converted.stop_words is None
assert converted.bad_words is None
assert converted.embedding_bias is None
assert converted.logits_post_processor_name is None
assert converted.external_draft_tokens_config is None
assert converted.prompt_tuning_config is None
assert converted.lora_config is None
assert converted.kv_cache_retention_config is None
assert converted.sampling_config.beam_width == 1
assert converted.sampling_config.top_k is None
assert converted.sampling_config.top_p is None
assert converted.sampling_config.top_p_min is None
assert converted.sampling_config.top_p_reset_ids is None
assert converted.sampling_config.top_p_decay is None
assert converted.sampling_config.seed is None
assert converted.sampling_config.temperature is None
assert converted.sampling_config.min_tokens is None
assert converted.sampling_config.beam_search_diversity_rate is None
assert converted.sampling_config.repetition_penalty is None
assert converted.sampling_config.presence_penalty is None
assert converted.sampling_config.frequency_penalty is None
assert converted.sampling_config.length_penalty is None
assert converted.sampling_config.early_stopping is None
assert converted.output_config.return_log_probs == False
assert converted.output_config.return_context_logits == False
assert converted.output_config.return_generation_logits == False
assert converted.output_config.exclude_input_from_output == False
def test_kv_cache_retention_config_invalid():
def check_retention_config(d: Dict[str, np.ndarray], is_valid: bool):
req = make_mock_triton_request(d)
if is_valid:
get_kv_cache_retention_config_from_request(req, 1, 0)
else:
with pytest.raises(RuntimeError):
get_kv_cache_retention_config_from_request(req, 1, 0)
check_retention_config(
{
"retention_token_range_starts": np.array([[0, 100]]),
"retention_token_range_ends": np.array([[100, 200]])
}, False)
check_retention_config(
{
"retention_token_range_starts": np.array([[0, 100]]),
"retention_token_range_ends": np.array([[100, 200]]),
"retention_token_range_priorities": np.array([[100]])
}, False)
check_retention_config(
{
"retention_token_range_starts": np.array([[0, 100]]),
"retention_token_range_ends": np.array([[100, 200]]),
"retention_token_range_priorities": np.array([[50, 50]])
}, True)
check_retention_config(
{
"retention_token_range_starts": np.array([[0, 100]]),
"retention_token_range_ends": np.array([[100, 200]]),
"retention_token_range_priorities": np.array([[50, 50]]),
"retention_token_range_durations_ms": np.array([[100]])
}, False)
check_retention_config(
{
"retention_token_range_starts": np.array([[0, 100]]),
"retention_token_range_ends": np.array([[100, 200]]),
"retention_token_range_priorities": np.array([[50, 50]]),
"retention_token_range_durations_ms": np.array([[100, 50]])
}, True)
check_retention_config(
{
"retention_token_range_starts": np.array([[0]]),
"retention_token_range_ends": np.array([[-1]]),
"retention_token_range_priorities": np.array([[50]]),
"retention_token_range_durations_ms": np.array([[1000]])
}, True)
# Need to test with Executor lookahead config.
def test_request_lookahead_config():
def check_request_lookahead_config(request_config: Dict[str, str],
executor_config, is_valid: bool):
req = make_mock_triton_request(request_config)
if is_valid:
get_lookahead_decoding_config_from_request(req,
executor_config,
batch_size=1,
batch_index=0)
else:
with pytest.raises(Exception):
get_lookahead_decoding_config_from_request(req,
executor_config,
batch_size=1,
batch_index=0)
# When request and executor lookahead_config are set correctly
check_request_lookahead_config(
{
"lookahead_window_size": np.array([[3]], dtype=np.int32),
"lookahead_ngram_size": np.array([[3]], dtype=np.int32),
"lookahead_verification_set_size": np.array([[3]], dtype=np.int32),
}, trtllm.LookaheadDecodingConfig(3, 3, 3), True)
# When request lookahead_config is not specified
check_request_lookahead_config({}, trtllm.LookaheadDecodingConfig(3, 3, 3),
True)
# When request lookahead_config is incomplete
check_request_lookahead_config(
{
"lookahead_window_size": np.array([[3]], dtype=np.int32),
}, trtllm.LookaheadDecodingConfig(3, 3, 3), False)
# When request lookahead_config is incomplete
check_request_lookahead_config(
{
"lookahead_window_size": np.array([[3]], dtype=np.int32),
"lookahead_ngram_size": np.array([[3]], dtype=np.int32),
}, trtllm.LookaheadDecodingConfig(3, 3, 3), False)
# When request lookahead_config is set while executor_lookahead_config is None
check_request_lookahead_config(
{
"lookahead_window_size": np.array([[3]], dtype=np.int32),
"lookahead_ngram_size": np.array([[3]], dtype=np.int32),
"lookahead_verification_set_size": np.array([[3]], dtype=np.int32),
}, None, False)
def test_convert_request_invalid():
with pytest.raises(Exception, match="A value is required for input_ids"):
no_input_ids = MockTritonRequest({
"request_output_len":
MockTritonTensor("request_output_len", np.array([[128]]))
})
convert_request(no_input_ids, False, False)
with pytest.raises(Exception, match="Invalid format for input_ids"):
bad_input_ids = MockTritonRequest(
{"input_ids": MockTritonTensor("input_ids", np.array([]))})
convert_request(bad_input_ids, False, False)
with pytest.raises(Exception,
match="A value is required for request_output_len"):
no_output_len = MockTritonRequest(
{"input_ids": MockTritonTensor("input_ids", np.array([[1, 2, 3]]))})
convert_request(no_output_len, False, False)
with pytest.raises(Exception,
match="Streaming is only supported in decoupled mode."):
streaming_non_decoupled = MockTritonRequest({
"input_ids":
MockTritonTensor("input_ids", np.array([[1, 2, 3]])),
"request_output_len":
MockTritonTensor("request_output_len", np.array([[128]])),
"streaming":
MockTritonTensor("streaming", np.array([[True]])),
})
convert_request(streaming_non_decoupled, False, False)
def test_convert_response(trtllm_response: trtllm.Response):
batch_index = 2
batch_size = 3
num_return_sequences = 1
response, is_final, output_length = convert_response(
trtllm_response, batch_index, batch_size, num_return_sequences)
assert is_final == True
assert (response.tensors["output_ids"].as_numpy() == np.array([[1, 2,
3]])).all()
assert (response.tensors["sequence_length"].as_numpy() == np.array(
[[3]])).all()
assert (response.tensors["cum_log_probs"].as_numpy() == np.array([1])).all()
assert (response.tensors["output_log_probs"].as_numpy() == np.array(
[[1, 3]])).all()
assert (response.tensors["context_logits"].as_numpy() == np.ones(
(3, 10), dtype=np.float32)).all()
assert (response.tensors["generation_logits"].as_numpy() == np.ones(
(1, 5, 10), dtype=np.float32)).all()
assert (response.tensors["batch_index"].as_numpy() == np.array(
[[batch_index]])).all()
def test_convert_response_minimal(trtllm_response_minimal: trtllm.Response):
batch_index = 2
batch_size = 3
num_return_sequences = 1
response, is_final, output_length = convert_response(
trtllm_response_minimal, batch_index, batch_size, num_return_sequences)
assert is_final == False
assert (response.tensors["output_ids"].as_numpy() == np.array([[1, 2,
3]])).all()
assert (response.tensors["sequence_length"].as_numpy() == np.array(
[[3]])).all()
assert "cum_log_probs" not in response.tensors
assert "output_log_probs" not in response.tensors
assert "output_log_probs" not in response.tensors
assert "context_logits" not in response.tensors
assert "generation_logits" not in response.tensors
assert (response.tensors["batch_index"].as_numpy() == np.array(
[[batch_index]])).all()
def test_convert_response_error(trtllm_response_error: trtllm.Response):
batch_index = 2
batch_size = 3
num_return_sequences = 1
response, is_final, output_length = convert_response(
trtllm_response_error, batch_index, batch_size, num_return_sequences)
assert is_final == True
assert response.has_error() and response.error.message == "internal error"
def test_convert_scheduler_policy():
assert convert_scheduler_policy(
"max_utilization") == trtllm.CapacitySchedulerPolicy.MAX_UTILIZATION
assert convert_scheduler_policy(
"guaranteed_no_evict"
) == trtllm.CapacitySchedulerPolicy.GUARANTEED_NO_EVICT
with pytest.raises(
Exception,
match="batch_scheduler_policy value of 'other' is not supported"):
convert_scheduler_policy("other")
def test_convert_batching_type():
assert convert_batching_type(
"inflight_fused_batching") == trtllm.BatchingType.INFLIGHT
assert convert_batching_type(
"inflight_batching") == trtllm.BatchingType.INFLIGHT
assert convert_batching_type("v1") == trtllm.BatchingType.STATIC
with pytest.raises(
Exception,
match="gpt_model_type value of 'other' is not supported"):
convert_batching_type("other")
def test_convert_decoding_mode():
assert convert_decoding_mode(None) is None
assert convert_decoding_mode("auto").isAuto()
assert convert_decoding_mode("top_k").isTopK()
assert convert_decoding_mode("top_p").isTopP()
assert convert_decoding_mode("top_k_top_p").isTopKandTopP()
assert convert_decoding_mode("beam_search").isBeamSearch()
assert convert_decoding_mode("medusa").isMedusa()
assert convert_decoding_mode("redrafter").isExplicitDraftTokens()
assert convert_decoding_mode("lookahead").isLookahead()
assert convert_decoding_mode("eagle").isEagle()
with pytest.raises(Exception,
match="decoding_mode value of 'other' is not supported"):
convert_decoding_mode("other")
@pytest.fixture
def model_config() -> Dict:
config = {
"max_beam_width": "2",
"enable_chunked_context": "true",
"normalize_log_probs": "false",
"gpt_model_type": "inflight_batching",
"medusa_choices": "{1, 2, 3, 4}, {5, 6, 7}",
"decoding_mode": "medusa",
"batch_scheduler_policy": "max_utilization",
"enable_kv_cache_reuse": "false",
"max_tokens_in_paged_kv_cache": "1",
"max_attention_window_size": "2",
"sink_token_length": "3",
"kv_cache_free_gpu_mem_fraction": "0.5",
"cross_kv_cache_fraction": "0.5",
"kv_cache_host_memory_bytes": "4",
"kv_cache_onboard_blocks": "false",
"gpu_device_ids": "0,1,2,3",
"executor_worker_path": str(os.path.abspath(__file__)),
"lora_cache_optimal_adapter_size": "1",
"lora_cache_max_adapter_size": "2",
"lora_cache_gpu_memory_fraction": "0.5",
"lora_cache_host_memory_bytes": "4",
"lora_prefetch_dir": "",
"enable_context_fmha_fp32_acc": "true"
}
return {"parameters": {k: {"string_value": v} for k, v in config.items()}}
def test_get_executor_config(model_config: Dict):
os.environ["TRTLLM_ORCHESTRATOR"] = "0"
config = TritonPythonModel().get_executor_config(model_config)
assert config.max_beam_width == 2
assert config.enable_chunked_context == True
assert config.normalize_log_probs == False
assert config.batching_type == trtllm.BatchingType.INFLIGHT
assert config.decoding_config.medusa_choices == [[1, 2, 3, 4], [5, 6, 7]]
assert config.decoding_config.decoding_mode.isMedusa()
assert config.scheduler_config.capacity_scheduler_policy == trtllm.CapacitySchedulerPolicy.MAX_UTILIZATION
assert config.kv_cache_config.enable_block_reuse == False
assert config.kv_cache_config.max_tokens == 1
assert config.kv_cache_config.max_attention_window == [2]
assert config.kv_cache_config.sink_token_length == 3
assert config.kv_cache_config.free_gpu_memory_fraction == 0.5
assert config.kv_cache_config.cross_kv_cache_fraction == 0.5
assert config.kv_cache_config.host_cache_size == 4
assert config.kv_cache_config.onboard_blocks == False
assert config.parallel_config.device_ids == [0, 1, 2, 3]
assert config.parallel_config.orchestrator_config is None
assert config.peft_cache_config.optimal_adapter_size == 1
assert config.peft_cache_config.max_adapter_size == 2
assert config.peft_cache_config.device_cache_percent == 0.5
assert config.peft_cache_config.host_cache_size == 4
assert config.iter_stats_max_iterations == 1000
assert config.request_stats_max_iterations == 0
assert config.logits_post_processor_config is None
assert config.extended_runtime_perf_knob_config.enable_context_fmha_fp32_acc == True
assert config.extended_runtime_perf_knob_config.multi_block_mode == True
del os.environ["TRTLLM_ORCHESTRATOR"]
def test_get_executor_config_orchestrator_mode(model_config: Dict):
os.environ["TRTLLM_ORCHESTRATOR"] = "1"
config = TritonPythonModel().get_executor_config(model_config)
assert config.parallel_config.device_ids == [0, 1, 2, 3]
assert config.parallel_config.orchestrator_config.is_orchestrator == True
assert config.parallel_config.orchestrator_config.worker_executable_path == str(
os.path.abspath(__file__))
del os.environ["TRTLLM_ORCHESTRATOR"]
def test_get_executor_config_minimal():
if "TRTLLM_ORCHESTRATOR" in os.environ:
del os.environ["TRTLLM_ORCHESTRATOR"]
config = TritonPythonModel().get_executor_config({"parameters": {}})
assert config.max_beam_width == 1
assert config.enable_chunked_context == False
assert config.normalize_log_probs == True
assert config.batching_type == trtllm.BatchingType.INFLIGHT
assert config.decoding_config.decoding_mode is None
assert config.decoding_config.medusa_choices is None
assert config.decoding_config.eagle_config is None
assert config.decoding_config.lookahead_decoding_config is None
assert config.scheduler_config.capacity_scheduler_policy == trtllm.CapacitySchedulerPolicy.GUARANTEED_NO_EVICT
assert config.kv_cache_config.enable_block_reuse == True
assert config.kv_cache_config.max_tokens is None
assert config.kv_cache_config.max_attention_window is None
assert config.kv_cache_config.sink_token_length is None
assert config.kv_cache_config.free_gpu_memory_fraction is None
assert config.kv_cache_config.cross_kv_cache_fraction is None
assert config.kv_cache_config.host_cache_size is None
assert config.kv_cache_config.onboard_blocks == True
assert config.parallel_config is None
assert config.peft_cache_config.optimal_adapter_size == 8
assert config.peft_cache_config.max_adapter_size == 64
assert config.peft_cache_config.device_cache_percent is None
assert config.peft_cache_config.host_cache_size is None
assert config.iter_stats_max_iterations == 1000
assert config.request_stats_max_iterations == 0
assert config.logits_post_processor_config is None
assert config.extended_runtime_perf_knob_config.enable_context_fmha_fp32_acc == False
assert config.extended_runtime_perf_knob_config.multi_block_mode == True
def test_convert_timestamp_to_seconds():
assert convert_timestamp_to_seconds("01-01-1970 00:00:00.000000") == 0
assert convert_timestamp_to_seconds(
"05-17-2024 23:28:39.000000") == 1715988519