[feat] get params

This commit is contained in:
jingyaogong 2025-12-31 20:46:59 +08:00
parent 288a1d7212
commit 9236260a4a

View File

@ -12,7 +12,7 @@ import torch
import torch.distributed as dist
from torch.utils.data import Sampler
from transformers import AutoTokenizer
from model.model_minimind_qwen3 import MiniMindForCausalLM
from model.model_minimind import MiniMindForCausalLM
def get_model_params(model, config):
total = sum(p.numel() for p in model.parameters()) / 1e6