mirror of
https://github.com/jingyaogong/minimind.git
synced 2026-06-06 00:04:50 +00:00
218 lines
9.2 KiB
Python
218 lines
9.2 KiB
Python
"""
|
|
# 如果使用sglang加速,需通过以下命令首先启动(transformers格式)模型:
|
|
python -m sglang.launch_server --model-path ./minimind-3 --attention-backend triton --host 0.0.0.0 --port 8998
|
|
"""
|
|
import os
|
|
import sys
|
|
|
|
__package__ = "trainer"
|
|
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
|
|
|
import requests
|
|
import torch
|
|
import torch.distributed as dist
|
|
from abc import ABC, abstractmethod
|
|
from contextlib import nullcontext
|
|
from dataclasses import dataclass
|
|
from typing import List, Optional, Tuple
|
|
from torch import Tensor
|
|
from torch.nn.parallel import DistributedDataParallel
|
|
from transformers import AutoTokenizer
|
|
|
|
|
|
# ===== 计算每个 token 的 logprob =====
|
|
def compute_per_token_logps(model, input_ids: Tensor, n_keep: int, attention_mask: Optional[Tensor] = None) -> Tensor:
|
|
if n_keep <= 0:
|
|
return input_ids.new_empty((input_ids.size(0), 0), dtype=torch.float32)
|
|
unwrapped = model.module if isinstance(model, DistributedDataParallel) else model
|
|
input_ids = input_ids.detach().clone() if input_ids.is_inference() else input_ids
|
|
logits = unwrapped(input_ids, attention_mask=attention_mask, logits_to_keep=n_keep + 1).logits[:, :-1, :]
|
|
per_token_logps = []
|
|
for logits_row, ids_row in zip(logits, input_ids[:, -n_keep:]):
|
|
ids_row = ids_row.detach().clone() if ids_row.is_inference() else ids_row
|
|
per_token_logps.append(
|
|
torch.gather(logits_row.log_softmax(dim=-1), 1, ids_row.unsqueeze(1)).squeeze(1)
|
|
)
|
|
return torch.stack(per_token_logps)
|
|
|
|
|
|
# ===== Rollout 结果 =====
|
|
@dataclass
|
|
class RolloutResult:
|
|
output_ids: Tensor
|
|
completion_ids: Tensor
|
|
per_token_logps: Tensor
|
|
completions: List[str]
|
|
|
|
|
|
# ===== Rollout 引擎抽象基类 =====
|
|
class RolloutEngine(ABC):
|
|
tokenizer = None
|
|
|
|
@abstractmethod
|
|
def rollout(self, prompt_ids: Tensor, attention_mask: Tensor, num_generations: int, max_new_tokens: int, temperature: float = 0.8) -> RolloutResult:
|
|
pass
|
|
|
|
@abstractmethod
|
|
def update_policy(self, model: torch.nn.Module):
|
|
pass
|
|
|
|
|
|
# ===== PyTorch 原生推理引擎 =====
|
|
class TorchRolloutEngine(RolloutEngine):
|
|
def __init__(self, policy_model: torch.nn.Module, tokenizer, device: str = "cuda", autocast_ctx=None):
|
|
self.policy_model = policy_model
|
|
self.tokenizer = tokenizer
|
|
self.device = device
|
|
self.autocast_ctx = autocast_ctx
|
|
|
|
def rollout(self, prompt_ids: Tensor, attention_mask: Tensor, num_generations: int, max_new_tokens: int, temperature: float = 0.8) -> RolloutResult:
|
|
model = self.policy_model.module if isinstance(self.policy_model, DistributedDataParallel) else self.policy_model
|
|
ctx = self.autocast_ctx if self.autocast_ctx else nullcontext()
|
|
with torch.no_grad(), ctx:
|
|
output_ids = model.generate(
|
|
input_ids=prompt_ids,
|
|
attention_mask=attention_mask,
|
|
max_new_tokens=max_new_tokens,
|
|
do_sample=True,
|
|
temperature=temperature,
|
|
num_return_sequences=num_generations,
|
|
pad_token_id=self.tokenizer.pad_token_id,
|
|
eos_token_id=self.tokenizer.eos_token_id,
|
|
) # [B*num_gen, P+R]
|
|
prompt_len = prompt_ids.size(1)
|
|
completion_ids = output_ids[:, prompt_len:] # [B*num_gen, R]
|
|
full_mask = (output_ids != self.tokenizer.pad_token_id).long()
|
|
per_token_logps = compute_per_token_logps(self.policy_model, output_ids, completion_ids.size(1), attention_mask=full_mask)
|
|
completions = self.tokenizer.batch_decode(completion_ids, skip_special_tokens=True)
|
|
return RolloutResult(output_ids, completion_ids, per_token_logps, completions)
|
|
|
|
def update_policy(self, model: torch.nn.Module):
|
|
self.policy_model = model
|
|
|
|
|
|
# ===== SGLang HTTP API 推理引擎 =====
|
|
class SGLangRolloutEngine(RolloutEngine):
|
|
def __init__(self, base_url: str, model_path: str, shared_ckpt_path: str = "./sglang_ckpt", timeout: int = 120):
|
|
self.base_url = base_url.rstrip('/')
|
|
self.shared_ckpt_path = shared_ckpt_path
|
|
self.timeout = timeout
|
|
self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
|
self.http = requests
|
|
|
|
def rollout(self, prompt_ids: Tensor, attention_mask: Tensor, num_generations: int, max_new_tokens: int, temperature: float = 0.8) -> RolloutResult:
|
|
# 去除左侧 padding tokens,只保留有效 token
|
|
input_ids_list = []
|
|
for ids, mask in zip(prompt_ids, attention_mask):
|
|
valid_ids = ids[mask.bool()].tolist()
|
|
input_ids_list.append(valid_ids)
|
|
all_input_ids = [ids for ids in input_ids_list for _ in range(num_generations)]
|
|
|
|
payload = {
|
|
"input_ids": all_input_ids,
|
|
"sampling_params": {
|
|
"temperature": temperature,
|
|
"max_new_tokens": max_new_tokens,
|
|
"stop_token_ids": [self.tokenizer.eos_token_id] if self.tokenizer.eos_token_id else [],
|
|
},
|
|
"return_logprob": True,
|
|
}
|
|
|
|
resp = self.http.post(f"{self.base_url}/generate", json=payload, timeout=self.timeout)
|
|
resp.raise_for_status()
|
|
|
|
results = resp.json()
|
|
if not isinstance(results, list):
|
|
results = [results]
|
|
|
|
all_output_ids, all_completion_ids, all_logprobs = [], [], []
|
|
completions = []
|
|
prompt_len = prompt_ids.size(1)
|
|
|
|
for i, result in enumerate(results):
|
|
meta = result.get("meta_info", {})
|
|
completion_ids = meta.get("output_ids", result.get("output_ids", []))
|
|
raw_logprobs = meta.get("output_token_logprobs", [])
|
|
|
|
logprobs = []
|
|
for item in raw_logprobs:
|
|
if isinstance(item, (list, tuple)) and len(item) >= 1:
|
|
logprobs.append(item[0])
|
|
elif isinstance(item, (int, float)):
|
|
logprobs.append(item)
|
|
|
|
if len(logprobs) < len(completion_ids):
|
|
logprobs = [0.0] * (len(completion_ids) - len(logprobs)) + logprobs
|
|
elif len(logprobs) > len(completion_ids):
|
|
logprobs = logprobs[-len(completion_ids):]
|
|
prompt = all_input_ids[i]
|
|
full_output = prompt + completion_ids
|
|
all_output_ids.append(full_output)
|
|
all_completion_ids.append(completion_ids)
|
|
all_logprobs.append(logprobs)
|
|
completions.append(self.tokenizer.decode(completion_ids, skip_special_tokens=True))
|
|
|
|
device = prompt_ids.device
|
|
max_out_len = max(len(ids) for ids in all_output_ids)
|
|
max_comp_len = max(len(ids) for ids in all_completion_ids)
|
|
|
|
def pad_to_tensor(seqs, max_len, pad_val=0):
|
|
return torch.tensor([s + [pad_val] * (max_len - len(s)) for s in seqs], device=device)
|
|
|
|
pad_id = self.tokenizer.pad_token_id
|
|
return RolloutResult(
|
|
output_ids=pad_to_tensor(all_output_ids, max_out_len, pad_val=pad_id),
|
|
completion_ids=pad_to_tensor(all_completion_ids, max_comp_len, pad_val=pad_id),
|
|
per_token_logps=pad_to_tensor(all_logprobs, max_comp_len, pad_val=0.0),
|
|
completions=completions,
|
|
)
|
|
|
|
def update_policy(self, model: torch.nn.Module):
|
|
ok = True
|
|
if not dist.is_initialized() or dist.get_rank() == 0:
|
|
unwrapped = model.module if isinstance(model, DistributedDataParallel) else model
|
|
unwrapped = getattr(unwrapped, '_orig_mod', unwrapped)
|
|
abs_path = os.path.abspath(self.shared_ckpt_path)
|
|
state_dict = {k: v.detach().half().cpu() for k, v in unwrapped.state_dict().items()}
|
|
unwrapped.save_pretrained(abs_path, state_dict=state_dict, safe_serialization=False)
|
|
self.tokenizer.save_pretrained(abs_path)
|
|
resp = self.http.post(
|
|
f"{self.base_url}/update_weights_from_disk",
|
|
json={"model_path": abs_path},
|
|
timeout=self.timeout
|
|
)
|
|
if resp.status_code != 200: print(f"[SGLANG WARNING] update_weights 失败: {resp.status_code}, {resp.text}")
|
|
ok = resp.status_code == 200
|
|
if dist.is_initialized(): dist.barrier()
|
|
return ok
|
|
|
|
def flush_cache(self) -> bool:
|
|
resp = self.http.post(f"{self.base_url}/flush_cache", timeout=30)
|
|
return resp.status_code == 200
|
|
|
|
def health(self) -> bool:
|
|
try:
|
|
resp = self.http.get(f"{self.base_url}/health", timeout=5)
|
|
return resp.status_code == 200
|
|
except:
|
|
return False
|
|
|
|
|
|
# ===== 工厂函数 =====
|
|
def create_rollout_engine(
|
|
engine_type: str = "torch",
|
|
policy_model: torch.nn.Module = None,
|
|
tokenizer = None,
|
|
device: str = "cuda",
|
|
autocast_ctx = None,
|
|
sglang_base_url: str = None,
|
|
sglang_model_path: str = None,
|
|
sglang_shared_path: str = None,
|
|
) -> RolloutEngine:
|
|
if engine_type == "torch":
|
|
return TorchRolloutEngine(policy_model, tokenizer, device, autocast_ctx)
|
|
elif engine_type == "sglang":
|
|
return SGLangRolloutEngine(sglang_base_url, sglang_model_path, sglang_shared_path)
|
|
else:
|
|
raise ValueError(f"不支持的引擎类型: {engine_type}")
|