[Fix][nvbug 5401163][nvbug 5404726][Qwen3] Fix bug of MoE on tp > 1 with trtllm moe backend (#6235)

Signed-off-by: bhsueh <11360707+byshiue@users.noreply.github.com>
This commit is contained in:
bhsueh_NV 2025-07-24 21:47:37 +08:00 committed by GitHub
parent 0cc1f8c03d
commit 7b6aadc800
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 36 additions and 8 deletions

View File

@ -309,6 +309,13 @@ class Qwen3MoEModel(DecoderModel):
super().__init__(model_config)
config = self.model_config
self.aux_stream = torch.cuda.Stream()
self.preload_weight_modules = []
if config.moe_backend == "TRTLLM":
self.preload_weight_modules = [
"experts",
"routing_method",
"all_reduce",
]
if model_config.mapping.enable_attention_dp:
# When attention_dp is enabled, we cannot do all_reduce since
@ -381,6 +388,7 @@ class Qwen3MoeForCausalLM(SpecDecOneEngineForCausalLM[Qwen3MoEModel,
Qwen3MoEModel(model_config),
model_config,
)
self.preload_weight_modules = self.model.preload_weight_modules
def load_weights(self, weights: dict, weight_mapper: BaseWeightMapper):
super().load_weights(weights, weight_mapper)

View File

@ -865,7 +865,7 @@ def _load_weights_impl_v2(model: Union[nn.Module, DecoderModelForCausalLM],
skip_modules: List[str] = [],
params_map: Optional[Dict[str, str]] = None,
preload_weight_modules: Optional[List[str]] = None):
# TODO: remove preload_weight_modules - it is a workaround for min-latency llama4 model loading where
# TODO: remove preload_weight_modules - it is a workaround for min-latency llama4 and Qwen3 model loading where
# we need some order in the module loading. Once this is resolved, we can remove this workaround.
weight_mapper.add_skip_modules(skip_modules)
if params_map is not None:

View File

@ -77,6 +77,8 @@ Qwen3/Qwen3-30B-A3B:
- quant_algo: NVFP4
kv_cache_quant_algo: FP8
accuracy: 83.43
- spec_dec_algo: Eagle
accuracy: 83.43
Qwen3/Qwen3-235B-A22B:
- quant_algo: FP8
kv_cache_quant_algo: FP8

View File

@ -1756,6 +1756,31 @@ class TestQwen3_30B_A3B(LlmapiAccuracyTestHarness):
task = GSM8K(self.MODEL_NAME)
task.evaluate(llm)
def test_eagle3(self):
pytorch_config = dict(
disable_overlap_scheduler=True,
cuda_graph_config=CudaGraphConfig(batch_sizes=[1, 2, 3, 4, 8]),
)
kv_cache_config = KvCacheConfig(enable_block_reuse=False)
eagle_model_dir = f"{llm_models_root()}/Qwen3/Qwen3-30B-eagle3"
target_model_dir = f"{llm_models_root()}/Qwen3/Qwen3-30B-A3B"
draft_len = 1
spec_config = EagleDecodingConfig(max_draft_len=draft_len,
speculative_model_dir=eagle_model_dir,
eagle3_one_model=True)
llm = LLM(model=target_model_dir,
**pytorch_config,
kv_cache_config=kv_cache_config,
speculative_config=spec_config,
max_seq_len=8192)
with llm:
task = GSM8K(self.MODEL_NAME)
task.evaluate(llm)
class TestQwen3_32B(LlmapiAccuracyTestHarness):
MODEL_NAME = "Qwen3/Qwen3-32B"
@ -1822,10 +1847,6 @@ class TestQwen3_235B_A22B(LlmapiAccuracyTestHarness):
)
def test_nvfp4(self, tp_size, pp_size, ep_size, attention_dp, cuda_graph,
overlap_scheduler, moe_backend):
if moe_backend == "TRTLLM":
pytest.skip(
"TRTLLM moe backend has accuracy issues: https://nvbugspro.nvidia.com/bug/5404726"
)
pytorch_config = dict(
disable_overlap_scheduler=not overlap_scheduler,

View File

@ -391,7 +391,6 @@ examples/test_llama.py::test_llm_llama_v3_1_2nodes_8gpus[llama-3.1-8b-disable_fp
test_e2e.py::test_openai_multinodes_chat_tp16pp1 SKIP (https://nvbugs/5112075)
examples/test_qwen.py::test_llm_hf_qwen_quantization_1gpu[qwen2_vl_7b_instruct-fp8-bfloat16] SKIP (https://nvbugs/5322488)
accuracy/test_cli_flow.py::TestSantacoder::test_auto_dtype SKIP (https://nvbugs/5234043)
full:B200/accuracy/test_llm_api_pytorch.py::TestQwen3_235B_A22B::test_nvfp4[latency_moe_trtllm] SKIP (https://nvbugs/5401163)
examples/test_multimodal.py::test_llm_multimodal_general[VILA1.5-3b-pp:1-tp:1-float16-bs:8-cpp_e2e:True-nb:1] SKIP (https://nvbugs/5360086)
examples/test_gpt.py::test_starcoder_fp8_quantization_2gpu[starcoder] SKIP (https://nvbugs/5355128)
examples/test_gpt.py::test_starcoder_fp8_quantization_2gpu[starcoderplus] SKIP (https://nvbugs/5355128)
@ -422,8 +421,6 @@ triton_server/test_triton_llm.py::test_llava_onevision[test_video-False-1---Fals
triton_server/test_triton.py::test_cpp_unit_tests[cpp-unit-tests] SKIP (https://nvbugs/5401088)
accuracy/test_llm_api_pytorch.py::TestGemma3_27BInstruct::test_auto_dtype SKIP (https://nvbugs/5401114)
test_e2e.py::test_ptp_quickstart_multimodal[gemma-3-27b-it-gemma/gemma-3-27b-it-image-True] SKIP (https://nvbugs/5401114)
accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_nvfp4[dep4_latency_moe_trtllm] SKIP (https://nvbugs/5401163)
accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_nvfp4[tep4_latency_moe_trtllm] SKIP (https://nvbugs/5401163)
examples/test_recurrentgemma.py::test_llm_recurrentgemma_1gpu[use_cpp_session-recurrentgemma-2b-use_paged_cache-int4_awq-float16-enable_attn_plugin-enable_gemm_plugin] SKIP (https://nvbugs/5401233)
examples/test_recurrentgemma.py::test_llm_recurrentgemma_2gpu[recurrentgemma-2b] SKIP (https://nvbugs/5401233)
examples/test_multimodal.py::test_llm_multimodal_general[VILA1.5-3b-pp:1-tp:1-float16-bs:1-cpp_e2e:False-nb:1] SKIP (https://nvbugs/5401156)