[TRTLLM-10325][feat] Refactor speculative decoding workers (#10768)

Signed-off-by: Guiju Zhang <7135567+cascade812@users.noreply.github.com>
This commit is contained in:
Guiju Zhang 2026-01-21 10:05:29 -08:00 committed by GitHub
parent f91ea37a13
commit 8cf8fbbe16
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 229 additions and 124 deletions

View File

@ -377,15 +377,14 @@ class Eagle3OneModelWorker(SpecWorkerBase):
raw_logits = logits
if self.guided_decoder is not None:
self.guided_decoder.execute(logits)
self._execute_guided_decoder_if_present(logits)
# Sample and accept tokens
accepted_tokens, num_accepted_tokens = self.sample_and_accept_draft_tokens(
logits, attn_metadata, spec_metadata)
# Save the old attn_metadata and spec_metadata
attn_metadata.prepare_for_spec_dec("_seq_lens", "_seq_lens_cuda")
self._prepare_attn_metadata_for_spec_dec(attn_metadata)
# Prepare inputs for the 1st draft model forward
position_ids = position_ids.squeeze(0)
@ -479,18 +478,15 @@ class Eagle3OneModelWorker(SpecWorkerBase):
next_draft_tokens = torch.stack(next_draft_tokens, dim=1)
# restore attn_metadata to support cuda graph
attn_metadata.restore_from_spec_dec()
attn_metadata.on_update()
self._restore_attn_metadata_from_spec_dec(attn_metadata)
# restore all_rank_num_tokens for attention DP
if original_all_rank_num_tokens is not None:
attn_metadata.all_rank_num_tokens = original_all_rank_num_tokens
# prepare next new tokens to support overlap scheduler
next_new_tokens = accepted_tokens[
spec_metadata.batch_indices_cuda[:batch_size],
num_accepted_tokens - 1].unsqueeze(1)
next_new_tokens = torch.concat([next_new_tokens, next_draft_tokens],
dim=1)
next_new_tokens = self._prepare_next_new_tokens(
accepted_tokens, next_draft_tokens,
spec_metadata.batch_indices_cuda, batch_size, num_accepted_tokens)
attn_metadata.use_spec_decoding = True
@ -512,39 +508,13 @@ class Eagle3OneModelWorker(SpecWorkerBase):
num_contexts = attn_metadata.num_contexts
num_gens = batch_size - num_contexts
if logits.dim() == 1:
logits = logits.unsqueeze(0)
# The return buffer
accepted_tokens = torch.empty((batch_size, (self.max_draft_len + 1)),
dtype=torch.int,
device=logits.device)
num_accepted_tokens = torch.ones(batch_size,
dtype=torch.int,
device=logits.device)
# Sample tokens using per-request sampling parameters
target_tokens = self._sample_tokens_for_batch(logits, spec_metadata,
num_contexts, batch_size)
# context
accepted_tokens[:num_contexts, 0] = target_tokens[:num_contexts]
# generation
gen_target_tokens = target_tokens[num_contexts:].reshape(
num_gens, self.max_draft_len + 1)
accepted_tokens[num_contexts:, :] = gen_target_tokens
# Reshape draft tokens for base implementation
draft_tokens = spec_metadata.draft_tokens.reshape(
num_gens, self.max_draft_len)
num_accepted_tokens[num_contexts:] += torch.cumprod(
(draft_tokens == gen_target_tokens[:, :self.max_draft_len]).int(),
dim=-1).sum(1)
# Check for environment variable override
if self.force_num_accepted_tokens != 0:
# total tokens per iteration = accepted draft tokens + 1 target token
force_total_tokens = min(self.force_num_accepted_tokens + 1,
self.max_draft_len + 1)
num_accepted_tokens[num_contexts:] = force_total_tokens
return accepted_tokens, num_accepted_tokens
# Use base implementation for strict acceptance
return self._sample_and_accept_draft_tokens_base(
logits, draft_tokens, num_contexts, batch_size, spec_metadata)
def draft_decoder(
self,
@ -570,15 +540,8 @@ class Eagle3OneModelWorker(SpecWorkerBase):
# Note: using greedy for draft tokens is a bit easier to implement and
# faster. It doesn't affect the final output and seems to have a negligible
# impact on AR.
draft_tokens = torch.argmax(logits, dim=-1)
# Apply d2t (offsets between draft model dictionary and main model dictionary).
if (d2t := getattr(draft_model.model, "d2t", None)) is not None:
draft_tokens = d2t[draft_tokens] + draft_tokens
draft_tokens = draft_tokens.type(torch.int32)
return draft_tokens
d2t = getattr(draft_model.model, "d2t", None)
return self._draft_sampler_greedy(logits, d2t)
def prepare_1st_drafter_inputs(
self,
@ -601,14 +564,9 @@ class Eagle3OneModelWorker(SpecWorkerBase):
hidden_states = draft_model.apply_eagle3_fc(hidden_states)
# context
input_ctx_ids = input_ids[:attn_metadata.num_ctx_tokens]
input_ids_ctx = torch.empty_like(input_ctx_ids,
dtype=torch.int32,
device="cuda")
input_ids_ctx[:-1].copy_(input_ctx_ids[1:])
input_ids_ctx[
spec_metadata.
gather_ids[:num_contexts]] = accepted_tokens[:num_contexts, 0]
input_ids_ctx = self._prepare_context_input_ids(
input_ids, attn_metadata.num_ctx_tokens, spec_metadata.gather_ids,
accepted_tokens, num_contexts)
# generation
input_ids_gen = accepted_tokens[num_contexts:, :].flatten()

View File

@ -423,6 +423,177 @@ class SpecWorkerBase(nn.Module, ABC):
self.guided_decoder = guided_decoder
return True
def _prepare_attn_metadata_for_spec_dec(self, attn_metadata):
"""
Prepare attention metadata before speculative decoding draft token generation.
Saves current state for later restoration.
"""
attn_metadata.prepare_for_spec_dec("_seq_lens", "_seq_lens_cuda")
def _restore_attn_metadata_from_spec_dec(self, attn_metadata):
"""
Restore attention metadata after speculative decoding draft token generation.
"""
attn_metadata.restore_from_spec_dec()
attn_metadata.on_update()
def _apply_force_accepted_tokens(self, num_accepted_tokens, num_contexts):
"""
Apply forced number of accepted tokens if environment variable is set.
This is used for testing and debugging.
Args:
num_accepted_tokens: Tensor of shape [batch_size] with current accepted counts
num_contexts: Number of context (prefill) requests
Returns:
Modified num_accepted_tokens tensor
Note:
For MTPWorker, self.max_draft_len equals num_nextn_predict_layers (mtp_num_modules).
For Eagle3OneModelWorker, self.max_draft_len equals spec_config.max_draft_len.
"""
if self.force_num_accepted_tokens != 0:
# total tokens per iteration = accepted draft tokens + 1 target token
force_total_tokens = min(self.force_num_accepted_tokens + 1,
self.max_draft_len + 1)
num_accepted_tokens[num_contexts:] = force_total_tokens
return num_accepted_tokens
def _sample_and_accept_draft_tokens_base(
self,
logits: torch.Tensor,
draft_tokens: torch.Tensor,
num_contexts: int,
batch_size: int,
spec_metadata,
):
"""
Base implementation for sampling and accepting draft tokens.
Uses strict acceptance (token equality with cumulative product).
This is the common logic shared between Eagle3 and MTP (when relaxed
acceptance is disabled).
Args:
logits: [num_tokens, vocab_size] - Target model logits
draft_tokens: [num_gens, max_draft_len] - Previously predicted draft tokens
num_contexts: Number of context requests
batch_size: Total number of requests
spec_metadata: Speculative decoding metadata
Returns:
accepted_tokens: [batch_size, max_draft_len + 1] - Accepted tokens
num_accepted_tokens: [batch_size] - Number of accepted tokens per request
"""
num_gens = batch_size - num_contexts
if logits.dim() == 1:
logits = logits.unsqueeze(0)
# Allocate return buffers
accepted_tokens = torch.empty((batch_size, self.max_draft_len + 1),
dtype=torch.int,
device=logits.device)
num_accepted_tokens = torch.ones(batch_size,
dtype=torch.int,
device=logits.device)
# Sample tokens using per-request sampling parameters
target_tokens = self._sample_tokens_for_batch(logits, spec_metadata,
num_contexts, batch_size)
# Context requests: only accept the sampled token (no draft tokens yet)
accepted_tokens[:num_contexts, 0] = target_tokens[:num_contexts]
# Generation requests: verify draft tokens against target tokens
gen_target_tokens = target_tokens[num_contexts:].reshape(
num_gens, self.max_draft_len + 1)
accepted_tokens[num_contexts:, :] = gen_target_tokens
# Compare draft tokens with target tokens using cumulative product
# Counts consecutive matches from the start
num_accepted_tokens[num_contexts:] += torch.cumprod(
(draft_tokens == gen_target_tokens[:, :self.max_draft_len]).int(),
dim=-1).sum(1)
# Apply force override if set
num_accepted_tokens = self._apply_force_accepted_tokens(
num_accepted_tokens, num_contexts)
return accepted_tokens, num_accepted_tokens
def _draft_sampler_greedy(self, logits: torch.Tensor, d2t=None):
"""
Simple greedy draft token sampling using argmax.
Args:
logits: [num_tokens, vocab_size] - Draft model logits
d2t: Optional dictionary offset tensor for vocab mapping
Returns:
draft_tokens: [num_tokens] - Sampled draft token ids (int32)
"""
draft_tokens = torch.argmax(logits, dim=-1)
# Apply d2t (offsets between draft and target model dictionaries)
if d2t is not None:
draft_tokens = d2t[draft_tokens] + draft_tokens
return draft_tokens.type(torch.int32)
def _execute_guided_decoder_if_present(self, logits):
"""Execute guided decoder on target model logits if available."""
if self.guided_decoder is not None:
self.guided_decoder.execute(logits)
def _prepare_next_new_tokens(self, accepted_tokens, next_draft_tokens,
batch_indices_cuda, batch_size,
num_accepted_tokens):
"""
Prepare next_new_tokens for overlap scheduler support.
Args:
accepted_tokens: [batch_size, max_draft_len + 1] - Accepted tokens
next_draft_tokens: [batch_size, max_draft_len] - Predicted draft tokens
batch_indices_cuda: Batch indices tensor
batch_size: Number of requests
num_accepted_tokens: [batch_size] - Number of accepted tokens per request
Returns:
next_new_tokens: [batch_size, max_draft_len + 1] - Input tokens for next iteration
"""
next_new_tokens = accepted_tokens[batch_indices_cuda[:batch_size],
num_accepted_tokens - 1].unsqueeze(1)
next_new_tokens = torch.concat([next_new_tokens, next_draft_tokens],
dim=1)
return next_new_tokens
def _prepare_context_input_ids(self, input_ids, num_ctx_tokens, gather_ids,
accepted_tokens, num_contexts):
"""
Prepare context input IDs for draft model forward.
Shifts input IDs left by 1 and places the first accepted token at gather positions.
Args:
input_ids: Original input IDs tensor
num_ctx_tokens: Number of context tokens
gather_ids: Indices for placing accepted tokens (last token positions)
accepted_tokens: [batch_size, max_draft_len + 1] - Accepted tokens
num_contexts: Number of context requests
Returns:
input_ids_ctx: Prepared context input IDs
"""
input_prompt_ids = input_ids[:num_ctx_tokens]
input_ids_ctx = torch.empty_like(input_prompt_ids,
dtype=torch.int32,
device="cuda")
input_ids_ctx[:-1].copy_(input_prompt_ids[1:])
input_ids_ctx[
gather_ids[:num_contexts]] = accepted_tokens[:num_contexts, 0]
return input_ids_ctx
def _sample_tokens_for_batch(
self,
logits: torch.Tensor,

View File

@ -473,8 +473,7 @@ class MTPWorker(SpecWorkerBase):
raw_logits = logits
if self.guided_decoder is not None:
self.guided_decoder.execute(logits)
self._execute_guided_decoder_if_present(logits)
# Sample and verify draft tokens
accepted_tokens, num_accepted_tokens = self.sample_and_accept_draft_tokens(
@ -541,14 +540,12 @@ class MTPWorker(SpecWorkerBase):
# restore attn metadata
if attn_metadata is not None:
self.restore_attn_metadata(attn_metadata=attn_metadata)
self._restore_attn_metadata_from_spec_dec(attn_metadata)
# prepare next new tokens to support overlap scheduler
next_new_tokens = accepted_tokens[
spec_metadata.batch_indices_cuda[:batch_size],
num_accepted_tokens - 1].unsqueeze(1)
next_new_tokens = torch.concat([next_new_tokens, next_draft_tokens],
dim=1)
next_new_tokens = self._prepare_next_new_tokens(
accepted_tokens, next_draft_tokens,
spec_metadata.batch_indices_cuda, batch_size, num_accepted_tokens)
return {
'logits': raw_logits,
@ -845,6 +842,10 @@ class MTPWorker(SpecWorkerBase):
self.spec_config.begin_thinking_phase_token,
self.spec_config.end_thinking_phase_token)
# Apply force override for relaxed acceptance path
num_accepted_tokens = self._apply_force_accepted_tokens(
num_accepted_tokens, num_contexts)
# Strict acceptance
else:
if self.is_thop:
@ -856,36 +857,25 @@ class MTPWorker(SpecWorkerBase):
accepted_tokens, num_accepted_tokens = torch.ops.trtllm.mtp_sampling_and_accepted_draft_tokens_op(
logits, spec_metadata.draft_tokens, target_tokens_cache,
mtp_num_modules, batch_size, num_contexts, logits.shape[-1])
# Apply force override for THOP path
num_accepted_tokens = self._apply_force_accepted_tokens(
num_accepted_tokens, num_contexts)
else:
target_tokens = self._sample_tokens_for_batch(
logits, spec_metadata, num_contexts, batch_size)
# context
accepted_tokens[:num_contexts, 0] = target_tokens[:num_contexts]
# generation
gen_target_tokens = target_tokens[num_contexts:].reshape(
num_gens, mtp_num_modules + 1)
accepted_tokens[num_contexts:, :] = gen_target_tokens
# Reshape draft tokens for base implementation
draft_tokens = spec_metadata.draft_tokens.reshape(
num_gens, mtp_num_modules)
num_accepted_tokens[num_contexts:] += torch.cumprod(
(draft_tokens == gen_target_tokens[:, :mtp_num_modules]
).int(),
dim=-1).sum(1)
# Check for environment variable override
if self.force_num_accepted_tokens != 0:
# total tokens per iteration = accepted draft tokens + 1 target token
force_total_tokens = min(self.force_num_accepted_tokens + 1,
mtp_num_modules + 1)
num_accepted_tokens[num_contexts:] = force_total_tokens
# Use base implementation for strict acceptance
accepted_tokens, num_accepted_tokens = self._sample_and_accept_draft_tokens_base(
logits, draft_tokens, num_contexts, batch_size,
spec_metadata)
return accepted_tokens, num_accepted_tokens
def change_attn_metadata(self, num_accepted_tokens: torch.Tensor,
attn_metadata: AttentionMetadata):
attn_metadata.prepare_for_spec_dec("_seq_lens", "_seq_lens_cuda")
self._prepare_attn_metadata_for_spec_dec(attn_metadata)
batch_size = attn_metadata.num_seqs
mtp_num_modules = self.spec_config.num_nextn_predict_layers
@ -908,10 +898,6 @@ class MTPWorker(SpecWorkerBase):
attn_metadata.kv_cache_params.num_cached_tokens_per_seq[
i] -= mtp_num_modules + 1 - num_accepted_tokens[i].item()
def restore_attn_metadata(self, attn_metadata: AttentionMetadata):
attn_metadata.restore_from_spec_dec()
attn_metadata.on_update()
def prepare_drafter_inputs(
self,
input_ids: torch.IntTensor,
@ -1015,13 +1001,9 @@ class MTPWorker(SpecWorkerBase):
# context
if num_contexts > 0:
hidden_states_ctx = hidden_states[:num_ctx_tokens, :]
input_prompt_ids = input_ids[:num_ctx_tokens]
input_ids_ctx = torch.empty_like(input_prompt_ids,
dtype=torch.int32,
device="cuda")
input_ids_ctx[:-1].copy_(input_prompt_ids[1:])
input_ids_ctx[last_tokens_idx[:num_contexts]] = \
accepted_tokens[:num_contexts, 0]
input_ids_ctx = self._prepare_context_input_ids(
input_ids, num_ctx_tokens, last_tokens_idx, accepted_tokens,
num_contexts)
return_input_ids_list.append(input_ids_ctx)
return_hidden_states_list.append(hidden_states_ctx)
# generation
@ -1137,7 +1119,7 @@ class MTPWorker(SpecWorkerBase):
draft_tokens = self.get_draft_tokens_from_gathered(sliced_gathered)
else:
# Simple argmax if no TP or no model config
draft_tokens = torch.argmax(logits, dim=-1).type(torch.int32)
draft_tokens = self._draft_sampler_greedy(logits)
return draft_tokens
@ -1176,15 +1158,14 @@ class MTPEagleWorker(MTPWorker):
raw_logits = logits
if self.guided_decoder is not None:
self.guided_decoder.execute(logits)
self._execute_guided_decoder_if_present(logits)
# Sample and verify draft tokens
accepted_tokens, num_accepted_tokens = self.sample_and_accept_draft_tokens(
input_ids, logits, spec_metadata, attn_metadata)
# Save the old attn_metadata and spec_metadata
attn_metadata.prepare_for_spec_dec("_seq_lens", "_seq_lens_cuda")
self._prepare_attn_metadata_for_spec_dec(attn_metadata)
# Prepare inputs for the 1st MTP layer
@torch.compile(options={"max-autotune": True})
@ -1323,22 +1304,9 @@ class MTPEagleWorker(MTPWorker):
}
# restore attn_metadata to support cuda graph
attn_metadata.restore_from_spec_dec()
attn_metadata.on_update()
self._restore_attn_metadata_from_spec_dec(attn_metadata)
@torch.compile(options={"max-autotune": True})
def prepare_next_tokens(next_draft_tokens, accepted_tokens,
spec_metadata, batch_size, num_accepted_tokens):
next_draft_tokens = torch.stack(next_draft_tokens, dim=1)
# prepare next new tokens to support overlap scheduler
next_new_tokens = accepted_tokens[
spec_metadata.batch_indices_cuda[:batch_size],
num_accepted_tokens - 1].unsqueeze(1)
next_new_tokens = torch.concat([next_new_tokens, next_draft_tokens],
dim=1)
return next_draft_tokens, next_new_tokens
next_draft_tokens, next_new_tokens = prepare_next_tokens(
next_draft_tokens, next_new_tokens = self._prepare_next_tokens(
next_draft_tokens, accepted_tokens, spec_metadata, batch_size,
num_accepted_tokens)
@ -1350,6 +1318,18 @@ class MTPEagleWorker(MTPWorker):
'next_new_tokens': next_new_tokens
}
@torch.compile(options={"max-autotune": True})
def _prepare_next_tokens(self, next_draft_tokens, accepted_tokens,
spec_metadata, batch_size, num_accepted_tokens):
"""
Stack draft tokens and prepare next_new_tokens for overlap scheduler.
"""
next_draft_tokens = torch.stack(next_draft_tokens, dim=1)
next_new_tokens = self._prepare_next_new_tokens(
accepted_tokens, next_draft_tokens,
spec_metadata.batch_indices_cuda, batch_size, num_accepted_tokens)
return next_draft_tokens, next_new_tokens
@torch.compile(options={"max-autotune": True})
def prepare_drafter_inputs(
self,
@ -1364,13 +1344,9 @@ class MTPEagleWorker(MTPWorker):
num_contexts = attn_metadata.num_contexts
# context
input_prompt_ids = input_ids[:attn_metadata.num_ctx_tokens]
input_ids_ctx = torch.empty_like(input_prompt_ids,
dtype=torch.int32,
device="cuda")
input_ids_ctx[:-1].copy_(input_prompt_ids[1:])
input_ids_ctx[
last_tokens_idx[:num_contexts]] = accepted_tokens[:num_contexts, 0]
input_ids_ctx = self._prepare_context_input_ids(
input_ids, attn_metadata.num_ctx_tokens, last_tokens_idx,
accepted_tokens, num_contexts)
# generation
input_ids_gen = accepted_tokens[num_contexts:, :].flatten()