refactor: add ai-sdk embedFunc and remove legacy

This commit is contained in:
suyao 2026-01-02 08:59:27 +08:00
parent 53a2e06120
commit f7ee2fc934
No known key found for this signature in database
61 changed files with 262 additions and 11115 deletions

View File

@ -5,6 +5,7 @@
import type { ImageModelV3, LanguageModelV3, LanguageModelV3Middleware, ProviderV3 } from '@ai-sdk/provider'
import type { LanguageModel } from 'ai'
import {
embedMany as _embedMany,
generateImage as _generateImage,
generateText as _generateText,
streamText as _streamText,
@ -17,7 +18,14 @@ import { type AiPlugin, type AiRequestContext, definePlugin } from '../plugins'
import type { CoreProviderSettingsMap, StringKeys } from '../providers/types'
import { ImageGenerationError, ImageModelResolutionError } from './errors'
import { PluginEngine } from './pluginEngine'
import type { generateImageParams, generateTextParams, RuntimeConfig, streamTextParams } from './types'
import type {
EmbedManyParams,
EmbedManyResult,
generateImageParams,
generateTextParams,
RuntimeConfig,
streamTextParams
} from './types'
export class RuntimeExecutor<
TSettingsMap extends Record<string, any> = CoreProviderSettingsMap,
@ -166,6 +174,23 @@ export class RuntimeExecutor<
}
}
/**
*
* AI SDK v6 embedMany embed
*/
async embedMany(params: EmbedManyParams): Promise<EmbedManyResult> {
const { model: modelOrId, ...options } = params
// 解析 embedding 模型
const embeddingModel =
typeof modelOrId === 'string' ? await this.modelResolver.resolveEmbeddingModel(modelOrId) : modelOrId
return _embedMany({
model: embeddingModel,
...options
})
}
// === 辅助方法 ===
/**

View File

@ -7,7 +7,7 @@
export { RuntimeExecutor } from './executor'
// 导出类型
export type { RuntimeConfig } from './types'
export type { EmbedManyParams, EmbedManyResult, RuntimeConfig } from './types'
// === 便捷工厂函数 ===
@ -84,6 +84,23 @@ export async function generateImage<
return executor.generateImage(params)
}
/**
*
* AI SDK v6 embedMany embed
*/
export async function embedMany<
TSettingsMap extends Record<string, any> = CoreProviderSettingsMap,
T extends StringKeys<TSettingsMap> = StringKeys<TSettingsMap>
>(
providerId: T,
options: TSettingsMap[T],
params: Parameters<RuntimeExecutor<TSettingsMap, T>['embedMany']>[0],
plugins?: AiPlugin[]
): Promise<ReturnType<RuntimeExecutor<TSettingsMap, T>['embedMany']>> {
const executor = await createExecutor<TSettingsMap, T>(providerId, options, plugins)
return executor.embedMany(params)
}
/**
* OpenAI Compatible
*/

View File

@ -1,8 +1,8 @@
/**
* Runtime
*/
import type { ImageModelV3, ProviderV3 } from '@ai-sdk/provider'
import type { generateImage, generateText, streamText } from 'ai'
import type { EmbeddingModelV3, ImageModelV3, ProviderV3 } from '@ai-sdk/provider'
import type { embedMany, generateImage, generateText, streamText } from 'ai'
import { type AiPlugin } from '../plugins'
import type { CoreProviderSettingsMap, StringKeys } from '../providers/types'
@ -28,3 +28,9 @@ export type generateImageParams = Omit<Parameters<typeof generateImage>[0], 'mod
}
export type generateTextParams = Parameters<typeof generateText>[0]
export type streamTextParams = Parameters<typeof streamText>[0]
// Embedding types (AI SDK v6 only has embedMany, no embed)
export type EmbedManyParams = Omit<Parameters<typeof embedMany>[0], 'model'> & {
model: string | EmbeddingModelV3
}
export type EmbedManyResult = Awaited<ReturnType<typeof embedMany>>

View File

@ -9,11 +9,15 @@
export {
createExecutor,
createOpenAICompatibleExecutor,
embedMany,
generateImage,
generateText,
streamText
} from './core/runtime'
// ==================== Embedding 类型 ====================
export type { EmbedManyParams, EmbedManyResult } from './core/runtime'
// ==================== 高级API ====================
export { isV2Model, isV3Model } from './core/models'

View File

@ -1,16 +1,12 @@
/**
* Cherry Studio AI Core -
*
*
* legacy AiProvider以保持现有代码的兼容性
* 使 ModernAiProvider
* Legacy provider
*/
// 导出Legacy AiProvider作为默认导出保持向后兼容
export { default } from './legacy/index'
import ModernAiProvider from './index_new'
// 同时导出Modern AiProvider供新代码使用
export { default as ModernAiProvider } from './index_new'
// 导出一些常用的类型和工具
export * from './legacy/clients/types'
export * from './legacy/middleware/schemas'
// 默认导出和命名导出
export default ModernAiProvider
export { ModernAiProvider }

View File

@ -27,12 +27,10 @@ import { buildClaudeCodeSystemModelMessage } from '@shared/anthropic'
import { gateway } from 'ai'
import AiSdkToChunkAdapter from './chunk/AiSdkToChunkAdapter'
import LegacyAiProvider from './legacy/index'
import type { CompletionsParams, CompletionsResult } from './legacy/middleware/schemas'
import { buildPlugins } from './plugins/PluginBuilder'
import { adaptProvider, getActualProvider, providerToAiSdkConfig } from './provider/providerConfig'
import { ModelListService } from './services/ModelListService'
import type { AppProviderSettingsMap, ProviderConfig } from './types'
import type { AppProviderSettingsMap, CompletionsResult, ProviderConfig } from './types'
import type { AiSdkMiddlewareConfig } from './types/middlewareConfig'
const logger = loggerService.withContext('ModernAiProvider')
@ -45,7 +43,6 @@ export type ModernAiProviderConfig = AiSdkMiddlewareConfig & {
}
export default class ModernAiProvider {
private legacyProvider: LegacyAiProvider
private config?: ProviderConfig
private actualProvider: Provider
private model?: Model
@ -101,8 +98,6 @@ export default class ModernAiProvider {
this.actualProvider = adaptProvider({ provider: modelOrProvider })
// model为可选某些操作如fetchModels不需要model
}
this.legacyProvider = new LegacyAiProvider(this.actualProvider)
}
/**
@ -169,28 +164,8 @@ export default class ModernAiProvider {
middlewareConfig: ModernAiProviderConfig,
providerConfig: ProviderConfig
): Promise<CompletionsResult> {
// ai-gateway不是image/generation 端点所以就先不走legacy了
if (middlewareConfig.isImageGenerationEndpoint && this.getActualProvider().id !== SystemProviderIds.gateway) {
// 使用 legacy 实现处理图像生成(支持图片编辑等高级功能)
if (!middlewareConfig.uiMessages) {
throw new Error('uiMessages is required for image generation endpoint')
}
const legacyParams: CompletionsParams = {
callType: 'chat',
messages: middlewareConfig.uiMessages, // 使用原始的 UI 消息格式
assistant: middlewareConfig.assistant,
streamOutput: middlewareConfig.streamOutput ?? true,
onChunk: middlewareConfig.onChunk,
topicId: middlewareConfig.topicId,
mcpTools: middlewareConfig.mcpTools,
enableWebSearch: middlewareConfig.enableWebSearch
}
// 调用 legacy 的 completions会自动使用 ImageGenerationMiddleware
return await this.legacyProvider.completions(legacyParams)
}
// 专用图像生成模型已在 ApiService 层分流到 fetchImageGeneration
// 这里只处理普通的 completions
return await this.modernCompletions(modelId, params, middlewareConfig, providerConfig)
}
@ -350,8 +325,29 @@ export default class ModernAiProvider {
return await ModelListService.listModels(this.actualProvider)
}
/**
*
* 使 AI SDK embedMany
*/
public async getEmbeddingDimensions(model: Model): Promise<number> {
return this.legacyProvider.getEmbeddingDimensions(model)
// 确保 config 已定义
if (!this.config) {
this.config = await Promise.resolve(providerToAiSdkConfig(this.actualProvider, model))
}
const executor = await createExecutor<AppProviderSettingsMap>(
this.config.providerId,
this.config.providerSettings,
[]
)
// 使用 AI SDK embedMany 测试获取维度
const result = await executor.embedMany({
model: model.id,
values: ['test']
})
return result.embeddings[0].length
}
/**
@ -453,10 +449,42 @@ export default class ModernAiProvider {
}
public getBaseURL(): string {
return this.legacyProvider.getBaseURL()
return this.actualProvider.apiHost || ''
}
public getApiKey(): string {
return this.legacyProvider.getApiKey()
const apiKey = this.actualProvider.apiKey
if (!apiKey || apiKey.trim() === '') {
return ''
}
const keys = apiKey
.split(',')
.map((key) => key.trim())
.filter(Boolean)
if (keys.length === 0) {
return ''
}
if (keys.length === 1) {
return keys[0]
}
// Multi-key rotation
const keyName = `provider:${this.actualProvider.id}:last_used_key`
const lastUsedKey = window.keyv.get(keyName)
if (!lastUsedKey) {
window.keyv.set(keyName, keys[0])
return keys[0]
}
const currentIndex = keys.indexOf(lastUsedKey)
const nextIndex = (currentIndex + 1) % keys.length
const nextKey = keys[nextIndex]
window.keyv.set(keyName, nextKey)
return nextKey
}
}

View File

@ -1,103 +0,0 @@
import { loggerService } from '@logger'
import type { Provider } from '@renderer/types'
import { isNewApiProvider } from '@renderer/utils/provider'
import { AihubmixAPIClient } from './aihubmix/AihubmixAPIClient'
import { AnthropicAPIClient } from './anthropic/AnthropicAPIClient'
import { AwsBedrockAPIClient } from './aws/AwsBedrockAPIClient'
import type { BaseApiClient } from './BaseApiClient'
import { CherryAiAPIClient } from './cherryai/CherryAiAPIClient'
import { GeminiAPIClient } from './gemini/GeminiAPIClient'
import { VertexAPIClient } from './gemini/VertexAPIClient'
import { NewAPIClient } from './newapi/NewAPIClient'
import { OpenAIAPIClient } from './openai/OpenAIApiClient'
import { OpenAIResponseAPIClient } from './openai/OpenAIResponseAPIClient'
import { OVMSClient } from './ovms/OVMSClient'
import { PPIOAPIClient } from './ppio/PPIOAPIClient'
import { ZhipuAPIClient } from './zhipu/ZhipuAPIClient'
const logger = loggerService.withContext('ApiClientFactory')
/**
* Factory for creating ApiClient instances based on provider configuration
* ApiClient实例的工厂
*/
export class ApiClientFactory {
/**
* Create an ApiClient instance for the given provider
* ApiClient实例
*/
static create(provider: Provider): BaseApiClient {
logger.debug(`Creating ApiClient for provider:`, {
id: provider.id,
type: provider.type
})
let instance: BaseApiClient
// 首先检查特殊的 Provider ID
if (provider.id === 'cherryai') {
instance = new CherryAiAPIClient(provider) as BaseApiClient
return instance
}
if (provider.id === 'aihubmix') {
logger.debug(`Creating AihubmixAPIClient for provider: ${provider.id}`)
instance = new AihubmixAPIClient(provider) as BaseApiClient
return instance
}
if (isNewApiProvider(provider)) {
logger.debug(`Creating NewAPIClient for provider: ${provider.id}`)
instance = new NewAPIClient(provider) as BaseApiClient
return instance
}
if (provider.id === 'ppio') {
logger.debug(`Creating PPIOAPIClient for provider: ${provider.id}`)
instance = new PPIOAPIClient(provider) as BaseApiClient
return instance
}
if (provider.id === 'zhipu') {
instance = new ZhipuAPIClient(provider) as BaseApiClient
return instance
}
if (provider.id === 'ovms') {
logger.debug(`Creating OVMSClient for provider: ${provider.id}`)
instance = new OVMSClient(provider) as BaseApiClient
return instance
}
// 然后检查标准的 Provider Type
switch (provider.type) {
case 'openai':
instance = new OpenAIAPIClient(provider) as BaseApiClient
break
case 'azure-openai':
case 'openai-response':
instance = new OpenAIResponseAPIClient(provider) as BaseApiClient
break
case 'gemini':
instance = new GeminiAPIClient(provider) as BaseApiClient
break
case 'vertexai':
logger.debug(`Creating VertexAPIClient for provider: ${provider.id}`)
instance = new VertexAPIClient(provider) as BaseApiClient
break
case 'anthropic':
instance = new AnthropicAPIClient(provider) as BaseApiClient
break
case 'aws-bedrock':
instance = new AwsBedrockAPIClient(provider) as BaseApiClient
break
default:
logger.debug(`Using default OpenAIApiClient for provider: ${provider.id}`)
instance = new OpenAIAPIClient(provider) as BaseApiClient
break
}
return instance
}
}

View File

@ -1,489 +0,0 @@
import { loggerService } from '@logger'
import {
getModelSupportedVerbosity,
isFunctionCallingModel,
isOpenAIModel,
isSupportFlexServiceTierModel,
isSupportTemperatureModel,
isSupportTopPModel
} from '@renderer/config/models'
import { REFERENCE_PROMPT } from '@renderer/config/prompts'
import { getLMStudioKeepAliveTime } from '@renderer/hooks/useLMStudio'
import { getAssistantSettings } from '@renderer/services/AssistantService'
import type { RootState } from '@renderer/store'
import type {
Assistant,
GenerateImageParams,
KnowledgeReference,
MCPCallToolResponse,
MCPTool,
MCPToolResponse,
MemoryItem,
Model,
Provider,
ToolCallResponse,
WebSearchProviderResponse,
WebSearchResponse
} from '@renderer/types'
import {
FileTypes,
GroqServiceTiers,
isGroqServiceTier,
isOpenAIServiceTier,
OpenAIServiceTiers,
SystemProviderIds
} from '@renderer/types'
import type { OpenAIVerbosity } from '@renderer/types/aiCoreTypes'
import type { Message } from '@renderer/types/newMessage'
import type {
RequestOptions,
SdkInstance,
SdkMessageParam,
SdkModel,
SdkParams,
SdkRawChunk,
SdkRawOutput,
SdkTool,
SdkToolCall
} from '@renderer/types/sdk'
import { isJSON, parseJSON } from '@renderer/utils'
import { addAbortController, removeAbortController } from '@renderer/utils/abortController'
import { findFileBlocks, getMainTextContent } from '@renderer/utils/messageUtils/find'
import { isSupportServiceTierProvider } from '@renderer/utils/provider'
import { defaultTimeout } from '@shared/config/constant'
import { defaultAppHeaders } from '@shared/utils'
import { isEmpty } from 'lodash'
import type { CompletionsContext } from '../middleware/types'
import type { ApiClient, RequestTransformer, ResponseChunkTransformer } from './types'
const logger = loggerService.withContext('BaseApiClient')
/**
* Abstract base class for API clients.
* Provides common functionality and structure for specific client implementations.
*/
export abstract class BaseApiClient<
TSdkInstance extends SdkInstance = SdkInstance,
TSdkParams extends SdkParams = SdkParams,
TRawOutput extends SdkRawOutput = SdkRawOutput,
TRawChunk extends SdkRawChunk = SdkRawChunk,
TMessageParam extends SdkMessageParam = SdkMessageParam,
TToolCall extends SdkToolCall = SdkToolCall,
TSdkSpecificTool extends SdkTool = SdkTool
> implements ApiClient<TSdkInstance, TSdkParams, TRawOutput, TRawChunk, TMessageParam, TToolCall, TSdkSpecificTool>
{
public provider: Provider
protected host: string
protected sdkInstance?: TSdkInstance
constructor(provider: Provider) {
this.provider = provider
this.host = this.getBaseURL()
}
/**
* Get the current API key with rotation support
* This getter ensures API keys rotate on each access when multiple keys are configured
*/
protected get apiKey(): string {
return this.getApiKey()
}
/**
*
* instanceof检查的类型收窄问题
* AihubmixAPIClient使
*/
// oxlint-disable-next-line @typescript-eslint/no-unused-vars
public getClientCompatibilityType(_model?: Model): string[] {
// 默认返回类的名称
return [this.constructor.name]
}
// // 核心的completions方法 - 在中间件架构中,这通常只是一个占位符
// abstract completions(params: CompletionsParams, internal?: ProcessingState): Promise<CompletionsResult>
/**
* API Endpoint
**/
abstract createCompletions(payload: TSdkParams, options?: RequestOptions): Promise<TRawOutput>
abstract generateImage(generateImageParams: GenerateImageParams): Promise<string[]>
abstract getEmbeddingDimensions(model?: Model): Promise<number>
abstract listModels(): Promise<SdkModel[]>
abstract getSdkInstance(): Promise<TSdkInstance> | TSdkInstance
/**
*
**/
// 在 CoreRequestToSdkParamsMiddleware中使用
abstract getRequestTransformer(): RequestTransformer<TSdkParams, TMessageParam>
// 在RawSdkChunkToGenericChunkMiddleware中使用
abstract getResponseChunkTransformer(ctx: CompletionsContext): ResponseChunkTransformer<TRawChunk>
/**
*
**/
// Optional tool conversion methods - implement if needed by the specific provider
abstract convertMcpToolsToSdkTools(mcpTools: MCPTool[]): TSdkSpecificTool[]
abstract convertSdkToolCallToMcp(toolCall: TToolCall, mcpTools: MCPTool[]): MCPTool | undefined
abstract convertSdkToolCallToMcpToolResponse(toolCall: TToolCall, mcpTool: MCPTool): ToolCallResponse
abstract buildSdkMessages(
currentReqMessages: TMessageParam[],
output: TRawOutput | string | undefined,
toolResults: TMessageParam[],
toolCalls?: TToolCall[]
): TMessageParam[]
abstract estimateMessageTokens(message: TMessageParam): number
abstract convertMcpToolResponseToSdkMessageParam(
mcpToolResponse: MCPToolResponse,
resp: MCPCallToolResponse,
model: Model
): TMessageParam | undefined
/**
* SDK载荷中提取消息数组访
* 使messageshistory等
*/
abstract extractMessagesFromSdkPayload(sdkPayload: TSdkParams): TMessageParam[]
/**
*
**/
public getBaseURL(): string {
return this.provider.apiHost
}
public getApiKey() {
const keys = this.provider.apiKey.split(',').map((key) => key.trim())
const keyName = `provider:${this.provider.id}:last_used_key`
if (keys.length === 1) {
return keys[0]
}
const lastUsedKey = window.keyv.get(keyName)
if (!lastUsedKey) {
window.keyv.set(keyName, keys[0])
return keys[0]
}
const currentIndex = keys.indexOf(lastUsedKey)
const nextIndex = (currentIndex + 1) % keys.length
const nextKey = keys[nextIndex]
window.keyv.set(keyName, nextKey)
return nextKey
}
public defaultHeaders() {
return {
...defaultAppHeaders(),
'X-Api-Key': this.apiKey
}
}
public get keepAliveTime() {
return this.provider.id === 'lmstudio' ? getLMStudioKeepAliveTime() : undefined
}
public getTemperature(assistant: Assistant, model: Model): number | undefined {
if (!isSupportTemperatureModel(model)) {
return undefined
}
const assistantSettings = getAssistantSettings(assistant)
return assistantSettings?.enableTemperature ? assistantSettings?.temperature : undefined
}
public getTopP(assistant: Assistant, model: Model): number | undefined {
if (!isSupportTopPModel(model)) {
return undefined
}
const assistantSettings = getAssistantSettings(assistant)
return assistantSettings?.enableTopP ? assistantSettings?.topP : undefined
}
// NOTE: 这个也许可以迁移到OpenAIBaseClient
protected getServiceTier(model: Model) {
const serviceTierSetting = this.provider.serviceTier
if (!isSupportServiceTierProvider(this.provider) || !isOpenAIModel(model) || !serviceTierSetting) {
return undefined
}
// 处理不同供应商需要 fallback 到默认值的情况
if (this.provider.id === SystemProviderIds.groq) {
if (
!isGroqServiceTier(serviceTierSetting) ||
(serviceTierSetting === GroqServiceTiers.flex && !isSupportFlexServiceTierModel(model))
) {
return undefined
}
} else {
// 其他 OpenAI 供应商,假设他们的服务层级设置和 OpenAI 完全相同
if (
!isOpenAIServiceTier(serviceTierSetting) ||
(serviceTierSetting === OpenAIServiceTiers.flex && !isSupportFlexServiceTierModel(model))
) {
return undefined
}
}
return serviceTierSetting
}
protected getVerbosity(model?: Model): OpenAIVerbosity {
try {
const state = window.store?.getState() as RootState
const verbosity = state?.settings?.openAI?.verbosity
// If model is provided, check if the verbosity is supported by the model
if (model) {
const supportedVerbosity = getModelSupportedVerbosity(model)
// Use user's verbosity if supported, otherwise use the first supported option
return supportedVerbosity.includes(verbosity) ? verbosity : supportedVerbosity[0]
}
return verbosity
} catch (error) {
logger.warn('Failed to get verbosity from state. Fallback to undefined.', error as Error)
return undefined
}
}
protected getTimeout(model: Model) {
if (isSupportFlexServiceTierModel(model)) {
return 15 * 1000 * 60
}
return defaultTimeout
}
public async getMessageContent(
message: Message
): Promise<{ textContent: string; imageContents: { fileId: string; fileExt: string }[] }> {
const content = getMainTextContent(message)
if (isEmpty(content)) {
return {
textContent: '',
imageContents: []
}
}
const webSearchReferences = await this.getWebSearchReferencesFromCache(message)
const knowledgeReferences = await this.getKnowledgeBaseReferencesFromCache(message)
const memoryReferences = this.getMemoryReferencesFromCache(message)
const knowledgeTextReferences = knowledgeReferences.filter((k) => k.metadata?.type !== 'image')
const knowledgeImageReferences = knowledgeReferences.filter((k) => k.metadata?.type === 'image')
// 添加偏移量以避免ID冲突
const reindexedKnowledgeReferences = knowledgeTextReferences.map((ref) => ({
...ref,
id: ref.id + webSearchReferences.length // 为知识库引用的ID添加网络搜索引用的数量作为偏移量
}))
const allReferences = [...webSearchReferences, ...reindexedKnowledgeReferences, ...memoryReferences]
logger.debug(`Found ${allReferences.length} references for ID: ${message.id}`, allReferences)
const referenceContent = `\`\`\`json\n${JSON.stringify(allReferences, null, 2)}\n\`\`\``
const imageReferences = knowledgeImageReferences.map((r) => {
return { fileId: r.metadata?.id, fileExt: r.metadata?.ext }
})
return {
textContent: isEmpty(allReferences)
? content
: REFERENCE_PROMPT.replace('{question}', content).replace('{references}', referenceContent),
imageContents: isEmpty(knowledgeImageReferences) ? [] : imageReferences
}
}
/**
* Extract the file content from the message
* @param message - The message
* @returns The file content
*/
protected async extractFileContent(message: Message) {
const fileBlocks = findFileBlocks(message)
if (fileBlocks.length > 0) {
const textFileBlocks = fileBlocks.filter(
(fb) => fb.file && [FileTypes.TEXT, FileTypes.DOCUMENT].includes(fb.file.type)
)
if (textFileBlocks.length > 0) {
let text = ''
const divider = '\n\n---\n\n'
for (const fileBlock of textFileBlocks) {
const file = fileBlock.file
const fileContent = (await window.api.file.read(file.id + file.ext, true)).trim()
const fileNameRow = 'file: ' + file.origin_name + '\n\n'
text = text + fileNameRow + fileContent + divider
}
return text
}
}
return ''
}
private getMemoryReferencesFromCache(message: Message) {
const memories = window.keyv.get(`memory-search-${message.id}`) as MemoryItem[] | undefined
if (memories) {
const memoryReferences: KnowledgeReference[] = memories.map((mem, index) => ({
id: index + 1,
content: `${mem.memory} -- Created at: ${mem.createdAt}`,
sourceUrl: '',
type: 'memory'
}))
return memoryReferences
}
return []
}
private async getWebSearchReferencesFromCache(message: Message) {
const content = getMainTextContent(message)
if (isEmpty(content)) {
return []
}
const webSearch: WebSearchResponse = window.keyv.get(`web-search-${message.id}`)
if (webSearch) {
window.keyv.remove(`web-search-${message.id}`)
return (webSearch.results as WebSearchProviderResponse).results.map(
(result, index) =>
({
id: index + 1,
content: result.content,
sourceUrl: result.url,
type: 'url'
}) as KnowledgeReference
)
}
return []
}
/**
*
*/
private async getKnowledgeBaseReferencesFromCache(message: Message): Promise<KnowledgeReference[]> {
const content = getMainTextContent(message)
if (isEmpty(content)) {
return []
}
const knowledgeReferences: KnowledgeReference[] = window.keyv.get(`knowledge-search-${message.id}`)
if (!isEmpty(knowledgeReferences)) {
window.keyv.remove(`knowledge-search-${message.id}`)
logger.debug(`Found ${knowledgeReferences.length} knowledge base references in cache for ID: ${message.id}`)
return knowledgeReferences
}
logger.debug(`No knowledge base references found in cache for ID: ${message.id}`)
return []
}
protected getCustomParameters(assistant: Assistant) {
return (
assistant?.settings?.customParameters?.reduce((acc, param) => {
if (!param.name?.trim()) {
return acc
}
// Parse JSON type parameters (Legacy API clients)
// Related: src/renderer/src/pages/settings/AssistantSettings/AssistantModelSettings.tsx:133-148
// The UI stores JSON type params as strings, this function parses them before sending to API
if (param.type === 'json') {
const value = param.value as string
if (value === 'undefined') {
return { ...acc, [param.name]: undefined }
}
return { ...acc, [param.name]: isJSON(value) ? parseJSON(value) : value }
}
return {
...acc,
[param.name]: param.value
}
}, {}) || {}
)
}
public createAbortController(messageId?: string, isAddEventListener?: boolean) {
const abortController = new AbortController()
const abortFn = () => abortController.abort()
if (messageId) {
addAbortController(messageId, abortFn)
}
const cleanup = () => {
if (messageId) {
signalPromise.resolve?.(undefined)
removeAbortController(messageId, abortFn)
}
}
const signalPromise: {
resolve: (value: unknown) => void
promise: Promise<unknown>
} = {
resolve: () => {},
promise: Promise.resolve()
}
if (isAddEventListener) {
signalPromise.promise = new Promise((resolve, reject) => {
signalPromise.resolve = resolve
if (abortController.signal.aborted) {
reject(new Error('Request was aborted.'))
}
// 捕获abort事件,有些abort事件必须
abortController.signal.addEventListener('abort', () => {
reject(new Error('Request was aborted.'))
})
})
return {
abortController,
cleanup,
signalPromise
}
}
return {
abortController,
cleanup
}
}
// Setup tools configuration based on provided parameters
public setupToolsConfig(params: { mcpTools?: MCPTool[]; model: Model; enableToolUse?: boolean }): {
tools: TSdkSpecificTool[]
} {
const { mcpTools, model, enableToolUse } = params
let tools: TSdkSpecificTool[] = []
// If there are no tools, return an empty array
if (!mcpTools?.length) {
return { tools }
}
// If the model supports function calling and tool usage is enabled
if (isFunctionCallingModel(model) && enableToolUse) {
tools = this.convertMcpToolsToSdkTools(mcpTools)
}
return { tools }
}
}

View File

@ -1,181 +0,0 @@
import type {
GenerateImageParams,
MCPCallToolResponse,
MCPTool,
MCPToolResponse,
Model,
Provider,
ToolCallResponse
} from '@renderer/types'
import type {
RequestOptions,
SdkInstance,
SdkMessageParam,
SdkModel,
SdkParams,
SdkRawChunk,
SdkRawOutput,
SdkTool,
SdkToolCall
} from '@renderer/types/sdk'
import type { CompletionsContext } from '../middleware/types'
import type { AnthropicAPIClient } from './anthropic/AnthropicAPIClient'
import { BaseApiClient } from './BaseApiClient'
import type { GeminiAPIClient } from './gemini/GeminiAPIClient'
import type { OpenAIAPIClient } from './openai/OpenAIApiClient'
import type { OpenAIResponseAPIClient } from './openai/OpenAIResponseAPIClient'
import type { RequestTransformer, ResponseChunkTransformer } from './types'
/**
* MixedAPIClient - Provider
*/
export abstract class MixedBaseAPIClient extends BaseApiClient {
// 使用联合类型而不是any保持类型安全
protected abstract clients: Map<
string,
AnthropicAPIClient | GeminiAPIClient | OpenAIResponseAPIClient | OpenAIAPIClient
>
protected abstract defaultClient: OpenAIAPIClient
protected abstract currentClient: BaseApiClient
constructor(provider: Provider) {
super(provider)
}
override getBaseURL(): string {
if (!this.currentClient) {
return this.provider.apiHost
}
return this.currentClient.getBaseURL()
}
/**
* client是BaseApiClient的实例
*/
protected isValidClient(client: unknown): client is BaseApiClient {
return (
client !== null &&
client !== undefined &&
typeof client === 'object' &&
'createCompletions' in client &&
'getRequestTransformer' in client &&
'getResponseChunkTransformer' in client
)
}
/**
* client
*/
protected abstract getClient(model: Model): BaseApiClient
/**
* client并委托调用
*/
public getClientForModel(model: Model): BaseApiClient {
this.currentClient = this.getClient(model)
return this.currentClient
}
/**
* 使
*/
public override getClientCompatibilityType(model?: Model): string[] {
if (!model) {
return [this.constructor.name]
}
const actualClient = this.getClient(model)
return actualClient.getClientCompatibilityType(model)
}
/**
* SDK payload中提取模型ID
*/
protected extractModelFromPayload(payload: SdkParams): string | null {
// 不同的SDK可能有不同的字段名
if ('model' in payload && typeof payload.model === 'string') {
return payload.model
}
return null
}
// ============ BaseApiClient 的抽象方法 ============
async createCompletions(payload: SdkParams, options?: RequestOptions): Promise<SdkRawOutput> {
// 尝试从payload中提取模型信息来选择client
const modelId = this.extractModelFromPayload(payload)
if (modelId) {
const modelObj = { id: modelId } as Model
const targetClient = this.getClient(modelObj)
return targetClient.createCompletions(payload, options)
}
// 如果无法从payload中提取模型使用当前设置的client
return this.currentClient.createCompletions(payload, options)
}
async generateImage(params: GenerateImageParams): Promise<string[]> {
return this.currentClient.generateImage(params)
}
async getEmbeddingDimensions(model?: Model): Promise<number> {
const client = model ? this.getClient(model) : this.currentClient
return client.getEmbeddingDimensions(model)
}
async listModels(): Promise<SdkModel[]> {
// 可以聚合所有client的模型或者使用默认client
return this.defaultClient.listModels()
}
async getSdkInstance(): Promise<SdkInstance> {
return this.currentClient.getSdkInstance()
}
getRequestTransformer(): RequestTransformer<SdkParams, SdkMessageParam> {
return this.currentClient.getRequestTransformer()
}
getResponseChunkTransformer(ctx: CompletionsContext): ResponseChunkTransformer<SdkRawChunk> {
return this.currentClient.getResponseChunkTransformer(ctx)
}
convertMcpToolsToSdkTools(mcpTools: MCPTool[]): SdkTool[] {
return this.currentClient.convertMcpToolsToSdkTools(mcpTools)
}
convertSdkToolCallToMcp(toolCall: SdkToolCall, mcpTools: MCPTool[]): MCPTool | undefined {
return this.currentClient.convertSdkToolCallToMcp(toolCall, mcpTools)
}
convertSdkToolCallToMcpToolResponse(toolCall: SdkToolCall, mcpTool: MCPTool): ToolCallResponse {
return this.currentClient.convertSdkToolCallToMcpToolResponse(toolCall, mcpTool)
}
buildSdkMessages(
currentReqMessages: SdkMessageParam[],
output: SdkRawOutput | string,
toolResults: SdkMessageParam[],
toolCalls?: SdkToolCall[]
): SdkMessageParam[] {
return this.currentClient.buildSdkMessages(currentReqMessages, output, toolResults, toolCalls)
}
estimateMessageTokens(message: SdkMessageParam): number {
return this.currentClient.estimateMessageTokens(message)
}
convertMcpToolResponseToSdkMessageParam(
mcpToolResponse: MCPToolResponse,
resp: MCPCallToolResponse,
model: Model
): SdkMessageParam | undefined {
const client = this.getClient(model)
return client.convertMcpToolResponseToSdkMessageParam(mcpToolResponse, resp, model)
}
extractMessagesFromSdkPayload(sdkPayload: SdkParams): SdkMessageParam[] {
return this.currentClient.extractMessagesFromSdkPayload(sdkPayload)
}
}

View File

@ -1,221 +0,0 @@
import type { Provider } from '@renderer/types'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import { AihubmixAPIClient } from '../aihubmix/AihubmixAPIClient'
import { AnthropicAPIClient } from '../anthropic/AnthropicAPIClient'
import { ApiClientFactory } from '../ApiClientFactory'
import { AwsBedrockAPIClient } from '../aws/AwsBedrockAPIClient'
import { GeminiAPIClient } from '../gemini/GeminiAPIClient'
import { VertexAPIClient } from '../gemini/VertexAPIClient'
import { NewAPIClient } from '../newapi/NewAPIClient'
import { OpenAIAPIClient } from '../openai/OpenAIApiClient'
import { OpenAIResponseAPIClient } from '../openai/OpenAIResponseAPIClient'
import { PPIOAPIClient } from '../ppio/PPIOAPIClient'
// 为工厂测试创建最小化 provider 的辅助函数
// ApiClientFactory 只使用 'id' 和 'type' 字段来决定创建哪个客户端
// 其他字段会传递给客户端构造函数,但不影响工厂逻辑
const createTestProvider = (id: string, type: string): Provider => ({
id,
type: type as Provider['type'],
name: '',
apiKey: '',
apiHost: '',
models: []
})
// Mock 所有客户端模块
vi.mock('../aihubmix/AihubmixAPIClient', () => ({
AihubmixAPIClient: vi.fn().mockImplementation(() => ({}))
}))
vi.mock('../anthropic/AnthropicAPIClient', () => ({
AnthropicAPIClient: vi.fn().mockImplementation(() => ({}))
}))
vi.mock('../anthropic/AnthropicVertexClient', () => ({
AnthropicVertexClient: vi.fn().mockImplementation(() => ({}))
}))
vi.mock('../gemini/GeminiAPIClient', () => ({
GeminiAPIClient: vi.fn().mockImplementation(() => ({}))
}))
vi.mock('../gemini/VertexAPIClient', () => ({
VertexAPIClient: vi.fn().mockImplementation(() => ({}))
}))
vi.mock('../newapi/NewAPIClient', () => ({
NewAPIClient: vi.fn().mockImplementation(() => ({}))
}))
vi.mock('../openai/OpenAIApiClient', () => ({
OpenAIAPIClient: vi.fn().mockImplementation(() => ({}))
}))
vi.mock('../openai/OpenAIResponseAPIClient', () => ({
OpenAIResponseAPIClient: vi.fn().mockImplementation(() => ({
getClient: vi.fn().mockReturnThis()
}))
}))
vi.mock('../ppio/PPIOAPIClient', () => ({
PPIOAPIClient: vi.fn().mockImplementation(() => ({}))
}))
vi.mock('../aws/AwsBedrockAPIClient', () => ({
AwsBedrockAPIClient: vi.fn().mockImplementation(() => ({}))
}))
vi.mock('@renderer/services/AssistantService.ts', () => ({
getDefaultAssistant: () => {
return {
id: 'default',
name: 'default',
emoji: '😀',
prompt: '',
topics: [],
messages: [],
type: 'assistant',
regularPhrases: [],
settings: {}
}
}
}))
// Mock the models config to prevent circular dependency issues
vi.mock('@renderer/config/models', () => ({
findTokenLimit: vi.fn(),
isReasoningModel: vi.fn(),
isOpenAILLMModel: vi.fn(),
SYSTEM_MODELS: {
silicon: [],
defaultModel: []
},
isOpenAIModel: vi.fn(() => false),
glm45FlashModel: {},
qwen38bModel: {}
}))
describe('ApiClientFactory', () => {
beforeEach(() => {
vi.clearAllMocks()
})
describe('create', () => {
// 测试特殊 ID 的客户端创建
it('should create AihubmixAPIClient for aihubmix provider', () => {
const provider = createTestProvider('aihubmix', 'openai')
const client = ApiClientFactory.create(provider)
expect(AihubmixAPIClient).toHaveBeenCalledWith(provider)
expect(client).toBeDefined()
})
it('should create NewAPIClient for new-api provider', () => {
const provider = createTestProvider('new-api', 'openai')
const client = ApiClientFactory.create(provider)
expect(NewAPIClient).toHaveBeenCalledWith(provider)
expect(client).toBeDefined()
})
it('should create PPIOAPIClient for ppio provider', () => {
const provider = createTestProvider('ppio', 'openai')
const client = ApiClientFactory.create(provider)
expect(PPIOAPIClient).toHaveBeenCalledWith(provider)
expect(client).toBeDefined()
})
// 测试标准类型的客户端创建
it('should create OpenAIAPIClient for openai type', () => {
const provider = createTestProvider('custom-openai', 'openai')
const client = ApiClientFactory.create(provider)
expect(OpenAIAPIClient).toHaveBeenCalledWith(provider)
expect(client).toBeDefined()
})
it('should create OpenAIResponseAPIClient for azure-openai type', () => {
const provider = createTestProvider('azure-openai', 'azure-openai')
const client = ApiClientFactory.create(provider)
expect(OpenAIResponseAPIClient).toHaveBeenCalledWith(provider)
expect(client).toBeDefined()
})
it('should create OpenAIResponseAPIClient for openai-response type', () => {
const provider = createTestProvider('response', 'openai-response')
const client = ApiClientFactory.create(provider)
expect(OpenAIResponseAPIClient).toHaveBeenCalledWith(provider)
expect(client).toBeDefined()
})
it('should create GeminiAPIClient for gemini type', () => {
const provider = createTestProvider('gemini', 'gemini')
const client = ApiClientFactory.create(provider)
expect(GeminiAPIClient).toHaveBeenCalledWith(provider)
expect(client).toBeDefined()
})
it('should create VertexAPIClient for vertexai type', () => {
const provider = createTestProvider('vertex', 'vertexai')
const client = ApiClientFactory.create(provider)
expect(VertexAPIClient).toHaveBeenCalledWith(provider)
expect(client).toBeDefined()
})
it('should create AnthropicAPIClient for anthropic type', () => {
const provider = createTestProvider('anthropic', 'anthropic')
const client = ApiClientFactory.create(provider)
expect(AnthropicAPIClient).toHaveBeenCalledWith(provider)
expect(client).toBeDefined()
})
it('should create AwsBedrockAPIClient for aws-bedrock type', () => {
const provider = createTestProvider('aws-bedrock', 'aws-bedrock')
const client = ApiClientFactory.create(provider)
expect(AwsBedrockAPIClient).toHaveBeenCalledWith(provider)
expect(client).toBeDefined()
})
// 测试默认情况
it('should create OpenAIAPIClient as default for unknown type', () => {
const provider = createTestProvider('unknown', 'unknown-type')
const client = ApiClientFactory.create(provider)
expect(OpenAIAPIClient).toHaveBeenCalledWith(provider)
expect(client).toBeDefined()
})
// 测试边界条件
it('should handle provider with minimal configuration', () => {
const provider = createTestProvider('minimal', 'openai')
const client = ApiClientFactory.create(provider)
expect(OpenAIAPIClient).toHaveBeenCalledWith(provider)
expect(client).toBeDefined()
})
// 测试特殊 ID 优先级高于类型
it('should prioritize special ID over type', () => {
const provider = createTestProvider('aihubmix', 'anthropic') // 即使类型是 anthropic
const client = ApiClientFactory.create(provider)
// 应该创建 AihubmixAPIClient 而不是 AnthropicAPIClient
expect(AihubmixAPIClient).toHaveBeenCalledWith(provider)
expect(AnthropicAPIClient).not.toHaveBeenCalled()
expect(client).toBeDefined()
})
})
})

View File

@ -1,38 +0,0 @@
import { describe, expect, it } from 'vitest'
import { normalizeAzureOpenAIEndpoint } from '../openai/azureOpenAIEndpoint'
describe('normalizeAzureOpenAIEndpoint', () => {
it.each([
{
apiHost: 'https://example.openai.azure.com/openai',
expectedEndpoint: 'https://example.openai.azure.com'
},
{
apiHost: 'https://example.openai.azure.com/openai/',
expectedEndpoint: 'https://example.openai.azure.com'
},
{
apiHost: 'https://example.openai.azure.com/openai/v1',
expectedEndpoint: 'https://example.openai.azure.com'
},
{
apiHost: 'https://example.openai.azure.com/openai/v1/',
expectedEndpoint: 'https://example.openai.azure.com'
},
{
apiHost: 'https://example.openai.azure.com',
expectedEndpoint: 'https://example.openai.azure.com'
},
{
apiHost: 'https://example.openai.azure.com/',
expectedEndpoint: 'https://example.openai.azure.com'
},
{
apiHost: 'https://example.openai.azure.com/OPENAI/V1',
expectedEndpoint: 'https://example.openai.azure.com'
}
])('strips trailing /openai from $apiHost', ({ apiHost, expectedEndpoint }) => {
expect(normalizeAzureOpenAIEndpoint(apiHost)).toBe(expectedEndpoint)
})
})

View File

@ -1,354 +0,0 @@
import { AihubmixAPIClient } from '@renderer/aiCore/legacy/clients/aihubmix/AihubmixAPIClient'
import { AnthropicAPIClient } from '@renderer/aiCore/legacy/clients/anthropic/AnthropicAPIClient'
import { ApiClientFactory } from '@renderer/aiCore/legacy/clients/ApiClientFactory'
import { GeminiAPIClient } from '@renderer/aiCore/legacy/clients/gemini/GeminiAPIClient'
import { VertexAPIClient } from '@renderer/aiCore/legacy/clients/gemini/VertexAPIClient'
import { NewAPIClient } from '@renderer/aiCore/legacy/clients/newapi/NewAPIClient'
import { OpenAIAPIClient } from '@renderer/aiCore/legacy/clients/openai/OpenAIApiClient'
import { OpenAIResponseAPIClient } from '@renderer/aiCore/legacy/clients/openai/OpenAIResponseAPIClient'
import type { EndpointType, Model, Provider } from '@renderer/types'
import { beforeEach, describe, expect, it, vi } from 'vitest'
vi.mock('@renderer/config/models', () => ({
SYSTEM_MODELS: {
defaultModel: [
{ id: 'gpt-4', name: 'GPT-4' },
{ id: 'gpt-4', name: 'GPT-4' },
{ id: 'gpt-4', name: 'GPT-4' }
],
zhipu: [],
silicon: [],
openai: [],
anthropic: [],
gemini: []
},
isOpenAIModel: vi.fn().mockReturnValue(true),
isOpenAILLMModel: vi.fn().mockReturnValue(true),
isOpenAIChatCompletionOnlyModel: vi.fn().mockReturnValue(false),
isAnthropicLLMModel: vi.fn().mockReturnValue(false),
isGeminiLLMModel: vi.fn().mockReturnValue(false),
isSupportedReasoningEffortOpenAIModel: vi.fn().mockReturnValue(false),
isVisionModel: vi.fn().mockReturnValue(false),
isClaudeReasoningModel: vi.fn().mockReturnValue(false),
isReasoningModel: vi.fn().mockReturnValue(false),
isWebSearchModel: vi.fn().mockReturnValue(false),
findTokenLimit: vi.fn().mockReturnValue(4096),
isFunctionCallingModel: vi.fn().mockReturnValue(false),
DEFAULT_MAX_TOKENS: 4096,
qwen38bModel: {},
glm45FlashModel: {}
}))
vi.mock('@renderer/services/AssistantService', () => ({
getProviderByModel: vi.fn(),
getAssistantSettings: vi.fn(),
getDefaultAssistant: vi.fn().mockReturnValue({
id: 'default',
name: 'Default Assistant',
prompt: '',
settings: {}
})
}))
vi.mock('@renderer/services/FileManager', () => ({
default: class {
static async read() {
return 'test content'
}
static async write() {
return true
}
}
}))
vi.mock('@renderer/services/TokenService', () => ({
estimateTextTokens: vi.fn().mockReturnValue(100)
}))
vi.mock('@logger', () => ({
loggerService: {
withContext: vi.fn().mockReturnValue({
debug: vi.fn(),
info: vi.fn(),
warn: vi.fn(),
error: vi.fn(),
silly: vi.fn()
})
}
}))
// 到底是谁想出来的在服务层调用 React Hook ?????????
// Mock additional services and hooks that might be imported
vi.mock('@renderer/hooks/useVertexAI', () => ({
getVertexAILocation: vi.fn().mockReturnValue('us-central1'),
getVertexAIProjectId: vi.fn().mockReturnValue('test-project'),
getVertexAIServiceAccount: vi.fn().mockReturnValue({
privateKey: 'test-key',
clientEmail: 'test@example.com'
}),
isVertexAIConfigured: vi.fn().mockReturnValue(true),
isVertexProvider: vi.fn().mockReturnValue(true)
}))
vi.mock('@renderer/hooks/useSettings', () => ({
getStoreSetting: vi.fn().mockReturnValue({}),
useSettings: vi.fn().mockReturnValue([{}, vi.fn()])
}))
vi.mock('@renderer/store/settings', () => ({
default: {},
settingsSlice: {
name: 'settings',
reducer: vi.fn(),
actions: {}
}
}))
vi.mock('@renderer/utils/abortController', () => ({
addAbortController: vi.fn(),
removeAbortController: vi.fn()
}))
vi.mock('@anthropic-ai/sdk', () => ({
default: vi.fn().mockImplementation(() => ({}))
}))
vi.mock('@anthropic-ai/vertex-sdk', () => ({
default: vi.fn().mockImplementation(() => ({}))
}))
vi.mock('openai', () => ({
default: vi.fn().mockImplementation(() => ({})),
AzureOpenAI: vi.fn().mockImplementation(() => ({}))
}))
vi.mock('@google/generative-ai', () => ({
GoogleGenerativeAI: vi.fn().mockImplementation(() => ({}))
}))
vi.mock('@google-cloud/vertexai', () => ({
VertexAI: vi.fn().mockImplementation(() => ({}))
}))
// Mock the circular dependency between VertexAPIClient and AnthropicVertexClient
vi.mock('@renderer/aiCore/legacy/clients/anthropic/AnthropicVertexClient', () => {
const MockAnthropicVertexClient = vi.fn()
MockAnthropicVertexClient.prototype.getClientCompatibilityType = vi.fn().mockReturnValue(['AnthropicVertexAPIClient'])
return {
AnthropicVertexClient: MockAnthropicVertexClient
}
})
// Helper to create test provider
const createTestProvider = (id: string, type: string): Provider => ({
id,
type: type as Provider['type'],
name: 'Test Provider',
apiKey: 'test-key',
apiHost: 'https://api.test.com',
models: []
})
// Helper to create test model
const createTestModel = (id: string, provider?: string, endpointType?: string): Model => ({
id,
name: 'Test Model',
provider: provider || 'test',
type: [],
group: 'test',
endpoint_type: endpointType as EndpointType
})
describe('Client Compatibility Types', () => {
let openaiProvider: Provider
let anthropicProvider: Provider
let geminiProvider: Provider
let azureProvider: Provider
let aihubmixProvider: Provider
let newApiProvider: Provider
let vertexProvider: Provider
beforeEach(() => {
vi.clearAllMocks()
openaiProvider = createTestProvider('openai', 'openai')
anthropicProvider = createTestProvider('anthropic', 'anthropic')
geminiProvider = createTestProvider('gemini', 'gemini')
azureProvider = createTestProvider('azure-openai', 'azure-openai')
aihubmixProvider = createTestProvider('aihubmix', 'openai')
newApiProvider = createTestProvider('new-api', 'openai')
vertexProvider = createTestProvider('vertex', 'vertexai')
})
describe('Direct API Clients', () => {
it('should return correct compatibility type for OpenAIAPIClient', () => {
const client = new OpenAIAPIClient(openaiProvider)
const compatibilityTypes = client.getClientCompatibilityType()
expect(compatibilityTypes).toEqual(['OpenAIAPIClient'])
})
it('should return correct compatibility type for AnthropicAPIClient', () => {
const client = new AnthropicAPIClient(anthropicProvider)
const compatibilityTypes = client.getClientCompatibilityType()
expect(compatibilityTypes).toEqual(['AnthropicAPIClient'])
})
it('should return correct compatibility type for GeminiAPIClient', () => {
const client = new GeminiAPIClient(geminiProvider)
const compatibilityTypes = client.getClientCompatibilityType()
expect(compatibilityTypes).toEqual(['GeminiAPIClient'])
})
})
describe('Decorator Pattern API Clients', () => {
it('should return OpenAIResponseAPIClient for OpenAIResponseAPIClient without model', () => {
const client = new OpenAIResponseAPIClient(azureProvider)
const compatibilityTypes = client.getClientCompatibilityType()
expect(compatibilityTypes).toEqual(['OpenAIResponseAPIClient'])
})
it('should delegate to underlying client for OpenAIResponseAPIClient with model', () => {
const client = new OpenAIResponseAPIClient(azureProvider)
const testModel = createTestModel('gpt-4', 'azure-openai')
// Get the actual client selected for this model
const actualClient = client.getClient(testModel)
const compatibilityTypes = actualClient.getClientCompatibilityType(testModel)
// Should return OpenAIResponseAPIClient for non-chat-completion-only models
expect(compatibilityTypes).toEqual(['OpenAIAPIClient'])
})
it('should return AihubmixAPIClient for AihubmixAPIClient without model', () => {
const client = new AihubmixAPIClient(aihubmixProvider)
const compatibilityTypes = client.getClientCompatibilityType()
expect(compatibilityTypes).toEqual(['AihubmixAPIClient'])
})
it('should delegate to underlying client for AihubmixAPIClient with model', () => {
const client = new AihubmixAPIClient(aihubmixProvider)
const testModel = createTestModel('gpt-4', 'openai')
// Get the actual client selected for this model
const actualClient = client.getClientForModel(testModel)
const compatibilityTypes = actualClient.getClientCompatibilityType(testModel)
// Should return the actual underlying client type based on model (OpenAI models use OpenAIResponseAPIClient in Aihubmix)
expect(compatibilityTypes).toEqual(['OpenAIResponseAPIClient'])
})
it('should return NewAPIClient for NewAPIClient without model', () => {
const client = new NewAPIClient(newApiProvider)
const compatibilityTypes = client.getClientCompatibilityType()
expect(compatibilityTypes).toEqual(['NewAPIClient'])
})
it('should delegate to underlying client for NewAPIClient with model', () => {
const client = new NewAPIClient(newApiProvider)
const testModel = createTestModel('gpt-4', 'openai', 'openai-response')
// Get the actual client selected for this model
const actualClient = client.getClientForModel(testModel)
const compatibilityTypes = actualClient.getClientCompatibilityType(testModel)
// Should return the actual underlying client type based on model
expect(compatibilityTypes).toEqual(['OpenAIResponseAPIClient'])
})
it('should return VertexAPIClient for VertexAPIClient without model', () => {
const client = new VertexAPIClient(vertexProvider)
const compatibilityTypes = client.getClientCompatibilityType()
expect(compatibilityTypes).toEqual(['VertexAPIClient'])
})
it('should delegate to underlying client for VertexAPIClient with model', () => {
const client = new VertexAPIClient(vertexProvider)
const testModel = createTestModel('claude-3-5-sonnet', 'vertexai')
// Get the actual client selected for this model
const actualClient = client.getClient(testModel)
const compatibilityTypes = actualClient.getClientCompatibilityType(testModel)
// Should return the actual underlying client type based on model (Claude models use AnthropicVertexClient)
expect(compatibilityTypes).toEqual(['AnthropicVertexAPIClient'])
})
})
describe('Middleware Compatibility Logic', () => {
it('should correctly identify OpenAI compatible clients', () => {
const openaiClient = new OpenAIAPIClient(openaiProvider)
const openaiResponseClient = new OpenAIResponseAPIClient(azureProvider)
const openaiTypes = openaiClient.getClientCompatibilityType()
const responseTypes = openaiResponseClient.getClientCompatibilityType()
// Test the logic from completions method line 94
const isOpenAICompatible = (types: string[]) =>
types.includes('OpenAIAPIClient') || types.includes('OpenAIResponseAPIClient')
expect(isOpenAICompatible(openaiTypes)).toBe(true)
expect(isOpenAICompatible(responseTypes)).toBe(true)
})
it('should correctly identify Anthropic or OpenAIResponse compatible clients', () => {
const anthropicClient = new AnthropicAPIClient(anthropicProvider)
const openaiResponseClient = new OpenAIResponseAPIClient(azureProvider)
const openaiClient = new OpenAIAPIClient(openaiProvider)
const anthropicTypes = anthropicClient.getClientCompatibilityType()
const responseTypes = openaiResponseClient.getClientCompatibilityType()
const openaiTypes = openaiClient.getClientCompatibilityType()
// Test the logic from completions method line 101
const isAnthropicOrOpenAIResponseCompatible = (types: string[]) =>
types.includes('AnthropicAPIClient') || types.includes('OpenAIResponseAPIClient')
expect(isAnthropicOrOpenAIResponseCompatible(anthropicTypes)).toBe(true)
expect(isAnthropicOrOpenAIResponseCompatible(responseTypes)).toBe(true)
expect(isAnthropicOrOpenAIResponseCompatible(openaiTypes)).toBe(false)
})
it('should handle non-compatible clients correctly', () => {
const geminiClient = new GeminiAPIClient(geminiProvider)
const geminiTypes = geminiClient.getClientCompatibilityType()
// Test that Gemini is not OpenAI compatible
const isOpenAICompatible = (types: string[]) =>
types.includes('OpenAIAPIClient') || types.includes('OpenAIResponseAPIClient')
// Test that Gemini is not Anthropic/OpenAIResponse compatible
const isAnthropicOrOpenAIResponseCompatible = (types: string[]) =>
types.includes('AnthropicAPIClient') || types.includes('OpenAIResponseAPIClient')
expect(isOpenAICompatible(geminiTypes)).toBe(false)
expect(isAnthropicOrOpenAIResponseCompatible(geminiTypes)).toBe(false)
})
})
describe('Factory Integration', () => {
it('should return correct compatibility types for factory-created clients', () => {
const testCases = [
{ provider: openaiProvider, expectedType: 'OpenAIAPIClient' },
{ provider: anthropicProvider, expectedType: 'AnthropicAPIClient' },
{ provider: azureProvider, expectedType: 'OpenAIResponseAPIClient' },
{ provider: aihubmixProvider, expectedType: 'AihubmixAPIClient' },
{ provider: newApiProvider, expectedType: 'NewAPIClient' },
{ provider: vertexProvider, expectedType: 'VertexAPIClient' }
]
testCases.forEach(({ provider, expectedType }) => {
const client = ApiClientFactory.create(provider)
const compatibilityTypes = client.getClientCompatibilityType()
expect(compatibilityTypes).toContain(expectedType)
})
})
})
})

View File

@ -1,96 +0,0 @@
import { isOpenAILLMModel } from '@renderer/config/models'
import type { Model, Provider } from '@renderer/types'
import { AnthropicAPIClient } from '../anthropic/AnthropicAPIClient'
import type { BaseApiClient } from '../BaseApiClient'
import { GeminiAPIClient } from '../gemini/GeminiAPIClient'
import { MixedBaseAPIClient } from '../MixedBaseApiClient'
import { OpenAIAPIClient } from '../openai/OpenAIApiClient'
import { OpenAIResponseAPIClient } from '../openai/OpenAIResponseAPIClient'
/**
* AihubmixAPIClient - ApiClient
* 使ApiClient层面进行模型路由
*/
export class AihubmixAPIClient extends MixedBaseAPIClient {
// 使用联合类型而不是any保持类型安全
protected clients: Map<string, AnthropicAPIClient | GeminiAPIClient | OpenAIResponseAPIClient | OpenAIAPIClient> =
new Map()
protected defaultClient: OpenAIAPIClient
protected currentClient: BaseApiClient
constructor(provider: Provider) {
super(provider)
const providerExtraHeaders = {
...provider,
extra_headers: {
...provider.extra_headers,
'APP-Code': 'MLTG2087'
}
}
// 初始化各个client - 现在有类型安全
const claudeClient = new AnthropicAPIClient(providerExtraHeaders)
const geminiClient = new GeminiAPIClient({ ...providerExtraHeaders, apiHost: 'https://aihubmix.com/gemini' })
const openaiClient = new OpenAIResponseAPIClient(providerExtraHeaders)
const defaultClient = new OpenAIAPIClient(providerExtraHeaders)
this.clients.set('claude', claudeClient)
this.clients.set('gemini', geminiClient)
this.clients.set('openai', openaiClient)
this.clients.set('default', defaultClient)
// 设置默认client
this.defaultClient = defaultClient
this.currentClient = this.defaultClient as BaseApiClient
}
override getBaseURL(): string {
if (!this.currentClient) {
return this.provider.apiHost
}
return this.currentClient.getBaseURL()
}
/**
* client
*/
protected getClient(model: Model): BaseApiClient {
const id = model.id.toLowerCase()
// claude开头
if (id.startsWith('claude')) {
const client = this.clients.get('claude')
if (!client || !this.isValidClient(client)) {
throw new Error('Claude client not properly initialized')
}
return client
}
// gemini开头 且不以-nothink、-search结尾
if (
(id.startsWith('gemini') || id.startsWith('imagen')) &&
!id.endsWith('-nothink') &&
!id.endsWith('-search') &&
!id.includes('embedding')
) {
const client = this.clients.get('gemini')
if (!client || !this.isValidClient(client)) {
throw new Error('Gemini client not properly initialized')
}
return client
}
// OpenAI系列模型 不包含gpt-oss
if (isOpenAILLMModel(model) && !model.id.includes('gpt-oss')) {
const client = this.clients.get('openai')
if (!client || !this.isValidClient(client)) {
throw new Error('OpenAI client not properly initialized')
}
return client
}
return this.defaultClient as BaseApiClient
}
}

View File

@ -1,788 +0,0 @@
import type Anthropic from '@anthropic-ai/sdk'
import type {
Base64ImageSource,
ImageBlockParam,
MessageParam,
TextBlockParam,
ToolResultBlockParam,
ToolUseBlock,
WebSearchTool20250305
} from '@anthropic-ai/sdk/resources'
import type {
ContentBlock,
ContentBlockParam,
MessageCreateParamsBase,
RedactedThinkingBlockParam,
ServerToolUseBlockParam,
ThinkingBlockParam,
ThinkingConfigParam,
ToolUnion,
ToolUseBlockParam,
WebSearchResultBlock,
WebSearchToolResultBlockParam,
WebSearchToolResultError
} from '@anthropic-ai/sdk/resources/messages'
import { MessageStream } from '@anthropic-ai/sdk/resources/messages/messages'
import type AnthropicVertex from '@anthropic-ai/vertex-sdk'
import { loggerService } from '@logger'
import { DEFAULT_MAX_TOKENS } from '@renderer/config/constant'
import { findTokenLimit, isClaudeReasoningModel, isReasoningModel, isWebSearchModel } from '@renderer/config/models'
import { getAssistantSettings } from '@renderer/services/AssistantService'
import FileManager from '@renderer/services/FileManager'
import { estimateTextTokens } from '@renderer/services/TokenService'
import type {
Assistant,
MCPCallToolResponse,
MCPTool,
MCPToolResponse,
Model,
Provider,
ToolCallResponse
} from '@renderer/types'
import { EFFORT_RATIO, FileTypes, WebSearchSource } from '@renderer/types'
import type {
ErrorChunk,
LLMWebSearchCompleteChunk,
LLMWebSearchInProgressChunk,
MCPToolCreatedChunk,
TextDeltaChunk,
TextStartChunk,
ThinkingDeltaChunk,
ThinkingStartChunk
} from '@renderer/types/chunk'
import { ChunkType } from '@renderer/types/chunk'
import { type Message } from '@renderer/types/newMessage'
import type {
AnthropicSdkMessageParam,
AnthropicSdkParams,
AnthropicSdkRawChunk,
AnthropicSdkRawOutput
} from '@renderer/types/sdk'
import { addImageFileToContents } from '@renderer/utils/formats'
import {
anthropicToolUseToMcpTool,
isSupportedToolUse,
mcpToolCallResponseToAnthropicMessage,
mcpToolsToAnthropicTools
} from '@renderer/utils/mcp-tools'
import { findFileBlocks, findImageBlocks } from '@renderer/utils/messageUtils/find'
import { buildClaudeCodeSystemMessage, getSdkClient } from '@shared/anthropic'
import { t } from 'i18next'
import type { GenericChunk } from '../../middleware/schemas'
import { BaseApiClient } from '../BaseApiClient'
import type { AnthropicStreamListener, RawStreamListener, RequestTransformer, ResponseChunkTransformer } from '../types'
const logger = loggerService.withContext('AnthropicAPIClient')
export class AnthropicAPIClient extends BaseApiClient<
Anthropic | AnthropicVertex,
AnthropicSdkParams,
AnthropicSdkRawOutput,
AnthropicSdkRawChunk,
AnthropicSdkMessageParam,
ToolUseBlock,
ToolUnion
> {
oauthToken: string | undefined = undefined
sdkInstance: Anthropic | AnthropicVertex | undefined = undefined
constructor(provider: Provider) {
super(provider)
}
async getSdkInstance(): Promise<Anthropic | AnthropicVertex> {
if (this.sdkInstance) {
return this.sdkInstance
}
if (this.provider.authType === 'oauth') {
this.oauthToken = await window.api.anthropic_oauth.getAccessToken()
}
this.sdkInstance = getSdkClient(this.provider, this.oauthToken)
return this.sdkInstance
}
override async createCompletions(
payload: AnthropicSdkParams,
options?: Anthropic.RequestOptions
): Promise<AnthropicSdkRawOutput> {
if (this.provider.authType === 'oauth') {
payload.system = buildClaudeCodeSystemMessage(payload.system)
}
const sdk = (await this.getSdkInstance()) as Anthropic
if (payload.stream) {
return sdk.messages.stream(payload, options)
}
return sdk.messages.create(payload, options)
}
// @ts-ignore sdk未提供
// oxlint-disable-next-line @typescript-eslint/no-unused-vars
override async generateImage(generateImageParams: GenerateImageParams): Promise<string[]> {
return []
}
override async listModels(): Promise<Anthropic.ModelInfo[]> {
const sdk = (await this.getSdkInstance()) as Anthropic
// prevent auto appended /v1. It's included in baseUrl.
const response = await sdk.models.list({ path: '/models' })
return response.data
}
// @ts-ignore sdk未提供
override async getEmbeddingDimensions(): Promise<number> {
throw new Error("Anthropic SDK doesn't support getEmbeddingDimensions method.")
}
override getTemperature(assistant: Assistant, model: Model): number | undefined {
if (assistant.settings?.reasoning_effort && isClaudeReasoningModel(model)) {
return undefined
}
return super.getTemperature(assistant, model)
}
override getTopP(assistant: Assistant, model: Model): number | undefined {
if (assistant.settings?.reasoning_effort && isClaudeReasoningModel(model)) {
return undefined
}
return super.getTopP(assistant, model)
}
/**
* Get the reasoning effort
* @param assistant - The assistant
* @param model - The model
* @returns The reasoning effort
*/
private getBudgetToken(assistant: Assistant, model: Model): ThinkingConfigParam | undefined {
if (!isReasoningModel(model)) {
return undefined
}
const { maxTokens } = getAssistantSettings(assistant)
const reasoningEffort = assistant?.settings?.reasoning_effort
if (reasoningEffort === undefined) {
return {
type: 'disabled'
}
}
const effortRatio = EFFORT_RATIO[reasoningEffort]
const budgetTokens = Math.max(
1024,
Math.floor(
Math.min(
(findTokenLimit(model.id)?.max! - findTokenLimit(model.id)?.min!) * effortRatio +
findTokenLimit(model.id)?.min!,
(maxTokens || DEFAULT_MAX_TOKENS) * effortRatio
)
)
)
return {
type: 'enabled',
budget_tokens: budgetTokens
}
}
private static isValidBase64ImageMediaType(mime: string): mime is Base64ImageSource['media_type'] {
return ['image/jpeg', 'image/png', 'image/gif', 'image/webp'].includes(mime)
}
/**
* Get the message parameter
* @param message - The message
* @returns The message parameter
*/
public async convertMessageToSdkParam(message: Message): Promise<AnthropicSdkMessageParam> {
const { textContent, imageContents } = await this.getMessageContent(message)
const parts: MessageParam['content'] = [
{
type: 'text',
text: textContent
}
]
if (imageContents.length > 0) {
for (const imageContent of imageContents) {
const base64Data = await window.api.file.base64Image(imageContent.fileId + imageContent.fileExt)
base64Data.mime = base64Data.mime.replace('jpg', 'jpeg')
if (AnthropicAPIClient.isValidBase64ImageMediaType(base64Data.mime)) {
parts.push({
type: 'image',
source: {
data: base64Data.base64,
media_type: base64Data.mime,
type: 'base64'
}
})
} else {
logger.warn('Unsupported image type, ignored.', { mime: base64Data.mime })
}
}
}
// Get and process image blocks
const imageBlocks = findImageBlocks(message)
for (const imageBlock of imageBlocks) {
if (imageBlock.file) {
// Handle uploaded file
const file = imageBlock.file
const base64Data = await window.api.file.base64Image(file.id + file.ext)
parts.push({
type: 'image',
source: {
data: base64Data.base64,
media_type: base64Data.mime.replace('jpg', 'jpeg') as any,
type: 'base64'
}
})
}
}
// Get and process file blocks
const fileBlocks = findFileBlocks(message)
for (const fileBlock of fileBlocks) {
const { file } = fileBlock
if ([FileTypes.TEXT, FileTypes.DOCUMENT].includes(file.type)) {
if (file.ext === '.pdf' && file.size < 32 * 1024 * 1024) {
const base64Data = await FileManager.readBase64File(file)
parts.push({
type: 'document',
source: {
type: 'base64',
media_type: 'application/pdf',
data: base64Data
}
})
} else {
const fileContent = await (await window.api.file.read(file.id + file.ext, true)).trim()
parts.push({
type: 'text',
text: file.origin_name + '\n' + fileContent
})
}
}
}
return {
role: message.role === 'system' ? 'user' : message.role,
content: parts
}
}
public convertMcpToolsToSdkTools(mcpTools: MCPTool[]): ToolUnion[] {
return mcpToolsToAnthropicTools(mcpTools)
}
public convertMcpToolResponseToSdkMessageParam(
mcpToolResponse: MCPToolResponse,
resp: MCPCallToolResponse,
model: Model
): AnthropicSdkMessageParam | undefined {
if ('toolUseId' in mcpToolResponse && mcpToolResponse.toolUseId) {
return mcpToolCallResponseToAnthropicMessage(mcpToolResponse, resp, model)
} else if ('toolCallId' in mcpToolResponse) {
return {
role: 'user',
content: [
{
type: 'tool_result',
tool_use_id: mcpToolResponse.toolCallId!,
content: resp.content
.map((item) => {
if (item.type === 'text') {
return {
type: 'text',
text: item.text || ''
} satisfies TextBlockParam
}
if (item.type === 'image') {
return {
type: 'image',
source: {
data: item.data || '',
media_type: (item.mimeType || 'image/png') as Base64ImageSource['media_type'],
type: 'base64'
}
} satisfies ImageBlockParam
}
return
})
.filter((n) => typeof n !== 'undefined'),
is_error: resp.isError
} satisfies ToolResultBlockParam
]
}
}
return
}
// Implementing abstract methods from BaseApiClient
convertSdkToolCallToMcp(toolCall: ToolUseBlock, mcpTools: MCPTool[]): MCPTool | undefined {
// Based on anthropicToolUseToMcpTool logic in AnthropicProvider
// This might need adjustment based on how tool calls are specifically handled in the new structure
const mcpTool = anthropicToolUseToMcpTool(mcpTools, toolCall)
return mcpTool
}
convertSdkToolCallToMcpToolResponse(toolCall: ToolUseBlock, mcpTool: MCPTool): ToolCallResponse {
return {
id: toolCall.id,
toolCallId: toolCall.id,
tool: mcpTool,
arguments: toolCall.input as Record<string, unknown>,
status: 'pending'
} as ToolCallResponse
}
override buildSdkMessages(
currentReqMessages: AnthropicSdkMessageParam[],
output: Anthropic.Message,
toolResults: AnthropicSdkMessageParam[]
): AnthropicSdkMessageParam[] {
const assistantMessage: AnthropicSdkMessageParam = {
role: output.role,
content: convertContentBlocksToParams(output.content)
}
const newMessages: AnthropicSdkMessageParam[] = [...currentReqMessages, assistantMessage]
if (toolResults && toolResults.length > 0) {
newMessages.push(...toolResults)
}
return newMessages
}
override estimateMessageTokens(message: AnthropicSdkMessageParam): number {
if (typeof message.content === 'string') {
return estimateTextTokens(message.content)
}
return message.content
.map((content) => {
switch (content.type) {
case 'text':
return estimateTextTokens(content.text)
case 'image':
if (content.source.type === 'base64') {
return estimateTextTokens(content.source.data)
} else {
return estimateTextTokens(content.source.url)
}
case 'tool_use':
return estimateTextTokens(JSON.stringify(content.input))
case 'tool_result':
return estimateTextTokens(JSON.stringify(content.content))
default:
return 0
}
})
.reduce((acc, curr) => acc + curr, 0)
}
public buildAssistantMessage(message: Anthropic.Message): AnthropicSdkMessageParam {
const messageParam: AnthropicSdkMessageParam = {
role: message.role,
content: convertContentBlocksToParams(message.content)
}
return messageParam
}
public extractMessagesFromSdkPayload(sdkPayload: AnthropicSdkParams): AnthropicSdkMessageParam[] {
return sdkPayload.messages || []
}
/**
* Anthropic专用的原始流监听器
* MessageStream对象的特定事件
*/
attachRawStreamListener(
rawOutput: AnthropicSdkRawOutput,
listener: RawStreamListener<AnthropicSdkRawChunk>
): AnthropicSdkRawOutput {
logger.debug(`Attaching stream listener to raw output`)
// 专用的Anthropic事件处理
const anthropicListener = listener as AnthropicStreamListener
// 检查是否为MessageStream
if (rawOutput instanceof MessageStream) {
logger.debug(`Detected Anthropic MessageStream, attaching specialized listener`)
if (listener.onStart) {
listener.onStart()
}
if (listener.onChunk) {
rawOutput.on('streamEvent', (event: AnthropicSdkRawChunk) => {
listener.onChunk!(event)
})
}
if (anthropicListener.onContentBlock) {
rawOutput.on('contentBlock', anthropicListener.onContentBlock)
}
if (anthropicListener.onMessage) {
rawOutput.on('finalMessage', anthropicListener.onMessage)
}
if (listener.onEnd) {
rawOutput.on('end', () => {
listener.onEnd!()
})
}
if (listener.onError) {
rawOutput.on('error', (error: Error) => {
listener.onError!(error)
})
}
return rawOutput
}
if (anthropicListener.onMessage) {
anthropicListener.onMessage(rawOutput)
}
// 对于非MessageStream响应
return rawOutput
}
private async getWebSearchParams(model: Model): Promise<WebSearchTool20250305 | undefined> {
if (!isWebSearchModel(model)) {
return undefined
}
return {
type: 'web_search_20250305',
name: 'web_search',
max_uses: 5
} as WebSearchTool20250305
}
getRequestTransformer(): RequestTransformer<AnthropicSdkParams, AnthropicSdkMessageParam> {
return {
transform: async (
coreRequest,
assistant,
model,
isRecursiveCall,
recursiveSdkMessages
): Promise<{
payload: AnthropicSdkParams
messages: AnthropicSdkMessageParam[]
metadata: Record<string, any>
}> => {
const { messages, mcpTools, maxTokens, streamOutput, enableWebSearch } = coreRequest
// 1. 处理系统消息
const systemPrompt = assistant.prompt
// 2. 设置工具
const { tools } = this.setupToolsConfig({
mcpTools: mcpTools,
model,
enableToolUse: isSupportedToolUse(assistant)
})
const systemMessage: TextBlockParam | undefined = systemPrompt
? { type: 'text', text: systemPrompt }
: undefined
// 3. 处理用户消息
const sdkMessages: AnthropicSdkMessageParam[] = []
if (typeof messages === 'string') {
sdkMessages.push({ role: 'user', content: messages })
} else {
const processedMessages = addImageFileToContents(messages)
for (const message of processedMessages) {
sdkMessages.push(await this.convertMessageToSdkParam(message))
}
}
if (enableWebSearch) {
const webSearchTool = await this.getWebSearchParams(model)
if (webSearchTool) {
tools.push(webSearchTool)
}
}
const commonParams: MessageCreateParamsBase = {
model: model.id,
messages:
isRecursiveCall && recursiveSdkMessages && recursiveSdkMessages.length > 0
? recursiveSdkMessages
: sdkMessages,
max_tokens: maxTokens || DEFAULT_MAX_TOKENS,
temperature: this.getTemperature(assistant, model),
top_p: this.getTopP(assistant, model),
system: systemMessage ? [systemMessage] : undefined,
thinking: this.getBudgetToken(assistant, model),
tools: tools.length > 0 ? tools : undefined,
stream: streamOutput,
// 只在对话场景下应用自定义参数,避免影响翻译、总结等其他业务逻辑
// 注意:用户自定义参数总是应该覆盖其他参数
...(coreRequest.callType === 'chat' ? this.getCustomParameters(assistant) : {})
}
const timeout = this.getTimeout(model)
return { payload: commonParams, messages: sdkMessages, metadata: { timeout } }
}
}
}
getResponseChunkTransformer(): ResponseChunkTransformer<AnthropicSdkRawChunk> {
return () => {
let accumulatedJson = ''
const toolCalls: Record<number, ToolUseBlock> = {}
return {
async transform(rawChunk: AnthropicSdkRawChunk, controller: TransformStreamDefaultController<GenericChunk>) {
if (typeof rawChunk === 'string') {
try {
rawChunk = JSON.parse(rawChunk)
} catch (error) {
logger.error('invalid chunk', { rawChunk, error })
throw new Error(t('error.chat.chunk.non_json'))
}
}
switch (rawChunk.type) {
case 'message': {
let i = 0
let hasTextContent = false
let hasThinkingContent = false
for (const content of rawChunk.content) {
switch (content.type) {
case 'text': {
if (!hasTextContent) {
controller.enqueue({
type: ChunkType.TEXT_START
} as TextStartChunk)
hasTextContent = true
}
controller.enqueue({
type: ChunkType.TEXT_DELTA,
text: content.text
} as TextDeltaChunk)
break
}
case 'tool_use': {
toolCalls[i] = content
i++
break
}
case 'thinking': {
if (!hasThinkingContent) {
controller.enqueue({
type: ChunkType.THINKING_START
} as ThinkingStartChunk)
hasThinkingContent = true
}
controller.enqueue({
type: ChunkType.THINKING_DELTA,
text: content.thinking
} as ThinkingDeltaChunk)
break
}
case 'web_search_tool_result': {
controller.enqueue({
type: ChunkType.LLM_WEB_SEARCH_COMPLETE,
llm_web_search: {
results: content.content,
source: WebSearchSource.ANTHROPIC
}
} as LLMWebSearchCompleteChunk)
break
}
}
}
if (i > 0) {
controller.enqueue({
type: ChunkType.MCP_TOOL_CREATED,
tool_calls: Object.values(toolCalls)
} as MCPToolCreatedChunk)
}
controller.enqueue({
type: ChunkType.LLM_RESPONSE_COMPLETE,
response: {
usage: {
prompt_tokens: rawChunk.usage.input_tokens || 0,
completion_tokens: rawChunk.usage.output_tokens || 0,
total_tokens: (rawChunk.usage.input_tokens || 0) + (rawChunk.usage.output_tokens || 0)
}
}
})
break
}
case 'content_block_start': {
const contentBlock = rawChunk.content_block
switch (contentBlock.type) {
case 'server_tool_use': {
if (contentBlock.name === 'web_search') {
controller.enqueue({
type: ChunkType.LLM_WEB_SEARCH_IN_PROGRESS
} as LLMWebSearchInProgressChunk)
}
break
}
case 'web_search_tool_result': {
if (
contentBlock.content &&
(contentBlock.content as WebSearchToolResultError).type === 'web_search_tool_result_error'
) {
controller.enqueue({
type: ChunkType.ERROR,
error: {
code: (contentBlock.content as WebSearchToolResultError).error_code,
message: (contentBlock.content as WebSearchToolResultError).error_code
}
} as ErrorChunk)
} else {
controller.enqueue({
type: ChunkType.LLM_WEB_SEARCH_COMPLETE,
llm_web_search: {
results: contentBlock.content as Array<WebSearchResultBlock>,
source: WebSearchSource.ANTHROPIC
}
} as LLMWebSearchCompleteChunk)
}
break
}
case 'tool_use': {
toolCalls[rawChunk.index] = contentBlock
break
}
case 'text': {
controller.enqueue({
type: ChunkType.TEXT_START
} as TextStartChunk)
break
}
case 'thinking':
case 'redacted_thinking': {
controller.enqueue({
type: ChunkType.THINKING_START
} as ThinkingStartChunk)
break
}
}
break
}
case 'content_block_delta': {
const messageDelta = rawChunk.delta
switch (messageDelta.type) {
case 'text_delta': {
if (messageDelta.text) {
controller.enqueue({
type: ChunkType.TEXT_DELTA,
text: messageDelta.text
} as TextDeltaChunk)
}
break
}
case 'thinking_delta': {
if (messageDelta.thinking) {
controller.enqueue({
type: ChunkType.THINKING_DELTA,
text: messageDelta.thinking
} as ThinkingDeltaChunk)
}
break
}
case 'input_json_delta': {
if (messageDelta.partial_json) {
accumulatedJson += messageDelta.partial_json
}
break
}
}
break
}
case 'content_block_stop': {
const toolCall = toolCalls[rawChunk.index]
if (toolCall) {
try {
toolCall.input = accumulatedJson ? JSON.parse(accumulatedJson) : {}
logger.debug(`Tool call id: ${toolCall.id}, accumulated json: ${accumulatedJson}`)
controller.enqueue({
type: ChunkType.MCP_TOOL_CREATED,
tool_calls: [toolCall]
} as MCPToolCreatedChunk)
} catch (error) {
logger.error('Error parsing tool call input:', error as Error)
}
}
break
}
case 'message_delta': {
controller.enqueue({
type: ChunkType.LLM_RESPONSE_COMPLETE,
response: {
usage: {
prompt_tokens: rawChunk.usage.input_tokens || 0,
completion_tokens: rawChunk.usage.output_tokens || 0,
total_tokens: (rawChunk.usage.input_tokens || 0) + (rawChunk.usage.output_tokens || 0)
}
}
})
}
}
}
}
}
}
}
/**
* ContentBlock ContentBlockParam
* API所需的字段
*/
function convertContentBlocksToParams(contentBlocks: ContentBlock[]): ContentBlockParam[] {
return contentBlocks.map((block): ContentBlockParam => {
switch (block.type) {
case 'text':
// TextBlock -> TextBlockParam去除 citations 等服务器字段
return {
type: 'text',
text: block.text
} satisfies TextBlockParam
case 'tool_use':
// ToolUseBlock -> ToolUseBlockParam
return {
type: 'tool_use',
id: block.id,
name: block.name,
input: block.input
} satisfies ToolUseBlockParam
case 'thinking':
// ThinkingBlock -> ThinkingBlockParam
return {
type: 'thinking',
thinking: block.thinking,
signature: block.signature
} satisfies ThinkingBlockParam
case 'redacted_thinking':
// RedactedThinkingBlock -> RedactedThinkingBlockParam
return {
type: 'redacted_thinking',
data: block.data
} satisfies RedactedThinkingBlockParam
case 'server_tool_use':
// ServerToolUseBlock -> ServerToolUseBlockParam
return {
type: 'server_tool_use',
id: block.id,
name: block.name,
input: block.input
} satisfies ServerToolUseBlockParam
case 'web_search_tool_result':
// WebSearchToolResultBlock -> WebSearchToolResultBlockParam
return {
type: 'web_search_tool_result',
tool_use_id: block.tool_use_id,
content: block.content
} satisfies WebSearchToolResultBlockParam
default:
return block as ContentBlockParam
}
})
}

View File

@ -1,104 +0,0 @@
import type Anthropic from '@anthropic-ai/sdk'
import AnthropicVertex from '@anthropic-ai/vertex-sdk'
import { loggerService } from '@logger'
import { getVertexAILocation, getVertexAIProjectId, getVertexAIServiceAccount } from '@renderer/hooks/useVertexAI'
import type { Provider } from '@renderer/types'
import { isEmpty } from 'lodash'
import { AnthropicAPIClient } from './AnthropicAPIClient'
const logger = loggerService.withContext('AnthropicVertexClient')
export class AnthropicVertexClient extends AnthropicAPIClient {
sdkInstance: AnthropicVertex | undefined = undefined
private authHeaders?: Record<string, string>
private authHeadersExpiry?: number
constructor(provider: Provider) {
super(provider)
}
private formatApiHost(host: string): string {
const forceUseOriginalHost = () => {
return host.endsWith('/')
}
if (!host) {
return host
}
return forceUseOriginalHost() ? host : `${host}/v1/`
}
override getBaseURL() {
return this.formatApiHost(this.provider.apiHost)
}
override async getSdkInstance(): Promise<AnthropicVertex> {
if (this.sdkInstance) {
return this.sdkInstance
}
const serviceAccount = getVertexAIServiceAccount()
const projectId = getVertexAIProjectId()
const location = getVertexAILocation()
if (!serviceAccount.privateKey || !serviceAccount.clientEmail || !projectId || !location) {
throw new Error('Vertex AI settings are not configured')
}
const authHeaders = await this.getServiceAccountAuthHeaders()
this.sdkInstance = new AnthropicVertex({
projectId: projectId,
region: location,
dangerouslyAllowBrowser: true,
defaultHeaders: authHeaders,
baseURL: isEmpty(this.getBaseURL()) ? undefined : this.getBaseURL()
})
return this.sdkInstance
}
override async listModels(): Promise<Anthropic.ModelInfo[]> {
throw new Error('Vertex AI does not support listModels method.')
}
/**
* service account
*/
private async getServiceAccountAuthHeaders(): Promise<Record<string, string> | undefined> {
const serviceAccount = getVertexAIServiceAccount()
const projectId = getVertexAIProjectId()
// 检查是否配置了 service account
if (!serviceAccount.privateKey || !serviceAccount.clientEmail || !projectId) {
return undefined
}
// 检查是否已有有效的认证头(提前 5 分钟过期)
const now = Date.now()
if (this.authHeaders && this.authHeadersExpiry && this.authHeadersExpiry - now > 5 * 60 * 1000) {
return this.authHeaders
}
try {
// 从主进程获取认证头
this.authHeaders = await window.api.vertexAI.getAuthHeaders({
projectId,
serviceAccount: {
privateKey: serviceAccount.privateKey,
clientEmail: serviceAccount.clientEmail
}
})
// 设置过期时间(通常认证头有效期为 1 小时)
this.authHeadersExpiry = now + 60 * 60 * 1000
return this.authHeaders
} catch (error: any) {
logger.error('Failed to get auth headers:', error)
throw new Error(`Service Account authentication failed: ${error.message}`)
}
}
}

View File

@ -1,51 +0,0 @@
import type OpenAI from '@cherrystudio/openai'
import type { Provider } from '@renderer/types'
import type { OpenAISdkParams, OpenAISdkRawOutput } from '@renderer/types/sdk'
import { OpenAIAPIClient } from '../openai/OpenAIApiClient'
export class CherryAiAPIClient extends OpenAIAPIClient {
constructor(provider: Provider) {
super(provider)
}
override async createCompletions(
payload: OpenAISdkParams,
options?: OpenAI.RequestOptions
): Promise<OpenAISdkRawOutput> {
const sdk = await this.getSdkInstance()
options = options || {}
options.headers = options.headers || {}
const signature = await window.api.cherryai.generateSignature({
method: 'POST',
path: '/chat/completions',
query: '',
body: payload
})
options.headers = {
...options.headers,
...signature
}
// @ts-ignore - SDK参数可能有额外的字段
return await sdk.chat.completions.create(payload, options)
}
override getClientCompatibilityType(): string[] {
return ['CherryAiAPIClient']
}
public async listModels(): Promise<OpenAI.Models.Model[]> {
const models = ['glm-4.5-flash', 'Qwen/Qwen3-8B']
const created = Date.now()
return models.map((id) => ({
id,
owned_by: 'cherryai',
object: 'model' as const,
created
}))
}
}

View File

@ -1,841 +0,0 @@
import type {
Content,
File,
FunctionCall,
GenerateContentConfig,
GenerateImagesConfig,
Model as GeminiModel,
Part,
SafetySetting,
SendMessageParameters,
ThinkingConfig,
Tool
} from '@google/genai'
import { createPartFromUri, GoogleGenAI, HarmBlockThreshold, HarmCategory, Modality } from '@google/genai'
import { loggerService } from '@logger'
import { nanoid } from '@reduxjs/toolkit'
import {
findTokenLimit,
GEMINI_FLASH_MODEL_REGEX,
isGemmaModel,
isSupportedThinkingTokenGeminiModel,
isVisionModel
} from '@renderer/config/models'
import { estimateTextTokens } from '@renderer/services/TokenService'
import type {
Assistant,
FileMetadata,
FileUploadResponse,
GenerateImageParams,
MCPCallToolResponse,
MCPTool,
MCPToolResponse,
Model,
Provider,
ToolCallResponse
} from '@renderer/types'
import { EFFORT_RATIO, FileTypes, WebSearchSource } from '@renderer/types'
import type { LLMWebSearchCompleteChunk, TextStartChunk, ThinkingStartChunk } from '@renderer/types/chunk'
import { ChunkType } from '@renderer/types/chunk'
import type { Message } from '@renderer/types/newMessage'
import type {
GeminiOptions,
GeminiSdkMessageParam,
GeminiSdkParams,
GeminiSdkRawChunk,
GeminiSdkRawOutput,
GeminiSdkToolCall
} from '@renderer/types/sdk'
import { isToolUseModeFunction } from '@renderer/utils/assistant'
import {
geminiFunctionCallToMcpTool,
isSupportedToolUse,
mcpToolCallResponseToGeminiMessage,
mcpToolsToGeminiTools
} from '@renderer/utils/mcp-tools'
import { findFileBlocks, findImageBlocks, getMainTextContent } from '@renderer/utils/messageUtils/find'
import { defaultTimeout, MB } from '@shared/config/constant'
import { getTrailingApiVersion, withoutTrailingApiVersion } from '@shared/utils'
import { t } from 'i18next'
import type { GenericChunk } from '../../middleware/schemas'
import { BaseApiClient } from '../BaseApiClient'
import type { RequestTransformer, ResponseChunkTransformer } from '../types'
const logger = loggerService.withContext('GeminiAPIClient')
export class GeminiAPIClient extends BaseApiClient<
GoogleGenAI,
GeminiSdkParams,
GeminiSdkRawOutput,
GeminiSdkRawChunk,
GeminiSdkMessageParam,
GeminiSdkToolCall,
Tool
> {
constructor(provider: Provider) {
super(provider)
}
override async createCompletions(payload: GeminiSdkParams, options?: GeminiOptions): Promise<GeminiSdkRawOutput> {
const sdk = await this.getSdkInstance()
const { model, history, ...rest } = payload
const realPayload: Omit<GeminiSdkParams, 'model'> = {
...rest,
config: {
...rest.config,
abortSignal: options?.signal,
httpOptions: {
...rest.config?.httpOptions,
timeout: options?.timeout
}
}
} satisfies SendMessageParameters
const streamOutput = options?.streamOutput
const chat = sdk.chats.create({
model: model,
history: history
})
if (streamOutput) {
const stream = chat.sendMessageStream(realPayload)
return stream
} else {
const response = await chat.sendMessage(realPayload)
return response
}
}
override async generateImage(generateImageParams: GenerateImageParams): Promise<string[]> {
const sdk = await this.getSdkInstance()
try {
const { model, prompt, imageSize, batchSize, signal } = generateImageParams
const config: GenerateImagesConfig = {
numberOfImages: batchSize,
aspectRatio: imageSize,
abortSignal: signal,
httpOptions: {
timeout: defaultTimeout
}
}
const response = await sdk.models.generateImages({
model: model,
prompt,
config
})
if (!response.generatedImages || response.generatedImages.length === 0) {
return []
}
const images = response.generatedImages
.filter((image) => image.image?.imageBytes)
.map((image) => {
const dataPrefix = `data:${image.image?.mimeType || 'image/png'};base64,`
return dataPrefix + image.image?.imageBytes
})
// console.log(response?.generatedImages?.[0]?.image?.imageBytes);
return images
} catch (error) {
logger.error('[generateImage] error:', error as Error)
throw error
}
}
override async getEmbeddingDimensions(model: Model): Promise<number> {
const sdk = await this.getSdkInstance()
const data = await sdk.models.embedContent({
model: model.id,
contents: [{ role: 'user', parts: [{ text: 'hi' }] }]
})
return data.embeddings?.[0]?.values?.length || 0
}
override async listModels(): Promise<GeminiModel[]> {
const sdk = await this.getSdkInstance()
const response = await sdk.models.list()
const models: GeminiModel[] = []
for await (const model of response) {
models.push(model)
}
return models
}
override getBaseURL(): string {
return withoutTrailingApiVersion(super.getBaseURL())
}
override async getSdkInstance() {
if (this.sdkInstance) {
return this.sdkInstance
}
const apiVersion = this.getApiVersion()
this.sdkInstance = new GoogleGenAI({
vertexai: false,
apiKey: this.apiKey,
apiVersion,
httpOptions: {
baseUrl: this.getBaseURL(),
apiVersion,
headers: {
...this.provider.extra_headers
}
}
})
return this.sdkInstance
}
protected getApiVersion(): string {
if (this.provider.isVertex) {
return 'v1'
}
// Extract trailing API version from the URL
const trailingVersion = getTrailingApiVersion(this.provider.apiHost || '')
if (trailingVersion) {
return trailingVersion
}
return ''
}
/**
* Handle a PDF file
* @param file - The file
* @returns The part
*/
private async handlePdfFile(file: FileMetadata): Promise<Part> {
const smallFileSize = 20 * MB
const isSmallFile = file.size < smallFileSize
if (isSmallFile) {
const { data, mimeType } = await this.base64File(file)
return {
inlineData: {
data,
mimeType
}
}
}
// Retrieve file from Gemini uploaded files
const fileMetadata: FileUploadResponse = await window.api.fileService.retrieve(this.provider, file.id)
if (fileMetadata.status === 'success') {
const remoteFile = fileMetadata.originalFile?.file as File
return createPartFromUri(remoteFile.uri!, remoteFile.mimeType!)
}
// If file is not found, upload it to Gemini
const result = await window.api.fileService.upload(this.provider, file)
const remoteFile = result.originalFile
if (!remoteFile) {
throw new Error('File upload failed, please try again')
}
if (remoteFile.type === 'gemini') {
const file = remoteFile.file
if (!file.uri) {
throw new Error('File URI is required but not found')
}
if (!file.mimeType) {
throw new Error('File MIME type is required but not found')
}
return createPartFromUri(file.uri, file.mimeType)
} else {
throw new Error('Unsupported file type for Gemini API')
}
}
/**
* Get the message contents
* @param message - The message
* @returns The message contents
*/
private async convertMessageToSdkParam(message: Message): Promise<Content> {
const role = message.role === 'user' ? 'user' : 'model'
const { textContent, imageContents } = await this.getMessageContent(message)
const parts: Part[] = [{ text: textContent }]
if (imageContents.length > 0) {
for (const imageContent of imageContents) {
const image = await window.api.file.base64Image(imageContent.fileId + imageContent.fileExt)
parts.push({
inlineData: {
data: image.base64,
mimeType: image.mime
} satisfies Part['inlineData']
})
}
}
// Add any generated images from previous responses
const imageBlocks = findImageBlocks(message)
for (const imageBlock of imageBlocks) {
if (
imageBlock.metadata?.generateImageResponse?.images &&
imageBlock.metadata.generateImageResponse.images.length > 0
) {
for (const imageUrl of imageBlock.metadata.generateImageResponse.images) {
if (imageUrl && imageUrl.startsWith('data:')) {
// Extract base64 data and mime type from the data URL
const matches = imageUrl.match(/^data:(.+);base64,(.*)$/)
if (matches && matches.length === 3) {
const mimeType = matches[1]
const base64Data = matches[2]
parts.push({
inlineData: {
data: base64Data,
mimeType: mimeType
} satisfies Part['inlineData']
})
}
}
}
}
const file = imageBlock.file
if (file) {
const base64Data = await window.api.file.base64Image(file.id + file.ext)
parts.push({
inlineData: {
data: base64Data.base64,
mimeType: base64Data.mime
} satisfies Part['inlineData']
})
}
}
const fileBlocks = findFileBlocks(message)
for (const fileBlock of fileBlocks) {
const file = fileBlock.file
if (file.type === FileTypes.IMAGE) {
const base64Data = await window.api.file.base64Image(file.id + file.ext)
parts.push({
inlineData: {
data: base64Data.base64,
mimeType: base64Data.mime
} satisfies Part['inlineData']
})
}
if (file.ext === '.pdf') {
parts.push(await this.handlePdfFile(file))
continue
}
if ([FileTypes.TEXT, FileTypes.DOCUMENT].includes(file.type)) {
const fileContent = await (await window.api.file.read(file.id + file.ext, true)).trim()
parts.push({
text: file.origin_name + '\n' + fileContent
})
}
}
return {
role,
parts: parts
}
}
// @ts-ignore unused
private async getImageFileContents(message: Message): Promise<Content> {
const role = message.role === 'user' ? 'user' : 'model'
const content = getMainTextContent(message)
const parts: Part[] = [{ text: content }]
const imageBlocks = findImageBlocks(message)
for (const imageBlock of imageBlocks) {
if (
imageBlock.metadata?.generateImageResponse?.images &&
imageBlock.metadata.generateImageResponse.images.length > 0
) {
for (const imageUrl of imageBlock.metadata.generateImageResponse.images) {
if (imageUrl && imageUrl.startsWith('data:')) {
// Extract base64 data and mime type from the data URL
const matches = imageUrl.match(/^data:(.+);base64,(.*)$/)
if (matches && matches.length === 3) {
const mimeType = matches[1]
const base64Data = matches[2]
parts.push({
inlineData: {
data: base64Data,
mimeType: mimeType
} satisfies Part['inlineData']
})
}
}
}
}
const file = imageBlock.file
if (file) {
const base64Data = await window.api.file.base64Image(file.id + file.ext)
parts.push({
inlineData: {
data: base64Data.base64,
mimeType: base64Data.mime
} satisfies Part['inlineData']
})
}
}
return {
role,
parts: parts
}
}
/**
* Get the safety settings
* @returns The safety settings
*/
private getSafetySettings(): SafetySetting[] {
const safetyThreshold = HarmBlockThreshold.OFF
return [
{
category: HarmCategory.HARM_CATEGORY_HATE_SPEECH,
threshold: safetyThreshold
},
{
category: HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
threshold: safetyThreshold
},
{
category: HarmCategory.HARM_CATEGORY_HARASSMENT,
threshold: safetyThreshold
},
{
category: HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
threshold: safetyThreshold
},
{
category: HarmCategory.HARM_CATEGORY_CIVIC_INTEGRITY,
threshold: HarmBlockThreshold.BLOCK_NONE
}
]
}
/**
* Get the reasoning effort for the assistant
* @param assistant - The assistant
* @param model - The model
* @returns The reasoning effort
*/
private getBudgetToken(assistant: Assistant, model: Model) {
if (isSupportedThinkingTokenGeminiModel(model)) {
const reasoningEffort = assistant?.settings?.reasoning_effort
// 如果thinking_budget是undefined不思考
if (reasoningEffort === undefined) {
return GEMINI_FLASH_MODEL_REGEX.test(model.id)
? {
thinkingConfig: {
thinkingBudget: 0
}
}
: {}
}
if (reasoningEffort === 'auto') {
return {
thinkingConfig: {
includeThoughts: true,
thinkingBudget: -1
}
}
}
const effortRatio = EFFORT_RATIO[reasoningEffort]
const { min, max } = findTokenLimit(model.id) || { min: 0, max: 0 }
// 计算 budgetTokens确保不低于 min
const budget = Math.floor((max - min) * effortRatio + min)
return {
thinkingConfig: {
...(budget > 0 ? { thinkingBudget: budget } : {}),
includeThoughts: true
} satisfies ThinkingConfig
}
}
return {}
}
private getGenerateImageParameter(): Partial<GenerateContentConfig> {
return {
systemInstruction: undefined,
responseModalities: [Modality.TEXT, Modality.IMAGE]
}
}
getRequestTransformer(): RequestTransformer<GeminiSdkParams, GeminiSdkMessageParam> {
return {
transform: async (
coreRequest,
assistant,
model,
isRecursiveCall,
recursiveSdkMessages
): Promise<{
payload: GeminiSdkParams
messages: GeminiSdkMessageParam[]
metadata: Record<string, any>
}> => {
const { messages, mcpTools, maxTokens, enableWebSearch, enableUrlContext, enableGenerateImage } = coreRequest
// 1. 处理系统消息
const systemInstruction = assistant.prompt
// 2. 设置工具
const { tools } = this.setupToolsConfig({
mcpTools,
model,
enableToolUse: isSupportedToolUse(assistant)
})
let messageContents: Content = { role: 'user', parts: [] } // Initialize messageContents
const history: Content[] = []
// 3. 处理用户消息
if (typeof messages === 'string') {
messageContents = {
role: 'user',
parts: [{ text: messages }]
}
} else {
const userLastMessage = messages.pop()
if (userLastMessage) {
messageContents = await this.convertMessageToSdkParam(userLastMessage)
for (const message of messages) {
history.push(await this.convertMessageToSdkParam(message))
}
messages.push(userLastMessage)
}
}
if (tools.length === 0 || !isToolUseModeFunction(assistant)) {
if (enableWebSearch) {
tools.push({
googleSearch: {}
})
}
if (enableUrlContext) {
tools.push({
urlContext: {}
})
}
} else if (enableWebSearch || enableUrlContext) {
logger.warn('Native tools cannot be used with function calling for now.')
}
if (isGemmaModel(model) && assistant.prompt) {
const isFirstMessage = history.length === 0
if (isFirstMessage && messageContents) {
const userMessageText =
messageContents.parts && messageContents.parts.length > 0 ? (messageContents.parts[0].text ?? '') : ''
const systemMessage = [
{
text:
'<start_of_turn>user\n' +
systemInstruction +
'<end_of_turn>\n' +
'<start_of_turn>user\n' +
userMessageText +
'<end_of_turn>'
}
] satisfies Part[]
if (messageContents && messageContents.parts) {
messageContents.parts[0] = systemMessage[0]
}
}
}
const newHistory =
isRecursiveCall && recursiveSdkMessages && recursiveSdkMessages.length > 0
? recursiveSdkMessages.slice(0, recursiveSdkMessages.length - 1)
: history
const newMessageContents =
isRecursiveCall && recursiveSdkMessages && recursiveSdkMessages.length > 0
? recursiveSdkMessages[recursiveSdkMessages.length - 1]
: messageContents
const generateContentConfig: GenerateContentConfig = {
safetySettings: this.getSafetySettings(),
systemInstruction: isGemmaModel(model) ? undefined : systemInstruction,
temperature: this.getTemperature(assistant, model),
topP: this.getTopP(assistant, model),
maxOutputTokens: maxTokens,
tools: tools,
...(enableGenerateImage ? this.getGenerateImageParameter() : {}),
...this.getBudgetToken(assistant, model),
// 只在对话场景下应用自定义参数,避免影响翻译、总结等其他业务逻辑
// 注意:用户自定义参数总是应该覆盖其他参数
...(coreRequest.callType === 'chat' ? this.getCustomParameters(assistant) : {})
}
const param: GeminiSdkParams = {
model: model.id,
config: generateContentConfig,
history: newHistory,
message: newMessageContents.parts!
}
return {
payload: param,
messages: [messageContents],
metadata: {}
}
}
}
}
getResponseChunkTransformer(): ResponseChunkTransformer<GeminiSdkRawChunk> {
const toolCalls: FunctionCall[] = []
let isFirstTextChunk = true
let isFirstThinkingChunk = true
return () => ({
async transform(chunk: GeminiSdkRawChunk, controller: TransformStreamDefaultController<GenericChunk>) {
logger.silly('chunk', chunk)
if (typeof chunk === 'string') {
try {
chunk = JSON.parse(chunk)
} catch (error) {
logger.error('invalid chunk', { chunk, error })
throw new Error(t('error.chat.chunk.non_json'))
}
}
if (chunk.candidates && chunk.candidates.length > 0) {
for (const candidate of chunk.candidates) {
if (candidate.content) {
candidate.content.parts?.forEach((part) => {
const text = part.text || ''
if (part.thought) {
if (isFirstThinkingChunk) {
controller.enqueue({
type: ChunkType.THINKING_START
} satisfies ThinkingStartChunk)
isFirstThinkingChunk = false
}
controller.enqueue({
type: ChunkType.THINKING_DELTA,
text: text
})
} else if (part.text) {
if (isFirstTextChunk) {
controller.enqueue({
type: ChunkType.TEXT_START
} satisfies TextStartChunk)
isFirstTextChunk = false
}
controller.enqueue({
type: ChunkType.TEXT_DELTA,
text: text
})
} else if (part.inlineData) {
controller.enqueue({
type: ChunkType.IMAGE_COMPLETE,
image: {
type: 'base64',
images: [
part.inlineData?.data?.startsWith('data:')
? part.inlineData?.data
: `data:${part.inlineData?.mimeType || 'image/png'};base64,${part.inlineData?.data}`
]
}
})
} else if (part.functionCall) {
toolCalls.push(part.functionCall)
}
})
}
if (candidate.finishReason) {
if (candidate.groundingMetadata) {
controller.enqueue({
type: ChunkType.LLM_WEB_SEARCH_COMPLETE,
llm_web_search: {
results: candidate.groundingMetadata,
source: WebSearchSource.GEMINI
}
} satisfies LLMWebSearchCompleteChunk)
}
if (toolCalls.length > 0) {
controller.enqueue({
type: ChunkType.MCP_TOOL_CREATED,
tool_calls: [...toolCalls]
})
toolCalls.length = 0
}
controller.enqueue({
type: ChunkType.LLM_RESPONSE_COMPLETE,
response: {
usage: {
prompt_tokens: chunk.usageMetadata?.promptTokenCount || 0,
completion_tokens:
(chunk.usageMetadata?.totalTokenCount || 0) - (chunk.usageMetadata?.promptTokenCount || 0),
total_tokens: chunk.usageMetadata?.totalTokenCount || 0
}
}
})
}
}
}
if (toolCalls.length > 0) {
controller.enqueue({
type: ChunkType.MCP_TOOL_CREATED,
tool_calls: toolCalls
})
}
}
})
}
public convertMcpToolsToSdkTools(mcpTools: MCPTool[]): Tool[] {
return mcpToolsToGeminiTools(mcpTools)
}
public convertSdkToolCallToMcp(toolCall: GeminiSdkToolCall, mcpTools: MCPTool[]): MCPTool | undefined {
return geminiFunctionCallToMcpTool(mcpTools, toolCall)
}
public convertSdkToolCallToMcpToolResponse(toolCall: GeminiSdkToolCall, mcpTool: MCPTool): ToolCallResponse {
const parsedArgs = (() => {
try {
return typeof toolCall.args === 'string' ? JSON.parse(toolCall.args) : toolCall.args
} catch {
return toolCall.args
}
})()
return {
id: toolCall.id || nanoid(),
toolCallId: toolCall.id,
tool: mcpTool,
arguments: parsedArgs,
status: 'pending'
} satisfies ToolCallResponse
}
public convertMcpToolResponseToSdkMessageParam(
mcpToolResponse: MCPToolResponse,
resp: MCPCallToolResponse,
model: Model
): GeminiSdkMessageParam | undefined {
if ('toolUseId' in mcpToolResponse && mcpToolResponse.toolUseId) {
return mcpToolCallResponseToGeminiMessage(mcpToolResponse, resp, isVisionModel(model))
} else if ('toolCallId' in mcpToolResponse) {
return {
role: 'user',
parts: [
{
functionResponse: {
id: mcpToolResponse.toolCallId,
name: mcpToolResponse.tool.id,
response: {
output: !resp.isError ? resp.content : undefined,
error: resp.isError ? resp.content : undefined
}
}
}
]
} satisfies Content
}
return
}
public buildSdkMessages(
currentReqMessages: Content[],
output: string,
toolResults: Content[],
toolCalls: FunctionCall[]
): Content[] {
const parts: Part[] = []
const modelParts: Part[] = []
if (output) {
modelParts.push({
text: output
})
}
toolCalls.forEach((toolCall) => {
modelParts.push({
functionCall: toolCall
})
})
parts.push(
...toolResults
.map((ts) => ts.parts)
.flat()
.filter((p) => p !== undefined)
)
const userMessage: Content = {
role: 'user',
parts: []
}
if (modelParts.length > 0) {
currentReqMessages.push({
role: 'model',
parts: modelParts
})
}
if (parts.length > 0) {
userMessage.parts?.push(...parts)
currentReqMessages.push(userMessage)
}
return currentReqMessages
}
override estimateMessageTokens(message: GeminiSdkMessageParam): number {
return (
message.parts?.reduce((acc, part) => {
if (part.text) {
return acc + estimateTextTokens(part.text)
}
if (part.functionCall) {
return acc + estimateTextTokens(JSON.stringify(part.functionCall))
}
if (part.functionResponse) {
return acc + estimateTextTokens(JSON.stringify(part.functionResponse.response))
}
if (part.inlineData) {
return acc + estimateTextTokens(part.inlineData.data || '')
}
if (part.fileData) {
return acc + estimateTextTokens(part.fileData.fileUri || '')
}
return acc
}, 0) || 0
)
}
public extractMessagesFromSdkPayload(sdkPayload: GeminiSdkParams): GeminiSdkMessageParam[] {
const messageParam: GeminiSdkMessageParam = {
role: 'user',
parts: []
}
if (Array.isArray(sdkPayload.message)) {
sdkPayload.message.forEach((part) => {
if (typeof part === 'string') {
messageParam.parts?.push({ text: part })
} else if (typeof part === 'object') {
messageParam.parts?.push(part)
}
})
}
return [...(sdkPayload.history || []), messageParam]
}
private async base64File(file: FileMetadata) {
const { data } = await window.api.file.base64File(file.id + file.ext)
return {
data,
mimeType: 'application/pdf'
}
}
}

View File

@ -1,143 +0,0 @@
import { GoogleGenAI } from '@google/genai'
import { loggerService } from '@logger'
import { createVertexProvider, isVertexAIConfigured } from '@renderer/hooks/useVertexAI'
import type { Model, Provider, VertexProvider } from '@renderer/types'
import { isVertexProvider } from '@renderer/utils/provider'
import { isEmpty } from 'lodash'
import { AnthropicVertexClient } from '../anthropic/AnthropicVertexClient'
import { GeminiAPIClient } from './GeminiAPIClient'
const logger = loggerService.withContext('VertexAPIClient')
export class VertexAPIClient extends GeminiAPIClient {
private authHeaders?: Record<string, string>
private authHeadersExpiry?: number
private anthropicVertexClient: AnthropicVertexClient
private vertexProvider: VertexProvider
constructor(provider: Provider) {
super(provider)
// 检查 VertexAI 配置
if (!isVertexAIConfigured()) {
throw new Error('VertexAI is not configured. Please configure project, location and service account credentials.')
}
this.anthropicVertexClient = new AnthropicVertexClient(provider)
// 如果传入的是普通 Provider转换为 VertexProvider
if (isVertexProvider(provider)) {
this.vertexProvider = provider
} else {
this.vertexProvider = createVertexProvider(provider)
}
}
override getClientCompatibilityType(model?: Model): string[] {
if (!model) {
return [this.constructor.name]
}
const actualClient = this.getClient(model)
if (actualClient === this) {
return [this.constructor.name]
}
return actualClient.getClientCompatibilityType(model)
}
public getClient(model: Model) {
if (model.id.includes('claude')) {
return this.anthropicVertexClient
}
return this
}
private formatApiHost(baseUrl: string) {
if (baseUrl.endsWith('/v1/')) {
baseUrl = baseUrl.slice(0, -4)
} else if (baseUrl.endsWith('/v1')) {
baseUrl = baseUrl.slice(0, -3)
}
return baseUrl
}
override getBaseURL() {
return this.formatApiHost(this.provider.apiHost)
}
override async getSdkInstance() {
if (this.sdkInstance) {
return this.sdkInstance
}
const { googleCredentials, project, location } = this.vertexProvider
if (!googleCredentials.privateKey || !googleCredentials.clientEmail || !project || !location) {
throw new Error('Vertex AI settings are not configured')
}
const authHeaders = await this.getServiceAccountAuthHeaders()
this.sdkInstance = new GoogleGenAI({
vertexai: true,
project: project,
location: location,
httpOptions: {
apiVersion: this.getApiVersion(),
headers: authHeaders,
baseUrl: isEmpty(this.getBaseURL()) ? undefined : this.getBaseURL()
}
})
return this.sdkInstance
}
/**
* service account
*/
private async getServiceAccountAuthHeaders(): Promise<Record<string, string> | undefined> {
const { googleCredentials, project } = this.vertexProvider
// 检查是否配置了 service account
if (!googleCredentials.privateKey || !googleCredentials.clientEmail || !project) {
return undefined
}
// 检查是否已有有效的认证头(提前 5 分钟过期)
const now = Date.now()
if (this.authHeaders && this.authHeadersExpiry && this.authHeadersExpiry - now > 5 * 60 * 1000) {
return this.authHeaders
}
try {
// 从主进程获取认证头
this.authHeaders = await window.api.vertexAI.getAuthHeaders({
projectId: project,
serviceAccount: {
privateKey: googleCredentials.privateKey,
clientEmail: googleCredentials.clientEmail
}
})
// 设置过期时间(通常认证头有效期为 1 小时)
this.authHeadersExpiry = now + 60 * 60 * 1000
return this.authHeaders
} catch (error: any) {
logger.error('Failed to get auth headers:', error)
throw new Error(`Service Account authentication failed: ${error.message}`)
}
}
/**
*
*/
clearAuthCache(): void {
this.authHeaders = undefined
this.authHeadersExpiry = undefined
const { googleCredentials, project } = this.vertexProvider
if (project && googleCredentials.clientEmail) {
window.api.vertexAI.clearAuthCache(project, googleCredentials.clientEmail)
}
}
}

View File

@ -1,8 +0,0 @@
export * from './ApiClientFactory'
export * from './BaseApiClient'
export * from './types'
// Export specific clients from subdirectories
export * from './anthropic/AnthropicAPIClient'
export * from './openai/OpenAIApiClient'
export * from './openai/OpenAIResponseAPIClient'

View File

@ -1,110 +0,0 @@
import { loggerService } from '@logger'
import { isSupportedModel } from '@renderer/config/models'
import type { Model, Provider } from '@renderer/types'
import type { NewApiModel } from '@renderer/types/sdk'
import { AnthropicAPIClient } from '../anthropic/AnthropicAPIClient'
import type { BaseApiClient } from '../BaseApiClient'
import { GeminiAPIClient } from '../gemini/GeminiAPIClient'
import { MixedBaseAPIClient } from '../MixedBaseApiClient'
import { OpenAIAPIClient } from '../openai/OpenAIApiClient'
import { OpenAIResponseAPIClient } from '../openai/OpenAIResponseAPIClient'
const logger = loggerService.withContext('NewAPIClient')
export class NewAPIClient extends MixedBaseAPIClient {
// 使用联合类型而不是any保持类型安全
protected clients: Map<string, AnthropicAPIClient | GeminiAPIClient | OpenAIResponseAPIClient | OpenAIAPIClient> =
new Map()
protected defaultClient: OpenAIAPIClient
protected currentClient: BaseApiClient
constructor(provider: Provider) {
super(provider)
const claudeClient = new AnthropicAPIClient(provider)
const geminiClient = new GeminiAPIClient(provider)
const openaiClient = new OpenAIAPIClient(provider)
const openaiResponseClient = new OpenAIResponseAPIClient(provider)
this.clients.set('claude', claudeClient)
this.clients.set('gemini', geminiClient)
this.clients.set('openai', openaiClient)
this.clients.set('openai-response', openaiResponseClient)
// 设置默认client
this.defaultClient = openaiClient
this.currentClient = this.defaultClient as BaseApiClient
}
override getBaseURL(): string {
if (!this.currentClient) {
return this.provider.apiHost
}
return this.currentClient.getBaseURL()
}
/**
* client
*/
protected getClient(model: Model): BaseApiClient {
if (!model.endpoint_type) {
throw new Error('Model endpoint type is not defined')
}
if (model.endpoint_type === 'anthropic') {
const client = this.clients.get('claude')
if (!client || !this.isValidClient(client)) {
throw new Error('Failed to get claude client')
}
return client
}
if (model.endpoint_type === 'openai-response') {
const client = this.clients.get('openai-response')
if (!client || !this.isValidClient(client)) {
throw new Error('Failed to get openai-response client')
}
return client
}
if (model.endpoint_type === 'gemini') {
const client = this.clients.get('gemini')
if (!client || !this.isValidClient(client)) {
throw new Error('Failed to get gemini client')
}
return client
}
if (model.endpoint_type === 'openai' || model.endpoint_type === 'image-generation') {
const client = this.clients.get('openai')
if (!client || !this.isValidClient(client)) {
throw new Error('Failed to get openai client')
}
return client
}
throw new Error('Invalid model endpoint type: ' + model.endpoint_type)
}
override async listModels(): Promise<NewApiModel[]> {
try {
const sdk = await this.defaultClient.getSdkInstance()
// Explicitly type the expected response shape so that `data` is recognised.
const response = await sdk.request<{ data: NewApiModel[] }>({
method: 'get',
path: '/models'
})
const models: NewApiModel[] = response.data ?? []
models.forEach((model) => {
model.id = model.id.trim()
})
return models.filter(isSupportedModel)
} catch (error) {
logger.error('Error listing models:', error as Error)
return []
}
}
}

View File

@ -1,309 +0,0 @@
import OpenAI, { AzureOpenAI } from '@cherrystudio/openai'
import { loggerService } from '@logger'
import { COPILOT_DEFAULT_HEADERS } from '@renderer/aiCore/provider/constants'
import {
isClaudeReasoningModel,
isOpenAIReasoningModel,
isSupportedModel,
isSupportedReasoningEffortOpenAIModel
} from '@renderer/config/models'
import { getStoreSetting } from '@renderer/hooks/useSettings'
import { getAssistantSettings } from '@renderer/services/AssistantService'
import store from '@renderer/store'
import type { SettingsState } from '@renderer/store/settings'
import { type Assistant, type GenerateImageParams, type Model, type Provider } from '@renderer/types'
import type {
OpenAIResponseSdkMessageParam,
OpenAIResponseSdkParams,
OpenAIResponseSdkRawChunk,
OpenAIResponseSdkRawOutput,
OpenAIResponseSdkTool,
OpenAIResponseSdkToolCall,
OpenAISdkMessageParam,
OpenAISdkParams,
OpenAISdkRawChunk,
OpenAISdkRawOutput,
ReasoningEffortOptionalParams
} from '@renderer/types/sdk'
import { withoutTrailingSlash } from '@renderer/utils/api'
import { isOllamaProvider } from '@renderer/utils/provider'
import { BaseApiClient } from '../BaseApiClient'
import { normalizeAzureOpenAIEndpoint } from './azureOpenAIEndpoint'
const logger = loggerService.withContext('OpenAIBaseClient')
/**
* OpenAI基础客户端类OpenAI客户端之间的共享功能
*/
export abstract class OpenAIBaseClient<
TSdkInstance extends OpenAI | AzureOpenAI,
TSdkParams extends OpenAISdkParams | OpenAIResponseSdkParams,
TRawOutput extends OpenAISdkRawOutput | OpenAIResponseSdkRawOutput,
TRawChunk extends OpenAISdkRawChunk | OpenAIResponseSdkRawChunk,
TMessageParam extends OpenAISdkMessageParam | OpenAIResponseSdkMessageParam,
TToolCall extends OpenAI.Chat.Completions.ChatCompletionMessageToolCall | OpenAIResponseSdkToolCall,
TSdkSpecificTool extends OpenAI.Chat.Completions.ChatCompletionTool | OpenAIResponseSdkTool
> extends BaseApiClient<TSdkInstance, TSdkParams, TRawOutput, TRawChunk, TMessageParam, TToolCall, TSdkSpecificTool> {
constructor(provider: Provider) {
super(provider)
}
// 仅适用于openai
override getBaseURL(): string {
// apiHost is formatted when called by AiProvider
return this.provider.apiHost
}
override async generateImage({
model,
prompt,
negativePrompt,
imageSize,
batchSize,
seed,
numInferenceSteps,
guidanceScale,
signal,
promptEnhancement
}: GenerateImageParams): Promise<string[]> {
const sdk = await this.getSdkInstance()
const response = (await sdk.request({
method: 'post',
path: '/v1/images/generations',
signal,
body: {
model,
prompt,
negative_prompt: negativePrompt,
image_size: imageSize,
batch_size: batchSize,
seed: seed ? parseInt(seed) : undefined,
num_inference_steps: numInferenceSteps,
guidance_scale: guidanceScale,
prompt_enhancement: promptEnhancement
}
})) as { data: Array<{ url: string }> }
return response.data.map((item) => item.url)
}
override async getEmbeddingDimensions(model: Model): Promise<number> {
let sdk: OpenAI = await this.getSdkInstance()
if (isOllamaProvider(this.provider)) {
const embedBaseUrl = `${this.provider.apiHost.replace(/(\/(api|v1))\/?$/, '')}/v1`
sdk = sdk.withOptions({ baseURL: embedBaseUrl })
}
const data = await sdk.embeddings.create({
model: model.id,
input: model?.provider === 'baidu-cloud' ? ['hi'] : 'hi',
encoding_format: this.provider.id === 'voyageai' ? undefined : 'float'
})
return data.data[0].embedding.length
}
override async listModels(): Promise<OpenAI.Models.Model[]> {
try {
const sdk = await this.getSdkInstance()
if (this.provider.id === 'openrouter') {
// https://openrouter.ai/docs/api/api-reference/embeddings/list-embeddings-models
const embedBaseUrl = 'https://openrouter.ai/api/v1/embeddings'
const embedSdk = sdk.withOptions({ baseURL: embedBaseUrl })
const modelPromise = sdk.models.list()
const embedModelPromise = embedSdk.models.list()
const [modelResponse, embedModelResponse] = await Promise.all([modelPromise, embedModelPromise])
const models = [...modelResponse.data, ...embedModelResponse.data]
const uniqueModels = Array.from(new Map(models.map((model) => [model.id, model])).values())
return uniqueModels.filter(isSupportedModel)
}
if (this.provider.id === 'github') {
// GitHub Models 其 models 和 chat completions 两个接口的 baseUrl 不一样
const baseUrl = 'https://models.github.ai/catalog/'
const newSdk = sdk.withOptions({ baseURL: baseUrl })
const response = await newSdk.models.list()
// @ts-ignore key is not typed
return response?.body
.map((model) => ({
id: model.id,
description: model.summary,
object: 'model',
owned_by: model.publisher
}))
.filter(isSupportedModel)
}
if (isOllamaProvider(this.provider)) {
const baseUrl = withoutTrailingSlash(this.getBaseURL())
.replace(/\/v1$/, '')
.replace(/\/api$/, '')
const response = await fetch(`${baseUrl}/api/tags`, {
headers: {
Authorization: `Bearer ${this.apiKey}`,
...this.defaultHeaders(),
...this.provider.extra_headers
}
})
if (!response.ok) {
throw new Error(`Ollama server returned ${response.status} ${response.statusText}`)
}
const data = await response.json()
if (!data?.models || !Array.isArray(data.models)) {
throw new Error('Invalid response from Ollama API: missing models array')
}
return data.models.map((model) => ({
id: model.name,
object: 'model',
owned_by: 'ollama'
}))
}
const response = await sdk.models.list()
if (this.provider.id === 'together') {
// @ts-ignore key is not typed
return response?.body.map((model) => ({
id: model.id,
description: model.display_name,
object: 'model',
owned_by: model.organization
}))
}
const models = response.data || []
models.forEach((model) => {
model.id = model.id.trim()
})
return models.filter(isSupportedModel)
} catch (error) {
logger.error('Error listing models:', error as Error)
return []
}
}
override async getSdkInstance() {
if (this.sdkInstance) {
return this.sdkInstance
}
let apiKeyForSdkInstance = this.apiKey
let baseURLForSdkInstance = this.getBaseURL()
logger.debug('baseURLForSdkInstance', { baseURLForSdkInstance })
let headersForSdkInstance = {
...this.defaultHeaders(),
...this.provider.extra_headers
}
if (this.provider.id === 'copilot') {
const defaultHeaders = store.getState().copilot.defaultHeaders
const { token } = await window.api.copilot.getToken(defaultHeaders)
// this.provider.apiKey不允许修改
// this.provider.apiKey = token
apiKeyForSdkInstance = token
baseURLForSdkInstance = this.getBaseURL()
headersForSdkInstance = {
...headersForSdkInstance,
...COPILOT_DEFAULT_HEADERS
}
}
if (this.provider.id === 'azure-openai' || this.provider.type === 'azure-openai') {
this.sdkInstance = new AzureOpenAI({
dangerouslyAllowBrowser: true,
apiKey: apiKeyForSdkInstance,
apiVersion: this.provider.apiVersion,
endpoint: normalizeAzureOpenAIEndpoint(this.provider.apiHost)
}) as TSdkInstance
} else {
this.sdkInstance = new OpenAI({
dangerouslyAllowBrowser: true,
apiKey: apiKeyForSdkInstance,
baseURL: baseURLForSdkInstance,
defaultHeaders: headersForSdkInstance
}) as TSdkInstance
}
return this.sdkInstance
}
override getTemperature(assistant: Assistant, model: Model): number | undefined {
if (assistant.settings?.reasoning_effort && isClaudeReasoningModel(model)) {
return undefined
}
return super.getTemperature(assistant, model)
}
override getTopP(assistant: Assistant, model: Model): number | undefined {
if (assistant.settings?.reasoning_effort && isClaudeReasoningModel(model)) {
return undefined
}
return super.getTopP(assistant, model)
}
/**
* Get the provider specific parameters for the assistant
* @param assistant - The assistant
* @param model - The model
* @returns The provider specific parameters
*/
protected getProviderSpecificParameters(assistant: Assistant, model: Model) {
const { maxTokens } = getAssistantSettings(assistant)
if (this.provider.id === 'openrouter') {
if (model.id.includes('deepseek-r1')) {
return {
include_reasoning: true
}
}
}
if (isOpenAIReasoningModel(model)) {
return {
max_tokens: undefined,
max_completion_tokens: maxTokens
}
}
return {}
}
/**
* Get the reasoning effort for the assistant
* @param assistant - The assistant
* @param model - The model
* @returns The reasoning effort
*/
protected getReasoningEffort(assistant: Assistant, model: Model): ReasoningEffortOptionalParams {
if (!isSupportedReasoningEffortOpenAIModel(model)) {
return {}
}
const openAI = getStoreSetting('openAI') as SettingsState['openAI']
const summaryText = openAI?.summaryText || 'off'
let summary: string | undefined = undefined
if (summaryText === 'off' || model.id.includes('o1-pro')) {
summary = undefined
} else {
summary = summaryText
}
const reasoningEffort = assistant?.settings?.reasoning_effort
if (!reasoningEffort) {
return {}
}
if (isSupportedReasoningEffortOpenAIModel(model)) {
return {
reasoning: {
effort: reasoningEffort as OpenAI.ReasoningEffort,
summary: summary
} as OpenAI.Reasoning
}
}
return {}
}
}

View File

@ -1,769 +0,0 @@
import OpenAI, { AzureOpenAI } from '@cherrystudio/openai'
import type { ResponseInput } from '@cherrystudio/openai/resources/responses/responses'
import { loggerService } from '@logger'
import type { GenericChunk } from '@renderer/aiCore/legacy/middleware/schemas'
import type { CompletionsContext } from '@renderer/aiCore/legacy/middleware/types'
import {
isGPT5SeriesModel,
isOpenAIChatCompletionOnlyModel,
isOpenAILLMModel,
isOpenAIOpenWeightModel,
isSupportedReasoningEffortOpenAIModel,
isSupportVerbosityModel,
isVisionModel
} from '@renderer/config/models'
import { estimateTextTokens } from '@renderer/services/TokenService'
import type {
FileMetadata,
MCPCallToolResponse,
MCPTool,
MCPToolResponse,
Model,
OpenAIServiceTier,
Provider,
ToolCallResponse
} from '@renderer/types'
import { FileTypes, WebSearchSource } from '@renderer/types'
import { ChunkType } from '@renderer/types/chunk'
import type { Message } from '@renderer/types/newMessage'
import type {
OpenAIResponseSdkMessageParam,
OpenAIResponseSdkParams,
OpenAIResponseSdkRawChunk,
OpenAIResponseSdkRawOutput,
OpenAIResponseSdkTool,
OpenAIResponseSdkToolCall
} from '@renderer/types/sdk'
import { addImageFileToContents } from '@renderer/utils/formats'
import {
isSupportedToolUse,
mcpToolCallResponseToOpenAIMessage,
mcpToolsToOpenAIResponseTools,
openAIToolsToMcpTool
} from '@renderer/utils/mcp-tools'
import { findFileBlocks, findImageBlocks } from '@renderer/utils/messageUtils/find'
import { isSupportDeveloperRoleProvider } from '@renderer/utils/provider'
import { MB } from '@shared/config/constant'
import { t } from 'i18next'
import { isEmpty } from 'lodash'
import type { RequestTransformer, ResponseChunkTransformer } from '../types'
import { OpenAIAPIClient } from './OpenAIApiClient'
import { OpenAIBaseClient } from './OpenAIBaseClient'
const logger = loggerService.withContext('OpenAIResponseAPIClient')
export class OpenAIResponseAPIClient extends OpenAIBaseClient<
OpenAI,
OpenAIResponseSdkParams,
OpenAIResponseSdkRawOutput,
OpenAIResponseSdkRawChunk,
OpenAIResponseSdkMessageParam,
OpenAIResponseSdkToolCall,
OpenAIResponseSdkTool
> {
private client: OpenAIAPIClient
constructor(provider: Provider) {
super(provider)
this.client = new OpenAIAPIClient(provider)
}
private formatApiHost() {
const host = this.provider.apiHost
if (host.endsWith('/openai/v1')) {
return host
} else {
if (host.endsWith('/')) {
return host + 'openai/v1'
} else {
return host + '/openai/v1'
}
}
}
/**
*
*/
public getClient(model: Model) {
if (this.provider.type === 'openai-response' && !isOpenAIChatCompletionOnlyModel(model)) {
return this
}
if (isOpenAILLMModel(model) && !isOpenAIChatCompletionOnlyModel(model)) {
if (this.provider.id === 'azure-openai' || this.provider.type === 'azure-openai') {
this.provider = { ...this.provider, apiHost: this.formatApiHost() }
if (this.provider.apiVersion === 'preview' || this.provider.apiVersion === 'v1') {
return this
} else {
return this.client
}
}
return this
} else {
return this.client
}
}
/**
* 使
*/
public override getClientCompatibilityType(model?: Model): string[] {
if (!model) {
return [this.constructor.name]
}
const actualClient = this.getClient(model)
// 避免循环调用:如果返回的是自己,直接返回自己的类型
if (actualClient === this) {
return [this.constructor.name]
}
return actualClient.getClientCompatibilityType(model)
}
override async getSdkInstance() {
if (this.sdkInstance) {
return this.sdkInstance
}
const baseUrl = this.getBaseURL()
if (this.provider.id === 'azure-openai' || this.provider.type === 'azure-openai') {
return new AzureOpenAI({
dangerouslyAllowBrowser: true,
apiKey: this.apiKey,
apiVersion: this.provider.apiVersion,
baseURL: this.provider.apiHost
})
} else {
return new OpenAI({
dangerouslyAllowBrowser: true,
apiKey: this.apiKey,
baseURL: baseUrl,
defaultHeaders: {
...this.defaultHeaders(),
...this.provider.extra_headers
}
})
}
}
override async createCompletions(
payload: OpenAIResponseSdkParams,
options?: OpenAI.RequestOptions
): Promise<OpenAIResponseSdkRawOutput> {
const sdk = await this.getSdkInstance()
return await sdk.responses.create(payload, options)
}
private async handlePdfFile(file: FileMetadata): Promise<OpenAI.Responses.ResponseInputFile | undefined> {
if (file.size > 32 * MB) return undefined
try {
const pageCount = await window.api.file.pdfInfo(file.id + file.ext)
if (pageCount > 100) return undefined
} catch {
return undefined
}
const { data } = await window.api.file.base64File(file.id + file.ext)
return {
type: 'input_file',
filename: file.origin_name,
file_data: `data:application/pdf;base64,${data}`
} as OpenAI.Responses.ResponseInputFile
}
public async convertMessageToSdkParam(message: Message, model: Model): Promise<OpenAIResponseSdkMessageParam> {
const isVision = isVisionModel(model)
const { textContent, imageContents } = await this.getMessageContent(message)
const fileBlocks = findFileBlocks(message)
const imageBlocks = findImageBlocks(message)
if (fileBlocks.length === 0 && imageBlocks.length === 0 && imageContents.length === 0) {
if (message.role === 'assistant') {
return {
role: 'assistant',
content: textContent
}
} else {
return {
role: message.role === 'system' ? 'user' : message.role,
content: textContent ? [{ type: 'input_text', text: textContent }] : []
} as OpenAI.Responses.EasyInputMessage
}
}
const parts: OpenAI.Responses.ResponseInputContent[] = []
if (imageContents) {
parts.push({
type: 'input_text',
text: textContent
})
}
if (imageContents.length > 0) {
for (const imageContent of imageContents) {
const image = await window.api.file.base64Image(imageContent.fileId + imageContent.fileExt)
parts.push({
detail: 'auto',
type: 'input_image',
image_url: image.data
})
}
}
for (const imageBlock of imageBlocks) {
if (isVision) {
if (imageBlock.file) {
const image = await window.api.file.base64Image(imageBlock.file.id + imageBlock.file.ext)
parts.push({
detail: 'auto',
type: 'input_image',
image_url: image.data as string
})
} else if (imageBlock.url && imageBlock.url.startsWith('data:')) {
parts.push({
detail: 'auto',
type: 'input_image',
image_url: imageBlock.url
})
}
}
}
for (const fileBlock of fileBlocks) {
const file = fileBlock.file
if (!file) continue
if (isVision && file.ext === '.pdf') {
const pdfPart = await this.handlePdfFile(file)
if (pdfPart) {
parts.push(pdfPart)
continue
}
}
if ([FileTypes.TEXT, FileTypes.DOCUMENT].includes(file.type)) {
const fileContent = (await window.api.file.read(file.id + file.ext, true)).trim()
parts.push({
type: 'input_text',
text: file.origin_name + '\n' + fileContent
})
}
}
return {
role: message.role === 'system' ? 'user' : message.role,
content: parts
}
}
public convertMcpToolsToSdkTools(mcpTools: MCPTool[]): OpenAI.Responses.Tool[] {
return mcpToolsToOpenAIResponseTools(mcpTools)
}
public convertSdkToolCallToMcp(toolCall: OpenAIResponseSdkToolCall, mcpTools: MCPTool[]): MCPTool | undefined {
return openAIToolsToMcpTool(mcpTools, toolCall)
}
public convertSdkToolCallToMcpToolResponse(toolCall: OpenAIResponseSdkToolCall, mcpTool: MCPTool): ToolCallResponse {
const parsedArgs = (() => {
try {
return JSON.parse(toolCall.arguments)
} catch {
return toolCall.arguments
}
})()
return {
id: toolCall.call_id,
toolCallId: toolCall.call_id,
tool: mcpTool,
arguments: parsedArgs,
status: 'pending'
}
}
public convertMcpToolResponseToSdkMessageParam(
mcpToolResponse: MCPToolResponse,
resp: MCPCallToolResponse,
model: Model
): OpenAIResponseSdkMessageParam | undefined {
if ('toolUseId' in mcpToolResponse && mcpToolResponse.toolUseId) {
return mcpToolCallResponseToOpenAIMessage(mcpToolResponse, resp, isVisionModel(model))
} else if ('toolCallId' in mcpToolResponse && mcpToolResponse.toolCallId) {
return {
type: 'function_call_output',
call_id: mcpToolResponse.toolCallId,
output: JSON.stringify(resp.content)
}
}
return
}
private convertResponseToMessageContent(response: OpenAI.Responses.Response): ResponseInput {
const content: OpenAI.Responses.ResponseInput = []
response.output.forEach((item) => {
if (item.type !== 'apply_patch_call' && item.type !== 'apply_patch_call_output') {
content.push(item)
} else if (item.type === 'apply_patch_call') {
if (item.operation !== undefined) {
const applyPatchToolCall: OpenAI.Responses.ResponseInputItem.ApplyPatchCall = {
...item,
operation: item.operation
}
content.push(applyPatchToolCall)
} else {
logger.warn('Undefined tool call operation for ApplyPatchToolCall.')
}
} else if (item.type === 'apply_patch_call_output') {
if (item.output !== undefined) {
const applyPatchToolCallOutput: OpenAI.Responses.ResponseInputItem.ApplyPatchCallOutput = {
...item,
output: item.output === null ? undefined : item.output
}
content.push(applyPatchToolCallOutput)
} else {
logger.warn('Undefined tool call operation for ApplyPatchToolCall.')
}
}
})
return content
}
public buildSdkMessages(
currentReqMessages: OpenAIResponseSdkMessageParam[],
output: OpenAI.Responses.Response | undefined,
toolResults: OpenAIResponseSdkMessageParam[],
toolCalls: OpenAIResponseSdkToolCall[]
): OpenAIResponseSdkMessageParam[] {
if (!output && toolCalls.length === 0) {
return [...currentReqMessages, ...toolResults]
}
if (!output) {
return [...currentReqMessages, ...(toolCalls || []), ...(toolResults || [])]
}
const content = this.convertResponseToMessageContent(output)
return [...currentReqMessages, ...content, ...(toolResults || [])]
}
override estimateMessageTokens(message: OpenAIResponseSdkMessageParam): number {
let sum = 0
if ('content' in message) {
if (typeof message.content === 'string') {
sum += estimateTextTokens(message.content)
} else if (Array.isArray(message.content)) {
for (const part of message.content) {
switch (part.type) {
case 'input_text':
sum += estimateTextTokens(part.text)
break
case 'input_image':
sum += estimateTextTokens(part.image_url || '')
break
default:
break
}
}
}
}
switch (message.type) {
case 'function_call_output': {
let str = ''
if (typeof message.output === 'string') {
str = message.output
} else {
for (const part of message.output) {
switch (part.type) {
case 'input_text':
str += part.text
break
case 'input_image':
str += part.image_url || ''
break
case 'input_file':
str += part.file_data || ''
break
}
}
}
sum += estimateTextTokens(str)
break
}
case 'function_call':
sum += estimateTextTokens(message.arguments)
break
default:
break
}
return sum
}
public extractMessagesFromSdkPayload(sdkPayload: OpenAIResponseSdkParams): OpenAIResponseSdkMessageParam[] {
if (!sdkPayload.input || typeof sdkPayload.input === 'string') {
return [{ role: 'user', content: sdkPayload.input ?? '' }]
}
return sdkPayload.input
}
getRequestTransformer(): RequestTransformer<OpenAIResponseSdkParams, OpenAIResponseSdkMessageParam> {
return {
transform: async (
coreRequest,
assistant,
model,
isRecursiveCall,
recursiveSdkMessages
): Promise<{
payload: OpenAIResponseSdkParams
messages: OpenAIResponseSdkMessageParam[]
metadata: Record<string, any>
}> => {
const { messages, mcpTools, maxTokens, streamOutput, enableWebSearch, enableGenerateImage } = coreRequest
// 1. 处理系统消息
const systemMessage: OpenAI.Responses.EasyInputMessage = {
role: 'system',
content: []
}
const systemMessageContent: OpenAI.Responses.ResponseInputMessageContentList = []
const systemMessageInput: OpenAI.Responses.ResponseInputText = {
text: assistant.prompt || '',
type: 'input_text'
}
if (
isSupportedReasoningEffortOpenAIModel(model) &&
isSupportDeveloperRoleProvider(this.provider) &&
isOpenAIOpenWeightModel(model)
) {
systemMessage.role = 'developer'
}
// 2. 设置工具
let tools: OpenAI.Responses.Tool[] = []
const { tools: extraTools } = this.setupToolsConfig({
mcpTools: mcpTools,
model,
enableToolUse: isSupportedToolUse(assistant)
})
systemMessageContent.push(systemMessageInput)
systemMessage.content = systemMessageContent
// 3. 处理用户消息
let userMessage: OpenAI.Responses.ResponseInputItem[] = []
if (typeof messages === 'string') {
userMessage.push({ role: 'user', content: messages })
} else {
const processedMessages = addImageFileToContents(messages)
for (const message of processedMessages) {
userMessage.push(await this.convertMessageToSdkParam(message, model))
}
}
// FIXME: 最好还是直接使用previous_response_id来处理或者在数据库中存储image_generation_call的id
if (enableGenerateImage) {
const finalAssistantMessage = userMessage.findLast(
(m) => (m as OpenAI.Responses.EasyInputMessage).role === 'assistant'
) as OpenAI.Responses.EasyInputMessage
const finalUserMessage = userMessage.pop() as OpenAI.Responses.EasyInputMessage
if (finalUserMessage && Array.isArray(finalUserMessage.content)) {
if (finalAssistantMessage && Array.isArray(finalAssistantMessage.content)) {
finalAssistantMessage.content = [...finalAssistantMessage.content, ...finalUserMessage.content]
// 这里是故意将上条助手消息的内容(包含图片和文件)作为用户消息发送
userMessage = [{ ...finalAssistantMessage, role: 'user' } as OpenAI.Responses.EasyInputMessage]
} else {
userMessage.push(finalUserMessage)
}
}
}
// 4. 最终请求消息
let reqMessages: OpenAI.Responses.ResponseInput
if (!systemMessage.content) {
reqMessages = [...userMessage]
} else {
reqMessages = [systemMessage, ...userMessage].filter(Boolean) as OpenAI.Responses.EasyInputMessage[]
}
if (enableWebSearch) {
tools.push({
type: 'web_search_preview'
})
}
if (enableGenerateImage) {
tools.push({
type: 'image_generation',
partial_images: streamOutput ? 2 : undefined
})
}
tools = tools.concat(extraTools)
const reasoningEffort = this.getReasoningEffort(assistant, model)
// minimal cannot be used with web_search tool
if (isGPT5SeriesModel(model) && reasoningEffort.reasoning?.effort === 'minimal' && enableWebSearch) {
reasoningEffort.reasoning.effort = 'low'
}
const commonParams: OpenAIResponseSdkParams = {
model: model.id,
input:
isRecursiveCall && recursiveSdkMessages && recursiveSdkMessages.length > 0
? recursiveSdkMessages
: reqMessages,
temperature: this.getTemperature(assistant, model),
top_p: this.getTopP(assistant, model),
max_output_tokens: maxTokens,
stream: streamOutput,
tools: !isEmpty(tools) ? tools : undefined,
// groq 有不同的 service tier 配置,不符合 openai 接口类型
service_tier: this.getServiceTier(model) as OpenAIServiceTier,
...(isSupportVerbosityModel(model)
? {
text: {
verbosity: this.getVerbosity(model)
}
}
: {}),
...(this.getReasoningEffort(assistant, model) as OpenAI.Reasoning),
// 只在对话场景下应用自定义参数,避免影响翻译、总结等其他业务逻辑
// 注意:用户自定义参数总是应该覆盖其他参数
...(coreRequest.callType === 'chat' ? this.getCustomParameters(assistant) : {})
}
const timeout = this.getTimeout(model)
return { payload: commonParams, messages: reqMessages, metadata: { timeout } }
}
}
}
getResponseChunkTransformer(ctx: CompletionsContext): ResponseChunkTransformer<OpenAIResponseSdkRawChunk> {
const toolCalls: OpenAIResponseSdkToolCall[] = []
const outputItems: OpenAI.Responses.ResponseOutputItem[] = []
let hasBeenCollectedToolCalls = false
let hasReasoningSummary = false
let isFirstThinkingChunk = true
let isFirstTextChunk = true
return () => ({
async transform(chunk: OpenAIResponseSdkRawChunk, controller: TransformStreamDefaultController<GenericChunk>) {
if (typeof chunk === 'string') {
try {
chunk = JSON.parse(chunk)
} catch (error) {
logger.error('invalid chunk', { chunk, error })
throw new Error(t('error.chat.chunk.non_json'))
}
}
// 处理chunk
if ('output' in chunk) {
if (ctx._internal?.toolProcessingState) {
ctx._internal.toolProcessingState.output = chunk
}
for (const output of chunk.output) {
switch (output.type) {
case 'message':
if (output.content[0].type === 'output_text') {
if (isFirstTextChunk) {
controller.enqueue({
type: ChunkType.TEXT_START
})
isFirstTextChunk = false
}
controller.enqueue({
type: ChunkType.TEXT_DELTA,
text: output.content[0].text
})
if (output.content[0].annotations && output.content[0].annotations.length > 0) {
controller.enqueue({
type: ChunkType.LLM_WEB_SEARCH_COMPLETE,
llm_web_search: {
source: WebSearchSource.OPENAI_RESPONSE,
results: output.content[0].annotations
}
})
}
}
break
case 'reasoning':
if (isFirstThinkingChunk) {
controller.enqueue({
type: ChunkType.THINKING_START
})
isFirstThinkingChunk = false
}
controller.enqueue({
type: ChunkType.THINKING_DELTA,
text: output.summary.map((s) => s.text).join('\n')
})
break
case 'function_call':
toolCalls.push(output)
break
case 'image_generation_call':
controller.enqueue({
type: ChunkType.IMAGE_CREATED
})
controller.enqueue({
type: ChunkType.IMAGE_COMPLETE,
image: {
type: 'base64',
images: [`data:image/png;base64,${output.result}`]
}
})
}
}
if (toolCalls.length > 0) {
controller.enqueue({
type: ChunkType.MCP_TOOL_CREATED,
tool_calls: toolCalls
})
}
controller.enqueue({
type: ChunkType.LLM_RESPONSE_COMPLETE,
response: {
usage: {
prompt_tokens: chunk.usage?.input_tokens || 0,
completion_tokens: chunk.usage?.output_tokens || 0,
total_tokens: chunk.usage?.total_tokens || 0
}
}
})
} else {
switch (chunk.type) {
case 'response.output_item.added':
if (chunk.item.type === 'function_call') {
outputItems.push(chunk.item)
} else if (chunk.item.type === 'web_search_call') {
controller.enqueue({
type: ChunkType.LLM_WEB_SEARCH_IN_PROGRESS
})
}
break
case 'response.reasoning_summary_part.added':
if (hasReasoningSummary) {
const separator = '\n\n'
controller.enqueue({
type: ChunkType.THINKING_DELTA,
text: separator
})
}
hasReasoningSummary = true
break
case 'response.reasoning_summary_text.delta':
if (isFirstThinkingChunk) {
controller.enqueue({
type: ChunkType.THINKING_START
})
isFirstThinkingChunk = false
}
controller.enqueue({
type: ChunkType.THINKING_DELTA,
text: chunk.delta
})
break
case 'response.image_generation_call.generating':
controller.enqueue({
type: ChunkType.IMAGE_CREATED
})
break
case 'response.image_generation_call.partial_image':
controller.enqueue({
type: ChunkType.IMAGE_DELTA,
image: {
type: 'base64',
images: [`data:image/png;base64,${chunk.partial_image_b64}`]
}
})
break
case 'response.image_generation_call.completed':
controller.enqueue({
type: ChunkType.IMAGE_COMPLETE
})
break
case 'response.output_text.delta': {
if (isFirstTextChunk) {
controller.enqueue({
type: ChunkType.TEXT_START
})
isFirstTextChunk = false
}
controller.enqueue({
type: ChunkType.TEXT_DELTA,
text: chunk.delta
})
break
}
case 'response.function_call_arguments.done': {
const outputItem: OpenAI.Responses.ResponseOutputItem | undefined = outputItems.find(
(item) => item.id === chunk.item_id
)
if (outputItem) {
if (outputItem.type === 'function_call') {
toolCalls.push({
...outputItem,
arguments: chunk.arguments,
status: 'completed'
})
}
}
break
}
case 'response.content_part.done': {
if (chunk.part.type === 'output_text' && chunk.part.annotations && chunk.part.annotations.length > 0) {
controller.enqueue({
type: ChunkType.LLM_WEB_SEARCH_COMPLETE,
llm_web_search: {
source: WebSearchSource.OPENAI_RESPONSE,
results: chunk.part.annotations
}
})
}
if (toolCalls.length > 0 && !hasBeenCollectedToolCalls) {
controller.enqueue({
type: ChunkType.MCP_TOOL_CREATED,
tool_calls: toolCalls
})
hasBeenCollectedToolCalls = true
}
break
}
case 'response.completed': {
if (ctx._internal?.toolProcessingState) {
ctx._internal.toolProcessingState.output = chunk.response
}
if (toolCalls.length > 0 && !hasBeenCollectedToolCalls) {
controller.enqueue({
type: ChunkType.MCP_TOOL_CREATED,
tool_calls: toolCalls
})
hasBeenCollectedToolCalls = true
}
const completion_tokens = chunk.response.usage?.output_tokens || 0
const total_tokens = chunk.response.usage?.total_tokens || 0
controller.enqueue({
type: ChunkType.LLM_RESPONSE_COMPLETE,
response: {
usage: {
prompt_tokens: chunk.response.usage?.input_tokens || 0,
completion_tokens: completion_tokens,
total_tokens: total_tokens
}
}
})
break
}
case 'error': {
controller.enqueue({
type: ChunkType.ERROR,
error: {
message: chunk.message,
code: chunk.code
}
})
break
}
}
}
}
})
}
}

View File

@ -1,4 +0,0 @@
export function normalizeAzureOpenAIEndpoint(apiHost: string): string {
const normalizedHost = apiHost.replace(/\/+$/, '')
return normalizedHost.replace(/\/openai(?:\/v1)?$/i, '')
}

View File

@ -1,56 +0,0 @@
import type OpenAI from '@cherrystudio/openai'
import { loggerService } from '@logger'
import { isSupportedModel } from '@renderer/config/models'
import type { Provider } from '@renderer/types'
import { objectKeys } from '@renderer/types'
import { formatApiHost } from '@renderer/utils'
import { withoutTrailingApiVersion } from '@shared/utils'
import { OpenAIAPIClient } from '../openai/OpenAIApiClient'
const logger = loggerService.withContext('OVMSClient')
export class OVMSClient extends OpenAIAPIClient {
constructor(provider: Provider) {
super(provider)
}
override async listModels(): Promise<OpenAI.Models.Model[]> {
try {
const sdk = await this.getSdkInstance()
const url = formatApiHost(withoutTrailingApiVersion(this.getBaseURL()), true, 'v1')
const chatModelsResponse = await sdk.withOptions({ baseURL: url }).get('/config')
logger.debug(`Chat models response: ${JSON.stringify(chatModelsResponse)}`)
// Parse the config response to extract model information
const config = chatModelsResponse as Record<string, any>
const models = objectKeys(config)
.map((modelName) => {
const modelInfo = config[modelName]
// Check if model has at least one version with "AVAILABLE" state
const hasAvailableVersion = modelInfo?.model_version_status?.some(
(versionStatus: any) => versionStatus?.state === 'AVAILABLE'
)
if (hasAvailableVersion) {
return {
id: modelName,
object: 'model' as const,
owned_by: 'ovms',
created: Date.now()
}
}
return null // Skip models without available versions
})
.filter(Boolean) // Remove null entries
logger.debug(`Processed models: ${JSON.stringify(models)}`)
// Filter out unsupported models
return models.filter((model): model is OpenAI.Models.Model => model !== null && isSupportedModel(model))
} catch (error) {
logger.error(`Error listing OVMS models: ${error}`)
return []
}
}
}

View File

@ -1,72 +0,0 @@
import type OpenAI from '@cherrystudio/openai'
import { loggerService } from '@logger'
import { isSupportedModel } from '@renderer/config/models'
import type { Model, Provider } from '@renderer/types'
import { OpenAIAPIClient } from '../openai/OpenAIApiClient'
const logger = loggerService.withContext('PPIOAPIClient')
export class PPIOAPIClient extends OpenAIAPIClient {
constructor(provider: Provider) {
super(provider)
}
// oxlint-disable-next-line @typescript-eslint/no-unused-vars
override getClientCompatibilityType(_model?: Model): string[] {
return ['OpenAIAPIClient']
}
override async listModels(): Promise<OpenAI.Models.Model[]> {
try {
const sdk = await this.getSdkInstance()
// PPIO requires three separate requests to get all model types
const [chatModelsResponse, embeddingModelsResponse, rerankerModelsResponse] = await Promise.all([
// Chat/completion models
sdk.request({
method: 'get',
path: '/models'
}),
// Embedding models
sdk.request({
method: 'get',
path: '/models?model_type=embedding'
}),
// Reranker models
sdk.request({
method: 'get',
path: '/models?model_type=reranker'
})
])
// Extract models from all responses
// @ts-ignore - PPIO response structure may not be typed
const allModels = [
...((chatModelsResponse as any)?.data || []),
...((embeddingModelsResponse as any)?.data || []),
...((rerankerModelsResponse as any)?.data || [])
]
// Process and standardize model data
const processedModels = allModels.map((model: any) => ({
id: model.id || model.name,
description: model.description || model.display_name || model.summary,
object: 'model' as const,
owned_by: model.owned_by || model.publisher || model.organization || 'ppio',
created: model.created || Date.now()
}))
// Clean up model IDs and filter supported models
processedModels.forEach((model) => {
if (model.id) {
model.id = model.id.trim()
}
})
return processedModels.filter(isSupportedModel)
} catch (error) {
logger.error('Error listing PPIO models:', error as Error)
return []
}
}
}

View File

@ -1,141 +0,0 @@
import type Anthropic from '@anthropic-ai/sdk'
import type OpenAI from '@cherrystudio/openai'
import type { Assistant, MCPTool, MCPToolResponse, Model, ToolCallResponse } from '@renderer/types'
import type { Provider } from '@renderer/types'
import type {
AnthropicSdkRawChunk,
OpenAIResponseSdkRawChunk,
OpenAIResponseSdkRawOutput,
OpenAISdkRawChunk,
SdkMessageParam,
SdkParams,
SdkRawChunk,
SdkRawOutput,
SdkTool,
SdkToolCall
} from '@renderer/types/sdk'
import type { CompletionsParams, GenericChunk } from '../middleware/schemas'
import type { CompletionsContext } from '../middleware/types'
/**
*
*/
export interface RawStreamListener<TRawChunk = SdkRawChunk> {
onChunk?: (chunk: TRawChunk) => void
onStart?: () => void
onEnd?: () => void
onError?: (error: Error) => void
}
/**
* OpenAI
*/
export interface OpenAIStreamListener extends RawStreamListener<OpenAISdkRawChunk> {
onChoice?: (choice: OpenAI.Chat.Completions.ChatCompletionChunk.Choice) => void
onFinishReason?: (reason: string) => void
}
/**
* OpenAI Response
*/
export interface OpenAIResponseStreamListener<TChunk extends OpenAIResponseSdkRawChunk = OpenAIResponseSdkRawChunk>
extends RawStreamListener<TChunk> {
onMessage?: (response: OpenAIResponseSdkRawOutput) => void
}
/**
* Anthropic
*/
export interface AnthropicStreamListener<TChunk extends AnthropicSdkRawChunk = AnthropicSdkRawChunk>
extends RawStreamListener<TChunk> {
onContentBlock?: (contentBlock: Anthropic.Messages.ContentBlock) => void
onMessage?: (message: Anthropic.Messages.Message) => void
}
/**
*
*/
export interface RequestTransformer<
TSdkParams extends SdkParams = SdkParams,
TMessageParam extends SdkMessageParam = SdkMessageParam
> {
transform(
completionsParams: CompletionsParams,
assistant: Assistant,
model: Model,
isRecursiveCall?: boolean,
recursiveSdkMessages?: TMessageParam[]
): Promise<{
payload: TSdkParams
messages: TMessageParam[]
metadata?: Record<string, any>
}>
}
/**
*
*/
export type ResponseChunkTransformer<TRawChunk extends SdkRawChunk = SdkRawChunk, TContext = any> = (
context?: TContext
) => Transformer<TRawChunk, GenericChunk>
export interface ResponseChunkTransformerContext {
isStreaming: boolean
isEnabledToolCalling: boolean
isEnabledWebSearch: boolean
isEnabledUrlContext: boolean
isEnabledReasoning: boolean
mcpTools: MCPTool[]
provider: Provider
}
/**
* API客户端接口
*/
export interface ApiClient<
TSdkInstance = any,
TSdkParams extends SdkParams = SdkParams,
TRawOutput extends SdkRawOutput = SdkRawOutput,
TRawChunk extends SdkRawChunk = SdkRawChunk,
TMessageParam extends SdkMessageParam = SdkMessageParam,
TToolCall extends SdkToolCall = SdkToolCall,
TSdkSpecificTool extends SdkTool = SdkTool
> {
provider: Provider
// 核心方法 - 在中间件架构中,这个方法可能只是一个占位符
// 实际的SDK调用由SdkCallMiddleware处理
// completions(params: CompletionsParams): Promise<CompletionsResult>
createCompletions(payload: TSdkParams): Promise<TRawOutput>
// SDK相关方法
getSdkInstance(): Promise<TSdkInstance> | TSdkInstance
getRequestTransformer(): RequestTransformer<TSdkParams, TMessageParam>
getResponseChunkTransformer(ctx: CompletionsContext): ResponseChunkTransformer<TRawChunk>
// 原始流监听方法
attachRawStreamListener?(rawOutput: TRawOutput, listener: RawStreamListener<TRawChunk>): TRawOutput
// 工具转换相关方法 (保持可选因为不是所有Provider都支持工具)
convertMcpToolsToSdkTools(mcpTools: MCPTool[]): TSdkSpecificTool[]
convertMcpToolResponseToSdkMessageParam?(
mcpToolResponse: MCPToolResponse,
resp: any,
model: Model
): TMessageParam | undefined
convertSdkToolCallToMcp?(toolCall: TToolCall, mcpTools: MCPTool[]): MCPTool | undefined
convertSdkToolCallToMcpToolResponse(toolCall: TToolCall, mcpTool: MCPTool): ToolCallResponse
// 构建SDK特定的消息列表用于工具调用后的递归调用
buildSdkMessages(
currentReqMessages: TMessageParam[],
output: TRawOutput | string,
toolResults: TMessageParam[],
toolCalls?: TToolCall[]
): TMessageParam[]
// 从SDK载荷中提取消息数组用于中间件中的类型安全访问
extractMessagesFromSdkPayload(sdkPayload: TSdkParams): TMessageParam[]
}

View File

@ -1,105 +0,0 @@
import type OpenAI from '@cherrystudio/openai'
import { loggerService } from '@logger'
import type { Provider } from '@renderer/types'
import type { GenerateImageParams } from '@renderer/types'
import { OpenAIAPIClient } from '../openai/OpenAIApiClient'
const logger = loggerService.withContext('ZhipuAPIClient')
export class ZhipuAPIClient extends OpenAIAPIClient {
constructor(provider: Provider) {
super(provider)
}
override getClientCompatibilityType(): string[] {
return ['ZhipuAPIClient']
}
override async generateImage({
model,
prompt,
negativePrompt,
imageSize,
batchSize,
signal,
quality
}: GenerateImageParams): Promise<string[]> {
const sdk = await this.getSdkInstance()
// 智谱AI使用不同的参数格式
const body: any = {
model,
prompt
}
// 智谱AI特有的参数格式
body.size = imageSize
body.n = batchSize
if (negativePrompt) {
body.negative_prompt = negativePrompt
}
// 只有cogview-4-250304模型支持quality和style参数
if (model === 'cogview-4-250304') {
if (quality) {
body.quality = quality
}
body.style = 'vivid'
}
try {
logger.debug('Calling Zhipu image generation API with params:', body)
const response = await sdk.images.generate(body, { signal })
if (response.data && response.data.length > 0) {
return response.data.map((image: any) => image.url).filter(Boolean)
}
return []
} catch (error) {
logger.error('Zhipu image generation failed:', error as Error)
throw error
}
}
public async listModels(): Promise<OpenAI.Models.Model[]> {
const models = [
'glm-4.7',
'glm-4.6',
'glm-4.6v',
'glm-4.6v-flash',
'glm-4.6v-flashx',
'glm-4.5',
'glm-4.5-x',
'glm-4.5-air',
'glm-4.5-airx',
'glm-4.5-flash',
'glm-4.5v',
'glm-z1-air',
'glm-z1-airx',
'cogview-3-flash',
'cogview-4-250304',
'glm-4-long',
'glm-4-plus',
'glm-4-air-250414',
'glm-4-airx',
'glm-4-flashx',
'glm-4v',
'glm-4v-flash',
'glm-4v-plus-0111',
'glm-4.1v-thinking-flash',
'glm-4-alltools',
'embedding-3'
]
const created = Date.now()
return models.map((id) => ({
id,
owned_by: 'zhipu',
object: 'model' as const,
created
}))
}
}

View File

@ -1,185 +0,0 @@
import { loggerService } from '@logger'
import { ApiClientFactory } from '@renderer/aiCore/legacy/clients/ApiClientFactory'
import type { BaseApiClient } from '@renderer/aiCore/legacy/clients/BaseApiClient'
import { isDedicatedImageGenerationModel, isFunctionCallingModel } from '@renderer/config/models'
import { withSpanResult } from '@renderer/services/SpanManagerService'
import type { StartSpanParams } from '@renderer/trace/types/ModelSpanEntity'
import type { GenerateImageParams, Model, Provider } from '@renderer/types'
import type { RequestOptions, SdkModel } from '@renderer/types/sdk'
import { isSupportedToolUse } from '@renderer/utils/mcp-tools'
import { AihubmixAPIClient } from './clients/aihubmix/AihubmixAPIClient'
import { VertexAPIClient } from './clients/gemini/VertexAPIClient'
import { NewAPIClient } from './clients/newapi/NewAPIClient'
import { OpenAIResponseAPIClient } from './clients/openai/OpenAIResponseAPIClient'
import { CompletionsMiddlewareBuilder } from './middleware/builder'
import { MIDDLEWARE_NAME as AbortHandlerMiddlewareName } from './middleware/common/AbortHandlerMiddleware'
import { MIDDLEWARE_NAME as ErrorHandlerMiddlewareName } from './middleware/common/ErrorHandlerMiddleware'
import { MIDDLEWARE_NAME as FinalChunkConsumerMiddlewareName } from './middleware/common/FinalChunkConsumerMiddleware'
import { applyCompletionsMiddlewares } from './middleware/composer'
import { MIDDLEWARE_NAME as McpToolChunkMiddlewareName } from './middleware/core/McpToolChunkMiddleware'
import { MIDDLEWARE_NAME as RawStreamListenerMiddlewareName } from './middleware/core/RawStreamListenerMiddleware'
import { MIDDLEWARE_NAME as WebSearchMiddlewareName } from './middleware/core/WebSearchMiddleware'
import { MIDDLEWARE_NAME as ImageGenerationMiddlewareName } from './middleware/feat/ImageGenerationMiddleware'
import { MIDDLEWARE_NAME as ThinkingTagExtractionMiddlewareName } from './middleware/feat/ThinkingTagExtractionMiddleware'
import { MIDDLEWARE_NAME as ToolUseExtractionMiddlewareName } from './middleware/feat/ToolUseExtractionMiddleware'
import { MiddlewareRegistry } from './middleware/register'
import type { CompletionsParams, CompletionsResult } from './middleware/schemas'
const logger = loggerService.withContext('AiProvider')
export default class AiProvider {
private apiClient: BaseApiClient
constructor(provider: Provider) {
// Use the new ApiClientFactory to get a BaseApiClient instance
this.apiClient = ApiClientFactory.create(provider)
}
public async completions(params: CompletionsParams, options?: RequestOptions): Promise<CompletionsResult> {
// 1. 根据模型识别正确的客户端
const model = params.assistant.model
if (!model) {
return Promise.reject(new Error('Model is required'))
}
// 根据client类型选择合适的处理方式
let client: BaseApiClient
if (this.apiClient instanceof AihubmixAPIClient) {
// AihubmixAPIClient: 根据模型选择合适的子client
client = this.apiClient.getClientForModel(model)
if (client instanceof OpenAIResponseAPIClient) {
client = client.getClient(model) as BaseApiClient
}
} else if (this.apiClient instanceof NewAPIClient) {
client = this.apiClient.getClientForModel(model)
if (client instanceof OpenAIResponseAPIClient) {
client = client.getClient(model) as BaseApiClient
}
} else if (this.apiClient instanceof OpenAIResponseAPIClient) {
// OpenAIResponseAPIClient: 根据模型特征选择API类型
client = this.apiClient.getClient(model) as BaseApiClient
} else if (this.apiClient instanceof VertexAPIClient) {
client = this.apiClient.getClient(model) as BaseApiClient
} else {
// 其他client直接使用
client = this.apiClient
}
// 2. 构建中间件链
const builder = CompletionsMiddlewareBuilder.withDefaults()
// images api
if (isDedicatedImageGenerationModel(model)) {
builder.clear()
builder
.add(MiddlewareRegistry[FinalChunkConsumerMiddlewareName])
.add(MiddlewareRegistry[ErrorHandlerMiddlewareName])
.add(MiddlewareRegistry[AbortHandlerMiddlewareName])
.add(MiddlewareRegistry[ImageGenerationMiddlewareName])
} else {
// Existing logic for other models
logger.silly('Builder Params', params)
// 使用兼容性类型检查避免typescript类型收窄和装饰器模式的问题
const clientTypes = client.getClientCompatibilityType(model)
const isOpenAICompatible =
clientTypes.includes('OpenAIAPIClient') || clientTypes.includes('OpenAIResponseAPIClient')
if (!isOpenAICompatible) {
logger.silly('ThinkingTagExtractionMiddleware is removed')
builder.remove(ThinkingTagExtractionMiddlewareName)
}
const isAnthropicOrOpenAIResponseCompatible =
clientTypes.includes('AnthropicAPIClient') ||
clientTypes.includes('OpenAIResponseAPIClient') ||
clientTypes.includes('AnthropicVertexAPIClient')
if (!isAnthropicOrOpenAIResponseCompatible) {
logger.silly('RawStreamListenerMiddleware is removed')
builder.remove(RawStreamListenerMiddlewareName)
}
if (!params.enableWebSearch) {
logger.silly('WebSearchMiddleware is removed')
builder.remove(WebSearchMiddlewareName)
}
if (!params.mcpTools?.length) {
builder.remove(ToolUseExtractionMiddlewareName)
logger.silly('ToolUseExtractionMiddleware is removed')
builder.remove(McpToolChunkMiddlewareName)
logger.silly('McpToolChunkMiddleware is removed')
}
if (isSupportedToolUse(params.assistant) && isFunctionCallingModel(model)) {
builder.remove(ToolUseExtractionMiddlewareName)
logger.silly('ToolUseExtractionMiddleware is removed')
}
if (params.callType !== 'chat' && params.callType !== 'check' && params.callType !== 'translate') {
logger.silly('AbortHandlerMiddleware is removed')
builder.remove(AbortHandlerMiddlewareName)
}
if (params.callType === 'test') {
builder.remove(ErrorHandlerMiddlewareName)
logger.silly('ErrorHandlerMiddleware is removed')
builder.remove(FinalChunkConsumerMiddlewareName)
logger.silly('FinalChunkConsumerMiddleware is removed')
}
}
const middlewares = builder.build()
logger.silly(
'middlewares',
middlewares.map((m) => m.name)
)
// 3. Create the wrapped SDK method with middlewares
const wrappedCompletionMethod = applyCompletionsMiddlewares(client, client.createCompletions, middlewares)
// 4. Execute the wrapped method with the original params
const result = wrappedCompletionMethod(params, options)
return result
}
public async completionsForTrace(params: CompletionsParams, options?: RequestOptions): Promise<CompletionsResult> {
const traceName = params.assistant.model?.name
? `${params.assistant.model?.name}.${params.callType}`
: `LLM.${params.callType}`
const traceParams: StartSpanParams = {
name: traceName,
tag: 'LLM',
topicId: params.topicId || '',
modelName: params.assistant.model?.name
}
return await withSpanResult(this.completions.bind(this), traceParams, params, options)
}
public async models(): Promise<SdkModel[]> {
return this.apiClient.listModels()
}
public async getEmbeddingDimensions(model: Model): Promise<number> {
try {
// Use the SDK instance to test embedding capabilities
const dimensions = await this.apiClient.getEmbeddingDimensions(model)
return dimensions
} catch (error) {
logger.error('Error getting embedding dimensions:', error as Error)
throw error
}
}
public async generateImage(params: GenerateImageParams): Promise<string[]> {
if (this.apiClient instanceof AihubmixAPIClient) {
const client = this.apiClient.getClientForModel({ id: params.model } as Model)
return client.generateImage(params)
}
return this.apiClient.generateImage(params)
}
public getBaseURL(): string {
return this.apiClient.getBaseURL()
}
public getApiKey(): string {
return this.apiClient.getApiKey()
}
}

View File

@ -1,182 +0,0 @@
# MiddlewareBuilder 使用指南
`MiddlewareBuilder` 是一个用于动态构建和管理中间件链的工具,提供灵活的中间件组织和配置能力。
## 主要特性
### 1. 统一的中间件命名
所有中间件都通过导出的 `MIDDLEWARE_NAME` 常量标识:
```typescript
// 中间件文件示例
export const MIDDLEWARE_NAME = 'SdkCallMiddleware'
export const SdkCallMiddleware: CompletionsMiddleware = ...
```
### 2. NamedMiddleware 接口
中间件使用统一的 `NamedMiddleware` 接口格式:
```typescript
interface NamedMiddleware<TMiddleware = any> {
name: string
middleware: TMiddleware
}
```
### 3. 中间件注册表
通过 `MiddlewareRegistry` 集中管理所有可用中间件:
```typescript
import { MiddlewareRegistry } from './register'
// 通过名称获取中间件
const sdkCallMiddleware = MiddlewareRegistry['SdkCallMiddleware']
```
## 基本用法
### 1. 使用默认中间件链
```typescript
import { CompletionsMiddlewareBuilder } from './builder'
const builder = CompletionsMiddlewareBuilder.withDefaults()
const middlewares = builder.build()
```
### 2. 自定义中间件链
```typescript
import { createCompletionsBuilder, MiddlewareRegistry } from './builder'
const builder = createCompletionsBuilder([
MiddlewareRegistry['AbortHandlerMiddleware'],
MiddlewareRegistry['TextChunkMiddleware']
])
const middlewares = builder.build()
```
### 3. 动态调整中间件链
```typescript
const builder = CompletionsMiddlewareBuilder.withDefaults()
// 根据条件添加、移除、替换中间件
if (needsLogging) {
builder.prepend(MiddlewareRegistry['GenericLoggingMiddleware'])
}
if (disableTools) {
builder.remove('McpToolChunkMiddleware')
}
if (customThinking) {
builder.replace('ThinkingTagExtractionMiddleware', customThinkingMiddleware)
}
const middlewares = builder.build()
```
### 4. 链式操作
```typescript
const middlewares = CompletionsMiddlewareBuilder.withDefaults()
.add(MiddlewareRegistry['CustomMiddleware'])
.insertBefore('SdkCallMiddleware', MiddlewareRegistry['SecurityCheckMiddleware'])
.remove('WebSearchMiddleware')
.build()
```
## API 参考
### CompletionsMiddlewareBuilder
**静态方法:**
- `static withDefaults()`: 创建带有默认中间件链的构建器
**实例方法:**
- `add(middleware: NamedMiddleware)`: 在链末尾添加中间件
- `prepend(middleware: NamedMiddleware)`: 在链开头添加中间件
- `insertAfter(targetName: string, middleware: NamedMiddleware)`: 在指定中间件后插入
- `insertBefore(targetName: string, middleware: NamedMiddleware)`: 在指定中间件前插入
- `replace(targetName: string, middleware: NamedMiddleware)`: 替换指定中间件
- `remove(targetName: string)`: 移除指定中间件
- `has(name: string)`: 检查是否包含指定中间件
- `build()`: 构建最终的中间件数组
- `getChain()`: 获取当前链(包含名称信息)
- `clear()`: 清空中间件链
- `execute(context, params, middlewareExecutor)`: 直接执行构建好的中间件链
### 工厂函数
- `createCompletionsBuilder(baseChain?)`: 创建 Completions 中间件构建器
- `createMethodBuilder(baseChain?)`: 创建通用方法中间件构建器
- `addMiddlewareName(middleware, name)`: 为中间件添加名称属性的辅助函数
### 中间件注册表
- `MiddlewareRegistry`: 所有注册中间件的集中访问点
- `getMiddleware(name)`: 根据名称获取中间件
- `getRegisteredMiddlewareNames()`: 获取所有注册的中间件名称
- `DefaultCompletionsNamedMiddlewares`: 默认的 Completions 中间件链NamedMiddleware 格式)
## 类型安全
构建器提供完整的 TypeScript 类型支持:
- `CompletionsMiddlewareBuilder` 专门用于 `CompletionsMiddleware` 类型
- `MethodMiddlewareBuilder` 用于通用的 `MethodMiddleware` 类型
- 所有中间件操作都基于 `NamedMiddleware<TMiddleware>` 接口
## 默认中间件链
默认的 Completions 中间件执行顺序:
1. `FinalChunkConsumerMiddleware` - 最终消费者
2. `TransformCoreToSdkParamsMiddleware` - 参数转换
3. `AbortHandlerMiddleware` - 中止处理
4. `McpToolChunkMiddleware` - 工具处理
5. `WebSearchMiddleware` - Web搜索处理
6. `TextChunkMiddleware` - 文本处理
7. `ThinkingTagExtractionMiddleware` - 思考标签提取处理
8. `ThinkChunkMiddleware` - 思考处理
9. `ResponseTransformMiddleware` - 响应转换
10. `StreamAdapterMiddleware` - 流适配器
11. `SdkCallMiddleware` - SDK调用
## 在 AiProvider 中的使用
```typescript
export default class AiProvider {
public async completions(params: CompletionsParams): Promise<CompletionsResult> {
// 1. 构建中间件链
const builder = CompletionsMiddlewareBuilder.withDefaults()
// 2. 根据参数动态调整
if (params.enableCustomFeature) {
builder.insertAfter('StreamAdapterMiddleware', customFeatureMiddleware)
}
// 3. 应用中间件
const middlewares = builder.build()
const wrappedMethod = applyCompletionsMiddlewares(this.apiClient, this.apiClient.createCompletions, middlewares)
return wrappedMethod(params)
}
}
```
## 注意事项
1. **类型兼容性**`MethodMiddleware` 和 `CompletionsMiddleware` 不兼容,需要使用对应的构建器
2. **中间件名称**:所有中间件必须导出 `MIDDLEWARE_NAME` 常量用于标识
3. **注册表管理**:新增中间件需要在 `register.ts` 中注册
4. **默认链**:默认链通过 `DefaultCompletionsNamedMiddlewares` 提供,支持延迟加载避免循环依赖
这种设计使得中间件链的构建既灵活又类型安全,同时保持了简洁的 API 接口。

View File

@ -1,175 +0,0 @@
# Cherry Studio 中间件规范
本文档定义了 Cherry Studio `aiCore` 模块中中间件的设计、实现和使用规范。目标是建立一个灵活、可维护且易于扩展的中间件系统。
## 1. 核心概念
### 1.1. 中间件 (Middleware)
中间件是一个函数或对象,它在 AI 请求的处理流程中的特定阶段执行,可以访问和修改请求上下文 (`AiProviderMiddlewareContext`)、请求参数 (`Params`),并控制是否将请求传递给下一个中间件或终止流程。
每个中间件应该专注于一个单一的横切关注点,例如日志记录、错误处理、流适配、特性解析等。
### 1.2. `AiProviderMiddlewareContext` (上下文对象)
这是一个在整个中间件链执行过程中传递的对象,包含以下核心信息:
- `_apiClientInstance: ApiClient<any,any,any>`: 当前选定的、已实例化的 AI Provider 客户端。
- `_coreRequest: CoreRequestType`: 标准化的内部核心请求对象。
- `resolvePromise: (value: AggregatedResultType) => void`: 用于在整个操作成功完成时解析 `AiCoreService` 返回的 Promise。
- `rejectPromise: (reason?: any) => void`: 用于在发生错误时拒绝 `AiCoreService` 返回的 Promise。
- `onChunk?: (chunk: Chunk) => void`: 应用层提供的流式数据块回调。
- `abortController?: AbortController`: 用于中止请求的控制器。
- 其他中间件可能读写的、与当前请求相关的动态数据。
### 1.3. `MiddlewareName` (中间件名称)
为了方便动态操作(如插入、替换、移除)中间件,每个重要的、可能被其他逻辑引用的中间件都应该有一个唯一的、可识别的名称。推荐使用 TypeScript 的 `enum` 来定义:
```typescript
// example
export enum MiddlewareName {
LOGGING_START = 'LoggingStartMiddleware',
LOGGING_END = 'LoggingEndMiddleware',
ERROR_HANDLING = 'ErrorHandlingMiddleware',
ABORT_HANDLER = 'AbortHandlerMiddleware',
// Core Flow
TRANSFORM_CORE_TO_SDK_PARAMS = 'TransformCoreToSdkParamsMiddleware',
REQUEST_EXECUTION = 'RequestExecutionMiddleware',
STREAM_ADAPTER = 'StreamAdapterMiddleware',
RAW_SDK_CHUNK_TO_APP_CHUNK = 'RawSdkChunkToAppChunkMiddleware',
// Features
THINKING_TAG_EXTRACTION = 'ThinkingTagExtractionMiddleware',
TOOL_USE_TAG_EXTRACTION = 'ToolUseTagExtractionMiddleware',
MCP_TOOL_HANDLER = 'McpToolHandlerMiddleware',
// Finalization
FINAL_CHUNK_CONSUMER = 'FinalChunkConsumerAndNotifierMiddleware'
// Add more as needed
}
```
中间件实例需要某种方式暴露其 `MiddlewareName`,例如通过一个 `name` 属性。
### 1.4. 中间件执行结构
我们采用一种灵活的中间件执行结构。一个中间件通常是一个函数,它接收 `Context`、`Params`,以及一个 `next` 函数(用于调用链中的下一个中间件)。
```typescript
// 简化形式的中间件函数签名
type MiddlewareFunction = (
context: AiProviderMiddlewareContext,
params: any, // e.g., CompletionsParams
next: () => Promise<void> // next 通常返回 Promise 以支持异步操作
) => Promise<void> // 中间件自身也可能返回 Promise
// 或者更经典的 Koa/Express 风格 (三段式)
// type MiddlewareFactory = (api?: MiddlewareApi) =>
// (nextMiddleware: (ctx: AiProviderMiddlewareContext, params: any) => Promise<void>) =>
// (context: AiProviderMiddlewareContext, params: any) => Promise<void>;
// 当前设计更倾向于上述简化的 MiddlewareFunction由 MiddlewareExecutor 负责 next 的编排。
```
`MiddlewareExecutor` (或 `applyMiddlewares`) 会负责管理 `next` 的调用。
## 2. `MiddlewareBuilder` (通用中间件构建器)
为了动态构建和管理中间件链,我们引入一个通用的 `MiddlewareBuilder` 类。
### 2.1. 设计理念
`MiddlewareBuilder` 提供了一个流式 API用于以声明式的方式构建中间件链。它允许从一个基础链开始然后根据特定条件添加、插入、替换或移除中间件。
### 2.2. API 概览
```typescript
class MiddlewareBuilder {
constructor(baseChain?: Middleware[])
add(middleware: Middleware): this
prepend(middleware: Middleware): this
insertAfter(targetName: MiddlewareName, middlewareToInsert: Middleware): this
insertBefore(targetName: MiddlewareName, middlewareToInsert: Middleware): this
replace(targetName: MiddlewareName, newMiddleware: Middleware): this
remove(targetName: MiddlewareName): this
build(): Middleware[] // 返回构建好的中间件数组
// 可选:直接执行链
execute(
context: AiProviderMiddlewareContext,
params: any,
middlewareExecutor: (chain: Middleware[], context: AiProviderMiddlewareContext, params: any) => void
): void
}
```
### 2.3. 使用示例
```typescript
// 1. 定义一些中间件实例 (假设它们有 .name 属性)
const loggingStart = { name: MiddlewareName.LOGGING_START, fn: loggingStartFn }
const requestExec = { name: MiddlewareName.REQUEST_EXECUTION, fn: requestExecFn }
const streamAdapter = { name: MiddlewareName.STREAM_ADAPTER, fn: streamAdapterFn }
const customFeature = { name: MiddlewareName.CUSTOM_FEATURE, fn: customFeatureFn } // 假设自定义
// 2. 定义一个基础链 (可选)
const BASE_CHAIN: Middleware[] = [loggingStart, requestExec, streamAdapter]
// 3. 使用 MiddlewareBuilder
const builder = new MiddlewareBuilder(BASE_CHAIN)
if (params.needsCustomFeature) {
builder.insertAfter(MiddlewareName.STREAM_ADAPTER, customFeature)
}
if (params.isHighSecurityContext) {
builder.insertBefore(MiddlewareName.REQUEST_EXECUTION, высокоSecurityCheckMiddleware)
}
if (params.overrideLogging) {
builder.replace(MiddlewareName.LOGGING_START, newSpecialLoggingMiddleware)
}
// 4. 获取最终链
const finalChain = builder.build()
// 5. 执行 (通过外部执行器)
// middlewareExecutor(finalChain, context, params);
// 或者 builder.execute(context, params, middlewareExecutor);
```
## 3. `MiddlewareExecutor` / `applyMiddlewares` (中间件执行器)
这是负责接收 `MiddlewareBuilder` 构建的中间件链并实际执行它们的组件。
### 3.1. 职责
- 接收 `Middleware[]`, `AiProviderMiddlewareContext`, `Params`
- 按顺序迭代中间件。
- 为每个中间件提供正确的 `next` 函数,该函数在被调用时会执行链中的下一个中间件。
- 处理中间件执行过程中的Promise如果中间件是异步的
- 基础的错误捕获(具体错误处理应由链内的 `ErrorHandlingMiddleware` 负责)。
## 4. 在 `AiCoreService` 中使用
`AiCoreService` 中的每个核心业务方法 (如 `executeCompletions`) 将负责:
1. 准备基础数据:实例化 `ApiClient`,转换 `Params``CoreRequest`
2. 实例化 `MiddlewareBuilder`,可能会传入一个特定于该业务方法的基础中间件链。
3. 根据 `Params``CoreRequest` 中的条件,调用 `MiddlewareBuilder` 的方法来动态调整中间件链。
4. 调用 `MiddlewareBuilder.build()` 获取最终的中间件链。
5. 创建完整的 `AiProviderMiddlewareContext` (包含 `resolvePromise`, `rejectPromise` 等)。
6. 调用 `MiddlewareExecutor` (或 `applyMiddlewares`) 来执行构建好的链。
## 5. 组合功能
对于组合功能(例如 "Completions then Translate"
- 不推荐创建一个单一、庞大的 `MiddlewareBuilder` 来处理整个组合流程。
- 推荐在 `AiCoreService` 中创建一个新的方法,该方法按顺序 `await` 调用底层的原子 `AiCoreService` 方法(例如,先 `await this.executeCompletions(...)`,然后用其结果 `await this.translateText(...)`)。
- 每个被调用的原子方法内部会使用其自身的 `MiddlewareBuilder` 实例来构建和执行其特定阶段的中间件链。
- 这种方式最大化了复用,并保持了各部分职责的清晰。
## 6. 中间件命名和发现
为中间件赋予唯一的 `MiddlewareName` 对于 `MiddlewareBuilder``insertAfter`, `insertBefore`, `replace`, `remove` 等操作至关重要。确保中间件实例能够以某种方式暴露其名称(例如,一个 `name` 属性)。

View File

@ -1,79 +0,0 @@
import { ChunkType } from '@renderer/types/chunk'
import { describe, expect, it } from 'vitest'
import { capitalize, createErrorChunk, isAsyncIterable } from '../utils'
describe('utils', () => {
describe('createErrorChunk', () => {
it('should handle Error instances', () => {
const error = new Error('Test error message')
const result = createErrorChunk(error)
expect(result.type).toBe(ChunkType.ERROR)
expect(result.error.message).toBe('Test error message')
expect(result.error.name).toBe('Error')
expect(result.error.stack).toBeDefined()
})
it('should handle string errors', () => {
const result = createErrorChunk('Something went wrong')
expect(result.error).toEqual({ message: 'Something went wrong' })
})
it('should handle plain objects', () => {
const error = { code: 'NETWORK_ERROR', status: 500 }
const result = createErrorChunk(error)
expect(result.error).toEqual(error)
})
it('should handle null and undefined', () => {
expect(createErrorChunk(null).error).toEqual({})
expect(createErrorChunk(undefined).error).toEqual({})
})
it('should use custom chunk type when provided', () => {
const result = createErrorChunk('error', ChunkType.BLOCK_COMPLETE)
expect(result.type).toBe(ChunkType.BLOCK_COMPLETE)
})
it('should use toString for objects without message', () => {
const error = {
toString: () => 'Custom error'
}
const result = createErrorChunk(error)
expect(result.error.message).toBe('Custom error')
})
})
describe('capitalize', () => {
it('should capitalize first letter', () => {
expect(capitalize('hello')).toBe('Hello')
expect(capitalize('a')).toBe('A')
})
it('should handle edge cases', () => {
expect(capitalize('')).toBe('')
expect(capitalize('123')).toBe('123')
expect(capitalize('Hello')).toBe('Hello')
})
})
describe('isAsyncIterable', () => {
it('should identify async iterables', () => {
async function* gen() {
yield 1
}
expect(isAsyncIterable(gen())).toBe(true)
expect(isAsyncIterable({ [Symbol.asyncIterator]: () => {} })).toBe(true)
})
it('should reject non-async iterables', () => {
expect(isAsyncIterable([1, 2, 3])).toBe(false)
expect(isAsyncIterable(new Set())).toBe(false)
expect(isAsyncIterable({})).toBe(false)
expect(isAsyncIterable(null)).toBe(false)
expect(isAsyncIterable(123)).toBe(false)
expect(isAsyncIterable('string')).toBe(false)
})
})
})

View File

@ -1,245 +0,0 @@
import { loggerService } from '@logger'
import { DefaultCompletionsNamedMiddlewares } from './register'
import type { BaseContext, CompletionsMiddleware, MethodMiddleware } from './types'
const logger = loggerService.withContext('aiCore:MiddlewareBuilder')
/**
*
*/
export interface NamedMiddleware<TMiddleware = any> {
name: string
middleware: TMiddleware
}
/**
*
*/
export type MiddlewareExecutor<TContext extends BaseContext = BaseContext> = (
chain: any[],
context: TContext,
params: any
) => Promise<any>
/**
*
* API
*
* MiddlewareRegistry 使 NamedMiddleware
*/
export class MiddlewareBuilder<TMiddleware = any> {
private middlewares: NamedMiddleware<TMiddleware>[]
/**
*
* @param baseChain - NamedMiddleware
*/
constructor(baseChain?: NamedMiddleware<TMiddleware>[]) {
this.middlewares = baseChain ? [...baseChain] : []
}
/**
*
* @param middleware -
* @returns this
*/
add(middleware: NamedMiddleware<TMiddleware>): this {
this.middlewares.push(middleware)
return this
}
/**
*
* @param middleware -
* @returns this
*/
prepend(middleware: NamedMiddleware<TMiddleware>): this {
this.middlewares.unshift(middleware)
return this
}
/**
*
* @param targetName -
* @param middlewareToInsert -
* @returns this
*/
insertAfter(targetName: string, middlewareToInsert: NamedMiddleware<TMiddleware>): this {
const index = this.findMiddlewareIndex(targetName)
if (index !== -1) {
this.middlewares.splice(index + 1, 0, middlewareToInsert)
} else {
logger.warn(`未找到名为 '${targetName}' 的中间件,无法插入`)
}
return this
}
/**
*
* @param targetName -
* @param middlewareToInsert -
* @returns this
*/
insertBefore(targetName: string, middlewareToInsert: NamedMiddleware<TMiddleware>): this {
const index = this.findMiddlewareIndex(targetName)
if (index !== -1) {
this.middlewares.splice(index, 0, middlewareToInsert)
} else {
logger.warn(`未找到名为 '${targetName}' 的中间件,无法插入`)
}
return this
}
/**
*
* @param targetName -
* @param newMiddleware -
* @returns this
*/
replace(targetName: string, newMiddleware: NamedMiddleware<TMiddleware>): this {
const index = this.findMiddlewareIndex(targetName)
if (index !== -1) {
this.middlewares[index] = newMiddleware
} else {
logger.warn(`未找到名为 '${targetName}' 的中间件,无法替换`)
}
return this
}
/**
*
* @param targetName -
* @returns this
*/
remove(targetName: string): this {
const index = this.findMiddlewareIndex(targetName)
if (index !== -1) {
this.middlewares.splice(index, 1)
}
return this
}
/**
*
* @returns
*/
build(): TMiddleware[] {
return this.middlewares.map((item) => item.middleware)
}
/**
*
* @returns
*/
getChain(): NamedMiddleware<TMiddleware>[] {
return [...this.middlewares]
}
/**
*
* @param name -
* @returns
*/
has(name: string): boolean {
return this.findMiddlewareIndex(name) !== -1
}
/**
*
* @returns
*/
get length(): number {
return this.middlewares.length
}
/**
*
* @returns this
*/
clear(): this {
this.middlewares = []
return this
}
/**
*
* @param context -
* @param params -
* @param middlewareExecutor -
* @returns
*/
execute<TContext extends BaseContext>(
context: TContext,
params: any,
middlewareExecutor: MiddlewareExecutor<TContext>
): Promise<any> {
const chain = this.build()
return middlewareExecutor(chain, context, params)
}
/**
*
* @param name -
* @returns -1
*/
private findMiddlewareIndex(name: string): number {
return this.middlewares.findIndex((item) => item.name === name)
}
}
/**
* Completions
*/
export class CompletionsMiddlewareBuilder extends MiddlewareBuilder<CompletionsMiddleware> {
constructor(baseChain?: NamedMiddleware<CompletionsMiddleware>[]) {
super(baseChain)
}
/**
* 使 Completions
* @returns CompletionsMiddlewareBuilder
*/
static withDefaults(): CompletionsMiddlewareBuilder {
return new CompletionsMiddlewareBuilder(DefaultCompletionsNamedMiddlewares)
}
}
/**
*
*/
export class MethodMiddlewareBuilder extends MiddlewareBuilder<MethodMiddleware> {
constructor(baseChain?: NamedMiddleware<MethodMiddleware>[]) {
super(baseChain)
}
}
// 便捷的工厂函数
/**
* Completions
* @param baseChain -
* @returns Completions
*/
export function createCompletionsBuilder(
baseChain?: NamedMiddleware<CompletionsMiddleware>[]
): CompletionsMiddlewareBuilder {
return new CompletionsMiddlewareBuilder(baseChain)
}
/**
*
* @param baseChain -
* @returns
*/
export function createMethodBuilder(baseChain?: NamedMiddleware<MethodMiddleware>[]): MethodMiddlewareBuilder {
return new MethodMiddlewareBuilder(baseChain)
}
/**
*
*
*/
export function addMiddlewareName<T extends object>(middleware: T, name: string): T & { MIDDLEWARE_NAME: string } {
return Object.assign(middleware, { MIDDLEWARE_NAME: name })
}

View File

@ -1,121 +0,0 @@
import { loggerService } from '@logger'
import type { Chunk, ErrorChunk } from '@renderer/types/chunk'
import { ChunkType } from '@renderer/types/chunk'
import { addAbortController, removeAbortController } from '@renderer/utils/abortController'
import type { CompletionsParams, CompletionsResult } from '../schemas'
import type { CompletionsContext, CompletionsMiddleware } from '../types'
const logger = loggerService.withContext('aiCore:AbortHandlerMiddleware')
export const MIDDLEWARE_NAME = 'AbortHandlerMiddleware'
export const AbortHandlerMiddleware: CompletionsMiddleware =
() =>
(next) =>
async (ctx: CompletionsContext, params: CompletionsParams): Promise<CompletionsResult> => {
const isRecursiveCall = ctx._internal?.toolProcessingState?.isRecursiveCall || false
// 在递归调用中,跳过 AbortController 的创建,直接使用已有的
if (isRecursiveCall) {
const result = await next(ctx, params)
return result
}
const abortController = new AbortController()
const abortFn = (): void => abortController.abort()
let abortSignal: AbortSignal | null = abortController.signal
let abortKey: string
// 如果参数中传入了abortKey则优先使用
if (params.abortKey) {
abortKey = params.abortKey
} else {
// 获取当前消息的ID用于abort管理
// 优先使用处理过的消息,如果没有则使用原始消息
let messageId: string | undefined
if (typeof params.messages === 'string') {
messageId = `message-${Date.now()}-${Math.random().toString(36).substring(2, 9)}`
} else {
const processedMessages = params.messages
const lastUserMessage = processedMessages.findLast((m) => m.role === 'user')
messageId = lastUserMessage?.id
}
if (!messageId) {
logger.warn(`No messageId found, abort functionality will not be available.`)
return next(ctx, params)
}
abortKey = messageId
}
addAbortController(abortKey, abortFn)
const cleanup = (): void => {
removeAbortController(abortKey, abortFn)
if (ctx._internal?.flowControl) {
ctx._internal.flowControl.abortController = undefined
ctx._internal.flowControl.abortSignal = undefined
ctx._internal.flowControl.cleanup = undefined
}
abortSignal = null
}
// 将controller添加到_internal中的flowControl状态
if (!ctx._internal.flowControl) {
ctx._internal.flowControl = {}
}
ctx._internal.flowControl.abortController = abortController
ctx._internal.flowControl.abortSignal = abortSignal
ctx._internal.flowControl.cleanup = cleanup
const result = await next(ctx, params)
const error = new DOMException('Request was aborted', 'AbortError')
const streamWithAbortHandler = (result.stream as ReadableStream<Chunk>).pipeThrough(
new TransformStream<Chunk, Chunk | ErrorChunk>({
transform(chunk, controller) {
// 如果已经收到错误块,不再检查 abort 状态
if (chunk.type === ChunkType.ERROR) {
controller.enqueue(chunk)
return
}
if (abortSignal?.aborted) {
// 转换为 ErrorChunk
const errorChunk: ErrorChunk = {
type: ChunkType.ERROR,
error
}
controller.enqueue(errorChunk)
cleanup()
return
}
// 正常传递 chunk
controller.enqueue(chunk)
},
flush(controller) {
// 在流结束时再次检查 abort 状态
if (abortSignal?.aborted) {
const errorChunk: ErrorChunk = {
type: ChunkType.ERROR,
error
}
controller.enqueue(errorChunk)
}
// 在流完全处理完成后清理 AbortController
cleanup()
}
})
)
return {
...result,
stream: streamWithAbortHandler
}
}

View File

@ -1,134 +0,0 @@
import { loggerService } from '@logger'
import { isZhipuModel } from '@renderer/config/models'
import { getStoreProviders } from '@renderer/hooks/useStore'
import { getDefaultModel } from '@renderer/services/AssistantService'
import type { Chunk } from '@renderer/types/chunk'
import type { CompletionsParams, CompletionsResult } from '../schemas'
import type { CompletionsContext } from '../types'
import { createErrorChunk } from '../utils'
const logger = loggerService.withContext('ErrorHandlerMiddleware')
export const MIDDLEWARE_NAME = 'ErrorHandlerMiddleware'
/**
*
*
*
* API调用中发生的任何错误
*
* @param config -
* @returns CompletionsMiddleware
*/
export const ErrorHandlerMiddleware =
() =>
(next) =>
async (ctx: CompletionsContext, params): Promise<CompletionsResult> => {
const { shouldThrow } = params
try {
// 尝试执行下一个中间件
return await next(ctx, params)
} catch (error: any) {
logger.error(error)
let processedError = error
processedError = handleError(error, params)
// 1. 使用通用的工具函数将错误解析为标准格式
const errorChunk = createErrorChunk(processedError)
// 2. 调用从外部传入的 onError 回调
if (params.onError) {
params.onError(processedError)
}
// 3. 根据配置决定是重新抛出错误,还是将其作为流的一部分向下传递
if (shouldThrow) {
throw processedError
}
// 如果不抛出,则创建一个只包含该错误块的流并向下传递
const errorStream = new ReadableStream<Chunk>({
start(controller) {
controller.enqueue(errorChunk)
controller.close()
}
})
return {
rawOutput: undefined,
stream: errorStream, // 将包含错误的流传递下去
controller: undefined,
getText: () => '' // 错误情况下没有文本结果
}
}
}
function handleError(error: any, params: CompletionsParams): any {
if (isZhipuModel(params.assistant.model || getDefaultModel()) && error.status && !params.enableGenerateImage) {
return handleZhipuError(error)
}
if (error.status === 401 || error.message.includes('401')) {
return {
...error,
i18nKey: 'chat.no_api_key',
providerId: params.assistant?.model?.provider
}
}
return error
}
/**
*
* 1. enableGenerateImage为false使
* 2. enableGenerateImage为true使
*/
function handleZhipuError(error: any): any {
const provider = getStoreProviders().find((p) => p.id === 'zhipu')
const logger = loggerService.withContext('handleZhipuError')
// 定义错误模式映射
const errorPatterns = [
{
condition: () => error.status === 401 || /令牌已过期|AuthenticationError|Unauthorized/i.test(error.message),
i18nKey: 'chat.no_api_key',
providerId: provider?.id
},
{
condition: () => error.error?.code === '1304' || /限额|免费配额|free quota|rate limit/i.test(error.message),
i18nKey: 'chat.quota_exceeded',
providerId: provider?.id
},
{
condition: () =>
(error.status === 429 && error.error?.code === '1113') || /余额不足|insufficient balance/i.test(error.message),
i18nKey: 'chat.insufficient_balance',
providerId: provider?.id
},
{
condition: () => !provider?.apiKey?.trim(),
i18nKey: 'chat.no_api_key',
providerId: provider?.id
}
]
// 遍历错误模式,返回第一个匹配的错误
for (const pattern of errorPatterns) {
if (pattern.condition()) {
return {
...error,
providerId: pattern.providerId,
i18nKey: pattern.i18nKey
}
}
}
// 如果不是智谱特定错误,返回原始错误
logger.debug('🔧 不是智谱特定错误,返回原始错误')
return error
}

View File

@ -1,195 +0,0 @@
import { loggerService } from '@logger'
import type { Usage } from '@renderer/types'
import type { Chunk } from '@renderer/types/chunk'
import { ChunkType } from '@renderer/types/chunk'
import type { CompletionsParams, CompletionsResult, GenericChunk } from '../schemas'
import type { CompletionsContext, CompletionsMiddleware } from '../types'
export const MIDDLEWARE_NAME = 'FinalChunkConsumerAndNotifierMiddleware'
const logger = loggerService.withContext('FinalChunkConsumerMiddleware')
/**
* Chunk消费和通知中间件
*
*
* 1. GenericChunk流中的chunks并转发给onChunk回调
* 2. usage/metrics数据SDK chunks或GenericChunk中提取
* 3. LLM_RESPONSE_COMPLETE时发送包含累计数据的BLOCK_COMPLETE
* 4. MCP工具调用的多轮请求中的数据累加
*/
const FinalChunkConsumerMiddleware: CompletionsMiddleware =
() =>
(next) =>
async (ctx: CompletionsContext, params: CompletionsParams): Promise<CompletionsResult> => {
const isRecursiveCall =
params._internal?.toolProcessingState?.isRecursiveCall ||
ctx._internal?.toolProcessingState?.isRecursiveCall ||
false
// 初始化累计数据(只在顶层调用时初始化)
if (!isRecursiveCall) {
if (!ctx._internal.customState) {
ctx._internal.customState = {}
}
ctx._internal.observer = {
usage: {
prompt_tokens: 0,
completion_tokens: 0,
total_tokens: 0
},
metrics: {
completion_tokens: 0,
time_completion_millsec: 0,
time_first_token_millsec: 0,
time_thinking_millsec: 0
}
}
// 初始化文本累积器
ctx._internal.customState.accumulatedText = ''
ctx._internal.customState.startTimestamp = Date.now()
}
// 调用下游中间件
const result = await next(ctx, params)
// 响应后处理处理GenericChunk流式响应
if (result.stream) {
const resultFromUpstream = result.stream
if (resultFromUpstream && resultFromUpstream instanceof ReadableStream) {
const reader = resultFromUpstream.getReader()
try {
while (true) {
const { done, value: chunk } = await reader.read()
logger.silly('chunk', chunk)
if (done) {
logger.debug(`Input stream finished.`)
break
}
if (chunk) {
const genericChunk = chunk as GenericChunk
// 提取并累加usage/metrics数据
extractAndAccumulateUsageMetrics(ctx, genericChunk)
const shouldSkipChunk =
isRecursiveCall &&
(genericChunk.type === ChunkType.BLOCK_COMPLETE ||
genericChunk.type === ChunkType.LLM_RESPONSE_COMPLETE)
if (!shouldSkipChunk) params.onChunk?.(genericChunk)
} else {
logger.warn(`Received undefined chunk before stream was done.`)
}
}
} catch (error: any) {
logger.error(`Error consuming stream:`, error as Error)
// FIXME: 临时解决方案。该中间件的异常无法被 ErrorHandlerMiddleware捕获。
if (params.onError) {
params.onError(error)
}
if (params.shouldThrow) {
throw error
}
} finally {
if (params.onChunk && !isRecursiveCall) {
params.onChunk({
type: ChunkType.BLOCK_COMPLETE,
response: {
usage: ctx._internal.observer?.usage ? { ...ctx._internal.observer.usage } : undefined,
metrics: ctx._internal.observer?.metrics ? { ...ctx._internal.observer.metrics } : undefined
}
} as Chunk)
if (ctx._internal.toolProcessingState) {
ctx._internal.toolProcessingState = {}
}
}
}
// 为流式输出添加getText方法
const modifiedResult = {
...result,
stream: new ReadableStream<GenericChunk>({
start(controller) {
controller.close()
}
}),
getText: () => {
return ctx._internal.customState?.accumulatedText || ''
}
}
return modifiedResult
} else {
logger.debug(`No GenericChunk stream to process.`)
}
}
return result
}
/**
* GenericChunk或原始SDK chunks中提取usage/metrics数据并累加
*/
function extractAndAccumulateUsageMetrics(ctx: CompletionsContext, chunk: GenericChunk): void {
if (!ctx._internal.observer?.usage || !ctx._internal.observer?.metrics) {
return
}
try {
if (ctx._internal.customState && !ctx._internal.customState?.firstTokenTimestamp) {
ctx._internal.customState.firstTokenTimestamp = Date.now()
logger.debug(`First token timestamp: ${ctx._internal.customState.firstTokenTimestamp}`)
}
if (chunk.type === ChunkType.LLM_RESPONSE_COMPLETE) {
// 从LLM_RESPONSE_COMPLETE chunk中提取usage数据
if (chunk.response?.usage) {
accumulateUsage(ctx._internal.observer.usage, chunk.response.usage)
}
if (ctx._internal.customState && ctx._internal.customState?.firstTokenTimestamp) {
ctx._internal.observer.metrics.time_first_token_millsec =
ctx._internal.customState.firstTokenTimestamp - ctx._internal.customState.startTimestamp
ctx._internal.observer.metrics.time_completion_millsec +=
Date.now() - ctx._internal.customState.firstTokenTimestamp
}
}
// 也可以从其他chunk类型中提取metrics数据
if (chunk.type === ChunkType.THINKING_COMPLETE && chunk.thinking_millsec && ctx._internal.observer?.metrics) {
ctx._internal.observer.metrics.time_thinking_millsec = Math.max(
ctx._internal.observer.metrics.time_thinking_millsec || 0,
chunk.thinking_millsec
)
}
} catch (error) {
logger.error('Error extracting usage/metrics from chunk:', error as Error)
}
}
/**
* usage数据
*/
function accumulateUsage(accumulated: Usage, newUsage: Usage): void {
if (newUsage.prompt_tokens !== undefined) {
accumulated.prompt_tokens += newUsage.prompt_tokens
}
if (newUsage.completion_tokens !== undefined) {
accumulated.completion_tokens += newUsage.completion_tokens
}
if (newUsage.total_tokens !== undefined) {
accumulated.total_tokens += newUsage.total_tokens
}
if (newUsage.thoughts_tokens !== undefined) {
accumulated.thoughts_tokens = (accumulated.thoughts_tokens || 0) + newUsage.thoughts_tokens
}
// Handle OpenRouter specific cost fields
if (newUsage.cost !== undefined) {
accumulated.cost = (accumulated.cost || 0) + newUsage.cost
}
}
export default FinalChunkConsumerMiddleware

View File

@ -1,68 +0,0 @@
import { loggerService } from '@logger'
import type { BaseContext, MethodMiddleware, MiddlewareAPI } from '../types'
const logger = loggerService.withContext('LoggingMiddleware')
export const MIDDLEWARE_NAME = 'GenericLoggingMiddlewares'
/**
* Helper function to safely stringify arguments for logging, handling circular references and large objects.
*
* @param args - The arguments array to stringify.
* @returns A string representation of the arguments.
*/
const stringifyArgsForLogging = (args: any[]): string => {
try {
return args
.map((arg) => {
if (typeof arg === 'function') return '[Function]'
if (typeof arg === 'object' && arg !== null && arg.constructor === Object && Object.keys(arg).length > 20) {
return '[Object with >20 keys]'
}
// Truncate long strings to avoid flooding logs 截断长字符串以避免日志泛滥
const stringifiedArg = JSON.stringify(arg, null, 2)
return stringifiedArg && stringifiedArg.length > 200 ? stringifiedArg.substring(0, 200) + '...' : stringifiedArg
})
.join(', ')
} catch (e) {
return '[Error serializing arguments]' // Handle potential errors during stringification 处理字符串化期间的潜在错误
}
}
/**
* Generic logging middleware for provider methods.
*
* This middleware logs the initiation, success/failure, and duration of a method call.
* /
*/
/**
* Creates a generic logging middleware for provider methods.
*
* @returns A `MethodMiddleware` instance. `MethodMiddleware`
*/
export const createGenericLoggingMiddleware: () => MethodMiddleware = () => {
const middlewareName = 'GenericLoggingMiddleware'
// oxlint-disable-next-line @typescript-eslint/no-unused-vars
return (_: MiddlewareAPI<BaseContext, any[]>) => (next) => async (ctx, args) => {
const methodName = ctx.methodName
const logPrefix = `[${middlewareName} (${methodName})]`
logger.debug(`${logPrefix} Initiating. Args: ${stringifyArgsForLogging(args)}`)
const startTime = Date.now()
try {
const result = await next(ctx, args)
const duration = Date.now() - startTime
// Log successful completion of the method call with duration. /
// 记录方法调用成功完成及其持续时间。
logger.debug(`${logPrefix} Successful. Duration: ${duration}ms`)
return result
} catch (error) {
const duration = Date.now() - startTime
// Log failure of the method call with duration and error information. /
// 记录方法调用失败及其持续时间和错误信息。
logger.error(`${logPrefix} Failed. Duration: ${duration}ms`, error as Error)
throw error // Re-throw the error to be handled by subsequent layers or the caller / 重新抛出错误,由后续层或调用者处理
}
}
}

View File

@ -1,289 +0,0 @@
import { withSpanResult } from '@renderer/services/SpanManagerService'
import type {
RequestOptions,
SdkInstance,
SdkMessageParam,
SdkParams,
SdkRawChunk,
SdkRawOutput,
SdkTool,
SdkToolCall
} from '@renderer/types/sdk'
import type { BaseApiClient } from '../clients'
import type { CompletionsParams, CompletionsResult } from './schemas'
import type { BaseContext, CompletionsContext, CompletionsMiddleware, MethodMiddleware, MiddlewareAPI } from './types'
import { MIDDLEWARE_CONTEXT_SYMBOL } from './types'
/**
* Creates the initial context for a method call, populating method-specific fields. /
*
* @param methodName - The name of the method being called. /
* @param originalCallArgs - The actual arguments array from the proxy/method call. / /
* @param providerId - The ID of the provider, if available. / ID
* @param providerInstance - The instance of the provider. /
* @param specificContextFactory - An optional factory function to create a specific context type from the base context and original call arguments. /
* @returns The created context object. /
*/
function createInitialCallContext<TContext extends BaseContext, TCallArgs extends unknown[]>(
methodName: string,
originalCallArgs: TCallArgs, // Renamed from originalArgs to avoid confusion with context.originalArgs
// Factory to create specific context from base and the *original call arguments array*
specificContextFactory?: (base: BaseContext, callArgs: TCallArgs) => TContext
): TContext {
const baseContext: BaseContext = {
[MIDDLEWARE_CONTEXT_SYMBOL]: true,
methodName,
originalArgs: originalCallArgs // Store the full original arguments array in the context
}
if (specificContextFactory) {
return specificContextFactory(baseContext, originalCallArgs)
}
return baseContext as TContext // Fallback to base context if no specific factory
}
/**
* Composes an array of functions from right to left. /
*
* `compose(f, g, h)` is `(...args) => f(g(h(...args)))`. /
* `compose(f, g, h)` `(...args) => f(g(h(...args)))`
* Each function in funcs is expected to take the result of the next function
* (or the initial value for the rightmost function) as its argument. /
* `funcs`
* @param funcs - Array of functions to compose. /
* @returns The composed function. /
*/
function compose(...funcs: Array<(...args: any[]) => any>): (...args: any[]) => any {
if (funcs.length === 0) {
// If no functions to compose, return a function that returns its first argument, or undefined if no args. /
// 如果没有要组合的函数则返回一个函数该函数返回其第一个参数如果没有参数则返回undefined。
return (...args: any[]) => (args.length > 0 ? args[0] : undefined)
}
if (funcs.length === 1) {
return funcs[0]
}
return funcs.reduce(
(a, b) =>
(...args: any[]) =>
a(b(...args))
)
}
/**
* Applies an array of Redux-style middlewares to a generic provider method. /
* Redux风格的中间件应用于一个通用的提供者方法
* This version keeps arguments as an array throughout the middleware chain. /
*
* @param originalProviderInstance - The original provider instance. /
* @param methodName - The name of the method to be enhanced. /
* @param originalMethod - The original method to be wrapped. /
* @param middlewares - An array of `ProviderMethodMiddleware` to apply. / `ProviderMethodMiddleware`
* @param specificContextFactory - An optional factory to create a specific context for this method. /
* @returns An enhanced method with the middlewares applied. /
*/
export function applyMethodMiddlewares<
TArgs extends unknown[] = unknown[], // Original method's arguments array type / 原始方法的参数数组类型
TResult = unknown,
TContext extends BaseContext = BaseContext
>(
methodName: string,
originalMethod: (...args: TArgs) => Promise<TResult>,
middlewares: MethodMiddleware[], // Expects generic middlewares / 期望通用中间件
specificContextFactory?: (base: BaseContext, callArgs: TArgs) => TContext
): (...args: TArgs) => Promise<TResult> {
// Returns a function matching the original method signature. /
// 返回一个与原始方法签名匹配的函数。
return async function enhancedMethod(...methodCallArgs: TArgs): Promise<TResult> {
const ctx = createInitialCallContext<TContext, TArgs>(
methodName,
methodCallArgs, // Pass the actual call arguments array / 传递实际的调用参数数组
specificContextFactory
)
const api: MiddlewareAPI<TContext, TArgs> = {
getContext: () => ctx,
getOriginalArgs: () => methodCallArgs // API provides the original arguments array / API提供原始参数数组
}
// `finalDispatch` is the function that will ultimately call the original provider method. /
// `finalDispatch` 是最终将调用原始提供者方法的函数。
// It receives the current context and arguments, which may have been transformed by middlewares. /
// 它接收当前的上下文和参数,这些参数可能已被中间件转换。
const finalDispatch = async (
_: TContext,
currentArgs: TArgs // Generic final dispatch expects args array / 通用finalDispatch期望参数数组
): Promise<TResult> => {
return originalMethod.apply(currentArgs)
}
const chain = middlewares.map((middleware) => middleware(api)) // Cast API if TContext/TArgs mismatch general ProviderMethodMiddleware / 如果TContext/TArgs与通用的ProviderMethodMiddleware不匹配则转换API
const composedMiddlewareLogic = compose(...chain)
const enhancedDispatch = composedMiddlewareLogic(finalDispatch)
return enhancedDispatch(ctx, methodCallArgs) // Pass context and original args array / 传递上下文和原始参数数组
}
}
/**
* Applies an array of `CompletionsMiddleware` to the `completions` method. /
* `CompletionsMiddleware` `completions`
* This version adapts for `CompletionsMiddleware` expecting a single `params` object. /
* `params` `CompletionsMiddleware`
* @param originalProviderInstance - The original provider instance. /
* @param originalCompletionsMethod - The original SDK `createCompletions` method. / SDK `createCompletions`
* @param middlewares - An array of `CompletionsMiddleware` to apply. / `CompletionsMiddleware`
* @returns An enhanced `completions` method with the middlewares applied. / `completions`
*/
export function applyCompletionsMiddlewares<
TSdkInstance extends SdkInstance = SdkInstance,
TSdkParams extends SdkParams = SdkParams,
TRawOutput extends SdkRawOutput = SdkRawOutput,
TRawChunk extends SdkRawChunk = SdkRawChunk,
TMessageParam extends SdkMessageParam = SdkMessageParam,
TToolCall extends SdkToolCall = SdkToolCall,
TSdkSpecificTool extends SdkTool = SdkTool
>(
originalApiClientInstance: BaseApiClient<
TSdkInstance,
TSdkParams,
TRawOutput,
TRawChunk,
TMessageParam,
TToolCall,
TSdkSpecificTool
>,
originalCompletionsMethod: (payload: TSdkParams, options?: RequestOptions) => Promise<TRawOutput>,
middlewares: CompletionsMiddleware<
TSdkParams,
TMessageParam,
TToolCall,
TSdkInstance,
TRawOutput,
TRawChunk,
TSdkSpecificTool
>[]
): (params: CompletionsParams, options?: RequestOptions) => Promise<CompletionsResult> {
// Returns a function matching the original method signature. /
// 返回一个与原始方法签名匹配的函数。
const methodName = 'completions'
// Factory to create AiProviderMiddlewareCompletionsContext. /
// 用于创建 AiProviderMiddlewareCompletionsContext 的工厂函数。
const completionsContextFactory = (
base: BaseContext,
callArgs: [CompletionsParams]
): CompletionsContext<
TSdkParams,
TMessageParam,
TToolCall,
TSdkInstance,
TRawOutput,
TRawChunk,
TSdkSpecificTool
> => {
return {
...base,
methodName,
apiClientInstance: originalApiClientInstance,
originalArgs: callArgs,
_internal: {
toolProcessingState: {
recursionDepth: 0,
isRecursiveCall: false
},
observer: {}
}
}
}
return async function enhancedCompletionsMethod(
params: CompletionsParams,
options?: RequestOptions
): Promise<CompletionsResult> {
// `originalCallArgs` for context creation is `[params]`. /
// 用于上下文创建的 `originalCallArgs` 是 `[params]`。
const originalCallArgs: [CompletionsParams] = [params]
const baseContext: BaseContext = {
[MIDDLEWARE_CONTEXT_SYMBOL]: true,
methodName,
originalArgs: originalCallArgs
}
const ctx = completionsContextFactory(baseContext, originalCallArgs)
const api: MiddlewareAPI<
CompletionsContext<TSdkParams, TMessageParam, TToolCall, TSdkInstance, TRawOutput, TRawChunk, TSdkSpecificTool>,
[CompletionsParams]
> = {
getContext: () => ctx,
getOriginalArgs: () => originalCallArgs // API provides [CompletionsParams] / API提供 `[CompletionsParams]`
}
// `finalDispatch` for CompletionsMiddleware: expects (context, params) not (context, args_array). /
// `CompletionsMiddleware` 的 `finalDispatch`:期望 (context, params) 而不是 (context, args_array)。
const finalDispatch = async (
context: CompletionsContext<
TSdkParams,
TMessageParam,
TToolCall,
TSdkInstance,
TRawOutput,
TRawChunk,
TSdkSpecificTool
> // Context passed through / 上下文透传
// _currentParams: CompletionsParams // Directly takes params / 直接接收参数 (unused but required for middleware signature)
): Promise<CompletionsResult> => {
// At this point, middleware should have transformed CompletionsParams to SDK params
// and stored them in context. If no transformation happened, we need to handle it.
// 此时,中间件应该已经将 CompletionsParams 转换为 SDK 参数并存储在上下文中。
// 如果没有进行转换,我们需要处理它。
const sdkPayload = context._internal?.sdkPayload
if (!sdkPayload) {
throw new Error('SDK payload not found in context. Middleware chain should have transformed parameters.')
}
const abortSignal = context._internal.flowControl?.abortSignal
const timeout = context._internal.customState?.sdkMetadata?.timeout
const methodCall = async (payload) => {
return await originalCompletionsMethod.call(originalApiClientInstance, payload, {
...options,
signal: abortSignal,
timeout
})
}
const traceParams = {
name: `${params.assistant?.model?.name}.client`,
tag: 'LLM',
topicId: params.topicId || '',
modelName: params.assistant?.model?.name
}
// Call the original SDK method with transformed parameters
// 使用转换后的参数调用原始 SDK 方法
const rawOutput = await withSpanResult(methodCall, traceParams, sdkPayload)
// Return result wrapped in CompletionsResult format
// 以 CompletionsResult 格式返回包装的结果
return { rawOutput } as CompletionsResult
}
const chain = middlewares.map((middleware) => middleware(api))
const composedMiddlewareLogic = compose(...chain)
// `enhancedDispatch` has the signature `(context, params) => Promise<CompletionsResult>`. /
// `enhancedDispatch` 的签名为 `(context, params) => Promise<CompletionsResult>`。
const enhancedDispatch = composedMiddlewareLogic(finalDispatch)
// 将 enhancedDispatch 保存到 context 中,供中间件进行递归调用
// 这样可以避免重复执行整个中间件链
ctx._internal.enhancedDispatch = enhancedDispatch
// Execute with context and the single params object. /
// 使用上下文和单个参数对象执行。
return enhancedDispatch(ctx, params)
}
}

View File

@ -1,591 +0,0 @@
import { loggerService } from '@logger'
import type { MCPCallToolResponse, MCPTool, MCPToolResponse, Model } from '@renderer/types'
import type { MCPToolCreatedChunk } from '@renderer/types/chunk'
import { ChunkType } from '@renderer/types/chunk'
import type { SdkMessageParam, SdkRawOutput, SdkToolCall } from '@renderer/types/sdk'
import {
callBuiltInTool,
callMCPTool,
getMcpServerByTool,
isToolAutoApproved,
parseToolUse,
upsertMCPToolResponse
} from '@renderer/utils/mcp-tools'
import { confirmSameNameTools, requestToolConfirmation, setToolIdToNameMapping } from '@renderer/utils/userConfirmation'
import type { CompletionsParams, CompletionsResult, GenericChunk } from '../schemas'
import type { CompletionsContext, CompletionsMiddleware } from '../types'
export const MIDDLEWARE_NAME = 'McpToolChunkMiddleware'
const MAX_TOOL_RECURSION_DEPTH = 20 // 防止无限递归
const logger = loggerService.withContext('McpToolChunkMiddleware')
/**
* MCP工具处理中间件
*
*
* 1. MCP工具进展chunkFunction Call方式和Tool Use方式
* 2.
* 3.
* 4.
*/
export const McpToolChunkMiddleware: CompletionsMiddleware =
() =>
(next) =>
async (ctx: CompletionsContext, params: CompletionsParams): Promise<CompletionsResult> => {
const mcpTools = params.mcpTools || []
// 如果没有工具,直接调用下一个中间件
if (!mcpTools || mcpTools.length === 0) {
return next(ctx, params)
}
const executeWithToolHandling = async (currentParams: CompletionsParams, depth = 0): Promise<CompletionsResult> => {
if (depth >= MAX_TOOL_RECURSION_DEPTH) {
logger.error(`Maximum recursion depth ${MAX_TOOL_RECURSION_DEPTH} exceeded`)
throw new Error(`Maximum tool recursion depth ${MAX_TOOL_RECURSION_DEPTH} exceeded`)
}
let result: CompletionsResult
if (depth === 0) {
result = await next(ctx, currentParams)
} else {
const enhancedCompletions = ctx._internal.enhancedDispatch
if (!enhancedCompletions) {
logger.error(`Enhanced completions method not found, cannot perform recursive call`)
throw new Error('Enhanced completions method not found')
}
ctx._internal.toolProcessingState!.isRecursiveCall = true
ctx._internal.toolProcessingState!.recursionDepth = depth
result = await enhancedCompletions(ctx, currentParams)
}
if (!result.stream) {
logger.error(`No stream returned from enhanced completions`)
throw new Error('No stream returned from enhanced completions')
}
const resultFromUpstream = result.stream as ReadableStream<GenericChunk>
const toolHandlingStream = resultFromUpstream.pipeThrough(
createToolHandlingTransform(ctx, currentParams, mcpTools, depth, executeWithToolHandling)
)
return {
...result,
stream: toolHandlingStream
}
}
return executeWithToolHandling(params, 0)
}
/**
* TransformStream
*/
function createToolHandlingTransform(
ctx: CompletionsContext,
currentParams: CompletionsParams,
mcpTools: MCPTool[],
depth: number,
executeWithToolHandling: (params: CompletionsParams, depth: number) => Promise<CompletionsResult>
): TransformStream<GenericChunk, GenericChunk> {
const toolCalls: SdkToolCall[] = []
const toolUseResponses: MCPToolResponse[] = []
const allToolResponses: MCPToolResponse[] = [] // 统一的工具响应状态管理数组
let hasToolCalls = false
let hasToolUseResponses = false
let streamEnded = false
// 存储已执行的工具结果
const executedToolResults: SdkMessageParam[] = []
const executedToolCalls: SdkToolCall[] = []
const executionPromises: Promise<void>[] = []
return new TransformStream({
async transform(chunk: GenericChunk, controller) {
try {
// 处理MCP工具进展chunk
logger.silly('chunk', chunk)
if (chunk.type === ChunkType.MCP_TOOL_CREATED) {
const createdChunk = chunk as MCPToolCreatedChunk
// 1. 处理Function Call方式的工具调用
if (createdChunk.tool_calls && createdChunk.tool_calls.length > 0) {
hasToolCalls = true
for (const toolCall of createdChunk.tool_calls) {
toolCalls.push(toolCall)
const executionPromise = (async () => {
try {
const result = await executeToolCalls(
ctx,
[toolCall],
mcpTools,
allToolResponses,
currentParams.onChunk,
currentParams.assistant.model!,
currentParams.topicId
)
// 缓存执行结果
executedToolResults.push(...result.toolResults)
executedToolCalls.push(...result.confirmedToolCalls)
} catch (error) {
logger.error(`Error executing tool call asynchronously:`, error as Error)
}
})()
executionPromises.push(executionPromise)
}
}
// 2. 处理Tool Use方式的工具调用
if (createdChunk.tool_use_responses && createdChunk.tool_use_responses.length > 0) {
hasToolUseResponses = true
for (const toolUseResponse of createdChunk.tool_use_responses) {
toolUseResponses.push(toolUseResponse)
const executionPromise = (async () => {
try {
const result = await executeToolUseResponses(
ctx,
[toolUseResponse], // 单个执行
mcpTools,
allToolResponses,
currentParams.onChunk,
currentParams.assistant.model!,
currentParams.topicId
)
// 缓存执行结果
executedToolResults.push(...result.toolResults)
} catch (error) {
logger.error(`Error executing tool use response asynchronously:`, error as Error)
// 错误时不影响其他工具的执行
}
})()
executionPromises.push(executionPromise)
}
}
} else {
controller.enqueue(chunk)
}
} catch (error) {
logger.error(`Error processing chunk:`, error as Error)
controller.error(error)
}
},
async flush(controller) {
// 在流结束时等待所有异步工具执行完成,然后进行递归调用
if (!streamEnded && (hasToolCalls || hasToolUseResponses)) {
streamEnded = true
try {
await Promise.all(executionPromises)
if (executedToolResults.length > 0) {
const output = ctx._internal.toolProcessingState?.output
const newParams = buildParamsWithToolResults(
ctx,
currentParams,
output,
executedToolResults,
executedToolCalls
)
// 在递归调用前通知UI开始新的LLM响应处理
if (currentParams.onChunk) {
currentParams.onChunk({
type: ChunkType.LLM_RESPONSE_CREATED
})
}
await executeWithToolHandling(newParams, depth + 1)
}
} catch (error) {
logger.error(`Error in tool processing:`, error as Error)
controller.error(error)
} finally {
hasToolCalls = false
hasToolUseResponses = false
}
}
}
})
}
/**
* Function Call
*/
async function executeToolCalls(
ctx: CompletionsContext,
toolCalls: SdkToolCall[],
mcpTools: MCPTool[],
allToolResponses: MCPToolResponse[],
onChunk: CompletionsParams['onChunk'],
model: Model,
topicId?: string
): Promise<{ toolResults: SdkMessageParam[]; confirmedToolCalls: SdkToolCall[] }> {
const mcpToolResponses: MCPToolResponse[] = toolCalls
.map((toolCall) => {
const mcpTool = ctx.apiClientInstance.convertSdkToolCallToMcp(toolCall, mcpTools)
if (!mcpTool) {
return undefined
}
return ctx.apiClientInstance.convertSdkToolCallToMcpToolResponse(toolCall, mcpTool)
})
.filter((t): t is MCPToolResponse => typeof t !== 'undefined')
if (mcpToolResponses.length === 0) {
logger.warn(`No valid MCP tool responses to execute`)
return { toolResults: [], confirmedToolCalls: [] }
}
// 使用现有的parseAndCallTools函数执行工具
const { toolResults, confirmedToolResponses } = await parseAndCallTools(
mcpToolResponses,
allToolResponses,
onChunk,
(mcpToolResponse, resp, model) => {
return ctx.apiClientInstance.convertMcpToolResponseToSdkMessageParam(mcpToolResponse, resp, model)
},
model,
mcpTools,
ctx._internal?.flowControl?.abortSignal,
topicId
)
// 找出已确认工具对应的原始toolCalls
const confirmedToolCalls = toolCalls.filter((toolCall) => {
return confirmedToolResponses.find((confirmed) => {
// 根据不同的ID字段匹配原始toolCall
return (
('name' in toolCall &&
(toolCall.name?.includes(confirmed.tool.name) || toolCall.name?.includes(confirmed.tool.id))) ||
confirmed.tool.name === toolCall.id ||
confirmed.tool.id === toolCall.id ||
('toolCallId' in confirmed && confirmed.toolCallId === toolCall.id) ||
('function' in toolCall && toolCall.function.name.toLowerCase().includes(confirmed.tool.name.toLowerCase()))
)
})
})
return { toolResults, confirmedToolCalls }
}
/**
* 使Tool Use Response
* ToolUseResponse[]
*/
async function executeToolUseResponses(
ctx: CompletionsContext,
toolUseResponses: MCPToolResponse[],
mcpTools: MCPTool[],
allToolResponses: MCPToolResponse[],
onChunk: CompletionsParams['onChunk'],
model: Model,
topicId?: CompletionsParams['topicId']
): Promise<{ toolResults: SdkMessageParam[] }> {
// 直接使用parseAndCallTools函数处理已经解析好的ToolUseResponse
const { toolResults } = await parseAndCallTools(
toolUseResponses,
allToolResponses,
onChunk,
(mcpToolResponse, resp, model) => {
return ctx.apiClientInstance.convertMcpToolResponseToSdkMessageParam(mcpToolResponse, resp, model)
},
model,
mcpTools,
ctx._internal?.flowControl?.abortSignal,
topicId
)
return { toolResults }
}
/**
*
*/
function buildParamsWithToolResults(
ctx: CompletionsContext,
currentParams: CompletionsParams,
output: SdkRawOutput | string | undefined,
toolResults: SdkMessageParam[],
confirmedToolCalls: SdkToolCall[]
): CompletionsParams {
// 获取当前已经转换好的reqMessages如果没有则使用原始messages
const currentReqMessages = getCurrentReqMessages(ctx)
const apiClient = ctx.apiClientInstance
// 从回复中构建助手消息
const newReqMessages = apiClient.buildSdkMessages(currentReqMessages, output, toolResults, confirmedToolCalls)
if (output && ctx._internal.toolProcessingState) {
ctx._internal.toolProcessingState.output = undefined
}
// 估算新增消息的 token 消耗并累加到 usage 中
if (ctx._internal.observer?.usage && newReqMessages.length > currentReqMessages.length) {
try {
const newMessages = newReqMessages.slice(currentReqMessages.length)
const additionalTokens = newMessages.reduce((acc, message) => {
return acc + ctx.apiClientInstance.estimateMessageTokens(message)
}, 0)
if (additionalTokens > 0) {
ctx._internal.observer.usage.prompt_tokens += additionalTokens
ctx._internal.observer.usage.total_tokens += additionalTokens
}
} catch (error) {
logger.error(`Error estimating token usage for new messages:`, error as Error)
}
}
// 更新递归状态
if (!ctx._internal.toolProcessingState) {
ctx._internal.toolProcessingState = {}
}
ctx._internal.toolProcessingState.isRecursiveCall = true
ctx._internal.toolProcessingState.recursionDepth = (ctx._internal.toolProcessingState?.recursionDepth || 0) + 1
return {
...currentParams,
_internal: {
...ctx._internal,
sdkPayload: ctx._internal.sdkPayload,
newReqMessages: newReqMessages
}
}
}
/**
*
* 使API客户端提供的抽象方法provider无关性
*/
function getCurrentReqMessages(ctx: CompletionsContext): SdkMessageParam[] {
const sdkPayload = ctx._internal.sdkPayload
if (!sdkPayload) {
return []
}
// 使用API客户端的抽象方法来提取消息保持provider无关性
return ctx.apiClientInstance.extractMessagesFromSdkPayload(sdkPayload)
}
export async function parseAndCallTools<R>(
tools: MCPToolResponse[],
allToolResponses: MCPToolResponse[],
onChunk: CompletionsParams['onChunk'],
convertToMessage: (mcpToolResponse: MCPToolResponse, resp: MCPCallToolResponse, model: Model) => R | undefined,
model: Model,
mcpTools?: MCPTool[],
abortSignal?: AbortSignal,
topicId?: CompletionsParams['topicId']
): Promise<{ toolResults: R[]; confirmedToolResponses: MCPToolResponse[] }>
export async function parseAndCallTools<R>(
content: string,
allToolResponses: MCPToolResponse[],
onChunk: CompletionsParams['onChunk'],
convertToMessage: (mcpToolResponse: MCPToolResponse, resp: MCPCallToolResponse, model: Model) => R | undefined,
model: Model,
mcpTools?: MCPTool[],
abortSignal?: AbortSignal,
topicId?: CompletionsParams['topicId']
): Promise<{ toolResults: R[]; confirmedToolResponses: MCPToolResponse[] }>
export async function parseAndCallTools<R>(
content: string | MCPToolResponse[],
allToolResponses: MCPToolResponse[],
onChunk: CompletionsParams['onChunk'],
convertToMessage: (mcpToolResponse: MCPToolResponse, resp: MCPCallToolResponse, model: Model) => R | undefined,
model: Model,
mcpTools?: MCPTool[],
abortSignal?: AbortSignal,
topicId?: CompletionsParams['topicId']
): Promise<{ toolResults: R[]; confirmedToolResponses: MCPToolResponse[] }> {
const toolResults: R[] = []
let curToolResponses: MCPToolResponse[] = []
if (Array.isArray(content)) {
curToolResponses = content
} else {
// process tool use
curToolResponses = parseToolUse(content, mcpTools || [], 0)
}
if (!curToolResponses || curToolResponses.length === 0) {
return { toolResults, confirmedToolResponses: [] }
}
for (const toolResponse of curToolResponses) {
upsertMCPToolResponse(
allToolResponses,
{
...toolResponse,
status: 'pending'
},
onChunk!
)
}
// 创建工具确认Promise映射并立即处理每个确认
const confirmedTools: MCPToolResponse[] = []
const pendingPromises: Promise<void>[] = []
curToolResponses.forEach((toolResponse) => {
const server = getMcpServerByTool(toolResponse.tool)
const isAutoApproveEnabled = isToolAutoApproved(toolResponse.tool, server)
let confirmationPromise: Promise<boolean>
if (isAutoApproveEnabled) {
confirmationPromise = Promise.resolve(true)
} else {
setToolIdToNameMapping(toolResponse.id, toolResponse.tool.name)
confirmationPromise = requestToolConfirmation(toolResponse.id, abortSignal).then((confirmed) => {
if (confirmed && server) {
// 自动确认其他同名的待确认工具
confirmSameNameTools(toolResponse.tool.name)
}
return confirmed
})
}
const processingPromise = confirmationPromise
.then(async (confirmed) => {
if (confirmed) {
// 立即更新为invoking状态
upsertMCPToolResponse(
allToolResponses,
{
...toolResponse,
status: 'invoking'
},
onChunk!
)
// 执行工具调用
try {
const images: string[] = []
// 根据工具类型选择不同的调用方式
const toolCallResponse = toolResponse.tool.isBuiltIn
? await callBuiltInTool(toolResponse)
: await callMCPTool(toolResponse, topicId, model.name)
// 立即更新为done状态
upsertMCPToolResponse(
allToolResponses,
{
...toolResponse,
status: 'done',
response: toolCallResponse
},
onChunk!
)
if (!toolCallResponse) {
return
}
// 处理图片
for (const content of toolCallResponse.content) {
if (content.type === 'image' && content.data) {
images.push(`data:${content.mimeType};base64,${content.data}`)
}
}
if (images.length) {
onChunk?.({
type: ChunkType.IMAGE_CREATED
})
onChunk?.({
type: ChunkType.IMAGE_COMPLETE,
image: {
type: 'base64',
images: images
}
})
}
// 转换消息并添加到结果
const convertedMessage = convertToMessage(toolResponse, toolCallResponse, model)
if (convertedMessage) {
confirmedTools.push(toolResponse)
toolResults.push(convertedMessage)
}
} catch (error) {
logger.error(`Error executing tool ${toolResponse.id}:`, error as Error)
// 更新为错误状态
upsertMCPToolResponse(
allToolResponses,
{
...toolResponse,
status: 'done',
response: {
isError: true,
content: [
{
type: 'text',
text: `Error executing tool: ${error instanceof Error ? error.message : 'Unknown error'}`
}
]
}
},
onChunk!
)
}
} else {
// 立即更新为cancelled状态
upsertMCPToolResponse(
allToolResponses,
{
...toolResponse,
status: 'cancelled',
response: {
isError: false,
content: [
{
type: 'text',
text: 'Tool call cancelled by user.'
}
]
}
},
onChunk!
)
}
})
.catch((error) => {
logger.error(`Error waiting for tool confirmation ${toolResponse.id}:`, error as Error)
// 立即更新为cancelled状态
upsertMCPToolResponse(
allToolResponses,
{
...toolResponse,
status: 'cancelled',
response: {
isError: true,
content: [
{
type: 'text',
text: `Error in confirmation process: ${error instanceof Error ? error.message : 'Unknown error'}`
}
]
}
},
onChunk!
)
})
pendingPromises.push(processingPromise)
})
// 等待所有工具处理完成(但每个工具的状态已经实时更新)
await Promise.all(pendingPromises)
return { toolResults, confirmedToolResponses: confirmedTools }
}

View File

@ -1,45 +0,0 @@
import { AnthropicAPIClient } from '@renderer/aiCore/legacy/clients/anthropic/AnthropicAPIClient'
import type { AnthropicSdkRawChunk, AnthropicSdkRawOutput } from '@renderer/types/sdk'
import type { AnthropicStreamListener } from '../../clients/types'
import type { CompletionsParams, CompletionsResult } from '../schemas'
import type { CompletionsContext, CompletionsMiddleware } from '../types'
export const MIDDLEWARE_NAME = 'RawStreamListenerMiddleware'
export const RawStreamListenerMiddleware: CompletionsMiddleware =
() =>
(next) =>
async (ctx: CompletionsContext, params: CompletionsParams): Promise<CompletionsResult> => {
const result = await next(ctx, params)
// 在这里可以监听到从SDK返回的最原始流
if (result.rawOutput) {
// TODO: 后面下放到AnthropicAPIClient
if (ctx.apiClientInstance instanceof AnthropicAPIClient) {
const anthropicListener: AnthropicStreamListener<AnthropicSdkRawChunk> = {
onMessage: (message) => {
if (ctx._internal?.toolProcessingState) {
ctx._internal.toolProcessingState.output = message
}
}
// onContentBlock: (contentBlock) => {
// console.log(`[${MIDDLEWARE_NAME}] 📝 Anthropic content block:`, contentBlock.type)
// }
}
const specificApiClient = ctx.apiClientInstance as AnthropicAPIClient
const monitoredOutput = specificApiClient.attachRawStreamListener(
result.rawOutput as AnthropicSdkRawOutput,
anthropicListener
)
return {
...result,
rawOutput: monitoredOutput
}
}
}
return result
}

View File

@ -1,88 +0,0 @@
import { loggerService } from '@logger'
import type { SdkRawChunk } from '@renderer/types/sdk'
import type { ResponseChunkTransformerContext } from '../../clients/types'
import type { CompletionsParams, CompletionsResult, GenericChunk } from '../schemas'
import type { CompletionsContext, CompletionsMiddleware } from '../types'
export const MIDDLEWARE_NAME = 'ResponseTransformMiddleware'
const logger = loggerService.withContext('ResponseTransformMiddleware')
/**
*
*
*
* 1. ReadableStream类型的响应流
* 2. 使ApiClient的getResponseChunkTransformer()SDK响应块转换为通用格式
* 3. ReadableStream保存到ctx._internal.apiCall.genericChunkStream使
*
* StreamAdapterMiddleware之后执行
*/
export const ResponseTransformMiddleware: CompletionsMiddleware =
() =>
(next) =>
async (ctx: CompletionsContext, params: CompletionsParams): Promise<CompletionsResult> => {
// 调用下游中间件
const result = await next(ctx, params)
// 响应后处理转换原始SDK响应块
if (result.stream) {
const adaptedStream = result.stream
// 处理ReadableStream类型的流
if (adaptedStream instanceof ReadableStream) {
const apiClient = ctx.apiClientInstance
if (!apiClient) {
logger.error(`ApiClient instance not found in context`)
throw new Error('ApiClient instance not found in context')
}
// 获取响应转换器
const responseChunkTransformer = apiClient.getResponseChunkTransformer(ctx)
if (!responseChunkTransformer) {
logger.warn(`No ResponseChunkTransformer available, skipping transformation`)
return result
}
const assistant = params.assistant
const model = assistant?.model
if (!assistant || !model) {
logger.error(`Assistant or Model not found for transformation`)
throw new Error('Assistant or Model not found for transformation')
}
const transformerContext: ResponseChunkTransformerContext = {
isStreaming: params.streamOutput || false,
isEnabledToolCalling: (params.mcpTools && params.mcpTools.length > 0) || false,
isEnabledWebSearch: params.enableWebSearch || false,
isEnabledUrlContext: params.enableUrlContext || false,
isEnabledReasoning: params.enableReasoning || false,
mcpTools: params.mcpTools || [],
provider: ctx.apiClientInstance?.provider
}
logger.debug(`Transforming raw SDK chunks with context:`, transformerContext)
try {
// 创建转换后的流
const genericChunkTransformStream = (adaptedStream as ReadableStream<SdkRawChunk>).pipeThrough<GenericChunk>(
new TransformStream<SdkRawChunk, GenericChunk>(responseChunkTransformer(transformerContext))
)
// 将转换后的ReadableStream保存到result供下游中间件使用
return {
...result,
stream: genericChunkTransformStream
}
} catch (error) {
logger.error('Error during chunk transformation:', error as Error)
throw error
}
}
}
// 如果没有流或不是ReadableStream返回原始结果
return result
}

View File

@ -1,56 +0,0 @@
import type { SdkRawChunk } from '@renderer/types/sdk'
import { asyncGeneratorToReadableStream, createSingleChunkReadableStream } from '@renderer/utils/stream'
import type { CompletionsParams, CompletionsResult } from '../schemas'
import type { CompletionsContext, CompletionsMiddleware } from '../types'
import { isAsyncIterable } from '../utils'
export const MIDDLEWARE_NAME = 'StreamAdapterMiddleware'
/**
*
*
*
* 1. ctx._internal.apiCall.rawSdkOutputAsyncIterable流
* 2. AsyncIterable转换为WHATWG ReadableStream
* 3. stream
*
* ResponseTransformMiddleware已处理过使transformedStream
*/
export const StreamAdapterMiddleware: CompletionsMiddleware =
() =>
(next) =>
async (ctx: CompletionsContext, params: CompletionsParams): Promise<CompletionsResult> => {
// TODO:调用开始因为这个是最靠近接口请求的地方next执行代表着开始接口请求了
// 但是这个中间件的职责是流适配,是否在这调用优待商榷
// 调用下游中间件
const result = await next(ctx, params)
if (
result.rawOutput &&
!(result.rawOutput instanceof ReadableStream) &&
isAsyncIterable<SdkRawChunk>(result.rawOutput)
) {
const whatwgReadableStream: ReadableStream<SdkRawChunk> = asyncGeneratorToReadableStream<SdkRawChunk>(
result.rawOutput
)
return {
...result,
stream: whatwgReadableStream
}
} else if (result.rawOutput && result.rawOutput instanceof ReadableStream) {
return {
...result,
stream: result.rawOutput
}
} else if (result.rawOutput) {
// 非流式输出,强行变为可读流
const whatwgReadableStream: ReadableStream<SdkRawChunk> = createSingleChunkReadableStream<SdkRawChunk>(
result.rawOutput as SdkRawChunk
)
return {
...result,
stream: whatwgReadableStream
}
}
return result
}

View File

@ -1,106 +0,0 @@
import { loggerService } from '@logger'
import { ChunkType } from '@renderer/types/chunk'
import type { CompletionsParams, CompletionsResult, GenericChunk } from '../schemas'
import type { CompletionsContext, CompletionsMiddleware } from '../types'
export const MIDDLEWARE_NAME = 'TextChunkMiddleware'
const logger = loggerService.withContext('TextChunkMiddleware')
/**
*
*
*
* 1. TEXT_DELTA
* 2.
* 3. TEXT_COMPLETE事件
* 4. Web搜索结果
* 5. onResponse
*/
export const TextChunkMiddleware: CompletionsMiddleware =
() =>
(next) =>
async (ctx: CompletionsContext, params: CompletionsParams): Promise<CompletionsResult> => {
// 调用下游中间件
const result = await next(ctx, params)
// 响应后处理:转换流式响应中的文本内容
if (result.stream) {
const resultFromUpstream = result.stream as ReadableStream<GenericChunk>
if (resultFromUpstream && resultFromUpstream instanceof ReadableStream) {
const assistant = params.assistant
const model = params.assistant?.model
if (!assistant || !model) {
logger.warn(`Missing assistant or model information, skipping text processing`)
return result
}
// 用于跨chunk的状态管理
let accumulatedTextContent = ''
const enhancedTextStream = resultFromUpstream.pipeThrough(
new TransformStream<GenericChunk, GenericChunk>({
transform(chunk: GenericChunk, controller) {
logger.silly('chunk', chunk)
if (chunk.type === ChunkType.TEXT_DELTA) {
if (model.supported_text_delta === false) {
accumulatedTextContent = chunk.text
} else {
accumulatedTextContent += chunk.text
}
// 处理 onResponse 回调 - 发送增量文本更新
if (params.onResponse) {
params.onResponse(accumulatedTextContent, false)
}
controller.enqueue({
...chunk,
text: accumulatedTextContent // 增量更新
})
} else if (accumulatedTextContent && chunk.type !== ChunkType.TEXT_START) {
ctx._internal.customState!.accumulatedText = accumulatedTextContent
if (ctx._internal.toolProcessingState && !ctx._internal.toolProcessingState?.output) {
ctx._internal.toolProcessingState.output = accumulatedTextContent
}
if (chunk.type === ChunkType.LLM_RESPONSE_COMPLETE) {
// 处理 onResponse 回调 - 发送最终完整文本
if (params.onResponse) {
params.onResponse(accumulatedTextContent, true)
}
controller.enqueue({
type: ChunkType.TEXT_COMPLETE,
text: accumulatedTextContent
})
controller.enqueue(chunk)
} else {
controller.enqueue({
type: ChunkType.TEXT_COMPLETE,
text: accumulatedTextContent
})
controller.enqueue(chunk)
}
accumulatedTextContent = ''
} else {
// 其他类型的chunk直接传递
controller.enqueue(chunk)
}
}
})
)
// 更新响应结果
return {
...result,
stream: enhancedTextStream
}
} else {
logger.warn(`No stream to process or not a ReadableStream. Returning original result.`)
}
}
return result
}

View File

@ -1,99 +0,0 @@
import { loggerService } from '@logger'
import type { ThinkingCompleteChunk, ThinkingDeltaChunk } from '@renderer/types/chunk'
import { ChunkType } from '@renderer/types/chunk'
import type { CompletionsParams, CompletionsResult, GenericChunk } from '../schemas'
import type { CompletionsContext, CompletionsMiddleware } from '../types'
export const MIDDLEWARE_NAME = 'ThinkChunkMiddleware'
const logger = loggerService.withContext('ThinkChunkMiddleware')
/**
*
*
* v2 ApiClient
*
* 1. SDK chunk中的reasoning字段
* 2.
* 3. THINKING_COMPLETE事件
*
*
* 1. THINKING_DELTA
* 2. THINKING_COMPLETE事件
* 3.
*
*/
export const ThinkChunkMiddleware: CompletionsMiddleware =
() =>
(next) =>
async (ctx: CompletionsContext, params: CompletionsParams): Promise<CompletionsResult> => {
// 调用下游中间件
const result = await next(ctx, params)
// 响应后处理:处理思考内容
if (result.stream) {
const resultFromUpstream = result.stream as ReadableStream<GenericChunk>
// 检查是否有流需要处理
if (resultFromUpstream && resultFromUpstream instanceof ReadableStream) {
// thinking 处理状态
let accumulatedThinkingContent = ''
let hasThinkingContent = false
let thinkingStartTime = 0
const processedStream = resultFromUpstream.pipeThrough(
new TransformStream<GenericChunk, GenericChunk>({
transform(chunk: GenericChunk, controller) {
if (chunk.type === ChunkType.THINKING_DELTA) {
const thinkingChunk = chunk as ThinkingDeltaChunk
// 第一次接收到思考内容时记录开始时间
if (!hasThinkingContent) {
hasThinkingContent = true
thinkingStartTime = Date.now()
}
accumulatedThinkingContent += thinkingChunk.text
// 更新思考时间并传递
const enhancedChunk: ThinkingDeltaChunk = {
...thinkingChunk,
text: accumulatedThinkingContent,
thinking_millsec: thinkingStartTime > 0 ? Date.now() - thinkingStartTime : 0
}
controller.enqueue(enhancedChunk)
} else if (hasThinkingContent && thinkingStartTime > 0 && chunk.type !== ChunkType.THINKING_START) {
// 收到任何非THINKING_DELTA的chunk时如果有累积的思考内容生成THINKING_COMPLETE
const thinkingCompleteChunk: ThinkingCompleteChunk = {
type: ChunkType.THINKING_COMPLETE,
text: accumulatedThinkingContent,
thinking_millsec: thinkingStartTime > 0 ? Date.now() - thinkingStartTime : 0
}
controller.enqueue(thinkingCompleteChunk)
hasThinkingContent = false
accumulatedThinkingContent = ''
thinkingStartTime = 0
// 继续传递当前chunk
controller.enqueue(chunk)
} else {
// 其他情况直接传递
controller.enqueue(chunk)
}
}
})
)
// 更新响应结果
return {
...result,
stream: processedStream
}
} else {
logger.warn(`No generic chunk stream to process or not a ReadableStream.`)
}
}
return result
}

View File

@ -1,81 +0,0 @@
import { loggerService } from '@logger'
import { ChunkType } from '@renderer/types/chunk'
import type { CompletionsParams, CompletionsResult } from '../schemas'
import type { CompletionsContext, CompletionsMiddleware } from '../types'
export const MIDDLEWARE_NAME = 'TransformCoreToSdkParamsMiddleware'
const logger = loggerService.withContext('TransformCoreToSdkParamsMiddleware')
/**
* CoreCompletionsRequest转换为SDK特定的参数
* 使ApiClient实例的requestTransformer进行转换
*/
export const TransformCoreToSdkParamsMiddleware: CompletionsMiddleware =
() =>
(next) =>
async (ctx: CompletionsContext, params: CompletionsParams): Promise<CompletionsResult> => {
const internal = ctx._internal
// 🔧 检测递归调用:检查 params 中是否携带了预处理的 SDK 消息
const isRecursiveCall = internal?.toolProcessingState?.isRecursiveCall || false
const newSdkMessages = params._internal?.newReqMessages
const apiClient = ctx.apiClientInstance
if (!apiClient) {
logger.error(`ApiClient instance not found in context.`)
throw new Error('ApiClient instance not found in context')
}
// 检查是否有requestTransformer方法
const requestTransformer = apiClient.getRequestTransformer()
if (!requestTransformer) {
logger.warn(`ApiClient does not have getRequestTransformer method, skipping transformation`)
const result = await next(ctx, params)
return result
}
// 确保assistant和model可用它们是transformer所需的
const assistant = params.assistant
const model = params.assistant.model
if (!assistant || !model) {
logger.error(`Assistant or Model not found for transformation.`)
throw new Error('Assistant or Model not found for transformation')
}
try {
const transformResult = await requestTransformer.transform(
params,
assistant,
model,
isRecursiveCall,
newSdkMessages
)
const { payload: sdkPayload, metadata } = transformResult
// 将SDK特定的payload和metadata存储在状态中供下游中间件使用
ctx._internal.sdkPayload = sdkPayload
if (metadata) {
ctx._internal.customState = {
...ctx._internal.customState,
sdkMetadata: metadata
}
}
if (params.enableGenerateImage) {
params.onChunk?.({
type: ChunkType.IMAGE_CREATED
})
}
return next(ctx, params)
} catch (error) {
logger.error('Error during request transformation:', error as Error)
// 让错误向上传播,或者可以在这里进行特定的错误处理
throw error
}
}

View File

@ -1,102 +0,0 @@
import { loggerService } from '@logger'
import { ChunkType } from '@renderer/types/chunk'
import { convertLinks, flushLinkConverterBuffer } from '@renderer/utils/linkConverter'
import type { CompletionsParams, CompletionsResult, GenericChunk } from '../schemas'
import type { CompletionsContext, CompletionsMiddleware } from '../types'
const logger = loggerService.withContext('WebSearchMiddleware')
export const MIDDLEWARE_NAME = 'WebSearchMiddleware'
/**
* Web搜索处理中间件 - GenericChunk流处理
*
*
* 1. Web搜索事件
* 2. Web搜索结果的后处理逻辑
* 3. Web搜索相关的状态
*
* Web搜索结果的识别和生成已在ApiClient的响应转换器中处理
*/
export const WebSearchMiddleware: CompletionsMiddleware =
() =>
(next) =>
async (ctx: CompletionsContext, params: CompletionsParams): Promise<CompletionsResult> => {
ctx._internal.webSearchState = {
results: undefined
}
// 调用下游中间件
const result = await next(ctx, params)
let isFirstChunk = true
// 响应后处理记录Web搜索事件
if (result.stream) {
const resultFromUpstream = result.stream
if (resultFromUpstream && resultFromUpstream instanceof ReadableStream) {
// Web搜索状态跟踪
const enhancedStream = (resultFromUpstream as ReadableStream<GenericChunk>).pipeThrough(
new TransformStream<GenericChunk, GenericChunk>({
transform(chunk: GenericChunk, controller) {
if (chunk.type === ChunkType.TEXT_DELTA) {
// 使用当前可用的Web搜索结果进行链接转换
const text = chunk.text
const result = convertLinks(text, isFirstChunk)
if (isFirstChunk) {
isFirstChunk = false
}
// - 如果有内容被缓冲说明convertLinks正在等待后续chunk不使用原文本避免重复
// - 如果没有内容被缓冲且结果为空,可能是其他处理问题,使用原文本作为安全回退
let finalText: string
if (result.hasBufferedContent) {
// 有内容被缓冲使用处理后的结果可能为空等待后续chunk
finalText = result.text
} else {
// 没有内容被缓冲,可以安全使用回退逻辑
finalText = result.text || text
}
// 只有当finalText不为空时才发送chunk
if (finalText) {
controller.enqueue({
...chunk,
text: finalText
})
}
} else if (chunk.type === ChunkType.LLM_WEB_SEARCH_COMPLETE) {
// 暂存Web搜索结果用于链接完善
ctx._internal.webSearchState!.results = chunk.llm_web_search
// 将Web搜索完成事件继续传递下去
controller.enqueue(chunk)
} else if (chunk.type === ChunkType.LLM_RESPONSE_COMPLETE) {
// 流结束时清空链接转换器的buffer并处理剩余内容
const remainingText = flushLinkConverterBuffer()
if (remainingText) {
controller.enqueue({
type: ChunkType.TEXT_DELTA,
text: remainingText
})
}
// 继续传递LLM_RESPONSE_COMPLETE事件
controller.enqueue(chunk)
} else {
controller.enqueue(chunk)
}
}
})
)
return {
...result,
stream: enhancedStream
}
} else {
logger.debug(`No stream to process or not a ReadableStream.`)
}
}
return result
}

View File

@ -1,148 +0,0 @@
import type OpenAI from '@cherrystudio/openai'
import { toFile } from '@cherrystudio/openai/uploads'
import { isDedicatedImageGenerationModel } from '@renderer/config/models'
import FileManager from '@renderer/services/FileManager'
import { ChunkType } from '@renderer/types/chunk'
import { findImageBlocks, getMainTextContent } from '@renderer/utils/messageUtils/find'
import { defaultTimeout } from '@shared/config/constant'
import type { BaseApiClient } from '../../clients/BaseApiClient'
import type { CompletionsParams, CompletionsResult, GenericChunk } from '../schemas'
import type { CompletionsContext, CompletionsMiddleware } from '../types'
export const MIDDLEWARE_NAME = 'ImageGenerationMiddleware'
export const ImageGenerationMiddleware: CompletionsMiddleware =
() =>
(next) =>
async (context: CompletionsContext, params: CompletionsParams): Promise<CompletionsResult> => {
const { assistant, messages } = params
const client = context.apiClientInstance as BaseApiClient<OpenAI>
const signal = context._internal?.flowControl?.abortSignal
if (!assistant.model || !isDedicatedImageGenerationModel(assistant.model) || typeof messages === 'string') {
return next(context, params)
}
const stream = new ReadableStream<GenericChunk>({
async start(controller) {
const enqueue = (chunk: GenericChunk) => controller.enqueue(chunk)
try {
if (!assistant.model) {
throw new Error('Assistant model is not defined.')
}
const sdk = await client.getSdkInstance()
const lastUserMessage = messages.findLast((m) => m.role === 'user')
const lastAssistantMessage = messages.findLast((m) => m.role === 'assistant')
if (!lastUserMessage) {
throw new Error('No user message found for image generation.')
}
const prompt = getMainTextContent(lastUserMessage)
let imageFiles: Blob[] = []
// Collect images from user message
const userImageBlocks = findImageBlocks(lastUserMessage)
const userImages = await Promise.all(
userImageBlocks.map(async (block) => {
if (!block.file) return null
const binaryData: Uint8Array = await FileManager.readBinaryImage(block.file)
const mimeType = `${block.file.type}/${block.file.ext.slice(1)}`
return await toFile(new Blob([binaryData]), block.file.origin_name || 'image.png', { type: mimeType })
})
)
imageFiles = imageFiles.concat(userImages.filter(Boolean) as Blob[])
// Collect images from last assistant message
if (lastAssistantMessage) {
const assistantImageBlocks = findImageBlocks(lastAssistantMessage)
const assistantImages = await Promise.all(
assistantImageBlocks.map(async (block) => {
const b64 = block.url?.replace(/^data:image\/\w+;base64,/, '')
if (!b64) return null
const binary = atob(b64)
const bytes = new Uint8Array(binary.length)
for (let i = 0; i < binary.length; i++) bytes[i] = binary.charCodeAt(i)
return await toFile(new Blob([bytes]), 'assistant_image.png', { type: 'image/png' })
})
)
imageFiles = imageFiles.concat(assistantImages.filter(Boolean) as Blob[])
}
enqueue({ type: ChunkType.IMAGE_CREATED })
const startTime = Date.now()
let response: OpenAI.Images.ImagesResponse
const options = { signal, timeout: defaultTimeout }
if (imageFiles.length > 0) {
const model = assistant.model
const provider = context.apiClientInstance.provider
// https://learn.microsoft.com/en-us/azure/ai-foundry/openai/how-to/dall-e?tabs=gpt-image-1#call-the-image-edit-api
if (model.id.toLowerCase().includes('gpt-image-1-mini') && provider.type === 'azure-openai') {
throw new Error('Azure OpenAI GPT-Image-1-Mini model does not support image editing.')
}
response = await sdk.images.edit(
{
model: assistant.model.id,
image: imageFiles,
prompt: prompt || ''
},
options
)
} else {
response = await sdk.images.generate(
{
model: assistant.model.id,
prompt: prompt || '',
response_format: assistant.model.id.includes('gpt-image-1') ? undefined : 'b64_json'
},
options
)
}
let imageType: 'url' | 'base64' = 'base64'
const imageList =
response.data?.reduce((acc: string[], image) => {
if (image.url) {
acc.push(image.url)
imageType = 'url'
} else if (image.b64_json) {
acc.push(`data:image/png;base64,${image.b64_json}`)
}
return acc
}, []) || []
enqueue({
type: ChunkType.IMAGE_COMPLETE,
image: { type: imageType, images: imageList }
})
const usage = (response as any).usage || { prompt_tokens: 0, completion_tokens: 0, total_tokens: 0 }
enqueue({
type: ChunkType.LLM_RESPONSE_COMPLETE,
response: {
usage,
metrics: {
completion_tokens: usage.completion_tokens,
time_first_token_millsec: 0,
time_completion_millsec: Date.now() - startTime
}
}
})
} catch (error: any) {
enqueue({ type: ChunkType.ERROR, error })
} finally {
controller.close()
}
}
})
return {
stream,
getText: () => ''
}
}

View File

@ -1,198 +0,0 @@
import { loggerService } from '@logger'
import type { Model } from '@renderer/types'
import type {
TextDeltaChunk,
ThinkingCompleteChunk,
ThinkingDeltaChunk,
ThinkingStartChunk
} from '@renderer/types/chunk'
import { ChunkType } from '@renderer/types/chunk'
import { getLowerBaseModelName } from '@renderer/utils'
import type { TagConfig } from '@renderer/utils/tagExtraction'
import { TagExtractor } from '@renderer/utils/tagExtraction'
import type { CompletionsParams, CompletionsResult, GenericChunk } from '../schemas'
import type { CompletionsContext, CompletionsMiddleware } from '../types'
const logger = loggerService.withContext('ThinkingTagExtractionMiddleware')
export const MIDDLEWARE_NAME = 'ThinkingTagExtractionMiddleware'
// 不同模型的思考标签配置
const reasoningTags: TagConfig[] = [
{ openingTag: '<think>', closingTag: '</think>', separator: '\n' },
{ openingTag: '<thought>', closingTag: '</thought>', separator: '\n' },
{ openingTag: '###Thinking', closingTag: '###Response', separator: '\n' },
{ openingTag: '◁think▷', closingTag: '◁/think▷', separator: '\n' },
{ openingTag: '<thinking>', closingTag: '</thinking>', separator: '\n' },
{ openingTag: '<seed:think>', closingTag: '</seed:think>', separator: '\n' }
]
const getAppropriateTag = (model?: Model): TagConfig => {
const modelId = model?.id ? getLowerBaseModelName(model.id) : undefined
if (modelId?.includes('qwen3')) return reasoningTags[0]
if (modelId?.includes('gemini-2.5')) return reasoningTags[1]
if (modelId?.includes('kimi-vl-a3b-thinking')) return reasoningTags[3]
if (modelId?.includes('seed-oss-36b')) return reasoningTags[5]
// 可以在这里添加更多模型特定的标签配置
return reasoningTags[0] // 默认使用 <think> 标签
}
/**
*
*
* <think>...</think>
* OpenAI provider
*
*
* 1.
* 2. THINKING_DELTA chunk
* 3.
* 4.
* 5. THINKING_COMPLETE
*/
export const ThinkingTagExtractionMiddleware: CompletionsMiddleware =
() =>
(next) =>
async (context: CompletionsContext, params: CompletionsParams): Promise<CompletionsResult> => {
// 调用下游中间件
const result = await next(context, params)
// 响应后处理:处理思考标签提取
if (result.stream) {
const resultFromUpstream = result.stream as ReadableStream<GenericChunk>
// 检查是否有流需要处理
if (resultFromUpstream && resultFromUpstream instanceof ReadableStream) {
// 获取当前模型的思考标签配置
const model = params.assistant?.model
const reasoningTag = getAppropriateTag(model)
// 创建标签提取器
const tagExtractor = new TagExtractor(reasoningTag)
// thinking 处理状态
let hasThinkingContent = false
let thinkingStartTime = 0
let accumulatingText = false
let accumulatedThinkingContent = ''
const processedStream = resultFromUpstream.pipeThrough(
new TransformStream<GenericChunk, GenericChunk>({
transform(chunk: GenericChunk, controller) {
logger.silly('chunk', chunk)
if (chunk.type === ChunkType.TEXT_DELTA) {
const textChunk = chunk as TextDeltaChunk
// 使用 TagExtractor 处理文本
const extractionResults = tagExtractor.processText(textChunk.text)
for (const extractionResult of extractionResults) {
if (extractionResult.complete && extractionResult.tagContentExtracted?.trim()) {
// 完成思考
// logger.silly(
// 'since extractionResult.complete and extractionResult.tagContentExtracted is not empty, THINKING_COMPLETE chunk is generated'
// )
// 如果完成思考,更新状态
accumulatingText = false
// 生成 THINKING_COMPLETE 事件
const thinkingCompleteChunk: ThinkingCompleteChunk = {
type: ChunkType.THINKING_COMPLETE,
text: extractionResult.tagContentExtracted.trim(),
thinking_millsec: thinkingStartTime > 0 ? Date.now() - thinkingStartTime : 0
}
controller.enqueue(thinkingCompleteChunk)
// 重置思考状态
hasThinkingContent = false
thinkingStartTime = 0
} else if (extractionResult.content.length > 0) {
// logger.silly(
// 'since extractionResult.content is not empty, try to generate THINKING_START/THINKING_DELTA chunk'
// )
if (extractionResult.isTagContent) {
// 如果提取到思考内容,更新状态
accumulatingText = false
// 第一次接收到思考内容时记录开始时间
if (!hasThinkingContent) {
hasThinkingContent = true
thinkingStartTime = Date.now()
controller.enqueue({
type: ChunkType.THINKING_START
} as ThinkingStartChunk)
}
if (extractionResult.content?.trim()) {
accumulatedThinkingContent += extractionResult.content.trim()
const thinkingDeltaChunk: ThinkingDeltaChunk = {
type: ChunkType.THINKING_DELTA,
text: accumulatedThinkingContent,
thinking_millsec: thinkingStartTime > 0 ? Date.now() - thinkingStartTime : 0
}
controller.enqueue(thinkingDeltaChunk)
}
} else {
// 如果没有思考内容,直接输出文本
// logger.silly(
// 'since extractionResult.isTagContent is falsy, try to generate TEXT_START/TEXT_DELTA chunk'
// )
// 在非组成文本状态下接收到非思考内容时,生成 TEXT_START chunk 并更新状态
if (!accumulatingText) {
// logger.silly('since accumulatingText is false, TEXT_START chunk is generated')
controller.enqueue({
type: ChunkType.TEXT_START
})
accumulatingText = true
}
// 发送清理后的文本内容
const cleanTextChunk: TextDeltaChunk = {
...textChunk,
text: extractionResult.content
}
controller.enqueue(cleanTextChunk)
}
} else {
// logger.silly('since both condition is false, skip')
}
}
} else if (chunk.type !== ChunkType.TEXT_START) {
// logger.silly('since chunk.type is not TEXT_START and not TEXT_DELTA, pass through')
// logger.silly('since chunk.type is not TEXT_START and not TEXT_DELTA, accumulatingText is set to false')
accumulatingText = false
// 其他类型的chunk直接传递包括 THINKING_DELTA, THINKING_COMPLETE 等)
controller.enqueue(chunk)
} else {
// 接收到的 TEXT_START chunk 直接丢弃
// logger.silly('since chunk.type is TEXT_START, passed')
}
},
flush(controller) {
// 处理可能剩余的思考内容
const finalResult = tagExtractor.finalize()
if (finalResult?.tagContentExtracted) {
const thinkingCompleteChunk: ThinkingCompleteChunk = {
type: ChunkType.THINKING_COMPLETE,
text: finalResult.tagContentExtracted,
thinking_millsec: thinkingStartTime > 0 ? Date.now() - thinkingStartTime : 0
}
controller.enqueue(thinkingCompleteChunk)
}
}
})
)
// 更新响应结果
return {
...result,
stream: processedStream
}
} else {
logger.warn(`[${MIDDLEWARE_NAME}] No generic chunk stream to process or not a ReadableStream.`)
}
}
return result
}

View File

@ -1,138 +0,0 @@
import { loggerService } from '@logger'
import type { MCPTool } from '@renderer/types'
import type { MCPToolCreatedChunk, TextDeltaChunk } from '@renderer/types/chunk'
import { ChunkType } from '@renderer/types/chunk'
import { parseToolUse } from '@renderer/utils/mcp-tools'
import type { TagConfig } from '@renderer/utils/tagExtraction'
import { TagExtractor } from '@renderer/utils/tagExtraction'
import type { CompletionsParams, CompletionsResult, GenericChunk } from '../schemas'
import type { CompletionsContext, CompletionsMiddleware } from '../types'
export const MIDDLEWARE_NAME = 'ToolUseExtractionMiddleware'
const logger = loggerService.withContext('ToolUseExtractionMiddleware')
// 工具使用标签配置
const TOOL_USE_TAG_CONFIG: TagConfig = {
openingTag: '<tool_use>',
closingTag: '</tool_use>',
separator: '\n'
}
/**
* 使
*
*
* 1. <tool_use></tool_use>
* 2. ToolUseResponse
* 3. MCP_TOOL_CREATED chunk McpToolChunkMiddleware
* 4. tool_use
* 5. 使
*
* McpToolChunkMiddleware
*/
export const ToolUseExtractionMiddleware: CompletionsMiddleware =
() =>
(next) =>
async (ctx: CompletionsContext, params: CompletionsParams): Promise<CompletionsResult> => {
const mcpTools = params.mcpTools || []
if (!mcpTools || mcpTools.length === 0) return next(ctx, params)
const result = await next(ctx, params)
if (result.stream) {
const resultFromUpstream = result.stream as ReadableStream<GenericChunk>
const processedStream = resultFromUpstream.pipeThrough(createToolUseExtractionTransform(ctx, mcpTools))
return {
...result,
stream: processedStream
}
}
return result
}
/**
* 使 TransformStream
*/
function createToolUseExtractionTransform(
_ctx: CompletionsContext,
mcpTools: MCPTool[]
): TransformStream<GenericChunk, GenericChunk> {
const toolUseExtractor = new TagExtractor(TOOL_USE_TAG_CONFIG)
let hasAnyToolUse = false
let toolCounter = 0
return new TransformStream({
async transform(chunk: GenericChunk, controller) {
try {
// 处理文本内容,检测工具使用标签
logger.silly('chunk', chunk)
if (chunk.type === ChunkType.TEXT_DELTA) {
const textChunk = chunk as TextDeltaChunk
// 处理 tool_use 标签
const toolUseResults = toolUseExtractor.processText(textChunk.text)
for (const result of toolUseResults) {
if (result.complete && result.tagContentExtracted) {
// 提取到完整的工具使用内容,解析并转换为 SDK ToolCall 格式
const toolUseResponses = parseToolUse(result.tagContentExtracted, mcpTools, toolCounter)
toolCounter += toolUseResponses.length
if (toolUseResponses.length > 0) {
// 生成 MCP_TOOL_CREATED chunk
const mcpToolCreatedChunk: MCPToolCreatedChunk = {
type: ChunkType.MCP_TOOL_CREATED,
tool_use_responses: toolUseResponses
}
controller.enqueue(mcpToolCreatedChunk)
// 标记已有工具调用
hasAnyToolUse = true
}
} else if (!result.isTagContent && result.content) {
if (!hasAnyToolUse) {
const cleanTextChunk: TextDeltaChunk = {
...textChunk,
text: result.content
}
controller.enqueue(cleanTextChunk)
}
}
// tool_use 标签内的内容不转发,避免重复显示
}
return
}
// 转发其他所有chunk
controller.enqueue(chunk)
} catch (error) {
logger.error('Error processing chunk:', error as Error)
controller.error(error)
}
},
async flush(controller) {
// 检查是否有未完成的 tool_use 标签内容
const finalToolUseResult = toolUseExtractor.finalize()
if (finalToolUseResult && finalToolUseResult.tagContentExtracted) {
const toolUseResponses = parseToolUse(finalToolUseResult.tagContentExtracted, mcpTools, toolCounter)
if (toolUseResponses.length > 0) {
const mcpToolCreatedChunk: MCPToolCreatedChunk = {
type: ChunkType.MCP_TOOL_CREATED,
tool_use_responses: toolUseResponses
}
controller.enqueue(mcpToolCreatedChunk)
hasAnyToolUse = true
}
}
}
})
}
export default ToolUseExtractionMiddleware

View File

@ -1,88 +0,0 @@
import type { CompletionsMiddleware, MethodMiddleware } from './types'
// /**
// * Wraps a provider instance with middlewares.
// */
// export function wrapProviderWithMiddleware(
// apiClientInstance: BaseApiClient,
// middlewareConfig: MiddlewareConfig
// ): BaseApiClient {
// console.log(`[wrapProviderWithMiddleware] Wrapping provider: ${apiClientInstance.provider?.id}`)
// console.log(`[wrapProviderWithMiddleware] Middleware config:`, {
// completions: middlewareConfig.completions?.length || 0,
// methods: Object.keys(middlewareConfig.methods || {}).length
// })
// // Cache for already wrapped methods to avoid re-wrapping on every access.
// const wrappedMethodsCache = new Map<string, (...args: any[]) => Promise<any>>()
// const proxy = new Proxy(apiClientInstance, {
// get(target, propKey, receiver) {
// const methodName = typeof propKey === 'string' ? propKey : undefined
// if (!methodName) {
// return Reflect.get(target, propKey, receiver)
// }
// if (wrappedMethodsCache.has(methodName)) {
// console.log(`[wrapProviderWithMiddleware] Using cached wrapped method: ${methodName}`)
// return wrappedMethodsCache.get(methodName)
// }
// const originalMethod = Reflect.get(target, propKey, receiver)
// // If the property is not a function, return it directly.
// if (typeof originalMethod !== 'function') {
// return originalMethod
// }
// let wrappedMethod: ((...args: any[]) => Promise<any>) | undefined
// // Handle completions method
// if (methodName === 'completions' && middlewareConfig.completions?.length) {
// console.log(
// `[wrapProviderWithMiddleware] Wrapping completions method with ${middlewareConfig.completions.length} middlewares`
// )
// const completionsOriginalMethod = originalMethod as (params: CompletionsParams) => Promise<any>
// wrappedMethod = applyCompletionsMiddlewares(target, completionsOriginalMethod, middlewareConfig.completions)
// }
// // Handle other methods
// else {
// const methodMiddlewares = middlewareConfig.methods?.[methodName]
// if (methodMiddlewares?.length) {
// console.log(
// `[wrapProviderWithMiddleware] Wrapping method ${methodName} with ${methodMiddlewares.length} middlewares`
// )
// const genericOriginalMethod = originalMethod as (...args: any[]) => Promise<any>
// wrappedMethod = applyMethodMiddlewares(target, methodName, genericOriginalMethod, methodMiddlewares)
// }
// }
// if (wrappedMethod) {
// console.log(`[wrapProviderWithMiddleware] Successfully wrapped method: ${methodName}`)
// wrappedMethodsCache.set(methodName, wrappedMethod)
// return wrappedMethod
// }
// // If no middlewares are configured for this method, return the original method bound to the target. /
// // 如果没有为此方法配置中间件,则返回绑定到目标的原始方法。
// console.log(`[wrapProviderWithMiddleware] No middlewares for method ${methodName}, returning original`)
// return originalMethod.bind(target)
// }
// })
// return proxy as BaseApiClient
// }
// Export types for external use
export type { CompletionsMiddleware, MethodMiddleware }
// Export MiddlewareBuilder related types and classes
export {
CompletionsMiddlewareBuilder,
createCompletionsBuilder,
createMethodBuilder,
MethodMiddlewareBuilder,
MiddlewareBuilder,
type MiddlewareExecutor,
type NamedMiddleware
} from './builder'

View File

@ -1,149 +0,0 @@
import * as AbortHandlerModule from './common/AbortHandlerMiddleware'
import * as ErrorHandlerModule from './common/ErrorHandlerMiddleware'
import * as FinalChunkConsumerModule from './common/FinalChunkConsumerMiddleware'
import * as LoggingModule from './common/LoggingMiddleware'
import * as McpToolChunkModule from './core/McpToolChunkMiddleware'
import * as RawStreamListenerModule from './core/RawStreamListenerMiddleware'
import * as ResponseTransformModule from './core/ResponseTransformMiddleware'
// import * as SdkCallModule from './core/SdkCallMiddleware'
import * as StreamAdapterModule from './core/StreamAdapterMiddleware'
import * as TextChunkModule from './core/TextChunkMiddleware'
import * as ThinkChunkModule from './core/ThinkChunkMiddleware'
import * as TransformCoreToSdkParamsModule from './core/TransformCoreToSdkParamsMiddleware'
import * as WebSearchModule from './core/WebSearchMiddleware'
import * as ImageGenerationModule from './feat/ImageGenerationMiddleware'
import * as ThinkingTagExtractionModule from './feat/ThinkingTagExtractionMiddleware'
import * as ToolUseExtractionMiddleware from './feat/ToolUseExtractionMiddleware'
/**
* - 访
* MIDDLEWARE_NAME linter
*/
export const MiddlewareRegistry = {
[ErrorHandlerModule.MIDDLEWARE_NAME]: {
name: ErrorHandlerModule.MIDDLEWARE_NAME,
middleware: ErrorHandlerModule.ErrorHandlerMiddleware
},
// 通用中间件
[AbortHandlerModule.MIDDLEWARE_NAME]: {
name: AbortHandlerModule.MIDDLEWARE_NAME,
middleware: AbortHandlerModule.AbortHandlerMiddleware
},
[FinalChunkConsumerModule.MIDDLEWARE_NAME]: {
name: FinalChunkConsumerModule.MIDDLEWARE_NAME,
middleware: FinalChunkConsumerModule.default
},
// 核心流程中间件
[TransformCoreToSdkParamsModule.MIDDLEWARE_NAME]: {
name: TransformCoreToSdkParamsModule.MIDDLEWARE_NAME,
middleware: TransformCoreToSdkParamsModule.TransformCoreToSdkParamsMiddleware
},
// [SdkCallModule.MIDDLEWARE_NAME]: {
// name: SdkCallModule.MIDDLEWARE_NAME,
// middleware: SdkCallModule.SdkCallMiddleware
// },
[StreamAdapterModule.MIDDLEWARE_NAME]: {
name: StreamAdapterModule.MIDDLEWARE_NAME,
middleware: StreamAdapterModule.StreamAdapterMiddleware
},
[RawStreamListenerModule.MIDDLEWARE_NAME]: {
name: RawStreamListenerModule.MIDDLEWARE_NAME,
middleware: RawStreamListenerModule.RawStreamListenerMiddleware
},
[ResponseTransformModule.MIDDLEWARE_NAME]: {
name: ResponseTransformModule.MIDDLEWARE_NAME,
middleware: ResponseTransformModule.ResponseTransformMiddleware
},
// 特性处理中间件
[ThinkingTagExtractionModule.MIDDLEWARE_NAME]: {
name: ThinkingTagExtractionModule.MIDDLEWARE_NAME,
middleware: ThinkingTagExtractionModule.ThinkingTagExtractionMiddleware
},
[ToolUseExtractionMiddleware.MIDDLEWARE_NAME]: {
name: ToolUseExtractionMiddleware.MIDDLEWARE_NAME,
middleware: ToolUseExtractionMiddleware.ToolUseExtractionMiddleware
},
[ThinkChunkModule.MIDDLEWARE_NAME]: {
name: ThinkChunkModule.MIDDLEWARE_NAME,
middleware: ThinkChunkModule.ThinkChunkMiddleware
},
[McpToolChunkModule.MIDDLEWARE_NAME]: {
name: McpToolChunkModule.MIDDLEWARE_NAME,
middleware: McpToolChunkModule.McpToolChunkMiddleware
},
[WebSearchModule.MIDDLEWARE_NAME]: {
name: WebSearchModule.MIDDLEWARE_NAME,
middleware: WebSearchModule.WebSearchMiddleware
},
[TextChunkModule.MIDDLEWARE_NAME]: {
name: TextChunkModule.MIDDLEWARE_NAME,
middleware: TextChunkModule.TextChunkMiddleware
},
[ImageGenerationModule.MIDDLEWARE_NAME]: {
name: ImageGenerationModule.MIDDLEWARE_NAME,
middleware: ImageGenerationModule.ImageGenerationMiddleware
}
} as const
/**
*
* @param name -
* @returns
*/
export function getMiddleware(name: string) {
return MiddlewareRegistry[name]
}
/**
*
* @returns
*/
export function getRegisteredMiddlewareNames(): string[] {
return Object.keys(MiddlewareRegistry)
}
/**
* Completions - NamedMiddleware MiddlewareBuilder
*/
export const DefaultCompletionsNamedMiddlewares = [
MiddlewareRegistry[FinalChunkConsumerModule.MIDDLEWARE_NAME], // 最终消费者
MiddlewareRegistry[ErrorHandlerModule.MIDDLEWARE_NAME], // 错误处理
MiddlewareRegistry[TransformCoreToSdkParamsModule.MIDDLEWARE_NAME], // 参数转换
MiddlewareRegistry[AbortHandlerModule.MIDDLEWARE_NAME], // 中止处理
MiddlewareRegistry[McpToolChunkModule.MIDDLEWARE_NAME], // 工具处理
MiddlewareRegistry[TextChunkModule.MIDDLEWARE_NAME], // 文本处理
MiddlewareRegistry[WebSearchModule.MIDDLEWARE_NAME], // Web搜索处理
MiddlewareRegistry[ToolUseExtractionMiddleware.MIDDLEWARE_NAME], // 工具使用提取处理
MiddlewareRegistry[ThinkingTagExtractionModule.MIDDLEWARE_NAME], // 思考标签提取处理特定provider
MiddlewareRegistry[ThinkChunkModule.MIDDLEWARE_NAME], // 思考处理通用SDK
MiddlewareRegistry[ResponseTransformModule.MIDDLEWARE_NAME], // 响应转换
MiddlewareRegistry[StreamAdapterModule.MIDDLEWARE_NAME], // 流适配器
MiddlewareRegistry[RawStreamListenerModule.MIDDLEWARE_NAME] // 原始流监听器
]
/**
* -
*/
export const DefaultMethodMiddlewares = {
translate: [LoggingModule.createGenericLoggingMiddleware()],
summaries: [LoggingModule.createGenericLoggingMiddleware()]
}
/**
* 便使
*/
export {
AbortHandlerModule,
FinalChunkConsumerModule,
LoggingModule,
McpToolChunkModule,
ResponseTransformModule,
StreamAdapterModule,
TextChunkModule,
ThinkChunkModule,
ThinkingTagExtractionModule,
TransformCoreToSdkParamsModule,
WebSearchModule
}

View File

@ -1,84 +0,0 @@
import type { Assistant, MCPTool } from '@renderer/types'
import type { Chunk } from '@renderer/types/chunk'
import type { Message } from '@renderer/types/newMessage'
import type { SdkRawChunk, SdkRawOutput } from '@renderer/types/sdk'
import type { ProcessingState } from './types'
// ============================================================================
// Core Request Types - 核心请求结构
// ============================================================================
/**
* AI Provider的统一处理
*
*/
export interface CompletionsParams {
/**
*
* 'chat':
* 'translate':
* 'summary':
* 'search':
* 'generate':
* 'check': API检查
* 'test':
* 'translate-lang-detect':
*/
callType?: 'chat' | 'translate' | 'summary' | 'search' | 'generate' | 'check' | 'test' | 'translate-lang-detect'
// 基础对话数据
messages: Message[] | string // 联合类型方便判断是否为空
assistant: Assistant // 助手为基本单位
// model: Model
onChunk?: (chunk: Chunk) => void
onResponse?: (text: string, isComplete: boolean) => void
// 错误相关
onError?: (error: Error) => void
shouldThrow?: boolean
// 工具相关
mcpTools?: MCPTool[]
// 生成参数
temperature?: number
topP?: number
maxTokens?: number
// 功能开关
streamOutput: boolean
enableWebSearch?: boolean
enableUrlContext?: boolean
enableReasoning?: boolean
enableGenerateImage?: boolean
// 上下文控制
contextCount?: number
topicId?: string // 主题ID用于关联上下文
// abort 控制
abortKey?: string
_internal?: ProcessingState
}
export interface CompletionsResult {
rawOutput?: SdkRawOutput
stream?: ReadableStream<SdkRawChunk> | ReadableStream<Chunk> | AsyncIterable<Chunk>
controller?: AbortController
getText: () => string
}
// ============================================================================
// Generic Chunk Types - 通用数据块结构
// ============================================================================
/**
*
* Chunk AI Provider都应该输出的标准化数据块格式
*/
export type GenericChunk = Chunk

View File

@ -1,166 +0,0 @@
import type { MCPToolResponse, Metrics, Usage, WebSearchResponse } from '@renderer/types'
import type { Chunk, ErrorChunk } from '@renderer/types/chunk'
import type {
SdkInstance,
SdkMessageParam,
SdkParams,
SdkRawChunk,
SdkRawOutput,
SdkTool,
SdkToolCall
} from '@renderer/types/sdk'
import type { BaseApiClient } from '../clients'
import type { CompletionsParams, CompletionsResult } from './schemas'
/**
* Symbol to uniquely identify middleware context objects.
*/
export const MIDDLEWARE_CONTEXT_SYMBOL = Symbol.for('AiProviderMiddlewareContext')
/**
* Defines the structure for the onChunk callback function.
*/
export type OnChunkFunction = (chunk: Chunk | ErrorChunk) => void
/**
* Base context that carries information about the current method call.
*/
export interface BaseContext {
[MIDDLEWARE_CONTEXT_SYMBOL]: true
methodName: string
originalArgs: Readonly<any[]>
}
/**
* Processing state shared between middlewares.
*/
export interface ProcessingState<
TParams extends SdkParams = SdkParams,
TMessageParam extends SdkMessageParam = SdkMessageParam,
TToolCall extends SdkToolCall = SdkToolCall
> {
sdkPayload?: TParams
newReqMessages?: TMessageParam[]
observer?: {
usage?: Usage
metrics?: Metrics
}
toolProcessingState?: {
pendingToolCalls?: Array<TToolCall>
executingToolCalls?: Array<{
sdkToolCall: TToolCall
mcpToolResponse: MCPToolResponse
}>
output?: SdkRawOutput | string
isRecursiveCall?: boolean
recursionDepth?: number
}
webSearchState?: {
results?: WebSearchResponse
}
flowControl?: {
abortController?: AbortController
abortSignal?: AbortSignal
cleanup?: () => void
}
enhancedDispatch?: (context: CompletionsContext, params: CompletionsParams) => Promise<CompletionsResult>
customState?: Record<string, any>
}
/**
* Extended context for completions method.
*/
export interface CompletionsContext<
TSdkParams extends SdkParams = SdkParams,
TSdkMessageParam extends SdkMessageParam = SdkMessageParam,
TSdkToolCall extends SdkToolCall = SdkToolCall,
TSdkInstance extends SdkInstance = SdkInstance,
TRawOutput extends SdkRawOutput = SdkRawOutput,
TRawChunk extends SdkRawChunk = SdkRawChunk,
TSdkSpecificTool extends SdkTool = SdkTool
> extends BaseContext {
readonly methodName: 'completions' // 强制方法名为 'completions'
apiClientInstance: BaseApiClient<
TSdkInstance,
TSdkParams,
TRawOutput,
TRawChunk,
TSdkMessageParam,
TSdkToolCall,
TSdkSpecificTool
>
// --- Mutable internal state for the duration of the middleware chain ---
_internal: ProcessingState<TSdkParams, TSdkMessageParam, TSdkToolCall> // 包含所有可变的处理状态
}
export interface MiddlewareAPI<Ctx extends BaseContext = BaseContext, Args extends any[] = any[]> {
getContext: () => Ctx // Function to get the current context / 获取当前上下文的函数
getOriginalArgs: () => Args // Function to get the original arguments of the method call / 获取方法调用原始参数的函数
}
/**
* Base middleware type.
*/
export type Middleware<TContext extends BaseContext> = (
api: MiddlewareAPI<TContext>
) => (
next: (context: TContext, args: any[]) => Promise<unknown>
) => (context: TContext, args: any[]) => Promise<unknown>
export type MethodMiddleware = Middleware<BaseContext>
/**
* Completions middleware type.
*/
export type CompletionsMiddleware<
TSdkParams extends SdkParams = SdkParams,
TSdkMessageParam extends SdkMessageParam = SdkMessageParam,
TSdkToolCall extends SdkToolCall = SdkToolCall,
TSdkInstance extends SdkInstance = SdkInstance,
TRawOutput extends SdkRawOutput = SdkRawOutput,
TRawChunk extends SdkRawChunk = SdkRawChunk,
TSdkSpecificTool extends SdkTool = SdkTool
> = (
api: MiddlewareAPI<
CompletionsContext<
TSdkParams,
TSdkMessageParam,
TSdkToolCall,
TSdkInstance,
TRawOutput,
TRawChunk,
TSdkSpecificTool
>,
[CompletionsParams]
>
) => (
next: (
context: CompletionsContext<
TSdkParams,
TSdkMessageParam,
TSdkToolCall,
TSdkInstance,
TRawOutput,
TRawChunk,
TSdkSpecificTool
>,
params: CompletionsParams
) => Promise<CompletionsResult>
) => (
context: CompletionsContext<
TSdkParams,
TSdkMessageParam,
TSdkToolCall,
TSdkInstance,
TRawOutput,
TRawChunk,
TSdkSpecificTool
>,
params: CompletionsParams
) => Promise<CompletionsResult>
// Re-export for convenience
export type { Chunk as OnChunkArg } from '@renderer/types/chunk'

View File

@ -1,58 +0,0 @@
import type { ErrorChunk } from '@renderer/types/chunk'
import { ChunkType } from '@renderer/types/chunk'
/**
* Creates an ErrorChunk object with a standardized structure.
* @param error The error object or message.
* @param chunkType The type of chunk, defaults to ChunkType.ERROR.
* @returns An ErrorChunk object.
*/
export function createErrorChunk(error: any, chunkType: ChunkType = ChunkType.ERROR): ErrorChunk {
let errorDetails: Record<string, any> = {}
if (error instanceof Error) {
errorDetails = {
message: error.message,
name: error.name,
stack: error.stack
}
} else if (typeof error === 'string') {
errorDetails = { message: error }
} else if (typeof error === 'object' && error !== null) {
errorDetails = Object.getOwnPropertyNames(error).reduce(
(acc, key) => {
acc[key] = error[key]
return acc
},
{} as Record<string, any>
)
if (!errorDetails.message && error.toString && typeof error.toString === 'function') {
const errMsg = error.toString()
if (errMsg !== '[object Object]') {
errorDetails.message = errMsg
}
}
}
return {
type: chunkType,
error: errorDetails
} as ErrorChunk
}
// Helper to capitalize method names for hook construction
export function capitalize(str: string): string {
if (!str) return ''
return str.charAt(0).toUpperCase() + str.slice(1)
}
/**
* AsyncIterable接口
*/
export function isAsyncIterable<T = unknown>(obj: unknown): obj is AsyncIterable<T> {
return (
obj !== null &&
typeof obj === 'object' &&
typeof (obj as Record<symbol, unknown>)[Symbol.asyncIterator] === 'function'
)
}

View File

@ -29,3 +29,11 @@ export type ProviderConfig<T extends StringKeys<AppProviderSettingsMap> = String
export type { AppProviderId, AppProviderSettingsMap, AppRuntimeConfig } from './merged'
export { appProviderIds, getAllProviderIds, isRegisteredProviderId } from './merged'
/**
* Result of completions operation
* Simple interface with getText method to retrieve the generated text
*/
export type CompletionsResult = {
getText: () => string
}

View File

@ -19,7 +19,7 @@ import { isToolUseModeFunction } from '@renderer/utils/assistant'
import { isAbortError } from '@renderer/utils/error'
import { purifyMarkdownImages } from '@renderer/utils/markdown'
import { isPromptToolUse, isSupportedToolUse } from '@renderer/utils/mcp-tools'
import { findFileBlocks, getMainTextContent } from '@renderer/utils/messageUtils/find'
import { findFileBlocks, findImageBlocks, getMainTextContent } from '@renderer/utils/messageUtils/find'
import { containsSupportedVariables, replacePromptVariables } from '@renderer/utils/prompt'
import { NOT_SUPPORT_API_KEY_PROVIDER_TYPES, NOT_SUPPORT_API_KEY_PROVIDERS } from '@renderer/utils/provider'
import { isEmpty, takeRight } from 'lodash'
@ -35,6 +35,7 @@ import {
getQuickModel
} from './AssistantService'
import { ConversationService } from './ConversationService'
import FileManager from './FileManager'
import { injectUserMessageWithKnowledgeSearchPrompt } from './KnowledgeService'
import type { BlockManager } from './messageStreaming'
import type { StreamProcessorCallbacks } from './StreamProcessingService'
@ -167,6 +168,20 @@ export async function fetchChatCompletion({
const AI = new AiProviderNew(assistant.model || getDefaultModel(), providerWithRotatedKey)
const provider = AI.getActualProvider()
// 专用图像生成模型走 generateImage 路径
if (isDedicatedImageGenerationModel(assistant.model || getDefaultModel())) {
if (!uiMessages || uiMessages.length === 0) {
throw new Error('uiMessages is required for dedicated image generation models')
}
await fetchImageGeneration({
messages: uiMessages,
assistant,
onChunkReceived,
aiProvider: AI
})
return
}
const mcpTools: MCPTool[] = []
onChunkReceived({ type: ChunkType.LLM_RESPONSE_CREATED })
@ -225,6 +240,114 @@ export async function fetchChatCompletion({
})
}
/**
*
*
*/
async function collectImagesFromMessages(userMessage: Message, assistantMessage?: Message): Promise<string[]> {
const images: string[] = []
// 收集用户消息中的图像
const userImageBlocks = findImageBlocks(userMessage)
for (const block of userImageBlocks) {
if (block.file) {
const base64 = await FileManager.readBase64File(block.file)
const mimeType = block.file.type || 'image/png'
images.push(`data:${mimeType};base64,${base64}`)
}
}
// 收集助手消息中的图像(用于继续编辑生成的图像)
if (assistantMessage) {
const assistantImageBlocks = findImageBlocks(assistantMessage)
for (const block of assistantImageBlocks) {
if (block.url) {
images.push(block.url)
}
}
}
return images
}
/**
*
* DALL-EGPT-Image-1
*/
async function fetchImageGeneration({
messages,
assistant,
onChunkReceived,
aiProvider
}: {
messages: Message[]
assistant: Assistant
onChunkReceived: (chunk: Chunk) => void
aiProvider: AiProviderNew
}) {
onChunkReceived({ type: ChunkType.LLM_RESPONSE_CREATED })
onChunkReceived({ type: ChunkType.IMAGE_CREATED })
const startTime = Date.now()
try {
// 提取 prompt 和图像
const lastUserMessage = messages.findLast((m) => m.role === 'user')
const lastAssistantMessage = messages.findLast((m) => m.role === 'assistant')
if (!lastUserMessage) {
throw new Error('No user message found for image generation.')
}
const prompt = getMainTextContent(lastUserMessage)
const inputImages = await collectImagesFromMessages(lastUserMessage, lastAssistantMessage)
// 调用 generateImage 或 editImage
// 使用默认图像生成配置
const imageSize = '1024x1024'
const batchSize = 1
let images: string[]
if (inputImages.length > 0) {
images = await aiProvider.editImage({
model: assistant.model!.id,
prompt: prompt || '',
inputImages,
imageSize
})
} else {
images = await aiProvider.generateImage({
model: assistant.model!.id,
prompt: prompt || '',
imageSize,
batchSize
})
}
// 发送结果 chunks
const imageType = images[0]?.startsWith('data:') ? 'base64' : 'url'
onChunkReceived({
type: ChunkType.IMAGE_COMPLETE,
image: { type: imageType, images }
})
onChunkReceived({
type: ChunkType.LLM_RESPONSE_COMPLETE,
response: {
usage: { prompt_tokens: 0, completion_tokens: 0, total_tokens: 0 },
metrics: {
completion_tokens: 0,
time_first_token_millsec: 0,
time_completion_millsec: Date.now() - startTime
}
}
})
} catch (error) {
onChunkReceived({ type: ChunkType.ERROR, error: error as Error })
throw error
}
}
export async function fetchMessagesSummary({ messages, assistant }: { messages: Message[]; assistant: Assistant }) {
let prompt = (getStoreSetting('topicNamingPrompt') as string) || i18n.t('prompts.title')
const model = getQuickModel() || assistant?.model || getDefaultModel()

View File

@ -1,7 +1,6 @@
import { loggerService } from '@logger'
import type { Span } from '@opentelemetry/api'
import { ModernAiProvider } from '@renderer/aiCore'
import AiProvider from '@renderer/aiCore/legacy'
import ModernAiProvider from '@renderer/aiCore'
import { getMessageContent } from '@renderer/aiCore/plugins/searchOrchestrationPlugin'
import { DEFAULT_KNOWLEDGE_DOCUMENT_COUNT, DEFAULT_KNOWLEDGE_THRESHOLD } from '@renderer/config/constant'
import { getEmbeddingMaxContext } from '@renderer/config/embedings'
@ -36,7 +35,7 @@ const logger = loggerService.withContext('RendererKnowledgeService')
export const getKnowledgeBaseParams = (base: KnowledgeBase): KnowledgeBaseParams => {
const rerankProvider = getProviderByModel(base.rerankModel)
const aiProvider = new ModernAiProvider(base.model)
const rerankAiProvider = new AiProvider(rerankProvider)
const rerankAiProvider = new ModernAiProvider(rerankProvider)
// get preprocess provider from store instead of base.preprocessProvider
const preprocessProvider = store

View File

@ -5,7 +5,6 @@ import type { SpanEntity, TokenUsage } from '@mcp-trace/trace-core'
import { cleanContext, endContext, getContext, startContext } from '@mcp-trace/trace-web'
import type { Context, Span } from '@opentelemetry/api'
import { context, SpanStatusCode, trace } from '@opentelemetry/api'
import { isAsyncIterable } from '@renderer/aiCore/legacy/middleware/utils'
import { db } from '@renderer/databases'
import { getEnableDeveloperMode } from '@renderer/hooks/useSettings'
import { EVENT_NAMES, EventEmitter } from '@renderer/services/EventService'
@ -22,6 +21,11 @@ import type { SdkRawChunk } from '@renderer/types/sdk'
const logger = loggerService.withContext('SpanManagerService')
// Type guard for AsyncIterable
function isAsyncIterable<T>(obj: any): obj is AsyncIterable<T> {
return obj != null && typeof obj === 'object' && typeof obj[Symbol.asyncIterator] === 'function'
}
class SpanManagerService {
private spanMap: Map<string, ModelSpanEntity[]> = new Map()

View File

@ -10,7 +10,7 @@ import type OpenAI from '@cherrystudio/openai'
import type { ChatCompletionChunk } from '@cherrystudio/openai/resources'
import type { FunctionCall } from '@google/genai'
import { FinishReason, MediaModality } from '@google/genai'
import AiProvider from '@renderer/aiCore'
import AiProvider from '@renderer/aiCore/legacy'
import type { BaseApiClient, OpenAIAPIClient, ResponseChunkTransformerContext } from '@renderer/aiCore/legacy/clients'
import type { AnthropicAPIClient } from '@renderer/aiCore/legacy/clients/anthropic/AnthropicAPIClient'
import { ApiClientFactory } from '@renderer/aiCore/legacy/clients/ApiClientFactory'

View File

@ -1,6 +1,6 @@
import type { TokenUsage } from '@mcp-trace/trace-core'
import type { Span } from '@opentelemetry/api'
import type { CompletionsResult } from '@renderer/aiCore/legacy/middleware/schemas'
import type { CompletionsResult } from '@renderer/aiCore/types'
import { endSpan } from '@renderer/services/SpanManagerService'
export class CompletionsResultHandler {