[None][feat] Support EPLB in Qwen3 MoE (#7443)

Signed-off-by: Gabriel Wu <13583761+lucifer1004@users.noreply.github.com>
Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
This commit is contained in:
Gabriel Wu 2025-09-19 16:45:35 +08:00 committed by GitHub
parent 0ac51487f4
commit 0e72e8f7e6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -78,7 +78,7 @@ class Qwen3MoE(nn.Module):
def __init__(
self,
model_config: ModelConfig[Qwen3MoeConfig],
aux_stream: torch.cuda.Stream,
aux_stream_dict: Dict[AuxStreamType, torch.cuda.Stream],
layer_idx: Optional[int] = None,
):
super().__init__()
@ -108,7 +108,7 @@ class Qwen3MoE(nn.Module):
routing_method=self.gate.routing_method,
hidden_size=self.hidden_dim,
intermediate_size=self.moe_intermediate_size,
aux_stream_dict={AuxStreamType.MoeChunkingOverlap: aux_stream},
aux_stream_dict=aux_stream_dict,
dtype=config.torch_dtype,
reduce_results=False,
model_config=model_config,
@ -160,7 +160,8 @@ class Qwen3MoE(nn.Module):
class Qwen3MoEDecoderLayer(DecoderLayer):
def __init__(self, model_config: ModelConfig[Qwen3MoeConfig],
layer_idx: int, aux_stream: torch.cuda.Stream):
layer_idx: int, aux_stream_dict: Dict[AuxStreamType,
torch.cuda.Stream]):
super().__init__()
self.model_config = model_config
config = model_config.pretrained_config
@ -171,7 +172,7 @@ class Qwen3MoEDecoderLayer(DecoderLayer):
self.mapping = model_config.mapping
self.enable_attention_dp = self.mapping.enable_attention_dp
self.mlp = Qwen3MoE(model_config, aux_stream, layer_idx=layer_idx)
self.mlp = Qwen3MoE(model_config, aux_stream_dict, layer_idx=layer_idx)
self.input_layernorm = RMSNorm(hidden_size=config.hidden_size,
eps=config.rms_norm_eps,
@ -302,7 +303,10 @@ class Qwen3MoEModel(DecoderModel):
def __init__(self, model_config: ModelConfig[Qwen3MoeConfig]):
super().__init__(model_config)
config = self.model_config
self.aux_stream = torch.cuda.Stream()
self.aux_stream_dict = {
AuxStreamType.MoeChunkingOverlap: torch.cuda.Stream(),
AuxStreamType.MoeBalancer: torch.cuda.Stream(),
}
self.preload_weight_modules = []
if config.moe_backend == "TRTLLM":
self.preload_weight_modules = [
@ -332,7 +336,7 @@ class Qwen3MoEModel(DecoderModel):
Qwen3MoEDecoderLayer(
model_config,
layer_idx,
self.aux_stream,
self.aux_stream_dict,
) for layer_idx in range(config.pretrained_config.num_hidden_layers)
])
self.norm = RMSNorm(