mirror of
https://github.com/jingyaogong/minimind.git
synced 2026-06-06 00:04:50 +00:00
[feat] get params
This commit is contained in:
+2
-2
@@ -6,7 +6,7 @@ import torch
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer
|
||||
from model.model_minimind import MiniMindConfig, MiniMindForCausalLM
|
||||
from model.model_lora import *
|
||||
from trainer.trainer_utils import setup_seed
|
||||
from trainer.trainer_utils import setup_seed, get_model_params
|
||||
warnings.filterwarnings('ignore')
|
||||
|
||||
def init_model(args):
|
||||
@@ -26,7 +26,7 @@ def init_model(args):
|
||||
load_lora(model, f'./{args.save_dir}/lora/{args.lora_weight}_{args.hidden_size}.pth')
|
||||
else:
|
||||
model = AutoModelForCausalLM.from_pretrained(args.load_from, trust_remote_code=True)
|
||||
print(f'MiniMind模型参数: {sum(p.numel() for p in model.parameters()) / 1e6:.2f} M(illion)')
|
||||
get_model_params(model, model.config)
|
||||
return model.eval().to(args.device), tokenizer
|
||||
|
||||
def main():
|
||||
|
||||
Reference in New Issue
Block a user