[TRTLLM-6409][feat] Enable guided decoding with speculative decoding (part 1: two-model engine) (#6300)

Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>
This commit is contained in:
Enwei Zhu 2025-08-07 17:53:48 +08:00 committed by GitHub
parent c23e8e7b05
commit 1b9781e8e7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 389 additions and 115 deletions

View File

@ -2027,7 +2027,7 @@ private:
// Scatter the input tokens to other beam
mTokens = BeamTokens(mSamplingConfig.beamWidth, inputTokens);
mLastTokens = VecTokens(mSamplingConfig.beamWidth);
mLastTokens = VecTokens(mSamplingConfig.beamWidth, inputTokens.back());
// Init mUniqueTokens
VecUniqueTokens uniqueTokens{inputTokens.size()};

View File

@ -8,11 +8,11 @@
| Disaggregated Serving | Yes | Yes | Yes | --- | | | | | | | | | | |
| Chunked Prefill | Yes | Yes | Yes | Untested | --- | | | | | | | | | |
| MTP | Yes | Yes | Yes | Yes | Untested | --- | | | | | | | | |
| EAGLE-3(One Model Engine) | Yes | Yes | Yes | No | Yes | No | --- | | | | | | | |
| EAGLE-3(Two Model Engine) | NO | Yes | Yes | No | Yes | No | No | --- | | | | | | |
| EAGLE-3(One Model Engine) | Yes | Yes | Yes | No | Yes | No | --- | | | | | | | |
| EAGLE-3(Two Model Engine) | NO | Yes | Yes | No | Yes | No | No | --- | | | | | | |
| Torch Sampler | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | --- | | | | | |
| TLLM C++ Sampler | Yes | Yes | Yes | Yes | Yes | No | No | No | No | --- | | | | |
| KV Cache Reuse | Yes | Yes | Yes | Untested | Yes | Untested | Yes | No | Yes | Yes | --- | | | |
| Slide Window Attention | Yes | Yes | Yes | Untested | No | Untested | Untested | Untested | Yes | Yes | WIP | --- | | |
| Logits Post Processor | No | Yes | Yes | No | Yes | No | No | No | Yes | Yes | Yes | Yes | --- | |
| Guided Decoding | Yes | Yes | Yes | No | Yes | No | No | No | Yes | Yes | Yes | Yes | Yes | --- |
| Slide Window Attention | Yes | Yes | Yes | Untested | No | Untested | Untested | Untested | Yes | Yes | WIP | --- | | |
| Logits Post Processor | No | Yes | Yes | No | Yes | No | No | No | Yes | Yes | Yes | Yes | --- | |
| Guided Decoding | Yes | Yes | Yes | No | Yes | No | No | Yes | Yes | Yes | Yes | Yes | Yes | --- |

View File

@ -52,6 +52,8 @@ einops
flashinfer-python==0.2.5
opencv-python-headless
xgrammar==0.1.21
llguidance==0.7.29
jsonschema
backoff
nvtx
matplotlib # FIXME: this is added to make nvtx happy
@ -59,7 +61,6 @@ meson
ninja
etcd3
blake3
llguidance==0.7.29
soundfile
triton==3.3.1; platform_machine == "x86_64"
tiktoken

View File

@ -16,11 +16,19 @@ class GrammarMatcher(ABC):
def accept_token(self, token_id: int) -> bool:
pass
@abstractmethod
def rollback(self, num_tokens: int) -> None:
pass
@abstractmethod
def fill_next_token_bitmask(self, next_token_bitmask: torch.Tensor,
index: int) -> None:
pass
@abstractmethod
def is_terminated(self) -> bool:
pass
class GrammarMatcherFactory(ABC):
@ -39,15 +47,23 @@ class XGrammarMatcher(GrammarMatcher):
def accept_token(self, token_id: int) -> bool:
return self._matcher.accept_token(token_id)
def rollback(self, num_tokens: int) -> None:
self._matcher.rollback(num_tokens)
def fill_next_token_bitmask(self, next_token_bitmask: torch.Tensor,
index: int) -> None:
self._matcher.fill_next_token_bitmask(next_token_bitmask, index)
def is_terminated(self) -> bool:
return self._matcher.is_terminated()
class XGrammarMatcherFactory(GrammarMatcherFactory):
def __init__(self, guided_decoding_config: GuidedDecodingConfig,
vocab_size_padded: int):
def __init__(self,
guided_decoding_config: GuidedDecodingConfig,
vocab_size_padded: int,
max_num_draft_tokens: int = 0):
super().__init__()
vocab_type = xgrammar.VocabType.RAW
add_prefix_space = False
@ -72,6 +88,7 @@ class XGrammarMatcherFactory(GrammarMatcherFactory):
cache_enabled=True,
cache_limit_bytes=cache_limit_bytes,
)
self.max_num_draft_tokens = max_num_draft_tokens
def create(self,
guided_decoding_params: GuidedDecodingParams) -> XGrammarMatcher:
@ -106,20 +123,38 @@ class XGrammarMatcherFactory(GrammarMatcherFactory):
case _:
raise ValueError(f"Unsupported guide type: {guide_type}.")
matcher = xgrammar.GrammarMatcher(compiled_grammar)
matcher = xgrammar.GrammarMatcher(
compiled_grammar, max_rollback_tokens=self.max_num_draft_tokens)
return XGrammarMatcher(matcher)
class LLGuidanceMatcher(GrammarMatcher):
def __init__(self, matcher: llguidance.LLMatcher):
def __init__(self, matcher: llguidance.LLMatcher, eos_token: int):
super().__init__()
self._matcher = matcher
self._eos_token = eos_token
self._is_terminated = False
def accept_token(self, token_id: int) -> bool:
result = self._matcher.consume_token(token_id)
if self._matcher.is_stopped():
# Accept EOS token only if the matcher is stopped.
if token_id == self._eos_token:
self._is_terminated = True
return True
else:
return False
num_accepted = self._matcher.try_consume_tokens([token_id])
self._check_err()
return num_accepted > 0
def rollback(self, num_tokens: int) -> None:
if self._is_terminated:
self._is_terminated = False
num_tokens -= 1
self._matcher.rollback(num_tokens)
self._check_err()
return result
def fill_next_token_bitmask(self, next_token_bitmask: torch.Tensor,
index: int) -> None:
@ -127,6 +162,9 @@ class LLGuidanceMatcher(GrammarMatcher):
next_token_bitmask, index)
self._check_err()
def is_terminated(self) -> bool:
return self._is_terminated
def _check_err(self) -> None:
if self._matcher.is_error():
raise ValueError(
@ -181,4 +219,4 @@ class LLGuidanceMatcherFactory(GrammarMatcherFactory):
if matcher.is_error():
raise ValueError(f"LLGuidance matcher error: {matcher.get_error()}")
return LLGuidanceMatcher(matcher)
return LLGuidanceMatcher(matcher, self._tokenizer.eos_token)

View File

@ -5,19 +5,25 @@ import torch
from ..._utils import nvtx_range
from ...bindings.executor import GuidedDecodingConfig
from ...logger import logger
from .grammar_matcher import (GrammarMatcher, GrammarMatcherFactory,
LLGuidanceMatcherFactory, XGrammarMatcherFactory)
from .llm_request import LlmRequest
from .scheduler import ScheduledRequests
class GuidedDecoder:
bitmask_dtype = torch.int32
def __init__(self, guided_decoding_config: GuidedDecodingConfig,
max_num_sequences: int, vocab_size_padded: int):
def __init__(self,
guided_decoding_config: GuidedDecodingConfig,
max_num_sequences: int,
vocab_size_padded: int,
max_num_draft_tokens: int = 0):
self.guided_decoding_backend = guided_decoding_config.backend
self.max_num_sequences = max_num_sequences
self.vocab_size_padded = vocab_size_padded
self.max_num_draft_tokens = max_num_draft_tokens
self.grammar_matcher_factory: Optional[GrammarMatcherFactory] = None
self.grammar_matchers: List[
@ -25,71 +31,216 @@ class GuidedDecoder:
if self.guided_decoding_backend == GuidedDecodingConfig.GuidedDecodingBackend.XGRAMMAR:
self.grammar_matcher_factory = XGrammarMatcherFactory(
guided_decoding_config, vocab_size_padded)
guided_decoding_config,
vocab_size_padded,
max_num_draft_tokens=max_num_draft_tokens)
elif self.guided_decoding_backend == GuidedDecodingConfig.GuidedDecodingBackend.LLGUIDANCE:
self.grammar_matcher_factory = LLGuidanceMatcherFactory(
guided_decoding_config, vocab_size_padded)
else:
raise ValueError(
f"invalid guided_decoding_backend: {self.guided_decoding_backend}"
f"invalid guided decoding backend: {self.guided_decoding_backend}"
)
logger.info(
f"Guided decoder initialized with backend: {self.guided_decoding_backend}"
)
self.bitmask = torch.empty(self.max_num_sequences,
self.max_num_draft_tokens + 1,
self.bitmask_size,
dtype=self.bitmask_dtype,
device='cuda')
self.bitmask_host = torch.empty(self.max_num_sequences,
self.max_num_draft_tokens + 1,
self.bitmask_size,
dtype=self.bitmask_dtype,
pin_memory=True)
# The number of tokens accepted by the grammar matcher in a build step.
self.num_advanced_tokens: List[int] = [0] * self.max_num_sequences
# The number of tokens with filled bitmask in a build step.
self.num_guided_tokens: List[int] = [0] * self.max_num_sequences
# The accumulated number of tokens accepted by the grammar matcher in a drafting loop.
self.num_advanced_draft_tokens: List[int] = [0] * self.max_num_sequences
# Whether is guided drafting is terminated because of unacceptable drafted tokens.
self.is_draft_terminated: List[bool] = [False] * self.max_num_sequences
self._stream = torch.cuda.Stream()
@property
def bitmask_size(self) -> int:
return math.ceil(self.vocab_size_padded / 32)
def _is_matcher_init(self, llm_req: LlmRequest) -> bool:
if llm_req.guided_decoding_params is None:
return False
if llm_req.py_is_draft:
return False
# The request is in the last chunk of a context forward step.
return llm_req.is_context_init_state and llm_req.is_last_context_chunk
def _is_matcher_in_progress(self, llm_req: LlmRequest) -> bool:
if llm_req.guided_decoding_params is None:
return False
if llm_req.py_is_draft:
return True
# The request is in a generation forward step.
return llm_req.is_generation_in_progress_state
@torch.inference_mode()
@nvtx_range("GuidedDecoder.build")
def build(self, scheduled_requests: ScheduledRequests) -> None:
for llm_req in scheduled_requests.all_requests():
if llm_req.guided_decoding_params is None:
continue
slot = llm_req.py_seq_slot
if llm_req.is_context_init_state and llm_req.context_current_position == llm_req.prepopulated_prompt_len:
self.grammar_matchers[
slot] = self.grammar_matcher_factory.create(
llm_req.guided_decoding_params)
"""Build the bitmask for requests with guided decoding enabled.
Specifically, this method:
- build and advance the grammar matcher for context and generation requests, respectively;
- call the grammar matcher to fill the bitmask on CPU;
- asynchronously copy the bitmask to GPU.
"""
for llm_req in scheduled_requests.all_requests():
slot: int = llm_req.py_target_seq_slot if llm_req.py_is_draft else llm_req.py_seq_slot
self.num_advanced_tokens[slot] = 0
self.num_guided_tokens[slot] = 0
if self._is_matcher_init(llm_req):
matcher = self.grammar_matcher_factory.create(
llm_req.guided_decoding_params)
self.grammar_matchers[slot] = matcher
elif self._is_matcher_in_progress(llm_req):
matcher = self.grammar_matchers[slot]
# The last new token must be acceptable unless the matcher is terminated in a drafting loop.
if llm_req.py_is_draft and (matcher.is_terminated()
or self.is_draft_terminated[slot]):
continue
last_new_token = llm_req.get_last_tokens(0)
accepted = matcher.accept_token(last_new_token)
if not accepted:
if llm_req.py_is_draft:
self.is_draft_terminated[slot] = True
logger.debug(
f"Draft request {llm_req.py_request_id} failed to accept last new token: {last_new_token}."
)
continue
# TODO: Make this an error response.
raise ValueError(
f"Request {llm_req.py_request_id} failed to accept last new token: {last_new_token}."
)
elif llm_req.is_generation_in_progress_state:
# The request is in a generation forward step.
# Currently, guided decoding does not support with beam search.
self.grammar_matchers[slot].accept_token(
llm_req.get_last_tokens(0))
else:
continue
# Fill the bitmask on host and asynchorously copy to device.
self.grammar_matchers[slot].fill_next_token_bitmask(
self.bitmask_host, slot)
with torch.cuda.stream(self._stream):
self.bitmask[slot].copy_(self.bitmask_host[slot],
non_blocking=True)
self.num_advanced_tokens[slot] += 1
if not matcher.is_terminated():
matcher.fill_next_token_bitmask(self.bitmask_host[slot], 0)
self.num_guided_tokens[slot] += 1
# Process draft tokens
for i, tid in enumerate(llm_req.py_draft_tokens, 1):
accepted = matcher.accept_token(tid)
if not accepted:
break
self.num_advanced_tokens[slot] += 1
if matcher.is_terminated():
break
matcher.fill_next_token_bitmask(self.bitmask_host[slot], i)
self.num_guided_tokens[slot] += 1
if llm_req.py_is_draft:
assert len(llm_req.py_draft_tokens) == 0
self.num_advanced_draft_tokens[
slot] += self.num_advanced_tokens[slot]
if (num_guided_tokens := self.num_guided_tokens[slot]) > 0:
with torch.cuda.stream(self._stream):
self.bitmask[slot, :num_guided_tokens].copy_(
self.bitmask_host[slot, :num_guided_tokens],
non_blocking=True)
@torch.inference_mode()
@nvtx_range("GuidedDecoder.execute")
def execute(self, scheduled_requests: ScheduledRequests,
logits: torch.Tensor) -> None:
assert logits.size(0) == len(scheduled_requests.context_requests) + len(
scheduled_requests.generation_requests)
def execute(self,
scheduled_requests: ScheduledRequests,
logits: torch.Tensor,
d2t: Optional[torch.Tensor] = None) -> None:
"""Apply the bitmask to the corresponding logits for requests with guided decoding enabled.
This method inplace modifies the logits tensor so that any tokens that violate the grammar constraints are masked out.
"""
torch.cuda.current_stream().wait_stream(self._stream)
# TODO: Fuse index_copy and index_select to logits_bitmask.
if d2t is not None:
draft_logits = logits
d2t_mapping = d2t + torch.arange(d2t.size(0), device=d2t.device)
logits = torch.empty(draft_logits.size(0),
self.vocab_size_padded,
dtype=draft_logits.dtype,
device=draft_logits.device)
logits.index_copy_(-1, d2t_mapping, draft_logits)
batched_logits, batched_bitmask = [], []
for i, llm_req in enumerate(scheduled_requests.all_requests()):
if llm_req.guided_decoding_params is None:
continue
if llm_req.is_context_init_state and not llm_req.is_last_context_chunk:
continue
batched_logits.append(logits[i])
batched_bitmask.append(self.bitmask[llm_req.py_seq_slot])
offset = 0
for llm_req in scheduled_requests.all_requests():
slot: int = llm_req.py_target_seq_slot if llm_req.py_is_draft else llm_req.py_seq_slot
for i in range(self.num_guided_tokens[slot]):
batched_logits.append(logits[offset + i])
batched_bitmask.append(self.bitmask[slot, i])
offset += len(llm_req.py_draft_tokens) + 1
assert offset == logits.size(0)
if len(batched_logits) > 0:
torch.ops.trtllm.logits_bitmask(batched_logits, batched_bitmask)
if d2t is not None:
torch.index_select(logits, -1, d2t_mapping, out=draft_logits)
@nvtx_range("GuidedDecoder.rollback_rejected_tokens")
def rollback_rejected_tokens(self,
scheduled_requests: ScheduledRequests) -> None:
"""Rollback the grammar matcher for rejected tokens.
This method should be called:
- after the verification (so that the accepted tokens are ready) and
- before the first guided decoding build of the next drafting loop.
"""
if self.max_num_draft_tokens <= 0:
return
for llm_req in scheduled_requests.all_requests():
assert not llm_req.py_is_draft
slot: int = llm_req.py_seq_slot
if self.num_advanced_tokens[slot] <= 0:
continue
# Rollback the grammar matcher to the last accepted token.
num_rollback_tokens = self.num_advanced_tokens[slot] - (
1 + llm_req.py_num_accepted_draft_tokens)
# TODO: Make this an error response.
if num_rollback_tokens < 0:
raise ValueError(
f"Failed to rollback: num_advanced_tokens={self.num_advanced_tokens[slot]}, num_accepted_draft_tokens={llm_req.py_num_accepted_draft_tokens}, num_rollback_tokens={num_rollback_tokens}"
)
self.grammar_matchers[slot].rollback(num_rollback_tokens)
@nvtx_range("GuidedDecoder.rollback_draft_tokens")
def rollback_draft_tokens(self,
scheduled_requests: ScheduledRequests) -> None:
"""Rollback the grammar matcher for draft tokens.
This method should be called:
- after the the drafting loop and
- before the guided decoding build of the target model.
"""
if self.max_num_draft_tokens <= 0:
return
for llm_req in scheduled_requests.all_requests():
assert not llm_req.py_is_draft
slot: int = llm_req.py_seq_slot
if self.num_advanced_draft_tokens[slot] <= 0:
continue
self.grammar_matchers[slot].rollback(
self.num_advanced_draft_tokens[slot])
# Reset the drafting states.
self.num_advanced_draft_tokens[slot] = 0
self.is_draft_terminated[slot] = False

View File

@ -281,6 +281,8 @@ class LlmRequest(tensorrt_llm.bindings.internal.batch_manager.LlmRequest):
llm_request: Optional[
tensorrt_llm.bindings.internal.batch_manager.LlmRequest] = None,
is_draft: bool = False,
seq_slot: Optional[int] = None,
target_seq_slot: Optional[int] = None,
**kwargs):
self.py_logits_post_processors = kwargs.pop("py_logits_post_processors",
@ -309,6 +311,7 @@ class LlmRequest(tensorrt_llm.bindings.internal.batch_manager.LlmRequest):
self.py_orig_prompt_len = self.orig_prompt_len
self.py_max_new_tokens = self.max_new_tokens
self.py_batch_idx = None
self.py_draft_pages_allocated = 0
self.py_rewind_len = 0
self.py_draft_tokens = [] if self.draft_tokens is None else self.draft_tokens
self.py_last_context_chunk = (None, None)
@ -326,7 +329,8 @@ class LlmRequest(tensorrt_llm.bindings.internal.batch_manager.LlmRequest):
self.py_return_generation_logits = return_generation_logits
self.py_return_logits_device_memory = return_logits_device_memory
self.py_is_draft = is_draft
self.py_seq_slot = None
self.py_seq_slot = seq_slot
self.py_target_seq_slot = target_seq_slot
# TODO: remove this when use DynamicDecodeOp in pytorch flow.
# currently, keep py_stop_words_list as python list, rather than tensor.

View File

@ -453,7 +453,7 @@ class PyTorchModelEngine(ModelEngine):
'type'] == 'mrope'
except Exception:
pass
logger.info(f"Detected use_mrope: {use_mrope}")
logger.debug(f"Detected use_mrope: {use_mrope}")
return use_mrope
@property

View File

@ -893,7 +893,8 @@ class PyExecutor:
f'{len(scheduled_batch.generation_requests)} generation requests')
return scheduled_batch, iter_stats
def _execute_guided_decoder(self, scheduled_batch, logits):
def _execute_guided_decoder(self, scheduled_batch: ScheduledRequests,
logits: torch.Tensor):
if self.guided_decoder is not None:
self.guided_decoder.build(scheduled_batch)
self.guided_decoder.execute(scheduled_batch, logits)
@ -931,6 +932,9 @@ class PyExecutor:
self.resource_manager.prepare_resources(scheduled_batch)
if self.drafter is not None and self.use_spec_decode:
if self.guided_decoder is not None:
self.guided_decoder.rollback_rejected_tokens(
scheduled_batch)
self.drafter.prepare_draft_tokens(
scheduled_batch, self.resource_manager)

View File

@ -33,6 +33,7 @@ from .py_executor import PyExecutor
class _ExecutorCreationStage(enum.Enum):
SAMPLER = "Sampler"
DRAFTER = "Drafter"
GUIDED_DECODER = "Guided decoder"
INIT_KV_CACHE = "Initial KV cache (temporary for KV cache size estimation)"
INIT_EXTRA_RESOURCES = "Additional executor resources (temporary for KV cache size estimation)"
MODEL_EXTRA = "Model resources created during usage"
@ -326,21 +327,28 @@ def create_py_executor(
else:
ctx_chunk_config = None
with mem_monitor.observe_creation_stage(
_ExecutorCreationStage.GUIDED_DECODER):
guided_decoder: Optional[GuidedDecoder] = None
if executor_config.guided_decoding_config is not None:
if spec_config is not None and not has_spec_drafter:
raise ValueError(
"Guided decoding is only supported with speculative decoding that has a dedicated drafter (two-model engine)."
)
if mapping.is_last_pp_rank():
max_num_draft_tokens = 0
if spec_config is not None:
max_num_draft_tokens = spec_config.max_draft_len
guided_decoder = GuidedDecoder(
executor_config.guided_decoding_config,
executor_config.max_batch_size,
model_engine.model.vocab_size_padded,
max_num_draft_tokens=max_num_draft_tokens)
with mem_monitor.observe_creation_stage(_ExecutorCreationStage.SAMPLER):
sampler = instantiate_sampler(model_engine, executor_config,
pytorch_backend_config, mapping)
guided_decoder: Optional[GuidedDecoder] = None
if executor_config.guided_decoding_config is not None:
if spec_config is not None:
raise ValueError(
"Guided decoding is not supported with speculative decoding.")
if mapping.is_last_pp_rank():
guided_decoder = GuidedDecoder(
executor_config.guided_decoding_config,
executor_config.max_batch_size,
model_engine.model.vocab_size_padded)
resources = {}
estimating_kv_cache = False
kv_cache_creator = None
@ -368,8 +376,11 @@ def create_py_executor(
# Drafter for speculative decoding
with mem_monitor.observe_creation_stage(_ExecutorCreationStage.DRAFTER):
drafter = get_spec_drafter(model_engine, draft_model_engine, sampler,
spec_resource_manager)
drafter = get_spec_drafter(model_engine,
draft_model_engine,
sampler,
spec_resource_manager=spec_resource_manager,
guided_decoder=guided_decoder)
with mem_monitor.observe_creation_stage(
_ExecutorCreationStage.INIT_EXTRA_RESOURCES

View File

@ -8,8 +8,9 @@ import torch
from tensorrt_llm._utils import nvtx_range
from tensorrt_llm.logger import logger
from ..pyexecutor.guided_decoder import GuidedDecoder
from ..pyexecutor.llm_request import (LlmRequest, LlmRequestState,
SamplingConfig, get_draft_token_length)
get_draft_token_length)
from ..pyexecutor.resource_manager import BaseResourceManager, ResourceManager
from ..pyexecutor.sampler import Sampler, SampleState, TorchSampler
from ..pyexecutor.scheduler import ScheduledRequests
@ -45,6 +46,7 @@ class ModelDrafter(Drafter):
draft_seq_slot_manager: SeqSlotManager,
sampler: Sampler,
spec_resource_manager: Optional[BaseResourceManager] = None,
guided_decoder: Optional[GuidedDecoder] = None,
):
# Validate required parameters
if draft_model_engine is None:
@ -65,17 +67,18 @@ class ModelDrafter(Drafter):
self._request_draft_logits = False
if isinstance(sampler, TorchSampler):
self._request_draft_logits = sampler.enable_mixed_sampler
self.guided_decoder = guided_decoder
def _create_draft_request(self, request_id: int, max_new_tokens: int,
input_tokens: Optional[List],
sampling_config: SamplingConfig,
return_perf_metrics: bool) -> LlmRequest:
def _create_draft_request(self, request: LlmRequest,
input_tokens: Optional[List]) -> LlmRequest:
"""Create a draft request with common parameters."""
return LlmRequest(request_id=request_id,
max_new_tokens=max_new_tokens,
input_tokens=input_tokens,
sampling_config=sampling_config,
return_perf_metrics=return_perf_metrics,
return LlmRequest(input_tokens=input_tokens,
request_id=request.py_request_id,
max_new_tokens=request.py_max_new_tokens,
sampling_config=request.sampling_config,
guided_decoding_params=request.guided_decoding_params,
target_seq_slot=request.py_seq_slot,
return_perf_metrics=request.return_perf_metrics,
is_streaming=False,
is_draft=True,
return_generation_logits=self._request_draft_logits)
@ -96,11 +99,7 @@ class ModelDrafter(Drafter):
def _create_context_request(self, request: LlmRequest,
input_tokens: Any) -> LlmRequest:
"""Create a context request for first-time drafting."""
new_request = self._create_draft_request(request.py_request_id,
request.py_max_new_tokens,
input_tokens,
request.sampling_config,
request.return_perf_metrics)
new_request = self._create_draft_request(request, input_tokens)
begin_compute, end_compute = request.py_last_context_chunk
if begin_compute is not None:
@ -111,13 +110,7 @@ class ModelDrafter(Drafter):
def _create_generation_request(self, request: LlmRequest,
input_tokens: Any) -> LlmRequest:
"""Create a generation request when no tokens were accepted."""
new_request = self._create_draft_request(request.py_request_id,
request.py_max_new_tokens,
input_tokens[:-1],
request.sampling_config,
request.return_perf_metrics)
# Explicitly add the last token so get_last_tokens() returns the right value
new_request.add_new_token(input_tokens[-1], 0)
new_request = self._create_draft_request(request, input_tokens)
new_request.state = LlmRequestState.GENERATION_IN_PROGRESS
return new_request
@ -128,11 +121,7 @@ class ModelDrafter(Drafter):
Create a chunked context request for accepted tokens.
Only applicable if the draft model needs to recompute KV cache for accepted tokens (e.g. eagle 3)
"""
new_request = self._create_draft_request(request.py_request_id,
request.py_max_new_tokens,
input_tokens,
request.sampling_config,
request.return_perf_metrics)
new_request = self._create_draft_request(request, input_tokens)
new_request.context_chunk_size = num_accepted_tokens + 1
new_request.context_current_position = len(
input_tokens) - num_accepted_tokens - 1
@ -144,7 +133,7 @@ class ModelDrafter(Drafter):
num_draft_tokens, num_accepted_tokens = self._initialize_draft_tokens(
request)
input_tokens = get_draft_model_prompt(self.spec_config.spec_dec_mode,
request.get_tokens()[0])
request.get_tokens(0))
# First time seeing this request - context request
if request.max_beam_num_tokens - 1 == request.py_prompt_len:
@ -206,7 +195,7 @@ class ModelDrafter(Drafter):
# We hit this path if we're doing chunked prefill. The target model processed
# a prefill chunk on the last iteration. Now, we need to fill in the KV cache
# for the draft model too.
all_tokens = request.get_tokens()[0]
all_tokens = request.get_tokens(0)
input_tokens = get_draft_model_prompt(
self.spec_config.spec_dec_mode, all_tokens)
@ -329,6 +318,14 @@ class ModelDrafter(Drafter):
req.py_draft_tokens.extend(
0 for _ in range(max_draft_tokens - num_draft_tokens))
def _execute_guided_decoder(self,
scheduled_batch: ScheduledRequests,
logits: torch.Tensor,
d2t: Optional[torch.Tensor] = None):
if self.guided_decoder is not None:
self.guided_decoder.build(scheduled_batch)
self.guided_decoder.execute(scheduled_batch, logits, d2t=d2t)
@nvtx_range("prepare_draft_tokens")
def prepare_draft_tokens(
self,
@ -363,6 +360,9 @@ class ModelDrafter(Drafter):
# Initial forward pass
outputs = self._forward_draft_model(draft_batch, resource_manager)
self._execute_guided_decoder(draft_batch,
outputs['logits'],
d2t=outputs.get('d2t'))
sample_state = self._sample_async(draft_batch, outputs)
previous_batch = sample_state
@ -380,10 +380,14 @@ class ModelDrafter(Drafter):
outputs = self._forward_draft_model(draft_batch,
resource_manager,
previous_batch)
if previous_batch is not None:
self._update_requests(previous_batch)
self._execute_guided_decoder(draft_batch,
outputs['logits'],
d2t=outputs.get('d2t'))
sample_state = self._sample_async(draft_batch, outputs)
self._update_request_states(draft_batch)
if previous_batch is not None:
self._update_requests(previous_batch)
new_requests = self._process_decoded_tokens(
previous_batch.scheduled_requests,
req_id_to_old_request)
@ -399,6 +403,9 @@ class ModelDrafter(Drafter):
req_id_to_old_request)
self._pad_to_max_draft_tokens(scheduled_requests)
if self.guided_decoder is not None:
self.guided_decoder.rollback_draft_tokens(scheduled_requests)
except Exception as e:
traceback.print_exc()
error_msg = str(e)

View File

@ -1,11 +1,12 @@
from itertools import chain
from typing import Optional
from ordered_set import OrderedSet
from tensorrt_llm.llmapi import NGramDecodingConfig
from tensorrt_llm.logger import logger
from ..pyexecutor.llm_request import *
from ..pyexecutor.llm_request import LlmRequest, LlmRequestState
from ..pyexecutor.resource_manager import BaseResourceManager, ResourceManager
from ..pyexecutor.scheduler import ScheduledRequests
from .drafter import Drafter
@ -181,6 +182,7 @@ class NGramDrafter(Drafter):
if self.spec_config.is_auto_heuristic and len(
scheduled_requests.all_requests()) > 32:
return
# Sort by request_id when py_batch_idx is None as a fallback.
# This happens in the disagg case: for a set of new requests, we draft
# before forward_step, so py_batch_idx is not assigned.
@ -190,7 +192,7 @@ class NGramDrafter(Drafter):
(r.py_batch_idx is None, r.py_batch_idx or r.request_id),
):
# Add new token to a copy of the generated tokens to find new draft tokens
prefix = list(request.get_tokens()[0]) # Get a copy
prefix = list(request.get_tokens(0)) # Get a copy
# Generate draft tokens
draft_tokens = self.spec_resource_manager.get_draft_tokens(

View File

@ -1,7 +1,9 @@
from tensorrt_llm._torch.pyexecutor.sampler import TorchSampler
from tensorrt_llm._torch.speculative.interface import SpecMetadata
from typing import Optional
from ..pyexecutor.guided_decoder import GuidedDecoder
from ..pyexecutor.sampler import TorchSampler
from ..pyexecutor.seq_slot_manager import SeqSlotManager
from ..speculative.interface import SpecMetadata
from .eagle3 import (Eagle3OneModelSampler, Eagle3OneModelSpecMetadata,
Eagle3OneModelWorker, Eagle3ResourceManager,
Eagle3SpecMetadata)
@ -114,8 +116,11 @@ def get_spec_decoder(sampler_args: TorchSampler.Args,
f"Unsupported speculative decoding mode: {spec_config.spec_dec_mode}")
def get_spec_drafter(model_engine, draft_model_engine, sampler,
spec_resource_manager):
def get_spec_drafter(model_engine,
draft_model_engine,
sampler,
spec_resource_manager,
guided_decoder: Optional[GuidedDecoder] = None):
spec_config = model_engine.spec_config
if spec_config is None:
return None
@ -126,10 +131,13 @@ def get_spec_drafter(model_engine, draft_model_engine, sampler,
max_num_requests = model_engine.batch_size
if spec_config.spec_dec_mode.is_draft_target(
) or spec_config.spec_dec_mode.is_eagle3():
return ModelDrafter(spec_config, draft_model_engine,
return ModelDrafter(spec_config,
draft_model_engine,
spec_config.max_draft_len,
SeqSlotManager(max_num_requests), sampler,
spec_resource_manager)
SeqSlotManager(max_num_requests),
sampler,
spec_resource_manager=spec_resource_manager,
guided_decoder=guided_decoder)
if spec_config.spec_dec_mode.is_ngram():
return NGramDrafter(spec_config, spec_resource_manager)

View File

@ -18,6 +18,7 @@ from typing import Iterable, List, Optional, Union
import click
import datasets
import jsonschema
import numpy as np
from .. import LLM as PyTorchLLM
@ -65,23 +66,30 @@ class JsonModeEval(Evaluator):
sampling_args = {
"guided_decoding": GuidedDecodingParams(json=schema)
}
yield sample["prompt"], sampling_args, sample["completion"]
yield sample["prompt"], sampling_args, sample["completion"], sample[
"schema"]
def compute_score(self, outputs: List[RequestOutput],
references: List[str]) -> float:
all_corrections = []
for output, ref in zip(outputs, references):
def compute_score(self, outputs: List[RequestOutput], references: List[str],
schemas: List[str]) -> float:
all_corrections, all_grammar_corrections = [], []
for output, ref, schema in zip(outputs, references, schemas):
try:
output_json = json.loads(output.outputs[0].text)
except json.JSONDecodeError:
jsonschema.validate(output_json, json.loads(schema))
except (json.JSONDecodeError, jsonschema.ValidationError):
all_corrections.append(False)
all_grammar_corrections.append(False)
continue
ref_json = json.loads(ref)
all_corrections.append(output_json == ref_json)
all_corrections.append(output_json == json.loads(ref))
all_grammar_corrections.append(True)
acc = np.mean(all_corrections) * 100
logger.info(
f"JSON Mode Eval accuracy: {acc:.2f} ({len(all_corrections)})")
grammar_acc = np.mean(all_grammar_corrections) * 100
logger.info(
f"JSON Mode Eval grammar accuracy: {grammar_acc:.2f} ({len(all_grammar_corrections)})"
)
return acc
@click.command("json_mode_eval")

View File

@ -1,2 +1,6 @@
meta-llama/Llama-3.1-8B-Instruct:
- accuracy: 74.00
- spec_dec_algo: Eagle
accuracy: 74.00
- spec_dec_algo: NGram
accuracy: 74.00

View File

@ -304,9 +304,7 @@ class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness):
@pytest.mark.parametrize("backend", ["xgrammar", "llguidance"])
def test_guided_decoding(self, backend: str, mocker):
mocker.patch.dict(os.environ, {"TRTLLM_XGUIDANCE_LENIENT": "1"})
llm = LLM(self.MODEL_PATH,
guided_decoding_backend=backend,
cuda_graph_config=CudaGraphConfig())
llm = LLM(self.MODEL_PATH, guided_decoding_backend=backend)
with llm:
task = JsonModeEval(self.MODEL_NAME)
task.evaluate(llm)
@ -318,12 +316,46 @@ class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness):
mocker.patch.dict(os.environ, {"TRTLLM_XGUIDANCE_LENIENT": "1"})
with LLM(self.MODEL_PATH,
guided_decoding_backend=backend,
cuda_graph_config=CudaGraphConfig(),
tensor_parallel_size=2,
pipeline_parallel_size=2) as llm:
task = JsonModeEval(self.MODEL_NAME)
task.evaluate(llm)
@skip_pre_hopper
@pytest.mark.parametrize("backend", ["xgrammar", "llguidance"])
def test_guided_decoding_with_eagle3(self, backend: str, mocker):
mocker.patch.dict(os.environ, {"TRTLLM_XGUIDANCE_LENIENT": "1"})
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.8)
spec_config = EagleDecodingConfig(
max_draft_len=3,
speculative_model_dir=
f"{llm_models_root()}/EAGLE3-LLaMA3.1-Instruct-8B",
eagle3_one_model=False)
llm = LLM(self.MODEL_PATH,
guided_decoding_backend=backend,
kv_cache_config=kv_cache_config,
speculative_config=spec_config,
disable_overlap_scheduler=True)
with llm:
task = JsonModeEval(self.MODEL_NAME)
task.evaluate(llm)
@skip_pre_hopper
@pytest.mark.parametrize("backend", ["xgrammar", "llguidance"])
def test_guided_decoding_with_ngram(self, backend: str, mocker):
mocker.patch.dict(os.environ, {"TRTLLM_XGUIDANCE_LENIENT": "1"})
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.8)
spec_config = NGramDecodingConfig(max_draft_len=3,
max_matching_ngram_size=3)
llm = LLM(self.MODEL_PATH,
guided_decoding_backend=backend,
kv_cache_config=kv_cache_config,
speculative_config=spec_config,
disable_overlap_scheduler=True)
with llm:
task = JsonModeEval(self.MODEL_NAME)
task.evaluate(llm)
class TestLlama3_2_1B(LlmapiAccuracyTestHarness):
MODEL_NAME = "meta-llama/Llama-3.2-1B"

View File

@ -31,6 +31,7 @@ l0_h100:
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_bfloat16[attn_backend=TRTLLM-torch_compile=True]
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_chunked_prefill[attn_backend=TRTLLM] TIMEOUT (90)
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding[xgrammar]
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding_with_eagle3[xgrammar]
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8[fp8kv=False-attn_backend=TRTLLM-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8[fp8kv=False-attn_backend=TRTLLM-torch_compile=True]
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8[fp8kv=True-attn_backend=TRTLLM-torch_compile=False]
@ -209,6 +210,9 @@ l0_h100:
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_fp8_block_scales[latency-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_fp8_block_scales[latency-torch_compile=True]
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding[llguidance]
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding_with_eagle3[llguidance]
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding_with_ngram[xgrammar]
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding_with_ngram[llguidance]
- test_e2e.py::test_ptp_quickstart_multimodal[mistral-small-3.1-24b-instruct-Mistral-Small-3.1-24B-Instruct-2503-image-True]
- test_e2e.py::test_ptp_quickstart_multimodal[mistral-small-3.1-24b-instruct-Mistral-Small-3.1-24B-Instruct-2503-mixture_text_image-True]
- condition: