fix: Improve provider config type safety and ensure required fields (#12589)

This commit is contained in:
Phantom 2026-01-26 14:35:46 +08:00 committed by GitHub
parent 5366110ce1
commit 0255cb8443
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -11,6 +11,7 @@ import { createVertexProvider, isVertexAIConfigured } from '@renderer/hooks/useV
import { getProviderByModel } from '@renderer/services/AssistantService'
import { getProviderById } from '@renderer/services/ProviderService'
import store from '@renderer/store'
import type { EndpointType } from '@renderer/types'
import { isSystemProvider, type Model, type Provider, SystemProviderIds } from '@renderer/types'
import type { OpenAICompletionsStreamOptions } from '@renderer/types/aiCoreTypes'
import {
@ -140,6 +141,48 @@ export function adaptProvider({ provider, model }: { provider: Provider; model?:
return adaptedProvider
}
interface BaseExtraOptions {
fetch?: typeof fetch
endpoint: string
mode?: 'responses' | 'chat'
headers: Record<string, string>
}
interface AzureOpenAIExtraOptions extends BaseExtraOptions {
apiVersion: string
useDeploymentBasedUrls: true | undefined
}
interface BedrockApiKeyExtraOptions extends BaseExtraOptions {
region: string
apiKey: string
}
interface BedrockAccessKeyExtraOptions extends BaseExtraOptions {
region: string
accessKeyId: string
secretAccessKey: string
}
type BedrockExtraOptions = BedrockApiKeyExtraOptions | BedrockAccessKeyExtraOptions
interface VertexExtraOptions extends BaseExtraOptions {
project: string
location: string
googleCredentials: {
privateKey: string
clientEmail: string
}
}
interface CherryInExtraOptions extends BaseExtraOptions {
endpointType?: EndpointType
anthropicBaseURL?: string
geminiBaseURL?: string
}
type ExtraOptions = BedrockExtraOptions | AzureOpenAIExtraOptions | VertexExtraOptions | CherryInExtraOptions
/**
* Provider AI SDK
*
@ -158,6 +201,8 @@ export function providerToAiSdkConfig(actualProvider: Provider, model: Model): A
includeUsage = store.getState().settings.openAI?.streamOptions?.includeUsage
}
// Specially, some providers which need to early return
// Copilot
const isCopilotProvider = actualProvider.id === SystemProviderIds.copilot
if (isCopilotProvider) {
const storedHeaders = store.getState().copilot.defaultHeaders ?? {}
@ -177,6 +222,7 @@ export function providerToAiSdkConfig(actualProvider: Provider, model: Model): A
}
}
// Ollama
if (isOllamaProvider(actualProvider)) {
return {
providerId: 'ollama',
@ -190,106 +236,142 @@ export function providerToAiSdkConfig(actualProvider: Provider, model: Model): A
}
}
// 处理OpenAI模式
const extraOptions: any = {}
extraOptions.endpoint = endpoint
if (actualProvider.type === 'openai-response' && !isOpenAIChatCompletionOnlyModel(model)) {
extraOptions.mode = 'responses'
} else if (aiSdkProviderId === 'openai' || (aiSdkProviderId === 'cherryin' && actualProvider.type === 'openai')) {
extraOptions.mode = 'chat'
// Generally, construct extraOptions according to provider & model
// Consider as OpenAI like provider
// Construct baseExtraOptions first
// About mode of azure:
// https://learn.microsoft.com/en-us/azure/ai-foundry/openai/latest
// https://learn.microsoft.com/en-us/azure/ai-foundry/openai/how-to/responses?tabs=python-key#responses-api
let mode: BaseExtraOptions['mode']
if (
(actualProvider.type === 'openai-response' && !isOpenAIChatCompletionOnlyModel(model)) ||
aiSdkProviderId === 'azure-responses'
) {
mode = 'responses'
} else if (
aiSdkProviderId === 'openai' ||
(aiSdkProviderId === 'cherryin' && actualProvider.type === 'openai') ||
aiSdkProviderId === 'azure'
) {
mode = 'chat'
}
extraOptions.headers = {
const headers: BaseExtraOptions['headers'] = {
...defaultAppHeaders(),
...actualProvider.extra_headers
}
if (aiSdkProviderId === 'openai') {
extraOptions.headers['X-Api-Key'] = baseConfig.apiKey
}
// azure
// https://learn.microsoft.com/en-us/azure/ai-foundry/openai/latest
// https://learn.microsoft.com/en-us/azure/ai-foundry/openai/how-to/responses?tabs=python-key#responses-api
if (aiSdkProviderId === 'azure-responses') {
extraOptions.mode = 'responses'
} else if (aiSdkProviderId === 'azure') {
extraOptions.mode = 'chat'
}
if (isAzureOpenAIProvider(actualProvider)) {
const apiVersion = actualProvider.apiVersion?.trim()
if (apiVersion) {
extraOptions.apiVersion = apiVersion
if (!['preview', 'v1'].includes(apiVersion)) {
extraOptions.useDeploymentBasedUrls = true
}
if (actualProvider.extra_headers?.['X-Api-Key'] === undefined) {
headers['X-Api-Key'] = baseConfig.apiKey
}
}
// bedrock
if (aiSdkProviderId === 'bedrock') {
const authType = getAwsBedrockAuthType()
extraOptions.region = getAwsBedrockRegion()
if (authType === 'apiKey') {
extraOptions.apiKey = getAwsBedrockApiKey()
} else {
extraOptions.accessKeyId = getAwsBedrockAccessKeyId()
extraOptions.secretAccessKey = getAwsBedrockSecretAccessKey()
}
}
// google-vertex
if (aiSdkProviderId === 'google-vertex' || aiSdkProviderId === 'google-vertex-anthropic') {
if (!isVertexAIConfigured()) {
throw new Error('VertexAI is not configured. Please configure project, location and service account credentials.')
}
const { project, location, googleCredentials } = createVertexProvider(actualProvider)
extraOptions.project = project
extraOptions.location = location
extraOptions.googleCredentials = {
...googleCredentials,
privateKey: formatPrivateKey(googleCredentials.privateKey)
}
baseConfig.baseURL += aiSdkProviderId === 'google-vertex' ? '/publishers/google' : '/publishers/anthropic/models'
}
// cherryin
if (aiSdkProviderId === 'cherryin') {
if (model.endpoint_type) {
extraOptions.endpointType = model.endpoint_type
}
// CherryIN API Host
const cherryinProvider = getProviderById(SystemProviderIds.cherryin)
if (cherryinProvider) {
extraOptions.anthropicBaseURL = cherryinProvider.anthropicApiHost + '/v1'
extraOptions.geminiBaseURL = cherryinProvider.apiHost + '/v1beta/models'
}
}
let _fetch: typeof fetch | undefined
// Apply developer-to-system role conversion for providers that don't support developer role
// bug: https://github.com/vercel/ai/issues/10982
// fixPR: https://github.com/vercel/ai/pull/11127
// TODO: but the PR don't backport to v5, the code will be removed when upgrading to v6
if (!isSupportDeveloperRoleProvider(actualProvider) || !isOpenAIReasoningModel(model)) {
extraOptions.fetch = createDeveloperToSystemFetch(extraOptions.fetch)
_fetch = createDeveloperToSystemFetch(fetch)
}
const baseExtraOptions = {
fetch: _fetch,
endpoint,
mode,
headers
} as const satisfies BaseExtraOptions
// Create specifical fields in extraOptions for different provider
let extraOptions: ExtraOptions | undefined
if (isAzureOpenAIProvider(actualProvider)) {
const apiVersion = actualProvider.apiVersion?.trim()
let useDeploymentBasedUrls: true | undefined
if (apiVersion) {
if (!['preview', 'v1'].includes(apiVersion)) {
useDeploymentBasedUrls = true
}
}
extraOptions = {
...baseExtraOptions,
apiVersion,
useDeploymentBasedUrls
} satisfies AzureOpenAIExtraOptions
} else if (aiSdkProviderId === 'bedrock') {
// bedrock
const authType = getAwsBedrockAuthType()
const region = getAwsBedrockRegion()
if (authType === 'apiKey') {
extraOptions = {
...baseExtraOptions,
region,
apiKey: getAwsBedrockApiKey()
} satisfies BedrockApiKeyExtraOptions
} else {
extraOptions = {
...baseExtraOptions,
region,
accessKeyId: getAwsBedrockAccessKeyId(),
secretAccessKey: getAwsBedrockSecretAccessKey()
} satisfies BedrockAccessKeyExtraOptions
}
} else if (aiSdkProviderId === 'google-vertex' || aiSdkProviderId === 'google-vertex-anthropic') {
// google-vertex
if (!isVertexAIConfigured()) {
throw new Error('VertexAI is not configured. Please configure project, location and service account credentials.')
}
const { project, location, googleCredentials } = createVertexProvider(actualProvider)
extraOptions = {
...baseExtraOptions,
project,
location,
googleCredentials: {
...googleCredentials,
privateKey: formatPrivateKey(googleCredentials.privateKey)
}
} satisfies VertexExtraOptions
baseConfig.baseURL += aiSdkProviderId === 'google-vertex' ? '/publishers/google' : '/publishers/anthropic/models'
} else if (aiSdkProviderId === 'cherryin') {
// CherryIN API Host
const cherryinProvider = getProviderById(SystemProviderIds.cherryin)
const endpointType: EndpointType | undefined = model.endpoint_type
let anthropicBaseURL: string | undefined
let geminiBaseURL: string | undefined
if (cherryinProvider) {
anthropicBaseURL = cherryinProvider.anthropicApiHost + '/v1'
geminiBaseURL = cherryinProvider.apiHost + '/v1beta/models'
}
extraOptions = {
...baseExtraOptions,
endpointType,
anthropicBaseURL,
geminiBaseURL
} satisfies CherryInExtraOptions
} else {
extraOptions = baseExtraOptions
}
if (hasProviderConfig(aiSdkProviderId) && aiSdkProviderId !== 'openai-compatible') {
// if the provider has a specific aisdk provider
const options = ProviderConfigFactory.fromProvider(aiSdkProviderId, baseConfig, extraOptions)
return {
providerId: aiSdkProviderId,
options
}
}
// 否则fallback到openai-compatible
const options = ProviderConfigFactory.createOpenAICompatible(baseConfig.baseURL, baseConfig.apiKey)
return {
providerId: 'openai-compatible',
options: {
...options,
name: actualProvider.id,
...extraOptions,
includeUsage
} else {
// otherwise, fallback to openai-compatible
const options = ProviderConfigFactory.createOpenAICompatible(baseConfig.baseURL, baseConfig.apiKey)
return {
providerId: 'openai-compatible',
options: {
...options,
name: actualProvider.id,
...extraOptions,
includeUsage
}
}
}
}