[TRTLLM-9157][doc] Guided decoding doc improvement (#9359)

Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>
Signed-off-by: Mike Iovine <6158008+mikeiovine@users.noreply.github.com>
Signed-off-by: Mike Iovine <miovine@nvidia.com>
This commit is contained in:
Enwei Zhu 2025-11-27 14:14:43 +08:00 committed by Mike Iovine
parent 0915c4e3a1
commit b46e78e263
3 changed files with 585 additions and 37 deletions

View File

@ -0,0 +1,583 @@
# Guided Decoding
Guided decoding (or interchangeably constrained decoding, structured generation) guarantees that the LLM outputs are amenable to a user-specified grammar (e.g., JSON schema, [regular expression](https://en.wikipedia.org/wiki/Regular_expression) or [EBNF](https://en.wikipedia.org/wiki/Extended_Backus%E2%80%93Naur_form) grammar).
TensorRT LLM supports two grammar backends:
* [XGrammar](https://github.com/mlc-ai/xgrammar/blob/v0.1.21/python/xgrammar/matcher.py#L341-L350): Supports JSON schema, regular expression, EBNF and [structural tag](https://xgrammar.mlc.ai/docs/tutorials/structural_tag.html).
* [LLGuidance](https://github.com/guidance-ai/llguidance/blob/v1.1.1/python/llguidance/_lib.pyi#L363-L366): Supports JSON schema, regular expression, EBNF.
## Online API: `trtllm-serve`
If you are using `trtllm-serve`, enable guided decoding by specifying `guided_decoding_backend` with `xgrammar` or `llguidance` in the YAML configuration file, and pass it to `--extra_llm_api_options`. For example,
```bash
cat > extra_llm_api_options.yaml <<EOF
guided_decoding_backend: xgrammar
EOF
trtllm-serve nvidia/Llama-3.1-8B-Instruct-FP8 --extra_llm_api_options extra_llm_api_options.yaml
```
You should see a log like the following, which indicates the grammar backend is successfully enabled.
```txt
......
[TRT-LLM] [I] Guided decoder initialized with backend: GuidedDecodingBackend.XGRAMMAR
......
```
### JSON Schema
Define a JSON schema and pass it to `response_format` when creating the OpenAI chat completion request. Alternatively, the JSON schema can be created using [pydantic](https://docs.pydantic.dev/latest/).
```python
from openai import OpenAI
client = OpenAI(
base_url="http://localhost:8000/v1",
api_key="tensorrt_llm",
)
json_schema = {
"type": "object",
"properties": {
"name": {
"type": "string",
"pattern": "^[\\w]+$"
},
"population": {
"type": "integer"
},
},
"required": ["name", "population"],
}
messages = [
{
"role": "system",
"content": "You are a helpful assistant.",
},
{
"role": "user",
"content": "Give me the information of the capital of France in the JSON format.",
},
]
chat_completion = client.chat.completions.create(
model="nvidia/Llama-3.1-8B-Instruct-FP8",
messages=messages,
max_completion_tokens=256,
response_format={
"type": "json",
"schema": json_schema
},
)
message = chat_completion.choices[0].message
print(message.content)
```
The output would look like:
```txt
{
"name": "Paris",
"population": 2145200
}
```
### Regular expression
Define a regular expression and pass it to `response_format` when creating the OpenAI chat completion request.
```python
from openai import OpenAI
client = OpenAI(
base_url="http://localhost:8000/v1",
api_key="tensorrt_llm",
)
messages = [
{
"role": "system",
"content": "You are a helpful assistant.",
},
{
"role": "user",
"content": "What is the capital of France?",
},
]
chat_completion = client.chat.completions.create(
model="nvidia/Llama-3.1-8B-Instruct-FP8",
messages=messages,
max_completion_tokens=256,
response_format={
"type": "regex",
"regex": "(Paris|London)"
},
)
message = chat_completion.choices[0].message
print(message.content)
```
The output would look like:
```txt
Paris
```
### EBNF grammar
Define an EBNF grammar and pass it to `response_format` when creating the OpenAI chat completion request.
```python
from openai import OpenAI
client = OpenAI(
base_url="http://localhost:8000/v1",
api_key="tensorrt_llm",
)
ebnf_grammar = """root ::= description
city ::= "London" | "Paris" | "Berlin" | "Rome"
description ::= city " is " status
status ::= "the capital of " country
country ::= "England" | "France" | "Germany" | "Italy"
"""
messages = [
{
"role": "system",
"content": "You are a helpful geography bot."
},
{
"role": "user",
"content": "Give me the information of the capital of France.",
},
]
chat_completion = client.chat.completions.create(
model="nvidia/Llama-3.1-8B-Instruct-FP8",
messages=messages,
max_completion_tokens=256,
response_format={
"type": "ebnf",
"ebnf": ebnf_grammar
},
)
message = chat_completion.choices[0].message
print(message.content)
```
The output would look like:
```txt
Paris is the capital of France
```
### Structural tag
Define a structural tag and pass it to `response_format` when creating the OpenAI chat completion request.
Structural tag is supported by `xgrammar` backend only. It is a powerful and flexible tool to represent the LLM output constraints. Please see [structural tag usage](https://xgrammar.mlc.ai/docs/tutorials/structural_tag.html) for a comprehensive tutorial. Below is an example of function calling with customized function call format for `Llama-3.1-8B-Instruct`.
```python
from openai import OpenAI
client = OpenAI(
base_url="http://localhost:8000/v1",
api_key="tensorrt_llm",
)
tool_get_current_weather = {
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"city": {
"type": "string",
"description": "The city to find the weather for, e.g. 'San Francisco'",
},
"state": {
"type": "string",
"description": "the two-letter abbreviation for the state that the city is in, e.g. 'CA' which would mean 'California'",
},
"unit": {
"type": "string",
"description": "The unit to fetch the temperature in",
"enum": ["celsius", "fahrenheit"],
},
},
"required": ["city", "state", "unit"],
},
},
}
tool_get_current_date = {
"type": "function",
"function": {
"name": "get_current_date",
"description": "Get the current date and time for a given timezone",
"parameters": {
"type": "object",
"properties": {
"timezone": {
"type": "string",
"description": "The timezone to fetch the current date and time for, e.g. 'America/New_York'",
}
},
"required": ["timezone"],
},
},
}
system_prompt = f"""# Tool Instructions
- Always execute python code in messages that you share.
- When looking for real time information use relevant functions if available else fallback to brave_search
You have access to the following functions:
Use the function 'get_current_weather' to: Get the current weather in a given location
{tool_get_current_weather["function"]}
Use the function 'get_current_date' to: Get the current date and time for a given timezone
{tool_get_current_date["function"]}
If a you choose to call a function ONLY reply in the following format:
<{{start_tag}}={{function_name}}>{{parameters}}{{end_tag}}
where
start_tag => `<function`
parameters => a JSON dict with the function argument name as key and function argument value as value.
end_tag => `</function>`
Here is an example,
<function=example_function_name>{{"example_name": "example_value"}}</function>
Reminder:
- Function calls MUST follow the specified format
- Required parameters MUST be specified
- Only call one function at a time
- Put the entire function call reply on one line
- Always add your sources when using search results to answer the user query
You are a helpful assistant."""
user_prompt = "You are in New York. Please get the current date and time, and the weather."
messages = [
{
"role": "system",
"content": system_prompt,
},
{
"role": "user",
"content": user_prompt,
},
]
chat_completion = client.chat.completions.create(
model="nvidia/Llama-3.1-8B-Instruct-FP8",
messages=messages,
max_completion_tokens=256,
response_format={
"type": "structural_tag",
"format": {
"type": "triggered_tags",
"triggers": ["<function="],
"tags": [
{
"begin": "<function=get_current_weather>",
"content": {
"type": "json_schema",
"json_schema": tool_get_current_weather["function"]["parameters"]
},
"end": "</function>",
},
{
"begin": "<function=get_current_date>",
"content": {
"type": "json_schema",
"json_schema": tool_get_current_date["function"]["parameters"]
},
"end": "</function>",
},
],
},
},
)
message = chat_completion.choices[0].message
print(message.content)
```
The output would look like:
```txt
<function=get_current_date>{"timezone": "America/New_York"}</function>
<function=get_current_weather>{"city": "New York", "state": "NY", "unit": "fahrenheit"}</function>
```
## Offline API: LLM API
If you are using LLM API, enable guided decoding by specifying `guided_decoding_backend` with `xgrammar` or `llguidance` when creating the LLM instance. For example,
```python
from tensorrt_llm import LLM
llm = LLM("nvidia/Llama-3.1-8B-Instruct-FP8", guided_decoding_backend="xgrammar")
```
### JSON Schema
Create a `GuidedDecodingParams` with the `json` field specified with a JSON schema, use it to create `SamplingParams`, and then pass to `llm.generate` or `llm.generate_async`. Alternatively, the JSON schema can be created using [pydantic](https://docs.pydantic.dev/latest/).
```python
from tensorrt_llm import LLM
from tensorrt_llm.sampling_params import SamplingParams, GuidedDecodingParams
if __name__ == "__main__":
llm = LLM("nvidia/Llama-3.1-8B-Instruct-FP8", guided_decoding_backend="xgrammar")
json_schema = {
"type": "object",
"properties": {
"name": {
"type": "string",
"pattern": "^[\\w]+$"
},
"population": {
"type": "integer"
},
},
"required": ["name", "population"],
}
messages = [
{
"role": "system",
"content": "You are a helpful assistant.",
},
{
"role": "user",
"content": "Give me the information of the capital of France in the JSON format.",
},
]
prompt = llm.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
output = llm.generate(
prompt,
sampling_params=SamplingParams(max_tokens=256, guided_decoding=GuidedDecodingParams(json=json_schema)),
)
print(output.outputs[0].text)
```
The output would look like:
```txt
{
"name": "Paris",
"population": 2145206
}
```
### Regular expression
Create a `GuidedDecodingParams` with the `regex` field specified with a regular expression, use it to create `SamplingParams`, and then pass to `llm.generate` or `llm.generate_async`.
```python
from tensorrt_llm import LLM
from tensorrt_llm.sampling_params import SamplingParams, GuidedDecodingParams
if __name__ == "__main__":
llm = LLM("nvidia/Llama-3.1-8B-Instruct-FP8", guided_decoding_backend="xgrammar")
messages = [
{
"role": "system",
"content": "You are a helpful assistant.",
},
{
"role": "user",
"content": "What is the capital of France?",
},
]
prompt = llm.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
output = llm.generate(
prompt,
sampling_params=SamplingParams(max_tokens=256, guided_decoding=GuidedDecodingParams(regex="(Paris|London)")),
)
print(output.outputs[0].text)
```
The output would look like:
```txt
Paris
```
### EBNF grammar
Create a `GuidedDecodingParams` with the `grammar` field specified with an EBNF grammar, use it to create `SamplingParams`, and then pass to `llm.generate` or `llm.generate_async`.
```python
from tensorrt_llm import LLM
from tensorrt_llm.sampling_params import SamplingParams, GuidedDecodingParams
if __name__ == "__main__":
llm = LLM("nvidia/Llama-3.1-8B-Instruct-FP8", guided_decoding_backend="xgrammar")
ebnf_grammar = """root ::= description
city ::= "London" | "Paris" | "Berlin" | "Rome"
description ::= city " is " status
status ::= "the capital of " country
country ::= "England" | "France" | "Germany" | "Italy"
"""
messages = [
{
"role": "system",
"content": "You are a helpful geography bot."
},
{
"role": "user",
"content": "Give me the information of the capital of France.",
},
]
prompt = llm.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
output = llm.generate(
prompt,
sampling_params=SamplingParams(max_tokens=256, guided_decoding=GuidedDecodingParams(grammar=ebnf_grammar)),
)
print(output.outputs[0].text)
```
The output would look like:
```txt
Paris is the capital of France
```
### Structural tag
Create a `GuidedDecodingParams` with the `structural_tag` field specified with a structural tag string, use it to create `SamplingParams`, and then pass to `llm.generate` or `llm.generate_async`.
Structural tag is supported by `xgrammar` backend only. It is a powerful and flexible tool to represent the LLM output constraints. Please see [structural tag usage](https://xgrammar.mlc.ai/docs/tutorials/structural_tag.html) for a comprehensive tutorial. Below is an example of function calling with customized function call format for `Llama-3.1-8B-Instruct`.
```python
import json
from tensorrt_llm import LLM
from tensorrt_llm.sampling_params import SamplingParams, GuidedDecodingParams
if __name__ == "__main__":
llm = LLM("nvidia/Llama-3.1-8B-Instruct-FP8", guided_decoding_backend="xgrammar")
tool_get_current_weather = {
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"city": {
"type": "string",
"description": "The city to find the weather for, e.g. 'San Francisco'",
},
"state": {
"type": "string",
"description": "the two-letter abbreviation for the state that the city is in, e.g. 'CA' which would mean 'California'",
},
"unit": {
"type": "string",
"description": "The unit to fetch the temperature in",
"enum": ["celsius", "fahrenheit"],
},
},
"required": ["city", "state", "unit"],
},
},
}
tool_get_current_date = {
"type": "function",
"function": {
"name": "get_current_date",
"description": "Get the current date and time for a given timezone",
"parameters": {
"type": "object",
"properties": {
"timezone": {
"type": "string",
"description": "The timezone to fetch the current date and time for, e.g. 'America/New_York'",
}
},
"required": ["timezone"],
},
},
}
system_prompt = f"""# Tool Instructions
- Always execute python code in messages that you share.
- When looking for real time information use relevant functions if available else fallback to brave_search
You have access to the following functions:
Use the function 'get_current_weather' to: Get the current weather in a given location
{tool_get_current_weather["function"]}
Use the function 'get_current_date' to: Get the current date and time for a given timezone
{tool_get_current_date["function"]}
If a you choose to call a function ONLY reply in the following format:
<{{start_tag}}={{function_name}}>{{parameters}}{{end_tag}}
where
start_tag => `<function`
parameters => a JSON dict with the function argument name as key and function argument value as value.
end_tag => `</function>`
Here is an example,
<function=example_function_name>{{"example_name": "example_value"}}</function>
Reminder:
- Function calls MUST follow the specified format
- Required parameters MUST be specified
- Only call one function at a time
- Put the entire function call reply on one line
- Always add your sources when using search results to answer the user query
You are a helpful assistant."""
user_prompt = "You are in New York. Please get the current date and time, and the weather."
structural_tag = {
"type": "structural_tag",
"format": {
"type": "triggered_tags",
"triggers": ["<function="],
"tags": [
{
"begin": "<function=get_current_weather>",
"content": {
"type": "json_schema",
"json_schema": tool_get_current_weather["function"]["parameters"]
},
"end": "</function>",
},
{
"begin": "<function=get_current_date>",
"content": {
"type": "json_schema",
"json_schema": tool_get_current_date["function"]["parameters"]
},
"end": "</function>",
},
],
},
}
messages = [
{
"role": "system",
"content": system_prompt,
},
{
"role": "user",
"content": user_prompt,
},
]
prompt = llm.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
output = llm.generate(
prompt,
sampling_params=SamplingParams(max_tokens=256, guided_decoding=GuidedDecodingParams(structural_tag=json.dumps(structural_tag))),
)
print(output.outputs[0].text)
```
The output would look like:
```txt
<function=get_current_date>{"timezone": "America/New_York"}</function>
<function=get_current_weather>{"city": "New York", "state": "NY", "unit": "fahrenheit"}</function>
```

View File

@ -1,5 +1,5 @@
# Sampling
The PyTorch backend supports most of the sampling features that are supported on the C++ backend, such as temperature, top-k and top-p sampling, beam search, stop words, bad words, penalty, context and generation logits, log probability, guided decoding and logits processors
The PyTorch backend supports most of the sampling features that are supported on the C++ backend, such as temperature, top-k and top-p sampling, beam search, stop words, bad words, penalty, context and generation logits, log probability and logits processors
## General usage
@ -60,42 +60,6 @@ llm.generate(["Hello, my name is",
"Hello, my name is"], sampling_params)
```
## Guided decoding
Guided decoding controls the generation outputs to conform to pre-defined structured formats, ensuring outputs follow specific schemas or patterns.
The PyTorch backend supports guided decoding with the XGrammar and Low-level Guidance (llguidance) backends and the following formats:
- JSON schema
- JSON object
- Regular expressions
- Extended Backus-Naur form (EBNF) grammar
- Structural tags
To enable guided decoding, you must:
1. Set the `guided_decoding_backend` parameter to `'xgrammar'` or `'llguidance'` in the `LLM` class
2. Create a [`GuidedDecodingParams`](source:tensorrt_llm/sampling_params.py#L14) object with the desired format specification
* Note: Depending on the type of format, a different parameter needs to be chosen to construct the object (`json`, `regex`, `grammar`, `structural_tag`).
3. Pass the `GuidedDecodingParams` object to the `guided_decoding` parameter of the `SamplingParams` object
The following example demonstrates guided decoding with a JSON schema:
```python
from tensorrt_llm import LLM, SamplingParams
from tensorrt_llm.llmapi import GuidedDecodingParams
llm = LLM(model='nvidia/Llama-3.1-8B-Instruct-FP8',
guided_decoding_backend='xgrammar')
structure = '{"title": "Example JSON", "type": "object", "properties": {...}}'
guided_decoding_params = GuidedDecodingParams(json=structure)
sampling_params = SamplingParams(
guided_decoding=guided_decoding_params,
)
llm.generate("Generate a JSON response", sampling_params)
```
You can find a more detailed example on guided decoding [here](source:examples/llm-api/llm_guided_decoding.py).
## Logits processor
Logits processors allow you to modify the logits produced by the network before sampling, enabling custom generation behavior and constraints.

View File

@ -71,6 +71,7 @@ Welcome to TensorRT LLM's Documentation!
features/quantization.md
features/sampling.md
features/additional-outputs.md
features/guided-decoding.md
features/speculative-decoding.md
features/checkpoint-loading.md
features/auto_deploy/auto-deploy.md