[TRTLLM-10065][feat] Add accuracy tests for super-v3 with multiple-gpus (#10234)

Signed-off-by: Wanli Jiang <35160485+Wanli-Jiang@users.noreply.github.com>
This commit is contained in:
Wanli Jiang 2026-01-05 09:41:49 +08:00 committed by GitHub
parent 82c1ba84a7
commit da0830670a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 115 additions and 2 deletions

View File

@ -200,18 +200,26 @@ class NemotronHMOE(nn.Module):
)
# Setup latent projection layers.
# These layers should NOT be TP-sharded to ensure MoE receives
# full latent representation. They are replicated across all GPUs.
if self.use_latent_moe:
self.fc1_latent_proj = Linear(
in_features=self.hidden_size,
out_features=self.moe_hidden_size,
bias=self.mlp_bias,
dtype=config.torch_dtype,
quant_config=model_config.get_quant_config(),
skip_create_weights_in_init=model_config.
skip_create_weights_in_init,
)
self.fc2_latent_proj = Linear(
in_features=self.moe_hidden_size,
out_features=self.hidden_size,
bias=self.mlp_bias,
dtype=config.torch_dtype,
quant_config=model_config.get_quant_config(),
skip_create_weights_in_init=model_config.
skip_create_weights_in_init,
)
else:
self.fc1_latent_proj = None

View File

@ -299,7 +299,10 @@ mistral/Mistral-Large-3-675B:
- spec_dec_algo: Eagle
accuracy: 86.1
nvidia/Nemotron-Super-V3:
- accuracy: 84.38
- accuracy: 83.74
- quant_algo: NVFP4
kv_cache_quant_algo: FP8
accuracy: 80.85
nvidia/Nemotron-3-Nano:
- accuracy: 69.37
- quant_algo: FP8

View File

@ -349,7 +349,10 @@ mistral/Mistral-Large-3-675B:
- spec_dec_algo: Eagle
accuracy: 87.54
nvidia/Nemotron-Super-V3:
- accuracy: 79.41
- accuracy: 81.07
- quant_algo: NVFP4
kv_cache_quant_algo: FP8
accuracy: 77.56
nvidia/Nemotron-3-Nano:
- accuracy: 73.85
- quant_algo: FP8

View File

@ -5091,3 +5091,84 @@ class TestNemotronV3Nano(LlmapiAccuracyTestHarness):
task = GSM8K(self.MODEL_NAME)
task.evaluate(llm,
extra_evaluator_kwargs=self.EXTRA_EVALUATOR_KWARGS)
class TestNemotronV3Super(LlmapiAccuracyTestHarness):
MODEL_NAME = "nvidia/Nemotron-Super-V3"
# No thinking mode for now.
EXTRA_EVALUATOR_KWARGS = dict(chat_template_kwargs=dict(
enable_thinking=False))
@pytest.mark.skip_less_device_memory(64000)
@pytest.mark.skip_less_mpi_world_size(4)
@pytest.mark.parametrize(
"tp_size, ep_size, attention_dp, overlap_scheduler, cuda_graph",
[
(4, 4, False, True, True),
(4, 1, False, False, True),
(4, 4, True, False, True),
(4, 1, True, True, True),
(4, 4, False, True, False),
(4, 1, False, False, False),
(4, 4, True, False, False),
(4, 1, True, True, False),
],
)
def test_auto_dtype_4gpus(self, tp_size, ep_size, attention_dp,
overlap_scheduler, cuda_graph):
if attention_dp:
pytest.skip(
"Attention DP is not supported for Nemotron-3-Super yet")
kv_cache_config = KvCacheConfig(enable_block_reuse=False,
mamba_ssm_cache_dtype="float32")
pytorch_config = dict(disable_overlap_scheduler=not overlap_scheduler,
cuda_graph_config=CudaGraphConfig(
max_batch_size=512, enable_padding=True)
if cuda_graph else None)
with LLM(
f"{llm_models_root()}/Nemotron-Super-3-120B-A12B-dev",
kv_cache_config=kv_cache_config,
max_batch_size=32,
tensor_parallel_size=tp_size,
moe_expert_parallel_size=ep_size,
enable_attention_dp=attention_dp,
**pytorch_config,
) as llm:
task = MMLU(self.MODEL_NAME)
task.evaluate(llm,
extra_evaluator_kwargs=self.EXTRA_EVALUATOR_KWARGS)
task = GSM8K(self.MODEL_NAME)
task.evaluate(llm,
extra_evaluator_kwargs=self.EXTRA_EVALUATOR_KWARGS)
@skip_pre_blackwell
@pytest.mark.skip_less_mpi_world_size(8)
def test_nvfp4_8gpus(self):
# Use this test to track the best performance config.
# The optimized config is still under investigation.
# Adding this test as placeholder.
with LLM(
f"{llm_models_root()}/Nemotron-SuperV3-phase1-mtp-nvfp4-fp8kv",
kv_cache_config=KvCacheConfig(
enable_block_reuse=False,
mamba_ssm_cache_dtype="float16",
free_gpu_memory_fraction=0.5,
),
max_batch_size=32,
tensor_parallel_size=8,
moe_expert_parallel_size=8,
pipeline_parallel_size=1,
enable_attention_dp=False,
cuda_graph_config=CudaGraphConfig(max_batch_size=32,
enable_padding=True),
disable_overlap_scheduler=False,
moe_config=MoeConfig(backend="CUTLASS"),
) as llm:
task = MMLU(self.MODEL_NAME)
task.evaluate(llm,
extra_evaluator_kwargs=self.EXTRA_EVALUATOR_KWARGS)
task = GSM8K(self.MODEL_NAME)
task.evaluate(llm,
extra_evaluator_kwargs=self.EXTRA_EVALUATOR_KWARGS)

View File

@ -666,6 +666,15 @@ accuracy/test_llm_api_pytorch.py::TestMistralNemo12B::test_auto_dtype_tp2
accuracy/test_llm_api_pytorch.py::TestSeedOss_36B::test_auto_dtype
accuracy/test_llm_api_pytorch.py::TestNemotronV3Nano::test_auto_dtype
accuracy/test_llm_api_pytorch.py::TestNemotronV3Nano::test_fp8
accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_auto_dtype_4gpus[4-4-False-True-True]
accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_auto_dtype_4gpus[4-1-True-True-True]
accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_auto_dtype_4gpus[4-4-False-True-False]
accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_auto_dtype_4gpus[4-1-False-False-False]
accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_auto_dtype_4gpus[4-4-True-False-False]
accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_auto_dtype_4gpus[4-1-True-True-False]
accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_auto_dtype_4gpus[4-1-False-False-True]
accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_auto_dtype_4gpus[4-4-True-False-True]
accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_nvfp4_8gpus
accuracy/test_llm_api_pytorch_multimodal.py::TestQwen2_VL_7B::test_auto_dtype
accuracy/test_llm_api_pytorch_multimodal.py::TestQwen2_5_VL_7B::test_auto_dtype

View File

@ -21,6 +21,8 @@ l0_dgx_b200:
- unittest/_torch/modules/test_fused_moe.py::test_fused_moe_w4a8_nvfp4_fp8[enable_configurable_moe-TRTLLM]
- unittest/_torch/modules/test_fused_moe.py::test_fused_moe_mxfp4_mxfp8[enable_configurable_moe-True-8-64-TRTLLM]
- unittest/_torch/modules/test_fused_moe.py::test_fused_moe_wfp4a16[enable_configurable_moe-TRTLLM-2880-dtype0]
- accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_auto_dtype_4gpus[4-4-False-True-True]
- accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_auto_dtype_4gpus[4-1-True-True-True]
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_nvfp4[tep4_latency_moe_trtllm-torch_compile=True]
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_nvfp4[dep4_latency_moe_trtllm-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_nvfp4[dep4_latency_moe_cutlass-torch_compile=False]
@ -80,6 +82,7 @@ l0_dgx_b200:
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_nvfp4_multi_gpus[baseline] TIMEOUT (60)
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_nvfp4_multi_gpus[baseline_mtp1] TIMEOUT (60)
- accuracy/test_disaggregated_serving.py::TestDeepSeekV32Exp::test_auto_dtype[False] TIMEOUT (60)
- accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_nvfp4_8gpus TIMEOUT (60)
- condition:
ranges:
system_gpu_count:
@ -148,6 +151,12 @@ l0_dgx_b200:
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=TRTLLM-mtp_nextn=2-tp4-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTEDSL-mtp_nextn=0-ep4-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_guided_decoding_4gpus[xgrammar-mtp_nextn=0]
- accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_auto_dtype_4gpus[4-4-False-True-False]
- accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_auto_dtype_4gpus[4-1-False-False-False]
- accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_auto_dtype_4gpus[4-4-True-False-False]
- accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_auto_dtype_4gpus[4-1-True-True-False]
- accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_auto_dtype_4gpus[4-1-False-False-True]
- accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_auto_dtype_4gpus[4-4-True-False-True]
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_nvfp4[tep4_latency_moe_trtllm-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_nvfp4[tep4_latency_moe_cutlass-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_nvfp4[tep4_latency_moe_cutlass-torch_compile=True]