[fix] lora moe

This commit is contained in:
jingyaogong
2026-04-09 16:36:48 +08:00
parent aa3e6affa1
commit 5351424bf0
2 changed files with 4 additions and 2 deletions
+2 -1
View File
@@ -127,6 +127,7 @@ def convert_json_to_jinja(json_file_path, output_path):
if __name__ == '__main__':
lm_config = MiniMindConfig(hidden_size=768, num_hidden_layers=8, max_seq_len=8192, use_moe=False)
# convert torch to transformers
torch_path = f"../out/full_sft_{lm_config.hidden_size}{'_moe' if lm_config.use_moe else ''}.pth"
transformers_path = '../minimind-3'
@@ -134,7 +135,7 @@ if __name__ == '__main__':
# # merge lora
# base_torch_path = f"../out/full_sft_{lm_config.hidden_size}{'_moe' if lm_config.use_moe else ''}.pth"
# lora_path = f"../out/lora_identity_{lm_config.hidden_size}.pth"
# lora_path = f"../out/lora_identity_{lm_config.hidden_size}{'_moe' if lm_config.use_moe else ''}.pth"
# merged_torch_path = f"../out/merge_identity_{lm_config.hidden_size}{'_moe' if lm_config.use_moe else ''}.pth"
# convert_merge_base_lora(base_torch_path, lora_path, merged_torch_path)
+2 -1
View File
@@ -58,7 +58,8 @@ def train_epoch(epoch, loader, iters, lora_params, start_step=0, wandb=None):
if (step % args.save_interval == 0 or step == iters) and is_main_process():
model.eval()
lora_save_path = f'{args.save_dir}/{args.lora_name}_{lm_config.hidden_size}.pth'
moe_suffix = '_moe' if lm_config.use_moe else ''
lora_save_path = f'{args.save_dir}/{args.lora_name}_{lm_config.hidden_size}{moe_suffix}.pth'
# LoRA只保存LoRA权重
save_lora(model, lora_save_path)
lm_checkpoint(lm_config, weight=args.lora_name, model=model, optimizer=optimizer, scaler=scaler, epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints')