test: Add LLGuidance test and refine guided decoding (#5348)

Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>
This commit is contained in:
Enwei Zhu 2025-06-25 14:12:56 +08:00 committed by GitHub
parent 76da7fed86
commit fc7a81ceb0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 51 additions and 26 deletions

View File

@ -35,6 +35,8 @@ GuidedDecoder::GuidedDecoder(executor::GuidedDecodingConfig const& guidedDecodin
, mLogitsDtype{logitsDtype}
, mCopyBufferManager{std::make_shared<CudaStream>()}
{
TLLM_CHECK_WITH_INFO(mGuidedDecodingBackend != executor::GuidedDecodingConfig::GuidedDecodingBackend::kLLGUIDANCE,
"LLGuidance is not supported for guided decoding in C++ runtime.");
if (mGuidedDecodingBackend == executor::GuidedDecodingConfig::GuidedDecodingBackend::kXGRAMMAR)
{
mXGrammarMatchers.resize(mMaxNumSequences);

View File

@ -1621,6 +1621,9 @@ std::tuple<Executor::Impl::RequestList, double> Executor::Impl::fetchNewRequests
TLLM_CHECK_WITH_INFO(mModel->hasGuidedDecoder(),
"Request is specified with GuidedDecodingParams, but GuidedDecoder is not setup. Please "
"provide a valid GuidedDecodingConfig to setup GuidedDecoder.");
TLLM_CHECK_WITH_INFO(newReq->getGuidedDecodingParams()->getGuideType()
!= executor::GuidedDecodingParams::GuideType::kSTRUCTURAL_TAG,
"Structural tag is not supported for guided decoding in C++ Executor.");
}
if (mModel->getWorldConfig().isLastPipelineParallelRank() && newReq->hasAdditionalOutputs())

View File

@ -10,8 +10,7 @@ We provide a CLI tool `trtllm-eval` for evaluating model accuracy. It shares the
pip install -r requirements.txt
# Evaluate Llama-3.1-8B-Instruct on MMLU
wget https://people.eecs.berkeley.edu/~hendrycks/data.tar && tar -xf data.tar
trtllm-eval --model meta-llama/Llama-3.1-8B-Instruct mmlu --dataset_path data
trtllm-eval --model meta-llama/Llama-3.1-8B-Instruct mmlu
# Evaluate Llama-3.1-8B-Instruct on GSM8K
trtllm-eval --model meta-llama/Llama-3.1-8B-Instruct gsm8k

View File

@ -96,7 +96,7 @@ class XGrammarMatcherFactory(GrammarMatcherFactory):
compiled_grammar = self._xgrammar_compiler.compile_structural_tag(
structures, triggers)
case _:
raise ValueError(f"Unrecognized guide type: {guide_type}.")
raise ValueError(f"Unsupported guide type: {guide_type}.")
matcher = xgrammar.GrammarMatcher(compiled_grammar)
return XGrammarMatcher(matcher)
@ -167,7 +167,7 @@ class LLGuidanceMatcherFactory(GrammarMatcherFactory):
# provide Lark-formatted grammar instead of standard EBNF.
grammar = llguidance.LLMatcher.grammar_from_lark(guide)
case _:
raise ValueError(f"Unrecognized guide type: {guide_type}.")
raise ValueError(f"Unsupported guide type: {guide_type}.")
matcher = llguidance.LLMatcher(self._tokenizer, grammar)
if matcher.is_error():

View File

@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
from typing import Iterable, List, Optional, Union
import click
@ -56,8 +57,13 @@ class JsonModeEval(Evaluator):
for i, sample in enumerate(self.data):
if i >= self.num_samples:
break
schema = sample["schema"]
if os.environ.get("TRTLLM_XGUIDANCE_LENIENT") == "1":
schema = json.loads(schema)
schema["x-guidance"] = {"lenient": True}
schema = json.dumps(schema)
sampling_args = {
"guided_decoding": GuidedDecodingParams(json=sample["schema"])
"guided_decoding": GuidedDecodingParams(json=schema)
}
yield sample["prompt"], sampling_args, sample["completion"]

View File

@ -879,8 +879,11 @@ class BaseLlmArgs(BaseModel):
enable_chunked_prefill: bool = Field(default=False,
description="Enable chunked prefill.")
guided_decoding_backend: Optional[str] = Field(
default=None, description="Guided decoding backend.")
guided_decoding_backend: Optional[Literal["xgrammar", "llguidance"]] = Field(
default=None,
description=
"Guided decoding backend. llguidance is supported in PyTorch backend only."
)
batched_logits_processor: Optional[object] = Field(
default=None,

View File

@ -19,7 +19,7 @@ class GuidedDecodingParams:
regex (str, optional): The generated text is amenable to the user-specified regular expression. Defaults to None.
grammar (str, optional): The generated text is amenable to the user-specified extended Backus-Naur form (EBNF) grammar. Defaults to None.
json_object (bool): If True, the generated text is amenable to json format. Defaults to False.
structural_tag (str, optional): The generated text is amenable to the user-specified structural tag. Defaults to None.
structural_tag (str, optional): The generated text is amenable to the user-specified structural tag. Structural tag is supported by xgrammar in PyTorch backend only. Defaults to None.
""" # noqa: E501
json: Optional[Union[str, BaseModel, dict]] = None

View File

@ -58,16 +58,18 @@ class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness):
task = GSM8K(self.MODEL_NAME)
task.evaluate(llm)
def test_guided_decoding(self):
llm = LLM(self.MODEL_PATH, guided_decoding_backend="xgrammar")
@pytest.mark.parametrize("backend", ["xgrammar"])
def test_guided_decoding(self, backend: str):
llm = LLM(self.MODEL_PATH, guided_decoding_backend=backend)
with llm:
task = JsonModeEval(self.MODEL_NAME)
task.evaluate(llm)
@pytest.mark.skip_less_device(4)
def test_guided_decoding_4gpus(self):
@pytest.mark.parametrize("backend", ["xgrammar"])
def test_guided_decoding_4gpus(self, backend: str):
llm = LLM(self.MODEL_PATH,
guided_decoding_backend="xgrammar",
guided_decoding_backend=backend,
tensor_parallel_size=2,
pipeline_parallel_size=2)
with llm:

View File

@ -12,6 +12,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import pytest
from tensorrt_llm import LLM
@ -277,9 +279,11 @@ class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness):
task = MMLU(self.MODEL_NAME)
task.evaluate(llm)
def test_guided_decoding(self):
@pytest.mark.parametrize("backend", ["xgrammar", "llguidance"])
def test_guided_decoding(self, backend: str, mocker):
mocker.patch.dict(os.environ, {"TRTLLM_XGUIDANCE_LENIENT": "1"})
llm = LLM(self.MODEL_PATH,
guided_decoding_backend="xgrammar",
guided_decoding_backend=backend,
disable_overlap_scheduler=True,
use_cuda_graph=True)
with llm:
@ -287,9 +291,11 @@ class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness):
task.evaluate(llm)
@pytest.mark.skip_less_device(4)
def test_guided_decoding_4gpus(self):
@pytest.mark.parametrize("backend", ["xgrammar", "llguidance"])
def test_guided_decoding_4gpus(self, backend: str, mocker):
mocker.patch.dict(os.environ, {"TRTLLM_XGUIDANCE_LENIENT": "1"})
llm = LLM(self.MODEL_PATH,
guided_decoding_backend="xgrammar",
guided_decoding_backend=backend,
disable_overlap_scheduler=True,
use_cuda_graph=True,
tensor_parallel_size=2,

View File

@ -418,8 +418,8 @@ accuracy/test_llm_api.py::TestQwen2_7BInstruct::test_weight_only
accuracy/test_cli_flow.py::TestQwen2_7BInstruct::test_int4_awq_prequantized
accuracy/test_cli_flow.py::TestQwen2_57B_A14B::test_tp4
accuracy/test_cli_flow.py::TestQwen2_57B_A14B::test_tp2pp2
accuracy/test_llm_api.py::TestLlama3_1_8BInstruct::test_guided_decoding
accuracy/test_llm_api.py::TestLlama3_1_8BInstruct::test_guided_decoding_4gpus
accuracy/test_llm_api.py::TestLlama3_1_8BInstruct::test_guided_decoding[xgrammar]
accuracy/test_llm_api.py::TestLlama3_1_8BInstruct::test_guided_decoding_4gpus[xgrammar]
accuracy/test_llm_api.py::TestQwen2_5_1_5BInstruct::test_auto_dtype
accuracy/test_llm_api.py::TestQwen2_5_1_5BInstruct::test_weight_only
accuracy/test_llm_api.py::TestLlama3_1_8B::test_fp8_rowwise
@ -440,8 +440,10 @@ accuracy/test_llm_api_pytorch.py::TestLlama3_1_8B::test_nvfp4
accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_llm_sampler
accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_eagle3
accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_ngram
accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding
accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding_4gpus
accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding[xgrammar]
accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding[llguidance]
accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding_4gpus[xgrammar]
accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding_4gpus[llguidance]
accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp8_tp4
accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_nvfp4_tp4
accuracy/test_cli_flow.py::TestLlama3_3_70BInstruct::test_fp8_prequantized_tp4

View File

@ -46,7 +46,7 @@ l0_a100:
- accuracy/test_cli_flow.py::TestLlama3_2_1B::test_smooth_quant_ootb_manage_weights
- accuracy/test_cli_flow.py::TestLlama3_8BInstruct::test_int8_gptq
- accuracy/test_cli_flow.py::TestQwen2_7BInstruct::test_int4_awq_prequantized
- accuracy/test_llm_api.py::TestLlama3_1_8BInstruct::test_guided_decoding
- accuracy/test_llm_api.py::TestLlama3_1_8BInstruct::test_guided_decoding[xgrammar]
- condition:
ranges:
system_gpu_count:

View File

@ -31,7 +31,7 @@ l0_dgx_h100:
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[tp2pp2-fp8kv=True-attn_backend=TRTLLM-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[pp4-fp8kv=True-attn_backend=TRTLLM-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[tp2pp2-fp8kv=True-attn_backend=TRTLLM-torch_compile=True]
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding_4gpus
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding_4gpus[xgrammar]
- disaggregated/test_disaggregated.py::test_disaggregated_multi_gpu_with_mpirun[TinyLlama-1.1B-Chat-v1.0]
- disaggregated/test_disaggregated.py::test_disaggregated_multi_gpu_with_mpirun_trt_backend[TinyLlama-1.1B-Chat-v1.0]
- disaggregated/test_disaggregated.py::test_disaggregated_cuda_graph[TinyLlama-1.1B-Chat-v1.0]

View File

@ -92,6 +92,7 @@ l0_dgx_h200:
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[tp2pp2-fp8kv=True-attn_backend=FLASHINFER-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[tp2pp2-fp8kv=True-attn_backend=FLASHINFER-torch_compile=True]
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[pp4-fp8kv=True-attn_backend=FLASHINFER-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding_4gpus[llguidance]
- condition:
ranges:
system_gpu_count:
@ -120,7 +121,7 @@ l0_dgx_h200:
- accuracy/test_llm_api.py::TestQwen2_7BInstruct::test_tp2
- accuracy/test_llm_api.py::TestLlama3_1_8BInstruct::test_cp2
- accuracy/test_llm_api.py::TestLlama3_1_8BInstruct::test_tp2cp2
- accuracy/test_llm_api.py::TestLlama3_1_8BInstruct::test_guided_decoding_4gpus
- accuracy/test_llm_api.py::TestLlama3_1_8BInstruct::test_guided_decoding_4gpus[xgrammar]
- examples/test_llama.py::test_llm_llama_long_alpaca_8gpu_summary[pg64317-tp4pp2-nb:4]
- examples/test_llama.py::test_llm_llama_v2_lora_benchmark_2gpu[chinese_lora-llama-v2-13b-hf]
- examples/test_mixtral.py::test_llm_mixtral_moe_plugin_lora_4gpus[Mixtral-8x7B-v0.1-chinese-mixtral-lora]

View File

@ -29,7 +29,7 @@ l0_h100:
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_bfloat16[attn_backend=FLASHINFER-torch_compile=True]
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_chunked_prefill[attn_backend=TRTLLM] TIMEOUT (90)
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_chunked_prefill[attn_backend=FLASHINFER] TIMEOUT (60)
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding[xgrammar]
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8[fp8kv=False-attn_backend=TRTLLM-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8[fp8kv=False-attn_backend=TRTLLM-torch_compile=True]
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8[fp8kv=False-attn_backend=FLASHINFER-torch_compile=False]
@ -182,6 +182,7 @@ l0_h100:
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales[mtp=vanilla-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_no_kv_cache_reuse[quant_dtype=none-mtp_nextn=2-fp8kv=False-attention_dp=True-cuda_graph=True-overlap_scheduler=True]
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_fp8_block_scales[latency]
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding[llguidance]
- condition:
ranges:
system_gpu_count:

View File

@ -433,7 +433,7 @@ accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backe
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=TRTLLM-mtp_nextn=2-tp4-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False] SKIP (https://nvbugs/5349343)
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=TRTLLM-mtp_nextn=2-ep4-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False] SKIP (https://nvbugs/5349343)
full:B200/test_e2e.py::test_ptp_quickstart_advanced_deepseek_multi_nodes[DeepSeek-R1/DeepSeek-R1-0528-FP4] SKIP (https://nvbugs/5344688)
accuracy/test_llm_api.py::TestLlama3_1_8BInstruct::test_guided_decoding_4gpus SKIP (https://nvbugs/5346443)
accuracy/test_llm_api.py::TestLlama3_1_8BInstruct::test_guided_decoding_4gpus[xgrammar] SKIP (https://nvbugs/5346443)
test_e2e.py::test_openai_reasoning SKIP (https://nvbugs/5355091)
test_e2e.py::test_openai_misc_example SKIP (https://nvbugs/5355091)
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype[mtp_nextn=0-overlap_scheduler=False] SKIP (https://nvbugs/5354956)

View File

@ -62,7 +62,7 @@ methods:
annotation: Optional[tensorrt_llm.sampling_params.BatchedLogitsProcessor]
default: null
guided_decoding_backend:
annotation: Optional[str]
annotation: Optional[Literal["xgrammar", "llguidance"]]
default: null
# Quantization and calibration
quant_config: