diff --git a/setup.py b/setup.py index 795d668999..29f1f532ac 100644 --- a/setup.py +++ b/setup.py @@ -18,9 +18,52 @@ from pathlib import Path from typing import List from setuptools import find_packages, setup +from setuptools.command.build_py import build_py from setuptools.dist import Distribution +class BuildPyWithProtoCompile(build_py): + """Custom build_py command that compiles protobuf files before building.""" + + def run(self): + self.compile_grpc_protos() + super().run() + + def compile_grpc_protos(self): + """Compile gRPC protobuf files if the proto file exists.""" + grpc_dir = Path(__file__).parent / "tensorrt_llm" / "grpc" + proto_file = grpc_dir / "trtllm_service.proto" + compile_script = grpc_dir / "compile_protos.py" + + if not proto_file.exists(): + return + + # Check if pb2 files need to be generated + pb2_file = grpc_dir / "trtllm_service_pb2.py" + pb2_grpc_file = grpc_dir / "trtllm_service_pb2_grpc.py" + + # Regenerate if pb2 files don't exist or are older than proto file + needs_compile = (not pb2_file.exists() or not pb2_grpc_file.exists() or + pb2_file.stat().st_mtime < proto_file.stat().st_mtime) + + if needs_compile and compile_script.exists(): + import subprocess + import sys + + print("Compiling gRPC protobuf files...") + try: + subprocess.run( + [sys.executable, str(compile_script)], + check=True, + cwd=str(grpc_dir.parent.parent), + ) + print("gRPC protobuf compilation successful") + except subprocess.CalledProcessError as e: + print(f"Warning: Failed to compile gRPC protos: {e}") + except Exception as e: + print(f"Warning: gRPC proto compilation skipped: {e}") + + def parse_requirements(filename: os.PathLike): with open(filename) as f: requirements = f.read().splitlines() @@ -374,6 +417,7 @@ packages += find_packages(include=["triton_kernels", "triton_kernels.*"]) setup( name='tensorrt_llm', version=get_version(), + cmdclass={'build_py': BuildPyWithProtoCompile}, description= ('TensorRT LLM provides users with an easy-to-use Python API to define Large Language Models (LLMs) and supports ' 'state-of-the-art optimizations to perform inference efficiently on NVIDIA GPUs.' diff --git a/tensorrt_llm/commands/serve.py b/tensorrt_llm/commands/serve.py index 02bfbf5f02..76cbde9646 100644 --- a/tensorrt_llm/commands/serve.py +++ b/tensorrt_llm/commands/serve.py @@ -233,6 +233,124 @@ def launch_server( asyncio.run(server(host, port, sockets=[s])) +def launch_grpc_server(host: str, port: int, llm_args: dict): + """ + Launch a gRPC server for TensorRT-LLM. + + This provides a high-performance gRPC interface designed for external routers + (e.g., sgl-router) using pre-tokenized input and raw token ID output. + + Args: + host: Host to bind to + port: Port to bind to + llm_args: Arguments for LLM initialization (from get_llm_args) + """ + import grpc + + try: + from grpc_reflection.v1alpha import reflection + REFLECTION_AVAILABLE = True + except ImportError: + REFLECTION_AVAILABLE = False + + from tensorrt_llm.grpc import trtllm_service_pb2, trtllm_service_pb2_grpc + from tensorrt_llm.grpc.grpc_request_manager import GrpcRequestManager + from tensorrt_llm.grpc.grpc_servicer import TrtllmServiceServicer + + async def serve_grpc_async(): + logger.info("Initializing TensorRT-LLM gRPC server...") + + backend = llm_args.get("backend") + model_path = llm_args.get("model", "") + + if backend == "pytorch": + llm_args.pop("build_config", None) + llm = PyTorchLLM(**llm_args) + elif backend == "_autodeploy": + from tensorrt_llm._torch.auto_deploy import LLM as AutoDeployLLM + llm_args.pop("build_config", None) + llm = AutoDeployLLM(**llm_args) + elif backend == "tensorrt" or backend == "trt": + llm_args.pop("backend") + llm = LLM(**llm_args) + else: + raise click.BadParameter( + f"{backend} is not a known backend, check help for available options.", + param_hint="backend") + + logger.info("Model loaded successfully") + + # Create request manager + request_manager = GrpcRequestManager(llm) + + # Create servicer + servicer = TrtllmServiceServicer(request_manager, model_path=model_path) + + # Create gRPC server + server = grpc.aio.server( + options=[ + ("grpc.max_send_message_length", -1), # Unlimited + ("grpc.max_receive_message_length", -1), # Unlimited + ("grpc.keepalive_time_ms", 30000), # 30s keepalive + ("grpc.keepalive_timeout_ms", 10000), # 10s timeout + ], ) + + # Add servicer to server + trtllm_service_pb2_grpc.add_TrtllmServiceServicer_to_server( + servicer, server) + + # Enable reflection for grpcurl and other tools + if REFLECTION_AVAILABLE: + service_names = ( + trtllm_service_pb2.DESCRIPTOR.services_by_name["TrtllmService"]. + full_name, + reflection.SERVICE_NAME, + ) + reflection.enable_server_reflection(service_names, server) + logger.info("gRPC reflection enabled") + + # Bind to address + address = f"{host}:{port}" + server.add_insecure_port(address) + + # Start server + await server.start() + logger.info(f"TensorRT-LLM gRPC server started on {address}") + logger.info("Server is ready to accept requests") + + # Handle shutdown signals + loop = asyncio.get_running_loop() + stop_event = asyncio.Event() + + def signal_handler(): + logger.info("Received shutdown signal") + stop_event.set() + + for sig in (signal.SIGTERM, signal.SIGINT): + loop.add_signal_handler(sig, signal_handler) + + # Serve until shutdown signal + try: + await stop_event.wait() + except KeyboardInterrupt: + logger.info("Interrupted by user") + finally: + logger.info("Shutting down TensorRT-LLM gRPC server...") + + # Stop gRPC server + await server.stop(grace=5.0) + logger.info("gRPC server stopped") + + # Shutdown LLM + if hasattr(llm, "shutdown"): + llm.shutdown() + logger.info("LLM engine stopped") + + logger.info("Shutdown complete") + + asyncio.run(serve_grpc_async()) + + def launch_mm_encoder_server( host: str, port: int, @@ -476,6 +594,12 @@ class ChoiceWithAlias(click.Choice): "Specify a custom chat template. " "Can be a file path or one-liner template string", "prototype")) +@click.option( + "--grpc", + is_flag=True, + default=False, + help="Run gRPC server instead of OpenAI HTTP server. " + "gRPC server accepts pre-tokenized requests and returns raw token IDs.") def serve( model: str, tokenizer: Optional[str], custom_tokenizer: Optional[str], host: str, port: int, log_level: str, backend: str, max_beam_width: int, @@ -491,8 +615,9 @@ def serve( fail_fast_on_attention_window_too_large: bool, otlp_traces_endpoint: Optional[str], enable_chunked_prefill: bool, disagg_cluster_uri: Optional[str], media_io_kwargs: Optional[str], - custom_module_dirs: list[Path], chat_template: Optional[str]): - """Running an OpenAI API compatible server + custom_module_dirs: list[Path], chat_template: Optional[str], + grpc: bool): + """Running an OpenAI API compatible server (or gRPC server with --grpc flag) MODEL: model name | HF checkpoint path | TensorRT engine path """ @@ -569,9 +694,29 @@ def serve( multimodal_server_config = MultimodalServerConfig( media_io_kwargs=parsed_media_io_kwargs) - launch_server(host, port, llm_args, tool_parser, chat_template, - metadata_server_cfg, server_role, disagg_cluster_config, - multimodal_server_config) + + if grpc: + # gRPC mode: launch gRPC server instead of OpenAI HTTP server + # Check for unsupported arguments that are silently ignored in gRPC mode + unsupported_args = { + "tool_parser": tool_parser, + "chat_template": chat_template, + "metadata_server_config_file": metadata_server_config_file, + "server_role": server_role, + "disagg_cluster_config": disagg_cluster_config, + } + for name, value in unsupported_args.items(): + if value is not None: + raise ValueError( + f"Argument '{name}' is not supported when running in gRPC mode. " + f"The gRPC server is designed for use with external routers that handle " + f"these features (e.g., tool parsing, chat templates).") + launch_grpc_server(host, port, llm_args) + else: + # Default: launch OpenAI HTTP server + launch_server(host, port, llm_args, tool_parser, chat_template, + metadata_server_cfg, server_role, disagg_cluster_config, + multimodal_server_config) @click.command("mm_embedding_serve") diff --git a/tensorrt_llm/grpc/__init__.py b/tensorrt_llm/grpc/__init__.py new file mode 100644 index 0000000000..646d681492 --- /dev/null +++ b/tensorrt_llm/grpc/__init__.py @@ -0,0 +1,121 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +r"""TensorRT-LLM gRPC module for high-performance communication with external routers. + +This module provides a gRPC server interface that accepts pre-tokenized requests +and returns raw token IDs, enabling efficient binary communication with Rust-based +routers like sgl-router. + +Key Features: +- Pre-tokenized input (no Python tokenization overhead) +- Raw token ID output (no Python detokenization overhead) +- Streaming support with delta tokens +- Full sampling parameter support +- Guided decoding (JSON schema, regex, grammar) +- LoRA and prompt tuning support +- Disaggregated inference support + +Usage: + python -m tensorrt_llm.commands.serve /path/to/model \ + --grpc \ + --host 0.0.0.0 \ + --port 50051 +""" + +from pathlib import Path + +# Module directory for proto files +GRPC_MODULE_DIR = Path(__file__).parent + +# Proto file path +PROTO_FILE = GRPC_MODULE_DIR / "trtllm_service.proto" + +# Try to import generated protobuf modules +try: + from . import trtllm_service_pb2, trtllm_service_pb2_grpc + + PROTOS_AVAILABLE = True +except ImportError: + PROTOS_AVAILABLE = False + trtllm_service_pb2 = None + trtllm_service_pb2_grpc = None + + +def compile_protos(): + """Compile protobuf files to generate Python modules. + + Run this function if the generated *_pb2.py files are missing. + Alternatively, run: python tensorrt_llm/grpc/compile_protos.py + """ + from .compile_protos import main as compile_main + + compile_main() + + +def ensure_protos_available(): + """Ensure protobuf modules are available, compiling if necessary.""" + global PROTOS_AVAILABLE, trtllm_service_pb2, trtllm_service_pb2_grpc + + if not PROTOS_AVAILABLE: + raise ImportError( + "gRPC protobuf modules are not available. " + "Please generate them by running: " + "python tensorrt_llm/grpc/compile_protos.py" + ) + + +# Try to import request manager +try: + from .grpc_request_manager import ( + GrpcRequestManager, + create_disaggregated_params_from_proto, + create_lora_request_from_proto, + create_sampling_params_from_proto, + ) + + REQUEST_MANAGER_AVAILABLE = True +except ImportError: + REQUEST_MANAGER_AVAILABLE = False + GrpcRequestManager = None + create_sampling_params_from_proto = None + create_lora_request_from_proto = None + create_disaggregated_params_from_proto = None + +# Try to import servicer +try: + from .grpc_servicer import TrtllmServiceServicer + + SERVICER_AVAILABLE = True +except ImportError: + SERVICER_AVAILABLE = False + TrtllmServiceServicer = None + +__all__ = [ + "GRPC_MODULE_DIR", + "PROTO_FILE", + "PROTOS_AVAILABLE", + "REQUEST_MANAGER_AVAILABLE", + "SERVICER_AVAILABLE", + "compile_protos", + "ensure_protos_available", + "trtllm_service_pb2", + "trtllm_service_pb2_grpc", + "GrpcRequestManager", + "TrtllmServiceServicer", + "create_sampling_params_from_proto", + "create_lora_request_from_proto", + "create_disaggregated_params_from_proto", +] diff --git a/tensorrt_llm/grpc/compile_protos.py b/tensorrt_llm/grpc/compile_protos.py new file mode 100644 index 0000000000..db8f7eb1fd --- /dev/null +++ b/tensorrt_llm/grpc/compile_protos.py @@ -0,0 +1,167 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Compile Protocol Buffer definitions for TensorRT-LLM gRPC server. + +This script generates Python bindings from the trtllm_service.proto file. + +Usage: + python -m tensorrt_llm.grpc.compile_protos + +Or directly: + python compile_protos.py + +Requirements: + pip install grpcio-tools +""" + +import subprocess # nosec B404 +import sys +from pathlib import Path + + +def compile_protos(proto_dir: Path = None, output_dir: Path = None) -> bool: + """Compile proto files to Python. + + Args: + proto_dir: Directory containing .proto files. Defaults to this script's directory. + output_dir: Directory for generated Python files. Defaults to proto_dir. + + Returns: + True if compilation succeeded, False otherwise. + """ + if proto_dir is None: + proto_dir = Path(__file__).parent + if output_dir is None: + output_dir = proto_dir + + proto_file = proto_dir / "trtllm_service.proto" + + if not proto_file.exists(): + print(f"Error: Proto file not found: {proto_file}") + return False + + # Check for grpcio-tools + try: + from grpc_tools import protoc + except ImportError: + print("grpcio-tools not installed. Installing...") + subprocess.check_call([sys.executable, "-m", "pip", "install", "grpcio-tools"]) + from grpc_tools import protoc + + # Compile proto file + print(f"Compiling {proto_file}...") + + result = protoc.main( + [ + "grpc_tools.protoc", + f"--proto_path={proto_dir}", + f"--python_out={output_dir}", + f"--grpc_python_out={output_dir}", + str(proto_file), + ] + ) + + if result != 0: + print(f"Error: protoc failed with code {result}") + return False + + # Fix imports in generated files (grpc_tools generates absolute imports) + pb2_file = output_dir / "trtllm_service_pb2.py" + pb2_grpc_file = output_dir / "trtllm_service_pb2_grpc.py" + + if pb2_grpc_file.exists(): + content = pb2_grpc_file.read_text() + # Fix import to use relative import + content = content.replace( + "import trtllm_service_pb2 as trtllm__service__pb2", + "from . import trtllm_service_pb2 as trtllm__service__pb2", + ) + pb2_grpc_file.write_text(content) + print(f"Fixed imports in {pb2_grpc_file}") + + print("Generated files:") + print(f" - {pb2_file}") + print(f" - {pb2_grpc_file}") + print("Proto compilation successful!") + + return True + + +def verify_generated_files(output_dir: Path = None) -> bool: + """Verify that generated files can be imported. + + Args: + output_dir: Directory containing generated files. + + Returns: + True if files can be imported, False otherwise. + """ + if output_dir is None: + output_dir = Path(__file__).parent + + pb2_file = output_dir / "trtllm_service_pb2.py" + pb2_grpc_file = output_dir / "trtllm_service_pb2_grpc.py" + + if not pb2_file.exists() or not pb2_grpc_file.exists(): + print("Generated files not found. Run compile_protos() first.") + return False + + # Try to import + import importlib.util + + try: + spec = importlib.util.spec_from_file_location("trtllm_service_pb2", pb2_file) + pb2_module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(pb2_module) + print(f"Successfully imported {pb2_file.name}") + + # Verify key message types exist + assert hasattr(pb2_module, "GenerateRequest") + assert hasattr(pb2_module, "GenerateResponse") + assert hasattr(pb2_module, "SamplingConfig") + assert hasattr(pb2_module, "OutputConfig") + print(" - GenerateRequest, GenerateResponse, SamplingConfig, OutputConfig OK") + + except Exception as e: + print(f"Error importing {pb2_file.name}: {e}") + return False + + print("Verification successful!") + return True + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(description="Compile TensorRT-LLM gRPC protos") + parser.add_argument( + "--proto-dir", type=Path, default=None, help="Directory containing .proto files" + ) + parser.add_argument( + "--output-dir", type=Path, default=None, help="Directory for generated Python files" + ) + parser.add_argument( + "--verify", action="store_true", help="Verify generated files after compilation" + ) + + args = parser.parse_args() + + success = compile_protos(args.proto_dir, args.output_dir) + + if success and args.verify: + success = verify_generated_files(args.output_dir) + + sys.exit(0 if success else 1) diff --git a/tensorrt_llm/grpc/grpc_request_manager.py b/tensorrt_llm/grpc/grpc_request_manager.py new file mode 100644 index 0000000000..63bfccb45e --- /dev/null +++ b/tensorrt_llm/grpc/grpc_request_manager.py @@ -0,0 +1,420 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""gRPC Request Manager for TensorRT-LLM. + +Manages request lifecycle for gRPC requests, converting between protobuf +and TensorRT-LLM types. Designed for high-performance communication with +external routers (e.g., sgl-router) using pre-tokenized input. + +Key optimization: Sets detokenize=False in SamplingParams to skip +detokenization and return token IDs only. +""" + +import asyncio +import traceback +from collections.abc import AsyncGenerator +from typing import Any, Dict, List, Optional, Tuple + +from tensorrt_llm.disaggregated_params import DisaggregatedParams +from tensorrt_llm.executor.request import LoRARequest, PromptAdapterRequest +from tensorrt_llm.executor.result import GenerationResult +from tensorrt_llm.llmapi.llm_utils import KvCacheRetentionConfig +from tensorrt_llm.logger import logger +from tensorrt_llm.sampling_params import GuidedDecodingParams, SamplingParams + +from . import trtllm_service_pb2 as pb2 + + +class GrpcRequestManager: + """Manages gRPC request lifecycle for TensorRT-LLM. + + Responsibilities: + - Convert protobuf requests to TensorRT-LLM types + - Set detokenize=False in SamplingParams (key optimization!) + - Submit requests to LLM.generate_async() + - Stream token IDs (not text) back to gRPC clients + - Handle abort/cancel operations + + This is modeled after vLLM's GrpcRequestManager but adapted for TensorRT-LLM's + GenerationResult async iterator pattern. + """ + + def __init__(self, llm: Any): + """Initialize the request manager. + + Args: + llm: The TensorRT-LLM LLM instance (tensorrt_llm.LLM or tensorrt_llm._tensorrt_engine.LLM) + """ + self.llm = llm + # Track active requests: request_id -> GenerationResult + self._rid_to_result: Dict[str, GenerationResult] = {} + + logger.info("GrpcRequestManager initialized") + + async def generate( + self, + request_id: str, + prompt_token_ids: List[int], + sampling_params: SamplingParams, + streaming: bool = True, + lora_request: Optional[LoRARequest] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + kv_cache_retention_config: Optional[KvCacheRetentionConfig] = None, + disaggregated_params: Optional[DisaggregatedParams] = None, + ) -> AsyncGenerator[GenerationResult, None]: + """Submit a generation request and stream outputs. + + Args: + request_id: Unique request identifier (for tracking/abort) + prompt_token_ids: Pre-tokenized input from Rust router + sampling_params: Sampling parameters (with detokenize=False!) + streaming: Whether to stream results + lora_request: Optional LoRA adapter request + prompt_adapter_request: Optional prompt adapter request + kv_cache_retention_config: KV cache retention config + disaggregated_params: Disaggregated inference params + + Yields: + GenerationResult objects containing token IDs (text will be empty + because detokenize=False) + """ + try: + # Submit to LLM.generate_async which returns a GenerationResult + # that is an async iterator + gen_result = self.llm.generate_async( + {"prompt_token_ids": prompt_token_ids}, + sampling_params, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request, + streaming=streaming, + kv_cache_retention_config=kv_cache_retention_config, + disaggregated_params=disaggregated_params, + ) + + # Track the result for potential abort + self._rid_to_result[request_id] = gen_result + + # Iterate over the async generator + # GenerationResult implements __aiter__ and __anext__ + async for result in gen_result: + yield result + + if result.finished: + break + + except asyncio.CancelledError: + logger.info(f"Request {request_id} cancelled by client") + await self.abort(request_id) + raise + except Exception as e: + logger.error(f"Error in generate for {request_id}: {e}") + raise + finally: + # Cleanup tracking + self._rid_to_result.pop(request_id, None) + + async def abort(self, request_id: str) -> bool: + """Abort a running request. + + Args: + request_id: The request ID to abort + + Returns: + True if request was found and aborted, False otherwise + """ + gen_result = self._rid_to_result.get(request_id) + + if gen_result is None: + logger.debug(f"Abort: request {request_id} not found (may have already completed)") + return False + + try: + # GenerationResult has an abort() method + gen_result.abort() + self._rid_to_result.pop(request_id, None) + logger.info(f"Request {request_id} aborted") + return True + except Exception as e: + logger.error(f"Error aborting request {request_id}: {e}") + self._rid_to_result.pop(request_id, None) + return False + + async def health_check(self) -> Tuple[bool, str]: + """Check if the engine is healthy. + + Returns: + Tuple of (is_healthy, message) + """ + try: + if self.llm is None: + return False, "LLM not initialized" + + # Check if executor is available and not shutdown + if hasattr(self.llm, "_executor"): + if self.llm._executor is None or self.llm._executor.is_shutdown(): + return False, "Executor is shutdown" + + return True, "OK" + except Exception as e: + logger.error(f"Health check error: {e}") + return False, f"Error: {e}" + + def get_model_config(self) -> Dict[str, Any]: + """Get model configuration information. + + Returns: + Dictionary with model config details + """ + config = { + "model_path": "", + "is_generation": True, + "max_context_length": 0, + "max_seq_len": 0, + "vocab_size": 0, + "supports_vision": False, + } + + try: + # Try to get model path + if hasattr(self.llm, "args"): + if hasattr(self.llm.args, "model"): + config["model_path"] = str(self.llm.args.model) + + # Try to get tokenizer info + if hasattr(self.llm, "tokenizer") and self.llm.tokenizer is not None: + if hasattr(self.llm.tokenizer, "vocab_size"): + config["vocab_size"] = self.llm.tokenizer.vocab_size + + # Try to get max context length from various sources + if hasattr(self.llm, "args") and self.llm.args is not None: + args = self.llm.args + # Try max_input_len first (input context) + if hasattr(args, "max_input_len") and args.max_input_len: + config["max_context_length"] = args.max_input_len + # Try max_seq_len (total sequence including output) + if hasattr(args, "max_seq_len") and args.max_seq_len: + config["max_seq_len"] = args.max_seq_len + + # Check for multimodal support + if hasattr(self.llm, "input_processor"): + processor_name = type(self.llm.input_processor).__name__ + config["supports_vision"] = processor_name != "DefaultInputProcessor" + + except Exception as e: + logger.warning( + f"Error getting model config: {type(e).__name__}: {e}\n{traceback.format_exc()}" + ) + + return config + + def get_num_unfinished_requests(self) -> int: + """Get the number of currently running requests. + + Returns: + Number of unfinished requests + """ + return len(self._rid_to_result) + + +def create_sampling_params_from_proto( + proto_config: pb2.SamplingConfig, + output_config: pb2.OutputConfig, + max_tokens: int, + end_id: Optional[int] = None, + pad_id: Optional[int] = None, + bad_words: Optional[List[pb2.TokenSequence]] = None, + stop_words: Optional[List[pb2.TokenSequence]] = None, + guided_decoding: Optional[pb2.GuidedDecodingParams] = None, + embedding_bias: Optional[List[float]] = None, +) -> SamplingParams: + """Convert protobuf configuration to TensorRT-LLM SamplingParams. + + Args: + proto_config: Protobuf SamplingConfig message + output_config: Protobuf OutputConfig message + max_tokens: Maximum tokens to generate + end_id: End-of-sequence token ID + pad_id: Padding token ID + bad_words: Bad word token sequences + stop_words: Stop word token sequences + guided_decoding: Guided decoding parameters + embedding_bias: Embedding bias tensor + + Returns: + TensorRT-LLM SamplingParams with detokenize=False + """ + # Build kwargs for SamplingParams + # KEY OPTIMIZATION: detokenize=False skips Python detokenization! + kwargs = { + "max_tokens": max_tokens, + "detokenize": False, + } + + # Beam search / sampling + if proto_config.beam_width > 1: + kwargs["beam_width"] = proto_config.beam_width + if proto_config.num_return_sequences > 0: + kwargs["n"] = proto_config.num_return_sequences + + # Temperature and sampling parameters (with sensible defaults as safety guard) + kwargs["temperature"] = ( + proto_config.temperature if proto_config.HasField("temperature") else 1.0 + ) + kwargs["top_p"] = proto_config.top_p if proto_config.HasField("top_p") else 1.0 + if proto_config.HasField("top_k"): + kwargs["top_k"] = proto_config.top_k + if proto_config.HasField("min_p"): + kwargs["min_p"] = proto_config.min_p + + # Top-P decay parameters + if proto_config.HasField("top_p_min"): + kwargs["top_p_min"] = proto_config.top_p_min + if proto_config.HasField("top_p_reset_ids"): + kwargs["top_p_reset_ids"] = proto_config.top_p_reset_ids + if proto_config.HasField("top_p_decay"): + kwargs["top_p_decay"] = proto_config.top_p_decay + + # Seed for reproducibility + if proto_config.HasField("seed"): + kwargs["random_seed"] = proto_config.seed + + # Min/max tokens + if proto_config.HasField("min_tokens"): + kwargs["min_tokens"] = proto_config.min_tokens + + # Penalties (repetition_penalty defaults to 1.0 = no penalty) + kwargs["repetition_penalty"] = ( + proto_config.repetition_penalty if proto_config.HasField("repetition_penalty") else 1.0 + ) + if proto_config.HasField("presence_penalty"): + kwargs["presence_penalty"] = proto_config.presence_penalty + if proto_config.HasField("frequency_penalty"): + kwargs["frequency_penalty"] = proto_config.frequency_penalty + + # Beam search parameters + if proto_config.HasField("beam_search_diversity_rate"): + kwargs["beam_search_diversity_rate"] = proto_config.beam_search_diversity_rate + if proto_config.HasField("length_penalty"): + kwargs["length_penalty"] = proto_config.length_penalty + if proto_config.HasField("early_stopping"): + kwargs["early_stopping"] = proto_config.early_stopping + + # N-gram blocking + if proto_config.HasField("no_repeat_ngram_size"): + kwargs["no_repeat_ngram_size"] = proto_config.no_repeat_ngram_size + + # End/pad tokens + if end_id is not None: + kwargs["end_id"] = end_id + if end_id == -1: + kwargs["ignore_eos"] = True + if pad_id is not None: + kwargs["pad_id"] = pad_id + + # Output configuration - logprobs + if output_config.HasField("logprobs"): + kwargs["logprobs"] = output_config.logprobs + if output_config.HasField("prompt_logprobs"): + kwargs["prompt_logprobs"] = output_config.prompt_logprobs + if output_config.return_context_logits: + kwargs["return_context_logits"] = True + if output_config.return_generation_logits: + kwargs["return_generation_logits"] = True + if output_config.exclude_input_from_output: + kwargs["exclude_input_from_output"] = True + + # Stop sequences (as token ID lists) + if stop_words: + kwargs["stop_words"] = [list(seq.token_ids) for seq in stop_words] + if bad_words: + kwargs["bad_words"] = [list(seq.token_ids) for seq in bad_words] + + # Embedding bias + if embedding_bias: + kwargs["embedding_bias"] = list(embedding_bias) + + # Guided decoding + if guided_decoding and guided_decoding.guide: + guide_type = guided_decoding.guide_type + guide_content = guided_decoding.guide + + if guide_type == pb2.GuidedDecodingParams.GUIDE_TYPE_JSON: + # json_object=True for JSON validation without schema constraint + kwargs["guided_decoding_params"] = GuidedDecodingParams(json_object=True) + elif guide_type == pb2.GuidedDecodingParams.GUIDE_TYPE_JSON_SCHEMA: + kwargs["guided_decoding_params"] = GuidedDecodingParams(json_schema=guide_content) + elif guide_type == pb2.GuidedDecodingParams.GUIDE_TYPE_REGEX: + kwargs["guided_decoding_params"] = GuidedDecodingParams(regex=guide_content) + elif guide_type == pb2.GuidedDecodingParams.GUIDE_TYPE_EBNF_GRAMMAR: + kwargs["guided_decoding_params"] = GuidedDecodingParams(grammar=guide_content) + + return SamplingParams(**kwargs) + + +def create_lora_request_from_proto( + proto_config: Optional[pb2.LoraConfig], +) -> Optional[LoRARequest]: + """Convert protobuf LoraConfig to TensorRT-LLM LoRARequest. + + Args: + proto_config: Protobuf LoraConfig message + + Returns: + LoRARequest or None + """ + if proto_config is None or proto_config.task_id == 0: + return None + + return LoRARequest( + lora_name=f"lora_{proto_config.task_id}", + lora_int_id=proto_config.task_id, + ) + + +def create_disaggregated_params_from_proto( + proto_config: Optional[pb2.DisaggregatedParams], +) -> Optional[DisaggregatedParams]: + """Convert protobuf DisaggregatedParams to TensorRT-LLM DisaggregatedParams. + + Args: + proto_config: Protobuf DisaggregatedParams message + + Returns: + DisaggregatedParams or None + """ + if proto_config is None: + return None + + request_type_map = { + pb2.DisaggregatedParams.REQUEST_TYPE_CONTEXT_AND_GENERATION: "context_and_generation", + pb2.DisaggregatedParams.REQUEST_TYPE_CONTEXT_ONLY: "context_only", + pb2.DisaggregatedParams.REQUEST_TYPE_GENERATION_ONLY: "generation_only", + } + + request_type = request_type_map.get(proto_config.request_type, "context_and_generation") + + params = DisaggregatedParams(request_type=request_type) + + if proto_config.ctx_request_id: + params.ctx_request_id = proto_config.ctx_request_id + + if proto_config.HasField("context_phase_params"): + ctx_params = proto_config.context_phase_params + params.first_gen_token_id = ctx_params.first_gen_token_id + if ctx_params.kv_cache_blocks: + params.kv_cache_blocks = ctx_params.kv_cache_blocks + + return params diff --git a/tensorrt_llm/grpc/grpc_servicer.py b/tensorrt_llm/grpc/grpc_servicer.py new file mode 100644 index 0000000000..4ad8addd80 --- /dev/null +++ b/tensorrt_llm/grpc/grpc_servicer.py @@ -0,0 +1,545 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""gRPC Servicer for TensorRT-LLM. + +Implements the TrtllmService gRPC service for high-performance communication +with external routers (e.g., sgl-router) using pre-tokenized input. +""" + +import asyncio +import time +from collections.abc import AsyncGenerator +from typing import List, Union + +import grpc + +from tensorrt_llm.executor.result import Logprob, TokenLogprobs +from tensorrt_llm.logger import logger + +from . import trtllm_service_pb2, trtllm_service_pb2_grpc +from .grpc_request_manager import ( + GrpcRequestManager, + create_disaggregated_params_from_proto, + create_lora_request_from_proto, + create_sampling_params_from_proto, +) + + +class TrtllmServiceServicer(trtllm_service_pb2_grpc.TrtllmServiceServicer): + """gRPC servicer implementing the TrtLlmEngine service. + + Handles RPCs: + - Generate: Streaming text generation + - Embed: Embeddings (for embedding models) + - HealthCheck: Health probe + - Abort: Cancel a request + - GetModelInfo: Model metadata + - GetServerInfo: Server state + """ + + def __init__(self, request_manager: GrpcRequestManager, model_path: str = ""): + """Initialize the servicer. + + Args: + request_manager: The GrpcRequestManager instance + model_path: Path to the model (for metadata) + """ + self.request_manager = request_manager + self.model_path = model_path + self._start_time = time.time() + logger.info("TrtllmServiceServicer initialized") + + async def Generate( + self, + request: trtllm_service_pb2.GenerateRequest, + context: grpc.aio.ServicerContext, + ) -> AsyncGenerator[trtllm_service_pb2.GenerateResponse, None]: + """Handle streaming generation requests. + + Args: + request: The GenerateRequest protobuf + context: gRPC context + + Yields: + GenerateResponse protobuf messages (streaming) + """ + request_id = request.request_id + logger.info(f"Generate request {request_id} received") + + try: + # Extract tokenized input (required) + if not request.HasField("tokenized"): + yield self._error_response( + request_id, + "Missing tokenized input", + "INVALID_REQUEST", + 400, + ) + return + + prompt_token_ids = list(request.tokenized.input_token_ids) + + # Build sampling params with detokenize=False (key optimization!) + sampling_params = create_sampling_params_from_proto( + proto_config=request.sampling_config, + output_config=request.output_config, + max_tokens=request.max_tokens, + end_id=request.end_id if request.HasField("end_id") else None, + pad_id=request.pad_id if request.HasField("pad_id") else None, + bad_words=list(request.bad_words) if request.bad_words else None, + stop_words=list(request.stop_words) if request.stop_words else None, + guided_decoding=request.guided_decoding + if request.HasField("guided_decoding") + else None, + embedding_bias=list(request.embedding_bias) if request.embedding_bias else None, + ) + + # Build LoRA request if present + lora_request = create_lora_request_from_proto( + request.lora_config if request.HasField("lora_config") else None + ) + + # Build disaggregated params if present + disaggregated_params = create_disaggregated_params_from_proto( + request.disaggregated_params if request.HasField("disaggregated_params") else None + ) + + # Track tokens sent per sequence index to avoid duplicates + # TRT-LLM's token_ids_diff doesn't clear between iterations for n>1 + sent_token_counts: dict[int, int] = {} + + # Submit to request manager and stream outputs + # The request manager now yields GenerationResult objects + async for gen_result in self.request_manager.generate( + request_id=request_id, + prompt_token_ids=prompt_token_ids, + sampling_params=sampling_params, + streaming=request.streaming, + lora_request=lora_request, + disaggregated_params=disaggregated_params, + ): + # Check if client disconnected + if context.cancelled(): + logger.info(f"Client disconnected for {request_id}") + await self.request_manager.abort(request_id) + return + + # Convert GenerationResult to protobuf response + if request.streaming: + for chunk_response in self._chunk_responses( + request_id, gen_result, prompt_token_ids, sent_token_counts + ): + yield chunk_response + + # Send complete responses when finished (one per sequence for n>1) + if gen_result.finished: + for complete_response in self._complete_responses( + request_id, gen_result, prompt_token_ids + ): + yield complete_response + + except asyncio.CancelledError: + logger.info(f"Request {request_id} cancelled") + await self.request_manager.abort(request_id) + raise + except Exception as e: + logger.error(f"Error in Generate for {request_id}: {e}") + yield self._error_response( + request_id, + str(e), + "INTERNAL_ERROR", + 500, + ) + + async def Embed( + self, + request: trtllm_service_pb2.EmbedRequest, + context: grpc.aio.ServicerContext, + ) -> trtllm_service_pb2.EmbedResponse: + """Handle embedding requests. + + Args: + request: The EmbedRequest protobuf + context: gRPC context + + Returns: + EmbedResponse protobuf + """ + logger.warning("Embed RPC not yet implemented") + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details("Embed RPC not yet implemented") + return trtllm_service_pb2.EmbedResponse( + request_id=request.request_id, + embedding=[], + prompt_tokens=0, + ) + + async def HealthCheck( + self, + request: trtllm_service_pb2.HealthCheckRequest, + context: grpc.aio.ServicerContext, + ) -> trtllm_service_pb2.HealthCheckResponse: + """Handle health check requests. + + Args: + request: The HealthCheckRequest protobuf + context: gRPC context + + Returns: + HealthCheckResponse protobuf + """ + is_healthy, message = await self.request_manager.health_check() + logger.info(f"HealthCheck: healthy={is_healthy}, message={message}") + + return trtllm_service_pb2.HealthCheckResponse( + status=message, + ) + + async def Abort( + self, + request: trtllm_service_pb2.AbortRequest, + context: grpc.aio.ServicerContext, + ) -> trtllm_service_pb2.AbortResponse: + """Handle abort requests. + + Args: + request: The AbortRequest protobuf + context: gRPC context + + Returns: + AbortResponse protobuf + """ + request_id = request.request_id + logger.info(f"Abort request for {request_id}") + + success = await self.request_manager.abort(request_id) + + return trtllm_service_pb2.AbortResponse( + success=success, + message=f"Request {request_id} {'aborted' if success else 'not found'}", + ) + + async def GetModelInfo( + self, + request: trtllm_service_pb2.GetModelInfoRequest, + context: grpc.aio.ServicerContext, + ) -> trtllm_service_pb2.GetModelInfoResponse: + """Handle model info requests. + + Args: + request: The GetModelInfoRequest protobuf + context: gRPC context + + Returns: + GetModelInfoResponse protobuf + """ + model_config = self.request_manager.get_model_config() + + return trtllm_service_pb2.GetModelInfoResponse( + model_id=self.model_path or model_config.get("model_path", ""), + max_input_len=model_config.get("max_context_length", 0), + max_seq_len=model_config.get("max_seq_len", 0) + or model_config.get("max_context_length", 0), + vocab_size=model_config.get("vocab_size", 0), + ) + + async def GetServerInfo( + self, + request: trtllm_service_pb2.GetServerInfoRequest, + context: grpc.aio.ServicerContext, + ) -> trtllm_service_pb2.GetServerInfoResponse: + """Handle server info requests. + + Args: + request: The GetServerInfoRequest protobuf + context: gRPC context + + Returns: + GetServerInfoResponse protobuf + """ + try: + import tensorrt_llm + + version = getattr(tensorrt_llm, "__version__", "unknown") + except Exception: + version = "unknown" + + # Try to get parallelism info from LLM args + tp_size = 1 + pp_size = 1 + world_size = 1 + + try: + llm = self.request_manager.llm + if hasattr(llm, "args") and llm.args is not None: + args = llm.args + if hasattr(args, "tensor_parallel_size") and args.tensor_parallel_size: + tp_size = args.tensor_parallel_size + if hasattr(args, "pipeline_parallel_size") and args.pipeline_parallel_size: + pp_size = args.pipeline_parallel_size + world_size = tp_size * pp_size + except Exception as e: + logger.debug(f"Could not get parallelism info: {e}") + + return trtllm_service_pb2.GetServerInfoResponse( + version=version, + backend="tensorrt-llm", + tensor_parallel_size=tp_size, + pipeline_parallel_size=pp_size, + context_parallel_size=1, # Context parallelism is separate from TP/PP + world_size=world_size, + ) + + # ========== Helper methods ========== + + def _convert_logprobs_to_proto( + self, + token_ids: List[int], + logprobs: Union[TokenLogprobs, List[float], None], + ) -> List[trtllm_service_pb2.TokenLogprob]: + """Convert TRT-LLM logprobs to protobuf TokenLogprob messages. + + Handles both formats: + - List[float]: Simple logprobs (one per token) + - TokenLogprobs (list[dict[int, Logprob]]): Top-k logprobs per position + + Args: + token_ids: The sampled token IDs + logprobs: Logprobs from TRT-LLM (can be List[float] or TokenLogprobs) + + Returns: + List of TokenLogprob proto messages + """ + if not logprobs or not token_ids: + return [] + + result = [] + for i, token_id in enumerate(token_ids): + if i >= len(logprobs): + break + + lp = logprobs[i] + + if isinstance(lp, dict): + # TokenLogprobs format: dict[int, Logprob] + # Each entry maps token_id -> Logprob(logprob, rank) + token_logprob = trtllm_service_pb2.TokenLogprob( + token_id=token_id, + logprob=lp[token_id].logprob if token_id in lp else 0.0, + ) + # Add top logprobs (all entries in the dict) + for tid, logprob_obj in lp.items(): + if isinstance(logprob_obj, Logprob): + token_logprob.top_logprobs.append( + trtllm_service_pb2.TopLogprob( + token_id=tid, + logprob=logprob_obj.logprob, + ) + ) + result.append(token_logprob) + elif isinstance(lp, (int, float)): + # Simple float logprob + token_logprob = trtllm_service_pb2.TokenLogprob( + token_id=token_id, + logprob=float(lp), + ) + result.append(token_logprob) + + return result + + def _chunk_responses( + self, + request_id: str, + gen_result, + prompt_token_ids: list, + sent_token_counts: dict[int, int], + ) -> List[trtllm_service_pb2.GenerateResponse]: + """Build streaming chunk responses from GenerationResult. + + Uses cumulative token_ids and tracks sent position to compute true deltas. + TRT-LLM's token_ids_diff doesn't clear between iterations for n>1, so we + compute deltas ourselves. + + Args: + request_id: The request ID + gen_result: TensorRT-LLM GenerationResult + prompt_token_ids: Original prompt tokens + sent_token_counts: Dict tracking tokens already sent per sequence index + + Returns: + List of GenerateResponse with chunk field set (one per output) + """ + responses = [] + cached_tokens = gen_result.cached_tokens if hasattr(gen_result, "cached_tokens") else 0 + + if not gen_result.outputs: + # No outputs yet, return empty chunk + responses.append( + trtllm_service_pb2.GenerateResponse( + request_id=request_id, + chunk=trtllm_service_pb2.GenerateStreamChunk( + token_ids=[], + prompt_tokens=len(prompt_token_ids), + completion_tokens=0, + cached_tokens=cached_tokens, + ), + ) + ) + return responses + + # Process all outputs (for n>1 support) + for completion in gen_result.outputs: + index = completion.index + # Use cumulative token_ids and compute delta ourselves + # because token_ids_diff doesn't clear between iterations for n>1 + all_tokens = list(completion.token_ids) if completion.token_ids else [] + sent_count = sent_token_counts.get(index, 0) + delta_tokens = all_tokens[sent_count:] + + # Skip if no new tokens for this sequence + if not delta_tokens: + continue + + # Update sent count + sent_token_counts[index] = len(all_tokens) + + chunk = trtllm_service_pb2.GenerateStreamChunk( + token_ids=delta_tokens, + sequence_index=completion.index, + prompt_tokens=len(prompt_token_ids), + completion_tokens=len(completion.token_ids) if completion.token_ids else 0, + cached_tokens=cached_tokens, + ) + + # Add logprobs if available + # Note: We compute delta logprobs ourselves since logprobs_diff has same issue as token_ids_diff + if completion.logprobs: + all_logprobs = completion.logprobs + delta_logprobs = all_logprobs[sent_count:] if sent_count < len(all_logprobs) else [] + proto_logprobs = self._convert_logprobs_to_proto(delta_tokens, delta_logprobs) + chunk.logprobs.extend(proto_logprobs) + + responses.append( + trtllm_service_pb2.GenerateResponse( + request_id=request_id, + chunk=chunk, + ) + ) + + return responses + + def _complete_responses( + self, + request_id: str, + gen_result, + prompt_token_ids: list, + ) -> List[trtllm_service_pb2.GenerateResponse]: + """Build final completion responses from GenerationResult. + + For n>1, returns one response per output sequence. + + Args: + request_id: The request ID + gen_result: TensorRT-LLM GenerationResult (finished=True) + prompt_token_ids: Original prompt tokens + + Returns: + List of GenerateResponse with complete field set (one per output) + """ + responses = [] + cached_tokens = gen_result.cached_tokens if hasattr(gen_result, "cached_tokens") else 0 + + if not gen_result.outputs: + # No outputs, return error response + responses.append( + trtllm_service_pb2.GenerateResponse( + request_id=request_id, + complete=trtllm_service_pb2.GenerateComplete( + output_token_ids=[], + finish_reason="error", + prompt_tokens=len(prompt_token_ids), + completion_tokens=0, + cached_tokens=0, + ), + ) + ) + return responses + + # Process all outputs (for n>1 support) + for completion in gen_result.outputs: + output_tokens = list(completion.token_ids) if completion.token_ids else [] + + complete = trtllm_service_pb2.GenerateComplete( + output_token_ids=output_tokens, + sequence_index=completion.index, + finish_reason=completion.finish_reason or "stop", + prompt_tokens=len(prompt_token_ids), + completion_tokens=len(output_tokens), + cached_tokens=cached_tokens, + ) + + # Add stop reason if available + if hasattr(completion, "stop_reason") and completion.stop_reason: + complete.stop_reason = str(completion.stop_reason) + + # Add generation logprobs if available + if completion.logprobs: + proto_logprobs = self._convert_logprobs_to_proto(output_tokens, completion.logprobs) + complete.logprobs.extend(proto_logprobs) + + # Add prompt logprobs if available + if hasattr(completion, "prompt_logprobs") and completion.prompt_logprobs: + # For prompt logprobs, we use the prompt_token_ids + proto_prompt_logprobs = self._convert_logprobs_to_proto( + prompt_token_ids, completion.prompt_logprobs + ) + complete.prompt_logprobs.extend(proto_prompt_logprobs) + + responses.append( + trtllm_service_pb2.GenerateResponse( + request_id=request_id, + complete=complete, + ) + ) + + return responses + + def _error_response( + self, + request_id: str, + message: str, + error_type: str, + code: int, + ) -> trtllm_service_pb2.GenerateResponse: + """Build an error response. + + Args: + request_id: The request ID + message: Error message + error_type: Error type string + code: Error code + + Returns: + GenerateResponse with error field set + """ + return trtllm_service_pb2.GenerateResponse( + request_id=request_id, + error=trtllm_service_pb2.GenerateError( + message=message, + type=error_type, + code=code, + ), + ) diff --git a/tensorrt_llm/grpc/trtllm_service.proto b/tensorrt_llm/grpc/trtllm_service.proto new file mode 100644 index 0000000000..427226e288 --- /dev/null +++ b/tensorrt_llm/grpc/trtllm_service.proto @@ -0,0 +1,511 @@ +// SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto3"; +package trtllm; + +// TensorRT-LLM gRPC Service +// +// This service provides high-performance inference for LLMs using TensorRT-LLM. +// It accepts pre-tokenized requests and returns raw token IDs, enabling efficient +// binary communication with external routers (e.g., sgl-router). +// +// Key Design Principles: +// - Token IDs only: No text in requests or responses (router handles tokenization) +// - Streaming delta mode: Chunks contain only new tokens since last chunk +// - Full feature support: All TensorRT-LLM capabilities exposed + +service TrtllmService { + // Generate tokens from pre-tokenized input + // Returns a stream of responses containing token IDs + rpc Generate(GenerateRequest) returns (stream GenerateResponse); + + // Generate embeddings from pre-tokenized input (for embedding models) + rpc Embed(EmbedRequest) returns (EmbedResponse); + + // Health check endpoint + rpc HealthCheck(HealthCheckRequest) returns (HealthCheckResponse); + + // Abort a running generation request + rpc Abort(AbortRequest) returns (AbortResponse); + + // Get model information (vocab size, max lengths, etc.) + rpc GetModelInfo(GetModelInfoRequest) returns (GetModelInfoResponse); + + // Get server information (version, parallelism, etc.) + rpc GetServerInfo(GetServerInfoRequest) returns (GetServerInfoResponse); +} + +// ============================================================================ +// Generate Request +// ============================================================================ + +message GenerateRequest { + // Unique request identifier (assigned by client/router) + string request_id = 1; + + // Pre-tokenized input (REQUIRED - router tokenizes) + TokenizedInput tokenized = 2; + + // Sampling configuration + SamplingConfig sampling_config = 3; + + // Output configuration + OutputConfig output_config = 4; + + // Maximum tokens to generate (REQUIRED) + uint32 max_tokens = 5; + + // Enable streaming mode + bool streaming = 6; + + // End-of-sequence token ID (optional, set to -1 to ignore EOS) + optional int32 end_id = 7; + + // Padding token ID + optional int32 pad_id = 8; + + // Bad word token sequences (generation redirects if these are produced) + repeated TokenSequence bad_words = 9; + + // Stop word token sequences (generation stops when these are produced) + repeated TokenSequence stop_words = 10; + + // Guided decoding parameters (JSON schema, regex, grammar) + optional GuidedDecodingParams guided_decoding = 11; + + // Embedding bias tensor (vocab_size floats, optional) + repeated float embedding_bias = 12; + + // LoRA adapter configuration + optional LoraConfig lora_config = 13; + + // Prompt tuning configuration + optional PromptTuningConfig prompt_tuning_config = 14; + + // Multimodal input (for VLMs) + optional MultimodalInput multimodal_input = 15; + + // KV cache retention configuration + optional KvCacheRetentionConfig kv_cache_retention = 16; + + // Disaggregated inference parameters + optional DisaggregatedParams disaggregated_params = 17; + + // Lookahead decoding configuration + optional LookaheadConfig lookahead_config = 18; + + // Cache salt ID for cache hashing + optional int64 cache_salt_id = 19; + + // Request arrival time (unix timestamp for metrics) + optional double arrival_time = 20; +} + +// Tokenized input from router +message TokenizedInput { + // Original text (for debugging/logging only, not used for generation) + string original_text = 1; + + // Pre-tokenized input token IDs (REQUIRED) + repeated uint32 input_token_ids = 2; + + // Query token IDs for VLM star attention (optional) + repeated uint32 query_token_ids = 3; +} + +// Sequence of token IDs (for stop/bad words) +message TokenSequence { + repeated uint32 token_ids = 1; +} + +// ============================================================================ +// Sampling Configuration +// Maps to tensorrt_llm.bindings.executor.SamplingConfig +// ============================================================================ + +message SamplingConfig { + // Beam width (1 for sampling, >1 for beam search) + int32 beam_width = 1; + + // Number of sequences to return + uint32 num_return_sequences = 2; + + // Top-K sampling (0 = disabled, considers all tokens) + optional int32 top_k = 3; + + // Top-P (nucleus) sampling threshold + optional float top_p = 4; + + // Top-P minimum threshold for decay + optional float top_p_min = 5; + + // Top-P reset token IDs + optional int32 top_p_reset_ids = 6; + + // Top-P decay factor + optional float top_p_decay = 7; + + // Random seed for reproducibility + optional uint64 seed = 8; + + // Temperature for sampling (0 = greedy, higher = more random) + optional float temperature = 9; + + // Minimum tokens to generate before stopping + optional uint32 min_tokens = 10; + + // Beam search diversity rate + optional float beam_search_diversity_rate = 11; + + // Repetition penalty (>1 discourages, <1 encourages repetition) + optional float repetition_penalty = 12; + + // Presence penalty (penalizes tokens that have appeared) + optional float presence_penalty = 13; + + // Frequency penalty (penalizes based on frequency of appearance) + optional float frequency_penalty = 14; + + // Number of prompt tokens to ignore for penalties + optional int32 prompt_ignore_length = 15; + + // Length penalty for beam search + optional float length_penalty = 16; + + // Early stopping for beam search + optional int32 early_stopping = 17; + + // No repeat n-gram size + optional int32 no_repeat_ngram_size = 18; + + // Min-P sampling threshold + optional float min_p = 19; + + // Variable beam width array for beam search + repeated int32 beam_width_array = 20; +} + +// ============================================================================ +// Output Configuration +// Maps to tensorrt_llm.bindings.executor.OutputConfig +// ============================================================================ + +message OutputConfig { + // Number of top log probabilities to return per output token + optional int32 logprobs = 1; + + // Number of top log probabilities to return per prompt token + optional int32 prompt_logprobs = 2; + + // Return context logits tensor (large, use with caution) + bool return_context_logits = 3; + + // Return generation logits tensor (large, use with caution) + bool return_generation_logits = 4; + + // Exclude input tokens from output (set to true to enable; TRT-LLM defaults to true internally) + bool exclude_input_from_output = 5; + + // Return encoder output (for encoder-decoder models) + bool return_encoder_output = 6; + + // Return performance metrics + bool return_perf_metrics = 7; +} + +// ============================================================================ +// Guided Decoding +// ============================================================================ + +message GuidedDecodingParams { + // Guide type enumeration + enum GuideType { + GUIDE_TYPE_UNSPECIFIED = 0; + GUIDE_TYPE_JSON = 1; // JSON format (any valid JSON) + GUIDE_TYPE_JSON_SCHEMA = 2; // JSON with schema constraint + GUIDE_TYPE_REGEX = 3; // Regular expression constraint + GUIDE_TYPE_EBNF_GRAMMAR = 4; // EBNF grammar constraint + GUIDE_TYPE_STRUCTURAL_TAG = 5; // Structural tag (xgrammar backend) + } + + GuideType guide_type = 1; + + // Guide content (schema string, regex pattern, or grammar definition) + string guide = 2; +} + +// ============================================================================ +// LoRA Configuration +// ============================================================================ + +message LoraConfig { + // LoRA task/adapter ID + int64 task_id = 1; + + // LoRA weights (serialized tensor, optional if already cached) + optional bytes weights = 2; + + // LoRA config as JSON string + optional string config_json = 3; +} + +// ============================================================================ +// Prompt Tuning Configuration +// ============================================================================ + +message PromptTuningConfig { + // Embedding table (serialized tensor) + bytes embedding_table = 1; +} + +// ============================================================================ +// Multimodal Input (for Vision-Language Models) +// ============================================================================ + +message MultimodalInput { + // Multimodal content hashes for caching + repeated int64 multimodal_hashes = 1; + + // Positions in input where multimodal content is inserted + repeated int32 multimodal_positions = 2; + + // Lengths of multimodal content at each position + repeated int32 multimodal_lengths = 3; +} + +// ============================================================================ +// KV Cache Retention +// ============================================================================ + +message KvCacheRetentionConfig { + // Retention policy name + string policy = 1; + + // Additional configuration as JSON string + string config_json = 2; +} + +// ============================================================================ +// Disaggregated Inference +// ============================================================================ + +message DisaggregatedParams { + // Request type for disaggregated inference + enum RequestType { + REQUEST_TYPE_CONTEXT_AND_GENERATION = 0; // Normal full request + REQUEST_TYPE_CONTEXT_ONLY = 1; // Prefill only + REQUEST_TYPE_GENERATION_ONLY = 2; // Decode only + } + + RequestType request_type = 1; + + // Context request ID (links context and generation phases) + string ctx_request_id = 2; + + // Context phase parameters (for generation_only requests) + optional ContextPhaseParams context_phase_params = 3; +} + +message ContextPhaseParams { + // First generated token ID from context phase + uint32 first_gen_token_id = 1; + + // KV cache block pointers (serialized) + bytes kv_cache_blocks = 2; +} + +// ============================================================================ +// Lookahead Decoding +// ============================================================================ + +message LookaheadConfig { + int32 max_window_size = 1; + int32 max_ngram_size = 2; + int32 max_verification_set_size = 3; +} + +// ============================================================================ +// Generate Response +// ============================================================================ + +message GenerateResponse { + // Request ID echo + string request_id = 1; + + // Response type (oneof ensures exactly one is set) + oneof response { + GenerateStreamChunk chunk = 2; // Streaming delta + GenerateComplete complete = 3; // Final response + GenerateError error = 4; // Error response + } +} + +// Streaming chunk containing delta tokens (new tokens since last chunk) +message GenerateStreamChunk { + // NEW token IDs only (delta from previous chunk) + repeated uint32 token_ids = 1; + + // Beam/sequence index (for beam_width > 1 or n > 1) + uint32 sequence_index = 2; + + // Token counts for usage tracking + uint32 prompt_tokens = 3; + uint32 completion_tokens = 4; + uint32 cached_tokens = 5; + + // Log probabilities for this chunk's tokens (if requested) + repeated TokenLogprob logprobs = 6; +} + +// Final/complete response with all output tokens +message GenerateComplete { + // All output token IDs (cumulative, not delta) + repeated uint32 output_token_ids = 1; + + // Beam/sequence index + uint32 sequence_index = 2; + + // Finish reason: "stop", "length", "stop_word" + string finish_reason = 3; + + // Specific stop reason (stop word/token that triggered stop) + optional string stop_reason = 4; + + // Token counts for usage tracking + uint32 prompt_tokens = 5; + uint32 completion_tokens = 6; + uint32 cached_tokens = 7; + + // Generation log probabilities (if requested) + repeated TokenLogprob logprobs = 8; + + // Prompt log probabilities (if requested) + repeated TokenLogprob prompt_logprobs = 9; + + // Performance metrics (if requested) + optional PerfMetrics perf_metrics = 10; + + // Context logits (if requested) - serialized float tensor + optional bytes context_logits = 11; + + // Generation logits (if requested) - serialized float tensor + optional bytes generation_logits = 12; +} + +// Token log probability information +message TokenLogprob { + uint32 token_id = 1; + float logprob = 2; + + // Top alternative tokens and their log probabilities + repeated TopLogprob top_logprobs = 3; +} + +message TopLogprob { + uint32 token_id = 1; + float logprob = 2; +} + +// Performance metrics for request +message PerfMetrics { + double arrival_time = 1; // When request arrived + double first_scheduled_time = 2; // When first scheduled + double first_token_time = 3; // Time to first token (TTFT) + double last_token_time = 4; // When last token generated + double kv_cache_transfer_start = 5; // KV cache transfer start (disagg) + double kv_cache_transfer_end = 6; // KV cache transfer end (disagg) + int64 kv_cache_size = 7; // KV cache size in bytes +} + +// Error response +message GenerateError { + string message = 1; // Human-readable error message + string type = 2; // Error type (e.g., "InvalidRequest", "InternalError") + int32 code = 3; // Error code +} + +// ============================================================================ +// Embed Request/Response (for embedding models) +// ============================================================================ + +message EmbedRequest { + string request_id = 1; + TokenizedInput tokenized = 2; +} + +message EmbedResponse { + string request_id = 1; + repeated float embedding = 2; // Embedding vector + uint32 prompt_tokens = 3; +} + +// ============================================================================ +// Health Check +// ============================================================================ + +message HealthCheckRequest {} + +message HealthCheckResponse { + string status = 1; // "OK" or error description +} + +// ============================================================================ +// Abort Request +// ============================================================================ + +message AbortRequest { + string request_id = 1; +} + +message AbortResponse { + bool success = 1; + string message = 2; +} + +// ============================================================================ +// Model Info +// ============================================================================ + +message GetModelInfoRequest {} + +message GetModelInfoResponse { + string model_id = 1; // Model identifier/path + int32 max_input_len = 2; // Maximum input length + int32 max_seq_len = 3; // Maximum sequence length (input + output) + int32 max_batch_size = 4; // Maximum batch size + int32 vocab_size = 5; // Vocabulary size + int32 hidden_size = 6; // Hidden dimension + int32 num_layers = 7; // Number of transformer layers + int32 num_heads = 8; // Number of attention heads + + // Supported features + repeated string supported_features = 9; // e.g., "lora", "guided_decoding" +} + +// ============================================================================ +// Server Info +// ============================================================================ + +message GetServerInfoRequest {} + +message GetServerInfoResponse { + string version = 1; // TensorRT-LLM version + string backend = 2; // "tensorrt" or "pytorch" + int32 tensor_parallel_size = 3; // TP size + int32 pipeline_parallel_size = 4; // PP size + int32 context_parallel_size = 5; // CP size + int32 world_size = 6; // Total world size +} diff --git a/tests/unittest/llmapi/test_grpc.py b/tests/unittest/llmapi/test_grpc.py new file mode 100644 index 0000000000..d766900cfd --- /dev/null +++ b/tests/unittest/llmapi/test_grpc.py @@ -0,0 +1,315 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Unit tests for gRPC server components.""" + +import pytest + +from tensorrt_llm.grpc import trtllm_service_pb2 as pb2 +from tensorrt_llm.grpc.grpc_request_manager import ( + create_disaggregated_params_from_proto, + create_lora_request_from_proto, + create_sampling_params_from_proto, +) + +pytestmark = pytest.mark.threadleak(enabled=False) + + +class TestSamplingParamsConversion: + """Tests for proto to SamplingParams conversion.""" + + def test_basic_sampling_config(self): + """Test basic sampling config conversion.""" + proto_config = pb2.SamplingConfig( + beam_width=1, + num_return_sequences=1, + temperature=0.7, + top_k=50, + top_p=0.9, + ) + output_config = pb2.OutputConfig() + + params = create_sampling_params_from_proto( + proto_config=proto_config, + output_config=output_config, + max_tokens=100, + ) + + assert params.max_tokens == 100 + assert params.temperature == 0.7 + assert params.top_k == 50 + assert params.top_p == 0.9 + + def test_beam_search_config(self): + """Test beam search configuration.""" + proto_config = pb2.SamplingConfig( + beam_width=4, + num_return_sequences=2, + length_penalty=1.2, + early_stopping=1, + ) + output_config = pb2.OutputConfig() + + params = create_sampling_params_from_proto( + proto_config=proto_config, + output_config=output_config, + max_tokens=50, + ) + + assert params.beam_width == 4 + assert params.n == 2 + assert params.length_penalty == 1.2 + + def test_penalties_config(self): + """Test penalty parameters conversion.""" + proto_config = pb2.SamplingConfig( + repetition_penalty=1.1, + presence_penalty=0.5, + frequency_penalty=0.3, + ) + output_config = pb2.OutputConfig() + + params = create_sampling_params_from_proto( + proto_config=proto_config, + output_config=output_config, + max_tokens=100, + ) + + assert params.repetition_penalty == 1.1 + assert params.presence_penalty == 0.5 + assert params.frequency_penalty == 0.3 + + def test_logprobs_config(self): + """Test logprobs configuration.""" + proto_config = pb2.SamplingConfig() + output_config = pb2.OutputConfig( + logprobs=5, + prompt_logprobs=3, + ) + + params = create_sampling_params_from_proto( + proto_config=proto_config, + output_config=output_config, + max_tokens=100, + ) + + assert params.logprobs == 5 + assert params.prompt_logprobs == 3 + + def test_guided_decoding_json_schema(self): + """Test guided decoding with JSON schema.""" + proto_config = pb2.SamplingConfig() + output_config = pb2.OutputConfig() + guided_decoding = pb2.GuidedDecodingParams( + guide_type=pb2.GuidedDecodingParams.GUIDE_TYPE_JSON_SCHEMA, + guide='{"type": "object", "properties": {"name": {"type": "string"}}}', + ) + + params = create_sampling_params_from_proto( + proto_config=proto_config, + output_config=output_config, + max_tokens=100, + guided_decoding=guided_decoding, + ) + + assert params.guided_decoding_params is not None + assert params.guided_decoding_params.json_schema is not None + + def test_guided_decoding_regex(self): + """Test guided decoding with regex.""" + proto_config = pb2.SamplingConfig() + output_config = pb2.OutputConfig() + guided_decoding = pb2.GuidedDecodingParams( + guide_type=pb2.GuidedDecodingParams.GUIDE_TYPE_REGEX, + guide=r"\d{3}-\d{4}", + ) + + params = create_sampling_params_from_proto( + proto_config=proto_config, + output_config=output_config, + max_tokens=100, + guided_decoding=guided_decoding, + ) + + assert params.guided_decoding_params is not None + assert params.guided_decoding_params.regex is not None + + +class TestLoraRequestConversion: + """Tests for proto to LoRARequest conversion.""" + + def test_basic_lora_config(self): + """Test basic LoRA config conversion.""" + lora_config = pb2.LoraConfig(task_id=123) + + request = create_lora_request_from_proto(lora_config) + + assert request is not None + assert request.task_id == 123 + + def test_none_lora_config(self): + """Test None LoRA config returns None.""" + request = create_lora_request_from_proto(None) + assert request is None + + +class TestDisaggregatedParamsConversion: + """Tests for proto to DisaggregatedParams conversion.""" + + def test_context_only_request(self): + """Test context-only disaggregated request.""" + proto_params = pb2.DisaggregatedParams( + request_type=pb2.DisaggregatedParams.REQUEST_TYPE_CONTEXT_ONLY, + ctx_request_id="ctx-123", + ) + + params = create_disaggregated_params_from_proto(proto_params) + + assert params is not None + assert params.ctx_request_id == "ctx-123" + + def test_generation_only_request(self): + """Test generation-only disaggregated request.""" + proto_params = pb2.DisaggregatedParams( + request_type=pb2.DisaggregatedParams.REQUEST_TYPE_GENERATION_ONLY, + ctx_request_id="gen-456", + ) + + params = create_disaggregated_params_from_proto(proto_params) + + assert params is not None + + def test_none_params(self): + """Test None disaggregated params returns None.""" + params = create_disaggregated_params_from_proto(None) + assert params is None + + +class TestProtoMessages: + """Tests for proto message structure.""" + + def test_generate_request_structure(self): + """Test GenerateRequest message structure.""" + request = pb2.GenerateRequest( + request_id="test-123", + tokenized=pb2.TokenizedInput( + input_token_ids=[1, 2, 3, 4, 5], + original_text="Hello world", + ), + sampling_config=pb2.SamplingConfig(temperature=0.8), + max_tokens=50, + streaming=True, + ) + + assert request.request_id == "test-123" + assert list(request.tokenized.input_token_ids) == [1, 2, 3, 4, 5] + assert request.tokenized.original_text == "Hello world" + assert request.sampling_config.temperature == 0.8 + assert request.max_tokens == 50 + assert request.streaming is True + + def test_generate_response_chunk(self): + """Test GenerateResponse with chunk.""" + response = pb2.GenerateResponse( + request_id="test-123", + chunk=pb2.GenerateStreamChunk( + token_ids=[10, 11, 12], + sequence_index=0, + prompt_tokens=5, + completion_tokens=3, + ), + ) + + assert response.request_id == "test-123" + assert list(response.chunk.token_ids) == [10, 11, 12] + assert response.chunk.prompt_tokens == 5 + assert response.chunk.completion_tokens == 3 + + def test_generate_response_complete(self): + """Test GenerateResponse with complete.""" + response = pb2.GenerateResponse( + request_id="test-123", + complete=pb2.GenerateComplete( + output_token_ids=[10, 11, 12, 13], + finish_reason="stop", + prompt_tokens=5, + completion_tokens=4, + ), + ) + + assert response.request_id == "test-123" + assert list(response.complete.output_token_ids) == [10, 11, 12, 13] + assert response.complete.finish_reason == "stop" + + def test_health_check_messages(self): + """Test HealthCheck messages.""" + _request = pb2.HealthCheckRequest() # noqa: F841 - verify message construction + response = pb2.HealthCheckResponse(status="healthy") + + assert response.status == "healthy" + + def test_model_info_response(self): + """Test GetModelInfoResponse message.""" + response = pb2.GetModelInfoResponse( + model_id="meta-llama/Llama-2-7b", + max_input_len=4096, + max_seq_len=8192, + vocab_size=32000, + ) + + assert response.model_id == "meta-llama/Llama-2-7b" + assert response.max_input_len == 4096 + assert response.max_seq_len == 8192 + assert response.vocab_size == 32000 + + def test_server_info_response(self): + """Test GetServerInfoResponse message.""" + response = pb2.GetServerInfoResponse( + version="0.17.0", + backend="tensorrt-llm", + tensor_parallel_size=2, + pipeline_parallel_size=1, + world_size=2, + ) + + assert response.version == "0.17.0" + assert response.backend == "tensorrt-llm" + assert response.tensor_parallel_size == 2 + assert response.world_size == 2 + + def test_embed_messages(self): + """Test Embed request and response messages.""" + request = pb2.EmbedRequest( + request_id="embed-123", + tokenized=pb2.TokenizedInput(input_token_ids=[1, 2, 3]), + ) + response = pb2.EmbedResponse( + request_id="embed-123", + embedding=[0.1, 0.2, 0.3, 0.4], + prompt_tokens=3, + ) + + assert request.request_id == "embed-123" + assert response.request_id == "embed-123" + assert list(response.embedding) == [0.1, 0.2, 0.3, 0.4] + assert response.prompt_tokens == 3 + + def test_abort_messages(self): + """Test Abort request and response messages.""" + request = pb2.AbortRequest(request_id="abort-123") + response = pb2.AbortResponse(success=True, message="Request aborted") + + assert request.request_id == "abort-123" + assert response.success is True + assert response.message == "Request aborted"