mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2026-02-21 18:14:45 +08:00
refactor:type inference
This commit is contained in:
parent
42ff133732
commit
e3351097a9
@ -51,7 +51,7 @@ export class ExtensionRegistry {
|
||||
* 支持链式调用
|
||||
*/
|
||||
register(extension: ProviderExtension<any, any, any>): this {
|
||||
const { name, aliases } = extension.config
|
||||
const { name, aliases, variants } = extension.config
|
||||
|
||||
// 检查主 ID 冲突
|
||||
if (this.extensions.has(name)) {
|
||||
@ -71,6 +71,19 @@ export class ExtensionRegistry {
|
||||
}
|
||||
}
|
||||
|
||||
// 注册变体 ID
|
||||
if (variants) {
|
||||
for (const variant of variants) {
|
||||
const variantId = `${name}-${variant.suffix}`
|
||||
if (this.aliasMap.has(variantId)) {
|
||||
throw new Error(
|
||||
`Provider variant ID "${variantId}" is already registered for "${this.aliasMap.get(variantId)}"`
|
||||
)
|
||||
}
|
||||
this.aliasMap.set(variantId, name)
|
||||
}
|
||||
}
|
||||
|
||||
return this
|
||||
}
|
||||
|
||||
@ -375,15 +388,24 @@ export class ExtensionRegistry {
|
||||
* @returns Provider 实例
|
||||
*/
|
||||
async createProvider(id: string, settings?: any, explicitId?: string): Promise<ProviderV3> {
|
||||
const extension = this.get(id)
|
||||
if (!extension) {
|
||||
// 解析 provider ID,提取基础 ID 和变体后缀
|
||||
const parsed = this.parseProviderId(id)
|
||||
if (!parsed) {
|
||||
throw new Error(`Provider extension "${id}" not found. Did you forget to register it?`)
|
||||
}
|
||||
|
||||
const { baseId, mode: variantSuffix } = parsed
|
||||
|
||||
// 获取基础 extension
|
||||
const extension = this.get(baseId)
|
||||
if (!extension) {
|
||||
throw new Error(`Provider extension "${baseId}" not found. Did you forget to register it?`)
|
||||
}
|
||||
|
||||
try {
|
||||
// 直接委托给 Extension 的 createProvider 方法
|
||||
// Extension 负责缓存、生命周期钩子、AI SDK 注册等
|
||||
return await extension.createProvider(settings, explicitId)
|
||||
// 委托给 Extension 的 createProvider 方法
|
||||
// Extension 负责缓存、生命周期钩子、AI SDK 注册、变体转换等
|
||||
return await extension.createProvider(settings, explicitId, variantSuffix)
|
||||
} catch (error) {
|
||||
throw new ProviderCreationError(
|
||||
`Failed to create provider "${id}"`,
|
||||
|
||||
@ -281,35 +281,43 @@ export class ProviderExtension<
|
||||
/**
|
||||
* 计算 settings 的稳定 hash
|
||||
* 用于缓存 key,确保相同配置复用实例
|
||||
*
|
||||
* @param settings - Provider 配置
|
||||
* @param variantSuffix - 可选的变体后缀,用于区分不同变体的缓存
|
||||
*/
|
||||
private computeHash(settings?: TSettings): string {
|
||||
if (settings === undefined || settings === null) {
|
||||
return 'default'
|
||||
}
|
||||
private computeHash(settings?: TSettings, variantSuffix?: string): string {
|
||||
const baseHash = (() => {
|
||||
if (settings === undefined || settings === null) {
|
||||
return 'default'
|
||||
}
|
||||
|
||||
// 使用 JSON.stringify 进行稳定序列化
|
||||
// 对于对象按键排序以确保一致性
|
||||
const stableStringify = (obj: any): string => {
|
||||
if (obj === null || obj === undefined) return 'null'
|
||||
if (typeof obj !== 'object') return JSON.stringify(obj)
|
||||
if (Array.isArray(obj)) return `[${obj.map(stableStringify).join(',')}]`
|
||||
// 使用 JSON.stringify 进行稳定序列化
|
||||
// 对于对象按键排序以确保一致性
|
||||
const stableStringify = (obj: any): string => {
|
||||
if (obj === null || obj === undefined) return 'null'
|
||||
if (typeof obj !== 'object') return JSON.stringify(obj)
|
||||
if (Array.isArray(obj)) return `[${obj.map(stableStringify).join(',')}]`
|
||||
|
||||
const keys = Object.keys(obj).sort()
|
||||
const pairs = keys.map((key) => `${JSON.stringify(key)}:${stableStringify(obj[key])}`)
|
||||
return `{${pairs.join(',')}}`
|
||||
}
|
||||
const keys = Object.keys(obj).sort()
|
||||
const pairs = keys.map((key) => `${JSON.stringify(key)}:${stableStringify(obj[key])}`)
|
||||
return `{${pairs.join(',')}}`
|
||||
}
|
||||
|
||||
const serialized = stableStringify(settings)
|
||||
const serialized = stableStringify(settings)
|
||||
|
||||
// 使用简单的哈希函数(不需要加密级别的安全性)
|
||||
let hash = 0
|
||||
for (let i = 0; i < serialized.length; i++) {
|
||||
const char = serialized.charCodeAt(i)
|
||||
hash = (hash << 5) - hash + char
|
||||
hash = hash & hash // Convert to 32bit integer
|
||||
}
|
||||
// 使用简单的哈希函数(不需要加密级别的安全性)
|
||||
let hash = 0
|
||||
for (let i = 0; i < serialized.length; i++) {
|
||||
const char = serialized.charCodeAt(i)
|
||||
hash = (hash << 5) - hash + char
|
||||
hash = hash & hash // Convert to 32bit integer
|
||||
}
|
||||
|
||||
return `${Math.abs(hash).toString(36)}`
|
||||
return `${Math.abs(hash).toString(36)}`
|
||||
})()
|
||||
|
||||
// 如果有变体后缀,将其附加到 hash 中
|
||||
return variantSuffix ? `${baseHash}:${variantSuffix}` : baseHash
|
||||
}
|
||||
|
||||
/**
|
||||
@ -329,17 +337,29 @@ export class ProviderExtension<
|
||||
*
|
||||
* @param settings - Provider 配置
|
||||
* @param explicitId - 可选的显式 ID,用于 AI SDK 注册
|
||||
* @param variantSuffix - 可选的变体后缀,用于应用变体转换
|
||||
* @returns Provider 实例
|
||||
*/
|
||||
async createProvider(settings?: TSettings, explicitId?: string): Promise<TProvider> {
|
||||
async createProvider(settings?: TSettings, explicitId?: string, variantSuffix?: string): Promise<TProvider> {
|
||||
// 验证变体后缀(如果提供)
|
||||
if (variantSuffix) {
|
||||
const variant = this.getVariant(variantSuffix)
|
||||
if (!variant) {
|
||||
throw new Error(
|
||||
`ProviderExtension "${this.config.name}": variant "${variantSuffix}" not found. ` +
|
||||
`Available variants: ${this.config.variants?.map((v) => v.suffix).join(', ') || 'none'}`
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// 合并 default options
|
||||
const mergedSettings = deepMergeObjects(
|
||||
(this.config.defaultOptions || {}) as any,
|
||||
(settings || {}) as any
|
||||
) as TSettings
|
||||
|
||||
// 计算 hash
|
||||
const hash = this.computeHash(mergedSettings)
|
||||
// 计算 hash(包含变体后缀)
|
||||
const hash = this.computeHash(mergedSettings, variantSuffix)
|
||||
|
||||
// 检查缓存
|
||||
const cachedInstance = this.instances.get(hash)
|
||||
@ -350,12 +370,12 @@ export class ProviderExtension<
|
||||
// 执行 onBeforeCreate 钩子
|
||||
await this.executeHook('onBeforeCreate', mergedSettings)
|
||||
|
||||
// 创建新实例
|
||||
let provider: ProviderV3
|
||||
// 创建基础 provider 实例
|
||||
let baseProvider: ProviderV3
|
||||
|
||||
if (this.config.create) {
|
||||
// 使用直接创建函数
|
||||
provider = await Promise.resolve(this.config.create(mergedSettings))
|
||||
baseProvider = await Promise.resolve(this.config.create(mergedSettings))
|
||||
} else if (this.config.import && this.config.creatorFunctionName) {
|
||||
// 动态导入
|
||||
const module = await this.config.import()
|
||||
@ -367,29 +387,45 @@ export class ProviderExtension<
|
||||
)
|
||||
}
|
||||
|
||||
provider = await Promise.resolve(creatorFn(mergedSettings))
|
||||
baseProvider = await Promise.resolve(creatorFn(mergedSettings))
|
||||
} else {
|
||||
throw new Error(`ProviderExtension "${this.config.name}": cannot create provider, invalid configuration`)
|
||||
}
|
||||
|
||||
const typedProvider = provider as TProvider
|
||||
|
||||
// 执行 onAfterCreate 钩子
|
||||
await this.executeHook('onAfterCreate', mergedSettings, typedProvider)
|
||||
|
||||
// 缓存实例
|
||||
this.instances.set(hash, typedProvider)
|
||||
this.settingsHashes.set(hash, mergedSettings)
|
||||
|
||||
// 注册到 AI SDK(如果提供了 explicitId)
|
||||
if (explicitId) {
|
||||
this.registerToAiSdk(typedProvider, explicitId)
|
||||
// 应用变体转换(如果提供了变体后缀)
|
||||
let finalProvider: TProvider
|
||||
if (variantSuffix) {
|
||||
const variant = this.getVariant(variantSuffix)!
|
||||
// 应用变体的 transform 函数
|
||||
finalProvider = (await Promise.resolve(variant.transform(baseProvider as TProvider, mergedSettings))) as TProvider
|
||||
} else {
|
||||
// 使用默认 ID: name:hash
|
||||
this.registerToAiSdk(typedProvider, `${this.config.name}:${hash}`)
|
||||
finalProvider = baseProvider as TProvider
|
||||
}
|
||||
|
||||
return typedProvider
|
||||
// 执行 onAfterCreate 钩子
|
||||
await this.executeHook('onAfterCreate', mergedSettings, finalProvider)
|
||||
|
||||
// 缓存实例
|
||||
this.instances.set(hash, finalProvider)
|
||||
this.settingsHashes.set(hash, mergedSettings)
|
||||
|
||||
// 确定注册 ID
|
||||
const registrationId = (() => {
|
||||
if (explicitId) {
|
||||
return explicitId
|
||||
}
|
||||
// 如果是变体,使用 name-suffix:hash 格式
|
||||
if (variantSuffix) {
|
||||
return `${this.config.name}-${variantSuffix}:${hash}`
|
||||
}
|
||||
// 否则使用 name:hash
|
||||
return `${this.config.name}:${hash}`
|
||||
})()
|
||||
|
||||
// 注册到 AI SDK
|
||||
this.registerToAiSdk(finalProvider, registrationId)
|
||||
|
||||
return finalProvider
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@ -24,7 +24,6 @@ import AiSdkToChunkAdapter from './chunk/AiSdkToChunkAdapter'
|
||||
import LegacyAiProvider from './legacy/index'
|
||||
import type { CompletionsParams, CompletionsResult } from './legacy/middleware/schemas'
|
||||
import { buildPlugins } from './plugins/PluginBuilder'
|
||||
import { createAiSdkProvider } from './provider/factory'
|
||||
import {
|
||||
adaptProvider,
|
||||
getActualProvider,
|
||||
@ -32,7 +31,7 @@ import {
|
||||
prepareSpecialProviderConfig,
|
||||
providerToAiSdkConfig
|
||||
} from './provider/providerConfig'
|
||||
import type { AiSdkConfig } from './types'
|
||||
import type { ProviderConfig } from './types'
|
||||
import type { AiSdkMiddlewareConfig } from './types/middlewareConfig'
|
||||
|
||||
const logger = loggerService.withContext('ModernAiProvider')
|
||||
@ -46,7 +45,7 @@ export type ModernAiProviderConfig = AiSdkMiddlewareConfig & {
|
||||
|
||||
export default class ModernAiProvider {
|
||||
private legacyProvider: LegacyAiProvider
|
||||
private config?: AiSdkConfig
|
||||
private config?: ProviderConfig
|
||||
private actualProvider: Provider
|
||||
private model?: Model
|
||||
private localProvider: Awaited<AiSdkProvider> | null = null
|
||||
@ -133,7 +132,7 @@ export default class ModernAiProvider {
|
||||
if (!this.config) {
|
||||
throw new Error('Provider config is undefined; cannot proceed with completions')
|
||||
}
|
||||
if (SUPPORTED_IMAGE_ENDPOINT_LIST.includes(this.config.options.endpoint)) {
|
||||
if (SUPPORTED_IMAGE_ENDPOINT_LIST.includes(this.config.providerSettings.endpoint)) {
|
||||
providerConfig.isImageGenerationEndpoint = true
|
||||
}
|
||||
// 准备特殊配置
|
||||
@ -141,7 +140,7 @@ export default class ModernAiProvider {
|
||||
|
||||
// 提前创建本地 provider 实例
|
||||
if (!this.localProvider) {
|
||||
this.localProvider = await createAiSdkProvider(this.config)
|
||||
// this.localProvider = await createAiSdkProvider(this.config) // TODO: Update provider creation
|
||||
}
|
||||
|
||||
if (!this.localProvider) {
|
||||
@ -321,7 +320,7 @@ export default class ModernAiProvider {
|
||||
const plugins = buildPlugins(config)
|
||||
|
||||
// 用构建好的插件数组创建executor
|
||||
const executor = createExecutor(this.config!.providerId, this.config!.options, plugins)
|
||||
const executor = createExecutor(this.config!.providerId, this.config!.providerSettings, plugins)
|
||||
|
||||
// 创建带有中间件的执行器
|
||||
if (config.onChunk) {
|
||||
@ -406,7 +405,7 @@ export default class ModernAiProvider {
|
||||
}
|
||||
|
||||
// 调用新 AI SDK 的图像生成功能
|
||||
const executor = createExecutor(this.config!.providerId, this.config!.options, [])
|
||||
const executor = createExecutor(this.config!.providerId, this.config!.providerSettings, [])
|
||||
const result = await executor.generateImage({
|
||||
model,
|
||||
...imageParams
|
||||
@ -504,7 +503,7 @@ export default class ModernAiProvider {
|
||||
|
||||
// 确保本地provider已创建
|
||||
if (!this.localProvider && this.config) {
|
||||
this.localProvider = await createAiSdkProvider(this.config)
|
||||
// this.localProvider = await createAiSdkProvider(this.config) // TODO: Update provider creation
|
||||
if (!this.localProvider) {
|
||||
throw new Error('Local provider not created')
|
||||
}
|
||||
@ -537,7 +536,7 @@ export default class ModernAiProvider {
|
||||
...(signal && { abortSignal: signal })
|
||||
}
|
||||
|
||||
const executor = createExecutor(this.config!.providerId, this.config!.options, [])
|
||||
const executor = createExecutor(this.config!.providerId, this.config!.providerSettings, [])
|
||||
const result = await executor.generateImage({
|
||||
model: model, // 直接使用 model ID 字符串,由 executor 内部解析
|
||||
...aiSdkParams
|
||||
|
||||
@ -174,9 +174,11 @@ describe('Copilot responses routing', () => {
|
||||
const config = providerToAiSdkConfig(provider, createModel('gpt-5-codex', 'GPT-5-CODEX'))
|
||||
|
||||
expect(config.providerId).toBe('github-copilot-openai-compatible')
|
||||
expect(config.options.headers?.['Editor-Version']).toBe(COPILOT_EDITOR_VERSION)
|
||||
expect(config.options.headers?.['Copilot-Integration-Id']).toBe(COPILOT_DEFAULT_HEADERS['Copilot-Integration-Id'])
|
||||
expect(config.options.headers?.['copilot-vision-request']).toBe('true')
|
||||
expect(config.providerSettings.headers?.['Editor-Version']).toBe(COPILOT_EDITOR_VERSION)
|
||||
expect(config.providerSettings.headers?.['Copilot-Integration-Id']).toBe(
|
||||
COPILOT_DEFAULT_HEADERS['Copilot-Integration-Id']
|
||||
)
|
||||
expect(config.providerSettings.headers?.['copilot-vision-request']).toBe('true')
|
||||
})
|
||||
|
||||
it('uses the Copilot provider for other models and keeps headers', () => {
|
||||
@ -184,8 +186,10 @@ describe('Copilot responses routing', () => {
|
||||
const config = providerToAiSdkConfig(provider, createModel('gpt-4'))
|
||||
|
||||
expect(config.providerId).toBe('github-copilot-openai-compatible')
|
||||
expect(config.options.headers?.['Editor-Version']).toBe(COPILOT_DEFAULT_HEADERS['Editor-Version'])
|
||||
expect(config.options.headers?.['Copilot-Integration-Id']).toBe(COPILOT_DEFAULT_HEADERS['Copilot-Integration-Id'])
|
||||
expect(config.providerSettings.headers?.['Editor-Version']).toBe(COPILOT_DEFAULT_HEADERS['Editor-Version'])
|
||||
expect(config.providerSettings.headers?.['Copilot-Integration-Id']).toBe(
|
||||
COPILOT_DEFAULT_HEADERS['Copilot-Integration-Id']
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
@ -388,7 +392,7 @@ describe('Stream options includeUsage configuration', () => {
|
||||
const provider = createOpenAIProvider()
|
||||
const config = providerToAiSdkConfig(provider, createModel('gpt-4', 'GPT-4', 'openai'))
|
||||
|
||||
expect(config.options.includeUsage).toBeUndefined()
|
||||
expect(config.providerSettings.includeUsage).toBeUndefined()
|
||||
})
|
||||
|
||||
it('uses includeUsage from settings when set to true', () => {
|
||||
@ -406,7 +410,7 @@ describe('Stream options includeUsage configuration', () => {
|
||||
const provider = createOpenAIProvider()
|
||||
const config = providerToAiSdkConfig(provider, createModel('gpt-4', 'GPT-4', 'openai'))
|
||||
|
||||
expect(config.options.includeUsage).toBe(true)
|
||||
expect(config.providerSettings.includeUsage).toBe(true)
|
||||
})
|
||||
|
||||
it('uses includeUsage from settings when set to false', () => {
|
||||
@ -424,7 +428,7 @@ describe('Stream options includeUsage configuration', () => {
|
||||
const provider = createOpenAIProvider()
|
||||
const config = providerToAiSdkConfig(provider, createModel('gpt-4', 'GPT-4', 'openai'))
|
||||
|
||||
expect(config.options.includeUsage).toBe(false)
|
||||
expect(config.providerSettings.includeUsage).toBe(false)
|
||||
})
|
||||
|
||||
it('respects includeUsage setting for non-supporting providers', () => {
|
||||
@ -455,7 +459,7 @@ describe('Stream options includeUsage configuration', () => {
|
||||
const config = providerToAiSdkConfig(testProvider, createModel('gpt-4', 'GPT-4', 'test'))
|
||||
|
||||
// Even though setting is true, provider doesn't support it, so includeUsage should be undefined
|
||||
expect(config.options.includeUsage).toBeUndefined()
|
||||
expect(config.providerSettings.includeUsage).toBeUndefined()
|
||||
})
|
||||
|
||||
it('uses includeUsage from settings for Copilot provider when set to false', () => {
|
||||
@ -473,7 +477,7 @@ describe('Stream options includeUsage configuration', () => {
|
||||
const provider = createCopilotProvider()
|
||||
const config = providerToAiSdkConfig(provider, createModel('gpt-4', 'GPT-4', 'copilot'))
|
||||
|
||||
expect(config.options.includeUsage).toBe(false)
|
||||
expect(config.providerSettings.includeUsage).toBe(false)
|
||||
expect(config.providerId).toBe('github-copilot-openai-compatible')
|
||||
})
|
||||
|
||||
@ -492,7 +496,7 @@ describe('Stream options includeUsage configuration', () => {
|
||||
const provider = createCopilotProvider()
|
||||
const config = providerToAiSdkConfig(provider, createModel('gpt-4', 'GPT-4', 'copilot'))
|
||||
|
||||
expect(config.options.includeUsage).toBe(true)
|
||||
expect(config.providerSettings.includeUsage).toBe(true)
|
||||
expect(config.providerId).toBe('github-copilot-openai-compatible')
|
||||
})
|
||||
|
||||
@ -511,7 +515,7 @@ describe('Stream options includeUsage configuration', () => {
|
||||
const provider = createCopilotProvider()
|
||||
const config = providerToAiSdkConfig(provider, createModel('gpt-4', 'GPT-4', 'copilot'))
|
||||
|
||||
expect(config.options.includeUsage).toBeUndefined()
|
||||
expect(config.providerSettings.includeUsage).toBeUndefined()
|
||||
expect(config.providerId).toBe('github-copilot-openai-compatible')
|
||||
})
|
||||
})
|
||||
@ -540,21 +544,21 @@ describe('Azure OpenAI traditional API routing', () => {
|
||||
const config = providerToAiSdkConfig(provider, createModel('gpt-4o', 'GPT-4o', provider.id))
|
||||
|
||||
expect(config.providerId).toBe('azure')
|
||||
expect(config.options.apiVersion).toBe('2024-02-15-preview')
|
||||
expect(config.options.useDeploymentBasedUrls).toBe(true)
|
||||
expect(config.providerSettings.apiVersion).toBe('2024-02-15-preview')
|
||||
expect(config.providerSettings.useDeploymentBasedUrls).toBe(true)
|
||||
})
|
||||
|
||||
it('does not force deployment-based URLs for apiVersion v1/preview', () => {
|
||||
const v1Provider = createAzureProvider('v1')
|
||||
const v1Config = providerToAiSdkConfig(v1Provider, createModel('gpt-4o', 'GPT-4o', v1Provider.id))
|
||||
expect(v1Config.providerId).toBe('azure-responses')
|
||||
expect(v1Config.options.apiVersion).toBe('v1')
|
||||
expect(v1Config.options.useDeploymentBasedUrls).toBeUndefined()
|
||||
expect(v1Config.providerSettings.apiVersion).toBe('v1')
|
||||
expect(v1Config.providerSettings.useDeploymentBasedUrls).toBeUndefined()
|
||||
|
||||
const previewProvider = createAzureProvider('preview')
|
||||
const previewConfig = providerToAiSdkConfig(previewProvider, createModel('gpt-4o', 'GPT-4o', previewProvider.id))
|
||||
expect(previewConfig.providerId).toBe('azure-responses')
|
||||
expect(previewConfig.options.apiVersion).toBe('preview')
|
||||
expect(previewConfig.options.useDeploymentBasedUrls).toBeUndefined()
|
||||
expect(previewConfig.providerSettings.apiVersion).toBe('preview')
|
||||
expect(previewConfig.providerSettings.useDeploymentBasedUrls).toBeUndefined()
|
||||
})
|
||||
})
|
||||
|
||||
@ -3,8 +3,21 @@
|
||||
* 用于支持运行时动态导入的 AI Providers
|
||||
*/
|
||||
|
||||
import type { ProviderV2 } from '@ai-sdk/provider'
|
||||
import { ProviderExtension, type ProviderExtensionConfig } from '@cherrystudio/ai-core/provider'
|
||||
import { type AmazonBedrockProviderSettings, createAmazonBedrock } from '@ai-sdk/amazon-bedrock'
|
||||
import { type AnthropicProviderSettings, createAnthropic } from '@ai-sdk/anthropic'
|
||||
import { type CerebrasProviderSettings, createCerebras } from '@ai-sdk/cerebras'
|
||||
import { createGateway, type GatewayProviderSettings } from '@ai-sdk/gateway'
|
||||
import { createVertexAnthropic } from '@ai-sdk/google-vertex/anthropic/edge'
|
||||
import { createVertex, type GoogleVertexProviderSettings } from '@ai-sdk/google-vertex/edge'
|
||||
import { createHuggingFace, type HuggingFaceProviderSettings } from '@ai-sdk/huggingface'
|
||||
import { createMistral, type MistralProviderSettings } from '@ai-sdk/mistral'
|
||||
import { createPerplexity, type PerplexityProviderSettings } from '@ai-sdk/perplexity'
|
||||
import type { ProviderV2, ProviderV3 } from '@ai-sdk/provider'
|
||||
import { ExtensionStorage, ProviderExtension, type ProviderExtensionConfig } from '@cherrystudio/ai-core/provider'
|
||||
import {
|
||||
createGitHubCopilotOpenAICompatible,
|
||||
type GitHubCopilotProviderSettings
|
||||
} from '@opeoginni/github-copilot-openai-compatible'
|
||||
import { wrapProvider } from 'ai'
|
||||
import type { OllamaProviderSettings } from 'ollama-ai-provider-v2'
|
||||
import { createOllama } from 'ollama-ai-provider-v2'
|
||||
@ -16,9 +29,13 @@ export const GoogleVertexExtension = ProviderExtension.create({
|
||||
name: 'google-vertex',
|
||||
aliases: ['vertexai'] as const,
|
||||
supportsImageGeneration: true,
|
||||
import: () => import('@ai-sdk/google-vertex/edge'),
|
||||
creatorFunctionName: 'createVertex'
|
||||
} as const satisfies ProviderExtensionConfig<any, any, any, 'google-vertex'>)
|
||||
create: createVertex
|
||||
} as const satisfies ProviderExtensionConfig<
|
||||
GoogleVertexProviderSettings,
|
||||
ExtensionStorage,
|
||||
ProviderV3,
|
||||
'google-vertex'
|
||||
>)
|
||||
|
||||
/**
|
||||
* Google Vertex AI Anthropic Extension
|
||||
@ -27,9 +44,13 @@ export const GoogleVertexAnthropicExtension = ProviderExtension.create({
|
||||
name: 'google-vertex-anthropic',
|
||||
aliases: ['vertexai-anthropic'] as const,
|
||||
supportsImageGeneration: true,
|
||||
import: () => import('@ai-sdk/google-vertex/anthropic/edge'),
|
||||
creatorFunctionName: 'createVertexAnthropic'
|
||||
} as const satisfies ProviderExtensionConfig<any, any, any, 'google-vertex-anthropic'>)
|
||||
create: createVertexAnthropic
|
||||
} as const satisfies ProviderExtensionConfig<
|
||||
GoogleVertexProviderSettings,
|
||||
ExtensionStorage,
|
||||
ProviderV3,
|
||||
'google-vertex-anthropic'
|
||||
>)
|
||||
|
||||
/**
|
||||
* Azure AI Anthropic Extension
|
||||
@ -38,9 +59,13 @@ export const AzureAnthropicExtension = ProviderExtension.create({
|
||||
name: 'azure-anthropic',
|
||||
aliases: ['azure-anthropic'] as const,
|
||||
supportsImageGeneration: false,
|
||||
import: () => import('@ai-sdk/anthropic'),
|
||||
creatorFunctionName: 'createAnthropic'
|
||||
} as const satisfies ProviderExtensionConfig<any, any, any, 'azure-anthropic'>)
|
||||
create: createAnthropic
|
||||
} as const satisfies ProviderExtensionConfig<
|
||||
AnthropicProviderSettings,
|
||||
ExtensionStorage,
|
||||
ProviderV3,
|
||||
'azure-anthropic'
|
||||
>)
|
||||
|
||||
/**
|
||||
* GitHub Copilot Extension
|
||||
@ -49,9 +74,16 @@ export const GitHubCopilotExtension = ProviderExtension.create({
|
||||
name: 'github-copilot-openai-compatible',
|
||||
aliases: ['copilot', 'github-copilot'] as const,
|
||||
supportsImageGeneration: false,
|
||||
import: () => import('@opeoginni/github-copilot-openai-compatible'),
|
||||
creatorFunctionName: 'createGitHubCopilotOpenAICompatible'
|
||||
} as const satisfies ProviderExtensionConfig<any, any, any, 'github-copilot-openai-compatible'>)
|
||||
create: (options?: GitHubCopilotProviderSettings) => {
|
||||
const provider = createGitHubCopilotOpenAICompatible(options) as unknown as ProviderV2
|
||||
return wrapProvider({ provider, languageModelMiddleware: [] })
|
||||
}
|
||||
} as const satisfies ProviderExtensionConfig<
|
||||
GitHubCopilotProviderSettings,
|
||||
ExtensionStorage,
|
||||
ProviderV3,
|
||||
'github-copilot-openai-compatible'
|
||||
>)
|
||||
|
||||
/**
|
||||
* Amazon Bedrock Extension
|
||||
@ -60,9 +92,8 @@ export const BedrockExtension = ProviderExtension.create({
|
||||
name: 'bedrock',
|
||||
aliases: ['aws-bedrock'] as const,
|
||||
supportsImageGeneration: true,
|
||||
import: () => import('@ai-sdk/amazon-bedrock'),
|
||||
creatorFunctionName: 'createAmazonBedrock'
|
||||
} as const satisfies ProviderExtensionConfig<any, any, any, 'bedrock'>)
|
||||
create: createAmazonBedrock
|
||||
} as const satisfies ProviderExtensionConfig<AmazonBedrockProviderSettings, ExtensionStorage, ProviderV3, 'bedrock'>)
|
||||
|
||||
/**
|
||||
* Perplexity Extension
|
||||
@ -70,9 +101,8 @@ export const BedrockExtension = ProviderExtension.create({
|
||||
export const PerplexityExtension = ProviderExtension.create({
|
||||
name: 'perplexity',
|
||||
supportsImageGeneration: false,
|
||||
import: () => import('@ai-sdk/perplexity'),
|
||||
creatorFunctionName: 'createPerplexity'
|
||||
} as const satisfies ProviderExtensionConfig<any, any, any, 'perplexity'>)
|
||||
create: createPerplexity
|
||||
} as const satisfies ProviderExtensionConfig<PerplexityProviderSettings, ExtensionStorage, ProviderV3, 'perplexity'>)
|
||||
|
||||
/**
|
||||
* Mistral Extension
|
||||
@ -81,9 +111,8 @@ export const MistralExtension = ProviderExtension.create({
|
||||
name: 'mistral',
|
||||
aliases: ['mistral'] as const,
|
||||
supportsImageGeneration: false,
|
||||
import: () => import('@ai-sdk/mistral'),
|
||||
creatorFunctionName: 'createMistral'
|
||||
} as const satisfies ProviderExtensionConfig<any, any, any, 'mistral'>)
|
||||
create: createMistral
|
||||
} as const satisfies ProviderExtensionConfig<MistralProviderSettings, ExtensionStorage, ProviderV3, 'mistral'>)
|
||||
|
||||
/**
|
||||
* HuggingFace Extension
|
||||
@ -92,9 +121,8 @@ export const HuggingFaceExtension = ProviderExtension.create({
|
||||
name: 'huggingface',
|
||||
aliases: ['hf', 'hugging-face'] as const,
|
||||
supportsImageGeneration: true,
|
||||
import: () => import('@ai-sdk/huggingface'),
|
||||
creatorFunctionName: 'createHuggingFace'
|
||||
} as const satisfies ProviderExtensionConfig<any, any, any, 'huggingface'>)
|
||||
create: createHuggingFace
|
||||
} as const satisfies ProviderExtensionConfig<HuggingFaceProviderSettings, ExtensionStorage, ProviderV3, 'huggingface'>)
|
||||
|
||||
/**
|
||||
* Vercel AI Gateway Extension
|
||||
@ -103,9 +131,8 @@ export const GatewayExtension = ProviderExtension.create({
|
||||
name: 'gateway',
|
||||
aliases: ['ai-gateway'] as const,
|
||||
supportsImageGeneration: true,
|
||||
import: () => import('@ai-sdk/gateway'),
|
||||
creatorFunctionName: 'createGateway'
|
||||
} as const satisfies ProviderExtensionConfig<any, any, any, 'gateway'>)
|
||||
create: createGateway
|
||||
} as const satisfies ProviderExtensionConfig<GatewayProviderSettings, ExtensionStorage, ProviderV3, 'gateway'>)
|
||||
|
||||
/**
|
||||
* Cerebras Extension
|
||||
@ -113,9 +140,8 @@ export const GatewayExtension = ProviderExtension.create({
|
||||
export const CerebrasExtension = ProviderExtension.create({
|
||||
name: 'cerebras',
|
||||
supportsImageGeneration: false,
|
||||
import: () => import('@ai-sdk/cerebras'),
|
||||
creatorFunctionName: 'createCerebras'
|
||||
} as const satisfies ProviderExtensionConfig<any, any, any, 'cerebras'>)
|
||||
create: createCerebras
|
||||
} as const satisfies ProviderExtensionConfig<CerebrasProviderSettings, ExtensionStorage, ProviderV3, 'cerebras'>)
|
||||
|
||||
/**
|
||||
* Ollama Extension
|
||||
@ -127,7 +153,7 @@ export const OllamaExtension = ProviderExtension.create({
|
||||
const provider = createOllama(options) as ProviderV2
|
||||
return wrapProvider({ provider, languageModelMiddleware: [] })
|
||||
}
|
||||
} as const satisfies ProviderExtensionConfig<OllamaProviderSettings, any, any, 'ollama'>)
|
||||
} as const satisfies ProviderExtensionConfig<OllamaProviderSettings, ExtensionStorage, ProviderV3, 'ollama'>)
|
||||
|
||||
/**
|
||||
* 所有项目特定的 Extensions
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
import { extensionRegistry, formatPrivateKey, hasProviderConfig } from '@cherrystudio/ai-core/provider'
|
||||
import { formatPrivateKey, hasProviderConfig } from '@cherrystudio/ai-core/provider'
|
||||
import type { AppProviderId } from '@renderer/aiCore/types'
|
||||
import { isOpenAIChatCompletionOnlyModel } from '@renderer/config/models'
|
||||
import {
|
||||
getAwsBedrockAccessKeyId,
|
||||
getAwsBedrockApiKey,
|
||||
@ -13,7 +12,6 @@ import { getProviderByModel } from '@renderer/services/AssistantService'
|
||||
import { getProviderById } from '@renderer/services/ProviderService'
|
||||
import store from '@renderer/store'
|
||||
import { isSystemProvider, type Model, type Provider, SystemProviderIds } from '@renderer/types'
|
||||
import type { OpenAICompletionsStreamOptions } from '@renderer/types/aiCoreTypes'
|
||||
import {
|
||||
formatApiHost,
|
||||
formatAzureOpenAIApiHost,
|
||||
@ -36,7 +34,7 @@ import {
|
||||
import { defaultAppHeaders } from '@shared/utils'
|
||||
import { cloneDeep, isEmpty } from 'lodash'
|
||||
|
||||
import type { AiSdkConfigRuntime } from '../types'
|
||||
import type { ProviderConfig } from '../types'
|
||||
import { aihubmixProviderCreator, newApiResolverCreator, vertexAnthropicProviderCreator } from './config'
|
||||
import { azureAnthropicProviderCreator } from './config/azure-anthropic'
|
||||
import { COPILOT_DEFAULT_HEADERS } from './constants'
|
||||
@ -146,155 +144,56 @@ export function adaptProvider({ provider, model }: { provider: Provider; model?:
|
||||
*
|
||||
* @param actualProvider - Cherry Studio provider配置
|
||||
* @param model - 模型配置
|
||||
* @returns 类型安全的 AI SDK 配置
|
||||
* @returns 类型安全的 Provider 配置
|
||||
*/
|
||||
export function providerToAiSdkConfig(actualProvider: Provider, model: Model): AiSdkConfigRuntime {
|
||||
const aiSdkProviderId: AppProviderId = getAiSdkProviderId(actualProvider)
|
||||
|
||||
// 构建基础配置
|
||||
export function providerToAiSdkConfig(actualProvider: Provider, model: Model): ProviderConfig {
|
||||
const aiSdkProviderId = getAiSdkProviderId(actualProvider)
|
||||
const { baseURL, endpoint } = routeToEndpoint(actualProvider.apiHost)
|
||||
const baseConfig = {
|
||||
baseURL: baseURL,
|
||||
apiKey: actualProvider.apiKey
|
||||
}
|
||||
let includeUsage: OpenAICompletionsStreamOptions['include_usage'] = undefined
|
||||
if (isSupportStreamOptionsProvider(actualProvider)) {
|
||||
includeUsage = store.getState().settings.openAI?.streamOptions?.includeUsage
|
||||
|
||||
// 构建上下文
|
||||
const ctx: BuilderContext = {
|
||||
actualProvider,
|
||||
model,
|
||||
baseConfig: {
|
||||
baseURL,
|
||||
apiKey: actualProvider.apiKey
|
||||
},
|
||||
endpoint,
|
||||
aiSdkProviderId
|
||||
}
|
||||
|
||||
const isCopilotProvider = actualProvider.id === SystemProviderIds.copilot
|
||||
if (isCopilotProvider) {
|
||||
const storedHeaders = store.getState().copilot.defaultHeaders ?? {}
|
||||
const options = {
|
||||
...baseConfig,
|
||||
headers: {
|
||||
...COPILOT_DEFAULT_HEADERS,
|
||||
...storedHeaders,
|
||||
...actualProvider.extra_headers
|
||||
},
|
||||
name: actualProvider.id,
|
||||
includeUsage
|
||||
}
|
||||
|
||||
return {
|
||||
providerId: 'github-copilot-openai-compatible',
|
||||
options
|
||||
}
|
||||
// 路由到专门的构建器
|
||||
if (actualProvider.id === SystemProviderIds.copilot) {
|
||||
return buildCopilotConfig(ctx)
|
||||
}
|
||||
|
||||
if (isOllamaProvider(actualProvider)) {
|
||||
return {
|
||||
providerId: 'ollama',
|
||||
options: {
|
||||
...baseConfig,
|
||||
headers: {
|
||||
...actualProvider.extra_headers,
|
||||
Authorization: !isEmpty(baseConfig.apiKey) ? `Bearer ${baseConfig.apiKey}` : undefined
|
||||
}
|
||||
}
|
||||
}
|
||||
return buildOllamaConfig(ctx)
|
||||
}
|
||||
|
||||
// 处理OpenAI模式
|
||||
const extraOptions: any = {}
|
||||
extraOptions.endpoint = endpoint
|
||||
|
||||
// 解析 provider ID,提取 base ID 和 mode
|
||||
const parsed = extensionRegistry.parseProviderId(aiSdkProviderId)
|
||||
if (parsed?.mode) {
|
||||
// 自动设置 mode(如 openai-chat → mode: 'chat', azure-responses → mode: 'responses')
|
||||
extraOptions.mode = parsed.mode
|
||||
}
|
||||
|
||||
// 特殊处理:OpenAI responses 模式(基于 provider.type)
|
||||
if (actualProvider.type === 'openai-response' && !isOpenAIChatCompletionOnlyModel(model)) {
|
||||
extraOptions.mode = 'responses'
|
||||
} else if (aiSdkProviderId === 'openai' || (aiSdkProviderId === 'cherryin' && actualProvider.type === 'openai')) {
|
||||
// 确保 OpenAI 和 CherryIN(type=openai) 使用 chat 模式
|
||||
extraOptions.mode = 'chat'
|
||||
}
|
||||
|
||||
extraOptions.headers = {
|
||||
...defaultAppHeaders(),
|
||||
...actualProvider.extra_headers
|
||||
}
|
||||
|
||||
if (aiSdkProviderId === 'openai') {
|
||||
const headers = extraOptions.headers as Record<string, string>
|
||||
headers['X-Api-Key'] = baseConfig.apiKey
|
||||
}
|
||||
if (isAzureOpenAIProvider(actualProvider)) {
|
||||
const apiVersion = actualProvider.apiVersion?.trim()
|
||||
if (apiVersion) {
|
||||
extraOptions.apiVersion = apiVersion
|
||||
if (!['preview', 'v1'].includes(apiVersion)) {
|
||||
extraOptions.useDeploymentBasedUrls = true
|
||||
}
|
||||
}
|
||||
return buildAzureConfig(ctx)
|
||||
}
|
||||
|
||||
// bedrock
|
||||
if (aiSdkProviderId === 'bedrock') {
|
||||
const authType = getAwsBedrockAuthType()
|
||||
extraOptions.region = getAwsBedrockRegion()
|
||||
|
||||
if (authType === 'apiKey') {
|
||||
extraOptions.apiKey = getAwsBedrockApiKey()
|
||||
} else {
|
||||
extraOptions.accessKeyId = getAwsBedrockAccessKeyId()
|
||||
extraOptions.secretAccessKey = getAwsBedrockSecretAccessKey()
|
||||
}
|
||||
return buildBedrockConfig(ctx)
|
||||
}
|
||||
// google-vertex
|
||||
|
||||
if (aiSdkProviderId === 'google-vertex' || aiSdkProviderId === 'google-vertex-anthropic') {
|
||||
if (!isVertexAIConfigured()) {
|
||||
throw new Error('VertexAI is not configured. Please configure project, location and service account credentials.')
|
||||
}
|
||||
const { project, location, googleCredentials } = createVertexProvider(actualProvider)
|
||||
extraOptions.project = project
|
||||
extraOptions.location = location
|
||||
extraOptions.googleCredentials = {
|
||||
...googleCredentials,
|
||||
privateKey: formatPrivateKey(googleCredentials.privateKey)
|
||||
}
|
||||
baseConfig.baseURL += aiSdkProviderId === 'google-vertex' ? '/publishers/google' : '/publishers/anthropic/models'
|
||||
return buildVertexConfig(ctx)
|
||||
}
|
||||
|
||||
// cherryin
|
||||
if (aiSdkProviderId === 'cherryin') {
|
||||
if (model.endpoint_type) {
|
||||
extraOptions.endpointType = model.endpoint_type
|
||||
}
|
||||
// CherryIN API Host
|
||||
const cherryinProvider = getProviderById(SystemProviderIds.cherryin)
|
||||
if (cherryinProvider) {
|
||||
extraOptions.anthropicBaseURL = cherryinProvider.anthropicApiHost + '/v1'
|
||||
extraOptions.geminiBaseURL = cherryinProvider.apiHost + '/v1beta/models'
|
||||
}
|
||||
return buildCherryinConfig(ctx)
|
||||
}
|
||||
|
||||
// 有 SDK 支持的 provider
|
||||
if (hasProviderConfig(aiSdkProviderId) && aiSdkProviderId !== 'openai-compatible') {
|
||||
const options = {
|
||||
...baseConfig,
|
||||
...extraOptions
|
||||
}
|
||||
return {
|
||||
providerId: aiSdkProviderId,
|
||||
options
|
||||
}
|
||||
return buildGenericProviderConfig(ctx)
|
||||
}
|
||||
|
||||
// 否则fallback到openai-compatible
|
||||
return {
|
||||
providerId: 'openai-compatible',
|
||||
options: {
|
||||
baseURL: baseConfig.baseURL,
|
||||
apiKey: baseConfig.apiKey,
|
||||
name: actualProvider.id,
|
||||
...extraOptions,
|
||||
includeUsage
|
||||
}
|
||||
}
|
||||
// 默认 fallback 到 openai-compatible
|
||||
return buildOpenAICompatibleConfig(ctx)
|
||||
}
|
||||
|
||||
/**
|
||||
@ -317,10 +216,7 @@ export function isModernSdkSupported(provider: Provider): boolean {
|
||||
/**
|
||||
* 准备特殊provider的配置,主要用于异步处理的配置
|
||||
*/
|
||||
export async function prepareSpecialProviderConfig(
|
||||
provider: Provider,
|
||||
config: ReturnType<typeof providerToAiSdkConfig>
|
||||
) {
|
||||
export async function prepareSpecialProviderConfig(provider: Provider, config: ProviderConfig) {
|
||||
switch (provider.id) {
|
||||
case 'copilot': {
|
||||
const defaultHeaders = store.getState().copilot.defaultHeaders ?? {}
|
||||
@ -329,15 +225,17 @@ export async function prepareSpecialProviderConfig(
|
||||
...defaultHeaders
|
||||
}
|
||||
const { token } = await window.api.copilot.getToken(headers)
|
||||
config.options.apiKey = token
|
||||
config.options.headers = {
|
||||
const settings = config.providerSettings as any
|
||||
settings.apiKey = token
|
||||
settings.headers = {
|
||||
...headers,
|
||||
...config.options.headers
|
||||
...settings.headers
|
||||
}
|
||||
break
|
||||
}
|
||||
case 'cherryai': {
|
||||
config.options.fetch = async (url, options) => {
|
||||
const settings = config.providerSettings as any
|
||||
settings.fetch = async (url: string, options: any) => {
|
||||
// 在这里对最终参数进行签名
|
||||
const signature = await window.api.cherryai.generateSignature({
|
||||
method: 'POST',
|
||||
@ -358,10 +256,11 @@ export async function prepareSpecialProviderConfig(
|
||||
case 'anthropic': {
|
||||
if (provider.authType === 'oauth') {
|
||||
const oauthToken = await window.api.anthropic_oauth.getAccessToken()
|
||||
config.options = {
|
||||
...config.options,
|
||||
const settings = config.providerSettings as any
|
||||
config.providerSettings = {
|
||||
...settings,
|
||||
headers: {
|
||||
...(config.options.headers ? config.options.headers : {}),
|
||||
...(settings.headers ? settings.headers : {}),
|
||||
'Content-Type': 'application/json',
|
||||
'anthropic-version': '2023-06-01',
|
||||
Authorization: `Bearer ${oauthToken}`
|
||||
@ -374,3 +273,253 @@ export async function prepareSpecialProviderConfig(
|
||||
}
|
||||
return config
|
||||
}
|
||||
|
||||
/**
|
||||
* 基础配置
|
||||
*/
|
||||
interface BaseConfig {
|
||||
baseURL: string
|
||||
apiKey: string
|
||||
}
|
||||
|
||||
/**
|
||||
* 构建器上下文
|
||||
*/
|
||||
interface BuilderContext {
|
||||
actualProvider: Provider
|
||||
model: Model
|
||||
baseConfig: BaseConfig
|
||||
endpoint?: string
|
||||
aiSdkProviderId: AppProviderId
|
||||
}
|
||||
|
||||
/**
|
||||
* GitHub Copilot 配置构建器
|
||||
*/
|
||||
function buildCopilotConfig(ctx: BuilderContext): ProviderConfig<'github-copilot-openai-compatible'> {
|
||||
const storedHeaders = store.getState().copilot.defaultHeaders ?? {}
|
||||
|
||||
return {
|
||||
providerId: 'github-copilot-openai-compatible',
|
||||
providerSettings: {
|
||||
...ctx.baseConfig,
|
||||
headers: {
|
||||
...COPILOT_DEFAULT_HEADERS,
|
||||
...storedHeaders,
|
||||
...ctx.actualProvider.extra_headers
|
||||
},
|
||||
name: ctx.actualProvider.id
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Ollama 配置构建器
|
||||
*/
|
||||
function buildOllamaConfig(ctx: BuilderContext): ProviderConfig<'ollama'> {
|
||||
const headers: ProviderConfig<'ollama'>['providerSettings']['headers'] = {
|
||||
...ctx.actualProvider.extra_headers
|
||||
}
|
||||
|
||||
if (!isEmpty(ctx.baseConfig.apiKey)) {
|
||||
headers.Authorization = `Bearer ${ctx.baseConfig.apiKey}`
|
||||
}
|
||||
|
||||
return {
|
||||
providerId: 'ollama',
|
||||
providerSettings: {
|
||||
...ctx.baseConfig,
|
||||
headers
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* AWS Bedrock 配置构建器
|
||||
*/
|
||||
function buildBedrockConfig(ctx: BuilderContext): ProviderConfig<'bedrock'> {
|
||||
const authType = getAwsBedrockAuthType()
|
||||
const region = getAwsBedrockRegion()
|
||||
|
||||
if (authType === 'apiKey') {
|
||||
return {
|
||||
providerId: 'bedrock',
|
||||
providerSettings: {
|
||||
...ctx.baseConfig,
|
||||
region,
|
||||
apiKey: getAwsBedrockApiKey()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
providerId: 'bedrock',
|
||||
providerSettings: {
|
||||
...ctx.baseConfig,
|
||||
region,
|
||||
accessKeyId: getAwsBedrockAccessKeyId(),
|
||||
secretAccessKey: getAwsBedrockSecretAccessKey()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Google Vertex AI 配置构建器
|
||||
*/
|
||||
function buildVertexConfig(
|
||||
ctx: BuilderContext
|
||||
): ProviderConfig<'google-vertex'> | ProviderConfig<'google-vertex-anthropic'> {
|
||||
if (!isVertexAIConfigured()) {
|
||||
throw new Error('VertexAI is not configured. Please configure project, location and service account credentials.')
|
||||
}
|
||||
|
||||
const { project, location, googleCredentials } = createVertexProvider(ctx.actualProvider)
|
||||
const isAnthropic = ctx.aiSdkProviderId === 'google-vertex-anthropic'
|
||||
|
||||
const baseURL = ctx.baseConfig.baseURL + (isAnthropic ? '/publishers/anthropic/models' : '/publishers/google')
|
||||
|
||||
if (isAnthropic) {
|
||||
return {
|
||||
providerId: 'google-vertex-anthropic',
|
||||
providerSettings: {
|
||||
...ctx.baseConfig,
|
||||
baseURL,
|
||||
project,
|
||||
location,
|
||||
googleCredentials: {
|
||||
...googleCredentials,
|
||||
privateKey: formatPrivateKey(googleCredentials.privateKey)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
providerId: 'google-vertex',
|
||||
providerSettings: {
|
||||
...ctx.baseConfig,
|
||||
baseURL,
|
||||
project,
|
||||
location,
|
||||
googleCredentials: {
|
||||
...googleCredentials,
|
||||
privateKey: formatPrivateKey(googleCredentials.privateKey)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* CherryIN 配置构建器
|
||||
*/
|
||||
function buildCherryinConfig(ctx: BuilderContext): ProviderConfig<'cherryin'> {
|
||||
const cherryinProvider = getProviderById(SystemProviderIds.cherryin)
|
||||
|
||||
return {
|
||||
providerId: 'cherryin',
|
||||
providerSettings: {
|
||||
...ctx.baseConfig,
|
||||
endpointType: ctx.model.endpoint_type,
|
||||
anthropicBaseURL: cherryinProvider ? cherryinProvider.anthropicApiHost + '/v1' : undefined,
|
||||
geminiBaseURL: cherryinProvider ? cherryinProvider.apiHost + '/v1beta/models' : undefined,
|
||||
headers: {
|
||||
...defaultAppHeaders(),
|
||||
...ctx.actualProvider.extra_headers
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Azure OpenAI 配置构建器
|
||||
*/
|
||||
function buildAzureConfig(ctx: BuilderContext): ProviderConfig<'azure'> | ProviderConfig<'azure-responses'> {
|
||||
const apiVersion = ctx.actualProvider.apiVersion?.trim()
|
||||
|
||||
// 根据 apiVersion 决定使用 azure 还是 azure-responses
|
||||
const useResponsesMode = apiVersion && ['preview', 'v1'].includes(apiVersion)
|
||||
|
||||
const providerSettings: Record<string, any> = {
|
||||
...ctx.baseConfig,
|
||||
endpoint: ctx.endpoint,
|
||||
headers: {
|
||||
...defaultAppHeaders(),
|
||||
...ctx.actualProvider.extra_headers
|
||||
}
|
||||
}
|
||||
|
||||
if (apiVersion) {
|
||||
providerSettings.apiVersion = apiVersion
|
||||
// 只有非 preview/v1 版本才使用 deployment-based URLs
|
||||
if (!useResponsesMode) {
|
||||
providerSettings.useDeploymentBasedUrls = true
|
||||
}
|
||||
}
|
||||
|
||||
if (useResponsesMode) {
|
||||
return {
|
||||
providerId: 'azure-responses',
|
||||
providerSettings
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
providerId: 'azure',
|
||||
providerSettings
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 构建通用的 OpenAI-compatible 或特定 provider 的额外选项
|
||||
*/
|
||||
function buildCommonOptions(ctx: BuilderContext) {
|
||||
const options: Record<string, any> = {
|
||||
endpoint: ctx.endpoint,
|
||||
headers: {
|
||||
...defaultAppHeaders(),
|
||||
...ctx.actualProvider.extra_headers
|
||||
}
|
||||
}
|
||||
|
||||
// OpenAI 特殊 header
|
||||
if (ctx.aiSdkProviderId === 'openai') {
|
||||
options.headers['X-Api-Key'] = ctx.baseConfig.apiKey
|
||||
}
|
||||
|
||||
return options
|
||||
}
|
||||
|
||||
/**
|
||||
* OpenAI-compatible 配置构建器
|
||||
*/
|
||||
function buildOpenAICompatibleConfig(ctx: BuilderContext): ProviderConfig<'openai-compatible'> {
|
||||
const commonOptions = buildCommonOptions(ctx)
|
||||
const includeUsage = isSupportStreamOptionsProvider(ctx.actualProvider)
|
||||
? store.getState().settings.openAI?.streamOptions?.includeUsage
|
||||
: undefined
|
||||
|
||||
return {
|
||||
providerId: 'openai-compatible',
|
||||
providerSettings: {
|
||||
...ctx.baseConfig,
|
||||
...commonOptions,
|
||||
name: ctx.actualProvider.id,
|
||||
includeUsage
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 通用 provider 配置构建器(有 SDK 支持的 provider)
|
||||
*/
|
||||
function buildGenericProviderConfig(ctx: BuilderContext): ProviderConfig {
|
||||
const commonOptions = buildCommonOptions(ctx)
|
||||
|
||||
return {
|
||||
providerId: ctx.aiSdkProviderId,
|
||||
providerSettings: {
|
||||
...ctx.baseConfig,
|
||||
...commonOptions
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -7,85 +7,38 @@
|
||||
* TODO: We should separate them clearly. Keep renderer only types in renderer, and main only types in main, and shared types in shared.
|
||||
*/
|
||||
|
||||
import type { AppProviderId, AppProviderSettingsMap } from './merged'
|
||||
import type { AppProviderId, AppRuntimeConfig } from './merged'
|
||||
|
||||
/**
|
||||
* Generic AI SDK configuration with compile-time type safety
|
||||
* Provider 配置(不含 plugins)
|
||||
* 基于 RuntimeConfig,用于构建 provider 实例的基础配置
|
||||
*
|
||||
* 🎯 Zero maintenance! Auto-extracts types from core and project extensions.
|
||||
*
|
||||
* @typeParam T - The specific provider ID type for type-safe options
|
||||
* @typeParam T - The specific provider ID type for type-safe settings
|
||||
*
|
||||
* @example
|
||||
* ```ts
|
||||
* // Type-safe config for core provider
|
||||
* const config1: AiSdkConfig<'openai'> = {
|
||||
* const config1: ProviderConfig<'openai'> = {
|
||||
* providerId: 'openai',
|
||||
* options: { apiKey: '...', baseURL: '...' } // ✅ Typed as OpenAIProviderSettings
|
||||
* providerSettings: { apiKey: '...', baseURL: '...' } // ✅ Typed as OpenAIProviderSettings
|
||||
* }
|
||||
*
|
||||
* // Type-safe config for project provider
|
||||
* const config2: AiSdkConfig<'google-vertex'> = {
|
||||
* const config2: ProviderConfig<'google-vertex'> = {
|
||||
* providerId: 'google-vertex',
|
||||
* options: { ... } // ✅ Typed as GoogleVertexProviderSettings
|
||||
* providerSettings: { ... } // ✅ Typed as GoogleVertexProviderSettings
|
||||
* }
|
||||
*
|
||||
* // Type-safe config with alias
|
||||
* const config3: AiSdkConfig<'oai'> = {
|
||||
* const config3: ProviderConfig<'oai'> = {
|
||||
* providerId: 'oai',
|
||||
* options: { apiKey: '...' } // ✅ Same type as 'openai'
|
||||
* providerSettings: { apiKey: '...' } // ✅ Same type as 'openai'
|
||||
* }
|
||||
* ```
|
||||
*/
|
||||
export type AiSdkConfig<T extends AppProviderId = AppProviderId> = {
|
||||
providerId: T
|
||||
options: AppProviderSettingsMap[T]
|
||||
}
|
||||
|
||||
/**
|
||||
* Runtime-safe AI SDK configuration for gradual migration
|
||||
* Use this when provider ID is not known at compile time
|
||||
*
|
||||
* 使用联合类型而不是 any,提供更好的类型安全性
|
||||
*
|
||||
* @example
|
||||
* ```ts
|
||||
* function createConfig(providerId: AppProviderId): AiSdkConfigRuntime {
|
||||
* return {
|
||||
* providerId,
|
||||
* options: buildOptions(providerId) // ✅ 类型安全:options 必须是某个 provider 的 settings
|
||||
* }
|
||||
* }
|
||||
* ```
|
||||
*/
|
||||
export type AiSdkConfigRuntime = {
|
||||
providerId: AppProviderId
|
||||
options: AppProviderSettingsMap[AppProviderId]
|
||||
}
|
||||
|
||||
/**
|
||||
* Type guard for runtime validation of AiSdkConfig
|
||||
*
|
||||
* @param config - Unknown value to validate
|
||||
* @returns true if config is a valid AiSdkConfigRuntime
|
||||
*
|
||||
* @example
|
||||
* ```ts
|
||||
* if (isValidAiSdkConfig(someConfig)) {
|
||||
* // someConfig is now typed as AiSdkConfigRuntime
|
||||
* await createAiSdkProvider(someConfig)
|
||||
* }
|
||||
* ```
|
||||
*/
|
||||
export function isValidAiSdkConfig(config: unknown): config is AiSdkConfigRuntime {
|
||||
if (!config || typeof config !== 'object') return false
|
||||
|
||||
const c = config as Record<string, unknown>
|
||||
|
||||
return (
|
||||
typeof c.providerId === 'string' && c.providerId.length > 0 && typeof c.options === 'object' && c.options !== null
|
||||
)
|
||||
}
|
||||
export type ProviderConfig<T extends AppProviderId = AppProviderId> = Omit<AppRuntimeConfig<T>, 'plugins'>
|
||||
|
||||
export type { AppProviderId, AppProviderSettingsMap } from './merged'
|
||||
export { appProviderIds, getAllProviderIds, isRegisteredProviderId } from './merged'
|
||||
|
||||
Loading…
Reference in New Issue
Block a user