mirror of
https://github.com/jingyaogong/minimind.git
synced 2026-04-25 08:48:16 +08:00
merge redundant forward passes for logps and aux_loss (in train_grpo.py)
This commit is contained in:
parent
349e74ec7b
commit
b113b494cb
@ -110,22 +110,36 @@ def grpo_train_epoch(epoch, loader, iters, ref_model, reward_model, reward_token
|
||||
|
||||
completion_ids = outputs[:, prompt_inputs["input_ids"].size(1):] # [B*num_gen, R]
|
||||
|
||||
def get_per_token_logps(mdl, input_ids, n_keep):
|
||||
def get_per_token_logps_and_aux(mdl, input_ids, n_keep, use_moe=False):
|
||||
input_ids = input_ids.detach().clone() if input_ids.is_inference() else input_ids
|
||||
logits = mdl(input_ids, logits_to_keep=n_keep + 1).logits[:, :-1, :]
|
||||
|
||||
model_output = mdl(input_ids, logits_to_keep=n_keep + 1)
|
||||
|
||||
logits = model_output.logits[:, :-1, :]
|
||||
per_token_logps = []
|
||||
for logits_row, ids_row in zip(logits, input_ids[:, -n_keep:]):
|
||||
ids_row = ids_row.detach().clone() if ids_row.is_inference() else ids_row
|
||||
per_token_logps.append(torch.gather(logits_row.log_softmax(dim=-1), 1, ids_row.unsqueeze(1)).squeeze(1))
|
||||
return torch.stack(per_token_logps)
|
||||
|
||||
stacked_logps = torch.stack(per_token_logps)
|
||||
if use_moe:
|
||||
aux_loss = getattr(model_output, 'aux_loss', None)
|
||||
if aux_loss is None:
|
||||
aux_loss = torch.tensor(0.0, device=input_ids.device)
|
||||
else:
|
||||
aux_loss = torch.tensor(0.0, device=input_ids.device)
|
||||
return stacked_logps, aux_loss
|
||||
|
||||
with autocast_ctx:
|
||||
per_token_logps = get_per_token_logps(model, outputs, completion_ids.size(1)) # [B*num_gen, R]
|
||||
res = model(outputs) if lm_config.use_moe else None
|
||||
aux_loss = res.aux_loss if res is not None else torch.tensor(0.0, device=args.device)
|
||||
per_token_logps, aux_loss = get_per_token_logps_and_aux(
|
||||
model,
|
||||
outputs,
|
||||
completion_ids.size(1),
|
||||
use_moe=lm_config.use_moe
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
ref_per_token_logps = get_per_token_logps(ref_model, outputs, completion_ids.size(1)) # [B*num_gen, R]
|
||||
ref_per_token_logps, _ = get_per_token_logps_and_aux(ref_model, outputs, completion_ids.size(1)) # [B*num_gen, R]
|
||||
|
||||
completions = tokenizer.batch_decode(completion_ids, skip_special_tokens=True)
|
||||
rewards = calculate_rewards(prompts, completions, reward_model, reward_tokenizer).to(args.device) # [B*num_gen]
|
||||
|
||||
Loading…
Reference in New Issue
Block a user