TensorRT-LLMs/tensorrt_llm/grpc/trtllm_service.proto
2026-01-30 07:48:27 +08:00

512 lines
16 KiB
Protocol Buffer

// SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto3";
package trtllm;
// TensorRT-LLM gRPC Service
//
// This service provides high-performance inference for LLMs using TensorRT-LLM.
// It accepts pre-tokenized requests and returns raw token IDs, enabling efficient
// binary communication with external routers (e.g., sgl-router).
//
// Key Design Principles:
// - Token IDs only: No text in requests or responses (router handles tokenization)
// - Streaming delta mode: Chunks contain only new tokens since last chunk
// - Full feature support: All TensorRT-LLM capabilities exposed
service TrtllmService {
// Generate tokens from pre-tokenized input
// Returns a stream of responses containing token IDs
rpc Generate(GenerateRequest) returns (stream GenerateResponse);
// Generate embeddings from pre-tokenized input (for embedding models)
rpc Embed(EmbedRequest) returns (EmbedResponse);
// Health check endpoint
rpc HealthCheck(HealthCheckRequest) returns (HealthCheckResponse);
// Abort a running generation request
rpc Abort(AbortRequest) returns (AbortResponse);
// Get model information (vocab size, max lengths, etc.)
rpc GetModelInfo(GetModelInfoRequest) returns (GetModelInfoResponse);
// Get server information (version, parallelism, etc.)
rpc GetServerInfo(GetServerInfoRequest) returns (GetServerInfoResponse);
}
// ============================================================================
// Generate Request
// ============================================================================
message GenerateRequest {
// Unique request identifier (assigned by client/router)
string request_id = 1;
// Pre-tokenized input (REQUIRED - router tokenizes)
TokenizedInput tokenized = 2;
// Sampling configuration
SamplingConfig sampling_config = 3;
// Output configuration
OutputConfig output_config = 4;
// Maximum tokens to generate (REQUIRED)
uint32 max_tokens = 5;
// Enable streaming mode
bool streaming = 6;
// End-of-sequence token ID (optional, set to -1 to ignore EOS)
optional int32 end_id = 7;
// Padding token ID
optional int32 pad_id = 8;
// Bad word token sequences (generation redirects if these are produced)
repeated TokenSequence bad_words = 9;
// Stop word token sequences (generation stops when these are produced)
repeated TokenSequence stop_words = 10;
// Guided decoding parameters (JSON schema, regex, grammar)
optional GuidedDecodingParams guided_decoding = 11;
// Embedding bias tensor (vocab_size floats, optional)
repeated float embedding_bias = 12;
// LoRA adapter configuration
optional LoraConfig lora_config = 13;
// Prompt tuning configuration
optional PromptTuningConfig prompt_tuning_config = 14;
// Multimodal input (for VLMs)
optional MultimodalInput multimodal_input = 15;
// KV cache retention configuration
optional KvCacheRetentionConfig kv_cache_retention = 16;
// Disaggregated inference parameters
optional DisaggregatedParams disaggregated_params = 17;
// Lookahead decoding configuration
optional LookaheadConfig lookahead_config = 18;
// Cache salt ID for cache hashing
optional int64 cache_salt_id = 19;
// Request arrival time (unix timestamp for metrics)
optional double arrival_time = 20;
}
// Tokenized input from router
message TokenizedInput {
// Original text (for debugging/logging only, not used for generation)
string original_text = 1;
// Pre-tokenized input token IDs (REQUIRED)
repeated uint32 input_token_ids = 2;
// Query token IDs for VLM star attention (optional)
repeated uint32 query_token_ids = 3;
}
// Sequence of token IDs (for stop/bad words)
message TokenSequence {
repeated uint32 token_ids = 1;
}
// ============================================================================
// Sampling Configuration
// Maps to tensorrt_llm.bindings.executor.SamplingConfig
// ============================================================================
message SamplingConfig {
// Beam width (1 for sampling, >1 for beam search)
int32 beam_width = 1;
// Number of sequences to return
uint32 num_return_sequences = 2;
// Top-K sampling (0 = disabled, considers all tokens)
optional int32 top_k = 3;
// Top-P (nucleus) sampling threshold
optional float top_p = 4;
// Top-P minimum threshold for decay
optional float top_p_min = 5;
// Top-P reset token IDs
optional int32 top_p_reset_ids = 6;
// Top-P decay factor
optional float top_p_decay = 7;
// Random seed for reproducibility
optional uint64 seed = 8;
// Temperature for sampling (0 = greedy, higher = more random)
optional float temperature = 9;
// Minimum tokens to generate before stopping
optional uint32 min_tokens = 10;
// Beam search diversity rate
optional float beam_search_diversity_rate = 11;
// Repetition penalty (>1 discourages, <1 encourages repetition)
optional float repetition_penalty = 12;
// Presence penalty (penalizes tokens that have appeared)
optional float presence_penalty = 13;
// Frequency penalty (penalizes based on frequency of appearance)
optional float frequency_penalty = 14;
// Number of prompt tokens to ignore for penalties
optional int32 prompt_ignore_length = 15;
// Length penalty for beam search
optional float length_penalty = 16;
// Early stopping for beam search
optional int32 early_stopping = 17;
// No repeat n-gram size
optional int32 no_repeat_ngram_size = 18;
// Min-P sampling threshold
optional float min_p = 19;
// Variable beam width array for beam search
repeated int32 beam_width_array = 20;
}
// ============================================================================
// Output Configuration
// Maps to tensorrt_llm.bindings.executor.OutputConfig
// ============================================================================
message OutputConfig {
// Number of top log probabilities to return per output token
optional int32 logprobs = 1;
// Number of top log probabilities to return per prompt token
optional int32 prompt_logprobs = 2;
// Return context logits tensor (large, use with caution)
bool return_context_logits = 3;
// Return generation logits tensor (large, use with caution)
bool return_generation_logits = 4;
// Exclude input tokens from output (set to true to enable; TRT-LLM defaults to true internally)
bool exclude_input_from_output = 5;
// Return encoder output (for encoder-decoder models)
bool return_encoder_output = 6;
// Return performance metrics
bool return_perf_metrics = 7;
}
// ============================================================================
// Guided Decoding
// ============================================================================
message GuidedDecodingParams {
// Guide type enumeration
enum GuideType {
GUIDE_TYPE_UNSPECIFIED = 0;
GUIDE_TYPE_JSON = 1; // JSON format (any valid JSON)
GUIDE_TYPE_JSON_SCHEMA = 2; // JSON with schema constraint
GUIDE_TYPE_REGEX = 3; // Regular expression constraint
GUIDE_TYPE_EBNF_GRAMMAR = 4; // EBNF grammar constraint
GUIDE_TYPE_STRUCTURAL_TAG = 5; // Structural tag (xgrammar backend)
}
GuideType guide_type = 1;
// Guide content (schema string, regex pattern, or grammar definition)
string guide = 2;
}
// ============================================================================
// LoRA Configuration
// ============================================================================
message LoraConfig {
// LoRA task/adapter ID
int64 task_id = 1;
// LoRA weights (serialized tensor, optional if already cached)
optional bytes weights = 2;
// LoRA config as JSON string
optional string config_json = 3;
}
// ============================================================================
// Prompt Tuning Configuration
// ============================================================================
message PromptTuningConfig {
// Embedding table (serialized tensor)
bytes embedding_table = 1;
}
// ============================================================================
// Multimodal Input (for Vision-Language Models)
// ============================================================================
message MultimodalInput {
// Multimodal content hashes for caching
repeated int64 multimodal_hashes = 1;
// Positions in input where multimodal content is inserted
repeated int32 multimodal_positions = 2;
// Lengths of multimodal content at each position
repeated int32 multimodal_lengths = 3;
}
// ============================================================================
// KV Cache Retention
// ============================================================================
message KvCacheRetentionConfig {
// Retention policy name
string policy = 1;
// Additional configuration as JSON string
string config_json = 2;
}
// ============================================================================
// Disaggregated Inference
// ============================================================================
message DisaggregatedParams {
// Request type for disaggregated inference
enum RequestType {
REQUEST_TYPE_CONTEXT_AND_GENERATION = 0; // Normal full request
REQUEST_TYPE_CONTEXT_ONLY = 1; // Prefill only
REQUEST_TYPE_GENERATION_ONLY = 2; // Decode only
}
RequestType request_type = 1;
// Context request ID (links context and generation phases)
string ctx_request_id = 2;
// Context phase parameters (for generation_only requests)
optional ContextPhaseParams context_phase_params = 3;
}
message ContextPhaseParams {
// First generated token ID from context phase
uint32 first_gen_token_id = 1;
// KV cache block pointers (serialized)
bytes kv_cache_blocks = 2;
}
// ============================================================================
// Lookahead Decoding
// ============================================================================
message LookaheadConfig {
int32 max_window_size = 1;
int32 max_ngram_size = 2;
int32 max_verification_set_size = 3;
}
// ============================================================================
// Generate Response
// ============================================================================
message GenerateResponse {
// Request ID echo
string request_id = 1;
// Response type (oneof ensures exactly one is set)
oneof response {
GenerateStreamChunk chunk = 2; // Streaming delta
GenerateComplete complete = 3; // Final response
GenerateError error = 4; // Error response
}
}
// Streaming chunk containing delta tokens (new tokens since last chunk)
message GenerateStreamChunk {
// NEW token IDs only (delta from previous chunk)
repeated uint32 token_ids = 1;
// Beam/sequence index (for beam_width > 1 or n > 1)
uint32 sequence_index = 2;
// Token counts for usage tracking
uint32 prompt_tokens = 3;
uint32 completion_tokens = 4;
uint32 cached_tokens = 5;
// Log probabilities for this chunk's tokens (if requested)
repeated TokenLogprob logprobs = 6;
}
// Final/complete response with all output tokens
message GenerateComplete {
// All output token IDs (cumulative, not delta)
repeated uint32 output_token_ids = 1;
// Beam/sequence index
uint32 sequence_index = 2;
// Finish reason: "stop", "length", "stop_word"
string finish_reason = 3;
// Specific stop reason (stop word/token that triggered stop)
optional string stop_reason = 4;
// Token counts for usage tracking
uint32 prompt_tokens = 5;
uint32 completion_tokens = 6;
uint32 cached_tokens = 7;
// Generation log probabilities (if requested)
repeated TokenLogprob logprobs = 8;
// Prompt log probabilities (if requested)
repeated TokenLogprob prompt_logprobs = 9;
// Performance metrics (if requested)
optional PerfMetrics perf_metrics = 10;
// Context logits (if requested) - serialized float tensor
optional bytes context_logits = 11;
// Generation logits (if requested) - serialized float tensor
optional bytes generation_logits = 12;
}
// Token log probability information
message TokenLogprob {
uint32 token_id = 1;
float logprob = 2;
// Top alternative tokens and their log probabilities
repeated TopLogprob top_logprobs = 3;
}
message TopLogprob {
uint32 token_id = 1;
float logprob = 2;
}
// Performance metrics for request
message PerfMetrics {
double arrival_time = 1; // When request arrived
double first_scheduled_time = 2; // When first scheduled
double first_token_time = 3; // Time to first token (TTFT)
double last_token_time = 4; // When last token generated
double kv_cache_transfer_start = 5; // KV cache transfer start (disagg)
double kv_cache_transfer_end = 6; // KV cache transfer end (disagg)
int64 kv_cache_size = 7; // KV cache size in bytes
}
// Error response
message GenerateError {
string message = 1; // Human-readable error message
string type = 2; // Error type (e.g., "InvalidRequest", "InternalError")
int32 code = 3; // Error code
}
// ============================================================================
// Embed Request/Response (for embedding models)
// ============================================================================
message EmbedRequest {
string request_id = 1;
TokenizedInput tokenized = 2;
}
message EmbedResponse {
string request_id = 1;
repeated float embedding = 2; // Embedding vector
uint32 prompt_tokens = 3;
}
// ============================================================================
// Health Check
// ============================================================================
message HealthCheckRequest {}
message HealthCheckResponse {
string status = 1; // "OK" or error description
}
// ============================================================================
// Abort Request
// ============================================================================
message AbortRequest {
string request_id = 1;
}
message AbortResponse {
bool success = 1;
string message = 2;
}
// ============================================================================
// Model Info
// ============================================================================
message GetModelInfoRequest {}
message GetModelInfoResponse {
string model_id = 1; // Model identifier/path
int32 max_input_len = 2; // Maximum input length
int32 max_seq_len = 3; // Maximum sequence length (input + output)
int32 max_batch_size = 4; // Maximum batch size
int32 vocab_size = 5; // Vocabulary size
int32 hidden_size = 6; // Hidden dimension
int32 num_layers = 7; // Number of transformer layers
int32 num_heads = 8; // Number of attention heads
// Supported features
repeated string supported_features = 9; // e.g., "lora", "guided_decoding"
}
// ============================================================================
// Server Info
// ============================================================================
message GetServerInfoRequest {}
message GetServerInfoResponse {
string version = 1; // TensorRT-LLM version
string backend = 2; // "tensorrt" or "pytorch"
int32 tensor_parallel_size = 3; // TP size
int32 pipeline_parallel_size = 4; // PP size
int32 context_parallel_size = 5; // CP size
int32 world_size = 6; // Total world size
}