[feat] get params

This commit is contained in:
jingyaogong 2025-12-31 20:44:34 +08:00
parent eead9538b2
commit 288a1d7212
2 changed files with 23 additions and 7 deletions

View File

@ -6,7 +6,7 @@ import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer
from model.model_minimind import MiniMindConfig, MiniMindForCausalLM
from model.model_lora import *
from trainer.trainer_utils import setup_seed
from trainer.trainer_utils import setup_seed, get_model_params
warnings.filterwarnings('ignore')
def init_model(args):
@ -26,7 +26,7 @@ def init_model(args):
load_lora(model, f'./{args.save_dir}/lora/{args.lora_weight}_{args.hidden_size}.pth')
else:
model = AutoModelForCausalLM.from_pretrained(args.load_from, trust_remote_code=True)
print(f'MiniMind模型参数: {sum(p.numel() for p in model.parameters()) / 1e6:.2f} M(illion)')
get_model_params(model, model.config)
return model.eval().to(args.device), tokenizer
def main():

View File

@ -1,8 +1,10 @@
"""
训练工具函数集合
"""
import gc
import os
import sys
__package__ = "trainer"
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
import random
import math
import numpy as np
@ -10,7 +12,21 @@ import torch
import torch.distributed as dist
from torch.utils.data import Sampler
from transformers import AutoTokenizer
from model.model_minimind import MiniMindForCausalLM
from model.model_minimind_qwen3 import MiniMindForCausalLM
def get_model_params(model, config):
total = sum(p.numel() for p in model.parameters()) / 1e6
if getattr(config, 'use_moe', False):
n_routed = getattr(config, 'n_routed_experts', getattr(config, 'num_experts', 0))
n_active = getattr(config, 'num_experts_per_tok', 0)
n_shared = getattr(config, 'n_shared_experts', 0)
expert = sum(p.numel() for n, p in model.named_parameters() if 'mlp.experts.0.' in n) / 1e6
shared_expert = sum(p.numel() for n, p in model.named_parameters() if 'mlp.shared_experts.0.' in n) / 1e6
base = total - (expert * n_routed) - (shared_expert * n_shared)
active = base + (expert * n_active) + (shared_expert * n_shared)
Logger(f'Model Params: {total:.2f}M-A{active:.2f}M')
else:
Logger(f'Model Params: {total:.2f}M')
def is_main_process():
@ -111,7 +127,8 @@ def init_model(lm_config, from_weight='pretrain', tokenizer_path='../model', sav
weights = torch.load(weight_path, map_location=device)
model.load_state_dict(weights, strict=False)
Logger(f'所加载Model可训练参数{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万')
get_model_params(model, lm_config)
Logger(f'Trainable Params: {sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f}M')
return model.to(device), tokenizer
@ -138,5 +155,4 @@ class SkipBatchSampler(Sampler):
def __len__(self):
total_batches = (len(self.sampler) + self.batch_size - 1) // self.batch_size
return max(0, total_batches - self.skip_batches)
return max(0, total_batches - self.skip_batches)