TensorRT-LLMs/tensorrt_llm/grpc/grpc_servicer.py
2026-01-30 07:48:27 +08:00

546 lines
20 KiB
Python

# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""gRPC Servicer for TensorRT-LLM.
Implements the TrtllmService gRPC service for high-performance communication
with external routers (e.g., sgl-router) using pre-tokenized input.
"""
import asyncio
import time
from collections.abc import AsyncGenerator
from typing import List, Union
import grpc
from tensorrt_llm.executor.result import Logprob, TokenLogprobs
from tensorrt_llm.logger import logger
from . import trtllm_service_pb2, trtllm_service_pb2_grpc
from .grpc_request_manager import (
GrpcRequestManager,
create_disaggregated_params_from_proto,
create_lora_request_from_proto,
create_sampling_params_from_proto,
)
class TrtllmServiceServicer(trtllm_service_pb2_grpc.TrtllmServiceServicer):
"""gRPC servicer implementing the TrtLlmEngine service.
Handles RPCs:
- Generate: Streaming text generation
- Embed: Embeddings (for embedding models)
- HealthCheck: Health probe
- Abort: Cancel a request
- GetModelInfo: Model metadata
- GetServerInfo: Server state
"""
def __init__(self, request_manager: GrpcRequestManager, model_path: str = ""):
"""Initialize the servicer.
Args:
request_manager: The GrpcRequestManager instance
model_path: Path to the model (for metadata)
"""
self.request_manager = request_manager
self.model_path = model_path
self._start_time = time.time()
logger.info("TrtllmServiceServicer initialized")
async def Generate(
self,
request: trtllm_service_pb2.GenerateRequest,
context: grpc.aio.ServicerContext,
) -> AsyncGenerator[trtllm_service_pb2.GenerateResponse, None]:
"""Handle streaming generation requests.
Args:
request: The GenerateRequest protobuf
context: gRPC context
Yields:
GenerateResponse protobuf messages (streaming)
"""
request_id = request.request_id
logger.info(f"Generate request {request_id} received")
try:
# Extract tokenized input (required)
if not request.HasField("tokenized"):
yield self._error_response(
request_id,
"Missing tokenized input",
"INVALID_REQUEST",
400,
)
return
prompt_token_ids = list(request.tokenized.input_token_ids)
# Build sampling params with detokenize=False (key optimization!)
sampling_params = create_sampling_params_from_proto(
proto_config=request.sampling_config,
output_config=request.output_config,
max_tokens=request.max_tokens,
end_id=request.end_id if request.HasField("end_id") else None,
pad_id=request.pad_id if request.HasField("pad_id") else None,
bad_words=list(request.bad_words) if request.bad_words else None,
stop_words=list(request.stop_words) if request.stop_words else None,
guided_decoding=request.guided_decoding
if request.HasField("guided_decoding")
else None,
embedding_bias=list(request.embedding_bias) if request.embedding_bias else None,
)
# Build LoRA request if present
lora_request = create_lora_request_from_proto(
request.lora_config if request.HasField("lora_config") else None
)
# Build disaggregated params if present
disaggregated_params = create_disaggregated_params_from_proto(
request.disaggregated_params if request.HasField("disaggregated_params") else None
)
# Track tokens sent per sequence index to avoid duplicates
# TRT-LLM's token_ids_diff doesn't clear between iterations for n>1
sent_token_counts: dict[int, int] = {}
# Submit to request manager and stream outputs
# The request manager now yields GenerationResult objects
async for gen_result in self.request_manager.generate(
request_id=request_id,
prompt_token_ids=prompt_token_ids,
sampling_params=sampling_params,
streaming=request.streaming,
lora_request=lora_request,
disaggregated_params=disaggregated_params,
):
# Check if client disconnected
if context.cancelled():
logger.info(f"Client disconnected for {request_id}")
await self.request_manager.abort(request_id)
return
# Convert GenerationResult to protobuf response
if request.streaming:
for chunk_response in self._chunk_responses(
request_id, gen_result, prompt_token_ids, sent_token_counts
):
yield chunk_response
# Send complete responses when finished (one per sequence for n>1)
if gen_result.finished:
for complete_response in self._complete_responses(
request_id, gen_result, prompt_token_ids
):
yield complete_response
except asyncio.CancelledError:
logger.info(f"Request {request_id} cancelled")
await self.request_manager.abort(request_id)
raise
except Exception as e:
logger.error(f"Error in Generate for {request_id}: {e}")
yield self._error_response(
request_id,
str(e),
"INTERNAL_ERROR",
500,
)
async def Embed(
self,
request: trtllm_service_pb2.EmbedRequest,
context: grpc.aio.ServicerContext,
) -> trtllm_service_pb2.EmbedResponse:
"""Handle embedding requests.
Args:
request: The EmbedRequest protobuf
context: gRPC context
Returns:
EmbedResponse protobuf
"""
logger.warning("Embed RPC not yet implemented")
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details("Embed RPC not yet implemented")
return trtllm_service_pb2.EmbedResponse(
request_id=request.request_id,
embedding=[],
prompt_tokens=0,
)
async def HealthCheck(
self,
request: trtllm_service_pb2.HealthCheckRequest,
context: grpc.aio.ServicerContext,
) -> trtllm_service_pb2.HealthCheckResponse:
"""Handle health check requests.
Args:
request: The HealthCheckRequest protobuf
context: gRPC context
Returns:
HealthCheckResponse protobuf
"""
is_healthy, message = await self.request_manager.health_check()
logger.info(f"HealthCheck: healthy={is_healthy}, message={message}")
return trtllm_service_pb2.HealthCheckResponse(
status=message,
)
async def Abort(
self,
request: trtllm_service_pb2.AbortRequest,
context: grpc.aio.ServicerContext,
) -> trtllm_service_pb2.AbortResponse:
"""Handle abort requests.
Args:
request: The AbortRequest protobuf
context: gRPC context
Returns:
AbortResponse protobuf
"""
request_id = request.request_id
logger.info(f"Abort request for {request_id}")
success = await self.request_manager.abort(request_id)
return trtllm_service_pb2.AbortResponse(
success=success,
message=f"Request {request_id} {'aborted' if success else 'not found'}",
)
async def GetModelInfo(
self,
request: trtllm_service_pb2.GetModelInfoRequest,
context: grpc.aio.ServicerContext,
) -> trtllm_service_pb2.GetModelInfoResponse:
"""Handle model info requests.
Args:
request: The GetModelInfoRequest protobuf
context: gRPC context
Returns:
GetModelInfoResponse protobuf
"""
model_config = self.request_manager.get_model_config()
return trtllm_service_pb2.GetModelInfoResponse(
model_id=self.model_path or model_config.get("model_path", ""),
max_input_len=model_config.get("max_context_length", 0),
max_seq_len=model_config.get("max_seq_len", 0)
or model_config.get("max_context_length", 0),
vocab_size=model_config.get("vocab_size", 0),
)
async def GetServerInfo(
self,
request: trtllm_service_pb2.GetServerInfoRequest,
context: grpc.aio.ServicerContext,
) -> trtllm_service_pb2.GetServerInfoResponse:
"""Handle server info requests.
Args:
request: The GetServerInfoRequest protobuf
context: gRPC context
Returns:
GetServerInfoResponse protobuf
"""
try:
import tensorrt_llm
version = getattr(tensorrt_llm, "__version__", "unknown")
except Exception:
version = "unknown"
# Try to get parallelism info from LLM args
tp_size = 1
pp_size = 1
world_size = 1
try:
llm = self.request_manager.llm
if hasattr(llm, "args") and llm.args is not None:
args = llm.args
if hasattr(args, "tensor_parallel_size") and args.tensor_parallel_size:
tp_size = args.tensor_parallel_size
if hasattr(args, "pipeline_parallel_size") and args.pipeline_parallel_size:
pp_size = args.pipeline_parallel_size
world_size = tp_size * pp_size
except Exception as e:
logger.debug(f"Could not get parallelism info: {e}")
return trtllm_service_pb2.GetServerInfoResponse(
version=version,
backend="tensorrt-llm",
tensor_parallel_size=tp_size,
pipeline_parallel_size=pp_size,
context_parallel_size=1, # Context parallelism is separate from TP/PP
world_size=world_size,
)
# ========== Helper methods ==========
def _convert_logprobs_to_proto(
self,
token_ids: List[int],
logprobs: Union[TokenLogprobs, List[float], None],
) -> List[trtllm_service_pb2.TokenLogprob]:
"""Convert TRT-LLM logprobs to protobuf TokenLogprob messages.
Handles both formats:
- List[float]: Simple logprobs (one per token)
- TokenLogprobs (list[dict[int, Logprob]]): Top-k logprobs per position
Args:
token_ids: The sampled token IDs
logprobs: Logprobs from TRT-LLM (can be List[float] or TokenLogprobs)
Returns:
List of TokenLogprob proto messages
"""
if not logprobs or not token_ids:
return []
result = []
for i, token_id in enumerate(token_ids):
if i >= len(logprobs):
break
lp = logprobs[i]
if isinstance(lp, dict):
# TokenLogprobs format: dict[int, Logprob]
# Each entry maps token_id -> Logprob(logprob, rank)
token_logprob = trtllm_service_pb2.TokenLogprob(
token_id=token_id,
logprob=lp[token_id].logprob if token_id in lp else 0.0,
)
# Add top logprobs (all entries in the dict)
for tid, logprob_obj in lp.items():
if isinstance(logprob_obj, Logprob):
token_logprob.top_logprobs.append(
trtllm_service_pb2.TopLogprob(
token_id=tid,
logprob=logprob_obj.logprob,
)
)
result.append(token_logprob)
elif isinstance(lp, (int, float)):
# Simple float logprob
token_logprob = trtllm_service_pb2.TokenLogprob(
token_id=token_id,
logprob=float(lp),
)
result.append(token_logprob)
return result
def _chunk_responses(
self,
request_id: str,
gen_result,
prompt_token_ids: list,
sent_token_counts: dict[int, int],
) -> List[trtllm_service_pb2.GenerateResponse]:
"""Build streaming chunk responses from GenerationResult.
Uses cumulative token_ids and tracks sent position to compute true deltas.
TRT-LLM's token_ids_diff doesn't clear between iterations for n>1, so we
compute deltas ourselves.
Args:
request_id: The request ID
gen_result: TensorRT-LLM GenerationResult
prompt_token_ids: Original prompt tokens
sent_token_counts: Dict tracking tokens already sent per sequence index
Returns:
List of GenerateResponse with chunk field set (one per output)
"""
responses = []
cached_tokens = gen_result.cached_tokens if hasattr(gen_result, "cached_tokens") else 0
if not gen_result.outputs:
# No outputs yet, return empty chunk
responses.append(
trtllm_service_pb2.GenerateResponse(
request_id=request_id,
chunk=trtllm_service_pb2.GenerateStreamChunk(
token_ids=[],
prompt_tokens=len(prompt_token_ids),
completion_tokens=0,
cached_tokens=cached_tokens,
),
)
)
return responses
# Process all outputs (for n>1 support)
for completion in gen_result.outputs:
index = completion.index
# Use cumulative token_ids and compute delta ourselves
# because token_ids_diff doesn't clear between iterations for n>1
all_tokens = list(completion.token_ids) if completion.token_ids else []
sent_count = sent_token_counts.get(index, 0)
delta_tokens = all_tokens[sent_count:]
# Skip if no new tokens for this sequence
if not delta_tokens:
continue
# Update sent count
sent_token_counts[index] = len(all_tokens)
chunk = trtllm_service_pb2.GenerateStreamChunk(
token_ids=delta_tokens,
sequence_index=completion.index,
prompt_tokens=len(prompt_token_ids),
completion_tokens=len(completion.token_ids) if completion.token_ids else 0,
cached_tokens=cached_tokens,
)
# Add logprobs if available
# Note: We compute delta logprobs ourselves since logprobs_diff has same issue as token_ids_diff
if completion.logprobs:
all_logprobs = completion.logprobs
delta_logprobs = all_logprobs[sent_count:] if sent_count < len(all_logprobs) else []
proto_logprobs = self._convert_logprobs_to_proto(delta_tokens, delta_logprobs)
chunk.logprobs.extend(proto_logprobs)
responses.append(
trtllm_service_pb2.GenerateResponse(
request_id=request_id,
chunk=chunk,
)
)
return responses
def _complete_responses(
self,
request_id: str,
gen_result,
prompt_token_ids: list,
) -> List[trtllm_service_pb2.GenerateResponse]:
"""Build final completion responses from GenerationResult.
For n>1, returns one response per output sequence.
Args:
request_id: The request ID
gen_result: TensorRT-LLM GenerationResult (finished=True)
prompt_token_ids: Original prompt tokens
Returns:
List of GenerateResponse with complete field set (one per output)
"""
responses = []
cached_tokens = gen_result.cached_tokens if hasattr(gen_result, "cached_tokens") else 0
if not gen_result.outputs:
# No outputs, return error response
responses.append(
trtllm_service_pb2.GenerateResponse(
request_id=request_id,
complete=trtllm_service_pb2.GenerateComplete(
output_token_ids=[],
finish_reason="error",
prompt_tokens=len(prompt_token_ids),
completion_tokens=0,
cached_tokens=0,
),
)
)
return responses
# Process all outputs (for n>1 support)
for completion in gen_result.outputs:
output_tokens = list(completion.token_ids) if completion.token_ids else []
complete = trtllm_service_pb2.GenerateComplete(
output_token_ids=output_tokens,
sequence_index=completion.index,
finish_reason=completion.finish_reason or "stop",
prompt_tokens=len(prompt_token_ids),
completion_tokens=len(output_tokens),
cached_tokens=cached_tokens,
)
# Add stop reason if available
if hasattr(completion, "stop_reason") and completion.stop_reason:
complete.stop_reason = str(completion.stop_reason)
# Add generation logprobs if available
if completion.logprobs:
proto_logprobs = self._convert_logprobs_to_proto(output_tokens, completion.logprobs)
complete.logprobs.extend(proto_logprobs)
# Add prompt logprobs if available
if hasattr(completion, "prompt_logprobs") and completion.prompt_logprobs:
# For prompt logprobs, we use the prompt_token_ids
proto_prompt_logprobs = self._convert_logprobs_to_proto(
prompt_token_ids, completion.prompt_logprobs
)
complete.prompt_logprobs.extend(proto_prompt_logprobs)
responses.append(
trtllm_service_pb2.GenerateResponse(
request_id=request_id,
complete=complete,
)
)
return responses
def _error_response(
self,
request_id: str,
message: str,
error_type: str,
code: int,
) -> trtllm_service_pb2.GenerateResponse:
"""Build an error response.
Args:
request_id: The request ID
message: Error message
error_type: Error type string
code: Error code
Returns:
GenerateResponse with error field set
"""
return trtllm_service_pb2.GenerateResponse(
request_id=request_id,
error=trtllm_service_pb2.GenerateError(
message=message,
type=error_type,
code=code,
),
)