diff --git a/tensorrt_llm/_torch/models/modeling_qwen3_moe.py b/tensorrt_llm/_torch/models/modeling_qwen3_moe.py index 943c2e2ed2..f6d1e504e4 100644 --- a/tensorrt_llm/_torch/models/modeling_qwen3_moe.py +++ b/tensorrt_llm/_torch/models/modeling_qwen3_moe.py @@ -73,8 +73,10 @@ class Qwen3MoE(nn.Module): hidden_states, (0, 0, 0, max_num_token - hidden_states.shape[0])) router_logits = self.gate(hidden_states) - final_hidden_states = self.experts(hidden_states, router_logits, - all_rank_num_tokens) + final_hidden_states = self.experts( + hidden_states, + router_logits, + all_rank_num_tokens=all_rank_num_tokens) return final_hidden_states.view(orig_shape) @@ -183,14 +185,23 @@ class Qwen3MoEModel(DecoderModel): self.padding_idx = config.pretrained_config.pad_token_id self.aux_stream = torch.cuda.Stream() - self.embed_tokens = Embedding( - config.pretrained_config.vocab_size, - config.pretrained_config.hidden_size, - dtype=config.pretrained_config.torch_dtype, - mapping=config.mapping, - tensor_parallel_mode=TensorParallelMode.COLUMN, - gather_output=True, - ) + if model_config.mapping.enable_attention_dp: + # When attention_dp is enabled, we cannot do all_reduce since + # the problem size of different ranks are different. + # So, we don't do parallelism here. + self.embed_tokens = nn.Embedding( + config.pretrained_config.vocab_size, + config.pretrained_config.hidden_size, + dtype=config.pretrained_config.torch_dtype) + else: + self.embed_tokens = Embedding( + config.pretrained_config.vocab_size, + config.pretrained_config.hidden_size, + dtype=config.pretrained_config.torch_dtype, + mapping=config.mapping, + tensor_parallel_mode=TensorParallelMode.COLUMN, + gather_output=True, + ) self.layers = nn.ModuleList([ Qwen3MoEDecoderLayer( model_config, @@ -250,6 +261,8 @@ class Qwen3MoeForCausalLM(DecoderModelForCausalLM[Qwen3MoEModel, def load_weights(self, weights: Dict): tp_size = self.model_config.mapping.tp_size + enable_attention_dp = self.model_config.mapping.enable_attention_dp + head_dim = getattr( self.config, "head_dim", self.config.hidden_size // self.config.num_attention_heads) @@ -288,7 +301,8 @@ class Qwen3MoeForCausalLM(DecoderModelForCausalLM[Qwen3MoEModel, k: (duplicate_kv_weight( weight=v[:], head_dim=head_dim, - tensor_parallel_size=tp_size) + tensor_parallel_size=tp_size + if not enable_attention_dp else 1) if k in tensors_need_duplication else v) for k, v in fw.items() }