[TRTLLM-9962][feat] Some optimizations for two-model spec dec (#10208)

Signed-off-by: ziyixiong-nv <219238287+ziyixiong-nv@users.noreply.github.com>
This commit is contained in:
Ziyi Xiong 2025-12-28 12:52:04 +08:00 committed by GitHub
parent ae6d5766ed
commit c59aa8bec5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 73 additions and 41 deletions

View File

@ -895,8 +895,6 @@ class PyTorchModelEngine(ModelEngine):
return None return None
num_extra_decoding_steps = self._get_num_extra_decoding_steps() num_extra_decoding_steps = self._get_num_extra_decoding_steps()
if num_extra_decoding_steps > 0:
return None # Disable autotuning for fused drafting loops for now.
if num_gen_requests > self.batch_size: if num_gen_requests > self.batch_size:
return None return None
@ -909,7 +907,10 @@ class PyTorchModelEngine(ModelEngine):
ctx_requests = [] ctx_requests = []
gen_requests = [] gen_requests = []
max_seq_len = self.max_seq_len - 1 # For drafting loops, reduce max_seq_len to leave room for extra decoding steps
max_seq_len = self.max_seq_len - 1 - num_extra_decoding_steps
if max_seq_len < 1:
return None # Not enough sequence length for drafting loop
num_full_seqs = 0 num_full_seqs = 0
num_left_over_tokens = 0 num_left_over_tokens = 0
@ -954,7 +955,8 @@ class PyTorchModelEngine(ModelEngine):
token_nums=ctx_token_nums, token_nums=ctx_token_nums,
is_gen=False, is_gen=False,
max_num_draft_tokens=self.runtime_draft_len, max_num_draft_tokens=self.runtime_draft_len,
use_mrope=self.use_mrope) use_mrope=self.use_mrope,
num_extra_decoding_steps=num_extra_decoding_steps)
if spec_resource_manager is not None: if spec_resource_manager is not None:
spec_resource_manager.add_dummy_requests( spec_resource_manager.add_dummy_requests(
@ -1546,7 +1548,6 @@ class PyTorchModelEngine(ModelEngine):
return lora_params return lora_params
@torch.compile(options={"max-autotune": True})
def _update_draft_input_tensors(self, def _update_draft_input_tensors(self,
num_accepted_tokens_device: torch.Tensor, num_accepted_tokens_device: torch.Tensor,
new_tokens_device: torch.Tensor, new_tokens_device: torch.Tensor,
@ -1671,7 +1672,6 @@ class PyTorchModelEngine(ModelEngine):
return inputs, self.gather_ids_cuda[:num_generation_tokens] return inputs, self.gather_ids_cuda[:num_generation_tokens]
@torch.compile(options={"max-autotune": True})
def _update_target_input_tensors( def _update_target_input_tensors(
self, num_accepted_tokens_device: torch.Tensor, self, num_accepted_tokens_device: torch.Tensor,
new_tokens_device: torch.Tensor, new_tokens_device: torch.Tensor,

View File

@ -1708,7 +1708,6 @@ class PyExecutor:
self.iter_counter += 1 self.iter_counter += 1
@nvtx_range("_accept_draft_tokens") @nvtx_range("_accept_draft_tokens")
@torch.compile(options={"max-autotune": True})
def _accept_draft_tokens( def _accept_draft_tokens(
self, scheduled_batch: ScheduledRequests, self, scheduled_batch: ScheduledRequests,
target_outputs: SampleStateTensors, target_outputs: SampleStateTensors,

View File

@ -120,24 +120,27 @@ class LinearDraftingLoopWrapper(BaseDraftingLoopWrapper):
new_draft_tokens = [self.sample(logits)] new_draft_tokens = [self.sample(logits)]
draft_logits = [logits] draft_logits = [logits]
with save_metadata_state(attn_metadata, spec_metadata): if self.max_draft_len > 1:
batch_size = attn_metadata.num_seqs is_eagle3 = isinstance(spec_metadata, Eagle3SpecMetadata)
with save_metadata_state(attn_metadata, spec_metadata):
batch_size = attn_metadata.num_seqs
new_position_ids = self.prepare_for_generation( new_position_ids = self.prepare_for_generation(
attn_metadata, spec_metadata, position_ids) attn_metadata, spec_metadata, position_ids)
for i in range(self.max_draft_len - 1): for i in range(self.max_draft_len - 1):
logits = self.draft_model.forward( logits = self.draft_model.forward(
input_ids=new_draft_tokens[-1], input_ids=new_draft_tokens[-1],
position_ids=new_position_ids, position_ids=new_position_ids,
attn_metadata=attn_metadata, attn_metadata=attn_metadata,
spec_metadata=spec_metadata) spec_metadata=spec_metadata)
new_draft_tokens.append(self.sample(logits)) new_draft_tokens.append(self.sample(logits))
draft_logits.append(logits) draft_logits.append(logits)
new_position_ids += 1 new_position_ids += 1
attn_metadata.kv_lens_cuda[:batch_size] += 1 attn_metadata.kv_lens_cuda[:batch_size] += 1
if i == 0 and isinstance(spec_metadata, Eagle3SpecMetadata): if i == 0 and is_eagle3:
spec_metadata.hidden_states_read_indices[:batch_size].copy_( spec_metadata.hidden_states_read_indices[:batch_size].copy_(
spec_metadata.hidden_states_write_indices[:batch_size]) spec_metadata.
hidden_states_write_indices[:batch_size])
return { return {
"new_draft_tokens": torch.stack(new_draft_tokens), "new_draft_tokens": torch.stack(new_draft_tokens),
@ -153,7 +156,6 @@ class LinearDraftingLoopWrapper(BaseDraftingLoopWrapper):
return tokens return tokens
@torch.compile(options={'max-autotune': True})
def prepare_for_generation(self, attn_metadata: AttentionMetadata, def prepare_for_generation(self, attn_metadata: AttentionMetadata,
spec_metadata: SpecMetadata, spec_metadata: SpecMetadata,
position_ids: torch.Tensor) -> torch.Tensor: position_ids: torch.Tensor) -> torch.Tensor:

View File

@ -576,22 +576,53 @@ class ModelDrafter(Drafter):
if target_inputs.next_draft_tokens is None: if target_inputs.next_draft_tokens is None:
return return
if draft_tensors is not None: draft_indices = []
for req_idx, request in enumerate(draft_batch.all_requests()): target_indices = []
target_req = self.req_id_to_old_request[request.py_request_id] for req_idx, request in enumerate(draft_batch.all_requests()):
if target_req.state != LlmRequestState.GENERATION_IN_PROGRESS: target_req = self.req_id_to_old_request[request.py_request_id]
# Skip prefill requests if target_req.state != LlmRequestState.GENERATION_IN_PROGRESS:
continue # Skip prefill requests
# Get the index of the draft/target tokens in the device tensor continue
draft_idx = req_idx if self.use_static_draft_loop else request.py_seq_slot # Get the index of the draft/target tokens in the device tensor
target_idx = target_req.py_seq_slot draft_idx = req_idx if self.use_static_draft_loop else request.py_seq_slot
target_inputs.new_tokens[draft_position + 1:draft_position + target_idx = target_req.py_seq_slot
draft_length + 1, target_idx, draft_indices.append(draft_idx)
0] = draft_tensors[0:draft_length, target_indices.append(target_idx)
draft_idx]
target_inputs.next_draft_tokens[ if len(draft_indices) == 0:
target_idx, draft_position:draft_position + return
draft_length] = draft_tensors[0:draft_length, draft_idx]
device = draft_tensors.device
# Create index tensors
draft_indices_tensor = torch.tensor(draft_indices,
dtype=torch.long,
pin_memory=True).to(
device, non_blocking=True)
target_indices_tensor = torch.tensor(target_indices,
dtype=torch.long,
pin_memory=True).to(
device, non_blocking=True)
# Pre-slice draft tensors: [draft_length, batch_size]
draft_slice = draft_tensors[0:draft_length]
# Gather all source data at once using single index_select kernel
# Result shape: [draft_length, num_requests]
gathered = draft_slice.index_select(1, draft_indices_tensor).to(
torch.int32)
# Scatter to new_tokens using advanced indexing (single kernel)
# Shape: [draft_length, num_requests] -> [seq_len, batch_size, beam_width]
target_inputs.new_tokens[draft_position + 1:draft_position +
draft_length + 1, target_indices_tensor,
0] = gathered
# Scatter to next_draft_tokens using advanced indexing (single kernel)
# Shape: [num_requests, draft_length] -> [batch_size, max_draft_len]
target_inputs.next_draft_tokens[target_indices_tensor,
draft_position:draft_position +
draft_length] = gathered.t()
def _setup_draft_batch_and_resources( def _setup_draft_batch_and_resources(
self, self,