import os import sys sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) import re import json import time import random import argparse import warnings import torch from datetime import datetime from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer from openai import OpenAI from model.model_minimind import MiniMindConfig, MiniMindForCausalLM from trainer.trainer_utils import setup_seed, get_model_params warnings.filterwarnings('ignore') TOOLS = [ {"type": "function", "function": {"name": "calculate_math", "description": "计算数学表达式的结果,支持加减乘除、幂运算、开方等", "parameters": {"type": "object", "properties": {"expression": {"type": "string", "description": "数学表达式,如123+456、2**10、sqrt(144)"}}, "required": ["expression"]}}}, {"type": "function", "function": {"name": "get_current_time", "description": "获取当前日期和时间,支持指定时区", "parameters": {"type": "object", "properties": {"timezone": {"type": "string", "description": "时区名称,如Asia/Shanghai、America/New_York", "default": "Asia/Shanghai"}}, "required": []}}}, {"type": "function", "function": {"name": "random_number", "description": "生成指定范围内的随机数", "parameters": {"type": "object", "properties": {"min": {"type": "integer", "description": "最小值", "default": 0}, "max": {"type": "integer", "description": "最大值", "default": 100}}, "required": []}}}, {"type": "function", "function": {"name": "text_length", "description": "计算文本的字符数和单词数", "parameters": {"type": "object", "properties": {"text": {"type": "string", "description": "要统计的文本"}}, "required": ["text"]}}}, {"type": "function", "function": {"name": "unit_converter", "description": "进行单位换算,支持长度、重量、温度等", "parameters": {"type": "object", "properties": {"value": {"type": "number", "description": "要转换的数值"}, "from_unit": {"type": "string", "description": "源单位,如km、miles、kg、pounds、celsius、fahrenheit"}, "to_unit": {"type": "string", "description": "目标单位"}}, "required": ["value", "from_unit", "to_unit"]}}}, {"type": "function", "function": {"name": "get_current_weather", "description": "获取指定城市的当前天气信息,包括温度、湿度和天气状况", "parameters": {"type": "object", "properties": {"location": {"type": "string", "description": "城市名称,如北京、上海、New York"}, "unit": {"type": "string", "description": "温度单位,celsius或fahrenheit", "enum": ["celsius", "fahrenheit"], "default": "celsius"}}, "required": ["location"]}}}, {"type": "function", "function": {"name": "get_exchange_rate", "description": "查询两种货币之间的实时汇率", "parameters": {"type": "object", "properties": {"from_currency": {"type": "string", "description": "源货币代码,如USD、CNY、EUR"}, "to_currency": {"type": "string", "description": "目标货币代码,如USD、CNY、EUR"}}, "required": ["from_currency", "to_currency"]}}}, {"type": "function", "function": {"name": "translate_text", "description": "将文本翻译成目标语言", "parameters": {"type": "object", "properties": {"text": {"type": "string", "description": "要翻译的文本"}, "target_language": {"type": "string", "description": "目标语言,如english、chinese、japanese、french"}}, "required": ["text", "target_language"]}}}, ] MOCK_RESULTS = { "calculate_math": lambda args: {"result": str(eval(str(args.get("expression", "0")).replace("^", "**").replace("×", "*").replace("÷", "/").replace("−", "-").replace("²", "**2").replace("³", "**3").replace("(", "(").replace(")", ")")))}, "get_current_time": lambda args: {"datetime": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), "timezone": args.get("timezone", "Asia/Shanghai")}, "random_number": lambda args: {"result": random.randint(int(args.get("min", 0)), int(args.get("max", 100)))}, "text_length": lambda args: {"characters": len(args.get("text", "")), "words": len(args.get("text", "").split())}, "unit_converter": lambda args: {"result": round(float(args.get("value", 0)) * 0.621371, 2), "from": f"{args.get('value', 0)} {args.get('from_unit', '')}", "to": args.get("to_unit", "")}, "get_current_weather": lambda args: {"city": args.get("location"), "temperature": "22°C", "humidity": "65%", "condition": "晴"}, "get_exchange_rate": lambda args: {"from": args.get("from_currency", ""), "to": args.get("to_currency", ""), "rate": 7.15}, "translate_text": lambda args: {"translated": "hello world"}, } TOOL_MAP = {t["function"]["name"]: t for t in TOOLS} def get_tools(names): return [TOOL_MAP[n] for n in names] TEST_CASES = [ {"prompt": "帮我算一下 256 乘以 37 等于多少", "tools": ["calculate_math", "get_current_time"]}, {"prompt": "现在几点了?", "tools": ["get_current_time", "random_number"]}, {"prompt": "帮我把100公里换算成英里", "tools": ["unit_converter", "calculate_math"]}, {"prompt": "帮我生成一个1到1000的随机数,然后计算它的平方", "tools": ["random_number", "calculate_math", "text_length"]}, {"prompt": "北京今天天气怎么样?", "tools": ["get_current_weather", "get_current_time"]}, {"prompt": "查一下美元兑人民币汇率", "tools": ["get_exchange_rate", "get_current_time"]}, {"prompt": "把'你好世界'翻译成英文", "tools": ["translate_text", "text_length"]}, {"prompt": "What is the weather in Tokyo? Also convert 30 celsius to fahrenheit.", "tools": ["get_current_weather", "unit_converter", "get_current_time"]}, ] def init_model(args): tokenizer = AutoTokenizer.from_pretrained(args.load_from) if 'model' in args.load_from: model = MiniMindForCausalLM(MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers, use_moe=bool(args.use_moe))) moe_suffix = '_moe' if args.use_moe else '' ckp = f'./{args.save_dir}/{args.weight}_{args.hidden_size}{moe_suffix}.pth' model.load_state_dict(torch.load(ckp, map_location=args.device), strict=True) else: model = AutoModelForCausalLM.from_pretrained(args.load_from, trust_remote_code=True) get_model_params(model, model.config) return model.half().eval().to(args.device), tokenizer def parse_tool_calls(text): matches = re.findall(r'(.*?)', text, re.DOTALL) calls = [] for m in matches: try: calls.append(json.loads(m.strip())) except Exception: pass return calls def parse_tool_call_from_text(content): pattern = r'\s*(\{.*?\})\s*' matches = re.findall(pattern, content, re.DOTALL) if not matches: return None tool_calls = [] for i, match in enumerate(matches): try: data = json.loads(match) tool_calls.append({ "id": f"call_{i}", "function": {"name": data.get("name", ""), "arguments": json.dumps(data.get("arguments", {}), ensure_ascii=False)} }) except Exception: pass return tool_calls if tool_calls else None def execute_tool(call, arguments=None): name = call.get("name", "") if isinstance(call, dict) else call try: raw_args = call.get("arguments", {}) if isinstance(call, dict) else arguments args = json.loads(raw_args) if isinstance(raw_args, str) else raw_args except Exception: args = {} fn = MOCK_RESULTS.get(name) if not fn: return {"error": f"未知工具: {name}"} try: return fn(args) except Exception as e: return {"error": f"工具执行失败: {str(e)[:80]}"} def generate(model, tokenizer, messages, tools, args): streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) input_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True, tools=tools, open_thinking=False) inputs = tokenizer(input_text, return_tensors="pt", truncation=True).to(args.device) st = time.time() print('🧠: ', end='') generated_ids = model.generate( inputs["input_ids"], attention_mask=inputs["attention_mask"], max_new_tokens=args.max_new_tokens, do_sample=True, streamer=streamer, pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id, top_p=args.top_p, temperature=args.temperature ) response = tokenizer.decode(generated_ids[0][len(inputs["input_ids"][0]):], skip_special_tokens=True) gen_tokens = len(generated_ids[0]) - len(inputs["input_ids"][0]) print(f'\n[Speed]: {gen_tokens / (time.time() - st):.2f} tokens/s') if args.show_speed else print() return response def chat_api(client, messages, tools, args, stream=True): response = client.chat.completions.create( model=args.api_model, messages=messages, tools=tools, stream=stream, temperature=args.temperature, max_tokens=8192, top_p=args.top_p ) if not stream: choice = response.choices[0] content = choice.message.content or "" tool_calls = choice.message.tool_calls if not tool_calls: tool_calls = parse_tool_call_from_text(content) print(f'🧠: {content}') return content, tool_calls print('🧠: ', end='', flush=True) content, tool_calls = "", None for chunk in response: delta = chunk.choices[0].delta if delta.content: print(delta.content, end="", flush=True) content += delta.content if delta.tool_calls: if tool_calls is None: tool_calls = [] for tc_chunk in delta.tool_calls: idx = tc_chunk.index if tc_chunk.index is not None else len(tool_calls) while len(tool_calls) <= idx: tool_calls.append({ "id": "", "function": {"name": "", "arguments": ""} }) if tc_chunk.id: tool_calls[idx]["id"] += tc_chunk.id if tc_chunk.function: if tc_chunk.function.name: tool_calls[idx]["function"]["name"] += tc_chunk.function.name if tc_chunk.function.arguments: tool_calls[idx]["function"]["arguments"] += tc_chunk.function.arguments print() if not tool_calls: tool_calls = parse_tool_call_from_text(content) return content, tool_calls def run_case(prompt, tools, args, model=None, tokenizer=None, client=None): messages = [{"role": "user", "content": prompt}] while True: if args.backend == 'local': content = generate(model, tokenizer, messages, tools, args) tool_calls = parse_tool_calls(content) else: content, tool_calls = chat_api(client, messages, tools, args, stream=bool(args.stream)) if not tool_calls: break tool_calls = [{ "id": tc.id if hasattr(tc, 'id') else tc.get("id", ""), "name": tc.function.name if hasattr(tc, 'function') else tc["function"]["name"], "arguments": tc.function.arguments if hasattr(tc, 'function') else tc["function"]["arguments"] } for tc in tool_calls] if args.backend == 'api' else tool_calls messages.append({"role": "assistant", "content": content} if args.backend == 'local' else {"role": "assistant", "content": content, "tool_calls": [{"id": tc["id"], "type": "function", "function": {"name": tc["name"], "arguments": tc["arguments"]}} for tc in tool_calls]}) for tc in tool_calls: name = tc["name"] arguments = tc["arguments"] print(f'📞 [Tool Calling]: {name} | args={arguments}') result = execute_tool(tc if args.backend == 'local' else name, arguments) print(f'✅ [Tool Called]: {json.dumps(result, ensure_ascii=False)}') messages.append({"role": "tool", "content": json.dumps(result, ensure_ascii=False)} if args.backend == 'local' else {"role": "tool", "content": json.dumps(result, ensure_ascii=False), "tool_call_id": tc["id"]}) def main(): parser = argparse.ArgumentParser(description="MiniMind ToolCall评测") parser.add_argument('--backend', default='local', choices=['local', 'api'], type=str, help="推理后端(local=本地模型,api=OpenAI兼容接口)") parser.add_argument('--load_from', default='../model', type=str, help="模型加载路径(model=原生torch权重,其他路径=transformers格式)") parser.add_argument('--save_dir', default='../out', type=str, help="模型权重目录") parser.add_argument('--weight', default='full_sft', type=str, help="权重名称前缀(pretrain, full_sft, rlhf, reason, ppo_actor, grpo, spo)") parser.add_argument('--hidden_size', default=768, type=int, help="隐藏层维度") parser.add_argument('--num_hidden_layers', default=8, type=int, help="隐藏层数量") parser.add_argument('--use_moe', default=0, type=int, choices=[0, 1], help="是否使用MoE架构(0=否,1=是)") parser.add_argument('--max_new_tokens', default=512, type=int, help="最大生成长度") parser.add_argument('--temperature', default=0.9, type=float, help="生成温度,控制随机性(0-1,越大越随机)") parser.add_argument('--top_p', default=0.9, type=float, help="nucleus采样阈值(0-1)") parser.add_argument('--show_speed', default=0, type=int, help="显示decode速度(tokens/s)") parser.add_argument('--device', default='cuda' if torch.cuda.is_available() else 'cpu', type=str, help="运行设备") parser.add_argument('--api_base_url', default="http://localhost:11434/v1", type=str, help="OpenAI兼容接口的base_url") parser.add_argument('--api_key', default='sk-123', type=str, help="OpenAI兼容接口的api_key") parser.add_argument('--api_model', default='jingyaogong/minimind-3:latest', type=str, help="API请求时使用的模型名称") parser.add_argument('--stream', default=1, type=int, help="API模式下是否流式输出(0=否,1=是)") args = parser.parse_args() model = tokenizer = client = None if args.backend == 'local': model, tokenizer = init_model(args) else: client = OpenAI(api_key=args.api_key, base_url=args.api_base_url) input_mode = int(input('[0] 自动测试\n[1] 手动输入\n')) cases = [{"prompt": case["prompt"], "tools": get_tools(case["tools"]), "tool_names": case["tools"]} for case in TEST_CASES] if input_mode == 0 else iter(lambda: {"prompt": input('💬: '), "tools": TOOLS, "tool_names": [t["function"]["name"] for t in TOOLS]}, {"prompt": "", "tools": TOOLS, "tool_names": []}) for case in cases: if not case["prompt"]: break setup_seed(random.randint(0, 31415926)) if input_mode == 0: print(f'📦 可用工具: {case["tool_names"]}\n') print(f'💬: {case["prompt"]}') run_case(case["prompt"], case["tools"], args, model=model, tokenizer=tokenizer, client=client) print('\n' + '-' * 50 + '\n') if __name__ == "__main__": main()