mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
[TRTLLM-5147][Qwen3] fix: fix bug of attention dp on qwen3_moe model (#4141)
* fix bug of attention dp on qwen3 Signed-off-by: bhsueh <11360707+byshiue@users.noreply.github.com> * fix pre-commit changes Signed-off-by: bhsueh <11360707+byshiue@users.noreply.github.com> * fix bug of attention dp 8 Signed-off-by: bhsueh <11360707+byshiue@users.noreply.github.com> --------- Signed-off-by: bhsueh <11360707+byshiue@users.noreply.github.com>
This commit is contained in:
parent
836c142e1b
commit
700d09ab65
@ -73,8 +73,10 @@ class Qwen3MoE(nn.Module):
|
||||
hidden_states,
|
||||
(0, 0, 0, max_num_token - hidden_states.shape[0]))
|
||||
router_logits = self.gate(hidden_states)
|
||||
final_hidden_states = self.experts(hidden_states, router_logits,
|
||||
all_rank_num_tokens)
|
||||
final_hidden_states = self.experts(
|
||||
hidden_states,
|
||||
router_logits,
|
||||
all_rank_num_tokens=all_rank_num_tokens)
|
||||
|
||||
return final_hidden_states.view(orig_shape)
|
||||
|
||||
@ -183,14 +185,23 @@ class Qwen3MoEModel(DecoderModel):
|
||||
self.padding_idx = config.pretrained_config.pad_token_id
|
||||
self.aux_stream = torch.cuda.Stream()
|
||||
|
||||
self.embed_tokens = Embedding(
|
||||
config.pretrained_config.vocab_size,
|
||||
config.pretrained_config.hidden_size,
|
||||
dtype=config.pretrained_config.torch_dtype,
|
||||
mapping=config.mapping,
|
||||
tensor_parallel_mode=TensorParallelMode.COLUMN,
|
||||
gather_output=True,
|
||||
)
|
||||
if model_config.mapping.enable_attention_dp:
|
||||
# When attention_dp is enabled, we cannot do all_reduce since
|
||||
# the problem size of different ranks are different.
|
||||
# So, we don't do parallelism here.
|
||||
self.embed_tokens = nn.Embedding(
|
||||
config.pretrained_config.vocab_size,
|
||||
config.pretrained_config.hidden_size,
|
||||
dtype=config.pretrained_config.torch_dtype)
|
||||
else:
|
||||
self.embed_tokens = Embedding(
|
||||
config.pretrained_config.vocab_size,
|
||||
config.pretrained_config.hidden_size,
|
||||
dtype=config.pretrained_config.torch_dtype,
|
||||
mapping=config.mapping,
|
||||
tensor_parallel_mode=TensorParallelMode.COLUMN,
|
||||
gather_output=True,
|
||||
)
|
||||
self.layers = nn.ModuleList([
|
||||
Qwen3MoEDecoderLayer(
|
||||
model_config,
|
||||
@ -250,6 +261,8 @@ class Qwen3MoeForCausalLM(DecoderModelForCausalLM[Qwen3MoEModel,
|
||||
|
||||
def load_weights(self, weights: Dict):
|
||||
tp_size = self.model_config.mapping.tp_size
|
||||
enable_attention_dp = self.model_config.mapping.enable_attention_dp
|
||||
|
||||
head_dim = getattr(
|
||||
self.config, "head_dim",
|
||||
self.config.hidden_size // self.config.num_attention_heads)
|
||||
@ -288,7 +301,8 @@ class Qwen3MoeForCausalLM(DecoderModelForCausalLM[Qwen3MoEModel,
|
||||
k: (duplicate_kv_weight(
|
||||
weight=v[:],
|
||||
head_dim=head_dim,
|
||||
tensor_parallel_size=tp_size)
|
||||
tensor_parallel_size=tp_size
|
||||
if not enable_attention_dp else 1)
|
||||
if k in tensors_need_duplication else v)
|
||||
for k, v in fw.items()
|
||||
}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user