[feat] get params

This commit is contained in:
jingyaogong
2025-12-31 20:44:34 +08:00
parent eead9538b2
commit 288a1d7212
2 changed files with 23 additions and 7 deletions
+2 -2
View File
@@ -6,7 +6,7 @@ import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer
from model.model_minimind import MiniMindConfig, MiniMindForCausalLM
from model.model_lora import *
from trainer.trainer_utils import setup_seed
from trainer.trainer_utils import setup_seed, get_model_params
warnings.filterwarnings('ignore')
def init_model(args):
@@ -26,7 +26,7 @@ def init_model(args):
load_lora(model, f'./{args.save_dir}/lora/{args.lora_weight}_{args.hidden_size}.pth')
else:
model = AutoModelForCausalLM.from_pretrained(args.load_from, trust_remote_code=True)
print(f'MiniMind模型参数: {sum(p.numel() for p in model.parameters()) / 1e6:.2f} M(illion)')
get_model_params(model, model.config)
return model.eval().to(args.device), tokenizer
def main():