mirror of
https://github.com/jingyaogong/minimind.git
synced 2026-05-01 11:48:14 +08:00
fix bugs
This commit is contained in:
parent
765c79c7a2
commit
29454c31af
@ -107,10 +107,10 @@ def main():
|
||||
# MiniMind2-moe (145M):(hidden_size=640, num_hidden_layers=8, use_moe=True)
|
||||
# MiniMind2-Small (26M):(hidden_size=512, num_hidden_layers=8)
|
||||
# MiniMind2 (104M):(hidden_size=768, num_hidden_layers=16)
|
||||
parser.add_argument('--hidden_size', default=640, type=int)
|
||||
parser.add_argument('--hidden_size', default=512, type=int)
|
||||
parser.add_argument('--num_hidden_layers', default=8, type=int)
|
||||
parser.add_argument('--max_seq_len', default=8192, type=int)
|
||||
parser.add_argument('--use_moe', default=True, type=bool)
|
||||
parser.add_argument('--use_moe', default=False, type=bool)
|
||||
# 携带历史对话上下文条数
|
||||
# history_cnt需要设为偶数,即【用户问题, 模型回答】为1组;设置为0时,即当前query不携带历史上文
|
||||
# 模型未经过外推微调时,在更长的上下文的chat_template时难免出现性能的明显退化,因此需要注意此处设置
|
||||
|
||||
@ -170,7 +170,7 @@ class Attention(nn.Module):
|
||||
repeat_kv(xv, self.n_rep).transpose(1, 2)
|
||||
)
|
||||
|
||||
if False and self.flash and seq_len != 1:
|
||||
if self.flash and seq_len != 1:
|
||||
dropout_p = self.dropout if self.training else 0.0
|
||||
attn_mask = None
|
||||
if attention_mask is not None:
|
||||
|
||||
@ -97,19 +97,10 @@ def init_model(lm_config):
|
||||
tokenizer = AutoTokenizer.from_pretrained('../model')
|
||||
model = MiniMindForCausalLM(lm_config)
|
||||
moe_path = '_moe' if lm_config.use_moe else ''
|
||||
ckp = f'{args.save_dir}/full_sft_{lm_config.hidden_size}{moe_path}.pth'
|
||||
ckp = f'{args.save_dir}/pretrain_{lm_config.hidden_size}{moe_path}.pth'
|
||||
state_dict = torch.load(ckp, map_location=args.device)
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
|
||||
# 冻结所有参数
|
||||
for param in model.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
# 只解冻注意力机制中的投影层参数
|
||||
for name, param in model.named_parameters():
|
||||
if any(proj in name for proj in ['q_proj', 'k_proj', 'v_proj', 'o_proj']):
|
||||
param.requires_grad = True
|
||||
|
||||
Logger(f'LLM可训练总参数量:{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万')
|
||||
model = model.to(args.device)
|
||||
return model, tokenizer
|
||||
|
||||
Loading…
Reference in New Issue
Block a user