[fix] bugs

This commit is contained in:
jingyaogong
2026-04-21 13:03:34 +08:00
parent 1718e9a44d
commit 5416a44471
9 changed files with 21 additions and 29 deletions
+15 -15
View File
@@ -1,4 +1,5 @@
"""Rollout Engine - 可插拔的推理引擎
"""
# 如果使用sglang加速,需通过以下命令首先启动(transformers格式)模型:
python -m sglang.launch_server --model-path ./minimind-3 --attention-backend triton --host 0.0.0.0 --port 8998
"""
import os
@@ -9,7 +10,9 @@ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
import requests
import torch
import torch.distributed as dist
from abc import ABC, abstractmethod
from contextlib import nullcontext
from dataclasses import dataclass
from typing import List, Optional, Tuple
from torch import Tensor
@@ -65,8 +68,8 @@ class TorchRolloutEngine(RolloutEngine):
def rollout(self, prompt_ids: Tensor, attention_mask: Tensor, num_generations: int, max_new_tokens: int, temperature: float = 0.8) -> RolloutResult:
model = self.policy_model.module if isinstance(self.policy_model, DistributedDataParallel) else self.policy_model
with torch.no_grad():
ctx = self.autocast_ctx if self.autocast_ctx else nullcontext()
with torch.no_grad(), ctx:
output_ids = model.generate(
input_ids=prompt_ids,
attention_mask=attention_mask,
@@ -77,15 +80,9 @@ class TorchRolloutEngine(RolloutEngine):
pad_token_id=self.tokenizer.pad_token_id,
eos_token_id=self.tokenizer.eos_token_id,
) # [B*num_gen, P+R]
prompt_len = prompt_ids.size(1)
completion_ids = output_ids[:, prompt_len:] # [B*num_gen, R]
from contextlib import nullcontext
ctx = self.autocast_ctx if self.autocast_ctx else nullcontext()
with ctx:
prompt_len = prompt_ids.size(1)
completion_ids = output_ids[:, prompt_len:] # [B*num_gen, R]
per_token_logps = compute_per_token_logps(self.policy_model, output_ids, completion_ids.size(1))
completions = self.tokenizer.batch_decode(completion_ids, skip_special_tokens=True)
return RolloutResult(output_ids, completion_ids, per_token_logps, completions)
@@ -143,6 +140,10 @@ class SGLangRolloutEngine(RolloutEngine):
elif isinstance(item, (int, float)):
logprobs.append(item)
if len(logprobs) < len(completion_ids):
logprobs = [0.0] * (len(completion_ids) - len(logprobs)) + logprobs
elif len(logprobs) > len(completion_ids):
logprobs = logprobs[-len(completion_ids):]
prompt = all_input_ids[i]
full_output = prompt + completion_ids
all_output_ids.append(full_output)
@@ -153,7 +154,6 @@ class SGLangRolloutEngine(RolloutEngine):
device = prompt_ids.device
max_out_len = max(len(ids) for ids in all_output_ids)
max_comp_len = max(len(ids) for ids in all_completion_ids)
max_logp_len = max(len(lp) for lp in all_logprobs)
def pad_to_tensor(seqs, max_len, pad_val=0):
return torch.tensor([s + [pad_val] * (max_len - len(s)) for s in seqs], device=device)
@@ -161,17 +161,17 @@ class SGLangRolloutEngine(RolloutEngine):
return RolloutResult(
output_ids=pad_to_tensor(all_output_ids, max_out_len),
completion_ids=pad_to_tensor(all_completion_ids, max_comp_len),
per_token_logps=pad_to_tensor(all_logprobs, max_logp_len, pad_val=0.0),
per_token_logps=pad_to_tensor(all_logprobs, max_comp_len, pad_val=0.0),
completions=completions,
)
def update_policy(self, model: torch.nn.Module):
if dist.is_initialized() and dist.get_rank() != 0: return True
unwrapped = model.module if isinstance(model, DistributedDataParallel) else model
unwrapped = getattr(unwrapped, '_orig_mod', unwrapped)
abs_path = os.path.abspath(self.shared_ckpt_path)
unwrapped.lm_head.weight = torch.nn.Parameter(unwrapped.lm_head.weight.clone())
state_dict = {k: v.detach().half().cpu() for k, v in unwrapped.state_dict().items()}
unwrapped.save_pretrained(abs_path, state_dict=state_dict, safe_serialization=False)
unwrapped.model.embed_tokens.weight = unwrapped.lm_head.weight
self.tokenizer.save_pretrained(abs_path)
resp = self.http.post(
f"{self.base_url}/update_weights_from_disk",
+2 -2
View File
@@ -403,7 +403,7 @@ if __name__ == "__main__":
parser.add_argument("--debug_interval", type=int, default=20, help="调试日志间隔")
parser.add_argument("--thinking_ratio", type=float, default=0.1, help="按概率开启thinking0.0~1.0")
parser.add_argument("--reward_model_path", type=str, default="../../internlm2-1_8b-reward", help="Reward模型路径")
parser.add_argument("--rollout_engine", type=str, default="sglang", choices=["torch", "sglang"], help="rollout引擎类型")
parser.add_argument("--rollout_engine", type=str, default="torch", choices=["torch", "sglang"], help="rollout引擎类型")
parser.add_argument("--sglang_base_url", type=str, default="http://localhost:8998", help="SGLang服务器URL")
parser.add_argument("--sglang_model_path", type=str, default="../model", help="SGLang tokenizer路径")
parser.add_argument("--sglang_shared_path", type=str, default="./sglang_ckpt_agent", help="SGLang共享存储路径")
@@ -467,8 +467,8 @@ if __name__ == "__main__":
if args.use_compile == 1:
model = torch.compile(model)
Logger('torch.compile enabled')
rollout_engine.update_policy(model)
if dist.is_initialized():
model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
model = DistributedDataParallel(model, device_ids=[local_rank])
if is_main_process(): rollout_engine.update_policy(model)
-1
View File
@@ -226,7 +226,6 @@ if __name__ == "__main__":
model = torch.compile(model)
Logger('torch.compile enabled')
if dist.is_initialized():
model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
model = DistributedDataParallel(model, device_ids=[local_rank])
# ========== 8. 开始训练 ==========
-1
View File
@@ -206,7 +206,6 @@ if __name__ == "__main__":
model = torch.compile(model)
Logger('torch.compile enabled')
if dist.is_initialized():
model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
model = DistributedDataParallel(model, device_ids=[local_rank])
# ========== 8. 开始训练 ==========
-1
View File
@@ -151,7 +151,6 @@ if __name__ == "__main__":
model = torch.compile(model)
Logger('torch.compile enabled')
if dist.is_initialized():
model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
model = DistributedDataParallel(model, device_ids=[local_rank])
# ========== 8. 开始训练 ==========
+2 -3
View File
@@ -236,8 +236,8 @@ if __name__ == "__main__":
parser.add_argument("--debug_mode", action="store_true", help="是否打印训练调试采样")
parser.add_argument("--debug_interval", type=int, default=20, help="debug模式下每隔多少step打印一次采样")
parser.add_argument("--thinking_ratio", type=float, default=0.9, help="按概率开启thinking0.0~1.0")
parser.add_argument("--rollout_engine", type=str, default="sglang", choices=["torch", "sglang"], help="rollout引擎类型")
parser.add_argument("--sglang_base_url", type=str, default="http://localhost:8996", help="SGLang服务器URL")
parser.add_argument("--rollout_engine", type=str, default="torch", choices=["torch", "sglang"], help="rollout引擎类型")
parser.add_argument("--sglang_base_url", type=str, default="http://localhost:8998", help="SGLang服务器URL")
parser.add_argument("--sglang_model_path", type=str, default="../model", help="SGLang tokenizer路径")
parser.add_argument("--sglang_shared_path", type=str, default="./sglang_ckpt_grpo", help="SGLang共享存储路径")
args = parser.parse_args()
@@ -311,7 +311,6 @@ if __name__ == "__main__":
Logger('torch.compile enabled')
rollout_engine.update_policy(model)
if dist.is_initialized():
model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
model = DistributedDataParallel(model, device_ids=[local_rank])
if is_main_process(): rollout_engine.update_policy(model)
-1
View File
@@ -164,7 +164,6 @@ if __name__ == "__main__":
args.use_compile = 0
Logger('[LoRA] monkey-patch forward 与 torch.compile 不兼容,use_compile 已自动关闭')
if dist.is_initialized():
model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
model = DistributedDataParallel(model, device_ids=[local_rank])
# ========== 9. 开始训练 ==========
+2 -4
View File
@@ -339,8 +339,8 @@ if __name__ == "__main__":
parser.add_argument("--debug_mode", action="store_true", help="是否打印训练调试采样")
parser.add_argument("--debug_interval", type=int, default=20, help="debug模式下每隔多少step打印一次采样")
parser.add_argument("--thinking_ratio", type=float, default=0.9, help="按概率开启thinking0.0~1.0")
parser.add_argument("--rollout_engine", type=str, default="sglang", choices=["torch", "sglang"], help="rollout引擎类型")
parser.add_argument("--sglang_base_url", type=str, default="http://localhost:8997", help="SGLang服务器URL")
parser.add_argument("--rollout_engine", type=str, default="torch", choices=["torch", "sglang"], help="rollout引擎类型")
parser.add_argument("--sglang_base_url", type=str, default="http://localhost:8998", help="SGLang服务器URL")
parser.add_argument("--sglang_model_path", type=str, default="../model", help="SGLang tokenizer路径")
parser.add_argument("--sglang_shared_path", type=str, default="./sglang_ckpt_ppo", help="SGLang共享存储路径")
args = parser.parse_args()
@@ -421,8 +421,6 @@ if __name__ == "__main__":
Logger('torch.compile enabled')
rollout_engine.update_policy(actor_model)
if dist.is_initialized():
actor_model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
critic_model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
actor_model = DistributedDataParallel(actor_model, device_ids=[local_rank])
critic_model = DistributedDataParallel(critic_model, device_ids=[local_rank])
if is_main_process(): rollout_engine.update_policy(actor_model)
-1
View File
@@ -150,7 +150,6 @@ if __name__ == "__main__":
model = torch.compile(model)
Logger('torch.compile enabled')
if dist.is_initialized():
model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
model = DistributedDataParallel(model, device_ids=[local_rank])
# ========== 8. 开始训练 ==========