mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[TRTLLM-9962][feat] Some optimizations for two-model spec dec (#10208)
Signed-off-by: ziyixiong-nv <219238287+ziyixiong-nv@users.noreply.github.com>
This commit is contained in:
parent
ae6d5766ed
commit
c59aa8bec5
@ -895,8 +895,6 @@ class PyTorchModelEngine(ModelEngine):
|
||||
return None
|
||||
|
||||
num_extra_decoding_steps = self._get_num_extra_decoding_steps()
|
||||
if num_extra_decoding_steps > 0:
|
||||
return None # Disable autotuning for fused drafting loops for now.
|
||||
|
||||
if num_gen_requests > self.batch_size:
|
||||
return None
|
||||
@ -909,7 +907,10 @@ class PyTorchModelEngine(ModelEngine):
|
||||
ctx_requests = []
|
||||
gen_requests = []
|
||||
|
||||
max_seq_len = self.max_seq_len - 1
|
||||
# For drafting loops, reduce max_seq_len to leave room for extra decoding steps
|
||||
max_seq_len = self.max_seq_len - 1 - num_extra_decoding_steps
|
||||
if max_seq_len < 1:
|
||||
return None # Not enough sequence length for drafting loop
|
||||
num_full_seqs = 0
|
||||
num_left_over_tokens = 0
|
||||
|
||||
@ -954,7 +955,8 @@ class PyTorchModelEngine(ModelEngine):
|
||||
token_nums=ctx_token_nums,
|
||||
is_gen=False,
|
||||
max_num_draft_tokens=self.runtime_draft_len,
|
||||
use_mrope=self.use_mrope)
|
||||
use_mrope=self.use_mrope,
|
||||
num_extra_decoding_steps=num_extra_decoding_steps)
|
||||
|
||||
if spec_resource_manager is not None:
|
||||
spec_resource_manager.add_dummy_requests(
|
||||
@ -1546,7 +1548,6 @@ class PyTorchModelEngine(ModelEngine):
|
||||
|
||||
return lora_params
|
||||
|
||||
@torch.compile(options={"max-autotune": True})
|
||||
def _update_draft_input_tensors(self,
|
||||
num_accepted_tokens_device: torch.Tensor,
|
||||
new_tokens_device: torch.Tensor,
|
||||
@ -1671,7 +1672,6 @@ class PyTorchModelEngine(ModelEngine):
|
||||
|
||||
return inputs, self.gather_ids_cuda[:num_generation_tokens]
|
||||
|
||||
@torch.compile(options={"max-autotune": True})
|
||||
def _update_target_input_tensors(
|
||||
self, num_accepted_tokens_device: torch.Tensor,
|
||||
new_tokens_device: torch.Tensor,
|
||||
|
||||
@ -1708,7 +1708,6 @@ class PyExecutor:
|
||||
self.iter_counter += 1
|
||||
|
||||
@nvtx_range("_accept_draft_tokens")
|
||||
@torch.compile(options={"max-autotune": True})
|
||||
def _accept_draft_tokens(
|
||||
self, scheduled_batch: ScheduledRequests,
|
||||
target_outputs: SampleStateTensors,
|
||||
|
||||
@ -120,24 +120,27 @@ class LinearDraftingLoopWrapper(BaseDraftingLoopWrapper):
|
||||
|
||||
new_draft_tokens = [self.sample(logits)]
|
||||
draft_logits = [logits]
|
||||
with save_metadata_state(attn_metadata, spec_metadata):
|
||||
batch_size = attn_metadata.num_seqs
|
||||
if self.max_draft_len > 1:
|
||||
is_eagle3 = isinstance(spec_metadata, Eagle3SpecMetadata)
|
||||
with save_metadata_state(attn_metadata, spec_metadata):
|
||||
batch_size = attn_metadata.num_seqs
|
||||
|
||||
new_position_ids = self.prepare_for_generation(
|
||||
attn_metadata, spec_metadata, position_ids)
|
||||
for i in range(self.max_draft_len - 1):
|
||||
logits = self.draft_model.forward(
|
||||
input_ids=new_draft_tokens[-1],
|
||||
position_ids=new_position_ids,
|
||||
attn_metadata=attn_metadata,
|
||||
spec_metadata=spec_metadata)
|
||||
new_draft_tokens.append(self.sample(logits))
|
||||
draft_logits.append(logits)
|
||||
new_position_ids += 1
|
||||
attn_metadata.kv_lens_cuda[:batch_size] += 1
|
||||
if i == 0 and isinstance(spec_metadata, Eagle3SpecMetadata):
|
||||
spec_metadata.hidden_states_read_indices[:batch_size].copy_(
|
||||
spec_metadata.hidden_states_write_indices[:batch_size])
|
||||
new_position_ids = self.prepare_for_generation(
|
||||
attn_metadata, spec_metadata, position_ids)
|
||||
for i in range(self.max_draft_len - 1):
|
||||
logits = self.draft_model.forward(
|
||||
input_ids=new_draft_tokens[-1],
|
||||
position_ids=new_position_ids,
|
||||
attn_metadata=attn_metadata,
|
||||
spec_metadata=spec_metadata)
|
||||
new_draft_tokens.append(self.sample(logits))
|
||||
draft_logits.append(logits)
|
||||
new_position_ids += 1
|
||||
attn_metadata.kv_lens_cuda[:batch_size] += 1
|
||||
if i == 0 and is_eagle3:
|
||||
spec_metadata.hidden_states_read_indices[:batch_size].copy_(
|
||||
spec_metadata.
|
||||
hidden_states_write_indices[:batch_size])
|
||||
|
||||
return {
|
||||
"new_draft_tokens": torch.stack(new_draft_tokens),
|
||||
@ -153,7 +156,6 @@ class LinearDraftingLoopWrapper(BaseDraftingLoopWrapper):
|
||||
|
||||
return tokens
|
||||
|
||||
@torch.compile(options={'max-autotune': True})
|
||||
def prepare_for_generation(self, attn_metadata: AttentionMetadata,
|
||||
spec_metadata: SpecMetadata,
|
||||
position_ids: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
@ -576,22 +576,53 @@ class ModelDrafter(Drafter):
|
||||
if target_inputs.next_draft_tokens is None:
|
||||
return
|
||||
|
||||
if draft_tensors is not None:
|
||||
for req_idx, request in enumerate(draft_batch.all_requests()):
|
||||
target_req = self.req_id_to_old_request[request.py_request_id]
|
||||
if target_req.state != LlmRequestState.GENERATION_IN_PROGRESS:
|
||||
# Skip prefill requests
|
||||
continue
|
||||
# Get the index of the draft/target tokens in the device tensor
|
||||
draft_idx = req_idx if self.use_static_draft_loop else request.py_seq_slot
|
||||
target_idx = target_req.py_seq_slot
|
||||
target_inputs.new_tokens[draft_position + 1:draft_position +
|
||||
draft_length + 1, target_idx,
|
||||
0] = draft_tensors[0:draft_length,
|
||||
draft_idx]
|
||||
target_inputs.next_draft_tokens[
|
||||
target_idx, draft_position:draft_position +
|
||||
draft_length] = draft_tensors[0:draft_length, draft_idx]
|
||||
draft_indices = []
|
||||
target_indices = []
|
||||
for req_idx, request in enumerate(draft_batch.all_requests()):
|
||||
target_req = self.req_id_to_old_request[request.py_request_id]
|
||||
if target_req.state != LlmRequestState.GENERATION_IN_PROGRESS:
|
||||
# Skip prefill requests
|
||||
continue
|
||||
# Get the index of the draft/target tokens in the device tensor
|
||||
draft_idx = req_idx if self.use_static_draft_loop else request.py_seq_slot
|
||||
target_idx = target_req.py_seq_slot
|
||||
draft_indices.append(draft_idx)
|
||||
target_indices.append(target_idx)
|
||||
|
||||
if len(draft_indices) == 0:
|
||||
return
|
||||
|
||||
device = draft_tensors.device
|
||||
|
||||
# Create index tensors
|
||||
draft_indices_tensor = torch.tensor(draft_indices,
|
||||
dtype=torch.long,
|
||||
pin_memory=True).to(
|
||||
device, non_blocking=True)
|
||||
target_indices_tensor = torch.tensor(target_indices,
|
||||
dtype=torch.long,
|
||||
pin_memory=True).to(
|
||||
device, non_blocking=True)
|
||||
|
||||
# Pre-slice draft tensors: [draft_length, batch_size]
|
||||
draft_slice = draft_tensors[0:draft_length]
|
||||
|
||||
# Gather all source data at once using single index_select kernel
|
||||
# Result shape: [draft_length, num_requests]
|
||||
gathered = draft_slice.index_select(1, draft_indices_tensor).to(
|
||||
torch.int32)
|
||||
|
||||
# Scatter to new_tokens using advanced indexing (single kernel)
|
||||
# Shape: [draft_length, num_requests] -> [seq_len, batch_size, beam_width]
|
||||
target_inputs.new_tokens[draft_position + 1:draft_position +
|
||||
draft_length + 1, target_indices_tensor,
|
||||
0] = gathered
|
||||
|
||||
# Scatter to next_draft_tokens using advanced indexing (single kernel)
|
||||
# Shape: [num_requests, draft_length] -> [batch_size, max_draft_len]
|
||||
target_inputs.next_draft_tokens[target_indices_tensor,
|
||||
draft_position:draft_position +
|
||||
draft_length] = gathered.t()
|
||||
|
||||
def _setup_draft_batch_and_resources(
|
||||
self,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user