mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-05 02:31:33 +08:00
[TRTLLM-10154][feat] Enable guided decoding with reasoning parsers (#10890)
Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>
This commit is contained in:
parent
895bb94b3d
commit
be4a431ffd
@ -37,6 +37,7 @@ from typing_extensions import Annotated, Required, TypeAlias, TypedDict
|
||||
from tensorrt_llm.executor.request import LoRARequest
|
||||
from tensorrt_llm.llmapi import DisaggregatedParams as LlmDisaggregatedParams
|
||||
from tensorrt_llm.llmapi import GuidedDecodingParams, SamplingParams
|
||||
from tensorrt_llm.llmapi.reasoning_parser import ReasoningParserFactory
|
||||
|
||||
|
||||
def _logit_bias_to_embedding_bias(logit_bias: Optional[Dict[str, float]],
|
||||
@ -192,40 +193,113 @@ class CompletionStreamResponse(OpenAIBaseModel):
|
||||
|
||||
|
||||
def _response_format_to_guided_decoding_params(
|
||||
response_format: Optional[ResponseFormat]
|
||||
response_format: Optional[ResponseFormat],
|
||||
reasoning_parser: Optional[str] = None,
|
||||
) -> Optional[GuidedDecodingParams]:
|
||||
if response_format is None:
|
||||
return None
|
||||
guided_decoding_params = None
|
||||
elif response_format.type == "text":
|
||||
return None
|
||||
guided_decoding_params = None
|
||||
elif response_format.type == "json":
|
||||
if response_format.schema is None:
|
||||
raise ValueError(
|
||||
"The 'schema' field is required when response_format.type is 'json'."
|
||||
f"response_format.schema is required for response_format.type == {response_format.type!r}, but got None."
|
||||
)
|
||||
return GuidedDecodingParams(json=response_format.schema)
|
||||
guided_decoding_params = GuidedDecodingParams(
|
||||
json=response_format.schema)
|
||||
elif response_format.type == "json_schema":
|
||||
if response_format.json_schema is None:
|
||||
raise ValueError(
|
||||
"The 'json_schema' field is required when response_format.type is 'json_schema'."
|
||||
f"response_format.json_schema is required for response_format.type == {response_format.type!r}, but got None."
|
||||
)
|
||||
return GuidedDecodingParams(json=response_format.json_schema)
|
||||
guided_decoding_params = GuidedDecodingParams(
|
||||
json=response_format.json_schema)
|
||||
elif response_format.type == "json_object":
|
||||
return GuidedDecodingParams(json_object=True)
|
||||
guided_decoding_params = GuidedDecodingParams(json_object=True)
|
||||
elif response_format.type == "regex":
|
||||
return GuidedDecodingParams(regex=response_format.regex)
|
||||
if response_format.regex is None:
|
||||
raise ValueError(
|
||||
f"response_format.regex is required for response_format.type == {response_format.type!r}, but got None."
|
||||
)
|
||||
guided_decoding_params = GuidedDecodingParams(
|
||||
regex=response_format.regex)
|
||||
elif response_format.type == "ebnf":
|
||||
return GuidedDecodingParams(grammar=response_format.ebnf)
|
||||
if response_format.ebnf is None:
|
||||
raise ValueError(
|
||||
f"response_format.ebnf is required for response_format.type == {response_format.type!r}, but got None."
|
||||
)
|
||||
guided_decoding_params = GuidedDecodingParams(
|
||||
grammar=response_format.ebnf)
|
||||
elif response_format.type == "structural_tag":
|
||||
return GuidedDecodingParams(
|
||||
guided_decoding_params = GuidedDecodingParams(
|
||||
structural_tag=response_format.model_dump_json(by_alias=True,
|
||||
exclude_none=True))
|
||||
else:
|
||||
raise ValueError(f"Unsupported response format: {response_format.type}")
|
||||
|
||||
if guided_decoding_params is None or reasoning_parser is None:
|
||||
return guided_decoding_params
|
||||
|
||||
if guided_decoding_params.structural_tag is not None:
|
||||
return guided_decoding_params
|
||||
|
||||
# Adapt guided_decoding_params for reasoning parser
|
||||
if guided_decoding_params.json is not None:
|
||||
content = {
|
||||
"type": "json_schema",
|
||||
"json_schema": guided_decoding_params.json
|
||||
}
|
||||
elif guided_decoding_params.json_object:
|
||||
content = {"type": "json_schema", "json_schema": {"type": "object"}}
|
||||
elif guided_decoding_params.regex is not None:
|
||||
content = {"type": "regex", "pattern": guided_decoding_params.regex}
|
||||
elif guided_decoding_params.grammar is not None:
|
||||
content = {"type": "grammar", "grammar": guided_decoding_params.grammar}
|
||||
|
||||
if reasoning_parser == "gpt_oss":
|
||||
# Trigger user constraint by final channel
|
||||
stag_format = {
|
||||
"type":
|
||||
"triggered_tags",
|
||||
"triggers": ["<|start|>assistant<|channel|>final<|message|>"],
|
||||
"tags": [
|
||||
{
|
||||
"begin": "<|start|>assistant<|channel|>final<|message|>",
|
||||
"content": content,
|
||||
"end": "",
|
||||
},
|
||||
],
|
||||
"stop_after_first":
|
||||
True,
|
||||
}
|
||||
else:
|
||||
# Force thinking and then trigger user constraint
|
||||
parser = ReasoningParserFactory.create_reasoning_parser(
|
||||
reasoning_parser)
|
||||
stag_format = {
|
||||
"type":
|
||||
"sequence",
|
||||
"elements": [
|
||||
{
|
||||
"type": "tag",
|
||||
"begin": parser.reasoning_start,
|
||||
"content": {
|
||||
"type": "any_text"
|
||||
},
|
||||
"end": parser.reasoning_end,
|
||||
},
|
||||
content,
|
||||
],
|
||||
}
|
||||
|
||||
stag_format = ResponseFormat(type="structural_tag", format=stag_format)
|
||||
return GuidedDecodingParams(structural_tag=stag_format.model_dump_json(
|
||||
by_alias=True, exclude_none=True))
|
||||
|
||||
|
||||
def _response_format_text_config_to_guided_decoding_params(
|
||||
text_format: Optional[ResponseFormatTextConfig]
|
||||
text_format: Optional[ResponseFormatTextConfig],
|
||||
reasoning_parser: Optional[str] = None,
|
||||
) -> Optional[GuidedDecodingParams]:
|
||||
if text_format is None:
|
||||
return None
|
||||
@ -233,7 +307,8 @@ def _response_format_text_config_to_guided_decoding_params(
|
||||
resp_format = ResponseFormat(type=text_format.type,
|
||||
json_schema=getattr(text_format, "schema_",
|
||||
None))
|
||||
return _response_format_to_guided_decoding_params(resp_format)
|
||||
return _response_format_to_guided_decoding_params(
|
||||
resp_format, reasoning_parser=reasoning_parser)
|
||||
|
||||
|
||||
class CompletionRequest(OpenAIBaseModel):
|
||||
@ -650,6 +725,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
||||
def to_sampling_params(self,
|
||||
vocab_size: int = 32000,
|
||||
gather_generation_logits: bool = False,
|
||||
reasoning_parser: Optional[str] = None,
|
||||
backend: Optional[str] = None) -> SamplingParams:
|
||||
sampling_params = SamplingParams(
|
||||
frequency_penalty=self.frequency_penalty,
|
||||
@ -680,7 +756,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
||||
spaces_between_special_tokens=self.spaces_between_special_tokens,
|
||||
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||
guided_decoding=_response_format_to_guided_decoding_params(
|
||||
self.response_format),
|
||||
self.response_format, reasoning_parser=reasoning_parser),
|
||||
|
||||
# logits_bias
|
||||
embedding_bias=_logit_bias_to_embedding_bias(
|
||||
@ -810,6 +886,7 @@ class ResponsesRequest(OpenAIBaseModel):
|
||||
def to_sampling_params(
|
||||
self,
|
||||
default_sampling_params: Optional[dict] = None,
|
||||
reasoning_parser: Optional[str] = None,
|
||||
) -> SamplingParams:
|
||||
max_tokens = None
|
||||
if self.max_output_tokens is not None:
|
||||
@ -828,7 +905,7 @@ class ResponsesRequest(OpenAIBaseModel):
|
||||
guided_decoding = None
|
||||
if self.text is not None and self.text.format is not None:
|
||||
guided_decoding = _response_format_text_config_to_guided_decoding_params(
|
||||
self.text.format)
|
||||
self.text.format, reasoning_parser=reasoning_parser)
|
||||
|
||||
return SamplingParams(
|
||||
temperature=temperature,
|
||||
|
||||
@ -540,6 +540,7 @@ class OpenAIServer:
|
||||
sampling_params = request.to_sampling_params(
|
||||
vocab_size=self.tokenizer.tokenizer.vocab_size,
|
||||
gather_generation_logits=self.llm.args.gather_generation_logits,
|
||||
reasoning_parser=self.llm.args.reasoning_parser,
|
||||
backend=self.llm.args.backend)
|
||||
postproc_args = ChatPostprocArgs.from_request(request)
|
||||
disaggregated_params = to_llm_disaggregated_params(request.disaggregated_params)
|
||||
@ -916,7 +917,8 @@ class OpenAIServer:
|
||||
request.stop_token_ids = harmony_stop_tokens
|
||||
|
||||
sampling_params = request.to_sampling_params(
|
||||
vocab_size=self.tokenizer.tokenizer.vocab_size)
|
||||
vocab_size=self.tokenizer.tokenizer.vocab_size,
|
||||
reasoning_parser="gpt_oss")
|
||||
sampling_params.detokenize = False # Harmony adapter handles detokenization
|
||||
|
||||
postproc_args = ChatCompletionPostprocArgs.from_request(request)
|
||||
@ -1018,6 +1020,7 @@ class OpenAIServer:
|
||||
tokenizer=self.tokenizer if not self.use_harmony else None,
|
||||
model_config=self.model_config if not self.use_harmony else None,
|
||||
processor=self.processor if not self.use_harmony else None,
|
||||
reasoning_parser=self.llm.args.reasoning_parser if not self.use_harmony else "gpt_oss",
|
||||
)
|
||||
|
||||
streaming_processor = None
|
||||
|
||||
@ -867,13 +867,16 @@ async def request_preprocess(
|
||||
tokenizer: Optional[Union[TransformersTokenizer, TokenizerBase]] = None,
|
||||
model_config: Optional[PretrainedConfig] = None,
|
||||
processor: Optional[AutoProcessor] = None,
|
||||
reasoning_parser: Optional[str] = None,
|
||||
) -> tuple[list[int], SamplingParams]:
|
||||
|
||||
sampling_params = request.to_sampling_params(
|
||||
default_sampling_params={
|
||||
"stop_token_ids":
|
||||
get_harmony_adapter().get_stop_tokens() if use_harmony else []
|
||||
})
|
||||
},
|
||||
reasoning_parser=reasoning_parser,
|
||||
)
|
||||
|
||||
prev_response_id = request.previous_response_id
|
||||
|
||||
|
||||
@ -831,7 +831,7 @@
|
||||
"test_e2e.py::test_mistral_e2e[use_py_session-remove_input_padding--]": 157.39577213302255,
|
||||
"test_e2e.py::test_mistral_large_hidden_vocab_size": 81.36711680702865,
|
||||
"test_e2e.py::test_openai_chat_example": 876.1966922096908,
|
||||
"test_e2e.py::test_openai_chat_guided_decoding": 55.12449237401597,
|
||||
"test_e2e.py::test_openai_chat_guided_decoding[meta-llama/Llama-3.1-8B-Instruct]": 55.12449237401597,
|
||||
"test_e2e.py::test_openai_chat_harmony": 1162.7252594940364,
|
||||
"test_e2e.py::test_openai_chat_multimodal_example": 215.8254322744906,
|
||||
"test_e2e.py::test_openai_consistent_chat": 0.0001894170418381691,
|
||||
|
||||
@ -1736,11 +1736,14 @@ def test_openai_mmencoder_example(llm_root, llm_venv):
|
||||
str(test_root / "_test_openai_mmencoder.py")])
|
||||
|
||||
|
||||
def test_openai_chat_guided_decoding(llm_root, llm_venv):
|
||||
@pytest.mark.parametrize(
|
||||
"model_name", ["meta-llama/Llama-3.1-8B-Instruct", "openai/gpt-oss-120b"])
|
||||
def test_openai_chat_guided_decoding(llm_root, llm_venv, model_name: str):
|
||||
test_root = unittest_path() / "llmapi" / "apps"
|
||||
llm_venv.run_cmd([
|
||||
"-m", "pytest",
|
||||
str(test_root / "_test_openai_chat_guided_decoding.py")
|
||||
str(test_root / "_test_openai_chat_guided_decoding.py"), "-k",
|
||||
model_name
|
||||
])
|
||||
|
||||
|
||||
|
||||
@ -346,7 +346,8 @@ test_e2e.py::test_mistral_e2e[use_py_session---]
|
||||
test_e2e.py::test_qwen_e2e_cpprunner_large_new_tokens[DeepSeek-R1-Distill-Qwen-1.5B-DeepSeek-R1-Distill-Qwen-1.5B]
|
||||
test_e2e.py::test_openai_multi_chat_example
|
||||
test_e2e.py::test_openai_consistent_chat
|
||||
test_e2e.py::test_openai_chat_guided_decoding
|
||||
test_e2e.py::test_openai_chat_guided_decoding[meta-llama/Llama-3.1-8B-Instruct]
|
||||
test_e2e.py::test_openai_chat_guided_decoding[openai/gpt-oss-120b]
|
||||
test_e2e.py::test_openai_chat_harmony
|
||||
test_e2e.py::test_trtllm_benchmark_serving[llama-3.1-model/Meta-Llama-3.1-8B]
|
||||
test_e2e.py::test_trtllm_benchmark_serving[gpt_oss/gpt-oss-20b]
|
||||
|
||||
@ -67,6 +67,7 @@ l0_b200:
|
||||
- test_e2e.py::test_ptp_quickstart_advanced_eagle3[Llama-3.1-8b-Instruct-llama-3.1-model/Llama-3.1-8B-Instruct-EAGLE3-LLaMA3.1-Instruct-8B]
|
||||
- test_e2e.py::test_ptp_quickstart_advanced_ngram[Llama-3.1-8B-Instruct-llama-3.1-model/Llama-3.1-8B-Instruct]
|
||||
- test_e2e.py::test_trtllm_bench_pytorch_backend_sanity[meta-llama/Llama-3.1-8B-llama-3.1-8b-False-False]
|
||||
- test_e2e.py::test_openai_chat_guided_decoding[openai/gpt-oss-120b]
|
||||
- unittest/_torch/attention
|
||||
- unittest/_torch/compilation
|
||||
- unittest/_torch/debugger
|
||||
|
||||
@ -110,7 +110,7 @@ l0_h100:
|
||||
- test_e2e.py::test_trtllm_bench_help_sanity[meta-llama/Llama-3.1-8B]
|
||||
- test_e2e.py::test_openai_chat_harmony
|
||||
- test_e2e.py::test_openai_responses
|
||||
- test_e2e.py::test_openai_chat_guided_decoding
|
||||
- test_e2e.py::test_openai_chat_guided_decoding[meta-llama/Llama-3.1-8B-Instruct]
|
||||
- test_e2e.py::test_trtllm_benchmark_serving[llama-3.1-model/Meta-Llama-3.1-8B]
|
||||
- condition:
|
||||
ranges:
|
||||
|
||||
@ -16,17 +16,28 @@ from .openai_server import RemoteOpenAIServer
|
||||
pytestmark = pytest.mark.threadleak(enabled=False)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def model_name():
|
||||
return "llama-3.1-model/Llama-3.1-8B-Instruct"
|
||||
@pytest.fixture(
|
||||
scope="module",
|
||||
params=["meta-llama/Llama-3.1-8B-Instruct", "openai/gpt-oss-120b"])
|
||||
def model_name(request):
|
||||
return request.param
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def temp_extra_llm_api_options_file():
|
||||
def temp_extra_llm_api_options_file(model_name: str):
|
||||
temp_dir = tempfile.gettempdir()
|
||||
temp_file_path = os.path.join(temp_dir, "extra_llm_api_options.yaml")
|
||||
try:
|
||||
extra_llm_api_options_dict = {"guided_decoding_backend": "xgrammar"}
|
||||
if model_name == "openai/gpt-oss-120b":
|
||||
extra_llm_api_options_dict["speculative_config"] = {
|
||||
"decoding_type":
|
||||
"Eagle",
|
||||
"max_draft_len":
|
||||
3,
|
||||
"speculative_model_dir":
|
||||
get_model_path("gpt_oss/gpt-oss-120b-Eagle3"),
|
||||
}
|
||||
|
||||
with open(temp_file_path, 'w') as f:
|
||||
yaml.dump(extra_llm_api_options_dict, f)
|
||||
@ -39,11 +50,13 @@ def temp_extra_llm_api_options_file():
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server(model_name: str, temp_extra_llm_api_options_file: str):
|
||||
model_path = get_model_path(model_name)
|
||||
if model_name == "meta-llama/Llama-3.1-8B-Instruct":
|
||||
model_path = get_model_path("llama-3.1-model/Llama-3.1-8B-Instruct")
|
||||
elif model_name == "openai/gpt-oss-120b":
|
||||
model_path = get_model_path("gpt_oss/gpt-oss-120b")
|
||||
|
||||
# Use small max_batch_size/max_seq_len/max_num_tokens to avoid OOM on A10/A30 GPUs.
|
||||
args = [
|
||||
"--max_batch_size=8", "--max_seq_len=1024", "--max_num_tokens=1024",
|
||||
"--max_batch_size=8", "--max_seq_len=4096", "--max_num_tokens=4096",
|
||||
f"--extra_llm_api_options={temp_extra_llm_api_options_file}"
|
||||
]
|
||||
with RemoteOpenAIServer(model_path, args) as remote_server:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user