mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
chore: Rename nvsmall to nemotron nas (#3447)
* Rename nvsmall to nemotron NAS * Revert nvsmall to nemotron_nas rename in paths in tests that access llm_models_root/nvsmall/tests * Add NemotronNAS to pytorch supported models table Signed-off-by: Amit Zuker <203509407+amitz-nv@users.noreply.github.com>
This commit is contained in:
parent
af05749e90
commit
a6a2ae6cc1
@ -52,6 +52,7 @@ python3 quickstart_multimodal.py --model_dir Efficient-Large-Model/NVILA-8B --mo
|
||||
| `MixtralForCausalLM` | Mixtral | `mistralai/Mixtral-8x7B-v0.1` | L |
|
||||
| `MllamaForConditionalGeneration` | Llama 3.2 | `meta-llama/Llama-3.2-11B-Vision` | L |
|
||||
| `NemotronForCausalLM` | Nemotron-3, Nemotron-4, Minitron | `nvidia/Minitron-8B-Base` | L |
|
||||
| `NemotronNASForCausalLM` | NemotronNAS | `nvidia/Llama-3_3-Nemotron-Super-49B-v1` | L |
|
||||
| `Qwen2ForCausalLM` | QwQ, Qwen2 | `Qwen/Qwen2-7B-Instruct` | L |
|
||||
| `Qwen2ForProcessRewardModel` | Qwen2-based | `Qwen/Qwen2.5-Math-PRM-7B` | L |
|
||||
| `Qwen2ForRewardModel` | Qwen2-based | `Qwen/Qwen2.5-Math-RM-72B` | L |
|
||||
|
||||
@ -8,7 +8,7 @@ from .modeling_llava_next import LlavaNextModel
|
||||
from .modeling_mamba_hybrid import MambaHybridForCausalLM
|
||||
from .modeling_mixtral import MixtralForCausalLM
|
||||
from .modeling_nemotron import NemotronForCausalLM
|
||||
from .modeling_nvsmall import NVSmallForCausalLM
|
||||
from .modeling_nemotron_nas import NemotronNASForCausalLM
|
||||
from .modeling_qwen import (Qwen2ForCausalLM, Qwen2ForProcessRewardModel,
|
||||
Qwen2ForRewardModel)
|
||||
from .modeling_qwen2vl import Qwen2_5_VLModel, Qwen2VLModel
|
||||
@ -26,7 +26,7 @@ __all__ = [
|
||||
"MambaHybridForCausalLM",
|
||||
"MixtralForCausalLM",
|
||||
"NemotronForCausalLM",
|
||||
"NVSmallForCausalLM",
|
||||
"NemotronNASForCausalLM",
|
||||
"Qwen2ForCausalLM",
|
||||
"Qwen2ForProcessRewardModel",
|
||||
"Qwen2ForRewardModel",
|
||||
|
||||
@ -21,7 +21,7 @@ from .modeling_utils import (DecoderModel, DecoderModelForCausalLM,
|
||||
register_auto_model)
|
||||
|
||||
|
||||
class NVSmallRotaryEmbedding(RotaryEmbedding):
|
||||
class NemotronNASRotaryEmbedding(RotaryEmbedding):
|
||||
|
||||
def __init__(self, config: PretrainedConfig, layer_idx: int):
|
||||
if config.rope_scaling is not None:
|
||||
@ -69,7 +69,7 @@ def _create_linear_from_configs(model_config: ModelConfig[PretrainedConfig],
|
||||
)
|
||||
|
||||
|
||||
class NVSmallAttention(Attention):
|
||||
class NemotronNASAttention(Attention):
|
||||
|
||||
def __init__(self, model_config: ModelConfig[PretrainedConfig],
|
||||
layer_idx: int):
|
||||
@ -88,7 +88,7 @@ class NVSmallAttention(Attention):
|
||||
max_position_embeddings=config.max_position_embeddings,
|
||||
bias=False,
|
||||
pos_embd_params=pos_embd_params,
|
||||
rotary_emb=NVSmallRotaryEmbedding(config, layer_idx=layer_idx),
|
||||
rotary_emb=NemotronNASRotaryEmbedding(config, layer_idx=layer_idx),
|
||||
layer_idx=layer_idx,
|
||||
dtype=config.torch_dtype,
|
||||
config=model_config)
|
||||
@ -120,7 +120,7 @@ class LinearMLP(nn.Module):
|
||||
return self.linear_mlp(hidden_states)
|
||||
|
||||
|
||||
class NVSmallDecoderLayer(DecoderLayer):
|
||||
class NemotronNASDecoderLayer(DecoderLayer):
|
||||
|
||||
def __init__(self, model_config: ModelConfig[PretrainedConfig],
|
||||
block_config: Dict[str, Any], layer_idx: int):
|
||||
@ -134,8 +134,8 @@ class NVSmallDecoderLayer(DecoderLayer):
|
||||
if self.block_config.attention.replace_with_linear:
|
||||
self.self_attn = LinearAttention(model_config, config)
|
||||
else:
|
||||
self.self_attn = NVSmallAttention(model_config=model_config,
|
||||
layer_idx=layer_idx)
|
||||
self.self_attn = NemotronNASAttention(model_config=model_config,
|
||||
layer_idx=layer_idx)
|
||||
if not self.block_config.ffn.no_op:
|
||||
self.post_attention_layernorm = RMSNorm(
|
||||
hidden_size=config.hidden_size,
|
||||
@ -182,7 +182,7 @@ class NVSmallDecoderLayer(DecoderLayer):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class NVSmallModel(DecoderModel):
|
||||
class NemotronNASModel(DecoderModel):
|
||||
|
||||
def __init__(self, model_config):
|
||||
super().__init__(model_config)
|
||||
@ -200,7 +200,7 @@ class NVSmallModel(DecoderModel):
|
||||
dtype=config.torch_dtype,
|
||||
)
|
||||
self.layers = nn.ModuleList([
|
||||
NVSmallDecoderLayer(model_config, block_config, layer_idx)
|
||||
NemotronNASDecoderLayer(model_config, block_config, layer_idx)
|
||||
for layer_idx, block_config in enumerate(config.block_configs)
|
||||
])
|
||||
self.norm = RMSNorm(hidden_size=config.hidden_size,
|
||||
@ -209,11 +209,11 @@ class NVSmallModel(DecoderModel):
|
||||
|
||||
|
||||
@register_auto_model("DeciLMForCausalLM")
|
||||
class NVSmallForCausalLM(DecoderModelForCausalLM[NVSmallModel,
|
||||
PretrainedConfig]):
|
||||
class NemotronNASForCausalLM(DecoderModelForCausalLM[NemotronNASModel,
|
||||
PretrainedConfig]):
|
||||
|
||||
def __init__(self, model_config: ModelConfig[PretrainedConfig]):
|
||||
super().__init__(NVSmallModel(model_config),
|
||||
super().__init__(NemotronNASModel(model_config),
|
||||
config=model_config,
|
||||
hidden_size=model_config.pretrained_config.hidden_size,
|
||||
vocab_size=model_config.pretrained_config.vocab_size)
|
||||
@ -202,7 +202,7 @@
|
||||
"examples/test_multimodal.py::test_llm_multimodal_general[Qwen2-VL-7B-Instruct-pp:1-tp:1-float16-bs:1-cpp_e2e:False-nb:4]": 355.9665117710829,
|
||||
"examples/test_qwen.py::test_llm_qwen_single_gpu_summary[qwen2_7b_instruct-enable_paged_kv_cache-enable_remove_input_padding-enable_weight_only-enable_fmha]": 334.18761303275824,
|
||||
"examples/test_qwen2audio.py::test_llm_qwen2audio_single_gpu[qwen2_audio_7b_instruct]": 373.8418433815241,
|
||||
"test_unittests.py::test_unittests_v2[unittest/_torch/modeling -k \"modeling_nvsmall\"]": 587.2612264081836,
|
||||
"test_unittests.py::test_unittests_v2[unittest/_torch/modeling -k \"modeling_nemotron_nas\"]": 587.2612264081836,
|
||||
"test_unittests.py::test_unittests_v2[unittest/trt/model/test_gpt.py -k \"partition2\"]": 1086.5458996072412,
|
||||
"examples/test_draft_target_model.py::test_llm_draft_target_model_1gpu[streaming-gpt2-use_cpp_session-use_tokens-draft_len_4-float16-bs1]": 216.4487509690225,
|
||||
"examples/test_enc_dec.py::test_llm_enc_dec_general[compare_hf-byt5-small-float32-enable_gemm_plugin-enable_attention_plugin-enable_paged_kv_cache-tp:1-pp:1-nb:1-disable_fp8]": 213.7413339074701,
|
||||
|
||||
@ -18,7 +18,7 @@ l0_a30:
|
||||
- unittest/_torch -k "modeling_llama"
|
||||
- unittest/_torch/modeling -k "modeling_mixtral"
|
||||
- unittest/_torch/modeling -k "modeling_mllama"
|
||||
- unittest/_torch/modeling -k "modeling_nvsmall"
|
||||
- unittest/_torch/modeling -k "modeling_nemotron_nas"
|
||||
- unittest/_torch/modeling -k "modeling_out_of_tree"
|
||||
- unittest/_torch/modeling -k "modeling_qwen"
|
||||
- unittest/_torch/modeling -k "modeling_qwen_moe"
|
||||
|
||||
@ -16,7 +16,7 @@ l0_l40s:
|
||||
# ------------- PyTorch tests ---------------
|
||||
- unittest/_torch -k "not (modeling or multi_gpu or auto_deploy)"
|
||||
- unittest/_torch/modeling -k "modeling_mllama"
|
||||
- unittest/_torch/modeling -k "modeling_nvsmall"
|
||||
- unittest/_torch/modeling -k "modeling_nemotron_nas"
|
||||
- unittest/_torch/modeling -k "modeling_out_of_tree"
|
||||
- unittest/_torch/modeling -k "modeling_qwen"
|
||||
- unittest/_torch/modeling -k "modeling_vila"
|
||||
|
||||
@ -96,7 +96,7 @@ from utils.llm_data import llm_models_root
|
||||
pytest.mark.skip(reason="https://nvbugspro.nvidia.com/bug/5178508"),
|
||||
],
|
||||
),
|
||||
# full NVSmall (Llama-3.1-Nemotron-51B) with torch-opt backend + simple runtime
|
||||
# full NemotronNAS (Llama-3.1-Nemotron-51B) with torch-opt backend + simple runtime
|
||||
param_with_device_count(
|
||||
4,
|
||||
{
|
||||
|
||||
@ -13,12 +13,13 @@ import tensorrt_llm
|
||||
from tensorrt_llm._torch.attention_backend.utils import get_attention_backend
|
||||
from tensorrt_llm._torch.metadata import KVCacheParams
|
||||
from tensorrt_llm._torch.model_config import ModelConfig
|
||||
from tensorrt_llm._torch.models.modeling_nvsmall import NVSmallForCausalLM
|
||||
from tensorrt_llm._torch.models.modeling_nemotron_nas import \
|
||||
NemotronNASForCausalLM
|
||||
from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager
|
||||
from tensorrt_llm.bindings.executor import KvCacheConfig
|
||||
from tensorrt_llm.mapping import Mapping
|
||||
|
||||
NVSMALL_MINI_CONFIG = {
|
||||
NEMOTRON_NAS_MINI_CONFIG = {
|
||||
"architectures": ["DeciLMForCausalLM"],
|
||||
"attention_bias":
|
||||
False,
|
||||
@ -224,7 +225,8 @@ class Scenario:
|
||||
return f"backend:{self.backend.lower()}"
|
||||
|
||||
|
||||
def reduce_nvsmall_config(mem_for_full_model: int, config_dict: dict[str, Any]):
|
||||
def reduce_nemotron_nas_config(mem_for_full_model: int, config_dict: dict[str,
|
||||
Any]):
|
||||
_, total_mem = torch.cuda.mem_get_info()
|
||||
# scale model down if gpu memory is low
|
||||
if total_mem < mem_for_full_model:
|
||||
@ -235,24 +237,24 @@ def reduce_nvsmall_config(mem_for_full_model: int, config_dict: dict[str, Any]):
|
||||
config_dict["block_configs"] = config_dict["block_configs"][:num_layers]
|
||||
|
||||
|
||||
class TestNVSmall(unittest.TestCase):
|
||||
class TestNemotronNAS(unittest.TestCase):
|
||||
|
||||
def test_nvsmall_sanity(self):
|
||||
config_dict = deepcopy(NVSMALL_MINI_CONFIG)
|
||||
def test_nemotron_nas_sanity(self):
|
||||
config_dict = deepcopy(NEMOTRON_NAS_MINI_CONFIG)
|
||||
# 8B * sizeof(float16) plus some extra for activations
|
||||
mem_for_full_model = (2 + 1) * 8 * 2**(30)
|
||||
reduce_nvsmall_config(mem_for_full_model, config_dict)
|
||||
reduce_nemotron_nas_config(mem_for_full_model, config_dict)
|
||||
if config_dict["num_hidden_layers"] <= 0:
|
||||
self.skipTest("Insufficient memory for a single NVSmall layer")
|
||||
nvsmall_config = AutoConfig.from_pretrained(
|
||||
self.skipTest("Insufficient memory for a single NemotronNAS layer")
|
||||
nemotron_nas_config = AutoConfig.from_pretrained(
|
||||
"nvidia/Llama-3_1-Nemotron-51B-Instruct", trust_remote_code=True)
|
||||
nvsmall_config = nvsmall_config.from_dict(config_dict)
|
||||
nemotron_nas_config = nemotron_nas_config.from_dict(config_dict)
|
||||
|
||||
dtype = nvsmall_config.torch_dtype
|
||||
dtype = nemotron_nas_config.torch_dtype
|
||||
device = torch.device('cuda')
|
||||
|
||||
model_config = ModelConfig(pretrained_config=nvsmall_config)
|
||||
nvsmall = NVSmallForCausalLM(model_config).to(dtype).to(device)
|
||||
model_config = ModelConfig(pretrained_config=nemotron_nas_config)
|
||||
nemotron_nas = NemotronNASForCausalLM(model_config).to(dtype).to(device)
|
||||
|
||||
input_ids = torch.tensor([100, 200, 300, 100, 200, 100, 400, 500],
|
||||
dtype=torch.int,
|
||||
@ -272,10 +274,10 @@ class TestNVSmall(unittest.TestCase):
|
||||
kv_cache_config = KvCacheConfig(max_tokens=num_blocks *
|
||||
tokens_per_block)
|
||||
|
||||
num_layers = nvsmall.config.num_hidden_layers
|
||||
num_kv_heads = nvsmall.config.num_key_value_heads
|
||||
num_heads = nvsmall.config.num_attention_heads
|
||||
head_dim = nvsmall.config.hidden_size // num_heads
|
||||
num_layers = nemotron_nas.config.num_hidden_layers
|
||||
num_kv_heads = nemotron_nas.config.num_key_value_heads
|
||||
num_heads = nemotron_nas.config.num_attention_heads
|
||||
head_dim = nemotron_nas.config.hidden_size // num_heads
|
||||
max_seq_len = num_blocks * tokens_per_block
|
||||
|
||||
context_sequence_lengths = [3, 2, 1]
|
||||
@ -329,18 +331,18 @@ class TestNVSmall(unittest.TestCase):
|
||||
|
||||
with torch.inference_mode():
|
||||
attn_metadata.prepare()
|
||||
logits = nvsmall.forward(input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
attn_metadata=attn_metadata)
|
||||
logits = nemotron_nas.forward(input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
attn_metadata=attn_metadata)
|
||||
|
||||
self.assertEqual(len(past_seen_tokens), logits.shape[0])
|
||||
|
||||
with torch.inference_mode():
|
||||
attn_metadata.prepare()
|
||||
logits = nvsmall.forward(input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
attn_metadata=attn_metadata,
|
||||
return_context_logits=True)
|
||||
logits = nemotron_nas.forward(input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
attn_metadata=attn_metadata,
|
||||
return_context_logits=True)
|
||||
self.assertEqual(input_ids.shape, logits.shape[:-1])
|
||||
|
||||
kv_cache_manager.shutdown()
|
||||
@ -352,7 +354,7 @@ class TestNVSmall(unittest.TestCase):
|
||||
], lambda testcase_func, param_num, param:
|
||||
f"{testcase_func.__name__}[{param.args[0]}]")
|
||||
@torch.no_grad()
|
||||
def test_nvsmall_allclose_to_hf(self, scenario: Scenario) -> None:
|
||||
def test_nemotron_nas_allclose_to_hf(self, scenario: Scenario) -> None:
|
||||
"""
|
||||
Compare output to HF
|
||||
"""
|
||||
@ -360,34 +362,34 @@ class TestNVSmall(unittest.TestCase):
|
||||
metadata_cls = get_attention_backend(backend).Metadata
|
||||
|
||||
torch.random.manual_seed(0)
|
||||
config_dict = deepcopy(NVSMALL_MINI_CONFIG)
|
||||
config_dict = deepcopy(NEMOTRON_NAS_MINI_CONFIG)
|
||||
# 8B * sizeof(float16) plus some extra for activations
|
||||
# times 2, since we'll need 2 of these
|
||||
mem_for_full_model = (2 + 1) * 8 * 2**(30) * 4
|
||||
reduce_nvsmall_config(mem_for_full_model, config_dict)
|
||||
reduce_nemotron_nas_config(mem_for_full_model, config_dict)
|
||||
if config_dict["num_hidden_layers"] <= 0:
|
||||
self.skipTest("Insufficient memory for a single NVSmall layer")
|
||||
nvsmall_config = AutoConfig.from_pretrained(
|
||||
self.skipTest("Insufficient memory for a single NemotronNAS layer")
|
||||
nemotron_nas_config = AutoConfig.from_pretrained(
|
||||
"nvidia/Llama-3_1-Nemotron-51B-Instruct", trust_remote_code=True)
|
||||
nvsmall_config = nvsmall_config.from_dict(config_dict)
|
||||
dtype = nvsmall_config.torch_dtype
|
||||
nemotron_nas_config = nemotron_nas_config.from_dict(config_dict)
|
||||
dtype = nemotron_nas_config.torch_dtype
|
||||
device = torch.device('cuda')
|
||||
|
||||
hf_nvsmall = AutoModelForCausalLM.from_pretrained(
|
||||
hf_nemotron_nas = AutoModelForCausalLM.from_pretrained(
|
||||
llm_models_root() / "nemotron-nas/Llama-3_1-Nemotron-51B-Instruct",
|
||||
trust_remote_code=True,
|
||||
device_map="meta")
|
||||
hf_nvsmall = hf_nvsmall.__class__(nvsmall_config).to(dtype).to(
|
||||
device).eval()
|
||||
hf_nemotron_nas = hf_nemotron_nas.__class__(nemotron_nas_config).to(
|
||||
dtype).to(device).eval()
|
||||
# This line populates the "variable" field in the NEED_SETUP_CACHE_CLASSES_MAPPING dict
|
||||
hf_nvsmall._prepare_generation_config(None)
|
||||
hf_nemotron_nas._prepare_generation_config(None)
|
||||
# And this line is the only way to access the only concrete Cache class DeciLMForCausalLM accepts
|
||||
VariableCache = NEED_SETUP_CACHE_CLASSES_MAPPING["variable"]
|
||||
|
||||
model_config = ModelConfig(pretrained_config=nvsmall_config,
|
||||
model_config = ModelConfig(pretrained_config=nemotron_nas_config,
|
||||
attn_backend=backend)
|
||||
nvsmall = NVSmallForCausalLM(model_config).to(dtype).to(device)
|
||||
nvsmall.load_weights(hf_nvsmall.state_dict())
|
||||
nemotron_nas = NemotronNASForCausalLM(model_config).to(dtype).to(device)
|
||||
nemotron_nas.load_weights(hf_nemotron_nas.state_dict())
|
||||
|
||||
num_blocks = 1
|
||||
tokens_per_block = 128
|
||||
@ -395,10 +397,10 @@ class TestNVSmall(unittest.TestCase):
|
||||
kv_cache_config = KvCacheConfig(max_tokens=num_blocks *
|
||||
tokens_per_block)
|
||||
|
||||
num_layers = nvsmall.config.num_hidden_layers
|
||||
num_kv_heads = nvsmall.config.num_key_value_heads
|
||||
num_heads = nvsmall.config.num_attention_heads
|
||||
head_dim = nvsmall.config.hidden_size // num_heads
|
||||
num_layers = nemotron_nas.config.num_hidden_layers
|
||||
num_kv_heads = nemotron_nas.config.num_key_value_heads
|
||||
num_heads = nemotron_nas.config.num_attention_heads
|
||||
head_dim = nemotron_nas.config.hidden_size // num_heads
|
||||
max_seq_len = num_blocks * tokens_per_block
|
||||
batch_size = 1
|
||||
|
||||
@ -450,19 +452,19 @@ class TestNVSmall(unittest.TestCase):
|
||||
|
||||
position_ids = [torch.arange(0, input_ids.size(-1))]
|
||||
position_ids = torch.cat(position_ids).unsqueeze(0).cuda()
|
||||
# And, lastly, this is the simplest way of creating a Cache that `hf_nvsmall` will accept
|
||||
past_key_values = VariableCache(config=nvsmall_config,
|
||||
# And, lastly, this is the simplest way of creating a Cache that `hf_nemotron_nas` will accept
|
||||
past_key_values = VariableCache(config=nemotron_nas_config,
|
||||
dtype=dtype,
|
||||
batch_size=1)
|
||||
with torch.inference_mode():
|
||||
attn_metadata.prepare()
|
||||
logits = nvsmall.forward(input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
attn_metadata=attn_metadata)
|
||||
ref = hf_nvsmall.forward(input_ids=input_ids.unsqueeze(0),
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=True)
|
||||
logits = nemotron_nas.forward(input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
attn_metadata=attn_metadata)
|
||||
ref = hf_nemotron_nas.forward(input_ids=input_ids.unsqueeze(0),
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=True)
|
||||
|
||||
torch.testing.assert_close(logits,
|
||||
ref.logits[:, -1].float(),
|
||||
@ -495,13 +497,13 @@ class TestNVSmall(unittest.TestCase):
|
||||
gen_position_ids = torch.cat(gen_position_ids).unsqueeze(0).cuda()
|
||||
with torch.inference_mode():
|
||||
attn_metadata.prepare()
|
||||
logits = nvsmall.forward(input_ids=gen_input_ids,
|
||||
position_ids=gen_position_ids,
|
||||
attn_metadata=attn_metadata)
|
||||
ref = hf_nvsmall.forward(input_ids=gen_input_ids.unsqueeze(0),
|
||||
position_ids=gen_position_ids,
|
||||
past_key_values=ref.past_key_values,
|
||||
use_cache=True)
|
||||
logits = nemotron_nas.forward(input_ids=gen_input_ids,
|
||||
position_ids=gen_position_ids,
|
||||
attn_metadata=attn_metadata)
|
||||
ref = hf_nemotron_nas.forward(input_ids=gen_input_ids.unsqueeze(0),
|
||||
position_ids=gen_position_ids,
|
||||
past_key_values=ref.past_key_values,
|
||||
use_cache=True)
|
||||
|
||||
torch.testing.assert_close(logits,
|
||||
ref.logits[:, -1].float(),
|
||||
Loading…
Reference in New Issue
Block a user