TensorRT-LLMs/tests/unittest/llmapi/test_grpc.py
2026-02-05 05:00:29 -05:00

715 lines
24 KiB
Python

# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Unit tests for gRPC server components."""
import asyncio
import os
import sys
import pytest
import torch
from tensorrt_llm.grpc import trtllm_service_pb2 as pb2
from tensorrt_llm.grpc.grpc_request_manager import (
GrpcRequestManager,
create_disaggregated_params_from_proto,
create_lora_request_from_proto,
create_sampling_params_from_proto,
)
from tensorrt_llm.grpc.grpc_servicer import TrtllmServiceServicer
from tensorrt_llm.llmapi import KvCacheConfig
# isort: off
sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/..")
from utils.llm_data import llm_models_root
# isort: on
skip_no_gpu = pytest.mark.skipif(not torch.cuda.is_available(), reason="GPU not available")
pytestmark = pytest.mark.threadleak(enabled=False)
class TestSamplingParamsConversion:
"""Tests for proto to SamplingParams conversion."""
def test_basic_sampling_config(self):
"""Test basic sampling config conversion."""
proto_config = pb2.SamplingConfig(
beam_width=1,
num_return_sequences=1,
temperature=0.7,
top_k=50,
top_p=0.9,
)
output_config = pb2.OutputConfig()
params = create_sampling_params_from_proto(
proto_config=proto_config,
output_config=output_config,
max_tokens=100,
)
assert params.max_tokens == 100
assert params.temperature == 0.7
assert params.top_k == 50
assert params.top_p == 0.9
def test_beam_search_config(self):
"""Test beam search configuration."""
proto_config = pb2.SamplingConfig(
beam_width=4,
num_return_sequences=2,
length_penalty=1.2,
early_stopping=1,
)
output_config = pb2.OutputConfig()
params = create_sampling_params_from_proto(
proto_config=proto_config,
output_config=output_config,
max_tokens=50,
)
assert params.use_beam_search is True
assert params.best_of == 4
assert params.n == 2
assert params.length_penalty == 1.2
def test_penalties_config(self):
"""Test penalty parameters conversion."""
proto_config = pb2.SamplingConfig(
repetition_penalty=1.1,
presence_penalty=0.5,
frequency_penalty=0.3,
)
output_config = pb2.OutputConfig()
params = create_sampling_params_from_proto(
proto_config=proto_config,
output_config=output_config,
max_tokens=100,
)
assert params.repetition_penalty == 1.1
assert params.presence_penalty == 0.5
assert params.frequency_penalty == 0.3
def test_logprobs_config(self):
"""Test logprobs configuration."""
proto_config = pb2.SamplingConfig()
output_config = pb2.OutputConfig(
logprobs=5,
prompt_logprobs=3,
)
params = create_sampling_params_from_proto(
proto_config=proto_config,
output_config=output_config,
max_tokens=100,
)
assert params.logprobs == 5
assert params.prompt_logprobs == 3
def test_guided_decoding_json_schema(self):
"""Test guided decoding with JSON schema."""
proto_config = pb2.SamplingConfig()
output_config = pb2.OutputConfig()
guided_decoding = pb2.GuidedDecodingParams(
guide_type=pb2.GuidedDecodingParams.GUIDE_TYPE_JSON_SCHEMA,
guide='{"type": "object", "properties": {"name": {"type": "string"}}}',
)
params = create_sampling_params_from_proto(
proto_config=proto_config,
output_config=output_config,
max_tokens=100,
guided_decoding=guided_decoding,
)
assert params.guided_decoding is not None
assert params.guided_decoding.json is not None
def test_guided_decoding_regex(self):
"""Test guided decoding with regex."""
proto_config = pb2.SamplingConfig()
output_config = pb2.OutputConfig()
guided_decoding = pb2.GuidedDecodingParams(
guide_type=pb2.GuidedDecodingParams.GUIDE_TYPE_REGEX,
guide=r"\d{3}-\d{4}",
)
params = create_sampling_params_from_proto(
proto_config=proto_config,
output_config=output_config,
max_tokens=100,
guided_decoding=guided_decoding,
)
assert params.guided_decoding is not None
assert params.guided_decoding.regex is not None
class TestLoraRequestConversion:
"""Tests for proto to LoRARequest conversion."""
def test_basic_lora_config(self):
"""Test basic LoRA config conversion."""
lora_config = pb2.LoraConfig(task_id=123)
request = create_lora_request_from_proto(lora_config)
assert request is not None
assert request.task_id == 123
def test_none_lora_config(self):
"""Test None LoRA config returns None."""
request = create_lora_request_from_proto(None)
assert request is None
class TestDisaggregatedParamsConversion:
"""Tests for proto to DisaggregatedParams conversion."""
def test_context_only_request(self):
"""Test context-only disaggregated request."""
proto_params = pb2.DisaggregatedParams(
request_type=pb2.DisaggregatedParams.REQUEST_TYPE_CONTEXT_ONLY,
ctx_request_id="ctx-123",
)
params = create_disaggregated_params_from_proto(proto_params)
assert params is not None
assert params.ctx_request_id == "ctx-123"
def test_generation_only_request(self):
"""Test generation-only disaggregated request."""
proto_params = pb2.DisaggregatedParams(
request_type=pb2.DisaggregatedParams.REQUEST_TYPE_GENERATION_ONLY,
ctx_request_id="gen-456",
)
params = create_disaggregated_params_from_proto(proto_params)
assert params is not None
def test_none_params(self):
"""Test None disaggregated params returns None."""
params = create_disaggregated_params_from_proto(None)
assert params is None
class TestProtoMessages:
"""Tests for proto message structure."""
def test_generate_request_structure(self):
"""Test GenerateRequest message structure."""
request = pb2.GenerateRequest(
request_id="test-123",
tokenized=pb2.TokenizedInput(
input_token_ids=[1, 2, 3, 4, 5],
original_text="Hello world",
),
sampling_config=pb2.SamplingConfig(temperature=0.8),
max_tokens=50,
streaming=True,
)
assert request.request_id == "test-123"
assert list(request.tokenized.input_token_ids) == [1, 2, 3, 4, 5]
assert request.tokenized.original_text == "Hello world"
assert request.sampling_config.temperature == 0.8
assert request.max_tokens == 50
assert request.streaming is True
def test_generate_response_chunk(self):
"""Test GenerateResponse with chunk."""
response = pb2.GenerateResponse(
request_id="test-123",
chunk=pb2.GenerateStreamChunk(
token_ids=[10, 11, 12],
sequence_index=0,
prompt_tokens=5,
completion_tokens=3,
),
)
assert response.request_id == "test-123"
assert list(response.chunk.token_ids) == [10, 11, 12]
assert response.chunk.prompt_tokens == 5
assert response.chunk.completion_tokens == 3
def test_generate_response_complete(self):
"""Test GenerateResponse with complete."""
response = pb2.GenerateResponse(
request_id="test-123",
complete=pb2.GenerateComplete(
output_token_ids=[10, 11, 12, 13],
finish_reason="stop",
prompt_tokens=5,
completion_tokens=4,
),
)
assert response.request_id == "test-123"
assert list(response.complete.output_token_ids) == [10, 11, 12, 13]
assert response.complete.finish_reason == "stop"
def test_health_check_messages(self):
"""Test HealthCheck messages."""
_request = pb2.HealthCheckRequest() # noqa: F841 - verify message construction
response = pb2.HealthCheckResponse(status="healthy")
assert response.status == "healthy"
def test_model_info_response(self):
"""Test GetModelInfoResponse message."""
response = pb2.GetModelInfoResponse(
model_id="meta-llama/Llama-2-7b",
max_input_len=4096,
max_seq_len=8192,
vocab_size=32000,
)
assert response.model_id == "meta-llama/Llama-2-7b"
assert response.max_input_len == 4096
assert response.max_seq_len == 8192
assert response.vocab_size == 32000
def test_server_info_response(self):
"""Test GetServerInfoResponse message."""
response = pb2.GetServerInfoResponse(
version="0.17.0",
backend="tensorrt-llm",
tensor_parallel_size=2,
pipeline_parallel_size=1,
world_size=2,
)
assert response.version == "0.17.0"
assert response.backend == "tensorrt-llm"
assert response.tensor_parallel_size == 2
assert response.world_size == 2
def test_embed_messages(self):
"""Test Embed request and response messages."""
request = pb2.EmbedRequest(
request_id="embed-123",
tokenized=pb2.TokenizedInput(input_token_ids=[1, 2, 3]),
)
response = pb2.EmbedResponse(
request_id="embed-123",
embedding=[0.1, 0.2, 0.3, 0.4],
prompt_tokens=3,
)
assert request.request_id == "embed-123"
assert response.request_id == "embed-123"
assert list(response.embedding) == [0.1, 0.2, 0.3, 0.4]
assert response.prompt_tokens == 3
def test_abort_messages(self):
"""Test Abort request and response messages."""
request = pb2.AbortRequest(request_id="abort-123")
response = pb2.AbortResponse(success=True, message="Request aborted")
assert request.request_id == "abort-123"
assert response.success is True
assert response.message == "Request aborted"
# ============================================================================
# Comprehensive SamplingParams conversion test
# ============================================================================
class TestComprehensiveSamplingParamsConversion:
"""Comprehensive test covering all proto fields for SamplingParams conversion.
Ensures the proto version stays consistent with the actual SamplingParams.
"""
def test_all_sampling_config_fields(self):
"""Test conversion of ALL SamplingConfig fields to SamplingParams.
Sets every proto field to a non-default value and verifies the
corresponding SamplingParams field is correctly mapped.
"""
proto_config = pb2.SamplingConfig(
beam_width=4,
num_return_sequences=2,
top_k=40,
top_p=0.95,
top_p_min=0.01,
top_p_reset_ids=5,
top_p_decay=0.99,
seed=42,
temperature=0.8,
min_tokens=10,
beam_search_diversity_rate=0.5,
repetition_penalty=1.2,
presence_penalty=0.6,
frequency_penalty=0.4,
length_penalty=1.1,
early_stopping=1,
no_repeat_ngram_size=3,
min_p=0.05,
)
output_config = pb2.OutputConfig(
logprobs=5,
prompt_logprobs=3,
return_context_logits=True,
return_generation_logits=True,
exclude_input_from_output=True,
)
stop_words = [
pb2.TokenSequence(token_ids=[50256]),
pb2.TokenSequence(token_ids=[50257, 50258]),
]
bad_words = [
pb2.TokenSequence(token_ids=[100, 101]),
]
embedding_bias = [0.0] * 10 + [1.5, -1.5]
params = create_sampling_params_from_proto(
proto_config=proto_config,
output_config=output_config,
max_tokens=256,
end_id=50256,
pad_id=50257,
stop_words=stop_words,
bad_words=bad_words,
embedding_bias=embedding_bias,
)
# Beam search fields
assert params.use_beam_search is True
assert params.best_of == 4
assert params.n == 2
# Sampling fields
assert params.top_k == 40
assert params.top_p == pytest.approx(0.95)
assert params.top_p_min == pytest.approx(0.01)
assert params.top_p_reset_ids == 5
assert params.top_p_decay == pytest.approx(0.99)
assert params.seed == 42
assert params.temperature == pytest.approx(0.8)
assert params.min_tokens == 10
assert params.min_p == pytest.approx(0.05)
# Beam search specific
assert params.beam_search_diversity_rate == pytest.approx(0.5)
assert params.length_penalty == pytest.approx(1.1)
assert params.early_stopping == 1
assert params.no_repeat_ngram_size == 3
# Penalties
assert params.repetition_penalty == pytest.approx(1.2)
assert params.presence_penalty == pytest.approx(0.6)
assert params.frequency_penalty == pytest.approx(0.4)
# OutputConfig fields
assert params.logprobs == 5
assert params.prompt_logprobs == 3
assert params.return_context_logits is True
assert params.return_generation_logits is True
assert params.exclude_input_from_output is True
# Other params
assert params.max_tokens == 256
assert params.end_id == 50256
assert params.pad_id == 50257
assert params.detokenize is False # key optimization
# Stop/bad words (set as pre-tokenized word IDs)
assert params._stop_word_ids == [[50256], [50257, 50258]]
assert params._bad_word_ids == [[100, 101]]
# Embedding bias converted to torch.Tensor
assert params.embedding_bias is not None
assert len(params.embedding_bias) == 12
def test_end_id_minus_one_sets_ignore_eos(self):
"""Test that end_id=-1 correctly sets ignore_eos=True."""
proto_config = pb2.SamplingConfig(temperature=0.7)
output_config = pb2.OutputConfig()
params = create_sampling_params_from_proto(
proto_config=proto_config,
output_config=output_config,
max_tokens=100,
end_id=-1,
)
assert params.end_id == -1
assert params.ignore_eos is True
def test_defaults_when_fields_unset(self):
"""Test that sensible defaults are applied for unset proto fields.
Proto optional fields default to unset, but the conversion function
applies safety defaults for temperature, top_p, and repetition_penalty.
"""
proto_config = pb2.SamplingConfig()
output_config = pb2.OutputConfig()
params = create_sampling_params_from_proto(
proto_config=proto_config,
output_config=output_config,
max_tokens=100,
)
assert params.temperature == 1.0 # default safety guard
assert params.top_p == 1.0 # default safety guard
assert params.repetition_penalty == 1.0 # default = no penalty
assert params.detokenize is False
def test_guided_decoding_all_types(self):
"""Test all guided decoding types map to correct GuidedDecodingParams fields."""
proto_config = pb2.SamplingConfig()
output_config = pb2.OutputConfig()
# JSON (object mode)
params = create_sampling_params_from_proto(
proto_config=proto_config,
output_config=output_config,
max_tokens=100,
guided_decoding=pb2.GuidedDecodingParams(
guide_type=pb2.GuidedDecodingParams.GUIDE_TYPE_JSON,
guide="{}",
),
)
assert params.guided_decoding is not None
assert params.guided_decoding.json_object is True
# JSON Schema
schema = '{"type": "object", "properties": {"name": {"type": "string"}}}'
params = create_sampling_params_from_proto(
proto_config=proto_config,
output_config=output_config,
max_tokens=100,
guided_decoding=pb2.GuidedDecodingParams(
guide_type=pb2.GuidedDecodingParams.GUIDE_TYPE_JSON_SCHEMA,
guide=schema,
),
)
assert params.guided_decoding is not None
assert params.guided_decoding.json == schema
# Regex
params = create_sampling_params_from_proto(
proto_config=proto_config,
output_config=output_config,
max_tokens=100,
guided_decoding=pb2.GuidedDecodingParams(
guide_type=pb2.GuidedDecodingParams.GUIDE_TYPE_REGEX,
guide=r"\d{3}-\d{4}",
),
)
assert params.guided_decoding is not None
assert params.guided_decoding.regex == r"\d{3}-\d{4}"
# EBNF Grammar
grammar = 'root ::= "hello" | "world"'
params = create_sampling_params_from_proto(
proto_config=proto_config,
output_config=output_config,
max_tokens=100,
guided_decoding=pb2.GuidedDecodingParams(
guide_type=pb2.GuidedDecodingParams.GUIDE_TYPE_EBNF_GRAMMAR,
guide=grammar,
),
)
assert params.guided_decoding is not None
assert params.guided_decoding.grammar == grammar
# ============================================================================
# End-to-end gRPC service tests (with real model)
# ============================================================================
default_model_name = "llama-models-v2/TinyLlama-1.1B-Chat-v1.0"
def get_model_path(model_name):
engine_dir = os.environ.get("LLM_ENGINE_DIR", None)
if engine_dir:
return engine_dir
return str(llm_models_root() / model_name)
@pytest.fixture(scope="module")
def grpc_service():
"""Create a real LLM, request manager, and servicer for e2e testing.
Uses TinyLlama-1.1B for minimal GPU resource usage.
Shared across all tests in this module.
"""
from tensorrt_llm._tensorrt_engine import LLM
model_path = get_model_path(default_model_name)
llm = LLM(
model=model_path,
kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.4),
fast_build=True,
)
tokenizer = llm.tokenizer
request_manager = GrpcRequestManager(llm)
servicer = TrtllmServiceServicer(request_manager, model_path=model_path)
yield llm, tokenizer, request_manager, servicer
llm.shutdown()
def _run_async(coro):
"""Helper to run async code in sync tests."""
loop = asyncio.new_event_loop()
try:
return loop.run_until_complete(coro)
finally:
loop.close()
class _MockContext:
"""Minimal mock for grpc.aio.ServicerContext."""
def cancelled(self):
return False
@skip_no_gpu
class TestGrpcServiceEndToEnd:
"""End-to-end tests for the gRPC service flow.
Tests the full pipeline: gRPC request -> servicer -> request manager -> LLM -> response.
Uses TinyLlama-1.1B for minimal GPU resource usage.
"""
def test_generate_non_streaming(self, grpc_service):
"""Test non-streaming generation returns a complete response with token IDs."""
llm, tokenizer, request_manager, servicer = grpc_service
prompt_token_ids = tokenizer.encode("A B C")
request = pb2.GenerateRequest(
request_id="e2e-non-stream",
tokenized=pb2.TokenizedInput(input_token_ids=prompt_token_ids),
sampling_config=pb2.SamplingConfig(temperature=0.0),
max_tokens=8,
streaming=False,
)
async def run():
responses = []
async for resp in servicer.Generate(request, _MockContext()):
responses.append(resp)
return responses
responses = _run_async(run())
completes = [r for r in responses if r.HasField("complete")]
assert len(completes) == 1
resp = completes[0]
assert resp.request_id == "e2e-non-stream"
assert len(resp.complete.output_token_ids) > 0
assert resp.complete.prompt_tokens == len(prompt_token_ids)
assert resp.complete.completion_tokens == len(resp.complete.output_token_ids)
assert resp.complete.finish_reason in ("stop", "length")
def test_generate_streaming(self, grpc_service):
"""Test streaming generation returns delta chunks followed by a complete response."""
llm, tokenizer, request_manager, servicer = grpc_service
prompt_token_ids = tokenizer.encode("A B C")
request = pb2.GenerateRequest(
request_id="e2e-stream",
tokenized=pb2.TokenizedInput(input_token_ids=prompt_token_ids),
sampling_config=pb2.SamplingConfig(temperature=0.0),
max_tokens=8,
streaming=True,
)
async def run():
responses = []
async for resp in servicer.Generate(request, _MockContext()):
responses.append(resp)
return responses
responses = _run_async(run())
chunks = [r for r in responses if r.HasField("chunk")]
completes = [r for r in responses if r.HasField("complete")]
# Should have at least one streaming chunk
assert len(chunks) >= 1
# Each chunk should have delta tokens
for chunk_resp in chunks:
assert len(chunk_resp.chunk.token_ids) > 0
# Reassemble all delta tokens and verify they match the complete response
all_streamed_tokens = []
for chunk_resp in chunks:
all_streamed_tokens.extend(chunk_resp.chunk.token_ids)
assert len(completes) == 1
complete_tokens = list(completes[0].complete.output_token_ids)
assert all_streamed_tokens == complete_tokens
def test_health_check(self, grpc_service):
"""Test HealthCheck RPC returns healthy status."""
_, _, _, servicer = grpc_service
async def run():
return await servicer.HealthCheck(pb2.HealthCheckRequest(), _MockContext())
response = _run_async(run())
assert response.status == "OK"
def test_abort_nonexistent_request(self, grpc_service):
"""Test aborting a request that doesn't exist returns failure."""
_, _, _, servicer = grpc_service
async def run():
return await servicer.Abort(pb2.AbortRequest(request_id="nonexistent"), _MockContext())
response = _run_async(run())
assert response.success is False
def test_get_model_info(self, grpc_service):
"""Test GetModelInfo RPC returns model metadata."""
_, _, _, servicer = grpc_service
async def run():
return await servicer.GetModelInfo(pb2.GetModelInfoRequest(), _MockContext())
response = _run_async(run())
assert response.vocab_size > 0
def test_get_server_info(self, grpc_service):
"""Test GetServerInfo RPC returns server metadata."""
_, _, _, servicer = grpc_service
async def run():
return await servicer.GetServerInfo(pb2.GetServerInfoRequest(), _MockContext())
response = _run_async(run())
assert response.backend == "tensorrt-llm"
assert response.world_size >= 1