mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[TRTLLM-9157][doc] Guided decoding doc improvement (#9359)
Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com> Signed-off-by: Mike Iovine <6158008+mikeiovine@users.noreply.github.com> Signed-off-by: Mike Iovine <miovine@nvidia.com>
This commit is contained in:
parent
0915c4e3a1
commit
b46e78e263
583
docs/source/features/guided-decoding.md
Normal file
583
docs/source/features/guided-decoding.md
Normal file
@ -0,0 +1,583 @@
|
||||
# Guided Decoding
|
||||
|
||||
Guided decoding (or interchangeably constrained decoding, structured generation) guarantees that the LLM outputs are amenable to a user-specified grammar (e.g., JSON schema, [regular expression](https://en.wikipedia.org/wiki/Regular_expression) or [EBNF](https://en.wikipedia.org/wiki/Extended_Backus%E2%80%93Naur_form) grammar).
|
||||
|
||||
TensorRT LLM supports two grammar backends:
|
||||
* [XGrammar](https://github.com/mlc-ai/xgrammar/blob/v0.1.21/python/xgrammar/matcher.py#L341-L350): Supports JSON schema, regular expression, EBNF and [structural tag](https://xgrammar.mlc.ai/docs/tutorials/structural_tag.html).
|
||||
* [LLGuidance](https://github.com/guidance-ai/llguidance/blob/v1.1.1/python/llguidance/_lib.pyi#L363-L366): Supports JSON schema, regular expression, EBNF.
|
||||
|
||||
|
||||
## Online API: `trtllm-serve`
|
||||
|
||||
If you are using `trtllm-serve`, enable guided decoding by specifying `guided_decoding_backend` with `xgrammar` or `llguidance` in the YAML configuration file, and pass it to `--extra_llm_api_options`. For example,
|
||||
|
||||
```bash
|
||||
cat > extra_llm_api_options.yaml <<EOF
|
||||
guided_decoding_backend: xgrammar
|
||||
EOF
|
||||
|
||||
trtllm-serve nvidia/Llama-3.1-8B-Instruct-FP8 --extra_llm_api_options extra_llm_api_options.yaml
|
||||
```
|
||||
|
||||
You should see a log like the following, which indicates the grammar backend is successfully enabled.
|
||||
|
||||
```txt
|
||||
......
|
||||
[TRT-LLM] [I] Guided decoder initialized with backend: GuidedDecodingBackend.XGRAMMAR
|
||||
......
|
||||
```
|
||||
|
||||
### JSON Schema
|
||||
|
||||
Define a JSON schema and pass it to `response_format` when creating the OpenAI chat completion request. Alternatively, the JSON schema can be created using [pydantic](https://docs.pydantic.dev/latest/).
|
||||
|
||||
```python
|
||||
from openai import OpenAI
|
||||
|
||||
client = OpenAI(
|
||||
base_url="http://localhost:8000/v1",
|
||||
api_key="tensorrt_llm",
|
||||
)
|
||||
|
||||
json_schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {
|
||||
"type": "string",
|
||||
"pattern": "^[\\w]+$"
|
||||
},
|
||||
"population": {
|
||||
"type": "integer"
|
||||
},
|
||||
},
|
||||
"required": ["name", "population"],
|
||||
}
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant.",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Give me the information of the capital of France in the JSON format.",
|
||||
},
|
||||
]
|
||||
chat_completion = client.chat.completions.create(
|
||||
model="nvidia/Llama-3.1-8B-Instruct-FP8",
|
||||
messages=messages,
|
||||
max_completion_tokens=256,
|
||||
response_format={
|
||||
"type": "json",
|
||||
"schema": json_schema
|
||||
},
|
||||
)
|
||||
|
||||
message = chat_completion.choices[0].message
|
||||
print(message.content)
|
||||
```
|
||||
|
||||
The output would look like:
|
||||
```txt
|
||||
{
|
||||
"name": "Paris",
|
||||
"population": 2145200
|
||||
}
|
||||
```
|
||||
|
||||
### Regular expression
|
||||
|
||||
Define a regular expression and pass it to `response_format` when creating the OpenAI chat completion request.
|
||||
|
||||
```python
|
||||
from openai import OpenAI
|
||||
|
||||
client = OpenAI(
|
||||
base_url="http://localhost:8000/v1",
|
||||
api_key="tensorrt_llm",
|
||||
)
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant.",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What is the capital of France?",
|
||||
},
|
||||
]
|
||||
chat_completion = client.chat.completions.create(
|
||||
model="nvidia/Llama-3.1-8B-Instruct-FP8",
|
||||
messages=messages,
|
||||
max_completion_tokens=256,
|
||||
response_format={
|
||||
"type": "regex",
|
||||
"regex": "(Paris|London)"
|
||||
},
|
||||
)
|
||||
|
||||
message = chat_completion.choices[0].message
|
||||
print(message.content)
|
||||
```
|
||||
|
||||
The output would look like:
|
||||
```txt
|
||||
Paris
|
||||
```
|
||||
|
||||
### EBNF grammar
|
||||
|
||||
Define an EBNF grammar and pass it to `response_format` when creating the OpenAI chat completion request.
|
||||
|
||||
```python
|
||||
from openai import OpenAI
|
||||
|
||||
client = OpenAI(
|
||||
base_url="http://localhost:8000/v1",
|
||||
api_key="tensorrt_llm",
|
||||
)
|
||||
|
||||
ebnf_grammar = """root ::= description
|
||||
city ::= "London" | "Paris" | "Berlin" | "Rome"
|
||||
description ::= city " is " status
|
||||
status ::= "the capital of " country
|
||||
country ::= "England" | "France" | "Germany" | "Italy"
|
||||
"""
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful geography bot."
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Give me the information of the capital of France.",
|
||||
},
|
||||
]
|
||||
chat_completion = client.chat.completions.create(
|
||||
model="nvidia/Llama-3.1-8B-Instruct-FP8",
|
||||
messages=messages,
|
||||
max_completion_tokens=256,
|
||||
response_format={
|
||||
"type": "ebnf",
|
||||
"ebnf": ebnf_grammar
|
||||
},
|
||||
)
|
||||
|
||||
message = chat_completion.choices[0].message
|
||||
print(message.content)
|
||||
```
|
||||
|
||||
The output would look like:
|
||||
```txt
|
||||
Paris is the capital of France
|
||||
```
|
||||
|
||||
### Structural tag
|
||||
|
||||
Define a structural tag and pass it to `response_format` when creating the OpenAI chat completion request.
|
||||
|
||||
Structural tag is supported by `xgrammar` backend only. It is a powerful and flexible tool to represent the LLM output constraints. Please see [structural tag usage](https://xgrammar.mlc.ai/docs/tutorials/structural_tag.html) for a comprehensive tutorial. Below is an example of function calling with customized function call format for `Llama-3.1-8B-Instruct`.
|
||||
|
||||
|
||||
```python
|
||||
from openai import OpenAI
|
||||
|
||||
client = OpenAI(
|
||||
base_url="http://localhost:8000/v1",
|
||||
api_key="tensorrt_llm",
|
||||
)
|
||||
|
||||
tool_get_current_weather = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather in a given location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {
|
||||
"type": "string",
|
||||
"description": "The city to find the weather for, e.g. 'San Francisco'",
|
||||
},
|
||||
"state": {
|
||||
"type": "string",
|
||||
"description": "the two-letter abbreviation for the state that the city is in, e.g. 'CA' which would mean 'California'",
|
||||
},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"description": "The unit to fetch the temperature in",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
},
|
||||
},
|
||||
"required": ["city", "state", "unit"],
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
tool_get_current_date = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_date",
|
||||
"description": "Get the current date and time for a given timezone",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"timezone": {
|
||||
"type": "string",
|
||||
"description": "The timezone to fetch the current date and time for, e.g. 'America/New_York'",
|
||||
}
|
||||
},
|
||||
"required": ["timezone"],
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
system_prompt = f"""# Tool Instructions
|
||||
- Always execute python code in messages that you share.
|
||||
- When looking for real time information use relevant functions if available else fallback to brave_search
|
||||
You have access to the following functions:
|
||||
Use the function 'get_current_weather' to: Get the current weather in a given location
|
||||
{tool_get_current_weather["function"]}
|
||||
Use the function 'get_current_date' to: Get the current date and time for a given timezone
|
||||
{tool_get_current_date["function"]}
|
||||
If a you choose to call a function ONLY reply in the following format:
|
||||
<{{start_tag}}={{function_name}}>{{parameters}}{{end_tag}}
|
||||
where
|
||||
start_tag => `<function`
|
||||
parameters => a JSON dict with the function argument name as key and function argument value as value.
|
||||
end_tag => `</function>`
|
||||
Here is an example,
|
||||
<function=example_function_name>{{"example_name": "example_value"}}</function>
|
||||
Reminder:
|
||||
- Function calls MUST follow the specified format
|
||||
- Required parameters MUST be specified
|
||||
- Only call one function at a time
|
||||
- Put the entire function call reply on one line
|
||||
- Always add your sources when using search results to answer the user query
|
||||
You are a helpful assistant."""
|
||||
user_prompt = "You are in New York. Please get the current date and time, and the weather."
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": system_prompt,
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": user_prompt,
|
||||
},
|
||||
]
|
||||
|
||||
chat_completion = client.chat.completions.create(
|
||||
model="nvidia/Llama-3.1-8B-Instruct-FP8",
|
||||
messages=messages,
|
||||
max_completion_tokens=256,
|
||||
response_format={
|
||||
"type": "structural_tag",
|
||||
"format": {
|
||||
"type": "triggered_tags",
|
||||
"triggers": ["<function="],
|
||||
"tags": [
|
||||
{
|
||||
"begin": "<function=get_current_weather>",
|
||||
"content": {
|
||||
"type": "json_schema",
|
||||
"json_schema": tool_get_current_weather["function"]["parameters"]
|
||||
},
|
||||
"end": "</function>",
|
||||
},
|
||||
{
|
||||
"begin": "<function=get_current_date>",
|
||||
"content": {
|
||||
"type": "json_schema",
|
||||
"json_schema": tool_get_current_date["function"]["parameters"]
|
||||
},
|
||||
"end": "</function>",
|
||||
},
|
||||
],
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
message = chat_completion.choices[0].message
|
||||
print(message.content)
|
||||
```
|
||||
|
||||
The output would look like:
|
||||
```txt
|
||||
<function=get_current_date>{"timezone": "America/New_York"}</function>
|
||||
<function=get_current_weather>{"city": "New York", "state": "NY", "unit": "fahrenheit"}</function>
|
||||
```
|
||||
|
||||
|
||||
## Offline API: LLM API
|
||||
|
||||
If you are using LLM API, enable guided decoding by specifying `guided_decoding_backend` with `xgrammar` or `llguidance` when creating the LLM instance. For example,
|
||||
|
||||
```python
|
||||
from tensorrt_llm import LLM
|
||||
|
||||
llm = LLM("nvidia/Llama-3.1-8B-Instruct-FP8", guided_decoding_backend="xgrammar")
|
||||
```
|
||||
|
||||
### JSON Schema
|
||||
|
||||
Create a `GuidedDecodingParams` with the `json` field specified with a JSON schema, use it to create `SamplingParams`, and then pass to `llm.generate` or `llm.generate_async`. Alternatively, the JSON schema can be created using [pydantic](https://docs.pydantic.dev/latest/).
|
||||
|
||||
```python
|
||||
from tensorrt_llm import LLM
|
||||
from tensorrt_llm.sampling_params import SamplingParams, GuidedDecodingParams
|
||||
|
||||
if __name__ == "__main__":
|
||||
llm = LLM("nvidia/Llama-3.1-8B-Instruct-FP8", guided_decoding_backend="xgrammar")
|
||||
|
||||
json_schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {
|
||||
"type": "string",
|
||||
"pattern": "^[\\w]+$"
|
||||
},
|
||||
"population": {
|
||||
"type": "integer"
|
||||
},
|
||||
},
|
||||
"required": ["name", "population"],
|
||||
}
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant.",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Give me the information of the capital of France in the JSON format.",
|
||||
},
|
||||
]
|
||||
prompt = llm.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
||||
|
||||
output = llm.generate(
|
||||
prompt,
|
||||
sampling_params=SamplingParams(max_tokens=256, guided_decoding=GuidedDecodingParams(json=json_schema)),
|
||||
)
|
||||
print(output.outputs[0].text)
|
||||
```
|
||||
|
||||
The output would look like:
|
||||
```txt
|
||||
{
|
||||
"name": "Paris",
|
||||
"population": 2145206
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
### Regular expression
|
||||
|
||||
Create a `GuidedDecodingParams` with the `regex` field specified with a regular expression, use it to create `SamplingParams`, and then pass to `llm.generate` or `llm.generate_async`.
|
||||
|
||||
```python
|
||||
from tensorrt_llm import LLM
|
||||
from tensorrt_llm.sampling_params import SamplingParams, GuidedDecodingParams
|
||||
|
||||
if __name__ == "__main__":
|
||||
llm = LLM("nvidia/Llama-3.1-8B-Instruct-FP8", guided_decoding_backend="xgrammar")
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant.",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What is the capital of France?",
|
||||
},
|
||||
]
|
||||
prompt = llm.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
||||
|
||||
output = llm.generate(
|
||||
prompt,
|
||||
sampling_params=SamplingParams(max_tokens=256, guided_decoding=GuidedDecodingParams(regex="(Paris|London)")),
|
||||
)
|
||||
print(output.outputs[0].text)
|
||||
```
|
||||
|
||||
The output would look like:
|
||||
```txt
|
||||
Paris
|
||||
```
|
||||
|
||||
### EBNF grammar
|
||||
|
||||
Create a `GuidedDecodingParams` with the `grammar` field specified with an EBNF grammar, use it to create `SamplingParams`, and then pass to `llm.generate` or `llm.generate_async`.
|
||||
|
||||
```python
|
||||
from tensorrt_llm import LLM
|
||||
from tensorrt_llm.sampling_params import SamplingParams, GuidedDecodingParams
|
||||
|
||||
if __name__ == "__main__":
|
||||
llm = LLM("nvidia/Llama-3.1-8B-Instruct-FP8", guided_decoding_backend="xgrammar")
|
||||
|
||||
ebnf_grammar = """root ::= description
|
||||
city ::= "London" | "Paris" | "Berlin" | "Rome"
|
||||
description ::= city " is " status
|
||||
status ::= "the capital of " country
|
||||
country ::= "England" | "France" | "Germany" | "Italy"
|
||||
"""
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful geography bot."
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Give me the information of the capital of France.",
|
||||
},
|
||||
]
|
||||
prompt = llm.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
||||
|
||||
output = llm.generate(
|
||||
prompt,
|
||||
sampling_params=SamplingParams(max_tokens=256, guided_decoding=GuidedDecodingParams(grammar=ebnf_grammar)),
|
||||
)
|
||||
print(output.outputs[0].text)
|
||||
```
|
||||
|
||||
The output would look like:
|
||||
```txt
|
||||
Paris is the capital of France
|
||||
```
|
||||
|
||||
### Structural tag
|
||||
|
||||
Create a `GuidedDecodingParams` with the `structural_tag` field specified with a structural tag string, use it to create `SamplingParams`, and then pass to `llm.generate` or `llm.generate_async`.
|
||||
|
||||
Structural tag is supported by `xgrammar` backend only. It is a powerful and flexible tool to represent the LLM output constraints. Please see [structural tag usage](https://xgrammar.mlc.ai/docs/tutorials/structural_tag.html) for a comprehensive tutorial. Below is an example of function calling with customized function call format for `Llama-3.1-8B-Instruct`.
|
||||
|
||||
```python
|
||||
import json
|
||||
from tensorrt_llm import LLM
|
||||
from tensorrt_llm.sampling_params import SamplingParams, GuidedDecodingParams
|
||||
|
||||
if __name__ == "__main__":
|
||||
llm = LLM("nvidia/Llama-3.1-8B-Instruct-FP8", guided_decoding_backend="xgrammar")
|
||||
|
||||
tool_get_current_weather = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather in a given location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {
|
||||
"type": "string",
|
||||
"description": "The city to find the weather for, e.g. 'San Francisco'",
|
||||
},
|
||||
"state": {
|
||||
"type": "string",
|
||||
"description": "the two-letter abbreviation for the state that the city is in, e.g. 'CA' which would mean 'California'",
|
||||
},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"description": "The unit to fetch the temperature in",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
},
|
||||
},
|
||||
"required": ["city", "state", "unit"],
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
tool_get_current_date = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_date",
|
||||
"description": "Get the current date and time for a given timezone",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"timezone": {
|
||||
"type": "string",
|
||||
"description": "The timezone to fetch the current date and time for, e.g. 'America/New_York'",
|
||||
}
|
||||
},
|
||||
"required": ["timezone"],
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
system_prompt = f"""# Tool Instructions
|
||||
- Always execute python code in messages that you share.
|
||||
- When looking for real time information use relevant functions if available else fallback to brave_search
|
||||
You have access to the following functions:
|
||||
Use the function 'get_current_weather' to: Get the current weather in a given location
|
||||
{tool_get_current_weather["function"]}
|
||||
Use the function 'get_current_date' to: Get the current date and time for a given timezone
|
||||
{tool_get_current_date["function"]}
|
||||
If a you choose to call a function ONLY reply in the following format:
|
||||
<{{start_tag}}={{function_name}}>{{parameters}}{{end_tag}}
|
||||
where
|
||||
start_tag => `<function`
|
||||
parameters => a JSON dict with the function argument name as key and function argument value as value.
|
||||
end_tag => `</function>`
|
||||
Here is an example,
|
||||
<function=example_function_name>{{"example_name": "example_value"}}</function>
|
||||
Reminder:
|
||||
- Function calls MUST follow the specified format
|
||||
- Required parameters MUST be specified
|
||||
- Only call one function at a time
|
||||
- Put the entire function call reply on one line
|
||||
- Always add your sources when using search results to answer the user query
|
||||
You are a helpful assistant."""
|
||||
user_prompt = "You are in New York. Please get the current date and time, and the weather."
|
||||
structural_tag = {
|
||||
"type": "structural_tag",
|
||||
"format": {
|
||||
"type": "triggered_tags",
|
||||
"triggers": ["<function="],
|
||||
"tags": [
|
||||
{
|
||||
"begin": "<function=get_current_weather>",
|
||||
"content": {
|
||||
"type": "json_schema",
|
||||
"json_schema": tool_get_current_weather["function"]["parameters"]
|
||||
},
|
||||
"end": "</function>",
|
||||
},
|
||||
{
|
||||
"begin": "<function=get_current_date>",
|
||||
"content": {
|
||||
"type": "json_schema",
|
||||
"json_schema": tool_get_current_date["function"]["parameters"]
|
||||
},
|
||||
"end": "</function>",
|
||||
},
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": system_prompt,
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": user_prompt,
|
||||
},
|
||||
]
|
||||
prompt = llm.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
||||
|
||||
output = llm.generate(
|
||||
prompt,
|
||||
sampling_params=SamplingParams(max_tokens=256, guided_decoding=GuidedDecodingParams(structural_tag=json.dumps(structural_tag))),
|
||||
)
|
||||
print(output.outputs[0].text)
|
||||
```
|
||||
|
||||
The output would look like:
|
||||
```txt
|
||||
<function=get_current_date>{"timezone": "America/New_York"}</function>
|
||||
<function=get_current_weather>{"city": "New York", "state": "NY", "unit": "fahrenheit"}</function>
|
||||
```
|
||||
@ -1,5 +1,5 @@
|
||||
# Sampling
|
||||
The PyTorch backend supports most of the sampling features that are supported on the C++ backend, such as temperature, top-k and top-p sampling, beam search, stop words, bad words, penalty, context and generation logits, log probability, guided decoding and logits processors
|
||||
The PyTorch backend supports most of the sampling features that are supported on the C++ backend, such as temperature, top-k and top-p sampling, beam search, stop words, bad words, penalty, context and generation logits, log probability and logits processors
|
||||
|
||||
## General usage
|
||||
|
||||
@ -60,42 +60,6 @@ llm.generate(["Hello, my name is",
|
||||
"Hello, my name is"], sampling_params)
|
||||
```
|
||||
|
||||
## Guided decoding
|
||||
|
||||
Guided decoding controls the generation outputs to conform to pre-defined structured formats, ensuring outputs follow specific schemas or patterns.
|
||||
|
||||
The PyTorch backend supports guided decoding with the XGrammar and Low-level Guidance (llguidance) backends and the following formats:
|
||||
- JSON schema
|
||||
- JSON object
|
||||
- Regular expressions
|
||||
- Extended Backus-Naur form (EBNF) grammar
|
||||
- Structural tags
|
||||
|
||||
To enable guided decoding, you must:
|
||||
|
||||
1. Set the `guided_decoding_backend` parameter to `'xgrammar'` or `'llguidance'` in the `LLM` class
|
||||
2. Create a [`GuidedDecodingParams`](source:tensorrt_llm/sampling_params.py#L14) object with the desired format specification
|
||||
* Note: Depending on the type of format, a different parameter needs to be chosen to construct the object (`json`, `regex`, `grammar`, `structural_tag`).
|
||||
3. Pass the `GuidedDecodingParams` object to the `guided_decoding` parameter of the `SamplingParams` object
|
||||
|
||||
The following example demonstrates guided decoding with a JSON schema:
|
||||
|
||||
```python
|
||||
from tensorrt_llm import LLM, SamplingParams
|
||||
from tensorrt_llm.llmapi import GuidedDecodingParams
|
||||
|
||||
llm = LLM(model='nvidia/Llama-3.1-8B-Instruct-FP8',
|
||||
guided_decoding_backend='xgrammar')
|
||||
structure = '{"title": "Example JSON", "type": "object", "properties": {...}}'
|
||||
guided_decoding_params = GuidedDecodingParams(json=structure)
|
||||
sampling_params = SamplingParams(
|
||||
guided_decoding=guided_decoding_params,
|
||||
)
|
||||
llm.generate("Generate a JSON response", sampling_params)
|
||||
```
|
||||
|
||||
You can find a more detailed example on guided decoding [here](source:examples/llm-api/llm_guided_decoding.py).
|
||||
|
||||
## Logits processor
|
||||
|
||||
Logits processors allow you to modify the logits produced by the network before sampling, enabling custom generation behavior and constraints.
|
||||
|
||||
@ -71,6 +71,7 @@ Welcome to TensorRT LLM's Documentation!
|
||||
features/quantization.md
|
||||
features/sampling.md
|
||||
features/additional-outputs.md
|
||||
features/guided-decoding.md
|
||||
features/speculative-decoding.md
|
||||
features/checkpoint-loading.md
|
||||
features/auto_deploy/auto-deploy.md
|
||||
|
||||
Loading…
Reference in New Issue
Block a user