Compare commits

...

No commits in common. "v2" and "master" have entirely different histories.
v2 ... master

76 changed files with 40487 additions and 1170 deletions

6
.gitignore vendored Normal file
View File

@ -0,0 +1,6 @@
__pycache__/
*.pyc
.DS_Store
out
website/
docs-minimind/

View File

@ -1,19 +0,0 @@
# Read the Docs 配置文件
version: 2
# 构建配置
build:
os: ubuntu-22.04
tools:
python: "3.11"
# MkDocs 配置
mkdocs:
configuration: mkdocs.yml
fail_on_warning: false
# Python 依赖
python:
install:
- requirements: requirements.txt

128
CODE_OF_CONDUCT.md Normal file
View File

@ -0,0 +1,128 @@
# Contributor Covenant Code of Conduct
## Our Pledge
We as members, contributors, and leaders pledge to make participation in our
community a harassment-free experience for everyone, regardless of age, body
size, visible or invisible disability, ethnicity, sex characteristics, gender
identity and expression, level of experience, education, socio-economic status,
nationality, personal appearance, race, religion, or sexual identity
and orientation.
We pledge to act and interact in ways that contribute to an open, welcoming,
diverse, inclusive, and healthy community.
## Our Standards
Examples of behavior that contributes to a positive environment for our
community include:
* Demonstrating empathy and kindness toward other people
* Being respectful of differing opinions, viewpoints, and experiences
* Giving and gracefully accepting constructive feedback
* Accepting responsibility and apologizing to those affected by our mistakes,
and learning from the experience
* Focusing on what is best not just for us as individuals, but for the
overall community
Examples of unacceptable behavior include:
* The use of sexualized language or imagery, and sexual attention or
advances of any kind
* Trolling, insulting or derogatory comments, and personal or political attacks
* Public or private harassment
* Publishing others' private information, such as a physical or email
address, without their explicit permission
* Other conduct which could reasonably be considered inappropriate in a
professional setting
## Enforcement Responsibilities
Community leaders are responsible for clarifying and enforcing our standards of
acceptable behavior and will take appropriate and fair corrective action in
response to any behavior that they deem inappropriate, threatening, offensive,
or harmful.
Community leaders have the right and responsibility to remove, edit, or reject
comments, commits, code, wiki edits, issues, and other contributions that are
not aligned to this Code of Conduct, and will communicate reasons for moderation
decisions when appropriate.
## Scope
This Code of Conduct applies within all community spaces, and also applies when
an individual is officially representing the community in public spaces.
Examples of representing our community include using an official e-mail address,
posting via an official social media account, or acting as an appointed
representative at an online or offline event.
## Enforcement
Instances of abusive, harassing, or otherwise unacceptable behavior may be
reported to the community leaders responsible for enforcement at
.
All complaints will be reviewed and investigated promptly and fairly.
All community leaders are obligated to respect the privacy and security of the
reporter of any incident.
## Enforcement Guidelines
Community leaders will follow these Community Impact Guidelines in determining
the consequences for any action they deem in violation of this Code of Conduct:
### 1. Correction
**Community Impact**: Use of inappropriate language or other behavior deemed
unprofessional or unwelcome in the community.
**Consequence**: A private, written warning from community leaders, providing
clarity around the nature of the violation and an explanation of why the
behavior was inappropriate. A public apology may be requested.
### 2. Warning
**Community Impact**: A violation through a single incident or series
of actions.
**Consequence**: A warning with consequences for continued behavior. No
interaction with the people involved, including unsolicited interaction with
those enforcing the Code of Conduct, for a specified period of time. This
includes avoiding interactions in community spaces as well as external channels
like social media. Violating these terms may lead to a temporary or
permanent ban.
### 3. Temporary Ban
**Community Impact**: A serious violation of community standards, including
sustained inappropriate behavior.
**Consequence**: A temporary ban from any sort of interaction or public
communication with the community for a specified period of time. No public or
private interaction with the people involved, including unsolicited interaction
with those enforcing the Code of Conduct, is allowed during this period.
Violating these terms may lead to a permanent ban.
### 4. Permanent Ban
**Community Impact**: Demonstrating a pattern of violation of community
standards, including sustained inappropriate behavior, harassment of an
individual, or aggression toward or disparagement of classes of individuals.
**Consequence**: A permanent ban from any sort of public interaction within
the community.
## Attribution
This Code of Conduct is adapted from the [Contributor Covenant][homepage],
version 2.0, available at
https://www.contributor-covenant.org/version/2/0/code_of_conduct.html.
Community Impact Guidelines were inspired by [Mozilla's code of conduct
enforcement ladder](https://github.com/mozilla/diversity).
[homepage]: https://www.contributor-covenant.org
For answers to common questions about this code of conduct, see the FAQ at
https://www.contributor-covenant.org/faq. Translations are available at
https://www.contributor-covenant.org/translations.

201
LICENSE Normal file
View File

@ -0,0 +1,201 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

1994
README.md

File diff suppressed because it is too large Load Diff

1992
README_en.md Normal file

File diff suppressed because it is too large Load Diff

0
dataset/__init__.py Normal file
View File

5
dataset/dataset.md Executable file
View File

@ -0,0 +1,5 @@
# MiniMind Datasets
将所有下载的数据集文件放置到当前目录.
Place the downloaded dataset file in the current directory.

256
dataset/lm_dataset.py Normal file
View File

@ -0,0 +1,256 @@
from torch.utils.data import Dataset
import torch
import json
import os
import random
from datasets import load_dataset, Features, Sequence, Value
os.environ["TOKENIZERS_PARALLELISM"] = "false"
def pre_processing_chat(conversations, add_system_ratio=0.2):
# tool use 数据完整保留不做处理
if any(conv.get('tools') for conv in conversations): return conversations
SYSTEM_PROMPTS = [
"你是一个知识丰富的AI尽力为用户提供准确的信息。",
"你是minimind一个小巧但有用的语言模型。",
"你是一个专业的AI助手请提供有价值的回答。",
"你是minimind请尽力帮助用户解决问题。",
"你是一个可靠的AI请给出准确的回答。",
"You are a helpful AI assistant.",
"You are minimind, a lightweight intelligent assistant.",
"You are a friendly chatbot. Please answer the user's questions carefully.",
"You are a knowledgeable AI. Try your best to provide accurate information.",
"You are minimind, a small but useful language model."
]
# 概率性添加system
if conversations[0].get('role') != 'system':
if random.random() < add_system_ratio:
return [{'role': 'system', 'content': random.choice(SYSTEM_PROMPTS)}] + conversations
return conversations
def post_processing_chat(prompt_content, empty_think_ratio=0.2):
# 以80%概率移除空思考标签
if '<think>\n\n</think>\n\n' in prompt_content and random.random() > empty_think_ratio:
prompt_content = prompt_content.replace('<think>\n\n</think>\n\n', '')
return prompt_content
class PretrainDataset(Dataset):
def __init__(self, data_path, tokenizer, max_length=512):
super().__init__()
self.tokenizer = tokenizer
self.max_length = max_length
self.samples = load_dataset('json', data_files=data_path, split='train')
def __len__(self):
return len(self.samples)
def __getitem__(self, index):
sample = self.samples[index]
tokens = self.tokenizer(str(sample['text']), add_special_tokens=False, max_length=self.max_length - 2, truncation=True).input_ids
tokens = [self.tokenizer.bos_token_id] + tokens + [self.tokenizer.eos_token_id]
input_ids = tokens + [self.tokenizer.pad_token_id] * (self.max_length - len(tokens))
input_ids = torch.tensor(input_ids, dtype=torch.long)
labels = input_ids.clone()
labels[input_ids == self.tokenizer.pad_token_id] = -100
return input_ids, labels
class SFTDataset(Dataset):
def __init__(self, jsonl_path, tokenizer, max_length=1024):
super().__init__()
self.tokenizer = tokenizer
self.max_length = max_length
features = Features({'conversations': [{'role': Value('string'), 'content': Value('string'), 'reasoning_content': Value('string'), 'tools': Value('string'), 'tool_calls': Value('string')}]})
self.samples = load_dataset('json', data_files=jsonl_path, split='train', features=features)
self.bos_id = tokenizer(f'{tokenizer.bos_token}assistant\n', add_special_tokens=False).input_ids
self.eos_id = tokenizer(f'{tokenizer.eos_token}\n', add_special_tokens=False).input_ids
def __len__(self):
return len(self.samples)
def create_chat_prompt(self, conversations):
messages = []
tools = None
for message in conversations:
message = dict(message)
if message.get("role") == "system" and message.get("tools"):
tools = json.loads(message["tools"]) if isinstance(message["tools"], str) else message["tools"]
if message.get("tool_calls") and isinstance(message["tool_calls"], str):
message["tool_calls"] = json.loads(message["tool_calls"])
messages.append(message)
return self.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=False,
tools=tools
)
def generate_labels(self, input_ids):
labels = [-100] * len(input_ids)
i = 0
while i < len(input_ids):
if input_ids[i:i + len(self.bos_id)] == self.bos_id:
start = i + len(self.bos_id)
end = start
while end < len(input_ids):
if input_ids[end:end + len(self.eos_id)] == self.eos_id:
break
end += 1
for j in range(start, min(end + len(self.eos_id), self.max_length)):
labels[j] = input_ids[j]
i = end + len(self.eos_id) if end < len(input_ids) else len(input_ids)
else:
i += 1
return labels
def __getitem__(self, index):
sample = self.samples[index]
conversations = pre_processing_chat(sample['conversations'])
prompt = self.create_chat_prompt(conversations)
prompt = post_processing_chat(prompt)
input_ids = self.tokenizer(prompt).input_ids[:self.max_length]
input_ids += [self.tokenizer.pad_token_id] * (self.max_length - len(input_ids))
labels = self.generate_labels(input_ids)
# # === 调试打印 ===
# print(f"\n--- Sample {index} ---")
# for i, (x, y) in enumerate(zip(input_ids[:-1], labels[1:])):
# print(f"{i:3d}: X={self.tokenizer.decode([x])!r:16s} ---> Y={self.tokenizer.decode([input_ids[i+1]])!r:16s} label={y}")
# # ================
return torch.tensor(input_ids, dtype=torch.long), torch.tensor(labels, dtype=torch.long)
class DPODataset(Dataset):
def __init__(self, file_path, tokenizer, max_length=4096):
super().__init__()
self.tokenizer = tokenizer
self.max_length = max_length
self.padding = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
self.bos_id = tokenizer(f'{tokenizer.bos_token}assistant\n', add_special_tokens=False).input_ids
self.eos_id = tokenizer(f'{tokenizer.eos_token}\n', add_special_tokens=False).input_ids
self.samples = load_dataset('json', data_files=file_path, split='train')
def __len__(self):
return len(self.samples)
def __getitem__(self, index):
sample = self.samples[index]
chosen = sample['chosen'] # 是一个 list里面包含若干 {role, content}
rejected = sample['rejected'] # 同上
chosen_prompt = self.tokenizer.apply_chat_template(
chosen, tokenize=False, add_generation_prompt=False
)
chosen_prompt = post_processing_chat(chosen_prompt)
rejected_prompt = self.tokenizer.apply_chat_template(
rejected, tokenize=False, add_generation_prompt=False
)
rejected_prompt = post_processing_chat(rejected_prompt)
chosen_encoding = self.tokenizer(
chosen_prompt, truncation=True, max_length=self.max_length, padding='max_length'
)
rejected_encoding = self.tokenizer(
rejected_prompt, truncation=True, max_length=self.max_length, padding='max_length'
)
chosen_input_ids = chosen_encoding['input_ids']
chosen_loss_mask = self.generate_loss_mask(chosen_input_ids)
rejected_input_ids = rejected_encoding['input_ids']
rejected_loss_mask = self.generate_loss_mask(rejected_input_ids)
x_chosen = torch.tensor(chosen_input_ids[:-1], dtype=torch.long)
y_chosen = torch.tensor(chosen_input_ids[1:], dtype=torch.long)
mask_chosen = torch.tensor(chosen_loss_mask[1:], dtype=torch.long)
x_rejected = torch.tensor(rejected_input_ids[:-1], dtype=torch.long)
y_rejected = torch.tensor(rejected_input_ids[1:], dtype=torch.long)
mask_rejected = torch.tensor(rejected_loss_mask[1:], dtype=torch.long)
return {
'x_chosen': x_chosen,
'y_chosen': y_chosen,
'mask_chosen': mask_chosen,
'x_rejected': x_rejected,
'y_rejected': y_rejected,
'mask_rejected': mask_rejected
}
def generate_loss_mask(self, input_ids):
loss_mask = [0] * len(input_ids)
i = 0
while i < len(input_ids):
if input_ids[i:i + len(self.bos_id)] == self.bos_id:
start = i + len(self.bos_id)
end = start
while end < len(input_ids):
if input_ids[end:end + len(self.eos_id)] == self.eos_id:
break
end += 1
for j in range(start, min(end + len(self.eos_id), self.max_length)):
loss_mask[j] = 1
i = end + len(self.eos_id) if end < len(input_ids) else len(input_ids)
else:
i += 1
return loss_mask
class RLAIFDataset(Dataset):
def __init__(self, jsonl_path, tokenizer, max_length=1024, thinking_ratio=0.5):
super().__init__()
self.tokenizer = tokenizer
self.max_length = max_length
self.thinking_ratio = thinking_ratio # 按概率开启 thinking
self.samples = load_dataset('json', data_files=jsonl_path, split='train')
self.bos_id = tokenizer(f'{tokenizer.bos_token}assistant', add_special_tokens=False).input_ids
self.eos_id = tokenizer(f'{tokenizer.eos_token}', add_special_tokens=False).input_ids
def __len__(self):
return len(self.samples)
def create_chat_prompt(self, conversations):
conversations = pre_processing_chat(conversations)
use_thinking = random.random() < self.thinking_ratio
return self.tokenizer.apply_chat_template(
conversations[:-1],
tokenize=False,
open_thinking=use_thinking,
add_generation_prompt=True
)
def __getitem__(self, index):
sample = self.samples[index]
prompt = self.create_chat_prompt(sample['conversations'])
return {
'prompt': prompt,
'answer': ""
}
class AgentRLDataset(Dataset):
def __init__(self, jsonl_path, tokenizer, max_length=1024):
super().__init__()
self.tokenizer = tokenizer
self.max_length = max_length
self.samples = []
with open(jsonl_path, 'r', encoding='utf-8') as f:
for line in f:
self.samples.append(json.loads(line.strip()))
def __len__(self):
return len(self.samples)
def parse_conversations(self, conversations):
messages = []
tools = None
for message in conversations:
message = dict(message)
if message.get("role") == "system" and message.get("tools"):
tools = json.loads(message["tools"]) if isinstance(message["tools"], str) else message["tools"]
messages.append(message)
return messages[:-1], tools
def __getitem__(self, index):
sample = self.samples[index]
messages, tools = self.parse_conversations(sample['conversations'])
return {'messages': messages, 'tools': tools, 'gt': sample['gt']}
if __name__ == "__main__":
pass

Binary file not shown.

Before

Width:  |  Height:  |  Size: 136 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 73 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 230 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 104 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 239 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 121 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 372 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 519 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 146 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 3.8 MiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 559 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 531 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 144 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 1006 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 943 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 214 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 246 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 246 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 241 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 234 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 145 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 152 KiB

View File

@ -1,124 +0,0 @@
# Welcome to MiniMind!
<figure markdown>
![logo](images/logo.png)
<figcaption><strong>"Simplicity is the ultimate sophistication"</strong></figcaption>
</figure>
## 📌 Introduction
**MiniMind** is a complete, open-source project for training ultra-small language models from scratch with minimal cost. Train a **26M** ChatBot in just **2 hours** with only **$3** on a single 3090 GPU!
- **MiniMind** series is extremely lightweight, the smallest version is **1/7000** the size of GPT-3
- Complete implementation covering:
- **Tokenizer training** with custom vocabulary
- **Pretraining** (knowledge learning)
- **Supervised Fine-Tuning (SFT)** (conversation patterns)
- **LoRA fine-tuning** (parameter-efficient adaptation)
- **Direct Preference Optimization (DPO)** (human preference alignment)
- **RLAIF algorithms** (PPO/GRPO/SPO - reinforcement learning)
- **Knowledge distillation** (compress large model knowledge)
- **Model reasoning distillation** (DeepSeek-R1 style)
- **YaRN algorithm** (context length extrapolation)
- **Pure PyTorch implementation**: All core algorithms are implemented from scratch using native PyTorch, without relying on third-party abstract interfaces
- **Educational value**: This is not only a full-stage open-source reproduction of large language models, but also a comprehensive tutorial for getting started with LLMs
- **Extended capabilities**: MiniMind now supports [MiniMind-V](https://github.com/jingyaogong/minimind-v) for vision multimodal tasks
!!! note "Training Cost & Time"
"2 hours" is based on **NVIDIA 3090** hardware (single card) testing
"$3" refers to GPU server rental cost
With 8× RTX 4090 GPUs, training time can be compressed to **under 10 minutes**
## ✨ Key Highlights
- **Ultra-low cost**: Single 3090, 2 hours, $3 to train a fully functional ChatBot from scratch
- **Complete pipeline**: Tokenizer → Pretraining → SFT → LoRA → DPO/RLAIF → Distillation → Reasoning
- **Latest algorithms**: Implements cutting-edge techniques including GRPO, SPO, and YaRN
- **Education-friendly**: Clean, well-documented code suitable for learning LLM principles
- **Ecosystem compatible**: Seamless support for `transformers`, `trl`, `peft`, `llama.cpp`, `vllm`, `ollama`, and `Llama-Factory`
- **Full capabilities**: Supports multi-GPU training (DDP/DeepSpeed), model visualization (Wandb/SwanLab), and dynamic checkpoint management
- **Production-ready**: OpenAI API protocol support for easy integration with third-party UIs (FastGPT, Open-WebUI, etc.)
- **Multimodal extension**: Extended to vision with [MiniMind-V](https://github.com/jingyaogong/minimind-v)
## 📊 Model Series
### MiniMind2 Series (Latest - 2025.04.26)
| Model | Parameters | Vocabulary | Layers | Hidden Dim | Context | Inference Memory |
|-------|-----------|------------|--------|-----------|---------|-----------------|
| MiniMind2-small | 26M | 6,400 | 8 | 512 | 2K | ~0.5 GB |
| MiniMind2-MoE | 145M | 6,400 | 8 | 640 | 2K | ~1.0 GB |
| MiniMind2 | 104M | 6,400 | 16 | 768 | 2K | ~1.0 GB |
### MiniMind-V1 Series (Legacy - 2024.09.01)
| Model | Parameters | Vocabulary | Layers | Hidden Dim | Context |
|-------|-----------|------------|--------|-----------|---------|
| minimind-v1-small | 26M | 6,400 | 8 | 512 | 2K |
| minimind-v1-moe | 104M | 6,400 | 8 | 512 | 2K |
| minimind-v1 | 108M | 6,400 | 16 | 768 | 2K |
## 📅 Latest Updates (2025-10-24)
🔥 **RLAIF Training Algorithms**: Native implementation of PPO, GRPO, and SPO
- **YaRN Algorithm**: RoPE length extrapolation for improved long-sequence handling
- **Adaptive Thinking**: Reasoning models support optional thinking chains
- **Full template support**: Tool calling and reasoning tags (`<tool_call>`, `<think>`, etc.)
- **Visualization**: Switched from WandB to [SwanLab](https://swanlab.cn/) (China-friendly)
- **Reasoning models**: Complete MiniMind-Reason series based on DeepSeek-R1 distillation
## 🎯 Project Contents
- Complete MiniMind-LLM architecture code (Dense + MoE models)
- Detailed Tokenizer training code
- Full training pipeline: Pretrain → SFT → LoRA → RLHF/RLAIF → Distillation
- High-quality, curated and deduplicated datasets at all stages
- Native PyTorch implementation of key algorithms, minimal third-party dependencies
- Multi-GPU training support (single-machine multi-card DDP, DeepSpeed, distributed clusters)
- Visualization with Wandb/SwanLab
- Model evaluation on third-party benchmarks (C-Eval, C-MMLU, OpenBookQA)
- YaRN algorithm for RoPE context length extrapolation
- OpenAI API protocol server for easy integration
- Streamlit web UI for chat
- Full compatibility with community tools: llama.cpp, vllm, ollama, Llama-Factory
- MiniMind-Reason models: Complete open-source data + weights for reasoning distillation
## 🚀 Quick Navigation
- **[Quick Start](quickstart.md)** - Environment setup, model download, quick testing
- **[Model Training](training.md)** - Pretraining, SFT, LoRA, RLHF, RLAIF, and reasoning training
## 🔗 Links & Resources
**Project Repositories**:
- **GitHub**: [https://github.com/jingyaogong/minimind](https://github.com/jingyaogong/minimind)
- **HuggingFace**: [MiniMind Collection](https://huggingface.co/collections/jingyaogong/minimind-66caf8d999f5c7fa64f399e5)
- **ModelScope**: [MiniMind Profile](https://www.modelscope.cn/profile/gongjy)
**Online Demos**:
- [ModelScope Studio - Standard Chat](https://www.modelscope.cn/studios/gongjy/MiniMind)
- [ModelScope Studio - Reasoning Model](https://www.modelscope.cn/studios/gongjy/MiniMind-Reasoning)
- [Bilibili Video Introduction](https://www.bilibili.com/video/BV12dHPeqE72/)
**Vision Extension**:
- [MiniMind-V](https://github.com/jingyaogong/minimind-v) - Multimodal vision language models
## 💡 Why MiniMind?
The AI community is flooded with high-cost, complex frameworks that abstract away the fundamentals. MiniMind aims to democratize LLM learning by:
1. **Lowering the barrier**: No need for expensive GPUs or cloud services
2. **Understanding, not just using**: Learn every detail from tokenization to inference
3. **End-to-end learning**: Train from scratch, not just fine-tune existing models
4. **Code clarity**: Pure PyTorch implementations you can read and understand
5. **Practical results**: Get a working ChatBot with minimal resources
As we say: **"Building a Lego airplane is far more exciting than flying first class!"**
---
Next: [Get Started →](quickstart.md)

View File

@ -1,279 +0,0 @@
# Quick Start
Get MiniMind up and running in minutes!
## 📋 Requirements
### Hardware
- **GPU Memory**: 8GB minimum (24GB recommended for comfortable development)
- **Recommended GPU**: NVIDIA RTX 3090 (24GB)
### Software
- **Python**: 3.10+
- **PyTorch**: 2.0+ (with CUDA 12.2+ for GPU support)
- **CUDA**: 12.2+ (optional, for GPU acceleration)
!!! tip "Hardware Configuration Reference"
- **CPU**: Intel i9-10980XE @ 3.00GHz
- **RAM**: 128 GB
- **GPU**: NVIDIA GeForce RTX 3090 (24GB) × 8
- **OS**: Ubuntu 20.04
- **CUDA**: 12.2
- **Python**: 3.10.16
## 🚀 Step 0: Clone the Repository
```bash
git clone https://github.com/jingyaogong/minimind.git
cd minimind
```
## 🎯 Section I: Testing Existing Models
### 1. Environment Setup
```bash
pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple
```
!!! warning "Verify CUDA Support"
After installation, verify PyTorch can access CUDA:
```python
import torch
print(torch.cuda.is_available())
print(torch.cuda.get_device_name(0))
```
If `False`, download the correct PyTorch version from [PyTorch Official](https://download.pytorch.org/whl/torch_stable.html)
### 2. Download Pretrained Models
Choose one option:
**From HuggingFace** (recommended for international users):
```bash
git clone https://huggingface.co/jingyaogong/MiniMind2
```
**From ModelScope** (recommended for China users):
```bash
git clone https://www.modelscope.cn/models/gongjy/MiniMind2.git
```
### 3. Command-Line Chat
```bash
# load=0: load PyTorch model, load=1: load transformers model
python eval_model.py --load 1 --model_mode 2
```
**Model Modes**:
- `model_mode 0`: Pretrain model (word continuation)
- `model_mode 1`: SFT Chat model (conversation)
- `model_mode 2`: RLHF model (refined responses, currently same as SFT for small models)
- `model_mode 3`: Reasoning model (with thinking chains)
- `model_mode 4/5`: RLAIF models (PPO/GRPO trained)
**Example Session**:
```text
👶: Hello, please introduce yourself.
🤖️: I am MiniMind, an AI assistant developed by Jingyao Gong.
I use natural language processing and machine learning algorithms to interact with users.
👶: What's the capital of France?
🤖️: The capital of France is Paris, which is located in the northern central part of France.
It is the largest city in France and serves as its political, economic, and cultural center.
```
### 4. Web UI Demo (Optional)
```bash
# Requires Python >= 3.10
pip install streamlit
cd scripts
streamlit run web_demo.py
```
Visit `http://localhost:8501` to use the interactive web interface.
### 5. Rope Length Extrapolation with YaRN
Extend context length beyond training with RoPE extrapolation:
```bash
python eval_model.py --inference_rope_scaling True
```
This enables the YaRN algorithm to handle sequences longer than the 2K training context, useful for processing documents and long conversations.
## 🔧 Third-Party Inference Frameworks
MiniMind is compatible with popular inference engines:
### Ollama (Easiest)
```bash
ollama run jingyaogong/minimind2
```
### vLLM (Fastest)
```bash
vllm serve ./MiniMind2/ --served-model-name "minimind" --port 8000
# Test with curl
curl http://localhost:8000/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"model": "minimind",
"messages": [{"role": "user", "content": "Hello!"}],
"temperature": 0.7,
"max_tokens": 512
}'
```
### llama.cpp (CPU-Friendly)
```bash
# Convert to GGUF format
python scripts/convert_model.py ./MiniMind2/ --output ./MiniMind2.gguf
# Quantize for size reduction
./llama-quantize ./MiniMind2.gguf ./MiniMind2-Q4.gguf Q4_K_M
# Run inference
./llama-cli -m ./MiniMind2-Q4.gguf -p "Hello" -n 128
```
## 🔌 OpenAI API Server (For Integration)
Run MiniMind as an OpenAI API-compatible service:
```bash
python scripts/serve_openai_api.py
```
Test the API:
```bash
# In another terminal
python scripts/chat_openai_api.py
```
**cURL Example**:
```bash
curl http://localhost:8000/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"model": "minimind",
"messages": [
{"role": "user", "content": "Explain machine learning in one sentence."}
],
"temperature": 0.7,
"max_tokens": 256,
"stream": true
}'
```
This enables integration with:
- [FastGPT](https://fastgpt.run/)
- [Open-WebUI](https://github.com/open-webui/open-webui)
- [Dify](https://dify.ai/)
- Any OpenAI API-compatible client
## 📊 Model Selection Guide
| Use Case | Recommended Model | Memory | Speed |
|----------|------------------|--------|-------|
| Learning/Testing | MiniMind2-small (26M) | ~0.5 GB | Fastest |
| Balanced | MiniMind2 (104M) | ~1.0 GB | Fast |
| Expert System (MoE) | MiniMind2-MoE (145M) | ~1.0 GB | Dynamic |
| Reasoning/Complex | MiniMind-Reason (104M) | ~1.0 GB | Standard |
## ⚡ Quick Test Results
**Model**: MiniMind2 (104M parameters)
```text
Q: What is photosynthesis?
A: Photosynthesis is a process in which plants convert light energy from the sun
into chemical energy to produce glucose. This process occurs mainly in leaves
and is essential for plant growth and survival.
Q: Write a Python function to calculate Fibonacci numbers.
A: def fibonacci(n):
if n <= 1:
return n
return fibonacci(n-1) + fibonacci(n-2)
# For better performance, use dynamic programming:
def fibonacci_dp(n):
dp = [0] * (n + 1)
for i in range(2, n + 1):
dp[i] = dp[i-1] + dp[i-2]
return dp[n]
Q: 世界上最高的山峰是什么? (What is the highest mountain?)
A: 珠穆朗玛峰Mount Everest是世界上最高的山峰位于喜马拉雅山脉...
(Mount Everest is the world's highest mountain, located in the Himalayas...)
```
## 🆘 Troubleshooting
### Issue: CUDA Out of Memory
**Solution**:
```bash
# Reduce batch size
python eval_model.py --batch_size 1
# Or use CPU (slow but works)
python eval_model.py --device cpu
```
### Issue: Slow Inference
**Solutions**:
- Use vLLM or llama.cpp for faster inference
- Enable quantization (4-bit, 8-bit)
- Use GPU instead of CPU
- Reduce `max_tokens` parameter
### Issue: Model Responses Are Poor Quality
**Possible Causes**:
- Using pretrain model (`model_mode 0`) instead of SFT (`model_mode 1`)
- Model is undertrained - download the full checkpoint instead
- Input prompt is too short - provide more context
### Issue: Python/PyTorch Version Mismatch
**Solution**:
```bash
# Use conda for clean environment
conda create -n minimind python=3.10
conda activate minimind
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu122
pip install -r requirements.txt
```
## 📖 Next Steps
- **[Model Training Guide](training.md)** - Train your own MiniMind from scratch
- **[Source Code](https://github.com/jingyaogong/minimind)** - Explore and learn LLM implementation
- **[Inference Benchmarks](https://huggingface.co/collections/jingyaogong/minimind-66caf8d999f5c7fa64f399e5)** - See model performance comparisons
## 💡 Pro Tips
1. **GPU Memory Optimization**: Use `torch.cuda.empty_cache()` periodically
2. **Batch Processing**: For efficiency, process multiple prompts in batches
3. **Temperature Tuning**: Lower (0.3-0.7) = more consistent, Higher (0.8-1.0) = more creative
4. **Prompt Engineering**: Better prompts → better results, even for small models
5. **Model Quantization**: Use 4-bit quantization to run on smaller GPUs
---
Done! Now you're ready to use MiniMind. Start with the Quick Start, then move to [Model Training](training.md) to learn how to train your own models.

View File

@ -1,679 +0,0 @@
# Model Training Guide
Learn how to train MiniMind language models from scratch using pure PyTorch.
## 📊 Training Overview
MiniMind implements a complete training pipeline:
```
Tokenizer Training
Pretraining (Learn knowledge)
SFT (Learn conversation)
┌───────────────────┬─────────────────────┬──────────────┐
↓ ↓ ↓ ↓
LoRA DPO/RLHF RLAIF (PPO/GRPO/SPO) Distillation
(Domain adapt) (Preference) (Reinforcement Learn) (Reasoning)
```
## 💰 Training Costs (Single NVIDIA 3090)
| Model | Dataset | Duration | Cost (RMB) | Quality |
|-------|---------|----------|-----------|---------|
| MiniMind2-Small | pretrain_hq + sft_mini_512 | 2.1h | ≈3 | 😊😊 |
| MiniMind2-Small | Full dataset | 38h | ≈50 | 😊😊😊😊😊😊 |
| MiniMind2 | pretrain_hq + sft_mini_512 | 3.3h | ≈5 | 😊😊 |
| MiniMind2 | Full dataset | 122h | ≈160 | 😊😊😊😊😊😊😊 |
!!! success "Ultra-Fast Training"
**Just 2.1 hours + $3 = Functional ChatBot!**
Use `pretrain_hq.jsonl` + `sft_mini_512.jsonl` for fastest reproduction
## 📋 Data Preparation
### 1. Download Datasets
Download from [ModelScope](https://www.modelscope.cn/datasets/gongjy/minimind_dataset) or [HuggingFace](https://huggingface.co/datasets/jingyaogong/minimind_dataset):
```bash
mkdir -p dataset
cd dataset
# Download required files
```
### 2. Dataset Directory Structure
```
./dataset/
├── pretrain_hq.jsonl ✨ (1.6GB, required for pretraining)
├── sft_mini_512.jsonl ✨ (1.2GB, fastest SFT)
├── sft_512.jsonl (7.5GB, standard SFT)
├── sft_1024.jsonl (5.6GB, longer SFT)
├── sft_2048.jsonl (9GB, very long SFT)
├── dpo.jsonl (909MB, DPO training)
├── r1_mix_1024.jsonl (340MB, reasoning distillation)
├── rlaif-mini.jsonl (1MB, RLAIF algorithms)
├── lora_identity.jsonl (22.8KB, identity LoRA)
└── lora_medical.jsonl (34MB, medical domain LoRA)
```
### 3. Data Formats
**Pretraining Data** (`pretrain_hq.jsonl`):
```json
{"text": "How to overcome procrastination? Overcoming procrastination is not easy, but these suggestions may help..."}
```
**SFT Data** (`sft_*.jsonl`):
```json
{
"conversations": [
{"role": "user", "content": "Hello!"},
{"role": "assistant", "content": "Hello! How can I help?"},
{"role": "user", "content": "Tell me a joke."},
{"role": "assistant", "content": "Why did the scarecrow win an award? Because he was outstanding in his field!"}
]
}
```
**DPO Data** (`dpo.jsonl`):
```json
{
"chosen": [
{"role": "user", "content": "What is 2+2?"},
{"role": "assistant", "content": "2+2 equals 4."}
],
"rejected": [
{"role": "user", "content": "What is 2+2?"},
{"role": "assistant", "content": "2+2 equals 5."}
]
}
```
**LoRA Domain Data** (`lora_*.jsonl`):
```json
{
"conversations": [
{"role": "user", "content": "What's the treatment for cervical spondylosis?"},
{"role": "assistant", "content": "Cervical spondylosis treatment typically includes..."}
]
}
```
## 🎯 Complete Training Pipeline
All training scripts are in the `./trainer` directory.
```bash
cd trainer
```
### Stage 1: Pretraining
**Purpose**: Learn foundational knowledge (word continuation)
```bash
# Single GPU
python train_pretrain.py
# Multi-GPU (DDP)
torchrun --nproc_per_node 2 train_pretrain.py
# Multi-GPU (DeepSpeed)
deepspeed --master_port 29500 --num_gpus=2 train_pretrain.py
```
**Key Parameters**:
- `max_seq_len`: 512 (adjust based on GPU memory)
- `learning_rate`: 1e-4
- `epochs`: Adjust based on dataset size
**Output**: `./out/pretrain_*.pth`
**Training Duration**:
- MiniMind2-Small (26M): ~1.1h
- MiniMind2 (104M): ~3.9h
!!! tip "Pretraining Tips"
- Start with `pretrain_hq.jsonl` for best results
- Quality > Quantity for pretraining data
- Monitor loss curve to detect overfitting
### Stage 2: Supervised Fine-Tuning (SFT)
**Purpose**: Teach conversation patterns and chat templates
```bash
# Single GPU
python train_full_sft.py
# Multi-GPU
torchrun --nproc_per_node 2 train_full_sft.py
```
**Configuration**:
- Load pretrained model from Stage 1
- Use SFT dataset (`sft_mini_512.jsonl` or `sft_512.jsonl`)
- Adjust `max_seq_len` to match training data
**Output**: `./out/full_sft_*.pth`
**Training Duration**:
- With sft_mini_512: 1-3 hours
- With full sft_512: 20-25 hours
!!! warning "SFT Data Selection"
- `sft_mini_512.jsonl`: Fastest, ~1.2GB, 512 tokens max
- `sft_512.jsonl`: Standard, ~7.5GB, 512 tokens max
- `sft_1024.jsonl`: Longer, ~5.6GB, 1024 tokens max
- `sft_2048.jsonl`: Extended, ~9GB, 2048 tokens max
### Stage 3: LoRA Fine-Tuning (Optional)
**Purpose**: Parameter-efficient domain adaptation
**Use Cases**:
- Medical Q&A knowledge
- Personal identity/self-awareness
- Proprietary domain knowledge
```bash
# Edit train_lora.py to set correct dataset and base model
python train_lora.py
# Multi-GPU
torchrun --nproc_per_node 2 train_lora.py
```
**Output**: `./out/lora/lora_*.pth`
**Example 1: Medical Domain**
Prepare `dataset/lora_medical.jsonl`:
```json
{
"conversations": [
{"role": "user", "content": "What's the correct pillow height for cervical spondylosis?"},
{"role": "assistant", "content": "For cervical spondylosis, pillow height should be..."}
]
}
```
Train:
```bash
# Modify train_lora.py: lora_name = 'medical'
python train_lora.py
```
**Example 2: Identity/Self-Awareness**
Prepare `dataset/lora_identity.jsonl`:
```json
{
"conversations": [
{"role": "user", "content": "Who are you?"},
{"role": "assistant", "content": "I am MiniMind..."}
]
}
```
### Stage 4: Direct Preference Optimization (DPO)
**Purpose**: Align model responses with human preferences
DPO eliminates the need for separate reward models by directly optimizing preference pairs.
```bash
python train_dpo.py
# Multi-GPU
torchrun --nproc_per_node 2 train_dpo.py
```
**Output**: `./out/rlhf_*.pth`
**Key Features**:
- Off-policy training (reuse data across epochs)
- No separate reward model needed
- Better sample efficiency than PPO
- Stable training convergence
**Training Duration**: ~1-3 hours
### Stage 5: Reinforcement Learning from AI Feedback (RLAIF)
RLAIF is an advanced training approach using AI-generated rewards instead of human annotations. MiniMind implements three modern algorithms:
#### 5.1 PPO (Proximal Policy Optimization)
Classical on-policy RL algorithm with proven stability.
```bash
python train_ppo.py
# Multi-GPU
torchrun --nproc_per_node 2 train_ppo.py
```
**Algorithm**:
$$\mathcal{L}_{PPO} = -\mathbb{E}\left[\min(r_t \cdot A_t, \text{clip}(r_t, 1-\varepsilon, 1+\varepsilon) \cdot A_t)\right] + \beta \cdot \mathbb{E}[\text{KL}]$$
**Characteristics**:
- Stable but slower reward improvement
- Requires both Actor and Critic networks
- High memory usage (1.5-2× single network)
- Good for exploration
**Output**: `./out/ppo_actor_*.pth`
**Training Duration**: ~1-3 hours
#### 5.2 GRPO (Group Relative Policy Optimization)
Modern algorithm used in DeepSeek-R1, with faster convergence.
```bash
python train_grpo.py
# Multi-GPU
torchrun --nproc_per_node 2 train_grpo.py
```
**Algorithm**:
$$\mathcal{L}_{GRPO} = -\mathbb{E}\left[r_t \cdot A_t - \beta \cdot \text{KL}_t\right]$$
Where advantage is computed as:
$$A_t = \frac{R - \mu_{group}}{\sigma_{group}}$$
**Characteristics**:
- Single-network design (memory efficient)
- Faster reward improvement
- Group normalization removes bias
- Better convergence stability
**Output**: `./out/grpo_*.pth`
**Training Duration**: ~1-3 hours
#### 5.3 SPO (Single-stream Policy Optimization)
Newest algorithm (2025) addressing GRPO's degenerate group problem.
```bash
python train_spo.py
# Multi-GPU
torchrun --nproc_per_node 2 train_spo.py
```
**Algorithm**:
$$\mathcal{L}_{SPO} = -\mathbb{E}\left[\log \pi_\theta(a_t|s) \cdot A_t - \beta \cdot \text{KL}_t\right]$$
With adaptive baseline: $B_t^{adaptive}$
**Characteristics**:
- No group dependency (1 input → 1 training sample)
- Adaptive value tracking
- Better handling of difficult examples
- Experimental on small models
**Output**: `./out/spo_*.pth`
**Training Duration**: ~1-3 hours
#### RLAIF Dataset Preparation
All RLAIF algorithms use `rlaif-mini.jsonl` (1MB, 10k examples):
```bash
# Download dataset
# Format: Same as SFT, but assistant content is "无" (none)
{
"conversations": [
{"role": "user", "content": "Explain photosynthesis briefly."},
{"role": "assistant", "content": "无"}
]
}
```
The model generates completions during training, which are scored by a **Reward Model** (e.g., InternLM2-1.8B-Reward).
**Reward Model Setup**:
```bash
# Download reward model to parent directory
cd ../
git clone https://huggingface.co/internlm/internlm2-1_8b-reward
# Directory structure should be:
# project/
# ├── minimind/
# └── internlm2-1_8b-reward/
```
#### RLAIF vs DPO Comparison
| Aspect | DPO | RLAIF (PPO/GRPO/SPO) |
|--------|-----|---------------------|
| Training Type | Off-policy | On-policy |
| Data Freshness | Static pairs | Dynamic (generated) |
| Reward Source | Implicit | Explicit model |
| Convergence | Fast | Slower |
| Memory Usage | Lower | Higher |
| Best For | Preference refinement | Capability improvement |
### Stage 6: Reasoning Model Distillation
**Purpose**: Distill DeepSeek-R1-style reasoning into MiniMind
```bash
python train_distill_reason.py
# Multi-GPU
torchrun --nproc_per_node 2 train_distill_reason.py
```
**Data Format** (`r1_mix_1024.jsonl`):
```json
{
"conversations": [
{
"role": "user",
"content": "Solve: 5 + 3 = ?"
},
{
"role": "assistant",
"content": "<think>\nI need to add 5 and 3.\n5 + 3 = 8\n</think>\n<answer>\n5 + 3 = 8\n</answer>"
}
]
}
```
**Output**: `./out/reason_*.pth`
**Training Features**:
- Enforces `<think>` and `<answer>` tags
- Penalty loss for format violations
- Mixed data (reasoning + multi-turn + English)
## 🔧 Multi-GPU Training
### DDP (Distributed Data Parallel)
Best for single-machine multi-GPU:
```bash
torchrun --nproc_per_node N train_xxx.py
# N = number of GPUs
```
### DeepSpeed
For advanced optimization:
```bash
deepspeed --master_port 29500 --num_gpus=N train_xxx.py
```
### Wandb Monitoring
Track training progress:
```bash
# Login first
wandb login
# Enable wandb logging
torchrun --nproc_per_node N train_xxx.py --use_wandb
# Or SwanLab (China-friendly alternative)
python train_xxx.py --use_wandb # Automatically uses SwanLab if available
```
## 🧪 Model Testing
### Evaluate Pretrain Model
```bash
python eval_model.py --model_mode 0
```
### Evaluate Chat Model
```bash
python eval_model.py --model_mode 1
```
### Evaluate with LoRA
```bash
python eval_model.py --lora_name 'lora_medical' --model_mode 1
```
### Evaluate Reasoning Model
```bash
python eval_model.py --model_mode 3
```
### Evaluate RLAIF Models
```bash
# PPO model
python eval_model.py --model_mode 4
# GRPO model
python eval_model.py --model_mode 4
```
### RoPE Length Extrapolation
Test with extended context:
```bash
python eval_model.py --model_mode 1 --inference_rope_scaling True
```
## 📐 Model Architecture
### MiniMind Structure
**Decoder-Only Transformer** (similar to Llama3):
```
Input Tokens
Token Embedding (6400 vocab)
Rotary Embeddings (RoPE) [with YaRN for length extrapolation]
[Transformer Blocks] ×N
├─ Attention (Multi-Head)
├─ RMSNorm
├─ SwiGLU FFN [or MoE for MoE variant]
└─ Residual Connections
RMSNorm
LM Head (→ 6400 vocab logits)
Output Probabilities
```
### Model Configurations
| Config | MiniMind2-Small | MiniMind2 | MiniMind2-MoE |
|--------|-----------------|----------|---------------|
| Parameters | 26M | 104M | 145M |
| Hidden Dim | 512 | 768 | 640 |
| Layers | 8 | 16 | 8 |
| KV Heads | 2 | 2 | 2 |
| Q Heads | 8 | 8 | 8 |
| Vocab Size | 6,400 | 6,400 | 6,400 |
| Context Length | 2,048 | 2,048 | 2,048 |
### Modify Architecture
Edit `./model/LMConfig.py`:
```python
class LMConfig:
hidden_size: int = 768
num_layers: int = 16
num_heads: int = 8
num_kv_heads: int = 2
# ... other configs
```
## 🔍 Training Tips & Best Practices
### Data Quality > Quantity
- High-quality pretraining data accelerates convergence
- `pretrain_hq.jsonl` is carefully curated for quality
- Consider data deduplication and cleaning
### Learning Rate Scheduling
```python
# Recommended schedules
- Linear warmup then decay
- Initial: 1e-4 to 5e-4
- Warmup steps: 10% of total
- Final: 10% of initial LR
```
### Batch Size & Sequence Length
```python
# Balance between GPU memory and convergence
- Pretraining: max_seq_len=512, batch_size=32
- SFT: max_seq_len=512, batch_size=16
- LoRA: max_seq_len=512, batch_size=16
```
### Memory Optimization
```bash
# Reduce batch size if OOM
python train_xxx.py --batch_size 8
# Or use gradient accumulation
python train_xxx.py --gradient_accumulation_steps 4
```
### Checkpoint Management
- Saves every 100 steps by default
- Each new save overwrites the old one
- Automatic backup before training
## 🚨 Common Issues & Solutions
### Issue: CUDA Out of Memory
```bash
# Solution 1: Reduce batch size
python train_xxx.py --batch_size 4
# Solution 2: Use gradient accumulation
python train_xxx.py --batch_size 16 --gradient_accumulation_steps 2
# Solution 3: Use smaller model
# Edit trainer script to use MiniMind2-Small instead
```
### Issue: Training Not Converging
```python
# Possible causes:
1. Learning rate too high/low
2. Data quality issues
3. Model capacity mismatch
# Solutions:
- Reduce learning rate: --learning_rate 1e-5
- Check data format and quality
- Try smaller model first
```
### Issue: Multi-GPU Sync Errors
```bash
# Ensure:
1. All GPUs visible: nvidia-smi
2. Same CUDA version across all GPUs
3. Network connectivity for distributed training
# Debug:
torchrun --nproc_per_node 2 train_xxx.py --debug
```
### Issue: Different Results Than Expected
```python
# Check:
1. Random seed set (reproducibility)
2. Correct model checkpoint loaded
3. Correct dataset being used
4. Same hyperparameters as reference
```
## 📈 Training Progression
Typical training curves:
```
Pretraining Loss: ↘↘↘ (steep decline, then plateau)
SFT Loss: ↘ (steady decline)
PPO Reward: ↗ (rising, may plateau)
GRPO Reward: ↗↗ (faster rise, more stable)
```
## 🎓 Advanced Topics
### Custom Datasets
Create your own dataset:
```python
# Format: JSONL with conversations list
# Each line is one training example
# Ensure consistent quality and format
```
### Model Quantization (Post-training)
```bash
# 4-bit quantization for inference
# Use tools like:
# - llama.cpp (gguf format)
# - bitsandbytes (dynamic quantization)
# - AutoGPTQ (static quantization)
```
### Model Merging
```python
# Merge base model + LoRA weights
# Use tools like: peft, llama.cpp
```
## 📚 References
- [Scaling Laws](https://arxiv.org/pdf/2001.08361.pdf)
- [RoPE Position Embeddings](https://arxiv.org/abs/2104.09864)
- [YaRN Length Extrapolation](https://arxiv.org/abs/2309.00071)
- [PPO Algorithm](https://arxiv.org/abs/1707.06347)
- [GRPO (DeepSeek)](https://arxiv.org/pdf/2402.03300)
- [SPO Algorithm](https://arxiv.org/abs/2509.13232)
- [DPO](https://arxiv.org/abs/2305.18290)
---
**Next**: Deploy your trained model or explore [advanced inference options](quickstart.md#third-party-inference-frameworks)

94
eval_llm.py Executable file
View File

@ -0,0 +1,94 @@
import time
import argparse
import random
import warnings
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer
from model.model_minimind import MiniMindConfig, MiniMindForCausalLM
from model.model_lora import *
from trainer.trainer_utils import setup_seed, get_model_params
warnings.filterwarnings('ignore')
def init_model(args):
tokenizer = AutoTokenizer.from_pretrained(args.load_from)
if 'model' in args.load_from:
model = MiniMindForCausalLM(MiniMindConfig(
hidden_size=args.hidden_size,
num_hidden_layers=args.num_hidden_layers,
use_moe=bool(args.use_moe),
inference_rope_scaling=args.inference_rope_scaling
))
moe_suffix = '_moe' if args.use_moe else ''
ckp = f'./{args.save_dir}/{args.weight}_{args.hidden_size}{moe_suffix}.pth'
model.load_state_dict(torch.load(ckp, map_location=args.device), strict=True)
if args.lora_weight != 'None':
apply_lora(model)
load_lora(model, f'./{args.save_dir}/{args.lora_weight}_{args.hidden_size}.pth')
else:
model = AutoModelForCausalLM.from_pretrained(args.load_from, trust_remote_code=True)
get_model_params(model, model.config)
return model.half().eval().to(args.device), tokenizer
def main():
parser = argparse.ArgumentParser(description="MiniMind模型推理与对话")
parser.add_argument('--load_from', default='model', type=str, help="模型加载路径model=原生torch权重其他路径=transformers格式")
parser.add_argument('--save_dir', default='out', type=str, help="模型权重目录")
parser.add_argument('--weight', default='full_sft', type=str, help="权重名称前缀pretrain, full_sft, rlhf, reason, ppo_actor, grpo, spo")
parser.add_argument('--lora_weight', default='None', type=str, help="LoRA权重名称None表示不使用可选lora_identity, lora_medical")
parser.add_argument('--hidden_size', default=768, type=int, help="隐藏层维度")
parser.add_argument('--num_hidden_layers', default=8, type=int, help="隐藏层数量")
parser.add_argument('--use_moe', default=0, type=int, choices=[0, 1], help="是否使用MoE架构0=否1=是)")
parser.add_argument('--inference_rope_scaling', default=False, action='store_true', help="启用RoPE位置编码外推4倍仅解决位置编码问题")
parser.add_argument('--max_new_tokens', default=8192, type=int, help="最大生成长度(注意:并非模型实际长文本能力)")
parser.add_argument('--temperature', default=0.85, type=float, help="生成温度控制随机性0-1越大越随机")
parser.add_argument('--top_p', default=0.95, type=float, help="nucleus采样阈值0-1")
parser.add_argument('--open_thinking', default=0, type=int, help="是否开启自适应思考0=否1=是)")
parser.add_argument('--historys', default=0, type=int, help="携带历史对话轮数需为偶数0表示不携带历史")
parser.add_argument('--show_speed', default=1, type=int, help="显示decode速度tokens/s")
parser.add_argument('--device', default='cuda' if torch.cuda.is_available() else 'cpu', type=str, help="运行设备")
args = parser.parse_args()
prompts = [
'你有什么特长?',
'为什么天空是蓝色的',
'请用Python写一个计算斐波那契数列的函数',
'解释一下"光合作用"的基本过程',
'如果明天下雨,我应该如何出门',
'比较一下猫和狗作为宠物的优缺点',
'解释什么是机器学习',
'推荐一些中国的美食'
]
conversation = []
model, tokenizer = init_model(args)
input_mode = int(input('[0] 自动测试\n[1] 手动输入\n'))
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
prompt_iter = prompts if input_mode == 0 else iter(lambda: input('💬: '), '')
for prompt in prompt_iter:
setup_seed(random.randint(0, 31415926))
if input_mode == 0: print(f'💬: {prompt}')
conversation = conversation[-args.historys:] if args.historys else []
conversation.append({"role": "user", "content": prompt})
if 'pretrain' in args.weight:
inputs = tokenizer.bos_token + prompt
else:
inputs = tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True, open_thinking=bool(args.open_thinking))
inputs = tokenizer(inputs, return_tensors="pt", truncation=True).to(args.device)
print('🧠: ', end='')
st = time.time()
generated_ids = model.generate(
inputs=inputs["input_ids"], attention_mask=inputs["attention_mask"],
max_new_tokens=args.max_new_tokens, do_sample=True, streamer=streamer,
pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id,
top_p=args.top_p, temperature=args.temperature, repetition_penalty=1
)
response = tokenizer.decode(generated_ids[0][len(inputs["input_ids"][0]):], skip_special_tokens=True)
conversation.append({"role": "assistant", "content": response})
gen_tokens = len(generated_ids[0]) - len(inputs["input_ids"][0])
print(f'\n[Speed]: {gen_tokens / (time.time() - st):.2f} tokens/s\n\n') if args.show_speed else print('\n\n')
if __name__ == "__main__":
main()

Binary file not shown.

After

Width:  |  Height:  |  Size: 313 KiB

BIN
images/LLM-structure.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 262 KiB

BIN
images/agent_rl_loss.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 702 KiB

BIN
images/agent_webui.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 124 KiB

BIN
images/benchmark_radar.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 96 KiB

BIN
images/dataset.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 123 KiB

View File

Before

Width:  |  Height:  |  Size: 66 KiB

After

Width:  |  Height:  |  Size: 66 KiB

BIN
images/grpo_loss.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 590 KiB

View File

Before

Width:  |  Height:  |  Size: 495 KiB

After

Width:  |  Height:  |  Size: 495 KiB

View File

Before

Width:  |  Height:  |  Size: 615 KiB

After

Width:  |  Height:  |  Size: 615 KiB

BIN
images/minimind-3.gif Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 5.7 MiB

BIN
images/ppo_loss.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 601 KiB

BIN
images/pretrain_loss.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 292 KiB

BIN
images/rl-structure.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 231 KiB

BIN
images/rope_ppl.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 79 KiB

BIN
images/sft_loss.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 466 KiB

View File

Before

Width:  |  Height:  |  Size: 178 KiB

After

Width:  |  Height:  |  Size: 178 KiB

View File

Before

Width:  |  Height:  |  Size: 150 KiB

After

Width:  |  Height:  |  Size: 150 KiB

View File

@ -1,66 +0,0 @@
site_name: MiniMind
site_description: MiniMind - 轻量级语言模型训练框架 / Lightweight Language Model Training Framework
site_author: jingyaogong
site_url: https://minimind.readthedocs.io/
# 搜索插件配置
plugins:
- search:
lang: en
# 主题配置
theme:
name: material
favicon: images/logo.png
icon:
logo: material/book-open-page-variant
palette:
# 浅色模式
- scheme: default
primary: white
accent: blue
toggle:
icon: material/brightness-7
name: 切换至深色模式
# 深色模式
- scheme: slate
primary: black
accent: blue
toggle:
icon: material/brightness-4
name: 切换至浅色模式
features:
# - navigation.instant # 与多语言切换不兼容,已禁用
- navigation.tracking # 锚点跟踪
- navigation.sections # 导航分组
- navigation.expand # 默认展开导航
- navigation.top # 返回顶部按钮
- search.suggest # 搜索建议
- search.highlight # 搜索高亮
- content.code.copy # 代码复制按钮
- toc.follow # 目录跟随
- toc.integrate # 目录集成到左侧边栏
language: en
# 导航结构
nav:
- Home: index.md
- Quick Start: quickstart.md
- Model Training: training.md
# Markdown 扩展
markdown_extensions:
- toc:
permalink: true
- admonition
- pymdownx.highlight:
anchor_linenums: true
- pymdownx.inlinehilite
- pymdownx.snippets
- pymdownx.superfences
- pymdownx.details
- pymdownx.tabbed:
alternate_style: true
- attr_list
- md_in_html

0
model/__init__.py Normal file
View File

65
model/model_lora.py Normal file
View File

@ -0,0 +1,65 @@
import torch
from torch import optim, nn
# 定义Lora网络结构
class LoRA(nn.Module):
def __init__(self, in_features, out_features, rank):
super().__init__()
self.rank = rank # LoRA的秩rank控制低秩矩阵的大小
self.A = nn.Linear(in_features, rank, bias=False) # 低秩矩阵A
self.B = nn.Linear(rank, out_features, bias=False) # 低秩矩阵B
# 矩阵A高斯初始化
self.A.weight.data.normal_(mean=0.0, std=0.02)
# 矩阵B全0初始化
self.B.weight.data.zero_()
def forward(self, x):
return self.B(self.A(x))
def apply_lora(model, rank=16):
for name, module in model.named_modules():
if isinstance(module, nn.Linear) and module.weight.shape[0] == module.weight.shape[1]:
lora = LoRA(module.weight.shape[0], module.weight.shape[1], rank=rank).to(model.device)
setattr(module, "lora", lora)
original_forward = module.forward
# 显式绑定
def forward_with_lora(x, layer1=original_forward, layer2=lora):
return layer1(x) + layer2(x)
module.forward = forward_with_lora
def load_lora(model, path):
state_dict = torch.load(path, map_location=model.device)
state_dict = {(k[7:] if k.startswith('module.') else k): v for k, v in state_dict.items()}
for name, module in model.named_modules():
if hasattr(module, 'lora'):
lora_state = {k.replace(f'{name}.lora.', ''): v for k, v in state_dict.items() if f'{name}.lora.' in k}
module.lora.load_state_dict(lora_state)
def save_lora(model, path):
raw_model = getattr(model, '_orig_mod', model)
state_dict = {}
for name, module in raw_model.named_modules():
if hasattr(module, 'lora'):
clean_name = name[7:] if name.startswith("module.") else name
lora_state = {f'{clean_name}.lora.{k}': v.cpu().half() for k, v in module.lora.state_dict().items()}
state_dict.update(lora_state)
torch.save(state_dict, path)
def merge_lora(model, lora_path, save_path):
load_lora(model, lora_path)
raw_model = getattr(model, '_orig_mod', model)
state_dict = {k: v.cpu().half() for k, v in raw_model.state_dict().items() if '.lora.' not in k}
for name, module in raw_model.named_modules():
if isinstance(module, nn.Linear) and '.lora.' not in name:
state_dict[f'{name}.weight'] = module.weight.data.clone().cpu().half()
if hasattr(module, 'lora'):
state_dict[f'{name}.weight'] += (module.lora.B.weight.data @ module.lora.A.weight.data).cpu().half()
torch.save(state_dict, save_path)

287
model/model_minimind.py Executable file
View File

@ -0,0 +1,287 @@
import math, torch, torch.nn.functional as F
from torch import nn
from transformers.activations import ACT2FN
from transformers import PreTrainedModel, GenerationMixin, PretrainedConfig
from transformers.modeling_outputs import MoeCausalLMOutputWithPast
# 🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏
# MiniMind Config
# 🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏
class MiniMindConfig(PretrainedConfig):
model_type = "minimind"
def __init__(self, hidden_size=768, num_hidden_layers=8, use_moe=False, **kwargs):
super().__init__(**kwargs)
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.use_moe = use_moe
self.dropout = kwargs.get("dropout", 0.0)
self.vocab_size = kwargs.get("vocab_size", 6400)
self.bos_token_id = kwargs.get("bos_token_id", 1)
self.eos_token_id = kwargs.get("eos_token_id", 2)
self.flash_attn = kwargs.get("flash_attn", True)
self.num_attention_heads = kwargs.get("num_attention_heads", 8)
self.num_key_value_heads = kwargs.get("num_key_value_heads", 4)
self.head_dim = kwargs.get("head_dim", self.hidden_size // self.num_attention_heads)
self.hidden_act = kwargs.get("hidden_act", 'silu')
self.intermediate_size = kwargs.get("intermediate_size", math.ceil(hidden_size * math.pi / 64) * 64)
self.max_position_embeddings = kwargs.get("max_position_embeddings", 32768)
self.rms_norm_eps = kwargs.get("rms_norm_eps", 1e-6)
self.rope_theta = kwargs.get("rope_theta", 1e6)
self.tie_word_embeddings = kwargs.get("tie_word_embeddings", True)
self.inference_rope_scaling = kwargs.get("inference_rope_scaling", False)
self.rope_scaling = {
"beta_fast": 32,
"beta_slow": 1,
"factor": 16,
"original_max_position_embeddings": 2048,
"attention_factor": 1.0,
"type": "yarn"
} if self.inference_rope_scaling else None
### MoE specific configs (ignored if use_moe = False)
self.num_experts = kwargs.get("num_experts", 4)
self.num_experts_per_tok = kwargs.get("num_experts_per_tok", 1)
self.moe_intermediate_size = kwargs.get("moe_intermediate_size", self.intermediate_size)
self.norm_topk_prob = kwargs.get("norm_topk_prob", True)
self.router_aux_loss_coef = kwargs.get("router_aux_loss_coef", 5e-4)
# 🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏
# MiniMind Model
# 🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏
class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-5):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
return (self.weight * self.norm(x.float())).type_as(x)
def precompute_freqs_cis(dim: int, end: int = int(32 * 1024), rope_base: float = 1e6, rope_scaling: dict = None):
freqs, attn_factor = 1.0 / (rope_base ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)), 1.0
if rope_scaling is not None: # YaRN: f'(i) = f(i)((1-γ) + γ/s), where γ∈[0,1] is linear ramp
orig_max, factor, beta_fast, beta_slow, attn_factor = (
rope_scaling.get("original_max_position_embeddings", 2048), rope_scaling.get("factor", 16),
rope_scaling.get("beta_fast", 32.0), rope_scaling.get("beta_slow", 1.0), rope_scaling.get("attention_factor", 1.0)
)
if end / orig_max > 1.0:
inv_dim = lambda b: (dim * math.log(orig_max / (b * 2 * math.pi))) / (2 * math.log(rope_base))
low, high = max(math.floor(inv_dim(beta_fast)), 0), min(math.ceil(inv_dim(beta_slow)), dim // 2 - 1)
ramp = torch.clamp((torch.arange(dim // 2, device=freqs.device).float() - low) / max(high - low, 0.001), 0, 1)
freqs = freqs * (1 - ramp + ramp / factor)
t = torch.arange(end, device=freqs.device)
freqs = torch.outer(t, freqs).float()
freqs_cos = torch.cat([torch.cos(freqs), torch.cos(freqs)], dim=-1) * attn_factor
freqs_sin = torch.cat([torch.sin(freqs), torch.sin(freqs)], dim=-1) * attn_factor
return freqs_cos, freqs_sin
def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
def rotate_half(x): return torch.cat((-x[..., x.shape[-1] // 2:], x[..., : x.shape[-1] // 2]), dim=-1)
q_embed = ((q * cos.unsqueeze(unsqueeze_dim)) + (rotate_half(q) * sin.unsqueeze(unsqueeze_dim))).to(q.dtype)
k_embed = ((k * cos.unsqueeze(unsqueeze_dim)) + (rotate_half(k) * sin.unsqueeze(unsqueeze_dim))).to(k.dtype)
return q_embed, k_embed
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
bs, slen, num_key_value_heads, head_dim = x.shape
if n_rep == 1: return x
return (x[:, :, :, None, :].expand(bs, slen, num_key_value_heads, n_rep, head_dim).reshape(bs, slen, num_key_value_heads * n_rep, head_dim))
class Attention(nn.Module):
def __init__(self, config: MiniMindConfig):
super().__init__()
self.num_key_value_heads = config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads
self.n_local_heads = config.num_attention_heads
self.n_local_kv_heads = self.num_key_value_heads
self.n_rep = self.n_local_heads // self.n_local_kv_heads
self.head_dim = config.head_dim
self.is_causal = True
self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=False)
self.k_proj = nn.Linear(config.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(config.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
self.q_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps)
self.k_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps)
self.attn_dropout = nn.Dropout(config.dropout)
self.resid_dropout = nn.Dropout(config.dropout)
self.dropout = config.dropout
self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') and config.flash_attn
def forward(self, x, position_embeddings, past_key_value=None, use_cache=False, attention_mask=None):
bsz, seq_len, _ = x.shape
xq, xk, xv = self.q_proj(x), self.k_proj(x), self.v_proj(x)
xq = xq.view(bsz, seq_len, self.n_local_heads, self.head_dim)
xk = xk.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim)
xv = xv.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim)
xq, xk = self.q_norm(xq), self.k_norm(xk)
cos, sin = position_embeddings
xq, xk = apply_rotary_pos_emb(xq, xk, cos, sin)
if past_key_value is not None:
xk = torch.cat([past_key_value[0], xk], dim=1)
xv = torch.cat([past_key_value[1], xv], dim=1)
past_kv = (xk, xv) if use_cache else None
xq, xk, xv = (xq.transpose(1, 2), repeat_kv(xk, self.n_rep).transpose(1, 2), repeat_kv(xv, self.n_rep).transpose(1, 2))
if self.flash and (seq_len > 1) and (not self.is_causal or past_key_value is None) and (attention_mask is None or torch.all(attention_mask == 1)):
output = F.scaled_dot_product_attention(xq, xk, xv, dropout_p=self.dropout if self.training else 0.0, is_causal=self.is_causal)
else:
scores = (xq @ xk.transpose(-2, -1)) / math.sqrt(self.head_dim)
if self.is_causal: scores[:, :, :, -seq_len:] += torch.full((seq_len, seq_len), float("-inf"), device=scores.device).triu(1)
if attention_mask is not None: scores += (1.0 - attention_mask.unsqueeze(1).unsqueeze(2)) * -1e9
output = self.attn_dropout(F.softmax(scores.float(), dim=-1).type_as(xq)) @ xv
output = output.transpose(1, 2).reshape(bsz, seq_len, -1)
output = self.resid_dropout(self.o_proj(output))
return output, past_kv
class FeedForward(nn.Module):
def __init__(self, config: MiniMindConfig, intermediate_size: int = None):
super().__init__()
intermediate_size = intermediate_size or config.intermediate_size
self.gate_proj = nn.Linear(config.hidden_size, intermediate_size, bias=False)
self.down_proj = nn.Linear(intermediate_size, config.hidden_size, bias=False)
self.up_proj = nn.Linear(config.hidden_size, intermediate_size, bias=False)
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, x):
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
class MOEFeedForward(nn.Module):
def __init__(self, config: MiniMindConfig):
super().__init__()
self.config = config
self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False)
self.experts = nn.ModuleList([FeedForward(config, intermediate_size=config.moe_intermediate_size) for _ in range(config.num_experts)])
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, x):
batch_size, seq_len, hidden_dim = x.shape
x_flat = x.view(-1, hidden_dim)
scores = F.softmax(self.gate(x_flat), dim=-1)
topk_weight, topk_idx = torch.topk(scores, k=self.config.num_experts_per_tok, dim=-1, sorted=False)
if self.config.norm_topk_prob: topk_weight = topk_weight / (topk_weight.sum(dim=-1, keepdim=True) + 1e-20)
y = torch.zeros_like(x_flat)
for i, expert in enumerate(self.experts):
mask = (topk_idx == i)
if mask.any():
token_idx = mask.any(dim=-1).nonzero().flatten()
weight = topk_weight[mask].view(-1, 1)
y.index_add_(0, token_idx, (expert(x_flat[token_idx]) * weight).to(y.dtype))
elif self.training:
y[0, 0] += 0 * sum(p.sum() for p in expert.parameters())
if self.training and self.config.router_aux_loss_coef > 0:
load = F.one_hot(topk_idx, self.config.num_experts).float().mean(0)
self.aux_loss = (load * scores.mean(0)).sum() * self.config.num_experts * self.config.router_aux_loss_coef
else:
self.aux_loss = scores.new_zeros(1).squeeze()
return y.view(batch_size, seq_len, hidden_dim)
class MiniMindBlock(nn.Module):
def __init__(self, layer_id: int, config: MiniMindConfig):
super().__init__()
self.self_attn = Attention(config)
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.mlp = FeedForward(config) if not config.use_moe else MOEFeedForward(config)
def forward(self, hidden_states, position_embeddings, past_key_value=None, use_cache=False, attention_mask=None):
residual = hidden_states
hidden_states, present_key_value = self.self_attn(
self.input_layernorm(hidden_states), position_embeddings,
past_key_value, use_cache, attention_mask
)
hidden_states += residual
hidden_states = hidden_states + self.mlp(self.post_attention_layernorm(hidden_states))
return hidden_states, present_key_value
class MiniMindModel(nn.Module):
def __init__(self, config: MiniMindConfig):
super().__init__()
self.config = config
self.vocab_size, self.num_hidden_layers = config.vocab_size, config.num_hidden_layers
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
self.dropout = nn.Dropout(config.dropout)
self.layers = nn.ModuleList([MiniMindBlock(l, config) for l in range(self.num_hidden_layers)])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
freqs_cos, freqs_sin = precompute_freqs_cis(dim=config.head_dim, end=config.max_position_embeddings, rope_base=config.rope_theta, rope_scaling=config.rope_scaling)
self.register_buffer("freqs_cos", freqs_cos, persistent=False)
self.register_buffer("freqs_sin", freqs_sin, persistent=False)
def forward(self, input_ids, attention_mask=None, past_key_values=None, use_cache=False, **kwargs):
batch_size, seq_length = input_ids.shape
if hasattr(past_key_values, 'layers'): past_key_values = None
past_key_values = past_key_values or [None] * len(self.layers)
start_pos = past_key_values[0][0].shape[1] if past_key_values[0] is not None else 0
hidden_states = self.dropout(self.embed_tokens(input_ids))
# Recompute RoPE buffers lost during meta-device init (transformers>=5.x)
if self.freqs_cos[0, 0] == 0:
freqs_cos, freqs_sin = precompute_freqs_cis(dim=self.config.head_dim, end=self.config.max_position_embeddings, rope_base=self.config.rope_theta, rope_scaling=self.config.rope_scaling)
self.freqs_cos, self.freqs_sin = freqs_cos.to(hidden_states.device), freqs_sin.to(hidden_states.device)
position_embeddings = (self.freqs_cos[start_pos:start_pos + seq_length], self.freqs_sin[start_pos:start_pos + seq_length])
presents = []
for layer, past_key_value in zip(self.layers, past_key_values):
hidden_states, present = layer(
hidden_states,
position_embeddings,
past_key_value=past_key_value,
use_cache=use_cache,
attention_mask=attention_mask
)
presents.append(present)
hidden_states = self.norm(hidden_states)
aux_loss = sum([l.mlp.aux_loss for l in self.layers if isinstance(l.mlp, MOEFeedForward)], hidden_states.new_zeros(1).squeeze())
return hidden_states, presents, aux_loss
class MiniMindForCausalLM(PreTrainedModel, GenerationMixin):
config_class = MiniMindConfig
_tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
def __init__(self, config: MiniMindConfig = None):
self.config = config or MiniMindConfig()
super().__init__(self.config)
self.model = MiniMindModel(self.config)
self.lm_head = nn.Linear(self.config.hidden_size, self.config.vocab_size, bias=False)
if self.config.tie_word_embeddings: self.model.embed_tokens.weight = self.lm_head.weight
self.post_init()
def forward(self, input_ids, attention_mask=None, past_key_values=None, use_cache=False, logits_to_keep=0, labels=None, **kwargs):
hidden_states, past_key_values, aux_loss = self.model(input_ids, attention_mask, past_key_values, use_cache, **kwargs)
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.lm_head(hidden_states[:, slice_indices, :])
loss = None
if labels is not None:
x, y = logits[..., :-1, :].contiguous(), labels[..., 1:].contiguous()
loss = F.cross_entropy(x.view(-1, x.size(-1)), y.view(-1), ignore_index=-100)
return MoeCausalLMOutputWithPast(loss=loss, aux_loss=aux_loss, logits=logits, past_key_values=past_key_values, hidden_states=hidden_states)
# https://github.com/jingyaogong/minimind/discussions/611
@torch.inference_mode()
def generate(self, inputs=None, attention_mask=None, max_new_tokens=8192, temperature=0.85, top_p=0.85, top_k=50, eos_token_id=2, streamer=None, use_cache=True, num_return_sequences=1, do_sample=True, repetition_penalty=1.0, **kwargs):
input_ids = kwargs.pop("input_ids", inputs).repeat(num_return_sequences, 1)
attention_mask = attention_mask.repeat(num_return_sequences, 1) if attention_mask is not None else None
past_key_values = kwargs.pop("past_key_values", None)
finished = torch.zeros(input_ids.shape[0], dtype=torch.bool, device=input_ids.device)
if streamer: streamer.put(input_ids.cpu())
for _ in range(max_new_tokens):
past_len = past_key_values[0][0].shape[1] if past_key_values else 0
outputs = self.forward(input_ids[:, past_len:], attention_mask, past_key_values, use_cache=use_cache, **kwargs)
attention_mask = torch.cat([attention_mask, attention_mask.new_ones(attention_mask.shape[0], 1)], -1) if attention_mask is not None else None
logits = outputs.logits[:, -1, :] / temperature
if repetition_penalty != 1.0:
for i in range(input_ids.shape[0]): logits[i, torch.unique(input_ids[i])] /= repetition_penalty
if top_k > 0:
logits[logits < torch.topk(logits, top_k)[0][..., -1, None]] = -float('inf')
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
mask = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1) > top_p
mask[..., 1:], mask[..., 0] = mask[..., :-1].clone(), 0
logits[mask.scatter(1, sorted_indices, mask)] = -float('inf')
next_token = torch.multinomial(torch.softmax(logits, dim=-1), num_samples=1) if do_sample else torch.argmax(logits, dim=-1, keepdim=True)
if eos_token_id is not None: next_token = torch.where(finished.unsqueeze(-1), next_token.new_full((next_token.shape[0], 1), eos_token_id), next_token)
input_ids = torch.cat([input_ids, next_token], dim=-1)
past_key_values = outputs.past_key_values if use_cache else None
if streamer: streamer.put(next_token.cpu())
if eos_token_id is not None:
finished |= next_token.squeeze(-1).eq(eos_token_id)
if finished.all(): break
if streamer: streamer.end()
if kwargs.get("return_kv"): return {'generated_ids': input_ids, 'past_kv': past_key_values}
return input_ids

31191
model/tokenizer.json Normal file

File diff suppressed because it is too large Load Diff

335
model/tokenizer_config.json Normal file
View File

@ -0,0 +1,335 @@
{
"add_bos_token": false,
"add_eos_token": false,
"add_prefix_space": false,
"added_tokens_decoder": {
"0": {
"content": "<|endoftext|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"1": {
"content": "<|im_start|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"2": {
"content": "<|im_end|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"3": {
"content": "<|object_ref_start|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"4": {
"content": "<|object_ref_end|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"5": {
"content": "<|box_start|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"6": {
"content": "<|box_end|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"7": {
"content": "<|quad_start|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"8": {
"content": "<|quad_end|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"9": {
"content": "<|vision_start|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"10": {
"content": "<|vision_end|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"11": {
"content": "<|vision_pad|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"12": {
"content": "<|image_pad|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"13": {
"content": "<|video_pad|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"14": {
"content": "<|audio_start|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"15": {
"content": "<|audio_end|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"16": {
"content": "<|audio_pad|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"17": {
"content": "<tts_pad>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"18": {
"content": "<tts_text_bos>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"19": {
"content": "<tts_text_eod>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"20": {
"content": "<tts_text_bos_single>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"21": {
"content": "<tool_call>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": false
},
"22": {
"content": "</tool_call>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": false
},
"23": {
"content": "<tool_response>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": false
},
"24": {
"content": "</tool_response>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": false
},
"25": {
"content": "<think>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": false
},
"26": {
"content": "</think>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": false
},
"27": {
"content": "<|buffer1|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": false
},
"28": {
"content": "<|buffer2|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": false
},
"29": {
"content": "<|buffer3|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": false
},
"30": {
"content": "<|buffer4|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": false
},
"31": {
"content": "<|buffer5|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": false
},
"32": {
"content": "<|buffer6|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": false
},
"33": {
"content": "<|buffer7|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": false
},
"34": {
"content": "<|buffer8|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": false
},
"35": {
"content": "<|buffer9|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": false
}
},
"additional_special_tokens": [
"<|im_start|>",
"<|im_end|>",
"<|object_ref_start|>",
"<|object_ref_end|>",
"<|box_start|>",
"<|box_end|>",
"<|quad_start|>",
"<|quad_end|>",
"<|vision_start|>",
"<|vision_end|>",
"<|vision_pad|>",
"<|image_pad|>",
"<|video_pad|>",
"<|audio_start|>",
"<|audio_end|>",
"<|audio_pad|>",
"<tts_pad>",
"<tts_text_bos>",
"<tts_text_eod>",
"<tts_text_bos_single>"
],
"bos_token": "<|im_start|>",
"clean_up_tokenization_spaces": false,
"eos_token": "<|im_end|>",
"legacy": true,
"model_max_length": 131072,
"pad_token": "<|endoftext|>",
"sp_model_kwargs": {},
"spaces_between_special_tokens": false,
"unk_token": "<|endoftext|>",
"image_token": "<|image_pad|>",
"audio_token": "<|audio_pad|>",
"video_token": "<|video_pad|>",
"vision_bos_token": "<|vision_start|>",
"vision_eos_token": "<|vision_end|>",
"audio_bos_token": "<|audio_start|>",
"audio_eos_token": "<|audio_end|>",
"chat_template": "{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0].role == 'system' %}\n {{- messages[0].content + '\\n\\n' }}\n {%- endif %}\n {{- \"# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within <tools></tools> XML tags:\\n<tools>\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n</tools>\\n\\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\\n<tool_call>\\n{\\\"name\\\": <function-name>, \\\"arguments\\\": <args-json-object>}\\n</tool_call><|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0].role == 'system' %}\n {{- '<|im_start|>system\\n' + messages[0].content + '<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}\n{%- for message in messages[::-1] %}\n {%- set index = (messages|length - 1) - loop.index0 %}\n {%- if ns.multi_step_tool and message.role == \"user\" and message.content is string and not(message.content.startswith('<tool_response>') and message.content.endswith('</tool_response>')) %}\n {%- set ns.multi_step_tool = false %}\n {%- set ns.last_query_index = index %}\n {%- endif %}\n{%- endfor %}\n{%- for message in messages %}\n {%- if message.content is string %}\n {%- set content = message.content %}\n {%- else %}\n {%- set content = '' %}\n {%- endif %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) %}\n {{- '<|im_start|>' + message.role + '\\n' + content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {%- set reasoning_content = '' %}\n {%- if message.reasoning_content is string %}\n {%- set reasoning_content = message.reasoning_content %}\n {%- else %}\n {%- if '</think>' in content %}\n {%- set reasoning_content = content.split('</think>')[0].rstrip('\\n').split('<think>')[-1].lstrip('\\n') %}\n {%- set content = content.split('</think>')[-1].lstrip('\\n') %}\n {%- endif %}\n {%- endif %}\n {%- if true %}\n {{- '<|im_start|>' + message.role + '\\n<think>\\n' + reasoning_content.strip('\\n') + '\\n</think>\\n\\n' + content.lstrip('\\n') }}\n {%- endif %}\n {%- if message.tool_calls %}\n {%- for tool_call in message.tool_calls %}\n {%- if (loop.first and content) or (not loop.first) %}\n {{- '\\n' }}\n {%- endif %}\n {%- if tool_call.function %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '<tool_call>\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {%- if tool_call.arguments is string %}\n {{- tool_call.arguments }}\n {%- else %}\n {{- tool_call.arguments | tojson }}\n {%- endif %}\n {{- '}\\n</tool_call>' }}\n {%- endfor %}\n {%- endif %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if loop.first or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n<tool_response>\\n' }}\n {{- content }}\n {{- '\\n</tool_response>' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n {%- if open_thinking is defined and open_thinking is true %}\n {{- '<think>\\n' }}\n {%- else %}\n {{- '<think>\\n\\n</think>\\n\\n' }}\n {%- endif %}\n{%- endif %}",
"tokenizer_class": "PreTrainedTokenizerFast"
}

View File

@ -1,2 +1,32 @@
mkdocs>=1.5.0
mkdocs-material>=9.0.0
datasets==3.6.0
datasketch==1.6.4
Flask==3.0.3
Flask_Cors==4.0.0
jieba==0.42.1
jsonlines==4.0.0
marshmallow==3.22.0
# matplotlib==3.10.0
ngrok==1.4.0
nltk==3.8
numpy==1.26.4
openai==1.59.6
# peft==0.7.1
psutil==5.9.8
pydantic==2.11.5
rich==13.7.1
scikit_learn==1.5.1
sentence_transformers==2.3.1
simhash==2.1.2
tiktoken==0.10.0
transformers==4.57.6
jinja2==3.1.2
jsonlines==4.0.0
trl==0.13.0
ujson==5.1.0
wandb==0.18.3
streamlit==1.50.0
einops==0.8.1
swanlab==0.7.11
modelscope==1.30.0
# torch==2.6.0
# torchvision==0.21.0

40
scripts/chat_api.py Normal file
View File

@ -0,0 +1,40 @@
from openai import OpenAI
client = OpenAI(
api_key="sk-123",
base_url="http://localhost:11434/v1"
)
stream = True
conversation_history_origin = []
conversation_history = conversation_history_origin.copy()
history_messages_num = 0 # 必须设置为偶数Q+A为0则不携带历史对话
while True:
query = input('[Q]: ')
conversation_history.append({"role": "user", "content": query})
response = client.chat.completions.create(
model="minimind-local:latest",
messages=conversation_history[-(history_messages_num or 1):],
stream=stream,
temperature=0.8,
max_tokens=2048,
top_p=0.8,
extra_body={"chat_template_kwargs": {"open_thinking": True}, "reasoning_effort": "medium"} # 思考开关
)
if not stream:
assistant_res = response.choices[0].message.content
print('[A]: ', assistant_res)
else:
print('[A]: ', end='', flush=True)
assistant_res = ''
for chunk in response:
delta = chunk.choices[0].delta
r = getattr(delta, 'reasoning_content', None) or ""
c = delta.content or ""
if r:
print(f'\033[90m{r}\033[0m', end="", flush=True)
if c:
print(c, end="", flush=True)
assistant_res += c
conversation_history.append({"role": "assistant", "content": assistant_res})
print('\n\n')

144
scripts/convert_model.py Normal file
View File

@ -0,0 +1,144 @@
import os
import sys
import json
__package__ = "scripts"
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
import torch
import transformers
import warnings
from transformers import AutoTokenizer, AutoModelForCausalLM, Qwen3Config, Qwen3ForCausalLM, Qwen3MoeConfig, Qwen3MoeForCausalLM
from model.model_minimind import MiniMindConfig, MiniMindForCausalLM
from model.model_lora import apply_lora, merge_lora
warnings.filterwarnings('ignore', category=UserWarning)
def convert_torch2transformers_minimind(torch_path, transformers_path, dtype=torch.float16):
MiniMindConfig.register_for_auto_class()
MiniMindForCausalLM.register_for_auto_class("AutoModelForCausalLM")
lm_model = MiniMindForCausalLM(lm_config)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
state_dict = torch.load(torch_path, map_location=device)
lm_model.load_state_dict(state_dict, strict=False)
lm_model = lm_model.to(dtype) # 转换模型权重精度
model_params = sum(p.numel() for p in lm_model.parameters() if p.requires_grad)
print(f'模型参数: {model_params / 1e6} 百万 = {model_params / 1e9} B (Billion)')
lm_model.save_pretrained(transformers_path, safe_serialization=False)
tokenizer = AutoTokenizer.from_pretrained('../model/')
tokenizer.save_pretrained(transformers_path)
# ======= transformers-5.0的兼容低版本写法 =======
if int(transformers.__version__.split('.')[0]) >= 5:
tokenizer_config_path, config_path = os.path.join(transformers_path, "tokenizer_config.json"), os.path.join(transformers_path, "config.json")
json.dump({**json.load(open(tokenizer_config_path, 'r', encoding='utf-8')), "tokenizer_class": "PreTrainedTokenizerFast", "extra_special_tokens": {}}, open(tokenizer_config_path, 'w', encoding='utf-8'), indent=2, ensure_ascii=False)
config = json.load(open(config_path, 'r', encoding='utf-8'))
config['rope_theta'] = lm_config.rope_theta; config['rope_scaling'] = None; del config['rope_parameters']
json.dump(config, open(config_path, 'w', encoding='utf-8'), indent=2, ensure_ascii=False)
print(f"模型已保存为 Transformers-MiniMind 格式: {transformers_path}")
# QwenForCausalLM/LlamaForCausalLM结构兼容生态
def convert_torch2transformers(torch_path, transformers_path, dtype=torch.float16):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
state_dict = torch.load(torch_path, map_location=device)
common_config = {
"vocab_size": lm_config.vocab_size,
"hidden_size": lm_config.hidden_size,
"intermediate_size": lm_config.intermediate_size,
"num_hidden_layers": lm_config.num_hidden_layers,
"num_attention_heads": lm_config.num_attention_heads,
"num_key_value_heads": lm_config.num_key_value_heads,
"head_dim": lm_config.hidden_size // lm_config.num_attention_heads,
"max_position_embeddings": lm_config.max_position_embeddings,
"rms_norm_eps": lm_config.rms_norm_eps,
"rope_theta": lm_config.rope_theta,
"tie_word_embeddings": lm_config.tie_word_embeddings
}
if not lm_config.use_moe:
qwen_config = Qwen3Config(
**common_config,
use_sliding_window=False,
sliding_window=None
)
qwen_model = Qwen3ForCausalLM(qwen_config)
else:
qwen_config = Qwen3MoeConfig(
**common_config,
num_experts=lm_config.num_experts,
num_experts_per_tok=lm_config.num_experts_per_tok,
moe_intermediate_size=lm_config.moe_intermediate_size,
norm_topk_prob=lm_config.norm_topk_prob
)
qwen_model = Qwen3MoeForCausalLM(qwen_config)
# ======= transformers-5.0的兼容低版本写法 =======
if int(transformers.__version__.split('.')[0]) >= 5:
new_sd = {k: v for k, v in state_dict.items() if 'experts.' not in k or 'gate.weight' in k}
for l in range(lm_config.num_hidden_layers):
p = f'model.layers.{l}.mlp.experts'
new_sd[f'{p}.gate_up_proj'] = torch.cat([torch.stack([state_dict[f'{p}.{e}.gate_proj.weight'] for e in range(lm_config.num_experts)]), torch.stack([state_dict[f'{p}.{e}.up_proj.weight'] for e in range(lm_config.num_experts)])], dim=1)
new_sd[f'{p}.down_proj'] = torch.stack([state_dict[f'{p}.{e}.down_proj.weight'] for e in range(lm_config.num_experts)])
state_dict = new_sd
qwen_model.load_state_dict(state_dict, strict=True)
qwen_model = qwen_model.to(dtype) # 转换模型权重精度
qwen_model.save_pretrained(transformers_path)
model_params = sum(p.numel() for p in qwen_model.parameters() if p.requires_grad)
print(f'模型参数: {model_params / 1e6} 百万 = {model_params / 1e9} B (Billion)')
tokenizer = AutoTokenizer.from_pretrained('../model/')
tokenizer.save_pretrained(transformers_path)
# ======= transformers-5.0的兼容低版本写法 =======
if int(transformers.__version__.split('.')[0]) >= 5:
tokenizer_config_path, config_path = os.path.join(transformers_path, "tokenizer_config.json"), os.path.join(transformers_path, "config.json")
json.dump({**json.load(open(tokenizer_config_path, 'r', encoding='utf-8')), "tokenizer_class": "PreTrainedTokenizerFast", "extra_special_tokens": {}}, open(tokenizer_config_path, 'w', encoding='utf-8'), indent=2, ensure_ascii=False)
config = json.load(open(config_path, 'r', encoding='utf-8'))
config['rope_theta'] = lm_config.rope_theta; config['rope_scaling'] = None; del config['rope_parameters']
json.dump(config, open(config_path, 'w', encoding='utf-8'), indent=2, ensure_ascii=False)
print(f"模型已保存为 Transformers 格式: {transformers_path}")
def convert_transformers2torch(transformers_path, torch_path):
model = AutoModelForCausalLM.from_pretrained(transformers_path, trust_remote_code=True)
torch.save({k: v.cpu().half() for k, v in model.state_dict().items()}, torch_path)
print(f"模型已保存为 PyTorch 格式: {torch_path}")
def convert_merge_base_lora(base_torch_path, lora_path, merged_torch_path):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
lm_model = MiniMindForCausalLM(lm_config).to(device)
state_dict = torch.load(base_torch_path, map_location=device)
lm_model.load_state_dict(state_dict, strict=False)
apply_lora(lm_model)
merge_lora(lm_model, lora_path, merged_torch_path)
print(f"LoRA 已合并并保存为基模结构 PyTorch 格式: {merged_torch_path}")
def convert_jinja_to_json(jinja_path):
with open(jinja_path, 'r') as f: template = f.read()
escaped = json.dumps(template)
print(f'"chat_template": {escaped}')
def convert_json_to_jinja(json_file_path, output_path):
with open(json_file_path, 'r') as f: config = json.load(f)
template = config['chat_template']
with open(output_path, 'w') as f: f.write(template)
print(f"模板已保存为 jinja 文件: {output_path}")
if __name__ == '__main__':
lm_config = MiniMindConfig(hidden_size=768, num_hidden_layers=8, max_seq_len=8192, use_moe=False)
# convert torch to transformers
torch_path = f"../out/full_sft_{lm_config.hidden_size}{'_moe' if lm_config.use_moe else ''}.pth"
transformers_path = '../minimind-3'
convert_torch2transformers(torch_path, transformers_path)
# # merge lora
# base_torch_path = f"../out/full_sft_{lm_config.hidden_size}{'_moe' if lm_config.use_moe else ''}.pth"
# lora_path = f"../out/lora_identity_{lm_config.hidden_size}{'_moe' if lm_config.use_moe else ''}.pth"
# merged_torch_path = f"../out/merge_identity_{lm_config.hidden_size}{'_moe' if lm_config.use_moe else ''}.pth"
# convert_merge_base_lora(base_torch_path, lora_path, merged_torch_path)
# convert_transformers2torch(transformers_path, torch_path)
# convert_json_to_jinja('../model/tokenizer_config.json', '../model/chat_template.jinja')
# convert_jinja_to_json('../model/chat_template.jinja')

240
scripts/eval_toolcall.py Normal file
View File

@ -0,0 +1,240 @@
import os
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
import re
import json
import time
import random
import argparse
import warnings
import torch
from datetime import datetime
from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer
from openai import OpenAI
from model.model_minimind import MiniMindConfig, MiniMindForCausalLM
from trainer.trainer_utils import setup_seed, get_model_params
warnings.filterwarnings('ignore')
TOOLS = [
{"type": "function", "function": {"name": "calculate_math", "description": "计算数学表达式的结果,支持加减乘除、幂运算、开方等", "parameters": {"type": "object", "properties": {"expression": {"type": "string", "description": "数学表达式如123+456、2**10、sqrt(144)"}}, "required": ["expression"]}}},
{"type": "function", "function": {"name": "get_current_time", "description": "获取当前日期和时间,支持指定时区", "parameters": {"type": "object", "properties": {"timezone": {"type": "string", "description": "时区名称如Asia/Shanghai、America/New_York", "default": "Asia/Shanghai"}}, "required": []}}},
{"type": "function", "function": {"name": "random_number", "description": "生成指定范围内的随机数", "parameters": {"type": "object", "properties": {"min": {"type": "integer", "description": "最小值", "default": 0}, "max": {"type": "integer", "description": "最大值", "default": 100}}, "required": []}}},
{"type": "function", "function": {"name": "text_length", "description": "计算文本的字符数和单词数", "parameters": {"type": "object", "properties": {"text": {"type": "string", "description": "要统计的文本"}}, "required": ["text"]}}},
{"type": "function", "function": {"name": "unit_converter", "description": "进行单位换算,支持长度、重量、温度等", "parameters": {"type": "object", "properties": {"value": {"type": "number", "description": "要转换的数值"}, "from_unit": {"type": "string", "description": "源单位如km、miles、kg、pounds、celsius、fahrenheit"}, "to_unit": {"type": "string", "description": "目标单位"}}, "required": ["value", "from_unit", "to_unit"]}}},
{"type": "function", "function": {"name": "get_current_weather", "description": "获取指定城市的当前天气信息,包括温度、湿度和天气状况", "parameters": {"type": "object", "properties": {"location": {"type": "string", "description": "城市名称如北京、上海、New York"}, "unit": {"type": "string", "description": "温度单位celsius或fahrenheit", "enum": ["celsius", "fahrenheit"], "default": "celsius"}}, "required": ["location"]}}},
{"type": "function", "function": {"name": "get_exchange_rate", "description": "查询两种货币之间的实时汇率", "parameters": {"type": "object", "properties": {"from_currency": {"type": "string", "description": "源货币代码如USD、CNY、EUR"}, "to_currency": {"type": "string", "description": "目标货币代码如USD、CNY、EUR"}}, "required": ["from_currency", "to_currency"]}}},
{"type": "function", "function": {"name": "translate_text", "description": "将文本翻译成目标语言", "parameters": {"type": "object", "properties": {"text": {"type": "string", "description": "要翻译的文本"}, "target_language": {"type": "string", "description": "目标语言如english、chinese、japanese、french"}}, "required": ["text", "target_language"]}}},
]
MOCK_RESULTS = {
"calculate_math": lambda args: {"result": str(eval(str(args.get("expression", "0")).replace("^", "**").replace("×", "*").replace("÷", "/").replace("", "-").replace("²", "**2").replace("³", "**3").replace("", "(").replace("", ")")))},
"get_current_time": lambda args: {"datetime": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), "timezone": args.get("timezone", "Asia/Shanghai")},
"random_number": lambda args: {"result": random.randint(int(args.get("min", 0)), int(args.get("max", 100)))},
"text_length": lambda args: {"characters": len(args.get("text", "")), "words": len(args.get("text", "").split())},
"unit_converter": lambda args: {"result": round(float(args.get("value", 0)) * 0.621371, 2), "from": f"{args.get('value', 0)} {args.get('from_unit', '')}", "to": args.get("to_unit", "")},
"get_current_weather": lambda args: {"city": args.get("location"), "temperature": "22°C", "humidity": "65%", "condition": ""},
"get_exchange_rate": lambda args: {"from": args.get("from_currency", ""), "to": args.get("to_currency", ""), "rate": 7.15},
"translate_text": lambda args: {"translated": "hello world"},
}
TOOL_MAP = {t["function"]["name"]: t for t in TOOLS}
def get_tools(names):
return [TOOL_MAP[n] for n in names]
TEST_CASES = [
{"prompt": "帮我算一下 256 乘以 37 等于多少", "tools": ["calculate_math", "get_current_time"]},
{"prompt": "现在几点了?", "tools": ["get_current_time", "random_number"]},
{"prompt": "帮我把100公里换算成英里", "tools": ["unit_converter", "calculate_math"]},
{"prompt": "帮我生成一个1到1000的随机数然后计算它的平方", "tools": ["random_number", "calculate_math", "text_length"]},
{"prompt": "北京今天天气怎么样?", "tools": ["get_current_weather", "get_current_time"]},
{"prompt": "查一下美元兑人民币汇率", "tools": ["get_exchange_rate", "get_current_time"]},
{"prompt": "'你好世界'翻译成英文", "tools": ["translate_text", "text_length"]},
{"prompt": "What is the weather in Tokyo? Also convert 30 celsius to fahrenheit.", "tools": ["get_current_weather", "unit_converter", "get_current_time"]},
]
def init_model(args):
tokenizer = AutoTokenizer.from_pretrained(args.load_from)
if 'model' in args.load_from:
model = MiniMindForCausalLM(MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers, use_moe=bool(args.use_moe)))
moe_suffix = '_moe' if args.use_moe else ''
ckp = f'./{args.save_dir}/{args.weight}_{args.hidden_size}{moe_suffix}.pth'
model.load_state_dict(torch.load(ckp, map_location=args.device), strict=True)
else:
model = AutoModelForCausalLM.from_pretrained(args.load_from, trust_remote_code=True)
get_model_params(model, model.config)
return model.half().eval().to(args.device), tokenizer
def parse_tool_calls(text):
matches = re.findall(r'<tool_call>(.*?)</tool_call>', text, re.DOTALL)
calls = []
for m in matches:
try:
calls.append(json.loads(m.strip()))
except Exception:
pass
return calls
def parse_tool_call_from_text(content):
pattern = r'<tool_call>\s*(\{.*?\})\s*</tool_call>'
matches = re.findall(pattern, content, re.DOTALL)
if not matches:
return None
tool_calls = []
for i, match in enumerate(matches):
try:
data = json.loads(match)
tool_calls.append({
"id": f"call_{i}",
"function": {"name": data.get("name", ""), "arguments": json.dumps(data.get("arguments", {}), ensure_ascii=False)}
})
except Exception:
pass
return tool_calls if tool_calls else None
def execute_tool(call, arguments=None):
name = call.get("name", "") if isinstance(call, dict) else call
try:
raw_args = call.get("arguments", {}) if isinstance(call, dict) else arguments
args = json.loads(raw_args) if isinstance(raw_args, str) else raw_args
except Exception:
args = {}
fn = MOCK_RESULTS.get(name)
if not fn:
return {"error": f"未知工具: {name}"}
try:
return fn(args)
except Exception as e:
return {"error": f"工具执行失败: {str(e)[:80]}"}
def generate(model, tokenizer, messages, tools, args):
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
input_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True, tools=tools, open_thinking=False)
inputs = tokenizer(input_text, return_tensors="pt", truncation=True).to(args.device)
st = time.time()
print('🧠: ', end='')
generated_ids = model.generate(
inputs["input_ids"], attention_mask=inputs["attention_mask"],
max_new_tokens=args.max_new_tokens, do_sample=True, streamer=streamer,
pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id,
top_p=args.top_p, temperature=args.temperature
)
response = tokenizer.decode(generated_ids[0][len(inputs["input_ids"][0]):], skip_special_tokens=True)
gen_tokens = len(generated_ids[0]) - len(inputs["input_ids"][0])
print(f'\n[Speed]: {gen_tokens / (time.time() - st):.2f} tokens/s') if args.show_speed else print()
return response
def chat_api(client, messages, tools, args, stream=True):
response = client.chat.completions.create(
model=args.api_model, messages=messages, tools=tools,
stream=stream, temperature=args.temperature,
max_tokens=8192, top_p=args.top_p
)
if not stream:
choice = response.choices[0]
content = choice.message.content or ""
tool_calls = choice.message.tool_calls
if not tool_calls:
tool_calls = parse_tool_call_from_text(content)
print(f'🧠: {content}')
return content, tool_calls
print('🧠: ', end='', flush=True)
content, tool_calls = "", None
for chunk in response:
delta = chunk.choices[0].delta
if delta.content:
print(delta.content, end="", flush=True)
content += delta.content
if delta.tool_calls:
if tool_calls is None:
tool_calls = []
for tc_chunk in delta.tool_calls:
idx = tc_chunk.index if tc_chunk.index is not None else len(tool_calls)
while len(tool_calls) <= idx:
tool_calls.append({
"id": "",
"function": {"name": "", "arguments": ""}
})
if tc_chunk.id:
tool_calls[idx]["id"] += tc_chunk.id
if tc_chunk.function:
if tc_chunk.function.name:
tool_calls[idx]["function"]["name"] += tc_chunk.function.name
if tc_chunk.function.arguments:
tool_calls[idx]["function"]["arguments"] += tc_chunk.function.arguments
print()
if not tool_calls:
tool_calls = parse_tool_call_from_text(content)
return content, tool_calls
def run_case(prompt, tools, args, model=None, tokenizer=None, client=None):
messages = [{"role": "user", "content": prompt}]
while True:
if args.backend == 'local':
content = generate(model, tokenizer, messages, tools, args)
tool_calls = parse_tool_calls(content)
else:
content, tool_calls = chat_api(client, messages, tools, args, stream=bool(args.stream))
if not tool_calls:
break
tool_calls = [{
"id": tc.id if hasattr(tc, 'id') else tc.get("id", ""),
"name": tc.function.name if hasattr(tc, 'function') else tc["function"]["name"],
"arguments": tc.function.arguments if hasattr(tc, 'function') else tc["function"]["arguments"]
} for tc in tool_calls] if args.backend == 'api' else tool_calls
messages.append({"role": "assistant", "content": content} if args.backend == 'local' else {"role": "assistant", "content": content, "tool_calls": [{"id": tc["id"], "type": "function", "function": {"name": tc["name"], "arguments": tc["arguments"]}} for tc in tool_calls]})
for tc in tool_calls:
name = tc["name"]
arguments = tc["arguments"]
print(f'📞 [Tool Calling]: {name} | args={arguments}')
result = execute_tool(tc if args.backend == 'local' else name, arguments)
print(f'✅ [Tool Called]: {json.dumps(result, ensure_ascii=False)}')
messages.append({"role": "tool", "content": json.dumps(result, ensure_ascii=False)} if args.backend == 'local' else {"role": "tool", "content": json.dumps(result, ensure_ascii=False), "tool_call_id": tc["id"]})
def main():
parser = argparse.ArgumentParser(description="MiniMind ToolCall评测")
parser.add_argument('--backend', default='local', choices=['local', 'api'], type=str, help="推理后端local=本地模型api=OpenAI兼容接口")
parser.add_argument('--load_from', default='../model', type=str, help="模型加载路径model=原生torch权重其他路径=transformers格式")
parser.add_argument('--save_dir', default='../out', type=str, help="模型权重目录")
parser.add_argument('--weight', default='full_sft', type=str, help="权重名称前缀pretrain, full_sft, rlhf, reason, ppo_actor, grpo, spo")
parser.add_argument('--hidden_size', default=768, type=int, help="隐藏层维度")
parser.add_argument('--num_hidden_layers', default=8, type=int, help="隐藏层数量")
parser.add_argument('--use_moe', default=0, type=int, choices=[0, 1], help="是否使用MoE架构0=否1=是)")
parser.add_argument('--max_new_tokens', default=512, type=int, help="最大生成长度")
parser.add_argument('--temperature', default=0.9, type=float, help="生成温度控制随机性0-1越大越随机")
parser.add_argument('--top_p', default=0.9, type=float, help="nucleus采样阈值0-1")
parser.add_argument('--show_speed', default=0, type=int, help="显示decode速度tokens/s")
parser.add_argument('--device', default='cuda' if torch.cuda.is_available() else 'cpu', type=str, help="运行设备")
parser.add_argument('--api_base_url', default="http://localhost:11434/v1", type=str, help="OpenAI兼容接口的base_url")
parser.add_argument('--api_key', default='sk-123', type=str, help="OpenAI兼容接口的api_key")
parser.add_argument('--api_model', default='jingyaogong/minimind-3:latest', type=str, help="API请求时使用的模型名称")
parser.add_argument('--stream', default=1, type=int, help="API模式下是否流式输出0=否1=是)")
args = parser.parse_args()
model = tokenizer = client = None
if args.backend == 'local': model, tokenizer = init_model(args)
else: client = OpenAI(api_key=args.api_key, base_url=args.api_base_url)
input_mode = int(input('[0] 自动测试\n[1] 手动输入\n'))
cases = [{"prompt": case["prompt"], "tools": get_tools(case["tools"]), "tool_names": case["tools"]} for case in TEST_CASES] if input_mode == 0 else iter(lambda: {"prompt": input('💬: '), "tools": TOOLS, "tool_names": [t["function"]["name"] for t in TOOLS]}, {"prompt": "", "tools": TOOLS, "tool_names": []})
for case in cases:
if not case["prompt"]: break
setup_seed(random.randint(0, 31415926))
if input_mode == 0:
print(f'📦 可用工具: {case["tool_names"]}\n')
print(f'💬: {case["prompt"]}')
run_case(case["prompt"], case["tools"], args, model=model, tokenizer=tokenizer, client=client)
print('\n' + '-' * 50 + '\n')
if __name__ == "__main__":
main()

245
scripts/serve_openai_api.py Normal file
View File

@ -0,0 +1,245 @@
import argparse
import json
import re
import os
import sys
__package__ = "scripts"
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
import time
import torch
import warnings
import uvicorn
from threading import Thread
from queue import Queue
from fastapi import FastAPI, HTTPException
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer
from model.model_minimind import MiniMindConfig, MiniMindForCausalLM
from model.model_lora import apply_lora, load_lora
warnings.filterwarnings('ignore')
app = FastAPI()
def init_model(args):
tokenizer = AutoTokenizer.from_pretrained(args.load_from)
if 'model' in args.load_from:
moe_suffix = '_moe' if args.use_moe else ''
ckp = f'../{args.save_dir}/{args.weight}_{args.hidden_size}{moe_suffix}.pth'
model = MiniMindForCausalLM(MiniMindConfig(
hidden_size=args.hidden_size,
num_hidden_layers=args.num_hidden_layers,
max_seq_len=args.max_seq_len,
use_moe=bool(args.use_moe),
inference_rope_scaling=args.inference_rope_scaling
))
model.load_state_dict(torch.load(ckp, map_location=device), strict=True)
if args.lora_weight != 'None':
apply_lora(model)
load_lora(model, f'../{args.save_dir}/lora/{args.lora_weight}_{args.hidden_size}.pth')
else:
model = AutoModelForCausalLM.from_pretrained(args.load_from, trust_remote_code=True)
print(f'MiniMind模型参数量: {sum(p.numel() for p in model.parameters()) / 1e6:.2f} M(illion)')
return model.half().eval().to(device), tokenizer
class ChatRequest(BaseModel):
model: str
messages: list
temperature: float = 0.7
top_p: float = 0.92
max_tokens: int = 8192
stream: bool = True
tools: list = []
open_thinking: bool = False
chat_template_kwargs: dict = None
def get_open_thinking(self) -> bool:
"""兼容多种方式开启 thinking"""
if self.open_thinking:
return True
if self.chat_template_kwargs:
return self.chat_template_kwargs.get('open_thinking', False) or \
self.chat_template_kwargs.get('enable_thinking', False)
return False
class CustomStreamer(TextStreamer):
def __init__(self, tokenizer, queue):
super().__init__(tokenizer, skip_prompt=True, skip_special_tokens=True)
self.queue = queue
self.tokenizer = tokenizer
def on_finalized_text(self, text: str, stream_end: bool = False):
self.queue.put(text)
if stream_end:
self.queue.put(None)
def parse_response(text):
reasoning_content = None
think_match = re.search(r'<think>(.*?)</think>', text, re.DOTALL)
if think_match:
reasoning_content = think_match.group(1).strip()
text = re.sub(r'<think>.*?</think>\s*', '', text, flags=re.DOTALL)
elif '</think>' in text:
parts = text.split('</think>', 1)
reasoning_content = parts[0].strip()
text = parts[1].strip() if len(parts) > 1 else ''
tool_calls = []
for i, m in enumerate(re.findall(r'<tool_call>(.*?)</tool_call>', text, re.DOTALL)):
try:
call = json.loads(m.strip())
tool_calls.append({"id": f"call_{int(time.time())}_{i}", "type": "function", "function": {"name": call.get("name", ""), "arguments": json.dumps(call.get("arguments", {}), ensure_ascii=False)}})
except Exception:
pass
if tool_calls:
text = re.sub(r'<tool_call>.*?</tool_call>', '', text, flags=re.DOTALL)
return text.strip(), reasoning_content, tool_calls or None
def generate_stream_response(messages, temperature, top_p, max_tokens, tools=None, open_thinking=False):
try:
new_prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True, tools=tools or None, open_thinking=open_thinking)[-max_tokens:]
inputs = tokenizer(new_prompt, return_tensors="pt", truncation=True).to(device)
queue = Queue()
streamer = CustomStreamer(tokenizer, queue)
def _generate():
model.generate(
inputs.input_ids,
max_new_tokens=max_tokens,
do_sample=True,
temperature=temperature,
top_p=top_p,
attention_mask=inputs.attention_mask,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
streamer=streamer
)
Thread(target=_generate).start()
full_text = ""
emitted = 0
thinking_ended = not bool(open_thinking)
while True:
text = queue.get()
if text is None:
break
full_text += text
if not thinking_ended:
pos = full_text.find('</think>')
if pos >= 0:
thinking_ended = True
new_r = full_text[emitted:pos]
if new_r:
yield json.dumps({"choices": [{"delta": {"reasoning_content": new_r}}]}, ensure_ascii=False)
emitted = pos + len('</think>')
after = full_text[emitted:].lstrip('\n')
emitted = len(full_text) - len(after)
if after:
yield json.dumps({"choices": [{"delta": {"content": after}}]}, ensure_ascii=False)
emitted = len(full_text)
else:
new_r = full_text[emitted:]
if new_r:
yield json.dumps({"choices": [{"delta": {"reasoning_content": new_r}}]}, ensure_ascii=False)
emitted = len(full_text)
else:
new_c = full_text[emitted:]
if new_c:
yield json.dumps({"choices": [{"delta": {"content": new_c}}]}, ensure_ascii=False)
emitted = len(full_text)
_, _, tool_calls = parse_response(full_text)
if tool_calls:
yield json.dumps({"choices": [{"delta": {"tool_calls": tool_calls}}]}, ensure_ascii=False)
yield json.dumps({"choices": [{"delta": {}, "finish_reason": "tool_calls" if tool_calls else "stop"}]}, ensure_ascii=False)
except Exception as e:
yield json.dumps({"error": str(e)})
@app.post("/v1/chat/completions")
async def chat_completions(request: ChatRequest):
try:
if request.stream:
return StreamingResponse(
(f"data: {chunk}\n\n" for chunk in generate_stream_response(
messages=request.messages,
temperature=request.temperature,
top_p=request.top_p,
max_tokens=request.max_tokens,
tools=request.tools,
open_thinking=request.get_open_thinking()
)),
media_type="text/event-stream"
)
else:
new_prompt = tokenizer.apply_chat_template(
request.messages,
tokenize=False,
add_generation_prompt=True,
tools=request.tools or None,
open_thinking=request.get_open_thinking()
)[-request.max_tokens:]
inputs = tokenizer(new_prompt, return_tensors="pt", truncation=True).to(device)
with torch.no_grad():
generated_ids = model.generate(
inputs["input_ids"],
max_length=inputs["input_ids"].shape[1] + request.max_tokens,
do_sample=True,
attention_mask=inputs["attention_mask"],
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
top_p=request.top_p,
temperature=request.temperature
)
answer = tokenizer.decode(generated_ids[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
content, reasoning_content, tool_calls = parse_response(answer)
message = {"role": "assistant", "content": content}
if reasoning_content:
message["reasoning_content"] = reasoning_content
if tool_calls:
message["tool_calls"] = tool_calls
return {
"id": f"chatcmpl-{int(time.time())}",
"object": "chat.completion",
"created": int(time.time()),
"model": "minimind",
"choices": [
{
"index": 0,
"message": message,
"finish_reason": "tool_calls" if tool_calls else "stop"
}
]
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Server for MiniMind")
parser.add_argument('--load_from', default='../model', type=str, help="模型加载路径model=原生torch权重其他路径=transformers格式")
parser.add_argument('--save_dir', default='out', type=str, help="模型权重目录")
parser.add_argument('--weight', default='full_sft', type=str, help="权重名称前缀pretrain, full_sft, dpo, reason, ppo_actor, grpo, spo")
parser.add_argument('--lora_weight', default='None', type=str, help="LoRA权重名称None表示不使用可选lora_identity, lora_medical")
parser.add_argument('--hidden_size', default=768, type=int, help="隐藏层维度")
parser.add_argument('--num_hidden_layers', default=8, type=int, help="隐藏层数量")
parser.add_argument('--max_seq_len', default=8192, type=int, help="最大序列长度")
parser.add_argument('--use_moe', default=0, type=int, choices=[0, 1], help="是否使用MoE架构0=否1=是)")
parser.add_argument('--inference_rope_scaling', default=False, action='store_true', help="启用RoPE位置编码外推4倍仅解决位置编码问题")
parser.add_argument('--device', default='cuda' if torch.cuda.is_available() else 'cpu', type=str, help="运行设备")
args = parser.parse_args()
device = args.device
model, tokenizer = init_model(args)
uvicorn.run(app, host="0.0.0.0", port=8998)

420
scripts/web_demo.py Normal file
View File

@ -0,0 +1,420 @@
import random
import re
import json
import os
from threading import Thread
import torch
import numpy as np
import streamlit as st
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
st.set_page_config(page_title="MiniMind", initial_sidebar_state="collapsed")
st.markdown("""
<style>
/* 添加操作按钮样式 */
.stButton button {
border-radius: 50% !important; /* 改为圆形 */
width: 32px !important; /* 固定宽度 */
height: 32px !important; /* 固定高度 */
padding: 0 !important; /* 移除内边距 */
background-color: transparent !important;
border: 1px solid #ddd !important;
display: flex !important;
align-items: center !important;
justify-content: center !important;
font-size: 14px !important;
color: #666 !important; /* 更柔和的颜色 */
margin: 5px 10px 5px 0 !important; /* 调整按钮间距 */
}
.stButton button:hover {
border-color: #999 !important;
color: #333 !important;
background-color: #f5f5f5 !important;
}
.stMainBlockContainer > div:first-child {
margin-top: -50px !important;
}
.stApp > div:last-child {
margin-bottom: -35px !important;
}
/* 重置按钮基础样式 */
.stButton > button {
all: unset !important; /* 重置所有默认样式 */
box-sizing: border-box !important;
border-radius: 50% !important;
width: 18px !important;
height: 18px !important;
min-width: 18px !important;
min-height: 18px !important;
max-width: 18px !important;
max-height: 18px !important;
padding: 0 !important;
background-color: transparent !important;
border: 1px solid #ddd !important;
display: flex !important;
align-items: center !important;
justify-content: center !important;
font-size: 14px !important;
color: #888 !important;
cursor: pointer !important;
transition: all 0.2s ease !important;
margin: 0 2px !important; /* 调整这里的 margin */
}
</style>
""", unsafe_allow_html=True)
device = "cuda" if torch.cuda.is_available() else "cpu"
# 多语言文本
LANG_TEXTS = {
'zh': {
'settings': '模型设定调整',
'history_rounds': '历史对话轮次',
'max_length': '最大生成长度',
'temperature': '温度',
'thinking': '思考',
'tools': '工具',
'language': '语言',
'send': '给 MiniMind 发送消息',
'disclaimer': 'AI 生成内容可能存在错误,请仔细核实',
'think_tip': '自适应思考目前多轮对话或Tool Call共存时思考不稳定',
'tool_select': '工具选择最多4个',
},
'en': {
'settings': 'Model Settings',
'history_rounds': 'History Rounds',
'max_length': 'Max Length',
'temperature': 'Temperature',
'thinking': 'Thinking',
'tools': 'Tools',
'language': 'Language',
'send': 'Send a message to MiniMind',
'disclaimer': 'AI-generated content may be inaccurate, please verify',
'think_tip': 'Adaptive thinking; may be unstable with multi-turn or Tool Call',
'tool_select': 'Tool Selection (max 4)',
}
}
def get_text(key):
lang = st.session_state.get('lang', 'en')
return LANG_TEXTS.get(lang, {}).get(key, LANG_TEXTS['zh'].get(key, key))
# 工具定义
TOOLS = [
{"type": "function", "function": {"name": "calculate_math", "description": "计算数学表达式", "parameters": {"type": "object", "properties": {"expression": {"type": "string", "description": "数学表达式"}}, "required": ["expression"]}}},
{"type": "function", "function": {"name": "get_current_time", "description": "获取当前时间", "parameters": {"type": "object", "properties": {"timezone": {"type": "string", "default": "Asia/Shanghai"}}, "required": []}}},
{"type": "function", "function": {"name": "random_number", "description": "生成随机数", "parameters": {"type": "object", "properties": {"min": {"type": "integer"}, "max": {"type": "integer"}}, "required": ["min", "max"]}}},
{"type": "function", "function": {"name": "text_length", "description": "计算文本长度", "parameters": {"type": "object", "properties": {"text": {"type": "string"}}, "required": ["text"]}}},
{"type": "function", "function": {"name": "unit_converter", "description": "单位转换", "parameters": {"type": "object", "properties": {"value": {"type": "number"}, "from_unit": {"type": "string"}, "to_unit": {"type": "string"}}, "required": ["value", "from_unit", "to_unit"]}}},
{"type": "function", "function": {"name": "get_current_weather", "description": "获取天气", "parameters": {"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]}}},
{"type": "function", "function": {"name": "get_exchange_rate", "description": "获取汇率", "parameters": {"type": "object", "properties": {"from_currency": {"type": "string"}, "to_currency": {"type": "string"}}, "required": ["from_currency", "to_currency"]}}},
{"type": "function", "function": {"name": "translate_text", "description": "翻译文本", "parameters": {"type": "object", "properties": {"text": {"type": "string"}, "target_lang": {"type": "string"}}, "required": ["text", "target_lang"]}}},
]
TOOL_SHORT_NAMES = {
'calculate_math': '数学', 'get_current_time': '时间', 'random_number': '随机',
'text_length': '字数', 'unit_converter': '单位', 'get_current_weather': '天气',
'get_exchange_rate': '汇率', 'translate_text': '翻译'
}
def execute_tool(tool_name, args):
import datetime
try:
if tool_name == 'calculate_math':
return {"result": eval(args.get('expression', '0'))}
elif tool_name == 'get_current_time':
tz = args.get('timezone', 'Asia/Shanghai')
return {"result": datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}
elif tool_name == 'random_number':
return {"result": random.randint(args.get('min', 0), args.get('max', 100))}
elif tool_name == 'text_length':
return {"result": len(args.get('text', ''))}
elif tool_name == 'unit_converter':
return {"result": f"{args.get('value', 0)} {args.get('from_unit', '')} = ? {args.get('to_unit', '')}"}
elif tool_name == 'get_current_weather':
return {"result": f"{args.get('city', 'Unknown')}: 晴, 7~10°C"}
elif tool_name == 'get_exchange_rate':
return {"result": f"1 {args.get('from_currency', 'USD')} = 7.2 {args.get('to_currency', 'CNY')}"}
elif tool_name == 'translate_text':
return {"result": f"[翻译结果]: hello world"}
return {"result": "Unknown tool"}
except Exception as e:
return {"error": str(e)}
def process_assistant_content(content, is_streaming=False):
# 处理tool_call标签格式化显示
if '<tool_call>' in content:
def format_tool_call(match):
try:
tc = json.loads(match.group(1))
name = tc.get('name', 'unknown')
args = tc.get('arguments', {})
return f'<div style="background: rgba(80, 110, 150, 0.20); border: 1px solid rgba(140, 170, 210, 0.30); padding: 10px 12px; border-radius: 12px; margin: 6px 0;"><div style="font-size:12px;opacity:.75;display:block;margin:0 0 6px 0;line-height:1;">ToolCalling</div><div><b>{name}</b>: {json.dumps(args, ensure_ascii=False)}</div></div>'
except:
return match.group(0)
content = re.sub(r'<tool_call>(.*?)</tool_call>', format_tool_call, content, flags=re.DOTALL)
# 流式生成且开启思考时,一开始就放到折叠里
if is_streaming and st.session_state.get('enable_thinking', False) and '</think>' not in content and '<think>' not in content:
m = re.search(r'(\n\n(?:我是|您好|你好)[^\n]*)', content)
if m and m.start(1) > 5:
i = m.start(1)
think_part = content[:i]
answer_part = content[i:]
return f'<details open style="border-left: 2px solid #666; padding-left: 12px; margin: 8px 0;"><summary style="cursor: pointer; color: #888;">已思考</summary><div style="color: #aaa; font-size: 0.95em; margin-top: 8px; max-height: 100px; overflow-y: auto;">{think_part.strip()}</div></details>{answer_part}'
elif len(content) > 5:
return f'<details open style="border-left: 2px solid #666; padding-left: 12px; margin: 8px 0;"><summary style="cursor: pointer; color: #888;">思考中...</summary><div style="color: #aaa; font-size: 0.95em; margin-top: 8px; max-height: 100px; overflow-y: auto; display: flex; flex-direction: column-reverse;"><div style="margin-bottom: auto;">{content.strip().replace(chr(10), "<br>")}</div></div></details>'
if '<think>' in content and '</think>' in content:
def format_think(match):
think_content = match.group(2)
if think_content.replace('\n', '').strip(): # 不是全换行
return f'<details open style="border-left: 2px solid #666; padding-left: 12px; margin: 8px 0;"><summary style="cursor: pointer; color: #888;">已思考</summary><div style="color: #aaa; font-size: 0.95em; margin-top: 8px; max-height: 100px; overflow-y: auto;">{think_content.strip()}</div></details>'
return ''
content = re.sub(r'(<think>)(.*?)(</think>)', format_think, content, flags=re.DOTALL)
if '<think>' in content and '</think>' not in content:
def format_think_in_progress(match):
tc = match.group(1)
return f'<details open style="border-left: 2px solid #666; padding-left: 12px; margin: 8px 0;"><summary style="cursor: pointer; color: #888;">思考中...</summary><div style="color: #aaa; font-size: 0.95em; margin-top: 8px; max-height: 100px; overflow-y: auto; display: flex; flex-direction: column-reverse;"><div style="margin-bottom: auto;">{tc.strip().replace(chr(10), "<br>")}</div></div></details>'
content = re.sub(r'<think>(.*?)$', format_think_in_progress, content, flags=re.DOTALL)
if '<think>' not in content and '</think>' in content:
def format_think_no_start(match):
think_content = match.group(1)
if think_content.replace('\n', '').strip():
return f'<details open style="border-left: 2px solid #666; padding-left: 12px; margin: 8px 0;"><summary style="cursor: pointer; color: #888;">已思考</summary><div style="color: #aaa; font-size: 0.95em; margin-top: 8px; max-height: 100px; overflow-y: auto;">{think_content.strip()}</div></details>'
return ''
content = re.sub(r'(.*?)</think>', format_think_no_start, content, flags=re.DOTALL)
return content
@st.cache_resource
def load_model_tokenizer(model_path):
model = AutoModelForCausalLM.from_pretrained(
model_path,
trust_remote_code=True
)
tokenizer = AutoTokenizer.from_pretrained(
model_path,
trust_remote_code=True
)
model = model.half().eval().to(device)
return model, tokenizer
def clear_chat_messages():
del st.session_state.messages
del st.session_state.chat_messages
def init_chat_messages():
if "messages" in st.session_state:
for i, message in enumerate(st.session_state.messages):
if message["role"] == "assistant":
st.markdown(process_assistant_content(message["content"]), unsafe_allow_html=True)
else:
st.markdown(
f'<div style="display: flex; justify-content: flex-end;"><div style="display: inline-block; margin: 10px 0; padding: 8px 12px 8px 12px; background-color: #3d4450; border-radius: 22px; color: white;">{message["content"]}</div></div>',
unsafe_allow_html=True)
else:
st.session_state.messages = []
st.session_state.chat_messages = []
return st.session_state.messages
def regenerate_answer(index):
st.session_state.messages.pop()
st.session_state.chat_messages.pop()
st.rerun()
# 动态扫描模型目录
script_dir = os.path.dirname(os.path.abspath(__file__))
MODEL_PATHS = {}
for d in sorted(os.listdir(script_dir), reverse=True):
full_path = os.path.join(script_dir, d)
if os.path.isdir(full_path) and not d.startswith('.') and not d.startswith('_'):
if any(f.endswith(('.bin', '.safetensors', '.pt')) or os.path.exists(os.path.join(full_path, 'model.safetensors.index.json')) for f in os.listdir(full_path) if os.path.isfile(os.path.join(full_path, f))):
MODEL_PATHS[d] = [d, d]
if not MODEL_PATHS:
MODEL_PATHS = {"No models found": ["", "No models"]}
# 模型选择
selected_model = st.sidebar.selectbox('Model', list(MODEL_PATHS.keys()), index=0)
model_path = MODEL_PATHS[selected_model][0]
slogan = f"我是 {MODEL_PATHS[selected_model][1]},有什么可以帮你的?" if st.session_state.get('lang', 'en') == 'zh' else f"I am {MODEL_PATHS[selected_model][1]}, how can I help you?"
st.sidebar.markdown('<hr style="margin: 12px 0 16px 0;">', unsafe_allow_html=True)
# 语言选择
lang_options = {'中文': 'zh', 'English': 'en'}
current_lang = st.session_state.get('lang', 'en')
lang_index = 0 if current_lang == 'zh' else 1
lang_label = st.sidebar.radio('Language / 语言', list(lang_options.keys()), index=lang_index, horizontal=True)
if lang_options[lang_label] != current_lang:
st.session_state.lang = lang_options[lang_label]
st.rerun()
st.sidebar.markdown('<hr style="margin: 12px 0 16px 0;">', unsafe_allow_html=True)
# 参数设置
st.session_state.history_chat_num = st.sidebar.slider(get_text('history_rounds'), 0, 8, 0, step=2)
st.session_state.max_new_tokens = st.sidebar.slider(get_text('max_length'), 256, 8192, 8192, step=1)
st.session_state.temperature = st.sidebar.slider(get_text('temperature'), 0.6, 1.2, 0.90, step=0.01)
st.sidebar.markdown('<hr style="margin: 12px 0 16px 0;">', unsafe_allow_html=True)
# 功能开关
st.session_state.enable_thinking = st.sidebar.checkbox(get_text('thinking'), value=False, help=get_text('think_tip'))
st.session_state.selected_tools = []
with st.sidebar.expander(get_text('tools')):
st.caption(get_text('tool_select'))
selected_count = sum(1 for tool in TOOLS if st.session_state.get(f"tool_{tool['function']['name']}", False))
for tool in TOOLS:
name = tool['function']['name']
short_name = TOOL_SHORT_NAMES.get(name, name)
checked = st.checkbox(short_name, key=f"tool_{name}", disabled=(selected_count >= 4 and not st.session_state.get(f"tool_{name}", False)))
if checked and len(st.session_state.selected_tools) < 4:
st.session_state.selected_tools.append(name)
image_url = "https://www.modelscope.cn/api/v1/studio/gongjy/MiniMind/repo?Revision=master&FilePath=images%2Flogo2.png&View=true"
st.markdown(
f'<div style="display: flex; flex-direction: column; align-items: center; text-align: center; margin: 0; padding: 0;">'
'<div style="font-style: italic; font-weight: 900; margin: 0; padding-top: 4px; display: flex; align-items: center; justify-content: center; flex-wrap: wrap; width: 100%;">'
f'<img src="{image_url}" style="width: 40px; height: 40px; "> '
f'<span style="font-size: 26px; margin-left: 10px;">{slogan}</span>'
'</div>'
f'<span style="color: #bbb; font-style: italic; margin-top: 6px; margin-bottom: 10px;">{get_text("disclaimer")}</span>'
'</div>',
unsafe_allow_html=True
)
def setup_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def main():
model, tokenizer = load_model_tokenizer(model_path)
if "messages" not in st.session_state:
st.session_state.messages = []
st.session_state.chat_messages = []
messages = st.session_state.messages
for i, message in enumerate(messages):
if message["role"] == "assistant":
st.markdown(process_assistant_content(message["content"]), unsafe_allow_html=True)
else:
st.markdown(
f'<div style="display: flex; justify-content: flex-end;"><div style="display: inline-block; margin: 10px 0; padding: 8px 12px 8px 12px; background-color: #3d4450; border-radius: 22px; color: white;">{message["content"]}</div></div>',
unsafe_allow_html=True)
prompt = st.chat_input(key="input", placeholder=get_text('send'))
if hasattr(st.session_state, 'regenerate') and st.session_state.regenerate:
prompt = st.session_state.last_user_message
regenerate_index = st.session_state.regenerate_index
delattr(st.session_state, 'regenerate')
delattr(st.session_state, 'last_user_message')
delattr(st.session_state, 'regenerate_index')
if prompt:
st.markdown(
f'<div style="display: flex; justify-content: flex-end;"><div style="display: inline-block; margin: 10px 0; padding: 8px 12px 8px 12px; background-color: #3d4450; border-radius: 22px; color: white;">{prompt}</div></div>',
unsafe_allow_html=True)
messages.append({"role": "user", "content": prompt[-st.session_state.max_new_tokens:]})
st.session_state.chat_messages.append({"role": "user", "content": prompt[-st.session_state.max_new_tokens:]})
placeholder = st.empty()
random_seed = random.randint(0, 2 ** 32 - 1)
setup_seed(random_seed)
tools = [t for t in TOOLS if t['function']['name'] in st.session_state.get('selected_tools', [])] or None
sys_prompt = [] if tools else [{"role": "system", "content": "你是MiniMind一个乐于助人、知识渊博的AI助手。请用完整且友好的方式回答用户问题。"}]
st.session_state.chat_messages = sys_prompt + st.session_state.chat_messages[-(st.session_state.history_chat_num + 1):]
template_kwargs = {"tokenize": False, "add_generation_prompt": True}
if st.session_state.get('enable_thinking', False):
template_kwargs["open_thinking"] = True
if tools:
template_kwargs["tools"] = tools
new_prompt = tokenizer.apply_chat_template(st.session_state.chat_messages, **template_kwargs)
inputs = tokenizer(new_prompt, return_tensors="pt", truncation=True).to(device)
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
generation_kwargs = {
"input_ids": inputs.input_ids,
"max_length": inputs.input_ids.shape[1] + st.session_state.max_new_tokens,
"num_return_sequences": 1,
"do_sample": True,
"attention_mask": inputs.attention_mask,
"pad_token_id": tokenizer.pad_token_id,
"eos_token_id": tokenizer.eos_token_id,
"temperature": st.session_state.temperature,
"top_p": 0.85,
"streamer": streamer,
}
Thread(target=model.generate, kwargs=generation_kwargs).start()
answer = ""
for new_text in streamer:
answer += new_text
placeholder.markdown(process_assistant_content(answer, is_streaming=True), unsafe_allow_html=True)
full_answer = answer
for _ in range(16):
tool_calls = re.findall(r'<tool_call>(.*?)</tool_call>', answer, re.DOTALL)
if not tool_calls:
break
st.session_state.chat_messages.append({"role": "assistant", "content": answer})
tool_results = []
for tc_str in tool_calls:
try:
tc = json.loads(tc_str.strip())
result = execute_tool(tc.get('name', ''), tc.get('arguments', {}))
st.session_state.chat_messages.append({"role": "tool", "content": json.dumps(result, ensure_ascii=False)})
tool_results.append(f'<div style="background: rgba(90, 130, 110, 0.20); border: 1px solid rgba(150, 200, 170, 0.30); padding: 10px 12px; border-radius: 12px; margin: 6px 0;"><div style="font-size:12px;opacity:.75;display:block;margin:0 0 6px 0;line-height:1;">ToolCalled</div><div><b>{tc.get("name", "")}</b>: {json.dumps(result, ensure_ascii=False)}</div></div>')
except:
pass
full_answer += "\n" + "\n".join(tool_results) + "\n"
placeholder.markdown(process_assistant_content(full_answer, is_streaming=True), unsafe_allow_html=True)
new_prompt = tokenizer.apply_chat_template(st.session_state.chat_messages, **template_kwargs)
inputs = tokenizer(new_prompt, return_tensors="pt", truncation=True).to(device)
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
generation_kwargs["input_ids"] = inputs.input_ids
generation_kwargs["attention_mask"] = inputs.attention_mask
generation_kwargs["max_length"] = inputs.input_ids.shape[1] + st.session_state.max_new_tokens
generation_kwargs["streamer"] = streamer
Thread(target=model.generate, kwargs=generation_kwargs).start()
answer = ""
for new_text in streamer:
answer += new_text
placeholder.markdown(process_assistant_content(full_answer + answer, is_streaming=True), unsafe_allow_html=True)
full_answer += answer
answer = full_answer
messages.append({"role": "assistant", "content": answer})
st.session_state.chat_messages.append({"role": "assistant", "content": answer})
if __name__ == "__main__":
main()

224
trainer/rollout_engine.py Normal file
View File

@ -0,0 +1,224 @@
"""
# 如果使用sglang加速需通过以下命令首先启动transformers格式模型
python -m sglang.launch_server --model-path ./minimind-3 --attention-backend triton --host 0.0.0.0 --port 8998
"""
import os
import sys
__package__ = "trainer"
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
import requests
import torch
import torch.distributed as dist
from abc import ABC, abstractmethod
from contextlib import nullcontext
from dataclasses import dataclass
from typing import List, Optional, Tuple
from torch import Tensor
from torch.nn.parallel import DistributedDataParallel
from transformers import AutoTokenizer
# ===== 计算每个 token 的 logprob =====
def compute_per_token_logps(model, input_ids: Tensor, n_keep: int, attention_mask: Optional[Tensor] = None) -> Tensor:
if n_keep <= 0:
return input_ids.new_empty((input_ids.size(0), 0), dtype=torch.float32)
unwrapped = model.module if isinstance(model, DistributedDataParallel) else model
input_ids = input_ids.detach().clone() if input_ids.is_inference() else input_ids
logits = unwrapped(input_ids, attention_mask=attention_mask, logits_to_keep=n_keep + 1).logits[:, :-1, :]
per_token_logps = []
for logits_row, ids_row in zip(logits, input_ids[:, -n_keep:]):
ids_row = ids_row.detach().clone() if ids_row.is_inference() else ids_row
per_token_logps.append(
torch.gather(logits_row.log_softmax(dim=-1), 1, ids_row.unsqueeze(1)).squeeze(1)
)
return torch.stack(per_token_logps)
# ===== Rollout 结果 =====
@dataclass
class RolloutResult:
output_ids: Tensor
completion_ids: Tensor
per_token_logps: Tensor
completions: List[str]
prompt_lens: Tensor
completion_mask: Tensor
# ===== Rollout 引擎抽象基类 =====
class RolloutEngine(ABC):
tokenizer = None
@abstractmethod
def rollout(self, prompt_ids: Tensor, attention_mask: Tensor, num_generations: int, max_new_tokens: int, temperature: float = 0.8) -> RolloutResult:
pass
@abstractmethod
def update_policy(self, model: torch.nn.Module):
pass
# ===== PyTorch 原生推理引擎 =====
class TorchRolloutEngine(RolloutEngine):
def __init__(self, policy_model: torch.nn.Module, tokenizer, device: str = "cuda", autocast_ctx=None):
self.policy_model = policy_model
self.tokenizer = tokenizer
self.device = device
self.autocast_ctx = autocast_ctx
def rollout(self, prompt_ids: Tensor, attention_mask: Tensor, num_generations: int, max_new_tokens: int, temperature: float = 0.8) -> RolloutResult:
model = self.policy_model.module if isinstance(self.policy_model, DistributedDataParallel) else self.policy_model
ctx = self.autocast_ctx if self.autocast_ctx else nullcontext()
with torch.no_grad(), ctx:
output_ids = model.generate(
input_ids=prompt_ids.repeat_interleave(num_generations, dim=0),
attention_mask=attention_mask.repeat_interleave(num_generations, dim=0),
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=temperature,
num_return_sequences=1,
pad_token_id=self.tokenizer.pad_token_id,
eos_token_id=self.tokenizer.eos_token_id,
) # [B*num_gen, P+R]
prompt_len = prompt_ids.size(1)
completion_ids = output_ids[:, prompt_len:] # [B*num_gen, R]
full_mask = (output_ids != self.tokenizer.pad_token_id).long()
per_token_logps = compute_per_token_logps(self.policy_model, output_ids, completion_ids.size(1), attention_mask=full_mask)
completions = self.tokenizer.batch_decode(completion_ids, skip_special_tokens=True)
return RolloutResult(output_ids, completion_ids, per_token_logps, completions,
prompt_ids.new_full((output_ids.size(0),), prompt_len),
attention_mask.new_ones(output_ids.size(0), completion_ids.size(1)))
def update_policy(self, model: torch.nn.Module):
self.policy_model = model
# ===== SGLang HTTP API 推理引擎 =====
class SGLangRolloutEngine(RolloutEngine):
def __init__(self, base_url: str, model_path: str, shared_ckpt_path: str = "./sglang_ckpt", timeout: int = 120):
self.base_url = base_url.rstrip('/')
self.shared_ckpt_path = shared_ckpt_path
self.timeout = timeout
self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
self.http = requests
def rollout(self, prompt_ids: Tensor, attention_mask: Tensor, num_generations: int, max_new_tokens: int, temperature: float = 0.8) -> RolloutResult:
# 去除左侧 padding tokens只保留有效 token
input_ids_list = []
for ids, mask in zip(prompt_ids, attention_mask):
valid_ids = ids[mask.bool()].tolist()
input_ids_list.append(valid_ids)
all_input_ids = [ids for ids in input_ids_list for _ in range(num_generations)]
payload = {
"input_ids": all_input_ids,
"sampling_params": {
"temperature": temperature,
"max_new_tokens": max_new_tokens,
"stop_token_ids": [self.tokenizer.eos_token_id] if self.tokenizer.eos_token_id else [],
},
"return_logprob": True,
}
resp = self.http.post(f"{self.base_url}/generate", json=payload, timeout=self.timeout)
resp.raise_for_status()
results = resp.json()
if not isinstance(results, list):
results = [results]
all_output_ids, all_completion_ids, all_logprobs = [], [], []
completions = []
for i, result in enumerate(results):
meta = result.get("meta_info", {})
completion_ids = meta.get("output_ids", result.get("output_ids", []))
raw_logprobs = meta.get("output_token_logprobs", [])
logprobs = []
for item in raw_logprobs:
if isinstance(item, (list, tuple)) and len(item) >= 1:
logprobs.append(item[0])
elif isinstance(item, (int, float)):
logprobs.append(item)
if len(logprobs) < len(completion_ids):
logprobs = [0.0] * (len(completion_ids) - len(logprobs)) + logprobs
elif len(logprobs) > len(completion_ids):
logprobs = logprobs[-len(completion_ids):] if completion_ids else []
prompt = all_input_ids[i]
full_output = prompt + completion_ids
all_output_ids.append(full_output)
all_completion_ids.append(completion_ids)
all_logprobs.append(logprobs)
completions.append(self.tokenizer.decode(completion_ids, skip_special_tokens=True))
device = prompt_ids.device
max_comp_len = max(1, max(len(ids) for ids in all_completion_ids))
max_out_len = max(len(ids) for ids in all_input_ids) + max_comp_len
def pad_to_tensor(seqs, max_len, pad_val=0):
return torch.tensor([s + [pad_val] * (max_len - len(s)) for s in seqs], device=device)
pad_id = self.tokenizer.pad_token_id
return RolloutResult(
output_ids=pad_to_tensor(all_output_ids, max_out_len, pad_val=pad_id),
completion_ids=pad_to_tensor(all_completion_ids, max_comp_len, pad_val=pad_id),
per_token_logps=pad_to_tensor(all_logprobs, max_comp_len, pad_val=0.0),
completions=completions,
prompt_lens=torch.tensor([len(ids) for ids in all_input_ids], device=device),
completion_mask=torch.tensor([[1] * len(ids) + [0] * (max_comp_len - len(ids)) for ids in all_completion_ids], device=device),
)
def update_policy(self, model: torch.nn.Module):
ok = True
if not dist.is_initialized() or dist.get_rank() == 0:
try:
unwrapped = model.module if isinstance(model, DistributedDataParallel) else model
unwrapped = getattr(unwrapped, '_orig_mod', unwrapped)
abs_path = os.path.abspath(self.shared_ckpt_path)
state_dict = {k: v.detach().half().cpu() for k, v in unwrapped.state_dict().items()}
unwrapped.save_pretrained(abs_path, state_dict=state_dict, safe_serialization=False)
self.tokenizer.save_pretrained(abs_path)
resp = self.http.post(f"{self.base_url}/update_weights_from_disk", json={"model_path": abs_path}, timeout=self.timeout)
if resp.status_code != 200: print(f"[SGLANG WARNING] update_weights 失败: {resp.status_code}, {resp.text}")
ok = resp.status_code == 200
except Exception as e:
print(f"[SGLANG WARNING] update_weights 异常: {e}"); ok = False
if dist.is_initialized():
ok_t = torch.tensor(int(ok), device=next(model.parameters()).device)
dist.broadcast(ok_t, src=0); dist.barrier(); ok = bool(ok_t.item())
if not ok: raise RuntimeError("SGLang update_policy failed")
return ok
def flush_cache(self) -> bool:
resp = self.http.post(f"{self.base_url}/flush_cache", timeout=30)
return resp.status_code == 200
def health(self) -> bool:
try:
resp = self.http.get(f"{self.base_url}/health", timeout=5)
return resp.status_code == 200
except:
return False
# ===== 工厂函数 =====
def create_rollout_engine(
engine_type: str = "torch",
policy_model: torch.nn.Module = None,
tokenizer = None,
device: str = "cuda",
autocast_ctx = None,
sglang_base_url: str = None,
sglang_model_path: str = None,
sglang_shared_path: str = None,
) -> RolloutEngine:
if engine_type == "torch":
return TorchRolloutEngine(policy_model, tokenizer, device, autocast_ctx)
elif engine_type == "sglang":
return SGLangRolloutEngine(sglang_base_url, sglang_model_path, sglang_shared_path)
else:
raise ValueError(f"不支持的引擎类型: {engine_type}")

488
trainer/train_agent.py Normal file
View File

@ -0,0 +1,488 @@
import os
import sys
__package__ = "trainer"
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
import re
import gc
import json
import math
import random
import signal
import argparse
import warnings
import torch
import torch.nn.functional as F
import torch.distributed as dist
from contextlib import nullcontext
from torch import optim
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import DataLoader, DistributedSampler
from torch.optim.lr_scheduler import CosineAnnealingLR
from transformers import AutoTokenizer
from model.model_minimind import MiniMindConfig, MiniMindForCausalLM
from dataset.lm_dataset import AgentRLDataset
from trainer.trainer_utils import Logger, is_main_process, lm_checkpoint, init_distributed_mode, setup_seed, SkipBatchSampler, init_model, LMForRewardModel
from trainer.rollout_engine import create_rollout_engine, compute_per_token_logps
warnings.filterwarnings('ignore')
# ================================ 工具与 Reward = Start ================================
def rep_penalty(text, n=3, cap=0.5):
toks = re.findall(r"\w+|[^\w\s]", text.lower())
grams = [tuple(toks[i:i + n]) for i in range(len(toks) - n + 1)]
return min(cap, (len(grams) - len(set(grams))) * cap * 2 / len(grams)) if grams else 0.0
# ======== 工具定义 ========
TOOLS = [
{"type": "function", "function": {"name": "calculate_math", "description": "计算数学表达式", "parameters": {"type": "object", "properties": {"expression": {"type": "string"}}, "required": ["expression"]}}},
{"type": "function", "function": {"name": "unit_converter", "description": "单位换算", "parameters": {"type": "object", "properties": {"value": {"type": "number"}, "from_unit": {"type": "string"}, "to_unit": {"type": "string"}}, "required": ["value", "from_unit", "to_unit"]}}},
{"type": "function", "function": {"name": "get_current_weather", "description": "获取天气", "parameters": {"type": "object", "properties": {"location": {"type": "string"}}, "required": ["location"]}}},
{"type": "function", "function": {"name": "get_current_time", "description": "获取时间", "parameters": {"type": "object", "properties": {"timezone": {"type": "string", "default": "Asia/Shanghai"}}, "required": []}}},
{"type": "function", "function": {"name": "get_exchange_rate", "description": "查询汇率", "parameters": {"type": "object", "properties": {"from_currency": {"type": "string"}, "to_currency": {"type": "string"}}, "required": ["from_currency", "to_currency"]}}},
{"type": "function", "function": {"name": "translate_text", "description": "翻译文本", "parameters": {"type": "object", "properties": {"text": {"type": "string"}, "target_language": {"type": "string"}}, "required": ["text", "target_language"]}}},
]
# ======== 模拟数据 ========
WEATHER_DATA = {"北京": ("28°C", ""), "上海": ("15°C", "多云"), "广州": ("32°C", "闷热"), "深圳": ("30°C", ""), "杭州": ("22°C", ""), "成都": ("18°C", "小雨"), "武汉": ("25°C", "多云"), "南京": ("20°C", ""), "西安": ("16°C", "大风"), "重庆": ("26°C", ""), "Tokyo": ("12°C", ""), "New York": ("8°C", "多云"), "London": ("5°C", "小雨"), "Paris": ("10°C", ""), "Sydney": ("25°C", "晴朗")}
TIME_DATA = {"Asia/Shanghai": "2025-03-07 14:30:00", "America/New_York": "2025-03-07 01:30:00", "Europe/London": "2025-03-07 06:30:00", "Asia/Tokyo": "2025-03-07 15:30:00", "Europe/Paris": "2025-03-07 07:30:00", "Australia/Sydney": "2025-03-07 17:30:00"}
EXCHANGE_DATA = {("USD", "CNY"): 7.21, ("EUR", "CNY"): 7.85, ("GBP", "CNY"): 9.12, ("JPY", "CNY"): 0.048, ("USD", "EUR"): 0.92, ("USD", "GBP"): 0.79, ("CNY", "JPY"): 20.83, ("AUD", "CNY"): 4.72}
TRANSLATE_DATA = {("你好世界", "english"): "Hello World", ("Good morning", "chinese"): "早上好", ("今天天气真好", "english"): "The weather is nice today", ("I love programming", "chinese"): "我喜欢编程", ("机器学习很有趣", "english"): "Machine learning is interesting", ("Happy birthday", "chinese"): "生日快乐"}
UNIT_DATA = {"km_miles": 0.621371, "miles_km": 1.60934, "kg_pounds": 2.20462, "pounds_kg": 0.453592, "meters_feet": 3.28084, "feet_meters": 0.3048, "celsius_fahrenheit": 1.8, "fahrenheit_celsius": 0.5556}
# ======== 模拟执行 ========
MOCK_RESULTS = {
"calculate_math": lambda args: {"result": str(eval(str(args.get("expression", "0")).replace("^", "**").replace("×", "*").replace("÷", "/").replace("", "-").replace("", "(").replace("", ")"), {"__builtins__": {}, "math": math}))},
"unit_converter": lambda args: {"result": round(float(args.get("value", 0)) * UNIT_DATA.get(f"{args.get('from_unit', '').lower()}_{args.get('to_unit', '').lower()}", 1), 4)},
"get_current_weather": lambda args: (lambda w: {"city": args.get("location"), "temperature": w[0], "humidity": "65%", "condition": w[1]})(WEATHER_DATA.get(args.get("location"), ("22°C", ""))),
"get_current_time": lambda args: {"datetime": TIME_DATA.get(args.get("timezone", "Asia/Shanghai"), "2025-03-07 14:30:00"), "timezone": args.get("timezone", "Asia/Shanghai")},
"get_exchange_rate": lambda args: {"from": args.get("from_currency"), "to": args.get("to_currency"), "rate": EXCHANGE_DATA.get((args.get("from_currency"), args.get("to_currency")), 1.0)},
"translate_text": lambda args: {"translated_text": TRANSLATE_DATA.get((args.get("text"), args.get("target_language")), args.get("text", ""))},
}
# ======== 参数校验 ========
CHECK_ARGS = {
"calculate_math": lambda a: bool(a.get("expression")),
"unit_converter": lambda a: a.get("value") is not None and a.get("from_unit") and a.get("to_unit"),
"get_current_weather": lambda a: bool(a.get("location")),
"get_current_time": lambda a: True,
"get_exchange_rate": lambda a: bool(a.get("from_currency")) and bool(a.get("to_currency")),
"translate_text": lambda a: bool(a.get("text")) and bool(a.get("target_language")),
}
# ======== 工具调用解析与执行 ========
def parse_tool_calls(text):
calls = []
for m in re.findall(r'<tool_call>(.*?)</tool_call>', text, re.DOTALL):
try: calls.append(json.loads(m.strip()))
except: pass
return calls
def execute_tool(name, args):
fn = MOCK_RESULTS.get(name)
if not fn: return None
try:
signal.signal(signal.SIGALRM, lambda *_: (_ for _ in ()).throw(TimeoutError()))
signal.alarm(1)
return fn(args)
except:
return None
finally:
try: signal.alarm(0)
except: pass
# ======== 多轮 Rollout ========
def rollout_single(rollout_engine, tokenizer, messages, tools, max_turns=3, max_new_tokens=256, thinking_ratio=0.5, device="cuda"):
all_outputs = []
prompt_ids = None
response_ids = []
response_mask = []
response_old_logps = []
final_context = ""
unfinished = False
open_thinking = random.random() < thinking_ratio
for turn in range(max_turns):
context = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True, tools=tools, open_thinking=open_thinking)
inputs = tokenizer(context, return_tensors="pt", add_special_tokens=False).to(device)
context_ids = inputs["input_ids"][0].tolist()
if prompt_ids is None:
prompt_ids = context_ids
rollout_result = rollout_engine.rollout(
prompt_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
num_generations=1,
max_new_tokens=max_new_tokens,
temperature=0.8,
)
new_ids = rollout_result.completion_ids[0].tolist()
new_logps = rollout_result.per_token_logps[0].tolist()
if len(new_ids) != len(new_logps): Logger(f"rollout token/logprob length mismatch: {len(new_ids)} vs {len(new_logps)}")
pairs = [(t, lp) for t, lp in zip(new_ids, new_logps) if t != tokenizer.pad_token_id and t != tokenizer.eos_token_id]
new_ids = [t for t, _ in pairs]
new_logps = [lp for _, lp in pairs]
new_text = rollout_result.completions[0]
all_outputs.append(new_text)
response_ids.extend(new_ids)
response_mask.extend([1] * len(new_ids))
response_old_logps.extend(new_logps)
final_context = context + new_text
calls = parse_tool_calls(new_text)
if not calls:
break
unfinished = turn == max_turns - 1
messages.append({"role": "assistant", "content": new_text})
for call in calls:
name, raw = call.get("name", ""), call.get("arguments", {})
if isinstance(raw, str):
try: raw = json.loads(raw)
except: raw = {}
result = execute_tool(name, raw)
result_str = (json.dumps(result, ensure_ascii=False) if result else '{"error": "tool not found"}')[:2048] # 防止天文数字撑爆tokenizer
messages.append({"role": "tool", "content": result_str})
observe_context = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=not unfinished, tools=tools, open_thinking=open_thinking)
observe_ids = tokenizer(observe_context, return_tensors="pt", add_special_tokens=False)["input_ids"][0].tolist()
current_len = len(prompt_ids) + len(response_ids)
obs_delta = observe_ids[current_len:]
response_ids.extend(obs_delta)
response_mask.extend([0] * len(obs_delta))
response_old_logps.extend([0.0] * len(obs_delta))
final_context = observe_context
final_output = all_outputs[-1] if all_outputs else ""
prompt_ids = prompt_ids or []
return final_output, final_context, prompt_ids, response_ids, response_mask, response_old_logps, list(all_outputs), unfinished
def rollout_batch(rollout_engine, tokenizer, messages_batch, tools_batch, num_gen, max_turns=3, max_new_tokens=256, thinking_ratio=0.5, device="cuda"):
all_completions = []
all_contexts = []
all_prompt_ids = []
all_response_ids = []
all_response_masks = []
all_response_old_logps = []
all_turn_outputs = []
all_unfinished = []
for messages, tools in zip(messages_batch, tools_batch):
for _ in range(num_gen):
msgs_copy = [dict(m) for m in messages]
completion, context, prompt_ids, response_ids, response_mask, response_old_logps, turn_outputs, unfinished = rollout_single(rollout_engine, tokenizer, msgs_copy, tools, max_turns, max_new_tokens, thinking_ratio, device)
all_completions.append(completion)
all_contexts.append(context)
all_prompt_ids.append(prompt_ids)
all_response_ids.append(response_ids)
all_response_masks.append(response_mask)
all_response_old_logps.append(response_old_logps)
all_turn_outputs.append(turn_outputs)
all_unfinished.append(unfinished)
return all_completions, all_contexts, all_prompt_ids, all_response_ids, all_response_masks, all_response_old_logps, all_turn_outputs, all_unfinished
# ======== Reward 计算 ========
def validate_gt_in_text(text, gt_list):
text, text_num = str(text), str(text).replace(',', '')
nums = [float(x) for x in re.findall(r'(?<![\w.])[-+]?\d+(?:\.\d+)?(?![\w.])', text_num)]
return {g for g in gt_list if ((s := str(g).strip()) and s.lower() in text.lower()) or (re.fullmatch(r'[-+]?\d+(?:\.\d+)?', str(g).strip().replace(',', '')) and any(abs(float(str(g).strip().replace(',', '')) - n) < 1e-6 for n in nums))}
def calculate_rewards(prompts, completions, gt_batch, tools_batch, num_gen, reward_model=None, device="cuda", turn_outputs_batch=None, unfinished_batch=None):
rewards = torch.zeros(len(completions), device=device)
for idx, response in enumerate(completions):
reward, answer = 0.0, response
sample_idx = idx // num_gen
tools = tools_batch[sample_idx]
turn_outputs = turn_outputs_batch[idx] if turn_outputs_batch is not None else [response]
unfinished = unfinished_batch[idx] if unfinished_batch is not None else False
turn_answers = [turn.split('</think>', 1)[-1].strip() if '</think>' in turn else turn.strip() for turn in turn_outputs]
answer = turn_answers[-1] if turn_answers else response.strip()
valid_names = {t['function']['name'] for t in tools} if tools else set()
tool_calls = []
for turn_answer in turn_answers: tool_calls.extend(parse_tool_calls(turn_answer)) # 解析tool调用
reward -= 0.5 * sum(abs(turn.count('<tool_call>') - turn.count('</tool_call>')) for turn in turn_answers) # 标签扣分
# -------- 无工具调用:格式+reward奖励 --------
if not tool_calls:
reward += 0.5 if 5 <= len(response.strip()) <= 800 else -0.5 # 长度分
if '</think>' in response:
think, answer = response.split('</think>', 1)
reward += 1.0 if 20 <= len(think.strip()) <= 300 else -0.5 # 思考长度分
reward += 0.25 if response.count('</think>') == 1 else -0.25 # 思考闭合分
answer = answer.strip()
if reward_model is not None:
prompt = prompts[sample_idx]
pattern = r"<\|im_start\|>(system|user|assistant)\s+(.*?)<\|im_end\|>"
matches = re.findall(pattern, prompt, re.DOTALL)
messages = [{"role": role, "content": content.strip()} for role, content in matches]
score = reward_model.get_score(messages, answer)
reward += score # RM分
reward -= rep_penalty(answer)
rewards[idx] = max(min(reward, 3.0), -3.0) # 总分Clip
# -------- 有工具调用:执行结果奖励 --------
else:
gt = gt_batch[sample_idx]
valid_call_count = 0
for tool_call in tool_calls:
name, raw = tool_call.get("name", ""), tool_call.get("arguments", {})
if isinstance(raw, str):
try: raw = json.loads(raw)
except: raw = {}
check = CHECK_ARGS.get(name)
valid_call_count += int(bool(name in valid_names and check and check(raw)))
tool_gap = abs(valid_call_count - len(gt)) + max(0, len(tool_calls) - valid_call_count) # tool数差值
reward += 0.5 if tool_gap == 0 else -0.5 * tool_gap # tool对齐分
final_text = "" if unfinished else (answer.split('</tool_call>')[-1] if '</tool_call>' in answer else answer)
verified = validate_gt_in_text(final_text, gt) if gt else set()
if gt: reward += 2.5 * len(verified) / len(gt) # GT分
if unfinished: reward -= 0.5 # 未完成扣分
reward -= rep_penalty(final_text if final_text else answer)
rewards[idx] = max(min(reward, 3.0), -3.0) # 总分Clip
return rewards
# ================================ 工具与 Reward = End ================================
def rl_train_epoch(epoch, loader, iters, rollout_engine, ref_model, reward_model=None, start_step=0, wandb=None, use_sglang=False):
last_step = start_step
for step, batch in enumerate(loader, start=start_step + 1):
messages_batch = batch['messages']
tools_batch = batch['tools']
gt_batch = batch['gt']
last_step = step
with torch.no_grad():
completions, contexts, prompt_ids_batch, response_ids_batch, response_masks_batch, response_old_logps_batch, turn_outputs_batch, unfinished_batch = rollout_batch(rollout_engine, tokenizer, messages_batch, tools_batch, args.num_generations, max_turns=3, max_new_tokens=args.max_gen_len, thinking_ratio=args.thinking_ratio, device=args.device)
prompts = [tokenizer.apply_chat_template(m, tokenize=False, add_generation_prompt=True, tools=t) for m, t in zip(messages_batch, tools_batch)]
packed_samples = []
for p, r, m, old_lp in zip(prompt_ids_batch, response_ids_batch, response_masks_batch, response_old_logps_batch):
ids = p + r
mask = [0] * len(p) + m
old_logps = [0.0] * max(len(p) - 1, 0) + old_lp
if len(ids) > args.max_total_len:
ids = ids[-args.max_total_len:]
mask = mask[-args.max_total_len:]
old_logps = old_logps[-(len(ids) - 1):]
prompt_len = next((i for i, v in enumerate(mask) if v == 1), len(mask))
packed_samples.append((ids, mask, prompt_len, old_logps))
seq_lens = torch.tensor([len(ids) for ids, _, _, _ in packed_samples], device=args.device)
max_len = seq_lens.max().item()
input_ids = torch.tensor([ids + [tokenizer.pad_token_id] * (max_len - len(ids)) for ids, _, _, _ in packed_samples], device=args.device)
prompt_lens = torch.tensor([prompt_len for _, _, prompt_len, _ in packed_samples], device=args.device)
full_response_masks = torch.tensor([mask + [0] * (max_len - len(mask)) for _, mask, _, _ in packed_samples], device=args.device, dtype=torch.float32)
old_per_token_logps = torch.tensor([old_logps + [0.0] * ((max_len - 1) - len(old_logps)) for _, _, _, old_logps in packed_samples], device=args.device, dtype=torch.float32)
full_mask = (input_ids != tokenizer.pad_token_id).long()
model_unwrapped = model.module if isinstance(model, DistributedDataParallel) else model
with autocast_ctx:
res = model_unwrapped(input_ids, attention_mask=full_mask)
aux_loss = res.aux_loss if lm_config.use_moe else torch.tensor(0.0, device=args.device)
logits = res.logits[:, :-1, :]
per_token_logps = F.log_softmax(logits, dim=-1).gather(2, input_ids[:, 1:].unsqueeze(-1)).squeeze(-1)
with torch.no_grad():
ref_per_token_logps = compute_per_token_logps(ref_model, input_ids, input_ids.size(1) - 1, attention_mask=full_mask)
completion_mask = full_response_masks[:, 1:]
is_eos = (input_ids[:, 1:] == tokenizer.eos_token_id) & completion_mask.bool()
eos_idx = torch.full((completion_mask.size(0),), completion_mask.size(1) - 1, device=args.device, dtype=torch.long)
has_eos = is_eos.any(dim=1)
eos_idx[has_eos] = is_eos.int().argmax(dim=1)[has_eos]
pos = torch.arange(completion_mask.size(1), device=args.device).unsqueeze(0)
completion_mask = completion_mask * (pos <= eos_idx.unsqueeze(1)).float()
token_counts = completion_mask.sum(dim=1)
valid_rows = token_counts > 0
rewards = calculate_rewards(prompts, completions, gt_batch, tools_batch, args.num_generations, reward_model, device=args.device, turn_outputs_batch=turn_outputs_batch, unfinished_batch=unfinished_batch)
if args.debug_mode and is_main_process() and step % args.debug_interval == 0:
for i in range(len(messages_batch)):
Logger(f"[DEBUG] step={step}, gt[{i}]: {repr(gt_batch[i])}")
Logger('-'*100)
for j in range(args.num_generations):
idx = i * args.num_generations + j
plen, slen = prompt_lens[idx].item(), seq_lens[idx].item()
Logger(f"{'=' * 30} [DEBUG] gen[{i}][{j}] CONTEXT_BEGIN {'=' * 30}")
Logger(contexts[idx])
Logger(f"{'=' * 31} [DEBUG] gen[{i}][{j}] CONTEXT_END {'=' * 31}")
Logger(f"[DEBUG] gen[{i}][{j}] prompt_len={plen}, seq_len={slen}")
tokens = input_ids[idx, plen:slen].tolist()
text = tokenizer.decode(tokens, skip_special_tokens=False)
Logger(f"{'=' * 28} [DEBUG] gen[{i}][{j}] COMPLETION_BEGIN [{plen}:{slen}] {'=' * 28}")
Logger(text)
Logger(f"{'=' * 29} [DEBUG] gen[{i}][{j}] COMPLETION_END {'=' * 29}")
Logger(f"[DEBUG] gen[{i}][{j}] reward={rewards[idx].item():.4f}")
Logger('='*100)
grouped_rewards = rewards.view(-1, args.num_generations)
mean_r = grouped_rewards.mean(dim=1).repeat_interleave(args.num_generations)
std_r = grouped_rewards.std(dim=1, unbiased=False).repeat_interleave(args.num_generations)
advantages = (rewards - mean_r) / (std_r + 1e-4)
kl_div = ref_per_token_logps - per_token_logps
per_token_kl = torch.exp(kl_div) - kl_div - 1
ratio = torch.exp(per_token_logps - old_per_token_logps)
if args.loss_type == "cispo":
clamped_ratio = torch.clamp(ratio, max=args.epsilon_high).detach()
per_token_loss = -(clamped_ratio * advantages.unsqueeze(1) * per_token_logps - args.beta * per_token_kl)
else:
clipped_ratio = torch.clamp(ratio, 1 - args.epsilon, 1 + args.epsilon)
per_token_loss1 = ratio * advantages.unsqueeze(1)
per_token_loss2 = clipped_ratio * advantages.unsqueeze(1)
per_token_loss = -(torch.min(per_token_loss1, per_token_loss2) - args.beta * per_token_kl)
policy_loss = (((per_token_loss * completion_mask).sum(dim=1)[valid_rows] / token_counts[valid_rows].clamp(min=1)).mean()
if valid_rows.any() else per_token_loss.sum() * 0.0)
loss = (policy_loss + aux_loss) / args.accumulation_steps
loss.backward()
if step % args.accumulation_steps == 0:
if args.grad_clip > 0: torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
optimizer.step(); scheduler.step(); optimizer.zero_grad()
if step % args.log_interval == 0 or step == iters:
pl = loss.item() * args.accumulation_steps
ar = rewards.mean().item()
al = token_counts.float().mean().item()
kl = ((ref_per_token_logps - per_token_logps) * completion_mask).sum().item() / max(token_counts.sum().item(), 1)
gs = grouped_rewards.std(dim=1, unbiased=False).mean().item()
am, ast = advantages.mean().item(), advantages.std().item()
lr = optimizer.param_groups[0]['lr']
Logger(f'Epoch:[{epoch+1}/{args.epochs}]({step}/{iters}), Reward:{ar:.4f}, KL:{kl:.4f}, GrpStd:{gs:.4f}, AdvStd:{ast:.4f}, Loss:{pl:.4f}, AvgLen:{al:.2f}, AdvMean:{am:.4f}, LR:{lr:.8f}')
if wandb and is_main_process():
wandb.log({"reward":ar,"kl_ref":kl,"group_reward_std":gs,"advantages_std":ast,"policy_loss":pl,"avg_response_len":al,"advantages_mean":am,"learning_rate":lr})
if (step % args.save_interval == 0 or step == iters) and is_main_process():
model.eval()
moe_suffix = '_moe' if lm_config.use_moe else ''
ckp = f'{args.save_dir}/{args.save_weight}_{lm_config.hidden_size}{moe_suffix}.pth'
raw_model = model.module if isinstance(model, DistributedDataParallel) else model
raw_model = getattr(raw_model, '_orig_mod', raw_model)
state_dict = raw_model.state_dict()
torch.save({k: v.half().cpu() for k, v in state_dict.items()}, ckp)
lm_checkpoint(lm_config, weight=args.save_weight, model=model, optimizer=optimizer,
epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints', scheduler=scheduler)
model.train()
del state_dict
if step % args.save_interval == 0 or step == iters: rollout_engine.update_policy(model)
del per_token_logps, ref_per_token_logps
del completions, rewards, grouped_rewards, mean_r, std_r, advantages, completion_mask
if last_step > start_step and last_step % args.accumulation_steps != 0:
if args.grad_clip > 0: torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
optimizer.step(); scheduler.step(); optimizer.zero_grad()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="MiniMind Agent RL")
parser.add_argument("--save_dir", type=str, default="../out", help="模型保存目录")
parser.add_argument('--save_weight', default='agent', type=str, help="保存权重名称")
parser.add_argument("--epochs", type=int, default=1, help="训练轮数")
parser.add_argument("--batch_size", type=int, default=2, help="批次大小")
parser.add_argument("--learning_rate", type=float, default=3e-7, help="学习率")
parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu", help="训练设备")
parser.add_argument("--dtype", type=str, default="bfloat16", help="数据类型 bfloat16/float16")
parser.add_argument("--num_workers", type=int, default=8, help="数据加载线程数")
parser.add_argument("--accumulation_steps", type=int, default=1, help="梯度累积步数")
parser.add_argument("--grad_clip", type=float, default=1.0, help="梯度裁剪阈值")
parser.add_argument("--log_interval", type=int, default=1, help="日志打印间隔")
parser.add_argument("--save_interval", type=int, default=10, help="模型保存间隔")
parser.add_argument('--hidden_size', default=768, type=int, help="模型隐藏层维度")
parser.add_argument('--num_hidden_layers', default=8, type=int, help="模型层数")
parser.add_argument('--use_moe', default=0, type=int, choices=[0, 1], help="是否使用MoE")
parser.add_argument('--max_seq_len', default=1024, type=int, help="最大序列长度")
parser.add_argument("--max_gen_len", type=int, default=768, help="单次最大生成长度")
parser.add_argument("--max_total_len", type=int, default=2500, help="训练侧最终总长度上界")
parser.add_argument("--data_path", type=str, default="../dataset/agent_rl.jsonl", help="训练数据路径")
parser.add_argument("--num_generations", type=int, default=4, help="每个prompt生成数量")
parser.add_argument("--beta", type=float, default=0.1, help="KL散度惩罚系数")
parser.add_argument("--loss_type", type=str, default="cispo", choices=["grpo", "cispo"], help="loss类型")
parser.add_argument("--epsilon", type=float, default=0.2, help="GRPO的PPO clip epsilon")
parser.add_argument("--epsilon_high", type=float, default=5.0, help="epsilon上界")
parser.add_argument('--from_weight', default='full_sft', type=str, help="加载预训练权重名称")
parser.add_argument('--from_resume', default=0, type=int, choices=[0, 1], help="是否从checkpoint恢复")
parser.add_argument("--use_wandb", action="store_true", help="是否使用wandb记录")
parser.add_argument("--wandb_project", type=str, default="MiniMind-Agent-RL", help="wandb项目名称")
parser.add_argument("--use_compile", default=0, type=int, choices=[0, 1], help="是否使用torch.compile")
parser.add_argument("--debug_mode", action="store_true", help="调试模式")
parser.add_argument("--debug_interval", type=int, default=20, help="调试日志间隔")
parser.add_argument("--thinking_ratio", type=float, default=0.1, help="按概率开启thinking0.0~1.0")
parser.add_argument("--reward_model_path", type=str, default="../../internlm2-1_8b-reward", help="Reward模型路径")
parser.add_argument("--rollout_engine", type=str, default="torch", choices=["torch", "sglang"], help="rollout引擎类型")
parser.add_argument("--sglang_base_url", type=str, default="http://localhost:8998", help="SGLang服务器URL")
parser.add_argument("--sglang_model_path", type=str, default="../model", help="SGLang tokenizer路径")
parser.add_argument("--sglang_shared_path", type=str, default="./sglang_ckpt_agent", help="SGLang共享存储路径")
args = parser.parse_args()
local_rank = init_distributed_mode()
if dist.is_initialized(): args.device = f"cuda:{local_rank}"
setup_seed(42 + (dist.get_rank() if dist.is_initialized() else 0))
os.makedirs(args.save_dir, exist_ok=True)
lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers,
max_seq_len=args.max_seq_len + args.max_gen_len, use_moe=bool(args.use_moe))
ckp_data = lm_checkpoint(lm_config, weight=args.save_weight, save_dir='../checkpoints') if args.from_resume == 1 else None
device_type = "cuda" if "cuda" in args.device else "cpu"
dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16
autocast_ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=dtype)
wandb = None
if args.use_wandb and is_main_process():
import swanlab as wandb
wandb_id = ckp_data.get('wandb_id') if ckp_data else None
resume = 'must' if wandb_id else None
wandb.init(project=args.wandb_project, name=f"Agent-RL-E{args.epochs}-B{args.batch_size}-LR{args.learning_rate}", id=wandb_id, resume=resume)
model, tokenizer = init_model(lm_config, args.from_weight, device=args.device)
ref_model, _ = init_model(lm_config, args.from_weight, device=args.device)
ref_model = ref_model.eval().requires_grad_(False)
reward_model = LMForRewardModel(args.reward_model_path, device=args.device, dtype=torch.float16)
Logger(f'Loaded reward model from {args.reward_model_path}')
# Rollout引擎
rollout_engine = create_rollout_engine(
engine_type=args.rollout_engine,
policy_model=model,
tokenizer=tokenizer,
device=args.device,
autocast_ctx=autocast_ctx,
sglang_base_url=args.sglang_base_url,
sglang_model_path=args.sglang_model_path,
sglang_shared_path=args.sglang_shared_path,
)
train_ds = AgentRLDataset(args.data_path, tokenizer, max_length=lm_config.max_seq_len)
train_sampler = DistributedSampler(train_ds) if dist.is_initialized() else None
optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)
def collate_fn(batch): return {'messages': [b['messages'] for b in batch], 'tools': [b['tools'] for b in batch], 'gt': [b['gt'] for b in batch]}
loader_for_count = DataLoader(train_ds, batch_size=args.batch_size, sampler=train_sampler, collate_fn=collate_fn)
iters = len(loader_for_count)
total_optimizer_steps = math.ceil(iters / args.accumulation_steps) * args.epochs
scheduler = CosineAnnealingLR(optimizer, T_max=total_optimizer_steps, eta_min=args.learning_rate / 10)
start_epoch, start_step = 0, 0
if ckp_data:
model.load_state_dict(ckp_data['model'])
optimizer.load_state_dict(ckp_data['optimizer'])
scheduler.load_state_dict(ckp_data['scheduler'])
start_epoch = ckp_data['epoch']
start_step = ckp_data.get('step', 0)
if args.use_compile == 1:
model = torch.compile(model)
Logger('torch.compile enabled')
rollout_engine.update_policy(model)
if dist.is_initialized():
model = DistributedDataParallel(model, device_ids=[local_rank])
rollout_engine.update_policy(model)
for epoch in range(start_epoch, args.epochs):
train_sampler and train_sampler.set_epoch(epoch)
setup_seed(42 + epoch); indices = torch.randperm(len(train_ds)).tolist()
skip = start_step if (epoch == start_epoch and start_step > 0) else 0
batch_sampler = SkipBatchSampler(train_sampler or indices, args.batch_size, skip)
loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True, collate_fn=collate_fn)
if skip > 0:
Logger(f'Epoch [{epoch+1}/{args.epochs}]: skip {start_step} steps')
rl_train_epoch(epoch, loader, len(loader) + skip, rollout_engine, ref_model, reward_model, start_step, wandb, use_sglang = (args.rollout_engine == "sglang"))
else:
rl_train_epoch(epoch, loader, len(loader), rollout_engine, ref_model, reward_model, 0, wandb, use_sglang = (args.rollout_engine == "sglang"))
if dist.is_initialized(): dist.destroy_process_group()

View File

@ -0,0 +1,245 @@
import os
import sys
__package__ = "trainer"
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
import argparse
import time
import warnings
import torch
import torch.nn.functional as F
import torch.distributed as dist
from contextlib import nullcontext
from torch import optim
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import DataLoader, DistributedSampler
from model.model_minimind import MiniMindConfig
from dataset.lm_dataset import SFTDataset
from trainer.trainer_utils import get_lr, Logger, is_main_process, lm_checkpoint, init_distributed_mode, setup_seed, init_model, SkipBatchSampler
warnings.filterwarnings('ignore')
def distillation_loss(student_logits, teacher_logits, temperature=1.0, reduction='batchmean'):
with torch.no_grad():
teacher_probs = F.softmax(teacher_logits / temperature, dim=-1).detach()
student_log_probs = F.log_softmax(student_logits / temperature, dim=-1)
kl = F.kl_div(
student_log_probs,
teacher_probs,
reduction=reduction
)
return (temperature ** 2) * kl
def train_epoch(epoch, loader, iters, teacher_model, lm_config_student, start_step=0, wandb=None, alpha=0.0, temperature=1.0):
start_time = time.time()
last_step = start_step
if teacher_model is not None:
teacher_model.eval()
teacher_model.requires_grad_(False)
for step, (input_ids, labels) in enumerate(loader, start=start_step + 1):
last_step = step
input_ids = input_ids.to(args.device)
labels = labels.to(args.device)
loss_mask = (labels[..., 1:] != -100).float()
lr = get_lr(epoch * iters + step, args.epochs * iters, args.learning_rate)
for param_group in optimizer.param_groups:
param_group['lr'] = lr
# 前向传播(学生模型)
with autocast_ctx:
res = model(input_ids)
student_logits = res.logits[..., :-1, :].contiguous()
# 教师模型前向传播只在eval & no_grad
if teacher_model is not None:
with torch.no_grad():
teacher_logits = teacher_model(input_ids).logits[..., :-1, :].contiguous()
vocab_size_student = student_logits.size(-1)
teacher_logits = teacher_logits[..., :vocab_size_student]
# ========== 计算损失 ==========
# 1) Ground-Truth CE Loss
shift_labels = labels[..., 1:].contiguous()
loss_mask_flat = loss_mask.view(-1)
ce_loss = F.cross_entropy(
student_logits.view(-1, student_logits.size(-1)),
shift_labels.view(-1),
ignore_index=-100,
reduction='none'
)
ce_loss_raw = torch.sum(ce_loss * loss_mask_flat) / (loss_mask_flat.sum() + 1e-8)
if lm_config_student.use_moe: ce_loss = ce_loss_raw + res.aux_loss
else: ce_loss = ce_loss_raw
# 2) Distillation Loss
if teacher_model is not None:
distill_loss = distillation_loss(
student_logits.view(-1, student_logits.size(-1))[loss_mask_flat == 1],
teacher_logits.view(-1, teacher_logits.size(-1))[loss_mask_flat == 1],
temperature=temperature
)
else:
distill_loss = torch.tensor(0.0, device=args.device)
# 3) 总损失 = alpha * CE + (1-alpha) * Distill
loss = (alpha * ce_loss + (1 - alpha) * distill_loss) / args.accumulation_steps
scaler.scale(loss).backward()
if step % args.accumulation_steps == 0:
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad(set_to_none=True)
if step % args.log_interval == 0 or step == iters:
spend_time = time.time() - start_time
current_loss = loss.item() * args.accumulation_steps
current_ce_loss = ce_loss_raw.item()
current_aux_loss = res.aux_loss.item() if lm_config_student.use_moe else 0.0
current_lr = optimizer.param_groups[-1]['lr']
eta_min = spend_time / max(step - start_step, 1) * (iters - step) // 60
Logger(f'Epoch:[{epoch + 1}/{args.epochs}]({step}/{iters}), loss: {current_loss:.4f}, ce: {current_ce_loss:.4f}, aux_loss: {current_aux_loss:.4f}, distill: {distill_loss.item():.4f}, learning_rate: {current_lr:.8f}, epoch_time: {eta_min:.3f}min')
if wandb:
wandb.log({
"loss": current_loss,
"ce_loss": current_ce_loss,
"aux_loss": current_aux_loss,
"distill_loss": distill_loss.item() if teacher_model is not None else 0.0,
"learning_rate": current_lr,
"epoch_time": eta_min
})
if (step % args.save_interval == 0 or step == iters) and is_main_process():
model.eval()
moe_suffix = '_moe' if lm_config_student.use_moe else ''
ckp = f'{args.save_dir}/{args.save_weight}_{lm_config_student.hidden_size}{moe_suffix}.pth'
raw_model = model.module if isinstance(model, DistributedDataParallel) else model
raw_model = getattr(raw_model, '_orig_mod', raw_model)
state_dict = raw_model.state_dict()
torch.save({k: v.half().cpu() for k, v in state_dict.items()}, ckp)
lm_checkpoint(lm_config_student, weight=args.save_weight, model=model, optimizer=optimizer, scaler=scaler, epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints')
model.train()
del state_dict
del input_ids, labels, loss_mask, res, student_logits, ce_loss, distill_loss, loss
if last_step > start_step and last_step % args.accumulation_steps != 0:
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad(set_to_none=True)
if __name__ == "__main__":
# 模拟用moe模型蒸馏dense模型也可以用更大teacher_hidden_size模型蒸馏更小student_hidden_size的
parser = argparse.ArgumentParser(description="MiniMind Knowledge Distillation")
parser.add_argument("--save_dir", type=str, default="../out", help="模型保存目录")
parser.add_argument('--save_weight', default='full_dist', type=str, help="保存权重的前缀名")
parser.add_argument("--epochs", type=int, default=6, help="训练轮数")
parser.add_argument("--batch_size", type=int, default=32, help="batch size")
parser.add_argument("--learning_rate", type=float, default=5e-6, help="初始学习率")
parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu", help="训练设备")
parser.add_argument("--dtype", type=str, default="bfloat16", help="混合精度类型")
parser.add_argument("--num_workers", type=int, default=8, help="数据加载线程数")
parser.add_argument("--accumulation_steps", type=int, default=1, help="梯度累积步数")
parser.add_argument("--grad_clip", type=float, default=1.0, help="梯度裁剪阈值")
parser.add_argument("--log_interval", type=int, default=100, help="日志打印间隔")
parser.add_argument("--save_interval", type=int, default=100, help="模型保存间隔")
parser.add_argument("--max_seq_len", type=int, default=340, help="训练的最大截断长度中文1token≈1.5~1.7字符)")
parser.add_argument("--data_path", type=str, default="../dataset/sft_t2t_mini.jsonl", help="训练数据路径")
parser.add_argument('--student_hidden_size', default=768, type=int, help="学生模型隐藏层维度")
parser.add_argument('--student_num_layers', default=8, type=int, help="学生模型隐藏层数量")
parser.add_argument('--teacher_hidden_size', default=768, type=int, help="教师模型隐藏层维度")
parser.add_argument('--teacher_num_layers', default=8, type=int, help="教师模型隐藏层数量")
parser.add_argument('--student_use_moe', default=0, type=int, choices=[0, 1], help="学生模型是否使用MoE0=否1=是)")
parser.add_argument('--teacher_use_moe', default=1, type=int, choices=[0, 1], help="教师模型是否使用MoE0=否1=是)")
parser.add_argument('--from_student_weight', default='full_sft', type=str, help="学生模型基于哪个权重")
parser.add_argument('--from_teacher_weight', default='full_sft', type=str, help="教师模型基于哪个权重")
parser.add_argument('--from_resume', default=0, type=int, choices=[0, 1], help="是否自动检测&续训0=否1=是)")
parser.add_argument('--alpha', default=0.5, type=float, help="CE损失权重总损失=alpha*CE+(1-alpha)*KL")
parser.add_argument('--temperature', default=1.5, type=float, help="蒸馏温度推荐范围1.0-2.0")
parser.add_argument("--use_wandb", action="store_true", help="是否使用wandb")
parser.add_argument("--wandb_project", type=str, default="MiniMind-Distillation", help="wandb项目名")
parser.add_argument("--use_compile", default=0, type=int, choices=[0, 1], help="是否使用torch.compile加速0=否1=是)")
args = parser.parse_args()
# ========== 1. 初始化环境和随机种子 ==========
local_rank = init_distributed_mode()
if dist.is_initialized(): args.device = f"cuda:{local_rank}"
setup_seed(42 + (dist.get_rank() if dist.is_initialized() else 0))
# ========== 2. 配置目录、模型参数、检查ckp ==========
os.makedirs(args.save_dir, exist_ok=True)
lm_config_student = MiniMindConfig(hidden_size=args.student_hidden_size, num_hidden_layers=args.student_num_layers, use_moe=bool(args.student_use_moe))
lm_config_teacher = MiniMindConfig(hidden_size=args.teacher_hidden_size, num_hidden_layers=args.teacher_num_layers, use_moe=bool(args.teacher_use_moe))
ckp_data = lm_checkpoint(lm_config_student, weight=args.save_weight, save_dir='../checkpoints') if args.from_resume==1 else None
# ========== 3. 设置混合精度 ==========
device_type = "cuda" if "cuda" in args.device else "cpu"
dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16
autocast_ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=dtype)
# ========== 4. 配wandb ==========
wandb = None
if args.use_wandb and is_main_process():
import swanlab as wandb
wandb_id = ckp_data.get('wandb_id') if ckp_data else None
resume = 'must' if wandb_id else None
wandb_run_name = f"MiniMind-Distill-S{args.student_hidden_size}T{args.teacher_hidden_size}-Epoch-{args.epochs}-BS-{args.batch_size}-LR-{args.learning_rate}"
wandb.init(project=args.wandb_project, name=wandb_run_name, id=wandb_id, resume=resume)
# ========== 5. 定义学生和教师模型 ==========
model, tokenizer = init_model(lm_config_student, args.from_student_weight, device=args.device)
Logger(f'学生模型总参数量:{sum(p.numel() for p in model.parameters()) / 1e6:.3f} M')
teacher_model, _ = init_model(lm_config_teacher, args.from_teacher_weight, device=args.device)
teacher_model.eval()
teacher_model.requires_grad_(False)
Logger(f'教师模型总参数量:{sum(p.numel() for p in teacher_model.parameters()) / 1e6:.3f} M')
train_ds = SFTDataset(args.data_path, tokenizer, max_length=args.max_seq_len)
train_sampler = DistributedSampler(train_ds) if dist.is_initialized() else None
scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype == 'float16'))
optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)
# ========== 6. 从ckp恢复状态 ==========
start_epoch, start_step = 0, 0
if ckp_data:
model.load_state_dict(ckp_data['model'])
optimizer.load_state_dict(ckp_data['optimizer'])
scaler.load_state_dict(ckp_data['scaler'])
start_epoch = ckp_data['epoch']
start_step = ckp_data.get('step', 0)
# ========== 7. 编译和分布式包装 ==========
if args.use_compile == 1:
model = torch.compile(model)
Logger('torch.compile enabled')
if dist.is_initialized():
model = DistributedDataParallel(model, device_ids=[local_rank])
# ========== 8. 开始训练 ==========
for epoch in range(start_epoch, args.epochs):
train_sampler and train_sampler.set_epoch(epoch)
setup_seed(42 + epoch); indices = torch.randperm(len(train_ds)).tolist()
skip = start_step if (epoch == start_epoch and start_step > 0) else 0
batch_sampler = SkipBatchSampler(train_sampler or indices, args.batch_size, skip)
loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True)
if skip > 0:
Logger(f'Epoch [{epoch + 1}/{args.epochs}]: 跳过前{start_step}个step从step {start_step + 1}开始')
train_epoch(epoch, loader, len(loader) + skip, teacher_model, lm_config_student, start_step, wandb, args.alpha, args.temperature)
else:
train_epoch(epoch, loader, len(loader), teacher_model, lm_config_student, 0, wandb, args.alpha, args.temperature)
# ========== 9. 清理分布进程 ==========
if dist.is_initialized(): dist.destroy_process_group()

225
trainer/train_dpo.py Normal file
View File

@ -0,0 +1,225 @@
import os
import sys
__package__ = "trainer"
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
import argparse
import time
import warnings
import torch
import torch.nn.functional as F
import torch.distributed as dist
from contextlib import nullcontext
from torch import optim
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import DataLoader, DistributedSampler
from model.model_minimind import MiniMindConfig
from dataset.lm_dataset import DPODataset
from trainer.trainer_utils import get_lr, Logger, is_main_process, lm_checkpoint, init_distributed_mode, setup_seed, init_model, SkipBatchSampler
warnings.filterwarnings('ignore')
def logits_to_log_probs(logits, labels):
# logits shape: (batch_size, seq_len, vocab_size)
# labels shape: (batch_size, seq_len)
# log_probs shape: (batch_size, seq_len)
log_probs = F.log_softmax(logits, dim=2)
log_probs_per_token = torch.gather(log_probs, dim=2, index=labels.unsqueeze(2)).squeeze(-1)
return log_probs_per_token
def dpo_loss(ref_log_probs, policy_log_probs, mask, beta):
# ref_log_probs 和 policy_log_probs 都是 shape: (batch_size, seq_len)
ref_log_probs = (ref_log_probs * mask).sum(dim=1)
policy_log_probs = (policy_log_probs * mask).sum(dim=1)
# 将 chosen 和 rejected 数据分开
batch_size = ref_log_probs.shape[0]
chosen_ref_log_probs = ref_log_probs[:batch_size // 2]
reject_ref_log_probs = ref_log_probs[batch_size // 2:]
chosen_policy_log_probs = policy_log_probs[:batch_size // 2]
reject_policy_log_probs = policy_log_probs[batch_size // 2:]
pi_logratios = chosen_policy_log_probs - reject_policy_log_probs
ref_logratios = chosen_ref_log_probs - reject_ref_log_probs
logits = pi_logratios - ref_logratios
loss = -F.logsigmoid(beta * logits)
return loss.mean()
def train_epoch(epoch, loader, iters, ref_model, lm_config, start_step=0, wandb=None, beta=0.1):
start_time = time.time()
last_step = start_step
for step, batch in enumerate(loader, start=start_step + 1):
last_step = step
x_chosen = batch['x_chosen'].to(args.device)
x_rejected = batch['x_rejected'].to(args.device)
y_chosen = batch['y_chosen'].to(args.device)
y_rejected = batch['y_rejected'].to(args.device)
mask_chosen = batch['mask_chosen'].to(args.device)
mask_rejected = batch['mask_rejected'].to(args.device)
x = torch.cat([x_chosen, x_rejected], dim=0)
y = torch.cat([y_chosen, y_rejected], dim=0)
mask = torch.cat([mask_chosen, mask_rejected], dim=0)
lr = get_lr(epoch * iters + step, args.epochs * iters, args.learning_rate)
for param_group in optimizer.param_groups:
param_group['lr'] = lr
with autocast_ctx:
with torch.no_grad():
ref_outputs = ref_model(x)
ref_logits = ref_outputs.logits
ref_log_probs = logits_to_log_probs(ref_logits, y)
outputs = model(x)
logits = outputs.logits
policy_log_probs = logits_to_log_probs(logits, y)
dpo_loss_val = dpo_loss(ref_log_probs, policy_log_probs, mask, beta=beta)
loss = dpo_loss_val + outputs.aux_loss
loss = loss / args.accumulation_steps
scaler.scale(loss).backward()
if step % args.accumulation_steps == 0:
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad(set_to_none=True)
if step % args.log_interval == 0 or step == iters:
spend_time = time.time() - start_time
current_loss = loss.item() * args.accumulation_steps
current_dpo_loss = dpo_loss_val.item()
current_aux_loss = outputs.aux_loss.item()
current_lr = optimizer.param_groups[-1]['lr']
eta_min = spend_time / max(step - start_step, 1) * (iters - step) // 60
Logger(f'Epoch:[{epoch + 1}/{args.epochs}]({step}/{iters}), loss: {current_loss:.4f}, dpo_loss: {current_dpo_loss:.4f}, aux_loss: {current_aux_loss:.4f}, learning_rate: {current_lr:.8f}, epoch_time: {eta_min:.3f}min')
if wandb: wandb.log({"loss": current_loss, "dpo_loss": current_dpo_loss, "aux_loss": current_aux_loss, "learning_rate": current_lr, "epoch_time": eta_min})
if (step % args.save_interval == 0 or step == iters) and is_main_process():
model.eval()
moe_suffix = '_moe' if lm_config.use_moe else ''
ckp = f'{args.save_dir}/{args.save_weight}_{lm_config.hidden_size}{moe_suffix}.pth'
raw_model = model.module if isinstance(model, DistributedDataParallel) else model
raw_model = getattr(raw_model, '_orig_mod', raw_model)
state_dict = raw_model.state_dict()
torch.save({k: v.half().cpu() for k, v in state_dict.items()}, ckp)
lm_checkpoint(lm_config, weight=args.save_weight, model=model, optimizer=optimizer, scaler=scaler, epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints')
model.train()
del state_dict
del x_chosen, x_rejected, y_chosen, y_rejected, mask_chosen, mask_rejected, x, y, mask
del ref_outputs, ref_logits, ref_log_probs, outputs, logits, policy_log_probs, loss
if last_step > start_step and last_step % args.accumulation_steps != 0:
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad(set_to_none=True)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="MiniMind DPO (Direct Preference Optimization)")
parser.add_argument("--save_dir", type=str, default="../out", help="模型保存目录")
parser.add_argument('--save_weight', default='dpo', type=str, help="保存权重的前缀名")
parser.add_argument("--epochs", type=int, default=1, help="训练轮数")
parser.add_argument("--batch_size", type=int, default=4, help="batch size")
parser.add_argument("--learning_rate", type=float, default=4e-8, help="初始学习率(建议<=5e-8避免遗忘")
parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu", help="训练设备")
parser.add_argument("--dtype", type=str, default="bfloat16", help="混合精度类型")
parser.add_argument("--num_workers", type=int, default=8, help="数据加载线程数")
parser.add_argument("--accumulation_steps", type=int, default=1, help="梯度累积步数")
parser.add_argument("--grad_clip", type=float, default=1.0, help="梯度裁剪阈值")
parser.add_argument("--log_interval", type=int, default=100, help="日志打印间隔")
parser.add_argument("--save_interval", type=int, default=100, help="模型保存间隔")
parser.add_argument('--hidden_size', default=768, type=int, help="隐藏层维度")
parser.add_argument('--num_hidden_layers', default=8, type=int, help="隐藏层数量")
parser.add_argument('--max_seq_len', default=1024, type=int, help="训练的最大截断长度中文1token≈1.5~1.7字符)")
parser.add_argument('--use_moe', default=0, type=int, choices=[0, 1], help="是否使用MoE架构0=否1=是)")
parser.add_argument("--data_path", type=str, default="../dataset/dpo.jsonl", help="DPO训练数据路径")
parser.add_argument('--from_weight', default='full_sft', type=str, help="基于哪个权重训练")
parser.add_argument('--from_resume', default=0, type=int, choices=[0, 1], help="是否自动检测&续训0=否1=是)")
parser.add_argument('--beta', default=0.15, type=float, help="DPO中的beta参数")
parser.add_argument("--use_wandb", action="store_true", help="是否使用wandb")
parser.add_argument("--wandb_project", type=str, default="MiniMind-DPO", help="wandb项目名")
parser.add_argument("--use_compile", default=0, type=int, choices=[0, 1], help="是否使用torch.compile加速0=否1=是)")
args = parser.parse_args()
# ========== 1. 初始化环境和随机种子 ==========
local_rank = init_distributed_mode()
if dist.is_initialized(): args.device = f"cuda:{local_rank}"
setup_seed(42 + (dist.get_rank() if dist.is_initialized() else 0))
# ========== 2. 配置目录、模型参数、检查ckp ==========
os.makedirs(args.save_dir, exist_ok=True)
lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers, use_moe=bool(args.use_moe))
ckp_data = lm_checkpoint(lm_config, weight=args.save_weight, save_dir='../checkpoints') if args.from_resume==1 else None
# ========== 3. 设置混合精度 ==========
device_type = "cuda" if "cuda" in args.device else "cpu"
dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16
autocast_ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=dtype)
# ========== 4. 配wandb ==========
wandb = None
if args.use_wandb and is_main_process():
import swanlab as wandb
wandb_id = ckp_data.get('wandb_id') if ckp_data else None
resume = 'must' if wandb_id else None
wandb_run_name = f"MiniMind-DPO-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LR-{args.learning_rate}"
wandb.init(project=args.wandb_project, name=wandb_run_name, id=wandb_id, resume=resume)
# ========== 5. 定义模型和参考模型 ==========
model, tokenizer = init_model(lm_config, args.from_weight, device=args.device)
Logger(f'策略模型总参数量:{sum(p.numel() for p in model.parameters()) / 1e6:.3f} M')
# 初始化参考模型ref_model冻结
ref_model, _ = init_model(lm_config, args.from_weight, device=args.device)
ref_model.eval()
ref_model.requires_grad_(False)
Logger(f'参考模型总参数量:{sum(p.numel() for p in ref_model.parameters()) / 1e6:.3f} M')
train_ds = DPODataset(args.data_path, tokenizer, max_length=args.max_seq_len)
train_sampler = DistributedSampler(train_ds) if dist.is_initialized() else None
scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype == 'float16'))
optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)
# ========== 6. 从ckp恢复状态 ==========
start_epoch, start_step = 0, 0
if ckp_data:
model.load_state_dict(ckp_data['model'])
optimizer.load_state_dict(ckp_data['optimizer'])
scaler.load_state_dict(ckp_data['scaler'])
start_epoch = ckp_data['epoch']
start_step = ckp_data.get('step', 0)
# ========== 7. 编译和分布式包装 ==========
if args.use_compile == 1:
model = torch.compile(model)
Logger('torch.compile enabled')
if dist.is_initialized():
model = DistributedDataParallel(model, device_ids=[local_rank])
# ========== 8. 开始训练 ==========
for epoch in range(start_epoch, args.epochs):
train_sampler and train_sampler.set_epoch(epoch)
setup_seed(42 + epoch); indices = torch.randperm(len(train_ds)).tolist()
skip = start_step if (epoch == start_epoch and start_step > 0) else 0
batch_sampler = SkipBatchSampler(train_sampler or indices, args.batch_size, skip)
loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True)
if skip > 0:
Logger(f'Epoch [{epoch + 1}/{args.epochs}]: 跳过前{start_step}个step从step {start_step + 1}开始')
train_epoch(epoch, loader, len(loader) + skip, ref_model, lm_config, start_step, wandb, args.beta)
else:
train_epoch(epoch, loader, len(loader), ref_model, lm_config, 0, wandb, args.beta)
# ========== 9. 清理分布进程 ==========
if dist.is_initialized(): dist.destroy_process_group()

170
trainer/train_full_sft.py Normal file
View File

@ -0,0 +1,170 @@
import os
import sys
__package__ = "trainer"
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
import argparse
import time
import warnings
import torch
import torch.distributed as dist
from contextlib import nullcontext
from torch import optim, nn
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import DataLoader, DistributedSampler
from model.model_minimind import MiniMindConfig
from dataset.lm_dataset import SFTDataset
from trainer.trainer_utils import get_lr, Logger, is_main_process, lm_checkpoint, init_distributed_mode, setup_seed, init_model, SkipBatchSampler
warnings.filterwarnings('ignore')
def train_epoch(epoch, loader, iters, start_step=0, wandb=None):
start_time = time.time()
last_step = start_step
for step, (input_ids, labels) in enumerate(loader, start=start_step + 1):
input_ids = input_ids.to(args.device)
labels = labels.to(args.device)
last_step = step
lr = get_lr(epoch * iters + step, args.epochs * iters, args.learning_rate)
for param_group in optimizer.param_groups:
param_group['lr'] = lr
with autocast_ctx:
res = model(input_ids, labels=labels)
loss = res.loss + res.aux_loss
loss = loss / args.accumulation_steps
scaler.scale(loss).backward()
if step % args.accumulation_steps == 0:
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad(set_to_none=True)
if step % args.log_interval == 0 or step == iters:
spend_time = time.time() - start_time
current_loss = loss.item() * args.accumulation_steps
current_aux_loss = res.aux_loss.item() if res.aux_loss is not None else 0.0
current_logits_loss = current_loss - current_aux_loss
current_lr = optimizer.param_groups[-1]['lr']
eta_min = spend_time / max(step - start_step, 1) * (iters - step) // 60
Logger(f'Epoch:[{epoch + 1}/{args.epochs}]({step}/{iters}), loss: {current_loss:.4f}, logits_loss: {current_logits_loss:.4f}, aux_loss: {current_aux_loss:.4f}, lr: {current_lr:.8f}, epoch_time: {eta_min:.1f}min')
if wandb: wandb.log({"loss": current_loss, "logits_loss": current_logits_loss, "aux_loss": current_aux_loss, "learning_rate": current_lr, "epoch_time": eta_min})
if (step % args.save_interval == 0 or step == iters) and is_main_process():
model.eval()
moe_suffix = '_moe' if lm_config.use_moe else ''
ckp = f'{args.save_dir}/{args.save_weight}_{lm_config.hidden_size}{moe_suffix}.pth'
raw_model = model.module if isinstance(model, DistributedDataParallel) else model
raw_model = getattr(raw_model, '_orig_mod', raw_model)
state_dict = raw_model.state_dict()
torch.save({k: v.half().cpu() for k, v in state_dict.items()}, ckp)
lm_checkpoint(lm_config, weight=args.save_weight, model=model, optimizer=optimizer,
epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints', scaler=scaler)
model.train()
del state_dict
del input_ids, labels, res, loss
if last_step > start_step and last_step % args.accumulation_steps != 0:
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad(set_to_none=True)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="MiniMind Full SFT")
parser.add_argument("--save_dir", type=str, default="../out", help="模型保存目录")
parser.add_argument('--save_weight', default='full_sft', type=str, help="保存权重的前缀名")
parser.add_argument("--epochs", type=int, default=2, help="训练轮数")
parser.add_argument("--batch_size", type=int, default=16, help="batch size")
parser.add_argument("--learning_rate", type=float, default=1e-5, help="初始学习率")
parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu", help="训练设备")
parser.add_argument("--dtype", type=str, default="bfloat16", help="混合精度类型")
parser.add_argument("--num_workers", type=int, default=8, help="数据加载线程数")
parser.add_argument("--accumulation_steps", type=int, default=1, help="梯度累积步数")
parser.add_argument("--grad_clip", type=float, default=1.0, help="梯度裁剪阈值")
parser.add_argument("--log_interval", type=int, default=100, help="日志打印间隔")
parser.add_argument("--save_interval", type=int, default=1000, help="模型保存间隔")
parser.add_argument('--hidden_size', default=768, type=int, help="隐藏层维度")
parser.add_argument('--num_hidden_layers', default=8, type=int, help="隐藏层数量")
parser.add_argument('--max_seq_len', default=768, type=int, help="训练的最大截断长度中文1token≈1.5~1.7字符)")
parser.add_argument('--use_moe', default=0, type=int, choices=[0, 1], help="是否使用MoE架构0=否1=是)")
parser.add_argument("--data_path", type=str, default="../dataset/sft_t2t_mini.jsonl", help="训练数据路径")
parser.add_argument('--from_weight', default='pretrain', type=str, help="基于哪个权重训练为none则不基于任何权重训练")
parser.add_argument('--from_resume', default=0, type=int, choices=[0, 1], help="是否自动检测&续训0=否1=是)")
parser.add_argument("--use_wandb", action="store_true", help="是否使用wandb")
parser.add_argument("--wandb_project", type=str, default="MiniMind-Full-SFT", help="wandb项目名")
parser.add_argument("--use_compile", default=0, type=int, choices=[0, 1], help="是否使用torch.compile加速0=否1=是)")
args = parser.parse_args()
# ========== 1. 初始化环境和随机种子 ==========
local_rank = init_distributed_mode()
if dist.is_initialized(): args.device = f"cuda:{local_rank}"
setup_seed(42 + (dist.get_rank() if dist.is_initialized() else 0))
# ========== 2. 配置目录、模型参数、检查ckp ==========
os.makedirs(args.save_dir, exist_ok=True)
lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers, use_moe=bool(args.use_moe))
ckp_data = lm_checkpoint(lm_config, weight=args.save_weight, save_dir='../checkpoints') if args.from_resume==1 else None
# ========== 3. 设置混合精度 ==========
device_type = "cuda" if "cuda" in args.device else "cpu"
dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16
autocast_ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=dtype)
# ========== 4. 配wandb ==========
wandb = None
if args.use_wandb and is_main_process():
import swanlab as wandb
wandb_id = ckp_data.get('wandb_id') if ckp_data else None
resume = 'must' if wandb_id else None
wandb_run_name = f"MiniMind-Full-SFT-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LearningRate-{args.learning_rate}"
wandb.init(project=args.wandb_project, name=wandb_run_name, id=wandb_id, resume=resume)
# ========== 5. 定义模型、数据、优化器 ==========
model, tokenizer = init_model(lm_config, args.from_weight, device=args.device)
train_ds = SFTDataset(args.data_path, tokenizer, max_length=args.max_seq_len)
train_sampler = DistributedSampler(train_ds) if dist.is_initialized() else None
scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype == 'float16'))
optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)
# ========== 6. 从ckp恢复状态 ==========
start_epoch, start_step = 0, 0
if ckp_data:
model.load_state_dict(ckp_data['model'])
optimizer.load_state_dict(ckp_data['optimizer'])
scaler.load_state_dict(ckp_data['scaler'])
start_epoch = ckp_data['epoch']
start_step = ckp_data.get('step', 0)
# ========== 7. 编译和分布式包装 ==========
if args.use_compile == 1:
model = torch.compile(model)
Logger('torch.compile enabled')
if dist.is_initialized():
model = DistributedDataParallel(model, device_ids=[local_rank])
# ========== 8. 开始训练 ==========
for epoch in range(start_epoch, args.epochs):
train_sampler and train_sampler.set_epoch(epoch)
setup_seed(42 + epoch); indices = torch.randperm(len(train_ds)).tolist()
skip = start_step if (epoch == start_epoch and start_step > 0) else 0
batch_sampler = SkipBatchSampler(train_sampler or indices, args.batch_size, skip)
loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True)
if skip > 0:
Logger(f'Epoch [{epoch + 1}/{args.epochs}]: 跳过前{start_step}个step从step {start_step + 1}开始')
train_epoch(epoch, loader, len(loader) + skip, start_step, wandb)
else:
train_epoch(epoch, loader, len(loader), 0, wandb)
# ========== 9. 清理分布进程 ==========
if dist.is_initialized(): dist.destroy_process_group()

330
trainer/train_grpo.py Executable file
View File

@ -0,0 +1,330 @@
import os
import sys
__package__ = "trainer"
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
import argparse
import math
import re
import gc
import warnings
import torch
import torch.nn.functional as F
import torch.distributed as dist
from transformers import AutoTokenizer
from contextlib import nullcontext
from torch import optim
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import DataLoader, DistributedSampler
from torch.optim.lr_scheduler import CosineAnnealingLR
from transformers import AutoModel
from model.model_minimind import MiniMindConfig, MiniMindForCausalLM
from dataset.lm_dataset import RLAIFDataset
from trainer.trainer_utils import Logger, is_main_process, lm_checkpoint, init_distributed_mode, setup_seed, SkipBatchSampler, init_model, LMForRewardModel
from trainer.rollout_engine import create_rollout_engine
warnings.filterwarnings('ignore')
def rep_penalty(text, n=3, cap=0.5):
toks = re.findall(r"\w+|[^\w\s]", text.lower())
grams = [tuple(toks[i:i + n]) for i in range(len(toks) - n + 1)]
return min(cap, (len(grams) - len(set(grams))) * cap * 2 / len(grams)) if grams else 0.0
def calculate_rewards(prompts, responses, reward_model):
rewards = torch.zeros(len(responses), device=args.device)
with torch.no_grad():
reward_model_scores = []
batch_size = len(prompts)
for i in range(batch_size):
for j in range(args.num_generations):
response_idx = i * args.num_generations + j
response = responses[response_idx]
prompt = prompts[i]
pattern = r"<\|im_start\|>(system|user|assistant)\s+(.*?)<\|im_end\|>"
matches = re.findall(pattern, prompt, re.DOTALL)
messages = [{"role": role, "content": content.strip()} for role, content in matches]
answer = response
rewards[response_idx] += 0.5 if 20 <= len(response.strip()) <= 800 else -0.5
if '</think>' in response:
thinking_content, answer_content = response.split('</think>', 1)
rewards[response_idx] += 1.0 if 20 <= len(thinking_content.strip()) <= 300 else -0.5
rewards[response_idx] += 0.25 if response.count('</think>') == 1 else -0.25
answer = answer_content.strip()
rewards[response_idx] -= rep_penalty(answer)
score = reward_model.get_score(messages, answer)
reward_model_scores.append(score)
reward_model_scores = torch.tensor(reward_model_scores, device=args.device)
rewards += reward_model_scores
return rewards
def grpo_train_epoch(epoch, loader, iters, rollout_engine, ref_model, reward_model, start_step=0, wandb=None, use_sglang=False):
for step, batch in enumerate(loader, start=start_step + 1):
prompts = batch['prompt'] # list[str], length B
prompt_inputs = tokenizer(prompts, return_tensors="pt", padding=True, return_token_type_ids=False,
padding_side="left", add_special_tokens=False).to(args.device)
if args.max_seq_len:
prompt_inputs["input_ids"] = prompt_inputs["input_ids"][:, -args.max_seq_len:]
prompt_inputs["attention_mask"] = prompt_inputs["attention_mask"][:, -args.max_seq_len:]
rollout_result = rollout_engine.rollout(
prompt_ids=prompt_inputs["input_ids"],
attention_mask=prompt_inputs["attention_mask"],
num_generations=args.num_generations,
max_new_tokens=args.max_gen_len,
temperature=0.8,
)
outputs = rollout_result.output_ids
completion_ids = rollout_result.completion_ids
completions = rollout_result.completions
old_per_token_logps = rollout_result.per_token_logps.to(args.device)
prompt_lens = rollout_result.prompt_lens.to(args.device)
full_mask = (outputs != tokenizer.pad_token_id).long()
logp_pos = prompt_lens.unsqueeze(1) - 1 + torch.arange(completion_ids.size(1), device=args.device).unsqueeze(0)
model_unwrapped = model.module if isinstance(model, DistributedDataParallel) else model
with autocast_ctx:
res = model_unwrapped(outputs, attention_mask=full_mask)
aux_loss = res.aux_loss if lm_config.use_moe else torch.tensor(0.0, device=args.device)
per_token_logps = F.log_softmax(res.logits[:, :-1, :], dim=-1).gather(2, outputs[:, 1:].unsqueeze(-1)).squeeze(-1).gather(1, logp_pos)
with torch.no_grad():
ref_per_token_logps = F.log_softmax(ref_model(outputs, attention_mask=full_mask).logits[:, :-1, :], dim=-1).gather(2, outputs[:, 1:].unsqueeze(-1)).squeeze(-1).gather(1, logp_pos)
rewards = calculate_rewards(prompts, completions, reward_model).to(args.device) # [B*num_gen]
if args.debug_mode and is_main_process() and step % args.debug_interval == 0:
for i in range(len(prompts)):
Logger(f"[DEBUG] step={step}, sample[{i}]")
Logger('-'*100)
Logger(f"{'=' * 30} [DEBUG] sample[{i}] CONTEXT_BEGIN {'=' * 30}")
Logger(prompts[i])
Logger(f"{'=' * 31} [DEBUG] sample[{i}] CONTEXT_END {'=' * 31}")
for j in range(args.num_generations):
idx = i * args.num_generations + j
Logger(f"{'=' * 28} [DEBUG] gen[{j}] RESPONSE_BEGIN {'=' * 28}")
Logger(completions[idx])
Logger(f"{'=' * 29} [DEBUG] gen[{j}] RESPONSE_END {'=' * 29}")
Logger(f"[DEBUG] gen[{j}] reward={rewards[idx].item():.4f}")
Logger('='*100)
grouped_rewards = rewards.view(-1, args.num_generations) # [B, num_gen]
mean_r = grouped_rewards.mean(dim=1).repeat_interleave(args.num_generations) # [B*num_gen]
std_r = grouped_rewards.std(dim=1, unbiased=False).repeat_interleave(args.num_generations) # [B*num_gen]
advantages = (rewards - mean_r) / (std_r + 1e-4) # [B*num_gen]
completion_pad_mask = rollout_result.completion_mask.to(args.device).bool()
is_eos = (completion_ids == tokenizer.eos_token_id) & completion_pad_mask # [B*num_gen, R]
eos_idx = torch.full((is_eos.size(0),), is_eos.size(1) - 1, dtype=torch.long, device=args.device)
eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)]
completion_mask = ((torch.arange(is_eos.size(1), device=args.device).expand(is_eos.size(0), -1) <= eos_idx.unsqueeze(1)) & completion_pad_mask).int() # [B*num_gen, R]
kl_div = ref_per_token_logps - per_token_logps
per_token_kl = torch.exp(kl_div) - kl_div - 1 # [B*num_gen, R]
ratio = torch.exp(per_token_logps - old_per_token_logps) # [B*num_gen, R]
if args.loss_type == "cispo":
clamped_ratio = torch.clamp(ratio, max=args.epsilon_high).detach()
per_token_loss = -(clamped_ratio * advantages.unsqueeze(1) * per_token_logps - args.beta * per_token_kl)
else:
clipped_ratio = torch.clamp(ratio, 1 - args.epsilon, 1 + args.epsilon)
per_token_loss1 = ratio * advantages.unsqueeze(1)
per_token_loss2 = clipped_ratio * advantages.unsqueeze(1)
per_token_loss = -(torch.min(per_token_loss1, per_token_loss2) - args.beta * per_token_kl)
policy_loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1).clamp(min=1)).mean()
loss = (policy_loss + aux_loss) / args.accumulation_steps # scalar
loss.backward()
if step % args.accumulation_steps == 0:
if args.grad_clip > 0:
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
optimizer.step()
scheduler.step()
optimizer.zero_grad()
if step % args.log_interval == 0 or step == iters:
policy_loss_val = loss.item() * args.accumulation_steps
current_aux_loss = aux_loss.item()
avg_reward_val = rewards.mean().item()
avg_len_val = completion_mask.sum(dim=1).float().mean().item()
kl_ref_val = ((ref_per_token_logps - per_token_logps) * completion_mask).sum().item() / max(completion_mask.sum().item(), 1)
advantages_mean_val = advantages.mean().item()
advantages_std_val = advantages.std().item()
current_lr = optimizer.param_groups[0]['lr']
Logger(f'Epoch:[{epoch + 1}/{args.epochs}]({step}/{iters}), '
f'Reward: {avg_reward_val:.4f}, KL_ref: {kl_ref_val:.4f}, '
f'Adv Std: {advantages_std_val:.4f}, Adv Mean: {advantages_mean_val:.4f}, '
f'Actor Loss: {policy_loss_val:.4f}, Avg Response Len: {avg_len_val:.2f}, Learning Rate: {current_lr:.8f}')
if wandb and is_main_process():
wandb.log({
"reward": avg_reward_val,
"kl_ref": kl_ref_val,
"advantages_std": advantages_std_val,
"advantages_mean": advantages_mean_val,
"policy_loss": policy_loss_val,
"avg_response_len": avg_len_val,
"learning_rate": current_lr
})
if (step % args.save_interval == 0 or step == iters) and is_main_process():
model.eval()
moe_suffix = '_moe' if lm_config.use_moe else ''
ckp = f'{args.save_dir}/{args.save_weight}_{lm_config.hidden_size}{moe_suffix}.pth'
raw_model = model.module if isinstance(model, DistributedDataParallel) else model
raw_model = getattr(raw_model, '_orig_mod', raw_model)
state_dict = raw_model.state_dict()
torch.save({k: v.half().cpu() for k, v in state_dict.items()}, ckp)
lm_checkpoint(lm_config, weight=args.save_weight, model=model, optimizer=optimizer,
epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints', scheduler=scheduler)
model.train()
del state_dict
if step % args.save_interval == 0 or step == iters: rollout_engine.update_policy(model)
del prompt_inputs, outputs, completion_ids, per_token_logps, ref_per_token_logps
del completions, rewards, grouped_rewards, mean_r, std_r, advantages, completion_mask, completion_pad_mask, prompt_lens, logp_pos
if step > start_step and step % args.accumulation_steps != 0:
if args.grad_clip > 0:
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
optimizer.step()
scheduler.step()
optimizer.zero_grad()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="MiniMind GRPO (Group Relative Policy Optimization)")
parser.add_argument("--save_dir", type=str, default="../out", help="模型保存目录")
parser.add_argument('--save_weight', default='grpo', type=str, help="保存权重的前缀名")
parser.add_argument("--epochs", type=int, default=1, help="训练轮数")
parser.add_argument("--batch_size", type=int, default=2, help="batch size")
parser.add_argument("--learning_rate", type=float, default=3e-7, help="初始学习率")
parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu", help="训练设备")
parser.add_argument("--dtype", type=str, default="bfloat16", help="混合精度类型")
parser.add_argument("--num_workers", type=int, default=8, help="数据加载线程数")
parser.add_argument("--accumulation_steps", type=int, default=1, help="梯度累积步数")
parser.add_argument("--grad_clip", type=float, default=1.0, help="梯度裁剪阈值")
parser.add_argument("--log_interval", type=int, default=1, help="日志打印间隔")
parser.add_argument("--save_interval", type=int, default=10, help="模型保存间隔")
parser.add_argument('--hidden_size', default=768, type=int, help="隐藏层维度")
parser.add_argument('--num_hidden_layers', default=8, type=int, help="隐藏层数量")
parser.add_argument('--use_moe', default=0, type=int, choices=[0, 1], help="是否使用MoE架构0=否1=是)")
parser.add_argument('--max_seq_len', default=768, type=int, help="Prompt最大长度")
parser.add_argument("--max_gen_len", type=int, default=1024, help="生成的最大长度")
parser.add_argument("--data_path", type=str, default="../dataset/rlaif.jsonl", help="RLAIF数据路径")
parser.add_argument("--num_generations", type=int, default=6, help="每个prompt生成的样本数")
parser.add_argument("--beta", type=float, default=0.1, help="KL惩罚系数")
parser.add_argument("--loss_type", type=str, default="cispo", choices=["grpo", "cispo"], help="loss类型")
parser.add_argument("--epsilon", type=float, default=0.2, help="GRPO的PPO clip epsilon")
parser.add_argument("--epsilon_high", type=float, default=5.0, help="epsilon上界")
parser.add_argument('--from_weight', default='full_sft', type=str, help="基于哪个权重训练")
parser.add_argument("--reward_model_path", type=str, default="../../internlm2-1_8b-reward", help="Reward模型路径")
parser.add_argument('--from_resume', default=0, type=int, choices=[0, 1], help="是否自动检测&续训0=否1=是)")
parser.add_argument("--use_wandb", action="store_true", help="是否使用wandb")
parser.add_argument("--wandb_project", type=str, default="MiniMind-GRPO", help="wandb项目名")
parser.add_argument("--use_compile", default=0, type=int, choices=[0, 1], help="是否使用torch.compile加速0=否1=是)")
parser.add_argument("--debug_mode", action="store_true", help="是否打印训练调试采样")
parser.add_argument("--debug_interval", type=int, default=20, help="debug模式下每隔多少step打印一次采样")
parser.add_argument("--thinking_ratio", type=float, default=0.9, help="按概率开启thinking0.0~1.0")
parser.add_argument("--rollout_engine", type=str, default="torch", choices=["torch", "sglang"], help="rollout引擎类型")
parser.add_argument("--sglang_base_url", type=str, default="http://localhost:8998", help="SGLang服务器URL")
parser.add_argument("--sglang_model_path", type=str, default="../model", help="SGLang tokenizer路径")
parser.add_argument("--sglang_shared_path", type=str, default="./sglang_ckpt_grpo", help="SGLang共享存储路径")
args = parser.parse_args()
# ========== 1. 初始化环境和随机种子 ==========
local_rank = init_distributed_mode()
if dist.is_initialized(): args.device = f"cuda:{local_rank}"
setup_seed(42 + (dist.get_rank() if dist.is_initialized() else 0))
# ========== 2. 配置目录、模型参数、检查ckp ==========
os.makedirs(args.save_dir, exist_ok=True)
lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers,
max_seq_len=args.max_seq_len + args.max_gen_len, use_moe=bool(args.use_moe))
ckp_data = lm_checkpoint(lm_config, weight=args.save_weight, save_dir='../checkpoints') if args.from_resume==1 else None
# ========== 3. 设置混合精度 ==========
device_type = "cuda" if "cuda" in args.device else "cpu"
dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16
autocast_ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=dtype)
# ========== 4. 配wandb ==========
wandb = None
if args.use_wandb and is_main_process():
import swanlab as wandb
wandb_id = ckp_data.get('wandb_id') if ckp_data else None
resume = 'must' if wandb_id else None
wandb_run_name = f"MiniMind-GRPO-Epoch-{args.epochs}-BS-{args.batch_size}-LR-{args.learning_rate}"
wandb.init(project=args.wandb_project, name=wandb_run_name, id=wandb_id, resume=resume)
# ========== 5. 初始化模型和数据 ==========
base_weight = args.from_weight
# Policy模型
model, tokenizer = init_model(lm_config, base_weight, device=args.device)
# Reference模型
ref_model, _ = init_model(lm_config, base_weight, device=args.device)
ref_model = ref_model.eval().requires_grad_(False)
# Reward模型
reward_model = LMForRewardModel(args.reward_model_path, device=args.device, dtype=torch.float16)
# Rollout引擎可插拔替换只负责 policy 推理)
rollout_engine = create_rollout_engine(
engine_type=args.rollout_engine,
policy_model=model,
tokenizer=tokenizer,
device=args.device,
autocast_ctx=autocast_ctx,
sglang_base_url=args.sglang_base_url,
sglang_model_path=args.sglang_model_path,
sglang_shared_path=args.sglang_shared_path,
)
# 数据和优化器
train_ds = RLAIFDataset(args.data_path, tokenizer, max_length=lm_config.max_seq_len, thinking_ratio=args.thinking_ratio)
train_sampler = DistributedSampler(train_ds) if dist.is_initialized() else None
optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)
loader_for_count = DataLoader(train_ds, batch_size=args.batch_size, sampler=train_sampler)
iters = len(loader_for_count)
total_optimizer_steps = math.ceil(iters / args.accumulation_steps) * args.epochs
scheduler = CosineAnnealingLR(optimizer, T_max=total_optimizer_steps, eta_min=args.learning_rate / 10)
# ========== 6. 从ckp恢复状态 ==========
start_epoch, start_step = 0, 0
if ckp_data:
model.load_state_dict(ckp_data['model'])
optimizer.load_state_dict(ckp_data['optimizer'])
scheduler.load_state_dict(ckp_data['scheduler'])
start_epoch = ckp_data['epoch']
start_step = ckp_data.get('step', 0)
# ========== 7. 编译和分布式包装 ==========
if args.use_compile == 1:
model = torch.compile(model)
Logger('torch.compile enabled')
rollout_engine.update_policy(model)
if dist.is_initialized():
model = DistributedDataParallel(model, device_ids=[local_rank])
rollout_engine.update_policy(model)
# ========== 8. 开始训练 ==========
for epoch in range(start_epoch, args.epochs):
train_sampler and train_sampler.set_epoch(epoch)
setup_seed(42 + epoch); indices = torch.randperm(len(train_ds)).tolist()
skip = start_step if (epoch == start_epoch and start_step > 0) else 0
batch_sampler = SkipBatchSampler(train_sampler or indices, args.batch_size, skip)
loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True)
if skip > 0:
Logger(f'Epoch [{epoch + 1}/{args.epochs}]: 跳过前{start_step}个step从step {start_step + 1}开始')
grpo_train_epoch(epoch, loader, len(loader) + skip, rollout_engine, ref_model, reward_model, start_step, wandb, use_sglang = (args.rollout_engine == "sglang"))
else:
grpo_train_epoch(epoch, loader, len(loader), rollout_engine, ref_model, reward_model, 0, wandb, use_sglang = (args.rollout_engine == "sglang"))
# ========== 9. 清理分布进程 ==========
if dist.is_initialized(): dist.destroy_process_group()

183
trainer/train_lora.py Normal file
View File

@ -0,0 +1,183 @@
import os
import sys
__package__ = "trainer"
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
import argparse
import time
import warnings
import torch
import torch.distributed as dist
from contextlib import nullcontext
from torch import optim, nn
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import DataLoader, DistributedSampler
from model.model_minimind import MiniMindConfig
from dataset.lm_dataset import SFTDataset
from model.model_lora import save_lora, apply_lora
from trainer.trainer_utils import get_lr, Logger, is_main_process, lm_checkpoint, init_distributed_mode, setup_seed, init_model, SkipBatchSampler
warnings.filterwarnings('ignore')
def train_epoch(epoch, loader, iters, lora_params, start_step=0, wandb=None):
start_time = time.time()
last_step = start_step
for step, (input_ids, labels) in enumerate(loader, start=start_step + 1):
input_ids = input_ids.to(args.device)
labels = labels.to(args.device)
last_step = step
lr = get_lr(epoch * iters + step, args.epochs * iters, args.learning_rate)
for param_group in optimizer.param_groups:
param_group['lr'] = lr
with autocast_ctx:
res = model(input_ids, labels=labels)
loss = res.loss + res.aux_loss
loss = loss / args.accumulation_steps
scaler.scale(loss).backward()
if step % args.accumulation_steps == 0:
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(lora_params, args.grad_clip)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad(set_to_none=True)
if step % args.log_interval == 0 or step == iters:
spend_time = time.time() - start_time
current_loss = loss.item() * args.accumulation_steps
current_aux_loss = res.aux_loss.item() if res.aux_loss is not None else 0.0
current_logits_loss = current_loss - current_aux_loss
current_lr = optimizer.param_groups[-1]['lr']
eta_min = spend_time / max(step - start_step, 1) * (iters - step) // 60
Logger(f'Epoch:[{epoch + 1}/{args.epochs}]({step}/{iters}), loss: {current_loss:.4f}, logits_loss: {current_logits_loss:.4f}, aux_loss: {current_aux_loss:.4f}, lr: {current_lr:.8f}, epoch_time: {eta_min:.1f}min')
if wandb: wandb.log({"loss": current_loss, "logits_loss": current_logits_loss, "aux_loss": current_aux_loss, "learning_rate": current_lr, "epoch_time": eta_min})
if (step % args.save_interval == 0 or step == iters) and is_main_process():
model.eval()
moe_suffix = '_moe' if lm_config.use_moe else ''
lora_save_path = f'{args.save_dir}/{args.lora_name}_{lm_config.hidden_size}{moe_suffix}.pth'
# LoRA只保存LoRA权重
save_lora(model, lora_save_path)
lm_checkpoint(lm_config, weight=args.lora_name, model=model, optimizer=optimizer, scaler=scaler, epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints')
model.train()
del input_ids, labels, res, loss
if last_step > start_step and last_step % args.accumulation_steps != 0:
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(lora_params, args.grad_clip)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad(set_to_none=True)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="MiniMind LoRA Fine-tuning")
parser.add_argument("--save_dir", type=str, default="../out", help="模型保存目录")
parser.add_argument("--lora_name", type=str, default="lora_medical", help="LoRA权重名称(如lora_identity/lora_medical等)")
parser.add_argument("--epochs", type=int, default=10, help="训练轮数")
parser.add_argument("--batch_size", type=int, default=32, help="batch size")
parser.add_argument("--learning_rate", type=float, default=1e-4, help="初始学习率")
parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu", help="训练设备")
parser.add_argument("--dtype", type=str, default="bfloat16", help="混合精度类型")
parser.add_argument("--num_workers", type=int, default=8, help="数据加载线程数")
parser.add_argument("--accumulation_steps", type=int, default=1, help="梯度累积步数")
parser.add_argument("--grad_clip", type=float, default=1.0, help="梯度裁剪阈值")
parser.add_argument("--log_interval", type=int, default=10, help="日志打印间隔")
parser.add_argument("--save_interval", type=int, default=1000, help="模型保存间隔")
parser.add_argument('--hidden_size', default=768, type=int, help="隐藏层维度")
parser.add_argument('--num_hidden_layers', default=8, type=int, help="隐藏层数量")
parser.add_argument('--max_seq_len', default=340, type=int, help="训练的最大截断长度中文1token≈1.5~1.7字符)")
parser.add_argument('--use_moe', default=0, type=int, choices=[0, 1], help="是否使用MoE架构0=否1=是)")
parser.add_argument("--data_path", type=str, default="../dataset/lora_medical.jsonl", help="LoRA训练数据路径")
parser.add_argument('--from_weight', default='full_sft', type=str, help="基于哪个权重训练默认full_sft")
parser.add_argument('--from_resume', default=0, type=int, choices=[0, 1], help="是否自动检测&续训0=否1=是)")
parser.add_argument("--use_wandb", action="store_true", help="是否使用wandb")
parser.add_argument("--wandb_project", type=str, default="MiniMind-LoRA", help="wandb项目名")
parser.add_argument("--use_compile", default=0, type=int, choices=[0, 1], help="是否使用torch.compile加速0=否1=是)")
args = parser.parse_args()
# ========== 1. 初始化环境和随机种子 ==========
local_rank = init_distributed_mode()
if dist.is_initialized(): args.device = f"cuda:{local_rank}"
setup_seed(42 + (dist.get_rank() if dist.is_initialized() else 0))
# ========== 2. 配置目录、模型参数、检查ckp ==========
os.makedirs(args.save_dir, exist_ok=True)
lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers, use_moe=bool(args.use_moe))
ckp_data = lm_checkpoint(lm_config, weight=args.lora_name, save_dir='../checkpoints') if args.from_resume==1 else None
# ========== 3. 设置混合精度 ==========
device_type = "cuda" if "cuda" in args.device else "cpu"
dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16
autocast_ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=dtype)
# ========== 4. 配wandb ==========
wandb = None
if args.use_wandb and is_main_process():
import swanlab as wandb
wandb_id = ckp_data.get('wandb_id') if ckp_data else None
resume = 'must' if wandb_id else None
wandb_run_name = f"MiniMind-LoRA-{args.lora_name}-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LR-{args.learning_rate}"
wandb.init(project=args.wandb_project, name=wandb_run_name, id=wandb_id, resume=resume)
# ========== 5. 定义模型、应用LoRA、冻结非LoRA参数 ==========
model, tokenizer = init_model(lm_config, args.from_weight, device=args.device)
apply_lora(model)
# 统计参数
total_params = sum(p.numel() for p in model.parameters())
lora_params_count = sum(p.numel() for name, p in model.named_parameters() if 'lora' in name)
Logger(f"LLM 总参数量: {total_params / 1e6:.3f} M")
Logger(f"LoRA 参数量: {lora_params_count / 1e6:.3f} M")
Logger(f"LoRA 参数占比: {lora_params_count / total_params * 100:.2f}%")
# 冻结非LoRA参数收集LoRA参数
lora_params = []
for name, param in model.named_parameters():
if 'lora' in name:
param.requires_grad = True
lora_params.append(param)
else:
param.requires_grad = False
# ========== 6. 定义数据和优化器 ==========
train_ds = SFTDataset(args.data_path, tokenizer, max_length=args.max_seq_len)
train_sampler = DistributedSampler(train_ds) if dist.is_initialized() else None
scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype == 'float16'))
optimizer = optim.AdamW(lora_params, lr=args.learning_rate)
# ========== 7. 从ckp恢复状态 ==========
start_epoch, start_step = 0, 0
if ckp_data:
model.load_state_dict(ckp_data['model'], strict=False)
optimizer.load_state_dict(ckp_data['optimizer'])
scaler.load_state_dict(ckp_data['scaler'])
start_epoch = ckp_data['epoch']
start_step = ckp_data.get('step', 0)
# ========== 8. 编译和分布式包装 ==========
if args.use_compile == 1:
args.use_compile = 0
Logger('[LoRA] monkey-patch forward 与 torch.compile 不兼容use_compile 已自动关闭')
if dist.is_initialized():
model = DistributedDataParallel(model, device_ids=[local_rank])
# ========== 9. 开始训练 ==========
for epoch in range(start_epoch, args.epochs):
train_sampler and train_sampler.set_epoch(epoch)
setup_seed(42 + epoch); indices = torch.randperm(len(train_ds)).tolist()
skip = start_step if (epoch == start_epoch and start_step > 0) else 0
batch_sampler = SkipBatchSampler(train_sampler or indices, args.batch_size, skip)
loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True)
if skip > 0:
Logger(f'Epoch [{epoch + 1}/{args.epochs}]: 跳过前{start_step}个step从step {start_step + 1}开始')
train_epoch(epoch, loader, len(loader) + skip, lora_params, start_step, wandb)
else:
train_epoch(epoch, loader, len(loader), lora_params, 0, wandb)
# ========== 10. 清理分布进程 ==========
if dist.is_initialized(): dist.destroy_process_group()

434
trainer/train_ppo.py Normal file
View File

@ -0,0 +1,434 @@
import os
import sys
__package__ = "trainer"
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
import argparse
import math
import re
import warnings
import torch
import torch.distributed as dist
import torch.nn.functional as F
from transformers import AutoTokenizer
from contextlib import nullcontext
from torch import optim, nn
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import DataLoader, DistributedSampler
from torch.nn.utils import clip_grad_norm_
from torch.optim.lr_scheduler import CosineAnnealingLR
from model.model_minimind import MiniMindConfig, MiniMindForCausalLM
from dataset.lm_dataset import RLAIFDataset
from trainer.trainer_utils import Logger, is_main_process, lm_checkpoint, init_distributed_mode, setup_seed, SkipBatchSampler, init_model, LMForRewardModel
from trainer.rollout_engine import create_rollout_engine
warnings.filterwarnings('ignore')
def rep_penalty(text, n=3, cap=0.5):
toks = re.findall(r"\w+|[^\w\s]", text.lower())
grams = [tuple(toks[i:i + n]) for i in range(len(toks) - n + 1)]
return min(cap, (len(grams) - len(set(grams))) * cap * 2 / len(grams)) if grams else 0.0
# 自定义的Critic模型继承自MiniMindLM
class CriticModel(MiniMindForCausalLM):
def __init__(self, params):
super().__init__(params)
# 替换lm_head为输出单一价值的线性层
self.value_head = nn.Linear(params.hidden_size, 1)
def forward(self, input_ids=None, attention_mask=None, **kwargs):
# 使用基础模型获取隐藏状态
outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, **kwargs)
hidden_states = self.model.norm(outputs[0])
# 使用value_head获取价值估计
values = self.value_head(hidden_states).squeeze(-1)
return values
def calculate_rewards(prompts, responses, reward_model):
rewards = torch.zeros(len(responses), device=args.device)
with torch.no_grad():
reward_model_scores = []
for i, (prompt, response) in enumerate(zip(prompts, responses)):
pattern = r"<\|im_start\|>(system|user|assistant)\s+(.*?)<\|im_end\|>"
matches = re.findall(pattern, prompt, re.DOTALL)
messages = [{"role": role, "content": content.strip()} for role, content in matches]
answer = response
rewards[i] += 0.5 if 20 <= len(response.strip()) <= 800 else -0.5
if '</think>' in response:
thinking_content, answer_content = response.split('</think>', 1)
rewards[i] += 1.0 if 20 <= len(thinking_content.strip()) <= 300 else -0.5
rewards[i] += 0.25 if response.count('</think>') == 1 else -0.25
answer = answer_content.strip()
rewards[i] -= rep_penalty(answer)
score = reward_model.get_score(messages, answer)
reward_model_scores.append(score)
reward_model_scores = torch.tensor(reward_model_scores, device=args.device)
rewards += reward_model_scores
return rewards
def ppo_train_epoch(epoch, loader, iters, rollout_engine, ref_model, actor_scheduler, critic_scheduler, reward_model, start_step=0, wandb=None, use_sglang=False):
actor_model.train()
critic_model.train()
grad_accum_step = 0
for step, batch in enumerate(loader, start=start_step + 1):
prompts = batch["prompt"] # list[str], length B
enc = tokenizer(prompts, return_tensors="pt", padding=True, truncation=True, max_length=args.max_seq_len,
padding_side="left").to(args.device) # input_ids: [B, P], attention_mask: [B, P]
rollout_result = rollout_engine.rollout(
prompt_ids=enc.input_ids,
attention_mask=enc.attention_mask,
num_generations=1,
max_new_tokens=args.max_gen_len,
temperature=0.8,
)
gen_out = rollout_result.output_ids
completion_ids = rollout_result.completion_ids
prompt_lens = rollout_result.prompt_lens.to(args.device)
responses_text = rollout_result.completions
old_resp_logp = rollout_result.per_token_logps.to(args.device)
rewards = calculate_rewards(prompts, responses_text, reward_model) # [B]
if args.debug_mode and is_main_process() and step % args.debug_interval == 0:
for i in range(len(prompts)):
Logger(f"[DEBUG] step={step}, sample[{i}]")
Logger('-'*100)
Logger(f"{'=' * 30} [DEBUG] sample[{i}] CONTEXT_BEGIN {'=' * 30}")
Logger(prompts[i])
Logger(f"{'=' * 31} [DEBUG] sample[{i}] CONTEXT_END {'=' * 31}")
Logger(f"[DEBUG] prompt_len={prompt_lens[i].item()}, response_len={len(responses_text[i])}")
Logger(f"{'=' * 28} [DEBUG] sample[{i}] RESPONSE_BEGIN {'=' * 28}")
Logger(responses_text[i])
Logger(f"{'=' * 29} [DEBUG] sample[{i}] RESPONSE_END {'=' * 29}")
Logger(f"[DEBUG] reward={rewards[i].item():.4f}")
Logger('='*100)
full_mask = (gen_out != tokenizer.pad_token_id).long() # [B, P+R]
labels = gen_out[:, 1:].clone() # [B, P+R-1]
B = len(prompts)
resp_labels = completion_ids
resp_idx = torch.arange(resp_labels.size(1), device=gen_out.device).unsqueeze(0)
logp_pos = prompt_lens.unsqueeze(1) - 1 + resp_idx
resp_pad_mask = rollout_result.completion_mask.to(args.device).bool()
resp_lengths = resp_pad_mask.sum(dim=1); valid_resp = resp_lengths > 0; eos_mask = resp_labels.eq(tokenizer.eos_token_id) & resp_pad_mask
has_eos = eos_mask.any(dim=1); eos_pos = torch.argmax(eos_mask.int(), dim=1)
resp_lengths = torch.where(has_eos, eos_pos + 1, resp_lengths).long().clamp(min=1)
resp_policy_mask = ((resp_idx < resp_lengths.unsqueeze(1)) & resp_pad_mask).float()
resp_value_mask = resp_policy_mask.clone()
with torch.no_grad(): # Rollout阶段只需推理获取old_logp和old_values切断梯度省显存
critic_for_rollout = critic_model.module if isinstance(critic_model, DistributedDataParallel) else critic_model
values_seq = critic_for_rollout(input_ids=gen_out, attention_mask=full_mask)
old_resp_values = values_seq.gather(1, logp_pos) * resp_value_mask
ref_resp_logp = F.log_softmax(ref_model(input_ids=gen_out, attention_mask=full_mask).logits[:, :-1], dim=-1).gather(2, labels.unsqueeze(-1)).squeeze(-1).gather(1, logp_pos)
token_rewards = torch.zeros_like(old_resp_logp)
last_idx = resp_lengths - 1 # [B]
token_rewards[torch.arange(B, device=args.device)[valid_resp], last_idx[valid_resp]] += rewards[valid_resp] # 末尾加外部奖励
gen_len = old_resp_values.size(1); lastgaelam = torch.zeros(B, device=args.device); advs_rev = []
for t in reversed(range(gen_len)):
nv = old_resp_values[:, t + 1] if t < gen_len - 1 else 0.0
delta = token_rewards[:, t] + args.gamma * nv - old_resp_values[:, t]
lastgaelam = delta + args.gamma * args.lam * lastgaelam
advs_rev.append(lastgaelam)
advantages = torch.stack(advs_rev[::-1], dim=1) # [B, R]
returns = advantages + old_resp_values # [B, R]
adv_mean = (advantages * resp_policy_mask).sum() / resp_policy_mask.sum().clamp(min=1)
adv_var = ((advantages - adv_mean) ** 2 * resp_policy_mask).sum() / resp_policy_mask.sum().clamp(min=1)
advantages = (advantages - adv_mean) * torch.rsqrt(adv_var + 1e-8) * resp_policy_mask
mb_size = max(1, min(args.mini_batch_size, B))
stop_ppo = False
policy_loss_sum = 0.0
value_loss_sum = 0.0
kl_sum = 0.0
kl_ref_sum = 0.0
clipfrac_sum = 0.0
aux_loss_sum = 0.0
log_count = 0
actor_unwrapped = actor_model.module if isinstance(actor_model, DistributedDataParallel) else actor_model
critic_unwrapped = critic_model.module if isinstance(critic_model, DistributedDataParallel) else critic_model
for ppo_epoch in range(args.ppo_update_iters):
if stop_ppo:
break
b_inds = torch.randperm(B, device=args.device)
for i in range(0, B, mb_size):
inds = b_inds[i:i + mb_size]
mb_values_seq = critic_unwrapped(input_ids=gen_out[inds], attention_mask=full_mask[inds])
mb_resp_values = mb_values_seq.gather(1, logp_pos[inds])
with autocast_ctx:
res = actor_unwrapped(input_ids=gen_out[inds], attention_mask=full_mask[inds])
aux_loss = res.aux_loss if lm_config.use_moe else torch.tensor(0.0, device=args.device)
mb_resp_logp = F.log_softmax(res.logits[:, :-1], dim=-1).gather(2, labels[inds].unsqueeze(-1)).squeeze(-1).gather(1, logp_pos[inds])
log_ratio = mb_resp_logp - old_resp_logp[inds]
approx_kl = (0.5 * (log_ratio ** 2) * resp_policy_mask[inds]).sum() / resp_policy_mask[inds].sum().clamp(min=1)
# 同步各卡的 approx_kl防止某卡 break 而其它卡继续导致 DDP 死锁
approx_kl_val = approx_kl.detach().clone()
if dist.is_initialized():
dist.all_reduce(approx_kl_val, op=dist.ReduceOp.AVG)
if approx_kl_val > args.early_stop_kl:
stop_ppo = True
ratio = torch.exp(log_ratio)
clipfrac = ((((ratio - 1.0).abs() > args.clip_epsilon).float() * resp_policy_mask[inds]).sum()
/ resp_policy_mask[inds].sum().clamp(min=1))
kl_ref_penalty = ((torch.exp(ref_resp_logp[inds] - mb_resp_logp) - (ref_resp_logp[inds] - mb_resp_logp) - 1.0)
* resp_policy_mask[inds]).sum() / resp_policy_mask[inds].sum().clamp(min=1)
policy_loss = ((torch.max(-advantages[inds] * ratio,
-advantages[inds] * torch.clamp(ratio, 1.0 - args.clip_epsilon, 1.0 + args.clip_epsilon))
* resp_policy_mask[inds]).sum() / resp_policy_mask[inds].sum().clamp(min=1)
+ args.kl_coef * kl_ref_penalty)
value_loss = 0.5 * (torch.max((mb_resp_values - returns[inds]) ** 2,
(torch.clamp(mb_resp_values, old_resp_values[inds] - args.cliprange_value,
old_resp_values[inds] + args.cliprange_value) - returns[inds]) ** 2)
* resp_value_mask[inds]).sum() / resp_value_mask[inds].sum().clamp(min=1)
kl = approx_kl_val
kl_ref = kl_ref_penalty.detach()
# 早停时必须保证 forward-backward 闭环,故只截断 loss 不中断 DDP 通信
if stop_ppo:
loss = (policy_loss + args.vf_coef * value_loss + aux_loss) * 0.0
else:
loss = (policy_loss + args.vf_coef * value_loss + aux_loss) / args.accumulation_steps
loss.backward()
policy_loss_sum += policy_loss.item()
value_loss_sum += value_loss.item()
kl_sum += kl.item()
kl_ref_sum += kl_ref.item()
clipfrac_sum += clipfrac.item()
aux_loss_sum += aux_loss.item()
log_count += 1
grad_accum_step += 1
if grad_accum_step % args.accumulation_steps == 0:
clip_grad_norm_(actor_model.parameters(), args.grad_clip)
clip_grad_norm_(critic_model.parameters(), args.grad_clip)
actor_optimizer.step()
critic_optimizer.step()
actor_scheduler.step()
critic_scheduler.step()
actor_optimizer.zero_grad()
critic_optimizer.zero_grad()
if grad_accum_step % args.accumulation_steps != 0:
clip_grad_norm_(actor_model.parameters(), args.grad_clip)
clip_grad_norm_(critic_model.parameters(), args.grad_clip)
actor_optimizer.step()
critic_optimizer.step()
actor_scheduler.step()
critic_scheduler.step()
actor_optimizer.zero_grad()
critic_optimizer.zero_grad()
if step % args.save_interval == 0 or step == iters: rollout_engine.update_policy(actor_model)
if is_main_process():
critic_loss_val = value_loss_sum / max(log_count, 1)
reward_val = rewards.mean().item()
approx_kl_val = kl_sum / max(log_count, 1)
kl_ref_val = kl_ref_sum / max(log_count, 1)
clipfrac_val = clipfrac_sum / max(log_count, 1)
avg_len_val = resp_lengths.float().mean().item()
actor_lr, critic_lr = actor_optimizer.param_groups[0]['lr'], critic_optimizer.param_groups[0]['lr']
if wandb is not None:
wandb.log({
"reward": reward_val,
"kl_ref": kl_ref_val,
"approx_kl": approx_kl_val,
"clipfrac": clipfrac_val,
"critic_loss": critic_loss_val,
"avg_response_len": avg_len_val,
"actor_lr": actor_lr,
"critic_lr": critic_lr,
})
Logger(f"Epoch:[{epoch + 1}/{args.epochs}]({step}/{iters}), "
f"Reward: {reward_val:.4f}, KL_ref: {kl_ref_val:.4f}, Approx KL: {approx_kl_val:.4f}, "
f"ClipFrac: {clipfrac_val:.4f}, Critic Loss: {critic_loss_val:.4f}, "
f"Avg Response Len: {avg_len_val:.2f}, Actor LR: {actor_lr:.8f}, Critic LR: {critic_lr:.8f}")
if (step % args.save_interval == 0 or step == iters) and is_main_process():
actor_model.eval()
moe_suffix = '_moe' if lm_config.use_moe else ''
ckp = f'{args.save_dir}/{args.save_weight}_{lm_config.hidden_size}{moe_suffix}.pth'
raw_actor = actor_model.module if isinstance(actor_model, DistributedDataParallel) else actor_model
raw_actor = getattr(raw_actor, '_orig_mod', raw_actor)
actor_state = raw_actor.state_dict()
torch.save({k: v.half().cpu() for k, v in actor_state.items()}, ckp)
# 使用 lm_checkpoint 保存完整状态(包括 critic
lm_checkpoint(lm_config, weight=args.save_weight, model=actor_model, optimizer=actor_optimizer,
epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints',
scheduler=actor_scheduler, critic_model=critic_model,
critic_optimizer=critic_optimizer, critic_scheduler=critic_scheduler)
actor_model.train()
del actor_state
del enc, gen_out, completion_ids, responses_text, rewards, full_mask, values_seq, advantages
del labels, resp_labels, resp_idx, resp_pad_mask, valid_resp, eos_mask, has_eos, eos_pos, resp_lengths, resp_policy_mask, resp_value_mask, old_resp_logp, ref_resp_logp
del kl, kl_ref, policy_loss, value_loss, loss, token_rewards, returns, old_resp_values, prompt_lens, logp_pos
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="MiniMind PPO (Proximal Policy Optimization)")
parser.add_argument("--save_dir", type=str, default="../out", help="模型保存目录")
parser.add_argument('--save_weight', default='ppo_actor', type=str, help="保存权重的前缀名")
parser.add_argument("--epochs", type=int, default=1, help="训练轮数")
parser.add_argument("--batch_size", type=int, default=2, help="batch size")
parser.add_argument("--learning_rate", type=float, default=3e-7, help="Actor学习率")
parser.add_argument("--critic_learning_rate", type=float, default=5e-7, help="Critic学习率")
parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu", help="训练设备")
parser.add_argument("--dtype", type=str, default="bfloat16", help="混合精度类型")
parser.add_argument("--num_workers", type=int, default=8, help="数据加载线程数")
parser.add_argument("--accumulation_steps", type=int, default=1, help="梯度累积步数")
parser.add_argument("--grad_clip", type=float, default=1.0, help="梯度裁剪阈值")
parser.add_argument("--log_interval", type=int, default=1, help="日志打印间隔")
parser.add_argument("--save_interval", type=int, default=10, help="模型保存间隔")
parser.add_argument('--hidden_size', default=768, type=int, help="隐藏层维度")
parser.add_argument('--num_hidden_layers', default=8, type=int, help="隐藏层数量")
parser.add_argument('--use_moe', default=0, type=int, choices=[0, 1], help="是否使用MoE架构0=否1=是)")
parser.add_argument('--max_seq_len', default=768, type=int, help="Prompt最大长度")
parser.add_argument("--max_gen_len", type=int, default=1024, help="生成的最大长度")
parser.add_argument("--data_path", type=str, default="../dataset/rlaif.jsonl", help="RLAIF数据路径")
parser.add_argument("--clip_epsilon", type=float, default=0.2, help="PPO裁剪参数")
parser.add_argument("--vf_coef", type=float, default=0.5, help="Value function系数")
parser.add_argument("--kl_coef", type=float, default=0.02, help="KL散度惩罚系数")
parser.add_argument("--gamma", type=float, default=1.0, help="GAE折扣因子")
parser.add_argument("--lam", type=float, default=0.95, help="GAE lambda参数")
parser.add_argument("--cliprange_value", type=float, default=0.2, help="Value function裁剪范围")
parser.add_argument("--ppo_update_iters", type=int, default=2, help="同一批rollout重复更新次数")
parser.add_argument("--early_stop_kl", type=float, default=0.25, help="PPO early stop 的 KL 阈值")
parser.add_argument("--mini_batch_size", type=int, default=2, help="PPO每次更新的minibatch大小")
parser.add_argument('--from_weight', default='full_sft', type=str, help="基于哪个权重训练")
parser.add_argument("--reward_model_path", type=str, default="../../internlm2-1_8b-reward", help="Reward模型路径")
parser.add_argument('--from_resume', default=0, type=int, choices=[0, 1], help="是否自动检测&续训0=否1=是)")
parser.add_argument("--use_wandb", action="store_true", help="是否使用wandb")
parser.add_argument("--wandb_project", type=str, default="MiniMind-PPO", help="wandb项目名")
parser.add_argument("--use_compile", default=0, type=int, choices=[0, 1], help="是否使用torch.compile加速0=否1=是)")
parser.add_argument("--debug_mode", action="store_true", help="是否打印训练调试采样")
parser.add_argument("--debug_interval", type=int, default=20, help="debug模式下每隔多少step打印一次采样")
parser.add_argument("--thinking_ratio", type=float, default=0.9, help="按概率开启thinking0.0~1.0")
parser.add_argument("--rollout_engine", type=str, default="torch", choices=["torch", "sglang"], help="rollout引擎类型")
parser.add_argument("--sglang_base_url", type=str, default="http://localhost:8998", help="SGLang服务器URL")
parser.add_argument("--sglang_model_path", type=str, default="../model", help="SGLang tokenizer路径")
parser.add_argument("--sglang_shared_path", type=str, default="./sglang_ckpt_ppo", help="SGLang共享存储路径")
args = parser.parse_args()
# ========== 1. 初始化环境和随机种子 ==========
local_rank = init_distributed_mode()
if dist.is_initialized(): args.device = f"cuda:{local_rank}"
setup_seed(42 + (dist.get_rank() if dist.is_initialized() else 0))
# ========== 2. 配置目录、模型参数、检查ckp ==========
os.makedirs(args.save_dir, exist_ok=True)
lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers, use_moe=bool(args.use_moe))
ckp_data = lm_checkpoint(lm_config, weight=args.save_weight, save_dir='../checkpoints') if args.from_resume==1 else None
# ========== 3. 设置混合精度 ==========
device_type = "cuda" if "cuda" in args.device else "cpu"
dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16
autocast_ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=dtype)
# ========== 4. 配wandb ==========
wandb = None
if args.use_wandb and is_main_process():
import swanlab as wandb
wandb_id = ckp_data.get('wandb_id') if ckp_data else None
resume = 'must' if wandb_id else None
wandb_run_name = f"MiniMind-PPO-Epoch-{args.epochs}-BS-{args.batch_size}-LR-{args.learning_rate}"
wandb.init(project=args.wandb_project, name=wandb_run_name, id=wandb_id, resume=resume)
# ========== 5. 初始化模型和数据 ==========
base_weight = args.from_weight
# Actor模型
actor_model, tokenizer = init_model(lm_config, base_weight, device=args.device)
ref_model, _ = init_model(lm_config, base_weight, device=args.device)
ref_model = ref_model.eval().requires_grad_(False)
moe_suffix = '_moe' if lm_config.use_moe else ''
ckp = f'{args.save_dir}/{base_weight}_{lm_config.hidden_size}{moe_suffix}.pth'
state_dict = torch.load(ckp, map_location=args.device)
critic_model = CriticModel(lm_config)
critic_model.load_state_dict(state_dict, strict=False)
critic_model = critic_model.to(args.device)
reward_model = LMForRewardModel(args.reward_model_path, device=args.device, dtype=torch.float16)
# Rollout引擎
rollout_engine = create_rollout_engine(
engine_type=args.rollout_engine,
policy_model=actor_model,
tokenizer=tokenizer,
device=args.device,
autocast_ctx=autocast_ctx,
sglang_base_url=args.sglang_base_url,
sglang_model_path=args.sglang_model_path,
sglang_shared_path=args.sglang_shared_path,
)
train_ds = RLAIFDataset(args.data_path, tokenizer, max_length=(args.max_seq_len + args.max_gen_len), thinking_ratio=args.thinking_ratio)
train_sampler = DistributedSampler(train_ds) if dist.is_initialized() else None
actor_optimizer = optim.AdamW(actor_model.parameters(), lr=args.learning_rate)
critic_optimizer = optim.AdamW(critic_model.parameters(), lr=args.critic_learning_rate)
loader_for_count = DataLoader(train_ds, batch_size=args.batch_size, sampler=train_sampler)
iters = len(loader_for_count)
mb_factor = max(1, math.ceil(args.batch_size / args.mini_batch_size))
total_optimizer_steps = math.ceil(iters * args.epochs * args.ppo_update_iters * mb_factor / args.accumulation_steps)
actor_scheduler = CosineAnnealingLR(actor_optimizer, T_max=total_optimizer_steps, eta_min=args.learning_rate / 10)
critic_scheduler = CosineAnnealingLR(critic_optimizer, T_max=total_optimizer_steps, eta_min=args.critic_learning_rate / 10)
start_epoch, start_step = 0, 0
if ckp_data:
actor_model.load_state_dict(ckp_data['model'])
critic_model.load_state_dict(ckp_data['critic_model'])
actor_optimizer.load_state_dict(ckp_data['optimizer'])
critic_optimizer.load_state_dict(ckp_data['critic_optimizer'])
actor_scheduler.load_state_dict(ckp_data['scheduler'])
critic_scheduler.load_state_dict(ckp_data['critic_scheduler'])
start_epoch = ckp_data['epoch']
start_step = ckp_data.get('step', 0)
# ========== 7. 编译和分布式包装 ==========
if args.use_compile == 1:
actor_model = torch.compile(actor_model)
Logger('torch.compile enabled')
rollout_engine.update_policy(actor_model)
if dist.is_initialized():
actor_model = DistributedDataParallel(actor_model, device_ids=[local_rank])
critic_model = DistributedDataParallel(critic_model, device_ids=[local_rank])
rollout_engine.update_policy(actor_model)
# ========== 8. 开始训练 ==========
for epoch in range(start_epoch, args.epochs):
train_sampler and train_sampler.set_epoch(epoch)
setup_seed(42 + epoch); indices = torch.randperm(len(train_ds)).tolist()
skip = start_step if (epoch == start_epoch and start_step > 0) else 0
batch_sampler = SkipBatchSampler(train_sampler or indices, args.batch_size, skip)
loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True)
if skip > 0:
Logger(f'Epoch [{epoch + 1}/{args.epochs}]: 跳过前{start_step}个step从step {start_step + 1}开始')
ppo_train_epoch(epoch, loader, len(loader) + skip, rollout_engine, ref_model, actor_scheduler, critic_scheduler, reward_model, start_step, wandb, use_sglang = (args.rollout_engine == "sglang"))
else:
ppo_train_epoch(epoch, loader, len(loader), rollout_engine, ref_model, actor_scheduler, critic_scheduler, reward_model, 0, wandb, use_sglang = (args.rollout_engine == "sglang"))
# ========== 9. 清理分布进程 ==========
if dist.is_initialized(): dist.destroy_process_group()

169
trainer/train_pretrain.py Normal file
View File

@ -0,0 +1,169 @@
import os
import sys
__package__ = "trainer"
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
import argparse
import time
import warnings
import torch
import torch.distributed as dist
from contextlib import nullcontext
from torch import optim, nn
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import DataLoader, DistributedSampler
from model.model_minimind import MiniMindConfig
from dataset.lm_dataset import PretrainDataset
from trainer.trainer_utils import get_lr, Logger, is_main_process, lm_checkpoint, init_distributed_mode, setup_seed, init_model, SkipBatchSampler
warnings.filterwarnings('ignore')
def train_epoch(epoch, loader, iters, start_step=0, wandb=None):
start_time = time.time()
last_step = start_step
for step, (input_ids, labels) in enumerate(loader, start=start_step + 1):
input_ids = input_ids.to(args.device)
labels = labels.to(args.device)
last_step = step
lr = get_lr(epoch * iters + step, args.epochs * iters, args.learning_rate)
for param_group in optimizer.param_groups:
param_group['lr'] = lr
with autocast_ctx:
res = model(input_ids, labels=labels)
loss = res.loss + res.aux_loss
loss = loss / args.accumulation_steps
scaler.scale(loss).backward()
if step % args.accumulation_steps == 0:
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad(set_to_none=True)
if step % args.log_interval == 0 or step == iters:
spend_time = time.time() - start_time
current_loss = loss.item() * args.accumulation_steps
current_aux_loss = res.aux_loss.item() if res.aux_loss is not None else 0.0
current_logits_loss = current_loss - current_aux_loss
current_lr = optimizer.param_groups[-1]['lr']
eta_min = spend_time / max(step - start_step, 1) * (iters - step) // 60
Logger(f'Epoch:[{epoch + 1}/{args.epochs}]({step}/{iters}), loss: {current_loss:.4f}, logits_loss: {current_logits_loss:.4f}, aux_loss: {current_aux_loss:.4f}, lr: {current_lr:.8f}, epoch_time: {eta_min:.1f}min')
if wandb: wandb.log({"loss": current_loss, "logits_loss": current_logits_loss, "aux_loss": current_aux_loss, "learning_rate": current_lr, "epoch_time": eta_min})
if (step % args.save_interval == 0 or step == iters) and is_main_process():
model.eval()
moe_suffix = '_moe' if lm_config.use_moe else ''
ckp = f'{args.save_dir}/{args.save_weight}_{lm_config.hidden_size}{moe_suffix}.pth'
raw_model = model.module if isinstance(model, DistributedDataParallel) else model
raw_model = getattr(raw_model, '_orig_mod', raw_model)
state_dict = raw_model.state_dict()
torch.save({k: v.half().cpu() for k, v in state_dict.items()}, ckp)
lm_checkpoint(lm_config, weight=args.save_weight, model=model, optimizer=optimizer, scaler=scaler, epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints')
model.train()
del state_dict
del input_ids, labels, res, loss
if last_step > start_step and last_step % args.accumulation_steps != 0:
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad(set_to_none=True)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="MiniMind Pretraining")
parser.add_argument("--save_dir", type=str, default="../out", help="模型保存目录")
parser.add_argument('--save_weight', default='pretrain', type=str, help="保存权重的前缀名")
parser.add_argument("--epochs", type=int, default=2, help="训练轮数")
parser.add_argument("--batch_size", type=int, default=32, help="batch size")
parser.add_argument("--learning_rate", type=float, default=5e-4, help="初始学习率")
parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu", help="训练设备")
parser.add_argument("--dtype", type=str, default="bfloat16", help="混合精度类型")
parser.add_argument("--num_workers", type=int, default=8, help="数据加载线程数")
parser.add_argument("--accumulation_steps", type=int, default=8, help="梯度累积步数")
parser.add_argument("--grad_clip", type=float, default=1.0, help="梯度裁剪阈值")
parser.add_argument("--log_interval", type=int, default=100, help="日志打印间隔")
parser.add_argument("--save_interval", type=int, default=1000, help="模型保存间隔")
parser.add_argument('--hidden_size', default=768, type=int, help="隐藏层维度")
parser.add_argument('--num_hidden_layers', default=8, type=int, help="隐藏层数量")
parser.add_argument('--max_seq_len', default=340, type=int, help="训练的最大截断长度中文1token≈1.5~1.7字符)")
parser.add_argument('--use_moe', default=0, type=int, choices=[0, 1], help="是否使用MoE架构0=否1=是)")
parser.add_argument("--data_path", type=str, default="../dataset/pretrain_t2t_mini.jsonl", help="预训练数据路径")
parser.add_argument('--from_weight', default='none', type=str, help="基于哪个权重训练为none则从头开始")
parser.add_argument('--from_resume', default=0, type=int, choices=[0, 1], help="是否自动检测&续训0=否1=是)")
parser.add_argument("--use_wandb", action="store_true", help="是否使用wandb")
parser.add_argument("--wandb_project", type=str, default="MiniMind-Pretrain", help="wandb项目名")
parser.add_argument("--use_compile", default=0, type=int, choices=[0, 1], help="是否使用torch.compile加速0=否1=是)")
args = parser.parse_args()
# ========== 1. 初始化环境和随机种子 ==========
local_rank = init_distributed_mode()
if dist.is_initialized(): args.device = f"cuda:{local_rank}"
setup_seed(42 + (dist.get_rank() if dist.is_initialized() else 0))
# ========== 2. 配置目录、模型参数、检查ckp ==========
os.makedirs(args.save_dir, exist_ok=True)
lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers, use_moe=bool(args.use_moe))
ckp_data = lm_checkpoint(lm_config, weight=args.save_weight, save_dir='../checkpoints') if args.from_resume==1 else None
# ========== 3. 设置混合精度 ==========
device_type = "cuda" if "cuda" in args.device else "cpu"
dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16
autocast_ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=dtype)
# ========== 4. 配wandb ==========
wandb = None
if args.use_wandb and is_main_process():
import swanlab as wandb
wandb_id = ckp_data.get('wandb_id') if ckp_data else None
resume = 'must' if wandb_id else None
wandb_run_name = f"MiniMind-Pretrain-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LearningRate-{args.learning_rate}"
wandb.init(project=args.wandb_project, name=wandb_run_name, id=wandb_id, resume=resume)
# ========== 5. 定义模型、数据、优化器 ==========
model, tokenizer = init_model(lm_config, args.from_weight, device=args.device)
train_ds = PretrainDataset(args.data_path, tokenizer, max_length=args.max_seq_len)
train_sampler = DistributedSampler(train_ds) if dist.is_initialized() else None
scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype == 'float16'))
optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)
# ========== 6. 从ckp恢复状态 ==========
start_epoch, start_step = 0, 0
if ckp_data:
model.load_state_dict(ckp_data['model'])
optimizer.load_state_dict(ckp_data['optimizer'])
scaler.load_state_dict(ckp_data['scaler'])
start_epoch = ckp_data['epoch']
start_step = ckp_data.get('step', 0)
# ========== 7. 编译和分布式包装 ==========
if args.use_compile == 1:
model = torch.compile(model)
Logger('torch.compile enabled')
if dist.is_initialized():
model = DistributedDataParallel(model, device_ids=[local_rank])
# ========== 8. 开始训练 ==========
for epoch in range(start_epoch, args.epochs):
train_sampler and train_sampler.set_epoch(epoch)
setup_seed(42 + epoch); indices = torch.randperm(len(train_ds)).tolist()
skip = start_step if (epoch == start_epoch and start_step > 0) else 0
batch_sampler = SkipBatchSampler(train_sampler or indices, args.batch_size, skip)
loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True)
if skip > 0:
Logger(f'Epoch [{epoch + 1}/{args.epochs}]: 跳过前{start_step}个step从step {start_step + 1}开始')
train_epoch(epoch, loader, len(loader) + skip, start_step, wandb)
else:
train_epoch(epoch, loader, len(loader), 0, wandb)
# ========== 9. 清理分布进程 ==========
if dist.is_initialized(): dist.destroy_process_group()

168
trainer/train_tokenizer.py Normal file
View File

@ -0,0 +1,168 @@
# 注不建议再重复训练tokenizer“词典”MiniMind已自带此脚本仅供学习和参考。基于不同词典训练的模型将导致输出完全不统一降低社区的模型复用性
# Note: It is not recommended to re-train the tokenizer. MiniMind already includes one. This script is for learning and reference only. Training models with different tokenizers will lead to inconsistent outputs and reduce model reusability in the community.
import os
import json
from tokenizers import decoders, models, pre_tokenizers, trainers, Tokenizer
DATA_PATH = '../dataset/sft_t2t_mini.jsonl'
TOKENIZER_DIR = '../model_learn_tokenizer/'
VOCAB_SIZE = 6400
SPECIAL_TOKENS_NUM = 36
def get_texts(data_path):
with open(data_path, 'r', encoding='utf-8', errors='ignore') as f:
for i, line in enumerate(f):
if i >= 10000: break # 选10000行测试
try:
data = json.loads(line)
contents = [item.get('content') for item in data.get('conversations', []) if item.get('content')]
if contents:
yield "\n".join(contents)
except json.JSONDecodeError:
continue
def train_tokenizer(data_path, tokenizer_dir, vocab_size, special_tokens_num=SPECIAL_TOKENS_NUM):
tokenizer = Tokenizer(models.BPE())
tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False)
special_tokens_list = [
"<|endoftext|>", "<|im_start|>", "<|im_end|>",
"<|object_ref_start|>", "<|object_ref_end|>", "<|box_start|>", "<|box_end|>", "<|quad_start|>", "<|quad_end|>",
"<|vision_start|>", "<|vision_end|>", "<|vision_pad|>", "<|image_pad|>", "<|video_pad|>",
"<|audio_start|>", "<|audio_end|>", "<|audio_pad|>", "<tts_pad>", "<tts_text_bos>", "<tts_text_eod>", "<tts_text_bos_single>"
]
additional_tokens_list = [
"<tool_call>", "</tool_call>",
"<tool_response>", "</tool_response>",
"<think>", "</think>"
]
num_buffer = special_tokens_num - len(special_tokens_list + additional_tokens_list)
buffer_tokens = [f"<|buffer{i}|>" for i in range(1, num_buffer + 1)] # 预留一定数量的token位置
all_special_tokens = special_tokens_list + additional_tokens_list + buffer_tokens
trainer = trainers.BpeTrainer(
vocab_size=vocab_size,
show_progress=True,
initial_alphabet=pre_tokenizers.ByteLevel.alphabet(),
special_tokens=all_special_tokens
)
texts = get_texts(data_path)
tokenizer.train_from_iterator(texts, trainer=trainer)
tokenizer.decoder = decoders.ByteLevel()
tokenizer.add_special_tokens(special_tokens_list)
os.makedirs(tokenizer_dir, exist_ok=True)
tokenizer.save(os.path.join(tokenizer_dir, "tokenizer.json"))
tokenizer.model.save(tokenizer_dir)
tokenizer_json_path = os.path.join(tokenizer_dir, "tokenizer.json")
with open(tokenizer_json_path, 'r', encoding='utf-8') as f:
tokenizer_data = json.load(f)
for token_info in tokenizer_data.get('added_tokens', []):
if token_info['content'] not in special_tokens_list:
token_info['special'] = False
with open(tokenizer_json_path, 'w', encoding='utf-8') as f:
json.dump(tokenizer_data, f, ensure_ascii=False, indent=2)
added_tokens_decoder = {}
for i, token in enumerate(all_special_tokens):
idx = tokenizer.token_to_id(token)
added_tokens_decoder[str(idx)] = {
"content": token,
"lstrip": False,
"normalized": False,
"rstrip": False,
"single_word": False,
"special": True if token in special_tokens_list else False
}
config = {
"add_bos_token": False,
"add_eos_token": False,
"add_prefix_space": False,
"added_tokens_decoder": added_tokens_decoder,
"additional_special_tokens": [t for t in special_tokens_list if t not in ["<|endoftext|>"]],
"bos_token": "<|im_start|>",
"clean_up_tokenization_spaces": False,
"eos_token": "<|im_end|>",
"legacy": True,
"model_max_length": 131072,
"pad_token": "<|endoftext|>",
"sp_model_kwargs": {},
"spaces_between_special_tokens": False,
"unk_token": "<|endoftext|>",
"image_token": "<|image_pad|>",
"audio_token": "<|audio_pad|>",
"video_token": "<|video_pad|>",
"vision_bos_token": "<|vision_start|>",
"vision_eos_token": "<|vision_end|>",
"audio_bos_token": "<|audio_start|>",
"audio_eos_token": "<|audio_end|>",
"chat_template": "{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0].role == 'system' %}\n {{- messages[0].content + '\\n\\n' }}\n {%- endif %}\n {{- \"# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within <tools></tools> XML tags:\\n<tools>\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n</tools>\\n\\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\\n<tool_call>\\n{\\\"name\\\": <function-name>, \\\"arguments\\\": <args-json-object>}\\n</tool_call><|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0].role == 'system' %}\n {{- '<|im_start|>system\\n' + messages[0].content + '<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}\n{%- for message in messages[::-1] %}\n {%- set index = (messages|length - 1) - loop.index0 %}\n {%- if ns.multi_step_tool and message.role == \"user\" and message.content is string and not(message.content.startswith('<tool_response>') and message.content.endswith('</tool_response>')) %}\n {%- set ns.multi_step_tool = false %}\n {%- set ns.last_query_index = index %}\n {%- endif %}\n{%- endfor %}\n{%- for message in messages %}\n {%- if message.content is string %}\n {%- set content = message.content %}\n {%- else %}\n {%- set content = '' %}\n {%- endif %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) %}\n {{- '<|im_start|>' + message.role + '\\n' + content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {%- set reasoning_content = '' %}\n {%- if message.reasoning_content is string %}\n {%- set reasoning_content = message.reasoning_content %}\n {%- else %}\n {%- if '</think>' in content %}\n {%- set reasoning_content = content.split('</think>')[0].rstrip('\\n').split('<think>')[-1].lstrip('\\n') %}\n {%- set content = content.split('</think>')[-1].lstrip('\\n') %}\n {%- endif %}\n {%- endif %}\n {%- if true %}\n {{- '<|im_start|>' + message.role + '\\n<think>\\n' + reasoning_content.strip('\\n') + '\\n</think>\\n\\n' + content.lstrip('\\n') }}\n {%- endif %}\n {%- if message.tool_calls %}\n {%- for tool_call in message.tool_calls %}\n {%- if (loop.first and content) or (not loop.first) %}\n {{- '\\n' }}\n {%- endif %}\n {%- if tool_call.function %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '<tool_call>\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {%- if tool_call.arguments is string %}\n {{- tool_call.arguments }}\n {%- else %}\n {{- tool_call.arguments | tojson }}\n {%- endif %}\n {{- '}\\n</tool_call>' }}\n {%- endfor %}\n {%- endif %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if loop.first or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n<tool_response>\\n' }}\n {{- content }}\n {{- '\\n</tool_response>' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n {%- if open_thinking is defined and open_thinking is true %}\n {{- '<think>\\n' }}\n {%- else %}\n {{- '<think>\\n\\n</think>\\n\\n' }}\n {%- endif %}\n{%- endif %}",
"tokenizer_class": "PreTrainedTokenizerFast"
}
with open(os.path.join(tokenizer_dir, "tokenizer_config.json"), "w", encoding="utf-8") as f:
json.dump(config, f, ensure_ascii=False, indent=4)
print("Tokenizer training completed.")
def eval_tokenizer(tokenizer_dir):
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir)
messages = [
{"role": "system", "content": "你是一个优秀的聊天机器人,总是给我正确的回应!"},
{"role": "user", "content": '你来自哪里?'},
{"role": "assistant", "content": '我来自月球'},
{"role": "user", "content": '你到底来自哪里?'},
{"role": "assistant", "content": '我来自地球'}
]
new_prompt = tokenizer.apply_chat_template(
messages,
tokenize=False
)
print('-'*100)
print(new_prompt)
print('-'*100)
print('tokenizer词表长度', len(tokenizer))
model_inputs = tokenizer(new_prompt)
print('encoder长度', len(model_inputs['input_ids']))
response = tokenizer.decode(model_inputs['input_ids'], skip_special_tokens=False)
print('decoder一致性', response == new_prompt, "\n")
print('-'*100)
print('压缩率测试Chars/Tokens')
test_texts = [
# 中文样本 (约200字)
"人工智能是计算机科学的一个分支,它企图了解智能的实质,并生产出一种新的能以人类智能相似的方式做出反应的智能机器,该领域的研究包括机器人、语言识别、图像识别、自然语言处理和专家系统等。人工智能从诞生以来,理论和技术日益成熟,应用领域也不断扩大,可以设想,未来人工智能带来的科技产品,将会是人类智慧的“容器”。人工智能可以对人的意识、思维的信息过程的模拟。人工智能不是人的智能,但能像人那样思考、也可能超过人的智能。",
"星际航行是指在星系内甚至星系间的空间中进行的航行。由于宇宙空间极其广阔,传统的化学火箭动力在恒星间航行时显得力不从心。科学家们提出了多种方案,包括离子推进器、核热火箭、甚至是利用反物质作为能源的设想。此外,曲率驱动和虫洞旅行等科幻概念也在理论物理研究中被反复探讨。尽管目前人类的足迹仅限于月球,但随着核聚变技术和材料科学的突破,前往火星乃至更遥远的太阳系边缘将成为可能。",
# 英文样本 (约200词/字符)
"Large language models (LLMs) are a type of artificial intelligence (AI) trained on vast amounts of text data to understand and generate human-like language. These models use deep learning techniques, specifically transformers, to process and predict the next word in a sequence. LLMs like GPT-4, Llama, and Claude have demonstrated remarkable capabilities in coding, translation, and creative writing. However, they also face challenges such as hallucinations, where the model generates factually incorrect information, and the need for significant computational resources.",
"The development of sustainable energy is crucial for the future of our planet. As climate change continues to impact global weather patterns, transitioning from fossil fuels to renewable sources like solar, wind, and hydroelectric power has become an urgent priority. Innovations in battery storage technology and smart grid management are essential to ensure a reliable energy supply. International cooperation and policy frameworks are also necessary to drive the global shift towards a greener economy and reduce carbon emissions.",
# 混合样本
"Python 是一种高级编程语言以其简洁的语法和强大的生态系统而闻名。It is widely used in data science, machine learning, and web development. 开发者可以利用 NumPy, Pandas, and PyTorch 等库快速构建复杂的应用。学习 Python 的过程非常愉快因为它的代码读起来就像英语一样。Whether you are a beginner or an expert, Python offers something for everyone.",
]
total_compression = 0
for i, text in enumerate(test_texts):
encoded = tokenizer.encode(text)
token_count = len(encoded)
char_count = len(text)
compression_ratio = char_count / token_count
total_compression += compression_ratio
print(f"样本 {i+1} | 字符数: {char_count:4} | Tokens: {token_count:3} | 压缩率: {compression_ratio:.2f}")
print(f"平均压缩率: {total_compression / len(test_texts):.2f}")
print('-'*100)
print('流式解码(字节缓冲)测试:')
input_ids = model_inputs['input_ids']
token_cache = []
for tid in input_ids:
token_cache.append(tid)
current_decode = tokenizer.decode(token_cache)
if current_decode and '\ufffd' not in current_decode:
display_ids = token_cache[0] if len(token_cache) == 1 else token_cache
raw_tokens = [tokenizer.convert_ids_to_tokens(int(t)) for t in (token_cache if isinstance(token_cache, list) else [token_cache])]
print(f'Token ID: {str(display_ids):15} -> Raw: {str(raw_tokens):20} -> Decode Str: {current_decode}')
token_cache = []
if __name__ == '__main__':
train_tokenizer(DATA_PATH, TOKENIZER_DIR, VOCAB_SIZE)
eval_tokenizer(TOKENIZER_DIR)

177
trainer/trainer_utils.py Normal file
View File

@ -0,0 +1,177 @@
"""
训练工具函数集合
"""
import os
import sys
__package__ = "trainer"
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
import random
import math
import numpy as np
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import Sampler
from transformers import AutoTokenizer, AutoModel, AutoModelForSequenceClassification
from model.model_minimind import MiniMindForCausalLM
def get_model_params(model, config):
total = sum(p.numel() for p in model.parameters()) / 1e6
n_routed = getattr(config, 'n_routed_experts', getattr(config, 'num_experts', 0))
n_active = getattr(config, 'num_experts_per_tok', 0)
n_shared = getattr(config, 'n_shared_experts', 0)
expert = sum(p.numel() for n, p in model.named_parameters() if 'mlp.experts.0.' in n) / 1e6
shared_expert = sum(p.numel() for n, p in model.named_parameters() if 'mlp.shared_experts.0.' in n) / 1e6
base = total - (expert * n_routed) - (shared_expert * n_shared)
active = base + (expert * n_active) + (shared_expert * n_shared)
if active < total: Logger(f'Model Params: {total:.2f}M-A{active:.2f}M')
else: Logger(f'Model Params: {total:.2f}M')
def is_main_process():
return not dist.is_initialized() or dist.get_rank() == 0
def Logger(content):
if is_main_process():
print(content)
def get_lr(current_step, total_steps, lr):
return lr*(0.1 + 0.45*(1 + math.cos(math.pi * current_step / total_steps)))
def init_distributed_mode():
if int(os.environ.get("RANK", -1)) == -1:
return 0 # 非DDP模式
dist.init_process_group(backend="nccl")
local_rank = int(os.environ["LOCAL_RANK"])
torch.cuda.set_device(local_rank)
return local_rank
def setup_seed(seed: int):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def lm_checkpoint(lm_config, weight='full_sft', model=None, optimizer=None, epoch=0, step=0, wandb=None, save_dir='../checkpoints', **kwargs):
os.makedirs(save_dir, exist_ok=True)
moe_path = '_moe' if lm_config.use_moe else ''
ckp_path = f'{save_dir}/{weight}_{lm_config.hidden_size}{moe_path}.pth'
resume_path = f'{save_dir}/{weight}_{lm_config.hidden_size}{moe_path}_resume.pth'
if model is not None:
raw_model = model.module if isinstance(model, DistributedDataParallel) else model
raw_model = getattr(raw_model, '_orig_mod', raw_model)
state_dict = raw_model.state_dict()
state_dict = {k: v.half().cpu() for k, v in state_dict.items()}
ckp_tmp = ckp_path + '.tmp'
torch.save(state_dict, ckp_tmp)
os.replace(ckp_tmp, ckp_path)
wandb_id = None
if wandb:
if hasattr(wandb, 'get_run'):
run = wandb.get_run()
wandb_id = getattr(run, 'id', None) if run else None
else:
wandb_id = getattr(wandb, 'id', None)
resume_data = {
'model': state_dict,
'optimizer': optimizer.state_dict(),
'epoch': epoch,
'step': step,
'world_size': dist.get_world_size() if dist.is_initialized() else 1,
'wandb_id': wandb_id
}
for key, value in kwargs.items():
if value is not None:
if hasattr(value, 'state_dict'):
raw_value = value.module if isinstance(value, DistributedDataParallel) else value
raw_value = getattr(raw_value, '_orig_mod', raw_value)
resume_data[key] = raw_value.state_dict()
else:
resume_data[key] = value
resume_tmp = resume_path + '.tmp'
torch.save(resume_data, resume_tmp)
os.replace(resume_tmp, resume_path)
del state_dict, resume_data
torch.cuda.empty_cache()
else: # 加载模式
if os.path.exists(resume_path):
ckp_data = torch.load(resume_path, map_location='cpu')
saved_ws = ckp_data.get('world_size', 1)
current_ws = dist.get_world_size() if dist.is_initialized() else 1
if saved_ws != current_ws:
ckp_data['step'] = ckp_data['step'] * saved_ws // current_ws
Logger(f'GPU数量变化({saved_ws}{current_ws})step已自动转换为{ckp_data["step"]}')
return ckp_data
return None
def init_model(lm_config, from_weight='pretrain', tokenizer_path='../model', save_dir='../out', device='cuda'):
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
model = MiniMindForCausalLM(lm_config)
if from_weight!= 'none':
moe_suffix = '_moe' if lm_config.use_moe else ''
weight_path = f'{save_dir}/{from_weight}_{lm_config.hidden_size}{moe_suffix}.pth'
weights = torch.load(weight_path, map_location=device)
model.load_state_dict(weights, strict=False)
get_model_params(model, lm_config)
Logger(f'Trainable Params: {sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f}M')
return model.to(device), tokenizer
class SkipBatchSampler(Sampler):
def __init__(self, sampler, batch_size, skip_batches=0):
self.sampler = sampler
self.batch_size = batch_size
self.skip_batches = skip_batches
def __iter__(self):
batch = []
skipped = 0
for idx in self.sampler:
batch.append(idx)
if len(batch) == self.batch_size:
if skipped < self.skip_batches:
skipped += 1
batch = []
continue
yield batch
batch = []
if len(batch) > 0 and skipped >= self.skip_batches:
yield batch
def __len__(self):
total_batches = (len(self.sampler) + self.batch_size - 1) // self.batch_size
return max(0, total_batches - self.skip_batches)
class LMForRewardModel:
def __init__(self, model_path, device="cuda", dtype=torch.float16):
self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
self.model = AutoModel.from_pretrained(model_path, torch_dtype=dtype, trust_remote_code=True)
self.model = self.model.to(device).eval()
self.device = device
@torch.no_grad()
def get_score(self, messages, response):
history_text = "\n".join([f"{m['role']}: {m['content']}" for m in messages[:-1]])
last_query = messages[-1]['content'] if messages else ""
message_context = f"{history_text}\n以上是对话历史。我的新问题是:\n{last_query}" if history_text else last_query
eval_messages = [
{"role": "user", "content": message_context},
{"role": "assistant", "content": response}
]
score = self.model.get_score(self.tokenizer, eval_messages)
return max(min(score, 3.0), -3.0)