mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2026-01-27 06:02:00 +08:00
186 lines
4.8 KiB
TypeScript
186 lines
4.8 KiB
TypeScript
import { Tool, ToolUnion, ToolUseBlock } from '@anthropic-ai/sdk/resources'
|
|
import { FunctionCall, FunctionDeclaration, SchemaType, Tool as geminiToool } from '@google/generative-ai'
|
|
import { MCPServer, MCPTool, MCPToolResponse } from '@renderer/types'
|
|
import { ChatCompletionMessageToolCall, ChatCompletionTool } from 'openai/resources'
|
|
|
|
import { ChunkCallbackData } from '../providers'
|
|
|
|
const supportedAttributes = [
|
|
'type',
|
|
'nullable',
|
|
'required',
|
|
// 'format',
|
|
'description',
|
|
'properties',
|
|
'items',
|
|
'enum',
|
|
'anyOf'
|
|
]
|
|
|
|
function filterPropertieAttributes(tool: MCPTool) {
|
|
const roperties = tool.inputSchema.properties
|
|
const getSubMap = (obj: Record<string, any>, keys: string[]) => {
|
|
return Object.fromEntries(Object.entries(obj).filter(([key]) => keys.includes(key)))
|
|
}
|
|
|
|
for (const [key, val] of Object.entries(roperties)) {
|
|
roperties[key] = getSubMap(val, supportedAttributes)
|
|
}
|
|
return roperties
|
|
}
|
|
|
|
export function mcpToolsToOpenAITools(mcpTools: MCPTool[]): Array<ChatCompletionTool> {
|
|
return mcpTools.map((tool) => ({
|
|
type: 'function',
|
|
function: {
|
|
name: tool.id,
|
|
description: tool.description,
|
|
parameters: {
|
|
type: 'object',
|
|
properties: filterPropertieAttributes(tool)
|
|
}
|
|
}
|
|
}))
|
|
}
|
|
|
|
export function openAIToolsToMcpTool(
|
|
mcpTools: MCPTool[] | undefined,
|
|
llmTool: ChatCompletionMessageToolCall
|
|
): MCPTool | undefined {
|
|
if (!mcpTools) return undefined
|
|
const tool = mcpTools.find((tool) => tool.id === llmTool.function.name)
|
|
if (!tool) {
|
|
return undefined
|
|
}
|
|
console.log(
|
|
`[MCP] OpenAI Tool to MCP Tool: ${tool.serverName} ${tool.name}`,
|
|
tool,
|
|
'args',
|
|
llmTool.function.arguments
|
|
)
|
|
// use this to parse the arguments and avoid parsing errors
|
|
let args: any = {}
|
|
try {
|
|
args = JSON.parse(llmTool.function.arguments)
|
|
} catch (e) {
|
|
console.error('Error parsing arguments', e)
|
|
}
|
|
|
|
return {
|
|
id: tool.id,
|
|
serverName: tool.serverName,
|
|
name: tool.name,
|
|
description: tool.description,
|
|
inputSchema: args
|
|
}
|
|
}
|
|
|
|
export async function callMCPTool(tool: MCPTool): Promise<any> {
|
|
console.log(`[MCP] Calling Tool: ${tool.serverName} ${tool.name}`, tool)
|
|
const resp = await window.api.mcp.callTool({
|
|
client: tool.serverName,
|
|
name: tool.name,
|
|
args: tool.inputSchema
|
|
})
|
|
console.log(`[MCP] Tool called: ${tool.serverName} ${tool.name}`, resp)
|
|
return resp
|
|
}
|
|
|
|
export function mcpToolsToAnthropicTools(mcpTools: MCPTool[]): Array<ToolUnion> {
|
|
return mcpTools.map((tool) => {
|
|
const t: Tool = {
|
|
name: tool.id,
|
|
description: tool.description,
|
|
// @ts-ignore no check
|
|
input_schema: tool.inputSchema
|
|
}
|
|
return t
|
|
})
|
|
}
|
|
|
|
export function anthropicToolUseToMcpTool(mcpTools: MCPTool[] | undefined, toolUse: ToolUseBlock): MCPTool | undefined {
|
|
if (!mcpTools) return undefined
|
|
const tool = mcpTools.find((tool) => tool.id === toolUse.name)
|
|
if (!tool) {
|
|
return undefined
|
|
}
|
|
// @ts-ignore ignore type as it it unknow
|
|
tool.inputSchema = toolUse.input
|
|
return tool
|
|
}
|
|
|
|
export function mcpToolsToGeminiTools(mcpTools: MCPTool[] | undefined): geminiToool[] {
|
|
if (!mcpTools || mcpTools.length === 0) {
|
|
// No tools available
|
|
return []
|
|
}
|
|
const functions: FunctionDeclaration[] = []
|
|
|
|
for (const tool of mcpTools) {
|
|
const functionDeclaration: FunctionDeclaration = {
|
|
name: tool.id,
|
|
description: tool.description,
|
|
parameters: {
|
|
type: SchemaType.OBJECT,
|
|
properties: filterPropertieAttributes(tool)
|
|
}
|
|
}
|
|
functions.push(functionDeclaration)
|
|
}
|
|
const tool: geminiToool = {
|
|
functionDeclarations: functions
|
|
}
|
|
return [tool]
|
|
}
|
|
|
|
export function geminiFunctionCallToMcpTool(
|
|
mcpTools: MCPTool[] | undefined,
|
|
fcall: FunctionCall | undefined
|
|
): MCPTool | undefined {
|
|
if (!fcall) return undefined
|
|
if (!mcpTools) return undefined
|
|
const tool = mcpTools.find((tool) => tool.id === fcall.name)
|
|
if (!tool) {
|
|
return undefined
|
|
}
|
|
// @ts-ignore schema is not a valid property
|
|
tool.inputSchema = fcall.args
|
|
return tool
|
|
}
|
|
|
|
export function upsertMCPToolResponse(
|
|
results: MCPToolResponse[],
|
|
resp: MCPToolResponse,
|
|
onChunk: ({ mcpToolResponse }: ChunkCallbackData) => void
|
|
) {
|
|
try {
|
|
for (const ret of results) {
|
|
if (ret.id === resp.id) {
|
|
ret.response = resp.response
|
|
ret.status = resp.status
|
|
return
|
|
}
|
|
}
|
|
results.push(resp)
|
|
} finally {
|
|
onChunk({
|
|
text: '\n',
|
|
mcpToolResponse: results
|
|
})
|
|
}
|
|
}
|
|
|
|
export function filterMCPTools(
|
|
mcpTools: MCPTool[] | undefined,
|
|
enabledServers: MCPServer[] | undefined
|
|
): MCPTool[] | undefined {
|
|
if (mcpTools) {
|
|
if (enabledServers) {
|
|
mcpTools = mcpTools.filter((t) => enabledServers.some((m) => m.name === t.serverName))
|
|
} else {
|
|
mcpTools = []
|
|
}
|
|
}
|
|
return mcpTools
|
|
}
|