mirror of
https://github.com/jingyaogong/minimind.git
synced 2026-01-13 19:57:20 +08:00
[feat] explicit left padding
This commit is contained in:
parent
a9c56b20e9
commit
11b962da06
@ -123,8 +123,8 @@ def ppo_train_epoch(epoch, loader, iters, old_actor_model, ref_model, actor_sche
|
||||
for step, batch in enumerate(loader, start=start_step + 1):
|
||||
prompts = batch["prompt"] # list[str], length B
|
||||
enc = tokenizer(prompts, return_tensors="pt", padding=True, truncation=True,
|
||||
max_length=args.max_seq_len).to(args.device) # input_ids: [B, P], attention_mask: [B, P]
|
||||
prompt_lengths = torch.full((enc.input_ids.size(0),), enc.input_ids.shape[1], dtype=torch.long, device=enc.input_ids.device) # [B]
|
||||
max_length=args.max_seq_len, padding_side="left").to(args.device) # input_ids: [B, P], attention_mask: [B, P]
|
||||
prompt_length = enc.input_ids.shape[1]
|
||||
|
||||
with torch.no_grad():
|
||||
# DDP 模型需要使用 .module 访问 generate 方法
|
||||
@ -134,7 +134,7 @@ def ppo_train_epoch(epoch, loader, iters, old_actor_model, ref_model, actor_sche
|
||||
max_new_tokens=args.max_gen_len, do_sample=True, temperature=0.8,
|
||||
pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id) # [B, P+R]
|
||||
|
||||
responses_text = [tokenizer.decode(gen_out[i, prompt_lengths[i]:], skip_special_tokens=True) for i in range(len(prompts))]
|
||||
responses_text = [tokenizer.decode(gen_out[i, prompt_length:], skip_special_tokens=True) for i in range(len(prompts))]
|
||||
rewards = calculate_rewards(prompts, responses_text, reward_model, reward_tokenizer) # [B]
|
||||
|
||||
full_mask = (gen_out != tokenizer.pad_token_id).long() # [B, P+R]
|
||||
@ -147,7 +147,7 @@ def ppo_train_epoch(epoch, loader, iters, old_actor_model, ref_model, actor_sche
|
||||
labels = gen_out[:, 1:].clone() # [B, P+R-1]
|
||||
logp_tokens = F.log_softmax(logits[:, :-1], dim=-1).gather(2, labels.unsqueeze(-1)).squeeze(-1) # [B, P+R-1]
|
||||
seq_len = gen_out.size(1) - 1
|
||||
resp_mask = torch.arange(seq_len, device=gen_out.device).unsqueeze(0) >= prompt_lengths.unsqueeze(1)
|
||||
resp_mask = torch.arange(seq_len, device=gen_out.device).unsqueeze(0) >= prompt_length - 1
|
||||
final_mask = resp_mask & (~labels.eq(tokenizer.pad_token_id)) # [B, P+R-1]
|
||||
actor_logp = (logp_tokens * final_mask).sum(dim=1) # [B]
|
||||
|
||||
@ -300,7 +300,6 @@ if __name__ == "__main__":
|
||||
base_weight = "reason" if args.reasoning == 1 else "full_sft"
|
||||
# Actor模型
|
||||
actor_model, tokenizer = init_model(lm_config, base_weight, device=args.device)
|
||||
tokenizer.padding_side = 'left' # PPO需要左侧padding
|
||||
# Old Actor模型
|
||||
old_actor_model, _ = init_model(lm_config, base_weight, device=args.device)
|
||||
old_actor_model = old_actor_model.eval().requires_grad_(False)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user