[update] minimind new docs
|
Before Width: | Height: | Size: 136 KiB |
|
Before Width: | Height: | Size: 73 KiB |
|
Before Width: | Height: | Size: 230 KiB |
|
Before Width: | Height: | Size: 104 KiB |
|
Before Width: | Height: | Size: 239 KiB |
BIN
docs/images/LLM-structure-moe.jpg
Normal file
|
After Width: | Height: | Size: 263 KiB |
|
Before Width: | Height: | Size: 121 KiB |
BIN
docs/images/LLM-structure.jpg
Normal file
|
After Width: | Height: | Size: 262 KiB |
|
Before Width: | Height: | Size: 372 KiB |
BIN
docs/images/agent_rl_loss.jpg
Normal file
|
After Width: | Height: | Size: 702 KiB |
BIN
docs/images/agent_webui.jpg
Normal file
|
After Width: | Height: | Size: 124 KiB |
BIN
docs/images/benchmark_radar.jpg
Normal file
|
After Width: | Height: | Size: 161 KiB |
|
Before Width: | Height: | Size: 519 KiB |
|
Before Width: | Height: | Size: 146 KiB After Width: | Height: | Size: 123 KiB |
BIN
docs/images/grpo_loss.jpg
Normal file
|
After Width: | Height: | Size: 590 KiB |
BIN
docs/images/minimind-3.gif
Normal file
|
After Width: | Height: | Size: 5.7 MiB |
|
Before Width: | Height: | Size: 3.8 MiB |
BIN
docs/images/ppo_loss.jpg
Normal file
|
After Width: | Height: | Size: 601 KiB |
|
Before Width: | Height: | Size: 559 KiB |
|
Before Width: | Height: | Size: 531 KiB |
BIN
docs/images/pretrain_loss.jpg
Normal file
|
After Width: | Height: | Size: 292 KiB |
BIN
docs/images/rl-structure.jpg
Normal file
|
After Width: | Height: | Size: 231 KiB |
|
Before Width: | Height: | Size: 144 KiB After Width: | Height: | Size: 79 KiB |
|
Before Width: | Height: | Size: 1006 KiB |
|
Before Width: | Height: | Size: 943 KiB |
BIN
docs/images/sft_loss.jpg
Normal file
|
After Width: | Height: | Size: 466 KiB |
|
Before Width: | Height: | Size: 214 KiB |
|
Before Width: | Height: | Size: 246 KiB |
|
Before Width: | Height: | Size: 246 KiB |
|
Before Width: | Height: | Size: 241 KiB |
|
Before Width: | Height: | Size: 234 KiB |
|
Before Width: | Height: | Size: 145 KiB |
|
Before Width: | Height: | Size: 152 KiB |
|
Before Width: | Height: | Size: 178 KiB After Width: | Height: | Size: 178 KiB |
|
Before Width: | Height: | Size: 150 KiB After Width: | Height: | Size: 150 KiB |
@ -7,18 +7,19 @@
|
||||
|
||||
## 📌 Introduction
|
||||
|
||||
**MiniMind** is a complete, open-source project for training ultra-small language models from scratch with minimal cost. Train a **26M** ChatBot in just **2 hours** with only **$3** on a single 3090 GPU!
|
||||
**MiniMind** is a complete, open-source project for training ultra-small language models from scratch with minimal cost. Train a **64M** ChatBot in just **2 hours** with only **$3** on a single 3090 GPU!
|
||||
|
||||
- **MiniMind** series is extremely lightweight, the smallest version is **1/7000** the size of GPT-3
|
||||
- **MiniMind** series is extremely lightweight, the smallest version is **1/2700** the size of GPT-3
|
||||
- Complete implementation covering:
|
||||
- **Tokenizer training** with custom vocabulary
|
||||
- **Pretraining** (knowledge learning)
|
||||
- **Supervised Fine-Tuning (SFT)** (conversation patterns)
|
||||
- **LoRA fine-tuning** (parameter-efficient adaptation)
|
||||
- **Direct Preference Optimization (DPO)** (human preference alignment)
|
||||
- **RLAIF algorithms** (PPO/GRPO/SPO - reinforcement learning)
|
||||
- **Knowledge distillation** (compress large model knowledge)
|
||||
- **Model reasoning distillation** (DeepSeek-R1 style)
|
||||
- **RLAIF algorithms** (PPO/GRPO/CISPO - reinforcement learning)
|
||||
- **Agentic RL** (multi-turn Tool-Use with delayed rewards)
|
||||
- **Knowledge distillation** (black-box & white-box)
|
||||
- **Tool Calling & Adaptive Thinking** (native template support)
|
||||
- **YaRN algorithm** (context length extrapolation)
|
||||
- **Pure PyTorch implementation**: All core algorithms are implemented from scratch using native PyTorch, without relying on third-party abstract interfaces
|
||||
- **Educational value**: This is not only a full-stage open-source reproduction of large language models, but also a comprehensive tutorial for getting started with LLMs
|
||||
@ -34,57 +35,48 @@
|
||||
## ✨ Key Highlights
|
||||
|
||||
- **Ultra-low cost**: Single 3090, 2 hours, $3 to train a fully functional ChatBot from scratch
|
||||
- **Complete pipeline**: Tokenizer → Pretraining → SFT → LoRA → DPO/RLAIF → Distillation → Reasoning
|
||||
- **Latest algorithms**: Implements cutting-edge techniques including GRPO, SPO, and YaRN
|
||||
- **Complete pipeline**: Tokenizer → Pretraining → SFT → LoRA → DPO → PPO/GRPO/CISPO → Agentic RL
|
||||
- **Latest algorithms**: Implements cutting-edge techniques including GRPO, CISPO, Agentic RL, and YaRN
|
||||
- **Education-friendly**: Clean, well-documented code suitable for learning LLM principles
|
||||
- **Ecosystem compatible**: Seamless support for `transformers`, `trl`, `peft`, `llama.cpp`, `vllm`, `ollama`, and `Llama-Factory`
|
||||
- **Ecosystem compatible**: Seamless support for `transformers`, `llama.cpp`, `vllm`, `ollama`, `SGLang`, and `MNN`
|
||||
- **Full capabilities**: Supports multi-GPU training (DDP/DeepSpeed), model visualization (Wandb/SwanLab), and dynamic checkpoint management
|
||||
- **Production-ready**: OpenAI API protocol support for easy integration with third-party UIs (FastGPT, Open-WebUI, etc.)
|
||||
- **Production-ready**: OpenAI API with Tool Calling & Adaptive Thinking for easy integration with FastGPT, Open-WebUI, Dify, etc.
|
||||
- **Multimodal extension**: Extended to vision with [MiniMind-V](https://github.com/jingyaogong/minimind-v)
|
||||
|
||||
## 📊 Model Series
|
||||
|
||||
### MiniMind2 Series (Latest - 2025.04.26)
|
||||
### MiniMind-3 Series (Latest - 2026.03.20)
|
||||
|
||||
| Model | Parameters | Vocabulary | Layers | Hidden Dim | Context | Inference Memory |
|
||||
|-------|-----------|------------|--------|-----------|---------|-----------------|
|
||||
| MiniMind2-small | 26M | 6,400 | 8 | 512 | 2K | ~0.5 GB |
|
||||
| MiniMind2-MoE | 145M | 6,400 | 8 | 640 | 2K | ~1.0 GB |
|
||||
| MiniMind2 | 104M | 6,400 | 16 | 768 | 2K | ~1.0 GB |
|
||||
| MiniMind-3 | 64M | 6,400 | 8 | 768 | 2K | ~0.5 GB |
|
||||
| MiniMind-3-MoE | 198M / A64M | 6,400 | 8 | 768 | 2K | ~1.0 GB |
|
||||
|
||||
### MiniMind-V1 Series (Legacy - 2024.09.01)
|
||||
## 📅 Latest Updates (2026-03-20)
|
||||
|
||||
| Model | Parameters | Vocabulary | Layers | Hidden Dim | Context |
|
||||
|-------|-----------|------------|--------|-----------|---------|
|
||||
| minimind-v1-small | 26M | 6,400 | 8 | 512 | 2K |
|
||||
| minimind-v1-moe | 104M | 6,400 | 8 | 512 | 2K |
|
||||
| minimind-v1 | 108M | 6,400 | 16 | 768 | 2K |
|
||||
🔥 **MiniMind-3 Release**: Architecture aligned with Qwen3/Qwen3-MoE, Dense ~64M, MoE ~198M/A64M
|
||||
|
||||
## 📅 Latest Updates (2025-10-24)
|
||||
|
||||
🔥 **RLAIF Training Algorithms**: Native implementation of PPO, GRPO, and SPO
|
||||
|
||||
- **YaRN Algorithm**: RoPE length extrapolation for improved long-sequence handling
|
||||
- **Adaptive Thinking**: Reasoning models support optional thinking chains
|
||||
- **Full template support**: Tool calling and reasoning tags (`<tool_call>`, `<think>`, etc.)
|
||||
- **Visualization**: Switched from WandB to [SwanLab](https://swanlab.cn/) (China-friendly)
|
||||
- **Reasoning models**: Complete MiniMind-Reason series based on DeepSeek-R1 distillation
|
||||
- **Agentic RL**: New `train_agent.py` for multi-turn Tool-Use RL with GRPO/CISPO
|
||||
- **RLAIF rollout engine**: Decoupled training/inference for flexible backends (SGLang, etc.)
|
||||
- **Tool Calling & Adaptive Thinking**: Native template support with `open_thinking` switch
|
||||
- **OpenAI API**: `serve_openai_api.py` supports `reasoning_content`, `tool_calls`, `open_thinking`
|
||||
- **Tokenizer**: Updated BPE + ByteLevel with tool call & thinking special tokens
|
||||
|
||||
## 🎯 Project Contents
|
||||
|
||||
- Complete MiniMind-LLM architecture code (Dense + MoE models)
|
||||
- Detailed Tokenizer training code
|
||||
- Full training pipeline: Pretrain → SFT → LoRA → RLHF/RLAIF → Distillation
|
||||
- Full training pipeline: Pretrain → SFT → LoRA → DPO → PPO/GRPO/CISPO → Agentic RL
|
||||
- Tool Calling & Adaptive Thinking (native chat template support)
|
||||
- High-quality, curated and deduplicated datasets at all stages
|
||||
- Native PyTorch implementation of key algorithms, minimal third-party dependencies
|
||||
- Multi-GPU training support (single-machine multi-card DDP, DeepSpeed, distributed clusters)
|
||||
- Visualization with Wandb/SwanLab
|
||||
- Model evaluation on third-party benchmarks (C-Eval, C-MMLU, OpenBookQA)
|
||||
- Visualization with SwanLab
|
||||
- Model evaluation on third-party benchmarks (C-Eval, CMMLU, OpenBookQA, etc.)
|
||||
- YaRN algorithm for RoPE context length extrapolation
|
||||
- OpenAI API protocol server for easy integration
|
||||
- OpenAI API server with reasoning_content / tool_calls / open_thinking
|
||||
- Streamlit web UI for chat
|
||||
- Full compatibility with community tools: llama.cpp, vllm, ollama, Llama-Factory
|
||||
- MiniMind-Reason models: Complete open-source data + weights for reasoning distillation
|
||||
- Full compatibility with community tools: llama.cpp, vllm, ollama, SGLang, MNN
|
||||
|
||||
## 🚀 Quick Navigation
|
||||
|
||||
|
||||
@ -53,27 +53,27 @@ Choose one option:
|
||||
|
||||
**From HuggingFace** (recommended for international users):
|
||||
```bash
|
||||
git clone https://huggingface.co/jingyaogong/MiniMind2
|
||||
git clone https://huggingface.co/jingyaogong/minimind-3
|
||||
```
|
||||
|
||||
**From ModelScope** (recommended for China users):
|
||||
```bash
|
||||
git clone https://www.modelscope.cn/models/gongjy/MiniMind2.git
|
||||
git clone https://www.modelscope.cn/models/gongjy/minimind-3.git
|
||||
```
|
||||
|
||||
### 3. Command-Line Chat
|
||||
|
||||
```bash
|
||||
# Use transformers format model
|
||||
python eval_llm.py --load_from ./MiniMind2
|
||||
python eval_llm.py --load_from ./minimind-3
|
||||
```
|
||||
|
||||
**Weight Options** (`--weight` parameter):
|
||||
- `pretrain`: Pretrain model (word continuation)
|
||||
- `full_sft`: SFT Chat model (conversation)
|
||||
- `dpo`: DPO model (preference optimization)
|
||||
- `reason`: Reasoning model (with thinking chains)
|
||||
- `ppo_actor`, `grpo`, `spo`: RLAIF models (reinforcement learning trained)
|
||||
- `ppo_actor`, `grpo`: RLAIF models (reinforcement learning trained)
|
||||
- `agent`: Agentic RL model (multi-turn Tool-Use)
|
||||
|
||||
**Example Session**:
|
||||
```text
|
||||
@ -115,16 +115,16 @@ MiniMind is compatible with popular inference engines:
|
||||
### Ollama (Easiest)
|
||||
|
||||
```bash
|
||||
ollama run jingyaogong/minimind2
|
||||
ollama run jingyaogong/minimind-3
|
||||
```
|
||||
|
||||
### vLLM (Fastest)
|
||||
|
||||
```bash
|
||||
vllm serve ./MiniMind2/ --served-model-name "minimind" --port 8000
|
||||
vllm serve ./minimind-3/ --model-impl transformers --served-model-name "minimind" --port 8998
|
||||
|
||||
# Test with curl
|
||||
curl http://localhost:8000/v1/chat/completions \
|
||||
curl http://localhost:8998/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "minimind",
|
||||
@ -137,14 +137,14 @@ curl http://localhost:8000/v1/chat/completions \
|
||||
### llama.cpp (CPU-Friendly)
|
||||
|
||||
```bash
|
||||
# Convert to GGUF format
|
||||
python scripts/convert_model.py ./MiniMind2/ --output ./MiniMind2.gguf
|
||||
# Convert to GGUF format (in llama.cpp directory)
|
||||
python convert_hf_to_gguf.py /path/to/minimind-3
|
||||
|
||||
# Quantize for size reduction
|
||||
./llama-quantize ./MiniMind2.gguf ./MiniMind2-Q4.gguf Q4_K_M
|
||||
./build/bin/llama-quantize /path/to/model/xxxx.gguf /path/to/model/xxxx.q8.gguf Q8_0
|
||||
|
||||
# Run inference
|
||||
./llama-cli -m ./MiniMind2-Q4.gguf -p "Hello" -n 128
|
||||
./build/bin/llama-cli -m /path/to/model/xxxx.gguf
|
||||
```
|
||||
|
||||
## 🔌 OpenAI API Server (For Integration)
|
||||
@ -152,19 +152,18 @@ python scripts/convert_model.py ./MiniMind2/ --output ./MiniMind2.gguf
|
||||
Run MiniMind as an OpenAI API-compatible service:
|
||||
|
||||
```bash
|
||||
python scripts/serve_openai_api.py
|
||||
cd scripts && python serve_openai_api.py
|
||||
```
|
||||
|
||||
Test the API:
|
||||
|
||||
```bash
|
||||
# In another terminal
|
||||
python scripts/chat_openai_api.py
|
||||
cd scripts && python chat_api.py
|
||||
```
|
||||
|
||||
**cURL Example**:
|
||||
```bash
|
||||
curl http://localhost:8000/v1/chat/completions \
|
||||
curl http://localhost:8998/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "minimind",
|
||||
@ -173,7 +172,8 @@ curl http://localhost:8000/v1/chat/completions \
|
||||
],
|
||||
"temperature": 0.7,
|
||||
"max_tokens": 256,
|
||||
"stream": true
|
||||
"stream": true,
|
||||
"open_thinking": true
|
||||
}'
|
||||
```
|
||||
|
||||
@ -187,14 +187,13 @@ This enables integration with:
|
||||
|
||||
| Use Case | Recommended Model | Memory | Speed |
|
||||
|----------|------------------|--------|-------|
|
||||
| Learning/Testing | MiniMind2-small (26M) | ~0.5 GB | Fastest |
|
||||
| Balanced | MiniMind2 (104M) | ~1.0 GB | Fast |
|
||||
| Expert System (MoE) | MiniMind2-MoE (145M) | ~1.0 GB | Dynamic |
|
||||
| Reasoning/Complex | MiniMind-Reason (104M) | ~1.0 GB | Standard |
|
||||
| Learning/Testing | MiniMind-3 (64M) | ~0.5 GB | Fast |
|
||||
| Expert System (MoE) | MiniMind-3-MoE (198M/A64M) | ~1.0 GB | Dynamic |
|
||||
| Tool-Use / Agent | MiniMind-3 Agent (64M) | ~0.5 GB | Fast |
|
||||
|
||||
## ⚡ Quick Test Results
|
||||
|
||||
**Model**: MiniMind2 (104M parameters)
|
||||
**Model**: MiniMind-3 (64M parameters)
|
||||
|
||||
```text
|
||||
Q: What is photosynthesis?
|
||||
|
||||
176
docs/training.md
@ -13,25 +13,23 @@ Tokenizer Training
|
||||
↓
|
||||
SFT (Learn conversation)
|
||||
↓
|
||||
┌───────────────────┬─────────────────────┬──────────────┐
|
||||
↓ ↓ ↓ ↓
|
||||
LoRA DPO/RLHF RLAIF (PPO/GRPO/SPO) Distillation
|
||||
(Domain adapt) (Preference) (Reinforcement Learn) (Reasoning)
|
||||
┌──────────────┬────────────────┬────────────────────────┬──────────────┐
|
||||
↓ ↓ ↓ ↓
|
||||
LoRA DPO/RLHF RLAIF (PPO/GRPO/CISPO) Agentic RL
|
||||
(Domain adapt) (Preference) (Reinforcement Learn) (Tool-Use RL)
|
||||
```
|
||||
|
||||
## 💰 Training Costs (Single NVIDIA 3090)
|
||||
|
||||
| Model | Dataset | Duration | Cost (RMB) | Quality |
|
||||
|-------|---------|----------|-----------|---------|
|
||||
| MiniMind2-Small | pretrain_hq + sft_mini_512 | 2.1h | ≈3 | 😊😊 |
|
||||
| MiniMind2-Small | Full dataset | 38h | ≈50 | 😊😊😊😊😊😊 |
|
||||
| MiniMind2 | pretrain_hq + sft_mini_512 | 3.3h | ≈5 | 😊😊 |
|
||||
| MiniMind2 | Full dataset | 122h | ≈160 | 😊😊😊😊😊😊😊 |
|
||||
| Model | params | pretrain_t2t_mini | sft_t2t_mini | toolcall | RLAIF |
|
||||
|-------|--------|-------------------|--------------|----------|-------|
|
||||
| MiniMind-3 | 64M | ≈1.21h / ≈1.57¥ | ≈1.10h / ≈1.43¥ | ≈0.9h / ≈1.17¥ | ≈1.1h / ≈1.43¥ |
|
||||
| MiniMind-3-moe | 198M / A64M | ≈1.69h / ≈2.20¥ | ≈1.54h / ≈2.00¥ | ≈1.26h / ≈1.64¥ | ≈1.54h / ≈2.00¥ |
|
||||
|
||||
!!! success "Ultra-Fast Training"
|
||||
**Just 2.1 hours + $3 = Functional ChatBot!**
|
||||
**Just ~2.3 hours + ¥3 = Functional ChatBot!**
|
||||
|
||||
Use `pretrain_hq.jsonl` + `sft_mini_512.jsonl` for fastest reproduction
|
||||
Use `pretrain_t2t_mini` + `sft_t2t_mini` for fastest reproduction
|
||||
|
||||
## 📋 Data Preparation
|
||||
|
||||
@ -49,21 +47,21 @@ cd dataset
|
||||
|
||||
```
|
||||
./dataset/
|
||||
├── pretrain_hq.jsonl ✨ (1.6GB, required for pretraining)
|
||||
├── sft_mini_512.jsonl ✨ (1.2GB, fastest SFT)
|
||||
├── sft_512.jsonl (7.5GB, standard SFT)
|
||||
├── sft_1024.jsonl (5.6GB, longer SFT)
|
||||
├── sft_2048.jsonl (9GB, very long SFT)
|
||||
├── dpo.jsonl ✨ (55MB, DPO training - optimized and simplified)
|
||||
├── r1_mix_1024.jsonl (340MB, reasoning distillation)
|
||||
├── rlaif-mini.jsonl (1MB, RLAIF algorithms)
|
||||
├── pretrain_t2t.jsonl (3.2GB, full pretraining)
|
||||
├── pretrain_t2t_mini.jsonl ✨ (0.8GB, quick pretraining)
|
||||
├── sft_t2t.jsonl (14GB, full SFT)
|
||||
├── sft_t2t_mini.jsonl ✨ (1.6GB, fastest SFT)
|
||||
├── dpo.jsonl ✨ (55MB, DPO training)
|
||||
├── rlaif.jsonl (20MB, RLAIF algorithms)
|
||||
├── agent_rl.jsonl (Agentic RL data)
|
||||
├── agent_rl_math.jsonl (Agentic RL math data)
|
||||
├── lora_identity.jsonl (22.8KB, identity LoRA)
|
||||
└── lora_medical.jsonl (34MB, medical domain LoRA)
|
||||
```
|
||||
|
||||
### 3. Data Formats
|
||||
|
||||
**Pretraining Data** (`pretrain_hq.jsonl`):
|
||||
**Pretraining Data** (`pretrain_t2t_mini.jsonl`):
|
||||
```json
|
||||
{"text": "How to overcome procrastination? Overcoming procrastination is not easy, but these suggestions may help..."}
|
||||
```
|
||||
@ -154,8 +152,8 @@ deepspeed --master_port 29500 --num_gpus=2 train_pretrain.py
|
||||
**Output**: `./out/pretrain_*.pth`
|
||||
|
||||
**Training Duration**:
|
||||
- MiniMind2-Small (26M): ~1.1h
|
||||
- MiniMind2 (104M): ~3.9h
|
||||
- MiniMind-3 (64M): ~1.21h
|
||||
- MiniMind-3-MoE (198M/A64M): ~1.69h
|
||||
|
||||
!!! tip "Pretraining Tips"
|
||||
- Start with `pretrain_hq.jsonl` for best results
|
||||
@ -176,7 +174,7 @@ torchrun --nproc_per_node 2 train_full_sft.py
|
||||
|
||||
**Configuration**:
|
||||
- Load pretrained model from Stage 1
|
||||
- Use SFT dataset (`sft_mini_512.jsonl` or `sft_512.jsonl`)
|
||||
- Use SFT dataset (`sft_t2t_mini.jsonl` or `sft_t2t.jsonl`)
|
||||
- Adjust `max_seq_len` to match training data
|
||||
|
||||
**Output**: `./out/full_sft_*.pth`
|
||||
@ -186,10 +184,8 @@ torchrun --nproc_per_node 2 train_full_sft.py
|
||||
- With full sft_512: 20-25 hours
|
||||
|
||||
!!! warning "SFT Data Selection"
|
||||
- `sft_mini_512.jsonl`: Fastest, ~1.2GB, 512 tokens max
|
||||
- `sft_512.jsonl`: Standard, ~7.5GB, 512 tokens max
|
||||
- `sft_1024.jsonl`: Longer, ~5.6GB, 1024 tokens max
|
||||
- `sft_2048.jsonl`: Extended, ~9GB, 2048 tokens max
|
||||
- `sft_t2t_mini.jsonl`: Fastest, ~1.6GB
|
||||
- `sft_t2t.jsonl`: Full, ~14GB
|
||||
|
||||
### Stage 3: LoRA Fine-Tuning (Optional)
|
||||
|
||||
@ -265,7 +261,7 @@ torchrun --nproc_per_node 2 train_dpo.py
|
||||
|
||||
### Stage 5: Reinforcement Learning from AI Feedback (RLAIF)
|
||||
|
||||
RLAIF is an advanced training approach using AI-generated rewards instead of human annotations. MiniMind implements three modern algorithms:
|
||||
RLAIF is an advanced training approach using AI-generated rewards instead of human annotations. MiniMind implements multiple algorithms:
|
||||
|
||||
#### 5.1 PPO (Proximal Policy Optimization)
|
||||
|
||||
@ -318,35 +314,53 @@ $$A_t = \frac{R - \mu_{group}}{\sigma_{group}}$$
|
||||
|
||||
**Training Duration**: ~1-3 hours
|
||||
|
||||
#### 5.3 SPO (Single-stream Policy Optimization)
|
||||
#### 5.3 CISPO (Clipped Importance Sampling Policy Optimization)
|
||||
|
||||
Newest algorithm (2025) addressing GRPO's degenerate group problem.
|
||||
CISPO fixes a long-standing issue in PPO/GRPO where clipped ratios kill gradient flow. It rewrites the policy term as "clipped weight × log probability", so gradients still pass through even when the ratio is truncated.
|
||||
|
||||
CISPO is implemented as a loss variant of GRPO. Just set `loss_type` to `cispo` in `train_grpo.py`:
|
||||
|
||||
```bash
|
||||
python train_spo.py
|
||||
|
||||
# Multi-GPU
|
||||
torchrun --nproc_per_node 2 train_spo.py
|
||||
# In train_grpo.py, set loss_type = 'cispo'
|
||||
python train_grpo.py
|
||||
```
|
||||
|
||||
**Algorithm**:
|
||||
$$\mathcal{L}_{SPO} = -\mathbb{E}\left[\log \pi_\theta(a_t|s) \cdot A_t - \beta \cdot \text{KL}_t\right]$$
|
||||
|
||||
With adaptive baseline: $B_t^{adaptive}$
|
||||
$$\mathcal{L}_{CISPO} = -\mathbb{E}\left[\min(r_t, \varepsilon_{max}) \cdot A_t \cdot \log \pi_\theta(a_t|s) - \beta \cdot \text{KL}_t\right]$$
|
||||
|
||||
**Characteristics**:
|
||||
- No group dependency (1 input → 1 training sample)
|
||||
- Adaptive value tracking
|
||||
- Better handling of difficult examples
|
||||
- Experimental on small models
|
||||
- Gradient flow preserved even when ratio is clipped
|
||||
- Shares GRPO's group sampling and advantage computation
|
||||
- Single-network, memory efficient
|
||||
- No separate script needed
|
||||
|
||||
**Output**: `./out/spo_*.pth`
|
||||
#### 5.4 Agentic RL (Multi-turn Tool-Use)
|
||||
|
||||
**Training Duration**: ~1-3 hours
|
||||
Agentic RL trains the model to perform multi-turn tool calling with delayed rewards. The model generates tool call actions, receives observations, and continues until task completion.
|
||||
|
||||
```bash
|
||||
python train_agent.py
|
||||
|
||||
# Multi-GPU
|
||||
torchrun --nproc_per_node N train_agent.py
|
||||
```
|
||||
|
||||
**Reward**:
|
||||
$$R(\tau) = R_{\text{answer}} + R_{\text{tool}} + R_{\text{format}} + R_{\text{rm}} - R_{\text{unfinished}}$$
|
||||
|
||||
**Characteristics**:
|
||||
- Multi-turn rollout with tool execution
|
||||
- Delayed reward (scored after full trajectory)
|
||||
- Decoupled rollout engine for flexible inference backends
|
||||
- Supports GRPO/CISPO loss variants
|
||||
|
||||
**Data**: `agent_rl.jsonl` / `agent_rl_math.jsonl` (with `gt` field for verification)
|
||||
|
||||
**Output**: `./out/agent_*.pth`
|
||||
|
||||
#### RLAIF Dataset Preparation
|
||||
|
||||
All RLAIF algorithms use `rlaif-mini.jsonl` (1MB, 10k examples):
|
||||
RLAIF algorithms use `rlaif.jsonl` (~20MB):
|
||||
|
||||
```bash
|
||||
# Download dataset
|
||||
@ -359,7 +373,7 @@ All RLAIF algorithms use `rlaif-mini.jsonl` (1MB, 10k examples):
|
||||
}
|
||||
```
|
||||
|
||||
The model generates completions during training, which are scored by a **Reward Model** (e.g., InternLM2-1.8B-Reward).
|
||||
The model generates completions during training, which are scored by a **Reward Model** (e.g., InternLM2-1.8B-Reward). Agentic RL uses `agent_rl.jsonl` with additional `gt` field for answer verification.
|
||||
|
||||
**Reward Model Setup**:
|
||||
|
||||
@ -376,7 +390,7 @@ git clone https://huggingface.co/internlm/internlm2-1_8b-reward
|
||||
|
||||
#### RLAIF vs DPO Comparison
|
||||
|
||||
| Aspect | DPO | RLAIF (PPO/GRPO/SPO) |
|
||||
| Aspect | DPO | RLAIF (PPO/GRPO/CISPO) |
|
||||
|--------|-----|---------------------|
|
||||
| Training Type | Off-policy | On-policy |
|
||||
| Data Freshness | Static pairs | Dynamic (generated) |
|
||||
@ -385,39 +399,23 @@ git clone https://huggingface.co/internlm/internlm2-1_8b-reward
|
||||
| Memory Usage | Lower | Higher |
|
||||
| Best For | Preference refinement | Capability improvement |
|
||||
|
||||
### Stage 6: Reasoning Model Distillation
|
||||
### Stage 6: Knowledge Distillation (Optional)
|
||||
|
||||
**Purpose**: Distill DeepSeek-R1-style reasoning into MiniMind
|
||||
**Purpose**: Transfer knowledge from a larger teacher model to MiniMind
|
||||
|
||||
MiniMind supports both black-box and white-box distillation:
|
||||
|
||||
- **Black-box**: Train on teacher-generated answers (equivalent to SFT on strong model outputs)
|
||||
- **White-box**: Additionally fit the teacher's token distribution via CE + KL mixed loss
|
||||
|
||||
```bash
|
||||
python train_distill_reason.py
|
||||
python train_distillation.py
|
||||
|
||||
# Multi-GPU
|
||||
torchrun --nproc_per_node 2 train_distill_reason.py
|
||||
torchrun --nproc_per_node N train_distillation.py
|
||||
```
|
||||
|
||||
**Data Format** (`r1_mix_1024.jsonl`):
|
||||
```json
|
||||
{
|
||||
"conversations": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Solve: 5 + 3 = ?"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "<think>\nI need to add 5 and 3.\n5 + 3 = 8\n</think>\n<answer>\n5 + 3 = 8\n</answer>"
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
**Output**: `./out/reason_*.pth`
|
||||
|
||||
**Training Features**:
|
||||
- Enforces `<think>` and `<answer>` tags
|
||||
- Penalty loss for format violations
|
||||
- Mixed data (reasoning + multi-turn + English)
|
||||
**Output**: Distilled model weights
|
||||
|
||||
## 🔧 Multi-GPU Training
|
||||
|
||||
@ -473,12 +471,6 @@ python eval_llm.py --weight full_sft
|
||||
python eval_llm.py --weight dpo --lora_weight lora_medical
|
||||
```
|
||||
|
||||
### Evaluate Reasoning Model
|
||||
|
||||
```bash
|
||||
python eval_llm.py --weight reason
|
||||
```
|
||||
|
||||
### Evaluate RLAIF Models
|
||||
|
||||
```bash
|
||||
@ -488,8 +480,8 @@ python eval_llm.py --weight ppo_actor
|
||||
# GRPO model
|
||||
python eval_llm.py --weight grpo
|
||||
|
||||
# SPO model
|
||||
python eval_llm.py --weight spo
|
||||
# Agent model (Agentic RL)
|
||||
python eval_toolcall.py --weight agent
|
||||
```
|
||||
|
||||
### RoPE Length Extrapolation
|
||||
@ -528,15 +520,15 @@ Output Probabilities
|
||||
|
||||
### Model Configurations
|
||||
|
||||
| Config | MiniMind2-Small | MiniMind2 | MiniMind2-MoE |
|
||||
|--------|-----------------|----------|---------------|
|
||||
| Parameters | 26M | 104M | 145M |
|
||||
| Hidden Dim | 512 | 768 | 640 |
|
||||
| Layers | 8 | 16 | 8 |
|
||||
| KV Heads | 2 | 2 | 2 |
|
||||
| Q Heads | 8 | 8 | 8 |
|
||||
| Vocab Size | 6,400 | 6,400 | 6,400 |
|
||||
| Context Length | 2,048 | 2,048 | 2,048 |
|
||||
| Config | MiniMind-3 | MiniMind-3-MoE |
|
||||
|--------|-----------|----------------|
|
||||
| Parameters | 64M | 198M / A64M |
|
||||
| Hidden Dim | 768 | 768 |
|
||||
| Layers | 8 | 8 |
|
||||
| KV Heads | 2 | 2 |
|
||||
| Q Heads | 8 | 8 |
|
||||
| Vocab Size | 6,400 | 6,400 |
|
||||
| Context Length | 2,048 | 2,048 |
|
||||
|
||||
### Modify Architecture
|
||||
|
||||
@ -545,7 +537,7 @@ Edit `./model/LMConfig.py`:
|
||||
```python
|
||||
class LMConfig:
|
||||
hidden_size: int = 768
|
||||
num_layers: int = 16
|
||||
num_layers: int = 8
|
||||
num_heads: int = 8
|
||||
num_kv_heads: int = 2
|
||||
# ... other configs
|
||||
@ -692,7 +684,7 @@ Create your own dataset:
|
||||
- [YaRN Length Extrapolation](https://arxiv.org/abs/2309.00071)
|
||||
- [PPO Algorithm](https://arxiv.org/abs/1707.06347)
|
||||
- [GRPO (DeepSeek)](https://arxiv.org/pdf/2402.03300)
|
||||
- [SPO Algorithm](https://arxiv.org/abs/2509.13232)
|
||||
- [CISPO Algorithm](https://huggingface.co/papers/2506.13585)
|
||||
- [DPO](https://arxiv.org/abs/2305.18290)
|
||||
|
||||
---
|
||||
|
||||