From 8cf8fbbe1619442a174f6e7dddfd0bb3a381d8b5 Mon Sep 17 00:00:00 2001 From: Guiju Zhang <7135567+cascade812@users.noreply.github.com> Date: Wed, 21 Jan 2026 10:05:29 -0800 Subject: [PATCH] [TRTLLM-10325][feat] Refactor speculative decoding workers (#10768) Signed-off-by: Guiju Zhang <7135567+cascade812@users.noreply.github.com> --- tensorrt_llm/_torch/speculative/eagle3.py | 74 ++------ tensorrt_llm/_torch/speculative/interface.py | 171 +++++++++++++++++++ tensorrt_llm/_torch/speculative/mtp.py | 108 +++++------- 3 files changed, 229 insertions(+), 124 deletions(-) diff --git a/tensorrt_llm/_torch/speculative/eagle3.py b/tensorrt_llm/_torch/speculative/eagle3.py index 6fa7fba858..fab4fb7b65 100644 --- a/tensorrt_llm/_torch/speculative/eagle3.py +++ b/tensorrt_llm/_torch/speculative/eagle3.py @@ -377,15 +377,14 @@ class Eagle3OneModelWorker(SpecWorkerBase): raw_logits = logits - if self.guided_decoder is not None: - self.guided_decoder.execute(logits) + self._execute_guided_decoder_if_present(logits) # Sample and accept tokens accepted_tokens, num_accepted_tokens = self.sample_and_accept_draft_tokens( logits, attn_metadata, spec_metadata) # Save the old attn_metadata and spec_metadata - attn_metadata.prepare_for_spec_dec("_seq_lens", "_seq_lens_cuda") + self._prepare_attn_metadata_for_spec_dec(attn_metadata) # Prepare inputs for the 1st draft model forward position_ids = position_ids.squeeze(0) @@ -479,18 +478,15 @@ class Eagle3OneModelWorker(SpecWorkerBase): next_draft_tokens = torch.stack(next_draft_tokens, dim=1) # restore attn_metadata to support cuda graph - attn_metadata.restore_from_spec_dec() - attn_metadata.on_update() + self._restore_attn_metadata_from_spec_dec(attn_metadata) # restore all_rank_num_tokens for attention DP if original_all_rank_num_tokens is not None: attn_metadata.all_rank_num_tokens = original_all_rank_num_tokens # prepare next new tokens to support overlap scheduler - next_new_tokens = accepted_tokens[ - spec_metadata.batch_indices_cuda[:batch_size], - num_accepted_tokens - 1].unsqueeze(1) - next_new_tokens = torch.concat([next_new_tokens, next_draft_tokens], - dim=1) + next_new_tokens = self._prepare_next_new_tokens( + accepted_tokens, next_draft_tokens, + spec_metadata.batch_indices_cuda, batch_size, num_accepted_tokens) attn_metadata.use_spec_decoding = True @@ -512,39 +508,13 @@ class Eagle3OneModelWorker(SpecWorkerBase): num_contexts = attn_metadata.num_contexts num_gens = batch_size - num_contexts - if logits.dim() == 1: - logits = logits.unsqueeze(0) - - # The return buffer - accepted_tokens = torch.empty((batch_size, (self.max_draft_len + 1)), - dtype=torch.int, - device=logits.device) - num_accepted_tokens = torch.ones(batch_size, - dtype=torch.int, - device=logits.device) - - # Sample tokens using per-request sampling parameters - target_tokens = self._sample_tokens_for_batch(logits, spec_metadata, - num_contexts, batch_size) - # context - accepted_tokens[:num_contexts, 0] = target_tokens[:num_contexts] - - # generation - gen_target_tokens = target_tokens[num_contexts:].reshape( - num_gens, self.max_draft_len + 1) - accepted_tokens[num_contexts:, :] = gen_target_tokens + # Reshape draft tokens for base implementation draft_tokens = spec_metadata.draft_tokens.reshape( num_gens, self.max_draft_len) - num_accepted_tokens[num_contexts:] += torch.cumprod( - (draft_tokens == gen_target_tokens[:, :self.max_draft_len]).int(), - dim=-1).sum(1) - # Check for environment variable override - if self.force_num_accepted_tokens != 0: - # total tokens per iteration = accepted draft tokens + 1 target token - force_total_tokens = min(self.force_num_accepted_tokens + 1, - self.max_draft_len + 1) - num_accepted_tokens[num_contexts:] = force_total_tokens - return accepted_tokens, num_accepted_tokens + + # Use base implementation for strict acceptance + return self._sample_and_accept_draft_tokens_base( + logits, draft_tokens, num_contexts, batch_size, spec_metadata) def draft_decoder( self, @@ -570,15 +540,8 @@ class Eagle3OneModelWorker(SpecWorkerBase): # Note: using greedy for draft tokens is a bit easier to implement and # faster. It doesn't affect the final output and seems to have a negligible # impact on AR. - draft_tokens = torch.argmax(logits, dim=-1) - - # Apply d2t (offsets between draft model dictionary and main model dictionary). - if (d2t := getattr(draft_model.model, "d2t", None)) is not None: - draft_tokens = d2t[draft_tokens] + draft_tokens - - draft_tokens = draft_tokens.type(torch.int32) - - return draft_tokens + d2t = getattr(draft_model.model, "d2t", None) + return self._draft_sampler_greedy(logits, d2t) def prepare_1st_drafter_inputs( self, @@ -601,14 +564,9 @@ class Eagle3OneModelWorker(SpecWorkerBase): hidden_states = draft_model.apply_eagle3_fc(hidden_states) # context - input_ctx_ids = input_ids[:attn_metadata.num_ctx_tokens] - input_ids_ctx = torch.empty_like(input_ctx_ids, - dtype=torch.int32, - device="cuda") - input_ids_ctx[:-1].copy_(input_ctx_ids[1:]) - input_ids_ctx[ - spec_metadata. - gather_ids[:num_contexts]] = accepted_tokens[:num_contexts, 0] + input_ids_ctx = self._prepare_context_input_ids( + input_ids, attn_metadata.num_ctx_tokens, spec_metadata.gather_ids, + accepted_tokens, num_contexts) # generation input_ids_gen = accepted_tokens[num_contexts:, :].flatten() diff --git a/tensorrt_llm/_torch/speculative/interface.py b/tensorrt_llm/_torch/speculative/interface.py index f4908c8b79..6c52fa813e 100644 --- a/tensorrt_llm/_torch/speculative/interface.py +++ b/tensorrt_llm/_torch/speculative/interface.py @@ -423,6 +423,177 @@ class SpecWorkerBase(nn.Module, ABC): self.guided_decoder = guided_decoder return True + def _prepare_attn_metadata_for_spec_dec(self, attn_metadata): + """ + Prepare attention metadata before speculative decoding draft token generation. + Saves current state for later restoration. + """ + attn_metadata.prepare_for_spec_dec("_seq_lens", "_seq_lens_cuda") + + def _restore_attn_metadata_from_spec_dec(self, attn_metadata): + """ + Restore attention metadata after speculative decoding draft token generation. + """ + attn_metadata.restore_from_spec_dec() + attn_metadata.on_update() + + def _apply_force_accepted_tokens(self, num_accepted_tokens, num_contexts): + """ + Apply forced number of accepted tokens if environment variable is set. + This is used for testing and debugging. + + Args: + num_accepted_tokens: Tensor of shape [batch_size] with current accepted counts + num_contexts: Number of context (prefill) requests + + Returns: + Modified num_accepted_tokens tensor + + Note: + For MTPWorker, self.max_draft_len equals num_nextn_predict_layers (mtp_num_modules). + For Eagle3OneModelWorker, self.max_draft_len equals spec_config.max_draft_len. + """ + if self.force_num_accepted_tokens != 0: + # total tokens per iteration = accepted draft tokens + 1 target token + force_total_tokens = min(self.force_num_accepted_tokens + 1, + self.max_draft_len + 1) + num_accepted_tokens[num_contexts:] = force_total_tokens + return num_accepted_tokens + + def _sample_and_accept_draft_tokens_base( + self, + logits: torch.Tensor, + draft_tokens: torch.Tensor, + num_contexts: int, + batch_size: int, + spec_metadata, + ): + """ + Base implementation for sampling and accepting draft tokens. + Uses strict acceptance (token equality with cumulative product). + + This is the common logic shared between Eagle3 and MTP (when relaxed + acceptance is disabled). + + Args: + logits: [num_tokens, vocab_size] - Target model logits + draft_tokens: [num_gens, max_draft_len] - Previously predicted draft tokens + num_contexts: Number of context requests + batch_size: Total number of requests + spec_metadata: Speculative decoding metadata + + Returns: + accepted_tokens: [batch_size, max_draft_len + 1] - Accepted tokens + num_accepted_tokens: [batch_size] - Number of accepted tokens per request + """ + num_gens = batch_size - num_contexts + + if logits.dim() == 1: + logits = logits.unsqueeze(0) + + # Allocate return buffers + accepted_tokens = torch.empty((batch_size, self.max_draft_len + 1), + dtype=torch.int, + device=logits.device) + num_accepted_tokens = torch.ones(batch_size, + dtype=torch.int, + device=logits.device) + + # Sample tokens using per-request sampling parameters + target_tokens = self._sample_tokens_for_batch(logits, spec_metadata, + num_contexts, batch_size) + + # Context requests: only accept the sampled token (no draft tokens yet) + accepted_tokens[:num_contexts, 0] = target_tokens[:num_contexts] + + # Generation requests: verify draft tokens against target tokens + gen_target_tokens = target_tokens[num_contexts:].reshape( + num_gens, self.max_draft_len + 1) + accepted_tokens[num_contexts:, :] = gen_target_tokens + + # Compare draft tokens with target tokens using cumulative product + # Counts consecutive matches from the start + num_accepted_tokens[num_contexts:] += torch.cumprod( + (draft_tokens == gen_target_tokens[:, :self.max_draft_len]).int(), + dim=-1).sum(1) + + # Apply force override if set + num_accepted_tokens = self._apply_force_accepted_tokens( + num_accepted_tokens, num_contexts) + + return accepted_tokens, num_accepted_tokens + + def _draft_sampler_greedy(self, logits: torch.Tensor, d2t=None): + """ + Simple greedy draft token sampling using argmax. + + Args: + logits: [num_tokens, vocab_size] - Draft model logits + d2t: Optional dictionary offset tensor for vocab mapping + + Returns: + draft_tokens: [num_tokens] - Sampled draft token ids (int32) + """ + draft_tokens = torch.argmax(logits, dim=-1) + + # Apply d2t (offsets between draft and target model dictionaries) + if d2t is not None: + draft_tokens = d2t[draft_tokens] + draft_tokens + + return draft_tokens.type(torch.int32) + + def _execute_guided_decoder_if_present(self, logits): + """Execute guided decoder on target model logits if available.""" + if self.guided_decoder is not None: + self.guided_decoder.execute(logits) + + def _prepare_next_new_tokens(self, accepted_tokens, next_draft_tokens, + batch_indices_cuda, batch_size, + num_accepted_tokens): + """ + Prepare next_new_tokens for overlap scheduler support. + + Args: + accepted_tokens: [batch_size, max_draft_len + 1] - Accepted tokens + next_draft_tokens: [batch_size, max_draft_len] - Predicted draft tokens + batch_indices_cuda: Batch indices tensor + batch_size: Number of requests + num_accepted_tokens: [batch_size] - Number of accepted tokens per request + + Returns: + next_new_tokens: [batch_size, max_draft_len + 1] - Input tokens for next iteration + """ + next_new_tokens = accepted_tokens[batch_indices_cuda[:batch_size], + num_accepted_tokens - 1].unsqueeze(1) + next_new_tokens = torch.concat([next_new_tokens, next_draft_tokens], + dim=1) + return next_new_tokens + + def _prepare_context_input_ids(self, input_ids, num_ctx_tokens, gather_ids, + accepted_tokens, num_contexts): + """ + Prepare context input IDs for draft model forward. + Shifts input IDs left by 1 and places the first accepted token at gather positions. + + Args: + input_ids: Original input IDs tensor + num_ctx_tokens: Number of context tokens + gather_ids: Indices for placing accepted tokens (last token positions) + accepted_tokens: [batch_size, max_draft_len + 1] - Accepted tokens + num_contexts: Number of context requests + + Returns: + input_ids_ctx: Prepared context input IDs + """ + input_prompt_ids = input_ids[:num_ctx_tokens] + input_ids_ctx = torch.empty_like(input_prompt_ids, + dtype=torch.int32, + device="cuda") + input_ids_ctx[:-1].copy_(input_prompt_ids[1:]) + input_ids_ctx[ + gather_ids[:num_contexts]] = accepted_tokens[:num_contexts, 0] + return input_ids_ctx + def _sample_tokens_for_batch( self, logits: torch.Tensor, diff --git a/tensorrt_llm/_torch/speculative/mtp.py b/tensorrt_llm/_torch/speculative/mtp.py index 9e7313c1e3..f2388b9851 100644 --- a/tensorrt_llm/_torch/speculative/mtp.py +++ b/tensorrt_llm/_torch/speculative/mtp.py @@ -473,8 +473,7 @@ class MTPWorker(SpecWorkerBase): raw_logits = logits - if self.guided_decoder is not None: - self.guided_decoder.execute(logits) + self._execute_guided_decoder_if_present(logits) # Sample and verify draft tokens accepted_tokens, num_accepted_tokens = self.sample_and_accept_draft_tokens( @@ -541,14 +540,12 @@ class MTPWorker(SpecWorkerBase): # restore attn metadata if attn_metadata is not None: - self.restore_attn_metadata(attn_metadata=attn_metadata) + self._restore_attn_metadata_from_spec_dec(attn_metadata) # prepare next new tokens to support overlap scheduler - next_new_tokens = accepted_tokens[ - spec_metadata.batch_indices_cuda[:batch_size], - num_accepted_tokens - 1].unsqueeze(1) - next_new_tokens = torch.concat([next_new_tokens, next_draft_tokens], - dim=1) + next_new_tokens = self._prepare_next_new_tokens( + accepted_tokens, next_draft_tokens, + spec_metadata.batch_indices_cuda, batch_size, num_accepted_tokens) return { 'logits': raw_logits, @@ -845,6 +842,10 @@ class MTPWorker(SpecWorkerBase): self.spec_config.begin_thinking_phase_token, self.spec_config.end_thinking_phase_token) + # Apply force override for relaxed acceptance path + num_accepted_tokens = self._apply_force_accepted_tokens( + num_accepted_tokens, num_contexts) + # Strict acceptance else: if self.is_thop: @@ -856,36 +857,25 @@ class MTPWorker(SpecWorkerBase): accepted_tokens, num_accepted_tokens = torch.ops.trtllm.mtp_sampling_and_accepted_draft_tokens_op( logits, spec_metadata.draft_tokens, target_tokens_cache, mtp_num_modules, batch_size, num_contexts, logits.shape[-1]) + + # Apply force override for THOP path + num_accepted_tokens = self._apply_force_accepted_tokens( + num_accepted_tokens, num_contexts) else: - target_tokens = self._sample_tokens_for_batch( - logits, spec_metadata, num_contexts, batch_size) - - # context - accepted_tokens[:num_contexts, 0] = target_tokens[:num_contexts] - - # generation - gen_target_tokens = target_tokens[num_contexts:].reshape( - num_gens, mtp_num_modules + 1) - accepted_tokens[num_contexts:, :] = gen_target_tokens + # Reshape draft tokens for base implementation draft_tokens = spec_metadata.draft_tokens.reshape( num_gens, mtp_num_modules) - num_accepted_tokens[num_contexts:] += torch.cumprod( - (draft_tokens == gen_target_tokens[:, :mtp_num_modules] - ).int(), - dim=-1).sum(1) - # Check for environment variable override - if self.force_num_accepted_tokens != 0: - # total tokens per iteration = accepted draft tokens + 1 target token - force_total_tokens = min(self.force_num_accepted_tokens + 1, - mtp_num_modules + 1) - num_accepted_tokens[num_contexts:] = force_total_tokens + # Use base implementation for strict acceptance + accepted_tokens, num_accepted_tokens = self._sample_and_accept_draft_tokens_base( + logits, draft_tokens, num_contexts, batch_size, + spec_metadata) return accepted_tokens, num_accepted_tokens def change_attn_metadata(self, num_accepted_tokens: torch.Tensor, attn_metadata: AttentionMetadata): - attn_metadata.prepare_for_spec_dec("_seq_lens", "_seq_lens_cuda") + self._prepare_attn_metadata_for_spec_dec(attn_metadata) batch_size = attn_metadata.num_seqs mtp_num_modules = self.spec_config.num_nextn_predict_layers @@ -908,10 +898,6 @@ class MTPWorker(SpecWorkerBase): attn_metadata.kv_cache_params.num_cached_tokens_per_seq[ i] -= mtp_num_modules + 1 - num_accepted_tokens[i].item() - def restore_attn_metadata(self, attn_metadata: AttentionMetadata): - attn_metadata.restore_from_spec_dec() - attn_metadata.on_update() - def prepare_drafter_inputs( self, input_ids: torch.IntTensor, @@ -1015,13 +1001,9 @@ class MTPWorker(SpecWorkerBase): # context if num_contexts > 0: hidden_states_ctx = hidden_states[:num_ctx_tokens, :] - input_prompt_ids = input_ids[:num_ctx_tokens] - input_ids_ctx = torch.empty_like(input_prompt_ids, - dtype=torch.int32, - device="cuda") - input_ids_ctx[:-1].copy_(input_prompt_ids[1:]) - input_ids_ctx[last_tokens_idx[:num_contexts]] = \ - accepted_tokens[:num_contexts, 0] + input_ids_ctx = self._prepare_context_input_ids( + input_ids, num_ctx_tokens, last_tokens_idx, accepted_tokens, + num_contexts) return_input_ids_list.append(input_ids_ctx) return_hidden_states_list.append(hidden_states_ctx) # generation @@ -1137,7 +1119,7 @@ class MTPWorker(SpecWorkerBase): draft_tokens = self.get_draft_tokens_from_gathered(sliced_gathered) else: # Simple argmax if no TP or no model config - draft_tokens = torch.argmax(logits, dim=-1).type(torch.int32) + draft_tokens = self._draft_sampler_greedy(logits) return draft_tokens @@ -1176,15 +1158,14 @@ class MTPEagleWorker(MTPWorker): raw_logits = logits - if self.guided_decoder is not None: - self.guided_decoder.execute(logits) + self._execute_guided_decoder_if_present(logits) # Sample and verify draft tokens accepted_tokens, num_accepted_tokens = self.sample_and_accept_draft_tokens( input_ids, logits, spec_metadata, attn_metadata) # Save the old attn_metadata and spec_metadata - attn_metadata.prepare_for_spec_dec("_seq_lens", "_seq_lens_cuda") + self._prepare_attn_metadata_for_spec_dec(attn_metadata) # Prepare inputs for the 1st MTP layer @torch.compile(options={"max-autotune": True}) @@ -1323,22 +1304,9 @@ class MTPEagleWorker(MTPWorker): } # restore attn_metadata to support cuda graph - attn_metadata.restore_from_spec_dec() - attn_metadata.on_update() + self._restore_attn_metadata_from_spec_dec(attn_metadata) - @torch.compile(options={"max-autotune": True}) - def prepare_next_tokens(next_draft_tokens, accepted_tokens, - spec_metadata, batch_size, num_accepted_tokens): - next_draft_tokens = torch.stack(next_draft_tokens, dim=1) - # prepare next new tokens to support overlap scheduler - next_new_tokens = accepted_tokens[ - spec_metadata.batch_indices_cuda[:batch_size], - num_accepted_tokens - 1].unsqueeze(1) - next_new_tokens = torch.concat([next_new_tokens, next_draft_tokens], - dim=1) - return next_draft_tokens, next_new_tokens - - next_draft_tokens, next_new_tokens = prepare_next_tokens( + next_draft_tokens, next_new_tokens = self._prepare_next_tokens( next_draft_tokens, accepted_tokens, spec_metadata, batch_size, num_accepted_tokens) @@ -1350,6 +1318,18 @@ class MTPEagleWorker(MTPWorker): 'next_new_tokens': next_new_tokens } + @torch.compile(options={"max-autotune": True}) + def _prepare_next_tokens(self, next_draft_tokens, accepted_tokens, + spec_metadata, batch_size, num_accepted_tokens): + """ + Stack draft tokens and prepare next_new_tokens for overlap scheduler. + """ + next_draft_tokens = torch.stack(next_draft_tokens, dim=1) + next_new_tokens = self._prepare_next_new_tokens( + accepted_tokens, next_draft_tokens, + spec_metadata.batch_indices_cuda, batch_size, num_accepted_tokens) + return next_draft_tokens, next_new_tokens + @torch.compile(options={"max-autotune": True}) def prepare_drafter_inputs( self, @@ -1364,13 +1344,9 @@ class MTPEagleWorker(MTPWorker): num_contexts = attn_metadata.num_contexts # context - input_prompt_ids = input_ids[:attn_metadata.num_ctx_tokens] - input_ids_ctx = torch.empty_like(input_prompt_ids, - dtype=torch.int32, - device="cuda") - input_ids_ctx[:-1].copy_(input_prompt_ids[1:]) - input_ids_ctx[ - last_tokens_idx[:num_contexts]] = accepted_tokens[:num_contexts, 0] + input_ids_ctx = self._prepare_context_input_ids( + input_ids, attn_metadata.num_ctx_tokens, last_tokens_idx, + accepted_tokens, num_contexts) # generation input_ids_gen = accepted_tokens[num_contexts:, :].flatten()