Compare commits
No commits in common. "v2" and "master" have entirely different histories.
6
.gitignore
vendored
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
__pycache__/
|
||||||
|
*.pyc
|
||||||
|
.DS_Store
|
||||||
|
out
|
||||||
|
website/
|
||||||
|
docs-minimind/
|
||||||
@ -1,19 +0,0 @@
|
|||||||
# Read the Docs 配置文件
|
|
||||||
version: 2
|
|
||||||
|
|
||||||
# 构建配置
|
|
||||||
build:
|
|
||||||
os: ubuntu-22.04
|
|
||||||
tools:
|
|
||||||
python: "3.11"
|
|
||||||
|
|
||||||
# MkDocs 配置
|
|
||||||
mkdocs:
|
|
||||||
configuration: mkdocs.yml
|
|
||||||
fail_on_warning: false
|
|
||||||
|
|
||||||
# Python 依赖
|
|
||||||
python:
|
|
||||||
install:
|
|
||||||
- requirements: requirements.txt
|
|
||||||
|
|
||||||
128
CODE_OF_CONDUCT.md
Normal file
@ -0,0 +1,128 @@
|
|||||||
|
# Contributor Covenant Code of Conduct
|
||||||
|
|
||||||
|
## Our Pledge
|
||||||
|
|
||||||
|
We as members, contributors, and leaders pledge to make participation in our
|
||||||
|
community a harassment-free experience for everyone, regardless of age, body
|
||||||
|
size, visible or invisible disability, ethnicity, sex characteristics, gender
|
||||||
|
identity and expression, level of experience, education, socio-economic status,
|
||||||
|
nationality, personal appearance, race, religion, or sexual identity
|
||||||
|
and orientation.
|
||||||
|
|
||||||
|
We pledge to act and interact in ways that contribute to an open, welcoming,
|
||||||
|
diverse, inclusive, and healthy community.
|
||||||
|
|
||||||
|
## Our Standards
|
||||||
|
|
||||||
|
Examples of behavior that contributes to a positive environment for our
|
||||||
|
community include:
|
||||||
|
|
||||||
|
* Demonstrating empathy and kindness toward other people
|
||||||
|
* Being respectful of differing opinions, viewpoints, and experiences
|
||||||
|
* Giving and gracefully accepting constructive feedback
|
||||||
|
* Accepting responsibility and apologizing to those affected by our mistakes,
|
||||||
|
and learning from the experience
|
||||||
|
* Focusing on what is best not just for us as individuals, but for the
|
||||||
|
overall community
|
||||||
|
|
||||||
|
Examples of unacceptable behavior include:
|
||||||
|
|
||||||
|
* The use of sexualized language or imagery, and sexual attention or
|
||||||
|
advances of any kind
|
||||||
|
* Trolling, insulting or derogatory comments, and personal or political attacks
|
||||||
|
* Public or private harassment
|
||||||
|
* Publishing others' private information, such as a physical or email
|
||||||
|
address, without their explicit permission
|
||||||
|
* Other conduct which could reasonably be considered inappropriate in a
|
||||||
|
professional setting
|
||||||
|
|
||||||
|
## Enforcement Responsibilities
|
||||||
|
|
||||||
|
Community leaders are responsible for clarifying and enforcing our standards of
|
||||||
|
acceptable behavior and will take appropriate and fair corrective action in
|
||||||
|
response to any behavior that they deem inappropriate, threatening, offensive,
|
||||||
|
or harmful.
|
||||||
|
|
||||||
|
Community leaders have the right and responsibility to remove, edit, or reject
|
||||||
|
comments, commits, code, wiki edits, issues, and other contributions that are
|
||||||
|
not aligned to this Code of Conduct, and will communicate reasons for moderation
|
||||||
|
decisions when appropriate.
|
||||||
|
|
||||||
|
## Scope
|
||||||
|
|
||||||
|
This Code of Conduct applies within all community spaces, and also applies when
|
||||||
|
an individual is officially representing the community in public spaces.
|
||||||
|
Examples of representing our community include using an official e-mail address,
|
||||||
|
posting via an official social media account, or acting as an appointed
|
||||||
|
representative at an online or offline event.
|
||||||
|
|
||||||
|
## Enforcement
|
||||||
|
|
||||||
|
Instances of abusive, harassing, or otherwise unacceptable behavior may be
|
||||||
|
reported to the community leaders responsible for enforcement at
|
||||||
|
.
|
||||||
|
All complaints will be reviewed and investigated promptly and fairly.
|
||||||
|
|
||||||
|
All community leaders are obligated to respect the privacy and security of the
|
||||||
|
reporter of any incident.
|
||||||
|
|
||||||
|
## Enforcement Guidelines
|
||||||
|
|
||||||
|
Community leaders will follow these Community Impact Guidelines in determining
|
||||||
|
the consequences for any action they deem in violation of this Code of Conduct:
|
||||||
|
|
||||||
|
### 1. Correction
|
||||||
|
|
||||||
|
**Community Impact**: Use of inappropriate language or other behavior deemed
|
||||||
|
unprofessional or unwelcome in the community.
|
||||||
|
|
||||||
|
**Consequence**: A private, written warning from community leaders, providing
|
||||||
|
clarity around the nature of the violation and an explanation of why the
|
||||||
|
behavior was inappropriate. A public apology may be requested.
|
||||||
|
|
||||||
|
### 2. Warning
|
||||||
|
|
||||||
|
**Community Impact**: A violation through a single incident or series
|
||||||
|
of actions.
|
||||||
|
|
||||||
|
**Consequence**: A warning with consequences for continued behavior. No
|
||||||
|
interaction with the people involved, including unsolicited interaction with
|
||||||
|
those enforcing the Code of Conduct, for a specified period of time. This
|
||||||
|
includes avoiding interactions in community spaces as well as external channels
|
||||||
|
like social media. Violating these terms may lead to a temporary or
|
||||||
|
permanent ban.
|
||||||
|
|
||||||
|
### 3. Temporary Ban
|
||||||
|
|
||||||
|
**Community Impact**: A serious violation of community standards, including
|
||||||
|
sustained inappropriate behavior.
|
||||||
|
|
||||||
|
**Consequence**: A temporary ban from any sort of interaction or public
|
||||||
|
communication with the community for a specified period of time. No public or
|
||||||
|
private interaction with the people involved, including unsolicited interaction
|
||||||
|
with those enforcing the Code of Conduct, is allowed during this period.
|
||||||
|
Violating these terms may lead to a permanent ban.
|
||||||
|
|
||||||
|
### 4. Permanent Ban
|
||||||
|
|
||||||
|
**Community Impact**: Demonstrating a pattern of violation of community
|
||||||
|
standards, including sustained inappropriate behavior, harassment of an
|
||||||
|
individual, or aggression toward or disparagement of classes of individuals.
|
||||||
|
|
||||||
|
**Consequence**: A permanent ban from any sort of public interaction within
|
||||||
|
the community.
|
||||||
|
|
||||||
|
## Attribution
|
||||||
|
|
||||||
|
This Code of Conduct is adapted from the [Contributor Covenant][homepage],
|
||||||
|
version 2.0, available at
|
||||||
|
https://www.contributor-covenant.org/version/2/0/code_of_conduct.html.
|
||||||
|
|
||||||
|
Community Impact Guidelines were inspired by [Mozilla's code of conduct
|
||||||
|
enforcement ladder](https://github.com/mozilla/diversity).
|
||||||
|
|
||||||
|
[homepage]: https://www.contributor-covenant.org
|
||||||
|
|
||||||
|
For answers to common questions about this code of conduct, see the FAQ at
|
||||||
|
https://www.contributor-covenant.org/faq. Translations are available at
|
||||||
|
https://www.contributor-covenant.org/translations.
|
||||||
201
LICENSE
Normal file
@ -0,0 +1,201 @@
|
|||||||
|
Apache License
|
||||||
|
Version 2.0, January 2004
|
||||||
|
http://www.apache.org/licenses/
|
||||||
|
|
||||||
|
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||||
|
|
||||||
|
1. Definitions.
|
||||||
|
|
||||||
|
"License" shall mean the terms and conditions for use, reproduction,
|
||||||
|
and distribution as defined by Sections 1 through 9 of this document.
|
||||||
|
|
||||||
|
"Licensor" shall mean the copyright owner or entity authorized by
|
||||||
|
the copyright owner that is granting the License.
|
||||||
|
|
||||||
|
"Legal Entity" shall mean the union of the acting entity and all
|
||||||
|
other entities that control, are controlled by, or are under common
|
||||||
|
control with that entity. For the purposes of this definition,
|
||||||
|
"control" means (i) the power, direct or indirect, to cause the
|
||||||
|
direction or management of such entity, whether by contract or
|
||||||
|
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||||
|
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||||
|
|
||||||
|
"You" (or "Your") shall mean an individual or Legal Entity
|
||||||
|
exercising permissions granted by this License.
|
||||||
|
|
||||||
|
"Source" form shall mean the preferred form for making modifications,
|
||||||
|
including but not limited to software source code, documentation
|
||||||
|
source, and configuration files.
|
||||||
|
|
||||||
|
"Object" form shall mean any form resulting from mechanical
|
||||||
|
transformation or translation of a Source form, including but
|
||||||
|
not limited to compiled object code, generated documentation,
|
||||||
|
and conversions to other media types.
|
||||||
|
|
||||||
|
"Work" shall mean the work of authorship, whether in Source or
|
||||||
|
Object form, made available under the License, as indicated by a
|
||||||
|
copyright notice that is included in or attached to the work
|
||||||
|
(an example is provided in the Appendix below).
|
||||||
|
|
||||||
|
"Derivative Works" shall mean any work, whether in Source or Object
|
||||||
|
form, that is based on (or derived from) the Work and for which the
|
||||||
|
editorial revisions, annotations, elaborations, or other modifications
|
||||||
|
represent, as a whole, an original work of authorship. For the purposes
|
||||||
|
of this License, Derivative Works shall not include works that remain
|
||||||
|
separable from, or merely link (or bind by name) to the interfaces of,
|
||||||
|
the Work and Derivative Works thereof.
|
||||||
|
|
||||||
|
"Contribution" shall mean any work of authorship, including
|
||||||
|
the original version of the Work and any modifications or additions
|
||||||
|
to that Work or Derivative Works thereof, that is intentionally
|
||||||
|
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||||
|
or by an individual or Legal Entity authorized to submit on behalf of
|
||||||
|
the copyright owner. For the purposes of this definition, "submitted"
|
||||||
|
means any form of electronic, verbal, or written communication sent
|
||||||
|
to the Licensor or its representatives, including but not limited to
|
||||||
|
communication on electronic mailing lists, source code control systems,
|
||||||
|
and issue tracking systems that are managed by, or on behalf of, the
|
||||||
|
Licensor for the purpose of discussing and improving the Work, but
|
||||||
|
excluding communication that is conspicuously marked or otherwise
|
||||||
|
designated in writing by the copyright owner as "Not a Contribution."
|
||||||
|
|
||||||
|
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||||
|
on behalf of whom a Contribution has been received by Licensor and
|
||||||
|
subsequently incorporated within the Work.
|
||||||
|
|
||||||
|
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||||
|
this License, each Contributor hereby grants to You a perpetual,
|
||||||
|
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||||
|
copyright license to reproduce, prepare Derivative Works of,
|
||||||
|
publicly display, publicly perform, sublicense, and distribute the
|
||||||
|
Work and such Derivative Works in Source or Object form.
|
||||||
|
|
||||||
|
3. Grant of Patent License. Subject to the terms and conditions of
|
||||||
|
this License, each Contributor hereby grants to You a perpetual,
|
||||||
|
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||||
|
(except as stated in this section) patent license to make, have made,
|
||||||
|
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||||
|
where such license applies only to those patent claims licensable
|
||||||
|
by such Contributor that are necessarily infringed by their
|
||||||
|
Contribution(s) alone or by combination of their Contribution(s)
|
||||||
|
with the Work to which such Contribution(s) was submitted. If You
|
||||||
|
institute patent litigation against any entity (including a
|
||||||
|
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||||
|
or a Contribution incorporated within the Work constitutes direct
|
||||||
|
or contributory patent infringement, then any patent licenses
|
||||||
|
granted to You under this License for that Work shall terminate
|
||||||
|
as of the date such litigation is filed.
|
||||||
|
|
||||||
|
4. Redistribution. You may reproduce and distribute copies of the
|
||||||
|
Work or Derivative Works thereof in any medium, with or without
|
||||||
|
modifications, and in Source or Object form, provided that You
|
||||||
|
meet the following conditions:
|
||||||
|
|
||||||
|
(a) You must give any other recipients of the Work or
|
||||||
|
Derivative Works a copy of this License; and
|
||||||
|
|
||||||
|
(b) You must cause any modified files to carry prominent notices
|
||||||
|
stating that You changed the files; and
|
||||||
|
|
||||||
|
(c) You must retain, in the Source form of any Derivative Works
|
||||||
|
that You distribute, all copyright, patent, trademark, and
|
||||||
|
attribution notices from the Source form of the Work,
|
||||||
|
excluding those notices that do not pertain to any part of
|
||||||
|
the Derivative Works; and
|
||||||
|
|
||||||
|
(d) If the Work includes a "NOTICE" text file as part of its
|
||||||
|
distribution, then any Derivative Works that You distribute must
|
||||||
|
include a readable copy of the attribution notices contained
|
||||||
|
within such NOTICE file, excluding those notices that do not
|
||||||
|
pertain to any part of the Derivative Works, in at least one
|
||||||
|
of the following places: within a NOTICE text file distributed
|
||||||
|
as part of the Derivative Works; within the Source form or
|
||||||
|
documentation, if provided along with the Derivative Works; or,
|
||||||
|
within a display generated by the Derivative Works, if and
|
||||||
|
wherever such third-party notices normally appear. The contents
|
||||||
|
of the NOTICE file are for informational purposes only and
|
||||||
|
do not modify the License. You may add Your own attribution
|
||||||
|
notices within Derivative Works that You distribute, alongside
|
||||||
|
or as an addendum to the NOTICE text from the Work, provided
|
||||||
|
that such additional attribution notices cannot be construed
|
||||||
|
as modifying the License.
|
||||||
|
|
||||||
|
You may add Your own copyright statement to Your modifications and
|
||||||
|
may provide additional or different license terms and conditions
|
||||||
|
for use, reproduction, or distribution of Your modifications, or
|
||||||
|
for any such Derivative Works as a whole, provided Your use,
|
||||||
|
reproduction, and distribution of the Work otherwise complies with
|
||||||
|
the conditions stated in this License.
|
||||||
|
|
||||||
|
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||||
|
any Contribution intentionally submitted for inclusion in the Work
|
||||||
|
by You to the Licensor shall be under the terms and conditions of
|
||||||
|
this License, without any additional terms or conditions.
|
||||||
|
Notwithstanding the above, nothing herein shall supersede or modify
|
||||||
|
the terms of any separate license agreement you may have executed
|
||||||
|
with Licensor regarding such Contributions.
|
||||||
|
|
||||||
|
6. Trademarks. This License does not grant permission to use the trade
|
||||||
|
names, trademarks, service marks, or product names of the Licensor,
|
||||||
|
except as required for reasonable and customary use in describing the
|
||||||
|
origin of the Work and reproducing the content of the NOTICE file.
|
||||||
|
|
||||||
|
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||||
|
agreed to in writing, Licensor provides the Work (and each
|
||||||
|
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||||
|
implied, including, without limitation, any warranties or conditions
|
||||||
|
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||||
|
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||||
|
appropriateness of using or redistributing the Work and assume any
|
||||||
|
risks associated with Your exercise of permissions under this License.
|
||||||
|
|
||||||
|
8. Limitation of Liability. In no event and under no legal theory,
|
||||||
|
whether in tort (including negligence), contract, or otherwise,
|
||||||
|
unless required by applicable law (such as deliberate and grossly
|
||||||
|
negligent acts) or agreed to in writing, shall any Contributor be
|
||||||
|
liable to You for damages, including any direct, indirect, special,
|
||||||
|
incidental, or consequential damages of any character arising as a
|
||||||
|
result of this License or out of the use or inability to use the
|
||||||
|
Work (including but not limited to damages for loss of goodwill,
|
||||||
|
work stoppage, computer failure or malfunction, or any and all
|
||||||
|
other commercial damages or losses), even if such Contributor
|
||||||
|
has been advised of the possibility of such damages.
|
||||||
|
|
||||||
|
9. Accepting Warranty or Additional Liability. While redistributing
|
||||||
|
the Work or Derivative Works thereof, You may choose to offer,
|
||||||
|
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||||
|
or other liability obligations and/or rights consistent with this
|
||||||
|
License. However, in accepting such obligations, You may act only
|
||||||
|
on Your own behalf and on Your sole responsibility, not on behalf
|
||||||
|
of any other Contributor, and only if You agree to indemnify,
|
||||||
|
defend, and hold each Contributor harmless for any liability
|
||||||
|
incurred by, or claims asserted against, such Contributor by reason
|
||||||
|
of your accepting any such warranty or additional liability.
|
||||||
|
|
||||||
|
END OF TERMS AND CONDITIONS
|
||||||
|
|
||||||
|
APPENDIX: How to apply the Apache License to your work.
|
||||||
|
|
||||||
|
To apply the Apache License to your work, attach the following
|
||||||
|
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||||
|
replaced with your own identifying information. (Don't include
|
||||||
|
the brackets!) The text should be enclosed in the appropriate
|
||||||
|
comment syntax for the file format. We also recommend that a
|
||||||
|
file or class name and description of purpose be included on the
|
||||||
|
same "printed page" as the copyright notice for easier
|
||||||
|
identification within third-party archives.
|
||||||
|
|
||||||
|
Copyright [yyyy] [name of copyright owner]
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
1992
README_en.md
Normal file
0
dataset/__init__.py
Normal file
5
dataset/dataset.md
Executable file
@ -0,0 +1,5 @@
|
|||||||
|
# MiniMind Datasets
|
||||||
|
|
||||||
|
将所有下载的数据集文件放置到当前目录.
|
||||||
|
|
||||||
|
Place the downloaded dataset file in the current directory.
|
||||||
256
dataset/lm_dataset.py
Normal file
@ -0,0 +1,256 @@
|
|||||||
|
from torch.utils.data import Dataset
|
||||||
|
import torch
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
from datasets import load_dataset, Features, Sequence, Value
|
||||||
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
|
|
||||||
|
def pre_processing_chat(conversations, add_system_ratio=0.2):
|
||||||
|
# tool use 数据完整保留不做处理
|
||||||
|
if any(conv.get('tools') for conv in conversations): return conversations
|
||||||
|
|
||||||
|
SYSTEM_PROMPTS = [
|
||||||
|
"你是一个知识丰富的AI,尽力为用户提供准确的信息。",
|
||||||
|
"你是minimind,一个小巧但有用的语言模型。",
|
||||||
|
"你是一个专业的AI助手,请提供有价值的回答。",
|
||||||
|
"你是minimind,请尽力帮助用户解决问题。",
|
||||||
|
"你是一个可靠的AI,请给出准确的回答。",
|
||||||
|
"You are a helpful AI assistant.",
|
||||||
|
"You are minimind, a lightweight intelligent assistant.",
|
||||||
|
"You are a friendly chatbot. Please answer the user's questions carefully.",
|
||||||
|
"You are a knowledgeable AI. Try your best to provide accurate information.",
|
||||||
|
"You are minimind, a small but useful language model."
|
||||||
|
]
|
||||||
|
# 概率性添加system
|
||||||
|
if conversations[0].get('role') != 'system':
|
||||||
|
if random.random() < add_system_ratio:
|
||||||
|
return [{'role': 'system', 'content': random.choice(SYSTEM_PROMPTS)}] + conversations
|
||||||
|
return conversations
|
||||||
|
|
||||||
|
def post_processing_chat(prompt_content, empty_think_ratio=0.2):
|
||||||
|
# 以80%概率移除空思考标签
|
||||||
|
if '<think>\n\n</think>\n\n' in prompt_content and random.random() > empty_think_ratio:
|
||||||
|
prompt_content = prompt_content.replace('<think>\n\n</think>\n\n', '')
|
||||||
|
return prompt_content
|
||||||
|
|
||||||
|
class PretrainDataset(Dataset):
|
||||||
|
def __init__(self, data_path, tokenizer, max_length=512):
|
||||||
|
super().__init__()
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
self.max_length = max_length
|
||||||
|
self.samples = load_dataset('json', data_files=data_path, split='train')
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.samples)
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
sample = self.samples[index]
|
||||||
|
tokens = self.tokenizer(str(sample['text']), add_special_tokens=False, max_length=self.max_length - 2, truncation=True).input_ids
|
||||||
|
tokens = [self.tokenizer.bos_token_id] + tokens + [self.tokenizer.eos_token_id]
|
||||||
|
input_ids = tokens + [self.tokenizer.pad_token_id] * (self.max_length - len(tokens))
|
||||||
|
input_ids = torch.tensor(input_ids, dtype=torch.long)
|
||||||
|
labels = input_ids.clone()
|
||||||
|
labels[input_ids == self.tokenizer.pad_token_id] = -100
|
||||||
|
return input_ids, labels
|
||||||
|
|
||||||
|
|
||||||
|
class SFTDataset(Dataset):
|
||||||
|
def __init__(self, jsonl_path, tokenizer, max_length=1024):
|
||||||
|
super().__init__()
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
self.max_length = max_length
|
||||||
|
features = Features({'conversations': [{'role': Value('string'), 'content': Value('string'), 'reasoning_content': Value('string'), 'tools': Value('string'), 'tool_calls': Value('string')}]})
|
||||||
|
self.samples = load_dataset('json', data_files=jsonl_path, split='train', features=features)
|
||||||
|
self.bos_id = tokenizer(f'{tokenizer.bos_token}assistant\n', add_special_tokens=False).input_ids
|
||||||
|
self.eos_id = tokenizer(f'{tokenizer.eos_token}\n', add_special_tokens=False).input_ids
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.samples)
|
||||||
|
|
||||||
|
def create_chat_prompt(self, conversations):
|
||||||
|
messages = []
|
||||||
|
tools = None
|
||||||
|
for message in conversations:
|
||||||
|
message = dict(message)
|
||||||
|
if message.get("role") == "system" and message.get("tools"):
|
||||||
|
tools = json.loads(message["tools"]) if isinstance(message["tools"], str) else message["tools"]
|
||||||
|
if message.get("tool_calls") and isinstance(message["tool_calls"], str):
|
||||||
|
message["tool_calls"] = json.loads(message["tool_calls"])
|
||||||
|
messages.append(message)
|
||||||
|
return self.tokenizer.apply_chat_template(
|
||||||
|
messages,
|
||||||
|
tokenize=False,
|
||||||
|
add_generation_prompt=False,
|
||||||
|
tools=tools
|
||||||
|
)
|
||||||
|
|
||||||
|
def generate_labels(self, input_ids):
|
||||||
|
labels = [-100] * len(input_ids)
|
||||||
|
i = 0
|
||||||
|
while i < len(input_ids):
|
||||||
|
if input_ids[i:i + len(self.bos_id)] == self.bos_id:
|
||||||
|
start = i + len(self.bos_id)
|
||||||
|
end = start
|
||||||
|
while end < len(input_ids):
|
||||||
|
if input_ids[end:end + len(self.eos_id)] == self.eos_id:
|
||||||
|
break
|
||||||
|
end += 1
|
||||||
|
for j in range(start, min(end + len(self.eos_id), self.max_length)):
|
||||||
|
labels[j] = input_ids[j]
|
||||||
|
i = end + len(self.eos_id) if end < len(input_ids) else len(input_ids)
|
||||||
|
else:
|
||||||
|
i += 1
|
||||||
|
return labels
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
sample = self.samples[index]
|
||||||
|
conversations = pre_processing_chat(sample['conversations'])
|
||||||
|
prompt = self.create_chat_prompt(conversations)
|
||||||
|
prompt = post_processing_chat(prompt)
|
||||||
|
input_ids = self.tokenizer(prompt).input_ids[:self.max_length]
|
||||||
|
input_ids += [self.tokenizer.pad_token_id] * (self.max_length - len(input_ids))
|
||||||
|
labels = self.generate_labels(input_ids)
|
||||||
|
# # === 调试打印 ===
|
||||||
|
# print(f"\n--- Sample {index} ---")
|
||||||
|
# for i, (x, y) in enumerate(zip(input_ids[:-1], labels[1:])):
|
||||||
|
# print(f"{i:3d}: X={self.tokenizer.decode([x])!r:16s} ---> Y={self.tokenizer.decode([input_ids[i+1]])!r:16s} label={y}")
|
||||||
|
# # ================
|
||||||
|
return torch.tensor(input_ids, dtype=torch.long), torch.tensor(labels, dtype=torch.long)
|
||||||
|
|
||||||
|
|
||||||
|
class DPODataset(Dataset):
|
||||||
|
def __init__(self, file_path, tokenizer, max_length=4096):
|
||||||
|
super().__init__()
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
self.max_length = max_length
|
||||||
|
self.padding = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
|
||||||
|
self.bos_id = tokenizer(f'{tokenizer.bos_token}assistant\n', add_special_tokens=False).input_ids
|
||||||
|
self.eos_id = tokenizer(f'{tokenizer.eos_token}\n', add_special_tokens=False).input_ids
|
||||||
|
self.samples = load_dataset('json', data_files=file_path, split='train')
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.samples)
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
sample = self.samples[index]
|
||||||
|
chosen = sample['chosen'] # 是一个 list,里面包含若干 {role, content}
|
||||||
|
rejected = sample['rejected'] # 同上
|
||||||
|
chosen_prompt = self.tokenizer.apply_chat_template(
|
||||||
|
chosen, tokenize=False, add_generation_prompt=False
|
||||||
|
)
|
||||||
|
chosen_prompt = post_processing_chat(chosen_prompt)
|
||||||
|
|
||||||
|
rejected_prompt = self.tokenizer.apply_chat_template(
|
||||||
|
rejected, tokenize=False, add_generation_prompt=False
|
||||||
|
)
|
||||||
|
rejected_prompt = post_processing_chat(rejected_prompt)
|
||||||
|
chosen_encoding = self.tokenizer(
|
||||||
|
chosen_prompt, truncation=True, max_length=self.max_length, padding='max_length'
|
||||||
|
)
|
||||||
|
rejected_encoding = self.tokenizer(
|
||||||
|
rejected_prompt, truncation=True, max_length=self.max_length, padding='max_length'
|
||||||
|
)
|
||||||
|
|
||||||
|
chosen_input_ids = chosen_encoding['input_ids']
|
||||||
|
chosen_loss_mask = self.generate_loss_mask(chosen_input_ids)
|
||||||
|
|
||||||
|
rejected_input_ids = rejected_encoding['input_ids']
|
||||||
|
rejected_loss_mask = self.generate_loss_mask(rejected_input_ids)
|
||||||
|
x_chosen = torch.tensor(chosen_input_ids[:-1], dtype=torch.long)
|
||||||
|
y_chosen = torch.tensor(chosen_input_ids[1:], dtype=torch.long)
|
||||||
|
mask_chosen = torch.tensor(chosen_loss_mask[1:], dtype=torch.long)
|
||||||
|
x_rejected = torch.tensor(rejected_input_ids[:-1], dtype=torch.long)
|
||||||
|
y_rejected = torch.tensor(rejected_input_ids[1:], dtype=torch.long)
|
||||||
|
mask_rejected = torch.tensor(rejected_loss_mask[1:], dtype=torch.long)
|
||||||
|
|
||||||
|
return {
|
||||||
|
'x_chosen': x_chosen,
|
||||||
|
'y_chosen': y_chosen,
|
||||||
|
'mask_chosen': mask_chosen,
|
||||||
|
'x_rejected': x_rejected,
|
||||||
|
'y_rejected': y_rejected,
|
||||||
|
'mask_rejected': mask_rejected
|
||||||
|
}
|
||||||
|
|
||||||
|
def generate_loss_mask(self, input_ids):
|
||||||
|
loss_mask = [0] * len(input_ids)
|
||||||
|
i = 0
|
||||||
|
while i < len(input_ids):
|
||||||
|
if input_ids[i:i + len(self.bos_id)] == self.bos_id:
|
||||||
|
start = i + len(self.bos_id)
|
||||||
|
end = start
|
||||||
|
while end < len(input_ids):
|
||||||
|
if input_ids[end:end + len(self.eos_id)] == self.eos_id:
|
||||||
|
break
|
||||||
|
end += 1
|
||||||
|
for j in range(start, min(end + len(self.eos_id), self.max_length)):
|
||||||
|
loss_mask[j] = 1
|
||||||
|
i = end + len(self.eos_id) if end < len(input_ids) else len(input_ids)
|
||||||
|
else:
|
||||||
|
i += 1
|
||||||
|
return loss_mask
|
||||||
|
|
||||||
|
|
||||||
|
class RLAIFDataset(Dataset):
|
||||||
|
def __init__(self, jsonl_path, tokenizer, max_length=1024, thinking_ratio=0.5):
|
||||||
|
super().__init__()
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
self.max_length = max_length
|
||||||
|
self.thinking_ratio = thinking_ratio # 按概率开启 thinking
|
||||||
|
self.samples = load_dataset('json', data_files=jsonl_path, split='train')
|
||||||
|
self.bos_id = tokenizer(f'{tokenizer.bos_token}assistant', add_special_tokens=False).input_ids
|
||||||
|
self.eos_id = tokenizer(f'{tokenizer.eos_token}', add_special_tokens=False).input_ids
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.samples)
|
||||||
|
|
||||||
|
def create_chat_prompt(self, conversations):
|
||||||
|
conversations = pre_processing_chat(conversations)
|
||||||
|
use_thinking = random.random() < self.thinking_ratio
|
||||||
|
return self.tokenizer.apply_chat_template(
|
||||||
|
conversations[:-1],
|
||||||
|
tokenize=False,
|
||||||
|
open_thinking=use_thinking,
|
||||||
|
add_generation_prompt=True
|
||||||
|
)
|
||||||
|
def __getitem__(self, index):
|
||||||
|
sample = self.samples[index]
|
||||||
|
prompt = self.create_chat_prompt(sample['conversations'])
|
||||||
|
|
||||||
|
return {
|
||||||
|
'prompt': prompt,
|
||||||
|
'answer': ""
|
||||||
|
}
|
||||||
|
|
||||||
|
class AgentRLDataset(Dataset):
|
||||||
|
def __init__(self, jsonl_path, tokenizer, max_length=1024):
|
||||||
|
super().__init__()
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
self.max_length = max_length
|
||||||
|
self.samples = []
|
||||||
|
with open(jsonl_path, 'r', encoding='utf-8') as f:
|
||||||
|
for line in f:
|
||||||
|
self.samples.append(json.loads(line.strip()))
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.samples)
|
||||||
|
|
||||||
|
def parse_conversations(self, conversations):
|
||||||
|
messages = []
|
||||||
|
tools = None
|
||||||
|
for message in conversations:
|
||||||
|
message = dict(message)
|
||||||
|
if message.get("role") == "system" and message.get("tools"):
|
||||||
|
tools = json.loads(message["tools"]) if isinstance(message["tools"], str) else message["tools"]
|
||||||
|
messages.append(message)
|
||||||
|
return messages[:-1], tools
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
sample = self.samples[index]
|
||||||
|
messages, tools = self.parse_conversations(sample['conversations'])
|
||||||
|
return {'messages': messages, 'tools': tools, 'gt': sample['gt']}
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
pass
|
||||||
|
Before Width: | Height: | Size: 136 KiB |
|
Before Width: | Height: | Size: 73 KiB |
|
Before Width: | Height: | Size: 230 KiB |
|
Before Width: | Height: | Size: 104 KiB |
|
Before Width: | Height: | Size: 239 KiB |
|
Before Width: | Height: | Size: 121 KiB |
|
Before Width: | Height: | Size: 372 KiB |
|
Before Width: | Height: | Size: 519 KiB |
|
Before Width: | Height: | Size: 146 KiB |
|
Before Width: | Height: | Size: 3.8 MiB |
|
Before Width: | Height: | Size: 559 KiB |
|
Before Width: | Height: | Size: 531 KiB |
|
Before Width: | Height: | Size: 144 KiB |
|
Before Width: | Height: | Size: 1006 KiB |
|
Before Width: | Height: | Size: 943 KiB |
|
Before Width: | Height: | Size: 214 KiB |
|
Before Width: | Height: | Size: 246 KiB |
|
Before Width: | Height: | Size: 246 KiB |
|
Before Width: | Height: | Size: 241 KiB |
|
Before Width: | Height: | Size: 234 KiB |
|
Before Width: | Height: | Size: 145 KiB |
|
Before Width: | Height: | Size: 152 KiB |
124
docs/index.md
@ -1,124 +0,0 @@
|
|||||||
# Welcome to MiniMind!
|
|
||||||
|
|
||||||
<figure markdown>
|
|
||||||

|
|
||||||
<figcaption><strong>"Simplicity is the ultimate sophistication"</strong></figcaption>
|
|
||||||
</figure>
|
|
||||||
|
|
||||||
## 📌 Introduction
|
|
||||||
|
|
||||||
**MiniMind** is a complete, open-source project for training ultra-small language models from scratch with minimal cost. Train a **26M** ChatBot in just **2 hours** with only **$3** on a single 3090 GPU!
|
|
||||||
|
|
||||||
- **MiniMind** series is extremely lightweight, the smallest version is **1/7000** the size of GPT-3
|
|
||||||
- Complete implementation covering:
|
|
||||||
- **Tokenizer training** with custom vocabulary
|
|
||||||
- **Pretraining** (knowledge learning)
|
|
||||||
- **Supervised Fine-Tuning (SFT)** (conversation patterns)
|
|
||||||
- **LoRA fine-tuning** (parameter-efficient adaptation)
|
|
||||||
- **Direct Preference Optimization (DPO)** (human preference alignment)
|
|
||||||
- **RLAIF algorithms** (PPO/GRPO/SPO - reinforcement learning)
|
|
||||||
- **Knowledge distillation** (compress large model knowledge)
|
|
||||||
- **Model reasoning distillation** (DeepSeek-R1 style)
|
|
||||||
- **YaRN algorithm** (context length extrapolation)
|
|
||||||
- **Pure PyTorch implementation**: All core algorithms are implemented from scratch using native PyTorch, without relying on third-party abstract interfaces
|
|
||||||
- **Educational value**: This is not only a full-stage open-source reproduction of large language models, but also a comprehensive tutorial for getting started with LLMs
|
|
||||||
- **Extended capabilities**: MiniMind now supports [MiniMind-V](https://github.com/jingyaogong/minimind-v) for vision multimodal tasks
|
|
||||||
|
|
||||||
!!! note "Training Cost & Time"
|
|
||||||
"2 hours" is based on **NVIDIA 3090** hardware (single card) testing
|
|
||||||
|
|
||||||
"$3" refers to GPU server rental cost
|
|
||||||
|
|
||||||
With 8× RTX 4090 GPUs, training time can be compressed to **under 10 minutes**
|
|
||||||
|
|
||||||
## ✨ Key Highlights
|
|
||||||
|
|
||||||
- **Ultra-low cost**: Single 3090, 2 hours, $3 to train a fully functional ChatBot from scratch
|
|
||||||
- **Complete pipeline**: Tokenizer → Pretraining → SFT → LoRA → DPO/RLAIF → Distillation → Reasoning
|
|
||||||
- **Latest algorithms**: Implements cutting-edge techniques including GRPO, SPO, and YaRN
|
|
||||||
- **Education-friendly**: Clean, well-documented code suitable for learning LLM principles
|
|
||||||
- **Ecosystem compatible**: Seamless support for `transformers`, `trl`, `peft`, `llama.cpp`, `vllm`, `ollama`, and `Llama-Factory`
|
|
||||||
- **Full capabilities**: Supports multi-GPU training (DDP/DeepSpeed), model visualization (Wandb/SwanLab), and dynamic checkpoint management
|
|
||||||
- **Production-ready**: OpenAI API protocol support for easy integration with third-party UIs (FastGPT, Open-WebUI, etc.)
|
|
||||||
- **Multimodal extension**: Extended to vision with [MiniMind-V](https://github.com/jingyaogong/minimind-v)
|
|
||||||
|
|
||||||
## 📊 Model Series
|
|
||||||
|
|
||||||
### MiniMind2 Series (Latest - 2025.04.26)
|
|
||||||
|
|
||||||
| Model | Parameters | Vocabulary | Layers | Hidden Dim | Context | Inference Memory |
|
|
||||||
|-------|-----------|------------|--------|-----------|---------|-----------------|
|
|
||||||
| MiniMind2-small | 26M | 6,400 | 8 | 512 | 2K | ~0.5 GB |
|
|
||||||
| MiniMind2-MoE | 145M | 6,400 | 8 | 640 | 2K | ~1.0 GB |
|
|
||||||
| MiniMind2 | 104M | 6,400 | 16 | 768 | 2K | ~1.0 GB |
|
|
||||||
|
|
||||||
### MiniMind-V1 Series (Legacy - 2024.09.01)
|
|
||||||
|
|
||||||
| Model | Parameters | Vocabulary | Layers | Hidden Dim | Context |
|
|
||||||
|-------|-----------|------------|--------|-----------|---------|
|
|
||||||
| minimind-v1-small | 26M | 6,400 | 8 | 512 | 2K |
|
|
||||||
| minimind-v1-moe | 104M | 6,400 | 8 | 512 | 2K |
|
|
||||||
| minimind-v1 | 108M | 6,400 | 16 | 768 | 2K |
|
|
||||||
|
|
||||||
## 📅 Latest Updates (2025-10-24)
|
|
||||||
|
|
||||||
🔥 **RLAIF Training Algorithms**: Native implementation of PPO, GRPO, and SPO
|
|
||||||
|
|
||||||
- **YaRN Algorithm**: RoPE length extrapolation for improved long-sequence handling
|
|
||||||
- **Adaptive Thinking**: Reasoning models support optional thinking chains
|
|
||||||
- **Full template support**: Tool calling and reasoning tags (`<tool_call>`, `<think>`, etc.)
|
|
||||||
- **Visualization**: Switched from WandB to [SwanLab](https://swanlab.cn/) (China-friendly)
|
|
||||||
- **Reasoning models**: Complete MiniMind-Reason series based on DeepSeek-R1 distillation
|
|
||||||
|
|
||||||
## 🎯 Project Contents
|
|
||||||
|
|
||||||
- Complete MiniMind-LLM architecture code (Dense + MoE models)
|
|
||||||
- Detailed Tokenizer training code
|
|
||||||
- Full training pipeline: Pretrain → SFT → LoRA → RLHF/RLAIF → Distillation
|
|
||||||
- High-quality, curated and deduplicated datasets at all stages
|
|
||||||
- Native PyTorch implementation of key algorithms, minimal third-party dependencies
|
|
||||||
- Multi-GPU training support (single-machine multi-card DDP, DeepSpeed, distributed clusters)
|
|
||||||
- Visualization with Wandb/SwanLab
|
|
||||||
- Model evaluation on third-party benchmarks (C-Eval, C-MMLU, OpenBookQA)
|
|
||||||
- YaRN algorithm for RoPE context length extrapolation
|
|
||||||
- OpenAI API protocol server for easy integration
|
|
||||||
- Streamlit web UI for chat
|
|
||||||
- Full compatibility with community tools: llama.cpp, vllm, ollama, Llama-Factory
|
|
||||||
- MiniMind-Reason models: Complete open-source data + weights for reasoning distillation
|
|
||||||
|
|
||||||
## 🚀 Quick Navigation
|
|
||||||
|
|
||||||
- **[Quick Start](quickstart.md)** - Environment setup, model download, quick testing
|
|
||||||
- **[Model Training](training.md)** - Pretraining, SFT, LoRA, RLHF, RLAIF, and reasoning training
|
|
||||||
|
|
||||||
## 🔗 Links & Resources
|
|
||||||
|
|
||||||
**Project Repositories**:
|
|
||||||
- **GitHub**: [https://github.com/jingyaogong/minimind](https://github.com/jingyaogong/minimind)
|
|
||||||
- **HuggingFace**: [MiniMind Collection](https://huggingface.co/collections/jingyaogong/minimind-66caf8d999f5c7fa64f399e5)
|
|
||||||
- **ModelScope**: [MiniMind Profile](https://www.modelscope.cn/profile/gongjy)
|
|
||||||
|
|
||||||
**Online Demos**:
|
|
||||||
- [ModelScope Studio - Standard Chat](https://www.modelscope.cn/studios/gongjy/MiniMind)
|
|
||||||
- [ModelScope Studio - Reasoning Model](https://www.modelscope.cn/studios/gongjy/MiniMind-Reasoning)
|
|
||||||
- [Bilibili Video Introduction](https://www.bilibili.com/video/BV12dHPeqE72/)
|
|
||||||
|
|
||||||
**Vision Extension**:
|
|
||||||
- [MiniMind-V](https://github.com/jingyaogong/minimind-v) - Multimodal vision language models
|
|
||||||
|
|
||||||
## 💡 Why MiniMind?
|
|
||||||
|
|
||||||
The AI community is flooded with high-cost, complex frameworks that abstract away the fundamentals. MiniMind aims to democratize LLM learning by:
|
|
||||||
|
|
||||||
1. **Lowering the barrier**: No need for expensive GPUs or cloud services
|
|
||||||
2. **Understanding, not just using**: Learn every detail from tokenization to inference
|
|
||||||
3. **End-to-end learning**: Train from scratch, not just fine-tune existing models
|
|
||||||
4. **Code clarity**: Pure PyTorch implementations you can read and understand
|
|
||||||
5. **Practical results**: Get a working ChatBot with minimal resources
|
|
||||||
|
|
||||||
As we say: **"Building a Lego airplane is far more exciting than flying first class!"**
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
Next: [Get Started →](quickstart.md)
|
|
||||||
|
|
||||||
@ -1,279 +0,0 @@
|
|||||||
# Quick Start
|
|
||||||
|
|
||||||
Get MiniMind up and running in minutes!
|
|
||||||
|
|
||||||
## 📋 Requirements
|
|
||||||
|
|
||||||
### Hardware
|
|
||||||
|
|
||||||
- **GPU Memory**: 8GB minimum (24GB recommended for comfortable development)
|
|
||||||
- **Recommended GPU**: NVIDIA RTX 3090 (24GB)
|
|
||||||
|
|
||||||
### Software
|
|
||||||
|
|
||||||
- **Python**: 3.10+
|
|
||||||
- **PyTorch**: 2.0+ (with CUDA 12.2+ for GPU support)
|
|
||||||
- **CUDA**: 12.2+ (optional, for GPU acceleration)
|
|
||||||
|
|
||||||
!!! tip "Hardware Configuration Reference"
|
|
||||||
- **CPU**: Intel i9-10980XE @ 3.00GHz
|
|
||||||
- **RAM**: 128 GB
|
|
||||||
- **GPU**: NVIDIA GeForce RTX 3090 (24GB) × 8
|
|
||||||
- **OS**: Ubuntu 20.04
|
|
||||||
- **CUDA**: 12.2
|
|
||||||
- **Python**: 3.10.16
|
|
||||||
|
|
||||||
## 🚀 Step 0: Clone the Repository
|
|
||||||
|
|
||||||
```bash
|
|
||||||
git clone https://github.com/jingyaogong/minimind.git
|
|
||||||
cd minimind
|
|
||||||
```
|
|
||||||
|
|
||||||
## 🎯 Section I: Testing Existing Models
|
|
||||||
|
|
||||||
### 1. Environment Setup
|
|
||||||
|
|
||||||
```bash
|
|
||||||
pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple
|
|
||||||
```
|
|
||||||
|
|
||||||
!!! warning "Verify CUDA Support"
|
|
||||||
After installation, verify PyTorch can access CUDA:
|
|
||||||
```python
|
|
||||||
import torch
|
|
||||||
print(torch.cuda.is_available())
|
|
||||||
print(torch.cuda.get_device_name(0))
|
|
||||||
```
|
|
||||||
If `False`, download the correct PyTorch version from [PyTorch Official](https://download.pytorch.org/whl/torch_stable.html)
|
|
||||||
|
|
||||||
### 2. Download Pretrained Models
|
|
||||||
|
|
||||||
Choose one option:
|
|
||||||
|
|
||||||
**From HuggingFace** (recommended for international users):
|
|
||||||
```bash
|
|
||||||
git clone https://huggingface.co/jingyaogong/MiniMind2
|
|
||||||
```
|
|
||||||
|
|
||||||
**From ModelScope** (recommended for China users):
|
|
||||||
```bash
|
|
||||||
git clone https://www.modelscope.cn/models/gongjy/MiniMind2.git
|
|
||||||
```
|
|
||||||
|
|
||||||
### 3. Command-Line Chat
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# load=0: load PyTorch model, load=1: load transformers model
|
|
||||||
python eval_model.py --load 1 --model_mode 2
|
|
||||||
```
|
|
||||||
|
|
||||||
**Model Modes**:
|
|
||||||
- `model_mode 0`: Pretrain model (word continuation)
|
|
||||||
- `model_mode 1`: SFT Chat model (conversation)
|
|
||||||
- `model_mode 2`: RLHF model (refined responses, currently same as SFT for small models)
|
|
||||||
- `model_mode 3`: Reasoning model (with thinking chains)
|
|
||||||
- `model_mode 4/5`: RLAIF models (PPO/GRPO trained)
|
|
||||||
|
|
||||||
**Example Session**:
|
|
||||||
```text
|
|
||||||
👶: Hello, please introduce yourself.
|
|
||||||
🤖️: I am MiniMind, an AI assistant developed by Jingyao Gong.
|
|
||||||
I use natural language processing and machine learning algorithms to interact with users.
|
|
||||||
|
|
||||||
👶: What's the capital of France?
|
|
||||||
🤖️: The capital of France is Paris, which is located in the northern central part of France.
|
|
||||||
It is the largest city in France and serves as its political, economic, and cultural center.
|
|
||||||
```
|
|
||||||
|
|
||||||
### 4. Web UI Demo (Optional)
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Requires Python >= 3.10
|
|
||||||
pip install streamlit
|
|
||||||
|
|
||||||
cd scripts
|
|
||||||
streamlit run web_demo.py
|
|
||||||
```
|
|
||||||
|
|
||||||
Visit `http://localhost:8501` to use the interactive web interface.
|
|
||||||
|
|
||||||
### 5. Rope Length Extrapolation with YaRN
|
|
||||||
|
|
||||||
Extend context length beyond training with RoPE extrapolation:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
python eval_model.py --inference_rope_scaling True
|
|
||||||
```
|
|
||||||
|
|
||||||
This enables the YaRN algorithm to handle sequences longer than the 2K training context, useful for processing documents and long conversations.
|
|
||||||
|
|
||||||
## 🔧 Third-Party Inference Frameworks
|
|
||||||
|
|
||||||
MiniMind is compatible with popular inference engines:
|
|
||||||
|
|
||||||
### Ollama (Easiest)
|
|
||||||
|
|
||||||
```bash
|
|
||||||
ollama run jingyaogong/minimind2
|
|
||||||
```
|
|
||||||
|
|
||||||
### vLLM (Fastest)
|
|
||||||
|
|
||||||
```bash
|
|
||||||
vllm serve ./MiniMind2/ --served-model-name "minimind" --port 8000
|
|
||||||
|
|
||||||
# Test with curl
|
|
||||||
curl http://localhost:8000/v1/chat/completions \
|
|
||||||
-H "Content-Type: application/json" \
|
|
||||||
-d '{
|
|
||||||
"model": "minimind",
|
|
||||||
"messages": [{"role": "user", "content": "Hello!"}],
|
|
||||||
"temperature": 0.7,
|
|
||||||
"max_tokens": 512
|
|
||||||
}'
|
|
||||||
```
|
|
||||||
|
|
||||||
### llama.cpp (CPU-Friendly)
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Convert to GGUF format
|
|
||||||
python scripts/convert_model.py ./MiniMind2/ --output ./MiniMind2.gguf
|
|
||||||
|
|
||||||
# Quantize for size reduction
|
|
||||||
./llama-quantize ./MiniMind2.gguf ./MiniMind2-Q4.gguf Q4_K_M
|
|
||||||
|
|
||||||
# Run inference
|
|
||||||
./llama-cli -m ./MiniMind2-Q4.gguf -p "Hello" -n 128
|
|
||||||
```
|
|
||||||
|
|
||||||
## 🔌 OpenAI API Server (For Integration)
|
|
||||||
|
|
||||||
Run MiniMind as an OpenAI API-compatible service:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
python scripts/serve_openai_api.py
|
|
||||||
```
|
|
||||||
|
|
||||||
Test the API:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# In another terminal
|
|
||||||
python scripts/chat_openai_api.py
|
|
||||||
```
|
|
||||||
|
|
||||||
**cURL Example**:
|
|
||||||
```bash
|
|
||||||
curl http://localhost:8000/v1/chat/completions \
|
|
||||||
-H "Content-Type: application/json" \
|
|
||||||
-d '{
|
|
||||||
"model": "minimind",
|
|
||||||
"messages": [
|
|
||||||
{"role": "user", "content": "Explain machine learning in one sentence."}
|
|
||||||
],
|
|
||||||
"temperature": 0.7,
|
|
||||||
"max_tokens": 256,
|
|
||||||
"stream": true
|
|
||||||
}'
|
|
||||||
```
|
|
||||||
|
|
||||||
This enables integration with:
|
|
||||||
- [FastGPT](https://fastgpt.run/)
|
|
||||||
- [Open-WebUI](https://github.com/open-webui/open-webui)
|
|
||||||
- [Dify](https://dify.ai/)
|
|
||||||
- Any OpenAI API-compatible client
|
|
||||||
|
|
||||||
## 📊 Model Selection Guide
|
|
||||||
|
|
||||||
| Use Case | Recommended Model | Memory | Speed |
|
|
||||||
|----------|------------------|--------|-------|
|
|
||||||
| Learning/Testing | MiniMind2-small (26M) | ~0.5 GB | Fastest |
|
|
||||||
| Balanced | MiniMind2 (104M) | ~1.0 GB | Fast |
|
|
||||||
| Expert System (MoE) | MiniMind2-MoE (145M) | ~1.0 GB | Dynamic |
|
|
||||||
| Reasoning/Complex | MiniMind-Reason (104M) | ~1.0 GB | Standard |
|
|
||||||
|
|
||||||
## ⚡ Quick Test Results
|
|
||||||
|
|
||||||
**Model**: MiniMind2 (104M parameters)
|
|
||||||
|
|
||||||
```text
|
|
||||||
Q: What is photosynthesis?
|
|
||||||
A: Photosynthesis is a process in which plants convert light energy from the sun
|
|
||||||
into chemical energy to produce glucose. This process occurs mainly in leaves
|
|
||||||
and is essential for plant growth and survival.
|
|
||||||
|
|
||||||
Q: Write a Python function to calculate Fibonacci numbers.
|
|
||||||
A: def fibonacci(n):
|
|
||||||
if n <= 1:
|
|
||||||
return n
|
|
||||||
return fibonacci(n-1) + fibonacci(n-2)
|
|
||||||
|
|
||||||
# For better performance, use dynamic programming:
|
|
||||||
def fibonacci_dp(n):
|
|
||||||
dp = [0] * (n + 1)
|
|
||||||
for i in range(2, n + 1):
|
|
||||||
dp[i] = dp[i-1] + dp[i-2]
|
|
||||||
return dp[n]
|
|
||||||
|
|
||||||
Q: 世界上最高的山峰是什么? (What is the highest mountain?)
|
|
||||||
A: 珠穆朗玛峰(Mount Everest)是世界上最高的山峰,位于喜马拉雅山脉...
|
|
||||||
(Mount Everest is the world's highest mountain, located in the Himalayas...)
|
|
||||||
```
|
|
||||||
|
|
||||||
## 🆘 Troubleshooting
|
|
||||||
|
|
||||||
### Issue: CUDA Out of Memory
|
|
||||||
|
|
||||||
**Solution**:
|
|
||||||
```bash
|
|
||||||
# Reduce batch size
|
|
||||||
python eval_model.py --batch_size 1
|
|
||||||
|
|
||||||
# Or use CPU (slow but works)
|
|
||||||
python eval_model.py --device cpu
|
|
||||||
```
|
|
||||||
|
|
||||||
### Issue: Slow Inference
|
|
||||||
|
|
||||||
**Solutions**:
|
|
||||||
- Use vLLM or llama.cpp for faster inference
|
|
||||||
- Enable quantization (4-bit, 8-bit)
|
|
||||||
- Use GPU instead of CPU
|
|
||||||
- Reduce `max_tokens` parameter
|
|
||||||
|
|
||||||
### Issue: Model Responses Are Poor Quality
|
|
||||||
|
|
||||||
**Possible Causes**:
|
|
||||||
- Using pretrain model (`model_mode 0`) instead of SFT (`model_mode 1`)
|
|
||||||
- Model is undertrained - download the full checkpoint instead
|
|
||||||
- Input prompt is too short - provide more context
|
|
||||||
|
|
||||||
### Issue: Python/PyTorch Version Mismatch
|
|
||||||
|
|
||||||
**Solution**:
|
|
||||||
```bash
|
|
||||||
# Use conda for clean environment
|
|
||||||
conda create -n minimind python=3.10
|
|
||||||
conda activate minimind
|
|
||||||
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu122
|
|
||||||
pip install -r requirements.txt
|
|
||||||
```
|
|
||||||
|
|
||||||
## 📖 Next Steps
|
|
||||||
|
|
||||||
- **[Model Training Guide](training.md)** - Train your own MiniMind from scratch
|
|
||||||
- **[Source Code](https://github.com/jingyaogong/minimind)** - Explore and learn LLM implementation
|
|
||||||
- **[Inference Benchmarks](https://huggingface.co/collections/jingyaogong/minimind-66caf8d999f5c7fa64f399e5)** - See model performance comparisons
|
|
||||||
|
|
||||||
## 💡 Pro Tips
|
|
||||||
|
|
||||||
1. **GPU Memory Optimization**: Use `torch.cuda.empty_cache()` periodically
|
|
||||||
2. **Batch Processing**: For efficiency, process multiple prompts in batches
|
|
||||||
3. **Temperature Tuning**: Lower (0.3-0.7) = more consistent, Higher (0.8-1.0) = more creative
|
|
||||||
4. **Prompt Engineering**: Better prompts → better results, even for small models
|
|
||||||
5. **Model Quantization**: Use 4-bit quantization to run on smaller GPUs
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
Done! Now you're ready to use MiniMind. Start with the Quick Start, then move to [Model Training](training.md) to learn how to train your own models.
|
|
||||||
|
|
||||||
679
docs/training.md
@ -1,679 +0,0 @@
|
|||||||
# Model Training Guide
|
|
||||||
|
|
||||||
Learn how to train MiniMind language models from scratch using pure PyTorch.
|
|
||||||
|
|
||||||
## 📊 Training Overview
|
|
||||||
|
|
||||||
MiniMind implements a complete training pipeline:
|
|
||||||
|
|
||||||
```
|
|
||||||
Tokenizer Training
|
|
||||||
↓
|
|
||||||
Pretraining (Learn knowledge)
|
|
||||||
↓
|
|
||||||
SFT (Learn conversation)
|
|
||||||
↓
|
|
||||||
┌───────────────────┬─────────────────────┬──────────────┐
|
|
||||||
↓ ↓ ↓ ↓
|
|
||||||
LoRA DPO/RLHF RLAIF (PPO/GRPO/SPO) Distillation
|
|
||||||
(Domain adapt) (Preference) (Reinforcement Learn) (Reasoning)
|
|
||||||
```
|
|
||||||
|
|
||||||
## 💰 Training Costs (Single NVIDIA 3090)
|
|
||||||
|
|
||||||
| Model | Dataset | Duration | Cost (RMB) | Quality |
|
|
||||||
|-------|---------|----------|-----------|---------|
|
|
||||||
| MiniMind2-Small | pretrain_hq + sft_mini_512 | 2.1h | ≈3 | 😊😊 |
|
|
||||||
| MiniMind2-Small | Full dataset | 38h | ≈50 | 😊😊😊😊😊😊 |
|
|
||||||
| MiniMind2 | pretrain_hq + sft_mini_512 | 3.3h | ≈5 | 😊😊 |
|
|
||||||
| MiniMind2 | Full dataset | 122h | ≈160 | 😊😊😊😊😊😊😊 |
|
|
||||||
|
|
||||||
!!! success "Ultra-Fast Training"
|
|
||||||
**Just 2.1 hours + $3 = Functional ChatBot!**
|
|
||||||
|
|
||||||
Use `pretrain_hq.jsonl` + `sft_mini_512.jsonl` for fastest reproduction
|
|
||||||
|
|
||||||
## 📋 Data Preparation
|
|
||||||
|
|
||||||
### 1. Download Datasets
|
|
||||||
|
|
||||||
Download from [ModelScope](https://www.modelscope.cn/datasets/gongjy/minimind_dataset) or [HuggingFace](https://huggingface.co/datasets/jingyaogong/minimind_dataset):
|
|
||||||
|
|
||||||
```bash
|
|
||||||
mkdir -p dataset
|
|
||||||
cd dataset
|
|
||||||
# Download required files
|
|
||||||
```
|
|
||||||
|
|
||||||
### 2. Dataset Directory Structure
|
|
||||||
|
|
||||||
```
|
|
||||||
./dataset/
|
|
||||||
├── pretrain_hq.jsonl ✨ (1.6GB, required for pretraining)
|
|
||||||
├── sft_mini_512.jsonl ✨ (1.2GB, fastest SFT)
|
|
||||||
├── sft_512.jsonl (7.5GB, standard SFT)
|
|
||||||
├── sft_1024.jsonl (5.6GB, longer SFT)
|
|
||||||
├── sft_2048.jsonl (9GB, very long SFT)
|
|
||||||
├── dpo.jsonl (909MB, DPO training)
|
|
||||||
├── r1_mix_1024.jsonl (340MB, reasoning distillation)
|
|
||||||
├── rlaif-mini.jsonl (1MB, RLAIF algorithms)
|
|
||||||
├── lora_identity.jsonl (22.8KB, identity LoRA)
|
|
||||||
└── lora_medical.jsonl (34MB, medical domain LoRA)
|
|
||||||
```
|
|
||||||
|
|
||||||
### 3. Data Formats
|
|
||||||
|
|
||||||
**Pretraining Data** (`pretrain_hq.jsonl`):
|
|
||||||
```json
|
|
||||||
{"text": "How to overcome procrastination? Overcoming procrastination is not easy, but these suggestions may help..."}
|
|
||||||
```
|
|
||||||
|
|
||||||
**SFT Data** (`sft_*.jsonl`):
|
|
||||||
```json
|
|
||||||
{
|
|
||||||
"conversations": [
|
|
||||||
{"role": "user", "content": "Hello!"},
|
|
||||||
{"role": "assistant", "content": "Hello! How can I help?"},
|
|
||||||
{"role": "user", "content": "Tell me a joke."},
|
|
||||||
{"role": "assistant", "content": "Why did the scarecrow win an award? Because he was outstanding in his field!"}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
**DPO Data** (`dpo.jsonl`):
|
|
||||||
```json
|
|
||||||
{
|
|
||||||
"chosen": [
|
|
||||||
{"role": "user", "content": "What is 2+2?"},
|
|
||||||
{"role": "assistant", "content": "2+2 equals 4."}
|
|
||||||
],
|
|
||||||
"rejected": [
|
|
||||||
{"role": "user", "content": "What is 2+2?"},
|
|
||||||
{"role": "assistant", "content": "2+2 equals 5."}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
**LoRA Domain Data** (`lora_*.jsonl`):
|
|
||||||
```json
|
|
||||||
{
|
|
||||||
"conversations": [
|
|
||||||
{"role": "user", "content": "What's the treatment for cervical spondylosis?"},
|
|
||||||
{"role": "assistant", "content": "Cervical spondylosis treatment typically includes..."}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
## 🎯 Complete Training Pipeline
|
|
||||||
|
|
||||||
All training scripts are in the `./trainer` directory.
|
|
||||||
|
|
||||||
```bash
|
|
||||||
cd trainer
|
|
||||||
```
|
|
||||||
|
|
||||||
### Stage 1: Pretraining
|
|
||||||
|
|
||||||
**Purpose**: Learn foundational knowledge (word continuation)
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Single GPU
|
|
||||||
python train_pretrain.py
|
|
||||||
|
|
||||||
# Multi-GPU (DDP)
|
|
||||||
torchrun --nproc_per_node 2 train_pretrain.py
|
|
||||||
|
|
||||||
# Multi-GPU (DeepSpeed)
|
|
||||||
deepspeed --master_port 29500 --num_gpus=2 train_pretrain.py
|
|
||||||
```
|
|
||||||
|
|
||||||
**Key Parameters**:
|
|
||||||
- `max_seq_len`: 512 (adjust based on GPU memory)
|
|
||||||
- `learning_rate`: 1e-4
|
|
||||||
- `epochs`: Adjust based on dataset size
|
|
||||||
|
|
||||||
**Output**: `./out/pretrain_*.pth`
|
|
||||||
|
|
||||||
**Training Duration**:
|
|
||||||
- MiniMind2-Small (26M): ~1.1h
|
|
||||||
- MiniMind2 (104M): ~3.9h
|
|
||||||
|
|
||||||
!!! tip "Pretraining Tips"
|
|
||||||
- Start with `pretrain_hq.jsonl` for best results
|
|
||||||
- Quality > Quantity for pretraining data
|
|
||||||
- Monitor loss curve to detect overfitting
|
|
||||||
|
|
||||||
### Stage 2: Supervised Fine-Tuning (SFT)
|
|
||||||
|
|
||||||
**Purpose**: Teach conversation patterns and chat templates
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Single GPU
|
|
||||||
python train_full_sft.py
|
|
||||||
|
|
||||||
# Multi-GPU
|
|
||||||
torchrun --nproc_per_node 2 train_full_sft.py
|
|
||||||
```
|
|
||||||
|
|
||||||
**Configuration**:
|
|
||||||
- Load pretrained model from Stage 1
|
|
||||||
- Use SFT dataset (`sft_mini_512.jsonl` or `sft_512.jsonl`)
|
|
||||||
- Adjust `max_seq_len` to match training data
|
|
||||||
|
|
||||||
**Output**: `./out/full_sft_*.pth`
|
|
||||||
|
|
||||||
**Training Duration**:
|
|
||||||
- With sft_mini_512: 1-3 hours
|
|
||||||
- With full sft_512: 20-25 hours
|
|
||||||
|
|
||||||
!!! warning "SFT Data Selection"
|
|
||||||
- `sft_mini_512.jsonl`: Fastest, ~1.2GB, 512 tokens max
|
|
||||||
- `sft_512.jsonl`: Standard, ~7.5GB, 512 tokens max
|
|
||||||
- `sft_1024.jsonl`: Longer, ~5.6GB, 1024 tokens max
|
|
||||||
- `sft_2048.jsonl`: Extended, ~9GB, 2048 tokens max
|
|
||||||
|
|
||||||
### Stage 3: LoRA Fine-Tuning (Optional)
|
|
||||||
|
|
||||||
**Purpose**: Parameter-efficient domain adaptation
|
|
||||||
|
|
||||||
**Use Cases**:
|
|
||||||
- Medical Q&A knowledge
|
|
||||||
- Personal identity/self-awareness
|
|
||||||
- Proprietary domain knowledge
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Edit train_lora.py to set correct dataset and base model
|
|
||||||
python train_lora.py
|
|
||||||
|
|
||||||
# Multi-GPU
|
|
||||||
torchrun --nproc_per_node 2 train_lora.py
|
|
||||||
```
|
|
||||||
|
|
||||||
**Output**: `./out/lora/lora_*.pth`
|
|
||||||
|
|
||||||
**Example 1: Medical Domain**
|
|
||||||
|
|
||||||
Prepare `dataset/lora_medical.jsonl`:
|
|
||||||
```json
|
|
||||||
{
|
|
||||||
"conversations": [
|
|
||||||
{"role": "user", "content": "What's the correct pillow height for cervical spondylosis?"},
|
|
||||||
{"role": "assistant", "content": "For cervical spondylosis, pillow height should be..."}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
Train:
|
|
||||||
```bash
|
|
||||||
# Modify train_lora.py: lora_name = 'medical'
|
|
||||||
python train_lora.py
|
|
||||||
```
|
|
||||||
|
|
||||||
**Example 2: Identity/Self-Awareness**
|
|
||||||
|
|
||||||
Prepare `dataset/lora_identity.jsonl`:
|
|
||||||
```json
|
|
||||||
{
|
|
||||||
"conversations": [
|
|
||||||
{"role": "user", "content": "Who are you?"},
|
|
||||||
{"role": "assistant", "content": "I am MiniMind..."}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
### Stage 4: Direct Preference Optimization (DPO)
|
|
||||||
|
|
||||||
**Purpose**: Align model responses with human preferences
|
|
||||||
|
|
||||||
DPO eliminates the need for separate reward models by directly optimizing preference pairs.
|
|
||||||
|
|
||||||
```bash
|
|
||||||
python train_dpo.py
|
|
||||||
|
|
||||||
# Multi-GPU
|
|
||||||
torchrun --nproc_per_node 2 train_dpo.py
|
|
||||||
```
|
|
||||||
|
|
||||||
**Output**: `./out/rlhf_*.pth`
|
|
||||||
|
|
||||||
**Key Features**:
|
|
||||||
- Off-policy training (reuse data across epochs)
|
|
||||||
- No separate reward model needed
|
|
||||||
- Better sample efficiency than PPO
|
|
||||||
- Stable training convergence
|
|
||||||
|
|
||||||
**Training Duration**: ~1-3 hours
|
|
||||||
|
|
||||||
### Stage 5: Reinforcement Learning from AI Feedback (RLAIF)
|
|
||||||
|
|
||||||
RLAIF is an advanced training approach using AI-generated rewards instead of human annotations. MiniMind implements three modern algorithms:
|
|
||||||
|
|
||||||
#### 5.1 PPO (Proximal Policy Optimization)
|
|
||||||
|
|
||||||
Classical on-policy RL algorithm with proven stability.
|
|
||||||
|
|
||||||
```bash
|
|
||||||
python train_ppo.py
|
|
||||||
|
|
||||||
# Multi-GPU
|
|
||||||
torchrun --nproc_per_node 2 train_ppo.py
|
|
||||||
```
|
|
||||||
|
|
||||||
**Algorithm**:
|
|
||||||
$$\mathcal{L}_{PPO} = -\mathbb{E}\left[\min(r_t \cdot A_t, \text{clip}(r_t, 1-\varepsilon, 1+\varepsilon) \cdot A_t)\right] + \beta \cdot \mathbb{E}[\text{KL}]$$
|
|
||||||
|
|
||||||
**Characteristics**:
|
|
||||||
- Stable but slower reward improvement
|
|
||||||
- Requires both Actor and Critic networks
|
|
||||||
- High memory usage (1.5-2× single network)
|
|
||||||
- Good for exploration
|
|
||||||
|
|
||||||
**Output**: `./out/ppo_actor_*.pth`
|
|
||||||
|
|
||||||
**Training Duration**: ~1-3 hours
|
|
||||||
|
|
||||||
#### 5.2 GRPO (Group Relative Policy Optimization)
|
|
||||||
|
|
||||||
Modern algorithm used in DeepSeek-R1, with faster convergence.
|
|
||||||
|
|
||||||
```bash
|
|
||||||
python train_grpo.py
|
|
||||||
|
|
||||||
# Multi-GPU
|
|
||||||
torchrun --nproc_per_node 2 train_grpo.py
|
|
||||||
```
|
|
||||||
|
|
||||||
**Algorithm**:
|
|
||||||
$$\mathcal{L}_{GRPO} = -\mathbb{E}\left[r_t \cdot A_t - \beta \cdot \text{KL}_t\right]$$
|
|
||||||
|
|
||||||
Where advantage is computed as:
|
|
||||||
$$A_t = \frac{R - \mu_{group}}{\sigma_{group}}$$
|
|
||||||
|
|
||||||
**Characteristics**:
|
|
||||||
- Single-network design (memory efficient)
|
|
||||||
- Faster reward improvement
|
|
||||||
- Group normalization removes bias
|
|
||||||
- Better convergence stability
|
|
||||||
|
|
||||||
**Output**: `./out/grpo_*.pth`
|
|
||||||
|
|
||||||
**Training Duration**: ~1-3 hours
|
|
||||||
|
|
||||||
#### 5.3 SPO (Single-stream Policy Optimization)
|
|
||||||
|
|
||||||
Newest algorithm (2025) addressing GRPO's degenerate group problem.
|
|
||||||
|
|
||||||
```bash
|
|
||||||
python train_spo.py
|
|
||||||
|
|
||||||
# Multi-GPU
|
|
||||||
torchrun --nproc_per_node 2 train_spo.py
|
|
||||||
```
|
|
||||||
|
|
||||||
**Algorithm**:
|
|
||||||
$$\mathcal{L}_{SPO} = -\mathbb{E}\left[\log \pi_\theta(a_t|s) \cdot A_t - \beta \cdot \text{KL}_t\right]$$
|
|
||||||
|
|
||||||
With adaptive baseline: $B_t^{adaptive}$
|
|
||||||
|
|
||||||
**Characteristics**:
|
|
||||||
- No group dependency (1 input → 1 training sample)
|
|
||||||
- Adaptive value tracking
|
|
||||||
- Better handling of difficult examples
|
|
||||||
- Experimental on small models
|
|
||||||
|
|
||||||
**Output**: `./out/spo_*.pth`
|
|
||||||
|
|
||||||
**Training Duration**: ~1-3 hours
|
|
||||||
|
|
||||||
#### RLAIF Dataset Preparation
|
|
||||||
|
|
||||||
All RLAIF algorithms use `rlaif-mini.jsonl` (1MB, 10k examples):
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Download dataset
|
|
||||||
# Format: Same as SFT, but assistant content is "无" (none)
|
|
||||||
{
|
|
||||||
"conversations": [
|
|
||||||
{"role": "user", "content": "Explain photosynthesis briefly."},
|
|
||||||
{"role": "assistant", "content": "无"}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
The model generates completions during training, which are scored by a **Reward Model** (e.g., InternLM2-1.8B-Reward).
|
|
||||||
|
|
||||||
**Reward Model Setup**:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Download reward model to parent directory
|
|
||||||
cd ../
|
|
||||||
git clone https://huggingface.co/internlm/internlm2-1_8b-reward
|
|
||||||
|
|
||||||
# Directory structure should be:
|
|
||||||
# project/
|
|
||||||
# ├── minimind/
|
|
||||||
# └── internlm2-1_8b-reward/
|
|
||||||
```
|
|
||||||
|
|
||||||
#### RLAIF vs DPO Comparison
|
|
||||||
|
|
||||||
| Aspect | DPO | RLAIF (PPO/GRPO/SPO) |
|
|
||||||
|--------|-----|---------------------|
|
|
||||||
| Training Type | Off-policy | On-policy |
|
|
||||||
| Data Freshness | Static pairs | Dynamic (generated) |
|
|
||||||
| Reward Source | Implicit | Explicit model |
|
|
||||||
| Convergence | Fast | Slower |
|
|
||||||
| Memory Usage | Lower | Higher |
|
|
||||||
| Best For | Preference refinement | Capability improvement |
|
|
||||||
|
|
||||||
### Stage 6: Reasoning Model Distillation
|
|
||||||
|
|
||||||
**Purpose**: Distill DeepSeek-R1-style reasoning into MiniMind
|
|
||||||
|
|
||||||
```bash
|
|
||||||
python train_distill_reason.py
|
|
||||||
|
|
||||||
# Multi-GPU
|
|
||||||
torchrun --nproc_per_node 2 train_distill_reason.py
|
|
||||||
```
|
|
||||||
|
|
||||||
**Data Format** (`r1_mix_1024.jsonl`):
|
|
||||||
```json
|
|
||||||
{
|
|
||||||
"conversations": [
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": "Solve: 5 + 3 = ?"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"role": "assistant",
|
|
||||||
"content": "<think>\nI need to add 5 and 3.\n5 + 3 = 8\n</think>\n<answer>\n5 + 3 = 8\n</answer>"
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
**Output**: `./out/reason_*.pth`
|
|
||||||
|
|
||||||
**Training Features**:
|
|
||||||
- Enforces `<think>` and `<answer>` tags
|
|
||||||
- Penalty loss for format violations
|
|
||||||
- Mixed data (reasoning + multi-turn + English)
|
|
||||||
|
|
||||||
## 🔧 Multi-GPU Training
|
|
||||||
|
|
||||||
### DDP (Distributed Data Parallel)
|
|
||||||
|
|
||||||
Best for single-machine multi-GPU:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
torchrun --nproc_per_node N train_xxx.py
|
|
||||||
# N = number of GPUs
|
|
||||||
```
|
|
||||||
|
|
||||||
### DeepSpeed
|
|
||||||
|
|
||||||
For advanced optimization:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
deepspeed --master_port 29500 --num_gpus=N train_xxx.py
|
|
||||||
```
|
|
||||||
|
|
||||||
### Wandb Monitoring
|
|
||||||
|
|
||||||
Track training progress:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Login first
|
|
||||||
wandb login
|
|
||||||
|
|
||||||
# Enable wandb logging
|
|
||||||
torchrun --nproc_per_node N train_xxx.py --use_wandb
|
|
||||||
|
|
||||||
# Or SwanLab (China-friendly alternative)
|
|
||||||
python train_xxx.py --use_wandb # Automatically uses SwanLab if available
|
|
||||||
```
|
|
||||||
|
|
||||||
## 🧪 Model Testing
|
|
||||||
|
|
||||||
### Evaluate Pretrain Model
|
|
||||||
|
|
||||||
```bash
|
|
||||||
python eval_model.py --model_mode 0
|
|
||||||
```
|
|
||||||
|
|
||||||
### Evaluate Chat Model
|
|
||||||
|
|
||||||
```bash
|
|
||||||
python eval_model.py --model_mode 1
|
|
||||||
```
|
|
||||||
|
|
||||||
### Evaluate with LoRA
|
|
||||||
|
|
||||||
```bash
|
|
||||||
python eval_model.py --lora_name 'lora_medical' --model_mode 1
|
|
||||||
```
|
|
||||||
|
|
||||||
### Evaluate Reasoning Model
|
|
||||||
|
|
||||||
```bash
|
|
||||||
python eval_model.py --model_mode 3
|
|
||||||
```
|
|
||||||
|
|
||||||
### Evaluate RLAIF Models
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# PPO model
|
|
||||||
python eval_model.py --model_mode 4
|
|
||||||
|
|
||||||
# GRPO model
|
|
||||||
python eval_model.py --model_mode 4
|
|
||||||
```
|
|
||||||
|
|
||||||
### RoPE Length Extrapolation
|
|
||||||
|
|
||||||
Test with extended context:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
python eval_model.py --model_mode 1 --inference_rope_scaling True
|
|
||||||
```
|
|
||||||
|
|
||||||
## 📐 Model Architecture
|
|
||||||
|
|
||||||
### MiniMind Structure
|
|
||||||
|
|
||||||
**Decoder-Only Transformer** (similar to Llama3):
|
|
||||||
|
|
||||||
```
|
|
||||||
Input Tokens
|
|
||||||
↓
|
|
||||||
Token Embedding (6400 vocab)
|
|
||||||
↓
|
|
||||||
Rotary Embeddings (RoPE) [with YaRN for length extrapolation]
|
|
||||||
↓
|
|
||||||
[Transformer Blocks] ×N
|
|
||||||
├─ Attention (Multi-Head)
|
|
||||||
├─ RMSNorm
|
|
||||||
├─ SwiGLU FFN [or MoE for MoE variant]
|
|
||||||
└─ Residual Connections
|
|
||||||
↓
|
|
||||||
RMSNorm
|
|
||||||
↓
|
|
||||||
LM Head (→ 6400 vocab logits)
|
|
||||||
↓
|
|
||||||
Output Probabilities
|
|
||||||
```
|
|
||||||
|
|
||||||
### Model Configurations
|
|
||||||
|
|
||||||
| Config | MiniMind2-Small | MiniMind2 | MiniMind2-MoE |
|
|
||||||
|--------|-----------------|----------|---------------|
|
|
||||||
| Parameters | 26M | 104M | 145M |
|
|
||||||
| Hidden Dim | 512 | 768 | 640 |
|
|
||||||
| Layers | 8 | 16 | 8 |
|
|
||||||
| KV Heads | 2 | 2 | 2 |
|
|
||||||
| Q Heads | 8 | 8 | 8 |
|
|
||||||
| Vocab Size | 6,400 | 6,400 | 6,400 |
|
|
||||||
| Context Length | 2,048 | 2,048 | 2,048 |
|
|
||||||
|
|
||||||
### Modify Architecture
|
|
||||||
|
|
||||||
Edit `./model/LMConfig.py`:
|
|
||||||
|
|
||||||
```python
|
|
||||||
class LMConfig:
|
|
||||||
hidden_size: int = 768
|
|
||||||
num_layers: int = 16
|
|
||||||
num_heads: int = 8
|
|
||||||
num_kv_heads: int = 2
|
|
||||||
# ... other configs
|
|
||||||
```
|
|
||||||
|
|
||||||
## 🔍 Training Tips & Best Practices
|
|
||||||
|
|
||||||
### Data Quality > Quantity
|
|
||||||
|
|
||||||
- High-quality pretraining data accelerates convergence
|
|
||||||
- `pretrain_hq.jsonl` is carefully curated for quality
|
|
||||||
- Consider data deduplication and cleaning
|
|
||||||
|
|
||||||
### Learning Rate Scheduling
|
|
||||||
|
|
||||||
```python
|
|
||||||
# Recommended schedules
|
|
||||||
- Linear warmup then decay
|
|
||||||
- Initial: 1e-4 to 5e-4
|
|
||||||
- Warmup steps: 10% of total
|
|
||||||
- Final: 10% of initial LR
|
|
||||||
```
|
|
||||||
|
|
||||||
### Batch Size & Sequence Length
|
|
||||||
|
|
||||||
```python
|
|
||||||
# Balance between GPU memory and convergence
|
|
||||||
- Pretraining: max_seq_len=512, batch_size=32
|
|
||||||
- SFT: max_seq_len=512, batch_size=16
|
|
||||||
- LoRA: max_seq_len=512, batch_size=16
|
|
||||||
```
|
|
||||||
|
|
||||||
### Memory Optimization
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Reduce batch size if OOM
|
|
||||||
python train_xxx.py --batch_size 8
|
|
||||||
|
|
||||||
# Or use gradient accumulation
|
|
||||||
python train_xxx.py --gradient_accumulation_steps 4
|
|
||||||
```
|
|
||||||
|
|
||||||
### Checkpoint Management
|
|
||||||
|
|
||||||
- Saves every 100 steps by default
|
|
||||||
- Each new save overwrites the old one
|
|
||||||
- Automatic backup before training
|
|
||||||
|
|
||||||
## 🚨 Common Issues & Solutions
|
|
||||||
|
|
||||||
### Issue: CUDA Out of Memory
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Solution 1: Reduce batch size
|
|
||||||
python train_xxx.py --batch_size 4
|
|
||||||
|
|
||||||
# Solution 2: Use gradient accumulation
|
|
||||||
python train_xxx.py --batch_size 16 --gradient_accumulation_steps 2
|
|
||||||
|
|
||||||
# Solution 3: Use smaller model
|
|
||||||
# Edit trainer script to use MiniMind2-Small instead
|
|
||||||
```
|
|
||||||
|
|
||||||
### Issue: Training Not Converging
|
|
||||||
|
|
||||||
```python
|
|
||||||
# Possible causes:
|
|
||||||
1. Learning rate too high/low
|
|
||||||
2. Data quality issues
|
|
||||||
3. Model capacity mismatch
|
|
||||||
|
|
||||||
# Solutions:
|
|
||||||
- Reduce learning rate: --learning_rate 1e-5
|
|
||||||
- Check data format and quality
|
|
||||||
- Try smaller model first
|
|
||||||
```
|
|
||||||
|
|
||||||
### Issue: Multi-GPU Sync Errors
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Ensure:
|
|
||||||
1. All GPUs visible: nvidia-smi
|
|
||||||
2. Same CUDA version across all GPUs
|
|
||||||
3. Network connectivity for distributed training
|
|
||||||
|
|
||||||
# Debug:
|
|
||||||
torchrun --nproc_per_node 2 train_xxx.py --debug
|
|
||||||
```
|
|
||||||
|
|
||||||
### Issue: Different Results Than Expected
|
|
||||||
|
|
||||||
```python
|
|
||||||
# Check:
|
|
||||||
1. Random seed set (reproducibility)
|
|
||||||
2. Correct model checkpoint loaded
|
|
||||||
3. Correct dataset being used
|
|
||||||
4. Same hyperparameters as reference
|
|
||||||
```
|
|
||||||
|
|
||||||
## 📈 Training Progression
|
|
||||||
|
|
||||||
Typical training curves:
|
|
||||||
|
|
||||||
```
|
|
||||||
Pretraining Loss: ↘↘↘ (steep decline, then plateau)
|
|
||||||
SFT Loss: ↘ (steady decline)
|
|
||||||
PPO Reward: ↗ (rising, may plateau)
|
|
||||||
GRPO Reward: ↗↗ (faster rise, more stable)
|
|
||||||
```
|
|
||||||
|
|
||||||
## 🎓 Advanced Topics
|
|
||||||
|
|
||||||
### Custom Datasets
|
|
||||||
|
|
||||||
Create your own dataset:
|
|
||||||
|
|
||||||
```python
|
|
||||||
# Format: JSONL with conversations list
|
|
||||||
# Each line is one training example
|
|
||||||
# Ensure consistent quality and format
|
|
||||||
```
|
|
||||||
|
|
||||||
### Model Quantization (Post-training)
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# 4-bit quantization for inference
|
|
||||||
# Use tools like:
|
|
||||||
# - llama.cpp (gguf format)
|
|
||||||
# - bitsandbytes (dynamic quantization)
|
|
||||||
# - AutoGPTQ (static quantization)
|
|
||||||
```
|
|
||||||
|
|
||||||
### Model Merging
|
|
||||||
|
|
||||||
```python
|
|
||||||
# Merge base model + LoRA weights
|
|
||||||
# Use tools like: peft, llama.cpp
|
|
||||||
```
|
|
||||||
|
|
||||||
## 📚 References
|
|
||||||
|
|
||||||
- [Scaling Laws](https://arxiv.org/pdf/2001.08361.pdf)
|
|
||||||
- [RoPE Position Embeddings](https://arxiv.org/abs/2104.09864)
|
|
||||||
- [YaRN Length Extrapolation](https://arxiv.org/abs/2309.00071)
|
|
||||||
- [PPO Algorithm](https://arxiv.org/abs/1707.06347)
|
|
||||||
- [GRPO (DeepSeek)](https://arxiv.org/pdf/2402.03300)
|
|
||||||
- [SPO Algorithm](https://arxiv.org/abs/2509.13232)
|
|
||||||
- [DPO](https://arxiv.org/abs/2305.18290)
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
**Next**: Deploy your trained model or explore [advanced inference options](quickstart.md#third-party-inference-frameworks)
|
|
||||||
|
|
||||||
94
eval_llm.py
Executable file
@ -0,0 +1,94 @@
|
|||||||
|
import time
|
||||||
|
import argparse
|
||||||
|
import random
|
||||||
|
import warnings
|
||||||
|
import torch
|
||||||
|
from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer
|
||||||
|
from model.model_minimind import MiniMindConfig, MiniMindForCausalLM
|
||||||
|
from model.model_lora import *
|
||||||
|
from trainer.trainer_utils import setup_seed, get_model_params
|
||||||
|
warnings.filterwarnings('ignore')
|
||||||
|
|
||||||
|
def init_model(args):
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(args.load_from)
|
||||||
|
if 'model' in args.load_from:
|
||||||
|
model = MiniMindForCausalLM(MiniMindConfig(
|
||||||
|
hidden_size=args.hidden_size,
|
||||||
|
num_hidden_layers=args.num_hidden_layers,
|
||||||
|
use_moe=bool(args.use_moe),
|
||||||
|
inference_rope_scaling=args.inference_rope_scaling
|
||||||
|
))
|
||||||
|
moe_suffix = '_moe' if args.use_moe else ''
|
||||||
|
ckp = f'./{args.save_dir}/{args.weight}_{args.hidden_size}{moe_suffix}.pth'
|
||||||
|
model.load_state_dict(torch.load(ckp, map_location=args.device), strict=True)
|
||||||
|
if args.lora_weight != 'None':
|
||||||
|
apply_lora(model)
|
||||||
|
load_lora(model, f'./{args.save_dir}/{args.lora_weight}_{args.hidden_size}.pth')
|
||||||
|
else:
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(args.load_from, trust_remote_code=True)
|
||||||
|
get_model_params(model, model.config)
|
||||||
|
return model.half().eval().to(args.device), tokenizer
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description="MiniMind模型推理与对话")
|
||||||
|
parser.add_argument('--load_from', default='model', type=str, help="模型加载路径(model=原生torch权重,其他路径=transformers格式)")
|
||||||
|
parser.add_argument('--save_dir', default='out', type=str, help="模型权重目录")
|
||||||
|
parser.add_argument('--weight', default='full_sft', type=str, help="权重名称前缀(pretrain, full_sft, rlhf, reason, ppo_actor, grpo, spo)")
|
||||||
|
parser.add_argument('--lora_weight', default='None', type=str, help="LoRA权重名称(None表示不使用,可选:lora_identity, lora_medical)")
|
||||||
|
parser.add_argument('--hidden_size', default=768, type=int, help="隐藏层维度")
|
||||||
|
parser.add_argument('--num_hidden_layers', default=8, type=int, help="隐藏层数量")
|
||||||
|
parser.add_argument('--use_moe', default=0, type=int, choices=[0, 1], help="是否使用MoE架构(0=否,1=是)")
|
||||||
|
parser.add_argument('--inference_rope_scaling', default=False, action='store_true', help="启用RoPE位置编码外推(4倍,仅解决位置编码问题)")
|
||||||
|
parser.add_argument('--max_new_tokens', default=8192, type=int, help="最大生成长度(注意:并非模型实际长文本能力)")
|
||||||
|
parser.add_argument('--temperature', default=0.85, type=float, help="生成温度,控制随机性(0-1,越大越随机)")
|
||||||
|
parser.add_argument('--top_p', default=0.95, type=float, help="nucleus采样阈值(0-1)")
|
||||||
|
parser.add_argument('--open_thinking', default=0, type=int, help="是否开启自适应思考(0=否,1=是)")
|
||||||
|
parser.add_argument('--historys', default=0, type=int, help="携带历史对话轮数(需为偶数,0表示不携带历史)")
|
||||||
|
parser.add_argument('--show_speed', default=1, type=int, help="显示decode速度(tokens/s)")
|
||||||
|
parser.add_argument('--device', default='cuda' if torch.cuda.is_available() else 'cpu', type=str, help="运行设备")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
prompts = [
|
||||||
|
'你有什么特长?',
|
||||||
|
'为什么天空是蓝色的',
|
||||||
|
'请用Python写一个计算斐波那契数列的函数',
|
||||||
|
'解释一下"光合作用"的基本过程',
|
||||||
|
'如果明天下雨,我应该如何出门',
|
||||||
|
'比较一下猫和狗作为宠物的优缺点',
|
||||||
|
'解释什么是机器学习',
|
||||||
|
'推荐一些中国的美食'
|
||||||
|
]
|
||||||
|
|
||||||
|
conversation = []
|
||||||
|
model, tokenizer = init_model(args)
|
||||||
|
input_mode = int(input('[0] 自动测试\n[1] 手动输入\n'))
|
||||||
|
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
||||||
|
|
||||||
|
prompt_iter = prompts if input_mode == 0 else iter(lambda: input('💬: '), '')
|
||||||
|
for prompt in prompt_iter:
|
||||||
|
setup_seed(random.randint(0, 31415926))
|
||||||
|
if input_mode == 0: print(f'💬: {prompt}')
|
||||||
|
conversation = conversation[-args.historys:] if args.historys else []
|
||||||
|
conversation.append({"role": "user", "content": prompt})
|
||||||
|
if 'pretrain' in args.weight:
|
||||||
|
inputs = tokenizer.bos_token + prompt
|
||||||
|
else:
|
||||||
|
inputs = tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True, open_thinking=bool(args.open_thinking))
|
||||||
|
|
||||||
|
inputs = tokenizer(inputs, return_tensors="pt", truncation=True).to(args.device)
|
||||||
|
|
||||||
|
print('🧠: ', end='')
|
||||||
|
st = time.time()
|
||||||
|
generated_ids = model.generate(
|
||||||
|
inputs=inputs["input_ids"], attention_mask=inputs["attention_mask"],
|
||||||
|
max_new_tokens=args.max_new_tokens, do_sample=True, streamer=streamer,
|
||||||
|
pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id,
|
||||||
|
top_p=args.top_p, temperature=args.temperature, repetition_penalty=1
|
||||||
|
)
|
||||||
|
response = tokenizer.decode(generated_ids[0][len(inputs["input_ids"][0]):], skip_special_tokens=True)
|
||||||
|
conversation.append({"role": "assistant", "content": response})
|
||||||
|
gen_tokens = len(generated_ids[0]) - len(inputs["input_ids"][0])
|
||||||
|
print(f'\n[Speed]: {gen_tokens / (time.time() - st):.2f} tokens/s\n\n') if args.show_speed else print('\n\n')
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
BIN
images/LLM-structure-moe.jpg
Normal file
|
After Width: | Height: | Size: 313 KiB |
BIN
images/LLM-structure.jpg
Normal file
|
After Width: | Height: | Size: 262 KiB |
BIN
images/agent_rl_loss.jpg
Normal file
|
After Width: | Height: | Size: 702 KiB |
BIN
images/agent_webui.jpg
Normal file
|
After Width: | Height: | Size: 124 KiB |
BIN
images/benchmark_radar.jpg
Normal file
|
After Width: | Height: | Size: 96 KiB |
BIN
images/dataset.jpg
Normal file
|
After Width: | Height: | Size: 123 KiB |
|
Before Width: | Height: | Size: 66 KiB After Width: | Height: | Size: 66 KiB |
BIN
images/grpo_loss.jpg
Normal file
|
After Width: | Height: | Size: 590 KiB |
|
Before Width: | Height: | Size: 495 KiB After Width: | Height: | Size: 495 KiB |
|
Before Width: | Height: | Size: 615 KiB After Width: | Height: | Size: 615 KiB |
BIN
images/minimind-3.gif
Normal file
|
After Width: | Height: | Size: 5.7 MiB |
BIN
images/ppo_loss.jpg
Normal file
|
After Width: | Height: | Size: 601 KiB |
BIN
images/pretrain_loss.jpg
Normal file
|
After Width: | Height: | Size: 292 KiB |
BIN
images/rl-structure.jpg
Normal file
|
After Width: | Height: | Size: 231 KiB |
BIN
images/rope_ppl.png
Normal file
|
After Width: | Height: | Size: 79 KiB |
BIN
images/sft_loss.jpg
Normal file
|
After Width: | Height: | Size: 466 KiB |
|
Before Width: | Height: | Size: 178 KiB After Width: | Height: | Size: 178 KiB |
|
Before Width: | Height: | Size: 150 KiB After Width: | Height: | Size: 150 KiB |
66
mkdocs.yml
@ -1,66 +0,0 @@
|
|||||||
site_name: MiniMind
|
|
||||||
site_description: MiniMind - 轻量级语言模型训练框架 / Lightweight Language Model Training Framework
|
|
||||||
site_author: jingyaogong
|
|
||||||
site_url: https://minimind.readthedocs.io/
|
|
||||||
|
|
||||||
# 搜索插件配置
|
|
||||||
plugins:
|
|
||||||
- search:
|
|
||||||
lang: en
|
|
||||||
|
|
||||||
# 主题配置
|
|
||||||
theme:
|
|
||||||
name: material
|
|
||||||
favicon: images/logo.png
|
|
||||||
icon:
|
|
||||||
logo: material/book-open-page-variant
|
|
||||||
palette:
|
|
||||||
# 浅色模式
|
|
||||||
- scheme: default
|
|
||||||
primary: white
|
|
||||||
accent: blue
|
|
||||||
toggle:
|
|
||||||
icon: material/brightness-7
|
|
||||||
name: 切换至深色模式
|
|
||||||
# 深色模式
|
|
||||||
- scheme: slate
|
|
||||||
primary: black
|
|
||||||
accent: blue
|
|
||||||
toggle:
|
|
||||||
icon: material/brightness-4
|
|
||||||
name: 切换至浅色模式
|
|
||||||
features:
|
|
||||||
# - navigation.instant # 与多语言切换不兼容,已禁用
|
|
||||||
- navigation.tracking # 锚点跟踪
|
|
||||||
- navigation.sections # 导航分组
|
|
||||||
- navigation.expand # 默认展开导航
|
|
||||||
- navigation.top # 返回顶部按钮
|
|
||||||
- search.suggest # 搜索建议
|
|
||||||
- search.highlight # 搜索高亮
|
|
||||||
- content.code.copy # 代码复制按钮
|
|
||||||
- toc.follow # 目录跟随
|
|
||||||
- toc.integrate # 目录集成到左侧边栏
|
|
||||||
language: en
|
|
||||||
|
|
||||||
# 导航结构
|
|
||||||
nav:
|
|
||||||
- Home: index.md
|
|
||||||
- Quick Start: quickstart.md
|
|
||||||
- Model Training: training.md
|
|
||||||
|
|
||||||
# Markdown 扩展
|
|
||||||
markdown_extensions:
|
|
||||||
- toc:
|
|
||||||
permalink: true
|
|
||||||
- admonition
|
|
||||||
- pymdownx.highlight:
|
|
||||||
anchor_linenums: true
|
|
||||||
- pymdownx.inlinehilite
|
|
||||||
- pymdownx.snippets
|
|
||||||
- pymdownx.superfences
|
|
||||||
- pymdownx.details
|
|
||||||
- pymdownx.tabbed:
|
|
||||||
alternate_style: true
|
|
||||||
- attr_list
|
|
||||||
- md_in_html
|
|
||||||
|
|
||||||
0
model/__init__.py
Normal file
65
model/model_lora.py
Normal file
@ -0,0 +1,65 @@
|
|||||||
|
import torch
|
||||||
|
from torch import optim, nn
|
||||||
|
|
||||||
|
|
||||||
|
# 定义Lora网络结构
|
||||||
|
class LoRA(nn.Module):
|
||||||
|
def __init__(self, in_features, out_features, rank):
|
||||||
|
super().__init__()
|
||||||
|
self.rank = rank # LoRA的秩(rank),控制低秩矩阵的大小
|
||||||
|
self.A = nn.Linear(in_features, rank, bias=False) # 低秩矩阵A
|
||||||
|
self.B = nn.Linear(rank, out_features, bias=False) # 低秩矩阵B
|
||||||
|
# 矩阵A高斯初始化
|
||||||
|
self.A.weight.data.normal_(mean=0.0, std=0.02)
|
||||||
|
# 矩阵B全0初始化
|
||||||
|
self.B.weight.data.zero_()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.B(self.A(x))
|
||||||
|
|
||||||
|
|
||||||
|
def apply_lora(model, rank=16):
|
||||||
|
for name, module in model.named_modules():
|
||||||
|
if isinstance(module, nn.Linear) and module.weight.shape[0] == module.weight.shape[1]:
|
||||||
|
lora = LoRA(module.weight.shape[0], module.weight.shape[1], rank=rank).to(model.device)
|
||||||
|
setattr(module, "lora", lora)
|
||||||
|
original_forward = module.forward
|
||||||
|
|
||||||
|
# 显式绑定
|
||||||
|
def forward_with_lora(x, layer1=original_forward, layer2=lora):
|
||||||
|
return layer1(x) + layer2(x)
|
||||||
|
|
||||||
|
module.forward = forward_with_lora
|
||||||
|
|
||||||
|
|
||||||
|
def load_lora(model, path):
|
||||||
|
state_dict = torch.load(path, map_location=model.device)
|
||||||
|
state_dict = {(k[7:] if k.startswith('module.') else k): v for k, v in state_dict.items()}
|
||||||
|
|
||||||
|
for name, module in model.named_modules():
|
||||||
|
if hasattr(module, 'lora'):
|
||||||
|
lora_state = {k.replace(f'{name}.lora.', ''): v for k, v in state_dict.items() if f'{name}.lora.' in k}
|
||||||
|
module.lora.load_state_dict(lora_state)
|
||||||
|
|
||||||
|
|
||||||
|
def save_lora(model, path):
|
||||||
|
raw_model = getattr(model, '_orig_mod', model)
|
||||||
|
state_dict = {}
|
||||||
|
for name, module in raw_model.named_modules():
|
||||||
|
if hasattr(module, 'lora'):
|
||||||
|
clean_name = name[7:] if name.startswith("module.") else name
|
||||||
|
lora_state = {f'{clean_name}.lora.{k}': v.cpu().half() for k, v in module.lora.state_dict().items()}
|
||||||
|
state_dict.update(lora_state)
|
||||||
|
torch.save(state_dict, path)
|
||||||
|
|
||||||
|
|
||||||
|
def merge_lora(model, lora_path, save_path):
|
||||||
|
load_lora(model, lora_path)
|
||||||
|
raw_model = getattr(model, '_orig_mod', model)
|
||||||
|
state_dict = {k: v.cpu().half() for k, v in raw_model.state_dict().items() if '.lora.' not in k}
|
||||||
|
for name, module in raw_model.named_modules():
|
||||||
|
if isinstance(module, nn.Linear) and '.lora.' not in name:
|
||||||
|
state_dict[f'{name}.weight'] = module.weight.data.clone().cpu().half()
|
||||||
|
if hasattr(module, 'lora'):
|
||||||
|
state_dict[f'{name}.weight'] += (module.lora.B.weight.data @ module.lora.A.weight.data).cpu().half()
|
||||||
|
torch.save(state_dict, save_path)
|
||||||
287
model/model_minimind.py
Executable file
@ -0,0 +1,287 @@
|
|||||||
|
import math, torch, torch.nn.functional as F
|
||||||
|
from torch import nn
|
||||||
|
from transformers.activations import ACT2FN
|
||||||
|
from transformers import PreTrainedModel, GenerationMixin, PretrainedConfig
|
||||||
|
from transformers.modeling_outputs import MoeCausalLMOutputWithPast
|
||||||
|
|
||||||
|
# 🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏
|
||||||
|
# MiniMind Config
|
||||||
|
# 🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏
|
||||||
|
class MiniMindConfig(PretrainedConfig):
|
||||||
|
model_type = "minimind"
|
||||||
|
def __init__(self, hidden_size=768, num_hidden_layers=8, use_moe=False, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.num_hidden_layers = num_hidden_layers
|
||||||
|
self.use_moe = use_moe
|
||||||
|
self.dropout = kwargs.get("dropout", 0.0)
|
||||||
|
self.vocab_size = kwargs.get("vocab_size", 6400)
|
||||||
|
self.bos_token_id = kwargs.get("bos_token_id", 1)
|
||||||
|
self.eos_token_id = kwargs.get("eos_token_id", 2)
|
||||||
|
self.flash_attn = kwargs.get("flash_attn", True)
|
||||||
|
self.num_attention_heads = kwargs.get("num_attention_heads", 8)
|
||||||
|
self.num_key_value_heads = kwargs.get("num_key_value_heads", 4)
|
||||||
|
self.head_dim = kwargs.get("head_dim", self.hidden_size // self.num_attention_heads)
|
||||||
|
self.hidden_act = kwargs.get("hidden_act", 'silu')
|
||||||
|
self.intermediate_size = kwargs.get("intermediate_size", math.ceil(hidden_size * math.pi / 64) * 64)
|
||||||
|
self.max_position_embeddings = kwargs.get("max_position_embeddings", 32768)
|
||||||
|
self.rms_norm_eps = kwargs.get("rms_norm_eps", 1e-6)
|
||||||
|
self.rope_theta = kwargs.get("rope_theta", 1e6)
|
||||||
|
self.tie_word_embeddings = kwargs.get("tie_word_embeddings", True)
|
||||||
|
self.inference_rope_scaling = kwargs.get("inference_rope_scaling", False)
|
||||||
|
self.rope_scaling = {
|
||||||
|
"beta_fast": 32,
|
||||||
|
"beta_slow": 1,
|
||||||
|
"factor": 16,
|
||||||
|
"original_max_position_embeddings": 2048,
|
||||||
|
"attention_factor": 1.0,
|
||||||
|
"type": "yarn"
|
||||||
|
} if self.inference_rope_scaling else None
|
||||||
|
### MoE specific configs (ignored if use_moe = False)
|
||||||
|
self.num_experts = kwargs.get("num_experts", 4)
|
||||||
|
self.num_experts_per_tok = kwargs.get("num_experts_per_tok", 1)
|
||||||
|
self.moe_intermediate_size = kwargs.get("moe_intermediate_size", self.intermediate_size)
|
||||||
|
self.norm_topk_prob = kwargs.get("norm_topk_prob", True)
|
||||||
|
self.router_aux_loss_coef = kwargs.get("router_aux_loss_coef", 5e-4)
|
||||||
|
|
||||||
|
# 🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏
|
||||||
|
# MiniMind Model
|
||||||
|
# 🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏
|
||||||
|
class RMSNorm(torch.nn.Module):
|
||||||
|
def __init__(self, dim: int, eps: float = 1e-5):
|
||||||
|
super().__init__()
|
||||||
|
self.eps = eps
|
||||||
|
self.weight = nn.Parameter(torch.ones(dim))
|
||||||
|
|
||||||
|
def norm(self, x):
|
||||||
|
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return (self.weight * self.norm(x.float())).type_as(x)
|
||||||
|
|
||||||
|
def precompute_freqs_cis(dim: int, end: int = int(32 * 1024), rope_base: float = 1e6, rope_scaling: dict = None):
|
||||||
|
freqs, attn_factor = 1.0 / (rope_base ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)), 1.0
|
||||||
|
if rope_scaling is not None: # YaRN: f'(i) = f(i)((1-γ) + γ/s), where γ∈[0,1] is linear ramp
|
||||||
|
orig_max, factor, beta_fast, beta_slow, attn_factor = (
|
||||||
|
rope_scaling.get("original_max_position_embeddings", 2048), rope_scaling.get("factor", 16),
|
||||||
|
rope_scaling.get("beta_fast", 32.0), rope_scaling.get("beta_slow", 1.0), rope_scaling.get("attention_factor", 1.0)
|
||||||
|
)
|
||||||
|
if end / orig_max > 1.0:
|
||||||
|
inv_dim = lambda b: (dim * math.log(orig_max / (b * 2 * math.pi))) / (2 * math.log(rope_base))
|
||||||
|
low, high = max(math.floor(inv_dim(beta_fast)), 0), min(math.ceil(inv_dim(beta_slow)), dim // 2 - 1)
|
||||||
|
ramp = torch.clamp((torch.arange(dim // 2, device=freqs.device).float() - low) / max(high - low, 0.001), 0, 1)
|
||||||
|
freqs = freqs * (1 - ramp + ramp / factor)
|
||||||
|
t = torch.arange(end, device=freqs.device)
|
||||||
|
freqs = torch.outer(t, freqs).float()
|
||||||
|
freqs_cos = torch.cat([torch.cos(freqs), torch.cos(freqs)], dim=-1) * attn_factor
|
||||||
|
freqs_sin = torch.cat([torch.sin(freqs), torch.sin(freqs)], dim=-1) * attn_factor
|
||||||
|
return freqs_cos, freqs_sin
|
||||||
|
|
||||||
|
def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
|
||||||
|
def rotate_half(x): return torch.cat((-x[..., x.shape[-1] // 2:], x[..., : x.shape[-1] // 2]), dim=-1)
|
||||||
|
q_embed = ((q * cos.unsqueeze(unsqueeze_dim)) + (rotate_half(q) * sin.unsqueeze(unsqueeze_dim))).to(q.dtype)
|
||||||
|
k_embed = ((k * cos.unsqueeze(unsqueeze_dim)) + (rotate_half(k) * sin.unsqueeze(unsqueeze_dim))).to(k.dtype)
|
||||||
|
return q_embed, k_embed
|
||||||
|
|
||||||
|
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||||
|
bs, slen, num_key_value_heads, head_dim = x.shape
|
||||||
|
if n_rep == 1: return x
|
||||||
|
return (x[:, :, :, None, :].expand(bs, slen, num_key_value_heads, n_rep, head_dim).reshape(bs, slen, num_key_value_heads * n_rep, head_dim))
|
||||||
|
|
||||||
|
class Attention(nn.Module):
|
||||||
|
def __init__(self, config: MiniMindConfig):
|
||||||
|
super().__init__()
|
||||||
|
self.num_key_value_heads = config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads
|
||||||
|
self.n_local_heads = config.num_attention_heads
|
||||||
|
self.n_local_kv_heads = self.num_key_value_heads
|
||||||
|
self.n_rep = self.n_local_heads // self.n_local_kv_heads
|
||||||
|
self.head_dim = config.head_dim
|
||||||
|
self.is_causal = True
|
||||||
|
self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=False)
|
||||||
|
self.k_proj = nn.Linear(config.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
|
||||||
|
self.v_proj = nn.Linear(config.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
|
||||||
|
self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
|
||||||
|
self.q_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps)
|
||||||
|
self.k_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps)
|
||||||
|
self.attn_dropout = nn.Dropout(config.dropout)
|
||||||
|
self.resid_dropout = nn.Dropout(config.dropout)
|
||||||
|
self.dropout = config.dropout
|
||||||
|
self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') and config.flash_attn
|
||||||
|
|
||||||
|
def forward(self, x, position_embeddings, past_key_value=None, use_cache=False, attention_mask=None):
|
||||||
|
bsz, seq_len, _ = x.shape
|
||||||
|
xq, xk, xv = self.q_proj(x), self.k_proj(x), self.v_proj(x)
|
||||||
|
xq = xq.view(bsz, seq_len, self.n_local_heads, self.head_dim)
|
||||||
|
xk = xk.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim)
|
||||||
|
xv = xv.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim)
|
||||||
|
xq, xk = self.q_norm(xq), self.k_norm(xk)
|
||||||
|
cos, sin = position_embeddings
|
||||||
|
xq, xk = apply_rotary_pos_emb(xq, xk, cos, sin)
|
||||||
|
if past_key_value is not None:
|
||||||
|
xk = torch.cat([past_key_value[0], xk], dim=1)
|
||||||
|
xv = torch.cat([past_key_value[1], xv], dim=1)
|
||||||
|
past_kv = (xk, xv) if use_cache else None
|
||||||
|
xq, xk, xv = (xq.transpose(1, 2), repeat_kv(xk, self.n_rep).transpose(1, 2), repeat_kv(xv, self.n_rep).transpose(1, 2))
|
||||||
|
if self.flash and (seq_len > 1) and (not self.is_causal or past_key_value is None) and (attention_mask is None or torch.all(attention_mask == 1)):
|
||||||
|
output = F.scaled_dot_product_attention(xq, xk, xv, dropout_p=self.dropout if self.training else 0.0, is_causal=self.is_causal)
|
||||||
|
else:
|
||||||
|
scores = (xq @ xk.transpose(-2, -1)) / math.sqrt(self.head_dim)
|
||||||
|
if self.is_causal: scores[:, :, :, -seq_len:] += torch.full((seq_len, seq_len), float("-inf"), device=scores.device).triu(1)
|
||||||
|
if attention_mask is not None: scores += (1.0 - attention_mask.unsqueeze(1).unsqueeze(2)) * -1e9
|
||||||
|
output = self.attn_dropout(F.softmax(scores.float(), dim=-1).type_as(xq)) @ xv
|
||||||
|
output = output.transpose(1, 2).reshape(bsz, seq_len, -1)
|
||||||
|
output = self.resid_dropout(self.o_proj(output))
|
||||||
|
return output, past_kv
|
||||||
|
|
||||||
|
class FeedForward(nn.Module):
|
||||||
|
def __init__(self, config: MiniMindConfig, intermediate_size: int = None):
|
||||||
|
super().__init__()
|
||||||
|
intermediate_size = intermediate_size or config.intermediate_size
|
||||||
|
self.gate_proj = nn.Linear(config.hidden_size, intermediate_size, bias=False)
|
||||||
|
self.down_proj = nn.Linear(intermediate_size, config.hidden_size, bias=False)
|
||||||
|
self.up_proj = nn.Linear(config.hidden_size, intermediate_size, bias=False)
|
||||||
|
self.act_fn = ACT2FN[config.hidden_act]
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
||||||
|
|
||||||
|
class MOEFeedForward(nn.Module):
|
||||||
|
def __init__(self, config: MiniMindConfig):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False)
|
||||||
|
self.experts = nn.ModuleList([FeedForward(config, intermediate_size=config.moe_intermediate_size) for _ in range(config.num_experts)])
|
||||||
|
self.act_fn = ACT2FN[config.hidden_act]
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
batch_size, seq_len, hidden_dim = x.shape
|
||||||
|
x_flat = x.view(-1, hidden_dim)
|
||||||
|
scores = F.softmax(self.gate(x_flat), dim=-1)
|
||||||
|
topk_weight, topk_idx = torch.topk(scores, k=self.config.num_experts_per_tok, dim=-1, sorted=False)
|
||||||
|
if self.config.norm_topk_prob: topk_weight = topk_weight / (topk_weight.sum(dim=-1, keepdim=True) + 1e-20)
|
||||||
|
y = torch.zeros_like(x_flat)
|
||||||
|
for i, expert in enumerate(self.experts):
|
||||||
|
mask = (topk_idx == i)
|
||||||
|
if mask.any():
|
||||||
|
token_idx = mask.any(dim=-1).nonzero().flatten()
|
||||||
|
weight = topk_weight[mask].view(-1, 1)
|
||||||
|
y.index_add_(0, token_idx, (expert(x_flat[token_idx]) * weight).to(y.dtype))
|
||||||
|
elif self.training:
|
||||||
|
y[0, 0] += 0 * sum(p.sum() for p in expert.parameters())
|
||||||
|
if self.training and self.config.router_aux_loss_coef > 0:
|
||||||
|
load = F.one_hot(topk_idx, self.config.num_experts).float().mean(0)
|
||||||
|
self.aux_loss = (load * scores.mean(0)).sum() * self.config.num_experts * self.config.router_aux_loss_coef
|
||||||
|
else:
|
||||||
|
self.aux_loss = scores.new_zeros(1).squeeze()
|
||||||
|
return y.view(batch_size, seq_len, hidden_dim)
|
||||||
|
|
||||||
|
class MiniMindBlock(nn.Module):
|
||||||
|
def __init__(self, layer_id: int, config: MiniMindConfig):
|
||||||
|
super().__init__()
|
||||||
|
self.self_attn = Attention(config)
|
||||||
|
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
self.mlp = FeedForward(config) if not config.use_moe else MOEFeedForward(config)
|
||||||
|
|
||||||
|
def forward(self, hidden_states, position_embeddings, past_key_value=None, use_cache=False, attention_mask=None):
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states, present_key_value = self.self_attn(
|
||||||
|
self.input_layernorm(hidden_states), position_embeddings,
|
||||||
|
past_key_value, use_cache, attention_mask
|
||||||
|
)
|
||||||
|
hidden_states += residual
|
||||||
|
hidden_states = hidden_states + self.mlp(self.post_attention_layernorm(hidden_states))
|
||||||
|
return hidden_states, present_key_value
|
||||||
|
|
||||||
|
class MiniMindModel(nn.Module):
|
||||||
|
def __init__(self, config: MiniMindConfig):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.vocab_size, self.num_hidden_layers = config.vocab_size, config.num_hidden_layers
|
||||||
|
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
|
||||||
|
self.dropout = nn.Dropout(config.dropout)
|
||||||
|
self.layers = nn.ModuleList([MiniMindBlock(l, config) for l in range(self.num_hidden_layers)])
|
||||||
|
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
freqs_cos, freqs_sin = precompute_freqs_cis(dim=config.head_dim, end=config.max_position_embeddings, rope_base=config.rope_theta, rope_scaling=config.rope_scaling)
|
||||||
|
self.register_buffer("freqs_cos", freqs_cos, persistent=False)
|
||||||
|
self.register_buffer("freqs_sin", freqs_sin, persistent=False)
|
||||||
|
|
||||||
|
def forward(self, input_ids, attention_mask=None, past_key_values=None, use_cache=False, **kwargs):
|
||||||
|
batch_size, seq_length = input_ids.shape
|
||||||
|
if hasattr(past_key_values, 'layers'): past_key_values = None
|
||||||
|
past_key_values = past_key_values or [None] * len(self.layers)
|
||||||
|
start_pos = past_key_values[0][0].shape[1] if past_key_values[0] is not None else 0
|
||||||
|
hidden_states = self.dropout(self.embed_tokens(input_ids))
|
||||||
|
# Recompute RoPE buffers lost during meta-device init (transformers>=5.x)
|
||||||
|
if self.freqs_cos[0, 0] == 0:
|
||||||
|
freqs_cos, freqs_sin = precompute_freqs_cis(dim=self.config.head_dim, end=self.config.max_position_embeddings, rope_base=self.config.rope_theta, rope_scaling=self.config.rope_scaling)
|
||||||
|
self.freqs_cos, self.freqs_sin = freqs_cos.to(hidden_states.device), freqs_sin.to(hidden_states.device)
|
||||||
|
position_embeddings = (self.freqs_cos[start_pos:start_pos + seq_length], self.freqs_sin[start_pos:start_pos + seq_length])
|
||||||
|
presents = []
|
||||||
|
for layer, past_key_value in zip(self.layers, past_key_values):
|
||||||
|
hidden_states, present = layer(
|
||||||
|
hidden_states,
|
||||||
|
position_embeddings,
|
||||||
|
past_key_value=past_key_value,
|
||||||
|
use_cache=use_cache,
|
||||||
|
attention_mask=attention_mask
|
||||||
|
)
|
||||||
|
presents.append(present)
|
||||||
|
hidden_states = self.norm(hidden_states)
|
||||||
|
aux_loss = sum([l.mlp.aux_loss for l in self.layers if isinstance(l.mlp, MOEFeedForward)], hidden_states.new_zeros(1).squeeze())
|
||||||
|
return hidden_states, presents, aux_loss
|
||||||
|
|
||||||
|
class MiniMindForCausalLM(PreTrainedModel, GenerationMixin):
|
||||||
|
config_class = MiniMindConfig
|
||||||
|
_tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
|
||||||
|
def __init__(self, config: MiniMindConfig = None):
|
||||||
|
self.config = config or MiniMindConfig()
|
||||||
|
super().__init__(self.config)
|
||||||
|
self.model = MiniMindModel(self.config)
|
||||||
|
self.lm_head = nn.Linear(self.config.hidden_size, self.config.vocab_size, bias=False)
|
||||||
|
if self.config.tie_word_embeddings: self.model.embed_tokens.weight = self.lm_head.weight
|
||||||
|
self.post_init()
|
||||||
|
|
||||||
|
def forward(self, input_ids, attention_mask=None, past_key_values=None, use_cache=False, logits_to_keep=0, labels=None, **kwargs):
|
||||||
|
hidden_states, past_key_values, aux_loss = self.model(input_ids, attention_mask, past_key_values, use_cache, **kwargs)
|
||||||
|
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
||||||
|
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
||||||
|
loss = None
|
||||||
|
if labels is not None:
|
||||||
|
x, y = logits[..., :-1, :].contiguous(), labels[..., 1:].contiguous()
|
||||||
|
loss = F.cross_entropy(x.view(-1, x.size(-1)), y.view(-1), ignore_index=-100)
|
||||||
|
return MoeCausalLMOutputWithPast(loss=loss, aux_loss=aux_loss, logits=logits, past_key_values=past_key_values, hidden_states=hidden_states)
|
||||||
|
|
||||||
|
# https://github.com/jingyaogong/minimind/discussions/611
|
||||||
|
@torch.inference_mode()
|
||||||
|
def generate(self, inputs=None, attention_mask=None, max_new_tokens=8192, temperature=0.85, top_p=0.85, top_k=50, eos_token_id=2, streamer=None, use_cache=True, num_return_sequences=1, do_sample=True, repetition_penalty=1.0, **kwargs):
|
||||||
|
input_ids = kwargs.pop("input_ids", inputs).repeat(num_return_sequences, 1)
|
||||||
|
attention_mask = attention_mask.repeat(num_return_sequences, 1) if attention_mask is not None else None
|
||||||
|
past_key_values = kwargs.pop("past_key_values", None)
|
||||||
|
finished = torch.zeros(input_ids.shape[0], dtype=torch.bool, device=input_ids.device)
|
||||||
|
if streamer: streamer.put(input_ids.cpu())
|
||||||
|
for _ in range(max_new_tokens):
|
||||||
|
past_len = past_key_values[0][0].shape[1] if past_key_values else 0
|
||||||
|
outputs = self.forward(input_ids[:, past_len:], attention_mask, past_key_values, use_cache=use_cache, **kwargs)
|
||||||
|
attention_mask = torch.cat([attention_mask, attention_mask.new_ones(attention_mask.shape[0], 1)], -1) if attention_mask is not None else None
|
||||||
|
logits = outputs.logits[:, -1, :] / temperature
|
||||||
|
if repetition_penalty != 1.0:
|
||||||
|
for i in range(input_ids.shape[0]): logits[i, torch.unique(input_ids[i])] /= repetition_penalty
|
||||||
|
if top_k > 0:
|
||||||
|
logits[logits < torch.topk(logits, top_k)[0][..., -1, None]] = -float('inf')
|
||||||
|
if top_p < 1.0:
|
||||||
|
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
||||||
|
mask = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1) > top_p
|
||||||
|
mask[..., 1:], mask[..., 0] = mask[..., :-1].clone(), 0
|
||||||
|
logits[mask.scatter(1, sorted_indices, mask)] = -float('inf')
|
||||||
|
next_token = torch.multinomial(torch.softmax(logits, dim=-1), num_samples=1) if do_sample else torch.argmax(logits, dim=-1, keepdim=True)
|
||||||
|
if eos_token_id is not None: next_token = torch.where(finished.unsqueeze(-1), next_token.new_full((next_token.shape[0], 1), eos_token_id), next_token)
|
||||||
|
input_ids = torch.cat([input_ids, next_token], dim=-1)
|
||||||
|
past_key_values = outputs.past_key_values if use_cache else None
|
||||||
|
if streamer: streamer.put(next_token.cpu())
|
||||||
|
if eos_token_id is not None:
|
||||||
|
finished |= next_token.squeeze(-1).eq(eos_token_id)
|
||||||
|
if finished.all(): break
|
||||||
|
if streamer: streamer.end()
|
||||||
|
if kwargs.get("return_kv"): return {'generated_ids': input_ids, 'past_kv': past_key_values}
|
||||||
|
return input_ids
|
||||||
31191
model/tokenizer.json
Normal file
335
model/tokenizer_config.json
Normal file
@ -0,0 +1,335 @@
|
|||||||
|
{
|
||||||
|
"add_bos_token": false,
|
||||||
|
"add_eos_token": false,
|
||||||
|
"add_prefix_space": false,
|
||||||
|
"added_tokens_decoder": {
|
||||||
|
"0": {
|
||||||
|
"content": "<|endoftext|>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"1": {
|
||||||
|
"content": "<|im_start|>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"2": {
|
||||||
|
"content": "<|im_end|>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"3": {
|
||||||
|
"content": "<|object_ref_start|>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"4": {
|
||||||
|
"content": "<|object_ref_end|>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"5": {
|
||||||
|
"content": "<|box_start|>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"6": {
|
||||||
|
"content": "<|box_end|>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"7": {
|
||||||
|
"content": "<|quad_start|>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"8": {
|
||||||
|
"content": "<|quad_end|>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"9": {
|
||||||
|
"content": "<|vision_start|>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"10": {
|
||||||
|
"content": "<|vision_end|>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"11": {
|
||||||
|
"content": "<|vision_pad|>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"12": {
|
||||||
|
"content": "<|image_pad|>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"13": {
|
||||||
|
"content": "<|video_pad|>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"14": {
|
||||||
|
"content": "<|audio_start|>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"15": {
|
||||||
|
"content": "<|audio_end|>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"16": {
|
||||||
|
"content": "<|audio_pad|>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"17": {
|
||||||
|
"content": "<tts_pad>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"18": {
|
||||||
|
"content": "<tts_text_bos>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"19": {
|
||||||
|
"content": "<tts_text_eod>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"20": {
|
||||||
|
"content": "<tts_text_bos_single>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"21": {
|
||||||
|
"content": "<tool_call>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": false
|
||||||
|
},
|
||||||
|
"22": {
|
||||||
|
"content": "</tool_call>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": false
|
||||||
|
},
|
||||||
|
"23": {
|
||||||
|
"content": "<tool_response>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": false
|
||||||
|
},
|
||||||
|
"24": {
|
||||||
|
"content": "</tool_response>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": false
|
||||||
|
},
|
||||||
|
"25": {
|
||||||
|
"content": "<think>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": false
|
||||||
|
},
|
||||||
|
"26": {
|
||||||
|
"content": "</think>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": false
|
||||||
|
},
|
||||||
|
"27": {
|
||||||
|
"content": "<|buffer1|>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": false
|
||||||
|
},
|
||||||
|
"28": {
|
||||||
|
"content": "<|buffer2|>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": false
|
||||||
|
},
|
||||||
|
"29": {
|
||||||
|
"content": "<|buffer3|>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": false
|
||||||
|
},
|
||||||
|
"30": {
|
||||||
|
"content": "<|buffer4|>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": false
|
||||||
|
},
|
||||||
|
"31": {
|
||||||
|
"content": "<|buffer5|>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": false
|
||||||
|
},
|
||||||
|
"32": {
|
||||||
|
"content": "<|buffer6|>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": false
|
||||||
|
},
|
||||||
|
"33": {
|
||||||
|
"content": "<|buffer7|>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": false
|
||||||
|
},
|
||||||
|
"34": {
|
||||||
|
"content": "<|buffer8|>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": false
|
||||||
|
},
|
||||||
|
"35": {
|
||||||
|
"content": "<|buffer9|>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": false
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"additional_special_tokens": [
|
||||||
|
"<|im_start|>",
|
||||||
|
"<|im_end|>",
|
||||||
|
"<|object_ref_start|>",
|
||||||
|
"<|object_ref_end|>",
|
||||||
|
"<|box_start|>",
|
||||||
|
"<|box_end|>",
|
||||||
|
"<|quad_start|>",
|
||||||
|
"<|quad_end|>",
|
||||||
|
"<|vision_start|>",
|
||||||
|
"<|vision_end|>",
|
||||||
|
"<|vision_pad|>",
|
||||||
|
"<|image_pad|>",
|
||||||
|
"<|video_pad|>",
|
||||||
|
"<|audio_start|>",
|
||||||
|
"<|audio_end|>",
|
||||||
|
"<|audio_pad|>",
|
||||||
|
"<tts_pad>",
|
||||||
|
"<tts_text_bos>",
|
||||||
|
"<tts_text_eod>",
|
||||||
|
"<tts_text_bos_single>"
|
||||||
|
],
|
||||||
|
"bos_token": "<|im_start|>",
|
||||||
|
"clean_up_tokenization_spaces": false,
|
||||||
|
"eos_token": "<|im_end|>",
|
||||||
|
"legacy": true,
|
||||||
|
"model_max_length": 131072,
|
||||||
|
"pad_token": "<|endoftext|>",
|
||||||
|
"sp_model_kwargs": {},
|
||||||
|
"spaces_between_special_tokens": false,
|
||||||
|
"unk_token": "<|endoftext|>",
|
||||||
|
"image_token": "<|image_pad|>",
|
||||||
|
"audio_token": "<|audio_pad|>",
|
||||||
|
"video_token": "<|video_pad|>",
|
||||||
|
"vision_bos_token": "<|vision_start|>",
|
||||||
|
"vision_eos_token": "<|vision_end|>",
|
||||||
|
"audio_bos_token": "<|audio_start|>",
|
||||||
|
"audio_eos_token": "<|audio_end|>",
|
||||||
|
"chat_template": "{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0].role == 'system' %}\n {{- messages[0].content + '\\n\\n' }}\n {%- endif %}\n {{- \"# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within <tools></tools> XML tags:\\n<tools>\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n</tools>\\n\\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\\n<tool_call>\\n{\\\"name\\\": <function-name>, \\\"arguments\\\": <args-json-object>}\\n</tool_call><|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0].role == 'system' %}\n {{- '<|im_start|>system\\n' + messages[0].content + '<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}\n{%- for message in messages[::-1] %}\n {%- set index = (messages|length - 1) - loop.index0 %}\n {%- if ns.multi_step_tool and message.role == \"user\" and message.content is string and not(message.content.startswith('<tool_response>') and message.content.endswith('</tool_response>')) %}\n {%- set ns.multi_step_tool = false %}\n {%- set ns.last_query_index = index %}\n {%- endif %}\n{%- endfor %}\n{%- for message in messages %}\n {%- if message.content is string %}\n {%- set content = message.content %}\n {%- else %}\n {%- set content = '' %}\n {%- endif %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) %}\n {{- '<|im_start|>' + message.role + '\\n' + content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {%- set reasoning_content = '' %}\n {%- if message.reasoning_content is string %}\n {%- set reasoning_content = message.reasoning_content %}\n {%- else %}\n {%- if '</think>' in content %}\n {%- set reasoning_content = content.split('</think>')[0].rstrip('\\n').split('<think>')[-1].lstrip('\\n') %}\n {%- set content = content.split('</think>')[-1].lstrip('\\n') %}\n {%- endif %}\n {%- endif %}\n {%- if true %}\n {{- '<|im_start|>' + message.role + '\\n<think>\\n' + reasoning_content.strip('\\n') + '\\n</think>\\n\\n' + content.lstrip('\\n') }}\n {%- endif %}\n {%- if message.tool_calls %}\n {%- for tool_call in message.tool_calls %}\n {%- if (loop.first and content) or (not loop.first) %}\n {{- '\\n' }}\n {%- endif %}\n {%- if tool_call.function %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '<tool_call>\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {%- if tool_call.arguments is string %}\n {{- tool_call.arguments }}\n {%- else %}\n {{- tool_call.arguments | tojson }}\n {%- endif %}\n {{- '}\\n</tool_call>' }}\n {%- endfor %}\n {%- endif %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if loop.first or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n<tool_response>\\n' }}\n {{- content }}\n {{- '\\n</tool_response>' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n {%- if open_thinking is defined and open_thinking is true %}\n {{- '<think>\\n' }}\n {%- else %}\n {{- '<think>\\n\\n</think>\\n\\n' }}\n {%- endif %}\n{%- endif %}",
|
||||||
|
"tokenizer_class": "PreTrainedTokenizerFast"
|
||||||
|
}
|
||||||
@ -1,2 +1,32 @@
|
|||||||
mkdocs>=1.5.0
|
datasets==3.6.0
|
||||||
mkdocs-material>=9.0.0
|
datasketch==1.6.4
|
||||||
|
Flask==3.0.3
|
||||||
|
Flask_Cors==4.0.0
|
||||||
|
jieba==0.42.1
|
||||||
|
jsonlines==4.0.0
|
||||||
|
marshmallow==3.22.0
|
||||||
|
# matplotlib==3.10.0
|
||||||
|
ngrok==1.4.0
|
||||||
|
nltk==3.8
|
||||||
|
numpy==1.26.4
|
||||||
|
openai==1.59.6
|
||||||
|
# peft==0.7.1
|
||||||
|
psutil==5.9.8
|
||||||
|
pydantic==2.11.5
|
||||||
|
rich==13.7.1
|
||||||
|
scikit_learn==1.5.1
|
||||||
|
sentence_transformers==2.3.1
|
||||||
|
simhash==2.1.2
|
||||||
|
tiktoken==0.10.0
|
||||||
|
transformers==4.57.6
|
||||||
|
jinja2==3.1.2
|
||||||
|
jsonlines==4.0.0
|
||||||
|
trl==0.13.0
|
||||||
|
ujson==5.1.0
|
||||||
|
wandb==0.18.3
|
||||||
|
streamlit==1.50.0
|
||||||
|
einops==0.8.1
|
||||||
|
swanlab==0.7.11
|
||||||
|
modelscope==1.30.0
|
||||||
|
# torch==2.6.0
|
||||||
|
# torchvision==0.21.0
|
||||||
40
scripts/chat_api.py
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
from openai import OpenAI
|
||||||
|
|
||||||
|
client = OpenAI(
|
||||||
|
api_key="sk-123",
|
||||||
|
base_url="http://localhost:11434/v1"
|
||||||
|
)
|
||||||
|
stream = True
|
||||||
|
conversation_history_origin = []
|
||||||
|
conversation_history = conversation_history_origin.copy()
|
||||||
|
history_messages_num = 0 # 必须设置为偶数(Q+A),为0则不携带历史对话
|
||||||
|
while True:
|
||||||
|
query = input('[Q]: ')
|
||||||
|
conversation_history.append({"role": "user", "content": query})
|
||||||
|
response = client.chat.completions.create(
|
||||||
|
model="minimind-local:latest",
|
||||||
|
messages=conversation_history[-(history_messages_num or 1):],
|
||||||
|
stream=stream,
|
||||||
|
temperature=0.8,
|
||||||
|
max_tokens=2048,
|
||||||
|
top_p=0.8,
|
||||||
|
extra_body={"chat_template_kwargs": {"open_thinking": True}, "reasoning_effort": "medium"} # 思考开关
|
||||||
|
)
|
||||||
|
if not stream:
|
||||||
|
assistant_res = response.choices[0].message.content
|
||||||
|
print('[A]: ', assistant_res)
|
||||||
|
else:
|
||||||
|
print('[A]: ', end='', flush=True)
|
||||||
|
assistant_res = ''
|
||||||
|
for chunk in response:
|
||||||
|
delta = chunk.choices[0].delta
|
||||||
|
r = getattr(delta, 'reasoning_content', None) or ""
|
||||||
|
c = delta.content or ""
|
||||||
|
if r:
|
||||||
|
print(f'\033[90m{r}\033[0m', end="", flush=True)
|
||||||
|
if c:
|
||||||
|
print(c, end="", flush=True)
|
||||||
|
assistant_res += c
|
||||||
|
|
||||||
|
conversation_history.append({"role": "assistant", "content": assistant_res})
|
||||||
|
print('\n\n')
|
||||||
144
scripts/convert_model.py
Normal file
@ -0,0 +1,144 @@
|
|||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import json
|
||||||
|
|
||||||
|
__package__ = "scripts"
|
||||||
|
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
||||||
|
import torch
|
||||||
|
import transformers
|
||||||
|
import warnings
|
||||||
|
from transformers import AutoTokenizer, AutoModelForCausalLM, Qwen3Config, Qwen3ForCausalLM, Qwen3MoeConfig, Qwen3MoeForCausalLM
|
||||||
|
from model.model_minimind import MiniMindConfig, MiniMindForCausalLM
|
||||||
|
from model.model_lora import apply_lora, merge_lora
|
||||||
|
|
||||||
|
warnings.filterwarnings('ignore', category=UserWarning)
|
||||||
|
|
||||||
|
def convert_torch2transformers_minimind(torch_path, transformers_path, dtype=torch.float16):
|
||||||
|
MiniMindConfig.register_for_auto_class()
|
||||||
|
MiniMindForCausalLM.register_for_auto_class("AutoModelForCausalLM")
|
||||||
|
lm_model = MiniMindForCausalLM(lm_config)
|
||||||
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
|
state_dict = torch.load(torch_path, map_location=device)
|
||||||
|
lm_model.load_state_dict(state_dict, strict=False)
|
||||||
|
lm_model = lm_model.to(dtype) # 转换模型权重精度
|
||||||
|
model_params = sum(p.numel() for p in lm_model.parameters() if p.requires_grad)
|
||||||
|
print(f'模型参数: {model_params / 1e6} 百万 = {model_params / 1e9} B (Billion)')
|
||||||
|
lm_model.save_pretrained(transformers_path, safe_serialization=False)
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained('../model/')
|
||||||
|
tokenizer.save_pretrained(transformers_path)
|
||||||
|
# ======= transformers-5.0的兼容低版本写法 =======
|
||||||
|
if int(transformers.__version__.split('.')[0]) >= 5:
|
||||||
|
tokenizer_config_path, config_path = os.path.join(transformers_path, "tokenizer_config.json"), os.path.join(transformers_path, "config.json")
|
||||||
|
json.dump({**json.load(open(tokenizer_config_path, 'r', encoding='utf-8')), "tokenizer_class": "PreTrainedTokenizerFast", "extra_special_tokens": {}}, open(tokenizer_config_path, 'w', encoding='utf-8'), indent=2, ensure_ascii=False)
|
||||||
|
config = json.load(open(config_path, 'r', encoding='utf-8'))
|
||||||
|
config['rope_theta'] = lm_config.rope_theta; config['rope_scaling'] = None; del config['rope_parameters']
|
||||||
|
json.dump(config, open(config_path, 'w', encoding='utf-8'), indent=2, ensure_ascii=False)
|
||||||
|
print(f"模型已保存为 Transformers-MiniMind 格式: {transformers_path}")
|
||||||
|
|
||||||
|
|
||||||
|
# QwenForCausalLM/LlamaForCausalLM结构兼容生态
|
||||||
|
def convert_torch2transformers(torch_path, transformers_path, dtype=torch.float16):
|
||||||
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
|
state_dict = torch.load(torch_path, map_location=device)
|
||||||
|
common_config = {
|
||||||
|
"vocab_size": lm_config.vocab_size,
|
||||||
|
"hidden_size": lm_config.hidden_size,
|
||||||
|
"intermediate_size": lm_config.intermediate_size,
|
||||||
|
"num_hidden_layers": lm_config.num_hidden_layers,
|
||||||
|
"num_attention_heads": lm_config.num_attention_heads,
|
||||||
|
"num_key_value_heads": lm_config.num_key_value_heads,
|
||||||
|
"head_dim": lm_config.hidden_size // lm_config.num_attention_heads,
|
||||||
|
"max_position_embeddings": lm_config.max_position_embeddings,
|
||||||
|
"rms_norm_eps": lm_config.rms_norm_eps,
|
||||||
|
"rope_theta": lm_config.rope_theta,
|
||||||
|
"tie_word_embeddings": lm_config.tie_word_embeddings
|
||||||
|
}
|
||||||
|
if not lm_config.use_moe:
|
||||||
|
qwen_config = Qwen3Config(
|
||||||
|
**common_config,
|
||||||
|
use_sliding_window=False,
|
||||||
|
sliding_window=None
|
||||||
|
)
|
||||||
|
qwen_model = Qwen3ForCausalLM(qwen_config)
|
||||||
|
else:
|
||||||
|
qwen_config = Qwen3MoeConfig(
|
||||||
|
**common_config,
|
||||||
|
num_experts=lm_config.num_experts,
|
||||||
|
num_experts_per_tok=lm_config.num_experts_per_tok,
|
||||||
|
moe_intermediate_size=lm_config.moe_intermediate_size,
|
||||||
|
norm_topk_prob=lm_config.norm_topk_prob
|
||||||
|
)
|
||||||
|
qwen_model = Qwen3MoeForCausalLM(qwen_config)
|
||||||
|
# ======= transformers-5.0的兼容低版本写法 =======
|
||||||
|
if int(transformers.__version__.split('.')[0]) >= 5:
|
||||||
|
new_sd = {k: v for k, v in state_dict.items() if 'experts.' not in k or 'gate.weight' in k}
|
||||||
|
for l in range(lm_config.num_hidden_layers):
|
||||||
|
p = f'model.layers.{l}.mlp.experts'
|
||||||
|
new_sd[f'{p}.gate_up_proj'] = torch.cat([torch.stack([state_dict[f'{p}.{e}.gate_proj.weight'] for e in range(lm_config.num_experts)]), torch.stack([state_dict[f'{p}.{e}.up_proj.weight'] for e in range(lm_config.num_experts)])], dim=1)
|
||||||
|
new_sd[f'{p}.down_proj'] = torch.stack([state_dict[f'{p}.{e}.down_proj.weight'] for e in range(lm_config.num_experts)])
|
||||||
|
state_dict = new_sd
|
||||||
|
|
||||||
|
qwen_model.load_state_dict(state_dict, strict=True)
|
||||||
|
qwen_model = qwen_model.to(dtype) # 转换模型权重精度
|
||||||
|
qwen_model.save_pretrained(transformers_path)
|
||||||
|
model_params = sum(p.numel() for p in qwen_model.parameters() if p.requires_grad)
|
||||||
|
print(f'模型参数: {model_params / 1e6} 百万 = {model_params / 1e9} B (Billion)')
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained('../model/')
|
||||||
|
tokenizer.save_pretrained(transformers_path)
|
||||||
|
|
||||||
|
# ======= transformers-5.0的兼容低版本写法 =======
|
||||||
|
if int(transformers.__version__.split('.')[0]) >= 5:
|
||||||
|
tokenizer_config_path, config_path = os.path.join(transformers_path, "tokenizer_config.json"), os.path.join(transformers_path, "config.json")
|
||||||
|
json.dump({**json.load(open(tokenizer_config_path, 'r', encoding='utf-8')), "tokenizer_class": "PreTrainedTokenizerFast", "extra_special_tokens": {}}, open(tokenizer_config_path, 'w', encoding='utf-8'), indent=2, ensure_ascii=False)
|
||||||
|
config = json.load(open(config_path, 'r', encoding='utf-8'))
|
||||||
|
config['rope_theta'] = lm_config.rope_theta; config['rope_scaling'] = None; del config['rope_parameters']
|
||||||
|
json.dump(config, open(config_path, 'w', encoding='utf-8'), indent=2, ensure_ascii=False)
|
||||||
|
print(f"模型已保存为 Transformers 格式: {transformers_path}")
|
||||||
|
|
||||||
|
|
||||||
|
def convert_transformers2torch(transformers_path, torch_path):
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(transformers_path, trust_remote_code=True)
|
||||||
|
torch.save({k: v.cpu().half() for k, v in model.state_dict().items()}, torch_path)
|
||||||
|
print(f"模型已保存为 PyTorch 格式: {torch_path}")
|
||||||
|
|
||||||
|
|
||||||
|
def convert_merge_base_lora(base_torch_path, lora_path, merged_torch_path):
|
||||||
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
|
lm_model = MiniMindForCausalLM(lm_config).to(device)
|
||||||
|
state_dict = torch.load(base_torch_path, map_location=device)
|
||||||
|
lm_model.load_state_dict(state_dict, strict=False)
|
||||||
|
apply_lora(lm_model)
|
||||||
|
merge_lora(lm_model, lora_path, merged_torch_path)
|
||||||
|
print(f"LoRA 已合并并保存为基模结构 PyTorch 格式: {merged_torch_path}")
|
||||||
|
|
||||||
|
|
||||||
|
def convert_jinja_to_json(jinja_path):
|
||||||
|
with open(jinja_path, 'r') as f: template = f.read()
|
||||||
|
escaped = json.dumps(template)
|
||||||
|
print(f'"chat_template": {escaped}')
|
||||||
|
|
||||||
|
|
||||||
|
def convert_json_to_jinja(json_file_path, output_path):
|
||||||
|
with open(json_file_path, 'r') as f: config = json.load(f)
|
||||||
|
template = config['chat_template']
|
||||||
|
with open(output_path, 'w') as f: f.write(template)
|
||||||
|
print(f"模板已保存为 jinja 文件: {output_path}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
lm_config = MiniMindConfig(hidden_size=768, num_hidden_layers=8, max_seq_len=8192, use_moe=False)
|
||||||
|
|
||||||
|
# convert torch to transformers
|
||||||
|
torch_path = f"../out/full_sft_{lm_config.hidden_size}{'_moe' if lm_config.use_moe else ''}.pth"
|
||||||
|
transformers_path = '../minimind-3'
|
||||||
|
convert_torch2transformers(torch_path, transformers_path)
|
||||||
|
|
||||||
|
# # merge lora
|
||||||
|
# base_torch_path = f"../out/full_sft_{lm_config.hidden_size}{'_moe' if lm_config.use_moe else ''}.pth"
|
||||||
|
# lora_path = f"../out/lora_identity_{lm_config.hidden_size}{'_moe' if lm_config.use_moe else ''}.pth"
|
||||||
|
# merged_torch_path = f"../out/merge_identity_{lm_config.hidden_size}{'_moe' if lm_config.use_moe else ''}.pth"
|
||||||
|
# convert_merge_base_lora(base_torch_path, lora_path, merged_torch_path)
|
||||||
|
|
||||||
|
# convert_transformers2torch(transformers_path, torch_path)
|
||||||
|
# convert_json_to_jinja('../model/tokenizer_config.json', '../model/chat_template.jinja')
|
||||||
|
# convert_jinja_to_json('../model/chat_template.jinja')
|
||||||
240
scripts/eval_toolcall.py
Normal file
@ -0,0 +1,240 @@
|
|||||||
|
import os
|
||||||
|
import sys
|
||||||
|
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
||||||
|
import re
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
import random
|
||||||
|
import argparse
|
||||||
|
import warnings
|
||||||
|
import torch
|
||||||
|
from datetime import datetime
|
||||||
|
from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer
|
||||||
|
from openai import OpenAI
|
||||||
|
from model.model_minimind import MiniMindConfig, MiniMindForCausalLM
|
||||||
|
from trainer.trainer_utils import setup_seed, get_model_params
|
||||||
|
warnings.filterwarnings('ignore')
|
||||||
|
|
||||||
|
TOOLS = [
|
||||||
|
{"type": "function", "function": {"name": "calculate_math", "description": "计算数学表达式的结果,支持加减乘除、幂运算、开方等", "parameters": {"type": "object", "properties": {"expression": {"type": "string", "description": "数学表达式,如123+456、2**10、sqrt(144)"}}, "required": ["expression"]}}},
|
||||||
|
{"type": "function", "function": {"name": "get_current_time", "description": "获取当前日期和时间,支持指定时区", "parameters": {"type": "object", "properties": {"timezone": {"type": "string", "description": "时区名称,如Asia/Shanghai、America/New_York", "default": "Asia/Shanghai"}}, "required": []}}},
|
||||||
|
{"type": "function", "function": {"name": "random_number", "description": "生成指定范围内的随机数", "parameters": {"type": "object", "properties": {"min": {"type": "integer", "description": "最小值", "default": 0}, "max": {"type": "integer", "description": "最大值", "default": 100}}, "required": []}}},
|
||||||
|
{"type": "function", "function": {"name": "text_length", "description": "计算文本的字符数和单词数", "parameters": {"type": "object", "properties": {"text": {"type": "string", "description": "要统计的文本"}}, "required": ["text"]}}},
|
||||||
|
{"type": "function", "function": {"name": "unit_converter", "description": "进行单位换算,支持长度、重量、温度等", "parameters": {"type": "object", "properties": {"value": {"type": "number", "description": "要转换的数值"}, "from_unit": {"type": "string", "description": "源单位,如km、miles、kg、pounds、celsius、fahrenheit"}, "to_unit": {"type": "string", "description": "目标单位"}}, "required": ["value", "from_unit", "to_unit"]}}},
|
||||||
|
{"type": "function", "function": {"name": "get_current_weather", "description": "获取指定城市的当前天气信息,包括温度、湿度和天气状况", "parameters": {"type": "object", "properties": {"location": {"type": "string", "description": "城市名称,如北京、上海、New York"}, "unit": {"type": "string", "description": "温度单位,celsius或fahrenheit", "enum": ["celsius", "fahrenheit"], "default": "celsius"}}, "required": ["location"]}}},
|
||||||
|
{"type": "function", "function": {"name": "get_exchange_rate", "description": "查询两种货币之间的实时汇率", "parameters": {"type": "object", "properties": {"from_currency": {"type": "string", "description": "源货币代码,如USD、CNY、EUR"}, "to_currency": {"type": "string", "description": "目标货币代码,如USD、CNY、EUR"}}, "required": ["from_currency", "to_currency"]}}},
|
||||||
|
{"type": "function", "function": {"name": "translate_text", "description": "将文本翻译成目标语言", "parameters": {"type": "object", "properties": {"text": {"type": "string", "description": "要翻译的文本"}, "target_language": {"type": "string", "description": "目标语言,如english、chinese、japanese、french"}}, "required": ["text", "target_language"]}}},
|
||||||
|
]
|
||||||
|
|
||||||
|
MOCK_RESULTS = {
|
||||||
|
"calculate_math": lambda args: {"result": str(eval(str(args.get("expression", "0")).replace("^", "**").replace("×", "*").replace("÷", "/").replace("−", "-").replace("²", "**2").replace("³", "**3").replace("(", "(").replace(")", ")")))},
|
||||||
|
"get_current_time": lambda args: {"datetime": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), "timezone": args.get("timezone", "Asia/Shanghai")},
|
||||||
|
"random_number": lambda args: {"result": random.randint(int(args.get("min", 0)), int(args.get("max", 100)))},
|
||||||
|
"text_length": lambda args: {"characters": len(args.get("text", "")), "words": len(args.get("text", "").split())},
|
||||||
|
"unit_converter": lambda args: {"result": round(float(args.get("value", 0)) * 0.621371, 2), "from": f"{args.get('value', 0)} {args.get('from_unit', '')}", "to": args.get("to_unit", "")},
|
||||||
|
"get_current_weather": lambda args: {"city": args.get("location"), "temperature": "22°C", "humidity": "65%", "condition": "晴"},
|
||||||
|
"get_exchange_rate": lambda args: {"from": args.get("from_currency", ""), "to": args.get("to_currency", ""), "rate": 7.15},
|
||||||
|
"translate_text": lambda args: {"translated": "hello world"},
|
||||||
|
}
|
||||||
|
|
||||||
|
TOOL_MAP = {t["function"]["name"]: t for t in TOOLS}
|
||||||
|
|
||||||
|
def get_tools(names):
|
||||||
|
return [TOOL_MAP[n] for n in names]
|
||||||
|
|
||||||
|
TEST_CASES = [
|
||||||
|
{"prompt": "帮我算一下 256 乘以 37 等于多少", "tools": ["calculate_math", "get_current_time"]},
|
||||||
|
{"prompt": "现在几点了?", "tools": ["get_current_time", "random_number"]},
|
||||||
|
{"prompt": "帮我把100公里换算成英里", "tools": ["unit_converter", "calculate_math"]},
|
||||||
|
{"prompt": "帮我生成一个1到1000的随机数,然后计算它的平方", "tools": ["random_number", "calculate_math", "text_length"]},
|
||||||
|
{"prompt": "北京今天天气怎么样?", "tools": ["get_current_weather", "get_current_time"]},
|
||||||
|
{"prompt": "查一下美元兑人民币汇率", "tools": ["get_exchange_rate", "get_current_time"]},
|
||||||
|
{"prompt": "把'你好世界'翻译成英文", "tools": ["translate_text", "text_length"]},
|
||||||
|
{"prompt": "What is the weather in Tokyo? Also convert 30 celsius to fahrenheit.", "tools": ["get_current_weather", "unit_converter", "get_current_time"]},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def init_model(args):
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(args.load_from)
|
||||||
|
if 'model' in args.load_from:
|
||||||
|
model = MiniMindForCausalLM(MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers, use_moe=bool(args.use_moe)))
|
||||||
|
moe_suffix = '_moe' if args.use_moe else ''
|
||||||
|
ckp = f'./{args.save_dir}/{args.weight}_{args.hidden_size}{moe_suffix}.pth'
|
||||||
|
model.load_state_dict(torch.load(ckp, map_location=args.device), strict=True)
|
||||||
|
else:
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(args.load_from, trust_remote_code=True)
|
||||||
|
get_model_params(model, model.config)
|
||||||
|
return model.half().eval().to(args.device), tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
def parse_tool_calls(text):
|
||||||
|
matches = re.findall(r'<tool_call>(.*?)</tool_call>', text, re.DOTALL)
|
||||||
|
calls = []
|
||||||
|
for m in matches:
|
||||||
|
try:
|
||||||
|
calls.append(json.loads(m.strip()))
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return calls
|
||||||
|
|
||||||
|
|
||||||
|
def parse_tool_call_from_text(content):
|
||||||
|
pattern = r'<tool_call>\s*(\{.*?\})\s*</tool_call>'
|
||||||
|
matches = re.findall(pattern, content, re.DOTALL)
|
||||||
|
if not matches:
|
||||||
|
return None
|
||||||
|
tool_calls = []
|
||||||
|
for i, match in enumerate(matches):
|
||||||
|
try:
|
||||||
|
data = json.loads(match)
|
||||||
|
tool_calls.append({
|
||||||
|
"id": f"call_{i}",
|
||||||
|
"function": {"name": data.get("name", ""), "arguments": json.dumps(data.get("arguments", {}), ensure_ascii=False)}
|
||||||
|
})
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return tool_calls if tool_calls else None
|
||||||
|
|
||||||
|
|
||||||
|
def execute_tool(call, arguments=None):
|
||||||
|
name = call.get("name", "") if isinstance(call, dict) else call
|
||||||
|
try:
|
||||||
|
raw_args = call.get("arguments", {}) if isinstance(call, dict) else arguments
|
||||||
|
args = json.loads(raw_args) if isinstance(raw_args, str) else raw_args
|
||||||
|
except Exception:
|
||||||
|
args = {}
|
||||||
|
fn = MOCK_RESULTS.get(name)
|
||||||
|
if not fn:
|
||||||
|
return {"error": f"未知工具: {name}"}
|
||||||
|
try:
|
||||||
|
return fn(args)
|
||||||
|
except Exception as e:
|
||||||
|
return {"error": f"工具执行失败: {str(e)[:80]}"}
|
||||||
|
|
||||||
|
|
||||||
|
def generate(model, tokenizer, messages, tools, args):
|
||||||
|
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
||||||
|
input_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True, tools=tools, open_thinking=False)
|
||||||
|
inputs = tokenizer(input_text, return_tensors="pt", truncation=True).to(args.device)
|
||||||
|
st = time.time()
|
||||||
|
print('🧠: ', end='')
|
||||||
|
generated_ids = model.generate(
|
||||||
|
inputs["input_ids"], attention_mask=inputs["attention_mask"],
|
||||||
|
max_new_tokens=args.max_new_tokens, do_sample=True, streamer=streamer,
|
||||||
|
pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id,
|
||||||
|
top_p=args.top_p, temperature=args.temperature
|
||||||
|
)
|
||||||
|
response = tokenizer.decode(generated_ids[0][len(inputs["input_ids"][0]):], skip_special_tokens=True)
|
||||||
|
gen_tokens = len(generated_ids[0]) - len(inputs["input_ids"][0])
|
||||||
|
print(f'\n[Speed]: {gen_tokens / (time.time() - st):.2f} tokens/s') if args.show_speed else print()
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
def chat_api(client, messages, tools, args, stream=True):
|
||||||
|
response = client.chat.completions.create(
|
||||||
|
model=args.api_model, messages=messages, tools=tools,
|
||||||
|
stream=stream, temperature=args.temperature,
|
||||||
|
max_tokens=8192, top_p=args.top_p
|
||||||
|
)
|
||||||
|
if not stream:
|
||||||
|
choice = response.choices[0]
|
||||||
|
content = choice.message.content or ""
|
||||||
|
tool_calls = choice.message.tool_calls
|
||||||
|
if not tool_calls:
|
||||||
|
tool_calls = parse_tool_call_from_text(content)
|
||||||
|
print(f'🧠: {content}')
|
||||||
|
return content, tool_calls
|
||||||
|
print('🧠: ', end='', flush=True)
|
||||||
|
content, tool_calls = "", None
|
||||||
|
for chunk in response:
|
||||||
|
delta = chunk.choices[0].delta
|
||||||
|
if delta.content:
|
||||||
|
print(delta.content, end="", flush=True)
|
||||||
|
content += delta.content
|
||||||
|
if delta.tool_calls:
|
||||||
|
if tool_calls is None:
|
||||||
|
tool_calls = []
|
||||||
|
for tc_chunk in delta.tool_calls:
|
||||||
|
idx = tc_chunk.index if tc_chunk.index is not None else len(tool_calls)
|
||||||
|
while len(tool_calls) <= idx:
|
||||||
|
tool_calls.append({
|
||||||
|
"id": "",
|
||||||
|
"function": {"name": "", "arguments": ""}
|
||||||
|
})
|
||||||
|
if tc_chunk.id:
|
||||||
|
tool_calls[idx]["id"] += tc_chunk.id
|
||||||
|
if tc_chunk.function:
|
||||||
|
if tc_chunk.function.name:
|
||||||
|
tool_calls[idx]["function"]["name"] += tc_chunk.function.name
|
||||||
|
if tc_chunk.function.arguments:
|
||||||
|
tool_calls[idx]["function"]["arguments"] += tc_chunk.function.arguments
|
||||||
|
print()
|
||||||
|
if not tool_calls:
|
||||||
|
tool_calls = parse_tool_call_from_text(content)
|
||||||
|
return content, tool_calls
|
||||||
|
|
||||||
|
|
||||||
|
def run_case(prompt, tools, args, model=None, tokenizer=None, client=None):
|
||||||
|
messages = [{"role": "user", "content": prompt}]
|
||||||
|
while True:
|
||||||
|
if args.backend == 'local':
|
||||||
|
content = generate(model, tokenizer, messages, tools, args)
|
||||||
|
tool_calls = parse_tool_calls(content)
|
||||||
|
else:
|
||||||
|
content, tool_calls = chat_api(client, messages, tools, args, stream=bool(args.stream))
|
||||||
|
if not tool_calls:
|
||||||
|
break
|
||||||
|
tool_calls = [{
|
||||||
|
"id": tc.id if hasattr(tc, 'id') else tc.get("id", ""),
|
||||||
|
"name": tc.function.name if hasattr(tc, 'function') else tc["function"]["name"],
|
||||||
|
"arguments": tc.function.arguments if hasattr(tc, 'function') else tc["function"]["arguments"]
|
||||||
|
} for tc in tool_calls] if args.backend == 'api' else tool_calls
|
||||||
|
messages.append({"role": "assistant", "content": content} if args.backend == 'local' else {"role": "assistant", "content": content, "tool_calls": [{"id": tc["id"], "type": "function", "function": {"name": tc["name"], "arguments": tc["arguments"]}} for tc in tool_calls]})
|
||||||
|
for tc in tool_calls:
|
||||||
|
name = tc["name"]
|
||||||
|
arguments = tc["arguments"]
|
||||||
|
print(f'📞 [Tool Calling]: {name} | args={arguments}')
|
||||||
|
result = execute_tool(tc if args.backend == 'local' else name, arguments)
|
||||||
|
print(f'✅ [Tool Called]: {json.dumps(result, ensure_ascii=False)}')
|
||||||
|
messages.append({"role": "tool", "content": json.dumps(result, ensure_ascii=False)} if args.backend == 'local' else {"role": "tool", "content": json.dumps(result, ensure_ascii=False), "tool_call_id": tc["id"]})
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description="MiniMind ToolCall评测")
|
||||||
|
parser.add_argument('--backend', default='local', choices=['local', 'api'], type=str, help="推理后端(local=本地模型,api=OpenAI兼容接口)")
|
||||||
|
parser.add_argument('--load_from', default='../model', type=str, help="模型加载路径(model=原生torch权重,其他路径=transformers格式)")
|
||||||
|
parser.add_argument('--save_dir', default='../out', type=str, help="模型权重目录")
|
||||||
|
parser.add_argument('--weight', default='full_sft', type=str, help="权重名称前缀(pretrain, full_sft, rlhf, reason, ppo_actor, grpo, spo)")
|
||||||
|
parser.add_argument('--hidden_size', default=768, type=int, help="隐藏层维度")
|
||||||
|
parser.add_argument('--num_hidden_layers', default=8, type=int, help="隐藏层数量")
|
||||||
|
parser.add_argument('--use_moe', default=0, type=int, choices=[0, 1], help="是否使用MoE架构(0=否,1=是)")
|
||||||
|
parser.add_argument('--max_new_tokens', default=512, type=int, help="最大生成长度")
|
||||||
|
parser.add_argument('--temperature', default=0.9, type=float, help="生成温度,控制随机性(0-1,越大越随机)")
|
||||||
|
parser.add_argument('--top_p', default=0.9, type=float, help="nucleus采样阈值(0-1)")
|
||||||
|
parser.add_argument('--show_speed', default=0, type=int, help="显示decode速度(tokens/s)")
|
||||||
|
parser.add_argument('--device', default='cuda' if torch.cuda.is_available() else 'cpu', type=str, help="运行设备")
|
||||||
|
parser.add_argument('--api_base_url', default="http://localhost:11434/v1", type=str, help="OpenAI兼容接口的base_url")
|
||||||
|
parser.add_argument('--api_key', default='sk-123', type=str, help="OpenAI兼容接口的api_key")
|
||||||
|
parser.add_argument('--api_model', default='jingyaogong/minimind-3:latest', type=str, help="API请求时使用的模型名称")
|
||||||
|
parser.add_argument('--stream', default=1, type=int, help="API模式下是否流式输出(0=否,1=是)")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
model = tokenizer = client = None
|
||||||
|
if args.backend == 'local': model, tokenizer = init_model(args)
|
||||||
|
else: client = OpenAI(api_key=args.api_key, base_url=args.api_base_url)
|
||||||
|
|
||||||
|
input_mode = int(input('[0] 自动测试\n[1] 手动输入\n'))
|
||||||
|
|
||||||
|
cases = [{"prompt": case["prompt"], "tools": get_tools(case["tools"]), "tool_names": case["tools"]} for case in TEST_CASES] if input_mode == 0 else iter(lambda: {"prompt": input('💬: '), "tools": TOOLS, "tool_names": [t["function"]["name"] for t in TOOLS]}, {"prompt": "", "tools": TOOLS, "tool_names": []})
|
||||||
|
for case in cases:
|
||||||
|
if not case["prompt"]: break
|
||||||
|
setup_seed(random.randint(0, 31415926))
|
||||||
|
if input_mode == 0:
|
||||||
|
print(f'📦 可用工具: {case["tool_names"]}\n')
|
||||||
|
print(f'💬: {case["prompt"]}')
|
||||||
|
run_case(case["prompt"], case["tools"], args, model=model, tokenizer=tokenizer, client=client)
|
||||||
|
print('\n' + '-' * 50 + '\n')
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
245
scripts/serve_openai_api.py
Normal file
@ -0,0 +1,245 @@
|
|||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import re
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
__package__ = "scripts"
|
||||||
|
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
||||||
|
import time
|
||||||
|
import torch
|
||||||
|
import warnings
|
||||||
|
import uvicorn
|
||||||
|
|
||||||
|
from threading import Thread
|
||||||
|
from queue import Queue
|
||||||
|
from fastapi import FastAPI, HTTPException
|
||||||
|
from fastapi.responses import StreamingResponse
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer
|
||||||
|
from model.model_minimind import MiniMindConfig, MiniMindForCausalLM
|
||||||
|
from model.model_lora import apply_lora, load_lora
|
||||||
|
|
||||||
|
warnings.filterwarnings('ignore')
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
|
||||||
|
def init_model(args):
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(args.load_from)
|
||||||
|
if 'model' in args.load_from:
|
||||||
|
moe_suffix = '_moe' if args.use_moe else ''
|
||||||
|
ckp = f'../{args.save_dir}/{args.weight}_{args.hidden_size}{moe_suffix}.pth'
|
||||||
|
model = MiniMindForCausalLM(MiniMindConfig(
|
||||||
|
hidden_size=args.hidden_size,
|
||||||
|
num_hidden_layers=args.num_hidden_layers,
|
||||||
|
max_seq_len=args.max_seq_len,
|
||||||
|
use_moe=bool(args.use_moe),
|
||||||
|
inference_rope_scaling=args.inference_rope_scaling
|
||||||
|
))
|
||||||
|
model.load_state_dict(torch.load(ckp, map_location=device), strict=True)
|
||||||
|
if args.lora_weight != 'None':
|
||||||
|
apply_lora(model)
|
||||||
|
load_lora(model, f'../{args.save_dir}/lora/{args.lora_weight}_{args.hidden_size}.pth')
|
||||||
|
else:
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(args.load_from, trust_remote_code=True)
|
||||||
|
print(f'MiniMind模型参数量: {sum(p.numel() for p in model.parameters()) / 1e6:.2f} M(illion)')
|
||||||
|
return model.half().eval().to(device), tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
class ChatRequest(BaseModel):
|
||||||
|
model: str
|
||||||
|
messages: list
|
||||||
|
temperature: float = 0.7
|
||||||
|
top_p: float = 0.92
|
||||||
|
max_tokens: int = 8192
|
||||||
|
stream: bool = True
|
||||||
|
tools: list = []
|
||||||
|
open_thinking: bool = False
|
||||||
|
chat_template_kwargs: dict = None
|
||||||
|
|
||||||
|
def get_open_thinking(self) -> bool:
|
||||||
|
"""兼容多种方式开启 thinking"""
|
||||||
|
if self.open_thinking:
|
||||||
|
return True
|
||||||
|
if self.chat_template_kwargs:
|
||||||
|
return self.chat_template_kwargs.get('open_thinking', False) or \
|
||||||
|
self.chat_template_kwargs.get('enable_thinking', False)
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
class CustomStreamer(TextStreamer):
|
||||||
|
def __init__(self, tokenizer, queue):
|
||||||
|
super().__init__(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
||||||
|
self.queue = queue
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
|
||||||
|
def on_finalized_text(self, text: str, stream_end: bool = False):
|
||||||
|
self.queue.put(text)
|
||||||
|
if stream_end:
|
||||||
|
self.queue.put(None)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_response(text):
|
||||||
|
reasoning_content = None
|
||||||
|
think_match = re.search(r'<think>(.*?)</think>', text, re.DOTALL)
|
||||||
|
if think_match:
|
||||||
|
reasoning_content = think_match.group(1).strip()
|
||||||
|
text = re.sub(r'<think>.*?</think>\s*', '', text, flags=re.DOTALL)
|
||||||
|
elif '</think>' in text:
|
||||||
|
parts = text.split('</think>', 1)
|
||||||
|
reasoning_content = parts[0].strip()
|
||||||
|
text = parts[1].strip() if len(parts) > 1 else ''
|
||||||
|
tool_calls = []
|
||||||
|
for i, m in enumerate(re.findall(r'<tool_call>(.*?)</tool_call>', text, re.DOTALL)):
|
||||||
|
try:
|
||||||
|
call = json.loads(m.strip())
|
||||||
|
tool_calls.append({"id": f"call_{int(time.time())}_{i}", "type": "function", "function": {"name": call.get("name", ""), "arguments": json.dumps(call.get("arguments", {}), ensure_ascii=False)}})
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
if tool_calls:
|
||||||
|
text = re.sub(r'<tool_call>.*?</tool_call>', '', text, flags=re.DOTALL)
|
||||||
|
return text.strip(), reasoning_content, tool_calls or None
|
||||||
|
|
||||||
|
|
||||||
|
def generate_stream_response(messages, temperature, top_p, max_tokens, tools=None, open_thinking=False):
|
||||||
|
try:
|
||||||
|
new_prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True, tools=tools or None, open_thinking=open_thinking)[-max_tokens:]
|
||||||
|
inputs = tokenizer(new_prompt, return_tensors="pt", truncation=True).to(device)
|
||||||
|
|
||||||
|
queue = Queue()
|
||||||
|
streamer = CustomStreamer(tokenizer, queue)
|
||||||
|
|
||||||
|
def _generate():
|
||||||
|
model.generate(
|
||||||
|
inputs.input_ids,
|
||||||
|
max_new_tokens=max_tokens,
|
||||||
|
do_sample=True,
|
||||||
|
temperature=temperature,
|
||||||
|
top_p=top_p,
|
||||||
|
attention_mask=inputs.attention_mask,
|
||||||
|
pad_token_id=tokenizer.pad_token_id,
|
||||||
|
eos_token_id=tokenizer.eos_token_id,
|
||||||
|
streamer=streamer
|
||||||
|
)
|
||||||
|
|
||||||
|
Thread(target=_generate).start()
|
||||||
|
|
||||||
|
full_text = ""
|
||||||
|
emitted = 0
|
||||||
|
thinking_ended = not bool(open_thinking)
|
||||||
|
|
||||||
|
while True:
|
||||||
|
text = queue.get()
|
||||||
|
if text is None:
|
||||||
|
break
|
||||||
|
full_text += text
|
||||||
|
|
||||||
|
if not thinking_ended:
|
||||||
|
pos = full_text.find('</think>')
|
||||||
|
if pos >= 0:
|
||||||
|
thinking_ended = True
|
||||||
|
new_r = full_text[emitted:pos]
|
||||||
|
if new_r:
|
||||||
|
yield json.dumps({"choices": [{"delta": {"reasoning_content": new_r}}]}, ensure_ascii=False)
|
||||||
|
emitted = pos + len('</think>')
|
||||||
|
after = full_text[emitted:].lstrip('\n')
|
||||||
|
emitted = len(full_text) - len(after)
|
||||||
|
if after:
|
||||||
|
yield json.dumps({"choices": [{"delta": {"content": after}}]}, ensure_ascii=False)
|
||||||
|
emitted = len(full_text)
|
||||||
|
else:
|
||||||
|
new_r = full_text[emitted:]
|
||||||
|
if new_r:
|
||||||
|
yield json.dumps({"choices": [{"delta": {"reasoning_content": new_r}}]}, ensure_ascii=False)
|
||||||
|
emitted = len(full_text)
|
||||||
|
else:
|
||||||
|
new_c = full_text[emitted:]
|
||||||
|
if new_c:
|
||||||
|
yield json.dumps({"choices": [{"delta": {"content": new_c}}]}, ensure_ascii=False)
|
||||||
|
emitted = len(full_text)
|
||||||
|
|
||||||
|
_, _, tool_calls = parse_response(full_text)
|
||||||
|
if tool_calls:
|
||||||
|
yield json.dumps({"choices": [{"delta": {"tool_calls": tool_calls}}]}, ensure_ascii=False)
|
||||||
|
yield json.dumps({"choices": [{"delta": {}, "finish_reason": "tool_calls" if tool_calls else "stop"}]}, ensure_ascii=False)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
yield json.dumps({"error": str(e)})
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/v1/chat/completions")
|
||||||
|
async def chat_completions(request: ChatRequest):
|
||||||
|
try:
|
||||||
|
if request.stream:
|
||||||
|
return StreamingResponse(
|
||||||
|
(f"data: {chunk}\n\n" for chunk in generate_stream_response(
|
||||||
|
messages=request.messages,
|
||||||
|
temperature=request.temperature,
|
||||||
|
top_p=request.top_p,
|
||||||
|
max_tokens=request.max_tokens,
|
||||||
|
tools=request.tools,
|
||||||
|
open_thinking=request.get_open_thinking()
|
||||||
|
)),
|
||||||
|
media_type="text/event-stream"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
new_prompt = tokenizer.apply_chat_template(
|
||||||
|
request.messages,
|
||||||
|
tokenize=False,
|
||||||
|
add_generation_prompt=True,
|
||||||
|
tools=request.tools or None,
|
||||||
|
open_thinking=request.get_open_thinking()
|
||||||
|
)[-request.max_tokens:]
|
||||||
|
inputs = tokenizer(new_prompt, return_tensors="pt", truncation=True).to(device)
|
||||||
|
with torch.no_grad():
|
||||||
|
generated_ids = model.generate(
|
||||||
|
inputs["input_ids"],
|
||||||
|
max_length=inputs["input_ids"].shape[1] + request.max_tokens,
|
||||||
|
do_sample=True,
|
||||||
|
attention_mask=inputs["attention_mask"],
|
||||||
|
pad_token_id=tokenizer.pad_token_id,
|
||||||
|
eos_token_id=tokenizer.eos_token_id,
|
||||||
|
top_p=request.top_p,
|
||||||
|
temperature=request.temperature
|
||||||
|
)
|
||||||
|
answer = tokenizer.decode(generated_ids[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
|
||||||
|
content, reasoning_content, tool_calls = parse_response(answer)
|
||||||
|
message = {"role": "assistant", "content": content}
|
||||||
|
if reasoning_content:
|
||||||
|
message["reasoning_content"] = reasoning_content
|
||||||
|
if tool_calls:
|
||||||
|
message["tool_calls"] = tool_calls
|
||||||
|
return {
|
||||||
|
"id": f"chatcmpl-{int(time.time())}",
|
||||||
|
"object": "chat.completion",
|
||||||
|
"created": int(time.time()),
|
||||||
|
"model": "minimind",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"message": message,
|
||||||
|
"finish_reason": "tool_calls" if tool_calls else "stop"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(description="Server for MiniMind")
|
||||||
|
parser.add_argument('--load_from', default='../model', type=str, help="模型加载路径(model=原生torch权重,其他路径=transformers格式)")
|
||||||
|
parser.add_argument('--save_dir', default='out', type=str, help="模型权重目录")
|
||||||
|
parser.add_argument('--weight', default='full_sft', type=str, help="权重名称前缀(pretrain, full_sft, dpo, reason, ppo_actor, grpo, spo)")
|
||||||
|
parser.add_argument('--lora_weight', default='None', type=str, help="LoRA权重名称(None表示不使用,可选:lora_identity, lora_medical)")
|
||||||
|
parser.add_argument('--hidden_size', default=768, type=int, help="隐藏层维度")
|
||||||
|
parser.add_argument('--num_hidden_layers', default=8, type=int, help="隐藏层数量")
|
||||||
|
parser.add_argument('--max_seq_len', default=8192, type=int, help="最大序列长度")
|
||||||
|
parser.add_argument('--use_moe', default=0, type=int, choices=[0, 1], help="是否使用MoE架构(0=否,1=是)")
|
||||||
|
parser.add_argument('--inference_rope_scaling', default=False, action='store_true', help="启用RoPE位置编码外推(4倍,仅解决位置编码问题)")
|
||||||
|
parser.add_argument('--device', default='cuda' if torch.cuda.is_available() else 'cpu', type=str, help="运行设备")
|
||||||
|
args = parser.parse_args()
|
||||||
|
device = args.device
|
||||||
|
model, tokenizer = init_model(args)
|
||||||
|
uvicorn.run(app, host="0.0.0.0", port=8998)
|
||||||
420
scripts/web_demo.py
Normal file
@ -0,0 +1,420 @@
|
|||||||
|
import random
|
||||||
|
import re
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from threading import Thread
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
import streamlit as st
|
||||||
|
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
|
||||||
|
|
||||||
|
st.set_page_config(page_title="MiniMind", initial_sidebar_state="collapsed")
|
||||||
|
|
||||||
|
st.markdown("""
|
||||||
|
<style>
|
||||||
|
/* 添加操作按钮样式 */
|
||||||
|
.stButton button {
|
||||||
|
border-radius: 50% !important; /* 改为圆形 */
|
||||||
|
width: 32px !important; /* 固定宽度 */
|
||||||
|
height: 32px !important; /* 固定高度 */
|
||||||
|
padding: 0 !important; /* 移除内边距 */
|
||||||
|
background-color: transparent !important;
|
||||||
|
border: 1px solid #ddd !important;
|
||||||
|
display: flex !important;
|
||||||
|
align-items: center !important;
|
||||||
|
justify-content: center !important;
|
||||||
|
font-size: 14px !important;
|
||||||
|
color: #666 !important; /* 更柔和的颜色 */
|
||||||
|
margin: 5px 10px 5px 0 !important; /* 调整按钮间距 */
|
||||||
|
}
|
||||||
|
.stButton button:hover {
|
||||||
|
border-color: #999 !important;
|
||||||
|
color: #333 !important;
|
||||||
|
background-color: #f5f5f5 !important;
|
||||||
|
}
|
||||||
|
.stMainBlockContainer > div:first-child {
|
||||||
|
margin-top: -50px !important;
|
||||||
|
}
|
||||||
|
.stApp > div:last-child {
|
||||||
|
margin-bottom: -35px !important;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* 重置按钮基础样式 */
|
||||||
|
.stButton > button {
|
||||||
|
all: unset !important; /* 重置所有默认样式 */
|
||||||
|
box-sizing: border-box !important;
|
||||||
|
border-radius: 50% !important;
|
||||||
|
width: 18px !important;
|
||||||
|
height: 18px !important;
|
||||||
|
min-width: 18px !important;
|
||||||
|
min-height: 18px !important;
|
||||||
|
max-width: 18px !important;
|
||||||
|
max-height: 18px !important;
|
||||||
|
padding: 0 !important;
|
||||||
|
background-color: transparent !important;
|
||||||
|
border: 1px solid #ddd !important;
|
||||||
|
display: flex !important;
|
||||||
|
align-items: center !important;
|
||||||
|
justify-content: center !important;
|
||||||
|
font-size: 14px !important;
|
||||||
|
color: #888 !important;
|
||||||
|
cursor: pointer !important;
|
||||||
|
transition: all 0.2s ease !important;
|
||||||
|
margin: 0 2px !important; /* 调整这里的 margin 值 */
|
||||||
|
}
|
||||||
|
|
||||||
|
</style>
|
||||||
|
""", unsafe_allow_html=True)
|
||||||
|
|
||||||
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
|
||||||
|
# 多语言文本
|
||||||
|
LANG_TEXTS = {
|
||||||
|
'zh': {
|
||||||
|
'settings': '模型设定调整',
|
||||||
|
'history_rounds': '历史对话轮次',
|
||||||
|
'max_length': '最大生成长度',
|
||||||
|
'temperature': '温度',
|
||||||
|
'thinking': '思考',
|
||||||
|
'tools': '工具',
|
||||||
|
'language': '语言',
|
||||||
|
'send': '给 MiniMind 发送消息',
|
||||||
|
'disclaimer': 'AI 生成内容可能存在错误,请仔细核实',
|
||||||
|
'think_tip': '自适应思考,目前多轮对话或Tool Call共存时思考不稳定',
|
||||||
|
'tool_select': '工具选择(最多4个)',
|
||||||
|
},
|
||||||
|
'en': {
|
||||||
|
'settings': 'Model Settings',
|
||||||
|
'history_rounds': 'History Rounds',
|
||||||
|
'max_length': 'Max Length',
|
||||||
|
'temperature': 'Temperature',
|
||||||
|
'thinking': 'Thinking',
|
||||||
|
'tools': 'Tools',
|
||||||
|
'language': 'Language',
|
||||||
|
'send': 'Send a message to MiniMind',
|
||||||
|
'disclaimer': 'AI-generated content may be inaccurate, please verify',
|
||||||
|
'think_tip': 'Adaptive thinking; may be unstable with multi-turn or Tool Call',
|
||||||
|
'tool_select': 'Tool Selection (max 4)',
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_text(key):
|
||||||
|
lang = st.session_state.get('lang', 'en')
|
||||||
|
return LANG_TEXTS.get(lang, {}).get(key, LANG_TEXTS['zh'].get(key, key))
|
||||||
|
|
||||||
|
# 工具定义
|
||||||
|
TOOLS = [
|
||||||
|
{"type": "function", "function": {"name": "calculate_math", "description": "计算数学表达式", "parameters": {"type": "object", "properties": {"expression": {"type": "string", "description": "数学表达式"}}, "required": ["expression"]}}},
|
||||||
|
{"type": "function", "function": {"name": "get_current_time", "description": "获取当前时间", "parameters": {"type": "object", "properties": {"timezone": {"type": "string", "default": "Asia/Shanghai"}}, "required": []}}},
|
||||||
|
{"type": "function", "function": {"name": "random_number", "description": "生成随机数", "parameters": {"type": "object", "properties": {"min": {"type": "integer"}, "max": {"type": "integer"}}, "required": ["min", "max"]}}},
|
||||||
|
{"type": "function", "function": {"name": "text_length", "description": "计算文本长度", "parameters": {"type": "object", "properties": {"text": {"type": "string"}}, "required": ["text"]}}},
|
||||||
|
{"type": "function", "function": {"name": "unit_converter", "description": "单位转换", "parameters": {"type": "object", "properties": {"value": {"type": "number"}, "from_unit": {"type": "string"}, "to_unit": {"type": "string"}}, "required": ["value", "from_unit", "to_unit"]}}},
|
||||||
|
{"type": "function", "function": {"name": "get_current_weather", "description": "获取天气", "parameters": {"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]}}},
|
||||||
|
{"type": "function", "function": {"name": "get_exchange_rate", "description": "获取汇率", "parameters": {"type": "object", "properties": {"from_currency": {"type": "string"}, "to_currency": {"type": "string"}}, "required": ["from_currency", "to_currency"]}}},
|
||||||
|
{"type": "function", "function": {"name": "translate_text", "description": "翻译文本", "parameters": {"type": "object", "properties": {"text": {"type": "string"}, "target_lang": {"type": "string"}}, "required": ["text", "target_lang"]}}},
|
||||||
|
]
|
||||||
|
|
||||||
|
TOOL_SHORT_NAMES = {
|
||||||
|
'calculate_math': '数学', 'get_current_time': '时间', 'random_number': '随机',
|
||||||
|
'text_length': '字数', 'unit_converter': '单位', 'get_current_weather': '天气',
|
||||||
|
'get_exchange_rate': '汇率', 'translate_text': '翻译'
|
||||||
|
}
|
||||||
|
|
||||||
|
def execute_tool(tool_name, args):
|
||||||
|
import datetime
|
||||||
|
try:
|
||||||
|
if tool_name == 'calculate_math':
|
||||||
|
return {"result": eval(args.get('expression', '0'))}
|
||||||
|
elif tool_name == 'get_current_time':
|
||||||
|
tz = args.get('timezone', 'Asia/Shanghai')
|
||||||
|
return {"result": datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}
|
||||||
|
elif tool_name == 'random_number':
|
||||||
|
return {"result": random.randint(args.get('min', 0), args.get('max', 100))}
|
||||||
|
elif tool_name == 'text_length':
|
||||||
|
return {"result": len(args.get('text', ''))}
|
||||||
|
elif tool_name == 'unit_converter':
|
||||||
|
return {"result": f"{args.get('value', 0)} {args.get('from_unit', '')} = ? {args.get('to_unit', '')}"}
|
||||||
|
elif tool_name == 'get_current_weather':
|
||||||
|
return {"result": f"{args.get('city', 'Unknown')}: 晴, 7~10°C"}
|
||||||
|
elif tool_name == 'get_exchange_rate':
|
||||||
|
return {"result": f"1 {args.get('from_currency', 'USD')} = 7.2 {args.get('to_currency', 'CNY')}"}
|
||||||
|
elif tool_name == 'translate_text':
|
||||||
|
return {"result": f"[翻译结果]: hello world"}
|
||||||
|
return {"result": "Unknown tool"}
|
||||||
|
except Exception as e:
|
||||||
|
return {"error": str(e)}
|
||||||
|
|
||||||
|
|
||||||
|
def process_assistant_content(content, is_streaming=False):
|
||||||
|
# 处理tool_call标签,格式化显示
|
||||||
|
if '<tool_call>' in content:
|
||||||
|
def format_tool_call(match):
|
||||||
|
try:
|
||||||
|
tc = json.loads(match.group(1))
|
||||||
|
name = tc.get('name', 'unknown')
|
||||||
|
args = tc.get('arguments', {})
|
||||||
|
return f'<div style="background: rgba(80, 110, 150, 0.20); border: 1px solid rgba(140, 170, 210, 0.30); padding: 10px 12px; border-radius: 12px; margin: 6px 0;"><div style="font-size:12px;opacity:.75;display:block;margin:0 0 6px 0;line-height:1;">ToolCalling</div><div><b>{name}</b>: {json.dumps(args, ensure_ascii=False)}</div></div>'
|
||||||
|
except:
|
||||||
|
return match.group(0)
|
||||||
|
content = re.sub(r'<tool_call>(.*?)</tool_call>', format_tool_call, content, flags=re.DOTALL)
|
||||||
|
|
||||||
|
# 流式生成且开启思考时,一开始就放到折叠里
|
||||||
|
if is_streaming and st.session_state.get('enable_thinking', False) and '</think>' not in content and '<think>' not in content:
|
||||||
|
m = re.search(r'(\n\n(?:我是|您好|你好)[^\n]*)', content)
|
||||||
|
if m and m.start(1) > 5:
|
||||||
|
i = m.start(1)
|
||||||
|
think_part = content[:i]
|
||||||
|
answer_part = content[i:]
|
||||||
|
return f'<details open style="border-left: 2px solid #666; padding-left: 12px; margin: 8px 0;"><summary style="cursor: pointer; color: #888;">已思考</summary><div style="color: #aaa; font-size: 0.95em; margin-top: 8px; max-height: 100px; overflow-y: auto;">{think_part.strip()}</div></details>{answer_part}'
|
||||||
|
elif len(content) > 5:
|
||||||
|
return f'<details open style="border-left: 2px solid #666; padding-left: 12px; margin: 8px 0;"><summary style="cursor: pointer; color: #888;">思考中...</summary><div style="color: #aaa; font-size: 0.95em; margin-top: 8px; max-height: 100px; overflow-y: auto; display: flex; flex-direction: column-reverse;"><div style="margin-bottom: auto;">{content.strip().replace(chr(10), "<br>")}</div></div></details>'
|
||||||
|
|
||||||
|
if '<think>' in content and '</think>' in content:
|
||||||
|
def format_think(match):
|
||||||
|
think_content = match.group(2)
|
||||||
|
if think_content.replace('\n', '').strip(): # 不是全换行
|
||||||
|
return f'<details open style="border-left: 2px solid #666; padding-left: 12px; margin: 8px 0;"><summary style="cursor: pointer; color: #888;">已思考</summary><div style="color: #aaa; font-size: 0.95em; margin-top: 8px; max-height: 100px; overflow-y: auto;">{think_content.strip()}</div></details>'
|
||||||
|
return ''
|
||||||
|
content = re.sub(r'(<think>)(.*?)(</think>)', format_think, content, flags=re.DOTALL)
|
||||||
|
|
||||||
|
if '<think>' in content and '</think>' not in content:
|
||||||
|
def format_think_in_progress(match):
|
||||||
|
tc = match.group(1)
|
||||||
|
return f'<details open style="border-left: 2px solid #666; padding-left: 12px; margin: 8px 0;"><summary style="cursor: pointer; color: #888;">思考中...</summary><div style="color: #aaa; font-size: 0.95em; margin-top: 8px; max-height: 100px; overflow-y: auto; display: flex; flex-direction: column-reverse;"><div style="margin-bottom: auto;">{tc.strip().replace(chr(10), "<br>")}</div></div></details>'
|
||||||
|
content = re.sub(r'<think>(.*?)$', format_think_in_progress, content, flags=re.DOTALL)
|
||||||
|
|
||||||
|
if '<think>' not in content and '</think>' in content:
|
||||||
|
def format_think_no_start(match):
|
||||||
|
think_content = match.group(1)
|
||||||
|
if think_content.replace('\n', '').strip():
|
||||||
|
return f'<details open style="border-left: 2px solid #666; padding-left: 12px; margin: 8px 0;"><summary style="cursor: pointer; color: #888;">已思考</summary><div style="color: #aaa; font-size: 0.95em; margin-top: 8px; max-height: 100px; overflow-y: auto;">{think_content.strip()}</div></details>'
|
||||||
|
return ''
|
||||||
|
content = re.sub(r'(.*?)</think>', format_think_no_start, content, flags=re.DOTALL)
|
||||||
|
|
||||||
|
return content
|
||||||
|
|
||||||
|
|
||||||
|
@st.cache_resource
|
||||||
|
def load_model_tokenizer(model_path):
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
model_path,
|
||||||
|
trust_remote_code=True
|
||||||
|
)
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
|
model_path,
|
||||||
|
trust_remote_code=True
|
||||||
|
)
|
||||||
|
model = model.half().eval().to(device)
|
||||||
|
return model, tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
def clear_chat_messages():
|
||||||
|
del st.session_state.messages
|
||||||
|
del st.session_state.chat_messages
|
||||||
|
|
||||||
|
|
||||||
|
def init_chat_messages():
|
||||||
|
if "messages" in st.session_state:
|
||||||
|
for i, message in enumerate(st.session_state.messages):
|
||||||
|
if message["role"] == "assistant":
|
||||||
|
st.markdown(process_assistant_content(message["content"]), unsafe_allow_html=True)
|
||||||
|
else:
|
||||||
|
st.markdown(
|
||||||
|
f'<div style="display: flex; justify-content: flex-end;"><div style="display: inline-block; margin: 10px 0; padding: 8px 12px 8px 12px; background-color: #3d4450; border-radius: 22px; color: white;">{message["content"]}</div></div>',
|
||||||
|
unsafe_allow_html=True)
|
||||||
|
|
||||||
|
else:
|
||||||
|
st.session_state.messages = []
|
||||||
|
st.session_state.chat_messages = []
|
||||||
|
|
||||||
|
return st.session_state.messages
|
||||||
|
|
||||||
|
def regenerate_answer(index):
|
||||||
|
st.session_state.messages.pop()
|
||||||
|
st.session_state.chat_messages.pop()
|
||||||
|
st.rerun()
|
||||||
|
|
||||||
|
|
||||||
|
# 动态扫描模型目录
|
||||||
|
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
MODEL_PATHS = {}
|
||||||
|
for d in sorted(os.listdir(script_dir), reverse=True):
|
||||||
|
full_path = os.path.join(script_dir, d)
|
||||||
|
if os.path.isdir(full_path) and not d.startswith('.') and not d.startswith('_'):
|
||||||
|
if any(f.endswith(('.bin', '.safetensors', '.pt')) or os.path.exists(os.path.join(full_path, 'model.safetensors.index.json')) for f in os.listdir(full_path) if os.path.isfile(os.path.join(full_path, f))):
|
||||||
|
MODEL_PATHS[d] = [d, d]
|
||||||
|
if not MODEL_PATHS:
|
||||||
|
MODEL_PATHS = {"No models found": ["", "No models"]}
|
||||||
|
|
||||||
|
# 模型选择
|
||||||
|
selected_model = st.sidebar.selectbox('Model', list(MODEL_PATHS.keys()), index=0)
|
||||||
|
model_path = MODEL_PATHS[selected_model][0]
|
||||||
|
slogan = f"我是 {MODEL_PATHS[selected_model][1]},有什么可以帮你的?" if st.session_state.get('lang', 'en') == 'zh' else f"I am {MODEL_PATHS[selected_model][1]}, how can I help you?"
|
||||||
|
|
||||||
|
st.sidebar.markdown('<hr style="margin: 12px 0 16px 0;">', unsafe_allow_html=True)
|
||||||
|
|
||||||
|
# 语言选择
|
||||||
|
lang_options = {'中文': 'zh', 'English': 'en'}
|
||||||
|
current_lang = st.session_state.get('lang', 'en')
|
||||||
|
lang_index = 0 if current_lang == 'zh' else 1
|
||||||
|
lang_label = st.sidebar.radio('Language / 语言', list(lang_options.keys()), index=lang_index, horizontal=True)
|
||||||
|
if lang_options[lang_label] != current_lang:
|
||||||
|
st.session_state.lang = lang_options[lang_label]
|
||||||
|
st.rerun()
|
||||||
|
|
||||||
|
st.sidebar.markdown('<hr style="margin: 12px 0 16px 0;">', unsafe_allow_html=True)
|
||||||
|
|
||||||
|
# 参数设置
|
||||||
|
st.session_state.history_chat_num = st.sidebar.slider(get_text('history_rounds'), 0, 8, 0, step=2)
|
||||||
|
st.session_state.max_new_tokens = st.sidebar.slider(get_text('max_length'), 256, 8192, 8192, step=1)
|
||||||
|
st.session_state.temperature = st.sidebar.slider(get_text('temperature'), 0.6, 1.2, 0.90, step=0.01)
|
||||||
|
|
||||||
|
st.sidebar.markdown('<hr style="margin: 12px 0 16px 0;">', unsafe_allow_html=True)
|
||||||
|
|
||||||
|
# 功能开关
|
||||||
|
st.session_state.enable_thinking = st.sidebar.checkbox(get_text('thinking'), value=False, help=get_text('think_tip'))
|
||||||
|
st.session_state.selected_tools = []
|
||||||
|
with st.sidebar.expander(get_text('tools')):
|
||||||
|
st.caption(get_text('tool_select'))
|
||||||
|
selected_count = sum(1 for tool in TOOLS if st.session_state.get(f"tool_{tool['function']['name']}", False))
|
||||||
|
for tool in TOOLS:
|
||||||
|
name = tool['function']['name']
|
||||||
|
short_name = TOOL_SHORT_NAMES.get(name, name)
|
||||||
|
checked = st.checkbox(short_name, key=f"tool_{name}", disabled=(selected_count >= 4 and not st.session_state.get(f"tool_{name}", False)))
|
||||||
|
if checked and len(st.session_state.selected_tools) < 4:
|
||||||
|
st.session_state.selected_tools.append(name)
|
||||||
|
|
||||||
|
image_url = "https://www.modelscope.cn/api/v1/studio/gongjy/MiniMind/repo?Revision=master&FilePath=images%2Flogo2.png&View=true"
|
||||||
|
|
||||||
|
st.markdown(
|
||||||
|
f'<div style="display: flex; flex-direction: column; align-items: center; text-align: center; margin: 0; padding: 0;">'
|
||||||
|
'<div style="font-style: italic; font-weight: 900; margin: 0; padding-top: 4px; display: flex; align-items: center; justify-content: center; flex-wrap: wrap; width: 100%;">'
|
||||||
|
f'<img src="{image_url}" style="width: 40px; height: 40px; "> '
|
||||||
|
f'<span style="font-size: 26px; margin-left: 10px;">{slogan}</span>'
|
||||||
|
'</div>'
|
||||||
|
f'<span style="color: #bbb; font-style: italic; margin-top: 6px; margin-bottom: 10px;">{get_text("disclaimer")}</span>'
|
||||||
|
'</div>',
|
||||||
|
unsafe_allow_html=True
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def setup_seed(seed):
|
||||||
|
random.seed(seed)
|
||||||
|
np.random.seed(seed)
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
torch.cuda.manual_seed(seed)
|
||||||
|
torch.cuda.manual_seed_all(seed)
|
||||||
|
torch.backends.cudnn.deterministic = True
|
||||||
|
torch.backends.cudnn.benchmark = False
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
model, tokenizer = load_model_tokenizer(model_path)
|
||||||
|
|
||||||
|
if "messages" not in st.session_state:
|
||||||
|
st.session_state.messages = []
|
||||||
|
st.session_state.chat_messages = []
|
||||||
|
|
||||||
|
messages = st.session_state.messages
|
||||||
|
|
||||||
|
for i, message in enumerate(messages):
|
||||||
|
if message["role"] == "assistant":
|
||||||
|
st.markdown(process_assistant_content(message["content"]), unsafe_allow_html=True)
|
||||||
|
else:
|
||||||
|
st.markdown(
|
||||||
|
f'<div style="display: flex; justify-content: flex-end;"><div style="display: inline-block; margin: 10px 0; padding: 8px 12px 8px 12px; background-color: #3d4450; border-radius: 22px; color: white;">{message["content"]}</div></div>',
|
||||||
|
unsafe_allow_html=True)
|
||||||
|
|
||||||
|
prompt = st.chat_input(key="input", placeholder=get_text('send'))
|
||||||
|
|
||||||
|
if hasattr(st.session_state, 'regenerate') and st.session_state.regenerate:
|
||||||
|
prompt = st.session_state.last_user_message
|
||||||
|
regenerate_index = st.session_state.regenerate_index
|
||||||
|
delattr(st.session_state, 'regenerate')
|
||||||
|
delattr(st.session_state, 'last_user_message')
|
||||||
|
delattr(st.session_state, 'regenerate_index')
|
||||||
|
|
||||||
|
if prompt:
|
||||||
|
st.markdown(
|
||||||
|
f'<div style="display: flex; justify-content: flex-end;"><div style="display: inline-block; margin: 10px 0; padding: 8px 12px 8px 12px; background-color: #3d4450; border-radius: 22px; color: white;">{prompt}</div></div>',
|
||||||
|
unsafe_allow_html=True)
|
||||||
|
messages.append({"role": "user", "content": prompt[-st.session_state.max_new_tokens:]})
|
||||||
|
st.session_state.chat_messages.append({"role": "user", "content": prompt[-st.session_state.max_new_tokens:]})
|
||||||
|
|
||||||
|
placeholder = st.empty()
|
||||||
|
|
||||||
|
random_seed = random.randint(0, 2 ** 32 - 1)
|
||||||
|
setup_seed(random_seed)
|
||||||
|
|
||||||
|
tools = [t for t in TOOLS if t['function']['name'] in st.session_state.get('selected_tools', [])] or None
|
||||||
|
sys_prompt = [] if tools else [{"role": "system", "content": "你是MiniMind,一个乐于助人、知识渊博的AI助手。请用完整且友好的方式回答用户问题。"}]
|
||||||
|
st.session_state.chat_messages = sys_prompt + st.session_state.chat_messages[-(st.session_state.history_chat_num + 1):]
|
||||||
|
template_kwargs = {"tokenize": False, "add_generation_prompt": True}
|
||||||
|
if st.session_state.get('enable_thinking', False):
|
||||||
|
template_kwargs["open_thinking"] = True
|
||||||
|
if tools:
|
||||||
|
template_kwargs["tools"] = tools
|
||||||
|
new_prompt = tokenizer.apply_chat_template(st.session_state.chat_messages, **template_kwargs)
|
||||||
|
|
||||||
|
inputs = tokenizer(new_prompt, return_tensors="pt", truncation=True).to(device)
|
||||||
|
|
||||||
|
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
||||||
|
generation_kwargs = {
|
||||||
|
"input_ids": inputs.input_ids,
|
||||||
|
"max_length": inputs.input_ids.shape[1] + st.session_state.max_new_tokens,
|
||||||
|
"num_return_sequences": 1,
|
||||||
|
"do_sample": True,
|
||||||
|
"attention_mask": inputs.attention_mask,
|
||||||
|
"pad_token_id": tokenizer.pad_token_id,
|
||||||
|
"eos_token_id": tokenizer.eos_token_id,
|
||||||
|
"temperature": st.session_state.temperature,
|
||||||
|
"top_p": 0.85,
|
||||||
|
"streamer": streamer,
|
||||||
|
}
|
||||||
|
|
||||||
|
Thread(target=model.generate, kwargs=generation_kwargs).start()
|
||||||
|
|
||||||
|
answer = ""
|
||||||
|
for new_text in streamer:
|
||||||
|
answer += new_text
|
||||||
|
placeholder.markdown(process_assistant_content(answer, is_streaming=True), unsafe_allow_html=True)
|
||||||
|
|
||||||
|
full_answer = answer
|
||||||
|
for _ in range(16):
|
||||||
|
tool_calls = re.findall(r'<tool_call>(.*?)</tool_call>', answer, re.DOTALL)
|
||||||
|
if not tool_calls:
|
||||||
|
break
|
||||||
|
st.session_state.chat_messages.append({"role": "assistant", "content": answer})
|
||||||
|
tool_results = []
|
||||||
|
for tc_str in tool_calls:
|
||||||
|
try:
|
||||||
|
tc = json.loads(tc_str.strip())
|
||||||
|
result = execute_tool(tc.get('name', ''), tc.get('arguments', {}))
|
||||||
|
st.session_state.chat_messages.append({"role": "tool", "content": json.dumps(result, ensure_ascii=False)})
|
||||||
|
tool_results.append(f'<div style="background: rgba(90, 130, 110, 0.20); border: 1px solid rgba(150, 200, 170, 0.30); padding: 10px 12px; border-radius: 12px; margin: 6px 0;"><div style="font-size:12px;opacity:.75;display:block;margin:0 0 6px 0;line-height:1;">ToolCalled</div><div><b>{tc.get("name", "")}</b>: {json.dumps(result, ensure_ascii=False)}</div></div>')
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
full_answer += "\n" + "\n".join(tool_results) + "\n"
|
||||||
|
placeholder.markdown(process_assistant_content(full_answer, is_streaming=True), unsafe_allow_html=True)
|
||||||
|
new_prompt = tokenizer.apply_chat_template(st.session_state.chat_messages, **template_kwargs)
|
||||||
|
inputs = tokenizer(new_prompt, return_tensors="pt", truncation=True).to(device)
|
||||||
|
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
||||||
|
generation_kwargs["input_ids"] = inputs.input_ids
|
||||||
|
generation_kwargs["attention_mask"] = inputs.attention_mask
|
||||||
|
generation_kwargs["max_length"] = inputs.input_ids.shape[1] + st.session_state.max_new_tokens
|
||||||
|
generation_kwargs["streamer"] = streamer
|
||||||
|
Thread(target=model.generate, kwargs=generation_kwargs).start()
|
||||||
|
answer = ""
|
||||||
|
for new_text in streamer:
|
||||||
|
answer += new_text
|
||||||
|
placeholder.markdown(process_assistant_content(full_answer + answer, is_streaming=True), unsafe_allow_html=True)
|
||||||
|
full_answer += answer
|
||||||
|
answer = full_answer
|
||||||
|
|
||||||
|
messages.append({"role": "assistant", "content": answer})
|
||||||
|
st.session_state.chat_messages.append({"role": "assistant", "content": answer})
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
224
trainer/rollout_engine.py
Normal file
@ -0,0 +1,224 @@
|
|||||||
|
"""
|
||||||
|
# 如果使用sglang加速,需通过以下命令首先启动(transformers格式)模型:
|
||||||
|
python -m sglang.launch_server --model-path ./minimind-3 --attention-backend triton --host 0.0.0.0 --port 8998
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
__package__ = "trainer"
|
||||||
|
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
||||||
|
|
||||||
|
import requests
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from contextlib import nullcontext
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import List, Optional, Tuple
|
||||||
|
from torch import Tensor
|
||||||
|
from torch.nn.parallel import DistributedDataParallel
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
|
||||||
|
# ===== 计算每个 token 的 logprob =====
|
||||||
|
def compute_per_token_logps(model, input_ids: Tensor, n_keep: int, attention_mask: Optional[Tensor] = None) -> Tensor:
|
||||||
|
if n_keep <= 0:
|
||||||
|
return input_ids.new_empty((input_ids.size(0), 0), dtype=torch.float32)
|
||||||
|
unwrapped = model.module if isinstance(model, DistributedDataParallel) else model
|
||||||
|
input_ids = input_ids.detach().clone() if input_ids.is_inference() else input_ids
|
||||||
|
logits = unwrapped(input_ids, attention_mask=attention_mask, logits_to_keep=n_keep + 1).logits[:, :-1, :]
|
||||||
|
per_token_logps = []
|
||||||
|
for logits_row, ids_row in zip(logits, input_ids[:, -n_keep:]):
|
||||||
|
ids_row = ids_row.detach().clone() if ids_row.is_inference() else ids_row
|
||||||
|
per_token_logps.append(
|
||||||
|
torch.gather(logits_row.log_softmax(dim=-1), 1, ids_row.unsqueeze(1)).squeeze(1)
|
||||||
|
)
|
||||||
|
return torch.stack(per_token_logps)
|
||||||
|
|
||||||
|
|
||||||
|
# ===== Rollout 结果 =====
|
||||||
|
@dataclass
|
||||||
|
class RolloutResult:
|
||||||
|
output_ids: Tensor
|
||||||
|
completion_ids: Tensor
|
||||||
|
per_token_logps: Tensor
|
||||||
|
completions: List[str]
|
||||||
|
prompt_lens: Tensor
|
||||||
|
completion_mask: Tensor
|
||||||
|
|
||||||
|
|
||||||
|
# ===== Rollout 引擎抽象基类 =====
|
||||||
|
class RolloutEngine(ABC):
|
||||||
|
tokenizer = None
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def rollout(self, prompt_ids: Tensor, attention_mask: Tensor, num_generations: int, max_new_tokens: int, temperature: float = 0.8) -> RolloutResult:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def update_policy(self, model: torch.nn.Module):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
# ===== PyTorch 原生推理引擎 =====
|
||||||
|
class TorchRolloutEngine(RolloutEngine):
|
||||||
|
def __init__(self, policy_model: torch.nn.Module, tokenizer, device: str = "cuda", autocast_ctx=None):
|
||||||
|
self.policy_model = policy_model
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
self.device = device
|
||||||
|
self.autocast_ctx = autocast_ctx
|
||||||
|
|
||||||
|
def rollout(self, prompt_ids: Tensor, attention_mask: Tensor, num_generations: int, max_new_tokens: int, temperature: float = 0.8) -> RolloutResult:
|
||||||
|
model = self.policy_model.module if isinstance(self.policy_model, DistributedDataParallel) else self.policy_model
|
||||||
|
ctx = self.autocast_ctx if self.autocast_ctx else nullcontext()
|
||||||
|
with torch.no_grad(), ctx:
|
||||||
|
output_ids = model.generate(
|
||||||
|
input_ids=prompt_ids.repeat_interleave(num_generations, dim=0),
|
||||||
|
attention_mask=attention_mask.repeat_interleave(num_generations, dim=0),
|
||||||
|
max_new_tokens=max_new_tokens,
|
||||||
|
do_sample=True,
|
||||||
|
temperature=temperature,
|
||||||
|
num_return_sequences=1,
|
||||||
|
pad_token_id=self.tokenizer.pad_token_id,
|
||||||
|
eos_token_id=self.tokenizer.eos_token_id,
|
||||||
|
) # [B*num_gen, P+R]
|
||||||
|
prompt_len = prompt_ids.size(1)
|
||||||
|
completion_ids = output_ids[:, prompt_len:] # [B*num_gen, R]
|
||||||
|
full_mask = (output_ids != self.tokenizer.pad_token_id).long()
|
||||||
|
per_token_logps = compute_per_token_logps(self.policy_model, output_ids, completion_ids.size(1), attention_mask=full_mask)
|
||||||
|
completions = self.tokenizer.batch_decode(completion_ids, skip_special_tokens=True)
|
||||||
|
return RolloutResult(output_ids, completion_ids, per_token_logps, completions,
|
||||||
|
prompt_ids.new_full((output_ids.size(0),), prompt_len),
|
||||||
|
attention_mask.new_ones(output_ids.size(0), completion_ids.size(1)))
|
||||||
|
|
||||||
|
def update_policy(self, model: torch.nn.Module):
|
||||||
|
self.policy_model = model
|
||||||
|
|
||||||
|
|
||||||
|
# ===== SGLang HTTP API 推理引擎 =====
|
||||||
|
class SGLangRolloutEngine(RolloutEngine):
|
||||||
|
def __init__(self, base_url: str, model_path: str, shared_ckpt_path: str = "./sglang_ckpt", timeout: int = 120):
|
||||||
|
self.base_url = base_url.rstrip('/')
|
||||||
|
self.shared_ckpt_path = shared_ckpt_path
|
||||||
|
self.timeout = timeout
|
||||||
|
self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
||||||
|
self.http = requests
|
||||||
|
|
||||||
|
def rollout(self, prompt_ids: Tensor, attention_mask: Tensor, num_generations: int, max_new_tokens: int, temperature: float = 0.8) -> RolloutResult:
|
||||||
|
# 去除左侧 padding tokens,只保留有效 token
|
||||||
|
input_ids_list = []
|
||||||
|
for ids, mask in zip(prompt_ids, attention_mask):
|
||||||
|
valid_ids = ids[mask.bool()].tolist()
|
||||||
|
input_ids_list.append(valid_ids)
|
||||||
|
all_input_ids = [ids for ids in input_ids_list for _ in range(num_generations)]
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"input_ids": all_input_ids,
|
||||||
|
"sampling_params": {
|
||||||
|
"temperature": temperature,
|
||||||
|
"max_new_tokens": max_new_tokens,
|
||||||
|
"stop_token_ids": [self.tokenizer.eos_token_id] if self.tokenizer.eos_token_id else [],
|
||||||
|
},
|
||||||
|
"return_logprob": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
resp = self.http.post(f"{self.base_url}/generate", json=payload, timeout=self.timeout)
|
||||||
|
resp.raise_for_status()
|
||||||
|
|
||||||
|
results = resp.json()
|
||||||
|
if not isinstance(results, list):
|
||||||
|
results = [results]
|
||||||
|
|
||||||
|
all_output_ids, all_completion_ids, all_logprobs = [], [], []
|
||||||
|
completions = []
|
||||||
|
|
||||||
|
for i, result in enumerate(results):
|
||||||
|
meta = result.get("meta_info", {})
|
||||||
|
completion_ids = meta.get("output_ids", result.get("output_ids", []))
|
||||||
|
raw_logprobs = meta.get("output_token_logprobs", [])
|
||||||
|
|
||||||
|
logprobs = []
|
||||||
|
for item in raw_logprobs:
|
||||||
|
if isinstance(item, (list, tuple)) and len(item) >= 1:
|
||||||
|
logprobs.append(item[0])
|
||||||
|
elif isinstance(item, (int, float)):
|
||||||
|
logprobs.append(item)
|
||||||
|
|
||||||
|
if len(logprobs) < len(completion_ids):
|
||||||
|
logprobs = [0.0] * (len(completion_ids) - len(logprobs)) + logprobs
|
||||||
|
elif len(logprobs) > len(completion_ids):
|
||||||
|
logprobs = logprobs[-len(completion_ids):] if completion_ids else []
|
||||||
|
prompt = all_input_ids[i]
|
||||||
|
full_output = prompt + completion_ids
|
||||||
|
all_output_ids.append(full_output)
|
||||||
|
all_completion_ids.append(completion_ids)
|
||||||
|
all_logprobs.append(logprobs)
|
||||||
|
completions.append(self.tokenizer.decode(completion_ids, skip_special_tokens=True))
|
||||||
|
|
||||||
|
device = prompt_ids.device
|
||||||
|
max_comp_len = max(1, max(len(ids) for ids in all_completion_ids))
|
||||||
|
max_out_len = max(len(ids) for ids in all_input_ids) + max_comp_len
|
||||||
|
|
||||||
|
def pad_to_tensor(seqs, max_len, pad_val=0):
|
||||||
|
return torch.tensor([s + [pad_val] * (max_len - len(s)) for s in seqs], device=device)
|
||||||
|
|
||||||
|
pad_id = self.tokenizer.pad_token_id
|
||||||
|
return RolloutResult(
|
||||||
|
output_ids=pad_to_tensor(all_output_ids, max_out_len, pad_val=pad_id),
|
||||||
|
completion_ids=pad_to_tensor(all_completion_ids, max_comp_len, pad_val=pad_id),
|
||||||
|
per_token_logps=pad_to_tensor(all_logprobs, max_comp_len, pad_val=0.0),
|
||||||
|
completions=completions,
|
||||||
|
prompt_lens=torch.tensor([len(ids) for ids in all_input_ids], device=device),
|
||||||
|
completion_mask=torch.tensor([[1] * len(ids) + [0] * (max_comp_len - len(ids)) for ids in all_completion_ids], device=device),
|
||||||
|
)
|
||||||
|
|
||||||
|
def update_policy(self, model: torch.nn.Module):
|
||||||
|
ok = True
|
||||||
|
if not dist.is_initialized() or dist.get_rank() == 0:
|
||||||
|
try:
|
||||||
|
unwrapped = model.module if isinstance(model, DistributedDataParallel) else model
|
||||||
|
unwrapped = getattr(unwrapped, '_orig_mod', unwrapped)
|
||||||
|
abs_path = os.path.abspath(self.shared_ckpt_path)
|
||||||
|
state_dict = {k: v.detach().half().cpu() for k, v in unwrapped.state_dict().items()}
|
||||||
|
unwrapped.save_pretrained(abs_path, state_dict=state_dict, safe_serialization=False)
|
||||||
|
self.tokenizer.save_pretrained(abs_path)
|
||||||
|
resp = self.http.post(f"{self.base_url}/update_weights_from_disk", json={"model_path": abs_path}, timeout=self.timeout)
|
||||||
|
if resp.status_code != 200: print(f"[SGLANG WARNING] update_weights 失败: {resp.status_code}, {resp.text}")
|
||||||
|
ok = resp.status_code == 200
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[SGLANG WARNING] update_weights 异常: {e}"); ok = False
|
||||||
|
if dist.is_initialized():
|
||||||
|
ok_t = torch.tensor(int(ok), device=next(model.parameters()).device)
|
||||||
|
dist.broadcast(ok_t, src=0); dist.barrier(); ok = bool(ok_t.item())
|
||||||
|
if not ok: raise RuntimeError("SGLang update_policy failed")
|
||||||
|
return ok
|
||||||
|
|
||||||
|
def flush_cache(self) -> bool:
|
||||||
|
resp = self.http.post(f"{self.base_url}/flush_cache", timeout=30)
|
||||||
|
return resp.status_code == 200
|
||||||
|
|
||||||
|
def health(self) -> bool:
|
||||||
|
try:
|
||||||
|
resp = self.http.get(f"{self.base_url}/health", timeout=5)
|
||||||
|
return resp.status_code == 200
|
||||||
|
except:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
# ===== 工厂函数 =====
|
||||||
|
def create_rollout_engine(
|
||||||
|
engine_type: str = "torch",
|
||||||
|
policy_model: torch.nn.Module = None,
|
||||||
|
tokenizer = None,
|
||||||
|
device: str = "cuda",
|
||||||
|
autocast_ctx = None,
|
||||||
|
sglang_base_url: str = None,
|
||||||
|
sglang_model_path: str = None,
|
||||||
|
sglang_shared_path: str = None,
|
||||||
|
) -> RolloutEngine:
|
||||||
|
if engine_type == "torch":
|
||||||
|
return TorchRolloutEngine(policy_model, tokenizer, device, autocast_ctx)
|
||||||
|
elif engine_type == "sglang":
|
||||||
|
return SGLangRolloutEngine(sglang_base_url, sglang_model_path, sglang_shared_path)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"不支持的引擎类型: {engine_type}")
|
||||||
488
trainer/train_agent.py
Normal file
@ -0,0 +1,488 @@
|
|||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
__package__ = "trainer"
|
||||||
|
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
||||||
|
|
||||||
|
import re
|
||||||
|
import gc
|
||||||
|
import json
|
||||||
|
import math
|
||||||
|
import random
|
||||||
|
import signal
|
||||||
|
import argparse
|
||||||
|
import warnings
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import torch.distributed as dist
|
||||||
|
from contextlib import nullcontext
|
||||||
|
from torch import optim
|
||||||
|
from torch.nn.parallel import DistributedDataParallel
|
||||||
|
from torch.utils.data import DataLoader, DistributedSampler
|
||||||
|
from torch.optim.lr_scheduler import CosineAnnealingLR
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
from model.model_minimind import MiniMindConfig, MiniMindForCausalLM
|
||||||
|
from dataset.lm_dataset import AgentRLDataset
|
||||||
|
from trainer.trainer_utils import Logger, is_main_process, lm_checkpoint, init_distributed_mode, setup_seed, SkipBatchSampler, init_model, LMForRewardModel
|
||||||
|
from trainer.rollout_engine import create_rollout_engine, compute_per_token_logps
|
||||||
|
|
||||||
|
warnings.filterwarnings('ignore')
|
||||||
|
|
||||||
|
# ================================ 工具与 Reward = Start ================================
|
||||||
|
|
||||||
|
def rep_penalty(text, n=3, cap=0.5):
|
||||||
|
toks = re.findall(r"\w+|[^\w\s]", text.lower())
|
||||||
|
grams = [tuple(toks[i:i + n]) for i in range(len(toks) - n + 1)]
|
||||||
|
return min(cap, (len(grams) - len(set(grams))) * cap * 2 / len(grams)) if grams else 0.0
|
||||||
|
|
||||||
|
# ======== 工具定义 ========
|
||||||
|
TOOLS = [
|
||||||
|
{"type": "function", "function": {"name": "calculate_math", "description": "计算数学表达式", "parameters": {"type": "object", "properties": {"expression": {"type": "string"}}, "required": ["expression"]}}},
|
||||||
|
{"type": "function", "function": {"name": "unit_converter", "description": "单位换算", "parameters": {"type": "object", "properties": {"value": {"type": "number"}, "from_unit": {"type": "string"}, "to_unit": {"type": "string"}}, "required": ["value", "from_unit", "to_unit"]}}},
|
||||||
|
{"type": "function", "function": {"name": "get_current_weather", "description": "获取天气", "parameters": {"type": "object", "properties": {"location": {"type": "string"}}, "required": ["location"]}}},
|
||||||
|
{"type": "function", "function": {"name": "get_current_time", "description": "获取时间", "parameters": {"type": "object", "properties": {"timezone": {"type": "string", "default": "Asia/Shanghai"}}, "required": []}}},
|
||||||
|
{"type": "function", "function": {"name": "get_exchange_rate", "description": "查询汇率", "parameters": {"type": "object", "properties": {"from_currency": {"type": "string"}, "to_currency": {"type": "string"}}, "required": ["from_currency", "to_currency"]}}},
|
||||||
|
{"type": "function", "function": {"name": "translate_text", "description": "翻译文本", "parameters": {"type": "object", "properties": {"text": {"type": "string"}, "target_language": {"type": "string"}}, "required": ["text", "target_language"]}}},
|
||||||
|
]
|
||||||
|
|
||||||
|
# ======== 模拟数据 ========
|
||||||
|
WEATHER_DATA = {"北京": ("28°C", "晴"), "上海": ("15°C", "多云"), "广州": ("32°C", "闷热"), "深圳": ("30°C", "晴"), "杭州": ("22°C", "阴"), "成都": ("18°C", "小雨"), "武汉": ("25°C", "多云"), "南京": ("20°C", "晴"), "西安": ("16°C", "大风"), "重庆": ("26°C", "阴"), "Tokyo": ("12°C", "晴"), "New York": ("8°C", "多云"), "London": ("5°C", "小雨"), "Paris": ("10°C", "阴"), "Sydney": ("25°C", "晴朗")}
|
||||||
|
TIME_DATA = {"Asia/Shanghai": "2025-03-07 14:30:00", "America/New_York": "2025-03-07 01:30:00", "Europe/London": "2025-03-07 06:30:00", "Asia/Tokyo": "2025-03-07 15:30:00", "Europe/Paris": "2025-03-07 07:30:00", "Australia/Sydney": "2025-03-07 17:30:00"}
|
||||||
|
EXCHANGE_DATA = {("USD", "CNY"): 7.21, ("EUR", "CNY"): 7.85, ("GBP", "CNY"): 9.12, ("JPY", "CNY"): 0.048, ("USD", "EUR"): 0.92, ("USD", "GBP"): 0.79, ("CNY", "JPY"): 20.83, ("AUD", "CNY"): 4.72}
|
||||||
|
TRANSLATE_DATA = {("你好世界", "english"): "Hello World", ("Good morning", "chinese"): "早上好", ("今天天气真好", "english"): "The weather is nice today", ("I love programming", "chinese"): "我喜欢编程", ("机器学习很有趣", "english"): "Machine learning is interesting", ("Happy birthday", "chinese"): "生日快乐"}
|
||||||
|
UNIT_DATA = {"km_miles": 0.621371, "miles_km": 1.60934, "kg_pounds": 2.20462, "pounds_kg": 0.453592, "meters_feet": 3.28084, "feet_meters": 0.3048, "celsius_fahrenheit": 1.8, "fahrenheit_celsius": 0.5556}
|
||||||
|
|
||||||
|
# ======== 模拟执行 ========
|
||||||
|
MOCK_RESULTS = {
|
||||||
|
"calculate_math": lambda args: {"result": str(eval(str(args.get("expression", "0")).replace("^", "**").replace("×", "*").replace("÷", "/").replace("−", "-").replace("(", "(").replace(")", ")"), {"__builtins__": {}, "math": math}))},
|
||||||
|
"unit_converter": lambda args: {"result": round(float(args.get("value", 0)) * UNIT_DATA.get(f"{args.get('from_unit', '').lower()}_{args.get('to_unit', '').lower()}", 1), 4)},
|
||||||
|
"get_current_weather": lambda args: (lambda w: {"city": args.get("location"), "temperature": w[0], "humidity": "65%", "condition": w[1]})(WEATHER_DATA.get(args.get("location"), ("22°C", "晴"))),
|
||||||
|
"get_current_time": lambda args: {"datetime": TIME_DATA.get(args.get("timezone", "Asia/Shanghai"), "2025-03-07 14:30:00"), "timezone": args.get("timezone", "Asia/Shanghai")},
|
||||||
|
"get_exchange_rate": lambda args: {"from": args.get("from_currency"), "to": args.get("to_currency"), "rate": EXCHANGE_DATA.get((args.get("from_currency"), args.get("to_currency")), 1.0)},
|
||||||
|
"translate_text": lambda args: {"translated_text": TRANSLATE_DATA.get((args.get("text"), args.get("target_language")), args.get("text", ""))},
|
||||||
|
}
|
||||||
|
|
||||||
|
# ======== 参数校验 ========
|
||||||
|
CHECK_ARGS = {
|
||||||
|
"calculate_math": lambda a: bool(a.get("expression")),
|
||||||
|
"unit_converter": lambda a: a.get("value") is not None and a.get("from_unit") and a.get("to_unit"),
|
||||||
|
"get_current_weather": lambda a: bool(a.get("location")),
|
||||||
|
"get_current_time": lambda a: True,
|
||||||
|
"get_exchange_rate": lambda a: bool(a.get("from_currency")) and bool(a.get("to_currency")),
|
||||||
|
"translate_text": lambda a: bool(a.get("text")) and bool(a.get("target_language")),
|
||||||
|
}
|
||||||
|
|
||||||
|
# ======== 工具调用解析与执行 ========
|
||||||
|
def parse_tool_calls(text):
|
||||||
|
calls = []
|
||||||
|
for m in re.findall(r'<tool_call>(.*?)</tool_call>', text, re.DOTALL):
|
||||||
|
try: calls.append(json.loads(m.strip()))
|
||||||
|
except: pass
|
||||||
|
return calls
|
||||||
|
|
||||||
|
def execute_tool(name, args):
|
||||||
|
fn = MOCK_RESULTS.get(name)
|
||||||
|
if not fn: return None
|
||||||
|
try:
|
||||||
|
signal.signal(signal.SIGALRM, lambda *_: (_ for _ in ()).throw(TimeoutError()))
|
||||||
|
signal.alarm(1)
|
||||||
|
return fn(args)
|
||||||
|
except:
|
||||||
|
return None
|
||||||
|
finally:
|
||||||
|
try: signal.alarm(0)
|
||||||
|
except: pass
|
||||||
|
|
||||||
|
# ======== 多轮 Rollout ========
|
||||||
|
def rollout_single(rollout_engine, tokenizer, messages, tools, max_turns=3, max_new_tokens=256, thinking_ratio=0.5, device="cuda"):
|
||||||
|
all_outputs = []
|
||||||
|
prompt_ids = None
|
||||||
|
response_ids = []
|
||||||
|
response_mask = []
|
||||||
|
response_old_logps = []
|
||||||
|
final_context = ""
|
||||||
|
unfinished = False
|
||||||
|
open_thinking = random.random() < thinking_ratio
|
||||||
|
for turn in range(max_turns):
|
||||||
|
context = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True, tools=tools, open_thinking=open_thinking)
|
||||||
|
inputs = tokenizer(context, return_tensors="pt", add_special_tokens=False).to(device)
|
||||||
|
context_ids = inputs["input_ids"][0].tolist()
|
||||||
|
if prompt_ids is None:
|
||||||
|
prompt_ids = context_ids
|
||||||
|
rollout_result = rollout_engine.rollout(
|
||||||
|
prompt_ids=inputs["input_ids"],
|
||||||
|
attention_mask=inputs["attention_mask"],
|
||||||
|
num_generations=1,
|
||||||
|
max_new_tokens=max_new_tokens,
|
||||||
|
temperature=0.8,
|
||||||
|
)
|
||||||
|
new_ids = rollout_result.completion_ids[0].tolist()
|
||||||
|
new_logps = rollout_result.per_token_logps[0].tolist()
|
||||||
|
if len(new_ids) != len(new_logps): Logger(f"rollout token/logprob length mismatch: {len(new_ids)} vs {len(new_logps)}")
|
||||||
|
pairs = [(t, lp) for t, lp in zip(new_ids, new_logps) if t != tokenizer.pad_token_id and t != tokenizer.eos_token_id]
|
||||||
|
new_ids = [t for t, _ in pairs]
|
||||||
|
new_logps = [lp for _, lp in pairs]
|
||||||
|
new_text = rollout_result.completions[0]
|
||||||
|
all_outputs.append(new_text)
|
||||||
|
response_ids.extend(new_ids)
|
||||||
|
response_mask.extend([1] * len(new_ids))
|
||||||
|
response_old_logps.extend(new_logps)
|
||||||
|
final_context = context + new_text
|
||||||
|
calls = parse_tool_calls(new_text)
|
||||||
|
if not calls:
|
||||||
|
break
|
||||||
|
unfinished = turn == max_turns - 1
|
||||||
|
messages.append({"role": "assistant", "content": new_text})
|
||||||
|
for call in calls:
|
||||||
|
name, raw = call.get("name", ""), call.get("arguments", {})
|
||||||
|
if isinstance(raw, str):
|
||||||
|
try: raw = json.loads(raw)
|
||||||
|
except: raw = {}
|
||||||
|
result = execute_tool(name, raw)
|
||||||
|
result_str = (json.dumps(result, ensure_ascii=False) if result else '{"error": "tool not found"}')[:2048] # 防止天文数字撑爆tokenizer
|
||||||
|
messages.append({"role": "tool", "content": result_str})
|
||||||
|
|
||||||
|
observe_context = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=not unfinished, tools=tools, open_thinking=open_thinking)
|
||||||
|
observe_ids = tokenizer(observe_context, return_tensors="pt", add_special_tokens=False)["input_ids"][0].tolist()
|
||||||
|
current_len = len(prompt_ids) + len(response_ids)
|
||||||
|
obs_delta = observe_ids[current_len:]
|
||||||
|
response_ids.extend(obs_delta)
|
||||||
|
response_mask.extend([0] * len(obs_delta))
|
||||||
|
response_old_logps.extend([0.0] * len(obs_delta))
|
||||||
|
final_context = observe_context
|
||||||
|
|
||||||
|
final_output = all_outputs[-1] if all_outputs else ""
|
||||||
|
prompt_ids = prompt_ids or []
|
||||||
|
return final_output, final_context, prompt_ids, response_ids, response_mask, response_old_logps, list(all_outputs), unfinished
|
||||||
|
|
||||||
|
def rollout_batch(rollout_engine, tokenizer, messages_batch, tools_batch, num_gen, max_turns=3, max_new_tokens=256, thinking_ratio=0.5, device="cuda"):
|
||||||
|
all_completions = []
|
||||||
|
all_contexts = []
|
||||||
|
all_prompt_ids = []
|
||||||
|
all_response_ids = []
|
||||||
|
all_response_masks = []
|
||||||
|
all_response_old_logps = []
|
||||||
|
all_turn_outputs = []
|
||||||
|
all_unfinished = []
|
||||||
|
for messages, tools in zip(messages_batch, tools_batch):
|
||||||
|
for _ in range(num_gen):
|
||||||
|
msgs_copy = [dict(m) for m in messages]
|
||||||
|
completion, context, prompt_ids, response_ids, response_mask, response_old_logps, turn_outputs, unfinished = rollout_single(rollout_engine, tokenizer, msgs_copy, tools, max_turns, max_new_tokens, thinking_ratio, device)
|
||||||
|
all_completions.append(completion)
|
||||||
|
all_contexts.append(context)
|
||||||
|
all_prompt_ids.append(prompt_ids)
|
||||||
|
all_response_ids.append(response_ids)
|
||||||
|
all_response_masks.append(response_mask)
|
||||||
|
all_response_old_logps.append(response_old_logps)
|
||||||
|
all_turn_outputs.append(turn_outputs)
|
||||||
|
all_unfinished.append(unfinished)
|
||||||
|
return all_completions, all_contexts, all_prompt_ids, all_response_ids, all_response_masks, all_response_old_logps, all_turn_outputs, all_unfinished
|
||||||
|
|
||||||
|
# ======== Reward 计算 ========
|
||||||
|
def validate_gt_in_text(text, gt_list):
|
||||||
|
text, text_num = str(text), str(text).replace(',', '')
|
||||||
|
nums = [float(x) for x in re.findall(r'(?<![\w.])[-+]?\d+(?:\.\d+)?(?![\w.])', text_num)]
|
||||||
|
return {g for g in gt_list if ((s := str(g).strip()) and s.lower() in text.lower()) or (re.fullmatch(r'[-+]?\d+(?:\.\d+)?', str(g).strip().replace(',', '')) and any(abs(float(str(g).strip().replace(',', '')) - n) < 1e-6 for n in nums))}
|
||||||
|
|
||||||
|
def calculate_rewards(prompts, completions, gt_batch, tools_batch, num_gen, reward_model=None, device="cuda", turn_outputs_batch=None, unfinished_batch=None):
|
||||||
|
rewards = torch.zeros(len(completions), device=device)
|
||||||
|
for idx, response in enumerate(completions):
|
||||||
|
reward, answer = 0.0, response
|
||||||
|
sample_idx = idx // num_gen
|
||||||
|
tools = tools_batch[sample_idx]
|
||||||
|
turn_outputs = turn_outputs_batch[idx] if turn_outputs_batch is not None else [response]
|
||||||
|
unfinished = unfinished_batch[idx] if unfinished_batch is not None else False
|
||||||
|
turn_answers = [turn.split('</think>', 1)[-1].strip() if '</think>' in turn else turn.strip() for turn in turn_outputs]
|
||||||
|
answer = turn_answers[-1] if turn_answers else response.strip()
|
||||||
|
valid_names = {t['function']['name'] for t in tools} if tools else set()
|
||||||
|
tool_calls = []
|
||||||
|
for turn_answer in turn_answers: tool_calls.extend(parse_tool_calls(turn_answer)) # 解析tool调用
|
||||||
|
reward -= 0.5 * sum(abs(turn.count('<tool_call>') - turn.count('</tool_call>')) for turn in turn_answers) # 标签扣分
|
||||||
|
# -------- 无工具调用:格式+reward奖励 --------
|
||||||
|
if not tool_calls:
|
||||||
|
reward += 0.5 if 5 <= len(response.strip()) <= 800 else -0.5 # 长度分
|
||||||
|
if '</think>' in response:
|
||||||
|
think, answer = response.split('</think>', 1)
|
||||||
|
reward += 1.0 if 20 <= len(think.strip()) <= 300 else -0.5 # 思考长度分
|
||||||
|
reward += 0.25 if response.count('</think>') == 1 else -0.25 # 思考闭合分
|
||||||
|
answer = answer.strip()
|
||||||
|
if reward_model is not None:
|
||||||
|
prompt = prompts[sample_idx]
|
||||||
|
pattern = r"<\|im_start\|>(system|user|assistant)\s+(.*?)<\|im_end\|>"
|
||||||
|
matches = re.findall(pattern, prompt, re.DOTALL)
|
||||||
|
messages = [{"role": role, "content": content.strip()} for role, content in matches]
|
||||||
|
score = reward_model.get_score(messages, answer)
|
||||||
|
reward += score # RM分
|
||||||
|
reward -= rep_penalty(answer)
|
||||||
|
rewards[idx] = max(min(reward, 3.0), -3.0) # 总分Clip
|
||||||
|
# -------- 有工具调用:执行结果奖励 --------
|
||||||
|
else:
|
||||||
|
gt = gt_batch[sample_idx]
|
||||||
|
valid_call_count = 0
|
||||||
|
for tool_call in tool_calls:
|
||||||
|
name, raw = tool_call.get("name", ""), tool_call.get("arguments", {})
|
||||||
|
if isinstance(raw, str):
|
||||||
|
try: raw = json.loads(raw)
|
||||||
|
except: raw = {}
|
||||||
|
check = CHECK_ARGS.get(name)
|
||||||
|
valid_call_count += int(bool(name in valid_names and check and check(raw)))
|
||||||
|
tool_gap = abs(valid_call_count - len(gt)) + max(0, len(tool_calls) - valid_call_count) # tool数差值
|
||||||
|
reward += 0.5 if tool_gap == 0 else -0.5 * tool_gap # tool对齐分
|
||||||
|
|
||||||
|
final_text = "" if unfinished else (answer.split('</tool_call>')[-1] if '</tool_call>' in answer else answer)
|
||||||
|
verified = validate_gt_in_text(final_text, gt) if gt else set()
|
||||||
|
if gt: reward += 2.5 * len(verified) / len(gt) # GT分
|
||||||
|
if unfinished: reward -= 0.5 # 未完成扣分
|
||||||
|
reward -= rep_penalty(final_text if final_text else answer)
|
||||||
|
rewards[idx] = max(min(reward, 3.0), -3.0) # 总分Clip
|
||||||
|
return rewards
|
||||||
|
|
||||||
|
# ================================ 工具与 Reward = End ================================
|
||||||
|
def rl_train_epoch(epoch, loader, iters, rollout_engine, ref_model, reward_model=None, start_step=0, wandb=None, use_sglang=False):
|
||||||
|
last_step = start_step
|
||||||
|
for step, batch in enumerate(loader, start=start_step + 1):
|
||||||
|
messages_batch = batch['messages']
|
||||||
|
tools_batch = batch['tools']
|
||||||
|
gt_batch = batch['gt']
|
||||||
|
last_step = step
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
completions, contexts, prompt_ids_batch, response_ids_batch, response_masks_batch, response_old_logps_batch, turn_outputs_batch, unfinished_batch = rollout_batch(rollout_engine, tokenizer, messages_batch, tools_batch, args.num_generations, max_turns=3, max_new_tokens=args.max_gen_len, thinking_ratio=args.thinking_ratio, device=args.device)
|
||||||
|
|
||||||
|
prompts = [tokenizer.apply_chat_template(m, tokenize=False, add_generation_prompt=True, tools=t) for m, t in zip(messages_batch, tools_batch)]
|
||||||
|
packed_samples = []
|
||||||
|
for p, r, m, old_lp in zip(prompt_ids_batch, response_ids_batch, response_masks_batch, response_old_logps_batch):
|
||||||
|
ids = p + r
|
||||||
|
mask = [0] * len(p) + m
|
||||||
|
old_logps = [0.0] * max(len(p) - 1, 0) + old_lp
|
||||||
|
if len(ids) > args.max_total_len:
|
||||||
|
ids = ids[-args.max_total_len:]
|
||||||
|
mask = mask[-args.max_total_len:]
|
||||||
|
old_logps = old_logps[-(len(ids) - 1):]
|
||||||
|
prompt_len = next((i for i, v in enumerate(mask) if v == 1), len(mask))
|
||||||
|
packed_samples.append((ids, mask, prompt_len, old_logps))
|
||||||
|
seq_lens = torch.tensor([len(ids) for ids, _, _, _ in packed_samples], device=args.device)
|
||||||
|
max_len = seq_lens.max().item()
|
||||||
|
input_ids = torch.tensor([ids + [tokenizer.pad_token_id] * (max_len - len(ids)) for ids, _, _, _ in packed_samples], device=args.device)
|
||||||
|
prompt_lens = torch.tensor([prompt_len for _, _, prompt_len, _ in packed_samples], device=args.device)
|
||||||
|
full_response_masks = torch.tensor([mask + [0] * (max_len - len(mask)) for _, mask, _, _ in packed_samples], device=args.device, dtype=torch.float32)
|
||||||
|
old_per_token_logps = torch.tensor([old_logps + [0.0] * ((max_len - 1) - len(old_logps)) for _, _, _, old_logps in packed_samples], device=args.device, dtype=torch.float32)
|
||||||
|
full_mask = (input_ids != tokenizer.pad_token_id).long()
|
||||||
|
|
||||||
|
model_unwrapped = model.module if isinstance(model, DistributedDataParallel) else model
|
||||||
|
with autocast_ctx:
|
||||||
|
res = model_unwrapped(input_ids, attention_mask=full_mask)
|
||||||
|
aux_loss = res.aux_loss if lm_config.use_moe else torch.tensor(0.0, device=args.device)
|
||||||
|
logits = res.logits[:, :-1, :]
|
||||||
|
per_token_logps = F.log_softmax(logits, dim=-1).gather(2, input_ids[:, 1:].unsqueeze(-1)).squeeze(-1)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
ref_per_token_logps = compute_per_token_logps(ref_model, input_ids, input_ids.size(1) - 1, attention_mask=full_mask)
|
||||||
|
|
||||||
|
completion_mask = full_response_masks[:, 1:]
|
||||||
|
is_eos = (input_ids[:, 1:] == tokenizer.eos_token_id) & completion_mask.bool()
|
||||||
|
eos_idx = torch.full((completion_mask.size(0),), completion_mask.size(1) - 1, device=args.device, dtype=torch.long)
|
||||||
|
has_eos = is_eos.any(dim=1)
|
||||||
|
eos_idx[has_eos] = is_eos.int().argmax(dim=1)[has_eos]
|
||||||
|
pos = torch.arange(completion_mask.size(1), device=args.device).unsqueeze(0)
|
||||||
|
completion_mask = completion_mask * (pos <= eos_idx.unsqueeze(1)).float()
|
||||||
|
token_counts = completion_mask.sum(dim=1)
|
||||||
|
valid_rows = token_counts > 0
|
||||||
|
rewards = calculate_rewards(prompts, completions, gt_batch, tools_batch, args.num_generations, reward_model, device=args.device, turn_outputs_batch=turn_outputs_batch, unfinished_batch=unfinished_batch)
|
||||||
|
|
||||||
|
if args.debug_mode and is_main_process() and step % args.debug_interval == 0:
|
||||||
|
for i in range(len(messages_batch)):
|
||||||
|
Logger(f"[DEBUG] step={step}, gt[{i}]: {repr(gt_batch[i])}")
|
||||||
|
Logger('-'*100)
|
||||||
|
for j in range(args.num_generations):
|
||||||
|
idx = i * args.num_generations + j
|
||||||
|
plen, slen = prompt_lens[idx].item(), seq_lens[idx].item()
|
||||||
|
Logger(f"{'=' * 30} [DEBUG] gen[{i}][{j}] CONTEXT_BEGIN {'=' * 30}")
|
||||||
|
Logger(contexts[idx])
|
||||||
|
Logger(f"{'=' * 31} [DEBUG] gen[{i}][{j}] CONTEXT_END {'=' * 31}")
|
||||||
|
Logger(f"[DEBUG] gen[{i}][{j}] prompt_len={plen}, seq_len={slen}")
|
||||||
|
tokens = input_ids[idx, plen:slen].tolist()
|
||||||
|
text = tokenizer.decode(tokens, skip_special_tokens=False)
|
||||||
|
Logger(f"{'=' * 28} [DEBUG] gen[{i}][{j}] COMPLETION_BEGIN [{plen}:{slen}] {'=' * 28}")
|
||||||
|
Logger(text)
|
||||||
|
Logger(f"{'=' * 29} [DEBUG] gen[{i}][{j}] COMPLETION_END {'=' * 29}")
|
||||||
|
Logger(f"[DEBUG] gen[{i}][{j}] reward={rewards[idx].item():.4f}")
|
||||||
|
Logger('='*100)
|
||||||
|
|
||||||
|
grouped_rewards = rewards.view(-1, args.num_generations)
|
||||||
|
mean_r = grouped_rewards.mean(dim=1).repeat_interleave(args.num_generations)
|
||||||
|
std_r = grouped_rewards.std(dim=1, unbiased=False).repeat_interleave(args.num_generations)
|
||||||
|
advantages = (rewards - mean_r) / (std_r + 1e-4)
|
||||||
|
|
||||||
|
kl_div = ref_per_token_logps - per_token_logps
|
||||||
|
per_token_kl = torch.exp(kl_div) - kl_div - 1
|
||||||
|
ratio = torch.exp(per_token_logps - old_per_token_logps)
|
||||||
|
if args.loss_type == "cispo":
|
||||||
|
clamped_ratio = torch.clamp(ratio, max=args.epsilon_high).detach()
|
||||||
|
per_token_loss = -(clamped_ratio * advantages.unsqueeze(1) * per_token_logps - args.beta * per_token_kl)
|
||||||
|
else:
|
||||||
|
clipped_ratio = torch.clamp(ratio, 1 - args.epsilon, 1 + args.epsilon)
|
||||||
|
per_token_loss1 = ratio * advantages.unsqueeze(1)
|
||||||
|
per_token_loss2 = clipped_ratio * advantages.unsqueeze(1)
|
||||||
|
per_token_loss = -(torch.min(per_token_loss1, per_token_loss2) - args.beta * per_token_kl)
|
||||||
|
policy_loss = (((per_token_loss * completion_mask).sum(dim=1)[valid_rows] / token_counts[valid_rows].clamp(min=1)).mean()
|
||||||
|
if valid_rows.any() else per_token_loss.sum() * 0.0)
|
||||||
|
loss = (policy_loss + aux_loss) / args.accumulation_steps
|
||||||
|
loss.backward()
|
||||||
|
|
||||||
|
if step % args.accumulation_steps == 0:
|
||||||
|
if args.grad_clip > 0: torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
|
||||||
|
optimizer.step(); scheduler.step(); optimizer.zero_grad()
|
||||||
|
|
||||||
|
if step % args.log_interval == 0 or step == iters:
|
||||||
|
pl = loss.item() * args.accumulation_steps
|
||||||
|
ar = rewards.mean().item()
|
||||||
|
al = token_counts.float().mean().item()
|
||||||
|
kl = ((ref_per_token_logps - per_token_logps) * completion_mask).sum().item() / max(token_counts.sum().item(), 1)
|
||||||
|
gs = grouped_rewards.std(dim=1, unbiased=False).mean().item()
|
||||||
|
am, ast = advantages.mean().item(), advantages.std().item()
|
||||||
|
lr = optimizer.param_groups[0]['lr']
|
||||||
|
Logger(f'Epoch:[{epoch+1}/{args.epochs}]({step}/{iters}), Reward:{ar:.4f}, KL:{kl:.4f}, GrpStd:{gs:.4f}, AdvStd:{ast:.4f}, Loss:{pl:.4f}, AvgLen:{al:.2f}, AdvMean:{am:.4f}, LR:{lr:.8f}')
|
||||||
|
if wandb and is_main_process():
|
||||||
|
wandb.log({"reward":ar,"kl_ref":kl,"group_reward_std":gs,"advantages_std":ast,"policy_loss":pl,"avg_response_len":al,"advantages_mean":am,"learning_rate":lr})
|
||||||
|
|
||||||
|
if (step % args.save_interval == 0 or step == iters) and is_main_process():
|
||||||
|
model.eval()
|
||||||
|
moe_suffix = '_moe' if lm_config.use_moe else ''
|
||||||
|
ckp = f'{args.save_dir}/{args.save_weight}_{lm_config.hidden_size}{moe_suffix}.pth'
|
||||||
|
raw_model = model.module if isinstance(model, DistributedDataParallel) else model
|
||||||
|
raw_model = getattr(raw_model, '_orig_mod', raw_model)
|
||||||
|
state_dict = raw_model.state_dict()
|
||||||
|
torch.save({k: v.half().cpu() for k, v in state_dict.items()}, ckp)
|
||||||
|
lm_checkpoint(lm_config, weight=args.save_weight, model=model, optimizer=optimizer,
|
||||||
|
epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints', scheduler=scheduler)
|
||||||
|
model.train()
|
||||||
|
del state_dict
|
||||||
|
|
||||||
|
if step % args.save_interval == 0 or step == iters: rollout_engine.update_policy(model)
|
||||||
|
|
||||||
|
del per_token_logps, ref_per_token_logps
|
||||||
|
del completions, rewards, grouped_rewards, mean_r, std_r, advantages, completion_mask
|
||||||
|
|
||||||
|
if last_step > start_step and last_step % args.accumulation_steps != 0:
|
||||||
|
if args.grad_clip > 0: torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
|
||||||
|
optimizer.step(); scheduler.step(); optimizer.zero_grad()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(description="MiniMind Agent RL")
|
||||||
|
parser.add_argument("--save_dir", type=str, default="../out", help="模型保存目录")
|
||||||
|
parser.add_argument('--save_weight', default='agent', type=str, help="保存权重名称")
|
||||||
|
parser.add_argument("--epochs", type=int, default=1, help="训练轮数")
|
||||||
|
parser.add_argument("--batch_size", type=int, default=2, help="批次大小")
|
||||||
|
parser.add_argument("--learning_rate", type=float, default=3e-7, help="学习率")
|
||||||
|
parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu", help="训练设备")
|
||||||
|
parser.add_argument("--dtype", type=str, default="bfloat16", help="数据类型 bfloat16/float16")
|
||||||
|
parser.add_argument("--num_workers", type=int, default=8, help="数据加载线程数")
|
||||||
|
parser.add_argument("--accumulation_steps", type=int, default=1, help="梯度累积步数")
|
||||||
|
parser.add_argument("--grad_clip", type=float, default=1.0, help="梯度裁剪阈值")
|
||||||
|
parser.add_argument("--log_interval", type=int, default=1, help="日志打印间隔")
|
||||||
|
parser.add_argument("--save_interval", type=int, default=10, help="模型保存间隔")
|
||||||
|
parser.add_argument('--hidden_size', default=768, type=int, help="模型隐藏层维度")
|
||||||
|
parser.add_argument('--num_hidden_layers', default=8, type=int, help="模型层数")
|
||||||
|
parser.add_argument('--use_moe', default=0, type=int, choices=[0, 1], help="是否使用MoE")
|
||||||
|
parser.add_argument('--max_seq_len', default=1024, type=int, help="最大序列长度")
|
||||||
|
parser.add_argument("--max_gen_len", type=int, default=768, help="单次最大生成长度")
|
||||||
|
parser.add_argument("--max_total_len", type=int, default=2500, help="训练侧最终总长度上界")
|
||||||
|
parser.add_argument("--data_path", type=str, default="../dataset/agent_rl.jsonl", help="训练数据路径")
|
||||||
|
parser.add_argument("--num_generations", type=int, default=4, help="每个prompt生成数量")
|
||||||
|
parser.add_argument("--beta", type=float, default=0.1, help="KL散度惩罚系数")
|
||||||
|
parser.add_argument("--loss_type", type=str, default="cispo", choices=["grpo", "cispo"], help="loss类型")
|
||||||
|
parser.add_argument("--epsilon", type=float, default=0.2, help="GRPO的PPO clip epsilon")
|
||||||
|
parser.add_argument("--epsilon_high", type=float, default=5.0, help="epsilon上界")
|
||||||
|
parser.add_argument('--from_weight', default='full_sft', type=str, help="加载预训练权重名称")
|
||||||
|
parser.add_argument('--from_resume', default=0, type=int, choices=[0, 1], help="是否从checkpoint恢复")
|
||||||
|
parser.add_argument("--use_wandb", action="store_true", help="是否使用wandb记录")
|
||||||
|
parser.add_argument("--wandb_project", type=str, default="MiniMind-Agent-RL", help="wandb项目名称")
|
||||||
|
parser.add_argument("--use_compile", default=0, type=int, choices=[0, 1], help="是否使用torch.compile")
|
||||||
|
parser.add_argument("--debug_mode", action="store_true", help="调试模式")
|
||||||
|
parser.add_argument("--debug_interval", type=int, default=20, help="调试日志间隔")
|
||||||
|
parser.add_argument("--thinking_ratio", type=float, default=0.1, help="按概率开启thinking(0.0~1.0)")
|
||||||
|
parser.add_argument("--reward_model_path", type=str, default="../../internlm2-1_8b-reward", help="Reward模型路径")
|
||||||
|
parser.add_argument("--rollout_engine", type=str, default="torch", choices=["torch", "sglang"], help="rollout引擎类型")
|
||||||
|
parser.add_argument("--sglang_base_url", type=str, default="http://localhost:8998", help="SGLang服务器URL")
|
||||||
|
parser.add_argument("--sglang_model_path", type=str, default="../model", help="SGLang tokenizer路径")
|
||||||
|
parser.add_argument("--sglang_shared_path", type=str, default="./sglang_ckpt_agent", help="SGLang共享存储路径")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
local_rank = init_distributed_mode()
|
||||||
|
if dist.is_initialized(): args.device = f"cuda:{local_rank}"
|
||||||
|
setup_seed(42 + (dist.get_rank() if dist.is_initialized() else 0))
|
||||||
|
|
||||||
|
os.makedirs(args.save_dir, exist_ok=True)
|
||||||
|
lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers,
|
||||||
|
max_seq_len=args.max_seq_len + args.max_gen_len, use_moe=bool(args.use_moe))
|
||||||
|
ckp_data = lm_checkpoint(lm_config, weight=args.save_weight, save_dir='../checkpoints') if args.from_resume == 1 else None
|
||||||
|
|
||||||
|
device_type = "cuda" if "cuda" in args.device else "cpu"
|
||||||
|
dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16
|
||||||
|
autocast_ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=dtype)
|
||||||
|
|
||||||
|
wandb = None
|
||||||
|
if args.use_wandb and is_main_process():
|
||||||
|
import swanlab as wandb
|
||||||
|
wandb_id = ckp_data.get('wandb_id') if ckp_data else None
|
||||||
|
resume = 'must' if wandb_id else None
|
||||||
|
wandb.init(project=args.wandb_project, name=f"Agent-RL-E{args.epochs}-B{args.batch_size}-LR{args.learning_rate}", id=wandb_id, resume=resume)
|
||||||
|
|
||||||
|
model, tokenizer = init_model(lm_config, args.from_weight, device=args.device)
|
||||||
|
|
||||||
|
ref_model, _ = init_model(lm_config, args.from_weight, device=args.device)
|
||||||
|
ref_model = ref_model.eval().requires_grad_(False)
|
||||||
|
|
||||||
|
reward_model = LMForRewardModel(args.reward_model_path, device=args.device, dtype=torch.float16)
|
||||||
|
Logger(f'Loaded reward model from {args.reward_model_path}')
|
||||||
|
# Rollout引擎
|
||||||
|
rollout_engine = create_rollout_engine(
|
||||||
|
engine_type=args.rollout_engine,
|
||||||
|
policy_model=model,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
device=args.device,
|
||||||
|
autocast_ctx=autocast_ctx,
|
||||||
|
sglang_base_url=args.sglang_base_url,
|
||||||
|
sglang_model_path=args.sglang_model_path,
|
||||||
|
sglang_shared_path=args.sglang_shared_path,
|
||||||
|
)
|
||||||
|
train_ds = AgentRLDataset(args.data_path, tokenizer, max_length=lm_config.max_seq_len)
|
||||||
|
train_sampler = DistributedSampler(train_ds) if dist.is_initialized() else None
|
||||||
|
optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)
|
||||||
|
def collate_fn(batch): return {'messages': [b['messages'] for b in batch], 'tools': [b['tools'] for b in batch], 'gt': [b['gt'] for b in batch]}
|
||||||
|
loader_for_count = DataLoader(train_ds, batch_size=args.batch_size, sampler=train_sampler, collate_fn=collate_fn)
|
||||||
|
iters = len(loader_for_count)
|
||||||
|
total_optimizer_steps = math.ceil(iters / args.accumulation_steps) * args.epochs
|
||||||
|
scheduler = CosineAnnealingLR(optimizer, T_max=total_optimizer_steps, eta_min=args.learning_rate / 10)
|
||||||
|
|
||||||
|
start_epoch, start_step = 0, 0
|
||||||
|
if ckp_data:
|
||||||
|
model.load_state_dict(ckp_data['model'])
|
||||||
|
optimizer.load_state_dict(ckp_data['optimizer'])
|
||||||
|
scheduler.load_state_dict(ckp_data['scheduler'])
|
||||||
|
start_epoch = ckp_data['epoch']
|
||||||
|
start_step = ckp_data.get('step', 0)
|
||||||
|
|
||||||
|
if args.use_compile == 1:
|
||||||
|
model = torch.compile(model)
|
||||||
|
Logger('torch.compile enabled')
|
||||||
|
rollout_engine.update_policy(model)
|
||||||
|
if dist.is_initialized():
|
||||||
|
model = DistributedDataParallel(model, device_ids=[local_rank])
|
||||||
|
rollout_engine.update_policy(model)
|
||||||
|
|
||||||
|
for epoch in range(start_epoch, args.epochs):
|
||||||
|
train_sampler and train_sampler.set_epoch(epoch)
|
||||||
|
setup_seed(42 + epoch); indices = torch.randperm(len(train_ds)).tolist()
|
||||||
|
skip = start_step if (epoch == start_epoch and start_step > 0) else 0
|
||||||
|
batch_sampler = SkipBatchSampler(train_sampler or indices, args.batch_size, skip)
|
||||||
|
loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True, collate_fn=collate_fn)
|
||||||
|
if skip > 0:
|
||||||
|
Logger(f'Epoch [{epoch+1}/{args.epochs}]: skip {start_step} steps')
|
||||||
|
rl_train_epoch(epoch, loader, len(loader) + skip, rollout_engine, ref_model, reward_model, start_step, wandb, use_sglang = (args.rollout_engine == "sglang"))
|
||||||
|
else:
|
||||||
|
rl_train_epoch(epoch, loader, len(loader), rollout_engine, ref_model, reward_model, 0, wandb, use_sglang = (args.rollout_engine == "sglang"))
|
||||||
|
|
||||||
|
if dist.is_initialized(): dist.destroy_process_group()
|
||||||
245
trainer/train_distillation.py
Normal file
@ -0,0 +1,245 @@
|
|||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
__package__ = "trainer"
|
||||||
|
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import time
|
||||||
|
import warnings
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import torch.distributed as dist
|
||||||
|
from contextlib import nullcontext
|
||||||
|
from torch import optim
|
||||||
|
from torch.nn.parallel import DistributedDataParallel
|
||||||
|
from torch.utils.data import DataLoader, DistributedSampler
|
||||||
|
from model.model_minimind import MiniMindConfig
|
||||||
|
from dataset.lm_dataset import SFTDataset
|
||||||
|
from trainer.trainer_utils import get_lr, Logger, is_main_process, lm_checkpoint, init_distributed_mode, setup_seed, init_model, SkipBatchSampler
|
||||||
|
|
||||||
|
warnings.filterwarnings('ignore')
|
||||||
|
|
||||||
|
|
||||||
|
def distillation_loss(student_logits, teacher_logits, temperature=1.0, reduction='batchmean'):
|
||||||
|
with torch.no_grad():
|
||||||
|
teacher_probs = F.softmax(teacher_logits / temperature, dim=-1).detach()
|
||||||
|
|
||||||
|
student_log_probs = F.log_softmax(student_logits / temperature, dim=-1)
|
||||||
|
|
||||||
|
kl = F.kl_div(
|
||||||
|
student_log_probs,
|
||||||
|
teacher_probs,
|
||||||
|
reduction=reduction
|
||||||
|
)
|
||||||
|
return (temperature ** 2) * kl
|
||||||
|
|
||||||
|
|
||||||
|
def train_epoch(epoch, loader, iters, teacher_model, lm_config_student, start_step=0, wandb=None, alpha=0.0, temperature=1.0):
|
||||||
|
start_time = time.time()
|
||||||
|
last_step = start_step
|
||||||
|
|
||||||
|
if teacher_model is not None:
|
||||||
|
teacher_model.eval()
|
||||||
|
teacher_model.requires_grad_(False)
|
||||||
|
|
||||||
|
for step, (input_ids, labels) in enumerate(loader, start=start_step + 1):
|
||||||
|
last_step = step
|
||||||
|
input_ids = input_ids.to(args.device)
|
||||||
|
labels = labels.to(args.device)
|
||||||
|
loss_mask = (labels[..., 1:] != -100).float()
|
||||||
|
lr = get_lr(epoch * iters + step, args.epochs * iters, args.learning_rate)
|
||||||
|
for param_group in optimizer.param_groups:
|
||||||
|
param_group['lr'] = lr
|
||||||
|
|
||||||
|
# 前向传播(学生模型)
|
||||||
|
with autocast_ctx:
|
||||||
|
res = model(input_ids)
|
||||||
|
student_logits = res.logits[..., :-1, :].contiguous()
|
||||||
|
|
||||||
|
# 教师模型前向传播(只在eval & no_grad)
|
||||||
|
if teacher_model is not None:
|
||||||
|
with torch.no_grad():
|
||||||
|
teacher_logits = teacher_model(input_ids).logits[..., :-1, :].contiguous()
|
||||||
|
vocab_size_student = student_logits.size(-1)
|
||||||
|
teacher_logits = teacher_logits[..., :vocab_size_student]
|
||||||
|
|
||||||
|
# ========== 计算损失 ==========
|
||||||
|
# 1) Ground-Truth CE Loss
|
||||||
|
shift_labels = labels[..., 1:].contiguous()
|
||||||
|
loss_mask_flat = loss_mask.view(-1)
|
||||||
|
ce_loss = F.cross_entropy(
|
||||||
|
student_logits.view(-1, student_logits.size(-1)),
|
||||||
|
shift_labels.view(-1),
|
||||||
|
ignore_index=-100,
|
||||||
|
reduction='none'
|
||||||
|
)
|
||||||
|
ce_loss_raw = torch.sum(ce_loss * loss_mask_flat) / (loss_mask_flat.sum() + 1e-8)
|
||||||
|
if lm_config_student.use_moe: ce_loss = ce_loss_raw + res.aux_loss
|
||||||
|
else: ce_loss = ce_loss_raw
|
||||||
|
|
||||||
|
# 2) Distillation Loss
|
||||||
|
if teacher_model is not None:
|
||||||
|
distill_loss = distillation_loss(
|
||||||
|
student_logits.view(-1, student_logits.size(-1))[loss_mask_flat == 1],
|
||||||
|
teacher_logits.view(-1, teacher_logits.size(-1))[loss_mask_flat == 1],
|
||||||
|
temperature=temperature
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
distill_loss = torch.tensor(0.0, device=args.device)
|
||||||
|
|
||||||
|
# 3) 总损失 = alpha * CE + (1-alpha) * Distill
|
||||||
|
loss = (alpha * ce_loss + (1 - alpha) * distill_loss) / args.accumulation_steps
|
||||||
|
|
||||||
|
scaler.scale(loss).backward()
|
||||||
|
|
||||||
|
if step % args.accumulation_steps == 0:
|
||||||
|
scaler.unscale_(optimizer)
|
||||||
|
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
|
||||||
|
scaler.step(optimizer)
|
||||||
|
scaler.update()
|
||||||
|
optimizer.zero_grad(set_to_none=True)
|
||||||
|
|
||||||
|
if step % args.log_interval == 0 or step == iters:
|
||||||
|
spend_time = time.time() - start_time
|
||||||
|
current_loss = loss.item() * args.accumulation_steps
|
||||||
|
current_ce_loss = ce_loss_raw.item()
|
||||||
|
current_aux_loss = res.aux_loss.item() if lm_config_student.use_moe else 0.0
|
||||||
|
current_lr = optimizer.param_groups[-1]['lr']
|
||||||
|
eta_min = spend_time / max(step - start_step, 1) * (iters - step) // 60
|
||||||
|
|
||||||
|
Logger(f'Epoch:[{epoch + 1}/{args.epochs}]({step}/{iters}), loss: {current_loss:.4f}, ce: {current_ce_loss:.4f}, aux_loss: {current_aux_loss:.4f}, distill: {distill_loss.item():.4f}, learning_rate: {current_lr:.8f}, epoch_time: {eta_min:.3f}min')
|
||||||
|
|
||||||
|
if wandb:
|
||||||
|
wandb.log({
|
||||||
|
"loss": current_loss,
|
||||||
|
"ce_loss": current_ce_loss,
|
||||||
|
"aux_loss": current_aux_loss,
|
||||||
|
"distill_loss": distill_loss.item() if teacher_model is not None else 0.0,
|
||||||
|
"learning_rate": current_lr,
|
||||||
|
"epoch_time": eta_min
|
||||||
|
})
|
||||||
|
|
||||||
|
if (step % args.save_interval == 0 or step == iters) and is_main_process():
|
||||||
|
model.eval()
|
||||||
|
moe_suffix = '_moe' if lm_config_student.use_moe else ''
|
||||||
|
ckp = f'{args.save_dir}/{args.save_weight}_{lm_config_student.hidden_size}{moe_suffix}.pth'
|
||||||
|
raw_model = model.module if isinstance(model, DistributedDataParallel) else model
|
||||||
|
raw_model = getattr(raw_model, '_orig_mod', raw_model)
|
||||||
|
state_dict = raw_model.state_dict()
|
||||||
|
torch.save({k: v.half().cpu() for k, v in state_dict.items()}, ckp)
|
||||||
|
lm_checkpoint(lm_config_student, weight=args.save_weight, model=model, optimizer=optimizer, scaler=scaler, epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints')
|
||||||
|
model.train()
|
||||||
|
del state_dict
|
||||||
|
|
||||||
|
del input_ids, labels, loss_mask, res, student_logits, ce_loss, distill_loss, loss
|
||||||
|
|
||||||
|
if last_step > start_step and last_step % args.accumulation_steps != 0:
|
||||||
|
scaler.unscale_(optimizer)
|
||||||
|
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
|
||||||
|
scaler.step(optimizer)
|
||||||
|
scaler.update()
|
||||||
|
optimizer.zero_grad(set_to_none=True)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# 模拟用moe模型蒸馏dense模型,也可以用更大teacher_hidden_size模型蒸馏更小student_hidden_size的
|
||||||
|
parser = argparse.ArgumentParser(description="MiniMind Knowledge Distillation")
|
||||||
|
parser.add_argument("--save_dir", type=str, default="../out", help="模型保存目录")
|
||||||
|
parser.add_argument('--save_weight', default='full_dist', type=str, help="保存权重的前缀名")
|
||||||
|
parser.add_argument("--epochs", type=int, default=6, help="训练轮数")
|
||||||
|
parser.add_argument("--batch_size", type=int, default=32, help="batch size")
|
||||||
|
parser.add_argument("--learning_rate", type=float, default=5e-6, help="初始学习率")
|
||||||
|
parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu", help="训练设备")
|
||||||
|
parser.add_argument("--dtype", type=str, default="bfloat16", help="混合精度类型")
|
||||||
|
parser.add_argument("--num_workers", type=int, default=8, help="数据加载线程数")
|
||||||
|
parser.add_argument("--accumulation_steps", type=int, default=1, help="梯度累积步数")
|
||||||
|
parser.add_argument("--grad_clip", type=float, default=1.0, help="梯度裁剪阈值")
|
||||||
|
parser.add_argument("--log_interval", type=int, default=100, help="日志打印间隔")
|
||||||
|
parser.add_argument("--save_interval", type=int, default=100, help="模型保存间隔")
|
||||||
|
parser.add_argument("--max_seq_len", type=int, default=340, help="训练的最大截断长度(中文1token≈1.5~1.7字符)")
|
||||||
|
parser.add_argument("--data_path", type=str, default="../dataset/sft_t2t_mini.jsonl", help="训练数据路径")
|
||||||
|
parser.add_argument('--student_hidden_size', default=768, type=int, help="学生模型隐藏层维度")
|
||||||
|
parser.add_argument('--student_num_layers', default=8, type=int, help="学生模型隐藏层数量")
|
||||||
|
parser.add_argument('--teacher_hidden_size', default=768, type=int, help="教师模型隐藏层维度")
|
||||||
|
parser.add_argument('--teacher_num_layers', default=8, type=int, help="教师模型隐藏层数量")
|
||||||
|
parser.add_argument('--student_use_moe', default=0, type=int, choices=[0, 1], help="学生模型是否使用MoE(0=否,1=是)")
|
||||||
|
parser.add_argument('--teacher_use_moe', default=1, type=int, choices=[0, 1], help="教师模型是否使用MoE(0=否,1=是)")
|
||||||
|
parser.add_argument('--from_student_weight', default='full_sft', type=str, help="学生模型基于哪个权重")
|
||||||
|
parser.add_argument('--from_teacher_weight', default='full_sft', type=str, help="教师模型基于哪个权重")
|
||||||
|
parser.add_argument('--from_resume', default=0, type=int, choices=[0, 1], help="是否自动检测&续训(0=否,1=是)")
|
||||||
|
parser.add_argument('--alpha', default=0.5, type=float, help="CE损失权重,总损失=alpha*CE+(1-alpha)*KL")
|
||||||
|
parser.add_argument('--temperature', default=1.5, type=float, help="蒸馏温度(推荐范围1.0-2.0)")
|
||||||
|
parser.add_argument("--use_wandb", action="store_true", help="是否使用wandb")
|
||||||
|
parser.add_argument("--wandb_project", type=str, default="MiniMind-Distillation", help="wandb项目名")
|
||||||
|
parser.add_argument("--use_compile", default=0, type=int, choices=[0, 1], help="是否使用torch.compile加速(0=否,1=是)")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# ========== 1. 初始化环境和随机种子 ==========
|
||||||
|
local_rank = init_distributed_mode()
|
||||||
|
if dist.is_initialized(): args.device = f"cuda:{local_rank}"
|
||||||
|
setup_seed(42 + (dist.get_rank() if dist.is_initialized() else 0))
|
||||||
|
|
||||||
|
# ========== 2. 配置目录、模型参数、检查ckp ==========
|
||||||
|
os.makedirs(args.save_dir, exist_ok=True)
|
||||||
|
lm_config_student = MiniMindConfig(hidden_size=args.student_hidden_size, num_hidden_layers=args.student_num_layers, use_moe=bool(args.student_use_moe))
|
||||||
|
lm_config_teacher = MiniMindConfig(hidden_size=args.teacher_hidden_size, num_hidden_layers=args.teacher_num_layers, use_moe=bool(args.teacher_use_moe))
|
||||||
|
ckp_data = lm_checkpoint(lm_config_student, weight=args.save_weight, save_dir='../checkpoints') if args.from_resume==1 else None
|
||||||
|
|
||||||
|
# ========== 3. 设置混合精度 ==========
|
||||||
|
device_type = "cuda" if "cuda" in args.device else "cpu"
|
||||||
|
dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16
|
||||||
|
autocast_ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=dtype)
|
||||||
|
|
||||||
|
# ========== 4. 配wandb ==========
|
||||||
|
wandb = None
|
||||||
|
if args.use_wandb and is_main_process():
|
||||||
|
import swanlab as wandb
|
||||||
|
wandb_id = ckp_data.get('wandb_id') if ckp_data else None
|
||||||
|
resume = 'must' if wandb_id else None
|
||||||
|
wandb_run_name = f"MiniMind-Distill-S{args.student_hidden_size}T{args.teacher_hidden_size}-Epoch-{args.epochs}-BS-{args.batch_size}-LR-{args.learning_rate}"
|
||||||
|
wandb.init(project=args.wandb_project, name=wandb_run_name, id=wandb_id, resume=resume)
|
||||||
|
|
||||||
|
# ========== 5. 定义学生和教师模型 ==========
|
||||||
|
model, tokenizer = init_model(lm_config_student, args.from_student_weight, device=args.device)
|
||||||
|
Logger(f'学生模型总参数量:{sum(p.numel() for p in model.parameters()) / 1e6:.3f} M')
|
||||||
|
teacher_model, _ = init_model(lm_config_teacher, args.from_teacher_weight, device=args.device)
|
||||||
|
teacher_model.eval()
|
||||||
|
teacher_model.requires_grad_(False)
|
||||||
|
Logger(f'教师模型总参数量:{sum(p.numel() for p in teacher_model.parameters()) / 1e6:.3f} M')
|
||||||
|
train_ds = SFTDataset(args.data_path, tokenizer, max_length=args.max_seq_len)
|
||||||
|
train_sampler = DistributedSampler(train_ds) if dist.is_initialized() else None
|
||||||
|
scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype == 'float16'))
|
||||||
|
optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)
|
||||||
|
|
||||||
|
# ========== 6. 从ckp恢复状态 ==========
|
||||||
|
start_epoch, start_step = 0, 0
|
||||||
|
if ckp_data:
|
||||||
|
model.load_state_dict(ckp_data['model'])
|
||||||
|
optimizer.load_state_dict(ckp_data['optimizer'])
|
||||||
|
scaler.load_state_dict(ckp_data['scaler'])
|
||||||
|
start_epoch = ckp_data['epoch']
|
||||||
|
start_step = ckp_data.get('step', 0)
|
||||||
|
|
||||||
|
# ========== 7. 编译和分布式包装 ==========
|
||||||
|
if args.use_compile == 1:
|
||||||
|
model = torch.compile(model)
|
||||||
|
Logger('torch.compile enabled')
|
||||||
|
if dist.is_initialized():
|
||||||
|
model = DistributedDataParallel(model, device_ids=[local_rank])
|
||||||
|
|
||||||
|
# ========== 8. 开始训练 ==========
|
||||||
|
for epoch in range(start_epoch, args.epochs):
|
||||||
|
train_sampler and train_sampler.set_epoch(epoch)
|
||||||
|
setup_seed(42 + epoch); indices = torch.randperm(len(train_ds)).tolist()
|
||||||
|
skip = start_step if (epoch == start_epoch and start_step > 0) else 0
|
||||||
|
batch_sampler = SkipBatchSampler(train_sampler or indices, args.batch_size, skip)
|
||||||
|
loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True)
|
||||||
|
if skip > 0:
|
||||||
|
Logger(f'Epoch [{epoch + 1}/{args.epochs}]: 跳过前{start_step}个step,从step {start_step + 1}开始')
|
||||||
|
train_epoch(epoch, loader, len(loader) + skip, teacher_model, lm_config_student, start_step, wandb, args.alpha, args.temperature)
|
||||||
|
else:
|
||||||
|
train_epoch(epoch, loader, len(loader), teacher_model, lm_config_student, 0, wandb, args.alpha, args.temperature)
|
||||||
|
|
||||||
|
# ========== 9. 清理分布进程 ==========
|
||||||
|
if dist.is_initialized(): dist.destroy_process_group()
|
||||||
225
trainer/train_dpo.py
Normal file
@ -0,0 +1,225 @@
|
|||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
__package__ = "trainer"
|
||||||
|
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import time
|
||||||
|
import warnings
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import torch.distributed as dist
|
||||||
|
from contextlib import nullcontext
|
||||||
|
from torch import optim
|
||||||
|
from torch.nn.parallel import DistributedDataParallel
|
||||||
|
from torch.utils.data import DataLoader, DistributedSampler
|
||||||
|
from model.model_minimind import MiniMindConfig
|
||||||
|
from dataset.lm_dataset import DPODataset
|
||||||
|
from trainer.trainer_utils import get_lr, Logger, is_main_process, lm_checkpoint, init_distributed_mode, setup_seed, init_model, SkipBatchSampler
|
||||||
|
|
||||||
|
warnings.filterwarnings('ignore')
|
||||||
|
|
||||||
|
|
||||||
|
def logits_to_log_probs(logits, labels):
|
||||||
|
# logits shape: (batch_size, seq_len, vocab_size)
|
||||||
|
# labels shape: (batch_size, seq_len)
|
||||||
|
# log_probs shape: (batch_size, seq_len)
|
||||||
|
log_probs = F.log_softmax(logits, dim=2)
|
||||||
|
log_probs_per_token = torch.gather(log_probs, dim=2, index=labels.unsqueeze(2)).squeeze(-1)
|
||||||
|
return log_probs_per_token
|
||||||
|
|
||||||
|
|
||||||
|
def dpo_loss(ref_log_probs, policy_log_probs, mask, beta):
|
||||||
|
# ref_log_probs 和 policy_log_probs 都是 shape: (batch_size, seq_len)
|
||||||
|
ref_log_probs = (ref_log_probs * mask).sum(dim=1)
|
||||||
|
policy_log_probs = (policy_log_probs * mask).sum(dim=1)
|
||||||
|
|
||||||
|
# 将 chosen 和 rejected 数据分开
|
||||||
|
batch_size = ref_log_probs.shape[0]
|
||||||
|
chosen_ref_log_probs = ref_log_probs[:batch_size // 2]
|
||||||
|
reject_ref_log_probs = ref_log_probs[batch_size // 2:]
|
||||||
|
chosen_policy_log_probs = policy_log_probs[:batch_size // 2]
|
||||||
|
reject_policy_log_probs = policy_log_probs[batch_size // 2:]
|
||||||
|
|
||||||
|
pi_logratios = chosen_policy_log_probs - reject_policy_log_probs
|
||||||
|
ref_logratios = chosen_ref_log_probs - reject_ref_log_probs
|
||||||
|
logits = pi_logratios - ref_logratios
|
||||||
|
loss = -F.logsigmoid(beta * logits)
|
||||||
|
return loss.mean()
|
||||||
|
|
||||||
|
|
||||||
|
def train_epoch(epoch, loader, iters, ref_model, lm_config, start_step=0, wandb=None, beta=0.1):
|
||||||
|
start_time = time.time()
|
||||||
|
last_step = start_step
|
||||||
|
|
||||||
|
for step, batch in enumerate(loader, start=start_step + 1):
|
||||||
|
last_step = step
|
||||||
|
x_chosen = batch['x_chosen'].to(args.device)
|
||||||
|
x_rejected = batch['x_rejected'].to(args.device)
|
||||||
|
y_chosen = batch['y_chosen'].to(args.device)
|
||||||
|
y_rejected = batch['y_rejected'].to(args.device)
|
||||||
|
mask_chosen = batch['mask_chosen'].to(args.device)
|
||||||
|
mask_rejected = batch['mask_rejected'].to(args.device)
|
||||||
|
x = torch.cat([x_chosen, x_rejected], dim=0)
|
||||||
|
y = torch.cat([y_chosen, y_rejected], dim=0)
|
||||||
|
mask = torch.cat([mask_chosen, mask_rejected], dim=0)
|
||||||
|
|
||||||
|
lr = get_lr(epoch * iters + step, args.epochs * iters, args.learning_rate)
|
||||||
|
for param_group in optimizer.param_groups:
|
||||||
|
param_group['lr'] = lr
|
||||||
|
|
||||||
|
with autocast_ctx:
|
||||||
|
with torch.no_grad():
|
||||||
|
ref_outputs = ref_model(x)
|
||||||
|
ref_logits = ref_outputs.logits
|
||||||
|
ref_log_probs = logits_to_log_probs(ref_logits, y)
|
||||||
|
|
||||||
|
outputs = model(x)
|
||||||
|
logits = outputs.logits
|
||||||
|
policy_log_probs = logits_to_log_probs(logits, y)
|
||||||
|
|
||||||
|
dpo_loss_val = dpo_loss(ref_log_probs, policy_log_probs, mask, beta=beta)
|
||||||
|
loss = dpo_loss_val + outputs.aux_loss
|
||||||
|
loss = loss / args.accumulation_steps
|
||||||
|
|
||||||
|
scaler.scale(loss).backward()
|
||||||
|
|
||||||
|
if step % args.accumulation_steps == 0:
|
||||||
|
scaler.unscale_(optimizer)
|
||||||
|
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
|
||||||
|
scaler.step(optimizer)
|
||||||
|
scaler.update()
|
||||||
|
optimizer.zero_grad(set_to_none=True)
|
||||||
|
|
||||||
|
if step % args.log_interval == 0 or step == iters:
|
||||||
|
spend_time = time.time() - start_time
|
||||||
|
current_loss = loss.item() * args.accumulation_steps
|
||||||
|
current_dpo_loss = dpo_loss_val.item()
|
||||||
|
current_aux_loss = outputs.aux_loss.item()
|
||||||
|
current_lr = optimizer.param_groups[-1]['lr']
|
||||||
|
eta_min = spend_time / max(step - start_step, 1) * (iters - step) // 60
|
||||||
|
|
||||||
|
Logger(f'Epoch:[{epoch + 1}/{args.epochs}]({step}/{iters}), loss: {current_loss:.4f}, dpo_loss: {current_dpo_loss:.4f}, aux_loss: {current_aux_loss:.4f}, learning_rate: {current_lr:.8f}, epoch_time: {eta_min:.3f}min')
|
||||||
|
|
||||||
|
if wandb: wandb.log({"loss": current_loss, "dpo_loss": current_dpo_loss, "aux_loss": current_aux_loss, "learning_rate": current_lr, "epoch_time": eta_min})
|
||||||
|
|
||||||
|
if (step % args.save_interval == 0 or step == iters) and is_main_process():
|
||||||
|
model.eval()
|
||||||
|
moe_suffix = '_moe' if lm_config.use_moe else ''
|
||||||
|
ckp = f'{args.save_dir}/{args.save_weight}_{lm_config.hidden_size}{moe_suffix}.pth'
|
||||||
|
raw_model = model.module if isinstance(model, DistributedDataParallel) else model
|
||||||
|
raw_model = getattr(raw_model, '_orig_mod', raw_model)
|
||||||
|
state_dict = raw_model.state_dict()
|
||||||
|
torch.save({k: v.half().cpu() for k, v in state_dict.items()}, ckp)
|
||||||
|
lm_checkpoint(lm_config, weight=args.save_weight, model=model, optimizer=optimizer, scaler=scaler, epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints')
|
||||||
|
model.train()
|
||||||
|
del state_dict
|
||||||
|
|
||||||
|
del x_chosen, x_rejected, y_chosen, y_rejected, mask_chosen, mask_rejected, x, y, mask
|
||||||
|
del ref_outputs, ref_logits, ref_log_probs, outputs, logits, policy_log_probs, loss
|
||||||
|
|
||||||
|
if last_step > start_step and last_step % args.accumulation_steps != 0:
|
||||||
|
scaler.unscale_(optimizer)
|
||||||
|
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
|
||||||
|
scaler.step(optimizer)
|
||||||
|
scaler.update()
|
||||||
|
optimizer.zero_grad(set_to_none=True)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(description="MiniMind DPO (Direct Preference Optimization)")
|
||||||
|
parser.add_argument("--save_dir", type=str, default="../out", help="模型保存目录")
|
||||||
|
parser.add_argument('--save_weight', default='dpo', type=str, help="保存权重的前缀名")
|
||||||
|
parser.add_argument("--epochs", type=int, default=1, help="训练轮数")
|
||||||
|
parser.add_argument("--batch_size", type=int, default=4, help="batch size")
|
||||||
|
parser.add_argument("--learning_rate", type=float, default=4e-8, help="初始学习率(建议<=5e-8避免遗忘)")
|
||||||
|
parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu", help="训练设备")
|
||||||
|
parser.add_argument("--dtype", type=str, default="bfloat16", help="混合精度类型")
|
||||||
|
parser.add_argument("--num_workers", type=int, default=8, help="数据加载线程数")
|
||||||
|
parser.add_argument("--accumulation_steps", type=int, default=1, help="梯度累积步数")
|
||||||
|
parser.add_argument("--grad_clip", type=float, default=1.0, help="梯度裁剪阈值")
|
||||||
|
parser.add_argument("--log_interval", type=int, default=100, help="日志打印间隔")
|
||||||
|
parser.add_argument("--save_interval", type=int, default=100, help="模型保存间隔")
|
||||||
|
parser.add_argument('--hidden_size', default=768, type=int, help="隐藏层维度")
|
||||||
|
parser.add_argument('--num_hidden_layers', default=8, type=int, help="隐藏层数量")
|
||||||
|
parser.add_argument('--max_seq_len', default=1024, type=int, help="训练的最大截断长度(中文1token≈1.5~1.7字符)")
|
||||||
|
parser.add_argument('--use_moe', default=0, type=int, choices=[0, 1], help="是否使用MoE架构(0=否,1=是)")
|
||||||
|
parser.add_argument("--data_path", type=str, default="../dataset/dpo.jsonl", help="DPO训练数据路径")
|
||||||
|
parser.add_argument('--from_weight', default='full_sft', type=str, help="基于哪个权重训练")
|
||||||
|
parser.add_argument('--from_resume', default=0, type=int, choices=[0, 1], help="是否自动检测&续训(0=否,1=是)")
|
||||||
|
parser.add_argument('--beta', default=0.15, type=float, help="DPO中的beta参数")
|
||||||
|
parser.add_argument("--use_wandb", action="store_true", help="是否使用wandb")
|
||||||
|
parser.add_argument("--wandb_project", type=str, default="MiniMind-DPO", help="wandb项目名")
|
||||||
|
parser.add_argument("--use_compile", default=0, type=int, choices=[0, 1], help="是否使用torch.compile加速(0=否,1=是)")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# ========== 1. 初始化环境和随机种子 ==========
|
||||||
|
local_rank = init_distributed_mode()
|
||||||
|
if dist.is_initialized(): args.device = f"cuda:{local_rank}"
|
||||||
|
setup_seed(42 + (dist.get_rank() if dist.is_initialized() else 0))
|
||||||
|
|
||||||
|
# ========== 2. 配置目录、模型参数、检查ckp ==========
|
||||||
|
os.makedirs(args.save_dir, exist_ok=True)
|
||||||
|
lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers, use_moe=bool(args.use_moe))
|
||||||
|
ckp_data = lm_checkpoint(lm_config, weight=args.save_weight, save_dir='../checkpoints') if args.from_resume==1 else None
|
||||||
|
|
||||||
|
# ========== 3. 设置混合精度 ==========
|
||||||
|
device_type = "cuda" if "cuda" in args.device else "cpu"
|
||||||
|
dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16
|
||||||
|
autocast_ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=dtype)
|
||||||
|
|
||||||
|
# ========== 4. 配wandb ==========
|
||||||
|
wandb = None
|
||||||
|
if args.use_wandb and is_main_process():
|
||||||
|
import swanlab as wandb
|
||||||
|
wandb_id = ckp_data.get('wandb_id') if ckp_data else None
|
||||||
|
resume = 'must' if wandb_id else None
|
||||||
|
wandb_run_name = f"MiniMind-DPO-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LR-{args.learning_rate}"
|
||||||
|
wandb.init(project=args.wandb_project, name=wandb_run_name, id=wandb_id, resume=resume)
|
||||||
|
|
||||||
|
# ========== 5. 定义模型和参考模型 ==========
|
||||||
|
model, tokenizer = init_model(lm_config, args.from_weight, device=args.device)
|
||||||
|
Logger(f'策略模型总参数量:{sum(p.numel() for p in model.parameters()) / 1e6:.3f} M')
|
||||||
|
# 初始化参考模型(ref_model冻结)
|
||||||
|
ref_model, _ = init_model(lm_config, args.from_weight, device=args.device)
|
||||||
|
ref_model.eval()
|
||||||
|
ref_model.requires_grad_(False)
|
||||||
|
Logger(f'参考模型总参数量:{sum(p.numel() for p in ref_model.parameters()) / 1e6:.3f} M')
|
||||||
|
|
||||||
|
train_ds = DPODataset(args.data_path, tokenizer, max_length=args.max_seq_len)
|
||||||
|
train_sampler = DistributedSampler(train_ds) if dist.is_initialized() else None
|
||||||
|
scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype == 'float16'))
|
||||||
|
optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)
|
||||||
|
|
||||||
|
# ========== 6. 从ckp恢复状态 ==========
|
||||||
|
start_epoch, start_step = 0, 0
|
||||||
|
if ckp_data:
|
||||||
|
model.load_state_dict(ckp_data['model'])
|
||||||
|
optimizer.load_state_dict(ckp_data['optimizer'])
|
||||||
|
scaler.load_state_dict(ckp_data['scaler'])
|
||||||
|
start_epoch = ckp_data['epoch']
|
||||||
|
start_step = ckp_data.get('step', 0)
|
||||||
|
|
||||||
|
# ========== 7. 编译和分布式包装 ==========
|
||||||
|
if args.use_compile == 1:
|
||||||
|
model = torch.compile(model)
|
||||||
|
Logger('torch.compile enabled')
|
||||||
|
if dist.is_initialized():
|
||||||
|
model = DistributedDataParallel(model, device_ids=[local_rank])
|
||||||
|
|
||||||
|
# ========== 8. 开始训练 ==========
|
||||||
|
for epoch in range(start_epoch, args.epochs):
|
||||||
|
train_sampler and train_sampler.set_epoch(epoch)
|
||||||
|
setup_seed(42 + epoch); indices = torch.randperm(len(train_ds)).tolist()
|
||||||
|
skip = start_step if (epoch == start_epoch and start_step > 0) else 0
|
||||||
|
batch_sampler = SkipBatchSampler(train_sampler or indices, args.batch_size, skip)
|
||||||
|
loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True)
|
||||||
|
if skip > 0:
|
||||||
|
Logger(f'Epoch [{epoch + 1}/{args.epochs}]: 跳过前{start_step}个step,从step {start_step + 1}开始')
|
||||||
|
train_epoch(epoch, loader, len(loader) + skip, ref_model, lm_config, start_step, wandb, args.beta)
|
||||||
|
else:
|
||||||
|
train_epoch(epoch, loader, len(loader), ref_model, lm_config, 0, wandb, args.beta)
|
||||||
|
|
||||||
|
# ========== 9. 清理分布进程 ==========
|
||||||
|
if dist.is_initialized(): dist.destroy_process_group()
|
||||||
170
trainer/train_full_sft.py
Normal file
@ -0,0 +1,170 @@
|
|||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
__package__ = "trainer"
|
||||||
|
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import time
|
||||||
|
import warnings
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
from contextlib import nullcontext
|
||||||
|
from torch import optim, nn
|
||||||
|
from torch.nn.parallel import DistributedDataParallel
|
||||||
|
from torch.utils.data import DataLoader, DistributedSampler
|
||||||
|
from model.model_minimind import MiniMindConfig
|
||||||
|
from dataset.lm_dataset import SFTDataset
|
||||||
|
from trainer.trainer_utils import get_lr, Logger, is_main_process, lm_checkpoint, init_distributed_mode, setup_seed, init_model, SkipBatchSampler
|
||||||
|
|
||||||
|
warnings.filterwarnings('ignore')
|
||||||
|
|
||||||
|
|
||||||
|
def train_epoch(epoch, loader, iters, start_step=0, wandb=None):
|
||||||
|
start_time = time.time()
|
||||||
|
last_step = start_step
|
||||||
|
for step, (input_ids, labels) in enumerate(loader, start=start_step + 1):
|
||||||
|
input_ids = input_ids.to(args.device)
|
||||||
|
labels = labels.to(args.device)
|
||||||
|
last_step = step
|
||||||
|
lr = get_lr(epoch * iters + step, args.epochs * iters, args.learning_rate)
|
||||||
|
for param_group in optimizer.param_groups:
|
||||||
|
param_group['lr'] = lr
|
||||||
|
|
||||||
|
with autocast_ctx:
|
||||||
|
res = model(input_ids, labels=labels)
|
||||||
|
loss = res.loss + res.aux_loss
|
||||||
|
loss = loss / args.accumulation_steps
|
||||||
|
|
||||||
|
scaler.scale(loss).backward()
|
||||||
|
|
||||||
|
if step % args.accumulation_steps == 0:
|
||||||
|
scaler.unscale_(optimizer)
|
||||||
|
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
|
||||||
|
|
||||||
|
scaler.step(optimizer)
|
||||||
|
scaler.update()
|
||||||
|
|
||||||
|
optimizer.zero_grad(set_to_none=True)
|
||||||
|
|
||||||
|
if step % args.log_interval == 0 or step == iters:
|
||||||
|
spend_time = time.time() - start_time
|
||||||
|
current_loss = loss.item() * args.accumulation_steps
|
||||||
|
current_aux_loss = res.aux_loss.item() if res.aux_loss is not None else 0.0
|
||||||
|
current_logits_loss = current_loss - current_aux_loss
|
||||||
|
current_lr = optimizer.param_groups[-1]['lr']
|
||||||
|
eta_min = spend_time / max(step - start_step, 1) * (iters - step) // 60
|
||||||
|
Logger(f'Epoch:[{epoch + 1}/{args.epochs}]({step}/{iters}), loss: {current_loss:.4f}, logits_loss: {current_logits_loss:.4f}, aux_loss: {current_aux_loss:.4f}, lr: {current_lr:.8f}, epoch_time: {eta_min:.1f}min')
|
||||||
|
if wandb: wandb.log({"loss": current_loss, "logits_loss": current_logits_loss, "aux_loss": current_aux_loss, "learning_rate": current_lr, "epoch_time": eta_min})
|
||||||
|
|
||||||
|
if (step % args.save_interval == 0 or step == iters) and is_main_process():
|
||||||
|
model.eval()
|
||||||
|
moe_suffix = '_moe' if lm_config.use_moe else ''
|
||||||
|
ckp = f'{args.save_dir}/{args.save_weight}_{lm_config.hidden_size}{moe_suffix}.pth'
|
||||||
|
raw_model = model.module if isinstance(model, DistributedDataParallel) else model
|
||||||
|
raw_model = getattr(raw_model, '_orig_mod', raw_model)
|
||||||
|
state_dict = raw_model.state_dict()
|
||||||
|
torch.save({k: v.half().cpu() for k, v in state_dict.items()}, ckp)
|
||||||
|
lm_checkpoint(lm_config, weight=args.save_weight, model=model, optimizer=optimizer,
|
||||||
|
epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints', scaler=scaler)
|
||||||
|
model.train()
|
||||||
|
del state_dict
|
||||||
|
|
||||||
|
del input_ids, labels, res, loss
|
||||||
|
|
||||||
|
if last_step > start_step and last_step % args.accumulation_steps != 0:
|
||||||
|
scaler.unscale_(optimizer)
|
||||||
|
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
|
||||||
|
scaler.step(optimizer)
|
||||||
|
scaler.update()
|
||||||
|
optimizer.zero_grad(set_to_none=True)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(description="MiniMind Full SFT")
|
||||||
|
parser.add_argument("--save_dir", type=str, default="../out", help="模型保存目录")
|
||||||
|
parser.add_argument('--save_weight', default='full_sft', type=str, help="保存权重的前缀名")
|
||||||
|
parser.add_argument("--epochs", type=int, default=2, help="训练轮数")
|
||||||
|
parser.add_argument("--batch_size", type=int, default=16, help="batch size")
|
||||||
|
parser.add_argument("--learning_rate", type=float, default=1e-5, help="初始学习率")
|
||||||
|
parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu", help="训练设备")
|
||||||
|
parser.add_argument("--dtype", type=str, default="bfloat16", help="混合精度类型")
|
||||||
|
parser.add_argument("--num_workers", type=int, default=8, help="数据加载线程数")
|
||||||
|
parser.add_argument("--accumulation_steps", type=int, default=1, help="梯度累积步数")
|
||||||
|
parser.add_argument("--grad_clip", type=float, default=1.0, help="梯度裁剪阈值")
|
||||||
|
parser.add_argument("--log_interval", type=int, default=100, help="日志打印间隔")
|
||||||
|
parser.add_argument("--save_interval", type=int, default=1000, help="模型保存间隔")
|
||||||
|
parser.add_argument('--hidden_size', default=768, type=int, help="隐藏层维度")
|
||||||
|
parser.add_argument('--num_hidden_layers', default=8, type=int, help="隐藏层数量")
|
||||||
|
parser.add_argument('--max_seq_len', default=768, type=int, help="训练的最大截断长度(中文1token≈1.5~1.7字符)")
|
||||||
|
parser.add_argument('--use_moe', default=0, type=int, choices=[0, 1], help="是否使用MoE架构(0=否,1=是)")
|
||||||
|
parser.add_argument("--data_path", type=str, default="../dataset/sft_t2t_mini.jsonl", help="训练数据路径")
|
||||||
|
parser.add_argument('--from_weight', default='pretrain', type=str, help="基于哪个权重训练,为none则不基于任何权重训练")
|
||||||
|
parser.add_argument('--from_resume', default=0, type=int, choices=[0, 1], help="是否自动检测&续训(0=否,1=是)")
|
||||||
|
parser.add_argument("--use_wandb", action="store_true", help="是否使用wandb")
|
||||||
|
parser.add_argument("--wandb_project", type=str, default="MiniMind-Full-SFT", help="wandb项目名")
|
||||||
|
parser.add_argument("--use_compile", default=0, type=int, choices=[0, 1], help="是否使用torch.compile加速(0=否,1=是)")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# ========== 1. 初始化环境和随机种子 ==========
|
||||||
|
local_rank = init_distributed_mode()
|
||||||
|
if dist.is_initialized(): args.device = f"cuda:{local_rank}"
|
||||||
|
setup_seed(42 + (dist.get_rank() if dist.is_initialized() else 0))
|
||||||
|
|
||||||
|
# ========== 2. 配置目录、模型参数、检查ckp ==========
|
||||||
|
os.makedirs(args.save_dir, exist_ok=True)
|
||||||
|
lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers, use_moe=bool(args.use_moe))
|
||||||
|
ckp_data = lm_checkpoint(lm_config, weight=args.save_weight, save_dir='../checkpoints') if args.from_resume==1 else None
|
||||||
|
|
||||||
|
# ========== 3. 设置混合精度 ==========
|
||||||
|
device_type = "cuda" if "cuda" in args.device else "cpu"
|
||||||
|
dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16
|
||||||
|
autocast_ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=dtype)
|
||||||
|
|
||||||
|
# ========== 4. 配wandb ==========
|
||||||
|
wandb = None
|
||||||
|
if args.use_wandb and is_main_process():
|
||||||
|
import swanlab as wandb
|
||||||
|
wandb_id = ckp_data.get('wandb_id') if ckp_data else None
|
||||||
|
resume = 'must' if wandb_id else None
|
||||||
|
wandb_run_name = f"MiniMind-Full-SFT-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LearningRate-{args.learning_rate}"
|
||||||
|
wandb.init(project=args.wandb_project, name=wandb_run_name, id=wandb_id, resume=resume)
|
||||||
|
|
||||||
|
# ========== 5. 定义模型、数据、优化器 ==========
|
||||||
|
model, tokenizer = init_model(lm_config, args.from_weight, device=args.device)
|
||||||
|
train_ds = SFTDataset(args.data_path, tokenizer, max_length=args.max_seq_len)
|
||||||
|
train_sampler = DistributedSampler(train_ds) if dist.is_initialized() else None
|
||||||
|
scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype == 'float16'))
|
||||||
|
optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)
|
||||||
|
|
||||||
|
# ========== 6. 从ckp恢复状态 ==========
|
||||||
|
start_epoch, start_step = 0, 0
|
||||||
|
if ckp_data:
|
||||||
|
model.load_state_dict(ckp_data['model'])
|
||||||
|
optimizer.load_state_dict(ckp_data['optimizer'])
|
||||||
|
scaler.load_state_dict(ckp_data['scaler'])
|
||||||
|
start_epoch = ckp_data['epoch']
|
||||||
|
start_step = ckp_data.get('step', 0)
|
||||||
|
|
||||||
|
# ========== 7. 编译和分布式包装 ==========
|
||||||
|
if args.use_compile == 1:
|
||||||
|
model = torch.compile(model)
|
||||||
|
Logger('torch.compile enabled')
|
||||||
|
if dist.is_initialized():
|
||||||
|
model = DistributedDataParallel(model, device_ids=[local_rank])
|
||||||
|
|
||||||
|
# ========== 8. 开始训练 ==========
|
||||||
|
for epoch in range(start_epoch, args.epochs):
|
||||||
|
train_sampler and train_sampler.set_epoch(epoch)
|
||||||
|
setup_seed(42 + epoch); indices = torch.randperm(len(train_ds)).tolist()
|
||||||
|
skip = start_step if (epoch == start_epoch and start_step > 0) else 0
|
||||||
|
batch_sampler = SkipBatchSampler(train_sampler or indices, args.batch_size, skip)
|
||||||
|
loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True)
|
||||||
|
if skip > 0:
|
||||||
|
Logger(f'Epoch [{epoch + 1}/{args.epochs}]: 跳过前{start_step}个step,从step {start_step + 1}开始')
|
||||||
|
train_epoch(epoch, loader, len(loader) + skip, start_step, wandb)
|
||||||
|
else:
|
||||||
|
train_epoch(epoch, loader, len(loader), 0, wandb)
|
||||||
|
|
||||||
|
# ========== 9. 清理分布进程 ==========
|
||||||
|
if dist.is_initialized(): dist.destroy_process_group()
|
||||||
330
trainer/train_grpo.py
Executable file
@ -0,0 +1,330 @@
|
|||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
__package__ = "trainer"
|
||||||
|
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import math
|
||||||
|
import re
|
||||||
|
import gc
|
||||||
|
import warnings
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import torch.distributed as dist
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
from contextlib import nullcontext
|
||||||
|
from torch import optim
|
||||||
|
from torch.nn.parallel import DistributedDataParallel
|
||||||
|
from torch.utils.data import DataLoader, DistributedSampler
|
||||||
|
from torch.optim.lr_scheduler import CosineAnnealingLR
|
||||||
|
from transformers import AutoModel
|
||||||
|
from model.model_minimind import MiniMindConfig, MiniMindForCausalLM
|
||||||
|
from dataset.lm_dataset import RLAIFDataset
|
||||||
|
from trainer.trainer_utils import Logger, is_main_process, lm_checkpoint, init_distributed_mode, setup_seed, SkipBatchSampler, init_model, LMForRewardModel
|
||||||
|
from trainer.rollout_engine import create_rollout_engine
|
||||||
|
|
||||||
|
warnings.filterwarnings('ignore')
|
||||||
|
|
||||||
|
|
||||||
|
def rep_penalty(text, n=3, cap=0.5):
|
||||||
|
toks = re.findall(r"\w+|[^\w\s]", text.lower())
|
||||||
|
grams = [tuple(toks[i:i + n]) for i in range(len(toks) - n + 1)]
|
||||||
|
return min(cap, (len(grams) - len(set(grams))) * cap * 2 / len(grams)) if grams else 0.0
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_rewards(prompts, responses, reward_model):
|
||||||
|
rewards = torch.zeros(len(responses), device=args.device)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
reward_model_scores = []
|
||||||
|
batch_size = len(prompts)
|
||||||
|
|
||||||
|
for i in range(batch_size):
|
||||||
|
for j in range(args.num_generations):
|
||||||
|
response_idx = i * args.num_generations + j
|
||||||
|
response = responses[response_idx]
|
||||||
|
prompt = prompts[i]
|
||||||
|
|
||||||
|
pattern = r"<\|im_start\|>(system|user|assistant)\s+(.*?)<\|im_end\|>"
|
||||||
|
matches = re.findall(pattern, prompt, re.DOTALL)
|
||||||
|
messages = [{"role": role, "content": content.strip()} for role, content in matches]
|
||||||
|
answer = response
|
||||||
|
rewards[response_idx] += 0.5 if 20 <= len(response.strip()) <= 800 else -0.5
|
||||||
|
if '</think>' in response:
|
||||||
|
thinking_content, answer_content = response.split('</think>', 1)
|
||||||
|
rewards[response_idx] += 1.0 if 20 <= len(thinking_content.strip()) <= 300 else -0.5
|
||||||
|
rewards[response_idx] += 0.25 if response.count('</think>') == 1 else -0.25
|
||||||
|
answer = answer_content.strip()
|
||||||
|
rewards[response_idx] -= rep_penalty(answer)
|
||||||
|
|
||||||
|
score = reward_model.get_score(messages, answer)
|
||||||
|
reward_model_scores.append(score)
|
||||||
|
|
||||||
|
reward_model_scores = torch.tensor(reward_model_scores, device=args.device)
|
||||||
|
rewards += reward_model_scores
|
||||||
|
|
||||||
|
return rewards
|
||||||
|
|
||||||
|
|
||||||
|
def grpo_train_epoch(epoch, loader, iters, rollout_engine, ref_model, reward_model, start_step=0, wandb=None, use_sglang=False):
|
||||||
|
for step, batch in enumerate(loader, start=start_step + 1):
|
||||||
|
prompts = batch['prompt'] # list[str], length B
|
||||||
|
prompt_inputs = tokenizer(prompts, return_tensors="pt", padding=True, return_token_type_ids=False,
|
||||||
|
padding_side="left", add_special_tokens=False).to(args.device)
|
||||||
|
if args.max_seq_len:
|
||||||
|
prompt_inputs["input_ids"] = prompt_inputs["input_ids"][:, -args.max_seq_len:]
|
||||||
|
prompt_inputs["attention_mask"] = prompt_inputs["attention_mask"][:, -args.max_seq_len:]
|
||||||
|
|
||||||
|
rollout_result = rollout_engine.rollout(
|
||||||
|
prompt_ids=prompt_inputs["input_ids"],
|
||||||
|
attention_mask=prompt_inputs["attention_mask"],
|
||||||
|
num_generations=args.num_generations,
|
||||||
|
max_new_tokens=args.max_gen_len,
|
||||||
|
temperature=0.8,
|
||||||
|
)
|
||||||
|
outputs = rollout_result.output_ids
|
||||||
|
completion_ids = rollout_result.completion_ids
|
||||||
|
completions = rollout_result.completions
|
||||||
|
old_per_token_logps = rollout_result.per_token_logps.to(args.device)
|
||||||
|
prompt_lens = rollout_result.prompt_lens.to(args.device)
|
||||||
|
full_mask = (outputs != tokenizer.pad_token_id).long()
|
||||||
|
logp_pos = prompt_lens.unsqueeze(1) - 1 + torch.arange(completion_ids.size(1), device=args.device).unsqueeze(0)
|
||||||
|
|
||||||
|
model_unwrapped = model.module if isinstance(model, DistributedDataParallel) else model
|
||||||
|
with autocast_ctx:
|
||||||
|
res = model_unwrapped(outputs, attention_mask=full_mask)
|
||||||
|
aux_loss = res.aux_loss if lm_config.use_moe else torch.tensor(0.0, device=args.device)
|
||||||
|
per_token_logps = F.log_softmax(res.logits[:, :-1, :], dim=-1).gather(2, outputs[:, 1:].unsqueeze(-1)).squeeze(-1).gather(1, logp_pos)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
ref_per_token_logps = F.log_softmax(ref_model(outputs, attention_mask=full_mask).logits[:, :-1, :], dim=-1).gather(2, outputs[:, 1:].unsqueeze(-1)).squeeze(-1).gather(1, logp_pos)
|
||||||
|
rewards = calculate_rewards(prompts, completions, reward_model).to(args.device) # [B*num_gen]
|
||||||
|
|
||||||
|
if args.debug_mode and is_main_process() and step % args.debug_interval == 0:
|
||||||
|
for i in range(len(prompts)):
|
||||||
|
Logger(f"[DEBUG] step={step}, sample[{i}]")
|
||||||
|
Logger('-'*100)
|
||||||
|
Logger(f"{'=' * 30} [DEBUG] sample[{i}] CONTEXT_BEGIN {'=' * 30}")
|
||||||
|
Logger(prompts[i])
|
||||||
|
Logger(f"{'=' * 31} [DEBUG] sample[{i}] CONTEXT_END {'=' * 31}")
|
||||||
|
for j in range(args.num_generations):
|
||||||
|
idx = i * args.num_generations + j
|
||||||
|
Logger(f"{'=' * 28} [DEBUG] gen[{j}] RESPONSE_BEGIN {'=' * 28}")
|
||||||
|
Logger(completions[idx])
|
||||||
|
Logger(f"{'=' * 29} [DEBUG] gen[{j}] RESPONSE_END {'=' * 29}")
|
||||||
|
Logger(f"[DEBUG] gen[{j}] reward={rewards[idx].item():.4f}")
|
||||||
|
Logger('='*100)
|
||||||
|
|
||||||
|
grouped_rewards = rewards.view(-1, args.num_generations) # [B, num_gen]
|
||||||
|
mean_r = grouped_rewards.mean(dim=1).repeat_interleave(args.num_generations) # [B*num_gen]
|
||||||
|
std_r = grouped_rewards.std(dim=1, unbiased=False).repeat_interleave(args.num_generations) # [B*num_gen]
|
||||||
|
advantages = (rewards - mean_r) / (std_r + 1e-4) # [B*num_gen]
|
||||||
|
|
||||||
|
completion_pad_mask = rollout_result.completion_mask.to(args.device).bool()
|
||||||
|
is_eos = (completion_ids == tokenizer.eos_token_id) & completion_pad_mask # [B*num_gen, R]
|
||||||
|
eos_idx = torch.full((is_eos.size(0),), is_eos.size(1) - 1, dtype=torch.long, device=args.device)
|
||||||
|
eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)]
|
||||||
|
completion_mask = ((torch.arange(is_eos.size(1), device=args.device).expand(is_eos.size(0), -1) <= eos_idx.unsqueeze(1)) & completion_pad_mask).int() # [B*num_gen, R]
|
||||||
|
|
||||||
|
kl_div = ref_per_token_logps - per_token_logps
|
||||||
|
per_token_kl = torch.exp(kl_div) - kl_div - 1 # [B*num_gen, R]
|
||||||
|
ratio = torch.exp(per_token_logps - old_per_token_logps) # [B*num_gen, R]
|
||||||
|
if args.loss_type == "cispo":
|
||||||
|
clamped_ratio = torch.clamp(ratio, max=args.epsilon_high).detach()
|
||||||
|
per_token_loss = -(clamped_ratio * advantages.unsqueeze(1) * per_token_logps - args.beta * per_token_kl)
|
||||||
|
else:
|
||||||
|
clipped_ratio = torch.clamp(ratio, 1 - args.epsilon, 1 + args.epsilon)
|
||||||
|
per_token_loss1 = ratio * advantages.unsqueeze(1)
|
||||||
|
per_token_loss2 = clipped_ratio * advantages.unsqueeze(1)
|
||||||
|
per_token_loss = -(torch.min(per_token_loss1, per_token_loss2) - args.beta * per_token_kl)
|
||||||
|
policy_loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1).clamp(min=1)).mean()
|
||||||
|
loss = (policy_loss + aux_loss) / args.accumulation_steps # scalar
|
||||||
|
loss.backward()
|
||||||
|
|
||||||
|
if step % args.accumulation_steps == 0:
|
||||||
|
if args.grad_clip > 0:
|
||||||
|
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
|
||||||
|
optimizer.step()
|
||||||
|
scheduler.step()
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
if step % args.log_interval == 0 or step == iters:
|
||||||
|
policy_loss_val = loss.item() * args.accumulation_steps
|
||||||
|
current_aux_loss = aux_loss.item()
|
||||||
|
avg_reward_val = rewards.mean().item()
|
||||||
|
avg_len_val = completion_mask.sum(dim=1).float().mean().item()
|
||||||
|
kl_ref_val = ((ref_per_token_logps - per_token_logps) * completion_mask).sum().item() / max(completion_mask.sum().item(), 1)
|
||||||
|
advantages_mean_val = advantages.mean().item()
|
||||||
|
advantages_std_val = advantages.std().item()
|
||||||
|
current_lr = optimizer.param_groups[0]['lr']
|
||||||
|
|
||||||
|
Logger(f'Epoch:[{epoch + 1}/{args.epochs}]({step}/{iters}), '
|
||||||
|
f'Reward: {avg_reward_val:.4f}, KL_ref: {kl_ref_val:.4f}, '
|
||||||
|
f'Adv Std: {advantages_std_val:.4f}, Adv Mean: {advantages_mean_val:.4f}, '
|
||||||
|
f'Actor Loss: {policy_loss_val:.4f}, Avg Response Len: {avg_len_val:.2f}, Learning Rate: {current_lr:.8f}')
|
||||||
|
|
||||||
|
if wandb and is_main_process():
|
||||||
|
wandb.log({
|
||||||
|
"reward": avg_reward_val,
|
||||||
|
"kl_ref": kl_ref_val,
|
||||||
|
"advantages_std": advantages_std_val,
|
||||||
|
"advantages_mean": advantages_mean_val,
|
||||||
|
"policy_loss": policy_loss_val,
|
||||||
|
"avg_response_len": avg_len_val,
|
||||||
|
"learning_rate": current_lr
|
||||||
|
})
|
||||||
|
|
||||||
|
if (step % args.save_interval == 0 or step == iters) and is_main_process():
|
||||||
|
model.eval()
|
||||||
|
moe_suffix = '_moe' if lm_config.use_moe else ''
|
||||||
|
ckp = f'{args.save_dir}/{args.save_weight}_{lm_config.hidden_size}{moe_suffix}.pth'
|
||||||
|
raw_model = model.module if isinstance(model, DistributedDataParallel) else model
|
||||||
|
raw_model = getattr(raw_model, '_orig_mod', raw_model)
|
||||||
|
state_dict = raw_model.state_dict()
|
||||||
|
torch.save({k: v.half().cpu() for k, v in state_dict.items()}, ckp)
|
||||||
|
lm_checkpoint(lm_config, weight=args.save_weight, model=model, optimizer=optimizer,
|
||||||
|
epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints', scheduler=scheduler)
|
||||||
|
model.train()
|
||||||
|
del state_dict
|
||||||
|
|
||||||
|
if step % args.save_interval == 0 or step == iters: rollout_engine.update_policy(model)
|
||||||
|
|
||||||
|
del prompt_inputs, outputs, completion_ids, per_token_logps, ref_per_token_logps
|
||||||
|
del completions, rewards, grouped_rewards, mean_r, std_r, advantages, completion_mask, completion_pad_mask, prompt_lens, logp_pos
|
||||||
|
|
||||||
|
if step > start_step and step % args.accumulation_steps != 0:
|
||||||
|
if args.grad_clip > 0:
|
||||||
|
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
|
||||||
|
optimizer.step()
|
||||||
|
scheduler.step()
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(description="MiniMind GRPO (Group Relative Policy Optimization)")
|
||||||
|
parser.add_argument("--save_dir", type=str, default="../out", help="模型保存目录")
|
||||||
|
parser.add_argument('--save_weight', default='grpo', type=str, help="保存权重的前缀名")
|
||||||
|
parser.add_argument("--epochs", type=int, default=1, help="训练轮数")
|
||||||
|
parser.add_argument("--batch_size", type=int, default=2, help="batch size")
|
||||||
|
parser.add_argument("--learning_rate", type=float, default=3e-7, help="初始学习率")
|
||||||
|
parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu", help="训练设备")
|
||||||
|
parser.add_argument("--dtype", type=str, default="bfloat16", help="混合精度类型")
|
||||||
|
parser.add_argument("--num_workers", type=int, default=8, help="数据加载线程数")
|
||||||
|
parser.add_argument("--accumulation_steps", type=int, default=1, help="梯度累积步数")
|
||||||
|
parser.add_argument("--grad_clip", type=float, default=1.0, help="梯度裁剪阈值")
|
||||||
|
parser.add_argument("--log_interval", type=int, default=1, help="日志打印间隔")
|
||||||
|
parser.add_argument("--save_interval", type=int, default=10, help="模型保存间隔")
|
||||||
|
parser.add_argument('--hidden_size', default=768, type=int, help="隐藏层维度")
|
||||||
|
parser.add_argument('--num_hidden_layers', default=8, type=int, help="隐藏层数量")
|
||||||
|
parser.add_argument('--use_moe', default=0, type=int, choices=[0, 1], help="是否使用MoE架构(0=否,1=是)")
|
||||||
|
parser.add_argument('--max_seq_len', default=768, type=int, help="Prompt最大长度")
|
||||||
|
parser.add_argument("--max_gen_len", type=int, default=1024, help="生成的最大长度")
|
||||||
|
parser.add_argument("--data_path", type=str, default="../dataset/rlaif.jsonl", help="RLAIF数据路径")
|
||||||
|
parser.add_argument("--num_generations", type=int, default=6, help="每个prompt生成的样本数")
|
||||||
|
parser.add_argument("--beta", type=float, default=0.1, help="KL惩罚系数")
|
||||||
|
parser.add_argument("--loss_type", type=str, default="cispo", choices=["grpo", "cispo"], help="loss类型")
|
||||||
|
parser.add_argument("--epsilon", type=float, default=0.2, help="GRPO的PPO clip epsilon")
|
||||||
|
parser.add_argument("--epsilon_high", type=float, default=5.0, help="epsilon上界")
|
||||||
|
parser.add_argument('--from_weight', default='full_sft', type=str, help="基于哪个权重训练")
|
||||||
|
parser.add_argument("--reward_model_path", type=str, default="../../internlm2-1_8b-reward", help="Reward模型路径")
|
||||||
|
parser.add_argument('--from_resume', default=0, type=int, choices=[0, 1], help="是否自动检测&续训(0=否,1=是)")
|
||||||
|
parser.add_argument("--use_wandb", action="store_true", help="是否使用wandb")
|
||||||
|
parser.add_argument("--wandb_project", type=str, default="MiniMind-GRPO", help="wandb项目名")
|
||||||
|
parser.add_argument("--use_compile", default=0, type=int, choices=[0, 1], help="是否使用torch.compile加速(0=否,1=是)")
|
||||||
|
parser.add_argument("--debug_mode", action="store_true", help="是否打印训练调试采样")
|
||||||
|
parser.add_argument("--debug_interval", type=int, default=20, help="debug模式下每隔多少step打印一次采样")
|
||||||
|
parser.add_argument("--thinking_ratio", type=float, default=0.9, help="按概率开启thinking(0.0~1.0)")
|
||||||
|
parser.add_argument("--rollout_engine", type=str, default="torch", choices=["torch", "sglang"], help="rollout引擎类型")
|
||||||
|
parser.add_argument("--sglang_base_url", type=str, default="http://localhost:8998", help="SGLang服务器URL")
|
||||||
|
parser.add_argument("--sglang_model_path", type=str, default="../model", help="SGLang tokenizer路径")
|
||||||
|
parser.add_argument("--sglang_shared_path", type=str, default="./sglang_ckpt_grpo", help="SGLang共享存储路径")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# ========== 1. 初始化环境和随机种子 ==========
|
||||||
|
local_rank = init_distributed_mode()
|
||||||
|
if dist.is_initialized(): args.device = f"cuda:{local_rank}"
|
||||||
|
setup_seed(42 + (dist.get_rank() if dist.is_initialized() else 0))
|
||||||
|
|
||||||
|
# ========== 2. 配置目录、模型参数、检查ckp ==========
|
||||||
|
os.makedirs(args.save_dir, exist_ok=True)
|
||||||
|
lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers,
|
||||||
|
max_seq_len=args.max_seq_len + args.max_gen_len, use_moe=bool(args.use_moe))
|
||||||
|
ckp_data = lm_checkpoint(lm_config, weight=args.save_weight, save_dir='../checkpoints') if args.from_resume==1 else None
|
||||||
|
|
||||||
|
# ========== 3. 设置混合精度 ==========
|
||||||
|
device_type = "cuda" if "cuda" in args.device else "cpu"
|
||||||
|
dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16
|
||||||
|
autocast_ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=dtype)
|
||||||
|
|
||||||
|
# ========== 4. 配wandb ==========
|
||||||
|
wandb = None
|
||||||
|
if args.use_wandb and is_main_process():
|
||||||
|
import swanlab as wandb
|
||||||
|
wandb_id = ckp_data.get('wandb_id') if ckp_data else None
|
||||||
|
resume = 'must' if wandb_id else None
|
||||||
|
wandb_run_name = f"MiniMind-GRPO-Epoch-{args.epochs}-BS-{args.batch_size}-LR-{args.learning_rate}"
|
||||||
|
wandb.init(project=args.wandb_project, name=wandb_run_name, id=wandb_id, resume=resume)
|
||||||
|
|
||||||
|
# ========== 5. 初始化模型和数据 ==========
|
||||||
|
base_weight = args.from_weight
|
||||||
|
# Policy模型
|
||||||
|
model, tokenizer = init_model(lm_config, base_weight, device=args.device)
|
||||||
|
# Reference模型
|
||||||
|
ref_model, _ = init_model(lm_config, base_weight, device=args.device)
|
||||||
|
ref_model = ref_model.eval().requires_grad_(False)
|
||||||
|
# Reward模型
|
||||||
|
reward_model = LMForRewardModel(args.reward_model_path, device=args.device, dtype=torch.float16)
|
||||||
|
# Rollout引擎(可插拔替换,只负责 policy 推理)
|
||||||
|
rollout_engine = create_rollout_engine(
|
||||||
|
engine_type=args.rollout_engine,
|
||||||
|
policy_model=model,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
device=args.device,
|
||||||
|
autocast_ctx=autocast_ctx,
|
||||||
|
sglang_base_url=args.sglang_base_url,
|
||||||
|
sglang_model_path=args.sglang_model_path,
|
||||||
|
sglang_shared_path=args.sglang_shared_path,
|
||||||
|
)
|
||||||
|
# 数据和优化器
|
||||||
|
train_ds = RLAIFDataset(args.data_path, tokenizer, max_length=lm_config.max_seq_len, thinking_ratio=args.thinking_ratio)
|
||||||
|
train_sampler = DistributedSampler(train_ds) if dist.is_initialized() else None
|
||||||
|
optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)
|
||||||
|
loader_for_count = DataLoader(train_ds, batch_size=args.batch_size, sampler=train_sampler)
|
||||||
|
iters = len(loader_for_count)
|
||||||
|
total_optimizer_steps = math.ceil(iters / args.accumulation_steps) * args.epochs
|
||||||
|
scheduler = CosineAnnealingLR(optimizer, T_max=total_optimizer_steps, eta_min=args.learning_rate / 10)
|
||||||
|
|
||||||
|
# ========== 6. 从ckp恢复状态 ==========
|
||||||
|
start_epoch, start_step = 0, 0
|
||||||
|
if ckp_data:
|
||||||
|
model.load_state_dict(ckp_data['model'])
|
||||||
|
optimizer.load_state_dict(ckp_data['optimizer'])
|
||||||
|
scheduler.load_state_dict(ckp_data['scheduler'])
|
||||||
|
start_epoch = ckp_data['epoch']
|
||||||
|
start_step = ckp_data.get('step', 0)
|
||||||
|
|
||||||
|
# ========== 7. 编译和分布式包装 ==========
|
||||||
|
if args.use_compile == 1:
|
||||||
|
model = torch.compile(model)
|
||||||
|
Logger('torch.compile enabled')
|
||||||
|
rollout_engine.update_policy(model)
|
||||||
|
if dist.is_initialized():
|
||||||
|
model = DistributedDataParallel(model, device_ids=[local_rank])
|
||||||
|
rollout_engine.update_policy(model)
|
||||||
|
|
||||||
|
# ========== 8. 开始训练 ==========
|
||||||
|
for epoch in range(start_epoch, args.epochs):
|
||||||
|
train_sampler and train_sampler.set_epoch(epoch)
|
||||||
|
setup_seed(42 + epoch); indices = torch.randperm(len(train_ds)).tolist()
|
||||||
|
skip = start_step if (epoch == start_epoch and start_step > 0) else 0
|
||||||
|
batch_sampler = SkipBatchSampler(train_sampler or indices, args.batch_size, skip)
|
||||||
|
loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True)
|
||||||
|
if skip > 0:
|
||||||
|
Logger(f'Epoch [{epoch + 1}/{args.epochs}]: 跳过前{start_step}个step,从step {start_step + 1}开始')
|
||||||
|
grpo_train_epoch(epoch, loader, len(loader) + skip, rollout_engine, ref_model, reward_model, start_step, wandb, use_sglang = (args.rollout_engine == "sglang"))
|
||||||
|
else:
|
||||||
|
grpo_train_epoch(epoch, loader, len(loader), rollout_engine, ref_model, reward_model, 0, wandb, use_sglang = (args.rollout_engine == "sglang"))
|
||||||
|
|
||||||
|
# ========== 9. 清理分布进程 ==========
|
||||||
|
if dist.is_initialized(): dist.destroy_process_group()
|
||||||
183
trainer/train_lora.py
Normal file
@ -0,0 +1,183 @@
|
|||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
__package__ = "trainer"
|
||||||
|
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import time
|
||||||
|
import warnings
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
from contextlib import nullcontext
|
||||||
|
from torch import optim, nn
|
||||||
|
from torch.nn.parallel import DistributedDataParallel
|
||||||
|
from torch.utils.data import DataLoader, DistributedSampler
|
||||||
|
from model.model_minimind import MiniMindConfig
|
||||||
|
from dataset.lm_dataset import SFTDataset
|
||||||
|
from model.model_lora import save_lora, apply_lora
|
||||||
|
from trainer.trainer_utils import get_lr, Logger, is_main_process, lm_checkpoint, init_distributed_mode, setup_seed, init_model, SkipBatchSampler
|
||||||
|
|
||||||
|
warnings.filterwarnings('ignore')
|
||||||
|
|
||||||
|
|
||||||
|
def train_epoch(epoch, loader, iters, lora_params, start_step=0, wandb=None):
|
||||||
|
start_time = time.time()
|
||||||
|
last_step = start_step
|
||||||
|
for step, (input_ids, labels) in enumerate(loader, start=start_step + 1):
|
||||||
|
input_ids = input_ids.to(args.device)
|
||||||
|
labels = labels.to(args.device)
|
||||||
|
last_step = step
|
||||||
|
lr = get_lr(epoch * iters + step, args.epochs * iters, args.learning_rate)
|
||||||
|
for param_group in optimizer.param_groups:
|
||||||
|
param_group['lr'] = lr
|
||||||
|
|
||||||
|
with autocast_ctx:
|
||||||
|
res = model(input_ids, labels=labels)
|
||||||
|
loss = res.loss + res.aux_loss
|
||||||
|
loss = loss / args.accumulation_steps
|
||||||
|
|
||||||
|
scaler.scale(loss).backward()
|
||||||
|
|
||||||
|
if step % args.accumulation_steps == 0:
|
||||||
|
scaler.unscale_(optimizer)
|
||||||
|
torch.nn.utils.clip_grad_norm_(lora_params, args.grad_clip)
|
||||||
|
scaler.step(optimizer)
|
||||||
|
scaler.update()
|
||||||
|
optimizer.zero_grad(set_to_none=True)
|
||||||
|
|
||||||
|
if step % args.log_interval == 0 or step == iters:
|
||||||
|
spend_time = time.time() - start_time
|
||||||
|
current_loss = loss.item() * args.accumulation_steps
|
||||||
|
current_aux_loss = res.aux_loss.item() if res.aux_loss is not None else 0.0
|
||||||
|
current_logits_loss = current_loss - current_aux_loss
|
||||||
|
current_lr = optimizer.param_groups[-1]['lr']
|
||||||
|
eta_min = spend_time / max(step - start_step, 1) * (iters - step) // 60
|
||||||
|
Logger(f'Epoch:[{epoch + 1}/{args.epochs}]({step}/{iters}), loss: {current_loss:.4f}, logits_loss: {current_logits_loss:.4f}, aux_loss: {current_aux_loss:.4f}, lr: {current_lr:.8f}, epoch_time: {eta_min:.1f}min')
|
||||||
|
if wandb: wandb.log({"loss": current_loss, "logits_loss": current_logits_loss, "aux_loss": current_aux_loss, "learning_rate": current_lr, "epoch_time": eta_min})
|
||||||
|
|
||||||
|
if (step % args.save_interval == 0 or step == iters) and is_main_process():
|
||||||
|
model.eval()
|
||||||
|
moe_suffix = '_moe' if lm_config.use_moe else ''
|
||||||
|
lora_save_path = f'{args.save_dir}/{args.lora_name}_{lm_config.hidden_size}{moe_suffix}.pth'
|
||||||
|
# LoRA只保存LoRA权重
|
||||||
|
save_lora(model, lora_save_path)
|
||||||
|
lm_checkpoint(lm_config, weight=args.lora_name, model=model, optimizer=optimizer, scaler=scaler, epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints')
|
||||||
|
model.train()
|
||||||
|
|
||||||
|
del input_ids, labels, res, loss
|
||||||
|
|
||||||
|
if last_step > start_step and last_step % args.accumulation_steps != 0:
|
||||||
|
scaler.unscale_(optimizer)
|
||||||
|
torch.nn.utils.clip_grad_norm_(lora_params, args.grad_clip)
|
||||||
|
scaler.step(optimizer)
|
||||||
|
scaler.update()
|
||||||
|
optimizer.zero_grad(set_to_none=True)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(description="MiniMind LoRA Fine-tuning")
|
||||||
|
parser.add_argument("--save_dir", type=str, default="../out", help="模型保存目录")
|
||||||
|
parser.add_argument("--lora_name", type=str, default="lora_medical", help="LoRA权重名称(如lora_identity/lora_medical等)")
|
||||||
|
parser.add_argument("--epochs", type=int, default=10, help="训练轮数")
|
||||||
|
parser.add_argument("--batch_size", type=int, default=32, help="batch size")
|
||||||
|
parser.add_argument("--learning_rate", type=float, default=1e-4, help="初始学习率")
|
||||||
|
parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu", help="训练设备")
|
||||||
|
parser.add_argument("--dtype", type=str, default="bfloat16", help="混合精度类型")
|
||||||
|
parser.add_argument("--num_workers", type=int, default=8, help="数据加载线程数")
|
||||||
|
parser.add_argument("--accumulation_steps", type=int, default=1, help="梯度累积步数")
|
||||||
|
parser.add_argument("--grad_clip", type=float, default=1.0, help="梯度裁剪阈值")
|
||||||
|
parser.add_argument("--log_interval", type=int, default=10, help="日志打印间隔")
|
||||||
|
parser.add_argument("--save_interval", type=int, default=1000, help="模型保存间隔")
|
||||||
|
parser.add_argument('--hidden_size', default=768, type=int, help="隐藏层维度")
|
||||||
|
parser.add_argument('--num_hidden_layers', default=8, type=int, help="隐藏层数量")
|
||||||
|
parser.add_argument('--max_seq_len', default=340, type=int, help="训练的最大截断长度(中文1token≈1.5~1.7字符)")
|
||||||
|
parser.add_argument('--use_moe', default=0, type=int, choices=[0, 1], help="是否使用MoE架构(0=否,1=是)")
|
||||||
|
parser.add_argument("--data_path", type=str, default="../dataset/lora_medical.jsonl", help="LoRA训练数据路径")
|
||||||
|
parser.add_argument('--from_weight', default='full_sft', type=str, help="基于哪个权重训练,默认full_sft")
|
||||||
|
parser.add_argument('--from_resume', default=0, type=int, choices=[0, 1], help="是否自动检测&续训(0=否,1=是)")
|
||||||
|
parser.add_argument("--use_wandb", action="store_true", help="是否使用wandb")
|
||||||
|
parser.add_argument("--wandb_project", type=str, default="MiniMind-LoRA", help="wandb项目名")
|
||||||
|
parser.add_argument("--use_compile", default=0, type=int, choices=[0, 1], help="是否使用torch.compile加速(0=否,1=是)")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# ========== 1. 初始化环境和随机种子 ==========
|
||||||
|
local_rank = init_distributed_mode()
|
||||||
|
if dist.is_initialized(): args.device = f"cuda:{local_rank}"
|
||||||
|
setup_seed(42 + (dist.get_rank() if dist.is_initialized() else 0))
|
||||||
|
|
||||||
|
# ========== 2. 配置目录、模型参数、检查ckp ==========
|
||||||
|
os.makedirs(args.save_dir, exist_ok=True)
|
||||||
|
lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers, use_moe=bool(args.use_moe))
|
||||||
|
ckp_data = lm_checkpoint(lm_config, weight=args.lora_name, save_dir='../checkpoints') if args.from_resume==1 else None
|
||||||
|
|
||||||
|
# ========== 3. 设置混合精度 ==========
|
||||||
|
device_type = "cuda" if "cuda" in args.device else "cpu"
|
||||||
|
dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16
|
||||||
|
autocast_ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=dtype)
|
||||||
|
|
||||||
|
# ========== 4. 配wandb ==========
|
||||||
|
wandb = None
|
||||||
|
if args.use_wandb and is_main_process():
|
||||||
|
import swanlab as wandb
|
||||||
|
wandb_id = ckp_data.get('wandb_id') if ckp_data else None
|
||||||
|
resume = 'must' if wandb_id else None
|
||||||
|
wandb_run_name = f"MiniMind-LoRA-{args.lora_name}-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LR-{args.learning_rate}"
|
||||||
|
wandb.init(project=args.wandb_project, name=wandb_run_name, id=wandb_id, resume=resume)
|
||||||
|
|
||||||
|
# ========== 5. 定义模型、应用LoRA、冻结非LoRA参数 ==========
|
||||||
|
model, tokenizer = init_model(lm_config, args.from_weight, device=args.device)
|
||||||
|
apply_lora(model)
|
||||||
|
|
||||||
|
# 统计参数
|
||||||
|
total_params = sum(p.numel() for p in model.parameters())
|
||||||
|
lora_params_count = sum(p.numel() for name, p in model.named_parameters() if 'lora' in name)
|
||||||
|
Logger(f"LLM 总参数量: {total_params / 1e6:.3f} M")
|
||||||
|
Logger(f"LoRA 参数量: {lora_params_count / 1e6:.3f} M")
|
||||||
|
Logger(f"LoRA 参数占比: {lora_params_count / total_params * 100:.2f}%")
|
||||||
|
|
||||||
|
# 冻结非LoRA参数,收集LoRA参数
|
||||||
|
lora_params = []
|
||||||
|
for name, param in model.named_parameters():
|
||||||
|
if 'lora' in name:
|
||||||
|
param.requires_grad = True
|
||||||
|
lora_params.append(param)
|
||||||
|
else:
|
||||||
|
param.requires_grad = False
|
||||||
|
|
||||||
|
# ========== 6. 定义数据和优化器 ==========
|
||||||
|
train_ds = SFTDataset(args.data_path, tokenizer, max_length=args.max_seq_len)
|
||||||
|
train_sampler = DistributedSampler(train_ds) if dist.is_initialized() else None
|
||||||
|
scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype == 'float16'))
|
||||||
|
optimizer = optim.AdamW(lora_params, lr=args.learning_rate)
|
||||||
|
|
||||||
|
# ========== 7. 从ckp恢复状态 ==========
|
||||||
|
start_epoch, start_step = 0, 0
|
||||||
|
if ckp_data:
|
||||||
|
model.load_state_dict(ckp_data['model'], strict=False)
|
||||||
|
optimizer.load_state_dict(ckp_data['optimizer'])
|
||||||
|
scaler.load_state_dict(ckp_data['scaler'])
|
||||||
|
start_epoch = ckp_data['epoch']
|
||||||
|
start_step = ckp_data.get('step', 0)
|
||||||
|
|
||||||
|
# ========== 8. 编译和分布式包装 ==========
|
||||||
|
if args.use_compile == 1:
|
||||||
|
args.use_compile = 0
|
||||||
|
Logger('[LoRA] monkey-patch forward 与 torch.compile 不兼容,use_compile 已自动关闭')
|
||||||
|
if dist.is_initialized():
|
||||||
|
model = DistributedDataParallel(model, device_ids=[local_rank])
|
||||||
|
|
||||||
|
# ========== 9. 开始训练 ==========
|
||||||
|
for epoch in range(start_epoch, args.epochs):
|
||||||
|
train_sampler and train_sampler.set_epoch(epoch)
|
||||||
|
setup_seed(42 + epoch); indices = torch.randperm(len(train_ds)).tolist()
|
||||||
|
skip = start_step if (epoch == start_epoch and start_step > 0) else 0
|
||||||
|
batch_sampler = SkipBatchSampler(train_sampler or indices, args.batch_size, skip)
|
||||||
|
loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True)
|
||||||
|
if skip > 0:
|
||||||
|
Logger(f'Epoch [{epoch + 1}/{args.epochs}]: 跳过前{start_step}个step,从step {start_step + 1}开始')
|
||||||
|
train_epoch(epoch, loader, len(loader) + skip, lora_params, start_step, wandb)
|
||||||
|
else:
|
||||||
|
train_epoch(epoch, loader, len(loader), lora_params, 0, wandb)
|
||||||
|
|
||||||
|
# ========== 10. 清理分布进程 ==========
|
||||||
|
if dist.is_initialized(): dist.destroy_process_group()
|
||||||
434
trainer/train_ppo.py
Normal file
@ -0,0 +1,434 @@
|
|||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
__package__ = "trainer"
|
||||||
|
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import math
|
||||||
|
import re
|
||||||
|
import warnings
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
from contextlib import nullcontext
|
||||||
|
from torch import optim, nn
|
||||||
|
from torch.nn.parallel import DistributedDataParallel
|
||||||
|
from torch.utils.data import DataLoader, DistributedSampler
|
||||||
|
from torch.nn.utils import clip_grad_norm_
|
||||||
|
from torch.optim.lr_scheduler import CosineAnnealingLR
|
||||||
|
from model.model_minimind import MiniMindConfig, MiniMindForCausalLM
|
||||||
|
from dataset.lm_dataset import RLAIFDataset
|
||||||
|
from trainer.trainer_utils import Logger, is_main_process, lm_checkpoint, init_distributed_mode, setup_seed, SkipBatchSampler, init_model, LMForRewardModel
|
||||||
|
from trainer.rollout_engine import create_rollout_engine
|
||||||
|
|
||||||
|
warnings.filterwarnings('ignore')
|
||||||
|
|
||||||
|
|
||||||
|
def rep_penalty(text, n=3, cap=0.5):
|
||||||
|
toks = re.findall(r"\w+|[^\w\s]", text.lower())
|
||||||
|
grams = [tuple(toks[i:i + n]) for i in range(len(toks) - n + 1)]
|
||||||
|
return min(cap, (len(grams) - len(set(grams))) * cap * 2 / len(grams)) if grams else 0.0
|
||||||
|
|
||||||
|
|
||||||
|
# 自定义的Critic模型,继承自MiniMindLM
|
||||||
|
class CriticModel(MiniMindForCausalLM):
|
||||||
|
def __init__(self, params):
|
||||||
|
super().__init__(params)
|
||||||
|
# 替换lm_head为输出单一价值的线性层
|
||||||
|
self.value_head = nn.Linear(params.hidden_size, 1)
|
||||||
|
|
||||||
|
def forward(self, input_ids=None, attention_mask=None, **kwargs):
|
||||||
|
# 使用基础模型获取隐藏状态
|
||||||
|
outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, **kwargs)
|
||||||
|
hidden_states = self.model.norm(outputs[0])
|
||||||
|
# 使用value_head获取价值估计
|
||||||
|
values = self.value_head(hidden_states).squeeze(-1)
|
||||||
|
return values
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_rewards(prompts, responses, reward_model):
|
||||||
|
rewards = torch.zeros(len(responses), device=args.device)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
reward_model_scores = []
|
||||||
|
for i, (prompt, response) in enumerate(zip(prompts, responses)):
|
||||||
|
pattern = r"<\|im_start\|>(system|user|assistant)\s+(.*?)<\|im_end\|>"
|
||||||
|
matches = re.findall(pattern, prompt, re.DOTALL)
|
||||||
|
messages = [{"role": role, "content": content.strip()} for role, content in matches]
|
||||||
|
answer = response
|
||||||
|
rewards[i] += 0.5 if 20 <= len(response.strip()) <= 800 else -0.5
|
||||||
|
if '</think>' in response:
|
||||||
|
thinking_content, answer_content = response.split('</think>', 1)
|
||||||
|
rewards[i] += 1.0 if 20 <= len(thinking_content.strip()) <= 300 else -0.5
|
||||||
|
rewards[i] += 0.25 if response.count('</think>') == 1 else -0.25
|
||||||
|
answer = answer_content.strip()
|
||||||
|
rewards[i] -= rep_penalty(answer)
|
||||||
|
|
||||||
|
score = reward_model.get_score(messages, answer)
|
||||||
|
reward_model_scores.append(score)
|
||||||
|
|
||||||
|
reward_model_scores = torch.tensor(reward_model_scores, device=args.device)
|
||||||
|
rewards += reward_model_scores
|
||||||
|
|
||||||
|
return rewards
|
||||||
|
|
||||||
|
|
||||||
|
def ppo_train_epoch(epoch, loader, iters, rollout_engine, ref_model, actor_scheduler, critic_scheduler, reward_model, start_step=0, wandb=None, use_sglang=False):
|
||||||
|
actor_model.train()
|
||||||
|
critic_model.train()
|
||||||
|
grad_accum_step = 0
|
||||||
|
|
||||||
|
for step, batch in enumerate(loader, start=start_step + 1):
|
||||||
|
prompts = batch["prompt"] # list[str], length B
|
||||||
|
enc = tokenizer(prompts, return_tensors="pt", padding=True, truncation=True, max_length=args.max_seq_len,
|
||||||
|
padding_side="left").to(args.device) # input_ids: [B, P], attention_mask: [B, P]
|
||||||
|
|
||||||
|
rollout_result = rollout_engine.rollout(
|
||||||
|
prompt_ids=enc.input_ids,
|
||||||
|
attention_mask=enc.attention_mask,
|
||||||
|
num_generations=1,
|
||||||
|
max_new_tokens=args.max_gen_len,
|
||||||
|
temperature=0.8,
|
||||||
|
)
|
||||||
|
gen_out = rollout_result.output_ids
|
||||||
|
completion_ids = rollout_result.completion_ids
|
||||||
|
prompt_lens = rollout_result.prompt_lens.to(args.device)
|
||||||
|
responses_text = rollout_result.completions
|
||||||
|
old_resp_logp = rollout_result.per_token_logps.to(args.device)
|
||||||
|
rewards = calculate_rewards(prompts, responses_text, reward_model) # [B]
|
||||||
|
|
||||||
|
if args.debug_mode and is_main_process() and step % args.debug_interval == 0:
|
||||||
|
for i in range(len(prompts)):
|
||||||
|
Logger(f"[DEBUG] step={step}, sample[{i}]")
|
||||||
|
Logger('-'*100)
|
||||||
|
Logger(f"{'=' * 30} [DEBUG] sample[{i}] CONTEXT_BEGIN {'=' * 30}")
|
||||||
|
Logger(prompts[i])
|
||||||
|
Logger(f"{'=' * 31} [DEBUG] sample[{i}] CONTEXT_END {'=' * 31}")
|
||||||
|
Logger(f"[DEBUG] prompt_len={prompt_lens[i].item()}, response_len={len(responses_text[i])}")
|
||||||
|
Logger(f"{'=' * 28} [DEBUG] sample[{i}] RESPONSE_BEGIN {'=' * 28}")
|
||||||
|
Logger(responses_text[i])
|
||||||
|
Logger(f"{'=' * 29} [DEBUG] sample[{i}] RESPONSE_END {'=' * 29}")
|
||||||
|
Logger(f"[DEBUG] reward={rewards[i].item():.4f}")
|
||||||
|
Logger('='*100)
|
||||||
|
|
||||||
|
full_mask = (gen_out != tokenizer.pad_token_id).long() # [B, P+R]
|
||||||
|
labels = gen_out[:, 1:].clone() # [B, P+R-1]
|
||||||
|
B = len(prompts)
|
||||||
|
resp_labels = completion_ids
|
||||||
|
resp_idx = torch.arange(resp_labels.size(1), device=gen_out.device).unsqueeze(0)
|
||||||
|
logp_pos = prompt_lens.unsqueeze(1) - 1 + resp_idx
|
||||||
|
resp_pad_mask = rollout_result.completion_mask.to(args.device).bool()
|
||||||
|
resp_lengths = resp_pad_mask.sum(dim=1); valid_resp = resp_lengths > 0; eos_mask = resp_labels.eq(tokenizer.eos_token_id) & resp_pad_mask
|
||||||
|
has_eos = eos_mask.any(dim=1); eos_pos = torch.argmax(eos_mask.int(), dim=1)
|
||||||
|
resp_lengths = torch.where(has_eos, eos_pos + 1, resp_lengths).long().clamp(min=1)
|
||||||
|
resp_policy_mask = ((resp_idx < resp_lengths.unsqueeze(1)) & resp_pad_mask).float()
|
||||||
|
resp_value_mask = resp_policy_mask.clone()
|
||||||
|
|
||||||
|
with torch.no_grad(): # Rollout阶段只需推理获取old_logp和old_values,切断梯度省显存
|
||||||
|
critic_for_rollout = critic_model.module if isinstance(critic_model, DistributedDataParallel) else critic_model
|
||||||
|
values_seq = critic_for_rollout(input_ids=gen_out, attention_mask=full_mask)
|
||||||
|
old_resp_values = values_seq.gather(1, logp_pos) * resp_value_mask
|
||||||
|
|
||||||
|
ref_resp_logp = F.log_softmax(ref_model(input_ids=gen_out, attention_mask=full_mask).logits[:, :-1], dim=-1).gather(2, labels.unsqueeze(-1)).squeeze(-1).gather(1, logp_pos)
|
||||||
|
token_rewards = torch.zeros_like(old_resp_logp)
|
||||||
|
last_idx = resp_lengths - 1 # [B]
|
||||||
|
token_rewards[torch.arange(B, device=args.device)[valid_resp], last_idx[valid_resp]] += rewards[valid_resp] # 末尾加外部奖励
|
||||||
|
|
||||||
|
gen_len = old_resp_values.size(1); lastgaelam = torch.zeros(B, device=args.device); advs_rev = []
|
||||||
|
for t in reversed(range(gen_len)):
|
||||||
|
nv = old_resp_values[:, t + 1] if t < gen_len - 1 else 0.0
|
||||||
|
delta = token_rewards[:, t] + args.gamma * nv - old_resp_values[:, t]
|
||||||
|
lastgaelam = delta + args.gamma * args.lam * lastgaelam
|
||||||
|
advs_rev.append(lastgaelam)
|
||||||
|
advantages = torch.stack(advs_rev[::-1], dim=1) # [B, R]
|
||||||
|
returns = advantages + old_resp_values # [B, R]
|
||||||
|
|
||||||
|
adv_mean = (advantages * resp_policy_mask).sum() / resp_policy_mask.sum().clamp(min=1)
|
||||||
|
adv_var = ((advantages - adv_mean) ** 2 * resp_policy_mask).sum() / resp_policy_mask.sum().clamp(min=1)
|
||||||
|
advantages = (advantages - adv_mean) * torch.rsqrt(adv_var + 1e-8) * resp_policy_mask
|
||||||
|
|
||||||
|
mb_size = max(1, min(args.mini_batch_size, B))
|
||||||
|
stop_ppo = False
|
||||||
|
policy_loss_sum = 0.0
|
||||||
|
value_loss_sum = 0.0
|
||||||
|
kl_sum = 0.0
|
||||||
|
kl_ref_sum = 0.0
|
||||||
|
clipfrac_sum = 0.0
|
||||||
|
aux_loss_sum = 0.0
|
||||||
|
log_count = 0
|
||||||
|
actor_unwrapped = actor_model.module if isinstance(actor_model, DistributedDataParallel) else actor_model
|
||||||
|
critic_unwrapped = critic_model.module if isinstance(critic_model, DistributedDataParallel) else critic_model
|
||||||
|
for ppo_epoch in range(args.ppo_update_iters):
|
||||||
|
if stop_ppo:
|
||||||
|
break
|
||||||
|
b_inds = torch.randperm(B, device=args.device)
|
||||||
|
for i in range(0, B, mb_size):
|
||||||
|
inds = b_inds[i:i + mb_size]
|
||||||
|
|
||||||
|
mb_values_seq = critic_unwrapped(input_ids=gen_out[inds], attention_mask=full_mask[inds])
|
||||||
|
mb_resp_values = mb_values_seq.gather(1, logp_pos[inds])
|
||||||
|
|
||||||
|
with autocast_ctx:
|
||||||
|
res = actor_unwrapped(input_ids=gen_out[inds], attention_mask=full_mask[inds])
|
||||||
|
aux_loss = res.aux_loss if lm_config.use_moe else torch.tensor(0.0, device=args.device)
|
||||||
|
|
||||||
|
mb_resp_logp = F.log_softmax(res.logits[:, :-1], dim=-1).gather(2, labels[inds].unsqueeze(-1)).squeeze(-1).gather(1, logp_pos[inds])
|
||||||
|
|
||||||
|
log_ratio = mb_resp_logp - old_resp_logp[inds]
|
||||||
|
approx_kl = (0.5 * (log_ratio ** 2) * resp_policy_mask[inds]).sum() / resp_policy_mask[inds].sum().clamp(min=1)
|
||||||
|
|
||||||
|
# 同步各卡的 approx_kl,防止某卡 break 而其它卡继续导致 DDP 死锁
|
||||||
|
approx_kl_val = approx_kl.detach().clone()
|
||||||
|
if dist.is_initialized():
|
||||||
|
dist.all_reduce(approx_kl_val, op=dist.ReduceOp.AVG)
|
||||||
|
|
||||||
|
if approx_kl_val > args.early_stop_kl:
|
||||||
|
stop_ppo = True
|
||||||
|
|
||||||
|
ratio = torch.exp(log_ratio)
|
||||||
|
clipfrac = ((((ratio - 1.0).abs() > args.clip_epsilon).float() * resp_policy_mask[inds]).sum()
|
||||||
|
/ resp_policy_mask[inds].sum().clamp(min=1))
|
||||||
|
kl_ref_penalty = ((torch.exp(ref_resp_logp[inds] - mb_resp_logp) - (ref_resp_logp[inds] - mb_resp_logp) - 1.0)
|
||||||
|
* resp_policy_mask[inds]).sum() / resp_policy_mask[inds].sum().clamp(min=1)
|
||||||
|
policy_loss = ((torch.max(-advantages[inds] * ratio,
|
||||||
|
-advantages[inds] * torch.clamp(ratio, 1.0 - args.clip_epsilon, 1.0 + args.clip_epsilon))
|
||||||
|
* resp_policy_mask[inds]).sum() / resp_policy_mask[inds].sum().clamp(min=1)
|
||||||
|
+ args.kl_coef * kl_ref_penalty)
|
||||||
|
value_loss = 0.5 * (torch.max((mb_resp_values - returns[inds]) ** 2,
|
||||||
|
(torch.clamp(mb_resp_values, old_resp_values[inds] - args.cliprange_value,
|
||||||
|
old_resp_values[inds] + args.cliprange_value) - returns[inds]) ** 2)
|
||||||
|
* resp_value_mask[inds]).sum() / resp_value_mask[inds].sum().clamp(min=1)
|
||||||
|
|
||||||
|
kl = approx_kl_val
|
||||||
|
kl_ref = kl_ref_penalty.detach()
|
||||||
|
|
||||||
|
# 早停时必须保证 forward-backward 闭环,故只截断 loss 不中断 DDP 通信
|
||||||
|
if stop_ppo:
|
||||||
|
loss = (policy_loss + args.vf_coef * value_loss + aux_loss) * 0.0
|
||||||
|
else:
|
||||||
|
loss = (policy_loss + args.vf_coef * value_loss + aux_loss) / args.accumulation_steps
|
||||||
|
|
||||||
|
loss.backward()
|
||||||
|
|
||||||
|
policy_loss_sum += policy_loss.item()
|
||||||
|
value_loss_sum += value_loss.item()
|
||||||
|
kl_sum += kl.item()
|
||||||
|
kl_ref_sum += kl_ref.item()
|
||||||
|
clipfrac_sum += clipfrac.item()
|
||||||
|
aux_loss_sum += aux_loss.item()
|
||||||
|
log_count += 1
|
||||||
|
|
||||||
|
grad_accum_step += 1
|
||||||
|
|
||||||
|
if grad_accum_step % args.accumulation_steps == 0:
|
||||||
|
clip_grad_norm_(actor_model.parameters(), args.grad_clip)
|
||||||
|
clip_grad_norm_(critic_model.parameters(), args.grad_clip)
|
||||||
|
actor_optimizer.step()
|
||||||
|
critic_optimizer.step()
|
||||||
|
actor_scheduler.step()
|
||||||
|
critic_scheduler.step()
|
||||||
|
actor_optimizer.zero_grad()
|
||||||
|
critic_optimizer.zero_grad()
|
||||||
|
|
||||||
|
if grad_accum_step % args.accumulation_steps != 0:
|
||||||
|
clip_grad_norm_(actor_model.parameters(), args.grad_clip)
|
||||||
|
clip_grad_norm_(critic_model.parameters(), args.grad_clip)
|
||||||
|
actor_optimizer.step()
|
||||||
|
critic_optimizer.step()
|
||||||
|
actor_scheduler.step()
|
||||||
|
critic_scheduler.step()
|
||||||
|
actor_optimizer.zero_grad()
|
||||||
|
critic_optimizer.zero_grad()
|
||||||
|
|
||||||
|
if step % args.save_interval == 0 or step == iters: rollout_engine.update_policy(actor_model)
|
||||||
|
|
||||||
|
if is_main_process():
|
||||||
|
critic_loss_val = value_loss_sum / max(log_count, 1)
|
||||||
|
reward_val = rewards.mean().item()
|
||||||
|
approx_kl_val = kl_sum / max(log_count, 1)
|
||||||
|
kl_ref_val = kl_ref_sum / max(log_count, 1)
|
||||||
|
clipfrac_val = clipfrac_sum / max(log_count, 1)
|
||||||
|
avg_len_val = resp_lengths.float().mean().item()
|
||||||
|
actor_lr, critic_lr = actor_optimizer.param_groups[0]['lr'], critic_optimizer.param_groups[0]['lr']
|
||||||
|
|
||||||
|
if wandb is not None:
|
||||||
|
wandb.log({
|
||||||
|
"reward": reward_val,
|
||||||
|
"kl_ref": kl_ref_val,
|
||||||
|
"approx_kl": approx_kl_val,
|
||||||
|
"clipfrac": clipfrac_val,
|
||||||
|
"critic_loss": critic_loss_val,
|
||||||
|
"avg_response_len": avg_len_val,
|
||||||
|
"actor_lr": actor_lr,
|
||||||
|
"critic_lr": critic_lr,
|
||||||
|
})
|
||||||
|
|
||||||
|
Logger(f"Epoch:[{epoch + 1}/{args.epochs}]({step}/{iters}), "
|
||||||
|
f"Reward: {reward_val:.4f}, KL_ref: {kl_ref_val:.4f}, Approx KL: {approx_kl_val:.4f}, "
|
||||||
|
f"ClipFrac: {clipfrac_val:.4f}, Critic Loss: {critic_loss_val:.4f}, "
|
||||||
|
f"Avg Response Len: {avg_len_val:.2f}, Actor LR: {actor_lr:.8f}, Critic LR: {critic_lr:.8f}")
|
||||||
|
|
||||||
|
if (step % args.save_interval == 0 or step == iters) and is_main_process():
|
||||||
|
actor_model.eval()
|
||||||
|
moe_suffix = '_moe' if lm_config.use_moe else ''
|
||||||
|
ckp = f'{args.save_dir}/{args.save_weight}_{lm_config.hidden_size}{moe_suffix}.pth'
|
||||||
|
raw_actor = actor_model.module if isinstance(actor_model, DistributedDataParallel) else actor_model
|
||||||
|
raw_actor = getattr(raw_actor, '_orig_mod', raw_actor)
|
||||||
|
actor_state = raw_actor.state_dict()
|
||||||
|
torch.save({k: v.half().cpu() for k, v in actor_state.items()}, ckp)
|
||||||
|
|
||||||
|
# 使用 lm_checkpoint 保存完整状态(包括 critic)
|
||||||
|
lm_checkpoint(lm_config, weight=args.save_weight, model=actor_model, optimizer=actor_optimizer,
|
||||||
|
epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints',
|
||||||
|
scheduler=actor_scheduler, critic_model=critic_model,
|
||||||
|
critic_optimizer=critic_optimizer, critic_scheduler=critic_scheduler)
|
||||||
|
actor_model.train()
|
||||||
|
del actor_state
|
||||||
|
|
||||||
|
del enc, gen_out, completion_ids, responses_text, rewards, full_mask, values_seq, advantages
|
||||||
|
del labels, resp_labels, resp_idx, resp_pad_mask, valid_resp, eos_mask, has_eos, eos_pos, resp_lengths, resp_policy_mask, resp_value_mask, old_resp_logp, ref_resp_logp
|
||||||
|
del kl, kl_ref, policy_loss, value_loss, loss, token_rewards, returns, old_resp_values, prompt_lens, logp_pos
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(description="MiniMind PPO (Proximal Policy Optimization)")
|
||||||
|
parser.add_argument("--save_dir", type=str, default="../out", help="模型保存目录")
|
||||||
|
parser.add_argument('--save_weight', default='ppo_actor', type=str, help="保存权重的前缀名")
|
||||||
|
parser.add_argument("--epochs", type=int, default=1, help="训练轮数")
|
||||||
|
parser.add_argument("--batch_size", type=int, default=2, help="batch size")
|
||||||
|
parser.add_argument("--learning_rate", type=float, default=3e-7, help="Actor学习率")
|
||||||
|
parser.add_argument("--critic_learning_rate", type=float, default=5e-7, help="Critic学习率")
|
||||||
|
parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu", help="训练设备")
|
||||||
|
parser.add_argument("--dtype", type=str, default="bfloat16", help="混合精度类型")
|
||||||
|
parser.add_argument("--num_workers", type=int, default=8, help="数据加载线程数")
|
||||||
|
parser.add_argument("--accumulation_steps", type=int, default=1, help="梯度累积步数")
|
||||||
|
parser.add_argument("--grad_clip", type=float, default=1.0, help="梯度裁剪阈值")
|
||||||
|
parser.add_argument("--log_interval", type=int, default=1, help="日志打印间隔")
|
||||||
|
parser.add_argument("--save_interval", type=int, default=10, help="模型保存间隔")
|
||||||
|
parser.add_argument('--hidden_size', default=768, type=int, help="隐藏层维度")
|
||||||
|
parser.add_argument('--num_hidden_layers', default=8, type=int, help="隐藏层数量")
|
||||||
|
parser.add_argument('--use_moe', default=0, type=int, choices=[0, 1], help="是否使用MoE架构(0=否,1=是)")
|
||||||
|
parser.add_argument('--max_seq_len', default=768, type=int, help="Prompt最大长度")
|
||||||
|
parser.add_argument("--max_gen_len", type=int, default=1024, help="生成的最大长度")
|
||||||
|
parser.add_argument("--data_path", type=str, default="../dataset/rlaif.jsonl", help="RLAIF数据路径")
|
||||||
|
parser.add_argument("--clip_epsilon", type=float, default=0.2, help="PPO裁剪参数")
|
||||||
|
parser.add_argument("--vf_coef", type=float, default=0.5, help="Value function系数")
|
||||||
|
parser.add_argument("--kl_coef", type=float, default=0.02, help="KL散度惩罚系数")
|
||||||
|
parser.add_argument("--gamma", type=float, default=1.0, help="GAE折扣因子")
|
||||||
|
parser.add_argument("--lam", type=float, default=0.95, help="GAE lambda参数")
|
||||||
|
parser.add_argument("--cliprange_value", type=float, default=0.2, help="Value function裁剪范围")
|
||||||
|
parser.add_argument("--ppo_update_iters", type=int, default=2, help="同一批rollout重复更新次数")
|
||||||
|
parser.add_argument("--early_stop_kl", type=float, default=0.25, help="PPO early stop 的 KL 阈值")
|
||||||
|
parser.add_argument("--mini_batch_size", type=int, default=2, help="PPO每次更新的minibatch大小")
|
||||||
|
parser.add_argument('--from_weight', default='full_sft', type=str, help="基于哪个权重训练")
|
||||||
|
parser.add_argument("--reward_model_path", type=str, default="../../internlm2-1_8b-reward", help="Reward模型路径")
|
||||||
|
parser.add_argument('--from_resume', default=0, type=int, choices=[0, 1], help="是否自动检测&续训(0=否,1=是)")
|
||||||
|
parser.add_argument("--use_wandb", action="store_true", help="是否使用wandb")
|
||||||
|
parser.add_argument("--wandb_project", type=str, default="MiniMind-PPO", help="wandb项目名")
|
||||||
|
parser.add_argument("--use_compile", default=0, type=int, choices=[0, 1], help="是否使用torch.compile加速(0=否,1=是)")
|
||||||
|
parser.add_argument("--debug_mode", action="store_true", help="是否打印训练调试采样")
|
||||||
|
parser.add_argument("--debug_interval", type=int, default=20, help="debug模式下每隔多少step打印一次采样")
|
||||||
|
parser.add_argument("--thinking_ratio", type=float, default=0.9, help="按概率开启thinking(0.0~1.0)")
|
||||||
|
parser.add_argument("--rollout_engine", type=str, default="torch", choices=["torch", "sglang"], help="rollout引擎类型")
|
||||||
|
parser.add_argument("--sglang_base_url", type=str, default="http://localhost:8998", help="SGLang服务器URL")
|
||||||
|
parser.add_argument("--sglang_model_path", type=str, default="../model", help="SGLang tokenizer路径")
|
||||||
|
parser.add_argument("--sglang_shared_path", type=str, default="./sglang_ckpt_ppo", help="SGLang共享存储路径")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# ========== 1. 初始化环境和随机种子 ==========
|
||||||
|
local_rank = init_distributed_mode()
|
||||||
|
if dist.is_initialized(): args.device = f"cuda:{local_rank}"
|
||||||
|
setup_seed(42 + (dist.get_rank() if dist.is_initialized() else 0))
|
||||||
|
|
||||||
|
# ========== 2. 配置目录、模型参数、检查ckp ==========
|
||||||
|
os.makedirs(args.save_dir, exist_ok=True)
|
||||||
|
lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers, use_moe=bool(args.use_moe))
|
||||||
|
ckp_data = lm_checkpoint(lm_config, weight=args.save_weight, save_dir='../checkpoints') if args.from_resume==1 else None
|
||||||
|
|
||||||
|
# ========== 3. 设置混合精度 ==========
|
||||||
|
device_type = "cuda" if "cuda" in args.device else "cpu"
|
||||||
|
dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16
|
||||||
|
autocast_ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=dtype)
|
||||||
|
|
||||||
|
# ========== 4. 配wandb ==========
|
||||||
|
wandb = None
|
||||||
|
if args.use_wandb and is_main_process():
|
||||||
|
import swanlab as wandb
|
||||||
|
wandb_id = ckp_data.get('wandb_id') if ckp_data else None
|
||||||
|
resume = 'must' if wandb_id else None
|
||||||
|
wandb_run_name = f"MiniMind-PPO-Epoch-{args.epochs}-BS-{args.batch_size}-LR-{args.learning_rate}"
|
||||||
|
wandb.init(project=args.wandb_project, name=wandb_run_name, id=wandb_id, resume=resume)
|
||||||
|
|
||||||
|
# ========== 5. 初始化模型和数据 ==========
|
||||||
|
base_weight = args.from_weight
|
||||||
|
# Actor模型
|
||||||
|
actor_model, tokenizer = init_model(lm_config, base_weight, device=args.device)
|
||||||
|
ref_model, _ = init_model(lm_config, base_weight, device=args.device)
|
||||||
|
ref_model = ref_model.eval().requires_grad_(False)
|
||||||
|
moe_suffix = '_moe' if lm_config.use_moe else ''
|
||||||
|
ckp = f'{args.save_dir}/{base_weight}_{lm_config.hidden_size}{moe_suffix}.pth'
|
||||||
|
state_dict = torch.load(ckp, map_location=args.device)
|
||||||
|
critic_model = CriticModel(lm_config)
|
||||||
|
critic_model.load_state_dict(state_dict, strict=False)
|
||||||
|
critic_model = critic_model.to(args.device)
|
||||||
|
reward_model = LMForRewardModel(args.reward_model_path, device=args.device, dtype=torch.float16)
|
||||||
|
# Rollout引擎
|
||||||
|
rollout_engine = create_rollout_engine(
|
||||||
|
engine_type=args.rollout_engine,
|
||||||
|
policy_model=actor_model,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
device=args.device,
|
||||||
|
autocast_ctx=autocast_ctx,
|
||||||
|
sglang_base_url=args.sglang_base_url,
|
||||||
|
sglang_model_path=args.sglang_model_path,
|
||||||
|
sglang_shared_path=args.sglang_shared_path,
|
||||||
|
)
|
||||||
|
train_ds = RLAIFDataset(args.data_path, tokenizer, max_length=(args.max_seq_len + args.max_gen_len), thinking_ratio=args.thinking_ratio)
|
||||||
|
train_sampler = DistributedSampler(train_ds) if dist.is_initialized() else None
|
||||||
|
actor_optimizer = optim.AdamW(actor_model.parameters(), lr=args.learning_rate)
|
||||||
|
critic_optimizer = optim.AdamW(critic_model.parameters(), lr=args.critic_learning_rate)
|
||||||
|
loader_for_count = DataLoader(train_ds, batch_size=args.batch_size, sampler=train_sampler)
|
||||||
|
iters = len(loader_for_count)
|
||||||
|
mb_factor = max(1, math.ceil(args.batch_size / args.mini_batch_size))
|
||||||
|
total_optimizer_steps = math.ceil(iters * args.epochs * args.ppo_update_iters * mb_factor / args.accumulation_steps)
|
||||||
|
actor_scheduler = CosineAnnealingLR(actor_optimizer, T_max=total_optimizer_steps, eta_min=args.learning_rate / 10)
|
||||||
|
critic_scheduler = CosineAnnealingLR(critic_optimizer, T_max=total_optimizer_steps, eta_min=args.critic_learning_rate / 10)
|
||||||
|
|
||||||
|
start_epoch, start_step = 0, 0
|
||||||
|
if ckp_data:
|
||||||
|
actor_model.load_state_dict(ckp_data['model'])
|
||||||
|
critic_model.load_state_dict(ckp_data['critic_model'])
|
||||||
|
actor_optimizer.load_state_dict(ckp_data['optimizer'])
|
||||||
|
critic_optimizer.load_state_dict(ckp_data['critic_optimizer'])
|
||||||
|
actor_scheduler.load_state_dict(ckp_data['scheduler'])
|
||||||
|
critic_scheduler.load_state_dict(ckp_data['critic_scheduler'])
|
||||||
|
start_epoch = ckp_data['epoch']
|
||||||
|
start_step = ckp_data.get('step', 0)
|
||||||
|
|
||||||
|
# ========== 7. 编译和分布式包装 ==========
|
||||||
|
if args.use_compile == 1:
|
||||||
|
actor_model = torch.compile(actor_model)
|
||||||
|
Logger('torch.compile enabled')
|
||||||
|
rollout_engine.update_policy(actor_model)
|
||||||
|
if dist.is_initialized():
|
||||||
|
actor_model = DistributedDataParallel(actor_model, device_ids=[local_rank])
|
||||||
|
critic_model = DistributedDataParallel(critic_model, device_ids=[local_rank])
|
||||||
|
rollout_engine.update_policy(actor_model)
|
||||||
|
|
||||||
|
# ========== 8. 开始训练 ==========
|
||||||
|
for epoch in range(start_epoch, args.epochs):
|
||||||
|
train_sampler and train_sampler.set_epoch(epoch)
|
||||||
|
setup_seed(42 + epoch); indices = torch.randperm(len(train_ds)).tolist()
|
||||||
|
skip = start_step if (epoch == start_epoch and start_step > 0) else 0
|
||||||
|
batch_sampler = SkipBatchSampler(train_sampler or indices, args.batch_size, skip)
|
||||||
|
loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True)
|
||||||
|
if skip > 0:
|
||||||
|
Logger(f'Epoch [{epoch + 1}/{args.epochs}]: 跳过前{start_step}个step,从step {start_step + 1}开始')
|
||||||
|
ppo_train_epoch(epoch, loader, len(loader) + skip, rollout_engine, ref_model, actor_scheduler, critic_scheduler, reward_model, start_step, wandb, use_sglang = (args.rollout_engine == "sglang"))
|
||||||
|
else:
|
||||||
|
ppo_train_epoch(epoch, loader, len(loader), rollout_engine, ref_model, actor_scheduler, critic_scheduler, reward_model, 0, wandb, use_sglang = (args.rollout_engine == "sglang"))
|
||||||
|
|
||||||
|
# ========== 9. 清理分布进程 ==========
|
||||||
|
if dist.is_initialized(): dist.destroy_process_group()
|
||||||
169
trainer/train_pretrain.py
Normal file
@ -0,0 +1,169 @@
|
|||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
__package__ = "trainer"
|
||||||
|
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import time
|
||||||
|
import warnings
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
from contextlib import nullcontext
|
||||||
|
from torch import optim, nn
|
||||||
|
from torch.nn.parallel import DistributedDataParallel
|
||||||
|
from torch.utils.data import DataLoader, DistributedSampler
|
||||||
|
from model.model_minimind import MiniMindConfig
|
||||||
|
from dataset.lm_dataset import PretrainDataset
|
||||||
|
from trainer.trainer_utils import get_lr, Logger, is_main_process, lm_checkpoint, init_distributed_mode, setup_seed, init_model, SkipBatchSampler
|
||||||
|
|
||||||
|
warnings.filterwarnings('ignore')
|
||||||
|
|
||||||
|
|
||||||
|
def train_epoch(epoch, loader, iters, start_step=0, wandb=None):
|
||||||
|
start_time = time.time()
|
||||||
|
last_step = start_step
|
||||||
|
for step, (input_ids, labels) in enumerate(loader, start=start_step + 1):
|
||||||
|
input_ids = input_ids.to(args.device)
|
||||||
|
labels = labels.to(args.device)
|
||||||
|
last_step = step
|
||||||
|
lr = get_lr(epoch * iters + step, args.epochs * iters, args.learning_rate)
|
||||||
|
for param_group in optimizer.param_groups:
|
||||||
|
param_group['lr'] = lr
|
||||||
|
|
||||||
|
with autocast_ctx:
|
||||||
|
res = model(input_ids, labels=labels)
|
||||||
|
loss = res.loss + res.aux_loss
|
||||||
|
loss = loss / args.accumulation_steps
|
||||||
|
|
||||||
|
scaler.scale(loss).backward()
|
||||||
|
|
||||||
|
if step % args.accumulation_steps == 0:
|
||||||
|
scaler.unscale_(optimizer)
|
||||||
|
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
|
||||||
|
|
||||||
|
scaler.step(optimizer)
|
||||||
|
scaler.update()
|
||||||
|
|
||||||
|
optimizer.zero_grad(set_to_none=True)
|
||||||
|
|
||||||
|
if step % args.log_interval == 0 or step == iters:
|
||||||
|
spend_time = time.time() - start_time
|
||||||
|
current_loss = loss.item() * args.accumulation_steps
|
||||||
|
current_aux_loss = res.aux_loss.item() if res.aux_loss is not None else 0.0
|
||||||
|
current_logits_loss = current_loss - current_aux_loss
|
||||||
|
current_lr = optimizer.param_groups[-1]['lr']
|
||||||
|
eta_min = spend_time / max(step - start_step, 1) * (iters - step) // 60
|
||||||
|
Logger(f'Epoch:[{epoch + 1}/{args.epochs}]({step}/{iters}), loss: {current_loss:.4f}, logits_loss: {current_logits_loss:.4f}, aux_loss: {current_aux_loss:.4f}, lr: {current_lr:.8f}, epoch_time: {eta_min:.1f}min')
|
||||||
|
if wandb: wandb.log({"loss": current_loss, "logits_loss": current_logits_loss, "aux_loss": current_aux_loss, "learning_rate": current_lr, "epoch_time": eta_min})
|
||||||
|
|
||||||
|
if (step % args.save_interval == 0 or step == iters) and is_main_process():
|
||||||
|
model.eval()
|
||||||
|
moe_suffix = '_moe' if lm_config.use_moe else ''
|
||||||
|
ckp = f'{args.save_dir}/{args.save_weight}_{lm_config.hidden_size}{moe_suffix}.pth'
|
||||||
|
raw_model = model.module if isinstance(model, DistributedDataParallel) else model
|
||||||
|
raw_model = getattr(raw_model, '_orig_mod', raw_model)
|
||||||
|
state_dict = raw_model.state_dict()
|
||||||
|
torch.save({k: v.half().cpu() for k, v in state_dict.items()}, ckp)
|
||||||
|
lm_checkpoint(lm_config, weight=args.save_weight, model=model, optimizer=optimizer, scaler=scaler, epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints')
|
||||||
|
model.train()
|
||||||
|
del state_dict
|
||||||
|
|
||||||
|
del input_ids, labels, res, loss
|
||||||
|
|
||||||
|
if last_step > start_step and last_step % args.accumulation_steps != 0:
|
||||||
|
scaler.unscale_(optimizer)
|
||||||
|
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
|
||||||
|
scaler.step(optimizer)
|
||||||
|
scaler.update()
|
||||||
|
optimizer.zero_grad(set_to_none=True)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(description="MiniMind Pretraining")
|
||||||
|
parser.add_argument("--save_dir", type=str, default="../out", help="模型保存目录")
|
||||||
|
parser.add_argument('--save_weight', default='pretrain', type=str, help="保存权重的前缀名")
|
||||||
|
parser.add_argument("--epochs", type=int, default=2, help="训练轮数")
|
||||||
|
parser.add_argument("--batch_size", type=int, default=32, help="batch size")
|
||||||
|
parser.add_argument("--learning_rate", type=float, default=5e-4, help="初始学习率")
|
||||||
|
parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu", help="训练设备")
|
||||||
|
parser.add_argument("--dtype", type=str, default="bfloat16", help="混合精度类型")
|
||||||
|
parser.add_argument("--num_workers", type=int, default=8, help="数据加载线程数")
|
||||||
|
parser.add_argument("--accumulation_steps", type=int, default=8, help="梯度累积步数")
|
||||||
|
parser.add_argument("--grad_clip", type=float, default=1.0, help="梯度裁剪阈值")
|
||||||
|
parser.add_argument("--log_interval", type=int, default=100, help="日志打印间隔")
|
||||||
|
parser.add_argument("--save_interval", type=int, default=1000, help="模型保存间隔")
|
||||||
|
parser.add_argument('--hidden_size', default=768, type=int, help="隐藏层维度")
|
||||||
|
parser.add_argument('--num_hidden_layers', default=8, type=int, help="隐藏层数量")
|
||||||
|
parser.add_argument('--max_seq_len', default=340, type=int, help="训练的最大截断长度(中文1token≈1.5~1.7字符)")
|
||||||
|
parser.add_argument('--use_moe', default=0, type=int, choices=[0, 1], help="是否使用MoE架构(0=否,1=是)")
|
||||||
|
parser.add_argument("--data_path", type=str, default="../dataset/pretrain_t2t_mini.jsonl", help="预训练数据路径")
|
||||||
|
parser.add_argument('--from_weight', default='none', type=str, help="基于哪个权重训练,为none则从头开始")
|
||||||
|
parser.add_argument('--from_resume', default=0, type=int, choices=[0, 1], help="是否自动检测&续训(0=否,1=是)")
|
||||||
|
parser.add_argument("--use_wandb", action="store_true", help="是否使用wandb")
|
||||||
|
parser.add_argument("--wandb_project", type=str, default="MiniMind-Pretrain", help="wandb项目名")
|
||||||
|
parser.add_argument("--use_compile", default=0, type=int, choices=[0, 1], help="是否使用torch.compile加速(0=否,1=是)")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# ========== 1. 初始化环境和随机种子 ==========
|
||||||
|
local_rank = init_distributed_mode()
|
||||||
|
if dist.is_initialized(): args.device = f"cuda:{local_rank}"
|
||||||
|
setup_seed(42 + (dist.get_rank() if dist.is_initialized() else 0))
|
||||||
|
|
||||||
|
# ========== 2. 配置目录、模型参数、检查ckp ==========
|
||||||
|
os.makedirs(args.save_dir, exist_ok=True)
|
||||||
|
lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers, use_moe=bool(args.use_moe))
|
||||||
|
ckp_data = lm_checkpoint(lm_config, weight=args.save_weight, save_dir='../checkpoints') if args.from_resume==1 else None
|
||||||
|
|
||||||
|
# ========== 3. 设置混合精度 ==========
|
||||||
|
device_type = "cuda" if "cuda" in args.device else "cpu"
|
||||||
|
dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16
|
||||||
|
autocast_ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=dtype)
|
||||||
|
|
||||||
|
# ========== 4. 配wandb ==========
|
||||||
|
wandb = None
|
||||||
|
if args.use_wandb and is_main_process():
|
||||||
|
import swanlab as wandb
|
||||||
|
wandb_id = ckp_data.get('wandb_id') if ckp_data else None
|
||||||
|
resume = 'must' if wandb_id else None
|
||||||
|
wandb_run_name = f"MiniMind-Pretrain-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LearningRate-{args.learning_rate}"
|
||||||
|
wandb.init(project=args.wandb_project, name=wandb_run_name, id=wandb_id, resume=resume)
|
||||||
|
|
||||||
|
# ========== 5. 定义模型、数据、优化器 ==========
|
||||||
|
model, tokenizer = init_model(lm_config, args.from_weight, device=args.device)
|
||||||
|
train_ds = PretrainDataset(args.data_path, tokenizer, max_length=args.max_seq_len)
|
||||||
|
train_sampler = DistributedSampler(train_ds) if dist.is_initialized() else None
|
||||||
|
scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype == 'float16'))
|
||||||
|
optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)
|
||||||
|
|
||||||
|
# ========== 6. 从ckp恢复状态 ==========
|
||||||
|
start_epoch, start_step = 0, 0
|
||||||
|
if ckp_data:
|
||||||
|
model.load_state_dict(ckp_data['model'])
|
||||||
|
optimizer.load_state_dict(ckp_data['optimizer'])
|
||||||
|
scaler.load_state_dict(ckp_data['scaler'])
|
||||||
|
start_epoch = ckp_data['epoch']
|
||||||
|
start_step = ckp_data.get('step', 0)
|
||||||
|
|
||||||
|
# ========== 7. 编译和分布式包装 ==========
|
||||||
|
if args.use_compile == 1:
|
||||||
|
model = torch.compile(model)
|
||||||
|
Logger('torch.compile enabled')
|
||||||
|
if dist.is_initialized():
|
||||||
|
model = DistributedDataParallel(model, device_ids=[local_rank])
|
||||||
|
|
||||||
|
# ========== 8. 开始训练 ==========
|
||||||
|
for epoch in range(start_epoch, args.epochs):
|
||||||
|
train_sampler and train_sampler.set_epoch(epoch)
|
||||||
|
setup_seed(42 + epoch); indices = torch.randperm(len(train_ds)).tolist()
|
||||||
|
skip = start_step if (epoch == start_epoch and start_step > 0) else 0
|
||||||
|
batch_sampler = SkipBatchSampler(train_sampler or indices, args.batch_size, skip)
|
||||||
|
loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True)
|
||||||
|
if skip > 0:
|
||||||
|
Logger(f'Epoch [{epoch + 1}/{args.epochs}]: 跳过前{start_step}个step,从step {start_step + 1}开始')
|
||||||
|
train_epoch(epoch, loader, len(loader) + skip, start_step, wandb)
|
||||||
|
else:
|
||||||
|
train_epoch(epoch, loader, len(loader), 0, wandb)
|
||||||
|
|
||||||
|
# ========== 9. 清理分布进程 ==========
|
||||||
|
if dist.is_initialized(): dist.destroy_process_group()
|
||||||
168
trainer/train_tokenizer.py
Normal file
@ -0,0 +1,168 @@
|
|||||||
|
# 注:不建议再重复训练tokenizer(“词典”),MiniMind已自带,此脚本仅供学习和参考。基于不同词典训练的模型将导致输出完全不统一,降低社区的模型复用性
|
||||||
|
# Note: It is not recommended to re-train the tokenizer. MiniMind already includes one. This script is for learning and reference only. Training models with different tokenizers will lead to inconsistent outputs and reduce model reusability in the community.
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
from tokenizers import decoders, models, pre_tokenizers, trainers, Tokenizer
|
||||||
|
|
||||||
|
DATA_PATH = '../dataset/sft_t2t_mini.jsonl'
|
||||||
|
TOKENIZER_DIR = '../model_learn_tokenizer/'
|
||||||
|
VOCAB_SIZE = 6400
|
||||||
|
SPECIAL_TOKENS_NUM = 36
|
||||||
|
|
||||||
|
def get_texts(data_path):
|
||||||
|
with open(data_path, 'r', encoding='utf-8', errors='ignore') as f:
|
||||||
|
for i, line in enumerate(f):
|
||||||
|
if i >= 10000: break # 选10000行测试
|
||||||
|
try:
|
||||||
|
data = json.loads(line)
|
||||||
|
contents = [item.get('content') for item in data.get('conversations', []) if item.get('content')]
|
||||||
|
if contents:
|
||||||
|
yield "\n".join(contents)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
def train_tokenizer(data_path, tokenizer_dir, vocab_size, special_tokens_num=SPECIAL_TOKENS_NUM):
|
||||||
|
tokenizer = Tokenizer(models.BPE())
|
||||||
|
tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False)
|
||||||
|
|
||||||
|
special_tokens_list = [
|
||||||
|
"<|endoftext|>", "<|im_start|>", "<|im_end|>",
|
||||||
|
"<|object_ref_start|>", "<|object_ref_end|>", "<|box_start|>", "<|box_end|>", "<|quad_start|>", "<|quad_end|>",
|
||||||
|
"<|vision_start|>", "<|vision_end|>", "<|vision_pad|>", "<|image_pad|>", "<|video_pad|>",
|
||||||
|
"<|audio_start|>", "<|audio_end|>", "<|audio_pad|>", "<tts_pad>", "<tts_text_bos>", "<tts_text_eod>", "<tts_text_bos_single>"
|
||||||
|
]
|
||||||
|
|
||||||
|
additional_tokens_list = [
|
||||||
|
"<tool_call>", "</tool_call>",
|
||||||
|
"<tool_response>", "</tool_response>",
|
||||||
|
"<think>", "</think>"
|
||||||
|
]
|
||||||
|
num_buffer = special_tokens_num - len(special_tokens_list + additional_tokens_list)
|
||||||
|
buffer_tokens = [f"<|buffer{i}|>" for i in range(1, num_buffer + 1)] # 预留一定数量的token位置
|
||||||
|
all_special_tokens = special_tokens_list + additional_tokens_list + buffer_tokens
|
||||||
|
trainer = trainers.BpeTrainer(
|
||||||
|
vocab_size=vocab_size,
|
||||||
|
show_progress=True,
|
||||||
|
initial_alphabet=pre_tokenizers.ByteLevel.alphabet(),
|
||||||
|
special_tokens=all_special_tokens
|
||||||
|
)
|
||||||
|
texts = get_texts(data_path)
|
||||||
|
tokenizer.train_from_iterator(texts, trainer=trainer)
|
||||||
|
tokenizer.decoder = decoders.ByteLevel()
|
||||||
|
tokenizer.add_special_tokens(special_tokens_list)
|
||||||
|
|
||||||
|
os.makedirs(tokenizer_dir, exist_ok=True)
|
||||||
|
tokenizer.save(os.path.join(tokenizer_dir, "tokenizer.json"))
|
||||||
|
tokenizer.model.save(tokenizer_dir)
|
||||||
|
tokenizer_json_path = os.path.join(tokenizer_dir, "tokenizer.json")
|
||||||
|
with open(tokenizer_json_path, 'r', encoding='utf-8') as f:
|
||||||
|
tokenizer_data = json.load(f)
|
||||||
|
for token_info in tokenizer_data.get('added_tokens', []):
|
||||||
|
if token_info['content'] not in special_tokens_list:
|
||||||
|
token_info['special'] = False
|
||||||
|
with open(tokenizer_json_path, 'w', encoding='utf-8') as f:
|
||||||
|
json.dump(tokenizer_data, f, ensure_ascii=False, indent=2)
|
||||||
|
|
||||||
|
added_tokens_decoder = {}
|
||||||
|
for i, token in enumerate(all_special_tokens):
|
||||||
|
idx = tokenizer.token_to_id(token)
|
||||||
|
added_tokens_decoder[str(idx)] = {
|
||||||
|
"content": token,
|
||||||
|
"lstrip": False,
|
||||||
|
"normalized": False,
|
||||||
|
"rstrip": False,
|
||||||
|
"single_word": False,
|
||||||
|
"special": True if token in special_tokens_list else False
|
||||||
|
}
|
||||||
|
|
||||||
|
config = {
|
||||||
|
"add_bos_token": False,
|
||||||
|
"add_eos_token": False,
|
||||||
|
"add_prefix_space": False,
|
||||||
|
"added_tokens_decoder": added_tokens_decoder,
|
||||||
|
"additional_special_tokens": [t for t in special_tokens_list if t not in ["<|endoftext|>"]],
|
||||||
|
"bos_token": "<|im_start|>",
|
||||||
|
"clean_up_tokenization_spaces": False,
|
||||||
|
"eos_token": "<|im_end|>",
|
||||||
|
"legacy": True,
|
||||||
|
"model_max_length": 131072,
|
||||||
|
"pad_token": "<|endoftext|>",
|
||||||
|
"sp_model_kwargs": {},
|
||||||
|
"spaces_between_special_tokens": False,
|
||||||
|
"unk_token": "<|endoftext|>",
|
||||||
|
"image_token": "<|image_pad|>",
|
||||||
|
"audio_token": "<|audio_pad|>",
|
||||||
|
"video_token": "<|video_pad|>",
|
||||||
|
"vision_bos_token": "<|vision_start|>",
|
||||||
|
"vision_eos_token": "<|vision_end|>",
|
||||||
|
"audio_bos_token": "<|audio_start|>",
|
||||||
|
"audio_eos_token": "<|audio_end|>",
|
||||||
|
"chat_template": "{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0].role == 'system' %}\n {{- messages[0].content + '\\n\\n' }}\n {%- endif %}\n {{- \"# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within <tools></tools> XML tags:\\n<tools>\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n</tools>\\n\\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\\n<tool_call>\\n{\\\"name\\\": <function-name>, \\\"arguments\\\": <args-json-object>}\\n</tool_call><|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0].role == 'system' %}\n {{- '<|im_start|>system\\n' + messages[0].content + '<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}\n{%- for message in messages[::-1] %}\n {%- set index = (messages|length - 1) - loop.index0 %}\n {%- if ns.multi_step_tool and message.role == \"user\" and message.content is string and not(message.content.startswith('<tool_response>') and message.content.endswith('</tool_response>')) %}\n {%- set ns.multi_step_tool = false %}\n {%- set ns.last_query_index = index %}\n {%- endif %}\n{%- endfor %}\n{%- for message in messages %}\n {%- if message.content is string %}\n {%- set content = message.content %}\n {%- else %}\n {%- set content = '' %}\n {%- endif %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) %}\n {{- '<|im_start|>' + message.role + '\\n' + content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {%- set reasoning_content = '' %}\n {%- if message.reasoning_content is string %}\n {%- set reasoning_content = message.reasoning_content %}\n {%- else %}\n {%- if '</think>' in content %}\n {%- set reasoning_content = content.split('</think>')[0].rstrip('\\n').split('<think>')[-1].lstrip('\\n') %}\n {%- set content = content.split('</think>')[-1].lstrip('\\n') %}\n {%- endif %}\n {%- endif %}\n {%- if true %}\n {{- '<|im_start|>' + message.role + '\\n<think>\\n' + reasoning_content.strip('\\n') + '\\n</think>\\n\\n' + content.lstrip('\\n') }}\n {%- endif %}\n {%- if message.tool_calls %}\n {%- for tool_call in message.tool_calls %}\n {%- if (loop.first and content) or (not loop.first) %}\n {{- '\\n' }}\n {%- endif %}\n {%- if tool_call.function %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '<tool_call>\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {%- if tool_call.arguments is string %}\n {{- tool_call.arguments }}\n {%- else %}\n {{- tool_call.arguments | tojson }}\n {%- endif %}\n {{- '}\\n</tool_call>' }}\n {%- endfor %}\n {%- endif %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if loop.first or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n<tool_response>\\n' }}\n {{- content }}\n {{- '\\n</tool_response>' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n {%- if open_thinking is defined and open_thinking is true %}\n {{- '<think>\\n' }}\n {%- else %}\n {{- '<think>\\n\\n</think>\\n\\n' }}\n {%- endif %}\n{%- endif %}",
|
||||||
|
"tokenizer_class": "PreTrainedTokenizerFast"
|
||||||
|
}
|
||||||
|
|
||||||
|
with open(os.path.join(tokenizer_dir, "tokenizer_config.json"), "w", encoding="utf-8") as f:
|
||||||
|
json.dump(config, f, ensure_ascii=False, indent=4)
|
||||||
|
print("Tokenizer training completed.")
|
||||||
|
|
||||||
|
def eval_tokenizer(tokenizer_dir):
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir)
|
||||||
|
messages = [
|
||||||
|
{"role": "system", "content": "你是一个优秀的聊天机器人,总是给我正确的回应!"},
|
||||||
|
{"role": "user", "content": '你来自哪里?'},
|
||||||
|
{"role": "assistant", "content": '我来自月球'},
|
||||||
|
{"role": "user", "content": '你到底来自哪里?'},
|
||||||
|
{"role": "assistant", "content": '我来自地球'}
|
||||||
|
]
|
||||||
|
new_prompt = tokenizer.apply_chat_template(
|
||||||
|
messages,
|
||||||
|
tokenize=False
|
||||||
|
)
|
||||||
|
print('-'*100)
|
||||||
|
print(new_prompt)
|
||||||
|
print('-'*100)
|
||||||
|
print('tokenizer词表长度:', len(tokenizer))
|
||||||
|
model_inputs = tokenizer(new_prompt)
|
||||||
|
print('encoder长度:', len(model_inputs['input_ids']))
|
||||||
|
response = tokenizer.decode(model_inputs['input_ids'], skip_special_tokens=False)
|
||||||
|
print('decoder一致性:', response == new_prompt, "\n")
|
||||||
|
print('-'*100)
|
||||||
|
print('压缩率测试(Chars/Tokens):')
|
||||||
|
test_texts = [
|
||||||
|
# 中文样本 (约200字)
|
||||||
|
"人工智能是计算机科学的一个分支,它企图了解智能的实质,并生产出一种新的能以人类智能相似的方式做出反应的智能机器,该领域的研究包括机器人、语言识别、图像识别、自然语言处理和专家系统等。人工智能从诞生以来,理论和技术日益成熟,应用领域也不断扩大,可以设想,未来人工智能带来的科技产品,将会是人类智慧的“容器”。人工智能可以对人的意识、思维的信息过程的模拟。人工智能不是人的智能,但能像人那样思考、也可能超过人的智能。",
|
||||||
|
"星际航行是指在星系内甚至星系间的空间中进行的航行。由于宇宙空间极其广阔,传统的化学火箭动力在恒星间航行时显得力不从心。科学家们提出了多种方案,包括离子推进器、核热火箭、甚至是利用反物质作为能源的设想。此外,曲率驱动和虫洞旅行等科幻概念也在理论物理研究中被反复探讨。尽管目前人类的足迹仅限于月球,但随着核聚变技术和材料科学的突破,前往火星乃至更遥远的太阳系边缘将成为可能。",
|
||||||
|
# 英文样本 (约200词/字符)
|
||||||
|
"Large language models (LLMs) are a type of artificial intelligence (AI) trained on vast amounts of text data to understand and generate human-like language. These models use deep learning techniques, specifically transformers, to process and predict the next word in a sequence. LLMs like GPT-4, Llama, and Claude have demonstrated remarkable capabilities in coding, translation, and creative writing. However, they also face challenges such as hallucinations, where the model generates factually incorrect information, and the need for significant computational resources.",
|
||||||
|
"The development of sustainable energy is crucial for the future of our planet. As climate change continues to impact global weather patterns, transitioning from fossil fuels to renewable sources like solar, wind, and hydroelectric power has become an urgent priority. Innovations in battery storage technology and smart grid management are essential to ensure a reliable energy supply. International cooperation and policy frameworks are also necessary to drive the global shift towards a greener economy and reduce carbon emissions.",
|
||||||
|
# 混合样本
|
||||||
|
"Python 是一种高级编程语言,以其简洁的语法和强大的生态系统而闻名。It is widely used in data science, machine learning, and web development. 开发者可以利用 NumPy, Pandas, and PyTorch 等库快速构建复杂的应用。学习 Python 的过程非常愉快,因为它的代码读起来就像英语一样。Whether you are a beginner or an expert, Python offers something for everyone.",
|
||||||
|
]
|
||||||
|
|
||||||
|
total_compression = 0
|
||||||
|
for i, text in enumerate(test_texts):
|
||||||
|
encoded = tokenizer.encode(text)
|
||||||
|
token_count = len(encoded)
|
||||||
|
char_count = len(text)
|
||||||
|
compression_ratio = char_count / token_count
|
||||||
|
total_compression += compression_ratio
|
||||||
|
print(f"样本 {i+1} | 字符数: {char_count:4} | Tokens: {token_count:3} | 压缩率: {compression_ratio:.2f}")
|
||||||
|
|
||||||
|
print(f"平均压缩率: {total_compression / len(test_texts):.2f}")
|
||||||
|
print('-'*100)
|
||||||
|
print('流式解码(字节缓冲)测试:')
|
||||||
|
input_ids = model_inputs['input_ids']
|
||||||
|
token_cache = []
|
||||||
|
for tid in input_ids:
|
||||||
|
token_cache.append(tid)
|
||||||
|
current_decode = tokenizer.decode(token_cache)
|
||||||
|
if current_decode and '\ufffd' not in current_decode:
|
||||||
|
display_ids = token_cache[0] if len(token_cache) == 1 else token_cache
|
||||||
|
raw_tokens = [tokenizer.convert_ids_to_tokens(int(t)) for t in (token_cache if isinstance(token_cache, list) else [token_cache])]
|
||||||
|
print(f'Token ID: {str(display_ids):15} -> Raw: {str(raw_tokens):20} -> Decode Str: {current_decode}')
|
||||||
|
token_cache = []
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
train_tokenizer(DATA_PATH, TOKENIZER_DIR, VOCAB_SIZE)
|
||||||
|
eval_tokenizer(TOKENIZER_DIR)
|
||||||
177
trainer/trainer_utils.py
Normal file
@ -0,0 +1,177 @@
|
|||||||
|
"""
|
||||||
|
训练工具函数集合
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
__package__ = "trainer"
|
||||||
|
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
||||||
|
import random
|
||||||
|
import math
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
from torch.nn.parallel import DistributedDataParallel
|
||||||
|
from torch.utils.data import Sampler
|
||||||
|
from transformers import AutoTokenizer, AutoModel, AutoModelForSequenceClassification
|
||||||
|
from model.model_minimind import MiniMindForCausalLM
|
||||||
|
|
||||||
|
def get_model_params(model, config):
|
||||||
|
total = sum(p.numel() for p in model.parameters()) / 1e6
|
||||||
|
n_routed = getattr(config, 'n_routed_experts', getattr(config, 'num_experts', 0))
|
||||||
|
n_active = getattr(config, 'num_experts_per_tok', 0)
|
||||||
|
n_shared = getattr(config, 'n_shared_experts', 0)
|
||||||
|
expert = sum(p.numel() for n, p in model.named_parameters() if 'mlp.experts.0.' in n) / 1e6
|
||||||
|
shared_expert = sum(p.numel() for n, p in model.named_parameters() if 'mlp.shared_experts.0.' in n) / 1e6
|
||||||
|
base = total - (expert * n_routed) - (shared_expert * n_shared)
|
||||||
|
active = base + (expert * n_active) + (shared_expert * n_shared)
|
||||||
|
if active < total: Logger(f'Model Params: {total:.2f}M-A{active:.2f}M')
|
||||||
|
else: Logger(f'Model Params: {total:.2f}M')
|
||||||
|
|
||||||
|
|
||||||
|
def is_main_process():
|
||||||
|
return not dist.is_initialized() or dist.get_rank() == 0
|
||||||
|
|
||||||
|
|
||||||
|
def Logger(content):
|
||||||
|
if is_main_process():
|
||||||
|
print(content)
|
||||||
|
|
||||||
|
|
||||||
|
def get_lr(current_step, total_steps, lr):
|
||||||
|
return lr*(0.1 + 0.45*(1 + math.cos(math.pi * current_step / total_steps)))
|
||||||
|
|
||||||
|
|
||||||
|
def init_distributed_mode():
|
||||||
|
if int(os.environ.get("RANK", -1)) == -1:
|
||||||
|
return 0 # 非DDP模式
|
||||||
|
|
||||||
|
dist.init_process_group(backend="nccl")
|
||||||
|
local_rank = int(os.environ["LOCAL_RANK"])
|
||||||
|
torch.cuda.set_device(local_rank)
|
||||||
|
return local_rank
|
||||||
|
|
||||||
|
|
||||||
|
def setup_seed(seed: int):
|
||||||
|
random.seed(seed)
|
||||||
|
np.random.seed(seed)
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
torch.cuda.manual_seed(seed)
|
||||||
|
torch.cuda.manual_seed_all(seed)
|
||||||
|
torch.backends.cudnn.deterministic = True
|
||||||
|
torch.backends.cudnn.benchmark = False
|
||||||
|
|
||||||
|
def lm_checkpoint(lm_config, weight='full_sft', model=None, optimizer=None, epoch=0, step=0, wandb=None, save_dir='../checkpoints', **kwargs):
|
||||||
|
os.makedirs(save_dir, exist_ok=True)
|
||||||
|
moe_path = '_moe' if lm_config.use_moe else ''
|
||||||
|
ckp_path = f'{save_dir}/{weight}_{lm_config.hidden_size}{moe_path}.pth'
|
||||||
|
resume_path = f'{save_dir}/{weight}_{lm_config.hidden_size}{moe_path}_resume.pth'
|
||||||
|
|
||||||
|
if model is not None:
|
||||||
|
raw_model = model.module if isinstance(model, DistributedDataParallel) else model
|
||||||
|
raw_model = getattr(raw_model, '_orig_mod', raw_model)
|
||||||
|
state_dict = raw_model.state_dict()
|
||||||
|
state_dict = {k: v.half().cpu() for k, v in state_dict.items()}
|
||||||
|
ckp_tmp = ckp_path + '.tmp'
|
||||||
|
torch.save(state_dict, ckp_tmp)
|
||||||
|
os.replace(ckp_tmp, ckp_path)
|
||||||
|
wandb_id = None
|
||||||
|
if wandb:
|
||||||
|
if hasattr(wandb, 'get_run'):
|
||||||
|
run = wandb.get_run()
|
||||||
|
wandb_id = getattr(run, 'id', None) if run else None
|
||||||
|
else:
|
||||||
|
wandb_id = getattr(wandb, 'id', None)
|
||||||
|
|
||||||
|
resume_data = {
|
||||||
|
'model': state_dict,
|
||||||
|
'optimizer': optimizer.state_dict(),
|
||||||
|
'epoch': epoch,
|
||||||
|
'step': step,
|
||||||
|
'world_size': dist.get_world_size() if dist.is_initialized() else 1,
|
||||||
|
'wandb_id': wandb_id
|
||||||
|
}
|
||||||
|
for key, value in kwargs.items():
|
||||||
|
if value is not None:
|
||||||
|
if hasattr(value, 'state_dict'):
|
||||||
|
raw_value = value.module if isinstance(value, DistributedDataParallel) else value
|
||||||
|
raw_value = getattr(raw_value, '_orig_mod', raw_value)
|
||||||
|
resume_data[key] = raw_value.state_dict()
|
||||||
|
else:
|
||||||
|
resume_data[key] = value
|
||||||
|
|
||||||
|
resume_tmp = resume_path + '.tmp'
|
||||||
|
torch.save(resume_data, resume_tmp)
|
||||||
|
os.replace(resume_tmp, resume_path)
|
||||||
|
del state_dict, resume_data
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
else: # 加载模式
|
||||||
|
if os.path.exists(resume_path):
|
||||||
|
ckp_data = torch.load(resume_path, map_location='cpu')
|
||||||
|
saved_ws = ckp_data.get('world_size', 1)
|
||||||
|
current_ws = dist.get_world_size() if dist.is_initialized() else 1
|
||||||
|
if saved_ws != current_ws:
|
||||||
|
ckp_data['step'] = ckp_data['step'] * saved_ws // current_ws
|
||||||
|
Logger(f'GPU数量变化({saved_ws}→{current_ws}),step已自动转换为{ckp_data["step"]}')
|
||||||
|
return ckp_data
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def init_model(lm_config, from_weight='pretrain', tokenizer_path='../model', save_dir='../out', device='cuda'):
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
|
||||||
|
model = MiniMindForCausalLM(lm_config)
|
||||||
|
|
||||||
|
if from_weight!= 'none':
|
||||||
|
moe_suffix = '_moe' if lm_config.use_moe else ''
|
||||||
|
weight_path = f'{save_dir}/{from_weight}_{lm_config.hidden_size}{moe_suffix}.pth'
|
||||||
|
weights = torch.load(weight_path, map_location=device)
|
||||||
|
model.load_state_dict(weights, strict=False)
|
||||||
|
|
||||||
|
get_model_params(model, lm_config)
|
||||||
|
Logger(f'Trainable Params: {sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f}M')
|
||||||
|
return model.to(device), tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
class SkipBatchSampler(Sampler):
|
||||||
|
def __init__(self, sampler, batch_size, skip_batches=0):
|
||||||
|
self.sampler = sampler
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.skip_batches = skip_batches
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
batch = []
|
||||||
|
skipped = 0
|
||||||
|
for idx in self.sampler:
|
||||||
|
batch.append(idx)
|
||||||
|
if len(batch) == self.batch_size:
|
||||||
|
if skipped < self.skip_batches:
|
||||||
|
skipped += 1
|
||||||
|
batch = []
|
||||||
|
continue
|
||||||
|
yield batch
|
||||||
|
batch = []
|
||||||
|
if len(batch) > 0 and skipped >= self.skip_batches:
|
||||||
|
yield batch
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
total_batches = (len(self.sampler) + self.batch_size - 1) // self.batch_size
|
||||||
|
return max(0, total_batches - self.skip_batches)
|
||||||
|
|
||||||
|
|
||||||
|
class LMForRewardModel:
|
||||||
|
def __init__(self, model_path, device="cuda", dtype=torch.float16):
|
||||||
|
self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
||||||
|
self.model = AutoModel.from_pretrained(model_path, torch_dtype=dtype, trust_remote_code=True)
|
||||||
|
self.model = self.model.to(device).eval()
|
||||||
|
self.device = device
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def get_score(self, messages, response):
|
||||||
|
history_text = "\n".join([f"{m['role']}: {m['content']}" for m in messages[:-1]])
|
||||||
|
last_query = messages[-1]['content'] if messages else ""
|
||||||
|
message_context = f"{history_text}\n以上是对话历史。我的新问题是:\n{last_query}" if history_text else last_query
|
||||||
|
eval_messages = [
|
||||||
|
{"role": "user", "content": message_context},
|
||||||
|
{"role": "assistant", "content": response}
|
||||||
|
]
|
||||||
|
score = self.model.get_score(self.tokenizer, eval_messages)
|
||||||
|
return max(min(score, 3.0), -3.0)
|
||||||