mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-05 02:31:33 +08:00
[TRTLLM-10388][feat] Support logprobs for Completions API (#10809)
Signed-off-by: Pengyun Lin <81065165+LinPoly@users.noreply.github.com>
This commit is contained in:
parent
9beb971827
commit
5e34112b27
@ -376,7 +376,10 @@ class CompletionRequest(OpenAIBaseModel):
|
||||
|
||||
# doc: end-completion-extra-params
|
||||
|
||||
def to_sampling_params(self, vocab_size: int = 32000) -> SamplingParams:
|
||||
def to_sampling_params(self,
|
||||
vocab_size: int = 32000,
|
||||
gather_generation_logits: bool = False,
|
||||
backend: Optional[str] = None) -> SamplingParams:
|
||||
sampling_params = SamplingParams(
|
||||
best_of=self.best_of,
|
||||
frequency_penalty=self.frequency_penalty,
|
||||
@ -416,17 +419,26 @@ class CompletionRequest(OpenAIBaseModel):
|
||||
|
||||
# completion-extra-params
|
||||
add_special_tokens=self.add_special_tokens,
|
||||
|
||||
# TODO: migrate to use logprobs and prompt_logprobs
|
||||
_return_log_probs=bool(self.logprobs),
|
||||
)
|
||||
if self.logprobs:
|
||||
if backend == "pytorch":
|
||||
sampling_params.logprobs = self.logprobs
|
||||
else:
|
||||
if gather_generation_logits:
|
||||
sampling_params.logprobs = self.logprobs
|
||||
elif self.logprobs > 1:
|
||||
raise ValueError(
|
||||
"`logprobs` must be 1 or `gather_generation_logits` must be `True` to use `logprobs` > 1"
|
||||
)
|
||||
else:
|
||||
sampling_params._return_log_probs = True
|
||||
return sampling_params
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_logprobs(cls, data):
|
||||
if data.get("logprobs"):
|
||||
raise ValueError("logprobs is not supported")
|
||||
if (logprobs := data.get("logprobs")) is not None and logprobs < 0:
|
||||
raise ValueError("logprobs must be positive or zero")
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
|
||||
@ -794,7 +794,9 @@ class OpenAIServer:
|
||||
# Pass the tokenizer vocabulary size so ``logit_bias`` can be
|
||||
# expanded into an embedding bias tensor in the sampler.
|
||||
sampling_params = request.to_sampling_params(
|
||||
vocab_size=self.tokenizer.tokenizer.vocab_size)
|
||||
vocab_size=self.tokenizer.tokenizer.vocab_size,
|
||||
gather_generation_logits=self.llm.args.gather_generation_logits,
|
||||
backend=self.llm.args.backend)
|
||||
# TODO: better way to enable metrics
|
||||
if len(os.getenv("TRTLLM_KVCACHE_TIME_OUTPUT_PATH", "")) > 0:
|
||||
sampling_params.return_perf_metrics = True
|
||||
|
||||
@ -26,8 +26,8 @@ from .openai_protocol import (ChatCompletionLogProbs,
|
||||
ChatCompletionResponseStreamChoice,
|
||||
ChatCompletionStreamResponse,
|
||||
ChatCompletionToolsParam, ChatMessage,
|
||||
CompletionRequest, CompletionResponse,
|
||||
CompletionResponseChoice,
|
||||
CompletionLogProbs, CompletionRequest,
|
||||
CompletionResponse, CompletionResponseChoice,
|
||||
CompletionResponseStreamChoice,
|
||||
CompletionStreamResponse, DeltaFunctionCall,
|
||||
DeltaMessage, DeltaToolCall, FunctionCall,
|
||||
@ -394,6 +394,7 @@ class CompletionPostprocArgs(PostprocArgs):
|
||||
prompt_idx: int = 0
|
||||
detokenize: bool = True
|
||||
prompt: Optional[str] = None
|
||||
return_logprobs: bool = False
|
||||
stream_options: Optional[StreamOptions] = None
|
||||
|
||||
@classmethod
|
||||
@ -404,9 +405,43 @@ class CompletionPostprocArgs(PostprocArgs):
|
||||
num_choices=request.n if request.n else 1,
|
||||
stream_options=request.stream_options,
|
||||
detokenize=request.detokenize,
|
||||
return_logprobs=bool(request.logprobs),
|
||||
)
|
||||
|
||||
|
||||
def create_completion_logprobs(token_ids: List[int],
|
||||
tokenizer: TransformersTokenizer,
|
||||
logprobs: List[float] | TokenLogprobs,
|
||||
initial_offset: int = 0) -> CompletionLogProbs:
|
||||
assert len(token_ids) == len(logprobs), \
|
||||
"token_ids and logprobs have different lengths"
|
||||
text_offset = []
|
||||
token_logprobs = []
|
||||
top_logprobs_list = []
|
||||
tokens = []
|
||||
for token_id, logprob in zip(token_ids, logprobs):
|
||||
if isinstance(logprob, dict):
|
||||
token_logprobs.append(max(logprob[token_id].logprob, -9999.0))
|
||||
top_logprobs_list.append({
|
||||
tokenizer.decode(tid):
|
||||
max(lp.logprob, -9999.0)
|
||||
for tid, lp in logprob.items()
|
||||
})
|
||||
else:
|
||||
token_logprobs.append(max(logprob, -9999.0))
|
||||
|
||||
token = tokenizer.decode(token_id)
|
||||
if len(text_offset) == 0:
|
||||
text_offset.append(initial_offset)
|
||||
else:
|
||||
text_offset.append(text_offset[-1] + len(token))
|
||||
tokens.append(token)
|
||||
return CompletionLogProbs(text_offset=text_offset,
|
||||
token_logprobs=token_logprobs,
|
||||
tokens=tokens,
|
||||
top_logprobs=top_logprobs_list)
|
||||
|
||||
|
||||
@nvtx_range_debug("completion_stream_post_processor")
|
||||
def completion_stream_post_processor(rsp: DetokenizedGenerationResultBase,
|
||||
args: CompletionPostprocArgs) -> List[str]:
|
||||
@ -433,6 +468,12 @@ def completion_stream_post_processor(rsp: DetokenizedGenerationResultBase,
|
||||
'avg_decoded_tokens_per_iter',
|
||||
None),
|
||||
)
|
||||
if args.return_logprobs:
|
||||
logprobs = output.logprobs_diff
|
||||
token_ids = output.token_ids_diff
|
||||
choice.logprobs = create_completion_logprobs(
|
||||
token_ids, args.tokenizer, logprobs, output._last_text_len)
|
||||
|
||||
chunk = CompletionStreamResponse(model=args.model, choices=[choice])
|
||||
if include_continuous_usage:
|
||||
chunk.usage = UsageInfo(prompt_tokens=prompt_tokens,
|
||||
@ -488,6 +529,11 @@ def completion_response_post_processor(
|
||||
'avg_decoded_tokens_per_iter',
|
||||
None),
|
||||
)
|
||||
if args.return_logprobs:
|
||||
logprobs = output.logprobs
|
||||
token_ids = output.token_ids
|
||||
choice.logprobs = create_completion_logprobs(
|
||||
token_ids, args.tokenizer, logprobs)
|
||||
|
||||
completion_tokens += output.length
|
||||
choices.append(choice)
|
||||
|
||||
@ -3,8 +3,10 @@
|
||||
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
import openai
|
||||
import pytest
|
||||
import yaml
|
||||
|
||||
from ..test_llm import get_model_path
|
||||
from .openai_server import RemoteOpenAIServer
|
||||
@ -22,6 +24,23 @@ def backend(request):
|
||||
return request.param
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def temp_extra_llm_api_options_file(tmp_path_factory):
|
||||
extra_llm_api_options_dict = {
|
||||
"enable_chunked_prefill": False,
|
||||
"gather_generation_logits": True,
|
||||
"kv_cache_config": {
|
||||
"enable_block_reuse": False,
|
||||
}
|
||||
}
|
||||
|
||||
temp_file_path = tmp_path_factory.mktemp(
|
||||
"config") / "extra_llm_api_options.yaml"
|
||||
with open(temp_file_path, 'w') as f:
|
||||
yaml.dump(extra_llm_api_options_dict, f)
|
||||
return temp_file_path
|
||||
|
||||
|
||||
@pytest.fixture(scope="module",
|
||||
params=[0, 2],
|
||||
ids=["disable_processpool", "enable_processpool"])
|
||||
@ -30,12 +49,16 @@ def num_postprocess_workers(request):
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server(model_name: str, backend: str, num_postprocess_workers: int):
|
||||
def server(model_name: str, backend: str, num_postprocess_workers: int,
|
||||
temp_extra_llm_api_options_file: str):
|
||||
model_path = get_model_path(model_name)
|
||||
args = ["--backend", f"{backend}"]
|
||||
args.extend(["--kv_cache_free_gpu_memory_fraction",
|
||||
"0.2"]) # for co-existence with other servers
|
||||
args.extend(["--num_postprocess_workers", f"{num_postprocess_workers}"])
|
||||
if backend == "trt":
|
||||
args.extend(
|
||||
["--extra_llm_api_options", temp_extra_llm_api_options_file])
|
||||
with RemoteOpenAIServer(model_path, args) as remote_server:
|
||||
yield remote_server
|
||||
|
||||
@ -433,6 +456,107 @@ async def test_completion_with_invalid_logit_bias(
|
||||
await invalid_logit_bias_helper(async_client, model_name, 'completions')
|
||||
|
||||
|
||||
def test_completion_logprobs(client: openai.OpenAI, model_name: str,
|
||||
backend: str, num_postprocess_workers: int):
|
||||
"""Test completion with logprobs enabled (non-streaming)."""
|
||||
if backend == "trt" and num_postprocess_workers > 0:
|
||||
pytest.skip("Logprobs is not supported in TRT processpool mode")
|
||||
|
||||
prompt = "Hello, my name is"
|
||||
|
||||
completion = client.completions.create(
|
||||
model=model_name,
|
||||
prompt=prompt,
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
logprobs=1,
|
||||
)
|
||||
|
||||
choice = completion.choices[0]
|
||||
assert choice.logprobs is not None
|
||||
|
||||
# Verify logprobs structure
|
||||
logprobs = choice.logprobs
|
||||
assert logprobs.tokens is not None
|
||||
assert logprobs.token_logprobs is not None
|
||||
assert logprobs.text_offset is not None
|
||||
|
||||
# Verify lengths match
|
||||
assert len(logprobs.tokens) == len(logprobs.token_logprobs)
|
||||
assert len(logprobs.tokens) == len(logprobs.text_offset)
|
||||
assert len(logprobs.tokens) > 0
|
||||
|
||||
# Verify logprobs values are valid (negative or zero for log probabilities)
|
||||
for token_logprob in logprobs.token_logprobs:
|
||||
assert token_logprob is not None
|
||||
assert token_logprob <= 0
|
||||
|
||||
# Verify text_offset is monotonically increasing
|
||||
for i in range(1, len(logprobs.text_offset)):
|
||||
assert logprobs.text_offset[i] >= logprobs.text_offset[i - 1]
|
||||
|
||||
# Verify tokens are non-empty strings
|
||||
for token in logprobs.tokens:
|
||||
assert isinstance(token, str)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="module")
|
||||
async def test_completion_logprobs_streaming(async_client: openai.AsyncOpenAI,
|
||||
backend: str, model_name: str,
|
||||
num_postprocess_workers: int):
|
||||
"""Test completion with logprobs enabled (streaming)."""
|
||||
if backend == "trt" and num_postprocess_workers > 0:
|
||||
pytest.skip("Logprobs is not supported in TRT processpool mode")
|
||||
|
||||
prompt = "Hello, my name is"
|
||||
|
||||
# First get non-streaming result for comparison
|
||||
single_completion = await async_client.completions.create(
|
||||
model=model_name,
|
||||
prompt=prompt,
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
logprobs=1,
|
||||
)
|
||||
single_logprobs = single_completion.choices[0].logprobs
|
||||
assert single_logprobs is not None
|
||||
|
||||
# Now test streaming
|
||||
stream = await async_client.completions.create(
|
||||
model=model_name,
|
||||
prompt=prompt,
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
logprobs=2,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
all_tokens: List[str] = []
|
||||
all_token_logprobs: List[float] = []
|
||||
|
||||
async for chunk in stream:
|
||||
choice = chunk.choices[0]
|
||||
if choice.logprobs is not None:
|
||||
if choice.logprobs.tokens:
|
||||
all_tokens.extend(choice.logprobs.tokens)
|
||||
if choice.logprobs.token_logprobs:
|
||||
all_token_logprobs.extend(choice.logprobs.token_logprobs)
|
||||
|
||||
# Verify streaming logprobs match non-streaming
|
||||
assert all_tokens == single_logprobs.tokens
|
||||
assert len(all_token_logprobs) == len(single_logprobs.token_logprobs)
|
||||
|
||||
# Compare logprobs values (should be close)
|
||||
all_token_logprobs_arr = np.array(all_token_logprobs)
|
||||
single_token_logprobs_arr = np.array(single_logprobs.token_logprobs)
|
||||
assert np.allclose(all_token_logprobs_arr, single_token_logprobs_arr)
|
||||
|
||||
# Verify all logprobs are valid
|
||||
for logprob in all_token_logprobs:
|
||||
assert logprob is not None
|
||||
assert logprob <= 0
|
||||
|
||||
|
||||
def test_completion_cached_tokens(client: openai.OpenAI, model_name: str,
|
||||
backend: str):
|
||||
if backend == "trt":
|
||||
|
||||
@ -64,7 +64,7 @@ async def test_chat_completion_top5_logprobs(async_client: openai.AsyncOpenAI,
|
||||
# Test top_logprobs
|
||||
chat_completion = await async_client.chat.completions.create(
|
||||
model=model_name,
|
||||
messages=messages,
|
||||
messages=messages, # type: ignore[arg-type]
|
||||
max_completion_tokens=10,
|
||||
temperature=0.0,
|
||||
logprobs=True,
|
||||
@ -81,3 +81,38 @@ async def test_chat_completion_top5_logprobs(async_client: openai.AsyncOpenAI,
|
||||
assert logprob_content.bytes is not None
|
||||
assert logprob_content.top_logprobs is not None
|
||||
assert len(logprob_content.top_logprobs) == 5
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="module")
|
||||
async def test_completion_top5_logprobs(async_client: openai.AsyncOpenAI,
|
||||
model_name: str):
|
||||
prompt = "Hello, my name is"
|
||||
|
||||
completion = await async_client.completions.create(model=model_name,
|
||||
prompt=prompt,
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
logprobs=5,
|
||||
extra_body={
|
||||
"ignore_eos": True,
|
||||
})
|
||||
|
||||
choice = completion.choices[0]
|
||||
logprobs = choice.logprobs
|
||||
assert logprobs is not None
|
||||
assert logprobs.tokens is not None
|
||||
assert logprobs.token_logprobs is not None
|
||||
assert logprobs.top_logprobs is not None
|
||||
|
||||
assert len(logprobs.tokens) == len(logprobs.token_logprobs) == len(
|
||||
logprobs.top_logprobs)
|
||||
assert len(logprobs.tokens) > 0
|
||||
|
||||
for token, token_logprob, token_top_logprobs in zip(logprobs.tokens,
|
||||
logprobs.token_logprobs,
|
||||
logprobs.top_logprobs):
|
||||
assert token is not None
|
||||
assert token_logprob is not None
|
||||
assert token_logprob <= 0
|
||||
assert token_top_logprobs is not None
|
||||
assert len(token_top_logprobs) == 5
|
||||
|
||||
Loading…
Reference in New Issue
Block a user