[TRTLLM-5147][Qwen3] fix: fix bug of attention dp on qwen3_moe model (#4141)

* fix bug of attention dp on qwen3

Signed-off-by: bhsueh <11360707+byshiue@users.noreply.github.com>

* fix pre-commit changes

Signed-off-by: bhsueh <11360707+byshiue@users.noreply.github.com>

* fix bug of attention dp 8

Signed-off-by: bhsueh <11360707+byshiue@users.noreply.github.com>

---------

Signed-off-by: bhsueh <11360707+byshiue@users.noreply.github.com>
This commit is contained in:
bhsueh_NV 2025-05-09 09:29:39 +08:00 committed by GitHub
parent 836c142e1b
commit 700d09ab65
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -73,8 +73,10 @@ class Qwen3MoE(nn.Module):
hidden_states,
(0, 0, 0, max_num_token - hidden_states.shape[0]))
router_logits = self.gate(hidden_states)
final_hidden_states = self.experts(hidden_states, router_logits,
all_rank_num_tokens)
final_hidden_states = self.experts(
hidden_states,
router_logits,
all_rank_num_tokens=all_rank_num_tokens)
return final_hidden_states.view(orig_shape)
@ -183,14 +185,23 @@ class Qwen3MoEModel(DecoderModel):
self.padding_idx = config.pretrained_config.pad_token_id
self.aux_stream = torch.cuda.Stream()
self.embed_tokens = Embedding(
config.pretrained_config.vocab_size,
config.pretrained_config.hidden_size,
dtype=config.pretrained_config.torch_dtype,
mapping=config.mapping,
tensor_parallel_mode=TensorParallelMode.COLUMN,
gather_output=True,
)
if model_config.mapping.enable_attention_dp:
# When attention_dp is enabled, we cannot do all_reduce since
# the problem size of different ranks are different.
# So, we don't do parallelism here.
self.embed_tokens = nn.Embedding(
config.pretrained_config.vocab_size,
config.pretrained_config.hidden_size,
dtype=config.pretrained_config.torch_dtype)
else:
self.embed_tokens = Embedding(
config.pretrained_config.vocab_size,
config.pretrained_config.hidden_size,
dtype=config.pretrained_config.torch_dtype,
mapping=config.mapping,
tensor_parallel_mode=TensorParallelMode.COLUMN,
gather_output=True,
)
self.layers = nn.ModuleList([
Qwen3MoEDecoderLayer(
model_config,
@ -250,6 +261,8 @@ class Qwen3MoeForCausalLM(DecoderModelForCausalLM[Qwen3MoEModel,
def load_weights(self, weights: Dict):
tp_size = self.model_config.mapping.tp_size
enable_attention_dp = self.model_config.mapping.enable_attention_dp
head_dim = getattr(
self.config, "head_dim",
self.config.hidden_size // self.config.num_attention_heads)
@ -288,7 +301,8 @@ class Qwen3MoeForCausalLM(DecoderModelForCausalLM[Qwen3MoEModel,
k: (duplicate_kv_weight(
weight=v[:],
head_dim=head_dim,
tensor_parallel_size=tp_size)
tensor_parallel_size=tp_size
if not enable_attention_dp else 1)
if k in tensors_need_duplication else v)
for k, v in fw.items()
}