From 9601b17459c3cd007cfc6287a3d118e22d1371e1 Mon Sep 17 00:00:00 2001 From: Chang Su Date: Thu, 5 Feb 2026 02:00:29 -0800 Subject: [PATCH] [#11037][fix] Fix proto-to-SamplingParams conversion bugs and add gRPC tests (#11292) Signed-off-by: Chang Su --- tensorrt_llm/grpc/grpc_request_manager.py | 33 +- tests/unittest/llmapi/test_grpc.py | 409 +++++++++++++++++++++- 2 files changed, 425 insertions(+), 17 deletions(-) diff --git a/tensorrt_llm/grpc/grpc_request_manager.py b/tensorrt_llm/grpc/grpc_request_manager.py index 63bfccb45e..c18af48ba2 100644 --- a/tensorrt_llm/grpc/grpc_request_manager.py +++ b/tensorrt_llm/grpc/grpc_request_manager.py @@ -265,7 +265,8 @@ def create_sampling_params_from_proto( # Beam search / sampling if proto_config.beam_width > 1: - kwargs["beam_width"] = proto_config.beam_width + kwargs["use_beam_search"] = True + kwargs["best_of"] = proto_config.beam_width if proto_config.num_return_sequences > 0: kwargs["n"] = proto_config.num_return_sequences @@ -289,7 +290,7 @@ def create_sampling_params_from_proto( # Seed for reproducibility if proto_config.HasField("seed"): - kwargs["random_seed"] = proto_config.seed + kwargs["seed"] = proto_config.seed # Min/max tokens if proto_config.HasField("min_tokens"): @@ -336,11 +337,10 @@ def create_sampling_params_from_proto( if output_config.exclude_input_from_output: kwargs["exclude_input_from_output"] = True - # Stop sequences (as token ID lists) - if stop_words: - kwargs["stop_words"] = [list(seq.token_ids) for seq in stop_words] - if bad_words: - kwargs["bad_words"] = [list(seq.token_ids) for seq in bad_words] + # Pre-tokenized stop/bad word sequences (set after construction since + # SamplingParams._stop_word_ids/_bad_word_ids are init=False fields) + stop_word_ids = [list(seq.token_ids) for seq in stop_words] if stop_words else None + bad_word_ids = [list(seq.token_ids) for seq in bad_words] if bad_words else None # Embedding bias if embedding_bias: @@ -353,15 +353,24 @@ def create_sampling_params_from_proto( if guide_type == pb2.GuidedDecodingParams.GUIDE_TYPE_JSON: # json_object=True for JSON validation without schema constraint - kwargs["guided_decoding_params"] = GuidedDecodingParams(json_object=True) + kwargs["guided_decoding"] = GuidedDecodingParams(json_object=True) elif guide_type == pb2.GuidedDecodingParams.GUIDE_TYPE_JSON_SCHEMA: - kwargs["guided_decoding_params"] = GuidedDecodingParams(json_schema=guide_content) + kwargs["guided_decoding"] = GuidedDecodingParams(json=guide_content) elif guide_type == pb2.GuidedDecodingParams.GUIDE_TYPE_REGEX: - kwargs["guided_decoding_params"] = GuidedDecodingParams(regex=guide_content) + kwargs["guided_decoding"] = GuidedDecodingParams(regex=guide_content) elif guide_type == pb2.GuidedDecodingParams.GUIDE_TYPE_EBNF_GRAMMAR: - kwargs["guided_decoding_params"] = GuidedDecodingParams(grammar=guide_content) + kwargs["guided_decoding"] = GuidedDecodingParams(grammar=guide_content) - return SamplingParams(**kwargs) + params = SamplingParams(**kwargs) + + # Set pre-tokenized stop/bad word IDs directly (these come pre-tokenized + # from the router, so we bypass the tokenizer-based setup path) + if stop_word_ids: + params._stop_word_ids = stop_word_ids + if bad_word_ids: + params._bad_word_ids = bad_word_ids + + return params def create_lora_request_from_proto( diff --git a/tests/unittest/llmapi/test_grpc.py b/tests/unittest/llmapi/test_grpc.py index d766900cfd..08d712f6ab 100644 --- a/tests/unittest/llmapi/test_grpc.py +++ b/tests/unittest/llmapi/test_grpc.py @@ -14,14 +14,30 @@ # limitations under the License. """Unit tests for gRPC server components.""" +import asyncio +import os +import sys + import pytest +import torch from tensorrt_llm.grpc import trtllm_service_pb2 as pb2 from tensorrt_llm.grpc.grpc_request_manager import ( + GrpcRequestManager, create_disaggregated_params_from_proto, create_lora_request_from_proto, create_sampling_params_from_proto, ) +from tensorrt_llm.grpc.grpc_servicer import TrtllmServiceServicer +from tensorrt_llm.llmapi import KvCacheConfig + +# isort: off +sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/..") +from utils.llm_data import llm_models_root + +# isort: on + +skip_no_gpu = pytest.mark.skipif(not torch.cuda.is_available(), reason="GPU not available") pytestmark = pytest.mark.threadleak(enabled=False) @@ -67,7 +83,8 @@ class TestSamplingParamsConversion: max_tokens=50, ) - assert params.beam_width == 4 + assert params.use_beam_search is True + assert params.best_of == 4 assert params.n == 2 assert params.length_penalty == 1.2 @@ -123,8 +140,8 @@ class TestSamplingParamsConversion: guided_decoding=guided_decoding, ) - assert params.guided_decoding_params is not None - assert params.guided_decoding_params.json_schema is not None + assert params.guided_decoding is not None + assert params.guided_decoding.json is not None def test_guided_decoding_regex(self): """Test guided decoding with regex.""" @@ -142,8 +159,8 @@ class TestSamplingParamsConversion: guided_decoding=guided_decoding, ) - assert params.guided_decoding_params is not None - assert params.guided_decoding_params.regex is not None + assert params.guided_decoding is not None + assert params.guided_decoding.regex is not None class TestLoraRequestConversion: @@ -313,3 +330,385 @@ class TestProtoMessages: assert request.request_id == "abort-123" assert response.success is True assert response.message == "Request aborted" + + +# ============================================================================ +# Comprehensive SamplingParams conversion test +# ============================================================================ + + +class TestComprehensiveSamplingParamsConversion: + """Comprehensive test covering all proto fields for SamplingParams conversion. + + Ensures the proto version stays consistent with the actual SamplingParams. + """ + + def test_all_sampling_config_fields(self): + """Test conversion of ALL SamplingConfig fields to SamplingParams. + + Sets every proto field to a non-default value and verifies the + corresponding SamplingParams field is correctly mapped. + """ + proto_config = pb2.SamplingConfig( + beam_width=4, + num_return_sequences=2, + top_k=40, + top_p=0.95, + top_p_min=0.01, + top_p_reset_ids=5, + top_p_decay=0.99, + seed=42, + temperature=0.8, + min_tokens=10, + beam_search_diversity_rate=0.5, + repetition_penalty=1.2, + presence_penalty=0.6, + frequency_penalty=0.4, + length_penalty=1.1, + early_stopping=1, + no_repeat_ngram_size=3, + min_p=0.05, + ) + output_config = pb2.OutputConfig( + logprobs=5, + prompt_logprobs=3, + return_context_logits=True, + return_generation_logits=True, + exclude_input_from_output=True, + ) + stop_words = [ + pb2.TokenSequence(token_ids=[50256]), + pb2.TokenSequence(token_ids=[50257, 50258]), + ] + bad_words = [ + pb2.TokenSequence(token_ids=[100, 101]), + ] + embedding_bias = [0.0] * 10 + [1.5, -1.5] + + params = create_sampling_params_from_proto( + proto_config=proto_config, + output_config=output_config, + max_tokens=256, + end_id=50256, + pad_id=50257, + stop_words=stop_words, + bad_words=bad_words, + embedding_bias=embedding_bias, + ) + + # Beam search fields + assert params.use_beam_search is True + assert params.best_of == 4 + assert params.n == 2 + + # Sampling fields + assert params.top_k == 40 + assert params.top_p == pytest.approx(0.95) + assert params.top_p_min == pytest.approx(0.01) + assert params.top_p_reset_ids == 5 + assert params.top_p_decay == pytest.approx(0.99) + assert params.seed == 42 + assert params.temperature == pytest.approx(0.8) + assert params.min_tokens == 10 + assert params.min_p == pytest.approx(0.05) + + # Beam search specific + assert params.beam_search_diversity_rate == pytest.approx(0.5) + assert params.length_penalty == pytest.approx(1.1) + assert params.early_stopping == 1 + assert params.no_repeat_ngram_size == 3 + + # Penalties + assert params.repetition_penalty == pytest.approx(1.2) + assert params.presence_penalty == pytest.approx(0.6) + assert params.frequency_penalty == pytest.approx(0.4) + + # OutputConfig fields + assert params.logprobs == 5 + assert params.prompt_logprobs == 3 + assert params.return_context_logits is True + assert params.return_generation_logits is True + assert params.exclude_input_from_output is True + + # Other params + assert params.max_tokens == 256 + assert params.end_id == 50256 + assert params.pad_id == 50257 + assert params.detokenize is False # key optimization + + # Stop/bad words (set as pre-tokenized word IDs) + assert params._stop_word_ids == [[50256], [50257, 50258]] + assert params._bad_word_ids == [[100, 101]] + + # Embedding bias converted to torch.Tensor + assert params.embedding_bias is not None + assert len(params.embedding_bias) == 12 + + def test_end_id_minus_one_sets_ignore_eos(self): + """Test that end_id=-1 correctly sets ignore_eos=True.""" + proto_config = pb2.SamplingConfig(temperature=0.7) + output_config = pb2.OutputConfig() + + params = create_sampling_params_from_proto( + proto_config=proto_config, + output_config=output_config, + max_tokens=100, + end_id=-1, + ) + + assert params.end_id == -1 + assert params.ignore_eos is True + + def test_defaults_when_fields_unset(self): + """Test that sensible defaults are applied for unset proto fields. + + Proto optional fields default to unset, but the conversion function + applies safety defaults for temperature, top_p, and repetition_penalty. + """ + proto_config = pb2.SamplingConfig() + output_config = pb2.OutputConfig() + + params = create_sampling_params_from_proto( + proto_config=proto_config, + output_config=output_config, + max_tokens=100, + ) + + assert params.temperature == 1.0 # default safety guard + assert params.top_p == 1.0 # default safety guard + assert params.repetition_penalty == 1.0 # default = no penalty + assert params.detokenize is False + + def test_guided_decoding_all_types(self): + """Test all guided decoding types map to correct GuidedDecodingParams fields.""" + proto_config = pb2.SamplingConfig() + output_config = pb2.OutputConfig() + + # JSON (object mode) + params = create_sampling_params_from_proto( + proto_config=proto_config, + output_config=output_config, + max_tokens=100, + guided_decoding=pb2.GuidedDecodingParams( + guide_type=pb2.GuidedDecodingParams.GUIDE_TYPE_JSON, + guide="{}", + ), + ) + assert params.guided_decoding is not None + assert params.guided_decoding.json_object is True + + # JSON Schema + schema = '{"type": "object", "properties": {"name": {"type": "string"}}}' + params = create_sampling_params_from_proto( + proto_config=proto_config, + output_config=output_config, + max_tokens=100, + guided_decoding=pb2.GuidedDecodingParams( + guide_type=pb2.GuidedDecodingParams.GUIDE_TYPE_JSON_SCHEMA, + guide=schema, + ), + ) + assert params.guided_decoding is not None + assert params.guided_decoding.json == schema + + # Regex + params = create_sampling_params_from_proto( + proto_config=proto_config, + output_config=output_config, + max_tokens=100, + guided_decoding=pb2.GuidedDecodingParams( + guide_type=pb2.GuidedDecodingParams.GUIDE_TYPE_REGEX, + guide=r"\d{3}-\d{4}", + ), + ) + assert params.guided_decoding is not None + assert params.guided_decoding.regex == r"\d{3}-\d{4}" + + # EBNF Grammar + grammar = 'root ::= "hello" | "world"' + params = create_sampling_params_from_proto( + proto_config=proto_config, + output_config=output_config, + max_tokens=100, + guided_decoding=pb2.GuidedDecodingParams( + guide_type=pb2.GuidedDecodingParams.GUIDE_TYPE_EBNF_GRAMMAR, + guide=grammar, + ), + ) + assert params.guided_decoding is not None + assert params.guided_decoding.grammar == grammar + + +# ============================================================================ +# End-to-end gRPC service tests (with real model) +# ============================================================================ + +default_model_name = "llama-models-v2/TinyLlama-1.1B-Chat-v1.0" + + +def get_model_path(model_name): + engine_dir = os.environ.get("LLM_ENGINE_DIR", None) + if engine_dir: + return engine_dir + return str(llm_models_root() / model_name) + + +@pytest.fixture(scope="module") +def grpc_service(): + """Create a real LLM, request manager, and servicer for e2e testing. + + Uses TinyLlama-1.1B for minimal GPU resource usage. + Shared across all tests in this module. + """ + from tensorrt_llm._tensorrt_engine import LLM + + model_path = get_model_path(default_model_name) + llm = LLM( + model=model_path, + kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.4), + fast_build=True, + ) + tokenizer = llm.tokenizer + + request_manager = GrpcRequestManager(llm) + servicer = TrtllmServiceServicer(request_manager, model_path=model_path) + + yield llm, tokenizer, request_manager, servicer + + llm.shutdown() + + +def _run_async(coro): + """Helper to run async code in sync tests.""" + loop = asyncio.new_event_loop() + try: + return loop.run_until_complete(coro) + finally: + loop.close() + + +class _MockContext: + """Minimal mock for grpc.aio.ServicerContext.""" + + def cancelled(self): + return False + + +@skip_no_gpu +class TestGrpcServiceEndToEnd: + """End-to-end tests for the gRPC service flow. + + Tests the full pipeline: gRPC request -> servicer -> request manager -> LLM -> response. + Uses TinyLlama-1.1B for minimal GPU resource usage. + """ + + def test_generate_non_streaming(self, grpc_service): + """Test non-streaming generation returns a complete response with token IDs.""" + llm, tokenizer, request_manager, servicer = grpc_service + prompt_token_ids = tokenizer.encode("A B C") + + request = pb2.GenerateRequest( + request_id="e2e-non-stream", + tokenized=pb2.TokenizedInput(input_token_ids=prompt_token_ids), + sampling_config=pb2.SamplingConfig(temperature=0.0), + max_tokens=8, + streaming=False, + ) + + async def run(): + responses = [] + async for resp in servicer.Generate(request, _MockContext()): + responses.append(resp) + return responses + + responses = _run_async(run()) + + completes = [r for r in responses if r.HasField("complete")] + assert len(completes) == 1 + + resp = completes[0] + assert resp.request_id == "e2e-non-stream" + assert len(resp.complete.output_token_ids) > 0 + assert resp.complete.prompt_tokens == len(prompt_token_ids) + assert resp.complete.completion_tokens == len(resp.complete.output_token_ids) + assert resp.complete.finish_reason in ("stop", "length") + + def test_generate_streaming(self, grpc_service): + """Test streaming generation returns delta chunks followed by a complete response.""" + llm, tokenizer, request_manager, servicer = grpc_service + prompt_token_ids = tokenizer.encode("A B C") + + request = pb2.GenerateRequest( + request_id="e2e-stream", + tokenized=pb2.TokenizedInput(input_token_ids=prompt_token_ids), + sampling_config=pb2.SamplingConfig(temperature=0.0), + max_tokens=8, + streaming=True, + ) + + async def run(): + responses = [] + async for resp in servicer.Generate(request, _MockContext()): + responses.append(resp) + return responses + + responses = _run_async(run()) + + chunks = [r for r in responses if r.HasField("chunk")] + completes = [r for r in responses if r.HasField("complete")] + + # Should have at least one streaming chunk + assert len(chunks) >= 1 + # Each chunk should have delta tokens + for chunk_resp in chunks: + assert len(chunk_resp.chunk.token_ids) > 0 + + # Reassemble all delta tokens and verify they match the complete response + all_streamed_tokens = [] + for chunk_resp in chunks: + all_streamed_tokens.extend(chunk_resp.chunk.token_ids) + + assert len(completes) == 1 + complete_tokens = list(completes[0].complete.output_token_ids) + assert all_streamed_tokens == complete_tokens + + def test_health_check(self, grpc_service): + """Test HealthCheck RPC returns healthy status.""" + _, _, _, servicer = grpc_service + + async def run(): + return await servicer.HealthCheck(pb2.HealthCheckRequest(), _MockContext()) + + response = _run_async(run()) + assert response.status == "OK" + + def test_abort_nonexistent_request(self, grpc_service): + """Test aborting a request that doesn't exist returns failure.""" + _, _, _, servicer = grpc_service + + async def run(): + return await servicer.Abort(pb2.AbortRequest(request_id="nonexistent"), _MockContext()) + + response = _run_async(run()) + assert response.success is False + + def test_get_model_info(self, grpc_service): + """Test GetModelInfo RPC returns model metadata.""" + _, _, _, servicer = grpc_service + + async def run(): + return await servicer.GetModelInfo(pb2.GetModelInfoRequest(), _MockContext()) + + response = _run_async(run()) + assert response.vocab_size > 0 + + def test_get_server_info(self, grpc_service): + """Test GetServerInfo RPC returns server metadata.""" + _, _, _, servicer = grpc_service + + async def run(): + return await servicer.GetServerInfo(pb2.GetServerInfoRequest(), _MockContext()) + + response = _run_async(run()) + assert response.backend == "tensorrt-llm" + assert response.world_size >= 1