feat: Optionally split MoE inputs into chunks to reduce GPU memory usage (#3104)

Signed-off-by: Jinyang Yuan <154768711+jinyangyuan-nvidia@users.noreply.github.com>
Co-authored-by: raccoonliukai <raccoonliu@tencent.com>
This commit is contained in:
Jinyang Yuan 2025-04-01 16:07:02 +08:00 committed by GitHub
parent 727d78e785
commit 992d513bc6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 180 additions and 89 deletions

View File

@ -23,6 +23,8 @@ class ModelConfig(Generic[TConfig]):
quant_config_dict: Optional[Dict[str, QuantConfig]] = None
skip_create_weights: bool = False
is_generation: bool = True
max_num_tokens: int = 8192
moe_max_num_tokens: Optional[int] = None
attn_backend: str = 'TRTLLM'

View File

@ -15,8 +15,7 @@ from tensorrt_llm.llmapi.utils import enable_llm_debug
from ..attention_backend import AttentionMetadata
from ..attention_backend.interface import PositionalEmbeddingParams, RopeParams
from ..distributed import (AllReduce, AllReduceFusionOp, AllReduceParams,
DeepseekAllReduce, ParallelConfig, allgather,
reducescatter)
DeepseekAllReduce, ParallelConfig, allgather)
from ..model_config import ModelConfig
from ..models.modeling_utils import MissingLayer, ModelConfig, support_pp
from ..modules.attention import MLA
@ -29,7 +28,8 @@ from ..modules.rotary_embedding import RotaryEmbedding
from ..pipeline_interface import PipelineInterface
from ..pyexecutor.cuda_graph_runner import is_graph_capturing
from ..speculative import MTPEagleWorker, MTPSpecMetadata, MTPWorker
from ..utils import Fp4QuantizedTensor, disable_fp4_allgather
from ..utils import (AuxStreamType, EventType, Fp4QuantizedTensor,
disable_fp4_allgather)
from .modeling_utils import (DecoderModel, DecoderModelForCausalLM,
EagerFusionConfig, register_auto_model)
@ -258,9 +258,8 @@ class Deepseekv3MoE(nn.Module):
hidden_size: int,
intermediate_size: int,
shared_expert_intermediate_size: int,
aux_stream: torch.cuda.Stream,
aux_stream_dict: Dict[AuxStreamType, torch.cuda.Stream],
dtype: Optional[torch.dtype] = None,
tune_max_num_tokens: int = 8192,
model_config: ModelConfig = ModelConfig()):
from ..distributed import AllReduce
@ -284,8 +283,8 @@ class Deepseekv3MoE(nn.Module):
dtype=dtype,
reduce_results=
False, # In both low latency and attention dp scenarios, FusedMoE needs not to do allreduce inside op.
tune_max_num_tokens=tune_max_num_tokens,
model_config=model_config)
model_config=model_config,
aux_stream=aux_stream_dict[AuxStreamType.MoeChunkingOverlap])
self.shared_output_scale = None
if self.use_dp:
@ -314,18 +313,11 @@ class Deepseekv3MoE(nn.Module):
pipeline_parallel_size=model_config.mapping.pp_size,
parallel_rank=model_config.mapping.rank)
self.all_reduce = AllReduce(self.parallel_config)
self.aux_stream = aux_stream
self.moe_event = [torch.cuda.Event(), torch.cuda.Event()]
def reduce_scatter(self, input_tensor, all_rank_num_tokens):
world_size = self.parallel_config.tensor_parallel_size
rank = self.parallel_config.tensor_parallel_rank
if world_size == 1:
return input_tensor
dst_tensor = input_tensor
outputs = reducescatter(dst_tensor, self.parallel_config, scatter_dim=0)
depad_tensors = outputs[:all_rank_num_tokens[rank]]
return depad_tensors
self.aux_stream = aux_stream_dict[AuxStreamType.MoeShared]
self.event_dict = {
key: torch.cuda.Event()
for key in [EventType.Main, EventType.MoeShared]
}
def compute_routed_output(self, hidden_states, all_rank_num_tokens):
if self.use_dp and self.parallel_config.tensor_parallel_size > 1:
@ -338,10 +330,8 @@ class Deepseekv3MoE(nn.Module):
self.parallel_config,
gather_dim=0)
router_logits = self.gate(hidden_states)
routed_output = self.experts(hidden_states, router_logits)
if self.use_dp:
routed_output = self.reduce_scatter(routed_output,
all_rank_num_tokens)
routed_output = self.experts(hidden_states, router_logits,
all_rank_num_tokens)
return routed_output
def forward(
@ -354,17 +344,17 @@ class Deepseekv3MoE(nn.Module):
# This design is mainly for low latency use case. Need to improve for max throughput use case.
do_multi_stream = is_graph_capturing()
if do_multi_stream:
self.moe_event[0].record()
self.event_dict[EventType.Main].record()
shared_output = self.shared_experts(hidden_states)
if self.shared_output_scale is not None:
shared_output *= self.shared_output_scale
if do_multi_stream:
with torch.cuda.stream(self.aux_stream):
self.moe_event[0].wait()
self.event_dict[EventType.Main].wait()
routed_output = self.compute_routed_output(
hidden_states, all_rank_num_tokens)
self.moe_event[1].record()
self.moe_event[1].wait()
self.event_dict[EventType.MoeShared].record()
self.event_dict[EventType.MoeShared].wait()
else:
routed_output = self.compute_routed_output(hidden_states,
all_rank_num_tokens)
@ -381,7 +371,8 @@ class Deepseekv3MoE(nn.Module):
class DeepseekV3DecoderLayer(DecoderLayer):
def __init__(self, model_config: ModelConfig[PretrainedConfig],
layer_idx: int, aux_stream: torch.cuda.Stream):
layer_idx: int, aux_stream_dict: Dict[AuxStreamType,
torch.cuda.Stream]):
super().__init__()
config = model_config.pretrained_config
self.hidden_size = config.hidden_size
@ -390,9 +381,10 @@ class DeepseekV3DecoderLayer(DecoderLayer):
self.num_shared_experts = config.n_shared_experts
self.top_k = config.num_experts_per_tok
self.self_attn = DeepseekV3Attention(model_config,
layer_idx=layer_idx,
aux_stream=aux_stream)
self.self_attn = DeepseekV3Attention(
model_config,
layer_idx=layer_idx,
aux_stream=aux_stream_dict[AuxStreamType.Attention])
self.fusion_config = EagerFusionConfig()
self.enable_attention_dp = model_config.mapping.enable_attention_dp
self.mlp_tp_size = model_config.mapping.tp_size
@ -420,9 +412,8 @@ class DeepseekV3DecoderLayer(DecoderLayer):
shared_expert_intermediate_size=self.moe_intermediate_size *
self.num_shared_experts,
dtype=config.torch_dtype,
tune_max_num_tokens=config.max_position_embeddings // 4,
model_config=model_config,
aux_stream=aux_stream)
aux_stream_dict=aux_stream_dict)
else:
if self.enable_attention_dp:
self.mlp_tp_size = 1
@ -606,8 +597,9 @@ class DeepseekV3DecoderLayer(DecoderLayer):
class DeepseekV3MTP(DeepseekV3DecoderLayer):
def __init__(self, model_config: ModelConfig[PretrainedConfig],
layer_idx: int, aux_stream: torch.cuda.Stream):
super().__init__(model_config, layer_idx, aux_stream)
layer_idx: int, aux_stream_dict: Dict[AuxStreamType,
torch.cuda.Stream]):
super().__init__(model_config, layer_idx, aux_stream_dict)
config = model_config.pretrained_config
self.hidden_dim = config.hidden_size
self.moe_intermediate_size = config.moe_intermediate_size
@ -735,14 +727,21 @@ class DeepseekV3Model(DecoderModel):
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.num_hidden_layers = config.num_hidden_layers
self.aux_stream = torch.cuda.Stream()
self.aux_stream_dict = {
key: torch.cuda.Stream()
for key in [
AuxStreamType.Attention, AuxStreamType.MoeShared,
AuxStreamType.MoeChunkingOverlap
]
}
self.embed_tokens = nn.Embedding(config.vocab_size,
config.hidden_size,
dtype=config.torch_dtype)
self.layers = nn.ModuleList([
DeepseekV3DecoderLayer(model_config, layer_idx, self.aux_stream)
DeepseekV3DecoderLayer(model_config, layer_idx,
self.aux_stream_dict)
for layer_idx in range(config.num_hidden_layers)
])
self.norm = RMSNorm(hidden_size=config.hidden_size,
@ -842,7 +841,7 @@ class DeepseekV3ForCausalLM(DecoderModelForCausalLM[DeepseekV3Model,
assert ckpt_nextn > 0, "There is not MTP modules in the checkpoint."
if ckpt_nextn == 1:
mtp_layer = DeepseekV3MTP(model_config, self.num_hidden_layers,
self.model.aux_stream)
self.model.aux_stream_dict)
self.model.layers.append(mtp_layer)
self.mtp_worker = MTPEagleWorker(model_config.spec_config)
else:
@ -851,7 +850,7 @@ class DeepseekV3ForCausalLM(DecoderModelForCausalLM[DeepseekV3Model,
mtp_layers = nn.ModuleList([
DeepseekV3MTP(model_config,
layer_idx + self.num_hidden_layers,
self.model.aux_stream)
self.model.aux_stream_dict)
for layer_idx in range(model_nextn)
])
self.model.layers.extend(mtp_layers)

View File

@ -8,7 +8,7 @@ from tensorrt_llm.functional import PositionEmbeddingType
from ..attention_backend import AttentionMetadata
from ..attention_backend.interface import PositionalEmbeddingParams, RopeParams
from ..distributed import ParallelConfig, allgather
from ..distributed import ParallelConfig
from ..model_config import ModelConfig
from ..models.modeling_utils import ModelConfig
from ..modules.attention import Attention
@ -23,13 +23,18 @@ from .modeling_utils import (DecoderModel, DecoderModelForCausalLM,
class MixtralMoE(nn.Module):
def __init__(self, model_config: ModelConfig[PretrainedConfig]):
def __init__(
self,
model_config: ModelConfig[PretrainedConfig],
aux_stream: torch.cuda.Stream,
):
super().__init__()
config = model_config.pretrained_config
self.hidden_dim = config.hidden_size
self.ffn_dim = config.intermediate_size
self.num_experts = config.num_local_experts
self.top_k = config.num_experts_per_tok
self.enable_attention_dp = model_config.mapping.enable_attention_dp
# moe gate (linear layer) only runs in half/full precision for now
self.gate = Linear(self.hidden_dim,
@ -45,19 +50,26 @@ class MixtralMoE(nn.Module):
routing_method=RenormalizeMoeRoutingMethod(top_k=self.top_k),
hidden_size=self.hidden_dim,
intermediate_size=self.ffn_dim,
aux_stream=aux_stream,
dtype=config.torch_dtype,
reduce_results=reduce_results,
tune_max_num_tokens=config.max_position_embeddings // 4,
model_config=model_config)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
assert hidden_states.shape[-1] == self.hidden_dim
orig_shape = hidden_states.shape
hidden_states = hidden_states.view(-1, self.hidden_dim)
# router_logits: (batch * sequence_length, n_experts)
def forward(
self,
hidden_states: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
all_rank_num_tokens = attn_metadata.all_rank_num_tokens
if self.enable_attention_dp and len(all_rank_num_tokens) > 1:
max_num_token = max(all_rank_num_tokens)
hidden_states = torch.nn.functional.pad(
hidden_states,
(0, 0, 0, max_num_token - hidden_states.shape[0]))
router_logits = self.gate(hidden_states)
final_hidden_states = self.experts(hidden_states, router_logits)
return final_hidden_states.view(orig_shape), router_logits
final_hidden_states = self.experts(hidden_states, router_logits,
all_rank_num_tokens)
return final_hidden_states
class MixtralRotaryEmbedding(RotaryEmbedding):
@ -105,14 +117,14 @@ class MixtralAttention(Attention):
class MixtralDecoderLayer(DecoderLayer):
def __init__(self, model_config: ModelConfig[PretrainedConfig],
layer_idx: int):
layer_idx: int, aux_stream: torch.cuda.Stream):
super().__init__()
config = model_config.pretrained_config
self.hidden_size = config.hidden_size
self.self_attn = MixtralAttention(model_config, layer_idx=layer_idx)
self.block_sparse_moe = MixtralMoE(model_config)
self.block_sparse_moe = MixtralMoE(model_config, aux_stream)
self.input_layernorm = RMSNorm(hidden_size=config.hidden_size,
eps=config.rms_norm_eps,
@ -121,7 +133,6 @@ class MixtralDecoderLayer(DecoderLayer):
self.post_attention_layernorm = RMSNorm(hidden_size=config.hidden_size,
eps=config.rms_norm_eps,
dtype=config.torch_dtype)
self.enable_attention_dp = model_config.mapping.enable_attention_dp
# TODO: add pipeline parallel config
self.parallel_config = ParallelConfig(
tensor_parallel_rank=model_config.mapping.tp_rank,
@ -129,26 +140,6 @@ class MixtralDecoderLayer(DecoderLayer):
gpus_per_node=model_config.mapping.gpus_per_node)
self.layer_idx = layer_idx
def all_gather(self, input_tensor, attn_metadata):
rank = self.parallel_config.tensor_parallel_rank
world_size = self.parallel_config.tensor_parallel_size
all_rank_num_tokens = attn_metadata.all_rank_num_tokens
max_num_token = max(all_rank_num_tokens)
if world_size == 1:
return input_tensor, 0, max_num_token
pad_tensor = torch.nn.functional.pad(
input_tensor, (0, 0, 0, max_num_token - input_tensor.shape[0]))
outputs = allgather(pad_tensor, self.parallel_config, gather_dim=0)
depad_tensors = torch.concat([
outputs[i * max_num_token:i * max_num_token +
all_rank_num_tokens[i]] for i in range(world_size)
])
cur_rank_start = 0 if rank == 0 else sum(all_rank_num_tokens[:rank])
cur_rank_end = cur_rank_start + all_rank_num_tokens[rank]
return depad_tensors, cur_rank_start, cur_rank_end
def forward(
self,
position_ids: torch.LongTensor,
@ -175,12 +166,7 @@ class MixtralDecoderLayer(DecoderLayer):
# Fully Connected
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
if self.enable_attention_dp:
hidden_states, cur_rank_start, cur_rank_end = self.all_gather(
hidden_states, attn_metadata)
hidden_states, _router_logits = self.block_sparse_moe(hidden_states)
if self.enable_attention_dp:
hidden_states = hidden_states[cur_rank_start:cur_rank_end]
hidden_states = self.block_sparse_moe(hidden_states, attn_metadata)
return hidden_states, residual
@ -191,6 +177,7 @@ class MixtralModel(DecoderModel):
config = model_config.pretrained_config
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.aux_stream = torch.cuda.Stream()
self.embed_tokens = nn.Embedding(config.vocab_size,
config.hidden_size,
@ -198,7 +185,7 @@ class MixtralModel(DecoderModel):
dtype=config.torch_dtype)
self.layers = nn.ModuleList([
MixtralDecoderLayer(model_config, layer_idx)
MixtralDecoderLayer(model_config, layer_idx, self.aux_stream)
for layer_idx in range(config.num_hidden_layers)
])
self.norm = RMSNorm(hidden_size=config.hidden_size,

View File

@ -5,9 +5,10 @@ import torch
from torch import nn
from ...quantization.utils.fp4_utils import float4_sf_dtype
from ..distributed import allgather
from ..distributed import allgather, reducescatter
from ..model_config import ModelConfig
from ..utils import disable_fp4_allgather, is_torch_compiling, reswizzle_sf
from ..utils import (EventType, disable_fp4_allgather, is_torch_compiling,
reswizzle_sf)
from .linear import ParallelConfig, TensorParallelMode, load_weight_shard
# The declarations aligns with moe_kernels.h
@ -198,9 +199,9 @@ class FusedMoE(nn.Module):
top_k (int): Number of top experts to select for each input token.
hidden_size (int): Size of the hidden state.
intermediate_size (int): Size of the intermediate state.
aux_stream (torch.cuda.Stream): Auxiliary CUDA stream to overlap chunks.
dtype (Optional[torch.dtype]): Data type for the weights.
reduce_results (bool): Whether to reduce the results across devices.
tune_max_num_tokens (int): Maximum number of tokens for performance tuning.
model_config (ModelConfig): Configuration object for the model.
"""
@ -213,8 +214,8 @@ class FusedMoE(nn.Module):
intermediate_size: int,
dtype: Optional[torch.dtype] = None,
reduce_results: bool = False,
tune_max_num_tokens: int = 8192,
model_config: ModelConfig = ModelConfig(),
aux_stream: torch.cuda.Stream = torch.cuda.Stream(),
):
from ..distributed import AllReduce
@ -224,6 +225,12 @@ class FusedMoE(nn.Module):
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.aux_stream = aux_stream
self.event_dict = {
key: torch.cuda.Event()
for key in [EventType.Main, EventType.MoeChunkingOverlap]
}
self.dtype = dtype
self.reduce_results = reduce_results
# could be modified later
@ -256,7 +263,18 @@ class FusedMoE(nn.Module):
self.expert_start + self.expert_size_per_partition,
self.num_experts)
self.tune_max_num_tokens = tune_max_num_tokens
self.moe_max_num_tokens = model_config.moe_max_num_tokens
if self.moe_max_num_tokens is None:
self.moe_max_num_tokens = model_config.max_num_tokens
if self.use_dp:
self.moe_max_num_tokens *= model_config.mapping.world_size
# The profiler converges on the same best tactic when the number of tokens is large enough.
# To avoid long profiling time, the max number of tokens used in the profiling is capped to
# around 16k tokens per expert, which is well into the compute bound domain.
self.tune_max_num_tokens = min(
self.moe_max_num_tokens,
16384 * num_experts / routing_method.get_experts_per_token(),
)
self.has_been_profiled = False
self._weights_created = False
@ -484,7 +502,18 @@ class FusedMoE(nn.Module):
outputs.append(output)
return outputs
def forward(
def reducescatter_or_allreduce(self, inputs):
outputs = inputs
if self.parallel_size > 1:
if self.use_dp:
outputs = reducescatter(inputs,
self.parallel_config,
scatter_dim=0)
elif self.reduce_results:
outputs = self.all_reduce(inputs)
return outputs
def forward_chunk(
self,
x: torch.Tensor,
router_logits: torch.Tensor,
@ -570,11 +599,65 @@ class FusedMoE(nn.Module):
use_fp8_block_scaling=use_fp8_block_scaling,
)
if self.reduce_results and self.parallel_size > 1:
final_hidden_states = self.all_reduce(final_hidden_states)
return final_hidden_states
def forward(
self,
x: torch.Tensor,
router_logits: torch.Tensor,
all_rank_num_tokens: Optional[List[int]] = None,
) -> torch.Tensor:
max_chunk_size = self.moe_max_num_tokens
if self.use_dp:
assert all_rank_num_tokens is not None
if not disable_fp4_allgather():
max_chunk_size //= len(all_rank_num_tokens)
num_chunks = (x.shape[0] + max_chunk_size - 1) // max_chunk_size
if num_chunks == 1:
outputs = self.forward_chunk(x, router_logits)
outputs = self.reducescatter_or_allreduce(outputs)
else:
val_div = x.shape[0] // num_chunks
val_mod = x.shape[0] % num_chunks
chunk_size_list = [val_div + 1
] * val_mod + [val_div] * (num_chunks - val_mod)
x_list = x.split(chunk_size_list)
router_logits_list = router_logits.split(chunk_size_list)
outputs_list = []
self.event_dict[EventType.Main].record()
with torch.cuda.stream(self.aux_stream):
self.event_dict[EventType.Main].wait()
# Postpone reduce-scatter/all-reduce to the next iteration to achieve better overlap
for idx_chunk, (x, router_logits) in enumerate(
zip(x_list, router_logits_list)):
if idx_chunk % 2 == 0:
with torch.cuda.stream(self.aux_stream):
outputs = self.forward_chunk(x, router_logits)
if idx_chunk > 0:
outputs_list[-1] = self.reducescatter_or_allreduce(
outputs_list[-1])
else:
outputs = self.forward_chunk(x, router_logits)
with torch.cuda.stream(self.aux_stream):
outputs_list[-1] = self.reducescatter_or_allreduce(
outputs_list[-1])
outputs_list.append(outputs)
if num_chunks % 2 == 0:
outputs_list[-1] = self.reducescatter_or_allreduce(
outputs_list[-1])
else:
with torch.cuda.stream(self.aux_stream):
outputs_list[-1] = self.reducescatter_or_allreduce(
outputs_list[-1])
with torch.cuda.stream(self.aux_stream):
self.event_dict[EventType.MoeChunkingOverlap].record()
self.event_dict[EventType.MoeChunkingOverlap].wait()
outputs = torch.cat(outputs_list)
if self.use_dp:
rank = self.parallel_config.tensor_parallel_rank
outputs = outputs[:all_rank_num_tokens[rank]]
return outputs
def load_weights(self, weights: List[Dict]):
assert self._weights_created
assert len(weights) == 1

View File

@ -46,6 +46,10 @@ class PyTorchConfig:
# This is usually a net win for performance.
cuda_graph_padding_enabled: bool = False
enable_overlap_scheduler: bool = False
max_num_tokens: int = 8192
# If set, at most moe_max_num_tokens tokens will be sent to torch.ops.trtllm.fused_moe at the same time.
# If the number of tokens exceeds moe_max_num_tokens, the input tensors will be split into chunks and a for loop will be used.
moe_max_num_tokens: Optional[int] = None
attn_backend: str = 'TRTLLM'
# If true, will iterate over sampling_params of each request and use the

View File

@ -257,6 +257,8 @@ class PyTorchModelEngine(ModelEngine):
mapping=self.mapping,
attn_backend=attn_backend,
load_format=pytorch_backend_config.load_format,
max_num_tokens=max_num_tokens,
moe_max_num_tokens=pytorch_backend_config.moe_max_num_tokens,
)
if self.pytorch_backend_config.enable_layerwise_nvtx_marker:
layerwise_nvtx_marker = LayerwiseNvtxMarker()
@ -732,11 +734,13 @@ class PyTorchModelEngine(ModelEngine):
torch.cuda.empty_cache()
def _load_model(self, checkpoint_dir: str, load_format: LoadFormat,
**kwargs):
max_num_tokens: int, moe_max_num_tokens: int, **kwargs):
config = ModelConfig.from_pretrained(checkpoint_dir,
trust_remote_code=True,
**kwargs)
config.spec_config = self.spec_config
config.max_num_tokens = max_num_tokens
config.moe_max_num_tokens = moe_max_num_tokens
validate_and_set_kv_cache_quant(
config, self.pytorch_backend_config.kv_cache_dtype)

View File

@ -1,5 +1,6 @@
import os
from dataclasses import dataclass
from enum import Enum
from typing import List
import torch
@ -10,6 +11,17 @@ from .pipeline_interface import PipelineInterface
is_torch_compiling_flag = False
aux_stream_name_list = ['Attention', 'MoeShared', 'MoeChunkingOverlap']
AuxStreamType = Enum(
'AuxStreamType',
aux_stream_name_list,
)
EventType = Enum(
'EventType',
['Main', *aux_stream_name_list],
start=0,
)
def set_torch_compiling(enable: bool):
global is_torch_compiling_flag