[update] params log

This commit is contained in:
jingyaogong 2026-01-07 23:08:45 +08:00
parent f55d4c32a0
commit df89069362

View File

@ -16,7 +16,6 @@ from model.model_minimind import MiniMindForCausalLM
def get_model_params(model, config):
total = sum(p.numel() for p in model.parameters()) / 1e6
if getattr(config, 'use_moe', False):
n_routed = getattr(config, 'n_routed_experts', getattr(config, 'num_experts', 0))
n_active = getattr(config, 'num_experts_per_tok', 0)
n_shared = getattr(config, 'n_shared_experts', 0)
@ -24,9 +23,8 @@ def get_model_params(model, config):
shared_expert = sum(p.numel() for n, p in model.named_parameters() if 'mlp.shared_experts.0.' in n) / 1e6
base = total - (expert * n_routed) - (shared_expert * n_shared)
active = base + (expert * n_active) + (shared_expert * n_shared)
Logger(f'Model Params: {total:.2f}M-A{active:.2f}M')
else:
Logger(f'Model Params: {total:.2f}M')
if active < total: Logger(f'Model Params: {total:.2f}M-A{active:.2f}M')
else: Logger(f'Model Params: {total:.2f}M')
def is_main_process():