diff --git a/trainer/rollout_engine.py b/trainer/rollout_engine.py index 029a44a..7ef2eae 100644 --- a/trainer/rollout_engine.py +++ b/trainer/rollout_engine.py @@ -1,4 +1,5 @@ -"""Rollout Engine - 可插拔的推理引擎 +""" +# 如果使用sglang加速,需通过以下命令首先启动(transformers格式)模型: python -m sglang.launch_server --model-path ./minimind-3 --attention-backend triton --host 0.0.0.0 --port 8998 """ import os @@ -9,7 +10,9 @@ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) import requests import torch +import torch.distributed as dist from abc import ABC, abstractmethod +from contextlib import nullcontext from dataclasses import dataclass from typing import List, Optional, Tuple from torch import Tensor @@ -65,8 +68,8 @@ class TorchRolloutEngine(RolloutEngine): def rollout(self, prompt_ids: Tensor, attention_mask: Tensor, num_generations: int, max_new_tokens: int, temperature: float = 0.8) -> RolloutResult: model = self.policy_model.module if isinstance(self.policy_model, DistributedDataParallel) else self.policy_model - - with torch.no_grad(): + ctx = self.autocast_ctx if self.autocast_ctx else nullcontext() + with torch.no_grad(), ctx: output_ids = model.generate( input_ids=prompt_ids, attention_mask=attention_mask, @@ -77,15 +80,9 @@ class TorchRolloutEngine(RolloutEngine): pad_token_id=self.tokenizer.pad_token_id, eos_token_id=self.tokenizer.eos_token_id, ) # [B*num_gen, P+R] - - prompt_len = prompt_ids.size(1) - completion_ids = output_ids[:, prompt_len:] # [B*num_gen, R] - - from contextlib import nullcontext - ctx = self.autocast_ctx if self.autocast_ctx else nullcontext() - with ctx: + prompt_len = prompt_ids.size(1) + completion_ids = output_ids[:, prompt_len:] # [B*num_gen, R] per_token_logps = compute_per_token_logps(self.policy_model, output_ids, completion_ids.size(1)) - completions = self.tokenizer.batch_decode(completion_ids, skip_special_tokens=True) return RolloutResult(output_ids, completion_ids, per_token_logps, completions) @@ -143,6 +140,10 @@ class SGLangRolloutEngine(RolloutEngine): elif isinstance(item, (int, float)): logprobs.append(item) + if len(logprobs) < len(completion_ids): + logprobs = [0.0] * (len(completion_ids) - len(logprobs)) + logprobs + elif len(logprobs) > len(completion_ids): + logprobs = logprobs[-len(completion_ids):] prompt = all_input_ids[i] full_output = prompt + completion_ids all_output_ids.append(full_output) @@ -153,7 +154,6 @@ class SGLangRolloutEngine(RolloutEngine): device = prompt_ids.device max_out_len = max(len(ids) for ids in all_output_ids) max_comp_len = max(len(ids) for ids in all_completion_ids) - max_logp_len = max(len(lp) for lp in all_logprobs) def pad_to_tensor(seqs, max_len, pad_val=0): return torch.tensor([s + [pad_val] * (max_len - len(s)) for s in seqs], device=device) @@ -161,17 +161,17 @@ class SGLangRolloutEngine(RolloutEngine): return RolloutResult( output_ids=pad_to_tensor(all_output_ids, max_out_len), completion_ids=pad_to_tensor(all_completion_ids, max_comp_len), - per_token_logps=pad_to_tensor(all_logprobs, max_logp_len, pad_val=0.0), + per_token_logps=pad_to_tensor(all_logprobs, max_comp_len, pad_val=0.0), completions=completions, ) def update_policy(self, model: torch.nn.Module): + if dist.is_initialized() and dist.get_rank() != 0: return True unwrapped = model.module if isinstance(model, DistributedDataParallel) else model + unwrapped = getattr(unwrapped, '_orig_mod', unwrapped) abs_path = os.path.abspath(self.shared_ckpt_path) - unwrapped.lm_head.weight = torch.nn.Parameter(unwrapped.lm_head.weight.clone()) state_dict = {k: v.detach().half().cpu() for k, v in unwrapped.state_dict().items()} unwrapped.save_pretrained(abs_path, state_dict=state_dict, safe_serialization=False) - unwrapped.model.embed_tokens.weight = unwrapped.lm_head.weight self.tokenizer.save_pretrained(abs_path) resp = self.http.post( f"{self.base_url}/update_weights_from_disk", diff --git a/trainer/train_agent.py b/trainer/train_agent.py index 9a84b69..ff21d62 100644 --- a/trainer/train_agent.py +++ b/trainer/train_agent.py @@ -403,7 +403,7 @@ if __name__ == "__main__": parser.add_argument("--debug_interval", type=int, default=20, help="调试日志间隔") parser.add_argument("--thinking_ratio", type=float, default=0.1, help="按概率开启thinking(0.0~1.0)") parser.add_argument("--reward_model_path", type=str, default="../../internlm2-1_8b-reward", help="Reward模型路径") - parser.add_argument("--rollout_engine", type=str, default="sglang", choices=["torch", "sglang"], help="rollout引擎类型") + parser.add_argument("--rollout_engine", type=str, default="torch", choices=["torch", "sglang"], help="rollout引擎类型") parser.add_argument("--sglang_base_url", type=str, default="http://localhost:8998", help="SGLang服务器URL") parser.add_argument("--sglang_model_path", type=str, default="../model", help="SGLang tokenizer路径") parser.add_argument("--sglang_shared_path", type=str, default="./sglang_ckpt_agent", help="SGLang共享存储路径") @@ -467,8 +467,8 @@ if __name__ == "__main__": if args.use_compile == 1: model = torch.compile(model) Logger('torch.compile enabled') + rollout_engine.update_policy(model) if dist.is_initialized(): - model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"} model = DistributedDataParallel(model, device_ids=[local_rank]) if is_main_process(): rollout_engine.update_policy(model) diff --git a/trainer/train_distillation.py b/trainer/train_distillation.py index 6b1b5b9..63b6025 100644 --- a/trainer/train_distillation.py +++ b/trainer/train_distillation.py @@ -226,7 +226,6 @@ if __name__ == "__main__": model = torch.compile(model) Logger('torch.compile enabled') if dist.is_initialized(): - model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"} model = DistributedDataParallel(model, device_ids=[local_rank]) # ========== 8. 开始训练 ========== diff --git a/trainer/train_dpo.py b/trainer/train_dpo.py index 1876e26..558d108 100644 --- a/trainer/train_dpo.py +++ b/trainer/train_dpo.py @@ -206,7 +206,6 @@ if __name__ == "__main__": model = torch.compile(model) Logger('torch.compile enabled') if dist.is_initialized(): - model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"} model = DistributedDataParallel(model, device_ids=[local_rank]) # ========== 8. 开始训练 ========== diff --git a/trainer/train_full_sft.py b/trainer/train_full_sft.py index f685760..51c755b 100644 --- a/trainer/train_full_sft.py +++ b/trainer/train_full_sft.py @@ -151,7 +151,6 @@ if __name__ == "__main__": model = torch.compile(model) Logger('torch.compile enabled') if dist.is_initialized(): - model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"} model = DistributedDataParallel(model, device_ids=[local_rank]) # ========== 8. 开始训练 ========== diff --git a/trainer/train_grpo.py b/trainer/train_grpo.py index 2d514e2..196c201 100755 --- a/trainer/train_grpo.py +++ b/trainer/train_grpo.py @@ -236,8 +236,8 @@ if __name__ == "__main__": parser.add_argument("--debug_mode", action="store_true", help="是否打印训练调试采样") parser.add_argument("--debug_interval", type=int, default=20, help="debug模式下每隔多少step打印一次采样") parser.add_argument("--thinking_ratio", type=float, default=0.9, help="按概率开启thinking(0.0~1.0)") - parser.add_argument("--rollout_engine", type=str, default="sglang", choices=["torch", "sglang"], help="rollout引擎类型") - parser.add_argument("--sglang_base_url", type=str, default="http://localhost:8996", help="SGLang服务器URL") + parser.add_argument("--rollout_engine", type=str, default="torch", choices=["torch", "sglang"], help="rollout引擎类型") + parser.add_argument("--sglang_base_url", type=str, default="http://localhost:8998", help="SGLang服务器URL") parser.add_argument("--sglang_model_path", type=str, default="../model", help="SGLang tokenizer路径") parser.add_argument("--sglang_shared_path", type=str, default="./sglang_ckpt_grpo", help="SGLang共享存储路径") args = parser.parse_args() @@ -311,7 +311,6 @@ if __name__ == "__main__": Logger('torch.compile enabled') rollout_engine.update_policy(model) if dist.is_initialized(): - model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"} model = DistributedDataParallel(model, device_ids=[local_rank]) if is_main_process(): rollout_engine.update_policy(model) diff --git a/trainer/train_lora.py b/trainer/train_lora.py index 4ea259a..974bf37 100644 --- a/trainer/train_lora.py +++ b/trainer/train_lora.py @@ -164,7 +164,6 @@ if __name__ == "__main__": args.use_compile = 0 Logger('[LoRA] monkey-patch forward 与 torch.compile 不兼容,use_compile 已自动关闭') if dist.is_initialized(): - model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"} model = DistributedDataParallel(model, device_ids=[local_rank]) # ========== 9. 开始训练 ========== diff --git a/trainer/train_ppo.py b/trainer/train_ppo.py index 27668ec..1faf74b 100644 --- a/trainer/train_ppo.py +++ b/trainer/train_ppo.py @@ -339,8 +339,8 @@ if __name__ == "__main__": parser.add_argument("--debug_mode", action="store_true", help="是否打印训练调试采样") parser.add_argument("--debug_interval", type=int, default=20, help="debug模式下每隔多少step打印一次采样") parser.add_argument("--thinking_ratio", type=float, default=0.9, help="按概率开启thinking(0.0~1.0)") - parser.add_argument("--rollout_engine", type=str, default="sglang", choices=["torch", "sglang"], help="rollout引擎类型") - parser.add_argument("--sglang_base_url", type=str, default="http://localhost:8997", help="SGLang服务器URL") + parser.add_argument("--rollout_engine", type=str, default="torch", choices=["torch", "sglang"], help="rollout引擎类型") + parser.add_argument("--sglang_base_url", type=str, default="http://localhost:8998", help="SGLang服务器URL") parser.add_argument("--sglang_model_path", type=str, default="../model", help="SGLang tokenizer路径") parser.add_argument("--sglang_shared_path", type=str, default="./sglang_ckpt_ppo", help="SGLang共享存储路径") args = parser.parse_args() @@ -421,8 +421,6 @@ if __name__ == "__main__": Logger('torch.compile enabled') rollout_engine.update_policy(actor_model) if dist.is_initialized(): - actor_model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"} - critic_model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"} actor_model = DistributedDataParallel(actor_model, device_ids=[local_rank]) critic_model = DistributedDataParallel(critic_model, device_ids=[local_rank]) if is_main_process(): rollout_engine.update_policy(actor_model) diff --git a/trainer/train_pretrain.py b/trainer/train_pretrain.py index 40acdfb..5d33de4 100644 --- a/trainer/train_pretrain.py +++ b/trainer/train_pretrain.py @@ -150,7 +150,6 @@ if __name__ == "__main__": model = torch.compile(model) Logger('torch.compile enabled') if dist.is_initialized(): - model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"} model = DistributedDataParallel(model, device_ids=[local_rank]) # ========== 8. 开始训练 ==========