Mtp optimizations round1 (#5689)

Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
Co-authored-by: Kefeng-Duan <176893526+Kefeng-Duan@users.noreply.github.com>
This commit is contained in:
ameynaik-hub 2025-07-25 10:48:27 -07:00 committed by GitHub
parent 7bff341553
commit 1e5e71aa42
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 191 additions and 56 deletions

View File

@ -131,11 +131,21 @@ class DeepseekV3MTPHead(nn.Module):
def __init__(self, model_config: ModelConfig[PretrainedConfig]):
super().__init__()
config = model_config.pretrained_config
self.model_config = model_config
self.norm = RMSNorm(hidden_size=config.hidden_size,
eps=config.rms_norm_eps,
dtype=config.torch_dtype)
@torch.compile(options={"max-autotune": True})
def get_last_token_states(self, hidden_states, attn_metadata):
last_tokens = torch.cumsum(
attn_metadata.seq_lens_cuda,
dim=0,
dtype=torch.long,
) - 1
return hidden_states[last_tokens]
def forward(self,
hidden_states: torch.Tensor,
lm_head: Linear,
@ -143,16 +153,16 @@ class DeepseekV3MTPHead(nn.Module):
return_context_logits: bool = False) -> torch.Tensor:
if not return_context_logits:
if attn_metadata is not None:
last_tokens = torch.cumsum(
attn_metadata.seq_lens_cuda,
dim=0,
dtype=torch.long,
) - 1
hidden_states = hidden_states[last_tokens]
hidden_states = self.get_last_token_states(
hidden_states, attn_metadata)
else:
hidden_states = hidden_states[-1].unsqueeze(0)
if not (self.model_config.mapping.enable_attention_dp):
lm_head.gather_output = False
logits = lm_head(hidden_states)
if not (self.model_config.mapping.enable_attention_dp):
lm_head.gather_output = True
return logits
@ -903,6 +913,12 @@ class DeepseekV3MTP(DeepseekV3DecoderLayer):
self.num_shared_experts = config.n_shared_experts
self.top_k = config.num_experts_per_tok
self.aux_stream = aux_stream_dict[AuxStreamType.MoeShared]
self.event_dict = {
key: torch.cuda.Event()
for key in [EventType.Main, EventType.MoeShared]
}
self.enorm = RMSNorm(hidden_size=config.hidden_size,
eps=config.rms_norm_eps,
dtype=config.torch_dtype)
@ -910,15 +926,27 @@ class DeepseekV3MTP(DeepseekV3DecoderLayer):
self.hnorm = RMSNorm(hidden_size=config.hidden_size,
eps=config.rms_norm_eps,
dtype=config.torch_dtype)
self.eh_proj = Linear(
config.hidden_size * 2,
config.hidden_size,
bias=False,
dtype=config.torch_dtype,
skip_create_weights_in_init=model_config.
skip_create_weights_in_init,
)
if model_config.mapping.enable_attention_dp:
self.eh_proj = Linear(
config.hidden_size * 2,
config.hidden_size,
bias=False,
dtype=config.torch_dtype,
skip_create_weights_in_init=model_config.
skip_create_weights_in_init,
)
else:
self.eh_proj = Linear(
config.hidden_size * 2,
config.hidden_size,
bias=False,
dtype=config.torch_dtype,
tensor_parallel_mode=TensorParallelMode.ROW,
mapping=model_config.mapping,
reduce_output=True,
skip_create_weights_in_init=model_config.
skip_create_weights_in_init,
)
self.shared_head = DeepseekV3MTPHead(model_config)
@ -934,9 +962,26 @@ class DeepseekV3MTP(DeepseekV3DecoderLayer):
**kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
inputs_embeds = self.enorm(embed_tokens(input_ids))
hidden_states = self.hnorm(hidden_states)
def norm_embeds():
return self.enorm(embed_tokens(input_ids)) #emdedding
def norm_hidden():
return self.hnorm(hidden_states)
inputs_embeds, hidden_states = maybe_execute_in_parallel(
norm_embeds,
norm_hidden,
self.event_dict[EventType.Main],
self.event_dict[EventType.MoeShared],
self.aux_stream,
)
hidden_states = torch.concat([inputs_embeds, hidden_states], dim=-1)
# Split hidden_states columnwise based on TP
tp_size = self.model_config.mapping.tp_size
tp_rank = self.model_config.mapping.tp_rank
if tp_size > 1 and not (self.model_config.mapping.enable_attention_dp):
hidden_states = torch.chunk(hidden_states, tp_size, dim=-1)[tp_rank]
hidden_states = self.eh_proj(hidden_states)
# Input layer norm
@ -1074,7 +1119,8 @@ class DeepseekV3ForCausalLM(DecoderModelForCausalLM[DeepseekV3Model,
self.model.aux_stream_dict)
self.model.layers.append(mtp_layer)
self.epilogue.append(mtp_layer)
self.mtp_worker = MTPEagleWorker(model_config.spec_config)
self.mtp_worker = MTPEagleWorker(model_config.spec_config,
model_config)
else:
mtp_layers = nn.ModuleList([
DeepseekV3MTP(model_config,
@ -1084,7 +1130,8 @@ class DeepseekV3ForCausalLM(DecoderModelForCausalLM[DeepseekV3Model,
])
self.model.layers.extend(mtp_layers)
self.epilogue.extend(mtp_layers)
self.mtp_worker = MTPWorker(model_config.spec_config)
self.mtp_worker = MTPWorker(model_config.spec_config,
model_config)
# modify the QuantConfig to support duplicated mtp layers
if model_config.quant_config.exclude_modules is not None:
extend_exclude_modules = []

View File

@ -359,6 +359,7 @@ class SpecDecOneEngineForCausalLM(DecoderModelForCausalLM[TModel, TConfig],
self.draft_model = get_draft_model(model_config, draft_config)
self.spec_worker = get_spec_worker(model_config.spec_config,
model_config,
model_config.mapping)
def forward(

View File

@ -1,10 +1,12 @@
from dataclasses import dataclass
from typing import List, Optional
from typing import TYPE_CHECKING, List, Optional
import torch
from torch import nn
from ..attention_backend import AttentionMetadata
from ..distributed.ops import allgather
from ..model_config import ModelConfig
from ..pyexecutor.llm_request import LlmRequest, LlmRequestState
from ..pyexecutor.resource_manager import BaseResourceManager, SlotManager
from ..pyexecutor.sampler import (SampleState, SampleStateTensors, TorchSampler,
@ -12,6 +14,9 @@ from ..pyexecutor.sampler import (SampleState, SampleStateTensors, TorchSampler,
from ..pyexecutor.scheduler import ScheduledRequests
from .interface import SpecMetadata
if TYPE_CHECKING:
from tensorrt_llm.llmapi.llm_args import MTPDecodingConfig
@dataclass(kw_only=True)
class SampleStateTensorsMTP(SampleStateTensors):
@ -311,9 +316,10 @@ class MTPSampler(TorchSampler):
class MTPWorker(nn.Module):
def __init__(self, spec_config: "MTPDecodingConfig"):
def __init__(self, spec_config: "MTPDecodingConfig", model_config=None):
super().__init__()
self.spec_config = spec_config
self.model_config = model_config
self.is_thop = False
def forward(
@ -670,6 +676,26 @@ class MTPWorker(nn.Module):
mtp_past_hidden_states_pool.index_copy_(0, slot_ids,
new_mtp_past_hidden_states)
@torch.compile(options={"max-autotune": True})
def topk_kernel(self, gen_logprobs, num_gens, mtp_num_modules,
spec_metadata):
topk_value, topk_indices = torch.topk(gen_logprobs,
k=self.spec_config.relaxed_topk,
dim=-1)
topk_indices = topk_indices.reshape(num_gens, mtp_num_modules + 1,
self.spec_config.relaxed_topk)
topk_value = topk_value.reshape(num_gens, mtp_num_modules + 1,
self.spec_config.relaxed_topk)
draft_tokens = spec_metadata.draft_tokens.reshape(
num_gens, mtp_num_modules)
return topk_value, topk_indices, draft_tokens
@torch.compile(options={"max-autotune": True})
def process_generation_logits(self, logits, num_contexts):
gen_logits = logits[num_contexts:]
gen_logprobs = torch.softmax(gen_logits, dim=-1)
return gen_logprobs
def sample_and_accept_draft_tokens(
self,
input_ids: torch.IntTensor,
@ -787,20 +813,9 @@ class MTPWorker(nn.Module):
mtp_relaxed_delta_pool.index_copy_(0, ctx_slot_ids, ctx_delta)
# generation
gen_logits = logits[num_contexts:]
gen_logprobs = torch.softmax(gen_logits, dim=-1)
topk_value, topk_indices = torch.topk(
gen_logprobs, k=self.spec_config.relaxed_topk, dim=-1)
# [num_gens, mtp_num_modules + 1, relaxed_topk]
topk_indices = topk_indices.reshape(num_gens, mtp_num_modules + 1,
self.spec_config.relaxed_topk)
topk_value = topk_value.reshape(num_gens, mtp_num_modules + 1,
self.spec_config.relaxed_topk)
# [num_gens, mtp_num_modules]
draft_tokens = spec_metadata.draft_tokens.reshape(
num_gens, mtp_num_modules)
gen_logprobs = self.process_generation_logits(logits, num_contexts)
topk_value, topk_indices, draft_tokens = self.topk_kernel(
gen_logprobs, num_gens, mtp_num_modules, spec_metadata)
accepted_tokens, num_accepted_tokens = torch.ops.trtllm.mtp_relaxed_acceptance_op(
spec_metadata.slot_ids, topk_value, topk_indices, draft_tokens,
@ -1024,6 +1039,37 @@ class MTPWorker(nn.Module):
"attn_metadata": attn_metadata,
}
@torch.compile(options={"max-autotune": True})
def get_local_max_and_combined(self, logits):
local_max_values, local_argmax = torch.max(logits, dim=-1, keepdim=True)
# Adjust indices based on TP rank and size
vocab_per_rank = logits.shape[-1]
max_index_per_rank = local_argmax.type(
torch.int32) + (self.model_config.mapping.tp_rank * vocab_per_rank)
# Use torch.stack and flatten instead of view+cat to avoid torch.compile issues
# Convert both to float32 to ensure consistent dtype
max_index_per_rank_float = max_index_per_rank.float()
local_max_values_float32 = local_max_values.float()
# Stack and flatten to get interleaved layout: [idx0, val0, idx1, val1, ...]
combined = torch.stack(
[max_index_per_rank_float, local_max_values_float32],
dim=-1).flatten(-2)
return combined
@torch.compile(options={"max-autotune": True})
def get_draft_tokens_from_gathered(self, gathered):
gathered_indices_float = gathered[..., 0::2] # Even positions: indices
gathered_values_float = gathered[..., 1::2] # Odd positions: values
# Find the rank with maximum value
max_indices = torch.argmax(gathered_values_float, dim=-1, keepdim=True)
# Get the corresponding token indices and convert back to int32
draft_tokens = torch.gather(gathered_indices_float, -1,
max_indices).squeeze(-1).type(torch.int32)
return draft_tokens
def draft_sampler(
self,
logits: torch.Tensor,
@ -1041,17 +1087,38 @@ class MTPWorker(nn.Module):
[batch_size * max_draft_len]
Draft token ids. Flattened.
'''
if (self.model_config is not None
and hasattr(self.model_config, 'mapping')
and self.model_config.mapping.tp_size
> 1) and not (self.model_config.mapping.enable_attention_dp):
combined = self.get_local_max_and_combined(logits)
gathered = allgather(combined, self.model_config.mapping, dim=-1)
draft_tokens = self.get_draft_tokens_from_gathered(gathered)
else:
# Simple argmax if no TP or no model config
draft_tokens = torch.argmax(logits, dim=-1).type(torch.int32)
draft_tokens = torch.argmax(logits, dim=-1).type(torch.int32)
return draft_tokens
class MTPEagleWorker(MTPWorker):
def __init__(self, spec_config: "MTPDecodingConfig"):
super().__init__(spec_config)
def __init__(self,
spec_config: "MTPDecodingConfig",
model_config: Optional[ModelConfig] = None):
super().__init__(spec_config, model_config)
self.model_config = model_config
self.mtp_num_modules = spec_config.num_nextn_predict_layers
@torch.compile(options={"max-autotune": True})
def update_draft_tokens(self, next_draft_tokens, new_draft_token,
hidden_states, gather_ids, inputs):
next_draft_tokens.append(new_draft_token)
# update inputs
hidden_states = hidden_states[gather_ids]
position_ids = inputs["position_ids"][gather_ids] + 1
return hidden_states, position_ids
def forward(
self,
input_ids,
@ -1079,9 +1146,15 @@ class MTPEagleWorker(MTPWorker):
seq_len_cuda = attn_metadata._seq_lens_cuda[:batch_size].clone()
# Prepare inputs for the 1st MTP layer
position_ids = position_ids.squeeze(0)
last_tokens_idx = torch.cumsum(
attn_metadata.seq_lens_cuda, dim=0, dtype=torch.long) - 1
@torch.compile(options={"max-autotune": True})
def prepare_position_ids_and_last_tokens(position_ids, attn_metadata):
position_ids = position_ids.squeeze(0)
last_tokens_idx = torch.cumsum(
attn_metadata.seq_lens_cuda, dim=0, dtype=torch.long) - 1
return position_ids, last_tokens_idx
position_ids, last_tokens_idx = prepare_position_ids_and_last_tokens(
position_ids, attn_metadata)
inputs = self.prepare_drafter_inputs(input_ids=input_ids,
position_ids=position_ids,
last_tokens_idx=last_tokens_idx,
@ -1122,10 +1195,10 @@ class MTPEagleWorker(MTPWorker):
logits = mtp_layers[0].shared_head(hidden_states[gather_ids],
lm_head, attn_metadata, True)
new_draft_token = self.draft_sampler(logits)
next_draft_tokens.append(new_draft_token)
# update inputs
hidden_states = hidden_states[gather_ids]
position_ids = inputs["position_ids"][gather_ids] + 1
hidden_states, position_ids = self.update_draft_tokens(
next_draft_tokens, new_draft_token, hidden_states, gather_ids,
inputs)
# update attn_metadata
if i == 0:
attn_metadata._seq_lens[:batch_size].fill_(1)
@ -1154,14 +1227,18 @@ class MTPEagleWorker(MTPWorker):
attn_metadata.block_ids_per_seq[:batch_size, :].copy_(
reorder_block_ids_per_seq, non_blocking=True)
elif hasattr(attn_metadata, 'kv_lens_cuda'):
attn_metadata.kv_lens_cuda[:batch_size] += 1
@torch.compile(options={"max-autotune": True})
def update_kv_lens(kv_lens_cuda, batch_size):
kv_lens_cuda[:batch_size] += 1
update_kv_lens(attn_metadata.kv_lens_cuda, batch_size)
inputs = {
"input_ids": new_draft_token,
"position_ids": position_ids,
"hidden_states": hidden_states,
"attn_metadata": attn_metadata,
}
next_draft_tokens = torch.stack(next_draft_tokens, dim=1)
# restore attn_metadata to support cuda graph
if attn_metadata.is_cuda_graph:
@ -1169,12 +1246,21 @@ class MTPEagleWorker(MTPWorker):
attn_metadata._seq_lens_cuda[:batch_size].copy_(seq_len_cuda)
attn_metadata.on_update()
# prepare next new tokens to support overlap scheduler
next_new_tokens = accepted_tokens[
spec_metadata.batch_indices_cuda[:batch_size],
num_accepted_tokens - 1].unsqueeze(1)
next_new_tokens = torch.concat([next_new_tokens, next_draft_tokens],
dim=1)
@torch.compile(options={"max-autotune": True})
def prepare_next_tokens(next_draft_tokens, accepted_tokens,
spec_metadata, batch_size, num_accepted_tokens):
next_draft_tokens = torch.stack(next_draft_tokens, dim=1)
# prepare next new tokens to support overlap scheduler
next_new_tokens = accepted_tokens[
spec_metadata.batch_indices_cuda[:batch_size],
num_accepted_tokens - 1].unsqueeze(1)
next_new_tokens = torch.concat([next_new_tokens, next_draft_tokens],
dim=1)
return next_draft_tokens, next_new_tokens
next_draft_tokens, next_new_tokens = prepare_next_tokens(
next_draft_tokens, accepted_tokens, spec_metadata, batch_size,
num_accepted_tokens)
return {
'logits': raw_logits,
@ -1184,6 +1270,7 @@ class MTPEagleWorker(MTPWorker):
'next_new_tokens': next_new_tokens
}
@torch.compile(options={"max-autotune": True})
def prepare_drafter_inputs(
self,
input_ids: torch.IntTensor,

View File

@ -145,11 +145,11 @@ def get_num_spec_layers(spec_config):
return 0
def get_spec_worker(spec_config, mapping):
def get_spec_worker(spec_config, model_config, mapping):
if spec_config.spec_dec_mode.is_mtp():
return MTPWorker(spec_config)
return MTPWorker(spec_config, model_config)
if spec_config.spec_dec_mode.is_mtp_eagle():
return MTPEagleWorker(spec_config)
return MTPEagleWorker(spec_config, model_config)
if spec_config.spec_dec_mode.is_eagle3_one_model():
return Eagle3OneModelWorker(spec_config, mapping)
return None