mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
Mtp optimizations round1 (#5689)
Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com> Co-authored-by: Kefeng-Duan <176893526+Kefeng-Duan@users.noreply.github.com>
This commit is contained in:
parent
7bff341553
commit
1e5e71aa42
@ -131,11 +131,21 @@ class DeepseekV3MTPHead(nn.Module):
|
||||
def __init__(self, model_config: ModelConfig[PretrainedConfig]):
|
||||
super().__init__()
|
||||
config = model_config.pretrained_config
|
||||
self.model_config = model_config
|
||||
|
||||
self.norm = RMSNorm(hidden_size=config.hidden_size,
|
||||
eps=config.rms_norm_eps,
|
||||
dtype=config.torch_dtype)
|
||||
|
||||
@torch.compile(options={"max-autotune": True})
|
||||
def get_last_token_states(self, hidden_states, attn_metadata):
|
||||
last_tokens = torch.cumsum(
|
||||
attn_metadata.seq_lens_cuda,
|
||||
dim=0,
|
||||
dtype=torch.long,
|
||||
) - 1
|
||||
return hidden_states[last_tokens]
|
||||
|
||||
def forward(self,
|
||||
hidden_states: torch.Tensor,
|
||||
lm_head: Linear,
|
||||
@ -143,16 +153,16 @@ class DeepseekV3MTPHead(nn.Module):
|
||||
return_context_logits: bool = False) -> torch.Tensor:
|
||||
if not return_context_logits:
|
||||
if attn_metadata is not None:
|
||||
last_tokens = torch.cumsum(
|
||||
attn_metadata.seq_lens_cuda,
|
||||
dim=0,
|
||||
dtype=torch.long,
|
||||
) - 1
|
||||
hidden_states = hidden_states[last_tokens]
|
||||
hidden_states = self.get_last_token_states(
|
||||
hidden_states, attn_metadata)
|
||||
else:
|
||||
hidden_states = hidden_states[-1].unsqueeze(0)
|
||||
|
||||
if not (self.model_config.mapping.enable_attention_dp):
|
||||
lm_head.gather_output = False
|
||||
logits = lm_head(hidden_states)
|
||||
if not (self.model_config.mapping.enable_attention_dp):
|
||||
lm_head.gather_output = True
|
||||
return logits
|
||||
|
||||
|
||||
@ -903,6 +913,12 @@ class DeepseekV3MTP(DeepseekV3DecoderLayer):
|
||||
self.num_shared_experts = config.n_shared_experts
|
||||
self.top_k = config.num_experts_per_tok
|
||||
|
||||
self.aux_stream = aux_stream_dict[AuxStreamType.MoeShared]
|
||||
self.event_dict = {
|
||||
key: torch.cuda.Event()
|
||||
for key in [EventType.Main, EventType.MoeShared]
|
||||
}
|
||||
|
||||
self.enorm = RMSNorm(hidden_size=config.hidden_size,
|
||||
eps=config.rms_norm_eps,
|
||||
dtype=config.torch_dtype)
|
||||
@ -910,15 +926,27 @@ class DeepseekV3MTP(DeepseekV3DecoderLayer):
|
||||
self.hnorm = RMSNorm(hidden_size=config.hidden_size,
|
||||
eps=config.rms_norm_eps,
|
||||
dtype=config.torch_dtype)
|
||||
|
||||
self.eh_proj = Linear(
|
||||
config.hidden_size * 2,
|
||||
config.hidden_size,
|
||||
bias=False,
|
||||
dtype=config.torch_dtype,
|
||||
skip_create_weights_in_init=model_config.
|
||||
skip_create_weights_in_init,
|
||||
)
|
||||
if model_config.mapping.enable_attention_dp:
|
||||
self.eh_proj = Linear(
|
||||
config.hidden_size * 2,
|
||||
config.hidden_size,
|
||||
bias=False,
|
||||
dtype=config.torch_dtype,
|
||||
skip_create_weights_in_init=model_config.
|
||||
skip_create_weights_in_init,
|
||||
)
|
||||
else:
|
||||
self.eh_proj = Linear(
|
||||
config.hidden_size * 2,
|
||||
config.hidden_size,
|
||||
bias=False,
|
||||
dtype=config.torch_dtype,
|
||||
tensor_parallel_mode=TensorParallelMode.ROW,
|
||||
mapping=model_config.mapping,
|
||||
reduce_output=True,
|
||||
skip_create_weights_in_init=model_config.
|
||||
skip_create_weights_in_init,
|
||||
)
|
||||
|
||||
self.shared_head = DeepseekV3MTPHead(model_config)
|
||||
|
||||
@ -934,9 +962,26 @@ class DeepseekV3MTP(DeepseekV3DecoderLayer):
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
|
||||
inputs_embeds = self.enorm(embed_tokens(input_ids))
|
||||
hidden_states = self.hnorm(hidden_states)
|
||||
def norm_embeds():
|
||||
return self.enorm(embed_tokens(input_ids)) #emdedding
|
||||
|
||||
def norm_hidden():
|
||||
return self.hnorm(hidden_states)
|
||||
|
||||
inputs_embeds, hidden_states = maybe_execute_in_parallel(
|
||||
norm_embeds,
|
||||
norm_hidden,
|
||||
self.event_dict[EventType.Main],
|
||||
self.event_dict[EventType.MoeShared],
|
||||
self.aux_stream,
|
||||
)
|
||||
hidden_states = torch.concat([inputs_embeds, hidden_states], dim=-1)
|
||||
# Split hidden_states columnwise based on TP
|
||||
tp_size = self.model_config.mapping.tp_size
|
||||
tp_rank = self.model_config.mapping.tp_rank
|
||||
|
||||
if tp_size > 1 and not (self.model_config.mapping.enable_attention_dp):
|
||||
hidden_states = torch.chunk(hidden_states, tp_size, dim=-1)[tp_rank]
|
||||
hidden_states = self.eh_proj(hidden_states)
|
||||
|
||||
# Input layer norm
|
||||
@ -1074,7 +1119,8 @@ class DeepseekV3ForCausalLM(DecoderModelForCausalLM[DeepseekV3Model,
|
||||
self.model.aux_stream_dict)
|
||||
self.model.layers.append(mtp_layer)
|
||||
self.epilogue.append(mtp_layer)
|
||||
self.mtp_worker = MTPEagleWorker(model_config.spec_config)
|
||||
self.mtp_worker = MTPEagleWorker(model_config.spec_config,
|
||||
model_config)
|
||||
else:
|
||||
mtp_layers = nn.ModuleList([
|
||||
DeepseekV3MTP(model_config,
|
||||
@ -1084,7 +1130,8 @@ class DeepseekV3ForCausalLM(DecoderModelForCausalLM[DeepseekV3Model,
|
||||
])
|
||||
self.model.layers.extend(mtp_layers)
|
||||
self.epilogue.extend(mtp_layers)
|
||||
self.mtp_worker = MTPWorker(model_config.spec_config)
|
||||
self.mtp_worker = MTPWorker(model_config.spec_config,
|
||||
model_config)
|
||||
# modify the QuantConfig to support duplicated mtp layers
|
||||
if model_config.quant_config.exclude_modules is not None:
|
||||
extend_exclude_modules = []
|
||||
|
||||
@ -359,6 +359,7 @@ class SpecDecOneEngineForCausalLM(DecoderModelForCausalLM[TModel, TConfig],
|
||||
|
||||
self.draft_model = get_draft_model(model_config, draft_config)
|
||||
self.spec_worker = get_spec_worker(model_config.spec_config,
|
||||
model_config,
|
||||
model_config.mapping)
|
||||
|
||||
def forward(
|
||||
|
||||
@ -1,10 +1,12 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from ..attention_backend import AttentionMetadata
|
||||
from ..distributed.ops import allgather
|
||||
from ..model_config import ModelConfig
|
||||
from ..pyexecutor.llm_request import LlmRequest, LlmRequestState
|
||||
from ..pyexecutor.resource_manager import BaseResourceManager, SlotManager
|
||||
from ..pyexecutor.sampler import (SampleState, SampleStateTensors, TorchSampler,
|
||||
@ -12,6 +14,9 @@ from ..pyexecutor.sampler import (SampleState, SampleStateTensors, TorchSampler,
|
||||
from ..pyexecutor.scheduler import ScheduledRequests
|
||||
from .interface import SpecMetadata
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from tensorrt_llm.llmapi.llm_args import MTPDecodingConfig
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class SampleStateTensorsMTP(SampleStateTensors):
|
||||
@ -311,9 +316,10 @@ class MTPSampler(TorchSampler):
|
||||
|
||||
class MTPWorker(nn.Module):
|
||||
|
||||
def __init__(self, spec_config: "MTPDecodingConfig"):
|
||||
def __init__(self, spec_config: "MTPDecodingConfig", model_config=None):
|
||||
super().__init__()
|
||||
self.spec_config = spec_config
|
||||
self.model_config = model_config
|
||||
self.is_thop = False
|
||||
|
||||
def forward(
|
||||
@ -670,6 +676,26 @@ class MTPWorker(nn.Module):
|
||||
mtp_past_hidden_states_pool.index_copy_(0, slot_ids,
|
||||
new_mtp_past_hidden_states)
|
||||
|
||||
@torch.compile(options={"max-autotune": True})
|
||||
def topk_kernel(self, gen_logprobs, num_gens, mtp_num_modules,
|
||||
spec_metadata):
|
||||
topk_value, topk_indices = torch.topk(gen_logprobs,
|
||||
k=self.spec_config.relaxed_topk,
|
||||
dim=-1)
|
||||
topk_indices = topk_indices.reshape(num_gens, mtp_num_modules + 1,
|
||||
self.spec_config.relaxed_topk)
|
||||
topk_value = topk_value.reshape(num_gens, mtp_num_modules + 1,
|
||||
self.spec_config.relaxed_topk)
|
||||
draft_tokens = spec_metadata.draft_tokens.reshape(
|
||||
num_gens, mtp_num_modules)
|
||||
return topk_value, topk_indices, draft_tokens
|
||||
|
||||
@torch.compile(options={"max-autotune": True})
|
||||
def process_generation_logits(self, logits, num_contexts):
|
||||
gen_logits = logits[num_contexts:]
|
||||
gen_logprobs = torch.softmax(gen_logits, dim=-1)
|
||||
return gen_logprobs
|
||||
|
||||
def sample_and_accept_draft_tokens(
|
||||
self,
|
||||
input_ids: torch.IntTensor,
|
||||
@ -787,20 +813,9 @@ class MTPWorker(nn.Module):
|
||||
mtp_relaxed_delta_pool.index_copy_(0, ctx_slot_ids, ctx_delta)
|
||||
|
||||
# generation
|
||||
gen_logits = logits[num_contexts:]
|
||||
gen_logprobs = torch.softmax(gen_logits, dim=-1)
|
||||
|
||||
topk_value, topk_indices = torch.topk(
|
||||
gen_logprobs, k=self.spec_config.relaxed_topk, dim=-1)
|
||||
# [num_gens, mtp_num_modules + 1, relaxed_topk]
|
||||
topk_indices = topk_indices.reshape(num_gens, mtp_num_modules + 1,
|
||||
self.spec_config.relaxed_topk)
|
||||
topk_value = topk_value.reshape(num_gens, mtp_num_modules + 1,
|
||||
self.spec_config.relaxed_topk)
|
||||
|
||||
# [num_gens, mtp_num_modules]
|
||||
draft_tokens = spec_metadata.draft_tokens.reshape(
|
||||
num_gens, mtp_num_modules)
|
||||
gen_logprobs = self.process_generation_logits(logits, num_contexts)
|
||||
topk_value, topk_indices, draft_tokens = self.topk_kernel(
|
||||
gen_logprobs, num_gens, mtp_num_modules, spec_metadata)
|
||||
|
||||
accepted_tokens, num_accepted_tokens = torch.ops.trtllm.mtp_relaxed_acceptance_op(
|
||||
spec_metadata.slot_ids, topk_value, topk_indices, draft_tokens,
|
||||
@ -1024,6 +1039,37 @@ class MTPWorker(nn.Module):
|
||||
"attn_metadata": attn_metadata,
|
||||
}
|
||||
|
||||
@torch.compile(options={"max-autotune": True})
|
||||
def get_local_max_and_combined(self, logits):
|
||||
local_max_values, local_argmax = torch.max(logits, dim=-1, keepdim=True)
|
||||
# Adjust indices based on TP rank and size
|
||||
vocab_per_rank = logits.shape[-1]
|
||||
max_index_per_rank = local_argmax.type(
|
||||
torch.int32) + (self.model_config.mapping.tp_rank * vocab_per_rank)
|
||||
# Use torch.stack and flatten instead of view+cat to avoid torch.compile issues
|
||||
# Convert both to float32 to ensure consistent dtype
|
||||
max_index_per_rank_float = max_index_per_rank.float()
|
||||
local_max_values_float32 = local_max_values.float()
|
||||
|
||||
# Stack and flatten to get interleaved layout: [idx0, val0, idx1, val1, ...]
|
||||
combined = torch.stack(
|
||||
[max_index_per_rank_float, local_max_values_float32],
|
||||
dim=-1).flatten(-2)
|
||||
return combined
|
||||
|
||||
@torch.compile(options={"max-autotune": True})
|
||||
def get_draft_tokens_from_gathered(self, gathered):
|
||||
gathered_indices_float = gathered[..., 0::2] # Even positions: indices
|
||||
gathered_values_float = gathered[..., 1::2] # Odd positions: values
|
||||
|
||||
# Find the rank with maximum value
|
||||
max_indices = torch.argmax(gathered_values_float, dim=-1, keepdim=True)
|
||||
|
||||
# Get the corresponding token indices and convert back to int32
|
||||
draft_tokens = torch.gather(gathered_indices_float, -1,
|
||||
max_indices).squeeze(-1).type(torch.int32)
|
||||
return draft_tokens
|
||||
|
||||
def draft_sampler(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
@ -1041,17 +1087,38 @@ class MTPWorker(nn.Module):
|
||||
[batch_size * max_draft_len]
|
||||
Draft token ids. Flattened.
|
||||
'''
|
||||
if (self.model_config is not None
|
||||
and hasattr(self.model_config, 'mapping')
|
||||
and self.model_config.mapping.tp_size
|
||||
> 1) and not (self.model_config.mapping.enable_attention_dp):
|
||||
combined = self.get_local_max_and_combined(logits)
|
||||
gathered = allgather(combined, self.model_config.mapping, dim=-1)
|
||||
draft_tokens = self.get_draft_tokens_from_gathered(gathered)
|
||||
else:
|
||||
# Simple argmax if no TP or no model config
|
||||
draft_tokens = torch.argmax(logits, dim=-1).type(torch.int32)
|
||||
|
||||
draft_tokens = torch.argmax(logits, dim=-1).type(torch.int32)
|
||||
return draft_tokens
|
||||
|
||||
|
||||
class MTPEagleWorker(MTPWorker):
|
||||
|
||||
def __init__(self, spec_config: "MTPDecodingConfig"):
|
||||
super().__init__(spec_config)
|
||||
def __init__(self,
|
||||
spec_config: "MTPDecodingConfig",
|
||||
model_config: Optional[ModelConfig] = None):
|
||||
super().__init__(spec_config, model_config)
|
||||
self.model_config = model_config
|
||||
self.mtp_num_modules = spec_config.num_nextn_predict_layers
|
||||
|
||||
@torch.compile(options={"max-autotune": True})
|
||||
def update_draft_tokens(self, next_draft_tokens, new_draft_token,
|
||||
hidden_states, gather_ids, inputs):
|
||||
next_draft_tokens.append(new_draft_token)
|
||||
# update inputs
|
||||
hidden_states = hidden_states[gather_ids]
|
||||
position_ids = inputs["position_ids"][gather_ids] + 1
|
||||
return hidden_states, position_ids
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids,
|
||||
@ -1079,9 +1146,15 @@ class MTPEagleWorker(MTPWorker):
|
||||
seq_len_cuda = attn_metadata._seq_lens_cuda[:batch_size].clone()
|
||||
|
||||
# Prepare inputs for the 1st MTP layer
|
||||
position_ids = position_ids.squeeze(0)
|
||||
last_tokens_idx = torch.cumsum(
|
||||
attn_metadata.seq_lens_cuda, dim=0, dtype=torch.long) - 1
|
||||
@torch.compile(options={"max-autotune": True})
|
||||
def prepare_position_ids_and_last_tokens(position_ids, attn_metadata):
|
||||
position_ids = position_ids.squeeze(0)
|
||||
last_tokens_idx = torch.cumsum(
|
||||
attn_metadata.seq_lens_cuda, dim=0, dtype=torch.long) - 1
|
||||
return position_ids, last_tokens_idx
|
||||
|
||||
position_ids, last_tokens_idx = prepare_position_ids_and_last_tokens(
|
||||
position_ids, attn_metadata)
|
||||
inputs = self.prepare_drafter_inputs(input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
last_tokens_idx=last_tokens_idx,
|
||||
@ -1122,10 +1195,10 @@ class MTPEagleWorker(MTPWorker):
|
||||
logits = mtp_layers[0].shared_head(hidden_states[gather_ids],
|
||||
lm_head, attn_metadata, True)
|
||||
new_draft_token = self.draft_sampler(logits)
|
||||
next_draft_tokens.append(new_draft_token)
|
||||
# update inputs
|
||||
hidden_states = hidden_states[gather_ids]
|
||||
position_ids = inputs["position_ids"][gather_ids] + 1
|
||||
|
||||
hidden_states, position_ids = self.update_draft_tokens(
|
||||
next_draft_tokens, new_draft_token, hidden_states, gather_ids,
|
||||
inputs)
|
||||
# update attn_metadata
|
||||
if i == 0:
|
||||
attn_metadata._seq_lens[:batch_size].fill_(1)
|
||||
@ -1154,14 +1227,18 @@ class MTPEagleWorker(MTPWorker):
|
||||
attn_metadata.block_ids_per_seq[:batch_size, :].copy_(
|
||||
reorder_block_ids_per_seq, non_blocking=True)
|
||||
elif hasattr(attn_metadata, 'kv_lens_cuda'):
|
||||
attn_metadata.kv_lens_cuda[:batch_size] += 1
|
||||
|
||||
@torch.compile(options={"max-autotune": True})
|
||||
def update_kv_lens(kv_lens_cuda, batch_size):
|
||||
kv_lens_cuda[:batch_size] += 1
|
||||
|
||||
update_kv_lens(attn_metadata.kv_lens_cuda, batch_size)
|
||||
inputs = {
|
||||
"input_ids": new_draft_token,
|
||||
"position_ids": position_ids,
|
||||
"hidden_states": hidden_states,
|
||||
"attn_metadata": attn_metadata,
|
||||
}
|
||||
next_draft_tokens = torch.stack(next_draft_tokens, dim=1)
|
||||
|
||||
# restore attn_metadata to support cuda graph
|
||||
if attn_metadata.is_cuda_graph:
|
||||
@ -1169,12 +1246,21 @@ class MTPEagleWorker(MTPWorker):
|
||||
attn_metadata._seq_lens_cuda[:batch_size].copy_(seq_len_cuda)
|
||||
attn_metadata.on_update()
|
||||
|
||||
# prepare next new tokens to support overlap scheduler
|
||||
next_new_tokens = accepted_tokens[
|
||||
spec_metadata.batch_indices_cuda[:batch_size],
|
||||
num_accepted_tokens - 1].unsqueeze(1)
|
||||
next_new_tokens = torch.concat([next_new_tokens, next_draft_tokens],
|
||||
dim=1)
|
||||
@torch.compile(options={"max-autotune": True})
|
||||
def prepare_next_tokens(next_draft_tokens, accepted_tokens,
|
||||
spec_metadata, batch_size, num_accepted_tokens):
|
||||
next_draft_tokens = torch.stack(next_draft_tokens, dim=1)
|
||||
# prepare next new tokens to support overlap scheduler
|
||||
next_new_tokens = accepted_tokens[
|
||||
spec_metadata.batch_indices_cuda[:batch_size],
|
||||
num_accepted_tokens - 1].unsqueeze(1)
|
||||
next_new_tokens = torch.concat([next_new_tokens, next_draft_tokens],
|
||||
dim=1)
|
||||
return next_draft_tokens, next_new_tokens
|
||||
|
||||
next_draft_tokens, next_new_tokens = prepare_next_tokens(
|
||||
next_draft_tokens, accepted_tokens, spec_metadata, batch_size,
|
||||
num_accepted_tokens)
|
||||
|
||||
return {
|
||||
'logits': raw_logits,
|
||||
@ -1184,6 +1270,7 @@ class MTPEagleWorker(MTPWorker):
|
||||
'next_new_tokens': next_new_tokens
|
||||
}
|
||||
|
||||
@torch.compile(options={"max-autotune": True})
|
||||
def prepare_drafter_inputs(
|
||||
self,
|
||||
input_ids: torch.IntTensor,
|
||||
|
||||
@ -145,11 +145,11 @@ def get_num_spec_layers(spec_config):
|
||||
return 0
|
||||
|
||||
|
||||
def get_spec_worker(spec_config, mapping):
|
||||
def get_spec_worker(spec_config, model_config, mapping):
|
||||
if spec_config.spec_dec_mode.is_mtp():
|
||||
return MTPWorker(spec_config)
|
||||
return MTPWorker(spec_config, model_config)
|
||||
if spec_config.spec_dec_mode.is_mtp_eagle():
|
||||
return MTPEagleWorker(spec_config)
|
||||
return MTPEagleWorker(spec_config, model_config)
|
||||
if spec_config.spec_dec_mode.is_eagle3_one_model():
|
||||
return Eagle3OneModelWorker(spec_config, mapping)
|
||||
return None
|
||||
|
||||
Loading…
Reference in New Issue
Block a user