[#9230][refactor] Replace nemotron patches with custom model implementation (#9751)

[#9230][refactor] Replace nemotron patches with custom model implementation

* Why?

Patching for nemotron H models was growing out of hand, and made certain
optimizations more complex than they needed to be.

* What?

This commit finally gets rid of them, and replaces them with the custom
model implementation in `modeling_nemotron_h.py`.

Closes #9230
Closes NvBug 5747867

Signed-off-by: William Zhang <133824995+2ez4bz@users.noreply.github.com>
This commit is contained in:
William Zhang 2025-12-18 19:36:27 -08:00 committed by GitHub
parent 72c5480dfb
commit 478b6b20a1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 379 additions and 465 deletions

View File

@ -2,6 +2,8 @@
import flashinfer
import torch
import torch.nn.functional as F
from einops import rearrange
from ...flashinfer_utils import get_env_enable_pdl
from ...modules.mamba.layernorm_gated import _layer_norm_fwd
@ -159,3 +161,35 @@ def _triton_rmsnorm_gated_meta(
assert gate.shape == x.shape, "gate must match x shape"
return x.new_empty(x.shape, dtype=torch.float32)
# Forked from:
# https://github.com/state-spaces/mamba/blob/6b32be06d026e170b3fdaf3ae6282c5a6ff57b06/mamba_ssm/ops/triton/layernorm_gated.py
# NOTES:
# 1. At time of writing (09/25/2025), the nano nemotron v2 modeling code expects `mamba_ssm`
# to be installed so as to be able to make use of its grouped gated RMS norm operation.
# We therefore replace it with one that uses einops + pytorch.
def gated_rms_norm_ref(
x, weight, bias, z=None, eps=1e-6, group_size=None, norm_before_gate=True, upcast=True
):
dtype = x.dtype
# N = x.shape[-1]
weight = weight.float()
bias = bias.float() if bias is not None else None
if upcast:
x = x.float()
z = z.float() if z is not None else z
if z is not None and not norm_before_gate:
x = x * F.silu(z)
if group_size is None:
rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps)
out = (x * rstd * weight) + bias if bias is not None else (x * rstd * weight)
else:
x_group = rearrange(x, "... (g d) -> ... g d", d=group_size)
rstd = 1 / torch.sqrt((x_group.square()).mean(dim=-1, keepdim=True) + eps)
out = rearrange(x_group * rstd, "... g d -> ... (g d)") * weight
if bias is not None:
out = out + bias
if z is not None and norm_before_gate:
out *= F.silu(z)
return out.to(dtype)

View File

@ -1,4 +1,2 @@
# TODO: When getting rid of the nemotron H patches, import `modeling_nemotron_h` here to ensure the
# custom model implementation is registered.
from . import custom, hf, nemotron_flash, patches
from .factory import *

View File

@ -1 +1,8 @@
from .modeling_nemotron_flash import NemotronFlashForCausalLM, NemotronFlashPreTrainedTokenizerFast
from .modeling_nemotron_h import NemotronHForCausalLM
__all__ = (
"NemotronFlashForCausalLM",
"NemotronFlashPreTrainedTokenizerFast",
"NemotronHForCausalLM",
)

View File

@ -25,17 +25,14 @@ from typing import Optional, Tuple, Union
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from einops import rearrange
from torch import nn
from transformers.activations import ACT2FN
from transformers.generation import GenerationMixin
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import ModelOutput
from tensorrt_llm._torch.auto_deploy.models.patches.nemotron_h import (
_nemotron_h_moe_forward,
_nemotron_h_topk_router_forward,
)
from tensorrt_llm._torch.auto_deploy.custom_ops.rms_norm import gated_rms_norm_ref
from tensorrt_llm._torch.auto_deploy.models.hf import AutoModelForCausalLMFactory
class MambaRMSNormGated(torch.nn.Module):
@ -46,7 +43,7 @@ class MambaRMSNormGated(torch.nn.Module):
self.group_size = group_size
def forward(self, hidden_states, gate=None):
return _rms_norm_ref(
return gated_rms_norm_ref(
x=hidden_states,
weight=self.weight,
bias=None,
@ -57,38 +54,6 @@ class MambaRMSNormGated(torch.nn.Module):
)
# Forked from:
# https://github.com/state-spaces/mamba/blob/6b32be06d026e170b3fdaf3ae6282c5a6ff57b06/mamba_ssm/ops/triton/layernorm_gated.py
# NOTES:
# 1. At time of writing (09/25/2025), the nano nemotron v2 modeling code expects `mamba_ssm`
# to be installed so as to be able to make use of its grouped gated RMS norm operation.
# We therefore replace it with one that uses einops + pytorch.
def _rms_norm_ref(
x, weight, bias, z=None, eps=1e-6, group_size=None, norm_before_gate=True, upcast=True
):
dtype = x.dtype
# N = x.shape[-1]
weight = weight.float()
bias = bias.float() if bias is not None else None
if upcast:
x = x.float()
z = z.float() if z is not None else z
if z is not None and not norm_before_gate:
x = x * F.silu(z)
if group_size is None:
rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps)
out = (x * rstd * weight) + bias if bias is not None else (x * rstd * weight)
else:
x_group = rearrange(x, "... (g d) -> ... g d", d=group_size)
rstd = 1 / torch.sqrt((x_group.square()).mean(dim=-1, keepdim=True) + eps)
out = rearrange(x_group * rstd, "... g d -> ... (g d)") * weight
if bias is not None:
out = out + bias
if z is not None and norm_before_gate:
out *= F.silu(z)
return out.to(dtype)
class NemotronHMamba2Mixer(nn.Module):
"""
Compute , A, B, C, and D the state space parameters and compute the `contextualized_states`.
@ -149,9 +114,9 @@ class NemotronHMamba2Mixer(nn.Module):
self.A_log._no_weight_decay = True
# Instead of recomputing `torch.exp(self.A_log.float())` on every forward pass, we will register a hook
# that sets this appropriately when loading weights.
# NOTE: we explicitly do NOT make this a `nn.Parameter` so that it does not appear in the state dict of
# this module, or an equivalent graph module trace from it.
self._minus_A = -A.float()
# NOTE: we explicitly register this as a non-persistent buffer so that it does not appear in the state dict of
# this module, or an equivalent graph module trace from it, but still gets included in e.g. `to()` calls.
self.register_buffer("_minus_A", -A.float(), persistent=False)
self.norm = MambaRMSNormGated(
self.intermediate_size,
eps=self.layer_norm_epsilon,
@ -317,8 +282,43 @@ class NemotronHMOE(nn.Module):
layer_idx=layer_idx,
)
# TODO: inline code from `_nemotron_h_moe_forward` when removing patches.
forward = _nemotron_h_moe_forward
def forward(self, hidden_states: torch.Tensor):
residuals = hidden_states
orig_shape = hidden_states.shape
topk_indices, topk_weights = self.gate(hidden_states)
x_flat = hidden_states.view(-1, hidden_states.shape[-1])
# NOTE: So far we've seen that the dispatch order in eager code is the same as the node order in the exported
# graph.
# We dispatch shared expert first so that we can easily fork the execution of the routed experts
# (using the custom op below) to an auxiliary stream.
shared_out = self.shared_experts(residuals)
# Check if this is a latent MOE (has fc1_latent_proj and fc2_latent_proj)
has_latent_proj = hasattr(self, "fc1_latent_proj") and hasattr(self, "fc2_latent_proj")
if has_latent_proj:
# Latent MOE: project to latent space before routing
x_flat = self.fc1_latent_proj(x_flat)
# Route through experts (operates in latent space if latent MOE, full space otherwise)
out_flat = torch.ops.auto_deploy.torch_moe(
x_flat,
topk_indices,
topk_weights,
w1_weight=[e.up_proj.weight for e in self.experts],
w2_weight=[e.down_proj.weight for e in self.experts],
w3_weight=[],
act_fn="relu2",
mlp_style="mlp",
)
if has_latent_proj:
# Latent MOE: project back from latent space
out_flat = self.fc2_latent_proj(out_flat)
routed_out = out_flat.view(*orig_shape)
out = shared_out + routed_out
return out
class NemotronHTopkRouter(nn.Module):
@ -339,22 +339,33 @@ class NemotronHTopkRouter(nn.Module):
"e_score_correction_bias", torch.zeros(self.n_routed_experts, dtype=torch.float32)
)
forward = _nemotron_h_topk_router_forward
def forward(self, hidden_states):
"""
Forward pass for NemotronHTopkRouter using the optimized noaux_tc_op kernel.
This replaces the original forward method which used pure PyTorch operations
with optimized CUDA kernels:
"""
hidden_states = hidden_states.view(-1, self.config.hidden_size)
if self.weight.dtype == torch.float32:
router_logits = F.linear(hidden_states.type(torch.float32), self.weight)
else:
router_logits = torch.ops.trtllm.dsv3_router_gemm_op(
hidden_states, self.weight.t(), bias=None, out_dtype=torch.float32
)
# Copied from transformers.models.llama.modeling_llama.repeat_kv
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(
batch, num_key_value_heads, n_rep, slen, head_dim
)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
# Use the fused noaux_tc_op kernel which applies sigmoid internally
# and performs group-based top-k selection with normalization
topk_weights, topk_indices = torch.ops.trtllm.noaux_tc_op(
router_logits,
self.e_score_correction_bias,
self.n_group,
self.topk_group,
self.top_k,
self.routed_scaling_factor,
)
return topk_indices, topk_weights
class NemotronHAttention(nn.Module):
@ -369,8 +380,23 @@ class NemotronHAttention(nn.Module):
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
if config.head_dim is not None:
self.head_dim = config.head_dim
# At some point during NemotronH development, what used to be called `attention_head_dim`
# was renamed to `head_dim`. Since no configuration class's code (nor the modeling code,
# for that matter) was ever upstreamed into `transformers`, we have to resort to the below
# hack in order to support multiple iterations of NemotronH models.
if hasattr(config, "head_dim"):
head_dim = config.head_dim
elif hasattr(config, "attention_head_dim"):
head_dim = config.attention_head_dim
else:
raise AttributeError(
"Expected either `head_dim` or `attention_head_dim` to be present in the config "
"class, found neither."
)
if head_dim is not None:
self.head_dim = head_dim
else:
self.head_dim = config.hidden_size // config.num_attention_heads
self.num_key_value_heads = config.num_key_value_heads
@ -594,7 +620,4 @@ class NemotronHForCausalLM(NemotronHPreTrainedModel, GenerationMixin):
return NemotronHCausalLMOutput(logits)
# TODO: uncomment after removing patches (and make sure it is imported in `__init__.py`).
# from tensorrt_llm._torch.auto_deploy.models.hf import AutoModelForCausalLMFactory
#
# AutoModelForCausalLMFactory.register_custom_model_cls("NemotronHConfig", NemotronHForCausalLM)
AutoModelForCausalLMFactory.register_custom_model_cls("NemotronHConfig", NemotronHForCausalLM)

View File

@ -1,200 +0,0 @@
import contextlib
import importlib.util
import sys
import types
from typing import Callable, Dict, List, Optional, Tuple
import torch
import torch.nn.functional as F
from einops import rearrange
from transformers import AutoModelForCausalLM
from tensorrt_llm._torch.auto_deploy.models.patches.bamba import _bamba_mixer_torch_forward
# Forked from:
# https://github.com/state-spaces/mamba/blob/6b32be06d026e170b3fdaf3ae6282c5a6ff57b06/mamba_ssm/ops/triton/layernorm_gated.py
# NOTES:
# 1. At time of writing (09/25/2025), the nano nemotron v2 modeling code expects `mamba_ssm`
# to be installed so as to be able to make use of its grouped gated RMS norm operation.
# We therefore replace it with one that uses einops + pytorch.
def _rms_norm_ref(
x, weight, bias, z=None, eps=1e-6, group_size=None, norm_before_gate=True, upcast=True
):
dtype = x.dtype
# N = x.shape[-1]
weight = weight.float()
bias = bias.float() if bias is not None else None
if upcast:
x = x.float()
z = z.float() if z is not None else z
if z is not None and not norm_before_gate:
x = x * F.silu(z)
if group_size is None:
rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps)
out = (x * rstd * weight) + bias if bias is not None else (x * rstd * weight)
else:
x_group = rearrange(x, "... (g d) -> ... g d", d=group_size)
rstd = 1 / torch.sqrt((x_group.square()).mean(dim=-1, keepdim=True) + eps)
out = rearrange(x_group * rstd, "... g d -> ... (g d)") * weight
if bias is not None:
out = out + bias
if z is not None and norm_before_gate:
out *= F.silu(z)
return out.to(dtype)
# The original implementation looks at `cache_position[0]` to decide what to do which does not
# play well with export. Plus, we do not want it to be updated anyway.
def _nemotron_h_model_update_mamba_mask(self, attention_mask, cache_position):
return None
def _nemotron_h_model_update_causal_mask(self, attention_mask, input_tensor, cache_position):
# Force attention to use causal mode without explicit masks
return None
def _nemotron_h_block_forward(
self,
hidden_states,
cache_params=None,
cache_position: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
):
device = hidden_states.device
with contextlib.ExitStack() as stack:
if device.type == "cuda":
stack.enter_context(torch.cuda.stream(torch.cuda.default_stream(device)))
# * Use torch.cuda.stream() to avoid NaN issues when using multiple GPUs
residual = hidden_states
hidden_states = self.norm(hidden_states.to(dtype=self.norm.weight.dtype))
if self.residual_in_fp32:
residual = residual.to(torch.float32)
if self.block_type == "mamba":
hidden_states = self.mixer(
hidden_states, cache_params=cache_params, cache_position=cache_position
)
elif self.block_type == "attention":
hidden_states = self.mixer(hidden_states, cache_position=cache_position)
hidden_states = hidden_states[0]
elif self.block_type in ["mlp", "moe"]:
hidden_states = self.mixer(hidden_states)
else:
raise ValueError(f"Invalid block_type: {self.block_type}")
hidden_states = residual + hidden_states
return hidden_states
def _nemotron_h_topk_router_forward(self, hidden_states):
"""
Forward pass for NemotronHTopkRouter using the optimized noaux_tc_op kernel.
This replaces the original forward method which used pure PyTorch operations
with optimized CUDA kernels:
"""
hidden_states = hidden_states.view(-1, self.config.hidden_size)
if self.weight.dtype == torch.float32:
router_logits = F.linear(hidden_states.type(torch.float32), self.weight)
else:
router_logits = torch.ops.trtllm.dsv3_router_gemm_op(
hidden_states, self.weight.t(), bias=None, out_dtype=torch.float32
)
# Use the fused noaux_tc_op kernel which applies sigmoid internally
# and performs group-based top-k selection with normalization
topk_weights, topk_indices = torch.ops.trtllm.noaux_tc_op(
router_logits,
self.e_score_correction_bias,
self.n_group,
self.topk_group,
self.top_k,
self.routed_scaling_factor,
)
return topk_indices, topk_weights
# Note: we assume experts have no bias for now
def _nemotron_h_moe_forward(self, hidden_states: torch.Tensor):
"""
Uses NemotronH router (returns indices, weights) and dispatches through auto_deploy::torch_moe
with act_fn='relu2'. Handles both latent MOE and direct MOE architectures.
"""
residuals = hidden_states
orig_shape = hidden_states.shape
topk_indices, topk_weights = self.gate(hidden_states)
x_flat = hidden_states.view(-1, hidden_states.shape[-1])
# NOTE: So far we've seen that the dispatch order in eager code is the same as the node order in the exported graph.
# We dispatch shared expert first so that we can easily fork the execution of the routed experts
# (using the custom op below) to an auxiliary stream.
shared_out = self.shared_experts(residuals)
# Check if this is a latent MOE (has fc1_latent_proj and fc2_latent_proj)
has_latent_proj = hasattr(self, "fc1_latent_proj") and hasattr(self, "fc2_latent_proj")
if has_latent_proj:
# Latent MOE: project to latent space before routing
x_flat = self.fc1_latent_proj(x_flat)
# Route through experts (operates in latent space if latent MOE, full space otherwise)
out_flat = torch.ops.auto_deploy.torch_moe(
x_flat,
topk_indices,
topk_weights,
w1_weight=[e.up_proj.weight for e in self.experts],
w2_weight=[e.down_proj.weight for e in self.experts],
w3_weight=[],
act_fn="relu2",
mlp_style="mlp",
)
if has_latent_proj:
# Latent MOE: project back from latent space
out_flat = self.fc2_latent_proj(out_flat)
routed_out = out_flat.view(*orig_shape)
out = shared_out + routed_out
return out
_from_config_original = AutoModelForCausalLM.from_config
CUSTOM_MODULE_PATCHES: Dict[str, List[Tuple[str, Callable]]] = {
"NemotronHMamba2Mixer": [("forward", _bamba_mixer_torch_forward)],
"NemotronHModel": [
("_update_causal_mask", _nemotron_h_model_update_causal_mask),
("_update_mamba_mask", _nemotron_h_model_update_mamba_mask),
],
"NemotronHBlock": [("forward", _nemotron_h_block_forward)],
"NemotronHMOE": [("forward", _nemotron_h_moe_forward)],
"NemotronHTopkRouter": [("forward", _nemotron_h_topk_router_forward)],
}
def get_model_from_config_patched(config, **kwargs):
model = _from_config_original(config, **kwargs)
# Patch modules
for _, module in model.named_modules():
if (module_name := type(module).__name__) in CUSTOM_MODULE_PATCHES.keys():
patches = CUSTOM_MODULE_PATCHES[module_name]
for method_name, method_patch in patches:
setattr(module, method_name, types.MethodType(method_patch, module))
return model
# TODO: figure out how this can be incorporated into the export patch system
AutoModelForCausalLM.from_config = get_model_from_config_patched
# TODO: figure out how this can be incorporated into the export patch system
# Only patch if the module isn't available
_mamba_ssm_module = "mamba_ssm"
_mamba_ssm_submodule = f"{_mamba_ssm_module}.ops.triton.layernorm_gated"
if importlib.util.find_spec(_mamba_ssm_module) is None:
stub_mod = types.ModuleType(_mamba_ssm_submodule)
stub_mod.rmsnorm_fn = _rms_norm_ref
sys.modules[_mamba_ssm_submodule] = stub_mod

View File

@ -123,7 +123,7 @@ class Quantization(BaseTransform):
cnt += 1
return gm, TransformInfo(
skipped=False, num_matches=cnt, is_clean=False, has_valid_shapes=True
skipped=False, num_matches=cnt, is_clean=False, has_valid_shapes=(cnt == 0)
)
def _insert_quantized_linear(

View File

@ -7,8 +7,8 @@ import torch
from pydantic import Field
from torch.fx import GraphModule
from ...custom_ops.rms_norm import gated_rms_norm_ref
from ...models.factory import ModelFactory
from ...models.patches.nemotron_h import _rms_norm_ref
from ...shim.interface import CachedSequenceInterface
# It is important to import ADPatternMatcherPass from pattern_matcher.py, not from torch._inductor.pattern_matcher
@ -225,7 +225,7 @@ def _gated_rmsnorm_pattern_ref(
eps: float = 1e-5,
group_size: int = 512,
) -> torch.Tensor:
y = _rms_norm_ref(
y = gated_rms_norm_ref(
x,
weight,
bias=None,

View File

@ -460,10 +460,6 @@ accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backe
test_e2e.py::test_trtllm_serve_multimodal_example SKIP (https://nvbugs/5747920)
examples/test_whisper.py::test_llm_whisper_general[large-v3-disable_gemm_plugin-enable_attention_plugin-disable_weight_only-float16-nb:1-use_cpp_runtime] SKIP (https://nvbugs/5747930)
test_e2e.py::test_trtllm_serve_example SKIP (https://nvbugs/5747938)
unittest/_torch/auto_deploy/unit/singlegpu/models/test_nemotron_h_patches.py::test_nemotronh_moe_patch_forward[dtype0-2-6-nvidia/NVIDIA-Nemotron-Nano-31B-A3-v3] SKIP (https://nvbugs/5747867)
unittest/_torch/auto_deploy/unit/singlegpu/models/test_nemotron_h_patches.py::test_nemotronh_moe_patch_forward[dtype0-1-8-nvidia/NVIDIA-Nemotron-Nano-31B-A3-v3] SKIP (https://nvbugs/5747867)
unittest/_torch/auto_deploy/unit/singlegpu/models/test_nemotron_h_patches.py::test_nemotronh_moe_custom_implementation[dtype0-2-6-nvidia/NVIDIA-Nemotron-Nano-31B-A3-v3] SKIP (https://nvbugs/5747867)
unittest/_torch/auto_deploy/unit/singlegpu/models/test_nemotron_h_patches.py::test_nemotronh_moe_custom_implementation[dtype0-1-8-nvidia/NVIDIA-Nemotron-Nano-31B-A3-v3] SKIP (https://nvbugs/5747867)
unittest/_torch/auto_deploy/unit/singlegpu/test_ad_build_small_single.py::test_build_ad[meta-llama/Llama-4-Scout-17B-16E-Instruct-llm_extra_args8] SKIP (https://nvbugs/5747878)
unittest/_torch/auto_deploy/unit/singlegpu/test_ad_build_small_single.py::test_build_ad[meta-llama/Llama-4-Scout-17B-16E-Instruct-llm_extra_args9] SKIP (https://nvbugs/5747878)
triton_server/test_triton.py::test_opt[opt] SKIP (https://nvbugs/5739981)

View File

@ -1,8 +1,10 @@
import pytest
import torch
import tensorrt_llm._torch.auto_deploy.custom_ops # noqa: F401
from tensorrt_llm._torch.auto_deploy.models.patches.nemotron_h import _rms_norm_ref
from tensorrt_llm._torch.auto_deploy.custom_ops.rms_norm import (
gated_rms_norm_ref,
triton_rmsnorm_gated,
)
@pytest.mark.skipif(
@ -19,12 +21,12 @@ def test_custom_op_matches_ref(B, T, H, group, use_gate, dtype):
z = torch.randn_like(x) if use_gate else None
w = torch.ones(H, dtype=dtype, device=device)
y_ref = _rms_norm_ref(
y_ref = gated_rms_norm_ref(
x, w, bias=None, z=z, eps=1e-5, group_size=group, norm_before_gate=False, upcast=True
)
# Custom op (currently returns fp32). Cast it back to x.dtype for apples-to-apples with ref.
y_op_fp32 = torch.ops.auto_deploy.triton_rmsnorm_gated(x, w, z, 1e-5, group, False)
y_op_fp32 = triton_rmsnorm_gated(x, w, z, 1e-5, group, False)
y_op = y_op_fp32.to(x.dtype)
assert y_ref.dtype == x.dtype and y_op.dtype == x.dtype

View File

@ -205,10 +205,12 @@ def test_custom_model_mapping_in_parent_does_not_affect_parent():
class Child(AutoModelForCausalLMFactory):
pass
parent_mapping = copy.copy(AutoModelForCausalLMFactory._custom_model_mapping)
custom_model_cls = MagicMock(spec=AutoModelForCausalLM)
custom_model_cls.configure_mock(_from_config=MagicMock(side_effect=MyError))
Child.register_custom_model_cls(
config_cls_name=FooConfig.__name__, custom_model_cls=custom_model_cls
)
assert AutoModelForCausalLMFactory._custom_model_mapping == {}
assert AutoModelForCausalLMFactory._custom_model_mapping == parent_mapping

View File

@ -1,5 +1,3 @@
import copy
import pytest
import torch
from _model_test_utils import get_small_model_config
@ -7,8 +5,6 @@ from torch.export import Dim
from tensorrt_llm._torch.auto_deploy.export import apply_export_patches, torch_export_to_gm
from tensorrt_llm._torch.auto_deploy.llm_args import AutoDeployConfig
from tensorrt_llm._torch.auto_deploy.models.hf import AutoModelForCausalLMFactory
from tensorrt_llm._torch.auto_deploy.models.modeling_nemotron_h import NemotronHForCausalLM
from tensorrt_llm._torch.auto_deploy.utils._graph import move_to_device
# NOTE: find example inputs with the same tokenization length to avoid seq concat.
@ -16,37 +12,15 @@ EXAMPLE_INPUT = "Mamba is a snake with the following properties:"
EXAMPLE_INPUT2 = "Tiger is a cat with the following properties:"
@pytest.fixture
def setup_custom_model_cls_registry(request):
# TODO: remove all this when the patches in `bamba.py` and `nemotron_h.py` can be removed.
old_mapping = copy.copy(AutoModelForCausalLMFactory._custom_model_mapping)
AutoModelForCausalLMFactory._custom_model_mapping = {}
register_custom_model = request.node.callspec.params.get("register_custom_model", False)
if register_custom_model:
AutoModelForCausalLMFactory.register_custom_model_cls(
config_cls_name="NemotronHConfig",
custom_model_cls=NemotronHForCausalLM,
)
yield
AutoModelForCausalLMFactory._custom_model_mapping = old_mapping
@pytest.mark.parametrize(
"model_dir,run_verify_generation,register_custom_model",
"model_dir,run_verify_generation",
[
("ibm-ai-platform/Bamba-9B-v2", True, False),
# This tests the incumbent patching approach.
("nvidia/NVIDIA-Nemotron-Nano-12B-v2", True, False),
# This tests the new custom model implementation.
("nvidia/NVIDIA-Nemotron-Nano-12B-v2", True, True),
("ibm-ai-platform/Bamba-9B-v2", True),
],
)
def test_bamba_patches(
model_dir: str,
run_verify_generation: bool,
register_custom_model: bool,
setup_custom_model_cls_registry,
):
# NOTE: set to False if you want to locally test the full model.
use_small_config: bool = True
@ -124,13 +98,14 @@ def test_bamba_patches(
move_to_device(gm, "cuda")
factory._to_maybe_random(model, "cuda")
model.load_state_dict(gm.state_dict())
gm.load_state_dict(model.state_dict())
else:
factory.load_or_random_init(model, device="cuda")
gm = _run_torch_export_to_gm()
move_to_device(gm, "cuda")
if run_verify_generation:
_verify_generation(factory, model, tokenizer)
_verify_generation(model, tokenizer)
# let's do a comparison of every state dict item between the model and the gm
torch.testing.assert_close(model.state_dict(), gm.state_dict(), rtol=0.0, atol=0.0)
@ -157,7 +132,7 @@ def test_bamba_patches(
)
def _verify_generation(factory, model, tokenizer):
def _verify_generation(model, tokenizer):
print("====== WITHOUT PATCH ======")
_generate(tokenizer, model)
with apply_export_patches(patch_list=["bamba"]):

View File

@ -0,0 +1,235 @@
import importlib.util
import sys
import types
from unittest import mock
import pytest
import torch
from _model_test_utils import get_small_model_config
from torch.export import Dim
from transformers import AutoConfig, AutoModelForCausalLM
from utils.llm_data import llm_models_root
from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm
from tensorrt_llm._torch.auto_deploy.llm_args import AutoDeployConfig
from tensorrt_llm._torch.auto_deploy.models.custom.modeling_nemotron_h import NemotronHForCausalLM
from tensorrt_llm._torch.auto_deploy.utils._graph import move_to_device
_BATCH_AND_SEQUENCE_TEST_CASES = ((2, 6), (1, 8))
@pytest.fixture(scope="function", autouse=True)
def set_seed():
torch.manual_seed(42)
@pytest.fixture(autouse=True)
def stub_mamba_ssm_if_missing():
"""Stub `mamba_ssm` package.
The `modeling_nemotron_h.py` code in all recent nemotron checkpoints have a hard dependency
on `mamba_ssm.ops.triton.layernorm_gated.rmsnorm_fn`. This fixture stubs it, such that we
at least can get past the import stage of the remote modeling code.
"""
module = "mamba_ssm"
submodule = f"{module}.ops.triton.layernorm_gated"
if importlib.util.find_spec(module) is not None:
yield
return
stub_mod = types.ModuleType(submodule)
stub_mod.rmsnorm_fn = None
with mock.patch.dict(sys.modules, {submodule: stub_mod}):
yield
def _load_nemotron_moe_layer(model_name_or_path: str, custom_model_cls=None):
"""
Build a tiny NemotronH model (1 layer, small dims) and return the first NemotronHMOE module.
"""
cfg = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True)
cfg.use_cache = False
cfg.torch_dtype = "bfloat16"
cfg.hidden_size = 32
cfg.intermediate_size = 64
cfg.moe_intermediate_size = 64
cfg.moe_shared_expert_intermediate_size = 64
cfg.mamba_head_dim = 40
cfg.mamba_num_heads = 4
cfg.n_groups = 2
cfg.num_attention_heads = 4
cfg.num_hidden_layers = 9
cfg.num_key_value_heads = 2
cfg.ssm_state_size = 32
if custom_model_cls is None:
model = AutoModelForCausalLM.from_config(cfg, trust_remote_code=True)
else:
model = custom_model_cls._from_config(cfg)
model.eval()
nemotron_moe = None
for _, mod in model.named_modules():
if type(mod).__name__ == "NemotronHMOE":
nemotron_moe = mod
break
if nemotron_moe is None:
raise RuntimeError("NemotronHMOE layer not found. Check your model id or config.")
_set_gate_weights(nemotron_moe)
return nemotron_moe
def _set_gate_weights(module):
# This helper function is necessary because the `weight` parameter of the `NemotronHTopkRouter`
# is initialized as `torch.empty` in the original model code, which no manner of random seed
# setting will have any effect on. We therefore set it like the below to ensure the
# reproducibility of the tests.
for _, mod in module.named_modules():
if type(mod).__name__ == "NemotronHTopkRouter":
if hasattr(mod, "weight"):
mod.weight = torch.nn.Parameter(torch.randn_like(mod.weight))
@pytest.mark.parametrize(
"model_name",
[
llm_models_root() / "NVIDIA-Nemotron-3-Nano-30B-A3B-BF16",
],
)
@pytest.mark.parametrize("B,S", _BATCH_AND_SEQUENCE_TEST_CASES)
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@torch.no_grad()
def test_nemotronh_moe_custom_implementation(model_name, B, S, dtype):
device = "cuda"
module = _load_nemotron_moe_layer(model_name)
module.to(device)
H = module.config.hidden_size
x = torch.randn(B, S, H, device=device, dtype=dtype)
ref = module(x)
new_module = _load_nemotron_moe_layer(model_name, custom_model_cls=NemotronHForCausalLM)
new_module.to(device)
new_module.load_state_dict(module.state_dict())
test = new_module(x)
rtol = 0.05
atol = 0.05
torch.testing.assert_close(test, ref, rtol=rtol, atol=atol)
@pytest.mark.parametrize(
"model_dir,model_on_meta_during_export",
[
("nvidia/NVIDIA-Nemotron-Nano-12B-v2", True),
("nvidia/NVIDIA-Nemotron-Nano-12B-v2", False),
],
)
def test_custom_model_implementation_can_be_exported(
model_dir: str,
model_on_meta_during_export: bool,
):
# NOTE: set to False if you want to locally test the full model.
use_small_config: bool = True
common_kwargs = {
"world_size": 0,
"runtime": "demollm",
"model_factory": "AutoModelForCausalLM",
"max_seq_len": 512,
"transforms": {
"insert_cached_attention": {"backend": "flashinfer"},
"compile_model": {"backend": "torch-simple"},
},
}
if use_small_config:
llm_args = get_small_model_config(model_dir, **common_kwargs)["args"]
else:
llm_args = {
"model": model_dir,
**common_kwargs,
"model_kwargs": {
"dtype": "bfloat16",
},
}
llm_args = AutoDeployConfig(**llm_args)
factory = llm_args.create_factory()
model = factory.build_model("meta")
tokenizer = factory.init_tokenizer()
# 1. Export wants min batch size of 2 (to avoid specialization during export).
# 2. Can't get `padding` / `truncation` to work without other steps so just use the prompts
# with the same tokenized length in order for the tokenizer not to complain when creating
# the tensor.
message = [
"Mamba is a snake with the following properties:",
"Tiger is a cat with the following properties:",
]
inputs = tokenizer(message, return_tensors="pt", return_token_type_ids=False).to("cuda")
input_ids = inputs["input_ids"]
position_ids = torch.arange(input_ids.shape[1], device=input_ids.device).repeat(
input_ids.shape[0], 1
)
dynamic_shapes = (
{0: Dim("batch_size", min=0, max=8), 1: Dim("seq_len", min=0, max=512)},
{
0: Dim("batch_size", min=0, max=8),
1: Dim("seq_len", min=0, max=512),
},
)
def _run_torch_export_to_gm():
return torch_export_to_gm(
model,
args=tuple(),
kwargs={"input_ids": input_ids, "position_ids": position_ids},
dynamic_shapes=dynamic_shapes,
)
if model_on_meta_during_export:
gm = _run_torch_export_to_gm()
factory.load_or_random_init(gm, device="cuda")
move_to_device(gm, "cuda")
factory._to_maybe_random(model, "cuda")
# In order to ensure the `_minus_A` (non-persistent buffer) is correct, we need to run the
# model's load state pre/post hooks by loading the state dicts after initialization.
# NOTE: this is done under the hood by `torch_export_to_gm`, so we only need this in this
# `if` clause.
model.load_state_dict(gm.state_dict())
gm.load_state_dict(model.state_dict())
else:
factory.load_or_random_init(model, device="cuda")
gm = _run_torch_export_to_gm()
move_to_device(gm, "cuda")
# let's do a comparison of every state dict item between the model and the gm
torch.testing.assert_close(model.state_dict(), gm.state_dict(), rtol=0.0, atol=0.0)
torch.testing.assert_close(
dict(model.named_buffers()), dict(gm.named_buffers()), rtol=0.0, atol=0.0
)
with torch.inference_mode():
out_original = model(input_ids=input_ids, position_ids=position_ids)
out_gm = gm(input_ids=input_ids, position_ids=position_ids)
atol, rtol = 1e-3, 1e-3
torch.testing.assert_close(
out_gm,
out_original,
rtol=rtol,
atol=atol,
)

View File

@ -1,158 +0,0 @@
import functools
import types
import pytest
import torch
from _model_test_utils import _hf_model_dir_or_hub_id
from transformers import AutoConfig
from tensorrt_llm._torch.auto_deploy.models.modeling_nemotron_h import NemotronHForCausalLM
from tensorrt_llm._torch.auto_deploy.models.patches.nemotron_h import (
_from_config_original,
_nemotron_h_moe_forward,
)
_BATCH_AND_SEQUENCE_TEST_CASES = ((2, 6), (1, 8))
@pytest.fixture(scope="function", autouse=True)
def set_seed():
torch.manual_seed(42)
def skip_on_no_hf_access(func):
"""Decorator for skipping tests that fail due to HF access issues.
This allows us to share the same test code for CI (where access may be restricted, especially for private
repositories) and locally.
"""
@functools.wraps(func)
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except OSError as e:
if "not a valid model identifier" in str(e):
pytest.skip("Test skipped due to (no) HF access.")
raise
return wrapper
def _load_nemotron_moe_layer(model_name_or_path: str, custom_model_cls=None):
"""
Build a tiny NemotronH model (1 layer, small dims) and return the first NemotronHMOE module.
"""
cfg = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True)
cfg.use_cache = False
cfg.torch_dtype = "bfloat16"
cfg.hidden_size = 32
cfg.intermediate_size = 64
cfg.moe_intermediate_size = 64
cfg.moe_shared_expert_intermediate_size = 64
cfg.mamba_head_dim = 40
cfg.mamba_num_heads = 4
cfg.n_groups = 2
cfg.num_attention_heads = 4
cfg.num_hidden_layers = 9
cfg.num_key_value_heads = 2
cfg.ssm_state_size = 32
if custom_model_cls is None:
model = _from_config_original(cfg, trust_remote_code=True)
else:
model = custom_model_cls._from_config(cfg)
model.eval()
nemotron_moe = None
for _, mod in model.named_modules():
if type(mod).__name__ == "NemotronHMOE":
nemotron_moe = mod
break
if nemotron_moe is None:
raise RuntimeError("NemotronHMOE layer not found. Check your model id or config.")
_set_gate_weights(nemotron_moe)
return nemotron_moe
def _set_gate_weights(module):
# This helper function is necessary because the `weight` parameter of the `NemotronHTopkRouter`
# is initialized as `torch.empty` in the original model code, which no manner of random seed
# setting will have any effect on. We therefore set it like the below to ensure the
# reproducibility of the tests.
for _, mod in module.named_modules():
if type(mod).__name__ == "NemotronHTopkRouter":
if hasattr(mod, "weight"):
mod.weight = torch.nn.Parameter(torch.randn_like(mod.weight))
@pytest.mark.parametrize(
"model_name",
[
_hf_model_dir_or_hub_id(
"NVIDIA-Nemotron-Nano-31B-A3-v3", "nvidia/NVIDIA-Nemotron-Nano-31B-A3-v3"
),
],
)
@pytest.mark.parametrize("B,S", _BATCH_AND_SEQUENCE_TEST_CASES)
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@torch.no_grad()
@skip_on_no_hf_access
def test_nemotronh_moe_patch_forward(model_name, B, S, dtype):
device = "cuda"
module = _load_nemotron_moe_layer(model_name)
module.to(device)
H = module.config.hidden_size
x = torch.randn(B, S, H, device=device, dtype=dtype)
ref = module(x)
module.forward = types.MethodType(_nemotron_h_moe_forward, module)
test = module(x)
rtol = 0.05
atol = 0.05
torch.testing.assert_close(test, ref, rtol=rtol, atol=atol)
@pytest.mark.parametrize(
"model_name",
[
_hf_model_dir_or_hub_id(
"NVIDIA-Nemotron-Nano-31B-A3-v3", "nvidia/NVIDIA-Nemotron-Nano-31B-A3-v3"
),
],
)
@pytest.mark.parametrize("B,S", _BATCH_AND_SEQUENCE_TEST_CASES)
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@torch.no_grad()
@skip_on_no_hf_access
def test_nemotronh_moe_custom_implementation(model_name, B, S, dtype):
device = "cuda"
module = _load_nemotron_moe_layer(model_name)
module.to(device)
H = module.config.hidden_size
x = torch.randn(B, S, H, device=device, dtype=dtype)
ref = module(x)
new_module = _load_nemotron_moe_layer(model_name, custom_model_cls=NemotronHForCausalLM)
new_module.to(device)
new_module.load_state_dict(module.state_dict())
test = new_module(x)
rtol = 0.05
atol = 0.05
torch.testing.assert_close(test, ref, rtol=rtol, atol=atol)