feat: Add an assistant preset feature that supports request parameter overriding and assistant selection.

This commit is contained in:
suyao 2026-01-07 13:38:20 +08:00
parent 8837542e1e
commit 426224c3f3
No known key found for this signature in database
5 changed files with 408 additions and 39 deletions

View File

@ -7,13 +7,210 @@
* - Model injection for simplified external app integration
*/
import { buildFunctionCallToolName } from '@main/utils/mcp'
import type { MCPTool } from '@types'
import type { NextFunction, Request, Response } from 'express'
import { loggerService } from '../../services/LoggerService'
import { reduxService } from '../../services/ReduxService'
import { config } from '../config'
import { mcpApiService } from '../services/mcp'
import { getMCPServersFromRedux } from '../utils/mcp'
const logger = loggerService.withContext('GatewayMiddleware')
type AssistantConfig = {
id: string
name: string
prompt?: string
model?: { id: string; provider: string }
defaultModel?: { id: string; provider: string }
settings?: {
streamOutput?: boolean
enableTemperature?: boolean
temperature?: number
enableTopP?: boolean
topP?: number
enableMaxTokens?: boolean
maxTokens?: number
}
mcpServers?: Array<{ id: string }>
allowed_tools?: string[]
}
type ToolDefinition = {
name: string
description?: string
inputSchema: MCPTool['inputSchema']
}
const getEndpointFormat = (endpoint: string): 'openai' | 'anthropic' | 'responses' | null => {
if (endpoint.startsWith('/v1/chat/completions')) return 'openai'
if (endpoint.startsWith('/v1/messages')) return 'anthropic'
if (endpoint.startsWith('/v1/responses')) return 'responses'
return null
}
const buildAssistantModelId = (assistant: AssistantConfig): string | null => {
const model = assistant.model ?? assistant.defaultModel
if (!model?.provider || !model?.id) {
return null
}
return `${model.provider}:${model.id}`
}
const applyAssistantMessageOverrides = (
body: Record<string, any>,
assistant: AssistantConfig,
format: 'openai' | 'anthropic' | 'responses'
): Record<string, any> => {
const nextBody = { ...body }
const prompt = assistant.prompt ?? ''
if (format === 'openai') {
const messages = Array.isArray(nextBody.messages) ? nextBody.messages : []
const filtered = messages.filter((message) => message?.role !== 'system' && message?.role !== 'developer')
if (prompt.trim().length > 0) {
filtered.unshift({ role: 'system', content: prompt })
}
nextBody.messages = filtered
} else if (format === 'responses') {
nextBody.instructions = prompt
} else {
nextBody.system = prompt
}
return nextBody
}
const applyAssistantParameterOverrides = (
body: Record<string, any>,
assistant: AssistantConfig,
format: 'openai' | 'anthropic' | 'responses'
): Record<string, any> => {
const nextBody = { ...body }
const settings = assistant.settings ?? {}
if (typeof settings.streamOutput === 'boolean') {
nextBody.stream = settings.streamOutput
}
if (settings.enableTemperature && typeof settings.temperature === 'number') {
nextBody.temperature = settings.temperature
} else if ('temperature' in nextBody) {
delete nextBody.temperature
}
if (settings.enableTopP && typeof settings.topP === 'number') {
nextBody.top_p = settings.topP
} else if ('top_p' in nextBody) {
delete nextBody.top_p
}
if (settings.enableMaxTokens && typeof settings.maxTokens === 'number') {
if (format === 'responses') {
nextBody.max_output_tokens = settings.maxTokens
delete nextBody.max_tokens
} else {
nextBody.max_tokens = settings.maxTokens
if ('max_output_tokens' in nextBody) {
delete nextBody.max_output_tokens
}
}
} else {
if ('max_tokens' in nextBody) {
delete nextBody.max_tokens
}
if ('max_output_tokens' in nextBody) {
delete nextBody.max_output_tokens
}
}
delete nextBody.tool_choice
return nextBody
}
const mapToolsForOpenAI = (tools: ToolDefinition[]) =>
tools.map((toolDef) => ({
type: 'function',
function: {
name: toolDef.name,
description: toolDef.description || '',
parameters: toolDef.inputSchema
}
}))
const mapToolsForResponses = (tools: ToolDefinition[]) =>
tools.map((toolDef) => ({
type: 'function',
name: toolDef.name,
description: toolDef.description || '',
parameters: toolDef.inputSchema
}))
const mapToolsForAnthropic = (tools: ToolDefinition[]) =>
tools.map((toolDef) => ({
name: toolDef.name,
description: toolDef.description || '',
input_schema: toolDef.inputSchema
}))
const buildAssistantTools = async (assistant: AssistantConfig): Promise<ToolDefinition[]> => {
const serverIds = assistant.mcpServers?.map((server) => server.id).filter(Boolean) ?? []
if (serverIds.length === 0) {
return []
}
const allowedTools = Array.isArray(assistant.allowed_tools) ? new Set(assistant.allowed_tools) : null
const servers = await getMCPServersFromRedux()
const tools: ToolDefinition[] = []
for (const serverId of serverIds) {
const server = servers.find((entry) => entry.id === serverId)
if (!server || !server.isActive) {
continue
}
const info = await mcpApiService.getServerInfo(serverId)
if (!info?.tools || !Array.isArray(info.tools)) {
continue
}
for (const tool of info.tools as Array<{
name: string
description?: string
inputSchema?: MCPTool['inputSchema']
}>) {
if (!tool?.name || !tool.inputSchema) {
continue
}
if (server.disabledTools?.includes(tool.name)) {
continue
}
const toolName = buildFunctionCallToolName(info.name, tool.name)
if (allowedTools && !allowedTools.has(toolName)) {
continue
}
tools.push({
name: toolName,
description: tool.description,
inputSchema: tool.inputSchema
})
}
}
return tools
}
const resolveAssistantById = async (assistantId: string): Promise<AssistantConfig | null> => {
const assistants = (await reduxService.select('state.assistants.assistants')) as AssistantConfig[] | null
return assistants?.find((assistant) => assistant.id === assistantId) ?? null
}
/**
* Gateway middleware for model group routing
*
@ -43,16 +240,79 @@ export const gatewayMiddleware = async (req: Request, res: Response, next: NextF
return
}
// Inject the group's model into the request
req.body = {
...req.body,
model: `${group.providerId}:${group.modelId}`
}
const endpoint = req.path.startsWith('/') ? req.path : `/${req.path}`
const endpointFormat = getEndpointFormat(endpoint)
logger.debug('Injected model from group', {
groupName,
model: `${group.providerId}:${group.modelId}`
})
if (group.mode === 'assistant' && group.assistantId) {
if (!endpointFormat) {
res.status(400).json({
error: {
type: 'invalid_request_error',
message: `Unsupported endpoint ${endpoint}`
}
})
return
}
const assistant = await resolveAssistantById(group.assistantId)
if (!assistant) {
res.status(400).json({
error: {
type: 'invalid_request_error',
message: `Assistant '${group.assistantId}' not found`
}
})
return
}
const modelId = buildAssistantModelId(assistant)
if (!modelId) {
res.status(400).json({
error: {
type: 'invalid_request_error',
message: `Assistant '${group.assistantId}' is missing a model`
}
})
return
}
let nextBody = {
...req.body,
model: modelId
}
nextBody = applyAssistantMessageOverrides(nextBody, assistant, endpointFormat)
nextBody = applyAssistantParameterOverrides(nextBody, assistant, endpointFormat)
const tools = await buildAssistantTools(assistant)
if (endpointFormat === 'openai') {
nextBody.tools = tools.length > 0 ? mapToolsForOpenAI(tools) : undefined
} else if (endpointFormat === 'responses') {
nextBody.tools = tools.length > 0 ? mapToolsForResponses(tools) : undefined
} else {
nextBody.tools = tools.length > 0 ? mapToolsForAnthropic(tools) : undefined
}
req.body = nextBody
logger.debug('Injected assistant preset into request', {
groupName,
assistantId: assistant.id,
model: modelId,
toolCount: tools.length
})
} else {
// Inject the group's model into the request
req.body = {
...req.body,
model: `${group.providerId}:${group.modelId}`
}
logger.debug('Injected model from group', {
groupName,
model: `${group.providerId}:${group.modelId}`
})
}
}
// Get the endpoint path (for group routes, use the part after groupName)

View File

@ -353,6 +353,13 @@
"description": "Create model groups with unique URLs for different provider/model combinations",
"empty": "No model groups configured. Click 'Add Group' to create one.",
"label": "Model Groups",
"mode": {
"assistant": "Assistant Preset",
"assistantHint": "Assistant preset overrides request parameters.",
"assistantPlaceholder": "Select assistant",
"label": "Mode",
"model": "Direct Model"
},
"namePlaceholder": "Group name"
},
"networkAccess": {

View File

@ -353,6 +353,13 @@
"description": "创建具有唯一 URL 的模型分组,用于不同的提供商/模型组合",
"empty": "尚未配置模型分组。点击「添加分组」创建一个。",
"label": "模型分组",
"mode": {
"assistant": "助手预设",
"assistantHint": "助手预设会覆盖请求参数。",
"assistantPlaceholder": "选择助手",
"label": "模式",
"model": "直接模型"
},
"namePlaceholder": "分组名称"
},
"networkAccess": {

View File

@ -46,6 +46,7 @@ const ApiServerSettings: FC = () => {
// API Gateway state with proper defaults
const apiServerConfig = useSelector((state: RootState) => state.settings.apiServer)
const assistants = useSelector((state: RootState) => state.assistants.assistants)
const { apiServerRunning, apiServerLoading, startApiServer, stopApiServer, restartApiServer, setApiServerEnabled } =
useApiServer()
@ -108,6 +109,8 @@ const ApiServerSettings: FC = () => {
name: `group-${apiServerConfig.modelGroups.length + 1}`, // URL-safe name
providerId: '',
modelId: '',
mode: 'model',
assistantId: '',
createdAt: Date.now()
}
dispatch(addApiGatewayModelGroup(newGroup))
@ -266,7 +269,13 @@ const ApiServerSettings: FC = () => {
) : (
<ModelGroupList>
{apiServerConfig.modelGroups.map((group) => (
<ModelGroupCard key={group.id} group={group} onUpdate={updateModelGroup} onDelete={deleteModelGroup} />
<ModelGroupCard
key={group.id}
group={group}
assistants={assistants}
onUpdate={updateModelGroup}
onDelete={deleteModelGroup}
/>
))}
</ModelGroupList>
)}
@ -295,6 +304,7 @@ const ApiServerSettings: FC = () => {
// Model Group Card Component
interface ModelGroupCardProps {
group: ModelGroup
assistants: RootState['assistants']['assistants']
onUpdate: (group: ModelGroup) => void
onDelete: (groupId: string) => void
}
@ -305,11 +315,12 @@ const ENV_FORMAT_TO_ENDPOINT: Record<EnvFormat, GatewayEndpoint> = {
responses: '/v1/responses'
}
const ModelGroupCard: FC<ModelGroupCardProps> = ({ group, onUpdate, onDelete }) => {
const ModelGroupCard: FC<ModelGroupCardProps> = ({ group, assistants, onUpdate, onDelete }) => {
const { t } = useTranslation()
const { providers } = useProviders()
const apiServerConfig = useSelector((state: RootState) => state.settings.apiServer)
const [envFormat, setEnvFormat] = useState<EnvFormat>('openai')
const mode = group.mode ?? 'model'
// Reset envFormat when selected endpoint is disabled
useEffect(() => {
@ -398,7 +409,36 @@ const ModelGroupCard: FC<ModelGroupCardProps> = ({ group, onUpdate, onDelete })
})
}
const isConfigured = group.providerId && group.modelId
const handleModeChange = (nextMode: 'model' | 'assistant') => {
if (nextMode === mode) return
onUpdate({
...group,
mode: nextMode,
assistantId: nextMode === 'assistant' ? group.assistantId || '' : '',
providerId: nextMode === 'model' ? group.providerId : '',
modelId: nextMode === 'model' ? group.modelId : ''
})
}
const handleAssistantChange = (assistantId: string | null) => {
onUpdate({
...group,
assistantId: assistantId || ''
})
}
const selectedAssistant = useMemo(() => {
if (!group.assistantId) return undefined
return assistants.find((assistant) => assistant.id === group.assistantId)
}, [assistants, group.assistantId])
const assistantModelLabel = useMemo(() => {
const model = selectedAssistant?.model ?? selectedAssistant?.defaultModel
if (!model) return undefined
return model.name || model.id
}, [selectedAssistant])
const isConfigured = mode === 'assistant' ? !!group.assistantId : !!(group.providerId && group.modelId)
return (
<GroupCard $configured={!!isConfigured}>
@ -427,35 +467,71 @@ const ModelGroupCard: FC<ModelGroupCardProps> = ({ group, onUpdate, onDelete })
</GroupHeader>
<GroupContent>
<ModeRow>
<ModeLabel>{t('apiGateway.fields.modelGroups.mode.label', 'Mode')}</ModeLabel>
<Segmented
size="small"
value={mode}
onChange={(value) => handleModeChange(value as 'model' | 'assistant')}
options={[
{ label: t('apiGateway.fields.modelGroups.mode.model', 'Direct Model'), value: 'model' },
{ label: t('apiGateway.fields.modelGroups.mode.assistant', 'Assistant Preset'), value: 'assistant' }
]}
/>
</ModeRow>
<SelectRow>
<StyledSelect
value={group.providerId || undefined}
onChange={(value) => handleProviderChange((value as string) || null)}
placeholder={t('apiGateway.fields.defaultModel.providerPlaceholder')}
allowClear
showSearch
optionFilterProp="label"
style={{ flex: 1 }}
options={providers.map((p) => ({
value: p.id,
label: p.isSystem ? getProviderLabel(p.id) : p.name
}))}
/>
<StyledSelect
value={group.modelId || undefined}
onChange={(value) => handleModelChange((value as string) || null)}
placeholder={t('apiGateway.fields.defaultModel.modelPlaceholder')}
allowClear
showSearch
optionFilterProp="label"
disabled={!group.providerId}
style={{ flex: 1 }}
options={models.map((m) => ({
value: m.id,
label: m.name || m.id
}))}
/>
{mode === 'assistant' ? (
<StyledSelect
value={group.assistantId || undefined}
onChange={(value) => handleAssistantChange((value as string) || null)}
placeholder={t('apiGateway.fields.modelGroups.mode.assistantPlaceholder', 'Select assistant')}
allowClear
showSearch
optionFilterProp="label"
style={{ flex: 1 }}
options={assistants.map((assistant) => ({
value: assistant.id,
label: assistant.name
}))}
/>
) : (
<>
<StyledSelect
value={group.providerId || undefined}
onChange={(value) => handleProviderChange((value as string) || null)}
placeholder={t('apiGateway.fields.defaultModel.providerPlaceholder')}
allowClear
showSearch
optionFilterProp="label"
style={{ flex: 1 }}
options={providers.map((p) => ({
value: p.id,
label: p.isSystem ? getProviderLabel(p.id) : p.name
}))}
/>
<StyledSelect
value={group.modelId || undefined}
onChange={(value) => handleModelChange((value as string) || null)}
placeholder={t('apiGateway.fields.defaultModel.modelPlaceholder')}
allowClear
showSearch
optionFilterProp="label"
disabled={!group.providerId}
style={{ flex: 1 }}
options={models.map((m) => ({
value: m.id,
label: m.name || m.id
}))}
/>
</>
)}
</SelectRow>
{mode === 'assistant' && (
<ModeHint>
{t('apiGateway.fields.modelGroups.mode.assistantHint', 'Assistant preset overrides request parameters.')}
{assistantModelLabel ? ` (${assistantModelLabel})` : ''}
</ModeHint>
)}
{isConfigured && (
<GroupUrlSection>
@ -819,6 +895,23 @@ const GroupContent = styled.div`
gap: 12px;
`
const ModeRow = styled.div`
display: flex;
align-items: center;
justify-content: space-between;
gap: 12px;
`
const ModeLabel = styled.div`
font-size: 12px;
color: var(--color-text-3);
`
const ModeHint = styled.div`
font-size: 12px;
color: var(--color-text-3);
`
const GroupUrlSection = styled.div`
display: flex;
flex-direction: column;

View File

@ -12,6 +12,8 @@ export type ModelGroup = {
name: string // display name
providerId: string
modelId: string
mode?: 'model' | 'assistant'
assistantId?: string
createdAt: number
}