mirror of
https://github.com/jingyaogong/minimind.git
synced 2026-06-06 00:04:50 +00:00
[feat] update import
This commit is contained in:
+10
-11
@@ -8,6 +8,8 @@ import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.utils.data import Sampler
|
||||
from transformers import AutoTokenizer
|
||||
from model.model_minimind import MiniMindForCausalLM
|
||||
|
||||
|
||||
def is_main_process():
|
||||
@@ -26,7 +28,7 @@ def get_lr(current_step, total_steps, lr):
|
||||
def init_distributed_mode():
|
||||
if int(os.environ.get("RANK", -1)) == -1:
|
||||
return 0 # 非DDP模式
|
||||
|
||||
|
||||
dist.init_process_group(backend="nccl")
|
||||
local_rank = int(os.environ["LOCAL_RANK"])
|
||||
torch.cuda.set_device(local_rank)
|
||||
@@ -47,7 +49,7 @@ def lm_checkpoint(lm_config, weight='full_sft', model=None, optimizer=None, epoc
|
||||
moe_path = '_moe' if lm_config.use_moe else ''
|
||||
ckp_path = f'{save_dir}/{weight}_{lm_config.hidden_size}{moe_path}.pth'
|
||||
resume_path = f'{save_dir}/{weight}_{lm_config.hidden_size}{moe_path}_resume.pth'
|
||||
|
||||
|
||||
if model is not None:
|
||||
from torch.nn.parallel import DistributedDataParallel
|
||||
state_dict = model.module.state_dict() if isinstance(model, DistributedDataParallel) else model.state_dict()
|
||||
@@ -61,7 +63,7 @@ def lm_checkpoint(lm_config, weight='full_sft', model=None, optimizer=None, epoc
|
||||
wandb_id = getattr(run, 'id', None) if run else None
|
||||
else:
|
||||
wandb_id = getattr(wandb, 'id', None)
|
||||
|
||||
|
||||
resume_data = {
|
||||
'model': state_dict,
|
||||
'optimizer': optimizer.state_dict(),
|
||||
@@ -79,7 +81,7 @@ def lm_checkpoint(lm_config, weight='full_sft', model=None, optimizer=None, epoc
|
||||
resume_data[key] = value.state_dict()
|
||||
else:
|
||||
resume_data[key] = value
|
||||
|
||||
|
||||
resume_tmp = resume_path + '.tmp'
|
||||
torch.save(resume_data, resume_tmp)
|
||||
os.replace(resume_tmp, resume_path)
|
||||
@@ -96,18 +98,15 @@ def lm_checkpoint(lm_config, weight='full_sft', model=None, optimizer=None, epoc
|
||||
|
||||
|
||||
def init_model(lm_config, from_weight='pretrain', tokenizer_path='../model', save_dir='../out', device='cuda'):
|
||||
from transformers import AutoTokenizer
|
||||
from model.model_minimind import MiniMindForCausalLM
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
|
||||
model = MiniMindForCausalLM(lm_config)
|
||||
|
||||
|
||||
if from_weight!= 'none':
|
||||
moe_suffix = '_moe' if lm_config.use_moe else ''
|
||||
weight_path = f'{save_dir}/{from_weight}_{lm_config.hidden_size}{moe_suffix}.pth'
|
||||
weights = torch.load(weight_path, map_location=device)
|
||||
model.load_state_dict(weights, strict=False)
|
||||
|
||||
|
||||
Logger(f'所加载Model可训练参数:{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万')
|
||||
return model.to(device), tokenizer
|
||||
|
||||
@@ -117,7 +116,7 @@ class SkipBatchSampler(Sampler):
|
||||
self.sampler = sampler
|
||||
self.batch_size = batch_size
|
||||
self.skip_batches = skip_batches
|
||||
|
||||
|
||||
def __iter__(self):
|
||||
batch = []
|
||||
skipped = 0
|
||||
@@ -132,7 +131,7 @@ class SkipBatchSampler(Sampler):
|
||||
batch = []
|
||||
if len(batch) > 0 and skipped >= self.skip_batches:
|
||||
yield batch
|
||||
|
||||
|
||||
def __len__(self):
|
||||
total_batches = (len(self.sampler) + self.batch_size - 1) // self.batch_size
|
||||
return max(0, total_batches - self.skip_batches)
|
||||
|
||||
Reference in New Issue
Block a user