mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-06 03:01:50 +08:00
546 lines
20 KiB
Python
546 lines
20 KiB
Python
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""gRPC Servicer for TensorRT-LLM.
|
|
|
|
Implements the TrtllmService gRPC service for high-performance communication
|
|
with external routers (e.g., sgl-router) using pre-tokenized input.
|
|
"""
|
|
|
|
import asyncio
|
|
import time
|
|
from collections.abc import AsyncGenerator
|
|
from typing import List, Union
|
|
|
|
import grpc
|
|
|
|
from tensorrt_llm.executor.result import Logprob, TokenLogprobs
|
|
from tensorrt_llm.logger import logger
|
|
|
|
from . import trtllm_service_pb2, trtllm_service_pb2_grpc
|
|
from .grpc_request_manager import (
|
|
GrpcRequestManager,
|
|
create_disaggregated_params_from_proto,
|
|
create_lora_request_from_proto,
|
|
create_sampling_params_from_proto,
|
|
)
|
|
|
|
|
|
class TrtllmServiceServicer(trtllm_service_pb2_grpc.TrtllmServiceServicer):
|
|
"""gRPC servicer implementing the TrtLlmEngine service.
|
|
|
|
Handles RPCs:
|
|
- Generate: Streaming text generation
|
|
- Embed: Embeddings (for embedding models)
|
|
- HealthCheck: Health probe
|
|
- Abort: Cancel a request
|
|
- GetModelInfo: Model metadata
|
|
- GetServerInfo: Server state
|
|
"""
|
|
|
|
def __init__(self, request_manager: GrpcRequestManager, model_path: str = ""):
|
|
"""Initialize the servicer.
|
|
|
|
Args:
|
|
request_manager: The GrpcRequestManager instance
|
|
model_path: Path to the model (for metadata)
|
|
"""
|
|
self.request_manager = request_manager
|
|
self.model_path = model_path
|
|
self._start_time = time.time()
|
|
logger.info("TrtllmServiceServicer initialized")
|
|
|
|
async def Generate(
|
|
self,
|
|
request: trtllm_service_pb2.GenerateRequest,
|
|
context: grpc.aio.ServicerContext,
|
|
) -> AsyncGenerator[trtllm_service_pb2.GenerateResponse, None]:
|
|
"""Handle streaming generation requests.
|
|
|
|
Args:
|
|
request: The GenerateRequest protobuf
|
|
context: gRPC context
|
|
|
|
Yields:
|
|
GenerateResponse protobuf messages (streaming)
|
|
"""
|
|
request_id = request.request_id
|
|
logger.info(f"Generate request {request_id} received")
|
|
|
|
try:
|
|
# Extract tokenized input (required)
|
|
if not request.HasField("tokenized"):
|
|
yield self._error_response(
|
|
request_id,
|
|
"Missing tokenized input",
|
|
"INVALID_REQUEST",
|
|
400,
|
|
)
|
|
return
|
|
|
|
prompt_token_ids = list(request.tokenized.input_token_ids)
|
|
|
|
# Build sampling params with detokenize=False (key optimization!)
|
|
sampling_params = create_sampling_params_from_proto(
|
|
proto_config=request.sampling_config,
|
|
output_config=request.output_config,
|
|
max_tokens=request.max_tokens,
|
|
end_id=request.end_id if request.HasField("end_id") else None,
|
|
pad_id=request.pad_id if request.HasField("pad_id") else None,
|
|
bad_words=list(request.bad_words) if request.bad_words else None,
|
|
stop_words=list(request.stop_words) if request.stop_words else None,
|
|
guided_decoding=request.guided_decoding
|
|
if request.HasField("guided_decoding")
|
|
else None,
|
|
embedding_bias=list(request.embedding_bias) if request.embedding_bias else None,
|
|
)
|
|
|
|
# Build LoRA request if present
|
|
lora_request = create_lora_request_from_proto(
|
|
request.lora_config if request.HasField("lora_config") else None
|
|
)
|
|
|
|
# Build disaggregated params if present
|
|
disaggregated_params = create_disaggregated_params_from_proto(
|
|
request.disaggregated_params if request.HasField("disaggregated_params") else None
|
|
)
|
|
|
|
# Track tokens sent per sequence index to avoid duplicates
|
|
# TRT-LLM's token_ids_diff doesn't clear between iterations for n>1
|
|
sent_token_counts: dict[int, int] = {}
|
|
|
|
# Submit to request manager and stream outputs
|
|
# The request manager now yields GenerationResult objects
|
|
async for gen_result in self.request_manager.generate(
|
|
request_id=request_id,
|
|
prompt_token_ids=prompt_token_ids,
|
|
sampling_params=sampling_params,
|
|
streaming=request.streaming,
|
|
lora_request=lora_request,
|
|
disaggregated_params=disaggregated_params,
|
|
):
|
|
# Check if client disconnected
|
|
if context.cancelled():
|
|
logger.info(f"Client disconnected for {request_id}")
|
|
await self.request_manager.abort(request_id)
|
|
return
|
|
|
|
# Convert GenerationResult to protobuf response
|
|
if request.streaming:
|
|
for chunk_response in self._chunk_responses(
|
|
request_id, gen_result, prompt_token_ids, sent_token_counts
|
|
):
|
|
yield chunk_response
|
|
|
|
# Send complete responses when finished (one per sequence for n>1)
|
|
if gen_result.finished:
|
|
for complete_response in self._complete_responses(
|
|
request_id, gen_result, prompt_token_ids
|
|
):
|
|
yield complete_response
|
|
|
|
except asyncio.CancelledError:
|
|
logger.info(f"Request {request_id} cancelled")
|
|
await self.request_manager.abort(request_id)
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"Error in Generate for {request_id}: {e}")
|
|
yield self._error_response(
|
|
request_id,
|
|
str(e),
|
|
"INTERNAL_ERROR",
|
|
500,
|
|
)
|
|
|
|
async def Embed(
|
|
self,
|
|
request: trtllm_service_pb2.EmbedRequest,
|
|
context: grpc.aio.ServicerContext,
|
|
) -> trtllm_service_pb2.EmbedResponse:
|
|
"""Handle embedding requests.
|
|
|
|
Args:
|
|
request: The EmbedRequest protobuf
|
|
context: gRPC context
|
|
|
|
Returns:
|
|
EmbedResponse protobuf
|
|
"""
|
|
logger.warning("Embed RPC not yet implemented")
|
|
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
|
context.set_details("Embed RPC not yet implemented")
|
|
return trtllm_service_pb2.EmbedResponse(
|
|
request_id=request.request_id,
|
|
embedding=[],
|
|
prompt_tokens=0,
|
|
)
|
|
|
|
async def HealthCheck(
|
|
self,
|
|
request: trtllm_service_pb2.HealthCheckRequest,
|
|
context: grpc.aio.ServicerContext,
|
|
) -> trtllm_service_pb2.HealthCheckResponse:
|
|
"""Handle health check requests.
|
|
|
|
Args:
|
|
request: The HealthCheckRequest protobuf
|
|
context: gRPC context
|
|
|
|
Returns:
|
|
HealthCheckResponse protobuf
|
|
"""
|
|
is_healthy, message = await self.request_manager.health_check()
|
|
logger.info(f"HealthCheck: healthy={is_healthy}, message={message}")
|
|
|
|
return trtllm_service_pb2.HealthCheckResponse(
|
|
status=message,
|
|
)
|
|
|
|
async def Abort(
|
|
self,
|
|
request: trtllm_service_pb2.AbortRequest,
|
|
context: grpc.aio.ServicerContext,
|
|
) -> trtllm_service_pb2.AbortResponse:
|
|
"""Handle abort requests.
|
|
|
|
Args:
|
|
request: The AbortRequest protobuf
|
|
context: gRPC context
|
|
|
|
Returns:
|
|
AbortResponse protobuf
|
|
"""
|
|
request_id = request.request_id
|
|
logger.info(f"Abort request for {request_id}")
|
|
|
|
success = await self.request_manager.abort(request_id)
|
|
|
|
return trtllm_service_pb2.AbortResponse(
|
|
success=success,
|
|
message=f"Request {request_id} {'aborted' if success else 'not found'}",
|
|
)
|
|
|
|
async def GetModelInfo(
|
|
self,
|
|
request: trtllm_service_pb2.GetModelInfoRequest,
|
|
context: grpc.aio.ServicerContext,
|
|
) -> trtllm_service_pb2.GetModelInfoResponse:
|
|
"""Handle model info requests.
|
|
|
|
Args:
|
|
request: The GetModelInfoRequest protobuf
|
|
context: gRPC context
|
|
|
|
Returns:
|
|
GetModelInfoResponse protobuf
|
|
"""
|
|
model_config = self.request_manager.get_model_config()
|
|
|
|
return trtllm_service_pb2.GetModelInfoResponse(
|
|
model_id=self.model_path or model_config.get("model_path", ""),
|
|
max_input_len=model_config.get("max_context_length", 0),
|
|
max_seq_len=model_config.get("max_seq_len", 0)
|
|
or model_config.get("max_context_length", 0),
|
|
vocab_size=model_config.get("vocab_size", 0),
|
|
)
|
|
|
|
async def GetServerInfo(
|
|
self,
|
|
request: trtllm_service_pb2.GetServerInfoRequest,
|
|
context: grpc.aio.ServicerContext,
|
|
) -> trtllm_service_pb2.GetServerInfoResponse:
|
|
"""Handle server info requests.
|
|
|
|
Args:
|
|
request: The GetServerInfoRequest protobuf
|
|
context: gRPC context
|
|
|
|
Returns:
|
|
GetServerInfoResponse protobuf
|
|
"""
|
|
try:
|
|
import tensorrt_llm
|
|
|
|
version = getattr(tensorrt_llm, "__version__", "unknown")
|
|
except Exception:
|
|
version = "unknown"
|
|
|
|
# Try to get parallelism info from LLM args
|
|
tp_size = 1
|
|
pp_size = 1
|
|
world_size = 1
|
|
|
|
try:
|
|
llm = self.request_manager.llm
|
|
if hasattr(llm, "args") and llm.args is not None:
|
|
args = llm.args
|
|
if hasattr(args, "tensor_parallel_size") and args.tensor_parallel_size:
|
|
tp_size = args.tensor_parallel_size
|
|
if hasattr(args, "pipeline_parallel_size") and args.pipeline_parallel_size:
|
|
pp_size = args.pipeline_parallel_size
|
|
world_size = tp_size * pp_size
|
|
except Exception as e:
|
|
logger.debug(f"Could not get parallelism info: {e}")
|
|
|
|
return trtllm_service_pb2.GetServerInfoResponse(
|
|
version=version,
|
|
backend="tensorrt-llm",
|
|
tensor_parallel_size=tp_size,
|
|
pipeline_parallel_size=pp_size,
|
|
context_parallel_size=1, # Context parallelism is separate from TP/PP
|
|
world_size=world_size,
|
|
)
|
|
|
|
# ========== Helper methods ==========
|
|
|
|
def _convert_logprobs_to_proto(
|
|
self,
|
|
token_ids: List[int],
|
|
logprobs: Union[TokenLogprobs, List[float], None],
|
|
) -> List[trtllm_service_pb2.TokenLogprob]:
|
|
"""Convert TRT-LLM logprobs to protobuf TokenLogprob messages.
|
|
|
|
Handles both formats:
|
|
- List[float]: Simple logprobs (one per token)
|
|
- TokenLogprobs (list[dict[int, Logprob]]): Top-k logprobs per position
|
|
|
|
Args:
|
|
token_ids: The sampled token IDs
|
|
logprobs: Logprobs from TRT-LLM (can be List[float] or TokenLogprobs)
|
|
|
|
Returns:
|
|
List of TokenLogprob proto messages
|
|
"""
|
|
if not logprobs or not token_ids:
|
|
return []
|
|
|
|
result = []
|
|
for i, token_id in enumerate(token_ids):
|
|
if i >= len(logprobs):
|
|
break
|
|
|
|
lp = logprobs[i]
|
|
|
|
if isinstance(lp, dict):
|
|
# TokenLogprobs format: dict[int, Logprob]
|
|
# Each entry maps token_id -> Logprob(logprob, rank)
|
|
token_logprob = trtllm_service_pb2.TokenLogprob(
|
|
token_id=token_id,
|
|
logprob=lp[token_id].logprob if token_id in lp else 0.0,
|
|
)
|
|
# Add top logprobs (all entries in the dict)
|
|
for tid, logprob_obj in lp.items():
|
|
if isinstance(logprob_obj, Logprob):
|
|
token_logprob.top_logprobs.append(
|
|
trtllm_service_pb2.TopLogprob(
|
|
token_id=tid,
|
|
logprob=logprob_obj.logprob,
|
|
)
|
|
)
|
|
result.append(token_logprob)
|
|
elif isinstance(lp, (int, float)):
|
|
# Simple float logprob
|
|
token_logprob = trtllm_service_pb2.TokenLogprob(
|
|
token_id=token_id,
|
|
logprob=float(lp),
|
|
)
|
|
result.append(token_logprob)
|
|
|
|
return result
|
|
|
|
def _chunk_responses(
|
|
self,
|
|
request_id: str,
|
|
gen_result,
|
|
prompt_token_ids: list,
|
|
sent_token_counts: dict[int, int],
|
|
) -> List[trtllm_service_pb2.GenerateResponse]:
|
|
"""Build streaming chunk responses from GenerationResult.
|
|
|
|
Uses cumulative token_ids and tracks sent position to compute true deltas.
|
|
TRT-LLM's token_ids_diff doesn't clear between iterations for n>1, so we
|
|
compute deltas ourselves.
|
|
|
|
Args:
|
|
request_id: The request ID
|
|
gen_result: TensorRT-LLM GenerationResult
|
|
prompt_token_ids: Original prompt tokens
|
|
sent_token_counts: Dict tracking tokens already sent per sequence index
|
|
|
|
Returns:
|
|
List of GenerateResponse with chunk field set (one per output)
|
|
"""
|
|
responses = []
|
|
cached_tokens = gen_result.cached_tokens if hasattr(gen_result, "cached_tokens") else 0
|
|
|
|
if not gen_result.outputs:
|
|
# No outputs yet, return empty chunk
|
|
responses.append(
|
|
trtllm_service_pb2.GenerateResponse(
|
|
request_id=request_id,
|
|
chunk=trtllm_service_pb2.GenerateStreamChunk(
|
|
token_ids=[],
|
|
prompt_tokens=len(prompt_token_ids),
|
|
completion_tokens=0,
|
|
cached_tokens=cached_tokens,
|
|
),
|
|
)
|
|
)
|
|
return responses
|
|
|
|
# Process all outputs (for n>1 support)
|
|
for completion in gen_result.outputs:
|
|
index = completion.index
|
|
# Use cumulative token_ids and compute delta ourselves
|
|
# because token_ids_diff doesn't clear between iterations for n>1
|
|
all_tokens = list(completion.token_ids) if completion.token_ids else []
|
|
sent_count = sent_token_counts.get(index, 0)
|
|
delta_tokens = all_tokens[sent_count:]
|
|
|
|
# Skip if no new tokens for this sequence
|
|
if not delta_tokens:
|
|
continue
|
|
|
|
# Update sent count
|
|
sent_token_counts[index] = len(all_tokens)
|
|
|
|
chunk = trtllm_service_pb2.GenerateStreamChunk(
|
|
token_ids=delta_tokens,
|
|
sequence_index=completion.index,
|
|
prompt_tokens=len(prompt_token_ids),
|
|
completion_tokens=len(completion.token_ids) if completion.token_ids else 0,
|
|
cached_tokens=cached_tokens,
|
|
)
|
|
|
|
# Add logprobs if available
|
|
# Note: We compute delta logprobs ourselves since logprobs_diff has same issue as token_ids_diff
|
|
if completion.logprobs:
|
|
all_logprobs = completion.logprobs
|
|
delta_logprobs = all_logprobs[sent_count:] if sent_count < len(all_logprobs) else []
|
|
proto_logprobs = self._convert_logprobs_to_proto(delta_tokens, delta_logprobs)
|
|
chunk.logprobs.extend(proto_logprobs)
|
|
|
|
responses.append(
|
|
trtllm_service_pb2.GenerateResponse(
|
|
request_id=request_id,
|
|
chunk=chunk,
|
|
)
|
|
)
|
|
|
|
return responses
|
|
|
|
def _complete_responses(
|
|
self,
|
|
request_id: str,
|
|
gen_result,
|
|
prompt_token_ids: list,
|
|
) -> List[trtllm_service_pb2.GenerateResponse]:
|
|
"""Build final completion responses from GenerationResult.
|
|
|
|
For n>1, returns one response per output sequence.
|
|
|
|
Args:
|
|
request_id: The request ID
|
|
gen_result: TensorRT-LLM GenerationResult (finished=True)
|
|
prompt_token_ids: Original prompt tokens
|
|
|
|
Returns:
|
|
List of GenerateResponse with complete field set (one per output)
|
|
"""
|
|
responses = []
|
|
cached_tokens = gen_result.cached_tokens if hasattr(gen_result, "cached_tokens") else 0
|
|
|
|
if not gen_result.outputs:
|
|
# No outputs, return error response
|
|
responses.append(
|
|
trtllm_service_pb2.GenerateResponse(
|
|
request_id=request_id,
|
|
complete=trtllm_service_pb2.GenerateComplete(
|
|
output_token_ids=[],
|
|
finish_reason="error",
|
|
prompt_tokens=len(prompt_token_ids),
|
|
completion_tokens=0,
|
|
cached_tokens=0,
|
|
),
|
|
)
|
|
)
|
|
return responses
|
|
|
|
# Process all outputs (for n>1 support)
|
|
for completion in gen_result.outputs:
|
|
output_tokens = list(completion.token_ids) if completion.token_ids else []
|
|
|
|
complete = trtllm_service_pb2.GenerateComplete(
|
|
output_token_ids=output_tokens,
|
|
sequence_index=completion.index,
|
|
finish_reason=completion.finish_reason or "stop",
|
|
prompt_tokens=len(prompt_token_ids),
|
|
completion_tokens=len(output_tokens),
|
|
cached_tokens=cached_tokens,
|
|
)
|
|
|
|
# Add stop reason if available
|
|
if hasattr(completion, "stop_reason") and completion.stop_reason:
|
|
complete.stop_reason = str(completion.stop_reason)
|
|
|
|
# Add generation logprobs if available
|
|
if completion.logprobs:
|
|
proto_logprobs = self._convert_logprobs_to_proto(output_tokens, completion.logprobs)
|
|
complete.logprobs.extend(proto_logprobs)
|
|
|
|
# Add prompt logprobs if available
|
|
if hasattr(completion, "prompt_logprobs") and completion.prompt_logprobs:
|
|
# For prompt logprobs, we use the prompt_token_ids
|
|
proto_prompt_logprobs = self._convert_logprobs_to_proto(
|
|
prompt_token_ids, completion.prompt_logprobs
|
|
)
|
|
complete.prompt_logprobs.extend(proto_prompt_logprobs)
|
|
|
|
responses.append(
|
|
trtllm_service_pb2.GenerateResponse(
|
|
request_id=request_id,
|
|
complete=complete,
|
|
)
|
|
)
|
|
|
|
return responses
|
|
|
|
def _error_response(
|
|
self,
|
|
request_id: str,
|
|
message: str,
|
|
error_type: str,
|
|
code: int,
|
|
) -> trtllm_service_pb2.GenerateResponse:
|
|
"""Build an error response.
|
|
|
|
Args:
|
|
request_id: The request ID
|
|
message: Error message
|
|
error_type: Error type string
|
|
code: Error code
|
|
|
|
Returns:
|
|
GenerateResponse with error field set
|
|
"""
|
|
return trtllm_service_pb2.GenerateResponse(
|
|
request_id=request_id,
|
|
error=trtllm_service_pb2.GenerateError(
|
|
message=message,
|
|
type=error_type,
|
|
code=code,
|
|
),
|
|
)
|