From 5e34112b27ef53d8dcd1b1811defb16a8fd72035 Mon Sep 17 00:00:00 2001 From: Pengyun Lin <81065165+LinPoly@users.noreply.github.com> Date: Thu, 22 Jan 2026 21:25:24 +0800 Subject: [PATCH] [TRTLLM-10388][feat] Support logprobs for Completions API (#10809) Signed-off-by: Pengyun Lin <81065165+LinPoly@users.noreply.github.com> --- tensorrt_llm/serve/openai_protocol.py | 24 +++- tensorrt_llm/serve/openai_server.py | 4 +- tensorrt_llm/serve/postprocess_handlers.py | 50 ++++++- .../llmapi/apps/_test_openai_completions.py | 126 +++++++++++++++++- .../apps/_test_trtllm_serve_top_logprobs.py | 37 ++++- 5 files changed, 230 insertions(+), 11 deletions(-) diff --git a/tensorrt_llm/serve/openai_protocol.py b/tensorrt_llm/serve/openai_protocol.py index 61f8d8548e..a909212024 100644 --- a/tensorrt_llm/serve/openai_protocol.py +++ b/tensorrt_llm/serve/openai_protocol.py @@ -376,7 +376,10 @@ class CompletionRequest(OpenAIBaseModel): # doc: end-completion-extra-params - def to_sampling_params(self, vocab_size: int = 32000) -> SamplingParams: + def to_sampling_params(self, + vocab_size: int = 32000, + gather_generation_logits: bool = False, + backend: Optional[str] = None) -> SamplingParams: sampling_params = SamplingParams( best_of=self.best_of, frequency_penalty=self.frequency_penalty, @@ -416,17 +419,26 @@ class CompletionRequest(OpenAIBaseModel): # completion-extra-params add_special_tokens=self.add_special_tokens, - - # TODO: migrate to use logprobs and prompt_logprobs - _return_log_probs=bool(self.logprobs), ) + if self.logprobs: + if backend == "pytorch": + sampling_params.logprobs = self.logprobs + else: + if gather_generation_logits: + sampling_params.logprobs = self.logprobs + elif self.logprobs > 1: + raise ValueError( + "`logprobs` must be 1 or `gather_generation_logits` must be `True` to use `logprobs` > 1" + ) + else: + sampling_params._return_log_probs = True return sampling_params @model_validator(mode="before") @classmethod def check_logprobs(cls, data): - if data.get("logprobs"): - raise ValueError("logprobs is not supported") + if (logprobs := data.get("logprobs")) is not None and logprobs < 0: + raise ValueError("logprobs must be positive or zero") return data @model_validator(mode="before") diff --git a/tensorrt_llm/serve/openai_server.py b/tensorrt_llm/serve/openai_server.py index 4c441237a6..524c7440ab 100644 --- a/tensorrt_llm/serve/openai_server.py +++ b/tensorrt_llm/serve/openai_server.py @@ -794,7 +794,9 @@ class OpenAIServer: # Pass the tokenizer vocabulary size so ``logit_bias`` can be # expanded into an embedding bias tensor in the sampler. sampling_params = request.to_sampling_params( - vocab_size=self.tokenizer.tokenizer.vocab_size) + vocab_size=self.tokenizer.tokenizer.vocab_size, + gather_generation_logits=self.llm.args.gather_generation_logits, + backend=self.llm.args.backend) # TODO: better way to enable metrics if len(os.getenv("TRTLLM_KVCACHE_TIME_OUTPUT_PATH", "")) > 0: sampling_params.return_perf_metrics = True diff --git a/tensorrt_llm/serve/postprocess_handlers.py b/tensorrt_llm/serve/postprocess_handlers.py index 01ffb648e2..38c6c93563 100644 --- a/tensorrt_llm/serve/postprocess_handlers.py +++ b/tensorrt_llm/serve/postprocess_handlers.py @@ -26,8 +26,8 @@ from .openai_protocol import (ChatCompletionLogProbs, ChatCompletionResponseStreamChoice, ChatCompletionStreamResponse, ChatCompletionToolsParam, ChatMessage, - CompletionRequest, CompletionResponse, - CompletionResponseChoice, + CompletionLogProbs, CompletionRequest, + CompletionResponse, CompletionResponseChoice, CompletionResponseStreamChoice, CompletionStreamResponse, DeltaFunctionCall, DeltaMessage, DeltaToolCall, FunctionCall, @@ -394,6 +394,7 @@ class CompletionPostprocArgs(PostprocArgs): prompt_idx: int = 0 detokenize: bool = True prompt: Optional[str] = None + return_logprobs: bool = False stream_options: Optional[StreamOptions] = None @classmethod @@ -404,9 +405,43 @@ class CompletionPostprocArgs(PostprocArgs): num_choices=request.n if request.n else 1, stream_options=request.stream_options, detokenize=request.detokenize, + return_logprobs=bool(request.logprobs), ) +def create_completion_logprobs(token_ids: List[int], + tokenizer: TransformersTokenizer, + logprobs: List[float] | TokenLogprobs, + initial_offset: int = 0) -> CompletionLogProbs: + assert len(token_ids) == len(logprobs), \ + "token_ids and logprobs have different lengths" + text_offset = [] + token_logprobs = [] + top_logprobs_list = [] + tokens = [] + for token_id, logprob in zip(token_ids, logprobs): + if isinstance(logprob, dict): + token_logprobs.append(max(logprob[token_id].logprob, -9999.0)) + top_logprobs_list.append({ + tokenizer.decode(tid): + max(lp.logprob, -9999.0) + for tid, lp in logprob.items() + }) + else: + token_logprobs.append(max(logprob, -9999.0)) + + token = tokenizer.decode(token_id) + if len(text_offset) == 0: + text_offset.append(initial_offset) + else: + text_offset.append(text_offset[-1] + len(token)) + tokens.append(token) + return CompletionLogProbs(text_offset=text_offset, + token_logprobs=token_logprobs, + tokens=tokens, + top_logprobs=top_logprobs_list) + + @nvtx_range_debug("completion_stream_post_processor") def completion_stream_post_processor(rsp: DetokenizedGenerationResultBase, args: CompletionPostprocArgs) -> List[str]: @@ -433,6 +468,12 @@ def completion_stream_post_processor(rsp: DetokenizedGenerationResultBase, 'avg_decoded_tokens_per_iter', None), ) + if args.return_logprobs: + logprobs = output.logprobs_diff + token_ids = output.token_ids_diff + choice.logprobs = create_completion_logprobs( + token_ids, args.tokenizer, logprobs, output._last_text_len) + chunk = CompletionStreamResponse(model=args.model, choices=[choice]) if include_continuous_usage: chunk.usage = UsageInfo(prompt_tokens=prompt_tokens, @@ -488,6 +529,11 @@ def completion_response_post_processor( 'avg_decoded_tokens_per_iter', None), ) + if args.return_logprobs: + logprobs = output.logprobs + token_ids = output.token_ids + choice.logprobs = create_completion_logprobs( + token_ids, args.tokenizer, logprobs) completion_tokens += output.length choices.append(choice) diff --git a/tests/unittest/llmapi/apps/_test_openai_completions.py b/tests/unittest/llmapi/apps/_test_openai_completions.py index e3e374ec1c..03f08a5c6b 100644 --- a/tests/unittest/llmapi/apps/_test_openai_completions.py +++ b/tests/unittest/llmapi/apps/_test_openai_completions.py @@ -3,8 +3,10 @@ from typing import List +import numpy as np import openai import pytest +import yaml from ..test_llm import get_model_path from .openai_server import RemoteOpenAIServer @@ -22,6 +24,23 @@ def backend(request): return request.param +@pytest.fixture(scope="module") +def temp_extra_llm_api_options_file(tmp_path_factory): + extra_llm_api_options_dict = { + "enable_chunked_prefill": False, + "gather_generation_logits": True, + "kv_cache_config": { + "enable_block_reuse": False, + } + } + + temp_file_path = tmp_path_factory.mktemp( + "config") / "extra_llm_api_options.yaml" + with open(temp_file_path, 'w') as f: + yaml.dump(extra_llm_api_options_dict, f) + return temp_file_path + + @pytest.fixture(scope="module", params=[0, 2], ids=["disable_processpool", "enable_processpool"]) @@ -30,12 +49,16 @@ def num_postprocess_workers(request): @pytest.fixture(scope="module") -def server(model_name: str, backend: str, num_postprocess_workers: int): +def server(model_name: str, backend: str, num_postprocess_workers: int, + temp_extra_llm_api_options_file: str): model_path = get_model_path(model_name) args = ["--backend", f"{backend}"] args.extend(["--kv_cache_free_gpu_memory_fraction", "0.2"]) # for co-existence with other servers args.extend(["--num_postprocess_workers", f"{num_postprocess_workers}"]) + if backend == "trt": + args.extend( + ["--extra_llm_api_options", temp_extra_llm_api_options_file]) with RemoteOpenAIServer(model_path, args) as remote_server: yield remote_server @@ -433,6 +456,107 @@ async def test_completion_with_invalid_logit_bias( await invalid_logit_bias_helper(async_client, model_name, 'completions') +def test_completion_logprobs(client: openai.OpenAI, model_name: str, + backend: str, num_postprocess_workers: int): + """Test completion with logprobs enabled (non-streaming).""" + if backend == "trt" and num_postprocess_workers > 0: + pytest.skip("Logprobs is not supported in TRT processpool mode") + + prompt = "Hello, my name is" + + completion = client.completions.create( + model=model_name, + prompt=prompt, + max_tokens=5, + temperature=0.0, + logprobs=1, + ) + + choice = completion.choices[0] + assert choice.logprobs is not None + + # Verify logprobs structure + logprobs = choice.logprobs + assert logprobs.tokens is not None + assert logprobs.token_logprobs is not None + assert logprobs.text_offset is not None + + # Verify lengths match + assert len(logprobs.tokens) == len(logprobs.token_logprobs) + assert len(logprobs.tokens) == len(logprobs.text_offset) + assert len(logprobs.tokens) > 0 + + # Verify logprobs values are valid (negative or zero for log probabilities) + for token_logprob in logprobs.token_logprobs: + assert token_logprob is not None + assert token_logprob <= 0 + + # Verify text_offset is monotonically increasing + for i in range(1, len(logprobs.text_offset)): + assert logprobs.text_offset[i] >= logprobs.text_offset[i - 1] + + # Verify tokens are non-empty strings + for token in logprobs.tokens: + assert isinstance(token, str) + + +@pytest.mark.asyncio(loop_scope="module") +async def test_completion_logprobs_streaming(async_client: openai.AsyncOpenAI, + backend: str, model_name: str, + num_postprocess_workers: int): + """Test completion with logprobs enabled (streaming).""" + if backend == "trt" and num_postprocess_workers > 0: + pytest.skip("Logprobs is not supported in TRT processpool mode") + + prompt = "Hello, my name is" + + # First get non-streaming result for comparison + single_completion = await async_client.completions.create( + model=model_name, + prompt=prompt, + max_tokens=5, + temperature=0.0, + logprobs=1, + ) + single_logprobs = single_completion.choices[0].logprobs + assert single_logprobs is not None + + # Now test streaming + stream = await async_client.completions.create( + model=model_name, + prompt=prompt, + max_tokens=5, + temperature=0.0, + logprobs=2, + stream=True, + ) + + all_tokens: List[str] = [] + all_token_logprobs: List[float] = [] + + async for chunk in stream: + choice = chunk.choices[0] + if choice.logprobs is not None: + if choice.logprobs.tokens: + all_tokens.extend(choice.logprobs.tokens) + if choice.logprobs.token_logprobs: + all_token_logprobs.extend(choice.logprobs.token_logprobs) + + # Verify streaming logprobs match non-streaming + assert all_tokens == single_logprobs.tokens + assert len(all_token_logprobs) == len(single_logprobs.token_logprobs) + + # Compare logprobs values (should be close) + all_token_logprobs_arr = np.array(all_token_logprobs) + single_token_logprobs_arr = np.array(single_logprobs.token_logprobs) + assert np.allclose(all_token_logprobs_arr, single_token_logprobs_arr) + + # Verify all logprobs are valid + for logprob in all_token_logprobs: + assert logprob is not None + assert logprob <= 0 + + def test_completion_cached_tokens(client: openai.OpenAI, model_name: str, backend: str): if backend == "trt": diff --git a/tests/unittest/llmapi/apps/_test_trtllm_serve_top_logprobs.py b/tests/unittest/llmapi/apps/_test_trtllm_serve_top_logprobs.py index dc95ecf292..aeb7be4a74 100644 --- a/tests/unittest/llmapi/apps/_test_trtllm_serve_top_logprobs.py +++ b/tests/unittest/llmapi/apps/_test_trtllm_serve_top_logprobs.py @@ -64,7 +64,7 @@ async def test_chat_completion_top5_logprobs(async_client: openai.AsyncOpenAI, # Test top_logprobs chat_completion = await async_client.chat.completions.create( model=model_name, - messages=messages, + messages=messages, # type: ignore[arg-type] max_completion_tokens=10, temperature=0.0, logprobs=True, @@ -81,3 +81,38 @@ async def test_chat_completion_top5_logprobs(async_client: openai.AsyncOpenAI, assert logprob_content.bytes is not None assert logprob_content.top_logprobs is not None assert len(logprob_content.top_logprobs) == 5 + + +@pytest.mark.asyncio(loop_scope="module") +async def test_completion_top5_logprobs(async_client: openai.AsyncOpenAI, + model_name: str): + prompt = "Hello, my name is" + + completion = await async_client.completions.create(model=model_name, + prompt=prompt, + max_tokens=5, + temperature=0.0, + logprobs=5, + extra_body={ + "ignore_eos": True, + }) + + choice = completion.choices[0] + logprobs = choice.logprobs + assert logprobs is not None + assert logprobs.tokens is not None + assert logprobs.token_logprobs is not None + assert logprobs.top_logprobs is not None + + assert len(logprobs.tokens) == len(logprobs.token_logprobs) == len( + logprobs.top_logprobs) + assert len(logprobs.tokens) > 0 + + for token, token_logprob, token_top_logprobs in zip(logprobs.tokens, + logprobs.token_logprobs, + logprobs.top_logprobs): + assert token is not None + assert token_logprob is not None + assert token_logprob <= 0 + assert token_top_logprobs is not None + assert len(token_top_logprobs) == 5