mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-16 15:55:08 +08:00
Signed-off-by: Chang Su <chang.s.su@oracle.com>
This commit is contained in:
parent
d9b936be94
commit
9601b17459
@ -265,7 +265,8 @@ def create_sampling_params_from_proto(
|
||||
|
||||
# Beam search / sampling
|
||||
if proto_config.beam_width > 1:
|
||||
kwargs["beam_width"] = proto_config.beam_width
|
||||
kwargs["use_beam_search"] = True
|
||||
kwargs["best_of"] = proto_config.beam_width
|
||||
if proto_config.num_return_sequences > 0:
|
||||
kwargs["n"] = proto_config.num_return_sequences
|
||||
|
||||
@ -289,7 +290,7 @@ def create_sampling_params_from_proto(
|
||||
|
||||
# Seed for reproducibility
|
||||
if proto_config.HasField("seed"):
|
||||
kwargs["random_seed"] = proto_config.seed
|
||||
kwargs["seed"] = proto_config.seed
|
||||
|
||||
# Min/max tokens
|
||||
if proto_config.HasField("min_tokens"):
|
||||
@ -336,11 +337,10 @@ def create_sampling_params_from_proto(
|
||||
if output_config.exclude_input_from_output:
|
||||
kwargs["exclude_input_from_output"] = True
|
||||
|
||||
# Stop sequences (as token ID lists)
|
||||
if stop_words:
|
||||
kwargs["stop_words"] = [list(seq.token_ids) for seq in stop_words]
|
||||
if bad_words:
|
||||
kwargs["bad_words"] = [list(seq.token_ids) for seq in bad_words]
|
||||
# Pre-tokenized stop/bad word sequences (set after construction since
|
||||
# SamplingParams._stop_word_ids/_bad_word_ids are init=False fields)
|
||||
stop_word_ids = [list(seq.token_ids) for seq in stop_words] if stop_words else None
|
||||
bad_word_ids = [list(seq.token_ids) for seq in bad_words] if bad_words else None
|
||||
|
||||
# Embedding bias
|
||||
if embedding_bias:
|
||||
@ -353,15 +353,24 @@ def create_sampling_params_from_proto(
|
||||
|
||||
if guide_type == pb2.GuidedDecodingParams.GUIDE_TYPE_JSON:
|
||||
# json_object=True for JSON validation without schema constraint
|
||||
kwargs["guided_decoding_params"] = GuidedDecodingParams(json_object=True)
|
||||
kwargs["guided_decoding"] = GuidedDecodingParams(json_object=True)
|
||||
elif guide_type == pb2.GuidedDecodingParams.GUIDE_TYPE_JSON_SCHEMA:
|
||||
kwargs["guided_decoding_params"] = GuidedDecodingParams(json_schema=guide_content)
|
||||
kwargs["guided_decoding"] = GuidedDecodingParams(json=guide_content)
|
||||
elif guide_type == pb2.GuidedDecodingParams.GUIDE_TYPE_REGEX:
|
||||
kwargs["guided_decoding_params"] = GuidedDecodingParams(regex=guide_content)
|
||||
kwargs["guided_decoding"] = GuidedDecodingParams(regex=guide_content)
|
||||
elif guide_type == pb2.GuidedDecodingParams.GUIDE_TYPE_EBNF_GRAMMAR:
|
||||
kwargs["guided_decoding_params"] = GuidedDecodingParams(grammar=guide_content)
|
||||
kwargs["guided_decoding"] = GuidedDecodingParams(grammar=guide_content)
|
||||
|
||||
return SamplingParams(**kwargs)
|
||||
params = SamplingParams(**kwargs)
|
||||
|
||||
# Set pre-tokenized stop/bad word IDs directly (these come pre-tokenized
|
||||
# from the router, so we bypass the tokenizer-based setup path)
|
||||
if stop_word_ids:
|
||||
params._stop_word_ids = stop_word_ids
|
||||
if bad_word_ids:
|
||||
params._bad_word_ids = bad_word_ids
|
||||
|
||||
return params
|
||||
|
||||
|
||||
def create_lora_request_from_proto(
|
||||
|
||||
@ -14,14 +14,30 @@
|
||||
# limitations under the License.
|
||||
"""Unit tests for gRPC server components."""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tensorrt_llm.grpc import trtllm_service_pb2 as pb2
|
||||
from tensorrt_llm.grpc.grpc_request_manager import (
|
||||
GrpcRequestManager,
|
||||
create_disaggregated_params_from_proto,
|
||||
create_lora_request_from_proto,
|
||||
create_sampling_params_from_proto,
|
||||
)
|
||||
from tensorrt_llm.grpc.grpc_servicer import TrtllmServiceServicer
|
||||
from tensorrt_llm.llmapi import KvCacheConfig
|
||||
|
||||
# isort: off
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/..")
|
||||
from utils.llm_data import llm_models_root
|
||||
|
||||
# isort: on
|
||||
|
||||
skip_no_gpu = pytest.mark.skipif(not torch.cuda.is_available(), reason="GPU not available")
|
||||
|
||||
pytestmark = pytest.mark.threadleak(enabled=False)
|
||||
|
||||
@ -67,7 +83,8 @@ class TestSamplingParamsConversion:
|
||||
max_tokens=50,
|
||||
)
|
||||
|
||||
assert params.beam_width == 4
|
||||
assert params.use_beam_search is True
|
||||
assert params.best_of == 4
|
||||
assert params.n == 2
|
||||
assert params.length_penalty == 1.2
|
||||
|
||||
@ -123,8 +140,8 @@ class TestSamplingParamsConversion:
|
||||
guided_decoding=guided_decoding,
|
||||
)
|
||||
|
||||
assert params.guided_decoding_params is not None
|
||||
assert params.guided_decoding_params.json_schema is not None
|
||||
assert params.guided_decoding is not None
|
||||
assert params.guided_decoding.json is not None
|
||||
|
||||
def test_guided_decoding_regex(self):
|
||||
"""Test guided decoding with regex."""
|
||||
@ -142,8 +159,8 @@ class TestSamplingParamsConversion:
|
||||
guided_decoding=guided_decoding,
|
||||
)
|
||||
|
||||
assert params.guided_decoding_params is not None
|
||||
assert params.guided_decoding_params.regex is not None
|
||||
assert params.guided_decoding is not None
|
||||
assert params.guided_decoding.regex is not None
|
||||
|
||||
|
||||
class TestLoraRequestConversion:
|
||||
@ -313,3 +330,385 @@ class TestProtoMessages:
|
||||
assert request.request_id == "abort-123"
|
||||
assert response.success is True
|
||||
assert response.message == "Request aborted"
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Comprehensive SamplingParams conversion test
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestComprehensiveSamplingParamsConversion:
|
||||
"""Comprehensive test covering all proto fields for SamplingParams conversion.
|
||||
|
||||
Ensures the proto version stays consistent with the actual SamplingParams.
|
||||
"""
|
||||
|
||||
def test_all_sampling_config_fields(self):
|
||||
"""Test conversion of ALL SamplingConfig fields to SamplingParams.
|
||||
|
||||
Sets every proto field to a non-default value and verifies the
|
||||
corresponding SamplingParams field is correctly mapped.
|
||||
"""
|
||||
proto_config = pb2.SamplingConfig(
|
||||
beam_width=4,
|
||||
num_return_sequences=2,
|
||||
top_k=40,
|
||||
top_p=0.95,
|
||||
top_p_min=0.01,
|
||||
top_p_reset_ids=5,
|
||||
top_p_decay=0.99,
|
||||
seed=42,
|
||||
temperature=0.8,
|
||||
min_tokens=10,
|
||||
beam_search_diversity_rate=0.5,
|
||||
repetition_penalty=1.2,
|
||||
presence_penalty=0.6,
|
||||
frequency_penalty=0.4,
|
||||
length_penalty=1.1,
|
||||
early_stopping=1,
|
||||
no_repeat_ngram_size=3,
|
||||
min_p=0.05,
|
||||
)
|
||||
output_config = pb2.OutputConfig(
|
||||
logprobs=5,
|
||||
prompt_logprobs=3,
|
||||
return_context_logits=True,
|
||||
return_generation_logits=True,
|
||||
exclude_input_from_output=True,
|
||||
)
|
||||
stop_words = [
|
||||
pb2.TokenSequence(token_ids=[50256]),
|
||||
pb2.TokenSequence(token_ids=[50257, 50258]),
|
||||
]
|
||||
bad_words = [
|
||||
pb2.TokenSequence(token_ids=[100, 101]),
|
||||
]
|
||||
embedding_bias = [0.0] * 10 + [1.5, -1.5]
|
||||
|
||||
params = create_sampling_params_from_proto(
|
||||
proto_config=proto_config,
|
||||
output_config=output_config,
|
||||
max_tokens=256,
|
||||
end_id=50256,
|
||||
pad_id=50257,
|
||||
stop_words=stop_words,
|
||||
bad_words=bad_words,
|
||||
embedding_bias=embedding_bias,
|
||||
)
|
||||
|
||||
# Beam search fields
|
||||
assert params.use_beam_search is True
|
||||
assert params.best_of == 4
|
||||
assert params.n == 2
|
||||
|
||||
# Sampling fields
|
||||
assert params.top_k == 40
|
||||
assert params.top_p == pytest.approx(0.95)
|
||||
assert params.top_p_min == pytest.approx(0.01)
|
||||
assert params.top_p_reset_ids == 5
|
||||
assert params.top_p_decay == pytest.approx(0.99)
|
||||
assert params.seed == 42
|
||||
assert params.temperature == pytest.approx(0.8)
|
||||
assert params.min_tokens == 10
|
||||
assert params.min_p == pytest.approx(0.05)
|
||||
|
||||
# Beam search specific
|
||||
assert params.beam_search_diversity_rate == pytest.approx(0.5)
|
||||
assert params.length_penalty == pytest.approx(1.1)
|
||||
assert params.early_stopping == 1
|
||||
assert params.no_repeat_ngram_size == 3
|
||||
|
||||
# Penalties
|
||||
assert params.repetition_penalty == pytest.approx(1.2)
|
||||
assert params.presence_penalty == pytest.approx(0.6)
|
||||
assert params.frequency_penalty == pytest.approx(0.4)
|
||||
|
||||
# OutputConfig fields
|
||||
assert params.logprobs == 5
|
||||
assert params.prompt_logprobs == 3
|
||||
assert params.return_context_logits is True
|
||||
assert params.return_generation_logits is True
|
||||
assert params.exclude_input_from_output is True
|
||||
|
||||
# Other params
|
||||
assert params.max_tokens == 256
|
||||
assert params.end_id == 50256
|
||||
assert params.pad_id == 50257
|
||||
assert params.detokenize is False # key optimization
|
||||
|
||||
# Stop/bad words (set as pre-tokenized word IDs)
|
||||
assert params._stop_word_ids == [[50256], [50257, 50258]]
|
||||
assert params._bad_word_ids == [[100, 101]]
|
||||
|
||||
# Embedding bias converted to torch.Tensor
|
||||
assert params.embedding_bias is not None
|
||||
assert len(params.embedding_bias) == 12
|
||||
|
||||
def test_end_id_minus_one_sets_ignore_eos(self):
|
||||
"""Test that end_id=-1 correctly sets ignore_eos=True."""
|
||||
proto_config = pb2.SamplingConfig(temperature=0.7)
|
||||
output_config = pb2.OutputConfig()
|
||||
|
||||
params = create_sampling_params_from_proto(
|
||||
proto_config=proto_config,
|
||||
output_config=output_config,
|
||||
max_tokens=100,
|
||||
end_id=-1,
|
||||
)
|
||||
|
||||
assert params.end_id == -1
|
||||
assert params.ignore_eos is True
|
||||
|
||||
def test_defaults_when_fields_unset(self):
|
||||
"""Test that sensible defaults are applied for unset proto fields.
|
||||
|
||||
Proto optional fields default to unset, but the conversion function
|
||||
applies safety defaults for temperature, top_p, and repetition_penalty.
|
||||
"""
|
||||
proto_config = pb2.SamplingConfig()
|
||||
output_config = pb2.OutputConfig()
|
||||
|
||||
params = create_sampling_params_from_proto(
|
||||
proto_config=proto_config,
|
||||
output_config=output_config,
|
||||
max_tokens=100,
|
||||
)
|
||||
|
||||
assert params.temperature == 1.0 # default safety guard
|
||||
assert params.top_p == 1.0 # default safety guard
|
||||
assert params.repetition_penalty == 1.0 # default = no penalty
|
||||
assert params.detokenize is False
|
||||
|
||||
def test_guided_decoding_all_types(self):
|
||||
"""Test all guided decoding types map to correct GuidedDecodingParams fields."""
|
||||
proto_config = pb2.SamplingConfig()
|
||||
output_config = pb2.OutputConfig()
|
||||
|
||||
# JSON (object mode)
|
||||
params = create_sampling_params_from_proto(
|
||||
proto_config=proto_config,
|
||||
output_config=output_config,
|
||||
max_tokens=100,
|
||||
guided_decoding=pb2.GuidedDecodingParams(
|
||||
guide_type=pb2.GuidedDecodingParams.GUIDE_TYPE_JSON,
|
||||
guide="{}",
|
||||
),
|
||||
)
|
||||
assert params.guided_decoding is not None
|
||||
assert params.guided_decoding.json_object is True
|
||||
|
||||
# JSON Schema
|
||||
schema = '{"type": "object", "properties": {"name": {"type": "string"}}}'
|
||||
params = create_sampling_params_from_proto(
|
||||
proto_config=proto_config,
|
||||
output_config=output_config,
|
||||
max_tokens=100,
|
||||
guided_decoding=pb2.GuidedDecodingParams(
|
||||
guide_type=pb2.GuidedDecodingParams.GUIDE_TYPE_JSON_SCHEMA,
|
||||
guide=schema,
|
||||
),
|
||||
)
|
||||
assert params.guided_decoding is not None
|
||||
assert params.guided_decoding.json == schema
|
||||
|
||||
# Regex
|
||||
params = create_sampling_params_from_proto(
|
||||
proto_config=proto_config,
|
||||
output_config=output_config,
|
||||
max_tokens=100,
|
||||
guided_decoding=pb2.GuidedDecodingParams(
|
||||
guide_type=pb2.GuidedDecodingParams.GUIDE_TYPE_REGEX,
|
||||
guide=r"\d{3}-\d{4}",
|
||||
),
|
||||
)
|
||||
assert params.guided_decoding is not None
|
||||
assert params.guided_decoding.regex == r"\d{3}-\d{4}"
|
||||
|
||||
# EBNF Grammar
|
||||
grammar = 'root ::= "hello" | "world"'
|
||||
params = create_sampling_params_from_proto(
|
||||
proto_config=proto_config,
|
||||
output_config=output_config,
|
||||
max_tokens=100,
|
||||
guided_decoding=pb2.GuidedDecodingParams(
|
||||
guide_type=pb2.GuidedDecodingParams.GUIDE_TYPE_EBNF_GRAMMAR,
|
||||
guide=grammar,
|
||||
),
|
||||
)
|
||||
assert params.guided_decoding is not None
|
||||
assert params.guided_decoding.grammar == grammar
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# End-to-end gRPC service tests (with real model)
|
||||
# ============================================================================
|
||||
|
||||
default_model_name = "llama-models-v2/TinyLlama-1.1B-Chat-v1.0"
|
||||
|
||||
|
||||
def get_model_path(model_name):
|
||||
engine_dir = os.environ.get("LLM_ENGINE_DIR", None)
|
||||
if engine_dir:
|
||||
return engine_dir
|
||||
return str(llm_models_root() / model_name)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def grpc_service():
|
||||
"""Create a real LLM, request manager, and servicer for e2e testing.
|
||||
|
||||
Uses TinyLlama-1.1B for minimal GPU resource usage.
|
||||
Shared across all tests in this module.
|
||||
"""
|
||||
from tensorrt_llm._tensorrt_engine import LLM
|
||||
|
||||
model_path = get_model_path(default_model_name)
|
||||
llm = LLM(
|
||||
model=model_path,
|
||||
kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.4),
|
||||
fast_build=True,
|
||||
)
|
||||
tokenizer = llm.tokenizer
|
||||
|
||||
request_manager = GrpcRequestManager(llm)
|
||||
servicer = TrtllmServiceServicer(request_manager, model_path=model_path)
|
||||
|
||||
yield llm, tokenizer, request_manager, servicer
|
||||
|
||||
llm.shutdown()
|
||||
|
||||
|
||||
def _run_async(coro):
|
||||
"""Helper to run async code in sync tests."""
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
return loop.run_until_complete(coro)
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
|
||||
class _MockContext:
|
||||
"""Minimal mock for grpc.aio.ServicerContext."""
|
||||
|
||||
def cancelled(self):
|
||||
return False
|
||||
|
||||
|
||||
@skip_no_gpu
|
||||
class TestGrpcServiceEndToEnd:
|
||||
"""End-to-end tests for the gRPC service flow.
|
||||
|
||||
Tests the full pipeline: gRPC request -> servicer -> request manager -> LLM -> response.
|
||||
Uses TinyLlama-1.1B for minimal GPU resource usage.
|
||||
"""
|
||||
|
||||
def test_generate_non_streaming(self, grpc_service):
|
||||
"""Test non-streaming generation returns a complete response with token IDs."""
|
||||
llm, tokenizer, request_manager, servicer = grpc_service
|
||||
prompt_token_ids = tokenizer.encode("A B C")
|
||||
|
||||
request = pb2.GenerateRequest(
|
||||
request_id="e2e-non-stream",
|
||||
tokenized=pb2.TokenizedInput(input_token_ids=prompt_token_ids),
|
||||
sampling_config=pb2.SamplingConfig(temperature=0.0),
|
||||
max_tokens=8,
|
||||
streaming=False,
|
||||
)
|
||||
|
||||
async def run():
|
||||
responses = []
|
||||
async for resp in servicer.Generate(request, _MockContext()):
|
||||
responses.append(resp)
|
||||
return responses
|
||||
|
||||
responses = _run_async(run())
|
||||
|
||||
completes = [r for r in responses if r.HasField("complete")]
|
||||
assert len(completes) == 1
|
||||
|
||||
resp = completes[0]
|
||||
assert resp.request_id == "e2e-non-stream"
|
||||
assert len(resp.complete.output_token_ids) > 0
|
||||
assert resp.complete.prompt_tokens == len(prompt_token_ids)
|
||||
assert resp.complete.completion_tokens == len(resp.complete.output_token_ids)
|
||||
assert resp.complete.finish_reason in ("stop", "length")
|
||||
|
||||
def test_generate_streaming(self, grpc_service):
|
||||
"""Test streaming generation returns delta chunks followed by a complete response."""
|
||||
llm, tokenizer, request_manager, servicer = grpc_service
|
||||
prompt_token_ids = tokenizer.encode("A B C")
|
||||
|
||||
request = pb2.GenerateRequest(
|
||||
request_id="e2e-stream",
|
||||
tokenized=pb2.TokenizedInput(input_token_ids=prompt_token_ids),
|
||||
sampling_config=pb2.SamplingConfig(temperature=0.0),
|
||||
max_tokens=8,
|
||||
streaming=True,
|
||||
)
|
||||
|
||||
async def run():
|
||||
responses = []
|
||||
async for resp in servicer.Generate(request, _MockContext()):
|
||||
responses.append(resp)
|
||||
return responses
|
||||
|
||||
responses = _run_async(run())
|
||||
|
||||
chunks = [r for r in responses if r.HasField("chunk")]
|
||||
completes = [r for r in responses if r.HasField("complete")]
|
||||
|
||||
# Should have at least one streaming chunk
|
||||
assert len(chunks) >= 1
|
||||
# Each chunk should have delta tokens
|
||||
for chunk_resp in chunks:
|
||||
assert len(chunk_resp.chunk.token_ids) > 0
|
||||
|
||||
# Reassemble all delta tokens and verify they match the complete response
|
||||
all_streamed_tokens = []
|
||||
for chunk_resp in chunks:
|
||||
all_streamed_tokens.extend(chunk_resp.chunk.token_ids)
|
||||
|
||||
assert len(completes) == 1
|
||||
complete_tokens = list(completes[0].complete.output_token_ids)
|
||||
assert all_streamed_tokens == complete_tokens
|
||||
|
||||
def test_health_check(self, grpc_service):
|
||||
"""Test HealthCheck RPC returns healthy status."""
|
||||
_, _, _, servicer = grpc_service
|
||||
|
||||
async def run():
|
||||
return await servicer.HealthCheck(pb2.HealthCheckRequest(), _MockContext())
|
||||
|
||||
response = _run_async(run())
|
||||
assert response.status == "OK"
|
||||
|
||||
def test_abort_nonexistent_request(self, grpc_service):
|
||||
"""Test aborting a request that doesn't exist returns failure."""
|
||||
_, _, _, servicer = grpc_service
|
||||
|
||||
async def run():
|
||||
return await servicer.Abort(pb2.AbortRequest(request_id="nonexistent"), _MockContext())
|
||||
|
||||
response = _run_async(run())
|
||||
assert response.success is False
|
||||
|
||||
def test_get_model_info(self, grpc_service):
|
||||
"""Test GetModelInfo RPC returns model metadata."""
|
||||
_, _, _, servicer = grpc_service
|
||||
|
||||
async def run():
|
||||
return await servicer.GetModelInfo(pb2.GetModelInfoRequest(), _MockContext())
|
||||
|
||||
response = _run_async(run())
|
||||
assert response.vocab_size > 0
|
||||
|
||||
def test_get_server_info(self, grpc_service):
|
||||
"""Test GetServerInfo RPC returns server metadata."""
|
||||
_, _, _, servicer = grpc_service
|
||||
|
||||
async def run():
|
||||
return await servicer.GetServerInfo(pb2.GetServerInfoRequest(), _MockContext())
|
||||
|
||||
response = _run_async(run())
|
||||
assert response.backend == "tensorrt-llm"
|
||||
assert response.world_size >= 1
|
||||
|
||||
Loading…
Reference in New Issue
Block a user