This commit is contained in:
jingyaogong 2025-04-27 09:56:49 +08:00
parent 765c79c7a2
commit 29454c31af
3 changed files with 4 additions and 13 deletions

View File

@ -107,10 +107,10 @@ def main():
# MiniMind2-moe (145M)(hidden_size=640, num_hidden_layers=8, use_moe=True) # MiniMind2-moe (145M)(hidden_size=640, num_hidden_layers=8, use_moe=True)
# MiniMind2-Small (26M)(hidden_size=512, num_hidden_layers=8) # MiniMind2-Small (26M)(hidden_size=512, num_hidden_layers=8)
# MiniMind2 (104M)(hidden_size=768, num_hidden_layers=16) # MiniMind2 (104M)(hidden_size=768, num_hidden_layers=16)
parser.add_argument('--hidden_size', default=640, type=int) parser.add_argument('--hidden_size', default=512, type=int)
parser.add_argument('--num_hidden_layers', default=8, type=int) parser.add_argument('--num_hidden_layers', default=8, type=int)
parser.add_argument('--max_seq_len', default=8192, type=int) parser.add_argument('--max_seq_len', default=8192, type=int)
parser.add_argument('--use_moe', default=True, type=bool) parser.add_argument('--use_moe', default=False, type=bool)
# 携带历史对话上下文条数 # 携带历史对话上下文条数
# history_cnt需要设为偶数即【用户问题, 模型回答】为1组设置为0时即当前query不携带历史上文 # history_cnt需要设为偶数即【用户问题, 模型回答】为1组设置为0时即当前query不携带历史上文
# 模型未经过外推微调时在更长的上下文的chat_template时难免出现性能的明显退化因此需要注意此处设置 # 模型未经过外推微调时在更长的上下文的chat_template时难免出现性能的明显退化因此需要注意此处设置

View File

@ -170,7 +170,7 @@ class Attention(nn.Module):
repeat_kv(xv, self.n_rep).transpose(1, 2) repeat_kv(xv, self.n_rep).transpose(1, 2)
) )
if False and self.flash and seq_len != 1: if self.flash and seq_len != 1:
dropout_p = self.dropout if self.training else 0.0 dropout_p = self.dropout if self.training else 0.0
attn_mask = None attn_mask = None
if attention_mask is not None: if attention_mask is not None:

View File

@ -97,19 +97,10 @@ def init_model(lm_config):
tokenizer = AutoTokenizer.from_pretrained('../model') tokenizer = AutoTokenizer.from_pretrained('../model')
model = MiniMindForCausalLM(lm_config) model = MiniMindForCausalLM(lm_config)
moe_path = '_moe' if lm_config.use_moe else '' moe_path = '_moe' if lm_config.use_moe else ''
ckp = f'{args.save_dir}/full_sft_{lm_config.hidden_size}{moe_path}.pth' ckp = f'{args.save_dir}/pretrain_{lm_config.hidden_size}{moe_path}.pth'
state_dict = torch.load(ckp, map_location=args.device) state_dict = torch.load(ckp, map_location=args.device)
model.load_state_dict(state_dict, strict=False) model.load_state_dict(state_dict, strict=False)
# 冻结所有参数
for param in model.parameters():
param.requires_grad = False
# 只解冻注意力机制中的投影层参数
for name, param in model.named_parameters():
if any(proj in name for proj in ['q_proj', 'k_proj', 'v_proj', 'o_proj']):
param.requires_grad = True
Logger(f'LLM可训练总参数量{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万') Logger(f'LLM可训练总参数量{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万')
model = model.to(args.device) model = model.to(args.device)
return model, tokenizer return model, tokenizer