mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-22 02:35:21 +08:00
512 lines
16 KiB
Protocol Buffer
512 lines
16 KiB
Protocol Buffer
// SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
// SPDX-License-Identifier: Apache-2.0
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
syntax = "proto3";
|
|
package trtllm;
|
|
|
|
// TensorRT-LLM gRPC Service
|
|
//
|
|
// This service provides high-performance inference for LLMs using TensorRT-LLM.
|
|
// It accepts pre-tokenized requests and returns raw token IDs, enabling efficient
|
|
// binary communication with external routers (e.g., sgl-router).
|
|
//
|
|
// Key Design Principles:
|
|
// - Token IDs only: No text in requests or responses (router handles tokenization)
|
|
// - Streaming delta mode: Chunks contain only new tokens since last chunk
|
|
// - Full feature support: All TensorRT-LLM capabilities exposed
|
|
|
|
service TrtllmService {
|
|
// Generate tokens from pre-tokenized input
|
|
// Returns a stream of responses containing token IDs
|
|
rpc Generate(GenerateRequest) returns (stream GenerateResponse);
|
|
|
|
// Generate embeddings from pre-tokenized input (for embedding models)
|
|
rpc Embed(EmbedRequest) returns (EmbedResponse);
|
|
|
|
// Health check endpoint
|
|
rpc HealthCheck(HealthCheckRequest) returns (HealthCheckResponse);
|
|
|
|
// Abort a running generation request
|
|
rpc Abort(AbortRequest) returns (AbortResponse);
|
|
|
|
// Get model information (vocab size, max lengths, etc.)
|
|
rpc GetModelInfo(GetModelInfoRequest) returns (GetModelInfoResponse);
|
|
|
|
// Get server information (version, parallelism, etc.)
|
|
rpc GetServerInfo(GetServerInfoRequest) returns (GetServerInfoResponse);
|
|
}
|
|
|
|
// ============================================================================
|
|
// Generate Request
|
|
// ============================================================================
|
|
|
|
message GenerateRequest {
|
|
// Unique request identifier (assigned by client/router)
|
|
string request_id = 1;
|
|
|
|
// Pre-tokenized input (REQUIRED - router tokenizes)
|
|
TokenizedInput tokenized = 2;
|
|
|
|
// Sampling configuration
|
|
SamplingConfig sampling_config = 3;
|
|
|
|
// Output configuration
|
|
OutputConfig output_config = 4;
|
|
|
|
// Maximum tokens to generate (REQUIRED)
|
|
uint32 max_tokens = 5;
|
|
|
|
// Enable streaming mode
|
|
bool streaming = 6;
|
|
|
|
// End-of-sequence token ID (optional, set to -1 to ignore EOS)
|
|
optional int32 end_id = 7;
|
|
|
|
// Padding token ID
|
|
optional int32 pad_id = 8;
|
|
|
|
// Bad word token sequences (generation redirects if these are produced)
|
|
repeated TokenSequence bad_words = 9;
|
|
|
|
// Stop word token sequences (generation stops when these are produced)
|
|
repeated TokenSequence stop_words = 10;
|
|
|
|
// Guided decoding parameters (JSON schema, regex, grammar)
|
|
optional GuidedDecodingParams guided_decoding = 11;
|
|
|
|
// Embedding bias tensor (vocab_size floats, optional)
|
|
repeated float embedding_bias = 12;
|
|
|
|
// LoRA adapter configuration
|
|
optional LoraConfig lora_config = 13;
|
|
|
|
// Prompt tuning configuration
|
|
optional PromptTuningConfig prompt_tuning_config = 14;
|
|
|
|
// Multimodal input (for VLMs)
|
|
optional MultimodalInput multimodal_input = 15;
|
|
|
|
// KV cache retention configuration
|
|
optional KvCacheRetentionConfig kv_cache_retention = 16;
|
|
|
|
// Disaggregated inference parameters
|
|
optional DisaggregatedParams disaggregated_params = 17;
|
|
|
|
// Lookahead decoding configuration
|
|
optional LookaheadConfig lookahead_config = 18;
|
|
|
|
// Cache salt ID for cache hashing
|
|
optional int64 cache_salt_id = 19;
|
|
|
|
// Request arrival time (unix timestamp for metrics)
|
|
optional double arrival_time = 20;
|
|
}
|
|
|
|
// Tokenized input from router
|
|
message TokenizedInput {
|
|
// Original text (for debugging/logging only, not used for generation)
|
|
string original_text = 1;
|
|
|
|
// Pre-tokenized input token IDs (REQUIRED)
|
|
repeated uint32 input_token_ids = 2;
|
|
|
|
// Query token IDs for VLM star attention (optional)
|
|
repeated uint32 query_token_ids = 3;
|
|
}
|
|
|
|
// Sequence of token IDs (for stop/bad words)
|
|
message TokenSequence {
|
|
repeated uint32 token_ids = 1;
|
|
}
|
|
|
|
// ============================================================================
|
|
// Sampling Configuration
|
|
// Maps to tensorrt_llm.bindings.executor.SamplingConfig
|
|
// ============================================================================
|
|
|
|
message SamplingConfig {
|
|
// Beam width (1 for sampling, >1 for beam search)
|
|
int32 beam_width = 1;
|
|
|
|
// Number of sequences to return
|
|
uint32 num_return_sequences = 2;
|
|
|
|
// Top-K sampling (0 = disabled, considers all tokens)
|
|
optional int32 top_k = 3;
|
|
|
|
// Top-P (nucleus) sampling threshold
|
|
optional float top_p = 4;
|
|
|
|
// Top-P minimum threshold for decay
|
|
optional float top_p_min = 5;
|
|
|
|
// Top-P reset token IDs
|
|
optional int32 top_p_reset_ids = 6;
|
|
|
|
// Top-P decay factor
|
|
optional float top_p_decay = 7;
|
|
|
|
// Random seed for reproducibility
|
|
optional uint64 seed = 8;
|
|
|
|
// Temperature for sampling (0 = greedy, higher = more random)
|
|
optional float temperature = 9;
|
|
|
|
// Minimum tokens to generate before stopping
|
|
optional uint32 min_tokens = 10;
|
|
|
|
// Beam search diversity rate
|
|
optional float beam_search_diversity_rate = 11;
|
|
|
|
// Repetition penalty (>1 discourages, <1 encourages repetition)
|
|
optional float repetition_penalty = 12;
|
|
|
|
// Presence penalty (penalizes tokens that have appeared)
|
|
optional float presence_penalty = 13;
|
|
|
|
// Frequency penalty (penalizes based on frequency of appearance)
|
|
optional float frequency_penalty = 14;
|
|
|
|
// Number of prompt tokens to ignore for penalties
|
|
optional int32 prompt_ignore_length = 15;
|
|
|
|
// Length penalty for beam search
|
|
optional float length_penalty = 16;
|
|
|
|
// Early stopping for beam search
|
|
optional int32 early_stopping = 17;
|
|
|
|
// No repeat n-gram size
|
|
optional int32 no_repeat_ngram_size = 18;
|
|
|
|
// Min-P sampling threshold
|
|
optional float min_p = 19;
|
|
|
|
// Variable beam width array for beam search
|
|
repeated int32 beam_width_array = 20;
|
|
}
|
|
|
|
// ============================================================================
|
|
// Output Configuration
|
|
// Maps to tensorrt_llm.bindings.executor.OutputConfig
|
|
// ============================================================================
|
|
|
|
message OutputConfig {
|
|
// Number of top log probabilities to return per output token
|
|
optional int32 logprobs = 1;
|
|
|
|
// Number of top log probabilities to return per prompt token
|
|
optional int32 prompt_logprobs = 2;
|
|
|
|
// Return context logits tensor (large, use with caution)
|
|
bool return_context_logits = 3;
|
|
|
|
// Return generation logits tensor (large, use with caution)
|
|
bool return_generation_logits = 4;
|
|
|
|
// Exclude input tokens from output (set to true to enable; TRT-LLM defaults to true internally)
|
|
bool exclude_input_from_output = 5;
|
|
|
|
// Return encoder output (for encoder-decoder models)
|
|
bool return_encoder_output = 6;
|
|
|
|
// Return performance metrics
|
|
bool return_perf_metrics = 7;
|
|
}
|
|
|
|
// ============================================================================
|
|
// Guided Decoding
|
|
// ============================================================================
|
|
|
|
message GuidedDecodingParams {
|
|
// Guide type enumeration
|
|
enum GuideType {
|
|
GUIDE_TYPE_UNSPECIFIED = 0;
|
|
GUIDE_TYPE_JSON = 1; // JSON format (any valid JSON)
|
|
GUIDE_TYPE_JSON_SCHEMA = 2; // JSON with schema constraint
|
|
GUIDE_TYPE_REGEX = 3; // Regular expression constraint
|
|
GUIDE_TYPE_EBNF_GRAMMAR = 4; // EBNF grammar constraint
|
|
GUIDE_TYPE_STRUCTURAL_TAG = 5; // Structural tag (xgrammar backend)
|
|
}
|
|
|
|
GuideType guide_type = 1;
|
|
|
|
// Guide content (schema string, regex pattern, or grammar definition)
|
|
string guide = 2;
|
|
}
|
|
|
|
// ============================================================================
|
|
// LoRA Configuration
|
|
// ============================================================================
|
|
|
|
message LoraConfig {
|
|
// LoRA task/adapter ID
|
|
int64 task_id = 1;
|
|
|
|
// LoRA weights (serialized tensor, optional if already cached)
|
|
optional bytes weights = 2;
|
|
|
|
// LoRA config as JSON string
|
|
optional string config_json = 3;
|
|
}
|
|
|
|
// ============================================================================
|
|
// Prompt Tuning Configuration
|
|
// ============================================================================
|
|
|
|
message PromptTuningConfig {
|
|
// Embedding table (serialized tensor)
|
|
bytes embedding_table = 1;
|
|
}
|
|
|
|
// ============================================================================
|
|
// Multimodal Input (for Vision-Language Models)
|
|
// ============================================================================
|
|
|
|
message MultimodalInput {
|
|
// Multimodal content hashes for caching
|
|
repeated int64 multimodal_hashes = 1;
|
|
|
|
// Positions in input where multimodal content is inserted
|
|
repeated int32 multimodal_positions = 2;
|
|
|
|
// Lengths of multimodal content at each position
|
|
repeated int32 multimodal_lengths = 3;
|
|
}
|
|
|
|
// ============================================================================
|
|
// KV Cache Retention
|
|
// ============================================================================
|
|
|
|
message KvCacheRetentionConfig {
|
|
// Retention policy name
|
|
string policy = 1;
|
|
|
|
// Additional configuration as JSON string
|
|
string config_json = 2;
|
|
}
|
|
|
|
// ============================================================================
|
|
// Disaggregated Inference
|
|
// ============================================================================
|
|
|
|
message DisaggregatedParams {
|
|
// Request type for disaggregated inference
|
|
enum RequestType {
|
|
REQUEST_TYPE_CONTEXT_AND_GENERATION = 0; // Normal full request
|
|
REQUEST_TYPE_CONTEXT_ONLY = 1; // Prefill only
|
|
REQUEST_TYPE_GENERATION_ONLY = 2; // Decode only
|
|
}
|
|
|
|
RequestType request_type = 1;
|
|
|
|
// Context request ID (links context and generation phases)
|
|
string ctx_request_id = 2;
|
|
|
|
// Context phase parameters (for generation_only requests)
|
|
optional ContextPhaseParams context_phase_params = 3;
|
|
}
|
|
|
|
message ContextPhaseParams {
|
|
// First generated token ID from context phase
|
|
uint32 first_gen_token_id = 1;
|
|
|
|
// KV cache block pointers (serialized)
|
|
bytes kv_cache_blocks = 2;
|
|
}
|
|
|
|
// ============================================================================
|
|
// Lookahead Decoding
|
|
// ============================================================================
|
|
|
|
message LookaheadConfig {
|
|
int32 max_window_size = 1;
|
|
int32 max_ngram_size = 2;
|
|
int32 max_verification_set_size = 3;
|
|
}
|
|
|
|
// ============================================================================
|
|
// Generate Response
|
|
// ============================================================================
|
|
|
|
message GenerateResponse {
|
|
// Request ID echo
|
|
string request_id = 1;
|
|
|
|
// Response type (oneof ensures exactly one is set)
|
|
oneof response {
|
|
GenerateStreamChunk chunk = 2; // Streaming delta
|
|
GenerateComplete complete = 3; // Final response
|
|
GenerateError error = 4; // Error response
|
|
}
|
|
}
|
|
|
|
// Streaming chunk containing delta tokens (new tokens since last chunk)
|
|
message GenerateStreamChunk {
|
|
// NEW token IDs only (delta from previous chunk)
|
|
repeated uint32 token_ids = 1;
|
|
|
|
// Beam/sequence index (for beam_width > 1 or n > 1)
|
|
uint32 sequence_index = 2;
|
|
|
|
// Token counts for usage tracking
|
|
uint32 prompt_tokens = 3;
|
|
uint32 completion_tokens = 4;
|
|
uint32 cached_tokens = 5;
|
|
|
|
// Log probabilities for this chunk's tokens (if requested)
|
|
repeated TokenLogprob logprobs = 6;
|
|
}
|
|
|
|
// Final/complete response with all output tokens
|
|
message GenerateComplete {
|
|
// All output token IDs (cumulative, not delta)
|
|
repeated uint32 output_token_ids = 1;
|
|
|
|
// Beam/sequence index
|
|
uint32 sequence_index = 2;
|
|
|
|
// Finish reason: "stop", "length", "stop_word"
|
|
string finish_reason = 3;
|
|
|
|
// Specific stop reason (stop word/token that triggered stop)
|
|
optional string stop_reason = 4;
|
|
|
|
// Token counts for usage tracking
|
|
uint32 prompt_tokens = 5;
|
|
uint32 completion_tokens = 6;
|
|
uint32 cached_tokens = 7;
|
|
|
|
// Generation log probabilities (if requested)
|
|
repeated TokenLogprob logprobs = 8;
|
|
|
|
// Prompt log probabilities (if requested)
|
|
repeated TokenLogprob prompt_logprobs = 9;
|
|
|
|
// Performance metrics (if requested)
|
|
optional PerfMetrics perf_metrics = 10;
|
|
|
|
// Context logits (if requested) - serialized float tensor
|
|
optional bytes context_logits = 11;
|
|
|
|
// Generation logits (if requested) - serialized float tensor
|
|
optional bytes generation_logits = 12;
|
|
}
|
|
|
|
// Token log probability information
|
|
message TokenLogprob {
|
|
uint32 token_id = 1;
|
|
float logprob = 2;
|
|
|
|
// Top alternative tokens and their log probabilities
|
|
repeated TopLogprob top_logprobs = 3;
|
|
}
|
|
|
|
message TopLogprob {
|
|
uint32 token_id = 1;
|
|
float logprob = 2;
|
|
}
|
|
|
|
// Performance metrics for request
|
|
message PerfMetrics {
|
|
double arrival_time = 1; // When request arrived
|
|
double first_scheduled_time = 2; // When first scheduled
|
|
double first_token_time = 3; // Time to first token (TTFT)
|
|
double last_token_time = 4; // When last token generated
|
|
double kv_cache_transfer_start = 5; // KV cache transfer start (disagg)
|
|
double kv_cache_transfer_end = 6; // KV cache transfer end (disagg)
|
|
int64 kv_cache_size = 7; // KV cache size in bytes
|
|
}
|
|
|
|
// Error response
|
|
message GenerateError {
|
|
string message = 1; // Human-readable error message
|
|
string type = 2; // Error type (e.g., "InvalidRequest", "InternalError")
|
|
int32 code = 3; // Error code
|
|
}
|
|
|
|
// ============================================================================
|
|
// Embed Request/Response (for embedding models)
|
|
// ============================================================================
|
|
|
|
message EmbedRequest {
|
|
string request_id = 1;
|
|
TokenizedInput tokenized = 2;
|
|
}
|
|
|
|
message EmbedResponse {
|
|
string request_id = 1;
|
|
repeated float embedding = 2; // Embedding vector
|
|
uint32 prompt_tokens = 3;
|
|
}
|
|
|
|
// ============================================================================
|
|
// Health Check
|
|
// ============================================================================
|
|
|
|
message HealthCheckRequest {}
|
|
|
|
message HealthCheckResponse {
|
|
string status = 1; // "OK" or error description
|
|
}
|
|
|
|
// ============================================================================
|
|
// Abort Request
|
|
// ============================================================================
|
|
|
|
message AbortRequest {
|
|
string request_id = 1;
|
|
}
|
|
|
|
message AbortResponse {
|
|
bool success = 1;
|
|
string message = 2;
|
|
}
|
|
|
|
// ============================================================================
|
|
// Model Info
|
|
// ============================================================================
|
|
|
|
message GetModelInfoRequest {}
|
|
|
|
message GetModelInfoResponse {
|
|
string model_id = 1; // Model identifier/path
|
|
int32 max_input_len = 2; // Maximum input length
|
|
int32 max_seq_len = 3; // Maximum sequence length (input + output)
|
|
int32 max_batch_size = 4; // Maximum batch size
|
|
int32 vocab_size = 5; // Vocabulary size
|
|
int32 hidden_size = 6; // Hidden dimension
|
|
int32 num_layers = 7; // Number of transformer layers
|
|
int32 num_heads = 8; // Number of attention heads
|
|
|
|
// Supported features
|
|
repeated string supported_features = 9; // e.g., "lora", "guided_decoding"
|
|
}
|
|
|
|
// ============================================================================
|
|
// Server Info
|
|
// ============================================================================
|
|
|
|
message GetServerInfoRequest {}
|
|
|
|
message GetServerInfoResponse {
|
|
string version = 1; // TensorRT-LLM version
|
|
string backend = 2; // "tensorrt" or "pytorch"
|
|
int32 tensor_parallel_size = 3; // TP size
|
|
int32 pipeline_parallel_size = 4; // PP size
|
|
int32 context_parallel_size = 5; // CP size
|
|
int32 world_size = 6; // Total world size
|
|
}
|