mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
Signed-off-by: Gal Hubara Agam <96368689+galagam@users.noreply.github.com>
This commit is contained in:
parent
5e0e48144f
commit
f3dd6da080
@ -250,14 +250,23 @@ class NemotronHBlock(nn.Module):
|
||||
|
||||
# Copied from transformers.models.nemotron.modeling_nemotron Nemotron->NemotronH
|
||||
class NemotronHMLP(nn.Module):
|
||||
def __init__(self, config, layer_idx: int, intermediate_size: Optional[int] = None):
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
layer_idx: int,
|
||||
intermediate_size: Optional[int] = None,
|
||||
is_expert: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.layer_idx = layer_idx
|
||||
self.hidden_size = config.hidden_size
|
||||
self.intermediate_size = intermediate_size or config.intermediate_size
|
||||
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
|
||||
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
|
||||
# Use latent size for expert MLPs if provided by config (required for SuperV3)
|
||||
use_latent_size = (getattr(self.config, "moe_latent_size", None) is not None) and is_expert
|
||||
input_size = self.config.moe_latent_size if use_latent_size else self.hidden_size
|
||||
self.up_proj = nn.Linear(input_size, self.intermediate_size, bias=config.mlp_bias)
|
||||
self.down_proj = nn.Linear(self.intermediate_size, input_size, bias=config.mlp_bias)
|
||||
self.act_fn = ACT2FN[config.mlp_hidden_act]
|
||||
|
||||
def forward(self, x):
|
||||
@ -271,7 +280,10 @@ class NemotronHMOE(nn.Module):
|
||||
self.experts = nn.ModuleList(
|
||||
[
|
||||
NemotronHMLP(
|
||||
config, intermediate_size=config.moe_intermediate_size, layer_idx=layer_idx
|
||||
config,
|
||||
layer_idx=layer_idx,
|
||||
intermediate_size=config.moe_intermediate_size,
|
||||
is_expert=True,
|
||||
)
|
||||
for _ in range(config.n_routed_experts)
|
||||
]
|
||||
@ -281,7 +293,19 @@ class NemotronHMOE(nn.Module):
|
||||
config=config,
|
||||
intermediate_size=config.moe_shared_expert_intermediate_size,
|
||||
layer_idx=layer_idx,
|
||||
is_expert=False,
|
||||
)
|
||||
# Add latent projections when using latent MoE (required for SuperV3)
|
||||
if getattr(config, "moe_latent_size", None) is not None:
|
||||
self.fc1_latent_proj = nn.Linear(
|
||||
config.hidden_size, config.moe_latent_size, bias=config.mlp_bias
|
||||
)
|
||||
self.fc2_latent_proj = nn.Linear(
|
||||
config.moe_latent_size, config.hidden_size, bias=config.mlp_bias
|
||||
)
|
||||
else:
|
||||
self.fc1_latent_proj = nn.Identity()
|
||||
self.fc2_latent_proj = nn.Identity()
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor):
|
||||
residuals = hidden_states
|
||||
|
||||
@ -235,7 +235,7 @@ class TestNemotronMOE(LlmapiAccuracyTestHarness):
|
||||
|
||||
class TestNemotronSuperV3(LlmapiAccuracyTestHarness):
|
||||
MODEL_NAME = "nvidia/Nemotron-Super-V3"
|
||||
MODEL_PATH_BF16 = "/scratch/models/super-v3-iter_0440000/hf" # add to llm_models_root? I don't have permissions
|
||||
MODEL_PATH_BF16 = f"{llm_models_root()}/Nemotron-Super-3-120B-A12B-dev"
|
||||
|
||||
def get_default_kwargs(self):
|
||||
return {
|
||||
@ -264,15 +264,15 @@ class TestNemotronSuperV3(LlmapiAccuracyTestHarness):
|
||||
n=beam_width,
|
||||
use_beam_search=beam_width > 1)
|
||||
|
||||
@pytest.mark.skip_less_device_memory(
|
||||
32000) # might need to require more memory
|
||||
@pytest.mark.skip_less_device(8)
|
||||
# 180GB works, might be able to go lower
|
||||
@pytest.mark.skip_less_device_memory(180000)
|
||||
@pytest.mark.skip_less_device(4)
|
||||
def test_bf16(self):
|
||||
kwargs = self.get_default_kwargs()
|
||||
sampling_params = self.get_default_sampling_params()
|
||||
with AutoDeployLLM(model=self.MODEL_PATH_BF16,
|
||||
tokenizer=self.MODEL_PATH_BF16,
|
||||
world_size=8,
|
||||
world_size=4,
|
||||
**kwargs) as llm:
|
||||
task = MMLU(self.MODEL_NAME)
|
||||
task.evaluate(llm, sampling_params=sampling_params)
|
||||
|
||||
@ -28,6 +28,8 @@ l0_dgx_b200:
|
||||
- disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_ucx[DeepSeek-V3-Lite-fp8]
|
||||
- disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_nixl[DeepSeek-V3-Lite-fp8]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[latency_adp_lmtp_tp4]
|
||||
# ------------- AutoDeploy tests ---------------
|
||||
- accuracy/test_llm_api_autodeploy.py::TestNemotronMOE::test_bf16
|
||||
- condition:
|
||||
ranges:
|
||||
system_gpu_count:
|
||||
|
||||
@ -124,6 +124,8 @@ l0_dgx_h100:
|
||||
- disaggregated/test_auto_scaling.py::test_worker_restart[http-load_balancing]
|
||||
- disaggregated/test_auto_scaling.py::test_minimal_instances[http-round_robin]
|
||||
- disaggregated/test_auto_scaling.py::test_disagg_server_restart[http-round_robin]
|
||||
# ------------- AutoDeploy tests ---------------
|
||||
- accuracy/test_llm_api_autodeploy.py::TestNemotronMOE::test_bf16
|
||||
- condition:
|
||||
ranges:
|
||||
system_gpu_count:
|
||||
|
||||
@ -134,6 +134,7 @@ l0_dgx_h200:
|
||||
# ------------- AutoDeploy tests ---------------
|
||||
- accuracy/test_llm_api_autodeploy.py::TestLlama3_1_8B::test_auto_dtype[False-4]
|
||||
- accuracy/test_llm_api_autodeploy.py::TestNemotronMOE::test_bf16
|
||||
- accuracy/test_llm_api_autodeploy.py::TestNemotronSuperV3::test_bf16
|
||||
- condition:
|
||||
ranges:
|
||||
system_gpu_count:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user