diff --git a/cpp/tensorrt_llm/batch_manager/guidedDecoder.cpp b/cpp/tensorrt_llm/batch_manager/guidedDecoder.cpp index edf8917d27..871a33e3ee 100644 --- a/cpp/tensorrt_llm/batch_manager/guidedDecoder.cpp +++ b/cpp/tensorrt_llm/batch_manager/guidedDecoder.cpp @@ -35,6 +35,8 @@ GuidedDecoder::GuidedDecoder(executor::GuidedDecodingConfig const& guidedDecodin , mLogitsDtype{logitsDtype} , mCopyBufferManager{std::make_shared()} { + TLLM_CHECK_WITH_INFO(mGuidedDecodingBackend != executor::GuidedDecodingConfig::GuidedDecodingBackend::kLLGUIDANCE, + "LLGuidance is not supported for guided decoding in C++ runtime."); if (mGuidedDecodingBackend == executor::GuidedDecodingConfig::GuidedDecodingBackend::kXGRAMMAR) { mXGrammarMatchers.resize(mMaxNumSequences); diff --git a/cpp/tensorrt_llm/executor/executorImpl.cpp b/cpp/tensorrt_llm/executor/executorImpl.cpp index c2118bbf77..b9df58f9a7 100644 --- a/cpp/tensorrt_llm/executor/executorImpl.cpp +++ b/cpp/tensorrt_llm/executor/executorImpl.cpp @@ -1621,6 +1621,9 @@ std::tuple Executor::Impl::fetchNewRequests TLLM_CHECK_WITH_INFO(mModel->hasGuidedDecoder(), "Request is specified with GuidedDecodingParams, but GuidedDecoder is not setup. Please " "provide a valid GuidedDecodingConfig to setup GuidedDecoder."); + TLLM_CHECK_WITH_INFO(newReq->getGuidedDecodingParams()->getGuideType() + != executor::GuidedDecodingParams::GuideType::kSTRUCTURAL_TAG, + "Structural tag is not supported for guided decoding in C++ Executor."); } if (mModel->getWorldConfig().isLastPipelineParallelRank() && newReq->hasAdditionalOutputs()) diff --git a/examples/trtllm-eval/README.md b/examples/trtllm-eval/README.md index c06310c388..217e292981 100644 --- a/examples/trtllm-eval/README.md +++ b/examples/trtllm-eval/README.md @@ -10,8 +10,7 @@ We provide a CLI tool `trtllm-eval` for evaluating model accuracy. It shares the pip install -r requirements.txt # Evaluate Llama-3.1-8B-Instruct on MMLU -wget https://people.eecs.berkeley.edu/~hendrycks/data.tar && tar -xf data.tar -trtllm-eval --model meta-llama/Llama-3.1-8B-Instruct mmlu --dataset_path data +trtllm-eval --model meta-llama/Llama-3.1-8B-Instruct mmlu # Evaluate Llama-3.1-8B-Instruct on GSM8K trtllm-eval --model meta-llama/Llama-3.1-8B-Instruct gsm8k diff --git a/tensorrt_llm/_torch/pyexecutor/grammar_matcher.py b/tensorrt_llm/_torch/pyexecutor/grammar_matcher.py index 536caab6d3..23b0c729c0 100644 --- a/tensorrt_llm/_torch/pyexecutor/grammar_matcher.py +++ b/tensorrt_llm/_torch/pyexecutor/grammar_matcher.py @@ -96,7 +96,7 @@ class XGrammarMatcherFactory(GrammarMatcherFactory): compiled_grammar = self._xgrammar_compiler.compile_structural_tag( structures, triggers) case _: - raise ValueError(f"Unrecognized guide type: {guide_type}.") + raise ValueError(f"Unsupported guide type: {guide_type}.") matcher = xgrammar.GrammarMatcher(compiled_grammar) return XGrammarMatcher(matcher) @@ -167,7 +167,7 @@ class LLGuidanceMatcherFactory(GrammarMatcherFactory): # provide Lark-formatted grammar instead of standard EBNF. grammar = llguidance.LLMatcher.grammar_from_lark(guide) case _: - raise ValueError(f"Unrecognized guide type: {guide_type}.") + raise ValueError(f"Unsupported guide type: {guide_type}.") matcher = llguidance.LLMatcher(self._tokenizer, grammar) if matcher.is_error(): diff --git a/tensorrt_llm/evaluate/json_mode_eval.py b/tensorrt_llm/evaluate/json_mode_eval.py index 69c41699cd..1854488bce 100644 --- a/tensorrt_llm/evaluate/json_mode_eval.py +++ b/tensorrt_llm/evaluate/json_mode_eval.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import json +import os from typing import Iterable, List, Optional, Union import click @@ -56,8 +57,13 @@ class JsonModeEval(Evaluator): for i, sample in enumerate(self.data): if i >= self.num_samples: break + schema = sample["schema"] + if os.environ.get("TRTLLM_XGUIDANCE_LENIENT") == "1": + schema = json.loads(schema) + schema["x-guidance"] = {"lenient": True} + schema = json.dumps(schema) sampling_args = { - "guided_decoding": GuidedDecodingParams(json=sample["schema"]) + "guided_decoding": GuidedDecodingParams(json=schema) } yield sample["prompt"], sampling_args, sample["completion"] diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index 636740d599..a0c2124960 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -879,8 +879,11 @@ class BaseLlmArgs(BaseModel): enable_chunked_prefill: bool = Field(default=False, description="Enable chunked prefill.") - guided_decoding_backend: Optional[str] = Field( - default=None, description="Guided decoding backend.") + guided_decoding_backend: Optional[Literal["xgrammar", "llguidance"]] = Field( + default=None, + description= + "Guided decoding backend. llguidance is supported in PyTorch backend only." + ) batched_logits_processor: Optional[object] = Field( default=None, diff --git a/tensorrt_llm/sampling_params.py b/tensorrt_llm/sampling_params.py index 04442a2e1f..42ccf02f13 100644 --- a/tensorrt_llm/sampling_params.py +++ b/tensorrt_llm/sampling_params.py @@ -19,7 +19,7 @@ class GuidedDecodingParams: regex (str, optional): The generated text is amenable to the user-specified regular expression. Defaults to None. grammar (str, optional): The generated text is amenable to the user-specified extended Backus-Naur form (EBNF) grammar. Defaults to None. json_object (bool): If True, the generated text is amenable to json format. Defaults to False. - structural_tag (str, optional): The generated text is amenable to the user-specified structural tag. Defaults to None. + structural_tag (str, optional): The generated text is amenable to the user-specified structural tag. Structural tag is supported by xgrammar in PyTorch backend only. Defaults to None. """ # noqa: E501 json: Optional[Union[str, BaseModel, dict]] = None diff --git a/tests/integration/defs/accuracy/test_llm_api.py b/tests/integration/defs/accuracy/test_llm_api.py index 8cdb49cf56..576b56b817 100644 --- a/tests/integration/defs/accuracy/test_llm_api.py +++ b/tests/integration/defs/accuracy/test_llm_api.py @@ -58,16 +58,18 @@ class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness): task = GSM8K(self.MODEL_NAME) task.evaluate(llm) - def test_guided_decoding(self): - llm = LLM(self.MODEL_PATH, guided_decoding_backend="xgrammar") + @pytest.mark.parametrize("backend", ["xgrammar"]) + def test_guided_decoding(self, backend: str): + llm = LLM(self.MODEL_PATH, guided_decoding_backend=backend) with llm: task = JsonModeEval(self.MODEL_NAME) task.evaluate(llm) @pytest.mark.skip_less_device(4) - def test_guided_decoding_4gpus(self): + @pytest.mark.parametrize("backend", ["xgrammar"]) + def test_guided_decoding_4gpus(self, backend: str): llm = LLM(self.MODEL_PATH, - guided_decoding_backend="xgrammar", + guided_decoding_backend=backend, tensor_parallel_size=2, pipeline_parallel_size=2) with llm: diff --git a/tests/integration/defs/accuracy/test_llm_api_pytorch.py b/tests/integration/defs/accuracy/test_llm_api_pytorch.py index c4bfad8a8e..9b4d47c947 100644 --- a/tests/integration/defs/accuracy/test_llm_api_pytorch.py +++ b/tests/integration/defs/accuracy/test_llm_api_pytorch.py @@ -12,6 +12,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os + import pytest from tensorrt_llm import LLM @@ -277,9 +279,11 @@ class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness): task = MMLU(self.MODEL_NAME) task.evaluate(llm) - def test_guided_decoding(self): + @pytest.mark.parametrize("backend", ["xgrammar", "llguidance"]) + def test_guided_decoding(self, backend: str, mocker): + mocker.patch.dict(os.environ, {"TRTLLM_XGUIDANCE_LENIENT": "1"}) llm = LLM(self.MODEL_PATH, - guided_decoding_backend="xgrammar", + guided_decoding_backend=backend, disable_overlap_scheduler=True, use_cuda_graph=True) with llm: @@ -287,9 +291,11 @@ class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness): task.evaluate(llm) @pytest.mark.skip_less_device(4) - def test_guided_decoding_4gpus(self): + @pytest.mark.parametrize("backend", ["xgrammar", "llguidance"]) + def test_guided_decoding_4gpus(self, backend: str, mocker): + mocker.patch.dict(os.environ, {"TRTLLM_XGUIDANCE_LENIENT": "1"}) llm = LLM(self.MODEL_PATH, - guided_decoding_backend="xgrammar", + guided_decoding_backend=backend, disable_overlap_scheduler=True, use_cuda_graph=True, tensor_parallel_size=2, diff --git a/tests/integration/test_lists/qa/examples_test_list.txt b/tests/integration/test_lists/qa/examples_test_list.txt index bf8a86411a..4d0d0da24e 100644 --- a/tests/integration/test_lists/qa/examples_test_list.txt +++ b/tests/integration/test_lists/qa/examples_test_list.txt @@ -418,8 +418,8 @@ accuracy/test_llm_api.py::TestQwen2_7BInstruct::test_weight_only accuracy/test_cli_flow.py::TestQwen2_7BInstruct::test_int4_awq_prequantized accuracy/test_cli_flow.py::TestQwen2_57B_A14B::test_tp4 accuracy/test_cli_flow.py::TestQwen2_57B_A14B::test_tp2pp2 -accuracy/test_llm_api.py::TestLlama3_1_8BInstruct::test_guided_decoding -accuracy/test_llm_api.py::TestLlama3_1_8BInstruct::test_guided_decoding_4gpus +accuracy/test_llm_api.py::TestLlama3_1_8BInstruct::test_guided_decoding[xgrammar] +accuracy/test_llm_api.py::TestLlama3_1_8BInstruct::test_guided_decoding_4gpus[xgrammar] accuracy/test_llm_api.py::TestQwen2_5_1_5BInstruct::test_auto_dtype accuracy/test_llm_api.py::TestQwen2_5_1_5BInstruct::test_weight_only accuracy/test_llm_api.py::TestLlama3_1_8B::test_fp8_rowwise @@ -440,8 +440,10 @@ accuracy/test_llm_api_pytorch.py::TestLlama3_1_8B::test_nvfp4 accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_llm_sampler accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_eagle3 accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_ngram -accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding -accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding_4gpus +accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding[xgrammar] +accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding[llguidance] +accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding_4gpus[xgrammar] +accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding_4gpus[llguidance] accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp8_tp4 accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_nvfp4_tp4 accuracy/test_cli_flow.py::TestLlama3_3_70BInstruct::test_fp8_prequantized_tp4 diff --git a/tests/integration/test_lists/test-db/l0_a100.yml b/tests/integration/test_lists/test-db/l0_a100.yml index 5f152ad9ce..3cf77d6067 100644 --- a/tests/integration/test_lists/test-db/l0_a100.yml +++ b/tests/integration/test_lists/test-db/l0_a100.yml @@ -46,7 +46,7 @@ l0_a100: - accuracy/test_cli_flow.py::TestLlama3_2_1B::test_smooth_quant_ootb_manage_weights - accuracy/test_cli_flow.py::TestLlama3_8BInstruct::test_int8_gptq - accuracy/test_cli_flow.py::TestQwen2_7BInstruct::test_int4_awq_prequantized - - accuracy/test_llm_api.py::TestLlama3_1_8BInstruct::test_guided_decoding + - accuracy/test_llm_api.py::TestLlama3_1_8BInstruct::test_guided_decoding[xgrammar] - condition: ranges: system_gpu_count: diff --git a/tests/integration/test_lists/test-db/l0_dgx_h100.yml b/tests/integration/test_lists/test-db/l0_dgx_h100.yml index 569e07377a..8c874e82f3 100644 --- a/tests/integration/test_lists/test-db/l0_dgx_h100.yml +++ b/tests/integration/test_lists/test-db/l0_dgx_h100.yml @@ -31,7 +31,7 @@ l0_dgx_h100: - accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[tp2pp2-fp8kv=True-attn_backend=TRTLLM-torch_compile=False] - accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[pp4-fp8kv=True-attn_backend=TRTLLM-torch_compile=False] - accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[tp2pp2-fp8kv=True-attn_backend=TRTLLM-torch_compile=True] - - accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding_4gpus + - accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding_4gpus[xgrammar] - disaggregated/test_disaggregated.py::test_disaggregated_multi_gpu_with_mpirun[TinyLlama-1.1B-Chat-v1.0] - disaggregated/test_disaggregated.py::test_disaggregated_multi_gpu_with_mpirun_trt_backend[TinyLlama-1.1B-Chat-v1.0] - disaggregated/test_disaggregated.py::test_disaggregated_cuda_graph[TinyLlama-1.1B-Chat-v1.0] diff --git a/tests/integration/test_lists/test-db/l0_dgx_h200.yml b/tests/integration/test_lists/test-db/l0_dgx_h200.yml index 6fd0b9875f..657ab4475a 100644 --- a/tests/integration/test_lists/test-db/l0_dgx_h200.yml +++ b/tests/integration/test_lists/test-db/l0_dgx_h200.yml @@ -92,6 +92,7 @@ l0_dgx_h200: - accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[tp2pp2-fp8kv=True-attn_backend=FLASHINFER-torch_compile=False] - accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[tp2pp2-fp8kv=True-attn_backend=FLASHINFER-torch_compile=True] - accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[pp4-fp8kv=True-attn_backend=FLASHINFER-torch_compile=False] + - accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding_4gpus[llguidance] - condition: ranges: system_gpu_count: @@ -120,7 +121,7 @@ l0_dgx_h200: - accuracy/test_llm_api.py::TestQwen2_7BInstruct::test_tp2 - accuracy/test_llm_api.py::TestLlama3_1_8BInstruct::test_cp2 - accuracy/test_llm_api.py::TestLlama3_1_8BInstruct::test_tp2cp2 - - accuracy/test_llm_api.py::TestLlama3_1_8BInstruct::test_guided_decoding_4gpus + - accuracy/test_llm_api.py::TestLlama3_1_8BInstruct::test_guided_decoding_4gpus[xgrammar] - examples/test_llama.py::test_llm_llama_long_alpaca_8gpu_summary[pg64317-tp4pp2-nb:4] - examples/test_llama.py::test_llm_llama_v2_lora_benchmark_2gpu[chinese_lora-llama-v2-13b-hf] - examples/test_mixtral.py::test_llm_mixtral_moe_plugin_lora_4gpus[Mixtral-8x7B-v0.1-chinese-mixtral-lora] diff --git a/tests/integration/test_lists/test-db/l0_h100.yml b/tests/integration/test_lists/test-db/l0_h100.yml index bb6af71122..42c8dbb8bb 100644 --- a/tests/integration/test_lists/test-db/l0_h100.yml +++ b/tests/integration/test_lists/test-db/l0_h100.yml @@ -29,7 +29,7 @@ l0_h100: - accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_bfloat16[attn_backend=FLASHINFER-torch_compile=True] - accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_chunked_prefill[attn_backend=TRTLLM] TIMEOUT (90) - accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_chunked_prefill[attn_backend=FLASHINFER] TIMEOUT (60) - - accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding + - accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding[xgrammar] - accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8[fp8kv=False-attn_backend=TRTLLM-torch_compile=False] - accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8[fp8kv=False-attn_backend=TRTLLM-torch_compile=True] - accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8[fp8kv=False-attn_backend=FLASHINFER-torch_compile=False] @@ -182,6 +182,7 @@ l0_h100: - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales[mtp=vanilla-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_no_kv_cache_reuse[quant_dtype=none-mtp_nextn=2-fp8kv=False-attention_dp=True-cuda_graph=True-overlap_scheduler=True] - accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_fp8_block_scales[latency] + - accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding[llguidance] - condition: ranges: system_gpu_count: diff --git a/tests/integration/test_lists/waives.txt b/tests/integration/test_lists/waives.txt index 6b69817e3b..11ee02bdf7 100644 --- a/tests/integration/test_lists/waives.txt +++ b/tests/integration/test_lists/waives.txt @@ -433,7 +433,7 @@ accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backe accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=TRTLLM-mtp_nextn=2-tp4-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False] SKIP (https://nvbugs/5349343) accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=TRTLLM-mtp_nextn=2-ep4-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False] SKIP (https://nvbugs/5349343) full:B200/test_e2e.py::test_ptp_quickstart_advanced_deepseek_multi_nodes[DeepSeek-R1/DeepSeek-R1-0528-FP4] SKIP (https://nvbugs/5344688) -accuracy/test_llm_api.py::TestLlama3_1_8BInstruct::test_guided_decoding_4gpus SKIP (https://nvbugs/5346443) +accuracy/test_llm_api.py::TestLlama3_1_8BInstruct::test_guided_decoding_4gpus[xgrammar] SKIP (https://nvbugs/5346443) test_e2e.py::test_openai_reasoning SKIP (https://nvbugs/5355091) test_e2e.py::test_openai_misc_example SKIP (https://nvbugs/5355091) accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype[mtp_nextn=0-overlap_scheduler=False] SKIP (https://nvbugs/5354956) diff --git a/tests/unittest/api_stability/references_committed/llm.yaml b/tests/unittest/api_stability/references_committed/llm.yaml index a30e62645f..7d26c9093e 100644 --- a/tests/unittest/api_stability/references_committed/llm.yaml +++ b/tests/unittest/api_stability/references_committed/llm.yaml @@ -62,7 +62,7 @@ methods: annotation: Optional[tensorrt_llm.sampling_params.BatchedLogitsProcessor] default: null guided_decoding_backend: - annotation: Optional[str] + annotation: Optional[Literal["xgrammar", "llguidance"]] default: null # Quantization and calibration quant_config: