This commit is contained in:
Guoming Zhang 2026-01-13 21:36:15 +08:00 committed by GitHub
commit 89e81fc442
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 110 additions and 66 deletions

View File

@ -38,7 +38,7 @@ Note: Support for other models may vary. Features marked "N/A" are not applicabl
| `DeepseekV3ForCausalLM` | Yes | Yes | Yes | Yes | Yes [^1] | Yes | No | No | Yes | Yes | Yes [^2] | N/A | Yes | Yes |
| `DeepseekV32ForCausalLM` | Yes | Yes | Yes | Yes | Yes | Yes | No | No | Yes | Yes | Yes | N/A | Yes | Yes |
| `Qwen3MoeForCausalLM` | Yes | Yes | Yes | Yes | Yes | No | Yes | Yes | Yes | Yes | Yes | N/A | Yes | Yes |
| `Qwen3NextForCausalLM` [^3] | Yes | Yes | No | Untested | Yes | No | No | No | Yes | Yes | No | No | Untested | Untested |
| `Qwen3NextForCausalLM` [^3] | Yes | Yes | Yes | Untested | Yes | No | No | No | Yes | Yes | No | No | Untested | Untested |
| `Llama4ForConditionalGeneration` | Yes | Yes | Yes | Yes | Yes | No | Yes | Yes | Yes | Yes | Untested | N/A | Yes | Yes |
| `GptOssForCausalLM` | Yes | Yes | Yes | Yes | Yes | No | Yes | Yes [^4] | Yes | Yes | Yes | N/A | Yes | Yes |

View File

@ -57,6 +57,8 @@ class Qwen3NextHfWeightMapper(Qwen2MoeHfWeightMapper):
tp_size = self.config.mapping.tp_size
tp_rank = self.config.mapping.tp_rank
if self.config.mapping.enable_attention_dp:
tp_size = 1
# linear_num_value_heads = config.linear_num_value_heads
# linear_num_key_heads = config.linear_num_key_heads
# linear_key_head_dim = config.linear_key_head_dim

View File

@ -31,19 +31,19 @@ from tensorrt_llm._torch.modules.fla.chunk import chunk_gated_delta_rule
from tensorrt_llm._torch.modules.fla.fused_sigmoid_gating_recurrent import \
fused_sigmoid_gating_delta_rule_update
from tensorrt_llm._torch.modules.mamba.mamba2_metadata import Mamba2Metadata
from tensorrt_llm._utils import get_sm_version
from tensorrt_llm.mapping import Mapping
from ..attention_backend import AttentionMetadata
from ..distributed import (AllReduce, AllReduceFusionOp, AllReduceParams,
MoEAllReduce, MoEAllReduceParams, allgather)
MoEAllReduce, MoEAllReduceParams)
from ..model_config import ModelConfig
from ..modules.decoder_layer import DecoderLayer
from ..modules.embedding import Embedding
from ..modules.fused_moe import (BaseMoeRoutingMethod,
RenormalizeMoeRoutingMethod,
RenormalizeNaiveMoeRoutingMethod,
RoutingMethodType, TRTLLMGenFusedMoE,
create_moe)
RoutingMethodType, create_moe)
from ..modules.gated_mlp import GatedMLP
from ..modules.linear import Linear, TensorParallelMode
from ..modules.mamba.causal_conv1d import causal_conv1d_fn, causal_conv1d_update
@ -133,8 +133,14 @@ class Qwen3NextSparseMoeBlock(nn.Module):
self.top_k = config.num_experts_per_tok
self.enable_attention_dp = model_config.mapping.enable_attention_dp
self.mapping = model_config.mapping
self.allreduce = AllReduce(mapping=model_config.mapping,
strategy=model_config.allreduce_strategy)
# Only create allreduce when not using attention DP
if not self.enable_attention_dp and self.mapping.tp_size > 1:
self.allreduce = AllReduce(mapping=model_config.mapping,
strategy=model_config.allreduce_strategy)
else:
self.allreduce = None
self.aux_stream = aux_stream
self.gate = Qwen3NextGate(
@ -165,14 +171,18 @@ class Qwen3NextSparseMoeBlock(nn.Module):
bias=config.mlp_bias if hasattr(config, 'mlp_bias') else False,
dtype=config.torch_dtype,
config=model_config,
reduce_output=False,
)
reduce_output=self.enable_attention_dp and self.mapping.tp_size > 1,
overridden_tp_size=1 if self.enable_attention_dp else None)
self.shared_expert_gate = Linear(self.hidden_dim,
1,
bias=False,
dtype=config.torch_dtype,
quant_config=None)
self.shared_expert_gate = Linear(
self.hidden_dim,
1,
bias=False,
dtype=config.torch_dtype,
allreduce_strategy=model_config.allreduce_strategy,
reduce_output=not self.enable_attention_dp
and self.mapping.tp_size > 1,
quant_config=None)
self.event_dict = {
key: torch.cuda.Event()
@ -192,27 +202,32 @@ class Qwen3NextSparseMoeBlock(nn.Module):
use_dp_padding = False
all_rank_num_tokens = attn_metadata.all_rank_num_tokens
if self.enable_attention_dp and self.mapping.tp_size > 1 and get_sm_version(
) == 120:
use_dp_padding = True
hidden_states = torch.nn.functional.pad(
hidden_states,
(0, 0, 0, max(all_rank_num_tokens) - hidden_states.shape[0]))
if not do_finalize:
# TODO: support do_finalize == False
raise NotImplementedError(
"do_finalize == False is not supported yet")
if self.enable_attention_dp and self.mapping.tp_size > 1:
if isinstance(self.experts, TRTLLMGenFusedMoE):
hidden_states = allgather(hidden_states,
self.mapping,
dim=0,
sizes=all_rank_num_tokens)
def _compute_routed_output():
router_logits = self.gate(hidden_states)
final_hidden_states = self.experts(
hidden_states,
router_logits,
output_dtype=hidden_states.dtype,
all_rank_num_tokens=all_rank_num_tokens,
final_all_reduce_params=AllReduceParams(
enable_allreduce=not self.enable_attention_dp
and self.mapping.tp_size > 1),
use_dp_padding=use_dp_padding,
do_finalize=do_finalize,
)
return final_hidden_states
def _compute_shared_output():
@ -418,9 +433,8 @@ class Qwen3NextGatedDeltaNet(nn.Module):
enable_attention_dp=model_config.mapping.enable_attention_dp,
)
self.mapping = mapping
self.attn_tp_rank = mapping.tp_rank
self.attn_tp_size = mapping.tp_size
self.attn_tp_size = 1 if model_config.mapping.enable_attention_dp else mapping.tp_size
self.hidden_size = config.hidden_size
self.num_v_heads = config.linear_num_value_heads
self.num_k_heads = config.linear_num_key_heads
@ -517,7 +531,8 @@ class Qwen3NextGatedDeltaNet(nn.Module):
mapping=mapping,
tensor_parallel_mode=TensorParallelMode.ROW,
quant_config=model_config.get_quant_config(),
reduce_output=True,
reduce_output=not self.model_config.mapping.enable_attention_dp
and self.model_config.mapping.tp_size > 1,
skip_create_weights_in_init=model_config.
skip_create_weights_in_init,
allreduce_strategy=model_config.allreduce_strategy,
@ -828,7 +843,6 @@ class Qwen3NextGatedDeltaNet(nn.Module):
attn_out = self.norm(attn_out, z)
attn_out = attn_out.reshape(z_shape_og)
attn_out = attn_out.reshape(*attn_out.shape[:-2], -1)
output = self.out_proj(attn_out, all_reduce_params=all_reduce_params)
return output
@ -865,8 +879,13 @@ class Qwen3NextLinearDecoderLayer(DecoderLayer):
use_gemma=True)
self.layer_idx = layer_idx
self.allreduce = AllReduce(mapping=model_config.mapping,
strategy=model_config.allreduce_strategy)
# Only create allreduce when not using attention DP
if not self.enable_attention_dp and model_config.mapping.tp_size > 1:
self.allreduce = AllReduce(mapping=model_config.mapping,
strategy=model_config.allreduce_strategy)
else:
self.allreduce = None
self.next_layer_layernorm: RMSNorm = None
self.fusion_config = EagerFusionConfig()
@ -875,15 +894,14 @@ class Qwen3NextLinearDecoderLayer(DecoderLayer):
"TRTLLM_QWEN3_EAGER_FUSION_DISABLED", "1") == "0"
self.enable_fusion &= not self.enable_attention_dp
# has_tp = self.mapping.has_tp()
has_tp = self.mapping.has_tp()
has_pp = self.mapping.has_pp()
# self.fusion_config.PRE_MOE_FUSION = self.enable_fusion and has_tp
self.fusion_config.PRE_MOE_FUSION = False # the fusion kernel does not support gemmaNorm yet
self.fusion_config.POST_MOE_FUSION = self.fusion_config.PRE_MOE_FUSION and not has_pp
self.disable_attn_allreduce = (self.fusion_config.PRE_MOE_FUSION
or self.mapping.tp_size == 1
self.fusion_config.PRE_MOE_FUSION = self.enable_fusion and has_tp
self.fusion_config.POST_MOE_FUSION = self.fusion_config.PRE_MOE_FUSION and not has_pp and self.enable_attention_dp
self.disable_attn_allreduce = (self.mapping.tp_size == 1
or self.enable_attention_dp)
self.moe_allreduce = MoEAllReduce(mapping=model_config.mapping)
def forward(
@ -908,8 +926,7 @@ class Qwen3NextLinearDecoderLayer(DecoderLayer):
hidden_states,
attn_metadata,
all_reduce_params=AllReduceParams(
enable_allreduce=not (self.fusion_config.PRE_MOE_FUSION
or self.mapping.tp_size == 1)),
enable_allreduce=not self.disable_attn_allreduce),
**kwargs)
if self.fusion_config.PRE_MOE_FUSION:
hidden_states, residual = self.allreduce(
@ -919,8 +936,7 @@ class Qwen3NextLinearDecoderLayer(DecoderLayer):
residual=residual,
norm_weight=self.post_attention_layernorm.weight,
eps=self.post_attention_layernorm.variance_epsilon,
enable_allreduce=not (self.fusion_config.PRE_MOE_FUSION
or self.mapping.tp_size == 1),
enable_allreduce=not self.disable_attn_allreduce,
))
else:
# No fusion
@ -928,9 +944,9 @@ class Qwen3NextLinearDecoderLayer(DecoderLayer):
hidden_states, residual)
# Note: this fusion pattern is only supported for TRTLLM-nvfp4 backend now
do_finalize = not (hidden_states.shape[0]
do_finalize = not (self.fusion_config.POST_MOE_FUSION
and hidden_states.shape[0]
<= self.moe_allreduce.max_token
and self.fusion_config.POST_MOE_FUSION
and self.model_config.moe_backend == 'TRTLLM'
and self.mlp.experts.has_nvfp4)
@ -942,6 +958,7 @@ class Qwen3NextLinearDecoderLayer(DecoderLayer):
or self.mapping.tp_size == 1)),
do_finalize=do_finalize,
)
if self.fusion_config.POST_MOE_FUSION:
if do_finalize:
hidden_states, residual = self.allreduce(
@ -986,12 +1003,13 @@ class Qwen3NextLinearDecoderLayer(DecoderLayer):
class Qwen3NextAttention(Qwen3Attention):
def __init__(self, model_config: ModelConfig[Qwen3NextConfig],
layer_idx: int, fuse_qk_norm_rope: bool):
layer_idx: int, fuse_qk_norm_rope: bool, reduce_output: bool):
super().__init__(model_config,
layer_idx,
fuse_qk_norm_rope=fuse_qk_norm_rope,
attn_output_gate=True,
use_gemma_rms_norm=True)
use_gemma_rms_norm=True,
reduce_output=reduce_output)
class Qwen3NextFullAttentionDecoderLayer(DecoderLayer):
@ -1002,13 +1020,16 @@ class Qwen3NextFullAttentionDecoderLayer(DecoderLayer):
self.model_config = model_config
config = model_config.pretrained_config
self.mapping = model_config.mapping
self.enable_attention_dp = self.mapping.enable_attention_dp
self.self_attn = Qwen3NextAttention(
model_config,
layer_idx=layer_idx,
fuse_qk_norm_rope=False,
reduce_output=not self.enable_attention_dp
and self.mapping.tp_size > 1,
)
self.mapping = model_config.mapping
self.enable_attention_dp = self.mapping.enable_attention_dp
self.mlp = Qwen3NextSparseMoeBlock(model_config,
aux_stream,
@ -1025,8 +1046,13 @@ class Qwen3NextFullAttentionDecoderLayer(DecoderLayer):
use_gemma=True)
self.layer_idx = layer_idx
self.allreduce = AllReduce(mapping=model_config.mapping,
strategy=model_config.allreduce_strategy)
# Only create allreduce when not using attention DP
if not self.enable_attention_dp and model_config.mapping.tp_size > 1:
self.allreduce = AllReduce(mapping=model_config.mapping,
strategy=model_config.allreduce_strategy)
else:
self.allreduce = None
self.next_layer_layernorm: RMSNorm = None
self.fusion_config = EagerFusionConfig()
@ -1034,14 +1060,13 @@ class Qwen3NextFullAttentionDecoderLayer(DecoderLayer):
"TRTLLM_QWEN3_EAGER_FUSION_DISABLED", "0") == "0"
self.enable_fusion &= not self.enable_attention_dp
# has_tp = self.mapping.has_tp()
has_tp = self.mapping.has_tp()
has_pp = self.mapping.has_pp()
# self.fusion_config.PRE_MOE_FUSION = self.enable_fusion and has_tp
self.fusion_config.PRE_MOE_FUSION = False
self.fusion_config.POST_MOE_FUSION = self.fusion_config.PRE_MOE_FUSION and not has_pp
self.disable_attn_allreduce = (self.fusion_config.PRE_MOE_FUSION
or self.mapping.tp_size == 1
self.fusion_config.PRE_MOE_FUSION = self.enable_fusion and has_tp
self.fusion_config.POST_MOE_FUSION = self.fusion_config.PRE_MOE_FUSION and not has_pp and self.enable_attention_dp
self.disable_attn_allreduce = (self.mapping.tp_size == 1
or self.enable_attention_dp)
self.moe_allreduce = MoEAllReduce(mapping=model_config.mapping)
@ -1072,7 +1097,7 @@ class Qwen3NextFullAttentionDecoderLayer(DecoderLayer):
**kwargs,
)
if self.fusion_config.PRE_MOE_FUSION:
if self.fusion_config.PRE_MOE_FUSION and self.enable_attention_dp:
hidden_states, residual = self.allreduce(
hidden_states,
all_reduce_params=AllReduceParams(
@ -1092,7 +1117,6 @@ class Qwen3NextFullAttentionDecoderLayer(DecoderLayer):
and self.fusion_config.POST_MOE_FUSION
and self.model_config.moe_backend == 'TRTLLM'
and self.mlp.experts.has_nvfp4)
hidden_states = self.mlp(
hidden_states,
attn_metadata,

View File

@ -45,7 +45,7 @@ class MambaCacheManager(BaseResourceManager):
self.mamba_ssm_cache_dtype = ssm_cache_dtype
# get tp size
tp_size = mapping.tp_size if not mapping.enable_attention_dp else 1
tp_size = 1 if mapping.enable_attention_dp else mapping.tp_size
# derive mamba parameters for conv and ssm states
d_inner = head_dim * num_heads

View File

@ -4833,18 +4833,21 @@ class TestQwen3NextInstruct(LlmapiAccuracyTestHarness):
@pytest.mark.skip_less_device(4)
@pytest.mark.parametrize(
"tp_size,pp_size,ep_size,cuda_graph,overlap_scheduler",
"tp_size,pp_size,ep_size,cuda_graph,overlap_scheduler,attention_dp",
[
(4, 1, 4, True, True),
(4, 1, 4, True, True, False),
(4, 1, 4, True, True, True),
],
ids=[
"tp4ep4_cudagraph_overlap",
"tp4ep4_cudagraph_overlap_adp_off",
"tp4ep4_cudagraph_overlap_adp_on",
],
)
def test_bf16_4gpu(self, tp_size, pp_size, ep_size, cuda_graph,
overlap_scheduler, mocker):
overlap_scheduler, attention_dp, mocker):
model_path = f"{self.MODEL_PATH}/Qwen3-Next-80B-A3B-Instruct"
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.6,
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.8,
enable_block_reuse=False)
pytorch_config = dict(
disable_overlap_scheduler=not overlap_scheduler,
@ -4857,9 +4860,11 @@ class TestQwen3NextInstruct(LlmapiAccuracyTestHarness):
model_path,
tensor_parallel_size=tp_size,
max_num_tokens=16384,
moe_config=MoeConfig(backend="CUTLASS"),
pipeline_parallel_size=pp_size,
moe_expert_parallel_size=ep_size,
kv_cache_config=kv_cache_config,
enable_attention_dp=attention_dp,
**pytorch_config,
) as llm:
task = MMLU(self.MODEL_NAME)
@ -4875,19 +4880,26 @@ class TestQwen3NextInstruct(LlmapiAccuracyTestHarness):
@pytest.mark.parametrize("moe_backend", ["CUTLASS", "TRTLLM"],
ids=["cutlass", "trtllm"])
@pytest.mark.parametrize(
"tp_size,pp_size,ep_size,cuda_graph,overlap_scheduler",
[(1, 1, 1, True, True), (4, 1, 1, True, True), (4, 1, 4, True, True),
(4, 1, 4, False, False)],
ids=["tp1", "tp4ep1", "tp4ep4", "no_cuda_graph_overlap"])
"tp_size,pp_size,ep_size,cuda_graph,overlap_scheduler,attention_dp", [
(1, 1, 1, True, True, False),
(4, 1, 1, True, True, False),
(4, 1, 4, True, True, True),
(4, 1, 4, True, True, False),
(4, 1, 4, False, False, False),
],
ids=[
"tp1", "tp4ep1", "tp4ep4_adp_on", "tp4ep4_adp_off",
"no_cuda_graph_overlap"
])
def test_nvfp4(self, moe_backend, tp_size, pp_size, ep_size, cuda_graph,
overlap_scheduler, mocker):
overlap_scheduler, attention_dp, mocker):
model_path = f"{self.MODEL_PATH}/qwen3-next-80b-instruct-nvfp4-ptq-fp8kv"
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.6,
enable_block_reuse=False)
pytorch_config = dict(disable_overlap_scheduler=not overlap_scheduler,
cuda_graph_config=CudaGraphConfig(
max_batch_size=512, enable_padding=True)
max_batch_size=512, enable_padding=False)
if cuda_graph else None)
moe_config = MoeConfig(backend=moe_backend)
@ -4897,6 +4909,7 @@ class TestQwen3NextInstruct(LlmapiAccuracyTestHarness):
pipeline_parallel_size=pp_size,
moe_expert_parallel_size=ep_size,
kv_cache_config=kv_cache_config,
enable_attention_dp=attention_dp,
**pytorch_config,
moe_config=moe_config) as llm:
task = MMLU(self.MODEL_NAME)

View File

@ -66,13 +66,18 @@ l0_gb200_multi_gpus:
- accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3_4gpus[cutlass-one_model-overlap_scheduler]
- accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3_4gpus[trtllm-one_model-overlap_scheduler]
- accuracy/test_llm_api_pytorch.py::TestQwen3NextThinking::test_auto_dtype[tp4ep4]
- accuracy/test_llm_api_pytorch.py::TestQwen3NextInstruct::test_bf16_4gpu[tp4ep4_cudagraph_overlap]
- accuracy/test_llm_api_pytorch.py::TestQwen3NextInstruct::test_bf16_4gpu[tp4ep4_cudagraph_overlap_adp_off]
- accuracy/test_llm_api_pytorch.py::TestQwen3NextInstruct::test_bf16_4gpu[tp4ep4_cudagraph_overlap_adp_on]
- accuracy/test_llm_api_pytorch.py::TestQwen3NextInstruct::test_nvfp4[tp1-cutlass]
- accuracy/test_llm_api_pytorch.py::TestQwen3NextInstruct::test_nvfp4[tp4ep1-cutlass]
- accuracy/test_llm_api_pytorch.py::TestQwen3NextInstruct::test_nvfp4[tp4ep4-cutlass]
- accuracy/test_llm_api_pytorch.py::TestQwen3NextInstruct::test_nvfp4[tp4ep4_adp_on-cutlass]
- accuracy/test_llm_api_pytorch.py::TestQwen3NextInstruct::test_nvfp4[tp4ep4_adp_off-cutlass]
- accuracy/test_llm_api_pytorch.py::TestQwen3NextInstruct::test_nvfp4[no_cuda_graph_overlap-cutlass]
- accuracy/test_llm_api_pytorch.py::TestQwen3NextInstruct::test_nvfp4[tp4ep4-trtllm]
- accuracy/test_llm_api_pytorch_multimodal.py::TestMistralLarge3_675B::test_nvfp4_4gpus[latency_moe_trtllm] TIMEOUT (90)
- accuracy/test_llm_api_pytorch.py::TestQwen3NextInstruct::test_nvfp4[tp4ep4_adp_on-trtllm]
- accuracy/test_llm_api_pytorch.py::TestQwen3NextInstruct::test_nvfp4[tp4ep4_adp_off-trtllm]
- condition:
ranges:
system_gpu_count: