diff --git a/packages/aiCore/src/core/runtime/executor.ts b/packages/aiCore/src/core/runtime/executor.ts index 9c6c75ef47..cb8170f568 100644 --- a/packages/aiCore/src/core/runtime/executor.ts +++ b/packages/aiCore/src/core/runtime/executor.ts @@ -5,6 +5,7 @@ import type { ImageModelV3, LanguageModelV3, LanguageModelV3Middleware, ProviderV3 } from '@ai-sdk/provider' import type { LanguageModel } from 'ai' import { + embedMany as _embedMany, generateImage as _generateImage, generateText as _generateText, streamText as _streamText, @@ -17,7 +18,14 @@ import { type AiPlugin, type AiRequestContext, definePlugin } from '../plugins' import type { CoreProviderSettingsMap, StringKeys } from '../providers/types' import { ImageGenerationError, ImageModelResolutionError } from './errors' import { PluginEngine } from './pluginEngine' -import type { generateImageParams, generateTextParams, RuntimeConfig, streamTextParams } from './types' +import type { + EmbedManyParams, + EmbedManyResult, + generateImageParams, + generateTextParams, + RuntimeConfig, + streamTextParams +} from './types' export class RuntimeExecutor< TSettingsMap extends Record = CoreProviderSettingsMap, @@ -166,6 +174,23 @@ export class RuntimeExecutor< } } + /** + * 批量嵌入文本 + * AI SDK v6 只有 embedMany,没有 embed + */ + async embedMany(params: EmbedManyParams): Promise { + const { model: modelOrId, ...options } = params + + // 解析 embedding 模型 + const embeddingModel = + typeof modelOrId === 'string' ? await this.modelResolver.resolveEmbeddingModel(modelOrId) : modelOrId + + return _embedMany({ + model: embeddingModel, + ...options + }) + } + // === 辅助方法 === /** diff --git a/packages/aiCore/src/core/runtime/index.ts b/packages/aiCore/src/core/runtime/index.ts index 10e4152581..e7f9df6d90 100644 --- a/packages/aiCore/src/core/runtime/index.ts +++ b/packages/aiCore/src/core/runtime/index.ts @@ -7,7 +7,7 @@ export { RuntimeExecutor } from './executor' // 导出类型 -export type { RuntimeConfig } from './types' +export type { EmbedManyParams, EmbedManyResult, RuntimeConfig } from './types' // === 便捷工厂函数 === @@ -84,6 +84,23 @@ export async function generateImage< return executor.generateImage(params) } +/** + * 直接批量嵌入文本 + * AI SDK v6 只有 embedMany,没有 embed + */ +export async function embedMany< + TSettingsMap extends Record = CoreProviderSettingsMap, + T extends StringKeys = StringKeys +>( + providerId: T, + options: TSettingsMap[T], + params: Parameters['embedMany']>[0], + plugins?: AiPlugin[] +): Promise['embedMany']>> { + const executor = await createExecutor(providerId, options, plugins) + return executor.embedMany(params) +} + /** * 创建 OpenAI Compatible 执行器 */ diff --git a/packages/aiCore/src/core/runtime/types.ts b/packages/aiCore/src/core/runtime/types.ts index 22a0ec4af8..ee005a55b9 100644 --- a/packages/aiCore/src/core/runtime/types.ts +++ b/packages/aiCore/src/core/runtime/types.ts @@ -1,8 +1,8 @@ /** * Runtime 层类型定义 */ -import type { ImageModelV3, ProviderV3 } from '@ai-sdk/provider' -import type { generateImage, generateText, streamText } from 'ai' +import type { EmbeddingModelV3, ImageModelV3, ProviderV3 } from '@ai-sdk/provider' +import type { embedMany, generateImage, generateText, streamText } from 'ai' import { type AiPlugin } from '../plugins' import type { CoreProviderSettingsMap, StringKeys } from '../providers/types' @@ -28,3 +28,9 @@ export type generateImageParams = Omit[0], 'mod } export type generateTextParams = Parameters[0] export type streamTextParams = Parameters[0] + +// Embedding types (AI SDK v6 only has embedMany, no embed) +export type EmbedManyParams = Omit[0], 'model'> & { + model: string | EmbeddingModelV3 +} +export type EmbedManyResult = Awaited> diff --git a/packages/aiCore/src/index.ts b/packages/aiCore/src/index.ts index 9cff910c20..65e2ab6c8a 100644 --- a/packages/aiCore/src/index.ts +++ b/packages/aiCore/src/index.ts @@ -9,11 +9,15 @@ export { createExecutor, createOpenAICompatibleExecutor, + embedMany, generateImage, generateText, streamText } from './core/runtime' +// ==================== Embedding 类型 ==================== +export type { EmbedManyParams, EmbedManyResult } from './core/runtime' + // ==================== 高级API ==================== export { isV2Model, isV3Model } from './core/models' diff --git a/src/renderer/src/aiCore/index.ts b/src/renderer/src/aiCore/index.ts index eb68da74ea..ef916b3c3a 100644 --- a/src/renderer/src/aiCore/index.ts +++ b/src/renderer/src/aiCore/index.ts @@ -1,16 +1,12 @@ /** * Cherry Studio AI Core - 统一入口点 * - * 这是新的统一入口,保持向后兼容性 - * 默认导出legacy AiProvider以保持现有代码的兼容性 + * 使用 ModernAiProvider 作为默认导出 + * Legacy provider 已移除 */ -// 导出Legacy AiProvider作为默认导出(保持向后兼容) -export { default } from './legacy/index' +import ModernAiProvider from './index_new' -// 同时导出Modern AiProvider供新代码使用 -export { default as ModernAiProvider } from './index_new' - -// 导出一些常用的类型和工具 -export * from './legacy/clients/types' -export * from './legacy/middleware/schemas' +// 默认导出和命名导出 +export default ModernAiProvider +export { ModernAiProvider } diff --git a/src/renderer/src/aiCore/index_new.ts b/src/renderer/src/aiCore/index_new.ts index 67d0778448..ec0a353ad7 100644 --- a/src/renderer/src/aiCore/index_new.ts +++ b/src/renderer/src/aiCore/index_new.ts @@ -27,12 +27,10 @@ import { buildClaudeCodeSystemModelMessage } from '@shared/anthropic' import { gateway } from 'ai' import AiSdkToChunkAdapter from './chunk/AiSdkToChunkAdapter' -import LegacyAiProvider from './legacy/index' -import type { CompletionsParams, CompletionsResult } from './legacy/middleware/schemas' import { buildPlugins } from './plugins/PluginBuilder' import { adaptProvider, getActualProvider, providerToAiSdkConfig } from './provider/providerConfig' import { ModelListService } from './services/ModelListService' -import type { AppProviderSettingsMap, ProviderConfig } from './types' +import type { AppProviderSettingsMap, CompletionsResult, ProviderConfig } from './types' import type { AiSdkMiddlewareConfig } from './types/middlewareConfig' const logger = loggerService.withContext('ModernAiProvider') @@ -45,7 +43,6 @@ export type ModernAiProviderConfig = AiSdkMiddlewareConfig & { } export default class ModernAiProvider { - private legacyProvider: LegacyAiProvider private config?: ProviderConfig private actualProvider: Provider private model?: Model @@ -101,8 +98,6 @@ export default class ModernAiProvider { this.actualProvider = adaptProvider({ provider: modelOrProvider }) // model为可选,某些操作(如fetchModels)不需要model } - - this.legacyProvider = new LegacyAiProvider(this.actualProvider) } /** @@ -169,28 +164,8 @@ export default class ModernAiProvider { middlewareConfig: ModernAiProviderConfig, providerConfig: ProviderConfig ): Promise { - // ai-gateway不是image/generation 端点,所以就先不走legacy了 - if (middlewareConfig.isImageGenerationEndpoint && this.getActualProvider().id !== SystemProviderIds.gateway) { - // 使用 legacy 实现处理图像生成(支持图片编辑等高级功能) - if (!middlewareConfig.uiMessages) { - throw new Error('uiMessages is required for image generation endpoint') - } - - const legacyParams: CompletionsParams = { - callType: 'chat', - messages: middlewareConfig.uiMessages, // 使用原始的 UI 消息格式 - assistant: middlewareConfig.assistant, - streamOutput: middlewareConfig.streamOutput ?? true, - onChunk: middlewareConfig.onChunk, - topicId: middlewareConfig.topicId, - mcpTools: middlewareConfig.mcpTools, - enableWebSearch: middlewareConfig.enableWebSearch - } - - // 调用 legacy 的 completions,会自动使用 ImageGenerationMiddleware - return await this.legacyProvider.completions(legacyParams) - } - + // 专用图像生成模型已在 ApiService 层分流到 fetchImageGeneration + // 这里只处理普通的 completions return await this.modernCompletions(modelId, params, middlewareConfig, providerConfig) } @@ -350,8 +325,29 @@ export default class ModernAiProvider { return await ModelListService.listModels(this.actualProvider) } + /** + * 获取嵌入模型的维度 + * 使用 AI SDK embedMany 测试获取维度 + */ public async getEmbeddingDimensions(model: Model): Promise { - return this.legacyProvider.getEmbeddingDimensions(model) + // 确保 config 已定义 + if (!this.config) { + this.config = await Promise.resolve(providerToAiSdkConfig(this.actualProvider, model)) + } + + const executor = await createExecutor( + this.config.providerId, + this.config.providerSettings, + [] + ) + + // 使用 AI SDK embedMany 测试获取维度 + const result = await executor.embedMany({ + model: model.id, + values: ['test'] + }) + + return result.embeddings[0].length } /** @@ -453,10 +449,42 @@ export default class ModernAiProvider { } public getBaseURL(): string { - return this.legacyProvider.getBaseURL() + return this.actualProvider.apiHost || '' } public getApiKey(): string { - return this.legacyProvider.getApiKey() + const apiKey = this.actualProvider.apiKey + if (!apiKey || apiKey.trim() === '') { + return '' + } + + const keys = apiKey + .split(',') + .map((key) => key.trim()) + .filter(Boolean) + + if (keys.length === 0) { + return '' + } + + if (keys.length === 1) { + return keys[0] + } + + // Multi-key rotation + const keyName = `provider:${this.actualProvider.id}:last_used_key` + const lastUsedKey = window.keyv.get(keyName) + + if (!lastUsedKey) { + window.keyv.set(keyName, keys[0]) + return keys[0] + } + + const currentIndex = keys.indexOf(lastUsedKey) + const nextIndex = (currentIndex + 1) % keys.length + const nextKey = keys[nextIndex] + window.keyv.set(keyName, nextKey) + + return nextKey } } diff --git a/src/renderer/src/aiCore/legacy/clients/ApiClientFactory.ts b/src/renderer/src/aiCore/legacy/clients/ApiClientFactory.ts deleted file mode 100644 index ee878f5861..0000000000 --- a/src/renderer/src/aiCore/legacy/clients/ApiClientFactory.ts +++ /dev/null @@ -1,103 +0,0 @@ -import { loggerService } from '@logger' -import type { Provider } from '@renderer/types' -import { isNewApiProvider } from '@renderer/utils/provider' - -import { AihubmixAPIClient } from './aihubmix/AihubmixAPIClient' -import { AnthropicAPIClient } from './anthropic/AnthropicAPIClient' -import { AwsBedrockAPIClient } from './aws/AwsBedrockAPIClient' -import type { BaseApiClient } from './BaseApiClient' -import { CherryAiAPIClient } from './cherryai/CherryAiAPIClient' -import { GeminiAPIClient } from './gemini/GeminiAPIClient' -import { VertexAPIClient } from './gemini/VertexAPIClient' -import { NewAPIClient } from './newapi/NewAPIClient' -import { OpenAIAPIClient } from './openai/OpenAIApiClient' -import { OpenAIResponseAPIClient } from './openai/OpenAIResponseAPIClient' -import { OVMSClient } from './ovms/OVMSClient' -import { PPIOAPIClient } from './ppio/PPIOAPIClient' -import { ZhipuAPIClient } from './zhipu/ZhipuAPIClient' - -const logger = loggerService.withContext('ApiClientFactory') - -/** - * Factory for creating ApiClient instances based on provider configuration - * 根据提供者配置创建ApiClient实例的工厂 - */ -export class ApiClientFactory { - /** - * Create an ApiClient instance for the given provider - * 为给定的提供者创建ApiClient实例 - */ - static create(provider: Provider): BaseApiClient { - logger.debug(`Creating ApiClient for provider:`, { - id: provider.id, - type: provider.type - }) - - let instance: BaseApiClient - - // 首先检查特殊的 Provider ID - if (provider.id === 'cherryai') { - instance = new CherryAiAPIClient(provider) as BaseApiClient - return instance - } - - if (provider.id === 'aihubmix') { - logger.debug(`Creating AihubmixAPIClient for provider: ${provider.id}`) - instance = new AihubmixAPIClient(provider) as BaseApiClient - return instance - } - - if (isNewApiProvider(provider)) { - logger.debug(`Creating NewAPIClient for provider: ${provider.id}`) - instance = new NewAPIClient(provider) as BaseApiClient - return instance - } - - if (provider.id === 'ppio') { - logger.debug(`Creating PPIOAPIClient for provider: ${provider.id}`) - instance = new PPIOAPIClient(provider) as BaseApiClient - return instance - } - - if (provider.id === 'zhipu') { - instance = new ZhipuAPIClient(provider) as BaseApiClient - return instance - } - - if (provider.id === 'ovms') { - logger.debug(`Creating OVMSClient for provider: ${provider.id}`) - instance = new OVMSClient(provider) as BaseApiClient - return instance - } - - // 然后检查标准的 Provider Type - switch (provider.type) { - case 'openai': - instance = new OpenAIAPIClient(provider) as BaseApiClient - break - case 'azure-openai': - case 'openai-response': - instance = new OpenAIResponseAPIClient(provider) as BaseApiClient - break - case 'gemini': - instance = new GeminiAPIClient(provider) as BaseApiClient - break - case 'vertexai': - logger.debug(`Creating VertexAPIClient for provider: ${provider.id}`) - instance = new VertexAPIClient(provider) as BaseApiClient - break - case 'anthropic': - instance = new AnthropicAPIClient(provider) as BaseApiClient - break - case 'aws-bedrock': - instance = new AwsBedrockAPIClient(provider) as BaseApiClient - break - default: - logger.debug(`Using default OpenAIApiClient for provider: ${provider.id}`) - instance = new OpenAIAPIClient(provider) as BaseApiClient - break - } - - return instance - } -} diff --git a/src/renderer/src/aiCore/legacy/clients/BaseApiClient.ts b/src/renderer/src/aiCore/legacy/clients/BaseApiClient.ts deleted file mode 100644 index 5d435b9074..0000000000 --- a/src/renderer/src/aiCore/legacy/clients/BaseApiClient.ts +++ /dev/null @@ -1,489 +0,0 @@ -import { loggerService } from '@logger' -import { - getModelSupportedVerbosity, - isFunctionCallingModel, - isOpenAIModel, - isSupportFlexServiceTierModel, - isSupportTemperatureModel, - isSupportTopPModel -} from '@renderer/config/models' -import { REFERENCE_PROMPT } from '@renderer/config/prompts' -import { getLMStudioKeepAliveTime } from '@renderer/hooks/useLMStudio' -import { getAssistantSettings } from '@renderer/services/AssistantService' -import type { RootState } from '@renderer/store' -import type { - Assistant, - GenerateImageParams, - KnowledgeReference, - MCPCallToolResponse, - MCPTool, - MCPToolResponse, - MemoryItem, - Model, - Provider, - ToolCallResponse, - WebSearchProviderResponse, - WebSearchResponse -} from '@renderer/types' -import { - FileTypes, - GroqServiceTiers, - isGroqServiceTier, - isOpenAIServiceTier, - OpenAIServiceTiers, - SystemProviderIds -} from '@renderer/types' -import type { OpenAIVerbosity } from '@renderer/types/aiCoreTypes' -import type { Message } from '@renderer/types/newMessage' -import type { - RequestOptions, - SdkInstance, - SdkMessageParam, - SdkModel, - SdkParams, - SdkRawChunk, - SdkRawOutput, - SdkTool, - SdkToolCall -} from '@renderer/types/sdk' -import { isJSON, parseJSON } from '@renderer/utils' -import { addAbortController, removeAbortController } from '@renderer/utils/abortController' -import { findFileBlocks, getMainTextContent } from '@renderer/utils/messageUtils/find' -import { isSupportServiceTierProvider } from '@renderer/utils/provider' -import { defaultTimeout } from '@shared/config/constant' -import { defaultAppHeaders } from '@shared/utils' -import { isEmpty } from 'lodash' - -import type { CompletionsContext } from '../middleware/types' -import type { ApiClient, RequestTransformer, ResponseChunkTransformer } from './types' - -const logger = loggerService.withContext('BaseApiClient') - -/** - * Abstract base class for API clients. - * Provides common functionality and structure for specific client implementations. - */ -export abstract class BaseApiClient< - TSdkInstance extends SdkInstance = SdkInstance, - TSdkParams extends SdkParams = SdkParams, - TRawOutput extends SdkRawOutput = SdkRawOutput, - TRawChunk extends SdkRawChunk = SdkRawChunk, - TMessageParam extends SdkMessageParam = SdkMessageParam, - TToolCall extends SdkToolCall = SdkToolCall, - TSdkSpecificTool extends SdkTool = SdkTool -> implements ApiClient -{ - public provider: Provider - protected host: string - protected sdkInstance?: TSdkInstance - - constructor(provider: Provider) { - this.provider = provider - this.host = this.getBaseURL() - } - - /** - * Get the current API key with rotation support - * This getter ensures API keys rotate on each access when multiple keys are configured - */ - protected get apiKey(): string { - return this.getApiKey() - } - - /** - * 获取客户端的兼容性类型 - * 用于判断客户端是否支持特定功能,避免instanceof检查的类型收窄问题 - * 对于装饰器模式的客户端(如AihubmixAPIClient),应该返回其内部实际使用的客户端类型 - */ - // oxlint-disable-next-line @typescript-eslint/no-unused-vars - public getClientCompatibilityType(_model?: Model): string[] { - // 默认返回类的名称 - return [this.constructor.name] - } - - // // 核心的completions方法 - 在中间件架构中,这通常只是一个占位符 - // abstract completions(params: CompletionsParams, internal?: ProcessingState): Promise - - /** - * 核心API Endpoint - **/ - - abstract createCompletions(payload: TSdkParams, options?: RequestOptions): Promise - - abstract generateImage(generateImageParams: GenerateImageParams): Promise - - abstract getEmbeddingDimensions(model?: Model): Promise - - abstract listModels(): Promise - - abstract getSdkInstance(): Promise | TSdkInstance - - /** - * 中间件 - **/ - - // 在 CoreRequestToSdkParamsMiddleware中使用 - abstract getRequestTransformer(): RequestTransformer - // 在RawSdkChunkToGenericChunkMiddleware中使用 - abstract getResponseChunkTransformer(ctx: CompletionsContext): ResponseChunkTransformer - - /** - * 工具转换 - **/ - - // Optional tool conversion methods - implement if needed by the specific provider - abstract convertMcpToolsToSdkTools(mcpTools: MCPTool[]): TSdkSpecificTool[] - - abstract convertSdkToolCallToMcp(toolCall: TToolCall, mcpTools: MCPTool[]): MCPTool | undefined - - abstract convertSdkToolCallToMcpToolResponse(toolCall: TToolCall, mcpTool: MCPTool): ToolCallResponse - - abstract buildSdkMessages( - currentReqMessages: TMessageParam[], - output: TRawOutput | string | undefined, - toolResults: TMessageParam[], - toolCalls?: TToolCall[] - ): TMessageParam[] - - abstract estimateMessageTokens(message: TMessageParam): number - - abstract convertMcpToolResponseToSdkMessageParam( - mcpToolResponse: MCPToolResponse, - resp: MCPCallToolResponse, - model: Model - ): TMessageParam | undefined - - /** - * 从SDK载荷中提取消息数组(用于中间件中的类型安全访问) - * 不同的提供商可能使用不同的字段名(如messages、history等) - */ - abstract extractMessagesFromSdkPayload(sdkPayload: TSdkParams): TMessageParam[] - - /** - * 通用函数 - **/ - - public getBaseURL(): string { - return this.provider.apiHost - } - - public getApiKey() { - const keys = this.provider.apiKey.split(',').map((key) => key.trim()) - const keyName = `provider:${this.provider.id}:last_used_key` - - if (keys.length === 1) { - return keys[0] - } - - const lastUsedKey = window.keyv.get(keyName) - if (!lastUsedKey) { - window.keyv.set(keyName, keys[0]) - return keys[0] - } - - const currentIndex = keys.indexOf(lastUsedKey) - const nextIndex = (currentIndex + 1) % keys.length - const nextKey = keys[nextIndex] - window.keyv.set(keyName, nextKey) - - return nextKey - } - - public defaultHeaders() { - return { - ...defaultAppHeaders(), - 'X-Api-Key': this.apiKey - } - } - - public get keepAliveTime() { - return this.provider.id === 'lmstudio' ? getLMStudioKeepAliveTime() : undefined - } - - public getTemperature(assistant: Assistant, model: Model): number | undefined { - if (!isSupportTemperatureModel(model)) { - return undefined - } - const assistantSettings = getAssistantSettings(assistant) - return assistantSettings?.enableTemperature ? assistantSettings?.temperature : undefined - } - - public getTopP(assistant: Assistant, model: Model): number | undefined { - if (!isSupportTopPModel(model)) { - return undefined - } - const assistantSettings = getAssistantSettings(assistant) - return assistantSettings?.enableTopP ? assistantSettings?.topP : undefined - } - - // NOTE: 这个也许可以迁移到OpenAIBaseClient - protected getServiceTier(model: Model) { - const serviceTierSetting = this.provider.serviceTier - - if (!isSupportServiceTierProvider(this.provider) || !isOpenAIModel(model) || !serviceTierSetting) { - return undefined - } - - // 处理不同供应商需要 fallback 到默认值的情况 - if (this.provider.id === SystemProviderIds.groq) { - if ( - !isGroqServiceTier(serviceTierSetting) || - (serviceTierSetting === GroqServiceTiers.flex && !isSupportFlexServiceTierModel(model)) - ) { - return undefined - } - } else { - // 其他 OpenAI 供应商,假设他们的服务层级设置和 OpenAI 完全相同 - if ( - !isOpenAIServiceTier(serviceTierSetting) || - (serviceTierSetting === OpenAIServiceTiers.flex && !isSupportFlexServiceTierModel(model)) - ) { - return undefined - } - } - - return serviceTierSetting - } - - protected getVerbosity(model?: Model): OpenAIVerbosity { - try { - const state = window.store?.getState() as RootState - const verbosity = state?.settings?.openAI?.verbosity - - // If model is provided, check if the verbosity is supported by the model - if (model) { - const supportedVerbosity = getModelSupportedVerbosity(model) - // Use user's verbosity if supported, otherwise use the first supported option - return supportedVerbosity.includes(verbosity) ? verbosity : supportedVerbosity[0] - } - return verbosity - } catch (error) { - logger.warn('Failed to get verbosity from state. Fallback to undefined.', error as Error) - return undefined - } - } - - protected getTimeout(model: Model) { - if (isSupportFlexServiceTierModel(model)) { - return 15 * 1000 * 60 - } - return defaultTimeout - } - - public async getMessageContent( - message: Message - ): Promise<{ textContent: string; imageContents: { fileId: string; fileExt: string }[] }> { - const content = getMainTextContent(message) - - if (isEmpty(content)) { - return { - textContent: '', - imageContents: [] - } - } - - const webSearchReferences = await this.getWebSearchReferencesFromCache(message) - const knowledgeReferences = await this.getKnowledgeBaseReferencesFromCache(message) - const memoryReferences = this.getMemoryReferencesFromCache(message) - - const knowledgeTextReferences = knowledgeReferences.filter((k) => k.metadata?.type !== 'image') - const knowledgeImageReferences = knowledgeReferences.filter((k) => k.metadata?.type === 'image') - - // 添加偏移量以避免ID冲突 - const reindexedKnowledgeReferences = knowledgeTextReferences.map((ref) => ({ - ...ref, - id: ref.id + webSearchReferences.length // 为知识库引用的ID添加网络搜索引用的数量作为偏移量 - })) - - const allReferences = [...webSearchReferences, ...reindexedKnowledgeReferences, ...memoryReferences] - - logger.debug(`Found ${allReferences.length} references for ID: ${message.id}`, allReferences) - - const referenceContent = `\`\`\`json\n${JSON.stringify(allReferences, null, 2)}\n\`\`\`` - const imageReferences = knowledgeImageReferences.map((r) => { - return { fileId: r.metadata?.id, fileExt: r.metadata?.ext } - }) - - return { - textContent: isEmpty(allReferences) - ? content - : REFERENCE_PROMPT.replace('{question}', content).replace('{references}', referenceContent), - imageContents: isEmpty(knowledgeImageReferences) ? [] : imageReferences - } - } - - /** - * Extract the file content from the message - * @param message - The message - * @returns The file content - */ - protected async extractFileContent(message: Message) { - const fileBlocks = findFileBlocks(message) - if (fileBlocks.length > 0) { - const textFileBlocks = fileBlocks.filter( - (fb) => fb.file && [FileTypes.TEXT, FileTypes.DOCUMENT].includes(fb.file.type) - ) - - if (textFileBlocks.length > 0) { - let text = '' - const divider = '\n\n---\n\n' - - for (const fileBlock of textFileBlocks) { - const file = fileBlock.file - const fileContent = (await window.api.file.read(file.id + file.ext, true)).trim() - const fileNameRow = 'file: ' + file.origin_name + '\n\n' - text = text + fileNameRow + fileContent + divider - } - - return text - } - } - - return '' - } - - private getMemoryReferencesFromCache(message: Message) { - const memories = window.keyv.get(`memory-search-${message.id}`) as MemoryItem[] | undefined - if (memories) { - const memoryReferences: KnowledgeReference[] = memories.map((mem, index) => ({ - id: index + 1, - content: `${mem.memory} -- Created at: ${mem.createdAt}`, - sourceUrl: '', - type: 'memory' - })) - return memoryReferences - } - return [] - } - - private async getWebSearchReferencesFromCache(message: Message) { - const content = getMainTextContent(message) - if (isEmpty(content)) { - return [] - } - const webSearch: WebSearchResponse = window.keyv.get(`web-search-${message.id}`) - - if (webSearch) { - window.keyv.remove(`web-search-${message.id}`) - return (webSearch.results as WebSearchProviderResponse).results.map( - (result, index) => - ({ - id: index + 1, - content: result.content, - sourceUrl: result.url, - type: 'url' - }) as KnowledgeReference - ) - } - - return [] - } - - /** - * 从缓存中获取知识库引用 - */ - private async getKnowledgeBaseReferencesFromCache(message: Message): Promise { - const content = getMainTextContent(message) - if (isEmpty(content)) { - return [] - } - const knowledgeReferences: KnowledgeReference[] = window.keyv.get(`knowledge-search-${message.id}`) - - if (!isEmpty(knowledgeReferences)) { - window.keyv.remove(`knowledge-search-${message.id}`) - logger.debug(`Found ${knowledgeReferences.length} knowledge base references in cache for ID: ${message.id}`) - return knowledgeReferences - } - logger.debug(`No knowledge base references found in cache for ID: ${message.id}`) - return [] - } - - protected getCustomParameters(assistant: Assistant) { - return ( - assistant?.settings?.customParameters?.reduce((acc, param) => { - if (!param.name?.trim()) { - return acc - } - // Parse JSON type parameters (Legacy API clients) - // Related: src/renderer/src/pages/settings/AssistantSettings/AssistantModelSettings.tsx:133-148 - // The UI stores JSON type params as strings, this function parses them before sending to API - if (param.type === 'json') { - const value = param.value as string - if (value === 'undefined') { - return { ...acc, [param.name]: undefined } - } - return { ...acc, [param.name]: isJSON(value) ? parseJSON(value) : value } - } - return { - ...acc, - [param.name]: param.value - } - }, {}) || {} - ) - } - - public createAbortController(messageId?: string, isAddEventListener?: boolean) { - const abortController = new AbortController() - const abortFn = () => abortController.abort() - - if (messageId) { - addAbortController(messageId, abortFn) - } - - const cleanup = () => { - if (messageId) { - signalPromise.resolve?.(undefined) - removeAbortController(messageId, abortFn) - } - } - const signalPromise: { - resolve: (value: unknown) => void - promise: Promise - } = { - resolve: () => {}, - promise: Promise.resolve() - } - - if (isAddEventListener) { - signalPromise.promise = new Promise((resolve, reject) => { - signalPromise.resolve = resolve - if (abortController.signal.aborted) { - reject(new Error('Request was aborted.')) - } - // 捕获abort事件,有些abort事件必须 - abortController.signal.addEventListener('abort', () => { - reject(new Error('Request was aborted.')) - }) - }) - return { - abortController, - cleanup, - signalPromise - } - } - return { - abortController, - cleanup - } - } - - // Setup tools configuration based on provided parameters - public setupToolsConfig(params: { mcpTools?: MCPTool[]; model: Model; enableToolUse?: boolean }): { - tools: TSdkSpecificTool[] - } { - const { mcpTools, model, enableToolUse } = params - let tools: TSdkSpecificTool[] = [] - - // If there are no tools, return an empty array - if (!mcpTools?.length) { - return { tools } - } - - // If the model supports function calling and tool usage is enabled - if (isFunctionCallingModel(model) && enableToolUse) { - tools = this.convertMcpToolsToSdkTools(mcpTools) - } - - return { tools } - } -} diff --git a/src/renderer/src/aiCore/legacy/clients/MixedBaseApiClient.ts b/src/renderer/src/aiCore/legacy/clients/MixedBaseApiClient.ts deleted file mode 100644 index fb5568a6e8..0000000000 --- a/src/renderer/src/aiCore/legacy/clients/MixedBaseApiClient.ts +++ /dev/null @@ -1,181 +0,0 @@ -import type { - GenerateImageParams, - MCPCallToolResponse, - MCPTool, - MCPToolResponse, - Model, - Provider, - ToolCallResponse -} from '@renderer/types' -import type { - RequestOptions, - SdkInstance, - SdkMessageParam, - SdkModel, - SdkParams, - SdkRawChunk, - SdkRawOutput, - SdkTool, - SdkToolCall -} from '@renderer/types/sdk' - -import type { CompletionsContext } from '../middleware/types' -import type { AnthropicAPIClient } from './anthropic/AnthropicAPIClient' -import { BaseApiClient } from './BaseApiClient' -import type { GeminiAPIClient } from './gemini/GeminiAPIClient' -import type { OpenAIAPIClient } from './openai/OpenAIApiClient' -import type { OpenAIResponseAPIClient } from './openai/OpenAIResponseAPIClient' -import type { RequestTransformer, ResponseChunkTransformer } from './types' - -/** - * MixedAPIClient - 适用于可能含有多种接口类型的Provider - */ -export abstract class MixedBaseAPIClient extends BaseApiClient { - // 使用联合类型而不是any,保持类型安全 - protected abstract clients: Map< - string, - AnthropicAPIClient | GeminiAPIClient | OpenAIResponseAPIClient | OpenAIAPIClient - > - protected abstract defaultClient: OpenAIAPIClient - protected abstract currentClient: BaseApiClient - - constructor(provider: Provider) { - super(provider) - } - - override getBaseURL(): string { - if (!this.currentClient) { - return this.provider.apiHost - } - return this.currentClient.getBaseURL() - } - - /** - * 类型守卫:确保client是BaseApiClient的实例 - */ - protected isValidClient(client: unknown): client is BaseApiClient { - return ( - client !== null && - client !== undefined && - typeof client === 'object' && - 'createCompletions' in client && - 'getRequestTransformer' in client && - 'getResponseChunkTransformer' in client - ) - } - - /** - * 根据模型获取合适的client - */ - protected abstract getClient(model: Model): BaseApiClient - - /** - * 根据模型选择合适的client并委托调用 - */ - public getClientForModel(model: Model): BaseApiClient { - this.currentClient = this.getClient(model) - return this.currentClient - } - - /** - * 重写基类方法,返回内部实际使用的客户端类型 - */ - public override getClientCompatibilityType(model?: Model): string[] { - if (!model) { - return [this.constructor.name] - } - - const actualClient = this.getClient(model) - return actualClient.getClientCompatibilityType(model) - } - - /** - * 从SDK payload中提取模型ID - */ - protected extractModelFromPayload(payload: SdkParams): string | null { - // 不同的SDK可能有不同的字段名 - if ('model' in payload && typeof payload.model === 'string') { - return payload.model - } - return null - } - - // ============ BaseApiClient 的抽象方法 ============ - - async createCompletions(payload: SdkParams, options?: RequestOptions): Promise { - // 尝试从payload中提取模型信息来选择client - const modelId = this.extractModelFromPayload(payload) - if (modelId) { - const modelObj = { id: modelId } as Model - const targetClient = this.getClient(modelObj) - return targetClient.createCompletions(payload, options) - } - - // 如果无法从payload中提取模型,使用当前设置的client - return this.currentClient.createCompletions(payload, options) - } - - async generateImage(params: GenerateImageParams): Promise { - return this.currentClient.generateImage(params) - } - - async getEmbeddingDimensions(model?: Model): Promise { - const client = model ? this.getClient(model) : this.currentClient - return client.getEmbeddingDimensions(model) - } - - async listModels(): Promise { - // 可以聚合所有client的模型,或者使用默认client - return this.defaultClient.listModels() - } - - async getSdkInstance(): Promise { - return this.currentClient.getSdkInstance() - } - - getRequestTransformer(): RequestTransformer { - return this.currentClient.getRequestTransformer() - } - - getResponseChunkTransformer(ctx: CompletionsContext): ResponseChunkTransformer { - return this.currentClient.getResponseChunkTransformer(ctx) - } - - convertMcpToolsToSdkTools(mcpTools: MCPTool[]): SdkTool[] { - return this.currentClient.convertMcpToolsToSdkTools(mcpTools) - } - - convertSdkToolCallToMcp(toolCall: SdkToolCall, mcpTools: MCPTool[]): MCPTool | undefined { - return this.currentClient.convertSdkToolCallToMcp(toolCall, mcpTools) - } - - convertSdkToolCallToMcpToolResponse(toolCall: SdkToolCall, mcpTool: MCPTool): ToolCallResponse { - return this.currentClient.convertSdkToolCallToMcpToolResponse(toolCall, mcpTool) - } - - buildSdkMessages( - currentReqMessages: SdkMessageParam[], - output: SdkRawOutput | string, - toolResults: SdkMessageParam[], - toolCalls?: SdkToolCall[] - ): SdkMessageParam[] { - return this.currentClient.buildSdkMessages(currentReqMessages, output, toolResults, toolCalls) - } - - estimateMessageTokens(message: SdkMessageParam): number { - return this.currentClient.estimateMessageTokens(message) - } - - convertMcpToolResponseToSdkMessageParam( - mcpToolResponse: MCPToolResponse, - resp: MCPCallToolResponse, - model: Model - ): SdkMessageParam | undefined { - const client = this.getClient(model) - return client.convertMcpToolResponseToSdkMessageParam(mcpToolResponse, resp, model) - } - - extractMessagesFromSdkPayload(sdkPayload: SdkParams): SdkMessageParam[] { - return this.currentClient.extractMessagesFromSdkPayload(sdkPayload) - } -} diff --git a/src/renderer/src/aiCore/legacy/clients/__tests__/ApiClientFactory.test.ts b/src/renderer/src/aiCore/legacy/clients/__tests__/ApiClientFactory.test.ts deleted file mode 100644 index 991c436ca3..0000000000 --- a/src/renderer/src/aiCore/legacy/clients/__tests__/ApiClientFactory.test.ts +++ /dev/null @@ -1,221 +0,0 @@ -import type { Provider } from '@renderer/types' -import { beforeEach, describe, expect, it, vi } from 'vitest' - -import { AihubmixAPIClient } from '../aihubmix/AihubmixAPIClient' -import { AnthropicAPIClient } from '../anthropic/AnthropicAPIClient' -import { ApiClientFactory } from '../ApiClientFactory' -import { AwsBedrockAPIClient } from '../aws/AwsBedrockAPIClient' -import { GeminiAPIClient } from '../gemini/GeminiAPIClient' -import { VertexAPIClient } from '../gemini/VertexAPIClient' -import { NewAPIClient } from '../newapi/NewAPIClient' -import { OpenAIAPIClient } from '../openai/OpenAIApiClient' -import { OpenAIResponseAPIClient } from '../openai/OpenAIResponseAPIClient' -import { PPIOAPIClient } from '../ppio/PPIOAPIClient' - -// 为工厂测试创建最小化 provider 的辅助函数 -// ApiClientFactory 只使用 'id' 和 'type' 字段来决定创建哪个客户端 -// 其他字段会传递给客户端构造函数,但不影响工厂逻辑 -const createTestProvider = (id: string, type: string): Provider => ({ - id, - type: type as Provider['type'], - name: '', - apiKey: '', - apiHost: '', - models: [] -}) - -// Mock 所有客户端模块 -vi.mock('../aihubmix/AihubmixAPIClient', () => ({ - AihubmixAPIClient: vi.fn().mockImplementation(() => ({})) -})) -vi.mock('../anthropic/AnthropicAPIClient', () => ({ - AnthropicAPIClient: vi.fn().mockImplementation(() => ({})) -})) -vi.mock('../anthropic/AnthropicVertexClient', () => ({ - AnthropicVertexClient: vi.fn().mockImplementation(() => ({})) -})) -vi.mock('../gemini/GeminiAPIClient', () => ({ - GeminiAPIClient: vi.fn().mockImplementation(() => ({})) -})) -vi.mock('../gemini/VertexAPIClient', () => ({ - VertexAPIClient: vi.fn().mockImplementation(() => ({})) -})) -vi.mock('../newapi/NewAPIClient', () => ({ - NewAPIClient: vi.fn().mockImplementation(() => ({})) -})) -vi.mock('../openai/OpenAIApiClient', () => ({ - OpenAIAPIClient: vi.fn().mockImplementation(() => ({})) -})) -vi.mock('../openai/OpenAIResponseAPIClient', () => ({ - OpenAIResponseAPIClient: vi.fn().mockImplementation(() => ({ - getClient: vi.fn().mockReturnThis() - })) -})) -vi.mock('../ppio/PPIOAPIClient', () => ({ - PPIOAPIClient: vi.fn().mockImplementation(() => ({})) -})) -vi.mock('../aws/AwsBedrockAPIClient', () => ({ - AwsBedrockAPIClient: vi.fn().mockImplementation(() => ({})) -})) - -vi.mock('@renderer/services/AssistantService.ts', () => ({ - getDefaultAssistant: () => { - return { - id: 'default', - name: 'default', - emoji: '😀', - prompt: '', - topics: [], - messages: [], - type: 'assistant', - regularPhrases: [], - settings: {} - } - } -})) - -// Mock the models config to prevent circular dependency issues -vi.mock('@renderer/config/models', () => ({ - findTokenLimit: vi.fn(), - isReasoningModel: vi.fn(), - isOpenAILLMModel: vi.fn(), - SYSTEM_MODELS: { - silicon: [], - defaultModel: [] - }, - isOpenAIModel: vi.fn(() => false), - glm45FlashModel: {}, - qwen38bModel: {} -})) - -describe('ApiClientFactory', () => { - beforeEach(() => { - vi.clearAllMocks() - }) - - describe('create', () => { - // 测试特殊 ID 的客户端创建 - it('should create AihubmixAPIClient for aihubmix provider', () => { - const provider = createTestProvider('aihubmix', 'openai') - - const client = ApiClientFactory.create(provider) - - expect(AihubmixAPIClient).toHaveBeenCalledWith(provider) - expect(client).toBeDefined() - }) - - it('should create NewAPIClient for new-api provider', () => { - const provider = createTestProvider('new-api', 'openai') - - const client = ApiClientFactory.create(provider) - - expect(NewAPIClient).toHaveBeenCalledWith(provider) - expect(client).toBeDefined() - }) - - it('should create PPIOAPIClient for ppio provider', () => { - const provider = createTestProvider('ppio', 'openai') - - const client = ApiClientFactory.create(provider) - - expect(PPIOAPIClient).toHaveBeenCalledWith(provider) - expect(client).toBeDefined() - }) - - // 测试标准类型的客户端创建 - it('should create OpenAIAPIClient for openai type', () => { - const provider = createTestProvider('custom-openai', 'openai') - - const client = ApiClientFactory.create(provider) - - expect(OpenAIAPIClient).toHaveBeenCalledWith(provider) - expect(client).toBeDefined() - }) - - it('should create OpenAIResponseAPIClient for azure-openai type', () => { - const provider = createTestProvider('azure-openai', 'azure-openai') - - const client = ApiClientFactory.create(provider) - - expect(OpenAIResponseAPIClient).toHaveBeenCalledWith(provider) - expect(client).toBeDefined() - }) - - it('should create OpenAIResponseAPIClient for openai-response type', () => { - const provider = createTestProvider('response', 'openai-response') - - const client = ApiClientFactory.create(provider) - - expect(OpenAIResponseAPIClient).toHaveBeenCalledWith(provider) - expect(client).toBeDefined() - }) - - it('should create GeminiAPIClient for gemini type', () => { - const provider = createTestProvider('gemini', 'gemini') - - const client = ApiClientFactory.create(provider) - - expect(GeminiAPIClient).toHaveBeenCalledWith(provider) - expect(client).toBeDefined() - }) - - it('should create VertexAPIClient for vertexai type', () => { - const provider = createTestProvider('vertex', 'vertexai') - - const client = ApiClientFactory.create(provider) - - expect(VertexAPIClient).toHaveBeenCalledWith(provider) - expect(client).toBeDefined() - }) - - it('should create AnthropicAPIClient for anthropic type', () => { - const provider = createTestProvider('anthropic', 'anthropic') - - const client = ApiClientFactory.create(provider) - - expect(AnthropicAPIClient).toHaveBeenCalledWith(provider) - expect(client).toBeDefined() - }) - - it('should create AwsBedrockAPIClient for aws-bedrock type', () => { - const provider = createTestProvider('aws-bedrock', 'aws-bedrock') - - const client = ApiClientFactory.create(provider) - - expect(AwsBedrockAPIClient).toHaveBeenCalledWith(provider) - expect(client).toBeDefined() - }) - - // 测试默认情况 - it('should create OpenAIAPIClient as default for unknown type', () => { - const provider = createTestProvider('unknown', 'unknown-type') - - const client = ApiClientFactory.create(provider) - - expect(OpenAIAPIClient).toHaveBeenCalledWith(provider) - expect(client).toBeDefined() - }) - - // 测试边界条件 - it('should handle provider with minimal configuration', () => { - const provider = createTestProvider('minimal', 'openai') - - const client = ApiClientFactory.create(provider) - - expect(OpenAIAPIClient).toHaveBeenCalledWith(provider) - expect(client).toBeDefined() - }) - - // 测试特殊 ID 优先级高于类型 - it('should prioritize special ID over type', () => { - const provider = createTestProvider('aihubmix', 'anthropic') // 即使类型是 anthropic - - const client = ApiClientFactory.create(provider) - - // 应该创建 AihubmixAPIClient 而不是 AnthropicAPIClient - expect(AihubmixAPIClient).toHaveBeenCalledWith(provider) - expect(AnthropicAPIClient).not.toHaveBeenCalled() - expect(client).toBeDefined() - }) - }) -}) diff --git a/src/renderer/src/aiCore/legacy/clients/__tests__/OpenAIBaseClient.azureEndpoint.test.ts b/src/renderer/src/aiCore/legacy/clients/__tests__/OpenAIBaseClient.azureEndpoint.test.ts deleted file mode 100644 index e3b2ef2676..0000000000 --- a/src/renderer/src/aiCore/legacy/clients/__tests__/OpenAIBaseClient.azureEndpoint.test.ts +++ /dev/null @@ -1,38 +0,0 @@ -import { describe, expect, it } from 'vitest' - -import { normalizeAzureOpenAIEndpoint } from '../openai/azureOpenAIEndpoint' - -describe('normalizeAzureOpenAIEndpoint', () => { - it.each([ - { - apiHost: 'https://example.openai.azure.com/openai', - expectedEndpoint: 'https://example.openai.azure.com' - }, - { - apiHost: 'https://example.openai.azure.com/openai/', - expectedEndpoint: 'https://example.openai.azure.com' - }, - { - apiHost: 'https://example.openai.azure.com/openai/v1', - expectedEndpoint: 'https://example.openai.azure.com' - }, - { - apiHost: 'https://example.openai.azure.com/openai/v1/', - expectedEndpoint: 'https://example.openai.azure.com' - }, - { - apiHost: 'https://example.openai.azure.com', - expectedEndpoint: 'https://example.openai.azure.com' - }, - { - apiHost: 'https://example.openai.azure.com/', - expectedEndpoint: 'https://example.openai.azure.com' - }, - { - apiHost: 'https://example.openai.azure.com/OPENAI/V1', - expectedEndpoint: 'https://example.openai.azure.com' - } - ])('strips trailing /openai from $apiHost', ({ apiHost, expectedEndpoint }) => { - expect(normalizeAzureOpenAIEndpoint(apiHost)).toBe(expectedEndpoint) - }) -}) diff --git a/src/renderer/src/aiCore/legacy/clients/__tests__/index.clientCompatibilityTypes.test.ts b/src/renderer/src/aiCore/legacy/clients/__tests__/index.clientCompatibilityTypes.test.ts deleted file mode 100644 index bcff572410..0000000000 --- a/src/renderer/src/aiCore/legacy/clients/__tests__/index.clientCompatibilityTypes.test.ts +++ /dev/null @@ -1,354 +0,0 @@ -import { AihubmixAPIClient } from '@renderer/aiCore/legacy/clients/aihubmix/AihubmixAPIClient' -import { AnthropicAPIClient } from '@renderer/aiCore/legacy/clients/anthropic/AnthropicAPIClient' -import { ApiClientFactory } from '@renderer/aiCore/legacy/clients/ApiClientFactory' -import { GeminiAPIClient } from '@renderer/aiCore/legacy/clients/gemini/GeminiAPIClient' -import { VertexAPIClient } from '@renderer/aiCore/legacy/clients/gemini/VertexAPIClient' -import { NewAPIClient } from '@renderer/aiCore/legacy/clients/newapi/NewAPIClient' -import { OpenAIAPIClient } from '@renderer/aiCore/legacy/clients/openai/OpenAIApiClient' -import { OpenAIResponseAPIClient } from '@renderer/aiCore/legacy/clients/openai/OpenAIResponseAPIClient' -import type { EndpointType, Model, Provider } from '@renderer/types' -import { beforeEach, describe, expect, it, vi } from 'vitest' - -vi.mock('@renderer/config/models', () => ({ - SYSTEM_MODELS: { - defaultModel: [ - { id: 'gpt-4', name: 'GPT-4' }, - { id: 'gpt-4', name: 'GPT-4' }, - { id: 'gpt-4', name: 'GPT-4' } - ], - zhipu: [], - silicon: [], - openai: [], - anthropic: [], - gemini: [] - }, - isOpenAIModel: vi.fn().mockReturnValue(true), - isOpenAILLMModel: vi.fn().mockReturnValue(true), - isOpenAIChatCompletionOnlyModel: vi.fn().mockReturnValue(false), - isAnthropicLLMModel: vi.fn().mockReturnValue(false), - isGeminiLLMModel: vi.fn().mockReturnValue(false), - isSupportedReasoningEffortOpenAIModel: vi.fn().mockReturnValue(false), - isVisionModel: vi.fn().mockReturnValue(false), - isClaudeReasoningModel: vi.fn().mockReturnValue(false), - isReasoningModel: vi.fn().mockReturnValue(false), - isWebSearchModel: vi.fn().mockReturnValue(false), - findTokenLimit: vi.fn().mockReturnValue(4096), - isFunctionCallingModel: vi.fn().mockReturnValue(false), - DEFAULT_MAX_TOKENS: 4096, - qwen38bModel: {}, - glm45FlashModel: {} -})) - -vi.mock('@renderer/services/AssistantService', () => ({ - getProviderByModel: vi.fn(), - getAssistantSettings: vi.fn(), - getDefaultAssistant: vi.fn().mockReturnValue({ - id: 'default', - name: 'Default Assistant', - prompt: '', - settings: {} - }) -})) - -vi.mock('@renderer/services/FileManager', () => ({ - default: class { - static async read() { - return 'test content' - } - static async write() { - return true - } - } -})) - -vi.mock('@renderer/services/TokenService', () => ({ - estimateTextTokens: vi.fn().mockReturnValue(100) -})) - -vi.mock('@logger', () => ({ - loggerService: { - withContext: vi.fn().mockReturnValue({ - debug: vi.fn(), - info: vi.fn(), - warn: vi.fn(), - error: vi.fn(), - silly: vi.fn() - }) - } -})) - -// 到底是谁想出来的在服务层调用 React Hook ????????? -// Mock additional services and hooks that might be imported -vi.mock('@renderer/hooks/useVertexAI', () => ({ - getVertexAILocation: vi.fn().mockReturnValue('us-central1'), - getVertexAIProjectId: vi.fn().mockReturnValue('test-project'), - getVertexAIServiceAccount: vi.fn().mockReturnValue({ - privateKey: 'test-key', - clientEmail: 'test@example.com' - }), - isVertexAIConfigured: vi.fn().mockReturnValue(true), - isVertexProvider: vi.fn().mockReturnValue(true) -})) - -vi.mock('@renderer/hooks/useSettings', () => ({ - getStoreSetting: vi.fn().mockReturnValue({}), - useSettings: vi.fn().mockReturnValue([{}, vi.fn()]) -})) - -vi.mock('@renderer/store/settings', () => ({ - default: {}, - settingsSlice: { - name: 'settings', - reducer: vi.fn(), - actions: {} - } -})) - -vi.mock('@renderer/utils/abortController', () => ({ - addAbortController: vi.fn(), - removeAbortController: vi.fn() -})) - -vi.mock('@anthropic-ai/sdk', () => ({ - default: vi.fn().mockImplementation(() => ({})) -})) - -vi.mock('@anthropic-ai/vertex-sdk', () => ({ - default: vi.fn().mockImplementation(() => ({})) -})) - -vi.mock('openai', () => ({ - default: vi.fn().mockImplementation(() => ({})), - AzureOpenAI: vi.fn().mockImplementation(() => ({})) -})) - -vi.mock('@google/generative-ai', () => ({ - GoogleGenerativeAI: vi.fn().mockImplementation(() => ({})) -})) - -vi.mock('@google-cloud/vertexai', () => ({ - VertexAI: vi.fn().mockImplementation(() => ({})) -})) - -// Mock the circular dependency between VertexAPIClient and AnthropicVertexClient -vi.mock('@renderer/aiCore/legacy/clients/anthropic/AnthropicVertexClient', () => { - const MockAnthropicVertexClient = vi.fn() - MockAnthropicVertexClient.prototype.getClientCompatibilityType = vi.fn().mockReturnValue(['AnthropicVertexAPIClient']) - return { - AnthropicVertexClient: MockAnthropicVertexClient - } -}) - -// Helper to create test provider -const createTestProvider = (id: string, type: string): Provider => ({ - id, - type: type as Provider['type'], - name: 'Test Provider', - apiKey: 'test-key', - apiHost: 'https://api.test.com', - models: [] -}) - -// Helper to create test model -const createTestModel = (id: string, provider?: string, endpointType?: string): Model => ({ - id, - name: 'Test Model', - provider: provider || 'test', - type: [], - group: 'test', - endpoint_type: endpointType as EndpointType -}) - -describe('Client Compatibility Types', () => { - let openaiProvider: Provider - let anthropicProvider: Provider - let geminiProvider: Provider - let azureProvider: Provider - let aihubmixProvider: Provider - let newApiProvider: Provider - let vertexProvider: Provider - - beforeEach(() => { - vi.clearAllMocks() - - openaiProvider = createTestProvider('openai', 'openai') - anthropicProvider = createTestProvider('anthropic', 'anthropic') - geminiProvider = createTestProvider('gemini', 'gemini') - azureProvider = createTestProvider('azure-openai', 'azure-openai') - aihubmixProvider = createTestProvider('aihubmix', 'openai') - newApiProvider = createTestProvider('new-api', 'openai') - vertexProvider = createTestProvider('vertex', 'vertexai') - }) - - describe('Direct API Clients', () => { - it('should return correct compatibility type for OpenAIAPIClient', () => { - const client = new OpenAIAPIClient(openaiProvider) - const compatibilityTypes = client.getClientCompatibilityType() - - expect(compatibilityTypes).toEqual(['OpenAIAPIClient']) - }) - - it('should return correct compatibility type for AnthropicAPIClient', () => { - const client = new AnthropicAPIClient(anthropicProvider) - const compatibilityTypes = client.getClientCompatibilityType() - - expect(compatibilityTypes).toEqual(['AnthropicAPIClient']) - }) - - it('should return correct compatibility type for GeminiAPIClient', () => { - const client = new GeminiAPIClient(geminiProvider) - const compatibilityTypes = client.getClientCompatibilityType() - - expect(compatibilityTypes).toEqual(['GeminiAPIClient']) - }) - }) - - describe('Decorator Pattern API Clients', () => { - it('should return OpenAIResponseAPIClient for OpenAIResponseAPIClient without model', () => { - const client = new OpenAIResponseAPIClient(azureProvider) - const compatibilityTypes = client.getClientCompatibilityType() - - expect(compatibilityTypes).toEqual(['OpenAIResponseAPIClient']) - }) - - it('should delegate to underlying client for OpenAIResponseAPIClient with model', () => { - const client = new OpenAIResponseAPIClient(azureProvider) - const testModel = createTestModel('gpt-4', 'azure-openai') - - // Get the actual client selected for this model - const actualClient = client.getClient(testModel) - const compatibilityTypes = actualClient.getClientCompatibilityType(testModel) - - // Should return OpenAIResponseAPIClient for non-chat-completion-only models - expect(compatibilityTypes).toEqual(['OpenAIAPIClient']) - }) - - it('should return AihubmixAPIClient for AihubmixAPIClient without model', () => { - const client = new AihubmixAPIClient(aihubmixProvider) - const compatibilityTypes = client.getClientCompatibilityType() - - expect(compatibilityTypes).toEqual(['AihubmixAPIClient']) - }) - - it('should delegate to underlying client for AihubmixAPIClient with model', () => { - const client = new AihubmixAPIClient(aihubmixProvider) - const testModel = createTestModel('gpt-4', 'openai') - - // Get the actual client selected for this model - const actualClient = client.getClientForModel(testModel) - const compatibilityTypes = actualClient.getClientCompatibilityType(testModel) - - // Should return the actual underlying client type based on model (OpenAI models use OpenAIResponseAPIClient in Aihubmix) - expect(compatibilityTypes).toEqual(['OpenAIResponseAPIClient']) - }) - - it('should return NewAPIClient for NewAPIClient without model', () => { - const client = new NewAPIClient(newApiProvider) - const compatibilityTypes = client.getClientCompatibilityType() - - expect(compatibilityTypes).toEqual(['NewAPIClient']) - }) - - it('should delegate to underlying client for NewAPIClient with model', () => { - const client = new NewAPIClient(newApiProvider) - const testModel = createTestModel('gpt-4', 'openai', 'openai-response') - - // Get the actual client selected for this model - const actualClient = client.getClientForModel(testModel) - const compatibilityTypes = actualClient.getClientCompatibilityType(testModel) - - // Should return the actual underlying client type based on model - expect(compatibilityTypes).toEqual(['OpenAIResponseAPIClient']) - }) - - it('should return VertexAPIClient for VertexAPIClient without model', () => { - const client = new VertexAPIClient(vertexProvider) - const compatibilityTypes = client.getClientCompatibilityType() - - expect(compatibilityTypes).toEqual(['VertexAPIClient']) - }) - - it('should delegate to underlying client for VertexAPIClient with model', () => { - const client = new VertexAPIClient(vertexProvider) - const testModel = createTestModel('claude-3-5-sonnet', 'vertexai') - - // Get the actual client selected for this model - const actualClient = client.getClient(testModel) - const compatibilityTypes = actualClient.getClientCompatibilityType(testModel) - - // Should return the actual underlying client type based on model (Claude models use AnthropicVertexClient) - expect(compatibilityTypes).toEqual(['AnthropicVertexAPIClient']) - }) - }) - - describe('Middleware Compatibility Logic', () => { - it('should correctly identify OpenAI compatible clients', () => { - const openaiClient = new OpenAIAPIClient(openaiProvider) - const openaiResponseClient = new OpenAIResponseAPIClient(azureProvider) - - const openaiTypes = openaiClient.getClientCompatibilityType() - const responseTypes = openaiResponseClient.getClientCompatibilityType() - - // Test the logic from completions method line 94 - const isOpenAICompatible = (types: string[]) => - types.includes('OpenAIAPIClient') || types.includes('OpenAIResponseAPIClient') - - expect(isOpenAICompatible(openaiTypes)).toBe(true) - expect(isOpenAICompatible(responseTypes)).toBe(true) - }) - - it('should correctly identify Anthropic or OpenAIResponse compatible clients', () => { - const anthropicClient = new AnthropicAPIClient(anthropicProvider) - const openaiResponseClient = new OpenAIResponseAPIClient(azureProvider) - const openaiClient = new OpenAIAPIClient(openaiProvider) - - const anthropicTypes = anthropicClient.getClientCompatibilityType() - const responseTypes = openaiResponseClient.getClientCompatibilityType() - const openaiTypes = openaiClient.getClientCompatibilityType() - - // Test the logic from completions method line 101 - const isAnthropicOrOpenAIResponseCompatible = (types: string[]) => - types.includes('AnthropicAPIClient') || types.includes('OpenAIResponseAPIClient') - - expect(isAnthropicOrOpenAIResponseCompatible(anthropicTypes)).toBe(true) - expect(isAnthropicOrOpenAIResponseCompatible(responseTypes)).toBe(true) - expect(isAnthropicOrOpenAIResponseCompatible(openaiTypes)).toBe(false) - }) - - it('should handle non-compatible clients correctly', () => { - const geminiClient = new GeminiAPIClient(geminiProvider) - const geminiTypes = geminiClient.getClientCompatibilityType() - - // Test that Gemini is not OpenAI compatible - const isOpenAICompatible = (types: string[]) => - types.includes('OpenAIAPIClient') || types.includes('OpenAIResponseAPIClient') - - // Test that Gemini is not Anthropic/OpenAIResponse compatible - const isAnthropicOrOpenAIResponseCompatible = (types: string[]) => - types.includes('AnthropicAPIClient') || types.includes('OpenAIResponseAPIClient') - - expect(isOpenAICompatible(geminiTypes)).toBe(false) - expect(isAnthropicOrOpenAIResponseCompatible(geminiTypes)).toBe(false) - }) - }) - - describe('Factory Integration', () => { - it('should return correct compatibility types for factory-created clients', () => { - const testCases = [ - { provider: openaiProvider, expectedType: 'OpenAIAPIClient' }, - { provider: anthropicProvider, expectedType: 'AnthropicAPIClient' }, - { provider: azureProvider, expectedType: 'OpenAIResponseAPIClient' }, - { provider: aihubmixProvider, expectedType: 'AihubmixAPIClient' }, - { provider: newApiProvider, expectedType: 'NewAPIClient' }, - { provider: vertexProvider, expectedType: 'VertexAPIClient' } - ] - - testCases.forEach(({ provider, expectedType }) => { - const client = ApiClientFactory.create(provider) - const compatibilityTypes = client.getClientCompatibilityType() - - expect(compatibilityTypes).toContain(expectedType) - }) - }) - }) -}) diff --git a/src/renderer/src/aiCore/legacy/clients/aihubmix/AihubmixAPIClient.ts b/src/renderer/src/aiCore/legacy/clients/aihubmix/AihubmixAPIClient.ts deleted file mode 100644 index a8a0ca5ac6..0000000000 --- a/src/renderer/src/aiCore/legacy/clients/aihubmix/AihubmixAPIClient.ts +++ /dev/null @@ -1,96 +0,0 @@ -import { isOpenAILLMModel } from '@renderer/config/models' -import type { Model, Provider } from '@renderer/types' - -import { AnthropicAPIClient } from '../anthropic/AnthropicAPIClient' -import type { BaseApiClient } from '../BaseApiClient' -import { GeminiAPIClient } from '../gemini/GeminiAPIClient' -import { MixedBaseAPIClient } from '../MixedBaseApiClient' -import { OpenAIAPIClient } from '../openai/OpenAIApiClient' -import { OpenAIResponseAPIClient } from '../openai/OpenAIResponseAPIClient' - -/** - * AihubmixAPIClient - 根据模型类型自动选择合适的ApiClient - * 使用装饰器模式实现,在ApiClient层面进行模型路由 - */ -export class AihubmixAPIClient extends MixedBaseAPIClient { - // 使用联合类型而不是any,保持类型安全 - protected clients: Map = - new Map() - protected defaultClient: OpenAIAPIClient - protected currentClient: BaseApiClient - - constructor(provider: Provider) { - super(provider) - - const providerExtraHeaders = { - ...provider, - extra_headers: { - ...provider.extra_headers, - 'APP-Code': 'MLTG2087' - } - } - - // 初始化各个client - 现在有类型安全 - const claudeClient = new AnthropicAPIClient(providerExtraHeaders) - const geminiClient = new GeminiAPIClient({ ...providerExtraHeaders, apiHost: 'https://aihubmix.com/gemini' }) - const openaiClient = new OpenAIResponseAPIClient(providerExtraHeaders) - const defaultClient = new OpenAIAPIClient(providerExtraHeaders) - - this.clients.set('claude', claudeClient) - this.clients.set('gemini', geminiClient) - this.clients.set('openai', openaiClient) - this.clients.set('default', defaultClient) - - // 设置默认client - this.defaultClient = defaultClient - this.currentClient = this.defaultClient as BaseApiClient - } - - override getBaseURL(): string { - if (!this.currentClient) { - return this.provider.apiHost - } - return this.currentClient.getBaseURL() - } - - /** - * 根据模型获取合适的client - */ - protected getClient(model: Model): BaseApiClient { - const id = model.id.toLowerCase() - - // claude开头 - if (id.startsWith('claude')) { - const client = this.clients.get('claude') - if (!client || !this.isValidClient(client)) { - throw new Error('Claude client not properly initialized') - } - return client - } - - // gemini开头 且不以-nothink、-search结尾 - if ( - (id.startsWith('gemini') || id.startsWith('imagen')) && - !id.endsWith('-nothink') && - !id.endsWith('-search') && - !id.includes('embedding') - ) { - const client = this.clients.get('gemini') - if (!client || !this.isValidClient(client)) { - throw new Error('Gemini client not properly initialized') - } - return client - } - - // OpenAI系列模型 不包含gpt-oss - if (isOpenAILLMModel(model) && !model.id.includes('gpt-oss')) { - const client = this.clients.get('openai') - if (!client || !this.isValidClient(client)) { - throw new Error('OpenAI client not properly initialized') - } - return client - } - - return this.defaultClient as BaseApiClient - } -} diff --git a/src/renderer/src/aiCore/legacy/clients/anthropic/AnthropicAPIClient.ts b/src/renderer/src/aiCore/legacy/clients/anthropic/AnthropicAPIClient.ts deleted file mode 100644 index 9b63b77ddf..0000000000 --- a/src/renderer/src/aiCore/legacy/clients/anthropic/AnthropicAPIClient.ts +++ /dev/null @@ -1,788 +0,0 @@ -import type Anthropic from '@anthropic-ai/sdk' -import type { - Base64ImageSource, - ImageBlockParam, - MessageParam, - TextBlockParam, - ToolResultBlockParam, - ToolUseBlock, - WebSearchTool20250305 -} from '@anthropic-ai/sdk/resources' -import type { - ContentBlock, - ContentBlockParam, - MessageCreateParamsBase, - RedactedThinkingBlockParam, - ServerToolUseBlockParam, - ThinkingBlockParam, - ThinkingConfigParam, - ToolUnion, - ToolUseBlockParam, - WebSearchResultBlock, - WebSearchToolResultBlockParam, - WebSearchToolResultError -} from '@anthropic-ai/sdk/resources/messages' -import { MessageStream } from '@anthropic-ai/sdk/resources/messages/messages' -import type AnthropicVertex from '@anthropic-ai/vertex-sdk' -import { loggerService } from '@logger' -import { DEFAULT_MAX_TOKENS } from '@renderer/config/constant' -import { findTokenLimit, isClaudeReasoningModel, isReasoningModel, isWebSearchModel } from '@renderer/config/models' -import { getAssistantSettings } from '@renderer/services/AssistantService' -import FileManager from '@renderer/services/FileManager' -import { estimateTextTokens } from '@renderer/services/TokenService' -import type { - Assistant, - MCPCallToolResponse, - MCPTool, - MCPToolResponse, - Model, - Provider, - ToolCallResponse -} from '@renderer/types' -import { EFFORT_RATIO, FileTypes, WebSearchSource } from '@renderer/types' -import type { - ErrorChunk, - LLMWebSearchCompleteChunk, - LLMWebSearchInProgressChunk, - MCPToolCreatedChunk, - TextDeltaChunk, - TextStartChunk, - ThinkingDeltaChunk, - ThinkingStartChunk -} from '@renderer/types/chunk' -import { ChunkType } from '@renderer/types/chunk' -import { type Message } from '@renderer/types/newMessage' -import type { - AnthropicSdkMessageParam, - AnthropicSdkParams, - AnthropicSdkRawChunk, - AnthropicSdkRawOutput -} from '@renderer/types/sdk' -import { addImageFileToContents } from '@renderer/utils/formats' -import { - anthropicToolUseToMcpTool, - isSupportedToolUse, - mcpToolCallResponseToAnthropicMessage, - mcpToolsToAnthropicTools -} from '@renderer/utils/mcp-tools' -import { findFileBlocks, findImageBlocks } from '@renderer/utils/messageUtils/find' -import { buildClaudeCodeSystemMessage, getSdkClient } from '@shared/anthropic' -import { t } from 'i18next' - -import type { GenericChunk } from '../../middleware/schemas' -import { BaseApiClient } from '../BaseApiClient' -import type { AnthropicStreamListener, RawStreamListener, RequestTransformer, ResponseChunkTransformer } from '../types' - -const logger = loggerService.withContext('AnthropicAPIClient') - -export class AnthropicAPIClient extends BaseApiClient< - Anthropic | AnthropicVertex, - AnthropicSdkParams, - AnthropicSdkRawOutput, - AnthropicSdkRawChunk, - AnthropicSdkMessageParam, - ToolUseBlock, - ToolUnion -> { - oauthToken: string | undefined = undefined - sdkInstance: Anthropic | AnthropicVertex | undefined = undefined - - constructor(provider: Provider) { - super(provider) - } - - async getSdkInstance(): Promise { - if (this.sdkInstance) { - return this.sdkInstance - } - if (this.provider.authType === 'oauth') { - this.oauthToken = await window.api.anthropic_oauth.getAccessToken() - } - this.sdkInstance = getSdkClient(this.provider, this.oauthToken) - return this.sdkInstance - } - - override async createCompletions( - payload: AnthropicSdkParams, - options?: Anthropic.RequestOptions - ): Promise { - if (this.provider.authType === 'oauth') { - payload.system = buildClaudeCodeSystemMessage(payload.system) - } - const sdk = (await this.getSdkInstance()) as Anthropic - if (payload.stream) { - return sdk.messages.stream(payload, options) - } - return sdk.messages.create(payload, options) - } - - // @ts-ignore sdk未提供 - // oxlint-disable-next-line @typescript-eslint/no-unused-vars - override async generateImage(generateImageParams: GenerateImageParams): Promise { - return [] - } - - override async listModels(): Promise { - const sdk = (await this.getSdkInstance()) as Anthropic - // prevent auto appended /v1. It's included in baseUrl. - const response = await sdk.models.list({ path: '/models' }) - return response.data - } - - // @ts-ignore sdk未提供 - override async getEmbeddingDimensions(): Promise { - throw new Error("Anthropic SDK doesn't support getEmbeddingDimensions method.") - } - - override getTemperature(assistant: Assistant, model: Model): number | undefined { - if (assistant.settings?.reasoning_effort && isClaudeReasoningModel(model)) { - return undefined - } - return super.getTemperature(assistant, model) - } - - override getTopP(assistant: Assistant, model: Model): number | undefined { - if (assistant.settings?.reasoning_effort && isClaudeReasoningModel(model)) { - return undefined - } - return super.getTopP(assistant, model) - } - - /** - * Get the reasoning effort - * @param assistant - The assistant - * @param model - The model - * @returns The reasoning effort - */ - private getBudgetToken(assistant: Assistant, model: Model): ThinkingConfigParam | undefined { - if (!isReasoningModel(model)) { - return undefined - } - const { maxTokens } = getAssistantSettings(assistant) - - const reasoningEffort = assistant?.settings?.reasoning_effort - - if (reasoningEffort === undefined) { - return { - type: 'disabled' - } - } - - const effortRatio = EFFORT_RATIO[reasoningEffort] - - const budgetTokens = Math.max( - 1024, - Math.floor( - Math.min( - (findTokenLimit(model.id)?.max! - findTokenLimit(model.id)?.min!) * effortRatio + - findTokenLimit(model.id)?.min!, - (maxTokens || DEFAULT_MAX_TOKENS) * effortRatio - ) - ) - ) - - return { - type: 'enabled', - budget_tokens: budgetTokens - } - } - - private static isValidBase64ImageMediaType(mime: string): mime is Base64ImageSource['media_type'] { - return ['image/jpeg', 'image/png', 'image/gif', 'image/webp'].includes(mime) - } - - /** - * Get the message parameter - * @param message - The message - * @returns The message parameter - */ - public async convertMessageToSdkParam(message: Message): Promise { - const { textContent, imageContents } = await this.getMessageContent(message) - - const parts: MessageParam['content'] = [ - { - type: 'text', - text: textContent - } - ] - - if (imageContents.length > 0) { - for (const imageContent of imageContents) { - const base64Data = await window.api.file.base64Image(imageContent.fileId + imageContent.fileExt) - base64Data.mime = base64Data.mime.replace('jpg', 'jpeg') - if (AnthropicAPIClient.isValidBase64ImageMediaType(base64Data.mime)) { - parts.push({ - type: 'image', - source: { - data: base64Data.base64, - media_type: base64Data.mime, - type: 'base64' - } - }) - } else { - logger.warn('Unsupported image type, ignored.', { mime: base64Data.mime }) - } - } - } - - // Get and process image blocks - const imageBlocks = findImageBlocks(message) - for (const imageBlock of imageBlocks) { - if (imageBlock.file) { - // Handle uploaded file - const file = imageBlock.file - const base64Data = await window.api.file.base64Image(file.id + file.ext) - parts.push({ - type: 'image', - source: { - data: base64Data.base64, - media_type: base64Data.mime.replace('jpg', 'jpeg') as any, - type: 'base64' - } - }) - } - } - // Get and process file blocks - const fileBlocks = findFileBlocks(message) - for (const fileBlock of fileBlocks) { - const { file } = fileBlock - if ([FileTypes.TEXT, FileTypes.DOCUMENT].includes(file.type)) { - if (file.ext === '.pdf' && file.size < 32 * 1024 * 1024) { - const base64Data = await FileManager.readBase64File(file) - parts.push({ - type: 'document', - source: { - type: 'base64', - media_type: 'application/pdf', - data: base64Data - } - }) - } else { - const fileContent = await (await window.api.file.read(file.id + file.ext, true)).trim() - parts.push({ - type: 'text', - text: file.origin_name + '\n' + fileContent - }) - } - } - } - - return { - role: message.role === 'system' ? 'user' : message.role, - content: parts - } - } - - public convertMcpToolsToSdkTools(mcpTools: MCPTool[]): ToolUnion[] { - return mcpToolsToAnthropicTools(mcpTools) - } - - public convertMcpToolResponseToSdkMessageParam( - mcpToolResponse: MCPToolResponse, - resp: MCPCallToolResponse, - model: Model - ): AnthropicSdkMessageParam | undefined { - if ('toolUseId' in mcpToolResponse && mcpToolResponse.toolUseId) { - return mcpToolCallResponseToAnthropicMessage(mcpToolResponse, resp, model) - } else if ('toolCallId' in mcpToolResponse) { - return { - role: 'user', - content: [ - { - type: 'tool_result', - tool_use_id: mcpToolResponse.toolCallId!, - content: resp.content - .map((item) => { - if (item.type === 'text') { - return { - type: 'text', - text: item.text || '' - } satisfies TextBlockParam - } - if (item.type === 'image') { - return { - type: 'image', - source: { - data: item.data || '', - media_type: (item.mimeType || 'image/png') as Base64ImageSource['media_type'], - type: 'base64' - } - } satisfies ImageBlockParam - } - return - }) - .filter((n) => typeof n !== 'undefined'), - is_error: resp.isError - } satisfies ToolResultBlockParam - ] - } - } - return - } - - // Implementing abstract methods from BaseApiClient - convertSdkToolCallToMcp(toolCall: ToolUseBlock, mcpTools: MCPTool[]): MCPTool | undefined { - // Based on anthropicToolUseToMcpTool logic in AnthropicProvider - // This might need adjustment based on how tool calls are specifically handled in the new structure - const mcpTool = anthropicToolUseToMcpTool(mcpTools, toolCall) - return mcpTool - } - - convertSdkToolCallToMcpToolResponse(toolCall: ToolUseBlock, mcpTool: MCPTool): ToolCallResponse { - return { - id: toolCall.id, - toolCallId: toolCall.id, - tool: mcpTool, - arguments: toolCall.input as Record, - status: 'pending' - } as ToolCallResponse - } - - override buildSdkMessages( - currentReqMessages: AnthropicSdkMessageParam[], - output: Anthropic.Message, - toolResults: AnthropicSdkMessageParam[] - ): AnthropicSdkMessageParam[] { - const assistantMessage: AnthropicSdkMessageParam = { - role: output.role, - content: convertContentBlocksToParams(output.content) - } - - const newMessages: AnthropicSdkMessageParam[] = [...currentReqMessages, assistantMessage] - if (toolResults && toolResults.length > 0) { - newMessages.push(...toolResults) - } - return newMessages - } - - override estimateMessageTokens(message: AnthropicSdkMessageParam): number { - if (typeof message.content === 'string') { - return estimateTextTokens(message.content) - } - return message.content - .map((content) => { - switch (content.type) { - case 'text': - return estimateTextTokens(content.text) - case 'image': - if (content.source.type === 'base64') { - return estimateTextTokens(content.source.data) - } else { - return estimateTextTokens(content.source.url) - } - case 'tool_use': - return estimateTextTokens(JSON.stringify(content.input)) - case 'tool_result': - return estimateTextTokens(JSON.stringify(content.content)) - default: - return 0 - } - }) - .reduce((acc, curr) => acc + curr, 0) - } - - public buildAssistantMessage(message: Anthropic.Message): AnthropicSdkMessageParam { - const messageParam: AnthropicSdkMessageParam = { - role: message.role, - content: convertContentBlocksToParams(message.content) - } - return messageParam - } - - public extractMessagesFromSdkPayload(sdkPayload: AnthropicSdkParams): AnthropicSdkMessageParam[] { - return sdkPayload.messages || [] - } - - /** - * Anthropic专用的原始流监听器 - * 处理MessageStream对象的特定事件 - */ - attachRawStreamListener( - rawOutput: AnthropicSdkRawOutput, - listener: RawStreamListener - ): AnthropicSdkRawOutput { - logger.debug(`Attaching stream listener to raw output`) - // 专用的Anthropic事件处理 - const anthropicListener = listener as AnthropicStreamListener - // 检查是否为MessageStream - if (rawOutput instanceof MessageStream) { - logger.debug(`Detected Anthropic MessageStream, attaching specialized listener`) - - if (listener.onStart) { - listener.onStart() - } - - if (listener.onChunk) { - rawOutput.on('streamEvent', (event: AnthropicSdkRawChunk) => { - listener.onChunk!(event) - }) - } - - if (anthropicListener.onContentBlock) { - rawOutput.on('contentBlock', anthropicListener.onContentBlock) - } - - if (anthropicListener.onMessage) { - rawOutput.on('finalMessage', anthropicListener.onMessage) - } - - if (listener.onEnd) { - rawOutput.on('end', () => { - listener.onEnd!() - }) - } - - if (listener.onError) { - rawOutput.on('error', (error: Error) => { - listener.onError!(error) - }) - } - - return rawOutput - } - - if (anthropicListener.onMessage) { - anthropicListener.onMessage(rawOutput) - } - - // 对于非MessageStream响应 - return rawOutput - } - - private async getWebSearchParams(model: Model): Promise { - if (!isWebSearchModel(model)) { - return undefined - } - return { - type: 'web_search_20250305', - name: 'web_search', - max_uses: 5 - } as WebSearchTool20250305 - } - - getRequestTransformer(): RequestTransformer { - return { - transform: async ( - coreRequest, - assistant, - model, - isRecursiveCall, - recursiveSdkMessages - ): Promise<{ - payload: AnthropicSdkParams - messages: AnthropicSdkMessageParam[] - metadata: Record - }> => { - const { messages, mcpTools, maxTokens, streamOutput, enableWebSearch } = coreRequest - // 1. 处理系统消息 - const systemPrompt = assistant.prompt - - // 2. 设置工具 - const { tools } = this.setupToolsConfig({ - mcpTools: mcpTools, - model, - enableToolUse: isSupportedToolUse(assistant) - }) - - const systemMessage: TextBlockParam | undefined = systemPrompt - ? { type: 'text', text: systemPrompt } - : undefined - - // 3. 处理用户消息 - const sdkMessages: AnthropicSdkMessageParam[] = [] - if (typeof messages === 'string') { - sdkMessages.push({ role: 'user', content: messages }) - } else { - const processedMessages = addImageFileToContents(messages) - for (const message of processedMessages) { - sdkMessages.push(await this.convertMessageToSdkParam(message)) - } - } - - if (enableWebSearch) { - const webSearchTool = await this.getWebSearchParams(model) - if (webSearchTool) { - tools.push(webSearchTool) - } - } - - const commonParams: MessageCreateParamsBase = { - model: model.id, - messages: - isRecursiveCall && recursiveSdkMessages && recursiveSdkMessages.length > 0 - ? recursiveSdkMessages - : sdkMessages, - max_tokens: maxTokens || DEFAULT_MAX_TOKENS, - temperature: this.getTemperature(assistant, model), - top_p: this.getTopP(assistant, model), - system: systemMessage ? [systemMessage] : undefined, - thinking: this.getBudgetToken(assistant, model), - tools: tools.length > 0 ? tools : undefined, - stream: streamOutput, - // 只在对话场景下应用自定义参数,避免影响翻译、总结等其他业务逻辑 - // 注意:用户自定义参数总是应该覆盖其他参数 - ...(coreRequest.callType === 'chat' ? this.getCustomParameters(assistant) : {}) - } - - const timeout = this.getTimeout(model) - return { payload: commonParams, messages: sdkMessages, metadata: { timeout } } - } - } - } - - getResponseChunkTransformer(): ResponseChunkTransformer { - return () => { - let accumulatedJson = '' - const toolCalls: Record = {} - return { - async transform(rawChunk: AnthropicSdkRawChunk, controller: TransformStreamDefaultController) { - if (typeof rawChunk === 'string') { - try { - rawChunk = JSON.parse(rawChunk) - } catch (error) { - logger.error('invalid chunk', { rawChunk, error }) - throw new Error(t('error.chat.chunk.non_json')) - } - } - switch (rawChunk.type) { - case 'message': { - let i = 0 - let hasTextContent = false - let hasThinkingContent = false - - for (const content of rawChunk.content) { - switch (content.type) { - case 'text': { - if (!hasTextContent) { - controller.enqueue({ - type: ChunkType.TEXT_START - } as TextStartChunk) - hasTextContent = true - } - controller.enqueue({ - type: ChunkType.TEXT_DELTA, - text: content.text - } as TextDeltaChunk) - break - } - case 'tool_use': { - toolCalls[i] = content - i++ - break - } - case 'thinking': { - if (!hasThinkingContent) { - controller.enqueue({ - type: ChunkType.THINKING_START - } as ThinkingStartChunk) - hasThinkingContent = true - } - controller.enqueue({ - type: ChunkType.THINKING_DELTA, - text: content.thinking - } as ThinkingDeltaChunk) - break - } - case 'web_search_tool_result': { - controller.enqueue({ - type: ChunkType.LLM_WEB_SEARCH_COMPLETE, - llm_web_search: { - results: content.content, - source: WebSearchSource.ANTHROPIC - } - } as LLMWebSearchCompleteChunk) - break - } - } - } - if (i > 0) { - controller.enqueue({ - type: ChunkType.MCP_TOOL_CREATED, - tool_calls: Object.values(toolCalls) - } as MCPToolCreatedChunk) - } - controller.enqueue({ - type: ChunkType.LLM_RESPONSE_COMPLETE, - response: { - usage: { - prompt_tokens: rawChunk.usage.input_tokens || 0, - completion_tokens: rawChunk.usage.output_tokens || 0, - total_tokens: (rawChunk.usage.input_tokens || 0) + (rawChunk.usage.output_tokens || 0) - } - } - }) - break - } - case 'content_block_start': { - const contentBlock = rawChunk.content_block - switch (contentBlock.type) { - case 'server_tool_use': { - if (contentBlock.name === 'web_search') { - controller.enqueue({ - type: ChunkType.LLM_WEB_SEARCH_IN_PROGRESS - } as LLMWebSearchInProgressChunk) - } - break - } - case 'web_search_tool_result': { - if ( - contentBlock.content && - (contentBlock.content as WebSearchToolResultError).type === 'web_search_tool_result_error' - ) { - controller.enqueue({ - type: ChunkType.ERROR, - error: { - code: (contentBlock.content as WebSearchToolResultError).error_code, - message: (contentBlock.content as WebSearchToolResultError).error_code - } - } as ErrorChunk) - } else { - controller.enqueue({ - type: ChunkType.LLM_WEB_SEARCH_COMPLETE, - llm_web_search: { - results: contentBlock.content as Array, - source: WebSearchSource.ANTHROPIC - } - } as LLMWebSearchCompleteChunk) - } - break - } - case 'tool_use': { - toolCalls[rawChunk.index] = contentBlock - break - } - case 'text': { - controller.enqueue({ - type: ChunkType.TEXT_START - } as TextStartChunk) - break - } - case 'thinking': - case 'redacted_thinking': { - controller.enqueue({ - type: ChunkType.THINKING_START - } as ThinkingStartChunk) - break - } - } - break - } - case 'content_block_delta': { - const messageDelta = rawChunk.delta - switch (messageDelta.type) { - case 'text_delta': { - if (messageDelta.text) { - controller.enqueue({ - type: ChunkType.TEXT_DELTA, - text: messageDelta.text - } as TextDeltaChunk) - } - break - } - case 'thinking_delta': { - if (messageDelta.thinking) { - controller.enqueue({ - type: ChunkType.THINKING_DELTA, - text: messageDelta.thinking - } as ThinkingDeltaChunk) - } - break - } - case 'input_json_delta': { - if (messageDelta.partial_json) { - accumulatedJson += messageDelta.partial_json - } - break - } - } - break - } - case 'content_block_stop': { - const toolCall = toolCalls[rawChunk.index] - if (toolCall) { - try { - toolCall.input = accumulatedJson ? JSON.parse(accumulatedJson) : {} - logger.debug(`Tool call id: ${toolCall.id}, accumulated json: ${accumulatedJson}`) - controller.enqueue({ - type: ChunkType.MCP_TOOL_CREATED, - tool_calls: [toolCall] - } as MCPToolCreatedChunk) - } catch (error) { - logger.error('Error parsing tool call input:', error as Error) - } - } - break - } - case 'message_delta': { - controller.enqueue({ - type: ChunkType.LLM_RESPONSE_COMPLETE, - response: { - usage: { - prompt_tokens: rawChunk.usage.input_tokens || 0, - completion_tokens: rawChunk.usage.output_tokens || 0, - total_tokens: (rawChunk.usage.input_tokens || 0) + (rawChunk.usage.output_tokens || 0) - } - } - }) - } - } - } - } - } - } -} - -/** - * 将 ContentBlock 数组转换为 ContentBlockParam 数组 - * 去除服务器生成的额外字段,只保留发送给API所需的字段 - */ -function convertContentBlocksToParams(contentBlocks: ContentBlock[]): ContentBlockParam[] { - return contentBlocks.map((block): ContentBlockParam => { - switch (block.type) { - case 'text': - // TextBlock -> TextBlockParam,去除 citations 等服务器字段 - return { - type: 'text', - text: block.text - } satisfies TextBlockParam - case 'tool_use': - // ToolUseBlock -> ToolUseBlockParam - return { - type: 'tool_use', - id: block.id, - name: block.name, - input: block.input - } satisfies ToolUseBlockParam - case 'thinking': - // ThinkingBlock -> ThinkingBlockParam - return { - type: 'thinking', - thinking: block.thinking, - signature: block.signature - } satisfies ThinkingBlockParam - case 'redacted_thinking': - // RedactedThinkingBlock -> RedactedThinkingBlockParam - return { - type: 'redacted_thinking', - data: block.data - } satisfies RedactedThinkingBlockParam - case 'server_tool_use': - // ServerToolUseBlock -> ServerToolUseBlockParam - return { - type: 'server_tool_use', - id: block.id, - name: block.name, - input: block.input - } satisfies ServerToolUseBlockParam - case 'web_search_tool_result': - // WebSearchToolResultBlock -> WebSearchToolResultBlockParam - return { - type: 'web_search_tool_result', - tool_use_id: block.tool_use_id, - content: block.content - } satisfies WebSearchToolResultBlockParam - default: - return block as ContentBlockParam - } - }) -} diff --git a/src/renderer/src/aiCore/legacy/clients/anthropic/AnthropicVertexClient.ts b/src/renderer/src/aiCore/legacy/clients/anthropic/AnthropicVertexClient.ts deleted file mode 100644 index 2fe16e8875..0000000000 --- a/src/renderer/src/aiCore/legacy/clients/anthropic/AnthropicVertexClient.ts +++ /dev/null @@ -1,104 +0,0 @@ -import type Anthropic from '@anthropic-ai/sdk' -import AnthropicVertex from '@anthropic-ai/vertex-sdk' -import { loggerService } from '@logger' -import { getVertexAILocation, getVertexAIProjectId, getVertexAIServiceAccount } from '@renderer/hooks/useVertexAI' -import type { Provider } from '@renderer/types' -import { isEmpty } from 'lodash' - -import { AnthropicAPIClient } from './AnthropicAPIClient' - -const logger = loggerService.withContext('AnthropicVertexClient') - -export class AnthropicVertexClient extends AnthropicAPIClient { - sdkInstance: AnthropicVertex | undefined = undefined - private authHeaders?: Record - private authHeadersExpiry?: number - - constructor(provider: Provider) { - super(provider) - } - - private formatApiHost(host: string): string { - const forceUseOriginalHost = () => { - return host.endsWith('/') - } - - if (!host) { - return host - } - - return forceUseOriginalHost() ? host : `${host}/v1/` - } - - override getBaseURL() { - return this.formatApiHost(this.provider.apiHost) - } - - override async getSdkInstance(): Promise { - if (this.sdkInstance) { - return this.sdkInstance - } - - const serviceAccount = getVertexAIServiceAccount() - const projectId = getVertexAIProjectId() - const location = getVertexAILocation() - - if (!serviceAccount.privateKey || !serviceAccount.clientEmail || !projectId || !location) { - throw new Error('Vertex AI settings are not configured') - } - - const authHeaders = await this.getServiceAccountAuthHeaders() - - this.sdkInstance = new AnthropicVertex({ - projectId: projectId, - region: location, - dangerouslyAllowBrowser: true, - defaultHeaders: authHeaders, - baseURL: isEmpty(this.getBaseURL()) ? undefined : this.getBaseURL() - }) - - return this.sdkInstance - } - - override async listModels(): Promise { - throw new Error('Vertex AI does not support listModels method.') - } - - /** - * 获取认证头,如果配置了 service account 则从主进程获取 - */ - private async getServiceAccountAuthHeaders(): Promise | undefined> { - const serviceAccount = getVertexAIServiceAccount() - const projectId = getVertexAIProjectId() - - // 检查是否配置了 service account - if (!serviceAccount.privateKey || !serviceAccount.clientEmail || !projectId) { - return undefined - } - - // 检查是否已有有效的认证头(提前 5 分钟过期) - const now = Date.now() - if (this.authHeaders && this.authHeadersExpiry && this.authHeadersExpiry - now > 5 * 60 * 1000) { - return this.authHeaders - } - - try { - // 从主进程获取认证头 - this.authHeaders = await window.api.vertexAI.getAuthHeaders({ - projectId, - serviceAccount: { - privateKey: serviceAccount.privateKey, - clientEmail: serviceAccount.clientEmail - } - }) - - // 设置过期时间(通常认证头有效期为 1 小时) - this.authHeadersExpiry = now + 60 * 60 * 1000 - - return this.authHeaders - } catch (error: any) { - logger.error('Failed to get auth headers:', error) - throw new Error(`Service Account authentication failed: ${error.message}`) - } - } -} diff --git a/src/renderer/src/aiCore/legacy/clients/aws/AwsBedrockAPIClient.ts b/src/renderer/src/aiCore/legacy/clients/aws/AwsBedrockAPIClient.ts deleted file mode 100644 index c4b0140579..0000000000 --- a/src/renderer/src/aiCore/legacy/clients/aws/AwsBedrockAPIClient.ts +++ /dev/null @@ -1,1106 +0,0 @@ -import { BedrockClient, ListFoundationModelsCommand, ListInferenceProfilesCommand } from '@aws-sdk/client-bedrock' -import { - BedrockRuntimeClient, - type BedrockRuntimeClientConfig, - ConverseCommand, - InvokeModelCommand, - InvokeModelWithResponseStreamCommand -} from '@aws-sdk/client-bedrock-runtime' -import { loggerService } from '@logger' -import type { GenericChunk } from '@renderer/aiCore/legacy/middleware/schemas' -import { DEFAULT_MAX_TOKENS } from '@renderer/config/constant' -import { findTokenLimit, isReasoningModel } from '@renderer/config/models' -import { - getAwsBedrockAccessKeyId, - getAwsBedrockApiKey, - getAwsBedrockAuthType, - getAwsBedrockRegion, - getAwsBedrockSecretAccessKey -} from '@renderer/hooks/useAwsBedrock' -import { getAssistantSettings } from '@renderer/services/AssistantService' -import { estimateTextTokens } from '@renderer/services/TokenService' -import type { - Assistant, - GenerateImageParams, - MCPCallToolResponse, - MCPTool, - MCPToolResponse, - Model, - Provider, - ToolCallResponse -} from '@renderer/types' -import { EFFORT_RATIO, FileTypes } from '@renderer/types' -import type { MCPToolCreatedChunk, TextDeltaChunk, ThinkingDeltaChunk, ThinkingStartChunk } from '@renderer/types/chunk' -import { ChunkType } from '@renderer/types/chunk' -import type { Message } from '@renderer/types/newMessage' -import type { - AwsBedrockSdkInstance, - AwsBedrockSdkMessageParam, - AwsBedrockSdkParams, - AwsBedrockSdkRawChunk, - AwsBedrockSdkRawOutput, - AwsBedrockSdkTool, - AwsBedrockSdkToolCall, - AwsBedrockStreamChunk, - SdkModel -} from '@renderer/types/sdk' -import { convertBase64ImageToAwsBedrockFormat } from '@renderer/utils/aws-bedrock-utils' -import { - awsBedrockToolUseToMcpTool, - isSupportedToolUse, - mcpToolCallResponseToAwsBedrockMessage, - mcpToolsToAwsBedrockTools -} from '@renderer/utils/mcp-tools' -import { findFileBlocks, findImageBlocks } from '@renderer/utils/messageUtils/find' -import { t } from 'i18next' - -import { BaseApiClient } from '../BaseApiClient' -import type { RequestTransformer, ResponseChunkTransformer } from '../types' - -const logger = loggerService.withContext('AwsBedrockAPIClient') - -export class AwsBedrockAPIClient extends BaseApiClient< - AwsBedrockSdkInstance, - AwsBedrockSdkParams, - AwsBedrockSdkRawOutput, - AwsBedrockSdkRawChunk, - AwsBedrockSdkMessageParam, - AwsBedrockSdkToolCall, - AwsBedrockSdkTool -> { - constructor(provider: Provider) { - super(provider) - } - - async getSdkInstance(): Promise { - if (this.sdkInstance) { - return this.sdkInstance - } - - const region = getAwsBedrockRegion() - const authType = getAwsBedrockAuthType() - - if (!region) { - throw new Error('AWS region is required. Please configure AWS region in settings.') - } - - // Build client configuration based on auth type - let clientConfig: BedrockRuntimeClientConfig - - if (authType === 'iam') { - // IAM credentials authentication - const accessKeyId = getAwsBedrockAccessKeyId() - const secretAccessKey = getAwsBedrockSecretAccessKey() - - if (!accessKeyId || !secretAccessKey) { - throw new Error('AWS credentials are required. Please configure Access Key ID and Secret Access Key.') - } - - clientConfig = { - region, - credentials: { - accessKeyId, - secretAccessKey - } - } - } else { - // API Key authentication - const awsBedrockApiKey = getAwsBedrockApiKey() - - if (!awsBedrockApiKey) { - throw new Error('AWS Bedrock API Key is required. Please configure API Key in settings.') - } - - clientConfig = { - region, - token: { token: awsBedrockApiKey }, - authSchemePreference: ['httpBearerAuth'] - } - } - - const client = new BedrockRuntimeClient(clientConfig) - const bedrockClient = new BedrockClient(clientConfig) - - this.sdkInstance = { client, bedrockClient, region } - return this.sdkInstance - } - - override async createCompletions(payload: AwsBedrockSdkParams): Promise { - const sdk = await this.getSdkInstance() - - // 转换消息格式(用于 InvokeModelWithResponseStreamCommand) - const awsMessages = payload.messages.map((msg) => ({ - role: msg.role, - content: msg.content.map((content) => { - if (content.text) { - return { type: 'text', text: content.text } - } - if (content.image) { - // 处理图片数据,将 Uint8Array 或数字数组转换为 base64 字符串 - let base64Data = '' - if (content.image.source.bytes) { - if (typeof content.image.source.bytes === 'string') { - // 如果已经是字符串,直接使用 - base64Data = content.image.source.bytes - } else { - // 如果是数组或 Uint8Array,转换为 base64 - const uint8Array = new Uint8Array(Object.values(content.image.source.bytes)) - const binaryString = Array.from(uint8Array) - .map((byte) => String.fromCharCode(byte)) - .join('') - base64Data = btoa(binaryString) - } - } - - return { - type: 'image', - source: { - type: 'base64', - media_type: `image/${content.image.format}`, - data: base64Data - } - } - } - if (content.toolResult) { - return { - type: 'tool_result', - tool_use_id: content.toolResult.toolUseId, - content: content.toolResult.content - } - } - if (content.toolUse) { - return { - type: 'tool_use', - id: content.toolUse.toolUseId, - name: content.toolUse.name, - input: content.toolUse.input - } - } - return { type: 'text', text: 'Unknown content type' } - }) - })) - - logger.info('Creating completions with model ID:', { modelId: payload.modelId }) - - const excludeKeys = ['modelId', 'messages', 'system', 'maxTokens', 'temperature', 'topP', 'stream', 'tools'] - const additionalParams = Object.keys(payload) - .filter((key) => !excludeKeys.includes(key)) - .reduce((acc, key) => ({ ...acc, [key]: payload[key] }), {}) - - const commonParams = { - modelId: payload.modelId, - messages: awsMessages as any, - system: payload.system ? [{ text: payload.system }] : undefined, - inferenceConfig: { - maxTokens: payload.maxTokens || DEFAULT_MAX_TOKENS, - temperature: payload.temperature || 0.7, - topP: payload.topP || 1 - }, - toolConfig: - payload.tools && payload.tools.length > 0 - ? { - tools: payload.tools - } - : undefined - } - - try { - if (payload.stream) { - // 根据模型类型选择正确的 API 格式 - const requestBody = this.createRequestBodyForModel(commonParams, additionalParams) - - const command = new InvokeModelWithResponseStreamCommand({ - modelId: commonParams.modelId, - body: JSON.stringify(requestBody), - contentType: 'application/json', - accept: 'application/json' - }) - - const response = await sdk.client.send(command) - return this.createInvokeModelStreamIterator(response) - } else { - const command = new ConverseCommand(commonParams) - const response = await sdk.client.send(command) - return { output: response } - } - } catch (error) { - logger.error('Failed to create completions with AWS Bedrock:', error as Error) - throw error - } - } - - /** - * 根据模型类型创建请求体 - */ - private createRequestBodyForModel(commonParams: any, additionalParams: any): any { - const modelId = commonParams.modelId.toLowerCase() - - // Claude 系列模型使用 Anthropic API 格式 - if (modelId.includes('claude')) { - return { - anthropic_version: 'bedrock-2023-05-31', - max_tokens: commonParams.inferenceConfig.maxTokens, - temperature: commonParams.inferenceConfig.temperature, - top_p: commonParams.inferenceConfig.topP, - messages: commonParams.messages, - ...(commonParams.system && commonParams.system[0]?.text ? { system: commonParams.system[0].text } : {}), - ...(commonParams.toolConfig?.tools ? { tools: commonParams.toolConfig.tools } : {}), - ...additionalParams - } - } - - // OpenAI 系列模型 - if (modelId.includes('gpt') || modelId.includes('openai')) { - const messages: any[] = [] - - // 添加系统消息 - if (commonParams.system && commonParams.system[0]?.text) { - messages.push({ - role: 'system', - content: commonParams.system[0].text - }) - } - - // 转换消息格式 - for (const message of commonParams.messages) { - const content: any[] = [] - for (const part of message.content) { - if (part.text) { - content.push({ type: 'text', text: part.text }) - } else if (part.image) { - content.push({ - type: 'image_url', - image_url: { - url: `data:image/${part.image.format};base64,${part.image.source.bytes}` - } - }) - } - } - messages.push({ - role: message.role, - content: content.length === 1 && content[0].type === 'text' ? content[0].text : content - }) - } - - const baseBody: any = { - model: commonParams.modelId, - messages: messages, - max_tokens: commonParams.inferenceConfig.maxTokens, - temperature: commonParams.inferenceConfig.temperature, - top_p: commonParams.inferenceConfig.topP, - stream: true, - ...(commonParams.toolConfig?.tools ? { tools: commonParams.toolConfig.tools } : {}) - } - - // OpenAI 模型的 thinking 参数格式 - if (additionalParams.reasoning_effort) { - baseBody.reasoning_effort = additionalParams.reasoning_effort - delete additionalParams.reasoning_effort - } - - return { - ...baseBody, - ...additionalParams - } - } - - // Llama 系列模型 - if (modelId.includes('llama')) { - const baseBody: any = { - prompt: this.convertMessagesToPrompt(commonParams.messages, commonParams.system), - max_gen_len: commonParams.inferenceConfig.maxTokens, - temperature: commonParams.inferenceConfig.temperature, - top_p: commonParams.inferenceConfig.topP - } - - // Llama 模型的 thinking 参数格式 - if (additionalParams.thinking_mode) { - baseBody.thinking_mode = additionalParams.thinking_mode - delete additionalParams.thinking_mode - } - - return { - ...baseBody, - ...additionalParams - } - } - - // Amazon Titan 系列模型 - if (modelId.includes('titan')) { - const textGenerationConfig: any = { - maxTokenCount: commonParams.inferenceConfig.maxTokens, - temperature: commonParams.inferenceConfig.temperature, - topP: commonParams.inferenceConfig.topP - } - - // 将 thinking 相关参数添加到 textGenerationConfig 中 - if (additionalParams.thinking) { - textGenerationConfig.thinking = additionalParams.thinking - delete additionalParams.thinking - } - - return { - inputText: this.convertMessagesToPrompt(commonParams.messages, commonParams.system), - textGenerationConfig: { - ...textGenerationConfig, - ...Object.keys(additionalParams).reduce((acc, key) => { - if (['thinking_tokens', 'reasoning_mode'].includes(key)) { - acc[key] = additionalParams[key] - delete additionalParams[key] - } - return acc - }, {} as any) - }, - ...additionalParams - } - } - - // Cohere Command 系列模型 - if (modelId.includes('cohere') || modelId.includes('command')) { - const baseBody: any = { - message: this.convertMessagesToPrompt(commonParams.messages, commonParams.system), - max_tokens: commonParams.inferenceConfig.maxTokens, - temperature: commonParams.inferenceConfig.temperature, - p: commonParams.inferenceConfig.topP - } - - // Cohere 模型的 thinking 参数格式 - if (additionalParams.thinking) { - baseBody.thinking = additionalParams.thinking - delete additionalParams.thinking - } - if (additionalParams.reasoning_tokens) { - baseBody.reasoning_tokens = additionalParams.reasoning_tokens - delete additionalParams.reasoning_tokens - } - - return { - ...baseBody, - ...additionalParams - } - } - - // 默认使用通用格式 - const baseBody: any = { - prompt: this.convertMessagesToPrompt(commonParams.messages, commonParams.system), - max_tokens: commonParams.inferenceConfig.maxTokens, - temperature: commonParams.inferenceConfig.temperature, - top_p: commonParams.inferenceConfig.topP - } - - return { - ...baseBody, - ...additionalParams - } - } - - /** - * 将消息转换为简单的 prompt 格式 - */ - private convertMessagesToPrompt(messages: any[], system?: any[]): string { - let prompt = '' - - // 添加系统消息 - if (system && system[0]?.text) { - prompt += `System: ${system[0].text}\n\n` - } - - // 添加对话消息 - for (const message of messages) { - const role = message.role === 'assistant' ? 'Assistant' : 'Human' - let content = '' - - for (const part of message.content) { - if (part.text) { - content += part.text - } else if (part.image) { - content += '[Image]' - } - } - - prompt += `${role}: ${content}\n\n` - } - - prompt += 'Assistant:' - return prompt - } - - private async *createInvokeModelStreamIterator(response: any): AsyncIterable { - try { - if (response.body) { - for await (const event of response.body) { - if (event.chunk) { - const chunk: AwsBedrockStreamChunk = JSON.parse(new TextDecoder().decode(event.chunk.bytes)) - - // 转换为标准格式 - if (chunk.type === 'content_block_delta') { - yield { - contentBlockDelta: { - delta: chunk.delta, - contentBlockIndex: chunk.index - } - } - } else if (chunk.type === 'message_start') { - yield { messageStart: chunk } - } else if (chunk.type === 'message_stop') { - yield { messageStop: chunk } - } else if (chunk.type === 'content_block_start') { - yield { - contentBlockStart: { - start: chunk.content_block, - contentBlockIndex: chunk.index - } - } - } else if (chunk.type === 'content_block_stop') { - yield { - contentBlockStop: { - contentBlockIndex: chunk.index - } - } - } - } - } - } - } catch (error) { - logger.error('Error in AWS Bedrock stream iterator:', error as Error) - throw error - } - } - - // @ts-ignore sdk未提供 - // oxlint-disable-next-line @typescript-eslint/no-unused-vars - override async generateImage(_generateImageParams: GenerateImageParams): Promise { - return [] - } - - override async getEmbeddingDimensions(model?: Model): Promise { - if (!model) { - throw new Error('Model is required for AWS Bedrock embedding dimensions.') - } - - const sdk = await this.getSdkInstance() - - // AWS Bedrock 支持的嵌入模型及其维度 - const embeddingModels: Record = { - 'cohere.embed-english-v3': 1024, - 'cohere.embed-multilingual-v3': 1024, - // Amazon Titan embeddings - 'amazon.titan-embed-text-v1': 1536, - 'amazon.titan-embed-text-v2:0': 1024 - // 可以根据需要添加更多模型 - } - - // 如果是已知的嵌入模型,直接返回维度 - if (embeddingModels[model.id]) { - return embeddingModels[model.id] - } - - // 对于未知模型,尝试实际调用API获取维度 - try { - let requestBody: any - - if (model.id.startsWith('cohere.embed')) { - // Cohere Embed API 格式 - requestBody = { - texts: ['test'], - input_type: 'search_document', - embedding_types: ['float'] - } - } else if (model.id.startsWith('amazon.titan-embed')) { - // Amazon Titan Embed API 格式 - requestBody = { - inputText: 'test' - } - } else { - // 通用格式,大多数嵌入模型都支持 - requestBody = { - inputText: 'test' - } - } - - const command = new InvokeModelCommand({ - modelId: model.id, - body: JSON.stringify(requestBody), - contentType: 'application/json', - accept: 'application/json' - }) - - const response = await sdk.client.send(command) - const responseBody = JSON.parse(new TextDecoder().decode(response.body)) - - // 解析响应获取嵌入维度 - if (responseBody.embeddings && responseBody.embeddings.length > 0) { - // Cohere 格式 - if (responseBody.embeddings[0].values) { - return responseBody.embeddings[0].values.length - } - // 其他可能的格式 - if (Array.isArray(responseBody.embeddings[0])) { - return responseBody.embeddings[0].length - } - } - - if (responseBody.embedding && Array.isArray(responseBody.embedding)) { - // Amazon Titan 格式 - return responseBody.embedding.length - } - - // 如果无法解析,则抛出错误 - throw new Error(`Unable to determine embedding dimensions for model ${model.id}`) - } catch (error) { - logger.error('Failed to get embedding dimensions from AWS Bedrock:', error as Error) - - // 根据模型名称推测维度 - if (model.id.includes('titan')) { - return 1536 // Amazon Titan 默认维度 - } - if (model.id.includes('cohere')) { - return 1024 // Cohere 默认维度 - } - - throw new Error(`Unable to determine embedding dimensions for model ${model.id}: ${(error as Error).message}`) - } - } - - override async listModels(): Promise { - try { - const sdk = await this.getSdkInstance() - - // 获取支持ON_DEMAND的基础模型列表 - const modelsCommand = new ListFoundationModelsCommand({ - byInferenceType: 'ON_DEMAND', - byOutputModality: 'TEXT' - }) - const modelsResponse = await sdk.bedrockClient.send(modelsCommand) - - // 获取推理配置文件列表 - const profilesCommand = new ListInferenceProfilesCommand({}) - const profilesResponse = await sdk.bedrockClient.send(profilesCommand) - - logger.info('Found ON_DEMAND foundation models:', { count: modelsResponse.modelSummaries?.length || 0 }) - logger.info('Found inference profiles:', { count: profilesResponse.inferenceProfileSummaries?.length || 0 }) - - const models: any[] = [] - - // 处理ON_DEMAND基础模型 - if (modelsResponse.modelSummaries) { - for (const model of modelsResponse.modelSummaries) { - if (!model.modelId || !model.modelName) continue - - logger.info('Adding ON_DEMAND model', { modelId: model.modelId }) - models.push({ - id: model.modelId, - name: model.modelName, - display_name: model.modelName, - description: `${model.providerName || 'AWS'} - ${model.modelName}`, - owned_by: model.providerName || 'AWS', - provider: this.provider.id, - group: 'AWS Bedrock', - isInferenceProfile: false - }) - } - } - - // 处理推理配置文件 - if (profilesResponse.inferenceProfileSummaries) { - for (const profile of profilesResponse.inferenceProfileSummaries) { - if (!profile.inferenceProfileArn || !profile.inferenceProfileName) continue - - logger.info('Adding inference profile', { - profileArn: profile.inferenceProfileArn, - profileName: profile.inferenceProfileName - }) - - models.push({ - id: profile.inferenceProfileArn, - name: `${profile.inferenceProfileName} (Profile)`, - display_name: `${profile.inferenceProfileName} (Profile)`, - description: `AWS Inference Profile - ${profile.inferenceProfileName}`, - owned_by: 'AWS', - provider: this.provider.id, - group: 'AWS Bedrock Profiles', - isInferenceProfile: true, - inferenceProfileId: profile.inferenceProfileId, - inferenceProfileArn: profile.inferenceProfileArn - }) - } - } - - logger.info('Total models added to list', { count: models.length }) - return models - } catch (error) { - logger.error('Failed to list AWS Bedrock models:', error as Error) - return [] - } - } - - public async convertMessageToSdkParam(message: Message): Promise { - const { textContent, imageContents } = await this.getMessageContent(message) - const parts: Array<{ - text?: string - image?: { - format: 'png' | 'jpeg' | 'gif' | 'webp' - source: { - bytes?: Uint8Array - s3Location?: { - uri: string - bucketOwner?: string - } - } - } - }> = [] - - // 添加文本内容 - 只在有非空内容时添加 - if (textContent && textContent.trim()) { - parts.push({ text: textContent }) - } - - if (imageContents.length > 0) { - for (const imageContent of imageContents) { - try { - const image = await window.api.file.base64Image(imageContent.fileId + imageContent.fileExt) - const mimeType = image.mime || 'image/png' - const base64Data = image.base64 - - const awsImage = convertBase64ImageToAwsBedrockFormat(base64Data, mimeType) - if (awsImage) { - parts.push({ image: awsImage }) - } else { - // 不支持的格式,转换为文本描述 - parts.push({ text: `[Image: ${mimeType}]` }) - } - } catch (error) { - logger.error('Error processing image:', error as Error) - parts.push({ text: '[Image processing failed]' }) - } - } - } - - // 处理图片内容 - const imageBlocks = findImageBlocks(message) - for (const imageBlock of imageBlocks) { - if (imageBlock.file) { - try { - const image = await window.api.file.base64Image(imageBlock.file.id + imageBlock.file.ext) - const mimeType = image.mime || 'image/png' - const base64Data = image.base64 - - const awsImage = convertBase64ImageToAwsBedrockFormat(base64Data, mimeType) - if (awsImage) { - parts.push({ image: awsImage }) - } else { - // 不支持的格式,转换为文本描述 - parts.push({ text: `[Image: ${mimeType}]` }) - } - } catch (error) { - logger.error('Error processing image:', error as Error) - parts.push({ text: '[Image processing failed]' }) - } - } else if (imageBlock.url && imageBlock.url.startsWith('data:')) { - try { - // 处理base64图片URL - const matches = imageBlock.url.match(/^data:(.+);base64,(.*)$/) - if (matches && matches.length === 3) { - const mimeType = matches[1] - const base64Data = matches[2] - - const awsImage = convertBase64ImageToAwsBedrockFormat(base64Data, mimeType) - if (awsImage) { - parts.push({ image: awsImage }) - } else { - parts.push({ text: `[Image: ${mimeType}]` }) - } - } - } catch (error) { - logger.error('Error processing base64 image:', error as Error) - parts.push({ text: '[Image processing failed]' }) - } - } - } - - // 处理文件内容 - const fileBlocks = findFileBlocks(message) - for (const fileBlock of fileBlocks) { - const file = fileBlock.file - if (!file) { - logger.warn(`No file in the file block. Passed.`, { fileBlock }) - continue - } - - if ([FileTypes.TEXT, FileTypes.DOCUMENT].includes(file.type)) { - try { - const fileContent = (await window.api.file.read(file.id + file.ext, true)).trim() - if (fileContent) { - parts.push({ - text: `${file.origin_name}\n${fileContent}` - }) - } - } catch (error) { - logger.error('Error reading file content:', error as Error) - parts.push({ text: `[File: ${file.origin_name} - Failed to read content]` }) - } - } - } - - // 如果没有任何内容,添加默认文本而不是空文本 - if (parts.length === 0) { - parts.push({ text: 'No content provided' }) - } - - return { - role: message.role === 'system' ? 'user' : message.role, - content: parts - } - } - - getRequestTransformer(): RequestTransformer { - return { - transform: async ( - coreRequest, - assistant, - model, - isRecursiveCall, - recursiveSdkMessages - ): Promise<{ - payload: AwsBedrockSdkParams - messages: AwsBedrockSdkMessageParam[] - metadata: Record - }> => { - const { messages, mcpTools, maxTokens, streamOutput } = coreRequest - // 1. 处理系统消息 - const systemPrompt = assistant.prompt - // 2. 设置工具 - const { tools } = this.setupToolsConfig({ - mcpTools: mcpTools, - model, - enableToolUse: isSupportedToolUse(assistant) - }) - - // 3. 处理消息 - const sdkMessages: AwsBedrockSdkMessageParam[] = [] - if (typeof messages === 'string') { - sdkMessages.push({ role: 'user', content: [{ text: messages }] }) - } else { - for (const message of messages) { - sdkMessages.push(await this.convertMessageToSdkParam(message)) - } - } - - // 获取推理预算token(对所有支持推理的模型) - const budgetTokens = this.getBudgetToken(assistant, model) - - // 构建基础自定义参数 - const customParams: Record = - coreRequest.callType === 'chat' ? this.getCustomParameters(assistant) : {} - - // 根据模型类型添加 thinking 参数 - if (budgetTokens) { - const modelId = model.id.toLowerCase() - - if (modelId.includes('claude')) { - // Claude 模型使用 Anthropic 格式 - customParams.thinking = { type: 'enabled', budget_tokens: budgetTokens } - } else if (modelId.includes('gpt') || modelId.includes('openai')) { - // OpenAI 模型格式 - customParams.reasoning_effort = assistant?.settings?.reasoning_effort - } else if (modelId.includes('llama')) { - // Llama 模型格式 - customParams.thinking_mode = true - customParams.thinking_tokens = budgetTokens - } else if (modelId.includes('titan')) { - // Titan 模型格式 - customParams.thinking = { enabled: true } - customParams.thinking_tokens = budgetTokens - } else if (modelId.includes('cohere') || modelId.includes('command')) { - // Cohere 模型格式 - customParams.thinking = { enabled: true } - customParams.reasoning_tokens = budgetTokens - } - } - - const payload: AwsBedrockSdkParams = { - modelId: model.id, - messages: - isRecursiveCall && recursiveSdkMessages && recursiveSdkMessages.length > 0 - ? recursiveSdkMessages - : sdkMessages, - system: systemPrompt, - maxTokens: maxTokens || DEFAULT_MAX_TOKENS, - temperature: this.getTemperature(assistant, model), - topP: this.getTopP(assistant, model), - stream: streamOutput !== false, - tools: tools.length > 0 ? tools : undefined, - ...customParams - } - - const timeout = this.getTimeout(model) - return { payload, messages: sdkMessages, metadata: { timeout } } - } - } - } - - getResponseChunkTransformer(): ResponseChunkTransformer { - return () => { - let hasStartedText = false - let hasStartedThinking = false - let accumulatedJson = '' - const toolCalls: Record = {} - - return { - async transform(rawChunk: AwsBedrockSdkRawChunk, controller: TransformStreamDefaultController) { - logger.silly('Processing AWS Bedrock chunk:', rawChunk) - - if (typeof rawChunk === 'string') { - try { - rawChunk = JSON.parse(rawChunk) - } catch (error) { - logger.error('invalid chunk', { rawChunk, error }) - throw new Error(t('error.chat.chunk.non_json')) - } - } - - // 处理消息开始事件 - if (rawChunk.messageStart) { - controller.enqueue({ - type: ChunkType.TEXT_START - }) - hasStartedText = true - logger.debug('Message started') - } - - // 处理内容块开始事件 - 参考 Anthropic 的 content_block_start 处理 - if (rawChunk.contentBlockStart?.start?.toolUse) { - const toolUse = rawChunk.contentBlockStart.start.toolUse - const blockIndex = rawChunk.contentBlockStart.contentBlockIndex || 0 - toolCalls[blockIndex] = { - id: toolUse.toolUseId, // 设置 id 字段与 toolUseId 相同 - name: toolUse.name, - toolUseId: toolUse.toolUseId, - input: {} - } - logger.debug('Tool use started:', toolUse) - } - - // 处理内容块增量事件 - 参考 Anthropic 的 content_block_delta 处理 - if (rawChunk.contentBlockDelta?.delta?.toolUse?.input) { - const inputDelta = rawChunk.contentBlockDelta.delta.toolUse.input - accumulatedJson += inputDelta - } - - // 处理文本增量 - if (rawChunk.contentBlockDelta?.delta?.text) { - if (!hasStartedText) { - controller.enqueue({ - type: ChunkType.TEXT_START - }) - hasStartedText = true - } - - controller.enqueue({ - type: ChunkType.TEXT_DELTA, - text: rawChunk.contentBlockDelta.delta.text - } as TextDeltaChunk) - } - - // 处理thinking增量 - if ( - rawChunk.contentBlockDelta?.delta?.type === 'thinking_delta' && - rawChunk.contentBlockDelta?.delta?.thinking - ) { - if (!hasStartedThinking) { - controller.enqueue({ - type: ChunkType.THINKING_START - } as ThinkingStartChunk) - hasStartedThinking = true - } - - controller.enqueue({ - type: ChunkType.THINKING_DELTA, - text: rawChunk.contentBlockDelta.delta.thinking - } as ThinkingDeltaChunk) - } - - // 处理内容块停止事件 - 参考 Anthropic 的 content_block_stop 处理 - if (rawChunk.contentBlockStop) { - const blockIndex = rawChunk.contentBlockStop.contentBlockIndex || 0 - const toolCall = toolCalls[blockIndex] - if (toolCall && accumulatedJson) { - try { - toolCall.input = JSON.parse(accumulatedJson) - controller.enqueue({ - type: ChunkType.MCP_TOOL_CREATED, - tool_calls: [toolCall] - } as MCPToolCreatedChunk) - accumulatedJson = '' - } catch (error) { - logger.error('Error parsing tool call input:', error as Error) - } - } - } - - // 处理消息结束事件 - if (rawChunk.messageStop) { - // 从metadata中提取usage信息 - const usage = rawChunk.metadata?.usage || {} - - controller.enqueue({ - type: ChunkType.LLM_RESPONSE_COMPLETE, - response: { - usage: { - prompt_tokens: usage.inputTokens || 0, - completion_tokens: usage.outputTokens || 0, - total_tokens: (usage.inputTokens || 0) + (usage.outputTokens || 0) - } - } - }) - } - } - } - } - } - - public convertMcpToolsToSdkTools(mcpTools: MCPTool[]): AwsBedrockSdkTool[] { - return mcpToolsToAwsBedrockTools(mcpTools) - } - - convertSdkToolCallToMcp(toolCall: AwsBedrockSdkToolCall, mcpTools: MCPTool[]): MCPTool | undefined { - return awsBedrockToolUseToMcpTool(mcpTools, toolCall) - } - - convertSdkToolCallToMcpToolResponse(toolCall: AwsBedrockSdkToolCall, mcpTool: MCPTool): ToolCallResponse { - return { - id: toolCall.id, - tool: mcpTool, - arguments: toolCall.input || {}, - status: 'pending', - toolCallId: toolCall.id - } - } - - override buildSdkMessages( - currentReqMessages: AwsBedrockSdkMessageParam[], - output: AwsBedrockSdkRawOutput | string | undefined, - toolResults: AwsBedrockSdkMessageParam[] - ): AwsBedrockSdkMessageParam[] { - const messages: AwsBedrockSdkMessageParam[] = [...currentReqMessages] - - if (typeof output === 'string') { - messages.push({ - role: 'assistant', - content: [{ text: output }] - }) - } - - if (toolResults.length > 0) { - messages.push(...toolResults) - } - - return messages - } - - override estimateMessageTokens(message: AwsBedrockSdkMessageParam): number { - if (typeof message.content === 'string') { - return estimateTextTokens(message.content) - } - const content = message.content - if (Array.isArray(content)) { - return content.reduce((total, item) => { - if (item.text) { - return total + estimateTextTokens(item.text) - } - return total - }, 0) - } - return 0 - } - - public convertMcpToolResponseToSdkMessageParam( - mcpToolResponse: MCPToolResponse, - resp: MCPCallToolResponse, - model: Model - ): AwsBedrockSdkMessageParam | undefined { - if ('toolUseId' in mcpToolResponse && mcpToolResponse.toolUseId) { - // 使用专用的转换函数处理 toolUseId 情况 - return mcpToolCallResponseToAwsBedrockMessage(mcpToolResponse, resp, model) - } else if ('toolCallId' in mcpToolResponse && mcpToolResponse.toolCallId) { - return { - role: 'user', - content: [ - { - toolResult: { - toolUseId: mcpToolResponse.toolCallId, - content: resp.content - .map((item) => { - if (item.type === 'text') { - // 确保文本不为空,如果为空则提供默认文本 - return { text: item.text && item.text.trim() ? item.text : 'No text content' } - } - if (item.type === 'image' && item.data) { - const awsImage = convertBase64ImageToAwsBedrockFormat(item.data, item.mimeType) - if (awsImage) { - return { image: awsImage } - } else { - // 如果转换失败,返回描述性文本 - return { text: `[Image: ${item.mimeType || 'unknown format'}]` } - } - } - return { text: JSON.stringify(item) } - }) - .filter((content) => content !== null) - } - } - ] - } - } - return undefined - } - - extractMessagesFromSdkPayload(sdkPayload: AwsBedrockSdkParams): AwsBedrockSdkMessageParam[] { - return sdkPayload.messages || [] - } - - /** - * 获取 AWS Bedrock 的推理工作量预算token - * @param assistant - The assistant - * @param model - The model - * @returns The budget tokens for reasoning effort - */ - private getBudgetToken(assistant: Assistant, model: Model): number | undefined { - try { - if (!isReasoningModel(model)) { - return undefined - } - - const { maxTokens } = getAssistantSettings(assistant) - const reasoningEffort = assistant?.settings?.reasoning_effort - - if (reasoningEffort === undefined) { - return undefined - } - - const effortRatio = EFFORT_RATIO[reasoningEffort] - const tokenLimits = findTokenLimit(model.id) - - if (tokenLimits) { - // 使用模型特定的 token 限制 - const budgetTokens = Math.max( - 1024, - Math.floor( - Math.min( - (tokenLimits.max - tokenLimits.min) * effortRatio + tokenLimits.min, - (maxTokens || DEFAULT_MAX_TOKENS) * effortRatio - ) - ) - ) - return budgetTokens - } else { - // 对于没有特定限制的模型,使用简化计算 - const budgetTokens = Math.max(1024, Math.floor((maxTokens || DEFAULT_MAX_TOKENS) * effortRatio)) - return budgetTokens - } - } catch (error) { - logger.warn('Failed to calculate budget tokens for reasoning effort:', error as Error) - return undefined - } - } -} diff --git a/src/renderer/src/aiCore/legacy/clients/cherryai/CherryAiAPIClient.ts b/src/renderer/src/aiCore/legacy/clients/cherryai/CherryAiAPIClient.ts deleted file mode 100644 index b72e0a8829..0000000000 --- a/src/renderer/src/aiCore/legacy/clients/cherryai/CherryAiAPIClient.ts +++ /dev/null @@ -1,51 +0,0 @@ -import type OpenAI from '@cherrystudio/openai' -import type { Provider } from '@renderer/types' -import type { OpenAISdkParams, OpenAISdkRawOutput } from '@renderer/types/sdk' - -import { OpenAIAPIClient } from '../openai/OpenAIApiClient' - -export class CherryAiAPIClient extends OpenAIAPIClient { - constructor(provider: Provider) { - super(provider) - } - - override async createCompletions( - payload: OpenAISdkParams, - options?: OpenAI.RequestOptions - ): Promise { - const sdk = await this.getSdkInstance() - options = options || {} - options.headers = options.headers || {} - - const signature = await window.api.cherryai.generateSignature({ - method: 'POST', - path: '/chat/completions', - query: '', - body: payload - }) - - options.headers = { - ...options.headers, - ...signature - } - - // @ts-ignore - SDK参数可能有额外的字段 - return await sdk.chat.completions.create(payload, options) - } - - override getClientCompatibilityType(): string[] { - return ['CherryAiAPIClient'] - } - - public async listModels(): Promise { - const models = ['glm-4.5-flash', 'Qwen/Qwen3-8B'] - - const created = Date.now() - return models.map((id) => ({ - id, - owned_by: 'cherryai', - object: 'model' as const, - created - })) - } -} diff --git a/src/renderer/src/aiCore/legacy/clients/gemini/GeminiAPIClient.ts b/src/renderer/src/aiCore/legacy/clients/gemini/GeminiAPIClient.ts deleted file mode 100644 index d7f14326f6..0000000000 --- a/src/renderer/src/aiCore/legacy/clients/gemini/GeminiAPIClient.ts +++ /dev/null @@ -1,841 +0,0 @@ -import type { - Content, - File, - FunctionCall, - GenerateContentConfig, - GenerateImagesConfig, - Model as GeminiModel, - Part, - SafetySetting, - SendMessageParameters, - ThinkingConfig, - Tool -} from '@google/genai' -import { createPartFromUri, GoogleGenAI, HarmBlockThreshold, HarmCategory, Modality } from '@google/genai' -import { loggerService } from '@logger' -import { nanoid } from '@reduxjs/toolkit' -import { - findTokenLimit, - GEMINI_FLASH_MODEL_REGEX, - isGemmaModel, - isSupportedThinkingTokenGeminiModel, - isVisionModel -} from '@renderer/config/models' -import { estimateTextTokens } from '@renderer/services/TokenService' -import type { - Assistant, - FileMetadata, - FileUploadResponse, - GenerateImageParams, - MCPCallToolResponse, - MCPTool, - MCPToolResponse, - Model, - Provider, - ToolCallResponse -} from '@renderer/types' -import { EFFORT_RATIO, FileTypes, WebSearchSource } from '@renderer/types' -import type { LLMWebSearchCompleteChunk, TextStartChunk, ThinkingStartChunk } from '@renderer/types/chunk' -import { ChunkType } from '@renderer/types/chunk' -import type { Message } from '@renderer/types/newMessage' -import type { - GeminiOptions, - GeminiSdkMessageParam, - GeminiSdkParams, - GeminiSdkRawChunk, - GeminiSdkRawOutput, - GeminiSdkToolCall -} from '@renderer/types/sdk' -import { isToolUseModeFunction } from '@renderer/utils/assistant' -import { - geminiFunctionCallToMcpTool, - isSupportedToolUse, - mcpToolCallResponseToGeminiMessage, - mcpToolsToGeminiTools -} from '@renderer/utils/mcp-tools' -import { findFileBlocks, findImageBlocks, getMainTextContent } from '@renderer/utils/messageUtils/find' -import { defaultTimeout, MB } from '@shared/config/constant' -import { getTrailingApiVersion, withoutTrailingApiVersion } from '@shared/utils' -import { t } from 'i18next' - -import type { GenericChunk } from '../../middleware/schemas' -import { BaseApiClient } from '../BaseApiClient' -import type { RequestTransformer, ResponseChunkTransformer } from '../types' - -const logger = loggerService.withContext('GeminiAPIClient') - -export class GeminiAPIClient extends BaseApiClient< - GoogleGenAI, - GeminiSdkParams, - GeminiSdkRawOutput, - GeminiSdkRawChunk, - GeminiSdkMessageParam, - GeminiSdkToolCall, - Tool -> { - constructor(provider: Provider) { - super(provider) - } - - override async createCompletions(payload: GeminiSdkParams, options?: GeminiOptions): Promise { - const sdk = await this.getSdkInstance() - const { model, history, ...rest } = payload - const realPayload: Omit = { - ...rest, - config: { - ...rest.config, - abortSignal: options?.signal, - httpOptions: { - ...rest.config?.httpOptions, - timeout: options?.timeout - } - } - } satisfies SendMessageParameters - - const streamOutput = options?.streamOutput - - const chat = sdk.chats.create({ - model: model, - history: history - }) - - if (streamOutput) { - const stream = chat.sendMessageStream(realPayload) - return stream - } else { - const response = await chat.sendMessage(realPayload) - return response - } - } - - override async generateImage(generateImageParams: GenerateImageParams): Promise { - const sdk = await this.getSdkInstance() - try { - const { model, prompt, imageSize, batchSize, signal } = generateImageParams - const config: GenerateImagesConfig = { - numberOfImages: batchSize, - aspectRatio: imageSize, - abortSignal: signal, - httpOptions: { - timeout: defaultTimeout - } - } - const response = await sdk.models.generateImages({ - model: model, - prompt, - config - }) - - if (!response.generatedImages || response.generatedImages.length === 0) { - return [] - } - - const images = response.generatedImages - .filter((image) => image.image?.imageBytes) - .map((image) => { - const dataPrefix = `data:${image.image?.mimeType || 'image/png'};base64,` - return dataPrefix + image.image?.imageBytes - }) - // console.log(response?.generatedImages?.[0]?.image?.imageBytes); - return images - } catch (error) { - logger.error('[generateImage] error:', error as Error) - throw error - } - } - - override async getEmbeddingDimensions(model: Model): Promise { - const sdk = await this.getSdkInstance() - - const data = await sdk.models.embedContent({ - model: model.id, - contents: [{ role: 'user', parts: [{ text: 'hi' }] }] - }) - return data.embeddings?.[0]?.values?.length || 0 - } - - override async listModels(): Promise { - const sdk = await this.getSdkInstance() - const response = await sdk.models.list() - const models: GeminiModel[] = [] - for await (const model of response) { - models.push(model) - } - return models - } - - override getBaseURL(): string { - return withoutTrailingApiVersion(super.getBaseURL()) - } - - override async getSdkInstance() { - if (this.sdkInstance) { - return this.sdkInstance - } - - const apiVersion = this.getApiVersion() - - this.sdkInstance = new GoogleGenAI({ - vertexai: false, - apiKey: this.apiKey, - apiVersion, - httpOptions: { - baseUrl: this.getBaseURL(), - apiVersion, - headers: { - ...this.provider.extra_headers - } - } - }) - - return this.sdkInstance - } - - protected getApiVersion(): string { - if (this.provider.isVertex) { - return 'v1' - } - - // Extract trailing API version from the URL - const trailingVersion = getTrailingApiVersion(this.provider.apiHost || '') - if (trailingVersion) { - return trailingVersion - } - - return '' - } - - /** - * Handle a PDF file - * @param file - The file - * @returns The part - */ - private async handlePdfFile(file: FileMetadata): Promise { - const smallFileSize = 20 * MB - const isSmallFile = file.size < smallFileSize - - if (isSmallFile) { - const { data, mimeType } = await this.base64File(file) - return { - inlineData: { - data, - mimeType - } - } - } - - // Retrieve file from Gemini uploaded files - const fileMetadata: FileUploadResponse = await window.api.fileService.retrieve(this.provider, file.id) - - if (fileMetadata.status === 'success') { - const remoteFile = fileMetadata.originalFile?.file as File - return createPartFromUri(remoteFile.uri!, remoteFile.mimeType!) - } - - // If file is not found, upload it to Gemini - const result = await window.api.fileService.upload(this.provider, file) - const remoteFile = result.originalFile - if (!remoteFile) { - throw new Error('File upload failed, please try again') - } - if (remoteFile.type === 'gemini') { - const file = remoteFile.file - if (!file.uri) { - throw new Error('File URI is required but not found') - } - if (!file.mimeType) { - throw new Error('File MIME type is required but not found') - } - return createPartFromUri(file.uri, file.mimeType) - } else { - throw new Error('Unsupported file type for Gemini API') - } - } - - /** - * Get the message contents - * @param message - The message - * @returns The message contents - */ - private async convertMessageToSdkParam(message: Message): Promise { - const role = message.role === 'user' ? 'user' : 'model' - const { textContent, imageContents } = await this.getMessageContent(message) - const parts: Part[] = [{ text: textContent }] - - if (imageContents.length > 0) { - for (const imageContent of imageContents) { - const image = await window.api.file.base64Image(imageContent.fileId + imageContent.fileExt) - parts.push({ - inlineData: { - data: image.base64, - mimeType: image.mime - } satisfies Part['inlineData'] - }) - } - } - - // Add any generated images from previous responses - const imageBlocks = findImageBlocks(message) - for (const imageBlock of imageBlocks) { - if ( - imageBlock.metadata?.generateImageResponse?.images && - imageBlock.metadata.generateImageResponse.images.length > 0 - ) { - for (const imageUrl of imageBlock.metadata.generateImageResponse.images) { - if (imageUrl && imageUrl.startsWith('data:')) { - // Extract base64 data and mime type from the data URL - const matches = imageUrl.match(/^data:(.+);base64,(.*)$/) - if (matches && matches.length === 3) { - const mimeType = matches[1] - const base64Data = matches[2] - parts.push({ - inlineData: { - data: base64Data, - mimeType: mimeType - } satisfies Part['inlineData'] - }) - } - } - } - } - const file = imageBlock.file - if (file) { - const base64Data = await window.api.file.base64Image(file.id + file.ext) - parts.push({ - inlineData: { - data: base64Data.base64, - mimeType: base64Data.mime - } satisfies Part['inlineData'] - }) - } - } - - const fileBlocks = findFileBlocks(message) - for (const fileBlock of fileBlocks) { - const file = fileBlock.file - if (file.type === FileTypes.IMAGE) { - const base64Data = await window.api.file.base64Image(file.id + file.ext) - parts.push({ - inlineData: { - data: base64Data.base64, - mimeType: base64Data.mime - } satisfies Part['inlineData'] - }) - } - - if (file.ext === '.pdf') { - parts.push(await this.handlePdfFile(file)) - continue - } - if ([FileTypes.TEXT, FileTypes.DOCUMENT].includes(file.type)) { - const fileContent = await (await window.api.file.read(file.id + file.ext, true)).trim() - parts.push({ - text: file.origin_name + '\n' + fileContent - }) - } - } - - return { - role, - parts: parts - } - } - - // @ts-ignore unused - private async getImageFileContents(message: Message): Promise { - const role = message.role === 'user' ? 'user' : 'model' - const content = getMainTextContent(message) - const parts: Part[] = [{ text: content }] - const imageBlocks = findImageBlocks(message) - for (const imageBlock of imageBlocks) { - if ( - imageBlock.metadata?.generateImageResponse?.images && - imageBlock.metadata.generateImageResponse.images.length > 0 - ) { - for (const imageUrl of imageBlock.metadata.generateImageResponse.images) { - if (imageUrl && imageUrl.startsWith('data:')) { - // Extract base64 data and mime type from the data URL - const matches = imageUrl.match(/^data:(.+);base64,(.*)$/) - if (matches && matches.length === 3) { - const mimeType = matches[1] - const base64Data = matches[2] - parts.push({ - inlineData: { - data: base64Data, - mimeType: mimeType - } satisfies Part['inlineData'] - }) - } - } - } - } - const file = imageBlock.file - if (file) { - const base64Data = await window.api.file.base64Image(file.id + file.ext) - parts.push({ - inlineData: { - data: base64Data.base64, - mimeType: base64Data.mime - } satisfies Part['inlineData'] - }) - } - } - return { - role, - parts: parts - } - } - - /** - * Get the safety settings - * @returns The safety settings - */ - private getSafetySettings(): SafetySetting[] { - const safetyThreshold = HarmBlockThreshold.OFF - - return [ - { - category: HarmCategory.HARM_CATEGORY_HATE_SPEECH, - threshold: safetyThreshold - }, - { - category: HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, - threshold: safetyThreshold - }, - { - category: HarmCategory.HARM_CATEGORY_HARASSMENT, - threshold: safetyThreshold - }, - { - category: HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, - threshold: safetyThreshold - }, - { - category: HarmCategory.HARM_CATEGORY_CIVIC_INTEGRITY, - threshold: HarmBlockThreshold.BLOCK_NONE - } - ] - } - - /** - * Get the reasoning effort for the assistant - * @param assistant - The assistant - * @param model - The model - * @returns The reasoning effort - */ - private getBudgetToken(assistant: Assistant, model: Model) { - if (isSupportedThinkingTokenGeminiModel(model)) { - const reasoningEffort = assistant?.settings?.reasoning_effort - - // 如果thinking_budget是undefined,不思考 - if (reasoningEffort === undefined) { - return GEMINI_FLASH_MODEL_REGEX.test(model.id) - ? { - thinkingConfig: { - thinkingBudget: 0 - } - } - : {} - } - - if (reasoningEffort === 'auto') { - return { - thinkingConfig: { - includeThoughts: true, - thinkingBudget: -1 - } - } - } - const effortRatio = EFFORT_RATIO[reasoningEffort] - const { min, max } = findTokenLimit(model.id) || { min: 0, max: 0 } - // 计算 budgetTokens,确保不低于 min - const budget = Math.floor((max - min) * effortRatio + min) - - return { - thinkingConfig: { - ...(budget > 0 ? { thinkingBudget: budget } : {}), - includeThoughts: true - } satisfies ThinkingConfig - } - } - - return {} - } - - private getGenerateImageParameter(): Partial { - return { - systemInstruction: undefined, - responseModalities: [Modality.TEXT, Modality.IMAGE] - } - } - - getRequestTransformer(): RequestTransformer { - return { - transform: async ( - coreRequest, - assistant, - model, - isRecursiveCall, - recursiveSdkMessages - ): Promise<{ - payload: GeminiSdkParams - messages: GeminiSdkMessageParam[] - metadata: Record - }> => { - const { messages, mcpTools, maxTokens, enableWebSearch, enableUrlContext, enableGenerateImage } = coreRequest - // 1. 处理系统消息 - const systemInstruction = assistant.prompt - - // 2. 设置工具 - const { tools } = this.setupToolsConfig({ - mcpTools, - model, - enableToolUse: isSupportedToolUse(assistant) - }) - - let messageContents: Content = { role: 'user', parts: [] } // Initialize messageContents - const history: Content[] = [] - // 3. 处理用户消息 - if (typeof messages === 'string') { - messageContents = { - role: 'user', - parts: [{ text: messages }] - } - } else { - const userLastMessage = messages.pop() - if (userLastMessage) { - messageContents = await this.convertMessageToSdkParam(userLastMessage) - for (const message of messages) { - history.push(await this.convertMessageToSdkParam(message)) - } - messages.push(userLastMessage) - } - } - - if (tools.length === 0 || !isToolUseModeFunction(assistant)) { - if (enableWebSearch) { - tools.push({ - googleSearch: {} - }) - } - - if (enableUrlContext) { - tools.push({ - urlContext: {} - }) - } - } else if (enableWebSearch || enableUrlContext) { - logger.warn('Native tools cannot be used with function calling for now.') - } - - if (isGemmaModel(model) && assistant.prompt) { - const isFirstMessage = history.length === 0 - if (isFirstMessage && messageContents) { - const userMessageText = - messageContents.parts && messageContents.parts.length > 0 ? (messageContents.parts[0].text ?? '') : '' - const systemMessage = [ - { - text: - 'user\n' + - systemInstruction + - '\n' + - 'user\n' + - userMessageText + - '' - } - ] satisfies Part[] - if (messageContents && messageContents.parts) { - messageContents.parts[0] = systemMessage[0] - } - } - } - - const newHistory = - isRecursiveCall && recursiveSdkMessages && recursiveSdkMessages.length > 0 - ? recursiveSdkMessages.slice(0, recursiveSdkMessages.length - 1) - : history - - const newMessageContents = - isRecursiveCall && recursiveSdkMessages && recursiveSdkMessages.length > 0 - ? recursiveSdkMessages[recursiveSdkMessages.length - 1] - : messageContents - - const generateContentConfig: GenerateContentConfig = { - safetySettings: this.getSafetySettings(), - systemInstruction: isGemmaModel(model) ? undefined : systemInstruction, - temperature: this.getTemperature(assistant, model), - topP: this.getTopP(assistant, model), - maxOutputTokens: maxTokens, - tools: tools, - ...(enableGenerateImage ? this.getGenerateImageParameter() : {}), - ...this.getBudgetToken(assistant, model), - // 只在对话场景下应用自定义参数,避免影响翻译、总结等其他业务逻辑 - // 注意:用户自定义参数总是应该覆盖其他参数 - ...(coreRequest.callType === 'chat' ? this.getCustomParameters(assistant) : {}) - } - - const param: GeminiSdkParams = { - model: model.id, - config: generateContentConfig, - history: newHistory, - message: newMessageContents.parts! - } - - return { - payload: param, - messages: [messageContents], - metadata: {} - } - } - } - } - - getResponseChunkTransformer(): ResponseChunkTransformer { - const toolCalls: FunctionCall[] = [] - let isFirstTextChunk = true - let isFirstThinkingChunk = true - return () => ({ - async transform(chunk: GeminiSdkRawChunk, controller: TransformStreamDefaultController) { - logger.silly('chunk', chunk) - if (typeof chunk === 'string') { - try { - chunk = JSON.parse(chunk) - } catch (error) { - logger.error('invalid chunk', { chunk, error }) - throw new Error(t('error.chat.chunk.non_json')) - } - } - if (chunk.candidates && chunk.candidates.length > 0) { - for (const candidate of chunk.candidates) { - if (candidate.content) { - candidate.content.parts?.forEach((part) => { - const text = part.text || '' - if (part.thought) { - if (isFirstThinkingChunk) { - controller.enqueue({ - type: ChunkType.THINKING_START - } satisfies ThinkingStartChunk) - isFirstThinkingChunk = false - } - controller.enqueue({ - type: ChunkType.THINKING_DELTA, - text: text - }) - } else if (part.text) { - if (isFirstTextChunk) { - controller.enqueue({ - type: ChunkType.TEXT_START - } satisfies TextStartChunk) - isFirstTextChunk = false - } - controller.enqueue({ - type: ChunkType.TEXT_DELTA, - text: text - }) - } else if (part.inlineData) { - controller.enqueue({ - type: ChunkType.IMAGE_COMPLETE, - image: { - type: 'base64', - images: [ - part.inlineData?.data?.startsWith('data:') - ? part.inlineData?.data - : `data:${part.inlineData?.mimeType || 'image/png'};base64,${part.inlineData?.data}` - ] - } - }) - } else if (part.functionCall) { - toolCalls.push(part.functionCall) - } - }) - } - - if (candidate.finishReason) { - if (candidate.groundingMetadata) { - controller.enqueue({ - type: ChunkType.LLM_WEB_SEARCH_COMPLETE, - llm_web_search: { - results: candidate.groundingMetadata, - source: WebSearchSource.GEMINI - } - } satisfies LLMWebSearchCompleteChunk) - } - if (toolCalls.length > 0) { - controller.enqueue({ - type: ChunkType.MCP_TOOL_CREATED, - tool_calls: [...toolCalls] - }) - toolCalls.length = 0 - } - controller.enqueue({ - type: ChunkType.LLM_RESPONSE_COMPLETE, - response: { - usage: { - prompt_tokens: chunk.usageMetadata?.promptTokenCount || 0, - completion_tokens: - (chunk.usageMetadata?.totalTokenCount || 0) - (chunk.usageMetadata?.promptTokenCount || 0), - total_tokens: chunk.usageMetadata?.totalTokenCount || 0 - } - } - }) - } - } - } - - if (toolCalls.length > 0) { - controller.enqueue({ - type: ChunkType.MCP_TOOL_CREATED, - tool_calls: toolCalls - }) - } - } - }) - } - - public convertMcpToolsToSdkTools(mcpTools: MCPTool[]): Tool[] { - return mcpToolsToGeminiTools(mcpTools) - } - - public convertSdkToolCallToMcp(toolCall: GeminiSdkToolCall, mcpTools: MCPTool[]): MCPTool | undefined { - return geminiFunctionCallToMcpTool(mcpTools, toolCall) - } - - public convertSdkToolCallToMcpToolResponse(toolCall: GeminiSdkToolCall, mcpTool: MCPTool): ToolCallResponse { - const parsedArgs = (() => { - try { - return typeof toolCall.args === 'string' ? JSON.parse(toolCall.args) : toolCall.args - } catch { - return toolCall.args - } - })() - - return { - id: toolCall.id || nanoid(), - toolCallId: toolCall.id, - tool: mcpTool, - arguments: parsedArgs, - status: 'pending' - } satisfies ToolCallResponse - } - - public convertMcpToolResponseToSdkMessageParam( - mcpToolResponse: MCPToolResponse, - resp: MCPCallToolResponse, - model: Model - ): GeminiSdkMessageParam | undefined { - if ('toolUseId' in mcpToolResponse && mcpToolResponse.toolUseId) { - return mcpToolCallResponseToGeminiMessage(mcpToolResponse, resp, isVisionModel(model)) - } else if ('toolCallId' in mcpToolResponse) { - return { - role: 'user', - parts: [ - { - functionResponse: { - id: mcpToolResponse.toolCallId, - name: mcpToolResponse.tool.id, - response: { - output: !resp.isError ? resp.content : undefined, - error: resp.isError ? resp.content : undefined - } - } - } - ] - } satisfies Content - } - return - } - - public buildSdkMessages( - currentReqMessages: Content[], - output: string, - toolResults: Content[], - toolCalls: FunctionCall[] - ): Content[] { - const parts: Part[] = [] - const modelParts: Part[] = [] - if (output) { - modelParts.push({ - text: output - }) - } - - toolCalls.forEach((toolCall) => { - modelParts.push({ - functionCall: toolCall - }) - }) - - parts.push( - ...toolResults - .map((ts) => ts.parts) - .flat() - .filter((p) => p !== undefined) - ) - - const userMessage: Content = { - role: 'user', - parts: [] - } - - if (modelParts.length > 0) { - currentReqMessages.push({ - role: 'model', - parts: modelParts - }) - } - if (parts.length > 0) { - userMessage.parts?.push(...parts) - currentReqMessages.push(userMessage) - } - - return currentReqMessages - } - - override estimateMessageTokens(message: GeminiSdkMessageParam): number { - return ( - message.parts?.reduce((acc, part) => { - if (part.text) { - return acc + estimateTextTokens(part.text) - } - if (part.functionCall) { - return acc + estimateTextTokens(JSON.stringify(part.functionCall)) - } - if (part.functionResponse) { - return acc + estimateTextTokens(JSON.stringify(part.functionResponse.response)) - } - if (part.inlineData) { - return acc + estimateTextTokens(part.inlineData.data || '') - } - if (part.fileData) { - return acc + estimateTextTokens(part.fileData.fileUri || '') - } - return acc - }, 0) || 0 - ) - } - - public extractMessagesFromSdkPayload(sdkPayload: GeminiSdkParams): GeminiSdkMessageParam[] { - const messageParam: GeminiSdkMessageParam = { - role: 'user', - parts: [] - } - if (Array.isArray(sdkPayload.message)) { - sdkPayload.message.forEach((part) => { - if (typeof part === 'string') { - messageParam.parts?.push({ text: part }) - } else if (typeof part === 'object') { - messageParam.parts?.push(part) - } - }) - } - return [...(sdkPayload.history || []), messageParam] - } - - private async base64File(file: FileMetadata) { - const { data } = await window.api.file.base64File(file.id + file.ext) - return { - data, - mimeType: 'application/pdf' - } - } -} diff --git a/src/renderer/src/aiCore/legacy/clients/gemini/VertexAPIClient.ts b/src/renderer/src/aiCore/legacy/clients/gemini/VertexAPIClient.ts deleted file mode 100644 index fb371d9ae5..0000000000 --- a/src/renderer/src/aiCore/legacy/clients/gemini/VertexAPIClient.ts +++ /dev/null @@ -1,143 +0,0 @@ -import { GoogleGenAI } from '@google/genai' -import { loggerService } from '@logger' -import { createVertexProvider, isVertexAIConfigured } from '@renderer/hooks/useVertexAI' -import type { Model, Provider, VertexProvider } from '@renderer/types' -import { isVertexProvider } from '@renderer/utils/provider' -import { isEmpty } from 'lodash' - -import { AnthropicVertexClient } from '../anthropic/AnthropicVertexClient' -import { GeminiAPIClient } from './GeminiAPIClient' - -const logger = loggerService.withContext('VertexAPIClient') -export class VertexAPIClient extends GeminiAPIClient { - private authHeaders?: Record - private authHeadersExpiry?: number - private anthropicVertexClient: AnthropicVertexClient - private vertexProvider: VertexProvider - - constructor(provider: Provider) { - super(provider) - // 检查 VertexAI 配置 - if (!isVertexAIConfigured()) { - throw new Error('VertexAI is not configured. Please configure project, location and service account credentials.') - } - this.anthropicVertexClient = new AnthropicVertexClient(provider) - // 如果传入的是普通 Provider,转换为 VertexProvider - if (isVertexProvider(provider)) { - this.vertexProvider = provider - } else { - this.vertexProvider = createVertexProvider(provider) - } - } - - override getClientCompatibilityType(model?: Model): string[] { - if (!model) { - return [this.constructor.name] - } - - const actualClient = this.getClient(model) - if (actualClient === this) { - return [this.constructor.name] - } - - return actualClient.getClientCompatibilityType(model) - } - - public getClient(model: Model) { - if (model.id.includes('claude')) { - return this.anthropicVertexClient - } - return this - } - - private formatApiHost(baseUrl: string) { - if (baseUrl.endsWith('/v1/')) { - baseUrl = baseUrl.slice(0, -4) - } else if (baseUrl.endsWith('/v1')) { - baseUrl = baseUrl.slice(0, -3) - } - return baseUrl - } - - override getBaseURL() { - return this.formatApiHost(this.provider.apiHost) - } - - override async getSdkInstance() { - if (this.sdkInstance) { - return this.sdkInstance - } - - const { googleCredentials, project, location } = this.vertexProvider - - if (!googleCredentials.privateKey || !googleCredentials.clientEmail || !project || !location) { - throw new Error('Vertex AI settings are not configured') - } - - const authHeaders = await this.getServiceAccountAuthHeaders() - - this.sdkInstance = new GoogleGenAI({ - vertexai: true, - project: project, - location: location, - httpOptions: { - apiVersion: this.getApiVersion(), - headers: authHeaders, - baseUrl: isEmpty(this.getBaseURL()) ? undefined : this.getBaseURL() - } - }) - - return this.sdkInstance - } - - /** - * 获取认证头,如果配置了 service account 则从主进程获取 - */ - private async getServiceAccountAuthHeaders(): Promise | undefined> { - const { googleCredentials, project } = this.vertexProvider - - // 检查是否配置了 service account - if (!googleCredentials.privateKey || !googleCredentials.clientEmail || !project) { - return undefined - } - - // 检查是否已有有效的认证头(提前 5 分钟过期) - const now = Date.now() - if (this.authHeaders && this.authHeadersExpiry && this.authHeadersExpiry - now > 5 * 60 * 1000) { - return this.authHeaders - } - - try { - // 从主进程获取认证头 - this.authHeaders = await window.api.vertexAI.getAuthHeaders({ - projectId: project, - serviceAccount: { - privateKey: googleCredentials.privateKey, - clientEmail: googleCredentials.clientEmail - } - }) - - // 设置过期时间(通常认证头有效期为 1 小时) - this.authHeadersExpiry = now + 60 * 60 * 1000 - - return this.authHeaders - } catch (error: any) { - logger.error('Failed to get auth headers:', error) - throw new Error(`Service Account authentication failed: ${error.message}`) - } - } - - /** - * 清理认证缓存并重新初始化 - */ - clearAuthCache(): void { - this.authHeaders = undefined - this.authHeadersExpiry = undefined - - const { googleCredentials, project } = this.vertexProvider - - if (project && googleCredentials.clientEmail) { - window.api.vertexAI.clearAuthCache(project, googleCredentials.clientEmail) - } - } -} diff --git a/src/renderer/src/aiCore/legacy/clients/index.ts b/src/renderer/src/aiCore/legacy/clients/index.ts deleted file mode 100644 index f364dbcee6..0000000000 --- a/src/renderer/src/aiCore/legacy/clients/index.ts +++ /dev/null @@ -1,8 +0,0 @@ -export * from './ApiClientFactory' -export * from './BaseApiClient' -export * from './types' - -// Export specific clients from subdirectories -export * from './anthropic/AnthropicAPIClient' -export * from './openai/OpenAIApiClient' -export * from './openai/OpenAIResponseAPIClient' diff --git a/src/renderer/src/aiCore/legacy/clients/newapi/NewAPIClient.ts b/src/renderer/src/aiCore/legacy/clients/newapi/NewAPIClient.ts deleted file mode 100644 index f3e04e0d55..0000000000 --- a/src/renderer/src/aiCore/legacy/clients/newapi/NewAPIClient.ts +++ /dev/null @@ -1,110 +0,0 @@ -import { loggerService } from '@logger' -import { isSupportedModel } from '@renderer/config/models' -import type { Model, Provider } from '@renderer/types' -import type { NewApiModel } from '@renderer/types/sdk' - -import { AnthropicAPIClient } from '../anthropic/AnthropicAPIClient' -import type { BaseApiClient } from '../BaseApiClient' -import { GeminiAPIClient } from '../gemini/GeminiAPIClient' -import { MixedBaseAPIClient } from '../MixedBaseApiClient' -import { OpenAIAPIClient } from '../openai/OpenAIApiClient' -import { OpenAIResponseAPIClient } from '../openai/OpenAIResponseAPIClient' - -const logger = loggerService.withContext('NewAPIClient') - -export class NewAPIClient extends MixedBaseAPIClient { - // 使用联合类型而不是any,保持类型安全 - protected clients: Map = - new Map() - protected defaultClient: OpenAIAPIClient - protected currentClient: BaseApiClient - - constructor(provider: Provider) { - super(provider) - - const claudeClient = new AnthropicAPIClient(provider) - const geminiClient = new GeminiAPIClient(provider) - const openaiClient = new OpenAIAPIClient(provider) - const openaiResponseClient = new OpenAIResponseAPIClient(provider) - - this.clients.set('claude', claudeClient) - this.clients.set('gemini', geminiClient) - this.clients.set('openai', openaiClient) - this.clients.set('openai-response', openaiResponseClient) - - // 设置默认client - this.defaultClient = openaiClient - this.currentClient = this.defaultClient as BaseApiClient - } - - override getBaseURL(): string { - if (!this.currentClient) { - return this.provider.apiHost - } - return this.currentClient.getBaseURL() - } - - /** - * 根据模型获取合适的client - */ - protected getClient(model: Model): BaseApiClient { - if (!model.endpoint_type) { - throw new Error('Model endpoint type is not defined') - } - - if (model.endpoint_type === 'anthropic') { - const client = this.clients.get('claude') - if (!client || !this.isValidClient(client)) { - throw new Error('Failed to get claude client') - } - return client - } - - if (model.endpoint_type === 'openai-response') { - const client = this.clients.get('openai-response') - if (!client || !this.isValidClient(client)) { - throw new Error('Failed to get openai-response client') - } - return client - } - - if (model.endpoint_type === 'gemini') { - const client = this.clients.get('gemini') - if (!client || !this.isValidClient(client)) { - throw new Error('Failed to get gemini client') - } - return client - } - - if (model.endpoint_type === 'openai' || model.endpoint_type === 'image-generation') { - const client = this.clients.get('openai') - if (!client || !this.isValidClient(client)) { - throw new Error('Failed to get openai client') - } - return client - } - - throw new Error('Invalid model endpoint type: ' + model.endpoint_type) - } - - override async listModels(): Promise { - try { - const sdk = await this.defaultClient.getSdkInstance() - // Explicitly type the expected response shape so that `data` is recognised. - const response = await sdk.request<{ data: NewApiModel[] }>({ - method: 'get', - path: '/models' - }) - const models: NewApiModel[] = response.data ?? [] - - models.forEach((model) => { - model.id = model.id.trim() - }) - - return models.filter(isSupportedModel) - } catch (error) { - logger.error('Error listing models:', error as Error) - return [] - } - } -} diff --git a/src/renderer/src/aiCore/legacy/clients/openai/OpenAIApiClient.ts b/src/renderer/src/aiCore/legacy/clients/openai/OpenAIApiClient.ts deleted file mode 100644 index 73a5bed4fe..0000000000 --- a/src/renderer/src/aiCore/legacy/clients/openai/OpenAIApiClient.ts +++ /dev/null @@ -1,1104 +0,0 @@ -import type { AzureOpenAI } from '@cherrystudio/openai' -import type OpenAI from '@cherrystudio/openai' -import type { - ChatCompletionContentPart, - ChatCompletionContentPartRefusal, - ChatCompletionTool -} from '@cherrystudio/openai/resources' -import { loggerService } from '@logger' -import { DEFAULT_MAX_TOKENS } from '@renderer/config/constant' -import { - findTokenLimit, - GEMINI_FLASH_MODEL_REGEX, - getModelSupportedReasoningEffortOptions, - isDeepSeekHybridInferenceModel, - isDoubaoThinkingAutoModel, - isGPT5SeriesModel, - isGrokReasoningModel, - isNotSupportSystemMessageModel, - isOpenAIDeepResearchModel, - isOpenAIOpenWeightModel, - isOpenAIReasoningModel, - isQwenAlwaysThinkModel, - isQwenMTModel, - isQwenReasoningModel, - isReasoningModel, - isSupportedReasoningEffortModel, - isSupportedReasoningEffortOpenAIModel, - isSupportedThinkingTokenClaudeModel, - isSupportedThinkingTokenDoubaoModel, - isSupportedThinkingTokenGeminiModel, - isSupportedThinkingTokenHunyuanModel, - isSupportedThinkingTokenModel, - isSupportedThinkingTokenQwenModel, - isSupportedThinkingTokenZhipuModel, - isVisionModel, - ZHIPU_RESULT_TOKENS -} from '@renderer/config/models' -import { mapLanguageToQwenMTModel } from '@renderer/config/translate' -import { processPostsuffixQwen3Model, processReqMessages } from '@renderer/services/ModelMessageService' -import { estimateTextTokens } from '@renderer/services/TokenService' -// For Copilot token -import type { - Assistant, - MCPCallToolResponse, - MCPTool, - MCPToolResponse, - Model, - OpenAIServiceTier, - Provider, - ToolCallResponse -} from '@renderer/types' -import { - EFFORT_RATIO, - FileTypes, - isSystemProvider, - isTranslateAssistant, - SystemProviderIds, - WebSearchSource -} from '@renderer/types' -import type { TextStartChunk, ThinkingStartChunk } from '@renderer/types/chunk' -import { ChunkType } from '@renderer/types/chunk' -import type { Message } from '@renderer/types/newMessage' -import type { - OpenAIExtraBody, - OpenAIModality, - OpenAISdkMessageParam, - OpenAISdkParams, - OpenAISdkRawChunk, - OpenAISdkRawContentSource, - OpenAISdkRawOutput, - ReasoningEffortOptionalParams -} from '@renderer/types/sdk' -import { addImageFileToContents } from '@renderer/utils/formats' -import { - isSupportedToolUse, - mcpToolCallResponseToOpenAICompatibleMessage, - mcpToolsToOpenAIChatTools, - openAIToolsToMcpTool -} from '@renderer/utils/mcp-tools' -import { findFileBlocks, findImageBlocks } from '@renderer/utils/messageUtils/find' -import { - isSupportArrayContentProvider, - isSupportDeveloperRoleProvider, - isSupportEnableThinkingProvider, - isSupportStreamOptionsProvider -} from '@renderer/utils/provider' -import { t } from 'i18next' - -import type { GenericChunk } from '../../middleware/schemas' -import type { RequestTransformer, ResponseChunkTransformer, ResponseChunkTransformerContext } from '../types' -import { OpenAIBaseClient } from './OpenAIBaseClient' - -const logger = loggerService.withContext('OpenAIApiClient') - -export class OpenAIAPIClient extends OpenAIBaseClient< - OpenAI | AzureOpenAI, - OpenAISdkParams, - OpenAISdkRawOutput, - OpenAISdkRawChunk, - OpenAISdkMessageParam, - OpenAI.Chat.Completions.ChatCompletionMessageToolCall, - ChatCompletionTool -> { - constructor(provider: Provider) { - super(provider) - } - - override async createCompletions( - payload: OpenAISdkParams, - options?: OpenAI.RequestOptions - ): Promise { - const sdk = await this.getSdkInstance() - // @ts-ignore - SDK参数可能有额外的字段 - return await sdk.chat.completions.create(payload, options) - } - - /** - * Get the reasoning effort for the assistant - * @param assistant - The assistant - * @param model - The model - * @returns The reasoning effort - */ - // Method for reasoning effort, moved from OpenAIProvider - override getReasoningEffort(assistant: Assistant, model: Model): ReasoningEffortOptionalParams { - if (this.provider.id === SystemProviderIds.groq) { - return {} - } - - if (!isReasoningModel(model)) { - return {} - } - - if (isOpenAIDeepResearchModel(model)) { - return { - reasoning_effort: 'medium' - } - } - - const reasoningEffort = assistant?.settings?.reasoning_effort - - if (isSupportedThinkingTokenZhipuModel(model)) { - return { thinking: { type: reasoningEffort ? 'enabled' : 'disabled' } } - } - - if (reasoningEffort === 'default') { - return {} - } - - if (!reasoningEffort) { - // DeepSeek hybrid inference models, v3.1 and maybe more in the future - // 不同的 provider 有不同的思考控制方式,在这里统一解决 - // if (isDeepSeekHybridInferenceModel(model)) { - // // do nothing for now. default to non-think. - // } - - // openrouter: use reasoning - // openrouter 如果关闭思考,会隐藏思考内容,所以对于总是思考的模型需要特别处理 - if (model.provider === SystemProviderIds.openrouter) { - // Don't disable reasoning for Gemini models that support thinking tokens - if (isSupportedThinkingTokenGeminiModel(model) && !GEMINI_FLASH_MODEL_REGEX.test(model.id)) { - return {} - } - // Don't disable reasoning for models that require it - if (isGrokReasoningModel(model) || isOpenAIReasoningModel(model)) { - return {} - } - if (isReasoningModel(model) && !isSupportedThinkingTokenModel(model)) { - return {} - } - return { reasoning: { enabled: false, exclude: true } } - } - - // providers that use enable_thinking - if ( - isSupportEnableThinkingProvider(this.provider) && - (isSupportedThinkingTokenQwenModel(model) || - isSupportedThinkingTokenHunyuanModel(model) || - (this.provider.id === SystemProviderIds.dashscope && isDeepSeekHybridInferenceModel(model))) - ) { - return { enable_thinking: false } - } - - // claude - if (isSupportedThinkingTokenClaudeModel(model)) { - return {} - } - - // gemini - if (isSupportedThinkingTokenGeminiModel(model)) { - if (GEMINI_FLASH_MODEL_REGEX.test(model.id)) { - return { - extra_body: { - google: { - thinking_config: { - thinking_budget: 0 - } - } - } - } - } - return {} - } - - if (isSupportedThinkingTokenDoubaoModel(model)) { - return { thinking: { type: 'disabled' } } - } - - return {} - } - - // reasoningEffort有效的情况 - const effortRatio = EFFORT_RATIO[reasoningEffort] - const budgetTokens = Math.floor( - (findTokenLimit(model.id)?.max! - findTokenLimit(model.id)?.min!) * effortRatio + findTokenLimit(model.id)?.min! - ) - - // DeepSeek hybrid inference models, v3.1 and maybe more in the future - // 不同的 provider 有不同的思考控制方式,在这里统一解决 - if (isDeepSeekHybridInferenceModel(model)) { - if (isSystemProvider(this.provider)) { - switch (this.provider.id) { - case SystemProviderIds.dashscope: - return { - enable_thinking: true, - incremental_output: true - } - case SystemProviderIds.doubao: - return { - thinking: { - type: 'enabled' // auto is invalid - } - } - case SystemProviderIds.openrouter: - return { - reasoning: { - enabled: true - } - } - case 'nvidia': - return { - chat_template_kwargs: { - thinking: true - } - } - case SystemProviderIds.silicon: - case SystemProviderIds.ppio: - return { - enable_thinking: true - } - default: - logger.warn( - `Use enable_thinking option as fallback for provider ${this.provider.name} since DeepSeek v3.1 thinking control method is unknown` - ) - return { - enable_thinking: true - } - } - } - } - - // OpenRouter models - if (model.provider === SystemProviderIds.openrouter) { - if (isSupportedReasoningEffortModel(model) || isSupportedThinkingTokenModel(model)) { - return { - reasoning: { - effort: reasoningEffort === 'auto' ? 'medium' : reasoningEffort - } - } - } - } - - // Doubao 思考模式支持 - if (isSupportedThinkingTokenDoubaoModel(model)) { - if (reasoningEffort === 'high') { - return { thinking: { type: 'enabled' } } - } - if (reasoningEffort === 'auto' && isDoubaoThinkingAutoModel(model)) { - return { thinking: { type: 'auto' } } - } - // 其他情况不带 thinking 字段 - return {} - } - - // Qwen models - if (isQwenReasoningModel(model)) { - const thinkConfig = { - enable_thinking: - isQwenAlwaysThinkModel(model) || !isSupportEnableThinkingProvider(this.provider) ? undefined : true, - thinking_budget: budgetTokens - } - if (this.provider.id === SystemProviderIds.dashscope) { - return { - ...thinkConfig, - incremental_output: true - } - } - return thinkConfig - } - - // Hunyuan models - if (isSupportedThinkingTokenHunyuanModel(model) && isSupportEnableThinkingProvider(this.provider)) { - return { - enable_thinking: true - } - } - - // Grok models/Perplexity models/OpenAI models - if (isSupportedReasoningEffortModel(model)) { - // 检查模型是否支持所选选项 - const supportedOptions = getModelSupportedReasoningEffortOptions(model)?.filter((option) => option !== 'default') - if (supportedOptions?.includes(reasoningEffort)) { - return { - reasoning_effort: reasoningEffort - } - } else { - // 如果不支持,fallback到第一个支持的值 - return { - reasoning_effort: supportedOptions?.[0] - } - } - } - - if (isSupportedThinkingTokenGeminiModel(model)) { - if (reasoningEffort === 'auto') { - return { - extra_body: { - google: { - thinking_config: { - thinking_budget: -1, - include_thoughts: true - } - } - } - } - } - return { - extra_body: { - google: { - thinking_config: { - thinking_budget: budgetTokens, - include_thoughts: true - } - } - } - } - } - - // Claude models - if (isSupportedThinkingTokenClaudeModel(model)) { - const maxTokens = assistant.settings?.maxTokens - return { - thinking: { - type: 'enabled', - budget_tokens: Math.floor( - Math.max(1024, Math.min(budgetTokens, (maxTokens || DEFAULT_MAX_TOKENS) * effortRatio)) - ) - } - } - } - - // Doubao models - if (isSupportedThinkingTokenDoubaoModel(model)) { - if (assistant.settings?.reasoning_effort === 'high') { - return { - thinking: { - type: 'enabled' - } - } - } - } - - // Default case: no special thinking settings - return {} - } - - /** - * Check if the provider does not support files - * @returns True if the provider does not support files, false otherwise - */ - private get isNotSupportFiles() { - if (this.provider?.isNotSupportArrayContent) { - return true - } - - return !isSupportArrayContentProvider(this.provider) - } - - /** - * Get the message parameter - * @param message - The message - * @param model - The model - * @returns The message parameter - */ - public async convertMessageToSdkParam(message: Message, model: Model): Promise { - const isVision = isVisionModel(model) - const { textContent, imageContents } = await this.getMessageContent(message) - const fileBlocks = findFileBlocks(message) - const imageBlocks = findImageBlocks(message) - - // If the model does not support files, extract the file content - if (this.isNotSupportFiles) { - const fileContent = await this.extractFileContent(message) - - return { - role: message.role === 'system' ? 'user' : message.role, - content: textContent + '\n\n---\n\n' + fileContent - } as OpenAISdkMessageParam - } - - // Check if we only have text content and no other media - if (fileBlocks.length === 0 && imageBlocks.length === 0 && imageContents.length === 0) { - return { - role: message.role === 'system' ? 'user' : message.role, - content: textContent - } as OpenAISdkMessageParam - } - - // If the model supports files, add the file content to the message - const parts: ChatCompletionContentPart[] = [] - - if (textContent) { - parts.push({ type: 'text', text: textContent }) - } - - if (imageContents.length > 0) { - for (const imageContent of imageContents) { - const image = await window.api.file.base64Image(imageContent.fileId + imageContent.fileExt) - parts.push({ type: 'image_url', image_url: { url: image.data } }) - } - } - - for (const imageBlock of imageBlocks) { - if (isVision) { - if (imageBlock.file) { - const image = await window.api.file.base64Image(imageBlock.file.id + imageBlock.file.ext) - parts.push({ type: 'image_url', image_url: { url: image.data } }) - } else if (imageBlock.url && imageBlock.url.startsWith('data:')) { - parts.push({ type: 'image_url', image_url: { url: imageBlock.url } }) - } - } - } - - for (const fileBlock of fileBlocks) { - const file = fileBlock.file - if (!file) { - continue - } - - if ([FileTypes.TEXT, FileTypes.DOCUMENT].includes(file.type)) { - const fileContent = await (await window.api.file.read(file.id + file.ext, true)).trim() - parts.push({ - type: 'text', - text: file.origin_name + '\n' + fileContent - }) - } - } - - return { - role: message.role === 'system' ? 'user' : message.role, - content: parts - } as OpenAISdkMessageParam - } - - public convertMcpToolsToSdkTools(mcpTools: MCPTool[]): ChatCompletionTool[] { - return mcpToolsToOpenAIChatTools(mcpTools) - } - - public convertSdkToolCallToMcp( - toolCall: OpenAI.Chat.Completions.ChatCompletionMessageToolCall, - mcpTools: MCPTool[] - ): MCPTool | undefined { - return openAIToolsToMcpTool(mcpTools, toolCall) - } - - public convertSdkToolCallToMcpToolResponse( - toolCall: OpenAI.Chat.Completions.ChatCompletionMessageToolCall, - mcpTool: MCPTool - ): ToolCallResponse { - let parsedArgs: any - try { - if ('function' in toolCall) { - parsedArgs = JSON.parse(toolCall.function.arguments) - } - } catch { - if ('function' in toolCall) { - parsedArgs = toolCall.function.arguments - } - } - return { - id: toolCall.id, - toolCallId: toolCall.id, - tool: mcpTool, - arguments: parsedArgs, - status: 'pending' - } as ToolCallResponse - } - - public convertMcpToolResponseToSdkMessageParam( - mcpToolResponse: MCPToolResponse, - resp: MCPCallToolResponse, - model: Model - ): OpenAISdkMessageParam | undefined { - if ('toolUseId' in mcpToolResponse && mcpToolResponse.toolUseId) { - // This case is for Anthropic/Claude like tool usage, OpenAI uses tool_call_id - // For OpenAI, we primarily expect toolCallId. This might need adjustment if mixing provider concepts. - return mcpToolCallResponseToOpenAICompatibleMessage( - mcpToolResponse, - resp, - isVisionModel(model), - !isSupportArrayContentProvider(this.provider) - ) - } else if ('toolCallId' in mcpToolResponse && mcpToolResponse.toolCallId) { - return { - role: 'tool', - tool_call_id: mcpToolResponse.toolCallId, - content: JSON.stringify(resp.content) - } as OpenAI.Chat.Completions.ChatCompletionToolMessageParam - } - return undefined - } - - public buildSdkMessages( - currentReqMessages: OpenAISdkMessageParam[], - output: string | undefined, - toolResults: OpenAISdkMessageParam[], - toolCalls: OpenAI.Chat.Completions.ChatCompletionMessageToolCall[] - ): OpenAISdkMessageParam[] { - if (!output && toolCalls.length === 0) { - return [...currentReqMessages, ...toolResults] - } - - const assistantMessage: OpenAISdkMessageParam = { - role: 'assistant', - content: output, - tool_calls: toolCalls.length > 0 ? toolCalls : undefined - } - const newReqMessages = [...currentReqMessages, assistantMessage, ...toolResults] - return newReqMessages - } - - override estimateMessageTokens(message: OpenAISdkMessageParam): number { - let sum = 0 - if (typeof message.content === 'string') { - sum += estimateTextTokens(message.content) - } else if (Array.isArray(message.content)) { - sum += (message.content || []) - .map((part: ChatCompletionContentPart | ChatCompletionContentPartRefusal) => { - switch (part.type) { - case 'text': - return estimateTextTokens(part.text) - case 'image_url': - return estimateTextTokens(part.image_url.url) - case 'input_audio': - return estimateTextTokens(part.input_audio.data) - case 'file': - return estimateTextTokens(part.file.file_data || '') - default: - return 0 - } - }) - .reduce((acc, curr) => acc + curr, 0) - } - if ('tool_calls' in message && message.tool_calls) { - sum += message.tool_calls.reduce((acc, toolCall) => { - if (toolCall.type === 'function' && 'function' in toolCall) { - return acc + estimateTextTokens(JSON.stringify(toolCall.function.arguments)) - } - return acc - }, 0) - } - return sum - } - - public extractMessagesFromSdkPayload(sdkPayload: OpenAISdkParams): OpenAISdkMessageParam[] { - return sdkPayload.messages || [] - } - - getRequestTransformer(): RequestTransformer { - return { - transform: async ( - coreRequest, - assistant, - model, - isRecursiveCall, - recursiveSdkMessages - ): Promise<{ - payload: OpenAISdkParams - messages: OpenAISdkMessageParam[] - metadata: Record - }> => { - const { messages, mcpTools, maxTokens, enableWebSearch, enableGenerateImage } = coreRequest - let { streamOutput } = coreRequest - - // Qwen3商业版(思考模式)、Qwen3开源版、QwQ、QVQ只支持流式输出。 - if (isQwenReasoningModel(model)) { - streamOutput = true - } - - const extra_body: OpenAIExtraBody = {} - - if (isQwenMTModel(model)) { - if (isTranslateAssistant(assistant)) { - const targetLanguage = mapLanguageToQwenMTModel(assistant.targetLanguage) - if (!targetLanguage) { - throw new Error(t('translate.error.not_supported', { language: assistant.targetLanguage.value })) - } - const translationOptions = { - source_lang: 'auto', - target_lang: targetLanguage - } as const - extra_body.translation_options = translationOptions - } else { - throw new Error(t('translate.error.chat_qwen_mt')) - } - } - - // 1. 处理系统消息 - const systemMessage = { role: 'system', content: assistant.prompt || '' } - - if ( - isSupportedReasoningEffortOpenAIModel(model) && - isSupportDeveloperRoleProvider(this.provider) && - !isOpenAIOpenWeightModel(model) - ) { - systemMessage.role = 'developer' - } - - if (model.id.includes('o1-mini') || model.id.includes('o1-preview')) { - systemMessage.role = 'assistant' - } - - // 2. 设置工具(必须在this.usesystemPromptForTools前面) - const { tools } = this.setupToolsConfig({ - mcpTools: mcpTools, - model, - enableToolUse: isSupportedToolUse(assistant) - }) - - // 3. 处理用户消息 - const userMessages: OpenAISdkMessageParam[] = [] - if (typeof messages === 'string') { - userMessages.push({ role: 'user', content: messages }) - } else { - const processedMessages = addImageFileToContents(messages) - for (const message of processedMessages) { - userMessages.push(await this.convertMessageToSdkParam(message, model)) - } - } - if (userMessages.length === 0) { - logger.warn('No user message. Some providers may not support.') - } - - const reasoningEffort = this.getReasoningEffort(assistant, model) - - const lastUserMsg = userMessages.findLast((m) => m.role === 'user') - if (lastUserMsg) { - if (isSupportedThinkingTokenQwenModel(model) && !isSupportEnableThinkingProvider(this.provider)) { - const qwenThinkModeEnabled = assistant.settings?.qwenThinkMode === true - const currentContent = lastUserMsg.content - - lastUserMsg.content = processPostsuffixQwen3Model(currentContent, qwenThinkModeEnabled) - } - } - - // 4. 最终请求消息 - let reqMessages: OpenAISdkMessageParam[] - if (!systemMessage.content) { - reqMessages = [...userMessages] - } else if (isNotSupportSystemMessageModel(model)) { - // transform into user message - const firstUserMsg = userMessages.shift() - if (firstUserMsg) { - firstUserMsg.content = `System Instruction: \n${systemMessage.content}\n\nUser Message(s):\n${firstUserMsg.content}` - reqMessages = [firstUserMsg, ...userMessages] - } else { - reqMessages = [] - } - } else { - reqMessages = [systemMessage, ...userMessages].filter(Boolean) as OpenAISdkMessageParam[] - } - - reqMessages = processReqMessages(model, reqMessages) - - // 5. 创建通用参数 - // Create the appropriate parameters object based on whether streaming is enabled - // Note: Some providers like Mistral don't support stream_options - const shouldIncludeStreamOptions = streamOutput && isSupportStreamOptionsProvider(this.provider) - - // minimal cannot be used with web_search tool - if (isGPT5SeriesModel(model) && reasoningEffort.reasoning_effort === 'minimal' && enableWebSearch) { - reasoningEffort.reasoning_effort = 'low' - } - - const modalities: { - modalities?: OpenAIModality[] - } = {} - // for openrouter generate image - // https://openrouter.ai/docs/features/multimodal/image-generation - if (enableGenerateImage && this.provider.id === SystemProviderIds.openrouter) { - modalities.modalities = ['image', 'text'] - } - - const commonParams: OpenAISdkParams = { - model: model.id, - messages: - isRecursiveCall && recursiveSdkMessages && recursiveSdkMessages.length > 0 - ? recursiveSdkMessages - : reqMessages, - temperature: this.getTemperature(assistant, model), - top_p: this.getTopP(assistant, model), - max_tokens: maxTokens, - tools: tools.length > 0 ? tools : undefined, - stream: streamOutput, - ...(shouldIncludeStreamOptions ? { stream_options: { include_usage: true } } : {}), - ...modalities, - // groq 有不同的 service tier 配置,不符合 openai 接口类型 - service_tier: this.getServiceTier(model) as OpenAIServiceTier, - // verbosity. getVerbosity ensures the returned value is valid. - verbosity: this.getVerbosity(model), - ...this.getProviderSpecificParameters(assistant, model), - ...reasoningEffort, - // ...getOpenAIWebSearchParams(model, enableWebSearch), - // OpenRouter usage tracking - ...(this.provider.id === 'openrouter' ? { usage: { include: true } } : {}), - ...extra_body, - // 只在对话场景下应用自定义参数,避免影响翻译、总结等其他业务逻辑 - // 注意:用户自定义参数总是应该覆盖其他参数 - ...(coreRequest.callType === 'chat' ? this.getCustomParameters(assistant) : {}) - } - - const timeout = this.getTimeout(model) - - return { payload: commonParams, messages: reqMessages, metadata: { timeout } } - } - } - } - - // 在RawSdkChunkToGenericChunkMiddleware中使用 - getResponseChunkTransformer(): ResponseChunkTransformer { - let hasBeenCollectedWebSearch = false - let hasEmittedWebSearchInProgress = false - const collectWebSearchData = ( - chunk: OpenAISdkRawChunk, - contentSource: OpenAISdkRawContentSource, - context: ResponseChunkTransformerContext - ) => { - if (hasBeenCollectedWebSearch) { - return - } - // OpenAI annotations - // @ts-ignore - annotations may not be in standard type definitions - const annotations = contentSource.annotations || chunk.annotations - if (annotations && annotations.length > 0 && annotations[0].type === 'url_citation') { - hasBeenCollectedWebSearch = true - return { - results: annotations, - source: WebSearchSource.OPENAI - } - } - - // Grok citations - // @ts-ignore - citations may not be in standard type definitions - if (context.provider?.id === 'grok' && chunk.citations) { - hasBeenCollectedWebSearch = true - return { - // @ts-ignore - citations may not be in standard type definitions - results: chunk.citations, - source: WebSearchSource.GROK - } - } - - // Perplexity citations - // @ts-ignore - citations may not be in standard type definitions - if (context.provider?.id === 'perplexity' && chunk.search_results && chunk.search_results.length > 0) { - hasBeenCollectedWebSearch = true - return { - // @ts-ignore - citations may not be in standard type definitions - results: chunk.search_results, - source: WebSearchSource.PERPLEXITY - } - } - - // OpenRouter citations - // @ts-ignore - citations may not be in standard type definitions - if (context.provider?.id === 'openrouter' && chunk.citations && chunk.citations.length > 0) { - hasBeenCollectedWebSearch = true - return { - // @ts-ignore - citations may not be in standard type definitions - results: chunk.citations, - source: WebSearchSource.OPENROUTER - } - } - - // Zhipu web search - // @ts-ignore - web_search may not be in standard type definitions - if (context.provider?.id === 'zhipu' && chunk.web_search) { - hasBeenCollectedWebSearch = true - return { - // @ts-ignore - web_search may not be in standard type definitions - results: chunk.web_search, - source: WebSearchSource.ZHIPU - } - } - - // Hunyuan web search - // @ts-ignore - search_info may not be in standard type definitions - if (context.provider?.id === 'hunyuan' && chunk.search_info?.search_results) { - hasBeenCollectedWebSearch = true - return { - // @ts-ignore - search_info may not be in standard type definitions - results: chunk.search_info.search_results, - source: WebSearchSource.HUNYUAN - } - } - - // TODO: 放到AnthropicApiClient中 - // // Other providers... - // // @ts-ignore - web_search may not be in standard type definitions - // if (chunk.web_search) { - // const sourceMap: Record = { - // openai: 'openai', - // anthropic: 'anthropic', - // qwenlm: 'qwen' - // } - // const source = sourceMap[context.provider?.id] || 'openai_response' - // return { - // results: chunk.web_search, - // source: source as const - // } - // } - - return null - } - - const toolCalls: OpenAI.Chat.Completions.ChatCompletionMessageToolCall[] = [] - let isFinished = false - let lastUsageInfo: any = null - let hasFinishReason = false // Track if we've seen a finish_reason - - /** - * 统一的完成信号发送逻辑 - * - 有 finish_reason 时 - * - 无 finish_reason 但是流正常结束时 - */ - const emitCompletionSignals = (controller: TransformStreamDefaultController) => { - if (isFinished) return - - if (toolCalls.length > 0) { - controller.enqueue({ - type: ChunkType.MCP_TOOL_CREATED, - tool_calls: toolCalls - }) - } - - const usage = lastUsageInfo || { - prompt_tokens: 0, - completion_tokens: 0, - total_tokens: 0 - } - - controller.enqueue({ - type: ChunkType.LLM_RESPONSE_COMPLETE, - response: { usage } - }) - - // 防止重复发送 - isFinished = true - } - - let isThinking = false - let accumulatingText = false - return (context: ResponseChunkTransformerContext) => ({ - async transform(chunk: OpenAISdkRawChunk, controller: TransformStreamDefaultController) { - // 持续更新usage信息 - logger.silly('chunk', chunk) - if (chunk.usage) { - const usage = chunk.usage - lastUsageInfo = { - prompt_tokens: usage.prompt_tokens || 0, - completion_tokens: usage.completion_tokens || 0, - total_tokens: usage.total_tokens || (usage.prompt_tokens || 0) + (usage.completion_tokens || 0), - // Handle OpenRouter specific cost fields - ...(usage.cost !== undefined ? { cost: usage.cost } : {}) - } - } - - // if we've already seen finish_reason, emit completion signals. No matter whether we get usage or not. - if (hasFinishReason && !isFinished) { - emitCompletionSignals(controller) - return - } - - if (typeof chunk === 'string') { - try { - chunk = JSON.parse(chunk) - } catch (error) { - logger.error('invalid chunk', { chunk, error }) - throw new Error(t('error.chat.chunk.non_json')) - } - } - - // 处理chunk - if ('choices' in chunk && chunk.choices && chunk.choices.length > 0) { - for (const choice of chunk.choices) { - if (!choice) continue - - // 对于流式响应,使用 delta;对于非流式响应,使用 message。 - // 然而某些 OpenAI 兼容平台在非流式请求时会错误地返回一个空对象的 delta 字段。 - // 如果 delta 为空对象或content为空,应当忽略它并回退到 message,避免造成内容缺失。 - let contentSource: OpenAISdkRawContentSource | null = null - if ( - 'delta' in choice && - choice.delta && - Object.keys(choice.delta).length > 0 && - (!('content' in choice.delta) || - (choice.delta.tool_calls && choice.delta.tool_calls.length > 0) || - (typeof choice.delta.content === 'string' && choice.delta.content !== '') || - (typeof (choice.delta as any).reasoning_content === 'string' && - (choice.delta as any).reasoning_content !== '') || - (typeof (choice.delta as any).reasoning === 'string' && (choice.delta as any).reasoning !== '') || - ((choice.delta as OpenAISdkRawContentSource).images && - Array.isArray((choice.delta as OpenAISdkRawContentSource).images))) - ) { - contentSource = choice.delta - } else if ('message' in choice) { - contentSource = choice.message - } - - // 状态管理 - if (!contentSource?.content) { - accumulatingText = false - } - // @ts-ignore - reasoning_content is not in standard OpenAI types but some providers use it - if (!contentSource?.reasoning_content && !contentSource?.reasoning) { - isThinking = false - } - - if (!contentSource) { - if ('finish_reason' in choice && choice.finish_reason) { - // OpenAI Chat Completions API 在启用 stream_options: { include_usage: true } 以后 - // 包含 usage 的 chunk 会在包含 finish_reason: stop 的 chunk 之后 - // 所以试图等到拿到 usage 之后再发出结束信号 - hasFinishReason = true - // If we already have usage info, emit completion signals now - if (lastUsageInfo && lastUsageInfo.total_tokens > 0) { - emitCompletionSignals(controller) - } - } - continue - } - - const webSearchData = collectWebSearchData(chunk, contentSource, context) - if (webSearchData) { - // 如果还未发送搜索进度事件,先发送进度事件 - if (!hasEmittedWebSearchInProgress) { - controller.enqueue({ - type: ChunkType.LLM_WEB_SEARCH_IN_PROGRESS - }) - hasEmittedWebSearchInProgress = true - } - controller.enqueue({ - type: ChunkType.LLM_WEB_SEARCH_COMPLETE, - llm_web_search: webSearchData - }) - } - - // 处理推理内容 (e.g. from OpenRouter DeepSeek-R1) - // @ts-ignore - reasoning_content is not in standard OpenAI types but some providers use it - const reasoningText = contentSource.reasoning_content || contentSource.reasoning - if (reasoningText) { - // logger.silly('since reasoningText is trusy, try to enqueue THINKING_START AND THINKING_DELTA') - if (!isThinking) { - // logger.silly('since isThinking is falsy, try to enqueue THINKING_START') - controller.enqueue({ - type: ChunkType.THINKING_START - } as ThinkingStartChunk) - isThinking = true - } - - // logger.silly('enqueue THINKING_DELTA') - controller.enqueue({ - type: ChunkType.THINKING_DELTA, - text: reasoningText - }) - } else { - isThinking = false - } - - // 处理文本内容 - if (contentSource.content) { - // logger.silly('since contentSource.content is trusy, try to enqueue TEXT_START and TEXT_DELTA') - if (!accumulatingText) { - // logger.silly('enqueue TEXT_START') - controller.enqueue({ - type: ChunkType.TEXT_START - } as TextStartChunk) - accumulatingText = true - } - // logger.silly('enqueue TEXT_DELTA') - // 处理特殊token - // 智谱api的一个chunk中只会输出一个token,因而使用 ===,避免正常内容被误判 - if ( - context.provider.id === SystemProviderIds.zhipu && - ZHIPU_RESULT_TOKENS.some((pattern) => contentSource.content === pattern) - ) { - controller.enqueue({ - type: ChunkType.TEXT_DELTA, - text: '**' // strong - }) - } else { - controller.enqueue({ - type: ChunkType.TEXT_DELTA, - text: contentSource.content - }) - } - } else { - accumulatingText = false - } - - // 处理图片内容 (e.g. from OpenRouter Gemini image generation models) - if (contentSource.images && Array.isArray(contentSource.images)) { - controller.enqueue({ - type: ChunkType.IMAGE_CREATED - }) - controller.enqueue({ - type: ChunkType.IMAGE_COMPLETE, - image: { - type: 'base64', - images: contentSource.images.map((image) => image.image_url?.url || '') - } - }) - } - - // 处理工具调用 - if (contentSource.tool_calls) { - for (const toolCall of contentSource.tool_calls) { - if ('index' in toolCall) { - const { id, index, function: fun } = toolCall - if (fun?.name) { - const toolCallObject = { - id: id || '', - function: { - name: fun.name, - arguments: fun.arguments || '' - }, - type: 'function' as const - } - - if (index === -1) { - toolCalls.push(toolCallObject) - } else { - toolCalls[index] = toolCallObject - } - } else if (fun?.arguments) { - if (toolCalls[index] && toolCalls[index].type === 'function' && 'function' in toolCalls[index]) { - toolCalls[index].function.arguments += fun.arguments - } - } - } else { - toolCalls.push(toolCall) - } - } - } - - // 处理finish_reason,发送流结束信号 - if ('finish_reason' in choice && choice.finish_reason) { - logger.debug(`Stream finished with reason: ${choice.finish_reason}`) - const webSearchData = collectWebSearchData(chunk, contentSource, context) - if (webSearchData) { - // 如果还未发送搜索进度事件,先发送进度事件 - if (!hasEmittedWebSearchInProgress) { - controller.enqueue({ - type: ChunkType.LLM_WEB_SEARCH_IN_PROGRESS - }) - hasEmittedWebSearchInProgress = true - } - controller.enqueue({ - type: ChunkType.LLM_WEB_SEARCH_COMPLETE, - llm_web_search: webSearchData - }) - } - - // Don't emit completion signals immediately after finish_reason - // Wait for the usage chunk that comes after - hasFinishReason = true - // If we already have usage info, emit completion signals now - if (lastUsageInfo && lastUsageInfo.total_tokens > 0) { - emitCompletionSignals(controller) - } - } - } - } - }, - - // 流正常结束时,检查是否需要发送完成信号 - flush(controller) { - if (isFinished) return - - logger.debug('Stream ended without finish_reason, emitting fallback completion signals') - emitCompletionSignals(controller) - } - }) - } -} diff --git a/src/renderer/src/aiCore/legacy/clients/openai/OpenAIBaseClient.ts b/src/renderer/src/aiCore/legacy/clients/openai/OpenAIBaseClient.ts deleted file mode 100644 index efc3f4f7ce..0000000000 --- a/src/renderer/src/aiCore/legacy/clients/openai/OpenAIBaseClient.ts +++ /dev/null @@ -1,309 +0,0 @@ -import OpenAI, { AzureOpenAI } from '@cherrystudio/openai' -import { loggerService } from '@logger' -import { COPILOT_DEFAULT_HEADERS } from '@renderer/aiCore/provider/constants' -import { - isClaudeReasoningModel, - isOpenAIReasoningModel, - isSupportedModel, - isSupportedReasoningEffortOpenAIModel -} from '@renderer/config/models' -import { getStoreSetting } from '@renderer/hooks/useSettings' -import { getAssistantSettings } from '@renderer/services/AssistantService' -import store from '@renderer/store' -import type { SettingsState } from '@renderer/store/settings' -import { type Assistant, type GenerateImageParams, type Model, type Provider } from '@renderer/types' -import type { - OpenAIResponseSdkMessageParam, - OpenAIResponseSdkParams, - OpenAIResponseSdkRawChunk, - OpenAIResponseSdkRawOutput, - OpenAIResponseSdkTool, - OpenAIResponseSdkToolCall, - OpenAISdkMessageParam, - OpenAISdkParams, - OpenAISdkRawChunk, - OpenAISdkRawOutput, - ReasoningEffortOptionalParams -} from '@renderer/types/sdk' -import { withoutTrailingSlash } from '@renderer/utils/api' -import { isOllamaProvider } from '@renderer/utils/provider' - -import { BaseApiClient } from '../BaseApiClient' -import { normalizeAzureOpenAIEndpoint } from './azureOpenAIEndpoint' - -const logger = loggerService.withContext('OpenAIBaseClient') - -/** - * 抽象的OpenAI基础客户端类,包含两个OpenAI客户端之间的共享功能 - */ -export abstract class OpenAIBaseClient< - TSdkInstance extends OpenAI | AzureOpenAI, - TSdkParams extends OpenAISdkParams | OpenAIResponseSdkParams, - TRawOutput extends OpenAISdkRawOutput | OpenAIResponseSdkRawOutput, - TRawChunk extends OpenAISdkRawChunk | OpenAIResponseSdkRawChunk, - TMessageParam extends OpenAISdkMessageParam | OpenAIResponseSdkMessageParam, - TToolCall extends OpenAI.Chat.Completions.ChatCompletionMessageToolCall | OpenAIResponseSdkToolCall, - TSdkSpecificTool extends OpenAI.Chat.Completions.ChatCompletionTool | OpenAIResponseSdkTool -> extends BaseApiClient { - constructor(provider: Provider) { - super(provider) - } - - // 仅适用于openai - override getBaseURL(): string { - // apiHost is formatted when called by AiProvider - return this.provider.apiHost - } - - override async generateImage({ - model, - prompt, - negativePrompt, - imageSize, - batchSize, - seed, - numInferenceSteps, - guidanceScale, - signal, - promptEnhancement - }: GenerateImageParams): Promise { - const sdk = await this.getSdkInstance() - const response = (await sdk.request({ - method: 'post', - path: '/v1/images/generations', - signal, - body: { - model, - prompt, - negative_prompt: negativePrompt, - image_size: imageSize, - batch_size: batchSize, - seed: seed ? parseInt(seed) : undefined, - num_inference_steps: numInferenceSteps, - guidance_scale: guidanceScale, - prompt_enhancement: promptEnhancement - } - })) as { data: Array<{ url: string }> } - - return response.data.map((item) => item.url) - } - - override async getEmbeddingDimensions(model: Model): Promise { - let sdk: OpenAI = await this.getSdkInstance() - if (isOllamaProvider(this.provider)) { - const embedBaseUrl = `${this.provider.apiHost.replace(/(\/(api|v1))\/?$/, '')}/v1` - sdk = sdk.withOptions({ baseURL: embedBaseUrl }) - } - - const data = await sdk.embeddings.create({ - model: model.id, - input: model?.provider === 'baidu-cloud' ? ['hi'] : 'hi', - encoding_format: this.provider.id === 'voyageai' ? undefined : 'float' - }) - return data.data[0].embedding.length - } - - override async listModels(): Promise { - try { - const sdk = await this.getSdkInstance() - if (this.provider.id === 'openrouter') { - // https://openrouter.ai/docs/api/api-reference/embeddings/list-embeddings-models - const embedBaseUrl = 'https://openrouter.ai/api/v1/embeddings' - const embedSdk = sdk.withOptions({ baseURL: embedBaseUrl }) - const modelPromise = sdk.models.list() - const embedModelPromise = embedSdk.models.list() - const [modelResponse, embedModelResponse] = await Promise.all([modelPromise, embedModelPromise]) - const models = [...modelResponse.data, ...embedModelResponse.data] - const uniqueModels = Array.from(new Map(models.map((model) => [model.id, model])).values()) - return uniqueModels.filter(isSupportedModel) - } - if (this.provider.id === 'github') { - // GitHub Models 其 models 和 chat completions 两个接口的 baseUrl 不一样 - const baseUrl = 'https://models.github.ai/catalog/' - const newSdk = sdk.withOptions({ baseURL: baseUrl }) - const response = await newSdk.models.list() - - // @ts-ignore key is not typed - return response?.body - .map((model) => ({ - id: model.id, - description: model.summary, - object: 'model', - owned_by: model.publisher - })) - .filter(isSupportedModel) - } - - if (isOllamaProvider(this.provider)) { - const baseUrl = withoutTrailingSlash(this.getBaseURL()) - .replace(/\/v1$/, '') - .replace(/\/api$/, '') - const response = await fetch(`${baseUrl}/api/tags`, { - headers: { - Authorization: `Bearer ${this.apiKey}`, - ...this.defaultHeaders(), - ...this.provider.extra_headers - } - }) - - if (!response.ok) { - throw new Error(`Ollama server returned ${response.status} ${response.statusText}`) - } - - const data = await response.json() - if (!data?.models || !Array.isArray(data.models)) { - throw new Error('Invalid response from Ollama API: missing models array') - } - - return data.models.map((model) => ({ - id: model.name, - object: 'model', - owned_by: 'ollama' - })) - } - const response = await sdk.models.list() - if (this.provider.id === 'together') { - // @ts-ignore key is not typed - return response?.body.map((model) => ({ - id: model.id, - description: model.display_name, - object: 'model', - owned_by: model.organization - })) - } - const models = response.data || [] - models.forEach((model) => { - model.id = model.id.trim() - }) - - return models.filter(isSupportedModel) - } catch (error) { - logger.error('Error listing models:', error as Error) - return [] - } - } - - override async getSdkInstance() { - if (this.sdkInstance) { - return this.sdkInstance - } - - let apiKeyForSdkInstance = this.apiKey - let baseURLForSdkInstance = this.getBaseURL() - logger.debug('baseURLForSdkInstance', { baseURLForSdkInstance }) - let headersForSdkInstance = { - ...this.defaultHeaders(), - ...this.provider.extra_headers - } - - if (this.provider.id === 'copilot') { - const defaultHeaders = store.getState().copilot.defaultHeaders - const { token } = await window.api.copilot.getToken(defaultHeaders) - // this.provider.apiKey不允许修改 - // this.provider.apiKey = token - apiKeyForSdkInstance = token - baseURLForSdkInstance = this.getBaseURL() - headersForSdkInstance = { - ...headersForSdkInstance, - ...COPILOT_DEFAULT_HEADERS - } - } - - if (this.provider.id === 'azure-openai' || this.provider.type === 'azure-openai') { - this.sdkInstance = new AzureOpenAI({ - dangerouslyAllowBrowser: true, - apiKey: apiKeyForSdkInstance, - apiVersion: this.provider.apiVersion, - endpoint: normalizeAzureOpenAIEndpoint(this.provider.apiHost) - }) as TSdkInstance - } else { - this.sdkInstance = new OpenAI({ - dangerouslyAllowBrowser: true, - apiKey: apiKeyForSdkInstance, - baseURL: baseURLForSdkInstance, - defaultHeaders: headersForSdkInstance - }) as TSdkInstance - } - return this.sdkInstance - } - - override getTemperature(assistant: Assistant, model: Model): number | undefined { - if (assistant.settings?.reasoning_effort && isClaudeReasoningModel(model)) { - return undefined - } - return super.getTemperature(assistant, model) - } - - override getTopP(assistant: Assistant, model: Model): number | undefined { - if (assistant.settings?.reasoning_effort && isClaudeReasoningModel(model)) { - return undefined - } - return super.getTopP(assistant, model) - } - - /** - * Get the provider specific parameters for the assistant - * @param assistant - The assistant - * @param model - The model - * @returns The provider specific parameters - */ - protected getProviderSpecificParameters(assistant: Assistant, model: Model) { - const { maxTokens } = getAssistantSettings(assistant) - - if (this.provider.id === 'openrouter') { - if (model.id.includes('deepseek-r1')) { - return { - include_reasoning: true - } - } - } - - if (isOpenAIReasoningModel(model)) { - return { - max_tokens: undefined, - max_completion_tokens: maxTokens - } - } - - return {} - } - - /** - * Get the reasoning effort for the assistant - * @param assistant - The assistant - * @param model - The model - * @returns The reasoning effort - */ - protected getReasoningEffort(assistant: Assistant, model: Model): ReasoningEffortOptionalParams { - if (!isSupportedReasoningEffortOpenAIModel(model)) { - return {} - } - - const openAI = getStoreSetting('openAI') as SettingsState['openAI'] - const summaryText = openAI?.summaryText || 'off' - - let summary: string | undefined = undefined - - if (summaryText === 'off' || model.id.includes('o1-pro')) { - summary = undefined - } else { - summary = summaryText - } - - const reasoningEffort = assistant?.settings?.reasoning_effort - if (!reasoningEffort) { - return {} - } - - if (isSupportedReasoningEffortOpenAIModel(model)) { - return { - reasoning: { - effort: reasoningEffort as OpenAI.ReasoningEffort, - summary: summary - } as OpenAI.Reasoning - } - } - - return {} - } -} diff --git a/src/renderer/src/aiCore/legacy/clients/openai/OpenAIResponseAPIClient.ts b/src/renderer/src/aiCore/legacy/clients/openai/OpenAIResponseAPIClient.ts deleted file mode 100644 index b4f63e2bce..0000000000 --- a/src/renderer/src/aiCore/legacy/clients/openai/OpenAIResponseAPIClient.ts +++ /dev/null @@ -1,769 +0,0 @@ -import OpenAI, { AzureOpenAI } from '@cherrystudio/openai' -import type { ResponseInput } from '@cherrystudio/openai/resources/responses/responses' -import { loggerService } from '@logger' -import type { GenericChunk } from '@renderer/aiCore/legacy/middleware/schemas' -import type { CompletionsContext } from '@renderer/aiCore/legacy/middleware/types' -import { - isGPT5SeriesModel, - isOpenAIChatCompletionOnlyModel, - isOpenAILLMModel, - isOpenAIOpenWeightModel, - isSupportedReasoningEffortOpenAIModel, - isSupportVerbosityModel, - isVisionModel -} from '@renderer/config/models' -import { estimateTextTokens } from '@renderer/services/TokenService' -import type { - FileMetadata, - MCPCallToolResponse, - MCPTool, - MCPToolResponse, - Model, - OpenAIServiceTier, - Provider, - ToolCallResponse -} from '@renderer/types' -import { FileTypes, WebSearchSource } from '@renderer/types' -import { ChunkType } from '@renderer/types/chunk' -import type { Message } from '@renderer/types/newMessage' -import type { - OpenAIResponseSdkMessageParam, - OpenAIResponseSdkParams, - OpenAIResponseSdkRawChunk, - OpenAIResponseSdkRawOutput, - OpenAIResponseSdkTool, - OpenAIResponseSdkToolCall -} from '@renderer/types/sdk' -import { addImageFileToContents } from '@renderer/utils/formats' -import { - isSupportedToolUse, - mcpToolCallResponseToOpenAIMessage, - mcpToolsToOpenAIResponseTools, - openAIToolsToMcpTool -} from '@renderer/utils/mcp-tools' -import { findFileBlocks, findImageBlocks } from '@renderer/utils/messageUtils/find' -import { isSupportDeveloperRoleProvider } from '@renderer/utils/provider' -import { MB } from '@shared/config/constant' -import { t } from 'i18next' -import { isEmpty } from 'lodash' - -import type { RequestTransformer, ResponseChunkTransformer } from '../types' -import { OpenAIAPIClient } from './OpenAIApiClient' -import { OpenAIBaseClient } from './OpenAIBaseClient' - -const logger = loggerService.withContext('OpenAIResponseAPIClient') -export class OpenAIResponseAPIClient extends OpenAIBaseClient< - OpenAI, - OpenAIResponseSdkParams, - OpenAIResponseSdkRawOutput, - OpenAIResponseSdkRawChunk, - OpenAIResponseSdkMessageParam, - OpenAIResponseSdkToolCall, - OpenAIResponseSdkTool -> { - private client: OpenAIAPIClient - constructor(provider: Provider) { - super(provider) - this.client = new OpenAIAPIClient(provider) - } - - private formatApiHost() { - const host = this.provider.apiHost - if (host.endsWith('/openai/v1')) { - return host - } else { - if (host.endsWith('/')) { - return host + 'openai/v1' - } else { - return host + '/openai/v1' - } - } - } - - /** - * 根据模型特征选择合适的客户端 - */ - public getClient(model: Model) { - if (this.provider.type === 'openai-response' && !isOpenAIChatCompletionOnlyModel(model)) { - return this - } - if (isOpenAILLMModel(model) && !isOpenAIChatCompletionOnlyModel(model)) { - if (this.provider.id === 'azure-openai' || this.provider.type === 'azure-openai') { - this.provider = { ...this.provider, apiHost: this.formatApiHost() } - if (this.provider.apiVersion === 'preview' || this.provider.apiVersion === 'v1') { - return this - } else { - return this.client - } - } - return this - } else { - return this.client - } - } - - /** - * 重写基类方法,返回内部实际使用的客户端类型 - */ - public override getClientCompatibilityType(model?: Model): string[] { - if (!model) { - return [this.constructor.name] - } - - const actualClient = this.getClient(model) - // 避免循环调用:如果返回的是自己,直接返回自己的类型 - if (actualClient === this) { - return [this.constructor.name] - } - return actualClient.getClientCompatibilityType(model) - } - - override async getSdkInstance() { - if (this.sdkInstance) { - return this.sdkInstance - } - const baseUrl = this.getBaseURL() - - if (this.provider.id === 'azure-openai' || this.provider.type === 'azure-openai') { - return new AzureOpenAI({ - dangerouslyAllowBrowser: true, - apiKey: this.apiKey, - apiVersion: this.provider.apiVersion, - baseURL: this.provider.apiHost - }) - } else { - return new OpenAI({ - dangerouslyAllowBrowser: true, - apiKey: this.apiKey, - baseURL: baseUrl, - defaultHeaders: { - ...this.defaultHeaders(), - ...this.provider.extra_headers - } - }) - } - } - - override async createCompletions( - payload: OpenAIResponseSdkParams, - options?: OpenAI.RequestOptions - ): Promise { - const sdk = await this.getSdkInstance() - return await sdk.responses.create(payload, options) - } - - private async handlePdfFile(file: FileMetadata): Promise { - if (file.size > 32 * MB) return undefined - try { - const pageCount = await window.api.file.pdfInfo(file.id + file.ext) - if (pageCount > 100) return undefined - } catch { - return undefined - } - - const { data } = await window.api.file.base64File(file.id + file.ext) - return { - type: 'input_file', - filename: file.origin_name, - file_data: `data:application/pdf;base64,${data}` - } as OpenAI.Responses.ResponseInputFile - } - - public async convertMessageToSdkParam(message: Message, model: Model): Promise { - const isVision = isVisionModel(model) - const { textContent, imageContents } = await this.getMessageContent(message) - const fileBlocks = findFileBlocks(message) - const imageBlocks = findImageBlocks(message) - - if (fileBlocks.length === 0 && imageBlocks.length === 0 && imageContents.length === 0) { - if (message.role === 'assistant') { - return { - role: 'assistant', - content: textContent - } - } else { - return { - role: message.role === 'system' ? 'user' : message.role, - content: textContent ? [{ type: 'input_text', text: textContent }] : [] - } as OpenAI.Responses.EasyInputMessage - } - } - - const parts: OpenAI.Responses.ResponseInputContent[] = [] - if (imageContents) { - parts.push({ - type: 'input_text', - text: textContent - }) - } - - if (imageContents.length > 0) { - for (const imageContent of imageContents) { - const image = await window.api.file.base64Image(imageContent.fileId + imageContent.fileExt) - parts.push({ - detail: 'auto', - type: 'input_image', - image_url: image.data - }) - } - } - - for (const imageBlock of imageBlocks) { - if (isVision) { - if (imageBlock.file) { - const image = await window.api.file.base64Image(imageBlock.file.id + imageBlock.file.ext) - parts.push({ - detail: 'auto', - type: 'input_image', - image_url: image.data as string - }) - } else if (imageBlock.url && imageBlock.url.startsWith('data:')) { - parts.push({ - detail: 'auto', - type: 'input_image', - image_url: imageBlock.url - }) - } - } - } - - for (const fileBlock of fileBlocks) { - const file = fileBlock.file - if (!file) continue - - if (isVision && file.ext === '.pdf') { - const pdfPart = await this.handlePdfFile(file) - if (pdfPart) { - parts.push(pdfPart) - continue - } - } - - if ([FileTypes.TEXT, FileTypes.DOCUMENT].includes(file.type)) { - const fileContent = (await window.api.file.read(file.id + file.ext, true)).trim() - parts.push({ - type: 'input_text', - text: file.origin_name + '\n' + fileContent - }) - } - } - - return { - role: message.role === 'system' ? 'user' : message.role, - content: parts - } - } - - public convertMcpToolsToSdkTools(mcpTools: MCPTool[]): OpenAI.Responses.Tool[] { - return mcpToolsToOpenAIResponseTools(mcpTools) - } - - public convertSdkToolCallToMcp(toolCall: OpenAIResponseSdkToolCall, mcpTools: MCPTool[]): MCPTool | undefined { - return openAIToolsToMcpTool(mcpTools, toolCall) - } - public convertSdkToolCallToMcpToolResponse(toolCall: OpenAIResponseSdkToolCall, mcpTool: MCPTool): ToolCallResponse { - const parsedArgs = (() => { - try { - return JSON.parse(toolCall.arguments) - } catch { - return toolCall.arguments - } - })() - - return { - id: toolCall.call_id, - toolCallId: toolCall.call_id, - tool: mcpTool, - arguments: parsedArgs, - status: 'pending' - } - } - - public convertMcpToolResponseToSdkMessageParam( - mcpToolResponse: MCPToolResponse, - resp: MCPCallToolResponse, - model: Model - ): OpenAIResponseSdkMessageParam | undefined { - if ('toolUseId' in mcpToolResponse && mcpToolResponse.toolUseId) { - return mcpToolCallResponseToOpenAIMessage(mcpToolResponse, resp, isVisionModel(model)) - } else if ('toolCallId' in mcpToolResponse && mcpToolResponse.toolCallId) { - return { - type: 'function_call_output', - call_id: mcpToolResponse.toolCallId, - output: JSON.stringify(resp.content) - } - } - return - } - - private convertResponseToMessageContent(response: OpenAI.Responses.Response): ResponseInput { - const content: OpenAI.Responses.ResponseInput = [] - response.output.forEach((item) => { - if (item.type !== 'apply_patch_call' && item.type !== 'apply_patch_call_output') { - content.push(item) - } else if (item.type === 'apply_patch_call') { - if (item.operation !== undefined) { - const applyPatchToolCall: OpenAI.Responses.ResponseInputItem.ApplyPatchCall = { - ...item, - operation: item.operation - } - content.push(applyPatchToolCall) - } else { - logger.warn('Undefined tool call operation for ApplyPatchToolCall.') - } - } else if (item.type === 'apply_patch_call_output') { - if (item.output !== undefined) { - const applyPatchToolCallOutput: OpenAI.Responses.ResponseInputItem.ApplyPatchCallOutput = { - ...item, - output: item.output === null ? undefined : item.output - } - content.push(applyPatchToolCallOutput) - } else { - logger.warn('Undefined tool call operation for ApplyPatchToolCall.') - } - } - }) - return content - } - - public buildSdkMessages( - currentReqMessages: OpenAIResponseSdkMessageParam[], - output: OpenAI.Responses.Response | undefined, - toolResults: OpenAIResponseSdkMessageParam[], - toolCalls: OpenAIResponseSdkToolCall[] - ): OpenAIResponseSdkMessageParam[] { - if (!output && toolCalls.length === 0) { - return [...currentReqMessages, ...toolResults] - } - - if (!output) { - return [...currentReqMessages, ...(toolCalls || []), ...(toolResults || [])] - } - - const content = this.convertResponseToMessageContent(output) - - return [...currentReqMessages, ...content, ...(toolResults || [])] - } - - override estimateMessageTokens(message: OpenAIResponseSdkMessageParam): number { - let sum = 0 - if ('content' in message) { - if (typeof message.content === 'string') { - sum += estimateTextTokens(message.content) - } else if (Array.isArray(message.content)) { - for (const part of message.content) { - switch (part.type) { - case 'input_text': - sum += estimateTextTokens(part.text) - break - case 'input_image': - sum += estimateTextTokens(part.image_url || '') - break - default: - break - } - } - } - } - switch (message.type) { - case 'function_call_output': { - let str = '' - if (typeof message.output === 'string') { - str = message.output - } else { - for (const part of message.output) { - switch (part.type) { - case 'input_text': - str += part.text - break - case 'input_image': - str += part.image_url || '' - break - case 'input_file': - str += part.file_data || '' - break - } - } - } - sum += estimateTextTokens(str) - break - } - case 'function_call': - sum += estimateTextTokens(message.arguments) - break - default: - break - } - return sum - } - - public extractMessagesFromSdkPayload(sdkPayload: OpenAIResponseSdkParams): OpenAIResponseSdkMessageParam[] { - if (!sdkPayload.input || typeof sdkPayload.input === 'string') { - return [{ role: 'user', content: sdkPayload.input ?? '' }] - } - return sdkPayload.input - } - - getRequestTransformer(): RequestTransformer { - return { - transform: async ( - coreRequest, - assistant, - model, - isRecursiveCall, - recursiveSdkMessages - ): Promise<{ - payload: OpenAIResponseSdkParams - messages: OpenAIResponseSdkMessageParam[] - metadata: Record - }> => { - const { messages, mcpTools, maxTokens, streamOutput, enableWebSearch, enableGenerateImage } = coreRequest - // 1. 处理系统消息 - const systemMessage: OpenAI.Responses.EasyInputMessage = { - role: 'system', - content: [] - } - - const systemMessageContent: OpenAI.Responses.ResponseInputMessageContentList = [] - const systemMessageInput: OpenAI.Responses.ResponseInputText = { - text: assistant.prompt || '', - type: 'input_text' - } - if ( - isSupportedReasoningEffortOpenAIModel(model) && - isSupportDeveloperRoleProvider(this.provider) && - isOpenAIOpenWeightModel(model) - ) { - systemMessage.role = 'developer' - } - - // 2. 设置工具 - let tools: OpenAI.Responses.Tool[] = [] - const { tools: extraTools } = this.setupToolsConfig({ - mcpTools: mcpTools, - model, - enableToolUse: isSupportedToolUse(assistant) - }) - - systemMessageContent.push(systemMessageInput) - systemMessage.content = systemMessageContent - - // 3. 处理用户消息 - let userMessage: OpenAI.Responses.ResponseInputItem[] = [] - if (typeof messages === 'string') { - userMessage.push({ role: 'user', content: messages }) - } else { - const processedMessages = addImageFileToContents(messages) - for (const message of processedMessages) { - userMessage.push(await this.convertMessageToSdkParam(message, model)) - } - } - // FIXME: 最好还是直接使用previous_response_id来处理(或者在数据库中存储image_generation_call的id) - if (enableGenerateImage) { - const finalAssistantMessage = userMessage.findLast( - (m) => (m as OpenAI.Responses.EasyInputMessage).role === 'assistant' - ) as OpenAI.Responses.EasyInputMessage - const finalUserMessage = userMessage.pop() as OpenAI.Responses.EasyInputMessage - if (finalUserMessage && Array.isArray(finalUserMessage.content)) { - if (finalAssistantMessage && Array.isArray(finalAssistantMessage.content)) { - finalAssistantMessage.content = [...finalAssistantMessage.content, ...finalUserMessage.content] - // 这里是故意将上条助手消息的内容(包含图片和文件)作为用户消息发送 - userMessage = [{ ...finalAssistantMessage, role: 'user' } as OpenAI.Responses.EasyInputMessage] - } else { - userMessage.push(finalUserMessage) - } - } - } - - // 4. 最终请求消息 - let reqMessages: OpenAI.Responses.ResponseInput - if (!systemMessage.content) { - reqMessages = [...userMessage] - } else { - reqMessages = [systemMessage, ...userMessage].filter(Boolean) as OpenAI.Responses.EasyInputMessage[] - } - - if (enableWebSearch) { - tools.push({ - type: 'web_search_preview' - }) - } - - if (enableGenerateImage) { - tools.push({ - type: 'image_generation', - partial_images: streamOutput ? 2 : undefined - }) - } - - tools = tools.concat(extraTools) - - const reasoningEffort = this.getReasoningEffort(assistant, model) - - // minimal cannot be used with web_search tool - if (isGPT5SeriesModel(model) && reasoningEffort.reasoning?.effort === 'minimal' && enableWebSearch) { - reasoningEffort.reasoning.effort = 'low' - } - - const commonParams: OpenAIResponseSdkParams = { - model: model.id, - input: - isRecursiveCall && recursiveSdkMessages && recursiveSdkMessages.length > 0 - ? recursiveSdkMessages - : reqMessages, - temperature: this.getTemperature(assistant, model), - top_p: this.getTopP(assistant, model), - max_output_tokens: maxTokens, - stream: streamOutput, - tools: !isEmpty(tools) ? tools : undefined, - // groq 有不同的 service tier 配置,不符合 openai 接口类型 - service_tier: this.getServiceTier(model) as OpenAIServiceTier, - ...(isSupportVerbosityModel(model) - ? { - text: { - verbosity: this.getVerbosity(model) - } - } - : {}), - ...(this.getReasoningEffort(assistant, model) as OpenAI.Reasoning), - // 只在对话场景下应用自定义参数,避免影响翻译、总结等其他业务逻辑 - // 注意:用户自定义参数总是应该覆盖其他参数 - ...(coreRequest.callType === 'chat' ? this.getCustomParameters(assistant) : {}) - } - const timeout = this.getTimeout(model) - return { payload: commonParams, messages: reqMessages, metadata: { timeout } } - } - } - } - - getResponseChunkTransformer(ctx: CompletionsContext): ResponseChunkTransformer { - const toolCalls: OpenAIResponseSdkToolCall[] = [] - const outputItems: OpenAI.Responses.ResponseOutputItem[] = [] - let hasBeenCollectedToolCalls = false - let hasReasoningSummary = false - let isFirstThinkingChunk = true - let isFirstTextChunk = true - return () => ({ - async transform(chunk: OpenAIResponseSdkRawChunk, controller: TransformStreamDefaultController) { - if (typeof chunk === 'string') { - try { - chunk = JSON.parse(chunk) - } catch (error) { - logger.error('invalid chunk', { chunk, error }) - throw new Error(t('error.chat.chunk.non_json')) - } - } - // 处理chunk - if ('output' in chunk) { - if (ctx._internal?.toolProcessingState) { - ctx._internal.toolProcessingState.output = chunk - } - for (const output of chunk.output) { - switch (output.type) { - case 'message': - if (output.content[0].type === 'output_text') { - if (isFirstTextChunk) { - controller.enqueue({ - type: ChunkType.TEXT_START - }) - isFirstTextChunk = false - } - controller.enqueue({ - type: ChunkType.TEXT_DELTA, - text: output.content[0].text - }) - if (output.content[0].annotations && output.content[0].annotations.length > 0) { - controller.enqueue({ - type: ChunkType.LLM_WEB_SEARCH_COMPLETE, - llm_web_search: { - source: WebSearchSource.OPENAI_RESPONSE, - results: output.content[0].annotations - } - }) - } - } - break - case 'reasoning': - if (isFirstThinkingChunk) { - controller.enqueue({ - type: ChunkType.THINKING_START - }) - isFirstThinkingChunk = false - } - controller.enqueue({ - type: ChunkType.THINKING_DELTA, - text: output.summary.map((s) => s.text).join('\n') - }) - break - case 'function_call': - toolCalls.push(output) - break - case 'image_generation_call': - controller.enqueue({ - type: ChunkType.IMAGE_CREATED - }) - controller.enqueue({ - type: ChunkType.IMAGE_COMPLETE, - image: { - type: 'base64', - images: [`data:image/png;base64,${output.result}`] - } - }) - } - } - if (toolCalls.length > 0) { - controller.enqueue({ - type: ChunkType.MCP_TOOL_CREATED, - tool_calls: toolCalls - }) - } - controller.enqueue({ - type: ChunkType.LLM_RESPONSE_COMPLETE, - response: { - usage: { - prompt_tokens: chunk.usage?.input_tokens || 0, - completion_tokens: chunk.usage?.output_tokens || 0, - total_tokens: chunk.usage?.total_tokens || 0 - } - } - }) - } else { - switch (chunk.type) { - case 'response.output_item.added': - if (chunk.item.type === 'function_call') { - outputItems.push(chunk.item) - } else if (chunk.item.type === 'web_search_call') { - controller.enqueue({ - type: ChunkType.LLM_WEB_SEARCH_IN_PROGRESS - }) - } - break - case 'response.reasoning_summary_part.added': - if (hasReasoningSummary) { - const separator = '\n\n' - controller.enqueue({ - type: ChunkType.THINKING_DELTA, - text: separator - }) - } - hasReasoningSummary = true - break - case 'response.reasoning_summary_text.delta': - if (isFirstThinkingChunk) { - controller.enqueue({ - type: ChunkType.THINKING_START - }) - isFirstThinkingChunk = false - } - controller.enqueue({ - type: ChunkType.THINKING_DELTA, - text: chunk.delta - }) - break - case 'response.image_generation_call.generating': - controller.enqueue({ - type: ChunkType.IMAGE_CREATED - }) - break - case 'response.image_generation_call.partial_image': - controller.enqueue({ - type: ChunkType.IMAGE_DELTA, - image: { - type: 'base64', - images: [`data:image/png;base64,${chunk.partial_image_b64}`] - } - }) - break - case 'response.image_generation_call.completed': - controller.enqueue({ - type: ChunkType.IMAGE_COMPLETE - }) - break - case 'response.output_text.delta': { - if (isFirstTextChunk) { - controller.enqueue({ - type: ChunkType.TEXT_START - }) - isFirstTextChunk = false - } - controller.enqueue({ - type: ChunkType.TEXT_DELTA, - text: chunk.delta - }) - break - } - case 'response.function_call_arguments.done': { - const outputItem: OpenAI.Responses.ResponseOutputItem | undefined = outputItems.find( - (item) => item.id === chunk.item_id - ) - if (outputItem) { - if (outputItem.type === 'function_call') { - toolCalls.push({ - ...outputItem, - arguments: chunk.arguments, - status: 'completed' - }) - } - } - break - } - case 'response.content_part.done': { - if (chunk.part.type === 'output_text' && chunk.part.annotations && chunk.part.annotations.length > 0) { - controller.enqueue({ - type: ChunkType.LLM_WEB_SEARCH_COMPLETE, - llm_web_search: { - source: WebSearchSource.OPENAI_RESPONSE, - results: chunk.part.annotations - } - }) - } - if (toolCalls.length > 0 && !hasBeenCollectedToolCalls) { - controller.enqueue({ - type: ChunkType.MCP_TOOL_CREATED, - tool_calls: toolCalls - }) - hasBeenCollectedToolCalls = true - } - break - } - case 'response.completed': { - if (ctx._internal?.toolProcessingState) { - ctx._internal.toolProcessingState.output = chunk.response - } - if (toolCalls.length > 0 && !hasBeenCollectedToolCalls) { - controller.enqueue({ - type: ChunkType.MCP_TOOL_CREATED, - tool_calls: toolCalls - }) - hasBeenCollectedToolCalls = true - } - const completion_tokens = chunk.response.usage?.output_tokens || 0 - const total_tokens = chunk.response.usage?.total_tokens || 0 - controller.enqueue({ - type: ChunkType.LLM_RESPONSE_COMPLETE, - response: { - usage: { - prompt_tokens: chunk.response.usage?.input_tokens || 0, - completion_tokens: completion_tokens, - total_tokens: total_tokens - } - } - }) - break - } - case 'error': { - controller.enqueue({ - type: ChunkType.ERROR, - error: { - message: chunk.message, - code: chunk.code - } - }) - break - } - } - } - } - }) - } -} diff --git a/src/renderer/src/aiCore/legacy/clients/openai/azureOpenAIEndpoint.ts b/src/renderer/src/aiCore/legacy/clients/openai/azureOpenAIEndpoint.ts deleted file mode 100644 index 777dbe74d7..0000000000 --- a/src/renderer/src/aiCore/legacy/clients/openai/azureOpenAIEndpoint.ts +++ /dev/null @@ -1,4 +0,0 @@ -export function normalizeAzureOpenAIEndpoint(apiHost: string): string { - const normalizedHost = apiHost.replace(/\/+$/, '') - return normalizedHost.replace(/\/openai(?:\/v1)?$/i, '') -} diff --git a/src/renderer/src/aiCore/legacy/clients/ovms/OVMSClient.ts b/src/renderer/src/aiCore/legacy/clients/ovms/OVMSClient.ts deleted file mode 100644 index 4936b693ee..0000000000 --- a/src/renderer/src/aiCore/legacy/clients/ovms/OVMSClient.ts +++ /dev/null @@ -1,56 +0,0 @@ -import type OpenAI from '@cherrystudio/openai' -import { loggerService } from '@logger' -import { isSupportedModel } from '@renderer/config/models' -import type { Provider } from '@renderer/types' -import { objectKeys } from '@renderer/types' -import { formatApiHost } from '@renderer/utils' -import { withoutTrailingApiVersion } from '@shared/utils' - -import { OpenAIAPIClient } from '../openai/OpenAIApiClient' - -const logger = loggerService.withContext('OVMSClient') - -export class OVMSClient extends OpenAIAPIClient { - constructor(provider: Provider) { - super(provider) - } - - override async listModels(): Promise { - try { - const sdk = await this.getSdkInstance() - const url = formatApiHost(withoutTrailingApiVersion(this.getBaseURL()), true, 'v1') - const chatModelsResponse = await sdk.withOptions({ baseURL: url }).get('/config') - logger.debug(`Chat models response: ${JSON.stringify(chatModelsResponse)}`) - - // Parse the config response to extract model information - const config = chatModelsResponse as Record - const models = objectKeys(config) - .map((modelName) => { - const modelInfo = config[modelName] - - // Check if model has at least one version with "AVAILABLE" state - const hasAvailableVersion = modelInfo?.model_version_status?.some( - (versionStatus: any) => versionStatus?.state === 'AVAILABLE' - ) - - if (hasAvailableVersion) { - return { - id: modelName, - object: 'model' as const, - owned_by: 'ovms', - created: Date.now() - } - } - return null // Skip models without available versions - }) - .filter(Boolean) // Remove null entries - logger.debug(`Processed models: ${JSON.stringify(models)}`) - - // Filter out unsupported models - return models.filter((model): model is OpenAI.Models.Model => model !== null && isSupportedModel(model)) - } catch (error) { - logger.error(`Error listing OVMS models: ${error}`) - return [] - } - } -} diff --git a/src/renderer/src/aiCore/legacy/clients/ppio/PPIOAPIClient.ts b/src/renderer/src/aiCore/legacy/clients/ppio/PPIOAPIClient.ts deleted file mode 100644 index 345496e156..0000000000 --- a/src/renderer/src/aiCore/legacy/clients/ppio/PPIOAPIClient.ts +++ /dev/null @@ -1,72 +0,0 @@ -import type OpenAI from '@cherrystudio/openai' -import { loggerService } from '@logger' -import { isSupportedModel } from '@renderer/config/models' -import type { Model, Provider } from '@renderer/types' - -import { OpenAIAPIClient } from '../openai/OpenAIApiClient' - -const logger = loggerService.withContext('PPIOAPIClient') -export class PPIOAPIClient extends OpenAIAPIClient { - constructor(provider: Provider) { - super(provider) - } - - // oxlint-disable-next-line @typescript-eslint/no-unused-vars - override getClientCompatibilityType(_model?: Model): string[] { - return ['OpenAIAPIClient'] - } - - override async listModels(): Promise { - try { - const sdk = await this.getSdkInstance() - - // PPIO requires three separate requests to get all model types - const [chatModelsResponse, embeddingModelsResponse, rerankerModelsResponse] = await Promise.all([ - // Chat/completion models - sdk.request({ - method: 'get', - path: '/models' - }), - // Embedding models - sdk.request({ - method: 'get', - path: '/models?model_type=embedding' - }), - // Reranker models - sdk.request({ - method: 'get', - path: '/models?model_type=reranker' - }) - ]) - - // Extract models from all responses - // @ts-ignore - PPIO response structure may not be typed - const allModels = [ - ...((chatModelsResponse as any)?.data || []), - ...((embeddingModelsResponse as any)?.data || []), - ...((rerankerModelsResponse as any)?.data || []) - ] - - // Process and standardize model data - const processedModels = allModels.map((model: any) => ({ - id: model.id || model.name, - description: model.description || model.display_name || model.summary, - object: 'model' as const, - owned_by: model.owned_by || model.publisher || model.organization || 'ppio', - created: model.created || Date.now() - })) - - // Clean up model IDs and filter supported models - processedModels.forEach((model) => { - if (model.id) { - model.id = model.id.trim() - } - }) - - return processedModels.filter(isSupportedModel) - } catch (error) { - logger.error('Error listing PPIO models:', error as Error) - return [] - } - } -} diff --git a/src/renderer/src/aiCore/legacy/clients/types.ts b/src/renderer/src/aiCore/legacy/clients/types.ts deleted file mode 100644 index bf7b129d93..0000000000 --- a/src/renderer/src/aiCore/legacy/clients/types.ts +++ /dev/null @@ -1,141 +0,0 @@ -import type Anthropic from '@anthropic-ai/sdk' -import type OpenAI from '@cherrystudio/openai' -import type { Assistant, MCPTool, MCPToolResponse, Model, ToolCallResponse } from '@renderer/types' -import type { Provider } from '@renderer/types' -import type { - AnthropicSdkRawChunk, - OpenAIResponseSdkRawChunk, - OpenAIResponseSdkRawOutput, - OpenAISdkRawChunk, - SdkMessageParam, - SdkParams, - SdkRawChunk, - SdkRawOutput, - SdkTool, - SdkToolCall -} from '@renderer/types/sdk' - -import type { CompletionsParams, GenericChunk } from '../middleware/schemas' -import type { CompletionsContext } from '../middleware/types' - -/** - * 原始流监听器接口 - */ -export interface RawStreamListener { - onChunk?: (chunk: TRawChunk) => void - onStart?: () => void - onEnd?: () => void - onError?: (error: Error) => void -} - -/** - * OpenAI 专用的流监听器 - */ -export interface OpenAIStreamListener extends RawStreamListener { - onChoice?: (choice: OpenAI.Chat.Completions.ChatCompletionChunk.Choice) => void - onFinishReason?: (reason: string) => void -} - -/** - * OpenAI Response 专用的流监听器 - */ -export interface OpenAIResponseStreamListener - extends RawStreamListener { - onMessage?: (response: OpenAIResponseSdkRawOutput) => void -} - -/** - * Anthropic 专用的流监听器 - */ -export interface AnthropicStreamListener - extends RawStreamListener { - onContentBlock?: (contentBlock: Anthropic.Messages.ContentBlock) => void - onMessage?: (message: Anthropic.Messages.Message) => void -} - -/** - * 请求转换器接口 - */ -export interface RequestTransformer< - TSdkParams extends SdkParams = SdkParams, - TMessageParam extends SdkMessageParam = SdkMessageParam -> { - transform( - completionsParams: CompletionsParams, - assistant: Assistant, - model: Model, - isRecursiveCall?: boolean, - recursiveSdkMessages?: TMessageParam[] - ): Promise<{ - payload: TSdkParams - messages: TMessageParam[] - metadata?: Record - }> -} - -/** - * 响应块转换器接口 - */ -export type ResponseChunkTransformer = ( - context?: TContext -) => Transformer - -export interface ResponseChunkTransformerContext { - isStreaming: boolean - isEnabledToolCalling: boolean - isEnabledWebSearch: boolean - isEnabledUrlContext: boolean - isEnabledReasoning: boolean - mcpTools: MCPTool[] - provider: Provider -} - -/** - * API客户端接口 - */ -export interface ApiClient< - TSdkInstance = any, - TSdkParams extends SdkParams = SdkParams, - TRawOutput extends SdkRawOutput = SdkRawOutput, - TRawChunk extends SdkRawChunk = SdkRawChunk, - TMessageParam extends SdkMessageParam = SdkMessageParam, - TToolCall extends SdkToolCall = SdkToolCall, - TSdkSpecificTool extends SdkTool = SdkTool -> { - provider: Provider - - // 核心方法 - 在中间件架构中,这个方法可能只是一个占位符 - // 实际的SDK调用由SdkCallMiddleware处理 - // completions(params: CompletionsParams): Promise - - createCompletions(payload: TSdkParams): Promise - - // SDK相关方法 - getSdkInstance(): Promise | TSdkInstance - getRequestTransformer(): RequestTransformer - getResponseChunkTransformer(ctx: CompletionsContext): ResponseChunkTransformer - - // 原始流监听方法 - attachRawStreamListener?(rawOutput: TRawOutput, listener: RawStreamListener): TRawOutput - - // 工具转换相关方法 (保持可选,因为不是所有Provider都支持工具) - convertMcpToolsToSdkTools(mcpTools: MCPTool[]): TSdkSpecificTool[] - convertMcpToolResponseToSdkMessageParam?( - mcpToolResponse: MCPToolResponse, - resp: any, - model: Model - ): TMessageParam | undefined - convertSdkToolCallToMcp?(toolCall: TToolCall, mcpTools: MCPTool[]): MCPTool | undefined - convertSdkToolCallToMcpToolResponse(toolCall: TToolCall, mcpTool: MCPTool): ToolCallResponse - - // 构建SDK特定的消息列表,用于工具调用后的递归调用 - buildSdkMessages( - currentReqMessages: TMessageParam[], - output: TRawOutput | string, - toolResults: TMessageParam[], - toolCalls?: TToolCall[] - ): TMessageParam[] - - // 从SDK载荷中提取消息数组(用于中间件中的类型安全访问) - extractMessagesFromSdkPayload(sdkPayload: TSdkParams): TMessageParam[] -} diff --git a/src/renderer/src/aiCore/legacy/clients/zhipu/ZhipuAPIClient.ts b/src/renderer/src/aiCore/legacy/clients/zhipu/ZhipuAPIClient.ts deleted file mode 100644 index 9c590996f1..0000000000 --- a/src/renderer/src/aiCore/legacy/clients/zhipu/ZhipuAPIClient.ts +++ /dev/null @@ -1,105 +0,0 @@ -import type OpenAI from '@cherrystudio/openai' -import { loggerService } from '@logger' -import type { Provider } from '@renderer/types' -import type { GenerateImageParams } from '@renderer/types' - -import { OpenAIAPIClient } from '../openai/OpenAIApiClient' - -const logger = loggerService.withContext('ZhipuAPIClient') - -export class ZhipuAPIClient extends OpenAIAPIClient { - constructor(provider: Provider) { - super(provider) - } - - override getClientCompatibilityType(): string[] { - return ['ZhipuAPIClient'] - } - - override async generateImage({ - model, - prompt, - negativePrompt, - imageSize, - batchSize, - signal, - quality - }: GenerateImageParams): Promise { - const sdk = await this.getSdkInstance() - - // 智谱AI使用不同的参数格式 - const body: any = { - model, - prompt - } - - // 智谱AI特有的参数格式 - body.size = imageSize - body.n = batchSize - if (negativePrompt) { - body.negative_prompt = negativePrompt - } - - // 只有cogview-4-250304模型支持quality和style参数 - if (model === 'cogview-4-250304') { - if (quality) { - body.quality = quality - } - body.style = 'vivid' - } - - try { - logger.debug('Calling Zhipu image generation API with params:', body) - - const response = await sdk.images.generate(body, { signal }) - - if (response.data && response.data.length > 0) { - return response.data.map((image: any) => image.url).filter(Boolean) - } - - return [] - } catch (error) { - logger.error('Zhipu image generation failed:', error as Error) - throw error - } - } - - public async listModels(): Promise { - const models = [ - 'glm-4.7', - 'glm-4.6', - 'glm-4.6v', - 'glm-4.6v-flash', - 'glm-4.6v-flashx', - 'glm-4.5', - 'glm-4.5-x', - 'glm-4.5-air', - 'glm-4.5-airx', - 'glm-4.5-flash', - 'glm-4.5v', - 'glm-z1-air', - 'glm-z1-airx', - 'cogview-3-flash', - 'cogview-4-250304', - 'glm-4-long', - 'glm-4-plus', - 'glm-4-air-250414', - 'glm-4-airx', - 'glm-4-flashx', - 'glm-4v', - 'glm-4v-flash', - 'glm-4v-plus-0111', - 'glm-4.1v-thinking-flash', - 'glm-4-alltools', - 'embedding-3' - ] - - const created = Date.now() - return models.map((id) => ({ - id, - owned_by: 'zhipu', - object: 'model' as const, - created - })) - } -} diff --git a/src/renderer/src/aiCore/legacy/index.ts b/src/renderer/src/aiCore/legacy/index.ts deleted file mode 100644 index 7c5f5211d9..0000000000 --- a/src/renderer/src/aiCore/legacy/index.ts +++ /dev/null @@ -1,185 +0,0 @@ -import { loggerService } from '@logger' -import { ApiClientFactory } from '@renderer/aiCore/legacy/clients/ApiClientFactory' -import type { BaseApiClient } from '@renderer/aiCore/legacy/clients/BaseApiClient' -import { isDedicatedImageGenerationModel, isFunctionCallingModel } from '@renderer/config/models' -import { withSpanResult } from '@renderer/services/SpanManagerService' -import type { StartSpanParams } from '@renderer/trace/types/ModelSpanEntity' -import type { GenerateImageParams, Model, Provider } from '@renderer/types' -import type { RequestOptions, SdkModel } from '@renderer/types/sdk' -import { isSupportedToolUse } from '@renderer/utils/mcp-tools' - -import { AihubmixAPIClient } from './clients/aihubmix/AihubmixAPIClient' -import { VertexAPIClient } from './clients/gemini/VertexAPIClient' -import { NewAPIClient } from './clients/newapi/NewAPIClient' -import { OpenAIResponseAPIClient } from './clients/openai/OpenAIResponseAPIClient' -import { CompletionsMiddlewareBuilder } from './middleware/builder' -import { MIDDLEWARE_NAME as AbortHandlerMiddlewareName } from './middleware/common/AbortHandlerMiddleware' -import { MIDDLEWARE_NAME as ErrorHandlerMiddlewareName } from './middleware/common/ErrorHandlerMiddleware' -import { MIDDLEWARE_NAME as FinalChunkConsumerMiddlewareName } from './middleware/common/FinalChunkConsumerMiddleware' -import { applyCompletionsMiddlewares } from './middleware/composer' -import { MIDDLEWARE_NAME as McpToolChunkMiddlewareName } from './middleware/core/McpToolChunkMiddleware' -import { MIDDLEWARE_NAME as RawStreamListenerMiddlewareName } from './middleware/core/RawStreamListenerMiddleware' -import { MIDDLEWARE_NAME as WebSearchMiddlewareName } from './middleware/core/WebSearchMiddleware' -import { MIDDLEWARE_NAME as ImageGenerationMiddlewareName } from './middleware/feat/ImageGenerationMiddleware' -import { MIDDLEWARE_NAME as ThinkingTagExtractionMiddlewareName } from './middleware/feat/ThinkingTagExtractionMiddleware' -import { MIDDLEWARE_NAME as ToolUseExtractionMiddlewareName } from './middleware/feat/ToolUseExtractionMiddleware' -import { MiddlewareRegistry } from './middleware/register' -import type { CompletionsParams, CompletionsResult } from './middleware/schemas' - -const logger = loggerService.withContext('AiProvider') - -export default class AiProvider { - private apiClient: BaseApiClient - - constructor(provider: Provider) { - // Use the new ApiClientFactory to get a BaseApiClient instance - this.apiClient = ApiClientFactory.create(provider) - } - - public async completions(params: CompletionsParams, options?: RequestOptions): Promise { - // 1. 根据模型识别正确的客户端 - const model = params.assistant.model - if (!model) { - return Promise.reject(new Error('Model is required')) - } - - // 根据client类型选择合适的处理方式 - let client: BaseApiClient - - if (this.apiClient instanceof AihubmixAPIClient) { - // AihubmixAPIClient: 根据模型选择合适的子client - client = this.apiClient.getClientForModel(model) - if (client instanceof OpenAIResponseAPIClient) { - client = client.getClient(model) as BaseApiClient - } - } else if (this.apiClient instanceof NewAPIClient) { - client = this.apiClient.getClientForModel(model) - if (client instanceof OpenAIResponseAPIClient) { - client = client.getClient(model) as BaseApiClient - } - } else if (this.apiClient instanceof OpenAIResponseAPIClient) { - // OpenAIResponseAPIClient: 根据模型特征选择API类型 - client = this.apiClient.getClient(model) as BaseApiClient - } else if (this.apiClient instanceof VertexAPIClient) { - client = this.apiClient.getClient(model) as BaseApiClient - } else { - // 其他client直接使用 - client = this.apiClient - } - - // 2. 构建中间件链 - const builder = CompletionsMiddlewareBuilder.withDefaults() - // images api - if (isDedicatedImageGenerationModel(model)) { - builder.clear() - builder - .add(MiddlewareRegistry[FinalChunkConsumerMiddlewareName]) - .add(MiddlewareRegistry[ErrorHandlerMiddlewareName]) - .add(MiddlewareRegistry[AbortHandlerMiddlewareName]) - .add(MiddlewareRegistry[ImageGenerationMiddlewareName]) - } else { - // Existing logic for other models - logger.silly('Builder Params', params) - // 使用兼容性类型检查,避免typescript类型收窄和装饰器模式的问题 - const clientTypes = client.getClientCompatibilityType(model) - const isOpenAICompatible = - clientTypes.includes('OpenAIAPIClient') || clientTypes.includes('OpenAIResponseAPIClient') - if (!isOpenAICompatible) { - logger.silly('ThinkingTagExtractionMiddleware is removed') - builder.remove(ThinkingTagExtractionMiddlewareName) - } - - const isAnthropicOrOpenAIResponseCompatible = - clientTypes.includes('AnthropicAPIClient') || - clientTypes.includes('OpenAIResponseAPIClient') || - clientTypes.includes('AnthropicVertexAPIClient') - if (!isAnthropicOrOpenAIResponseCompatible) { - logger.silly('RawStreamListenerMiddleware is removed') - builder.remove(RawStreamListenerMiddlewareName) - } - if (!params.enableWebSearch) { - logger.silly('WebSearchMiddleware is removed') - builder.remove(WebSearchMiddlewareName) - } - if (!params.mcpTools?.length) { - builder.remove(ToolUseExtractionMiddlewareName) - logger.silly('ToolUseExtractionMiddleware is removed') - builder.remove(McpToolChunkMiddlewareName) - logger.silly('McpToolChunkMiddleware is removed') - } - if (isSupportedToolUse(params.assistant) && isFunctionCallingModel(model)) { - builder.remove(ToolUseExtractionMiddlewareName) - logger.silly('ToolUseExtractionMiddleware is removed') - } - if (params.callType !== 'chat' && params.callType !== 'check' && params.callType !== 'translate') { - logger.silly('AbortHandlerMiddleware is removed') - builder.remove(AbortHandlerMiddlewareName) - } - if (params.callType === 'test') { - builder.remove(ErrorHandlerMiddlewareName) - logger.silly('ErrorHandlerMiddleware is removed') - builder.remove(FinalChunkConsumerMiddlewareName) - logger.silly('FinalChunkConsumerMiddleware is removed') - } - } - - const middlewares = builder.build() - logger.silly( - 'middlewares', - middlewares.map((m) => m.name) - ) - - // 3. Create the wrapped SDK method with middlewares - const wrappedCompletionMethod = applyCompletionsMiddlewares(client, client.createCompletions, middlewares) - - // 4. Execute the wrapped method with the original params - const result = wrappedCompletionMethod(params, options) - return result - } - - public async completionsForTrace(params: CompletionsParams, options?: RequestOptions): Promise { - const traceName = params.assistant.model?.name - ? `${params.assistant.model?.name}.${params.callType}` - : `LLM.${params.callType}` - - const traceParams: StartSpanParams = { - name: traceName, - tag: 'LLM', - topicId: params.topicId || '', - modelName: params.assistant.model?.name - } - - return await withSpanResult(this.completions.bind(this), traceParams, params, options) - } - - public async models(): Promise { - return this.apiClient.listModels() - } - - public async getEmbeddingDimensions(model: Model): Promise { - try { - // Use the SDK instance to test embedding capabilities - const dimensions = await this.apiClient.getEmbeddingDimensions(model) - return dimensions - } catch (error) { - logger.error('Error getting embedding dimensions:', error as Error) - throw error - } - } - - public async generateImage(params: GenerateImageParams): Promise { - if (this.apiClient instanceof AihubmixAPIClient) { - const client = this.apiClient.getClientForModel({ id: params.model } as Model) - return client.generateImage(params) - } - return this.apiClient.generateImage(params) - } - - public getBaseURL(): string { - return this.apiClient.getBaseURL() - } - - public getApiKey(): string { - return this.apiClient.getApiKey() - } -} diff --git a/src/renderer/src/aiCore/legacy/middleware/BUILDER_USAGE.md b/src/renderer/src/aiCore/legacy/middleware/BUILDER_USAGE.md deleted file mode 100644 index 27d9e32136..0000000000 --- a/src/renderer/src/aiCore/legacy/middleware/BUILDER_USAGE.md +++ /dev/null @@ -1,182 +0,0 @@ -# MiddlewareBuilder 使用指南 - -`MiddlewareBuilder` 是一个用于动态构建和管理中间件链的工具,提供灵活的中间件组织和配置能力。 - -## 主要特性 - -### 1. 统一的中间件命名 - -所有中间件都通过导出的 `MIDDLEWARE_NAME` 常量标识: - -```typescript -// 中间件文件示例 -export const MIDDLEWARE_NAME = 'SdkCallMiddleware' -export const SdkCallMiddleware: CompletionsMiddleware = ... -``` - -### 2. NamedMiddleware 接口 - -中间件使用统一的 `NamedMiddleware` 接口格式: - -```typescript -interface NamedMiddleware { - name: string - middleware: TMiddleware -} -``` - -### 3. 中间件注册表 - -通过 `MiddlewareRegistry` 集中管理所有可用中间件: - -```typescript -import { MiddlewareRegistry } from './register' - -// 通过名称获取中间件 -const sdkCallMiddleware = MiddlewareRegistry['SdkCallMiddleware'] -``` - -## 基本用法 - -### 1. 使用默认中间件链 - -```typescript -import { CompletionsMiddlewareBuilder } from './builder' - -const builder = CompletionsMiddlewareBuilder.withDefaults() -const middlewares = builder.build() -``` - -### 2. 自定义中间件链 - -```typescript -import { createCompletionsBuilder, MiddlewareRegistry } from './builder' - -const builder = createCompletionsBuilder([ - MiddlewareRegistry['AbortHandlerMiddleware'], - MiddlewareRegistry['TextChunkMiddleware'] -]) - -const middlewares = builder.build() -``` - -### 3. 动态调整中间件链 - -```typescript -const builder = CompletionsMiddlewareBuilder.withDefaults() - -// 根据条件添加、移除、替换中间件 -if (needsLogging) { - builder.prepend(MiddlewareRegistry['GenericLoggingMiddleware']) -} - -if (disableTools) { - builder.remove('McpToolChunkMiddleware') -} - -if (customThinking) { - builder.replace('ThinkingTagExtractionMiddleware', customThinkingMiddleware) -} - -const middlewares = builder.build() -``` - -### 4. 链式操作 - -```typescript -const middlewares = CompletionsMiddlewareBuilder.withDefaults() - .add(MiddlewareRegistry['CustomMiddleware']) - .insertBefore('SdkCallMiddleware', MiddlewareRegistry['SecurityCheckMiddleware']) - .remove('WebSearchMiddleware') - .build() -``` - -## API 参考 - -### CompletionsMiddlewareBuilder - -**静态方法:** - -- `static withDefaults()`: 创建带有默认中间件链的构建器 - -**实例方法:** - -- `add(middleware: NamedMiddleware)`: 在链末尾添加中间件 -- `prepend(middleware: NamedMiddleware)`: 在链开头添加中间件 -- `insertAfter(targetName: string, middleware: NamedMiddleware)`: 在指定中间件后插入 -- `insertBefore(targetName: string, middleware: NamedMiddleware)`: 在指定中间件前插入 -- `replace(targetName: string, middleware: NamedMiddleware)`: 替换指定中间件 -- `remove(targetName: string)`: 移除指定中间件 -- `has(name: string)`: 检查是否包含指定中间件 -- `build()`: 构建最终的中间件数组 -- `getChain()`: 获取当前链(包含名称信息) -- `clear()`: 清空中间件链 -- `execute(context, params, middlewareExecutor)`: 直接执行构建好的中间件链 - -### 工厂函数 - -- `createCompletionsBuilder(baseChain?)`: 创建 Completions 中间件构建器 -- `createMethodBuilder(baseChain?)`: 创建通用方法中间件构建器 -- `addMiddlewareName(middleware, name)`: 为中间件添加名称属性的辅助函数 - -### 中间件注册表 - -- `MiddlewareRegistry`: 所有注册中间件的集中访问点 -- `getMiddleware(name)`: 根据名称获取中间件 -- `getRegisteredMiddlewareNames()`: 获取所有注册的中间件名称 -- `DefaultCompletionsNamedMiddlewares`: 默认的 Completions 中间件链(NamedMiddleware 格式) - -## 类型安全 - -构建器提供完整的 TypeScript 类型支持: - -- `CompletionsMiddlewareBuilder` 专门用于 `CompletionsMiddleware` 类型 -- `MethodMiddlewareBuilder` 用于通用的 `MethodMiddleware` 类型 -- 所有中间件操作都基于 `NamedMiddleware` 接口 - -## 默认中间件链 - -默认的 Completions 中间件执行顺序: - -1. `FinalChunkConsumerMiddleware` - 最终消费者 -2. `TransformCoreToSdkParamsMiddleware` - 参数转换 -3. `AbortHandlerMiddleware` - 中止处理 -4. `McpToolChunkMiddleware` - 工具处理 -5. `WebSearchMiddleware` - Web搜索处理 -6. `TextChunkMiddleware` - 文本处理 -7. `ThinkingTagExtractionMiddleware` - 思考标签提取处理 -8. `ThinkChunkMiddleware` - 思考处理 -9. `ResponseTransformMiddleware` - 响应转换 -10. `StreamAdapterMiddleware` - 流适配器 -11. `SdkCallMiddleware` - SDK调用 - -## 在 AiProvider 中的使用 - -```typescript -export default class AiProvider { - public async completions(params: CompletionsParams): Promise { - // 1. 构建中间件链 - const builder = CompletionsMiddlewareBuilder.withDefaults() - - // 2. 根据参数动态调整 - if (params.enableCustomFeature) { - builder.insertAfter('StreamAdapterMiddleware', customFeatureMiddleware) - } - - // 3. 应用中间件 - const middlewares = builder.build() - const wrappedMethod = applyCompletionsMiddlewares(this.apiClient, this.apiClient.createCompletions, middlewares) - - return wrappedMethod(params) - } -} -``` - -## 注意事项 - -1. **类型兼容性**:`MethodMiddleware` 和 `CompletionsMiddleware` 不兼容,需要使用对应的构建器 -2. **中间件名称**:所有中间件必须导出 `MIDDLEWARE_NAME` 常量用于标识 -3. **注册表管理**:新增中间件需要在 `register.ts` 中注册 -4. **默认链**:默认链通过 `DefaultCompletionsNamedMiddlewares` 提供,支持延迟加载避免循环依赖 - -这种设计使得中间件链的构建既灵活又类型安全,同时保持了简洁的 API 接口。 diff --git a/src/renderer/src/aiCore/legacy/middleware/MIDDLEWARE_SPECIFICATION.md b/src/renderer/src/aiCore/legacy/middleware/MIDDLEWARE_SPECIFICATION.md deleted file mode 100644 index 6437282ff2..0000000000 --- a/src/renderer/src/aiCore/legacy/middleware/MIDDLEWARE_SPECIFICATION.md +++ /dev/null @@ -1,175 +0,0 @@ -# Cherry Studio 中间件规范 - -本文档定义了 Cherry Studio `aiCore` 模块中中间件的设计、实现和使用规范。目标是建立一个灵活、可维护且易于扩展的中间件系统。 - -## 1. 核心概念 - -### 1.1. 中间件 (Middleware) - -中间件是一个函数或对象,它在 AI 请求的处理流程中的特定阶段执行,可以访问和修改请求上下文 (`AiProviderMiddlewareContext`)、请求参数 (`Params`),并控制是否将请求传递给下一个中间件或终止流程。 - -每个中间件应该专注于一个单一的横切关注点,例如日志记录、错误处理、流适配、特性解析等。 - -### 1.2. `AiProviderMiddlewareContext` (上下文对象) - -这是一个在整个中间件链执行过程中传递的对象,包含以下核心信息: - -- `_apiClientInstance: ApiClient`: 当前选定的、已实例化的 AI Provider 客户端。 -- `_coreRequest: CoreRequestType`: 标准化的内部核心请求对象。 -- `resolvePromise: (value: AggregatedResultType) => void`: 用于在整个操作成功完成时解析 `AiCoreService` 返回的 Promise。 -- `rejectPromise: (reason?: any) => void`: 用于在发生错误时拒绝 `AiCoreService` 返回的 Promise。 -- `onChunk?: (chunk: Chunk) => void`: 应用层提供的流式数据块回调。 -- `abortController?: AbortController`: 用于中止请求的控制器。 -- 其他中间件可能读写的、与当前请求相关的动态数据。 - -### 1.3. `MiddlewareName` (中间件名称) - -为了方便动态操作(如插入、替换、移除)中间件,每个重要的、可能被其他逻辑引用的中间件都应该有一个唯一的、可识别的名称。推荐使用 TypeScript 的 `enum` 来定义: - -```typescript -// example -export enum MiddlewareName { - LOGGING_START = 'LoggingStartMiddleware', - LOGGING_END = 'LoggingEndMiddleware', - ERROR_HANDLING = 'ErrorHandlingMiddleware', - ABORT_HANDLER = 'AbortHandlerMiddleware', - // Core Flow - TRANSFORM_CORE_TO_SDK_PARAMS = 'TransformCoreToSdkParamsMiddleware', - REQUEST_EXECUTION = 'RequestExecutionMiddleware', - STREAM_ADAPTER = 'StreamAdapterMiddleware', - RAW_SDK_CHUNK_TO_APP_CHUNK = 'RawSdkChunkToAppChunkMiddleware', - // Features - THINKING_TAG_EXTRACTION = 'ThinkingTagExtractionMiddleware', - TOOL_USE_TAG_EXTRACTION = 'ToolUseTagExtractionMiddleware', - MCP_TOOL_HANDLER = 'McpToolHandlerMiddleware', - // Finalization - FINAL_CHUNK_CONSUMER = 'FinalChunkConsumerAndNotifierMiddleware' - // Add more as needed -} -``` - -中间件实例需要某种方式暴露其 `MiddlewareName`,例如通过一个 `name` 属性。 - -### 1.4. 中间件执行结构 - -我们采用一种灵活的中间件执行结构。一个中间件通常是一个函数,它接收 `Context`、`Params`,以及一个 `next` 函数(用于调用链中的下一个中间件)。 - -```typescript -// 简化形式的中间件函数签名 -type MiddlewareFunction = ( - context: AiProviderMiddlewareContext, - params: any, // e.g., CompletionsParams - next: () => Promise // next 通常返回 Promise 以支持异步操作 -) => Promise // 中间件自身也可能返回 Promise - -// 或者更经典的 Koa/Express 风格 (三段式) -// type MiddlewareFactory = (api?: MiddlewareApi) => -// (nextMiddleware: (ctx: AiProviderMiddlewareContext, params: any) => Promise) => -// (context: AiProviderMiddlewareContext, params: any) => Promise; -// 当前设计更倾向于上述简化的 MiddlewareFunction,由 MiddlewareExecutor 负责 next 的编排。 -``` - -`MiddlewareExecutor` (或 `applyMiddlewares`) 会负责管理 `next` 的调用。 - -## 2. `MiddlewareBuilder` (通用中间件构建器) - -为了动态构建和管理中间件链,我们引入一个通用的 `MiddlewareBuilder` 类。 - -### 2.1. 设计理念 - -`MiddlewareBuilder` 提供了一个流式 API,用于以声明式的方式构建中间件链。它允许从一个基础链开始,然后根据特定条件添加、插入、替换或移除中间件。 - -### 2.2. API 概览 - -```typescript -class MiddlewareBuilder { - constructor(baseChain?: Middleware[]) - - add(middleware: Middleware): this - prepend(middleware: Middleware): this - insertAfter(targetName: MiddlewareName, middlewareToInsert: Middleware): this - insertBefore(targetName: MiddlewareName, middlewareToInsert: Middleware): this - replace(targetName: MiddlewareName, newMiddleware: Middleware): this - remove(targetName: MiddlewareName): this - - build(): Middleware[] // 返回构建好的中间件数组 - - // 可选:直接执行链 - execute( - context: AiProviderMiddlewareContext, - params: any, - middlewareExecutor: (chain: Middleware[], context: AiProviderMiddlewareContext, params: any) => void - ): void -} -``` - -### 2.3. 使用示例 - -```typescript -// 1. 定义一些中间件实例 (假设它们有 .name 属性) -const loggingStart = { name: MiddlewareName.LOGGING_START, fn: loggingStartFn } -const requestExec = { name: MiddlewareName.REQUEST_EXECUTION, fn: requestExecFn } -const streamAdapter = { name: MiddlewareName.STREAM_ADAPTER, fn: streamAdapterFn } -const customFeature = { name: MiddlewareName.CUSTOM_FEATURE, fn: customFeatureFn } // 假设自定义 - -// 2. 定义一个基础链 (可选) -const BASE_CHAIN: Middleware[] = [loggingStart, requestExec, streamAdapter] - -// 3. 使用 MiddlewareBuilder -const builder = new MiddlewareBuilder(BASE_CHAIN) - -if (params.needsCustomFeature) { - builder.insertAfter(MiddlewareName.STREAM_ADAPTER, customFeature) -} - -if (params.isHighSecurityContext) { - builder.insertBefore(MiddlewareName.REQUEST_EXECUTION, высокоSecurityCheckMiddleware) -} - -if (params.overrideLogging) { - builder.replace(MiddlewareName.LOGGING_START, newSpecialLoggingMiddleware) -} - -// 4. 获取最终链 -const finalChain = builder.build() - -// 5. 执行 (通过外部执行器) -// middlewareExecutor(finalChain, context, params); -// 或者 builder.execute(context, params, middlewareExecutor); -``` - -## 3. `MiddlewareExecutor` / `applyMiddlewares` (中间件执行器) - -这是负责接收 `MiddlewareBuilder` 构建的中间件链并实际执行它们的组件。 - -### 3.1. 职责 - -- 接收 `Middleware[]`, `AiProviderMiddlewareContext`, `Params`。 -- 按顺序迭代中间件。 -- 为每个中间件提供正确的 `next` 函数,该函数在被调用时会执行链中的下一个中间件。 -- 处理中间件执行过程中的Promise(如果中间件是异步的)。 -- 基础的错误捕获(具体错误处理应由链内的 `ErrorHandlingMiddleware` 负责)。 - -## 4. 在 `AiCoreService` 中使用 - -`AiCoreService` 中的每个核心业务方法 (如 `executeCompletions`) 将负责: - -1. 准备基础数据:实例化 `ApiClient`,转换 `Params` 为 `CoreRequest`。 -2. 实例化 `MiddlewareBuilder`,可能会传入一个特定于该业务方法的基础中间件链。 -3. 根据 `Params` 和 `CoreRequest` 中的条件,调用 `MiddlewareBuilder` 的方法来动态调整中间件链。 -4. 调用 `MiddlewareBuilder.build()` 获取最终的中间件链。 -5. 创建完整的 `AiProviderMiddlewareContext` (包含 `resolvePromise`, `rejectPromise` 等)。 -6. 调用 `MiddlewareExecutor` (或 `applyMiddlewares`) 来执行构建好的链。 - -## 5. 组合功能 - -对于组合功能(例如 "Completions then Translate"): - -- 不推荐创建一个单一、庞大的 `MiddlewareBuilder` 来处理整个组合流程。 -- 推荐在 `AiCoreService` 中创建一个新的方法,该方法按顺序 `await` 调用底层的原子 `AiCoreService` 方法(例如,先 `await this.executeCompletions(...)`,然后用其结果 `await this.translateText(...)`)。 -- 每个被调用的原子方法内部会使用其自身的 `MiddlewareBuilder` 实例来构建和执行其特定阶段的中间件链。 -- 这种方式最大化了复用,并保持了各部分职责的清晰。 - -## 6. 中间件命名和发现 - -为中间件赋予唯一的 `MiddlewareName` 对于 `MiddlewareBuilder` 的 `insertAfter`, `insertBefore`, `replace`, `remove` 等操作至关重要。确保中间件实例能够以某种方式暴露其名称(例如,一个 `name` 属性)。 diff --git a/src/renderer/src/aiCore/legacy/middleware/__tests__/utils.test.ts b/src/renderer/src/aiCore/legacy/middleware/__tests__/utils.test.ts deleted file mode 100644 index 94602e5125..0000000000 --- a/src/renderer/src/aiCore/legacy/middleware/__tests__/utils.test.ts +++ /dev/null @@ -1,79 +0,0 @@ -import { ChunkType } from '@renderer/types/chunk' -import { describe, expect, it } from 'vitest' - -import { capitalize, createErrorChunk, isAsyncIterable } from '../utils' - -describe('utils', () => { - describe('createErrorChunk', () => { - it('should handle Error instances', () => { - const error = new Error('Test error message') - const result = createErrorChunk(error) - - expect(result.type).toBe(ChunkType.ERROR) - expect(result.error.message).toBe('Test error message') - expect(result.error.name).toBe('Error') - expect(result.error.stack).toBeDefined() - }) - - it('should handle string errors', () => { - const result = createErrorChunk('Something went wrong') - expect(result.error).toEqual({ message: 'Something went wrong' }) - }) - - it('should handle plain objects', () => { - const error = { code: 'NETWORK_ERROR', status: 500 } - const result = createErrorChunk(error) - expect(result.error).toEqual(error) - }) - - it('should handle null and undefined', () => { - expect(createErrorChunk(null).error).toEqual({}) - expect(createErrorChunk(undefined).error).toEqual({}) - }) - - it('should use custom chunk type when provided', () => { - const result = createErrorChunk('error', ChunkType.BLOCK_COMPLETE) - expect(result.type).toBe(ChunkType.BLOCK_COMPLETE) - }) - - it('should use toString for objects without message', () => { - const error = { - toString: () => 'Custom error' - } - const result = createErrorChunk(error) - expect(result.error.message).toBe('Custom error') - }) - }) - - describe('capitalize', () => { - it('should capitalize first letter', () => { - expect(capitalize('hello')).toBe('Hello') - expect(capitalize('a')).toBe('A') - }) - - it('should handle edge cases', () => { - expect(capitalize('')).toBe('') - expect(capitalize('123')).toBe('123') - expect(capitalize('Hello')).toBe('Hello') - }) - }) - - describe('isAsyncIterable', () => { - it('should identify async iterables', () => { - async function* gen() { - yield 1 - } - expect(isAsyncIterable(gen())).toBe(true) - expect(isAsyncIterable({ [Symbol.asyncIterator]: () => {} })).toBe(true) - }) - - it('should reject non-async iterables', () => { - expect(isAsyncIterable([1, 2, 3])).toBe(false) - expect(isAsyncIterable(new Set())).toBe(false) - expect(isAsyncIterable({})).toBe(false) - expect(isAsyncIterable(null)).toBe(false) - expect(isAsyncIterable(123)).toBe(false) - expect(isAsyncIterable('string')).toBe(false) - }) - }) -}) diff --git a/src/renderer/src/aiCore/legacy/middleware/builder.ts b/src/renderer/src/aiCore/legacy/middleware/builder.ts deleted file mode 100644 index 1d0b9d136d..0000000000 --- a/src/renderer/src/aiCore/legacy/middleware/builder.ts +++ /dev/null @@ -1,245 +0,0 @@ -import { loggerService } from '@logger' - -import { DefaultCompletionsNamedMiddlewares } from './register' -import type { BaseContext, CompletionsMiddleware, MethodMiddleware } from './types' - -const logger = loggerService.withContext('aiCore:MiddlewareBuilder') - -/** - * 带有名称标识的中间件接口 - */ -export interface NamedMiddleware { - name: string - middleware: TMiddleware -} - -/** - * 中间件执行器函数类型 - */ -export type MiddlewareExecutor = ( - chain: any[], - context: TContext, - params: any -) => Promise - -/** - * 通用中间件构建器类 - * 提供流式 API 用于动态构建和管理中间件链 - * - * 注意:所有中间件都通过 MiddlewareRegistry 管理,使用 NamedMiddleware 格式 - */ -export class MiddlewareBuilder { - private middlewares: NamedMiddleware[] - - /** - * 构造函数 - * @param baseChain - 可选的基础中间件链(NamedMiddleware 格式) - */ - constructor(baseChain?: NamedMiddleware[]) { - this.middlewares = baseChain ? [...baseChain] : [] - } - - /** - * 在链的末尾添加中间件 - * @param middleware - 要添加的具名中间件 - * @returns this,支持链式调用 - */ - add(middleware: NamedMiddleware): this { - this.middlewares.push(middleware) - return this - } - - /** - * 在链的开头添加中间件 - * @param middleware - 要添加的具名中间件 - * @returns this,支持链式调用 - */ - prepend(middleware: NamedMiddleware): this { - this.middlewares.unshift(middleware) - return this - } - - /** - * 在指定中间件之后插入新中间件 - * @param targetName - 目标中间件名称 - * @param middlewareToInsert - 要插入的具名中间件 - * @returns this,支持链式调用 - */ - insertAfter(targetName: string, middlewareToInsert: NamedMiddleware): this { - const index = this.findMiddlewareIndex(targetName) - if (index !== -1) { - this.middlewares.splice(index + 1, 0, middlewareToInsert) - } else { - logger.warn(`未找到名为 '${targetName}' 的中间件,无法插入`) - } - return this - } - - /** - * 在指定中间件之前插入新中间件 - * @param targetName - 目标中间件名称 - * @param middlewareToInsert - 要插入的具名中间件 - * @returns this,支持链式调用 - */ - insertBefore(targetName: string, middlewareToInsert: NamedMiddleware): this { - const index = this.findMiddlewareIndex(targetName) - if (index !== -1) { - this.middlewares.splice(index, 0, middlewareToInsert) - } else { - logger.warn(`未找到名为 '${targetName}' 的中间件,无法插入`) - } - return this - } - - /** - * 替换指定的中间件 - * @param targetName - 要替换的中间件名称 - * @param newMiddleware - 新的具名中间件 - * @returns this,支持链式调用 - */ - replace(targetName: string, newMiddleware: NamedMiddleware): this { - const index = this.findMiddlewareIndex(targetName) - if (index !== -1) { - this.middlewares[index] = newMiddleware - } else { - logger.warn(`未找到名为 '${targetName}' 的中间件,无法替换`) - } - return this - } - - /** - * 移除指定的中间件 - * @param targetName - 要移除的中间件名称 - * @returns this,支持链式调用 - */ - remove(targetName: string): this { - const index = this.findMiddlewareIndex(targetName) - if (index !== -1) { - this.middlewares.splice(index, 1) - } - return this - } - - /** - * 构建最终的中间件数组 - * @returns 构建好的中间件数组 - */ - build(): TMiddleware[] { - return this.middlewares.map((item) => item.middleware) - } - - /** - * 获取当前中间件链的副本(包含名称信息) - * @returns 当前中间件链的副本 - */ - getChain(): NamedMiddleware[] { - return [...this.middlewares] - } - - /** - * 检查是否包含指定名称的中间件 - * @param name - 中间件名称 - * @returns 是否包含该中间件 - */ - has(name: string): boolean { - return this.findMiddlewareIndex(name) !== -1 - } - - /** - * 获取中间件链的长度 - * @returns 中间件数量 - */ - get length(): number { - return this.middlewares.length - } - - /** - * 清空中间件链 - * @returns this,支持链式调用 - */ - clear(): this { - this.middlewares = [] - return this - } - - /** - * 直接执行构建好的中间件链 - * @param context - 中间件上下文 - * @param params - 参数 - * @param middlewareExecutor - 中间件执行器 - * @returns 执行结果 - */ - execute( - context: TContext, - params: any, - middlewareExecutor: MiddlewareExecutor - ): Promise { - const chain = this.build() - return middlewareExecutor(chain, context, params) - } - - /** - * 查找中间件在链中的索引 - * @param name - 中间件名称 - * @returns 索引,如果未找到返回 -1 - */ - private findMiddlewareIndex(name: string): number { - return this.middlewares.findIndex((item) => item.name === name) - } -} - -/** - * Completions 中间件构建器 - */ -export class CompletionsMiddlewareBuilder extends MiddlewareBuilder { - constructor(baseChain?: NamedMiddleware[]) { - super(baseChain) - } - - /** - * 使用默认的 Completions 中间件链 - * @returns CompletionsMiddlewareBuilder 实例 - */ - static withDefaults(): CompletionsMiddlewareBuilder { - return new CompletionsMiddlewareBuilder(DefaultCompletionsNamedMiddlewares) - } -} - -/** - * 通用方法中间件构建器 - */ -export class MethodMiddlewareBuilder extends MiddlewareBuilder { - constructor(baseChain?: NamedMiddleware[]) { - super(baseChain) - } -} - -// 便捷的工厂函数 - -/** - * 创建 Completions 中间件构建器 - * @param baseChain - 可选的基础链 - * @returns Completions 中间件构建器实例 - */ -export function createCompletionsBuilder( - baseChain?: NamedMiddleware[] -): CompletionsMiddlewareBuilder { - return new CompletionsMiddlewareBuilder(baseChain) -} - -/** - * 创建通用方法中间件构建器 - * @param baseChain - 可选的基础链 - * @returns 通用方法中间件构建器实例 - */ -export function createMethodBuilder(baseChain?: NamedMiddleware[]): MethodMiddlewareBuilder { - return new MethodMiddlewareBuilder(baseChain) -} - -/** - * 为中间件添加名称属性的辅助函数 - * 可以用于给现有的中间件添加名称属性 - */ -export function addMiddlewareName(middleware: T, name: string): T & { MIDDLEWARE_NAME: string } { - return Object.assign(middleware, { MIDDLEWARE_NAME: name }) -} diff --git a/src/renderer/src/aiCore/legacy/middleware/common/AbortHandlerMiddleware.ts b/src/renderer/src/aiCore/legacy/middleware/common/AbortHandlerMiddleware.ts deleted file mode 100644 index 5f24797813..0000000000 --- a/src/renderer/src/aiCore/legacy/middleware/common/AbortHandlerMiddleware.ts +++ /dev/null @@ -1,121 +0,0 @@ -import { loggerService } from '@logger' -import type { Chunk, ErrorChunk } from '@renderer/types/chunk' -import { ChunkType } from '@renderer/types/chunk' -import { addAbortController, removeAbortController } from '@renderer/utils/abortController' - -import type { CompletionsParams, CompletionsResult } from '../schemas' -import type { CompletionsContext, CompletionsMiddleware } from '../types' - -const logger = loggerService.withContext('aiCore:AbortHandlerMiddleware') - -export const MIDDLEWARE_NAME = 'AbortHandlerMiddleware' - -export const AbortHandlerMiddleware: CompletionsMiddleware = - () => - (next) => - async (ctx: CompletionsContext, params: CompletionsParams): Promise => { - const isRecursiveCall = ctx._internal?.toolProcessingState?.isRecursiveCall || false - - // 在递归调用中,跳过 AbortController 的创建,直接使用已有的 - if (isRecursiveCall) { - const result = await next(ctx, params) - return result - } - - const abortController = new AbortController() - const abortFn = (): void => abortController.abort() - let abortSignal: AbortSignal | null = abortController.signal - let abortKey: string - - // 如果参数中传入了abortKey则优先使用 - if (params.abortKey) { - abortKey = params.abortKey - } else { - // 获取当前消息的ID用于abort管理 - // 优先使用处理过的消息,如果没有则使用原始消息 - let messageId: string | undefined - - if (typeof params.messages === 'string') { - messageId = `message-${Date.now()}-${Math.random().toString(36).substring(2, 9)}` - } else { - const processedMessages = params.messages - const lastUserMessage = processedMessages.findLast((m) => m.role === 'user') - messageId = lastUserMessage?.id - } - - if (!messageId) { - logger.warn(`No messageId found, abort functionality will not be available.`) - return next(ctx, params) - } - - abortKey = messageId - } - - addAbortController(abortKey, abortFn) - const cleanup = (): void => { - removeAbortController(abortKey, abortFn) - if (ctx._internal?.flowControl) { - ctx._internal.flowControl.abortController = undefined - ctx._internal.flowControl.abortSignal = undefined - ctx._internal.flowControl.cleanup = undefined - } - abortSignal = null - } - - // 将controller添加到_internal中的flowControl状态 - if (!ctx._internal.flowControl) { - ctx._internal.flowControl = {} - } - ctx._internal.flowControl.abortController = abortController - ctx._internal.flowControl.abortSignal = abortSignal - ctx._internal.flowControl.cleanup = cleanup - - const result = await next(ctx, params) - - const error = new DOMException('Request was aborted', 'AbortError') - - const streamWithAbortHandler = (result.stream as ReadableStream).pipeThrough( - new TransformStream({ - transform(chunk, controller) { - // 如果已经收到错误块,不再检查 abort 状态 - if (chunk.type === ChunkType.ERROR) { - controller.enqueue(chunk) - return - } - - if (abortSignal?.aborted) { - // 转换为 ErrorChunk - const errorChunk: ErrorChunk = { - type: ChunkType.ERROR, - error - } - - controller.enqueue(errorChunk) - cleanup() - return - } - - // 正常传递 chunk - controller.enqueue(chunk) - }, - - flush(controller) { - // 在流结束时再次检查 abort 状态 - if (abortSignal?.aborted) { - const errorChunk: ErrorChunk = { - type: ChunkType.ERROR, - error - } - controller.enqueue(errorChunk) - } - // 在流完全处理完成后清理 AbortController - cleanup() - } - }) - ) - - return { - ...result, - stream: streamWithAbortHandler - } - } diff --git a/src/renderer/src/aiCore/legacy/middleware/common/ErrorHandlerMiddleware.ts b/src/renderer/src/aiCore/legacy/middleware/common/ErrorHandlerMiddleware.ts deleted file mode 100644 index c93e42fbb2..0000000000 --- a/src/renderer/src/aiCore/legacy/middleware/common/ErrorHandlerMiddleware.ts +++ /dev/null @@ -1,134 +0,0 @@ -import { loggerService } from '@logger' -import { isZhipuModel } from '@renderer/config/models' -import { getStoreProviders } from '@renderer/hooks/useStore' -import { getDefaultModel } from '@renderer/services/AssistantService' -import type { Chunk } from '@renderer/types/chunk' - -import type { CompletionsParams, CompletionsResult } from '../schemas' -import type { CompletionsContext } from '../types' -import { createErrorChunk } from '../utils' - -const logger = loggerService.withContext('ErrorHandlerMiddleware') - -export const MIDDLEWARE_NAME = 'ErrorHandlerMiddleware' - -/** - * 创建一个错误处理中间件。 - * - * 这是一个高阶函数,它接收配置并返回一个标准的中间件。 - * 它的主要职责是捕获下游中间件或API调用中发生的任何错误。 - * - * @param config - 中间件的配置。 - * @returns 一个配置好的CompletionsMiddleware。 - */ -export const ErrorHandlerMiddleware = - () => - (next) => - async (ctx: CompletionsContext, params): Promise => { - const { shouldThrow } = params - - try { - // 尝试执行下一个中间件 - return await next(ctx, params) - } catch (error: any) { - logger.error(error) - - let processedError = error - processedError = handleError(error, params) - - // 1. 使用通用的工具函数将错误解析为标准格式 - const errorChunk = createErrorChunk(processedError) - - // 2. 调用从外部传入的 onError 回调 - if (params.onError) { - params.onError(processedError) - } - - // 3. 根据配置决定是重新抛出错误,还是将其作为流的一部分向下传递 - if (shouldThrow) { - throw processedError - } - - // 如果不抛出,则创建一个只包含该错误块的流并向下传递 - const errorStream = new ReadableStream({ - start(controller) { - controller.enqueue(errorChunk) - controller.close() - } - }) - - return { - rawOutput: undefined, - stream: errorStream, // 将包含错误的流传递下去 - controller: undefined, - getText: () => '' // 错误情况下没有文本结果 - } - } - } - -function handleError(error: any, params: CompletionsParams): any { - if (isZhipuModel(params.assistant.model || getDefaultModel()) && error.status && !params.enableGenerateImage) { - return handleZhipuError(error) - } - - if (error.status === 401 || error.message.includes('401')) { - return { - ...error, - i18nKey: 'chat.no_api_key', - providerId: params.assistant?.model?.provider - } - } - - return error -} - -/** - * 处理智谱特定错误 - * 1. 只有对话功能(enableGenerateImage为false)才使用自定义错误处理 - * 2. 绘画功能(enableGenerateImage为true)使用通用错误处理 - */ -function handleZhipuError(error: any): any { - const provider = getStoreProviders().find((p) => p.id === 'zhipu') - const logger = loggerService.withContext('handleZhipuError') - - // 定义错误模式映射 - const errorPatterns = [ - { - condition: () => error.status === 401 || /令牌已过期|AuthenticationError|Unauthorized/i.test(error.message), - i18nKey: 'chat.no_api_key', - providerId: provider?.id - }, - { - condition: () => error.error?.code === '1304' || /限额|免费配额|free quota|rate limit/i.test(error.message), - i18nKey: 'chat.quota_exceeded', - providerId: provider?.id - }, - { - condition: () => - (error.status === 429 && error.error?.code === '1113') || /余额不足|insufficient balance/i.test(error.message), - i18nKey: 'chat.insufficient_balance', - providerId: provider?.id - }, - { - condition: () => !provider?.apiKey?.trim(), - i18nKey: 'chat.no_api_key', - providerId: provider?.id - } - ] - - // 遍历错误模式,返回第一个匹配的错误 - for (const pattern of errorPatterns) { - if (pattern.condition()) { - return { - ...error, - providerId: pattern.providerId, - i18nKey: pattern.i18nKey - } - } - } - - // 如果不是智谱特定错误,返回原始错误 - logger.debug('🔧 不是智谱特定错误,返回原始错误') - - return error -} diff --git a/src/renderer/src/aiCore/legacy/middleware/common/FinalChunkConsumerMiddleware.ts b/src/renderer/src/aiCore/legacy/middleware/common/FinalChunkConsumerMiddleware.ts deleted file mode 100644 index 0325e4e21a..0000000000 --- a/src/renderer/src/aiCore/legacy/middleware/common/FinalChunkConsumerMiddleware.ts +++ /dev/null @@ -1,195 +0,0 @@ -import { loggerService } from '@logger' -import type { Usage } from '@renderer/types' -import type { Chunk } from '@renderer/types/chunk' -import { ChunkType } from '@renderer/types/chunk' - -import type { CompletionsParams, CompletionsResult, GenericChunk } from '../schemas' -import type { CompletionsContext, CompletionsMiddleware } from '../types' - -export const MIDDLEWARE_NAME = 'FinalChunkConsumerAndNotifierMiddleware' - -const logger = loggerService.withContext('FinalChunkConsumerMiddleware') - -/** - * 最终Chunk消费和通知中间件 - * - * 职责: - * 1. 消费所有GenericChunk流中的chunks并转发给onChunk回调 - * 2. 累加usage/metrics数据(从原始SDK chunks或GenericChunk中提取) - * 3. 在检测到LLM_RESPONSE_COMPLETE时发送包含累计数据的BLOCK_COMPLETE - * 4. 处理MCP工具调用的多轮请求中的数据累加 - */ -const FinalChunkConsumerMiddleware: CompletionsMiddleware = - () => - (next) => - async (ctx: CompletionsContext, params: CompletionsParams): Promise => { - const isRecursiveCall = - params._internal?.toolProcessingState?.isRecursiveCall || - ctx._internal?.toolProcessingState?.isRecursiveCall || - false - - // 初始化累计数据(只在顶层调用时初始化) - if (!isRecursiveCall) { - if (!ctx._internal.customState) { - ctx._internal.customState = {} - } - ctx._internal.observer = { - usage: { - prompt_tokens: 0, - completion_tokens: 0, - total_tokens: 0 - }, - metrics: { - completion_tokens: 0, - time_completion_millsec: 0, - time_first_token_millsec: 0, - time_thinking_millsec: 0 - } - } - // 初始化文本累积器 - ctx._internal.customState.accumulatedText = '' - ctx._internal.customState.startTimestamp = Date.now() - } - - // 调用下游中间件 - const result = await next(ctx, params) - - // 响应后处理:处理GenericChunk流式响应 - if (result.stream) { - const resultFromUpstream = result.stream - - if (resultFromUpstream && resultFromUpstream instanceof ReadableStream) { - const reader = resultFromUpstream.getReader() - - try { - while (true) { - const { done, value: chunk } = await reader.read() - logger.silly('chunk', chunk) - if (done) { - logger.debug(`Input stream finished.`) - break - } - - if (chunk) { - const genericChunk = chunk as GenericChunk - // 提取并累加usage/metrics数据 - extractAndAccumulateUsageMetrics(ctx, genericChunk) - - const shouldSkipChunk = - isRecursiveCall && - (genericChunk.type === ChunkType.BLOCK_COMPLETE || - genericChunk.type === ChunkType.LLM_RESPONSE_COMPLETE) - - if (!shouldSkipChunk) params.onChunk?.(genericChunk) - } else { - logger.warn(`Received undefined chunk before stream was done.`) - } - } - } catch (error: any) { - logger.error(`Error consuming stream:`, error as Error) - // FIXME: 临时解决方案。该中间件的异常无法被 ErrorHandlerMiddleware捕获。 - if (params.onError) { - params.onError(error) - } - if (params.shouldThrow) { - throw error - } - } finally { - if (params.onChunk && !isRecursiveCall) { - params.onChunk({ - type: ChunkType.BLOCK_COMPLETE, - response: { - usage: ctx._internal.observer?.usage ? { ...ctx._internal.observer.usage } : undefined, - metrics: ctx._internal.observer?.metrics ? { ...ctx._internal.observer.metrics } : undefined - } - } as Chunk) - if (ctx._internal.toolProcessingState) { - ctx._internal.toolProcessingState = {} - } - } - } - - // 为流式输出添加getText方法 - const modifiedResult = { - ...result, - stream: new ReadableStream({ - start(controller) { - controller.close() - } - }), - getText: () => { - return ctx._internal.customState?.accumulatedText || '' - } - } - - return modifiedResult - } else { - logger.debug(`No GenericChunk stream to process.`) - } - } - - return result - } - -/** - * 从GenericChunk或原始SDK chunks中提取usage/metrics数据并累加 - */ -function extractAndAccumulateUsageMetrics(ctx: CompletionsContext, chunk: GenericChunk): void { - if (!ctx._internal.observer?.usage || !ctx._internal.observer?.metrics) { - return - } - - try { - if (ctx._internal.customState && !ctx._internal.customState?.firstTokenTimestamp) { - ctx._internal.customState.firstTokenTimestamp = Date.now() - logger.debug(`First token timestamp: ${ctx._internal.customState.firstTokenTimestamp}`) - } - if (chunk.type === ChunkType.LLM_RESPONSE_COMPLETE) { - // 从LLM_RESPONSE_COMPLETE chunk中提取usage数据 - if (chunk.response?.usage) { - accumulateUsage(ctx._internal.observer.usage, chunk.response.usage) - } - - if (ctx._internal.customState && ctx._internal.customState?.firstTokenTimestamp) { - ctx._internal.observer.metrics.time_first_token_millsec = - ctx._internal.customState.firstTokenTimestamp - ctx._internal.customState.startTimestamp - ctx._internal.observer.metrics.time_completion_millsec += - Date.now() - ctx._internal.customState.firstTokenTimestamp - } - } - - // 也可以从其他chunk类型中提取metrics数据 - if (chunk.type === ChunkType.THINKING_COMPLETE && chunk.thinking_millsec && ctx._internal.observer?.metrics) { - ctx._internal.observer.metrics.time_thinking_millsec = Math.max( - ctx._internal.observer.metrics.time_thinking_millsec || 0, - chunk.thinking_millsec - ) - } - } catch (error) { - logger.error('Error extracting usage/metrics from chunk:', error as Error) - } -} - -/** - * 累加usage数据 - */ -function accumulateUsage(accumulated: Usage, newUsage: Usage): void { - if (newUsage.prompt_tokens !== undefined) { - accumulated.prompt_tokens += newUsage.prompt_tokens - } - if (newUsage.completion_tokens !== undefined) { - accumulated.completion_tokens += newUsage.completion_tokens - } - if (newUsage.total_tokens !== undefined) { - accumulated.total_tokens += newUsage.total_tokens - } - if (newUsage.thoughts_tokens !== undefined) { - accumulated.thoughts_tokens = (accumulated.thoughts_tokens || 0) + newUsage.thoughts_tokens - } - // Handle OpenRouter specific cost fields - if (newUsage.cost !== undefined) { - accumulated.cost = (accumulated.cost || 0) + newUsage.cost - } -} - -export default FinalChunkConsumerMiddleware diff --git a/src/renderer/src/aiCore/legacy/middleware/common/LoggingMiddleware.ts b/src/renderer/src/aiCore/legacy/middleware/common/LoggingMiddleware.ts deleted file mode 100644 index 480cbbc39f..0000000000 --- a/src/renderer/src/aiCore/legacy/middleware/common/LoggingMiddleware.ts +++ /dev/null @@ -1,68 +0,0 @@ -import { loggerService } from '@logger' - -import type { BaseContext, MethodMiddleware, MiddlewareAPI } from '../types' - -const logger = loggerService.withContext('LoggingMiddleware') - -export const MIDDLEWARE_NAME = 'GenericLoggingMiddlewares' - -/** - * Helper function to safely stringify arguments for logging, handling circular references and large objects. - * 安全地字符串化日志参数的辅助函数,处理循环引用和大型对象。 - * @param args - The arguments array to stringify. 要字符串化的参数数组。 - * @returns A string representation of the arguments. 参数的字符串表示形式。 - */ -const stringifyArgsForLogging = (args: any[]): string => { - try { - return args - .map((arg) => { - if (typeof arg === 'function') return '[Function]' - if (typeof arg === 'object' && arg !== null && arg.constructor === Object && Object.keys(arg).length > 20) { - return '[Object with >20 keys]' - } - // Truncate long strings to avoid flooding logs 截断长字符串以避免日志泛滥 - const stringifiedArg = JSON.stringify(arg, null, 2) - return stringifiedArg && stringifiedArg.length > 200 ? stringifiedArg.substring(0, 200) + '...' : stringifiedArg - }) - .join(', ') - } catch (e) { - return '[Error serializing arguments]' // Handle potential errors during stringification 处理字符串化期间的潜在错误 - } -} - -/** - * Generic logging middleware for provider methods. - * 为提供者方法创建一个通用的日志中间件。 - * This middleware logs the initiation, success/failure, and duration of a method call. - * 此中间件记录方法调用的启动、成功/失败以及持续时间。 - */ - -/** - * Creates a generic logging middleware for provider methods. - * 为提供者方法创建一个通用的日志中间件。 - * @returns A `MethodMiddleware` instance. 一个 `MethodMiddleware` 实例。 - */ -export const createGenericLoggingMiddleware: () => MethodMiddleware = () => { - const middlewareName = 'GenericLoggingMiddleware' - // oxlint-disable-next-line @typescript-eslint/no-unused-vars - return (_: MiddlewareAPI) => (next) => async (ctx, args) => { - const methodName = ctx.methodName - const logPrefix = `[${middlewareName} (${methodName})]` - logger.debug(`${logPrefix} Initiating. Args: ${stringifyArgsForLogging(args)}`) - const startTime = Date.now() - try { - const result = await next(ctx, args) - const duration = Date.now() - startTime - // Log successful completion of the method call with duration. / - // 记录方法调用成功完成及其持续时间。 - logger.debug(`${logPrefix} Successful. Duration: ${duration}ms`) - return result - } catch (error) { - const duration = Date.now() - startTime - // Log failure of the method call with duration and error information. / - // 记录方法调用失败及其持续时间和错误信息。 - logger.error(`${logPrefix} Failed. Duration: ${duration}ms`, error as Error) - throw error // Re-throw the error to be handled by subsequent layers or the caller / 重新抛出错误,由后续层或调用者处理 - } - } -} diff --git a/src/renderer/src/aiCore/legacy/middleware/composer.ts b/src/renderer/src/aiCore/legacy/middleware/composer.ts deleted file mode 100644 index 97bbf0a38d..0000000000 --- a/src/renderer/src/aiCore/legacy/middleware/composer.ts +++ /dev/null @@ -1,289 +0,0 @@ -import { withSpanResult } from '@renderer/services/SpanManagerService' -import type { - RequestOptions, - SdkInstance, - SdkMessageParam, - SdkParams, - SdkRawChunk, - SdkRawOutput, - SdkTool, - SdkToolCall -} from '@renderer/types/sdk' - -import type { BaseApiClient } from '../clients' -import type { CompletionsParams, CompletionsResult } from './schemas' -import type { BaseContext, CompletionsContext, CompletionsMiddleware, MethodMiddleware, MiddlewareAPI } from './types' -import { MIDDLEWARE_CONTEXT_SYMBOL } from './types' - -/** - * Creates the initial context for a method call, populating method-specific fields. / - * 为方法调用创建初始上下文,并填充特定于该方法的字段。 - * @param methodName - The name of the method being called. / 被调用的方法名。 - * @param originalCallArgs - The actual arguments array from the proxy/method call. / 代理/方法调用的实际参数数组。 - * @param providerId - The ID of the provider, if available. / 提供者的ID(如果可用)。 - * @param providerInstance - The instance of the provider. / 提供者实例。 - * @param specificContextFactory - An optional factory function to create a specific context type from the base context and original call arguments. / 一个可选的工厂函数,用于从基础上下文和原始调用参数创建特定的上下文类型。 - * @returns The created context object. / 创建的上下文对象。 - */ -function createInitialCallContext( - methodName: string, - originalCallArgs: TCallArgs, // Renamed from originalArgs to avoid confusion with context.originalArgs - // Factory to create specific context from base and the *original call arguments array* - specificContextFactory?: (base: BaseContext, callArgs: TCallArgs) => TContext -): TContext { - const baseContext: BaseContext = { - [MIDDLEWARE_CONTEXT_SYMBOL]: true, - methodName, - originalArgs: originalCallArgs // Store the full original arguments array in the context - } - - if (specificContextFactory) { - return specificContextFactory(baseContext, originalCallArgs) - } - return baseContext as TContext // Fallback to base context if no specific factory -} - -/** - * Composes an array of functions from right to left. / - * 从右到左组合一个函数数组。 - * `compose(f, g, h)` is `(...args) => f(g(h(...args)))`. / - * `compose(f, g, h)` 等同于 `(...args) => f(g(h(...args)))`。 - * Each function in funcs is expected to take the result of the next function - * (or the initial value for the rightmost function) as its argument. / - * `funcs` 中的每个函数都期望接收下一个函数的结果(或最右侧函数的初始值)作为其参数。 - * @param funcs - Array of functions to compose. / 要组合的函数数组。 - * @returns The composed function. / 组合后的函数。 - */ -function compose(...funcs: Array<(...args: any[]) => any>): (...args: any[]) => any { - if (funcs.length === 0) { - // If no functions to compose, return a function that returns its first argument, or undefined if no args. / - // 如果没有要组合的函数,则返回一个函数,该函数返回其第一个参数,如果没有参数则返回undefined。 - return (...args: any[]) => (args.length > 0 ? args[0] : undefined) - } - if (funcs.length === 1) { - return funcs[0] - } - return funcs.reduce( - (a, b) => - (...args: any[]) => - a(b(...args)) - ) -} - -/** - * Applies an array of Redux-style middlewares to a generic provider method. / - * 将一组Redux风格的中间件应用于一个通用的提供者方法。 - * This version keeps arguments as an array throughout the middleware chain. / - * 此版本在整个中间件链中将参数保持为数组形式。 - * @param originalProviderInstance - The original provider instance. / 原始提供者实例。 - * @param methodName - The name of the method to be enhanced. / 需要增强的方法名。 - * @param originalMethod - The original method to be wrapped. / 需要包装的原始方法。 - * @param middlewares - An array of `ProviderMethodMiddleware` to apply. / 要应用的 `ProviderMethodMiddleware` 数组。 - * @param specificContextFactory - An optional factory to create a specific context for this method. / 可选的工厂函数,用于为此方法创建特定的上下文。 - * @returns An enhanced method with the middlewares applied. / 应用了中间件的增强方法。 - */ -export function applyMethodMiddlewares< - TArgs extends unknown[] = unknown[], // Original method's arguments array type / 原始方法的参数数组类型 - TResult = unknown, - TContext extends BaseContext = BaseContext ->( - methodName: string, - originalMethod: (...args: TArgs) => Promise, - middlewares: MethodMiddleware[], // Expects generic middlewares / 期望通用中间件 - specificContextFactory?: (base: BaseContext, callArgs: TArgs) => TContext -): (...args: TArgs) => Promise { - // Returns a function matching the original method signature. / - // 返回一个与原始方法签名匹配的函数。 - return async function enhancedMethod(...methodCallArgs: TArgs): Promise { - const ctx = createInitialCallContext( - methodName, - methodCallArgs, // Pass the actual call arguments array / 传递实际的调用参数数组 - specificContextFactory - ) - - const api: MiddlewareAPI = { - getContext: () => ctx, - getOriginalArgs: () => methodCallArgs // API provides the original arguments array / API提供原始参数数组 - } - - // `finalDispatch` is the function that will ultimately call the original provider method. / - // `finalDispatch` 是最终将调用原始提供者方法的函数。 - // It receives the current context and arguments, which may have been transformed by middlewares. / - // 它接收当前的上下文和参数,这些参数可能已被中间件转换。 - const finalDispatch = async ( - _: TContext, - currentArgs: TArgs // Generic final dispatch expects args array / 通用finalDispatch期望参数数组 - ): Promise => { - return originalMethod.apply(currentArgs) - } - - const chain = middlewares.map((middleware) => middleware(api)) // Cast API if TContext/TArgs mismatch general ProviderMethodMiddleware / 如果TContext/TArgs与通用的ProviderMethodMiddleware不匹配,则转换API - const composedMiddlewareLogic = compose(...chain) - const enhancedDispatch = composedMiddlewareLogic(finalDispatch) - - return enhancedDispatch(ctx, methodCallArgs) // Pass context and original args array / 传递上下文和原始参数数组 - } -} - -/** - * Applies an array of `CompletionsMiddleware` to the `completions` method. / - * 将一组 `CompletionsMiddleware` 应用于 `completions` 方法。 - * This version adapts for `CompletionsMiddleware` expecting a single `params` object. / - * 此版本适配了期望单个 `params` 对象的 `CompletionsMiddleware`。 - * @param originalProviderInstance - The original provider instance. / 原始提供者实例。 - * @param originalCompletionsMethod - The original SDK `createCompletions` method. / 原始的 SDK `createCompletions` 方法。 - * @param middlewares - An array of `CompletionsMiddleware` to apply. / 要应用的 `CompletionsMiddleware` 数组。 - * @returns An enhanced `completions` method with the middlewares applied. / 应用了中间件的增强版 `completions` 方法。 - */ -export function applyCompletionsMiddlewares< - TSdkInstance extends SdkInstance = SdkInstance, - TSdkParams extends SdkParams = SdkParams, - TRawOutput extends SdkRawOutput = SdkRawOutput, - TRawChunk extends SdkRawChunk = SdkRawChunk, - TMessageParam extends SdkMessageParam = SdkMessageParam, - TToolCall extends SdkToolCall = SdkToolCall, - TSdkSpecificTool extends SdkTool = SdkTool ->( - originalApiClientInstance: BaseApiClient< - TSdkInstance, - TSdkParams, - TRawOutput, - TRawChunk, - TMessageParam, - TToolCall, - TSdkSpecificTool - >, - originalCompletionsMethod: (payload: TSdkParams, options?: RequestOptions) => Promise, - middlewares: CompletionsMiddleware< - TSdkParams, - TMessageParam, - TToolCall, - TSdkInstance, - TRawOutput, - TRawChunk, - TSdkSpecificTool - >[] -): (params: CompletionsParams, options?: RequestOptions) => Promise { - // Returns a function matching the original method signature. / - // 返回一个与原始方法签名匹配的函数。 - - const methodName = 'completions' - - // Factory to create AiProviderMiddlewareCompletionsContext. / - // 用于创建 AiProviderMiddlewareCompletionsContext 的工厂函数。 - const completionsContextFactory = ( - base: BaseContext, - callArgs: [CompletionsParams] - ): CompletionsContext< - TSdkParams, - TMessageParam, - TToolCall, - TSdkInstance, - TRawOutput, - TRawChunk, - TSdkSpecificTool - > => { - return { - ...base, - methodName, - apiClientInstance: originalApiClientInstance, - originalArgs: callArgs, - _internal: { - toolProcessingState: { - recursionDepth: 0, - isRecursiveCall: false - }, - observer: {} - } - } - } - - return async function enhancedCompletionsMethod( - params: CompletionsParams, - options?: RequestOptions - ): Promise { - // `originalCallArgs` for context creation is `[params]`. / - // 用于上下文创建的 `originalCallArgs` 是 `[params]`。 - const originalCallArgs: [CompletionsParams] = [params] - const baseContext: BaseContext = { - [MIDDLEWARE_CONTEXT_SYMBOL]: true, - methodName, - originalArgs: originalCallArgs - } - const ctx = completionsContextFactory(baseContext, originalCallArgs) - - const api: MiddlewareAPI< - CompletionsContext, - [CompletionsParams] - > = { - getContext: () => ctx, - getOriginalArgs: () => originalCallArgs // API provides [CompletionsParams] / API提供 `[CompletionsParams]` - } - - // `finalDispatch` for CompletionsMiddleware: expects (context, params) not (context, args_array). / - // `CompletionsMiddleware` 的 `finalDispatch`:期望 (context, params) 而不是 (context, args_array)。 - const finalDispatch = async ( - context: CompletionsContext< - TSdkParams, - TMessageParam, - TToolCall, - TSdkInstance, - TRawOutput, - TRawChunk, - TSdkSpecificTool - > // Context passed through / 上下文透传 - // _currentParams: CompletionsParams // Directly takes params / 直接接收参数 (unused but required for middleware signature) - ): Promise => { - // At this point, middleware should have transformed CompletionsParams to SDK params - // and stored them in context. If no transformation happened, we need to handle it. - // 此时,中间件应该已经将 CompletionsParams 转换为 SDK 参数并存储在上下文中。 - // 如果没有进行转换,我们需要处理它。 - - const sdkPayload = context._internal?.sdkPayload - if (!sdkPayload) { - throw new Error('SDK payload not found in context. Middleware chain should have transformed parameters.') - } - - const abortSignal = context._internal.flowControl?.abortSignal - const timeout = context._internal.customState?.sdkMetadata?.timeout - - const methodCall = async (payload) => { - return await originalCompletionsMethod.call(originalApiClientInstance, payload, { - ...options, - signal: abortSignal, - timeout - }) - } - - const traceParams = { - name: `${params.assistant?.model?.name}.client`, - tag: 'LLM', - topicId: params.topicId || '', - modelName: params.assistant?.model?.name - } - - // Call the original SDK method with transformed parameters - // 使用转换后的参数调用原始 SDK 方法 - const rawOutput = await withSpanResult(methodCall, traceParams, sdkPayload) - - // Return result wrapped in CompletionsResult format - // 以 CompletionsResult 格式返回包装的结果 - return { rawOutput } as CompletionsResult - } - - const chain = middlewares.map((middleware) => middleware(api)) - const composedMiddlewareLogic = compose(...chain) - - // `enhancedDispatch` has the signature `(context, params) => Promise`. / - // `enhancedDispatch` 的签名为 `(context, params) => Promise`。 - const enhancedDispatch = composedMiddlewareLogic(finalDispatch) - - // 将 enhancedDispatch 保存到 context 中,供中间件进行递归调用 - // 这样可以避免重复执行整个中间件链 - ctx._internal.enhancedDispatch = enhancedDispatch - - // Execute with context and the single params object. / - // 使用上下文和单个参数对象执行。 - return enhancedDispatch(ctx, params) - } -} diff --git a/src/renderer/src/aiCore/legacy/middleware/core/McpToolChunkMiddleware.ts b/src/renderer/src/aiCore/legacy/middleware/core/McpToolChunkMiddleware.ts deleted file mode 100644 index 6affa5a565..0000000000 --- a/src/renderer/src/aiCore/legacy/middleware/core/McpToolChunkMiddleware.ts +++ /dev/null @@ -1,591 +0,0 @@ -import { loggerService } from '@logger' -import type { MCPCallToolResponse, MCPTool, MCPToolResponse, Model } from '@renderer/types' -import type { MCPToolCreatedChunk } from '@renderer/types/chunk' -import { ChunkType } from '@renderer/types/chunk' -import type { SdkMessageParam, SdkRawOutput, SdkToolCall } from '@renderer/types/sdk' -import { - callBuiltInTool, - callMCPTool, - getMcpServerByTool, - isToolAutoApproved, - parseToolUse, - upsertMCPToolResponse -} from '@renderer/utils/mcp-tools' -import { confirmSameNameTools, requestToolConfirmation, setToolIdToNameMapping } from '@renderer/utils/userConfirmation' - -import type { CompletionsParams, CompletionsResult, GenericChunk } from '../schemas' -import type { CompletionsContext, CompletionsMiddleware } from '../types' - -export const MIDDLEWARE_NAME = 'McpToolChunkMiddleware' -const MAX_TOOL_RECURSION_DEPTH = 20 // 防止无限递归 - -const logger = loggerService.withContext('McpToolChunkMiddleware') - -/** - * MCP工具处理中间件 - * - * 职责: - * 1. 检测并拦截MCP工具进展chunk(Function Call方式和Tool Use方式) - * 2. 执行工具调用 - * 3. 递归处理工具结果 - * 4. 管理工具调用状态和递归深度 - */ -export const McpToolChunkMiddleware: CompletionsMiddleware = - () => - (next) => - async (ctx: CompletionsContext, params: CompletionsParams): Promise => { - const mcpTools = params.mcpTools || [] - - // 如果没有工具,直接调用下一个中间件 - if (!mcpTools || mcpTools.length === 0) { - return next(ctx, params) - } - - const executeWithToolHandling = async (currentParams: CompletionsParams, depth = 0): Promise => { - if (depth >= MAX_TOOL_RECURSION_DEPTH) { - logger.error(`Maximum recursion depth ${MAX_TOOL_RECURSION_DEPTH} exceeded`) - throw new Error(`Maximum tool recursion depth ${MAX_TOOL_RECURSION_DEPTH} exceeded`) - } - - let result: CompletionsResult - - if (depth === 0) { - result = await next(ctx, currentParams) - } else { - const enhancedCompletions = ctx._internal.enhancedDispatch - if (!enhancedCompletions) { - logger.error(`Enhanced completions method not found, cannot perform recursive call`) - throw new Error('Enhanced completions method not found') - } - - ctx._internal.toolProcessingState!.isRecursiveCall = true - ctx._internal.toolProcessingState!.recursionDepth = depth - - result = await enhancedCompletions(ctx, currentParams) - } - - if (!result.stream) { - logger.error(`No stream returned from enhanced completions`) - throw new Error('No stream returned from enhanced completions') - } - - const resultFromUpstream = result.stream as ReadableStream - const toolHandlingStream = resultFromUpstream.pipeThrough( - createToolHandlingTransform(ctx, currentParams, mcpTools, depth, executeWithToolHandling) - ) - - return { - ...result, - stream: toolHandlingStream - } - } - - return executeWithToolHandling(params, 0) - } - -/** - * 创建工具处理的 TransformStream - */ -function createToolHandlingTransform( - ctx: CompletionsContext, - currentParams: CompletionsParams, - mcpTools: MCPTool[], - depth: number, - executeWithToolHandling: (params: CompletionsParams, depth: number) => Promise -): TransformStream { - const toolCalls: SdkToolCall[] = [] - const toolUseResponses: MCPToolResponse[] = [] - const allToolResponses: MCPToolResponse[] = [] // 统一的工具响应状态管理数组 - let hasToolCalls = false - let hasToolUseResponses = false - let streamEnded = false - - // 存储已执行的工具结果 - const executedToolResults: SdkMessageParam[] = [] - const executedToolCalls: SdkToolCall[] = [] - const executionPromises: Promise[] = [] - - return new TransformStream({ - async transform(chunk: GenericChunk, controller) { - try { - // 处理MCP工具进展chunk - logger.silly('chunk', chunk) - if (chunk.type === ChunkType.MCP_TOOL_CREATED) { - const createdChunk = chunk as MCPToolCreatedChunk - - // 1. 处理Function Call方式的工具调用 - if (createdChunk.tool_calls && createdChunk.tool_calls.length > 0) { - hasToolCalls = true - - for (const toolCall of createdChunk.tool_calls) { - toolCalls.push(toolCall) - - const executionPromise = (async () => { - try { - const result = await executeToolCalls( - ctx, - [toolCall], - mcpTools, - allToolResponses, - currentParams.onChunk, - currentParams.assistant.model!, - currentParams.topicId - ) - - // 缓存执行结果 - executedToolResults.push(...result.toolResults) - executedToolCalls.push(...result.confirmedToolCalls) - } catch (error) { - logger.error(`Error executing tool call asynchronously:`, error as Error) - } - })() - - executionPromises.push(executionPromise) - } - } - - // 2. 处理Tool Use方式的工具调用 - if (createdChunk.tool_use_responses && createdChunk.tool_use_responses.length > 0) { - hasToolUseResponses = true - for (const toolUseResponse of createdChunk.tool_use_responses) { - toolUseResponses.push(toolUseResponse) - const executionPromise = (async () => { - try { - const result = await executeToolUseResponses( - ctx, - [toolUseResponse], // 单个执行 - mcpTools, - allToolResponses, - currentParams.onChunk, - currentParams.assistant.model!, - currentParams.topicId - ) - - // 缓存执行结果 - executedToolResults.push(...result.toolResults) - } catch (error) { - logger.error(`Error executing tool use response asynchronously:`, error as Error) - // 错误时不影响其他工具的执行 - } - })() - - executionPromises.push(executionPromise) - } - } - } else { - controller.enqueue(chunk) - } - } catch (error) { - logger.error(`Error processing chunk:`, error as Error) - controller.error(error) - } - }, - - async flush(controller) { - // 在流结束时等待所有异步工具执行完成,然后进行递归调用 - if (!streamEnded && (hasToolCalls || hasToolUseResponses)) { - streamEnded = true - - try { - await Promise.all(executionPromises) - if (executedToolResults.length > 0) { - const output = ctx._internal.toolProcessingState?.output - const newParams = buildParamsWithToolResults( - ctx, - currentParams, - output, - executedToolResults, - executedToolCalls - ) - - // 在递归调用前通知UI开始新的LLM响应处理 - if (currentParams.onChunk) { - currentParams.onChunk({ - type: ChunkType.LLM_RESPONSE_CREATED - }) - } - - await executeWithToolHandling(newParams, depth + 1) - } - } catch (error) { - logger.error(`Error in tool processing:`, error as Error) - controller.error(error) - } finally { - hasToolCalls = false - hasToolUseResponses = false - } - } - } - }) -} - -/** - * 执行工具调用(Function Call 方式) - */ -async function executeToolCalls( - ctx: CompletionsContext, - toolCalls: SdkToolCall[], - mcpTools: MCPTool[], - allToolResponses: MCPToolResponse[], - onChunk: CompletionsParams['onChunk'], - model: Model, - topicId?: string -): Promise<{ toolResults: SdkMessageParam[]; confirmedToolCalls: SdkToolCall[] }> { - const mcpToolResponses: MCPToolResponse[] = toolCalls - .map((toolCall) => { - const mcpTool = ctx.apiClientInstance.convertSdkToolCallToMcp(toolCall, mcpTools) - if (!mcpTool) { - return undefined - } - return ctx.apiClientInstance.convertSdkToolCallToMcpToolResponse(toolCall, mcpTool) - }) - .filter((t): t is MCPToolResponse => typeof t !== 'undefined') - - if (mcpToolResponses.length === 0) { - logger.warn(`No valid MCP tool responses to execute`) - return { toolResults: [], confirmedToolCalls: [] } - } - - // 使用现有的parseAndCallTools函数执行工具 - const { toolResults, confirmedToolResponses } = await parseAndCallTools( - mcpToolResponses, - allToolResponses, - onChunk, - (mcpToolResponse, resp, model) => { - return ctx.apiClientInstance.convertMcpToolResponseToSdkMessageParam(mcpToolResponse, resp, model) - }, - model, - mcpTools, - ctx._internal?.flowControl?.abortSignal, - topicId - ) - - // 找出已确认工具对应的原始toolCalls - const confirmedToolCalls = toolCalls.filter((toolCall) => { - return confirmedToolResponses.find((confirmed) => { - // 根据不同的ID字段匹配原始toolCall - return ( - ('name' in toolCall && - (toolCall.name?.includes(confirmed.tool.name) || toolCall.name?.includes(confirmed.tool.id))) || - confirmed.tool.name === toolCall.id || - confirmed.tool.id === toolCall.id || - ('toolCallId' in confirmed && confirmed.toolCallId === toolCall.id) || - ('function' in toolCall && toolCall.function.name.toLowerCase().includes(confirmed.tool.name.toLowerCase())) - ) - }) - }) - - return { toolResults, confirmedToolCalls } -} - -/** - * 执行工具使用响应(Tool Use Response 方式) - * 处理已经解析好的 ToolUseResponse[],不需要重新解析字符串 - */ -async function executeToolUseResponses( - ctx: CompletionsContext, - toolUseResponses: MCPToolResponse[], - mcpTools: MCPTool[], - allToolResponses: MCPToolResponse[], - onChunk: CompletionsParams['onChunk'], - model: Model, - topicId?: CompletionsParams['topicId'] -): Promise<{ toolResults: SdkMessageParam[] }> { - // 直接使用parseAndCallTools函数处理已经解析好的ToolUseResponse - const { toolResults } = await parseAndCallTools( - toolUseResponses, - allToolResponses, - onChunk, - (mcpToolResponse, resp, model) => { - return ctx.apiClientInstance.convertMcpToolResponseToSdkMessageParam(mcpToolResponse, resp, model) - }, - model, - mcpTools, - ctx._internal?.flowControl?.abortSignal, - topicId - ) - - return { toolResults } -} - -/** - * 构建包含工具结果的新参数 - */ -function buildParamsWithToolResults( - ctx: CompletionsContext, - currentParams: CompletionsParams, - output: SdkRawOutput | string | undefined, - toolResults: SdkMessageParam[], - confirmedToolCalls: SdkToolCall[] -): CompletionsParams { - // 获取当前已经转换好的reqMessages,如果没有则使用原始messages - const currentReqMessages = getCurrentReqMessages(ctx) - - const apiClient = ctx.apiClientInstance - - // 从回复中构建助手消息 - const newReqMessages = apiClient.buildSdkMessages(currentReqMessages, output, toolResults, confirmedToolCalls) - - if (output && ctx._internal.toolProcessingState) { - ctx._internal.toolProcessingState.output = undefined - } - - // 估算新增消息的 token 消耗并累加到 usage 中 - if (ctx._internal.observer?.usage && newReqMessages.length > currentReqMessages.length) { - try { - const newMessages = newReqMessages.slice(currentReqMessages.length) - const additionalTokens = newMessages.reduce((acc, message) => { - return acc + ctx.apiClientInstance.estimateMessageTokens(message) - }, 0) - - if (additionalTokens > 0) { - ctx._internal.observer.usage.prompt_tokens += additionalTokens - ctx._internal.observer.usage.total_tokens += additionalTokens - } - } catch (error) { - logger.error(`Error estimating token usage for new messages:`, error as Error) - } - } - - // 更新递归状态 - if (!ctx._internal.toolProcessingState) { - ctx._internal.toolProcessingState = {} - } - ctx._internal.toolProcessingState.isRecursiveCall = true - ctx._internal.toolProcessingState.recursionDepth = (ctx._internal.toolProcessingState?.recursionDepth || 0) + 1 - - return { - ...currentParams, - _internal: { - ...ctx._internal, - sdkPayload: ctx._internal.sdkPayload, - newReqMessages: newReqMessages - } - } -} - -/** - * 类型安全地获取当前请求消息 - * 使用API客户端提供的抽象方法,保持中间件的provider无关性 - */ -function getCurrentReqMessages(ctx: CompletionsContext): SdkMessageParam[] { - const sdkPayload = ctx._internal.sdkPayload - if (!sdkPayload) { - return [] - } - - // 使用API客户端的抽象方法来提取消息,保持provider无关性 - return ctx.apiClientInstance.extractMessagesFromSdkPayload(sdkPayload) -} - -export async function parseAndCallTools( - tools: MCPToolResponse[], - allToolResponses: MCPToolResponse[], - onChunk: CompletionsParams['onChunk'], - convertToMessage: (mcpToolResponse: MCPToolResponse, resp: MCPCallToolResponse, model: Model) => R | undefined, - model: Model, - mcpTools?: MCPTool[], - abortSignal?: AbortSignal, - topicId?: CompletionsParams['topicId'] -): Promise<{ toolResults: R[]; confirmedToolResponses: MCPToolResponse[] }> - -export async function parseAndCallTools( - content: string, - allToolResponses: MCPToolResponse[], - onChunk: CompletionsParams['onChunk'], - convertToMessage: (mcpToolResponse: MCPToolResponse, resp: MCPCallToolResponse, model: Model) => R | undefined, - model: Model, - mcpTools?: MCPTool[], - abortSignal?: AbortSignal, - topicId?: CompletionsParams['topicId'] -): Promise<{ toolResults: R[]; confirmedToolResponses: MCPToolResponse[] }> - -export async function parseAndCallTools( - content: string | MCPToolResponse[], - allToolResponses: MCPToolResponse[], - onChunk: CompletionsParams['onChunk'], - convertToMessage: (mcpToolResponse: MCPToolResponse, resp: MCPCallToolResponse, model: Model) => R | undefined, - model: Model, - mcpTools?: MCPTool[], - abortSignal?: AbortSignal, - topicId?: CompletionsParams['topicId'] -): Promise<{ toolResults: R[]; confirmedToolResponses: MCPToolResponse[] }> { - const toolResults: R[] = [] - let curToolResponses: MCPToolResponse[] = [] - if (Array.isArray(content)) { - curToolResponses = content - } else { - // process tool use - curToolResponses = parseToolUse(content, mcpTools || [], 0) - } - if (!curToolResponses || curToolResponses.length === 0) { - return { toolResults, confirmedToolResponses: [] } - } - - for (const toolResponse of curToolResponses) { - upsertMCPToolResponse( - allToolResponses, - { - ...toolResponse, - status: 'pending' - }, - onChunk! - ) - } - - // 创建工具确认Promise映射,并立即处理每个确认 - const confirmedTools: MCPToolResponse[] = [] - const pendingPromises: Promise[] = [] - - curToolResponses.forEach((toolResponse) => { - const server = getMcpServerByTool(toolResponse.tool) - const isAutoApproveEnabled = isToolAutoApproved(toolResponse.tool, server) - let confirmationPromise: Promise - if (isAutoApproveEnabled) { - confirmationPromise = Promise.resolve(true) - } else { - setToolIdToNameMapping(toolResponse.id, toolResponse.tool.name) - - confirmationPromise = requestToolConfirmation(toolResponse.id, abortSignal).then((confirmed) => { - if (confirmed && server) { - // 自动确认其他同名的待确认工具 - confirmSameNameTools(toolResponse.tool.name) - } - return confirmed - }) - } - - const processingPromise = confirmationPromise - .then(async (confirmed) => { - if (confirmed) { - // 立即更新为invoking状态 - upsertMCPToolResponse( - allToolResponses, - { - ...toolResponse, - status: 'invoking' - }, - onChunk! - ) - - // 执行工具调用 - try { - const images: string[] = [] - // 根据工具类型选择不同的调用方式 - const toolCallResponse = toolResponse.tool.isBuiltIn - ? await callBuiltInTool(toolResponse) - : await callMCPTool(toolResponse, topicId, model.name) - - // 立即更新为done状态 - upsertMCPToolResponse( - allToolResponses, - { - ...toolResponse, - status: 'done', - response: toolCallResponse - }, - onChunk! - ) - - if (!toolCallResponse) { - return - } - - // 处理图片 - for (const content of toolCallResponse.content) { - if (content.type === 'image' && content.data) { - images.push(`data:${content.mimeType};base64,${content.data}`) - } - } - - if (images.length) { - onChunk?.({ - type: ChunkType.IMAGE_CREATED - }) - onChunk?.({ - type: ChunkType.IMAGE_COMPLETE, - image: { - type: 'base64', - images: images - } - }) - } - - // 转换消息并添加到结果 - const convertedMessage = convertToMessage(toolResponse, toolCallResponse, model) - if (convertedMessage) { - confirmedTools.push(toolResponse) - toolResults.push(convertedMessage) - } - } catch (error) { - logger.error(`Error executing tool ${toolResponse.id}:`, error as Error) - // 更新为错误状态 - upsertMCPToolResponse( - allToolResponses, - { - ...toolResponse, - status: 'done', - response: { - isError: true, - content: [ - { - type: 'text', - text: `Error executing tool: ${error instanceof Error ? error.message : 'Unknown error'}` - } - ] - } - }, - onChunk! - ) - } - } else { - // 立即更新为cancelled状态 - upsertMCPToolResponse( - allToolResponses, - { - ...toolResponse, - status: 'cancelled', - response: { - isError: false, - content: [ - { - type: 'text', - text: 'Tool call cancelled by user.' - } - ] - } - }, - onChunk! - ) - } - }) - .catch((error) => { - logger.error(`Error waiting for tool confirmation ${toolResponse.id}:`, error as Error) - // 立即更新为cancelled状态 - upsertMCPToolResponse( - allToolResponses, - { - ...toolResponse, - status: 'cancelled', - response: { - isError: true, - content: [ - { - type: 'text', - text: `Error in confirmation process: ${error instanceof Error ? error.message : 'Unknown error'}` - } - ] - } - }, - onChunk! - ) - }) - - pendingPromises.push(processingPromise) - }) - - // 等待所有工具处理完成(但每个工具的状态已经实时更新) - await Promise.all(pendingPromises) - - return { toolResults, confirmedToolResponses: confirmedTools } -} diff --git a/src/renderer/src/aiCore/legacy/middleware/core/RawStreamListenerMiddleware.ts b/src/renderer/src/aiCore/legacy/middleware/core/RawStreamListenerMiddleware.ts deleted file mode 100644 index 04bfd751e2..0000000000 --- a/src/renderer/src/aiCore/legacy/middleware/core/RawStreamListenerMiddleware.ts +++ /dev/null @@ -1,45 +0,0 @@ -import { AnthropicAPIClient } from '@renderer/aiCore/legacy/clients/anthropic/AnthropicAPIClient' -import type { AnthropicSdkRawChunk, AnthropicSdkRawOutput } from '@renderer/types/sdk' - -import type { AnthropicStreamListener } from '../../clients/types' -import type { CompletionsParams, CompletionsResult } from '../schemas' -import type { CompletionsContext, CompletionsMiddleware } from '../types' - -export const MIDDLEWARE_NAME = 'RawStreamListenerMiddleware' - -export const RawStreamListenerMiddleware: CompletionsMiddleware = - () => - (next) => - async (ctx: CompletionsContext, params: CompletionsParams): Promise => { - const result = await next(ctx, params) - - // 在这里可以监听到从SDK返回的最原始流 - if (result.rawOutput) { - // TODO: 后面下放到AnthropicAPIClient - if (ctx.apiClientInstance instanceof AnthropicAPIClient) { - const anthropicListener: AnthropicStreamListener = { - onMessage: (message) => { - if (ctx._internal?.toolProcessingState) { - ctx._internal.toolProcessingState.output = message - } - } - // onContentBlock: (contentBlock) => { - // console.log(`[${MIDDLEWARE_NAME}] 📝 Anthropic content block:`, contentBlock.type) - // } - } - - const specificApiClient = ctx.apiClientInstance as AnthropicAPIClient - - const monitoredOutput = specificApiClient.attachRawStreamListener( - result.rawOutput as AnthropicSdkRawOutput, - anthropicListener - ) - return { - ...result, - rawOutput: monitoredOutput - } - } - } - - return result - } diff --git a/src/renderer/src/aiCore/legacy/middleware/core/ResponseTransformMiddleware.ts b/src/renderer/src/aiCore/legacy/middleware/core/ResponseTransformMiddleware.ts deleted file mode 100644 index bdab7b8783..0000000000 --- a/src/renderer/src/aiCore/legacy/middleware/core/ResponseTransformMiddleware.ts +++ /dev/null @@ -1,88 +0,0 @@ -import { loggerService } from '@logger' -import type { SdkRawChunk } from '@renderer/types/sdk' - -import type { ResponseChunkTransformerContext } from '../../clients/types' -import type { CompletionsParams, CompletionsResult, GenericChunk } from '../schemas' -import type { CompletionsContext, CompletionsMiddleware } from '../types' - -export const MIDDLEWARE_NAME = 'ResponseTransformMiddleware' - -const logger = loggerService.withContext('ResponseTransformMiddleware') - -/** - * 响应转换中间件 - * - * 职责: - * 1. 检测ReadableStream类型的响应流 - * 2. 使用ApiClient的getResponseChunkTransformer()将原始SDK响应块转换为通用格式 - * 3. 将转换后的ReadableStream保存到ctx._internal.apiCall.genericChunkStream,供下游中间件使用 - * - * 注意:此中间件应该在StreamAdapterMiddleware之后执行 - */ -export const ResponseTransformMiddleware: CompletionsMiddleware = - () => - (next) => - async (ctx: CompletionsContext, params: CompletionsParams): Promise => { - // 调用下游中间件 - const result = await next(ctx, params) - - // 响应后处理:转换原始SDK响应块 - if (result.stream) { - const adaptedStream = result.stream - - // 处理ReadableStream类型的流 - if (adaptedStream instanceof ReadableStream) { - const apiClient = ctx.apiClientInstance - if (!apiClient) { - logger.error(`ApiClient instance not found in context`) - throw new Error('ApiClient instance not found in context') - } - - // 获取响应转换器 - const responseChunkTransformer = apiClient.getResponseChunkTransformer(ctx) - if (!responseChunkTransformer) { - logger.warn(`No ResponseChunkTransformer available, skipping transformation`) - return result - } - - const assistant = params.assistant - const model = assistant?.model - - if (!assistant || !model) { - logger.error(`Assistant or Model not found for transformation`) - throw new Error('Assistant or Model not found for transformation') - } - - const transformerContext: ResponseChunkTransformerContext = { - isStreaming: params.streamOutput || false, - isEnabledToolCalling: (params.mcpTools && params.mcpTools.length > 0) || false, - isEnabledWebSearch: params.enableWebSearch || false, - isEnabledUrlContext: params.enableUrlContext || false, - isEnabledReasoning: params.enableReasoning || false, - mcpTools: params.mcpTools || [], - provider: ctx.apiClientInstance?.provider - } - - logger.debug(`Transforming raw SDK chunks with context:`, transformerContext) - - try { - // 创建转换后的流 - const genericChunkTransformStream = (adaptedStream as ReadableStream).pipeThrough( - new TransformStream(responseChunkTransformer(transformerContext)) - ) - - // 将转换后的ReadableStream保存到result,供下游中间件使用 - return { - ...result, - stream: genericChunkTransformStream - } - } catch (error) { - logger.error('Error during chunk transformation:', error as Error) - throw error - } - } - } - - // 如果没有流或不是ReadableStream,返回原始结果 - return result - } diff --git a/src/renderer/src/aiCore/legacy/middleware/core/StreamAdapterMiddleware.ts b/src/renderer/src/aiCore/legacy/middleware/core/StreamAdapterMiddleware.ts deleted file mode 100644 index b6dc13e602..0000000000 --- a/src/renderer/src/aiCore/legacy/middleware/core/StreamAdapterMiddleware.ts +++ /dev/null @@ -1,56 +0,0 @@ -import type { SdkRawChunk } from '@renderer/types/sdk' -import { asyncGeneratorToReadableStream, createSingleChunkReadableStream } from '@renderer/utils/stream' - -import type { CompletionsParams, CompletionsResult } from '../schemas' -import type { CompletionsContext, CompletionsMiddleware } from '../types' -import { isAsyncIterable } from '../utils' - -export const MIDDLEWARE_NAME = 'StreamAdapterMiddleware' - -/** - * 流适配器中间件 - * - * 职责: - * 1. 检测ctx._internal.apiCall.rawSdkOutput(优先)或原始AsyncIterable流 - * 2. 将AsyncIterable转换为WHATWG ReadableStream - * 3. 更新响应结果中的stream - * - * 注意:如果ResponseTransformMiddleware已处理过,会优先使用transformedStream - */ -export const StreamAdapterMiddleware: CompletionsMiddleware = - () => - (next) => - async (ctx: CompletionsContext, params: CompletionsParams): Promise => { - // TODO:调用开始,因为这个是最靠近接口请求的地方,next执行代表着开始接口请求了 - // 但是这个中间件的职责是流适配,是否在这调用优待商榷 - // 调用下游中间件 - const result = await next(ctx, params) - if ( - result.rawOutput && - !(result.rawOutput instanceof ReadableStream) && - isAsyncIterable(result.rawOutput) - ) { - const whatwgReadableStream: ReadableStream = asyncGeneratorToReadableStream( - result.rawOutput - ) - return { - ...result, - stream: whatwgReadableStream - } - } else if (result.rawOutput && result.rawOutput instanceof ReadableStream) { - return { - ...result, - stream: result.rawOutput - } - } else if (result.rawOutput) { - // 非流式输出,强行变为可读流 - const whatwgReadableStream: ReadableStream = createSingleChunkReadableStream( - result.rawOutput as SdkRawChunk - ) - return { - ...result, - stream: whatwgReadableStream - } - } - return result - } diff --git a/src/renderer/src/aiCore/legacy/middleware/core/TextChunkMiddleware.ts b/src/renderer/src/aiCore/legacy/middleware/core/TextChunkMiddleware.ts deleted file mode 100644 index 837244981a..0000000000 --- a/src/renderer/src/aiCore/legacy/middleware/core/TextChunkMiddleware.ts +++ /dev/null @@ -1,106 +0,0 @@ -import { loggerService } from '@logger' -import { ChunkType } from '@renderer/types/chunk' - -import type { CompletionsParams, CompletionsResult, GenericChunk } from '../schemas' -import type { CompletionsContext, CompletionsMiddleware } from '../types' - -export const MIDDLEWARE_NAME = 'TextChunkMiddleware' - -const logger = loggerService.withContext('TextChunkMiddleware') - -/** - * 文本块处理中间件 - * - * 职责: - * 1. 累积文本内容(TEXT_DELTA) - * 2. 对文本内容进行智能链接转换 - * 3. 生成TEXT_COMPLETE事件 - * 4. 暂存Web搜索结果,用于最终链接完善 - * 5. 处理 onResponse 回调,实时发送文本更新和最终完整文本 - */ -export const TextChunkMiddleware: CompletionsMiddleware = - () => - (next) => - async (ctx: CompletionsContext, params: CompletionsParams): Promise => { - // 调用下游中间件 - const result = await next(ctx, params) - - // 响应后处理:转换流式响应中的文本内容 - if (result.stream) { - const resultFromUpstream = result.stream as ReadableStream - - if (resultFromUpstream && resultFromUpstream instanceof ReadableStream) { - const assistant = params.assistant - const model = params.assistant?.model - - if (!assistant || !model) { - logger.warn(`Missing assistant or model information, skipping text processing`) - return result - } - - // 用于跨chunk的状态管理 - let accumulatedTextContent = '' - const enhancedTextStream = resultFromUpstream.pipeThrough( - new TransformStream({ - transform(chunk: GenericChunk, controller) { - logger.silly('chunk', chunk) - if (chunk.type === ChunkType.TEXT_DELTA) { - if (model.supported_text_delta === false) { - accumulatedTextContent = chunk.text - } else { - accumulatedTextContent += chunk.text - } - // 处理 onResponse 回调 - 发送增量文本更新 - if (params.onResponse) { - params.onResponse(accumulatedTextContent, false) - } - - controller.enqueue({ - ...chunk, - text: accumulatedTextContent // 增量更新 - }) - } else if (accumulatedTextContent && chunk.type !== ChunkType.TEXT_START) { - ctx._internal.customState!.accumulatedText = accumulatedTextContent - if (ctx._internal.toolProcessingState && !ctx._internal.toolProcessingState?.output) { - ctx._internal.toolProcessingState.output = accumulatedTextContent - } - - if (chunk.type === ChunkType.LLM_RESPONSE_COMPLETE) { - // 处理 onResponse 回调 - 发送最终完整文本 - if (params.onResponse) { - params.onResponse(accumulatedTextContent, true) - } - - controller.enqueue({ - type: ChunkType.TEXT_COMPLETE, - text: accumulatedTextContent - }) - controller.enqueue(chunk) - } else { - controller.enqueue({ - type: ChunkType.TEXT_COMPLETE, - text: accumulatedTextContent - }) - controller.enqueue(chunk) - } - accumulatedTextContent = '' - } else { - // 其他类型的chunk直接传递 - controller.enqueue(chunk) - } - } - }) - ) - - // 更新响应结果 - return { - ...result, - stream: enhancedTextStream - } - } else { - logger.warn(`No stream to process or not a ReadableStream. Returning original result.`) - } - } - - return result - } diff --git a/src/renderer/src/aiCore/legacy/middleware/core/ThinkChunkMiddleware.ts b/src/renderer/src/aiCore/legacy/middleware/core/ThinkChunkMiddleware.ts deleted file mode 100644 index 5920cdc0ea..0000000000 --- a/src/renderer/src/aiCore/legacy/middleware/core/ThinkChunkMiddleware.ts +++ /dev/null @@ -1,99 +0,0 @@ -import { loggerService } from '@logger' -import type { ThinkingCompleteChunk, ThinkingDeltaChunk } from '@renderer/types/chunk' -import { ChunkType } from '@renderer/types/chunk' - -import type { CompletionsParams, CompletionsResult, GenericChunk } from '../schemas' -import type { CompletionsContext, CompletionsMiddleware } from '../types' - -export const MIDDLEWARE_NAME = 'ThinkChunkMiddleware' - -const logger = loggerService.withContext('ThinkChunkMiddleware') - -/** - * 处理思考内容的中间件 - * - * 注意:从 v2 版本开始,流结束语义的判断已移至 ApiClient 层处理 - * 此中间件现在主要负责: - * 1. 处理原始SDK chunk中的reasoning字段 - * 2. 计算准确的思考时间 - * 3. 在思考内容结束时生成THINKING_COMPLETE事件 - * - * 职责: - * 1. 累积思考内容(THINKING_DELTA) - * 2. 监听流结束信号,生成THINKING_COMPLETE事件 - * 3. 计算准确的思考时间 - * - */ -export const ThinkChunkMiddleware: CompletionsMiddleware = - () => - (next) => - async (ctx: CompletionsContext, params: CompletionsParams): Promise => { - // 调用下游中间件 - const result = await next(ctx, params) - - // 响应后处理:处理思考内容 - if (result.stream) { - const resultFromUpstream = result.stream as ReadableStream - - // 检查是否有流需要处理 - if (resultFromUpstream && resultFromUpstream instanceof ReadableStream) { - // thinking 处理状态 - let accumulatedThinkingContent = '' - let hasThinkingContent = false - let thinkingStartTime = 0 - - const processedStream = resultFromUpstream.pipeThrough( - new TransformStream({ - transform(chunk: GenericChunk, controller) { - if (chunk.type === ChunkType.THINKING_DELTA) { - const thinkingChunk = chunk as ThinkingDeltaChunk - - // 第一次接收到思考内容时记录开始时间 - if (!hasThinkingContent) { - hasThinkingContent = true - thinkingStartTime = Date.now() - } - - accumulatedThinkingContent += thinkingChunk.text - - // 更新思考时间并传递 - const enhancedChunk: ThinkingDeltaChunk = { - ...thinkingChunk, - text: accumulatedThinkingContent, - thinking_millsec: thinkingStartTime > 0 ? Date.now() - thinkingStartTime : 0 - } - controller.enqueue(enhancedChunk) - } else if (hasThinkingContent && thinkingStartTime > 0 && chunk.type !== ChunkType.THINKING_START) { - // 收到任何非THINKING_DELTA的chunk时,如果有累积的思考内容,生成THINKING_COMPLETE - const thinkingCompleteChunk: ThinkingCompleteChunk = { - type: ChunkType.THINKING_COMPLETE, - text: accumulatedThinkingContent, - thinking_millsec: thinkingStartTime > 0 ? Date.now() - thinkingStartTime : 0 - } - controller.enqueue(thinkingCompleteChunk) - hasThinkingContent = false - accumulatedThinkingContent = '' - thinkingStartTime = 0 - - // 继续传递当前chunk - controller.enqueue(chunk) - } else { - // 其他情况直接传递 - controller.enqueue(chunk) - } - } - }) - ) - - // 更新响应结果 - return { - ...result, - stream: processedStream - } - } else { - logger.warn(`No generic chunk stream to process or not a ReadableStream.`) - } - } - - return result - } diff --git a/src/renderer/src/aiCore/legacy/middleware/core/TransformCoreToSdkParamsMiddleware.ts b/src/renderer/src/aiCore/legacy/middleware/core/TransformCoreToSdkParamsMiddleware.ts deleted file mode 100644 index ebc86f5a5e..0000000000 --- a/src/renderer/src/aiCore/legacy/middleware/core/TransformCoreToSdkParamsMiddleware.ts +++ /dev/null @@ -1,81 +0,0 @@ -import { loggerService } from '@logger' -import { ChunkType } from '@renderer/types/chunk' - -import type { CompletionsParams, CompletionsResult } from '../schemas' -import type { CompletionsContext, CompletionsMiddleware } from '../types' - -export const MIDDLEWARE_NAME = 'TransformCoreToSdkParamsMiddleware' - -const logger = loggerService.withContext('TransformCoreToSdkParamsMiddleware') - -/** - * 中间件:将CoreCompletionsRequest转换为SDK特定的参数 - * 使用上下文中ApiClient实例的requestTransformer进行转换 - */ -export const TransformCoreToSdkParamsMiddleware: CompletionsMiddleware = - () => - (next) => - async (ctx: CompletionsContext, params: CompletionsParams): Promise => { - const internal = ctx._internal - - // 🔧 检测递归调用:检查 params 中是否携带了预处理的 SDK 消息 - const isRecursiveCall = internal?.toolProcessingState?.isRecursiveCall || false - const newSdkMessages = params._internal?.newReqMessages - - const apiClient = ctx.apiClientInstance - - if (!apiClient) { - logger.error(`ApiClient instance not found in context.`) - throw new Error('ApiClient instance not found in context') - } - - // 检查是否有requestTransformer方法 - const requestTransformer = apiClient.getRequestTransformer() - if (!requestTransformer) { - logger.warn(`ApiClient does not have getRequestTransformer method, skipping transformation`) - const result = await next(ctx, params) - return result - } - - // 确保assistant和model可用,它们是transformer所需的 - const assistant = params.assistant - const model = params.assistant.model - - if (!assistant || !model) { - logger.error(`Assistant or Model not found for transformation.`) - throw new Error('Assistant or Model not found for transformation') - } - - try { - const transformResult = await requestTransformer.transform( - params, - assistant, - model, - isRecursiveCall, - newSdkMessages - ) - - const { payload: sdkPayload, metadata } = transformResult - - // 将SDK特定的payload和metadata存储在状态中,供下游中间件使用 - ctx._internal.sdkPayload = sdkPayload - - if (metadata) { - ctx._internal.customState = { - ...ctx._internal.customState, - sdkMetadata: metadata - } - } - - if (params.enableGenerateImage) { - params.onChunk?.({ - type: ChunkType.IMAGE_CREATED - }) - } - return next(ctx, params) - } catch (error) { - logger.error('Error during request transformation:', error as Error) - // 让错误向上传播,或者可以在这里进行特定的错误处理 - throw error - } - } diff --git a/src/renderer/src/aiCore/legacy/middleware/core/WebSearchMiddleware.ts b/src/renderer/src/aiCore/legacy/middleware/core/WebSearchMiddleware.ts deleted file mode 100644 index 3365b163b6..0000000000 --- a/src/renderer/src/aiCore/legacy/middleware/core/WebSearchMiddleware.ts +++ /dev/null @@ -1,102 +0,0 @@ -import { loggerService } from '@logger' -import { ChunkType } from '@renderer/types/chunk' -import { convertLinks, flushLinkConverterBuffer } from '@renderer/utils/linkConverter' - -import type { CompletionsParams, CompletionsResult, GenericChunk } from '../schemas' -import type { CompletionsContext, CompletionsMiddleware } from '../types' - -const logger = loggerService.withContext('WebSearchMiddleware') - -export const MIDDLEWARE_NAME = 'WebSearchMiddleware' - -/** - * Web搜索处理中间件 - 基于GenericChunk流处理 - * - * 职责: - * 1. 监听和记录Web搜索事件 - * 2. 可以在此处添加Web搜索结果的后处理逻辑 - * 3. 维护Web搜索相关的状态 - * - * 注意:Web搜索结果的识别和生成已在ApiClient的响应转换器中处理 - */ -export const WebSearchMiddleware: CompletionsMiddleware = - () => - (next) => - async (ctx: CompletionsContext, params: CompletionsParams): Promise => { - ctx._internal.webSearchState = { - results: undefined - } - // 调用下游中间件 - const result = await next(ctx, params) - let isFirstChunk = true - - // 响应后处理:记录Web搜索事件 - if (result.stream) { - const resultFromUpstream = result.stream - - if (resultFromUpstream && resultFromUpstream instanceof ReadableStream) { - // Web搜索状态跟踪 - const enhancedStream = (resultFromUpstream as ReadableStream).pipeThrough( - new TransformStream({ - transform(chunk: GenericChunk, controller) { - if (chunk.type === ChunkType.TEXT_DELTA) { - // 使用当前可用的Web搜索结果进行链接转换 - const text = chunk.text - const result = convertLinks(text, isFirstChunk) - if (isFirstChunk) { - isFirstChunk = false - } - - // - 如果有内容被缓冲,说明convertLinks正在等待后续chunk,不使用原文本避免重复 - // - 如果没有内容被缓冲且结果为空,可能是其他处理问题,使用原文本作为安全回退 - let finalText: string - if (result.hasBufferedContent) { - // 有内容被缓冲,使用处理后的结果(可能为空,等待后续chunk) - finalText = result.text - } else { - // 没有内容被缓冲,可以安全使用回退逻辑 - finalText = result.text || text - } - - // 只有当finalText不为空时才发送chunk - if (finalText) { - controller.enqueue({ - ...chunk, - text: finalText - }) - } - } else if (chunk.type === ChunkType.LLM_WEB_SEARCH_COMPLETE) { - // 暂存Web搜索结果用于链接完善 - ctx._internal.webSearchState!.results = chunk.llm_web_search - - // 将Web搜索完成事件继续传递下去 - controller.enqueue(chunk) - } else if (chunk.type === ChunkType.LLM_RESPONSE_COMPLETE) { - // 流结束时,清空链接转换器的buffer并处理剩余内容 - const remainingText = flushLinkConverterBuffer() - if (remainingText) { - controller.enqueue({ - type: ChunkType.TEXT_DELTA, - text: remainingText - }) - } - // 继续传递LLM_RESPONSE_COMPLETE事件 - controller.enqueue(chunk) - } else { - controller.enqueue(chunk) - } - } - }) - ) - - return { - ...result, - stream: enhancedStream - } - } else { - logger.debug(`No stream to process or not a ReadableStream.`) - } - } - - return result - } diff --git a/src/renderer/src/aiCore/legacy/middleware/feat/ImageGenerationMiddleware.ts b/src/renderer/src/aiCore/legacy/middleware/feat/ImageGenerationMiddleware.ts deleted file mode 100644 index 0df303e41e..0000000000 --- a/src/renderer/src/aiCore/legacy/middleware/feat/ImageGenerationMiddleware.ts +++ /dev/null @@ -1,148 +0,0 @@ -import type OpenAI from '@cherrystudio/openai' -import { toFile } from '@cherrystudio/openai/uploads' -import { isDedicatedImageGenerationModel } from '@renderer/config/models' -import FileManager from '@renderer/services/FileManager' -import { ChunkType } from '@renderer/types/chunk' -import { findImageBlocks, getMainTextContent } from '@renderer/utils/messageUtils/find' -import { defaultTimeout } from '@shared/config/constant' - -import type { BaseApiClient } from '../../clients/BaseApiClient' -import type { CompletionsParams, CompletionsResult, GenericChunk } from '../schemas' -import type { CompletionsContext, CompletionsMiddleware } from '../types' - -export const MIDDLEWARE_NAME = 'ImageGenerationMiddleware' - -export const ImageGenerationMiddleware: CompletionsMiddleware = - () => - (next) => - async (context: CompletionsContext, params: CompletionsParams): Promise => { - const { assistant, messages } = params - const client = context.apiClientInstance as BaseApiClient - const signal = context._internal?.flowControl?.abortSignal - if (!assistant.model || !isDedicatedImageGenerationModel(assistant.model) || typeof messages === 'string') { - return next(context, params) - } - - const stream = new ReadableStream({ - async start(controller) { - const enqueue = (chunk: GenericChunk) => controller.enqueue(chunk) - - try { - if (!assistant.model) { - throw new Error('Assistant model is not defined.') - } - - const sdk = await client.getSdkInstance() - const lastUserMessage = messages.findLast((m) => m.role === 'user') - const lastAssistantMessage = messages.findLast((m) => m.role === 'assistant') - - if (!lastUserMessage) { - throw new Error('No user message found for image generation.') - } - - const prompt = getMainTextContent(lastUserMessage) - let imageFiles: Blob[] = [] - - // Collect images from user message - const userImageBlocks = findImageBlocks(lastUserMessage) - const userImages = await Promise.all( - userImageBlocks.map(async (block) => { - if (!block.file) return null - const binaryData: Uint8Array = await FileManager.readBinaryImage(block.file) - const mimeType = `${block.file.type}/${block.file.ext.slice(1)}` - return await toFile(new Blob([binaryData]), block.file.origin_name || 'image.png', { type: mimeType }) - }) - ) - imageFiles = imageFiles.concat(userImages.filter(Boolean) as Blob[]) - - // Collect images from last assistant message - if (lastAssistantMessage) { - const assistantImageBlocks = findImageBlocks(lastAssistantMessage) - const assistantImages = await Promise.all( - assistantImageBlocks.map(async (block) => { - const b64 = block.url?.replace(/^data:image\/\w+;base64,/, '') - if (!b64) return null - const binary = atob(b64) - const bytes = new Uint8Array(binary.length) - for (let i = 0; i < binary.length; i++) bytes[i] = binary.charCodeAt(i) - return await toFile(new Blob([bytes]), 'assistant_image.png', { type: 'image/png' }) - }) - ) - imageFiles = imageFiles.concat(assistantImages.filter(Boolean) as Blob[]) - } - - enqueue({ type: ChunkType.IMAGE_CREATED }) - - const startTime = Date.now() - let response: OpenAI.Images.ImagesResponse - const options = { signal, timeout: defaultTimeout } - - if (imageFiles.length > 0) { - const model = assistant.model - const provider = context.apiClientInstance.provider - // https://learn.microsoft.com/en-us/azure/ai-foundry/openai/how-to/dall-e?tabs=gpt-image-1#call-the-image-edit-api - if (model.id.toLowerCase().includes('gpt-image-1-mini') && provider.type === 'azure-openai') { - throw new Error('Azure OpenAI GPT-Image-1-Mini model does not support image editing.') - } - response = await sdk.images.edit( - { - model: assistant.model.id, - image: imageFiles, - prompt: prompt || '' - }, - options - ) - } else { - response = await sdk.images.generate( - { - model: assistant.model.id, - prompt: prompt || '', - response_format: assistant.model.id.includes('gpt-image-1') ? undefined : 'b64_json' - }, - options - ) - } - - let imageType: 'url' | 'base64' = 'base64' - const imageList = - response.data?.reduce((acc: string[], image) => { - if (image.url) { - acc.push(image.url) - imageType = 'url' - } else if (image.b64_json) { - acc.push(`data:image/png;base64,${image.b64_json}`) - } - return acc - }, []) || [] - - enqueue({ - type: ChunkType.IMAGE_COMPLETE, - image: { type: imageType, images: imageList } - }) - - const usage = (response as any).usage || { prompt_tokens: 0, completion_tokens: 0, total_tokens: 0 } - - enqueue({ - type: ChunkType.LLM_RESPONSE_COMPLETE, - response: { - usage, - metrics: { - completion_tokens: usage.completion_tokens, - time_first_token_millsec: 0, - time_completion_millsec: Date.now() - startTime - } - } - }) - } catch (error: any) { - enqueue({ type: ChunkType.ERROR, error }) - } finally { - controller.close() - } - } - }) - - return { - stream, - getText: () => '' - } - } diff --git a/src/renderer/src/aiCore/legacy/middleware/feat/ThinkingTagExtractionMiddleware.ts b/src/renderer/src/aiCore/legacy/middleware/feat/ThinkingTagExtractionMiddleware.ts deleted file mode 100644 index dea679eaac..0000000000 --- a/src/renderer/src/aiCore/legacy/middleware/feat/ThinkingTagExtractionMiddleware.ts +++ /dev/null @@ -1,198 +0,0 @@ -import { loggerService } from '@logger' -import type { Model } from '@renderer/types' -import type { - TextDeltaChunk, - ThinkingCompleteChunk, - ThinkingDeltaChunk, - ThinkingStartChunk -} from '@renderer/types/chunk' -import { ChunkType } from '@renderer/types/chunk' -import { getLowerBaseModelName } from '@renderer/utils' -import type { TagConfig } from '@renderer/utils/tagExtraction' -import { TagExtractor } from '@renderer/utils/tagExtraction' - -import type { CompletionsParams, CompletionsResult, GenericChunk } from '../schemas' -import type { CompletionsContext, CompletionsMiddleware } from '../types' - -const logger = loggerService.withContext('ThinkingTagExtractionMiddleware') - -export const MIDDLEWARE_NAME = 'ThinkingTagExtractionMiddleware' - -// 不同模型的思考标签配置 -const reasoningTags: TagConfig[] = [ - { openingTag: '', closingTag: '', separator: '\n' }, - { openingTag: '', closingTag: '', separator: '\n' }, - { openingTag: '###Thinking', closingTag: '###Response', separator: '\n' }, - { openingTag: '◁think▷', closingTag: '◁/think▷', separator: '\n' }, - { openingTag: '', closingTag: '', separator: '\n' }, - { openingTag: '', closingTag: '', separator: '\n' } -] - -const getAppropriateTag = (model?: Model): TagConfig => { - const modelId = model?.id ? getLowerBaseModelName(model.id) : undefined - if (modelId?.includes('qwen3')) return reasoningTags[0] - if (modelId?.includes('gemini-2.5')) return reasoningTags[1] - if (modelId?.includes('kimi-vl-a3b-thinking')) return reasoningTags[3] - if (modelId?.includes('seed-oss-36b')) return reasoningTags[5] - // 可以在这里添加更多模型特定的标签配置 - return reasoningTags[0] // 默认使用 标签 -} - -/** - * 处理文本流中思考标签提取的中间件 - * - * 该中间件专门处理文本流中的思考标签内容(如 ...) - * 主要用于 OpenAI 等支持思考标签的 provider - * - * 职责: - * 1. 从文本流中提取思考标签内容 - * 2. 将标签内的内容转换为 THINKING_DELTA chunk - * 3. 将标签外的内容作为正常文本输出 - * 4. 处理不同模型的思考标签格式 - * 5. 在思考内容结束时生成 THINKING_COMPLETE 事件 - */ -export const ThinkingTagExtractionMiddleware: CompletionsMiddleware = - () => - (next) => - async (context: CompletionsContext, params: CompletionsParams): Promise => { - // 调用下游中间件 - const result = await next(context, params) - - // 响应后处理:处理思考标签提取 - if (result.stream) { - const resultFromUpstream = result.stream as ReadableStream - - // 检查是否有流需要处理 - if (resultFromUpstream && resultFromUpstream instanceof ReadableStream) { - // 获取当前模型的思考标签配置 - const model = params.assistant?.model - const reasoningTag = getAppropriateTag(model) - - // 创建标签提取器 - const tagExtractor = new TagExtractor(reasoningTag) - - // thinking 处理状态 - let hasThinkingContent = false - let thinkingStartTime = 0 - - let accumulatingText = false - let accumulatedThinkingContent = '' - const processedStream = resultFromUpstream.pipeThrough( - new TransformStream({ - transform(chunk: GenericChunk, controller) { - logger.silly('chunk', chunk) - - if (chunk.type === ChunkType.TEXT_DELTA) { - const textChunk = chunk as TextDeltaChunk - - // 使用 TagExtractor 处理文本 - const extractionResults = tagExtractor.processText(textChunk.text) - - for (const extractionResult of extractionResults) { - if (extractionResult.complete && extractionResult.tagContentExtracted?.trim()) { - // 完成思考 - // logger.silly( - // 'since extractionResult.complete and extractionResult.tagContentExtracted is not empty, THINKING_COMPLETE chunk is generated' - // ) - // 如果完成思考,更新状态 - accumulatingText = false - - // 生成 THINKING_COMPLETE 事件 - const thinkingCompleteChunk: ThinkingCompleteChunk = { - type: ChunkType.THINKING_COMPLETE, - text: extractionResult.tagContentExtracted.trim(), - thinking_millsec: thinkingStartTime > 0 ? Date.now() - thinkingStartTime : 0 - } - controller.enqueue(thinkingCompleteChunk) - - // 重置思考状态 - hasThinkingContent = false - thinkingStartTime = 0 - } else if (extractionResult.content.length > 0) { - // logger.silly( - // 'since extractionResult.content is not empty, try to generate THINKING_START/THINKING_DELTA chunk' - // ) - if (extractionResult.isTagContent) { - // 如果提取到思考内容,更新状态 - accumulatingText = false - - // 第一次接收到思考内容时记录开始时间 - if (!hasThinkingContent) { - hasThinkingContent = true - thinkingStartTime = Date.now() - controller.enqueue({ - type: ChunkType.THINKING_START - } as ThinkingStartChunk) - } - - if (extractionResult.content?.trim()) { - accumulatedThinkingContent += extractionResult.content.trim() - const thinkingDeltaChunk: ThinkingDeltaChunk = { - type: ChunkType.THINKING_DELTA, - text: accumulatedThinkingContent, - thinking_millsec: thinkingStartTime > 0 ? Date.now() - thinkingStartTime : 0 - } - controller.enqueue(thinkingDeltaChunk) - } - } else { - // 如果没有思考内容,直接输出文本 - // logger.silly( - // 'since extractionResult.isTagContent is falsy, try to generate TEXT_START/TEXT_DELTA chunk' - // ) - // 在非组成文本状态下接收到非思考内容时,生成 TEXT_START chunk 并更新状态 - if (!accumulatingText) { - // logger.silly('since accumulatingText is false, TEXT_START chunk is generated') - controller.enqueue({ - type: ChunkType.TEXT_START - }) - accumulatingText = true - } - // 发送清理后的文本内容 - const cleanTextChunk: TextDeltaChunk = { - ...textChunk, - text: extractionResult.content - } - controller.enqueue(cleanTextChunk) - } - } else { - // logger.silly('since both condition is false, skip') - } - } - } else if (chunk.type !== ChunkType.TEXT_START) { - // logger.silly('since chunk.type is not TEXT_START and not TEXT_DELTA, pass through') - - // logger.silly('since chunk.type is not TEXT_START and not TEXT_DELTA, accumulatingText is set to false') - accumulatingText = false - // 其他类型的chunk直接传递(包括 THINKING_DELTA, THINKING_COMPLETE 等) - controller.enqueue(chunk) - } else { - // 接收到的 TEXT_START chunk 直接丢弃 - // logger.silly('since chunk.type is TEXT_START, passed') - } - }, - flush(controller) { - // 处理可能剩余的思考内容 - const finalResult = tagExtractor.finalize() - if (finalResult?.tagContentExtracted) { - const thinkingCompleteChunk: ThinkingCompleteChunk = { - type: ChunkType.THINKING_COMPLETE, - text: finalResult.tagContentExtracted, - thinking_millsec: thinkingStartTime > 0 ? Date.now() - thinkingStartTime : 0 - } - controller.enqueue(thinkingCompleteChunk) - } - } - }) - ) - - // 更新响应结果 - return { - ...result, - stream: processedStream - } - } else { - logger.warn(`[${MIDDLEWARE_NAME}] No generic chunk stream to process or not a ReadableStream.`) - } - } - return result - } diff --git a/src/renderer/src/aiCore/legacy/middleware/feat/ToolUseExtractionMiddleware.ts b/src/renderer/src/aiCore/legacy/middleware/feat/ToolUseExtractionMiddleware.ts deleted file mode 100644 index 38d842e08d..0000000000 --- a/src/renderer/src/aiCore/legacy/middleware/feat/ToolUseExtractionMiddleware.ts +++ /dev/null @@ -1,138 +0,0 @@ -import { loggerService } from '@logger' -import type { MCPTool } from '@renderer/types' -import type { MCPToolCreatedChunk, TextDeltaChunk } from '@renderer/types/chunk' -import { ChunkType } from '@renderer/types/chunk' -import { parseToolUse } from '@renderer/utils/mcp-tools' -import type { TagConfig } from '@renderer/utils/tagExtraction' -import { TagExtractor } from '@renderer/utils/tagExtraction' - -import type { CompletionsParams, CompletionsResult, GenericChunk } from '../schemas' -import type { CompletionsContext, CompletionsMiddleware } from '../types' - -export const MIDDLEWARE_NAME = 'ToolUseExtractionMiddleware' - -const logger = loggerService.withContext('ToolUseExtractionMiddleware') - -// 工具使用标签配置 -const TOOL_USE_TAG_CONFIG: TagConfig = { - openingTag: '', - closingTag: '', - separator: '\n' -} - -/** - * 工具使用提取中间件 - * - * 职责: - * 1. 从文本流中检测并提取 标签 - * 2. 解析工具调用信息并转换为 ToolUseResponse 格式 - * 3. 生成 MCP_TOOL_CREATED chunk 供 McpToolChunkMiddleware 处理 - * 4. 丢弃 tool_use 之后的所有内容(助手幻觉) - * 5. 清理文本流,移除工具使用标签但保留正常文本 - * - * 注意:此中间件只负责提取和转换,实际工具调用由 McpToolChunkMiddleware 处理 - */ -export const ToolUseExtractionMiddleware: CompletionsMiddleware = - () => - (next) => - async (ctx: CompletionsContext, params: CompletionsParams): Promise => { - const mcpTools = params.mcpTools || [] - - if (!mcpTools || mcpTools.length === 0) return next(ctx, params) - - const result = await next(ctx, params) - - if (result.stream) { - const resultFromUpstream = result.stream as ReadableStream - - const processedStream = resultFromUpstream.pipeThrough(createToolUseExtractionTransform(ctx, mcpTools)) - - return { - ...result, - stream: processedStream - } - } - - return result - } - -/** - * 创建工具使用提取的 TransformStream - */ -function createToolUseExtractionTransform( - _ctx: CompletionsContext, - mcpTools: MCPTool[] -): TransformStream { - const toolUseExtractor = new TagExtractor(TOOL_USE_TAG_CONFIG) - let hasAnyToolUse = false - let toolCounter = 0 - - return new TransformStream({ - async transform(chunk: GenericChunk, controller) { - try { - // 处理文本内容,检测工具使用标签 - logger.silly('chunk', chunk) - if (chunk.type === ChunkType.TEXT_DELTA) { - const textChunk = chunk as TextDeltaChunk - - // 处理 tool_use 标签 - const toolUseResults = toolUseExtractor.processText(textChunk.text) - - for (const result of toolUseResults) { - if (result.complete && result.tagContentExtracted) { - // 提取到完整的工具使用内容,解析并转换为 SDK ToolCall 格式 - const toolUseResponses = parseToolUse(result.tagContentExtracted, mcpTools, toolCounter) - toolCounter += toolUseResponses.length - - if (toolUseResponses.length > 0) { - // 生成 MCP_TOOL_CREATED chunk - const mcpToolCreatedChunk: MCPToolCreatedChunk = { - type: ChunkType.MCP_TOOL_CREATED, - tool_use_responses: toolUseResponses - } - controller.enqueue(mcpToolCreatedChunk) - - // 标记已有工具调用 - hasAnyToolUse = true - } - } else if (!result.isTagContent && result.content) { - if (!hasAnyToolUse) { - const cleanTextChunk: TextDeltaChunk = { - ...textChunk, - text: result.content - } - controller.enqueue(cleanTextChunk) - } - } - // tool_use 标签内的内容不转发,避免重复显示 - } - return - } - - // 转发其他所有chunk - controller.enqueue(chunk) - } catch (error) { - logger.error('Error processing chunk:', error as Error) - controller.error(error) - } - }, - - async flush(controller) { - // 检查是否有未完成的 tool_use 标签内容 - const finalToolUseResult = toolUseExtractor.finalize() - if (finalToolUseResult && finalToolUseResult.tagContentExtracted) { - const toolUseResponses = parseToolUse(finalToolUseResult.tagContentExtracted, mcpTools, toolCounter) - if (toolUseResponses.length > 0) { - const mcpToolCreatedChunk: MCPToolCreatedChunk = { - type: ChunkType.MCP_TOOL_CREATED, - tool_use_responses: toolUseResponses - } - controller.enqueue(mcpToolCreatedChunk) - hasAnyToolUse = true - } - } - } - }) -} - -export default ToolUseExtractionMiddleware diff --git a/src/renderer/src/aiCore/legacy/middleware/index.ts b/src/renderer/src/aiCore/legacy/middleware/index.ts deleted file mode 100644 index 66213c33b8..0000000000 --- a/src/renderer/src/aiCore/legacy/middleware/index.ts +++ /dev/null @@ -1,88 +0,0 @@ -import type { CompletionsMiddleware, MethodMiddleware } from './types' - -// /** -// * Wraps a provider instance with middlewares. -// */ -// export function wrapProviderWithMiddleware( -// apiClientInstance: BaseApiClient, -// middlewareConfig: MiddlewareConfig -// ): BaseApiClient { -// console.log(`[wrapProviderWithMiddleware] Wrapping provider: ${apiClientInstance.provider?.id}`) -// console.log(`[wrapProviderWithMiddleware] Middleware config:`, { -// completions: middlewareConfig.completions?.length || 0, -// methods: Object.keys(middlewareConfig.methods || {}).length -// }) - -// // Cache for already wrapped methods to avoid re-wrapping on every access. -// const wrappedMethodsCache = new Map Promise>() - -// const proxy = new Proxy(apiClientInstance, { -// get(target, propKey, receiver) { -// const methodName = typeof propKey === 'string' ? propKey : undefined - -// if (!methodName) { -// return Reflect.get(target, propKey, receiver) -// } - -// if (wrappedMethodsCache.has(methodName)) { -// console.log(`[wrapProviderWithMiddleware] Using cached wrapped method: ${methodName}`) -// return wrappedMethodsCache.get(methodName) -// } - -// const originalMethod = Reflect.get(target, propKey, receiver) - -// // If the property is not a function, return it directly. -// if (typeof originalMethod !== 'function') { -// return originalMethod -// } - -// let wrappedMethod: ((...args: any[]) => Promise) | undefined - -// // Handle completions method -// if (methodName === 'completions' && middlewareConfig.completions?.length) { -// console.log( -// `[wrapProviderWithMiddleware] Wrapping completions method with ${middlewareConfig.completions.length} middlewares` -// ) -// const completionsOriginalMethod = originalMethod as (params: CompletionsParams) => Promise -// wrappedMethod = applyCompletionsMiddlewares(target, completionsOriginalMethod, middlewareConfig.completions) -// } -// // Handle other methods -// else { -// const methodMiddlewares = middlewareConfig.methods?.[methodName] -// if (methodMiddlewares?.length) { -// console.log( -// `[wrapProviderWithMiddleware] Wrapping method ${methodName} with ${methodMiddlewares.length} middlewares` -// ) -// const genericOriginalMethod = originalMethod as (...args: any[]) => Promise -// wrappedMethod = applyMethodMiddlewares(target, methodName, genericOriginalMethod, methodMiddlewares) -// } -// } - -// if (wrappedMethod) { -// console.log(`[wrapProviderWithMiddleware] Successfully wrapped method: ${methodName}`) -// wrappedMethodsCache.set(methodName, wrappedMethod) -// return wrappedMethod -// } - -// // If no middlewares are configured for this method, return the original method bound to the target. / -// // 如果没有为此方法配置中间件,则返回绑定到目标的原始方法。 -// console.log(`[wrapProviderWithMiddleware] No middlewares for method ${methodName}, returning original`) -// return originalMethod.bind(target) -// } -// }) -// return proxy as BaseApiClient -// } - -// Export types for external use -export type { CompletionsMiddleware, MethodMiddleware } - -// Export MiddlewareBuilder related types and classes -export { - CompletionsMiddlewareBuilder, - createCompletionsBuilder, - createMethodBuilder, - MethodMiddlewareBuilder, - MiddlewareBuilder, - type MiddlewareExecutor, - type NamedMiddleware -} from './builder' diff --git a/src/renderer/src/aiCore/legacy/middleware/register.ts b/src/renderer/src/aiCore/legacy/middleware/register.ts deleted file mode 100644 index 003ce7e93a..0000000000 --- a/src/renderer/src/aiCore/legacy/middleware/register.ts +++ /dev/null @@ -1,149 +0,0 @@ -import * as AbortHandlerModule from './common/AbortHandlerMiddleware' -import * as ErrorHandlerModule from './common/ErrorHandlerMiddleware' -import * as FinalChunkConsumerModule from './common/FinalChunkConsumerMiddleware' -import * as LoggingModule from './common/LoggingMiddleware' -import * as McpToolChunkModule from './core/McpToolChunkMiddleware' -import * as RawStreamListenerModule from './core/RawStreamListenerMiddleware' -import * as ResponseTransformModule from './core/ResponseTransformMiddleware' -// import * as SdkCallModule from './core/SdkCallMiddleware' -import * as StreamAdapterModule from './core/StreamAdapterMiddleware' -import * as TextChunkModule from './core/TextChunkMiddleware' -import * as ThinkChunkModule from './core/ThinkChunkMiddleware' -import * as TransformCoreToSdkParamsModule from './core/TransformCoreToSdkParamsMiddleware' -import * as WebSearchModule from './core/WebSearchMiddleware' -import * as ImageGenerationModule from './feat/ImageGenerationMiddleware' -import * as ThinkingTagExtractionModule from './feat/ThinkingTagExtractionMiddleware' -import * as ToolUseExtractionMiddleware from './feat/ToolUseExtractionMiddleware' - -/** - * 中间件注册表 - 提供所有可用中间件的集中访问 - * 注意:目前中间件文件还未导出 MIDDLEWARE_NAME,会有 linter 错误,这是正常的 - */ -export const MiddlewareRegistry = { - [ErrorHandlerModule.MIDDLEWARE_NAME]: { - name: ErrorHandlerModule.MIDDLEWARE_NAME, - middleware: ErrorHandlerModule.ErrorHandlerMiddleware - }, - // 通用中间件 - [AbortHandlerModule.MIDDLEWARE_NAME]: { - name: AbortHandlerModule.MIDDLEWARE_NAME, - middleware: AbortHandlerModule.AbortHandlerMiddleware - }, - [FinalChunkConsumerModule.MIDDLEWARE_NAME]: { - name: FinalChunkConsumerModule.MIDDLEWARE_NAME, - middleware: FinalChunkConsumerModule.default - }, - - // 核心流程中间件 - [TransformCoreToSdkParamsModule.MIDDLEWARE_NAME]: { - name: TransformCoreToSdkParamsModule.MIDDLEWARE_NAME, - middleware: TransformCoreToSdkParamsModule.TransformCoreToSdkParamsMiddleware - }, - // [SdkCallModule.MIDDLEWARE_NAME]: { - // name: SdkCallModule.MIDDLEWARE_NAME, - // middleware: SdkCallModule.SdkCallMiddleware - // }, - [StreamAdapterModule.MIDDLEWARE_NAME]: { - name: StreamAdapterModule.MIDDLEWARE_NAME, - middleware: StreamAdapterModule.StreamAdapterMiddleware - }, - [RawStreamListenerModule.MIDDLEWARE_NAME]: { - name: RawStreamListenerModule.MIDDLEWARE_NAME, - middleware: RawStreamListenerModule.RawStreamListenerMiddleware - }, - [ResponseTransformModule.MIDDLEWARE_NAME]: { - name: ResponseTransformModule.MIDDLEWARE_NAME, - middleware: ResponseTransformModule.ResponseTransformMiddleware - }, - - // 特性处理中间件 - [ThinkingTagExtractionModule.MIDDLEWARE_NAME]: { - name: ThinkingTagExtractionModule.MIDDLEWARE_NAME, - middleware: ThinkingTagExtractionModule.ThinkingTagExtractionMiddleware - }, - [ToolUseExtractionMiddleware.MIDDLEWARE_NAME]: { - name: ToolUseExtractionMiddleware.MIDDLEWARE_NAME, - middleware: ToolUseExtractionMiddleware.ToolUseExtractionMiddleware - }, - [ThinkChunkModule.MIDDLEWARE_NAME]: { - name: ThinkChunkModule.MIDDLEWARE_NAME, - middleware: ThinkChunkModule.ThinkChunkMiddleware - }, - [McpToolChunkModule.MIDDLEWARE_NAME]: { - name: McpToolChunkModule.MIDDLEWARE_NAME, - middleware: McpToolChunkModule.McpToolChunkMiddleware - }, - [WebSearchModule.MIDDLEWARE_NAME]: { - name: WebSearchModule.MIDDLEWARE_NAME, - middleware: WebSearchModule.WebSearchMiddleware - }, - [TextChunkModule.MIDDLEWARE_NAME]: { - name: TextChunkModule.MIDDLEWARE_NAME, - middleware: TextChunkModule.TextChunkMiddleware - }, - [ImageGenerationModule.MIDDLEWARE_NAME]: { - name: ImageGenerationModule.MIDDLEWARE_NAME, - middleware: ImageGenerationModule.ImageGenerationMiddleware - } -} as const - -/** - * 根据名称获取中间件 - * @param name - 中间件名称 - * @returns 对应的中间件信息 - */ -export function getMiddleware(name: string) { - return MiddlewareRegistry[name] -} - -/** - * 获取所有注册的中间件名称 - * @returns 中间件名称列表 - */ -export function getRegisteredMiddlewareNames(): string[] { - return Object.keys(MiddlewareRegistry) -} - -/** - * 默认的 Completions 中间件配置 - NamedMiddleware 格式,用于 MiddlewareBuilder - */ -export const DefaultCompletionsNamedMiddlewares = [ - MiddlewareRegistry[FinalChunkConsumerModule.MIDDLEWARE_NAME], // 最终消费者 - MiddlewareRegistry[ErrorHandlerModule.MIDDLEWARE_NAME], // 错误处理 - MiddlewareRegistry[TransformCoreToSdkParamsModule.MIDDLEWARE_NAME], // 参数转换 - MiddlewareRegistry[AbortHandlerModule.MIDDLEWARE_NAME], // 中止处理 - MiddlewareRegistry[McpToolChunkModule.MIDDLEWARE_NAME], // 工具处理 - MiddlewareRegistry[TextChunkModule.MIDDLEWARE_NAME], // 文本处理 - MiddlewareRegistry[WebSearchModule.MIDDLEWARE_NAME], // Web搜索处理 - MiddlewareRegistry[ToolUseExtractionMiddleware.MIDDLEWARE_NAME], // 工具使用提取处理 - MiddlewareRegistry[ThinkingTagExtractionModule.MIDDLEWARE_NAME], // 思考标签提取处理(特定provider) - MiddlewareRegistry[ThinkChunkModule.MIDDLEWARE_NAME], // 思考处理(通用SDK) - MiddlewareRegistry[ResponseTransformModule.MIDDLEWARE_NAME], // 响应转换 - MiddlewareRegistry[StreamAdapterModule.MIDDLEWARE_NAME], // 流适配器 - MiddlewareRegistry[RawStreamListenerModule.MIDDLEWARE_NAME] // 原始流监听器 -] - -/** - * 默认的通用方法中间件 - 例如翻译、摘要等 - */ -export const DefaultMethodMiddlewares = { - translate: [LoggingModule.createGenericLoggingMiddleware()], - summaries: [LoggingModule.createGenericLoggingMiddleware()] -} - -/** - * 导出所有中间件模块,方便外部使用 - */ -export { - AbortHandlerModule, - FinalChunkConsumerModule, - LoggingModule, - McpToolChunkModule, - ResponseTransformModule, - StreamAdapterModule, - TextChunkModule, - ThinkChunkModule, - ThinkingTagExtractionModule, - TransformCoreToSdkParamsModule, - WebSearchModule -} diff --git a/src/renderer/src/aiCore/legacy/middleware/schemas.ts b/src/renderer/src/aiCore/legacy/middleware/schemas.ts deleted file mode 100644 index 9119d818db..0000000000 --- a/src/renderer/src/aiCore/legacy/middleware/schemas.ts +++ /dev/null @@ -1,84 +0,0 @@ -import type { Assistant, MCPTool } from '@renderer/types' -import type { Chunk } from '@renderer/types/chunk' -import type { Message } from '@renderer/types/newMessage' -import type { SdkRawChunk, SdkRawOutput } from '@renderer/types/sdk' - -import type { ProcessingState } from './types' - -// ============================================================================ -// Core Request Types - 核心请求结构 -// ============================================================================ - -/** - * 标准化的内部核心请求结构,用于所有AI Provider的统一处理 - * 这是应用层参数转换后的标准格式,不包含回调函数和控制逻辑 - */ -export interface CompletionsParams { - /** - * 调用的业务场景类型,用于中间件判断是否执行 - * 'chat': 主要对话流程 - * 'translate': 翻译 - * 'summary': 摘要 - * 'search': 搜索摘要 - * 'generate': 生成 - * 'check': API检查 - * 'test': 测试调用 - * 'translate-lang-detect': 翻译语言检测 - */ - callType?: 'chat' | 'translate' | 'summary' | 'search' | 'generate' | 'check' | 'test' | 'translate-lang-detect' - - // 基础对话数据 - messages: Message[] | string // 联合类型方便判断是否为空 - - assistant: Assistant // 助手为基本单位 - // model: Model - - onChunk?: (chunk: Chunk) => void - onResponse?: (text: string, isComplete: boolean) => void - - // 错误相关 - onError?: (error: Error) => void - shouldThrow?: boolean - - // 工具相关 - mcpTools?: MCPTool[] - - // 生成参数 - temperature?: number - topP?: number - maxTokens?: number - - // 功能开关 - streamOutput: boolean - enableWebSearch?: boolean - enableUrlContext?: boolean - enableReasoning?: boolean - enableGenerateImage?: boolean - - // 上下文控制 - contextCount?: number - topicId?: string // 主题ID,用于关联上下文 - - // abort 控制 - abortKey?: string - - _internal?: ProcessingState -} - -export interface CompletionsResult { - rawOutput?: SdkRawOutput - stream?: ReadableStream | ReadableStream | AsyncIterable - controller?: AbortController - - getText: () => string -} - -// ============================================================================ -// Generic Chunk Types - 通用数据块结构 -// ============================================================================ - -/** - * 通用数据块类型 - * 复用现有的 Chunk 类型,这是所有AI Provider都应该输出的标准化数据块格式 - */ -export type GenericChunk = Chunk diff --git a/src/renderer/src/aiCore/legacy/middleware/types.ts b/src/renderer/src/aiCore/legacy/middleware/types.ts deleted file mode 100644 index 3762035107..0000000000 --- a/src/renderer/src/aiCore/legacy/middleware/types.ts +++ /dev/null @@ -1,166 +0,0 @@ -import type { MCPToolResponse, Metrics, Usage, WebSearchResponse } from '@renderer/types' -import type { Chunk, ErrorChunk } from '@renderer/types/chunk' -import type { - SdkInstance, - SdkMessageParam, - SdkParams, - SdkRawChunk, - SdkRawOutput, - SdkTool, - SdkToolCall -} from '@renderer/types/sdk' - -import type { BaseApiClient } from '../clients' -import type { CompletionsParams, CompletionsResult } from './schemas' - -/** - * Symbol to uniquely identify middleware context objects. - */ -export const MIDDLEWARE_CONTEXT_SYMBOL = Symbol.for('AiProviderMiddlewareContext') - -/** - * Defines the structure for the onChunk callback function. - */ -export type OnChunkFunction = (chunk: Chunk | ErrorChunk) => void - -/** - * Base context that carries information about the current method call. - */ -export interface BaseContext { - [MIDDLEWARE_CONTEXT_SYMBOL]: true - methodName: string - originalArgs: Readonly -} - -/** - * Processing state shared between middlewares. - */ -export interface ProcessingState< - TParams extends SdkParams = SdkParams, - TMessageParam extends SdkMessageParam = SdkMessageParam, - TToolCall extends SdkToolCall = SdkToolCall -> { - sdkPayload?: TParams - newReqMessages?: TMessageParam[] - observer?: { - usage?: Usage - metrics?: Metrics - } - toolProcessingState?: { - pendingToolCalls?: Array - executingToolCalls?: Array<{ - sdkToolCall: TToolCall - mcpToolResponse: MCPToolResponse - }> - output?: SdkRawOutput | string - isRecursiveCall?: boolean - recursionDepth?: number - } - webSearchState?: { - results?: WebSearchResponse - } - flowControl?: { - abortController?: AbortController - abortSignal?: AbortSignal - cleanup?: () => void - } - enhancedDispatch?: (context: CompletionsContext, params: CompletionsParams) => Promise - customState?: Record -} - -/** - * Extended context for completions method. - */ -export interface CompletionsContext< - TSdkParams extends SdkParams = SdkParams, - TSdkMessageParam extends SdkMessageParam = SdkMessageParam, - TSdkToolCall extends SdkToolCall = SdkToolCall, - TSdkInstance extends SdkInstance = SdkInstance, - TRawOutput extends SdkRawOutput = SdkRawOutput, - TRawChunk extends SdkRawChunk = SdkRawChunk, - TSdkSpecificTool extends SdkTool = SdkTool -> extends BaseContext { - readonly methodName: 'completions' // 强制方法名为 'completions' - - apiClientInstance: BaseApiClient< - TSdkInstance, - TSdkParams, - TRawOutput, - TRawChunk, - TSdkMessageParam, - TSdkToolCall, - TSdkSpecificTool - > - - // --- Mutable internal state for the duration of the middleware chain --- - _internal: ProcessingState // 包含所有可变的处理状态 -} - -export interface MiddlewareAPI { - getContext: () => Ctx // Function to get the current context / 获取当前上下文的函数 - getOriginalArgs: () => Args // Function to get the original arguments of the method call / 获取方法调用原始参数的函数 -} - -/** - * Base middleware type. - */ -export type Middleware = ( - api: MiddlewareAPI -) => ( - next: (context: TContext, args: any[]) => Promise -) => (context: TContext, args: any[]) => Promise - -export type MethodMiddleware = Middleware - -/** - * Completions middleware type. - */ -export type CompletionsMiddleware< - TSdkParams extends SdkParams = SdkParams, - TSdkMessageParam extends SdkMessageParam = SdkMessageParam, - TSdkToolCall extends SdkToolCall = SdkToolCall, - TSdkInstance extends SdkInstance = SdkInstance, - TRawOutput extends SdkRawOutput = SdkRawOutput, - TRawChunk extends SdkRawChunk = SdkRawChunk, - TSdkSpecificTool extends SdkTool = SdkTool -> = ( - api: MiddlewareAPI< - CompletionsContext< - TSdkParams, - TSdkMessageParam, - TSdkToolCall, - TSdkInstance, - TRawOutput, - TRawChunk, - TSdkSpecificTool - >, - [CompletionsParams] - > -) => ( - next: ( - context: CompletionsContext< - TSdkParams, - TSdkMessageParam, - TSdkToolCall, - TSdkInstance, - TRawOutput, - TRawChunk, - TSdkSpecificTool - >, - params: CompletionsParams - ) => Promise -) => ( - context: CompletionsContext< - TSdkParams, - TSdkMessageParam, - TSdkToolCall, - TSdkInstance, - TRawOutput, - TRawChunk, - TSdkSpecificTool - >, - params: CompletionsParams -) => Promise - -// Re-export for convenience -export type { Chunk as OnChunkArg } from '@renderer/types/chunk' diff --git a/src/renderer/src/aiCore/legacy/middleware/utils.ts b/src/renderer/src/aiCore/legacy/middleware/utils.ts deleted file mode 100644 index 32e94e16b6..0000000000 --- a/src/renderer/src/aiCore/legacy/middleware/utils.ts +++ /dev/null @@ -1,58 +0,0 @@ -import type { ErrorChunk } from '@renderer/types/chunk' -import { ChunkType } from '@renderer/types/chunk' - -/** - * Creates an ErrorChunk object with a standardized structure. - * @param error The error object or message. - * @param chunkType The type of chunk, defaults to ChunkType.ERROR. - * @returns An ErrorChunk object. - */ -export function createErrorChunk(error: any, chunkType: ChunkType = ChunkType.ERROR): ErrorChunk { - let errorDetails: Record = {} - - if (error instanceof Error) { - errorDetails = { - message: error.message, - name: error.name, - stack: error.stack - } - } else if (typeof error === 'string') { - errorDetails = { message: error } - } else if (typeof error === 'object' && error !== null) { - errorDetails = Object.getOwnPropertyNames(error).reduce( - (acc, key) => { - acc[key] = error[key] - return acc - }, - {} as Record - ) - if (!errorDetails.message && error.toString && typeof error.toString === 'function') { - const errMsg = error.toString() - if (errMsg !== '[object Object]') { - errorDetails.message = errMsg - } - } - } - - return { - type: chunkType, - error: errorDetails - } as ErrorChunk -} - -// Helper to capitalize method names for hook construction -export function capitalize(str: string): string { - if (!str) return '' - return str.charAt(0).toUpperCase() + str.slice(1) -} - -/** - * 检查对象是否实现了AsyncIterable接口 - */ -export function isAsyncIterable(obj: unknown): obj is AsyncIterable { - return ( - obj !== null && - typeof obj === 'object' && - typeof (obj as Record)[Symbol.asyncIterator] === 'function' - ) -} diff --git a/src/renderer/src/aiCore/types/index.ts b/src/renderer/src/aiCore/types/index.ts index 41cf3b3fb8..4de59fcfea 100644 --- a/src/renderer/src/aiCore/types/index.ts +++ b/src/renderer/src/aiCore/types/index.ts @@ -29,3 +29,11 @@ export type ProviderConfig = String export type { AppProviderId, AppProviderSettingsMap, AppRuntimeConfig } from './merged' export { appProviderIds, getAllProviderIds, isRegisteredProviderId } from './merged' + +/** + * Result of completions operation + * Simple interface with getText method to retrieve the generated text + */ +export type CompletionsResult = { + getText: () => string +} diff --git a/src/renderer/src/services/ApiService.ts b/src/renderer/src/services/ApiService.ts index f5908becd4..f04cd88ffb 100644 --- a/src/renderer/src/services/ApiService.ts +++ b/src/renderer/src/services/ApiService.ts @@ -19,7 +19,7 @@ import { isToolUseModeFunction } from '@renderer/utils/assistant' import { isAbortError } from '@renderer/utils/error' import { purifyMarkdownImages } from '@renderer/utils/markdown' import { isPromptToolUse, isSupportedToolUse } from '@renderer/utils/mcp-tools' -import { findFileBlocks, getMainTextContent } from '@renderer/utils/messageUtils/find' +import { findFileBlocks, findImageBlocks, getMainTextContent } from '@renderer/utils/messageUtils/find' import { containsSupportedVariables, replacePromptVariables } from '@renderer/utils/prompt' import { NOT_SUPPORT_API_KEY_PROVIDER_TYPES, NOT_SUPPORT_API_KEY_PROVIDERS } from '@renderer/utils/provider' import { isEmpty, takeRight } from 'lodash' @@ -35,6 +35,7 @@ import { getQuickModel } from './AssistantService' import { ConversationService } from './ConversationService' +import FileManager from './FileManager' import { injectUserMessageWithKnowledgeSearchPrompt } from './KnowledgeService' import type { BlockManager } from './messageStreaming' import type { StreamProcessorCallbacks } from './StreamProcessingService' @@ -167,6 +168,20 @@ export async function fetchChatCompletion({ const AI = new AiProviderNew(assistant.model || getDefaultModel(), providerWithRotatedKey) const provider = AI.getActualProvider() + // 专用图像生成模型走 generateImage 路径 + if (isDedicatedImageGenerationModel(assistant.model || getDefaultModel())) { + if (!uiMessages || uiMessages.length === 0) { + throw new Error('uiMessages is required for dedicated image generation models') + } + await fetchImageGeneration({ + messages: uiMessages, + assistant, + onChunkReceived, + aiProvider: AI + }) + return + } + const mcpTools: MCPTool[] = [] onChunkReceived({ type: ChunkType.LLM_RESPONSE_CREATED }) @@ -225,6 +240,114 @@ export async function fetchChatCompletion({ }) } +/** + * 从消息中收集图像(用于图像编辑) + * 收集用户消息中上传的图像和助手消息中生成的图像 + */ +async function collectImagesFromMessages(userMessage: Message, assistantMessage?: Message): Promise { + const images: string[] = [] + + // 收集用户消息中的图像 + const userImageBlocks = findImageBlocks(userMessage) + for (const block of userImageBlocks) { + if (block.file) { + const base64 = await FileManager.readBase64File(block.file) + const mimeType = block.file.type || 'image/png' + images.push(`data:${mimeType};base64,${base64}`) + } + } + + // 收集助手消息中的图像(用于继续编辑生成的图像) + if (assistantMessage) { + const assistantImageBlocks = findImageBlocks(assistantMessage) + for (const block of assistantImageBlocks) { + if (block.url) { + images.push(block.url) + } + } + } + + return images +} + +/** + * 独立的图像生成函数 + * 专用于 DALL-E、GPT-Image-1 等专用图像生成模型 + */ +async function fetchImageGeneration({ + messages, + assistant, + onChunkReceived, + aiProvider +}: { + messages: Message[] + assistant: Assistant + onChunkReceived: (chunk: Chunk) => void + aiProvider: AiProviderNew +}) { + onChunkReceived({ type: ChunkType.LLM_RESPONSE_CREATED }) + onChunkReceived({ type: ChunkType.IMAGE_CREATED }) + + const startTime = Date.now() + + try { + // 提取 prompt 和图像 + const lastUserMessage = messages.findLast((m) => m.role === 'user') + const lastAssistantMessage = messages.findLast((m) => m.role === 'assistant') + + if (!lastUserMessage) { + throw new Error('No user message found for image generation.') + } + + const prompt = getMainTextContent(lastUserMessage) + const inputImages = await collectImagesFromMessages(lastUserMessage, lastAssistantMessage) + + // 调用 generateImage 或 editImage + // 使用默认图像生成配置 + const imageSize = '1024x1024' + const batchSize = 1 + + let images: string[] + if (inputImages.length > 0) { + images = await aiProvider.editImage({ + model: assistant.model!.id, + prompt: prompt || '', + inputImages, + imageSize + }) + } else { + images = await aiProvider.generateImage({ + model: assistant.model!.id, + prompt: prompt || '', + imageSize, + batchSize + }) + } + + // 发送结果 chunks + const imageType = images[0]?.startsWith('data:') ? 'base64' : 'url' + onChunkReceived({ + type: ChunkType.IMAGE_COMPLETE, + image: { type: imageType, images } + }) + + onChunkReceived({ + type: ChunkType.LLM_RESPONSE_COMPLETE, + response: { + usage: { prompt_tokens: 0, completion_tokens: 0, total_tokens: 0 }, + metrics: { + completion_tokens: 0, + time_first_token_millsec: 0, + time_completion_millsec: Date.now() - startTime + } + } + }) + } catch (error) { + onChunkReceived({ type: ChunkType.ERROR, error: error as Error }) + throw error + } +} + export async function fetchMessagesSummary({ messages, assistant }: { messages: Message[]; assistant: Assistant }) { let prompt = (getStoreSetting('topicNamingPrompt') as string) || i18n.t('prompts.title') const model = getQuickModel() || assistant?.model || getDefaultModel() diff --git a/src/renderer/src/services/KnowledgeService.ts b/src/renderer/src/services/KnowledgeService.ts index ce9577c68d..7caa5dcde6 100644 --- a/src/renderer/src/services/KnowledgeService.ts +++ b/src/renderer/src/services/KnowledgeService.ts @@ -1,7 +1,6 @@ import { loggerService } from '@logger' import type { Span } from '@opentelemetry/api' -import { ModernAiProvider } from '@renderer/aiCore' -import AiProvider from '@renderer/aiCore/legacy' +import ModernAiProvider from '@renderer/aiCore' import { getMessageContent } from '@renderer/aiCore/plugins/searchOrchestrationPlugin' import { DEFAULT_KNOWLEDGE_DOCUMENT_COUNT, DEFAULT_KNOWLEDGE_THRESHOLD } from '@renderer/config/constant' import { getEmbeddingMaxContext } from '@renderer/config/embedings' @@ -36,7 +35,7 @@ const logger = loggerService.withContext('RendererKnowledgeService') export const getKnowledgeBaseParams = (base: KnowledgeBase): KnowledgeBaseParams => { const rerankProvider = getProviderByModel(base.rerankModel) const aiProvider = new ModernAiProvider(base.model) - const rerankAiProvider = new AiProvider(rerankProvider) + const rerankAiProvider = new ModernAiProvider(rerankProvider) // get preprocess provider from store instead of base.preprocessProvider const preprocessProvider = store diff --git a/src/renderer/src/services/SpanManagerService.ts b/src/renderer/src/services/SpanManagerService.ts index 71f00d85b9..bb6f65699d 100644 --- a/src/renderer/src/services/SpanManagerService.ts +++ b/src/renderer/src/services/SpanManagerService.ts @@ -5,7 +5,6 @@ import type { SpanEntity, TokenUsage } from '@mcp-trace/trace-core' import { cleanContext, endContext, getContext, startContext } from '@mcp-trace/trace-web' import type { Context, Span } from '@opentelemetry/api' import { context, SpanStatusCode, trace } from '@opentelemetry/api' -import { isAsyncIterable } from '@renderer/aiCore/legacy/middleware/utils' import { db } from '@renderer/databases' import { getEnableDeveloperMode } from '@renderer/hooks/useSettings' import { EVENT_NAMES, EventEmitter } from '@renderer/services/EventService' @@ -22,6 +21,11 @@ import type { SdkRawChunk } from '@renderer/types/sdk' const logger = loggerService.withContext('SpanManagerService') +// Type guard for AsyncIterable +function isAsyncIterable(obj: any): obj is AsyncIterable { + return obj != null && typeof obj === 'object' && typeof obj[Symbol.asyncIterator] === 'function' +} + class SpanManagerService { private spanMap: Map = new Map() diff --git a/src/renderer/src/services/__tests__/ApiService.test.ts b/src/renderer/src/services/__tests__/ApiService.legacy.test.ts.bak similarity index 99% rename from src/renderer/src/services/__tests__/ApiService.test.ts rename to src/renderer/src/services/__tests__/ApiService.legacy.test.ts.bak index 1e9792cdcd..ebb773291a 100644 --- a/src/renderer/src/services/__tests__/ApiService.test.ts +++ b/src/renderer/src/services/__tests__/ApiService.legacy.test.ts.bak @@ -10,7 +10,7 @@ import type OpenAI from '@cherrystudio/openai' import type { ChatCompletionChunk } from '@cherrystudio/openai/resources' import type { FunctionCall } from '@google/genai' import { FinishReason, MediaModality } from '@google/genai' -import AiProvider from '@renderer/aiCore' +import AiProvider from '@renderer/aiCore/legacy' import type { BaseApiClient, OpenAIAPIClient, ResponseChunkTransformerContext } from '@renderer/aiCore/legacy/clients' import type { AnthropicAPIClient } from '@renderer/aiCore/legacy/clients/anthropic/AnthropicAPIClient' import { ApiClientFactory } from '@renderer/aiCore/legacy/clients/ApiClientFactory' diff --git a/src/renderer/src/trace/dataHandler/CommonResultHandler.ts b/src/renderer/src/trace/dataHandler/CommonResultHandler.ts index 468555d948..dce5a2c688 100644 --- a/src/renderer/src/trace/dataHandler/CommonResultHandler.ts +++ b/src/renderer/src/trace/dataHandler/CommonResultHandler.ts @@ -1,6 +1,6 @@ import type { TokenUsage } from '@mcp-trace/trace-core' import type { Span } from '@opentelemetry/api' -import type { CompletionsResult } from '@renderer/aiCore/legacy/middleware/schemas' +import type { CompletionsResult } from '@renderer/aiCore/types' import { endSpan } from '@renderer/services/SpanManagerService' export class CompletionsResultHandler {