mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[None][feat] Update reasoning parser for nano-v3 (#9944)
Signed-off-by: Wanli Jiang <35160485+Wanli-Jiang@users.noreply.github.com>
This commit is contained in:
parent
9e7182b603
commit
3230fbe79a
@ -1,6 +1,6 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Type
|
||||
from typing import Any, Optional, Type
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -109,15 +109,28 @@ class ReasoningParserFactory:
|
||||
parsers: dict[str, Type[BaseReasoningParser]] = {
|
||||
"deepseek-r1": DeepSeekR1Parser,
|
||||
"qwen3": DeepSeekR1Parser,
|
||||
"nano-v3": DeepSeekR1Parser,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def create_reasoning_parser(reasoning_parser: str) -> BaseReasoningParser:
|
||||
def create_reasoning_parser(
|
||||
reasoning_parser: str,
|
||||
chat_template_kwargs: Optional[dict[str, Any]] = None
|
||||
) -> BaseReasoningParser:
|
||||
try:
|
||||
reasoning_parser_class = ReasoningParserFactory.parsers[
|
||||
reasoning_parser.lower()]
|
||||
if reasoning_parser == "deepseek-r1":
|
||||
return reasoning_parser_class(reasoning_at_start=True)
|
||||
elif reasoning_parser == "nano-v3":
|
||||
# Note: If the model is with reasoning (default behavior), `reasoning_at_start` should be True, and the starting response should be parsed into `reasoning_content`.
|
||||
# While the model is without reasoning, `reasoning_at_start` should be False to parse the response into `content` fields.
|
||||
is_reasoning_model = True
|
||||
if isinstance(chat_template_kwargs, dict):
|
||||
is_reasoning_model = chat_template_kwargs.get(
|
||||
"enable_thinking", True)
|
||||
return reasoning_parser_class(
|
||||
reasoning_at_start=is_reasoning_model)
|
||||
return reasoning_parser_class()
|
||||
except KeyError as e:
|
||||
raise ValueError(
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Literal, Optional, Tuple, Union
|
||||
from typing import Any, List, Literal, Optional, Tuple, Union
|
||||
|
||||
from .._utils import nvtx_range_debug
|
||||
from ..executor import (DetokenizedGenerationResultBase, GenerationResult,
|
||||
@ -55,6 +55,7 @@ class ChatPostprocArgs(PostprocArgs):
|
||||
tool_parser_dict: dict[int, BaseToolParser] = field(default_factory=dict)
|
||||
has_tool_call: dict[int, bool] = field(default_factory=dict)
|
||||
tool_call_id_type: str = "random"
|
||||
chat_template_kwargs: Optional[dict[str, Any]] = None
|
||||
|
||||
@classmethod
|
||||
def from_request(cls, request: ChatCompletionRequest):
|
||||
@ -69,6 +70,7 @@ class ChatPostprocArgs(PostprocArgs):
|
||||
stream_options=request.stream_options,
|
||||
return_logprobs=bool(request.logprobs),
|
||||
top_logprobs=bool(request.top_logprobs),
|
||||
chat_template_kwargs=request.chat_template_kwargs,
|
||||
)
|
||||
|
||||
|
||||
@ -108,9 +110,10 @@ def apply_reasoning_parser(args: ChatPostprocArgs, output_index: int, text: str,
|
||||
reasoning_parser = None
|
||||
if args.reasoning_parser is not None:
|
||||
if output_index not in args.reasoning_parser_dict:
|
||||
chat_template_kwargs = getattr(args, "chat_template_kwargs", None)
|
||||
args.reasoning_parser_dict[
|
||||
output_index] = ReasoningParserFactory.create_reasoning_parser(
|
||||
args.reasoning_parser)
|
||||
args.reasoning_parser, chat_template_kwargs)
|
||||
reasoning_parser = args.reasoning_parser_dict[output_index]
|
||||
|
||||
if reasoning_parser is not None:
|
||||
@ -501,6 +504,7 @@ class ChatCompletionPostprocArgs(PostprocArgs):
|
||||
tool_choice: Optional[Union[Literal["none", "auto"],
|
||||
ChatCompletionNamedToolChoiceParam]]
|
||||
request_id: Optional[int] = None
|
||||
chat_template_kwargs: Optional[dict[str, Any]] = None
|
||||
|
||||
@classmethod
|
||||
def from_request(cls, request: ChatCompletionRequest):
|
||||
@ -508,6 +512,7 @@ class ChatCompletionPostprocArgs(PostprocArgs):
|
||||
model=request.model,
|
||||
tools=request.tools,
|
||||
tool_choice=request.tool_choice,
|
||||
chat_template_kwargs=request.chat_template_kwargs,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -71,3 +71,81 @@ def test_qwen3_reasoning_parser_stream(delta_texts: list, content: list,
|
||||
result = reasoning_parser.parse_delta(delta_text)
|
||||
assert result.content == content[i]
|
||||
assert result.reasoning_content == reasoning_context[i]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("text", "content", "reasoning_context", "chat_template_kwargs"),
|
||||
[
|
||||
("a b", "", "a b", None),
|
||||
(f"{R1_END} a b", " a b", "", None),
|
||||
(f"a {R1_END} b", " b", "a ", None),
|
||||
(f"a b {R1_END}", "", "a b ", None),
|
||||
(f"{R1_START} a {R1_END} b", " b", f"{R1_START} a ", None),
|
||||
# All without reasoning_context.
|
||||
("a b", "a b", "", {
|
||||
"enable_thinking": False
|
||||
}),
|
||||
(f"{R1_END} a b", f"{R1_END} a b", "", {
|
||||
"enable_thinking": False
|
||||
}),
|
||||
(f"a {R1_END} b", f"a {R1_END} b", "", {
|
||||
"enable_thinking": False
|
||||
}),
|
||||
(f"a b {R1_END}", f"a b {R1_END}", "", {
|
||||
"enable_thinking": False
|
||||
}),
|
||||
])
|
||||
def test_nano_v3_reasoning_parser(text: str, content: str,
|
||||
reasoning_context: str,
|
||||
chat_template_kwargs: dict):
|
||||
reasoning_parser = ReasoningParserFactory.create_reasoning_parser(
|
||||
"nano-v3", chat_template_kwargs)
|
||||
result = reasoning_parser.parse(text)
|
||||
print(f"text: {text}, result: {result}")
|
||||
assert result.content == content
|
||||
assert result.reasoning_content == reasoning_context
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("delta_texts", "content", "reasoning_context", "chat_template_kwargs"),
|
||||
[
|
||||
(["a", "b"], ["", ""], ["a", "b"], None),
|
||||
([R1_END, "a", "b"], ["", "a", "b"], ["", "", ""], None),
|
||||
(["a", R1_END, "b"], ["", "", "b"], ["a", "", ""], None),
|
||||
(["a", "b", R1_END], ["", "", ""], ["a", "b", ""], None),
|
||||
(["a", f"l{R1_END}", "b"], ["", "", "b"], ["a", "l", ""], None),
|
||||
(["a", f"l{R1_END}r", "b"], ["", "r", "b"], ["a", "l", ""], None),
|
||||
(["a", f"{R1_END}r", "b"], ["", "r", "b"], ["a", "", ""], None),
|
||||
# All without reasoning_context.
|
||||
(["a", "b"], ["a", "b"], ["", ""], {
|
||||
"enable_thinking": False
|
||||
}),
|
||||
([R1_END, "a", "b"], ["", f"{R1_END}a", "b"], ["", "", ""], {
|
||||
"enable_thinking": False
|
||||
}),
|
||||
(["a", R1_END, "b"], ["a", "", f"{R1_END}b"], ["", "", ""], {
|
||||
"enable_thinking": False
|
||||
}),
|
||||
(["a", "b", R1_END], ["a", "b", ""], ["", "", ""], {
|
||||
"enable_thinking": False
|
||||
}),
|
||||
(["a", f"l{R1_END}", "b"], ["a", f"l{R1_END}", "b"], ["", "", ""], {
|
||||
"enable_thinking": False
|
||||
}),
|
||||
(["a", f"l{R1_END}r", "b"], ["a", f"l{R1_END}r", "b"], ["", "", ""], {
|
||||
"enable_thinking": False
|
||||
}),
|
||||
(["a", f"{R1_END}r", "b"], ["a", f"{R1_END}r", "b"], ["", "", ""], {
|
||||
"enable_thinking": False
|
||||
}),
|
||||
])
|
||||
def test_nano_v3_reasoning_parser_stream(delta_texts: list, content: list,
|
||||
reasoning_context: list,
|
||||
chat_template_kwargs: dict):
|
||||
reasoning_parser = ReasoningParserFactory.create_reasoning_parser(
|
||||
"nano-v3", chat_template_kwargs)
|
||||
for i, delta_text in enumerate(delta_texts):
|
||||
result = reasoning_parser.parse_delta(delta_text)
|
||||
print(f"delta_text: {delta_text}, result: {result}")
|
||||
assert result.content == content[i]
|
||||
assert result.reasoning_content == reasoning_context[i]
|
||||
|
||||
Loading…
Reference in New Issue
Block a user