mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
455 lines
19 KiB
Python
455 lines
19 KiB
Python
from dataclasses import dataclass, field
|
|
from typing import List, Literal, Optional, Tuple, Union
|
|
|
|
from .._utils import nvtx_range_debug
|
|
from ..executor import (DetokenizedGenerationResultBase, GenerationResult,
|
|
GenerationResultBase)
|
|
from ..executor.postproc_worker import PostprocArgs
|
|
from ..executor.result import Logprob, TokenLogprobs
|
|
from ..llmapi.reasoning_parser import (BaseReasoningParser,
|
|
ReasoningParserFactory)
|
|
from ..llmapi.tokenizer import TransformersTokenizer
|
|
# yapf: disable
|
|
from .harmony_adapter import (handle_non_streaming_response,
|
|
handle_streaming_response)
|
|
from .openai_protocol import (ChatCompletionLogProbs,
|
|
ChatCompletionLogProbsContent,
|
|
ChatCompletionNamedToolChoiceParam,
|
|
ChatCompletionRequest, ChatCompletionResponse,
|
|
ChatCompletionResponseChoice,
|
|
ChatCompletionResponseStreamChoice,
|
|
ChatCompletionStreamResponse,
|
|
ChatCompletionToolsParam, ChatMessage,
|
|
CompletionRequest, CompletionResponse,
|
|
CompletionResponseChoice,
|
|
CompletionResponseStreamChoice,
|
|
CompletionStreamResponse, DeltaMessage,
|
|
FunctionCall, PromptTokensDetails, StreamOptions,
|
|
ToolCall, UsageInfo, to_disaggregated_params)
|
|
|
|
# yapf: enable
|
|
|
|
|
|
@dataclass(kw_only=True)
|
|
class ChatPostprocArgs(PostprocArgs):
|
|
echo: bool = False
|
|
role: str = None
|
|
model: str = None
|
|
num_choices: int = 1
|
|
tools: Optional[List[ChatCompletionToolsParam]] = None
|
|
tool_choice: Optional[Union[Literal["none"],
|
|
ChatCompletionNamedToolChoiceParam]] = "none"
|
|
return_logprobs: bool = False
|
|
top_logprobs: bool = False
|
|
stream_options: Optional[StreamOptions] = None
|
|
last_message_content: Optional[str] = None
|
|
reasoning_parser: Optional[str] = None
|
|
reasoning_parser_dict: dict[int, BaseReasoningParser] = field(
|
|
default_factory=dict)
|
|
|
|
@classmethod
|
|
def from_request(cls, request: ChatCompletionRequest):
|
|
return cls(
|
|
echo=request.echo,
|
|
role="assistant"
|
|
if request.add_generation_prompt else request.messages[-1]["role"],
|
|
model=request.model,
|
|
num_choices=request.n if request.n else 1,
|
|
tools=request.tools,
|
|
tool_choice=request.tool_choice,
|
|
stream_options=request.stream_options,
|
|
return_logprobs=bool(request.logprobs),
|
|
top_logprobs=bool(request.top_logprobs),
|
|
)
|
|
|
|
|
|
def create_logprobs(token_ids: List[int], tokenizer: TransformersTokenizer,
|
|
logprobs: List[float] | TokenLogprobs,
|
|
top_logprobs: bool) -> ChatCompletionLogProbs:
|
|
assert len(token_ids) == len(logprobs), \
|
|
"token_ids and logprobs have different lengths"
|
|
content: List[ChatCompletionLogProbsContent] = []
|
|
for token_id, logprob in zip(token_ids, logprobs):
|
|
logprob: float | dict[int, Logprob]
|
|
token = tokenizer.decode(token_id)
|
|
chat_logprob = ChatCompletionLogProbsContent(
|
|
token=token,
|
|
bytes=list(token.encode("utf-8", errors="replace")),
|
|
)
|
|
if isinstance(logprob, dict):
|
|
if token_id in logprob:
|
|
chat_logprob.logprob = max(logprob[token_id].logprob, -9999.0)
|
|
if top_logprobs:
|
|
chat_logprob.top_logprobs = [
|
|
ChatCompletionLogProbsContent(
|
|
token=(tk := tokenizer.decode(tid)),
|
|
logprob=max(logprob.logprob, -9999.0),
|
|
bytes=list(tk.encode("utf-8", errors="replace")))
|
|
for tid, logprob in logprob.items()
|
|
]
|
|
else:
|
|
chat_logprob.logprob = max(logprob, -9999.0)
|
|
content.append(chat_logprob)
|
|
chat_logprobs = ChatCompletionLogProbs(content=content)
|
|
return chat_logprobs
|
|
|
|
|
|
def apply_reasoning_parser(args: ChatPostprocArgs, output_index: int, text: str,
|
|
streaming: bool) -> Tuple[str, str]:
|
|
reasoning_parser = None
|
|
if args.reasoning_parser is not None:
|
|
if output_index not in args.reasoning_parser_dict:
|
|
args.reasoning_parser_dict[
|
|
output_index] = ReasoningParserFactory.create_reasoning_parser(
|
|
args.reasoning_parser)
|
|
reasoning_parser = args.reasoning_parser_dict[output_index]
|
|
|
|
if reasoning_parser is not None:
|
|
if not streaming:
|
|
result = reasoning_parser.parse(text)
|
|
else:
|
|
result = reasoning_parser.parse_delta(text)
|
|
content, reasoning_content = result.content, result.reasoning_content
|
|
else:
|
|
content, reasoning_content = text, ""
|
|
|
|
return content, reasoning_content
|
|
|
|
|
|
@nvtx_range_debug("chat_stream_post_processor")
|
|
def chat_stream_post_processor(rsp: GenerationResultBase,
|
|
args: ChatPostprocArgs) -> List[str]:
|
|
|
|
def yield_first_chat(num_tokens: int,
|
|
idx: int,
|
|
role: str | None = None,
|
|
content: str | None = None):
|
|
choice_data = ChatCompletionResponseStreamChoice(index=idx,
|
|
delta=DeltaMessage(
|
|
role=role,
|
|
content=content),
|
|
finish_reason=None)
|
|
chunk = ChatCompletionStreamResponse(choices=[choice_data],
|
|
model=args.model)
|
|
if include_continuous_usage:
|
|
chunk.usage = UsageInfo(
|
|
prompt_tokens=num_tokens,
|
|
total_tokens=num_tokens,
|
|
completion_tokens=0,
|
|
prompt_tokens_details=PromptTokensDetails(
|
|
cached_tokens=rsp.cached_tokens),
|
|
)
|
|
data = chunk.model_dump_json(exclude_none=True)
|
|
return data
|
|
|
|
res: List[str] = []
|
|
finish_reason_sent = [False] * args.num_choices
|
|
prompt_tokens = args.num_prompt_tokens
|
|
if stream_option := args.stream_options:
|
|
include_usage = stream_option.include_usage
|
|
include_continuous_usage = include_usage and stream_option.continuous_usage_stats
|
|
else:
|
|
include_usage = False
|
|
include_continuous_usage = False
|
|
if args.first_iteration:
|
|
for i in range(args.num_choices):
|
|
res.append(
|
|
f"data: {yield_first_chat(prompt_tokens, i, role=args.role)} \n\n"
|
|
)
|
|
if args.echo and args.last_message_content:
|
|
res.append(
|
|
f"data: {yield_first_chat(prompt_tokens, i, content=args.last_message_content)} \n\n"
|
|
)
|
|
args.first_iteration = False
|
|
|
|
for output in rsp.outputs:
|
|
i = output.index
|
|
|
|
if finish_reason_sent[i]:
|
|
continue
|
|
|
|
delta_text = output.text_diff
|
|
|
|
delta_text, reasoning_delta_text = apply_reasoning_parser(
|
|
args, i, delta_text, True)
|
|
|
|
if args.tool_choice and type(
|
|
args.tool_choice) is ChatCompletionNamedToolChoiceParam:
|
|
delta_message = DeltaMessage(tool_calls=[
|
|
ToolCall(function=FunctionCall(
|
|
name=args.tool_choice.function.name, arguments=delta_text))
|
|
])
|
|
else:
|
|
delta_message = DeltaMessage(content=delta_text,
|
|
reasoning_content=reasoning_delta_text)
|
|
|
|
choice = ChatCompletionResponseStreamChoice(
|
|
index=i,
|
|
delta=delta_message,
|
|
finish_reason=None,
|
|
avg_decoded_tokens_per_iter=getattr(rsp,
|
|
'avg_decoded_tokens_per_iter',
|
|
None))
|
|
if args.return_logprobs:
|
|
logprobs = output.logprobs_diff
|
|
token_ids = output.token_ids_diff
|
|
choice.logprobs = create_logprobs(token_ids, args.tokenizer,
|
|
logprobs, args.top_logprobs)
|
|
if output.finish_reason is not None:
|
|
choice.finish_reason = output.finish_reason
|
|
choice.stop_reason = output.stop_reason
|
|
finish_reason_sent[i] = True
|
|
chunk = ChatCompletionStreamResponse(choices=[choice], model=args.model)
|
|
if include_continuous_usage:
|
|
chunk.usage = UsageInfo(prompt_tokens=prompt_tokens,
|
|
completion_tokens=output.length,
|
|
total_tokens=output.length + prompt_tokens,
|
|
prompt_tokens_details=PromptTokensDetails(
|
|
cached_tokens=rsp.cached_tokens))
|
|
data = chunk.model_dump_json(exclude_none=True)
|
|
res.append(f"data: {data}\n\n")
|
|
|
|
if include_usage and rsp._done:
|
|
completion_tokens = sum(output.length for output in rsp.outputs)
|
|
final_usage = UsageInfo(
|
|
prompt_tokens=prompt_tokens,
|
|
completion_tokens=completion_tokens,
|
|
total_tokens=prompt_tokens + completion_tokens,
|
|
prompt_tokens_details=PromptTokensDetails(
|
|
cached_tokens=rsp.cached_tokens),
|
|
)
|
|
|
|
final_usage_chunk = ChatCompletionStreamResponse(choices=[],
|
|
model=args.model,
|
|
usage=final_usage)
|
|
final_usage_data = final_usage_chunk.model_dump_json()
|
|
res.append(f"data: {final_usage_data}\n\n")
|
|
return res
|
|
|
|
|
|
@nvtx_range_debug("chat_response_post_processor")
|
|
def chat_response_post_processor(
|
|
rsp: GenerationResultBase,
|
|
args: ChatPostprocArgs) -> ChatCompletionResponse:
|
|
choices: List[ChatCompletionResponseChoice] = []
|
|
role = args.role
|
|
for output in rsp.outputs:
|
|
text, reasoning_text = apply_reasoning_parser(args, output.index,
|
|
output.text, False)
|
|
|
|
if args.tool_choice and isinstance(args.tool_choice,
|
|
ChatCompletionNamedToolChoiceParam):
|
|
message = ChatMessage(
|
|
role=role,
|
|
content="",
|
|
tool_calls=[
|
|
ToolCall(function=FunctionCall(
|
|
name=args.tool_choice.function.name, arguments=text))
|
|
])
|
|
else:
|
|
message = ChatMessage(role=role,
|
|
content=text,
|
|
reasoning_content=reasoning_text)
|
|
disaggregated_params = to_disaggregated_params(
|
|
output.disaggregated_params)
|
|
choice = ChatCompletionResponseChoice(
|
|
index=output.index,
|
|
message=message,
|
|
finish_reason=output.finish_reason,
|
|
stop_reason=output.stop_reason,
|
|
disaggregated_params=disaggregated_params,
|
|
avg_decoded_tokens_per_iter=getattr(rsp,
|
|
'avg_decoded_tokens_per_iter',
|
|
None),
|
|
)
|
|
|
|
if args.return_logprobs:
|
|
choice.logprobs = create_logprobs(output.token_ids, args.tokenizer,
|
|
output.logprobs,
|
|
args.top_logprobs)
|
|
choices.append(choice)
|
|
|
|
if args.echo and args.last_message_content:
|
|
for choice in choices:
|
|
full_message = args.last_message_content + choice.message.content
|
|
choice.message.content = full_message
|
|
|
|
num_prompt_tokens = args.num_prompt_tokens
|
|
num_generated_tokens = sum(len(output.token_ids) for output in rsp.outputs)
|
|
usage = UsageInfo(
|
|
prompt_tokens=num_prompt_tokens,
|
|
completion_tokens=num_generated_tokens,
|
|
total_tokens=num_prompt_tokens + num_generated_tokens,
|
|
prompt_tokens_details=PromptTokensDetails(
|
|
cached_tokens=rsp.cached_tokens),
|
|
)
|
|
response = ChatCompletionResponse(
|
|
model=args.model,
|
|
choices=choices,
|
|
usage=usage,
|
|
)
|
|
return response
|
|
|
|
|
|
@dataclass(kw_only=True)
|
|
class CompletionPostprocArgs(PostprocArgs):
|
|
echo: bool = False
|
|
model: str = None
|
|
num_choices: int = 1
|
|
prompt_idx: int = 0
|
|
detokenize: bool = True
|
|
prompt: Optional[str] = None
|
|
stream_options: Optional[StreamOptions] = None
|
|
|
|
@classmethod
|
|
def from_request(cls, request: CompletionRequest):
|
|
return cls(
|
|
echo=request.echo,
|
|
model=request.model,
|
|
num_choices=request.n if request.n else 1,
|
|
stream_options=request.stream_options,
|
|
detokenize=request.detokenize,
|
|
)
|
|
|
|
|
|
@nvtx_range_debug("completion_stream_post_processor")
|
|
def completion_stream_post_processor(rsp: DetokenizedGenerationResultBase,
|
|
args: CompletionPostprocArgs) -> List[str]:
|
|
res: List[str] = []
|
|
prompt_tokens = args.num_prompt_tokens
|
|
if stream_option := args.stream_options:
|
|
include_usage = stream_option.include_usage
|
|
include_continuous_usage = include_usage and stream_option.continuous_usage_stats
|
|
else:
|
|
include_usage = False
|
|
include_continuous_usage = False
|
|
|
|
for output in rsp.outputs:
|
|
delta_text = output.text_diff
|
|
if args.echo and args.first_iteration:
|
|
delta_text = args.prompt + delta_text
|
|
choice = CompletionResponseStreamChoice(
|
|
index=args.prompt_idx * args.num_choices + output.index,
|
|
text=delta_text if args.detokenize else "",
|
|
token_ids=None if args.detokenize else output.token_ids_diff,
|
|
finish_reason=output.finish_reason,
|
|
stop_reason=output.stop_reason,
|
|
avg_decoded_tokens_per_iter=getattr(rsp,
|
|
'avg_decoded_tokens_per_iter',
|
|
None),
|
|
)
|
|
chunk = CompletionStreamResponse(model=args.model, choices=[choice])
|
|
if include_continuous_usage:
|
|
chunk.usage = UsageInfo(prompt_tokens=prompt_tokens,
|
|
completion_tokens=output.length,
|
|
total_tokens=output.length + prompt_tokens,
|
|
prompt_tokens_details=PromptTokensDetails(
|
|
cached_tokens=rsp.cached_tokens))
|
|
data = chunk.model_dump_json(exclude_unset=False)
|
|
res.append(f"data: {data}\n\n")
|
|
|
|
if include_usage and rsp._done:
|
|
completion_tokens = sum(output.length for output in rsp.outputs)
|
|
final_usage = UsageInfo(
|
|
prompt_tokens=prompt_tokens,
|
|
completion_tokens=completion_tokens,
|
|
total_tokens=prompt_tokens + completion_tokens,
|
|
prompt_tokens_details=PromptTokensDetails(
|
|
cached_tokens=rsp.cached_tokens),
|
|
)
|
|
|
|
final_usage_chunk = ChatCompletionStreamResponse(choices=[],
|
|
model=args.model,
|
|
usage=final_usage)
|
|
final_usage_data = final_usage_chunk.model_dump_json()
|
|
res.append(f"data: {final_usage_data}\n\n")
|
|
args.first_iteration = False
|
|
return res
|
|
|
|
|
|
@nvtx_range_debug("completion_response_post_processor")
|
|
def completion_response_post_processor(
|
|
rsp: GenerationResult,
|
|
args: CompletionPostprocArgs) -> CompletionResponse:
|
|
prompt_tokens = args.num_prompt_tokens
|
|
completion_tokens = 0
|
|
choices = []
|
|
for output in rsp.outputs:
|
|
text = output.text
|
|
if args.echo:
|
|
text = args.prompt + text
|
|
disaggregated_params = to_disaggregated_params(
|
|
output.disaggregated_params)
|
|
choice = CompletionResponseChoice(
|
|
text=text if args.detokenize else "",
|
|
token_ids=None if args.detokenize else output.token_ids,
|
|
index=args.prompt_idx * args.num_choices + output.index,
|
|
disaggregated_params=disaggregated_params,
|
|
context_logits=None
|
|
if rsp.context_logits is None else rsp.context_logits.tolist(),
|
|
stop_reason=output.stop_reason,
|
|
finish_reason=output.finish_reason,
|
|
avg_decoded_tokens_per_iter=getattr(rsp,
|
|
'avg_decoded_tokens_per_iter',
|
|
None),
|
|
)
|
|
|
|
completion_tokens += output.length
|
|
choices.append(choice)
|
|
|
|
usage = UsageInfo(prompt_tokens=prompt_tokens,
|
|
completion_tokens=completion_tokens,
|
|
total_tokens=completion_tokens + prompt_tokens,
|
|
prompt_tokens_details=PromptTokensDetails(
|
|
cached_tokens=rsp.cached_tokens))
|
|
response = CompletionResponse(choices=choices,
|
|
model=args.model,
|
|
usage=usage)
|
|
return response
|
|
|
|
|
|
@dataclass(kw_only=True)
|
|
class ChatCompletionPostprocArgs(PostprocArgs):
|
|
model: str
|
|
tools: Optional[List[ChatCompletionToolsParam]]
|
|
tool_choice: Optional[Union[Literal["none", "auto"],
|
|
ChatCompletionNamedToolChoiceParam]]
|
|
request_id: Optional[int] = None
|
|
|
|
@classmethod
|
|
def from_request(cls, request: ChatCompletionRequest):
|
|
return cls(
|
|
model=request.model,
|
|
tools=request.tools,
|
|
tool_choice=request.tool_choice,
|
|
)
|
|
|
|
|
|
@nvtx_range_debug("chat_harmony_post_processor")
|
|
def chat_harmony_post_processor(
|
|
rsp: GenerationResult,
|
|
args: ChatCompletionPostprocArgs) -> ChatCompletionResponse:
|
|
response = handle_non_streaming_response(
|
|
tools=args.tools,
|
|
tool_choice=args.tool_choice,
|
|
outputs=rsp.outputs,
|
|
model=args.model,
|
|
num_prompt_tokens=args.num_prompt_tokens,
|
|
)
|
|
return response
|
|
|
|
|
|
@nvtx_range_debug("chat_harmony_streaming_post_processor")
|
|
def chat_harmony_streaming_post_processor(
|
|
rsp: GenerationResult, args: ChatCompletionPostprocArgs) -> List[str]:
|
|
response = handle_streaming_response(
|
|
tools=args.tools,
|
|
tool_choice=args.tool_choice,
|
|
result=rsp,
|
|
model=args.model,
|
|
request_id=args.request_id,
|
|
done=rsp._done,
|
|
num_prompt_tokens=args.num_prompt_tokens,
|
|
)
|
|
return response
|