diff --git a/dataset/lm_dataset.py b/dataset/lm_dataset.py
index a958303..21e321c 100644
--- a/dataset/lm_dataset.py
+++ b/dataset/lm_dataset.py
@@ -16,8 +16,6 @@ class PretrainDataset(Dataset):
def __getitem__(self, index):
sample = self.samples[index]
-
- # 构建输入文本
encoding = self.tokenizer(
str(sample['text']),
max_length=self.max_length,
@@ -26,12 +24,9 @@ class PretrainDataset(Dataset):
return_tensors='pt'
)
input_ids = encoding.input_ids.squeeze()
- loss_mask = (input_ids != self.tokenizer.pad_token_id)
-
- X = torch.tensor(input_ids[:-1], dtype=torch.long)
- Y = torch.tensor(input_ids[1:], dtype=torch.long)
- loss_mask = torch.tensor(loss_mask[1:], dtype=torch.long)
- return X, Y, loss_mask
+ labels = input_ids.clone()
+ labels[input_ids == self.tokenizer.pad_token_id] = -100
+ return input_ids, labels
class SFTDataset(Dataset):
@@ -56,8 +51,8 @@ class SFTDataset(Dataset):
tools=tools
)
- def generate_loss_mask(self, input_ids):
- loss_mask = [0] * len(input_ids)
+ def generate_labels(self, input_ids):
+ labels = [-100] * len(input_ids)
i = 0
while i < len(input_ids):
if input_ids[i:i + len(self.bos_id)] == self.bos_id:
@@ -68,29 +63,24 @@ class SFTDataset(Dataset):
break
end += 1
for j in range(start + 1, min(end + len(self.eos_id) + 1, self.max_length)):
- loss_mask[j] = 1
+ labels[j] = input_ids[j]
i = end + len(self.eos_id) if end < len(input_ids) else len(input_ids)
else:
i += 1
- return loss_mask
+ return labels
def __getitem__(self, index):
sample = self.samples[index]
prompt = self.create_chat_prompt(sample['conversations'])
input_ids = self.tokenizer(prompt).input_ids[:self.max_length]
input_ids += [self.tokenizer.pad_token_id] * (self.max_length - len(input_ids))
- loss_mask = self.generate_loss_mask(input_ids)
-
- # 构建训练数据
- X = torch.tensor(input_ids[:-1], dtype=torch.long)
- Y = torch.tensor(input_ids[1:], dtype=torch.long)
- loss_mask = torch.tensor(loss_mask[1:], dtype=torch.long) # 对齐预测位置
- # # === 打印每个token的掩码情况 ===
+ labels = self.generate_labels(input_ids)
+ # # === 调试打印 ===
# print(f"\n--- Sample {index} ---")
- # for i, (x, y, m) in enumerate(zip(X, Y, loss_mask)):
- # print(f"{i:3d}: X={self.tokenizer.decode([x])!r:16s} ---> Y={self.tokenizer.decode([y])!r:16s} mask={m}")
- # # ================================
- return X, Y, loss_mask
+ # for i, (x, y) in enumerate(zip(input_ids[:-1], labels[1:])):
+ # print(f"{i:3d}: X={self.tokenizer.decode([x])!r:16s} ---> Y={self.tokenizer.decode([input_ids[i+1]])!r:16s} label={y}")
+ # # ================
+ return torch.tensor(input_ids, dtype=torch.long), torch.tensor(labels, dtype=torch.long)
class DPODataset(Dataset):
diff --git a/model/model_minimind.py b/model/model_minimind.py
index d826f82..b3910a8 100755
--- a/model/model_minimind.py
+++ b/model/model_minimind.py
@@ -437,6 +437,7 @@ class MiniMindForCausalLM(PreTrainedModel, GenerationMixin):
def forward(self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
use_cache: bool = False,
logits_to_keep: Union[int, torch.Tensor] = 0,
@@ -450,6 +451,13 @@ class MiniMindForCausalLM(PreTrainedModel, GenerationMixin):
)
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.lm_head(hidden_states[:, slice_indices, :])
- output = CausalLMOutputWithPast(logits=logits, past_key_values=past_key_values, hidden_states=hidden_states)
+
+ loss = None
+ if labels is not None:
+ shift_logits = logits[..., :-1, :].contiguous()
+ shift_labels = labels[..., 1:].contiguous()
+ loss = F.cross_entropy(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1), ignore_index=-100)
+
+ output = CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=past_key_values, hidden_states=hidden_states)
output.aux_loss = aux_loss
return output
diff --git a/trainer/train_distillation.py b/trainer/train_distillation.py
index abae2c1..2b8afc3 100644
--- a/trainer/train_distillation.py
+++ b/trainer/train_distillation.py
@@ -42,36 +42,37 @@ def train_epoch(epoch, loader, iters, teacher_model, lm_config_student, start_st
teacher_model.eval()
teacher_model.requires_grad_(False)
- for step, (X, Y, loss_mask) in enumerate(loader, start=start_step + 1):
- X = X.to(args.device)
- Y = Y.to(args.device)
- loss_mask = loss_mask.to(args.device)
+ for step, (input_ids, labels) in enumerate(loader, start=start_step + 1):
+ input_ids = input_ids.to(args.device)
+ labels = labels.to(args.device)
+ loss_mask = (labels[..., 1:] != -100).float()
lr = get_lr(epoch * iters + step, args.epochs * iters, args.learning_rate)
for param_group in optimizer.param_groups:
param_group['lr'] = lr
# 前向传播(学生模型)
with autocast_ctx:
- res = model(X)
- student_logits = res.logits
+ res = model(input_ids)
+ student_logits = res.logits[..., :-1, :].contiguous()
# 教师模型前向传播(只在eval & no_grad)
if teacher_model is not None:
with torch.no_grad():
- teacher_logits = teacher_model(X).logits
+ teacher_logits = teacher_model(input_ids).logits[..., :-1, :].contiguous()
vocab_size_student = student_logits.size(-1)
teacher_logits = teacher_logits[..., :vocab_size_student]
# ========== 计算损失 ==========
# 1) Ground-Truth CE Loss
+ shift_labels = labels[..., 1:].contiguous()
loss_mask_flat = loss_mask.view(-1)
ce_loss = F.cross_entropy(
student_logits.view(-1, student_logits.size(-1)),
- Y.view(-1),
- ignore_index=0,
+ shift_labels.view(-1),
+ ignore_index=-100,
reduction='none'
)
- ce_loss_raw = torch.sum(ce_loss * loss_mask_flat) / loss_mask_flat.sum()
+ ce_loss_raw = torch.sum(ce_loss * loss_mask_flat) / (loss_mask_flat.sum() + 1e-8)
if lm_config_student.use_moe: ce_loss = ce_loss_raw + res.aux_loss
else: ce_loss = ce_loss_raw
@@ -124,13 +125,12 @@ def train_epoch(epoch, loader, iters, teacher_model, lm_config_student, start_st
raw_model = model.module if isinstance(model, DistributedDataParallel) else model
raw_model = getattr(raw_model, '_orig_mod', raw_model)
state_dict = raw_model.state_dict()
- state_dict = {k: v.half().cpu() for k, v in state_dict.items()}
- torch.save(state_dict, ckp)
+ torch.save({k: v.half().cpu() for k, v in state_dict.items()}, ckp)
lm_checkpoint(lm_config_student, weight=args.save_weight, model=model, optimizer=optimizer, scaler=scaler, epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints')
model.train()
del state_dict
- del X, Y, loss_mask, res, student_logits, teacher_logits, ce_loss, distill_loss, loss
+ del input_ids, labels, loss_mask, res, student_logits, ce_loss, distill_loss, loss
if __name__ == "__main__":
diff --git a/trainer/train_dpo.py b/trainer/train_dpo.py
index 52311ca..2f37851 100644
--- a/trainer/train_dpo.py
+++ b/trainer/train_dpo.py
@@ -111,8 +111,7 @@ def train_epoch(epoch, loader, iters, ref_model, lm_config, start_step=0, wandb=
raw_model = model.module if isinstance(model, DistributedDataParallel) else model
raw_model = getattr(raw_model, '_orig_mod', raw_model)
state_dict = raw_model.state_dict()
- state_dict = {k: v.half().cpu() for k, v in state_dict.items()}
- torch.save(state_dict, ckp)
+ torch.save({k: v.half().cpu() for k, v in state_dict.items()}, ckp)
lm_checkpoint(lm_config, weight=args.save_weight, model=model, optimizer=optimizer, scaler=scaler, epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints')
model.train()
del state_dict
diff --git a/trainer/train_full_sft.py b/trainer/train_full_sft.py
index 3908213..fc68cd7 100644
--- a/trainer/train_full_sft.py
+++ b/trainer/train_full_sft.py
@@ -21,25 +21,17 @@ warnings.filterwarnings('ignore')
def train_epoch(epoch, loader, iters, start_step=0, wandb=None):
- loss_fct = nn.CrossEntropyLoss(reduction='none')
start_time = time.time()
- for step, (X, Y, loss_mask) in enumerate(loader, start=start_step + 1):
- X = X.to(args.device)
- Y = Y.to(args.device)
- loss_mask = loss_mask.to(args.device)
+ for step, (input_ids, labels) in enumerate(loader, start=start_step + 1):
+ input_ids = input_ids.to(args.device)
+ labels = labels.to(args.device)
lr = get_lr(epoch * iters + step, args.epochs * iters, args.learning_rate)
for param_group in optimizer.param_groups:
param_group['lr'] = lr
with autocast_ctx:
- res = model(X)
- loss = loss_fct(
- res.logits.view(-1, res.logits.size(-1)),
- Y.view(-1)
- ).view(Y.size())
-
- logits_loss = (loss * loss_mask).sum() / loss_mask.sum()
- loss = logits_loss + res.aux_loss
+ res = model(input_ids, labels=labels)
+ loss = res.loss + res.aux_loss
loss = loss / args.accumulation_steps
scaler.scale(loss).backward()
@@ -56,13 +48,11 @@ def train_epoch(epoch, loader, iters, start_step=0, wandb=None):
if step % args.log_interval == 0 or step == iters - 1:
spend_time = time.time() - start_time
current_loss = loss.item() * args.accumulation_steps
- current_logits_loss = logits_loss.item()
- current_aux_loss = res.aux_loss.item()
+ current_aux_loss = res.aux_loss.item() if res.aux_loss is not None else 0.0
+ current_logits_loss = current_loss - current_aux_loss
current_lr = optimizer.param_groups[-1]['lr']
eta_min = spend_time / (step + 1) * iters // 60 - spend_time // 60
-
- Logger(f'Epoch:[{epoch + 1}/{args.epochs}]({step}/{iters}), loss: {current_loss:.4f}, logits_loss: {current_logits_loss:.4f}, aux_loss: {current_aux_loss:.4f}, learning_rate: {current_lr:.8f}, epoch_time: {eta_min:.3f}min')
-
+ Logger(f'Epoch:[{epoch + 1}/{args.epochs}]({step}/{iters}), loss: {current_loss:.4f}, logits_loss: {current_logits_loss:.4f}, aux_loss: {current_aux_loss:.4f}, lr: {current_lr:.8f}, epoch_time: {eta_min:.1f}min')
if wandb: wandb.log({"loss": current_loss, "logits_loss": current_logits_loss, "aux_loss": current_aux_loss, "learning_rate": current_lr, "epoch_time": eta_min})
if (step % args.save_interval == 0 or step == iters - 1) and is_main_process():
@@ -72,14 +62,13 @@ def train_epoch(epoch, loader, iters, start_step=0, wandb=None):
raw_model = model.module if isinstance(model, DistributedDataParallel) else model
raw_model = getattr(raw_model, '_orig_mod', raw_model)
state_dict = raw_model.state_dict()
- state_dict = {k: v.half().cpu() for k, v in state_dict.items()}
- torch.save(state_dict, ckp)
+ torch.save({k: v.half().cpu() for k, v in state_dict.items()}, ckp)
lm_checkpoint(lm_config, weight=args.save_weight, model=model, optimizer=optimizer,
epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints', scaler=scaler)
model.train()
del state_dict
- del X, Y, loss_mask, res, loss
+ del input_ids, labels, res, loss
if __name__ == "__main__":
diff --git a/trainer/train_lora.py b/trainer/train_lora.py
index 9f3792f..fdf83bc 100644
--- a/trainer/train_lora.py
+++ b/trainer/train_lora.py
@@ -22,25 +22,17 @@ warnings.filterwarnings('ignore')
def train_epoch(epoch, loader, iters, lora_params, start_step=0, wandb=None):
- loss_fct = nn.CrossEntropyLoss(reduction='none')
start_time = time.time()
- for step, (X, Y, loss_mask) in enumerate(loader, start=start_step + 1):
- X = X.to(args.device)
- Y = Y.to(args.device)
- loss_mask = loss_mask.to(args.device)
+ for step, (input_ids, labels) in enumerate(loader, start=start_step + 1):
+ input_ids = input_ids.to(args.device)
+ labels = labels.to(args.device)
lr = get_lr(epoch * iters + step, args.epochs * iters, args.learning_rate)
for param_group in optimizer.param_groups:
param_group['lr'] = lr
with autocast_ctx:
- res = model(X)
- loss = loss_fct(
- res.logits.view(-1, res.logits.size(-1)),
- Y.view(-1)
- ).view(Y.size())
-
- logits_loss = (loss * loss_mask).sum() / loss_mask.sum()
- loss = logits_loss + res.aux_loss
+ res = model(input_ids, labels=labels)
+ loss = res.loss + res.aux_loss
loss = loss / args.accumulation_steps
scaler.scale(loss).backward()
@@ -48,22 +40,18 @@ def train_epoch(epoch, loader, iters, lora_params, start_step=0, wandb=None):
if (step + 1) % args.accumulation_steps == 0:
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(lora_params, args.grad_clip)
-
scaler.step(optimizer)
scaler.update()
-
optimizer.zero_grad(set_to_none=True)
if step % args.log_interval == 0 or step == iters - 1:
spend_time = time.time() - start_time
current_loss = loss.item() * args.accumulation_steps
- current_logits_loss = logits_loss.item()
- current_aux_loss = res.aux_loss.item()
+ current_aux_loss = res.aux_loss.item() if res.aux_loss is not None else 0.0
+ current_logits_loss = current_loss - current_aux_loss
current_lr = optimizer.param_groups[-1]['lr']
eta_min = spend_time / (step + 1) * iters // 60 - spend_time // 60
-
- Logger(f'Epoch:[{epoch + 1}/{args.epochs}]({step}/{iters}), loss: {current_loss:.4f}, logits_loss: {current_logits_loss:.4f}, aux_loss: {current_aux_loss:.4f}, learning_rate: {current_lr:.8f}, epoch_time: {eta_min:.3f}min')
-
+ Logger(f'Epoch:[{epoch + 1}/{args.epochs}]({step}/{iters}), loss: {current_loss:.4f}, logits_loss: {current_logits_loss:.4f}, aux_loss: {current_aux_loss:.4f}, lr: {current_lr:.8f}, epoch_time: {eta_min:.1f}min')
if wandb: wandb.log({"loss": current_loss, "logits_loss": current_logits_loss, "aux_loss": current_aux_loss, "learning_rate": current_lr, "epoch_time": eta_min})
if (step % args.save_interval == 0 or step == iters - 1) and is_main_process():
@@ -74,7 +62,7 @@ def train_epoch(epoch, loader, iters, lora_params, start_step=0, wandb=None):
lm_checkpoint(lm_config, weight=args.lora_name, model=model, optimizer=optimizer, scaler=scaler, epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints')
model.train()
- del X, Y, loss_mask, res, loss
+ del input_ids, labels, res, loss
if __name__ == "__main__":
diff --git a/trainer/train_pretrain.py b/trainer/train_pretrain.py
index 4ebceeb..7f2a584 100644
--- a/trainer/train_pretrain.py
+++ b/trainer/train_pretrain.py
@@ -21,25 +21,17 @@ warnings.filterwarnings('ignore')
def train_epoch(epoch, loader, iters, start_step=0, wandb=None):
- loss_fct = nn.CrossEntropyLoss(reduction='none')
start_time = time.time()
- for step, (X, Y, loss_mask) in enumerate(loader, start=start_step + 1):
- X = X.to(args.device)
- Y = Y.to(args.device)
- loss_mask = loss_mask.to(args.device)
+ for step, (input_ids, labels) in enumerate(loader, start=start_step + 1):
+ input_ids = input_ids.to(args.device)
+ labels = labels.to(args.device)
lr = get_lr(epoch * iters + step, args.epochs * iters, args.learning_rate)
for param_group in optimizer.param_groups:
param_group['lr'] = lr
with autocast_ctx:
- res = model(X)
- loss = loss_fct(
- res.logits.view(-1, res.logits.size(-1)),
- Y.view(-1)
- ).view(Y.size())
-
- logits_loss = (loss * loss_mask).sum() / loss_mask.sum()
- loss = logits_loss + res.aux_loss
+ res = model(input_ids, labels=labels)
+ loss = res.loss + res.aux_loss
loss = loss / args.accumulation_steps
scaler.scale(loss).backward()
@@ -56,13 +48,11 @@ def train_epoch(epoch, loader, iters, start_step=0, wandb=None):
if step % args.log_interval == 0 or step == iters - 1:
spend_time = time.time() - start_time
current_loss = loss.item() * args.accumulation_steps
- current_logits_loss = logits_loss.item()
- current_aux_loss = res.aux_loss.item()
+ current_aux_loss = res.aux_loss.item() if res.aux_loss is not None else 0.0
+ current_logits_loss = current_loss - current_aux_loss
current_lr = optimizer.param_groups[-1]['lr']
eta_min = spend_time / (step + 1) * iters // 60 - spend_time // 60
-
- Logger(f'Epoch:[{epoch + 1}/{args.epochs}]({step}/{iters}), loss: {current_loss:.4f}, logits_loss: {current_logits_loss:.4f}, aux_loss: {current_aux_loss:.4f}, learning_rate: {current_lr:.8f}, epoch_time: {eta_min:.3f}min')
-
+ Logger(f'Epoch:[{epoch + 1}/{args.epochs}]({step}/{iters}), loss: {current_loss:.4f}, logits_loss: {current_logits_loss:.4f}, aux_loss: {current_aux_loss:.4f}, lr: {current_lr:.8f}, epoch_time: {eta_min:.1f}min')
if wandb: wandb.log({"loss": current_loss, "logits_loss": current_logits_loss, "aux_loss": current_aux_loss, "learning_rate": current_lr, "epoch_time": eta_min})
if (step % args.save_interval == 0 or step == iters - 1) and is_main_process():
@@ -72,13 +62,12 @@ def train_epoch(epoch, loader, iters, start_step=0, wandb=None):
raw_model = model.module if isinstance(model, DistributedDataParallel) else model
raw_model = getattr(raw_model, '_orig_mod', raw_model)
state_dict = raw_model.state_dict()
- state_dict = {k: v.half().cpu() for k, v in state_dict.items()}
- torch.save(state_dict, ckp)
+ torch.save({k: v.half().cpu() for k, v in state_dict.items()}, ckp)
lm_checkpoint(lm_config, weight=args.save_weight, model=model, optimizer=optimizer, scaler=scaler, epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints')
model.train()
del state_dict
- del X, Y, loss_mask, res, loss
+ del input_ids, labels, res, loss
if __name__ == "__main__":
diff --git a/trainer/train_reason.py b/trainer/train_reason.py
index fbdaf29..80c6879 100644
--- a/trainer/train_reason.py
+++ b/trainer/train_reason.py
@@ -21,7 +21,6 @@ warnings.filterwarnings('ignore')
def train_epoch(epoch, loader, iters, tokenizer, lm_config, start_step=0, wandb=None):
- # 思考标签占位符
start_of_think_ids = tokenizer('').input_ids
end_of_think_ids = tokenizer('').input_ids
start_of_answer_ids = tokenizer('').input_ids
@@ -29,30 +28,28 @@ def train_epoch(epoch, loader, iters, tokenizer, lm_config, start_step=0, wandb=
loss_fct = nn.CrossEntropyLoss(reduction='none')
start_time = time.time()
- for step, (X, Y, loss_mask) in enumerate(loader, start=start_step + 1):
- X = X.to(args.device)
- Y = Y.to(args.device)
- loss_mask = loss_mask.to(args.device)
+ for step, (input_ids, labels) in enumerate(loader, start=start_step + 1):
+ input_ids = input_ids.to(args.device)
+ labels = labels.to(args.device)
lr = get_lr(epoch * iters + step, args.epochs * iters, args.learning_rate)
for param_group in optimizer.param_groups:
param_group['lr'] = lr
with autocast_ctx:
- res = model(X)
- loss = loss_fct(
- res.logits.view(-1, res.logits.size(-1)),
- Y.view(-1)
- ).view(Y.size())
+ res = model(input_ids)
+ shift_logits = res.logits[..., :-1, :].contiguous()
+ shift_labels = labels[..., 1:].contiguous()
+ loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)).view(shift_labels.size())
- # 特殊标签位置增加权重(推理蒸馏特有)
- sp_ids = torch.isin(Y.view(-1),
+ loss_mask = (shift_labels != -100).float()
+ sp_ids = torch.isin(shift_labels.view(-1),
torch.tensor(start_of_think_ids + end_of_think_ids
+ start_of_answer_ids + end_of_answer_ids
).to(args.device))
- loss_mask = loss_mask.view(-1)
- loss_mask_sum = loss_mask.sum()
- loss_mask[sp_ids] = 10 # 对思考标签增加10倍权重
- loss_mask = loss_mask.view(Y.size())
+ loss_mask_flat = loss_mask.view(-1)
+ loss_mask_sum = loss_mask_flat.sum()
+ loss_mask_flat[sp_ids] = 10
+ loss_mask = loss_mask_flat.view(shift_labels.size())
logits_loss = (loss * loss_mask).sum() / loss_mask_sum
loss = logits_loss + res.aux_loss
loss = loss / args.accumulation_steps
@@ -69,13 +66,11 @@ def train_epoch(epoch, loader, iters, tokenizer, lm_config, start_step=0, wandb=
if step % args.log_interval == 0 or step == iters - 1:
spend_time = time.time() - start_time
current_loss = loss.item() * args.accumulation_steps
+ current_aux_loss = res.aux_loss.item() if res.aux_loss is not None else 0.0
current_logits_loss = logits_loss.item()
- current_aux_loss = res.aux_loss.item()
current_lr = optimizer.param_groups[-1]['lr']
eta_min = spend_time / (step + 1) * iters // 60 - spend_time // 60
-
- Logger(f'Epoch:[{epoch + 1}/{args.epochs}]({step}/{iters}), loss: {current_loss:.4f}, logits_loss: {current_logits_loss:.4f}, aux_loss: {current_aux_loss:.4f}, learning_rate: {current_lr:.8f}, epoch_time: {eta_min:.3f}min')
-
+ Logger(f'Epoch:[{epoch + 1}/{args.epochs}]({step}/{iters}), loss: {current_loss:.4f}, logits_loss: {current_logits_loss:.4f}, aux_loss: {current_aux_loss:.4f}, lr: {current_lr:.8f}, epoch_time: {eta_min:.1f}min')
if wandb: wandb.log({"loss": current_loss, "logits_loss": current_logits_loss, "aux_loss": current_aux_loss, "learning_rate": current_lr, "epoch_time": eta_min})
if (step % args.save_interval == 0 or step == iters - 1) and is_main_process():
@@ -85,13 +80,12 @@ def train_epoch(epoch, loader, iters, tokenizer, lm_config, start_step=0, wandb=
raw_model = model.module if isinstance(model, DistributedDataParallel) else model
raw_model = getattr(raw_model, '_orig_mod', raw_model)
state_dict = raw_model.state_dict()
- state_dict = {k: v.half().cpu() for k, v in state_dict.items()}
- torch.save(state_dict, ckp)
+ torch.save({k: v.half().cpu() for k, v in state_dict.items()}, ckp)
lm_checkpoint(lm_config, weight=args.save_weight, model=model, optimizer=optimizer, scaler=scaler, epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints')
model.train()
del state_dict
- del X, Y, loss_mask, res, loss
+ del input_ids, labels, res, loss
if __name__ == "__main__":