mirror of
https://github.com/jingyaogong/minimind.git
synced 2026-05-01 19:58:15 +08:00
[feat] update docs
This commit is contained in:
parent
941db7a5e6
commit
f44ee7a1b0
BIN
docs/images/train_grpo_512.png
Normal file
BIN
docs/images/train_grpo_512.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 214 KiB |
BIN
docs/images/train_grpo_768.png
Normal file
BIN
docs/images/train_grpo_768.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 246 KiB |
BIN
docs/images/train_ppo_512.png
Normal file
BIN
docs/images/train_ppo_512.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 246 KiB |
BIN
docs/images/train_ppo_768.png
Normal file
BIN
docs/images/train_ppo_768.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 241 KiB |
BIN
docs/images/train_spo_768.png
Normal file
BIN
docs/images/train_spo_768.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 234 KiB |
131
docs/index.md
131
docs/index.md
@ -1,4 +1,4 @@
|
||||
# <strong>Welcome to MiniMind!</strong>
|
||||
# Welcome to MiniMind!
|
||||
|
||||
<figure markdown>
|
||||

|
||||
@ -7,47 +7,118 @@
|
||||
|
||||
## 📌 Introduction
|
||||
|
||||
MiniMind is a super-small language model project trained completely from scratch, requiring **only $0.5 + 2 hours** to train a **26M** language model!
|
||||
**MiniMind** is a complete, open-source project for training ultra-small language models from scratch with minimal cost. Train a **26M** ChatBot in just **2 hours** with only **$3** on a single 3090 GPU!
|
||||
|
||||
- **MiniMind** series is extremely lightweight, the smallest version is **1/7000** the size of GPT-3
|
||||
- The project open-sources the minimalist structure of large models, including:
|
||||
- Mixture of Experts (MoE)
|
||||
- Dataset cleaning
|
||||
- Pretraining
|
||||
- Supervised Fine-Tuning (SFT)
|
||||
- LoRA fine-tuning
|
||||
- Direct Preference Optimization (DPO)
|
||||
- Model distillation
|
||||
- All core algorithm code is reconstructed from scratch using native PyTorch, without relying on third-party abstract interfaces
|
||||
- This is not only a full-stage open-source reproduction of large language models, but also a tutorial for getting started with LLMs
|
||||
- Complete implementation covering:
|
||||
- **Tokenizer training** with custom vocabulary
|
||||
- **Pretraining** (knowledge learning)
|
||||
- **Supervised Fine-Tuning (SFT)** (conversation patterns)
|
||||
- **LoRA fine-tuning** (parameter-efficient adaptation)
|
||||
- **Direct Preference Optimization (DPO)** (human preference alignment)
|
||||
- **RLAIF algorithms** (PPO/GRPO/SPO - reinforcement learning)
|
||||
- **Knowledge distillation** (compress large model knowledge)
|
||||
- **Model reasoning distillation** (DeepSeek-R1 style)
|
||||
- **YaRN algorithm** (context length extrapolation)
|
||||
- **Pure PyTorch implementation**: All core algorithms are implemented from scratch using native PyTorch, without relying on third-party abstract interfaces
|
||||
- **Educational value**: This is not only a full-stage open-source reproduction of large language models, but also a comprehensive tutorial for getting started with LLMs
|
||||
- **Extended capabilities**: MiniMind now supports [MiniMind-V](https://github.com/jingyaogong/minimind-v) for vision multimodal tasks
|
||||
|
||||
!!! note "Training Cost"
|
||||
"2 hours" is based on NVIDIA 3090 hardware (single card) testing, "$0.5" refers to GPU server rental cost
|
||||
!!! note "Training Cost & Time"
|
||||
"2 hours" is based on **NVIDIA 3090** hardware (single card) testing
|
||||
|
||||
"$3" refers to GPU server rental cost
|
||||
|
||||
With 8× RTX 4090 GPUs, training time can be compressed to **under 10 minutes**
|
||||
|
||||
## ✨ Key Features
|
||||
## ✨ Key Highlights
|
||||
|
||||
- **Ultra-low cost**: Single 3090, 2 hours, $0.5 to train a ChatBot from scratch
|
||||
- **Complete pipeline**: Covers Tokenizer, pretraining, SFT, LoRA, DPO, distillation full process
|
||||
- **Education-friendly**: Clean code, suitable for learning LLM principles
|
||||
- **Ecosystem compatible**: Supports `transformers`, `llama.cpp`, `vllm`, `ollama` and other mainstream frameworks
|
||||
- **Ultra-low cost**: Single 3090, 2 hours, $3 to train a fully functional ChatBot from scratch
|
||||
- **Complete pipeline**: Tokenizer → Pretraining → SFT → LoRA → DPO/RLAIF → Distillation → Reasoning
|
||||
- **Latest algorithms**: Implements cutting-edge techniques including GRPO, SPO, and YaRN
|
||||
- **Education-friendly**: Clean, well-documented code suitable for learning LLM principles
|
||||
- **Ecosystem compatible**: Seamless support for `transformers`, `trl`, `peft`, `llama.cpp`, `vllm`, `ollama`, and `Llama-Factory`
|
||||
- **Full capabilities**: Supports multi-GPU training (DDP/DeepSpeed), model visualization (Wandb/SwanLab), and dynamic checkpoint management
|
||||
- **Production-ready**: OpenAI API protocol support for easy integration with third-party UIs (FastGPT, Open-WebUI, etc.)
|
||||
- **Multimodal extension**: Extended to vision with [MiniMind-V](https://github.com/jingyaogong/minimind-v)
|
||||
|
||||
## 📊 Model List
|
||||
## 📊 Model Series
|
||||
|
||||
| Model (Size) | Inference Memory (Approx.) | Release |
|
||||
|------------|----------|---------|
|
||||
| MiniMind2-small (26M) | 0.5 GB | 2025.04.26 |
|
||||
| MiniMind2-MoE (145M) | 1.0 GB | 2025.04.26 |
|
||||
| MiniMind2 (104M) | 1.0 GB | 2025.04.26 |
|
||||
### MiniMind2 Series (Latest - 2025.04.26)
|
||||
|
||||
| Model | Parameters | Vocabulary | Layers | Hidden Dim | Context | Inference Memory |
|
||||
|-------|-----------|------------|--------|-----------|---------|-----------------|
|
||||
| MiniMind2-small | 26M | 6,400 | 8 | 512 | 2K | ~0.5 GB |
|
||||
| MiniMind2-MoE | 145M | 6,400 | 8 | 640 | 2K | ~1.0 GB |
|
||||
| MiniMind2 | 104M | 6,400 | 16 | 768 | 2K | ~1.0 GB |
|
||||
|
||||
### MiniMind-V1 Series (Legacy - 2024.09.01)
|
||||
|
||||
| Model | Parameters | Vocabulary | Layers | Hidden Dim | Context |
|
||||
|-------|-----------|------------|--------|-----------|---------|
|
||||
| minimind-v1-small | 26M | 6,400 | 8 | 512 | 2K |
|
||||
| minimind-v1-moe | 104M | 6,400 | 8 | 512 | 2K |
|
||||
| minimind-v1 | 108M | 6,400 | 16 | 768 | 2K |
|
||||
|
||||
## 📅 Latest Updates (2025-10-24)
|
||||
|
||||
🔥 **RLAIF Training Algorithms**: Native implementation of PPO, GRPO, and SPO
|
||||
|
||||
- **YaRN Algorithm**: RoPE length extrapolation for improved long-sequence handling
|
||||
- **Adaptive Thinking**: Reasoning models support optional thinking chains
|
||||
- **Full template support**: Tool calling and reasoning tags (`<tool_call>`, `<think>`, etc.)
|
||||
- **Visualization**: Switched from WandB to [SwanLab](https://swanlab.cn/) (China-friendly)
|
||||
- **Reasoning models**: Complete MiniMind-Reason series based on DeepSeek-R1 distillation
|
||||
|
||||
## 🎯 Project Contents
|
||||
|
||||
- Complete MiniMind-LLM architecture code (Dense + MoE models)
|
||||
- Detailed Tokenizer training code
|
||||
- Full training pipeline: Pretrain → SFT → LoRA → RLHF/RLAIF → Distillation
|
||||
- High-quality, curated and deduplicated datasets at all stages
|
||||
- Native PyTorch implementation of key algorithms, minimal third-party dependencies
|
||||
- Multi-GPU training support (single-machine multi-card DDP, DeepSpeed, distributed clusters)
|
||||
- Visualization with Wandb/SwanLab
|
||||
- Model evaluation on third-party benchmarks (C-Eval, C-MMLU, OpenBookQA)
|
||||
- YaRN algorithm for RoPE context length extrapolation
|
||||
- OpenAI API protocol server for easy integration
|
||||
- Streamlit web UI for chat
|
||||
- Full compatibility with community tools: llama.cpp, vllm, ollama, Llama-Factory
|
||||
- MiniMind-Reason models: Complete open-source data + weights for reasoning distillation
|
||||
|
||||
## 🚀 Quick Navigation
|
||||
|
||||
- [Quick Start](quickstart.md) - Environment setup, model download, quick testing
|
||||
- [Model Training](training.md) - Pretraining, SFT, LoRA, DPO training process
|
||||
- **[Quick Start](quickstart.md)** - Environment setup, model download, quick testing
|
||||
- **[Model Training](training.md)** - Pretraining, SFT, LoRA, RLHF, RLAIF, and reasoning training
|
||||
|
||||
## 🔗 Related Links
|
||||
## 🔗 Links & Resources
|
||||
|
||||
**Project Repositories**:
|
||||
- **GitHub**: [https://github.com/jingyaogong/minimind](https://github.com/jingyaogong/minimind)
|
||||
- **HuggingFace**: [MiniMind Collection](https://huggingface.co/collections/jingyaogong/minimind-66caf8d999f5c7fa64f399e5)
|
||||
- **ModelScope**: [MiniMind Models](https://www.modelscope.cn/profile/gongjy)
|
||||
- **Online Demo**: [ModelScope Studio](https://www.modelscope.cn/studios/gongjy/MiniMind)
|
||||
- **ModelScope**: [MiniMind Profile](https://www.modelscope.cn/profile/gongjy)
|
||||
|
||||
**Online Demos**:
|
||||
- [ModelScope Studio - Standard Chat](https://www.modelscope.cn/studios/gongjy/MiniMind)
|
||||
- [ModelScope Studio - Reasoning Model](https://www.modelscope.cn/studios/gongjy/MiniMind-Reasoning)
|
||||
- [Bilibili Video Introduction](https://www.bilibili.com/video/BV12dHPeqE72/)
|
||||
|
||||
**Vision Extension**:
|
||||
- [MiniMind-V](https://github.com/jingyaogong/minimind-v) - Multimodal vision language models
|
||||
|
||||
## 💡 Why MiniMind?
|
||||
|
||||
The AI community is flooded with high-cost, complex frameworks that abstract away the fundamentals. MiniMind aims to democratize LLM learning by:
|
||||
|
||||
1. **Lowering the barrier**: No need for expensive GPUs or cloud services
|
||||
2. **Understanding, not just using**: Learn every detail from tokenization to inference
|
||||
3. **End-to-end learning**: Train from scratch, not just fine-tune existing models
|
||||
4. **Code clarity**: Pure PyTorch implementations you can read and understand
|
||||
5. **Practical results**: Get a working ChatBot with minimal resources
|
||||
|
||||
As we say: **"Building a Lego airplane is far more exciting than flying first class!"**
|
||||
|
||||
---
|
||||
|
||||
Next: [Get Started →](quickstart.md)
|
||||
|
||||
|
||||
@ -1,114 +1,279 @@
|
||||
# Quick Start
|
||||
|
||||
This page will help you quickly get started with the MiniMind project.
|
||||
Get MiniMind up and running in minutes!
|
||||
|
||||
## 📋 Requirements
|
||||
|
||||
### Hardware
|
||||
|
||||
- **GPU Memory**: 8GB minimum (24GB recommended for comfortable development)
|
||||
- **Recommended GPU**: NVIDIA RTX 3090 (24GB)
|
||||
|
||||
### Software
|
||||
|
||||
- **Python**: 3.10+
|
||||
- **PyTorch**: 1.12+
|
||||
- **PyTorch**: 2.0+ (with CUDA 12.2+ for GPU support)
|
||||
- **CUDA**: 12.2+ (optional, for GPU acceleration)
|
||||
- **VRAM**: At least 8GB (24GB recommended)
|
||||
|
||||
!!! tip "Hardware Configuration Reference"
|
||||
- CPU: Intel i9-10980XE @ 3.00GHz
|
||||
- RAM: 128 GB
|
||||
- GPU: NVIDIA GeForce RTX 3090 (24GB)
|
||||
- **CPU**: Intel i9-10980XE @ 3.00GHz
|
||||
- **RAM**: 128 GB
|
||||
- **GPU**: NVIDIA GeForce RTX 3090 (24GB) × 8
|
||||
- **OS**: Ubuntu 20.04
|
||||
- **CUDA**: 12.2
|
||||
- **Python**: 3.10.16
|
||||
|
||||
## 🚀 Testing Existing Models
|
||||
|
||||
### 1. Clone the Project
|
||||
## 🚀 Step 0: Clone the Repository
|
||||
|
||||
```bash
|
||||
git clone https://github.com/jingyaogong/minimind.git
|
||||
cd minimind
|
||||
```
|
||||
|
||||
### 2. Install Dependencies
|
||||
## 🎯 Section I: Testing Existing Models
|
||||
|
||||
### 1. Environment Setup
|
||||
|
||||
```bash
|
||||
pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple
|
||||
```
|
||||
|
||||
!!! warning "Torch CUDA Check"
|
||||
After installation, test if Torch can use CUDA:
|
||||
!!! warning "Verify CUDA Support"
|
||||
After installation, verify PyTorch can access CUDA:
|
||||
```python
|
||||
import torch
|
||||
print(torch.cuda.is_available())
|
||||
print(torch.cuda.get_device_name(0))
|
||||
```
|
||||
If `False`, download the correct PyTorch version from [PyTorch Official](https://download.pytorch.org/whl/torch_stable.html)
|
||||
|
||||
### 3. Download Model
|
||||
### 2. Download Pretrained Models
|
||||
|
||||
Download pretrained models from HuggingFace or ModelScope:
|
||||
Choose one option:
|
||||
|
||||
**From HuggingFace** (recommended for international users):
|
||||
```bash
|
||||
# From HuggingFace
|
||||
git clone https://huggingface.co/jingyaogong/MiniMind2
|
||||
```
|
||||
|
||||
# Or from ModelScope
|
||||
**From ModelScope** (recommended for China users):
|
||||
```bash
|
||||
git clone https://www.modelscope.cn/models/gongjy/MiniMind2.git
|
||||
```
|
||||
|
||||
### 4. Command Line Q&A
|
||||
### 3. Command-Line Chat
|
||||
|
||||
```bash
|
||||
# load=0: load PyTorch model, load=1: load Transformers model
|
||||
# load=0: load PyTorch model, load=1: load transformers model
|
||||
python eval_model.py --load 1 --model_mode 2
|
||||
```
|
||||
|
||||
### 5. Start WebUI (Optional)
|
||||
**Model Modes**:
|
||||
- `model_mode 0`: Pretrain model (word continuation)
|
||||
- `model_mode 1`: SFT Chat model (conversation)
|
||||
- `model_mode 2`: RLHF model (refined responses, currently same as SFT for small models)
|
||||
- `model_mode 3`: Reasoning model (with thinking chains)
|
||||
- `model_mode 4/5`: RLAIF models (PPO/GRPO trained)
|
||||
|
||||
**Example Session**:
|
||||
```text
|
||||
👶: Hello, please introduce yourself.
|
||||
🤖️: I am MiniMind, an AI assistant developed by Jingyao Gong.
|
||||
I use natural language processing and machine learning algorithms to interact with users.
|
||||
|
||||
👶: What's the capital of France?
|
||||
🤖️: The capital of France is Paris, which is located in the northern central part of France.
|
||||
It is the largest city in France and serves as its political, economic, and cultural center.
|
||||
```
|
||||
|
||||
### 4. Web UI Demo (Optional)
|
||||
|
||||
```bash
|
||||
# Requires Python >= 3.10
|
||||
pip install streamlit
|
||||
|
||||
cd scripts
|
||||
streamlit run web_demo.py
|
||||
```
|
||||
|
||||
Visit `http://localhost:8501` to use the web interface.
|
||||
Visit `http://localhost:8501` to use the interactive web interface.
|
||||
|
||||
## 🔧 Third-party Inference Frameworks
|
||||
### 5. Rope Length Extrapolation with YaRN
|
||||
|
||||
MiniMind supports multiple mainstream inference frameworks:
|
||||
Extend context length beyond training with RoPE extrapolation:
|
||||
|
||||
### Ollama
|
||||
```bash
|
||||
python eval_model.py --inference_rope_scaling True
|
||||
```
|
||||
|
||||
This enables the YaRN algorithm to handle sequences longer than the 2K training context, useful for processing documents and long conversations.
|
||||
|
||||
## 🔧 Third-Party Inference Frameworks
|
||||
|
||||
MiniMind is compatible with popular inference engines:
|
||||
|
||||
### Ollama (Easiest)
|
||||
|
||||
```bash
|
||||
ollama run jingyaogong/minimind2
|
||||
```
|
||||
|
||||
### vLLM
|
||||
### vLLM (Fastest)
|
||||
|
||||
```bash
|
||||
vllm serve ./MiniMind2/ --served-model-name "minimind"
|
||||
vllm serve ./MiniMind2/ --served-model-name "minimind" --port 8000
|
||||
|
||||
# Test with curl
|
||||
curl http://localhost:8000/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "minimind",
|
||||
"messages": [{"role": "user", "content": "Hello!"}],
|
||||
"temperature": 0.7,
|
||||
"max_tokens": 512
|
||||
}'
|
||||
```
|
||||
|
||||
### llama.cpp
|
||||
### llama.cpp (CPU-Friendly)
|
||||
|
||||
```bash
|
||||
# Convert model
|
||||
python convert_hf_to_gguf.py ./MiniMind2/
|
||||
# Convert to GGUF format
|
||||
python scripts/convert_model.py ./MiniMind2/ --output ./MiniMind2.gguf
|
||||
|
||||
# Quantize model
|
||||
./build/bin/llama-quantize ./MiniMind2/MiniMind2-109M-F16.gguf ./Q4-MiniMind2.gguf Q4_K_M
|
||||
# Quantize for size reduction
|
||||
./llama-quantize ./MiniMind2.gguf ./MiniMind2-Q4.gguf Q4_K_M
|
||||
|
||||
# Inference
|
||||
./build/bin/llama-cli -m ./Q4-MiniMind2.gguf --chat-template chatml
|
||||
# Run inference
|
||||
./llama-cli -m ./MiniMind2-Q4.gguf -p "Hello" -n 128
|
||||
```
|
||||
|
||||
## 📝 Effect Testing
|
||||
## 🔌 OpenAI API Server (For Integration)
|
||||
|
||||
Run MiniMind as an OpenAI API-compatible service:
|
||||
|
||||
```bash
|
||||
python scripts/serve_openai_api.py
|
||||
```
|
||||
|
||||
Test the API:
|
||||
|
||||
```bash
|
||||
# In another terminal
|
||||
python scripts/chat_openai_api.py
|
||||
```
|
||||
|
||||
**cURL Example**:
|
||||
```bash
|
||||
curl http://localhost:8000/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "minimind",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Explain machine learning in one sentence."}
|
||||
],
|
||||
"temperature": 0.7,
|
||||
"max_tokens": 256,
|
||||
"stream": true
|
||||
}'
|
||||
```
|
||||
|
||||
This enables integration with:
|
||||
- [FastGPT](https://fastgpt.run/)
|
||||
- [Open-WebUI](https://github.com/open-webui/open-webui)
|
||||
- [Dify](https://dify.ai/)
|
||||
- Any OpenAI API-compatible client
|
||||
|
||||
## 📊 Model Selection Guide
|
||||
|
||||
| Use Case | Recommended Model | Memory | Speed |
|
||||
|----------|------------------|--------|-------|
|
||||
| Learning/Testing | MiniMind2-small (26M) | ~0.5 GB | Fastest |
|
||||
| Balanced | MiniMind2 (104M) | ~1.0 GB | Fast |
|
||||
| Expert System (MoE) | MiniMind2-MoE (145M) | ~1.0 GB | Dynamic |
|
||||
| Reasoning/Complex | MiniMind-Reason (104M) | ~1.0 GB | Standard |
|
||||
|
||||
## ⚡ Quick Test Results
|
||||
|
||||
**Model**: MiniMind2 (104M parameters)
|
||||
|
||||
```text
|
||||
👶: Hello, please introduce yourself.
|
||||
🤖️: Hello! I'm MiniMind, an AI assistant developed by Jingyao Gong.
|
||||
I interact with users through natural language processing and algorithm training.
|
||||
Q: What is photosynthesis?
|
||||
A: Photosynthesis is a process in which plants convert light energy from the sun
|
||||
into chemical energy to produce glucose. This process occurs mainly in leaves
|
||||
and is essential for plant growth and survival.
|
||||
|
||||
👶: What is the highest mountain in the world?
|
||||
🤖️: Mount Everest is the highest mountain in the world, located in the Himalayas,
|
||||
with an elevation of 8,848.86 meters (29,031.7 feet).
|
||||
Q: Write a Python function to calculate Fibonacci numbers.
|
||||
A: def fibonacci(n):
|
||||
if n <= 1:
|
||||
return n
|
||||
return fibonacci(n-1) + fibonacci(n-2)
|
||||
|
||||
# For better performance, use dynamic programming:
|
||||
def fibonacci_dp(n):
|
||||
dp = [0] * (n + 1)
|
||||
for i in range(2, n + 1):
|
||||
dp[i] = dp[i-1] + dp[i-2]
|
||||
return dp[n]
|
||||
|
||||
Q: 世界上最高的山峰是什么? (What is the highest mountain?)
|
||||
A: 珠穆朗玛峰(Mount Everest)是世界上最高的山峰,位于喜马拉雅山脉...
|
||||
(Mount Everest is the world's highest mountain, located in the Himalayas...)
|
||||
```
|
||||
|
||||
## 🎯 Next Steps
|
||||
## 🆘 Troubleshooting
|
||||
|
||||
- Check [Model Training](training.md) to learn how to train your own model from scratch
|
||||
- Read the source code to understand LLM implementation principles
|
||||
### Issue: CUDA Out of Memory
|
||||
|
||||
**Solution**:
|
||||
```bash
|
||||
# Reduce batch size
|
||||
python eval_model.py --batch_size 1
|
||||
|
||||
# Or use CPU (slow but works)
|
||||
python eval_model.py --device cpu
|
||||
```
|
||||
|
||||
### Issue: Slow Inference
|
||||
|
||||
**Solutions**:
|
||||
- Use vLLM or llama.cpp for faster inference
|
||||
- Enable quantization (4-bit, 8-bit)
|
||||
- Use GPU instead of CPU
|
||||
- Reduce `max_tokens` parameter
|
||||
|
||||
### Issue: Model Responses Are Poor Quality
|
||||
|
||||
**Possible Causes**:
|
||||
- Using pretrain model (`model_mode 0`) instead of SFT (`model_mode 1`)
|
||||
- Model is undertrained - download the full checkpoint instead
|
||||
- Input prompt is too short - provide more context
|
||||
|
||||
### Issue: Python/PyTorch Version Mismatch
|
||||
|
||||
**Solution**:
|
||||
```bash
|
||||
# Use conda for clean environment
|
||||
conda create -n minimind python=3.10
|
||||
conda activate minimind
|
||||
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu122
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
## 📖 Next Steps
|
||||
|
||||
- **[Model Training Guide](training.md)** - Train your own MiniMind from scratch
|
||||
- **[Source Code](https://github.com/jingyaogong/minimind)** - Explore and learn LLM implementation
|
||||
- **[Inference Benchmarks](https://huggingface.co/collections/jingyaogong/minimind-66caf8d999f5c7fa64f399e5)** - See model performance comparisons
|
||||
|
||||
## 💡 Pro Tips
|
||||
|
||||
1. **GPU Memory Optimization**: Use `torch.cuda.empty_cache()` periodically
|
||||
2. **Batch Processing**: For efficiency, process multiple prompts in batches
|
||||
3. **Temperature Tuning**: Lower (0.3-0.7) = more consistent, Higher (0.8-1.0) = more creative
|
||||
4. **Prompt Engineering**: Better prompts → better results, even for small models
|
||||
5. **Model Quantization**: Use 4-bit quantization to run on smaller GPUs
|
||||
|
||||
---
|
||||
|
||||
Done! Now you're ready to use MiniMind. Start with the Quick Start, then move to [Model Training](training.md) to learn how to train your own models.
|
||||
|
||||
|
||||
689
docs/training.md
689
docs/training.md
@ -1,155 +1,419 @@
|
||||
# Model Training
|
||||
# Model Training Guide
|
||||
|
||||
This page introduces how to train MiniMind language models from scratch.
|
||||
Learn how to train MiniMind language models from scratch using pure PyTorch.
|
||||
|
||||
## 📊 Data Preparation
|
||||
## 📊 Training Overview
|
||||
|
||||
### 1. Download Dataset
|
||||
MiniMind implements a complete training pipeline:
|
||||
|
||||
Download datasets from [ModelScope](https://www.modelscope.cn/datasets/gongjy/minimind_dataset/files) or [HuggingFace](https://huggingface.co/datasets/jingyaogong/minimind_dataset).
|
||||
|
||||
Create `./dataset` directory and place data files:
|
||||
|
||||
```bash
|
||||
./dataset/
|
||||
├── pretrain_hq.jsonl (1.6GB, ✨Recommended)
|
||||
├── sft_mini_512.jsonl (1.2GB, ✨Recommended)
|
||||
├── sft_512.jsonl (7.5GB)
|
||||
├── sft_1024.jsonl (5.6GB)
|
||||
├── sft_2048.jsonl (9GB)
|
||||
├── dpo.jsonl (909MB)
|
||||
├── r1_mix_1024.jsonl (340MB)
|
||||
└── lora_*.jsonl
|
||||
```
|
||||
Tokenizer Training
|
||||
↓
|
||||
Pretraining (Learn knowledge)
|
||||
↓
|
||||
SFT (Learn conversation)
|
||||
↓
|
||||
┌───────────────────┬─────────────────────┬──────────────┐
|
||||
↓ ↓ ↓ ↓
|
||||
LoRA DPO/RLHF RLAIF (PPO/GRPO/SPO) Distillation
|
||||
(Domain adapt) (Preference) (Reinforcement Learn) (Reasoning)
|
||||
```
|
||||
|
||||
!!! tip "Recommended Combination"
|
||||
Fastest reproduction: `pretrain_hq.jsonl` + `sft_mini_512.jsonl`
|
||||
## 💰 Training Costs (Single NVIDIA 3090)
|
||||
|
||||
| Model | Dataset | Duration | Cost (RMB) | Quality |
|
||||
|-------|---------|----------|-----------|---------|
|
||||
| MiniMind2-Small | pretrain_hq + sft_mini_512 | 2.1h | ≈3 | 😊😊 |
|
||||
| MiniMind2-Small | Full dataset | 38h | ≈50 | 😊😊😊😊😊😊 |
|
||||
| MiniMind2 | pretrain_hq + sft_mini_512 | 3.3h | ≈5 | 😊😊 |
|
||||
| MiniMind2 | Full dataset | 122h | ≈160 | 😊😊😊😊😊😊😊 |
|
||||
|
||||
!!! success "Ultra-Fast Training"
|
||||
**Just 2.1 hours + $3 = Functional ChatBot!**
|
||||
|
||||
**Single 3090 only needs 2 hours + $0.5!**
|
||||
Use `pretrain_hq.jsonl` + `sft_mini_512.jsonl` for fastest reproduction
|
||||
|
||||
### 2. Data Format
|
||||
## 📋 Data Preparation
|
||||
|
||||
**Pretrain Data** (`pretrain_hq.jsonl`):
|
||||
### 1. Download Datasets
|
||||
|
||||
Download from [ModelScope](https://www.modelscope.cn/datasets/gongjy/minimind_dataset) or [HuggingFace](https://huggingface.co/datasets/jingyaogong/minimind_dataset):
|
||||
|
||||
```bash
|
||||
mkdir -p dataset
|
||||
cd dataset
|
||||
# Download required files
|
||||
```
|
||||
|
||||
### 2. Dataset Directory Structure
|
||||
|
||||
```
|
||||
./dataset/
|
||||
├── pretrain_hq.jsonl ✨ (1.6GB, required for pretraining)
|
||||
├── sft_mini_512.jsonl ✨ (1.2GB, fastest SFT)
|
||||
├── sft_512.jsonl (7.5GB, standard SFT)
|
||||
├── sft_1024.jsonl (5.6GB, longer SFT)
|
||||
├── sft_2048.jsonl (9GB, very long SFT)
|
||||
├── dpo.jsonl (909MB, DPO training)
|
||||
├── r1_mix_1024.jsonl (340MB, reasoning distillation)
|
||||
├── rlaif-mini.jsonl (1MB, RLAIF algorithms)
|
||||
├── lora_identity.jsonl (22.8KB, identity LoRA)
|
||||
└── lora_medical.jsonl (34MB, medical domain LoRA)
|
||||
```
|
||||
|
||||
### 3. Data Formats
|
||||
|
||||
**Pretraining Data** (`pretrain_hq.jsonl`):
|
||||
```json
|
||||
{"text": "How to overcome procrastination? Overcoming procrastination is not easy..."}
|
||||
{"text": "How to overcome procrastination? Overcoming procrastination is not easy, but these suggestions may help..."}
|
||||
```
|
||||
|
||||
**SFT Data** (`sft_*.jsonl`):
|
||||
```json
|
||||
{
|
||||
"conversations": [
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hello!"}
|
||||
{"role": "user", "content": "Hello!"},
|
||||
{"role": "assistant", "content": "Hello! How can I help?"},
|
||||
{"role": "user", "content": "Tell me a joke."},
|
||||
{"role": "assistant", "content": "Why did the scarecrow win an award? Because he was outstanding in his field!"}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
## 🎯 Training Pipeline
|
||||
**DPO Data** (`dpo.jsonl`):
|
||||
```json
|
||||
{
|
||||
"chosen": [
|
||||
{"role": "user", "content": "What is 2+2?"},
|
||||
{"role": "assistant", "content": "2+2 equals 4."}
|
||||
],
|
||||
"rejected": [
|
||||
{"role": "user", "content": "What is 2+2?"},
|
||||
{"role": "assistant", "content": "2+2 equals 5."}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
All training scripts are located in the `./trainer` directory.
|
||||
**LoRA Domain Data** (`lora_*.jsonl`):
|
||||
```json
|
||||
{
|
||||
"conversations": [
|
||||
{"role": "user", "content": "What's the treatment for cervical spondylosis?"},
|
||||
{"role": "assistant", "content": "Cervical spondylosis treatment typically includes..."}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
### 1. Pretraining
|
||||
## 🎯 Complete Training Pipeline
|
||||
|
||||
The pretraining stage lets the model learn basic knowledge, the goal is to **learn word continuation**.
|
||||
All training scripts are in the `./trainer` directory.
|
||||
|
||||
```bash
|
||||
cd trainer
|
||||
python train_pretrain.py
|
||||
|
||||
# Multi-GPU training
|
||||
torchrun --nproc_per_node 2 train_pretrain.py
|
||||
```
|
||||
|
||||
Output weights: `./out/pretrain_*.pth`
|
||||
### Stage 1: Pretraining
|
||||
|
||||
!!! info "Training Duration"
|
||||
- MiniMind2-Small (26M): ~1.1h (single 3090)
|
||||
- MiniMind2 (104M): ~3.9h (single 3090)
|
||||
|
||||
### 2. Supervised Fine-Tuning (SFT)
|
||||
|
||||
The SFT stage teaches the model conversation patterns and adapts to chat templates.
|
||||
**Purpose**: Learn foundational knowledge (word continuation)
|
||||
|
||||
```bash
|
||||
# Single GPU
|
||||
python train_pretrain.py
|
||||
|
||||
# Multi-GPU (DDP)
|
||||
torchrun --nproc_per_node 2 train_pretrain.py
|
||||
|
||||
# Multi-GPU (DeepSpeed)
|
||||
deepspeed --master_port 29500 --num_gpus=2 train_pretrain.py
|
||||
```
|
||||
|
||||
**Key Parameters**:
|
||||
- `max_seq_len`: 512 (adjust based on GPU memory)
|
||||
- `learning_rate`: 1e-4
|
||||
- `epochs`: Adjust based on dataset size
|
||||
|
||||
**Output**: `./out/pretrain_*.pth`
|
||||
|
||||
**Training Duration**:
|
||||
- MiniMind2-Small (26M): ~1.1h
|
||||
- MiniMind2 (104M): ~3.9h
|
||||
|
||||
!!! tip "Pretraining Tips"
|
||||
- Start with `pretrain_hq.jsonl` for best results
|
||||
- Quality > Quantity for pretraining data
|
||||
- Monitor loss curve to detect overfitting
|
||||
|
||||
### Stage 2: Supervised Fine-Tuning (SFT)
|
||||
|
||||
**Purpose**: Teach conversation patterns and chat templates
|
||||
|
||||
```bash
|
||||
# Single GPU
|
||||
python train_full_sft.py
|
||||
|
||||
# Multi-GPU training
|
||||
# Multi-GPU
|
||||
torchrun --nproc_per_node 2 train_full_sft.py
|
||||
```
|
||||
|
||||
Output weights: `./out/full_sft_*.pth`
|
||||
**Configuration**:
|
||||
- Load pretrained model from Stage 1
|
||||
- Use SFT dataset (`sft_mini_512.jsonl` or `sft_512.jsonl`)
|
||||
- Adjust `max_seq_len` to match training data
|
||||
|
||||
!!! info "Training Duration"
|
||||
- MiniMind2-Small: ~1h (using sft_mini_512)
|
||||
- MiniMind2: ~3.3h (using sft_mini_512)
|
||||
**Output**: `./out/full_sft_*.pth`
|
||||
|
||||
### 3. LoRA Fine-tuning (Optional)
|
||||
**Training Duration**:
|
||||
- With sft_mini_512: 1-3 hours
|
||||
- With full sft_512: 20-25 hours
|
||||
|
||||
LoRA is a parameter-efficient fine-tuning method, suitable for domain adaptation.
|
||||
!!! warning "SFT Data Selection"
|
||||
- `sft_mini_512.jsonl`: Fastest, ~1.2GB, 512 tokens max
|
||||
- `sft_512.jsonl`: Standard, ~7.5GB, 512 tokens max
|
||||
- `sft_1024.jsonl`: Longer, ~5.6GB, 1024 tokens max
|
||||
- `sft_2048.jsonl`: Extended, ~9GB, 2048 tokens max
|
||||
|
||||
### Stage 3: LoRA Fine-Tuning (Optional)
|
||||
|
||||
**Purpose**: Parameter-efficient domain adaptation
|
||||
|
||||
**Use Cases**:
|
||||
- Medical Q&A knowledge
|
||||
- Personal identity/self-awareness
|
||||
- Proprietary domain knowledge
|
||||
|
||||
```bash
|
||||
# Edit train_lora.py to set correct dataset and base model
|
||||
python train_lora.py
|
||||
|
||||
# Multi-GPU
|
||||
torchrun --nproc_per_node 2 train_lora.py
|
||||
```
|
||||
|
||||
**Output**: `./out/lora/lora_*.pth`
|
||||
|
||||
**Example 1: Medical Domain**
|
||||
|
||||
Prepare `dataset/lora_medical.jsonl`:
|
||||
```json
|
||||
{
|
||||
"conversations": [
|
||||
{"role": "user", "content": "What's the correct pillow height for cervical spondylosis?"},
|
||||
{"role": "assistant", "content": "For cervical spondylosis, pillow height should be..."}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
Train:
|
||||
```bash
|
||||
# Modify train_lora.py: lora_name = 'medical'
|
||||
python train_lora.py
|
||||
```
|
||||
|
||||
**Use Cases**:
|
||||
- Medical Q&A: use `lora_medical.jsonl`
|
||||
- Self-awareness: use `lora_identity.jsonl`
|
||||
**Example 2: Identity/Self-Awareness**
|
||||
|
||||
Output weights: `./out/lora/lora_*.pth`
|
||||
Prepare `dataset/lora_identity.jsonl`:
|
||||
```json
|
||||
{
|
||||
"conversations": [
|
||||
{"role": "user", "content": "Who are you?"},
|
||||
{"role": "assistant", "content": "I am MiniMind..."}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
### 4. DPO Reinforcement Learning (Optional)
|
||||
### Stage 4: Direct Preference Optimization (DPO)
|
||||
|
||||
DPO is used to optimize model response quality to better align with human preferences.
|
||||
**Purpose**: Align model responses with human preferences
|
||||
|
||||
DPO eliminates the need for separate reward models by directly optimizing preference pairs.
|
||||
|
||||
```bash
|
||||
python train_dpo.py
|
||||
|
||||
# Multi-GPU
|
||||
torchrun --nproc_per_node 2 train_dpo.py
|
||||
```
|
||||
|
||||
Output weights: `./out/rlhf_*.pth`
|
||||
**Output**: `./out/rlhf_*.pth`
|
||||
|
||||
### 5. Reasoning Model Distillation (Optional)
|
||||
**Key Features**:
|
||||
- Off-policy training (reuse data across epochs)
|
||||
- No separate reward model needed
|
||||
- Better sample efficiency than PPO
|
||||
- Stable training convergence
|
||||
|
||||
Distill reasoning capabilities from DeepSeek-R1.
|
||||
**Training Duration**: ~1-3 hours
|
||||
|
||||
### Stage 5: Reinforcement Learning from AI Feedback (RLAIF)
|
||||
|
||||
RLAIF is an advanced training approach using AI-generated rewards instead of human annotations. MiniMind implements three modern algorithms:
|
||||
|
||||
#### 5.1 PPO (Proximal Policy Optimization)
|
||||
|
||||
Classical on-policy RL algorithm with proven stability.
|
||||
|
||||
```bash
|
||||
python train_ppo.py
|
||||
|
||||
# Multi-GPU
|
||||
torchrun --nproc_per_node 2 train_ppo.py
|
||||
```
|
||||
|
||||
**Algorithm**:
|
||||
$$\mathcal{L}_{PPO} = -\mathbb{E}\left[\min(r_t \cdot A_t, \text{clip}(r_t, 1-\varepsilon, 1+\varepsilon) \cdot A_t)\right] + \beta \cdot \mathbb{E}[\text{KL}]$$
|
||||
|
||||
**Characteristics**:
|
||||
- Stable but slower reward improvement
|
||||
- Requires both Actor and Critic networks
|
||||
- High memory usage (1.5-2× single network)
|
||||
- Good for exploration
|
||||
|
||||
**Output**: `./out/ppo_actor_*.pth`
|
||||
|
||||
**Training Duration**: ~1-3 hours
|
||||
|
||||
#### 5.2 GRPO (Group Relative Policy Optimization)
|
||||
|
||||
Modern algorithm used in DeepSeek-R1, with faster convergence.
|
||||
|
||||
```bash
|
||||
python train_grpo.py
|
||||
|
||||
# Multi-GPU
|
||||
torchrun --nproc_per_node 2 train_grpo.py
|
||||
```
|
||||
|
||||
**Algorithm**:
|
||||
$$\mathcal{L}_{GRPO} = -\mathbb{E}\left[r_t \cdot A_t - \beta \cdot \text{KL}_t\right]$$
|
||||
|
||||
Where advantage is computed as:
|
||||
$$A_t = \frac{R - \mu_{group}}{\sigma_{group}}$$
|
||||
|
||||
**Characteristics**:
|
||||
- Single-network design (memory efficient)
|
||||
- Faster reward improvement
|
||||
- Group normalization removes bias
|
||||
- Better convergence stability
|
||||
|
||||
**Output**: `./out/grpo_*.pth`
|
||||
|
||||
**Training Duration**: ~1-3 hours
|
||||
|
||||
#### 5.3 SPO (Single-stream Policy Optimization)
|
||||
|
||||
Newest algorithm (2025) addressing GRPO's degenerate group problem.
|
||||
|
||||
```bash
|
||||
python train_spo.py
|
||||
|
||||
# Multi-GPU
|
||||
torchrun --nproc_per_node 2 train_spo.py
|
||||
```
|
||||
|
||||
**Algorithm**:
|
||||
$$\mathcal{L}_{SPO} = -\mathbb{E}\left[\log \pi_\theta(a_t|s) \cdot A_t - \beta \cdot \text{KL}_t\right]$$
|
||||
|
||||
With adaptive baseline: $B_t^{adaptive}$
|
||||
|
||||
**Characteristics**:
|
||||
- No group dependency (1 input → 1 training sample)
|
||||
- Adaptive value tracking
|
||||
- Better handling of difficult examples
|
||||
- Experimental on small models
|
||||
|
||||
**Output**: `./out/spo_*.pth`
|
||||
|
||||
**Training Duration**: ~1-3 hours
|
||||
|
||||
#### RLAIF Dataset Preparation
|
||||
|
||||
All RLAIF algorithms use `rlaif-mini.jsonl` (1MB, 10k examples):
|
||||
|
||||
```bash
|
||||
# Download dataset
|
||||
# Format: Same as SFT, but assistant content is "无" (none)
|
||||
{
|
||||
"conversations": [
|
||||
{"role": "user", "content": "Explain photosynthesis briefly."},
|
||||
{"role": "assistant", "content": "无"}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
The model generates completions during training, which are scored by a **Reward Model** (e.g., InternLM2-1.8B-Reward).
|
||||
|
||||
**Reward Model Setup**:
|
||||
|
||||
```bash
|
||||
# Download reward model to parent directory
|
||||
cd ../
|
||||
git clone https://huggingface.co/internlm/internlm2-1_8b-reward
|
||||
|
||||
# Directory structure should be:
|
||||
# project/
|
||||
# ├── minimind/
|
||||
# └── internlm2-1_8b-reward/
|
||||
```
|
||||
|
||||
#### RLAIF vs DPO Comparison
|
||||
|
||||
| Aspect | DPO | RLAIF (PPO/GRPO/SPO) |
|
||||
|--------|-----|---------------------|
|
||||
| Training Type | Off-policy | On-policy |
|
||||
| Data Freshness | Static pairs | Dynamic (generated) |
|
||||
| Reward Source | Implicit | Explicit model |
|
||||
| Convergence | Fast | Slower |
|
||||
| Memory Usage | Lower | Higher |
|
||||
| Best For | Preference refinement | Capability improvement |
|
||||
|
||||
### Stage 6: Reasoning Model Distillation
|
||||
|
||||
**Purpose**: Distill DeepSeek-R1-style reasoning into MiniMind
|
||||
|
||||
```bash
|
||||
python train_distill_reason.py
|
||||
|
||||
# Multi-GPU
|
||||
torchrun --nproc_per_node 2 train_distill_reason.py
|
||||
```
|
||||
|
||||
Output weights: `./out/reason_*.pth`
|
||||
|
||||
## 📈 Model Architecture
|
||||
|
||||
MiniMind uses Transformer Decoder-Only architecture (similar to Llama3):
|
||||
|
||||

|
||||
|
||||
### Model Parameter Configuration
|
||||
|
||||
| Model Name | params | d_model | n_layers | kv_heads | q_heads |
|
||||
|------------|--------|---------|----------|----------|---------|
|
||||
| MiniMind2-Small | 26M | 512 | 8 | 2 | 8 |
|
||||
| MiniMind2-MoE | 145M | 640 | 8 | 2 | 8 |
|
||||
| MiniMind2 | 104M | 768 | 16 | 2 | 8 |
|
||||
|
||||
## 🧪 Test Model
|
||||
|
||||
```bash
|
||||
# model_mode: 0=pretrain, 1=sft, 2=rlhf, 3=reason
|
||||
python eval_model.py --model_mode 1
|
||||
|
||||
# Test LoRA model
|
||||
python eval_model.py --lora_name 'lora_medical' --model_mode 2
|
||||
**Data Format** (`r1_mix_1024.jsonl`):
|
||||
```json
|
||||
{
|
||||
"conversations": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Solve: 5 + 3 = ?"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "<think>\nI need to add 5 and 3.\n5 + 3 = 8\n</think>\n<answer>\n5 + 3 = 8\n</answer>"
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
**Output**: `./out/reason_*.pth`
|
||||
|
||||
**Training Features**:
|
||||
- Enforces `<think>` and `<answer>` tags
|
||||
- Penalty loss for format violations
|
||||
- Mixed data (reasoning + multi-turn + English)
|
||||
|
||||
## 🔧 Multi-GPU Training
|
||||
|
||||
### DDP Method
|
||||
### DDP (Distributed Data Parallel)
|
||||
|
||||
Best for single-machine multi-GPU:
|
||||
|
||||
```bash
|
||||
torchrun --nproc_per_node N train_xxx.py
|
||||
# N = number of GPUs
|
||||
```
|
||||
|
||||
### DeepSpeed Method
|
||||
### DeepSpeed
|
||||
|
||||
For advanced optimization:
|
||||
|
||||
```bash
|
||||
deepspeed --master_port 29500 --num_gpus=N train_xxx.py
|
||||
@ -157,30 +421,259 @@ deepspeed --master_port 29500 --num_gpus=N train_xxx.py
|
||||
|
||||
### Wandb Monitoring
|
||||
|
||||
Track training progress:
|
||||
|
||||
```bash
|
||||
# Login first
|
||||
wandb login
|
||||
|
||||
# Enable wandb
|
||||
# Enable wandb logging
|
||||
torchrun --nproc_per_node N train_xxx.py --use_wandb
|
||||
|
||||
# Or SwanLab (China-friendly alternative)
|
||||
python train_xxx.py --use_wandb # Automatically uses SwanLab if available
|
||||
```
|
||||
|
||||
## 💰 Training Cost
|
||||
## 🧪 Model Testing
|
||||
|
||||
Based on single NVIDIA 3090:
|
||||
### Evaluate Pretrain Model
|
||||
|
||||
| Dataset Combination | Duration | Cost | Effect |
|
||||
|-----------|------|------|------|
|
||||
| pretrain_hq + sft_mini_512 | 2.1h | ≈$0.35 | 😊😊 Basic chat |
|
||||
| Full dataset (MiniMind2-Small) | 38h | ≈$6.50 | 😊😊😊😊😊😊 Complete capabilities |
|
||||
| Full dataset (MiniMind2) | 122h | ≈$20.80 | 😊😊😊😊😊😊😊😊 Best performance |
|
||||
```bash
|
||||
python eval_model.py --model_mode 0
|
||||
```
|
||||
|
||||
!!! success "Quick Reproduction"
|
||||
Using `pretrain_hq` + `sft_mini_512`, single 3090 only needs **2 hours + $0.5** to train a ChatBot!
|
||||
### Evaluate Chat Model
|
||||
|
||||
## 📝 Common Issues
|
||||
```bash
|
||||
python eval_model.py --model_mode 1
|
||||
```
|
||||
|
||||
- **Out of memory**: Reduce `batch_size` or use DeepSpeed
|
||||
- **Training not converging**: Adjust learning rate or check data quality
|
||||
- **Multi-GPU training error**: Ensure all GPUs are visible and CUDA versions are consistent
|
||||
### Evaluate with LoRA
|
||||
|
||||
```bash
|
||||
python eval_model.py --lora_name 'lora_medical' --model_mode 1
|
||||
```
|
||||
|
||||
### Evaluate Reasoning Model
|
||||
|
||||
```bash
|
||||
python eval_model.py --model_mode 3
|
||||
```
|
||||
|
||||
### Evaluate RLAIF Models
|
||||
|
||||
```bash
|
||||
# PPO model
|
||||
python eval_model.py --model_mode 4
|
||||
|
||||
# GRPO model
|
||||
python eval_model.py --model_mode 4
|
||||
```
|
||||
|
||||
### RoPE Length Extrapolation
|
||||
|
||||
Test with extended context:
|
||||
|
||||
```bash
|
||||
python eval_model.py --model_mode 1 --inference_rope_scaling True
|
||||
```
|
||||
|
||||
## 📐 Model Architecture
|
||||
|
||||
### MiniMind Structure
|
||||
|
||||
**Decoder-Only Transformer** (similar to Llama3):
|
||||
|
||||
```
|
||||
Input Tokens
|
||||
↓
|
||||
Token Embedding (6400 vocab)
|
||||
↓
|
||||
Rotary Embeddings (RoPE) [with YaRN for length extrapolation]
|
||||
↓
|
||||
[Transformer Blocks] ×N
|
||||
├─ Attention (Multi-Head)
|
||||
├─ RMSNorm
|
||||
├─ SwiGLU FFN [or MoE for MoE variant]
|
||||
└─ Residual Connections
|
||||
↓
|
||||
RMSNorm
|
||||
↓
|
||||
LM Head (→ 6400 vocab logits)
|
||||
↓
|
||||
Output Probabilities
|
||||
```
|
||||
|
||||
### Model Configurations
|
||||
|
||||
| Config | MiniMind2-Small | MiniMind2 | MiniMind2-MoE |
|
||||
|--------|-----------------|----------|---------------|
|
||||
| Parameters | 26M | 104M | 145M |
|
||||
| Hidden Dim | 512 | 768 | 640 |
|
||||
| Layers | 8 | 16 | 8 |
|
||||
| KV Heads | 2 | 2 | 2 |
|
||||
| Q Heads | 8 | 8 | 8 |
|
||||
| Vocab Size | 6,400 | 6,400 | 6,400 |
|
||||
| Context Length | 2,048 | 2,048 | 2,048 |
|
||||
|
||||
### Modify Architecture
|
||||
|
||||
Edit `./model/LMConfig.py`:
|
||||
|
||||
```python
|
||||
class LMConfig:
|
||||
hidden_size: int = 768
|
||||
num_layers: int = 16
|
||||
num_heads: int = 8
|
||||
num_kv_heads: int = 2
|
||||
# ... other configs
|
||||
```
|
||||
|
||||
## 🔍 Training Tips & Best Practices
|
||||
|
||||
### Data Quality > Quantity
|
||||
|
||||
- High-quality pretraining data accelerates convergence
|
||||
- `pretrain_hq.jsonl` is carefully curated for quality
|
||||
- Consider data deduplication and cleaning
|
||||
|
||||
### Learning Rate Scheduling
|
||||
|
||||
```python
|
||||
# Recommended schedules
|
||||
- Linear warmup then decay
|
||||
- Initial: 1e-4 to 5e-4
|
||||
- Warmup steps: 10% of total
|
||||
- Final: 10% of initial LR
|
||||
```
|
||||
|
||||
### Batch Size & Sequence Length
|
||||
|
||||
```python
|
||||
# Balance between GPU memory and convergence
|
||||
- Pretraining: max_seq_len=512, batch_size=32
|
||||
- SFT: max_seq_len=512, batch_size=16
|
||||
- LoRA: max_seq_len=512, batch_size=16
|
||||
```
|
||||
|
||||
### Memory Optimization
|
||||
|
||||
```bash
|
||||
# Reduce batch size if OOM
|
||||
python train_xxx.py --batch_size 8
|
||||
|
||||
# Or use gradient accumulation
|
||||
python train_xxx.py --gradient_accumulation_steps 4
|
||||
```
|
||||
|
||||
### Checkpoint Management
|
||||
|
||||
- Saves every 100 steps by default
|
||||
- Each new save overwrites the old one
|
||||
- Automatic backup before training
|
||||
|
||||
## 🚨 Common Issues & Solutions
|
||||
|
||||
### Issue: CUDA Out of Memory
|
||||
|
||||
```bash
|
||||
# Solution 1: Reduce batch size
|
||||
python train_xxx.py --batch_size 4
|
||||
|
||||
# Solution 2: Use gradient accumulation
|
||||
python train_xxx.py --batch_size 16 --gradient_accumulation_steps 2
|
||||
|
||||
# Solution 3: Use smaller model
|
||||
# Edit trainer script to use MiniMind2-Small instead
|
||||
```
|
||||
|
||||
### Issue: Training Not Converging
|
||||
|
||||
```python
|
||||
# Possible causes:
|
||||
1. Learning rate too high/low
|
||||
2. Data quality issues
|
||||
3. Model capacity mismatch
|
||||
|
||||
# Solutions:
|
||||
- Reduce learning rate: --learning_rate 1e-5
|
||||
- Check data format and quality
|
||||
- Try smaller model first
|
||||
```
|
||||
|
||||
### Issue: Multi-GPU Sync Errors
|
||||
|
||||
```bash
|
||||
# Ensure:
|
||||
1. All GPUs visible: nvidia-smi
|
||||
2. Same CUDA version across all GPUs
|
||||
3. Network connectivity for distributed training
|
||||
|
||||
# Debug:
|
||||
torchrun --nproc_per_node 2 train_xxx.py --debug
|
||||
```
|
||||
|
||||
### Issue: Different Results Than Expected
|
||||
|
||||
```python
|
||||
# Check:
|
||||
1. Random seed set (reproducibility)
|
||||
2. Correct model checkpoint loaded
|
||||
3. Correct dataset being used
|
||||
4. Same hyperparameters as reference
|
||||
```
|
||||
|
||||
## 📈 Training Progression
|
||||
|
||||
Typical training curves:
|
||||
|
||||
```
|
||||
Pretraining Loss: ↘↘↘ (steep decline, then plateau)
|
||||
SFT Loss: ↘ (steady decline)
|
||||
PPO Reward: ↗ (rising, may plateau)
|
||||
GRPO Reward: ↗↗ (faster rise, more stable)
|
||||
```
|
||||
|
||||
## 🎓 Advanced Topics
|
||||
|
||||
### Custom Datasets
|
||||
|
||||
Create your own dataset:
|
||||
|
||||
```python
|
||||
# Format: JSONL with conversations list
|
||||
# Each line is one training example
|
||||
# Ensure consistent quality and format
|
||||
```
|
||||
|
||||
### Model Quantization (Post-training)
|
||||
|
||||
```bash
|
||||
# 4-bit quantization for inference
|
||||
# Use tools like:
|
||||
# - llama.cpp (gguf format)
|
||||
# - bitsandbytes (dynamic quantization)
|
||||
# - AutoGPTQ (static quantization)
|
||||
```
|
||||
|
||||
### Model Merging
|
||||
|
||||
```python
|
||||
# Merge base model + LoRA weights
|
||||
# Use tools like: peft, llama.cpp
|
||||
```
|
||||
|
||||
## 📚 References
|
||||
|
||||
- [Scaling Laws](https://arxiv.org/pdf/2001.08361.pdf)
|
||||
- [RoPE Position Embeddings](https://arxiv.org/abs/2104.09864)
|
||||
- [YaRN Length Extrapolation](https://arxiv.org/abs/2309.00071)
|
||||
- [PPO Algorithm](https://arxiv.org/abs/1707.06347)
|
||||
- [GRPO (DeepSeek)](https://arxiv.org/pdf/2402.03300)
|
||||
- [SPO Algorithm](https://arxiv.org/abs/2509.13232)
|
||||
- [DPO](https://arxiv.org/abs/2305.18290)
|
||||
|
||||
---
|
||||
|
||||
**Next**: Deploy your trained model or explore [advanced inference options](quickstart.md#third-party-inference-frameworks)
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user