mirror of
https://github.com/jingyaogong/minimind.git
synced 2026-04-23 15:58:15 +08:00
[fix] model device
This commit is contained in:
parent
acd5925193
commit
1713c24114
@ -140,7 +140,7 @@ if __name__ == "__main__":
|
||||
wandb.init(project=args.wandb_project, name=wandb_run_name, id=wandb_id, resume=resume)
|
||||
|
||||
# ========== 5. 定义模型、数据、优化器 ==========
|
||||
model, tokenizer = init_model(lm_config, args.from_weight)
|
||||
model, tokenizer = init_model(lm_config, args.from_weight, device=args.device)
|
||||
train_ds = SFTDataset(args.data_path, tokenizer, max_length=args.max_seq_len)
|
||||
train_sampler = DistributedSampler(train_ds) if dist.is_initialized() else None
|
||||
scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype == 'float16'))
|
||||
|
||||
@ -184,9 +184,9 @@ if __name__ == "__main__":
|
||||
wandb.init(project=args.wandb_project, name=wandb_run_name, id=wandb_id, resume=resume)
|
||||
|
||||
# ========== 5. 定义学生和教师模型 ==========
|
||||
model, tokenizer = init_model(lm_config_student, args.from_student_weight)
|
||||
model, tokenizer = init_model(lm_config_student, args.from_student_weight, device=args.device)
|
||||
Logger(f'学生模型总参数量:{sum(p.numel() for p in model.parameters()) / 1e6:.3f} M')
|
||||
teacher_model, _ = init_model(lm_config_teacher, args.from_teacher_weight)
|
||||
teacher_model, _ = init_model(lm_config_teacher, args.from_teacher_weight, device=args.device)
|
||||
teacher_model.eval()
|
||||
teacher_model.requires_grad_(False)
|
||||
Logger(f'教师模型总参数量:{sum(p.numel() for p in teacher_model.parameters()) / 1e6:.3f} M')
|
||||
|
||||
@ -166,10 +166,10 @@ if __name__ == "__main__":
|
||||
wandb.init(project=args.wandb_project, name=wandb_run_name, id=wandb_id, resume=resume)
|
||||
|
||||
# ========== 5. 定义模型和参考模型 ==========
|
||||
model, tokenizer = init_model(lm_config, args.from_weight)
|
||||
model, tokenizer = init_model(lm_config, args.from_weight, device=args.device)
|
||||
Logger(f'策略模型总参数量:{sum(p.numel() for p in model.parameters()) / 1e6:.3f} M')
|
||||
# 初始化参考模型(ref_model冻结)
|
||||
ref_model, _ = init_model(lm_config, args.from_weight)
|
||||
ref_model, _ = init_model(lm_config, args.from_weight, device=args.device)
|
||||
ref_model.eval()
|
||||
ref_model.requires_grad_(False)
|
||||
Logger(f'参考模型总参数量:{sum(p.numel() for p in ref_model.parameters()) / 1e6:.3f} M')
|
||||
|
||||
@ -128,7 +128,7 @@ if __name__ == "__main__":
|
||||
wandb.init(project=args.wandb_project, name=wandb_run_name, id=wandb_id, resume=resume)
|
||||
|
||||
# ========== 5. 定义模型、数据、优化器 ==========
|
||||
model, tokenizer = init_model(lm_config, args.from_weight)
|
||||
model, tokenizer = init_model(lm_config, args.from_weight, device=args.device)
|
||||
train_ds = SFTDataset(args.data_path, tokenizer, max_length=args.max_seq_len)
|
||||
train_sampler = DistributedSampler(train_ds) if dist.is_initialized() else None
|
||||
scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype == 'float16'))
|
||||
|
||||
@ -19,7 +19,7 @@ from torch.optim.lr_scheduler import CosineAnnealingLR
|
||||
from transformers import AutoModel
|
||||
from model.model_minimind import MiniMindConfig, MiniMindForCausalLM
|
||||
from dataset.lm_dataset import RLAIFDataset
|
||||
from trainer.trainer_utils import Logger, is_main_process, lm_checkpoint, init_distributed_mode, setup_seed, SkipBatchSampler
|
||||
from trainer.trainer_utils import Logger, is_main_process, lm_checkpoint, init_distributed_mode, setup_seed, SkipBatchSampler, init_model
|
||||
|
||||
warnings.filterwarnings('ignore')
|
||||
|
||||
@ -240,25 +240,17 @@ if __name__ == "__main__":
|
||||
wandb.init(project=args.wandb_project, name=wandb_run_name, id=wandb_id, resume=resume)
|
||||
|
||||
# ========== 5. 初始化模型和数据 ==========
|
||||
tokenizer = AutoTokenizer.from_pretrained('../model/')
|
||||
moe_suffix = '_moe' if lm_config.use_moe else ''
|
||||
base_weight = "reason" if args.reasoning == 1 else "full_sft"
|
||||
ckp = f'{args.save_dir}/{base_weight}_{lm_config.hidden_size}{moe_suffix}.pth'
|
||||
state_dict = torch.load(ckp, map_location=args.device)
|
||||
# Policy模型
|
||||
model = MiniMindForCausalLM(lm_config)
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
model = model.to(args.device)
|
||||
Logger(f'Policy模型总参数量:{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} M')
|
||||
model, tokenizer = init_model(lm_config, base_weight, device=args.device)
|
||||
# Reference模型
|
||||
ref_model = MiniMindForCausalLM(lm_config)
|
||||
ref_model.load_state_dict(state_dict, strict=False)
|
||||
ref_model.eval().requires_grad_(False)
|
||||
ref_model = ref_model.to(args.device)
|
||||
ref_model, _ = init_model(lm_config, base_weight, device=args.device)
|
||||
ref_model = ref_model.eval().requires_grad_(False)
|
||||
# Reward模型
|
||||
reward_model = AutoModel.from_pretrained(
|
||||
args.reward_model_path, device_map="cuda", torch_dtype=torch.float16, trust_remote_code=True
|
||||
).to(args.device).eval().requires_grad_(False)
|
||||
args.reward_model_path, torch_dtype=torch.float16, trust_remote_code=True
|
||||
)
|
||||
reward_model = reward_model.to(args.device).eval().requires_grad_(False)
|
||||
reward_tokenizer = AutoTokenizer.from_pretrained(args.reward_model_path, trust_remote_code=True)
|
||||
# 数据和优化器
|
||||
train_ds = RLAIFDataset(args.data_path, tokenizer, max_length=lm_config.max_seq_len)
|
||||
|
||||
@ -123,7 +123,7 @@ if __name__ == "__main__":
|
||||
wandb.init(project=args.wandb_project, name=wandb_run_name, id=wandb_id, resume=resume)
|
||||
|
||||
# ========== 5. 定义模型、应用LoRA、冻结非LoRA参数 ==========
|
||||
model, tokenizer = init_model(lm_config, args.from_weight)
|
||||
model, tokenizer = init_model(lm_config, args.from_weight, device=args.device)
|
||||
apply_lora(model)
|
||||
|
||||
# 统计参数
|
||||
|
||||
@ -20,7 +20,7 @@ from torch.optim.lr_scheduler import CosineAnnealingLR
|
||||
from transformers import AutoModel
|
||||
from model.model_minimind import MiniMindConfig, MiniMindForCausalLM
|
||||
from dataset.lm_dataset import RLAIFDataset
|
||||
from trainer.trainer_utils import Logger, is_main_process, lm_checkpoint, init_distributed_mode, setup_seed, SkipBatchSampler
|
||||
from trainer.trainer_utils import Logger, is_main_process, lm_checkpoint, init_distributed_mode, setup_seed, SkipBatchSampler, init_model
|
||||
|
||||
warnings.filterwarnings('ignore')
|
||||
|
||||
@ -290,33 +290,28 @@ if __name__ == "__main__":
|
||||
wandb.init(project=args.wandb_project, name=wandb_run_name, id=wandb_id, resume=resume)
|
||||
|
||||
# ========== 5. 初始化模型和数据 ==========
|
||||
tokenizer = AutoTokenizer.from_pretrained('../model/', padding_side='left')
|
||||
moe_suffix = '_moe' if lm_config.use_moe else ''
|
||||
base_weight = "reason" if args.reasoning == 1 else "full_sft"
|
||||
# Actor模型
|
||||
actor_model, tokenizer = init_model(lm_config, base_weight, device=args.device)
|
||||
tokenizer.padding_side = 'left' # PPO需要左侧padding
|
||||
# Old Actor模型
|
||||
old_actor_model, _ = init_model(lm_config, base_weight, device=args.device)
|
||||
old_actor_model = old_actor_model.eval().requires_grad_(False)
|
||||
# Reference模型
|
||||
ref_model, _ = init_model(lm_config, base_weight, device=args.device)
|
||||
ref_model = ref_model.eval().requires_grad_(False)
|
||||
# Critic模型
|
||||
moe_suffix = '_moe' if lm_config.use_moe else ''
|
||||
ckp = f'{args.save_dir}/{base_weight}_{lm_config.hidden_size}{moe_suffix}.pth'
|
||||
state_dict = torch.load(ckp, map_location=args.device)
|
||||
# Actor模型
|
||||
actor_model = MiniMindForCausalLM(lm_config)
|
||||
actor_model.load_state_dict(state_dict, strict=False)
|
||||
actor_model = actor_model.to(args.device)
|
||||
Logger(f'Actor模型总参数量:{sum(p.numel() for p in actor_model.parameters() if p.requires_grad) / 1e6:.3f} M')
|
||||
# Old Actor模型
|
||||
old_actor_model = MiniMindForCausalLM(lm_config)
|
||||
old_actor_model.load_state_dict(state_dict, strict=False)
|
||||
old_actor_model = old_actor_model.eval().requires_grad_(False).to(args.device)
|
||||
# Reference模型
|
||||
ref_model = MiniMindForCausalLM(lm_config)
|
||||
ref_model.load_state_dict(state_dict, strict=False)
|
||||
ref_model = ref_model.eval().requires_grad_(False).to(args.device)
|
||||
# Critic模型
|
||||
critic_model = CriticModel(lm_config)
|
||||
critic_model.load_state_dict(state_dict, strict=False)
|
||||
critic_model = critic_model.to(args.device)
|
||||
Logger(f'Critic模型总参数量:{sum(p.numel() for p in critic_model.parameters() if p.requires_grad) / 1e6:.3f} M')
|
||||
# Reward模型
|
||||
reward_model = AutoModel.from_pretrained(
|
||||
args.reward_model_path, device_map="cuda", torch_dtype=torch.float32, trust_remote_code=True
|
||||
).to(args.device).eval().requires_grad_(False)
|
||||
args.reward_model_path, torch_dtype=torch.float16, trust_remote_code=True
|
||||
)
|
||||
reward_model = reward_model.to(args.device).eval().requires_grad_(False)
|
||||
reward_tokenizer = AutoTokenizer.from_pretrained(args.reward_model_path, trust_remote_code=True)
|
||||
# 数据和优化器
|
||||
train_ds = RLAIFDataset(args.data_path, tokenizer, max_length=(args.max_seq_len + args.max_gen_len))
|
||||
|
||||
@ -127,7 +127,7 @@ if __name__ == "__main__":
|
||||
wandb.init(project=args.wandb_project, name=wandb_run_name, id=wandb_id, resume=resume)
|
||||
|
||||
# ========== 5. 定义模型、数据、优化器 ==========
|
||||
model, tokenizer = init_model(lm_config, args.from_weight)
|
||||
model, tokenizer = init_model(lm_config, args.from_weight, device=args.device)
|
||||
train_ds = PretrainDataset(args.data_path, tokenizer, max_length=args.max_seq_len)
|
||||
train_sampler = DistributedSampler(train_ds) if dist.is_initialized() else None
|
||||
scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype == 'float16'))
|
||||
|
||||
@ -19,7 +19,7 @@ from torch.optim.lr_scheduler import CosineAnnealingLR
|
||||
from transformers import AutoModel
|
||||
from model.model_minimind import MiniMindConfig, MiniMindForCausalLM
|
||||
from dataset.lm_dataset import RLAIFDataset
|
||||
from trainer.trainer_utils import Logger, is_main_process, lm_checkpoint, init_distributed_mode, setup_seed, SkipBatchSampler
|
||||
from trainer.trainer_utils import Logger, is_main_process, lm_checkpoint, init_distributed_mode, setup_seed, SkipBatchSampler, init_model
|
||||
|
||||
warnings.filterwarnings('ignore')
|
||||
|
||||
@ -287,25 +287,17 @@ if __name__ == "__main__":
|
||||
wandb.init(project=args.wandb_project, name=wandb_run_name, id=wandb_id, resume=resume)
|
||||
|
||||
# ========== 5. 初始化模型(Policy, Ref, Reward)和Value Tracker、数据 ==========
|
||||
tokenizer = AutoTokenizer.from_pretrained('../model/')
|
||||
moe_suffix = '_moe' if lm_config.use_moe else ''
|
||||
base_weight = "reason" if args.reasoning == 1 else "full_sft"
|
||||
ckp = f'{args.save_dir}/{base_weight}_{lm_config.hidden_size}{moe_suffix}.pth'
|
||||
state_dict = torch.load(ckp, map_location=args.device)
|
||||
# Policy模型
|
||||
model = MiniMindForCausalLM(lm_config)
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
model = model.to(args.device)
|
||||
Logger(f'Policy模型总参数量:{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} M')
|
||||
model, tokenizer = init_model(lm_config, base_weight, device=args.device)
|
||||
# Reference模型
|
||||
ref_model = MiniMindForCausalLM(lm_config)
|
||||
ref_model.load_state_dict(state_dict, strict=False)
|
||||
ref_model.eval().requires_grad_(False)
|
||||
ref_model = ref_model.to(args.device)
|
||||
ref_model, _ = init_model(lm_config, base_weight, device=args.device)
|
||||
ref_model = ref_model.eval().requires_grad_(False)
|
||||
# Reward模型
|
||||
reward_model = AutoModel.from_pretrained(
|
||||
args.reward_model_path, device_map="cuda", torch_dtype=torch.float16, trust_remote_code=True
|
||||
).to(args.device).eval().requires_grad_(False)
|
||||
args.reward_model_path, torch_dtype=torch.float16, trust_remote_code=True
|
||||
)
|
||||
reward_model = reward_model.to(args.device).eval().requires_grad_(False)
|
||||
reward_tokenizer = AutoTokenizer.from_pretrained(args.reward_model_path, trust_remote_code=True)
|
||||
# Value Tracker
|
||||
value_tracker = AutoAdaptiveValueTracker(rho_mode='kl', rho_const=0.9, D_half=0.06, clip_lower=0.5, clip_upper=0.96)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user