chore: Rename nvsmall to nemotron nas (#3447)

* Rename nvsmall to nemotron NAS

* Revert nvsmall to nemotron_nas rename in paths in tests that access llm_models_root/nvsmall/tests

* Add NemotronNAS to pytorch supported models table

Signed-off-by: Amit Zuker <203509407+amitz-nv@users.noreply.github.com>
This commit is contained in:
amitz-nv 2025-04-10 18:16:52 +03:00 committed by GitHub
parent af05749e90
commit a6a2ae6cc1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 78 additions and 75 deletions

View File

@ -52,6 +52,7 @@ python3 quickstart_multimodal.py --model_dir Efficient-Large-Model/NVILA-8B --mo
| `MixtralForCausalLM` | Mixtral | `mistralai/Mixtral-8x7B-v0.1` | L |
| `MllamaForConditionalGeneration` | Llama 3.2 | `meta-llama/Llama-3.2-11B-Vision` | L |
| `NemotronForCausalLM` | Nemotron-3, Nemotron-4, Minitron | `nvidia/Minitron-8B-Base` | L |
| `NemotronNASForCausalLM` | NemotronNAS | `nvidia/Llama-3_3-Nemotron-Super-49B-v1` | L |
| `Qwen2ForCausalLM` | QwQ, Qwen2 | `Qwen/Qwen2-7B-Instruct` | L |
| `Qwen2ForProcessRewardModel` | Qwen2-based | `Qwen/Qwen2.5-Math-PRM-7B` | L |
| `Qwen2ForRewardModel` | Qwen2-based | `Qwen/Qwen2.5-Math-RM-72B` | L |

View File

@ -8,7 +8,7 @@ from .modeling_llava_next import LlavaNextModel
from .modeling_mamba_hybrid import MambaHybridForCausalLM
from .modeling_mixtral import MixtralForCausalLM
from .modeling_nemotron import NemotronForCausalLM
from .modeling_nvsmall import NVSmallForCausalLM
from .modeling_nemotron_nas import NemotronNASForCausalLM
from .modeling_qwen import (Qwen2ForCausalLM, Qwen2ForProcessRewardModel,
Qwen2ForRewardModel)
from .modeling_qwen2vl import Qwen2_5_VLModel, Qwen2VLModel
@ -26,7 +26,7 @@ __all__ = [
"MambaHybridForCausalLM",
"MixtralForCausalLM",
"NemotronForCausalLM",
"NVSmallForCausalLM",
"NemotronNASForCausalLM",
"Qwen2ForCausalLM",
"Qwen2ForProcessRewardModel",
"Qwen2ForRewardModel",

View File

@ -21,7 +21,7 @@ from .modeling_utils import (DecoderModel, DecoderModelForCausalLM,
register_auto_model)
class NVSmallRotaryEmbedding(RotaryEmbedding):
class NemotronNASRotaryEmbedding(RotaryEmbedding):
def __init__(self, config: PretrainedConfig, layer_idx: int):
if config.rope_scaling is not None:
@ -69,7 +69,7 @@ def _create_linear_from_configs(model_config: ModelConfig[PretrainedConfig],
)
class NVSmallAttention(Attention):
class NemotronNASAttention(Attention):
def __init__(self, model_config: ModelConfig[PretrainedConfig],
layer_idx: int):
@ -88,7 +88,7 @@ class NVSmallAttention(Attention):
max_position_embeddings=config.max_position_embeddings,
bias=False,
pos_embd_params=pos_embd_params,
rotary_emb=NVSmallRotaryEmbedding(config, layer_idx=layer_idx),
rotary_emb=NemotronNASRotaryEmbedding(config, layer_idx=layer_idx),
layer_idx=layer_idx,
dtype=config.torch_dtype,
config=model_config)
@ -120,7 +120,7 @@ class LinearMLP(nn.Module):
return self.linear_mlp(hidden_states)
class NVSmallDecoderLayer(DecoderLayer):
class NemotronNASDecoderLayer(DecoderLayer):
def __init__(self, model_config: ModelConfig[PretrainedConfig],
block_config: Dict[str, Any], layer_idx: int):
@ -134,8 +134,8 @@ class NVSmallDecoderLayer(DecoderLayer):
if self.block_config.attention.replace_with_linear:
self.self_attn = LinearAttention(model_config, config)
else:
self.self_attn = NVSmallAttention(model_config=model_config,
layer_idx=layer_idx)
self.self_attn = NemotronNASAttention(model_config=model_config,
layer_idx=layer_idx)
if not self.block_config.ffn.no_op:
self.post_attention_layernorm = RMSNorm(
hidden_size=config.hidden_size,
@ -182,7 +182,7 @@ class NVSmallDecoderLayer(DecoderLayer):
return hidden_states
class NVSmallModel(DecoderModel):
class NemotronNASModel(DecoderModel):
def __init__(self, model_config):
super().__init__(model_config)
@ -200,7 +200,7 @@ class NVSmallModel(DecoderModel):
dtype=config.torch_dtype,
)
self.layers = nn.ModuleList([
NVSmallDecoderLayer(model_config, block_config, layer_idx)
NemotronNASDecoderLayer(model_config, block_config, layer_idx)
for layer_idx, block_config in enumerate(config.block_configs)
])
self.norm = RMSNorm(hidden_size=config.hidden_size,
@ -209,11 +209,11 @@ class NVSmallModel(DecoderModel):
@register_auto_model("DeciLMForCausalLM")
class NVSmallForCausalLM(DecoderModelForCausalLM[NVSmallModel,
PretrainedConfig]):
class NemotronNASForCausalLM(DecoderModelForCausalLM[NemotronNASModel,
PretrainedConfig]):
def __init__(self, model_config: ModelConfig[PretrainedConfig]):
super().__init__(NVSmallModel(model_config),
super().__init__(NemotronNASModel(model_config),
config=model_config,
hidden_size=model_config.pretrained_config.hidden_size,
vocab_size=model_config.pretrained_config.vocab_size)

View File

@ -202,7 +202,7 @@
"examples/test_multimodal.py::test_llm_multimodal_general[Qwen2-VL-7B-Instruct-pp:1-tp:1-float16-bs:1-cpp_e2e:False-nb:4]": 355.9665117710829,
"examples/test_qwen.py::test_llm_qwen_single_gpu_summary[qwen2_7b_instruct-enable_paged_kv_cache-enable_remove_input_padding-enable_weight_only-enable_fmha]": 334.18761303275824,
"examples/test_qwen2audio.py::test_llm_qwen2audio_single_gpu[qwen2_audio_7b_instruct]": 373.8418433815241,
"test_unittests.py::test_unittests_v2[unittest/_torch/modeling -k \"modeling_nvsmall\"]": 587.2612264081836,
"test_unittests.py::test_unittests_v2[unittest/_torch/modeling -k \"modeling_nemotron_nas\"]": 587.2612264081836,
"test_unittests.py::test_unittests_v2[unittest/trt/model/test_gpt.py -k \"partition2\"]": 1086.5458996072412,
"examples/test_draft_target_model.py::test_llm_draft_target_model_1gpu[streaming-gpt2-use_cpp_session-use_tokens-draft_len_4-float16-bs1]": 216.4487509690225,
"examples/test_enc_dec.py::test_llm_enc_dec_general[compare_hf-byt5-small-float32-enable_gemm_plugin-enable_attention_plugin-enable_paged_kv_cache-tp:1-pp:1-nb:1-disable_fp8]": 213.7413339074701,

View File

@ -18,7 +18,7 @@ l0_a30:
- unittest/_torch -k "modeling_llama"
- unittest/_torch/modeling -k "modeling_mixtral"
- unittest/_torch/modeling -k "modeling_mllama"
- unittest/_torch/modeling -k "modeling_nvsmall"
- unittest/_torch/modeling -k "modeling_nemotron_nas"
- unittest/_torch/modeling -k "modeling_out_of_tree"
- unittest/_torch/modeling -k "modeling_qwen"
- unittest/_torch/modeling -k "modeling_qwen_moe"

View File

@ -16,7 +16,7 @@ l0_l40s:
# ------------- PyTorch tests ---------------
- unittest/_torch -k "not (modeling or multi_gpu or auto_deploy)"
- unittest/_torch/modeling -k "modeling_mllama"
- unittest/_torch/modeling -k "modeling_nvsmall"
- unittest/_torch/modeling -k "modeling_nemotron_nas"
- unittest/_torch/modeling -k "modeling_out_of_tree"
- unittest/_torch/modeling -k "modeling_qwen"
- unittest/_torch/modeling -k "modeling_vila"

View File

@ -96,7 +96,7 @@ from utils.llm_data import llm_models_root
pytest.mark.skip(reason="https://nvbugspro.nvidia.com/bug/5178508"),
],
),
# full NVSmall (Llama-3.1-Nemotron-51B) with torch-opt backend + simple runtime
# full NemotronNAS (Llama-3.1-Nemotron-51B) with torch-opt backend + simple runtime
param_with_device_count(
4,
{

View File

@ -13,12 +13,13 @@ import tensorrt_llm
from tensorrt_llm._torch.attention_backend.utils import get_attention_backend
from tensorrt_llm._torch.metadata import KVCacheParams
from tensorrt_llm._torch.model_config import ModelConfig
from tensorrt_llm._torch.models.modeling_nvsmall import NVSmallForCausalLM
from tensorrt_llm._torch.models.modeling_nemotron_nas import \
NemotronNASForCausalLM
from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager
from tensorrt_llm.bindings.executor import KvCacheConfig
from tensorrt_llm.mapping import Mapping
NVSMALL_MINI_CONFIG = {
NEMOTRON_NAS_MINI_CONFIG = {
"architectures": ["DeciLMForCausalLM"],
"attention_bias":
False,
@ -224,7 +225,8 @@ class Scenario:
return f"backend:{self.backend.lower()}"
def reduce_nvsmall_config(mem_for_full_model: int, config_dict: dict[str, Any]):
def reduce_nemotron_nas_config(mem_for_full_model: int, config_dict: dict[str,
Any]):
_, total_mem = torch.cuda.mem_get_info()
# scale model down if gpu memory is low
if total_mem < mem_for_full_model:
@ -235,24 +237,24 @@ def reduce_nvsmall_config(mem_for_full_model: int, config_dict: dict[str, Any]):
config_dict["block_configs"] = config_dict["block_configs"][:num_layers]
class TestNVSmall(unittest.TestCase):
class TestNemotronNAS(unittest.TestCase):
def test_nvsmall_sanity(self):
config_dict = deepcopy(NVSMALL_MINI_CONFIG)
def test_nemotron_nas_sanity(self):
config_dict = deepcopy(NEMOTRON_NAS_MINI_CONFIG)
# 8B * sizeof(float16) plus some extra for activations
mem_for_full_model = (2 + 1) * 8 * 2**(30)
reduce_nvsmall_config(mem_for_full_model, config_dict)
reduce_nemotron_nas_config(mem_for_full_model, config_dict)
if config_dict["num_hidden_layers"] <= 0:
self.skipTest("Insufficient memory for a single NVSmall layer")
nvsmall_config = AutoConfig.from_pretrained(
self.skipTest("Insufficient memory for a single NemotronNAS layer")
nemotron_nas_config = AutoConfig.from_pretrained(
"nvidia/Llama-3_1-Nemotron-51B-Instruct", trust_remote_code=True)
nvsmall_config = nvsmall_config.from_dict(config_dict)
nemotron_nas_config = nemotron_nas_config.from_dict(config_dict)
dtype = nvsmall_config.torch_dtype
dtype = nemotron_nas_config.torch_dtype
device = torch.device('cuda')
model_config = ModelConfig(pretrained_config=nvsmall_config)
nvsmall = NVSmallForCausalLM(model_config).to(dtype).to(device)
model_config = ModelConfig(pretrained_config=nemotron_nas_config)
nemotron_nas = NemotronNASForCausalLM(model_config).to(dtype).to(device)
input_ids = torch.tensor([100, 200, 300, 100, 200, 100, 400, 500],
dtype=torch.int,
@ -272,10 +274,10 @@ class TestNVSmall(unittest.TestCase):
kv_cache_config = KvCacheConfig(max_tokens=num_blocks *
tokens_per_block)
num_layers = nvsmall.config.num_hidden_layers
num_kv_heads = nvsmall.config.num_key_value_heads
num_heads = nvsmall.config.num_attention_heads
head_dim = nvsmall.config.hidden_size // num_heads
num_layers = nemotron_nas.config.num_hidden_layers
num_kv_heads = nemotron_nas.config.num_key_value_heads
num_heads = nemotron_nas.config.num_attention_heads
head_dim = nemotron_nas.config.hidden_size // num_heads
max_seq_len = num_blocks * tokens_per_block
context_sequence_lengths = [3, 2, 1]
@ -329,18 +331,18 @@ class TestNVSmall(unittest.TestCase):
with torch.inference_mode():
attn_metadata.prepare()
logits = nvsmall.forward(input_ids=input_ids,
position_ids=position_ids,
attn_metadata=attn_metadata)
logits = nemotron_nas.forward(input_ids=input_ids,
position_ids=position_ids,
attn_metadata=attn_metadata)
self.assertEqual(len(past_seen_tokens), logits.shape[0])
with torch.inference_mode():
attn_metadata.prepare()
logits = nvsmall.forward(input_ids=input_ids,
position_ids=position_ids,
attn_metadata=attn_metadata,
return_context_logits=True)
logits = nemotron_nas.forward(input_ids=input_ids,
position_ids=position_ids,
attn_metadata=attn_metadata,
return_context_logits=True)
self.assertEqual(input_ids.shape, logits.shape[:-1])
kv_cache_manager.shutdown()
@ -352,7 +354,7 @@ class TestNVSmall(unittest.TestCase):
], lambda testcase_func, param_num, param:
f"{testcase_func.__name__}[{param.args[0]}]")
@torch.no_grad()
def test_nvsmall_allclose_to_hf(self, scenario: Scenario) -> None:
def test_nemotron_nas_allclose_to_hf(self, scenario: Scenario) -> None:
"""
Compare output to HF
"""
@ -360,34 +362,34 @@ class TestNVSmall(unittest.TestCase):
metadata_cls = get_attention_backend(backend).Metadata
torch.random.manual_seed(0)
config_dict = deepcopy(NVSMALL_MINI_CONFIG)
config_dict = deepcopy(NEMOTRON_NAS_MINI_CONFIG)
# 8B * sizeof(float16) plus some extra for activations
# times 2, since we'll need 2 of these
mem_for_full_model = (2 + 1) * 8 * 2**(30) * 4
reduce_nvsmall_config(mem_for_full_model, config_dict)
reduce_nemotron_nas_config(mem_for_full_model, config_dict)
if config_dict["num_hidden_layers"] <= 0:
self.skipTest("Insufficient memory for a single NVSmall layer")
nvsmall_config = AutoConfig.from_pretrained(
self.skipTest("Insufficient memory for a single NemotronNAS layer")
nemotron_nas_config = AutoConfig.from_pretrained(
"nvidia/Llama-3_1-Nemotron-51B-Instruct", trust_remote_code=True)
nvsmall_config = nvsmall_config.from_dict(config_dict)
dtype = nvsmall_config.torch_dtype
nemotron_nas_config = nemotron_nas_config.from_dict(config_dict)
dtype = nemotron_nas_config.torch_dtype
device = torch.device('cuda')
hf_nvsmall = AutoModelForCausalLM.from_pretrained(
hf_nemotron_nas = AutoModelForCausalLM.from_pretrained(
llm_models_root() / "nemotron-nas/Llama-3_1-Nemotron-51B-Instruct",
trust_remote_code=True,
device_map="meta")
hf_nvsmall = hf_nvsmall.__class__(nvsmall_config).to(dtype).to(
device).eval()
hf_nemotron_nas = hf_nemotron_nas.__class__(nemotron_nas_config).to(
dtype).to(device).eval()
# This line populates the "variable" field in the NEED_SETUP_CACHE_CLASSES_MAPPING dict
hf_nvsmall._prepare_generation_config(None)
hf_nemotron_nas._prepare_generation_config(None)
# And this line is the only way to access the only concrete Cache class DeciLMForCausalLM accepts
VariableCache = NEED_SETUP_CACHE_CLASSES_MAPPING["variable"]
model_config = ModelConfig(pretrained_config=nvsmall_config,
model_config = ModelConfig(pretrained_config=nemotron_nas_config,
attn_backend=backend)
nvsmall = NVSmallForCausalLM(model_config).to(dtype).to(device)
nvsmall.load_weights(hf_nvsmall.state_dict())
nemotron_nas = NemotronNASForCausalLM(model_config).to(dtype).to(device)
nemotron_nas.load_weights(hf_nemotron_nas.state_dict())
num_blocks = 1
tokens_per_block = 128
@ -395,10 +397,10 @@ class TestNVSmall(unittest.TestCase):
kv_cache_config = KvCacheConfig(max_tokens=num_blocks *
tokens_per_block)
num_layers = nvsmall.config.num_hidden_layers
num_kv_heads = nvsmall.config.num_key_value_heads
num_heads = nvsmall.config.num_attention_heads
head_dim = nvsmall.config.hidden_size // num_heads
num_layers = nemotron_nas.config.num_hidden_layers
num_kv_heads = nemotron_nas.config.num_key_value_heads
num_heads = nemotron_nas.config.num_attention_heads
head_dim = nemotron_nas.config.hidden_size // num_heads
max_seq_len = num_blocks * tokens_per_block
batch_size = 1
@ -450,19 +452,19 @@ class TestNVSmall(unittest.TestCase):
position_ids = [torch.arange(0, input_ids.size(-1))]
position_ids = torch.cat(position_ids).unsqueeze(0).cuda()
# And, lastly, this is the simplest way of creating a Cache that `hf_nvsmall` will accept
past_key_values = VariableCache(config=nvsmall_config,
# And, lastly, this is the simplest way of creating a Cache that `hf_nemotron_nas` will accept
past_key_values = VariableCache(config=nemotron_nas_config,
dtype=dtype,
batch_size=1)
with torch.inference_mode():
attn_metadata.prepare()
logits = nvsmall.forward(input_ids=input_ids,
position_ids=position_ids,
attn_metadata=attn_metadata)
ref = hf_nvsmall.forward(input_ids=input_ids.unsqueeze(0),
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=True)
logits = nemotron_nas.forward(input_ids=input_ids,
position_ids=position_ids,
attn_metadata=attn_metadata)
ref = hf_nemotron_nas.forward(input_ids=input_ids.unsqueeze(0),
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=True)
torch.testing.assert_close(logits,
ref.logits[:, -1].float(),
@ -495,13 +497,13 @@ class TestNVSmall(unittest.TestCase):
gen_position_ids = torch.cat(gen_position_ids).unsqueeze(0).cuda()
with torch.inference_mode():
attn_metadata.prepare()
logits = nvsmall.forward(input_ids=gen_input_ids,
position_ids=gen_position_ids,
attn_metadata=attn_metadata)
ref = hf_nvsmall.forward(input_ids=gen_input_ids.unsqueeze(0),
position_ids=gen_position_ids,
past_key_values=ref.past_key_values,
use_cache=True)
logits = nemotron_nas.forward(input_ids=gen_input_ids,
position_ids=gen_position_ids,
attn_metadata=attn_metadata)
ref = hf_nemotron_nas.forward(input_ids=gen_input_ids.unsqueeze(0),
position_ids=gen_position_ids,
past_key_values=ref.past_key_values,
use_cache=True)
torch.testing.assert_close(logits,
ref.logits[:, -1].float(),