import random import re import json import os from threading import Thread import torch import numpy as np import streamlit as st from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer st.set_page_config(page_title="MiniMind", initial_sidebar_state="collapsed") st.markdown(""" """, unsafe_allow_html=True) device = "cuda" if torch.cuda.is_available() else "cpu" # 多语言文本 LANG_TEXTS = { 'zh': { 'settings': '模型设定调整', 'history_rounds': '历史对话轮次', 'max_length': '最大生成长度', 'temperature': '温度', 'thinking': '思考', 'tools': '工具', 'language': '语言', 'send': '给 MiniMind 发送消息', 'disclaimer': 'AI 生成内容可能存在错误,请仔细核实', 'think_tip': '自适应思考,目前多轮对话或Tool Call共存时思考不稳定', 'tool_select': '工具选择(最多4个)', }, 'en': { 'settings': 'Model Settings', 'history_rounds': 'History Rounds', 'max_length': 'Max Length', 'temperature': 'Temperature', 'thinking': 'Thinking', 'tools': 'Tools', 'language': 'Language', 'send': 'Send a message to MiniMind', 'disclaimer': 'AI-generated content may be inaccurate, please verify', 'think_tip': 'Adaptive thinking; may be unstable with multi-turn or Tool Call', 'tool_select': 'Tool Selection (max 4)', } } def get_text(key): lang = st.session_state.get('lang', 'en') return LANG_TEXTS.get(lang, {}).get(key, LANG_TEXTS['zh'].get(key, key)) # 工具定义 TOOLS = [ {"type": "function", "function": {"name": "calculate_math", "description": "计算数学表达式", "parameters": {"type": "object", "properties": {"expression": {"type": "string", "description": "数学表达式"}}, "required": ["expression"]}}}, {"type": "function", "function": {"name": "get_current_time", "description": "获取当前时间", "parameters": {"type": "object", "properties": {"timezone": {"type": "string", "default": "Asia/Shanghai"}}, "required": []}}}, {"type": "function", "function": {"name": "random_number", "description": "生成随机数", "parameters": {"type": "object", "properties": {"min": {"type": "integer"}, "max": {"type": "integer"}}, "required": ["min", "max"]}}}, {"type": "function", "function": {"name": "text_length", "description": "计算文本长度", "parameters": {"type": "object", "properties": {"text": {"type": "string"}}, "required": ["text"]}}}, {"type": "function", "function": {"name": "unit_converter", "description": "单位转换", "parameters": {"type": "object", "properties": {"value": {"type": "number"}, "from_unit": {"type": "string"}, "to_unit": {"type": "string"}}, "required": ["value", "from_unit", "to_unit"]}}}, {"type": "function", "function": {"name": "get_current_weather", "description": "获取天气", "parameters": {"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]}}}, {"type": "function", "function": {"name": "get_exchange_rate", "description": "获取汇率", "parameters": {"type": "object", "properties": {"from_currency": {"type": "string"}, "to_currency": {"type": "string"}}, "required": ["from_currency", "to_currency"]}}}, {"type": "function", "function": {"name": "translate_text", "description": "翻译文本", "parameters": {"type": "object", "properties": {"text": {"type": "string"}, "target_lang": {"type": "string"}}, "required": ["text", "target_lang"]}}}, ] TOOL_SHORT_NAMES = { 'calculate_math': '数学', 'get_current_time': '时间', 'random_number': '随机', 'text_length': '字数', 'unit_converter': '单位', 'get_current_weather': '天气', 'get_exchange_rate': '汇率', 'translate_text': '翻译' } def execute_tool(tool_name, args): import datetime try: if tool_name == 'calculate_math': return {"result": eval(args.get('expression', '0'))} elif tool_name == 'get_current_time': tz = args.get('timezone', 'Asia/Shanghai') return {"result": datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')} elif tool_name == 'random_number': return {"result": random.randint(args.get('min', 0), args.get('max', 100))} elif tool_name == 'text_length': return {"result": len(args.get('text', ''))} elif tool_name == 'unit_converter': return {"result": f"{args.get('value', 0)} {args.get('from_unit', '')} = ? {args.get('to_unit', '')}"} elif tool_name == 'get_current_weather': return {"result": f"{args.get('city', 'Unknown')}: 晴, 7~10°C"} elif tool_name == 'get_exchange_rate': return {"result": f"1 {args.get('from_currency', 'USD')} = 7.2 {args.get('to_currency', 'CNY')}"} elif tool_name == 'translate_text': return {"result": f"[翻译结果]: hello world"} return {"result": "Unknown tool"} except Exception as e: return {"error": str(e)} def process_assistant_content(content, is_streaming=False): # 处理tool_call标签,格式化显示 if '' in content: def format_tool_call(match): try: tc = json.loads(match.group(1)) name = tc.get('name', 'unknown') args = tc.get('arguments', {}) return f'
ToolCalling
{name}: {json.dumps(args, ensure_ascii=False)}
' except: return match.group(0) content = re.sub(r'(.*?)', format_tool_call, content, flags=re.DOTALL) # 流式生成且开启思考时,一开始就放到折叠里 if is_streaming and st.session_state.get('enable_thinking', False) and '' not in content and '' not in content: m = re.search(r'(\n\n(?:我是|您好|你好)[^\n]*)', content) if m and m.start(1) > 5: i = m.start(1) think_part = content[:i] answer_part = content[i:] return f'
已思考
{think_part.strip()}
{answer_part}' elif len(content) > 5: return f'
思考中...
{content.strip().replace(chr(10), "
")}
' if '' in content and '' in content: def format_think(match): think_content = match.group(2) if think_content.replace('\n', '').strip(): # 不是全换行 return f'
已思考
{think_content.strip()}
' return '' content = re.sub(r'()(.*?)()', format_think, content, flags=re.DOTALL) if '' in content and '' not in content: def format_think_in_progress(match): tc = match.group(1) return f'
思考中...
{tc.strip().replace(chr(10), "
")}
' content = re.sub(r'(.*?)$', format_think_in_progress, content, flags=re.DOTALL) if '' not in content and '' in content: def format_think_no_start(match): think_content = match.group(1) if think_content.replace('\n', '').strip(): return f'
已思考
{think_content.strip()}
' return '' content = re.sub(r'(.*?)
', format_think_no_start, content, flags=re.DOTALL) return content @st.cache_resource def load_model_tokenizer(model_path): model = AutoModelForCausalLM.from_pretrained( model_path, trust_remote_code=True ) tokenizer = AutoTokenizer.from_pretrained( model_path, trust_remote_code=True ) model = model.half().eval().to(device) return model, tokenizer def clear_chat_messages(): del st.session_state.messages del st.session_state.chat_messages def init_chat_messages(): if "messages" in st.session_state: for i, message in enumerate(st.session_state.messages): if message["role"] == "assistant": st.markdown(process_assistant_content(message["content"]), unsafe_allow_html=True) else: st.markdown( f'
{message["content"]}
', unsafe_allow_html=True) else: st.session_state.messages = [] st.session_state.chat_messages = [] return st.session_state.messages def regenerate_answer(index): st.session_state.messages.pop() st.session_state.chat_messages.pop() st.rerun() # 动态扫描模型目录 script_dir = os.path.dirname(os.path.abspath(__file__)) MODEL_PATHS = {} for d in sorted(os.listdir(script_dir), reverse=True): full_path = os.path.join(script_dir, d) if os.path.isdir(full_path) and not d.startswith('.') and not d.startswith('_'): if any(f.endswith(('.bin', '.safetensors', '.pt')) or os.path.exists(os.path.join(full_path, 'model.safetensors.index.json')) for f in os.listdir(full_path) if os.path.isfile(os.path.join(full_path, f))): MODEL_PATHS[d] = [d, d] if not MODEL_PATHS: MODEL_PATHS = {"No models found": ["", "No models"]} # 模型选择 selected_model = st.sidebar.selectbox('Model', list(MODEL_PATHS.keys()), index=0) model_path = MODEL_PATHS[selected_model][0] slogan = f"我是 {MODEL_PATHS[selected_model][1]},有什么可以帮你的?" if st.session_state.get('lang', 'en') == 'zh' else f"I am {MODEL_PATHS[selected_model][1]}, how can I help you?" st.sidebar.markdown('
', unsafe_allow_html=True) # 语言选择 lang_options = {'中文': 'zh', 'English': 'en'} current_lang = st.session_state.get('lang', 'en') lang_index = 0 if current_lang == 'zh' else 1 lang_label = st.sidebar.radio('Language / 语言', list(lang_options.keys()), index=lang_index, horizontal=True) if lang_options[lang_label] != current_lang: st.session_state.lang = lang_options[lang_label] st.rerun() st.sidebar.markdown('
', unsafe_allow_html=True) # 参数设置 st.session_state.history_chat_num = st.sidebar.slider(get_text('history_rounds'), 0, 8, 0, step=2) st.session_state.max_new_tokens = st.sidebar.slider(get_text('max_length'), 256, 8192, 8192, step=1) st.session_state.temperature = st.sidebar.slider(get_text('temperature'), 0.6, 1.2, 0.90, step=0.01) st.sidebar.markdown('
', unsafe_allow_html=True) # 功能开关 st.session_state.enable_thinking = st.sidebar.checkbox(get_text('thinking'), value=False, help=get_text('think_tip')) st.session_state.selected_tools = [] with st.sidebar.expander(get_text('tools')): st.caption(get_text('tool_select')) selected_count = sum(1 for tool in TOOLS if st.session_state.get(f"tool_{tool['function']['name']}", False)) for tool in TOOLS: name = tool['function']['name'] short_name = TOOL_SHORT_NAMES.get(name, name) checked = st.checkbox(short_name, key=f"tool_{name}", disabled=(selected_count >= 4 and not st.session_state.get(f"tool_{name}", False))) if checked and len(st.session_state.selected_tools) < 4: st.session_state.selected_tools.append(name) image_url = "https://www.modelscope.cn/api/v1/studio/gongjy/MiniMind/repo?Revision=master&FilePath=images%2Flogo2.png&View=true" st.markdown( f'
' '
' f' ' f'{slogan}' '
' f'{get_text("disclaimer")}' '
', unsafe_allow_html=True ) def setup_seed(seed): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False def main(): model, tokenizer = load_model_tokenizer(model_path) if "messages" not in st.session_state: st.session_state.messages = [] st.session_state.chat_messages = [] messages = st.session_state.messages for i, message in enumerate(messages): if message["role"] == "assistant": st.markdown(process_assistant_content(message["content"]), unsafe_allow_html=True) else: st.markdown( f'
{message["content"]}
', unsafe_allow_html=True) prompt = st.chat_input(key="input", placeholder=get_text('send')) if hasattr(st.session_state, 'regenerate') and st.session_state.regenerate: prompt = st.session_state.last_user_message regenerate_index = st.session_state.regenerate_index delattr(st.session_state, 'regenerate') delattr(st.session_state, 'last_user_message') delattr(st.session_state, 'regenerate_index') if prompt: st.markdown( f'
{prompt}
', unsafe_allow_html=True) messages.append({"role": "user", "content": prompt[-st.session_state.max_new_tokens:]}) st.session_state.chat_messages.append({"role": "user", "content": prompt[-st.session_state.max_new_tokens:]}) placeholder = st.empty() random_seed = random.randint(0, 2 ** 32 - 1) setup_seed(random_seed) tools = [t for t in TOOLS if t['function']['name'] in st.session_state.get('selected_tools', [])] or None sys_prompt = [] if tools else [{"role": "system", "content": "你是MiniMind,一个乐于助人、知识渊博的AI助手。请用完整且友好的方式回答用户问题。"}] st.session_state.chat_messages = sys_prompt + st.session_state.chat_messages[-(st.session_state.history_chat_num + 1):] template_kwargs = {"tokenize": False, "add_generation_prompt": True} if st.session_state.get('enable_thinking', False): template_kwargs["open_thinking"] = True if tools: template_kwargs["tools"] = tools new_prompt = tokenizer.apply_chat_template(st.session_state.chat_messages, **template_kwargs) inputs = tokenizer(new_prompt, return_tensors="pt", truncation=True).to(device) streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) generation_kwargs = { "input_ids": inputs.input_ids, "max_length": inputs.input_ids.shape[1] + st.session_state.max_new_tokens, "num_return_sequences": 1, "do_sample": True, "attention_mask": inputs.attention_mask, "pad_token_id": tokenizer.pad_token_id, "eos_token_id": tokenizer.eos_token_id, "temperature": st.session_state.temperature, "top_p": 0.85, "streamer": streamer, } Thread(target=model.generate, kwargs=generation_kwargs).start() answer = "" for new_text in streamer: answer += new_text placeholder.markdown(process_assistant_content(answer, is_streaming=True), unsafe_allow_html=True) full_answer = answer for _ in range(16): tool_calls = re.findall(r'(.*?)', answer, re.DOTALL) if not tool_calls: break st.session_state.chat_messages.append({"role": "assistant", "content": answer}) tool_results = [] for tc_str in tool_calls: try: tc = json.loads(tc_str.strip()) result = execute_tool(tc.get('name', ''), tc.get('arguments', {})) st.session_state.chat_messages.append({"role": "tool", "content": json.dumps(result, ensure_ascii=False)}) tool_results.append(f'
ToolCalled
{tc.get("name", "")}: {json.dumps(result, ensure_ascii=False)}
') except: pass full_answer += "\n" + "\n".join(tool_results) + "\n" placeholder.markdown(process_assistant_content(full_answer, is_streaming=True), unsafe_allow_html=True) new_prompt = tokenizer.apply_chat_template(st.session_state.chat_messages, **template_kwargs) inputs = tokenizer(new_prompt, return_tensors="pt", truncation=True).to(device) streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) generation_kwargs["input_ids"] = inputs.input_ids generation_kwargs["attention_mask"] = inputs.attention_mask generation_kwargs["max_length"] = inputs.input_ids.shape[1] + st.session_state.max_new_tokens generation_kwargs["streamer"] = streamer Thread(target=model.generate, kwargs=generation_kwargs).start() answer = "" for new_text in streamer: answer += new_text placeholder.markdown(process_assistant_content(full_answer + answer, is_streaming=True), unsafe_allow_html=True) full_answer += answer answer = full_answer messages.append({"role": "assistant", "content": answer}) st.session_state.chat_messages.append({"role": "assistant", "content": answer}) if __name__ == "__main__": main()