diff --git a/src/main/apiServer/middleware/gateway.ts b/src/main/apiServer/middleware/gateway.ts index 4fa4f5f77f..2de7a6a726 100644 --- a/src/main/apiServer/middleware/gateway.ts +++ b/src/main/apiServer/middleware/gateway.ts @@ -7,13 +7,210 @@ * - Model injection for simplified external app integration */ +import { buildFunctionCallToolName } from '@main/utils/mcp' +import type { MCPTool } from '@types' import type { NextFunction, Request, Response } from 'express' import { loggerService } from '../../services/LoggerService' +import { reduxService } from '../../services/ReduxService' import { config } from '../config' +import { mcpApiService } from '../services/mcp' +import { getMCPServersFromRedux } from '../utils/mcp' const logger = loggerService.withContext('GatewayMiddleware') +type AssistantConfig = { + id: string + name: string + prompt?: string + model?: { id: string; provider: string } + defaultModel?: { id: string; provider: string } + settings?: { + streamOutput?: boolean + enableTemperature?: boolean + temperature?: number + enableTopP?: boolean + topP?: number + enableMaxTokens?: boolean + maxTokens?: number + } + mcpServers?: Array<{ id: string }> + allowed_tools?: string[] +} + +type ToolDefinition = { + name: string + description?: string + inputSchema: MCPTool['inputSchema'] +} + +const getEndpointFormat = (endpoint: string): 'openai' | 'anthropic' | 'responses' | null => { + if (endpoint.startsWith('/v1/chat/completions')) return 'openai' + if (endpoint.startsWith('/v1/messages')) return 'anthropic' + if (endpoint.startsWith('/v1/responses')) return 'responses' + return null +} + +const buildAssistantModelId = (assistant: AssistantConfig): string | null => { + const model = assistant.model ?? assistant.defaultModel + if (!model?.provider || !model?.id) { + return null + } + return `${model.provider}:${model.id}` +} + +const applyAssistantMessageOverrides = ( + body: Record, + assistant: AssistantConfig, + format: 'openai' | 'anthropic' | 'responses' +): Record => { + const nextBody = { ...body } + const prompt = assistant.prompt ?? '' + + if (format === 'openai') { + const messages = Array.isArray(nextBody.messages) ? nextBody.messages : [] + const filtered = messages.filter((message) => message?.role !== 'system' && message?.role !== 'developer') + if (prompt.trim().length > 0) { + filtered.unshift({ role: 'system', content: prompt }) + } + nextBody.messages = filtered + } else if (format === 'responses') { + nextBody.instructions = prompt + } else { + nextBody.system = prompt + } + + return nextBody +} + +const applyAssistantParameterOverrides = ( + body: Record, + assistant: AssistantConfig, + format: 'openai' | 'anthropic' | 'responses' +): Record => { + const nextBody = { ...body } + const settings = assistant.settings ?? {} + + if (typeof settings.streamOutput === 'boolean') { + nextBody.stream = settings.streamOutput + } + + if (settings.enableTemperature && typeof settings.temperature === 'number') { + nextBody.temperature = settings.temperature + } else if ('temperature' in nextBody) { + delete nextBody.temperature + } + + if (settings.enableTopP && typeof settings.topP === 'number') { + nextBody.top_p = settings.topP + } else if ('top_p' in nextBody) { + delete nextBody.top_p + } + + if (settings.enableMaxTokens && typeof settings.maxTokens === 'number') { + if (format === 'responses') { + nextBody.max_output_tokens = settings.maxTokens + delete nextBody.max_tokens + } else { + nextBody.max_tokens = settings.maxTokens + if ('max_output_tokens' in nextBody) { + delete nextBody.max_output_tokens + } + } + } else { + if ('max_tokens' in nextBody) { + delete nextBody.max_tokens + } + if ('max_output_tokens' in nextBody) { + delete nextBody.max_output_tokens + } + } + + delete nextBody.tool_choice + + return nextBody +} + +const mapToolsForOpenAI = (tools: ToolDefinition[]) => + tools.map((toolDef) => ({ + type: 'function', + function: { + name: toolDef.name, + description: toolDef.description || '', + parameters: toolDef.inputSchema + } + })) + +const mapToolsForResponses = (tools: ToolDefinition[]) => + tools.map((toolDef) => ({ + type: 'function', + name: toolDef.name, + description: toolDef.description || '', + parameters: toolDef.inputSchema + })) + +const mapToolsForAnthropic = (tools: ToolDefinition[]) => + tools.map((toolDef) => ({ + name: toolDef.name, + description: toolDef.description || '', + input_schema: toolDef.inputSchema + })) + +const buildAssistantTools = async (assistant: AssistantConfig): Promise => { + const serverIds = assistant.mcpServers?.map((server) => server.id).filter(Boolean) ?? [] + if (serverIds.length === 0) { + return [] + } + + const allowedTools = Array.isArray(assistant.allowed_tools) ? new Set(assistant.allowed_tools) : null + const servers = await getMCPServersFromRedux() + const tools: ToolDefinition[] = [] + + for (const serverId of serverIds) { + const server = servers.find((entry) => entry.id === serverId) + if (!server || !server.isActive) { + continue + } + + const info = await mcpApiService.getServerInfo(serverId) + if (!info?.tools || !Array.isArray(info.tools)) { + continue + } + + for (const tool of info.tools as Array<{ + name: string + description?: string + inputSchema?: MCPTool['inputSchema'] + }>) { + if (!tool?.name || !tool.inputSchema) { + continue + } + + if (server.disabledTools?.includes(tool.name)) { + continue + } + + const toolName = buildFunctionCallToolName(info.name, tool.name) + if (allowedTools && !allowedTools.has(toolName)) { + continue + } + + tools.push({ + name: toolName, + description: tool.description, + inputSchema: tool.inputSchema + }) + } + } + + return tools +} + +const resolveAssistantById = async (assistantId: string): Promise => { + const assistants = (await reduxService.select('state.assistants.assistants')) as AssistantConfig[] | null + return assistants?.find((assistant) => assistant.id === assistantId) ?? null +} + /** * Gateway middleware for model group routing * @@ -43,16 +240,79 @@ export const gatewayMiddleware = async (req: Request, res: Response, next: NextF return } - // Inject the group's model into the request - req.body = { - ...req.body, - model: `${group.providerId}:${group.modelId}` - } + const endpoint = req.path.startsWith('/') ? req.path : `/${req.path}` + const endpointFormat = getEndpointFormat(endpoint) - logger.debug('Injected model from group', { - groupName, - model: `${group.providerId}:${group.modelId}` - }) + if (group.mode === 'assistant' && group.assistantId) { + if (!endpointFormat) { + res.status(400).json({ + error: { + type: 'invalid_request_error', + message: `Unsupported endpoint ${endpoint}` + } + }) + return + } + + const assistant = await resolveAssistantById(group.assistantId) + if (!assistant) { + res.status(400).json({ + error: { + type: 'invalid_request_error', + message: `Assistant '${group.assistantId}' not found` + } + }) + return + } + + const modelId = buildAssistantModelId(assistant) + if (!modelId) { + res.status(400).json({ + error: { + type: 'invalid_request_error', + message: `Assistant '${group.assistantId}' is missing a model` + } + }) + return + } + + let nextBody = { + ...req.body, + model: modelId + } + + nextBody = applyAssistantMessageOverrides(nextBody, assistant, endpointFormat) + nextBody = applyAssistantParameterOverrides(nextBody, assistant, endpointFormat) + + const tools = await buildAssistantTools(assistant) + if (endpointFormat === 'openai') { + nextBody.tools = tools.length > 0 ? mapToolsForOpenAI(tools) : undefined + } else if (endpointFormat === 'responses') { + nextBody.tools = tools.length > 0 ? mapToolsForResponses(tools) : undefined + } else { + nextBody.tools = tools.length > 0 ? mapToolsForAnthropic(tools) : undefined + } + + req.body = nextBody + + logger.debug('Injected assistant preset into request', { + groupName, + assistantId: assistant.id, + model: modelId, + toolCount: tools.length + }) + } else { + // Inject the group's model into the request + req.body = { + ...req.body, + model: `${group.providerId}:${group.modelId}` + } + + logger.debug('Injected model from group', { + groupName, + model: `${group.providerId}:${group.modelId}` + }) + } } // Get the endpoint path (for group routes, use the part after groupName) diff --git a/src/renderer/src/i18n/locales/en-us.json b/src/renderer/src/i18n/locales/en-us.json index feb2a804e9..f73c531cbc 100644 --- a/src/renderer/src/i18n/locales/en-us.json +++ b/src/renderer/src/i18n/locales/en-us.json @@ -353,6 +353,13 @@ "description": "Create model groups with unique URLs for different provider/model combinations", "empty": "No model groups configured. Click 'Add Group' to create one.", "label": "Model Groups", + "mode": { + "assistant": "Assistant Preset", + "assistantHint": "Assistant preset overrides request parameters.", + "assistantPlaceholder": "Select assistant", + "label": "Mode", + "model": "Direct Model" + }, "namePlaceholder": "Group name" }, "networkAccess": { diff --git a/src/renderer/src/i18n/locales/zh-cn.json b/src/renderer/src/i18n/locales/zh-cn.json index da208d2407..f5da50c87e 100644 --- a/src/renderer/src/i18n/locales/zh-cn.json +++ b/src/renderer/src/i18n/locales/zh-cn.json @@ -353,6 +353,13 @@ "description": "创建具有唯一 URL 的模型分组,用于不同的提供商/模型组合", "empty": "尚未配置模型分组。点击「添加分组」创建一个。", "label": "模型分组", + "mode": { + "assistant": "助手预设", + "assistantHint": "助手预设会覆盖请求参数。", + "assistantPlaceholder": "选择助手", + "label": "模式", + "model": "直接模型" + }, "namePlaceholder": "分组名称" }, "networkAccess": { diff --git a/src/renderer/src/pages/settings/ToolSettings/ApiServerSettings/ApiServerSettings.tsx b/src/renderer/src/pages/settings/ToolSettings/ApiServerSettings/ApiServerSettings.tsx index ae8b549086..65a2e3df60 100644 --- a/src/renderer/src/pages/settings/ToolSettings/ApiServerSettings/ApiServerSettings.tsx +++ b/src/renderer/src/pages/settings/ToolSettings/ApiServerSettings/ApiServerSettings.tsx @@ -46,6 +46,7 @@ const ApiServerSettings: FC = () => { // API Gateway state with proper defaults const apiServerConfig = useSelector((state: RootState) => state.settings.apiServer) + const assistants = useSelector((state: RootState) => state.assistants.assistants) const { apiServerRunning, apiServerLoading, startApiServer, stopApiServer, restartApiServer, setApiServerEnabled } = useApiServer() @@ -108,6 +109,8 @@ const ApiServerSettings: FC = () => { name: `group-${apiServerConfig.modelGroups.length + 1}`, // URL-safe name providerId: '', modelId: '', + mode: 'model', + assistantId: '', createdAt: Date.now() } dispatch(addApiGatewayModelGroup(newGroup)) @@ -266,7 +269,13 @@ const ApiServerSettings: FC = () => { ) : ( {apiServerConfig.modelGroups.map((group) => ( - + ))} )} @@ -295,6 +304,7 @@ const ApiServerSettings: FC = () => { // Model Group Card Component interface ModelGroupCardProps { group: ModelGroup + assistants: RootState['assistants']['assistants'] onUpdate: (group: ModelGroup) => void onDelete: (groupId: string) => void } @@ -305,11 +315,12 @@ const ENV_FORMAT_TO_ENDPOINT: Record = { responses: '/v1/responses' } -const ModelGroupCard: FC = ({ group, onUpdate, onDelete }) => { +const ModelGroupCard: FC = ({ group, assistants, onUpdate, onDelete }) => { const { t } = useTranslation() const { providers } = useProviders() const apiServerConfig = useSelector((state: RootState) => state.settings.apiServer) const [envFormat, setEnvFormat] = useState('openai') + const mode = group.mode ?? 'model' // Reset envFormat when selected endpoint is disabled useEffect(() => { @@ -398,7 +409,36 @@ const ModelGroupCard: FC = ({ group, onUpdate, onDelete }) }) } - const isConfigured = group.providerId && group.modelId + const handleModeChange = (nextMode: 'model' | 'assistant') => { + if (nextMode === mode) return + onUpdate({ + ...group, + mode: nextMode, + assistantId: nextMode === 'assistant' ? group.assistantId || '' : '', + providerId: nextMode === 'model' ? group.providerId : '', + modelId: nextMode === 'model' ? group.modelId : '' + }) + } + + const handleAssistantChange = (assistantId: string | null) => { + onUpdate({ + ...group, + assistantId: assistantId || '' + }) + } + + const selectedAssistant = useMemo(() => { + if (!group.assistantId) return undefined + return assistants.find((assistant) => assistant.id === group.assistantId) + }, [assistants, group.assistantId]) + + const assistantModelLabel = useMemo(() => { + const model = selectedAssistant?.model ?? selectedAssistant?.defaultModel + if (!model) return undefined + return model.name || model.id + }, [selectedAssistant]) + + const isConfigured = mode === 'assistant' ? !!group.assistantId : !!(group.providerId && group.modelId) return ( @@ -427,35 +467,71 @@ const ModelGroupCard: FC = ({ group, onUpdate, onDelete }) + + {t('apiGateway.fields.modelGroups.mode.label', 'Mode')} + handleModeChange(value as 'model' | 'assistant')} + options={[ + { label: t('apiGateway.fields.modelGroups.mode.model', 'Direct Model'), value: 'model' }, + { label: t('apiGateway.fields.modelGroups.mode.assistant', 'Assistant Preset'), value: 'assistant' } + ]} + /> + - handleProviderChange((value as string) || null)} - placeholder={t('apiGateway.fields.defaultModel.providerPlaceholder')} - allowClear - showSearch - optionFilterProp="label" - style={{ flex: 1 }} - options={providers.map((p) => ({ - value: p.id, - label: p.isSystem ? getProviderLabel(p.id) : p.name - }))} - /> - handleModelChange((value as string) || null)} - placeholder={t('apiGateway.fields.defaultModel.modelPlaceholder')} - allowClear - showSearch - optionFilterProp="label" - disabled={!group.providerId} - style={{ flex: 1 }} - options={models.map((m) => ({ - value: m.id, - label: m.name || m.id - }))} - /> + {mode === 'assistant' ? ( + handleAssistantChange((value as string) || null)} + placeholder={t('apiGateway.fields.modelGroups.mode.assistantPlaceholder', 'Select assistant')} + allowClear + showSearch + optionFilterProp="label" + style={{ flex: 1 }} + options={assistants.map((assistant) => ({ + value: assistant.id, + label: assistant.name + }))} + /> + ) : ( + <> + handleProviderChange((value as string) || null)} + placeholder={t('apiGateway.fields.defaultModel.providerPlaceholder')} + allowClear + showSearch + optionFilterProp="label" + style={{ flex: 1 }} + options={providers.map((p) => ({ + value: p.id, + label: p.isSystem ? getProviderLabel(p.id) : p.name + }))} + /> + handleModelChange((value as string) || null)} + placeholder={t('apiGateway.fields.defaultModel.modelPlaceholder')} + allowClear + showSearch + optionFilterProp="label" + disabled={!group.providerId} + style={{ flex: 1 }} + options={models.map((m) => ({ + value: m.id, + label: m.name || m.id + }))} + /> + + )} + {mode === 'assistant' && ( + + {t('apiGateway.fields.modelGroups.mode.assistantHint', 'Assistant preset overrides request parameters.')} + {assistantModelLabel ? ` (${assistantModelLabel})` : ''} + + )} {isConfigured && ( @@ -819,6 +895,23 @@ const GroupContent = styled.div` gap: 12px; ` +const ModeRow = styled.div` + display: flex; + align-items: center; + justify-content: space-between; + gap: 12px; +` + +const ModeLabel = styled.div` + font-size: 12px; + color: var(--color-text-3); +` + +const ModeHint = styled.div` + font-size: 12px; + color: var(--color-text-3); +` + const GroupUrlSection = styled.div` display: flex; flex-direction: column; diff --git a/src/renderer/src/types/apiServer.ts b/src/renderer/src/types/apiServer.ts index 28df8b9cd8..a35a70a407 100644 --- a/src/renderer/src/types/apiServer.ts +++ b/src/renderer/src/types/apiServer.ts @@ -12,6 +12,8 @@ export type ModelGroup = { name: string // display name providerId: string modelId: string + mode?: 'model' | 'assistant' + assistantId?: string createdAt: number }