mirror of
https://github.com/jingyaogong/minimind.git
synced 2026-04-25 08:48:16 +08:00
[update] align loss
This commit is contained in:
parent
e119db8478
commit
c090b69c4d
@ -16,8 +16,6 @@ class PretrainDataset(Dataset):
|
||||
|
||||
def __getitem__(self, index):
|
||||
sample = self.samples[index]
|
||||
|
||||
# 构建输入文本
|
||||
encoding = self.tokenizer(
|
||||
str(sample['text']),
|
||||
max_length=self.max_length,
|
||||
@ -26,12 +24,9 @@ class PretrainDataset(Dataset):
|
||||
return_tensors='pt'
|
||||
)
|
||||
input_ids = encoding.input_ids.squeeze()
|
||||
loss_mask = (input_ids != self.tokenizer.pad_token_id)
|
||||
|
||||
X = torch.tensor(input_ids[:-1], dtype=torch.long)
|
||||
Y = torch.tensor(input_ids[1:], dtype=torch.long)
|
||||
loss_mask = torch.tensor(loss_mask[1:], dtype=torch.long)
|
||||
return X, Y, loss_mask
|
||||
labels = input_ids.clone()
|
||||
labels[input_ids == self.tokenizer.pad_token_id] = -100
|
||||
return input_ids, labels
|
||||
|
||||
|
||||
class SFTDataset(Dataset):
|
||||
@ -56,8 +51,8 @@ class SFTDataset(Dataset):
|
||||
tools=tools
|
||||
)
|
||||
|
||||
def generate_loss_mask(self, input_ids):
|
||||
loss_mask = [0] * len(input_ids)
|
||||
def generate_labels(self, input_ids):
|
||||
labels = [-100] * len(input_ids)
|
||||
i = 0
|
||||
while i < len(input_ids):
|
||||
if input_ids[i:i + len(self.bos_id)] == self.bos_id:
|
||||
@ -68,29 +63,24 @@ class SFTDataset(Dataset):
|
||||
break
|
||||
end += 1
|
||||
for j in range(start + 1, min(end + len(self.eos_id) + 1, self.max_length)):
|
||||
loss_mask[j] = 1
|
||||
labels[j] = input_ids[j]
|
||||
i = end + len(self.eos_id) if end < len(input_ids) else len(input_ids)
|
||||
else:
|
||||
i += 1
|
||||
return loss_mask
|
||||
return labels
|
||||
|
||||
def __getitem__(self, index):
|
||||
sample = self.samples[index]
|
||||
prompt = self.create_chat_prompt(sample['conversations'])
|
||||
input_ids = self.tokenizer(prompt).input_ids[:self.max_length]
|
||||
input_ids += [self.tokenizer.pad_token_id] * (self.max_length - len(input_ids))
|
||||
loss_mask = self.generate_loss_mask(input_ids)
|
||||
|
||||
# 构建训练数据
|
||||
X = torch.tensor(input_ids[:-1], dtype=torch.long)
|
||||
Y = torch.tensor(input_ids[1:], dtype=torch.long)
|
||||
loss_mask = torch.tensor(loss_mask[1:], dtype=torch.long) # 对齐预测位置
|
||||
# # === 打印每个token的掩码情况 ===
|
||||
labels = self.generate_labels(input_ids)
|
||||
# # === 调试打印 ===
|
||||
# print(f"\n--- Sample {index} ---")
|
||||
# for i, (x, y, m) in enumerate(zip(X, Y, loss_mask)):
|
||||
# print(f"{i:3d}: X={self.tokenizer.decode([x])!r:16s} ---> Y={self.tokenizer.decode([y])!r:16s} mask={m}")
|
||||
# # ================================
|
||||
return X, Y, loss_mask
|
||||
# for i, (x, y) in enumerate(zip(input_ids[:-1], labels[1:])):
|
||||
# print(f"{i:3d}: X={self.tokenizer.decode([x])!r:16s} ---> Y={self.tokenizer.decode([input_ids[i+1]])!r:16s} label={y}")
|
||||
# # ================
|
||||
return torch.tensor(input_ids, dtype=torch.long), torch.tensor(labels, dtype=torch.long)
|
||||
|
||||
|
||||
class DPODataset(Dataset):
|
||||
|
||||
@ -437,6 +437,7 @@ class MiniMindForCausalLM(PreTrainedModel, GenerationMixin):
|
||||
def forward(self,
|
||||
input_ids: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
labels: Optional[torch.Tensor] = None,
|
||||
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
|
||||
use_cache: bool = False,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
@ -450,6 +451,13 @@ class MiniMindForCausalLM(PreTrainedModel, GenerationMixin):
|
||||
)
|
||||
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
||||
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
||||
output = CausalLMOutputWithPast(logits=logits, past_key_values=past_key_values, hidden_states=hidden_states)
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
loss = F.cross_entropy(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1), ignore_index=-100)
|
||||
|
||||
output = CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=past_key_values, hidden_states=hidden_states)
|
||||
output.aux_loss = aux_loss
|
||||
return output
|
||||
|
||||
@ -42,36 +42,37 @@ def train_epoch(epoch, loader, iters, teacher_model, lm_config_student, start_st
|
||||
teacher_model.eval()
|
||||
teacher_model.requires_grad_(False)
|
||||
|
||||
for step, (X, Y, loss_mask) in enumerate(loader, start=start_step + 1):
|
||||
X = X.to(args.device)
|
||||
Y = Y.to(args.device)
|
||||
loss_mask = loss_mask.to(args.device)
|
||||
for step, (input_ids, labels) in enumerate(loader, start=start_step + 1):
|
||||
input_ids = input_ids.to(args.device)
|
||||
labels = labels.to(args.device)
|
||||
loss_mask = (labels[..., 1:] != -100).float()
|
||||
lr = get_lr(epoch * iters + step, args.epochs * iters, args.learning_rate)
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group['lr'] = lr
|
||||
|
||||
# 前向传播(学生模型)
|
||||
with autocast_ctx:
|
||||
res = model(X)
|
||||
student_logits = res.logits
|
||||
res = model(input_ids)
|
||||
student_logits = res.logits[..., :-1, :].contiguous()
|
||||
|
||||
# 教师模型前向传播(只在eval & no_grad)
|
||||
if teacher_model is not None:
|
||||
with torch.no_grad():
|
||||
teacher_logits = teacher_model(X).logits
|
||||
teacher_logits = teacher_model(input_ids).logits[..., :-1, :].contiguous()
|
||||
vocab_size_student = student_logits.size(-1)
|
||||
teacher_logits = teacher_logits[..., :vocab_size_student]
|
||||
|
||||
# ========== 计算损失 ==========
|
||||
# 1) Ground-Truth CE Loss
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
loss_mask_flat = loss_mask.view(-1)
|
||||
ce_loss = F.cross_entropy(
|
||||
student_logits.view(-1, student_logits.size(-1)),
|
||||
Y.view(-1),
|
||||
ignore_index=0,
|
||||
shift_labels.view(-1),
|
||||
ignore_index=-100,
|
||||
reduction='none'
|
||||
)
|
||||
ce_loss_raw = torch.sum(ce_loss * loss_mask_flat) / loss_mask_flat.sum()
|
||||
ce_loss_raw = torch.sum(ce_loss * loss_mask_flat) / (loss_mask_flat.sum() + 1e-8)
|
||||
if lm_config_student.use_moe: ce_loss = ce_loss_raw + res.aux_loss
|
||||
else: ce_loss = ce_loss_raw
|
||||
|
||||
@ -124,13 +125,12 @@ def train_epoch(epoch, loader, iters, teacher_model, lm_config_student, start_st
|
||||
raw_model = model.module if isinstance(model, DistributedDataParallel) else model
|
||||
raw_model = getattr(raw_model, '_orig_mod', raw_model)
|
||||
state_dict = raw_model.state_dict()
|
||||
state_dict = {k: v.half().cpu() for k, v in state_dict.items()}
|
||||
torch.save(state_dict, ckp)
|
||||
torch.save({k: v.half().cpu() for k, v in state_dict.items()}, ckp)
|
||||
lm_checkpoint(lm_config_student, weight=args.save_weight, model=model, optimizer=optimizer, scaler=scaler, epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints')
|
||||
model.train()
|
||||
del state_dict
|
||||
|
||||
del X, Y, loss_mask, res, student_logits, teacher_logits, ce_loss, distill_loss, loss
|
||||
del input_ids, labels, loss_mask, res, student_logits, ce_loss, distill_loss, loss
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -111,8 +111,7 @@ def train_epoch(epoch, loader, iters, ref_model, lm_config, start_step=0, wandb=
|
||||
raw_model = model.module if isinstance(model, DistributedDataParallel) else model
|
||||
raw_model = getattr(raw_model, '_orig_mod', raw_model)
|
||||
state_dict = raw_model.state_dict()
|
||||
state_dict = {k: v.half().cpu() for k, v in state_dict.items()}
|
||||
torch.save(state_dict, ckp)
|
||||
torch.save({k: v.half().cpu() for k, v in state_dict.items()}, ckp)
|
||||
lm_checkpoint(lm_config, weight=args.save_weight, model=model, optimizer=optimizer, scaler=scaler, epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints')
|
||||
model.train()
|
||||
del state_dict
|
||||
|
||||
@ -21,25 +21,17 @@ warnings.filterwarnings('ignore')
|
||||
|
||||
|
||||
def train_epoch(epoch, loader, iters, start_step=0, wandb=None):
|
||||
loss_fct = nn.CrossEntropyLoss(reduction='none')
|
||||
start_time = time.time()
|
||||
for step, (X, Y, loss_mask) in enumerate(loader, start=start_step + 1):
|
||||
X = X.to(args.device)
|
||||
Y = Y.to(args.device)
|
||||
loss_mask = loss_mask.to(args.device)
|
||||
for step, (input_ids, labels) in enumerate(loader, start=start_step + 1):
|
||||
input_ids = input_ids.to(args.device)
|
||||
labels = labels.to(args.device)
|
||||
lr = get_lr(epoch * iters + step, args.epochs * iters, args.learning_rate)
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group['lr'] = lr
|
||||
|
||||
with autocast_ctx:
|
||||
res = model(X)
|
||||
loss = loss_fct(
|
||||
res.logits.view(-1, res.logits.size(-1)),
|
||||
Y.view(-1)
|
||||
).view(Y.size())
|
||||
|
||||
logits_loss = (loss * loss_mask).sum() / loss_mask.sum()
|
||||
loss = logits_loss + res.aux_loss
|
||||
res = model(input_ids, labels=labels)
|
||||
loss = res.loss + res.aux_loss
|
||||
loss = loss / args.accumulation_steps
|
||||
|
||||
scaler.scale(loss).backward()
|
||||
@ -56,13 +48,11 @@ def train_epoch(epoch, loader, iters, start_step=0, wandb=None):
|
||||
if step % args.log_interval == 0 or step == iters - 1:
|
||||
spend_time = time.time() - start_time
|
||||
current_loss = loss.item() * args.accumulation_steps
|
||||
current_logits_loss = logits_loss.item()
|
||||
current_aux_loss = res.aux_loss.item()
|
||||
current_aux_loss = res.aux_loss.item() if res.aux_loss is not None else 0.0
|
||||
current_logits_loss = current_loss - current_aux_loss
|
||||
current_lr = optimizer.param_groups[-1]['lr']
|
||||
eta_min = spend_time / (step + 1) * iters // 60 - spend_time // 60
|
||||
|
||||
Logger(f'Epoch:[{epoch + 1}/{args.epochs}]({step}/{iters}), loss: {current_loss:.4f}, logits_loss: {current_logits_loss:.4f}, aux_loss: {current_aux_loss:.4f}, learning_rate: {current_lr:.8f}, epoch_time: {eta_min:.3f}min')
|
||||
|
||||
Logger(f'Epoch:[{epoch + 1}/{args.epochs}]({step}/{iters}), loss: {current_loss:.4f}, logits_loss: {current_logits_loss:.4f}, aux_loss: {current_aux_loss:.4f}, lr: {current_lr:.8f}, epoch_time: {eta_min:.1f}min')
|
||||
if wandb: wandb.log({"loss": current_loss, "logits_loss": current_logits_loss, "aux_loss": current_aux_loss, "learning_rate": current_lr, "epoch_time": eta_min})
|
||||
|
||||
if (step % args.save_interval == 0 or step == iters - 1) and is_main_process():
|
||||
@ -72,14 +62,13 @@ def train_epoch(epoch, loader, iters, start_step=0, wandb=None):
|
||||
raw_model = model.module if isinstance(model, DistributedDataParallel) else model
|
||||
raw_model = getattr(raw_model, '_orig_mod', raw_model)
|
||||
state_dict = raw_model.state_dict()
|
||||
state_dict = {k: v.half().cpu() for k, v in state_dict.items()}
|
||||
torch.save(state_dict, ckp)
|
||||
torch.save({k: v.half().cpu() for k, v in state_dict.items()}, ckp)
|
||||
lm_checkpoint(lm_config, weight=args.save_weight, model=model, optimizer=optimizer,
|
||||
epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints', scaler=scaler)
|
||||
model.train()
|
||||
del state_dict
|
||||
|
||||
del X, Y, loss_mask, res, loss
|
||||
del input_ids, labels, res, loss
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -22,25 +22,17 @@ warnings.filterwarnings('ignore')
|
||||
|
||||
|
||||
def train_epoch(epoch, loader, iters, lora_params, start_step=0, wandb=None):
|
||||
loss_fct = nn.CrossEntropyLoss(reduction='none')
|
||||
start_time = time.time()
|
||||
for step, (X, Y, loss_mask) in enumerate(loader, start=start_step + 1):
|
||||
X = X.to(args.device)
|
||||
Y = Y.to(args.device)
|
||||
loss_mask = loss_mask.to(args.device)
|
||||
for step, (input_ids, labels) in enumerate(loader, start=start_step + 1):
|
||||
input_ids = input_ids.to(args.device)
|
||||
labels = labels.to(args.device)
|
||||
lr = get_lr(epoch * iters + step, args.epochs * iters, args.learning_rate)
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group['lr'] = lr
|
||||
|
||||
with autocast_ctx:
|
||||
res = model(X)
|
||||
loss = loss_fct(
|
||||
res.logits.view(-1, res.logits.size(-1)),
|
||||
Y.view(-1)
|
||||
).view(Y.size())
|
||||
|
||||
logits_loss = (loss * loss_mask).sum() / loss_mask.sum()
|
||||
loss = logits_loss + res.aux_loss
|
||||
res = model(input_ids, labels=labels)
|
||||
loss = res.loss + res.aux_loss
|
||||
loss = loss / args.accumulation_steps
|
||||
|
||||
scaler.scale(loss).backward()
|
||||
@ -48,22 +40,18 @@ def train_epoch(epoch, loader, iters, lora_params, start_step=0, wandb=None):
|
||||
if (step + 1) % args.accumulation_steps == 0:
|
||||
scaler.unscale_(optimizer)
|
||||
torch.nn.utils.clip_grad_norm_(lora_params, args.grad_clip)
|
||||
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
|
||||
if step % args.log_interval == 0 or step == iters - 1:
|
||||
spend_time = time.time() - start_time
|
||||
current_loss = loss.item() * args.accumulation_steps
|
||||
current_logits_loss = logits_loss.item()
|
||||
current_aux_loss = res.aux_loss.item()
|
||||
current_aux_loss = res.aux_loss.item() if res.aux_loss is not None else 0.0
|
||||
current_logits_loss = current_loss - current_aux_loss
|
||||
current_lr = optimizer.param_groups[-1]['lr']
|
||||
eta_min = spend_time / (step + 1) * iters // 60 - spend_time // 60
|
||||
|
||||
Logger(f'Epoch:[{epoch + 1}/{args.epochs}]({step}/{iters}), loss: {current_loss:.4f}, logits_loss: {current_logits_loss:.4f}, aux_loss: {current_aux_loss:.4f}, learning_rate: {current_lr:.8f}, epoch_time: {eta_min:.3f}min')
|
||||
|
||||
Logger(f'Epoch:[{epoch + 1}/{args.epochs}]({step}/{iters}), loss: {current_loss:.4f}, logits_loss: {current_logits_loss:.4f}, aux_loss: {current_aux_loss:.4f}, lr: {current_lr:.8f}, epoch_time: {eta_min:.1f}min')
|
||||
if wandb: wandb.log({"loss": current_loss, "logits_loss": current_logits_loss, "aux_loss": current_aux_loss, "learning_rate": current_lr, "epoch_time": eta_min})
|
||||
|
||||
if (step % args.save_interval == 0 or step == iters - 1) and is_main_process():
|
||||
@ -74,7 +62,7 @@ def train_epoch(epoch, loader, iters, lora_params, start_step=0, wandb=None):
|
||||
lm_checkpoint(lm_config, weight=args.lora_name, model=model, optimizer=optimizer, scaler=scaler, epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints')
|
||||
model.train()
|
||||
|
||||
del X, Y, loss_mask, res, loss
|
||||
del input_ids, labels, res, loss
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -21,25 +21,17 @@ warnings.filterwarnings('ignore')
|
||||
|
||||
|
||||
def train_epoch(epoch, loader, iters, start_step=0, wandb=None):
|
||||
loss_fct = nn.CrossEntropyLoss(reduction='none')
|
||||
start_time = time.time()
|
||||
for step, (X, Y, loss_mask) in enumerate(loader, start=start_step + 1):
|
||||
X = X.to(args.device)
|
||||
Y = Y.to(args.device)
|
||||
loss_mask = loss_mask.to(args.device)
|
||||
for step, (input_ids, labels) in enumerate(loader, start=start_step + 1):
|
||||
input_ids = input_ids.to(args.device)
|
||||
labels = labels.to(args.device)
|
||||
lr = get_lr(epoch * iters + step, args.epochs * iters, args.learning_rate)
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group['lr'] = lr
|
||||
|
||||
with autocast_ctx:
|
||||
res = model(X)
|
||||
loss = loss_fct(
|
||||
res.logits.view(-1, res.logits.size(-1)),
|
||||
Y.view(-1)
|
||||
).view(Y.size())
|
||||
|
||||
logits_loss = (loss * loss_mask).sum() / loss_mask.sum()
|
||||
loss = logits_loss + res.aux_loss
|
||||
res = model(input_ids, labels=labels)
|
||||
loss = res.loss + res.aux_loss
|
||||
loss = loss / args.accumulation_steps
|
||||
|
||||
scaler.scale(loss).backward()
|
||||
@ -56,13 +48,11 @@ def train_epoch(epoch, loader, iters, start_step=0, wandb=None):
|
||||
if step % args.log_interval == 0 or step == iters - 1:
|
||||
spend_time = time.time() - start_time
|
||||
current_loss = loss.item() * args.accumulation_steps
|
||||
current_logits_loss = logits_loss.item()
|
||||
current_aux_loss = res.aux_loss.item()
|
||||
current_aux_loss = res.aux_loss.item() if res.aux_loss is not None else 0.0
|
||||
current_logits_loss = current_loss - current_aux_loss
|
||||
current_lr = optimizer.param_groups[-1]['lr']
|
||||
eta_min = spend_time / (step + 1) * iters // 60 - spend_time // 60
|
||||
|
||||
Logger(f'Epoch:[{epoch + 1}/{args.epochs}]({step}/{iters}), loss: {current_loss:.4f}, logits_loss: {current_logits_loss:.4f}, aux_loss: {current_aux_loss:.4f}, learning_rate: {current_lr:.8f}, epoch_time: {eta_min:.3f}min')
|
||||
|
||||
Logger(f'Epoch:[{epoch + 1}/{args.epochs}]({step}/{iters}), loss: {current_loss:.4f}, logits_loss: {current_logits_loss:.4f}, aux_loss: {current_aux_loss:.4f}, lr: {current_lr:.8f}, epoch_time: {eta_min:.1f}min')
|
||||
if wandb: wandb.log({"loss": current_loss, "logits_loss": current_logits_loss, "aux_loss": current_aux_loss, "learning_rate": current_lr, "epoch_time": eta_min})
|
||||
|
||||
if (step % args.save_interval == 0 or step == iters - 1) and is_main_process():
|
||||
@ -72,13 +62,12 @@ def train_epoch(epoch, loader, iters, start_step=0, wandb=None):
|
||||
raw_model = model.module if isinstance(model, DistributedDataParallel) else model
|
||||
raw_model = getattr(raw_model, '_orig_mod', raw_model)
|
||||
state_dict = raw_model.state_dict()
|
||||
state_dict = {k: v.half().cpu() for k, v in state_dict.items()}
|
||||
torch.save(state_dict, ckp)
|
||||
torch.save({k: v.half().cpu() for k, v in state_dict.items()}, ckp)
|
||||
lm_checkpoint(lm_config, weight=args.save_weight, model=model, optimizer=optimizer, scaler=scaler, epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints')
|
||||
model.train()
|
||||
del state_dict
|
||||
|
||||
del X, Y, loss_mask, res, loss
|
||||
del input_ids, labels, res, loss
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -21,7 +21,6 @@ warnings.filterwarnings('ignore')
|
||||
|
||||
|
||||
def train_epoch(epoch, loader, iters, tokenizer, lm_config, start_step=0, wandb=None):
|
||||
# 思考标签占位符
|
||||
start_of_think_ids = tokenizer('<think>').input_ids
|
||||
end_of_think_ids = tokenizer('</think>').input_ids
|
||||
start_of_answer_ids = tokenizer('<answer>').input_ids
|
||||
@ -29,30 +28,28 @@ def train_epoch(epoch, loader, iters, tokenizer, lm_config, start_step=0, wandb=
|
||||
loss_fct = nn.CrossEntropyLoss(reduction='none')
|
||||
start_time = time.time()
|
||||
|
||||
for step, (X, Y, loss_mask) in enumerate(loader, start=start_step + 1):
|
||||
X = X.to(args.device)
|
||||
Y = Y.to(args.device)
|
||||
loss_mask = loss_mask.to(args.device)
|
||||
for step, (input_ids, labels) in enumerate(loader, start=start_step + 1):
|
||||
input_ids = input_ids.to(args.device)
|
||||
labels = labels.to(args.device)
|
||||
lr = get_lr(epoch * iters + step, args.epochs * iters, args.learning_rate)
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group['lr'] = lr
|
||||
|
||||
with autocast_ctx:
|
||||
res = model(X)
|
||||
loss = loss_fct(
|
||||
res.logits.view(-1, res.logits.size(-1)),
|
||||
Y.view(-1)
|
||||
).view(Y.size())
|
||||
res = model(input_ids)
|
||||
shift_logits = res.logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)).view(shift_labels.size())
|
||||
|
||||
# 特殊标签位置增加权重(推理蒸馏特有)
|
||||
sp_ids = torch.isin(Y.view(-1),
|
||||
loss_mask = (shift_labels != -100).float()
|
||||
sp_ids = torch.isin(shift_labels.view(-1),
|
||||
torch.tensor(start_of_think_ids + end_of_think_ids
|
||||
+ start_of_answer_ids + end_of_answer_ids
|
||||
).to(args.device))
|
||||
loss_mask = loss_mask.view(-1)
|
||||
loss_mask_sum = loss_mask.sum()
|
||||
loss_mask[sp_ids] = 10 # 对思考标签增加10倍权重
|
||||
loss_mask = loss_mask.view(Y.size())
|
||||
loss_mask_flat = loss_mask.view(-1)
|
||||
loss_mask_sum = loss_mask_flat.sum()
|
||||
loss_mask_flat[sp_ids] = 10
|
||||
loss_mask = loss_mask_flat.view(shift_labels.size())
|
||||
logits_loss = (loss * loss_mask).sum() / loss_mask_sum
|
||||
loss = logits_loss + res.aux_loss
|
||||
loss = loss / args.accumulation_steps
|
||||
@ -69,13 +66,11 @@ def train_epoch(epoch, loader, iters, tokenizer, lm_config, start_step=0, wandb=
|
||||
if step % args.log_interval == 0 or step == iters - 1:
|
||||
spend_time = time.time() - start_time
|
||||
current_loss = loss.item() * args.accumulation_steps
|
||||
current_aux_loss = res.aux_loss.item() if res.aux_loss is not None else 0.0
|
||||
current_logits_loss = logits_loss.item()
|
||||
current_aux_loss = res.aux_loss.item()
|
||||
current_lr = optimizer.param_groups[-1]['lr']
|
||||
eta_min = spend_time / (step + 1) * iters // 60 - spend_time // 60
|
||||
|
||||
Logger(f'Epoch:[{epoch + 1}/{args.epochs}]({step}/{iters}), loss: {current_loss:.4f}, logits_loss: {current_logits_loss:.4f}, aux_loss: {current_aux_loss:.4f}, learning_rate: {current_lr:.8f}, epoch_time: {eta_min:.3f}min')
|
||||
|
||||
Logger(f'Epoch:[{epoch + 1}/{args.epochs}]({step}/{iters}), loss: {current_loss:.4f}, logits_loss: {current_logits_loss:.4f}, aux_loss: {current_aux_loss:.4f}, lr: {current_lr:.8f}, epoch_time: {eta_min:.1f}min')
|
||||
if wandb: wandb.log({"loss": current_loss, "logits_loss": current_logits_loss, "aux_loss": current_aux_loss, "learning_rate": current_lr, "epoch_time": eta_min})
|
||||
|
||||
if (step % args.save_interval == 0 or step == iters - 1) and is_main_process():
|
||||
@ -85,13 +80,12 @@ def train_epoch(epoch, loader, iters, tokenizer, lm_config, start_step=0, wandb=
|
||||
raw_model = model.module if isinstance(model, DistributedDataParallel) else model
|
||||
raw_model = getattr(raw_model, '_orig_mod', raw_model)
|
||||
state_dict = raw_model.state_dict()
|
||||
state_dict = {k: v.half().cpu() for k, v in state_dict.items()}
|
||||
torch.save(state_dict, ckp)
|
||||
torch.save({k: v.half().cpu() for k, v in state_dict.items()}, ckp)
|
||||
lm_checkpoint(lm_config, weight=args.save_weight, model=model, optimizer=optimizer, scaler=scaler, epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints')
|
||||
model.train()
|
||||
del state_dict
|
||||
|
||||
del X, Y, loss_mask, res, loss
|
||||
del input_ids, labels, res, loss
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Loading…
Reference in New Issue
Block a user