mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
1227 lines
52 KiB
Python
1227 lines
52 KiB
Python
from dataclasses import dataclass
|
|
from typing import List, Optional
|
|
|
|
import torch
|
|
from torch import nn
|
|
|
|
from tensorrt_llm.bindings.executor import FinishReason
|
|
|
|
from ..attention_backend import AttentionMetadata
|
|
from ..pyexecutor.llm_request import LlmRequest, LlmRequestState
|
|
from ..pyexecutor.resource_manager import BaseResourceManager, SlotManager
|
|
from ..pyexecutor.sampler import SampleState, SampleStateTensors, TorchSampler
|
|
from ..pyexecutor.scheduler import ScheduledRequests
|
|
from .interface import SpecConfig, SpecMetadata, SpeculativeDecodingMode
|
|
|
|
|
|
@dataclass(frozen=True, kw_only=True)
|
|
class SampleStateTensorsMTP(SampleStateTensors):
|
|
new_tokens_lens: torch.Tensor
|
|
next_draft_tokens: torch.Tensor
|
|
|
|
|
|
@dataclass(frozen=True, kw_only=True)
|
|
class SampleStateMTP(SampleState):
|
|
device: SampleStateTensorsMTP
|
|
host: SampleStateTensorsMTP
|
|
|
|
|
|
@dataclass
|
|
class MTPConfig(SpecConfig):
|
|
"""
|
|
Configuration for MTP.
|
|
"""
|
|
# The name of speculative decoding.
|
|
spec_dec_name = "MTP"
|
|
# The number of MTP modules
|
|
num_nextn_predict_layers: int = 1
|
|
# The number of max batch size
|
|
max_batch_size: int = 8
|
|
|
|
# Whether to use relaxed acceptance during thinking phase for reasoning model
|
|
use_relaxed_acceptance_for_thinking: bool = False
|
|
# The top-N tokens are sampled from logits to obtain a candidate set.
|
|
relaxed_topk: int = 1
|
|
# The threshold to further filter the candidate set.
|
|
# Filter out tokens with a large probability gap between the top-1 token's log probability.
|
|
relaxed_delta: float = 0.
|
|
|
|
# TODO: Hard code for DeepSeek R1
|
|
# When encounter <think>, start thinking phase.
|
|
# When encounter </think>, end thinking phase.
|
|
# <think> [thinking phase] </think> [real output]
|
|
BEGIN_THINKING_PHASE_TOKEN: int = 128798
|
|
END_THINKING_PHASE_TOKEN: int = 128799
|
|
|
|
def __post_init__(self) -> None:
|
|
self.spec_dec_mode = SpeculativeDecodingMode.from_string(
|
|
self.spec_dec_name)
|
|
self.max_draft_tokens = self.num_nextn_predict_layers
|
|
|
|
def update_from_model_config(self, model_config):
|
|
assert self.num_nextn_predict_layers > 0
|
|
if model_config.num_nextn_predict_layers == 1:
|
|
self.spec_dec_mode = SpeculativeDecodingMode.MTP_EAGLE
|
|
self.num_extra_kv_tokens = self.num_nextn_predict_layers - 1
|
|
|
|
|
|
class MTPHiddenStatesManager(BaseResourceManager):
|
|
|
|
def __init__(self, config: MTPConfig, dtype: torch.dtype, hidden_size: int,
|
|
max_num_requests: int):
|
|
self.dtype = dtype
|
|
self.num_nextn_predict_layers = config.num_nextn_predict_layers
|
|
self.hidden_size = hidden_size
|
|
self.max_num_requests = max_num_requests
|
|
self.use_relaxed_acceptance_for_thinking = config.use_relaxed_acceptance_for_thinking
|
|
self.slot_manager = SlotManager(max_num_requests)
|
|
|
|
# Since golden token's hidden state will always be generated after target model
|
|
self.mtp_past_hidden_states_pool = torch.zeros(
|
|
(max_num_requests, self.num_nextn_predict_layers, self.hidden_size),
|
|
device='cuda',
|
|
dtype=self.dtype,
|
|
)
|
|
self.mtp_past_tokens_pool = torch.zeros(
|
|
(max_num_requests, self.num_nextn_predict_layers),
|
|
device='cuda',
|
|
dtype=torch.int,
|
|
)
|
|
if self.use_relaxed_acceptance_for_thinking:
|
|
# The relaxed_delta for relaxed acceptance
|
|
self.mtp_relaxed_delta_pool = torch.zeros(
|
|
(self.max_num_requests),
|
|
dtype=torch.float,
|
|
device='cuda',
|
|
)
|
|
|
|
def prepare_resources(self, scheduled_batch: ScheduledRequests):
|
|
context_batch = scheduled_batch.context_requests
|
|
# allocate hidden state tensors
|
|
for req in context_batch:
|
|
if req.is_first_context_chunk():
|
|
slot_id = self.slot_manager.add_slot(req.request_id)
|
|
if self.use_relaxed_acceptance_for_thinking:
|
|
self.mtp_relaxed_delta_pool[slot_id] = 0.
|
|
|
|
def update_resources(self, scheduled_batch: ScheduledRequests):
|
|
pass
|
|
|
|
def free_resources(self, request: LlmRequest):
|
|
free_slot_id = self.slot_manager.get_slot(request.request_id)
|
|
if self.use_relaxed_acceptance_for_thinking:
|
|
self.mtp_relaxed_delta_pool[free_slot_id] = 0.
|
|
self.slot_manager.remove_slot(request.request_id)
|
|
|
|
def add_dummy_requests(self, request_ids: List[int]):
|
|
for rid in request_ids:
|
|
self.slot_manager.add_slot(rid)
|
|
|
|
def shutdown(self):
|
|
pass
|
|
|
|
def get_max_resource_count(self) -> int:
|
|
return self.max_num_requests
|
|
|
|
def get_needed_resource_to_completion(self, request: LlmRequest):
|
|
return 0
|
|
|
|
|
|
@dataclass
|
|
class MTPSpecMetadata(SpecMetadata):
|
|
"""
|
|
Metadata for MTP.
|
|
"""
|
|
# The number of MTP modules in the model
|
|
mtp_num_modules: int = 1,
|
|
# The hidden states manager for MTP
|
|
mtp_hidden_states_manager: Optional[MTPHiddenStatesManager] = None
|
|
# The slot ids for each request.
|
|
slot_ids: Optional[torch.Tensor] = None
|
|
# The index of the batche inputs
|
|
batch_indices_cuda: Optional[torch.Tensor] = None
|
|
|
|
def __post_init__(self) -> None:
|
|
if self.mtp_hidden_states_manager is not None:
|
|
# mtp_hidden_states_ptrs is a pointer tensor
|
|
self.mtp_hidden_states_ptrs = torch.empty(
|
|
[self.max_num_requests],
|
|
dtype=torch.int64,
|
|
device='cuda',
|
|
)
|
|
|
|
self.mtp_past_tokens_ptrs = torch.empty(
|
|
[self.max_num_requests],
|
|
dtype=torch.int64,
|
|
device='cuda',
|
|
)
|
|
self.slot_ids = torch.empty(
|
|
[self.max_num_requests],
|
|
dtype=torch.long,
|
|
device='cuda',
|
|
)
|
|
self.batch_indices_cuda = torch.empty(
|
|
[self.max_num_requests],
|
|
dtype=torch.int,
|
|
device='cuda',
|
|
)
|
|
|
|
def prepare(self):
|
|
assert self.request_ids is not None
|
|
num_seqs = len(self.request_ids)
|
|
# update batch indeices
|
|
batch_indices = torch.arange(num_seqs,
|
|
dtype=torch.int,
|
|
device='cpu',
|
|
pin_memory=True)
|
|
self.batch_indices_cuda[:num_seqs].copy_(batch_indices,
|
|
non_blocking=True)
|
|
# MTP module need different number of input tokens in generation phase
|
|
if self.spec_dec_mode.is_mtp_eagle():
|
|
self.num_tokens -= (self.num_generations) * self.mtp_num_modules
|
|
else:
|
|
self.num_tokens -= self.num_generations
|
|
|
|
if self.mtp_hidden_states_manager is not None: # MTP vanilla or use relaxed acceptance
|
|
mtp_slot_ids = []
|
|
for rid in self.request_ids:
|
|
slot_id = self.mtp_hidden_states_manager.slot_manager.get_slot(
|
|
rid)
|
|
mtp_slot_ids.append(slot_id)
|
|
|
|
# MTP Vanilla: Update mtp hidden states and past tokens
|
|
if self.spec_dec_mode.is_mtp():
|
|
mtp_hidden_states_ptrs = []
|
|
mtp_past_tokens_ptrs = []
|
|
for slot_id in mtp_slot_ids:
|
|
mtp_hidden_states_ptrs.append(
|
|
self.mtp_hidden_states_manager.
|
|
mtp_past_hidden_states_pool[slot_id].data_ptr())
|
|
mtp_past_tokens_ptrs.append(
|
|
self.mtp_hidden_states_manager.
|
|
mtp_past_tokens_pool[slot_id].data_ptr())
|
|
mtp_hidden_states_ptrs = torch.tensor(mtp_hidden_states_ptrs,
|
|
dtype=torch.int64,
|
|
pin_memory=True)
|
|
mtp_past_tokens_ptrs = torch.tensor(mtp_past_tokens_ptrs,
|
|
dtype=torch.int64,
|
|
pin_memory=True)
|
|
self.mtp_hidden_states_ptrs[:num_seqs].copy_(
|
|
mtp_hidden_states_ptrs, non_blocking=True)
|
|
self.mtp_past_tokens_ptrs[:num_seqs].copy_(mtp_past_tokens_ptrs,
|
|
non_blocking=True)
|
|
mtp_slot_ids = torch.tensor(mtp_slot_ids,
|
|
dtype=torch.int,
|
|
pin_memory=True)
|
|
self.slot_ids[:num_seqs].copy_(mtp_slot_ids, non_blocking=True)
|
|
|
|
|
|
class MTPSampler(TorchSampler):
|
|
"""
|
|
MTP sampler.
|
|
"""
|
|
|
|
def __init__(self, max_seq_len: int, config: MTPConfig):
|
|
super().__init__(max_seq_len, False)
|
|
self.mapping = None
|
|
self.draft_len = config.num_nextn_predict_layers
|
|
|
|
def _draft_meet_max_token_stop_criteria(self, request: LlmRequest,
|
|
num_tokens: int, beam_idx: int):
|
|
if self._meet_max_token_stop_criteria(request,
|
|
num_tokens + self.draft_len):
|
|
request.state = LlmRequestState.GENERATION_COMPLETE
|
|
request.set_finished_reason(FinishReason.LENGTH, beam_idx)
|
|
|
|
def update_requests(self, state: SampleStateMTP) -> None:
|
|
assert isinstance(state, SampleStateMTP)
|
|
|
|
state.sampler_event.synchronize()
|
|
new_tokens_list = state.host.new_tokens.tolist()
|
|
new_tokens_lens_list = state.host.new_tokens_lens.tolist()
|
|
next_draft_tokens_list = state.host.next_draft_tokens.tolist()
|
|
|
|
idx = 0
|
|
beam_idx = 0
|
|
for request in state.scheduled_requests.context_requests:
|
|
assert not request.py_return_context_logits, "return_context_logits not implemented for MTPSampler"
|
|
assert not request.py_return_generation_logits, "return_generation_logits not implemented for MTPSampler"
|
|
assert not request.py_return_log_probs, "return_log_probs not implemented for MTPSampler"
|
|
if request.get_context_remaining_length() != 0:
|
|
idx += 1
|
|
continue
|
|
|
|
if request.state != LlmRequestState.GENERATION_COMPLETE:
|
|
new_token = new_tokens_list[idx][0]
|
|
num_tokens = request.add_new_token(new_token, beam_idx)
|
|
should_stop = self._handle_stop_criteria(
|
|
request, new_token, num_tokens, beam_idx)
|
|
if self._draft_meet_max_token_stop_criteria(
|
|
request, num_tokens, beam_idx):
|
|
should_stop = True
|
|
request.py_draft_tokens = next_draft_tokens_list[idx]
|
|
request.py_decoding_iter += 1
|
|
idx += 1
|
|
|
|
# skip the results of cuda graph dummy requests
|
|
if idx == 0:
|
|
num_cuda_graph_dummy_requests = len(new_tokens_list) - len(
|
|
state.scheduled_requests.generation_requests)
|
|
idx += num_cuda_graph_dummy_requests
|
|
|
|
for request in state.scheduled_requests.generation_requests:
|
|
assert not request.py_return_context_logits, "return_context_logits not implemented for MTPSampler"
|
|
assert not request.py_return_generation_logits, "return_generation_logits not implemented for MTPSampler"
|
|
assert not request.py_return_log_probs, "return_log_probs not implemented for MTPSampler"
|
|
if request.state != LlmRequestState.GENERATION_COMPLETE:
|
|
new_tokens = new_tokens_list[idx]
|
|
num_new_tokens = new_tokens_lens_list[idx]
|
|
should_stop = False
|
|
for i in range(num_new_tokens):
|
|
new_token = new_tokens[i]
|
|
num_tokens = request.add_new_token(new_token, beam_idx)
|
|
should_stop = self._handle_stop_criteria(
|
|
request, new_token, num_tokens, beam_idx)
|
|
if should_stop:
|
|
break
|
|
if self._draft_meet_max_token_stop_criteria(
|
|
request, num_tokens, beam_idx):
|
|
should_stop = True
|
|
request.py_draft_tokens = next_draft_tokens_list[idx]
|
|
request.py_rewind_len = self.draft_len - (num_new_tokens - 1)
|
|
request.py_decoding_iter += 1
|
|
idx += 1
|
|
|
|
def sample_async(self, scheduled_requests: ScheduledRequests,
|
|
model_outputs) -> SampleStateMTP:
|
|
# new_tokens_device: all of the accepted tokens, device tensor
|
|
# new_tokens_lens_device: the accepted lengths, device tensor
|
|
# next_draft_tokens_device: predicted draft tokens, device tensor
|
|
# next_new_tokens_device: input tokens for the next iteration, device tensor
|
|
new_tokens_device = model_outputs['new_tokens']
|
|
new_tokens_lens_device = model_outputs['new_tokens_lens']
|
|
next_draft_tokens_device = model_outputs['next_draft_tokens']
|
|
next_new_tokens_device = model_outputs['next_new_tokens']
|
|
|
|
device = SampleStateTensorsMTP(
|
|
new_tokens=next_new_tokens_device,
|
|
new_tokens_lens=new_tokens_lens_device,
|
|
next_draft_tokens=next_draft_tokens_device,
|
|
)
|
|
host = SampleStateTensorsMTP(
|
|
new_tokens=new_tokens_device.to('cpu', non_blocking=True),
|
|
new_tokens_lens=new_tokens_lens_device.to('cpu', non_blocking=True),
|
|
next_draft_tokens=next_draft_tokens_device.to('cpu',
|
|
non_blocking=True),
|
|
)
|
|
sampler_event = torch.cuda.Event()
|
|
sampler_event.record()
|
|
# add dummy draft tokens to context requests to prepare kv cache in advance
|
|
# with the max draft token length
|
|
for request in scheduled_requests.context_requests:
|
|
request.py_draft_tokens = [1] * self.draft_len
|
|
return SampleStateMTP(scheduled_requests=scheduled_requests,
|
|
logits=model_outputs['logits'],
|
|
device=device,
|
|
host=host,
|
|
sampler_event=sampler_event)
|
|
|
|
|
|
class MTPWorker(nn.Module):
|
|
|
|
def __init__(self, spec_config: MTPConfig):
|
|
super().__init__()
|
|
self.spec_config = spec_config
|
|
self.is_thop = True
|
|
|
|
def forward(
|
|
self,
|
|
input_ids,
|
|
position_ids,
|
|
hidden_states,
|
|
logits,
|
|
lm_head,
|
|
embed_tokens,
|
|
attn_metadata,
|
|
spec_metadata,
|
|
mtp_layers,
|
|
):
|
|
batch_size = attn_metadata.num_seqs
|
|
|
|
# Sample and verify draft tokens
|
|
raw_logits = logits
|
|
accepted_tokens, num_accepted_tokens = self.sample_and_accept_draft_tokens(
|
|
input_ids, logits, spec_metadata, attn_metadata)
|
|
|
|
# Update MTP past hidden states
|
|
self.update_mtp_hidden_states(input_ids=input_ids,
|
|
target_model_hidden_states=hidden_states,
|
|
num_accepted_tokens=num_accepted_tokens,
|
|
accepted_tokens=accepted_tokens,
|
|
spec_metadata=spec_metadata,
|
|
attn_metadata=attn_metadata)
|
|
|
|
# Predict draft tokens
|
|
next_draft_tokens = []
|
|
# will not be used in the first MTP, just a placeholder to avoid Nonetype
|
|
previous_layer_draft_tokens = torch.empty(1,
|
|
dtype=torch.int,
|
|
device='cpu')
|
|
for mtp_layer_idx, mtp_layer in enumerate(mtp_layers):
|
|
draft_inputs = self.prepare_drafter_inputs(
|
|
mtp_layer_idx=mtp_layer_idx,
|
|
input_ids=input_ids,
|
|
position_ids=position_ids,
|
|
previous_layer_hidden_states=hidden_states,
|
|
previous_layer_draft_tokens=previous_layer_draft_tokens,
|
|
num_accepted_tokens=num_accepted_tokens,
|
|
spec_metadata=spec_metadata,
|
|
attn_metadata=attn_metadata)
|
|
hidden_states, logits = mtp_layer(lm_head=lm_head,
|
|
embed_tokens=embed_tokens,
|
|
**draft_inputs)
|
|
previous_layer_draft_tokens = self.draft_sampler(logits)
|
|
next_draft_tokens.append(previous_layer_draft_tokens)
|
|
|
|
input_ids = draft_inputs["input_ids"]
|
|
position_ids = draft_inputs["position_ids"]
|
|
attn_metadata = draft_inputs["attn_metadata"]
|
|
next_draft_tokens = torch.stack(next_draft_tokens, dim=1)
|
|
if attn_metadata.is_cuda_graph and attn_metadata is not None:
|
|
self.restore_attn_metadata(attn_metadata=attn_metadata)
|
|
|
|
# prepare next new tokens to support overlap scheduler
|
|
next_new_tokens = accepted_tokens[
|
|
spec_metadata.batch_indices_cuda[:batch_size],
|
|
num_accepted_tokens - 1].unsqueeze(1)
|
|
next_new_tokens = torch.concat([next_new_tokens, next_draft_tokens],
|
|
dim=1)
|
|
|
|
return {
|
|
'logits': raw_logits,
|
|
'new_tokens': accepted_tokens,
|
|
'new_tokens_lens': num_accepted_tokens,
|
|
'next_draft_tokens': next_draft_tokens,
|
|
'next_new_tokens': next_new_tokens
|
|
}
|
|
|
|
def skip_forward(
|
|
self,
|
|
input_ids,
|
|
position_ids,
|
|
hidden_states,
|
|
logits,
|
|
lm_head,
|
|
embed_tokens,
|
|
attn_metadata,
|
|
spec_metadata,
|
|
mtp_layers,
|
|
):
|
|
batch_size = attn_metadata.num_seqs
|
|
mtp_num_modules = self.spec_config.num_nextn_predict_layers
|
|
accepted_tokens = torch.empty((batch_size, (mtp_num_modules + 1)),
|
|
dtype=torch.int,
|
|
device=logits.device)
|
|
num_accepted_tokens = torch.ones(batch_size,
|
|
dtype=torch.int,
|
|
device=logits.device)
|
|
next_draft_tokens = torch.empty((batch_size, mtp_num_modules),
|
|
dtype=torch.int,
|
|
device=logits.device)
|
|
next_new_tokens = torch.empty((batch_size, (mtp_num_modules + 1)),
|
|
dtype=torch.int,
|
|
device=logits.device)
|
|
return {
|
|
'logits': logits,
|
|
'new_tokens': accepted_tokens,
|
|
'new_tokens_lens': num_accepted_tokens,
|
|
'next_draft_tokens': next_draft_tokens,
|
|
'next_new_tokens': next_new_tokens
|
|
}
|
|
|
|
def update_mtp_hidden_states(
|
|
self,
|
|
input_ids: torch.LongTensor,
|
|
target_model_hidden_states: torch.Tensor,
|
|
num_accepted_tokens: torch.Tensor,
|
|
accepted_tokens: torch.Tensor,
|
|
spec_metadata: MTPSpecMetadata,
|
|
attn_metadata: AttentionMetadata,
|
|
):
|
|
'''
|
|
Update the past hidden states and past tokens in spec_metadata base on
|
|
the newly accepted tokens and historical hidden states.
|
|
These past hidden states and past tokens will be use in MTP module.
|
|
Also update the seq_len and kv_lens in attention metadata.
|
|
|
|
Args:
|
|
input_ids: torch.LongTensor
|
|
[num_tokens]
|
|
The input ids of all requests. Flatten.
|
|
|
|
target_model_hidden_states: torch.Tensor
|
|
[num_tokens, hidden_size]
|
|
Target model's hidden states.
|
|
|
|
num_accepted_tokens: torch.Tensor
|
|
[batch_size]
|
|
Number of accepted tokens per request.
|
|
|
|
accepted_tokens: torch.Tensor
|
|
[batch_size, max_draft_tokens + 1]
|
|
Accepted token ids. Flattened.
|
|
|
|
spec_metadata: MTPSpecMetadata
|
|
MTP speculative decoding metadata
|
|
|
|
attn_metadata: AttentionMetadata
|
|
Attention metadata
|
|
|
|
Returns:
|
|
None
|
|
|
|
Example:
|
|
Assume there are 3 MTP layers
|
|
Notation:
|
|
- H_t: token t's hidden state, generated by the target model
|
|
- h_t: token t's hidden state, generated by the draft model
|
|
|
|
Prompt: ABCD
|
|
|
|
Context phase:
|
|
Target model:
|
|
- input tokens: ABCD + []
|
|
- sampling tokens: E
|
|
- accepted tokens: E
|
|
- KV cache: ABCD
|
|
- hidden states: H_A, H_B, H_C, H_D
|
|
Draft model:
|
|
MTP1:
|
|
- input tokens: BCDE # For context request, prompt[2: -1] + new generated goloden token is the input.
|
|
- input hidden states: H_A, H_B, H_C, H_D # Therefore, the input and output tokens/hidden_states will have the dimension of `prompt_length`.
|
|
- KV cache: () + BCDE '()' means historical KV cache
|
|
- output hidden states: h_B, h_C, h_D, h_E
|
|
- output next draft token: F
|
|
MTP2:
|
|
- input token: CDEF # Also with the `prompt_length`.
|
|
- input hidden states: h_B, h_C, h_D, h_E
|
|
- KV cache: () + CDEF
|
|
- output hidden states: h_C, h_D, h_E, h_F
|
|
- output next draft token: G
|
|
MTP3:
|
|
- input tokens: DEFG
|
|
- input hidden states: h_C, h_D, h_E, h_F
|
|
- KV cache: () + DEFG
|
|
- output hidden states: h_D, h_E, h_F, h_G
|
|
- output next draft token: H
|
|
After 3 MTP layers:
|
|
- input tokens: BCDE
|
|
- new generated draft tokens: FGH
|
|
|
|
Generation phase 1: accept partial draft tokens
|
|
Target model:
|
|
- input tokens: E + FGH
|
|
- sampling tokens: FGXY
|
|
- accepted tokens: FGX
|
|
- KV cache: (ABCD) + EFGH (H's KV cache is useless)
|
|
- hidden states: H_E, H_F, H_G, H_H (H_H is useless)
|
|
Draft model:
|
|
MPT1:
|
|
- input tokens: FGX # For generation request, `mtp_num_modules` + 1 of tokens will be used as input.
|
|
- input hidden states: H_E, H_F, H_G
|
|
- KV cache: (BCDE) + FGX
|
|
- output hidden states: h_F, h_G, h_X
|
|
- output next draft token: N
|
|
MPT2:
|
|
- input tokens: GXN
|
|
- input hidden states: h_F, h_G, h_X
|
|
- KV cache: (CDEF) + GXN
|
|
- output hidden states: h_G, h_X, h_N
|
|
- output next draft token: O
|
|
MPT3:
|
|
- input tokens: XNO
|
|
- input hidden states: h_G, h_X, h_N
|
|
- KV cache: (DEFG) + XNO
|
|
- output hidden states: h_X, h_N, h_O
|
|
- output next draft token: P
|
|
After 3 MTP layers:
|
|
- input tokens: FGX
|
|
- new generated draft tokens: NOP
|
|
|
|
Generation 2: accept none draft tokens
|
|
Target model:
|
|
- input tokens: X + NOP
|
|
- sampling tokens: KMZY
|
|
- accepted tokens: K
|
|
- KV cache: (ABCDEFG) + NOP (NOP's KV cache is useless)
|
|
- hidden states: H_X, H_N, H_O, H_P (H_N, H_O, H_P is useless)
|
|
Draft model:
|
|
MTP1:
|
|
- input tokens: GXK
|
|
- input hidden states: H_F, H_G, H_X
|
|
- KV cache: (BCDE + FGX) + FGX
|
|
- output hidden states: h_G, h_X, h_K
|
|
- output next draft token: U
|
|
MTP2:
|
|
- input tokens: XKU
|
|
- input hidden states: h_G, h_X, h_K
|
|
- KV cache: (CDEF + GXN) + XKU
|
|
- output hidden states: h_X, h_K, h_U
|
|
- output next draft token: V
|
|
MTP3:
|
|
- input tokens: KUV
|
|
- input hidden states: h_X, h_K, h_U
|
|
- KV cache: (DEFG + XNO) + KUV
|
|
- output hidden states: h_K, h_U, h_V
|
|
- output next draft token: Q
|
|
After 3 MTP layers:
|
|
- input tokens: GXK
|
|
- new generated draft tokens: UVQ
|
|
'''
|
|
batch_size = attn_metadata.num_seqs
|
|
num_contexts = attn_metadata.num_contexts
|
|
seq_lens = attn_metadata.seq_lens_cuda
|
|
hidden_size = target_model_hidden_states.shape[-1]
|
|
mtp_num_modules = self.spec_config.num_nextn_predict_layers
|
|
|
|
if self.is_thop:
|
|
_, _ = torch.ops.trtllm.mtp_update_hidden_states_op(
|
|
input_ids, seq_lens, target_model_hidden_states,
|
|
spec_metadata.mtp_hidden_states_ptrs,
|
|
spec_metadata.mtp_past_tokens_ptrs, num_accepted_tokens,
|
|
accepted_tokens, mtp_num_modules, batch_size, num_contexts,
|
|
hidden_size)
|
|
else:
|
|
assert len(spec_metadata.request_ids) == batch_size
|
|
mtp_past_hidden_states_pool = spec_metadata.mtp_hidden_states_manager.mtp_past_hidden_states_pool
|
|
mtp_past_tokens_pool = spec_metadata.mtp_hidden_states_manager.mtp_past_tokens_pool
|
|
|
|
input_ids_offset = 0
|
|
|
|
for bix in range(batch_size):
|
|
slot_id_cuda = spec_metadata.slot_ids[bix]
|
|
# [num_nextn_predict_layers - 1, hidden_states]
|
|
mtp_hidden_states = torch.index_select(
|
|
mtp_past_hidden_states_pool, 0, slot_id_cuda).squeeze(0)
|
|
# [num_nextn_predict_layers]
|
|
mtp_tokens = torch.index_select(mtp_past_tokens_pool, 0,
|
|
slot_id_cuda).squeeze(0)
|
|
cur_accepted_len = num_accepted_tokens[bix]
|
|
cur_seq_len = seq_lens[bix]
|
|
|
|
# Update MTP tokens
|
|
cur_accepted_tokens = accepted_tokens[bix, 0:cur_accepted_len]
|
|
if bix < num_contexts:
|
|
past_input_ids = input_ids[
|
|
input_ids_offset:input_ids_offset + cur_seq_len]
|
|
else:
|
|
past_input_ids = mtp_tokens
|
|
|
|
cat_past_tokens = torch.cat(
|
|
(past_input_ids, cur_accepted_tokens), dim=0)
|
|
# shape: [mtp_num_modules]
|
|
new_mtp_past_tokens = cat_past_tokens[
|
|
-mtp_num_modules:,
|
|
]
|
|
# Update the buffer, but keep the pointer unchanged.
|
|
mtp_past_tokens_pool.index_copy_(
|
|
0, slot_id_cuda, new_mtp_past_tokens.unsqueeze(0))
|
|
|
|
# Update MTP hidden states
|
|
past_hidden_states = mtp_hidden_states
|
|
# For context, we need to slice prompt length
|
|
# For generation, we only need to slice accepted length
|
|
num_slice_tokens = cur_seq_len if bix < num_contexts else cur_accepted_len
|
|
accepted_tokens_hidden_states = target_model_hidden_states[
|
|
input_ids_offset:input_ids_offset + num_slice_tokens]
|
|
cat_hidden_states = torch.cat(
|
|
(past_hidden_states, accepted_tokens_hidden_states), dim=0)
|
|
# shape: [mtp_num_modules, hidden_states]
|
|
new_mtp_hidden_states = cat_hidden_states[
|
|
-mtp_num_modules:,
|
|
]
|
|
# Update the buffer, but keep the pointer unchanged.
|
|
mtp_past_hidden_states_pool.index_copy_(
|
|
0, slot_id_cuda, new_mtp_hidden_states.unsqueeze(0))
|
|
|
|
# Update offset
|
|
input_ids_offset += cur_seq_len
|
|
|
|
def sample_and_accept_draft_tokens(
|
|
self,
|
|
input_ids: torch.LongTensor,
|
|
logits: torch.Tensor,
|
|
spec_metadata: MTPSpecMetadata,
|
|
attn_metadata: AttentionMetadata,
|
|
):
|
|
'''
|
|
Takes input logits and samples golden token + predictions from draft tokens.
|
|
Runs acceptance algorithm to accept draft tokens.
|
|
Currently only support greedy sampling. All decoding is done using Top1 and token equality is used
|
|
for acceptance.
|
|
|
|
Args:
|
|
input_ids: torch.LongTensor
|
|
[num_tokens]
|
|
The input ids of all requests. Flatten.
|
|
|
|
logits: torch.Tensor
|
|
[num_tokens, vocab_size]
|
|
Logits produced by the target model.
|
|
|
|
spec_metadata: MTPSpecMetadata
|
|
MTP speculative decoding metadata
|
|
|
|
attn_metadata: AttentionMetadata
|
|
Attention metadata
|
|
|
|
Returns:
|
|
accepted_tokens: torch.Tensor
|
|
[batch_size, (max_draft_tokens + 1)]
|
|
Accepted token ids. Flattened.
|
|
|
|
num_accepted_tokens: torch.Tensor
|
|
[batch_size]
|
|
Number of accepted tokens per request.
|
|
|
|
Example:
|
|
Assume there are 3 MTP layers
|
|
Prompt: ABCD
|
|
|
|
Context phase:
|
|
Target model:
|
|
- input tokens: ABCD + []
|
|
- sampling tokens: E
|
|
- accepted tokens: E
|
|
Draft model:
|
|
- input tokens: BCDE
|
|
- new generated draft tokens: FGH
|
|
Current sequence: ABCD E`FGH -> Whitespace separates tokens produced by each phase
|
|
-> Backtick separates accepted and draft tokens
|
|
|
|
Generation phase 1:
|
|
Target model:
|
|
- input tokens: E + FGH
|
|
- sampling tokens: FGXY -> Sample with E's logit and get 'F'; Sample with F's logit, ...
|
|
- accepted tokens: FGX -> 'X' will be treat as the accepted token
|
|
Draft model:
|
|
- input tokens: FGX
|
|
- new generated draft tokens: PQR
|
|
Current sequence: ABCD EFG X`PQR
|
|
|
|
Generation phase 2:
|
|
Target model:
|
|
- input tokens: X + PQR
|
|
- sampling tokens: PYST
|
|
- accepted token: PY
|
|
Draft model:
|
|
- input tokens: PY
|
|
- new generated draft tokens: UVW
|
|
Current sequence: ABCD EFG XP Y`UVW
|
|
'''
|
|
|
|
batch_size = attn_metadata.num_seqs
|
|
num_contexts = attn_metadata.num_contexts
|
|
num_gens = batch_size - num_contexts
|
|
mtp_num_modules = self.spec_config.num_nextn_predict_layers
|
|
|
|
if logits.dim() == 1:
|
|
logits = logits.unsqueeze(0)
|
|
|
|
# The return buffer
|
|
accepted_tokens = torch.empty((batch_size, (mtp_num_modules + 1)),
|
|
dtype=torch.int,
|
|
device=logits.device)
|
|
num_accepted_tokens = torch.ones(batch_size,
|
|
dtype=torch.int,
|
|
device=logits.device)
|
|
if self.spec_config.use_relaxed_acceptance_for_thinking:
|
|
mtp_relaxed_delta_pool = spec_metadata.mtp_hidden_states_manager.mtp_relaxed_delta_pool
|
|
|
|
# context
|
|
con_logits = logits[:num_contexts]
|
|
con_target_tokens = torch.argmax(con_logits, dim=-1)
|
|
accepted_tokens[:num_contexts, 0] = con_target_tokens[:num_contexts]
|
|
last_tokens_idx = torch.cumsum(
|
|
attn_metadata.seq_lens_cuda, dim=0, dtype=torch.long) - 1
|
|
ctx_input_ids = input_ids[:attn_metadata.num_ctx_tokens]
|
|
ctx_is_think = (ctx_input_ids ==
|
|
self.spec_config.BEGIN_THINKING_PHASE_TOKEN).int()
|
|
ctx_is_think_cumsum = torch.cumsum(ctx_is_think, dim=0)
|
|
ctx_last_cumsum = ctx_is_think_cumsum[
|
|
last_tokens_idx[:num_contexts]]
|
|
ctx_think_tokens_num = torch.diff(
|
|
ctx_last_cumsum,
|
|
dim=0,
|
|
prepend=torch.zeros(1,
|
|
dtype=torch.int,
|
|
device=ctx_last_cumsum.device))
|
|
|
|
ctx_delta = (ctx_think_tokens_num
|
|
>= 1).int() * self.spec_config.relaxed_delta
|
|
ctx_slot_ids = spec_metadata.slot_ids[:num_contexts]
|
|
mtp_relaxed_delta_pool.index_copy_(0, ctx_slot_ids, ctx_delta)
|
|
|
|
# generation
|
|
gen_logits = logits[num_contexts:]
|
|
gen_logprobs = torch.softmax(gen_logits, dim=-1)
|
|
|
|
topk_value, topk_indices = torch.topk(
|
|
gen_logprobs, k=self.spec_config.relaxed_topk, dim=-1)
|
|
# [num_gens, mtp_num_modules + 1, relaxed_topk]
|
|
topk_indices = topk_indices.reshape(num_gens, mtp_num_modules + 1,
|
|
self.spec_config.relaxed_topk)
|
|
topk_value = topk_value.reshape(num_gens, mtp_num_modules + 1,
|
|
self.spec_config.relaxed_topk)
|
|
|
|
# [num_gens, mtp_num_modules]
|
|
draft_tokens = spec_metadata.draft_tokens.reshape(
|
|
num_gens, mtp_num_modules)
|
|
|
|
accepted_tokens, num_accepted_tokens = torch.ops.trtllm.mtp_relaxed_acceptance_op(
|
|
spec_metadata.slot_ids, topk_value, topk_indices, draft_tokens,
|
|
mtp_relaxed_delta_pool, num_accepted_tokens, accepted_tokens,
|
|
mtp_num_modules, batch_size, num_contexts,
|
|
self.spec_config.relaxed_topk, self.spec_config.relaxed_delta,
|
|
self.spec_config.BEGIN_THINKING_PHASE_TOKEN,
|
|
self.spec_config.END_THINKING_PHASE_TOKEN)
|
|
|
|
# Strict acceptance
|
|
else:
|
|
if self.is_thop:
|
|
# Temporary buffer
|
|
target_tokens_cache = torch.zeros(batch_size *
|
|
(mtp_num_modules + 1),
|
|
dtype=torch.int,
|
|
device=logits.device)
|
|
accepted_tokens, num_accepted_tokens = torch.ops.trtllm.mtp_sampling_and_accepted_draft_tokens_op(
|
|
logits, spec_metadata.draft_tokens, target_tokens_cache,
|
|
mtp_num_modules, batch_size, num_contexts, logits.shape[-1])
|
|
else:
|
|
# Do greedy sampling for the input logits
|
|
target_tokens = torch.argmax(logits, dim=-1)
|
|
|
|
# context
|
|
accepted_tokens[:num_contexts, 0] = target_tokens[:num_contexts]
|
|
|
|
# generation
|
|
gen_target_tokens = target_tokens[num_contexts:].reshape(
|
|
num_gens, mtp_num_modules + 1)
|
|
accepted_tokens[num_contexts:, :] = gen_target_tokens
|
|
draft_tokens = spec_metadata.draft_tokens.reshape(
|
|
num_gens, mtp_num_modules)
|
|
num_accepted_tokens[num_contexts:] += torch.cumprod(
|
|
(draft_tokens == gen_target_tokens[:, :mtp_num_modules]
|
|
).int(),
|
|
dim=-1).sum(1)
|
|
|
|
return accepted_tokens, num_accepted_tokens
|
|
|
|
def change_attn_metadata(self, num_accepted_tokens: torch.Tensor,
|
|
attn_metadata: AttentionMetadata):
|
|
batch_size = attn_metadata.num_seqs
|
|
mtp_num_modules = self.spec_config.num_nextn_predict_layers
|
|
|
|
num_contexts = attn_metadata.num_contexts
|
|
attn_metadata._seq_lens[num_contexts:batch_size] -= 1
|
|
attn_metadata._seq_lens_cuda[num_contexts:batch_size] -= 1
|
|
attn_metadata.on_update()
|
|
|
|
if hasattr(attn_metadata, 'kv_lens_cuda'):
|
|
# Note that it's important to not free the seq_lens_cuda
|
|
# buffer once the graph has been captured also - this will invalidate
|
|
# the graph and force an expensive recapture.
|
|
attn_metadata.kv_lens_cuda[num_contexts:batch_size] -= (
|
|
mtp_num_modules + 1 -
|
|
num_accepted_tokens[num_contexts:batch_size])
|
|
|
|
if attn_metadata.kv_cache_params is not None and not attn_metadata.is_cuda_graph:
|
|
for i in range(num_contexts, batch_size):
|
|
# used for vanilla MLA, list on cpu
|
|
attn_metadata.kv_cache_params.num_cached_tokens_per_seq[
|
|
i] -= mtp_num_modules + 1 - num_accepted_tokens[i].item()
|
|
|
|
def restore_attn_metadata(self, attn_metadata: AttentionMetadata):
|
|
batch_size = attn_metadata.num_seqs
|
|
num_contexts = attn_metadata.num_contexts
|
|
attn_metadata._seq_lens[num_contexts:batch_size] += 1
|
|
attn_metadata._seq_lens_cuda[num_contexts:batch_size] += 1
|
|
attn_metadata.on_update()
|
|
|
|
def prepare_drafter_inputs(
|
|
self,
|
|
mtp_layer_idx: int,
|
|
input_ids: torch.LongTensor,
|
|
position_ids: torch.LongTensor,
|
|
previous_layer_hidden_states: torch.Tensor,
|
|
previous_layer_draft_tokens: torch.Tensor,
|
|
num_accepted_tokens: torch.Tensor,
|
|
spec_metadata: MTPSpecMetadata,
|
|
attn_metadata: AttentionMetadata,
|
|
):
|
|
'''
|
|
Parepare the input of the draft model.
|
|
|
|
Args:
|
|
mtp_layer_idx: int
|
|
The index number of the current MTP layer.
|
|
|
|
input_ids: torch.LongTensor
|
|
[num_tokens]
|
|
The input ids of all requests. Flatten.
|
|
When mtp_layer_idx == 0: num_tokens = sum(all prompts) + num_generation * (mtp_num_modules + 1)
|
|
When mtp_layer_idx > 0: num_tokens = sum(all prompts) + num_generation * (mtp_num_modules)
|
|
|
|
position_ids: torch.LongTensor
|
|
[1][num_tokens]
|
|
The position id of all requests. Flatten.
|
|
|
|
previous_layer_hidden_states: torch.Tensor
|
|
[num_tokens, hidden_size]
|
|
Target model's hidden states.
|
|
|
|
previous_layer_draft_tokens: torch.Tensor
|
|
[batch_size]
|
|
Privous layer's draft tokens.
|
|
|
|
num_accepted_tokens: torch.Tensor
|
|
[batch_size]
|
|
Number of accepted draft tokens. Will be used for the first MTP layer.
|
|
|
|
spec_metadata: MTPSpecMetadata
|
|
MTP speculative decoding metadata
|
|
|
|
attn_metadata: AttentionMetadata
|
|
Attention metadata
|
|
|
|
Returns: draft_inputs
|
|
input_ids: torch.Tensor
|
|
[num_tokens]
|
|
The new input ids of all requests. Flatten.
|
|
num_tokens = sum(all prompts) + num_generation * (mtp_num_modules)
|
|
|
|
position_ids: torch.Tensor
|
|
[1][[num_tokens]]
|
|
The new position ids of all requests. Flatten.
|
|
Directly use the input position ids.
|
|
|
|
hidden_states: torch.Tensor
|
|
[num_tokens][hidden_size]
|
|
Continuous hidden states buffer.
|
|
|
|
attn_metadata: AttentionMetadata
|
|
Attention metadata
|
|
|
|
'''
|
|
batch_size = attn_metadata.num_seqs
|
|
num_contexts = attn_metadata.num_contexts
|
|
num_gens = batch_size - num_contexts
|
|
mtp_past_hidden_states_pool = spec_metadata.mtp_hidden_states_manager.mtp_past_hidden_states_pool
|
|
mtp_past_tokens_pool = spec_metadata.mtp_hidden_states_manager.mtp_past_tokens_pool
|
|
mtp_num_modules = self.spec_config.num_nextn_predict_layers
|
|
|
|
if self.is_thop:
|
|
# Temporary buffer
|
|
hidden_size = previous_layer_hidden_states.shape[-1]
|
|
|
|
# generation requests' golden tokens
|
|
num_tokens = input_ids.shape[
|
|
0] - num_gens if mtp_layer_idx == 0 else input_ids.shape[0]
|
|
return_input_ids = torch.empty(num_tokens,
|
|
dtype=torch.int,
|
|
device="cuda")
|
|
|
|
if (mtp_layer_idx == 0):
|
|
return_hidden_states = torch.empty(
|
|
(num_tokens, hidden_size),
|
|
dtype=previous_layer_hidden_states.dtype,
|
|
device="cuda")
|
|
else:
|
|
return_hidden_states = torch.empty(
|
|
1, dtype=previous_layer_hidden_states.dtype,
|
|
device="cuda") # Useless, placeholder
|
|
|
|
(return_input_ids, return_hidden_states
|
|
) = torch.ops.trtllm.mtp_prepare_drafter_inputs_op(
|
|
input_ids, attn_metadata.seq_lens_cuda,
|
|
spec_metadata.mtp_hidden_states_ptrs,
|
|
spec_metadata.mtp_past_tokens_ptrs,
|
|
previous_layer_hidden_states, previous_layer_draft_tokens,
|
|
return_input_ids, return_hidden_states, mtp_num_modules,
|
|
mtp_layer_idx, batch_size, num_contexts, hidden_size)
|
|
|
|
else:
|
|
return_input_ids_list = []
|
|
return_hidden_states_list = []
|
|
if mtp_layer_idx == 0: # The first MTP layer
|
|
input_ids_offset = 0
|
|
for bix in range(batch_size):
|
|
slot_id_cuda = spec_metadata.slot_ids[bix]
|
|
cur_seq_len = attn_metadata.seq_lens_cuda[bix]
|
|
past_tokens = torch.index_select(mtp_past_tokens_pool, 0,
|
|
slot_id_cuda).squeeze(0)
|
|
past_hidden_states = torch.index_select(
|
|
mtp_past_hidden_states_pool, 0, slot_id_cuda).squeeze(0)
|
|
|
|
if bix < num_contexts:
|
|
# Context request
|
|
# MTP past tokens
|
|
# cuda Graph should not run this part since has context request
|
|
prompt_tokens = input_ids[
|
|
input_ids_offset:input_ids_offset + cur_seq_len]
|
|
cat_tensor = torch.cat(
|
|
(prompt_tokens[1:], past_tokens[-1:]), dim=0)
|
|
return_input_ids_list.append(cat_tensor)
|
|
|
|
# MTP past hidden states
|
|
prompt_hidden_states = previous_layer_hidden_states[
|
|
input_ids_offset:input_ids_offset + cur_seq_len]
|
|
return_hidden_states_list.append(prompt_hidden_states)
|
|
else:
|
|
# Generation request
|
|
# Directly append
|
|
return_input_ids_list.append(past_tokens)
|
|
return_hidden_states_list.append(past_hidden_states)
|
|
|
|
# Update offset
|
|
input_ids_offset += cur_seq_len
|
|
|
|
# Concat into a continuous buffer
|
|
return_input_ids = torch.cat(return_input_ids_list, dim=0)
|
|
return_hidden_states = torch.cat(return_hidden_states_list,
|
|
dim=0)
|
|
else:
|
|
# this else part should be CUDA Graph supported
|
|
input_ids_offset = 0
|
|
for bix in range(batch_size):
|
|
# For the generation request, the 'cur_seq_len' already been update to 'num_nextn_predict_layers'.
|
|
cur_seq_len = attn_metadata.seq_lens_cuda[bix]
|
|
|
|
# The 'input_ids' come from the prvious layer
|
|
previous_layer_tokens = input_ids[
|
|
input_ids_offset:input_ids_offset + cur_seq_len]
|
|
|
|
# MTP past tokens
|
|
previous_draft_tokens = previous_layer_draft_tokens[bix:(
|
|
bix + 1)]
|
|
cat_tensor = torch.cat(
|
|
(previous_layer_tokens, previous_draft_tokens), dim=0)
|
|
return_input_ids_list.append(cat_tensor[1:])
|
|
|
|
# Update offset
|
|
input_ids_offset += cur_seq_len
|
|
|
|
return_input_ids = torch.cat(return_input_ids_list, dim=0)
|
|
# Directly use previous_layer_hidden_states as this layer's input hidden states
|
|
return_hidden_states = previous_layer_hidden_states
|
|
|
|
if mtp_layer_idx == 0 and attn_metadata is not None:
|
|
self.change_attn_metadata(num_accepted_tokens, attn_metadata)
|
|
|
|
return {
|
|
"input_ids": return_input_ids,
|
|
"position_ids": position_ids,
|
|
"hidden_states": return_hidden_states,
|
|
"attn_metadata": attn_metadata,
|
|
"spec_metadata": spec_metadata,
|
|
}
|
|
|
|
def draft_sampler(
|
|
self,
|
|
logits: torch.Tensor,
|
|
):
|
|
'''
|
|
Sampling draft tokens.
|
|
|
|
Args:
|
|
logits: torch.Tensor
|
|
[num_tokens, vocab_size]
|
|
Logits produced by the draft model.
|
|
|
|
Returns:
|
|
draft_tokens: torch.Tensor
|
|
[batch_size * max_draft_tokens]
|
|
Draft token ids. Flattened.
|
|
'''
|
|
|
|
draft_tokens = torch.argmax(logits, dim=-1).type(torch.int32)
|
|
return draft_tokens
|
|
|
|
|
|
class MTPEagleWorker(MTPWorker):
|
|
|
|
def __init__(self, spec_config: MTPConfig):
|
|
super().__init__(spec_config)
|
|
self.mtp_num_modules = spec_config.num_nextn_predict_layers
|
|
|
|
def forward(
|
|
self,
|
|
input_ids,
|
|
position_ids,
|
|
hidden_states,
|
|
logits,
|
|
lm_head,
|
|
embed_tokens,
|
|
attn_metadata,
|
|
spec_metadata,
|
|
mtp_layers,
|
|
):
|
|
batch_size = attn_metadata.num_seqs
|
|
num_contexts = attn_metadata.num_contexts
|
|
|
|
# Sample and verify draft tokens
|
|
raw_logits = logits
|
|
accepted_tokens, num_accepted_tokens = self.sample_and_accept_draft_tokens(
|
|
input_ids, logits, spec_metadata, attn_metadata)
|
|
|
|
# Save the old attn_metadata and spec_metadata
|
|
if attn_metadata.is_cuda_graph:
|
|
seq_len = attn_metadata._seq_lens[:batch_size].clone()
|
|
seq_len_cuda = attn_metadata._seq_lens_cuda[:batch_size].clone()
|
|
|
|
# Prepare inputs for the 1st MTP layer
|
|
position_ids = position_ids.squeeze(0)
|
|
inputs = self.prepare_drafter_inputs(
|
|
input_ids=input_ids,
|
|
position_ids=position_ids,
|
|
hidden_states=hidden_states,
|
|
accepted_tokens=accepted_tokens,
|
|
num_accepted_tokens=num_accepted_tokens,
|
|
attn_metadata=attn_metadata,
|
|
spec_metadata=spec_metadata)
|
|
|
|
# Predict draft tokens
|
|
next_draft_tokens = []
|
|
for i in range(self.mtp_num_modules):
|
|
hidden_states, logits = mtp_layers[0](lm_head=lm_head,
|
|
embed_tokens=embed_tokens,
|
|
**inputs)
|
|
new_draft_token = self.draft_sampler(logits)
|
|
next_draft_tokens.append(new_draft_token)
|
|
# update inputs
|
|
last_tokens = torch.cumsum(
|
|
attn_metadata.seq_lens_cuda,
|
|
dim=0,
|
|
dtype=torch.long,
|
|
) - 1
|
|
position_ids = inputs["position_ids"][last_tokens] + 1
|
|
hidden_states = hidden_states[last_tokens]
|
|
attn_metadata._seq_lens[:attn_metadata.num_contexts].fill_(1)
|
|
attn_metadata._seq_lens_cuda[:attn_metadata.num_contexts].fill_(1)
|
|
attn_metadata.on_update()
|
|
# cannot run generation if their is no kv cache
|
|
if inputs["attn_metadata"].kv_cache_manager is not None:
|
|
attn_metadata.host_request_types[:attn_metadata.
|
|
num_contexts].fill_(1)
|
|
attn_metadata.num_contexts = 0
|
|
if i == 0 and num_contexts > 0 and attn_metadata.enable_flash_mla:
|
|
reorder_block_ids_per_seq = torch.cat([
|
|
attn_metadata.
|
|
kv_block_ids_per_seq[num_contexts:batch_size],
|
|
attn_metadata.kv_block_ids_per_seq[:num_contexts]
|
|
])
|
|
attn_metadata.block_ids_per_seq[:batch_size, :].copy_(
|
|
reorder_block_ids_per_seq, non_blocking=True)
|
|
if hasattr(attn_metadata, 'kv_lens_cuda'):
|
|
attn_metadata.kv_lens_cuda[:batch_size] += 1
|
|
# support attention dp
|
|
if spec_metadata.all_rank_num_tokens is not None:
|
|
spec_metadata.all_rank_num_tokens = spec_metadata.all_rank_num_seqs
|
|
inputs = {
|
|
"input_ids": new_draft_token,
|
|
"position_ids": position_ids,
|
|
"hidden_states": hidden_states,
|
|
"attn_metadata": attn_metadata,
|
|
"spec_metadata": spec_metadata,
|
|
}
|
|
next_draft_tokens = torch.stack(next_draft_tokens, dim=1)
|
|
|
|
# restore attn_metadata to support cuda graph
|
|
if attn_metadata.is_cuda_graph:
|
|
attn_metadata._seq_lens[:batch_size].copy_(seq_len)
|
|
attn_metadata._seq_lens_cuda[:batch_size].copy_(seq_len_cuda)
|
|
attn_metadata.on_update()
|
|
|
|
# prepare next new tokens to support overlap scheduler
|
|
next_new_tokens = accepted_tokens[
|
|
spec_metadata.batch_indices_cuda[:batch_size],
|
|
num_accepted_tokens - 1].unsqueeze(1)
|
|
next_new_tokens = torch.concat([next_new_tokens, next_draft_tokens],
|
|
dim=1)
|
|
|
|
return {
|
|
'logits': raw_logits,
|
|
'new_tokens': accepted_tokens,
|
|
'new_tokens_lens': num_accepted_tokens,
|
|
'next_draft_tokens': next_draft_tokens,
|
|
'next_new_tokens': next_new_tokens
|
|
}
|
|
|
|
def prepare_drafter_inputs(
|
|
self,
|
|
input_ids: torch.LongTensor,
|
|
position_ids: torch.LongTensor,
|
|
hidden_states: torch.Tensor,
|
|
accepted_tokens: torch.Tensor,
|
|
num_accepted_tokens: torch.Tensor,
|
|
attn_metadata: AttentionMetadata,
|
|
spec_metadata: MTPSpecMetadata,
|
|
):
|
|
batch_size = attn_metadata.num_seqs
|
|
num_contexts = attn_metadata.num_contexts
|
|
num_gens = batch_size - num_contexts
|
|
num_ctx_tokens = attn_metadata.num_ctx_tokens
|
|
hidden_size = hidden_states.shape[1]
|
|
last_tokens_idx = torch.cumsum(
|
|
attn_metadata.seq_lens_cuda, dim=0, dtype=torch.long) - 1
|
|
|
|
# context
|
|
hidden_states_ctx = hidden_states[:attn_metadata.num_ctx_tokens, :]
|
|
input_ctx_ids = input_ids[:attn_metadata.num_ctx_tokens]
|
|
input_ids_ctx = torch.empty_like(input_ctx_ids,
|
|
dtype=torch.int32,
|
|
device="cuda")
|
|
input_ids_ctx[:-1].copy_(input_ctx_ids[1:])
|
|
input_ids_ctx[
|
|
last_tokens_idx[:num_contexts]] = accepted_tokens[:num_contexts, 0]
|
|
position_ids_ctx = position_ids[:num_ctx_tokens]
|
|
|
|
# generation
|
|
gen_batch_idx = spec_metadata.batch_indices_cuda[:num_gens]
|
|
gen_token_idx = num_accepted_tokens[num_contexts:] - 1
|
|
hidden_states_gen = hidden_states[attn_metadata.num_ctx_tokens:, :]
|
|
hidden_states_gen = hidden_states_gen.reshape(num_gens,
|
|
self.mtp_num_modules + 1,
|
|
hidden_size)
|
|
hidden_states_gen = hidden_states_gen[gen_batch_idx, gen_token_idx, :]
|
|
accepted_tokens_gen = accepted_tokens[num_contexts:, :]
|
|
input_ids_gen = accepted_tokens_gen[gen_batch_idx, gen_token_idx]
|
|
position_ids_gen = position_ids[num_ctx_tokens:].reshape(
|
|
num_gens, self.mtp_num_modules + 1)
|
|
position_ids_gen = position_ids_gen[gen_batch_idx, gen_token_idx]
|
|
|
|
# get draft inputs
|
|
input_ids = torch.concat([input_ids_ctx, input_ids_gen], dim=0)
|
|
hidden_states = torch.concat([hidden_states_ctx, hidden_states_gen],
|
|
dim=0)
|
|
position_ids = torch.concat([position_ids_ctx, position_ids_gen], dim=0)
|
|
|
|
# change attn_metadata
|
|
attn_metadata._seq_lens[num_contexts:batch_size].fill_(1)
|
|
attn_metadata._seq_lens_cuda[num_contexts:batch_size].fill_(1)
|
|
attn_metadata.on_update()
|
|
if hasattr(attn_metadata, 'kv_lens_cuda'):
|
|
# Note that it's important to not free the seq_lens_cuda
|
|
# buffer once the graph has been captured also - this will invalidate
|
|
# the graph and force an expensive recapture.
|
|
attn_metadata.kv_lens_cuda[num_contexts:batch_size] -= (
|
|
self.mtp_num_modules + 1 - num_accepted_tokens[num_contexts:])
|
|
|
|
return {
|
|
"input_ids": input_ids,
|
|
"position_ids": position_ids,
|
|
"hidden_states": hidden_states,
|
|
"attn_metadata": attn_metadata,
|
|
"spec_metadata": spec_metadata,
|
|
}
|