mirror of
https://github.com/jingyaogong/minimind.git
synced 2026-06-06 00:04:50 +00:00
421 lines
22 KiB
Python
421 lines
22 KiB
Python
import random
|
|
import re
|
|
import json
|
|
import os
|
|
from threading import Thread
|
|
|
|
import torch
|
|
import numpy as np
|
|
import streamlit as st
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
|
|
|
|
st.set_page_config(page_title="MiniMind", initial_sidebar_state="collapsed")
|
|
|
|
st.markdown("""
|
|
<style>
|
|
/* 添加操作按钮样式 */
|
|
.stButton button {
|
|
border-radius: 50% !important; /* 改为圆形 */
|
|
width: 32px !important; /* 固定宽度 */
|
|
height: 32px !important; /* 固定高度 */
|
|
padding: 0 !important; /* 移除内边距 */
|
|
background-color: transparent !important;
|
|
border: 1px solid #ddd !important;
|
|
display: flex !important;
|
|
align-items: center !important;
|
|
justify-content: center !important;
|
|
font-size: 14px !important;
|
|
color: #666 !important; /* 更柔和的颜色 */
|
|
margin: 5px 10px 5px 0 !important; /* 调整按钮间距 */
|
|
}
|
|
.stButton button:hover {
|
|
border-color: #999 !important;
|
|
color: #333 !important;
|
|
background-color: #f5f5f5 !important;
|
|
}
|
|
.stMainBlockContainer > div:first-child {
|
|
margin-top: -50px !important;
|
|
}
|
|
.stApp > div:last-child {
|
|
margin-bottom: -35px !important;
|
|
}
|
|
|
|
/* 重置按钮基础样式 */
|
|
.stButton > button {
|
|
all: unset !important; /* 重置所有默认样式 */
|
|
box-sizing: border-box !important;
|
|
border-radius: 50% !important;
|
|
width: 18px !important;
|
|
height: 18px !important;
|
|
min-width: 18px !important;
|
|
min-height: 18px !important;
|
|
max-width: 18px !important;
|
|
max-height: 18px !important;
|
|
padding: 0 !important;
|
|
background-color: transparent !important;
|
|
border: 1px solid #ddd !important;
|
|
display: flex !important;
|
|
align-items: center !important;
|
|
justify-content: center !important;
|
|
font-size: 14px !important;
|
|
color: #888 !important;
|
|
cursor: pointer !important;
|
|
transition: all 0.2s ease !important;
|
|
margin: 0 2px !important; /* 调整这里的 margin 值 */
|
|
}
|
|
|
|
</style>
|
|
""", unsafe_allow_html=True)
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
# 多语言文本
|
|
LANG_TEXTS = {
|
|
'zh': {
|
|
'settings': '模型设定调整',
|
|
'history_rounds': '历史对话轮次',
|
|
'max_length': '最大生成长度',
|
|
'temperature': '温度',
|
|
'thinking': '思考',
|
|
'tools': '工具',
|
|
'language': '语言',
|
|
'send': '给 MiniMind 发送消息',
|
|
'disclaimer': 'AI 生成内容可能存在错误,请仔细核实',
|
|
'think_tip': '自适应思考,目前多轮对话或Tool Call共存时思考不稳定',
|
|
'tool_select': '工具选择(最多4个)',
|
|
},
|
|
'en': {
|
|
'settings': 'Model Settings',
|
|
'history_rounds': 'History Rounds',
|
|
'max_length': 'Max Length',
|
|
'temperature': 'Temperature',
|
|
'thinking': 'Thinking',
|
|
'tools': 'Tools',
|
|
'language': 'Language',
|
|
'send': 'Send a message to MiniMind',
|
|
'disclaimer': 'AI-generated content may be inaccurate, please verify',
|
|
'think_tip': 'Adaptive thinking; may be unstable with multi-turn or Tool Call',
|
|
'tool_select': 'Tool Selection (max 4)',
|
|
}
|
|
}
|
|
|
|
def get_text(key):
|
|
lang = st.session_state.get('lang', 'en')
|
|
return LANG_TEXTS.get(lang, {}).get(key, LANG_TEXTS['zh'].get(key, key))
|
|
|
|
# 工具定义
|
|
TOOLS = [
|
|
{"type": "function", "function": {"name": "calculate_math", "description": "计算数学表达式", "parameters": {"type": "object", "properties": {"expression": {"type": "string", "description": "数学表达式"}}, "required": ["expression"]}}},
|
|
{"type": "function", "function": {"name": "get_current_time", "description": "获取当前时间", "parameters": {"type": "object", "properties": {"timezone": {"type": "string", "default": "Asia/Shanghai"}}, "required": []}}},
|
|
{"type": "function", "function": {"name": "random_number", "description": "生成随机数", "parameters": {"type": "object", "properties": {"min": {"type": "integer"}, "max": {"type": "integer"}}, "required": ["min", "max"]}}},
|
|
{"type": "function", "function": {"name": "text_length", "description": "计算文本长度", "parameters": {"type": "object", "properties": {"text": {"type": "string"}}, "required": ["text"]}}},
|
|
{"type": "function", "function": {"name": "unit_converter", "description": "单位转换", "parameters": {"type": "object", "properties": {"value": {"type": "number"}, "from_unit": {"type": "string"}, "to_unit": {"type": "string"}}, "required": ["value", "from_unit", "to_unit"]}}},
|
|
{"type": "function", "function": {"name": "get_current_weather", "description": "获取天气", "parameters": {"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]}}},
|
|
{"type": "function", "function": {"name": "get_exchange_rate", "description": "获取汇率", "parameters": {"type": "object", "properties": {"from_currency": {"type": "string"}, "to_currency": {"type": "string"}}, "required": ["from_currency", "to_currency"]}}},
|
|
{"type": "function", "function": {"name": "translate_text", "description": "翻译文本", "parameters": {"type": "object", "properties": {"text": {"type": "string"}, "target_lang": {"type": "string"}}, "required": ["text", "target_lang"]}}},
|
|
]
|
|
|
|
TOOL_SHORT_NAMES = {
|
|
'calculate_math': '数学', 'get_current_time': '时间', 'random_number': '随机',
|
|
'text_length': '字数', 'unit_converter': '单位', 'get_current_weather': '天气',
|
|
'get_exchange_rate': '汇率', 'translate_text': '翻译'
|
|
}
|
|
|
|
def execute_tool(tool_name, args):
|
|
import datetime
|
|
try:
|
|
if tool_name == 'calculate_math':
|
|
return {"result": eval(args.get('expression', '0'))}
|
|
elif tool_name == 'get_current_time':
|
|
tz = args.get('timezone', 'Asia/Shanghai')
|
|
return {"result": datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}
|
|
elif tool_name == 'random_number':
|
|
return {"result": random.randint(args.get('min', 0), args.get('max', 100))}
|
|
elif tool_name == 'text_length':
|
|
return {"result": len(args.get('text', ''))}
|
|
elif tool_name == 'unit_converter':
|
|
return {"result": f"{args.get('value', 0)} {args.get('from_unit', '')} = ? {args.get('to_unit', '')}"}
|
|
elif tool_name == 'get_current_weather':
|
|
return {"result": f"{args.get('city', 'Unknown')}: 晴, 7~10°C"}
|
|
elif tool_name == 'get_exchange_rate':
|
|
return {"result": f"1 {args.get('from_currency', 'USD')} = 7.2 {args.get('to_currency', 'CNY')}"}
|
|
elif tool_name == 'translate_text':
|
|
return {"result": f"[翻译结果]: hello world"}
|
|
return {"result": "Unknown tool"}
|
|
except Exception as e:
|
|
return {"error": str(e)}
|
|
|
|
|
|
def process_assistant_content(content, is_streaming=False):
|
|
# 处理tool_call标签,格式化显示
|
|
if '<tool_call>' in content:
|
|
def format_tool_call(match):
|
|
try:
|
|
tc = json.loads(match.group(1))
|
|
name = tc.get('name', 'unknown')
|
|
args = tc.get('arguments', {})
|
|
return f'<div style="background: rgba(80, 110, 150, 0.20); border: 1px solid rgba(140, 170, 210, 0.30); padding: 10px 12px; border-radius: 12px; margin: 6px 0;"><div style="font-size:12px;opacity:.75;display:block;margin:0 0 6px 0;line-height:1;">ToolCalling</div><div><b>{name}</b>: {json.dumps(args, ensure_ascii=False)}</div></div>'
|
|
except:
|
|
return match.group(0)
|
|
content = re.sub(r'<tool_call>(.*?)</tool_call>', format_tool_call, content, flags=re.DOTALL)
|
|
|
|
# 流式生成且开启思考时,一开始就放到折叠里
|
|
if is_streaming and st.session_state.get('enable_thinking', False) and '</think>' not in content and '<think>' not in content:
|
|
m = re.search(r'(\n\n(?:我是|您好|你好)[^\n]*)', content)
|
|
if m and m.start(1) > 5:
|
|
i = m.start(1)
|
|
think_part = content[:i]
|
|
answer_part = content[i:]
|
|
return f'<details open style="border-left: 2px solid #666; padding-left: 12px; margin: 8px 0;"><summary style="cursor: pointer; color: #888;">已思考</summary><div style="color: #aaa; font-size: 0.95em; margin-top: 8px; max-height: 100px; overflow-y: auto;">{think_part.strip()}</div></details>{answer_part}'
|
|
elif len(content) > 5:
|
|
return f'<details open style="border-left: 2px solid #666; padding-left: 12px; margin: 8px 0;"><summary style="cursor: pointer; color: #888;">思考中...</summary><div style="color: #aaa; font-size: 0.95em; margin-top: 8px; max-height: 100px; overflow-y: auto; display: flex; flex-direction: column-reverse;"><div style="margin-bottom: auto;">{content.strip().replace(chr(10), "<br>")}</div></div></details>'
|
|
|
|
if '<think>' in content and '</think>' in content:
|
|
def format_think(match):
|
|
think_content = match.group(2)
|
|
if think_content.replace('\n', '').strip(): # 不是全换行
|
|
return f'<details open style="border-left: 2px solid #666; padding-left: 12px; margin: 8px 0;"><summary style="cursor: pointer; color: #888;">已思考</summary><div style="color: #aaa; font-size: 0.95em; margin-top: 8px; max-height: 100px; overflow-y: auto;">{think_content.strip()}</div></details>'
|
|
return ''
|
|
content = re.sub(r'(<think>)(.*?)(</think>)', format_think, content, flags=re.DOTALL)
|
|
|
|
if '<think>' in content and '</think>' not in content:
|
|
def format_think_in_progress(match):
|
|
tc = match.group(1)
|
|
return f'<details open style="border-left: 2px solid #666; padding-left: 12px; margin: 8px 0;"><summary style="cursor: pointer; color: #888;">思考中...</summary><div style="color: #aaa; font-size: 0.95em; margin-top: 8px; max-height: 100px; overflow-y: auto; display: flex; flex-direction: column-reverse;"><div style="margin-bottom: auto;">{tc.strip().replace(chr(10), "<br>")}</div></div></details>'
|
|
content = re.sub(r'<think>(.*?)$', format_think_in_progress, content, flags=re.DOTALL)
|
|
|
|
if '<think>' not in content and '</think>' in content:
|
|
def format_think_no_start(match):
|
|
think_content = match.group(1)
|
|
if think_content.replace('\n', '').strip():
|
|
return f'<details open style="border-left: 2px solid #666; padding-left: 12px; margin: 8px 0;"><summary style="cursor: pointer; color: #888;">已思考</summary><div style="color: #aaa; font-size: 0.95em; margin-top: 8px; max-height: 100px; overflow-y: auto;">{think_content.strip()}</div></details>'
|
|
return ''
|
|
content = re.sub(r'(.*?)</think>', format_think_no_start, content, flags=re.DOTALL)
|
|
|
|
return content
|
|
|
|
|
|
@st.cache_resource
|
|
def load_model_tokenizer(model_path):
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
model_path,
|
|
trust_remote_code=True
|
|
)
|
|
tokenizer = AutoTokenizer.from_pretrained(
|
|
model_path,
|
|
trust_remote_code=True
|
|
)
|
|
model = model.half().eval().to(device)
|
|
return model, tokenizer
|
|
|
|
|
|
def clear_chat_messages():
|
|
del st.session_state.messages
|
|
del st.session_state.chat_messages
|
|
|
|
|
|
def init_chat_messages():
|
|
if "messages" in st.session_state:
|
|
for i, message in enumerate(st.session_state.messages):
|
|
if message["role"] == "assistant":
|
|
st.markdown(process_assistant_content(message["content"]), unsafe_allow_html=True)
|
|
else:
|
|
st.markdown(
|
|
f'<div style="display: flex; justify-content: flex-end;"><div style="display: inline-block; margin: 10px 0; padding: 8px 12px 8px 12px; background-color: #3d4450; border-radius: 22px; color: white;">{message["content"]}</div></div>',
|
|
unsafe_allow_html=True)
|
|
|
|
else:
|
|
st.session_state.messages = []
|
|
st.session_state.chat_messages = []
|
|
|
|
return st.session_state.messages
|
|
|
|
def regenerate_answer(index):
|
|
st.session_state.messages.pop()
|
|
st.session_state.chat_messages.pop()
|
|
st.rerun()
|
|
|
|
|
|
# 动态扫描模型目录
|
|
script_dir = os.path.dirname(os.path.abspath(__file__))
|
|
MODEL_PATHS = {}
|
|
for d in sorted(os.listdir(script_dir), reverse=True):
|
|
full_path = os.path.join(script_dir, d)
|
|
if os.path.isdir(full_path) and not d.startswith('.') and not d.startswith('_'):
|
|
if any(f.endswith(('.bin', '.safetensors', '.pt')) or os.path.exists(os.path.join(full_path, 'model.safetensors.index.json')) for f in os.listdir(full_path) if os.path.isfile(os.path.join(full_path, f))):
|
|
MODEL_PATHS[d] = [d, d]
|
|
if not MODEL_PATHS:
|
|
MODEL_PATHS = {"No models found": ["", "No models"]}
|
|
|
|
# 模型选择
|
|
selected_model = st.sidebar.selectbox('Model', list(MODEL_PATHS.keys()), index=0)
|
|
model_path = MODEL_PATHS[selected_model][0]
|
|
slogan = f"我是 {MODEL_PATHS[selected_model][1]},有什么可以帮你的?" if st.session_state.get('lang', 'en') == 'zh' else f"I am {MODEL_PATHS[selected_model][1]}, how can I help you?"
|
|
|
|
st.sidebar.markdown('<hr style="margin: 12px 0 16px 0;">', unsafe_allow_html=True)
|
|
|
|
# 语言选择
|
|
lang_options = {'中文': 'zh', 'English': 'en'}
|
|
current_lang = st.session_state.get('lang', 'en')
|
|
lang_index = 0 if current_lang == 'zh' else 1
|
|
lang_label = st.sidebar.radio('Language / 语言', list(lang_options.keys()), index=lang_index, horizontal=True)
|
|
if lang_options[lang_label] != current_lang:
|
|
st.session_state.lang = lang_options[lang_label]
|
|
st.rerun()
|
|
|
|
st.sidebar.markdown('<hr style="margin: 12px 0 16px 0;">', unsafe_allow_html=True)
|
|
|
|
# 参数设置
|
|
st.session_state.history_chat_num = st.sidebar.slider(get_text('history_rounds'), 0, 8, 0, step=2)
|
|
st.session_state.max_new_tokens = st.sidebar.slider(get_text('max_length'), 256, 8192, 8192, step=1)
|
|
st.session_state.temperature = st.sidebar.slider(get_text('temperature'), 0.6, 1.2, 0.90, step=0.01)
|
|
|
|
st.sidebar.markdown('<hr style="margin: 12px 0 16px 0;">', unsafe_allow_html=True)
|
|
|
|
# 功能开关
|
|
st.session_state.enable_thinking = st.sidebar.checkbox(get_text('thinking'), value=False, help=get_text('think_tip'))
|
|
st.session_state.selected_tools = []
|
|
with st.sidebar.expander(get_text('tools')):
|
|
st.caption(get_text('tool_select'))
|
|
selected_count = sum(1 for tool in TOOLS if st.session_state.get(f"tool_{tool['function']['name']}", False))
|
|
for tool in TOOLS:
|
|
name = tool['function']['name']
|
|
short_name = TOOL_SHORT_NAMES.get(name, name)
|
|
checked = st.checkbox(short_name, key=f"tool_{name}", disabled=(selected_count >= 4 and not st.session_state.get(f"tool_{name}", False)))
|
|
if checked and len(st.session_state.selected_tools) < 4:
|
|
st.session_state.selected_tools.append(name)
|
|
|
|
image_url = "https://www.modelscope.cn/api/v1/studio/gongjy/MiniMind/repo?Revision=master&FilePath=images%2Flogo2.png&View=true"
|
|
|
|
st.markdown(
|
|
f'<div style="display: flex; flex-direction: column; align-items: center; text-align: center; margin: 0; padding: 0;">'
|
|
'<div style="font-style: italic; font-weight: 900; margin: 0; padding-top: 4px; display: flex; align-items: center; justify-content: center; flex-wrap: wrap; width: 100%;">'
|
|
f'<img src="{image_url}" style="width: 40px; height: 40px; "> '
|
|
f'<span style="font-size: 26px; margin-left: 10px;">{slogan}</span>'
|
|
'</div>'
|
|
f'<span style="color: #bbb; font-style: italic; margin-top: 6px; margin-bottom: 10px;">{get_text("disclaimer")}</span>'
|
|
'</div>',
|
|
unsafe_allow_html=True
|
|
)
|
|
|
|
|
|
def setup_seed(seed):
|
|
random.seed(seed)
|
|
np.random.seed(seed)
|
|
torch.manual_seed(seed)
|
|
torch.cuda.manual_seed(seed)
|
|
torch.cuda.manual_seed_all(seed)
|
|
torch.backends.cudnn.deterministic = True
|
|
torch.backends.cudnn.benchmark = False
|
|
|
|
|
|
def main():
|
|
model, tokenizer = load_model_tokenizer(model_path)
|
|
|
|
if "messages" not in st.session_state:
|
|
st.session_state.messages = []
|
|
st.session_state.chat_messages = []
|
|
|
|
messages = st.session_state.messages
|
|
|
|
for i, message in enumerate(messages):
|
|
if message["role"] == "assistant":
|
|
st.markdown(process_assistant_content(message["content"]), unsafe_allow_html=True)
|
|
else:
|
|
st.markdown(
|
|
f'<div style="display: flex; justify-content: flex-end;"><div style="display: inline-block; margin: 10px 0; padding: 8px 12px 8px 12px; background-color: #3d4450; border-radius: 22px; color: white;">{message["content"]}</div></div>',
|
|
unsafe_allow_html=True)
|
|
|
|
prompt = st.chat_input(key="input", placeholder=get_text('send'))
|
|
|
|
if hasattr(st.session_state, 'regenerate') and st.session_state.regenerate:
|
|
prompt = st.session_state.last_user_message
|
|
regenerate_index = st.session_state.regenerate_index
|
|
delattr(st.session_state, 'regenerate')
|
|
delattr(st.session_state, 'last_user_message')
|
|
delattr(st.session_state, 'regenerate_index')
|
|
|
|
if prompt:
|
|
st.markdown(
|
|
f'<div style="display: flex; justify-content: flex-end;"><div style="display: inline-block; margin: 10px 0; padding: 8px 12px 8px 12px; background-color: #3d4450; border-radius: 22px; color: white;">{prompt}</div></div>',
|
|
unsafe_allow_html=True)
|
|
messages.append({"role": "user", "content": prompt[-st.session_state.max_new_tokens:]})
|
|
st.session_state.chat_messages.append({"role": "user", "content": prompt[-st.session_state.max_new_tokens:]})
|
|
|
|
placeholder = st.empty()
|
|
|
|
random_seed = random.randint(0, 2 ** 32 - 1)
|
|
setup_seed(random_seed)
|
|
|
|
tools = [t for t in TOOLS if t['function']['name'] in st.session_state.get('selected_tools', [])] or None
|
|
sys_prompt = [] if tools else [{"role": "system", "content": "你是MiniMind,一个乐于助人、知识渊博的AI助手。请用完整且友好的方式回答用户问题。"}]
|
|
st.session_state.chat_messages = sys_prompt + st.session_state.chat_messages[-(st.session_state.history_chat_num + 1):]
|
|
template_kwargs = {"tokenize": False, "add_generation_prompt": True}
|
|
if st.session_state.get('enable_thinking', False):
|
|
template_kwargs["open_thinking"] = True
|
|
if tools:
|
|
template_kwargs["tools"] = tools
|
|
new_prompt = tokenizer.apply_chat_template(st.session_state.chat_messages, **template_kwargs)
|
|
|
|
inputs = tokenizer(new_prompt, return_tensors="pt", truncation=True).to(device)
|
|
|
|
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
|
generation_kwargs = {
|
|
"input_ids": inputs.input_ids,
|
|
"max_length": inputs.input_ids.shape[1] + st.session_state.max_new_tokens,
|
|
"num_return_sequences": 1,
|
|
"do_sample": True,
|
|
"attention_mask": inputs.attention_mask,
|
|
"pad_token_id": tokenizer.pad_token_id,
|
|
"eos_token_id": tokenizer.eos_token_id,
|
|
"temperature": st.session_state.temperature,
|
|
"top_p": 0.85,
|
|
"streamer": streamer,
|
|
}
|
|
|
|
Thread(target=model.generate, kwargs=generation_kwargs).start()
|
|
|
|
answer = ""
|
|
for new_text in streamer:
|
|
answer += new_text
|
|
placeholder.markdown(process_assistant_content(answer, is_streaming=True), unsafe_allow_html=True)
|
|
|
|
full_answer = answer
|
|
for _ in range(16):
|
|
tool_calls = re.findall(r'<tool_call>(.*?)</tool_call>', answer, re.DOTALL)
|
|
if not tool_calls:
|
|
break
|
|
st.session_state.chat_messages.append({"role": "assistant", "content": answer})
|
|
tool_results = []
|
|
for tc_str in tool_calls:
|
|
try:
|
|
tc = json.loads(tc_str.strip())
|
|
result = execute_tool(tc.get('name', ''), tc.get('arguments', {}))
|
|
st.session_state.chat_messages.append({"role": "tool", "content": json.dumps(result, ensure_ascii=False)})
|
|
tool_results.append(f'<div style="background: rgba(90, 130, 110, 0.20); border: 1px solid rgba(150, 200, 170, 0.30); padding: 10px 12px; border-radius: 12px; margin: 6px 0;"><div style="font-size:12px;opacity:.75;display:block;margin:0 0 6px 0;line-height:1;">ToolCalled</div><div><b>{tc.get("name", "")}</b>: {json.dumps(result, ensure_ascii=False)}</div></div>')
|
|
except:
|
|
pass
|
|
full_answer += "\n" + "\n".join(tool_results) + "\n"
|
|
placeholder.markdown(process_assistant_content(full_answer, is_streaming=True), unsafe_allow_html=True)
|
|
new_prompt = tokenizer.apply_chat_template(st.session_state.chat_messages, **template_kwargs)
|
|
inputs = tokenizer(new_prompt, return_tensors="pt", truncation=True).to(device)
|
|
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
|
generation_kwargs["input_ids"] = inputs.input_ids
|
|
generation_kwargs["attention_mask"] = inputs.attention_mask
|
|
generation_kwargs["max_length"] = inputs.input_ids.shape[1] + st.session_state.max_new_tokens
|
|
generation_kwargs["streamer"] = streamer
|
|
Thread(target=model.generate, kwargs=generation_kwargs).start()
|
|
answer = ""
|
|
for new_text in streamer:
|
|
answer += new_text
|
|
placeholder.markdown(process_assistant_content(full_answer + answer, is_streaming=True), unsafe_allow_html=True)
|
|
full_answer += answer
|
|
answer = full_answer
|
|
|
|
messages.append({"role": "assistant", "content": answer})
|
|
st.session_state.chat_messages.append({"role": "assistant", "content": answer})
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|