[None][feat] Add gRPC server for high-performance external router integration (#11037)

Signed-off-by: Chang Su <chang.s.su@oracle.com>
This commit is contained in:
Chang Su 2026-01-29 15:48:27 -08:00 committed by GitHub
parent e033929221
commit dbad94715b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 2273 additions and 5 deletions

View File

@ -18,9 +18,52 @@ from pathlib import Path
from typing import List
from setuptools import find_packages, setup
from setuptools.command.build_py import build_py
from setuptools.dist import Distribution
class BuildPyWithProtoCompile(build_py):
"""Custom build_py command that compiles protobuf files before building."""
def run(self):
self.compile_grpc_protos()
super().run()
def compile_grpc_protos(self):
"""Compile gRPC protobuf files if the proto file exists."""
grpc_dir = Path(__file__).parent / "tensorrt_llm" / "grpc"
proto_file = grpc_dir / "trtllm_service.proto"
compile_script = grpc_dir / "compile_protos.py"
if not proto_file.exists():
return
# Check if pb2 files need to be generated
pb2_file = grpc_dir / "trtllm_service_pb2.py"
pb2_grpc_file = grpc_dir / "trtllm_service_pb2_grpc.py"
# Regenerate if pb2 files don't exist or are older than proto file
needs_compile = (not pb2_file.exists() or not pb2_grpc_file.exists() or
pb2_file.stat().st_mtime < proto_file.stat().st_mtime)
if needs_compile and compile_script.exists():
import subprocess
import sys
print("Compiling gRPC protobuf files...")
try:
subprocess.run(
[sys.executable, str(compile_script)],
check=True,
cwd=str(grpc_dir.parent.parent),
)
print("gRPC protobuf compilation successful")
except subprocess.CalledProcessError as e:
print(f"Warning: Failed to compile gRPC protos: {e}")
except Exception as e:
print(f"Warning: gRPC proto compilation skipped: {e}")
def parse_requirements(filename: os.PathLike):
with open(filename) as f:
requirements = f.read().splitlines()
@ -374,6 +417,7 @@ packages += find_packages(include=["triton_kernels", "triton_kernels.*"])
setup(
name='tensorrt_llm',
version=get_version(),
cmdclass={'build_py': BuildPyWithProtoCompile},
description=
('TensorRT LLM provides users with an easy-to-use Python API to define Large Language Models (LLMs) and supports '
'state-of-the-art optimizations to perform inference efficiently on NVIDIA GPUs.'

View File

@ -233,6 +233,124 @@ def launch_server(
asyncio.run(server(host, port, sockets=[s]))
def launch_grpc_server(host: str, port: int, llm_args: dict):
"""
Launch a gRPC server for TensorRT-LLM.
This provides a high-performance gRPC interface designed for external routers
(e.g., sgl-router) using pre-tokenized input and raw token ID output.
Args:
host: Host to bind to
port: Port to bind to
llm_args: Arguments for LLM initialization (from get_llm_args)
"""
import grpc
try:
from grpc_reflection.v1alpha import reflection
REFLECTION_AVAILABLE = True
except ImportError:
REFLECTION_AVAILABLE = False
from tensorrt_llm.grpc import trtllm_service_pb2, trtllm_service_pb2_grpc
from tensorrt_llm.grpc.grpc_request_manager import GrpcRequestManager
from tensorrt_llm.grpc.grpc_servicer import TrtllmServiceServicer
async def serve_grpc_async():
logger.info("Initializing TensorRT-LLM gRPC server...")
backend = llm_args.get("backend")
model_path = llm_args.get("model", "")
if backend == "pytorch":
llm_args.pop("build_config", None)
llm = PyTorchLLM(**llm_args)
elif backend == "_autodeploy":
from tensorrt_llm._torch.auto_deploy import LLM as AutoDeployLLM
llm_args.pop("build_config", None)
llm = AutoDeployLLM(**llm_args)
elif backend == "tensorrt" or backend == "trt":
llm_args.pop("backend")
llm = LLM(**llm_args)
else:
raise click.BadParameter(
f"{backend} is not a known backend, check help for available options.",
param_hint="backend")
logger.info("Model loaded successfully")
# Create request manager
request_manager = GrpcRequestManager(llm)
# Create servicer
servicer = TrtllmServiceServicer(request_manager, model_path=model_path)
# Create gRPC server
server = grpc.aio.server(
options=[
("grpc.max_send_message_length", -1), # Unlimited
("grpc.max_receive_message_length", -1), # Unlimited
("grpc.keepalive_time_ms", 30000), # 30s keepalive
("grpc.keepalive_timeout_ms", 10000), # 10s timeout
], )
# Add servicer to server
trtllm_service_pb2_grpc.add_TrtllmServiceServicer_to_server(
servicer, server)
# Enable reflection for grpcurl and other tools
if REFLECTION_AVAILABLE:
service_names = (
trtllm_service_pb2.DESCRIPTOR.services_by_name["TrtllmService"].
full_name,
reflection.SERVICE_NAME,
)
reflection.enable_server_reflection(service_names, server)
logger.info("gRPC reflection enabled")
# Bind to address
address = f"{host}:{port}"
server.add_insecure_port(address)
# Start server
await server.start()
logger.info(f"TensorRT-LLM gRPC server started on {address}")
logger.info("Server is ready to accept requests")
# Handle shutdown signals
loop = asyncio.get_running_loop()
stop_event = asyncio.Event()
def signal_handler():
logger.info("Received shutdown signal")
stop_event.set()
for sig in (signal.SIGTERM, signal.SIGINT):
loop.add_signal_handler(sig, signal_handler)
# Serve until shutdown signal
try:
await stop_event.wait()
except KeyboardInterrupt:
logger.info("Interrupted by user")
finally:
logger.info("Shutting down TensorRT-LLM gRPC server...")
# Stop gRPC server
await server.stop(grace=5.0)
logger.info("gRPC server stopped")
# Shutdown LLM
if hasattr(llm, "shutdown"):
llm.shutdown()
logger.info("LLM engine stopped")
logger.info("Shutdown complete")
asyncio.run(serve_grpc_async())
def launch_mm_encoder_server(
host: str,
port: int,
@ -476,6 +594,12 @@ class ChoiceWithAlias(click.Choice):
"Specify a custom chat template. "
"Can be a file path or one-liner template string",
"prototype"))
@click.option(
"--grpc",
is_flag=True,
default=False,
help="Run gRPC server instead of OpenAI HTTP server. "
"gRPC server accepts pre-tokenized requests and returns raw token IDs.")
def serve(
model: str, tokenizer: Optional[str], custom_tokenizer: Optional[str],
host: str, port: int, log_level: str, backend: str, max_beam_width: int,
@ -491,8 +615,9 @@ def serve(
fail_fast_on_attention_window_too_large: bool,
otlp_traces_endpoint: Optional[str], enable_chunked_prefill: bool,
disagg_cluster_uri: Optional[str], media_io_kwargs: Optional[str],
custom_module_dirs: list[Path], chat_template: Optional[str]):
"""Running an OpenAI API compatible server
custom_module_dirs: list[Path], chat_template: Optional[str],
grpc: bool):
"""Running an OpenAI API compatible server (or gRPC server with --grpc flag)
MODEL: model name | HF checkpoint path | TensorRT engine path
"""
@ -569,9 +694,29 @@ def serve(
multimodal_server_config = MultimodalServerConfig(
media_io_kwargs=parsed_media_io_kwargs)
launch_server(host, port, llm_args, tool_parser, chat_template,
metadata_server_cfg, server_role, disagg_cluster_config,
multimodal_server_config)
if grpc:
# gRPC mode: launch gRPC server instead of OpenAI HTTP server
# Check for unsupported arguments that are silently ignored in gRPC mode
unsupported_args = {
"tool_parser": tool_parser,
"chat_template": chat_template,
"metadata_server_config_file": metadata_server_config_file,
"server_role": server_role,
"disagg_cluster_config": disagg_cluster_config,
}
for name, value in unsupported_args.items():
if value is not None:
raise ValueError(
f"Argument '{name}' is not supported when running in gRPC mode. "
f"The gRPC server is designed for use with external routers that handle "
f"these features (e.g., tool parsing, chat templates).")
launch_grpc_server(host, port, llm_args)
else:
# Default: launch OpenAI HTTP server
launch_server(host, port, llm_args, tool_parser, chat_template,
metadata_server_cfg, server_role, disagg_cluster_config,
multimodal_server_config)
@click.command("mm_embedding_serve")

View File

@ -0,0 +1,121 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""TensorRT-LLM gRPC module for high-performance communication with external routers.
This module provides a gRPC server interface that accepts pre-tokenized requests
and returns raw token IDs, enabling efficient binary communication with Rust-based
routers like sgl-router.
Key Features:
- Pre-tokenized input (no Python tokenization overhead)
- Raw token ID output (no Python detokenization overhead)
- Streaming support with delta tokens
- Full sampling parameter support
- Guided decoding (JSON schema, regex, grammar)
- LoRA and prompt tuning support
- Disaggregated inference support
Usage:
python -m tensorrt_llm.commands.serve /path/to/model \
--grpc \
--host 0.0.0.0 \
--port 50051
"""
from pathlib import Path
# Module directory for proto files
GRPC_MODULE_DIR = Path(__file__).parent
# Proto file path
PROTO_FILE = GRPC_MODULE_DIR / "trtllm_service.proto"
# Try to import generated protobuf modules
try:
from . import trtllm_service_pb2, trtllm_service_pb2_grpc
PROTOS_AVAILABLE = True
except ImportError:
PROTOS_AVAILABLE = False
trtllm_service_pb2 = None
trtllm_service_pb2_grpc = None
def compile_protos():
"""Compile protobuf files to generate Python modules.
Run this function if the generated *_pb2.py files are missing.
Alternatively, run: python tensorrt_llm/grpc/compile_protos.py
"""
from .compile_protos import main as compile_main
compile_main()
def ensure_protos_available():
"""Ensure protobuf modules are available, compiling if necessary."""
global PROTOS_AVAILABLE, trtllm_service_pb2, trtllm_service_pb2_grpc
if not PROTOS_AVAILABLE:
raise ImportError(
"gRPC protobuf modules are not available. "
"Please generate them by running: "
"python tensorrt_llm/grpc/compile_protos.py"
)
# Try to import request manager
try:
from .grpc_request_manager import (
GrpcRequestManager,
create_disaggregated_params_from_proto,
create_lora_request_from_proto,
create_sampling_params_from_proto,
)
REQUEST_MANAGER_AVAILABLE = True
except ImportError:
REQUEST_MANAGER_AVAILABLE = False
GrpcRequestManager = None
create_sampling_params_from_proto = None
create_lora_request_from_proto = None
create_disaggregated_params_from_proto = None
# Try to import servicer
try:
from .grpc_servicer import TrtllmServiceServicer
SERVICER_AVAILABLE = True
except ImportError:
SERVICER_AVAILABLE = False
TrtllmServiceServicer = None
__all__ = [
"GRPC_MODULE_DIR",
"PROTO_FILE",
"PROTOS_AVAILABLE",
"REQUEST_MANAGER_AVAILABLE",
"SERVICER_AVAILABLE",
"compile_protos",
"ensure_protos_available",
"trtllm_service_pb2",
"trtllm_service_pb2_grpc",
"GrpcRequestManager",
"TrtllmServiceServicer",
"create_sampling_params_from_proto",
"create_lora_request_from_proto",
"create_disaggregated_params_from_proto",
]

View File

@ -0,0 +1,167 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Compile Protocol Buffer definitions for TensorRT-LLM gRPC server.
This script generates Python bindings from the trtllm_service.proto file.
Usage:
python -m tensorrt_llm.grpc.compile_protos
Or directly:
python compile_protos.py
Requirements:
pip install grpcio-tools
"""
import subprocess # nosec B404
import sys
from pathlib import Path
def compile_protos(proto_dir: Path = None, output_dir: Path = None) -> bool:
"""Compile proto files to Python.
Args:
proto_dir: Directory containing .proto files. Defaults to this script's directory.
output_dir: Directory for generated Python files. Defaults to proto_dir.
Returns:
True if compilation succeeded, False otherwise.
"""
if proto_dir is None:
proto_dir = Path(__file__).parent
if output_dir is None:
output_dir = proto_dir
proto_file = proto_dir / "trtllm_service.proto"
if not proto_file.exists():
print(f"Error: Proto file not found: {proto_file}")
return False
# Check for grpcio-tools
try:
from grpc_tools import protoc
except ImportError:
print("grpcio-tools not installed. Installing...")
subprocess.check_call([sys.executable, "-m", "pip", "install", "grpcio-tools"])
from grpc_tools import protoc
# Compile proto file
print(f"Compiling {proto_file}...")
result = protoc.main(
[
"grpc_tools.protoc",
f"--proto_path={proto_dir}",
f"--python_out={output_dir}",
f"--grpc_python_out={output_dir}",
str(proto_file),
]
)
if result != 0:
print(f"Error: protoc failed with code {result}")
return False
# Fix imports in generated files (grpc_tools generates absolute imports)
pb2_file = output_dir / "trtllm_service_pb2.py"
pb2_grpc_file = output_dir / "trtllm_service_pb2_grpc.py"
if pb2_grpc_file.exists():
content = pb2_grpc_file.read_text()
# Fix import to use relative import
content = content.replace(
"import trtllm_service_pb2 as trtllm__service__pb2",
"from . import trtllm_service_pb2 as trtllm__service__pb2",
)
pb2_grpc_file.write_text(content)
print(f"Fixed imports in {pb2_grpc_file}")
print("Generated files:")
print(f" - {pb2_file}")
print(f" - {pb2_grpc_file}")
print("Proto compilation successful!")
return True
def verify_generated_files(output_dir: Path = None) -> bool:
"""Verify that generated files can be imported.
Args:
output_dir: Directory containing generated files.
Returns:
True if files can be imported, False otherwise.
"""
if output_dir is None:
output_dir = Path(__file__).parent
pb2_file = output_dir / "trtllm_service_pb2.py"
pb2_grpc_file = output_dir / "trtllm_service_pb2_grpc.py"
if not pb2_file.exists() or not pb2_grpc_file.exists():
print("Generated files not found. Run compile_protos() first.")
return False
# Try to import
import importlib.util
try:
spec = importlib.util.spec_from_file_location("trtllm_service_pb2", pb2_file)
pb2_module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(pb2_module)
print(f"Successfully imported {pb2_file.name}")
# Verify key message types exist
assert hasattr(pb2_module, "GenerateRequest")
assert hasattr(pb2_module, "GenerateResponse")
assert hasattr(pb2_module, "SamplingConfig")
assert hasattr(pb2_module, "OutputConfig")
print(" - GenerateRequest, GenerateResponse, SamplingConfig, OutputConfig OK")
except Exception as e:
print(f"Error importing {pb2_file.name}: {e}")
return False
print("Verification successful!")
return True
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Compile TensorRT-LLM gRPC protos")
parser.add_argument(
"--proto-dir", type=Path, default=None, help="Directory containing .proto files"
)
parser.add_argument(
"--output-dir", type=Path, default=None, help="Directory for generated Python files"
)
parser.add_argument(
"--verify", action="store_true", help="Verify generated files after compilation"
)
args = parser.parse_args()
success = compile_protos(args.proto_dir, args.output_dir)
if success and args.verify:
success = verify_generated_files(args.output_dir)
sys.exit(0 if success else 1)

View File

@ -0,0 +1,420 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""gRPC Request Manager for TensorRT-LLM.
Manages request lifecycle for gRPC requests, converting between protobuf
and TensorRT-LLM types. Designed for high-performance communication with
external routers (e.g., sgl-router) using pre-tokenized input.
Key optimization: Sets detokenize=False in SamplingParams to skip
detokenization and return token IDs only.
"""
import asyncio
import traceback
from collections.abc import AsyncGenerator
from typing import Any, Dict, List, Optional, Tuple
from tensorrt_llm.disaggregated_params import DisaggregatedParams
from tensorrt_llm.executor.request import LoRARequest, PromptAdapterRequest
from tensorrt_llm.executor.result import GenerationResult
from tensorrt_llm.llmapi.llm_utils import KvCacheRetentionConfig
from tensorrt_llm.logger import logger
from tensorrt_llm.sampling_params import GuidedDecodingParams, SamplingParams
from . import trtllm_service_pb2 as pb2
class GrpcRequestManager:
"""Manages gRPC request lifecycle for TensorRT-LLM.
Responsibilities:
- Convert protobuf requests to TensorRT-LLM types
- Set detokenize=False in SamplingParams (key optimization!)
- Submit requests to LLM.generate_async()
- Stream token IDs (not text) back to gRPC clients
- Handle abort/cancel operations
This is modeled after vLLM's GrpcRequestManager but adapted for TensorRT-LLM's
GenerationResult async iterator pattern.
"""
def __init__(self, llm: Any):
"""Initialize the request manager.
Args:
llm: The TensorRT-LLM LLM instance (tensorrt_llm.LLM or tensorrt_llm._tensorrt_engine.LLM)
"""
self.llm = llm
# Track active requests: request_id -> GenerationResult
self._rid_to_result: Dict[str, GenerationResult] = {}
logger.info("GrpcRequestManager initialized")
async def generate(
self,
request_id: str,
prompt_token_ids: List[int],
sampling_params: SamplingParams,
streaming: bool = True,
lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
kv_cache_retention_config: Optional[KvCacheRetentionConfig] = None,
disaggregated_params: Optional[DisaggregatedParams] = None,
) -> AsyncGenerator[GenerationResult, None]:
"""Submit a generation request and stream outputs.
Args:
request_id: Unique request identifier (for tracking/abort)
prompt_token_ids: Pre-tokenized input from Rust router
sampling_params: Sampling parameters (with detokenize=False!)
streaming: Whether to stream results
lora_request: Optional LoRA adapter request
prompt_adapter_request: Optional prompt adapter request
kv_cache_retention_config: KV cache retention config
disaggregated_params: Disaggregated inference params
Yields:
GenerationResult objects containing token IDs (text will be empty
because detokenize=False)
"""
try:
# Submit to LLM.generate_async which returns a GenerationResult
# that is an async iterator
gen_result = self.llm.generate_async(
{"prompt_token_ids": prompt_token_ids},
sampling_params,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
streaming=streaming,
kv_cache_retention_config=kv_cache_retention_config,
disaggregated_params=disaggregated_params,
)
# Track the result for potential abort
self._rid_to_result[request_id] = gen_result
# Iterate over the async generator
# GenerationResult implements __aiter__ and __anext__
async for result in gen_result:
yield result
if result.finished:
break
except asyncio.CancelledError:
logger.info(f"Request {request_id} cancelled by client")
await self.abort(request_id)
raise
except Exception as e:
logger.error(f"Error in generate for {request_id}: {e}")
raise
finally:
# Cleanup tracking
self._rid_to_result.pop(request_id, None)
async def abort(self, request_id: str) -> bool:
"""Abort a running request.
Args:
request_id: The request ID to abort
Returns:
True if request was found and aborted, False otherwise
"""
gen_result = self._rid_to_result.get(request_id)
if gen_result is None:
logger.debug(f"Abort: request {request_id} not found (may have already completed)")
return False
try:
# GenerationResult has an abort() method
gen_result.abort()
self._rid_to_result.pop(request_id, None)
logger.info(f"Request {request_id} aborted")
return True
except Exception as e:
logger.error(f"Error aborting request {request_id}: {e}")
self._rid_to_result.pop(request_id, None)
return False
async def health_check(self) -> Tuple[bool, str]:
"""Check if the engine is healthy.
Returns:
Tuple of (is_healthy, message)
"""
try:
if self.llm is None:
return False, "LLM not initialized"
# Check if executor is available and not shutdown
if hasattr(self.llm, "_executor"):
if self.llm._executor is None or self.llm._executor.is_shutdown():
return False, "Executor is shutdown"
return True, "OK"
except Exception as e:
logger.error(f"Health check error: {e}")
return False, f"Error: {e}"
def get_model_config(self) -> Dict[str, Any]:
"""Get model configuration information.
Returns:
Dictionary with model config details
"""
config = {
"model_path": "",
"is_generation": True,
"max_context_length": 0,
"max_seq_len": 0,
"vocab_size": 0,
"supports_vision": False,
}
try:
# Try to get model path
if hasattr(self.llm, "args"):
if hasattr(self.llm.args, "model"):
config["model_path"] = str(self.llm.args.model)
# Try to get tokenizer info
if hasattr(self.llm, "tokenizer") and self.llm.tokenizer is not None:
if hasattr(self.llm.tokenizer, "vocab_size"):
config["vocab_size"] = self.llm.tokenizer.vocab_size
# Try to get max context length from various sources
if hasattr(self.llm, "args") and self.llm.args is not None:
args = self.llm.args
# Try max_input_len first (input context)
if hasattr(args, "max_input_len") and args.max_input_len:
config["max_context_length"] = args.max_input_len
# Try max_seq_len (total sequence including output)
if hasattr(args, "max_seq_len") and args.max_seq_len:
config["max_seq_len"] = args.max_seq_len
# Check for multimodal support
if hasattr(self.llm, "input_processor"):
processor_name = type(self.llm.input_processor).__name__
config["supports_vision"] = processor_name != "DefaultInputProcessor"
except Exception as e:
logger.warning(
f"Error getting model config: {type(e).__name__}: {e}\n{traceback.format_exc()}"
)
return config
def get_num_unfinished_requests(self) -> int:
"""Get the number of currently running requests.
Returns:
Number of unfinished requests
"""
return len(self._rid_to_result)
def create_sampling_params_from_proto(
proto_config: pb2.SamplingConfig,
output_config: pb2.OutputConfig,
max_tokens: int,
end_id: Optional[int] = None,
pad_id: Optional[int] = None,
bad_words: Optional[List[pb2.TokenSequence]] = None,
stop_words: Optional[List[pb2.TokenSequence]] = None,
guided_decoding: Optional[pb2.GuidedDecodingParams] = None,
embedding_bias: Optional[List[float]] = None,
) -> SamplingParams:
"""Convert protobuf configuration to TensorRT-LLM SamplingParams.
Args:
proto_config: Protobuf SamplingConfig message
output_config: Protobuf OutputConfig message
max_tokens: Maximum tokens to generate
end_id: End-of-sequence token ID
pad_id: Padding token ID
bad_words: Bad word token sequences
stop_words: Stop word token sequences
guided_decoding: Guided decoding parameters
embedding_bias: Embedding bias tensor
Returns:
TensorRT-LLM SamplingParams with detokenize=False
"""
# Build kwargs for SamplingParams
# KEY OPTIMIZATION: detokenize=False skips Python detokenization!
kwargs = {
"max_tokens": max_tokens,
"detokenize": False,
}
# Beam search / sampling
if proto_config.beam_width > 1:
kwargs["beam_width"] = proto_config.beam_width
if proto_config.num_return_sequences > 0:
kwargs["n"] = proto_config.num_return_sequences
# Temperature and sampling parameters (with sensible defaults as safety guard)
kwargs["temperature"] = (
proto_config.temperature if proto_config.HasField("temperature") else 1.0
)
kwargs["top_p"] = proto_config.top_p if proto_config.HasField("top_p") else 1.0
if proto_config.HasField("top_k"):
kwargs["top_k"] = proto_config.top_k
if proto_config.HasField("min_p"):
kwargs["min_p"] = proto_config.min_p
# Top-P decay parameters
if proto_config.HasField("top_p_min"):
kwargs["top_p_min"] = proto_config.top_p_min
if proto_config.HasField("top_p_reset_ids"):
kwargs["top_p_reset_ids"] = proto_config.top_p_reset_ids
if proto_config.HasField("top_p_decay"):
kwargs["top_p_decay"] = proto_config.top_p_decay
# Seed for reproducibility
if proto_config.HasField("seed"):
kwargs["random_seed"] = proto_config.seed
# Min/max tokens
if proto_config.HasField("min_tokens"):
kwargs["min_tokens"] = proto_config.min_tokens
# Penalties (repetition_penalty defaults to 1.0 = no penalty)
kwargs["repetition_penalty"] = (
proto_config.repetition_penalty if proto_config.HasField("repetition_penalty") else 1.0
)
if proto_config.HasField("presence_penalty"):
kwargs["presence_penalty"] = proto_config.presence_penalty
if proto_config.HasField("frequency_penalty"):
kwargs["frequency_penalty"] = proto_config.frequency_penalty
# Beam search parameters
if proto_config.HasField("beam_search_diversity_rate"):
kwargs["beam_search_diversity_rate"] = proto_config.beam_search_diversity_rate
if proto_config.HasField("length_penalty"):
kwargs["length_penalty"] = proto_config.length_penalty
if proto_config.HasField("early_stopping"):
kwargs["early_stopping"] = proto_config.early_stopping
# N-gram blocking
if proto_config.HasField("no_repeat_ngram_size"):
kwargs["no_repeat_ngram_size"] = proto_config.no_repeat_ngram_size
# End/pad tokens
if end_id is not None:
kwargs["end_id"] = end_id
if end_id == -1:
kwargs["ignore_eos"] = True
if pad_id is not None:
kwargs["pad_id"] = pad_id
# Output configuration - logprobs
if output_config.HasField("logprobs"):
kwargs["logprobs"] = output_config.logprobs
if output_config.HasField("prompt_logprobs"):
kwargs["prompt_logprobs"] = output_config.prompt_logprobs
if output_config.return_context_logits:
kwargs["return_context_logits"] = True
if output_config.return_generation_logits:
kwargs["return_generation_logits"] = True
if output_config.exclude_input_from_output:
kwargs["exclude_input_from_output"] = True
# Stop sequences (as token ID lists)
if stop_words:
kwargs["stop_words"] = [list(seq.token_ids) for seq in stop_words]
if bad_words:
kwargs["bad_words"] = [list(seq.token_ids) for seq in bad_words]
# Embedding bias
if embedding_bias:
kwargs["embedding_bias"] = list(embedding_bias)
# Guided decoding
if guided_decoding and guided_decoding.guide:
guide_type = guided_decoding.guide_type
guide_content = guided_decoding.guide
if guide_type == pb2.GuidedDecodingParams.GUIDE_TYPE_JSON:
# json_object=True for JSON validation without schema constraint
kwargs["guided_decoding_params"] = GuidedDecodingParams(json_object=True)
elif guide_type == pb2.GuidedDecodingParams.GUIDE_TYPE_JSON_SCHEMA:
kwargs["guided_decoding_params"] = GuidedDecodingParams(json_schema=guide_content)
elif guide_type == pb2.GuidedDecodingParams.GUIDE_TYPE_REGEX:
kwargs["guided_decoding_params"] = GuidedDecodingParams(regex=guide_content)
elif guide_type == pb2.GuidedDecodingParams.GUIDE_TYPE_EBNF_GRAMMAR:
kwargs["guided_decoding_params"] = GuidedDecodingParams(grammar=guide_content)
return SamplingParams(**kwargs)
def create_lora_request_from_proto(
proto_config: Optional[pb2.LoraConfig],
) -> Optional[LoRARequest]:
"""Convert protobuf LoraConfig to TensorRT-LLM LoRARequest.
Args:
proto_config: Protobuf LoraConfig message
Returns:
LoRARequest or None
"""
if proto_config is None or proto_config.task_id == 0:
return None
return LoRARequest(
lora_name=f"lora_{proto_config.task_id}",
lora_int_id=proto_config.task_id,
)
def create_disaggregated_params_from_proto(
proto_config: Optional[pb2.DisaggregatedParams],
) -> Optional[DisaggregatedParams]:
"""Convert protobuf DisaggregatedParams to TensorRT-LLM DisaggregatedParams.
Args:
proto_config: Protobuf DisaggregatedParams message
Returns:
DisaggregatedParams or None
"""
if proto_config is None:
return None
request_type_map = {
pb2.DisaggregatedParams.REQUEST_TYPE_CONTEXT_AND_GENERATION: "context_and_generation",
pb2.DisaggregatedParams.REQUEST_TYPE_CONTEXT_ONLY: "context_only",
pb2.DisaggregatedParams.REQUEST_TYPE_GENERATION_ONLY: "generation_only",
}
request_type = request_type_map.get(proto_config.request_type, "context_and_generation")
params = DisaggregatedParams(request_type=request_type)
if proto_config.ctx_request_id:
params.ctx_request_id = proto_config.ctx_request_id
if proto_config.HasField("context_phase_params"):
ctx_params = proto_config.context_phase_params
params.first_gen_token_id = ctx_params.first_gen_token_id
if ctx_params.kv_cache_blocks:
params.kv_cache_blocks = ctx_params.kv_cache_blocks
return params

View File

@ -0,0 +1,545 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""gRPC Servicer for TensorRT-LLM.
Implements the TrtllmService gRPC service for high-performance communication
with external routers (e.g., sgl-router) using pre-tokenized input.
"""
import asyncio
import time
from collections.abc import AsyncGenerator
from typing import List, Union
import grpc
from tensorrt_llm.executor.result import Logprob, TokenLogprobs
from tensorrt_llm.logger import logger
from . import trtllm_service_pb2, trtllm_service_pb2_grpc
from .grpc_request_manager import (
GrpcRequestManager,
create_disaggregated_params_from_proto,
create_lora_request_from_proto,
create_sampling_params_from_proto,
)
class TrtllmServiceServicer(trtllm_service_pb2_grpc.TrtllmServiceServicer):
"""gRPC servicer implementing the TrtLlmEngine service.
Handles RPCs:
- Generate: Streaming text generation
- Embed: Embeddings (for embedding models)
- HealthCheck: Health probe
- Abort: Cancel a request
- GetModelInfo: Model metadata
- GetServerInfo: Server state
"""
def __init__(self, request_manager: GrpcRequestManager, model_path: str = ""):
"""Initialize the servicer.
Args:
request_manager: The GrpcRequestManager instance
model_path: Path to the model (for metadata)
"""
self.request_manager = request_manager
self.model_path = model_path
self._start_time = time.time()
logger.info("TrtllmServiceServicer initialized")
async def Generate(
self,
request: trtllm_service_pb2.GenerateRequest,
context: grpc.aio.ServicerContext,
) -> AsyncGenerator[trtllm_service_pb2.GenerateResponse, None]:
"""Handle streaming generation requests.
Args:
request: The GenerateRequest protobuf
context: gRPC context
Yields:
GenerateResponse protobuf messages (streaming)
"""
request_id = request.request_id
logger.info(f"Generate request {request_id} received")
try:
# Extract tokenized input (required)
if not request.HasField("tokenized"):
yield self._error_response(
request_id,
"Missing tokenized input",
"INVALID_REQUEST",
400,
)
return
prompt_token_ids = list(request.tokenized.input_token_ids)
# Build sampling params with detokenize=False (key optimization!)
sampling_params = create_sampling_params_from_proto(
proto_config=request.sampling_config,
output_config=request.output_config,
max_tokens=request.max_tokens,
end_id=request.end_id if request.HasField("end_id") else None,
pad_id=request.pad_id if request.HasField("pad_id") else None,
bad_words=list(request.bad_words) if request.bad_words else None,
stop_words=list(request.stop_words) if request.stop_words else None,
guided_decoding=request.guided_decoding
if request.HasField("guided_decoding")
else None,
embedding_bias=list(request.embedding_bias) if request.embedding_bias else None,
)
# Build LoRA request if present
lora_request = create_lora_request_from_proto(
request.lora_config if request.HasField("lora_config") else None
)
# Build disaggregated params if present
disaggregated_params = create_disaggregated_params_from_proto(
request.disaggregated_params if request.HasField("disaggregated_params") else None
)
# Track tokens sent per sequence index to avoid duplicates
# TRT-LLM's token_ids_diff doesn't clear between iterations for n>1
sent_token_counts: dict[int, int] = {}
# Submit to request manager and stream outputs
# The request manager now yields GenerationResult objects
async for gen_result in self.request_manager.generate(
request_id=request_id,
prompt_token_ids=prompt_token_ids,
sampling_params=sampling_params,
streaming=request.streaming,
lora_request=lora_request,
disaggregated_params=disaggregated_params,
):
# Check if client disconnected
if context.cancelled():
logger.info(f"Client disconnected for {request_id}")
await self.request_manager.abort(request_id)
return
# Convert GenerationResult to protobuf response
if request.streaming:
for chunk_response in self._chunk_responses(
request_id, gen_result, prompt_token_ids, sent_token_counts
):
yield chunk_response
# Send complete responses when finished (one per sequence for n>1)
if gen_result.finished:
for complete_response in self._complete_responses(
request_id, gen_result, prompt_token_ids
):
yield complete_response
except asyncio.CancelledError:
logger.info(f"Request {request_id} cancelled")
await self.request_manager.abort(request_id)
raise
except Exception as e:
logger.error(f"Error in Generate for {request_id}: {e}")
yield self._error_response(
request_id,
str(e),
"INTERNAL_ERROR",
500,
)
async def Embed(
self,
request: trtllm_service_pb2.EmbedRequest,
context: grpc.aio.ServicerContext,
) -> trtllm_service_pb2.EmbedResponse:
"""Handle embedding requests.
Args:
request: The EmbedRequest protobuf
context: gRPC context
Returns:
EmbedResponse protobuf
"""
logger.warning("Embed RPC not yet implemented")
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details("Embed RPC not yet implemented")
return trtllm_service_pb2.EmbedResponse(
request_id=request.request_id,
embedding=[],
prompt_tokens=0,
)
async def HealthCheck(
self,
request: trtllm_service_pb2.HealthCheckRequest,
context: grpc.aio.ServicerContext,
) -> trtllm_service_pb2.HealthCheckResponse:
"""Handle health check requests.
Args:
request: The HealthCheckRequest protobuf
context: gRPC context
Returns:
HealthCheckResponse protobuf
"""
is_healthy, message = await self.request_manager.health_check()
logger.info(f"HealthCheck: healthy={is_healthy}, message={message}")
return trtllm_service_pb2.HealthCheckResponse(
status=message,
)
async def Abort(
self,
request: trtllm_service_pb2.AbortRequest,
context: grpc.aio.ServicerContext,
) -> trtllm_service_pb2.AbortResponse:
"""Handle abort requests.
Args:
request: The AbortRequest protobuf
context: gRPC context
Returns:
AbortResponse protobuf
"""
request_id = request.request_id
logger.info(f"Abort request for {request_id}")
success = await self.request_manager.abort(request_id)
return trtllm_service_pb2.AbortResponse(
success=success,
message=f"Request {request_id} {'aborted' if success else 'not found'}",
)
async def GetModelInfo(
self,
request: trtllm_service_pb2.GetModelInfoRequest,
context: grpc.aio.ServicerContext,
) -> trtllm_service_pb2.GetModelInfoResponse:
"""Handle model info requests.
Args:
request: The GetModelInfoRequest protobuf
context: gRPC context
Returns:
GetModelInfoResponse protobuf
"""
model_config = self.request_manager.get_model_config()
return trtllm_service_pb2.GetModelInfoResponse(
model_id=self.model_path or model_config.get("model_path", ""),
max_input_len=model_config.get("max_context_length", 0),
max_seq_len=model_config.get("max_seq_len", 0)
or model_config.get("max_context_length", 0),
vocab_size=model_config.get("vocab_size", 0),
)
async def GetServerInfo(
self,
request: trtllm_service_pb2.GetServerInfoRequest,
context: grpc.aio.ServicerContext,
) -> trtllm_service_pb2.GetServerInfoResponse:
"""Handle server info requests.
Args:
request: The GetServerInfoRequest protobuf
context: gRPC context
Returns:
GetServerInfoResponse protobuf
"""
try:
import tensorrt_llm
version = getattr(tensorrt_llm, "__version__", "unknown")
except Exception:
version = "unknown"
# Try to get parallelism info from LLM args
tp_size = 1
pp_size = 1
world_size = 1
try:
llm = self.request_manager.llm
if hasattr(llm, "args") and llm.args is not None:
args = llm.args
if hasattr(args, "tensor_parallel_size") and args.tensor_parallel_size:
tp_size = args.tensor_parallel_size
if hasattr(args, "pipeline_parallel_size") and args.pipeline_parallel_size:
pp_size = args.pipeline_parallel_size
world_size = tp_size * pp_size
except Exception as e:
logger.debug(f"Could not get parallelism info: {e}")
return trtllm_service_pb2.GetServerInfoResponse(
version=version,
backend="tensorrt-llm",
tensor_parallel_size=tp_size,
pipeline_parallel_size=pp_size,
context_parallel_size=1, # Context parallelism is separate from TP/PP
world_size=world_size,
)
# ========== Helper methods ==========
def _convert_logprobs_to_proto(
self,
token_ids: List[int],
logprobs: Union[TokenLogprobs, List[float], None],
) -> List[trtllm_service_pb2.TokenLogprob]:
"""Convert TRT-LLM logprobs to protobuf TokenLogprob messages.
Handles both formats:
- List[float]: Simple logprobs (one per token)
- TokenLogprobs (list[dict[int, Logprob]]): Top-k logprobs per position
Args:
token_ids: The sampled token IDs
logprobs: Logprobs from TRT-LLM (can be List[float] or TokenLogprobs)
Returns:
List of TokenLogprob proto messages
"""
if not logprobs or not token_ids:
return []
result = []
for i, token_id in enumerate(token_ids):
if i >= len(logprobs):
break
lp = logprobs[i]
if isinstance(lp, dict):
# TokenLogprobs format: dict[int, Logprob]
# Each entry maps token_id -> Logprob(logprob, rank)
token_logprob = trtllm_service_pb2.TokenLogprob(
token_id=token_id,
logprob=lp[token_id].logprob if token_id in lp else 0.0,
)
# Add top logprobs (all entries in the dict)
for tid, logprob_obj in lp.items():
if isinstance(logprob_obj, Logprob):
token_logprob.top_logprobs.append(
trtllm_service_pb2.TopLogprob(
token_id=tid,
logprob=logprob_obj.logprob,
)
)
result.append(token_logprob)
elif isinstance(lp, (int, float)):
# Simple float logprob
token_logprob = trtllm_service_pb2.TokenLogprob(
token_id=token_id,
logprob=float(lp),
)
result.append(token_logprob)
return result
def _chunk_responses(
self,
request_id: str,
gen_result,
prompt_token_ids: list,
sent_token_counts: dict[int, int],
) -> List[trtllm_service_pb2.GenerateResponse]:
"""Build streaming chunk responses from GenerationResult.
Uses cumulative token_ids and tracks sent position to compute true deltas.
TRT-LLM's token_ids_diff doesn't clear between iterations for n>1, so we
compute deltas ourselves.
Args:
request_id: The request ID
gen_result: TensorRT-LLM GenerationResult
prompt_token_ids: Original prompt tokens
sent_token_counts: Dict tracking tokens already sent per sequence index
Returns:
List of GenerateResponse with chunk field set (one per output)
"""
responses = []
cached_tokens = gen_result.cached_tokens if hasattr(gen_result, "cached_tokens") else 0
if not gen_result.outputs:
# No outputs yet, return empty chunk
responses.append(
trtllm_service_pb2.GenerateResponse(
request_id=request_id,
chunk=trtllm_service_pb2.GenerateStreamChunk(
token_ids=[],
prompt_tokens=len(prompt_token_ids),
completion_tokens=0,
cached_tokens=cached_tokens,
),
)
)
return responses
# Process all outputs (for n>1 support)
for completion in gen_result.outputs:
index = completion.index
# Use cumulative token_ids and compute delta ourselves
# because token_ids_diff doesn't clear between iterations for n>1
all_tokens = list(completion.token_ids) if completion.token_ids else []
sent_count = sent_token_counts.get(index, 0)
delta_tokens = all_tokens[sent_count:]
# Skip if no new tokens for this sequence
if not delta_tokens:
continue
# Update sent count
sent_token_counts[index] = len(all_tokens)
chunk = trtllm_service_pb2.GenerateStreamChunk(
token_ids=delta_tokens,
sequence_index=completion.index,
prompt_tokens=len(prompt_token_ids),
completion_tokens=len(completion.token_ids) if completion.token_ids else 0,
cached_tokens=cached_tokens,
)
# Add logprobs if available
# Note: We compute delta logprobs ourselves since logprobs_diff has same issue as token_ids_diff
if completion.logprobs:
all_logprobs = completion.logprobs
delta_logprobs = all_logprobs[sent_count:] if sent_count < len(all_logprobs) else []
proto_logprobs = self._convert_logprobs_to_proto(delta_tokens, delta_logprobs)
chunk.logprobs.extend(proto_logprobs)
responses.append(
trtllm_service_pb2.GenerateResponse(
request_id=request_id,
chunk=chunk,
)
)
return responses
def _complete_responses(
self,
request_id: str,
gen_result,
prompt_token_ids: list,
) -> List[trtllm_service_pb2.GenerateResponse]:
"""Build final completion responses from GenerationResult.
For n>1, returns one response per output sequence.
Args:
request_id: The request ID
gen_result: TensorRT-LLM GenerationResult (finished=True)
prompt_token_ids: Original prompt tokens
Returns:
List of GenerateResponse with complete field set (one per output)
"""
responses = []
cached_tokens = gen_result.cached_tokens if hasattr(gen_result, "cached_tokens") else 0
if not gen_result.outputs:
# No outputs, return error response
responses.append(
trtllm_service_pb2.GenerateResponse(
request_id=request_id,
complete=trtllm_service_pb2.GenerateComplete(
output_token_ids=[],
finish_reason="error",
prompt_tokens=len(prompt_token_ids),
completion_tokens=0,
cached_tokens=0,
),
)
)
return responses
# Process all outputs (for n>1 support)
for completion in gen_result.outputs:
output_tokens = list(completion.token_ids) if completion.token_ids else []
complete = trtllm_service_pb2.GenerateComplete(
output_token_ids=output_tokens,
sequence_index=completion.index,
finish_reason=completion.finish_reason or "stop",
prompt_tokens=len(prompt_token_ids),
completion_tokens=len(output_tokens),
cached_tokens=cached_tokens,
)
# Add stop reason if available
if hasattr(completion, "stop_reason") and completion.stop_reason:
complete.stop_reason = str(completion.stop_reason)
# Add generation logprobs if available
if completion.logprobs:
proto_logprobs = self._convert_logprobs_to_proto(output_tokens, completion.logprobs)
complete.logprobs.extend(proto_logprobs)
# Add prompt logprobs if available
if hasattr(completion, "prompt_logprobs") and completion.prompt_logprobs:
# For prompt logprobs, we use the prompt_token_ids
proto_prompt_logprobs = self._convert_logprobs_to_proto(
prompt_token_ids, completion.prompt_logprobs
)
complete.prompt_logprobs.extend(proto_prompt_logprobs)
responses.append(
trtllm_service_pb2.GenerateResponse(
request_id=request_id,
complete=complete,
)
)
return responses
def _error_response(
self,
request_id: str,
message: str,
error_type: str,
code: int,
) -> trtllm_service_pb2.GenerateResponse:
"""Build an error response.
Args:
request_id: The request ID
message: Error message
error_type: Error type string
code: Error code
Returns:
GenerateResponse with error field set
"""
return trtllm_service_pb2.GenerateResponse(
request_id=request_id,
error=trtllm_service_pb2.GenerateError(
message=message,
type=error_type,
code=code,
),
)

View File

@ -0,0 +1,511 @@
// SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto3";
package trtllm;
// TensorRT-LLM gRPC Service
//
// This service provides high-performance inference for LLMs using TensorRT-LLM.
// It accepts pre-tokenized requests and returns raw token IDs, enabling efficient
// binary communication with external routers (e.g., sgl-router).
//
// Key Design Principles:
// - Token IDs only: No text in requests or responses (router handles tokenization)
// - Streaming delta mode: Chunks contain only new tokens since last chunk
// - Full feature support: All TensorRT-LLM capabilities exposed
service TrtllmService {
// Generate tokens from pre-tokenized input
// Returns a stream of responses containing token IDs
rpc Generate(GenerateRequest) returns (stream GenerateResponse);
// Generate embeddings from pre-tokenized input (for embedding models)
rpc Embed(EmbedRequest) returns (EmbedResponse);
// Health check endpoint
rpc HealthCheck(HealthCheckRequest) returns (HealthCheckResponse);
// Abort a running generation request
rpc Abort(AbortRequest) returns (AbortResponse);
// Get model information (vocab size, max lengths, etc.)
rpc GetModelInfo(GetModelInfoRequest) returns (GetModelInfoResponse);
// Get server information (version, parallelism, etc.)
rpc GetServerInfo(GetServerInfoRequest) returns (GetServerInfoResponse);
}
// ============================================================================
// Generate Request
// ============================================================================
message GenerateRequest {
// Unique request identifier (assigned by client/router)
string request_id = 1;
// Pre-tokenized input (REQUIRED - router tokenizes)
TokenizedInput tokenized = 2;
// Sampling configuration
SamplingConfig sampling_config = 3;
// Output configuration
OutputConfig output_config = 4;
// Maximum tokens to generate (REQUIRED)
uint32 max_tokens = 5;
// Enable streaming mode
bool streaming = 6;
// End-of-sequence token ID (optional, set to -1 to ignore EOS)
optional int32 end_id = 7;
// Padding token ID
optional int32 pad_id = 8;
// Bad word token sequences (generation redirects if these are produced)
repeated TokenSequence bad_words = 9;
// Stop word token sequences (generation stops when these are produced)
repeated TokenSequence stop_words = 10;
// Guided decoding parameters (JSON schema, regex, grammar)
optional GuidedDecodingParams guided_decoding = 11;
// Embedding bias tensor (vocab_size floats, optional)
repeated float embedding_bias = 12;
// LoRA adapter configuration
optional LoraConfig lora_config = 13;
// Prompt tuning configuration
optional PromptTuningConfig prompt_tuning_config = 14;
// Multimodal input (for VLMs)
optional MultimodalInput multimodal_input = 15;
// KV cache retention configuration
optional KvCacheRetentionConfig kv_cache_retention = 16;
// Disaggregated inference parameters
optional DisaggregatedParams disaggregated_params = 17;
// Lookahead decoding configuration
optional LookaheadConfig lookahead_config = 18;
// Cache salt ID for cache hashing
optional int64 cache_salt_id = 19;
// Request arrival time (unix timestamp for metrics)
optional double arrival_time = 20;
}
// Tokenized input from router
message TokenizedInput {
// Original text (for debugging/logging only, not used for generation)
string original_text = 1;
// Pre-tokenized input token IDs (REQUIRED)
repeated uint32 input_token_ids = 2;
// Query token IDs for VLM star attention (optional)
repeated uint32 query_token_ids = 3;
}
// Sequence of token IDs (for stop/bad words)
message TokenSequence {
repeated uint32 token_ids = 1;
}
// ============================================================================
// Sampling Configuration
// Maps to tensorrt_llm.bindings.executor.SamplingConfig
// ============================================================================
message SamplingConfig {
// Beam width (1 for sampling, >1 for beam search)
int32 beam_width = 1;
// Number of sequences to return
uint32 num_return_sequences = 2;
// Top-K sampling (0 = disabled, considers all tokens)
optional int32 top_k = 3;
// Top-P (nucleus) sampling threshold
optional float top_p = 4;
// Top-P minimum threshold for decay
optional float top_p_min = 5;
// Top-P reset token IDs
optional int32 top_p_reset_ids = 6;
// Top-P decay factor
optional float top_p_decay = 7;
// Random seed for reproducibility
optional uint64 seed = 8;
// Temperature for sampling (0 = greedy, higher = more random)
optional float temperature = 9;
// Minimum tokens to generate before stopping
optional uint32 min_tokens = 10;
// Beam search diversity rate
optional float beam_search_diversity_rate = 11;
// Repetition penalty (>1 discourages, <1 encourages repetition)
optional float repetition_penalty = 12;
// Presence penalty (penalizes tokens that have appeared)
optional float presence_penalty = 13;
// Frequency penalty (penalizes based on frequency of appearance)
optional float frequency_penalty = 14;
// Number of prompt tokens to ignore for penalties
optional int32 prompt_ignore_length = 15;
// Length penalty for beam search
optional float length_penalty = 16;
// Early stopping for beam search
optional int32 early_stopping = 17;
// No repeat n-gram size
optional int32 no_repeat_ngram_size = 18;
// Min-P sampling threshold
optional float min_p = 19;
// Variable beam width array for beam search
repeated int32 beam_width_array = 20;
}
// ============================================================================
// Output Configuration
// Maps to tensorrt_llm.bindings.executor.OutputConfig
// ============================================================================
message OutputConfig {
// Number of top log probabilities to return per output token
optional int32 logprobs = 1;
// Number of top log probabilities to return per prompt token
optional int32 prompt_logprobs = 2;
// Return context logits tensor (large, use with caution)
bool return_context_logits = 3;
// Return generation logits tensor (large, use with caution)
bool return_generation_logits = 4;
// Exclude input tokens from output (set to true to enable; TRT-LLM defaults to true internally)
bool exclude_input_from_output = 5;
// Return encoder output (for encoder-decoder models)
bool return_encoder_output = 6;
// Return performance metrics
bool return_perf_metrics = 7;
}
// ============================================================================
// Guided Decoding
// ============================================================================
message GuidedDecodingParams {
// Guide type enumeration
enum GuideType {
GUIDE_TYPE_UNSPECIFIED = 0;
GUIDE_TYPE_JSON = 1; // JSON format (any valid JSON)
GUIDE_TYPE_JSON_SCHEMA = 2; // JSON with schema constraint
GUIDE_TYPE_REGEX = 3; // Regular expression constraint
GUIDE_TYPE_EBNF_GRAMMAR = 4; // EBNF grammar constraint
GUIDE_TYPE_STRUCTURAL_TAG = 5; // Structural tag (xgrammar backend)
}
GuideType guide_type = 1;
// Guide content (schema string, regex pattern, or grammar definition)
string guide = 2;
}
// ============================================================================
// LoRA Configuration
// ============================================================================
message LoraConfig {
// LoRA task/adapter ID
int64 task_id = 1;
// LoRA weights (serialized tensor, optional if already cached)
optional bytes weights = 2;
// LoRA config as JSON string
optional string config_json = 3;
}
// ============================================================================
// Prompt Tuning Configuration
// ============================================================================
message PromptTuningConfig {
// Embedding table (serialized tensor)
bytes embedding_table = 1;
}
// ============================================================================
// Multimodal Input (for Vision-Language Models)
// ============================================================================
message MultimodalInput {
// Multimodal content hashes for caching
repeated int64 multimodal_hashes = 1;
// Positions in input where multimodal content is inserted
repeated int32 multimodal_positions = 2;
// Lengths of multimodal content at each position
repeated int32 multimodal_lengths = 3;
}
// ============================================================================
// KV Cache Retention
// ============================================================================
message KvCacheRetentionConfig {
// Retention policy name
string policy = 1;
// Additional configuration as JSON string
string config_json = 2;
}
// ============================================================================
// Disaggregated Inference
// ============================================================================
message DisaggregatedParams {
// Request type for disaggregated inference
enum RequestType {
REQUEST_TYPE_CONTEXT_AND_GENERATION = 0; // Normal full request
REQUEST_TYPE_CONTEXT_ONLY = 1; // Prefill only
REQUEST_TYPE_GENERATION_ONLY = 2; // Decode only
}
RequestType request_type = 1;
// Context request ID (links context and generation phases)
string ctx_request_id = 2;
// Context phase parameters (for generation_only requests)
optional ContextPhaseParams context_phase_params = 3;
}
message ContextPhaseParams {
// First generated token ID from context phase
uint32 first_gen_token_id = 1;
// KV cache block pointers (serialized)
bytes kv_cache_blocks = 2;
}
// ============================================================================
// Lookahead Decoding
// ============================================================================
message LookaheadConfig {
int32 max_window_size = 1;
int32 max_ngram_size = 2;
int32 max_verification_set_size = 3;
}
// ============================================================================
// Generate Response
// ============================================================================
message GenerateResponse {
// Request ID echo
string request_id = 1;
// Response type (oneof ensures exactly one is set)
oneof response {
GenerateStreamChunk chunk = 2; // Streaming delta
GenerateComplete complete = 3; // Final response
GenerateError error = 4; // Error response
}
}
// Streaming chunk containing delta tokens (new tokens since last chunk)
message GenerateStreamChunk {
// NEW token IDs only (delta from previous chunk)
repeated uint32 token_ids = 1;
// Beam/sequence index (for beam_width > 1 or n > 1)
uint32 sequence_index = 2;
// Token counts for usage tracking
uint32 prompt_tokens = 3;
uint32 completion_tokens = 4;
uint32 cached_tokens = 5;
// Log probabilities for this chunk's tokens (if requested)
repeated TokenLogprob logprobs = 6;
}
// Final/complete response with all output tokens
message GenerateComplete {
// All output token IDs (cumulative, not delta)
repeated uint32 output_token_ids = 1;
// Beam/sequence index
uint32 sequence_index = 2;
// Finish reason: "stop", "length", "stop_word"
string finish_reason = 3;
// Specific stop reason (stop word/token that triggered stop)
optional string stop_reason = 4;
// Token counts for usage tracking
uint32 prompt_tokens = 5;
uint32 completion_tokens = 6;
uint32 cached_tokens = 7;
// Generation log probabilities (if requested)
repeated TokenLogprob logprobs = 8;
// Prompt log probabilities (if requested)
repeated TokenLogprob prompt_logprobs = 9;
// Performance metrics (if requested)
optional PerfMetrics perf_metrics = 10;
// Context logits (if requested) - serialized float tensor
optional bytes context_logits = 11;
// Generation logits (if requested) - serialized float tensor
optional bytes generation_logits = 12;
}
// Token log probability information
message TokenLogprob {
uint32 token_id = 1;
float logprob = 2;
// Top alternative tokens and their log probabilities
repeated TopLogprob top_logprobs = 3;
}
message TopLogprob {
uint32 token_id = 1;
float logprob = 2;
}
// Performance metrics for request
message PerfMetrics {
double arrival_time = 1; // When request arrived
double first_scheduled_time = 2; // When first scheduled
double first_token_time = 3; // Time to first token (TTFT)
double last_token_time = 4; // When last token generated
double kv_cache_transfer_start = 5; // KV cache transfer start (disagg)
double kv_cache_transfer_end = 6; // KV cache transfer end (disagg)
int64 kv_cache_size = 7; // KV cache size in bytes
}
// Error response
message GenerateError {
string message = 1; // Human-readable error message
string type = 2; // Error type (e.g., "InvalidRequest", "InternalError")
int32 code = 3; // Error code
}
// ============================================================================
// Embed Request/Response (for embedding models)
// ============================================================================
message EmbedRequest {
string request_id = 1;
TokenizedInput tokenized = 2;
}
message EmbedResponse {
string request_id = 1;
repeated float embedding = 2; // Embedding vector
uint32 prompt_tokens = 3;
}
// ============================================================================
// Health Check
// ============================================================================
message HealthCheckRequest {}
message HealthCheckResponse {
string status = 1; // "OK" or error description
}
// ============================================================================
// Abort Request
// ============================================================================
message AbortRequest {
string request_id = 1;
}
message AbortResponse {
bool success = 1;
string message = 2;
}
// ============================================================================
// Model Info
// ============================================================================
message GetModelInfoRequest {}
message GetModelInfoResponse {
string model_id = 1; // Model identifier/path
int32 max_input_len = 2; // Maximum input length
int32 max_seq_len = 3; // Maximum sequence length (input + output)
int32 max_batch_size = 4; // Maximum batch size
int32 vocab_size = 5; // Vocabulary size
int32 hidden_size = 6; // Hidden dimension
int32 num_layers = 7; // Number of transformer layers
int32 num_heads = 8; // Number of attention heads
// Supported features
repeated string supported_features = 9; // e.g., "lora", "guided_decoding"
}
// ============================================================================
// Server Info
// ============================================================================
message GetServerInfoRequest {}
message GetServerInfoResponse {
string version = 1; // TensorRT-LLM version
string backend = 2; // "tensorrt" or "pytorch"
int32 tensor_parallel_size = 3; // TP size
int32 pipeline_parallel_size = 4; // PP size
int32 context_parallel_size = 5; // CP size
int32 world_size = 6; // Total world size
}

View File

@ -0,0 +1,315 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Unit tests for gRPC server components."""
import pytest
from tensorrt_llm.grpc import trtllm_service_pb2 as pb2
from tensorrt_llm.grpc.grpc_request_manager import (
create_disaggregated_params_from_proto,
create_lora_request_from_proto,
create_sampling_params_from_proto,
)
pytestmark = pytest.mark.threadleak(enabled=False)
class TestSamplingParamsConversion:
"""Tests for proto to SamplingParams conversion."""
def test_basic_sampling_config(self):
"""Test basic sampling config conversion."""
proto_config = pb2.SamplingConfig(
beam_width=1,
num_return_sequences=1,
temperature=0.7,
top_k=50,
top_p=0.9,
)
output_config = pb2.OutputConfig()
params = create_sampling_params_from_proto(
proto_config=proto_config,
output_config=output_config,
max_tokens=100,
)
assert params.max_tokens == 100
assert params.temperature == 0.7
assert params.top_k == 50
assert params.top_p == 0.9
def test_beam_search_config(self):
"""Test beam search configuration."""
proto_config = pb2.SamplingConfig(
beam_width=4,
num_return_sequences=2,
length_penalty=1.2,
early_stopping=1,
)
output_config = pb2.OutputConfig()
params = create_sampling_params_from_proto(
proto_config=proto_config,
output_config=output_config,
max_tokens=50,
)
assert params.beam_width == 4
assert params.n == 2
assert params.length_penalty == 1.2
def test_penalties_config(self):
"""Test penalty parameters conversion."""
proto_config = pb2.SamplingConfig(
repetition_penalty=1.1,
presence_penalty=0.5,
frequency_penalty=0.3,
)
output_config = pb2.OutputConfig()
params = create_sampling_params_from_proto(
proto_config=proto_config,
output_config=output_config,
max_tokens=100,
)
assert params.repetition_penalty == 1.1
assert params.presence_penalty == 0.5
assert params.frequency_penalty == 0.3
def test_logprobs_config(self):
"""Test logprobs configuration."""
proto_config = pb2.SamplingConfig()
output_config = pb2.OutputConfig(
logprobs=5,
prompt_logprobs=3,
)
params = create_sampling_params_from_proto(
proto_config=proto_config,
output_config=output_config,
max_tokens=100,
)
assert params.logprobs == 5
assert params.prompt_logprobs == 3
def test_guided_decoding_json_schema(self):
"""Test guided decoding with JSON schema."""
proto_config = pb2.SamplingConfig()
output_config = pb2.OutputConfig()
guided_decoding = pb2.GuidedDecodingParams(
guide_type=pb2.GuidedDecodingParams.GUIDE_TYPE_JSON_SCHEMA,
guide='{"type": "object", "properties": {"name": {"type": "string"}}}',
)
params = create_sampling_params_from_proto(
proto_config=proto_config,
output_config=output_config,
max_tokens=100,
guided_decoding=guided_decoding,
)
assert params.guided_decoding_params is not None
assert params.guided_decoding_params.json_schema is not None
def test_guided_decoding_regex(self):
"""Test guided decoding with regex."""
proto_config = pb2.SamplingConfig()
output_config = pb2.OutputConfig()
guided_decoding = pb2.GuidedDecodingParams(
guide_type=pb2.GuidedDecodingParams.GUIDE_TYPE_REGEX,
guide=r"\d{3}-\d{4}",
)
params = create_sampling_params_from_proto(
proto_config=proto_config,
output_config=output_config,
max_tokens=100,
guided_decoding=guided_decoding,
)
assert params.guided_decoding_params is not None
assert params.guided_decoding_params.regex is not None
class TestLoraRequestConversion:
"""Tests for proto to LoRARequest conversion."""
def test_basic_lora_config(self):
"""Test basic LoRA config conversion."""
lora_config = pb2.LoraConfig(task_id=123)
request = create_lora_request_from_proto(lora_config)
assert request is not None
assert request.task_id == 123
def test_none_lora_config(self):
"""Test None LoRA config returns None."""
request = create_lora_request_from_proto(None)
assert request is None
class TestDisaggregatedParamsConversion:
"""Tests for proto to DisaggregatedParams conversion."""
def test_context_only_request(self):
"""Test context-only disaggregated request."""
proto_params = pb2.DisaggregatedParams(
request_type=pb2.DisaggregatedParams.REQUEST_TYPE_CONTEXT_ONLY,
ctx_request_id="ctx-123",
)
params = create_disaggregated_params_from_proto(proto_params)
assert params is not None
assert params.ctx_request_id == "ctx-123"
def test_generation_only_request(self):
"""Test generation-only disaggregated request."""
proto_params = pb2.DisaggregatedParams(
request_type=pb2.DisaggregatedParams.REQUEST_TYPE_GENERATION_ONLY,
ctx_request_id="gen-456",
)
params = create_disaggregated_params_from_proto(proto_params)
assert params is not None
def test_none_params(self):
"""Test None disaggregated params returns None."""
params = create_disaggregated_params_from_proto(None)
assert params is None
class TestProtoMessages:
"""Tests for proto message structure."""
def test_generate_request_structure(self):
"""Test GenerateRequest message structure."""
request = pb2.GenerateRequest(
request_id="test-123",
tokenized=pb2.TokenizedInput(
input_token_ids=[1, 2, 3, 4, 5],
original_text="Hello world",
),
sampling_config=pb2.SamplingConfig(temperature=0.8),
max_tokens=50,
streaming=True,
)
assert request.request_id == "test-123"
assert list(request.tokenized.input_token_ids) == [1, 2, 3, 4, 5]
assert request.tokenized.original_text == "Hello world"
assert request.sampling_config.temperature == 0.8
assert request.max_tokens == 50
assert request.streaming is True
def test_generate_response_chunk(self):
"""Test GenerateResponse with chunk."""
response = pb2.GenerateResponse(
request_id="test-123",
chunk=pb2.GenerateStreamChunk(
token_ids=[10, 11, 12],
sequence_index=0,
prompt_tokens=5,
completion_tokens=3,
),
)
assert response.request_id == "test-123"
assert list(response.chunk.token_ids) == [10, 11, 12]
assert response.chunk.prompt_tokens == 5
assert response.chunk.completion_tokens == 3
def test_generate_response_complete(self):
"""Test GenerateResponse with complete."""
response = pb2.GenerateResponse(
request_id="test-123",
complete=pb2.GenerateComplete(
output_token_ids=[10, 11, 12, 13],
finish_reason="stop",
prompt_tokens=5,
completion_tokens=4,
),
)
assert response.request_id == "test-123"
assert list(response.complete.output_token_ids) == [10, 11, 12, 13]
assert response.complete.finish_reason == "stop"
def test_health_check_messages(self):
"""Test HealthCheck messages."""
_request = pb2.HealthCheckRequest() # noqa: F841 - verify message construction
response = pb2.HealthCheckResponse(status="healthy")
assert response.status == "healthy"
def test_model_info_response(self):
"""Test GetModelInfoResponse message."""
response = pb2.GetModelInfoResponse(
model_id="meta-llama/Llama-2-7b",
max_input_len=4096,
max_seq_len=8192,
vocab_size=32000,
)
assert response.model_id == "meta-llama/Llama-2-7b"
assert response.max_input_len == 4096
assert response.max_seq_len == 8192
assert response.vocab_size == 32000
def test_server_info_response(self):
"""Test GetServerInfoResponse message."""
response = pb2.GetServerInfoResponse(
version="0.17.0",
backend="tensorrt-llm",
tensor_parallel_size=2,
pipeline_parallel_size=1,
world_size=2,
)
assert response.version == "0.17.0"
assert response.backend == "tensorrt-llm"
assert response.tensor_parallel_size == 2
assert response.world_size == 2
def test_embed_messages(self):
"""Test Embed request and response messages."""
request = pb2.EmbedRequest(
request_id="embed-123",
tokenized=pb2.TokenizedInput(input_token_ids=[1, 2, 3]),
)
response = pb2.EmbedResponse(
request_id="embed-123",
embedding=[0.1, 0.2, 0.3, 0.4],
prompt_tokens=3,
)
assert request.request_id == "embed-123"
assert response.request_id == "embed-123"
assert list(response.embedding) == [0.1, 0.2, 0.3, 0.4]
assert response.prompt_tokens == 3
def test_abort_messages(self):
"""Test Abort request and response messages."""
request = pb2.AbortRequest(request_id="abort-123")
response = pb2.AbortResponse(success=True, message="Request aborted")
assert request.request_id == "abort-123"
assert response.success is True
assert response.message == "Request aborted"