mirror of
https://github.com/jingyaogong/minimind.git
synced 2026-05-01 11:48:14 +08:00
[fix] prompt length calculate
This commit is contained in:
parent
f3441b0078
commit
9c98cabc9a
@ -124,7 +124,7 @@ def ppo_train_epoch(epoch, loader, iters, old_actor_model, ref_model, actor_sche
|
||||
prompts = batch["prompt"] # list[str], length B
|
||||
enc = tokenizer(prompts, return_tensors="pt", padding=True, truncation=True,
|
||||
max_length=args.max_seq_len).to(args.device) # input_ids: [B, P], attention_mask: [B, P]
|
||||
prompt_lengths = enc.attention_mask.sum(dim=1) # [B]
|
||||
prompt_lengths = torch.full((enc.input_ids.size(0),), enc.input_ids.shape[1], dtype=torch.long, device=enc.input_ids.device) # [B]
|
||||
|
||||
with torch.no_grad():
|
||||
# DDP 模型需要使用 .module 访问 generate 方法
|
||||
|
||||
Loading…
Reference in New Issue
Block a user