[TRTLLM-10388][feat] Support logprobs for Completions API (#10809)

Signed-off-by: Pengyun Lin <81065165+LinPoly@users.noreply.github.com>
This commit is contained in:
Pengyun Lin 2026-01-22 21:25:24 +08:00 committed by GitHub
parent 9beb971827
commit 5e34112b27
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 230 additions and 11 deletions

View File

@ -376,7 +376,10 @@ class CompletionRequest(OpenAIBaseModel):
# doc: end-completion-extra-params
def to_sampling_params(self, vocab_size: int = 32000) -> SamplingParams:
def to_sampling_params(self,
vocab_size: int = 32000,
gather_generation_logits: bool = False,
backend: Optional[str] = None) -> SamplingParams:
sampling_params = SamplingParams(
best_of=self.best_of,
frequency_penalty=self.frequency_penalty,
@ -416,17 +419,26 @@ class CompletionRequest(OpenAIBaseModel):
# completion-extra-params
add_special_tokens=self.add_special_tokens,
# TODO: migrate to use logprobs and prompt_logprobs
_return_log_probs=bool(self.logprobs),
)
if self.logprobs:
if backend == "pytorch":
sampling_params.logprobs = self.logprobs
else:
if gather_generation_logits:
sampling_params.logprobs = self.logprobs
elif self.logprobs > 1:
raise ValueError(
"`logprobs` must be 1 or `gather_generation_logits` must be `True` to use `logprobs` > 1"
)
else:
sampling_params._return_log_probs = True
return sampling_params
@model_validator(mode="before")
@classmethod
def check_logprobs(cls, data):
if data.get("logprobs"):
raise ValueError("logprobs is not supported")
if (logprobs := data.get("logprobs")) is not None and logprobs < 0:
raise ValueError("logprobs must be positive or zero")
return data
@model_validator(mode="before")

View File

@ -794,7 +794,9 @@ class OpenAIServer:
# Pass the tokenizer vocabulary size so ``logit_bias`` can be
# expanded into an embedding bias tensor in the sampler.
sampling_params = request.to_sampling_params(
vocab_size=self.tokenizer.tokenizer.vocab_size)
vocab_size=self.tokenizer.tokenizer.vocab_size,
gather_generation_logits=self.llm.args.gather_generation_logits,
backend=self.llm.args.backend)
# TODO: better way to enable metrics
if len(os.getenv("TRTLLM_KVCACHE_TIME_OUTPUT_PATH", "")) > 0:
sampling_params.return_perf_metrics = True

View File

@ -26,8 +26,8 @@ from .openai_protocol import (ChatCompletionLogProbs,
ChatCompletionResponseStreamChoice,
ChatCompletionStreamResponse,
ChatCompletionToolsParam, ChatMessage,
CompletionRequest, CompletionResponse,
CompletionResponseChoice,
CompletionLogProbs, CompletionRequest,
CompletionResponse, CompletionResponseChoice,
CompletionResponseStreamChoice,
CompletionStreamResponse, DeltaFunctionCall,
DeltaMessage, DeltaToolCall, FunctionCall,
@ -394,6 +394,7 @@ class CompletionPostprocArgs(PostprocArgs):
prompt_idx: int = 0
detokenize: bool = True
prompt: Optional[str] = None
return_logprobs: bool = False
stream_options: Optional[StreamOptions] = None
@classmethod
@ -404,9 +405,43 @@ class CompletionPostprocArgs(PostprocArgs):
num_choices=request.n if request.n else 1,
stream_options=request.stream_options,
detokenize=request.detokenize,
return_logprobs=bool(request.logprobs),
)
def create_completion_logprobs(token_ids: List[int],
tokenizer: TransformersTokenizer,
logprobs: List[float] | TokenLogprobs,
initial_offset: int = 0) -> CompletionLogProbs:
assert len(token_ids) == len(logprobs), \
"token_ids and logprobs have different lengths"
text_offset = []
token_logprobs = []
top_logprobs_list = []
tokens = []
for token_id, logprob in zip(token_ids, logprobs):
if isinstance(logprob, dict):
token_logprobs.append(max(logprob[token_id].logprob, -9999.0))
top_logprobs_list.append({
tokenizer.decode(tid):
max(lp.logprob, -9999.0)
for tid, lp in logprob.items()
})
else:
token_logprobs.append(max(logprob, -9999.0))
token = tokenizer.decode(token_id)
if len(text_offset) == 0:
text_offset.append(initial_offset)
else:
text_offset.append(text_offset[-1] + len(token))
tokens.append(token)
return CompletionLogProbs(text_offset=text_offset,
token_logprobs=token_logprobs,
tokens=tokens,
top_logprobs=top_logprobs_list)
@nvtx_range_debug("completion_stream_post_processor")
def completion_stream_post_processor(rsp: DetokenizedGenerationResultBase,
args: CompletionPostprocArgs) -> List[str]:
@ -433,6 +468,12 @@ def completion_stream_post_processor(rsp: DetokenizedGenerationResultBase,
'avg_decoded_tokens_per_iter',
None),
)
if args.return_logprobs:
logprobs = output.logprobs_diff
token_ids = output.token_ids_diff
choice.logprobs = create_completion_logprobs(
token_ids, args.tokenizer, logprobs, output._last_text_len)
chunk = CompletionStreamResponse(model=args.model, choices=[choice])
if include_continuous_usage:
chunk.usage = UsageInfo(prompt_tokens=prompt_tokens,
@ -488,6 +529,11 @@ def completion_response_post_processor(
'avg_decoded_tokens_per_iter',
None),
)
if args.return_logprobs:
logprobs = output.logprobs
token_ids = output.token_ids
choice.logprobs = create_completion_logprobs(
token_ids, args.tokenizer, logprobs)
completion_tokens += output.length
choices.append(choice)

View File

@ -3,8 +3,10 @@
from typing import List
import numpy as np
import openai
import pytest
import yaml
from ..test_llm import get_model_path
from .openai_server import RemoteOpenAIServer
@ -22,6 +24,23 @@ def backend(request):
return request.param
@pytest.fixture(scope="module")
def temp_extra_llm_api_options_file(tmp_path_factory):
extra_llm_api_options_dict = {
"enable_chunked_prefill": False,
"gather_generation_logits": True,
"kv_cache_config": {
"enable_block_reuse": False,
}
}
temp_file_path = tmp_path_factory.mktemp(
"config") / "extra_llm_api_options.yaml"
with open(temp_file_path, 'w') as f:
yaml.dump(extra_llm_api_options_dict, f)
return temp_file_path
@pytest.fixture(scope="module",
params=[0, 2],
ids=["disable_processpool", "enable_processpool"])
@ -30,12 +49,16 @@ def num_postprocess_workers(request):
@pytest.fixture(scope="module")
def server(model_name: str, backend: str, num_postprocess_workers: int):
def server(model_name: str, backend: str, num_postprocess_workers: int,
temp_extra_llm_api_options_file: str):
model_path = get_model_path(model_name)
args = ["--backend", f"{backend}"]
args.extend(["--kv_cache_free_gpu_memory_fraction",
"0.2"]) # for co-existence with other servers
args.extend(["--num_postprocess_workers", f"{num_postprocess_workers}"])
if backend == "trt":
args.extend(
["--extra_llm_api_options", temp_extra_llm_api_options_file])
with RemoteOpenAIServer(model_path, args) as remote_server:
yield remote_server
@ -433,6 +456,107 @@ async def test_completion_with_invalid_logit_bias(
await invalid_logit_bias_helper(async_client, model_name, 'completions')
def test_completion_logprobs(client: openai.OpenAI, model_name: str,
backend: str, num_postprocess_workers: int):
"""Test completion with logprobs enabled (non-streaming)."""
if backend == "trt" and num_postprocess_workers > 0:
pytest.skip("Logprobs is not supported in TRT processpool mode")
prompt = "Hello, my name is"
completion = client.completions.create(
model=model_name,
prompt=prompt,
max_tokens=5,
temperature=0.0,
logprobs=1,
)
choice = completion.choices[0]
assert choice.logprobs is not None
# Verify logprobs structure
logprobs = choice.logprobs
assert logprobs.tokens is not None
assert logprobs.token_logprobs is not None
assert logprobs.text_offset is not None
# Verify lengths match
assert len(logprobs.tokens) == len(logprobs.token_logprobs)
assert len(logprobs.tokens) == len(logprobs.text_offset)
assert len(logprobs.tokens) > 0
# Verify logprobs values are valid (negative or zero for log probabilities)
for token_logprob in logprobs.token_logprobs:
assert token_logprob is not None
assert token_logprob <= 0
# Verify text_offset is monotonically increasing
for i in range(1, len(logprobs.text_offset)):
assert logprobs.text_offset[i] >= logprobs.text_offset[i - 1]
# Verify tokens are non-empty strings
for token in logprobs.tokens:
assert isinstance(token, str)
@pytest.mark.asyncio(loop_scope="module")
async def test_completion_logprobs_streaming(async_client: openai.AsyncOpenAI,
backend: str, model_name: str,
num_postprocess_workers: int):
"""Test completion with logprobs enabled (streaming)."""
if backend == "trt" and num_postprocess_workers > 0:
pytest.skip("Logprobs is not supported in TRT processpool mode")
prompt = "Hello, my name is"
# First get non-streaming result for comparison
single_completion = await async_client.completions.create(
model=model_name,
prompt=prompt,
max_tokens=5,
temperature=0.0,
logprobs=1,
)
single_logprobs = single_completion.choices[0].logprobs
assert single_logprobs is not None
# Now test streaming
stream = await async_client.completions.create(
model=model_name,
prompt=prompt,
max_tokens=5,
temperature=0.0,
logprobs=2,
stream=True,
)
all_tokens: List[str] = []
all_token_logprobs: List[float] = []
async for chunk in stream:
choice = chunk.choices[0]
if choice.logprobs is not None:
if choice.logprobs.tokens:
all_tokens.extend(choice.logprobs.tokens)
if choice.logprobs.token_logprobs:
all_token_logprobs.extend(choice.logprobs.token_logprobs)
# Verify streaming logprobs match non-streaming
assert all_tokens == single_logprobs.tokens
assert len(all_token_logprobs) == len(single_logprobs.token_logprobs)
# Compare logprobs values (should be close)
all_token_logprobs_arr = np.array(all_token_logprobs)
single_token_logprobs_arr = np.array(single_logprobs.token_logprobs)
assert np.allclose(all_token_logprobs_arr, single_token_logprobs_arr)
# Verify all logprobs are valid
for logprob in all_token_logprobs:
assert logprob is not None
assert logprob <= 0
def test_completion_cached_tokens(client: openai.OpenAI, model_name: str,
backend: str):
if backend == "trt":

View File

@ -64,7 +64,7 @@ async def test_chat_completion_top5_logprobs(async_client: openai.AsyncOpenAI,
# Test top_logprobs
chat_completion = await async_client.chat.completions.create(
model=model_name,
messages=messages,
messages=messages, # type: ignore[arg-type]
max_completion_tokens=10,
temperature=0.0,
logprobs=True,
@ -81,3 +81,38 @@ async def test_chat_completion_top5_logprobs(async_client: openai.AsyncOpenAI,
assert logprob_content.bytes is not None
assert logprob_content.top_logprobs is not None
assert len(logprob_content.top_logprobs) == 5
@pytest.mark.asyncio(loop_scope="module")
async def test_completion_top5_logprobs(async_client: openai.AsyncOpenAI,
model_name: str):
prompt = "Hello, my name is"
completion = await async_client.completions.create(model=model_name,
prompt=prompt,
max_tokens=5,
temperature=0.0,
logprobs=5,
extra_body={
"ignore_eos": True,
})
choice = completion.choices[0]
logprobs = choice.logprobs
assert logprobs is not None
assert logprobs.tokens is not None
assert logprobs.token_logprobs is not None
assert logprobs.top_logprobs is not None
assert len(logprobs.tokens) == len(logprobs.token_logprobs) == len(
logprobs.top_logprobs)
assert len(logprobs.tokens) > 0
for token, token_logprob, token_top_logprobs in zip(logprobs.tokens,
logprobs.token_logprobs,
logprobs.top_logprobs):
assert token is not None
assert token_logprob is not None
assert token_logprob <= 0
assert token_top_logprobs is not None
assert len(token_top_logprobs) == 5