mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
Signed-off-by: Dimitrios Bariamis <12195802+dbari@users.noreply.github.com> Signed-off-by: Dimitrios Bariamis <dbari@users.noreply.github.com> Co-authored-by: Dimitrios Bariamis <12195802+dbari@users.noreply.github.com> Co-authored-by: Iman Tabrizian <10105175+Tabrizian@users.noreply.github.com>
879 lines
36 KiB
Python
879 lines
36 KiB
Python
# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
#
|
|
# Redistribution and use in source and binary forms, with or without
|
|
# modification, are permitted provided that the following conditions
|
|
# are met:
|
|
# * Redistributions of source code must retain the above copyright
|
|
# notice, this list of conditions and the following disclaimer.
|
|
# * Redistributions in binary form must reproduce the above copyright
|
|
# notice, this list of conditions and the following disclaimer in the
|
|
# documentation and/or other materials provided with the distribution.
|
|
# * Neither the name of NVIDIA CORPORATION nor the names of its
|
|
# contributors may be used to endorse or promote products derived
|
|
# from this software without specific prior written permission.
|
|
#
|
|
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
|
|
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
|
|
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
|
|
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
|
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
|
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
|
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
|
|
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
|
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
|
|
import os
|
|
import sys
|
|
from dataclasses import dataclass
|
|
from typing import Dict, List, Union
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
import numpy as np
|
|
import pytest
|
|
import torch
|
|
|
|
# Mock pb_utils
|
|
sys.modules["triton_python_backend_utils"] = MagicMock()
|
|
|
|
# Use PYTHONPATH=../inflight_batcher_llm/tensorrt_llm/1/
|
|
from model import *
|
|
|
|
import tensorrt_llm.bindings.executor as trtllm
|
|
|
|
|
|
@dataclass
|
|
class MockTritonTensor:
|
|
_name: str
|
|
_tensor: Union[np.ndarray, torch.Tensor]
|
|
|
|
def name(self) -> str:
|
|
return self._name
|
|
|
|
def as_numpy(self) -> np.ndarray:
|
|
if self.is_cpu():
|
|
return self._tensor
|
|
else:
|
|
return self._tensor.as_numpy()
|
|
|
|
def is_cpu(self) -> bool:
|
|
if isinstance(self._tensor, np.ndarray):
|
|
return True
|
|
else:
|
|
return False
|
|
|
|
def to_dlpack(self):
|
|
if self.is_cpu():
|
|
return self._tensor.__dlpack__()
|
|
else:
|
|
return self._tensor.to_dlpack()
|
|
|
|
|
|
@dataclass
|
|
class MockTritonError:
|
|
message: str
|
|
|
|
|
|
@dataclass
|
|
class MockTritonResponse:
|
|
tensors: Dict[str, MockTritonTensor]
|
|
error: MockTritonError
|
|
|
|
def __init__(self,
|
|
output_tensors: List[MockTritonTensor],
|
|
error: MockTritonError = None):
|
|
self.tensors = {}
|
|
for tensor in output_tensors:
|
|
self.tensors[tensor.name()] = tensor
|
|
self.error = error
|
|
|
|
def output_tensors(self):
|
|
return self.tensors.values()
|
|
|
|
def has_error(self):
|
|
return self.error is not None
|
|
|
|
|
|
@dataclass
|
|
class MockTritonRequest:
|
|
tensors: Dict[str, MockTritonTensor]
|
|
|
|
def get_input_tensor_by_name(self, name: str) -> MockTritonTensor:
|
|
return self.tensors[name] if name in self.tensors else None
|
|
|
|
def get_response_sender(self):
|
|
return None
|
|
|
|
|
|
def mock_pb_utils_get_input_tensor_by_name_side_effect(
|
|
request: MockTritonRequest, name: str) -> MockTritonTensor:
|
|
return request.get_input_tensor_by_name(name)
|
|
|
|
|
|
def make_mock_triton_request(
|
|
tensors: Dict[str, np.ndarray]) -> MockTritonRequest:
|
|
return MockTritonRequest({
|
|
k: MockTritonTensor(k, np.array(v))
|
|
for k, v in tensors.items()
|
|
})
|
|
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def apply_patches():
|
|
patch("model.pb_utils.Tensor", new=MockTritonTensor).start()
|
|
patch("model.pb_utils.InferenceResponse", new=MockTritonResponse).start()
|
|
patch("model.pb_utils.TritonError", new=MockTritonError).start()
|
|
patch("model.pb_utils.InferenceRequest", new=MockTritonRequest).start()
|
|
patch("model.pb_utils.get_input_tensor_by_name",
|
|
new=mock_pb_utils_get_input_tensor_by_name_side_effect).start()
|
|
patch("model.pb_utils.TritonModelException", new=Exception).start()
|
|
|
|
|
|
@pytest.fixture
|
|
def triton_request() -> MockTritonRequest:
|
|
inputs = {
|
|
"input_ids": [[28524, 287, 5093, 12]],
|
|
"request_output_len": [16],
|
|
"streaming": [True],
|
|
"end_id": [50256],
|
|
"pad_id": [50256],
|
|
"stop_words_list": [[[14480, 326, 262, 1171], [1, 4, -1, -1]]],
|
|
"bad_words_list": [[[24044, 76, 1230], [2, 3, -1]]],
|
|
"embedding_bias":
|
|
np.array([[0., 0., 0.]], dtype=np.float32),
|
|
"beam_width": [2],
|
|
"runtime_top_k": [1],
|
|
"runtime_top_p": [0.],
|
|
"seed": [4],
|
|
"temperature": [1.],
|
|
"min_tokens": [3],
|
|
"repetition_penalty": [1.0],
|
|
"presence_penalty": [2.0],
|
|
"frequency_penalty": [4.0],
|
|
"len_penalty": [8.0],
|
|
"runtime_top_p_min": [1.0],
|
|
"runtime_top_p_reset_ids": [1],
|
|
"runtime_top_p_decay": [1.0],
|
|
"beam_search_diversity_rate": [1.0],
|
|
"early_stopping": [True],
|
|
"return_log_probs":
|
|
True,
|
|
"return_context_logits":
|
|
True,
|
|
"return_generation_logits":
|
|
True,
|
|
"draft_input_ids": [[0, 1]],
|
|
"draft_logits":
|
|
np.array([[[1.0, 2.0], [3.0, 4.0]]], dtype=np.float32),
|
|
"draft_acceptance_threshold":
|
|
1.0,
|
|
"prompt_embedding_table":
|
|
np.array([[[1.0, 2.0], [3.0, 4.0]]], dtype=np.float16),
|
|
"lora_task_id": [1],
|
|
"lora_weights":
|
|
np.array([[[1.0, 2.0], [3.0, 4.0]]], dtype=np.float16),
|
|
"lora_config":
|
|
np.array([[[1, 2, 3], [4, 5, 6]]], dtype=np.int32),
|
|
"retention_token_range_starts":
|
|
np.array([[0, 100]], dtype=np.int32),
|
|
"retention_token_range_ends":
|
|
np.array([[100, 200]], dtype=np.int32),
|
|
"retention_token_range_priorities":
|
|
np.array([[100, 50]], dtype=np.int32),
|
|
"prompt_vocab_size": [2],
|
|
}
|
|
return make_mock_triton_request(inputs)
|
|
|
|
|
|
@pytest.fixture
|
|
def batched_triton_request() -> MockTritonRequest:
|
|
inputs = {
|
|
"input_ids": [[28524, 287, 5093, 12], [1, 2, 3, 4]],
|
|
"input_lengths": [4, 2],
|
|
"request_output_len": [16, 3],
|
|
"streaming": [True, False],
|
|
"end_id": [50256, 50257],
|
|
"pad_id": [50256, 50257],
|
|
"stop_words_list": [[[14480, 326, 262, 1171], [1, 4, -1, -1]],
|
|
[[66, 77, -1, -1], [1, 2, -1, -1]]],
|
|
"bad_words_list": [[[24044, 76, 1230], [2, 3, -1]],
|
|
[[88, 99, 111], [1, 3, -1]]],
|
|
"embedding_bias":
|
|
np.array([[0., 0., 0.], [1., 1., 1.]], dtype=np.float32),
|
|
"beam_width": [2, 3],
|
|
"runtime_top_k": [1, 2],
|
|
"runtime_top_p": [0., 1.],
|
|
"seed": [4, 7],
|
|
"temperature": [1., 0.5],
|
|
"min_tokens": [3, 10],
|
|
"repetition_penalty": [1.0, 1.1],
|
|
"presence_penalty": [2.0, 2.1],
|
|
"frequency_penalty": [4.0, 4.1],
|
|
"len_penalty": [8.0, 8.1],
|
|
"runtime_top_p_min": [1.0, 0.5],
|
|
"runtime_top_p_reset_ids": [1, 3],
|
|
"runtime_top_p_decay": [1.0, 0.1],
|
|
"beam_search_diversity_rate": [1.0, 0.7],
|
|
"early_stopping": [True, False],
|
|
"return_log_probs": [True, False],
|
|
"return_context_logits": [True, False],
|
|
"return_generation_logits": [True, False],
|
|
"draft_input_ids": [[0, 1], [2, 3]],
|
|
"draft_logits":
|
|
np.array([[[1.0, 2.0], [3.0, 4.0]], [[1.1, 2.1], [3.1, 4.1]]],
|
|
dtype=np.float32),
|
|
"draft_acceptance_threshold": [1.0, 0.5],
|
|
"prompt_embedding_table":
|
|
np.array([[[1.0, 2.0], [3.0, 4.0]], [[2.0, 3.0], [4.0, 5.0]]],
|
|
dtype=np.float16),
|
|
"lora_task_id": [1, 2],
|
|
"lora_weights":
|
|
np.array([[[1.0, 2.0], [3.0, 4.0]], [[3.0, 4.0], [5.0, 6.0]]],
|
|
dtype=np.float16),
|
|
"lora_config":
|
|
np.array([[[1, 2, 3], [4, 5, 6]], [[11, 12, 13], [14, 15, 16]]],
|
|
dtype=np.int32),
|
|
"retention_token_range_starts":
|
|
np.array([[0, 100], [0, 200]], dtype=np.int32),
|
|
"retention_token_range_ends":
|
|
np.array([[100, 200], [200, 300]], dtype=np.int32),
|
|
"retention_token_range_priorities":
|
|
np.array([[100, 50], [0, 0]], dtype=np.int32),
|
|
"prompt_vocab_size": [2],
|
|
}
|
|
return make_mock_triton_request(inputs)
|
|
|
|
|
|
@pytest.fixture
|
|
def triton_request_minimal() -> MockTritonRequest:
|
|
inputs = {
|
|
"input_ids": [[28524, 287, 5093, 12]],
|
|
"request_output_len": [[16]],
|
|
}
|
|
return MockTritonRequest({
|
|
k: MockTritonTensor(k, np.array(v))
|
|
for k, v in inputs.items()
|
|
})
|
|
|
|
|
|
@pytest.fixture
|
|
def trtllm_response() -> trtllm.Response:
|
|
result = trtllm.Result()
|
|
result.is_final = True
|
|
result.output_token_ids = [[1, 2, 3]]
|
|
result.cum_log_probs = [1]
|
|
result.log_probs = [[1, 3]]
|
|
result.context_logits = torch.ones(3, 10)
|
|
result.generation_logits = torch.ones(1, 5, 10)
|
|
return trtllm.Response(0, result)
|
|
|
|
|
|
@pytest.fixture
|
|
def trtllm_response_minimal() -> trtllm.Response:
|
|
result = trtllm.Result()
|
|
result.is_final = False
|
|
result.output_token_ids = [[1, 2, 3]]
|
|
return trtllm.Response(0, result)
|
|
|
|
|
|
@pytest.fixture
|
|
def trtllm_response_error() -> trtllm.Response:
|
|
return trtllm.Response(0, "internal error")
|
|
|
|
|
|
def test_get_input_tensor_by_name(triton_request: MockTritonRequest):
|
|
assert (get_input_tensor_by_name(triton_request, "input_ids") == np.array(
|
|
[[28524, 287, 5093, 12]])).all()
|
|
assert get_input_tensor_by_name(triton_request, "no_value") is None
|
|
|
|
|
|
def test_get_input_scalar_by_name(triton_request: MockTritonRequest):
|
|
assert get_input_scalar_by_name(triton_request, "request_output_len") == 16
|
|
assert get_input_scalar_by_name(triton_request, "streaming") == True
|
|
assert get_input_scalar_by_name(triton_request, "end_id") == 50256
|
|
assert get_input_scalar_by_name(triton_request, "pad_id") == 50256
|
|
assert get_input_scalar_by_name(triton_request, "beam_width") == 2
|
|
assert get_input_scalar_by_name(triton_request, "runtime_top_k") == 1
|
|
assert get_input_scalar_by_name(triton_request, "runtime_top_p") == 0.
|
|
assert get_input_scalar_by_name(triton_request, "temperature") == 1.
|
|
|
|
|
|
def test_read_parameter_as_type():
|
|
assert read_parameter_as_type("", "name") is None
|
|
assert read_parameter_as_type("", "name", int) is None
|
|
assert read_parameter_as_type("", "name", float) is None
|
|
assert read_parameter_as_type("", "name", bool) is None
|
|
assert read_parameter_as_type("${unfilled_parameter}", "name") is None
|
|
assert read_parameter_as_type("foo", "name", int) is None
|
|
assert read_parameter_as_type("string_value", "name") == "string_value"
|
|
assert read_parameter_as_type("4", "name", int) == 4
|
|
assert read_parameter_as_type("0.5", "name", float) == 0.5
|
|
assert read_parameter_as_type("1", "name", bool) == True
|
|
assert read_parameter_as_type("true", "name", bool) == True
|
|
assert read_parameter_as_type("True", "name", bool) == True
|
|
assert read_parameter_as_type("0", "name", bool) == False
|
|
assert read_parameter_as_type("false", "name", bool) == False
|
|
assert read_parameter_as_type("False", "name", bool) == False
|
|
|
|
|
|
def test_get_parameter():
|
|
model_config = {"parameters": {"max_beam_width": {"string_value": "1"}}}
|
|
assert get_parameter(model_config, "max_beam_width", int) == 1
|
|
assert get_parameter(model_config, "gpt_model_type", str) is None
|
|
|
|
|
|
def test_convert_word_list():
|
|
assert convert_word_list(None) is None
|
|
assert convert_word_list(np.array([[[], []]])) == []
|
|
assert convert_word_list(
|
|
np.array([[[14480, 326, 262, 1171], [1, 4, -1,
|
|
-1]]])) == [[14480],
|
|
[326, 262, 1171]]
|
|
assert convert_word_list(np.array([[[24044, 76, 1230],
|
|
[2, 3, -1]]])) == [[24044, 76], [1230]]
|
|
assert convert_word_list(np.array([[[326, 262, 1230],
|
|
[3, -1, -1]]])) == [[326, 262, 1230]]
|
|
for bad_format in [
|
|
np.array([]),
|
|
np.array([[]]),
|
|
np.array([[[]]]),
|
|
np.array([[[1], [2], [3]]]),
|
|
np.array([[[262], [5]]]),
|
|
]:
|
|
with pytest.raises(Exception, match="Invalid format for word list"):
|
|
convert_word_list(bad_format)
|
|
|
|
|
|
def test_parse_medusa_choices():
|
|
assert parse_medusa_choices("{0, 0, 0}, {0, 1}") == [[0, 0, 0], [0, 1]]
|
|
for bad_format in [
|
|
"{{}",
|
|
"{",
|
|
"{{}",
|
|
"}",
|
|
"{0, 1, 2",
|
|
"0, 1, 2",
|
|
"{0, 1, 2}, {\"foo\"}",
|
|
]:
|
|
with pytest.raises(Exception,
|
|
match="Invalid format for medusa_choices"):
|
|
parse_medusa_choices(bad_format)
|
|
|
|
|
|
def check_converted_request(converted):
|
|
assert isinstance(converted, trtllm.Request)
|
|
assert converted.input_token_ids == [28524, 287, 5093, 12]
|
|
assert converted.max_tokens == 16
|
|
assert converted.streaming == True
|
|
assert converted.end_id == 50256
|
|
assert converted.pad_id == 50256
|
|
assert converted.stop_words == [[14480], [326, 262, 1171]]
|
|
assert converted.bad_words == [[24044, 76], [1230]]
|
|
assert (converted.embedding_bias == torch.tensor([0., 0., 0.])).all()
|
|
assert converted.logits_post_processor_name is None
|
|
|
|
assert isinstance(converted.external_draft_tokens_config,
|
|
trtllm.ExternalDraftTokensConfig)
|
|
assert converted.external_draft_tokens_config.tokens == [0, 1]
|
|
assert (converted.external_draft_tokens_config.logits == torch.tensor(
|
|
[[1.0, 2.0], [3.0, 4.0]])).all()
|
|
assert converted.external_draft_tokens_config.acceptance_threshold == 1.0
|
|
|
|
assert isinstance(converted.prompt_tuning_config, trtllm.PromptTuningConfig)
|
|
assert (converted.prompt_tuning_config.embedding_table == torch.tensor(
|
|
[[1.0, 2.0], [3.0, 4.0]])).all()
|
|
|
|
assert isinstance(converted.lora_config, trtllm.LoraConfig)
|
|
assert converted.lora_config.task_id == 1
|
|
assert (converted.lora_config.weights == torch.tensor([[1.0, 2.0],
|
|
[3.0, 4.0]])).all()
|
|
assert (converted.lora_config.config == torch.tensor([[1, 2, 3],
|
|
[4, 5, 6]])).all()
|
|
|
|
assert isinstance(converted.kv_cache_retention_config,
|
|
trtllm.KvCacheRetentionConfig)
|
|
assert len(
|
|
converted.kv_cache_retention_config.token_range_retention_configs)
|
|
assert converted.kv_cache_retention_config.token_range_retention_configs[
|
|
0].token_start == 0
|
|
assert converted.kv_cache_retention_config.token_range_retention_configs[
|
|
0].token_end == 100
|
|
assert converted.kv_cache_retention_config.token_range_retention_configs[
|
|
0].priority == 100
|
|
|
|
assert converted.sampling_config.beam_width == 2
|
|
assert converted.sampling_config.top_k == 1
|
|
assert converted.sampling_config.top_p is None
|
|
assert converted.sampling_config.top_p_min == 1.0
|
|
assert converted.sampling_config.top_p_reset_ids == 1
|
|
assert converted.sampling_config.top_p_decay == 1.0
|
|
assert converted.sampling_config.seed == 4
|
|
assert converted.sampling_config.temperature == 1.0
|
|
assert converted.sampling_config.min_tokens == 3
|
|
assert converted.sampling_config.beam_search_diversity_rate == 1.0
|
|
assert converted.sampling_config.repetition_penalty == 1.0
|
|
assert converted.sampling_config.presence_penalty == 2.0
|
|
assert converted.sampling_config.frequency_penalty == 4.0
|
|
assert converted.sampling_config.length_penalty == 8.0
|
|
assert converted.sampling_config.early_stopping == True
|
|
|
|
assert converted.output_config.return_log_probs == True
|
|
assert converted.output_config.return_context_logits == True
|
|
assert converted.output_config.return_generation_logits == True
|
|
assert converted.output_config.exclude_input_from_output == True
|
|
|
|
|
|
def test_convert_batched_request(batched_triton_request: MockTritonRequest):
|
|
converted_reqs = convert_request(batched_triton_request,
|
|
exclude_input_from_output=True,
|
|
decoupled=True)
|
|
assert len(converted_reqs) == 2
|
|
converted0 = converted_reqs[0]
|
|
check_converted_request(converted0)
|
|
|
|
converted = converted_reqs[1]
|
|
|
|
assert isinstance(converted, trtllm.Request)
|
|
assert converted.input_token_ids == [1, 2]
|
|
assert converted.max_tokens == 3
|
|
assert converted.streaming == False
|
|
assert converted.end_id == 50257
|
|
assert converted.pad_id == 50257
|
|
assert converted.stop_words == [[66], [77]]
|
|
assert converted.bad_words == [[88], [99, 111]]
|
|
assert (converted.embedding_bias == torch.tensor([1., 1., 1.])).all()
|
|
assert converted.logits_post_processor_name is None
|
|
|
|
assert isinstance(converted.external_draft_tokens_config,
|
|
trtllm.ExternalDraftTokensConfig)
|
|
assert converted.external_draft_tokens_config.tokens == [2, 3]
|
|
assert (converted.external_draft_tokens_config.logits == torch.tensor(
|
|
[[1.1, 2.1], [3.1, 4.1]])).all()
|
|
assert converted.external_draft_tokens_config.acceptance_threshold == 0.5
|
|
|
|
assert isinstance(converted.prompt_tuning_config, trtllm.PromptTuningConfig)
|
|
print(converted.prompt_tuning_config.embedding_table)
|
|
assert (converted.prompt_tuning_config.embedding_table == torch.tensor(
|
|
[[2.0, 3.0], [4.0, 5.0]])).all()
|
|
|
|
assert isinstance(converted.lora_config, trtllm.LoraConfig)
|
|
assert converted.lora_config.task_id == 2
|
|
assert (converted.lora_config.weights == torch.tensor([[3.0, 4.0],
|
|
[5.0, 6.0]])).all()
|
|
assert (converted.lora_config.config == torch.tensor([[11, 12, 13],
|
|
[14, 15, 16]])).all()
|
|
|
|
assert converted.sampling_config.beam_width == 3
|
|
assert converted.sampling_config.top_k == 2
|
|
assert converted.sampling_config.top_p == 1.
|
|
assert converted.sampling_config.top_p_min == 0.5
|
|
assert converted.sampling_config.top_p_reset_ids == 3
|
|
assert converted.sampling_config.top_p_decay == pytest.approx(0.1)
|
|
assert converted.sampling_config.seed == 7
|
|
assert converted.sampling_config.temperature == 0.5
|
|
assert converted.sampling_config.min_tokens == 10
|
|
assert converted.sampling_config.beam_search_diversity_rate == pytest.approx(
|
|
0.7)
|
|
assert converted.sampling_config.repetition_penalty == pytest.approx(1.1)
|
|
assert converted.sampling_config.presence_penalty == pytest.approx(2.1)
|
|
assert converted.sampling_config.frequency_penalty == pytest.approx(4.1)
|
|
assert converted.sampling_config.length_penalty == pytest.approx(8.1)
|
|
assert converted.sampling_config.early_stopping == False
|
|
|
|
assert converted.output_config.return_log_probs == False
|
|
assert converted.output_config.return_context_logits == False
|
|
assert converted.output_config.return_generation_logits == False
|
|
assert converted.output_config.exclude_input_from_output == True
|
|
|
|
|
|
def test_convert_request(triton_request: MockTritonRequest):
|
|
converted_reqs = convert_request(triton_request,
|
|
exclude_input_from_output=True,
|
|
decoupled=True)
|
|
assert len(converted_reqs) == 1
|
|
converted = converted_reqs[0]
|
|
check_converted_request(converted)
|
|
|
|
|
|
def test_convert_request_minimal(triton_request_minimal: MockTritonRequest):
|
|
converted_reqs = convert_request(triton_request_minimal,
|
|
exclude_input_from_output=False,
|
|
decoupled=False)
|
|
assert len(converted_reqs) == 1
|
|
converted = converted_reqs[0]
|
|
assert converted.input_token_ids == [28524, 287, 5093, 12]
|
|
assert converted.max_tokens == 16
|
|
assert converted.streaming == False
|
|
assert converted.end_id is None
|
|
assert converted.pad_id is None
|
|
assert converted.stop_words is None
|
|
assert converted.bad_words is None
|
|
assert converted.embedding_bias is None
|
|
assert converted.logits_post_processor_name is None
|
|
assert converted.external_draft_tokens_config is None
|
|
assert converted.prompt_tuning_config is None
|
|
assert converted.lora_config is None
|
|
assert converted.kv_cache_retention_config is None
|
|
|
|
assert converted.sampling_config.beam_width == 1
|
|
assert converted.sampling_config.top_k is None
|
|
assert converted.sampling_config.top_p is None
|
|
assert converted.sampling_config.top_p_min is None
|
|
assert converted.sampling_config.top_p_reset_ids is None
|
|
assert converted.sampling_config.top_p_decay is None
|
|
assert converted.sampling_config.seed is None
|
|
assert converted.sampling_config.temperature is None
|
|
assert converted.sampling_config.min_tokens is None
|
|
assert converted.sampling_config.beam_search_diversity_rate is None
|
|
assert converted.sampling_config.repetition_penalty is None
|
|
assert converted.sampling_config.presence_penalty is None
|
|
assert converted.sampling_config.frequency_penalty is None
|
|
assert converted.sampling_config.length_penalty is None
|
|
assert converted.sampling_config.early_stopping is None
|
|
|
|
assert converted.output_config.return_log_probs == False
|
|
assert converted.output_config.return_context_logits == False
|
|
assert converted.output_config.return_generation_logits == False
|
|
assert converted.output_config.exclude_input_from_output == False
|
|
|
|
|
|
def test_kv_cache_retention_config_invalid():
|
|
|
|
def check_retention_config(d: Dict[str, np.ndarray], is_valid: bool):
|
|
req = make_mock_triton_request(d)
|
|
if is_valid:
|
|
get_kv_cache_retention_config_from_request(req, 1, 0)
|
|
else:
|
|
with pytest.raises(RuntimeError):
|
|
get_kv_cache_retention_config_from_request(req, 1, 0)
|
|
|
|
check_retention_config(
|
|
{
|
|
"retention_token_range_starts": np.array([[0, 100]]),
|
|
"retention_token_range_ends": np.array([[100, 200]])
|
|
}, False)
|
|
|
|
check_retention_config(
|
|
{
|
|
"retention_token_range_starts": np.array([[0, 100]]),
|
|
"retention_token_range_ends": np.array([[100, 200]]),
|
|
"retention_token_range_priorities": np.array([[100]])
|
|
}, False)
|
|
|
|
check_retention_config(
|
|
{
|
|
"retention_token_range_starts": np.array([[0, 100]]),
|
|
"retention_token_range_ends": np.array([[100, 200]]),
|
|
"retention_token_range_priorities": np.array([[50, 50]])
|
|
}, True)
|
|
|
|
check_retention_config(
|
|
{
|
|
"retention_token_range_starts": np.array([[0, 100]]),
|
|
"retention_token_range_ends": np.array([[100, 200]]),
|
|
"retention_token_range_priorities": np.array([[50, 50]]),
|
|
"retention_token_range_durations_ms": np.array([[100]])
|
|
}, False)
|
|
|
|
check_retention_config(
|
|
{
|
|
"retention_token_range_starts": np.array([[0, 100]]),
|
|
"retention_token_range_ends": np.array([[100, 200]]),
|
|
"retention_token_range_priorities": np.array([[50, 50]]),
|
|
"retention_token_range_durations_ms": np.array([[100, 50]])
|
|
}, True)
|
|
|
|
check_retention_config(
|
|
{
|
|
"retention_token_range_starts": np.array([[0]]),
|
|
"retention_token_range_ends": np.array([[-1]]),
|
|
"retention_token_range_priorities": np.array([[50]]),
|
|
"retention_token_range_durations_ms": np.array([[1000]])
|
|
}, True)
|
|
|
|
|
|
# Need to test with Executor lookahead config.
|
|
def test_request_lookahead_config():
|
|
|
|
def check_request_lookahead_config(request_config: Dict[str, str],
|
|
executor_config, is_valid: bool):
|
|
req = make_mock_triton_request(request_config)
|
|
|
|
if is_valid:
|
|
get_lookahead_decoding_config_from_request(req,
|
|
executor_config,
|
|
batch_size=1,
|
|
batch_index=0)
|
|
else:
|
|
with pytest.raises(Exception):
|
|
get_lookahead_decoding_config_from_request(req,
|
|
executor_config,
|
|
batch_size=1,
|
|
batch_index=0)
|
|
|
|
# When request and executor lookahead_config are set correctly
|
|
check_request_lookahead_config(
|
|
{
|
|
"lookahead_window_size": np.array([[3]], dtype=np.int32),
|
|
"lookahead_ngram_size": np.array([[3]], dtype=np.int32),
|
|
"lookahead_verification_set_size": np.array([[3]], dtype=np.int32),
|
|
}, trtllm.LookaheadDecodingConfig(3, 3, 3), True)
|
|
|
|
# When request lookahead_config is not specified
|
|
check_request_lookahead_config({}, trtllm.LookaheadDecodingConfig(3, 3, 3),
|
|
True)
|
|
|
|
# When request lookahead_config is incomplete
|
|
check_request_lookahead_config(
|
|
{
|
|
"lookahead_window_size": np.array([[3]], dtype=np.int32),
|
|
}, trtllm.LookaheadDecodingConfig(3, 3, 3), False)
|
|
|
|
# When request lookahead_config is incomplete
|
|
check_request_lookahead_config(
|
|
{
|
|
"lookahead_window_size": np.array([[3]], dtype=np.int32),
|
|
"lookahead_ngram_size": np.array([[3]], dtype=np.int32),
|
|
}, trtllm.LookaheadDecodingConfig(3, 3, 3), False)
|
|
|
|
# When request lookahead_config is set while executor_lookahead_config is None
|
|
check_request_lookahead_config(
|
|
{
|
|
"lookahead_window_size": np.array([[3]], dtype=np.int32),
|
|
"lookahead_ngram_size": np.array([[3]], dtype=np.int32),
|
|
"lookahead_verification_set_size": np.array([[3]], dtype=np.int32),
|
|
}, None, False)
|
|
|
|
|
|
def test_convert_request_invalid():
|
|
with pytest.raises(Exception, match="A value is required for input_ids"):
|
|
no_input_ids = MockTritonRequest({
|
|
"request_output_len":
|
|
MockTritonTensor("request_output_len", np.array([[128]]))
|
|
})
|
|
convert_request(no_input_ids, False, False)
|
|
with pytest.raises(Exception, match="Invalid format for input_ids"):
|
|
bad_input_ids = MockTritonRequest(
|
|
{"input_ids": MockTritonTensor("input_ids", np.array([]))})
|
|
convert_request(bad_input_ids, False, False)
|
|
with pytest.raises(Exception,
|
|
match="A value is required for request_output_len"):
|
|
no_output_len = MockTritonRequest(
|
|
{"input_ids": MockTritonTensor("input_ids", np.array([[1, 2, 3]]))})
|
|
convert_request(no_output_len, False, False)
|
|
with pytest.raises(Exception,
|
|
match="Streaming is only supported in decoupled mode."):
|
|
streaming_non_decoupled = MockTritonRequest({
|
|
"input_ids":
|
|
MockTritonTensor("input_ids", np.array([[1, 2, 3]])),
|
|
"request_output_len":
|
|
MockTritonTensor("request_output_len", np.array([[128]])),
|
|
"streaming":
|
|
MockTritonTensor("streaming", np.array([[True]])),
|
|
})
|
|
convert_request(streaming_non_decoupled, False, False)
|
|
|
|
|
|
def test_convert_response(trtllm_response: trtllm.Response):
|
|
batch_index = 2
|
|
batch_size = 3
|
|
num_return_sequences = 1
|
|
response, is_final, output_length = convert_response(
|
|
trtllm_response, batch_index, batch_size, num_return_sequences)
|
|
assert is_final == True
|
|
assert (response.tensors["output_ids"].as_numpy() == np.array([[1, 2,
|
|
3]])).all()
|
|
assert (response.tensors["sequence_length"].as_numpy() == np.array(
|
|
[[3]])).all()
|
|
assert (response.tensors["cum_log_probs"].as_numpy() == np.array([1])).all()
|
|
assert (response.tensors["output_log_probs"].as_numpy() == np.array(
|
|
[[1, 3]])).all()
|
|
assert (response.tensors["context_logits"].as_numpy() == np.ones(
|
|
(3, 10), dtype=np.float32)).all()
|
|
assert (response.tensors["generation_logits"].as_numpy() == np.ones(
|
|
(1, 5, 10), dtype=np.float32)).all()
|
|
assert (response.tensors["batch_index"].as_numpy() == np.array(
|
|
[[batch_index]])).all()
|
|
|
|
|
|
def test_convert_response_minimal(trtllm_response_minimal: trtllm.Response):
|
|
batch_index = 2
|
|
batch_size = 3
|
|
num_return_sequences = 1
|
|
response, is_final, output_length = convert_response(
|
|
trtllm_response_minimal, batch_index, batch_size, num_return_sequences)
|
|
assert is_final == False
|
|
assert (response.tensors["output_ids"].as_numpy() == np.array([[1, 2,
|
|
3]])).all()
|
|
assert (response.tensors["sequence_length"].as_numpy() == np.array(
|
|
[[3]])).all()
|
|
assert "cum_log_probs" not in response.tensors
|
|
assert "output_log_probs" not in response.tensors
|
|
assert "output_log_probs" not in response.tensors
|
|
assert "context_logits" not in response.tensors
|
|
assert "generation_logits" not in response.tensors
|
|
assert (response.tensors["batch_index"].as_numpy() == np.array(
|
|
[[batch_index]])).all()
|
|
|
|
|
|
def test_convert_response_error(trtllm_response_error: trtllm.Response):
|
|
batch_index = 2
|
|
batch_size = 3
|
|
num_return_sequences = 1
|
|
response, is_final, output_length = convert_response(
|
|
trtllm_response_error, batch_index, batch_size, num_return_sequences)
|
|
assert is_final == True
|
|
assert response.has_error() and response.error.message == "internal error"
|
|
|
|
|
|
def test_convert_scheduler_policy():
|
|
assert convert_scheduler_policy(
|
|
"max_utilization") == trtllm.CapacitySchedulerPolicy.MAX_UTILIZATION
|
|
assert convert_scheduler_policy(
|
|
"guaranteed_no_evict"
|
|
) == trtllm.CapacitySchedulerPolicy.GUARANTEED_NO_EVICT
|
|
with pytest.raises(
|
|
Exception,
|
|
match="batch_scheduler_policy value of 'other' is not supported"):
|
|
convert_scheduler_policy("other")
|
|
|
|
|
|
def test_convert_batching_type():
|
|
assert convert_batching_type(
|
|
"inflight_fused_batching") == trtllm.BatchingType.INFLIGHT
|
|
assert convert_batching_type(
|
|
"inflight_batching") == trtllm.BatchingType.INFLIGHT
|
|
assert convert_batching_type("v1") == trtllm.BatchingType.STATIC
|
|
with pytest.raises(
|
|
Exception,
|
|
match="gpt_model_type value of 'other' is not supported"):
|
|
convert_batching_type("other")
|
|
|
|
|
|
def test_convert_decoding_mode():
|
|
assert convert_decoding_mode(None) is None
|
|
assert convert_decoding_mode("auto").isAuto()
|
|
assert convert_decoding_mode("top_k").isTopK()
|
|
assert convert_decoding_mode("top_p").isTopP()
|
|
assert convert_decoding_mode("top_k_top_p").isTopKandTopP()
|
|
assert convert_decoding_mode("beam_search").isBeamSearch()
|
|
assert convert_decoding_mode("medusa").isMedusa()
|
|
assert convert_decoding_mode("redrafter").isExplicitDraftTokens()
|
|
assert convert_decoding_mode("lookahead").isLookahead()
|
|
assert convert_decoding_mode("eagle").isEagle()
|
|
with pytest.raises(Exception,
|
|
match="decoding_mode value of 'other' is not supported"):
|
|
convert_decoding_mode("other")
|
|
|
|
|
|
@pytest.fixture
|
|
def model_config() -> Dict:
|
|
config = {
|
|
"max_beam_width": "2",
|
|
"enable_chunked_context": "true",
|
|
"normalize_log_probs": "false",
|
|
"gpt_model_type": "inflight_batching",
|
|
"medusa_choices": "{1, 2, 3, 4}, {5, 6, 7}",
|
|
"decoding_mode": "medusa",
|
|
"batch_scheduler_policy": "max_utilization",
|
|
"enable_kv_cache_reuse": "false",
|
|
"max_tokens_in_paged_kv_cache": "1",
|
|
"max_attention_window_size": "2",
|
|
"sink_token_length": "3",
|
|
"kv_cache_free_gpu_mem_fraction": "0.5",
|
|
"cross_kv_cache_fraction": "0.5",
|
|
"kv_cache_host_memory_bytes": "4",
|
|
"kv_cache_onboard_blocks": "false",
|
|
"gpu_device_ids": "0,1,2,3",
|
|
"executor_worker_path": str(os.path.abspath(__file__)),
|
|
"lora_cache_optimal_adapter_size": "1",
|
|
"lora_cache_max_adapter_size": "2",
|
|
"lora_cache_gpu_memory_fraction": "0.5",
|
|
"lora_cache_host_memory_bytes": "4",
|
|
"lora_prefetch_dir": "",
|
|
"enable_context_fmha_fp32_acc": "true"
|
|
}
|
|
return {"parameters": {k: {"string_value": v} for k, v in config.items()}}
|
|
|
|
|
|
def test_get_executor_config(model_config: Dict):
|
|
os.environ["TRTLLM_ORCHESTRATOR"] = "0"
|
|
config = TritonPythonModel().get_executor_config(model_config)
|
|
assert config.max_beam_width == 2
|
|
assert config.enable_chunked_context == True
|
|
assert config.normalize_log_probs == False
|
|
assert config.batching_type == trtllm.BatchingType.INFLIGHT
|
|
assert config.decoding_config.medusa_choices == [[1, 2, 3, 4], [5, 6, 7]]
|
|
assert config.decoding_config.decoding_mode.isMedusa()
|
|
assert config.scheduler_config.capacity_scheduler_policy == trtllm.CapacitySchedulerPolicy.MAX_UTILIZATION
|
|
assert config.kv_cache_config.enable_block_reuse == False
|
|
assert config.kv_cache_config.max_tokens == 1
|
|
assert config.kv_cache_config.max_attention_window == [2]
|
|
assert config.kv_cache_config.sink_token_length == 3
|
|
assert config.kv_cache_config.free_gpu_memory_fraction == 0.5
|
|
assert config.kv_cache_config.cross_kv_cache_fraction == 0.5
|
|
assert config.kv_cache_config.host_cache_size == 4
|
|
assert config.kv_cache_config.onboard_blocks == False
|
|
assert config.parallel_config.device_ids == [0, 1, 2, 3]
|
|
assert config.parallel_config.orchestrator_config is None
|
|
assert config.peft_cache_config.optimal_adapter_size == 1
|
|
assert config.peft_cache_config.max_adapter_size == 2
|
|
assert config.peft_cache_config.device_cache_percent == 0.5
|
|
assert config.peft_cache_config.host_cache_size == 4
|
|
assert config.iter_stats_max_iterations == 1000
|
|
assert config.request_stats_max_iterations == 0
|
|
assert config.logits_post_processor_config is None
|
|
assert config.extended_runtime_perf_knob_config.enable_context_fmha_fp32_acc == True
|
|
assert config.extended_runtime_perf_knob_config.multi_block_mode == True
|
|
del os.environ["TRTLLM_ORCHESTRATOR"]
|
|
|
|
|
|
def test_get_executor_config_orchestrator_mode(model_config: Dict):
|
|
os.environ["TRTLLM_ORCHESTRATOR"] = "1"
|
|
config = TritonPythonModel().get_executor_config(model_config)
|
|
assert config.parallel_config.device_ids == [0, 1, 2, 3]
|
|
assert config.parallel_config.orchestrator_config.is_orchestrator == True
|
|
assert config.parallel_config.orchestrator_config.worker_executable_path == str(
|
|
os.path.abspath(__file__))
|
|
del os.environ["TRTLLM_ORCHESTRATOR"]
|
|
|
|
|
|
def test_get_executor_config_minimal():
|
|
if "TRTLLM_ORCHESTRATOR" in os.environ:
|
|
del os.environ["TRTLLM_ORCHESTRATOR"]
|
|
config = TritonPythonModel().get_executor_config({"parameters": {}})
|
|
assert config.max_beam_width == 1
|
|
assert config.enable_chunked_context == False
|
|
assert config.normalize_log_probs == True
|
|
assert config.batching_type == trtllm.BatchingType.INFLIGHT
|
|
assert config.decoding_config.decoding_mode is None
|
|
assert config.decoding_config.medusa_choices is None
|
|
assert config.decoding_config.eagle_config is None
|
|
assert config.decoding_config.lookahead_decoding_config is None
|
|
assert config.scheduler_config.capacity_scheduler_policy == trtllm.CapacitySchedulerPolicy.GUARANTEED_NO_EVICT
|
|
assert config.kv_cache_config.enable_block_reuse == True
|
|
assert config.kv_cache_config.max_tokens is None
|
|
assert config.kv_cache_config.max_attention_window is None
|
|
assert config.kv_cache_config.sink_token_length is None
|
|
assert config.kv_cache_config.free_gpu_memory_fraction is None
|
|
assert config.kv_cache_config.cross_kv_cache_fraction is None
|
|
assert config.kv_cache_config.host_cache_size is None
|
|
assert config.kv_cache_config.onboard_blocks == True
|
|
assert config.parallel_config is None
|
|
assert config.peft_cache_config.optimal_adapter_size == 8
|
|
assert config.peft_cache_config.max_adapter_size == 64
|
|
assert config.peft_cache_config.device_cache_percent is None
|
|
assert config.peft_cache_config.host_cache_size is None
|
|
assert config.iter_stats_max_iterations == 1000
|
|
assert config.request_stats_max_iterations == 0
|
|
assert config.logits_post_processor_config is None
|
|
assert config.extended_runtime_perf_knob_config.enable_context_fmha_fp32_acc == False
|
|
assert config.extended_runtime_perf_knob_config.multi_block_mode == True
|
|
|
|
|
|
def test_convert_timestamp_to_seconds():
|
|
assert convert_timestamp_to_seconds("01-01-1970 00:00:00.000000") == 0
|
|
assert convert_timestamp_to_seconds(
|
|
"05-17-2024 23:28:39.000000") == 1715988519
|