refactor:type inference

This commit is contained in:
suyao 2026-01-01 19:33:10 +08:00
parent 42ff133732
commit e3351097a9
No known key found for this signature in database
7 changed files with 499 additions and 310 deletions

View File

@ -51,7 +51,7 @@ export class ExtensionRegistry {
*
*/
register(extension: ProviderExtension<any, any, any>): this {
const { name, aliases } = extension.config
const { name, aliases, variants } = extension.config
// 检查主 ID 冲突
if (this.extensions.has(name)) {
@ -71,6 +71,19 @@ export class ExtensionRegistry {
}
}
// 注册变体 ID
if (variants) {
for (const variant of variants) {
const variantId = `${name}-${variant.suffix}`
if (this.aliasMap.has(variantId)) {
throw new Error(
`Provider variant ID "${variantId}" is already registered for "${this.aliasMap.get(variantId)}"`
)
}
this.aliasMap.set(variantId, name)
}
}
return this
}
@ -375,15 +388,24 @@ export class ExtensionRegistry {
* @returns Provider
*/
async createProvider(id: string, settings?: any, explicitId?: string): Promise<ProviderV3> {
const extension = this.get(id)
if (!extension) {
// 解析 provider ID提取基础 ID 和变体后缀
const parsed = this.parseProviderId(id)
if (!parsed) {
throw new Error(`Provider extension "${id}" not found. Did you forget to register it?`)
}
const { baseId, mode: variantSuffix } = parsed
// 获取基础 extension
const extension = this.get(baseId)
if (!extension) {
throw new Error(`Provider extension "${baseId}" not found. Did you forget to register it?`)
}
try {
// 直接委托给 Extension 的 createProvider 方法
// Extension 负责缓存、生命周期钩子、AI SDK 注册等
return await extension.createProvider(settings, explicitId)
// 委托给 Extension 的 createProvider 方法
// Extension 负责缓存、生命周期钩子、AI SDK 注册、变体转换
return await extension.createProvider(settings, explicitId, variantSuffix)
} catch (error) {
throw new ProviderCreationError(
`Failed to create provider "${id}"`,

View File

@ -281,35 +281,43 @@ export class ProviderExtension<
/**
* settings hash
* key
*
* @param settings - Provider
* @param variantSuffix -
*/
private computeHash(settings?: TSettings): string {
if (settings === undefined || settings === null) {
return 'default'
}
private computeHash(settings?: TSettings, variantSuffix?: string): string {
const baseHash = (() => {
if (settings === undefined || settings === null) {
return 'default'
}
// 使用 JSON.stringify 进行稳定序列化
// 对于对象按键排序以确保一致性
const stableStringify = (obj: any): string => {
if (obj === null || obj === undefined) return 'null'
if (typeof obj !== 'object') return JSON.stringify(obj)
if (Array.isArray(obj)) return `[${obj.map(stableStringify).join(',')}]`
// 使用 JSON.stringify 进行稳定序列化
// 对于对象按键排序以确保一致性
const stableStringify = (obj: any): string => {
if (obj === null || obj === undefined) return 'null'
if (typeof obj !== 'object') return JSON.stringify(obj)
if (Array.isArray(obj)) return `[${obj.map(stableStringify).join(',')}]`
const keys = Object.keys(obj).sort()
const pairs = keys.map((key) => `${JSON.stringify(key)}:${stableStringify(obj[key])}`)
return `{${pairs.join(',')}}`
}
const keys = Object.keys(obj).sort()
const pairs = keys.map((key) => `${JSON.stringify(key)}:${stableStringify(obj[key])}`)
return `{${pairs.join(',')}}`
}
const serialized = stableStringify(settings)
const serialized = stableStringify(settings)
// 使用简单的哈希函数(不需要加密级别的安全性)
let hash = 0
for (let i = 0; i < serialized.length; i++) {
const char = serialized.charCodeAt(i)
hash = (hash << 5) - hash + char
hash = hash & hash // Convert to 32bit integer
}
// 使用简单的哈希函数(不需要加密级别的安全性)
let hash = 0
for (let i = 0; i < serialized.length; i++) {
const char = serialized.charCodeAt(i)
hash = (hash << 5) - hash + char
hash = hash & hash // Convert to 32bit integer
}
return `${Math.abs(hash).toString(36)}`
return `${Math.abs(hash).toString(36)}`
})()
// 如果有变体后缀,将其附加到 hash 中
return variantSuffix ? `${baseHash}:${variantSuffix}` : baseHash
}
/**
@ -329,17 +337,29 @@ export class ProviderExtension<
*
* @param settings - Provider
* @param explicitId - ID AI SDK
* @param variantSuffix -
* @returns Provider
*/
async createProvider(settings?: TSettings, explicitId?: string): Promise<TProvider> {
async createProvider(settings?: TSettings, explicitId?: string, variantSuffix?: string): Promise<TProvider> {
// 验证变体后缀(如果提供)
if (variantSuffix) {
const variant = this.getVariant(variantSuffix)
if (!variant) {
throw new Error(
`ProviderExtension "${this.config.name}": variant "${variantSuffix}" not found. ` +
`Available variants: ${this.config.variants?.map((v) => v.suffix).join(', ') || 'none'}`
)
}
}
// 合并 default options
const mergedSettings = deepMergeObjects(
(this.config.defaultOptions || {}) as any,
(settings || {}) as any
) as TSettings
// 计算 hash
const hash = this.computeHash(mergedSettings)
// 计算 hash(包含变体后缀)
const hash = this.computeHash(mergedSettings, variantSuffix)
// 检查缓存
const cachedInstance = this.instances.get(hash)
@ -350,12 +370,12 @@ export class ProviderExtension<
// 执行 onBeforeCreate 钩子
await this.executeHook('onBeforeCreate', mergedSettings)
// 创建实例
let provider: ProviderV3
// 创建基础 provider 实例
let baseProvider: ProviderV3
if (this.config.create) {
// 使用直接创建函数
provider = await Promise.resolve(this.config.create(mergedSettings))
baseProvider = await Promise.resolve(this.config.create(mergedSettings))
} else if (this.config.import && this.config.creatorFunctionName) {
// 动态导入
const module = await this.config.import()
@ -367,29 +387,45 @@ export class ProviderExtension<
)
}
provider = await Promise.resolve(creatorFn(mergedSettings))
baseProvider = await Promise.resolve(creatorFn(mergedSettings))
} else {
throw new Error(`ProviderExtension "${this.config.name}": cannot create provider, invalid configuration`)
}
const typedProvider = provider as TProvider
// 执行 onAfterCreate 钩子
await this.executeHook('onAfterCreate', mergedSettings, typedProvider)
// 缓存实例
this.instances.set(hash, typedProvider)
this.settingsHashes.set(hash, mergedSettings)
// 注册到 AI SDK如果提供了 explicitId
if (explicitId) {
this.registerToAiSdk(typedProvider, explicitId)
// 应用变体转换(如果提供了变体后缀)
let finalProvider: TProvider
if (variantSuffix) {
const variant = this.getVariant(variantSuffix)!
// 应用变体的 transform 函数
finalProvider = (await Promise.resolve(variant.transform(baseProvider as TProvider, mergedSettings))) as TProvider
} else {
// 使用默认 ID: name:hash
this.registerToAiSdk(typedProvider, `${this.config.name}:${hash}`)
finalProvider = baseProvider as TProvider
}
return typedProvider
// 执行 onAfterCreate 钩子
await this.executeHook('onAfterCreate', mergedSettings, finalProvider)
// 缓存实例
this.instances.set(hash, finalProvider)
this.settingsHashes.set(hash, mergedSettings)
// 确定注册 ID
const registrationId = (() => {
if (explicitId) {
return explicitId
}
// 如果是变体,使用 name-suffix:hash 格式
if (variantSuffix) {
return `${this.config.name}-${variantSuffix}:${hash}`
}
// 否则使用 name:hash
return `${this.config.name}:${hash}`
})()
// 注册到 AI SDK
this.registerToAiSdk(finalProvider, registrationId)
return finalProvider
}
/**

View File

@ -24,7 +24,6 @@ import AiSdkToChunkAdapter from './chunk/AiSdkToChunkAdapter'
import LegacyAiProvider from './legacy/index'
import type { CompletionsParams, CompletionsResult } from './legacy/middleware/schemas'
import { buildPlugins } from './plugins/PluginBuilder'
import { createAiSdkProvider } from './provider/factory'
import {
adaptProvider,
getActualProvider,
@ -32,7 +31,7 @@ import {
prepareSpecialProviderConfig,
providerToAiSdkConfig
} from './provider/providerConfig'
import type { AiSdkConfig } from './types'
import type { ProviderConfig } from './types'
import type { AiSdkMiddlewareConfig } from './types/middlewareConfig'
const logger = loggerService.withContext('ModernAiProvider')
@ -46,7 +45,7 @@ export type ModernAiProviderConfig = AiSdkMiddlewareConfig & {
export default class ModernAiProvider {
private legacyProvider: LegacyAiProvider
private config?: AiSdkConfig
private config?: ProviderConfig
private actualProvider: Provider
private model?: Model
private localProvider: Awaited<AiSdkProvider> | null = null
@ -133,7 +132,7 @@ export default class ModernAiProvider {
if (!this.config) {
throw new Error('Provider config is undefined; cannot proceed with completions')
}
if (SUPPORTED_IMAGE_ENDPOINT_LIST.includes(this.config.options.endpoint)) {
if (SUPPORTED_IMAGE_ENDPOINT_LIST.includes(this.config.providerSettings.endpoint)) {
providerConfig.isImageGenerationEndpoint = true
}
// 准备特殊配置
@ -141,7 +140,7 @@ export default class ModernAiProvider {
// 提前创建本地 provider 实例
if (!this.localProvider) {
this.localProvider = await createAiSdkProvider(this.config)
// this.localProvider = await createAiSdkProvider(this.config) // TODO: Update provider creation
}
if (!this.localProvider) {
@ -321,7 +320,7 @@ export default class ModernAiProvider {
const plugins = buildPlugins(config)
// 用构建好的插件数组创建executor
const executor = createExecutor(this.config!.providerId, this.config!.options, plugins)
const executor = createExecutor(this.config!.providerId, this.config!.providerSettings, plugins)
// 创建带有中间件的执行器
if (config.onChunk) {
@ -406,7 +405,7 @@ export default class ModernAiProvider {
}
// 调用新 AI SDK 的图像生成功能
const executor = createExecutor(this.config!.providerId, this.config!.options, [])
const executor = createExecutor(this.config!.providerId, this.config!.providerSettings, [])
const result = await executor.generateImage({
model,
...imageParams
@ -504,7 +503,7 @@ export default class ModernAiProvider {
// 确保本地provider已创建
if (!this.localProvider && this.config) {
this.localProvider = await createAiSdkProvider(this.config)
// this.localProvider = await createAiSdkProvider(this.config) // TODO: Update provider creation
if (!this.localProvider) {
throw new Error('Local provider not created')
}
@ -537,7 +536,7 @@ export default class ModernAiProvider {
...(signal && { abortSignal: signal })
}
const executor = createExecutor(this.config!.providerId, this.config!.options, [])
const executor = createExecutor(this.config!.providerId, this.config!.providerSettings, [])
const result = await executor.generateImage({
model: model, // 直接使用 model ID 字符串,由 executor 内部解析
...aiSdkParams

View File

@ -174,9 +174,11 @@ describe('Copilot responses routing', () => {
const config = providerToAiSdkConfig(provider, createModel('gpt-5-codex', 'GPT-5-CODEX'))
expect(config.providerId).toBe('github-copilot-openai-compatible')
expect(config.options.headers?.['Editor-Version']).toBe(COPILOT_EDITOR_VERSION)
expect(config.options.headers?.['Copilot-Integration-Id']).toBe(COPILOT_DEFAULT_HEADERS['Copilot-Integration-Id'])
expect(config.options.headers?.['copilot-vision-request']).toBe('true')
expect(config.providerSettings.headers?.['Editor-Version']).toBe(COPILOT_EDITOR_VERSION)
expect(config.providerSettings.headers?.['Copilot-Integration-Id']).toBe(
COPILOT_DEFAULT_HEADERS['Copilot-Integration-Id']
)
expect(config.providerSettings.headers?.['copilot-vision-request']).toBe('true')
})
it('uses the Copilot provider for other models and keeps headers', () => {
@ -184,8 +186,10 @@ describe('Copilot responses routing', () => {
const config = providerToAiSdkConfig(provider, createModel('gpt-4'))
expect(config.providerId).toBe('github-copilot-openai-compatible')
expect(config.options.headers?.['Editor-Version']).toBe(COPILOT_DEFAULT_HEADERS['Editor-Version'])
expect(config.options.headers?.['Copilot-Integration-Id']).toBe(COPILOT_DEFAULT_HEADERS['Copilot-Integration-Id'])
expect(config.providerSettings.headers?.['Editor-Version']).toBe(COPILOT_DEFAULT_HEADERS['Editor-Version'])
expect(config.providerSettings.headers?.['Copilot-Integration-Id']).toBe(
COPILOT_DEFAULT_HEADERS['Copilot-Integration-Id']
)
})
})
@ -388,7 +392,7 @@ describe('Stream options includeUsage configuration', () => {
const provider = createOpenAIProvider()
const config = providerToAiSdkConfig(provider, createModel('gpt-4', 'GPT-4', 'openai'))
expect(config.options.includeUsage).toBeUndefined()
expect(config.providerSettings.includeUsage).toBeUndefined()
})
it('uses includeUsage from settings when set to true', () => {
@ -406,7 +410,7 @@ describe('Stream options includeUsage configuration', () => {
const provider = createOpenAIProvider()
const config = providerToAiSdkConfig(provider, createModel('gpt-4', 'GPT-4', 'openai'))
expect(config.options.includeUsage).toBe(true)
expect(config.providerSettings.includeUsage).toBe(true)
})
it('uses includeUsage from settings when set to false', () => {
@ -424,7 +428,7 @@ describe('Stream options includeUsage configuration', () => {
const provider = createOpenAIProvider()
const config = providerToAiSdkConfig(provider, createModel('gpt-4', 'GPT-4', 'openai'))
expect(config.options.includeUsage).toBe(false)
expect(config.providerSettings.includeUsage).toBe(false)
})
it('respects includeUsage setting for non-supporting providers', () => {
@ -455,7 +459,7 @@ describe('Stream options includeUsage configuration', () => {
const config = providerToAiSdkConfig(testProvider, createModel('gpt-4', 'GPT-4', 'test'))
// Even though setting is true, provider doesn't support it, so includeUsage should be undefined
expect(config.options.includeUsage).toBeUndefined()
expect(config.providerSettings.includeUsage).toBeUndefined()
})
it('uses includeUsage from settings for Copilot provider when set to false', () => {
@ -473,7 +477,7 @@ describe('Stream options includeUsage configuration', () => {
const provider = createCopilotProvider()
const config = providerToAiSdkConfig(provider, createModel('gpt-4', 'GPT-4', 'copilot'))
expect(config.options.includeUsage).toBe(false)
expect(config.providerSettings.includeUsage).toBe(false)
expect(config.providerId).toBe('github-copilot-openai-compatible')
})
@ -492,7 +496,7 @@ describe('Stream options includeUsage configuration', () => {
const provider = createCopilotProvider()
const config = providerToAiSdkConfig(provider, createModel('gpt-4', 'GPT-4', 'copilot'))
expect(config.options.includeUsage).toBe(true)
expect(config.providerSettings.includeUsage).toBe(true)
expect(config.providerId).toBe('github-copilot-openai-compatible')
})
@ -511,7 +515,7 @@ describe('Stream options includeUsage configuration', () => {
const provider = createCopilotProvider()
const config = providerToAiSdkConfig(provider, createModel('gpt-4', 'GPT-4', 'copilot'))
expect(config.options.includeUsage).toBeUndefined()
expect(config.providerSettings.includeUsage).toBeUndefined()
expect(config.providerId).toBe('github-copilot-openai-compatible')
})
})
@ -540,21 +544,21 @@ describe('Azure OpenAI traditional API routing', () => {
const config = providerToAiSdkConfig(provider, createModel('gpt-4o', 'GPT-4o', provider.id))
expect(config.providerId).toBe('azure')
expect(config.options.apiVersion).toBe('2024-02-15-preview')
expect(config.options.useDeploymentBasedUrls).toBe(true)
expect(config.providerSettings.apiVersion).toBe('2024-02-15-preview')
expect(config.providerSettings.useDeploymentBasedUrls).toBe(true)
})
it('does not force deployment-based URLs for apiVersion v1/preview', () => {
const v1Provider = createAzureProvider('v1')
const v1Config = providerToAiSdkConfig(v1Provider, createModel('gpt-4o', 'GPT-4o', v1Provider.id))
expect(v1Config.providerId).toBe('azure-responses')
expect(v1Config.options.apiVersion).toBe('v1')
expect(v1Config.options.useDeploymentBasedUrls).toBeUndefined()
expect(v1Config.providerSettings.apiVersion).toBe('v1')
expect(v1Config.providerSettings.useDeploymentBasedUrls).toBeUndefined()
const previewProvider = createAzureProvider('preview')
const previewConfig = providerToAiSdkConfig(previewProvider, createModel('gpt-4o', 'GPT-4o', previewProvider.id))
expect(previewConfig.providerId).toBe('azure-responses')
expect(previewConfig.options.apiVersion).toBe('preview')
expect(previewConfig.options.useDeploymentBasedUrls).toBeUndefined()
expect(previewConfig.providerSettings.apiVersion).toBe('preview')
expect(previewConfig.providerSettings.useDeploymentBasedUrls).toBeUndefined()
})
})

View File

@ -3,8 +3,21 @@
* AI Providers
*/
import type { ProviderV2 } from '@ai-sdk/provider'
import { ProviderExtension, type ProviderExtensionConfig } from '@cherrystudio/ai-core/provider'
import { type AmazonBedrockProviderSettings, createAmazonBedrock } from '@ai-sdk/amazon-bedrock'
import { type AnthropicProviderSettings, createAnthropic } from '@ai-sdk/anthropic'
import { type CerebrasProviderSettings, createCerebras } from '@ai-sdk/cerebras'
import { createGateway, type GatewayProviderSettings } from '@ai-sdk/gateway'
import { createVertexAnthropic } from '@ai-sdk/google-vertex/anthropic/edge'
import { createVertex, type GoogleVertexProviderSettings } from '@ai-sdk/google-vertex/edge'
import { createHuggingFace, type HuggingFaceProviderSettings } from '@ai-sdk/huggingface'
import { createMistral, type MistralProviderSettings } from '@ai-sdk/mistral'
import { createPerplexity, type PerplexityProviderSettings } from '@ai-sdk/perplexity'
import type { ProviderV2, ProviderV3 } from '@ai-sdk/provider'
import { ExtensionStorage, ProviderExtension, type ProviderExtensionConfig } from '@cherrystudio/ai-core/provider'
import {
createGitHubCopilotOpenAICompatible,
type GitHubCopilotProviderSettings
} from '@opeoginni/github-copilot-openai-compatible'
import { wrapProvider } from 'ai'
import type { OllamaProviderSettings } from 'ollama-ai-provider-v2'
import { createOllama } from 'ollama-ai-provider-v2'
@ -16,9 +29,13 @@ export const GoogleVertexExtension = ProviderExtension.create({
name: 'google-vertex',
aliases: ['vertexai'] as const,
supportsImageGeneration: true,
import: () => import('@ai-sdk/google-vertex/edge'),
creatorFunctionName: 'createVertex'
} as const satisfies ProviderExtensionConfig<any, any, any, 'google-vertex'>)
create: createVertex
} as const satisfies ProviderExtensionConfig<
GoogleVertexProviderSettings,
ExtensionStorage,
ProviderV3,
'google-vertex'
>)
/**
* Google Vertex AI Anthropic Extension
@ -27,9 +44,13 @@ export const GoogleVertexAnthropicExtension = ProviderExtension.create({
name: 'google-vertex-anthropic',
aliases: ['vertexai-anthropic'] as const,
supportsImageGeneration: true,
import: () => import('@ai-sdk/google-vertex/anthropic/edge'),
creatorFunctionName: 'createVertexAnthropic'
} as const satisfies ProviderExtensionConfig<any, any, any, 'google-vertex-anthropic'>)
create: createVertexAnthropic
} as const satisfies ProviderExtensionConfig<
GoogleVertexProviderSettings,
ExtensionStorage,
ProviderV3,
'google-vertex-anthropic'
>)
/**
* Azure AI Anthropic Extension
@ -38,9 +59,13 @@ export const AzureAnthropicExtension = ProviderExtension.create({
name: 'azure-anthropic',
aliases: ['azure-anthropic'] as const,
supportsImageGeneration: false,
import: () => import('@ai-sdk/anthropic'),
creatorFunctionName: 'createAnthropic'
} as const satisfies ProviderExtensionConfig<any, any, any, 'azure-anthropic'>)
create: createAnthropic
} as const satisfies ProviderExtensionConfig<
AnthropicProviderSettings,
ExtensionStorage,
ProviderV3,
'azure-anthropic'
>)
/**
* GitHub Copilot Extension
@ -49,9 +74,16 @@ export const GitHubCopilotExtension = ProviderExtension.create({
name: 'github-copilot-openai-compatible',
aliases: ['copilot', 'github-copilot'] as const,
supportsImageGeneration: false,
import: () => import('@opeoginni/github-copilot-openai-compatible'),
creatorFunctionName: 'createGitHubCopilotOpenAICompatible'
} as const satisfies ProviderExtensionConfig<any, any, any, 'github-copilot-openai-compatible'>)
create: (options?: GitHubCopilotProviderSettings) => {
const provider = createGitHubCopilotOpenAICompatible(options) as unknown as ProviderV2
return wrapProvider({ provider, languageModelMiddleware: [] })
}
} as const satisfies ProviderExtensionConfig<
GitHubCopilotProviderSettings,
ExtensionStorage,
ProviderV3,
'github-copilot-openai-compatible'
>)
/**
* Amazon Bedrock Extension
@ -60,9 +92,8 @@ export const BedrockExtension = ProviderExtension.create({
name: 'bedrock',
aliases: ['aws-bedrock'] as const,
supportsImageGeneration: true,
import: () => import('@ai-sdk/amazon-bedrock'),
creatorFunctionName: 'createAmazonBedrock'
} as const satisfies ProviderExtensionConfig<any, any, any, 'bedrock'>)
create: createAmazonBedrock
} as const satisfies ProviderExtensionConfig<AmazonBedrockProviderSettings, ExtensionStorage, ProviderV3, 'bedrock'>)
/**
* Perplexity Extension
@ -70,9 +101,8 @@ export const BedrockExtension = ProviderExtension.create({
export const PerplexityExtension = ProviderExtension.create({
name: 'perplexity',
supportsImageGeneration: false,
import: () => import('@ai-sdk/perplexity'),
creatorFunctionName: 'createPerplexity'
} as const satisfies ProviderExtensionConfig<any, any, any, 'perplexity'>)
create: createPerplexity
} as const satisfies ProviderExtensionConfig<PerplexityProviderSettings, ExtensionStorage, ProviderV3, 'perplexity'>)
/**
* Mistral Extension
@ -81,9 +111,8 @@ export const MistralExtension = ProviderExtension.create({
name: 'mistral',
aliases: ['mistral'] as const,
supportsImageGeneration: false,
import: () => import('@ai-sdk/mistral'),
creatorFunctionName: 'createMistral'
} as const satisfies ProviderExtensionConfig<any, any, any, 'mistral'>)
create: createMistral
} as const satisfies ProviderExtensionConfig<MistralProviderSettings, ExtensionStorage, ProviderV3, 'mistral'>)
/**
* HuggingFace Extension
@ -92,9 +121,8 @@ export const HuggingFaceExtension = ProviderExtension.create({
name: 'huggingface',
aliases: ['hf', 'hugging-face'] as const,
supportsImageGeneration: true,
import: () => import('@ai-sdk/huggingface'),
creatorFunctionName: 'createHuggingFace'
} as const satisfies ProviderExtensionConfig<any, any, any, 'huggingface'>)
create: createHuggingFace
} as const satisfies ProviderExtensionConfig<HuggingFaceProviderSettings, ExtensionStorage, ProviderV3, 'huggingface'>)
/**
* Vercel AI Gateway Extension
@ -103,9 +131,8 @@ export const GatewayExtension = ProviderExtension.create({
name: 'gateway',
aliases: ['ai-gateway'] as const,
supportsImageGeneration: true,
import: () => import('@ai-sdk/gateway'),
creatorFunctionName: 'createGateway'
} as const satisfies ProviderExtensionConfig<any, any, any, 'gateway'>)
create: createGateway
} as const satisfies ProviderExtensionConfig<GatewayProviderSettings, ExtensionStorage, ProviderV3, 'gateway'>)
/**
* Cerebras Extension
@ -113,9 +140,8 @@ export const GatewayExtension = ProviderExtension.create({
export const CerebrasExtension = ProviderExtension.create({
name: 'cerebras',
supportsImageGeneration: false,
import: () => import('@ai-sdk/cerebras'),
creatorFunctionName: 'createCerebras'
} as const satisfies ProviderExtensionConfig<any, any, any, 'cerebras'>)
create: createCerebras
} as const satisfies ProviderExtensionConfig<CerebrasProviderSettings, ExtensionStorage, ProviderV3, 'cerebras'>)
/**
* Ollama Extension
@ -127,7 +153,7 @@ export const OllamaExtension = ProviderExtension.create({
const provider = createOllama(options) as ProviderV2
return wrapProvider({ provider, languageModelMiddleware: [] })
}
} as const satisfies ProviderExtensionConfig<OllamaProviderSettings, any, any, 'ollama'>)
} as const satisfies ProviderExtensionConfig<OllamaProviderSettings, ExtensionStorage, ProviderV3, 'ollama'>)
/**
* Extensions

View File

@ -1,6 +1,5 @@
import { extensionRegistry, formatPrivateKey, hasProviderConfig } from '@cherrystudio/ai-core/provider'
import { formatPrivateKey, hasProviderConfig } from '@cherrystudio/ai-core/provider'
import type { AppProviderId } from '@renderer/aiCore/types'
import { isOpenAIChatCompletionOnlyModel } from '@renderer/config/models'
import {
getAwsBedrockAccessKeyId,
getAwsBedrockApiKey,
@ -13,7 +12,6 @@ import { getProviderByModel } from '@renderer/services/AssistantService'
import { getProviderById } from '@renderer/services/ProviderService'
import store from '@renderer/store'
import { isSystemProvider, type Model, type Provider, SystemProviderIds } from '@renderer/types'
import type { OpenAICompletionsStreamOptions } from '@renderer/types/aiCoreTypes'
import {
formatApiHost,
formatAzureOpenAIApiHost,
@ -36,7 +34,7 @@ import {
import { defaultAppHeaders } from '@shared/utils'
import { cloneDeep, isEmpty } from 'lodash'
import type { AiSdkConfigRuntime } from '../types'
import type { ProviderConfig } from '../types'
import { aihubmixProviderCreator, newApiResolverCreator, vertexAnthropicProviderCreator } from './config'
import { azureAnthropicProviderCreator } from './config/azure-anthropic'
import { COPILOT_DEFAULT_HEADERS } from './constants'
@ -146,155 +144,56 @@ export function adaptProvider({ provider, model }: { provider: Provider; model?:
*
* @param actualProvider - Cherry Studio provider配置
* @param model -
* @returns AI SDK
* @returns Provider
*/
export function providerToAiSdkConfig(actualProvider: Provider, model: Model): AiSdkConfigRuntime {
const aiSdkProviderId: AppProviderId = getAiSdkProviderId(actualProvider)
// 构建基础配置
export function providerToAiSdkConfig(actualProvider: Provider, model: Model): ProviderConfig {
const aiSdkProviderId = getAiSdkProviderId(actualProvider)
const { baseURL, endpoint } = routeToEndpoint(actualProvider.apiHost)
const baseConfig = {
baseURL: baseURL,
apiKey: actualProvider.apiKey
}
let includeUsage: OpenAICompletionsStreamOptions['include_usage'] = undefined
if (isSupportStreamOptionsProvider(actualProvider)) {
includeUsage = store.getState().settings.openAI?.streamOptions?.includeUsage
// 构建上下文
const ctx: BuilderContext = {
actualProvider,
model,
baseConfig: {
baseURL,
apiKey: actualProvider.apiKey
},
endpoint,
aiSdkProviderId
}
const isCopilotProvider = actualProvider.id === SystemProviderIds.copilot
if (isCopilotProvider) {
const storedHeaders = store.getState().copilot.defaultHeaders ?? {}
const options = {
...baseConfig,
headers: {
...COPILOT_DEFAULT_HEADERS,
...storedHeaders,
...actualProvider.extra_headers
},
name: actualProvider.id,
includeUsage
}
return {
providerId: 'github-copilot-openai-compatible',
options
}
// 路由到专门的构建器
if (actualProvider.id === SystemProviderIds.copilot) {
return buildCopilotConfig(ctx)
}
if (isOllamaProvider(actualProvider)) {
return {
providerId: 'ollama',
options: {
...baseConfig,
headers: {
...actualProvider.extra_headers,
Authorization: !isEmpty(baseConfig.apiKey) ? `Bearer ${baseConfig.apiKey}` : undefined
}
}
}
return buildOllamaConfig(ctx)
}
// 处理OpenAI模式
const extraOptions: any = {}
extraOptions.endpoint = endpoint
// 解析 provider ID提取 base ID 和 mode
const parsed = extensionRegistry.parseProviderId(aiSdkProviderId)
if (parsed?.mode) {
// 自动设置 mode如 openai-chat → mode: 'chat', azure-responses → mode: 'responses'
extraOptions.mode = parsed.mode
}
// 特殊处理OpenAI responses 模式(基于 provider.type
if (actualProvider.type === 'openai-response' && !isOpenAIChatCompletionOnlyModel(model)) {
extraOptions.mode = 'responses'
} else if (aiSdkProviderId === 'openai' || (aiSdkProviderId === 'cherryin' && actualProvider.type === 'openai')) {
// 确保 OpenAI 和 CherryIN(type=openai) 使用 chat 模式
extraOptions.mode = 'chat'
}
extraOptions.headers = {
...defaultAppHeaders(),
...actualProvider.extra_headers
}
if (aiSdkProviderId === 'openai') {
const headers = extraOptions.headers as Record<string, string>
headers['X-Api-Key'] = baseConfig.apiKey
}
if (isAzureOpenAIProvider(actualProvider)) {
const apiVersion = actualProvider.apiVersion?.trim()
if (apiVersion) {
extraOptions.apiVersion = apiVersion
if (!['preview', 'v1'].includes(apiVersion)) {
extraOptions.useDeploymentBasedUrls = true
}
}
return buildAzureConfig(ctx)
}
// bedrock
if (aiSdkProviderId === 'bedrock') {
const authType = getAwsBedrockAuthType()
extraOptions.region = getAwsBedrockRegion()
if (authType === 'apiKey') {
extraOptions.apiKey = getAwsBedrockApiKey()
} else {
extraOptions.accessKeyId = getAwsBedrockAccessKeyId()
extraOptions.secretAccessKey = getAwsBedrockSecretAccessKey()
}
return buildBedrockConfig(ctx)
}
// google-vertex
if (aiSdkProviderId === 'google-vertex' || aiSdkProviderId === 'google-vertex-anthropic') {
if (!isVertexAIConfigured()) {
throw new Error('VertexAI is not configured. Please configure project, location and service account credentials.')
}
const { project, location, googleCredentials } = createVertexProvider(actualProvider)
extraOptions.project = project
extraOptions.location = location
extraOptions.googleCredentials = {
...googleCredentials,
privateKey: formatPrivateKey(googleCredentials.privateKey)
}
baseConfig.baseURL += aiSdkProviderId === 'google-vertex' ? '/publishers/google' : '/publishers/anthropic/models'
return buildVertexConfig(ctx)
}
// cherryin
if (aiSdkProviderId === 'cherryin') {
if (model.endpoint_type) {
extraOptions.endpointType = model.endpoint_type
}
// CherryIN API Host
const cherryinProvider = getProviderById(SystemProviderIds.cherryin)
if (cherryinProvider) {
extraOptions.anthropicBaseURL = cherryinProvider.anthropicApiHost + '/v1'
extraOptions.geminiBaseURL = cherryinProvider.apiHost + '/v1beta/models'
}
return buildCherryinConfig(ctx)
}
// 有 SDK 支持的 provider
if (hasProviderConfig(aiSdkProviderId) && aiSdkProviderId !== 'openai-compatible') {
const options = {
...baseConfig,
...extraOptions
}
return {
providerId: aiSdkProviderId,
options
}
return buildGenericProviderConfig(ctx)
}
// 否则fallback到openai-compatible
return {
providerId: 'openai-compatible',
options: {
baseURL: baseConfig.baseURL,
apiKey: baseConfig.apiKey,
name: actualProvider.id,
...extraOptions,
includeUsage
}
}
// 默认 fallback 到 openai-compatible
return buildOpenAICompatibleConfig(ctx)
}
/**
@ -317,10 +216,7 @@ export function isModernSdkSupported(provider: Provider): boolean {
/**
* provider的配置,
*/
export async function prepareSpecialProviderConfig(
provider: Provider,
config: ReturnType<typeof providerToAiSdkConfig>
) {
export async function prepareSpecialProviderConfig(provider: Provider, config: ProviderConfig) {
switch (provider.id) {
case 'copilot': {
const defaultHeaders = store.getState().copilot.defaultHeaders ?? {}
@ -329,15 +225,17 @@ export async function prepareSpecialProviderConfig(
...defaultHeaders
}
const { token } = await window.api.copilot.getToken(headers)
config.options.apiKey = token
config.options.headers = {
const settings = config.providerSettings as any
settings.apiKey = token
settings.headers = {
...headers,
...config.options.headers
...settings.headers
}
break
}
case 'cherryai': {
config.options.fetch = async (url, options) => {
const settings = config.providerSettings as any
settings.fetch = async (url: string, options: any) => {
// 在这里对最终参数进行签名
const signature = await window.api.cherryai.generateSignature({
method: 'POST',
@ -358,10 +256,11 @@ export async function prepareSpecialProviderConfig(
case 'anthropic': {
if (provider.authType === 'oauth') {
const oauthToken = await window.api.anthropic_oauth.getAccessToken()
config.options = {
...config.options,
const settings = config.providerSettings as any
config.providerSettings = {
...settings,
headers: {
...(config.options.headers ? config.options.headers : {}),
...(settings.headers ? settings.headers : {}),
'Content-Type': 'application/json',
'anthropic-version': '2023-06-01',
Authorization: `Bearer ${oauthToken}`
@ -374,3 +273,253 @@ export async function prepareSpecialProviderConfig(
}
return config
}
/**
*
*/
interface BaseConfig {
baseURL: string
apiKey: string
}
/**
*
*/
interface BuilderContext {
actualProvider: Provider
model: Model
baseConfig: BaseConfig
endpoint?: string
aiSdkProviderId: AppProviderId
}
/**
* GitHub Copilot
*/
function buildCopilotConfig(ctx: BuilderContext): ProviderConfig<'github-copilot-openai-compatible'> {
const storedHeaders = store.getState().copilot.defaultHeaders ?? {}
return {
providerId: 'github-copilot-openai-compatible',
providerSettings: {
...ctx.baseConfig,
headers: {
...COPILOT_DEFAULT_HEADERS,
...storedHeaders,
...ctx.actualProvider.extra_headers
},
name: ctx.actualProvider.id
}
}
}
/**
* Ollama
*/
function buildOllamaConfig(ctx: BuilderContext): ProviderConfig<'ollama'> {
const headers: ProviderConfig<'ollama'>['providerSettings']['headers'] = {
...ctx.actualProvider.extra_headers
}
if (!isEmpty(ctx.baseConfig.apiKey)) {
headers.Authorization = `Bearer ${ctx.baseConfig.apiKey}`
}
return {
providerId: 'ollama',
providerSettings: {
...ctx.baseConfig,
headers
}
}
}
/**
* AWS Bedrock
*/
function buildBedrockConfig(ctx: BuilderContext): ProviderConfig<'bedrock'> {
const authType = getAwsBedrockAuthType()
const region = getAwsBedrockRegion()
if (authType === 'apiKey') {
return {
providerId: 'bedrock',
providerSettings: {
...ctx.baseConfig,
region,
apiKey: getAwsBedrockApiKey()
}
}
}
return {
providerId: 'bedrock',
providerSettings: {
...ctx.baseConfig,
region,
accessKeyId: getAwsBedrockAccessKeyId(),
secretAccessKey: getAwsBedrockSecretAccessKey()
}
}
}
/**
* Google Vertex AI
*/
function buildVertexConfig(
ctx: BuilderContext
): ProviderConfig<'google-vertex'> | ProviderConfig<'google-vertex-anthropic'> {
if (!isVertexAIConfigured()) {
throw new Error('VertexAI is not configured. Please configure project, location and service account credentials.')
}
const { project, location, googleCredentials } = createVertexProvider(ctx.actualProvider)
const isAnthropic = ctx.aiSdkProviderId === 'google-vertex-anthropic'
const baseURL = ctx.baseConfig.baseURL + (isAnthropic ? '/publishers/anthropic/models' : '/publishers/google')
if (isAnthropic) {
return {
providerId: 'google-vertex-anthropic',
providerSettings: {
...ctx.baseConfig,
baseURL,
project,
location,
googleCredentials: {
...googleCredentials,
privateKey: formatPrivateKey(googleCredentials.privateKey)
}
}
}
}
return {
providerId: 'google-vertex',
providerSettings: {
...ctx.baseConfig,
baseURL,
project,
location,
googleCredentials: {
...googleCredentials,
privateKey: formatPrivateKey(googleCredentials.privateKey)
}
}
}
}
/**
* CherryIN
*/
function buildCherryinConfig(ctx: BuilderContext): ProviderConfig<'cherryin'> {
const cherryinProvider = getProviderById(SystemProviderIds.cherryin)
return {
providerId: 'cherryin',
providerSettings: {
...ctx.baseConfig,
endpointType: ctx.model.endpoint_type,
anthropicBaseURL: cherryinProvider ? cherryinProvider.anthropicApiHost + '/v1' : undefined,
geminiBaseURL: cherryinProvider ? cherryinProvider.apiHost + '/v1beta/models' : undefined,
headers: {
...defaultAppHeaders(),
...ctx.actualProvider.extra_headers
}
}
}
}
/**
* Azure OpenAI
*/
function buildAzureConfig(ctx: BuilderContext): ProviderConfig<'azure'> | ProviderConfig<'azure-responses'> {
const apiVersion = ctx.actualProvider.apiVersion?.trim()
// 根据 apiVersion 决定使用 azure 还是 azure-responses
const useResponsesMode = apiVersion && ['preview', 'v1'].includes(apiVersion)
const providerSettings: Record<string, any> = {
...ctx.baseConfig,
endpoint: ctx.endpoint,
headers: {
...defaultAppHeaders(),
...ctx.actualProvider.extra_headers
}
}
if (apiVersion) {
providerSettings.apiVersion = apiVersion
// 只有非 preview/v1 版本才使用 deployment-based URLs
if (!useResponsesMode) {
providerSettings.useDeploymentBasedUrls = true
}
}
if (useResponsesMode) {
return {
providerId: 'azure-responses',
providerSettings
}
}
return {
providerId: 'azure',
providerSettings
}
}
/**
* OpenAI-compatible provider
*/
function buildCommonOptions(ctx: BuilderContext) {
const options: Record<string, any> = {
endpoint: ctx.endpoint,
headers: {
...defaultAppHeaders(),
...ctx.actualProvider.extra_headers
}
}
// OpenAI 特殊 header
if (ctx.aiSdkProviderId === 'openai') {
options.headers['X-Api-Key'] = ctx.baseConfig.apiKey
}
return options
}
/**
* OpenAI-compatible
*/
function buildOpenAICompatibleConfig(ctx: BuilderContext): ProviderConfig<'openai-compatible'> {
const commonOptions = buildCommonOptions(ctx)
const includeUsage = isSupportStreamOptionsProvider(ctx.actualProvider)
? store.getState().settings.openAI?.streamOptions?.includeUsage
: undefined
return {
providerId: 'openai-compatible',
providerSettings: {
...ctx.baseConfig,
...commonOptions,
name: ctx.actualProvider.id,
includeUsage
}
}
}
/**
* provider SDK provider
*/
function buildGenericProviderConfig(ctx: BuilderContext): ProviderConfig {
const commonOptions = buildCommonOptions(ctx)
return {
providerId: ctx.aiSdkProviderId,
providerSettings: {
...ctx.baseConfig,
...commonOptions
}
}
}

View File

@ -7,85 +7,38 @@
* TODO: We should separate them clearly. Keep renderer only types in renderer, and main only types in main, and shared types in shared.
*/
import type { AppProviderId, AppProviderSettingsMap } from './merged'
import type { AppProviderId, AppRuntimeConfig } from './merged'
/**
* Generic AI SDK configuration with compile-time type safety
* Provider plugins
* RuntimeConfig provider
*
* 🎯 Zero maintenance! Auto-extracts types from core and project extensions.
*
* @typeParam T - The specific provider ID type for type-safe options
* @typeParam T - The specific provider ID type for type-safe settings
*
* @example
* ```ts
* // Type-safe config for core provider
* const config1: AiSdkConfig<'openai'> = {
* const config1: ProviderConfig<'openai'> = {
* providerId: 'openai',
* options: { apiKey: '...', baseURL: '...' } // ✅ Typed as OpenAIProviderSettings
* providerSettings: { apiKey: '...', baseURL: '...' } // ✅ Typed as OpenAIProviderSettings
* }
*
* // Type-safe config for project provider
* const config2: AiSdkConfig<'google-vertex'> = {
* const config2: ProviderConfig<'google-vertex'> = {
* providerId: 'google-vertex',
* options: { ... } // ✅ Typed as GoogleVertexProviderSettings
* providerSettings: { ... } // ✅ Typed as GoogleVertexProviderSettings
* }
*
* // Type-safe config with alias
* const config3: AiSdkConfig<'oai'> = {
* const config3: ProviderConfig<'oai'> = {
* providerId: 'oai',
* options: { apiKey: '...' } // ✅ Same type as 'openai'
* providerSettings: { apiKey: '...' } // ✅ Same type as 'openai'
* }
* ```
*/
export type AiSdkConfig<T extends AppProviderId = AppProviderId> = {
providerId: T
options: AppProviderSettingsMap[T]
}
/**
* Runtime-safe AI SDK configuration for gradual migration
* Use this when provider ID is not known at compile time
*
* 使 any
*
* @example
* ```ts
* function createConfig(providerId: AppProviderId): AiSdkConfigRuntime {
* return {
* providerId,
* options: buildOptions(providerId) // ✅ 类型安全options 必须是某个 provider 的 settings
* }
* }
* ```
*/
export type AiSdkConfigRuntime = {
providerId: AppProviderId
options: AppProviderSettingsMap[AppProviderId]
}
/**
* Type guard for runtime validation of AiSdkConfig
*
* @param config - Unknown value to validate
* @returns true if config is a valid AiSdkConfigRuntime
*
* @example
* ```ts
* if (isValidAiSdkConfig(someConfig)) {
* // someConfig is now typed as AiSdkConfigRuntime
* await createAiSdkProvider(someConfig)
* }
* ```
*/
export function isValidAiSdkConfig(config: unknown): config is AiSdkConfigRuntime {
if (!config || typeof config !== 'object') return false
const c = config as Record<string, unknown>
return (
typeof c.providerId === 'string' && c.providerId.length > 0 && typeof c.options === 'object' && c.options !== null
)
}
export type ProviderConfig<T extends AppProviderId = AppProviderId> = Omit<AppRuntimeConfig<T>, 'plugins'>
export type { AppProviderId, AppProviderSettingsMap } from './merged'
export { appProviderIds, getAllProviderIds, isRegisteredProviderId } from './merged'