mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-06 03:01:50 +08:00
316 lines
10 KiB
Python
316 lines
10 KiB
Python
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""Unit tests for gRPC server components."""
|
|
|
|
import pytest
|
|
|
|
from tensorrt_llm.grpc import trtllm_service_pb2 as pb2
|
|
from tensorrt_llm.grpc.grpc_request_manager import (
|
|
create_disaggregated_params_from_proto,
|
|
create_lora_request_from_proto,
|
|
create_sampling_params_from_proto,
|
|
)
|
|
|
|
pytestmark = pytest.mark.threadleak(enabled=False)
|
|
|
|
|
|
class TestSamplingParamsConversion:
|
|
"""Tests for proto to SamplingParams conversion."""
|
|
|
|
def test_basic_sampling_config(self):
|
|
"""Test basic sampling config conversion."""
|
|
proto_config = pb2.SamplingConfig(
|
|
beam_width=1,
|
|
num_return_sequences=1,
|
|
temperature=0.7,
|
|
top_k=50,
|
|
top_p=0.9,
|
|
)
|
|
output_config = pb2.OutputConfig()
|
|
|
|
params = create_sampling_params_from_proto(
|
|
proto_config=proto_config,
|
|
output_config=output_config,
|
|
max_tokens=100,
|
|
)
|
|
|
|
assert params.max_tokens == 100
|
|
assert params.temperature == 0.7
|
|
assert params.top_k == 50
|
|
assert params.top_p == 0.9
|
|
|
|
def test_beam_search_config(self):
|
|
"""Test beam search configuration."""
|
|
proto_config = pb2.SamplingConfig(
|
|
beam_width=4,
|
|
num_return_sequences=2,
|
|
length_penalty=1.2,
|
|
early_stopping=1,
|
|
)
|
|
output_config = pb2.OutputConfig()
|
|
|
|
params = create_sampling_params_from_proto(
|
|
proto_config=proto_config,
|
|
output_config=output_config,
|
|
max_tokens=50,
|
|
)
|
|
|
|
assert params.beam_width == 4
|
|
assert params.n == 2
|
|
assert params.length_penalty == 1.2
|
|
|
|
def test_penalties_config(self):
|
|
"""Test penalty parameters conversion."""
|
|
proto_config = pb2.SamplingConfig(
|
|
repetition_penalty=1.1,
|
|
presence_penalty=0.5,
|
|
frequency_penalty=0.3,
|
|
)
|
|
output_config = pb2.OutputConfig()
|
|
|
|
params = create_sampling_params_from_proto(
|
|
proto_config=proto_config,
|
|
output_config=output_config,
|
|
max_tokens=100,
|
|
)
|
|
|
|
assert params.repetition_penalty == 1.1
|
|
assert params.presence_penalty == 0.5
|
|
assert params.frequency_penalty == 0.3
|
|
|
|
def test_logprobs_config(self):
|
|
"""Test logprobs configuration."""
|
|
proto_config = pb2.SamplingConfig()
|
|
output_config = pb2.OutputConfig(
|
|
logprobs=5,
|
|
prompt_logprobs=3,
|
|
)
|
|
|
|
params = create_sampling_params_from_proto(
|
|
proto_config=proto_config,
|
|
output_config=output_config,
|
|
max_tokens=100,
|
|
)
|
|
|
|
assert params.logprobs == 5
|
|
assert params.prompt_logprobs == 3
|
|
|
|
def test_guided_decoding_json_schema(self):
|
|
"""Test guided decoding with JSON schema."""
|
|
proto_config = pb2.SamplingConfig()
|
|
output_config = pb2.OutputConfig()
|
|
guided_decoding = pb2.GuidedDecodingParams(
|
|
guide_type=pb2.GuidedDecodingParams.GUIDE_TYPE_JSON_SCHEMA,
|
|
guide='{"type": "object", "properties": {"name": {"type": "string"}}}',
|
|
)
|
|
|
|
params = create_sampling_params_from_proto(
|
|
proto_config=proto_config,
|
|
output_config=output_config,
|
|
max_tokens=100,
|
|
guided_decoding=guided_decoding,
|
|
)
|
|
|
|
assert params.guided_decoding_params is not None
|
|
assert params.guided_decoding_params.json_schema is not None
|
|
|
|
def test_guided_decoding_regex(self):
|
|
"""Test guided decoding with regex."""
|
|
proto_config = pb2.SamplingConfig()
|
|
output_config = pb2.OutputConfig()
|
|
guided_decoding = pb2.GuidedDecodingParams(
|
|
guide_type=pb2.GuidedDecodingParams.GUIDE_TYPE_REGEX,
|
|
guide=r"\d{3}-\d{4}",
|
|
)
|
|
|
|
params = create_sampling_params_from_proto(
|
|
proto_config=proto_config,
|
|
output_config=output_config,
|
|
max_tokens=100,
|
|
guided_decoding=guided_decoding,
|
|
)
|
|
|
|
assert params.guided_decoding_params is not None
|
|
assert params.guided_decoding_params.regex is not None
|
|
|
|
|
|
class TestLoraRequestConversion:
|
|
"""Tests for proto to LoRARequest conversion."""
|
|
|
|
def test_basic_lora_config(self):
|
|
"""Test basic LoRA config conversion."""
|
|
lora_config = pb2.LoraConfig(task_id=123)
|
|
|
|
request = create_lora_request_from_proto(lora_config)
|
|
|
|
assert request is not None
|
|
assert request.task_id == 123
|
|
|
|
def test_none_lora_config(self):
|
|
"""Test None LoRA config returns None."""
|
|
request = create_lora_request_from_proto(None)
|
|
assert request is None
|
|
|
|
|
|
class TestDisaggregatedParamsConversion:
|
|
"""Tests for proto to DisaggregatedParams conversion."""
|
|
|
|
def test_context_only_request(self):
|
|
"""Test context-only disaggregated request."""
|
|
proto_params = pb2.DisaggregatedParams(
|
|
request_type=pb2.DisaggregatedParams.REQUEST_TYPE_CONTEXT_ONLY,
|
|
ctx_request_id="ctx-123",
|
|
)
|
|
|
|
params = create_disaggregated_params_from_proto(proto_params)
|
|
|
|
assert params is not None
|
|
assert params.ctx_request_id == "ctx-123"
|
|
|
|
def test_generation_only_request(self):
|
|
"""Test generation-only disaggregated request."""
|
|
proto_params = pb2.DisaggregatedParams(
|
|
request_type=pb2.DisaggregatedParams.REQUEST_TYPE_GENERATION_ONLY,
|
|
ctx_request_id="gen-456",
|
|
)
|
|
|
|
params = create_disaggregated_params_from_proto(proto_params)
|
|
|
|
assert params is not None
|
|
|
|
def test_none_params(self):
|
|
"""Test None disaggregated params returns None."""
|
|
params = create_disaggregated_params_from_proto(None)
|
|
assert params is None
|
|
|
|
|
|
class TestProtoMessages:
|
|
"""Tests for proto message structure."""
|
|
|
|
def test_generate_request_structure(self):
|
|
"""Test GenerateRequest message structure."""
|
|
request = pb2.GenerateRequest(
|
|
request_id="test-123",
|
|
tokenized=pb2.TokenizedInput(
|
|
input_token_ids=[1, 2, 3, 4, 5],
|
|
original_text="Hello world",
|
|
),
|
|
sampling_config=pb2.SamplingConfig(temperature=0.8),
|
|
max_tokens=50,
|
|
streaming=True,
|
|
)
|
|
|
|
assert request.request_id == "test-123"
|
|
assert list(request.tokenized.input_token_ids) == [1, 2, 3, 4, 5]
|
|
assert request.tokenized.original_text == "Hello world"
|
|
assert request.sampling_config.temperature == 0.8
|
|
assert request.max_tokens == 50
|
|
assert request.streaming is True
|
|
|
|
def test_generate_response_chunk(self):
|
|
"""Test GenerateResponse with chunk."""
|
|
response = pb2.GenerateResponse(
|
|
request_id="test-123",
|
|
chunk=pb2.GenerateStreamChunk(
|
|
token_ids=[10, 11, 12],
|
|
sequence_index=0,
|
|
prompt_tokens=5,
|
|
completion_tokens=3,
|
|
),
|
|
)
|
|
|
|
assert response.request_id == "test-123"
|
|
assert list(response.chunk.token_ids) == [10, 11, 12]
|
|
assert response.chunk.prompt_tokens == 5
|
|
assert response.chunk.completion_tokens == 3
|
|
|
|
def test_generate_response_complete(self):
|
|
"""Test GenerateResponse with complete."""
|
|
response = pb2.GenerateResponse(
|
|
request_id="test-123",
|
|
complete=pb2.GenerateComplete(
|
|
output_token_ids=[10, 11, 12, 13],
|
|
finish_reason="stop",
|
|
prompt_tokens=5,
|
|
completion_tokens=4,
|
|
),
|
|
)
|
|
|
|
assert response.request_id == "test-123"
|
|
assert list(response.complete.output_token_ids) == [10, 11, 12, 13]
|
|
assert response.complete.finish_reason == "stop"
|
|
|
|
def test_health_check_messages(self):
|
|
"""Test HealthCheck messages."""
|
|
_request = pb2.HealthCheckRequest() # noqa: F841 - verify message construction
|
|
response = pb2.HealthCheckResponse(status="healthy")
|
|
|
|
assert response.status == "healthy"
|
|
|
|
def test_model_info_response(self):
|
|
"""Test GetModelInfoResponse message."""
|
|
response = pb2.GetModelInfoResponse(
|
|
model_id="meta-llama/Llama-2-7b",
|
|
max_input_len=4096,
|
|
max_seq_len=8192,
|
|
vocab_size=32000,
|
|
)
|
|
|
|
assert response.model_id == "meta-llama/Llama-2-7b"
|
|
assert response.max_input_len == 4096
|
|
assert response.max_seq_len == 8192
|
|
assert response.vocab_size == 32000
|
|
|
|
def test_server_info_response(self):
|
|
"""Test GetServerInfoResponse message."""
|
|
response = pb2.GetServerInfoResponse(
|
|
version="0.17.0",
|
|
backend="tensorrt-llm",
|
|
tensor_parallel_size=2,
|
|
pipeline_parallel_size=1,
|
|
world_size=2,
|
|
)
|
|
|
|
assert response.version == "0.17.0"
|
|
assert response.backend == "tensorrt-llm"
|
|
assert response.tensor_parallel_size == 2
|
|
assert response.world_size == 2
|
|
|
|
def test_embed_messages(self):
|
|
"""Test Embed request and response messages."""
|
|
request = pb2.EmbedRequest(
|
|
request_id="embed-123",
|
|
tokenized=pb2.TokenizedInput(input_token_ids=[1, 2, 3]),
|
|
)
|
|
response = pb2.EmbedResponse(
|
|
request_id="embed-123",
|
|
embedding=[0.1, 0.2, 0.3, 0.4],
|
|
prompt_tokens=3,
|
|
)
|
|
|
|
assert request.request_id == "embed-123"
|
|
assert response.request_id == "embed-123"
|
|
assert list(response.embedding) == [0.1, 0.2, 0.3, 0.4]
|
|
assert response.prompt_tokens == 3
|
|
|
|
def test_abort_messages(self):
|
|
"""Test Abort request and response messages."""
|
|
request = pb2.AbortRequest(request_id="abort-123")
|
|
response = pb2.AbortResponse(success=True, message="Request aborted")
|
|
|
|
assert request.request_id == "abort-123"
|
|
assert response.success is True
|
|
assert response.message == "Request aborted"
|