mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[TRTLLM-10065][feat] Add accuracy tests for super-v3 with multiple-gpus (#10234)
Signed-off-by: Wanli Jiang <35160485+Wanli-Jiang@users.noreply.github.com>
This commit is contained in:
parent
82c1ba84a7
commit
da0830670a
@ -200,18 +200,26 @@ class NemotronHMOE(nn.Module):
|
||||
)
|
||||
|
||||
# Setup latent projection layers.
|
||||
# These layers should NOT be TP-sharded to ensure MoE receives
|
||||
# full latent representation. They are replicated across all GPUs.
|
||||
if self.use_latent_moe:
|
||||
self.fc1_latent_proj = Linear(
|
||||
in_features=self.hidden_size,
|
||||
out_features=self.moe_hidden_size,
|
||||
bias=self.mlp_bias,
|
||||
dtype=config.torch_dtype,
|
||||
quant_config=model_config.get_quant_config(),
|
||||
skip_create_weights_in_init=model_config.
|
||||
skip_create_weights_in_init,
|
||||
)
|
||||
self.fc2_latent_proj = Linear(
|
||||
in_features=self.moe_hidden_size,
|
||||
out_features=self.hidden_size,
|
||||
bias=self.mlp_bias,
|
||||
dtype=config.torch_dtype,
|
||||
quant_config=model_config.get_quant_config(),
|
||||
skip_create_weights_in_init=model_config.
|
||||
skip_create_weights_in_init,
|
||||
)
|
||||
else:
|
||||
self.fc1_latent_proj = None
|
||||
|
||||
@ -299,7 +299,10 @@ mistral/Mistral-Large-3-675B:
|
||||
- spec_dec_algo: Eagle
|
||||
accuracy: 86.1
|
||||
nvidia/Nemotron-Super-V3:
|
||||
- accuracy: 84.38
|
||||
- accuracy: 83.74
|
||||
- quant_algo: NVFP4
|
||||
kv_cache_quant_algo: FP8
|
||||
accuracy: 80.85
|
||||
nvidia/Nemotron-3-Nano:
|
||||
- accuracy: 69.37
|
||||
- quant_algo: FP8
|
||||
|
||||
@ -349,7 +349,10 @@ mistral/Mistral-Large-3-675B:
|
||||
- spec_dec_algo: Eagle
|
||||
accuracy: 87.54
|
||||
nvidia/Nemotron-Super-V3:
|
||||
- accuracy: 79.41
|
||||
- accuracy: 81.07
|
||||
- quant_algo: NVFP4
|
||||
kv_cache_quant_algo: FP8
|
||||
accuracy: 77.56
|
||||
nvidia/Nemotron-3-Nano:
|
||||
- accuracy: 73.85
|
||||
- quant_algo: FP8
|
||||
|
||||
@ -5091,3 +5091,84 @@ class TestNemotronV3Nano(LlmapiAccuracyTestHarness):
|
||||
task = GSM8K(self.MODEL_NAME)
|
||||
task.evaluate(llm,
|
||||
extra_evaluator_kwargs=self.EXTRA_EVALUATOR_KWARGS)
|
||||
|
||||
|
||||
class TestNemotronV3Super(LlmapiAccuracyTestHarness):
|
||||
MODEL_NAME = "nvidia/Nemotron-Super-V3"
|
||||
# No thinking mode for now.
|
||||
EXTRA_EVALUATOR_KWARGS = dict(chat_template_kwargs=dict(
|
||||
enable_thinking=False))
|
||||
|
||||
@pytest.mark.skip_less_device_memory(64000)
|
||||
@pytest.mark.skip_less_mpi_world_size(4)
|
||||
@pytest.mark.parametrize(
|
||||
"tp_size, ep_size, attention_dp, overlap_scheduler, cuda_graph",
|
||||
[
|
||||
(4, 4, False, True, True),
|
||||
(4, 1, False, False, True),
|
||||
(4, 4, True, False, True),
|
||||
(4, 1, True, True, True),
|
||||
(4, 4, False, True, False),
|
||||
(4, 1, False, False, False),
|
||||
(4, 4, True, False, False),
|
||||
(4, 1, True, True, False),
|
||||
],
|
||||
)
|
||||
def test_auto_dtype_4gpus(self, tp_size, ep_size, attention_dp,
|
||||
overlap_scheduler, cuda_graph):
|
||||
if attention_dp:
|
||||
pytest.skip(
|
||||
"Attention DP is not supported for Nemotron-3-Super yet")
|
||||
|
||||
kv_cache_config = KvCacheConfig(enable_block_reuse=False,
|
||||
mamba_ssm_cache_dtype="float32")
|
||||
pytorch_config = dict(disable_overlap_scheduler=not overlap_scheduler,
|
||||
cuda_graph_config=CudaGraphConfig(
|
||||
max_batch_size=512, enable_padding=True)
|
||||
if cuda_graph else None)
|
||||
|
||||
with LLM(
|
||||
f"{llm_models_root()}/Nemotron-Super-3-120B-A12B-dev",
|
||||
kv_cache_config=kv_cache_config,
|
||||
max_batch_size=32,
|
||||
tensor_parallel_size=tp_size,
|
||||
moe_expert_parallel_size=ep_size,
|
||||
enable_attention_dp=attention_dp,
|
||||
**pytorch_config,
|
||||
) as llm:
|
||||
task = MMLU(self.MODEL_NAME)
|
||||
task.evaluate(llm,
|
||||
extra_evaluator_kwargs=self.EXTRA_EVALUATOR_KWARGS)
|
||||
task = GSM8K(self.MODEL_NAME)
|
||||
task.evaluate(llm,
|
||||
extra_evaluator_kwargs=self.EXTRA_EVALUATOR_KWARGS)
|
||||
|
||||
@skip_pre_blackwell
|
||||
@pytest.mark.skip_less_mpi_world_size(8)
|
||||
def test_nvfp4_8gpus(self):
|
||||
# Use this test to track the best performance config.
|
||||
# The optimized config is still under investigation.
|
||||
# Adding this test as placeholder.
|
||||
with LLM(
|
||||
f"{llm_models_root()}/Nemotron-SuperV3-phase1-mtp-nvfp4-fp8kv",
|
||||
kv_cache_config=KvCacheConfig(
|
||||
enable_block_reuse=False,
|
||||
mamba_ssm_cache_dtype="float16",
|
||||
free_gpu_memory_fraction=0.5,
|
||||
),
|
||||
max_batch_size=32,
|
||||
tensor_parallel_size=8,
|
||||
moe_expert_parallel_size=8,
|
||||
pipeline_parallel_size=1,
|
||||
enable_attention_dp=False,
|
||||
cuda_graph_config=CudaGraphConfig(max_batch_size=32,
|
||||
enable_padding=True),
|
||||
disable_overlap_scheduler=False,
|
||||
moe_config=MoeConfig(backend="CUTLASS"),
|
||||
) as llm:
|
||||
task = MMLU(self.MODEL_NAME)
|
||||
task.evaluate(llm,
|
||||
extra_evaluator_kwargs=self.EXTRA_EVALUATOR_KWARGS)
|
||||
task = GSM8K(self.MODEL_NAME)
|
||||
task.evaluate(llm,
|
||||
extra_evaluator_kwargs=self.EXTRA_EVALUATOR_KWARGS)
|
||||
|
||||
@ -666,6 +666,15 @@ accuracy/test_llm_api_pytorch.py::TestMistralNemo12B::test_auto_dtype_tp2
|
||||
accuracy/test_llm_api_pytorch.py::TestSeedOss_36B::test_auto_dtype
|
||||
accuracy/test_llm_api_pytorch.py::TestNemotronV3Nano::test_auto_dtype
|
||||
accuracy/test_llm_api_pytorch.py::TestNemotronV3Nano::test_fp8
|
||||
accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_auto_dtype_4gpus[4-4-False-True-True]
|
||||
accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_auto_dtype_4gpus[4-1-True-True-True]
|
||||
accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_auto_dtype_4gpus[4-4-False-True-False]
|
||||
accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_auto_dtype_4gpus[4-1-False-False-False]
|
||||
accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_auto_dtype_4gpus[4-4-True-False-False]
|
||||
accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_auto_dtype_4gpus[4-1-True-True-False]
|
||||
accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_auto_dtype_4gpus[4-1-False-False-True]
|
||||
accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_auto_dtype_4gpus[4-4-True-False-True]
|
||||
accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_nvfp4_8gpus
|
||||
|
||||
accuracy/test_llm_api_pytorch_multimodal.py::TestQwen2_VL_7B::test_auto_dtype
|
||||
accuracy/test_llm_api_pytorch_multimodal.py::TestQwen2_5_VL_7B::test_auto_dtype
|
||||
|
||||
@ -21,6 +21,8 @@ l0_dgx_b200:
|
||||
- unittest/_torch/modules/test_fused_moe.py::test_fused_moe_w4a8_nvfp4_fp8[enable_configurable_moe-TRTLLM]
|
||||
- unittest/_torch/modules/test_fused_moe.py::test_fused_moe_mxfp4_mxfp8[enable_configurable_moe-True-8-64-TRTLLM]
|
||||
- unittest/_torch/modules/test_fused_moe.py::test_fused_moe_wfp4a16[enable_configurable_moe-TRTLLM-2880-dtype0]
|
||||
- accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_auto_dtype_4gpus[4-4-False-True-True]
|
||||
- accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_auto_dtype_4gpus[4-1-True-True-True]
|
||||
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_nvfp4[tep4_latency_moe_trtllm-torch_compile=True]
|
||||
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_nvfp4[dep4_latency_moe_trtllm-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_nvfp4[dep4_latency_moe_cutlass-torch_compile=False]
|
||||
@ -80,6 +82,7 @@ l0_dgx_b200:
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_nvfp4_multi_gpus[baseline] TIMEOUT (60)
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_nvfp4_multi_gpus[baseline_mtp1] TIMEOUT (60)
|
||||
- accuracy/test_disaggregated_serving.py::TestDeepSeekV32Exp::test_auto_dtype[False] TIMEOUT (60)
|
||||
- accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_nvfp4_8gpus TIMEOUT (60)
|
||||
- condition:
|
||||
ranges:
|
||||
system_gpu_count:
|
||||
@ -148,6 +151,12 @@ l0_dgx_b200:
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=TRTLLM-mtp_nextn=2-tp4-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTEDSL-mtp_nextn=0-ep4-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_guided_decoding_4gpus[xgrammar-mtp_nextn=0]
|
||||
- accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_auto_dtype_4gpus[4-4-False-True-False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_auto_dtype_4gpus[4-1-False-False-False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_auto_dtype_4gpus[4-4-True-False-False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_auto_dtype_4gpus[4-1-True-True-False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_auto_dtype_4gpus[4-1-False-False-True]
|
||||
- accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_auto_dtype_4gpus[4-4-True-False-True]
|
||||
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_nvfp4[tep4_latency_moe_trtllm-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_nvfp4[tep4_latency_moe_cutlass-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_nvfp4[tep4_latency_moe_cutlass-torch_compile=True]
|
||||
|
||||
Loading…
Reference in New Issue
Block a user