mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
1424 lines
62 KiB
Python
1424 lines
62 KiB
Python
from dataclasses import dataclass
|
|
from typing import TYPE_CHECKING, List, Optional
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from torch import nn
|
|
|
|
from tensorrt_llm.mapping import Mapping
|
|
|
|
from ..attention_backend import AttentionMetadata
|
|
from ..distributed.ops import allgather
|
|
from ..model_config import ModelConfig
|
|
from ..pyexecutor.guided_decoder import CapturableGuidedDecoder
|
|
from ..pyexecutor.llm_request import LlmRequest, LlmRequestState
|
|
from ..pyexecutor.resource_manager import BaseResourceManager, SlotManager
|
|
from ..pyexecutor.sampler import (DEFAULT_BEAM_IDX, SampleState,
|
|
SampleStateTensors, TorchSampler, add_token,
|
|
int_tensor)
|
|
from ..pyexecutor.scheduler import ScheduledRequests
|
|
from .interface import SpecMetadata, get_force_num_accepted_tokens
|
|
|
|
if TYPE_CHECKING:
|
|
from tensorrt_llm.llmapi.llm_args import MTPDecodingConfig
|
|
|
|
|
|
@dataclass(kw_only=True)
|
|
class SampleStateTensorsMTP(SampleStateTensors):
|
|
new_tokens_lens: torch.Tensor
|
|
next_draft_tokens: torch.Tensor
|
|
|
|
|
|
@dataclass(kw_only=True)
|
|
class SampleStateMTP(SampleState):
|
|
device: SampleStateTensorsMTP
|
|
host: SampleStateTensorsMTP
|
|
|
|
|
|
class MTPHiddenStatesManager(BaseResourceManager):
|
|
|
|
def __init__(self, config: "MTPDecodingConfig", dtype: torch.dtype,
|
|
hidden_size: int, max_num_requests: int):
|
|
self.dtype = dtype
|
|
self.num_nextn_predict_layers = config.num_nextn_predict_layers
|
|
self.hidden_size = hidden_size
|
|
self.max_num_requests = max_num_requests
|
|
self.use_relaxed_acceptance_for_thinking = config.use_relaxed_acceptance_for_thinking
|
|
self.slot_manager = SlotManager(max_num_requests)
|
|
|
|
# Since golden token's hidden state will always be generated after target model
|
|
self.mtp_past_hidden_states_pool = torch.zeros(
|
|
(max_num_requests, self.num_nextn_predict_layers, self.hidden_size),
|
|
device='cuda',
|
|
dtype=self.dtype,
|
|
)
|
|
self.mtp_past_tokens_pool = torch.zeros(
|
|
(max_num_requests, self.num_nextn_predict_layers),
|
|
device='cuda',
|
|
dtype=torch.int,
|
|
)
|
|
if self.use_relaxed_acceptance_for_thinking:
|
|
# The relaxed_delta for relaxed acceptance
|
|
self.mtp_relaxed_delta_pool = torch.zeros(
|
|
(self.max_num_requests),
|
|
dtype=torch.float,
|
|
device='cuda',
|
|
)
|
|
|
|
def prepare_resources(self, scheduled_batch: ScheduledRequests):
|
|
context_batch = scheduled_batch.context_requests
|
|
# allocate hidden state tensors
|
|
for req in context_batch:
|
|
if req.is_first_context_chunk:
|
|
slot_id = self.slot_manager.add_slot(req.request_id)
|
|
if self.use_relaxed_acceptance_for_thinking:
|
|
self.mtp_relaxed_delta_pool[slot_id].copy_(
|
|
0, non_blocking=True)
|
|
|
|
def update_resources(self, scheduled_batch: ScheduledRequests):
|
|
pass
|
|
|
|
def free_resources(self, request: LlmRequest):
|
|
free_slot_id = self.slot_manager.get_slot(request.request_id)
|
|
if self.use_relaxed_acceptance_for_thinking:
|
|
self.mtp_relaxed_delta_pool[free_slot_id].copy_(0,
|
|
non_blocking=True)
|
|
self.slot_manager.remove_slot(request.request_id)
|
|
|
|
def add_dummy_requests(self, request_ids: List[int]):
|
|
for rid in request_ids:
|
|
self.slot_manager.add_slot(rid)
|
|
|
|
def shutdown(self):
|
|
self.slot_manager.shutdown()
|
|
|
|
def get_max_resource_count(self) -> int:
|
|
return self.max_num_requests
|
|
|
|
def get_needed_resource_to_completion(self, request: LlmRequest):
|
|
return 0
|
|
|
|
|
|
@dataclass
|
|
class MTPSpecMetadata(SpecMetadata):
|
|
"""
|
|
Metadata for MTP.
|
|
"""
|
|
# The number of MTP modules in the model
|
|
mtp_num_modules: int = 1
|
|
# The hidden states manager for MTP
|
|
mtp_hidden_states_manager: Optional[MTPHiddenStatesManager] = None
|
|
# The slot ids for each request.
|
|
slot_ids: Optional[torch.Tensor] = None
|
|
# The index of the batche inputs
|
|
batch_indices_cuda: Optional[torch.Tensor] = None
|
|
# The number of sequences for speculative model/layer of different rank
|
|
_all_rank_num_seqs: Optional[List[int]] = None
|
|
# This is used for attention dp in the MTP Eagle worker. The numbers of input
|
|
# tokens varies between the 1st draft forward and subsequent ones. To support
|
|
# CUDA graph, we use this tensor to store the number of input tokens for the
|
|
# subsequence draft forward.
|
|
subseq_all_rank_num_tokens: Optional[List[int]] = None
|
|
|
|
def __post_init__(self) -> None:
|
|
if self.mtp_hidden_states_manager is not None:
|
|
# mtp_hidden_states_ptrs is a pointer tensor
|
|
self.mtp_hidden_states_ptrs = torch.empty(
|
|
[self.max_num_requests],
|
|
dtype=torch.int64,
|
|
device='cuda',
|
|
)
|
|
|
|
self.mtp_past_tokens_ptrs = torch.empty(
|
|
[self.max_num_requests],
|
|
dtype=torch.int64,
|
|
device='cuda',
|
|
)
|
|
self.slot_ids = torch.empty(
|
|
[self.max_num_requests],
|
|
dtype=torch.long,
|
|
device='cuda',
|
|
)
|
|
self.batch_indices_cuda = torch.empty(
|
|
[self.max_num_requests],
|
|
dtype=torch.int,
|
|
device='cuda',
|
|
)
|
|
self.draft_token_indices_cuda = torch.arange(
|
|
self.mtp_num_modules,
|
|
device='cuda',
|
|
)
|
|
|
|
@property
|
|
def all_rank_num_seqs(self):
|
|
return self._all_rank_num_seqs
|
|
|
|
@all_rank_num_seqs.setter
|
|
def all_rank_num_seqs(self, value: List[int]):
|
|
self._all_rank_num_seqs = value
|
|
if self.spec_dec_mode.is_mtp_eagle_one_model():
|
|
self.subseq_all_rank_num_tokens = value
|
|
|
|
def prepare(self):
|
|
assert self.request_ids is not None
|
|
num_seqs = len(self.request_ids)
|
|
# update batch indeices
|
|
batch_indices = torch.arange(num_seqs,
|
|
dtype=torch.int,
|
|
device='cpu',
|
|
pin_memory=True)
|
|
self.batch_indices_cuda[:num_seqs].copy_(batch_indices,
|
|
non_blocking=True)
|
|
# MTP vanilla worker uses total max_draft_len input tokens in generation phase,
|
|
# while MTP Eagle worker uses (max_draft_len + 1) input tokens in the 1st draft
|
|
# forward and only one input token in the following draft forward.
|
|
# This num_tokens is used to set the all_rank_num_tokens for attention dp.
|
|
if not self.spec_dec_mode.is_mtp_eagle_one_model():
|
|
self.num_tokens -= self.num_generations
|
|
|
|
if self.mtp_hidden_states_manager is not None: # MTP vanilla or use relaxed acceptance
|
|
mtp_slot_ids = []
|
|
for rid in self.request_ids:
|
|
slot_id = self.mtp_hidden_states_manager.slot_manager.get_slot(
|
|
rid)
|
|
mtp_slot_ids.append(slot_id)
|
|
|
|
# MTP Vanilla: Update mtp hidden states and past tokens
|
|
if self.spec_dec_mode.is_mtp_one_model():
|
|
mtp_hidden_states_ptrs = []
|
|
mtp_past_tokens_ptrs = []
|
|
for slot_id in mtp_slot_ids:
|
|
mtp_hidden_states_ptrs.append(
|
|
self.mtp_hidden_states_manager.
|
|
mtp_past_hidden_states_pool[slot_id].data_ptr())
|
|
mtp_past_tokens_ptrs.append(
|
|
self.mtp_hidden_states_manager.
|
|
mtp_past_tokens_pool[slot_id].data_ptr())
|
|
mtp_hidden_states_ptrs = torch.tensor(mtp_hidden_states_ptrs,
|
|
dtype=torch.int64,
|
|
pin_memory=True)
|
|
mtp_past_tokens_ptrs = torch.tensor(mtp_past_tokens_ptrs,
|
|
dtype=torch.int64,
|
|
pin_memory=True)
|
|
self.mtp_hidden_states_ptrs[:num_seqs].copy_(
|
|
mtp_hidden_states_ptrs, non_blocking=True)
|
|
self.mtp_past_tokens_ptrs[:num_seqs].copy_(mtp_past_tokens_ptrs,
|
|
non_blocking=True)
|
|
mtp_slot_ids = torch.tensor(mtp_slot_ids,
|
|
dtype=torch.int,
|
|
pin_memory=True)
|
|
self.slot_ids[:num_seqs].copy_(mtp_slot_ids, non_blocking=True)
|
|
|
|
|
|
class MTPSampler(TorchSampler):
|
|
"""
|
|
MTP sampler.
|
|
"""
|
|
|
|
SampleState = SampleStateMTP
|
|
|
|
@dataclass(frozen=True, kw_only=True)
|
|
class Store(TorchSampler.Store):
|
|
new_tokens: torch.Tensor
|
|
next_new_tokens: torch.Tensor
|
|
next_draft_tokens: torch.Tensor
|
|
new_tokens_lens: torch.Tensor
|
|
max_total_draft_tokens: torch.Tensor
|
|
finish_reasons: None = None # Necessary to satisfy the interface of TorchSampler.Store
|
|
|
|
def __post_init__(self):
|
|
pass # finish_reasons has no size to compare against new_tokens in MTPSampler
|
|
|
|
def __init__(self, args: TorchSampler.Args, *, nextn: int):
|
|
self.mapping = None
|
|
self.draft_len = nextn
|
|
self.max_seq_len = args.max_seq_len
|
|
|
|
seq_slots = args.max_num_sequences
|
|
max_tokens = args.max_total_draft_tokens + 1
|
|
self.max_beam_width = args.max_beam_width
|
|
assert self.max_beam_width == 1, "beam width must be 1 for MTP"
|
|
|
|
self.store = self.Store(
|
|
new_tokens=int_tensor((max_tokens, seq_slots, self.max_beam_width)),
|
|
next_new_tokens=int_tensor(
|
|
(max_tokens, seq_slots, self.max_beam_width)),
|
|
next_draft_tokens=int_tensor(
|
|
(seq_slots, args.max_total_draft_tokens)),
|
|
new_tokens_lens=int_tensor((seq_slots, )),
|
|
max_total_draft_tokens=int_tensor(
|
|
(seq_slots, args.max_total_draft_tokens)),
|
|
)
|
|
|
|
def _request_common_handling(self, request: LlmRequest,
|
|
next_draft_tokens: list[list[int]]):
|
|
assert not request.py_return_context_logits, "return_context_logits not implemented for MTPSampler"
|
|
assert not request.py_return_generation_logits, "return_generation_logits not implemented for MTPSampler"
|
|
assert not request.py_return_log_probs, "return_log_probs not implemented for MTPSampler"
|
|
request.py_draft_tokens = next_draft_tokens[request.py_seq_slot]
|
|
request.py_decoding_iter += 1
|
|
|
|
def update_requests(
|
|
self,
|
|
state: SampleStateMTP,
|
|
resource_manager: Optional[BaseResourceManager] = None) -> None:
|
|
# resource_manager will be not be used in this function
|
|
assert isinstance(state, SampleStateMTP)
|
|
|
|
state.sampler_event.synchronize()
|
|
new_tokens = state.host.new_tokens.tolist()
|
|
new_tokens_lens_list = state.host.new_tokens_lens.tolist()
|
|
next_draft_tokens_list = state.host.next_draft_tokens.tolist()
|
|
beam_idx = DEFAULT_BEAM_IDX
|
|
for req in state.scheduled_requests.context_requests:
|
|
if req.state == LlmRequestState.GENERATION_COMPLETE or req.context_remaining_length != 0:
|
|
continue
|
|
new_token = add_token(req, new_tokens, beam_idx=beam_idx)
|
|
TorchSampler._handle_stop_criteria(req,
|
|
new_token,
|
|
max_seq_len=self.max_seq_len,
|
|
beam_idx=beam_idx)
|
|
self._request_common_handling(req, next_draft_tokens_list)
|
|
|
|
for req in state.scheduled_requests.generation_requests:
|
|
if req.state == LlmRequestState.GENERATION_COMPLETE:
|
|
continue
|
|
num_new_tokens = new_tokens_lens_list[req.py_seq_slot]
|
|
for i in range(num_new_tokens):
|
|
new_token = add_token(req,
|
|
new_tokens,
|
|
beam_idx=beam_idx,
|
|
step=i)
|
|
if TorchSampler._handle_stop_criteria(
|
|
req,
|
|
new_token,
|
|
max_seq_len=self.max_seq_len,
|
|
beam_idx=beam_idx):
|
|
break
|
|
req.py_num_accepted_draft_tokens = num_new_tokens - 1
|
|
req.py_rewind_len = self.draft_len - req.py_num_accepted_draft_tokens
|
|
self._request_common_handling(req, next_draft_tokens_list)
|
|
|
|
def sample_async(
|
|
self, scheduled_requests: ScheduledRequests,
|
|
outputs: dict[str, torch.Tensor],
|
|
num_context_logits_prefix_sum: list[int]) -> SampleStateMTP:
|
|
# new_tokens_device: accepted tokens, device tensor, shape: batch_size, nextn + 1
|
|
# new_tokens_lens_device: accepted lengths, device tensor, shape: batch_size
|
|
# next_draft_tokens_device: predicted draft tokens, device tensor, shape: batch_size, nextn
|
|
# next_new_tokens_device: input tokens for the next iteration, device tensor, shape: batch_size, nextn + 1
|
|
|
|
requests = scheduled_requests.all_requests()
|
|
slots = torch.as_tensor([r.py_seq_slot for r in requests])
|
|
slots = slots.to(device="cuda", non_blocking=True)
|
|
|
|
o_new_tokens = outputs['new_tokens'][:len(requests)]
|
|
o_new_tokens_lens = outputs['new_tokens_lens'][:len(requests)]
|
|
o_next_draft_tokens = outputs['next_draft_tokens'][:len(requests)]
|
|
o_next_new_tokens = outputs['next_new_tokens'][:len(requests)]
|
|
|
|
new_tokens = self.store.new_tokens
|
|
next_new_tokens = self.store.next_new_tokens
|
|
new_tokens_lens = self.store.new_tokens_lens
|
|
next_draft_tokens = self.store.next_draft_tokens
|
|
|
|
new_tokens.squeeze(-1).T.index_copy_(0, slots, o_new_tokens)
|
|
next_new_tokens.squeeze(-1).T.index_copy_(0, slots, o_next_new_tokens)
|
|
new_tokens_lens.index_copy_(0, slots, o_new_tokens_lens)
|
|
next_draft_tokens.index_copy_(0, slots, o_next_draft_tokens)
|
|
|
|
device = SampleStateTensorsMTP(
|
|
new_tokens=next_new_tokens,
|
|
new_tokens_lens=new_tokens_lens,
|
|
next_draft_tokens=next_draft_tokens,
|
|
)
|
|
host = SampleStateTensorsMTP(
|
|
new_tokens=new_tokens.to('cpu', non_blocking=True),
|
|
new_tokens_lens=new_tokens_lens.to('cpu', non_blocking=True),
|
|
next_draft_tokens=next_draft_tokens.to('cpu', non_blocking=True),
|
|
)
|
|
sampler_event = torch.cuda.Event()
|
|
sampler_event.record()
|
|
# add dummy draft tokens to context requests to prepare kv cache in advance
|
|
# with the max draft token length
|
|
for request in scheduled_requests.context_requests:
|
|
request.py_draft_tokens = [1] * self.draft_len
|
|
return SampleStateMTP(scheduled_requests=scheduled_requests,
|
|
device=device,
|
|
host=host,
|
|
sampler_event=sampler_event)
|
|
|
|
|
|
class MTPWorker(nn.Module):
|
|
|
|
def __init__(self, spec_config: "MTPDecodingConfig", model_config=None):
|
|
super().__init__()
|
|
self.spec_config = spec_config
|
|
self.model_config = model_config
|
|
self.is_thop = False
|
|
self.guided_decoder: Optional[CapturableGuidedDecoder] = None
|
|
self.force_num_accepted_tokens = get_force_num_accepted_tokens()
|
|
|
|
def forward(
|
|
self,
|
|
input_ids,
|
|
position_ids,
|
|
hidden_states,
|
|
logits,
|
|
attn_metadata,
|
|
spec_metadata,
|
|
draft_model,
|
|
):
|
|
'''
|
|
Example:
|
|
Assume there are 3 MTP layers
|
|
Notation:
|
|
- H_t: token t's hidden state, generated by the target model
|
|
- h_t: token t's hidden state, generated by the draft model
|
|
|
|
Prompt: ABCD
|
|
|
|
Context phase:
|
|
Target model:
|
|
- input tokens: ABCD + []
|
|
- sampling tokens: E
|
|
- accepted tokens: E
|
|
- KV cache: ABCD
|
|
- hidden states: H_A, H_B, H_C, H_D
|
|
Draft model:
|
|
MTP1:
|
|
# For context request, prompt[1:] + new generated goloden token is the input.
|
|
- input tokens: BCDE
|
|
- input hidden states: H_A, H_B, H_C, H_D
|
|
# '()' means historical KV cache
|
|
- KV cache: () + BCDE
|
|
- output hidden states: h_B, h_C, h_D, h_E
|
|
- output next draft token: F
|
|
MTP2:
|
|
- input token: CDEF
|
|
- input hidden states: H_B, H_C, H_D, h_E
|
|
- KV cache: () + CDEF
|
|
- output hidden states: h_C, h_D, h_E, h_F
|
|
- output next draft token: G
|
|
MTP3:
|
|
- input tokens: DEFG
|
|
- input hidden states: H_C, H_D, h_E, h_F
|
|
- KV cache: () + DEFG
|
|
- output hidden states: h_D, h_E, h_F, h_G
|
|
- output next draft token: H
|
|
After 3 MTP layers:
|
|
- new generated draft tokens: FGH
|
|
|
|
Generation phase 1: accept partial draft tokens
|
|
Target model:
|
|
- input tokens: E + FGH
|
|
- sampling tokens: FGXY
|
|
- accepted tokens: FGX
|
|
- KV cache: (ABCD) + EFGH (H's KV cache is invalid)
|
|
- hidden states: H_E, H_F, H_G, H_H (H_H is invalid)
|
|
Draft model:
|
|
MTP1:
|
|
# For generation request, `mtp_num_modules` of tokens will be used as input.
|
|
- input tokens: FGX
|
|
- input hidden states: H_E, H_F, H_G
|
|
- KV cache: (BCDE) + FGX
|
|
- output hidden states: h_F, h_G, h_X
|
|
- output next draft token: N
|
|
MTP2:
|
|
- input tokens: GXN
|
|
- input hidden states: H_F, H_G, h_X
|
|
- KV cache: (CDEF) + GXN
|
|
- output hidden states: h_G, h_X, h_N
|
|
- output next draft token: O
|
|
MTP3:
|
|
- input tokens: XNO
|
|
- input hidden states: H_G, H_X, h_N
|
|
- KV cache: (DEFG) + XNO
|
|
- output hidden states: h_X, h_N, h_O
|
|
- output next draft token: P
|
|
After 3 MTP layers:
|
|
- new generated draft tokens: NOP
|
|
|
|
Generation 2: accept none draft tokens
|
|
Target model:
|
|
- input tokens: X + NOP
|
|
- sampling tokens: KMZY
|
|
- accepted tokens: K
|
|
- KV cache: (ABCDEFG) + NOP (NOP's KV cache is invalid)
|
|
- hidden states: H_X, H_N, H_O, H_P (H_N, H_O, H_P is invalid)
|
|
Draft model:
|
|
MTP1:
|
|
- input tokens: GXK
|
|
- input hidden states: H_F, H_G, H_X
|
|
- KV cache: (BCDE + F) + GXK
|
|
- output hidden states: h_G, h_X, h_K
|
|
- output next draft token: U
|
|
MTP2:
|
|
- input tokens: XKU
|
|
- input hidden states: H_G, H_X, h_K
|
|
- KV cache: (CDEF + G) + XKU
|
|
- output hidden states: h_X, h_K, h_U
|
|
- output next draft token: V
|
|
MTP3:
|
|
- input tokens: KUV
|
|
- input hidden states: H_X, h_K, h_U
|
|
- KV cache: (DEFG + X) + KUV
|
|
- output hidden states: h_K, h_U, h_V
|
|
- output next draft token: Q
|
|
After 3 MTP layers:
|
|
- new generated draft tokens: UVQ
|
|
'''
|
|
|
|
batch_size = attn_metadata.num_seqs
|
|
|
|
raw_logits = logits
|
|
|
|
if self.guided_decoder is not None:
|
|
self.guided_decoder.execute(logits)
|
|
|
|
# Sample and verify draft tokens
|
|
accepted_tokens, num_accepted_tokens = self.sample_and_accept_draft_tokens(
|
|
input_ids, logits, spec_metadata, attn_metadata)
|
|
|
|
# Update MTP past hidden states
|
|
self.update_mtp_hidden_states(input_ids=input_ids,
|
|
hidden_states=hidden_states,
|
|
num_accepted_tokens=num_accepted_tokens,
|
|
spec_metadata=spec_metadata,
|
|
attn_metadata=attn_metadata)
|
|
|
|
# prepare draft layer inputs
|
|
position_ids = position_ids.squeeze(0)
|
|
draft_inputs = self.prepare_drafter_inputs(
|
|
input_ids=input_ids,
|
|
position_ids=position_ids,
|
|
hidden_states=hidden_states,
|
|
accepted_tokens=accepted_tokens,
|
|
num_accepted_tokens=num_accepted_tokens,
|
|
spec_metadata=spec_metadata,
|
|
attn_metadata=attn_metadata)
|
|
|
|
# update attn metadata
|
|
if attn_metadata is not None:
|
|
self.change_attn_metadata(num_accepted_tokens, attn_metadata)
|
|
draft_inputs.update(attn_metadata=attn_metadata)
|
|
|
|
# Run MTP layers to predict draft tokens
|
|
next_draft_tokens = []
|
|
last_tokens_idx = torch.cumsum(
|
|
attn_metadata.seq_lens_cuda, dim=0, dtype=torch.long) - 1
|
|
for i, mtp_layer in enumerate(draft_model.mtp_layers):
|
|
if self.guided_decoder is not None:
|
|
new_tokens = draft_inputs['input_ids'][last_tokens_idx]
|
|
self.guided_decoder.add_draft_batch(new_tokens,
|
|
num_accepted_tokens,
|
|
draft_step=i)
|
|
|
|
hidden_states = mtp_layer(embed_tokens=draft_model.embed_tokens,
|
|
**draft_inputs)
|
|
logits = mtp_layer.shared_head(hidden_states, draft_model.lm_head,
|
|
attn_metadata).float()
|
|
if self.guided_decoder is not None:
|
|
self.guided_decoder.execute_draft_batch(logits, draft_step=i)
|
|
|
|
new_draft_token = self.draft_sampler(logits)
|
|
next_draft_tokens.append(new_draft_token)
|
|
# shift input_ids and hidden_states
|
|
input_ids = draft_inputs["input_ids"]
|
|
input_ids[:-1] = input_ids[1:].clone()
|
|
input_ids[last_tokens_idx] = new_draft_token
|
|
draft_hidden_states = draft_inputs["hidden_states"]
|
|
draft_hidden_states[:-1] = draft_hidden_states[1:].clone()
|
|
draft_hidden_states[last_tokens_idx] = hidden_states[
|
|
last_tokens_idx, :]
|
|
draft_inputs = {
|
|
"input_ids": input_ids,
|
|
"position_ids": draft_inputs["position_ids"],
|
|
"hidden_states": draft_hidden_states,
|
|
"attn_metadata": draft_inputs["attn_metadata"],
|
|
}
|
|
next_draft_tokens = torch.stack(next_draft_tokens, dim=1)
|
|
|
|
# restore attn metadata
|
|
if attn_metadata is not None:
|
|
self.restore_attn_metadata(attn_metadata=attn_metadata)
|
|
|
|
# prepare next new tokens to support overlap scheduler
|
|
next_new_tokens = accepted_tokens[
|
|
spec_metadata.batch_indices_cuda[:batch_size],
|
|
num_accepted_tokens - 1].unsqueeze(1)
|
|
next_new_tokens = torch.concat([next_new_tokens, next_draft_tokens],
|
|
dim=1)
|
|
|
|
return {
|
|
'logits': raw_logits,
|
|
'new_tokens': accepted_tokens,
|
|
'new_tokens_lens': num_accepted_tokens,
|
|
'next_draft_tokens': next_draft_tokens,
|
|
'next_new_tokens': next_new_tokens
|
|
}
|
|
|
|
def skip_forward(
|
|
self,
|
|
input_ids,
|
|
position_ids,
|
|
hidden_states,
|
|
logits,
|
|
attn_metadata,
|
|
spec_metadata,
|
|
draft_model,
|
|
):
|
|
batch_size = attn_metadata.num_seqs
|
|
mtp_num_modules = self.spec_config.num_nextn_predict_layers
|
|
accepted_tokens = torch.empty((batch_size, (mtp_num_modules + 1)),
|
|
dtype=torch.int,
|
|
device=logits.device)
|
|
num_accepted_tokens = torch.ones(batch_size,
|
|
dtype=torch.int,
|
|
device=logits.device)
|
|
next_draft_tokens = torch.empty((batch_size, mtp_num_modules),
|
|
dtype=torch.int,
|
|
device=logits.device)
|
|
next_new_tokens = torch.empty((batch_size, (mtp_num_modules + 1)),
|
|
dtype=torch.int,
|
|
device=logits.device)
|
|
return {
|
|
'logits': logits,
|
|
'new_tokens': accepted_tokens,
|
|
'new_tokens_lens': num_accepted_tokens,
|
|
'next_draft_tokens': next_draft_tokens,
|
|
'next_new_tokens': next_new_tokens
|
|
}
|
|
|
|
def update_mtp_hidden_states(
|
|
self,
|
|
input_ids: torch.IntTensor,
|
|
hidden_states: torch.Tensor,
|
|
num_accepted_tokens: torch.Tensor,
|
|
spec_metadata: MTPSpecMetadata,
|
|
attn_metadata: AttentionMetadata,
|
|
):
|
|
'''
|
|
Update the past hidden states and past tokens in spec_metadata base on
|
|
the newly accepted tokens and historical hidden states.
|
|
These past hidden states and past tokens will be use in MTP module.
|
|
|
|
Args:
|
|
input_ids: torch.IntTensor
|
|
[num_tokens]
|
|
The input ids of all requests. Flatten.
|
|
|
|
hidden_states: torch.Tensor
|
|
[num_tokens, hidden_size]
|
|
Target model's hidden states.
|
|
|
|
num_accepted_tokens: torch.Tensor
|
|
[batch_size]
|
|
Number of accepted tokens per request.
|
|
|
|
spec_metadata: MTPSpecMetadata
|
|
MTP speculative decoding metadata
|
|
|
|
attn_metadata: AttentionMetadata
|
|
Attention metadata
|
|
|
|
Returns:
|
|
None
|
|
'''
|
|
|
|
def unpack_sequence(packed_seq_cuda, seq_lens_cuda, seq_lens_cpu):
|
|
# max_length is used as tensor shape, so it should be from host;
|
|
# otherwise, an implicit D2H copy will be triggered.
|
|
max_length = seq_lens_cpu.max().item()
|
|
num_sequences = seq_lens_cuda.shape[0]
|
|
# initialize a zero tensor to store the result
|
|
result = torch.zeros(
|
|
(num_sequences, max_length, packed_seq_cuda.shape[1]),
|
|
dtype=packed_seq_cuda.dtype,
|
|
device=packed_seq_cuda.device)
|
|
# get mask
|
|
seq_indices = torch.arange(
|
|
max_length, device=seq_lens_cuda.device).unsqueeze(0).expand(
|
|
num_sequences, -1)
|
|
mask = seq_indices < seq_lens_cuda.unsqueeze(1)
|
|
# unpack
|
|
result[mask] = packed_seq_cuda
|
|
return result
|
|
|
|
batch_size = attn_metadata.num_seqs
|
|
num_contexts = attn_metadata.num_contexts
|
|
num_ctx_tokens = attn_metadata.num_ctx_tokens
|
|
num_gens = batch_size - num_contexts
|
|
seq_lens = attn_metadata.seq_lens_cuda
|
|
seq_lens_cpu = attn_metadata.seq_lens
|
|
hidden_size = hidden_states.shape[-1]
|
|
mtp_num_modules = self.spec_config.num_nextn_predict_layers
|
|
|
|
if self.is_thop:
|
|
_, _ = torch.ops.trtllm.mtp_update_hidden_states_op(
|
|
input_ids, seq_lens, hidden_states,
|
|
spec_metadata.mtp_hidden_states_ptrs,
|
|
spec_metadata.mtp_past_tokens_ptrs, num_accepted_tokens,
|
|
mtp_num_modules, batch_size, num_contexts, hidden_size)
|
|
else:
|
|
assert len(spec_metadata.request_ids) == batch_size
|
|
mtp_past_hidden_states_pool = spec_metadata.mtp_hidden_states_manager.mtp_past_hidden_states_pool
|
|
mtp_past_tokens_pool = spec_metadata.mtp_hidden_states_manager.mtp_past_tokens_pool
|
|
|
|
slot_ids = spec_metadata.slot_ids[:batch_size]
|
|
mtp_tokens = mtp_past_tokens_pool[slot_ids]
|
|
mtp_hidden_states = mtp_past_hidden_states_pool[slot_ids]
|
|
|
|
new_mtp_past_tokens, new_mtp_past_hidden_states = [], []
|
|
# context
|
|
if num_contexts > 0:
|
|
seq_lens_ctx = seq_lens[:num_contexts]
|
|
seq_lens_ctx_cpu = seq_lens_cpu[:num_contexts]
|
|
unpacked_input_ids_ctx = unpack_sequence(
|
|
input_ids[:num_ctx_tokens].unsqueeze(1), seq_lens_ctx,
|
|
seq_lens_ctx_cpu).squeeze(2)
|
|
unpacked_hidden_states_ctx = unpack_sequence(
|
|
hidden_states[:num_ctx_tokens], seq_lens_ctx,
|
|
seq_lens_ctx_cpu)
|
|
cat_tokens_ctx = torch.cat(
|
|
(mtp_tokens[:num_contexts], unpacked_input_ids_ctx), dim=1)
|
|
cat_hidden_states_ctx = torch.cat(
|
|
(mtp_hidden_states[:num_contexts],
|
|
unpacked_hidden_states_ctx),
|
|
dim=1)
|
|
ctx_batch_idx = spec_metadata.batch_indices_cuda[:num_contexts]
|
|
row_indices_ctx = ctx_batch_idx.unsqueeze(1).expand(
|
|
-1, mtp_num_modules)
|
|
col_indices_ctx = (seq_lens_ctx.unsqueeze(1) +
|
|
spec_metadata.draft_token_indices_cuda)
|
|
new_mtp_past_tokens.append(cat_tokens_ctx[row_indices_ctx,
|
|
col_indices_ctx])
|
|
new_mtp_past_hidden_states.append(
|
|
cat_hidden_states_ctx[row_indices_ctx, col_indices_ctx, :])
|
|
|
|
# generation
|
|
if num_gens > 0:
|
|
unpacked_input_ids_gen = input_ids[num_ctx_tokens:].reshape(
|
|
num_gens, mtp_num_modules + 1).int()
|
|
hidden_states_gen = hidden_states[num_ctx_tokens:, :]
|
|
unpacked_hidden_states_gen = hidden_states_gen.reshape(
|
|
num_gens, mtp_num_modules + 1, hidden_size)
|
|
cat_tokens_gen = torch.cat(
|
|
(mtp_tokens[num_contexts:], unpacked_input_ids_gen), dim=1)
|
|
cat_hidden_states_gen = torch.cat(
|
|
(mtp_hidden_states[num_contexts:],
|
|
unpacked_hidden_states_gen),
|
|
dim=1)
|
|
gen_batch_idx = spec_metadata.batch_indices_cuda[:num_gens]
|
|
row_indices_gen = gen_batch_idx.unsqueeze(1).expand(
|
|
-1, mtp_num_modules)
|
|
col_indices_gen = (
|
|
num_accepted_tokens[num_contexts:].unsqueeze(1) +
|
|
spec_metadata.draft_token_indices_cuda)
|
|
new_mtp_past_tokens.append(cat_tokens_gen[row_indices_gen,
|
|
col_indices_gen])
|
|
new_mtp_past_hidden_states.append(
|
|
cat_hidden_states_gen[row_indices_gen, col_indices_gen, :])
|
|
|
|
# update past tokens and hidden states
|
|
new_mtp_past_tokens = torch.cat(new_mtp_past_tokens, dim=0)
|
|
new_mtp_past_hidden_states = torch.cat(new_mtp_past_hidden_states,
|
|
dim=0)
|
|
mtp_past_tokens_pool.index_copy_(0, slot_ids, new_mtp_past_tokens)
|
|
mtp_past_hidden_states_pool.index_copy_(0, slot_ids,
|
|
new_mtp_past_hidden_states)
|
|
|
|
@torch.compile(options={"max-autotune": True})
|
|
def topk_kernel(self, gen_logprobs, num_gens, mtp_num_modules,
|
|
spec_metadata):
|
|
topk_value, topk_indices = torch.topk(gen_logprobs,
|
|
k=self.spec_config.relaxed_topk,
|
|
dim=-1)
|
|
topk_indices = topk_indices.reshape(num_gens, mtp_num_modules + 1,
|
|
self.spec_config.relaxed_topk)
|
|
topk_value = topk_value.reshape(num_gens, mtp_num_modules + 1,
|
|
self.spec_config.relaxed_topk)
|
|
draft_tokens = spec_metadata.draft_tokens.reshape(
|
|
num_gens, mtp_num_modules)
|
|
return topk_value, topk_indices, draft_tokens
|
|
|
|
@torch.compile(options={"max-autotune": True})
|
|
def process_generation_logits(self, logits, num_contexts):
|
|
gen_logits = logits[num_contexts:]
|
|
gen_logprobs = torch.softmax(gen_logits, dim=-1)
|
|
return gen_logprobs
|
|
|
|
def sample_and_accept_draft_tokens(
|
|
self,
|
|
input_ids: torch.IntTensor,
|
|
logits: torch.Tensor,
|
|
spec_metadata: MTPSpecMetadata,
|
|
attn_metadata: AttentionMetadata,
|
|
):
|
|
'''
|
|
Takes input logits and samples golden token + predictions from draft tokens.
|
|
Runs acceptance algorithm to accept draft tokens.
|
|
Currently only support greedy sampling. All decoding is done using Top1 and token equality is used
|
|
for acceptance.
|
|
|
|
Args:
|
|
input_ids: torch.IntTensor
|
|
[num_tokens]
|
|
The input ids of all requests. Flatten.
|
|
|
|
logits: torch.Tensor
|
|
[num_tokens, vocab_size]
|
|
Logits produced by the target model.
|
|
|
|
spec_metadata: MTPSpecMetadata
|
|
MTP speculative decoding metadata
|
|
|
|
attn_metadata: AttentionMetadata
|
|
Attention metadata
|
|
|
|
Returns:
|
|
accepted_tokens: torch.Tensor
|
|
[batch_size, (max_draft_len + 1)]
|
|
Accepted token ids. Flattened.
|
|
|
|
num_accepted_tokens: torch.Tensor
|
|
[batch_size]
|
|
Number of accepted tokens per request.
|
|
|
|
Example:
|
|
Assume there are 3 MTP layers
|
|
Prompt: ABCD
|
|
|
|
Context phase:
|
|
Target model:
|
|
- input tokens: ABCD + []
|
|
- sampling tokens: E
|
|
- accepted tokens: E
|
|
Draft model:
|
|
- input tokens: BCDE
|
|
- new generated draft tokens: FGH
|
|
Current sequence: ABCD E`FGH -> Whitespace separates tokens produced by each phase
|
|
-> Backtick separates accepted and draft tokens
|
|
|
|
Generation phase 1:
|
|
Target model:
|
|
- input tokens: E + FGH
|
|
- sampling tokens: FGXY -> Sample with E's logit and get 'F'; Sample with F's logit, ...
|
|
- accepted tokens: FGX -> 'X' will be treat as the accepted token
|
|
Draft model:
|
|
- input tokens: FGX
|
|
- new generated draft tokens: PQR
|
|
Current sequence: ABCD EFG X`PQR
|
|
|
|
Generation phase 2:
|
|
Target model:
|
|
- input tokens: X + PQR
|
|
- sampling tokens: PYST
|
|
- accepted token: PY
|
|
Draft model:
|
|
- input tokens: PY
|
|
- new generated draft tokens: UVW
|
|
Current sequence: ABCD EFG XP Y`UVW
|
|
'''
|
|
|
|
batch_size = attn_metadata.num_seqs
|
|
num_contexts = attn_metadata.num_contexts
|
|
num_gens = batch_size - num_contexts
|
|
mtp_num_modules = self.spec_config.num_nextn_predict_layers
|
|
|
|
if logits.dim() == 1:
|
|
logits = logits.unsqueeze(0)
|
|
|
|
# The return buffer
|
|
if self.spec_config.use_relaxed_acceptance_for_thinking or not self.is_thop:
|
|
accepted_tokens = torch.ones((batch_size, (mtp_num_modules + 1)),
|
|
dtype=torch.int,
|
|
device=logits.device)
|
|
num_accepted_tokens = torch.ones(batch_size,
|
|
dtype=torch.int,
|
|
device=logits.device)
|
|
if self.spec_config.use_relaxed_acceptance_for_thinking:
|
|
mtp_relaxed_delta_pool = spec_metadata.mtp_hidden_states_manager.mtp_relaxed_delta_pool
|
|
|
|
# context
|
|
con_logits = logits[:num_contexts]
|
|
con_target_tokens = torch.argmax(con_logits, dim=-1)
|
|
accepted_tokens[:num_contexts, 0] = con_target_tokens[:num_contexts]
|
|
last_tokens_idx = torch.cumsum(
|
|
attn_metadata.seq_lens_cuda, dim=0, dtype=torch.long) - 1
|
|
ctx_input_ids = input_ids[:attn_metadata.num_ctx_tokens]
|
|
ctx_is_think = (ctx_input_ids ==
|
|
self.spec_config.begin_thinking_phase_token).int()
|
|
ctx_is_think_cumsum = torch.cumsum(ctx_is_think, dim=0)
|
|
ctx_last_cumsum = ctx_is_think_cumsum[
|
|
last_tokens_idx[:num_contexts]]
|
|
ctx_think_tokens_num = torch.diff(
|
|
ctx_last_cumsum,
|
|
dim=0,
|
|
prepend=torch.zeros(1,
|
|
dtype=torch.int,
|
|
device=ctx_last_cumsum.device))
|
|
|
|
ctx_delta = (ctx_think_tokens_num
|
|
>= 1).int() * self.spec_config.relaxed_delta
|
|
ctx_slot_ids = spec_metadata.slot_ids[:num_contexts]
|
|
mtp_relaxed_delta_pool.index_copy_(0, ctx_slot_ids, ctx_delta)
|
|
|
|
# generation
|
|
gen_logprobs = self.process_generation_logits(logits, num_contexts)
|
|
topk_value, topk_indices, draft_tokens = self.topk_kernel(
|
|
gen_logprobs, num_gens, mtp_num_modules, spec_metadata)
|
|
|
|
accepted_tokens, num_accepted_tokens = torch.ops.trtllm.mtp_relaxed_acceptance_op(
|
|
spec_metadata.slot_ids, topk_value, topk_indices, draft_tokens,
|
|
mtp_relaxed_delta_pool, num_accepted_tokens, accepted_tokens,
|
|
mtp_num_modules, batch_size, num_contexts,
|
|
self.spec_config.relaxed_topk, self.spec_config.relaxed_delta,
|
|
self.spec_config.begin_thinking_phase_token,
|
|
self.spec_config.end_thinking_phase_token)
|
|
|
|
# Strict acceptance
|
|
else:
|
|
if self.is_thop:
|
|
# Temporary buffer
|
|
target_tokens_cache = torch.zeros(batch_size *
|
|
(mtp_num_modules + 1),
|
|
dtype=torch.int,
|
|
device=logits.device)
|
|
accepted_tokens, num_accepted_tokens = torch.ops.trtllm.mtp_sampling_and_accepted_draft_tokens_op(
|
|
logits, spec_metadata.draft_tokens, target_tokens_cache,
|
|
mtp_num_modules, batch_size, num_contexts, logits.shape[-1])
|
|
else:
|
|
# Do greedy sampling for the input logits
|
|
target_tokens = torch.argmax(logits, dim=-1)
|
|
|
|
# context
|
|
accepted_tokens[:num_contexts, 0] = target_tokens[:num_contexts]
|
|
|
|
# generation
|
|
gen_target_tokens = target_tokens[num_contexts:].reshape(
|
|
num_gens, mtp_num_modules + 1)
|
|
accepted_tokens[num_contexts:, :] = gen_target_tokens
|
|
draft_tokens = spec_metadata.draft_tokens.reshape(
|
|
num_gens, mtp_num_modules)
|
|
num_accepted_tokens[num_contexts:] += torch.cumprod(
|
|
(draft_tokens == gen_target_tokens[:, :mtp_num_modules]
|
|
).int(),
|
|
dim=-1).sum(1)
|
|
|
|
# Check for environment variable override
|
|
if self.force_num_accepted_tokens != 0:
|
|
# total tokens per iteration = accepted draft tokens + 1 target token
|
|
force_total_tokens = min(self.force_num_accepted_tokens + 1,
|
|
mtp_num_modules + 1)
|
|
num_accepted_tokens[num_contexts:] = force_total_tokens
|
|
|
|
return accepted_tokens, num_accepted_tokens
|
|
|
|
def change_attn_metadata(self, num_accepted_tokens: torch.Tensor,
|
|
attn_metadata: AttentionMetadata):
|
|
attn_metadata.prepare_for_spec_dec("_seq_lens", "_seq_lens_cuda")
|
|
batch_size = attn_metadata.num_seqs
|
|
mtp_num_modules = self.spec_config.num_nextn_predict_layers
|
|
|
|
num_contexts = attn_metadata.num_contexts
|
|
attn_metadata._seq_lens[num_contexts:batch_size] -= 1
|
|
attn_metadata._seq_lens_cuda[num_contexts:batch_size] -= 1
|
|
attn_metadata.on_update()
|
|
|
|
if hasattr(attn_metadata, 'kv_lens_cuda'):
|
|
# Note that it's important to not free the seq_lens_cuda
|
|
# buffer once the graph has been captured also - this will invalidate
|
|
# the graph and force an expensive recapture.
|
|
attn_metadata.kv_lens_cuda[num_contexts:batch_size] -= (
|
|
mtp_num_modules + 1 -
|
|
num_accepted_tokens[num_contexts:batch_size])
|
|
|
|
if attn_metadata.kv_cache_params is not None and not attn_metadata.is_cuda_graph:
|
|
for i in range(num_contexts, batch_size):
|
|
# used for vanilla MLA, list on cpu
|
|
attn_metadata.kv_cache_params.num_cached_tokens_per_seq[
|
|
i] -= mtp_num_modules + 1 - num_accepted_tokens[i].item()
|
|
|
|
def restore_attn_metadata(self, attn_metadata: AttentionMetadata):
|
|
attn_metadata.restore_from_spec_dec()
|
|
attn_metadata.on_update()
|
|
|
|
def prepare_drafter_inputs(
|
|
self,
|
|
input_ids: torch.IntTensor,
|
|
position_ids: torch.IntTensor,
|
|
hidden_states: torch.Tensor,
|
|
accepted_tokens: torch.Tensor,
|
|
num_accepted_tokens: torch.Tensor,
|
|
spec_metadata: MTPSpecMetadata,
|
|
attn_metadata: AttentionMetadata,
|
|
):
|
|
'''
|
|
Parepare the input of the draft model.
|
|
|
|
Args:
|
|
input_ids: torch.IntTensor
|
|
[num_tokens]
|
|
The input ids of all requests. Flatten.
|
|
num_tokens = sum(all prompts) + num_generation * (mtp_num_modules + 1)
|
|
|
|
position_ids: torch.IntTensor
|
|
[1][num_tokens]
|
|
The position id of all requests. Flatten.
|
|
|
|
hidden_states: torch.Tensor
|
|
[num_tokens, hidden_size]
|
|
Target model's hidden states.
|
|
|
|
accepted_tokens: torch.Tensor
|
|
[batch_size, max_draft_len + 1]
|
|
Accepted token ids. Flattened.
|
|
|
|
num_accepted_tokens: torch.Tensor
|
|
[batch_size]
|
|
Number of accepted draft tokens. Will be used for the first MTP layer.
|
|
|
|
spec_metadata: MTPSpecMetadata
|
|
MTP speculative decoding metadata
|
|
|
|
attn_metadata: AttentionMetadata
|
|
Attention metadata
|
|
|
|
Returns: draft_inputs
|
|
input_ids: torch.Tensor
|
|
[num_tokens]
|
|
The new input ids of all requests. Flatten.
|
|
num_tokens = sum(all prompts) + num_generation * (mtp_num_modules)
|
|
|
|
position_ids: torch.Tensor
|
|
[1][[num_tokens]]
|
|
The new position ids of all requests. Flatten.
|
|
Directly use the input position ids.
|
|
|
|
hidden_states: torch.Tensor
|
|
[num_tokens][hidden_size]
|
|
Continuous hidden states buffer.
|
|
|
|
attn_metadata: AttentionMetadata
|
|
Attention metadata
|
|
|
|
spec_metadata: MTPSpecMetadata
|
|
MTP speculative decoding metadata
|
|
|
|
'''
|
|
batch_size = attn_metadata.num_seqs
|
|
num_contexts = attn_metadata.num_contexts
|
|
num_ctx_tokens = attn_metadata.num_ctx_tokens
|
|
num_gens = batch_size - num_contexts
|
|
mtp_past_hidden_states_pool = spec_metadata.mtp_hidden_states_manager.mtp_past_hidden_states_pool
|
|
mtp_past_tokens_pool = spec_metadata.mtp_hidden_states_manager.mtp_past_tokens_pool
|
|
mtp_num_modules = self.spec_config.num_nextn_predict_layers
|
|
|
|
if self.is_thop:
|
|
# Temporary buffer
|
|
hidden_size = hidden_states.shape[-1]
|
|
|
|
# generation requests' golden tokens
|
|
num_tokens = input_ids.shape[0] - num_gens
|
|
return_input_ids = torch.empty(num_tokens,
|
|
dtype=torch.int,
|
|
device="cuda")
|
|
|
|
return_hidden_states = torch.empty((num_tokens, hidden_size),
|
|
dtype=hidden_states.dtype,
|
|
device="cuda")
|
|
|
|
(return_input_ids, return_hidden_states
|
|
) = torch.ops.trtllm.mtp_prepare_drafter_inputs_op(
|
|
input_ids, attn_metadata.seq_lens_cuda,
|
|
spec_metadata.mtp_hidden_states_ptrs,
|
|
spec_metadata.mtp_past_tokens_ptrs, hidden_states,
|
|
accepted_tokens, num_accepted_tokens, return_input_ids,
|
|
return_hidden_states, mtp_num_modules, batch_size,
|
|
num_contexts, hidden_size)
|
|
|
|
else:
|
|
return_input_ids_list = []
|
|
return_hidden_states_list = []
|
|
# Calculate cumulative sequence lengths for indexing
|
|
last_tokens_idx = torch.cumsum(
|
|
attn_metadata.seq_lens_cuda, dim=0, dtype=torch.long) - 1
|
|
# context
|
|
if num_contexts > 0:
|
|
hidden_states_ctx = hidden_states[:num_ctx_tokens, :]
|
|
input_prompt_ids = input_ids[:num_ctx_tokens]
|
|
input_ids_ctx = torch.empty_like(input_prompt_ids,
|
|
dtype=torch.int32,
|
|
device="cuda")
|
|
input_ids_ctx[:-1].copy_(input_prompt_ids[1:])
|
|
input_ids_ctx[last_tokens_idx[:num_contexts]] = \
|
|
accepted_tokens[:num_contexts, 0]
|
|
return_input_ids_list.append(input_ids_ctx)
|
|
return_hidden_states_list.append(hidden_states_ctx)
|
|
# generation
|
|
if num_gens > 0:
|
|
slot_ids = spec_metadata.slot_ids[num_contexts:batch_size]
|
|
gen_batch_idx = spec_metadata.batch_indices_cuda[:num_gens]
|
|
gen_token_idx = num_accepted_tokens[num_contexts:] - 1
|
|
accepted_tokens_gen = accepted_tokens[num_contexts:, :]
|
|
input_ids_gen = accepted_tokens_gen[gen_batch_idx,
|
|
gen_token_idx].unsqueeze(1)
|
|
input_ids_gen = torch.concat(
|
|
[mtp_past_tokens_pool[slot_ids][:, 1:], input_ids_gen],
|
|
dim=1)
|
|
hidden_states_gen = mtp_past_hidden_states_pool[
|
|
slot_ids].flatten(0, 1)
|
|
return_input_ids_list.append(input_ids_gen.flatten(0, 1))
|
|
return_hidden_states_list.append(hidden_states_gen)
|
|
# Concatenate into continuous buffers
|
|
return_input_ids = torch.concat(return_input_ids_list, dim=0)
|
|
return_hidden_states = torch.concat(return_hidden_states_list,
|
|
dim=0)
|
|
|
|
# update position_ids
|
|
position_ids_list = []
|
|
if num_contexts > 0:
|
|
position_ids_list.append(position_ids[:num_ctx_tokens])
|
|
if num_gens > 0:
|
|
position_ids_gen = position_ids[num_ctx_tokens:].reshape(
|
|
num_gens, mtp_num_modules + 1)[:, -mtp_num_modules:]
|
|
position_ids_gen = position_ids_gen - (
|
|
1 + mtp_num_modules -
|
|
num_accepted_tokens[num_contexts:].unsqueeze(1))
|
|
position_ids_list.append(position_ids_gen.flatten())
|
|
return_position_ids = torch.concat(position_ids_list, dim=-1)
|
|
|
|
return {
|
|
"input_ids": return_input_ids,
|
|
"position_ids": return_position_ids,
|
|
"hidden_states": return_hidden_states,
|
|
"attn_metadata": attn_metadata,
|
|
}
|
|
|
|
@torch.compile(options={"max-autotune": True})
|
|
def get_local_max_and_combined(self, logits, mapping_lm_tp=None):
|
|
local_max_values, local_argmax = torch.max(logits, dim=-1, keepdim=True)
|
|
# Adjust indices based on TP rank and size
|
|
vocab_per_rank = logits.shape[-1]
|
|
mapping_lm_tp = mapping_lm_tp if mapping_lm_tp is not None else self.model_config.mapping
|
|
max_index_per_rank = local_argmax.type(
|
|
torch.int32) + (mapping_lm_tp.tp_rank * vocab_per_rank)
|
|
# Use torch.stack and flatten instead of view+cat to avoid torch.compile issues
|
|
# Convert both to float32 to ensure consistent dtype
|
|
max_index_per_rank_float = max_index_per_rank.float()
|
|
local_max_values_float32 = local_max_values.float()
|
|
|
|
# Stack and flatten to get interleaved layout: [idx0, val0, idx1, val1, ...]
|
|
combined = torch.stack(
|
|
[max_index_per_rank_float, local_max_values_float32],
|
|
dim=-1).flatten(-2)
|
|
return combined
|
|
|
|
@torch.compile(options={"max-autotune": True})
|
|
def get_draft_tokens_from_gathered(self, gathered):
|
|
gathered_indices_float = gathered[..., 0::2] # Even positions: indices
|
|
gathered_values_float = gathered[..., 1::2] # Odd positions: values
|
|
|
|
# Find the rank with maximum value
|
|
max_indices = torch.argmax(gathered_values_float, dim=-1, keepdim=True)
|
|
|
|
# Get the corresponding token indices and convert back to int32
|
|
draft_tokens = torch.gather(gathered_indices_float, -1,
|
|
max_indices).squeeze(-1).type(torch.int32)
|
|
return draft_tokens
|
|
|
|
def draft_sampler(
|
|
self,
|
|
logits: torch.Tensor,
|
|
mapping_lm_head_tp: Mapping = None,
|
|
):
|
|
'''
|
|
Sampling draft tokens.
|
|
|
|
Args:
|
|
logits: torch.Tensor
|
|
[num_tokens, vocab_size]
|
|
Logits produced by the draft model.
|
|
|
|
Returns:
|
|
draft_tokens: torch.Tensor
|
|
[batch_size * max_draft_len]
|
|
Draft token ids. Flattened.
|
|
'''
|
|
if (self.model_config is not None
|
|
and hasattr(self.model_config, 'mapping')
|
|
and self.model_config.mapping.tp_size
|
|
> 1) and not (self.model_config.mapping.enable_attention_dp):
|
|
combined = self.get_local_max_and_combined(logits)
|
|
gathered = allgather(combined, self.model_config.mapping, dim=-1)
|
|
draft_tokens = self.get_draft_tokens_from_gathered(gathered)
|
|
elif (self.model_config is not None
|
|
and hasattr(self.model_config, 'mapping')
|
|
and self.model_config.mapping.tp_size
|
|
> 1) and self.model_config.mapping.enable_lm_head_tp_in_adp:
|
|
# For ADP + LM head TP mode, we need to find the global argmax across all TP ranks
|
|
combined = self.get_local_max_and_combined(logits,
|
|
mapping_lm_head_tp)
|
|
gathered = allgather(combined, mapping_lm_head_tp, dim=-1)
|
|
batch_size = logits.shape[0]
|
|
local_batch_size = batch_size // mapping_lm_head_tp.tp_size
|
|
gathered = gathered.view(mapping_lm_head_tp.tp_size,
|
|
local_batch_size, -1)
|
|
sliced_gathered = gathered[mapping_lm_head_tp.tp_rank]
|
|
draft_tokens = self.get_draft_tokens_from_gathered(sliced_gathered)
|
|
else:
|
|
# Simple argmax if no TP or no model config
|
|
draft_tokens = torch.argmax(logits, dim=-1).type(torch.int32)
|
|
|
|
return draft_tokens
|
|
|
|
def set_guided_decoder(self,
|
|
guided_decoder: CapturableGuidedDecoder) -> bool:
|
|
self.guided_decoder = guided_decoder
|
|
return True
|
|
|
|
|
|
class MTPEagleWorker(MTPWorker):
|
|
|
|
def __init__(self,
|
|
spec_config: "MTPDecodingConfig",
|
|
model_config: Optional[ModelConfig] = None):
|
|
super().__init__(spec_config, model_config)
|
|
self.model_config = model_config
|
|
self.mtp_num_modules = spec_config.num_nextn_predict_layers
|
|
|
|
@torch.compile(options={"max-autotune": True})
|
|
def update_draft_tokens(self, next_draft_tokens, new_draft_token,
|
|
hidden_states, gather_ids, inputs):
|
|
next_draft_tokens.append(new_draft_token)
|
|
# update inputs
|
|
hidden_states = hidden_states[gather_ids]
|
|
position_ids = inputs["position_ids"][gather_ids] + 1
|
|
return hidden_states, position_ids
|
|
|
|
def forward(
|
|
self,
|
|
input_ids,
|
|
position_ids,
|
|
hidden_states,
|
|
logits,
|
|
attn_metadata,
|
|
spec_metadata,
|
|
draft_model,
|
|
):
|
|
batch_size = attn_metadata.num_seqs
|
|
num_contexts = attn_metadata.num_contexts
|
|
num_gens = batch_size - num_contexts
|
|
|
|
raw_logits = logits
|
|
|
|
if self.guided_decoder is not None:
|
|
self.guided_decoder.execute(logits)
|
|
|
|
# Sample and verify draft tokens
|
|
accepted_tokens, num_accepted_tokens = self.sample_and_accept_draft_tokens(
|
|
input_ids, logits, spec_metadata, attn_metadata)
|
|
|
|
# Save the old attn_metadata and spec_metadata
|
|
attn_metadata.prepare_for_spec_dec("_seq_lens", "_seq_lens_cuda")
|
|
|
|
# Prepare inputs for the 1st MTP layer
|
|
@torch.compile(options={"max-autotune": True})
|
|
def prepare_position_ids_and_last_tokens(position_ids, attn_metadata):
|
|
position_ids = position_ids.squeeze(0)
|
|
last_tokens_idx = torch.cumsum(
|
|
attn_metadata.seq_lens_cuda, dim=0, dtype=torch.long) - 1
|
|
return position_ids, last_tokens_idx
|
|
|
|
position_ids, last_tokens_idx = prepare_position_ids_and_last_tokens(
|
|
position_ids, attn_metadata)
|
|
inputs = self.prepare_drafter_inputs(input_ids=input_ids,
|
|
position_ids=position_ids,
|
|
last_tokens_idx=last_tokens_idx,
|
|
hidden_states=hidden_states,
|
|
accepted_tokens=accepted_tokens,
|
|
attn_metadata=attn_metadata,
|
|
spec_metadata=spec_metadata)
|
|
|
|
# Predict draft tokens
|
|
next_draft_tokens = []
|
|
for i in range(self.mtp_num_modules):
|
|
if i == 0:
|
|
hidden_states = draft_model.mtp_layers[0](
|
|
embed_tokens=draft_model.embed_tokens,
|
|
all_rank_num_tokens=spec_metadata.all_rank_num_tokens,
|
|
**inputs)
|
|
start_ids_gen = (spec_metadata.batch_indices_cuda[:num_gens] *
|
|
(self.mtp_num_modules + 1)).long()
|
|
gather_ids_gen = (start_ids_gen +
|
|
num_accepted_tokens[num_contexts:] - 1 +
|
|
attn_metadata.num_ctx_tokens)
|
|
gather_ids = torch.concat(
|
|
[last_tokens_idx[:num_contexts], gather_ids_gen], dim=0)
|
|
else:
|
|
hidden_states = draft_model.mtp_layers[0](
|
|
embed_tokens=draft_model.embed_tokens,
|
|
all_rank_num_tokens=spec_metadata.
|
|
subseq_all_rank_num_tokens,
|
|
**inputs)
|
|
# All of the seq_len are 1, use batch_indices_cuda as gather_ids
|
|
gather_ids = spec_metadata.batch_indices_cuda[:batch_size]
|
|
|
|
if self.guided_decoder is not None:
|
|
new_tokens = inputs["input_ids"][gather_ids]
|
|
self.guided_decoder.add_draft_batch(new_tokens,
|
|
num_accepted_tokens,
|
|
draft_step=i)
|
|
if self.model_config.mapping.enable_attention_dp and \
|
|
getattr(self.model_config.mapping, 'enable_lm_head_tp_in_adp', False):
|
|
hidden_states_gathered = hidden_states[gather_ids]
|
|
token_count = hidden_states_gathered.view(
|
|
-1, hidden_states_gathered.shape[-1]).shape[0]
|
|
max_num_requests = spec_metadata.max_num_requests
|
|
pad_len = max_num_requests - token_count
|
|
if pad_len > 0:
|
|
padded_hidden_states = F.pad(hidden_states_gathered.view(
|
|
-1, hidden_states_gathered.shape[-1]),
|
|
(0, 0, 0, pad_len),
|
|
mode="constant",
|
|
value=0)
|
|
elif pad_len == 0:
|
|
padded_hidden_states = hidden_states_gathered.view(
|
|
-1, hidden_states_gathered.shape[-1])
|
|
else:
|
|
raise ValueError(
|
|
f"In MTPEagleWorker.forward(), token_count < max_num_requests, which is not supported"
|
|
)
|
|
logits = draft_model.mtp_layers[0].shared_head(
|
|
padded_hidden_states, draft_model.lm_head, attn_metadata,
|
|
True)
|
|
else:
|
|
logits = draft_model.mtp_layers[0].shared_head(
|
|
hidden_states[gather_ids], draft_model.lm_head,
|
|
attn_metadata, True)
|
|
if self.guided_decoder is not None:
|
|
self.guided_decoder.execute_draft_batch(logits, draft_step=i)
|
|
|
|
if self.model_config.mapping.enable_attention_dp and \
|
|
getattr(self.model_config.mapping, 'enable_lm_head_tp_in_adp', False):
|
|
mapping_lm_head_tp = draft_model.mtp_layers[
|
|
0].shared_head.mapping_lm_head_tp
|
|
new_draft_token = self.draft_sampler(logits, mapping_lm_head_tp)
|
|
new_draft_token = new_draft_token[:token_count]
|
|
else:
|
|
new_draft_token = self.draft_sampler(logits)
|
|
|
|
hidden_states, position_ids = self.update_draft_tokens(
|
|
next_draft_tokens, new_draft_token, hidden_states, gather_ids,
|
|
inputs)
|
|
# update attn_metadata
|
|
if i == 0:
|
|
attn_metadata._seq_lens[:batch_size].fill_(1)
|
|
attn_metadata._seq_lens_cuda[:batch_size].fill_(1)
|
|
attn_metadata.on_update()
|
|
# cannot run generation if their is no kv cache
|
|
has_kv_cache = inputs[
|
|
"attn_metadata"].kv_cache_manager is not None
|
|
if has_kv_cache:
|
|
attn_metadata.host_request_types[:attn_metadata.
|
|
num_contexts].fill_(1)
|
|
attn_metadata.num_contexts = 0
|
|
# update kv_lens_cuda
|
|
if hasattr(attn_metadata, 'kv_lens_cuda'):
|
|
attn_metadata.kv_lens_cuda[num_contexts:batch_size] -= (
|
|
self.mtp_num_modules -
|
|
num_accepted_tokens[num_contexts:])
|
|
attn_metadata.kv_lens_cuda[:num_contexts] += 1
|
|
# update metadata for flash mla
|
|
if has_kv_cache and num_contexts > 0 and attn_metadata.enable_flash_mla:
|
|
reorder_block_ids_per_seq = torch.cat([
|
|
attn_metadata.
|
|
kv_block_ids_per_seq[num_contexts:batch_size],
|
|
attn_metadata.kv_block_ids_per_seq[:num_contexts]
|
|
])
|
|
attn_metadata.block_ids_per_seq[:batch_size, :].copy_(
|
|
reorder_block_ids_per_seq, non_blocking=True)
|
|
# update metadata
|
|
# some attention metadata needs to be updated when changing seq_lens/kv_lens
|
|
attn_metadata.update_for_spec_dec()
|
|
elif hasattr(attn_metadata, 'kv_lens_cuda'):
|
|
|
|
@torch.compile(options={"max-autotune": True})
|
|
def update_kv_lens(kv_lens_cuda, batch_size):
|
|
kv_lens_cuda[:batch_size] += 1
|
|
|
|
update_kv_lens(attn_metadata.kv_lens_cuda, batch_size)
|
|
# update metadata
|
|
# some attention metadata needs to be updated when changing kv_lens
|
|
attn_metadata.update_for_spec_dec()
|
|
inputs = {
|
|
"input_ids": new_draft_token,
|
|
"position_ids": position_ids,
|
|
"hidden_states": hidden_states,
|
|
"attn_metadata": attn_metadata,
|
|
}
|
|
|
|
# restore attn_metadata to support cuda graph
|
|
attn_metadata.restore_from_spec_dec()
|
|
attn_metadata.on_update()
|
|
|
|
@torch.compile(options={"max-autotune": True})
|
|
def prepare_next_tokens(next_draft_tokens, accepted_tokens,
|
|
spec_metadata, batch_size, num_accepted_tokens):
|
|
next_draft_tokens = torch.stack(next_draft_tokens, dim=1)
|
|
# prepare next new tokens to support overlap scheduler
|
|
next_new_tokens = accepted_tokens[
|
|
spec_metadata.batch_indices_cuda[:batch_size],
|
|
num_accepted_tokens - 1].unsqueeze(1)
|
|
next_new_tokens = torch.concat([next_new_tokens, next_draft_tokens],
|
|
dim=1)
|
|
return next_draft_tokens, next_new_tokens
|
|
|
|
next_draft_tokens, next_new_tokens = prepare_next_tokens(
|
|
next_draft_tokens, accepted_tokens, spec_metadata, batch_size,
|
|
num_accepted_tokens)
|
|
|
|
return {
|
|
'logits': raw_logits,
|
|
'new_tokens': accepted_tokens,
|
|
'new_tokens_lens': num_accepted_tokens,
|
|
'next_draft_tokens': next_draft_tokens,
|
|
'next_new_tokens': next_new_tokens
|
|
}
|
|
|
|
@torch.compile(options={"max-autotune": True})
|
|
def prepare_drafter_inputs(
|
|
self,
|
|
input_ids: torch.IntTensor,
|
|
position_ids: torch.IntTensor,
|
|
last_tokens_idx: torch.LongTensor,
|
|
hidden_states: torch.Tensor,
|
|
accepted_tokens: torch.Tensor,
|
|
attn_metadata: AttentionMetadata,
|
|
spec_metadata: MTPSpecMetadata,
|
|
):
|
|
num_contexts = attn_metadata.num_contexts
|
|
|
|
# context
|
|
input_prompt_ids = input_ids[:attn_metadata.num_ctx_tokens]
|
|
input_ids_ctx = torch.empty_like(input_prompt_ids,
|
|
dtype=torch.int32,
|
|
device="cuda")
|
|
input_ids_ctx[:-1].copy_(input_prompt_ids[1:])
|
|
input_ids_ctx[
|
|
last_tokens_idx[:num_contexts]] = accepted_tokens[:num_contexts, 0]
|
|
|
|
# generation
|
|
input_ids_gen = accepted_tokens[num_contexts:, :].flatten()
|
|
|
|
# get draft inputs
|
|
input_ids = torch.concat([input_ids_ctx, input_ids_gen], dim=0)
|
|
|
|
return {
|
|
"input_ids": input_ids,
|
|
"position_ids": position_ids,
|
|
"hidden_states": hidden_states,
|
|
"attn_metadata": attn_metadata,
|
|
}
|