[TRTLLM-9962][feat] Some optimizations for two-model spec dec (#10208)

Signed-off-by: ziyixiong-nv <219238287+ziyixiong-nv@users.noreply.github.com>
This commit is contained in:
Ziyi Xiong 2025-12-28 12:52:04 +08:00 committed by GitHub
parent ae6d5766ed
commit c59aa8bec5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 73 additions and 41 deletions

View File

@ -895,8 +895,6 @@ class PyTorchModelEngine(ModelEngine):
return None
num_extra_decoding_steps = self._get_num_extra_decoding_steps()
if num_extra_decoding_steps > 0:
return None # Disable autotuning for fused drafting loops for now.
if num_gen_requests > self.batch_size:
return None
@ -909,7 +907,10 @@ class PyTorchModelEngine(ModelEngine):
ctx_requests = []
gen_requests = []
max_seq_len = self.max_seq_len - 1
# For drafting loops, reduce max_seq_len to leave room for extra decoding steps
max_seq_len = self.max_seq_len - 1 - num_extra_decoding_steps
if max_seq_len < 1:
return None # Not enough sequence length for drafting loop
num_full_seqs = 0
num_left_over_tokens = 0
@ -954,7 +955,8 @@ class PyTorchModelEngine(ModelEngine):
token_nums=ctx_token_nums,
is_gen=False,
max_num_draft_tokens=self.runtime_draft_len,
use_mrope=self.use_mrope)
use_mrope=self.use_mrope,
num_extra_decoding_steps=num_extra_decoding_steps)
if spec_resource_manager is not None:
spec_resource_manager.add_dummy_requests(
@ -1546,7 +1548,6 @@ class PyTorchModelEngine(ModelEngine):
return lora_params
@torch.compile(options={"max-autotune": True})
def _update_draft_input_tensors(self,
num_accepted_tokens_device: torch.Tensor,
new_tokens_device: torch.Tensor,
@ -1671,7 +1672,6 @@ class PyTorchModelEngine(ModelEngine):
return inputs, self.gather_ids_cuda[:num_generation_tokens]
@torch.compile(options={"max-autotune": True})
def _update_target_input_tensors(
self, num_accepted_tokens_device: torch.Tensor,
new_tokens_device: torch.Tensor,

View File

@ -1708,7 +1708,6 @@ class PyExecutor:
self.iter_counter += 1
@nvtx_range("_accept_draft_tokens")
@torch.compile(options={"max-autotune": True})
def _accept_draft_tokens(
self, scheduled_batch: ScheduledRequests,
target_outputs: SampleStateTensors,

View File

@ -120,24 +120,27 @@ class LinearDraftingLoopWrapper(BaseDraftingLoopWrapper):
new_draft_tokens = [self.sample(logits)]
draft_logits = [logits]
with save_metadata_state(attn_metadata, spec_metadata):
batch_size = attn_metadata.num_seqs
if self.max_draft_len > 1:
is_eagle3 = isinstance(spec_metadata, Eagle3SpecMetadata)
with save_metadata_state(attn_metadata, spec_metadata):
batch_size = attn_metadata.num_seqs
new_position_ids = self.prepare_for_generation(
attn_metadata, spec_metadata, position_ids)
for i in range(self.max_draft_len - 1):
logits = self.draft_model.forward(
input_ids=new_draft_tokens[-1],
position_ids=new_position_ids,
attn_metadata=attn_metadata,
spec_metadata=spec_metadata)
new_draft_tokens.append(self.sample(logits))
draft_logits.append(logits)
new_position_ids += 1
attn_metadata.kv_lens_cuda[:batch_size] += 1
if i == 0 and isinstance(spec_metadata, Eagle3SpecMetadata):
spec_metadata.hidden_states_read_indices[:batch_size].copy_(
spec_metadata.hidden_states_write_indices[:batch_size])
new_position_ids = self.prepare_for_generation(
attn_metadata, spec_metadata, position_ids)
for i in range(self.max_draft_len - 1):
logits = self.draft_model.forward(
input_ids=new_draft_tokens[-1],
position_ids=new_position_ids,
attn_metadata=attn_metadata,
spec_metadata=spec_metadata)
new_draft_tokens.append(self.sample(logits))
draft_logits.append(logits)
new_position_ids += 1
attn_metadata.kv_lens_cuda[:batch_size] += 1
if i == 0 and is_eagle3:
spec_metadata.hidden_states_read_indices[:batch_size].copy_(
spec_metadata.
hidden_states_write_indices[:batch_size])
return {
"new_draft_tokens": torch.stack(new_draft_tokens),
@ -153,7 +156,6 @@ class LinearDraftingLoopWrapper(BaseDraftingLoopWrapper):
return tokens
@torch.compile(options={'max-autotune': True})
def prepare_for_generation(self, attn_metadata: AttentionMetadata,
spec_metadata: SpecMetadata,
position_ids: torch.Tensor) -> torch.Tensor:

View File

@ -576,22 +576,53 @@ class ModelDrafter(Drafter):
if target_inputs.next_draft_tokens is None:
return
if draft_tensors is not None:
for req_idx, request in enumerate(draft_batch.all_requests()):
target_req = self.req_id_to_old_request[request.py_request_id]
if target_req.state != LlmRequestState.GENERATION_IN_PROGRESS:
# Skip prefill requests
continue
# Get the index of the draft/target tokens in the device tensor
draft_idx = req_idx if self.use_static_draft_loop else request.py_seq_slot
target_idx = target_req.py_seq_slot
target_inputs.new_tokens[draft_position + 1:draft_position +
draft_length + 1, target_idx,
0] = draft_tensors[0:draft_length,
draft_idx]
target_inputs.next_draft_tokens[
target_idx, draft_position:draft_position +
draft_length] = draft_tensors[0:draft_length, draft_idx]
draft_indices = []
target_indices = []
for req_idx, request in enumerate(draft_batch.all_requests()):
target_req = self.req_id_to_old_request[request.py_request_id]
if target_req.state != LlmRequestState.GENERATION_IN_PROGRESS:
# Skip prefill requests
continue
# Get the index of the draft/target tokens in the device tensor
draft_idx = req_idx if self.use_static_draft_loop else request.py_seq_slot
target_idx = target_req.py_seq_slot
draft_indices.append(draft_idx)
target_indices.append(target_idx)
if len(draft_indices) == 0:
return
device = draft_tensors.device
# Create index tensors
draft_indices_tensor = torch.tensor(draft_indices,
dtype=torch.long,
pin_memory=True).to(
device, non_blocking=True)
target_indices_tensor = torch.tensor(target_indices,
dtype=torch.long,
pin_memory=True).to(
device, non_blocking=True)
# Pre-slice draft tensors: [draft_length, batch_size]
draft_slice = draft_tensors[0:draft_length]
# Gather all source data at once using single index_select kernel
# Result shape: [draft_length, num_requests]
gathered = draft_slice.index_select(1, draft_indices_tensor).to(
torch.int32)
# Scatter to new_tokens using advanced indexing (single kernel)
# Shape: [draft_length, num_requests] -> [seq_len, batch_size, beam_width]
target_inputs.new_tokens[draft_position + 1:draft_position +
draft_length + 1, target_indices_tensor,
0] = gathered
# Scatter to next_draft_tokens using advanced indexing (single kernel)
# Shape: [num_requests, draft_length] -> [batch_size, max_draft_len]
target_inputs.next_draft_tokens[target_indices_tensor,
draft_position:draft_position +
draft_length] = gathered.t()
def _setup_draft_batch_and_resources(
self,