from dataclasses import dataclass from typing import TYPE_CHECKING, List, Optional import torch import torch.nn.functional as F from tensorrt_llm.mapping import Mapping from ..attention_backend import AttentionMetadata from ..distributed.ops import allgather from ..model_config import ModelConfig from ..pyexecutor.llm_request import LlmRequest, LlmRequestState from ..pyexecutor.resource_manager import BaseResourceManager, SlotManager from ..pyexecutor.sampler import (DEFAULT_BEAM_IDX, SampleState, SampleStateTensors, TorchSampler, add_token, int_tensor) from ..pyexecutor.scheduler import ScheduledRequests from .interface import SpecMetadata, SpecWorkerBase if TYPE_CHECKING: from tensorrt_llm.llmapi.llm_args import MTPDecodingConfig @dataclass(kw_only=True) class SampleStateTensorsMTP(SampleStateTensors): new_tokens_lens: torch.Tensor next_draft_tokens: torch.Tensor @dataclass(kw_only=True) class SampleStateMTP(SampleState): device: SampleStateTensorsMTP host: SampleStateTensorsMTP class MTPHiddenStatesManager(BaseResourceManager): def __init__(self, config: "MTPDecodingConfig", dtype: torch.dtype, hidden_size: int, max_num_requests: int): self.dtype = dtype self.num_nextn_predict_layers = config.num_nextn_predict_layers self.hidden_size = hidden_size self.max_num_requests = max_num_requests self.use_relaxed_acceptance_for_thinking = config.use_relaxed_acceptance_for_thinking self.slot_manager = SlotManager(max_num_requests) # Since golden token's hidden state will always be generated after target model self.mtp_past_hidden_states_pool = torch.zeros( (max_num_requests, self.num_nextn_predict_layers, self.hidden_size), device='cuda', dtype=self.dtype, ) self.mtp_past_tokens_pool = torch.zeros( (max_num_requests, self.num_nextn_predict_layers), device='cuda', dtype=torch.int, ) if self.use_relaxed_acceptance_for_thinking: # The relaxed_delta for relaxed acceptance self.mtp_relaxed_delta_pool = torch.zeros( (self.max_num_requests), dtype=torch.float, device='cuda', ) def prepare_resources(self, scheduled_batch: ScheduledRequests): context_batch = scheduled_batch.context_requests # allocate hidden state tensors for req in context_batch: if req.is_first_context_chunk: slot_id = self.slot_manager.add_slot(req.request_id) if self.use_relaxed_acceptance_for_thinking: self.mtp_relaxed_delta_pool[slot_id].copy_( 0, non_blocking=True) def update_resources(self, scheduled_batch: ScheduledRequests): pass def free_resources(self, request: LlmRequest): free_slot_id = self.slot_manager.get_slot(request.request_id) if self.use_relaxed_acceptance_for_thinking: self.mtp_relaxed_delta_pool[free_slot_id].copy_(0, non_blocking=True) self.slot_manager.remove_slot(request.request_id) def add_dummy_requests(self, request_ids: List[int]): for rid in request_ids: self.slot_manager.add_slot(rid) def shutdown(self): self.slot_manager.shutdown() def get_max_resource_count(self) -> int: return self.max_num_requests def get_needed_resource_to_completion(self, request: LlmRequest): return 0 @dataclass class MTPSpecMetadata(SpecMetadata): """ Metadata for MTP. """ # The number of MTP modules in the model mtp_num_modules: int = 1 # The hidden states manager for MTP mtp_hidden_states_manager: Optional[MTPHiddenStatesManager] = None # The slot ids for each request. slot_ids: Optional[torch.Tensor] = None # The index of the batche inputs batch_indices_cuda: Optional[torch.Tensor] = None # The number of sequences for speculative model/layer of different rank _all_rank_num_seqs: Optional[List[int]] = None # This is used for attention dp in the MTP Eagle worker. The numbers of input # tokens varies between the 1st draft forward and subsequent ones. To support # CUDA graph, we use this tensor to store the number of input tokens for the # subsequence draft forward. subseq_all_rank_num_tokens: Optional[List[int]] = None def __post_init__(self) -> None: if self.mtp_hidden_states_manager is not None: # mtp_hidden_states_ptrs is a pointer tensor self.mtp_hidden_states_ptrs = torch.empty( [self.max_num_requests], dtype=torch.int64, device='cuda', ) self.mtp_past_tokens_ptrs = torch.empty( [self.max_num_requests], dtype=torch.int64, device='cuda', ) self.slot_ids = torch.empty( [self.max_num_requests], dtype=torch.long, device='cuda', ) self.batch_indices_cuda = torch.empty( [self.max_num_requests], dtype=torch.int, device='cuda', ) self.draft_token_indices_cuda = torch.arange( self.mtp_num_modules, device='cuda', ) @property def all_rank_num_seqs(self): return self._all_rank_num_seqs @all_rank_num_seqs.setter def all_rank_num_seqs(self, value: List[int]): self._all_rank_num_seqs = value if self.spec_dec_mode.is_mtp_eagle_one_model(): self.subseq_all_rank_num_tokens = value def prepare(self): assert self.request_ids is not None num_seqs = len(self.request_ids) # update batch indeices batch_indices = torch.arange(num_seqs, dtype=torch.int, device='cpu', pin_memory=True) self.batch_indices_cuda[:num_seqs].copy_(batch_indices, non_blocking=True) # MTP vanilla worker uses total max_draft_len input tokens in generation phase, # while MTP Eagle worker uses (max_draft_len + 1) input tokens in the 1st draft # forward and only one input token in the following draft forward. # This num_tokens is used to set the all_rank_num_tokens for attention dp. if not self.spec_dec_mode.is_mtp_eagle_one_model(): self.num_tokens -= self.num_generations if self.mtp_hidden_states_manager is not None: # MTP vanilla or use relaxed acceptance mtp_slot_ids = [] for rid in self.request_ids: slot_id = self.mtp_hidden_states_manager.slot_manager.get_slot( rid) mtp_slot_ids.append(slot_id) # MTP Vanilla: Update mtp hidden states and past tokens if self.spec_dec_mode.is_mtp_one_model(): mtp_hidden_states_ptrs = [] mtp_past_tokens_ptrs = [] for slot_id in mtp_slot_ids: mtp_hidden_states_ptrs.append( self.mtp_hidden_states_manager. mtp_past_hidden_states_pool[slot_id].data_ptr()) mtp_past_tokens_ptrs.append( self.mtp_hidden_states_manager. mtp_past_tokens_pool[slot_id].data_ptr()) mtp_hidden_states_ptrs = torch.tensor(mtp_hidden_states_ptrs, dtype=torch.int64, pin_memory=True) mtp_past_tokens_ptrs = torch.tensor(mtp_past_tokens_ptrs, dtype=torch.int64, pin_memory=True) self.mtp_hidden_states_ptrs[:num_seqs].copy_( mtp_hidden_states_ptrs, non_blocking=True) self.mtp_past_tokens_ptrs[:num_seqs].copy_(mtp_past_tokens_ptrs, non_blocking=True) mtp_slot_ids = torch.tensor(mtp_slot_ids, dtype=torch.int, pin_memory=True) self.slot_ids[:num_seqs].copy_(mtp_slot_ids, non_blocking=True) class MTPSampler(TorchSampler): """ MTP sampler. """ SampleState = SampleStateMTP @dataclass(frozen=True, kw_only=True) class Store(TorchSampler.Store): new_tokens: torch.Tensor next_new_tokens: torch.Tensor next_draft_tokens: torch.Tensor new_tokens_lens: torch.Tensor max_total_draft_tokens: torch.Tensor finish_reasons: None = None # Necessary to satisfy the interface of TorchSampler.Store def __post_init__(self): pass # finish_reasons has no size to compare against new_tokens in MTPSampler def __init__(self, args: TorchSampler.Args, *, nextn: int): self.mapping = None self.draft_len = nextn self.max_seq_len = args.max_seq_len seq_slots = args.max_num_sequences max_tokens = args.max_total_draft_tokens + 1 self.max_beam_width = args.max_beam_width assert self.max_beam_width == 1, "beam width must be 1 for MTP" self.store = self.Store( new_tokens=int_tensor((max_tokens, seq_slots, self.max_beam_width)), next_new_tokens=int_tensor( (max_tokens, seq_slots, self.max_beam_width)), next_draft_tokens=int_tensor( (seq_slots, args.max_total_draft_tokens)), new_tokens_lens=int_tensor((seq_slots, )), max_total_draft_tokens=int_tensor( (seq_slots, args.max_total_draft_tokens)), ) def _request_common_handling(self, request: LlmRequest, next_draft_tokens: list[list[int]]): assert not request.py_return_context_logits, "return_context_logits not implemented for MTPSampler" assert not request.py_return_generation_logits, "return_generation_logits not implemented for MTPSampler" assert not request.py_return_log_probs, "return_log_probs not implemented for MTPSampler" request.py_draft_tokens = next_draft_tokens[request.py_seq_slot] request.py_decoding_iter += 1 def update_requests( self, state: SampleStateMTP, resource_manager: Optional[BaseResourceManager] = None) -> None: # resource_manager will be not be used in this function assert isinstance(state, SampleStateMTP) state.sampler_event.synchronize() new_tokens = state.host.new_tokens.tolist() new_tokens_lens_list = state.host.new_tokens_lens.tolist() next_draft_tokens_list = state.host.next_draft_tokens.tolist() beam_idx = DEFAULT_BEAM_IDX for req in state.scheduled_requests.context_requests: if req.state == LlmRequestState.GENERATION_COMPLETE or req.context_remaining_length != 0: continue new_token = add_token(req, new_tokens, beam_idx=beam_idx) TorchSampler._handle_stop_criteria(req, new_token, max_seq_len=self.max_seq_len, beam_idx=beam_idx) self._request_common_handling(req, next_draft_tokens_list) for req in state.scheduled_requests.generation_requests: if req.state == LlmRequestState.GENERATION_COMPLETE: continue num_new_tokens = new_tokens_lens_list[req.py_seq_slot] for i in range(num_new_tokens): new_token = add_token(req, new_tokens, beam_idx=beam_idx, step=i) if TorchSampler._handle_stop_criteria( req, new_token, max_seq_len=self.max_seq_len, beam_idx=beam_idx): break req.py_num_accepted_draft_tokens = num_new_tokens - 1 req.py_rewind_len = self.draft_len - req.py_num_accepted_draft_tokens self._request_common_handling(req, next_draft_tokens_list) def sample_async( self, scheduled_requests: ScheduledRequests, outputs: dict[str, torch.Tensor], num_context_logits_prefix_sum: list[int]) -> SampleStateMTP: # new_tokens_device: accepted tokens, device tensor, shape: batch_size, nextn + 1 # new_tokens_lens_device: accepted lengths, device tensor, shape: batch_size # next_draft_tokens_device: predicted draft tokens, device tensor, shape: batch_size, nextn # next_new_tokens_device: input tokens for the next iteration, device tensor, shape: batch_size, nextn + 1 requests = scheduled_requests.all_requests() slots = torch.as_tensor([r.py_seq_slot for r in requests]) slots = slots.to(device="cuda", non_blocking=True) o_new_tokens = outputs['new_tokens'][:len(requests)] o_new_tokens_lens = outputs['new_tokens_lens'][:len(requests)] o_next_draft_tokens = outputs['next_draft_tokens'][:len(requests)] o_next_new_tokens = outputs['next_new_tokens'][:len(requests)] new_tokens = self.store.new_tokens next_new_tokens = self.store.next_new_tokens new_tokens_lens = self.store.new_tokens_lens next_draft_tokens = self.store.next_draft_tokens new_tokens.squeeze(-1).T.index_copy_(0, slots, o_new_tokens) next_new_tokens.squeeze(-1).T.index_copy_(0, slots, o_next_new_tokens) new_tokens_lens.index_copy_(0, slots, o_new_tokens_lens) next_draft_tokens.index_copy_(0, slots, o_next_draft_tokens) device = SampleStateTensorsMTP( new_tokens=next_new_tokens, new_tokens_lens=new_tokens_lens, next_draft_tokens=next_draft_tokens, ) host = SampleStateTensorsMTP( new_tokens=new_tokens.to('cpu', non_blocking=True), new_tokens_lens=new_tokens_lens.to('cpu', non_blocking=True), next_draft_tokens=next_draft_tokens.to('cpu', non_blocking=True), ) sampler_event = torch.cuda.Event() sampler_event.record() # add dummy draft tokens to context requests to prepare kv cache in advance # with the max draft token length for request in scheduled_requests.context_requests: request.py_draft_tokens = [1] * self.draft_len return SampleStateMTP(scheduled_requests=scheduled_requests, device=device, host=host, sampler_event=sampler_event) class MTPWorker(SpecWorkerBase): def __init__(self, spec_config: "MTPDecodingConfig", model_config=None): super().__init__() self.spec_config = spec_config self.model_config = model_config self.is_thop = False @property def max_draft_len(self) -> int: return self.spec_config.num_nextn_predict_layers def forward( self, input_ids, position_ids, hidden_states, logits, attn_metadata, spec_metadata, draft_model, ): ''' Example: Assume there are 3 MTP layers Notation: - H_t: token t's hidden state, generated by the target model - h_t: token t's hidden state, generated by the draft model Prompt: ABCD Context phase: Target model: - input tokens: ABCD + [] - sampling tokens: E - accepted tokens: E - KV cache: ABCD - hidden states: H_A, H_B, H_C, H_D Draft model: MTP1: # For context request, prompt[1:] + new generated goloden token is the input. - input tokens: BCDE - input hidden states: H_A, H_B, H_C, H_D # '()' means historical KV cache - KV cache: () + BCDE - output hidden states: h_B, h_C, h_D, h_E - output next draft token: F MTP2: - input token: CDEF - input hidden states: H_B, H_C, H_D, h_E - KV cache: () + CDEF - output hidden states: h_C, h_D, h_E, h_F - output next draft token: G MTP3: - input tokens: DEFG - input hidden states: H_C, H_D, h_E, h_F - KV cache: () + DEFG - output hidden states: h_D, h_E, h_F, h_G - output next draft token: H After 3 MTP layers: - new generated draft tokens: FGH Generation phase 1: accept partial draft tokens Target model: - input tokens: E + FGH - sampling tokens: FGXY - accepted tokens: FGX - KV cache: (ABCD) + EFGH (H's KV cache is invalid) - hidden states: H_E, H_F, H_G, H_H (H_H is invalid) Draft model: MTP1: # For generation request, `mtp_num_modules` of tokens will be used as input. - input tokens: FGX - input hidden states: H_E, H_F, H_G - KV cache: (BCDE) + FGX - output hidden states: h_F, h_G, h_X - output next draft token: N MTP2: - input tokens: GXN - input hidden states: H_F, H_G, h_X - KV cache: (CDEF) + GXN - output hidden states: h_G, h_X, h_N - output next draft token: O MTP3: - input tokens: XNO - input hidden states: H_G, H_X, h_N - KV cache: (DEFG) + XNO - output hidden states: h_X, h_N, h_O - output next draft token: P After 3 MTP layers: - new generated draft tokens: NOP Generation 2: accept none draft tokens Target model: - input tokens: X + NOP - sampling tokens: KMZY - accepted tokens: K - KV cache: (ABCDEFG) + NOP (NOP's KV cache is invalid) - hidden states: H_X, H_N, H_O, H_P (H_N, H_O, H_P is invalid) Draft model: MTP1: - input tokens: GXK - input hidden states: H_F, H_G, H_X - KV cache: (BCDE + F) + GXK - output hidden states: h_G, h_X, h_K - output next draft token: U MTP2: - input tokens: XKU - input hidden states: H_G, H_X, h_K - KV cache: (CDEF + G) + XKU - output hidden states: h_X, h_K, h_U - output next draft token: V MTP3: - input tokens: KUV - input hidden states: H_X, h_K, h_U - KV cache: (DEFG + X) + KUV - output hidden states: h_K, h_U, h_V - output next draft token: Q After 3 MTP layers: - new generated draft tokens: UVQ ''' batch_size = attn_metadata.num_seqs raw_logits = logits if self.guided_decoder is not None: self.guided_decoder.execute(logits) # Sample and verify draft tokens accepted_tokens, num_accepted_tokens = self.sample_and_accept_draft_tokens( input_ids, logits, spec_metadata, attn_metadata) # Update MTP past hidden states self.update_mtp_hidden_states(input_ids=input_ids, hidden_states=hidden_states, num_accepted_tokens=num_accepted_tokens, spec_metadata=spec_metadata, attn_metadata=attn_metadata) # prepare draft layer inputs position_ids = position_ids.squeeze(0) draft_inputs = self.prepare_drafter_inputs( input_ids=input_ids, position_ids=position_ids, hidden_states=hidden_states, accepted_tokens=accepted_tokens, num_accepted_tokens=num_accepted_tokens, spec_metadata=spec_metadata, attn_metadata=attn_metadata) # update attn metadata if attn_metadata is not None: self.change_attn_metadata(num_accepted_tokens, attn_metadata) draft_inputs.update(attn_metadata=attn_metadata) # Run MTP layers to predict draft tokens next_draft_tokens = [] last_tokens_idx = torch.cumsum( attn_metadata.seq_lens_cuda, dim=0, dtype=torch.long) - 1 for i, mtp_layer in enumerate(draft_model.mtp_layers): if self.guided_decoder is not None: new_tokens = draft_inputs['input_ids'][last_tokens_idx] self.guided_decoder.add_draft_batch(new_tokens, num_accepted_tokens, draft_step=i) hidden_states = mtp_layer(embed_tokens=draft_model.embed_tokens, **draft_inputs) logits = mtp_layer.shared_head(hidden_states, draft_model.lm_head, attn_metadata).float() if self.guided_decoder is not None: self.guided_decoder.execute_draft_batch(logits, draft_step=i) new_draft_token = self.draft_sampler(logits) next_draft_tokens.append(new_draft_token) # shift input_ids and hidden_states input_ids = draft_inputs["input_ids"] input_ids[:-1] = input_ids[1:].clone() input_ids[last_tokens_idx] = new_draft_token draft_hidden_states = draft_inputs["hidden_states"] draft_hidden_states[:-1] = draft_hidden_states[1:].clone() draft_hidden_states[last_tokens_idx] = hidden_states[ last_tokens_idx, :] draft_inputs = { "input_ids": input_ids, "position_ids": draft_inputs["position_ids"], "hidden_states": draft_hidden_states, "attn_metadata": draft_inputs["attn_metadata"], } next_draft_tokens = torch.stack(next_draft_tokens, dim=1) # restore attn metadata if attn_metadata is not None: self.restore_attn_metadata(attn_metadata=attn_metadata) # prepare next new tokens to support overlap scheduler next_new_tokens = accepted_tokens[ spec_metadata.batch_indices_cuda[:batch_size], num_accepted_tokens - 1].unsqueeze(1) next_new_tokens = torch.concat([next_new_tokens, next_draft_tokens], dim=1) return { 'logits': raw_logits, 'new_tokens': accepted_tokens, 'new_tokens_lens': num_accepted_tokens, 'next_draft_tokens': next_draft_tokens, 'next_new_tokens': next_new_tokens } def skip_forward( self, input_ids, position_ids, hidden_states, logits, attn_metadata, spec_metadata, draft_model, ): batch_size = attn_metadata.num_seqs mtp_num_modules = self.spec_config.num_nextn_predict_layers accepted_tokens = torch.empty((batch_size, (mtp_num_modules + 1)), dtype=torch.int, device=logits.device) num_accepted_tokens = torch.ones(batch_size, dtype=torch.int, device=logits.device) next_draft_tokens = torch.empty((batch_size, mtp_num_modules), dtype=torch.int, device=logits.device) next_new_tokens = torch.empty((batch_size, (mtp_num_modules + 1)), dtype=torch.int, device=logits.device) return { 'logits': logits, 'new_tokens': accepted_tokens, 'new_tokens_lens': num_accepted_tokens, 'next_draft_tokens': next_draft_tokens, 'next_new_tokens': next_new_tokens } def update_mtp_hidden_states( self, input_ids: torch.IntTensor, hidden_states: torch.Tensor, num_accepted_tokens: torch.Tensor, spec_metadata: MTPSpecMetadata, attn_metadata: AttentionMetadata, ): ''' Update the past hidden states and past tokens in spec_metadata base on the newly accepted tokens and historical hidden states. These past hidden states and past tokens will be use in MTP module. Args: input_ids: torch.IntTensor [num_tokens] The input ids of all requests. Flatten. hidden_states: torch.Tensor [num_tokens, hidden_size] Target model's hidden states. num_accepted_tokens: torch.Tensor [batch_size] Number of accepted tokens per request. spec_metadata: MTPSpecMetadata MTP speculative decoding metadata attn_metadata: AttentionMetadata Attention metadata Returns: None ''' def unpack_sequence(packed_seq_cuda, seq_lens_cuda, seq_lens_cpu): # max_length is used as tensor shape, so it should be from host; # otherwise, an implicit D2H copy will be triggered. max_length = seq_lens_cpu.max().item() num_sequences = seq_lens_cuda.shape[0] # initialize a zero tensor to store the result result = torch.zeros( (num_sequences, max_length, packed_seq_cuda.shape[1]), dtype=packed_seq_cuda.dtype, device=packed_seq_cuda.device) # get mask seq_indices = torch.arange( max_length, device=seq_lens_cuda.device).unsqueeze(0).expand( num_sequences, -1) mask = seq_indices < seq_lens_cuda.unsqueeze(1) # unpack result[mask] = packed_seq_cuda return result batch_size = attn_metadata.num_seqs num_contexts = attn_metadata.num_contexts num_ctx_tokens = attn_metadata.num_ctx_tokens num_gens = batch_size - num_contexts seq_lens = attn_metadata.seq_lens_cuda seq_lens_cpu = attn_metadata.seq_lens hidden_size = hidden_states.shape[-1] mtp_num_modules = self.spec_config.num_nextn_predict_layers if self.is_thop: _, _ = torch.ops.trtllm.mtp_update_hidden_states_op( input_ids, seq_lens, hidden_states, spec_metadata.mtp_hidden_states_ptrs, spec_metadata.mtp_past_tokens_ptrs, num_accepted_tokens, mtp_num_modules, batch_size, num_contexts, hidden_size) else: assert len(spec_metadata.request_ids) == batch_size mtp_past_hidden_states_pool = spec_metadata.mtp_hidden_states_manager.mtp_past_hidden_states_pool mtp_past_tokens_pool = spec_metadata.mtp_hidden_states_manager.mtp_past_tokens_pool slot_ids = spec_metadata.slot_ids[:batch_size] mtp_tokens = mtp_past_tokens_pool[slot_ids] mtp_hidden_states = mtp_past_hidden_states_pool[slot_ids] new_mtp_past_tokens, new_mtp_past_hidden_states = [], [] # context if num_contexts > 0: seq_lens_ctx = seq_lens[:num_contexts] seq_lens_ctx_cpu = seq_lens_cpu[:num_contexts] unpacked_input_ids_ctx = unpack_sequence( input_ids[:num_ctx_tokens].unsqueeze(1), seq_lens_ctx, seq_lens_ctx_cpu).squeeze(2) unpacked_hidden_states_ctx = unpack_sequence( hidden_states[:num_ctx_tokens], seq_lens_ctx, seq_lens_ctx_cpu) cat_tokens_ctx = torch.cat( (mtp_tokens[:num_contexts], unpacked_input_ids_ctx), dim=1) cat_hidden_states_ctx = torch.cat( (mtp_hidden_states[:num_contexts], unpacked_hidden_states_ctx), dim=1) ctx_batch_idx = spec_metadata.batch_indices_cuda[:num_contexts] row_indices_ctx = ctx_batch_idx.unsqueeze(1).expand( -1, mtp_num_modules) col_indices_ctx = (seq_lens_ctx.unsqueeze(1) + spec_metadata.draft_token_indices_cuda) new_mtp_past_tokens.append(cat_tokens_ctx[row_indices_ctx, col_indices_ctx]) new_mtp_past_hidden_states.append( cat_hidden_states_ctx[row_indices_ctx, col_indices_ctx, :]) # generation if num_gens > 0: unpacked_input_ids_gen = input_ids[num_ctx_tokens:].reshape( num_gens, mtp_num_modules + 1).int() hidden_states_gen = hidden_states[num_ctx_tokens:, :] unpacked_hidden_states_gen = hidden_states_gen.reshape( num_gens, mtp_num_modules + 1, hidden_size) cat_tokens_gen = torch.cat( (mtp_tokens[num_contexts:], unpacked_input_ids_gen), dim=1) cat_hidden_states_gen = torch.cat( (mtp_hidden_states[num_contexts:], unpacked_hidden_states_gen), dim=1) gen_batch_idx = spec_metadata.batch_indices_cuda[:num_gens] row_indices_gen = gen_batch_idx.unsqueeze(1).expand( -1, mtp_num_modules) col_indices_gen = ( num_accepted_tokens[num_contexts:].unsqueeze(1) + spec_metadata.draft_token_indices_cuda) new_mtp_past_tokens.append(cat_tokens_gen[row_indices_gen, col_indices_gen]) new_mtp_past_hidden_states.append( cat_hidden_states_gen[row_indices_gen, col_indices_gen, :]) # update past tokens and hidden states new_mtp_past_tokens = torch.cat(new_mtp_past_tokens, dim=0) new_mtp_past_hidden_states = torch.cat(new_mtp_past_hidden_states, dim=0) mtp_past_tokens_pool.index_copy_(0, slot_ids, new_mtp_past_tokens) mtp_past_hidden_states_pool.index_copy_(0, slot_ids, new_mtp_past_hidden_states) @torch.compile(options={"max-autotune": True}) def topk_kernel(self, gen_logprobs, num_gens, mtp_num_modules, spec_metadata): topk_value, topk_indices = torch.topk(gen_logprobs, k=self.spec_config.relaxed_topk, dim=-1) topk_indices = topk_indices.reshape(num_gens, mtp_num_modules + 1, self.spec_config.relaxed_topk) topk_value = topk_value.reshape(num_gens, mtp_num_modules + 1, self.spec_config.relaxed_topk) draft_tokens = spec_metadata.draft_tokens.reshape( num_gens, mtp_num_modules) return topk_value, topk_indices, draft_tokens @torch.compile(options={"max-autotune": True}) def process_generation_logits(self, logits, num_contexts): gen_logits = logits[num_contexts:] gen_logprobs = torch.softmax(gen_logits, dim=-1) return gen_logprobs def sample_and_accept_draft_tokens( self, input_ids: torch.IntTensor, logits: torch.Tensor, spec_metadata: MTPSpecMetadata, attn_metadata: AttentionMetadata, ): ''' Takes input logits and samples golden token + predictions from draft tokens. Runs acceptance algorithm to accept draft tokens. Currently only support greedy sampling. All decoding is done using Top1 and token equality is used for acceptance. Args: input_ids: torch.IntTensor [num_tokens] The input ids of all requests. Flatten. logits: torch.Tensor [num_tokens, vocab_size] Logits produced by the target model. spec_metadata: MTPSpecMetadata MTP speculative decoding metadata attn_metadata: AttentionMetadata Attention metadata Returns: accepted_tokens: torch.Tensor [batch_size, (max_draft_len + 1)] Accepted token ids. Flattened. num_accepted_tokens: torch.Tensor [batch_size] Number of accepted tokens per request. Example: Assume there are 3 MTP layers Prompt: ABCD Context phase: Target model: - input tokens: ABCD + [] - sampling tokens: E - accepted tokens: E Draft model: - input tokens: BCDE - new generated draft tokens: FGH Current sequence: ABCD E`FGH -> Whitespace separates tokens produced by each phase -> Backtick separates accepted and draft tokens Generation phase 1: Target model: - input tokens: E + FGH - sampling tokens: FGXY -> Sample with E's logit and get 'F'; Sample with F's logit, ... - accepted tokens: FGX -> 'X' will be treat as the accepted token Draft model: - input tokens: FGX - new generated draft tokens: PQR Current sequence: ABCD EFG X`PQR Generation phase 2: Target model: - input tokens: X + PQR - sampling tokens: PYST - accepted token: PY Draft model: - input tokens: PY - new generated draft tokens: UVW Current sequence: ABCD EFG XP Y`UVW ''' batch_size = attn_metadata.num_seqs num_contexts = attn_metadata.num_contexts num_gens = batch_size - num_contexts mtp_num_modules = self.spec_config.num_nextn_predict_layers if logits.dim() == 1: logits = logits.unsqueeze(0) # The return buffer if self.spec_config.use_relaxed_acceptance_for_thinking or not self.is_thop: accepted_tokens = torch.ones((batch_size, (mtp_num_modules + 1)), dtype=torch.int, device=logits.device) num_accepted_tokens = torch.ones(batch_size, dtype=torch.int, device=logits.device) if self.spec_config.use_relaxed_acceptance_for_thinking: mtp_relaxed_delta_pool = spec_metadata.mtp_hidden_states_manager.mtp_relaxed_delta_pool # context con_logits = logits[:num_contexts] con_target_tokens = torch.argmax(con_logits, dim=-1) accepted_tokens[:num_contexts, 0] = con_target_tokens[:num_contexts] last_tokens_idx = torch.cumsum( attn_metadata.seq_lens_cuda, dim=0, dtype=torch.long) - 1 ctx_input_ids = input_ids[:attn_metadata.num_ctx_tokens] ctx_is_think = (ctx_input_ids == self.spec_config.begin_thinking_phase_token).int() ctx_is_think_cumsum = torch.cumsum(ctx_is_think, dim=0) ctx_last_cumsum = ctx_is_think_cumsum[ last_tokens_idx[:num_contexts]] ctx_think_tokens_num = torch.diff( ctx_last_cumsum, dim=0, prepend=torch.zeros(1, dtype=torch.int, device=ctx_last_cumsum.device)) ctx_delta = (ctx_think_tokens_num >= 1).int() * self.spec_config.relaxed_delta ctx_slot_ids = spec_metadata.slot_ids[:num_contexts] mtp_relaxed_delta_pool.index_copy_(0, ctx_slot_ids, ctx_delta) # generation gen_logprobs = self.process_generation_logits(logits, num_contexts) topk_value, topk_indices, draft_tokens = self.topk_kernel( gen_logprobs, num_gens, mtp_num_modules, spec_metadata) accepted_tokens, num_accepted_tokens = torch.ops.trtllm.mtp_relaxed_acceptance_op( spec_metadata.slot_ids, topk_value, topk_indices, draft_tokens, mtp_relaxed_delta_pool, num_accepted_tokens, accepted_tokens, mtp_num_modules, batch_size, num_contexts, self.spec_config.relaxed_topk, self.spec_config.relaxed_delta, self.spec_config.begin_thinking_phase_token, self.spec_config.end_thinking_phase_token) # Strict acceptance else: if self.is_thop: # Temporary buffer target_tokens_cache = torch.zeros(batch_size * (mtp_num_modules + 1), dtype=torch.int, device=logits.device) accepted_tokens, num_accepted_tokens = torch.ops.trtllm.mtp_sampling_and_accepted_draft_tokens_op( logits, spec_metadata.draft_tokens, target_tokens_cache, mtp_num_modules, batch_size, num_contexts, logits.shape[-1]) else: target_tokens = self._sample_tokens_for_batch( logits, spec_metadata, num_contexts, batch_size) # context accepted_tokens[:num_contexts, 0] = target_tokens[:num_contexts] # generation gen_target_tokens = target_tokens[num_contexts:].reshape( num_gens, mtp_num_modules + 1) accepted_tokens[num_contexts:, :] = gen_target_tokens draft_tokens = spec_metadata.draft_tokens.reshape( num_gens, mtp_num_modules) num_accepted_tokens[num_contexts:] += torch.cumprod( (draft_tokens == gen_target_tokens[:, :mtp_num_modules] ).int(), dim=-1).sum(1) # Check for environment variable override if self.force_num_accepted_tokens != 0: # total tokens per iteration = accepted draft tokens + 1 target token force_total_tokens = min(self.force_num_accepted_tokens + 1, mtp_num_modules + 1) num_accepted_tokens[num_contexts:] = force_total_tokens return accepted_tokens, num_accepted_tokens def change_attn_metadata(self, num_accepted_tokens: torch.Tensor, attn_metadata: AttentionMetadata): attn_metadata.prepare_for_spec_dec("_seq_lens", "_seq_lens_cuda") batch_size = attn_metadata.num_seqs mtp_num_modules = self.spec_config.num_nextn_predict_layers num_contexts = attn_metadata.num_contexts attn_metadata._seq_lens[num_contexts:batch_size] -= 1 attn_metadata._seq_lens_cuda[num_contexts:batch_size] -= 1 attn_metadata.on_update() if hasattr(attn_metadata, 'kv_lens_cuda'): # Note that it's important to not free the seq_lens_cuda # buffer once the graph has been captured also - this will invalidate # the graph and force an expensive recapture. attn_metadata.kv_lens_cuda[num_contexts:batch_size] -= ( mtp_num_modules + 1 - num_accepted_tokens[num_contexts:batch_size]) if attn_metadata.kv_cache_params is not None and not attn_metadata.is_cuda_graph: for i in range(num_contexts, batch_size): # used for vanilla MLA, list on cpu attn_metadata.kv_cache_params.num_cached_tokens_per_seq[ i] -= mtp_num_modules + 1 - num_accepted_tokens[i].item() def restore_attn_metadata(self, attn_metadata: AttentionMetadata): attn_metadata.restore_from_spec_dec() attn_metadata.on_update() def prepare_drafter_inputs( self, input_ids: torch.IntTensor, position_ids: torch.IntTensor, hidden_states: torch.Tensor, accepted_tokens: torch.Tensor, num_accepted_tokens: torch.Tensor, spec_metadata: MTPSpecMetadata, attn_metadata: AttentionMetadata, ): ''' Parepare the input of the draft model. Args: input_ids: torch.IntTensor [num_tokens] The input ids of all requests. Flatten. num_tokens = sum(all prompts) + num_generation * (mtp_num_modules + 1) position_ids: torch.IntTensor [1][num_tokens] The position id of all requests. Flatten. hidden_states: torch.Tensor [num_tokens, hidden_size] Target model's hidden states. accepted_tokens: torch.Tensor [batch_size, max_draft_len + 1] Accepted token ids. Flattened. num_accepted_tokens: torch.Tensor [batch_size] Number of accepted draft tokens. Will be used for the first MTP layer. spec_metadata: MTPSpecMetadata MTP speculative decoding metadata attn_metadata: AttentionMetadata Attention metadata Returns: draft_inputs input_ids: torch.Tensor [num_tokens] The new input ids of all requests. Flatten. num_tokens = sum(all prompts) + num_generation * (mtp_num_modules) position_ids: torch.Tensor [1][[num_tokens]] The new position ids of all requests. Flatten. Directly use the input position ids. hidden_states: torch.Tensor [num_tokens][hidden_size] Continuous hidden states buffer. attn_metadata: AttentionMetadata Attention metadata spec_metadata: MTPSpecMetadata MTP speculative decoding metadata ''' batch_size = attn_metadata.num_seqs num_contexts = attn_metadata.num_contexts num_ctx_tokens = attn_metadata.num_ctx_tokens num_gens = batch_size - num_contexts mtp_past_hidden_states_pool = spec_metadata.mtp_hidden_states_manager.mtp_past_hidden_states_pool mtp_past_tokens_pool = spec_metadata.mtp_hidden_states_manager.mtp_past_tokens_pool mtp_num_modules = self.spec_config.num_nextn_predict_layers if self.is_thop: # Temporary buffer hidden_size = hidden_states.shape[-1] # generation requests' golden tokens num_tokens = input_ids.shape[0] - num_gens return_input_ids = torch.empty(num_tokens, dtype=torch.int, device="cuda") return_hidden_states = torch.empty((num_tokens, hidden_size), dtype=hidden_states.dtype, device="cuda") (return_input_ids, return_hidden_states ) = torch.ops.trtllm.mtp_prepare_drafter_inputs_op( input_ids, attn_metadata.seq_lens_cuda, spec_metadata.mtp_hidden_states_ptrs, spec_metadata.mtp_past_tokens_ptrs, hidden_states, accepted_tokens, num_accepted_tokens, return_input_ids, return_hidden_states, mtp_num_modules, batch_size, num_contexts, hidden_size) else: return_input_ids_list = [] return_hidden_states_list = [] # Calculate cumulative sequence lengths for indexing last_tokens_idx = torch.cumsum( attn_metadata.seq_lens_cuda, dim=0, dtype=torch.long) - 1 # context if num_contexts > 0: hidden_states_ctx = hidden_states[:num_ctx_tokens, :] input_prompt_ids = input_ids[:num_ctx_tokens] input_ids_ctx = torch.empty_like(input_prompt_ids, dtype=torch.int32, device="cuda") input_ids_ctx[:-1].copy_(input_prompt_ids[1:]) input_ids_ctx[last_tokens_idx[:num_contexts]] = \ accepted_tokens[:num_contexts, 0] return_input_ids_list.append(input_ids_ctx) return_hidden_states_list.append(hidden_states_ctx) # generation if num_gens > 0: slot_ids = spec_metadata.slot_ids[num_contexts:batch_size] gen_batch_idx = spec_metadata.batch_indices_cuda[:num_gens] gen_token_idx = num_accepted_tokens[num_contexts:] - 1 accepted_tokens_gen = accepted_tokens[num_contexts:, :] input_ids_gen = accepted_tokens_gen[gen_batch_idx, gen_token_idx].unsqueeze(1) input_ids_gen = torch.concat( [mtp_past_tokens_pool[slot_ids][:, 1:], input_ids_gen], dim=1) hidden_states_gen = mtp_past_hidden_states_pool[ slot_ids].flatten(0, 1) return_input_ids_list.append(input_ids_gen.flatten(0, 1)) return_hidden_states_list.append(hidden_states_gen) # Concatenate into continuous buffers return_input_ids = torch.concat(return_input_ids_list, dim=0) return_hidden_states = torch.concat(return_hidden_states_list, dim=0) # update position_ids position_ids_list = [] if num_contexts > 0: position_ids_list.append(position_ids[:num_ctx_tokens]) if num_gens > 0: position_ids_gen = position_ids[num_ctx_tokens:].reshape( num_gens, mtp_num_modules + 1)[:, -mtp_num_modules:] position_ids_gen = position_ids_gen - ( 1 + mtp_num_modules - num_accepted_tokens[num_contexts:].unsqueeze(1)) position_ids_list.append(position_ids_gen.flatten()) return_position_ids = torch.concat(position_ids_list, dim=-1) return { "input_ids": return_input_ids, "position_ids": return_position_ids, "hidden_states": return_hidden_states, "attn_metadata": attn_metadata, } @torch.compile(options={"max-autotune": True}) def get_local_max_and_combined(self, logits, mapping_lm_tp=None): local_max_values, local_argmax = torch.max(logits, dim=-1, keepdim=True) # Adjust indices based on TP rank and size vocab_per_rank = logits.shape[-1] mapping_lm_tp = mapping_lm_tp if mapping_lm_tp is not None else self.model_config.mapping max_index_per_rank = local_argmax.type( torch.int32) + (mapping_lm_tp.tp_rank * vocab_per_rank) # Use torch.stack and flatten instead of view+cat to avoid torch.compile issues # Convert both to float32 to ensure consistent dtype max_index_per_rank_float = max_index_per_rank.float() local_max_values_float32 = local_max_values.float() # Stack and flatten to get interleaved layout: [idx0, val0, idx1, val1, ...] combined = torch.stack( [max_index_per_rank_float, local_max_values_float32], dim=-1).flatten(-2) return combined @torch.compile(options={"max-autotune": True}) def get_draft_tokens_from_gathered(self, gathered): gathered_indices_float = gathered[..., 0::2] # Even positions: indices gathered_values_float = gathered[..., 1::2] # Odd positions: values # Find the rank with maximum value max_indices = torch.argmax(gathered_values_float, dim=-1, keepdim=True) # Get the corresponding token indices and convert back to int32 draft_tokens = torch.gather(gathered_indices_float, -1, max_indices).squeeze(-1).type(torch.int32) return draft_tokens def draft_sampler( self, logits: torch.Tensor, mapping_lm_head_tp: Mapping = None, ): ''' Sampling draft tokens. Args: logits: torch.Tensor [num_tokens, vocab_size] Logits produced by the draft model. Returns: draft_tokens: torch.Tensor [batch_size * max_draft_len] Draft token ids. Flattened. ''' if (self.model_config is not None and hasattr(self.model_config, 'mapping') and self.model_config.mapping.tp_size > 1) and not (self.model_config.mapping.enable_attention_dp): combined = self.get_local_max_and_combined(logits) gathered = allgather(combined, self.model_config.mapping, dim=-1) draft_tokens = self.get_draft_tokens_from_gathered(gathered) elif (self.model_config is not None and hasattr(self.model_config, 'mapping') and self.model_config.mapping.tp_size > 1) and self.model_config.mapping.enable_lm_head_tp_in_adp: # For ADP + LM head TP mode, we need to find the global argmax across all TP ranks combined = self.get_local_max_and_combined(logits, mapping_lm_head_tp) gathered = allgather(combined, mapping_lm_head_tp, dim=-1) batch_size = logits.shape[0] local_batch_size = batch_size // mapping_lm_head_tp.tp_size gathered = gathered.view(mapping_lm_head_tp.tp_size, local_batch_size, -1) sliced_gathered = gathered[mapping_lm_head_tp.tp_rank] draft_tokens = self.get_draft_tokens_from_gathered(sliced_gathered) else: # Simple argmax if no TP or no model config draft_tokens = torch.argmax(logits, dim=-1).type(torch.int32) return draft_tokens class MTPEagleWorker(MTPWorker): def __init__(self, spec_config: "MTPDecodingConfig", model_config: Optional[ModelConfig] = None): super().__init__(spec_config, model_config) self.model_config = model_config self.mtp_num_modules = spec_config.num_nextn_predict_layers @torch.compile(options={"max-autotune": True}) def update_draft_tokens(self, next_draft_tokens, new_draft_token, hidden_states, gather_ids, inputs): next_draft_tokens.append(new_draft_token) # update inputs hidden_states = hidden_states[gather_ids] position_ids = inputs["position_ids"][gather_ids] + 1 return hidden_states, position_ids def forward( self, input_ids, position_ids, hidden_states, logits, attn_metadata, spec_metadata, draft_model, ): batch_size = attn_metadata.num_seqs num_contexts = attn_metadata.num_contexts num_gens = batch_size - num_contexts raw_logits = logits if self.guided_decoder is not None: self.guided_decoder.execute(logits) # Sample and verify draft tokens accepted_tokens, num_accepted_tokens = self.sample_and_accept_draft_tokens( input_ids, logits, spec_metadata, attn_metadata) # Save the old attn_metadata and spec_metadata attn_metadata.prepare_for_spec_dec("_seq_lens", "_seq_lens_cuda") # Prepare inputs for the 1st MTP layer @torch.compile(options={"max-autotune": True}) def prepare_position_ids_and_last_tokens(position_ids, attn_metadata): position_ids = position_ids.squeeze(0) last_tokens_idx = torch.cumsum( attn_metadata.seq_lens_cuda, dim=0, dtype=torch.long) - 1 return position_ids, last_tokens_idx position_ids, last_tokens_idx = prepare_position_ids_and_last_tokens( position_ids, attn_metadata) inputs = self.prepare_drafter_inputs(input_ids=input_ids, position_ids=position_ids, last_tokens_idx=last_tokens_idx, hidden_states=hidden_states, accepted_tokens=accepted_tokens, attn_metadata=attn_metadata, spec_metadata=spec_metadata) # Predict draft tokens next_draft_tokens = [] for i in range(self.mtp_num_modules): if i == 0: hidden_states = draft_model.mtp_layers[0]( embed_tokens=draft_model.embed_tokens, all_rank_num_tokens=spec_metadata.all_rank_num_tokens, **inputs) start_ids_gen = (spec_metadata.batch_indices_cuda[:num_gens] * (self.mtp_num_modules + 1)).long() gather_ids_gen = (start_ids_gen + num_accepted_tokens[num_contexts:] - 1 + attn_metadata.num_ctx_tokens) gather_ids = torch.concat( [last_tokens_idx[:num_contexts], gather_ids_gen], dim=0) else: hidden_states = draft_model.mtp_layers[0]( embed_tokens=draft_model.embed_tokens, all_rank_num_tokens=spec_metadata. subseq_all_rank_num_tokens, **inputs) # All of the seq_len are 1, use batch_indices_cuda as gather_ids gather_ids = spec_metadata.batch_indices_cuda[:batch_size] if self.guided_decoder is not None: new_tokens = inputs["input_ids"][gather_ids] self.guided_decoder.add_draft_batch(new_tokens, num_accepted_tokens, draft_step=i) if self.model_config.mapping.enable_attention_dp and \ getattr(self.model_config.mapping, 'enable_lm_head_tp_in_adp', False): hidden_states_gathered = hidden_states[gather_ids] token_count = hidden_states_gathered.view( -1, hidden_states_gathered.shape[-1]).shape[0] max_num_requests = spec_metadata.max_num_requests pad_len = max_num_requests - token_count if pad_len > 0: padded_hidden_states = F.pad(hidden_states_gathered.view( -1, hidden_states_gathered.shape[-1]), (0, 0, 0, pad_len), mode="constant", value=0) elif pad_len == 0: padded_hidden_states = hidden_states_gathered.view( -1, hidden_states_gathered.shape[-1]) else: raise ValueError( f"In MTPEagleWorker.forward(), token_count < max_num_requests, which is not supported" ) logits = draft_model.mtp_layers[0].shared_head( padded_hidden_states, draft_model.lm_head, attn_metadata, True) else: logits = draft_model.mtp_layers[0].shared_head( hidden_states[gather_ids], draft_model.lm_head, attn_metadata, True) if self.guided_decoder is not None: self.guided_decoder.execute_draft_batch(logits, draft_step=i) if self.model_config.mapping.enable_attention_dp and \ getattr(self.model_config.mapping, 'enable_lm_head_tp_in_adp', False): mapping_lm_head_tp = draft_model.mtp_layers[ 0].shared_head.mapping_lm_head_tp new_draft_token = self.draft_sampler(logits, mapping_lm_head_tp) new_draft_token = new_draft_token[:token_count] else: new_draft_token = self.draft_sampler(logits) hidden_states, position_ids = self.update_draft_tokens( next_draft_tokens, new_draft_token, hidden_states, gather_ids, inputs) # update attn_metadata if i == 0: attn_metadata._seq_lens[:batch_size].fill_(1) attn_metadata._seq_lens_cuda[:batch_size].fill_(1) attn_metadata.on_update() # cannot run generation if their is no kv cache has_kv_cache = inputs[ "attn_metadata"].kv_cache_manager is not None if has_kv_cache: attn_metadata.host_request_types[:attn_metadata. num_contexts].fill_(1) attn_metadata.num_contexts = 0 # update kv_lens_cuda if hasattr(attn_metadata, 'kv_lens_cuda'): attn_metadata.kv_lens_cuda[num_contexts:batch_size] -= ( self.mtp_num_modules - num_accepted_tokens[num_contexts:]) attn_metadata.kv_lens_cuda[:num_contexts] += 1 # update metadata for flash mla if has_kv_cache and num_contexts > 0 and attn_metadata.enable_flash_mla: reorder_block_ids_per_seq = torch.cat([ attn_metadata. kv_block_ids_per_seq[num_contexts:batch_size], attn_metadata.kv_block_ids_per_seq[:num_contexts] ]) attn_metadata.block_ids_per_seq[:batch_size, :].copy_( reorder_block_ids_per_seq, non_blocking=True) # update metadata # some attention metadata needs to be updated when changing seq_lens/kv_lens attn_metadata.update_for_spec_dec() elif hasattr(attn_metadata, 'kv_lens_cuda'): @torch.compile(options={"max-autotune": True}) def update_kv_lens(kv_lens_cuda, batch_size): kv_lens_cuda[:batch_size] += 1 update_kv_lens(attn_metadata.kv_lens_cuda, batch_size) # update metadata # some attention metadata needs to be updated when changing kv_lens attn_metadata.update_for_spec_dec() inputs = { "input_ids": new_draft_token, "position_ids": position_ids, "hidden_states": hidden_states, "attn_metadata": attn_metadata, } # restore attn_metadata to support cuda graph attn_metadata.restore_from_spec_dec() attn_metadata.on_update() @torch.compile(options={"max-autotune": True}) def prepare_next_tokens(next_draft_tokens, accepted_tokens, spec_metadata, batch_size, num_accepted_tokens): next_draft_tokens = torch.stack(next_draft_tokens, dim=1) # prepare next new tokens to support overlap scheduler next_new_tokens = accepted_tokens[ spec_metadata.batch_indices_cuda[:batch_size], num_accepted_tokens - 1].unsqueeze(1) next_new_tokens = torch.concat([next_new_tokens, next_draft_tokens], dim=1) return next_draft_tokens, next_new_tokens next_draft_tokens, next_new_tokens = prepare_next_tokens( next_draft_tokens, accepted_tokens, spec_metadata, batch_size, num_accepted_tokens) return { 'logits': raw_logits, 'new_tokens': accepted_tokens, 'new_tokens_lens': num_accepted_tokens, 'next_draft_tokens': next_draft_tokens, 'next_new_tokens': next_new_tokens } @torch.compile(options={"max-autotune": True}) def prepare_drafter_inputs( self, input_ids: torch.IntTensor, position_ids: torch.IntTensor, last_tokens_idx: torch.LongTensor, hidden_states: torch.Tensor, accepted_tokens: torch.Tensor, attn_metadata: AttentionMetadata, spec_metadata: MTPSpecMetadata, ): num_contexts = attn_metadata.num_contexts # context input_prompt_ids = input_ids[:attn_metadata.num_ctx_tokens] input_ids_ctx = torch.empty_like(input_prompt_ids, dtype=torch.int32, device="cuda") input_ids_ctx[:-1].copy_(input_prompt_ids[1:]) input_ids_ctx[ last_tokens_idx[:num_contexts]] = accepted_tokens[:num_contexts, 0] # generation input_ids_gen = accepted_tokens[num_contexts:, :].flatten() # get draft inputs input_ids = torch.concat([input_ids_ctx, input_ids_gen], dim=0) return { "input_ids": input_ids, "position_ids": position_ids, "hidden_states": hidden_states, "attn_metadata": attn_metadata, }