mirror of
https://github.com/jingyaogong/minimind.git
synced 2026-04-25 08:48:16 +08:00
[update] harden training and inference reliability
Fix truncated SFT/DPO label masking and resume-safe gradient accumulation to prevent unstable optimization, and add explicit eval backend selection plus regression checks for safer local validation.
This commit is contained in:
parent
25a7edcd6f
commit
d37bfa9d75
64
REGRESSION_CHECKS.md
Normal file
64
REGRESSION_CHECKS.md
Normal file
@ -0,0 +1,64 @@
|
||||
# Regression Checks for Recent Fixes
|
||||
|
||||
This file records lightweight checks for the recent high-priority fixes.
|
||||
|
||||
## 1) Syntax checks
|
||||
|
||||
```bash
|
||||
python -m py_compile dataset/lm_dataset.py trainer/train_pretrain.py trainer/train_full_sft.py trainer/train_lora.py trainer/train_distillation.py trainer/train_dpo.py trainer/train_grpo.py trainer/train_agent.py eval_llm.py
|
||||
```
|
||||
|
||||
## 2) Dataset label/mask boundary check (no eos after truncation)
|
||||
|
||||
```bash
|
||||
python - <<'PY'
|
||||
from dataset.lm_dataset import SFTDataset, DPODataset
|
||||
|
||||
s = SFTDataset.__new__(SFTDataset)
|
||||
s.bos_id = [11, 12]
|
||||
s.eos_id = [13]
|
||||
s.max_length = 10
|
||||
raw_ids = [99, 11, 12, 21, 22] # truncated without eos
|
||||
labels = s.generate_labels(raw_ids)
|
||||
pad_len = s.max_length - len(raw_ids)
|
||||
labels = labels + ([-100] * pad_len)
|
||||
print("sft_pad_labels_ok", labels[5:] == [-100] * pad_len)
|
||||
|
||||
d = DPODataset.__new__(DPODataset)
|
||||
d.bos_id = [11, 12]
|
||||
d.eos_id = [13]
|
||||
d.max_length = 10
|
||||
mask = d.generate_loss_mask(raw_ids)
|
||||
mask = mask + ([0] * pad_len)
|
||||
print("dpo_pad_mask_ok", mask[5:] == [0] * pad_len)
|
||||
PY
|
||||
```
|
||||
|
||||
Expected output:
|
||||
|
||||
```text
|
||||
sft_pad_labels_ok True
|
||||
dpo_pad_mask_ok True
|
||||
```
|
||||
|
||||
## 3) Eval backend argument check
|
||||
|
||||
```bash
|
||||
python eval_llm.py -h
|
||||
```
|
||||
|
||||
Expected: help output includes `--backend {auto,torch,hf}`.
|
||||
|
||||
## 4) Manual inference smoke tests
|
||||
|
||||
HF backend:
|
||||
|
||||
```bash
|
||||
python eval_llm.py --backend hf --load_from ./minimind-3 --max_new_tokens 64 --temperature 0.2 --top_p 0.8
|
||||
```
|
||||
|
||||
Torch backend:
|
||||
|
||||
```bash
|
||||
python eval_llm.py --backend torch --load_from model --weight full_sft --max_new_tokens 64 --temperature 0.2 --top_p 0.8
|
||||
```
|
||||
@ -92,11 +92,19 @@ class SFTDataset(Dataset):
|
||||
if input_ids[i:i + len(self.bos_id)] == self.bos_id:
|
||||
start = i + len(self.bos_id)
|
||||
end = start
|
||||
found_eos = False
|
||||
while end < len(input_ids):
|
||||
if input_ids[end:end + len(self.eos_id)] == self.eos_id:
|
||||
found_eos = True
|
||||
break
|
||||
end += 1
|
||||
for j in range(start, min(end + len(self.eos_id), self.max_length)):
|
||||
# If eos is not found (e.g. after truncation), only supervise real tokens,
|
||||
# never extend to padding positions.
|
||||
if found_eos:
|
||||
supervise_end = min(end + len(self.eos_id), len(input_ids))
|
||||
else:
|
||||
supervise_end = min(end, len(input_ids))
|
||||
for j in range(start, supervise_end):
|
||||
labels[j] = input_ids[j]
|
||||
i = end + len(self.eos_id) if end < len(input_ids) else len(input_ids)
|
||||
else:
|
||||
@ -109,8 +117,10 @@ class SFTDataset(Dataset):
|
||||
prompt = self.create_chat_prompt(conversations)
|
||||
prompt = post_processing_chat(prompt)
|
||||
input_ids = self.tokenizer(prompt).input_ids[:self.max_length]
|
||||
input_ids += [self.tokenizer.pad_token_id] * (self.max_length - len(input_ids))
|
||||
labels = self.generate_labels(input_ids)
|
||||
pad_len = self.max_length - len(input_ids)
|
||||
input_ids += [self.tokenizer.pad_token_id] * pad_len
|
||||
labels += [-100] * pad_len
|
||||
# # === 调试打印 ===
|
||||
# print(f"\n--- Sample {index} ---")
|
||||
# for i, (x, y) in enumerate(zip(input_ids[:-1], labels[1:])):
|
||||
@ -145,18 +155,21 @@ class DPODataset(Dataset):
|
||||
rejected, tokenize=False, add_generation_prompt=False
|
||||
)
|
||||
rejected_prompt = post_processing_chat(rejected_prompt)
|
||||
chosen_encoding = self.tokenizer(
|
||||
chosen_prompt, truncation=True, max_length=self.max_length, padding='max_length'
|
||||
)
|
||||
rejected_encoding = self.tokenizer(
|
||||
rejected_prompt, truncation=True, max_length=self.max_length, padding='max_length'
|
||||
)
|
||||
|
||||
chosen_input_ids = chosen_encoding['input_ids']
|
||||
chosen_input_ids = self.tokenizer(
|
||||
chosen_prompt, truncation=True, max_length=self.max_length, padding=False
|
||||
)['input_ids']
|
||||
chosen_loss_mask = self.generate_loss_mask(chosen_input_ids)
|
||||
chosen_pad_len = self.max_length - len(chosen_input_ids)
|
||||
chosen_input_ids += [self.padding] * chosen_pad_len
|
||||
chosen_loss_mask += [0] * chosen_pad_len
|
||||
|
||||
rejected_input_ids = rejected_encoding['input_ids']
|
||||
rejected_input_ids = self.tokenizer(
|
||||
rejected_prompt, truncation=True, max_length=self.max_length, padding=False
|
||||
)['input_ids']
|
||||
rejected_loss_mask = self.generate_loss_mask(rejected_input_ids)
|
||||
rejected_pad_len = self.max_length - len(rejected_input_ids)
|
||||
rejected_input_ids += [self.padding] * rejected_pad_len
|
||||
rejected_loss_mask += [0] * rejected_pad_len
|
||||
x_chosen = torch.tensor(chosen_input_ids[:-1], dtype=torch.long)
|
||||
y_chosen = torch.tensor(chosen_input_ids[1:], dtype=torch.long)
|
||||
mask_chosen = torch.tensor(chosen_loss_mask[1:], dtype=torch.long)
|
||||
@ -180,11 +193,17 @@ class DPODataset(Dataset):
|
||||
if input_ids[i:i + len(self.bos_id)] == self.bos_id:
|
||||
start = i + len(self.bos_id)
|
||||
end = start
|
||||
found_eos = False
|
||||
while end < len(input_ids):
|
||||
if input_ids[end:end + len(self.eos_id)] == self.eos_id:
|
||||
found_eos = True
|
||||
break
|
||||
end += 1
|
||||
for j in range(start, min(end + len(self.eos_id), self.max_length)):
|
||||
if found_eos:
|
||||
supervise_end = min(end + len(self.eos_id), len(input_ids))
|
||||
else:
|
||||
supervise_end = min(end, len(input_ids))
|
||||
for j in range(start, supervise_end):
|
||||
loss_mask[j] = 1
|
||||
i = end + len(self.eos_id) if end < len(input_ids) else len(input_ids)
|
||||
else:
|
||||
|
||||
15
eval_llm.py
15
eval_llm.py
@ -2,6 +2,7 @@ import time
|
||||
import argparse
|
||||
import random
|
||||
import warnings
|
||||
import os
|
||||
import torch
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer
|
||||
from model.model_minimind import MiniMindConfig, MiniMindForCausalLM
|
||||
@ -11,7 +12,18 @@ warnings.filterwarnings('ignore')
|
||||
|
||||
def init_model(args):
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.load_from)
|
||||
if 'model' in args.load_from:
|
||||
use_torch_backend = args.backend == 'torch'
|
||||
if args.backend == 'auto':
|
||||
if args.load_from == 'model':
|
||||
use_torch_backend = True
|
||||
elif os.path.isdir(args.load_from):
|
||||
has_hf_config = os.path.exists(os.path.join(args.load_from, "config.json"))
|
||||
has_hf_weight = os.path.exists(os.path.join(args.load_from, "model.safetensors")) or os.path.exists(os.path.join(args.load_from, "pytorch_model.bin"))
|
||||
use_torch_backend = not (has_hf_config or has_hf_weight)
|
||||
else:
|
||||
use_torch_backend = False
|
||||
|
||||
if use_torch_backend:
|
||||
model = MiniMindForCausalLM(MiniMindConfig(
|
||||
hidden_size=args.hidden_size,
|
||||
num_hidden_layers=args.num_hidden_layers,
|
||||
@ -32,6 +44,7 @@ def init_model(args):
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="MiniMind模型推理与对话")
|
||||
parser.add_argument('--load_from', default='model', type=str, help="模型加载路径(model=原生torch权重,其他路径=transformers格式)")
|
||||
parser.add_argument('--backend', default='auto', choices=['auto', 'torch', 'hf'], help="模型加载后端(auto/torch/hf)")
|
||||
parser.add_argument('--save_dir', default='out', type=str, help="模型权重目录")
|
||||
parser.add_argument('--weight', default='full_sft', type=str, help="权重名称前缀(pretrain, full_sft, rlhf, reason, ppo_actor, grpo, spo)")
|
||||
parser.add_argument('--lora_weight', default='None', type=str, help="LoRA权重名称(None表示不使用,可选:lora_identity, lora_medical)")
|
||||
|
||||
@ -240,6 +240,7 @@ def calculate_rewards(prompts, completions, gt_batch, tools_batch, num_gen, rewa
|
||||
# ================================ 工具与 Reward = End ================================
|
||||
def rl_train_epoch(epoch, loader, iters, rollout_engine, ref_model, reward_model=None, start_step=0, wandb=None, use_sglang=False):
|
||||
last_step = start_step
|
||||
accum_counter = start_step % args.accumulation_steps
|
||||
for step, batch in enumerate(loader, start=start_step + 1):
|
||||
messages_batch = batch['messages']
|
||||
tools_batch = batch['tools']
|
||||
@ -328,8 +329,9 @@ def rl_train_epoch(epoch, loader, iters, rollout_engine, ref_model, reward_model
|
||||
if valid_rows.any() else per_token_loss.sum() * 0.0)
|
||||
loss = (policy_loss + aux_loss) / args.accumulation_steps
|
||||
loss.backward()
|
||||
accum_counter += 1
|
||||
|
||||
if step % args.accumulation_steps == 0:
|
||||
if accum_counter % args.accumulation_steps == 0:
|
||||
if args.grad_clip > 0: torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
|
||||
optimizer.step(); scheduler.step(); optimizer.zero_grad()
|
||||
if is_main_process() and step % args.save_interval == 0: rollout_engine.update_policy(model)
|
||||
@ -362,7 +364,7 @@ def rl_train_epoch(epoch, loader, iters, rollout_engine, ref_model, reward_model
|
||||
del per_token_logps, ref_per_token_logps
|
||||
del completions, rewards, grouped_rewards, mean_r, std_r, advantages, completion_mask
|
||||
|
||||
if last_step > start_step and last_step % args.accumulation_steps != 0:
|
||||
if last_step > start_step and accum_counter % args.accumulation_steps != 0:
|
||||
if args.grad_clip > 0: torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
|
||||
optimizer.step(); scheduler.step(); optimizer.zero_grad()
|
||||
if is_main_process() and last_step % args.save_interval == 0: rollout_engine.update_policy(model)
|
||||
|
||||
@ -38,6 +38,7 @@ def distillation_loss(student_logits, teacher_logits, temperature=1.0, reduction
|
||||
def train_epoch(epoch, loader, iters, teacher_model, lm_config_student, start_step=0, wandb=None, alpha=0.0, temperature=1.0):
|
||||
start_time = time.time()
|
||||
last_step = start_step
|
||||
accum_counter = start_step % args.accumulation_steps
|
||||
|
||||
if teacher_model is not None:
|
||||
teacher_model.eval()
|
||||
@ -92,8 +93,9 @@ def train_epoch(epoch, loader, iters, teacher_model, lm_config_student, start_st
|
||||
loss = (alpha * ce_loss + (1 - alpha) * distill_loss) / args.accumulation_steps
|
||||
|
||||
scaler.scale(loss).backward()
|
||||
accum_counter += 1
|
||||
|
||||
if step % args.accumulation_steps == 0:
|
||||
if accum_counter % args.accumulation_steps == 0:
|
||||
scaler.unscale_(optimizer)
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
|
||||
scaler.step(optimizer)
|
||||
@ -134,7 +136,7 @@ def train_epoch(epoch, loader, iters, teacher_model, lm_config_student, start_st
|
||||
|
||||
del input_ids, labels, loss_mask, res, student_logits, ce_loss, distill_loss, loss
|
||||
|
||||
if last_step > start_step and last_step % args.accumulation_steps != 0:
|
||||
if last_step > start_step and accum_counter % args.accumulation_steps != 0:
|
||||
scaler.unscale_(optimizer)
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
|
||||
scaler.step(optimizer)
|
||||
|
||||
@ -52,6 +52,7 @@ def dpo_loss(ref_log_probs, policy_log_probs, mask, beta):
|
||||
def train_epoch(epoch, loader, iters, ref_model, lm_config, start_step=0, wandb=None, beta=0.1):
|
||||
start_time = time.time()
|
||||
last_step = start_step
|
||||
accum_counter = start_step % args.accumulation_steps
|
||||
|
||||
for step, batch in enumerate(loader, start=start_step + 1):
|
||||
last_step = step
|
||||
@ -84,8 +85,9 @@ def train_epoch(epoch, loader, iters, ref_model, lm_config, start_step=0, wandb=
|
||||
loss = loss / args.accumulation_steps
|
||||
|
||||
scaler.scale(loss).backward()
|
||||
accum_counter += 1
|
||||
|
||||
if step % args.accumulation_steps == 0:
|
||||
if accum_counter % args.accumulation_steps == 0:
|
||||
scaler.unscale_(optimizer)
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
|
||||
scaler.step(optimizer)
|
||||
@ -119,7 +121,7 @@ def train_epoch(epoch, loader, iters, ref_model, lm_config, start_step=0, wandb=
|
||||
del x_chosen, x_rejected, y_chosen, y_rejected, mask_chosen, mask_rejected, x, y, mask
|
||||
del ref_outputs, ref_logits, ref_log_probs, outputs, logits, policy_log_probs, loss
|
||||
|
||||
if last_step > start_step and last_step % args.accumulation_steps != 0:
|
||||
if last_step > start_step and accum_counter % args.accumulation_steps != 0:
|
||||
scaler.unscale_(optimizer)
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
|
||||
scaler.step(optimizer)
|
||||
|
||||
@ -23,6 +23,7 @@ warnings.filterwarnings('ignore')
|
||||
def train_epoch(epoch, loader, iters, start_step=0, wandb=None):
|
||||
start_time = time.time()
|
||||
last_step = start_step
|
||||
accum_counter = start_step % args.accumulation_steps
|
||||
for step, (input_ids, labels) in enumerate(loader, start=start_step + 1):
|
||||
input_ids = input_ids.to(args.device)
|
||||
labels = labels.to(args.device)
|
||||
@ -37,8 +38,9 @@ def train_epoch(epoch, loader, iters, start_step=0, wandb=None):
|
||||
loss = loss / args.accumulation_steps
|
||||
|
||||
scaler.scale(loss).backward()
|
||||
accum_counter += 1
|
||||
|
||||
if step % args.accumulation_steps == 0:
|
||||
if accum_counter % args.accumulation_steps == 0:
|
||||
scaler.unscale_(optimizer)
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
|
||||
|
||||
@ -72,7 +74,7 @@ def train_epoch(epoch, loader, iters, start_step=0, wandb=None):
|
||||
|
||||
del input_ids, labels, res, loss
|
||||
|
||||
if last_step > start_step and last_step % args.accumulation_steps != 0:
|
||||
if last_step > start_step and accum_counter % args.accumulation_steps != 0:
|
||||
scaler.unscale_(optimizer)
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
|
||||
scaler.step(optimizer)
|
||||
|
||||
@ -68,7 +68,10 @@ def calculate_rewards(prompts, responses, reward_model):
|
||||
|
||||
|
||||
def grpo_train_epoch(epoch, loader, iters, rollout_engine, ref_model, reward_model, start_step=0, wandb=None, use_sglang=False):
|
||||
last_step = start_step
|
||||
accum_counter = start_step % args.accumulation_steps
|
||||
for step, batch in enumerate(loader, start=start_step + 1):
|
||||
last_step = step
|
||||
prompts = batch['prompt'] # list[str], length B
|
||||
prompt_inputs = tokenizer(prompts, return_tensors="pt", padding=True, return_token_type_ids=False,
|
||||
padding_side="left", add_special_tokens=False).to(args.device)
|
||||
@ -142,8 +145,9 @@ def grpo_train_epoch(epoch, loader, iters, rollout_engine, ref_model, reward_mod
|
||||
policy_loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
|
||||
loss = (policy_loss + aux_loss) / args.accumulation_steps # scalar
|
||||
loss.backward()
|
||||
accum_counter += 1
|
||||
|
||||
if step % args.accumulation_steps == 0:
|
||||
if accum_counter % args.accumulation_steps == 0:
|
||||
if args.grad_clip > 0:
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
|
||||
optimizer.step()
|
||||
@ -190,7 +194,7 @@ def grpo_train_epoch(epoch, loader, iters, rollout_engine, ref_model, reward_mod
|
||||
model.train()
|
||||
del state_dict
|
||||
|
||||
if step > start_step and step % args.accumulation_steps != 0:
|
||||
if last_step > start_step and accum_counter % args.accumulation_steps != 0:
|
||||
if args.grad_clip > 0:
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
|
||||
optimizer.step()
|
||||
|
||||
@ -24,6 +24,7 @@ warnings.filterwarnings('ignore')
|
||||
def train_epoch(epoch, loader, iters, lora_params, start_step=0, wandb=None):
|
||||
start_time = time.time()
|
||||
last_step = start_step
|
||||
accum_counter = start_step % args.accumulation_steps
|
||||
for step, (input_ids, labels) in enumerate(loader, start=start_step + 1):
|
||||
input_ids = input_ids.to(args.device)
|
||||
labels = labels.to(args.device)
|
||||
@ -38,8 +39,9 @@ def train_epoch(epoch, loader, iters, lora_params, start_step=0, wandb=None):
|
||||
loss = loss / args.accumulation_steps
|
||||
|
||||
scaler.scale(loss).backward()
|
||||
accum_counter += 1
|
||||
|
||||
if step % args.accumulation_steps == 0:
|
||||
if accum_counter % args.accumulation_steps == 0:
|
||||
scaler.unscale_(optimizer)
|
||||
torch.nn.utils.clip_grad_norm_(lora_params, args.grad_clip)
|
||||
scaler.step(optimizer)
|
||||
@ -66,7 +68,7 @@ def train_epoch(epoch, loader, iters, lora_params, start_step=0, wandb=None):
|
||||
|
||||
del input_ids, labels, res, loss
|
||||
|
||||
if last_step > start_step and last_step % args.accumulation_steps != 0:
|
||||
if last_step > start_step and accum_counter % args.accumulation_steps != 0:
|
||||
scaler.unscale_(optimizer)
|
||||
torch.nn.utils.clip_grad_norm_(lora_params, args.grad_clip)
|
||||
scaler.step(optimizer)
|
||||
|
||||
@ -23,22 +23,25 @@ warnings.filterwarnings('ignore')
|
||||
def train_epoch(epoch, loader, iters, start_step=0, wandb=None):
|
||||
start_time = time.time()
|
||||
last_step = start_step
|
||||
accum_counter = start_step % args.accumulation_steps
|
||||
for step, (input_ids, labels) in enumerate(loader, start=start_step + 1):
|
||||
input_ids = input_ids.to(args.device)
|
||||
labels = labels.to(args.device)
|
||||
attention_mask = (input_ids != tokenizer.pad_token_id).long()
|
||||
last_step = step
|
||||
lr = get_lr(epoch * iters + step, args.epochs * iters, args.learning_rate)
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group['lr'] = lr
|
||||
|
||||
with autocast_ctx:
|
||||
res = model(input_ids, labels=labels)
|
||||
res = model(input_ids, attention_mask=attention_mask, labels=labels)
|
||||
loss = res.loss + res.aux_loss
|
||||
loss = loss / args.accumulation_steps
|
||||
|
||||
scaler.scale(loss).backward()
|
||||
accum_counter += 1
|
||||
|
||||
if step % args.accumulation_steps == 0:
|
||||
if accum_counter % args.accumulation_steps == 0:
|
||||
scaler.unscale_(optimizer)
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
|
||||
|
||||
@ -69,9 +72,9 @@ def train_epoch(epoch, loader, iters, start_step=0, wandb=None):
|
||||
model.train()
|
||||
del state_dict
|
||||
|
||||
del input_ids, labels, res, loss
|
||||
del input_ids, labels, attention_mask, res, loss
|
||||
|
||||
if last_step > start_step and last_step % args.accumulation_steps != 0:
|
||||
if last_step > start_step and accum_counter % args.accumulation_steps != 0:
|
||||
scaler.unscale_(optimizer)
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
|
||||
scaler.step(optimizer)
|
||||
@ -83,7 +86,7 @@ if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="MiniMind Pretraining")
|
||||
parser.add_argument("--save_dir", type=str, default="../out", help="模型保存目录")
|
||||
parser.add_argument('--save_weight', default='pretrain', type=str, help="保存权重的前缀名")
|
||||
parser.add_argument("--epochs", type=int, default=2, help="训练轮数")
|
||||
parser.add_argument("--epochs", type=int, default=3, help="训练轮数")
|
||||
parser.add_argument("--batch_size", type=int, default=32, help="batch size")
|
||||
parser.add_argument("--learning_rate", type=float, default=5e-4, help="初始学习率")
|
||||
parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu", help="训练设备")
|
||||
@ -94,15 +97,15 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--log_interval", type=int, default=100, help="日志打印间隔")
|
||||
parser.add_argument("--save_interval", type=int, default=1000, help="模型保存间隔")
|
||||
parser.add_argument('--hidden_size', default=768, type=int, help="隐藏层维度")
|
||||
parser.add_argument('--num_hidden_layers', default=8, type=int, help="隐藏层数量")
|
||||
parser.add_argument('--num_hidden_layers', default=12, type=int, help="隐藏层数量")
|
||||
parser.add_argument('--max_seq_len', default=340, type=int, help="训练的最大截断长度(中文1token≈1.5~1.7字符)")
|
||||
parser.add_argument('--use_moe', default=0, type=int, choices=[0, 1], help="是否使用MoE架构(0=否,1=是)")
|
||||
parser.add_argument("--data_path", type=str, default="../dataset/pretrain_t2t_mini.jsonl", help="预训练数据路径")
|
||||
parser.add_argument('--from_weight', default='none', type=str, help="基于哪个权重训练,为none则从头开始")
|
||||
parser.add_argument('--from_resume', default=0, type=int, choices=[0, 1], help="是否自动检测&续训(0=否,1=是)")
|
||||
parser.add_argument('--from_resume', default=1, type=int, choices=[0, 1], help="是否自动检测&续训(0=否,1=是)")
|
||||
parser.add_argument("--use_wandb", action="store_true", help="是否使用wandb")
|
||||
parser.add_argument("--wandb_project", type=str, default="MiniMind-Pretrain", help="wandb项目名")
|
||||
parser.add_argument("--use_compile", default=0, type=int, choices=[0, 1], help="是否使用torch.compile加速(0=否,1=是)")
|
||||
parser.add_argument("--use_compile", default=1, type=int, choices=[0, 1], help="是否使用torch.compile加速(0=否,1=是)")
|
||||
args = parser.parse_args()
|
||||
|
||||
# ========== 1. 初始化环境和随机种子 ==========
|
||||
|
||||
Loading…
Reference in New Issue
Block a user