mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-04 18:21:52 +08:00
[None][feat] Add gRPC server for high-performance external router integration (#11037)
Signed-off-by: Chang Su <chang.s.su@oracle.com>
This commit is contained in:
parent
e033929221
commit
dbad94715b
44
setup.py
44
setup.py
@ -18,9 +18,52 @@ from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
from setuptools import find_packages, setup
|
||||
from setuptools.command.build_py import build_py
|
||||
from setuptools.dist import Distribution
|
||||
|
||||
|
||||
class BuildPyWithProtoCompile(build_py):
|
||||
"""Custom build_py command that compiles protobuf files before building."""
|
||||
|
||||
def run(self):
|
||||
self.compile_grpc_protos()
|
||||
super().run()
|
||||
|
||||
def compile_grpc_protos(self):
|
||||
"""Compile gRPC protobuf files if the proto file exists."""
|
||||
grpc_dir = Path(__file__).parent / "tensorrt_llm" / "grpc"
|
||||
proto_file = grpc_dir / "trtllm_service.proto"
|
||||
compile_script = grpc_dir / "compile_protos.py"
|
||||
|
||||
if not proto_file.exists():
|
||||
return
|
||||
|
||||
# Check if pb2 files need to be generated
|
||||
pb2_file = grpc_dir / "trtllm_service_pb2.py"
|
||||
pb2_grpc_file = grpc_dir / "trtllm_service_pb2_grpc.py"
|
||||
|
||||
# Regenerate if pb2 files don't exist or are older than proto file
|
||||
needs_compile = (not pb2_file.exists() or not pb2_grpc_file.exists() or
|
||||
pb2_file.stat().st_mtime < proto_file.stat().st_mtime)
|
||||
|
||||
if needs_compile and compile_script.exists():
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
print("Compiling gRPC protobuf files...")
|
||||
try:
|
||||
subprocess.run(
|
||||
[sys.executable, str(compile_script)],
|
||||
check=True,
|
||||
cwd=str(grpc_dir.parent.parent),
|
||||
)
|
||||
print("gRPC protobuf compilation successful")
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f"Warning: Failed to compile gRPC protos: {e}")
|
||||
except Exception as e:
|
||||
print(f"Warning: gRPC proto compilation skipped: {e}")
|
||||
|
||||
|
||||
def parse_requirements(filename: os.PathLike):
|
||||
with open(filename) as f:
|
||||
requirements = f.read().splitlines()
|
||||
@ -374,6 +417,7 @@ packages += find_packages(include=["triton_kernels", "triton_kernels.*"])
|
||||
setup(
|
||||
name='tensorrt_llm',
|
||||
version=get_version(),
|
||||
cmdclass={'build_py': BuildPyWithProtoCompile},
|
||||
description=
|
||||
('TensorRT LLM provides users with an easy-to-use Python API to define Large Language Models (LLMs) and supports '
|
||||
'state-of-the-art optimizations to perform inference efficiently on NVIDIA GPUs.'
|
||||
|
||||
@ -233,6 +233,124 @@ def launch_server(
|
||||
asyncio.run(server(host, port, sockets=[s]))
|
||||
|
||||
|
||||
def launch_grpc_server(host: str, port: int, llm_args: dict):
|
||||
"""
|
||||
Launch a gRPC server for TensorRT-LLM.
|
||||
|
||||
This provides a high-performance gRPC interface designed for external routers
|
||||
(e.g., sgl-router) using pre-tokenized input and raw token ID output.
|
||||
|
||||
Args:
|
||||
host: Host to bind to
|
||||
port: Port to bind to
|
||||
llm_args: Arguments for LLM initialization (from get_llm_args)
|
||||
"""
|
||||
import grpc
|
||||
|
||||
try:
|
||||
from grpc_reflection.v1alpha import reflection
|
||||
REFLECTION_AVAILABLE = True
|
||||
except ImportError:
|
||||
REFLECTION_AVAILABLE = False
|
||||
|
||||
from tensorrt_llm.grpc import trtllm_service_pb2, trtllm_service_pb2_grpc
|
||||
from tensorrt_llm.grpc.grpc_request_manager import GrpcRequestManager
|
||||
from tensorrt_llm.grpc.grpc_servicer import TrtllmServiceServicer
|
||||
|
||||
async def serve_grpc_async():
|
||||
logger.info("Initializing TensorRT-LLM gRPC server...")
|
||||
|
||||
backend = llm_args.get("backend")
|
||||
model_path = llm_args.get("model", "")
|
||||
|
||||
if backend == "pytorch":
|
||||
llm_args.pop("build_config", None)
|
||||
llm = PyTorchLLM(**llm_args)
|
||||
elif backend == "_autodeploy":
|
||||
from tensorrt_llm._torch.auto_deploy import LLM as AutoDeployLLM
|
||||
llm_args.pop("build_config", None)
|
||||
llm = AutoDeployLLM(**llm_args)
|
||||
elif backend == "tensorrt" or backend == "trt":
|
||||
llm_args.pop("backend")
|
||||
llm = LLM(**llm_args)
|
||||
else:
|
||||
raise click.BadParameter(
|
||||
f"{backend} is not a known backend, check help for available options.",
|
||||
param_hint="backend")
|
||||
|
||||
logger.info("Model loaded successfully")
|
||||
|
||||
# Create request manager
|
||||
request_manager = GrpcRequestManager(llm)
|
||||
|
||||
# Create servicer
|
||||
servicer = TrtllmServiceServicer(request_manager, model_path=model_path)
|
||||
|
||||
# Create gRPC server
|
||||
server = grpc.aio.server(
|
||||
options=[
|
||||
("grpc.max_send_message_length", -1), # Unlimited
|
||||
("grpc.max_receive_message_length", -1), # Unlimited
|
||||
("grpc.keepalive_time_ms", 30000), # 30s keepalive
|
||||
("grpc.keepalive_timeout_ms", 10000), # 10s timeout
|
||||
], )
|
||||
|
||||
# Add servicer to server
|
||||
trtllm_service_pb2_grpc.add_TrtllmServiceServicer_to_server(
|
||||
servicer, server)
|
||||
|
||||
# Enable reflection for grpcurl and other tools
|
||||
if REFLECTION_AVAILABLE:
|
||||
service_names = (
|
||||
trtllm_service_pb2.DESCRIPTOR.services_by_name["TrtllmService"].
|
||||
full_name,
|
||||
reflection.SERVICE_NAME,
|
||||
)
|
||||
reflection.enable_server_reflection(service_names, server)
|
||||
logger.info("gRPC reflection enabled")
|
||||
|
||||
# Bind to address
|
||||
address = f"{host}:{port}"
|
||||
server.add_insecure_port(address)
|
||||
|
||||
# Start server
|
||||
await server.start()
|
||||
logger.info(f"TensorRT-LLM gRPC server started on {address}")
|
||||
logger.info("Server is ready to accept requests")
|
||||
|
||||
# Handle shutdown signals
|
||||
loop = asyncio.get_running_loop()
|
||||
stop_event = asyncio.Event()
|
||||
|
||||
def signal_handler():
|
||||
logger.info("Received shutdown signal")
|
||||
stop_event.set()
|
||||
|
||||
for sig in (signal.SIGTERM, signal.SIGINT):
|
||||
loop.add_signal_handler(sig, signal_handler)
|
||||
|
||||
# Serve until shutdown signal
|
||||
try:
|
||||
await stop_event.wait()
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Interrupted by user")
|
||||
finally:
|
||||
logger.info("Shutting down TensorRT-LLM gRPC server...")
|
||||
|
||||
# Stop gRPC server
|
||||
await server.stop(grace=5.0)
|
||||
logger.info("gRPC server stopped")
|
||||
|
||||
# Shutdown LLM
|
||||
if hasattr(llm, "shutdown"):
|
||||
llm.shutdown()
|
||||
logger.info("LLM engine stopped")
|
||||
|
||||
logger.info("Shutdown complete")
|
||||
|
||||
asyncio.run(serve_grpc_async())
|
||||
|
||||
|
||||
def launch_mm_encoder_server(
|
||||
host: str,
|
||||
port: int,
|
||||
@ -476,6 +594,12 @@ class ChoiceWithAlias(click.Choice):
|
||||
"Specify a custom chat template. "
|
||||
"Can be a file path or one-liner template string",
|
||||
"prototype"))
|
||||
@click.option(
|
||||
"--grpc",
|
||||
is_flag=True,
|
||||
default=False,
|
||||
help="Run gRPC server instead of OpenAI HTTP server. "
|
||||
"gRPC server accepts pre-tokenized requests and returns raw token IDs.")
|
||||
def serve(
|
||||
model: str, tokenizer: Optional[str], custom_tokenizer: Optional[str],
|
||||
host: str, port: int, log_level: str, backend: str, max_beam_width: int,
|
||||
@ -491,8 +615,9 @@ def serve(
|
||||
fail_fast_on_attention_window_too_large: bool,
|
||||
otlp_traces_endpoint: Optional[str], enable_chunked_prefill: bool,
|
||||
disagg_cluster_uri: Optional[str], media_io_kwargs: Optional[str],
|
||||
custom_module_dirs: list[Path], chat_template: Optional[str]):
|
||||
"""Running an OpenAI API compatible server
|
||||
custom_module_dirs: list[Path], chat_template: Optional[str],
|
||||
grpc: bool):
|
||||
"""Running an OpenAI API compatible server (or gRPC server with --grpc flag)
|
||||
|
||||
MODEL: model name | HF checkpoint path | TensorRT engine path
|
||||
"""
|
||||
@ -569,9 +694,29 @@ def serve(
|
||||
|
||||
multimodal_server_config = MultimodalServerConfig(
|
||||
media_io_kwargs=parsed_media_io_kwargs)
|
||||
launch_server(host, port, llm_args, tool_parser, chat_template,
|
||||
metadata_server_cfg, server_role, disagg_cluster_config,
|
||||
multimodal_server_config)
|
||||
|
||||
if grpc:
|
||||
# gRPC mode: launch gRPC server instead of OpenAI HTTP server
|
||||
# Check for unsupported arguments that are silently ignored in gRPC mode
|
||||
unsupported_args = {
|
||||
"tool_parser": tool_parser,
|
||||
"chat_template": chat_template,
|
||||
"metadata_server_config_file": metadata_server_config_file,
|
||||
"server_role": server_role,
|
||||
"disagg_cluster_config": disagg_cluster_config,
|
||||
}
|
||||
for name, value in unsupported_args.items():
|
||||
if value is not None:
|
||||
raise ValueError(
|
||||
f"Argument '{name}' is not supported when running in gRPC mode. "
|
||||
f"The gRPC server is designed for use with external routers that handle "
|
||||
f"these features (e.g., tool parsing, chat templates).")
|
||||
launch_grpc_server(host, port, llm_args)
|
||||
else:
|
||||
# Default: launch OpenAI HTTP server
|
||||
launch_server(host, port, llm_args, tool_parser, chat_template,
|
||||
metadata_server_cfg, server_role, disagg_cluster_config,
|
||||
multimodal_server_config)
|
||||
|
||||
|
||||
@click.command("mm_embedding_serve")
|
||||
|
||||
121
tensorrt_llm/grpc/__init__.py
Normal file
121
tensorrt_llm/grpc/__init__.py
Normal file
@ -0,0 +1,121 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
r"""TensorRT-LLM gRPC module for high-performance communication with external routers.
|
||||
|
||||
This module provides a gRPC server interface that accepts pre-tokenized requests
|
||||
and returns raw token IDs, enabling efficient binary communication with Rust-based
|
||||
routers like sgl-router.
|
||||
|
||||
Key Features:
|
||||
- Pre-tokenized input (no Python tokenization overhead)
|
||||
- Raw token ID output (no Python detokenization overhead)
|
||||
- Streaming support with delta tokens
|
||||
- Full sampling parameter support
|
||||
- Guided decoding (JSON schema, regex, grammar)
|
||||
- LoRA and prompt tuning support
|
||||
- Disaggregated inference support
|
||||
|
||||
Usage:
|
||||
python -m tensorrt_llm.commands.serve /path/to/model \
|
||||
--grpc \
|
||||
--host 0.0.0.0 \
|
||||
--port 50051
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
# Module directory for proto files
|
||||
GRPC_MODULE_DIR = Path(__file__).parent
|
||||
|
||||
# Proto file path
|
||||
PROTO_FILE = GRPC_MODULE_DIR / "trtllm_service.proto"
|
||||
|
||||
# Try to import generated protobuf modules
|
||||
try:
|
||||
from . import trtllm_service_pb2, trtllm_service_pb2_grpc
|
||||
|
||||
PROTOS_AVAILABLE = True
|
||||
except ImportError:
|
||||
PROTOS_AVAILABLE = False
|
||||
trtllm_service_pb2 = None
|
||||
trtllm_service_pb2_grpc = None
|
||||
|
||||
|
||||
def compile_protos():
|
||||
"""Compile protobuf files to generate Python modules.
|
||||
|
||||
Run this function if the generated *_pb2.py files are missing.
|
||||
Alternatively, run: python tensorrt_llm/grpc/compile_protos.py
|
||||
"""
|
||||
from .compile_protos import main as compile_main
|
||||
|
||||
compile_main()
|
||||
|
||||
|
||||
def ensure_protos_available():
|
||||
"""Ensure protobuf modules are available, compiling if necessary."""
|
||||
global PROTOS_AVAILABLE, trtllm_service_pb2, trtllm_service_pb2_grpc
|
||||
|
||||
if not PROTOS_AVAILABLE:
|
||||
raise ImportError(
|
||||
"gRPC protobuf modules are not available. "
|
||||
"Please generate them by running: "
|
||||
"python tensorrt_llm/grpc/compile_protos.py"
|
||||
)
|
||||
|
||||
|
||||
# Try to import request manager
|
||||
try:
|
||||
from .grpc_request_manager import (
|
||||
GrpcRequestManager,
|
||||
create_disaggregated_params_from_proto,
|
||||
create_lora_request_from_proto,
|
||||
create_sampling_params_from_proto,
|
||||
)
|
||||
|
||||
REQUEST_MANAGER_AVAILABLE = True
|
||||
except ImportError:
|
||||
REQUEST_MANAGER_AVAILABLE = False
|
||||
GrpcRequestManager = None
|
||||
create_sampling_params_from_proto = None
|
||||
create_lora_request_from_proto = None
|
||||
create_disaggregated_params_from_proto = None
|
||||
|
||||
# Try to import servicer
|
||||
try:
|
||||
from .grpc_servicer import TrtllmServiceServicer
|
||||
|
||||
SERVICER_AVAILABLE = True
|
||||
except ImportError:
|
||||
SERVICER_AVAILABLE = False
|
||||
TrtllmServiceServicer = None
|
||||
|
||||
__all__ = [
|
||||
"GRPC_MODULE_DIR",
|
||||
"PROTO_FILE",
|
||||
"PROTOS_AVAILABLE",
|
||||
"REQUEST_MANAGER_AVAILABLE",
|
||||
"SERVICER_AVAILABLE",
|
||||
"compile_protos",
|
||||
"ensure_protos_available",
|
||||
"trtllm_service_pb2",
|
||||
"trtllm_service_pb2_grpc",
|
||||
"GrpcRequestManager",
|
||||
"TrtllmServiceServicer",
|
||||
"create_sampling_params_from_proto",
|
||||
"create_lora_request_from_proto",
|
||||
"create_disaggregated_params_from_proto",
|
||||
]
|
||||
167
tensorrt_llm/grpc/compile_protos.py
Normal file
167
tensorrt_llm/grpc/compile_protos.py
Normal file
@ -0,0 +1,167 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Compile Protocol Buffer definitions for TensorRT-LLM gRPC server.
|
||||
|
||||
This script generates Python bindings from the trtllm_service.proto file.
|
||||
|
||||
Usage:
|
||||
python -m tensorrt_llm.grpc.compile_protos
|
||||
|
||||
Or directly:
|
||||
python compile_protos.py
|
||||
|
||||
Requirements:
|
||||
pip install grpcio-tools
|
||||
"""
|
||||
|
||||
import subprocess # nosec B404
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def compile_protos(proto_dir: Path = None, output_dir: Path = None) -> bool:
|
||||
"""Compile proto files to Python.
|
||||
|
||||
Args:
|
||||
proto_dir: Directory containing .proto files. Defaults to this script's directory.
|
||||
output_dir: Directory for generated Python files. Defaults to proto_dir.
|
||||
|
||||
Returns:
|
||||
True if compilation succeeded, False otherwise.
|
||||
"""
|
||||
if proto_dir is None:
|
||||
proto_dir = Path(__file__).parent
|
||||
if output_dir is None:
|
||||
output_dir = proto_dir
|
||||
|
||||
proto_file = proto_dir / "trtllm_service.proto"
|
||||
|
||||
if not proto_file.exists():
|
||||
print(f"Error: Proto file not found: {proto_file}")
|
||||
return False
|
||||
|
||||
# Check for grpcio-tools
|
||||
try:
|
||||
from grpc_tools import protoc
|
||||
except ImportError:
|
||||
print("grpcio-tools not installed. Installing...")
|
||||
subprocess.check_call([sys.executable, "-m", "pip", "install", "grpcio-tools"])
|
||||
from grpc_tools import protoc
|
||||
|
||||
# Compile proto file
|
||||
print(f"Compiling {proto_file}...")
|
||||
|
||||
result = protoc.main(
|
||||
[
|
||||
"grpc_tools.protoc",
|
||||
f"--proto_path={proto_dir}",
|
||||
f"--python_out={output_dir}",
|
||||
f"--grpc_python_out={output_dir}",
|
||||
str(proto_file),
|
||||
]
|
||||
)
|
||||
|
||||
if result != 0:
|
||||
print(f"Error: protoc failed with code {result}")
|
||||
return False
|
||||
|
||||
# Fix imports in generated files (grpc_tools generates absolute imports)
|
||||
pb2_file = output_dir / "trtllm_service_pb2.py"
|
||||
pb2_grpc_file = output_dir / "trtllm_service_pb2_grpc.py"
|
||||
|
||||
if pb2_grpc_file.exists():
|
||||
content = pb2_grpc_file.read_text()
|
||||
# Fix import to use relative import
|
||||
content = content.replace(
|
||||
"import trtllm_service_pb2 as trtllm__service__pb2",
|
||||
"from . import trtllm_service_pb2 as trtllm__service__pb2",
|
||||
)
|
||||
pb2_grpc_file.write_text(content)
|
||||
print(f"Fixed imports in {pb2_grpc_file}")
|
||||
|
||||
print("Generated files:")
|
||||
print(f" - {pb2_file}")
|
||||
print(f" - {pb2_grpc_file}")
|
||||
print("Proto compilation successful!")
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def verify_generated_files(output_dir: Path = None) -> bool:
|
||||
"""Verify that generated files can be imported.
|
||||
|
||||
Args:
|
||||
output_dir: Directory containing generated files.
|
||||
|
||||
Returns:
|
||||
True if files can be imported, False otherwise.
|
||||
"""
|
||||
if output_dir is None:
|
||||
output_dir = Path(__file__).parent
|
||||
|
||||
pb2_file = output_dir / "trtllm_service_pb2.py"
|
||||
pb2_grpc_file = output_dir / "trtllm_service_pb2_grpc.py"
|
||||
|
||||
if not pb2_file.exists() or not pb2_grpc_file.exists():
|
||||
print("Generated files not found. Run compile_protos() first.")
|
||||
return False
|
||||
|
||||
# Try to import
|
||||
import importlib.util
|
||||
|
||||
try:
|
||||
spec = importlib.util.spec_from_file_location("trtllm_service_pb2", pb2_file)
|
||||
pb2_module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(pb2_module)
|
||||
print(f"Successfully imported {pb2_file.name}")
|
||||
|
||||
# Verify key message types exist
|
||||
assert hasattr(pb2_module, "GenerateRequest")
|
||||
assert hasattr(pb2_module, "GenerateResponse")
|
||||
assert hasattr(pb2_module, "SamplingConfig")
|
||||
assert hasattr(pb2_module, "OutputConfig")
|
||||
print(" - GenerateRequest, GenerateResponse, SamplingConfig, OutputConfig OK")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error importing {pb2_file.name}: {e}")
|
||||
return False
|
||||
|
||||
print("Verification successful!")
|
||||
return True
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="Compile TensorRT-LLM gRPC protos")
|
||||
parser.add_argument(
|
||||
"--proto-dir", type=Path, default=None, help="Directory containing .proto files"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-dir", type=Path, default=None, help="Directory for generated Python files"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--verify", action="store_true", help="Verify generated files after compilation"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
success = compile_protos(args.proto_dir, args.output_dir)
|
||||
|
||||
if success and args.verify:
|
||||
success = verify_generated_files(args.output_dir)
|
||||
|
||||
sys.exit(0 if success else 1)
|
||||
420
tensorrt_llm/grpc/grpc_request_manager.py
Normal file
420
tensorrt_llm/grpc/grpc_request_manager.py
Normal file
@ -0,0 +1,420 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""gRPC Request Manager for TensorRT-LLM.
|
||||
|
||||
Manages request lifecycle for gRPC requests, converting between protobuf
|
||||
and TensorRT-LLM types. Designed for high-performance communication with
|
||||
external routers (e.g., sgl-router) using pre-tokenized input.
|
||||
|
||||
Key optimization: Sets detokenize=False in SamplingParams to skip
|
||||
detokenization and return token IDs only.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import traceback
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from tensorrt_llm.disaggregated_params import DisaggregatedParams
|
||||
from tensorrt_llm.executor.request import LoRARequest, PromptAdapterRequest
|
||||
from tensorrt_llm.executor.result import GenerationResult
|
||||
from tensorrt_llm.llmapi.llm_utils import KvCacheRetentionConfig
|
||||
from tensorrt_llm.logger import logger
|
||||
from tensorrt_llm.sampling_params import GuidedDecodingParams, SamplingParams
|
||||
|
||||
from . import trtllm_service_pb2 as pb2
|
||||
|
||||
|
||||
class GrpcRequestManager:
|
||||
"""Manages gRPC request lifecycle for TensorRT-LLM.
|
||||
|
||||
Responsibilities:
|
||||
- Convert protobuf requests to TensorRT-LLM types
|
||||
- Set detokenize=False in SamplingParams (key optimization!)
|
||||
- Submit requests to LLM.generate_async()
|
||||
- Stream token IDs (not text) back to gRPC clients
|
||||
- Handle abort/cancel operations
|
||||
|
||||
This is modeled after vLLM's GrpcRequestManager but adapted for TensorRT-LLM's
|
||||
GenerationResult async iterator pattern.
|
||||
"""
|
||||
|
||||
def __init__(self, llm: Any):
|
||||
"""Initialize the request manager.
|
||||
|
||||
Args:
|
||||
llm: The TensorRT-LLM LLM instance (tensorrt_llm.LLM or tensorrt_llm._tensorrt_engine.LLM)
|
||||
"""
|
||||
self.llm = llm
|
||||
# Track active requests: request_id -> GenerationResult
|
||||
self._rid_to_result: Dict[str, GenerationResult] = {}
|
||||
|
||||
logger.info("GrpcRequestManager initialized")
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
request_id: str,
|
||||
prompt_token_ids: List[int],
|
||||
sampling_params: SamplingParams,
|
||||
streaming: bool = True,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
kv_cache_retention_config: Optional[KvCacheRetentionConfig] = None,
|
||||
disaggregated_params: Optional[DisaggregatedParams] = None,
|
||||
) -> AsyncGenerator[GenerationResult, None]:
|
||||
"""Submit a generation request and stream outputs.
|
||||
|
||||
Args:
|
||||
request_id: Unique request identifier (for tracking/abort)
|
||||
prompt_token_ids: Pre-tokenized input from Rust router
|
||||
sampling_params: Sampling parameters (with detokenize=False!)
|
||||
streaming: Whether to stream results
|
||||
lora_request: Optional LoRA adapter request
|
||||
prompt_adapter_request: Optional prompt adapter request
|
||||
kv_cache_retention_config: KV cache retention config
|
||||
disaggregated_params: Disaggregated inference params
|
||||
|
||||
Yields:
|
||||
GenerationResult objects containing token IDs (text will be empty
|
||||
because detokenize=False)
|
||||
"""
|
||||
try:
|
||||
# Submit to LLM.generate_async which returns a GenerationResult
|
||||
# that is an async iterator
|
||||
gen_result = self.llm.generate_async(
|
||||
{"prompt_token_ids": prompt_token_ids},
|
||||
sampling_params,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
streaming=streaming,
|
||||
kv_cache_retention_config=kv_cache_retention_config,
|
||||
disaggregated_params=disaggregated_params,
|
||||
)
|
||||
|
||||
# Track the result for potential abort
|
||||
self._rid_to_result[request_id] = gen_result
|
||||
|
||||
# Iterate over the async generator
|
||||
# GenerationResult implements __aiter__ and __anext__
|
||||
async for result in gen_result:
|
||||
yield result
|
||||
|
||||
if result.finished:
|
||||
break
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"Request {request_id} cancelled by client")
|
||||
await self.abort(request_id)
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error in generate for {request_id}: {e}")
|
||||
raise
|
||||
finally:
|
||||
# Cleanup tracking
|
||||
self._rid_to_result.pop(request_id, None)
|
||||
|
||||
async def abort(self, request_id: str) -> bool:
|
||||
"""Abort a running request.
|
||||
|
||||
Args:
|
||||
request_id: The request ID to abort
|
||||
|
||||
Returns:
|
||||
True if request was found and aborted, False otherwise
|
||||
"""
|
||||
gen_result = self._rid_to_result.get(request_id)
|
||||
|
||||
if gen_result is None:
|
||||
logger.debug(f"Abort: request {request_id} not found (may have already completed)")
|
||||
return False
|
||||
|
||||
try:
|
||||
# GenerationResult has an abort() method
|
||||
gen_result.abort()
|
||||
self._rid_to_result.pop(request_id, None)
|
||||
logger.info(f"Request {request_id} aborted")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Error aborting request {request_id}: {e}")
|
||||
self._rid_to_result.pop(request_id, None)
|
||||
return False
|
||||
|
||||
async def health_check(self) -> Tuple[bool, str]:
|
||||
"""Check if the engine is healthy.
|
||||
|
||||
Returns:
|
||||
Tuple of (is_healthy, message)
|
||||
"""
|
||||
try:
|
||||
if self.llm is None:
|
||||
return False, "LLM not initialized"
|
||||
|
||||
# Check if executor is available and not shutdown
|
||||
if hasattr(self.llm, "_executor"):
|
||||
if self.llm._executor is None or self.llm._executor.is_shutdown():
|
||||
return False, "Executor is shutdown"
|
||||
|
||||
return True, "OK"
|
||||
except Exception as e:
|
||||
logger.error(f"Health check error: {e}")
|
||||
return False, f"Error: {e}"
|
||||
|
||||
def get_model_config(self) -> Dict[str, Any]:
|
||||
"""Get model configuration information.
|
||||
|
||||
Returns:
|
||||
Dictionary with model config details
|
||||
"""
|
||||
config = {
|
||||
"model_path": "",
|
||||
"is_generation": True,
|
||||
"max_context_length": 0,
|
||||
"max_seq_len": 0,
|
||||
"vocab_size": 0,
|
||||
"supports_vision": False,
|
||||
}
|
||||
|
||||
try:
|
||||
# Try to get model path
|
||||
if hasattr(self.llm, "args"):
|
||||
if hasattr(self.llm.args, "model"):
|
||||
config["model_path"] = str(self.llm.args.model)
|
||||
|
||||
# Try to get tokenizer info
|
||||
if hasattr(self.llm, "tokenizer") and self.llm.tokenizer is not None:
|
||||
if hasattr(self.llm.tokenizer, "vocab_size"):
|
||||
config["vocab_size"] = self.llm.tokenizer.vocab_size
|
||||
|
||||
# Try to get max context length from various sources
|
||||
if hasattr(self.llm, "args") and self.llm.args is not None:
|
||||
args = self.llm.args
|
||||
# Try max_input_len first (input context)
|
||||
if hasattr(args, "max_input_len") and args.max_input_len:
|
||||
config["max_context_length"] = args.max_input_len
|
||||
# Try max_seq_len (total sequence including output)
|
||||
if hasattr(args, "max_seq_len") and args.max_seq_len:
|
||||
config["max_seq_len"] = args.max_seq_len
|
||||
|
||||
# Check for multimodal support
|
||||
if hasattr(self.llm, "input_processor"):
|
||||
processor_name = type(self.llm.input_processor).__name__
|
||||
config["supports_vision"] = processor_name != "DefaultInputProcessor"
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Error getting model config: {type(e).__name__}: {e}\n{traceback.format_exc()}"
|
||||
)
|
||||
|
||||
return config
|
||||
|
||||
def get_num_unfinished_requests(self) -> int:
|
||||
"""Get the number of currently running requests.
|
||||
|
||||
Returns:
|
||||
Number of unfinished requests
|
||||
"""
|
||||
return len(self._rid_to_result)
|
||||
|
||||
|
||||
def create_sampling_params_from_proto(
|
||||
proto_config: pb2.SamplingConfig,
|
||||
output_config: pb2.OutputConfig,
|
||||
max_tokens: int,
|
||||
end_id: Optional[int] = None,
|
||||
pad_id: Optional[int] = None,
|
||||
bad_words: Optional[List[pb2.TokenSequence]] = None,
|
||||
stop_words: Optional[List[pb2.TokenSequence]] = None,
|
||||
guided_decoding: Optional[pb2.GuidedDecodingParams] = None,
|
||||
embedding_bias: Optional[List[float]] = None,
|
||||
) -> SamplingParams:
|
||||
"""Convert protobuf configuration to TensorRT-LLM SamplingParams.
|
||||
|
||||
Args:
|
||||
proto_config: Protobuf SamplingConfig message
|
||||
output_config: Protobuf OutputConfig message
|
||||
max_tokens: Maximum tokens to generate
|
||||
end_id: End-of-sequence token ID
|
||||
pad_id: Padding token ID
|
||||
bad_words: Bad word token sequences
|
||||
stop_words: Stop word token sequences
|
||||
guided_decoding: Guided decoding parameters
|
||||
embedding_bias: Embedding bias tensor
|
||||
|
||||
Returns:
|
||||
TensorRT-LLM SamplingParams with detokenize=False
|
||||
"""
|
||||
# Build kwargs for SamplingParams
|
||||
# KEY OPTIMIZATION: detokenize=False skips Python detokenization!
|
||||
kwargs = {
|
||||
"max_tokens": max_tokens,
|
||||
"detokenize": False,
|
||||
}
|
||||
|
||||
# Beam search / sampling
|
||||
if proto_config.beam_width > 1:
|
||||
kwargs["beam_width"] = proto_config.beam_width
|
||||
if proto_config.num_return_sequences > 0:
|
||||
kwargs["n"] = proto_config.num_return_sequences
|
||||
|
||||
# Temperature and sampling parameters (with sensible defaults as safety guard)
|
||||
kwargs["temperature"] = (
|
||||
proto_config.temperature if proto_config.HasField("temperature") else 1.0
|
||||
)
|
||||
kwargs["top_p"] = proto_config.top_p if proto_config.HasField("top_p") else 1.0
|
||||
if proto_config.HasField("top_k"):
|
||||
kwargs["top_k"] = proto_config.top_k
|
||||
if proto_config.HasField("min_p"):
|
||||
kwargs["min_p"] = proto_config.min_p
|
||||
|
||||
# Top-P decay parameters
|
||||
if proto_config.HasField("top_p_min"):
|
||||
kwargs["top_p_min"] = proto_config.top_p_min
|
||||
if proto_config.HasField("top_p_reset_ids"):
|
||||
kwargs["top_p_reset_ids"] = proto_config.top_p_reset_ids
|
||||
if proto_config.HasField("top_p_decay"):
|
||||
kwargs["top_p_decay"] = proto_config.top_p_decay
|
||||
|
||||
# Seed for reproducibility
|
||||
if proto_config.HasField("seed"):
|
||||
kwargs["random_seed"] = proto_config.seed
|
||||
|
||||
# Min/max tokens
|
||||
if proto_config.HasField("min_tokens"):
|
||||
kwargs["min_tokens"] = proto_config.min_tokens
|
||||
|
||||
# Penalties (repetition_penalty defaults to 1.0 = no penalty)
|
||||
kwargs["repetition_penalty"] = (
|
||||
proto_config.repetition_penalty if proto_config.HasField("repetition_penalty") else 1.0
|
||||
)
|
||||
if proto_config.HasField("presence_penalty"):
|
||||
kwargs["presence_penalty"] = proto_config.presence_penalty
|
||||
if proto_config.HasField("frequency_penalty"):
|
||||
kwargs["frequency_penalty"] = proto_config.frequency_penalty
|
||||
|
||||
# Beam search parameters
|
||||
if proto_config.HasField("beam_search_diversity_rate"):
|
||||
kwargs["beam_search_diversity_rate"] = proto_config.beam_search_diversity_rate
|
||||
if proto_config.HasField("length_penalty"):
|
||||
kwargs["length_penalty"] = proto_config.length_penalty
|
||||
if proto_config.HasField("early_stopping"):
|
||||
kwargs["early_stopping"] = proto_config.early_stopping
|
||||
|
||||
# N-gram blocking
|
||||
if proto_config.HasField("no_repeat_ngram_size"):
|
||||
kwargs["no_repeat_ngram_size"] = proto_config.no_repeat_ngram_size
|
||||
|
||||
# End/pad tokens
|
||||
if end_id is not None:
|
||||
kwargs["end_id"] = end_id
|
||||
if end_id == -1:
|
||||
kwargs["ignore_eos"] = True
|
||||
if pad_id is not None:
|
||||
kwargs["pad_id"] = pad_id
|
||||
|
||||
# Output configuration - logprobs
|
||||
if output_config.HasField("logprobs"):
|
||||
kwargs["logprobs"] = output_config.logprobs
|
||||
if output_config.HasField("prompt_logprobs"):
|
||||
kwargs["prompt_logprobs"] = output_config.prompt_logprobs
|
||||
if output_config.return_context_logits:
|
||||
kwargs["return_context_logits"] = True
|
||||
if output_config.return_generation_logits:
|
||||
kwargs["return_generation_logits"] = True
|
||||
if output_config.exclude_input_from_output:
|
||||
kwargs["exclude_input_from_output"] = True
|
||||
|
||||
# Stop sequences (as token ID lists)
|
||||
if stop_words:
|
||||
kwargs["stop_words"] = [list(seq.token_ids) for seq in stop_words]
|
||||
if bad_words:
|
||||
kwargs["bad_words"] = [list(seq.token_ids) for seq in bad_words]
|
||||
|
||||
# Embedding bias
|
||||
if embedding_bias:
|
||||
kwargs["embedding_bias"] = list(embedding_bias)
|
||||
|
||||
# Guided decoding
|
||||
if guided_decoding and guided_decoding.guide:
|
||||
guide_type = guided_decoding.guide_type
|
||||
guide_content = guided_decoding.guide
|
||||
|
||||
if guide_type == pb2.GuidedDecodingParams.GUIDE_TYPE_JSON:
|
||||
# json_object=True for JSON validation without schema constraint
|
||||
kwargs["guided_decoding_params"] = GuidedDecodingParams(json_object=True)
|
||||
elif guide_type == pb2.GuidedDecodingParams.GUIDE_TYPE_JSON_SCHEMA:
|
||||
kwargs["guided_decoding_params"] = GuidedDecodingParams(json_schema=guide_content)
|
||||
elif guide_type == pb2.GuidedDecodingParams.GUIDE_TYPE_REGEX:
|
||||
kwargs["guided_decoding_params"] = GuidedDecodingParams(regex=guide_content)
|
||||
elif guide_type == pb2.GuidedDecodingParams.GUIDE_TYPE_EBNF_GRAMMAR:
|
||||
kwargs["guided_decoding_params"] = GuidedDecodingParams(grammar=guide_content)
|
||||
|
||||
return SamplingParams(**kwargs)
|
||||
|
||||
|
||||
def create_lora_request_from_proto(
|
||||
proto_config: Optional[pb2.LoraConfig],
|
||||
) -> Optional[LoRARequest]:
|
||||
"""Convert protobuf LoraConfig to TensorRT-LLM LoRARequest.
|
||||
|
||||
Args:
|
||||
proto_config: Protobuf LoraConfig message
|
||||
|
||||
Returns:
|
||||
LoRARequest or None
|
||||
"""
|
||||
if proto_config is None or proto_config.task_id == 0:
|
||||
return None
|
||||
|
||||
return LoRARequest(
|
||||
lora_name=f"lora_{proto_config.task_id}",
|
||||
lora_int_id=proto_config.task_id,
|
||||
)
|
||||
|
||||
|
||||
def create_disaggregated_params_from_proto(
|
||||
proto_config: Optional[pb2.DisaggregatedParams],
|
||||
) -> Optional[DisaggregatedParams]:
|
||||
"""Convert protobuf DisaggregatedParams to TensorRT-LLM DisaggregatedParams.
|
||||
|
||||
Args:
|
||||
proto_config: Protobuf DisaggregatedParams message
|
||||
|
||||
Returns:
|
||||
DisaggregatedParams or None
|
||||
"""
|
||||
if proto_config is None:
|
||||
return None
|
||||
|
||||
request_type_map = {
|
||||
pb2.DisaggregatedParams.REQUEST_TYPE_CONTEXT_AND_GENERATION: "context_and_generation",
|
||||
pb2.DisaggregatedParams.REQUEST_TYPE_CONTEXT_ONLY: "context_only",
|
||||
pb2.DisaggregatedParams.REQUEST_TYPE_GENERATION_ONLY: "generation_only",
|
||||
}
|
||||
|
||||
request_type = request_type_map.get(proto_config.request_type, "context_and_generation")
|
||||
|
||||
params = DisaggregatedParams(request_type=request_type)
|
||||
|
||||
if proto_config.ctx_request_id:
|
||||
params.ctx_request_id = proto_config.ctx_request_id
|
||||
|
||||
if proto_config.HasField("context_phase_params"):
|
||||
ctx_params = proto_config.context_phase_params
|
||||
params.first_gen_token_id = ctx_params.first_gen_token_id
|
||||
if ctx_params.kv_cache_blocks:
|
||||
params.kv_cache_blocks = ctx_params.kv_cache_blocks
|
||||
|
||||
return params
|
||||
545
tensorrt_llm/grpc/grpc_servicer.py
Normal file
545
tensorrt_llm/grpc/grpc_servicer.py
Normal file
@ -0,0 +1,545 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""gRPC Servicer for TensorRT-LLM.
|
||||
|
||||
Implements the TrtllmService gRPC service for high-performance communication
|
||||
with external routers (e.g., sgl-router) using pre-tokenized input.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import List, Union
|
||||
|
||||
import grpc
|
||||
|
||||
from tensorrt_llm.executor.result import Logprob, TokenLogprobs
|
||||
from tensorrt_llm.logger import logger
|
||||
|
||||
from . import trtllm_service_pb2, trtllm_service_pb2_grpc
|
||||
from .grpc_request_manager import (
|
||||
GrpcRequestManager,
|
||||
create_disaggregated_params_from_proto,
|
||||
create_lora_request_from_proto,
|
||||
create_sampling_params_from_proto,
|
||||
)
|
||||
|
||||
|
||||
class TrtllmServiceServicer(trtllm_service_pb2_grpc.TrtllmServiceServicer):
|
||||
"""gRPC servicer implementing the TrtLlmEngine service.
|
||||
|
||||
Handles RPCs:
|
||||
- Generate: Streaming text generation
|
||||
- Embed: Embeddings (for embedding models)
|
||||
- HealthCheck: Health probe
|
||||
- Abort: Cancel a request
|
||||
- GetModelInfo: Model metadata
|
||||
- GetServerInfo: Server state
|
||||
"""
|
||||
|
||||
def __init__(self, request_manager: GrpcRequestManager, model_path: str = ""):
|
||||
"""Initialize the servicer.
|
||||
|
||||
Args:
|
||||
request_manager: The GrpcRequestManager instance
|
||||
model_path: Path to the model (for metadata)
|
||||
"""
|
||||
self.request_manager = request_manager
|
||||
self.model_path = model_path
|
||||
self._start_time = time.time()
|
||||
logger.info("TrtllmServiceServicer initialized")
|
||||
|
||||
async def Generate(
|
||||
self,
|
||||
request: trtllm_service_pb2.GenerateRequest,
|
||||
context: grpc.aio.ServicerContext,
|
||||
) -> AsyncGenerator[trtllm_service_pb2.GenerateResponse, None]:
|
||||
"""Handle streaming generation requests.
|
||||
|
||||
Args:
|
||||
request: The GenerateRequest protobuf
|
||||
context: gRPC context
|
||||
|
||||
Yields:
|
||||
GenerateResponse protobuf messages (streaming)
|
||||
"""
|
||||
request_id = request.request_id
|
||||
logger.info(f"Generate request {request_id} received")
|
||||
|
||||
try:
|
||||
# Extract tokenized input (required)
|
||||
if not request.HasField("tokenized"):
|
||||
yield self._error_response(
|
||||
request_id,
|
||||
"Missing tokenized input",
|
||||
"INVALID_REQUEST",
|
||||
400,
|
||||
)
|
||||
return
|
||||
|
||||
prompt_token_ids = list(request.tokenized.input_token_ids)
|
||||
|
||||
# Build sampling params with detokenize=False (key optimization!)
|
||||
sampling_params = create_sampling_params_from_proto(
|
||||
proto_config=request.sampling_config,
|
||||
output_config=request.output_config,
|
||||
max_tokens=request.max_tokens,
|
||||
end_id=request.end_id if request.HasField("end_id") else None,
|
||||
pad_id=request.pad_id if request.HasField("pad_id") else None,
|
||||
bad_words=list(request.bad_words) if request.bad_words else None,
|
||||
stop_words=list(request.stop_words) if request.stop_words else None,
|
||||
guided_decoding=request.guided_decoding
|
||||
if request.HasField("guided_decoding")
|
||||
else None,
|
||||
embedding_bias=list(request.embedding_bias) if request.embedding_bias else None,
|
||||
)
|
||||
|
||||
# Build LoRA request if present
|
||||
lora_request = create_lora_request_from_proto(
|
||||
request.lora_config if request.HasField("lora_config") else None
|
||||
)
|
||||
|
||||
# Build disaggregated params if present
|
||||
disaggregated_params = create_disaggregated_params_from_proto(
|
||||
request.disaggregated_params if request.HasField("disaggregated_params") else None
|
||||
)
|
||||
|
||||
# Track tokens sent per sequence index to avoid duplicates
|
||||
# TRT-LLM's token_ids_diff doesn't clear between iterations for n>1
|
||||
sent_token_counts: dict[int, int] = {}
|
||||
|
||||
# Submit to request manager and stream outputs
|
||||
# The request manager now yields GenerationResult objects
|
||||
async for gen_result in self.request_manager.generate(
|
||||
request_id=request_id,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
sampling_params=sampling_params,
|
||||
streaming=request.streaming,
|
||||
lora_request=lora_request,
|
||||
disaggregated_params=disaggregated_params,
|
||||
):
|
||||
# Check if client disconnected
|
||||
if context.cancelled():
|
||||
logger.info(f"Client disconnected for {request_id}")
|
||||
await self.request_manager.abort(request_id)
|
||||
return
|
||||
|
||||
# Convert GenerationResult to protobuf response
|
||||
if request.streaming:
|
||||
for chunk_response in self._chunk_responses(
|
||||
request_id, gen_result, prompt_token_ids, sent_token_counts
|
||||
):
|
||||
yield chunk_response
|
||||
|
||||
# Send complete responses when finished (one per sequence for n>1)
|
||||
if gen_result.finished:
|
||||
for complete_response in self._complete_responses(
|
||||
request_id, gen_result, prompt_token_ids
|
||||
):
|
||||
yield complete_response
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"Request {request_id} cancelled")
|
||||
await self.request_manager.abort(request_id)
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error in Generate for {request_id}: {e}")
|
||||
yield self._error_response(
|
||||
request_id,
|
||||
str(e),
|
||||
"INTERNAL_ERROR",
|
||||
500,
|
||||
)
|
||||
|
||||
async def Embed(
|
||||
self,
|
||||
request: trtllm_service_pb2.EmbedRequest,
|
||||
context: grpc.aio.ServicerContext,
|
||||
) -> trtllm_service_pb2.EmbedResponse:
|
||||
"""Handle embedding requests.
|
||||
|
||||
Args:
|
||||
request: The EmbedRequest protobuf
|
||||
context: gRPC context
|
||||
|
||||
Returns:
|
||||
EmbedResponse protobuf
|
||||
"""
|
||||
logger.warning("Embed RPC not yet implemented")
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details("Embed RPC not yet implemented")
|
||||
return trtllm_service_pb2.EmbedResponse(
|
||||
request_id=request.request_id,
|
||||
embedding=[],
|
||||
prompt_tokens=0,
|
||||
)
|
||||
|
||||
async def HealthCheck(
|
||||
self,
|
||||
request: trtllm_service_pb2.HealthCheckRequest,
|
||||
context: grpc.aio.ServicerContext,
|
||||
) -> trtllm_service_pb2.HealthCheckResponse:
|
||||
"""Handle health check requests.
|
||||
|
||||
Args:
|
||||
request: The HealthCheckRequest protobuf
|
||||
context: gRPC context
|
||||
|
||||
Returns:
|
||||
HealthCheckResponse protobuf
|
||||
"""
|
||||
is_healthy, message = await self.request_manager.health_check()
|
||||
logger.info(f"HealthCheck: healthy={is_healthy}, message={message}")
|
||||
|
||||
return trtllm_service_pb2.HealthCheckResponse(
|
||||
status=message,
|
||||
)
|
||||
|
||||
async def Abort(
|
||||
self,
|
||||
request: trtllm_service_pb2.AbortRequest,
|
||||
context: grpc.aio.ServicerContext,
|
||||
) -> trtllm_service_pb2.AbortResponse:
|
||||
"""Handle abort requests.
|
||||
|
||||
Args:
|
||||
request: The AbortRequest protobuf
|
||||
context: gRPC context
|
||||
|
||||
Returns:
|
||||
AbortResponse protobuf
|
||||
"""
|
||||
request_id = request.request_id
|
||||
logger.info(f"Abort request for {request_id}")
|
||||
|
||||
success = await self.request_manager.abort(request_id)
|
||||
|
||||
return trtllm_service_pb2.AbortResponse(
|
||||
success=success,
|
||||
message=f"Request {request_id} {'aborted' if success else 'not found'}",
|
||||
)
|
||||
|
||||
async def GetModelInfo(
|
||||
self,
|
||||
request: trtllm_service_pb2.GetModelInfoRequest,
|
||||
context: grpc.aio.ServicerContext,
|
||||
) -> trtllm_service_pb2.GetModelInfoResponse:
|
||||
"""Handle model info requests.
|
||||
|
||||
Args:
|
||||
request: The GetModelInfoRequest protobuf
|
||||
context: gRPC context
|
||||
|
||||
Returns:
|
||||
GetModelInfoResponse protobuf
|
||||
"""
|
||||
model_config = self.request_manager.get_model_config()
|
||||
|
||||
return trtllm_service_pb2.GetModelInfoResponse(
|
||||
model_id=self.model_path or model_config.get("model_path", ""),
|
||||
max_input_len=model_config.get("max_context_length", 0),
|
||||
max_seq_len=model_config.get("max_seq_len", 0)
|
||||
or model_config.get("max_context_length", 0),
|
||||
vocab_size=model_config.get("vocab_size", 0),
|
||||
)
|
||||
|
||||
async def GetServerInfo(
|
||||
self,
|
||||
request: trtllm_service_pb2.GetServerInfoRequest,
|
||||
context: grpc.aio.ServicerContext,
|
||||
) -> trtllm_service_pb2.GetServerInfoResponse:
|
||||
"""Handle server info requests.
|
||||
|
||||
Args:
|
||||
request: The GetServerInfoRequest protobuf
|
||||
context: gRPC context
|
||||
|
||||
Returns:
|
||||
GetServerInfoResponse protobuf
|
||||
"""
|
||||
try:
|
||||
import tensorrt_llm
|
||||
|
||||
version = getattr(tensorrt_llm, "__version__", "unknown")
|
||||
except Exception:
|
||||
version = "unknown"
|
||||
|
||||
# Try to get parallelism info from LLM args
|
||||
tp_size = 1
|
||||
pp_size = 1
|
||||
world_size = 1
|
||||
|
||||
try:
|
||||
llm = self.request_manager.llm
|
||||
if hasattr(llm, "args") and llm.args is not None:
|
||||
args = llm.args
|
||||
if hasattr(args, "tensor_parallel_size") and args.tensor_parallel_size:
|
||||
tp_size = args.tensor_parallel_size
|
||||
if hasattr(args, "pipeline_parallel_size") and args.pipeline_parallel_size:
|
||||
pp_size = args.pipeline_parallel_size
|
||||
world_size = tp_size * pp_size
|
||||
except Exception as e:
|
||||
logger.debug(f"Could not get parallelism info: {e}")
|
||||
|
||||
return trtllm_service_pb2.GetServerInfoResponse(
|
||||
version=version,
|
||||
backend="tensorrt-llm",
|
||||
tensor_parallel_size=tp_size,
|
||||
pipeline_parallel_size=pp_size,
|
||||
context_parallel_size=1, # Context parallelism is separate from TP/PP
|
||||
world_size=world_size,
|
||||
)
|
||||
|
||||
# ========== Helper methods ==========
|
||||
|
||||
def _convert_logprobs_to_proto(
|
||||
self,
|
||||
token_ids: List[int],
|
||||
logprobs: Union[TokenLogprobs, List[float], None],
|
||||
) -> List[trtllm_service_pb2.TokenLogprob]:
|
||||
"""Convert TRT-LLM logprobs to protobuf TokenLogprob messages.
|
||||
|
||||
Handles both formats:
|
||||
- List[float]: Simple logprobs (one per token)
|
||||
- TokenLogprobs (list[dict[int, Logprob]]): Top-k logprobs per position
|
||||
|
||||
Args:
|
||||
token_ids: The sampled token IDs
|
||||
logprobs: Logprobs from TRT-LLM (can be List[float] or TokenLogprobs)
|
||||
|
||||
Returns:
|
||||
List of TokenLogprob proto messages
|
||||
"""
|
||||
if not logprobs or not token_ids:
|
||||
return []
|
||||
|
||||
result = []
|
||||
for i, token_id in enumerate(token_ids):
|
||||
if i >= len(logprobs):
|
||||
break
|
||||
|
||||
lp = logprobs[i]
|
||||
|
||||
if isinstance(lp, dict):
|
||||
# TokenLogprobs format: dict[int, Logprob]
|
||||
# Each entry maps token_id -> Logprob(logprob, rank)
|
||||
token_logprob = trtllm_service_pb2.TokenLogprob(
|
||||
token_id=token_id,
|
||||
logprob=lp[token_id].logprob if token_id in lp else 0.0,
|
||||
)
|
||||
# Add top logprobs (all entries in the dict)
|
||||
for tid, logprob_obj in lp.items():
|
||||
if isinstance(logprob_obj, Logprob):
|
||||
token_logprob.top_logprobs.append(
|
||||
trtllm_service_pb2.TopLogprob(
|
||||
token_id=tid,
|
||||
logprob=logprob_obj.logprob,
|
||||
)
|
||||
)
|
||||
result.append(token_logprob)
|
||||
elif isinstance(lp, (int, float)):
|
||||
# Simple float logprob
|
||||
token_logprob = trtllm_service_pb2.TokenLogprob(
|
||||
token_id=token_id,
|
||||
logprob=float(lp),
|
||||
)
|
||||
result.append(token_logprob)
|
||||
|
||||
return result
|
||||
|
||||
def _chunk_responses(
|
||||
self,
|
||||
request_id: str,
|
||||
gen_result,
|
||||
prompt_token_ids: list,
|
||||
sent_token_counts: dict[int, int],
|
||||
) -> List[trtllm_service_pb2.GenerateResponse]:
|
||||
"""Build streaming chunk responses from GenerationResult.
|
||||
|
||||
Uses cumulative token_ids and tracks sent position to compute true deltas.
|
||||
TRT-LLM's token_ids_diff doesn't clear between iterations for n>1, so we
|
||||
compute deltas ourselves.
|
||||
|
||||
Args:
|
||||
request_id: The request ID
|
||||
gen_result: TensorRT-LLM GenerationResult
|
||||
prompt_token_ids: Original prompt tokens
|
||||
sent_token_counts: Dict tracking tokens already sent per sequence index
|
||||
|
||||
Returns:
|
||||
List of GenerateResponse with chunk field set (one per output)
|
||||
"""
|
||||
responses = []
|
||||
cached_tokens = gen_result.cached_tokens if hasattr(gen_result, "cached_tokens") else 0
|
||||
|
||||
if not gen_result.outputs:
|
||||
# No outputs yet, return empty chunk
|
||||
responses.append(
|
||||
trtllm_service_pb2.GenerateResponse(
|
||||
request_id=request_id,
|
||||
chunk=trtllm_service_pb2.GenerateStreamChunk(
|
||||
token_ids=[],
|
||||
prompt_tokens=len(prompt_token_ids),
|
||||
completion_tokens=0,
|
||||
cached_tokens=cached_tokens,
|
||||
),
|
||||
)
|
||||
)
|
||||
return responses
|
||||
|
||||
# Process all outputs (for n>1 support)
|
||||
for completion in gen_result.outputs:
|
||||
index = completion.index
|
||||
# Use cumulative token_ids and compute delta ourselves
|
||||
# because token_ids_diff doesn't clear between iterations for n>1
|
||||
all_tokens = list(completion.token_ids) if completion.token_ids else []
|
||||
sent_count = sent_token_counts.get(index, 0)
|
||||
delta_tokens = all_tokens[sent_count:]
|
||||
|
||||
# Skip if no new tokens for this sequence
|
||||
if not delta_tokens:
|
||||
continue
|
||||
|
||||
# Update sent count
|
||||
sent_token_counts[index] = len(all_tokens)
|
||||
|
||||
chunk = trtllm_service_pb2.GenerateStreamChunk(
|
||||
token_ids=delta_tokens,
|
||||
sequence_index=completion.index,
|
||||
prompt_tokens=len(prompt_token_ids),
|
||||
completion_tokens=len(completion.token_ids) if completion.token_ids else 0,
|
||||
cached_tokens=cached_tokens,
|
||||
)
|
||||
|
||||
# Add logprobs if available
|
||||
# Note: We compute delta logprobs ourselves since logprobs_diff has same issue as token_ids_diff
|
||||
if completion.logprobs:
|
||||
all_logprobs = completion.logprobs
|
||||
delta_logprobs = all_logprobs[sent_count:] if sent_count < len(all_logprobs) else []
|
||||
proto_logprobs = self._convert_logprobs_to_proto(delta_tokens, delta_logprobs)
|
||||
chunk.logprobs.extend(proto_logprobs)
|
||||
|
||||
responses.append(
|
||||
trtllm_service_pb2.GenerateResponse(
|
||||
request_id=request_id,
|
||||
chunk=chunk,
|
||||
)
|
||||
)
|
||||
|
||||
return responses
|
||||
|
||||
def _complete_responses(
|
||||
self,
|
||||
request_id: str,
|
||||
gen_result,
|
||||
prompt_token_ids: list,
|
||||
) -> List[trtllm_service_pb2.GenerateResponse]:
|
||||
"""Build final completion responses from GenerationResult.
|
||||
|
||||
For n>1, returns one response per output sequence.
|
||||
|
||||
Args:
|
||||
request_id: The request ID
|
||||
gen_result: TensorRT-LLM GenerationResult (finished=True)
|
||||
prompt_token_ids: Original prompt tokens
|
||||
|
||||
Returns:
|
||||
List of GenerateResponse with complete field set (one per output)
|
||||
"""
|
||||
responses = []
|
||||
cached_tokens = gen_result.cached_tokens if hasattr(gen_result, "cached_tokens") else 0
|
||||
|
||||
if not gen_result.outputs:
|
||||
# No outputs, return error response
|
||||
responses.append(
|
||||
trtllm_service_pb2.GenerateResponse(
|
||||
request_id=request_id,
|
||||
complete=trtllm_service_pb2.GenerateComplete(
|
||||
output_token_ids=[],
|
||||
finish_reason="error",
|
||||
prompt_tokens=len(prompt_token_ids),
|
||||
completion_tokens=0,
|
||||
cached_tokens=0,
|
||||
),
|
||||
)
|
||||
)
|
||||
return responses
|
||||
|
||||
# Process all outputs (for n>1 support)
|
||||
for completion in gen_result.outputs:
|
||||
output_tokens = list(completion.token_ids) if completion.token_ids else []
|
||||
|
||||
complete = trtllm_service_pb2.GenerateComplete(
|
||||
output_token_ids=output_tokens,
|
||||
sequence_index=completion.index,
|
||||
finish_reason=completion.finish_reason or "stop",
|
||||
prompt_tokens=len(prompt_token_ids),
|
||||
completion_tokens=len(output_tokens),
|
||||
cached_tokens=cached_tokens,
|
||||
)
|
||||
|
||||
# Add stop reason if available
|
||||
if hasattr(completion, "stop_reason") and completion.stop_reason:
|
||||
complete.stop_reason = str(completion.stop_reason)
|
||||
|
||||
# Add generation logprobs if available
|
||||
if completion.logprobs:
|
||||
proto_logprobs = self._convert_logprobs_to_proto(output_tokens, completion.logprobs)
|
||||
complete.logprobs.extend(proto_logprobs)
|
||||
|
||||
# Add prompt logprobs if available
|
||||
if hasattr(completion, "prompt_logprobs") and completion.prompt_logprobs:
|
||||
# For prompt logprobs, we use the prompt_token_ids
|
||||
proto_prompt_logprobs = self._convert_logprobs_to_proto(
|
||||
prompt_token_ids, completion.prompt_logprobs
|
||||
)
|
||||
complete.prompt_logprobs.extend(proto_prompt_logprobs)
|
||||
|
||||
responses.append(
|
||||
trtllm_service_pb2.GenerateResponse(
|
||||
request_id=request_id,
|
||||
complete=complete,
|
||||
)
|
||||
)
|
||||
|
||||
return responses
|
||||
|
||||
def _error_response(
|
||||
self,
|
||||
request_id: str,
|
||||
message: str,
|
||||
error_type: str,
|
||||
code: int,
|
||||
) -> trtllm_service_pb2.GenerateResponse:
|
||||
"""Build an error response.
|
||||
|
||||
Args:
|
||||
request_id: The request ID
|
||||
message: Error message
|
||||
error_type: Error type string
|
||||
code: Error code
|
||||
|
||||
Returns:
|
||||
GenerateResponse with error field set
|
||||
"""
|
||||
return trtllm_service_pb2.GenerateResponse(
|
||||
request_id=request_id,
|
||||
error=trtllm_service_pb2.GenerateError(
|
||||
message=message,
|
||||
type=error_type,
|
||||
code=code,
|
||||
),
|
||||
)
|
||||
511
tensorrt_llm/grpc/trtllm_service.proto
Normal file
511
tensorrt_llm/grpc/trtllm_service.proto
Normal file
@ -0,0 +1,511 @@
|
||||
// SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
syntax = "proto3";
|
||||
package trtllm;
|
||||
|
||||
// TensorRT-LLM gRPC Service
|
||||
//
|
||||
// This service provides high-performance inference for LLMs using TensorRT-LLM.
|
||||
// It accepts pre-tokenized requests and returns raw token IDs, enabling efficient
|
||||
// binary communication with external routers (e.g., sgl-router).
|
||||
//
|
||||
// Key Design Principles:
|
||||
// - Token IDs only: No text in requests or responses (router handles tokenization)
|
||||
// - Streaming delta mode: Chunks contain only new tokens since last chunk
|
||||
// - Full feature support: All TensorRT-LLM capabilities exposed
|
||||
|
||||
service TrtllmService {
|
||||
// Generate tokens from pre-tokenized input
|
||||
// Returns a stream of responses containing token IDs
|
||||
rpc Generate(GenerateRequest) returns (stream GenerateResponse);
|
||||
|
||||
// Generate embeddings from pre-tokenized input (for embedding models)
|
||||
rpc Embed(EmbedRequest) returns (EmbedResponse);
|
||||
|
||||
// Health check endpoint
|
||||
rpc HealthCheck(HealthCheckRequest) returns (HealthCheckResponse);
|
||||
|
||||
// Abort a running generation request
|
||||
rpc Abort(AbortRequest) returns (AbortResponse);
|
||||
|
||||
// Get model information (vocab size, max lengths, etc.)
|
||||
rpc GetModelInfo(GetModelInfoRequest) returns (GetModelInfoResponse);
|
||||
|
||||
// Get server information (version, parallelism, etc.)
|
||||
rpc GetServerInfo(GetServerInfoRequest) returns (GetServerInfoResponse);
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Generate Request
|
||||
// ============================================================================
|
||||
|
||||
message GenerateRequest {
|
||||
// Unique request identifier (assigned by client/router)
|
||||
string request_id = 1;
|
||||
|
||||
// Pre-tokenized input (REQUIRED - router tokenizes)
|
||||
TokenizedInput tokenized = 2;
|
||||
|
||||
// Sampling configuration
|
||||
SamplingConfig sampling_config = 3;
|
||||
|
||||
// Output configuration
|
||||
OutputConfig output_config = 4;
|
||||
|
||||
// Maximum tokens to generate (REQUIRED)
|
||||
uint32 max_tokens = 5;
|
||||
|
||||
// Enable streaming mode
|
||||
bool streaming = 6;
|
||||
|
||||
// End-of-sequence token ID (optional, set to -1 to ignore EOS)
|
||||
optional int32 end_id = 7;
|
||||
|
||||
// Padding token ID
|
||||
optional int32 pad_id = 8;
|
||||
|
||||
// Bad word token sequences (generation redirects if these are produced)
|
||||
repeated TokenSequence bad_words = 9;
|
||||
|
||||
// Stop word token sequences (generation stops when these are produced)
|
||||
repeated TokenSequence stop_words = 10;
|
||||
|
||||
// Guided decoding parameters (JSON schema, regex, grammar)
|
||||
optional GuidedDecodingParams guided_decoding = 11;
|
||||
|
||||
// Embedding bias tensor (vocab_size floats, optional)
|
||||
repeated float embedding_bias = 12;
|
||||
|
||||
// LoRA adapter configuration
|
||||
optional LoraConfig lora_config = 13;
|
||||
|
||||
// Prompt tuning configuration
|
||||
optional PromptTuningConfig prompt_tuning_config = 14;
|
||||
|
||||
// Multimodal input (for VLMs)
|
||||
optional MultimodalInput multimodal_input = 15;
|
||||
|
||||
// KV cache retention configuration
|
||||
optional KvCacheRetentionConfig kv_cache_retention = 16;
|
||||
|
||||
// Disaggregated inference parameters
|
||||
optional DisaggregatedParams disaggregated_params = 17;
|
||||
|
||||
// Lookahead decoding configuration
|
||||
optional LookaheadConfig lookahead_config = 18;
|
||||
|
||||
// Cache salt ID for cache hashing
|
||||
optional int64 cache_salt_id = 19;
|
||||
|
||||
// Request arrival time (unix timestamp for metrics)
|
||||
optional double arrival_time = 20;
|
||||
}
|
||||
|
||||
// Tokenized input from router
|
||||
message TokenizedInput {
|
||||
// Original text (for debugging/logging only, not used for generation)
|
||||
string original_text = 1;
|
||||
|
||||
// Pre-tokenized input token IDs (REQUIRED)
|
||||
repeated uint32 input_token_ids = 2;
|
||||
|
||||
// Query token IDs for VLM star attention (optional)
|
||||
repeated uint32 query_token_ids = 3;
|
||||
}
|
||||
|
||||
// Sequence of token IDs (for stop/bad words)
|
||||
message TokenSequence {
|
||||
repeated uint32 token_ids = 1;
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Sampling Configuration
|
||||
// Maps to tensorrt_llm.bindings.executor.SamplingConfig
|
||||
// ============================================================================
|
||||
|
||||
message SamplingConfig {
|
||||
// Beam width (1 for sampling, >1 for beam search)
|
||||
int32 beam_width = 1;
|
||||
|
||||
// Number of sequences to return
|
||||
uint32 num_return_sequences = 2;
|
||||
|
||||
// Top-K sampling (0 = disabled, considers all tokens)
|
||||
optional int32 top_k = 3;
|
||||
|
||||
// Top-P (nucleus) sampling threshold
|
||||
optional float top_p = 4;
|
||||
|
||||
// Top-P minimum threshold for decay
|
||||
optional float top_p_min = 5;
|
||||
|
||||
// Top-P reset token IDs
|
||||
optional int32 top_p_reset_ids = 6;
|
||||
|
||||
// Top-P decay factor
|
||||
optional float top_p_decay = 7;
|
||||
|
||||
// Random seed for reproducibility
|
||||
optional uint64 seed = 8;
|
||||
|
||||
// Temperature for sampling (0 = greedy, higher = more random)
|
||||
optional float temperature = 9;
|
||||
|
||||
// Minimum tokens to generate before stopping
|
||||
optional uint32 min_tokens = 10;
|
||||
|
||||
// Beam search diversity rate
|
||||
optional float beam_search_diversity_rate = 11;
|
||||
|
||||
// Repetition penalty (>1 discourages, <1 encourages repetition)
|
||||
optional float repetition_penalty = 12;
|
||||
|
||||
// Presence penalty (penalizes tokens that have appeared)
|
||||
optional float presence_penalty = 13;
|
||||
|
||||
// Frequency penalty (penalizes based on frequency of appearance)
|
||||
optional float frequency_penalty = 14;
|
||||
|
||||
// Number of prompt tokens to ignore for penalties
|
||||
optional int32 prompt_ignore_length = 15;
|
||||
|
||||
// Length penalty for beam search
|
||||
optional float length_penalty = 16;
|
||||
|
||||
// Early stopping for beam search
|
||||
optional int32 early_stopping = 17;
|
||||
|
||||
// No repeat n-gram size
|
||||
optional int32 no_repeat_ngram_size = 18;
|
||||
|
||||
// Min-P sampling threshold
|
||||
optional float min_p = 19;
|
||||
|
||||
// Variable beam width array for beam search
|
||||
repeated int32 beam_width_array = 20;
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Output Configuration
|
||||
// Maps to tensorrt_llm.bindings.executor.OutputConfig
|
||||
// ============================================================================
|
||||
|
||||
message OutputConfig {
|
||||
// Number of top log probabilities to return per output token
|
||||
optional int32 logprobs = 1;
|
||||
|
||||
// Number of top log probabilities to return per prompt token
|
||||
optional int32 prompt_logprobs = 2;
|
||||
|
||||
// Return context logits tensor (large, use with caution)
|
||||
bool return_context_logits = 3;
|
||||
|
||||
// Return generation logits tensor (large, use with caution)
|
||||
bool return_generation_logits = 4;
|
||||
|
||||
// Exclude input tokens from output (set to true to enable; TRT-LLM defaults to true internally)
|
||||
bool exclude_input_from_output = 5;
|
||||
|
||||
// Return encoder output (for encoder-decoder models)
|
||||
bool return_encoder_output = 6;
|
||||
|
||||
// Return performance metrics
|
||||
bool return_perf_metrics = 7;
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Guided Decoding
|
||||
// ============================================================================
|
||||
|
||||
message GuidedDecodingParams {
|
||||
// Guide type enumeration
|
||||
enum GuideType {
|
||||
GUIDE_TYPE_UNSPECIFIED = 0;
|
||||
GUIDE_TYPE_JSON = 1; // JSON format (any valid JSON)
|
||||
GUIDE_TYPE_JSON_SCHEMA = 2; // JSON with schema constraint
|
||||
GUIDE_TYPE_REGEX = 3; // Regular expression constraint
|
||||
GUIDE_TYPE_EBNF_GRAMMAR = 4; // EBNF grammar constraint
|
||||
GUIDE_TYPE_STRUCTURAL_TAG = 5; // Structural tag (xgrammar backend)
|
||||
}
|
||||
|
||||
GuideType guide_type = 1;
|
||||
|
||||
// Guide content (schema string, regex pattern, or grammar definition)
|
||||
string guide = 2;
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// LoRA Configuration
|
||||
// ============================================================================
|
||||
|
||||
message LoraConfig {
|
||||
// LoRA task/adapter ID
|
||||
int64 task_id = 1;
|
||||
|
||||
// LoRA weights (serialized tensor, optional if already cached)
|
||||
optional bytes weights = 2;
|
||||
|
||||
// LoRA config as JSON string
|
||||
optional string config_json = 3;
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Prompt Tuning Configuration
|
||||
// ============================================================================
|
||||
|
||||
message PromptTuningConfig {
|
||||
// Embedding table (serialized tensor)
|
||||
bytes embedding_table = 1;
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Multimodal Input (for Vision-Language Models)
|
||||
// ============================================================================
|
||||
|
||||
message MultimodalInput {
|
||||
// Multimodal content hashes for caching
|
||||
repeated int64 multimodal_hashes = 1;
|
||||
|
||||
// Positions in input where multimodal content is inserted
|
||||
repeated int32 multimodal_positions = 2;
|
||||
|
||||
// Lengths of multimodal content at each position
|
||||
repeated int32 multimodal_lengths = 3;
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// KV Cache Retention
|
||||
// ============================================================================
|
||||
|
||||
message KvCacheRetentionConfig {
|
||||
// Retention policy name
|
||||
string policy = 1;
|
||||
|
||||
// Additional configuration as JSON string
|
||||
string config_json = 2;
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Disaggregated Inference
|
||||
// ============================================================================
|
||||
|
||||
message DisaggregatedParams {
|
||||
// Request type for disaggregated inference
|
||||
enum RequestType {
|
||||
REQUEST_TYPE_CONTEXT_AND_GENERATION = 0; // Normal full request
|
||||
REQUEST_TYPE_CONTEXT_ONLY = 1; // Prefill only
|
||||
REQUEST_TYPE_GENERATION_ONLY = 2; // Decode only
|
||||
}
|
||||
|
||||
RequestType request_type = 1;
|
||||
|
||||
// Context request ID (links context and generation phases)
|
||||
string ctx_request_id = 2;
|
||||
|
||||
// Context phase parameters (for generation_only requests)
|
||||
optional ContextPhaseParams context_phase_params = 3;
|
||||
}
|
||||
|
||||
message ContextPhaseParams {
|
||||
// First generated token ID from context phase
|
||||
uint32 first_gen_token_id = 1;
|
||||
|
||||
// KV cache block pointers (serialized)
|
||||
bytes kv_cache_blocks = 2;
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Lookahead Decoding
|
||||
// ============================================================================
|
||||
|
||||
message LookaheadConfig {
|
||||
int32 max_window_size = 1;
|
||||
int32 max_ngram_size = 2;
|
||||
int32 max_verification_set_size = 3;
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Generate Response
|
||||
// ============================================================================
|
||||
|
||||
message GenerateResponse {
|
||||
// Request ID echo
|
||||
string request_id = 1;
|
||||
|
||||
// Response type (oneof ensures exactly one is set)
|
||||
oneof response {
|
||||
GenerateStreamChunk chunk = 2; // Streaming delta
|
||||
GenerateComplete complete = 3; // Final response
|
||||
GenerateError error = 4; // Error response
|
||||
}
|
||||
}
|
||||
|
||||
// Streaming chunk containing delta tokens (new tokens since last chunk)
|
||||
message GenerateStreamChunk {
|
||||
// NEW token IDs only (delta from previous chunk)
|
||||
repeated uint32 token_ids = 1;
|
||||
|
||||
// Beam/sequence index (for beam_width > 1 or n > 1)
|
||||
uint32 sequence_index = 2;
|
||||
|
||||
// Token counts for usage tracking
|
||||
uint32 prompt_tokens = 3;
|
||||
uint32 completion_tokens = 4;
|
||||
uint32 cached_tokens = 5;
|
||||
|
||||
// Log probabilities for this chunk's tokens (if requested)
|
||||
repeated TokenLogprob logprobs = 6;
|
||||
}
|
||||
|
||||
// Final/complete response with all output tokens
|
||||
message GenerateComplete {
|
||||
// All output token IDs (cumulative, not delta)
|
||||
repeated uint32 output_token_ids = 1;
|
||||
|
||||
// Beam/sequence index
|
||||
uint32 sequence_index = 2;
|
||||
|
||||
// Finish reason: "stop", "length", "stop_word"
|
||||
string finish_reason = 3;
|
||||
|
||||
// Specific stop reason (stop word/token that triggered stop)
|
||||
optional string stop_reason = 4;
|
||||
|
||||
// Token counts for usage tracking
|
||||
uint32 prompt_tokens = 5;
|
||||
uint32 completion_tokens = 6;
|
||||
uint32 cached_tokens = 7;
|
||||
|
||||
// Generation log probabilities (if requested)
|
||||
repeated TokenLogprob logprobs = 8;
|
||||
|
||||
// Prompt log probabilities (if requested)
|
||||
repeated TokenLogprob prompt_logprobs = 9;
|
||||
|
||||
// Performance metrics (if requested)
|
||||
optional PerfMetrics perf_metrics = 10;
|
||||
|
||||
// Context logits (if requested) - serialized float tensor
|
||||
optional bytes context_logits = 11;
|
||||
|
||||
// Generation logits (if requested) - serialized float tensor
|
||||
optional bytes generation_logits = 12;
|
||||
}
|
||||
|
||||
// Token log probability information
|
||||
message TokenLogprob {
|
||||
uint32 token_id = 1;
|
||||
float logprob = 2;
|
||||
|
||||
// Top alternative tokens and their log probabilities
|
||||
repeated TopLogprob top_logprobs = 3;
|
||||
}
|
||||
|
||||
message TopLogprob {
|
||||
uint32 token_id = 1;
|
||||
float logprob = 2;
|
||||
}
|
||||
|
||||
// Performance metrics for request
|
||||
message PerfMetrics {
|
||||
double arrival_time = 1; // When request arrived
|
||||
double first_scheduled_time = 2; // When first scheduled
|
||||
double first_token_time = 3; // Time to first token (TTFT)
|
||||
double last_token_time = 4; // When last token generated
|
||||
double kv_cache_transfer_start = 5; // KV cache transfer start (disagg)
|
||||
double kv_cache_transfer_end = 6; // KV cache transfer end (disagg)
|
||||
int64 kv_cache_size = 7; // KV cache size in bytes
|
||||
}
|
||||
|
||||
// Error response
|
||||
message GenerateError {
|
||||
string message = 1; // Human-readable error message
|
||||
string type = 2; // Error type (e.g., "InvalidRequest", "InternalError")
|
||||
int32 code = 3; // Error code
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Embed Request/Response (for embedding models)
|
||||
// ============================================================================
|
||||
|
||||
message EmbedRequest {
|
||||
string request_id = 1;
|
||||
TokenizedInput tokenized = 2;
|
||||
}
|
||||
|
||||
message EmbedResponse {
|
||||
string request_id = 1;
|
||||
repeated float embedding = 2; // Embedding vector
|
||||
uint32 prompt_tokens = 3;
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Health Check
|
||||
// ============================================================================
|
||||
|
||||
message HealthCheckRequest {}
|
||||
|
||||
message HealthCheckResponse {
|
||||
string status = 1; // "OK" or error description
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Abort Request
|
||||
// ============================================================================
|
||||
|
||||
message AbortRequest {
|
||||
string request_id = 1;
|
||||
}
|
||||
|
||||
message AbortResponse {
|
||||
bool success = 1;
|
||||
string message = 2;
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Model Info
|
||||
// ============================================================================
|
||||
|
||||
message GetModelInfoRequest {}
|
||||
|
||||
message GetModelInfoResponse {
|
||||
string model_id = 1; // Model identifier/path
|
||||
int32 max_input_len = 2; // Maximum input length
|
||||
int32 max_seq_len = 3; // Maximum sequence length (input + output)
|
||||
int32 max_batch_size = 4; // Maximum batch size
|
||||
int32 vocab_size = 5; // Vocabulary size
|
||||
int32 hidden_size = 6; // Hidden dimension
|
||||
int32 num_layers = 7; // Number of transformer layers
|
||||
int32 num_heads = 8; // Number of attention heads
|
||||
|
||||
// Supported features
|
||||
repeated string supported_features = 9; // e.g., "lora", "guided_decoding"
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Server Info
|
||||
// ============================================================================
|
||||
|
||||
message GetServerInfoRequest {}
|
||||
|
||||
message GetServerInfoResponse {
|
||||
string version = 1; // TensorRT-LLM version
|
||||
string backend = 2; // "tensorrt" or "pytorch"
|
||||
int32 tensor_parallel_size = 3; // TP size
|
||||
int32 pipeline_parallel_size = 4; // PP size
|
||||
int32 context_parallel_size = 5; // CP size
|
||||
int32 world_size = 6; // Total world size
|
||||
}
|
||||
315
tests/unittest/llmapi/test_grpc.py
Normal file
315
tests/unittest/llmapi/test_grpc.py
Normal file
@ -0,0 +1,315 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Unit tests for gRPC server components."""
|
||||
|
||||
import pytest
|
||||
|
||||
from tensorrt_llm.grpc import trtllm_service_pb2 as pb2
|
||||
from tensorrt_llm.grpc.grpc_request_manager import (
|
||||
create_disaggregated_params_from_proto,
|
||||
create_lora_request_from_proto,
|
||||
create_sampling_params_from_proto,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.threadleak(enabled=False)
|
||||
|
||||
|
||||
class TestSamplingParamsConversion:
|
||||
"""Tests for proto to SamplingParams conversion."""
|
||||
|
||||
def test_basic_sampling_config(self):
|
||||
"""Test basic sampling config conversion."""
|
||||
proto_config = pb2.SamplingConfig(
|
||||
beam_width=1,
|
||||
num_return_sequences=1,
|
||||
temperature=0.7,
|
||||
top_k=50,
|
||||
top_p=0.9,
|
||||
)
|
||||
output_config = pb2.OutputConfig()
|
||||
|
||||
params = create_sampling_params_from_proto(
|
||||
proto_config=proto_config,
|
||||
output_config=output_config,
|
||||
max_tokens=100,
|
||||
)
|
||||
|
||||
assert params.max_tokens == 100
|
||||
assert params.temperature == 0.7
|
||||
assert params.top_k == 50
|
||||
assert params.top_p == 0.9
|
||||
|
||||
def test_beam_search_config(self):
|
||||
"""Test beam search configuration."""
|
||||
proto_config = pb2.SamplingConfig(
|
||||
beam_width=4,
|
||||
num_return_sequences=2,
|
||||
length_penalty=1.2,
|
||||
early_stopping=1,
|
||||
)
|
||||
output_config = pb2.OutputConfig()
|
||||
|
||||
params = create_sampling_params_from_proto(
|
||||
proto_config=proto_config,
|
||||
output_config=output_config,
|
||||
max_tokens=50,
|
||||
)
|
||||
|
||||
assert params.beam_width == 4
|
||||
assert params.n == 2
|
||||
assert params.length_penalty == 1.2
|
||||
|
||||
def test_penalties_config(self):
|
||||
"""Test penalty parameters conversion."""
|
||||
proto_config = pb2.SamplingConfig(
|
||||
repetition_penalty=1.1,
|
||||
presence_penalty=0.5,
|
||||
frequency_penalty=0.3,
|
||||
)
|
||||
output_config = pb2.OutputConfig()
|
||||
|
||||
params = create_sampling_params_from_proto(
|
||||
proto_config=proto_config,
|
||||
output_config=output_config,
|
||||
max_tokens=100,
|
||||
)
|
||||
|
||||
assert params.repetition_penalty == 1.1
|
||||
assert params.presence_penalty == 0.5
|
||||
assert params.frequency_penalty == 0.3
|
||||
|
||||
def test_logprobs_config(self):
|
||||
"""Test logprobs configuration."""
|
||||
proto_config = pb2.SamplingConfig()
|
||||
output_config = pb2.OutputConfig(
|
||||
logprobs=5,
|
||||
prompt_logprobs=3,
|
||||
)
|
||||
|
||||
params = create_sampling_params_from_proto(
|
||||
proto_config=proto_config,
|
||||
output_config=output_config,
|
||||
max_tokens=100,
|
||||
)
|
||||
|
||||
assert params.logprobs == 5
|
||||
assert params.prompt_logprobs == 3
|
||||
|
||||
def test_guided_decoding_json_schema(self):
|
||||
"""Test guided decoding with JSON schema."""
|
||||
proto_config = pb2.SamplingConfig()
|
||||
output_config = pb2.OutputConfig()
|
||||
guided_decoding = pb2.GuidedDecodingParams(
|
||||
guide_type=pb2.GuidedDecodingParams.GUIDE_TYPE_JSON_SCHEMA,
|
||||
guide='{"type": "object", "properties": {"name": {"type": "string"}}}',
|
||||
)
|
||||
|
||||
params = create_sampling_params_from_proto(
|
||||
proto_config=proto_config,
|
||||
output_config=output_config,
|
||||
max_tokens=100,
|
||||
guided_decoding=guided_decoding,
|
||||
)
|
||||
|
||||
assert params.guided_decoding_params is not None
|
||||
assert params.guided_decoding_params.json_schema is not None
|
||||
|
||||
def test_guided_decoding_regex(self):
|
||||
"""Test guided decoding with regex."""
|
||||
proto_config = pb2.SamplingConfig()
|
||||
output_config = pb2.OutputConfig()
|
||||
guided_decoding = pb2.GuidedDecodingParams(
|
||||
guide_type=pb2.GuidedDecodingParams.GUIDE_TYPE_REGEX,
|
||||
guide=r"\d{3}-\d{4}",
|
||||
)
|
||||
|
||||
params = create_sampling_params_from_proto(
|
||||
proto_config=proto_config,
|
||||
output_config=output_config,
|
||||
max_tokens=100,
|
||||
guided_decoding=guided_decoding,
|
||||
)
|
||||
|
||||
assert params.guided_decoding_params is not None
|
||||
assert params.guided_decoding_params.regex is not None
|
||||
|
||||
|
||||
class TestLoraRequestConversion:
|
||||
"""Tests for proto to LoRARequest conversion."""
|
||||
|
||||
def test_basic_lora_config(self):
|
||||
"""Test basic LoRA config conversion."""
|
||||
lora_config = pb2.LoraConfig(task_id=123)
|
||||
|
||||
request = create_lora_request_from_proto(lora_config)
|
||||
|
||||
assert request is not None
|
||||
assert request.task_id == 123
|
||||
|
||||
def test_none_lora_config(self):
|
||||
"""Test None LoRA config returns None."""
|
||||
request = create_lora_request_from_proto(None)
|
||||
assert request is None
|
||||
|
||||
|
||||
class TestDisaggregatedParamsConversion:
|
||||
"""Tests for proto to DisaggregatedParams conversion."""
|
||||
|
||||
def test_context_only_request(self):
|
||||
"""Test context-only disaggregated request."""
|
||||
proto_params = pb2.DisaggregatedParams(
|
||||
request_type=pb2.DisaggregatedParams.REQUEST_TYPE_CONTEXT_ONLY,
|
||||
ctx_request_id="ctx-123",
|
||||
)
|
||||
|
||||
params = create_disaggregated_params_from_proto(proto_params)
|
||||
|
||||
assert params is not None
|
||||
assert params.ctx_request_id == "ctx-123"
|
||||
|
||||
def test_generation_only_request(self):
|
||||
"""Test generation-only disaggregated request."""
|
||||
proto_params = pb2.DisaggregatedParams(
|
||||
request_type=pb2.DisaggregatedParams.REQUEST_TYPE_GENERATION_ONLY,
|
||||
ctx_request_id="gen-456",
|
||||
)
|
||||
|
||||
params = create_disaggregated_params_from_proto(proto_params)
|
||||
|
||||
assert params is not None
|
||||
|
||||
def test_none_params(self):
|
||||
"""Test None disaggregated params returns None."""
|
||||
params = create_disaggregated_params_from_proto(None)
|
||||
assert params is None
|
||||
|
||||
|
||||
class TestProtoMessages:
|
||||
"""Tests for proto message structure."""
|
||||
|
||||
def test_generate_request_structure(self):
|
||||
"""Test GenerateRequest message structure."""
|
||||
request = pb2.GenerateRequest(
|
||||
request_id="test-123",
|
||||
tokenized=pb2.TokenizedInput(
|
||||
input_token_ids=[1, 2, 3, 4, 5],
|
||||
original_text="Hello world",
|
||||
),
|
||||
sampling_config=pb2.SamplingConfig(temperature=0.8),
|
||||
max_tokens=50,
|
||||
streaming=True,
|
||||
)
|
||||
|
||||
assert request.request_id == "test-123"
|
||||
assert list(request.tokenized.input_token_ids) == [1, 2, 3, 4, 5]
|
||||
assert request.tokenized.original_text == "Hello world"
|
||||
assert request.sampling_config.temperature == 0.8
|
||||
assert request.max_tokens == 50
|
||||
assert request.streaming is True
|
||||
|
||||
def test_generate_response_chunk(self):
|
||||
"""Test GenerateResponse with chunk."""
|
||||
response = pb2.GenerateResponse(
|
||||
request_id="test-123",
|
||||
chunk=pb2.GenerateStreamChunk(
|
||||
token_ids=[10, 11, 12],
|
||||
sequence_index=0,
|
||||
prompt_tokens=5,
|
||||
completion_tokens=3,
|
||||
),
|
||||
)
|
||||
|
||||
assert response.request_id == "test-123"
|
||||
assert list(response.chunk.token_ids) == [10, 11, 12]
|
||||
assert response.chunk.prompt_tokens == 5
|
||||
assert response.chunk.completion_tokens == 3
|
||||
|
||||
def test_generate_response_complete(self):
|
||||
"""Test GenerateResponse with complete."""
|
||||
response = pb2.GenerateResponse(
|
||||
request_id="test-123",
|
||||
complete=pb2.GenerateComplete(
|
||||
output_token_ids=[10, 11, 12, 13],
|
||||
finish_reason="stop",
|
||||
prompt_tokens=5,
|
||||
completion_tokens=4,
|
||||
),
|
||||
)
|
||||
|
||||
assert response.request_id == "test-123"
|
||||
assert list(response.complete.output_token_ids) == [10, 11, 12, 13]
|
||||
assert response.complete.finish_reason == "stop"
|
||||
|
||||
def test_health_check_messages(self):
|
||||
"""Test HealthCheck messages."""
|
||||
_request = pb2.HealthCheckRequest() # noqa: F841 - verify message construction
|
||||
response = pb2.HealthCheckResponse(status="healthy")
|
||||
|
||||
assert response.status == "healthy"
|
||||
|
||||
def test_model_info_response(self):
|
||||
"""Test GetModelInfoResponse message."""
|
||||
response = pb2.GetModelInfoResponse(
|
||||
model_id="meta-llama/Llama-2-7b",
|
||||
max_input_len=4096,
|
||||
max_seq_len=8192,
|
||||
vocab_size=32000,
|
||||
)
|
||||
|
||||
assert response.model_id == "meta-llama/Llama-2-7b"
|
||||
assert response.max_input_len == 4096
|
||||
assert response.max_seq_len == 8192
|
||||
assert response.vocab_size == 32000
|
||||
|
||||
def test_server_info_response(self):
|
||||
"""Test GetServerInfoResponse message."""
|
||||
response = pb2.GetServerInfoResponse(
|
||||
version="0.17.0",
|
||||
backend="tensorrt-llm",
|
||||
tensor_parallel_size=2,
|
||||
pipeline_parallel_size=1,
|
||||
world_size=2,
|
||||
)
|
||||
|
||||
assert response.version == "0.17.0"
|
||||
assert response.backend == "tensorrt-llm"
|
||||
assert response.tensor_parallel_size == 2
|
||||
assert response.world_size == 2
|
||||
|
||||
def test_embed_messages(self):
|
||||
"""Test Embed request and response messages."""
|
||||
request = pb2.EmbedRequest(
|
||||
request_id="embed-123",
|
||||
tokenized=pb2.TokenizedInput(input_token_ids=[1, 2, 3]),
|
||||
)
|
||||
response = pb2.EmbedResponse(
|
||||
request_id="embed-123",
|
||||
embedding=[0.1, 0.2, 0.3, 0.4],
|
||||
prompt_tokens=3,
|
||||
)
|
||||
|
||||
assert request.request_id == "embed-123"
|
||||
assert response.request_id == "embed-123"
|
||||
assert list(response.embedding) == [0.1, 0.2, 0.3, 0.4]
|
||||
assert response.prompt_tokens == 3
|
||||
|
||||
def test_abort_messages(self):
|
||||
"""Test Abort request and response messages."""
|
||||
request = pb2.AbortRequest(request_id="abort-123")
|
||||
response = pb2.AbortResponse(success=True, message="Request aborted")
|
||||
|
||||
assert request.request_id == "abort-123"
|
||||
assert response.success is True
|
||||
assert response.message == "Request aborted"
|
||||
Loading…
Reference in New Issue
Block a user