mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2026-02-16 15:44:45 +08:00
280 lines
10 KiB
TypeScript
280 lines
10 KiB
TypeScript
/**
|
|
* 参数构建模块
|
|
* 构建AI SDK的流式和非流式参数
|
|
*/
|
|
|
|
import { anthropic } from '@ai-sdk/anthropic'
|
|
import { azure } from '@ai-sdk/azure'
|
|
import { google } from '@ai-sdk/google'
|
|
import { vertexAnthropic } from '@ai-sdk/google-vertex/anthropic/edge'
|
|
import { vertex } from '@ai-sdk/google-vertex/edge'
|
|
import { combineHeaders } from '@ai-sdk/provider-utils'
|
|
import type { AnthropicSearchConfig, WebSearchPluginConfig } from '@cherrystudio/ai-core/built-in/plugins'
|
|
import { isBaseProvider } from '@cherrystudio/ai-core/core/providers/schemas'
|
|
import type { BaseProviderId } from '@cherrystudio/ai-core/provider'
|
|
import { loggerService } from '@logger'
|
|
import {
|
|
isAnthropicModel,
|
|
isFixedReasoningModel,
|
|
isGeminiModel,
|
|
isGenerateImageModel,
|
|
isGrokModel,
|
|
isOpenAIModel,
|
|
isOpenRouterBuiltInWebSearchModel,
|
|
isPureGenerateImageModel,
|
|
isSupportedReasoningEffortModel,
|
|
isSupportedThinkingTokenModel,
|
|
isWebSearchModel
|
|
} from '@renderer/config/models'
|
|
import { getHubModeSystemPrompt } from '@renderer/config/prompts-code-mode'
|
|
import { fetchAllActiveServerTools } from '@renderer/services/ApiService'
|
|
import { getDefaultModel } from '@renderer/services/AssistantService'
|
|
import store from '@renderer/store'
|
|
import type { CherryWebSearchConfig } from '@renderer/store/websearch'
|
|
import type { Model } from '@renderer/types'
|
|
import { type Assistant, getEffectiveMcpMode, type MCPTool, type Provider, SystemProviderIds } from '@renderer/types'
|
|
import type { StreamTextParams } from '@renderer/types/aiCoreTypes'
|
|
import { mapRegexToPatterns } from '@renderer/utils/blacklistMatchPattern'
|
|
import { replacePromptVariables } from '@renderer/utils/prompt'
|
|
import { isAIGatewayProvider, isAwsBedrockProvider, isSupportUrlContextProvider } from '@renderer/utils/provider'
|
|
import type { ModelMessage, Tool } from 'ai'
|
|
import { stepCountIs } from 'ai'
|
|
|
|
import { getAiSdkProviderId } from '../provider/factory'
|
|
import { setupToolsConfig } from '../utils/mcp'
|
|
import { buildProviderOptions } from '../utils/options'
|
|
import { buildProviderBuiltinWebSearchConfig } from '../utils/websearch'
|
|
import { addAnthropicHeaders } from './header'
|
|
import { getMaxTokens, getTemperature, getTopP } from './modelParameters'
|
|
|
|
const logger = loggerService.withContext('parameterBuilder')
|
|
|
|
type ProviderDefinedTool = Extract<Tool<any, any>, { type: 'provider' }>
|
|
|
|
function mapVertexAIGatewayModelToProviderId(model: Model): BaseProviderId | undefined {
|
|
if (isAnthropicModel(model)) {
|
|
return 'anthropic'
|
|
}
|
|
if (isGeminiModel(model)) {
|
|
return 'google'
|
|
}
|
|
if (isGrokModel(model)) {
|
|
return 'xai'
|
|
}
|
|
if (isOpenAIModel(model)) {
|
|
return 'openai'
|
|
}
|
|
logger.warn(`Unknown model type for AI Gateway: ${model.id}. Web search will not be enabled.`)
|
|
return undefined
|
|
}
|
|
|
|
/**
|
|
* 构建 AI SDK 流式参数
|
|
* 这是主要的参数构建函数,整合所有转换逻辑
|
|
*/
|
|
export async function buildStreamTextParams(
|
|
sdkMessages: StreamTextParams['messages'] = [],
|
|
assistant: Assistant,
|
|
provider: Provider,
|
|
options: {
|
|
mcpTools?: MCPTool[]
|
|
webSearchProviderId?: string
|
|
webSearchConfig?: CherryWebSearchConfig
|
|
requestOptions?: {
|
|
signal?: AbortSignal
|
|
timeout?: number
|
|
headers?: Record<string, string>
|
|
}
|
|
}
|
|
): Promise<{
|
|
params: StreamTextParams
|
|
modelId: string
|
|
capabilities: {
|
|
enableReasoning: boolean
|
|
enableWebSearch: boolean
|
|
enableGenerateImage: boolean
|
|
enableUrlContext: boolean
|
|
}
|
|
webSearchPluginConfig?: WebSearchPluginConfig
|
|
}> {
|
|
const { mcpTools } = options
|
|
|
|
const model = assistant.model || getDefaultModel()
|
|
const aiSdkProviderId = getAiSdkProviderId(provider)
|
|
|
|
// 这三个变量透传出来,交给下面启用插件/中间件
|
|
// 也可以在外部构建好再传入buildStreamTextParams
|
|
// FIXME: qwen3即使关闭思考仍然会导致enableReasoning的结果为true
|
|
const enableReasoning =
|
|
((isSupportedThinkingTokenModel(model) || isSupportedReasoningEffortModel(model)) &&
|
|
assistant.settings?.reasoning_effort !== undefined) ||
|
|
isFixedReasoningModel(model)
|
|
|
|
// 判断是否使用内置搜索
|
|
// 条件:没有外部搜索提供商 && (用户开启了内置搜索 || 模型强制使用内置搜索)
|
|
const hasExternalSearch = !!options.webSearchProviderId
|
|
const enableWebSearch =
|
|
!hasExternalSearch &&
|
|
((assistant.enableWebSearch && isWebSearchModel(model)) ||
|
|
isOpenRouterBuiltInWebSearchModel(model) ||
|
|
model.id.includes('sonar'))
|
|
|
|
// Validate provider and model support to prevent stale state from triggering urlContext
|
|
const enableUrlContext = !!(
|
|
assistant.enableUrlContext &&
|
|
isSupportUrlContextProvider(provider) &&
|
|
!isPureGenerateImageModel(model) &&
|
|
(isGeminiModel(model) || isAnthropicModel(model))
|
|
)
|
|
|
|
const enableGenerateImage = !!(isGenerateImageModel(model) && assistant.enableGenerateImage)
|
|
|
|
let tools = setupToolsConfig(mcpTools)
|
|
|
|
// 构建真正的 providerOptions
|
|
const webSearchConfig: CherryWebSearchConfig = {
|
|
maxResults: store.getState().websearch.maxResults,
|
|
excludeDomains: store.getState().websearch.excludeDomains,
|
|
searchWithTime: store.getState().websearch.searchWithTime
|
|
}
|
|
|
|
const { providerOptions, standardParams } = buildProviderOptions(assistant, model, provider, {
|
|
enableReasoning,
|
|
enableWebSearch,
|
|
enableGenerateImage
|
|
})
|
|
|
|
let webSearchPluginConfig: WebSearchPluginConfig | undefined = undefined
|
|
if (enableWebSearch) {
|
|
if (isBaseProvider(aiSdkProviderId)) {
|
|
webSearchPluginConfig = buildProviderBuiltinWebSearchConfig(aiSdkProviderId, webSearchConfig, model)
|
|
} else if (isAIGatewayProvider(provider) || SystemProviderIds.gateway === provider.id) {
|
|
const aiSdkProviderId = mapVertexAIGatewayModelToProviderId(model)
|
|
if (aiSdkProviderId) {
|
|
webSearchPluginConfig = buildProviderBuiltinWebSearchConfig(aiSdkProviderId, webSearchConfig, model)
|
|
}
|
|
}
|
|
if (!tools) {
|
|
tools = {}
|
|
}
|
|
if (aiSdkProviderId === 'google-vertex') {
|
|
tools.google_search = vertex.tools.googleSearch({}) as ProviderDefinedTool
|
|
} else if (aiSdkProviderId === 'google-vertex-anthropic') {
|
|
const blockedDomains = mapRegexToPatterns(webSearchConfig.excludeDomains)
|
|
tools.web_search = vertexAnthropic.tools.webSearch_20250305({
|
|
maxUses: webSearchConfig.maxResults,
|
|
blockedDomains: blockedDomains.length > 0 ? blockedDomains : undefined
|
|
}) as ProviderDefinedTool
|
|
} else if (aiSdkProviderId === 'azure-responses') {
|
|
tools.web_search_preview = azure.tools.webSearchPreview({
|
|
searchContextSize: webSearchPluginConfig?.openai!.searchContextSize
|
|
}) as ProviderDefinedTool
|
|
} else if (aiSdkProviderId === 'azure-anthropic') {
|
|
const blockedDomains = mapRegexToPatterns(webSearchConfig.excludeDomains)
|
|
const anthropicSearchOptions: AnthropicSearchConfig = {
|
|
maxUses: webSearchConfig.maxResults,
|
|
blockedDomains: blockedDomains.length > 0 ? blockedDomains : undefined
|
|
}
|
|
tools.web_search = anthropic.tools.webSearch_20250305(anthropicSearchOptions) as ProviderDefinedTool
|
|
}
|
|
}
|
|
|
|
if (enableUrlContext) {
|
|
if (!tools) {
|
|
tools = {}
|
|
}
|
|
const blockedDomains = mapRegexToPatterns(webSearchConfig.excludeDomains)
|
|
|
|
switch (aiSdkProviderId) {
|
|
case 'google-vertex':
|
|
tools.url_context = vertex.tools.urlContext({}) as ProviderDefinedTool
|
|
break
|
|
case 'google':
|
|
tools.url_context = google.tools.urlContext({}) as ProviderDefinedTool
|
|
break
|
|
case 'anthropic':
|
|
case 'azure-anthropic':
|
|
case 'google-vertex-anthropic':
|
|
if (['anthropic', 'azure-anthropic'].includes(aiSdkProviderId)) {
|
|
tools.web_fetch = anthropic.tools.webFetch_20250910({
|
|
maxUses: webSearchConfig.maxResults,
|
|
blockedDomains: blockedDomains.length > 0 ? blockedDomains : undefined
|
|
}) as ProviderDefinedTool
|
|
}
|
|
break
|
|
}
|
|
}
|
|
|
|
let headers: Record<string, string | undefined> = options.requestOptions?.headers ?? {}
|
|
|
|
if (isAnthropicModel(model) && !isAwsBedrockProvider(provider)) {
|
|
const betaHeaders = addAnthropicHeaders(assistant, model)
|
|
// Only add the anthropic-beta header if there are actual beta headers to include
|
|
if (betaHeaders.length > 0) {
|
|
const newBetaHeaders = { 'anthropic-beta': betaHeaders.join(',') }
|
|
headers = combineHeaders(headers, newBetaHeaders)
|
|
}
|
|
}
|
|
|
|
// 构建基础参数
|
|
// Note: standardParams (topK, frequencyPenalty, presencePenalty, stopSequences, seed)
|
|
// are extracted from custom parameters and passed directly to streamText()
|
|
// instead of being placed in providerOptions
|
|
const params: StreamTextParams = {
|
|
messages: sdkMessages,
|
|
maxOutputTokens: getMaxTokens(assistant, model),
|
|
temperature: getTemperature(assistant, model),
|
|
topP: getTopP(assistant, model),
|
|
// Include AI SDK standard params extracted from custom parameters
|
|
...standardParams,
|
|
abortSignal: options.requestOptions?.signal,
|
|
headers,
|
|
providerOptions,
|
|
stopWhen: stepCountIs(20),
|
|
maxRetries: 0
|
|
}
|
|
|
|
if (tools) {
|
|
params.tools = tools
|
|
}
|
|
|
|
let systemPrompt = assistant.prompt ? await replacePromptVariables(assistant.prompt, model.name) : ''
|
|
|
|
if (getEffectiveMcpMode(assistant) === 'auto') {
|
|
const allActiveTools = await fetchAllActiveServerTools()
|
|
const autoModePrompt = getHubModeSystemPrompt(allActiveTools)
|
|
if (autoModePrompt) {
|
|
systemPrompt = systemPrompt ? `${systemPrompt}\n\n${autoModePrompt}` : autoModePrompt
|
|
}
|
|
}
|
|
|
|
if (systemPrompt) {
|
|
params.system = systemPrompt
|
|
}
|
|
|
|
logger.debug('params', params)
|
|
|
|
return {
|
|
params,
|
|
modelId: model.id,
|
|
capabilities: { enableReasoning, enableWebSearch, enableGenerateImage, enableUrlContext },
|
|
webSearchPluginConfig
|
|
}
|
|
}
|
|
|
|
/**
|
|
* 构建非流式的 generateText 参数
|
|
*/
|
|
export async function buildGenerateTextParams(
|
|
messages: ModelMessage[],
|
|
assistant: Assistant,
|
|
provider: Provider,
|
|
options: {
|
|
mcpTools?: MCPTool[]
|
|
enableTools?: boolean
|
|
} = {}
|
|
): Promise<any> {
|
|
// 复用流式参数的构建逻辑
|
|
return await buildStreamTextParams(messages, assistant, provider, options)
|
|
}
|