mirror of
https://github.com/jingyaogong/minimind.git
synced 2026-05-01 11:48:14 +08:00
update app
This commit is contained in:
parent
561979c7e3
commit
36159fb2ab
47
eval_model.py → eval_llm.py
Normal file → Executable file
47
eval_model.py → eval_llm.py
Normal file → Executable file
@ -13,13 +13,14 @@ def init_model(args):
|
||||
tokenizer = AutoTokenizer.from_pretrained('./model/')
|
||||
if args.load == 0:
|
||||
moe_path = '_moe' if args.use_moe else ''
|
||||
modes = {0: 'pretrain', 1: 'full_sft', 2: 'rlhf', 3: 'reason', 4: 'grpo'}
|
||||
modes = {0: 'pretrain', 1: 'full_sft', 2: 'rlhf', 3: 'reason', 4: 'ppo_actor', 5: 'grpo'}
|
||||
ckp = f'./{args.out_dir}/{modes[args.model_mode]}_{args.hidden_size}{moe_path}.pth'
|
||||
|
||||
model = MiniMindForCausalLM(MiniMindConfig(
|
||||
hidden_size=args.hidden_size,
|
||||
num_hidden_layers=args.num_hidden_layers,
|
||||
use_moe=args.use_moe
|
||||
use_moe=args.use_moe,
|
||||
inference_rope_scaling=args.inference_rope_scaling
|
||||
))
|
||||
|
||||
model.load_state_dict(torch.load(ckp, map_location=args.device), strict=True)
|
||||
@ -28,7 +29,7 @@ def init_model(args):
|
||||
apply_lora(model)
|
||||
load_lora(model, f'./{args.out_dir}/lora/{args.lora_name}_{args.hidden_size}.pth')
|
||||
else:
|
||||
transformers_model_path = './MiniMind2'
|
||||
transformers_model_path = './MiniMind2-MoE'
|
||||
tokenizer = AutoTokenizer.from_pretrained(transformers_model_path)
|
||||
model = AutoModelForCausalLM.from_pretrained(transformers_model_path, trust_remote_code=True)
|
||||
print(f'MiniMind模型参数量: {sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.2f}M(illion)')
|
||||
@ -48,8 +49,8 @@ def get_prompt_datas(args):
|
||||
'杭州市的美食有'
|
||||
]
|
||||
else:
|
||||
# 非LoRA模型的通用对话问题
|
||||
if args.lora_name == 'None':
|
||||
# 通用对话问题
|
||||
prompt_datas = [
|
||||
'请介绍一下自己。',
|
||||
'你更擅长哪一个学科?',
|
||||
@ -62,7 +63,7 @@ def get_prompt_datas(args):
|
||||
'Introduce the history of the United States, please.'
|
||||
]
|
||||
else:
|
||||
# 特定领域问题
|
||||
# LoRA微调模型的特定领域问题
|
||||
lora_prompt_datas = {
|
||||
'lora_identity': [
|
||||
"你是ChatGPT吧。",
|
||||
@ -111,13 +112,14 @@ def main():
|
||||
parser.add_argument('--num_hidden_layers', default=8, type=int)
|
||||
parser.add_argument('--max_seq_len', default=8192, type=int)
|
||||
parser.add_argument('--use_moe', default=False, type=bool)
|
||||
# 携带历史对话上下文条数
|
||||
# history_cnt需要设为偶数,即【用户问题, 模型回答】为1组;设置为0时,即当前query不携带历史上文
|
||||
# 模型未经过外推微调时,在更长的上下文的chat_template时难免出现性能的明显退化,因此需要注意此处设置
|
||||
parser.add_argument('--model_mode', default=5, type=int, help="0: 预训练模型,1: SFT-Chat模型,2: RLHF-Chat模型,3: Reason模型,4: RLAIF-Chat模型,6: Funcall-Chat模型")
|
||||
# 启用长度外推,默认为4倍(注:仅解决位置编码外推问题,不代表模型真实具备长文本能力)
|
||||
parser.add_argument('--inference_rope_scaling', default=False, action='store_true')
|
||||
# 携带历史对话上下文条数history_cnt需要设为偶数,即【用户问题, 模型回答】为1组;设置为0时,即当前query不携带历史上文
|
||||
# 模型未经过多轮对话微调时,在多轮次的长上下文难免出现能力的明显退化,因此需要注意此处设置
|
||||
parser.add_argument('--history_cnt', default=0, type=int)
|
||||
parser.add_argument('--load', default=0, type=int, help="0: 原生torch权重,1: transformers加载")
|
||||
parser.add_argument('--model_mode', default=1, type=int,
|
||||
help="0: 预训练模型,1: SFT-Chat模型,2: RLHF-Chat模型,3: Reason模型,4: RLAIF-Chat模型")
|
||||
# load模式为1时,前置hidden_size、num_hidden_layers、max_seq_len等参数失效,即以加载的transformers模型的config.json配置为准
|
||||
parser.add_argument('--load', default=1, type=int, help="0: 原生torch权重,1: transformers加载")
|
||||
args = parser.parse_args()
|
||||
|
||||
model, tokenizer = init_model(args)
|
||||
@ -128,18 +130,27 @@ def main():
|
||||
|
||||
messages = []
|
||||
for idx, prompt in enumerate(prompts if test_mode == 0 else iter(lambda: input('👶: '), '')):
|
||||
setup_seed(random.randint(0, 2048))
|
||||
# setup_seed(2025) # 如需固定每次输出则换成【固定】的随机种子
|
||||
# setup_seed(random.randint(0, 2048))
|
||||
setup_seed(2026) # 如需固定每次输出则换成【固定】的随机种子
|
||||
if test_mode == 0: print(f'👶: {prompt}')
|
||||
|
||||
messages = messages[-args.history_cnt:] if args.history_cnt else []
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
|
||||
new_prompt = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True
|
||||
) if args.model_mode != 0 else (tokenizer.bos_token + prompt)
|
||||
# 1. Pretrain:接龙模型
|
||||
if args.model_mode == 0:
|
||||
new_prompt = tokenizer.bos_token + prompt
|
||||
# 2. SFT/RL:聊天模型
|
||||
else:
|
||||
template_args = {
|
||||
"conversation": messages,
|
||||
"tokenize": False,
|
||||
"add_generation_prompt": True
|
||||
}
|
||||
# 只可对Reason模型使用,非思考模型不能加此参数
|
||||
if args.model_mode == 3:
|
||||
template_args["enable_thinking"] = True # False则关闭think
|
||||
new_prompt = tokenizer.apply_chat_template(**template_args)
|
||||
|
||||
inputs = tokenizer(
|
||||
new_prompt,
|
||||
320
others/HF-Space/app.py
Normal file
320
others/HF-Space/app.py
Normal file
@ -0,0 +1,320 @@
|
||||
import random
|
||||
import re
|
||||
from threading import Thread
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
import streamlit as st
|
||||
|
||||
st.set_page_config(page_title="MiniMind", initial_sidebar_state="collapsed")
|
||||
|
||||
st.markdown("""
|
||||
<style>
|
||||
/* 调整主容器边距,避免溢出 */
|
||||
.stMainBlockContainer {
|
||||
padding-top: 1rem !important;
|
||||
padding-bottom: 1rem !important;
|
||||
}
|
||||
|
||||
/* 操作按钮样式 */
|
||||
.stButton > button {
|
||||
box-sizing: border-box !important;
|
||||
border-radius: 50% !important;
|
||||
width: 24px !important;
|
||||
height: 24px !important;
|
||||
min-width: 24px !important;
|
||||
min-height: 24px !important;
|
||||
max-width: 24px !important;
|
||||
max-height: 24px !important;
|
||||
padding: 0 !important;
|
||||
background-color: transparent !important;
|
||||
border: 1px solid #ddd !important;
|
||||
display: flex !important;
|
||||
align-items: center !important;
|
||||
justify-content: center !important;
|
||||
font-size: 14px !important;
|
||||
color: #888 !important;
|
||||
cursor: pointer !important;
|
||||
transition: all 0.2s ease !important;
|
||||
margin: 2px !important;
|
||||
}
|
||||
|
||||
.stButton > button:hover {
|
||||
border-color: #999 !important;
|
||||
color: #333 !important;
|
||||
background-color: #f5f5f5 !important;
|
||||
}
|
||||
|
||||
/* 确保聊天容器不溢出 */
|
||||
.stChatMessage {
|
||||
max-width: 100% !important;
|
||||
}
|
||||
|
||||
/* 侧边栏样式优化 */
|
||||
section[data-testid="stSidebar"] {
|
||||
overflow-y: auto !important;
|
||||
}
|
||||
</style>
|
||||
""", unsafe_allow_html=True)
|
||||
|
||||
system_prompt = []
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
|
||||
def process_assistant_content(content):
|
||||
if model_source == "API" and 'R1' not in api_model_name:
|
||||
return content
|
||||
if model_source != "API" and 'R1' not in MODEL_PATHS[selected_model][1]:
|
||||
return content
|
||||
|
||||
if '<think>' in content and '</think>' in content:
|
||||
content = re.sub(r'(<think>)(.*?)(</think>)',
|
||||
r'<details style="font-style: italic; background: rgba(222, 222, 222, 0.5); padding: 10px; border-radius: 10px;"><summary style="font-weight:bold;">推理内容(展开)</summary>\2</details>',
|
||||
content,
|
||||
flags=re.DOTALL)
|
||||
|
||||
if '<think>' in content and '</think>' not in content:
|
||||
content = re.sub(r'<think>(.*?)$',
|
||||
r'<details open style="font-style: italic; background: rgba(222, 222, 222, 0.5); padding: 10px; border-radius: 10px;"><summary style="font-weight:bold;">推理中...</summary>\1</details>',
|
||||
content,
|
||||
flags=re.DOTALL)
|
||||
|
||||
if '<think>' not in content and '</think>' in content:
|
||||
content = re.sub(r'(.*?)</think>',
|
||||
r'<details style="font-style: italic; background: rgba(222, 222, 222, 0.5); padding: 10px; border-radius: 10px;"><summary style="font-weight:bold;">推理内容(展开)</summary>\1</details>',
|
||||
content,
|
||||
flags=re.DOTALL)
|
||||
|
||||
return content
|
||||
|
||||
|
||||
@st.cache_resource
|
||||
def load_model_tokenizer(model_path):
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_path,
|
||||
trust_remote_code=True
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_path,
|
||||
trust_remote_code=True
|
||||
)
|
||||
model = model.eval().to(device)
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
def clear_chat_messages():
|
||||
del st.session_state.messages
|
||||
del st.session_state.chat_messages
|
||||
|
||||
|
||||
def init_chat_messages():
|
||||
if "messages" in st.session_state:
|
||||
for i, message in enumerate(st.session_state.messages):
|
||||
if message["role"] == "assistant":
|
||||
with st.chat_message("assistant", avatar=image_url):
|
||||
st.markdown(process_assistant_content(message["content"]), unsafe_allow_html=True)
|
||||
if st.button("🗑", key=f"delete_{i}"):
|
||||
st.session_state.messages.pop(i)
|
||||
st.session_state.messages.pop(i - 1)
|
||||
st.session_state.chat_messages.pop(i)
|
||||
st.session_state.chat_messages.pop(i - 1)
|
||||
st.rerun()
|
||||
else:
|
||||
st.markdown(
|
||||
f'<div style="display: flex; justify-content: flex-end;"><div style="display: inline-block; margin: 10px 0; padding: 8px 12px 8px 12px; background-color: #ddd; border-radius: 10px; color: black;">{message["content"]}</div></div>',
|
||||
unsafe_allow_html=True)
|
||||
|
||||
else:
|
||||
st.session_state.messages = []
|
||||
st.session_state.chat_messages = []
|
||||
|
||||
return st.session_state.messages
|
||||
|
||||
|
||||
def regenerate_answer(index):
|
||||
st.session_state.messages.pop()
|
||||
st.session_state.chat_messages.pop()
|
||||
st.rerun()
|
||||
|
||||
|
||||
def delete_conversation(index):
|
||||
st.session_state.messages.pop(index)
|
||||
st.session_state.messages.pop(index - 1)
|
||||
st.session_state.chat_messages.pop(index)
|
||||
st.session_state.chat_messages.pop(index - 1)
|
||||
st.rerun()
|
||||
|
||||
|
||||
st.sidebar.title("模型设定调整")
|
||||
|
||||
# st.sidebar.text("训练数据偏差,增加上下文记忆时\n多轮对话(较单轮)容易出现能力衰减")
|
||||
st.session_state.history_chat_num = st.sidebar.slider("Number of Historical Dialogues", 0, 6, 0, step=2)
|
||||
# st.session_state.history_chat_num = 0
|
||||
st.session_state.max_new_tokens = st.sidebar.slider("Max Sequence Length", 256, 8192, 8192, step=1)
|
||||
st.session_state.temperature = st.sidebar.slider("Temperature", 0.6, 1.2, 0.85, step=0.01)
|
||||
|
||||
model_source = st.sidebar.radio("选择模型来源", ["本地模型", "API"], index=0)
|
||||
|
||||
if model_source == "API":
|
||||
api_url = st.sidebar.text_input("API URL", value="http://127.0.0.1:8000/v1")
|
||||
api_model_id = st.sidebar.text_input("Model ID", value="minimind")
|
||||
api_model_name = st.sidebar.text_input("Model Name", value="MiniMind2")
|
||||
api_key = st.sidebar.text_input("API Key", value="none", type="password")
|
||||
slogan = f"Hi, I'm {api_model_name}"
|
||||
else:
|
||||
MODEL_PATHS = {
|
||||
"MiniMind2 (0.1B)": ["./MiniMind2", "MiniMind2"],
|
||||
"MiniMind2-MoE (0.15B)": ["./MiniMind2-MoE", "MiniMind2-MoE"],
|
||||
"MiniMind2-Small (0.02B)": ["./MiniMind2-Small", "MiniMind2-Small"]
|
||||
}
|
||||
|
||||
selected_model = st.sidebar.selectbox('Models', list(MODEL_PATHS.keys()), index=0) # 默认选择 MiniMind2
|
||||
model_path = MODEL_PATHS[selected_model][0]
|
||||
slogan = f"Hi, I'm {MODEL_PATHS[selected_model][1]}"
|
||||
|
||||
image_url = "https://www.modelscope.cn/api/v1/studio/gongjy/MiniMind/repo?Revision=master&FilePath=images%2Flogo2.png&View=true"
|
||||
|
||||
st.markdown(
|
||||
f'<div style="display: flex; flex-direction: column; align-items: center; text-align: center; margin: 0; padding: 0;">'
|
||||
'<div style="font-style: italic; font-weight: 900; margin: 0; padding-top: 4px; display: flex; align-items: center; justify-content: center; flex-wrap: wrap; width: 100%;">'
|
||||
f'<img src="{image_url}" style="width: 45px; height: 45px; "> '
|
||||
f'<span style="font-size: 26px; margin-left: 10px;">{slogan}</span>'
|
||||
'</div>'
|
||||
'<span style="color: #bbb; font-style: italic; margin-top: 6px; margin-bottom: 10px;">内容完全由AI生成,请务必仔细甄别<br>Content AI-generated, please discern with care</span>'
|
||||
'</div>',
|
||||
unsafe_allow_html=True
|
||||
)
|
||||
|
||||
|
||||
def setup_seed(seed):
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
|
||||
|
||||
def main():
|
||||
if model_source == "本地模型":
|
||||
model, tokenizer = load_model_tokenizer(model_path)
|
||||
else:
|
||||
model, tokenizer = None, None
|
||||
|
||||
if "messages" not in st.session_state:
|
||||
st.session_state.messages = []
|
||||
st.session_state.chat_messages = []
|
||||
|
||||
messages = st.session_state.messages
|
||||
|
||||
for i, message in enumerate(messages):
|
||||
if message["role"] == "assistant":
|
||||
with st.chat_message("assistant", avatar=image_url):
|
||||
st.markdown(process_assistant_content(message["content"]), unsafe_allow_html=True)
|
||||
if st.button("×", key=f"delete_{i}"):
|
||||
st.session_state.messages = st.session_state.messages[:i - 1]
|
||||
st.session_state.chat_messages = st.session_state.chat_messages[:i - 1]
|
||||
st.rerun()
|
||||
else:
|
||||
st.markdown(
|
||||
f'<div style="display: flex; justify-content: flex-end;"><div style="display: inline-block; margin: 10px 0; padding: 8px 12px 8px 12px; background-color: gray; border-radius: 10px; color:white; ">{message["content"]}</div></div>',
|
||||
unsafe_allow_html=True)
|
||||
|
||||
prompt = st.chat_input(key="input", placeholder="给 MiniMind 发送消息")
|
||||
|
||||
if hasattr(st.session_state, 'regenerate') and st.session_state.regenerate:
|
||||
prompt = st.session_state.last_user_message
|
||||
regenerate_index = st.session_state.regenerate_index
|
||||
delattr(st.session_state, 'regenerate')
|
||||
delattr(st.session_state, 'last_user_message')
|
||||
delattr(st.session_state, 'regenerate_index')
|
||||
|
||||
if prompt:
|
||||
st.markdown(
|
||||
f'<div style="display: flex; justify-content: flex-end;"><div style="display: inline-block; margin: 10px 0; padding: 8px 12px 8px 12px; background-color: gray; border-radius: 10px; color:white; ">{prompt}</div></div>',
|
||||
unsafe_allow_html=True)
|
||||
messages.append({"role": "user", "content": prompt[-st.session_state.max_new_tokens:]})
|
||||
st.session_state.chat_messages.append({"role": "user", "content": prompt[-st.session_state.max_new_tokens:]})
|
||||
|
||||
with st.chat_message("assistant", avatar=image_url):
|
||||
placeholder = st.empty()
|
||||
|
||||
if model_source == "API":
|
||||
try:
|
||||
from openai import OpenAI
|
||||
|
||||
client = OpenAI(
|
||||
api_key=api_key,
|
||||
base_url=api_url
|
||||
)
|
||||
history_num = st.session_state.history_chat_num + 1 # +1 是为了包含当前的用户消息
|
||||
conversation_history = system_prompt + st.session_state.chat_messages[-history_num:]
|
||||
answer = ""
|
||||
response = client.chat.completions.create(
|
||||
model=api_model_id,
|
||||
messages=conversation_history,
|
||||
stream=True,
|
||||
temperature=st.session_state.temperature
|
||||
)
|
||||
|
||||
for chunk in response:
|
||||
content = chunk.choices[0].delta.content or ""
|
||||
answer += content
|
||||
placeholder.markdown(process_assistant_content(answer), unsafe_allow_html=True)
|
||||
|
||||
except Exception as e:
|
||||
answer = f"API调用出错: {str(e)}"
|
||||
placeholder.markdown(answer, unsafe_allow_html=True)
|
||||
else:
|
||||
random_seed = random.randint(0, 2 ** 32 - 1)
|
||||
setup_seed(random_seed)
|
||||
|
||||
st.session_state.chat_messages = system_prompt + st.session_state.chat_messages[
|
||||
-(st.session_state.history_chat_num + 1):]
|
||||
new_prompt = tokenizer.apply_chat_template(
|
||||
st.session_state.chat_messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True
|
||||
)
|
||||
|
||||
inputs = tokenizer(
|
||||
new_prompt,
|
||||
return_tensors="pt",
|
||||
truncation=True
|
||||
).to(device)
|
||||
|
||||
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
||||
generation_kwargs = {
|
||||
"input_ids": inputs.input_ids,
|
||||
"max_length": inputs.input_ids.shape[1] + st.session_state.max_new_tokens,
|
||||
"num_return_sequences": 1,
|
||||
"do_sample": True,
|
||||
"attention_mask": inputs.attention_mask,
|
||||
"pad_token_id": tokenizer.pad_token_id,
|
||||
"eos_token_id": tokenizer.eos_token_id,
|
||||
"temperature": st.session_state.temperature,
|
||||
"top_p": 0.85,
|
||||
"streamer": streamer,
|
||||
}
|
||||
|
||||
Thread(target=model.generate, kwargs=generation_kwargs).start()
|
||||
|
||||
answer = ""
|
||||
for new_text in streamer:
|
||||
answer += new_text
|
||||
placeholder.markdown(process_assistant_content(answer), unsafe_allow_html=True)
|
||||
|
||||
messages.append({"role": "assistant", "content": answer})
|
||||
st.session_state.chat_messages.append({"role": "assistant", "content": answer})
|
||||
with st.empty():
|
||||
if st.button("×", key=f"delete_{len(messages) - 1}"):
|
||||
st.session_state.messages = st.session_state.messages[:-2]
|
||||
st.session_state.chat_messages = st.session_state.chat_messages[:-2]
|
||||
st.rerun()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
|
||||
|
||||
main()
|
||||
Loading…
Reference in New Issue
Block a user