mirror of
https://github.com/jingyaogong/minimind.git
synced 2026-06-06 00:04:50 +00:00
[fix] lora moe
This commit is contained in:
@@ -127,6 +127,7 @@ def convert_json_to_jinja(json_file_path, output_path):
|
||||
|
||||
if __name__ == '__main__':
|
||||
lm_config = MiniMindConfig(hidden_size=768, num_hidden_layers=8, max_seq_len=8192, use_moe=False)
|
||||
|
||||
# convert torch to transformers
|
||||
torch_path = f"../out/full_sft_{lm_config.hidden_size}{'_moe' if lm_config.use_moe else ''}.pth"
|
||||
transformers_path = '../minimind-3'
|
||||
@@ -134,7 +135,7 @@ if __name__ == '__main__':
|
||||
|
||||
# # merge lora
|
||||
# base_torch_path = f"../out/full_sft_{lm_config.hidden_size}{'_moe' if lm_config.use_moe else ''}.pth"
|
||||
# lora_path = f"../out/lora_identity_{lm_config.hidden_size}.pth"
|
||||
# lora_path = f"../out/lora_identity_{lm_config.hidden_size}{'_moe' if lm_config.use_moe else ''}.pth"
|
||||
# merged_torch_path = f"../out/merge_identity_{lm_config.hidden_size}{'_moe' if lm_config.use_moe else ''}.pth"
|
||||
# convert_merge_base_lora(base_torch_path, lora_path, merged_torch_path)
|
||||
|
||||
|
||||
@@ -58,7 +58,8 @@ def train_epoch(epoch, loader, iters, lora_params, start_step=0, wandb=None):
|
||||
|
||||
if (step % args.save_interval == 0 or step == iters) and is_main_process():
|
||||
model.eval()
|
||||
lora_save_path = f'{args.save_dir}/{args.lora_name}_{lm_config.hidden_size}.pth'
|
||||
moe_suffix = '_moe' if lm_config.use_moe else ''
|
||||
lora_save_path = f'{args.save_dir}/{args.lora_name}_{lm_config.hidden_size}{moe_suffix}.pth'
|
||||
# LoRA只保存LoRA权重
|
||||
save_lora(model, lora_save_path)
|
||||
lm_checkpoint(lm_config, weight=args.lora_name, model=model, optimizer=optimizer, scaler=scaler, epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints')
|
||||
|
||||
Reference in New Issue
Block a user