mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2026-01-14 06:07:23 +08:00
feat: Add an assistant preset feature that supports request parameter overriding and assistant selection.
This commit is contained in:
parent
8837542e1e
commit
426224c3f3
@ -7,13 +7,210 @@
|
||||
* - Model injection for simplified external app integration
|
||||
*/
|
||||
|
||||
import { buildFunctionCallToolName } from '@main/utils/mcp'
|
||||
import type { MCPTool } from '@types'
|
||||
import type { NextFunction, Request, Response } from 'express'
|
||||
|
||||
import { loggerService } from '../../services/LoggerService'
|
||||
import { reduxService } from '../../services/ReduxService'
|
||||
import { config } from '../config'
|
||||
import { mcpApiService } from '../services/mcp'
|
||||
import { getMCPServersFromRedux } from '../utils/mcp'
|
||||
|
||||
const logger = loggerService.withContext('GatewayMiddleware')
|
||||
|
||||
type AssistantConfig = {
|
||||
id: string
|
||||
name: string
|
||||
prompt?: string
|
||||
model?: { id: string; provider: string }
|
||||
defaultModel?: { id: string; provider: string }
|
||||
settings?: {
|
||||
streamOutput?: boolean
|
||||
enableTemperature?: boolean
|
||||
temperature?: number
|
||||
enableTopP?: boolean
|
||||
topP?: number
|
||||
enableMaxTokens?: boolean
|
||||
maxTokens?: number
|
||||
}
|
||||
mcpServers?: Array<{ id: string }>
|
||||
allowed_tools?: string[]
|
||||
}
|
||||
|
||||
type ToolDefinition = {
|
||||
name: string
|
||||
description?: string
|
||||
inputSchema: MCPTool['inputSchema']
|
||||
}
|
||||
|
||||
const getEndpointFormat = (endpoint: string): 'openai' | 'anthropic' | 'responses' | null => {
|
||||
if (endpoint.startsWith('/v1/chat/completions')) return 'openai'
|
||||
if (endpoint.startsWith('/v1/messages')) return 'anthropic'
|
||||
if (endpoint.startsWith('/v1/responses')) return 'responses'
|
||||
return null
|
||||
}
|
||||
|
||||
const buildAssistantModelId = (assistant: AssistantConfig): string | null => {
|
||||
const model = assistant.model ?? assistant.defaultModel
|
||||
if (!model?.provider || !model?.id) {
|
||||
return null
|
||||
}
|
||||
return `${model.provider}:${model.id}`
|
||||
}
|
||||
|
||||
const applyAssistantMessageOverrides = (
|
||||
body: Record<string, any>,
|
||||
assistant: AssistantConfig,
|
||||
format: 'openai' | 'anthropic' | 'responses'
|
||||
): Record<string, any> => {
|
||||
const nextBody = { ...body }
|
||||
const prompt = assistant.prompt ?? ''
|
||||
|
||||
if (format === 'openai') {
|
||||
const messages = Array.isArray(nextBody.messages) ? nextBody.messages : []
|
||||
const filtered = messages.filter((message) => message?.role !== 'system' && message?.role !== 'developer')
|
||||
if (prompt.trim().length > 0) {
|
||||
filtered.unshift({ role: 'system', content: prompt })
|
||||
}
|
||||
nextBody.messages = filtered
|
||||
} else if (format === 'responses') {
|
||||
nextBody.instructions = prompt
|
||||
} else {
|
||||
nextBody.system = prompt
|
||||
}
|
||||
|
||||
return nextBody
|
||||
}
|
||||
|
||||
const applyAssistantParameterOverrides = (
|
||||
body: Record<string, any>,
|
||||
assistant: AssistantConfig,
|
||||
format: 'openai' | 'anthropic' | 'responses'
|
||||
): Record<string, any> => {
|
||||
const nextBody = { ...body }
|
||||
const settings = assistant.settings ?? {}
|
||||
|
||||
if (typeof settings.streamOutput === 'boolean') {
|
||||
nextBody.stream = settings.streamOutput
|
||||
}
|
||||
|
||||
if (settings.enableTemperature && typeof settings.temperature === 'number') {
|
||||
nextBody.temperature = settings.temperature
|
||||
} else if ('temperature' in nextBody) {
|
||||
delete nextBody.temperature
|
||||
}
|
||||
|
||||
if (settings.enableTopP && typeof settings.topP === 'number') {
|
||||
nextBody.top_p = settings.topP
|
||||
} else if ('top_p' in nextBody) {
|
||||
delete nextBody.top_p
|
||||
}
|
||||
|
||||
if (settings.enableMaxTokens && typeof settings.maxTokens === 'number') {
|
||||
if (format === 'responses') {
|
||||
nextBody.max_output_tokens = settings.maxTokens
|
||||
delete nextBody.max_tokens
|
||||
} else {
|
||||
nextBody.max_tokens = settings.maxTokens
|
||||
if ('max_output_tokens' in nextBody) {
|
||||
delete nextBody.max_output_tokens
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if ('max_tokens' in nextBody) {
|
||||
delete nextBody.max_tokens
|
||||
}
|
||||
if ('max_output_tokens' in nextBody) {
|
||||
delete nextBody.max_output_tokens
|
||||
}
|
||||
}
|
||||
|
||||
delete nextBody.tool_choice
|
||||
|
||||
return nextBody
|
||||
}
|
||||
|
||||
const mapToolsForOpenAI = (tools: ToolDefinition[]) =>
|
||||
tools.map((toolDef) => ({
|
||||
type: 'function',
|
||||
function: {
|
||||
name: toolDef.name,
|
||||
description: toolDef.description || '',
|
||||
parameters: toolDef.inputSchema
|
||||
}
|
||||
}))
|
||||
|
||||
const mapToolsForResponses = (tools: ToolDefinition[]) =>
|
||||
tools.map((toolDef) => ({
|
||||
type: 'function',
|
||||
name: toolDef.name,
|
||||
description: toolDef.description || '',
|
||||
parameters: toolDef.inputSchema
|
||||
}))
|
||||
|
||||
const mapToolsForAnthropic = (tools: ToolDefinition[]) =>
|
||||
tools.map((toolDef) => ({
|
||||
name: toolDef.name,
|
||||
description: toolDef.description || '',
|
||||
input_schema: toolDef.inputSchema
|
||||
}))
|
||||
|
||||
const buildAssistantTools = async (assistant: AssistantConfig): Promise<ToolDefinition[]> => {
|
||||
const serverIds = assistant.mcpServers?.map((server) => server.id).filter(Boolean) ?? []
|
||||
if (serverIds.length === 0) {
|
||||
return []
|
||||
}
|
||||
|
||||
const allowedTools = Array.isArray(assistant.allowed_tools) ? new Set(assistant.allowed_tools) : null
|
||||
const servers = await getMCPServersFromRedux()
|
||||
const tools: ToolDefinition[] = []
|
||||
|
||||
for (const serverId of serverIds) {
|
||||
const server = servers.find((entry) => entry.id === serverId)
|
||||
if (!server || !server.isActive) {
|
||||
continue
|
||||
}
|
||||
|
||||
const info = await mcpApiService.getServerInfo(serverId)
|
||||
if (!info?.tools || !Array.isArray(info.tools)) {
|
||||
continue
|
||||
}
|
||||
|
||||
for (const tool of info.tools as Array<{
|
||||
name: string
|
||||
description?: string
|
||||
inputSchema?: MCPTool['inputSchema']
|
||||
}>) {
|
||||
if (!tool?.name || !tool.inputSchema) {
|
||||
continue
|
||||
}
|
||||
|
||||
if (server.disabledTools?.includes(tool.name)) {
|
||||
continue
|
||||
}
|
||||
|
||||
const toolName = buildFunctionCallToolName(info.name, tool.name)
|
||||
if (allowedTools && !allowedTools.has(toolName)) {
|
||||
continue
|
||||
}
|
||||
|
||||
tools.push({
|
||||
name: toolName,
|
||||
description: tool.description,
|
||||
inputSchema: tool.inputSchema
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return tools
|
||||
}
|
||||
|
||||
const resolveAssistantById = async (assistantId: string): Promise<AssistantConfig | null> => {
|
||||
const assistants = (await reduxService.select('state.assistants.assistants')) as AssistantConfig[] | null
|
||||
return assistants?.find((assistant) => assistant.id === assistantId) ?? null
|
||||
}
|
||||
|
||||
/**
|
||||
* Gateway middleware for model group routing
|
||||
*
|
||||
@ -43,16 +240,79 @@ export const gatewayMiddleware = async (req: Request, res: Response, next: NextF
|
||||
return
|
||||
}
|
||||
|
||||
// Inject the group's model into the request
|
||||
req.body = {
|
||||
...req.body,
|
||||
model: `${group.providerId}:${group.modelId}`
|
||||
}
|
||||
const endpoint = req.path.startsWith('/') ? req.path : `/${req.path}`
|
||||
const endpointFormat = getEndpointFormat(endpoint)
|
||||
|
||||
logger.debug('Injected model from group', {
|
||||
groupName,
|
||||
model: `${group.providerId}:${group.modelId}`
|
||||
})
|
||||
if (group.mode === 'assistant' && group.assistantId) {
|
||||
if (!endpointFormat) {
|
||||
res.status(400).json({
|
||||
error: {
|
||||
type: 'invalid_request_error',
|
||||
message: `Unsupported endpoint ${endpoint}`
|
||||
}
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
const assistant = await resolveAssistantById(group.assistantId)
|
||||
if (!assistant) {
|
||||
res.status(400).json({
|
||||
error: {
|
||||
type: 'invalid_request_error',
|
||||
message: `Assistant '${group.assistantId}' not found`
|
||||
}
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
const modelId = buildAssistantModelId(assistant)
|
||||
if (!modelId) {
|
||||
res.status(400).json({
|
||||
error: {
|
||||
type: 'invalid_request_error',
|
||||
message: `Assistant '${group.assistantId}' is missing a model`
|
||||
}
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
let nextBody = {
|
||||
...req.body,
|
||||
model: modelId
|
||||
}
|
||||
|
||||
nextBody = applyAssistantMessageOverrides(nextBody, assistant, endpointFormat)
|
||||
nextBody = applyAssistantParameterOverrides(nextBody, assistant, endpointFormat)
|
||||
|
||||
const tools = await buildAssistantTools(assistant)
|
||||
if (endpointFormat === 'openai') {
|
||||
nextBody.tools = tools.length > 0 ? mapToolsForOpenAI(tools) : undefined
|
||||
} else if (endpointFormat === 'responses') {
|
||||
nextBody.tools = tools.length > 0 ? mapToolsForResponses(tools) : undefined
|
||||
} else {
|
||||
nextBody.tools = tools.length > 0 ? mapToolsForAnthropic(tools) : undefined
|
||||
}
|
||||
|
||||
req.body = nextBody
|
||||
|
||||
logger.debug('Injected assistant preset into request', {
|
||||
groupName,
|
||||
assistantId: assistant.id,
|
||||
model: modelId,
|
||||
toolCount: tools.length
|
||||
})
|
||||
} else {
|
||||
// Inject the group's model into the request
|
||||
req.body = {
|
||||
...req.body,
|
||||
model: `${group.providerId}:${group.modelId}`
|
||||
}
|
||||
|
||||
logger.debug('Injected model from group', {
|
||||
groupName,
|
||||
model: `${group.providerId}:${group.modelId}`
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Get the endpoint path (for group routes, use the part after groupName)
|
||||
|
||||
@ -353,6 +353,13 @@
|
||||
"description": "Create model groups with unique URLs for different provider/model combinations",
|
||||
"empty": "No model groups configured. Click 'Add Group' to create one.",
|
||||
"label": "Model Groups",
|
||||
"mode": {
|
||||
"assistant": "Assistant Preset",
|
||||
"assistantHint": "Assistant preset overrides request parameters.",
|
||||
"assistantPlaceholder": "Select assistant",
|
||||
"label": "Mode",
|
||||
"model": "Direct Model"
|
||||
},
|
||||
"namePlaceholder": "Group name"
|
||||
},
|
||||
"networkAccess": {
|
||||
|
||||
@ -353,6 +353,13 @@
|
||||
"description": "创建具有唯一 URL 的模型分组,用于不同的提供商/模型组合",
|
||||
"empty": "尚未配置模型分组。点击「添加分组」创建一个。",
|
||||
"label": "模型分组",
|
||||
"mode": {
|
||||
"assistant": "助手预设",
|
||||
"assistantHint": "助手预设会覆盖请求参数。",
|
||||
"assistantPlaceholder": "选择助手",
|
||||
"label": "模式",
|
||||
"model": "直接模型"
|
||||
},
|
||||
"namePlaceholder": "分组名称"
|
||||
},
|
||||
"networkAccess": {
|
||||
|
||||
@ -46,6 +46,7 @@ const ApiServerSettings: FC = () => {
|
||||
|
||||
// API Gateway state with proper defaults
|
||||
const apiServerConfig = useSelector((state: RootState) => state.settings.apiServer)
|
||||
const assistants = useSelector((state: RootState) => state.assistants.assistants)
|
||||
const { apiServerRunning, apiServerLoading, startApiServer, stopApiServer, restartApiServer, setApiServerEnabled } =
|
||||
useApiServer()
|
||||
|
||||
@ -108,6 +109,8 @@ const ApiServerSettings: FC = () => {
|
||||
name: `group-${apiServerConfig.modelGroups.length + 1}`, // URL-safe name
|
||||
providerId: '',
|
||||
modelId: '',
|
||||
mode: 'model',
|
||||
assistantId: '',
|
||||
createdAt: Date.now()
|
||||
}
|
||||
dispatch(addApiGatewayModelGroup(newGroup))
|
||||
@ -266,7 +269,13 @@ const ApiServerSettings: FC = () => {
|
||||
) : (
|
||||
<ModelGroupList>
|
||||
{apiServerConfig.modelGroups.map((group) => (
|
||||
<ModelGroupCard key={group.id} group={group} onUpdate={updateModelGroup} onDelete={deleteModelGroup} />
|
||||
<ModelGroupCard
|
||||
key={group.id}
|
||||
group={group}
|
||||
assistants={assistants}
|
||||
onUpdate={updateModelGroup}
|
||||
onDelete={deleteModelGroup}
|
||||
/>
|
||||
))}
|
||||
</ModelGroupList>
|
||||
)}
|
||||
@ -295,6 +304,7 @@ const ApiServerSettings: FC = () => {
|
||||
// Model Group Card Component
|
||||
interface ModelGroupCardProps {
|
||||
group: ModelGroup
|
||||
assistants: RootState['assistants']['assistants']
|
||||
onUpdate: (group: ModelGroup) => void
|
||||
onDelete: (groupId: string) => void
|
||||
}
|
||||
@ -305,11 +315,12 @@ const ENV_FORMAT_TO_ENDPOINT: Record<EnvFormat, GatewayEndpoint> = {
|
||||
responses: '/v1/responses'
|
||||
}
|
||||
|
||||
const ModelGroupCard: FC<ModelGroupCardProps> = ({ group, onUpdate, onDelete }) => {
|
||||
const ModelGroupCard: FC<ModelGroupCardProps> = ({ group, assistants, onUpdate, onDelete }) => {
|
||||
const { t } = useTranslation()
|
||||
const { providers } = useProviders()
|
||||
const apiServerConfig = useSelector((state: RootState) => state.settings.apiServer)
|
||||
const [envFormat, setEnvFormat] = useState<EnvFormat>('openai')
|
||||
const mode = group.mode ?? 'model'
|
||||
|
||||
// Reset envFormat when selected endpoint is disabled
|
||||
useEffect(() => {
|
||||
@ -398,7 +409,36 @@ const ModelGroupCard: FC<ModelGroupCardProps> = ({ group, onUpdate, onDelete })
|
||||
})
|
||||
}
|
||||
|
||||
const isConfigured = group.providerId && group.modelId
|
||||
const handleModeChange = (nextMode: 'model' | 'assistant') => {
|
||||
if (nextMode === mode) return
|
||||
onUpdate({
|
||||
...group,
|
||||
mode: nextMode,
|
||||
assistantId: nextMode === 'assistant' ? group.assistantId || '' : '',
|
||||
providerId: nextMode === 'model' ? group.providerId : '',
|
||||
modelId: nextMode === 'model' ? group.modelId : ''
|
||||
})
|
||||
}
|
||||
|
||||
const handleAssistantChange = (assistantId: string | null) => {
|
||||
onUpdate({
|
||||
...group,
|
||||
assistantId: assistantId || ''
|
||||
})
|
||||
}
|
||||
|
||||
const selectedAssistant = useMemo(() => {
|
||||
if (!group.assistantId) return undefined
|
||||
return assistants.find((assistant) => assistant.id === group.assistantId)
|
||||
}, [assistants, group.assistantId])
|
||||
|
||||
const assistantModelLabel = useMemo(() => {
|
||||
const model = selectedAssistant?.model ?? selectedAssistant?.defaultModel
|
||||
if (!model) return undefined
|
||||
return model.name || model.id
|
||||
}, [selectedAssistant])
|
||||
|
||||
const isConfigured = mode === 'assistant' ? !!group.assistantId : !!(group.providerId && group.modelId)
|
||||
|
||||
return (
|
||||
<GroupCard $configured={!!isConfigured}>
|
||||
@ -427,35 +467,71 @@ const ModelGroupCard: FC<ModelGroupCardProps> = ({ group, onUpdate, onDelete })
|
||||
</GroupHeader>
|
||||
|
||||
<GroupContent>
|
||||
<ModeRow>
|
||||
<ModeLabel>{t('apiGateway.fields.modelGroups.mode.label', 'Mode')}</ModeLabel>
|
||||
<Segmented
|
||||
size="small"
|
||||
value={mode}
|
||||
onChange={(value) => handleModeChange(value as 'model' | 'assistant')}
|
||||
options={[
|
||||
{ label: t('apiGateway.fields.modelGroups.mode.model', 'Direct Model'), value: 'model' },
|
||||
{ label: t('apiGateway.fields.modelGroups.mode.assistant', 'Assistant Preset'), value: 'assistant' }
|
||||
]}
|
||||
/>
|
||||
</ModeRow>
|
||||
<SelectRow>
|
||||
<StyledSelect
|
||||
value={group.providerId || undefined}
|
||||
onChange={(value) => handleProviderChange((value as string) || null)}
|
||||
placeholder={t('apiGateway.fields.defaultModel.providerPlaceholder')}
|
||||
allowClear
|
||||
showSearch
|
||||
optionFilterProp="label"
|
||||
style={{ flex: 1 }}
|
||||
options={providers.map((p) => ({
|
||||
value: p.id,
|
||||
label: p.isSystem ? getProviderLabel(p.id) : p.name
|
||||
}))}
|
||||
/>
|
||||
<StyledSelect
|
||||
value={group.modelId || undefined}
|
||||
onChange={(value) => handleModelChange((value as string) || null)}
|
||||
placeholder={t('apiGateway.fields.defaultModel.modelPlaceholder')}
|
||||
allowClear
|
||||
showSearch
|
||||
optionFilterProp="label"
|
||||
disabled={!group.providerId}
|
||||
style={{ flex: 1 }}
|
||||
options={models.map((m) => ({
|
||||
value: m.id,
|
||||
label: m.name || m.id
|
||||
}))}
|
||||
/>
|
||||
{mode === 'assistant' ? (
|
||||
<StyledSelect
|
||||
value={group.assistantId || undefined}
|
||||
onChange={(value) => handleAssistantChange((value as string) || null)}
|
||||
placeholder={t('apiGateway.fields.modelGroups.mode.assistantPlaceholder', 'Select assistant')}
|
||||
allowClear
|
||||
showSearch
|
||||
optionFilterProp="label"
|
||||
style={{ flex: 1 }}
|
||||
options={assistants.map((assistant) => ({
|
||||
value: assistant.id,
|
||||
label: assistant.name
|
||||
}))}
|
||||
/>
|
||||
) : (
|
||||
<>
|
||||
<StyledSelect
|
||||
value={group.providerId || undefined}
|
||||
onChange={(value) => handleProviderChange((value as string) || null)}
|
||||
placeholder={t('apiGateway.fields.defaultModel.providerPlaceholder')}
|
||||
allowClear
|
||||
showSearch
|
||||
optionFilterProp="label"
|
||||
style={{ flex: 1 }}
|
||||
options={providers.map((p) => ({
|
||||
value: p.id,
|
||||
label: p.isSystem ? getProviderLabel(p.id) : p.name
|
||||
}))}
|
||||
/>
|
||||
<StyledSelect
|
||||
value={group.modelId || undefined}
|
||||
onChange={(value) => handleModelChange((value as string) || null)}
|
||||
placeholder={t('apiGateway.fields.defaultModel.modelPlaceholder')}
|
||||
allowClear
|
||||
showSearch
|
||||
optionFilterProp="label"
|
||||
disabled={!group.providerId}
|
||||
style={{ flex: 1 }}
|
||||
options={models.map((m) => ({
|
||||
value: m.id,
|
||||
label: m.name || m.id
|
||||
}))}
|
||||
/>
|
||||
</>
|
||||
)}
|
||||
</SelectRow>
|
||||
{mode === 'assistant' && (
|
||||
<ModeHint>
|
||||
{t('apiGateway.fields.modelGroups.mode.assistantHint', 'Assistant preset overrides request parameters.')}
|
||||
{assistantModelLabel ? ` (${assistantModelLabel})` : ''}
|
||||
</ModeHint>
|
||||
)}
|
||||
|
||||
{isConfigured && (
|
||||
<GroupUrlSection>
|
||||
@ -819,6 +895,23 @@ const GroupContent = styled.div`
|
||||
gap: 12px;
|
||||
`
|
||||
|
||||
const ModeRow = styled.div`
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: space-between;
|
||||
gap: 12px;
|
||||
`
|
||||
|
||||
const ModeLabel = styled.div`
|
||||
font-size: 12px;
|
||||
color: var(--color-text-3);
|
||||
`
|
||||
|
||||
const ModeHint = styled.div`
|
||||
font-size: 12px;
|
||||
color: var(--color-text-3);
|
||||
`
|
||||
|
||||
const GroupUrlSection = styled.div`
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
|
||||
@ -12,6 +12,8 @@ export type ModelGroup = {
|
||||
name: string // display name
|
||||
providerId: string
|
||||
modelId: string
|
||||
mode?: 'model' | 'assistant'
|
||||
assistantId?: string
|
||||
createdAt: number
|
||||
}
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user