[#10056][chore] AutoDeploy: Enable Nemo SuperV3 accuracy test (#10308)

Signed-off-by: Gal Hubara Agam <96368689+galagam@users.noreply.github.com>
This commit is contained in:
Gal Hubara-Agam 2026-01-02 11:20:19 +02:00 committed by GitHub
parent 5e0e48144f
commit f3dd6da080
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 38 additions and 9 deletions

View File

@ -250,14 +250,23 @@ class NemotronHBlock(nn.Module):
# Copied from transformers.models.nemotron.modeling_nemotron Nemotron->NemotronH
class NemotronHMLP(nn.Module):
def __init__(self, config, layer_idx: int, intermediate_size: Optional[int] = None):
def __init__(
self,
config,
layer_idx: int,
intermediate_size: Optional[int] = None,
is_expert: bool = False,
):
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.hidden_size = config.hidden_size
self.intermediate_size = intermediate_size or config.intermediate_size
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
# Use latent size for expert MLPs if provided by config (required for SuperV3)
use_latent_size = (getattr(self.config, "moe_latent_size", None) is not None) and is_expert
input_size = self.config.moe_latent_size if use_latent_size else self.hidden_size
self.up_proj = nn.Linear(input_size, self.intermediate_size, bias=config.mlp_bias)
self.down_proj = nn.Linear(self.intermediate_size, input_size, bias=config.mlp_bias)
self.act_fn = ACT2FN[config.mlp_hidden_act]
def forward(self, x):
@ -271,7 +280,10 @@ class NemotronHMOE(nn.Module):
self.experts = nn.ModuleList(
[
NemotronHMLP(
config, intermediate_size=config.moe_intermediate_size, layer_idx=layer_idx
config,
layer_idx=layer_idx,
intermediate_size=config.moe_intermediate_size,
is_expert=True,
)
for _ in range(config.n_routed_experts)
]
@ -281,7 +293,19 @@ class NemotronHMOE(nn.Module):
config=config,
intermediate_size=config.moe_shared_expert_intermediate_size,
layer_idx=layer_idx,
is_expert=False,
)
# Add latent projections when using latent MoE (required for SuperV3)
if getattr(config, "moe_latent_size", None) is not None:
self.fc1_latent_proj = nn.Linear(
config.hidden_size, config.moe_latent_size, bias=config.mlp_bias
)
self.fc2_latent_proj = nn.Linear(
config.moe_latent_size, config.hidden_size, bias=config.mlp_bias
)
else:
self.fc1_latent_proj = nn.Identity()
self.fc2_latent_proj = nn.Identity()
def forward(self, hidden_states: torch.Tensor):
residuals = hidden_states

View File

@ -235,7 +235,7 @@ class TestNemotronMOE(LlmapiAccuracyTestHarness):
class TestNemotronSuperV3(LlmapiAccuracyTestHarness):
MODEL_NAME = "nvidia/Nemotron-Super-V3"
MODEL_PATH_BF16 = "/scratch/models/super-v3-iter_0440000/hf" # add to llm_models_root? I don't have permissions
MODEL_PATH_BF16 = f"{llm_models_root()}/Nemotron-Super-3-120B-A12B-dev"
def get_default_kwargs(self):
return {
@ -264,15 +264,15 @@ class TestNemotronSuperV3(LlmapiAccuracyTestHarness):
n=beam_width,
use_beam_search=beam_width > 1)
@pytest.mark.skip_less_device_memory(
32000) # might need to require more memory
@pytest.mark.skip_less_device(8)
# 180GB works, might be able to go lower
@pytest.mark.skip_less_device_memory(180000)
@pytest.mark.skip_less_device(4)
def test_bf16(self):
kwargs = self.get_default_kwargs()
sampling_params = self.get_default_sampling_params()
with AutoDeployLLM(model=self.MODEL_PATH_BF16,
tokenizer=self.MODEL_PATH_BF16,
world_size=8,
world_size=4,
**kwargs) as llm:
task = MMLU(self.MODEL_NAME)
task.evaluate(llm, sampling_params=sampling_params)

View File

@ -28,6 +28,8 @@ l0_dgx_b200:
- disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_ucx[DeepSeek-V3-Lite-fp8]
- disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_nixl[DeepSeek-V3-Lite-fp8]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[latency_adp_lmtp_tp4]
# ------------- AutoDeploy tests ---------------
- accuracy/test_llm_api_autodeploy.py::TestNemotronMOE::test_bf16
- condition:
ranges:
system_gpu_count:

View File

@ -124,6 +124,8 @@ l0_dgx_h100:
- disaggregated/test_auto_scaling.py::test_worker_restart[http-load_balancing]
- disaggregated/test_auto_scaling.py::test_minimal_instances[http-round_robin]
- disaggregated/test_auto_scaling.py::test_disagg_server_restart[http-round_robin]
# ------------- AutoDeploy tests ---------------
- accuracy/test_llm_api_autodeploy.py::TestNemotronMOE::test_bf16
- condition:
ranges:
system_gpu_count:

View File

@ -134,6 +134,7 @@ l0_dgx_h200:
# ------------- AutoDeploy tests ---------------
- accuracy/test_llm_api_autodeploy.py::TestLlama3_1_8B::test_auto_dtype[False-4]
- accuracy/test_llm_api_autodeploy.py::TestNemotronMOE::test_bf16
- accuracy/test_llm_api_autodeploy.py::TestNemotronSuperV3::test_bf16
- condition:
ranges:
system_gpu_count: