[TRTLLM-10154][feat] Enable guided decoding with reasoning parsers (#10890)

Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>
This commit is contained in:
Enwei Zhu 2026-01-22 14:14:28 +08:00 committed by GitHub
parent 895bb94b3d
commit be4a431ffd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 130 additions and 29 deletions

View File

@ -37,6 +37,7 @@ from typing_extensions import Annotated, Required, TypeAlias, TypedDict
from tensorrt_llm.executor.request import LoRARequest
from tensorrt_llm.llmapi import DisaggregatedParams as LlmDisaggregatedParams
from tensorrt_llm.llmapi import GuidedDecodingParams, SamplingParams
from tensorrt_llm.llmapi.reasoning_parser import ReasoningParserFactory
def _logit_bias_to_embedding_bias(logit_bias: Optional[Dict[str, float]],
@ -192,40 +193,113 @@ class CompletionStreamResponse(OpenAIBaseModel):
def _response_format_to_guided_decoding_params(
response_format: Optional[ResponseFormat]
response_format: Optional[ResponseFormat],
reasoning_parser: Optional[str] = None,
) -> Optional[GuidedDecodingParams]:
if response_format is None:
return None
guided_decoding_params = None
elif response_format.type == "text":
return None
guided_decoding_params = None
elif response_format.type == "json":
if response_format.schema is None:
raise ValueError(
"The 'schema' field is required when response_format.type is 'json'."
f"response_format.schema is required for response_format.type == {response_format.type!r}, but got None."
)
return GuidedDecodingParams(json=response_format.schema)
guided_decoding_params = GuidedDecodingParams(
json=response_format.schema)
elif response_format.type == "json_schema":
if response_format.json_schema is None:
raise ValueError(
"The 'json_schema' field is required when response_format.type is 'json_schema'."
f"response_format.json_schema is required for response_format.type == {response_format.type!r}, but got None."
)
return GuidedDecodingParams(json=response_format.json_schema)
guided_decoding_params = GuidedDecodingParams(
json=response_format.json_schema)
elif response_format.type == "json_object":
return GuidedDecodingParams(json_object=True)
guided_decoding_params = GuidedDecodingParams(json_object=True)
elif response_format.type == "regex":
return GuidedDecodingParams(regex=response_format.regex)
if response_format.regex is None:
raise ValueError(
f"response_format.regex is required for response_format.type == {response_format.type!r}, but got None."
)
guided_decoding_params = GuidedDecodingParams(
regex=response_format.regex)
elif response_format.type == "ebnf":
return GuidedDecodingParams(grammar=response_format.ebnf)
if response_format.ebnf is None:
raise ValueError(
f"response_format.ebnf is required for response_format.type == {response_format.type!r}, but got None."
)
guided_decoding_params = GuidedDecodingParams(
grammar=response_format.ebnf)
elif response_format.type == "structural_tag":
return GuidedDecodingParams(
guided_decoding_params = GuidedDecodingParams(
structural_tag=response_format.model_dump_json(by_alias=True,
exclude_none=True))
else:
raise ValueError(f"Unsupported response format: {response_format.type}")
if guided_decoding_params is None or reasoning_parser is None:
return guided_decoding_params
if guided_decoding_params.structural_tag is not None:
return guided_decoding_params
# Adapt guided_decoding_params for reasoning parser
if guided_decoding_params.json is not None:
content = {
"type": "json_schema",
"json_schema": guided_decoding_params.json
}
elif guided_decoding_params.json_object:
content = {"type": "json_schema", "json_schema": {"type": "object"}}
elif guided_decoding_params.regex is not None:
content = {"type": "regex", "pattern": guided_decoding_params.regex}
elif guided_decoding_params.grammar is not None:
content = {"type": "grammar", "grammar": guided_decoding_params.grammar}
if reasoning_parser == "gpt_oss":
# Trigger user constraint by final channel
stag_format = {
"type":
"triggered_tags",
"triggers": ["<|start|>assistant<|channel|>final<|message|>"],
"tags": [
{
"begin": "<|start|>assistant<|channel|>final<|message|>",
"content": content,
"end": "",
},
],
"stop_after_first":
True,
}
else:
# Force thinking and then trigger user constraint
parser = ReasoningParserFactory.create_reasoning_parser(
reasoning_parser)
stag_format = {
"type":
"sequence",
"elements": [
{
"type": "tag",
"begin": parser.reasoning_start,
"content": {
"type": "any_text"
},
"end": parser.reasoning_end,
},
content,
],
}
stag_format = ResponseFormat(type="structural_tag", format=stag_format)
return GuidedDecodingParams(structural_tag=stag_format.model_dump_json(
by_alias=True, exclude_none=True))
def _response_format_text_config_to_guided_decoding_params(
text_format: Optional[ResponseFormatTextConfig]
text_format: Optional[ResponseFormatTextConfig],
reasoning_parser: Optional[str] = None,
) -> Optional[GuidedDecodingParams]:
if text_format is None:
return None
@ -233,7 +307,8 @@ def _response_format_text_config_to_guided_decoding_params(
resp_format = ResponseFormat(type=text_format.type,
json_schema=getattr(text_format, "schema_",
None))
return _response_format_to_guided_decoding_params(resp_format)
return _response_format_to_guided_decoding_params(
resp_format, reasoning_parser=reasoning_parser)
class CompletionRequest(OpenAIBaseModel):
@ -650,6 +725,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
def to_sampling_params(self,
vocab_size: int = 32000,
gather_generation_logits: bool = False,
reasoning_parser: Optional[str] = None,
backend: Optional[str] = None) -> SamplingParams:
sampling_params = SamplingParams(
frequency_penalty=self.frequency_penalty,
@ -680,7 +756,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
spaces_between_special_tokens=self.spaces_between_special_tokens,
truncate_prompt_tokens=self.truncate_prompt_tokens,
guided_decoding=_response_format_to_guided_decoding_params(
self.response_format),
self.response_format, reasoning_parser=reasoning_parser),
# logits_bias
embedding_bias=_logit_bias_to_embedding_bias(
@ -810,6 +886,7 @@ class ResponsesRequest(OpenAIBaseModel):
def to_sampling_params(
self,
default_sampling_params: Optional[dict] = None,
reasoning_parser: Optional[str] = None,
) -> SamplingParams:
max_tokens = None
if self.max_output_tokens is not None:
@ -828,7 +905,7 @@ class ResponsesRequest(OpenAIBaseModel):
guided_decoding = None
if self.text is not None and self.text.format is not None:
guided_decoding = _response_format_text_config_to_guided_decoding_params(
self.text.format)
self.text.format, reasoning_parser=reasoning_parser)
return SamplingParams(
temperature=temperature,

View File

@ -540,6 +540,7 @@ class OpenAIServer:
sampling_params = request.to_sampling_params(
vocab_size=self.tokenizer.tokenizer.vocab_size,
gather_generation_logits=self.llm.args.gather_generation_logits,
reasoning_parser=self.llm.args.reasoning_parser,
backend=self.llm.args.backend)
postproc_args = ChatPostprocArgs.from_request(request)
disaggregated_params = to_llm_disaggregated_params(request.disaggregated_params)
@ -916,7 +917,8 @@ class OpenAIServer:
request.stop_token_ids = harmony_stop_tokens
sampling_params = request.to_sampling_params(
vocab_size=self.tokenizer.tokenizer.vocab_size)
vocab_size=self.tokenizer.tokenizer.vocab_size,
reasoning_parser="gpt_oss")
sampling_params.detokenize = False # Harmony adapter handles detokenization
postproc_args = ChatCompletionPostprocArgs.from_request(request)
@ -1018,6 +1020,7 @@ class OpenAIServer:
tokenizer=self.tokenizer if not self.use_harmony else None,
model_config=self.model_config if not self.use_harmony else None,
processor=self.processor if not self.use_harmony else None,
reasoning_parser=self.llm.args.reasoning_parser if not self.use_harmony else "gpt_oss",
)
streaming_processor = None

View File

@ -867,13 +867,16 @@ async def request_preprocess(
tokenizer: Optional[Union[TransformersTokenizer, TokenizerBase]] = None,
model_config: Optional[PretrainedConfig] = None,
processor: Optional[AutoProcessor] = None,
reasoning_parser: Optional[str] = None,
) -> tuple[list[int], SamplingParams]:
sampling_params = request.to_sampling_params(
default_sampling_params={
"stop_token_ids":
get_harmony_adapter().get_stop_tokens() if use_harmony else []
})
},
reasoning_parser=reasoning_parser,
)
prev_response_id = request.previous_response_id

View File

@ -831,7 +831,7 @@
"test_e2e.py::test_mistral_e2e[use_py_session-remove_input_padding--]": 157.39577213302255,
"test_e2e.py::test_mistral_large_hidden_vocab_size": 81.36711680702865,
"test_e2e.py::test_openai_chat_example": 876.1966922096908,
"test_e2e.py::test_openai_chat_guided_decoding": 55.12449237401597,
"test_e2e.py::test_openai_chat_guided_decoding[meta-llama/Llama-3.1-8B-Instruct]": 55.12449237401597,
"test_e2e.py::test_openai_chat_harmony": 1162.7252594940364,
"test_e2e.py::test_openai_chat_multimodal_example": 215.8254322744906,
"test_e2e.py::test_openai_consistent_chat": 0.0001894170418381691,

View File

@ -1736,11 +1736,14 @@ def test_openai_mmencoder_example(llm_root, llm_venv):
str(test_root / "_test_openai_mmencoder.py")])
def test_openai_chat_guided_decoding(llm_root, llm_venv):
@pytest.mark.parametrize(
"model_name", ["meta-llama/Llama-3.1-8B-Instruct", "openai/gpt-oss-120b"])
def test_openai_chat_guided_decoding(llm_root, llm_venv, model_name: str):
test_root = unittest_path() / "llmapi" / "apps"
llm_venv.run_cmd([
"-m", "pytest",
str(test_root / "_test_openai_chat_guided_decoding.py")
str(test_root / "_test_openai_chat_guided_decoding.py"), "-k",
model_name
])

View File

@ -346,7 +346,8 @@ test_e2e.py::test_mistral_e2e[use_py_session---]
test_e2e.py::test_qwen_e2e_cpprunner_large_new_tokens[DeepSeek-R1-Distill-Qwen-1.5B-DeepSeek-R1-Distill-Qwen-1.5B]
test_e2e.py::test_openai_multi_chat_example
test_e2e.py::test_openai_consistent_chat
test_e2e.py::test_openai_chat_guided_decoding
test_e2e.py::test_openai_chat_guided_decoding[meta-llama/Llama-3.1-8B-Instruct]
test_e2e.py::test_openai_chat_guided_decoding[openai/gpt-oss-120b]
test_e2e.py::test_openai_chat_harmony
test_e2e.py::test_trtllm_benchmark_serving[llama-3.1-model/Meta-Llama-3.1-8B]
test_e2e.py::test_trtllm_benchmark_serving[gpt_oss/gpt-oss-20b]

View File

@ -67,6 +67,7 @@ l0_b200:
- test_e2e.py::test_ptp_quickstart_advanced_eagle3[Llama-3.1-8b-Instruct-llama-3.1-model/Llama-3.1-8B-Instruct-EAGLE3-LLaMA3.1-Instruct-8B]
- test_e2e.py::test_ptp_quickstart_advanced_ngram[Llama-3.1-8B-Instruct-llama-3.1-model/Llama-3.1-8B-Instruct]
- test_e2e.py::test_trtllm_bench_pytorch_backend_sanity[meta-llama/Llama-3.1-8B-llama-3.1-8b-False-False]
- test_e2e.py::test_openai_chat_guided_decoding[openai/gpt-oss-120b]
- unittest/_torch/attention
- unittest/_torch/compilation
- unittest/_torch/debugger

View File

@ -110,7 +110,7 @@ l0_h100:
- test_e2e.py::test_trtllm_bench_help_sanity[meta-llama/Llama-3.1-8B]
- test_e2e.py::test_openai_chat_harmony
- test_e2e.py::test_openai_responses
- test_e2e.py::test_openai_chat_guided_decoding
- test_e2e.py::test_openai_chat_guided_decoding[meta-llama/Llama-3.1-8B-Instruct]
- test_e2e.py::test_trtllm_benchmark_serving[llama-3.1-model/Meta-Llama-3.1-8B]
- condition:
ranges:

View File

@ -16,17 +16,28 @@ from .openai_server import RemoteOpenAIServer
pytestmark = pytest.mark.threadleak(enabled=False)
@pytest.fixture(scope="module")
def model_name():
return "llama-3.1-model/Llama-3.1-8B-Instruct"
@pytest.fixture(
scope="module",
params=["meta-llama/Llama-3.1-8B-Instruct", "openai/gpt-oss-120b"])
def model_name(request):
return request.param
@pytest.fixture(scope="module")
def temp_extra_llm_api_options_file():
def temp_extra_llm_api_options_file(model_name: str):
temp_dir = tempfile.gettempdir()
temp_file_path = os.path.join(temp_dir, "extra_llm_api_options.yaml")
try:
extra_llm_api_options_dict = {"guided_decoding_backend": "xgrammar"}
if model_name == "openai/gpt-oss-120b":
extra_llm_api_options_dict["speculative_config"] = {
"decoding_type":
"Eagle",
"max_draft_len":
3,
"speculative_model_dir":
get_model_path("gpt_oss/gpt-oss-120b-Eagle3"),
}
with open(temp_file_path, 'w') as f:
yaml.dump(extra_llm_api_options_dict, f)
@ -39,11 +50,13 @@ def temp_extra_llm_api_options_file():
@pytest.fixture(scope="module")
def server(model_name: str, temp_extra_llm_api_options_file: str):
model_path = get_model_path(model_name)
if model_name == "meta-llama/Llama-3.1-8B-Instruct":
model_path = get_model_path("llama-3.1-model/Llama-3.1-8B-Instruct")
elif model_name == "openai/gpt-oss-120b":
model_path = get_model_path("gpt_oss/gpt-oss-120b")
# Use small max_batch_size/max_seq_len/max_num_tokens to avoid OOM on A10/A30 GPUs.
args = [
"--max_batch_size=8", "--max_seq_len=1024", "--max_num_tokens=1024",
"--max_batch_size=8", "--max_seq_len=4096", "--max_num_tokens=4096",
f"--extra_llm_api_options={temp_extra_llm_api_options_file}"
]
with RemoteOpenAIServer(model_path, args) as remote_server: