mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
[None][feat] Support EPLB in Qwen3 MoE (#7443)
Signed-off-by: Gabriel Wu <13583761+lucifer1004@users.noreply.github.com> Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
This commit is contained in:
parent
0ac51487f4
commit
0e72e8f7e6
@ -78,7 +78,7 @@ class Qwen3MoE(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
model_config: ModelConfig[Qwen3MoeConfig],
|
||||
aux_stream: torch.cuda.Stream,
|
||||
aux_stream_dict: Dict[AuxStreamType, torch.cuda.Stream],
|
||||
layer_idx: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
@ -108,7 +108,7 @@ class Qwen3MoE(nn.Module):
|
||||
routing_method=self.gate.routing_method,
|
||||
hidden_size=self.hidden_dim,
|
||||
intermediate_size=self.moe_intermediate_size,
|
||||
aux_stream_dict={AuxStreamType.MoeChunkingOverlap: aux_stream},
|
||||
aux_stream_dict=aux_stream_dict,
|
||||
dtype=config.torch_dtype,
|
||||
reduce_results=False,
|
||||
model_config=model_config,
|
||||
@ -160,7 +160,8 @@ class Qwen3MoE(nn.Module):
|
||||
class Qwen3MoEDecoderLayer(DecoderLayer):
|
||||
|
||||
def __init__(self, model_config: ModelConfig[Qwen3MoeConfig],
|
||||
layer_idx: int, aux_stream: torch.cuda.Stream):
|
||||
layer_idx: int, aux_stream_dict: Dict[AuxStreamType,
|
||||
torch.cuda.Stream]):
|
||||
super().__init__()
|
||||
self.model_config = model_config
|
||||
config = model_config.pretrained_config
|
||||
@ -171,7 +172,7 @@ class Qwen3MoEDecoderLayer(DecoderLayer):
|
||||
self.mapping = model_config.mapping
|
||||
self.enable_attention_dp = self.mapping.enable_attention_dp
|
||||
|
||||
self.mlp = Qwen3MoE(model_config, aux_stream, layer_idx=layer_idx)
|
||||
self.mlp = Qwen3MoE(model_config, aux_stream_dict, layer_idx=layer_idx)
|
||||
|
||||
self.input_layernorm = RMSNorm(hidden_size=config.hidden_size,
|
||||
eps=config.rms_norm_eps,
|
||||
@ -302,7 +303,10 @@ class Qwen3MoEModel(DecoderModel):
|
||||
def __init__(self, model_config: ModelConfig[Qwen3MoeConfig]):
|
||||
super().__init__(model_config)
|
||||
config = self.model_config
|
||||
self.aux_stream = torch.cuda.Stream()
|
||||
self.aux_stream_dict = {
|
||||
AuxStreamType.MoeChunkingOverlap: torch.cuda.Stream(),
|
||||
AuxStreamType.MoeBalancer: torch.cuda.Stream(),
|
||||
}
|
||||
self.preload_weight_modules = []
|
||||
if config.moe_backend == "TRTLLM":
|
||||
self.preload_weight_modules = [
|
||||
@ -332,7 +336,7 @@ class Qwen3MoEModel(DecoderModel):
|
||||
Qwen3MoEDecoderLayer(
|
||||
model_config,
|
||||
layer_idx,
|
||||
self.aux_stream,
|
||||
self.aux_stream_dict,
|
||||
) for layer_idx in range(config.pretrained_config.num_hidden_layers)
|
||||
])
|
||||
self.norm = RMSNorm(
|
||||
|
||||
Loading…
Reference in New Issue
Block a user