[TRTLLM-8209][feat] Support new structural tag API (upgrade XGrammar to 0.1.25) (#7893)

Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>
This commit is contained in:
Enwei Zhu 2025-09-23 09:10:09 +08:00 committed by GitHub
parent d471655242
commit 8330d5363a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 42 additions and 62 deletions

2
3rdparty/xgrammar vendored

@ -1 +1 @@
Subproject commit 774867ce410ce1c5c7c10011cc0a9e20bd7894e2
Subproject commit e4e816f5f0fe39f5b1601a17a4552307fa3b70ff

View File

@ -109,34 +109,20 @@ void GuidedDecoder::build(ScheduledRequests const& scheduledRequests)
}
case executor::GuidedDecodingParams::GuideType::kREGEX:
{
auto const& grammar = xgrammar::Grammar::FromRegex(guide.value());
mXGrammarMatchers.at(seqSlot)
= std::make_shared<xgrammar::GrammarMatcher>(mXGrammarCompiler->CompileGrammar(grammar));
mXGrammarMatchers.at(seqSlot) = std::make_shared<xgrammar::GrammarMatcher>(
mXGrammarCompiler->CompileRegex(guide.value()));
break;
}
case executor::GuidedDecodingParams::GuideType::kEBNF_GRAMMAR:
{
auto const& grammar = xgrammar::Grammar::FromEBNF(guide.value());
mXGrammarMatchers.at(seqSlot)
= std::make_shared<xgrammar::GrammarMatcher>(mXGrammarCompiler->CompileGrammar(grammar));
mXGrammarMatchers.at(seqSlot) = std::make_shared<xgrammar::GrammarMatcher>(
mXGrammarCompiler->CompileGrammar(guide.value()));
break;
}
case executor::GuidedDecodingParams::GuideType::kSTRUCTURAL_TAG:
{
auto const& structuralTagParametersJson = nlohmann::json::parse(guide.value());
auto const& structuralTagItemsJson
= structuralTagParametersJson.at("structures").template get<std::vector<nlohmann::json>>();
std::vector<xgrammar::StructuralTagItem> structuralTagItems;
for (auto const& s : structuralTagItemsJson)
{
structuralTagItems.emplace_back(
xgrammar::StructuralTagItem{s.at("begin").template get<std::string>(),
s.at("schema").dump(), s.at("end").template get<std::string>()});
}
auto const& triggers
= structuralTagParametersJson.at("triggers").template get<std::vector<std::string>>();
mXGrammarMatchers.at(seqSlot) = std::make_shared<xgrammar::GrammarMatcher>(
mXGrammarCompiler->CompileStructuralTag(structuralTagItems, triggers));
mXGrammarCompiler->CompileStructuralTag(guide.value()));
break;
}
default:

View File

@ -60,7 +60,7 @@ patchelf==0.18.0
einops
flashinfer-python>=0.3.0
opencv-python-headless
xgrammar==0.1.21
xgrammar==0.1.25
llguidance==0.7.29
jsonschema
backoff

View File

@ -1,4 +1,3 @@
import json
import os
from abc import ABC, abstractmethod
@ -102,24 +101,13 @@ class XGrammarMatcherFactory(GrammarMatcherFactory):
compiled_grammar = self._xgrammar_compiler.compile_json_schema(
guide)
case GuidedDecodingParams.GuideType.REGEX:
grammar = xgrammar.Grammar.from_regex(guide)
compiled_grammar = self._xgrammar_compiler.compile_grammar(
grammar)
compiled_grammar = self._xgrammar_compiler.compile_regex(guide)
case GuidedDecodingParams.GuideType.EBNF_GRAMMAR:
grammar = xgrammar.Grammar.from_ebnf(guide)
compiled_grammar = self._xgrammar_compiler.compile_grammar(
grammar)
guide)
case GuidedDecodingParams.GuideType.STRUCTURAL_TAG:
structural_tag_parameters = json.loads(guide)
structures = structural_tag_parameters["structures"]
structures = [
xgrammar.StructuralTagItem(begin=s["begin"],
schema=json.dumps(s["schema"]),
end=s["end"]) for s in structures
]
triggers = structural_tag_parameters["triggers"]
compiled_grammar = self._xgrammar_compiler.compile_structural_tag(
structures, triggers)
guide)
case _:
raise ValueError(f"Unsupported guide type: {guide_type}.")

View File

@ -6,6 +6,7 @@ import uuid
from typing import Any, Dict, List, Literal, Optional, Union
import torch
import xgrammar
from openai.types.chat import ChatCompletionAssistantMessageParam
from openai.types.chat import \
ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam
@ -85,18 +86,11 @@ class ModelList(OpenAIBaseModel):
data: List[ModelCard] = Field(default_factory=list)
class StructuralTag(OpenAIBaseModel):
begin: str
schema_: Optional[dict[str, Any]] = Field(alias="schema")
end: str
class ResponseFormat(OpenAIBaseModel):
# type must be one of "text", "json", "json_object", or "structural_tag"
type: Literal["text", "json", "json_object", "structural_tag"]
schema: Optional[dict] = None
structures: Optional[List[StructuralTag]] = None
triggers: Optional[List[str]] = None
format: Optional[xgrammar.structural_tag.Format] = None
class DisaggregatedParams(OpenAIBaseModel):

View File

@ -21,7 +21,7 @@ def test_logits_bitmask(batch_size: int, vocab_size: int, stride: int,
size=(batch_size, vocab_size),
dtype=torch.bool,
device="cuda")
bitmask = xgrammar.testing._bool_mask_to_bitmask(bool_mask)
bitmask = xgrammar.testing.bool_mask_to_bitmask(bool_mask)
token_mask = None
if stride > 1:
token_mask = torch.arange(batch_size, dtype=torch.int32,
@ -53,7 +53,7 @@ def test_logits_bitmask_with_d2t(batch_size: int, vocab_size: int, stride: int,
size=(batch_size, vocab_size),
dtype=torch.bool,
device="cuda")
bitmask = xgrammar.testing._bool_mask_to_bitmask(bool_mask)
bitmask = xgrammar.testing.bool_mask_to_bitmask(bool_mask)
token_mask = None
if stride > 1:
token_mask = torch.arange(batch_size, dtype=torch.int32,

View File

@ -162,22 +162,34 @@ You are a helpful assistant."""
messages=messages,
max_completion_tokens=256,
response_format={
"type":
"structural_tag",
"structures": [
{
"begin": "<function=get_current_weather>",
"schema":
tool_get_current_weather["function"]["parameters"],
"end": "</function>",
},
{
"begin": "<function=get_current_date>",
"schema": tool_get_current_date["function"]["parameters"],
"end": "</function>",
},
],
"triggers": ["<function="],
"type": "structural_tag",
"format": {
"type":
"triggered_tags",
"triggers": ["<function="],
"tags": [
{
"begin": "<function=get_current_weather>",
"content": {
"type":
"json_schema",
"json_schema":
tool_get_current_weather["function"]["parameters"]
},
"end": "</function>",
},
{
"begin": "<function=get_current_date>",
"content": {
"type":
"json_schema",
"json_schema":
tool_get_current_date["function"]["parameters"]
},
"end": "</function>",
},
],
},
},
)