diff --git a/packages/aiCore/src/core/providers/core/ExtensionRegistry.ts b/packages/aiCore/src/core/providers/core/ExtensionRegistry.ts index 5ed6b7eb59..30ef9bc270 100644 --- a/packages/aiCore/src/core/providers/core/ExtensionRegistry.ts +++ b/packages/aiCore/src/core/providers/core/ExtensionRegistry.ts @@ -51,7 +51,7 @@ export class ExtensionRegistry { * 支持链式调用 */ register(extension: ProviderExtension): this { - const { name, aliases } = extension.config + const { name, aliases, variants } = extension.config // 检查主 ID 冲突 if (this.extensions.has(name)) { @@ -71,6 +71,19 @@ export class ExtensionRegistry { } } + // 注册变体 ID + if (variants) { + for (const variant of variants) { + const variantId = `${name}-${variant.suffix}` + if (this.aliasMap.has(variantId)) { + throw new Error( + `Provider variant ID "${variantId}" is already registered for "${this.aliasMap.get(variantId)}"` + ) + } + this.aliasMap.set(variantId, name) + } + } + return this } @@ -375,15 +388,24 @@ export class ExtensionRegistry { * @returns Provider 实例 */ async createProvider(id: string, settings?: any, explicitId?: string): Promise { - const extension = this.get(id) - if (!extension) { + // 解析 provider ID,提取基础 ID 和变体后缀 + const parsed = this.parseProviderId(id) + if (!parsed) { throw new Error(`Provider extension "${id}" not found. Did you forget to register it?`) } + const { baseId, mode: variantSuffix } = parsed + + // 获取基础 extension + const extension = this.get(baseId) + if (!extension) { + throw new Error(`Provider extension "${baseId}" not found. Did you forget to register it?`) + } + try { - // 直接委托给 Extension 的 createProvider 方法 - // Extension 负责缓存、生命周期钩子、AI SDK 注册等 - return await extension.createProvider(settings, explicitId) + // 委托给 Extension 的 createProvider 方法 + // Extension 负责缓存、生命周期钩子、AI SDK 注册、变体转换等 + return await extension.createProvider(settings, explicitId, variantSuffix) } catch (error) { throw new ProviderCreationError( `Failed to create provider "${id}"`, diff --git a/packages/aiCore/src/core/providers/core/ProviderExtension.ts b/packages/aiCore/src/core/providers/core/ProviderExtension.ts index af8894a468..c8484e1b13 100644 --- a/packages/aiCore/src/core/providers/core/ProviderExtension.ts +++ b/packages/aiCore/src/core/providers/core/ProviderExtension.ts @@ -281,35 +281,43 @@ export class ProviderExtension< /** * 计算 settings 的稳定 hash * 用于缓存 key,确保相同配置复用实例 + * + * @param settings - Provider 配置 + * @param variantSuffix - 可选的变体后缀,用于区分不同变体的缓存 */ - private computeHash(settings?: TSettings): string { - if (settings === undefined || settings === null) { - return 'default' - } + private computeHash(settings?: TSettings, variantSuffix?: string): string { + const baseHash = (() => { + if (settings === undefined || settings === null) { + return 'default' + } - // 使用 JSON.stringify 进行稳定序列化 - // 对于对象按键排序以确保一致性 - const stableStringify = (obj: any): string => { - if (obj === null || obj === undefined) return 'null' - if (typeof obj !== 'object') return JSON.stringify(obj) - if (Array.isArray(obj)) return `[${obj.map(stableStringify).join(',')}]` + // 使用 JSON.stringify 进行稳定序列化 + // 对于对象按键排序以确保一致性 + const stableStringify = (obj: any): string => { + if (obj === null || obj === undefined) return 'null' + if (typeof obj !== 'object') return JSON.stringify(obj) + if (Array.isArray(obj)) return `[${obj.map(stableStringify).join(',')}]` - const keys = Object.keys(obj).sort() - const pairs = keys.map((key) => `${JSON.stringify(key)}:${stableStringify(obj[key])}`) - return `{${pairs.join(',')}}` - } + const keys = Object.keys(obj).sort() + const pairs = keys.map((key) => `${JSON.stringify(key)}:${stableStringify(obj[key])}`) + return `{${pairs.join(',')}}` + } - const serialized = stableStringify(settings) + const serialized = stableStringify(settings) - // 使用简单的哈希函数(不需要加密级别的安全性) - let hash = 0 - for (let i = 0; i < serialized.length; i++) { - const char = serialized.charCodeAt(i) - hash = (hash << 5) - hash + char - hash = hash & hash // Convert to 32bit integer - } + // 使用简单的哈希函数(不需要加密级别的安全性) + let hash = 0 + for (let i = 0; i < serialized.length; i++) { + const char = serialized.charCodeAt(i) + hash = (hash << 5) - hash + char + hash = hash & hash // Convert to 32bit integer + } - return `${Math.abs(hash).toString(36)}` + return `${Math.abs(hash).toString(36)}` + })() + + // 如果有变体后缀,将其附加到 hash 中 + return variantSuffix ? `${baseHash}:${variantSuffix}` : baseHash } /** @@ -329,17 +337,29 @@ export class ProviderExtension< * * @param settings - Provider 配置 * @param explicitId - 可选的显式 ID,用于 AI SDK 注册 + * @param variantSuffix - 可选的变体后缀,用于应用变体转换 * @returns Provider 实例 */ - async createProvider(settings?: TSettings, explicitId?: string): Promise { + async createProvider(settings?: TSettings, explicitId?: string, variantSuffix?: string): Promise { + // 验证变体后缀(如果提供) + if (variantSuffix) { + const variant = this.getVariant(variantSuffix) + if (!variant) { + throw new Error( + `ProviderExtension "${this.config.name}": variant "${variantSuffix}" not found. ` + + `Available variants: ${this.config.variants?.map((v) => v.suffix).join(', ') || 'none'}` + ) + } + } + // 合并 default options const mergedSettings = deepMergeObjects( (this.config.defaultOptions || {}) as any, (settings || {}) as any ) as TSettings - // 计算 hash - const hash = this.computeHash(mergedSettings) + // 计算 hash(包含变体后缀) + const hash = this.computeHash(mergedSettings, variantSuffix) // 检查缓存 const cachedInstance = this.instances.get(hash) @@ -350,12 +370,12 @@ export class ProviderExtension< // 执行 onBeforeCreate 钩子 await this.executeHook('onBeforeCreate', mergedSettings) - // 创建新实例 - let provider: ProviderV3 + // 创建基础 provider 实例 + let baseProvider: ProviderV3 if (this.config.create) { // 使用直接创建函数 - provider = await Promise.resolve(this.config.create(mergedSettings)) + baseProvider = await Promise.resolve(this.config.create(mergedSettings)) } else if (this.config.import && this.config.creatorFunctionName) { // 动态导入 const module = await this.config.import() @@ -367,29 +387,45 @@ export class ProviderExtension< ) } - provider = await Promise.resolve(creatorFn(mergedSettings)) + baseProvider = await Promise.resolve(creatorFn(mergedSettings)) } else { throw new Error(`ProviderExtension "${this.config.name}": cannot create provider, invalid configuration`) } - const typedProvider = provider as TProvider - - // 执行 onAfterCreate 钩子 - await this.executeHook('onAfterCreate', mergedSettings, typedProvider) - - // 缓存实例 - this.instances.set(hash, typedProvider) - this.settingsHashes.set(hash, mergedSettings) - - // 注册到 AI SDK(如果提供了 explicitId) - if (explicitId) { - this.registerToAiSdk(typedProvider, explicitId) + // 应用变体转换(如果提供了变体后缀) + let finalProvider: TProvider + if (variantSuffix) { + const variant = this.getVariant(variantSuffix)! + // 应用变体的 transform 函数 + finalProvider = (await Promise.resolve(variant.transform(baseProvider as TProvider, mergedSettings))) as TProvider } else { - // 使用默认 ID: name:hash - this.registerToAiSdk(typedProvider, `${this.config.name}:${hash}`) + finalProvider = baseProvider as TProvider } - return typedProvider + // 执行 onAfterCreate 钩子 + await this.executeHook('onAfterCreate', mergedSettings, finalProvider) + + // 缓存实例 + this.instances.set(hash, finalProvider) + this.settingsHashes.set(hash, mergedSettings) + + // 确定注册 ID + const registrationId = (() => { + if (explicitId) { + return explicitId + } + // 如果是变体,使用 name-suffix:hash 格式 + if (variantSuffix) { + return `${this.config.name}-${variantSuffix}:${hash}` + } + // 否则使用 name:hash + return `${this.config.name}:${hash}` + })() + + // 注册到 AI SDK + this.registerToAiSdk(finalProvider, registrationId) + + return finalProvider } /** diff --git a/src/renderer/src/aiCore/index_new.ts b/src/renderer/src/aiCore/index_new.ts index ea6fab734d..2753133118 100644 --- a/src/renderer/src/aiCore/index_new.ts +++ b/src/renderer/src/aiCore/index_new.ts @@ -24,7 +24,6 @@ import AiSdkToChunkAdapter from './chunk/AiSdkToChunkAdapter' import LegacyAiProvider from './legacy/index' import type { CompletionsParams, CompletionsResult } from './legacy/middleware/schemas' import { buildPlugins } from './plugins/PluginBuilder' -import { createAiSdkProvider } from './provider/factory' import { adaptProvider, getActualProvider, @@ -32,7 +31,7 @@ import { prepareSpecialProviderConfig, providerToAiSdkConfig } from './provider/providerConfig' -import type { AiSdkConfig } from './types' +import type { ProviderConfig } from './types' import type { AiSdkMiddlewareConfig } from './types/middlewareConfig' const logger = loggerService.withContext('ModernAiProvider') @@ -46,7 +45,7 @@ export type ModernAiProviderConfig = AiSdkMiddlewareConfig & { export default class ModernAiProvider { private legacyProvider: LegacyAiProvider - private config?: AiSdkConfig + private config?: ProviderConfig private actualProvider: Provider private model?: Model private localProvider: Awaited | null = null @@ -133,7 +132,7 @@ export default class ModernAiProvider { if (!this.config) { throw new Error('Provider config is undefined; cannot proceed with completions') } - if (SUPPORTED_IMAGE_ENDPOINT_LIST.includes(this.config.options.endpoint)) { + if (SUPPORTED_IMAGE_ENDPOINT_LIST.includes(this.config.providerSettings.endpoint)) { providerConfig.isImageGenerationEndpoint = true } // 准备特殊配置 @@ -141,7 +140,7 @@ export default class ModernAiProvider { // 提前创建本地 provider 实例 if (!this.localProvider) { - this.localProvider = await createAiSdkProvider(this.config) + // this.localProvider = await createAiSdkProvider(this.config) // TODO: Update provider creation } if (!this.localProvider) { @@ -321,7 +320,7 @@ export default class ModernAiProvider { const plugins = buildPlugins(config) // 用构建好的插件数组创建executor - const executor = createExecutor(this.config!.providerId, this.config!.options, plugins) + const executor = createExecutor(this.config!.providerId, this.config!.providerSettings, plugins) // 创建带有中间件的执行器 if (config.onChunk) { @@ -406,7 +405,7 @@ export default class ModernAiProvider { } // 调用新 AI SDK 的图像生成功能 - const executor = createExecutor(this.config!.providerId, this.config!.options, []) + const executor = createExecutor(this.config!.providerId, this.config!.providerSettings, []) const result = await executor.generateImage({ model, ...imageParams @@ -504,7 +503,7 @@ export default class ModernAiProvider { // 确保本地provider已创建 if (!this.localProvider && this.config) { - this.localProvider = await createAiSdkProvider(this.config) + // this.localProvider = await createAiSdkProvider(this.config) // TODO: Update provider creation if (!this.localProvider) { throw new Error('Local provider not created') } @@ -537,7 +536,7 @@ export default class ModernAiProvider { ...(signal && { abortSignal: signal }) } - const executor = createExecutor(this.config!.providerId, this.config!.options, []) + const executor = createExecutor(this.config!.providerId, this.config!.providerSettings, []) const result = await executor.generateImage({ model: model, // 直接使用 model ID 字符串,由 executor 内部解析 ...aiSdkParams diff --git a/src/renderer/src/aiCore/provider/__tests__/providerConfig.test.ts b/src/renderer/src/aiCore/provider/__tests__/providerConfig.test.ts index b1d8e34fcd..e5fbed8eda 100644 --- a/src/renderer/src/aiCore/provider/__tests__/providerConfig.test.ts +++ b/src/renderer/src/aiCore/provider/__tests__/providerConfig.test.ts @@ -174,9 +174,11 @@ describe('Copilot responses routing', () => { const config = providerToAiSdkConfig(provider, createModel('gpt-5-codex', 'GPT-5-CODEX')) expect(config.providerId).toBe('github-copilot-openai-compatible') - expect(config.options.headers?.['Editor-Version']).toBe(COPILOT_EDITOR_VERSION) - expect(config.options.headers?.['Copilot-Integration-Id']).toBe(COPILOT_DEFAULT_HEADERS['Copilot-Integration-Id']) - expect(config.options.headers?.['copilot-vision-request']).toBe('true') + expect(config.providerSettings.headers?.['Editor-Version']).toBe(COPILOT_EDITOR_VERSION) + expect(config.providerSettings.headers?.['Copilot-Integration-Id']).toBe( + COPILOT_DEFAULT_HEADERS['Copilot-Integration-Id'] + ) + expect(config.providerSettings.headers?.['copilot-vision-request']).toBe('true') }) it('uses the Copilot provider for other models and keeps headers', () => { @@ -184,8 +186,10 @@ describe('Copilot responses routing', () => { const config = providerToAiSdkConfig(provider, createModel('gpt-4')) expect(config.providerId).toBe('github-copilot-openai-compatible') - expect(config.options.headers?.['Editor-Version']).toBe(COPILOT_DEFAULT_HEADERS['Editor-Version']) - expect(config.options.headers?.['Copilot-Integration-Id']).toBe(COPILOT_DEFAULT_HEADERS['Copilot-Integration-Id']) + expect(config.providerSettings.headers?.['Editor-Version']).toBe(COPILOT_DEFAULT_HEADERS['Editor-Version']) + expect(config.providerSettings.headers?.['Copilot-Integration-Id']).toBe( + COPILOT_DEFAULT_HEADERS['Copilot-Integration-Id'] + ) }) }) @@ -388,7 +392,7 @@ describe('Stream options includeUsage configuration', () => { const provider = createOpenAIProvider() const config = providerToAiSdkConfig(provider, createModel('gpt-4', 'GPT-4', 'openai')) - expect(config.options.includeUsage).toBeUndefined() + expect(config.providerSettings.includeUsage).toBeUndefined() }) it('uses includeUsage from settings when set to true', () => { @@ -406,7 +410,7 @@ describe('Stream options includeUsage configuration', () => { const provider = createOpenAIProvider() const config = providerToAiSdkConfig(provider, createModel('gpt-4', 'GPT-4', 'openai')) - expect(config.options.includeUsage).toBe(true) + expect(config.providerSettings.includeUsage).toBe(true) }) it('uses includeUsage from settings when set to false', () => { @@ -424,7 +428,7 @@ describe('Stream options includeUsage configuration', () => { const provider = createOpenAIProvider() const config = providerToAiSdkConfig(provider, createModel('gpt-4', 'GPT-4', 'openai')) - expect(config.options.includeUsage).toBe(false) + expect(config.providerSettings.includeUsage).toBe(false) }) it('respects includeUsage setting for non-supporting providers', () => { @@ -455,7 +459,7 @@ describe('Stream options includeUsage configuration', () => { const config = providerToAiSdkConfig(testProvider, createModel('gpt-4', 'GPT-4', 'test')) // Even though setting is true, provider doesn't support it, so includeUsage should be undefined - expect(config.options.includeUsage).toBeUndefined() + expect(config.providerSettings.includeUsage).toBeUndefined() }) it('uses includeUsage from settings for Copilot provider when set to false', () => { @@ -473,7 +477,7 @@ describe('Stream options includeUsage configuration', () => { const provider = createCopilotProvider() const config = providerToAiSdkConfig(provider, createModel('gpt-4', 'GPT-4', 'copilot')) - expect(config.options.includeUsage).toBe(false) + expect(config.providerSettings.includeUsage).toBe(false) expect(config.providerId).toBe('github-copilot-openai-compatible') }) @@ -492,7 +496,7 @@ describe('Stream options includeUsage configuration', () => { const provider = createCopilotProvider() const config = providerToAiSdkConfig(provider, createModel('gpt-4', 'GPT-4', 'copilot')) - expect(config.options.includeUsage).toBe(true) + expect(config.providerSettings.includeUsage).toBe(true) expect(config.providerId).toBe('github-copilot-openai-compatible') }) @@ -511,7 +515,7 @@ describe('Stream options includeUsage configuration', () => { const provider = createCopilotProvider() const config = providerToAiSdkConfig(provider, createModel('gpt-4', 'GPT-4', 'copilot')) - expect(config.options.includeUsage).toBeUndefined() + expect(config.providerSettings.includeUsage).toBeUndefined() expect(config.providerId).toBe('github-copilot-openai-compatible') }) }) @@ -540,21 +544,21 @@ describe('Azure OpenAI traditional API routing', () => { const config = providerToAiSdkConfig(provider, createModel('gpt-4o', 'GPT-4o', provider.id)) expect(config.providerId).toBe('azure') - expect(config.options.apiVersion).toBe('2024-02-15-preview') - expect(config.options.useDeploymentBasedUrls).toBe(true) + expect(config.providerSettings.apiVersion).toBe('2024-02-15-preview') + expect(config.providerSettings.useDeploymentBasedUrls).toBe(true) }) it('does not force deployment-based URLs for apiVersion v1/preview', () => { const v1Provider = createAzureProvider('v1') const v1Config = providerToAiSdkConfig(v1Provider, createModel('gpt-4o', 'GPT-4o', v1Provider.id)) expect(v1Config.providerId).toBe('azure-responses') - expect(v1Config.options.apiVersion).toBe('v1') - expect(v1Config.options.useDeploymentBasedUrls).toBeUndefined() + expect(v1Config.providerSettings.apiVersion).toBe('v1') + expect(v1Config.providerSettings.useDeploymentBasedUrls).toBeUndefined() const previewProvider = createAzureProvider('preview') const previewConfig = providerToAiSdkConfig(previewProvider, createModel('gpt-4o', 'GPT-4o', previewProvider.id)) expect(previewConfig.providerId).toBe('azure-responses') - expect(previewConfig.options.apiVersion).toBe('preview') - expect(previewConfig.options.useDeploymentBasedUrls).toBeUndefined() + expect(previewConfig.providerSettings.apiVersion).toBe('preview') + expect(previewConfig.providerSettings.useDeploymentBasedUrls).toBeUndefined() }) }) diff --git a/src/renderer/src/aiCore/provider/extensions/index.ts b/src/renderer/src/aiCore/provider/extensions/index.ts index bafc8363d5..9f60eb664a 100644 --- a/src/renderer/src/aiCore/provider/extensions/index.ts +++ b/src/renderer/src/aiCore/provider/extensions/index.ts @@ -3,8 +3,21 @@ * 用于支持运行时动态导入的 AI Providers */ -import type { ProviderV2 } from '@ai-sdk/provider' -import { ProviderExtension, type ProviderExtensionConfig } from '@cherrystudio/ai-core/provider' +import { type AmazonBedrockProviderSettings, createAmazonBedrock } from '@ai-sdk/amazon-bedrock' +import { type AnthropicProviderSettings, createAnthropic } from '@ai-sdk/anthropic' +import { type CerebrasProviderSettings, createCerebras } from '@ai-sdk/cerebras' +import { createGateway, type GatewayProviderSettings } from '@ai-sdk/gateway' +import { createVertexAnthropic } from '@ai-sdk/google-vertex/anthropic/edge' +import { createVertex, type GoogleVertexProviderSettings } from '@ai-sdk/google-vertex/edge' +import { createHuggingFace, type HuggingFaceProviderSettings } from '@ai-sdk/huggingface' +import { createMistral, type MistralProviderSettings } from '@ai-sdk/mistral' +import { createPerplexity, type PerplexityProviderSettings } from '@ai-sdk/perplexity' +import type { ProviderV2, ProviderV3 } from '@ai-sdk/provider' +import { ExtensionStorage, ProviderExtension, type ProviderExtensionConfig } from '@cherrystudio/ai-core/provider' +import { + createGitHubCopilotOpenAICompatible, + type GitHubCopilotProviderSettings +} from '@opeoginni/github-copilot-openai-compatible' import { wrapProvider } from 'ai' import type { OllamaProviderSettings } from 'ollama-ai-provider-v2' import { createOllama } from 'ollama-ai-provider-v2' @@ -16,9 +29,13 @@ export const GoogleVertexExtension = ProviderExtension.create({ name: 'google-vertex', aliases: ['vertexai'] as const, supportsImageGeneration: true, - import: () => import('@ai-sdk/google-vertex/edge'), - creatorFunctionName: 'createVertex' -} as const satisfies ProviderExtensionConfig) + create: createVertex +} as const satisfies ProviderExtensionConfig< + GoogleVertexProviderSettings, + ExtensionStorage, + ProviderV3, + 'google-vertex' +>) /** * Google Vertex AI Anthropic Extension @@ -27,9 +44,13 @@ export const GoogleVertexAnthropicExtension = ProviderExtension.create({ name: 'google-vertex-anthropic', aliases: ['vertexai-anthropic'] as const, supportsImageGeneration: true, - import: () => import('@ai-sdk/google-vertex/anthropic/edge'), - creatorFunctionName: 'createVertexAnthropic' -} as const satisfies ProviderExtensionConfig) + create: createVertexAnthropic +} as const satisfies ProviderExtensionConfig< + GoogleVertexProviderSettings, + ExtensionStorage, + ProviderV3, + 'google-vertex-anthropic' +>) /** * Azure AI Anthropic Extension @@ -38,9 +59,13 @@ export const AzureAnthropicExtension = ProviderExtension.create({ name: 'azure-anthropic', aliases: ['azure-anthropic'] as const, supportsImageGeneration: false, - import: () => import('@ai-sdk/anthropic'), - creatorFunctionName: 'createAnthropic' -} as const satisfies ProviderExtensionConfig) + create: createAnthropic +} as const satisfies ProviderExtensionConfig< + AnthropicProviderSettings, + ExtensionStorage, + ProviderV3, + 'azure-anthropic' +>) /** * GitHub Copilot Extension @@ -49,9 +74,16 @@ export const GitHubCopilotExtension = ProviderExtension.create({ name: 'github-copilot-openai-compatible', aliases: ['copilot', 'github-copilot'] as const, supportsImageGeneration: false, - import: () => import('@opeoginni/github-copilot-openai-compatible'), - creatorFunctionName: 'createGitHubCopilotOpenAICompatible' -} as const satisfies ProviderExtensionConfig) + create: (options?: GitHubCopilotProviderSettings) => { + const provider = createGitHubCopilotOpenAICompatible(options) as unknown as ProviderV2 + return wrapProvider({ provider, languageModelMiddleware: [] }) + } +} as const satisfies ProviderExtensionConfig< + GitHubCopilotProviderSettings, + ExtensionStorage, + ProviderV3, + 'github-copilot-openai-compatible' +>) /** * Amazon Bedrock Extension @@ -60,9 +92,8 @@ export const BedrockExtension = ProviderExtension.create({ name: 'bedrock', aliases: ['aws-bedrock'] as const, supportsImageGeneration: true, - import: () => import('@ai-sdk/amazon-bedrock'), - creatorFunctionName: 'createAmazonBedrock' -} as const satisfies ProviderExtensionConfig) + create: createAmazonBedrock +} as const satisfies ProviderExtensionConfig) /** * Perplexity Extension @@ -70,9 +101,8 @@ export const BedrockExtension = ProviderExtension.create({ export const PerplexityExtension = ProviderExtension.create({ name: 'perplexity', supportsImageGeneration: false, - import: () => import('@ai-sdk/perplexity'), - creatorFunctionName: 'createPerplexity' -} as const satisfies ProviderExtensionConfig) + create: createPerplexity +} as const satisfies ProviderExtensionConfig) /** * Mistral Extension @@ -81,9 +111,8 @@ export const MistralExtension = ProviderExtension.create({ name: 'mistral', aliases: ['mistral'] as const, supportsImageGeneration: false, - import: () => import('@ai-sdk/mistral'), - creatorFunctionName: 'createMistral' -} as const satisfies ProviderExtensionConfig) + create: createMistral +} as const satisfies ProviderExtensionConfig) /** * HuggingFace Extension @@ -92,9 +121,8 @@ export const HuggingFaceExtension = ProviderExtension.create({ name: 'huggingface', aliases: ['hf', 'hugging-face'] as const, supportsImageGeneration: true, - import: () => import('@ai-sdk/huggingface'), - creatorFunctionName: 'createHuggingFace' -} as const satisfies ProviderExtensionConfig) + create: createHuggingFace +} as const satisfies ProviderExtensionConfig) /** * Vercel AI Gateway Extension @@ -103,9 +131,8 @@ export const GatewayExtension = ProviderExtension.create({ name: 'gateway', aliases: ['ai-gateway'] as const, supportsImageGeneration: true, - import: () => import('@ai-sdk/gateway'), - creatorFunctionName: 'createGateway' -} as const satisfies ProviderExtensionConfig) + create: createGateway +} as const satisfies ProviderExtensionConfig) /** * Cerebras Extension @@ -113,9 +140,8 @@ export const GatewayExtension = ProviderExtension.create({ export const CerebrasExtension = ProviderExtension.create({ name: 'cerebras', supportsImageGeneration: false, - import: () => import('@ai-sdk/cerebras'), - creatorFunctionName: 'createCerebras' -} as const satisfies ProviderExtensionConfig) + create: createCerebras +} as const satisfies ProviderExtensionConfig) /** * Ollama Extension @@ -127,7 +153,7 @@ export const OllamaExtension = ProviderExtension.create({ const provider = createOllama(options) as ProviderV2 return wrapProvider({ provider, languageModelMiddleware: [] }) } -} as const satisfies ProviderExtensionConfig) +} as const satisfies ProviderExtensionConfig) /** * 所有项目特定的 Extensions diff --git a/src/renderer/src/aiCore/provider/providerConfig.ts b/src/renderer/src/aiCore/provider/providerConfig.ts index f16275fc68..022f72e2f6 100644 --- a/src/renderer/src/aiCore/provider/providerConfig.ts +++ b/src/renderer/src/aiCore/provider/providerConfig.ts @@ -1,6 +1,5 @@ -import { extensionRegistry, formatPrivateKey, hasProviderConfig } from '@cherrystudio/ai-core/provider' +import { formatPrivateKey, hasProviderConfig } from '@cherrystudio/ai-core/provider' import type { AppProviderId } from '@renderer/aiCore/types' -import { isOpenAIChatCompletionOnlyModel } from '@renderer/config/models' import { getAwsBedrockAccessKeyId, getAwsBedrockApiKey, @@ -13,7 +12,6 @@ import { getProviderByModel } from '@renderer/services/AssistantService' import { getProviderById } from '@renderer/services/ProviderService' import store from '@renderer/store' import { isSystemProvider, type Model, type Provider, SystemProviderIds } from '@renderer/types' -import type { OpenAICompletionsStreamOptions } from '@renderer/types/aiCoreTypes' import { formatApiHost, formatAzureOpenAIApiHost, @@ -36,7 +34,7 @@ import { import { defaultAppHeaders } from '@shared/utils' import { cloneDeep, isEmpty } from 'lodash' -import type { AiSdkConfigRuntime } from '../types' +import type { ProviderConfig } from '../types' import { aihubmixProviderCreator, newApiResolverCreator, vertexAnthropicProviderCreator } from './config' import { azureAnthropicProviderCreator } from './config/azure-anthropic' import { COPILOT_DEFAULT_HEADERS } from './constants' @@ -146,155 +144,56 @@ export function adaptProvider({ provider, model }: { provider: Provider; model?: * * @param actualProvider - Cherry Studio provider配置 * @param model - 模型配置 - * @returns 类型安全的 AI SDK 配置 + * @returns 类型安全的 Provider 配置 */ -export function providerToAiSdkConfig(actualProvider: Provider, model: Model): AiSdkConfigRuntime { - const aiSdkProviderId: AppProviderId = getAiSdkProviderId(actualProvider) - - // 构建基础配置 +export function providerToAiSdkConfig(actualProvider: Provider, model: Model): ProviderConfig { + const aiSdkProviderId = getAiSdkProviderId(actualProvider) const { baseURL, endpoint } = routeToEndpoint(actualProvider.apiHost) - const baseConfig = { - baseURL: baseURL, - apiKey: actualProvider.apiKey - } - let includeUsage: OpenAICompletionsStreamOptions['include_usage'] = undefined - if (isSupportStreamOptionsProvider(actualProvider)) { - includeUsage = store.getState().settings.openAI?.streamOptions?.includeUsage + + // 构建上下文 + const ctx: BuilderContext = { + actualProvider, + model, + baseConfig: { + baseURL, + apiKey: actualProvider.apiKey + }, + endpoint, + aiSdkProviderId } - const isCopilotProvider = actualProvider.id === SystemProviderIds.copilot - if (isCopilotProvider) { - const storedHeaders = store.getState().copilot.defaultHeaders ?? {} - const options = { - ...baseConfig, - headers: { - ...COPILOT_DEFAULT_HEADERS, - ...storedHeaders, - ...actualProvider.extra_headers - }, - name: actualProvider.id, - includeUsage - } - - return { - providerId: 'github-copilot-openai-compatible', - options - } + // 路由到专门的构建器 + if (actualProvider.id === SystemProviderIds.copilot) { + return buildCopilotConfig(ctx) } if (isOllamaProvider(actualProvider)) { - return { - providerId: 'ollama', - options: { - ...baseConfig, - headers: { - ...actualProvider.extra_headers, - Authorization: !isEmpty(baseConfig.apiKey) ? `Bearer ${baseConfig.apiKey}` : undefined - } - } - } + return buildOllamaConfig(ctx) } - // 处理OpenAI模式 - const extraOptions: any = {} - extraOptions.endpoint = endpoint - - // 解析 provider ID,提取 base ID 和 mode - const parsed = extensionRegistry.parseProviderId(aiSdkProviderId) - if (parsed?.mode) { - // 自动设置 mode(如 openai-chat → mode: 'chat', azure-responses → mode: 'responses') - extraOptions.mode = parsed.mode - } - - // 特殊处理:OpenAI responses 模式(基于 provider.type) - if (actualProvider.type === 'openai-response' && !isOpenAIChatCompletionOnlyModel(model)) { - extraOptions.mode = 'responses' - } else if (aiSdkProviderId === 'openai' || (aiSdkProviderId === 'cherryin' && actualProvider.type === 'openai')) { - // 确保 OpenAI 和 CherryIN(type=openai) 使用 chat 模式 - extraOptions.mode = 'chat' - } - - extraOptions.headers = { - ...defaultAppHeaders(), - ...actualProvider.extra_headers - } - - if (aiSdkProviderId === 'openai') { - const headers = extraOptions.headers as Record - headers['X-Api-Key'] = baseConfig.apiKey - } if (isAzureOpenAIProvider(actualProvider)) { - const apiVersion = actualProvider.apiVersion?.trim() - if (apiVersion) { - extraOptions.apiVersion = apiVersion - if (!['preview', 'v1'].includes(apiVersion)) { - extraOptions.useDeploymentBasedUrls = true - } - } + return buildAzureConfig(ctx) } - // bedrock if (aiSdkProviderId === 'bedrock') { - const authType = getAwsBedrockAuthType() - extraOptions.region = getAwsBedrockRegion() - - if (authType === 'apiKey') { - extraOptions.apiKey = getAwsBedrockApiKey() - } else { - extraOptions.accessKeyId = getAwsBedrockAccessKeyId() - extraOptions.secretAccessKey = getAwsBedrockSecretAccessKey() - } + return buildBedrockConfig(ctx) } - // google-vertex + if (aiSdkProviderId === 'google-vertex' || aiSdkProviderId === 'google-vertex-anthropic') { - if (!isVertexAIConfigured()) { - throw new Error('VertexAI is not configured. Please configure project, location and service account credentials.') - } - const { project, location, googleCredentials } = createVertexProvider(actualProvider) - extraOptions.project = project - extraOptions.location = location - extraOptions.googleCredentials = { - ...googleCredentials, - privateKey: formatPrivateKey(googleCredentials.privateKey) - } - baseConfig.baseURL += aiSdkProviderId === 'google-vertex' ? '/publishers/google' : '/publishers/anthropic/models' + return buildVertexConfig(ctx) } - // cherryin if (aiSdkProviderId === 'cherryin') { - if (model.endpoint_type) { - extraOptions.endpointType = model.endpoint_type - } - // CherryIN API Host - const cherryinProvider = getProviderById(SystemProviderIds.cherryin) - if (cherryinProvider) { - extraOptions.anthropicBaseURL = cherryinProvider.anthropicApiHost + '/v1' - extraOptions.geminiBaseURL = cherryinProvider.apiHost + '/v1beta/models' - } + return buildCherryinConfig(ctx) } + // 有 SDK 支持的 provider if (hasProviderConfig(aiSdkProviderId) && aiSdkProviderId !== 'openai-compatible') { - const options = { - ...baseConfig, - ...extraOptions - } - return { - providerId: aiSdkProviderId, - options - } + return buildGenericProviderConfig(ctx) } - // 否则fallback到openai-compatible - return { - providerId: 'openai-compatible', - options: { - baseURL: baseConfig.baseURL, - apiKey: baseConfig.apiKey, - name: actualProvider.id, - ...extraOptions, - includeUsage - } - } + // 默认 fallback 到 openai-compatible + return buildOpenAICompatibleConfig(ctx) } /** @@ -317,10 +216,7 @@ export function isModernSdkSupported(provider: Provider): boolean { /** * 准备特殊provider的配置,主要用于异步处理的配置 */ -export async function prepareSpecialProviderConfig( - provider: Provider, - config: ReturnType -) { +export async function prepareSpecialProviderConfig(provider: Provider, config: ProviderConfig) { switch (provider.id) { case 'copilot': { const defaultHeaders = store.getState().copilot.defaultHeaders ?? {} @@ -329,15 +225,17 @@ export async function prepareSpecialProviderConfig( ...defaultHeaders } const { token } = await window.api.copilot.getToken(headers) - config.options.apiKey = token - config.options.headers = { + const settings = config.providerSettings as any + settings.apiKey = token + settings.headers = { ...headers, - ...config.options.headers + ...settings.headers } break } case 'cherryai': { - config.options.fetch = async (url, options) => { + const settings = config.providerSettings as any + settings.fetch = async (url: string, options: any) => { // 在这里对最终参数进行签名 const signature = await window.api.cherryai.generateSignature({ method: 'POST', @@ -358,10 +256,11 @@ export async function prepareSpecialProviderConfig( case 'anthropic': { if (provider.authType === 'oauth') { const oauthToken = await window.api.anthropic_oauth.getAccessToken() - config.options = { - ...config.options, + const settings = config.providerSettings as any + config.providerSettings = { + ...settings, headers: { - ...(config.options.headers ? config.options.headers : {}), + ...(settings.headers ? settings.headers : {}), 'Content-Type': 'application/json', 'anthropic-version': '2023-06-01', Authorization: `Bearer ${oauthToken}` @@ -374,3 +273,253 @@ export async function prepareSpecialProviderConfig( } return config } + +/** + * 基础配置 + */ +interface BaseConfig { + baseURL: string + apiKey: string +} + +/** + * 构建器上下文 + */ +interface BuilderContext { + actualProvider: Provider + model: Model + baseConfig: BaseConfig + endpoint?: string + aiSdkProviderId: AppProviderId +} + +/** + * GitHub Copilot 配置构建器 + */ +function buildCopilotConfig(ctx: BuilderContext): ProviderConfig<'github-copilot-openai-compatible'> { + const storedHeaders = store.getState().copilot.defaultHeaders ?? {} + + return { + providerId: 'github-copilot-openai-compatible', + providerSettings: { + ...ctx.baseConfig, + headers: { + ...COPILOT_DEFAULT_HEADERS, + ...storedHeaders, + ...ctx.actualProvider.extra_headers + }, + name: ctx.actualProvider.id + } + } +} + +/** + * Ollama 配置构建器 + */ +function buildOllamaConfig(ctx: BuilderContext): ProviderConfig<'ollama'> { + const headers: ProviderConfig<'ollama'>['providerSettings']['headers'] = { + ...ctx.actualProvider.extra_headers + } + + if (!isEmpty(ctx.baseConfig.apiKey)) { + headers.Authorization = `Bearer ${ctx.baseConfig.apiKey}` + } + + return { + providerId: 'ollama', + providerSettings: { + ...ctx.baseConfig, + headers + } + } +} + +/** + * AWS Bedrock 配置构建器 + */ +function buildBedrockConfig(ctx: BuilderContext): ProviderConfig<'bedrock'> { + const authType = getAwsBedrockAuthType() + const region = getAwsBedrockRegion() + + if (authType === 'apiKey') { + return { + providerId: 'bedrock', + providerSettings: { + ...ctx.baseConfig, + region, + apiKey: getAwsBedrockApiKey() + } + } + } + + return { + providerId: 'bedrock', + providerSettings: { + ...ctx.baseConfig, + region, + accessKeyId: getAwsBedrockAccessKeyId(), + secretAccessKey: getAwsBedrockSecretAccessKey() + } + } +} + +/** + * Google Vertex AI 配置构建器 + */ +function buildVertexConfig( + ctx: BuilderContext +): ProviderConfig<'google-vertex'> | ProviderConfig<'google-vertex-anthropic'> { + if (!isVertexAIConfigured()) { + throw new Error('VertexAI is not configured. Please configure project, location and service account credentials.') + } + + const { project, location, googleCredentials } = createVertexProvider(ctx.actualProvider) + const isAnthropic = ctx.aiSdkProviderId === 'google-vertex-anthropic' + + const baseURL = ctx.baseConfig.baseURL + (isAnthropic ? '/publishers/anthropic/models' : '/publishers/google') + + if (isAnthropic) { + return { + providerId: 'google-vertex-anthropic', + providerSettings: { + ...ctx.baseConfig, + baseURL, + project, + location, + googleCredentials: { + ...googleCredentials, + privateKey: formatPrivateKey(googleCredentials.privateKey) + } + } + } + } + + return { + providerId: 'google-vertex', + providerSettings: { + ...ctx.baseConfig, + baseURL, + project, + location, + googleCredentials: { + ...googleCredentials, + privateKey: formatPrivateKey(googleCredentials.privateKey) + } + } + } +} + +/** + * CherryIN 配置构建器 + */ +function buildCherryinConfig(ctx: BuilderContext): ProviderConfig<'cherryin'> { + const cherryinProvider = getProviderById(SystemProviderIds.cherryin) + + return { + providerId: 'cherryin', + providerSettings: { + ...ctx.baseConfig, + endpointType: ctx.model.endpoint_type, + anthropicBaseURL: cherryinProvider ? cherryinProvider.anthropicApiHost + '/v1' : undefined, + geminiBaseURL: cherryinProvider ? cherryinProvider.apiHost + '/v1beta/models' : undefined, + headers: { + ...defaultAppHeaders(), + ...ctx.actualProvider.extra_headers + } + } + } +} + +/** + * Azure OpenAI 配置构建器 + */ +function buildAzureConfig(ctx: BuilderContext): ProviderConfig<'azure'> | ProviderConfig<'azure-responses'> { + const apiVersion = ctx.actualProvider.apiVersion?.trim() + + // 根据 apiVersion 决定使用 azure 还是 azure-responses + const useResponsesMode = apiVersion && ['preview', 'v1'].includes(apiVersion) + + const providerSettings: Record = { + ...ctx.baseConfig, + endpoint: ctx.endpoint, + headers: { + ...defaultAppHeaders(), + ...ctx.actualProvider.extra_headers + } + } + + if (apiVersion) { + providerSettings.apiVersion = apiVersion + // 只有非 preview/v1 版本才使用 deployment-based URLs + if (!useResponsesMode) { + providerSettings.useDeploymentBasedUrls = true + } + } + + if (useResponsesMode) { + return { + providerId: 'azure-responses', + providerSettings + } + } + + return { + providerId: 'azure', + providerSettings + } +} + +/** + * 构建通用的 OpenAI-compatible 或特定 provider 的额外选项 + */ +function buildCommonOptions(ctx: BuilderContext) { + const options: Record = { + endpoint: ctx.endpoint, + headers: { + ...defaultAppHeaders(), + ...ctx.actualProvider.extra_headers + } + } + + // OpenAI 特殊 header + if (ctx.aiSdkProviderId === 'openai') { + options.headers['X-Api-Key'] = ctx.baseConfig.apiKey + } + + return options +} + +/** + * OpenAI-compatible 配置构建器 + */ +function buildOpenAICompatibleConfig(ctx: BuilderContext): ProviderConfig<'openai-compatible'> { + const commonOptions = buildCommonOptions(ctx) + const includeUsage = isSupportStreamOptionsProvider(ctx.actualProvider) + ? store.getState().settings.openAI?.streamOptions?.includeUsage + : undefined + + return { + providerId: 'openai-compatible', + providerSettings: { + ...ctx.baseConfig, + ...commonOptions, + name: ctx.actualProvider.id, + includeUsage + } + } +} + +/** + * 通用 provider 配置构建器(有 SDK 支持的 provider) + */ +function buildGenericProviderConfig(ctx: BuilderContext): ProviderConfig { + const commonOptions = buildCommonOptions(ctx) + + return { + providerId: ctx.aiSdkProviderId, + providerSettings: { + ...ctx.baseConfig, + ...commonOptions + } + } +} diff --git a/src/renderer/src/aiCore/types/index.ts b/src/renderer/src/aiCore/types/index.ts index 5c82f5f7c0..21b561970f 100644 --- a/src/renderer/src/aiCore/types/index.ts +++ b/src/renderer/src/aiCore/types/index.ts @@ -7,85 +7,38 @@ * TODO: We should separate them clearly. Keep renderer only types in renderer, and main only types in main, and shared types in shared. */ -import type { AppProviderId, AppProviderSettingsMap } from './merged' +import type { AppProviderId, AppRuntimeConfig } from './merged' /** - * Generic AI SDK configuration with compile-time type safety + * Provider 配置(不含 plugins) + * 基于 RuntimeConfig,用于构建 provider 实例的基础配置 * * 🎯 Zero maintenance! Auto-extracts types from core and project extensions. * - * @typeParam T - The specific provider ID type for type-safe options + * @typeParam T - The specific provider ID type for type-safe settings * * @example * ```ts * // Type-safe config for core provider - * const config1: AiSdkConfig<'openai'> = { + * const config1: ProviderConfig<'openai'> = { * providerId: 'openai', - * options: { apiKey: '...', baseURL: '...' } // ✅ Typed as OpenAIProviderSettings + * providerSettings: { apiKey: '...', baseURL: '...' } // ✅ Typed as OpenAIProviderSettings * } * * // Type-safe config for project provider - * const config2: AiSdkConfig<'google-vertex'> = { + * const config2: ProviderConfig<'google-vertex'> = { * providerId: 'google-vertex', - * options: { ... } // ✅ Typed as GoogleVertexProviderSettings + * providerSettings: { ... } // ✅ Typed as GoogleVertexProviderSettings * } * * // Type-safe config with alias - * const config3: AiSdkConfig<'oai'> = { + * const config3: ProviderConfig<'oai'> = { * providerId: 'oai', - * options: { apiKey: '...' } // ✅ Same type as 'openai' + * providerSettings: { apiKey: '...' } // ✅ Same type as 'openai' * } * ``` */ -export type AiSdkConfig = { - providerId: T - options: AppProviderSettingsMap[T] -} - -/** - * Runtime-safe AI SDK configuration for gradual migration - * Use this when provider ID is not known at compile time - * - * 使用联合类型而不是 any,提供更好的类型安全性 - * - * @example - * ```ts - * function createConfig(providerId: AppProviderId): AiSdkConfigRuntime { - * return { - * providerId, - * options: buildOptions(providerId) // ✅ 类型安全:options 必须是某个 provider 的 settings - * } - * } - * ``` - */ -export type AiSdkConfigRuntime = { - providerId: AppProviderId - options: AppProviderSettingsMap[AppProviderId] -} - -/** - * Type guard for runtime validation of AiSdkConfig - * - * @param config - Unknown value to validate - * @returns true if config is a valid AiSdkConfigRuntime - * - * @example - * ```ts - * if (isValidAiSdkConfig(someConfig)) { - * // someConfig is now typed as AiSdkConfigRuntime - * await createAiSdkProvider(someConfig) - * } - * ``` - */ -export function isValidAiSdkConfig(config: unknown): config is AiSdkConfigRuntime { - if (!config || typeof config !== 'object') return false - - const c = config as Record - - return ( - typeof c.providerId === 'string' && c.providerId.length > 0 && typeof c.options === 'object' && c.options !== null - ) -} +export type ProviderConfig = Omit, 'plugins'> export type { AppProviderId, AppProviderSettingsMap } from './merged' export { appProviderIds, getAllProviderIds, isRegisteredProviderId } from './merged'