diff --git a/dataset/lm_dataset.py b/dataset/lm_dataset.py index a958303..21e321c 100644 --- a/dataset/lm_dataset.py +++ b/dataset/lm_dataset.py @@ -16,8 +16,6 @@ class PretrainDataset(Dataset): def __getitem__(self, index): sample = self.samples[index] - - # 构建输入文本 encoding = self.tokenizer( str(sample['text']), max_length=self.max_length, @@ -26,12 +24,9 @@ class PretrainDataset(Dataset): return_tensors='pt' ) input_ids = encoding.input_ids.squeeze() - loss_mask = (input_ids != self.tokenizer.pad_token_id) - - X = torch.tensor(input_ids[:-1], dtype=torch.long) - Y = torch.tensor(input_ids[1:], dtype=torch.long) - loss_mask = torch.tensor(loss_mask[1:], dtype=torch.long) - return X, Y, loss_mask + labels = input_ids.clone() + labels[input_ids == self.tokenizer.pad_token_id] = -100 + return input_ids, labels class SFTDataset(Dataset): @@ -56,8 +51,8 @@ class SFTDataset(Dataset): tools=tools ) - def generate_loss_mask(self, input_ids): - loss_mask = [0] * len(input_ids) + def generate_labels(self, input_ids): + labels = [-100] * len(input_ids) i = 0 while i < len(input_ids): if input_ids[i:i + len(self.bos_id)] == self.bos_id: @@ -68,29 +63,24 @@ class SFTDataset(Dataset): break end += 1 for j in range(start + 1, min(end + len(self.eos_id) + 1, self.max_length)): - loss_mask[j] = 1 + labels[j] = input_ids[j] i = end + len(self.eos_id) if end < len(input_ids) else len(input_ids) else: i += 1 - return loss_mask + return labels def __getitem__(self, index): sample = self.samples[index] prompt = self.create_chat_prompt(sample['conversations']) input_ids = self.tokenizer(prompt).input_ids[:self.max_length] input_ids += [self.tokenizer.pad_token_id] * (self.max_length - len(input_ids)) - loss_mask = self.generate_loss_mask(input_ids) - - # 构建训练数据 - X = torch.tensor(input_ids[:-1], dtype=torch.long) - Y = torch.tensor(input_ids[1:], dtype=torch.long) - loss_mask = torch.tensor(loss_mask[1:], dtype=torch.long) # 对齐预测位置 - # # === 打印每个token的掩码情况 === + labels = self.generate_labels(input_ids) + # # === 调试打印 === # print(f"\n--- Sample {index} ---") - # for i, (x, y, m) in enumerate(zip(X, Y, loss_mask)): - # print(f"{i:3d}: X={self.tokenizer.decode([x])!r:16s} ---> Y={self.tokenizer.decode([y])!r:16s} mask={m}") - # # ================================ - return X, Y, loss_mask + # for i, (x, y) in enumerate(zip(input_ids[:-1], labels[1:])): + # print(f"{i:3d}: X={self.tokenizer.decode([x])!r:16s} ---> Y={self.tokenizer.decode([input_ids[i+1]])!r:16s} label={y}") + # # ================ + return torch.tensor(input_ids, dtype=torch.long), torch.tensor(labels, dtype=torch.long) class DPODataset(Dataset): diff --git a/model/model_minimind.py b/model/model_minimind.py index d826f82..b3910a8 100755 --- a/model/model_minimind.py +++ b/model/model_minimind.py @@ -437,6 +437,7 @@ class MiniMindForCausalLM(PreTrainedModel, GenerationMixin): def forward(self, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None, use_cache: bool = False, logits_to_keep: Union[int, torch.Tensor] = 0, @@ -450,6 +451,13 @@ class MiniMindForCausalLM(PreTrainedModel, GenerationMixin): ) slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep logits = self.lm_head(hidden_states[:, slice_indices, :]) - output = CausalLMOutputWithPast(logits=logits, past_key_values=past_key_values, hidden_states=hidden_states) + + loss = None + if labels is not None: + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + loss = F.cross_entropy(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1), ignore_index=-100) + + output = CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=past_key_values, hidden_states=hidden_states) output.aux_loss = aux_loss return output diff --git a/trainer/train_distillation.py b/trainer/train_distillation.py index abae2c1..2b8afc3 100644 --- a/trainer/train_distillation.py +++ b/trainer/train_distillation.py @@ -42,36 +42,37 @@ def train_epoch(epoch, loader, iters, teacher_model, lm_config_student, start_st teacher_model.eval() teacher_model.requires_grad_(False) - for step, (X, Y, loss_mask) in enumerate(loader, start=start_step + 1): - X = X.to(args.device) - Y = Y.to(args.device) - loss_mask = loss_mask.to(args.device) + for step, (input_ids, labels) in enumerate(loader, start=start_step + 1): + input_ids = input_ids.to(args.device) + labels = labels.to(args.device) + loss_mask = (labels[..., 1:] != -100).float() lr = get_lr(epoch * iters + step, args.epochs * iters, args.learning_rate) for param_group in optimizer.param_groups: param_group['lr'] = lr # 前向传播(学生模型) with autocast_ctx: - res = model(X) - student_logits = res.logits + res = model(input_ids) + student_logits = res.logits[..., :-1, :].contiguous() # 教师模型前向传播(只在eval & no_grad) if teacher_model is not None: with torch.no_grad(): - teacher_logits = teacher_model(X).logits + teacher_logits = teacher_model(input_ids).logits[..., :-1, :].contiguous() vocab_size_student = student_logits.size(-1) teacher_logits = teacher_logits[..., :vocab_size_student] # ========== 计算损失 ========== # 1) Ground-Truth CE Loss + shift_labels = labels[..., 1:].contiguous() loss_mask_flat = loss_mask.view(-1) ce_loss = F.cross_entropy( student_logits.view(-1, student_logits.size(-1)), - Y.view(-1), - ignore_index=0, + shift_labels.view(-1), + ignore_index=-100, reduction='none' ) - ce_loss_raw = torch.sum(ce_loss * loss_mask_flat) / loss_mask_flat.sum() + ce_loss_raw = torch.sum(ce_loss * loss_mask_flat) / (loss_mask_flat.sum() + 1e-8) if lm_config_student.use_moe: ce_loss = ce_loss_raw + res.aux_loss else: ce_loss = ce_loss_raw @@ -124,13 +125,12 @@ def train_epoch(epoch, loader, iters, teacher_model, lm_config_student, start_st raw_model = model.module if isinstance(model, DistributedDataParallel) else model raw_model = getattr(raw_model, '_orig_mod', raw_model) state_dict = raw_model.state_dict() - state_dict = {k: v.half().cpu() for k, v in state_dict.items()} - torch.save(state_dict, ckp) + torch.save({k: v.half().cpu() for k, v in state_dict.items()}, ckp) lm_checkpoint(lm_config_student, weight=args.save_weight, model=model, optimizer=optimizer, scaler=scaler, epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints') model.train() del state_dict - del X, Y, loss_mask, res, student_logits, teacher_logits, ce_loss, distill_loss, loss + del input_ids, labels, loss_mask, res, student_logits, ce_loss, distill_loss, loss if __name__ == "__main__": diff --git a/trainer/train_dpo.py b/trainer/train_dpo.py index 52311ca..2f37851 100644 --- a/trainer/train_dpo.py +++ b/trainer/train_dpo.py @@ -111,8 +111,7 @@ def train_epoch(epoch, loader, iters, ref_model, lm_config, start_step=0, wandb= raw_model = model.module if isinstance(model, DistributedDataParallel) else model raw_model = getattr(raw_model, '_orig_mod', raw_model) state_dict = raw_model.state_dict() - state_dict = {k: v.half().cpu() for k, v in state_dict.items()} - torch.save(state_dict, ckp) + torch.save({k: v.half().cpu() for k, v in state_dict.items()}, ckp) lm_checkpoint(lm_config, weight=args.save_weight, model=model, optimizer=optimizer, scaler=scaler, epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints') model.train() del state_dict diff --git a/trainer/train_full_sft.py b/trainer/train_full_sft.py index 3908213..fc68cd7 100644 --- a/trainer/train_full_sft.py +++ b/trainer/train_full_sft.py @@ -21,25 +21,17 @@ warnings.filterwarnings('ignore') def train_epoch(epoch, loader, iters, start_step=0, wandb=None): - loss_fct = nn.CrossEntropyLoss(reduction='none') start_time = time.time() - for step, (X, Y, loss_mask) in enumerate(loader, start=start_step + 1): - X = X.to(args.device) - Y = Y.to(args.device) - loss_mask = loss_mask.to(args.device) + for step, (input_ids, labels) in enumerate(loader, start=start_step + 1): + input_ids = input_ids.to(args.device) + labels = labels.to(args.device) lr = get_lr(epoch * iters + step, args.epochs * iters, args.learning_rate) for param_group in optimizer.param_groups: param_group['lr'] = lr with autocast_ctx: - res = model(X) - loss = loss_fct( - res.logits.view(-1, res.logits.size(-1)), - Y.view(-1) - ).view(Y.size()) - - logits_loss = (loss * loss_mask).sum() / loss_mask.sum() - loss = logits_loss + res.aux_loss + res = model(input_ids, labels=labels) + loss = res.loss + res.aux_loss loss = loss / args.accumulation_steps scaler.scale(loss).backward() @@ -56,13 +48,11 @@ def train_epoch(epoch, loader, iters, start_step=0, wandb=None): if step % args.log_interval == 0 or step == iters - 1: spend_time = time.time() - start_time current_loss = loss.item() * args.accumulation_steps - current_logits_loss = logits_loss.item() - current_aux_loss = res.aux_loss.item() + current_aux_loss = res.aux_loss.item() if res.aux_loss is not None else 0.0 + current_logits_loss = current_loss - current_aux_loss current_lr = optimizer.param_groups[-1]['lr'] eta_min = spend_time / (step + 1) * iters // 60 - spend_time // 60 - - Logger(f'Epoch:[{epoch + 1}/{args.epochs}]({step}/{iters}), loss: {current_loss:.4f}, logits_loss: {current_logits_loss:.4f}, aux_loss: {current_aux_loss:.4f}, learning_rate: {current_lr:.8f}, epoch_time: {eta_min:.3f}min') - + Logger(f'Epoch:[{epoch + 1}/{args.epochs}]({step}/{iters}), loss: {current_loss:.4f}, logits_loss: {current_logits_loss:.4f}, aux_loss: {current_aux_loss:.4f}, lr: {current_lr:.8f}, epoch_time: {eta_min:.1f}min') if wandb: wandb.log({"loss": current_loss, "logits_loss": current_logits_loss, "aux_loss": current_aux_loss, "learning_rate": current_lr, "epoch_time": eta_min}) if (step % args.save_interval == 0 or step == iters - 1) and is_main_process(): @@ -72,14 +62,13 @@ def train_epoch(epoch, loader, iters, start_step=0, wandb=None): raw_model = model.module if isinstance(model, DistributedDataParallel) else model raw_model = getattr(raw_model, '_orig_mod', raw_model) state_dict = raw_model.state_dict() - state_dict = {k: v.half().cpu() for k, v in state_dict.items()} - torch.save(state_dict, ckp) + torch.save({k: v.half().cpu() for k, v in state_dict.items()}, ckp) lm_checkpoint(lm_config, weight=args.save_weight, model=model, optimizer=optimizer, epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints', scaler=scaler) model.train() del state_dict - del X, Y, loss_mask, res, loss + del input_ids, labels, res, loss if __name__ == "__main__": diff --git a/trainer/train_lora.py b/trainer/train_lora.py index 9f3792f..fdf83bc 100644 --- a/trainer/train_lora.py +++ b/trainer/train_lora.py @@ -22,25 +22,17 @@ warnings.filterwarnings('ignore') def train_epoch(epoch, loader, iters, lora_params, start_step=0, wandb=None): - loss_fct = nn.CrossEntropyLoss(reduction='none') start_time = time.time() - for step, (X, Y, loss_mask) in enumerate(loader, start=start_step + 1): - X = X.to(args.device) - Y = Y.to(args.device) - loss_mask = loss_mask.to(args.device) + for step, (input_ids, labels) in enumerate(loader, start=start_step + 1): + input_ids = input_ids.to(args.device) + labels = labels.to(args.device) lr = get_lr(epoch * iters + step, args.epochs * iters, args.learning_rate) for param_group in optimizer.param_groups: param_group['lr'] = lr with autocast_ctx: - res = model(X) - loss = loss_fct( - res.logits.view(-1, res.logits.size(-1)), - Y.view(-1) - ).view(Y.size()) - - logits_loss = (loss * loss_mask).sum() / loss_mask.sum() - loss = logits_loss + res.aux_loss + res = model(input_ids, labels=labels) + loss = res.loss + res.aux_loss loss = loss / args.accumulation_steps scaler.scale(loss).backward() @@ -48,22 +40,18 @@ def train_epoch(epoch, loader, iters, lora_params, start_step=0, wandb=None): if (step + 1) % args.accumulation_steps == 0: scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(lora_params, args.grad_clip) - scaler.step(optimizer) scaler.update() - optimizer.zero_grad(set_to_none=True) if step % args.log_interval == 0 or step == iters - 1: spend_time = time.time() - start_time current_loss = loss.item() * args.accumulation_steps - current_logits_loss = logits_loss.item() - current_aux_loss = res.aux_loss.item() + current_aux_loss = res.aux_loss.item() if res.aux_loss is not None else 0.0 + current_logits_loss = current_loss - current_aux_loss current_lr = optimizer.param_groups[-1]['lr'] eta_min = spend_time / (step + 1) * iters // 60 - spend_time // 60 - - Logger(f'Epoch:[{epoch + 1}/{args.epochs}]({step}/{iters}), loss: {current_loss:.4f}, logits_loss: {current_logits_loss:.4f}, aux_loss: {current_aux_loss:.4f}, learning_rate: {current_lr:.8f}, epoch_time: {eta_min:.3f}min') - + Logger(f'Epoch:[{epoch + 1}/{args.epochs}]({step}/{iters}), loss: {current_loss:.4f}, logits_loss: {current_logits_loss:.4f}, aux_loss: {current_aux_loss:.4f}, lr: {current_lr:.8f}, epoch_time: {eta_min:.1f}min') if wandb: wandb.log({"loss": current_loss, "logits_loss": current_logits_loss, "aux_loss": current_aux_loss, "learning_rate": current_lr, "epoch_time": eta_min}) if (step % args.save_interval == 0 or step == iters - 1) and is_main_process(): @@ -74,7 +62,7 @@ def train_epoch(epoch, loader, iters, lora_params, start_step=0, wandb=None): lm_checkpoint(lm_config, weight=args.lora_name, model=model, optimizer=optimizer, scaler=scaler, epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints') model.train() - del X, Y, loss_mask, res, loss + del input_ids, labels, res, loss if __name__ == "__main__": diff --git a/trainer/train_pretrain.py b/trainer/train_pretrain.py index 4ebceeb..7f2a584 100644 --- a/trainer/train_pretrain.py +++ b/trainer/train_pretrain.py @@ -21,25 +21,17 @@ warnings.filterwarnings('ignore') def train_epoch(epoch, loader, iters, start_step=0, wandb=None): - loss_fct = nn.CrossEntropyLoss(reduction='none') start_time = time.time() - for step, (X, Y, loss_mask) in enumerate(loader, start=start_step + 1): - X = X.to(args.device) - Y = Y.to(args.device) - loss_mask = loss_mask.to(args.device) + for step, (input_ids, labels) in enumerate(loader, start=start_step + 1): + input_ids = input_ids.to(args.device) + labels = labels.to(args.device) lr = get_lr(epoch * iters + step, args.epochs * iters, args.learning_rate) for param_group in optimizer.param_groups: param_group['lr'] = lr with autocast_ctx: - res = model(X) - loss = loss_fct( - res.logits.view(-1, res.logits.size(-1)), - Y.view(-1) - ).view(Y.size()) - - logits_loss = (loss * loss_mask).sum() / loss_mask.sum() - loss = logits_loss + res.aux_loss + res = model(input_ids, labels=labels) + loss = res.loss + res.aux_loss loss = loss / args.accumulation_steps scaler.scale(loss).backward() @@ -56,13 +48,11 @@ def train_epoch(epoch, loader, iters, start_step=0, wandb=None): if step % args.log_interval == 0 or step == iters - 1: spend_time = time.time() - start_time current_loss = loss.item() * args.accumulation_steps - current_logits_loss = logits_loss.item() - current_aux_loss = res.aux_loss.item() + current_aux_loss = res.aux_loss.item() if res.aux_loss is not None else 0.0 + current_logits_loss = current_loss - current_aux_loss current_lr = optimizer.param_groups[-1]['lr'] eta_min = spend_time / (step + 1) * iters // 60 - spend_time // 60 - - Logger(f'Epoch:[{epoch + 1}/{args.epochs}]({step}/{iters}), loss: {current_loss:.4f}, logits_loss: {current_logits_loss:.4f}, aux_loss: {current_aux_loss:.4f}, learning_rate: {current_lr:.8f}, epoch_time: {eta_min:.3f}min') - + Logger(f'Epoch:[{epoch + 1}/{args.epochs}]({step}/{iters}), loss: {current_loss:.4f}, logits_loss: {current_logits_loss:.4f}, aux_loss: {current_aux_loss:.4f}, lr: {current_lr:.8f}, epoch_time: {eta_min:.1f}min') if wandb: wandb.log({"loss": current_loss, "logits_loss": current_logits_loss, "aux_loss": current_aux_loss, "learning_rate": current_lr, "epoch_time": eta_min}) if (step % args.save_interval == 0 or step == iters - 1) and is_main_process(): @@ -72,13 +62,12 @@ def train_epoch(epoch, loader, iters, start_step=0, wandb=None): raw_model = model.module if isinstance(model, DistributedDataParallel) else model raw_model = getattr(raw_model, '_orig_mod', raw_model) state_dict = raw_model.state_dict() - state_dict = {k: v.half().cpu() for k, v in state_dict.items()} - torch.save(state_dict, ckp) + torch.save({k: v.half().cpu() for k, v in state_dict.items()}, ckp) lm_checkpoint(lm_config, weight=args.save_weight, model=model, optimizer=optimizer, scaler=scaler, epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints') model.train() del state_dict - del X, Y, loss_mask, res, loss + del input_ids, labels, res, loss if __name__ == "__main__": diff --git a/trainer/train_reason.py b/trainer/train_reason.py index fbdaf29..80c6879 100644 --- a/trainer/train_reason.py +++ b/trainer/train_reason.py @@ -21,7 +21,6 @@ warnings.filterwarnings('ignore') def train_epoch(epoch, loader, iters, tokenizer, lm_config, start_step=0, wandb=None): - # 思考标签占位符 start_of_think_ids = tokenizer('').input_ids end_of_think_ids = tokenizer('').input_ids start_of_answer_ids = tokenizer('').input_ids @@ -29,30 +28,28 @@ def train_epoch(epoch, loader, iters, tokenizer, lm_config, start_step=0, wandb= loss_fct = nn.CrossEntropyLoss(reduction='none') start_time = time.time() - for step, (X, Y, loss_mask) in enumerate(loader, start=start_step + 1): - X = X.to(args.device) - Y = Y.to(args.device) - loss_mask = loss_mask.to(args.device) + for step, (input_ids, labels) in enumerate(loader, start=start_step + 1): + input_ids = input_ids.to(args.device) + labels = labels.to(args.device) lr = get_lr(epoch * iters + step, args.epochs * iters, args.learning_rate) for param_group in optimizer.param_groups: param_group['lr'] = lr with autocast_ctx: - res = model(X) - loss = loss_fct( - res.logits.view(-1, res.logits.size(-1)), - Y.view(-1) - ).view(Y.size()) + res = model(input_ids) + shift_logits = res.logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)).view(shift_labels.size()) - # 特殊标签位置增加权重(推理蒸馏特有) - sp_ids = torch.isin(Y.view(-1), + loss_mask = (shift_labels != -100).float() + sp_ids = torch.isin(shift_labels.view(-1), torch.tensor(start_of_think_ids + end_of_think_ids + start_of_answer_ids + end_of_answer_ids ).to(args.device)) - loss_mask = loss_mask.view(-1) - loss_mask_sum = loss_mask.sum() - loss_mask[sp_ids] = 10 # 对思考标签增加10倍权重 - loss_mask = loss_mask.view(Y.size()) + loss_mask_flat = loss_mask.view(-1) + loss_mask_sum = loss_mask_flat.sum() + loss_mask_flat[sp_ids] = 10 + loss_mask = loss_mask_flat.view(shift_labels.size()) logits_loss = (loss * loss_mask).sum() / loss_mask_sum loss = logits_loss + res.aux_loss loss = loss / args.accumulation_steps @@ -69,13 +66,11 @@ def train_epoch(epoch, loader, iters, tokenizer, lm_config, start_step=0, wandb= if step % args.log_interval == 0 or step == iters - 1: spend_time = time.time() - start_time current_loss = loss.item() * args.accumulation_steps + current_aux_loss = res.aux_loss.item() if res.aux_loss is not None else 0.0 current_logits_loss = logits_loss.item() - current_aux_loss = res.aux_loss.item() current_lr = optimizer.param_groups[-1]['lr'] eta_min = spend_time / (step + 1) * iters // 60 - spend_time // 60 - - Logger(f'Epoch:[{epoch + 1}/{args.epochs}]({step}/{iters}), loss: {current_loss:.4f}, logits_loss: {current_logits_loss:.4f}, aux_loss: {current_aux_loss:.4f}, learning_rate: {current_lr:.8f}, epoch_time: {eta_min:.3f}min') - + Logger(f'Epoch:[{epoch + 1}/{args.epochs}]({step}/{iters}), loss: {current_loss:.4f}, logits_loss: {current_logits_loss:.4f}, aux_loss: {current_aux_loss:.4f}, lr: {current_lr:.8f}, epoch_time: {eta_min:.1f}min') if wandb: wandb.log({"loss": current_loss, "logits_loss": current_logits_loss, "aux_loss": current_aux_loss, "learning_rate": current_lr, "epoch_time": eta_min}) if (step % args.save_interval == 0 or step == iters - 1) and is_main_process(): @@ -85,13 +80,12 @@ def train_epoch(epoch, loader, iters, tokenizer, lm_config, start_step=0, wandb= raw_model = model.module if isinstance(model, DistributedDataParallel) else model raw_model = getattr(raw_model, '_orig_mod', raw_model) state_dict = raw_model.state_dict() - state_dict = {k: v.half().cpu() for k, v in state_dict.items()} - torch.save(state_dict, ckp) + torch.save({k: v.half().cpu() for k, v in state_dict.items()}, ckp) lm_checkpoint(lm_config, weight=args.save_weight, model=model, optimizer=optimizer, scaler=scaler, epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints') model.train() del state_dict - del X, Y, loss_mask, res, loss + del input_ids, labels, res, loss if __name__ == "__main__":