Files
minimind/trainer/rollout_engine.py
T
2026-03-25 23:57:45 +08:00

213 lines
8.7 KiB
Python

"""Rollout Engine - 可插拔的推理引擎
python -m sglang.launch_server --model-path ./minimind-3 --attention-backend triton --host 0.0.0.0 --port 8998
"""
import os
import sys
__package__ = "trainer"
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
import requests
import torch
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import List, Optional, Tuple
from torch import Tensor
from torch.nn.parallel import DistributedDataParallel
from transformers import AutoTokenizer
# ===== 计算每个 token 的 logprob =====
def compute_per_token_logps(model, input_ids: Tensor, n_keep: int, attention_mask: Optional[Tensor] = None) -> Tensor:
if n_keep <= 0:
return input_ids.new_empty((input_ids.size(0), 0), dtype=torch.float32)
unwrapped = model.module if isinstance(model, DistributedDataParallel) else model
input_ids = input_ids.detach().clone() if input_ids.is_inference() else input_ids
logits = unwrapped(input_ids, attention_mask=attention_mask, logits_to_keep=n_keep + 1).logits[:, :-1, :]
per_token_logps = []
for logits_row, ids_row in zip(logits, input_ids[:, -n_keep:]):
ids_row = ids_row.detach().clone() if ids_row.is_inference() else ids_row
per_token_logps.append(
torch.gather(logits_row.log_softmax(dim=-1), 1, ids_row.unsqueeze(1)).squeeze(1)
)
return torch.stack(per_token_logps)
# ===== Rollout 结果 =====
@dataclass
class RolloutResult:
output_ids: Tensor
completion_ids: Tensor
per_token_logps: Tensor
completions: List[str]
# ===== Rollout 引擎抽象基类 =====
class RolloutEngine(ABC):
tokenizer = None
@abstractmethod
def rollout(self, prompt_ids: Tensor, attention_mask: Tensor, num_generations: int, max_new_tokens: int, temperature: float = 0.8) -> RolloutResult:
pass
@abstractmethod
def update_policy(self, model: torch.nn.Module):
pass
# ===== PyTorch 原生推理引擎 =====
class TorchRolloutEngine(RolloutEngine):
def __init__(self, policy_model: torch.nn.Module, tokenizer, device: str = "cuda", autocast_ctx=None):
self.policy_model = policy_model
self.tokenizer = tokenizer
self.device = device
self.autocast_ctx = autocast_ctx
def rollout(self, prompt_ids: Tensor, attention_mask: Tensor, num_generations: int, max_new_tokens: int, temperature: float = 0.8) -> RolloutResult:
model = self.policy_model.module if isinstance(self.policy_model, DistributedDataParallel) else self.policy_model
with torch.no_grad():
output_ids = model.generate(
input_ids=prompt_ids,
attention_mask=attention_mask,
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=temperature,
num_return_sequences=num_generations,
pad_token_id=self.tokenizer.pad_token_id,
eos_token_id=self.tokenizer.eos_token_id,
) # [B*num_gen, P+R]
prompt_len = prompt_ids.size(1)
completion_ids = output_ids[:, prompt_len:] # [B*num_gen, R]
from contextlib import nullcontext
ctx = self.autocast_ctx if self.autocast_ctx else nullcontext()
with ctx:
per_token_logps = compute_per_token_logps(self.policy_model, output_ids, completion_ids.size(1))
completions = self.tokenizer.batch_decode(completion_ids, skip_special_tokens=True)
return RolloutResult(output_ids, completion_ids, per_token_logps, completions)
def update_policy(self, model: torch.nn.Module):
self.policy_model = model
# ===== SGLang HTTP API 推理引擎 =====
class SGLangRolloutEngine(RolloutEngine):
def __init__(self, base_url: str, model_path: str, shared_ckpt_path: str = "./sglang_ckpt", timeout: int = 120):
self.base_url = base_url.rstrip('/')
self.shared_ckpt_path = shared_ckpt_path
self.timeout = timeout
self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
self.http = requests
def rollout(self, prompt_ids: Tensor, attention_mask: Tensor, num_generations: int, max_new_tokens: int, temperature: float = 0.8) -> RolloutResult:
# 去除左侧 padding tokens,只保留有效 token
input_ids_list = []
for ids, mask in zip(prompt_ids, attention_mask):
valid_ids = ids[mask.bool()].tolist()
input_ids_list.append(valid_ids)
all_input_ids = [ids for ids in input_ids_list for _ in range(num_generations)]
payload = {
"input_ids": all_input_ids,
"sampling_params": {
"temperature": temperature,
"max_new_tokens": max_new_tokens,
"stop_token_ids": [self.tokenizer.eos_token_id] if self.tokenizer.eos_token_id else [],
},
"return_logprob": True,
}
resp = self.http.post(f"{self.base_url}/generate", json=payload, timeout=self.timeout)
resp.raise_for_status()
results = resp.json()
if not isinstance(results, list):
results = [results]
all_output_ids, all_completion_ids, all_logprobs = [], [], []
completions = []
prompt_len = prompt_ids.size(1)
for i, result in enumerate(results):
meta = result.get("meta_info", {})
completion_ids = meta.get("output_ids", result.get("output_ids", []))
raw_logprobs = meta.get("output_token_logprobs", [])
logprobs = []
for item in raw_logprobs:
if isinstance(item, (list, tuple)) and len(item) >= 1:
logprobs.append(item[0])
elif isinstance(item, (int, float)):
logprobs.append(item)
prompt = all_input_ids[i]
full_output = prompt + completion_ids
all_output_ids.append(full_output)
all_completion_ids.append(completion_ids)
all_logprobs.append(logprobs)
completions.append(self.tokenizer.decode(completion_ids, skip_special_tokens=True))
device = prompt_ids.device
max_out_len = max(len(ids) for ids in all_output_ids)
max_comp_len = max(len(ids) for ids in all_completion_ids)
max_logp_len = max(len(lp) for lp in all_logprobs)
def pad_to_tensor(seqs, max_len, pad_val=0):
return torch.tensor([s + [pad_val] * (max_len - len(s)) for s in seqs], device=device)
return RolloutResult(
output_ids=pad_to_tensor(all_output_ids, max_out_len),
completion_ids=pad_to_tensor(all_completion_ids, max_comp_len),
per_token_logps=pad_to_tensor(all_logprobs, max_logp_len, pad_val=0.0),
completions=completions,
)
def update_policy(self, model: torch.nn.Module):
unwrapped = model.module if isinstance(model, DistributedDataParallel) else model
abs_path = os.path.abspath(self.shared_ckpt_path)
unwrapped.lm_head.weight = torch.nn.Parameter(unwrapped.lm_head.weight.clone())
state_dict = {k: v.detach().half().cpu() for k, v in unwrapped.state_dict().items()}
unwrapped.save_pretrained(abs_path, state_dict=state_dict, safe_serialization=False)
unwrapped.model.embed_tokens.weight = unwrapped.lm_head.weight
self.tokenizer.save_pretrained(abs_path)
resp = self.http.post(
f"{self.base_url}/update_weights_from_disk",
json={"model_path": abs_path},
timeout=self.timeout
)
if resp.status_code != 200: print(f"[SGLANG WARNING] update_weights 失败: {resp.status_code}, {resp.text}")
return resp.status_code == 200
def flush_cache(self) -> bool:
resp = self.http.post(f"{self.base_url}/flush_cache", timeout=30)
return resp.status_code == 200
def health(self) -> bool:
try:
resp = self.http.get(f"{self.base_url}/health", timeout=5)
return resp.status_code == 200
except:
return False
# ===== 工厂函数 =====
def create_rollout_engine(
engine_type: str = "torch",
policy_model: torch.nn.Module = None,
tokenizer = None,
device: str = "cuda",
autocast_ctx = None,
sglang_base_url: str = None,
sglang_model_path: str = None,
sglang_shared_path: str = None,
) -> RolloutEngine:
if engine_type == "torch":
return TorchRolloutEngine(policy_model, tokenizer, device, autocast_ctx)
elif engine_type == "sglang":
return SGLangRolloutEngine(sglang_base_url, sglang_model_path, sglang_shared_path)
else:
raise ValueError(f"不支持的引擎类型: {engine_type}")