mirror of
https://github.com/jingyaogong/minimind.git
synced 2026-06-06 00:04:50 +00:00
[fix] bugs
This commit is contained in:
+13
-13
@@ -1,4 +1,5 @@
|
||||
"""Rollout Engine - 可插拔的推理引擎
|
||||
"""
|
||||
# 如果使用sglang加速,需通过以下命令首先启动(transformers格式)模型:
|
||||
python -m sglang.launch_server --model-path ./minimind-3 --attention-backend triton --host 0.0.0.0 --port 8998
|
||||
"""
|
||||
import os
|
||||
@@ -9,7 +10,9 @@ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
||||
|
||||
import requests
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from abc import ABC, abstractmethod
|
||||
from contextlib import nullcontext
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Tuple
|
||||
from torch import Tensor
|
||||
@@ -65,8 +68,8 @@ class TorchRolloutEngine(RolloutEngine):
|
||||
|
||||
def rollout(self, prompt_ids: Tensor, attention_mask: Tensor, num_generations: int, max_new_tokens: int, temperature: float = 0.8) -> RolloutResult:
|
||||
model = self.policy_model.module if isinstance(self.policy_model, DistributedDataParallel) else self.policy_model
|
||||
|
||||
with torch.no_grad():
|
||||
ctx = self.autocast_ctx if self.autocast_ctx else nullcontext()
|
||||
with torch.no_grad(), ctx:
|
||||
output_ids = model.generate(
|
||||
input_ids=prompt_ids,
|
||||
attention_mask=attention_mask,
|
||||
@@ -77,15 +80,9 @@ class TorchRolloutEngine(RolloutEngine):
|
||||
pad_token_id=self.tokenizer.pad_token_id,
|
||||
eos_token_id=self.tokenizer.eos_token_id,
|
||||
) # [B*num_gen, P+R]
|
||||
|
||||
prompt_len = prompt_ids.size(1)
|
||||
completion_ids = output_ids[:, prompt_len:] # [B*num_gen, R]
|
||||
|
||||
from contextlib import nullcontext
|
||||
ctx = self.autocast_ctx if self.autocast_ctx else nullcontext()
|
||||
with ctx:
|
||||
per_token_logps = compute_per_token_logps(self.policy_model, output_ids, completion_ids.size(1))
|
||||
|
||||
completions = self.tokenizer.batch_decode(completion_ids, skip_special_tokens=True)
|
||||
return RolloutResult(output_ids, completion_ids, per_token_logps, completions)
|
||||
|
||||
@@ -143,6 +140,10 @@ class SGLangRolloutEngine(RolloutEngine):
|
||||
elif isinstance(item, (int, float)):
|
||||
logprobs.append(item)
|
||||
|
||||
if len(logprobs) < len(completion_ids):
|
||||
logprobs = [0.0] * (len(completion_ids) - len(logprobs)) + logprobs
|
||||
elif len(logprobs) > len(completion_ids):
|
||||
logprobs = logprobs[-len(completion_ids):]
|
||||
prompt = all_input_ids[i]
|
||||
full_output = prompt + completion_ids
|
||||
all_output_ids.append(full_output)
|
||||
@@ -153,7 +154,6 @@ class SGLangRolloutEngine(RolloutEngine):
|
||||
device = prompt_ids.device
|
||||
max_out_len = max(len(ids) for ids in all_output_ids)
|
||||
max_comp_len = max(len(ids) for ids in all_completion_ids)
|
||||
max_logp_len = max(len(lp) for lp in all_logprobs)
|
||||
|
||||
def pad_to_tensor(seqs, max_len, pad_val=0):
|
||||
return torch.tensor([s + [pad_val] * (max_len - len(s)) for s in seqs], device=device)
|
||||
@@ -161,17 +161,17 @@ class SGLangRolloutEngine(RolloutEngine):
|
||||
return RolloutResult(
|
||||
output_ids=pad_to_tensor(all_output_ids, max_out_len),
|
||||
completion_ids=pad_to_tensor(all_completion_ids, max_comp_len),
|
||||
per_token_logps=pad_to_tensor(all_logprobs, max_logp_len, pad_val=0.0),
|
||||
per_token_logps=pad_to_tensor(all_logprobs, max_comp_len, pad_val=0.0),
|
||||
completions=completions,
|
||||
)
|
||||
|
||||
def update_policy(self, model: torch.nn.Module):
|
||||
if dist.is_initialized() and dist.get_rank() != 0: return True
|
||||
unwrapped = model.module if isinstance(model, DistributedDataParallel) else model
|
||||
unwrapped = getattr(unwrapped, '_orig_mod', unwrapped)
|
||||
abs_path = os.path.abspath(self.shared_ckpt_path)
|
||||
unwrapped.lm_head.weight = torch.nn.Parameter(unwrapped.lm_head.weight.clone())
|
||||
state_dict = {k: v.detach().half().cpu() for k, v in unwrapped.state_dict().items()}
|
||||
unwrapped.save_pretrained(abs_path, state_dict=state_dict, safe_serialization=False)
|
||||
unwrapped.model.embed_tokens.weight = unwrapped.lm_head.weight
|
||||
self.tokenizer.save_pretrained(abs_path)
|
||||
resp = self.http.post(
|
||||
f"{self.base_url}/update_weights_from_disk",
|
||||
|
||||
@@ -403,7 +403,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--debug_interval", type=int, default=20, help="调试日志间隔")
|
||||
parser.add_argument("--thinking_ratio", type=float, default=0.1, help="按概率开启thinking(0.0~1.0)")
|
||||
parser.add_argument("--reward_model_path", type=str, default="../../internlm2-1_8b-reward", help="Reward模型路径")
|
||||
parser.add_argument("--rollout_engine", type=str, default="sglang", choices=["torch", "sglang"], help="rollout引擎类型")
|
||||
parser.add_argument("--rollout_engine", type=str, default="torch", choices=["torch", "sglang"], help="rollout引擎类型")
|
||||
parser.add_argument("--sglang_base_url", type=str, default="http://localhost:8998", help="SGLang服务器URL")
|
||||
parser.add_argument("--sglang_model_path", type=str, default="../model", help="SGLang tokenizer路径")
|
||||
parser.add_argument("--sglang_shared_path", type=str, default="./sglang_ckpt_agent", help="SGLang共享存储路径")
|
||||
@@ -467,8 +467,8 @@ if __name__ == "__main__":
|
||||
if args.use_compile == 1:
|
||||
model = torch.compile(model)
|
||||
Logger('torch.compile enabled')
|
||||
rollout_engine.update_policy(model)
|
||||
if dist.is_initialized():
|
||||
model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
|
||||
model = DistributedDataParallel(model, device_ids=[local_rank])
|
||||
if is_main_process(): rollout_engine.update_policy(model)
|
||||
|
||||
|
||||
@@ -226,7 +226,6 @@ if __name__ == "__main__":
|
||||
model = torch.compile(model)
|
||||
Logger('torch.compile enabled')
|
||||
if dist.is_initialized():
|
||||
model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
|
||||
model = DistributedDataParallel(model, device_ids=[local_rank])
|
||||
|
||||
# ========== 8. 开始训练 ==========
|
||||
|
||||
@@ -206,7 +206,6 @@ if __name__ == "__main__":
|
||||
model = torch.compile(model)
|
||||
Logger('torch.compile enabled')
|
||||
if dist.is_initialized():
|
||||
model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
|
||||
model = DistributedDataParallel(model, device_ids=[local_rank])
|
||||
|
||||
# ========== 8. 开始训练 ==========
|
||||
|
||||
@@ -151,7 +151,6 @@ if __name__ == "__main__":
|
||||
model = torch.compile(model)
|
||||
Logger('torch.compile enabled')
|
||||
if dist.is_initialized():
|
||||
model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
|
||||
model = DistributedDataParallel(model, device_ids=[local_rank])
|
||||
|
||||
# ========== 8. 开始训练 ==========
|
||||
|
||||
@@ -236,8 +236,8 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--debug_mode", action="store_true", help="是否打印训练调试采样")
|
||||
parser.add_argument("--debug_interval", type=int, default=20, help="debug模式下每隔多少step打印一次采样")
|
||||
parser.add_argument("--thinking_ratio", type=float, default=0.9, help="按概率开启thinking(0.0~1.0)")
|
||||
parser.add_argument("--rollout_engine", type=str, default="sglang", choices=["torch", "sglang"], help="rollout引擎类型")
|
||||
parser.add_argument("--sglang_base_url", type=str, default="http://localhost:8996", help="SGLang服务器URL")
|
||||
parser.add_argument("--rollout_engine", type=str, default="torch", choices=["torch", "sglang"], help="rollout引擎类型")
|
||||
parser.add_argument("--sglang_base_url", type=str, default="http://localhost:8998", help="SGLang服务器URL")
|
||||
parser.add_argument("--sglang_model_path", type=str, default="../model", help="SGLang tokenizer路径")
|
||||
parser.add_argument("--sglang_shared_path", type=str, default="./sglang_ckpt_grpo", help="SGLang共享存储路径")
|
||||
args = parser.parse_args()
|
||||
@@ -311,7 +311,6 @@ if __name__ == "__main__":
|
||||
Logger('torch.compile enabled')
|
||||
rollout_engine.update_policy(model)
|
||||
if dist.is_initialized():
|
||||
model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
|
||||
model = DistributedDataParallel(model, device_ids=[local_rank])
|
||||
if is_main_process(): rollout_engine.update_policy(model)
|
||||
|
||||
|
||||
@@ -164,7 +164,6 @@ if __name__ == "__main__":
|
||||
args.use_compile = 0
|
||||
Logger('[LoRA] monkey-patch forward 与 torch.compile 不兼容,use_compile 已自动关闭')
|
||||
if dist.is_initialized():
|
||||
model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
|
||||
model = DistributedDataParallel(model, device_ids=[local_rank])
|
||||
|
||||
# ========== 9. 开始训练 ==========
|
||||
|
||||
@@ -339,8 +339,8 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--debug_mode", action="store_true", help="是否打印训练调试采样")
|
||||
parser.add_argument("--debug_interval", type=int, default=20, help="debug模式下每隔多少step打印一次采样")
|
||||
parser.add_argument("--thinking_ratio", type=float, default=0.9, help="按概率开启thinking(0.0~1.0)")
|
||||
parser.add_argument("--rollout_engine", type=str, default="sglang", choices=["torch", "sglang"], help="rollout引擎类型")
|
||||
parser.add_argument("--sglang_base_url", type=str, default="http://localhost:8997", help="SGLang服务器URL")
|
||||
parser.add_argument("--rollout_engine", type=str, default="torch", choices=["torch", "sglang"], help="rollout引擎类型")
|
||||
parser.add_argument("--sglang_base_url", type=str, default="http://localhost:8998", help="SGLang服务器URL")
|
||||
parser.add_argument("--sglang_model_path", type=str, default="../model", help="SGLang tokenizer路径")
|
||||
parser.add_argument("--sglang_shared_path", type=str, default="./sglang_ckpt_ppo", help="SGLang共享存储路径")
|
||||
args = parser.parse_args()
|
||||
@@ -421,8 +421,6 @@ if __name__ == "__main__":
|
||||
Logger('torch.compile enabled')
|
||||
rollout_engine.update_policy(actor_model)
|
||||
if dist.is_initialized():
|
||||
actor_model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
|
||||
critic_model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
|
||||
actor_model = DistributedDataParallel(actor_model, device_ids=[local_rank])
|
||||
critic_model = DistributedDataParallel(critic_model, device_ids=[local_rank])
|
||||
if is_main_process(): rollout_engine.update_policy(actor_model)
|
||||
|
||||
@@ -150,7 +150,6 @@ if __name__ == "__main__":
|
||||
model = torch.compile(model)
|
||||
Logger('torch.compile enabled')
|
||||
if dist.is_initialized():
|
||||
model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
|
||||
model = DistributedDataParallel(model, device_ids=[local_rank])
|
||||
|
||||
# ========== 8. 开始训练 ==========
|
||||
|
||||
Reference in New Issue
Block a user