[feat] update import

This commit is contained in:
jingyaogong
2025-10-31 23:45:55 +08:00
parent 8d71754e05
commit 0323815729
+10 -11
View File
@@ -8,6 +8,8 @@ import numpy as np
import torch
import torch.distributed as dist
from torch.utils.data import Sampler
from transformers import AutoTokenizer
from model.model_minimind import MiniMindForCausalLM
def is_main_process():
@@ -26,7 +28,7 @@ def get_lr(current_step, total_steps, lr):
def init_distributed_mode():
if int(os.environ.get("RANK", -1)) == -1:
return 0 # 非DDP模式
dist.init_process_group(backend="nccl")
local_rank = int(os.environ["LOCAL_RANK"])
torch.cuda.set_device(local_rank)
@@ -47,7 +49,7 @@ def lm_checkpoint(lm_config, weight='full_sft', model=None, optimizer=None, epoc
moe_path = '_moe' if lm_config.use_moe else ''
ckp_path = f'{save_dir}/{weight}_{lm_config.hidden_size}{moe_path}.pth'
resume_path = f'{save_dir}/{weight}_{lm_config.hidden_size}{moe_path}_resume.pth'
if model is not None:
from torch.nn.parallel import DistributedDataParallel
state_dict = model.module.state_dict() if isinstance(model, DistributedDataParallel) else model.state_dict()
@@ -61,7 +63,7 @@ def lm_checkpoint(lm_config, weight='full_sft', model=None, optimizer=None, epoc
wandb_id = getattr(run, 'id', None) if run else None
else:
wandb_id = getattr(wandb, 'id', None)
resume_data = {
'model': state_dict,
'optimizer': optimizer.state_dict(),
@@ -79,7 +81,7 @@ def lm_checkpoint(lm_config, weight='full_sft', model=None, optimizer=None, epoc
resume_data[key] = value.state_dict()
else:
resume_data[key] = value
resume_tmp = resume_path + '.tmp'
torch.save(resume_data, resume_tmp)
os.replace(resume_tmp, resume_path)
@@ -96,18 +98,15 @@ def lm_checkpoint(lm_config, weight='full_sft', model=None, optimizer=None, epoc
def init_model(lm_config, from_weight='pretrain', tokenizer_path='../model', save_dir='../out', device='cuda'):
from transformers import AutoTokenizer
from model.model_minimind import MiniMindForCausalLM
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
model = MiniMindForCausalLM(lm_config)
if from_weight!= 'none':
moe_suffix = '_moe' if lm_config.use_moe else ''
weight_path = f'{save_dir}/{from_weight}_{lm_config.hidden_size}{moe_suffix}.pth'
weights = torch.load(weight_path, map_location=device)
model.load_state_dict(weights, strict=False)
Logger(f'所加载Model可训练参数:{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万')
return model.to(device), tokenizer
@@ -117,7 +116,7 @@ class SkipBatchSampler(Sampler):
self.sampler = sampler
self.batch_size = batch_size
self.skip_batches = skip_batches
def __iter__(self):
batch = []
skipped = 0
@@ -132,7 +131,7 @@ class SkipBatchSampler(Sampler):
batch = []
if len(batch) > 0 and skipped >= self.skip_batches:
yield batch
def __len__(self):
total_batches = (len(self.sampler) + self.batch_size - 1) // self.batch_size
return max(0, total_batches - self.skip_batches)