mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-05 02:31:33 +08:00
[TRTLLM-10325][feat] Refactor speculative decoding workers (#10768)
Signed-off-by: Guiju Zhang <7135567+cascade812@users.noreply.github.com>
This commit is contained in:
parent
f91ea37a13
commit
8cf8fbbe16
@ -377,15 +377,14 @@ class Eagle3OneModelWorker(SpecWorkerBase):
|
||||
|
||||
raw_logits = logits
|
||||
|
||||
if self.guided_decoder is not None:
|
||||
self.guided_decoder.execute(logits)
|
||||
self._execute_guided_decoder_if_present(logits)
|
||||
|
||||
# Sample and accept tokens
|
||||
accepted_tokens, num_accepted_tokens = self.sample_and_accept_draft_tokens(
|
||||
logits, attn_metadata, spec_metadata)
|
||||
|
||||
# Save the old attn_metadata and spec_metadata
|
||||
attn_metadata.prepare_for_spec_dec("_seq_lens", "_seq_lens_cuda")
|
||||
self._prepare_attn_metadata_for_spec_dec(attn_metadata)
|
||||
|
||||
# Prepare inputs for the 1st draft model forward
|
||||
position_ids = position_ids.squeeze(0)
|
||||
@ -479,18 +478,15 @@ class Eagle3OneModelWorker(SpecWorkerBase):
|
||||
next_draft_tokens = torch.stack(next_draft_tokens, dim=1)
|
||||
|
||||
# restore attn_metadata to support cuda graph
|
||||
attn_metadata.restore_from_spec_dec()
|
||||
attn_metadata.on_update()
|
||||
self._restore_attn_metadata_from_spec_dec(attn_metadata)
|
||||
# restore all_rank_num_tokens for attention DP
|
||||
if original_all_rank_num_tokens is not None:
|
||||
attn_metadata.all_rank_num_tokens = original_all_rank_num_tokens
|
||||
|
||||
# prepare next new tokens to support overlap scheduler
|
||||
next_new_tokens = accepted_tokens[
|
||||
spec_metadata.batch_indices_cuda[:batch_size],
|
||||
num_accepted_tokens - 1].unsqueeze(1)
|
||||
next_new_tokens = torch.concat([next_new_tokens, next_draft_tokens],
|
||||
dim=1)
|
||||
next_new_tokens = self._prepare_next_new_tokens(
|
||||
accepted_tokens, next_draft_tokens,
|
||||
spec_metadata.batch_indices_cuda, batch_size, num_accepted_tokens)
|
||||
|
||||
attn_metadata.use_spec_decoding = True
|
||||
|
||||
@ -512,39 +508,13 @@ class Eagle3OneModelWorker(SpecWorkerBase):
|
||||
num_contexts = attn_metadata.num_contexts
|
||||
num_gens = batch_size - num_contexts
|
||||
|
||||
if logits.dim() == 1:
|
||||
logits = logits.unsqueeze(0)
|
||||
|
||||
# The return buffer
|
||||
accepted_tokens = torch.empty((batch_size, (self.max_draft_len + 1)),
|
||||
dtype=torch.int,
|
||||
device=logits.device)
|
||||
num_accepted_tokens = torch.ones(batch_size,
|
||||
dtype=torch.int,
|
||||
device=logits.device)
|
||||
|
||||
# Sample tokens using per-request sampling parameters
|
||||
target_tokens = self._sample_tokens_for_batch(logits, spec_metadata,
|
||||
num_contexts, batch_size)
|
||||
# context
|
||||
accepted_tokens[:num_contexts, 0] = target_tokens[:num_contexts]
|
||||
|
||||
# generation
|
||||
gen_target_tokens = target_tokens[num_contexts:].reshape(
|
||||
num_gens, self.max_draft_len + 1)
|
||||
accepted_tokens[num_contexts:, :] = gen_target_tokens
|
||||
# Reshape draft tokens for base implementation
|
||||
draft_tokens = spec_metadata.draft_tokens.reshape(
|
||||
num_gens, self.max_draft_len)
|
||||
num_accepted_tokens[num_contexts:] += torch.cumprod(
|
||||
(draft_tokens == gen_target_tokens[:, :self.max_draft_len]).int(),
|
||||
dim=-1).sum(1)
|
||||
# Check for environment variable override
|
||||
if self.force_num_accepted_tokens != 0:
|
||||
# total tokens per iteration = accepted draft tokens + 1 target token
|
||||
force_total_tokens = min(self.force_num_accepted_tokens + 1,
|
||||
self.max_draft_len + 1)
|
||||
num_accepted_tokens[num_contexts:] = force_total_tokens
|
||||
return accepted_tokens, num_accepted_tokens
|
||||
|
||||
# Use base implementation for strict acceptance
|
||||
return self._sample_and_accept_draft_tokens_base(
|
||||
logits, draft_tokens, num_contexts, batch_size, spec_metadata)
|
||||
|
||||
def draft_decoder(
|
||||
self,
|
||||
@ -570,15 +540,8 @@ class Eagle3OneModelWorker(SpecWorkerBase):
|
||||
# Note: using greedy for draft tokens is a bit easier to implement and
|
||||
# faster. It doesn't affect the final output and seems to have a negligible
|
||||
# impact on AR.
|
||||
draft_tokens = torch.argmax(logits, dim=-1)
|
||||
|
||||
# Apply d2t (offsets between draft model dictionary and main model dictionary).
|
||||
if (d2t := getattr(draft_model.model, "d2t", None)) is not None:
|
||||
draft_tokens = d2t[draft_tokens] + draft_tokens
|
||||
|
||||
draft_tokens = draft_tokens.type(torch.int32)
|
||||
|
||||
return draft_tokens
|
||||
d2t = getattr(draft_model.model, "d2t", None)
|
||||
return self._draft_sampler_greedy(logits, d2t)
|
||||
|
||||
def prepare_1st_drafter_inputs(
|
||||
self,
|
||||
@ -601,14 +564,9 @@ class Eagle3OneModelWorker(SpecWorkerBase):
|
||||
hidden_states = draft_model.apply_eagle3_fc(hidden_states)
|
||||
|
||||
# context
|
||||
input_ctx_ids = input_ids[:attn_metadata.num_ctx_tokens]
|
||||
input_ids_ctx = torch.empty_like(input_ctx_ids,
|
||||
dtype=torch.int32,
|
||||
device="cuda")
|
||||
input_ids_ctx[:-1].copy_(input_ctx_ids[1:])
|
||||
input_ids_ctx[
|
||||
spec_metadata.
|
||||
gather_ids[:num_contexts]] = accepted_tokens[:num_contexts, 0]
|
||||
input_ids_ctx = self._prepare_context_input_ids(
|
||||
input_ids, attn_metadata.num_ctx_tokens, spec_metadata.gather_ids,
|
||||
accepted_tokens, num_contexts)
|
||||
|
||||
# generation
|
||||
input_ids_gen = accepted_tokens[num_contexts:, :].flatten()
|
||||
|
||||
@ -423,6 +423,177 @@ class SpecWorkerBase(nn.Module, ABC):
|
||||
self.guided_decoder = guided_decoder
|
||||
return True
|
||||
|
||||
def _prepare_attn_metadata_for_spec_dec(self, attn_metadata):
|
||||
"""
|
||||
Prepare attention metadata before speculative decoding draft token generation.
|
||||
Saves current state for later restoration.
|
||||
"""
|
||||
attn_metadata.prepare_for_spec_dec("_seq_lens", "_seq_lens_cuda")
|
||||
|
||||
def _restore_attn_metadata_from_spec_dec(self, attn_metadata):
|
||||
"""
|
||||
Restore attention metadata after speculative decoding draft token generation.
|
||||
"""
|
||||
attn_metadata.restore_from_spec_dec()
|
||||
attn_metadata.on_update()
|
||||
|
||||
def _apply_force_accepted_tokens(self, num_accepted_tokens, num_contexts):
|
||||
"""
|
||||
Apply forced number of accepted tokens if environment variable is set.
|
||||
This is used for testing and debugging.
|
||||
|
||||
Args:
|
||||
num_accepted_tokens: Tensor of shape [batch_size] with current accepted counts
|
||||
num_contexts: Number of context (prefill) requests
|
||||
|
||||
Returns:
|
||||
Modified num_accepted_tokens tensor
|
||||
|
||||
Note:
|
||||
For MTPWorker, self.max_draft_len equals num_nextn_predict_layers (mtp_num_modules).
|
||||
For Eagle3OneModelWorker, self.max_draft_len equals spec_config.max_draft_len.
|
||||
"""
|
||||
if self.force_num_accepted_tokens != 0:
|
||||
# total tokens per iteration = accepted draft tokens + 1 target token
|
||||
force_total_tokens = min(self.force_num_accepted_tokens + 1,
|
||||
self.max_draft_len + 1)
|
||||
num_accepted_tokens[num_contexts:] = force_total_tokens
|
||||
return num_accepted_tokens
|
||||
|
||||
def _sample_and_accept_draft_tokens_base(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
draft_tokens: torch.Tensor,
|
||||
num_contexts: int,
|
||||
batch_size: int,
|
||||
spec_metadata,
|
||||
):
|
||||
"""
|
||||
Base implementation for sampling and accepting draft tokens.
|
||||
Uses strict acceptance (token equality with cumulative product).
|
||||
|
||||
This is the common logic shared between Eagle3 and MTP (when relaxed
|
||||
acceptance is disabled).
|
||||
|
||||
Args:
|
||||
logits: [num_tokens, vocab_size] - Target model logits
|
||||
draft_tokens: [num_gens, max_draft_len] - Previously predicted draft tokens
|
||||
num_contexts: Number of context requests
|
||||
batch_size: Total number of requests
|
||||
spec_metadata: Speculative decoding metadata
|
||||
|
||||
Returns:
|
||||
accepted_tokens: [batch_size, max_draft_len + 1] - Accepted tokens
|
||||
num_accepted_tokens: [batch_size] - Number of accepted tokens per request
|
||||
"""
|
||||
num_gens = batch_size - num_contexts
|
||||
|
||||
if logits.dim() == 1:
|
||||
logits = logits.unsqueeze(0)
|
||||
|
||||
# Allocate return buffers
|
||||
accepted_tokens = torch.empty((batch_size, self.max_draft_len + 1),
|
||||
dtype=torch.int,
|
||||
device=logits.device)
|
||||
num_accepted_tokens = torch.ones(batch_size,
|
||||
dtype=torch.int,
|
||||
device=logits.device)
|
||||
|
||||
# Sample tokens using per-request sampling parameters
|
||||
target_tokens = self._sample_tokens_for_batch(logits, spec_metadata,
|
||||
num_contexts, batch_size)
|
||||
|
||||
# Context requests: only accept the sampled token (no draft tokens yet)
|
||||
accepted_tokens[:num_contexts, 0] = target_tokens[:num_contexts]
|
||||
|
||||
# Generation requests: verify draft tokens against target tokens
|
||||
gen_target_tokens = target_tokens[num_contexts:].reshape(
|
||||
num_gens, self.max_draft_len + 1)
|
||||
accepted_tokens[num_contexts:, :] = gen_target_tokens
|
||||
|
||||
# Compare draft tokens with target tokens using cumulative product
|
||||
# Counts consecutive matches from the start
|
||||
num_accepted_tokens[num_contexts:] += torch.cumprod(
|
||||
(draft_tokens == gen_target_tokens[:, :self.max_draft_len]).int(),
|
||||
dim=-1).sum(1)
|
||||
|
||||
# Apply force override if set
|
||||
num_accepted_tokens = self._apply_force_accepted_tokens(
|
||||
num_accepted_tokens, num_contexts)
|
||||
|
||||
return accepted_tokens, num_accepted_tokens
|
||||
|
||||
def _draft_sampler_greedy(self, logits: torch.Tensor, d2t=None):
|
||||
"""
|
||||
Simple greedy draft token sampling using argmax.
|
||||
|
||||
Args:
|
||||
logits: [num_tokens, vocab_size] - Draft model logits
|
||||
d2t: Optional dictionary offset tensor for vocab mapping
|
||||
|
||||
Returns:
|
||||
draft_tokens: [num_tokens] - Sampled draft token ids (int32)
|
||||
"""
|
||||
draft_tokens = torch.argmax(logits, dim=-1)
|
||||
|
||||
# Apply d2t (offsets between draft and target model dictionaries)
|
||||
if d2t is not None:
|
||||
draft_tokens = d2t[draft_tokens] + draft_tokens
|
||||
|
||||
return draft_tokens.type(torch.int32)
|
||||
|
||||
def _execute_guided_decoder_if_present(self, logits):
|
||||
"""Execute guided decoder on target model logits if available."""
|
||||
if self.guided_decoder is not None:
|
||||
self.guided_decoder.execute(logits)
|
||||
|
||||
def _prepare_next_new_tokens(self, accepted_tokens, next_draft_tokens,
|
||||
batch_indices_cuda, batch_size,
|
||||
num_accepted_tokens):
|
||||
"""
|
||||
Prepare next_new_tokens for overlap scheduler support.
|
||||
|
||||
Args:
|
||||
accepted_tokens: [batch_size, max_draft_len + 1] - Accepted tokens
|
||||
next_draft_tokens: [batch_size, max_draft_len] - Predicted draft tokens
|
||||
batch_indices_cuda: Batch indices tensor
|
||||
batch_size: Number of requests
|
||||
num_accepted_tokens: [batch_size] - Number of accepted tokens per request
|
||||
|
||||
Returns:
|
||||
next_new_tokens: [batch_size, max_draft_len + 1] - Input tokens for next iteration
|
||||
"""
|
||||
next_new_tokens = accepted_tokens[batch_indices_cuda[:batch_size],
|
||||
num_accepted_tokens - 1].unsqueeze(1)
|
||||
next_new_tokens = torch.concat([next_new_tokens, next_draft_tokens],
|
||||
dim=1)
|
||||
return next_new_tokens
|
||||
|
||||
def _prepare_context_input_ids(self, input_ids, num_ctx_tokens, gather_ids,
|
||||
accepted_tokens, num_contexts):
|
||||
"""
|
||||
Prepare context input IDs for draft model forward.
|
||||
Shifts input IDs left by 1 and places the first accepted token at gather positions.
|
||||
|
||||
Args:
|
||||
input_ids: Original input IDs tensor
|
||||
num_ctx_tokens: Number of context tokens
|
||||
gather_ids: Indices for placing accepted tokens (last token positions)
|
||||
accepted_tokens: [batch_size, max_draft_len + 1] - Accepted tokens
|
||||
num_contexts: Number of context requests
|
||||
|
||||
Returns:
|
||||
input_ids_ctx: Prepared context input IDs
|
||||
"""
|
||||
input_prompt_ids = input_ids[:num_ctx_tokens]
|
||||
input_ids_ctx = torch.empty_like(input_prompt_ids,
|
||||
dtype=torch.int32,
|
||||
device="cuda")
|
||||
input_ids_ctx[:-1].copy_(input_prompt_ids[1:])
|
||||
input_ids_ctx[
|
||||
gather_ids[:num_contexts]] = accepted_tokens[:num_contexts, 0]
|
||||
return input_ids_ctx
|
||||
|
||||
def _sample_tokens_for_batch(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
|
||||
@ -473,8 +473,7 @@ class MTPWorker(SpecWorkerBase):
|
||||
|
||||
raw_logits = logits
|
||||
|
||||
if self.guided_decoder is not None:
|
||||
self.guided_decoder.execute(logits)
|
||||
self._execute_guided_decoder_if_present(logits)
|
||||
|
||||
# Sample and verify draft tokens
|
||||
accepted_tokens, num_accepted_tokens = self.sample_and_accept_draft_tokens(
|
||||
@ -541,14 +540,12 @@ class MTPWorker(SpecWorkerBase):
|
||||
|
||||
# restore attn metadata
|
||||
if attn_metadata is not None:
|
||||
self.restore_attn_metadata(attn_metadata=attn_metadata)
|
||||
self._restore_attn_metadata_from_spec_dec(attn_metadata)
|
||||
|
||||
# prepare next new tokens to support overlap scheduler
|
||||
next_new_tokens = accepted_tokens[
|
||||
spec_metadata.batch_indices_cuda[:batch_size],
|
||||
num_accepted_tokens - 1].unsqueeze(1)
|
||||
next_new_tokens = torch.concat([next_new_tokens, next_draft_tokens],
|
||||
dim=1)
|
||||
next_new_tokens = self._prepare_next_new_tokens(
|
||||
accepted_tokens, next_draft_tokens,
|
||||
spec_metadata.batch_indices_cuda, batch_size, num_accepted_tokens)
|
||||
|
||||
return {
|
||||
'logits': raw_logits,
|
||||
@ -845,6 +842,10 @@ class MTPWorker(SpecWorkerBase):
|
||||
self.spec_config.begin_thinking_phase_token,
|
||||
self.spec_config.end_thinking_phase_token)
|
||||
|
||||
# Apply force override for relaxed acceptance path
|
||||
num_accepted_tokens = self._apply_force_accepted_tokens(
|
||||
num_accepted_tokens, num_contexts)
|
||||
|
||||
# Strict acceptance
|
||||
else:
|
||||
if self.is_thop:
|
||||
@ -856,36 +857,25 @@ class MTPWorker(SpecWorkerBase):
|
||||
accepted_tokens, num_accepted_tokens = torch.ops.trtllm.mtp_sampling_and_accepted_draft_tokens_op(
|
||||
logits, spec_metadata.draft_tokens, target_tokens_cache,
|
||||
mtp_num_modules, batch_size, num_contexts, logits.shape[-1])
|
||||
|
||||
# Apply force override for THOP path
|
||||
num_accepted_tokens = self._apply_force_accepted_tokens(
|
||||
num_accepted_tokens, num_contexts)
|
||||
else:
|
||||
target_tokens = self._sample_tokens_for_batch(
|
||||
logits, spec_metadata, num_contexts, batch_size)
|
||||
|
||||
# context
|
||||
accepted_tokens[:num_contexts, 0] = target_tokens[:num_contexts]
|
||||
|
||||
# generation
|
||||
gen_target_tokens = target_tokens[num_contexts:].reshape(
|
||||
num_gens, mtp_num_modules + 1)
|
||||
accepted_tokens[num_contexts:, :] = gen_target_tokens
|
||||
# Reshape draft tokens for base implementation
|
||||
draft_tokens = spec_metadata.draft_tokens.reshape(
|
||||
num_gens, mtp_num_modules)
|
||||
num_accepted_tokens[num_contexts:] += torch.cumprod(
|
||||
(draft_tokens == gen_target_tokens[:, :mtp_num_modules]
|
||||
).int(),
|
||||
dim=-1).sum(1)
|
||||
|
||||
# Check for environment variable override
|
||||
if self.force_num_accepted_tokens != 0:
|
||||
# total tokens per iteration = accepted draft tokens + 1 target token
|
||||
force_total_tokens = min(self.force_num_accepted_tokens + 1,
|
||||
mtp_num_modules + 1)
|
||||
num_accepted_tokens[num_contexts:] = force_total_tokens
|
||||
# Use base implementation for strict acceptance
|
||||
accepted_tokens, num_accepted_tokens = self._sample_and_accept_draft_tokens_base(
|
||||
logits, draft_tokens, num_contexts, batch_size,
|
||||
spec_metadata)
|
||||
|
||||
return accepted_tokens, num_accepted_tokens
|
||||
|
||||
def change_attn_metadata(self, num_accepted_tokens: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata):
|
||||
attn_metadata.prepare_for_spec_dec("_seq_lens", "_seq_lens_cuda")
|
||||
self._prepare_attn_metadata_for_spec_dec(attn_metadata)
|
||||
batch_size = attn_metadata.num_seqs
|
||||
mtp_num_modules = self.spec_config.num_nextn_predict_layers
|
||||
|
||||
@ -908,10 +898,6 @@ class MTPWorker(SpecWorkerBase):
|
||||
attn_metadata.kv_cache_params.num_cached_tokens_per_seq[
|
||||
i] -= mtp_num_modules + 1 - num_accepted_tokens[i].item()
|
||||
|
||||
def restore_attn_metadata(self, attn_metadata: AttentionMetadata):
|
||||
attn_metadata.restore_from_spec_dec()
|
||||
attn_metadata.on_update()
|
||||
|
||||
def prepare_drafter_inputs(
|
||||
self,
|
||||
input_ids: torch.IntTensor,
|
||||
@ -1015,13 +1001,9 @@ class MTPWorker(SpecWorkerBase):
|
||||
# context
|
||||
if num_contexts > 0:
|
||||
hidden_states_ctx = hidden_states[:num_ctx_tokens, :]
|
||||
input_prompt_ids = input_ids[:num_ctx_tokens]
|
||||
input_ids_ctx = torch.empty_like(input_prompt_ids,
|
||||
dtype=torch.int32,
|
||||
device="cuda")
|
||||
input_ids_ctx[:-1].copy_(input_prompt_ids[1:])
|
||||
input_ids_ctx[last_tokens_idx[:num_contexts]] = \
|
||||
accepted_tokens[:num_contexts, 0]
|
||||
input_ids_ctx = self._prepare_context_input_ids(
|
||||
input_ids, num_ctx_tokens, last_tokens_idx, accepted_tokens,
|
||||
num_contexts)
|
||||
return_input_ids_list.append(input_ids_ctx)
|
||||
return_hidden_states_list.append(hidden_states_ctx)
|
||||
# generation
|
||||
@ -1137,7 +1119,7 @@ class MTPWorker(SpecWorkerBase):
|
||||
draft_tokens = self.get_draft_tokens_from_gathered(sliced_gathered)
|
||||
else:
|
||||
# Simple argmax if no TP or no model config
|
||||
draft_tokens = torch.argmax(logits, dim=-1).type(torch.int32)
|
||||
draft_tokens = self._draft_sampler_greedy(logits)
|
||||
|
||||
return draft_tokens
|
||||
|
||||
@ -1176,15 +1158,14 @@ class MTPEagleWorker(MTPWorker):
|
||||
|
||||
raw_logits = logits
|
||||
|
||||
if self.guided_decoder is not None:
|
||||
self.guided_decoder.execute(logits)
|
||||
self._execute_guided_decoder_if_present(logits)
|
||||
|
||||
# Sample and verify draft tokens
|
||||
accepted_tokens, num_accepted_tokens = self.sample_and_accept_draft_tokens(
|
||||
input_ids, logits, spec_metadata, attn_metadata)
|
||||
|
||||
# Save the old attn_metadata and spec_metadata
|
||||
attn_metadata.prepare_for_spec_dec("_seq_lens", "_seq_lens_cuda")
|
||||
self._prepare_attn_metadata_for_spec_dec(attn_metadata)
|
||||
|
||||
# Prepare inputs for the 1st MTP layer
|
||||
@torch.compile(options={"max-autotune": True})
|
||||
@ -1323,22 +1304,9 @@ class MTPEagleWorker(MTPWorker):
|
||||
}
|
||||
|
||||
# restore attn_metadata to support cuda graph
|
||||
attn_metadata.restore_from_spec_dec()
|
||||
attn_metadata.on_update()
|
||||
self._restore_attn_metadata_from_spec_dec(attn_metadata)
|
||||
|
||||
@torch.compile(options={"max-autotune": True})
|
||||
def prepare_next_tokens(next_draft_tokens, accepted_tokens,
|
||||
spec_metadata, batch_size, num_accepted_tokens):
|
||||
next_draft_tokens = torch.stack(next_draft_tokens, dim=1)
|
||||
# prepare next new tokens to support overlap scheduler
|
||||
next_new_tokens = accepted_tokens[
|
||||
spec_metadata.batch_indices_cuda[:batch_size],
|
||||
num_accepted_tokens - 1].unsqueeze(1)
|
||||
next_new_tokens = torch.concat([next_new_tokens, next_draft_tokens],
|
||||
dim=1)
|
||||
return next_draft_tokens, next_new_tokens
|
||||
|
||||
next_draft_tokens, next_new_tokens = prepare_next_tokens(
|
||||
next_draft_tokens, next_new_tokens = self._prepare_next_tokens(
|
||||
next_draft_tokens, accepted_tokens, spec_metadata, batch_size,
|
||||
num_accepted_tokens)
|
||||
|
||||
@ -1350,6 +1318,18 @@ class MTPEagleWorker(MTPWorker):
|
||||
'next_new_tokens': next_new_tokens
|
||||
}
|
||||
|
||||
@torch.compile(options={"max-autotune": True})
|
||||
def _prepare_next_tokens(self, next_draft_tokens, accepted_tokens,
|
||||
spec_metadata, batch_size, num_accepted_tokens):
|
||||
"""
|
||||
Stack draft tokens and prepare next_new_tokens for overlap scheduler.
|
||||
"""
|
||||
next_draft_tokens = torch.stack(next_draft_tokens, dim=1)
|
||||
next_new_tokens = self._prepare_next_new_tokens(
|
||||
accepted_tokens, next_draft_tokens,
|
||||
spec_metadata.batch_indices_cuda, batch_size, num_accepted_tokens)
|
||||
return next_draft_tokens, next_new_tokens
|
||||
|
||||
@torch.compile(options={"max-autotune": True})
|
||||
def prepare_drafter_inputs(
|
||||
self,
|
||||
@ -1364,13 +1344,9 @@ class MTPEagleWorker(MTPWorker):
|
||||
num_contexts = attn_metadata.num_contexts
|
||||
|
||||
# context
|
||||
input_prompt_ids = input_ids[:attn_metadata.num_ctx_tokens]
|
||||
input_ids_ctx = torch.empty_like(input_prompt_ids,
|
||||
dtype=torch.int32,
|
||||
device="cuda")
|
||||
input_ids_ctx[:-1].copy_(input_prompt_ids[1:])
|
||||
input_ids_ctx[
|
||||
last_tokens_idx[:num_contexts]] = accepted_tokens[:num_contexts, 0]
|
||||
input_ids_ctx = self._prepare_context_input_ids(
|
||||
input_ids, attn_metadata.num_ctx_tokens, last_tokens_idx,
|
||||
accepted_tokens, num_contexts)
|
||||
|
||||
# generation
|
||||
input_ids_gen = accepted_tokens[num_contexts:, :].flatten()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user