mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
[TRTLLM-9603][feat] Enable ConfigurableMoE test in the CI (#9645)
This commit is contained in:
parent
4da0e1473c
commit
8e27ce7084
@ -170,18 +170,23 @@ class ConfigurableMoE(MoE):
|
||||
# ConfigurableMoE's super().__init__() was called with real layer_idx and initialized load balancer.
|
||||
# Backend was created with init_load_balancer=False and without_comm=True to avoid
|
||||
# duplicate initialization. Now sync all attributes from ConfigurableMoE to backend.
|
||||
self.backend.aux_stream_dict = self.aux_stream_dict
|
||||
self.backend.layer_idx = self.layer_idx
|
||||
self.backend.layer_idx_str = self.layer_idx_str
|
||||
self.backend.num_slots = self.num_slots
|
||||
self.backend.layer_load_balancer = self.layer_load_balancer
|
||||
self.backend.repeat_count = self.repeat_count
|
||||
self.backend.repeat_idx = self.repeat_idx
|
||||
self.backend.initial_local_expert_ids = self.initial_local_expert_ids
|
||||
self.backend.initial_global_assignments = self.initial_global_assignments
|
||||
self.backend.slot_start = self.slot_start
|
||||
self.backend.slot_end = self.slot_end
|
||||
self.backend.expert_size_per_partition = self.expert_size_per_partition
|
||||
if self.backend is not None:
|
||||
# Add a check to WAR the issue that the backend is none during torch.compile
|
||||
assert not torch.compiler.is_compiling(), (
|
||||
"Backend should not be none if not in torch.compile"
|
||||
)
|
||||
self.backend.aux_stream_dict = self.aux_stream_dict
|
||||
self.backend.layer_idx = self.layer_idx
|
||||
self.backend.layer_idx_str = self.layer_idx_str
|
||||
self.backend.num_slots = self.num_slots
|
||||
self.backend.layer_load_balancer = self.layer_load_balancer
|
||||
self.backend.repeat_count = self.repeat_count
|
||||
self.backend.repeat_idx = self.repeat_idx
|
||||
self.backend.initial_local_expert_ids = self.initial_local_expert_ids
|
||||
self.backend.initial_global_assignments = self.initial_global_assignments
|
||||
self.backend.slot_start = self.slot_start
|
||||
self.backend.slot_end = self.slot_end
|
||||
self.backend.expert_size_per_partition = self.expert_size_per_partition
|
||||
|
||||
# Create weights here, because the backend needs the layer_load_balancer info to create weights
|
||||
model_config._frozen = False
|
||||
|
||||
@ -13,9 +13,37 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from mpi4py.futures import MPIPoolExecutor
|
||||
|
||||
|
||||
def patch_mpi_pool_session_for_env(mocker, env_vars: dict):
|
||||
"""
|
||||
Patch MpiPoolSession._start_mpi_pool to propagate environment variables to MPI child processes.
|
||||
|
||||
Uses MPIPoolExecutor's built-in `env` parameter instead of `initializer` to avoid
|
||||
segfault issues during process cleanup (UCX memory cache conflicts with PyTorch
|
||||
tensor cleanup during Py_FinalizeEx).
|
||||
|
||||
Args:
|
||||
mocker: pytest-mock mocker fixture
|
||||
env_vars: Dictionary of environment variable name -> value to propagate
|
||||
"""
|
||||
from tensorrt_llm.llmapi.mpi_session import MpiPoolSession
|
||||
|
||||
def patched_start_mpi_pool(self):
|
||||
assert not self.mpi_pool, 'MPI session already started'
|
||||
self.mpi_pool = MPIPoolExecutor(max_workers=self.n_workers,
|
||||
path=sys.path,
|
||||
env=env_vars)
|
||||
|
||||
mocker.patch.object(MpiPoolSession, '_start_mpi_pool',
|
||||
patched_start_mpi_pool)
|
||||
|
||||
|
||||
from defs.conftest import get_sm_version, is_sm_100f
|
||||
|
||||
from tensorrt_llm import LLM
|
||||
@ -1830,9 +1858,24 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
|
||||
ids=["tp4", "ep4", "tp2pp2", "pp4"])
|
||||
@parametrize_with_ids("mtp_nextn", [0, 2])
|
||||
@parametrize_with_ids("moe_backend", ["CUTLASS", "TRTLLM", "CUTEDSL"])
|
||||
@pytest.mark.parametrize("enable_configurable_moe", [0, 1],
|
||||
ids=lambda x: ""
|
||||
if x == 0 else "enable_configurable_moe")
|
||||
def test_nvfp4_4gpus(self, fp8kv, attention_dp, cuda_graph,
|
||||
overlap_scheduler, tp_size, pp_size, ep_size,
|
||||
torch_compile, mtp_nextn, moe_backend):
|
||||
torch_compile, mtp_nextn, moe_backend,
|
||||
enable_configurable_moe, mocker):
|
||||
# Handle ENABLE_CONFIGURABLE_MOE environment variable
|
||||
if enable_configurable_moe == 1 and moe_backend != "TRTLLM":
|
||||
pytest.skip(
|
||||
f"ENABLE_CONFIGURABLE_MOE=1 is only supported with TRTLLM backend, "
|
||||
f"current backend is {moe_backend}")
|
||||
|
||||
# Patch MpiPoolSession to propagate env vars to MPI worker processes
|
||||
env_value = "1" if enable_configurable_moe == 1 and moe_backend == "TRTLLM" else "0"
|
||||
patch_mpi_pool_session_for_env(mocker,
|
||||
{"ENABLE_CONFIGURABLE_MOE": env_value})
|
||||
|
||||
if moe_backend == "TRTLLM" and (get_sm_version() == 120
|
||||
or get_sm_version() == 121):
|
||||
pytest.skip(
|
||||
@ -3452,9 +3495,23 @@ class TestQwen3_30B_A3B(LlmapiAccuracyTestHarness):
|
||||
ids=["latency", "ep2", "ep4"])
|
||||
@pytest.mark.parametrize("activation_dtype", ["static_fp8", "mxfp8"],
|
||||
ids=["fp8", "mxfp8"])
|
||||
@pytest.mark.parametrize("enable_configurable_moe", [0, 1],
|
||||
ids=lambda x: ""
|
||||
if x == 0 else "enable_configurable_moe")
|
||||
def test_w4a8_mxfp4(self, moe_backend, tp_size, pp_size, ep_size,
|
||||
attention_dp, cuda_graph, overlap_scheduler,
|
||||
activation_dtype):
|
||||
activation_dtype, enable_configurable_moe, mocker):
|
||||
# Handle ENABLE_CONFIGURABLE_MOE environment variable
|
||||
if enable_configurable_moe == 1 and moe_backend != "TRTLLM":
|
||||
pytest.skip(
|
||||
f"ENABLE_CONFIGURABLE_MOE=1 is only supported with TRTLLM backend, "
|
||||
f"current backend is {moe_backend}")
|
||||
|
||||
# Patch MpiPoolSession to propagate env vars to MPI worker processes
|
||||
env_value = "1" if enable_configurable_moe == 1 and moe_backend == "TRTLLM" else "0"
|
||||
patch_mpi_pool_session_for_env(mocker,
|
||||
{"ENABLE_CONFIGURABLE_MOE": env_value})
|
||||
|
||||
if moe_backend == "TRITON":
|
||||
if not IS_TRITON_KERNELS_AVAILABLE:
|
||||
pytest.skip("TRITON moe backend is not available.")
|
||||
@ -3906,9 +3963,23 @@ class TestGPTOSS(LlmapiAccuracyTestHarness):
|
||||
(4, 1, 4, True, True, True),
|
||||
],
|
||||
ids=["tp4", "ep4", "dp4"])
|
||||
@pytest.mark.parametrize("enable_configurable_moe", [0, 1],
|
||||
ids=lambda x: ""
|
||||
if x == 0 else "enable_configurable_moe")
|
||||
def test_w4_4gpus(self, kv_cache_dtype, moe_backend, tp_size, pp_size,
|
||||
ep_size, attention_dp, cuda_graph, overlap_scheduler,
|
||||
mocker):
|
||||
enable_configurable_moe, mocker):
|
||||
# Handle ENABLE_CONFIGURABLE_MOE environment variable
|
||||
if enable_configurable_moe == 1 and moe_backend != "TRTLLM":
|
||||
pytest.skip(
|
||||
f"ENABLE_CONFIGURABLE_MOE=1 is only supported with TRTLLM backend, "
|
||||
f"current backend is {moe_backend}")
|
||||
|
||||
# Patch MpiPoolSession to propagate env vars to MPI worker processes
|
||||
env_value = "1" if enable_configurable_moe == 1 and moe_backend == "TRTLLM" else "0"
|
||||
patch_mpi_pool_session_for_env(mocker,
|
||||
{"ENABLE_CONFIGURABLE_MOE": env_value})
|
||||
|
||||
if moe_backend == "TRITON":
|
||||
if not IS_TRITON_KERNELS_AVAILABLE:
|
||||
pytest.skip("Triton kernels are not available")
|
||||
@ -3925,7 +3996,8 @@ class TestGPTOSS(LlmapiAccuracyTestHarness):
|
||||
|
||||
pytorch_config = dict(
|
||||
disable_overlap_scheduler=not overlap_scheduler,
|
||||
cuda_graph_config=CudaGraphConfig() if cuda_graph else None)
|
||||
cuda_graph_config=CudaGraphConfig() if cuda_graph else None,
|
||||
moe_config=MoeConfig(backend=moe_backend))
|
||||
|
||||
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.7,
|
||||
dtype=kv_cache_dtype)
|
||||
@ -3939,8 +4011,7 @@ class TestGPTOSS(LlmapiAccuracyTestHarness):
|
||||
max_seq_len=max_seq_len,
|
||||
max_batch_size=720,
|
||||
**pytorch_config,
|
||||
enable_attention_dp=attention_dp,
|
||||
moe_config=MoeConfig(backend=moe_backend))
|
||||
enable_attention_dp=attention_dp)
|
||||
|
||||
with llm:
|
||||
model_name = "GPT-OSS/120B-MXFP4"
|
||||
@ -4252,8 +4323,17 @@ class TestGPTOSS(LlmapiAccuracyTestHarness):
|
||||
@pytest.mark.parametrize(
|
||||
"kv_cache_dtype",
|
||||
["auto", pytest.param("fp8", marks=skip_pre_blackwell)])
|
||||
def test_w4_4gpus_online_eplb(self, kv_cache_dtype, mocker):
|
||||
@pytest.mark.parametrize("enable_configurable_moe", [0, 1],
|
||||
ids=lambda x: ""
|
||||
if x == 0 else "enable_configurable_moe")
|
||||
def test_w4_4gpus_online_eplb(self, kv_cache_dtype, enable_configurable_moe,
|
||||
mocker):
|
||||
"""Test GPTOSS with online expert parallel load balancer using TRTLLM backend and attention DP."""
|
||||
# Patch MpiPoolSession to propagate env vars to MPI worker processes
|
||||
env_value = "1" if enable_configurable_moe == 1 else "0"
|
||||
patch_mpi_pool_session_for_env(mocker,
|
||||
{"ENABLE_CONFIGURABLE_MOE": env_value})
|
||||
|
||||
mocker.patch.object(GSM8K, "MAX_OUTPUT_LEN", 8192)
|
||||
mocker.patch.dict(GSM8K.EVALUATE_KWARGS,
|
||||
{"scores_filter": "exact_match,flexible-extract"})
|
||||
|
||||
@ -2209,6 +2209,94 @@ def pytest_generate_tests(metafunc: pytest.Metafunc):
|
||||
metafunc.parametrize("case", uts, ids=lambda x: x)
|
||||
|
||||
|
||||
# Test cases that use enable_configurable_moe parameter and need ID conversion
|
||||
TESTS_WITH_CONFIGURABLE_MOE = [
|
||||
"TestDeepSeekV3Lite::test_nvfp4_4gpus",
|
||||
"TestGPTOSS::test_w4_4gpus",
|
||||
"TestGPTOSS::test_w4_4gpus_online_eplb",
|
||||
"TestQwen3_30B_A3B::test_w4a8_mxfp4",
|
||||
]
|
||||
|
||||
|
||||
def _convert_clean_to_original_moe_test_id(test_id):
|
||||
"""Convert clean MoE test ID back to original format for pytest collection.
|
||||
|
||||
Example: "test_llm_api_pytorch.py::test_foo[param]" -> "test_llm_api_pytorch.py::test_foo[-param]"
|
||||
|
||||
This is needed because the `enable_configurable_moe` parameter uses empty string
|
||||
as ID when value is 0, resulting in test IDs like "test_foo[-param]".
|
||||
We clean these up in pytest_collection_modifyitems, but pytest filters tests
|
||||
during collection using the original IDs. So when user runs with clean test name,
|
||||
we need to convert it back to match the original.
|
||||
"""
|
||||
if "test_llm_api_pytorch.py" not in test_id:
|
||||
return test_id
|
||||
|
||||
# Match pattern like "test_name[params]" and add leading dash after "["
|
||||
# But only if params don't already start with "-" or "enable_configurable_moe"
|
||||
match = re.search(r"\[([^\]]+)\]", test_id)
|
||||
if match:
|
||||
params = match.group(1)
|
||||
# Skip if already has leading dash or starts with enable_configurable_moe
|
||||
if not params.startswith("-") and not params.startswith(
|
||||
"enable_configurable_moe"):
|
||||
# Add leading dash to params
|
||||
new_params = "-" + params
|
||||
test_id = test_id.replace(f"[{params}]", f"[{new_params}]")
|
||||
|
||||
return test_id
|
||||
|
||||
|
||||
def pytest_sessionstart(session):
|
||||
"""Convert clean MoE test IDs in config.args to original format for collection.
|
||||
|
||||
This is needed because pytest filters tests during collection using original IDs.
|
||||
When user runs with clean test name, we convert it back to match the original.
|
||||
"""
|
||||
args = session.config.args
|
||||
for i, arg in enumerate(args):
|
||||
if "test_llm_api_pytorch.py" in arg and "[" in arg:
|
||||
# Only apply conversion to specific tests that use enable_configurable_moe
|
||||
should_convert = any(test_name in arg
|
||||
for test_name in TESTS_WITH_CONFIGURABLE_MOE)
|
||||
if should_convert:
|
||||
args[i] = _convert_clean_to_original_moe_test_id(arg)
|
||||
|
||||
|
||||
def _clean_moe_test_ids(items):
|
||||
"""Clean up test IDs by removing leading/trailing dashes from parameter IDs.
|
||||
|
||||
This is needed because `enable_configurable_moe` parameter can be empty,
|
||||
resulting in ugly test IDs like "test_foo[-True]" or "test_foo[--abc]".
|
||||
We clean these up to "test_foo[True]" or "test_foo[abc]" so that:
|
||||
1. Test names in waive files and test lists remain unchanged
|
||||
2. Test reports look cleaner
|
||||
"""
|
||||
for item in items:
|
||||
if "test_llm_api_pytorch.py" in item.nodeid and "[" in item.nodeid:
|
||||
# Only apply cleanup to specific tests that use enable_configurable_moe
|
||||
should_cleanup = any(test_name in item.nodeid
|
||||
for test_name in TESTS_WITH_CONFIGURABLE_MOE)
|
||||
if should_cleanup:
|
||||
original_nodeid = item.nodeid
|
||||
original_name = item.name
|
||||
nodeid = item.nodeid
|
||||
name = item.name
|
||||
|
||||
# Clean up leading/trailing dashes in nodeid
|
||||
nodeid = nodeid.replace("[-", "[")
|
||||
nodeid = nodeid.replace("-]", "]")
|
||||
|
||||
# Clean up leading/trailing dashes in name
|
||||
name = name.replace("[-", "[")
|
||||
name = name.replace("-]", "]")
|
||||
|
||||
if nodeid != original_nodeid:
|
||||
item._nodeid = nodeid
|
||||
if name != original_name:
|
||||
item.name = name
|
||||
|
||||
|
||||
@pytest.hookimpl(tryfirst=True, hookwrapper=True)
|
||||
def pytest_collection_modifyitems(session, config, items):
|
||||
testlist_path = config.getoption("--test-list")
|
||||
@ -2217,6 +2305,10 @@ def pytest_collection_modifyitems(session, config, items):
|
||||
perf_test = config.getoption("--perf")
|
||||
test_model_suites = config.getoption("--test-model-suites")
|
||||
|
||||
# TODO Once the MoE refactor is complete, this should be removed.
|
||||
# This is a temporary WAR to minimize the impact of the MoE refactor on the existing test lists.
|
||||
_clean_moe_test_ids(items)
|
||||
|
||||
if perf_test:
|
||||
global ALL_PYTEST_ITEMS
|
||||
ALL_PYTEST_ITEMS = None
|
||||
|
||||
@ -17,6 +17,10 @@ l0_dgx_b200:
|
||||
tests:
|
||||
- unittest/_torch/modules/test_fused_moe.py::test_fused_moe_alltoall_fp4[DeepEPLowLatency]
|
||||
- unittest/_torch/modules/test_fused_moe.py::test_fused_moe_alltoall_fp4[MNNVL]
|
||||
- unittest/_torch/modules/test_fused_moe.py::test_fused_moe_nvfp4[enable_configurable_moe-TRTLLM-dtype1]
|
||||
- unittest/_torch/modules/test_fused_moe.py::test_fused_moe_w4a8_nvfp4_fp8[enable_configurable_moe-TRTLLM]
|
||||
- unittest/_torch/modules/test_fused_moe.py::test_fused_moe_mxfp4_mxfp8[enable_configurable_moe-True-8-64-TRTLLM]
|
||||
- unittest/_torch/modules/test_fused_moe.py::test_fused_moe_wfp4a16[enable_configurable_moe-TRTLLM-2880-dtype0]
|
||||
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_bfloat16_4gpus[pp4-attn_backend=TRTLLM-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[tp4-fp8kv=True-attn_backend=TRTLLM-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus[tp4-mtp_nextn=2-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=True]
|
||||
@ -158,6 +162,8 @@ l0_dgx_b200:
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus[pp4-mtp_nextn=2-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=True]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus[pp4-mtp_nextn=0-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus[pp4-mtp_nextn=2-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[enable_configurable_moe-moe_backend=TRTLLM-mtp_nextn=0-tp4-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[enable_configurable_moe-moe_backend=TRTLLM-mtp_nextn=0-ep4-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=True]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=0-tp4-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=True]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=0-tp4-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=True]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=0-tp2pp2-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=True]
|
||||
@ -191,12 +197,16 @@ l0_dgx_b200:
|
||||
- accuracy/test_llm_api_pytorch.py::TestLlama4ScoutInstruct::test_fp8_chunked_prefill[tp4ep4-cuda_graph=True]
|
||||
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_guided_decoding[xgrammar-mtp_nextn=0]
|
||||
- accuracy/test_disaggregated_serving.py::TestQwen3_30B_A3B::test_mixed_ctx_gen_model[ctxpp2gentp2]
|
||||
- accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[enable_configurable_moe-tp4-trtllm-fp8]
|
||||
- accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[enable_configurable_moe-ep4-trtllm-fp8]
|
||||
- accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[enable_configurable_moe-dp4-trtllm-fp8]
|
||||
- accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[tp4-cutlass-auto]
|
||||
- accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[tp4-triton-auto]
|
||||
- accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[ep4-trtllm-auto]
|
||||
- accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[ep4-trtllm-fp8]
|
||||
- accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[dp4-cutlass-auto]
|
||||
- accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[dp4-triton-auto]
|
||||
- accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus_online_eplb[enable_configurable_moe-fp8]
|
||||
- accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3[trtllm-one_model-overlap_scheduler]
|
||||
- accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3[trtllm-one_model-no_overlap_scheduler]
|
||||
- accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3[trtllm-two_model-no_overlap_scheduler]
|
||||
|
||||
118
tests/unittest/_torch/modules/conftest.py
Normal file
118
tests/unittest/_torch/modules/conftest.py
Normal file
@ -0,0 +1,118 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# TEMPORARY FILE - Will be removed after MoE refactor is complete.
|
||||
#
|
||||
# Background:
|
||||
# The `enable_configurable_moe` parameter is a temporary measure during the MoE
|
||||
# refactor. The old and new MoE flows will coexist for a period of time. To avoid
|
||||
# large-scale changes to the existing test lists, we handle the test ID cleanup
|
||||
# here. Once the refactor is complete and all tests use ConfigurableMoE by default,
|
||||
# this file will no longer be needed and should be deleted.
|
||||
#
|
||||
# Two-phase approach:
|
||||
# 1. pytest_sessionstart: Convert clean test names in CLI args back to original
|
||||
# format so pytest can find tests during collection.
|
||||
# 2. pytest_collection_modifyitems: Clean up the collected test IDs for display
|
||||
# and waive matching.
|
||||
import re
|
||||
|
||||
# Test functions that use enable_configurable_moe parameter and need ID conversion
|
||||
TESTS_WITH_CONFIGURABLE_MOE = [
|
||||
"test_fused_moe_nvfp4",
|
||||
"test_fused_moe_mxfp4_mxfp8",
|
||||
"test_fused_moe_w4a8_nvfp4_fp8",
|
||||
"test_fused_moe_wfp4a16",
|
||||
]
|
||||
|
||||
|
||||
def _convert_clean_to_original_moe_test_id(test_id):
|
||||
"""Convert clean MoE test ID back to original format for pytest collection.
|
||||
|
||||
Example: "test_fused_moe.py::test_foo[TRTLLM-dtype0]" -> "test_fused_moe.py::test_foo[-TRTLLM-dtype0]"
|
||||
|
||||
This is needed because the `enable_configurable_moe` parameter uses empty string
|
||||
as ID when value is 0, resulting in test IDs like "test_foo[-TRTLLM-dtype0]".
|
||||
We clean these up in pytest_collection_modifyitems, but pytest filters tests
|
||||
during collection using the original IDs. So when user runs with clean test name,
|
||||
we need to convert it back to match the original.
|
||||
"""
|
||||
if "test_fused_moe.py" not in test_id:
|
||||
return test_id
|
||||
|
||||
# Match pattern like "test_name[params]" and add leading dash after "["
|
||||
# But only if params don't already start with "-" or "enable_configurable_moe"
|
||||
match = re.search(r"\[([^\]]+)\]", test_id)
|
||||
if match:
|
||||
params = match.group(1)
|
||||
# Skip if already has leading dash or starts with enable_configurable_moe
|
||||
if not params.startswith("-") and not params.startswith("enable_configurable_moe"):
|
||||
# Add leading dash to params
|
||||
new_params = "-" + params
|
||||
test_id = test_id.replace(f"[{params}]", f"[{new_params}]")
|
||||
|
||||
return test_id
|
||||
|
||||
|
||||
def pytest_sessionstart(session):
|
||||
"""Convert clean MoE test IDs in config.args to original format for collection.
|
||||
|
||||
This is needed because pytest filters tests during collection using original IDs.
|
||||
When user runs with clean test name, we convert it back to match the original.
|
||||
"""
|
||||
args = session.config.args
|
||||
for i, arg in enumerate(args):
|
||||
if "test_fused_moe.py" in arg and "[" in arg:
|
||||
# Only apply conversion to specific tests that use enable_configurable_moe
|
||||
should_convert = any(test_name in arg for test_name in TESTS_WITH_CONFIGURABLE_MOE)
|
||||
if should_convert:
|
||||
args[i] = _convert_clean_to_original_moe_test_id(arg)
|
||||
|
||||
|
||||
def pytest_collection_modifyitems(items):
|
||||
"""Clean up test IDs by removing leading/trailing dashes from parameter IDs.
|
||||
|
||||
This is needed because `enable_configurable_moe` parameter can be empty,
|
||||
resulting in ugly test IDs like "test_foo[-True]" or "test_foo[--abc]".
|
||||
We clean these up to "test_foo[True]" or "test_foo[abc]" so that:
|
||||
1. Test names in waive files and test lists remain unchanged
|
||||
2. Test reports look cleaner
|
||||
|
||||
This runs BEFORE the global conftest applies waives (due to hookwrapper).
|
||||
"""
|
||||
for item in items:
|
||||
if "test_fused_moe.py" in item.nodeid and "[" in item.nodeid:
|
||||
# Only apply cleanup to specific tests that use enable_configurable_moe
|
||||
should_cleanup = any(
|
||||
test_name in item.nodeid for test_name in TESTS_WITH_CONFIGURABLE_MOE
|
||||
)
|
||||
if should_cleanup:
|
||||
original_nodeid = item.nodeid
|
||||
original_name = item.name
|
||||
nodeid = item.nodeid
|
||||
name = item.name
|
||||
|
||||
# Clean up leading/trailing dashes in nodeid
|
||||
nodeid = nodeid.replace("[-", "[")
|
||||
nodeid = nodeid.replace("-]", "]")
|
||||
|
||||
# Clean up leading/trailing dashes in name
|
||||
name = name.replace("[-", "[")
|
||||
name = name.replace("-]", "]")
|
||||
|
||||
if nodeid != original_nodeid:
|
||||
item._nodeid = nodeid
|
||||
if name != original_name:
|
||||
item.name = name
|
||||
@ -1356,7 +1356,20 @@ def test_fused_moe_fp8_blockwise_cute_dsl_multi_gpu(ep_size, routing_method,
|
||||
@pytest.mark.parametrize("moe_backend", [
|
||||
pytest.param("TRTLLM", marks=skip_blackwell_geforce), "CUTLASS", "CUTEDSL"
|
||||
])
|
||||
def test_fused_moe_nvfp4(dtype, moe_backend):
|
||||
@pytest.mark.parametrize("enable_configurable_moe", [0, 1],
|
||||
ids=lambda x: ""
|
||||
if x == 0 else "enable_configurable_moe")
|
||||
def test_fused_moe_nvfp4(dtype, moe_backend, enable_configurable_moe, mocker):
|
||||
|
||||
if enable_configurable_moe == 1 and moe_backend != "TRTLLM":
|
||||
pytest.skip("ENABLE_CONFIGURABLE_MOE=1, only TRTLLM backend is enabled")
|
||||
|
||||
mocker.patch.dict(
|
||||
os.environ, {
|
||||
"ENABLE_CONFIGURABLE_MOE":
|
||||
"1"
|
||||
if enable_configurable_moe == 1 and moe_backend == "TRTLLM" else "0"
|
||||
})
|
||||
|
||||
if moe_backend == "TRTLLM" and dtype == torch.float16:
|
||||
pytest.skip("TRTLLM NVFP4 MoE backend does not support float16 yet")
|
||||
@ -1515,7 +1528,20 @@ def test_fused_moe_nvfp4(dtype, moe_backend):
|
||||
@pytest.mark.parametrize(
|
||||
"moe_backend",
|
||||
[pytest.param("TRTLLM", marks=skip_blackwell_geforce), "CUTLASS"])
|
||||
def test_fused_moe_w4a8_nvfp4_fp8(moe_backend):
|
||||
@pytest.mark.parametrize("enable_configurable_moe", [0, 1],
|
||||
ids=lambda x: ""
|
||||
if x == 0 else "enable_configurable_moe")
|
||||
def test_fused_moe_w4a8_nvfp4_fp8(moe_backend, enable_configurable_moe, mocker):
|
||||
if enable_configurable_moe == 1 and moe_backend != "TRTLLM":
|
||||
pytest.skip("ENABLE_CONFIGURABLE_MOE=1, only TRTLLM backend is enabled")
|
||||
|
||||
mocker.patch.dict(
|
||||
os.environ, {
|
||||
"ENABLE_CONFIGURABLE_MOE":
|
||||
"1"
|
||||
if enable_configurable_moe == 1 and moe_backend == "TRTLLM" else "0"
|
||||
})
|
||||
|
||||
dtype = torch.bfloat16
|
||||
mapping = Mapping()
|
||||
mapping.rank = mpi_rank()
|
||||
@ -1930,7 +1956,21 @@ def test_fused_moe_w4afp8(dtype, weight_loading_mode):
|
||||
@pytest.mark.parametrize("hidden_unpadded", [64, 192, 256])
|
||||
@pytest.mark.parametrize("seq_len", [8, 128])
|
||||
@pytest.mark.parametrize("bias", [True, False])
|
||||
def test_fused_moe_mxfp4_mxfp8(moe_backend, hidden_unpadded, seq_len, bias):
|
||||
@pytest.mark.parametrize("enable_configurable_moe", [0, 1],
|
||||
ids=lambda x: ""
|
||||
if x == 0 else "enable_configurable_moe")
|
||||
def test_fused_moe_mxfp4_mxfp8(moe_backend, hidden_unpadded, seq_len, bias,
|
||||
enable_configurable_moe, mocker):
|
||||
|
||||
if enable_configurable_moe == 1 and moe_backend != "TRTLLM":
|
||||
pytest.skip("ENABLE_CONFIGURABLE_MOE=1, only TRTLLM backend is enabled")
|
||||
|
||||
mocker.patch.dict(
|
||||
os.environ, {
|
||||
"ENABLE_CONFIGURABLE_MOE":
|
||||
"1"
|
||||
if enable_configurable_moe == 1 and moe_backend == "TRTLLM" else "0"
|
||||
})
|
||||
|
||||
if moe_backend == "CUTLASS" and hidden_unpadded % 128 != 0:
|
||||
pytest.skip()
|
||||
@ -2191,7 +2231,21 @@ def test_fused_moe_mxfp4_mxfp8(moe_backend, hidden_unpadded, seq_len, bias):
|
||||
marks=[skip_pre_hopper, skip_blackwell, skip_blackwell_geforce]),
|
||||
],
|
||||
)
|
||||
def test_fused_moe_wfp4a16(dtype, hidden_size, moe_backend):
|
||||
@pytest.mark.parametrize("enable_configurable_moe", [0, 1],
|
||||
ids=lambda x: ""
|
||||
if x == 0 else "enable_configurable_moe")
|
||||
def test_fused_moe_wfp4a16(dtype, hidden_size, moe_backend,
|
||||
enable_configurable_moe, mocker):
|
||||
|
||||
if enable_configurable_moe == 1 and moe_backend != "TRTLLM":
|
||||
pytest.skip("ENABLE_CONFIGURABLE_MOE=1, only TRTLLM backend is enabled")
|
||||
|
||||
mocker.patch.dict(
|
||||
os.environ, {
|
||||
"ENABLE_CONFIGURABLE_MOE":
|
||||
"1"
|
||||
if enable_configurable_moe == 1 and moe_backend == "TRTLLM" else "0"
|
||||
})
|
||||
|
||||
mapping = Mapping()
|
||||
mapping.rank = mpi_rank()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user