mirror of
https://github.com/jingyaogong/minimind.git
synced 2026-01-13 19:57:20 +08:00
[update] params log
This commit is contained in:
parent
f55d4c32a0
commit
df89069362
@ -16,7 +16,6 @@ from model.model_minimind import MiniMindForCausalLM
|
|||||||
|
|
||||||
def get_model_params(model, config):
|
def get_model_params(model, config):
|
||||||
total = sum(p.numel() for p in model.parameters()) / 1e6
|
total = sum(p.numel() for p in model.parameters()) / 1e6
|
||||||
if getattr(config, 'use_moe', False):
|
|
||||||
n_routed = getattr(config, 'n_routed_experts', getattr(config, 'num_experts', 0))
|
n_routed = getattr(config, 'n_routed_experts', getattr(config, 'num_experts', 0))
|
||||||
n_active = getattr(config, 'num_experts_per_tok', 0)
|
n_active = getattr(config, 'num_experts_per_tok', 0)
|
||||||
n_shared = getattr(config, 'n_shared_experts', 0)
|
n_shared = getattr(config, 'n_shared_experts', 0)
|
||||||
@ -24,9 +23,8 @@ def get_model_params(model, config):
|
|||||||
shared_expert = sum(p.numel() for n, p in model.named_parameters() if 'mlp.shared_experts.0.' in n) / 1e6
|
shared_expert = sum(p.numel() for n, p in model.named_parameters() if 'mlp.shared_experts.0.' in n) / 1e6
|
||||||
base = total - (expert * n_routed) - (shared_expert * n_shared)
|
base = total - (expert * n_routed) - (shared_expert * n_shared)
|
||||||
active = base + (expert * n_active) + (shared_expert * n_shared)
|
active = base + (expert * n_active) + (shared_expert * n_shared)
|
||||||
Logger(f'Model Params: {total:.2f}M-A{active:.2f}M')
|
if active < total: Logger(f'Model Params: {total:.2f}M-A{active:.2f}M')
|
||||||
else:
|
else: Logger(f'Model Params: {total:.2f}M')
|
||||||
Logger(f'Model Params: {total:.2f}M')
|
|
||||||
|
|
||||||
|
|
||||||
def is_main_process():
|
def is_main_process():
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user