mirror of
https://github.com/jingyaogong/minimind.git
synced 2026-01-13 19:57:20 +08:00
[update] params log
This commit is contained in:
parent
f55d4c32a0
commit
df89069362
@ -16,7 +16,6 @@ from model.model_minimind import MiniMindForCausalLM
|
||||
|
||||
def get_model_params(model, config):
|
||||
total = sum(p.numel() for p in model.parameters()) / 1e6
|
||||
if getattr(config, 'use_moe', False):
|
||||
n_routed = getattr(config, 'n_routed_experts', getattr(config, 'num_experts', 0))
|
||||
n_active = getattr(config, 'num_experts_per_tok', 0)
|
||||
n_shared = getattr(config, 'n_shared_experts', 0)
|
||||
@ -24,9 +23,8 @@ def get_model_params(model, config):
|
||||
shared_expert = sum(p.numel() for n, p in model.named_parameters() if 'mlp.shared_experts.0.' in n) / 1e6
|
||||
base = total - (expert * n_routed) - (shared_expert * n_shared)
|
||||
active = base + (expert * n_active) + (shared_expert * n_shared)
|
||||
Logger(f'Model Params: {total:.2f}M-A{active:.2f}M')
|
||||
else:
|
||||
Logger(f'Model Params: {total:.2f}M')
|
||||
if active < total: Logger(f'Model Params: {total:.2f}M-A{active:.2f}M')
|
||||
else: Logger(f'Model Params: {total:.2f}M')
|
||||
|
||||
|
||||
def is_main_process():
|
||||
|
||||
Loading…
Reference in New Issue
Block a user