mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[TRTLLM-8209][feat] Support new structural tag API (upgrade XGrammar to 0.1.25) (#7893)
Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>
This commit is contained in:
parent
d471655242
commit
8330d5363a
2
3rdparty/xgrammar
vendored
2
3rdparty/xgrammar
vendored
@ -1 +1 @@
|
||||
Subproject commit 774867ce410ce1c5c7c10011cc0a9e20bd7894e2
|
||||
Subproject commit e4e816f5f0fe39f5b1601a17a4552307fa3b70ff
|
||||
@ -109,34 +109,20 @@ void GuidedDecoder::build(ScheduledRequests const& scheduledRequests)
|
||||
}
|
||||
case executor::GuidedDecodingParams::GuideType::kREGEX:
|
||||
{
|
||||
auto const& grammar = xgrammar::Grammar::FromRegex(guide.value());
|
||||
mXGrammarMatchers.at(seqSlot)
|
||||
= std::make_shared<xgrammar::GrammarMatcher>(mXGrammarCompiler->CompileGrammar(grammar));
|
||||
mXGrammarMatchers.at(seqSlot) = std::make_shared<xgrammar::GrammarMatcher>(
|
||||
mXGrammarCompiler->CompileRegex(guide.value()));
|
||||
break;
|
||||
}
|
||||
case executor::GuidedDecodingParams::GuideType::kEBNF_GRAMMAR:
|
||||
{
|
||||
auto const& grammar = xgrammar::Grammar::FromEBNF(guide.value());
|
||||
mXGrammarMatchers.at(seqSlot)
|
||||
= std::make_shared<xgrammar::GrammarMatcher>(mXGrammarCompiler->CompileGrammar(grammar));
|
||||
mXGrammarMatchers.at(seqSlot) = std::make_shared<xgrammar::GrammarMatcher>(
|
||||
mXGrammarCompiler->CompileGrammar(guide.value()));
|
||||
break;
|
||||
}
|
||||
case executor::GuidedDecodingParams::GuideType::kSTRUCTURAL_TAG:
|
||||
{
|
||||
auto const& structuralTagParametersJson = nlohmann::json::parse(guide.value());
|
||||
auto const& structuralTagItemsJson
|
||||
= structuralTagParametersJson.at("structures").template get<std::vector<nlohmann::json>>();
|
||||
std::vector<xgrammar::StructuralTagItem> structuralTagItems;
|
||||
for (auto const& s : structuralTagItemsJson)
|
||||
{
|
||||
structuralTagItems.emplace_back(
|
||||
xgrammar::StructuralTagItem{s.at("begin").template get<std::string>(),
|
||||
s.at("schema").dump(), s.at("end").template get<std::string>()});
|
||||
}
|
||||
auto const& triggers
|
||||
= structuralTagParametersJson.at("triggers").template get<std::vector<std::string>>();
|
||||
mXGrammarMatchers.at(seqSlot) = std::make_shared<xgrammar::GrammarMatcher>(
|
||||
mXGrammarCompiler->CompileStructuralTag(structuralTagItems, triggers));
|
||||
mXGrammarCompiler->CompileStructuralTag(guide.value()));
|
||||
break;
|
||||
}
|
||||
default:
|
||||
|
||||
@ -60,7 +60,7 @@ patchelf==0.18.0
|
||||
einops
|
||||
flashinfer-python>=0.3.0
|
||||
opencv-python-headless
|
||||
xgrammar==0.1.21
|
||||
xgrammar==0.1.25
|
||||
llguidance==0.7.29
|
||||
jsonschema
|
||||
backoff
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
import json
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
@ -102,24 +101,13 @@ class XGrammarMatcherFactory(GrammarMatcherFactory):
|
||||
compiled_grammar = self._xgrammar_compiler.compile_json_schema(
|
||||
guide)
|
||||
case GuidedDecodingParams.GuideType.REGEX:
|
||||
grammar = xgrammar.Grammar.from_regex(guide)
|
||||
compiled_grammar = self._xgrammar_compiler.compile_grammar(
|
||||
grammar)
|
||||
compiled_grammar = self._xgrammar_compiler.compile_regex(guide)
|
||||
case GuidedDecodingParams.GuideType.EBNF_GRAMMAR:
|
||||
grammar = xgrammar.Grammar.from_ebnf(guide)
|
||||
compiled_grammar = self._xgrammar_compiler.compile_grammar(
|
||||
grammar)
|
||||
guide)
|
||||
case GuidedDecodingParams.GuideType.STRUCTURAL_TAG:
|
||||
structural_tag_parameters = json.loads(guide)
|
||||
structures = structural_tag_parameters["structures"]
|
||||
structures = [
|
||||
xgrammar.StructuralTagItem(begin=s["begin"],
|
||||
schema=json.dumps(s["schema"]),
|
||||
end=s["end"]) for s in structures
|
||||
]
|
||||
triggers = structural_tag_parameters["triggers"]
|
||||
compiled_grammar = self._xgrammar_compiler.compile_structural_tag(
|
||||
structures, triggers)
|
||||
guide)
|
||||
case _:
|
||||
raise ValueError(f"Unsupported guide type: {guide_type}.")
|
||||
|
||||
|
||||
@ -6,6 +6,7 @@ import uuid
|
||||
from typing import Any, Dict, List, Literal, Optional, Union
|
||||
|
||||
import torch
|
||||
import xgrammar
|
||||
from openai.types.chat import ChatCompletionAssistantMessageParam
|
||||
from openai.types.chat import \
|
||||
ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam
|
||||
@ -85,18 +86,11 @@ class ModelList(OpenAIBaseModel):
|
||||
data: List[ModelCard] = Field(default_factory=list)
|
||||
|
||||
|
||||
class StructuralTag(OpenAIBaseModel):
|
||||
begin: str
|
||||
schema_: Optional[dict[str, Any]] = Field(alias="schema")
|
||||
end: str
|
||||
|
||||
|
||||
class ResponseFormat(OpenAIBaseModel):
|
||||
# type must be one of "text", "json", "json_object", or "structural_tag"
|
||||
type: Literal["text", "json", "json_object", "structural_tag"]
|
||||
schema: Optional[dict] = None
|
||||
structures: Optional[List[StructuralTag]] = None
|
||||
triggers: Optional[List[str]] = None
|
||||
format: Optional[xgrammar.structural_tag.Format] = None
|
||||
|
||||
|
||||
class DisaggregatedParams(OpenAIBaseModel):
|
||||
|
||||
@ -21,7 +21,7 @@ def test_logits_bitmask(batch_size: int, vocab_size: int, stride: int,
|
||||
size=(batch_size, vocab_size),
|
||||
dtype=torch.bool,
|
||||
device="cuda")
|
||||
bitmask = xgrammar.testing._bool_mask_to_bitmask(bool_mask)
|
||||
bitmask = xgrammar.testing.bool_mask_to_bitmask(bool_mask)
|
||||
token_mask = None
|
||||
if stride > 1:
|
||||
token_mask = torch.arange(batch_size, dtype=torch.int32,
|
||||
@ -53,7 +53,7 @@ def test_logits_bitmask_with_d2t(batch_size: int, vocab_size: int, stride: int,
|
||||
size=(batch_size, vocab_size),
|
||||
dtype=torch.bool,
|
||||
device="cuda")
|
||||
bitmask = xgrammar.testing._bool_mask_to_bitmask(bool_mask)
|
||||
bitmask = xgrammar.testing.bool_mask_to_bitmask(bool_mask)
|
||||
token_mask = None
|
||||
if stride > 1:
|
||||
token_mask = torch.arange(batch_size, dtype=torch.int32,
|
||||
|
||||
@ -162,22 +162,34 @@ You are a helpful assistant."""
|
||||
messages=messages,
|
||||
max_completion_tokens=256,
|
||||
response_format={
|
||||
"type":
|
||||
"structural_tag",
|
||||
"structures": [
|
||||
{
|
||||
"begin": "<function=get_current_weather>",
|
||||
"schema":
|
||||
tool_get_current_weather["function"]["parameters"],
|
||||
"end": "</function>",
|
||||
},
|
||||
{
|
||||
"begin": "<function=get_current_date>",
|
||||
"schema": tool_get_current_date["function"]["parameters"],
|
||||
"end": "</function>",
|
||||
},
|
||||
],
|
||||
"triggers": ["<function="],
|
||||
"type": "structural_tag",
|
||||
"format": {
|
||||
"type":
|
||||
"triggered_tags",
|
||||
"triggers": ["<function="],
|
||||
"tags": [
|
||||
{
|
||||
"begin": "<function=get_current_weather>",
|
||||
"content": {
|
||||
"type":
|
||||
"json_schema",
|
||||
"json_schema":
|
||||
tool_get_current_weather["function"]["parameters"]
|
||||
},
|
||||
"end": "</function>",
|
||||
},
|
||||
{
|
||||
"begin": "<function=get_current_date>",
|
||||
"content": {
|
||||
"type":
|
||||
"json_schema",
|
||||
"json_schema":
|
||||
tool_get_current_date["function"]["parameters"]
|
||||
},
|
||||
"end": "</function>",
|
||||
},
|
||||
],
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user