From fca41ed966e77ee630c50e58c8a39e6362372eb4 Mon Sep 17 00:00:00 2001 From: suyao Date: Mon, 29 Dec 2025 14:48:24 +0800 Subject: [PATCH] refactor: remove obsolete middleware and README, add new plugins for reasoning and tool selection - Deleted README.md for AiSdkMiddlewareBuilder as it was outdated. - Removed toolChoiceMiddleware.ts as it is no longer needed. - Introduced new plugins: - noThinkPlugin: Appends '/no_think' to user messages to prevent unnecessary thinking. - openrouterGenerateImagePlugin: Configures OpenRouter for image and text modalities. - openrouterReasoningPlugin: Redacts reasoning blocks in OpenRouter responses. - qwenThinkingPlugin: Controls thinking mode for Qwen models based on provider support. - reasoningExtractionPlugin: Extracts reasoning tags from OpenAI/Azure responses. - simulateStreamingPlugin: Converts non-streaming responses to streaming format. - skipGeminiThoughtSignaturePlugin: Skips Gemini3 thought signatures for multi-model requests. - Updated parameterBuilder.ts to correct type definitions. - Added middlewareConfig.ts for better middleware configuration management. - Enhanced reasoning utility functions for better tag name retrieval. - Updated ApiService.ts and aiCoreTypes.ts for consistency with new changes. --- packages/aiCore/package.json | 1 + packages/aiCore/src/core/errors/index.ts | 124 ++++++++ packages/aiCore/src/core/models/index.ts | 3 + packages/aiCore/src/core/models/types.ts | 5 +- packages/aiCore/src/core/models/utils.ts | 27 ++ packages/aiCore/src/core/options/examples.ts | 87 ------ .../built-in/googleToolsPlugin/index.ts | 52 ++-- .../toolUsePlugin/StreamEventManager.ts | 139 ++++++--- .../toolUsePlugin/promptToolUsePlugin.ts | 24 +- packages/aiCore/src/core/plugins/index.ts | 42 ++- packages/aiCore/src/core/plugins/manager.ts | 74 +++-- packages/aiCore/src/core/plugins/types.ts | 107 +++++-- .../aiCore/src/core/providers/HubProvider.ts | 94 ++++-- .../src/core/providers/RegistryManagement.ts | 2 - packages/aiCore/src/core/providers/index.ts | 1 + packages/aiCore/src/core/providers/schemas.ts | 9 +- packages/aiCore/src/core/providers/types.ts | 21 +- packages/aiCore/src/core/runtime/executor.ts | 24 +- .../aiCore/src/core/runtime/pluginEngine.ts | 221 +++++++++----- packages/aiCore/src/core/types/branded.ts | 80 +++++ packages/aiCore/src/index.ts | 42 ++- src/renderer/src/aiCore/index_new.ts | 36 +-- .../middleware/AiSdkMiddlewareBuilder.ts | 286 ------------------ src/renderer/src/aiCore/middleware/README.md | 140 --------- .../aiCore/middleware/toolChoiceMiddleware.ts | 45 --- .../src/aiCore/plugins/PluginBuilder.ts | 68 ++++- .../noThinkPlugin.ts} | 18 +- .../openrouterGenerateImagePlugin.ts} | 19 +- .../openrouterReasoningPlugin.ts} | 24 +- .../qwenThinkingPlugin.ts} | 16 +- .../plugins/reasoningExtractionPlugin.ts | 22 ++ .../plugins/searchOrchestrationPlugin.ts | 24 +- .../aiCore/plugins/simulateStreamingPlugin.ts | 18 ++ .../skipGeminiThoughtSignaturePlugin.ts} | 16 +- .../aiCore/prepareParams/parameterBuilder.ts | 6 +- .../src/aiCore/types/middlewareConfig.ts | 26 ++ src/renderer/src/aiCore/utils/reasoning.ts | 18 ++ src/renderer/src/services/ApiService.ts | 2 +- src/renderer/src/types/aiCoreTypes.ts | 18 +- yarn.lock | 43 ++- 40 files changed, 1139 insertions(+), 885 deletions(-) create mode 100644 packages/aiCore/src/core/errors/index.ts create mode 100644 packages/aiCore/src/core/models/utils.ts delete mode 100644 packages/aiCore/src/core/options/examples.ts create mode 100644 packages/aiCore/src/core/types/branded.ts delete mode 100644 src/renderer/src/aiCore/middleware/AiSdkMiddlewareBuilder.ts delete mode 100644 src/renderer/src/aiCore/middleware/README.md delete mode 100644 src/renderer/src/aiCore/middleware/toolChoiceMiddleware.ts rename src/renderer/src/aiCore/{middleware/noThinkMiddleware.ts => plugins/noThinkPlugin.ts} (78%) rename src/renderer/src/aiCore/{middleware/openrouterGenerateImageMiddleware.ts => plugins/openrouterGenerateImagePlugin.ts} (67%) rename src/renderer/src/aiCore/{middleware/openrouterReasoningMiddleware.ts => plugins/openrouterReasoningPlugin.ts} (67%) rename src/renderer/src/aiCore/{middleware/qwenThinkingMiddleware.ts => plugins/qwenThinkingPlugin.ts} (72%) create mode 100644 src/renderer/src/aiCore/plugins/reasoningExtractionPlugin.ts create mode 100644 src/renderer/src/aiCore/plugins/simulateStreamingPlugin.ts rename src/renderer/src/aiCore/{middleware/skipGeminiThoughtSignatureMiddleware.ts => plugins/skipGeminiThoughtSignaturePlugin.ts} (71%) create mode 100644 src/renderer/src/aiCore/types/middlewareConfig.ts diff --git a/packages/aiCore/package.json b/packages/aiCore/package.json index 6d4c1c11fc..4d3a6421e7 100644 --- a/packages/aiCore/package.json +++ b/packages/aiCore/package.json @@ -9,6 +9,7 @@ "scripts": { "build": "tsdown", "dev": "tsc -w", + "typecheck": "tsc --noEmit", "clean": "rm -rf dist", "test": "vitest run", "test:watch": "vitest" diff --git a/packages/aiCore/src/core/errors/index.ts b/packages/aiCore/src/core/errors/index.ts new file mode 100644 index 0000000000..1ff110296a --- /dev/null +++ b/packages/aiCore/src/core/errors/index.ts @@ -0,0 +1,124 @@ +/** + * AI Core Error System + * Unified error handling for the AI Core package + */ + +/** + * Base error class for all AI Core errors + * Provides structured error information with error codes, context, and cause tracking + */ +export class AiCoreError extends Error { + constructor( + public readonly code: string, + message: string, + public readonly context?: Record, + public readonly cause?: Error + ) { + super(message) + this.name = 'AiCoreError' + + if (cause) { + this.stack = `${this.stack}\nCaused by: ${cause.stack}` + } + } + + toJSON() { + return { + name: this.name, + code: this.code, + message: this.message, + context: this.context, + cause: this.cause + ? { + name: this.cause.name, + message: this.cause.message + } + : undefined + } + } +} + +/** + * Recursive depth limit exceeded error + * Thrown when recursive calls exceed the maximum allowed depth + */ +export class RecursiveDepthError extends AiCoreError { + constructor(requestId: string, currentDepth: number, maxDepth: number) { + super('RECURSIVE_DEPTH_EXCEEDED', `Maximum recursive depth (${maxDepth}) exceeded at depth ${currentDepth}`, { + requestId, + currentDepth, + maxDepth + }) + this.name = 'RecursiveDepthError' + } +} + +/** + * Model resolution failure error + * Thrown when a model ID cannot be resolved to a model instance + */ +export class ModelResolutionError extends AiCoreError { + constructor(modelId: string, providerId: string, cause?: Error) { + super('MODEL_RESOLUTION_FAILED', `Failed to resolve model: ${modelId}`, { modelId, providerId }, cause) + this.name = 'ModelResolutionError' + } +} + +/** + * Parameter validation error + * Thrown when request parameters fail validation + */ +export class ParameterValidationError extends AiCoreError { + constructor(paramName: string, reason: string, value?: unknown) { + super('PARAMETER_VALIDATION_FAILED', `Invalid parameter '${paramName}': ${reason}`, { + paramName, + reason, + value + }) + this.name = 'ParameterValidationError' + } +} + +/** + * Plugin execution error + * Thrown when a plugin fails during execution + */ +export class PluginExecutionError extends AiCoreError { + constructor(pluginName: string, hookName: string, cause: Error) { + super( + 'PLUGIN_EXECUTION_FAILED', + `Plugin '${pluginName}' failed in hook '${hookName}'`, + { + pluginName, + hookName + }, + cause + ) + this.name = 'PluginExecutionError' + } +} + +/** + * Provider configuration error + * Thrown when provider settings are invalid or missing + */ +export class ProviderConfigError extends AiCoreError { + constructor(providerId: string, reason: string) { + super('PROVIDER_CONFIG_ERROR', `Provider '${providerId}' configuration error: ${reason}`, { + providerId, + reason + }) + this.name = 'ProviderConfigError' + } +} + +/** + * Template loading error + * Thrown when a template cannot be loaded + */ +export class TemplateLoadError extends AiCoreError { + constructor(templateName: string, cause?: Error) { + super('TEMPLATE_LOAD_FAILED', `Failed to load template: ${templateName}`, { templateName }, cause) + this.name = 'TemplateLoadError' + } +} diff --git a/packages/aiCore/src/core/models/index.ts b/packages/aiCore/src/core/models/index.ts index 439d3d0f41..1e6d33bf2a 100644 --- a/packages/aiCore/src/core/models/index.ts +++ b/packages/aiCore/src/core/models/index.ts @@ -7,3 +7,6 @@ export { globalModelResolver, ModelResolver } from './ModelResolver' // 保留的类型定义(可能被其他地方使用) export type { ModelConfig as ModelConfigType } from './types' + +// 模型工具函数 +export { hasModelId, isV2Model, isV3Model } from './utils' diff --git a/packages/aiCore/src/core/models/types.ts b/packages/aiCore/src/core/models/types.ts index 847db6bc3f..c7b107263d 100644 --- a/packages/aiCore/src/core/models/types.ts +++ b/packages/aiCore/src/core/models/types.ts @@ -1,7 +1,7 @@ /** * Creation 模块类型定义 */ -import type { LanguageModelV3Middleware } from '@ai-sdk/provider' +import type { JSONObject, LanguageModelV3Middleware } from '@ai-sdk/provider' import type { ProviderId, ProviderSettingsMap } from '../providers/types' @@ -10,6 +10,5 @@ export interface ModelConfig { modelId: string providerSettings: ProviderSettingsMap[T] & { mode?: 'chat' | 'responses' } middlewares?: LanguageModelV3Middleware[] - // 额外模型参数 - extraModelConfig?: Record + extraModelConfig?: JSONObject } diff --git a/packages/aiCore/src/core/models/utils.ts b/packages/aiCore/src/core/models/utils.ts new file mode 100644 index 0000000000..6278e168ab --- /dev/null +++ b/packages/aiCore/src/core/models/utils.ts @@ -0,0 +1,27 @@ +import type { LanguageModelV2, LanguageModelV3 } from '@ai-sdk/provider' + +import type { AiSdkModel } from '../providers' + +export const isV2Model = (model: AiSdkModel): model is LanguageModelV2 => { + return typeof model === 'object' && model !== null && model.specificationVersion === 'v2' +} + +export const isV3Model = (model: AiSdkModel): model is LanguageModelV3 => { + return typeof model === 'object' && model !== null && model.specificationVersion === 'v3' +} + +/** + * Type guard to check if a model has a modelId property + */ +export const hasModelId = (model: unknown): model is { modelId: string } => { + if (typeof model !== 'object' || model === null) { + return false + } + + if (!('modelId' in model)) { + return false + } + + const obj = model as Record + return typeof obj.modelId === 'string' +} diff --git a/packages/aiCore/src/core/options/examples.ts b/packages/aiCore/src/core/options/examples.ts deleted file mode 100644 index 9078437d9c..0000000000 --- a/packages/aiCore/src/core/options/examples.ts +++ /dev/null @@ -1,87 +0,0 @@ -import { streamText } from 'ai' - -import { - createAnthropicOptions, - createGenericProviderOptions, - createGoogleOptions, - createOpenAIOptions, - mergeProviderOptions -} from './factory' - -// 示例1: 使用已知供应商的严格类型约束 -export function exampleOpenAIWithOptions() { - const openaiOptions = createOpenAIOptions({ - reasoningEffort: 'medium' - }) - - // 这里会有类型检查,确保选项符合OpenAI的设置 - return streamText({ - model: {} as any, // 实际使用时替换为真实模型 - prompt: 'Hello', - providerOptions: openaiOptions - }) -} - -// 示例2: 使用Anthropic供应商选项 -export function exampleAnthropicWithOptions() { - const anthropicOptions = createAnthropicOptions({ - thinking: { - type: 'enabled', - budgetTokens: 1000 - } - }) - - return streamText({ - model: {} as any, - prompt: 'Hello', - providerOptions: anthropicOptions - }) -} - -// 示例3: 使用Google供应商选项 -export function exampleGoogleWithOptions() { - const googleOptions = createGoogleOptions({ - thinkingConfig: { - includeThoughts: true, - thinkingBudget: 1000 - } - }) - - return streamText({ - model: {} as any, - prompt: 'Hello', - providerOptions: googleOptions - }) -} - -// 示例4: 使用未知供应商(通用类型) -export function exampleUnknownProviderWithOptions() { - const customProviderOptions = createGenericProviderOptions('custom-provider', { - temperature: 0.7, - customSetting: 'value', - anotherOption: true - }) - - return streamText({ - model: {} as any, - prompt: 'Hello', - providerOptions: customProviderOptions - }) -} - -// 示例5: 合并多个供应商选项 -export function exampleMergedOptions() { - const openaiOptions = createOpenAIOptions({}) - - const customOptions = createGenericProviderOptions('custom', { - customParam: 'value' - }) - - const mergedOptions = mergeProviderOptions(openaiOptions, customOptions) - - return streamText({ - model: {} as any, - prompt: 'Hello', - providerOptions: mergedOptions - }) -} diff --git a/packages/aiCore/src/core/plugins/built-in/googleToolsPlugin/index.ts b/packages/aiCore/src/core/plugins/built-in/googleToolsPlugin/index.ts index 09a741d9f2..c9372ca7b9 100644 --- a/packages/aiCore/src/core/plugins/built-in/googleToolsPlugin/index.ts +++ b/packages/aiCore/src/core/plugins/built-in/googleToolsPlugin/index.ts @@ -1,7 +1,6 @@ import { google } from '@ai-sdk/google' -import { definePlugin } from '../../' -import type { AiRequestContext } from '../../types' +import { type AiPlugin, definePlugin, type StreamTextParams, type StreamTextResult } from '../../' const toolNameMap = { googleSearch: 'google_search', @@ -12,27 +11,40 @@ const toolNameMap = { type ToolConfigKey = keyof typeof toolNameMap type ToolConfig = { googleSearch?: boolean; urlContext?: boolean; codeExecution?: boolean } -export const googleToolsPlugin = (config?: ToolConfig) => - definePlugin({ +export const googleToolsPlugin = (config?: ToolConfig): AiPlugin => + definePlugin({ name: 'googleToolsPlugin', - transformParams: (params: T, context: AiRequestContext): T => { + transformParams: (params, context) => { const { providerId } = context - if (providerId === 'google' && config) { - if (typeof params === 'object' && params !== null) { - const typedParams = params as T & { tools?: Record } - if (!typedParams.tools) { - typedParams.tools = {} - } - // 使用类型安全的方式遍历配置 - ;(Object.keys(config) as ToolConfigKey[]).forEach((key) => { - if (config[key] && key in toolNameMap && key in google.tools) { - const toolName = toolNameMap[key] - typedParams.tools![toolName] = google.tools[key]({}) - } - }) - } + // 只在 Google provider 且有配置时才修改参数 + if (providerId !== 'google' || !config) { + return {} // 返回空 Partial,表示不修改 } - return params + + if (typeof params !== 'object' || params === null) { + return {} + } + + // 构建 tools 对象,确保类型兼容 + const hasTools = (Object.keys(config) as ToolConfigKey[]).some( + (key) => config[key] && key in toolNameMap && key in google.tools + ) + + if (!hasTools) { + return {} // 返回空 Partial,表示不修改 + } + + // 构建符合 AI SDK 的 tools 对象 + const tools: Record> = {} + + ;(Object.keys(config) as ToolConfigKey[]).forEach((key) => { + if (config[key] && key in toolNameMap && key in google.tools) { + const toolName = toolNameMap[key] + tools[toolName] = google.tools[key]({}) + } + }) + + return { tools: tools } } }) diff --git a/packages/aiCore/src/core/plugins/built-in/toolUsePlugin/StreamEventManager.ts b/packages/aiCore/src/core/plugins/built-in/toolUsePlugin/StreamEventManager.ts index c30c2015f6..4a3025b39e 100644 --- a/packages/aiCore/src/core/plugins/built-in/toolUsePlugin/StreamEventManager.ts +++ b/packages/aiCore/src/core/plugins/built-in/toolUsePlugin/StreamEventManager.ts @@ -4,11 +4,61 @@ * 负责处理 AI SDK 流事件的发送和管理 * 从 promptToolUsePlugin.ts 中提取出来以降低复杂度 */ -import type { ModelMessage } from 'ai' +import type { EmbeddingModelUsage, ImageModelUsage, LanguageModelUsage, ModelMessage } from 'ai' -import type { AiRequestContext } from '../../types' +import type { AiSdkUsage } from '../../../providers/types' +import type { AiRequestContext, StreamTextParams, StreamTextResult } from '../../types' import type { StreamController } from './ToolExecutor' +/** + * 类型守卫:检查对象是否是有效的流结果(包含 ReadableStream 类型的 fullStream) + */ +function hasFullStream(obj: unknown): obj is StreamTextResult & { fullStream: ReadableStream } { + return typeof obj === 'object' && obj !== null && 'fullStream' in obj && obj.fullStream instanceof ReadableStream +} + +/** + * 类型守卫:检查 usage 是否是 LanguageModelUsage + * LanguageModelUsage 包含 totalTokens, inputTokens, outputTokens 等字段 + */ +function isLanguageModelUsage(usage: unknown): usage is LanguageModelUsage { + return ( + typeof usage === 'object' && + usage !== null && + ('totalTokens' in usage || 'inputTokens' in usage || 'outputTokens' in usage) + ) +} + +/** + * 类型守卫:检查 usage 是否是 ImageModelUsage + * ImageModelUsage 包含 inputTokens, outputTokens, totalTokens 字段 + */ +function isImageModelUsage(usage: unknown): usage is ImageModelUsage { + return ( + typeof usage === 'object' && + usage !== null && + 'inputTokens' in usage && + 'outputTokens' in usage && + // 确保不是 LanguageModelUsage(LanguageModelUsage 可能有 reasoningTokens 等额外字段) + !('reasoningTokens' in usage) + ) +} + +/** + * 类型守卫:检查 usage 是否是 EmbeddingModelUsage + * EmbeddingModelUsage 只包含 tokens 字段 + */ +function isEmbeddingModelUsage(usage: unknown): usage is EmbeddingModelUsage { + return ( + typeof usage === 'object' && + usage !== null && + 'tokens' in usage && + // 确保只有 tokens 字段(没有 inputTokens, outputTokens 等) + !('inputTokens' in usage) && + !('outputTokens' in usage) + ) +} + /** * 流事件管理器类 */ @@ -50,10 +100,10 @@ export class StreamEventManager { /** * 处理递归调用并将结果流接入当前流 */ - async handleRecursiveCall( + async handleRecursiveCall( controller: StreamController, - recursiveParams: any, - context: AiRequestContext + recursiveParams: Partial, + context: AiRequestContext ): Promise { // try { // 重置工具执行状态,准备处理新的步骤 @@ -61,7 +111,7 @@ export class StreamEventManager { const recursiveResult = await context.recursiveCall(recursiveParams) - if (recursiveResult && recursiveResult.fullStream) { + if (hasFullStream(recursiveResult)) { await this.pipeRecursiveStream(controller, recursiveResult.fullStream) } else { console.warn('[MCP Prompt] No fullstream found in recursive result:', recursiveResult) @@ -97,36 +147,20 @@ export class StreamEventManager { } } - /** - * 处理递归调用错误 - */ - // private handleRecursiveCallError(controller: StreamController, error: unknown): void { - // console.error('[MCP Prompt] Recursive call failed:', error) - - // // 使用 AI SDK 标准错误格式,但不中断流 - // controller.enqueue({ - // type: 'error', - // error: { - // message: error instanceof Error ? error.message : String(error), - // name: error instanceof Error ? error.name : 'RecursiveCallError' - // } - // }) - - // // // 继续发送文本增量,保持流的连续性 - // // controller.enqueue({ - // // type: 'text-delta', - // // id: stepId, - // // text: '\n\n[工具执行后递归调用失败,继续对话...]' - // // }) - // } - /** * 构建递归调用的参数 */ - buildRecursiveParams(context: AiRequestContext, textBuffer: string, toolResultsText: string, tools: any): any { + buildRecursiveParams( + context: AiRequestContext, + textBuffer: string, + toolResultsText: string, + tools: any + ): Partial { + const params = context.originalParams + // 构建新的对话消息 const newMessages: ModelMessage[] = [ - ...(context.originalParams.messages || []), + ...(params.messages || []), // 只有当 textBuffer 有内容时才添加 assistant 消息,避免空消息导致 API 错误 ...(textBuffer ? [{ role: 'assistant' as const, content: textBuffer }] : []), { @@ -137,28 +171,47 @@ export class StreamEventManager { // 递归调用,继续对话,重新传递 tools const recursiveParams = { - ...context.originalParams, + ...params, messages: newMessages, tools: tools - } - - // 更新上下文中的消息 - context.originalParams.messages = newMessages + } as Partial return recursiveParams } /** * 累加 usage 数据 + * + * 使用类型守卫来处理不同类型的 usage(LanguageModelUsage, ImageModelUsage, EmbeddingModelUsage) + * - LanguageModelUsage: inputTokens, outputTokens, totalTokens + * - ImageModelUsage: inputTokens, outputTokens, totalTokens + * - EmbeddingModelUsage: tokens */ - accumulateUsage(target: any, source: any): void { + accumulateUsage(target: Partial, source: Partial): void { if (!target || !source) return - // 累加各种 token 类型 - target.inputTokens = (target.inputTokens || 0) + (source.inputTokens || 0) - target.outputTokens = (target.outputTokens || 0) + (source.outputTokens || 0) - target.totalTokens = (target.totalTokens || 0) + (source.totalTokens || 0) - target.reasoningTokens = (target.reasoningTokens || 0) + (source.reasoningTokens || 0) - target.cachedInputTokens = (target.cachedInputTokens || 0) + (source.cachedInputTokens || 0) + if (isLanguageModelUsage(target) && isLanguageModelUsage(source)) { + target.totalTokens = (target.totalTokens || 0) + (source.totalTokens || 0) + target.inputTokens = (target.inputTokens || 0) + (source.inputTokens || 0) + target.outputTokens = (target.outputTokens || 0) + (source.outputTokens || 0) + return + } + if (isImageModelUsage(target) && isImageModelUsage(source)) { + target.totalTokens = (target.totalTokens || 0) + (source.totalTokens || 0) + target.inputTokens = (target.inputTokens || 0) + (source.inputTokens || 0) + target.outputTokens = (target.outputTokens || 0) + (source.outputTokens || 0) + return + } + + if (isEmbeddingModelUsage(target) && isEmbeddingModelUsage(source)) { + target.tokens = (target.tokens || 0) + (source.tokens || 0) + return + } + + // ⚠️ 未知类型或类型不匹配,不进行累加 + console.warn('[StreamEventManager] Unable to accumulate usage - type mismatch or unknown type', { + target, + source + }) } } diff --git a/packages/aiCore/src/core/plugins/built-in/toolUsePlugin/promptToolUsePlugin.ts b/packages/aiCore/src/core/plugins/built-in/toolUsePlugin/promptToolUsePlugin.ts index 224cee05ae..d20b501ab7 100644 --- a/packages/aiCore/src/core/plugins/built-in/toolUsePlugin/promptToolUsePlugin.ts +++ b/packages/aiCore/src/core/plugins/built-in/toolUsePlugin/promptToolUsePlugin.ts @@ -6,7 +6,7 @@ import type { TextStreamPart, ToolSet } from 'ai' import { definePlugin } from '../../index' -import type { AiRequestContext } from '../../types' +import type { AiPlugin, StreamTextParams, StreamTextResult } from '../../types' import { StreamEventManager } from './StreamEventManager' import { type TagConfig, TagExtractor } from './tagExtraction' import { ToolExecutor } from './ToolExecutor' @@ -254,23 +254,25 @@ function defaultParseToolUse(content: string, tools: ToolSet): { results: ToolUs return { results, content: contentToProcess } } -export const createPromptToolUsePlugin = (config: PromptToolUseConfig = {}) => { +export const createPromptToolUsePlugin = ( + config: PromptToolUseConfig = {} +): AiPlugin => { const { enabled = true, buildSystemPrompt = defaultBuildSystemPrompt, parseToolUse = defaultParseToolUse } = config - return definePlugin({ + return definePlugin({ name: 'built-in:prompt-tool-use', - transformParams: (params: any, context: AiRequestContext) => { + transformParams: (params, context) => { if (!enabled || !params.tools || typeof params.tools !== 'object') { return params } - // 分离 provider-defined 和其他类型的工具 + // 分离 provider 和其他类型的工具 const providerDefinedTools: ToolSet = {} const promptTools: ToolSet = {} for (const [toolName, tool] of Object.entries(params.tools as ToolSet)) { - if (tool.type === 'provider-defined') { - // provider-defined 类型的工具保留在 tools 参数中 + if (tool.type === 'provider') { + // provider 类型的工具保留在 tools 参数中 providerDefinedTools[toolName] = tool } else { // 其他工具转换为 prompt 模式 @@ -278,12 +280,12 @@ export const createPromptToolUsePlugin = (config: PromptToolUseConfig = {}) => { } } - // 只有当有非 provider-defined 工具时才保存到 context + // 只有当有非 provider 工具时才保存到 context if (Object.keys(promptTools).length > 0) { context.mcpTools = promptTools } - // 构建系统提示符(只包含非 provider-defined 工具) + // 构建系统提示符(只包含非 provider 工具) const userSystemPrompt = typeof params.system === 'string' ? params.system : '' const systemPrompt = buildSystemPrompt(userSystemPrompt, promptTools) let systemMessage: string | null = systemPrompt @@ -292,7 +294,7 @@ export const createPromptToolUsePlugin = (config: PromptToolUseConfig = {}) => { systemMessage = config.createSystemMessage(systemPrompt, params, context) } - // 保留 provider-defined tools,移除其他 tools + // 保留 provide tools,移除其他 tools const transformedParams = { ...params, ...(systemMessage ? { system: systemMessage } : {}), @@ -301,7 +303,7 @@ export const createPromptToolUsePlugin = (config: PromptToolUseConfig = {}) => { context.originalParams = transformedParams return transformedParams }, - transformStream: (_: any, context: AiRequestContext) => () => { + transformStream: (_, context) => () => { let textBuffer = '' // let stepId = '' diff --git a/packages/aiCore/src/core/plugins/index.ts b/packages/aiCore/src/core/plugins/index.ts index 9dc9ef6528..bc8b1c3088 100644 --- a/packages/aiCore/src/core/plugins/index.ts +++ b/packages/aiCore/src/core/plugins/index.ts @@ -1,7 +1,17 @@ // 核心类型和接口 -export type { AiPlugin, AiRequestContext, HookResult, PluginManagerConfig } from './types' -import type { ImageModelV3 } from '@ai-sdk/provider' -import type { LanguageModel } from 'ai' +export type { + AiPlugin, + AiRequestContext, + AiRequestMetadata, + GenerateTextParams, + GenerateTextResult, + HookResult, + PluginManagerConfig, + RecursiveCallFn, + StreamTextParams, + StreamTextResult +} from './types' +import type { ImageModel, LanguageModel } from 'ai' import type { ProviderId } from '../providers' import type { AiPlugin, AiRequestContext } from './types' @@ -10,11 +20,11 @@ import type { AiPlugin, AiRequestContext } from './types' export { PluginManager } from './manager' // 工具函数 -export function createContext( +export function createContext( providerId: T, - model: LanguageModel | ImageModelV3, - originalParams: any -): AiRequestContext { + model: LanguageModel | ImageModel, + originalParams: TParams +): AiRequestContext { return { providerId, model, @@ -22,14 +32,28 @@ export function createContext( metadata: {}, startTime: Date.now(), requestId: `${providerId}-${typeof model === 'string' ? model : model?.modelId}-${Date.now()}-${Math.random().toString(36).slice(2)}`, - // 占位 - recursiveCall: () => Promise.resolve(null) + isRecursiveCall: false, + recursiveDepth: 0, // 初始化递归深度为 0 + maxRecursiveDepth: 10, // 默认最大递归深度为 10 + extensions: new Map(), + middlewares: [], + // 占位递归调用函数,实际使用时会被 PluginEngine 替换 + recursiveCall: () => Promise.resolve(null as any) } } // 插件构建器 - 便于创建插件 + +// 重载 1: 泛型插件(显式指定类型参数) +export function definePlugin(plugin: AiPlugin): AiPlugin + +// 重载 2: 非泛型插件(默认 unknown) export function definePlugin(plugin: AiPlugin): AiPlugin + +// 重载 3: 插件工厂函数 export function definePlugin AiPlugin>(pluginFactory: T): T + +// 实现 export function definePlugin(plugin: AiPlugin | ((...args: any[]) => AiPlugin)) { return plugin } diff --git a/packages/aiCore/src/core/plugins/manager.ts b/packages/aiCore/src/core/plugins/manager.ts index 40f5836c44..a3f4573be1 100644 --- a/packages/aiCore/src/core/plugins/manager.ts +++ b/packages/aiCore/src/core/plugins/manager.ts @@ -1,19 +1,20 @@ import type { AiPlugin, AiRequestContext } from './types' /** - * 插件管理器 + * 插件管理器(泛型化) + * 支持类型安全的插件管理,同时通过逆变保持灵活性 */ -export class PluginManager { - private plugins: AiPlugin[] = [] +export class PluginManager { + private plugins: AiPlugin[] = [] - constructor(plugins: AiPlugin[] = []) { + constructor(plugins: AiPlugin[] = []) { this.plugins = this.sortPlugins(plugins) } /** - * 添加插件 + * 添加插件(支持逆变:AiPlugin 可赋值给 AiPlugin) */ - use(plugin: AiPlugin): this { + use(plugin: AiPlugin): this { this.plugins = this.sortPlugins([...this.plugins, plugin]) return this } @@ -29,10 +30,10 @@ export class PluginManager { /** * 插件排序:pre -> normal -> post */ - private sortPlugins(plugins: AiPlugin[]): AiPlugin[] { - const pre: AiPlugin[] = [] - const normal: AiPlugin[] = [] - const post: AiPlugin[] = [] + private sortPlugins(plugins: AiPlugin[]): AiPlugin[] { + const pre: AiPlugin[] = [] + const normal: AiPlugin[] = [] + const post: AiPlugin[] = [] plugins.forEach((plugin) => { if (plugin.enforce === 'pre') { @@ -53,7 +54,7 @@ export class PluginManager { async executeFirst( hookName: 'resolveModel' | 'loadTemplate', arg: any, - context: AiRequestContext + context: AiRequestContext ): Promise { for (const plugin of this.plugins) { const hook = plugin[hookName] @@ -68,19 +69,42 @@ export class PluginManager { } /** - * 执行 Sequential 钩子 - 链式数据转换 + * 执行 transformParams 钩子 - 链式参数转换 + * 每个插件返回 Partial,逐步合并到原始参数 */ - async executeSequential( - hookName: 'transformParams' | 'transformResult', - initialValue: T, - context: AiRequestContext - ): Promise { + async executeTransformParams( + initialValue: TParams, + context: AiRequestContext + ): Promise { let result = initialValue for (const plugin of this.plugins) { - const hook = plugin[hookName] - if (hook) { - result = await hook(result, context) + if (plugin.transformParams) { + const partial = await plugin.transformParams(result, context) + // 合并 Partial 到现有参数 + result = { ...result, ...partial } + } + } + + return result + } + + /** + * 执行 transformResult 钩子 - 链式结果转换 + * 每个插件接收并返回完整的 TResult + */ + async executeTransformResult( + initialValue: TResult, + context: AiRequestContext + ): Promise { + let result = initialValue + + for (const plugin of this.plugins) { + if (plugin.transformResult) { + // SAFETY: transformResult 的契约保证返回 TResult + // 由于插件接口定义,这个类型断言是安全的 + const transformed = await plugin.transformResult(result, context) + result = transformed as TResult } } @@ -90,7 +114,7 @@ export class PluginManager { /** * 执行 ConfigureContext 钩子 - 串行配置上下文 */ - async executeConfigureContext(context: AiRequestContext): Promise { + async executeConfigureContext(context: AiRequestContext): Promise { for (const plugin of this.plugins) { const hook = plugin.configureContext if (hook) { @@ -104,8 +128,8 @@ export class PluginManager { */ async executeParallel( hookName: 'onRequestStart' | 'onRequestEnd' | 'onError', - context: AiRequestContext, - result?: any, + context: AiRequestContext, + result?: TResult, error?: Error ): Promise { const promises = this.plugins @@ -131,7 +155,7 @@ export class PluginManager { /** * 收集所有流转换器(返回数组,AI SDK 原生支持) */ - collectStreamTransforms(params: any, context: AiRequestContext) { + collectStreamTransforms(params: TParams, context: AiRequestContext) { return this.plugins .filter((plugin) => plugin.transformStream) .map((plugin) => plugin.transformStream?.(params, context)) @@ -140,7 +164,7 @@ export class PluginManager { /** * 获取所有插件信息 */ - getPlugins(): AiPlugin[] { + getPlugins(): AiPlugin[] { return [...this.plugins] } diff --git a/packages/aiCore/src/core/plugins/types.ts b/packages/aiCore/src/core/plugins/types.ts index d7ca283c14..2da8af8025 100644 --- a/packages/aiCore/src/core/plugins/types.ts +++ b/packages/aiCore/src/core/plugins/types.ts @@ -1,79 +1,126 @@ -import type { ImageModelV3 } from '@ai-sdk/provider' -import type { LanguageModel, TextStreamPart, ToolSet } from 'ai' +import type { JSONObject, JSONValue } from '@ai-sdk/provider' +import type { generateText, LanguageModelMiddleware, streamText, TextStreamPart, ToolSet } from 'ai' -import { type ProviderId } from '../providers/types' +import type { AiSdkModel, ProviderId } from '../providers/types' + +/** + * 常用的 AI SDK 参数类型(完整版,用于插件泛型) + */ +export type StreamTextParams = Parameters[0] +export type StreamTextResult = ReturnType +export type GenerateTextParams = Parameters[0] +export type GenerateTextResult = ReturnType + +/** + * AI 请求元数据 + * 定义结构化的元数据字段,避免使用 Record + */ +export interface AiRequestMetadata { + topicId?: string + callType?: string + enableReasoning?: boolean + enableWebSearch?: boolean + enableGenerateImage?: boolean + isPromptToolUse?: boolean + isSupportedToolUse?: boolean + isImageGenerationEndpoint?: boolean + // 自定义元数据,使用 JSONValue 确保类型安全 + custom?: JSONObject +} /** * 递归调用函数类型 - * 使用 any 是因为递归调用时参数和返回类型可能完全不同 + * 泛型化以保持类型推导 */ -export type RecursiveCallFn = (newParams: any) => Promise +export type RecursiveCallFn = (newParams: Partial) => Promise /** * AI 请求上下文 + * 使用泛型参数以支持不同类型的请求 */ -export interface AiRequestContext { +export interface AiRequestContext { providerId: ProviderId - model: LanguageModel | ImageModelV3 - originalParams: any - metadata: Record + model: AiSdkModel + originalParams: TParams + metadata: AiRequestMetadata startTime: number requestId: string - recursiveCall: RecursiveCallFn - isRecursiveCall?: boolean + recursiveCall: RecursiveCallFn + isRecursiveCall: boolean + + // 递归深度控制(防止栈溢出) + recursiveDepth: number // 当前递归深度 + maxRecursiveDepth: number // 最大递归深度限制,默认 10 + mcpTools?: ToolSet + + extensions: Map + + middlewares?: LanguageModelMiddleware[] + + // 向后兼容:允许插件动态添加属性(临时保留) [key: string]: any } /** * 钩子分类 + * 使用泛型参数以支持不同类型的请求和响应 */ -export interface AiPlugin { +export interface AiPlugin { name: string enforce?: 'pre' | 'post' // 【First】首个钩子 - 只执行第一个返回值的插件 resolveModel?: ( modelId: string, - context: AiRequestContext - ) => Promise | LanguageModel | ImageModelV3 | null - loadTemplate?: (templateName: string, context: AiRequestContext) => any | null | Promise + context: AiRequestContext + ) => Promise | AiSdkModel | null + + loadTemplate?: ( + templateName: string, + context: AiRequestContext + ) => JSONValue | null | Promise // 【Sequential】串行钩子 - 链式执行,支持数据转换 - configureContext?: (context: AiRequestContext) => void | Promise - transformParams?: (params: T, context: AiRequestContext) => T | Promise - transformResult?: (result: T, context: AiRequestContext) => T | Promise + configureContext?: (context: AiRequestContext) => void | Promise + + transformParams?: ( + params: TParams, + context: AiRequestContext + ) => Partial | Promise> + + transformResult?: (result: TResult, context: AiRequestContext) => TResult | Promise // 【Parallel】并行钩子 - 不依赖顺序,用于副作用 - onRequestStart?: (context: AiRequestContext) => void | Promise - onRequestEnd?: (context: AiRequestContext, result: any) => void | Promise - onError?: (error: Error, context: AiRequestContext) => void | Promise + onRequestStart?: (context: AiRequestContext) => void | Promise + + onRequestEnd?: (context: AiRequestContext, result: TResult) => void | Promise + + onError?: (error: Error, context: AiRequestContext) => void | Promise // 【Stream】流处理 - 直接使用 AI SDK transformStream?: ( - params: any, - context: AiRequestContext + params: TParams, + context: AiRequestContext ) => (options?: { tools: TOOLS stopStream: () => void }) => TransformStream, TextStreamPart> - - // AI SDK 原生中间件 - // aiSdkMiddlewares?: LanguageModelV1Middleware[] } /** * 插件管理器配置 */ -export interface PluginManagerConfig { - plugins: AiPlugin[] - context: Partial +export interface PluginManagerConfig { + plugins: AiPlugin[] + context: Partial> } /** * 钩子执行结果 + * 泛型参数指定返回值类型 */ -export interface HookResult { +export interface HookResult { value: T stop?: boolean } diff --git a/packages/aiCore/src/core/providers/HubProvider.ts b/packages/aiCore/src/core/providers/HubProvider.ts index 629c41b421..37d40e147e 100644 --- a/packages/aiCore/src/core/providers/HubProvider.ts +++ b/packages/aiCore/src/core/providers/HubProvider.ts @@ -5,11 +5,19 @@ * 例如: aihubmix:anthropic:claude-3.5-sonnet */ -import type { ProviderV3 } from '@ai-sdk/provider' -import { customProvider } from 'ai' +import type { + EmbeddingModelV3, + ImageModelV3, + LanguageModelV3, + ProviderV3, + RerankingModelV3, + SpeechModelV3, + TranscriptionModelV3 +} from '@ai-sdk/provider' +import { customProvider, wrapProvider } from 'ai' -import { globalRegistryManagement } from './RegistryManagement' -import type { AiSdkMethodName, AiSdkModelReturn, AiSdkModelType } from './types' +import { DEFAULT_SEPARATOR, globalRegistryManagement } from './RegistryManagement' +import type { AiSdkProvider } from './types' export interface HubProviderConfig { /** Hub的唯一标识符 */ @@ -34,7 +42,7 @@ export class HubProviderError extends Error { * 解析Hub模型ID */ function parseHubModelId(modelId: string): { provider: string; actualModelId: string } { - const parts = modelId.split(':') + const parts = modelId.split(DEFAULT_SEPARATOR) if (parts.length !== 2) { throw new HubProviderError(`Invalid hub model ID format. Expected "provider:modelId", got: ${modelId}`, 'unknown') } @@ -47,7 +55,7 @@ function parseHubModelId(modelId: string): { provider: string; actualModelId: st /** * 创建Hub Provider */ -export function createHubProvider(config: HubProviderConfig): ProviderV3 { +export function createHubProvider(config: HubProviderConfig): AiSdkProvider { const { hubId } = config function getTargetProvider(providerId: string): ProviderV3 { @@ -61,7 +69,9 @@ export function createHubProvider(config: HubProviderConfig): ProviderV3 { providerId ) } - return provider + // 使用 wrapProvider 确保返回的是 V3 provider + // 这样可以自动处理 V2 provider 到 V3 的转换 + return wrapProvider({ provider, languageModelMiddleware: [] }) } catch (error) { throw new HubProviderError( `Failed to get provider "${providerId}": ${error instanceof Error ? error.message : 'Unknown error'}`, @@ -72,30 +82,62 @@ export function createHubProvider(config: HubProviderConfig): ProviderV3 { } } - function resolveModel( - modelId: string, - modelType: T, - methodName: AiSdkMethodName - ): AiSdkModelReturn { - const { provider, actualModelId } = parseHubModelId(modelId) - const targetProvider = getTargetProvider(provider) + // 创建符合 ProviderV3 规范的 fallback provider + const hubFallbackProvider = { + specificationVersion: 'v3' as const, - const fn = targetProvider[methodName] as (id: string) => AiSdkModelReturn + languageModel: (modelId: string): LanguageModelV3 => { + const { provider, actualModelId } = parseHubModelId(modelId) + const targetProvider = getTargetProvider(provider) + return targetProvider.languageModel(actualModelId) + }, - if (!fn) { - throw new HubProviderError(`Provider "${provider}" does not support ${modelType}`, hubId, provider) + embeddingModel: (modelId: string): EmbeddingModelV3 => { + const { provider, actualModelId } = parseHubModelId(modelId) + const targetProvider = getTargetProvider(provider) + return targetProvider.embeddingModel(actualModelId) + }, + + imageModel: (modelId: string): ImageModelV3 => { + const { provider, actualModelId } = parseHubModelId(modelId) + const targetProvider = getTargetProvider(provider) + return targetProvider.imageModel(actualModelId) + }, + + transcriptionModel: (modelId: string): TranscriptionModelV3 => { + const { provider, actualModelId } = parseHubModelId(modelId) + const targetProvider = getTargetProvider(provider) + + if (!targetProvider.transcriptionModel) { + throw new HubProviderError(`Provider "${provider}" does not support transcription models`, hubId, provider) + } + + return targetProvider.transcriptionModel(actualModelId) + }, + + speechModel: (modelId: string): SpeechModelV3 => { + const { provider, actualModelId } = parseHubModelId(modelId) + const targetProvider = getTargetProvider(provider) + + if (!targetProvider.speechModel) { + throw new HubProviderError(`Provider "${provider}" does not support speech models`, hubId, provider) + } + + return targetProvider.speechModel(actualModelId) + }, + rerankingModel: (modelId: string): RerankingModelV3 => { + const { provider, actualModelId } = parseHubModelId(modelId) + const targetProvider = getTargetProvider(provider) + + if (!targetProvider.rerankingModel) { + throw new HubProviderError(`Provider "${provider}" does not support reranking models`, hubId, provider) + } + + return targetProvider.rerankingModel(actualModelId) } - - return fn(actualModelId) } return customProvider({ - fallbackProvider: { - languageModel: (modelId: string) => resolveModel(modelId, 'text', 'languageModel'), - textEmbeddingModel: (modelId: string) => resolveModel(modelId, 'embedding', 'textEmbeddingModel'), - imageModel: (modelId: string) => resolveModel(modelId, 'image', 'imageModel'), - transcriptionModel: (modelId: string) => resolveModel(modelId, 'transcription', 'transcriptionModel'), - speechModel: (modelId: string) => resolveModel(modelId, 'speech', 'speechModel') - } + fallbackProvider: hubFallbackProvider }) } diff --git a/packages/aiCore/src/core/providers/RegistryManagement.ts b/packages/aiCore/src/core/providers/RegistryManagement.ts index aefec2c542..e9a9ca358a 100644 --- a/packages/aiCore/src/core/providers/RegistryManagement.ts +++ b/packages/aiCore/src/core/providers/RegistryManagement.ts @@ -11,8 +11,6 @@ type PROVIDERS = Record export const DEFAULT_SEPARATOR = '|' -// export type MODEL_ID = `${string}${typeof DEFAULT_SEPARATOR}${string}` - export class RegistryManagement { private providers: PROVIDERS = {} private aliases: Set = new Set() // 记录哪些key是别名 diff --git a/packages/aiCore/src/core/providers/index.ts b/packages/aiCore/src/core/providers/index.ts index b9ebd6f682..358987f8ff 100644 --- a/packages/aiCore/src/core/providers/index.ts +++ b/packages/aiCore/src/core/providers/index.ts @@ -56,6 +56,7 @@ export type { } from './schemas' // 从 schemas 导出的类型 export { baseProviderIdSchema, customProviderIdSchema, providerConfigSchema, providerIdSchema } from './schemas' // Schema 导出 export type { + AiSdkModel, DynamicProviderRegistry, ExtensibleProviderSettingsMap, ProviderError, diff --git a/packages/aiCore/src/core/providers/schemas.ts b/packages/aiCore/src/core/providers/schemas.ts index 3274b5c993..8c038bebbf 100644 --- a/packages/aiCore/src/core/providers/schemas.ts +++ b/packages/aiCore/src/core/providers/schemas.ts @@ -12,8 +12,8 @@ import { createOpenAICompatible } from '@ai-sdk/openai-compatible' import type { ProviderV3 } from '@ai-sdk/provider' import { createXai } from '@ai-sdk/xai' import { type CherryInProviderSettings, createCherryIn } from '@cherrystudio/ai-sdk-provider' -import { createOpenRouter } from '@openrouter/ai-sdk-provider' -import { customProvider } from 'ai' +import { createOpenRouter, type OpenRouterProviderSettings } from '@openrouter/ai-sdk-provider' +import { customProvider, wrapProvider } from 'ai' import * as z from 'zod' /** @@ -133,7 +133,10 @@ export const baseProviders = [ { id: 'openrouter', name: 'OpenRouter', - creator: createOpenRouter, + creator: (options?: OpenRouterProviderSettings) => { + const provider = createOpenRouter(options) + return wrapProvider({ provider, languageModelMiddleware: [] }) + }, supportsImageGeneration: true }, { diff --git a/packages/aiCore/src/core/providers/types.ts b/packages/aiCore/src/core/providers/types.ts index 177b80bf80..0ed732cba1 100644 --- a/packages/aiCore/src/core/providers/types.ts +++ b/packages/aiCore/src/core/providers/types.ts @@ -4,15 +4,18 @@ import { type DeepSeekProviderSettings } from '@ai-sdk/deepseek' import { type GoogleGenerativeAIProviderSettings } from '@ai-sdk/google' import { type OpenAIProviderSettings } from '@ai-sdk/openai' import { type OpenAICompatibleProviderSettings } from '@ai-sdk/openai-compatible' -import type { - EmbeddingModelV3 as EmbeddingModel, - ImageModelV3 as ImageModel, - LanguageModelV3 as LanguageModel, - ProviderV3, - SpeechModelV3 as SpeechModel, - TranscriptionModelV3 as TranscriptionModel -} from '@ai-sdk/provider' +import type { ProviderV2, ProviderV3 } from '@ai-sdk/provider' import { type XaiProviderSettings } from '@ai-sdk/xai' +import type { + EmbeddingModel, + EmbeddingModelUsage, + ImageModel, + ImageModelUsage, + LanguageModel, + LanguageModelUsage, + SpeechModel, + TranscriptionModel +} from 'ai' // 导入基于 Zod 的 ProviderId 类型 import { type ProviderId as ZodProviderId } from './schemas' @@ -70,6 +73,8 @@ export type { } export type AiSdkModel = LanguageModel | ImageModel | EmbeddingModel | TranscriptionModel | SpeechModel +export type AiSdkProvider = ProviderV2 | ProviderV3 +export type AiSdkUsage = LanguageModelUsage | ImageModelUsage | EmbeddingModelUsage export type AiSdkModelType = 'text' | 'image' | 'embedding' | 'transcription' | 'speech' diff --git a/packages/aiCore/src/core/runtime/executor.ts b/packages/aiCore/src/core/runtime/executor.ts index bffe55b806..db93e21562 100644 --- a/packages/aiCore/src/core/runtime/executor.ts +++ b/packages/aiCore/src/core/runtime/executor.ts @@ -4,10 +4,16 @@ */ import type { ImageModelV3, LanguageModelV3, LanguageModelV3Middleware } from '@ai-sdk/provider' import type { LanguageModel } from 'ai' -import { generateImage as _generateImage, generateText as _generateText, streamText as _streamText } from 'ai' +import { + generateImage as _generateImage, + generateText as _generateText, + streamText as _streamText, + wrapLanguageModel +} from 'ai' import { globalModelResolver } from '../models' import { type ModelConfig } from '../models/types' +import { isV3Model } from '../models/utils' import { type AiPlugin, type AiRequestContext, definePlugin } from '../plugins' import { type ProviderId } from '../providers' import { ImageGenerationError, ImageModelResolutionError } from './errors' @@ -181,8 +187,20 @@ export class RuntimeExecutor { middlewares // 中间件数组 ) } else { - // 已经是模型,直接返回 - return modelOrId + // 已经是模型对象 + // 所有 provider 都应该返回 V3 模型(通过 wrapProvider 确保) + if (!isV3Model(modelOrId)) { + throw new Error( + `Model must be V3. Provider "${this.config.providerId}" returned a V2 model. ` + + 'All providers should be wrapped with wrapProvider to return V3 models.' + ) + } + + // V3 模型,使用 wrapLanguageModel 应用中间件 + return wrapLanguageModel({ + model: modelOrId, + middleware: middlewares || [] + }) } } diff --git a/packages/aiCore/src/core/runtime/pluginEngine.ts b/packages/aiCore/src/core/runtime/pluginEngine.ts index 82d2845999..3a81fa4d7b 100644 --- a/packages/aiCore/src/core/runtime/pluginEngine.ts +++ b/packages/aiCore/src/core/runtime/pluginEngine.ts @@ -1,8 +1,18 @@ /* eslint-disable @eslint-react/naming-convention/context-name */ import type { ImageModelV3 } from '@ai-sdk/provider' -import type { generateImage, generateText, LanguageModel, streamText } from 'ai' +import type { generateImage, LanguageModel } from 'ai' -import { type AiPlugin, createContext, PluginManager } from '../plugins' +import { ModelResolutionError, RecursiveDepthError } from '../errors' +import { + type AiPlugin, + type AiRequestContext, + createContext, + type GenerateTextParams, + type GenerateTextResult, + PluginManager, + type StreamTextParams, + type StreamTextResult +} from '../plugins' import { type ProviderId } from '../providers/types' /** @@ -10,21 +20,22 @@ import { type ProviderId } from '../providers/types' * 专注于插件处理,不暴露用户API */ export class PluginEngine { - private pluginManager: PluginManager + // ✅ 存储为非泛型数组(允许混合不同类型的插件) + private basePlugins: AiPlugin[] = [] constructor( private readonly providerId: T, // private readonly options: ProviderSettingsMap[T], plugins: AiPlugin[] = [] ) { - this.pluginManager = new PluginManager(plugins) + this.basePlugins = plugins } /** * 添加插件 */ use(plugin: AiPlugin): this { - this.pluginManager.use(plugin) + this.basePlugins.push(plugin) return this } @@ -32,7 +43,7 @@ export class PluginEngine { * 批量添加插件 */ usePlugins(plugins: AiPlugin[]): this { - plugins.forEach((plugin) => this.use(plugin)) + this.basePlugins.push(...plugins) return this } @@ -40,7 +51,7 @@ export class PluginEngine { * 移除插件 */ removePlugin(pluginName: string): this { - this.pluginManager.remove(pluginName) + this.basePlugins = this.basePlugins.filter((p) => p.name !== pluginName) return this } @@ -48,14 +59,16 @@ export class PluginEngine { * 获取插件统计 */ getPluginStats() { - return this.pluginManager.getStats() + // 创建临时 manager 来获取统计信息 + const tempManager = new PluginManager(this.basePlugins) + return tempManager.getStats() } /** * 获取所有插件 */ getPlugins() { - return this.pluginManager.getPlugins() + return [...this.basePlugins] } /** @@ -63,13 +76,13 @@ export class PluginEngine { * 提供给AiExecutor使用 */ async executeWithPlugins< - TParams extends Parameters[0], - TResult extends ReturnType + TParams extends GenerateTextParams, + TResult extends GenerateTextResult >( methodName: string, params: TParams, executor: (model: LanguageModel, transformedParams: TParams) => TResult, - _context?: ReturnType + _context?: AiRequestContext ): Promise { // 统一处理模型解析 let resolvedModel: LanguageModel | undefined @@ -84,54 +97,76 @@ export class PluginEngine { modelId = model.modelId } - // 使用正确的createContext创建请求上下文 - const context = _context ? _context : createContext(this.providerId, model, params) + // 创建类型安全的 context + const context = _context ?? createContext(this.providerId, model, params) - // 🔥 为上下文添加递归调用能力 - context.recursiveCall = async (newParams: any): Promise => { - // 递归调用自身,重新走完整的插件流程 - context.isRecursiveCall = true - const result = await this.executeWithPlugins(methodName, newParams, executor, context) - context.isRecursiveCall = false - return result + // ✅ 创建类型化的 manager(逆变安全) + const manager = new PluginManager( + this.basePlugins as AiPlugin[] + ) + + // ✅ 递归调用泛型化,增加深度限制 + context.recursiveCall = async (newParams: Partial): Promise => { + if (context.recursiveDepth >= context.maxRecursiveDepth) { + throw new RecursiveDepthError(context.requestId, context.recursiveDepth, context.maxRecursiveDepth) + } + + const previousDepth = context.recursiveDepth + const wasRecursive = context.isRecursiveCall + + try { + context.recursiveDepth = previousDepth + 1 + context.isRecursiveCall = true + + return await this.executeWithPlugins( + methodName, + { ...params, ...newParams } as TParams, + executor, + context + ) as unknown as R + } finally { + // ✅ finally 确保状态恢复 + context.recursiveDepth = previousDepth + context.isRecursiveCall = wasRecursive + } } try { // 0. 配置上下文 - await this.pluginManager.executeConfigureContext(context) + await manager.executeConfigureContext(context) // 1. 触发请求开始事件 - await this.pluginManager.executeParallel('onRequestStart', context) + await manager.executeParallel('onRequestStart', context) // 2. 解析模型(如果是字符串) if (typeof model === 'string') { - const resolved = await this.pluginManager.executeFirst('resolveModel', modelId, context) + const resolved = await manager.executeFirst('resolveModel', modelId, context) if (!resolved) { - throw new Error(`Failed to resolve model: ${modelId}`) + throw new ModelResolutionError(modelId, this.providerId) } resolvedModel = resolved } if (!resolvedModel) { - throw new Error(`Model resolution failed: no model available`) + throw new ModelResolutionError(modelId, this.providerId) } // 3. 转换请求参数 - const transformedParams = await this.pluginManager.executeSequential('transformParams', params, context) + const transformedParams = await manager.executeTransformParams(params, context) // 4. 执行具体的 API 调用 const result = await executor(resolvedModel, transformedParams) // 5. 转换结果(对于非流式调用) - const transformedResult = await this.pluginManager.executeSequential('transformResult', result, context) + const transformedResult = await manager.executeTransformResult(result, context) // 6. 触发完成事件 - await this.pluginManager.executeParallel('onRequestEnd', context, transformedResult) + await manager.executeParallel('onRequestEnd', context, transformedResult) return transformedResult } catch (error) { // 7. 触发错误事件 - await this.pluginManager.executeParallel('onError', context, undefined, error as Error) + await manager.executeParallel('onError', context, undefined, error as Error) throw error } } @@ -147,7 +182,7 @@ export class PluginEngine { methodName: string, params: TParams, executor: (model: ImageModelV3, transformedParams: TParams) => TResult, - _context?: ReturnType + _context?: AiRequestContext ): Promise { // 统一处理模型解析 let resolvedModel: ImageModelV3 | undefined @@ -162,54 +197,76 @@ export class PluginEngine { modelId = model.modelId } - // 使用正确的createContext创建请求上下文 - const context = _context ? _context : createContext(this.providerId, model, params) + // 创建类型安全的 context + const context = _context ?? createContext(this.providerId, model, params) - // 🔥 为上下文添加递归调用能力 - context.recursiveCall = async (newParams: any): Promise => { - // 递归调用自身,重新走完整的插件流程 - context.isRecursiveCall = true - const result = await this.executeImageWithPlugins(methodName, newParams, executor, context) - context.isRecursiveCall = false - return result + // ✅ 创建类型化的 manager(逆变安全) + const manager = new PluginManager( + this.basePlugins as AiPlugin[] + ) + + // ✅ 递归调用泛型化,增加深度限制 + context.recursiveCall = async (newParams: Partial): Promise => { + if (context.recursiveDepth >= context.maxRecursiveDepth) { + throw new RecursiveDepthError(context.requestId, context.recursiveDepth, context.maxRecursiveDepth) + } + + const previousDepth = context.recursiveDepth + const wasRecursive = context.isRecursiveCall + + try { + context.recursiveDepth = previousDepth + 1 + context.isRecursiveCall = true + + return await this.executeImageWithPlugins( + methodName, + { ...params, ...newParams } as TParams, + executor, + context + ) as unknown as R + } finally { + // ✅ finally 确保状态恢复 + context.recursiveDepth = previousDepth + context.isRecursiveCall = wasRecursive + } } try { // 0. 配置上下文 - await this.pluginManager.executeConfigureContext(context) + await manager.executeConfigureContext(context) // 1. 触发请求开始事件 - await this.pluginManager.executeParallel('onRequestStart', context) + await manager.executeParallel('onRequestStart', context) // 2. 解析模型(如果是字符串) if (typeof model === 'string') { - const resolved = await this.pluginManager.executeFirst('resolveModel', modelId, context) + const resolved = await manager.executeFirst('resolveModel', modelId, context) if (!resolved) { - throw new Error(`Failed to resolve image model: ${modelId}`) + throw new ModelResolutionError(modelId, this.providerId) } resolvedModel = resolved } if (!resolvedModel) { - throw new Error(`Image model resolution failed: no model available`) + throw new ModelResolutionError(modelId, this.providerId) } // 3. 转换请求参数 - const transformedParams = await this.pluginManager.executeSequential('transformParams', params, context) + const transformedParams = await manager.executeTransformParams(params, context) // 4. 执行具体的 API 调用 const result = await executor(resolvedModel, transformedParams) // 5. 转换结果 - const transformedResult = await this.pluginManager.executeSequential('transformResult', result, context) + const transformedResult = await manager.executeTransformResult(result, context) // 6. 触发完成事件 - await this.pluginManager.executeParallel('onRequestEnd', context, transformedResult) + await manager.executeParallel('onRequestEnd', context, transformedResult) return transformedResult } catch (error) { // 7. 触发错误事件 - await this.pluginManager.executeParallel('onError', context, undefined, error as Error) + await manager.executeParallel('onError', context, undefined, error as Error) throw error } } @@ -219,13 +276,13 @@ export class PluginEngine { * 提供给AiExecutor使用 */ async executeStreamWithPlugins< - TParams extends Parameters[0], - TResult extends ReturnType + TParams extends StreamTextParams, + TResult extends StreamTextResult >( methodName: string, params: TParams, executor: (model: LanguageModel, transformedParams: TParams, streamTransforms: any[]) => TResult, - _context?: ReturnType + _context?: AiRequestContext ): Promise { // 统一处理模型解析 let resolvedModel: LanguageModel | undefined @@ -240,56 +297,78 @@ export class PluginEngine { modelId = model.modelId } - // 创建请求上下文 - const context = _context ? _context : createContext(this.providerId, model, params) + // 创建类型安全的 context + const context = _context ?? createContext(this.providerId, model, params) - // 🔥 为上下文添加递归调用能力 - context.recursiveCall = async (newParams: any): Promise => { - // 递归调用自身,重新走完整的插件流程 - context.isRecursiveCall = true - const result = await this.executeStreamWithPlugins(methodName, newParams, executor, context) - context.isRecursiveCall = false - return result + // ✅ 创建类型化的 manager(逆变安全) + const manager = new PluginManager( + this.basePlugins as AiPlugin[] + ) + + // ✅ 递归调用泛型化,增加深度限制 + context.recursiveCall = async (newParams: Partial): Promise => { + if (context.recursiveDepth >= context.maxRecursiveDepth) { + throw new RecursiveDepthError(context.requestId, context.recursiveDepth, context.maxRecursiveDepth) + } + + const previousDepth = context.recursiveDepth + const wasRecursive = context.isRecursiveCall + + try { + context.recursiveDepth = previousDepth + 1 + context.isRecursiveCall = true + + return await this.executeStreamWithPlugins( + methodName, + { ...params, ...newParams } as TParams, + executor, + context + ) as unknown as R + } finally { + // ✅ finally 确保状态恢复 + context.recursiveDepth = previousDepth + context.isRecursiveCall = wasRecursive + } } try { // 0. 配置上下文 - await this.pluginManager.executeConfigureContext(context) + await manager.executeConfigureContext(context) // 1. 触发请求开始事件 - await this.pluginManager.executeParallel('onRequestStart', context) + await manager.executeParallel('onRequestStart', context) // 2. 解析模型(如果是字符串) if (typeof model === 'string') { - const resolved = await this.pluginManager.executeFirst('resolveModel', modelId, context) + const resolved = await manager.executeFirst('resolveModel', modelId, context) if (!resolved) { - throw new Error(`Failed to resolve model: ${modelId}`) + throw new ModelResolutionError(modelId, this.providerId) } resolvedModel = resolved } if (!resolvedModel) { - throw new Error(`Model resolution failed: no model available`) + throw new ModelResolutionError(modelId, this.providerId) } // 3. 转换请求参数 - const transformedParams = await this.pluginManager.executeSequential('transformParams', params, context) + const transformedParams = await manager.executeTransformParams(params, context) // 4. 收集流转换器 - const streamTransforms = this.pluginManager.collectStreamTransforms(transformedParams, context) + const streamTransforms = manager.collectStreamTransforms(transformedParams, context) // 5. 执行流式 API 调用 const result = await executor(resolvedModel, transformedParams, streamTransforms) - const transformedResult = await this.pluginManager.executeSequential('transformResult', result, context) + const transformedResult = await manager.executeTransformResult(result, context) // 6. 触发完成事件(注意:对于流式调用,这里触发的是开始流式响应的事件) - await this.pluginManager.executeParallel('onRequestEnd', context, transformedResult) + await manager.executeParallel('onRequestEnd', context, transformedResult) return transformedResult } catch (error) { // 7. 触发错误事件 - await this.pluginManager.executeParallel('onError', context, undefined, error as Error) + await manager.executeParallel('onError', context, undefined, error as Error) throw error } } diff --git a/packages/aiCore/src/core/types/branded.ts b/packages/aiCore/src/core/types/branded.ts new file mode 100644 index 0000000000..06a37ca6c8 --- /dev/null +++ b/packages/aiCore/src/core/types/branded.ts @@ -0,0 +1,80 @@ +/** + * Branded Types for type-safe IDs + * + * Branded types prevent accidental misuse of primitive types (like string) + * by adding compile-time type safety without runtime overhead. + * + * @example + * ```typescript + * const modelId = ModelId('gpt-4') // ModelId type + * const requestId = RequestId('req-123') // RequestId type + * + * function processModel(id: ModelId) { ... } + * processModel(requestId) // ❌ Compile error - type mismatch + * ``` + */ + +/** + * Brand helper type + */ +type Brand = K & { readonly __brand: T } + +/** + * Model ID branded type + * Represents a unique model identifier + */ +export type ModelId = Brand + +/** + * Request ID branded type + * Represents a unique request identifier for tracing + */ +export type RequestId = Brand + +/** + * Provider ID branded type + * Represents a provider identifier (e.g., 'openai', 'anthropic') + */ +export type ProviderId = Brand + +/** + * Create a ModelId from a string + * @param id - The model identifier string + * @returns Branded ModelId + */ +export const ModelId = (id: string): ModelId => id as ModelId + +/** + * Create a RequestId from a string + * @param id - The request identifier string + * @returns Branded RequestId + */ +export const RequestId = (id: string): RequestId => id as RequestId + +/** + * Create a ProviderId from a string + * @param id - The provider identifier string + * @returns Branded ProviderId + */ +export const ProviderId = (id: string): ProviderId => id as ProviderId + +/** + * Type guard to check if a string is a valid ModelId + */ +export const isModelId = (value: unknown): value is ModelId => { + return typeof value === 'string' && value.length > 0 +} + +/** + * Type guard to check if a string is a valid RequestId + */ +export const isRequestId = (value: unknown): value is RequestId => { + return typeof value === 'string' && value.length > 0 +} + +/** + * Type guard to check if a string is a valid ProviderId + */ +export const isProviderId = (value: unknown): value is ProviderId => { + return typeof value === 'string' && value.length > 0 +} diff --git a/packages/aiCore/src/index.ts b/packages/aiCore/src/index.ts index cbad877669..015ae59856 100644 --- a/packages/aiCore/src/index.ts +++ b/packages/aiCore/src/index.ts @@ -15,19 +15,34 @@ export { } from './core/runtime' // ==================== 高级API ==================== -export { globalModelResolver as modelResolver } from './core/models' +export { isV2Model, isV3Model, globalModelResolver as modelResolver } from './core/models' // ==================== 插件系统 ==================== -export type { AiPlugin, AiRequestContext, HookResult, PluginManagerConfig } from './core/plugins' +export type { + AiPlugin, + AiRequestContext, + AiRequestMetadata, + GenerateTextParams, + GenerateTextResult, + HookResult, + PluginManagerConfig, + RecursiveCallFn, + StreamTextParams, + StreamTextResult +} from './core/plugins' export { createContext, definePlugin, PluginManager } from './core/plugins' -// export { createPromptToolUsePlugin, webSearchPlugin } from './core/plugins/built-in' export { PluginEngine } from './core/runtime/pluginEngine' -// ==================== AI SDK 常用类型导出 ==================== -// 直接导出 AI SDK 的常用类型,方便使用 -export type { LanguageModelV3Middleware, LanguageModelV3StreamPart } from '@ai-sdk/provider' -export type { ToolCall } from '@ai-sdk/provider-utils' -export type { ReasoningPart } from '@ai-sdk/provider-utils' +// ==================== 类型工具 ==================== +export type { ModelId, ProviderId, RequestId } from './core/types/branded' +export { isModelId, isProviderId, isRequestId } from './core/types/branded' +// Branded type constructors (values, not types) +export type { AiSdkModel } from './core/providers' +export { + ModelId as createModelId, + ProviderId as createProviderId, + RequestId as createRequestId +} from './core/types/branded' // ==================== 选项 ==================== export { @@ -40,6 +55,17 @@ export { type TypedProviderOptions } from './core/options' +// ==================== 错误处理 ==================== +export { + AiCoreError, + ModelResolutionError, + ParameterValidationError, + PluginExecutionError, + ProviderConfigError, + RecursiveDepthError, + TemplateLoadError +} from './core/errors' + // ==================== 包信息 ==================== export const AI_CORE_VERSION = '1.0.0' export const AI_CORE_NAME = '@cherrystudio/ai-core' diff --git a/src/renderer/src/aiCore/index_new.ts b/src/renderer/src/aiCore/index_new.ts index 5c84a7254e..ea6fab734d 100644 --- a/src/renderer/src/aiCore/index_new.ts +++ b/src/renderer/src/aiCore/index_new.ts @@ -7,6 +7,7 @@ * 2. 暂时保持接口兼容性 */ +import type { AiSdkModel } from '@cherrystudio/ai-core' import { createExecutor } from '@cherrystudio/ai-core' import { loggerService } from '@logger' import { getEnableDeveloperMode } from '@renderer/hooks/useSettings' @@ -14,16 +15,14 @@ import { normalizeGatewayModels, normalizeSdkModels } from '@renderer/services/m import { addSpan, endSpan } from '@renderer/services/SpanManagerService' import type { StartSpanParams } from '@renderer/trace/types/ModelSpanEntity' import { type Assistant, type GenerateImageParams, type Model, type Provider, SystemProviderIds } from '@renderer/types' -import type { AiSdkModel, StreamTextParams } from '@renderer/types/aiCoreTypes' +import type { StreamTextParams } from '@renderer/types/aiCoreTypes' import { SUPPORTED_IMAGE_ENDPOINT_LIST } from '@renderer/utils' import { buildClaudeCodeSystemModelMessage } from '@shared/anthropic' -import { gateway, type ImageModel, type LanguageModel, type Provider as AiSdkProvider, wrapLanguageModel } from 'ai' +import { gateway, type LanguageModel, type Provider as AiSdkProvider } from 'ai' import AiSdkToChunkAdapter from './chunk/AiSdkToChunkAdapter' import LegacyAiProvider from './legacy/index' import type { CompletionsParams, CompletionsResult } from './legacy/middleware/schemas' -import type { AiSdkMiddlewareConfig } from './middleware/AiSdkMiddlewareBuilder' -import { buildAiSdkMiddlewares } from './middleware/AiSdkMiddlewareBuilder' import { buildPlugins } from './plugins/PluginBuilder' import { createAiSdkProvider } from './provider/factory' import { @@ -34,6 +33,7 @@ import { providerToAiSdkConfig } from './provider/providerConfig' import type { AiSdkConfig } from './types' +import type { AiSdkMiddlewareConfig } from './types/middlewareConfig' const logger = loggerService.withContext('ModernAiProvider') @@ -144,16 +144,6 @@ export default class ModernAiProvider { this.localProvider = await createAiSdkProvider(this.config) } - // 提前构建中间件 - const middlewares = buildAiSdkMiddlewares({ - ...providerConfig, - provider: this.actualProvider, - assistant: providerConfig.assistant - }) - logger.debug('Built middlewares in completions', { - middlewareCount: middlewares.length, - isImageGeneration: providerConfig.isImageGenerationEndpoint - }) if (!this.localProvider) { throw new Error('Local provider not created') } @@ -164,14 +154,20 @@ export default class ModernAiProvider { model = this.localProvider.imageModel(modelId) } else { model = this.localProvider.languageModel(modelId) - // 如果有中间件,应用到语言模型上 - if (middlewares.length > 0 && typeof model === 'object') { - model = wrapLanguageModel({ model, middleware: middlewares }) - } } if (this.actualProvider.id === 'anthropic' && this.actualProvider.authType === 'oauth') { - const claudeCodeSystemMessage = buildClaudeCodeSystemModelMessage(params.system) + // 类型守卫:确保 system 是 string、Array 或 undefined + const system = params.system + let systemParam: string | Array | undefined + if (typeof system === 'string' || Array.isArray(system) || system === undefined) { + systemParam = system + } else { + // SystemModelMessage 类型,转换为 string + systemParam = undefined + } + + const claudeCodeSystemMessage = buildClaudeCodeSystemModelMessage(systemParam) params.system = undefined // 清除原有system,避免重复 params.messages = [...claudeCodeSystemMessage, ...(params.messages || [])] } @@ -543,7 +539,7 @@ export default class ModernAiProvider { const executor = createExecutor(this.config!.providerId, this.config!.options, []) const result = await executor.generateImage({ - model: this.localProvider?.imageModel(model) as ImageModel, + model: model, // 直接使用 model ID 字符串,由 executor 内部解析 ...aiSdkParams }) diff --git a/src/renderer/src/aiCore/middleware/AiSdkMiddlewareBuilder.ts b/src/renderer/src/aiCore/middleware/AiSdkMiddlewareBuilder.ts deleted file mode 100644 index b2a796bd33..0000000000 --- a/src/renderer/src/aiCore/middleware/AiSdkMiddlewareBuilder.ts +++ /dev/null @@ -1,286 +0,0 @@ -import type { WebSearchPluginConfig } from '@cherrystudio/ai-core/built-in/plugins' -import { loggerService } from '@logger' -import { isGemini3Model, isSupportedThinkingTokenQwenModel } from '@renderer/config/models' -import type { MCPTool } from '@renderer/types' -import { type Assistant, type Message, type Model, type Provider, SystemProviderIds } from '@renderer/types' -import type { Chunk } from '@renderer/types/chunk' -import { isOllamaProvider, isSupportEnableThinkingProvider } from '@renderer/utils/provider' -import type { LanguageModelMiddleware } from 'ai' -import { extractReasoningMiddleware, simulateStreamingMiddleware } from 'ai' - -import { getAiSdkProviderId } from '../provider/factory' -import { isOpenRouterGeminiGenerateImageModel } from '../utils/image' -import { noThinkMiddleware } from './noThinkMiddleware' -import { openrouterGenerateImageMiddleware } from './openrouterGenerateImageMiddleware' -import { openrouterReasoningMiddleware } from './openrouterReasoningMiddleware' -import { qwenThinkingMiddleware } from './qwenThinkingMiddleware' -import { skipGeminiThoughtSignatureMiddleware } from './skipGeminiThoughtSignatureMiddleware' - -const logger = loggerService.withContext('AiSdkMiddlewareBuilder') - -/** - * AI SDK 中间件配置项 - */ -export interface AiSdkMiddlewareConfig { - streamOutput: boolean - onChunk?: (chunk: Chunk) => void - model?: Model - provider?: Provider - assistant?: Assistant - enableReasoning: boolean - // 是否开启提示词工具调用 - isPromptToolUse: boolean - // 是否支持工具调用 - isSupportedToolUse: boolean - // image generation endpoint - isImageGenerationEndpoint: boolean - // 是否开启内置搜索 - enableWebSearch: boolean - enableGenerateImage: boolean - enableUrlContext: boolean - mcpTools?: MCPTool[] - uiMessages?: Message[] - // 内置搜索配置 - webSearchPluginConfig?: WebSearchPluginConfig - // 知识库识别开关,默认开启 - knowledgeRecognition?: 'off' | 'on' -} - -/** - * 具名的 AI SDK 中间件 - */ -export interface NamedAiSdkMiddleware { - name: string - middleware: LanguageModelMiddleware -} - -/** - * AI SDK 中间件建造者 - * 用于根据不同条件动态构建中间件数组 - */ -export class AiSdkMiddlewareBuilder { - private middlewares: NamedAiSdkMiddleware[] = [] - - /** - * 添加具名中间件 - */ - public add(namedMiddleware: NamedAiSdkMiddleware): this { - this.middlewares.push(namedMiddleware) - return this - } - - /** - * 在指定位置插入中间件 - */ - public insertAfter(targetName: string, middleware: NamedAiSdkMiddleware): this { - const index = this.middlewares.findIndex((m) => m.name === targetName) - if (index !== -1) { - this.middlewares.splice(index + 1, 0, middleware) - } else { - logger.warn(`AiSdkMiddlewareBuilder: Middleware named '${targetName}' not found, cannot insert`) - } - return this - } - - /** - * 检查是否包含指定名称的中间件 - */ - public has(name: string): boolean { - return this.middlewares.some((m) => m.name === name) - } - - /** - * 移除指定名称的中间件 - */ - public remove(name: string): this { - this.middlewares = this.middlewares.filter((m) => m.name !== name) - return this - } - - /** - * 构建最终的中间件数组 - */ - public build(): LanguageModelMiddleware[] { - return this.middlewares.map((m) => m.middleware) - } - - /** - * 获取具名中间件数组(用于调试) - */ - public buildNamed(): NamedAiSdkMiddleware[] { - return [...this.middlewares] - } - - /** - * 清空所有中间件 - */ - public clear(): this { - this.middlewares = [] - return this - } - - /** - * 获取中间件总数 - */ - public get length(): number { - return this.middlewares.length - } -} - -/** - * 根据配置构建AI SDK中间件的工厂函数 - * 这里要注意构建顺序,因为有些中间件需要依赖其他中间件的结果 - */ -export function buildAiSdkMiddlewares(config: AiSdkMiddlewareConfig): LanguageModelMiddleware[] { - const builder = new AiSdkMiddlewareBuilder() - - // 1. 根据provider添加特定中间件 - if (config.provider) { - addProviderSpecificMiddlewares(builder, config) - } - - // 2. 根据模型类型添加特定中间件 - if (config.model) { - addModelSpecificMiddlewares(builder, config) - } - - // 3. 非流式输出时添加模拟流中间件 - if (config.streamOutput === false) { - builder.add({ - name: 'simulate-streaming', - middleware: simulateStreamingMiddleware() - }) - } - - return builder.build() -} - -const tagName = { - reasoning: 'reasoning', - think: 'think', - thought: 'thought', - seedThink: 'seed:think' -} - -function getReasoningTagName(modelId: string | undefined): string { - if (modelId?.includes('gpt-oss')) return tagName.reasoning - if (modelId?.includes('gemini')) return tagName.thought - if (modelId?.includes('seed-oss-36b')) return tagName.seedThink - return tagName.think -} - -/** - * 添加provider特定的中间件 - */ -function addProviderSpecificMiddlewares(builder: AiSdkMiddlewareBuilder, config: AiSdkMiddlewareConfig): void { - if (!config.provider) return - - // 根据不同provider添加特定中间件 - switch (config.provider.type) { - case 'anthropic': - // Anthropic特定中间件 - break - case 'openai': - case 'azure-openai': { - if (config.enableReasoning) { - const tagName = getReasoningTagName(config.model?.id.toLowerCase()) - builder.add({ - name: 'thinking-tag-extraction', - middleware: extractReasoningMiddleware({ tagName }) - }) - } - break - } - case 'gemini': - // Gemini特定中间件 - break - case 'aws-bedrock': { - break - } - default: - // 其他provider的通用处理 - break - } - - // OVMS+MCP's specific middleware - if (config.provider.id === 'ovms' && config.mcpTools && config.mcpTools.length > 0) { - builder.add({ - name: 'no-think', - middleware: noThinkMiddleware() - }) - } - - if (config.provider.id === SystemProviderIds.openrouter && config.enableReasoning) { - builder.add({ - name: 'openrouter-reasoning-redaction', - middleware: openrouterReasoningMiddleware() - }) - logger.debug('Added OpenRouter reasoning redaction middleware') - } -} - -/** - * 添加模型特定的中间件 - */ -function addModelSpecificMiddlewares(builder: AiSdkMiddlewareBuilder, config: AiSdkMiddlewareConfig): void { - if (!config.model || !config.provider) return - - // Qwen models on providers that don't support enable_thinking parameter (like Ollama, LM Studio, NVIDIA) - // Use /think or /no_think suffix to control thinking mode - if ( - config.provider && - !isOllamaProvider(config.provider) && - isSupportedThinkingTokenQwenModel(config.model) && - !isSupportEnableThinkingProvider(config.provider) - ) { - const enableThinking = config.assistant?.settings?.reasoning_effort !== undefined - builder.add({ - name: 'qwen-thinking-control', - middleware: qwenThinkingMiddleware(enableThinking) - }) - logger.debug(`Added Qwen thinking middleware with thinking ${enableThinking ? 'enabled' : 'disabled'}`) - } - - // 可以根据模型ID或特性添加特定中间件 - // 例如:图像生成模型、多模态模型等 - if (isOpenRouterGeminiGenerateImageModel(config.model, config.provider)) { - builder.add({ - name: 'openrouter-gemini-image-generation', - middleware: openrouterGenerateImageMiddleware() - }) - } - - if (isGemini3Model(config.model)) { - const aiSdkId = getAiSdkProviderId(config.provider) - builder.add({ - name: 'skip-gemini3-thought-signature', - middleware: skipGeminiThoughtSignatureMiddleware(aiSdkId) - }) - logger.debug('Added skip Gemini3 thought signature middleware') - } -} - -/** - * 创建一个预配置的中间件建造者 - */ -export function createAiSdkMiddlewareBuilder(): AiSdkMiddlewareBuilder { - return new AiSdkMiddlewareBuilder() -} - -/** - * 创建一个带有默认中间件的建造者 - */ -export function createDefaultAiSdkMiddlewareBuilder(config: AiSdkMiddlewareConfig): AiSdkMiddlewareBuilder { - const builder = new AiSdkMiddlewareBuilder() - const defaultMiddlewares = buildAiSdkMiddlewares(config) - - // 将普通中间件数组转换为具名中间件并添加 - defaultMiddlewares.forEach((middleware, index) => { - builder.add({ - name: `default-middleware-${index}`, - middleware - }) - }) - - return builder -} diff --git a/src/renderer/src/aiCore/middleware/README.md b/src/renderer/src/aiCore/middleware/README.md deleted file mode 100644 index 7731d263c3..0000000000 --- a/src/renderer/src/aiCore/middleware/README.md +++ /dev/null @@ -1,140 +0,0 @@ -# AI SDK 中间件建造者 - -## 概述 - -`AiSdkMiddlewareBuilder` 是一个用于动态构建 AI SDK 中间件数组的建造者模式实现。它可以根据不同的条件(如流式输出、思考模型、provider类型等)自动构建合适的中间件组合。 - -## 使用方式 - -### 基本用法 - -```typescript -import { buildAiSdkMiddlewares, type AiSdkMiddlewareConfig } from './AiSdkMiddlewareBuilder' - -// 配置中间件参数 -const config: AiSdkMiddlewareConfig = { - streamOutput: false, // 非流式输出 - onChunk: chunkHandler, // chunk回调函数 - model: currentModel, // 当前模型 - provider: currentProvider, // 当前provider - enableReasoning: true, // 启用推理 - enableTool: false, // 禁用工具 - enableWebSearch: false // 禁用网页搜索 -} - -// 构建中间件数组 -const middlewares = buildAiSdkMiddlewares(config) - -// 创建带有中间件的客户端 -const client = createClient(providerId, options, middlewares) -``` - -### 手动构建 - -```typescript -import { AiSdkMiddlewareBuilder, createAiSdkMiddlewareBuilder } from './AiSdkMiddlewareBuilder' - -const builder = createAiSdkMiddlewareBuilder() - -// 添加特定中间件 -builder.add({ - name: 'custom-middleware', - aiSdkMiddlewares: [customMiddleware()] -}) - -// 检查是否包含某个中间件 -if (builder.has('thinking-time')) { - console.log('已包含思考时间中间件') -} - -// 移除不需要的中间件 -builder.remove('simulate-streaming') - -// 构建最终数组 -const middlewares = builder.build() -``` - -## 支持的条件 - -### 1. 流式输出控制 - -- **streamOutput = false**: 自动添加 `simulateStreamingMiddleware` -- **streamOutput = true**: 使用原生流式处理 - -### 2. 思考模型处理 - -- **条件**: `onChunk` 存在 && `isReasoningModel(model)` 为 true -- **效果**: 自动添加 `thinkingTimeMiddleware` - -### 3. Provider 特定中间件 - -根据不同的 provider 类型添加特定中间件: - -- **anthropic**: Anthropic 特定处理 -- **openai**: OpenAI 特定处理 -- **gemini**: Gemini 特定处理 - -### 4. 模型特定中间件 - -根据模型特性添加中间件: - -- **图像生成模型**: 添加图像处理相关中间件 -- **多模态模型**: 添加多模态处理中间件 - -## 扩展指南 - -### 添加新的条件判断 - -在 `buildAiSdkMiddlewares` 函数中添加新的条件: - -```typescript -// 例如:添加缓存中间件 -if (config.enableCache) { - builder.add({ - name: 'cache', - aiSdkMiddlewares: [cacheMiddleware(config.cacheOptions)] - }) -} -``` - -### 添加 Provider 特定处理 - -在 `addProviderSpecificMiddlewares` 函数中添加: - -```typescript -case 'custom-provider': - builder.add({ - name: 'custom-provider-middleware', - aiSdkMiddlewares: [customProviderMiddleware()] - }) - break -``` - -### 添加模型特定处理 - -在 `addModelSpecificMiddlewares` 函数中添加: - -```typescript -if (config.model.id.includes('custom-model')) { - builder.add({ - name: 'custom-model-middleware', - aiSdkMiddlewares: [customModelMiddleware()] - }) -} -``` - -## 中间件执行顺序 - -中间件按照添加顺序执行: - -1. **simulate-streaming** (如果 streamOutput = false) -2. **thinking-time** (如果是思考模型且有 onChunk) -3. **provider-specific** (根据 provider 类型) -4. **model-specific** (根据模型类型) - -## 注意事项 - -1. 中间件的执行顺序很重要,确保按正确顺序添加 -2. 避免添加冲突的中间件 -3. 某些中间件可能有依赖关系,需要确保依赖的中间件先添加 -4. 建议在开发环境下启用日志,以便调试中间件构建过程 diff --git a/src/renderer/src/aiCore/middleware/toolChoiceMiddleware.ts b/src/renderer/src/aiCore/middleware/toolChoiceMiddleware.ts deleted file mode 100644 index 7bb00aff55..0000000000 --- a/src/renderer/src/aiCore/middleware/toolChoiceMiddleware.ts +++ /dev/null @@ -1,45 +0,0 @@ -import { loggerService } from '@logger' -import type { LanguageModelMiddleware } from 'ai' - -const logger = loggerService.withContext('toolChoiceMiddleware') - -/** - * Tool Choice Middleware - * Controls tool selection strategy across multiple rounds of tool calls: - * - First round: Forces the model to call a specific tool (e.g., knowledge base search) - * - Subsequent rounds: Allows the model to automatically choose any available tool - * - * This ensures knowledge base is consulted first while still enabling MCP tools - * and other capabilities in follow-up interactions. - * - * @param forceFirstToolName - The tool name to force on the first round - * @returns LanguageModelMiddleware - */ -export function toolChoiceMiddleware(forceFirstToolName: string): LanguageModelMiddleware { - let toolCallRound = 0 - - return { - middlewareVersion: 'v2', - - transformParams: async ({ params }) => { - toolCallRound++ - - const transformedParams = { ...params } - - if (toolCallRound === 1) { - // First round: force the specified tool - logger.debug(`Round ${toolCallRound}: Forcing tool choice to '${forceFirstToolName}'`) - transformedParams.toolChoice = { - type: 'tool', - toolName: forceFirstToolName - } - } else { - // Subsequent rounds: allow automatic tool selection - logger.debug(`Round ${toolCallRound}: Using automatic tool choice`) - transformedParams.toolChoice = { type: 'auto' } - } - - return transformedParams - } - } -} diff --git a/src/renderer/src/aiCore/plugins/PluginBuilder.ts b/src/renderer/src/aiCore/plugins/PluginBuilder.ts index eb46eb7524..ca5e879ce8 100644 --- a/src/renderer/src/aiCore/plugins/PluginBuilder.ts +++ b/src/renderer/src/aiCore/plugins/PluginBuilder.ts @@ -1,11 +1,24 @@ import type { AiPlugin } from '@cherrystudio/ai-core' import { createPromptToolUsePlugin, webSearchPlugin } from '@cherrystudio/ai-core/built-in/plugins' import { loggerService } from '@logger' +import { isGemini3Model, isSupportedThinkingTokenQwenModel } from '@renderer/config/models' import { getEnableDeveloperMode } from '@renderer/hooks/useSettings' import type { Assistant } from '@renderer/types' +import { SystemProviderIds } from '@renderer/types' +import { isOllamaProvider, isSupportEnableThinkingProvider } from '@renderer/utils/provider' -import type { AiSdkMiddlewareConfig } from '../middleware/AiSdkMiddlewareBuilder' +import { getAiSdkProviderId } from '../provider/factory' +import type { AiSdkMiddlewareConfig } from '../types/middlewareConfig' +import { isOpenRouterGeminiGenerateImageModel } from '../utils/image' +import { getReasoningTagName } from '../utils/reasoning' +import { createNoThinkPlugin } from './noThinkPlugin' +import { createOpenrouterGenerateImagePlugin } from './openrouterGenerateImagePlugin' +import { createOpenrouterReasoningPlugin } from './openrouterReasoningPlugin' +import { createQwenThinkingPlugin } from './qwenThinkingPlugin' +import { createReasoningExtractionPlugin } from './reasoningExtractionPlugin' import { searchOrchestrationPlugin } from './searchOrchestrationPlugin' +import { createSimulateStreamingPlugin } from './simulateStreamingPlugin' +import { createSkipGeminiThoughtSignaturePlugin } from './skipGeminiThoughtSignaturePlugin' import { createTelemetryPlugin } from './telemetryPlugin' const logger = loggerService.withContext('PluginBuilder') @@ -28,6 +41,59 @@ export function buildPlugins( ) } + // === AI SDK Middleware Plugins === + + // 0.1 Simulate streaming for non-streaming requests + if (!middlewareConfig.streamOutput) { + plugins.push(createSimulateStreamingPlugin()) + } + + // 0.2 Reasoning extraction for OpenAI/Azure providers + if (middlewareConfig.enableReasoning && middlewareConfig.provider) { + const providerType = middlewareConfig.provider.type + if (providerType === 'openai' || providerType === 'azure-openai') { + const tagName = getReasoningTagName(middlewareConfig.model?.id.toLowerCase()) + plugins.push(createReasoningExtractionPlugin({ tagName })) + } + } + + // 0.3 OpenRouter reasoning redaction + if (middlewareConfig.provider?.id === SystemProviderIds.openrouter && middlewareConfig.enableReasoning) { + plugins.push(createOpenrouterReasoningPlugin()) + } + + // 0.4 OVMS no-think for MCP tools + if (middlewareConfig.provider?.id === 'ovms' && middlewareConfig.mcpTools && middlewareConfig.mcpTools.length > 0) { + plugins.push(createNoThinkPlugin()) + } + + // 0.5 Qwen thinking control for providers without enable_thinking support + if ( + middlewareConfig.provider && + middlewareConfig.model && + !isOllamaProvider(middlewareConfig.provider) && + isSupportedThinkingTokenQwenModel(middlewareConfig.model) && + !isSupportEnableThinkingProvider(middlewareConfig.provider) + ) { + const enableThinking = middlewareConfig.assistant?.settings?.reasoning_effort !== undefined + plugins.push(createQwenThinkingPlugin(enableThinking)) + } + + // 0.6 OpenRouter Gemini image generation + if ( + middlewareConfig.model && + middlewareConfig.provider && + isOpenRouterGeminiGenerateImageModel(middlewareConfig.model, middlewareConfig.provider) + ) { + plugins.push(createOpenrouterGenerateImagePlugin()) + } + + // 0.7 Skip Gemini3 thought signature + if (middlewareConfig.model && middlewareConfig.provider && isGemini3Model(middlewareConfig.model)) { + const aiSdkId = getAiSdkProviderId(middlewareConfig.provider) + plugins.push(createSkipGeminiThoughtSignaturePlugin(aiSdkId)) + } + // 1. 模型内置搜索 if (middlewareConfig.enableWebSearch && middlewareConfig.webSearchPluginConfig) { plugins.push(webSearchPlugin(middlewareConfig.webSearchPluginConfig)) diff --git a/src/renderer/src/aiCore/middleware/noThinkMiddleware.ts b/src/renderer/src/aiCore/plugins/noThinkPlugin.ts similarity index 78% rename from src/renderer/src/aiCore/middleware/noThinkMiddleware.ts rename to src/renderer/src/aiCore/plugins/noThinkPlugin.ts index 3e5624983c..919d3ab965 100644 --- a/src/renderer/src/aiCore/middleware/noThinkMiddleware.ts +++ b/src/renderer/src/aiCore/plugins/noThinkPlugin.ts @@ -1,7 +1,8 @@ +import { definePlugin } from '@cherrystudio/ai-core' import { loggerService } from '@logger' import type { LanguageModelMiddleware } from 'ai' -const logger = loggerService.withContext('noThinkMiddleware') +const logger = loggerService.withContext('noThinkPlugin') /** * No Think Middleware @@ -9,9 +10,9 @@ const logger = loggerService.withContext('noThinkMiddleware') * This prevents the model from generating unnecessary thinking process and returns results directly * @returns LanguageModelMiddleware */ -export function noThinkMiddleware(): LanguageModelMiddleware { +function createNoThinkMiddleware(): LanguageModelMiddleware { return { - middlewareVersion: 'v2', + specificationVersion: 'v3', transformParams: async ({ params }) => { const transformedParams = { ...params } @@ -50,3 +51,14 @@ export function noThinkMiddleware(): LanguageModelMiddleware { } } } + +export const createNoThinkPlugin = () => + definePlugin({ + name: 'noThink', + enforce: 'pre', + + configureContext: (context) => { + context.middlewares = context.middlewares || [] + context.middlewares.push(createNoThinkMiddleware()) + } + }) diff --git a/src/renderer/src/aiCore/middleware/openrouterGenerateImageMiddleware.ts b/src/renderer/src/aiCore/plugins/openrouterGenerateImagePlugin.ts similarity index 67% rename from src/renderer/src/aiCore/middleware/openrouterGenerateImageMiddleware.ts rename to src/renderer/src/aiCore/plugins/openrouterGenerateImagePlugin.ts index 792192b931..e446b5fd86 100644 --- a/src/renderer/src/aiCore/middleware/openrouterGenerateImageMiddleware.ts +++ b/src/renderer/src/aiCore/plugins/openrouterGenerateImagePlugin.ts @@ -1,3 +1,4 @@ +import { definePlugin } from '@cherrystudio/ai-core' import type { LanguageModelMiddleware } from 'ai' /** @@ -6,7 +7,7 @@ import type { LanguageModelMiddleware } from 'ai' * https://openrouter.ai/docs/features/multimodal/image-generation * * Remarks: - * - The middleware declares middlewareVersion as 'v2'. + * - The middleware declares specificationVersion as 'v3'. * - transformParams asynchronously clones the incoming params and sets * providerOptions.openrouter.modalities = ['image', 'text'], preserving other providerOptions and * openrouter fields when present. @@ -15,9 +16,9 @@ import type { LanguageModelMiddleware } from 'ai' * * @returns LanguageModelMiddleware - a middleware that augments providerOptions for OpenRouter to include image and text modalities. */ -export function openrouterGenerateImageMiddleware(): LanguageModelMiddleware { +function createOpenrouterGenerateImageMiddleware(): LanguageModelMiddleware { return { - middlewareVersion: 'v2', + specificationVersion: 'v3', transformParams: async ({ params }) => { const transformedParams = { ...params } @@ -25,9 +26,19 @@ export function openrouterGenerateImageMiddleware(): LanguageModelMiddleware { ...transformedParams.providerOptions, openrouter: { ...transformedParams.providerOptions?.openrouter, modalities: ['image', 'text'] } } - transformedParams return transformedParams } } } + +export const createOpenrouterGenerateImagePlugin = () => + definePlugin({ + name: 'openrouterGenerateImage', + enforce: 'pre', + + configureContext: (context) => { + context.middlewares = context.middlewares || [] + context.middlewares.push(createOpenrouterGenerateImageMiddleware()) + } + }) diff --git a/src/renderer/src/aiCore/middleware/openrouterReasoningMiddleware.ts b/src/renderer/src/aiCore/plugins/openrouterReasoningPlugin.ts similarity index 67% rename from src/renderer/src/aiCore/middleware/openrouterReasoningMiddleware.ts rename to src/renderer/src/aiCore/plugins/openrouterReasoningPlugin.ts index 9ef3df61e9..7363ccc761 100644 --- a/src/renderer/src/aiCore/middleware/openrouterReasoningMiddleware.ts +++ b/src/renderer/src/aiCore/plugins/openrouterReasoningPlugin.ts @@ -1,4 +1,5 @@ -import type { LanguageModelV2StreamPart } from '@ai-sdk/provider' +import type { LanguageModelV3StreamPart } from '@ai-sdk/provider' +import { definePlugin } from '@cherrystudio/ai-core' import type { LanguageModelMiddleware } from 'ai' /** @@ -6,10 +7,10 @@ import type { LanguageModelMiddleware } from 'ai' * * @returns LanguageModelMiddleware - a middleware filter redacted block */ -export function openrouterReasoningMiddleware(): LanguageModelMiddleware { +function createOpenrouterReasoningMiddleware(): LanguageModelMiddleware { const REDACTED_BLOCK = '[REDACTED]' return { - middlewareVersion: 'v2', + specificationVersion: 'v3', wrapGenerate: async ({ doGenerate }) => { const { content, ...rest } = await doGenerate() const modifiedContent = content.map((part) => { @@ -27,10 +28,10 @@ export function openrouterReasoningMiddleware(): LanguageModelMiddleware { const { stream, ...rest } = await doStream() return { stream: stream.pipeThrough( - new TransformStream({ + new TransformStream({ transform( - chunk: LanguageModelV2StreamPart, - controller: TransformStreamDefaultController + chunk: LanguageModelV3StreamPart, + controller: TransformStreamDefaultController ) { if (chunk.type === 'reasoning-delta' && chunk.delta.includes(REDACTED_BLOCK)) { controller.enqueue({ @@ -48,3 +49,14 @@ export function openrouterReasoningMiddleware(): LanguageModelMiddleware { } } } + +export const createOpenrouterReasoningPlugin = () => + definePlugin({ + name: 'openrouterReasoning', + enforce: 'pre', + + configureContext: (context) => { + context.middlewares = context.middlewares || [] + context.middlewares.push(createOpenrouterReasoningMiddleware()) + } + }) diff --git a/src/renderer/src/aiCore/middleware/qwenThinkingMiddleware.ts b/src/renderer/src/aiCore/plugins/qwenThinkingPlugin.ts similarity index 72% rename from src/renderer/src/aiCore/middleware/qwenThinkingMiddleware.ts rename to src/renderer/src/aiCore/plugins/qwenThinkingPlugin.ts index 931831a1c6..44bcdeffac 100644 --- a/src/renderer/src/aiCore/middleware/qwenThinkingMiddleware.ts +++ b/src/renderer/src/aiCore/plugins/qwenThinkingPlugin.ts @@ -1,3 +1,4 @@ +import { definePlugin } from '@cherrystudio/ai-core' import type { LanguageModelMiddleware } from 'ai' /** @@ -7,11 +8,11 @@ import type { LanguageModelMiddleware } from 'ai' * @param enableThinking - Whether thinking mode is enabled (based on reasoning_effort !== undefined) * @returns LanguageModelMiddleware */ -export function qwenThinkingMiddleware(enableThinking: boolean): LanguageModelMiddleware { +function createQwenThinkingMiddleware(enableThinking: boolean): LanguageModelMiddleware { const suffix = enableThinking ? ' /think' : ' /no_think' return { - middlewareVersion: 'v2', + specificationVersion: 'v3', transformParams: async ({ params }) => { const transformedParams = { ...params } @@ -37,3 +38,14 @@ export function qwenThinkingMiddleware(enableThinking: boolean): LanguageModelMi } } } + +export const createQwenThinkingPlugin = (enableThinking: boolean) => + definePlugin({ + name: 'qwenThinking', + enforce: 'pre', + + configureContext: (context) => { + context.middlewares = context.middlewares || [] + context.middlewares.push(createQwenThinkingMiddleware(enableThinking)) + } + }) diff --git a/src/renderer/src/aiCore/plugins/reasoningExtractionPlugin.ts b/src/renderer/src/aiCore/plugins/reasoningExtractionPlugin.ts new file mode 100644 index 0000000000..ec9ee0f89b --- /dev/null +++ b/src/renderer/src/aiCore/plugins/reasoningExtractionPlugin.ts @@ -0,0 +1,22 @@ +import { definePlugin } from '@cherrystudio/ai-core' +import { extractReasoningMiddleware } from 'ai' + +/** + * Reasoning Extraction Plugin + * Extracts reasoning/thinking tags from OpenAI/Azure responses + * Uses AI SDK's built-in extractReasoningMiddleware + */ +export const createReasoningExtractionPlugin = (options: { tagName?: string } = {}) => + definePlugin({ + name: 'reasoningExtraction', + enforce: 'pre', + + configureContext: (context) => { + context.middlewares = context.middlewares || [] + context.middlewares.push( + extractReasoningMiddleware({ + tagName: options.tagName || 'thinking' + }) + ) + } + }) diff --git a/src/renderer/src/aiCore/plugins/searchOrchestrationPlugin.ts b/src/renderer/src/aiCore/plugins/searchOrchestrationPlugin.ts index 5b095a4461..476802bd07 100644 --- a/src/renderer/src/aiCore/plugins/searchOrchestrationPlugin.ts +++ b/src/renderer/src/aiCore/plugins/searchOrchestrationPlugin.ts @@ -6,7 +6,13 @@ * 2. transformParams: 根据意图分析结果动态添加对应的工具 * 3. onRequestEnd: 自动记忆存储 */ -import { type AiRequestContext, definePlugin } from '@cherrystudio/ai-core' +import { + type AiPlugin, + type AiRequestContext, + definePlugin, + type StreamTextParams, + type StreamTextResult +} from '@cherrystudio/ai-core' import { loggerService } from '@logger' // import { generateObject } from '@cherrystudio/ai-core' import { @@ -236,18 +242,21 @@ async function storeConversationMemory( /** * 🎯 搜索编排插件 */ -export const searchOrchestrationPlugin = (assistant: Assistant, topicId: string) => { +export const searchOrchestrationPlugin = ( + assistant: Assistant, + topicId: string +): AiPlugin => { // 存储意图分析结果 const intentAnalysisResults: { [requestId: string]: ExtractResults } = {} const userMessages: { [requestId: string]: ModelMessage } = {} - return definePlugin({ + return definePlugin({ name: 'search-orchestration', enforce: 'pre', // 确保在其他插件之前执行 /** * 🔍 Step 1: 意图识别阶段 */ - onRequestStart: async (context: AiRequestContext) => { + onRequestStart: async (context) => { // 没开启任何搜索则不进行意图分析 if (!(assistant.webSearchProviderId || assistant.knowledge_bases?.length || assistant.enableMemory)) return @@ -297,7 +306,7 @@ export const searchOrchestrationPlugin = (assistant: Assistant, topicId: string) /** * 🔧 Step 2: 工具配置阶段 */ - transformParams: async (params: any, context: AiRequestContext) => { + transformParams: async (params, context) => { // logger.info('🔧 Configuring tools based on intent...', context.requestId) try { @@ -309,7 +318,7 @@ export const searchOrchestrationPlugin = (assistant: Assistant, topicId: string) // 确保 tools 对象存在 if (!params.tools) { - params.tools = {} + return { tools: {} } } // 🌐 网络搜索工具配置 @@ -371,11 +380,12 @@ export const searchOrchestrationPlugin = (assistant: Assistant, topicId: string) * 💾 Step 3: 记忆存储阶段 */ - onRequestEnd: async (context: AiRequestContext) => { + onRequestEnd: async (context) => { // context.isAnalyzing = false // logger.info('context.isAnalyzing', context, result) // logger.info('💾 Starting memory storage...', context.requestId) try { + // ✅ 类型安全访问:context.originalParams 已通过泛型正确类型化 const messages = context.originalParams.messages if (messages && assistant) { diff --git a/src/renderer/src/aiCore/plugins/simulateStreamingPlugin.ts b/src/renderer/src/aiCore/plugins/simulateStreamingPlugin.ts new file mode 100644 index 0000000000..fc4c35ddc8 --- /dev/null +++ b/src/renderer/src/aiCore/plugins/simulateStreamingPlugin.ts @@ -0,0 +1,18 @@ +import { definePlugin } from '@cherrystudio/ai-core' +import { simulateStreamingMiddleware } from 'ai' + +/** + * Simulate Streaming Plugin + * Converts non-streaming responses to streaming format + * Uses AI SDK's built-in simulateStreamingMiddleware + */ +export const createSimulateStreamingPlugin = () => + definePlugin({ + name: 'simulateStreaming', + enforce: 'pre', + + configureContext: (context) => { + context.middlewares = context.middlewares || [] + context.middlewares.push(simulateStreamingMiddleware()) + } + }) diff --git a/src/renderer/src/aiCore/middleware/skipGeminiThoughtSignatureMiddleware.ts b/src/renderer/src/aiCore/plugins/skipGeminiThoughtSignaturePlugin.ts similarity index 71% rename from src/renderer/src/aiCore/middleware/skipGeminiThoughtSignatureMiddleware.ts rename to src/renderer/src/aiCore/plugins/skipGeminiThoughtSignaturePlugin.ts index da318ea60d..0264c77fad 100644 --- a/src/renderer/src/aiCore/middleware/skipGeminiThoughtSignatureMiddleware.ts +++ b/src/renderer/src/aiCore/plugins/skipGeminiThoughtSignaturePlugin.ts @@ -1,3 +1,4 @@ +import { definePlugin } from '@cherrystudio/ai-core' import type { LanguageModelMiddleware } from 'ai' /** @@ -8,10 +9,10 @@ import type { LanguageModelMiddleware } from 'ai' * @param aiSdkId AI SDK Provider ID * @returns LanguageModelMiddleware */ -export function skipGeminiThoughtSignatureMiddleware(aiSdkId: string): LanguageModelMiddleware { +function createSkipGeminiThoughtSignatureMiddleware(aiSdkId: string): LanguageModelMiddleware { const MAGIC_STRING = 'skip_thought_signature_validator' return { - middlewareVersion: 'v2', + specificationVersion: 'v3', transformParams: async ({ params }) => { const transformedParams = { ...params } @@ -34,3 +35,14 @@ export function skipGeminiThoughtSignatureMiddleware(aiSdkId: string): LanguageM } } } + +export const createSkipGeminiThoughtSignaturePlugin = (aiSdkId: string) => + definePlugin({ + name: 'skipGeminiThoughtSignature', + enforce: 'pre', + + configureContext: (context) => { + context.middlewares = context.middlewares || [] + context.middlewares.push(createSkipGeminiThoughtSignatureMiddleware(aiSdkId)) + } + }) diff --git a/src/renderer/src/aiCore/prepareParams/parameterBuilder.ts b/src/renderer/src/aiCore/prepareParams/parameterBuilder.ts index 52234c5f1f..3db524616a 100644 --- a/src/renderer/src/aiCore/prepareParams/parameterBuilder.ts +++ b/src/renderer/src/aiCore/prepareParams/parameterBuilder.ts @@ -47,7 +47,7 @@ import { getMaxTokens, getTemperature, getTopP } from './modelParameters' const logger = loggerService.withContext('parameterBuilder') -type ProviderDefinedTool = Extract, { type: 'provider-defined' }> +type ProviderDefinedTool = Extract, { type: 'provider' }> function mapVertexAIGatewayModelToProviderId(model: Model): BaseProviderId | undefined { if (isAnthropicModel(model)) { @@ -62,9 +62,7 @@ function mapVertexAIGatewayModelToProviderId(model: Model): BaseProviderId | und if (isOpenAIModel(model)) { return 'openai' } - logger.warn( - `[mapVertexAIGatewayModelToProviderId] Unknown model type for AI Gateway: ${model.id}. Web search will not be enabled.` - ) + logger.warn(`Unknown model type for AI Gateway: ${model.id}. Web search will not be enabled.`) return undefined } diff --git a/src/renderer/src/aiCore/types/middlewareConfig.ts b/src/renderer/src/aiCore/types/middlewareConfig.ts new file mode 100644 index 0000000000..94743cb140 --- /dev/null +++ b/src/renderer/src/aiCore/types/middlewareConfig.ts @@ -0,0 +1,26 @@ +import type { WebSearchPluginConfig } from '@cherrystudio/ai-core/built-in/plugins' +import type { MCPTool } from '@renderer/types' +import type { Assistant, Message, Model, Provider } from '@renderer/types' +import type { Chunk } from '@renderer/types/chunk' + +/** + * AI SDK 中间件配置项(用于插件构建) + */ +export interface AiSdkMiddlewareConfig { + streamOutput: boolean + onChunk?: (chunk: Chunk) => void + model?: Model + provider?: Provider + assistant?: Assistant + enableReasoning: boolean + isPromptToolUse: boolean + isSupportedToolUse: boolean + isImageGenerationEndpoint: boolean + enableWebSearch: boolean + enableGenerateImage: boolean + enableUrlContext: boolean + mcpTools?: MCPTool[] + uiMessages?: Message[] + webSearchPluginConfig?: WebSearchPluginConfig + knowledgeRecognition?: 'off' | 'on' +} diff --git a/src/renderer/src/aiCore/utils/reasoning.ts b/src/renderer/src/aiCore/utils/reasoning.ts index ab8a0b7983..0cb82809f6 100644 --- a/src/renderer/src/aiCore/utils/reasoning.ts +++ b/src/renderer/src/aiCore/utils/reasoning.ts @@ -726,3 +726,21 @@ export function getCustomParameters(assistant: Assistant): Record { }, {}) || {} ) } + +/** + * Get reasoning tag name based on model ID + * Used for extractReasoningMiddleware configuration + */ +export function getReasoningTagName(modelId: string | undefined): string { + const tagName = { + reasoning: 'reasoning', + think: 'think', + thought: 'thought', + seedThink: 'seed:think' + } + + if (modelId?.includes('gpt-oss')) return tagName.reasoning + if (modelId?.includes('gemini')) return tagName.thought + if (modelId?.includes('seed-oss-36b')) return tagName.seedThink + return tagName.think +} diff --git a/src/renderer/src/services/ApiService.ts b/src/renderer/src/services/ApiService.ts index 0cd57a353a..f5908becd4 100644 --- a/src/renderer/src/services/ApiService.ts +++ b/src/renderer/src/services/ApiService.ts @@ -2,8 +2,8 @@ * 职责:提供原子化的、无状态的API调用函数 */ import { loggerService } from '@logger' -import type { AiSdkMiddlewareConfig } from '@renderer/aiCore/middleware/AiSdkMiddlewareBuilder' import { buildStreamTextParams } from '@renderer/aiCore/prepareParams' +import type { AiSdkMiddlewareConfig } from '@renderer/aiCore/types/middlewareConfig' import { isDedicatedImageGenerationModel, isEmbeddingModel, isFunctionCallingModel } from '@renderer/config/models' import { getStoreSetting } from '@renderer/hooks/useSettings' import i18n from '@renderer/i18n' diff --git a/src/renderer/src/types/aiCoreTypes.ts b/src/renderer/src/types/aiCoreTypes.ts index 28250e4053..19a8d67627 100644 --- a/src/renderer/src/types/aiCoreTypes.ts +++ b/src/renderer/src/types/aiCoreTypes.ts @@ -1,9 +1,14 @@ import type OpenAI from '@cherrystudio/openai' import type { NotUndefined } from '@types' -import type { ImageModel, LanguageModel } from 'ai' -import type { generateObject, generateText, ModelMessage, streamObject, streamText } from 'ai' +import type { generateText, ModelMessage, streamText } from 'ai' import * as z from 'zod' +/** + * 渲染器侧参数类型(不包含 model 和 messages,因为它们会单独处理) + * 注意:这与 @cherrystudio/ai-core 导出的完整参数类型不同 + * - @cherrystudio/ai-core 的 StreamTextParams: 完整的 AI SDK 参数(用于插件系统) + * - 此处的 StreamTextParams: 去除 model/messages 的参数(用于渲染器参数构建) + */ export type StreamTextParams = Omit[0], 'model' | 'messages'> & ( | { @@ -15,6 +20,11 @@ export type StreamTextParams = Omit[0], 'model' | prompt?: never } ) + +/** + * 渲染器侧参数类型(不包含 model 和 messages) + * 注意:这与 @cherrystudio/ai-core 导出的完整参数类型不同 + */ export type GenerateTextParams = Omit[0], 'model' | 'messages'> & ( | { @@ -26,10 +36,6 @@ export type GenerateTextParams = Omit[0], 'model prompt?: never } ) -export type StreamObjectParams = Omit[0], 'model'> -export type GenerateObjectParams = Omit[0], 'model'> - -export type AiSdkModel = LanguageModel | ImageModel /** * Constrains the verbosity of the model's response. Lower values will result in more concise responses, while higher values will result in more verbose responses. diff --git a/yarn.lock b/yarn.lock index 01a7e5f2d9..742adfc097 100644 --- a/yarn.lock +++ b/yarn.lock @@ -205,6 +205,18 @@ __metadata: languageName: node linkType: hard +"@ai-sdk/openai-compatible@npm:1.0.28": + version: 1.0.28 + resolution: "@ai-sdk/openai-compatible@npm:1.0.28" + dependencies: + "@ai-sdk/provider": "npm:2.0.0" + "@ai-sdk/provider-utils": "npm:3.0.18" + peerDependencies: + zod: ^3.25.76 || ^4.1.8 + checksum: 10c0/f484774e0094a12674f392d925038a296191723b4c76bd833eabf1b334cf3c84fe77a2e2c5fbac974ec5e18340e113c6a81c86d957c9529a7a60e87cd390ada8 + languageName: node + linkType: hard + "@ai-sdk/openai-compatible@npm:2.0.0, @ai-sdk/openai-compatible@npm:^2.0.0": version: 2.0.0 resolution: "@ai-sdk/openai-compatible@npm:2.0.0" @@ -217,15 +229,15 @@ __metadata: languageName: node linkType: hard -"@ai-sdk/openai-compatible@npm:^1.0.19": - version: 1.0.27 - resolution: "@ai-sdk/openai-compatible@npm:1.0.27" +"@ai-sdk/openai-compatible@patch:@ai-sdk/openai-compatible@npm%3A1.0.28#~/.yarn/patches/@ai-sdk-openai-compatible-npm-1.0.28-5705188855.patch": + version: 1.0.28 + resolution: "@ai-sdk/openai-compatible@patch:@ai-sdk/openai-compatible@npm%3A1.0.28#~/.yarn/patches/@ai-sdk-openai-compatible-npm-1.0.28-5705188855.patch::version=1.0.28&hash=f2cb20" dependencies: "@ai-sdk/provider": "npm:2.0.0" - "@ai-sdk/provider-utils": "npm:3.0.17" + "@ai-sdk/provider-utils": "npm:3.0.18" peerDependencies: zod: ^3.25.76 || ^4.1.8 - checksum: 10c0/9f656e4f2ea4d714dc05be588baafd962b2e0360e9195fef373e745efeb20172698ea87e1033c0c5e1f1aa6e0db76a32629427bc8433eb42bd1a0ee00e04af0c + checksum: 10c0/0b1d99fe8ce506e5c0a3703ae0511ac2017781584074d41faa2df82923c64eb1229ffe9f036de150d0248923613c761a463fe89d5923493983e0463a1101e792 languageName: node linkType: hard @@ -277,16 +289,16 @@ __metadata: languageName: node linkType: hard -"@ai-sdk/provider-utils@npm:3.0.17, @ai-sdk/provider-utils@npm:^3.0.10, @ai-sdk/provider-utils@npm:^3.0.17": - version: 3.0.17 - resolution: "@ai-sdk/provider-utils@npm:3.0.17" +"@ai-sdk/provider-utils@npm:3.0.18": + version: 3.0.18 + resolution: "@ai-sdk/provider-utils@npm:3.0.18" dependencies: "@ai-sdk/provider": "npm:2.0.0" "@standard-schema/spec": "npm:^1.0.0" eventsource-parser: "npm:^3.0.6" peerDependencies: zod: ^3.25.76 || ^4.1.8 - checksum: 10c0/1bae6dc4cacd0305b6aa152f9589bbd61c29f150155482c285a77f83d7ed416d52bc2aa7fdaba2e5764530392d9e8f799baea34a63dce6c72ecd3de364dc62d1 + checksum: 10c0/209c15b0dceef0ba95a7d3de544be0a417ad4a0bd5143496b3966a35fedf144156d93a42ff8c3d7db56781b9836bafc8c132c98978c49240e55bc1a36e18a67f languageName: node linkType: hard @@ -316,6 +328,19 @@ __metadata: languageName: node linkType: hard +"@ai-sdk/provider-utils@npm:^3.0.10, @ai-sdk/provider-utils@npm:^3.0.17": + version: 3.0.17 + resolution: "@ai-sdk/provider-utils@npm:3.0.17" + dependencies: + "@ai-sdk/provider": "npm:2.0.0" + "@standard-schema/spec": "npm:^1.0.0" + eventsource-parser: "npm:^3.0.6" + peerDependencies: + zod: ^3.25.76 || ^4.1.8 + checksum: 10c0/1bae6dc4cacd0305b6aa152f9589bbd61c29f150155482c285a77f83d7ed416d52bc2aa7fdaba2e5764530392d9e8f799baea34a63dce6c72ecd3de364dc62d1 + languageName: node + linkType: hard + "@ai-sdk/provider@npm:2.0.0, @ai-sdk/provider@npm:^2.0.0": version: 2.0.0 resolution: "@ai-sdk/provider@npm:2.0.0"