mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2026-01-14 06:07:23 +08:00
refactor: add ai-sdk embedFunc and remove legacy
This commit is contained in:
parent
53a2e06120
commit
f7ee2fc934
@ -5,6 +5,7 @@
|
||||
import type { ImageModelV3, LanguageModelV3, LanguageModelV3Middleware, ProviderV3 } from '@ai-sdk/provider'
|
||||
import type { LanguageModel } from 'ai'
|
||||
import {
|
||||
embedMany as _embedMany,
|
||||
generateImage as _generateImage,
|
||||
generateText as _generateText,
|
||||
streamText as _streamText,
|
||||
@ -17,7 +18,14 @@ import { type AiPlugin, type AiRequestContext, definePlugin } from '../plugins'
|
||||
import type { CoreProviderSettingsMap, StringKeys } from '../providers/types'
|
||||
import { ImageGenerationError, ImageModelResolutionError } from './errors'
|
||||
import { PluginEngine } from './pluginEngine'
|
||||
import type { generateImageParams, generateTextParams, RuntimeConfig, streamTextParams } from './types'
|
||||
import type {
|
||||
EmbedManyParams,
|
||||
EmbedManyResult,
|
||||
generateImageParams,
|
||||
generateTextParams,
|
||||
RuntimeConfig,
|
||||
streamTextParams
|
||||
} from './types'
|
||||
|
||||
export class RuntimeExecutor<
|
||||
TSettingsMap extends Record<string, any> = CoreProviderSettingsMap,
|
||||
@ -166,6 +174,23 @@ export class RuntimeExecutor<
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 批量嵌入文本
|
||||
* AI SDK v6 只有 embedMany,没有 embed
|
||||
*/
|
||||
async embedMany(params: EmbedManyParams): Promise<EmbedManyResult> {
|
||||
const { model: modelOrId, ...options } = params
|
||||
|
||||
// 解析 embedding 模型
|
||||
const embeddingModel =
|
||||
typeof modelOrId === 'string' ? await this.modelResolver.resolveEmbeddingModel(modelOrId) : modelOrId
|
||||
|
||||
return _embedMany({
|
||||
model: embeddingModel,
|
||||
...options
|
||||
})
|
||||
}
|
||||
|
||||
// === 辅助方法 ===
|
||||
|
||||
/**
|
||||
|
||||
@ -7,7 +7,7 @@
|
||||
export { RuntimeExecutor } from './executor'
|
||||
|
||||
// 导出类型
|
||||
export type { RuntimeConfig } from './types'
|
||||
export type { EmbedManyParams, EmbedManyResult, RuntimeConfig } from './types'
|
||||
|
||||
// === 便捷工厂函数 ===
|
||||
|
||||
@ -84,6 +84,23 @@ export async function generateImage<
|
||||
return executor.generateImage(params)
|
||||
}
|
||||
|
||||
/**
|
||||
* 直接批量嵌入文本
|
||||
* AI SDK v6 只有 embedMany,没有 embed
|
||||
*/
|
||||
export async function embedMany<
|
||||
TSettingsMap extends Record<string, any> = CoreProviderSettingsMap,
|
||||
T extends StringKeys<TSettingsMap> = StringKeys<TSettingsMap>
|
||||
>(
|
||||
providerId: T,
|
||||
options: TSettingsMap[T],
|
||||
params: Parameters<RuntimeExecutor<TSettingsMap, T>['embedMany']>[0],
|
||||
plugins?: AiPlugin[]
|
||||
): Promise<ReturnType<RuntimeExecutor<TSettingsMap, T>['embedMany']>> {
|
||||
const executor = await createExecutor<TSettingsMap, T>(providerId, options, plugins)
|
||||
return executor.embedMany(params)
|
||||
}
|
||||
|
||||
/**
|
||||
* 创建 OpenAI Compatible 执行器
|
||||
*/
|
||||
|
||||
@ -1,8 +1,8 @@
|
||||
/**
|
||||
* Runtime 层类型定义
|
||||
*/
|
||||
import type { ImageModelV3, ProviderV3 } from '@ai-sdk/provider'
|
||||
import type { generateImage, generateText, streamText } from 'ai'
|
||||
import type { EmbeddingModelV3, ImageModelV3, ProviderV3 } from '@ai-sdk/provider'
|
||||
import type { embedMany, generateImage, generateText, streamText } from 'ai'
|
||||
|
||||
import { type AiPlugin } from '../plugins'
|
||||
import type { CoreProviderSettingsMap, StringKeys } from '../providers/types'
|
||||
@ -28,3 +28,9 @@ export type generateImageParams = Omit<Parameters<typeof generateImage>[0], 'mod
|
||||
}
|
||||
export type generateTextParams = Parameters<typeof generateText>[0]
|
||||
export type streamTextParams = Parameters<typeof streamText>[0]
|
||||
|
||||
// Embedding types (AI SDK v6 only has embedMany, no embed)
|
||||
export type EmbedManyParams = Omit<Parameters<typeof embedMany>[0], 'model'> & {
|
||||
model: string | EmbeddingModelV3
|
||||
}
|
||||
export type EmbedManyResult = Awaited<ReturnType<typeof embedMany>>
|
||||
|
||||
@ -9,11 +9,15 @@
|
||||
export {
|
||||
createExecutor,
|
||||
createOpenAICompatibleExecutor,
|
||||
embedMany,
|
||||
generateImage,
|
||||
generateText,
|
||||
streamText
|
||||
} from './core/runtime'
|
||||
|
||||
// ==================== Embedding 类型 ====================
|
||||
export type { EmbedManyParams, EmbedManyResult } from './core/runtime'
|
||||
|
||||
// ==================== 高级API ====================
|
||||
export { isV2Model, isV3Model } from './core/models'
|
||||
|
||||
|
||||
@ -1,16 +1,12 @@
|
||||
/**
|
||||
* Cherry Studio AI Core - 统一入口点
|
||||
*
|
||||
* 这是新的统一入口,保持向后兼容性
|
||||
* 默认导出legacy AiProvider以保持现有代码的兼容性
|
||||
* 使用 ModernAiProvider 作为默认导出
|
||||
* Legacy provider 已移除
|
||||
*/
|
||||
|
||||
// 导出Legacy AiProvider作为默认导出(保持向后兼容)
|
||||
export { default } from './legacy/index'
|
||||
import ModernAiProvider from './index_new'
|
||||
|
||||
// 同时导出Modern AiProvider供新代码使用
|
||||
export { default as ModernAiProvider } from './index_new'
|
||||
|
||||
// 导出一些常用的类型和工具
|
||||
export * from './legacy/clients/types'
|
||||
export * from './legacy/middleware/schemas'
|
||||
// 默认导出和命名导出
|
||||
export default ModernAiProvider
|
||||
export { ModernAiProvider }
|
||||
|
||||
@ -27,12 +27,10 @@ import { buildClaudeCodeSystemModelMessage } from '@shared/anthropic'
|
||||
import { gateway } from 'ai'
|
||||
|
||||
import AiSdkToChunkAdapter from './chunk/AiSdkToChunkAdapter'
|
||||
import LegacyAiProvider from './legacy/index'
|
||||
import type { CompletionsParams, CompletionsResult } from './legacy/middleware/schemas'
|
||||
import { buildPlugins } from './plugins/PluginBuilder'
|
||||
import { adaptProvider, getActualProvider, providerToAiSdkConfig } from './provider/providerConfig'
|
||||
import { ModelListService } from './services/ModelListService'
|
||||
import type { AppProviderSettingsMap, ProviderConfig } from './types'
|
||||
import type { AppProviderSettingsMap, CompletionsResult, ProviderConfig } from './types'
|
||||
import type { AiSdkMiddlewareConfig } from './types/middlewareConfig'
|
||||
|
||||
const logger = loggerService.withContext('ModernAiProvider')
|
||||
@ -45,7 +43,6 @@ export type ModernAiProviderConfig = AiSdkMiddlewareConfig & {
|
||||
}
|
||||
|
||||
export default class ModernAiProvider {
|
||||
private legacyProvider: LegacyAiProvider
|
||||
private config?: ProviderConfig
|
||||
private actualProvider: Provider
|
||||
private model?: Model
|
||||
@ -101,8 +98,6 @@ export default class ModernAiProvider {
|
||||
this.actualProvider = adaptProvider({ provider: modelOrProvider })
|
||||
// model为可选,某些操作(如fetchModels)不需要model
|
||||
}
|
||||
|
||||
this.legacyProvider = new LegacyAiProvider(this.actualProvider)
|
||||
}
|
||||
|
||||
/**
|
||||
@ -169,28 +164,8 @@ export default class ModernAiProvider {
|
||||
middlewareConfig: ModernAiProviderConfig,
|
||||
providerConfig: ProviderConfig
|
||||
): Promise<CompletionsResult> {
|
||||
// ai-gateway不是image/generation 端点,所以就先不走legacy了
|
||||
if (middlewareConfig.isImageGenerationEndpoint && this.getActualProvider().id !== SystemProviderIds.gateway) {
|
||||
// 使用 legacy 实现处理图像生成(支持图片编辑等高级功能)
|
||||
if (!middlewareConfig.uiMessages) {
|
||||
throw new Error('uiMessages is required for image generation endpoint')
|
||||
}
|
||||
|
||||
const legacyParams: CompletionsParams = {
|
||||
callType: 'chat',
|
||||
messages: middlewareConfig.uiMessages, // 使用原始的 UI 消息格式
|
||||
assistant: middlewareConfig.assistant,
|
||||
streamOutput: middlewareConfig.streamOutput ?? true,
|
||||
onChunk: middlewareConfig.onChunk,
|
||||
topicId: middlewareConfig.topicId,
|
||||
mcpTools: middlewareConfig.mcpTools,
|
||||
enableWebSearch: middlewareConfig.enableWebSearch
|
||||
}
|
||||
|
||||
// 调用 legacy 的 completions,会自动使用 ImageGenerationMiddleware
|
||||
return await this.legacyProvider.completions(legacyParams)
|
||||
}
|
||||
|
||||
// 专用图像生成模型已在 ApiService 层分流到 fetchImageGeneration
|
||||
// 这里只处理普通的 completions
|
||||
return await this.modernCompletions(modelId, params, middlewareConfig, providerConfig)
|
||||
}
|
||||
|
||||
@ -350,8 +325,29 @@ export default class ModernAiProvider {
|
||||
return await ModelListService.listModels(this.actualProvider)
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取嵌入模型的维度
|
||||
* 使用 AI SDK embedMany 测试获取维度
|
||||
*/
|
||||
public async getEmbeddingDimensions(model: Model): Promise<number> {
|
||||
return this.legacyProvider.getEmbeddingDimensions(model)
|
||||
// 确保 config 已定义
|
||||
if (!this.config) {
|
||||
this.config = await Promise.resolve(providerToAiSdkConfig(this.actualProvider, model))
|
||||
}
|
||||
|
||||
const executor = await createExecutor<AppProviderSettingsMap>(
|
||||
this.config.providerId,
|
||||
this.config.providerSettings,
|
||||
[]
|
||||
)
|
||||
|
||||
// 使用 AI SDK embedMany 测试获取维度
|
||||
const result = await executor.embedMany({
|
||||
model: model.id,
|
||||
values: ['test']
|
||||
})
|
||||
|
||||
return result.embeddings[0].length
|
||||
}
|
||||
|
||||
/**
|
||||
@ -453,10 +449,42 @@ export default class ModernAiProvider {
|
||||
}
|
||||
|
||||
public getBaseURL(): string {
|
||||
return this.legacyProvider.getBaseURL()
|
||||
return this.actualProvider.apiHost || ''
|
||||
}
|
||||
|
||||
public getApiKey(): string {
|
||||
return this.legacyProvider.getApiKey()
|
||||
const apiKey = this.actualProvider.apiKey
|
||||
if (!apiKey || apiKey.trim() === '') {
|
||||
return ''
|
||||
}
|
||||
|
||||
const keys = apiKey
|
||||
.split(',')
|
||||
.map((key) => key.trim())
|
||||
.filter(Boolean)
|
||||
|
||||
if (keys.length === 0) {
|
||||
return ''
|
||||
}
|
||||
|
||||
if (keys.length === 1) {
|
||||
return keys[0]
|
||||
}
|
||||
|
||||
// Multi-key rotation
|
||||
const keyName = `provider:${this.actualProvider.id}:last_used_key`
|
||||
const lastUsedKey = window.keyv.get(keyName)
|
||||
|
||||
if (!lastUsedKey) {
|
||||
window.keyv.set(keyName, keys[0])
|
||||
return keys[0]
|
||||
}
|
||||
|
||||
const currentIndex = keys.indexOf(lastUsedKey)
|
||||
const nextIndex = (currentIndex + 1) % keys.length
|
||||
const nextKey = keys[nextIndex]
|
||||
window.keyv.set(keyName, nextKey)
|
||||
|
||||
return nextKey
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,103 +0,0 @@
|
||||
import { loggerService } from '@logger'
|
||||
import type { Provider } from '@renderer/types'
|
||||
import { isNewApiProvider } from '@renderer/utils/provider'
|
||||
|
||||
import { AihubmixAPIClient } from './aihubmix/AihubmixAPIClient'
|
||||
import { AnthropicAPIClient } from './anthropic/AnthropicAPIClient'
|
||||
import { AwsBedrockAPIClient } from './aws/AwsBedrockAPIClient'
|
||||
import type { BaseApiClient } from './BaseApiClient'
|
||||
import { CherryAiAPIClient } from './cherryai/CherryAiAPIClient'
|
||||
import { GeminiAPIClient } from './gemini/GeminiAPIClient'
|
||||
import { VertexAPIClient } from './gemini/VertexAPIClient'
|
||||
import { NewAPIClient } from './newapi/NewAPIClient'
|
||||
import { OpenAIAPIClient } from './openai/OpenAIApiClient'
|
||||
import { OpenAIResponseAPIClient } from './openai/OpenAIResponseAPIClient'
|
||||
import { OVMSClient } from './ovms/OVMSClient'
|
||||
import { PPIOAPIClient } from './ppio/PPIOAPIClient'
|
||||
import { ZhipuAPIClient } from './zhipu/ZhipuAPIClient'
|
||||
|
||||
const logger = loggerService.withContext('ApiClientFactory')
|
||||
|
||||
/**
|
||||
* Factory for creating ApiClient instances based on provider configuration
|
||||
* 根据提供者配置创建ApiClient实例的工厂
|
||||
*/
|
||||
export class ApiClientFactory {
|
||||
/**
|
||||
* Create an ApiClient instance for the given provider
|
||||
* 为给定的提供者创建ApiClient实例
|
||||
*/
|
||||
static create(provider: Provider): BaseApiClient {
|
||||
logger.debug(`Creating ApiClient for provider:`, {
|
||||
id: provider.id,
|
||||
type: provider.type
|
||||
})
|
||||
|
||||
let instance: BaseApiClient
|
||||
|
||||
// 首先检查特殊的 Provider ID
|
||||
if (provider.id === 'cherryai') {
|
||||
instance = new CherryAiAPIClient(provider) as BaseApiClient
|
||||
return instance
|
||||
}
|
||||
|
||||
if (provider.id === 'aihubmix') {
|
||||
logger.debug(`Creating AihubmixAPIClient for provider: ${provider.id}`)
|
||||
instance = new AihubmixAPIClient(provider) as BaseApiClient
|
||||
return instance
|
||||
}
|
||||
|
||||
if (isNewApiProvider(provider)) {
|
||||
logger.debug(`Creating NewAPIClient for provider: ${provider.id}`)
|
||||
instance = new NewAPIClient(provider) as BaseApiClient
|
||||
return instance
|
||||
}
|
||||
|
||||
if (provider.id === 'ppio') {
|
||||
logger.debug(`Creating PPIOAPIClient for provider: ${provider.id}`)
|
||||
instance = new PPIOAPIClient(provider) as BaseApiClient
|
||||
return instance
|
||||
}
|
||||
|
||||
if (provider.id === 'zhipu') {
|
||||
instance = new ZhipuAPIClient(provider) as BaseApiClient
|
||||
return instance
|
||||
}
|
||||
|
||||
if (provider.id === 'ovms') {
|
||||
logger.debug(`Creating OVMSClient for provider: ${provider.id}`)
|
||||
instance = new OVMSClient(provider) as BaseApiClient
|
||||
return instance
|
||||
}
|
||||
|
||||
// 然后检查标准的 Provider Type
|
||||
switch (provider.type) {
|
||||
case 'openai':
|
||||
instance = new OpenAIAPIClient(provider) as BaseApiClient
|
||||
break
|
||||
case 'azure-openai':
|
||||
case 'openai-response':
|
||||
instance = new OpenAIResponseAPIClient(provider) as BaseApiClient
|
||||
break
|
||||
case 'gemini':
|
||||
instance = new GeminiAPIClient(provider) as BaseApiClient
|
||||
break
|
||||
case 'vertexai':
|
||||
logger.debug(`Creating VertexAPIClient for provider: ${provider.id}`)
|
||||
instance = new VertexAPIClient(provider) as BaseApiClient
|
||||
break
|
||||
case 'anthropic':
|
||||
instance = new AnthropicAPIClient(provider) as BaseApiClient
|
||||
break
|
||||
case 'aws-bedrock':
|
||||
instance = new AwsBedrockAPIClient(provider) as BaseApiClient
|
||||
break
|
||||
default:
|
||||
logger.debug(`Using default OpenAIApiClient for provider: ${provider.id}`)
|
||||
instance = new OpenAIAPIClient(provider) as BaseApiClient
|
||||
break
|
||||
}
|
||||
|
||||
return instance
|
||||
}
|
||||
}
|
||||
@ -1,489 +0,0 @@
|
||||
import { loggerService } from '@logger'
|
||||
import {
|
||||
getModelSupportedVerbosity,
|
||||
isFunctionCallingModel,
|
||||
isOpenAIModel,
|
||||
isSupportFlexServiceTierModel,
|
||||
isSupportTemperatureModel,
|
||||
isSupportTopPModel
|
||||
} from '@renderer/config/models'
|
||||
import { REFERENCE_PROMPT } from '@renderer/config/prompts'
|
||||
import { getLMStudioKeepAliveTime } from '@renderer/hooks/useLMStudio'
|
||||
import { getAssistantSettings } from '@renderer/services/AssistantService'
|
||||
import type { RootState } from '@renderer/store'
|
||||
import type {
|
||||
Assistant,
|
||||
GenerateImageParams,
|
||||
KnowledgeReference,
|
||||
MCPCallToolResponse,
|
||||
MCPTool,
|
||||
MCPToolResponse,
|
||||
MemoryItem,
|
||||
Model,
|
||||
Provider,
|
||||
ToolCallResponse,
|
||||
WebSearchProviderResponse,
|
||||
WebSearchResponse
|
||||
} from '@renderer/types'
|
||||
import {
|
||||
FileTypes,
|
||||
GroqServiceTiers,
|
||||
isGroqServiceTier,
|
||||
isOpenAIServiceTier,
|
||||
OpenAIServiceTiers,
|
||||
SystemProviderIds
|
||||
} from '@renderer/types'
|
||||
import type { OpenAIVerbosity } from '@renderer/types/aiCoreTypes'
|
||||
import type { Message } from '@renderer/types/newMessage'
|
||||
import type {
|
||||
RequestOptions,
|
||||
SdkInstance,
|
||||
SdkMessageParam,
|
||||
SdkModel,
|
||||
SdkParams,
|
||||
SdkRawChunk,
|
||||
SdkRawOutput,
|
||||
SdkTool,
|
||||
SdkToolCall
|
||||
} from '@renderer/types/sdk'
|
||||
import { isJSON, parseJSON } from '@renderer/utils'
|
||||
import { addAbortController, removeAbortController } from '@renderer/utils/abortController'
|
||||
import { findFileBlocks, getMainTextContent } from '@renderer/utils/messageUtils/find'
|
||||
import { isSupportServiceTierProvider } from '@renderer/utils/provider'
|
||||
import { defaultTimeout } from '@shared/config/constant'
|
||||
import { defaultAppHeaders } from '@shared/utils'
|
||||
import { isEmpty } from 'lodash'
|
||||
|
||||
import type { CompletionsContext } from '../middleware/types'
|
||||
import type { ApiClient, RequestTransformer, ResponseChunkTransformer } from './types'
|
||||
|
||||
const logger = loggerService.withContext('BaseApiClient')
|
||||
|
||||
/**
|
||||
* Abstract base class for API clients.
|
||||
* Provides common functionality and structure for specific client implementations.
|
||||
*/
|
||||
export abstract class BaseApiClient<
|
||||
TSdkInstance extends SdkInstance = SdkInstance,
|
||||
TSdkParams extends SdkParams = SdkParams,
|
||||
TRawOutput extends SdkRawOutput = SdkRawOutput,
|
||||
TRawChunk extends SdkRawChunk = SdkRawChunk,
|
||||
TMessageParam extends SdkMessageParam = SdkMessageParam,
|
||||
TToolCall extends SdkToolCall = SdkToolCall,
|
||||
TSdkSpecificTool extends SdkTool = SdkTool
|
||||
> implements ApiClient<TSdkInstance, TSdkParams, TRawOutput, TRawChunk, TMessageParam, TToolCall, TSdkSpecificTool>
|
||||
{
|
||||
public provider: Provider
|
||||
protected host: string
|
||||
protected sdkInstance?: TSdkInstance
|
||||
|
||||
constructor(provider: Provider) {
|
||||
this.provider = provider
|
||||
this.host = this.getBaseURL()
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the current API key with rotation support
|
||||
* This getter ensures API keys rotate on each access when multiple keys are configured
|
||||
*/
|
||||
protected get apiKey(): string {
|
||||
return this.getApiKey()
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取客户端的兼容性类型
|
||||
* 用于判断客户端是否支持特定功能,避免instanceof检查的类型收窄问题
|
||||
* 对于装饰器模式的客户端(如AihubmixAPIClient),应该返回其内部实际使用的客户端类型
|
||||
*/
|
||||
// oxlint-disable-next-line @typescript-eslint/no-unused-vars
|
||||
public getClientCompatibilityType(_model?: Model): string[] {
|
||||
// 默认返回类的名称
|
||||
return [this.constructor.name]
|
||||
}
|
||||
|
||||
// // 核心的completions方法 - 在中间件架构中,这通常只是一个占位符
|
||||
// abstract completions(params: CompletionsParams, internal?: ProcessingState): Promise<CompletionsResult>
|
||||
|
||||
/**
|
||||
* 核心API Endpoint
|
||||
**/
|
||||
|
||||
abstract createCompletions(payload: TSdkParams, options?: RequestOptions): Promise<TRawOutput>
|
||||
|
||||
abstract generateImage(generateImageParams: GenerateImageParams): Promise<string[]>
|
||||
|
||||
abstract getEmbeddingDimensions(model?: Model): Promise<number>
|
||||
|
||||
abstract listModels(): Promise<SdkModel[]>
|
||||
|
||||
abstract getSdkInstance(): Promise<TSdkInstance> | TSdkInstance
|
||||
|
||||
/**
|
||||
* 中间件
|
||||
**/
|
||||
|
||||
// 在 CoreRequestToSdkParamsMiddleware中使用
|
||||
abstract getRequestTransformer(): RequestTransformer<TSdkParams, TMessageParam>
|
||||
// 在RawSdkChunkToGenericChunkMiddleware中使用
|
||||
abstract getResponseChunkTransformer(ctx: CompletionsContext): ResponseChunkTransformer<TRawChunk>
|
||||
|
||||
/**
|
||||
* 工具转换
|
||||
**/
|
||||
|
||||
// Optional tool conversion methods - implement if needed by the specific provider
|
||||
abstract convertMcpToolsToSdkTools(mcpTools: MCPTool[]): TSdkSpecificTool[]
|
||||
|
||||
abstract convertSdkToolCallToMcp(toolCall: TToolCall, mcpTools: MCPTool[]): MCPTool | undefined
|
||||
|
||||
abstract convertSdkToolCallToMcpToolResponse(toolCall: TToolCall, mcpTool: MCPTool): ToolCallResponse
|
||||
|
||||
abstract buildSdkMessages(
|
||||
currentReqMessages: TMessageParam[],
|
||||
output: TRawOutput | string | undefined,
|
||||
toolResults: TMessageParam[],
|
||||
toolCalls?: TToolCall[]
|
||||
): TMessageParam[]
|
||||
|
||||
abstract estimateMessageTokens(message: TMessageParam): number
|
||||
|
||||
abstract convertMcpToolResponseToSdkMessageParam(
|
||||
mcpToolResponse: MCPToolResponse,
|
||||
resp: MCPCallToolResponse,
|
||||
model: Model
|
||||
): TMessageParam | undefined
|
||||
|
||||
/**
|
||||
* 从SDK载荷中提取消息数组(用于中间件中的类型安全访问)
|
||||
* 不同的提供商可能使用不同的字段名(如messages、history等)
|
||||
*/
|
||||
abstract extractMessagesFromSdkPayload(sdkPayload: TSdkParams): TMessageParam[]
|
||||
|
||||
/**
|
||||
* 通用函数
|
||||
**/
|
||||
|
||||
public getBaseURL(): string {
|
||||
return this.provider.apiHost
|
||||
}
|
||||
|
||||
public getApiKey() {
|
||||
const keys = this.provider.apiKey.split(',').map((key) => key.trim())
|
||||
const keyName = `provider:${this.provider.id}:last_used_key`
|
||||
|
||||
if (keys.length === 1) {
|
||||
return keys[0]
|
||||
}
|
||||
|
||||
const lastUsedKey = window.keyv.get(keyName)
|
||||
if (!lastUsedKey) {
|
||||
window.keyv.set(keyName, keys[0])
|
||||
return keys[0]
|
||||
}
|
||||
|
||||
const currentIndex = keys.indexOf(lastUsedKey)
|
||||
const nextIndex = (currentIndex + 1) % keys.length
|
||||
const nextKey = keys[nextIndex]
|
||||
window.keyv.set(keyName, nextKey)
|
||||
|
||||
return nextKey
|
||||
}
|
||||
|
||||
public defaultHeaders() {
|
||||
return {
|
||||
...defaultAppHeaders(),
|
||||
'X-Api-Key': this.apiKey
|
||||
}
|
||||
}
|
||||
|
||||
public get keepAliveTime() {
|
||||
return this.provider.id === 'lmstudio' ? getLMStudioKeepAliveTime() : undefined
|
||||
}
|
||||
|
||||
public getTemperature(assistant: Assistant, model: Model): number | undefined {
|
||||
if (!isSupportTemperatureModel(model)) {
|
||||
return undefined
|
||||
}
|
||||
const assistantSettings = getAssistantSettings(assistant)
|
||||
return assistantSettings?.enableTemperature ? assistantSettings?.temperature : undefined
|
||||
}
|
||||
|
||||
public getTopP(assistant: Assistant, model: Model): number | undefined {
|
||||
if (!isSupportTopPModel(model)) {
|
||||
return undefined
|
||||
}
|
||||
const assistantSettings = getAssistantSettings(assistant)
|
||||
return assistantSettings?.enableTopP ? assistantSettings?.topP : undefined
|
||||
}
|
||||
|
||||
// NOTE: 这个也许可以迁移到OpenAIBaseClient
|
||||
protected getServiceTier(model: Model) {
|
||||
const serviceTierSetting = this.provider.serviceTier
|
||||
|
||||
if (!isSupportServiceTierProvider(this.provider) || !isOpenAIModel(model) || !serviceTierSetting) {
|
||||
return undefined
|
||||
}
|
||||
|
||||
// 处理不同供应商需要 fallback 到默认值的情况
|
||||
if (this.provider.id === SystemProviderIds.groq) {
|
||||
if (
|
||||
!isGroqServiceTier(serviceTierSetting) ||
|
||||
(serviceTierSetting === GroqServiceTiers.flex && !isSupportFlexServiceTierModel(model))
|
||||
) {
|
||||
return undefined
|
||||
}
|
||||
} else {
|
||||
// 其他 OpenAI 供应商,假设他们的服务层级设置和 OpenAI 完全相同
|
||||
if (
|
||||
!isOpenAIServiceTier(serviceTierSetting) ||
|
||||
(serviceTierSetting === OpenAIServiceTiers.flex && !isSupportFlexServiceTierModel(model))
|
||||
) {
|
||||
return undefined
|
||||
}
|
||||
}
|
||||
|
||||
return serviceTierSetting
|
||||
}
|
||||
|
||||
protected getVerbosity(model?: Model): OpenAIVerbosity {
|
||||
try {
|
||||
const state = window.store?.getState() as RootState
|
||||
const verbosity = state?.settings?.openAI?.verbosity
|
||||
|
||||
// If model is provided, check if the verbosity is supported by the model
|
||||
if (model) {
|
||||
const supportedVerbosity = getModelSupportedVerbosity(model)
|
||||
// Use user's verbosity if supported, otherwise use the first supported option
|
||||
return supportedVerbosity.includes(verbosity) ? verbosity : supportedVerbosity[0]
|
||||
}
|
||||
return verbosity
|
||||
} catch (error) {
|
||||
logger.warn('Failed to get verbosity from state. Fallback to undefined.', error as Error)
|
||||
return undefined
|
||||
}
|
||||
}
|
||||
|
||||
protected getTimeout(model: Model) {
|
||||
if (isSupportFlexServiceTierModel(model)) {
|
||||
return 15 * 1000 * 60
|
||||
}
|
||||
return defaultTimeout
|
||||
}
|
||||
|
||||
public async getMessageContent(
|
||||
message: Message
|
||||
): Promise<{ textContent: string; imageContents: { fileId: string; fileExt: string }[] }> {
|
||||
const content = getMainTextContent(message)
|
||||
|
||||
if (isEmpty(content)) {
|
||||
return {
|
||||
textContent: '',
|
||||
imageContents: []
|
||||
}
|
||||
}
|
||||
|
||||
const webSearchReferences = await this.getWebSearchReferencesFromCache(message)
|
||||
const knowledgeReferences = await this.getKnowledgeBaseReferencesFromCache(message)
|
||||
const memoryReferences = this.getMemoryReferencesFromCache(message)
|
||||
|
||||
const knowledgeTextReferences = knowledgeReferences.filter((k) => k.metadata?.type !== 'image')
|
||||
const knowledgeImageReferences = knowledgeReferences.filter((k) => k.metadata?.type === 'image')
|
||||
|
||||
// 添加偏移量以避免ID冲突
|
||||
const reindexedKnowledgeReferences = knowledgeTextReferences.map((ref) => ({
|
||||
...ref,
|
||||
id: ref.id + webSearchReferences.length // 为知识库引用的ID添加网络搜索引用的数量作为偏移量
|
||||
}))
|
||||
|
||||
const allReferences = [...webSearchReferences, ...reindexedKnowledgeReferences, ...memoryReferences]
|
||||
|
||||
logger.debug(`Found ${allReferences.length} references for ID: ${message.id}`, allReferences)
|
||||
|
||||
const referenceContent = `\`\`\`json\n${JSON.stringify(allReferences, null, 2)}\n\`\`\``
|
||||
const imageReferences = knowledgeImageReferences.map((r) => {
|
||||
return { fileId: r.metadata?.id, fileExt: r.metadata?.ext }
|
||||
})
|
||||
|
||||
return {
|
||||
textContent: isEmpty(allReferences)
|
||||
? content
|
||||
: REFERENCE_PROMPT.replace('{question}', content).replace('{references}', referenceContent),
|
||||
imageContents: isEmpty(knowledgeImageReferences) ? [] : imageReferences
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Extract the file content from the message
|
||||
* @param message - The message
|
||||
* @returns The file content
|
||||
*/
|
||||
protected async extractFileContent(message: Message) {
|
||||
const fileBlocks = findFileBlocks(message)
|
||||
if (fileBlocks.length > 0) {
|
||||
const textFileBlocks = fileBlocks.filter(
|
||||
(fb) => fb.file && [FileTypes.TEXT, FileTypes.DOCUMENT].includes(fb.file.type)
|
||||
)
|
||||
|
||||
if (textFileBlocks.length > 0) {
|
||||
let text = ''
|
||||
const divider = '\n\n---\n\n'
|
||||
|
||||
for (const fileBlock of textFileBlocks) {
|
||||
const file = fileBlock.file
|
||||
const fileContent = (await window.api.file.read(file.id + file.ext, true)).trim()
|
||||
const fileNameRow = 'file: ' + file.origin_name + '\n\n'
|
||||
text = text + fileNameRow + fileContent + divider
|
||||
}
|
||||
|
||||
return text
|
||||
}
|
||||
}
|
||||
|
||||
return ''
|
||||
}
|
||||
|
||||
private getMemoryReferencesFromCache(message: Message) {
|
||||
const memories = window.keyv.get(`memory-search-${message.id}`) as MemoryItem[] | undefined
|
||||
if (memories) {
|
||||
const memoryReferences: KnowledgeReference[] = memories.map((mem, index) => ({
|
||||
id: index + 1,
|
||||
content: `${mem.memory} -- Created at: ${mem.createdAt}`,
|
||||
sourceUrl: '',
|
||||
type: 'memory'
|
||||
}))
|
||||
return memoryReferences
|
||||
}
|
||||
return []
|
||||
}
|
||||
|
||||
private async getWebSearchReferencesFromCache(message: Message) {
|
||||
const content = getMainTextContent(message)
|
||||
if (isEmpty(content)) {
|
||||
return []
|
||||
}
|
||||
const webSearch: WebSearchResponse = window.keyv.get(`web-search-${message.id}`)
|
||||
|
||||
if (webSearch) {
|
||||
window.keyv.remove(`web-search-${message.id}`)
|
||||
return (webSearch.results as WebSearchProviderResponse).results.map(
|
||||
(result, index) =>
|
||||
({
|
||||
id: index + 1,
|
||||
content: result.content,
|
||||
sourceUrl: result.url,
|
||||
type: 'url'
|
||||
}) as KnowledgeReference
|
||||
)
|
||||
}
|
||||
|
||||
return []
|
||||
}
|
||||
|
||||
/**
|
||||
* 从缓存中获取知识库引用
|
||||
*/
|
||||
private async getKnowledgeBaseReferencesFromCache(message: Message): Promise<KnowledgeReference[]> {
|
||||
const content = getMainTextContent(message)
|
||||
if (isEmpty(content)) {
|
||||
return []
|
||||
}
|
||||
const knowledgeReferences: KnowledgeReference[] = window.keyv.get(`knowledge-search-${message.id}`)
|
||||
|
||||
if (!isEmpty(knowledgeReferences)) {
|
||||
window.keyv.remove(`knowledge-search-${message.id}`)
|
||||
logger.debug(`Found ${knowledgeReferences.length} knowledge base references in cache for ID: ${message.id}`)
|
||||
return knowledgeReferences
|
||||
}
|
||||
logger.debug(`No knowledge base references found in cache for ID: ${message.id}`)
|
||||
return []
|
||||
}
|
||||
|
||||
protected getCustomParameters(assistant: Assistant) {
|
||||
return (
|
||||
assistant?.settings?.customParameters?.reduce((acc, param) => {
|
||||
if (!param.name?.trim()) {
|
||||
return acc
|
||||
}
|
||||
// Parse JSON type parameters (Legacy API clients)
|
||||
// Related: src/renderer/src/pages/settings/AssistantSettings/AssistantModelSettings.tsx:133-148
|
||||
// The UI stores JSON type params as strings, this function parses them before sending to API
|
||||
if (param.type === 'json') {
|
||||
const value = param.value as string
|
||||
if (value === 'undefined') {
|
||||
return { ...acc, [param.name]: undefined }
|
||||
}
|
||||
return { ...acc, [param.name]: isJSON(value) ? parseJSON(value) : value }
|
||||
}
|
||||
return {
|
||||
...acc,
|
||||
[param.name]: param.value
|
||||
}
|
||||
}, {}) || {}
|
||||
)
|
||||
}
|
||||
|
||||
public createAbortController(messageId?: string, isAddEventListener?: boolean) {
|
||||
const abortController = new AbortController()
|
||||
const abortFn = () => abortController.abort()
|
||||
|
||||
if (messageId) {
|
||||
addAbortController(messageId, abortFn)
|
||||
}
|
||||
|
||||
const cleanup = () => {
|
||||
if (messageId) {
|
||||
signalPromise.resolve?.(undefined)
|
||||
removeAbortController(messageId, abortFn)
|
||||
}
|
||||
}
|
||||
const signalPromise: {
|
||||
resolve: (value: unknown) => void
|
||||
promise: Promise<unknown>
|
||||
} = {
|
||||
resolve: () => {},
|
||||
promise: Promise.resolve()
|
||||
}
|
||||
|
||||
if (isAddEventListener) {
|
||||
signalPromise.promise = new Promise((resolve, reject) => {
|
||||
signalPromise.resolve = resolve
|
||||
if (abortController.signal.aborted) {
|
||||
reject(new Error('Request was aborted.'))
|
||||
}
|
||||
// 捕获abort事件,有些abort事件必须
|
||||
abortController.signal.addEventListener('abort', () => {
|
||||
reject(new Error('Request was aborted.'))
|
||||
})
|
||||
})
|
||||
return {
|
||||
abortController,
|
||||
cleanup,
|
||||
signalPromise
|
||||
}
|
||||
}
|
||||
return {
|
||||
abortController,
|
||||
cleanup
|
||||
}
|
||||
}
|
||||
|
||||
// Setup tools configuration based on provided parameters
|
||||
public setupToolsConfig(params: { mcpTools?: MCPTool[]; model: Model; enableToolUse?: boolean }): {
|
||||
tools: TSdkSpecificTool[]
|
||||
} {
|
||||
const { mcpTools, model, enableToolUse } = params
|
||||
let tools: TSdkSpecificTool[] = []
|
||||
|
||||
// If there are no tools, return an empty array
|
||||
if (!mcpTools?.length) {
|
||||
return { tools }
|
||||
}
|
||||
|
||||
// If the model supports function calling and tool usage is enabled
|
||||
if (isFunctionCallingModel(model) && enableToolUse) {
|
||||
tools = this.convertMcpToolsToSdkTools(mcpTools)
|
||||
}
|
||||
|
||||
return { tools }
|
||||
}
|
||||
}
|
||||
@ -1,181 +0,0 @@
|
||||
import type {
|
||||
GenerateImageParams,
|
||||
MCPCallToolResponse,
|
||||
MCPTool,
|
||||
MCPToolResponse,
|
||||
Model,
|
||||
Provider,
|
||||
ToolCallResponse
|
||||
} from '@renderer/types'
|
||||
import type {
|
||||
RequestOptions,
|
||||
SdkInstance,
|
||||
SdkMessageParam,
|
||||
SdkModel,
|
||||
SdkParams,
|
||||
SdkRawChunk,
|
||||
SdkRawOutput,
|
||||
SdkTool,
|
||||
SdkToolCall
|
||||
} from '@renderer/types/sdk'
|
||||
|
||||
import type { CompletionsContext } from '../middleware/types'
|
||||
import type { AnthropicAPIClient } from './anthropic/AnthropicAPIClient'
|
||||
import { BaseApiClient } from './BaseApiClient'
|
||||
import type { GeminiAPIClient } from './gemini/GeminiAPIClient'
|
||||
import type { OpenAIAPIClient } from './openai/OpenAIApiClient'
|
||||
import type { OpenAIResponseAPIClient } from './openai/OpenAIResponseAPIClient'
|
||||
import type { RequestTransformer, ResponseChunkTransformer } from './types'
|
||||
|
||||
/**
|
||||
* MixedAPIClient - 适用于可能含有多种接口类型的Provider
|
||||
*/
|
||||
export abstract class MixedBaseAPIClient extends BaseApiClient {
|
||||
// 使用联合类型而不是any,保持类型安全
|
||||
protected abstract clients: Map<
|
||||
string,
|
||||
AnthropicAPIClient | GeminiAPIClient | OpenAIResponseAPIClient | OpenAIAPIClient
|
||||
>
|
||||
protected abstract defaultClient: OpenAIAPIClient
|
||||
protected abstract currentClient: BaseApiClient
|
||||
|
||||
constructor(provider: Provider) {
|
||||
super(provider)
|
||||
}
|
||||
|
||||
override getBaseURL(): string {
|
||||
if (!this.currentClient) {
|
||||
return this.provider.apiHost
|
||||
}
|
||||
return this.currentClient.getBaseURL()
|
||||
}
|
||||
|
||||
/**
|
||||
* 类型守卫:确保client是BaseApiClient的实例
|
||||
*/
|
||||
protected isValidClient(client: unknown): client is BaseApiClient {
|
||||
return (
|
||||
client !== null &&
|
||||
client !== undefined &&
|
||||
typeof client === 'object' &&
|
||||
'createCompletions' in client &&
|
||||
'getRequestTransformer' in client &&
|
||||
'getResponseChunkTransformer' in client
|
||||
)
|
||||
}
|
||||
|
||||
/**
|
||||
* 根据模型获取合适的client
|
||||
*/
|
||||
protected abstract getClient(model: Model): BaseApiClient
|
||||
|
||||
/**
|
||||
* 根据模型选择合适的client并委托调用
|
||||
*/
|
||||
public getClientForModel(model: Model): BaseApiClient {
|
||||
this.currentClient = this.getClient(model)
|
||||
return this.currentClient
|
||||
}
|
||||
|
||||
/**
|
||||
* 重写基类方法,返回内部实际使用的客户端类型
|
||||
*/
|
||||
public override getClientCompatibilityType(model?: Model): string[] {
|
||||
if (!model) {
|
||||
return [this.constructor.name]
|
||||
}
|
||||
|
||||
const actualClient = this.getClient(model)
|
||||
return actualClient.getClientCompatibilityType(model)
|
||||
}
|
||||
|
||||
/**
|
||||
* 从SDK payload中提取模型ID
|
||||
*/
|
||||
protected extractModelFromPayload(payload: SdkParams): string | null {
|
||||
// 不同的SDK可能有不同的字段名
|
||||
if ('model' in payload && typeof payload.model === 'string') {
|
||||
return payload.model
|
||||
}
|
||||
return null
|
||||
}
|
||||
|
||||
// ============ BaseApiClient 的抽象方法 ============
|
||||
|
||||
async createCompletions(payload: SdkParams, options?: RequestOptions): Promise<SdkRawOutput> {
|
||||
// 尝试从payload中提取模型信息来选择client
|
||||
const modelId = this.extractModelFromPayload(payload)
|
||||
if (modelId) {
|
||||
const modelObj = { id: modelId } as Model
|
||||
const targetClient = this.getClient(modelObj)
|
||||
return targetClient.createCompletions(payload, options)
|
||||
}
|
||||
|
||||
// 如果无法从payload中提取模型,使用当前设置的client
|
||||
return this.currentClient.createCompletions(payload, options)
|
||||
}
|
||||
|
||||
async generateImage(params: GenerateImageParams): Promise<string[]> {
|
||||
return this.currentClient.generateImage(params)
|
||||
}
|
||||
|
||||
async getEmbeddingDimensions(model?: Model): Promise<number> {
|
||||
const client = model ? this.getClient(model) : this.currentClient
|
||||
return client.getEmbeddingDimensions(model)
|
||||
}
|
||||
|
||||
async listModels(): Promise<SdkModel[]> {
|
||||
// 可以聚合所有client的模型,或者使用默认client
|
||||
return this.defaultClient.listModels()
|
||||
}
|
||||
|
||||
async getSdkInstance(): Promise<SdkInstance> {
|
||||
return this.currentClient.getSdkInstance()
|
||||
}
|
||||
|
||||
getRequestTransformer(): RequestTransformer<SdkParams, SdkMessageParam> {
|
||||
return this.currentClient.getRequestTransformer()
|
||||
}
|
||||
|
||||
getResponseChunkTransformer(ctx: CompletionsContext): ResponseChunkTransformer<SdkRawChunk> {
|
||||
return this.currentClient.getResponseChunkTransformer(ctx)
|
||||
}
|
||||
|
||||
convertMcpToolsToSdkTools(mcpTools: MCPTool[]): SdkTool[] {
|
||||
return this.currentClient.convertMcpToolsToSdkTools(mcpTools)
|
||||
}
|
||||
|
||||
convertSdkToolCallToMcp(toolCall: SdkToolCall, mcpTools: MCPTool[]): MCPTool | undefined {
|
||||
return this.currentClient.convertSdkToolCallToMcp(toolCall, mcpTools)
|
||||
}
|
||||
|
||||
convertSdkToolCallToMcpToolResponse(toolCall: SdkToolCall, mcpTool: MCPTool): ToolCallResponse {
|
||||
return this.currentClient.convertSdkToolCallToMcpToolResponse(toolCall, mcpTool)
|
||||
}
|
||||
|
||||
buildSdkMessages(
|
||||
currentReqMessages: SdkMessageParam[],
|
||||
output: SdkRawOutput | string,
|
||||
toolResults: SdkMessageParam[],
|
||||
toolCalls?: SdkToolCall[]
|
||||
): SdkMessageParam[] {
|
||||
return this.currentClient.buildSdkMessages(currentReqMessages, output, toolResults, toolCalls)
|
||||
}
|
||||
|
||||
estimateMessageTokens(message: SdkMessageParam): number {
|
||||
return this.currentClient.estimateMessageTokens(message)
|
||||
}
|
||||
|
||||
convertMcpToolResponseToSdkMessageParam(
|
||||
mcpToolResponse: MCPToolResponse,
|
||||
resp: MCPCallToolResponse,
|
||||
model: Model
|
||||
): SdkMessageParam | undefined {
|
||||
const client = this.getClient(model)
|
||||
return client.convertMcpToolResponseToSdkMessageParam(mcpToolResponse, resp, model)
|
||||
}
|
||||
|
||||
extractMessagesFromSdkPayload(sdkPayload: SdkParams): SdkMessageParam[] {
|
||||
return this.currentClient.extractMessagesFromSdkPayload(sdkPayload)
|
||||
}
|
||||
}
|
||||
@ -1,221 +0,0 @@
|
||||
import type { Provider } from '@renderer/types'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import { AihubmixAPIClient } from '../aihubmix/AihubmixAPIClient'
|
||||
import { AnthropicAPIClient } from '../anthropic/AnthropicAPIClient'
|
||||
import { ApiClientFactory } from '../ApiClientFactory'
|
||||
import { AwsBedrockAPIClient } from '../aws/AwsBedrockAPIClient'
|
||||
import { GeminiAPIClient } from '../gemini/GeminiAPIClient'
|
||||
import { VertexAPIClient } from '../gemini/VertexAPIClient'
|
||||
import { NewAPIClient } from '../newapi/NewAPIClient'
|
||||
import { OpenAIAPIClient } from '../openai/OpenAIApiClient'
|
||||
import { OpenAIResponseAPIClient } from '../openai/OpenAIResponseAPIClient'
|
||||
import { PPIOAPIClient } from '../ppio/PPIOAPIClient'
|
||||
|
||||
// 为工厂测试创建最小化 provider 的辅助函数
|
||||
// ApiClientFactory 只使用 'id' 和 'type' 字段来决定创建哪个客户端
|
||||
// 其他字段会传递给客户端构造函数,但不影响工厂逻辑
|
||||
const createTestProvider = (id: string, type: string): Provider => ({
|
||||
id,
|
||||
type: type as Provider['type'],
|
||||
name: '',
|
||||
apiKey: '',
|
||||
apiHost: '',
|
||||
models: []
|
||||
})
|
||||
|
||||
// Mock 所有客户端模块
|
||||
vi.mock('../aihubmix/AihubmixAPIClient', () => ({
|
||||
AihubmixAPIClient: vi.fn().mockImplementation(() => ({}))
|
||||
}))
|
||||
vi.mock('../anthropic/AnthropicAPIClient', () => ({
|
||||
AnthropicAPIClient: vi.fn().mockImplementation(() => ({}))
|
||||
}))
|
||||
vi.mock('../anthropic/AnthropicVertexClient', () => ({
|
||||
AnthropicVertexClient: vi.fn().mockImplementation(() => ({}))
|
||||
}))
|
||||
vi.mock('../gemini/GeminiAPIClient', () => ({
|
||||
GeminiAPIClient: vi.fn().mockImplementation(() => ({}))
|
||||
}))
|
||||
vi.mock('../gemini/VertexAPIClient', () => ({
|
||||
VertexAPIClient: vi.fn().mockImplementation(() => ({}))
|
||||
}))
|
||||
vi.mock('../newapi/NewAPIClient', () => ({
|
||||
NewAPIClient: vi.fn().mockImplementation(() => ({}))
|
||||
}))
|
||||
vi.mock('../openai/OpenAIApiClient', () => ({
|
||||
OpenAIAPIClient: vi.fn().mockImplementation(() => ({}))
|
||||
}))
|
||||
vi.mock('../openai/OpenAIResponseAPIClient', () => ({
|
||||
OpenAIResponseAPIClient: vi.fn().mockImplementation(() => ({
|
||||
getClient: vi.fn().mockReturnThis()
|
||||
}))
|
||||
}))
|
||||
vi.mock('../ppio/PPIOAPIClient', () => ({
|
||||
PPIOAPIClient: vi.fn().mockImplementation(() => ({}))
|
||||
}))
|
||||
vi.mock('../aws/AwsBedrockAPIClient', () => ({
|
||||
AwsBedrockAPIClient: vi.fn().mockImplementation(() => ({}))
|
||||
}))
|
||||
|
||||
vi.mock('@renderer/services/AssistantService.ts', () => ({
|
||||
getDefaultAssistant: () => {
|
||||
return {
|
||||
id: 'default',
|
||||
name: 'default',
|
||||
emoji: '😀',
|
||||
prompt: '',
|
||||
topics: [],
|
||||
messages: [],
|
||||
type: 'assistant',
|
||||
regularPhrases: [],
|
||||
settings: {}
|
||||
}
|
||||
}
|
||||
}))
|
||||
|
||||
// Mock the models config to prevent circular dependency issues
|
||||
vi.mock('@renderer/config/models', () => ({
|
||||
findTokenLimit: vi.fn(),
|
||||
isReasoningModel: vi.fn(),
|
||||
isOpenAILLMModel: vi.fn(),
|
||||
SYSTEM_MODELS: {
|
||||
silicon: [],
|
||||
defaultModel: []
|
||||
},
|
||||
isOpenAIModel: vi.fn(() => false),
|
||||
glm45FlashModel: {},
|
||||
qwen38bModel: {}
|
||||
}))
|
||||
|
||||
describe('ApiClientFactory', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
|
||||
describe('create', () => {
|
||||
// 测试特殊 ID 的客户端创建
|
||||
it('should create AihubmixAPIClient for aihubmix provider', () => {
|
||||
const provider = createTestProvider('aihubmix', 'openai')
|
||||
|
||||
const client = ApiClientFactory.create(provider)
|
||||
|
||||
expect(AihubmixAPIClient).toHaveBeenCalledWith(provider)
|
||||
expect(client).toBeDefined()
|
||||
})
|
||||
|
||||
it('should create NewAPIClient for new-api provider', () => {
|
||||
const provider = createTestProvider('new-api', 'openai')
|
||||
|
||||
const client = ApiClientFactory.create(provider)
|
||||
|
||||
expect(NewAPIClient).toHaveBeenCalledWith(provider)
|
||||
expect(client).toBeDefined()
|
||||
})
|
||||
|
||||
it('should create PPIOAPIClient for ppio provider', () => {
|
||||
const provider = createTestProvider('ppio', 'openai')
|
||||
|
||||
const client = ApiClientFactory.create(provider)
|
||||
|
||||
expect(PPIOAPIClient).toHaveBeenCalledWith(provider)
|
||||
expect(client).toBeDefined()
|
||||
})
|
||||
|
||||
// 测试标准类型的客户端创建
|
||||
it('should create OpenAIAPIClient for openai type', () => {
|
||||
const provider = createTestProvider('custom-openai', 'openai')
|
||||
|
||||
const client = ApiClientFactory.create(provider)
|
||||
|
||||
expect(OpenAIAPIClient).toHaveBeenCalledWith(provider)
|
||||
expect(client).toBeDefined()
|
||||
})
|
||||
|
||||
it('should create OpenAIResponseAPIClient for azure-openai type', () => {
|
||||
const provider = createTestProvider('azure-openai', 'azure-openai')
|
||||
|
||||
const client = ApiClientFactory.create(provider)
|
||||
|
||||
expect(OpenAIResponseAPIClient).toHaveBeenCalledWith(provider)
|
||||
expect(client).toBeDefined()
|
||||
})
|
||||
|
||||
it('should create OpenAIResponseAPIClient for openai-response type', () => {
|
||||
const provider = createTestProvider('response', 'openai-response')
|
||||
|
||||
const client = ApiClientFactory.create(provider)
|
||||
|
||||
expect(OpenAIResponseAPIClient).toHaveBeenCalledWith(provider)
|
||||
expect(client).toBeDefined()
|
||||
})
|
||||
|
||||
it('should create GeminiAPIClient for gemini type', () => {
|
||||
const provider = createTestProvider('gemini', 'gemini')
|
||||
|
||||
const client = ApiClientFactory.create(provider)
|
||||
|
||||
expect(GeminiAPIClient).toHaveBeenCalledWith(provider)
|
||||
expect(client).toBeDefined()
|
||||
})
|
||||
|
||||
it('should create VertexAPIClient for vertexai type', () => {
|
||||
const provider = createTestProvider('vertex', 'vertexai')
|
||||
|
||||
const client = ApiClientFactory.create(provider)
|
||||
|
||||
expect(VertexAPIClient).toHaveBeenCalledWith(provider)
|
||||
expect(client).toBeDefined()
|
||||
})
|
||||
|
||||
it('should create AnthropicAPIClient for anthropic type', () => {
|
||||
const provider = createTestProvider('anthropic', 'anthropic')
|
||||
|
||||
const client = ApiClientFactory.create(provider)
|
||||
|
||||
expect(AnthropicAPIClient).toHaveBeenCalledWith(provider)
|
||||
expect(client).toBeDefined()
|
||||
})
|
||||
|
||||
it('should create AwsBedrockAPIClient for aws-bedrock type', () => {
|
||||
const provider = createTestProvider('aws-bedrock', 'aws-bedrock')
|
||||
|
||||
const client = ApiClientFactory.create(provider)
|
||||
|
||||
expect(AwsBedrockAPIClient).toHaveBeenCalledWith(provider)
|
||||
expect(client).toBeDefined()
|
||||
})
|
||||
|
||||
// 测试默认情况
|
||||
it('should create OpenAIAPIClient as default for unknown type', () => {
|
||||
const provider = createTestProvider('unknown', 'unknown-type')
|
||||
|
||||
const client = ApiClientFactory.create(provider)
|
||||
|
||||
expect(OpenAIAPIClient).toHaveBeenCalledWith(provider)
|
||||
expect(client).toBeDefined()
|
||||
})
|
||||
|
||||
// 测试边界条件
|
||||
it('should handle provider with minimal configuration', () => {
|
||||
const provider = createTestProvider('minimal', 'openai')
|
||||
|
||||
const client = ApiClientFactory.create(provider)
|
||||
|
||||
expect(OpenAIAPIClient).toHaveBeenCalledWith(provider)
|
||||
expect(client).toBeDefined()
|
||||
})
|
||||
|
||||
// 测试特殊 ID 优先级高于类型
|
||||
it('should prioritize special ID over type', () => {
|
||||
const provider = createTestProvider('aihubmix', 'anthropic') // 即使类型是 anthropic
|
||||
|
||||
const client = ApiClientFactory.create(provider)
|
||||
|
||||
// 应该创建 AihubmixAPIClient 而不是 AnthropicAPIClient
|
||||
expect(AihubmixAPIClient).toHaveBeenCalledWith(provider)
|
||||
expect(AnthropicAPIClient).not.toHaveBeenCalled()
|
||||
expect(client).toBeDefined()
|
||||
})
|
||||
})
|
||||
})
|
||||
@ -1,38 +0,0 @@
|
||||
import { describe, expect, it } from 'vitest'
|
||||
|
||||
import { normalizeAzureOpenAIEndpoint } from '../openai/azureOpenAIEndpoint'
|
||||
|
||||
describe('normalizeAzureOpenAIEndpoint', () => {
|
||||
it.each([
|
||||
{
|
||||
apiHost: 'https://example.openai.azure.com/openai',
|
||||
expectedEndpoint: 'https://example.openai.azure.com'
|
||||
},
|
||||
{
|
||||
apiHost: 'https://example.openai.azure.com/openai/',
|
||||
expectedEndpoint: 'https://example.openai.azure.com'
|
||||
},
|
||||
{
|
||||
apiHost: 'https://example.openai.azure.com/openai/v1',
|
||||
expectedEndpoint: 'https://example.openai.azure.com'
|
||||
},
|
||||
{
|
||||
apiHost: 'https://example.openai.azure.com/openai/v1/',
|
||||
expectedEndpoint: 'https://example.openai.azure.com'
|
||||
},
|
||||
{
|
||||
apiHost: 'https://example.openai.azure.com',
|
||||
expectedEndpoint: 'https://example.openai.azure.com'
|
||||
},
|
||||
{
|
||||
apiHost: 'https://example.openai.azure.com/',
|
||||
expectedEndpoint: 'https://example.openai.azure.com'
|
||||
},
|
||||
{
|
||||
apiHost: 'https://example.openai.azure.com/OPENAI/V1',
|
||||
expectedEndpoint: 'https://example.openai.azure.com'
|
||||
}
|
||||
])('strips trailing /openai from $apiHost', ({ apiHost, expectedEndpoint }) => {
|
||||
expect(normalizeAzureOpenAIEndpoint(apiHost)).toBe(expectedEndpoint)
|
||||
})
|
||||
})
|
||||
@ -1,354 +0,0 @@
|
||||
import { AihubmixAPIClient } from '@renderer/aiCore/legacy/clients/aihubmix/AihubmixAPIClient'
|
||||
import { AnthropicAPIClient } from '@renderer/aiCore/legacy/clients/anthropic/AnthropicAPIClient'
|
||||
import { ApiClientFactory } from '@renderer/aiCore/legacy/clients/ApiClientFactory'
|
||||
import { GeminiAPIClient } from '@renderer/aiCore/legacy/clients/gemini/GeminiAPIClient'
|
||||
import { VertexAPIClient } from '@renderer/aiCore/legacy/clients/gemini/VertexAPIClient'
|
||||
import { NewAPIClient } from '@renderer/aiCore/legacy/clients/newapi/NewAPIClient'
|
||||
import { OpenAIAPIClient } from '@renderer/aiCore/legacy/clients/openai/OpenAIApiClient'
|
||||
import { OpenAIResponseAPIClient } from '@renderer/aiCore/legacy/clients/openai/OpenAIResponseAPIClient'
|
||||
import type { EndpointType, Model, Provider } from '@renderer/types'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
vi.mock('@renderer/config/models', () => ({
|
||||
SYSTEM_MODELS: {
|
||||
defaultModel: [
|
||||
{ id: 'gpt-4', name: 'GPT-4' },
|
||||
{ id: 'gpt-4', name: 'GPT-4' },
|
||||
{ id: 'gpt-4', name: 'GPT-4' }
|
||||
],
|
||||
zhipu: [],
|
||||
silicon: [],
|
||||
openai: [],
|
||||
anthropic: [],
|
||||
gemini: []
|
||||
},
|
||||
isOpenAIModel: vi.fn().mockReturnValue(true),
|
||||
isOpenAILLMModel: vi.fn().mockReturnValue(true),
|
||||
isOpenAIChatCompletionOnlyModel: vi.fn().mockReturnValue(false),
|
||||
isAnthropicLLMModel: vi.fn().mockReturnValue(false),
|
||||
isGeminiLLMModel: vi.fn().mockReturnValue(false),
|
||||
isSupportedReasoningEffortOpenAIModel: vi.fn().mockReturnValue(false),
|
||||
isVisionModel: vi.fn().mockReturnValue(false),
|
||||
isClaudeReasoningModel: vi.fn().mockReturnValue(false),
|
||||
isReasoningModel: vi.fn().mockReturnValue(false),
|
||||
isWebSearchModel: vi.fn().mockReturnValue(false),
|
||||
findTokenLimit: vi.fn().mockReturnValue(4096),
|
||||
isFunctionCallingModel: vi.fn().mockReturnValue(false),
|
||||
DEFAULT_MAX_TOKENS: 4096,
|
||||
qwen38bModel: {},
|
||||
glm45FlashModel: {}
|
||||
}))
|
||||
|
||||
vi.mock('@renderer/services/AssistantService', () => ({
|
||||
getProviderByModel: vi.fn(),
|
||||
getAssistantSettings: vi.fn(),
|
||||
getDefaultAssistant: vi.fn().mockReturnValue({
|
||||
id: 'default',
|
||||
name: 'Default Assistant',
|
||||
prompt: '',
|
||||
settings: {}
|
||||
})
|
||||
}))
|
||||
|
||||
vi.mock('@renderer/services/FileManager', () => ({
|
||||
default: class {
|
||||
static async read() {
|
||||
return 'test content'
|
||||
}
|
||||
static async write() {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}))
|
||||
|
||||
vi.mock('@renderer/services/TokenService', () => ({
|
||||
estimateTextTokens: vi.fn().mockReturnValue(100)
|
||||
}))
|
||||
|
||||
vi.mock('@logger', () => ({
|
||||
loggerService: {
|
||||
withContext: vi.fn().mockReturnValue({
|
||||
debug: vi.fn(),
|
||||
info: vi.fn(),
|
||||
warn: vi.fn(),
|
||||
error: vi.fn(),
|
||||
silly: vi.fn()
|
||||
})
|
||||
}
|
||||
}))
|
||||
|
||||
// 到底是谁想出来的在服务层调用 React Hook ?????????
|
||||
// Mock additional services and hooks that might be imported
|
||||
vi.mock('@renderer/hooks/useVertexAI', () => ({
|
||||
getVertexAILocation: vi.fn().mockReturnValue('us-central1'),
|
||||
getVertexAIProjectId: vi.fn().mockReturnValue('test-project'),
|
||||
getVertexAIServiceAccount: vi.fn().mockReturnValue({
|
||||
privateKey: 'test-key',
|
||||
clientEmail: 'test@example.com'
|
||||
}),
|
||||
isVertexAIConfigured: vi.fn().mockReturnValue(true),
|
||||
isVertexProvider: vi.fn().mockReturnValue(true)
|
||||
}))
|
||||
|
||||
vi.mock('@renderer/hooks/useSettings', () => ({
|
||||
getStoreSetting: vi.fn().mockReturnValue({}),
|
||||
useSettings: vi.fn().mockReturnValue([{}, vi.fn()])
|
||||
}))
|
||||
|
||||
vi.mock('@renderer/store/settings', () => ({
|
||||
default: {},
|
||||
settingsSlice: {
|
||||
name: 'settings',
|
||||
reducer: vi.fn(),
|
||||
actions: {}
|
||||
}
|
||||
}))
|
||||
|
||||
vi.mock('@renderer/utils/abortController', () => ({
|
||||
addAbortController: vi.fn(),
|
||||
removeAbortController: vi.fn()
|
||||
}))
|
||||
|
||||
vi.mock('@anthropic-ai/sdk', () => ({
|
||||
default: vi.fn().mockImplementation(() => ({}))
|
||||
}))
|
||||
|
||||
vi.mock('@anthropic-ai/vertex-sdk', () => ({
|
||||
default: vi.fn().mockImplementation(() => ({}))
|
||||
}))
|
||||
|
||||
vi.mock('openai', () => ({
|
||||
default: vi.fn().mockImplementation(() => ({})),
|
||||
AzureOpenAI: vi.fn().mockImplementation(() => ({}))
|
||||
}))
|
||||
|
||||
vi.mock('@google/generative-ai', () => ({
|
||||
GoogleGenerativeAI: vi.fn().mockImplementation(() => ({}))
|
||||
}))
|
||||
|
||||
vi.mock('@google-cloud/vertexai', () => ({
|
||||
VertexAI: vi.fn().mockImplementation(() => ({}))
|
||||
}))
|
||||
|
||||
// Mock the circular dependency between VertexAPIClient and AnthropicVertexClient
|
||||
vi.mock('@renderer/aiCore/legacy/clients/anthropic/AnthropicVertexClient', () => {
|
||||
const MockAnthropicVertexClient = vi.fn()
|
||||
MockAnthropicVertexClient.prototype.getClientCompatibilityType = vi.fn().mockReturnValue(['AnthropicVertexAPIClient'])
|
||||
return {
|
||||
AnthropicVertexClient: MockAnthropicVertexClient
|
||||
}
|
||||
})
|
||||
|
||||
// Helper to create test provider
|
||||
const createTestProvider = (id: string, type: string): Provider => ({
|
||||
id,
|
||||
type: type as Provider['type'],
|
||||
name: 'Test Provider',
|
||||
apiKey: 'test-key',
|
||||
apiHost: 'https://api.test.com',
|
||||
models: []
|
||||
})
|
||||
|
||||
// Helper to create test model
|
||||
const createTestModel = (id: string, provider?: string, endpointType?: string): Model => ({
|
||||
id,
|
||||
name: 'Test Model',
|
||||
provider: provider || 'test',
|
||||
type: [],
|
||||
group: 'test',
|
||||
endpoint_type: endpointType as EndpointType
|
||||
})
|
||||
|
||||
describe('Client Compatibility Types', () => {
|
||||
let openaiProvider: Provider
|
||||
let anthropicProvider: Provider
|
||||
let geminiProvider: Provider
|
||||
let azureProvider: Provider
|
||||
let aihubmixProvider: Provider
|
||||
let newApiProvider: Provider
|
||||
let vertexProvider: Provider
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
|
||||
openaiProvider = createTestProvider('openai', 'openai')
|
||||
anthropicProvider = createTestProvider('anthropic', 'anthropic')
|
||||
geminiProvider = createTestProvider('gemini', 'gemini')
|
||||
azureProvider = createTestProvider('azure-openai', 'azure-openai')
|
||||
aihubmixProvider = createTestProvider('aihubmix', 'openai')
|
||||
newApiProvider = createTestProvider('new-api', 'openai')
|
||||
vertexProvider = createTestProvider('vertex', 'vertexai')
|
||||
})
|
||||
|
||||
describe('Direct API Clients', () => {
|
||||
it('should return correct compatibility type for OpenAIAPIClient', () => {
|
||||
const client = new OpenAIAPIClient(openaiProvider)
|
||||
const compatibilityTypes = client.getClientCompatibilityType()
|
||||
|
||||
expect(compatibilityTypes).toEqual(['OpenAIAPIClient'])
|
||||
})
|
||||
|
||||
it('should return correct compatibility type for AnthropicAPIClient', () => {
|
||||
const client = new AnthropicAPIClient(anthropicProvider)
|
||||
const compatibilityTypes = client.getClientCompatibilityType()
|
||||
|
||||
expect(compatibilityTypes).toEqual(['AnthropicAPIClient'])
|
||||
})
|
||||
|
||||
it('should return correct compatibility type for GeminiAPIClient', () => {
|
||||
const client = new GeminiAPIClient(geminiProvider)
|
||||
const compatibilityTypes = client.getClientCompatibilityType()
|
||||
|
||||
expect(compatibilityTypes).toEqual(['GeminiAPIClient'])
|
||||
})
|
||||
})
|
||||
|
||||
describe('Decorator Pattern API Clients', () => {
|
||||
it('should return OpenAIResponseAPIClient for OpenAIResponseAPIClient without model', () => {
|
||||
const client = new OpenAIResponseAPIClient(azureProvider)
|
||||
const compatibilityTypes = client.getClientCompatibilityType()
|
||||
|
||||
expect(compatibilityTypes).toEqual(['OpenAIResponseAPIClient'])
|
||||
})
|
||||
|
||||
it('should delegate to underlying client for OpenAIResponseAPIClient with model', () => {
|
||||
const client = new OpenAIResponseAPIClient(azureProvider)
|
||||
const testModel = createTestModel('gpt-4', 'azure-openai')
|
||||
|
||||
// Get the actual client selected for this model
|
||||
const actualClient = client.getClient(testModel)
|
||||
const compatibilityTypes = actualClient.getClientCompatibilityType(testModel)
|
||||
|
||||
// Should return OpenAIResponseAPIClient for non-chat-completion-only models
|
||||
expect(compatibilityTypes).toEqual(['OpenAIAPIClient'])
|
||||
})
|
||||
|
||||
it('should return AihubmixAPIClient for AihubmixAPIClient without model', () => {
|
||||
const client = new AihubmixAPIClient(aihubmixProvider)
|
||||
const compatibilityTypes = client.getClientCompatibilityType()
|
||||
|
||||
expect(compatibilityTypes).toEqual(['AihubmixAPIClient'])
|
||||
})
|
||||
|
||||
it('should delegate to underlying client for AihubmixAPIClient with model', () => {
|
||||
const client = new AihubmixAPIClient(aihubmixProvider)
|
||||
const testModel = createTestModel('gpt-4', 'openai')
|
||||
|
||||
// Get the actual client selected for this model
|
||||
const actualClient = client.getClientForModel(testModel)
|
||||
const compatibilityTypes = actualClient.getClientCompatibilityType(testModel)
|
||||
|
||||
// Should return the actual underlying client type based on model (OpenAI models use OpenAIResponseAPIClient in Aihubmix)
|
||||
expect(compatibilityTypes).toEqual(['OpenAIResponseAPIClient'])
|
||||
})
|
||||
|
||||
it('should return NewAPIClient for NewAPIClient without model', () => {
|
||||
const client = new NewAPIClient(newApiProvider)
|
||||
const compatibilityTypes = client.getClientCompatibilityType()
|
||||
|
||||
expect(compatibilityTypes).toEqual(['NewAPIClient'])
|
||||
})
|
||||
|
||||
it('should delegate to underlying client for NewAPIClient with model', () => {
|
||||
const client = new NewAPIClient(newApiProvider)
|
||||
const testModel = createTestModel('gpt-4', 'openai', 'openai-response')
|
||||
|
||||
// Get the actual client selected for this model
|
||||
const actualClient = client.getClientForModel(testModel)
|
||||
const compatibilityTypes = actualClient.getClientCompatibilityType(testModel)
|
||||
|
||||
// Should return the actual underlying client type based on model
|
||||
expect(compatibilityTypes).toEqual(['OpenAIResponseAPIClient'])
|
||||
})
|
||||
|
||||
it('should return VertexAPIClient for VertexAPIClient without model', () => {
|
||||
const client = new VertexAPIClient(vertexProvider)
|
||||
const compatibilityTypes = client.getClientCompatibilityType()
|
||||
|
||||
expect(compatibilityTypes).toEqual(['VertexAPIClient'])
|
||||
})
|
||||
|
||||
it('should delegate to underlying client for VertexAPIClient with model', () => {
|
||||
const client = new VertexAPIClient(vertexProvider)
|
||||
const testModel = createTestModel('claude-3-5-sonnet', 'vertexai')
|
||||
|
||||
// Get the actual client selected for this model
|
||||
const actualClient = client.getClient(testModel)
|
||||
const compatibilityTypes = actualClient.getClientCompatibilityType(testModel)
|
||||
|
||||
// Should return the actual underlying client type based on model (Claude models use AnthropicVertexClient)
|
||||
expect(compatibilityTypes).toEqual(['AnthropicVertexAPIClient'])
|
||||
})
|
||||
})
|
||||
|
||||
describe('Middleware Compatibility Logic', () => {
|
||||
it('should correctly identify OpenAI compatible clients', () => {
|
||||
const openaiClient = new OpenAIAPIClient(openaiProvider)
|
||||
const openaiResponseClient = new OpenAIResponseAPIClient(azureProvider)
|
||||
|
||||
const openaiTypes = openaiClient.getClientCompatibilityType()
|
||||
const responseTypes = openaiResponseClient.getClientCompatibilityType()
|
||||
|
||||
// Test the logic from completions method line 94
|
||||
const isOpenAICompatible = (types: string[]) =>
|
||||
types.includes('OpenAIAPIClient') || types.includes('OpenAIResponseAPIClient')
|
||||
|
||||
expect(isOpenAICompatible(openaiTypes)).toBe(true)
|
||||
expect(isOpenAICompatible(responseTypes)).toBe(true)
|
||||
})
|
||||
|
||||
it('should correctly identify Anthropic or OpenAIResponse compatible clients', () => {
|
||||
const anthropicClient = new AnthropicAPIClient(anthropicProvider)
|
||||
const openaiResponseClient = new OpenAIResponseAPIClient(azureProvider)
|
||||
const openaiClient = new OpenAIAPIClient(openaiProvider)
|
||||
|
||||
const anthropicTypes = anthropicClient.getClientCompatibilityType()
|
||||
const responseTypes = openaiResponseClient.getClientCompatibilityType()
|
||||
const openaiTypes = openaiClient.getClientCompatibilityType()
|
||||
|
||||
// Test the logic from completions method line 101
|
||||
const isAnthropicOrOpenAIResponseCompatible = (types: string[]) =>
|
||||
types.includes('AnthropicAPIClient') || types.includes('OpenAIResponseAPIClient')
|
||||
|
||||
expect(isAnthropicOrOpenAIResponseCompatible(anthropicTypes)).toBe(true)
|
||||
expect(isAnthropicOrOpenAIResponseCompatible(responseTypes)).toBe(true)
|
||||
expect(isAnthropicOrOpenAIResponseCompatible(openaiTypes)).toBe(false)
|
||||
})
|
||||
|
||||
it('should handle non-compatible clients correctly', () => {
|
||||
const geminiClient = new GeminiAPIClient(geminiProvider)
|
||||
const geminiTypes = geminiClient.getClientCompatibilityType()
|
||||
|
||||
// Test that Gemini is not OpenAI compatible
|
||||
const isOpenAICompatible = (types: string[]) =>
|
||||
types.includes('OpenAIAPIClient') || types.includes('OpenAIResponseAPIClient')
|
||||
|
||||
// Test that Gemini is not Anthropic/OpenAIResponse compatible
|
||||
const isAnthropicOrOpenAIResponseCompatible = (types: string[]) =>
|
||||
types.includes('AnthropicAPIClient') || types.includes('OpenAIResponseAPIClient')
|
||||
|
||||
expect(isOpenAICompatible(geminiTypes)).toBe(false)
|
||||
expect(isAnthropicOrOpenAIResponseCompatible(geminiTypes)).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
describe('Factory Integration', () => {
|
||||
it('should return correct compatibility types for factory-created clients', () => {
|
||||
const testCases = [
|
||||
{ provider: openaiProvider, expectedType: 'OpenAIAPIClient' },
|
||||
{ provider: anthropicProvider, expectedType: 'AnthropicAPIClient' },
|
||||
{ provider: azureProvider, expectedType: 'OpenAIResponseAPIClient' },
|
||||
{ provider: aihubmixProvider, expectedType: 'AihubmixAPIClient' },
|
||||
{ provider: newApiProvider, expectedType: 'NewAPIClient' },
|
||||
{ provider: vertexProvider, expectedType: 'VertexAPIClient' }
|
||||
]
|
||||
|
||||
testCases.forEach(({ provider, expectedType }) => {
|
||||
const client = ApiClientFactory.create(provider)
|
||||
const compatibilityTypes = client.getClientCompatibilityType()
|
||||
|
||||
expect(compatibilityTypes).toContain(expectedType)
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
@ -1,96 +0,0 @@
|
||||
import { isOpenAILLMModel } from '@renderer/config/models'
|
||||
import type { Model, Provider } from '@renderer/types'
|
||||
|
||||
import { AnthropicAPIClient } from '../anthropic/AnthropicAPIClient'
|
||||
import type { BaseApiClient } from '../BaseApiClient'
|
||||
import { GeminiAPIClient } from '../gemini/GeminiAPIClient'
|
||||
import { MixedBaseAPIClient } from '../MixedBaseApiClient'
|
||||
import { OpenAIAPIClient } from '../openai/OpenAIApiClient'
|
||||
import { OpenAIResponseAPIClient } from '../openai/OpenAIResponseAPIClient'
|
||||
|
||||
/**
|
||||
* AihubmixAPIClient - 根据模型类型自动选择合适的ApiClient
|
||||
* 使用装饰器模式实现,在ApiClient层面进行模型路由
|
||||
*/
|
||||
export class AihubmixAPIClient extends MixedBaseAPIClient {
|
||||
// 使用联合类型而不是any,保持类型安全
|
||||
protected clients: Map<string, AnthropicAPIClient | GeminiAPIClient | OpenAIResponseAPIClient | OpenAIAPIClient> =
|
||||
new Map()
|
||||
protected defaultClient: OpenAIAPIClient
|
||||
protected currentClient: BaseApiClient
|
||||
|
||||
constructor(provider: Provider) {
|
||||
super(provider)
|
||||
|
||||
const providerExtraHeaders = {
|
||||
...provider,
|
||||
extra_headers: {
|
||||
...provider.extra_headers,
|
||||
'APP-Code': 'MLTG2087'
|
||||
}
|
||||
}
|
||||
|
||||
// 初始化各个client - 现在有类型安全
|
||||
const claudeClient = new AnthropicAPIClient(providerExtraHeaders)
|
||||
const geminiClient = new GeminiAPIClient({ ...providerExtraHeaders, apiHost: 'https://aihubmix.com/gemini' })
|
||||
const openaiClient = new OpenAIResponseAPIClient(providerExtraHeaders)
|
||||
const defaultClient = new OpenAIAPIClient(providerExtraHeaders)
|
||||
|
||||
this.clients.set('claude', claudeClient)
|
||||
this.clients.set('gemini', geminiClient)
|
||||
this.clients.set('openai', openaiClient)
|
||||
this.clients.set('default', defaultClient)
|
||||
|
||||
// 设置默认client
|
||||
this.defaultClient = defaultClient
|
||||
this.currentClient = this.defaultClient as BaseApiClient
|
||||
}
|
||||
|
||||
override getBaseURL(): string {
|
||||
if (!this.currentClient) {
|
||||
return this.provider.apiHost
|
||||
}
|
||||
return this.currentClient.getBaseURL()
|
||||
}
|
||||
|
||||
/**
|
||||
* 根据模型获取合适的client
|
||||
*/
|
||||
protected getClient(model: Model): BaseApiClient {
|
||||
const id = model.id.toLowerCase()
|
||||
|
||||
// claude开头
|
||||
if (id.startsWith('claude')) {
|
||||
const client = this.clients.get('claude')
|
||||
if (!client || !this.isValidClient(client)) {
|
||||
throw new Error('Claude client not properly initialized')
|
||||
}
|
||||
return client
|
||||
}
|
||||
|
||||
// gemini开头 且不以-nothink、-search结尾
|
||||
if (
|
||||
(id.startsWith('gemini') || id.startsWith('imagen')) &&
|
||||
!id.endsWith('-nothink') &&
|
||||
!id.endsWith('-search') &&
|
||||
!id.includes('embedding')
|
||||
) {
|
||||
const client = this.clients.get('gemini')
|
||||
if (!client || !this.isValidClient(client)) {
|
||||
throw new Error('Gemini client not properly initialized')
|
||||
}
|
||||
return client
|
||||
}
|
||||
|
||||
// OpenAI系列模型 不包含gpt-oss
|
||||
if (isOpenAILLMModel(model) && !model.id.includes('gpt-oss')) {
|
||||
const client = this.clients.get('openai')
|
||||
if (!client || !this.isValidClient(client)) {
|
||||
throw new Error('OpenAI client not properly initialized')
|
||||
}
|
||||
return client
|
||||
}
|
||||
|
||||
return this.defaultClient as BaseApiClient
|
||||
}
|
||||
}
|
||||
@ -1,788 +0,0 @@
|
||||
import type Anthropic from '@anthropic-ai/sdk'
|
||||
import type {
|
||||
Base64ImageSource,
|
||||
ImageBlockParam,
|
||||
MessageParam,
|
||||
TextBlockParam,
|
||||
ToolResultBlockParam,
|
||||
ToolUseBlock,
|
||||
WebSearchTool20250305
|
||||
} from '@anthropic-ai/sdk/resources'
|
||||
import type {
|
||||
ContentBlock,
|
||||
ContentBlockParam,
|
||||
MessageCreateParamsBase,
|
||||
RedactedThinkingBlockParam,
|
||||
ServerToolUseBlockParam,
|
||||
ThinkingBlockParam,
|
||||
ThinkingConfigParam,
|
||||
ToolUnion,
|
||||
ToolUseBlockParam,
|
||||
WebSearchResultBlock,
|
||||
WebSearchToolResultBlockParam,
|
||||
WebSearchToolResultError
|
||||
} from '@anthropic-ai/sdk/resources/messages'
|
||||
import { MessageStream } from '@anthropic-ai/sdk/resources/messages/messages'
|
||||
import type AnthropicVertex from '@anthropic-ai/vertex-sdk'
|
||||
import { loggerService } from '@logger'
|
||||
import { DEFAULT_MAX_TOKENS } from '@renderer/config/constant'
|
||||
import { findTokenLimit, isClaudeReasoningModel, isReasoningModel, isWebSearchModel } from '@renderer/config/models'
|
||||
import { getAssistantSettings } from '@renderer/services/AssistantService'
|
||||
import FileManager from '@renderer/services/FileManager'
|
||||
import { estimateTextTokens } from '@renderer/services/TokenService'
|
||||
import type {
|
||||
Assistant,
|
||||
MCPCallToolResponse,
|
||||
MCPTool,
|
||||
MCPToolResponse,
|
||||
Model,
|
||||
Provider,
|
||||
ToolCallResponse
|
||||
} from '@renderer/types'
|
||||
import { EFFORT_RATIO, FileTypes, WebSearchSource } from '@renderer/types'
|
||||
import type {
|
||||
ErrorChunk,
|
||||
LLMWebSearchCompleteChunk,
|
||||
LLMWebSearchInProgressChunk,
|
||||
MCPToolCreatedChunk,
|
||||
TextDeltaChunk,
|
||||
TextStartChunk,
|
||||
ThinkingDeltaChunk,
|
||||
ThinkingStartChunk
|
||||
} from '@renderer/types/chunk'
|
||||
import { ChunkType } from '@renderer/types/chunk'
|
||||
import { type Message } from '@renderer/types/newMessage'
|
||||
import type {
|
||||
AnthropicSdkMessageParam,
|
||||
AnthropicSdkParams,
|
||||
AnthropicSdkRawChunk,
|
||||
AnthropicSdkRawOutput
|
||||
} from '@renderer/types/sdk'
|
||||
import { addImageFileToContents } from '@renderer/utils/formats'
|
||||
import {
|
||||
anthropicToolUseToMcpTool,
|
||||
isSupportedToolUse,
|
||||
mcpToolCallResponseToAnthropicMessage,
|
||||
mcpToolsToAnthropicTools
|
||||
} from '@renderer/utils/mcp-tools'
|
||||
import { findFileBlocks, findImageBlocks } from '@renderer/utils/messageUtils/find'
|
||||
import { buildClaudeCodeSystemMessage, getSdkClient } from '@shared/anthropic'
|
||||
import { t } from 'i18next'
|
||||
|
||||
import type { GenericChunk } from '../../middleware/schemas'
|
||||
import { BaseApiClient } from '../BaseApiClient'
|
||||
import type { AnthropicStreamListener, RawStreamListener, RequestTransformer, ResponseChunkTransformer } from '../types'
|
||||
|
||||
const logger = loggerService.withContext('AnthropicAPIClient')
|
||||
|
||||
export class AnthropicAPIClient extends BaseApiClient<
|
||||
Anthropic | AnthropicVertex,
|
||||
AnthropicSdkParams,
|
||||
AnthropicSdkRawOutput,
|
||||
AnthropicSdkRawChunk,
|
||||
AnthropicSdkMessageParam,
|
||||
ToolUseBlock,
|
||||
ToolUnion
|
||||
> {
|
||||
oauthToken: string | undefined = undefined
|
||||
sdkInstance: Anthropic | AnthropicVertex | undefined = undefined
|
||||
|
||||
constructor(provider: Provider) {
|
||||
super(provider)
|
||||
}
|
||||
|
||||
async getSdkInstance(): Promise<Anthropic | AnthropicVertex> {
|
||||
if (this.sdkInstance) {
|
||||
return this.sdkInstance
|
||||
}
|
||||
if (this.provider.authType === 'oauth') {
|
||||
this.oauthToken = await window.api.anthropic_oauth.getAccessToken()
|
||||
}
|
||||
this.sdkInstance = getSdkClient(this.provider, this.oauthToken)
|
||||
return this.sdkInstance
|
||||
}
|
||||
|
||||
override async createCompletions(
|
||||
payload: AnthropicSdkParams,
|
||||
options?: Anthropic.RequestOptions
|
||||
): Promise<AnthropicSdkRawOutput> {
|
||||
if (this.provider.authType === 'oauth') {
|
||||
payload.system = buildClaudeCodeSystemMessage(payload.system)
|
||||
}
|
||||
const sdk = (await this.getSdkInstance()) as Anthropic
|
||||
if (payload.stream) {
|
||||
return sdk.messages.stream(payload, options)
|
||||
}
|
||||
return sdk.messages.create(payload, options)
|
||||
}
|
||||
|
||||
// @ts-ignore sdk未提供
|
||||
// oxlint-disable-next-line @typescript-eslint/no-unused-vars
|
||||
override async generateImage(generateImageParams: GenerateImageParams): Promise<string[]> {
|
||||
return []
|
||||
}
|
||||
|
||||
override async listModels(): Promise<Anthropic.ModelInfo[]> {
|
||||
const sdk = (await this.getSdkInstance()) as Anthropic
|
||||
// prevent auto appended /v1. It's included in baseUrl.
|
||||
const response = await sdk.models.list({ path: '/models' })
|
||||
return response.data
|
||||
}
|
||||
|
||||
// @ts-ignore sdk未提供
|
||||
override async getEmbeddingDimensions(): Promise<number> {
|
||||
throw new Error("Anthropic SDK doesn't support getEmbeddingDimensions method.")
|
||||
}
|
||||
|
||||
override getTemperature(assistant: Assistant, model: Model): number | undefined {
|
||||
if (assistant.settings?.reasoning_effort && isClaudeReasoningModel(model)) {
|
||||
return undefined
|
||||
}
|
||||
return super.getTemperature(assistant, model)
|
||||
}
|
||||
|
||||
override getTopP(assistant: Assistant, model: Model): number | undefined {
|
||||
if (assistant.settings?.reasoning_effort && isClaudeReasoningModel(model)) {
|
||||
return undefined
|
||||
}
|
||||
return super.getTopP(assistant, model)
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the reasoning effort
|
||||
* @param assistant - The assistant
|
||||
* @param model - The model
|
||||
* @returns The reasoning effort
|
||||
*/
|
||||
private getBudgetToken(assistant: Assistant, model: Model): ThinkingConfigParam | undefined {
|
||||
if (!isReasoningModel(model)) {
|
||||
return undefined
|
||||
}
|
||||
const { maxTokens } = getAssistantSettings(assistant)
|
||||
|
||||
const reasoningEffort = assistant?.settings?.reasoning_effort
|
||||
|
||||
if (reasoningEffort === undefined) {
|
||||
return {
|
||||
type: 'disabled'
|
||||
}
|
||||
}
|
||||
|
||||
const effortRatio = EFFORT_RATIO[reasoningEffort]
|
||||
|
||||
const budgetTokens = Math.max(
|
||||
1024,
|
||||
Math.floor(
|
||||
Math.min(
|
||||
(findTokenLimit(model.id)?.max! - findTokenLimit(model.id)?.min!) * effortRatio +
|
||||
findTokenLimit(model.id)?.min!,
|
||||
(maxTokens || DEFAULT_MAX_TOKENS) * effortRatio
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
return {
|
||||
type: 'enabled',
|
||||
budget_tokens: budgetTokens
|
||||
}
|
||||
}
|
||||
|
||||
private static isValidBase64ImageMediaType(mime: string): mime is Base64ImageSource['media_type'] {
|
||||
return ['image/jpeg', 'image/png', 'image/gif', 'image/webp'].includes(mime)
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the message parameter
|
||||
* @param message - The message
|
||||
* @returns The message parameter
|
||||
*/
|
||||
public async convertMessageToSdkParam(message: Message): Promise<AnthropicSdkMessageParam> {
|
||||
const { textContent, imageContents } = await this.getMessageContent(message)
|
||||
|
||||
const parts: MessageParam['content'] = [
|
||||
{
|
||||
type: 'text',
|
||||
text: textContent
|
||||
}
|
||||
]
|
||||
|
||||
if (imageContents.length > 0) {
|
||||
for (const imageContent of imageContents) {
|
||||
const base64Data = await window.api.file.base64Image(imageContent.fileId + imageContent.fileExt)
|
||||
base64Data.mime = base64Data.mime.replace('jpg', 'jpeg')
|
||||
if (AnthropicAPIClient.isValidBase64ImageMediaType(base64Data.mime)) {
|
||||
parts.push({
|
||||
type: 'image',
|
||||
source: {
|
||||
data: base64Data.base64,
|
||||
media_type: base64Data.mime,
|
||||
type: 'base64'
|
||||
}
|
||||
})
|
||||
} else {
|
||||
logger.warn('Unsupported image type, ignored.', { mime: base64Data.mime })
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Get and process image blocks
|
||||
const imageBlocks = findImageBlocks(message)
|
||||
for (const imageBlock of imageBlocks) {
|
||||
if (imageBlock.file) {
|
||||
// Handle uploaded file
|
||||
const file = imageBlock.file
|
||||
const base64Data = await window.api.file.base64Image(file.id + file.ext)
|
||||
parts.push({
|
||||
type: 'image',
|
||||
source: {
|
||||
data: base64Data.base64,
|
||||
media_type: base64Data.mime.replace('jpg', 'jpeg') as any,
|
||||
type: 'base64'
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
// Get and process file blocks
|
||||
const fileBlocks = findFileBlocks(message)
|
||||
for (const fileBlock of fileBlocks) {
|
||||
const { file } = fileBlock
|
||||
if ([FileTypes.TEXT, FileTypes.DOCUMENT].includes(file.type)) {
|
||||
if (file.ext === '.pdf' && file.size < 32 * 1024 * 1024) {
|
||||
const base64Data = await FileManager.readBase64File(file)
|
||||
parts.push({
|
||||
type: 'document',
|
||||
source: {
|
||||
type: 'base64',
|
||||
media_type: 'application/pdf',
|
||||
data: base64Data
|
||||
}
|
||||
})
|
||||
} else {
|
||||
const fileContent = await (await window.api.file.read(file.id + file.ext, true)).trim()
|
||||
parts.push({
|
||||
type: 'text',
|
||||
text: file.origin_name + '\n' + fileContent
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
role: message.role === 'system' ? 'user' : message.role,
|
||||
content: parts
|
||||
}
|
||||
}
|
||||
|
||||
public convertMcpToolsToSdkTools(mcpTools: MCPTool[]): ToolUnion[] {
|
||||
return mcpToolsToAnthropicTools(mcpTools)
|
||||
}
|
||||
|
||||
public convertMcpToolResponseToSdkMessageParam(
|
||||
mcpToolResponse: MCPToolResponse,
|
||||
resp: MCPCallToolResponse,
|
||||
model: Model
|
||||
): AnthropicSdkMessageParam | undefined {
|
||||
if ('toolUseId' in mcpToolResponse && mcpToolResponse.toolUseId) {
|
||||
return mcpToolCallResponseToAnthropicMessage(mcpToolResponse, resp, model)
|
||||
} else if ('toolCallId' in mcpToolResponse) {
|
||||
return {
|
||||
role: 'user',
|
||||
content: [
|
||||
{
|
||||
type: 'tool_result',
|
||||
tool_use_id: mcpToolResponse.toolCallId!,
|
||||
content: resp.content
|
||||
.map((item) => {
|
||||
if (item.type === 'text') {
|
||||
return {
|
||||
type: 'text',
|
||||
text: item.text || ''
|
||||
} satisfies TextBlockParam
|
||||
}
|
||||
if (item.type === 'image') {
|
||||
return {
|
||||
type: 'image',
|
||||
source: {
|
||||
data: item.data || '',
|
||||
media_type: (item.mimeType || 'image/png') as Base64ImageSource['media_type'],
|
||||
type: 'base64'
|
||||
}
|
||||
} satisfies ImageBlockParam
|
||||
}
|
||||
return
|
||||
})
|
||||
.filter((n) => typeof n !== 'undefined'),
|
||||
is_error: resp.isError
|
||||
} satisfies ToolResultBlockParam
|
||||
]
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Implementing abstract methods from BaseApiClient
|
||||
convertSdkToolCallToMcp(toolCall: ToolUseBlock, mcpTools: MCPTool[]): MCPTool | undefined {
|
||||
// Based on anthropicToolUseToMcpTool logic in AnthropicProvider
|
||||
// This might need adjustment based on how tool calls are specifically handled in the new structure
|
||||
const mcpTool = anthropicToolUseToMcpTool(mcpTools, toolCall)
|
||||
return mcpTool
|
||||
}
|
||||
|
||||
convertSdkToolCallToMcpToolResponse(toolCall: ToolUseBlock, mcpTool: MCPTool): ToolCallResponse {
|
||||
return {
|
||||
id: toolCall.id,
|
||||
toolCallId: toolCall.id,
|
||||
tool: mcpTool,
|
||||
arguments: toolCall.input as Record<string, unknown>,
|
||||
status: 'pending'
|
||||
} as ToolCallResponse
|
||||
}
|
||||
|
||||
override buildSdkMessages(
|
||||
currentReqMessages: AnthropicSdkMessageParam[],
|
||||
output: Anthropic.Message,
|
||||
toolResults: AnthropicSdkMessageParam[]
|
||||
): AnthropicSdkMessageParam[] {
|
||||
const assistantMessage: AnthropicSdkMessageParam = {
|
||||
role: output.role,
|
||||
content: convertContentBlocksToParams(output.content)
|
||||
}
|
||||
|
||||
const newMessages: AnthropicSdkMessageParam[] = [...currentReqMessages, assistantMessage]
|
||||
if (toolResults && toolResults.length > 0) {
|
||||
newMessages.push(...toolResults)
|
||||
}
|
||||
return newMessages
|
||||
}
|
||||
|
||||
override estimateMessageTokens(message: AnthropicSdkMessageParam): number {
|
||||
if (typeof message.content === 'string') {
|
||||
return estimateTextTokens(message.content)
|
||||
}
|
||||
return message.content
|
||||
.map((content) => {
|
||||
switch (content.type) {
|
||||
case 'text':
|
||||
return estimateTextTokens(content.text)
|
||||
case 'image':
|
||||
if (content.source.type === 'base64') {
|
||||
return estimateTextTokens(content.source.data)
|
||||
} else {
|
||||
return estimateTextTokens(content.source.url)
|
||||
}
|
||||
case 'tool_use':
|
||||
return estimateTextTokens(JSON.stringify(content.input))
|
||||
case 'tool_result':
|
||||
return estimateTextTokens(JSON.stringify(content.content))
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
})
|
||||
.reduce((acc, curr) => acc + curr, 0)
|
||||
}
|
||||
|
||||
public buildAssistantMessage(message: Anthropic.Message): AnthropicSdkMessageParam {
|
||||
const messageParam: AnthropicSdkMessageParam = {
|
||||
role: message.role,
|
||||
content: convertContentBlocksToParams(message.content)
|
||||
}
|
||||
return messageParam
|
||||
}
|
||||
|
||||
public extractMessagesFromSdkPayload(sdkPayload: AnthropicSdkParams): AnthropicSdkMessageParam[] {
|
||||
return sdkPayload.messages || []
|
||||
}
|
||||
|
||||
/**
|
||||
* Anthropic专用的原始流监听器
|
||||
* 处理MessageStream对象的特定事件
|
||||
*/
|
||||
attachRawStreamListener(
|
||||
rawOutput: AnthropicSdkRawOutput,
|
||||
listener: RawStreamListener<AnthropicSdkRawChunk>
|
||||
): AnthropicSdkRawOutput {
|
||||
logger.debug(`Attaching stream listener to raw output`)
|
||||
// 专用的Anthropic事件处理
|
||||
const anthropicListener = listener as AnthropicStreamListener
|
||||
// 检查是否为MessageStream
|
||||
if (rawOutput instanceof MessageStream) {
|
||||
logger.debug(`Detected Anthropic MessageStream, attaching specialized listener`)
|
||||
|
||||
if (listener.onStart) {
|
||||
listener.onStart()
|
||||
}
|
||||
|
||||
if (listener.onChunk) {
|
||||
rawOutput.on('streamEvent', (event: AnthropicSdkRawChunk) => {
|
||||
listener.onChunk!(event)
|
||||
})
|
||||
}
|
||||
|
||||
if (anthropicListener.onContentBlock) {
|
||||
rawOutput.on('contentBlock', anthropicListener.onContentBlock)
|
||||
}
|
||||
|
||||
if (anthropicListener.onMessage) {
|
||||
rawOutput.on('finalMessage', anthropicListener.onMessage)
|
||||
}
|
||||
|
||||
if (listener.onEnd) {
|
||||
rawOutput.on('end', () => {
|
||||
listener.onEnd!()
|
||||
})
|
||||
}
|
||||
|
||||
if (listener.onError) {
|
||||
rawOutput.on('error', (error: Error) => {
|
||||
listener.onError!(error)
|
||||
})
|
||||
}
|
||||
|
||||
return rawOutput
|
||||
}
|
||||
|
||||
if (anthropicListener.onMessage) {
|
||||
anthropicListener.onMessage(rawOutput)
|
||||
}
|
||||
|
||||
// 对于非MessageStream响应
|
||||
return rawOutput
|
||||
}
|
||||
|
||||
private async getWebSearchParams(model: Model): Promise<WebSearchTool20250305 | undefined> {
|
||||
if (!isWebSearchModel(model)) {
|
||||
return undefined
|
||||
}
|
||||
return {
|
||||
type: 'web_search_20250305',
|
||||
name: 'web_search',
|
||||
max_uses: 5
|
||||
} as WebSearchTool20250305
|
||||
}
|
||||
|
||||
getRequestTransformer(): RequestTransformer<AnthropicSdkParams, AnthropicSdkMessageParam> {
|
||||
return {
|
||||
transform: async (
|
||||
coreRequest,
|
||||
assistant,
|
||||
model,
|
||||
isRecursiveCall,
|
||||
recursiveSdkMessages
|
||||
): Promise<{
|
||||
payload: AnthropicSdkParams
|
||||
messages: AnthropicSdkMessageParam[]
|
||||
metadata: Record<string, any>
|
||||
}> => {
|
||||
const { messages, mcpTools, maxTokens, streamOutput, enableWebSearch } = coreRequest
|
||||
// 1. 处理系统消息
|
||||
const systemPrompt = assistant.prompt
|
||||
|
||||
// 2. 设置工具
|
||||
const { tools } = this.setupToolsConfig({
|
||||
mcpTools: mcpTools,
|
||||
model,
|
||||
enableToolUse: isSupportedToolUse(assistant)
|
||||
})
|
||||
|
||||
const systemMessage: TextBlockParam | undefined = systemPrompt
|
||||
? { type: 'text', text: systemPrompt }
|
||||
: undefined
|
||||
|
||||
// 3. 处理用户消息
|
||||
const sdkMessages: AnthropicSdkMessageParam[] = []
|
||||
if (typeof messages === 'string') {
|
||||
sdkMessages.push({ role: 'user', content: messages })
|
||||
} else {
|
||||
const processedMessages = addImageFileToContents(messages)
|
||||
for (const message of processedMessages) {
|
||||
sdkMessages.push(await this.convertMessageToSdkParam(message))
|
||||
}
|
||||
}
|
||||
|
||||
if (enableWebSearch) {
|
||||
const webSearchTool = await this.getWebSearchParams(model)
|
||||
if (webSearchTool) {
|
||||
tools.push(webSearchTool)
|
||||
}
|
||||
}
|
||||
|
||||
const commonParams: MessageCreateParamsBase = {
|
||||
model: model.id,
|
||||
messages:
|
||||
isRecursiveCall && recursiveSdkMessages && recursiveSdkMessages.length > 0
|
||||
? recursiveSdkMessages
|
||||
: sdkMessages,
|
||||
max_tokens: maxTokens || DEFAULT_MAX_TOKENS,
|
||||
temperature: this.getTemperature(assistant, model),
|
||||
top_p: this.getTopP(assistant, model),
|
||||
system: systemMessage ? [systemMessage] : undefined,
|
||||
thinking: this.getBudgetToken(assistant, model),
|
||||
tools: tools.length > 0 ? tools : undefined,
|
||||
stream: streamOutput,
|
||||
// 只在对话场景下应用自定义参数,避免影响翻译、总结等其他业务逻辑
|
||||
// 注意:用户自定义参数总是应该覆盖其他参数
|
||||
...(coreRequest.callType === 'chat' ? this.getCustomParameters(assistant) : {})
|
||||
}
|
||||
|
||||
const timeout = this.getTimeout(model)
|
||||
return { payload: commonParams, messages: sdkMessages, metadata: { timeout } }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
getResponseChunkTransformer(): ResponseChunkTransformer<AnthropicSdkRawChunk> {
|
||||
return () => {
|
||||
let accumulatedJson = ''
|
||||
const toolCalls: Record<number, ToolUseBlock> = {}
|
||||
return {
|
||||
async transform(rawChunk: AnthropicSdkRawChunk, controller: TransformStreamDefaultController<GenericChunk>) {
|
||||
if (typeof rawChunk === 'string') {
|
||||
try {
|
||||
rawChunk = JSON.parse(rawChunk)
|
||||
} catch (error) {
|
||||
logger.error('invalid chunk', { rawChunk, error })
|
||||
throw new Error(t('error.chat.chunk.non_json'))
|
||||
}
|
||||
}
|
||||
switch (rawChunk.type) {
|
||||
case 'message': {
|
||||
let i = 0
|
||||
let hasTextContent = false
|
||||
let hasThinkingContent = false
|
||||
|
||||
for (const content of rawChunk.content) {
|
||||
switch (content.type) {
|
||||
case 'text': {
|
||||
if (!hasTextContent) {
|
||||
controller.enqueue({
|
||||
type: ChunkType.TEXT_START
|
||||
} as TextStartChunk)
|
||||
hasTextContent = true
|
||||
}
|
||||
controller.enqueue({
|
||||
type: ChunkType.TEXT_DELTA,
|
||||
text: content.text
|
||||
} as TextDeltaChunk)
|
||||
break
|
||||
}
|
||||
case 'tool_use': {
|
||||
toolCalls[i] = content
|
||||
i++
|
||||
break
|
||||
}
|
||||
case 'thinking': {
|
||||
if (!hasThinkingContent) {
|
||||
controller.enqueue({
|
||||
type: ChunkType.THINKING_START
|
||||
} as ThinkingStartChunk)
|
||||
hasThinkingContent = true
|
||||
}
|
||||
controller.enqueue({
|
||||
type: ChunkType.THINKING_DELTA,
|
||||
text: content.thinking
|
||||
} as ThinkingDeltaChunk)
|
||||
break
|
||||
}
|
||||
case 'web_search_tool_result': {
|
||||
controller.enqueue({
|
||||
type: ChunkType.LLM_WEB_SEARCH_COMPLETE,
|
||||
llm_web_search: {
|
||||
results: content.content,
|
||||
source: WebSearchSource.ANTHROPIC
|
||||
}
|
||||
} as LLMWebSearchCompleteChunk)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if (i > 0) {
|
||||
controller.enqueue({
|
||||
type: ChunkType.MCP_TOOL_CREATED,
|
||||
tool_calls: Object.values(toolCalls)
|
||||
} as MCPToolCreatedChunk)
|
||||
}
|
||||
controller.enqueue({
|
||||
type: ChunkType.LLM_RESPONSE_COMPLETE,
|
||||
response: {
|
||||
usage: {
|
||||
prompt_tokens: rawChunk.usage.input_tokens || 0,
|
||||
completion_tokens: rawChunk.usage.output_tokens || 0,
|
||||
total_tokens: (rawChunk.usage.input_tokens || 0) + (rawChunk.usage.output_tokens || 0)
|
||||
}
|
||||
}
|
||||
})
|
||||
break
|
||||
}
|
||||
case 'content_block_start': {
|
||||
const contentBlock = rawChunk.content_block
|
||||
switch (contentBlock.type) {
|
||||
case 'server_tool_use': {
|
||||
if (contentBlock.name === 'web_search') {
|
||||
controller.enqueue({
|
||||
type: ChunkType.LLM_WEB_SEARCH_IN_PROGRESS
|
||||
} as LLMWebSearchInProgressChunk)
|
||||
}
|
||||
break
|
||||
}
|
||||
case 'web_search_tool_result': {
|
||||
if (
|
||||
contentBlock.content &&
|
||||
(contentBlock.content as WebSearchToolResultError).type === 'web_search_tool_result_error'
|
||||
) {
|
||||
controller.enqueue({
|
||||
type: ChunkType.ERROR,
|
||||
error: {
|
||||
code: (contentBlock.content as WebSearchToolResultError).error_code,
|
||||
message: (contentBlock.content as WebSearchToolResultError).error_code
|
||||
}
|
||||
} as ErrorChunk)
|
||||
} else {
|
||||
controller.enqueue({
|
||||
type: ChunkType.LLM_WEB_SEARCH_COMPLETE,
|
||||
llm_web_search: {
|
||||
results: contentBlock.content as Array<WebSearchResultBlock>,
|
||||
source: WebSearchSource.ANTHROPIC
|
||||
}
|
||||
} as LLMWebSearchCompleteChunk)
|
||||
}
|
||||
break
|
||||
}
|
||||
case 'tool_use': {
|
||||
toolCalls[rawChunk.index] = contentBlock
|
||||
break
|
||||
}
|
||||
case 'text': {
|
||||
controller.enqueue({
|
||||
type: ChunkType.TEXT_START
|
||||
} as TextStartChunk)
|
||||
break
|
||||
}
|
||||
case 'thinking':
|
||||
case 'redacted_thinking': {
|
||||
controller.enqueue({
|
||||
type: ChunkType.THINKING_START
|
||||
} as ThinkingStartChunk)
|
||||
break
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
case 'content_block_delta': {
|
||||
const messageDelta = rawChunk.delta
|
||||
switch (messageDelta.type) {
|
||||
case 'text_delta': {
|
||||
if (messageDelta.text) {
|
||||
controller.enqueue({
|
||||
type: ChunkType.TEXT_DELTA,
|
||||
text: messageDelta.text
|
||||
} as TextDeltaChunk)
|
||||
}
|
||||
break
|
||||
}
|
||||
case 'thinking_delta': {
|
||||
if (messageDelta.thinking) {
|
||||
controller.enqueue({
|
||||
type: ChunkType.THINKING_DELTA,
|
||||
text: messageDelta.thinking
|
||||
} as ThinkingDeltaChunk)
|
||||
}
|
||||
break
|
||||
}
|
||||
case 'input_json_delta': {
|
||||
if (messageDelta.partial_json) {
|
||||
accumulatedJson += messageDelta.partial_json
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
case 'content_block_stop': {
|
||||
const toolCall = toolCalls[rawChunk.index]
|
||||
if (toolCall) {
|
||||
try {
|
||||
toolCall.input = accumulatedJson ? JSON.parse(accumulatedJson) : {}
|
||||
logger.debug(`Tool call id: ${toolCall.id}, accumulated json: ${accumulatedJson}`)
|
||||
controller.enqueue({
|
||||
type: ChunkType.MCP_TOOL_CREATED,
|
||||
tool_calls: [toolCall]
|
||||
} as MCPToolCreatedChunk)
|
||||
} catch (error) {
|
||||
logger.error('Error parsing tool call input:', error as Error)
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
case 'message_delta': {
|
||||
controller.enqueue({
|
||||
type: ChunkType.LLM_RESPONSE_COMPLETE,
|
||||
response: {
|
||||
usage: {
|
||||
prompt_tokens: rawChunk.usage.input_tokens || 0,
|
||||
completion_tokens: rawChunk.usage.output_tokens || 0,
|
||||
total_tokens: (rawChunk.usage.input_tokens || 0) + (rawChunk.usage.output_tokens || 0)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 将 ContentBlock 数组转换为 ContentBlockParam 数组
|
||||
* 去除服务器生成的额外字段,只保留发送给API所需的字段
|
||||
*/
|
||||
function convertContentBlocksToParams(contentBlocks: ContentBlock[]): ContentBlockParam[] {
|
||||
return contentBlocks.map((block): ContentBlockParam => {
|
||||
switch (block.type) {
|
||||
case 'text':
|
||||
// TextBlock -> TextBlockParam,去除 citations 等服务器字段
|
||||
return {
|
||||
type: 'text',
|
||||
text: block.text
|
||||
} satisfies TextBlockParam
|
||||
case 'tool_use':
|
||||
// ToolUseBlock -> ToolUseBlockParam
|
||||
return {
|
||||
type: 'tool_use',
|
||||
id: block.id,
|
||||
name: block.name,
|
||||
input: block.input
|
||||
} satisfies ToolUseBlockParam
|
||||
case 'thinking':
|
||||
// ThinkingBlock -> ThinkingBlockParam
|
||||
return {
|
||||
type: 'thinking',
|
||||
thinking: block.thinking,
|
||||
signature: block.signature
|
||||
} satisfies ThinkingBlockParam
|
||||
case 'redacted_thinking':
|
||||
// RedactedThinkingBlock -> RedactedThinkingBlockParam
|
||||
return {
|
||||
type: 'redacted_thinking',
|
||||
data: block.data
|
||||
} satisfies RedactedThinkingBlockParam
|
||||
case 'server_tool_use':
|
||||
// ServerToolUseBlock -> ServerToolUseBlockParam
|
||||
return {
|
||||
type: 'server_tool_use',
|
||||
id: block.id,
|
||||
name: block.name,
|
||||
input: block.input
|
||||
} satisfies ServerToolUseBlockParam
|
||||
case 'web_search_tool_result':
|
||||
// WebSearchToolResultBlock -> WebSearchToolResultBlockParam
|
||||
return {
|
||||
type: 'web_search_tool_result',
|
||||
tool_use_id: block.tool_use_id,
|
||||
content: block.content
|
||||
} satisfies WebSearchToolResultBlockParam
|
||||
default:
|
||||
return block as ContentBlockParam
|
||||
}
|
||||
})
|
||||
}
|
||||
@ -1,104 +0,0 @@
|
||||
import type Anthropic from '@anthropic-ai/sdk'
|
||||
import AnthropicVertex from '@anthropic-ai/vertex-sdk'
|
||||
import { loggerService } from '@logger'
|
||||
import { getVertexAILocation, getVertexAIProjectId, getVertexAIServiceAccount } from '@renderer/hooks/useVertexAI'
|
||||
import type { Provider } from '@renderer/types'
|
||||
import { isEmpty } from 'lodash'
|
||||
|
||||
import { AnthropicAPIClient } from './AnthropicAPIClient'
|
||||
|
||||
const logger = loggerService.withContext('AnthropicVertexClient')
|
||||
|
||||
export class AnthropicVertexClient extends AnthropicAPIClient {
|
||||
sdkInstance: AnthropicVertex | undefined = undefined
|
||||
private authHeaders?: Record<string, string>
|
||||
private authHeadersExpiry?: number
|
||||
|
||||
constructor(provider: Provider) {
|
||||
super(provider)
|
||||
}
|
||||
|
||||
private formatApiHost(host: string): string {
|
||||
const forceUseOriginalHost = () => {
|
||||
return host.endsWith('/')
|
||||
}
|
||||
|
||||
if (!host) {
|
||||
return host
|
||||
}
|
||||
|
||||
return forceUseOriginalHost() ? host : `${host}/v1/`
|
||||
}
|
||||
|
||||
override getBaseURL() {
|
||||
return this.formatApiHost(this.provider.apiHost)
|
||||
}
|
||||
|
||||
override async getSdkInstance(): Promise<AnthropicVertex> {
|
||||
if (this.sdkInstance) {
|
||||
return this.sdkInstance
|
||||
}
|
||||
|
||||
const serviceAccount = getVertexAIServiceAccount()
|
||||
const projectId = getVertexAIProjectId()
|
||||
const location = getVertexAILocation()
|
||||
|
||||
if (!serviceAccount.privateKey || !serviceAccount.clientEmail || !projectId || !location) {
|
||||
throw new Error('Vertex AI settings are not configured')
|
||||
}
|
||||
|
||||
const authHeaders = await this.getServiceAccountAuthHeaders()
|
||||
|
||||
this.sdkInstance = new AnthropicVertex({
|
||||
projectId: projectId,
|
||||
region: location,
|
||||
dangerouslyAllowBrowser: true,
|
||||
defaultHeaders: authHeaders,
|
||||
baseURL: isEmpty(this.getBaseURL()) ? undefined : this.getBaseURL()
|
||||
})
|
||||
|
||||
return this.sdkInstance
|
||||
}
|
||||
|
||||
override async listModels(): Promise<Anthropic.ModelInfo[]> {
|
||||
throw new Error('Vertex AI does not support listModels method.')
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取认证头,如果配置了 service account 则从主进程获取
|
||||
*/
|
||||
private async getServiceAccountAuthHeaders(): Promise<Record<string, string> | undefined> {
|
||||
const serviceAccount = getVertexAIServiceAccount()
|
||||
const projectId = getVertexAIProjectId()
|
||||
|
||||
// 检查是否配置了 service account
|
||||
if (!serviceAccount.privateKey || !serviceAccount.clientEmail || !projectId) {
|
||||
return undefined
|
||||
}
|
||||
|
||||
// 检查是否已有有效的认证头(提前 5 分钟过期)
|
||||
const now = Date.now()
|
||||
if (this.authHeaders && this.authHeadersExpiry && this.authHeadersExpiry - now > 5 * 60 * 1000) {
|
||||
return this.authHeaders
|
||||
}
|
||||
|
||||
try {
|
||||
// 从主进程获取认证头
|
||||
this.authHeaders = await window.api.vertexAI.getAuthHeaders({
|
||||
projectId,
|
||||
serviceAccount: {
|
||||
privateKey: serviceAccount.privateKey,
|
||||
clientEmail: serviceAccount.clientEmail
|
||||
}
|
||||
})
|
||||
|
||||
// 设置过期时间(通常认证头有效期为 1 小时)
|
||||
this.authHeadersExpiry = now + 60 * 60 * 1000
|
||||
|
||||
return this.authHeaders
|
||||
} catch (error: any) {
|
||||
logger.error('Failed to get auth headers:', error)
|
||||
throw new Error(`Service Account authentication failed: ${error.message}`)
|
||||
}
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@ -1,51 +0,0 @@
|
||||
import type OpenAI from '@cherrystudio/openai'
|
||||
import type { Provider } from '@renderer/types'
|
||||
import type { OpenAISdkParams, OpenAISdkRawOutput } from '@renderer/types/sdk'
|
||||
|
||||
import { OpenAIAPIClient } from '../openai/OpenAIApiClient'
|
||||
|
||||
export class CherryAiAPIClient extends OpenAIAPIClient {
|
||||
constructor(provider: Provider) {
|
||||
super(provider)
|
||||
}
|
||||
|
||||
override async createCompletions(
|
||||
payload: OpenAISdkParams,
|
||||
options?: OpenAI.RequestOptions
|
||||
): Promise<OpenAISdkRawOutput> {
|
||||
const sdk = await this.getSdkInstance()
|
||||
options = options || {}
|
||||
options.headers = options.headers || {}
|
||||
|
||||
const signature = await window.api.cherryai.generateSignature({
|
||||
method: 'POST',
|
||||
path: '/chat/completions',
|
||||
query: '',
|
||||
body: payload
|
||||
})
|
||||
|
||||
options.headers = {
|
||||
...options.headers,
|
||||
...signature
|
||||
}
|
||||
|
||||
// @ts-ignore - SDK参数可能有额外的字段
|
||||
return await sdk.chat.completions.create(payload, options)
|
||||
}
|
||||
|
||||
override getClientCompatibilityType(): string[] {
|
||||
return ['CherryAiAPIClient']
|
||||
}
|
||||
|
||||
public async listModels(): Promise<OpenAI.Models.Model[]> {
|
||||
const models = ['glm-4.5-flash', 'Qwen/Qwen3-8B']
|
||||
|
||||
const created = Date.now()
|
||||
return models.map((id) => ({
|
||||
id,
|
||||
owned_by: 'cherryai',
|
||||
object: 'model' as const,
|
||||
created
|
||||
}))
|
||||
}
|
||||
}
|
||||
@ -1,841 +0,0 @@
|
||||
import type {
|
||||
Content,
|
||||
File,
|
||||
FunctionCall,
|
||||
GenerateContentConfig,
|
||||
GenerateImagesConfig,
|
||||
Model as GeminiModel,
|
||||
Part,
|
||||
SafetySetting,
|
||||
SendMessageParameters,
|
||||
ThinkingConfig,
|
||||
Tool
|
||||
} from '@google/genai'
|
||||
import { createPartFromUri, GoogleGenAI, HarmBlockThreshold, HarmCategory, Modality } from '@google/genai'
|
||||
import { loggerService } from '@logger'
|
||||
import { nanoid } from '@reduxjs/toolkit'
|
||||
import {
|
||||
findTokenLimit,
|
||||
GEMINI_FLASH_MODEL_REGEX,
|
||||
isGemmaModel,
|
||||
isSupportedThinkingTokenGeminiModel,
|
||||
isVisionModel
|
||||
} from '@renderer/config/models'
|
||||
import { estimateTextTokens } from '@renderer/services/TokenService'
|
||||
import type {
|
||||
Assistant,
|
||||
FileMetadata,
|
||||
FileUploadResponse,
|
||||
GenerateImageParams,
|
||||
MCPCallToolResponse,
|
||||
MCPTool,
|
||||
MCPToolResponse,
|
||||
Model,
|
||||
Provider,
|
||||
ToolCallResponse
|
||||
} from '@renderer/types'
|
||||
import { EFFORT_RATIO, FileTypes, WebSearchSource } from '@renderer/types'
|
||||
import type { LLMWebSearchCompleteChunk, TextStartChunk, ThinkingStartChunk } from '@renderer/types/chunk'
|
||||
import { ChunkType } from '@renderer/types/chunk'
|
||||
import type { Message } from '@renderer/types/newMessage'
|
||||
import type {
|
||||
GeminiOptions,
|
||||
GeminiSdkMessageParam,
|
||||
GeminiSdkParams,
|
||||
GeminiSdkRawChunk,
|
||||
GeminiSdkRawOutput,
|
||||
GeminiSdkToolCall
|
||||
} from '@renderer/types/sdk'
|
||||
import { isToolUseModeFunction } from '@renderer/utils/assistant'
|
||||
import {
|
||||
geminiFunctionCallToMcpTool,
|
||||
isSupportedToolUse,
|
||||
mcpToolCallResponseToGeminiMessage,
|
||||
mcpToolsToGeminiTools
|
||||
} from '@renderer/utils/mcp-tools'
|
||||
import { findFileBlocks, findImageBlocks, getMainTextContent } from '@renderer/utils/messageUtils/find'
|
||||
import { defaultTimeout, MB } from '@shared/config/constant'
|
||||
import { getTrailingApiVersion, withoutTrailingApiVersion } from '@shared/utils'
|
||||
import { t } from 'i18next'
|
||||
|
||||
import type { GenericChunk } from '../../middleware/schemas'
|
||||
import { BaseApiClient } from '../BaseApiClient'
|
||||
import type { RequestTransformer, ResponseChunkTransformer } from '../types'
|
||||
|
||||
const logger = loggerService.withContext('GeminiAPIClient')
|
||||
|
||||
export class GeminiAPIClient extends BaseApiClient<
|
||||
GoogleGenAI,
|
||||
GeminiSdkParams,
|
||||
GeminiSdkRawOutput,
|
||||
GeminiSdkRawChunk,
|
||||
GeminiSdkMessageParam,
|
||||
GeminiSdkToolCall,
|
||||
Tool
|
||||
> {
|
||||
constructor(provider: Provider) {
|
||||
super(provider)
|
||||
}
|
||||
|
||||
override async createCompletions(payload: GeminiSdkParams, options?: GeminiOptions): Promise<GeminiSdkRawOutput> {
|
||||
const sdk = await this.getSdkInstance()
|
||||
const { model, history, ...rest } = payload
|
||||
const realPayload: Omit<GeminiSdkParams, 'model'> = {
|
||||
...rest,
|
||||
config: {
|
||||
...rest.config,
|
||||
abortSignal: options?.signal,
|
||||
httpOptions: {
|
||||
...rest.config?.httpOptions,
|
||||
timeout: options?.timeout
|
||||
}
|
||||
}
|
||||
} satisfies SendMessageParameters
|
||||
|
||||
const streamOutput = options?.streamOutput
|
||||
|
||||
const chat = sdk.chats.create({
|
||||
model: model,
|
||||
history: history
|
||||
})
|
||||
|
||||
if (streamOutput) {
|
||||
const stream = chat.sendMessageStream(realPayload)
|
||||
return stream
|
||||
} else {
|
||||
const response = await chat.sendMessage(realPayload)
|
||||
return response
|
||||
}
|
||||
}
|
||||
|
||||
override async generateImage(generateImageParams: GenerateImageParams): Promise<string[]> {
|
||||
const sdk = await this.getSdkInstance()
|
||||
try {
|
||||
const { model, prompt, imageSize, batchSize, signal } = generateImageParams
|
||||
const config: GenerateImagesConfig = {
|
||||
numberOfImages: batchSize,
|
||||
aspectRatio: imageSize,
|
||||
abortSignal: signal,
|
||||
httpOptions: {
|
||||
timeout: defaultTimeout
|
||||
}
|
||||
}
|
||||
const response = await sdk.models.generateImages({
|
||||
model: model,
|
||||
prompt,
|
||||
config
|
||||
})
|
||||
|
||||
if (!response.generatedImages || response.generatedImages.length === 0) {
|
||||
return []
|
||||
}
|
||||
|
||||
const images = response.generatedImages
|
||||
.filter((image) => image.image?.imageBytes)
|
||||
.map((image) => {
|
||||
const dataPrefix = `data:${image.image?.mimeType || 'image/png'};base64,`
|
||||
return dataPrefix + image.image?.imageBytes
|
||||
})
|
||||
// console.log(response?.generatedImages?.[0]?.image?.imageBytes);
|
||||
return images
|
||||
} catch (error) {
|
||||
logger.error('[generateImage] error:', error as Error)
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
override async getEmbeddingDimensions(model: Model): Promise<number> {
|
||||
const sdk = await this.getSdkInstance()
|
||||
|
||||
const data = await sdk.models.embedContent({
|
||||
model: model.id,
|
||||
contents: [{ role: 'user', parts: [{ text: 'hi' }] }]
|
||||
})
|
||||
return data.embeddings?.[0]?.values?.length || 0
|
||||
}
|
||||
|
||||
override async listModels(): Promise<GeminiModel[]> {
|
||||
const sdk = await this.getSdkInstance()
|
||||
const response = await sdk.models.list()
|
||||
const models: GeminiModel[] = []
|
||||
for await (const model of response) {
|
||||
models.push(model)
|
||||
}
|
||||
return models
|
||||
}
|
||||
|
||||
override getBaseURL(): string {
|
||||
return withoutTrailingApiVersion(super.getBaseURL())
|
||||
}
|
||||
|
||||
override async getSdkInstance() {
|
||||
if (this.sdkInstance) {
|
||||
return this.sdkInstance
|
||||
}
|
||||
|
||||
const apiVersion = this.getApiVersion()
|
||||
|
||||
this.sdkInstance = new GoogleGenAI({
|
||||
vertexai: false,
|
||||
apiKey: this.apiKey,
|
||||
apiVersion,
|
||||
httpOptions: {
|
||||
baseUrl: this.getBaseURL(),
|
||||
apiVersion,
|
||||
headers: {
|
||||
...this.provider.extra_headers
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
return this.sdkInstance
|
||||
}
|
||||
|
||||
protected getApiVersion(): string {
|
||||
if (this.provider.isVertex) {
|
||||
return 'v1'
|
||||
}
|
||||
|
||||
// Extract trailing API version from the URL
|
||||
const trailingVersion = getTrailingApiVersion(this.provider.apiHost || '')
|
||||
if (trailingVersion) {
|
||||
return trailingVersion
|
||||
}
|
||||
|
||||
return ''
|
||||
}
|
||||
|
||||
/**
|
||||
* Handle a PDF file
|
||||
* @param file - The file
|
||||
* @returns The part
|
||||
*/
|
||||
private async handlePdfFile(file: FileMetadata): Promise<Part> {
|
||||
const smallFileSize = 20 * MB
|
||||
const isSmallFile = file.size < smallFileSize
|
||||
|
||||
if (isSmallFile) {
|
||||
const { data, mimeType } = await this.base64File(file)
|
||||
return {
|
||||
inlineData: {
|
||||
data,
|
||||
mimeType
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Retrieve file from Gemini uploaded files
|
||||
const fileMetadata: FileUploadResponse = await window.api.fileService.retrieve(this.provider, file.id)
|
||||
|
||||
if (fileMetadata.status === 'success') {
|
||||
const remoteFile = fileMetadata.originalFile?.file as File
|
||||
return createPartFromUri(remoteFile.uri!, remoteFile.mimeType!)
|
||||
}
|
||||
|
||||
// If file is not found, upload it to Gemini
|
||||
const result = await window.api.fileService.upload(this.provider, file)
|
||||
const remoteFile = result.originalFile
|
||||
if (!remoteFile) {
|
||||
throw new Error('File upload failed, please try again')
|
||||
}
|
||||
if (remoteFile.type === 'gemini') {
|
||||
const file = remoteFile.file
|
||||
if (!file.uri) {
|
||||
throw new Error('File URI is required but not found')
|
||||
}
|
||||
if (!file.mimeType) {
|
||||
throw new Error('File MIME type is required but not found')
|
||||
}
|
||||
return createPartFromUri(file.uri, file.mimeType)
|
||||
} else {
|
||||
throw new Error('Unsupported file type for Gemini API')
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the message contents
|
||||
* @param message - The message
|
||||
* @returns The message contents
|
||||
*/
|
||||
private async convertMessageToSdkParam(message: Message): Promise<Content> {
|
||||
const role = message.role === 'user' ? 'user' : 'model'
|
||||
const { textContent, imageContents } = await this.getMessageContent(message)
|
||||
const parts: Part[] = [{ text: textContent }]
|
||||
|
||||
if (imageContents.length > 0) {
|
||||
for (const imageContent of imageContents) {
|
||||
const image = await window.api.file.base64Image(imageContent.fileId + imageContent.fileExt)
|
||||
parts.push({
|
||||
inlineData: {
|
||||
data: image.base64,
|
||||
mimeType: image.mime
|
||||
} satisfies Part['inlineData']
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Add any generated images from previous responses
|
||||
const imageBlocks = findImageBlocks(message)
|
||||
for (const imageBlock of imageBlocks) {
|
||||
if (
|
||||
imageBlock.metadata?.generateImageResponse?.images &&
|
||||
imageBlock.metadata.generateImageResponse.images.length > 0
|
||||
) {
|
||||
for (const imageUrl of imageBlock.metadata.generateImageResponse.images) {
|
||||
if (imageUrl && imageUrl.startsWith('data:')) {
|
||||
// Extract base64 data and mime type from the data URL
|
||||
const matches = imageUrl.match(/^data:(.+);base64,(.*)$/)
|
||||
if (matches && matches.length === 3) {
|
||||
const mimeType = matches[1]
|
||||
const base64Data = matches[2]
|
||||
parts.push({
|
||||
inlineData: {
|
||||
data: base64Data,
|
||||
mimeType: mimeType
|
||||
} satisfies Part['inlineData']
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
const file = imageBlock.file
|
||||
if (file) {
|
||||
const base64Data = await window.api.file.base64Image(file.id + file.ext)
|
||||
parts.push({
|
||||
inlineData: {
|
||||
data: base64Data.base64,
|
||||
mimeType: base64Data.mime
|
||||
} satisfies Part['inlineData']
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
const fileBlocks = findFileBlocks(message)
|
||||
for (const fileBlock of fileBlocks) {
|
||||
const file = fileBlock.file
|
||||
if (file.type === FileTypes.IMAGE) {
|
||||
const base64Data = await window.api.file.base64Image(file.id + file.ext)
|
||||
parts.push({
|
||||
inlineData: {
|
||||
data: base64Data.base64,
|
||||
mimeType: base64Data.mime
|
||||
} satisfies Part['inlineData']
|
||||
})
|
||||
}
|
||||
|
||||
if (file.ext === '.pdf') {
|
||||
parts.push(await this.handlePdfFile(file))
|
||||
continue
|
||||
}
|
||||
if ([FileTypes.TEXT, FileTypes.DOCUMENT].includes(file.type)) {
|
||||
const fileContent = await (await window.api.file.read(file.id + file.ext, true)).trim()
|
||||
parts.push({
|
||||
text: file.origin_name + '\n' + fileContent
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
role,
|
||||
parts: parts
|
||||
}
|
||||
}
|
||||
|
||||
// @ts-ignore unused
|
||||
private async getImageFileContents(message: Message): Promise<Content> {
|
||||
const role = message.role === 'user' ? 'user' : 'model'
|
||||
const content = getMainTextContent(message)
|
||||
const parts: Part[] = [{ text: content }]
|
||||
const imageBlocks = findImageBlocks(message)
|
||||
for (const imageBlock of imageBlocks) {
|
||||
if (
|
||||
imageBlock.metadata?.generateImageResponse?.images &&
|
||||
imageBlock.metadata.generateImageResponse.images.length > 0
|
||||
) {
|
||||
for (const imageUrl of imageBlock.metadata.generateImageResponse.images) {
|
||||
if (imageUrl && imageUrl.startsWith('data:')) {
|
||||
// Extract base64 data and mime type from the data URL
|
||||
const matches = imageUrl.match(/^data:(.+);base64,(.*)$/)
|
||||
if (matches && matches.length === 3) {
|
||||
const mimeType = matches[1]
|
||||
const base64Data = matches[2]
|
||||
parts.push({
|
||||
inlineData: {
|
||||
data: base64Data,
|
||||
mimeType: mimeType
|
||||
} satisfies Part['inlineData']
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
const file = imageBlock.file
|
||||
if (file) {
|
||||
const base64Data = await window.api.file.base64Image(file.id + file.ext)
|
||||
parts.push({
|
||||
inlineData: {
|
||||
data: base64Data.base64,
|
||||
mimeType: base64Data.mime
|
||||
} satisfies Part['inlineData']
|
||||
})
|
||||
}
|
||||
}
|
||||
return {
|
||||
role,
|
||||
parts: parts
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the safety settings
|
||||
* @returns The safety settings
|
||||
*/
|
||||
private getSafetySettings(): SafetySetting[] {
|
||||
const safetyThreshold = HarmBlockThreshold.OFF
|
||||
|
||||
return [
|
||||
{
|
||||
category: HarmCategory.HARM_CATEGORY_HATE_SPEECH,
|
||||
threshold: safetyThreshold
|
||||
},
|
||||
{
|
||||
category: HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
|
||||
threshold: safetyThreshold
|
||||
},
|
||||
{
|
||||
category: HarmCategory.HARM_CATEGORY_HARASSMENT,
|
||||
threshold: safetyThreshold
|
||||
},
|
||||
{
|
||||
category: HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
|
||||
threshold: safetyThreshold
|
||||
},
|
||||
{
|
||||
category: HarmCategory.HARM_CATEGORY_CIVIC_INTEGRITY,
|
||||
threshold: HarmBlockThreshold.BLOCK_NONE
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the reasoning effort for the assistant
|
||||
* @param assistant - The assistant
|
||||
* @param model - The model
|
||||
* @returns The reasoning effort
|
||||
*/
|
||||
private getBudgetToken(assistant: Assistant, model: Model) {
|
||||
if (isSupportedThinkingTokenGeminiModel(model)) {
|
||||
const reasoningEffort = assistant?.settings?.reasoning_effort
|
||||
|
||||
// 如果thinking_budget是undefined,不思考
|
||||
if (reasoningEffort === undefined) {
|
||||
return GEMINI_FLASH_MODEL_REGEX.test(model.id)
|
||||
? {
|
||||
thinkingConfig: {
|
||||
thinkingBudget: 0
|
||||
}
|
||||
}
|
||||
: {}
|
||||
}
|
||||
|
||||
if (reasoningEffort === 'auto') {
|
||||
return {
|
||||
thinkingConfig: {
|
||||
includeThoughts: true,
|
||||
thinkingBudget: -1
|
||||
}
|
||||
}
|
||||
}
|
||||
const effortRatio = EFFORT_RATIO[reasoningEffort]
|
||||
const { min, max } = findTokenLimit(model.id) || { min: 0, max: 0 }
|
||||
// 计算 budgetTokens,确保不低于 min
|
||||
const budget = Math.floor((max - min) * effortRatio + min)
|
||||
|
||||
return {
|
||||
thinkingConfig: {
|
||||
...(budget > 0 ? { thinkingBudget: budget } : {}),
|
||||
includeThoughts: true
|
||||
} satisfies ThinkingConfig
|
||||
}
|
||||
}
|
||||
|
||||
return {}
|
||||
}
|
||||
|
||||
private getGenerateImageParameter(): Partial<GenerateContentConfig> {
|
||||
return {
|
||||
systemInstruction: undefined,
|
||||
responseModalities: [Modality.TEXT, Modality.IMAGE]
|
||||
}
|
||||
}
|
||||
|
||||
getRequestTransformer(): RequestTransformer<GeminiSdkParams, GeminiSdkMessageParam> {
|
||||
return {
|
||||
transform: async (
|
||||
coreRequest,
|
||||
assistant,
|
||||
model,
|
||||
isRecursiveCall,
|
||||
recursiveSdkMessages
|
||||
): Promise<{
|
||||
payload: GeminiSdkParams
|
||||
messages: GeminiSdkMessageParam[]
|
||||
metadata: Record<string, any>
|
||||
}> => {
|
||||
const { messages, mcpTools, maxTokens, enableWebSearch, enableUrlContext, enableGenerateImage } = coreRequest
|
||||
// 1. 处理系统消息
|
||||
const systemInstruction = assistant.prompt
|
||||
|
||||
// 2. 设置工具
|
||||
const { tools } = this.setupToolsConfig({
|
||||
mcpTools,
|
||||
model,
|
||||
enableToolUse: isSupportedToolUse(assistant)
|
||||
})
|
||||
|
||||
let messageContents: Content = { role: 'user', parts: [] } // Initialize messageContents
|
||||
const history: Content[] = []
|
||||
// 3. 处理用户消息
|
||||
if (typeof messages === 'string') {
|
||||
messageContents = {
|
||||
role: 'user',
|
||||
parts: [{ text: messages }]
|
||||
}
|
||||
} else {
|
||||
const userLastMessage = messages.pop()
|
||||
if (userLastMessage) {
|
||||
messageContents = await this.convertMessageToSdkParam(userLastMessage)
|
||||
for (const message of messages) {
|
||||
history.push(await this.convertMessageToSdkParam(message))
|
||||
}
|
||||
messages.push(userLastMessage)
|
||||
}
|
||||
}
|
||||
|
||||
if (tools.length === 0 || !isToolUseModeFunction(assistant)) {
|
||||
if (enableWebSearch) {
|
||||
tools.push({
|
||||
googleSearch: {}
|
||||
})
|
||||
}
|
||||
|
||||
if (enableUrlContext) {
|
||||
tools.push({
|
||||
urlContext: {}
|
||||
})
|
||||
}
|
||||
} else if (enableWebSearch || enableUrlContext) {
|
||||
logger.warn('Native tools cannot be used with function calling for now.')
|
||||
}
|
||||
|
||||
if (isGemmaModel(model) && assistant.prompt) {
|
||||
const isFirstMessage = history.length === 0
|
||||
if (isFirstMessage && messageContents) {
|
||||
const userMessageText =
|
||||
messageContents.parts && messageContents.parts.length > 0 ? (messageContents.parts[0].text ?? '') : ''
|
||||
const systemMessage = [
|
||||
{
|
||||
text:
|
||||
'<start_of_turn>user\n' +
|
||||
systemInstruction +
|
||||
'<end_of_turn>\n' +
|
||||
'<start_of_turn>user\n' +
|
||||
userMessageText +
|
||||
'<end_of_turn>'
|
||||
}
|
||||
] satisfies Part[]
|
||||
if (messageContents && messageContents.parts) {
|
||||
messageContents.parts[0] = systemMessage[0]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const newHistory =
|
||||
isRecursiveCall && recursiveSdkMessages && recursiveSdkMessages.length > 0
|
||||
? recursiveSdkMessages.slice(0, recursiveSdkMessages.length - 1)
|
||||
: history
|
||||
|
||||
const newMessageContents =
|
||||
isRecursiveCall && recursiveSdkMessages && recursiveSdkMessages.length > 0
|
||||
? recursiveSdkMessages[recursiveSdkMessages.length - 1]
|
||||
: messageContents
|
||||
|
||||
const generateContentConfig: GenerateContentConfig = {
|
||||
safetySettings: this.getSafetySettings(),
|
||||
systemInstruction: isGemmaModel(model) ? undefined : systemInstruction,
|
||||
temperature: this.getTemperature(assistant, model),
|
||||
topP: this.getTopP(assistant, model),
|
||||
maxOutputTokens: maxTokens,
|
||||
tools: tools,
|
||||
...(enableGenerateImage ? this.getGenerateImageParameter() : {}),
|
||||
...this.getBudgetToken(assistant, model),
|
||||
// 只在对话场景下应用自定义参数,避免影响翻译、总结等其他业务逻辑
|
||||
// 注意:用户自定义参数总是应该覆盖其他参数
|
||||
...(coreRequest.callType === 'chat' ? this.getCustomParameters(assistant) : {})
|
||||
}
|
||||
|
||||
const param: GeminiSdkParams = {
|
||||
model: model.id,
|
||||
config: generateContentConfig,
|
||||
history: newHistory,
|
||||
message: newMessageContents.parts!
|
||||
}
|
||||
|
||||
return {
|
||||
payload: param,
|
||||
messages: [messageContents],
|
||||
metadata: {}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
getResponseChunkTransformer(): ResponseChunkTransformer<GeminiSdkRawChunk> {
|
||||
const toolCalls: FunctionCall[] = []
|
||||
let isFirstTextChunk = true
|
||||
let isFirstThinkingChunk = true
|
||||
return () => ({
|
||||
async transform(chunk: GeminiSdkRawChunk, controller: TransformStreamDefaultController<GenericChunk>) {
|
||||
logger.silly('chunk', chunk)
|
||||
if (typeof chunk === 'string') {
|
||||
try {
|
||||
chunk = JSON.parse(chunk)
|
||||
} catch (error) {
|
||||
logger.error('invalid chunk', { chunk, error })
|
||||
throw new Error(t('error.chat.chunk.non_json'))
|
||||
}
|
||||
}
|
||||
if (chunk.candidates && chunk.candidates.length > 0) {
|
||||
for (const candidate of chunk.candidates) {
|
||||
if (candidate.content) {
|
||||
candidate.content.parts?.forEach((part) => {
|
||||
const text = part.text || ''
|
||||
if (part.thought) {
|
||||
if (isFirstThinkingChunk) {
|
||||
controller.enqueue({
|
||||
type: ChunkType.THINKING_START
|
||||
} satisfies ThinkingStartChunk)
|
||||
isFirstThinkingChunk = false
|
||||
}
|
||||
controller.enqueue({
|
||||
type: ChunkType.THINKING_DELTA,
|
||||
text: text
|
||||
})
|
||||
} else if (part.text) {
|
||||
if (isFirstTextChunk) {
|
||||
controller.enqueue({
|
||||
type: ChunkType.TEXT_START
|
||||
} satisfies TextStartChunk)
|
||||
isFirstTextChunk = false
|
||||
}
|
||||
controller.enqueue({
|
||||
type: ChunkType.TEXT_DELTA,
|
||||
text: text
|
||||
})
|
||||
} else if (part.inlineData) {
|
||||
controller.enqueue({
|
||||
type: ChunkType.IMAGE_COMPLETE,
|
||||
image: {
|
||||
type: 'base64',
|
||||
images: [
|
||||
part.inlineData?.data?.startsWith('data:')
|
||||
? part.inlineData?.data
|
||||
: `data:${part.inlineData?.mimeType || 'image/png'};base64,${part.inlineData?.data}`
|
||||
]
|
||||
}
|
||||
})
|
||||
} else if (part.functionCall) {
|
||||
toolCalls.push(part.functionCall)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
if (candidate.finishReason) {
|
||||
if (candidate.groundingMetadata) {
|
||||
controller.enqueue({
|
||||
type: ChunkType.LLM_WEB_SEARCH_COMPLETE,
|
||||
llm_web_search: {
|
||||
results: candidate.groundingMetadata,
|
||||
source: WebSearchSource.GEMINI
|
||||
}
|
||||
} satisfies LLMWebSearchCompleteChunk)
|
||||
}
|
||||
if (toolCalls.length > 0) {
|
||||
controller.enqueue({
|
||||
type: ChunkType.MCP_TOOL_CREATED,
|
||||
tool_calls: [...toolCalls]
|
||||
})
|
||||
toolCalls.length = 0
|
||||
}
|
||||
controller.enqueue({
|
||||
type: ChunkType.LLM_RESPONSE_COMPLETE,
|
||||
response: {
|
||||
usage: {
|
||||
prompt_tokens: chunk.usageMetadata?.promptTokenCount || 0,
|
||||
completion_tokens:
|
||||
(chunk.usageMetadata?.totalTokenCount || 0) - (chunk.usageMetadata?.promptTokenCount || 0),
|
||||
total_tokens: chunk.usageMetadata?.totalTokenCount || 0
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (toolCalls.length > 0) {
|
||||
controller.enqueue({
|
||||
type: ChunkType.MCP_TOOL_CREATED,
|
||||
tool_calls: toolCalls
|
||||
})
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
public convertMcpToolsToSdkTools(mcpTools: MCPTool[]): Tool[] {
|
||||
return mcpToolsToGeminiTools(mcpTools)
|
||||
}
|
||||
|
||||
public convertSdkToolCallToMcp(toolCall: GeminiSdkToolCall, mcpTools: MCPTool[]): MCPTool | undefined {
|
||||
return geminiFunctionCallToMcpTool(mcpTools, toolCall)
|
||||
}
|
||||
|
||||
public convertSdkToolCallToMcpToolResponse(toolCall: GeminiSdkToolCall, mcpTool: MCPTool): ToolCallResponse {
|
||||
const parsedArgs = (() => {
|
||||
try {
|
||||
return typeof toolCall.args === 'string' ? JSON.parse(toolCall.args) : toolCall.args
|
||||
} catch {
|
||||
return toolCall.args
|
||||
}
|
||||
})()
|
||||
|
||||
return {
|
||||
id: toolCall.id || nanoid(),
|
||||
toolCallId: toolCall.id,
|
||||
tool: mcpTool,
|
||||
arguments: parsedArgs,
|
||||
status: 'pending'
|
||||
} satisfies ToolCallResponse
|
||||
}
|
||||
|
||||
public convertMcpToolResponseToSdkMessageParam(
|
||||
mcpToolResponse: MCPToolResponse,
|
||||
resp: MCPCallToolResponse,
|
||||
model: Model
|
||||
): GeminiSdkMessageParam | undefined {
|
||||
if ('toolUseId' in mcpToolResponse && mcpToolResponse.toolUseId) {
|
||||
return mcpToolCallResponseToGeminiMessage(mcpToolResponse, resp, isVisionModel(model))
|
||||
} else if ('toolCallId' in mcpToolResponse) {
|
||||
return {
|
||||
role: 'user',
|
||||
parts: [
|
||||
{
|
||||
functionResponse: {
|
||||
id: mcpToolResponse.toolCallId,
|
||||
name: mcpToolResponse.tool.id,
|
||||
response: {
|
||||
output: !resp.isError ? resp.content : undefined,
|
||||
error: resp.isError ? resp.content : undefined
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
} satisfies Content
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
public buildSdkMessages(
|
||||
currentReqMessages: Content[],
|
||||
output: string,
|
||||
toolResults: Content[],
|
||||
toolCalls: FunctionCall[]
|
||||
): Content[] {
|
||||
const parts: Part[] = []
|
||||
const modelParts: Part[] = []
|
||||
if (output) {
|
||||
modelParts.push({
|
||||
text: output
|
||||
})
|
||||
}
|
||||
|
||||
toolCalls.forEach((toolCall) => {
|
||||
modelParts.push({
|
||||
functionCall: toolCall
|
||||
})
|
||||
})
|
||||
|
||||
parts.push(
|
||||
...toolResults
|
||||
.map((ts) => ts.parts)
|
||||
.flat()
|
||||
.filter((p) => p !== undefined)
|
||||
)
|
||||
|
||||
const userMessage: Content = {
|
||||
role: 'user',
|
||||
parts: []
|
||||
}
|
||||
|
||||
if (modelParts.length > 0) {
|
||||
currentReqMessages.push({
|
||||
role: 'model',
|
||||
parts: modelParts
|
||||
})
|
||||
}
|
||||
if (parts.length > 0) {
|
||||
userMessage.parts?.push(...parts)
|
||||
currentReqMessages.push(userMessage)
|
||||
}
|
||||
|
||||
return currentReqMessages
|
||||
}
|
||||
|
||||
override estimateMessageTokens(message: GeminiSdkMessageParam): number {
|
||||
return (
|
||||
message.parts?.reduce((acc, part) => {
|
||||
if (part.text) {
|
||||
return acc + estimateTextTokens(part.text)
|
||||
}
|
||||
if (part.functionCall) {
|
||||
return acc + estimateTextTokens(JSON.stringify(part.functionCall))
|
||||
}
|
||||
if (part.functionResponse) {
|
||||
return acc + estimateTextTokens(JSON.stringify(part.functionResponse.response))
|
||||
}
|
||||
if (part.inlineData) {
|
||||
return acc + estimateTextTokens(part.inlineData.data || '')
|
||||
}
|
||||
if (part.fileData) {
|
||||
return acc + estimateTextTokens(part.fileData.fileUri || '')
|
||||
}
|
||||
return acc
|
||||
}, 0) || 0
|
||||
)
|
||||
}
|
||||
|
||||
public extractMessagesFromSdkPayload(sdkPayload: GeminiSdkParams): GeminiSdkMessageParam[] {
|
||||
const messageParam: GeminiSdkMessageParam = {
|
||||
role: 'user',
|
||||
parts: []
|
||||
}
|
||||
if (Array.isArray(sdkPayload.message)) {
|
||||
sdkPayload.message.forEach((part) => {
|
||||
if (typeof part === 'string') {
|
||||
messageParam.parts?.push({ text: part })
|
||||
} else if (typeof part === 'object') {
|
||||
messageParam.parts?.push(part)
|
||||
}
|
||||
})
|
||||
}
|
||||
return [...(sdkPayload.history || []), messageParam]
|
||||
}
|
||||
|
||||
private async base64File(file: FileMetadata) {
|
||||
const { data } = await window.api.file.base64File(file.id + file.ext)
|
||||
return {
|
||||
data,
|
||||
mimeType: 'application/pdf'
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1,143 +0,0 @@
|
||||
import { GoogleGenAI } from '@google/genai'
|
||||
import { loggerService } from '@logger'
|
||||
import { createVertexProvider, isVertexAIConfigured } from '@renderer/hooks/useVertexAI'
|
||||
import type { Model, Provider, VertexProvider } from '@renderer/types'
|
||||
import { isVertexProvider } from '@renderer/utils/provider'
|
||||
import { isEmpty } from 'lodash'
|
||||
|
||||
import { AnthropicVertexClient } from '../anthropic/AnthropicVertexClient'
|
||||
import { GeminiAPIClient } from './GeminiAPIClient'
|
||||
|
||||
const logger = loggerService.withContext('VertexAPIClient')
|
||||
export class VertexAPIClient extends GeminiAPIClient {
|
||||
private authHeaders?: Record<string, string>
|
||||
private authHeadersExpiry?: number
|
||||
private anthropicVertexClient: AnthropicVertexClient
|
||||
private vertexProvider: VertexProvider
|
||||
|
||||
constructor(provider: Provider) {
|
||||
super(provider)
|
||||
// 检查 VertexAI 配置
|
||||
if (!isVertexAIConfigured()) {
|
||||
throw new Error('VertexAI is not configured. Please configure project, location and service account credentials.')
|
||||
}
|
||||
this.anthropicVertexClient = new AnthropicVertexClient(provider)
|
||||
// 如果传入的是普通 Provider,转换为 VertexProvider
|
||||
if (isVertexProvider(provider)) {
|
||||
this.vertexProvider = provider
|
||||
} else {
|
||||
this.vertexProvider = createVertexProvider(provider)
|
||||
}
|
||||
}
|
||||
|
||||
override getClientCompatibilityType(model?: Model): string[] {
|
||||
if (!model) {
|
||||
return [this.constructor.name]
|
||||
}
|
||||
|
||||
const actualClient = this.getClient(model)
|
||||
if (actualClient === this) {
|
||||
return [this.constructor.name]
|
||||
}
|
||||
|
||||
return actualClient.getClientCompatibilityType(model)
|
||||
}
|
||||
|
||||
public getClient(model: Model) {
|
||||
if (model.id.includes('claude')) {
|
||||
return this.anthropicVertexClient
|
||||
}
|
||||
return this
|
||||
}
|
||||
|
||||
private formatApiHost(baseUrl: string) {
|
||||
if (baseUrl.endsWith('/v1/')) {
|
||||
baseUrl = baseUrl.slice(0, -4)
|
||||
} else if (baseUrl.endsWith('/v1')) {
|
||||
baseUrl = baseUrl.slice(0, -3)
|
||||
}
|
||||
return baseUrl
|
||||
}
|
||||
|
||||
override getBaseURL() {
|
||||
return this.formatApiHost(this.provider.apiHost)
|
||||
}
|
||||
|
||||
override async getSdkInstance() {
|
||||
if (this.sdkInstance) {
|
||||
return this.sdkInstance
|
||||
}
|
||||
|
||||
const { googleCredentials, project, location } = this.vertexProvider
|
||||
|
||||
if (!googleCredentials.privateKey || !googleCredentials.clientEmail || !project || !location) {
|
||||
throw new Error('Vertex AI settings are not configured')
|
||||
}
|
||||
|
||||
const authHeaders = await this.getServiceAccountAuthHeaders()
|
||||
|
||||
this.sdkInstance = new GoogleGenAI({
|
||||
vertexai: true,
|
||||
project: project,
|
||||
location: location,
|
||||
httpOptions: {
|
||||
apiVersion: this.getApiVersion(),
|
||||
headers: authHeaders,
|
||||
baseUrl: isEmpty(this.getBaseURL()) ? undefined : this.getBaseURL()
|
||||
}
|
||||
})
|
||||
|
||||
return this.sdkInstance
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取认证头,如果配置了 service account 则从主进程获取
|
||||
*/
|
||||
private async getServiceAccountAuthHeaders(): Promise<Record<string, string> | undefined> {
|
||||
const { googleCredentials, project } = this.vertexProvider
|
||||
|
||||
// 检查是否配置了 service account
|
||||
if (!googleCredentials.privateKey || !googleCredentials.clientEmail || !project) {
|
||||
return undefined
|
||||
}
|
||||
|
||||
// 检查是否已有有效的认证头(提前 5 分钟过期)
|
||||
const now = Date.now()
|
||||
if (this.authHeaders && this.authHeadersExpiry && this.authHeadersExpiry - now > 5 * 60 * 1000) {
|
||||
return this.authHeaders
|
||||
}
|
||||
|
||||
try {
|
||||
// 从主进程获取认证头
|
||||
this.authHeaders = await window.api.vertexAI.getAuthHeaders({
|
||||
projectId: project,
|
||||
serviceAccount: {
|
||||
privateKey: googleCredentials.privateKey,
|
||||
clientEmail: googleCredentials.clientEmail
|
||||
}
|
||||
})
|
||||
|
||||
// 设置过期时间(通常认证头有效期为 1 小时)
|
||||
this.authHeadersExpiry = now + 60 * 60 * 1000
|
||||
|
||||
return this.authHeaders
|
||||
} catch (error: any) {
|
||||
logger.error('Failed to get auth headers:', error)
|
||||
throw new Error(`Service Account authentication failed: ${error.message}`)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 清理认证缓存并重新初始化
|
||||
*/
|
||||
clearAuthCache(): void {
|
||||
this.authHeaders = undefined
|
||||
this.authHeadersExpiry = undefined
|
||||
|
||||
const { googleCredentials, project } = this.vertexProvider
|
||||
|
||||
if (project && googleCredentials.clientEmail) {
|
||||
window.api.vertexAI.clearAuthCache(project, googleCredentials.clientEmail)
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1,8 +0,0 @@
|
||||
export * from './ApiClientFactory'
|
||||
export * from './BaseApiClient'
|
||||
export * from './types'
|
||||
|
||||
// Export specific clients from subdirectories
|
||||
export * from './anthropic/AnthropicAPIClient'
|
||||
export * from './openai/OpenAIApiClient'
|
||||
export * from './openai/OpenAIResponseAPIClient'
|
||||
@ -1,110 +0,0 @@
|
||||
import { loggerService } from '@logger'
|
||||
import { isSupportedModel } from '@renderer/config/models'
|
||||
import type { Model, Provider } from '@renderer/types'
|
||||
import type { NewApiModel } from '@renderer/types/sdk'
|
||||
|
||||
import { AnthropicAPIClient } from '../anthropic/AnthropicAPIClient'
|
||||
import type { BaseApiClient } from '../BaseApiClient'
|
||||
import { GeminiAPIClient } from '../gemini/GeminiAPIClient'
|
||||
import { MixedBaseAPIClient } from '../MixedBaseApiClient'
|
||||
import { OpenAIAPIClient } from '../openai/OpenAIApiClient'
|
||||
import { OpenAIResponseAPIClient } from '../openai/OpenAIResponseAPIClient'
|
||||
|
||||
const logger = loggerService.withContext('NewAPIClient')
|
||||
|
||||
export class NewAPIClient extends MixedBaseAPIClient {
|
||||
// 使用联合类型而不是any,保持类型安全
|
||||
protected clients: Map<string, AnthropicAPIClient | GeminiAPIClient | OpenAIResponseAPIClient | OpenAIAPIClient> =
|
||||
new Map()
|
||||
protected defaultClient: OpenAIAPIClient
|
||||
protected currentClient: BaseApiClient
|
||||
|
||||
constructor(provider: Provider) {
|
||||
super(provider)
|
||||
|
||||
const claudeClient = new AnthropicAPIClient(provider)
|
||||
const geminiClient = new GeminiAPIClient(provider)
|
||||
const openaiClient = new OpenAIAPIClient(provider)
|
||||
const openaiResponseClient = new OpenAIResponseAPIClient(provider)
|
||||
|
||||
this.clients.set('claude', claudeClient)
|
||||
this.clients.set('gemini', geminiClient)
|
||||
this.clients.set('openai', openaiClient)
|
||||
this.clients.set('openai-response', openaiResponseClient)
|
||||
|
||||
// 设置默认client
|
||||
this.defaultClient = openaiClient
|
||||
this.currentClient = this.defaultClient as BaseApiClient
|
||||
}
|
||||
|
||||
override getBaseURL(): string {
|
||||
if (!this.currentClient) {
|
||||
return this.provider.apiHost
|
||||
}
|
||||
return this.currentClient.getBaseURL()
|
||||
}
|
||||
|
||||
/**
|
||||
* 根据模型获取合适的client
|
||||
*/
|
||||
protected getClient(model: Model): BaseApiClient {
|
||||
if (!model.endpoint_type) {
|
||||
throw new Error('Model endpoint type is not defined')
|
||||
}
|
||||
|
||||
if (model.endpoint_type === 'anthropic') {
|
||||
const client = this.clients.get('claude')
|
||||
if (!client || !this.isValidClient(client)) {
|
||||
throw new Error('Failed to get claude client')
|
||||
}
|
||||
return client
|
||||
}
|
||||
|
||||
if (model.endpoint_type === 'openai-response') {
|
||||
const client = this.clients.get('openai-response')
|
||||
if (!client || !this.isValidClient(client)) {
|
||||
throw new Error('Failed to get openai-response client')
|
||||
}
|
||||
return client
|
||||
}
|
||||
|
||||
if (model.endpoint_type === 'gemini') {
|
||||
const client = this.clients.get('gemini')
|
||||
if (!client || !this.isValidClient(client)) {
|
||||
throw new Error('Failed to get gemini client')
|
||||
}
|
||||
return client
|
||||
}
|
||||
|
||||
if (model.endpoint_type === 'openai' || model.endpoint_type === 'image-generation') {
|
||||
const client = this.clients.get('openai')
|
||||
if (!client || !this.isValidClient(client)) {
|
||||
throw new Error('Failed to get openai client')
|
||||
}
|
||||
return client
|
||||
}
|
||||
|
||||
throw new Error('Invalid model endpoint type: ' + model.endpoint_type)
|
||||
}
|
||||
|
||||
override async listModels(): Promise<NewApiModel[]> {
|
||||
try {
|
||||
const sdk = await this.defaultClient.getSdkInstance()
|
||||
// Explicitly type the expected response shape so that `data` is recognised.
|
||||
const response = await sdk.request<{ data: NewApiModel[] }>({
|
||||
method: 'get',
|
||||
path: '/models'
|
||||
})
|
||||
const models: NewApiModel[] = response.data ?? []
|
||||
|
||||
models.forEach((model) => {
|
||||
model.id = model.id.trim()
|
||||
})
|
||||
|
||||
return models.filter(isSupportedModel)
|
||||
} catch (error) {
|
||||
logger.error('Error listing models:', error as Error)
|
||||
return []
|
||||
}
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@ -1,309 +0,0 @@
|
||||
import OpenAI, { AzureOpenAI } from '@cherrystudio/openai'
|
||||
import { loggerService } from '@logger'
|
||||
import { COPILOT_DEFAULT_HEADERS } from '@renderer/aiCore/provider/constants'
|
||||
import {
|
||||
isClaudeReasoningModel,
|
||||
isOpenAIReasoningModel,
|
||||
isSupportedModel,
|
||||
isSupportedReasoningEffortOpenAIModel
|
||||
} from '@renderer/config/models'
|
||||
import { getStoreSetting } from '@renderer/hooks/useSettings'
|
||||
import { getAssistantSettings } from '@renderer/services/AssistantService'
|
||||
import store from '@renderer/store'
|
||||
import type { SettingsState } from '@renderer/store/settings'
|
||||
import { type Assistant, type GenerateImageParams, type Model, type Provider } from '@renderer/types'
|
||||
import type {
|
||||
OpenAIResponseSdkMessageParam,
|
||||
OpenAIResponseSdkParams,
|
||||
OpenAIResponseSdkRawChunk,
|
||||
OpenAIResponseSdkRawOutput,
|
||||
OpenAIResponseSdkTool,
|
||||
OpenAIResponseSdkToolCall,
|
||||
OpenAISdkMessageParam,
|
||||
OpenAISdkParams,
|
||||
OpenAISdkRawChunk,
|
||||
OpenAISdkRawOutput,
|
||||
ReasoningEffortOptionalParams
|
||||
} from '@renderer/types/sdk'
|
||||
import { withoutTrailingSlash } from '@renderer/utils/api'
|
||||
import { isOllamaProvider } from '@renderer/utils/provider'
|
||||
|
||||
import { BaseApiClient } from '../BaseApiClient'
|
||||
import { normalizeAzureOpenAIEndpoint } from './azureOpenAIEndpoint'
|
||||
|
||||
const logger = loggerService.withContext('OpenAIBaseClient')
|
||||
|
||||
/**
|
||||
* 抽象的OpenAI基础客户端类,包含两个OpenAI客户端之间的共享功能
|
||||
*/
|
||||
export abstract class OpenAIBaseClient<
|
||||
TSdkInstance extends OpenAI | AzureOpenAI,
|
||||
TSdkParams extends OpenAISdkParams | OpenAIResponseSdkParams,
|
||||
TRawOutput extends OpenAISdkRawOutput | OpenAIResponseSdkRawOutput,
|
||||
TRawChunk extends OpenAISdkRawChunk | OpenAIResponseSdkRawChunk,
|
||||
TMessageParam extends OpenAISdkMessageParam | OpenAIResponseSdkMessageParam,
|
||||
TToolCall extends OpenAI.Chat.Completions.ChatCompletionMessageToolCall | OpenAIResponseSdkToolCall,
|
||||
TSdkSpecificTool extends OpenAI.Chat.Completions.ChatCompletionTool | OpenAIResponseSdkTool
|
||||
> extends BaseApiClient<TSdkInstance, TSdkParams, TRawOutput, TRawChunk, TMessageParam, TToolCall, TSdkSpecificTool> {
|
||||
constructor(provider: Provider) {
|
||||
super(provider)
|
||||
}
|
||||
|
||||
// 仅适用于openai
|
||||
override getBaseURL(): string {
|
||||
// apiHost is formatted when called by AiProvider
|
||||
return this.provider.apiHost
|
||||
}
|
||||
|
||||
override async generateImage({
|
||||
model,
|
||||
prompt,
|
||||
negativePrompt,
|
||||
imageSize,
|
||||
batchSize,
|
||||
seed,
|
||||
numInferenceSteps,
|
||||
guidanceScale,
|
||||
signal,
|
||||
promptEnhancement
|
||||
}: GenerateImageParams): Promise<string[]> {
|
||||
const sdk = await this.getSdkInstance()
|
||||
const response = (await sdk.request({
|
||||
method: 'post',
|
||||
path: '/v1/images/generations',
|
||||
signal,
|
||||
body: {
|
||||
model,
|
||||
prompt,
|
||||
negative_prompt: negativePrompt,
|
||||
image_size: imageSize,
|
||||
batch_size: batchSize,
|
||||
seed: seed ? parseInt(seed) : undefined,
|
||||
num_inference_steps: numInferenceSteps,
|
||||
guidance_scale: guidanceScale,
|
||||
prompt_enhancement: promptEnhancement
|
||||
}
|
||||
})) as { data: Array<{ url: string }> }
|
||||
|
||||
return response.data.map((item) => item.url)
|
||||
}
|
||||
|
||||
override async getEmbeddingDimensions(model: Model): Promise<number> {
|
||||
let sdk: OpenAI = await this.getSdkInstance()
|
||||
if (isOllamaProvider(this.provider)) {
|
||||
const embedBaseUrl = `${this.provider.apiHost.replace(/(\/(api|v1))\/?$/, '')}/v1`
|
||||
sdk = sdk.withOptions({ baseURL: embedBaseUrl })
|
||||
}
|
||||
|
||||
const data = await sdk.embeddings.create({
|
||||
model: model.id,
|
||||
input: model?.provider === 'baidu-cloud' ? ['hi'] : 'hi',
|
||||
encoding_format: this.provider.id === 'voyageai' ? undefined : 'float'
|
||||
})
|
||||
return data.data[0].embedding.length
|
||||
}
|
||||
|
||||
override async listModels(): Promise<OpenAI.Models.Model[]> {
|
||||
try {
|
||||
const sdk = await this.getSdkInstance()
|
||||
if (this.provider.id === 'openrouter') {
|
||||
// https://openrouter.ai/docs/api/api-reference/embeddings/list-embeddings-models
|
||||
const embedBaseUrl = 'https://openrouter.ai/api/v1/embeddings'
|
||||
const embedSdk = sdk.withOptions({ baseURL: embedBaseUrl })
|
||||
const modelPromise = sdk.models.list()
|
||||
const embedModelPromise = embedSdk.models.list()
|
||||
const [modelResponse, embedModelResponse] = await Promise.all([modelPromise, embedModelPromise])
|
||||
const models = [...modelResponse.data, ...embedModelResponse.data]
|
||||
const uniqueModels = Array.from(new Map(models.map((model) => [model.id, model])).values())
|
||||
return uniqueModels.filter(isSupportedModel)
|
||||
}
|
||||
if (this.provider.id === 'github') {
|
||||
// GitHub Models 其 models 和 chat completions 两个接口的 baseUrl 不一样
|
||||
const baseUrl = 'https://models.github.ai/catalog/'
|
||||
const newSdk = sdk.withOptions({ baseURL: baseUrl })
|
||||
const response = await newSdk.models.list()
|
||||
|
||||
// @ts-ignore key is not typed
|
||||
return response?.body
|
||||
.map((model) => ({
|
||||
id: model.id,
|
||||
description: model.summary,
|
||||
object: 'model',
|
||||
owned_by: model.publisher
|
||||
}))
|
||||
.filter(isSupportedModel)
|
||||
}
|
||||
|
||||
if (isOllamaProvider(this.provider)) {
|
||||
const baseUrl = withoutTrailingSlash(this.getBaseURL())
|
||||
.replace(/\/v1$/, '')
|
||||
.replace(/\/api$/, '')
|
||||
const response = await fetch(`${baseUrl}/api/tags`, {
|
||||
headers: {
|
||||
Authorization: `Bearer ${this.apiKey}`,
|
||||
...this.defaultHeaders(),
|
||||
...this.provider.extra_headers
|
||||
}
|
||||
})
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error(`Ollama server returned ${response.status} ${response.statusText}`)
|
||||
}
|
||||
|
||||
const data = await response.json()
|
||||
if (!data?.models || !Array.isArray(data.models)) {
|
||||
throw new Error('Invalid response from Ollama API: missing models array')
|
||||
}
|
||||
|
||||
return data.models.map((model) => ({
|
||||
id: model.name,
|
||||
object: 'model',
|
||||
owned_by: 'ollama'
|
||||
}))
|
||||
}
|
||||
const response = await sdk.models.list()
|
||||
if (this.provider.id === 'together') {
|
||||
// @ts-ignore key is not typed
|
||||
return response?.body.map((model) => ({
|
||||
id: model.id,
|
||||
description: model.display_name,
|
||||
object: 'model',
|
||||
owned_by: model.organization
|
||||
}))
|
||||
}
|
||||
const models = response.data || []
|
||||
models.forEach((model) => {
|
||||
model.id = model.id.trim()
|
||||
})
|
||||
|
||||
return models.filter(isSupportedModel)
|
||||
} catch (error) {
|
||||
logger.error('Error listing models:', error as Error)
|
||||
return []
|
||||
}
|
||||
}
|
||||
|
||||
override async getSdkInstance() {
|
||||
if (this.sdkInstance) {
|
||||
return this.sdkInstance
|
||||
}
|
||||
|
||||
let apiKeyForSdkInstance = this.apiKey
|
||||
let baseURLForSdkInstance = this.getBaseURL()
|
||||
logger.debug('baseURLForSdkInstance', { baseURLForSdkInstance })
|
||||
let headersForSdkInstance = {
|
||||
...this.defaultHeaders(),
|
||||
...this.provider.extra_headers
|
||||
}
|
||||
|
||||
if (this.provider.id === 'copilot') {
|
||||
const defaultHeaders = store.getState().copilot.defaultHeaders
|
||||
const { token } = await window.api.copilot.getToken(defaultHeaders)
|
||||
// this.provider.apiKey不允许修改
|
||||
// this.provider.apiKey = token
|
||||
apiKeyForSdkInstance = token
|
||||
baseURLForSdkInstance = this.getBaseURL()
|
||||
headersForSdkInstance = {
|
||||
...headersForSdkInstance,
|
||||
...COPILOT_DEFAULT_HEADERS
|
||||
}
|
||||
}
|
||||
|
||||
if (this.provider.id === 'azure-openai' || this.provider.type === 'azure-openai') {
|
||||
this.sdkInstance = new AzureOpenAI({
|
||||
dangerouslyAllowBrowser: true,
|
||||
apiKey: apiKeyForSdkInstance,
|
||||
apiVersion: this.provider.apiVersion,
|
||||
endpoint: normalizeAzureOpenAIEndpoint(this.provider.apiHost)
|
||||
}) as TSdkInstance
|
||||
} else {
|
||||
this.sdkInstance = new OpenAI({
|
||||
dangerouslyAllowBrowser: true,
|
||||
apiKey: apiKeyForSdkInstance,
|
||||
baseURL: baseURLForSdkInstance,
|
||||
defaultHeaders: headersForSdkInstance
|
||||
}) as TSdkInstance
|
||||
}
|
||||
return this.sdkInstance
|
||||
}
|
||||
|
||||
override getTemperature(assistant: Assistant, model: Model): number | undefined {
|
||||
if (assistant.settings?.reasoning_effort && isClaudeReasoningModel(model)) {
|
||||
return undefined
|
||||
}
|
||||
return super.getTemperature(assistant, model)
|
||||
}
|
||||
|
||||
override getTopP(assistant: Assistant, model: Model): number | undefined {
|
||||
if (assistant.settings?.reasoning_effort && isClaudeReasoningModel(model)) {
|
||||
return undefined
|
||||
}
|
||||
return super.getTopP(assistant, model)
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the provider specific parameters for the assistant
|
||||
* @param assistant - The assistant
|
||||
* @param model - The model
|
||||
* @returns The provider specific parameters
|
||||
*/
|
||||
protected getProviderSpecificParameters(assistant: Assistant, model: Model) {
|
||||
const { maxTokens } = getAssistantSettings(assistant)
|
||||
|
||||
if (this.provider.id === 'openrouter') {
|
||||
if (model.id.includes('deepseek-r1')) {
|
||||
return {
|
||||
include_reasoning: true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (isOpenAIReasoningModel(model)) {
|
||||
return {
|
||||
max_tokens: undefined,
|
||||
max_completion_tokens: maxTokens
|
||||
}
|
||||
}
|
||||
|
||||
return {}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the reasoning effort for the assistant
|
||||
* @param assistant - The assistant
|
||||
* @param model - The model
|
||||
* @returns The reasoning effort
|
||||
*/
|
||||
protected getReasoningEffort(assistant: Assistant, model: Model): ReasoningEffortOptionalParams {
|
||||
if (!isSupportedReasoningEffortOpenAIModel(model)) {
|
||||
return {}
|
||||
}
|
||||
|
||||
const openAI = getStoreSetting('openAI') as SettingsState['openAI']
|
||||
const summaryText = openAI?.summaryText || 'off'
|
||||
|
||||
let summary: string | undefined = undefined
|
||||
|
||||
if (summaryText === 'off' || model.id.includes('o1-pro')) {
|
||||
summary = undefined
|
||||
} else {
|
||||
summary = summaryText
|
||||
}
|
||||
|
||||
const reasoningEffort = assistant?.settings?.reasoning_effort
|
||||
if (!reasoningEffort) {
|
||||
return {}
|
||||
}
|
||||
|
||||
if (isSupportedReasoningEffortOpenAIModel(model)) {
|
||||
return {
|
||||
reasoning: {
|
||||
effort: reasoningEffort as OpenAI.ReasoningEffort,
|
||||
summary: summary
|
||||
} as OpenAI.Reasoning
|
||||
}
|
||||
}
|
||||
|
||||
return {}
|
||||
}
|
||||
}
|
||||
@ -1,769 +0,0 @@
|
||||
import OpenAI, { AzureOpenAI } from '@cherrystudio/openai'
|
||||
import type { ResponseInput } from '@cherrystudio/openai/resources/responses/responses'
|
||||
import { loggerService } from '@logger'
|
||||
import type { GenericChunk } from '@renderer/aiCore/legacy/middleware/schemas'
|
||||
import type { CompletionsContext } from '@renderer/aiCore/legacy/middleware/types'
|
||||
import {
|
||||
isGPT5SeriesModel,
|
||||
isOpenAIChatCompletionOnlyModel,
|
||||
isOpenAILLMModel,
|
||||
isOpenAIOpenWeightModel,
|
||||
isSupportedReasoningEffortOpenAIModel,
|
||||
isSupportVerbosityModel,
|
||||
isVisionModel
|
||||
} from '@renderer/config/models'
|
||||
import { estimateTextTokens } from '@renderer/services/TokenService'
|
||||
import type {
|
||||
FileMetadata,
|
||||
MCPCallToolResponse,
|
||||
MCPTool,
|
||||
MCPToolResponse,
|
||||
Model,
|
||||
OpenAIServiceTier,
|
||||
Provider,
|
||||
ToolCallResponse
|
||||
} from '@renderer/types'
|
||||
import { FileTypes, WebSearchSource } from '@renderer/types'
|
||||
import { ChunkType } from '@renderer/types/chunk'
|
||||
import type { Message } from '@renderer/types/newMessage'
|
||||
import type {
|
||||
OpenAIResponseSdkMessageParam,
|
||||
OpenAIResponseSdkParams,
|
||||
OpenAIResponseSdkRawChunk,
|
||||
OpenAIResponseSdkRawOutput,
|
||||
OpenAIResponseSdkTool,
|
||||
OpenAIResponseSdkToolCall
|
||||
} from '@renderer/types/sdk'
|
||||
import { addImageFileToContents } from '@renderer/utils/formats'
|
||||
import {
|
||||
isSupportedToolUse,
|
||||
mcpToolCallResponseToOpenAIMessage,
|
||||
mcpToolsToOpenAIResponseTools,
|
||||
openAIToolsToMcpTool
|
||||
} from '@renderer/utils/mcp-tools'
|
||||
import { findFileBlocks, findImageBlocks } from '@renderer/utils/messageUtils/find'
|
||||
import { isSupportDeveloperRoleProvider } from '@renderer/utils/provider'
|
||||
import { MB } from '@shared/config/constant'
|
||||
import { t } from 'i18next'
|
||||
import { isEmpty } from 'lodash'
|
||||
|
||||
import type { RequestTransformer, ResponseChunkTransformer } from '../types'
|
||||
import { OpenAIAPIClient } from './OpenAIApiClient'
|
||||
import { OpenAIBaseClient } from './OpenAIBaseClient'
|
||||
|
||||
const logger = loggerService.withContext('OpenAIResponseAPIClient')
|
||||
export class OpenAIResponseAPIClient extends OpenAIBaseClient<
|
||||
OpenAI,
|
||||
OpenAIResponseSdkParams,
|
||||
OpenAIResponseSdkRawOutput,
|
||||
OpenAIResponseSdkRawChunk,
|
||||
OpenAIResponseSdkMessageParam,
|
||||
OpenAIResponseSdkToolCall,
|
||||
OpenAIResponseSdkTool
|
||||
> {
|
||||
private client: OpenAIAPIClient
|
||||
constructor(provider: Provider) {
|
||||
super(provider)
|
||||
this.client = new OpenAIAPIClient(provider)
|
||||
}
|
||||
|
||||
private formatApiHost() {
|
||||
const host = this.provider.apiHost
|
||||
if (host.endsWith('/openai/v1')) {
|
||||
return host
|
||||
} else {
|
||||
if (host.endsWith('/')) {
|
||||
return host + 'openai/v1'
|
||||
} else {
|
||||
return host + '/openai/v1'
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 根据模型特征选择合适的客户端
|
||||
*/
|
||||
public getClient(model: Model) {
|
||||
if (this.provider.type === 'openai-response' && !isOpenAIChatCompletionOnlyModel(model)) {
|
||||
return this
|
||||
}
|
||||
if (isOpenAILLMModel(model) && !isOpenAIChatCompletionOnlyModel(model)) {
|
||||
if (this.provider.id === 'azure-openai' || this.provider.type === 'azure-openai') {
|
||||
this.provider = { ...this.provider, apiHost: this.formatApiHost() }
|
||||
if (this.provider.apiVersion === 'preview' || this.provider.apiVersion === 'v1') {
|
||||
return this
|
||||
} else {
|
||||
return this.client
|
||||
}
|
||||
}
|
||||
return this
|
||||
} else {
|
||||
return this.client
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 重写基类方法,返回内部实际使用的客户端类型
|
||||
*/
|
||||
public override getClientCompatibilityType(model?: Model): string[] {
|
||||
if (!model) {
|
||||
return [this.constructor.name]
|
||||
}
|
||||
|
||||
const actualClient = this.getClient(model)
|
||||
// 避免循环调用:如果返回的是自己,直接返回自己的类型
|
||||
if (actualClient === this) {
|
||||
return [this.constructor.name]
|
||||
}
|
||||
return actualClient.getClientCompatibilityType(model)
|
||||
}
|
||||
|
||||
override async getSdkInstance() {
|
||||
if (this.sdkInstance) {
|
||||
return this.sdkInstance
|
||||
}
|
||||
const baseUrl = this.getBaseURL()
|
||||
|
||||
if (this.provider.id === 'azure-openai' || this.provider.type === 'azure-openai') {
|
||||
return new AzureOpenAI({
|
||||
dangerouslyAllowBrowser: true,
|
||||
apiKey: this.apiKey,
|
||||
apiVersion: this.provider.apiVersion,
|
||||
baseURL: this.provider.apiHost
|
||||
})
|
||||
} else {
|
||||
return new OpenAI({
|
||||
dangerouslyAllowBrowser: true,
|
||||
apiKey: this.apiKey,
|
||||
baseURL: baseUrl,
|
||||
defaultHeaders: {
|
||||
...this.defaultHeaders(),
|
||||
...this.provider.extra_headers
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
override async createCompletions(
|
||||
payload: OpenAIResponseSdkParams,
|
||||
options?: OpenAI.RequestOptions
|
||||
): Promise<OpenAIResponseSdkRawOutput> {
|
||||
const sdk = await this.getSdkInstance()
|
||||
return await sdk.responses.create(payload, options)
|
||||
}
|
||||
|
||||
private async handlePdfFile(file: FileMetadata): Promise<OpenAI.Responses.ResponseInputFile | undefined> {
|
||||
if (file.size > 32 * MB) return undefined
|
||||
try {
|
||||
const pageCount = await window.api.file.pdfInfo(file.id + file.ext)
|
||||
if (pageCount > 100) return undefined
|
||||
} catch {
|
||||
return undefined
|
||||
}
|
||||
|
||||
const { data } = await window.api.file.base64File(file.id + file.ext)
|
||||
return {
|
||||
type: 'input_file',
|
||||
filename: file.origin_name,
|
||||
file_data: `data:application/pdf;base64,${data}`
|
||||
} as OpenAI.Responses.ResponseInputFile
|
||||
}
|
||||
|
||||
public async convertMessageToSdkParam(message: Message, model: Model): Promise<OpenAIResponseSdkMessageParam> {
|
||||
const isVision = isVisionModel(model)
|
||||
const { textContent, imageContents } = await this.getMessageContent(message)
|
||||
const fileBlocks = findFileBlocks(message)
|
||||
const imageBlocks = findImageBlocks(message)
|
||||
|
||||
if (fileBlocks.length === 0 && imageBlocks.length === 0 && imageContents.length === 0) {
|
||||
if (message.role === 'assistant') {
|
||||
return {
|
||||
role: 'assistant',
|
||||
content: textContent
|
||||
}
|
||||
} else {
|
||||
return {
|
||||
role: message.role === 'system' ? 'user' : message.role,
|
||||
content: textContent ? [{ type: 'input_text', text: textContent }] : []
|
||||
} as OpenAI.Responses.EasyInputMessage
|
||||
}
|
||||
}
|
||||
|
||||
const parts: OpenAI.Responses.ResponseInputContent[] = []
|
||||
if (imageContents) {
|
||||
parts.push({
|
||||
type: 'input_text',
|
||||
text: textContent
|
||||
})
|
||||
}
|
||||
|
||||
if (imageContents.length > 0) {
|
||||
for (const imageContent of imageContents) {
|
||||
const image = await window.api.file.base64Image(imageContent.fileId + imageContent.fileExt)
|
||||
parts.push({
|
||||
detail: 'auto',
|
||||
type: 'input_image',
|
||||
image_url: image.data
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
for (const imageBlock of imageBlocks) {
|
||||
if (isVision) {
|
||||
if (imageBlock.file) {
|
||||
const image = await window.api.file.base64Image(imageBlock.file.id + imageBlock.file.ext)
|
||||
parts.push({
|
||||
detail: 'auto',
|
||||
type: 'input_image',
|
||||
image_url: image.data as string
|
||||
})
|
||||
} else if (imageBlock.url && imageBlock.url.startsWith('data:')) {
|
||||
parts.push({
|
||||
detail: 'auto',
|
||||
type: 'input_image',
|
||||
image_url: imageBlock.url
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (const fileBlock of fileBlocks) {
|
||||
const file = fileBlock.file
|
||||
if (!file) continue
|
||||
|
||||
if (isVision && file.ext === '.pdf') {
|
||||
const pdfPart = await this.handlePdfFile(file)
|
||||
if (pdfPart) {
|
||||
parts.push(pdfPart)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
if ([FileTypes.TEXT, FileTypes.DOCUMENT].includes(file.type)) {
|
||||
const fileContent = (await window.api.file.read(file.id + file.ext, true)).trim()
|
||||
parts.push({
|
||||
type: 'input_text',
|
||||
text: file.origin_name + '\n' + fileContent
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
role: message.role === 'system' ? 'user' : message.role,
|
||||
content: parts
|
||||
}
|
||||
}
|
||||
|
||||
public convertMcpToolsToSdkTools(mcpTools: MCPTool[]): OpenAI.Responses.Tool[] {
|
||||
return mcpToolsToOpenAIResponseTools(mcpTools)
|
||||
}
|
||||
|
||||
public convertSdkToolCallToMcp(toolCall: OpenAIResponseSdkToolCall, mcpTools: MCPTool[]): MCPTool | undefined {
|
||||
return openAIToolsToMcpTool(mcpTools, toolCall)
|
||||
}
|
||||
public convertSdkToolCallToMcpToolResponse(toolCall: OpenAIResponseSdkToolCall, mcpTool: MCPTool): ToolCallResponse {
|
||||
const parsedArgs = (() => {
|
||||
try {
|
||||
return JSON.parse(toolCall.arguments)
|
||||
} catch {
|
||||
return toolCall.arguments
|
||||
}
|
||||
})()
|
||||
|
||||
return {
|
||||
id: toolCall.call_id,
|
||||
toolCallId: toolCall.call_id,
|
||||
tool: mcpTool,
|
||||
arguments: parsedArgs,
|
||||
status: 'pending'
|
||||
}
|
||||
}
|
||||
|
||||
public convertMcpToolResponseToSdkMessageParam(
|
||||
mcpToolResponse: MCPToolResponse,
|
||||
resp: MCPCallToolResponse,
|
||||
model: Model
|
||||
): OpenAIResponseSdkMessageParam | undefined {
|
||||
if ('toolUseId' in mcpToolResponse && mcpToolResponse.toolUseId) {
|
||||
return mcpToolCallResponseToOpenAIMessage(mcpToolResponse, resp, isVisionModel(model))
|
||||
} else if ('toolCallId' in mcpToolResponse && mcpToolResponse.toolCallId) {
|
||||
return {
|
||||
type: 'function_call_output',
|
||||
call_id: mcpToolResponse.toolCallId,
|
||||
output: JSON.stringify(resp.content)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
private convertResponseToMessageContent(response: OpenAI.Responses.Response): ResponseInput {
|
||||
const content: OpenAI.Responses.ResponseInput = []
|
||||
response.output.forEach((item) => {
|
||||
if (item.type !== 'apply_patch_call' && item.type !== 'apply_patch_call_output') {
|
||||
content.push(item)
|
||||
} else if (item.type === 'apply_patch_call') {
|
||||
if (item.operation !== undefined) {
|
||||
const applyPatchToolCall: OpenAI.Responses.ResponseInputItem.ApplyPatchCall = {
|
||||
...item,
|
||||
operation: item.operation
|
||||
}
|
||||
content.push(applyPatchToolCall)
|
||||
} else {
|
||||
logger.warn('Undefined tool call operation for ApplyPatchToolCall.')
|
||||
}
|
||||
} else if (item.type === 'apply_patch_call_output') {
|
||||
if (item.output !== undefined) {
|
||||
const applyPatchToolCallOutput: OpenAI.Responses.ResponseInputItem.ApplyPatchCallOutput = {
|
||||
...item,
|
||||
output: item.output === null ? undefined : item.output
|
||||
}
|
||||
content.push(applyPatchToolCallOutput)
|
||||
} else {
|
||||
logger.warn('Undefined tool call operation for ApplyPatchToolCall.')
|
||||
}
|
||||
}
|
||||
})
|
||||
return content
|
||||
}
|
||||
|
||||
public buildSdkMessages(
|
||||
currentReqMessages: OpenAIResponseSdkMessageParam[],
|
||||
output: OpenAI.Responses.Response | undefined,
|
||||
toolResults: OpenAIResponseSdkMessageParam[],
|
||||
toolCalls: OpenAIResponseSdkToolCall[]
|
||||
): OpenAIResponseSdkMessageParam[] {
|
||||
if (!output && toolCalls.length === 0) {
|
||||
return [...currentReqMessages, ...toolResults]
|
||||
}
|
||||
|
||||
if (!output) {
|
||||
return [...currentReqMessages, ...(toolCalls || []), ...(toolResults || [])]
|
||||
}
|
||||
|
||||
const content = this.convertResponseToMessageContent(output)
|
||||
|
||||
return [...currentReqMessages, ...content, ...(toolResults || [])]
|
||||
}
|
||||
|
||||
override estimateMessageTokens(message: OpenAIResponseSdkMessageParam): number {
|
||||
let sum = 0
|
||||
if ('content' in message) {
|
||||
if (typeof message.content === 'string') {
|
||||
sum += estimateTextTokens(message.content)
|
||||
} else if (Array.isArray(message.content)) {
|
||||
for (const part of message.content) {
|
||||
switch (part.type) {
|
||||
case 'input_text':
|
||||
sum += estimateTextTokens(part.text)
|
||||
break
|
||||
case 'input_image':
|
||||
sum += estimateTextTokens(part.image_url || '')
|
||||
break
|
||||
default:
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
switch (message.type) {
|
||||
case 'function_call_output': {
|
||||
let str = ''
|
||||
if (typeof message.output === 'string') {
|
||||
str = message.output
|
||||
} else {
|
||||
for (const part of message.output) {
|
||||
switch (part.type) {
|
||||
case 'input_text':
|
||||
str += part.text
|
||||
break
|
||||
case 'input_image':
|
||||
str += part.image_url || ''
|
||||
break
|
||||
case 'input_file':
|
||||
str += part.file_data || ''
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
sum += estimateTextTokens(str)
|
||||
break
|
||||
}
|
||||
case 'function_call':
|
||||
sum += estimateTextTokens(message.arguments)
|
||||
break
|
||||
default:
|
||||
break
|
||||
}
|
||||
return sum
|
||||
}
|
||||
|
||||
public extractMessagesFromSdkPayload(sdkPayload: OpenAIResponseSdkParams): OpenAIResponseSdkMessageParam[] {
|
||||
if (!sdkPayload.input || typeof sdkPayload.input === 'string') {
|
||||
return [{ role: 'user', content: sdkPayload.input ?? '' }]
|
||||
}
|
||||
return sdkPayload.input
|
||||
}
|
||||
|
||||
getRequestTransformer(): RequestTransformer<OpenAIResponseSdkParams, OpenAIResponseSdkMessageParam> {
|
||||
return {
|
||||
transform: async (
|
||||
coreRequest,
|
||||
assistant,
|
||||
model,
|
||||
isRecursiveCall,
|
||||
recursiveSdkMessages
|
||||
): Promise<{
|
||||
payload: OpenAIResponseSdkParams
|
||||
messages: OpenAIResponseSdkMessageParam[]
|
||||
metadata: Record<string, any>
|
||||
}> => {
|
||||
const { messages, mcpTools, maxTokens, streamOutput, enableWebSearch, enableGenerateImage } = coreRequest
|
||||
// 1. 处理系统消息
|
||||
const systemMessage: OpenAI.Responses.EasyInputMessage = {
|
||||
role: 'system',
|
||||
content: []
|
||||
}
|
||||
|
||||
const systemMessageContent: OpenAI.Responses.ResponseInputMessageContentList = []
|
||||
const systemMessageInput: OpenAI.Responses.ResponseInputText = {
|
||||
text: assistant.prompt || '',
|
||||
type: 'input_text'
|
||||
}
|
||||
if (
|
||||
isSupportedReasoningEffortOpenAIModel(model) &&
|
||||
isSupportDeveloperRoleProvider(this.provider) &&
|
||||
isOpenAIOpenWeightModel(model)
|
||||
) {
|
||||
systemMessage.role = 'developer'
|
||||
}
|
||||
|
||||
// 2. 设置工具
|
||||
let tools: OpenAI.Responses.Tool[] = []
|
||||
const { tools: extraTools } = this.setupToolsConfig({
|
||||
mcpTools: mcpTools,
|
||||
model,
|
||||
enableToolUse: isSupportedToolUse(assistant)
|
||||
})
|
||||
|
||||
systemMessageContent.push(systemMessageInput)
|
||||
systemMessage.content = systemMessageContent
|
||||
|
||||
// 3. 处理用户消息
|
||||
let userMessage: OpenAI.Responses.ResponseInputItem[] = []
|
||||
if (typeof messages === 'string') {
|
||||
userMessage.push({ role: 'user', content: messages })
|
||||
} else {
|
||||
const processedMessages = addImageFileToContents(messages)
|
||||
for (const message of processedMessages) {
|
||||
userMessage.push(await this.convertMessageToSdkParam(message, model))
|
||||
}
|
||||
}
|
||||
// FIXME: 最好还是直接使用previous_response_id来处理(或者在数据库中存储image_generation_call的id)
|
||||
if (enableGenerateImage) {
|
||||
const finalAssistantMessage = userMessage.findLast(
|
||||
(m) => (m as OpenAI.Responses.EasyInputMessage).role === 'assistant'
|
||||
) as OpenAI.Responses.EasyInputMessage
|
||||
const finalUserMessage = userMessage.pop() as OpenAI.Responses.EasyInputMessage
|
||||
if (finalUserMessage && Array.isArray(finalUserMessage.content)) {
|
||||
if (finalAssistantMessage && Array.isArray(finalAssistantMessage.content)) {
|
||||
finalAssistantMessage.content = [...finalAssistantMessage.content, ...finalUserMessage.content]
|
||||
// 这里是故意将上条助手消息的内容(包含图片和文件)作为用户消息发送
|
||||
userMessage = [{ ...finalAssistantMessage, role: 'user' } as OpenAI.Responses.EasyInputMessage]
|
||||
} else {
|
||||
userMessage.push(finalUserMessage)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 4. 最终请求消息
|
||||
let reqMessages: OpenAI.Responses.ResponseInput
|
||||
if (!systemMessage.content) {
|
||||
reqMessages = [...userMessage]
|
||||
} else {
|
||||
reqMessages = [systemMessage, ...userMessage].filter(Boolean) as OpenAI.Responses.EasyInputMessage[]
|
||||
}
|
||||
|
||||
if (enableWebSearch) {
|
||||
tools.push({
|
||||
type: 'web_search_preview'
|
||||
})
|
||||
}
|
||||
|
||||
if (enableGenerateImage) {
|
||||
tools.push({
|
||||
type: 'image_generation',
|
||||
partial_images: streamOutput ? 2 : undefined
|
||||
})
|
||||
}
|
||||
|
||||
tools = tools.concat(extraTools)
|
||||
|
||||
const reasoningEffort = this.getReasoningEffort(assistant, model)
|
||||
|
||||
// minimal cannot be used with web_search tool
|
||||
if (isGPT5SeriesModel(model) && reasoningEffort.reasoning?.effort === 'minimal' && enableWebSearch) {
|
||||
reasoningEffort.reasoning.effort = 'low'
|
||||
}
|
||||
|
||||
const commonParams: OpenAIResponseSdkParams = {
|
||||
model: model.id,
|
||||
input:
|
||||
isRecursiveCall && recursiveSdkMessages && recursiveSdkMessages.length > 0
|
||||
? recursiveSdkMessages
|
||||
: reqMessages,
|
||||
temperature: this.getTemperature(assistant, model),
|
||||
top_p: this.getTopP(assistant, model),
|
||||
max_output_tokens: maxTokens,
|
||||
stream: streamOutput,
|
||||
tools: !isEmpty(tools) ? tools : undefined,
|
||||
// groq 有不同的 service tier 配置,不符合 openai 接口类型
|
||||
service_tier: this.getServiceTier(model) as OpenAIServiceTier,
|
||||
...(isSupportVerbosityModel(model)
|
||||
? {
|
||||
text: {
|
||||
verbosity: this.getVerbosity(model)
|
||||
}
|
||||
}
|
||||
: {}),
|
||||
...(this.getReasoningEffort(assistant, model) as OpenAI.Reasoning),
|
||||
// 只在对话场景下应用自定义参数,避免影响翻译、总结等其他业务逻辑
|
||||
// 注意:用户自定义参数总是应该覆盖其他参数
|
||||
...(coreRequest.callType === 'chat' ? this.getCustomParameters(assistant) : {})
|
||||
}
|
||||
const timeout = this.getTimeout(model)
|
||||
return { payload: commonParams, messages: reqMessages, metadata: { timeout } }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
getResponseChunkTransformer(ctx: CompletionsContext): ResponseChunkTransformer<OpenAIResponseSdkRawChunk> {
|
||||
const toolCalls: OpenAIResponseSdkToolCall[] = []
|
||||
const outputItems: OpenAI.Responses.ResponseOutputItem[] = []
|
||||
let hasBeenCollectedToolCalls = false
|
||||
let hasReasoningSummary = false
|
||||
let isFirstThinkingChunk = true
|
||||
let isFirstTextChunk = true
|
||||
return () => ({
|
||||
async transform(chunk: OpenAIResponseSdkRawChunk, controller: TransformStreamDefaultController<GenericChunk>) {
|
||||
if (typeof chunk === 'string') {
|
||||
try {
|
||||
chunk = JSON.parse(chunk)
|
||||
} catch (error) {
|
||||
logger.error('invalid chunk', { chunk, error })
|
||||
throw new Error(t('error.chat.chunk.non_json'))
|
||||
}
|
||||
}
|
||||
// 处理chunk
|
||||
if ('output' in chunk) {
|
||||
if (ctx._internal?.toolProcessingState) {
|
||||
ctx._internal.toolProcessingState.output = chunk
|
||||
}
|
||||
for (const output of chunk.output) {
|
||||
switch (output.type) {
|
||||
case 'message':
|
||||
if (output.content[0].type === 'output_text') {
|
||||
if (isFirstTextChunk) {
|
||||
controller.enqueue({
|
||||
type: ChunkType.TEXT_START
|
||||
})
|
||||
isFirstTextChunk = false
|
||||
}
|
||||
controller.enqueue({
|
||||
type: ChunkType.TEXT_DELTA,
|
||||
text: output.content[0].text
|
||||
})
|
||||
if (output.content[0].annotations && output.content[0].annotations.length > 0) {
|
||||
controller.enqueue({
|
||||
type: ChunkType.LLM_WEB_SEARCH_COMPLETE,
|
||||
llm_web_search: {
|
||||
source: WebSearchSource.OPENAI_RESPONSE,
|
||||
results: output.content[0].annotations
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
break
|
||||
case 'reasoning':
|
||||
if (isFirstThinkingChunk) {
|
||||
controller.enqueue({
|
||||
type: ChunkType.THINKING_START
|
||||
})
|
||||
isFirstThinkingChunk = false
|
||||
}
|
||||
controller.enqueue({
|
||||
type: ChunkType.THINKING_DELTA,
|
||||
text: output.summary.map((s) => s.text).join('\n')
|
||||
})
|
||||
break
|
||||
case 'function_call':
|
||||
toolCalls.push(output)
|
||||
break
|
||||
case 'image_generation_call':
|
||||
controller.enqueue({
|
||||
type: ChunkType.IMAGE_CREATED
|
||||
})
|
||||
controller.enqueue({
|
||||
type: ChunkType.IMAGE_COMPLETE,
|
||||
image: {
|
||||
type: 'base64',
|
||||
images: [`data:image/png;base64,${output.result}`]
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
if (toolCalls.length > 0) {
|
||||
controller.enqueue({
|
||||
type: ChunkType.MCP_TOOL_CREATED,
|
||||
tool_calls: toolCalls
|
||||
})
|
||||
}
|
||||
controller.enqueue({
|
||||
type: ChunkType.LLM_RESPONSE_COMPLETE,
|
||||
response: {
|
||||
usage: {
|
||||
prompt_tokens: chunk.usage?.input_tokens || 0,
|
||||
completion_tokens: chunk.usage?.output_tokens || 0,
|
||||
total_tokens: chunk.usage?.total_tokens || 0
|
||||
}
|
||||
}
|
||||
})
|
||||
} else {
|
||||
switch (chunk.type) {
|
||||
case 'response.output_item.added':
|
||||
if (chunk.item.type === 'function_call') {
|
||||
outputItems.push(chunk.item)
|
||||
} else if (chunk.item.type === 'web_search_call') {
|
||||
controller.enqueue({
|
||||
type: ChunkType.LLM_WEB_SEARCH_IN_PROGRESS
|
||||
})
|
||||
}
|
||||
break
|
||||
case 'response.reasoning_summary_part.added':
|
||||
if (hasReasoningSummary) {
|
||||
const separator = '\n\n'
|
||||
controller.enqueue({
|
||||
type: ChunkType.THINKING_DELTA,
|
||||
text: separator
|
||||
})
|
||||
}
|
||||
hasReasoningSummary = true
|
||||
break
|
||||
case 'response.reasoning_summary_text.delta':
|
||||
if (isFirstThinkingChunk) {
|
||||
controller.enqueue({
|
||||
type: ChunkType.THINKING_START
|
||||
})
|
||||
isFirstThinkingChunk = false
|
||||
}
|
||||
controller.enqueue({
|
||||
type: ChunkType.THINKING_DELTA,
|
||||
text: chunk.delta
|
||||
})
|
||||
break
|
||||
case 'response.image_generation_call.generating':
|
||||
controller.enqueue({
|
||||
type: ChunkType.IMAGE_CREATED
|
||||
})
|
||||
break
|
||||
case 'response.image_generation_call.partial_image':
|
||||
controller.enqueue({
|
||||
type: ChunkType.IMAGE_DELTA,
|
||||
image: {
|
||||
type: 'base64',
|
||||
images: [`data:image/png;base64,${chunk.partial_image_b64}`]
|
||||
}
|
||||
})
|
||||
break
|
||||
case 'response.image_generation_call.completed':
|
||||
controller.enqueue({
|
||||
type: ChunkType.IMAGE_COMPLETE
|
||||
})
|
||||
break
|
||||
case 'response.output_text.delta': {
|
||||
if (isFirstTextChunk) {
|
||||
controller.enqueue({
|
||||
type: ChunkType.TEXT_START
|
||||
})
|
||||
isFirstTextChunk = false
|
||||
}
|
||||
controller.enqueue({
|
||||
type: ChunkType.TEXT_DELTA,
|
||||
text: chunk.delta
|
||||
})
|
||||
break
|
||||
}
|
||||
case 'response.function_call_arguments.done': {
|
||||
const outputItem: OpenAI.Responses.ResponseOutputItem | undefined = outputItems.find(
|
||||
(item) => item.id === chunk.item_id
|
||||
)
|
||||
if (outputItem) {
|
||||
if (outputItem.type === 'function_call') {
|
||||
toolCalls.push({
|
||||
...outputItem,
|
||||
arguments: chunk.arguments,
|
||||
status: 'completed'
|
||||
})
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
case 'response.content_part.done': {
|
||||
if (chunk.part.type === 'output_text' && chunk.part.annotations && chunk.part.annotations.length > 0) {
|
||||
controller.enqueue({
|
||||
type: ChunkType.LLM_WEB_SEARCH_COMPLETE,
|
||||
llm_web_search: {
|
||||
source: WebSearchSource.OPENAI_RESPONSE,
|
||||
results: chunk.part.annotations
|
||||
}
|
||||
})
|
||||
}
|
||||
if (toolCalls.length > 0 && !hasBeenCollectedToolCalls) {
|
||||
controller.enqueue({
|
||||
type: ChunkType.MCP_TOOL_CREATED,
|
||||
tool_calls: toolCalls
|
||||
})
|
||||
hasBeenCollectedToolCalls = true
|
||||
}
|
||||
break
|
||||
}
|
||||
case 'response.completed': {
|
||||
if (ctx._internal?.toolProcessingState) {
|
||||
ctx._internal.toolProcessingState.output = chunk.response
|
||||
}
|
||||
if (toolCalls.length > 0 && !hasBeenCollectedToolCalls) {
|
||||
controller.enqueue({
|
||||
type: ChunkType.MCP_TOOL_CREATED,
|
||||
tool_calls: toolCalls
|
||||
})
|
||||
hasBeenCollectedToolCalls = true
|
||||
}
|
||||
const completion_tokens = chunk.response.usage?.output_tokens || 0
|
||||
const total_tokens = chunk.response.usage?.total_tokens || 0
|
||||
controller.enqueue({
|
||||
type: ChunkType.LLM_RESPONSE_COMPLETE,
|
||||
response: {
|
||||
usage: {
|
||||
prompt_tokens: chunk.response.usage?.input_tokens || 0,
|
||||
completion_tokens: completion_tokens,
|
||||
total_tokens: total_tokens
|
||||
}
|
||||
}
|
||||
})
|
||||
break
|
||||
}
|
||||
case 'error': {
|
||||
controller.enqueue({
|
||||
type: ChunkType.ERROR,
|
||||
error: {
|
||||
message: chunk.message,
|
||||
code: chunk.code
|
||||
}
|
||||
})
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -1,4 +0,0 @@
|
||||
export function normalizeAzureOpenAIEndpoint(apiHost: string): string {
|
||||
const normalizedHost = apiHost.replace(/\/+$/, '')
|
||||
return normalizedHost.replace(/\/openai(?:\/v1)?$/i, '')
|
||||
}
|
||||
@ -1,56 +0,0 @@
|
||||
import type OpenAI from '@cherrystudio/openai'
|
||||
import { loggerService } from '@logger'
|
||||
import { isSupportedModel } from '@renderer/config/models'
|
||||
import type { Provider } from '@renderer/types'
|
||||
import { objectKeys } from '@renderer/types'
|
||||
import { formatApiHost } from '@renderer/utils'
|
||||
import { withoutTrailingApiVersion } from '@shared/utils'
|
||||
|
||||
import { OpenAIAPIClient } from '../openai/OpenAIApiClient'
|
||||
|
||||
const logger = loggerService.withContext('OVMSClient')
|
||||
|
||||
export class OVMSClient extends OpenAIAPIClient {
|
||||
constructor(provider: Provider) {
|
||||
super(provider)
|
||||
}
|
||||
|
||||
override async listModels(): Promise<OpenAI.Models.Model[]> {
|
||||
try {
|
||||
const sdk = await this.getSdkInstance()
|
||||
const url = formatApiHost(withoutTrailingApiVersion(this.getBaseURL()), true, 'v1')
|
||||
const chatModelsResponse = await sdk.withOptions({ baseURL: url }).get('/config')
|
||||
logger.debug(`Chat models response: ${JSON.stringify(chatModelsResponse)}`)
|
||||
|
||||
// Parse the config response to extract model information
|
||||
const config = chatModelsResponse as Record<string, any>
|
||||
const models = objectKeys(config)
|
||||
.map((modelName) => {
|
||||
const modelInfo = config[modelName]
|
||||
|
||||
// Check if model has at least one version with "AVAILABLE" state
|
||||
const hasAvailableVersion = modelInfo?.model_version_status?.some(
|
||||
(versionStatus: any) => versionStatus?.state === 'AVAILABLE'
|
||||
)
|
||||
|
||||
if (hasAvailableVersion) {
|
||||
return {
|
||||
id: modelName,
|
||||
object: 'model' as const,
|
||||
owned_by: 'ovms',
|
||||
created: Date.now()
|
||||
}
|
||||
}
|
||||
return null // Skip models without available versions
|
||||
})
|
||||
.filter(Boolean) // Remove null entries
|
||||
logger.debug(`Processed models: ${JSON.stringify(models)}`)
|
||||
|
||||
// Filter out unsupported models
|
||||
return models.filter((model): model is OpenAI.Models.Model => model !== null && isSupportedModel(model))
|
||||
} catch (error) {
|
||||
logger.error(`Error listing OVMS models: ${error}`)
|
||||
return []
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1,72 +0,0 @@
|
||||
import type OpenAI from '@cherrystudio/openai'
|
||||
import { loggerService } from '@logger'
|
||||
import { isSupportedModel } from '@renderer/config/models'
|
||||
import type { Model, Provider } from '@renderer/types'
|
||||
|
||||
import { OpenAIAPIClient } from '../openai/OpenAIApiClient'
|
||||
|
||||
const logger = loggerService.withContext('PPIOAPIClient')
|
||||
export class PPIOAPIClient extends OpenAIAPIClient {
|
||||
constructor(provider: Provider) {
|
||||
super(provider)
|
||||
}
|
||||
|
||||
// oxlint-disable-next-line @typescript-eslint/no-unused-vars
|
||||
override getClientCompatibilityType(_model?: Model): string[] {
|
||||
return ['OpenAIAPIClient']
|
||||
}
|
||||
|
||||
override async listModels(): Promise<OpenAI.Models.Model[]> {
|
||||
try {
|
||||
const sdk = await this.getSdkInstance()
|
||||
|
||||
// PPIO requires three separate requests to get all model types
|
||||
const [chatModelsResponse, embeddingModelsResponse, rerankerModelsResponse] = await Promise.all([
|
||||
// Chat/completion models
|
||||
sdk.request({
|
||||
method: 'get',
|
||||
path: '/models'
|
||||
}),
|
||||
// Embedding models
|
||||
sdk.request({
|
||||
method: 'get',
|
||||
path: '/models?model_type=embedding'
|
||||
}),
|
||||
// Reranker models
|
||||
sdk.request({
|
||||
method: 'get',
|
||||
path: '/models?model_type=reranker'
|
||||
})
|
||||
])
|
||||
|
||||
// Extract models from all responses
|
||||
// @ts-ignore - PPIO response structure may not be typed
|
||||
const allModels = [
|
||||
...((chatModelsResponse as any)?.data || []),
|
||||
...((embeddingModelsResponse as any)?.data || []),
|
||||
...((rerankerModelsResponse as any)?.data || [])
|
||||
]
|
||||
|
||||
// Process and standardize model data
|
||||
const processedModels = allModels.map((model: any) => ({
|
||||
id: model.id || model.name,
|
||||
description: model.description || model.display_name || model.summary,
|
||||
object: 'model' as const,
|
||||
owned_by: model.owned_by || model.publisher || model.organization || 'ppio',
|
||||
created: model.created || Date.now()
|
||||
}))
|
||||
|
||||
// Clean up model IDs and filter supported models
|
||||
processedModels.forEach((model) => {
|
||||
if (model.id) {
|
||||
model.id = model.id.trim()
|
||||
}
|
||||
})
|
||||
|
||||
return processedModels.filter(isSupportedModel)
|
||||
} catch (error) {
|
||||
logger.error('Error listing PPIO models:', error as Error)
|
||||
return []
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1,141 +0,0 @@
|
||||
import type Anthropic from '@anthropic-ai/sdk'
|
||||
import type OpenAI from '@cherrystudio/openai'
|
||||
import type { Assistant, MCPTool, MCPToolResponse, Model, ToolCallResponse } from '@renderer/types'
|
||||
import type { Provider } from '@renderer/types'
|
||||
import type {
|
||||
AnthropicSdkRawChunk,
|
||||
OpenAIResponseSdkRawChunk,
|
||||
OpenAIResponseSdkRawOutput,
|
||||
OpenAISdkRawChunk,
|
||||
SdkMessageParam,
|
||||
SdkParams,
|
||||
SdkRawChunk,
|
||||
SdkRawOutput,
|
||||
SdkTool,
|
||||
SdkToolCall
|
||||
} from '@renderer/types/sdk'
|
||||
|
||||
import type { CompletionsParams, GenericChunk } from '../middleware/schemas'
|
||||
import type { CompletionsContext } from '../middleware/types'
|
||||
|
||||
/**
|
||||
* 原始流监听器接口
|
||||
*/
|
||||
export interface RawStreamListener<TRawChunk = SdkRawChunk> {
|
||||
onChunk?: (chunk: TRawChunk) => void
|
||||
onStart?: () => void
|
||||
onEnd?: () => void
|
||||
onError?: (error: Error) => void
|
||||
}
|
||||
|
||||
/**
|
||||
* OpenAI 专用的流监听器
|
||||
*/
|
||||
export interface OpenAIStreamListener extends RawStreamListener<OpenAISdkRawChunk> {
|
||||
onChoice?: (choice: OpenAI.Chat.Completions.ChatCompletionChunk.Choice) => void
|
||||
onFinishReason?: (reason: string) => void
|
||||
}
|
||||
|
||||
/**
|
||||
* OpenAI Response 专用的流监听器
|
||||
*/
|
||||
export interface OpenAIResponseStreamListener<TChunk extends OpenAIResponseSdkRawChunk = OpenAIResponseSdkRawChunk>
|
||||
extends RawStreamListener<TChunk> {
|
||||
onMessage?: (response: OpenAIResponseSdkRawOutput) => void
|
||||
}
|
||||
|
||||
/**
|
||||
* Anthropic 专用的流监听器
|
||||
*/
|
||||
export interface AnthropicStreamListener<TChunk extends AnthropicSdkRawChunk = AnthropicSdkRawChunk>
|
||||
extends RawStreamListener<TChunk> {
|
||||
onContentBlock?: (contentBlock: Anthropic.Messages.ContentBlock) => void
|
||||
onMessage?: (message: Anthropic.Messages.Message) => void
|
||||
}
|
||||
|
||||
/**
|
||||
* 请求转换器接口
|
||||
*/
|
||||
export interface RequestTransformer<
|
||||
TSdkParams extends SdkParams = SdkParams,
|
||||
TMessageParam extends SdkMessageParam = SdkMessageParam
|
||||
> {
|
||||
transform(
|
||||
completionsParams: CompletionsParams,
|
||||
assistant: Assistant,
|
||||
model: Model,
|
||||
isRecursiveCall?: boolean,
|
||||
recursiveSdkMessages?: TMessageParam[]
|
||||
): Promise<{
|
||||
payload: TSdkParams
|
||||
messages: TMessageParam[]
|
||||
metadata?: Record<string, any>
|
||||
}>
|
||||
}
|
||||
|
||||
/**
|
||||
* 响应块转换器接口
|
||||
*/
|
||||
export type ResponseChunkTransformer<TRawChunk extends SdkRawChunk = SdkRawChunk, TContext = any> = (
|
||||
context?: TContext
|
||||
) => Transformer<TRawChunk, GenericChunk>
|
||||
|
||||
export interface ResponseChunkTransformerContext {
|
||||
isStreaming: boolean
|
||||
isEnabledToolCalling: boolean
|
||||
isEnabledWebSearch: boolean
|
||||
isEnabledUrlContext: boolean
|
||||
isEnabledReasoning: boolean
|
||||
mcpTools: MCPTool[]
|
||||
provider: Provider
|
||||
}
|
||||
|
||||
/**
|
||||
* API客户端接口
|
||||
*/
|
||||
export interface ApiClient<
|
||||
TSdkInstance = any,
|
||||
TSdkParams extends SdkParams = SdkParams,
|
||||
TRawOutput extends SdkRawOutput = SdkRawOutput,
|
||||
TRawChunk extends SdkRawChunk = SdkRawChunk,
|
||||
TMessageParam extends SdkMessageParam = SdkMessageParam,
|
||||
TToolCall extends SdkToolCall = SdkToolCall,
|
||||
TSdkSpecificTool extends SdkTool = SdkTool
|
||||
> {
|
||||
provider: Provider
|
||||
|
||||
// 核心方法 - 在中间件架构中,这个方法可能只是一个占位符
|
||||
// 实际的SDK调用由SdkCallMiddleware处理
|
||||
// completions(params: CompletionsParams): Promise<CompletionsResult>
|
||||
|
||||
createCompletions(payload: TSdkParams): Promise<TRawOutput>
|
||||
|
||||
// SDK相关方法
|
||||
getSdkInstance(): Promise<TSdkInstance> | TSdkInstance
|
||||
getRequestTransformer(): RequestTransformer<TSdkParams, TMessageParam>
|
||||
getResponseChunkTransformer(ctx: CompletionsContext): ResponseChunkTransformer<TRawChunk>
|
||||
|
||||
// 原始流监听方法
|
||||
attachRawStreamListener?(rawOutput: TRawOutput, listener: RawStreamListener<TRawChunk>): TRawOutput
|
||||
|
||||
// 工具转换相关方法 (保持可选,因为不是所有Provider都支持工具)
|
||||
convertMcpToolsToSdkTools(mcpTools: MCPTool[]): TSdkSpecificTool[]
|
||||
convertMcpToolResponseToSdkMessageParam?(
|
||||
mcpToolResponse: MCPToolResponse,
|
||||
resp: any,
|
||||
model: Model
|
||||
): TMessageParam | undefined
|
||||
convertSdkToolCallToMcp?(toolCall: TToolCall, mcpTools: MCPTool[]): MCPTool | undefined
|
||||
convertSdkToolCallToMcpToolResponse(toolCall: TToolCall, mcpTool: MCPTool): ToolCallResponse
|
||||
|
||||
// 构建SDK特定的消息列表,用于工具调用后的递归调用
|
||||
buildSdkMessages(
|
||||
currentReqMessages: TMessageParam[],
|
||||
output: TRawOutput | string,
|
||||
toolResults: TMessageParam[],
|
||||
toolCalls?: TToolCall[]
|
||||
): TMessageParam[]
|
||||
|
||||
// 从SDK载荷中提取消息数组(用于中间件中的类型安全访问)
|
||||
extractMessagesFromSdkPayload(sdkPayload: TSdkParams): TMessageParam[]
|
||||
}
|
||||
@ -1,105 +0,0 @@
|
||||
import type OpenAI from '@cherrystudio/openai'
|
||||
import { loggerService } from '@logger'
|
||||
import type { Provider } from '@renderer/types'
|
||||
import type { GenerateImageParams } from '@renderer/types'
|
||||
|
||||
import { OpenAIAPIClient } from '../openai/OpenAIApiClient'
|
||||
|
||||
const logger = loggerService.withContext('ZhipuAPIClient')
|
||||
|
||||
export class ZhipuAPIClient extends OpenAIAPIClient {
|
||||
constructor(provider: Provider) {
|
||||
super(provider)
|
||||
}
|
||||
|
||||
override getClientCompatibilityType(): string[] {
|
||||
return ['ZhipuAPIClient']
|
||||
}
|
||||
|
||||
override async generateImage({
|
||||
model,
|
||||
prompt,
|
||||
negativePrompt,
|
||||
imageSize,
|
||||
batchSize,
|
||||
signal,
|
||||
quality
|
||||
}: GenerateImageParams): Promise<string[]> {
|
||||
const sdk = await this.getSdkInstance()
|
||||
|
||||
// 智谱AI使用不同的参数格式
|
||||
const body: any = {
|
||||
model,
|
||||
prompt
|
||||
}
|
||||
|
||||
// 智谱AI特有的参数格式
|
||||
body.size = imageSize
|
||||
body.n = batchSize
|
||||
if (negativePrompt) {
|
||||
body.negative_prompt = negativePrompt
|
||||
}
|
||||
|
||||
// 只有cogview-4-250304模型支持quality和style参数
|
||||
if (model === 'cogview-4-250304') {
|
||||
if (quality) {
|
||||
body.quality = quality
|
||||
}
|
||||
body.style = 'vivid'
|
||||
}
|
||||
|
||||
try {
|
||||
logger.debug('Calling Zhipu image generation API with params:', body)
|
||||
|
||||
const response = await sdk.images.generate(body, { signal })
|
||||
|
||||
if (response.data && response.data.length > 0) {
|
||||
return response.data.map((image: any) => image.url).filter(Boolean)
|
||||
}
|
||||
|
||||
return []
|
||||
} catch (error) {
|
||||
logger.error('Zhipu image generation failed:', error as Error)
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
public async listModels(): Promise<OpenAI.Models.Model[]> {
|
||||
const models = [
|
||||
'glm-4.7',
|
||||
'glm-4.6',
|
||||
'glm-4.6v',
|
||||
'glm-4.6v-flash',
|
||||
'glm-4.6v-flashx',
|
||||
'glm-4.5',
|
||||
'glm-4.5-x',
|
||||
'glm-4.5-air',
|
||||
'glm-4.5-airx',
|
||||
'glm-4.5-flash',
|
||||
'glm-4.5v',
|
||||
'glm-z1-air',
|
||||
'glm-z1-airx',
|
||||
'cogview-3-flash',
|
||||
'cogview-4-250304',
|
||||
'glm-4-long',
|
||||
'glm-4-plus',
|
||||
'glm-4-air-250414',
|
||||
'glm-4-airx',
|
||||
'glm-4-flashx',
|
||||
'glm-4v',
|
||||
'glm-4v-flash',
|
||||
'glm-4v-plus-0111',
|
||||
'glm-4.1v-thinking-flash',
|
||||
'glm-4-alltools',
|
||||
'embedding-3'
|
||||
]
|
||||
|
||||
const created = Date.now()
|
||||
return models.map((id) => ({
|
||||
id,
|
||||
owned_by: 'zhipu',
|
||||
object: 'model' as const,
|
||||
created
|
||||
}))
|
||||
}
|
||||
}
|
||||
@ -1,185 +0,0 @@
|
||||
import { loggerService } from '@logger'
|
||||
import { ApiClientFactory } from '@renderer/aiCore/legacy/clients/ApiClientFactory'
|
||||
import type { BaseApiClient } from '@renderer/aiCore/legacy/clients/BaseApiClient'
|
||||
import { isDedicatedImageGenerationModel, isFunctionCallingModel } from '@renderer/config/models'
|
||||
import { withSpanResult } from '@renderer/services/SpanManagerService'
|
||||
import type { StartSpanParams } from '@renderer/trace/types/ModelSpanEntity'
|
||||
import type { GenerateImageParams, Model, Provider } from '@renderer/types'
|
||||
import type { RequestOptions, SdkModel } from '@renderer/types/sdk'
|
||||
import { isSupportedToolUse } from '@renderer/utils/mcp-tools'
|
||||
|
||||
import { AihubmixAPIClient } from './clients/aihubmix/AihubmixAPIClient'
|
||||
import { VertexAPIClient } from './clients/gemini/VertexAPIClient'
|
||||
import { NewAPIClient } from './clients/newapi/NewAPIClient'
|
||||
import { OpenAIResponseAPIClient } from './clients/openai/OpenAIResponseAPIClient'
|
||||
import { CompletionsMiddlewareBuilder } from './middleware/builder'
|
||||
import { MIDDLEWARE_NAME as AbortHandlerMiddlewareName } from './middleware/common/AbortHandlerMiddleware'
|
||||
import { MIDDLEWARE_NAME as ErrorHandlerMiddlewareName } from './middleware/common/ErrorHandlerMiddleware'
|
||||
import { MIDDLEWARE_NAME as FinalChunkConsumerMiddlewareName } from './middleware/common/FinalChunkConsumerMiddleware'
|
||||
import { applyCompletionsMiddlewares } from './middleware/composer'
|
||||
import { MIDDLEWARE_NAME as McpToolChunkMiddlewareName } from './middleware/core/McpToolChunkMiddleware'
|
||||
import { MIDDLEWARE_NAME as RawStreamListenerMiddlewareName } from './middleware/core/RawStreamListenerMiddleware'
|
||||
import { MIDDLEWARE_NAME as WebSearchMiddlewareName } from './middleware/core/WebSearchMiddleware'
|
||||
import { MIDDLEWARE_NAME as ImageGenerationMiddlewareName } from './middleware/feat/ImageGenerationMiddleware'
|
||||
import { MIDDLEWARE_NAME as ThinkingTagExtractionMiddlewareName } from './middleware/feat/ThinkingTagExtractionMiddleware'
|
||||
import { MIDDLEWARE_NAME as ToolUseExtractionMiddlewareName } from './middleware/feat/ToolUseExtractionMiddleware'
|
||||
import { MiddlewareRegistry } from './middleware/register'
|
||||
import type { CompletionsParams, CompletionsResult } from './middleware/schemas'
|
||||
|
||||
const logger = loggerService.withContext('AiProvider')
|
||||
|
||||
export default class AiProvider {
|
||||
private apiClient: BaseApiClient
|
||||
|
||||
constructor(provider: Provider) {
|
||||
// Use the new ApiClientFactory to get a BaseApiClient instance
|
||||
this.apiClient = ApiClientFactory.create(provider)
|
||||
}
|
||||
|
||||
public async completions(params: CompletionsParams, options?: RequestOptions): Promise<CompletionsResult> {
|
||||
// 1. 根据模型识别正确的客户端
|
||||
const model = params.assistant.model
|
||||
if (!model) {
|
||||
return Promise.reject(new Error('Model is required'))
|
||||
}
|
||||
|
||||
// 根据client类型选择合适的处理方式
|
||||
let client: BaseApiClient
|
||||
|
||||
if (this.apiClient instanceof AihubmixAPIClient) {
|
||||
// AihubmixAPIClient: 根据模型选择合适的子client
|
||||
client = this.apiClient.getClientForModel(model)
|
||||
if (client instanceof OpenAIResponseAPIClient) {
|
||||
client = client.getClient(model) as BaseApiClient
|
||||
}
|
||||
} else if (this.apiClient instanceof NewAPIClient) {
|
||||
client = this.apiClient.getClientForModel(model)
|
||||
if (client instanceof OpenAIResponseAPIClient) {
|
||||
client = client.getClient(model) as BaseApiClient
|
||||
}
|
||||
} else if (this.apiClient instanceof OpenAIResponseAPIClient) {
|
||||
// OpenAIResponseAPIClient: 根据模型特征选择API类型
|
||||
client = this.apiClient.getClient(model) as BaseApiClient
|
||||
} else if (this.apiClient instanceof VertexAPIClient) {
|
||||
client = this.apiClient.getClient(model) as BaseApiClient
|
||||
} else {
|
||||
// 其他client直接使用
|
||||
client = this.apiClient
|
||||
}
|
||||
|
||||
// 2. 构建中间件链
|
||||
const builder = CompletionsMiddlewareBuilder.withDefaults()
|
||||
// images api
|
||||
if (isDedicatedImageGenerationModel(model)) {
|
||||
builder.clear()
|
||||
builder
|
||||
.add(MiddlewareRegistry[FinalChunkConsumerMiddlewareName])
|
||||
.add(MiddlewareRegistry[ErrorHandlerMiddlewareName])
|
||||
.add(MiddlewareRegistry[AbortHandlerMiddlewareName])
|
||||
.add(MiddlewareRegistry[ImageGenerationMiddlewareName])
|
||||
} else {
|
||||
// Existing logic for other models
|
||||
logger.silly('Builder Params', params)
|
||||
// 使用兼容性类型检查,避免typescript类型收窄和装饰器模式的问题
|
||||
const clientTypes = client.getClientCompatibilityType(model)
|
||||
const isOpenAICompatible =
|
||||
clientTypes.includes('OpenAIAPIClient') || clientTypes.includes('OpenAIResponseAPIClient')
|
||||
if (!isOpenAICompatible) {
|
||||
logger.silly('ThinkingTagExtractionMiddleware is removed')
|
||||
builder.remove(ThinkingTagExtractionMiddlewareName)
|
||||
}
|
||||
|
||||
const isAnthropicOrOpenAIResponseCompatible =
|
||||
clientTypes.includes('AnthropicAPIClient') ||
|
||||
clientTypes.includes('OpenAIResponseAPIClient') ||
|
||||
clientTypes.includes('AnthropicVertexAPIClient')
|
||||
if (!isAnthropicOrOpenAIResponseCompatible) {
|
||||
logger.silly('RawStreamListenerMiddleware is removed')
|
||||
builder.remove(RawStreamListenerMiddlewareName)
|
||||
}
|
||||
if (!params.enableWebSearch) {
|
||||
logger.silly('WebSearchMiddleware is removed')
|
||||
builder.remove(WebSearchMiddlewareName)
|
||||
}
|
||||
if (!params.mcpTools?.length) {
|
||||
builder.remove(ToolUseExtractionMiddlewareName)
|
||||
logger.silly('ToolUseExtractionMiddleware is removed')
|
||||
builder.remove(McpToolChunkMiddlewareName)
|
||||
logger.silly('McpToolChunkMiddleware is removed')
|
||||
}
|
||||
if (isSupportedToolUse(params.assistant) && isFunctionCallingModel(model)) {
|
||||
builder.remove(ToolUseExtractionMiddlewareName)
|
||||
logger.silly('ToolUseExtractionMiddleware is removed')
|
||||
}
|
||||
if (params.callType !== 'chat' && params.callType !== 'check' && params.callType !== 'translate') {
|
||||
logger.silly('AbortHandlerMiddleware is removed')
|
||||
builder.remove(AbortHandlerMiddlewareName)
|
||||
}
|
||||
if (params.callType === 'test') {
|
||||
builder.remove(ErrorHandlerMiddlewareName)
|
||||
logger.silly('ErrorHandlerMiddleware is removed')
|
||||
builder.remove(FinalChunkConsumerMiddlewareName)
|
||||
logger.silly('FinalChunkConsumerMiddleware is removed')
|
||||
}
|
||||
}
|
||||
|
||||
const middlewares = builder.build()
|
||||
logger.silly(
|
||||
'middlewares',
|
||||
middlewares.map((m) => m.name)
|
||||
)
|
||||
|
||||
// 3. Create the wrapped SDK method with middlewares
|
||||
const wrappedCompletionMethod = applyCompletionsMiddlewares(client, client.createCompletions, middlewares)
|
||||
|
||||
// 4. Execute the wrapped method with the original params
|
||||
const result = wrappedCompletionMethod(params, options)
|
||||
return result
|
||||
}
|
||||
|
||||
public async completionsForTrace(params: CompletionsParams, options?: RequestOptions): Promise<CompletionsResult> {
|
||||
const traceName = params.assistant.model?.name
|
||||
? `${params.assistant.model?.name}.${params.callType}`
|
||||
: `LLM.${params.callType}`
|
||||
|
||||
const traceParams: StartSpanParams = {
|
||||
name: traceName,
|
||||
tag: 'LLM',
|
||||
topicId: params.topicId || '',
|
||||
modelName: params.assistant.model?.name
|
||||
}
|
||||
|
||||
return await withSpanResult(this.completions.bind(this), traceParams, params, options)
|
||||
}
|
||||
|
||||
public async models(): Promise<SdkModel[]> {
|
||||
return this.apiClient.listModels()
|
||||
}
|
||||
|
||||
public async getEmbeddingDimensions(model: Model): Promise<number> {
|
||||
try {
|
||||
// Use the SDK instance to test embedding capabilities
|
||||
const dimensions = await this.apiClient.getEmbeddingDimensions(model)
|
||||
return dimensions
|
||||
} catch (error) {
|
||||
logger.error('Error getting embedding dimensions:', error as Error)
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
public async generateImage(params: GenerateImageParams): Promise<string[]> {
|
||||
if (this.apiClient instanceof AihubmixAPIClient) {
|
||||
const client = this.apiClient.getClientForModel({ id: params.model } as Model)
|
||||
return client.generateImage(params)
|
||||
}
|
||||
return this.apiClient.generateImage(params)
|
||||
}
|
||||
|
||||
public getBaseURL(): string {
|
||||
return this.apiClient.getBaseURL()
|
||||
}
|
||||
|
||||
public getApiKey(): string {
|
||||
return this.apiClient.getApiKey()
|
||||
}
|
||||
}
|
||||
@ -1,182 +0,0 @@
|
||||
# MiddlewareBuilder 使用指南
|
||||
|
||||
`MiddlewareBuilder` 是一个用于动态构建和管理中间件链的工具,提供灵活的中间件组织和配置能力。
|
||||
|
||||
## 主要特性
|
||||
|
||||
### 1. 统一的中间件命名
|
||||
|
||||
所有中间件都通过导出的 `MIDDLEWARE_NAME` 常量标识:
|
||||
|
||||
```typescript
|
||||
// 中间件文件示例
|
||||
export const MIDDLEWARE_NAME = 'SdkCallMiddleware'
|
||||
export const SdkCallMiddleware: CompletionsMiddleware = ...
|
||||
```
|
||||
|
||||
### 2. NamedMiddleware 接口
|
||||
|
||||
中间件使用统一的 `NamedMiddleware` 接口格式:
|
||||
|
||||
```typescript
|
||||
interface NamedMiddleware<TMiddleware = any> {
|
||||
name: string
|
||||
middleware: TMiddleware
|
||||
}
|
||||
```
|
||||
|
||||
### 3. 中间件注册表
|
||||
|
||||
通过 `MiddlewareRegistry` 集中管理所有可用中间件:
|
||||
|
||||
```typescript
|
||||
import { MiddlewareRegistry } from './register'
|
||||
|
||||
// 通过名称获取中间件
|
||||
const sdkCallMiddleware = MiddlewareRegistry['SdkCallMiddleware']
|
||||
```
|
||||
|
||||
## 基本用法
|
||||
|
||||
### 1. 使用默认中间件链
|
||||
|
||||
```typescript
|
||||
import { CompletionsMiddlewareBuilder } from './builder'
|
||||
|
||||
const builder = CompletionsMiddlewareBuilder.withDefaults()
|
||||
const middlewares = builder.build()
|
||||
```
|
||||
|
||||
### 2. 自定义中间件链
|
||||
|
||||
```typescript
|
||||
import { createCompletionsBuilder, MiddlewareRegistry } from './builder'
|
||||
|
||||
const builder = createCompletionsBuilder([
|
||||
MiddlewareRegistry['AbortHandlerMiddleware'],
|
||||
MiddlewareRegistry['TextChunkMiddleware']
|
||||
])
|
||||
|
||||
const middlewares = builder.build()
|
||||
```
|
||||
|
||||
### 3. 动态调整中间件链
|
||||
|
||||
```typescript
|
||||
const builder = CompletionsMiddlewareBuilder.withDefaults()
|
||||
|
||||
// 根据条件添加、移除、替换中间件
|
||||
if (needsLogging) {
|
||||
builder.prepend(MiddlewareRegistry['GenericLoggingMiddleware'])
|
||||
}
|
||||
|
||||
if (disableTools) {
|
||||
builder.remove('McpToolChunkMiddleware')
|
||||
}
|
||||
|
||||
if (customThinking) {
|
||||
builder.replace('ThinkingTagExtractionMiddleware', customThinkingMiddleware)
|
||||
}
|
||||
|
||||
const middlewares = builder.build()
|
||||
```
|
||||
|
||||
### 4. 链式操作
|
||||
|
||||
```typescript
|
||||
const middlewares = CompletionsMiddlewareBuilder.withDefaults()
|
||||
.add(MiddlewareRegistry['CustomMiddleware'])
|
||||
.insertBefore('SdkCallMiddleware', MiddlewareRegistry['SecurityCheckMiddleware'])
|
||||
.remove('WebSearchMiddleware')
|
||||
.build()
|
||||
```
|
||||
|
||||
## API 参考
|
||||
|
||||
### CompletionsMiddlewareBuilder
|
||||
|
||||
**静态方法:**
|
||||
|
||||
- `static withDefaults()`: 创建带有默认中间件链的构建器
|
||||
|
||||
**实例方法:**
|
||||
|
||||
- `add(middleware: NamedMiddleware)`: 在链末尾添加中间件
|
||||
- `prepend(middleware: NamedMiddleware)`: 在链开头添加中间件
|
||||
- `insertAfter(targetName: string, middleware: NamedMiddleware)`: 在指定中间件后插入
|
||||
- `insertBefore(targetName: string, middleware: NamedMiddleware)`: 在指定中间件前插入
|
||||
- `replace(targetName: string, middleware: NamedMiddleware)`: 替换指定中间件
|
||||
- `remove(targetName: string)`: 移除指定中间件
|
||||
- `has(name: string)`: 检查是否包含指定中间件
|
||||
- `build()`: 构建最终的中间件数组
|
||||
- `getChain()`: 获取当前链(包含名称信息)
|
||||
- `clear()`: 清空中间件链
|
||||
- `execute(context, params, middlewareExecutor)`: 直接执行构建好的中间件链
|
||||
|
||||
### 工厂函数
|
||||
|
||||
- `createCompletionsBuilder(baseChain?)`: 创建 Completions 中间件构建器
|
||||
- `createMethodBuilder(baseChain?)`: 创建通用方法中间件构建器
|
||||
- `addMiddlewareName(middleware, name)`: 为中间件添加名称属性的辅助函数
|
||||
|
||||
### 中间件注册表
|
||||
|
||||
- `MiddlewareRegistry`: 所有注册中间件的集中访问点
|
||||
- `getMiddleware(name)`: 根据名称获取中间件
|
||||
- `getRegisteredMiddlewareNames()`: 获取所有注册的中间件名称
|
||||
- `DefaultCompletionsNamedMiddlewares`: 默认的 Completions 中间件链(NamedMiddleware 格式)
|
||||
|
||||
## 类型安全
|
||||
|
||||
构建器提供完整的 TypeScript 类型支持:
|
||||
|
||||
- `CompletionsMiddlewareBuilder` 专门用于 `CompletionsMiddleware` 类型
|
||||
- `MethodMiddlewareBuilder` 用于通用的 `MethodMiddleware` 类型
|
||||
- 所有中间件操作都基于 `NamedMiddleware<TMiddleware>` 接口
|
||||
|
||||
## 默认中间件链
|
||||
|
||||
默认的 Completions 中间件执行顺序:
|
||||
|
||||
1. `FinalChunkConsumerMiddleware` - 最终消费者
|
||||
2. `TransformCoreToSdkParamsMiddleware` - 参数转换
|
||||
3. `AbortHandlerMiddleware` - 中止处理
|
||||
4. `McpToolChunkMiddleware` - 工具处理
|
||||
5. `WebSearchMiddleware` - Web搜索处理
|
||||
6. `TextChunkMiddleware` - 文本处理
|
||||
7. `ThinkingTagExtractionMiddleware` - 思考标签提取处理
|
||||
8. `ThinkChunkMiddleware` - 思考处理
|
||||
9. `ResponseTransformMiddleware` - 响应转换
|
||||
10. `StreamAdapterMiddleware` - 流适配器
|
||||
11. `SdkCallMiddleware` - SDK调用
|
||||
|
||||
## 在 AiProvider 中的使用
|
||||
|
||||
```typescript
|
||||
export default class AiProvider {
|
||||
public async completions(params: CompletionsParams): Promise<CompletionsResult> {
|
||||
// 1. 构建中间件链
|
||||
const builder = CompletionsMiddlewareBuilder.withDefaults()
|
||||
|
||||
// 2. 根据参数动态调整
|
||||
if (params.enableCustomFeature) {
|
||||
builder.insertAfter('StreamAdapterMiddleware', customFeatureMiddleware)
|
||||
}
|
||||
|
||||
// 3. 应用中间件
|
||||
const middlewares = builder.build()
|
||||
const wrappedMethod = applyCompletionsMiddlewares(this.apiClient, this.apiClient.createCompletions, middlewares)
|
||||
|
||||
return wrappedMethod(params)
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## 注意事项
|
||||
|
||||
1. **类型兼容性**:`MethodMiddleware` 和 `CompletionsMiddleware` 不兼容,需要使用对应的构建器
|
||||
2. **中间件名称**:所有中间件必须导出 `MIDDLEWARE_NAME` 常量用于标识
|
||||
3. **注册表管理**:新增中间件需要在 `register.ts` 中注册
|
||||
4. **默认链**:默认链通过 `DefaultCompletionsNamedMiddlewares` 提供,支持延迟加载避免循环依赖
|
||||
|
||||
这种设计使得中间件链的构建既灵活又类型安全,同时保持了简洁的 API 接口。
|
||||
@ -1,175 +0,0 @@
|
||||
# Cherry Studio 中间件规范
|
||||
|
||||
本文档定义了 Cherry Studio `aiCore` 模块中中间件的设计、实现和使用规范。目标是建立一个灵活、可维护且易于扩展的中间件系统。
|
||||
|
||||
## 1. 核心概念
|
||||
|
||||
### 1.1. 中间件 (Middleware)
|
||||
|
||||
中间件是一个函数或对象,它在 AI 请求的处理流程中的特定阶段执行,可以访问和修改请求上下文 (`AiProviderMiddlewareContext`)、请求参数 (`Params`),并控制是否将请求传递给下一个中间件或终止流程。
|
||||
|
||||
每个中间件应该专注于一个单一的横切关注点,例如日志记录、错误处理、流适配、特性解析等。
|
||||
|
||||
### 1.2. `AiProviderMiddlewareContext` (上下文对象)
|
||||
|
||||
这是一个在整个中间件链执行过程中传递的对象,包含以下核心信息:
|
||||
|
||||
- `_apiClientInstance: ApiClient<any,any,any>`: 当前选定的、已实例化的 AI Provider 客户端。
|
||||
- `_coreRequest: CoreRequestType`: 标准化的内部核心请求对象。
|
||||
- `resolvePromise: (value: AggregatedResultType) => void`: 用于在整个操作成功完成时解析 `AiCoreService` 返回的 Promise。
|
||||
- `rejectPromise: (reason?: any) => void`: 用于在发生错误时拒绝 `AiCoreService` 返回的 Promise。
|
||||
- `onChunk?: (chunk: Chunk) => void`: 应用层提供的流式数据块回调。
|
||||
- `abortController?: AbortController`: 用于中止请求的控制器。
|
||||
- 其他中间件可能读写的、与当前请求相关的动态数据。
|
||||
|
||||
### 1.3. `MiddlewareName` (中间件名称)
|
||||
|
||||
为了方便动态操作(如插入、替换、移除)中间件,每个重要的、可能被其他逻辑引用的中间件都应该有一个唯一的、可识别的名称。推荐使用 TypeScript 的 `enum` 来定义:
|
||||
|
||||
```typescript
|
||||
// example
|
||||
export enum MiddlewareName {
|
||||
LOGGING_START = 'LoggingStartMiddleware',
|
||||
LOGGING_END = 'LoggingEndMiddleware',
|
||||
ERROR_HANDLING = 'ErrorHandlingMiddleware',
|
||||
ABORT_HANDLER = 'AbortHandlerMiddleware',
|
||||
// Core Flow
|
||||
TRANSFORM_CORE_TO_SDK_PARAMS = 'TransformCoreToSdkParamsMiddleware',
|
||||
REQUEST_EXECUTION = 'RequestExecutionMiddleware',
|
||||
STREAM_ADAPTER = 'StreamAdapterMiddleware',
|
||||
RAW_SDK_CHUNK_TO_APP_CHUNK = 'RawSdkChunkToAppChunkMiddleware',
|
||||
// Features
|
||||
THINKING_TAG_EXTRACTION = 'ThinkingTagExtractionMiddleware',
|
||||
TOOL_USE_TAG_EXTRACTION = 'ToolUseTagExtractionMiddleware',
|
||||
MCP_TOOL_HANDLER = 'McpToolHandlerMiddleware',
|
||||
// Finalization
|
||||
FINAL_CHUNK_CONSUMER = 'FinalChunkConsumerAndNotifierMiddleware'
|
||||
// Add more as needed
|
||||
}
|
||||
```
|
||||
|
||||
中间件实例需要某种方式暴露其 `MiddlewareName`,例如通过一个 `name` 属性。
|
||||
|
||||
### 1.4. 中间件执行结构
|
||||
|
||||
我们采用一种灵活的中间件执行结构。一个中间件通常是一个函数,它接收 `Context`、`Params`,以及一个 `next` 函数(用于调用链中的下一个中间件)。
|
||||
|
||||
```typescript
|
||||
// 简化形式的中间件函数签名
|
||||
type MiddlewareFunction = (
|
||||
context: AiProviderMiddlewareContext,
|
||||
params: any, // e.g., CompletionsParams
|
||||
next: () => Promise<void> // next 通常返回 Promise 以支持异步操作
|
||||
) => Promise<void> // 中间件自身也可能返回 Promise
|
||||
|
||||
// 或者更经典的 Koa/Express 风格 (三段式)
|
||||
// type MiddlewareFactory = (api?: MiddlewareApi) =>
|
||||
// (nextMiddleware: (ctx: AiProviderMiddlewareContext, params: any) => Promise<void>) =>
|
||||
// (context: AiProviderMiddlewareContext, params: any) => Promise<void>;
|
||||
// 当前设计更倾向于上述简化的 MiddlewareFunction,由 MiddlewareExecutor 负责 next 的编排。
|
||||
```
|
||||
|
||||
`MiddlewareExecutor` (或 `applyMiddlewares`) 会负责管理 `next` 的调用。
|
||||
|
||||
## 2. `MiddlewareBuilder` (通用中间件构建器)
|
||||
|
||||
为了动态构建和管理中间件链,我们引入一个通用的 `MiddlewareBuilder` 类。
|
||||
|
||||
### 2.1. 设计理念
|
||||
|
||||
`MiddlewareBuilder` 提供了一个流式 API,用于以声明式的方式构建中间件链。它允许从一个基础链开始,然后根据特定条件添加、插入、替换或移除中间件。
|
||||
|
||||
### 2.2. API 概览
|
||||
|
||||
```typescript
|
||||
class MiddlewareBuilder {
|
||||
constructor(baseChain?: Middleware[])
|
||||
|
||||
add(middleware: Middleware): this
|
||||
prepend(middleware: Middleware): this
|
||||
insertAfter(targetName: MiddlewareName, middlewareToInsert: Middleware): this
|
||||
insertBefore(targetName: MiddlewareName, middlewareToInsert: Middleware): this
|
||||
replace(targetName: MiddlewareName, newMiddleware: Middleware): this
|
||||
remove(targetName: MiddlewareName): this
|
||||
|
||||
build(): Middleware[] // 返回构建好的中间件数组
|
||||
|
||||
// 可选:直接执行链
|
||||
execute(
|
||||
context: AiProviderMiddlewareContext,
|
||||
params: any,
|
||||
middlewareExecutor: (chain: Middleware[], context: AiProviderMiddlewareContext, params: any) => void
|
||||
): void
|
||||
}
|
||||
```
|
||||
|
||||
### 2.3. 使用示例
|
||||
|
||||
```typescript
|
||||
// 1. 定义一些中间件实例 (假设它们有 .name 属性)
|
||||
const loggingStart = { name: MiddlewareName.LOGGING_START, fn: loggingStartFn }
|
||||
const requestExec = { name: MiddlewareName.REQUEST_EXECUTION, fn: requestExecFn }
|
||||
const streamAdapter = { name: MiddlewareName.STREAM_ADAPTER, fn: streamAdapterFn }
|
||||
const customFeature = { name: MiddlewareName.CUSTOM_FEATURE, fn: customFeatureFn } // 假设自定义
|
||||
|
||||
// 2. 定义一个基础链 (可选)
|
||||
const BASE_CHAIN: Middleware[] = [loggingStart, requestExec, streamAdapter]
|
||||
|
||||
// 3. 使用 MiddlewareBuilder
|
||||
const builder = new MiddlewareBuilder(BASE_CHAIN)
|
||||
|
||||
if (params.needsCustomFeature) {
|
||||
builder.insertAfter(MiddlewareName.STREAM_ADAPTER, customFeature)
|
||||
}
|
||||
|
||||
if (params.isHighSecurityContext) {
|
||||
builder.insertBefore(MiddlewareName.REQUEST_EXECUTION, высокоSecurityCheckMiddleware)
|
||||
}
|
||||
|
||||
if (params.overrideLogging) {
|
||||
builder.replace(MiddlewareName.LOGGING_START, newSpecialLoggingMiddleware)
|
||||
}
|
||||
|
||||
// 4. 获取最终链
|
||||
const finalChain = builder.build()
|
||||
|
||||
// 5. 执行 (通过外部执行器)
|
||||
// middlewareExecutor(finalChain, context, params);
|
||||
// 或者 builder.execute(context, params, middlewareExecutor);
|
||||
```
|
||||
|
||||
## 3. `MiddlewareExecutor` / `applyMiddlewares` (中间件执行器)
|
||||
|
||||
这是负责接收 `MiddlewareBuilder` 构建的中间件链并实际执行它们的组件。
|
||||
|
||||
### 3.1. 职责
|
||||
|
||||
- 接收 `Middleware[]`, `AiProviderMiddlewareContext`, `Params`。
|
||||
- 按顺序迭代中间件。
|
||||
- 为每个中间件提供正确的 `next` 函数,该函数在被调用时会执行链中的下一个中间件。
|
||||
- 处理中间件执行过程中的Promise(如果中间件是异步的)。
|
||||
- 基础的错误捕获(具体错误处理应由链内的 `ErrorHandlingMiddleware` 负责)。
|
||||
|
||||
## 4. 在 `AiCoreService` 中使用
|
||||
|
||||
`AiCoreService` 中的每个核心业务方法 (如 `executeCompletions`) 将负责:
|
||||
|
||||
1. 准备基础数据:实例化 `ApiClient`,转换 `Params` 为 `CoreRequest`。
|
||||
2. 实例化 `MiddlewareBuilder`,可能会传入一个特定于该业务方法的基础中间件链。
|
||||
3. 根据 `Params` 和 `CoreRequest` 中的条件,调用 `MiddlewareBuilder` 的方法来动态调整中间件链。
|
||||
4. 调用 `MiddlewareBuilder.build()` 获取最终的中间件链。
|
||||
5. 创建完整的 `AiProviderMiddlewareContext` (包含 `resolvePromise`, `rejectPromise` 等)。
|
||||
6. 调用 `MiddlewareExecutor` (或 `applyMiddlewares`) 来执行构建好的链。
|
||||
|
||||
## 5. 组合功能
|
||||
|
||||
对于组合功能(例如 "Completions then Translate"):
|
||||
|
||||
- 不推荐创建一个单一、庞大的 `MiddlewareBuilder` 来处理整个组合流程。
|
||||
- 推荐在 `AiCoreService` 中创建一个新的方法,该方法按顺序 `await` 调用底层的原子 `AiCoreService` 方法(例如,先 `await this.executeCompletions(...)`,然后用其结果 `await this.translateText(...)`)。
|
||||
- 每个被调用的原子方法内部会使用其自身的 `MiddlewareBuilder` 实例来构建和执行其特定阶段的中间件链。
|
||||
- 这种方式最大化了复用,并保持了各部分职责的清晰。
|
||||
|
||||
## 6. 中间件命名和发现
|
||||
|
||||
为中间件赋予唯一的 `MiddlewareName` 对于 `MiddlewareBuilder` 的 `insertAfter`, `insertBefore`, `replace`, `remove` 等操作至关重要。确保中间件实例能够以某种方式暴露其名称(例如,一个 `name` 属性)。
|
||||
@ -1,79 +0,0 @@
|
||||
import { ChunkType } from '@renderer/types/chunk'
|
||||
import { describe, expect, it } from 'vitest'
|
||||
|
||||
import { capitalize, createErrorChunk, isAsyncIterable } from '../utils'
|
||||
|
||||
describe('utils', () => {
|
||||
describe('createErrorChunk', () => {
|
||||
it('should handle Error instances', () => {
|
||||
const error = new Error('Test error message')
|
||||
const result = createErrorChunk(error)
|
||||
|
||||
expect(result.type).toBe(ChunkType.ERROR)
|
||||
expect(result.error.message).toBe('Test error message')
|
||||
expect(result.error.name).toBe('Error')
|
||||
expect(result.error.stack).toBeDefined()
|
||||
})
|
||||
|
||||
it('should handle string errors', () => {
|
||||
const result = createErrorChunk('Something went wrong')
|
||||
expect(result.error).toEqual({ message: 'Something went wrong' })
|
||||
})
|
||||
|
||||
it('should handle plain objects', () => {
|
||||
const error = { code: 'NETWORK_ERROR', status: 500 }
|
||||
const result = createErrorChunk(error)
|
||||
expect(result.error).toEqual(error)
|
||||
})
|
||||
|
||||
it('should handle null and undefined', () => {
|
||||
expect(createErrorChunk(null).error).toEqual({})
|
||||
expect(createErrorChunk(undefined).error).toEqual({})
|
||||
})
|
||||
|
||||
it('should use custom chunk type when provided', () => {
|
||||
const result = createErrorChunk('error', ChunkType.BLOCK_COMPLETE)
|
||||
expect(result.type).toBe(ChunkType.BLOCK_COMPLETE)
|
||||
})
|
||||
|
||||
it('should use toString for objects without message', () => {
|
||||
const error = {
|
||||
toString: () => 'Custom error'
|
||||
}
|
||||
const result = createErrorChunk(error)
|
||||
expect(result.error.message).toBe('Custom error')
|
||||
})
|
||||
})
|
||||
|
||||
describe('capitalize', () => {
|
||||
it('should capitalize first letter', () => {
|
||||
expect(capitalize('hello')).toBe('Hello')
|
||||
expect(capitalize('a')).toBe('A')
|
||||
})
|
||||
|
||||
it('should handle edge cases', () => {
|
||||
expect(capitalize('')).toBe('')
|
||||
expect(capitalize('123')).toBe('123')
|
||||
expect(capitalize('Hello')).toBe('Hello')
|
||||
})
|
||||
})
|
||||
|
||||
describe('isAsyncIterable', () => {
|
||||
it('should identify async iterables', () => {
|
||||
async function* gen() {
|
||||
yield 1
|
||||
}
|
||||
expect(isAsyncIterable(gen())).toBe(true)
|
||||
expect(isAsyncIterable({ [Symbol.asyncIterator]: () => {} })).toBe(true)
|
||||
})
|
||||
|
||||
it('should reject non-async iterables', () => {
|
||||
expect(isAsyncIterable([1, 2, 3])).toBe(false)
|
||||
expect(isAsyncIterable(new Set())).toBe(false)
|
||||
expect(isAsyncIterable({})).toBe(false)
|
||||
expect(isAsyncIterable(null)).toBe(false)
|
||||
expect(isAsyncIterable(123)).toBe(false)
|
||||
expect(isAsyncIterable('string')).toBe(false)
|
||||
})
|
||||
})
|
||||
})
|
||||
@ -1,245 +0,0 @@
|
||||
import { loggerService } from '@logger'
|
||||
|
||||
import { DefaultCompletionsNamedMiddlewares } from './register'
|
||||
import type { BaseContext, CompletionsMiddleware, MethodMiddleware } from './types'
|
||||
|
||||
const logger = loggerService.withContext('aiCore:MiddlewareBuilder')
|
||||
|
||||
/**
|
||||
* 带有名称标识的中间件接口
|
||||
*/
|
||||
export interface NamedMiddleware<TMiddleware = any> {
|
||||
name: string
|
||||
middleware: TMiddleware
|
||||
}
|
||||
|
||||
/**
|
||||
* 中间件执行器函数类型
|
||||
*/
|
||||
export type MiddlewareExecutor<TContext extends BaseContext = BaseContext> = (
|
||||
chain: any[],
|
||||
context: TContext,
|
||||
params: any
|
||||
) => Promise<any>
|
||||
|
||||
/**
|
||||
* 通用中间件构建器类
|
||||
* 提供流式 API 用于动态构建和管理中间件链
|
||||
*
|
||||
* 注意:所有中间件都通过 MiddlewareRegistry 管理,使用 NamedMiddleware 格式
|
||||
*/
|
||||
export class MiddlewareBuilder<TMiddleware = any> {
|
||||
private middlewares: NamedMiddleware<TMiddleware>[]
|
||||
|
||||
/**
|
||||
* 构造函数
|
||||
* @param baseChain - 可选的基础中间件链(NamedMiddleware 格式)
|
||||
*/
|
||||
constructor(baseChain?: NamedMiddleware<TMiddleware>[]) {
|
||||
this.middlewares = baseChain ? [...baseChain] : []
|
||||
}
|
||||
|
||||
/**
|
||||
* 在链的末尾添加中间件
|
||||
* @param middleware - 要添加的具名中间件
|
||||
* @returns this,支持链式调用
|
||||
*/
|
||||
add(middleware: NamedMiddleware<TMiddleware>): this {
|
||||
this.middlewares.push(middleware)
|
||||
return this
|
||||
}
|
||||
|
||||
/**
|
||||
* 在链的开头添加中间件
|
||||
* @param middleware - 要添加的具名中间件
|
||||
* @returns this,支持链式调用
|
||||
*/
|
||||
prepend(middleware: NamedMiddleware<TMiddleware>): this {
|
||||
this.middlewares.unshift(middleware)
|
||||
return this
|
||||
}
|
||||
|
||||
/**
|
||||
* 在指定中间件之后插入新中间件
|
||||
* @param targetName - 目标中间件名称
|
||||
* @param middlewareToInsert - 要插入的具名中间件
|
||||
* @returns this,支持链式调用
|
||||
*/
|
||||
insertAfter(targetName: string, middlewareToInsert: NamedMiddleware<TMiddleware>): this {
|
||||
const index = this.findMiddlewareIndex(targetName)
|
||||
if (index !== -1) {
|
||||
this.middlewares.splice(index + 1, 0, middlewareToInsert)
|
||||
} else {
|
||||
logger.warn(`未找到名为 '${targetName}' 的中间件,无法插入`)
|
||||
}
|
||||
return this
|
||||
}
|
||||
|
||||
/**
|
||||
* 在指定中间件之前插入新中间件
|
||||
* @param targetName - 目标中间件名称
|
||||
* @param middlewareToInsert - 要插入的具名中间件
|
||||
* @returns this,支持链式调用
|
||||
*/
|
||||
insertBefore(targetName: string, middlewareToInsert: NamedMiddleware<TMiddleware>): this {
|
||||
const index = this.findMiddlewareIndex(targetName)
|
||||
if (index !== -1) {
|
||||
this.middlewares.splice(index, 0, middlewareToInsert)
|
||||
} else {
|
||||
logger.warn(`未找到名为 '${targetName}' 的中间件,无法插入`)
|
||||
}
|
||||
return this
|
||||
}
|
||||
|
||||
/**
|
||||
* 替换指定的中间件
|
||||
* @param targetName - 要替换的中间件名称
|
||||
* @param newMiddleware - 新的具名中间件
|
||||
* @returns this,支持链式调用
|
||||
*/
|
||||
replace(targetName: string, newMiddleware: NamedMiddleware<TMiddleware>): this {
|
||||
const index = this.findMiddlewareIndex(targetName)
|
||||
if (index !== -1) {
|
||||
this.middlewares[index] = newMiddleware
|
||||
} else {
|
||||
logger.warn(`未找到名为 '${targetName}' 的中间件,无法替换`)
|
||||
}
|
||||
return this
|
||||
}
|
||||
|
||||
/**
|
||||
* 移除指定的中间件
|
||||
* @param targetName - 要移除的中间件名称
|
||||
* @returns this,支持链式调用
|
||||
*/
|
||||
remove(targetName: string): this {
|
||||
const index = this.findMiddlewareIndex(targetName)
|
||||
if (index !== -1) {
|
||||
this.middlewares.splice(index, 1)
|
||||
}
|
||||
return this
|
||||
}
|
||||
|
||||
/**
|
||||
* 构建最终的中间件数组
|
||||
* @returns 构建好的中间件数组
|
||||
*/
|
||||
build(): TMiddleware[] {
|
||||
return this.middlewares.map((item) => item.middleware)
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取当前中间件链的副本(包含名称信息)
|
||||
* @returns 当前中间件链的副本
|
||||
*/
|
||||
getChain(): NamedMiddleware<TMiddleware>[] {
|
||||
return [...this.middlewares]
|
||||
}
|
||||
|
||||
/**
|
||||
* 检查是否包含指定名称的中间件
|
||||
* @param name - 中间件名称
|
||||
* @returns 是否包含该中间件
|
||||
*/
|
||||
has(name: string): boolean {
|
||||
return this.findMiddlewareIndex(name) !== -1
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取中间件链的长度
|
||||
* @returns 中间件数量
|
||||
*/
|
||||
get length(): number {
|
||||
return this.middlewares.length
|
||||
}
|
||||
|
||||
/**
|
||||
* 清空中间件链
|
||||
* @returns this,支持链式调用
|
||||
*/
|
||||
clear(): this {
|
||||
this.middlewares = []
|
||||
return this
|
||||
}
|
||||
|
||||
/**
|
||||
* 直接执行构建好的中间件链
|
||||
* @param context - 中间件上下文
|
||||
* @param params - 参数
|
||||
* @param middlewareExecutor - 中间件执行器
|
||||
* @returns 执行结果
|
||||
*/
|
||||
execute<TContext extends BaseContext>(
|
||||
context: TContext,
|
||||
params: any,
|
||||
middlewareExecutor: MiddlewareExecutor<TContext>
|
||||
): Promise<any> {
|
||||
const chain = this.build()
|
||||
return middlewareExecutor(chain, context, params)
|
||||
}
|
||||
|
||||
/**
|
||||
* 查找中间件在链中的索引
|
||||
* @param name - 中间件名称
|
||||
* @returns 索引,如果未找到返回 -1
|
||||
*/
|
||||
private findMiddlewareIndex(name: string): number {
|
||||
return this.middlewares.findIndex((item) => item.name === name)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Completions 中间件构建器
|
||||
*/
|
||||
export class CompletionsMiddlewareBuilder extends MiddlewareBuilder<CompletionsMiddleware> {
|
||||
constructor(baseChain?: NamedMiddleware<CompletionsMiddleware>[]) {
|
||||
super(baseChain)
|
||||
}
|
||||
|
||||
/**
|
||||
* 使用默认的 Completions 中间件链
|
||||
* @returns CompletionsMiddlewareBuilder 实例
|
||||
*/
|
||||
static withDefaults(): CompletionsMiddlewareBuilder {
|
||||
return new CompletionsMiddlewareBuilder(DefaultCompletionsNamedMiddlewares)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 通用方法中间件构建器
|
||||
*/
|
||||
export class MethodMiddlewareBuilder extends MiddlewareBuilder<MethodMiddleware> {
|
||||
constructor(baseChain?: NamedMiddleware<MethodMiddleware>[]) {
|
||||
super(baseChain)
|
||||
}
|
||||
}
|
||||
|
||||
// 便捷的工厂函数
|
||||
|
||||
/**
|
||||
* 创建 Completions 中间件构建器
|
||||
* @param baseChain - 可选的基础链
|
||||
* @returns Completions 中间件构建器实例
|
||||
*/
|
||||
export function createCompletionsBuilder(
|
||||
baseChain?: NamedMiddleware<CompletionsMiddleware>[]
|
||||
): CompletionsMiddlewareBuilder {
|
||||
return new CompletionsMiddlewareBuilder(baseChain)
|
||||
}
|
||||
|
||||
/**
|
||||
* 创建通用方法中间件构建器
|
||||
* @param baseChain - 可选的基础链
|
||||
* @returns 通用方法中间件构建器实例
|
||||
*/
|
||||
export function createMethodBuilder(baseChain?: NamedMiddleware<MethodMiddleware>[]): MethodMiddlewareBuilder {
|
||||
return new MethodMiddlewareBuilder(baseChain)
|
||||
}
|
||||
|
||||
/**
|
||||
* 为中间件添加名称属性的辅助函数
|
||||
* 可以用于给现有的中间件添加名称属性
|
||||
*/
|
||||
export function addMiddlewareName<T extends object>(middleware: T, name: string): T & { MIDDLEWARE_NAME: string } {
|
||||
return Object.assign(middleware, { MIDDLEWARE_NAME: name })
|
||||
}
|
||||
@ -1,121 +0,0 @@
|
||||
import { loggerService } from '@logger'
|
||||
import type { Chunk, ErrorChunk } from '@renderer/types/chunk'
|
||||
import { ChunkType } from '@renderer/types/chunk'
|
||||
import { addAbortController, removeAbortController } from '@renderer/utils/abortController'
|
||||
|
||||
import type { CompletionsParams, CompletionsResult } from '../schemas'
|
||||
import type { CompletionsContext, CompletionsMiddleware } from '../types'
|
||||
|
||||
const logger = loggerService.withContext('aiCore:AbortHandlerMiddleware')
|
||||
|
||||
export const MIDDLEWARE_NAME = 'AbortHandlerMiddleware'
|
||||
|
||||
export const AbortHandlerMiddleware: CompletionsMiddleware =
|
||||
() =>
|
||||
(next) =>
|
||||
async (ctx: CompletionsContext, params: CompletionsParams): Promise<CompletionsResult> => {
|
||||
const isRecursiveCall = ctx._internal?.toolProcessingState?.isRecursiveCall || false
|
||||
|
||||
// 在递归调用中,跳过 AbortController 的创建,直接使用已有的
|
||||
if (isRecursiveCall) {
|
||||
const result = await next(ctx, params)
|
||||
return result
|
||||
}
|
||||
|
||||
const abortController = new AbortController()
|
||||
const abortFn = (): void => abortController.abort()
|
||||
let abortSignal: AbortSignal | null = abortController.signal
|
||||
let abortKey: string
|
||||
|
||||
// 如果参数中传入了abortKey则优先使用
|
||||
if (params.abortKey) {
|
||||
abortKey = params.abortKey
|
||||
} else {
|
||||
// 获取当前消息的ID用于abort管理
|
||||
// 优先使用处理过的消息,如果没有则使用原始消息
|
||||
let messageId: string | undefined
|
||||
|
||||
if (typeof params.messages === 'string') {
|
||||
messageId = `message-${Date.now()}-${Math.random().toString(36).substring(2, 9)}`
|
||||
} else {
|
||||
const processedMessages = params.messages
|
||||
const lastUserMessage = processedMessages.findLast((m) => m.role === 'user')
|
||||
messageId = lastUserMessage?.id
|
||||
}
|
||||
|
||||
if (!messageId) {
|
||||
logger.warn(`No messageId found, abort functionality will not be available.`)
|
||||
return next(ctx, params)
|
||||
}
|
||||
|
||||
abortKey = messageId
|
||||
}
|
||||
|
||||
addAbortController(abortKey, abortFn)
|
||||
const cleanup = (): void => {
|
||||
removeAbortController(abortKey, abortFn)
|
||||
if (ctx._internal?.flowControl) {
|
||||
ctx._internal.flowControl.abortController = undefined
|
||||
ctx._internal.flowControl.abortSignal = undefined
|
||||
ctx._internal.flowControl.cleanup = undefined
|
||||
}
|
||||
abortSignal = null
|
||||
}
|
||||
|
||||
// 将controller添加到_internal中的flowControl状态
|
||||
if (!ctx._internal.flowControl) {
|
||||
ctx._internal.flowControl = {}
|
||||
}
|
||||
ctx._internal.flowControl.abortController = abortController
|
||||
ctx._internal.flowControl.abortSignal = abortSignal
|
||||
ctx._internal.flowControl.cleanup = cleanup
|
||||
|
||||
const result = await next(ctx, params)
|
||||
|
||||
const error = new DOMException('Request was aborted', 'AbortError')
|
||||
|
||||
const streamWithAbortHandler = (result.stream as ReadableStream<Chunk>).pipeThrough(
|
||||
new TransformStream<Chunk, Chunk | ErrorChunk>({
|
||||
transform(chunk, controller) {
|
||||
// 如果已经收到错误块,不再检查 abort 状态
|
||||
if (chunk.type === ChunkType.ERROR) {
|
||||
controller.enqueue(chunk)
|
||||
return
|
||||
}
|
||||
|
||||
if (abortSignal?.aborted) {
|
||||
// 转换为 ErrorChunk
|
||||
const errorChunk: ErrorChunk = {
|
||||
type: ChunkType.ERROR,
|
||||
error
|
||||
}
|
||||
|
||||
controller.enqueue(errorChunk)
|
||||
cleanup()
|
||||
return
|
||||
}
|
||||
|
||||
// 正常传递 chunk
|
||||
controller.enqueue(chunk)
|
||||
},
|
||||
|
||||
flush(controller) {
|
||||
// 在流结束时再次检查 abort 状态
|
||||
if (abortSignal?.aborted) {
|
||||
const errorChunk: ErrorChunk = {
|
||||
type: ChunkType.ERROR,
|
||||
error
|
||||
}
|
||||
controller.enqueue(errorChunk)
|
||||
}
|
||||
// 在流完全处理完成后清理 AbortController
|
||||
cleanup()
|
||||
}
|
||||
})
|
||||
)
|
||||
|
||||
return {
|
||||
...result,
|
||||
stream: streamWithAbortHandler
|
||||
}
|
||||
}
|
||||
@ -1,134 +0,0 @@
|
||||
import { loggerService } from '@logger'
|
||||
import { isZhipuModel } from '@renderer/config/models'
|
||||
import { getStoreProviders } from '@renderer/hooks/useStore'
|
||||
import { getDefaultModel } from '@renderer/services/AssistantService'
|
||||
import type { Chunk } from '@renderer/types/chunk'
|
||||
|
||||
import type { CompletionsParams, CompletionsResult } from '../schemas'
|
||||
import type { CompletionsContext } from '../types'
|
||||
import { createErrorChunk } from '../utils'
|
||||
|
||||
const logger = loggerService.withContext('ErrorHandlerMiddleware')
|
||||
|
||||
export const MIDDLEWARE_NAME = 'ErrorHandlerMiddleware'
|
||||
|
||||
/**
|
||||
* 创建一个错误处理中间件。
|
||||
*
|
||||
* 这是一个高阶函数,它接收配置并返回一个标准的中间件。
|
||||
* 它的主要职责是捕获下游中间件或API调用中发生的任何错误。
|
||||
*
|
||||
* @param config - 中间件的配置。
|
||||
* @returns 一个配置好的CompletionsMiddleware。
|
||||
*/
|
||||
export const ErrorHandlerMiddleware =
|
||||
() =>
|
||||
(next) =>
|
||||
async (ctx: CompletionsContext, params): Promise<CompletionsResult> => {
|
||||
const { shouldThrow } = params
|
||||
|
||||
try {
|
||||
// 尝试执行下一个中间件
|
||||
return await next(ctx, params)
|
||||
} catch (error: any) {
|
||||
logger.error(error)
|
||||
|
||||
let processedError = error
|
||||
processedError = handleError(error, params)
|
||||
|
||||
// 1. 使用通用的工具函数将错误解析为标准格式
|
||||
const errorChunk = createErrorChunk(processedError)
|
||||
|
||||
// 2. 调用从外部传入的 onError 回调
|
||||
if (params.onError) {
|
||||
params.onError(processedError)
|
||||
}
|
||||
|
||||
// 3. 根据配置决定是重新抛出错误,还是将其作为流的一部分向下传递
|
||||
if (shouldThrow) {
|
||||
throw processedError
|
||||
}
|
||||
|
||||
// 如果不抛出,则创建一个只包含该错误块的流并向下传递
|
||||
const errorStream = new ReadableStream<Chunk>({
|
||||
start(controller) {
|
||||
controller.enqueue(errorChunk)
|
||||
controller.close()
|
||||
}
|
||||
})
|
||||
|
||||
return {
|
||||
rawOutput: undefined,
|
||||
stream: errorStream, // 将包含错误的流传递下去
|
||||
controller: undefined,
|
||||
getText: () => '' // 错误情况下没有文本结果
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
function handleError(error: any, params: CompletionsParams): any {
|
||||
if (isZhipuModel(params.assistant.model || getDefaultModel()) && error.status && !params.enableGenerateImage) {
|
||||
return handleZhipuError(error)
|
||||
}
|
||||
|
||||
if (error.status === 401 || error.message.includes('401')) {
|
||||
return {
|
||||
...error,
|
||||
i18nKey: 'chat.no_api_key',
|
||||
providerId: params.assistant?.model?.provider
|
||||
}
|
||||
}
|
||||
|
||||
return error
|
||||
}
|
||||
|
||||
/**
|
||||
* 处理智谱特定错误
|
||||
* 1. 只有对话功能(enableGenerateImage为false)才使用自定义错误处理
|
||||
* 2. 绘画功能(enableGenerateImage为true)使用通用错误处理
|
||||
*/
|
||||
function handleZhipuError(error: any): any {
|
||||
const provider = getStoreProviders().find((p) => p.id === 'zhipu')
|
||||
const logger = loggerService.withContext('handleZhipuError')
|
||||
|
||||
// 定义错误模式映射
|
||||
const errorPatterns = [
|
||||
{
|
||||
condition: () => error.status === 401 || /令牌已过期|AuthenticationError|Unauthorized/i.test(error.message),
|
||||
i18nKey: 'chat.no_api_key',
|
||||
providerId: provider?.id
|
||||
},
|
||||
{
|
||||
condition: () => error.error?.code === '1304' || /限额|免费配额|free quota|rate limit/i.test(error.message),
|
||||
i18nKey: 'chat.quota_exceeded',
|
||||
providerId: provider?.id
|
||||
},
|
||||
{
|
||||
condition: () =>
|
||||
(error.status === 429 && error.error?.code === '1113') || /余额不足|insufficient balance/i.test(error.message),
|
||||
i18nKey: 'chat.insufficient_balance',
|
||||
providerId: provider?.id
|
||||
},
|
||||
{
|
||||
condition: () => !provider?.apiKey?.trim(),
|
||||
i18nKey: 'chat.no_api_key',
|
||||
providerId: provider?.id
|
||||
}
|
||||
]
|
||||
|
||||
// 遍历错误模式,返回第一个匹配的错误
|
||||
for (const pattern of errorPatterns) {
|
||||
if (pattern.condition()) {
|
||||
return {
|
||||
...error,
|
||||
providerId: pattern.providerId,
|
||||
i18nKey: pattern.i18nKey
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 如果不是智谱特定错误,返回原始错误
|
||||
logger.debug('🔧 不是智谱特定错误,返回原始错误')
|
||||
|
||||
return error
|
||||
}
|
||||
@ -1,195 +0,0 @@
|
||||
import { loggerService } from '@logger'
|
||||
import type { Usage } from '@renderer/types'
|
||||
import type { Chunk } from '@renderer/types/chunk'
|
||||
import { ChunkType } from '@renderer/types/chunk'
|
||||
|
||||
import type { CompletionsParams, CompletionsResult, GenericChunk } from '../schemas'
|
||||
import type { CompletionsContext, CompletionsMiddleware } from '../types'
|
||||
|
||||
export const MIDDLEWARE_NAME = 'FinalChunkConsumerAndNotifierMiddleware'
|
||||
|
||||
const logger = loggerService.withContext('FinalChunkConsumerMiddleware')
|
||||
|
||||
/**
|
||||
* 最终Chunk消费和通知中间件
|
||||
*
|
||||
* 职责:
|
||||
* 1. 消费所有GenericChunk流中的chunks并转发给onChunk回调
|
||||
* 2. 累加usage/metrics数据(从原始SDK chunks或GenericChunk中提取)
|
||||
* 3. 在检测到LLM_RESPONSE_COMPLETE时发送包含累计数据的BLOCK_COMPLETE
|
||||
* 4. 处理MCP工具调用的多轮请求中的数据累加
|
||||
*/
|
||||
const FinalChunkConsumerMiddleware: CompletionsMiddleware =
|
||||
() =>
|
||||
(next) =>
|
||||
async (ctx: CompletionsContext, params: CompletionsParams): Promise<CompletionsResult> => {
|
||||
const isRecursiveCall =
|
||||
params._internal?.toolProcessingState?.isRecursiveCall ||
|
||||
ctx._internal?.toolProcessingState?.isRecursiveCall ||
|
||||
false
|
||||
|
||||
// 初始化累计数据(只在顶层调用时初始化)
|
||||
if (!isRecursiveCall) {
|
||||
if (!ctx._internal.customState) {
|
||||
ctx._internal.customState = {}
|
||||
}
|
||||
ctx._internal.observer = {
|
||||
usage: {
|
||||
prompt_tokens: 0,
|
||||
completion_tokens: 0,
|
||||
total_tokens: 0
|
||||
},
|
||||
metrics: {
|
||||
completion_tokens: 0,
|
||||
time_completion_millsec: 0,
|
||||
time_first_token_millsec: 0,
|
||||
time_thinking_millsec: 0
|
||||
}
|
||||
}
|
||||
// 初始化文本累积器
|
||||
ctx._internal.customState.accumulatedText = ''
|
||||
ctx._internal.customState.startTimestamp = Date.now()
|
||||
}
|
||||
|
||||
// 调用下游中间件
|
||||
const result = await next(ctx, params)
|
||||
|
||||
// 响应后处理:处理GenericChunk流式响应
|
||||
if (result.stream) {
|
||||
const resultFromUpstream = result.stream
|
||||
|
||||
if (resultFromUpstream && resultFromUpstream instanceof ReadableStream) {
|
||||
const reader = resultFromUpstream.getReader()
|
||||
|
||||
try {
|
||||
while (true) {
|
||||
const { done, value: chunk } = await reader.read()
|
||||
logger.silly('chunk', chunk)
|
||||
if (done) {
|
||||
logger.debug(`Input stream finished.`)
|
||||
break
|
||||
}
|
||||
|
||||
if (chunk) {
|
||||
const genericChunk = chunk as GenericChunk
|
||||
// 提取并累加usage/metrics数据
|
||||
extractAndAccumulateUsageMetrics(ctx, genericChunk)
|
||||
|
||||
const shouldSkipChunk =
|
||||
isRecursiveCall &&
|
||||
(genericChunk.type === ChunkType.BLOCK_COMPLETE ||
|
||||
genericChunk.type === ChunkType.LLM_RESPONSE_COMPLETE)
|
||||
|
||||
if (!shouldSkipChunk) params.onChunk?.(genericChunk)
|
||||
} else {
|
||||
logger.warn(`Received undefined chunk before stream was done.`)
|
||||
}
|
||||
}
|
||||
} catch (error: any) {
|
||||
logger.error(`Error consuming stream:`, error as Error)
|
||||
// FIXME: 临时解决方案。该中间件的异常无法被 ErrorHandlerMiddleware捕获。
|
||||
if (params.onError) {
|
||||
params.onError(error)
|
||||
}
|
||||
if (params.shouldThrow) {
|
||||
throw error
|
||||
}
|
||||
} finally {
|
||||
if (params.onChunk && !isRecursiveCall) {
|
||||
params.onChunk({
|
||||
type: ChunkType.BLOCK_COMPLETE,
|
||||
response: {
|
||||
usage: ctx._internal.observer?.usage ? { ...ctx._internal.observer.usage } : undefined,
|
||||
metrics: ctx._internal.observer?.metrics ? { ...ctx._internal.observer.metrics } : undefined
|
||||
}
|
||||
} as Chunk)
|
||||
if (ctx._internal.toolProcessingState) {
|
||||
ctx._internal.toolProcessingState = {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 为流式输出添加getText方法
|
||||
const modifiedResult = {
|
||||
...result,
|
||||
stream: new ReadableStream<GenericChunk>({
|
||||
start(controller) {
|
||||
controller.close()
|
||||
}
|
||||
}),
|
||||
getText: () => {
|
||||
return ctx._internal.customState?.accumulatedText || ''
|
||||
}
|
||||
}
|
||||
|
||||
return modifiedResult
|
||||
} else {
|
||||
logger.debug(`No GenericChunk stream to process.`)
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
/**
|
||||
* 从GenericChunk或原始SDK chunks中提取usage/metrics数据并累加
|
||||
*/
|
||||
function extractAndAccumulateUsageMetrics(ctx: CompletionsContext, chunk: GenericChunk): void {
|
||||
if (!ctx._internal.observer?.usage || !ctx._internal.observer?.metrics) {
|
||||
return
|
||||
}
|
||||
|
||||
try {
|
||||
if (ctx._internal.customState && !ctx._internal.customState?.firstTokenTimestamp) {
|
||||
ctx._internal.customState.firstTokenTimestamp = Date.now()
|
||||
logger.debug(`First token timestamp: ${ctx._internal.customState.firstTokenTimestamp}`)
|
||||
}
|
||||
if (chunk.type === ChunkType.LLM_RESPONSE_COMPLETE) {
|
||||
// 从LLM_RESPONSE_COMPLETE chunk中提取usage数据
|
||||
if (chunk.response?.usage) {
|
||||
accumulateUsage(ctx._internal.observer.usage, chunk.response.usage)
|
||||
}
|
||||
|
||||
if (ctx._internal.customState && ctx._internal.customState?.firstTokenTimestamp) {
|
||||
ctx._internal.observer.metrics.time_first_token_millsec =
|
||||
ctx._internal.customState.firstTokenTimestamp - ctx._internal.customState.startTimestamp
|
||||
ctx._internal.observer.metrics.time_completion_millsec +=
|
||||
Date.now() - ctx._internal.customState.firstTokenTimestamp
|
||||
}
|
||||
}
|
||||
|
||||
// 也可以从其他chunk类型中提取metrics数据
|
||||
if (chunk.type === ChunkType.THINKING_COMPLETE && chunk.thinking_millsec && ctx._internal.observer?.metrics) {
|
||||
ctx._internal.observer.metrics.time_thinking_millsec = Math.max(
|
||||
ctx._internal.observer.metrics.time_thinking_millsec || 0,
|
||||
chunk.thinking_millsec
|
||||
)
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error('Error extracting usage/metrics from chunk:', error as Error)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 累加usage数据
|
||||
*/
|
||||
function accumulateUsage(accumulated: Usage, newUsage: Usage): void {
|
||||
if (newUsage.prompt_tokens !== undefined) {
|
||||
accumulated.prompt_tokens += newUsage.prompt_tokens
|
||||
}
|
||||
if (newUsage.completion_tokens !== undefined) {
|
||||
accumulated.completion_tokens += newUsage.completion_tokens
|
||||
}
|
||||
if (newUsage.total_tokens !== undefined) {
|
||||
accumulated.total_tokens += newUsage.total_tokens
|
||||
}
|
||||
if (newUsage.thoughts_tokens !== undefined) {
|
||||
accumulated.thoughts_tokens = (accumulated.thoughts_tokens || 0) + newUsage.thoughts_tokens
|
||||
}
|
||||
// Handle OpenRouter specific cost fields
|
||||
if (newUsage.cost !== undefined) {
|
||||
accumulated.cost = (accumulated.cost || 0) + newUsage.cost
|
||||
}
|
||||
}
|
||||
|
||||
export default FinalChunkConsumerMiddleware
|
||||
@ -1,68 +0,0 @@
|
||||
import { loggerService } from '@logger'
|
||||
|
||||
import type { BaseContext, MethodMiddleware, MiddlewareAPI } from '../types'
|
||||
|
||||
const logger = loggerService.withContext('LoggingMiddleware')
|
||||
|
||||
export const MIDDLEWARE_NAME = 'GenericLoggingMiddlewares'
|
||||
|
||||
/**
|
||||
* Helper function to safely stringify arguments for logging, handling circular references and large objects.
|
||||
* 安全地字符串化日志参数的辅助函数,处理循环引用和大型对象。
|
||||
* @param args - The arguments array to stringify. 要字符串化的参数数组。
|
||||
* @returns A string representation of the arguments. 参数的字符串表示形式。
|
||||
*/
|
||||
const stringifyArgsForLogging = (args: any[]): string => {
|
||||
try {
|
||||
return args
|
||||
.map((arg) => {
|
||||
if (typeof arg === 'function') return '[Function]'
|
||||
if (typeof arg === 'object' && arg !== null && arg.constructor === Object && Object.keys(arg).length > 20) {
|
||||
return '[Object with >20 keys]'
|
||||
}
|
||||
// Truncate long strings to avoid flooding logs 截断长字符串以避免日志泛滥
|
||||
const stringifiedArg = JSON.stringify(arg, null, 2)
|
||||
return stringifiedArg && stringifiedArg.length > 200 ? stringifiedArg.substring(0, 200) + '...' : stringifiedArg
|
||||
})
|
||||
.join(', ')
|
||||
} catch (e) {
|
||||
return '[Error serializing arguments]' // Handle potential errors during stringification 处理字符串化期间的潜在错误
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Generic logging middleware for provider methods.
|
||||
* 为提供者方法创建一个通用的日志中间件。
|
||||
* This middleware logs the initiation, success/failure, and duration of a method call.
|
||||
* 此中间件记录方法调用的启动、成功/失败以及持续时间。
|
||||
*/
|
||||
|
||||
/**
|
||||
* Creates a generic logging middleware for provider methods.
|
||||
* 为提供者方法创建一个通用的日志中间件。
|
||||
* @returns A `MethodMiddleware` instance. 一个 `MethodMiddleware` 实例。
|
||||
*/
|
||||
export const createGenericLoggingMiddleware: () => MethodMiddleware = () => {
|
||||
const middlewareName = 'GenericLoggingMiddleware'
|
||||
// oxlint-disable-next-line @typescript-eslint/no-unused-vars
|
||||
return (_: MiddlewareAPI<BaseContext, any[]>) => (next) => async (ctx, args) => {
|
||||
const methodName = ctx.methodName
|
||||
const logPrefix = `[${middlewareName} (${methodName})]`
|
||||
logger.debug(`${logPrefix} Initiating. Args: ${stringifyArgsForLogging(args)}`)
|
||||
const startTime = Date.now()
|
||||
try {
|
||||
const result = await next(ctx, args)
|
||||
const duration = Date.now() - startTime
|
||||
// Log successful completion of the method call with duration. /
|
||||
// 记录方法调用成功完成及其持续时间。
|
||||
logger.debug(`${logPrefix} Successful. Duration: ${duration}ms`)
|
||||
return result
|
||||
} catch (error) {
|
||||
const duration = Date.now() - startTime
|
||||
// Log failure of the method call with duration and error information. /
|
||||
// 记录方法调用失败及其持续时间和错误信息。
|
||||
logger.error(`${logPrefix} Failed. Duration: ${duration}ms`, error as Error)
|
||||
throw error // Re-throw the error to be handled by subsequent layers or the caller / 重新抛出错误,由后续层或调用者处理
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1,289 +0,0 @@
|
||||
import { withSpanResult } from '@renderer/services/SpanManagerService'
|
||||
import type {
|
||||
RequestOptions,
|
||||
SdkInstance,
|
||||
SdkMessageParam,
|
||||
SdkParams,
|
||||
SdkRawChunk,
|
||||
SdkRawOutput,
|
||||
SdkTool,
|
||||
SdkToolCall
|
||||
} from '@renderer/types/sdk'
|
||||
|
||||
import type { BaseApiClient } from '../clients'
|
||||
import type { CompletionsParams, CompletionsResult } from './schemas'
|
||||
import type { BaseContext, CompletionsContext, CompletionsMiddleware, MethodMiddleware, MiddlewareAPI } from './types'
|
||||
import { MIDDLEWARE_CONTEXT_SYMBOL } from './types'
|
||||
|
||||
/**
|
||||
* Creates the initial context for a method call, populating method-specific fields. /
|
||||
* 为方法调用创建初始上下文,并填充特定于该方法的字段。
|
||||
* @param methodName - The name of the method being called. / 被调用的方法名。
|
||||
* @param originalCallArgs - The actual arguments array from the proxy/method call. / 代理/方法调用的实际参数数组。
|
||||
* @param providerId - The ID of the provider, if available. / 提供者的ID(如果可用)。
|
||||
* @param providerInstance - The instance of the provider. / 提供者实例。
|
||||
* @param specificContextFactory - An optional factory function to create a specific context type from the base context and original call arguments. / 一个可选的工厂函数,用于从基础上下文和原始调用参数创建特定的上下文类型。
|
||||
* @returns The created context object. / 创建的上下文对象。
|
||||
*/
|
||||
function createInitialCallContext<TContext extends BaseContext, TCallArgs extends unknown[]>(
|
||||
methodName: string,
|
||||
originalCallArgs: TCallArgs, // Renamed from originalArgs to avoid confusion with context.originalArgs
|
||||
// Factory to create specific context from base and the *original call arguments array*
|
||||
specificContextFactory?: (base: BaseContext, callArgs: TCallArgs) => TContext
|
||||
): TContext {
|
||||
const baseContext: BaseContext = {
|
||||
[MIDDLEWARE_CONTEXT_SYMBOL]: true,
|
||||
methodName,
|
||||
originalArgs: originalCallArgs // Store the full original arguments array in the context
|
||||
}
|
||||
|
||||
if (specificContextFactory) {
|
||||
return specificContextFactory(baseContext, originalCallArgs)
|
||||
}
|
||||
return baseContext as TContext // Fallback to base context if no specific factory
|
||||
}
|
||||
|
||||
/**
|
||||
* Composes an array of functions from right to left. /
|
||||
* 从右到左组合一个函数数组。
|
||||
* `compose(f, g, h)` is `(...args) => f(g(h(...args)))`. /
|
||||
* `compose(f, g, h)` 等同于 `(...args) => f(g(h(...args)))`。
|
||||
* Each function in funcs is expected to take the result of the next function
|
||||
* (or the initial value for the rightmost function) as its argument. /
|
||||
* `funcs` 中的每个函数都期望接收下一个函数的结果(或最右侧函数的初始值)作为其参数。
|
||||
* @param funcs - Array of functions to compose. / 要组合的函数数组。
|
||||
* @returns The composed function. / 组合后的函数。
|
||||
*/
|
||||
function compose(...funcs: Array<(...args: any[]) => any>): (...args: any[]) => any {
|
||||
if (funcs.length === 0) {
|
||||
// If no functions to compose, return a function that returns its first argument, or undefined if no args. /
|
||||
// 如果没有要组合的函数,则返回一个函数,该函数返回其第一个参数,如果没有参数则返回undefined。
|
||||
return (...args: any[]) => (args.length > 0 ? args[0] : undefined)
|
||||
}
|
||||
if (funcs.length === 1) {
|
||||
return funcs[0]
|
||||
}
|
||||
return funcs.reduce(
|
||||
(a, b) =>
|
||||
(...args: any[]) =>
|
||||
a(b(...args))
|
||||
)
|
||||
}
|
||||
|
||||
/**
|
||||
* Applies an array of Redux-style middlewares to a generic provider method. /
|
||||
* 将一组Redux风格的中间件应用于一个通用的提供者方法。
|
||||
* This version keeps arguments as an array throughout the middleware chain. /
|
||||
* 此版本在整个中间件链中将参数保持为数组形式。
|
||||
* @param originalProviderInstance - The original provider instance. / 原始提供者实例。
|
||||
* @param methodName - The name of the method to be enhanced. / 需要增强的方法名。
|
||||
* @param originalMethod - The original method to be wrapped. / 需要包装的原始方法。
|
||||
* @param middlewares - An array of `ProviderMethodMiddleware` to apply. / 要应用的 `ProviderMethodMiddleware` 数组。
|
||||
* @param specificContextFactory - An optional factory to create a specific context for this method. / 可选的工厂函数,用于为此方法创建特定的上下文。
|
||||
* @returns An enhanced method with the middlewares applied. / 应用了中间件的增强方法。
|
||||
*/
|
||||
export function applyMethodMiddlewares<
|
||||
TArgs extends unknown[] = unknown[], // Original method's arguments array type / 原始方法的参数数组类型
|
||||
TResult = unknown,
|
||||
TContext extends BaseContext = BaseContext
|
||||
>(
|
||||
methodName: string,
|
||||
originalMethod: (...args: TArgs) => Promise<TResult>,
|
||||
middlewares: MethodMiddleware[], // Expects generic middlewares / 期望通用中间件
|
||||
specificContextFactory?: (base: BaseContext, callArgs: TArgs) => TContext
|
||||
): (...args: TArgs) => Promise<TResult> {
|
||||
// Returns a function matching the original method signature. /
|
||||
// 返回一个与原始方法签名匹配的函数。
|
||||
return async function enhancedMethod(...methodCallArgs: TArgs): Promise<TResult> {
|
||||
const ctx = createInitialCallContext<TContext, TArgs>(
|
||||
methodName,
|
||||
methodCallArgs, // Pass the actual call arguments array / 传递实际的调用参数数组
|
||||
specificContextFactory
|
||||
)
|
||||
|
||||
const api: MiddlewareAPI<TContext, TArgs> = {
|
||||
getContext: () => ctx,
|
||||
getOriginalArgs: () => methodCallArgs // API provides the original arguments array / API提供原始参数数组
|
||||
}
|
||||
|
||||
// `finalDispatch` is the function that will ultimately call the original provider method. /
|
||||
// `finalDispatch` 是最终将调用原始提供者方法的函数。
|
||||
// It receives the current context and arguments, which may have been transformed by middlewares. /
|
||||
// 它接收当前的上下文和参数,这些参数可能已被中间件转换。
|
||||
const finalDispatch = async (
|
||||
_: TContext,
|
||||
currentArgs: TArgs // Generic final dispatch expects args array / 通用finalDispatch期望参数数组
|
||||
): Promise<TResult> => {
|
||||
return originalMethod.apply(currentArgs)
|
||||
}
|
||||
|
||||
const chain = middlewares.map((middleware) => middleware(api)) // Cast API if TContext/TArgs mismatch general ProviderMethodMiddleware / 如果TContext/TArgs与通用的ProviderMethodMiddleware不匹配,则转换API
|
||||
const composedMiddlewareLogic = compose(...chain)
|
||||
const enhancedDispatch = composedMiddlewareLogic(finalDispatch)
|
||||
|
||||
return enhancedDispatch(ctx, methodCallArgs) // Pass context and original args array / 传递上下文和原始参数数组
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Applies an array of `CompletionsMiddleware` to the `completions` method. /
|
||||
* 将一组 `CompletionsMiddleware` 应用于 `completions` 方法。
|
||||
* This version adapts for `CompletionsMiddleware` expecting a single `params` object. /
|
||||
* 此版本适配了期望单个 `params` 对象的 `CompletionsMiddleware`。
|
||||
* @param originalProviderInstance - The original provider instance. / 原始提供者实例。
|
||||
* @param originalCompletionsMethod - The original SDK `createCompletions` method. / 原始的 SDK `createCompletions` 方法。
|
||||
* @param middlewares - An array of `CompletionsMiddleware` to apply. / 要应用的 `CompletionsMiddleware` 数组。
|
||||
* @returns An enhanced `completions` method with the middlewares applied. / 应用了中间件的增强版 `completions` 方法。
|
||||
*/
|
||||
export function applyCompletionsMiddlewares<
|
||||
TSdkInstance extends SdkInstance = SdkInstance,
|
||||
TSdkParams extends SdkParams = SdkParams,
|
||||
TRawOutput extends SdkRawOutput = SdkRawOutput,
|
||||
TRawChunk extends SdkRawChunk = SdkRawChunk,
|
||||
TMessageParam extends SdkMessageParam = SdkMessageParam,
|
||||
TToolCall extends SdkToolCall = SdkToolCall,
|
||||
TSdkSpecificTool extends SdkTool = SdkTool
|
||||
>(
|
||||
originalApiClientInstance: BaseApiClient<
|
||||
TSdkInstance,
|
||||
TSdkParams,
|
||||
TRawOutput,
|
||||
TRawChunk,
|
||||
TMessageParam,
|
||||
TToolCall,
|
||||
TSdkSpecificTool
|
||||
>,
|
||||
originalCompletionsMethod: (payload: TSdkParams, options?: RequestOptions) => Promise<TRawOutput>,
|
||||
middlewares: CompletionsMiddleware<
|
||||
TSdkParams,
|
||||
TMessageParam,
|
||||
TToolCall,
|
||||
TSdkInstance,
|
||||
TRawOutput,
|
||||
TRawChunk,
|
||||
TSdkSpecificTool
|
||||
>[]
|
||||
): (params: CompletionsParams, options?: RequestOptions) => Promise<CompletionsResult> {
|
||||
// Returns a function matching the original method signature. /
|
||||
// 返回一个与原始方法签名匹配的函数。
|
||||
|
||||
const methodName = 'completions'
|
||||
|
||||
// Factory to create AiProviderMiddlewareCompletionsContext. /
|
||||
// 用于创建 AiProviderMiddlewareCompletionsContext 的工厂函数。
|
||||
const completionsContextFactory = (
|
||||
base: BaseContext,
|
||||
callArgs: [CompletionsParams]
|
||||
): CompletionsContext<
|
||||
TSdkParams,
|
||||
TMessageParam,
|
||||
TToolCall,
|
||||
TSdkInstance,
|
||||
TRawOutput,
|
||||
TRawChunk,
|
||||
TSdkSpecificTool
|
||||
> => {
|
||||
return {
|
||||
...base,
|
||||
methodName,
|
||||
apiClientInstance: originalApiClientInstance,
|
||||
originalArgs: callArgs,
|
||||
_internal: {
|
||||
toolProcessingState: {
|
||||
recursionDepth: 0,
|
||||
isRecursiveCall: false
|
||||
},
|
||||
observer: {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return async function enhancedCompletionsMethod(
|
||||
params: CompletionsParams,
|
||||
options?: RequestOptions
|
||||
): Promise<CompletionsResult> {
|
||||
// `originalCallArgs` for context creation is `[params]`. /
|
||||
// 用于上下文创建的 `originalCallArgs` 是 `[params]`。
|
||||
const originalCallArgs: [CompletionsParams] = [params]
|
||||
const baseContext: BaseContext = {
|
||||
[MIDDLEWARE_CONTEXT_SYMBOL]: true,
|
||||
methodName,
|
||||
originalArgs: originalCallArgs
|
||||
}
|
||||
const ctx = completionsContextFactory(baseContext, originalCallArgs)
|
||||
|
||||
const api: MiddlewareAPI<
|
||||
CompletionsContext<TSdkParams, TMessageParam, TToolCall, TSdkInstance, TRawOutput, TRawChunk, TSdkSpecificTool>,
|
||||
[CompletionsParams]
|
||||
> = {
|
||||
getContext: () => ctx,
|
||||
getOriginalArgs: () => originalCallArgs // API provides [CompletionsParams] / API提供 `[CompletionsParams]`
|
||||
}
|
||||
|
||||
// `finalDispatch` for CompletionsMiddleware: expects (context, params) not (context, args_array). /
|
||||
// `CompletionsMiddleware` 的 `finalDispatch`:期望 (context, params) 而不是 (context, args_array)。
|
||||
const finalDispatch = async (
|
||||
context: CompletionsContext<
|
||||
TSdkParams,
|
||||
TMessageParam,
|
||||
TToolCall,
|
||||
TSdkInstance,
|
||||
TRawOutput,
|
||||
TRawChunk,
|
||||
TSdkSpecificTool
|
||||
> // Context passed through / 上下文透传
|
||||
// _currentParams: CompletionsParams // Directly takes params / 直接接收参数 (unused but required for middleware signature)
|
||||
): Promise<CompletionsResult> => {
|
||||
// At this point, middleware should have transformed CompletionsParams to SDK params
|
||||
// and stored them in context. If no transformation happened, we need to handle it.
|
||||
// 此时,中间件应该已经将 CompletionsParams 转换为 SDK 参数并存储在上下文中。
|
||||
// 如果没有进行转换,我们需要处理它。
|
||||
|
||||
const sdkPayload = context._internal?.sdkPayload
|
||||
if (!sdkPayload) {
|
||||
throw new Error('SDK payload not found in context. Middleware chain should have transformed parameters.')
|
||||
}
|
||||
|
||||
const abortSignal = context._internal.flowControl?.abortSignal
|
||||
const timeout = context._internal.customState?.sdkMetadata?.timeout
|
||||
|
||||
const methodCall = async (payload) => {
|
||||
return await originalCompletionsMethod.call(originalApiClientInstance, payload, {
|
||||
...options,
|
||||
signal: abortSignal,
|
||||
timeout
|
||||
})
|
||||
}
|
||||
|
||||
const traceParams = {
|
||||
name: `${params.assistant?.model?.name}.client`,
|
||||
tag: 'LLM',
|
||||
topicId: params.topicId || '',
|
||||
modelName: params.assistant?.model?.name
|
||||
}
|
||||
|
||||
// Call the original SDK method with transformed parameters
|
||||
// 使用转换后的参数调用原始 SDK 方法
|
||||
const rawOutput = await withSpanResult(methodCall, traceParams, sdkPayload)
|
||||
|
||||
// Return result wrapped in CompletionsResult format
|
||||
// 以 CompletionsResult 格式返回包装的结果
|
||||
return { rawOutput } as CompletionsResult
|
||||
}
|
||||
|
||||
const chain = middlewares.map((middleware) => middleware(api))
|
||||
const composedMiddlewareLogic = compose(...chain)
|
||||
|
||||
// `enhancedDispatch` has the signature `(context, params) => Promise<CompletionsResult>`. /
|
||||
// `enhancedDispatch` 的签名为 `(context, params) => Promise<CompletionsResult>`。
|
||||
const enhancedDispatch = composedMiddlewareLogic(finalDispatch)
|
||||
|
||||
// 将 enhancedDispatch 保存到 context 中,供中间件进行递归调用
|
||||
// 这样可以避免重复执行整个中间件链
|
||||
ctx._internal.enhancedDispatch = enhancedDispatch
|
||||
|
||||
// Execute with context and the single params object. /
|
||||
// 使用上下文和单个参数对象执行。
|
||||
return enhancedDispatch(ctx, params)
|
||||
}
|
||||
}
|
||||
@ -1,591 +0,0 @@
|
||||
import { loggerService } from '@logger'
|
||||
import type { MCPCallToolResponse, MCPTool, MCPToolResponse, Model } from '@renderer/types'
|
||||
import type { MCPToolCreatedChunk } from '@renderer/types/chunk'
|
||||
import { ChunkType } from '@renderer/types/chunk'
|
||||
import type { SdkMessageParam, SdkRawOutput, SdkToolCall } from '@renderer/types/sdk'
|
||||
import {
|
||||
callBuiltInTool,
|
||||
callMCPTool,
|
||||
getMcpServerByTool,
|
||||
isToolAutoApproved,
|
||||
parseToolUse,
|
||||
upsertMCPToolResponse
|
||||
} from '@renderer/utils/mcp-tools'
|
||||
import { confirmSameNameTools, requestToolConfirmation, setToolIdToNameMapping } from '@renderer/utils/userConfirmation'
|
||||
|
||||
import type { CompletionsParams, CompletionsResult, GenericChunk } from '../schemas'
|
||||
import type { CompletionsContext, CompletionsMiddleware } from '../types'
|
||||
|
||||
export const MIDDLEWARE_NAME = 'McpToolChunkMiddleware'
|
||||
const MAX_TOOL_RECURSION_DEPTH = 20 // 防止无限递归
|
||||
|
||||
const logger = loggerService.withContext('McpToolChunkMiddleware')
|
||||
|
||||
/**
|
||||
* MCP工具处理中间件
|
||||
*
|
||||
* 职责:
|
||||
* 1. 检测并拦截MCP工具进展chunk(Function Call方式和Tool Use方式)
|
||||
* 2. 执行工具调用
|
||||
* 3. 递归处理工具结果
|
||||
* 4. 管理工具调用状态和递归深度
|
||||
*/
|
||||
export const McpToolChunkMiddleware: CompletionsMiddleware =
|
||||
() =>
|
||||
(next) =>
|
||||
async (ctx: CompletionsContext, params: CompletionsParams): Promise<CompletionsResult> => {
|
||||
const mcpTools = params.mcpTools || []
|
||||
|
||||
// 如果没有工具,直接调用下一个中间件
|
||||
if (!mcpTools || mcpTools.length === 0) {
|
||||
return next(ctx, params)
|
||||
}
|
||||
|
||||
const executeWithToolHandling = async (currentParams: CompletionsParams, depth = 0): Promise<CompletionsResult> => {
|
||||
if (depth >= MAX_TOOL_RECURSION_DEPTH) {
|
||||
logger.error(`Maximum recursion depth ${MAX_TOOL_RECURSION_DEPTH} exceeded`)
|
||||
throw new Error(`Maximum tool recursion depth ${MAX_TOOL_RECURSION_DEPTH} exceeded`)
|
||||
}
|
||||
|
||||
let result: CompletionsResult
|
||||
|
||||
if (depth === 0) {
|
||||
result = await next(ctx, currentParams)
|
||||
} else {
|
||||
const enhancedCompletions = ctx._internal.enhancedDispatch
|
||||
if (!enhancedCompletions) {
|
||||
logger.error(`Enhanced completions method not found, cannot perform recursive call`)
|
||||
throw new Error('Enhanced completions method not found')
|
||||
}
|
||||
|
||||
ctx._internal.toolProcessingState!.isRecursiveCall = true
|
||||
ctx._internal.toolProcessingState!.recursionDepth = depth
|
||||
|
||||
result = await enhancedCompletions(ctx, currentParams)
|
||||
}
|
||||
|
||||
if (!result.stream) {
|
||||
logger.error(`No stream returned from enhanced completions`)
|
||||
throw new Error('No stream returned from enhanced completions')
|
||||
}
|
||||
|
||||
const resultFromUpstream = result.stream as ReadableStream<GenericChunk>
|
||||
const toolHandlingStream = resultFromUpstream.pipeThrough(
|
||||
createToolHandlingTransform(ctx, currentParams, mcpTools, depth, executeWithToolHandling)
|
||||
)
|
||||
|
||||
return {
|
||||
...result,
|
||||
stream: toolHandlingStream
|
||||
}
|
||||
}
|
||||
|
||||
return executeWithToolHandling(params, 0)
|
||||
}
|
||||
|
||||
/**
|
||||
* 创建工具处理的 TransformStream
|
||||
*/
|
||||
function createToolHandlingTransform(
|
||||
ctx: CompletionsContext,
|
||||
currentParams: CompletionsParams,
|
||||
mcpTools: MCPTool[],
|
||||
depth: number,
|
||||
executeWithToolHandling: (params: CompletionsParams, depth: number) => Promise<CompletionsResult>
|
||||
): TransformStream<GenericChunk, GenericChunk> {
|
||||
const toolCalls: SdkToolCall[] = []
|
||||
const toolUseResponses: MCPToolResponse[] = []
|
||||
const allToolResponses: MCPToolResponse[] = [] // 统一的工具响应状态管理数组
|
||||
let hasToolCalls = false
|
||||
let hasToolUseResponses = false
|
||||
let streamEnded = false
|
||||
|
||||
// 存储已执行的工具结果
|
||||
const executedToolResults: SdkMessageParam[] = []
|
||||
const executedToolCalls: SdkToolCall[] = []
|
||||
const executionPromises: Promise<void>[] = []
|
||||
|
||||
return new TransformStream({
|
||||
async transform(chunk: GenericChunk, controller) {
|
||||
try {
|
||||
// 处理MCP工具进展chunk
|
||||
logger.silly('chunk', chunk)
|
||||
if (chunk.type === ChunkType.MCP_TOOL_CREATED) {
|
||||
const createdChunk = chunk as MCPToolCreatedChunk
|
||||
|
||||
// 1. 处理Function Call方式的工具调用
|
||||
if (createdChunk.tool_calls && createdChunk.tool_calls.length > 0) {
|
||||
hasToolCalls = true
|
||||
|
||||
for (const toolCall of createdChunk.tool_calls) {
|
||||
toolCalls.push(toolCall)
|
||||
|
||||
const executionPromise = (async () => {
|
||||
try {
|
||||
const result = await executeToolCalls(
|
||||
ctx,
|
||||
[toolCall],
|
||||
mcpTools,
|
||||
allToolResponses,
|
||||
currentParams.onChunk,
|
||||
currentParams.assistant.model!,
|
||||
currentParams.topicId
|
||||
)
|
||||
|
||||
// 缓存执行结果
|
||||
executedToolResults.push(...result.toolResults)
|
||||
executedToolCalls.push(...result.confirmedToolCalls)
|
||||
} catch (error) {
|
||||
logger.error(`Error executing tool call asynchronously:`, error as Error)
|
||||
}
|
||||
})()
|
||||
|
||||
executionPromises.push(executionPromise)
|
||||
}
|
||||
}
|
||||
|
||||
// 2. 处理Tool Use方式的工具调用
|
||||
if (createdChunk.tool_use_responses && createdChunk.tool_use_responses.length > 0) {
|
||||
hasToolUseResponses = true
|
||||
for (const toolUseResponse of createdChunk.tool_use_responses) {
|
||||
toolUseResponses.push(toolUseResponse)
|
||||
const executionPromise = (async () => {
|
||||
try {
|
||||
const result = await executeToolUseResponses(
|
||||
ctx,
|
||||
[toolUseResponse], // 单个执行
|
||||
mcpTools,
|
||||
allToolResponses,
|
||||
currentParams.onChunk,
|
||||
currentParams.assistant.model!,
|
||||
currentParams.topicId
|
||||
)
|
||||
|
||||
// 缓存执行结果
|
||||
executedToolResults.push(...result.toolResults)
|
||||
} catch (error) {
|
||||
logger.error(`Error executing tool use response asynchronously:`, error as Error)
|
||||
// 错误时不影响其他工具的执行
|
||||
}
|
||||
})()
|
||||
|
||||
executionPromises.push(executionPromise)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
controller.enqueue(chunk)
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error(`Error processing chunk:`, error as Error)
|
||||
controller.error(error)
|
||||
}
|
||||
},
|
||||
|
||||
async flush(controller) {
|
||||
// 在流结束时等待所有异步工具执行完成,然后进行递归调用
|
||||
if (!streamEnded && (hasToolCalls || hasToolUseResponses)) {
|
||||
streamEnded = true
|
||||
|
||||
try {
|
||||
await Promise.all(executionPromises)
|
||||
if (executedToolResults.length > 0) {
|
||||
const output = ctx._internal.toolProcessingState?.output
|
||||
const newParams = buildParamsWithToolResults(
|
||||
ctx,
|
||||
currentParams,
|
||||
output,
|
||||
executedToolResults,
|
||||
executedToolCalls
|
||||
)
|
||||
|
||||
// 在递归调用前通知UI开始新的LLM响应处理
|
||||
if (currentParams.onChunk) {
|
||||
currentParams.onChunk({
|
||||
type: ChunkType.LLM_RESPONSE_CREATED
|
||||
})
|
||||
}
|
||||
|
||||
await executeWithToolHandling(newParams, depth + 1)
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error(`Error in tool processing:`, error as Error)
|
||||
controller.error(error)
|
||||
} finally {
|
||||
hasToolCalls = false
|
||||
hasToolUseResponses = false
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/**
|
||||
* 执行工具调用(Function Call 方式)
|
||||
*/
|
||||
async function executeToolCalls(
|
||||
ctx: CompletionsContext,
|
||||
toolCalls: SdkToolCall[],
|
||||
mcpTools: MCPTool[],
|
||||
allToolResponses: MCPToolResponse[],
|
||||
onChunk: CompletionsParams['onChunk'],
|
||||
model: Model,
|
||||
topicId?: string
|
||||
): Promise<{ toolResults: SdkMessageParam[]; confirmedToolCalls: SdkToolCall[] }> {
|
||||
const mcpToolResponses: MCPToolResponse[] = toolCalls
|
||||
.map((toolCall) => {
|
||||
const mcpTool = ctx.apiClientInstance.convertSdkToolCallToMcp(toolCall, mcpTools)
|
||||
if (!mcpTool) {
|
||||
return undefined
|
||||
}
|
||||
return ctx.apiClientInstance.convertSdkToolCallToMcpToolResponse(toolCall, mcpTool)
|
||||
})
|
||||
.filter((t): t is MCPToolResponse => typeof t !== 'undefined')
|
||||
|
||||
if (mcpToolResponses.length === 0) {
|
||||
logger.warn(`No valid MCP tool responses to execute`)
|
||||
return { toolResults: [], confirmedToolCalls: [] }
|
||||
}
|
||||
|
||||
// 使用现有的parseAndCallTools函数执行工具
|
||||
const { toolResults, confirmedToolResponses } = await parseAndCallTools(
|
||||
mcpToolResponses,
|
||||
allToolResponses,
|
||||
onChunk,
|
||||
(mcpToolResponse, resp, model) => {
|
||||
return ctx.apiClientInstance.convertMcpToolResponseToSdkMessageParam(mcpToolResponse, resp, model)
|
||||
},
|
||||
model,
|
||||
mcpTools,
|
||||
ctx._internal?.flowControl?.abortSignal,
|
||||
topicId
|
||||
)
|
||||
|
||||
// 找出已确认工具对应的原始toolCalls
|
||||
const confirmedToolCalls = toolCalls.filter((toolCall) => {
|
||||
return confirmedToolResponses.find((confirmed) => {
|
||||
// 根据不同的ID字段匹配原始toolCall
|
||||
return (
|
||||
('name' in toolCall &&
|
||||
(toolCall.name?.includes(confirmed.tool.name) || toolCall.name?.includes(confirmed.tool.id))) ||
|
||||
confirmed.tool.name === toolCall.id ||
|
||||
confirmed.tool.id === toolCall.id ||
|
||||
('toolCallId' in confirmed && confirmed.toolCallId === toolCall.id) ||
|
||||
('function' in toolCall && toolCall.function.name.toLowerCase().includes(confirmed.tool.name.toLowerCase()))
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
return { toolResults, confirmedToolCalls }
|
||||
}
|
||||
|
||||
/**
|
||||
* 执行工具使用响应(Tool Use Response 方式)
|
||||
* 处理已经解析好的 ToolUseResponse[],不需要重新解析字符串
|
||||
*/
|
||||
async function executeToolUseResponses(
|
||||
ctx: CompletionsContext,
|
||||
toolUseResponses: MCPToolResponse[],
|
||||
mcpTools: MCPTool[],
|
||||
allToolResponses: MCPToolResponse[],
|
||||
onChunk: CompletionsParams['onChunk'],
|
||||
model: Model,
|
||||
topicId?: CompletionsParams['topicId']
|
||||
): Promise<{ toolResults: SdkMessageParam[] }> {
|
||||
// 直接使用parseAndCallTools函数处理已经解析好的ToolUseResponse
|
||||
const { toolResults } = await parseAndCallTools(
|
||||
toolUseResponses,
|
||||
allToolResponses,
|
||||
onChunk,
|
||||
(mcpToolResponse, resp, model) => {
|
||||
return ctx.apiClientInstance.convertMcpToolResponseToSdkMessageParam(mcpToolResponse, resp, model)
|
||||
},
|
||||
model,
|
||||
mcpTools,
|
||||
ctx._internal?.flowControl?.abortSignal,
|
||||
topicId
|
||||
)
|
||||
|
||||
return { toolResults }
|
||||
}
|
||||
|
||||
/**
|
||||
* 构建包含工具结果的新参数
|
||||
*/
|
||||
function buildParamsWithToolResults(
|
||||
ctx: CompletionsContext,
|
||||
currentParams: CompletionsParams,
|
||||
output: SdkRawOutput | string | undefined,
|
||||
toolResults: SdkMessageParam[],
|
||||
confirmedToolCalls: SdkToolCall[]
|
||||
): CompletionsParams {
|
||||
// 获取当前已经转换好的reqMessages,如果没有则使用原始messages
|
||||
const currentReqMessages = getCurrentReqMessages(ctx)
|
||||
|
||||
const apiClient = ctx.apiClientInstance
|
||||
|
||||
// 从回复中构建助手消息
|
||||
const newReqMessages = apiClient.buildSdkMessages(currentReqMessages, output, toolResults, confirmedToolCalls)
|
||||
|
||||
if (output && ctx._internal.toolProcessingState) {
|
||||
ctx._internal.toolProcessingState.output = undefined
|
||||
}
|
||||
|
||||
// 估算新增消息的 token 消耗并累加到 usage 中
|
||||
if (ctx._internal.observer?.usage && newReqMessages.length > currentReqMessages.length) {
|
||||
try {
|
||||
const newMessages = newReqMessages.slice(currentReqMessages.length)
|
||||
const additionalTokens = newMessages.reduce((acc, message) => {
|
||||
return acc + ctx.apiClientInstance.estimateMessageTokens(message)
|
||||
}, 0)
|
||||
|
||||
if (additionalTokens > 0) {
|
||||
ctx._internal.observer.usage.prompt_tokens += additionalTokens
|
||||
ctx._internal.observer.usage.total_tokens += additionalTokens
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error(`Error estimating token usage for new messages:`, error as Error)
|
||||
}
|
||||
}
|
||||
|
||||
// 更新递归状态
|
||||
if (!ctx._internal.toolProcessingState) {
|
||||
ctx._internal.toolProcessingState = {}
|
||||
}
|
||||
ctx._internal.toolProcessingState.isRecursiveCall = true
|
||||
ctx._internal.toolProcessingState.recursionDepth = (ctx._internal.toolProcessingState?.recursionDepth || 0) + 1
|
||||
|
||||
return {
|
||||
...currentParams,
|
||||
_internal: {
|
||||
...ctx._internal,
|
||||
sdkPayload: ctx._internal.sdkPayload,
|
||||
newReqMessages: newReqMessages
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 类型安全地获取当前请求消息
|
||||
* 使用API客户端提供的抽象方法,保持中间件的provider无关性
|
||||
*/
|
||||
function getCurrentReqMessages(ctx: CompletionsContext): SdkMessageParam[] {
|
||||
const sdkPayload = ctx._internal.sdkPayload
|
||||
if (!sdkPayload) {
|
||||
return []
|
||||
}
|
||||
|
||||
// 使用API客户端的抽象方法来提取消息,保持provider无关性
|
||||
return ctx.apiClientInstance.extractMessagesFromSdkPayload(sdkPayload)
|
||||
}
|
||||
|
||||
export async function parseAndCallTools<R>(
|
||||
tools: MCPToolResponse[],
|
||||
allToolResponses: MCPToolResponse[],
|
||||
onChunk: CompletionsParams['onChunk'],
|
||||
convertToMessage: (mcpToolResponse: MCPToolResponse, resp: MCPCallToolResponse, model: Model) => R | undefined,
|
||||
model: Model,
|
||||
mcpTools?: MCPTool[],
|
||||
abortSignal?: AbortSignal,
|
||||
topicId?: CompletionsParams['topicId']
|
||||
): Promise<{ toolResults: R[]; confirmedToolResponses: MCPToolResponse[] }>
|
||||
|
||||
export async function parseAndCallTools<R>(
|
||||
content: string,
|
||||
allToolResponses: MCPToolResponse[],
|
||||
onChunk: CompletionsParams['onChunk'],
|
||||
convertToMessage: (mcpToolResponse: MCPToolResponse, resp: MCPCallToolResponse, model: Model) => R | undefined,
|
||||
model: Model,
|
||||
mcpTools?: MCPTool[],
|
||||
abortSignal?: AbortSignal,
|
||||
topicId?: CompletionsParams['topicId']
|
||||
): Promise<{ toolResults: R[]; confirmedToolResponses: MCPToolResponse[] }>
|
||||
|
||||
export async function parseAndCallTools<R>(
|
||||
content: string | MCPToolResponse[],
|
||||
allToolResponses: MCPToolResponse[],
|
||||
onChunk: CompletionsParams['onChunk'],
|
||||
convertToMessage: (mcpToolResponse: MCPToolResponse, resp: MCPCallToolResponse, model: Model) => R | undefined,
|
||||
model: Model,
|
||||
mcpTools?: MCPTool[],
|
||||
abortSignal?: AbortSignal,
|
||||
topicId?: CompletionsParams['topicId']
|
||||
): Promise<{ toolResults: R[]; confirmedToolResponses: MCPToolResponse[] }> {
|
||||
const toolResults: R[] = []
|
||||
let curToolResponses: MCPToolResponse[] = []
|
||||
if (Array.isArray(content)) {
|
||||
curToolResponses = content
|
||||
} else {
|
||||
// process tool use
|
||||
curToolResponses = parseToolUse(content, mcpTools || [], 0)
|
||||
}
|
||||
if (!curToolResponses || curToolResponses.length === 0) {
|
||||
return { toolResults, confirmedToolResponses: [] }
|
||||
}
|
||||
|
||||
for (const toolResponse of curToolResponses) {
|
||||
upsertMCPToolResponse(
|
||||
allToolResponses,
|
||||
{
|
||||
...toolResponse,
|
||||
status: 'pending'
|
||||
},
|
||||
onChunk!
|
||||
)
|
||||
}
|
||||
|
||||
// 创建工具确认Promise映射,并立即处理每个确认
|
||||
const confirmedTools: MCPToolResponse[] = []
|
||||
const pendingPromises: Promise<void>[] = []
|
||||
|
||||
curToolResponses.forEach((toolResponse) => {
|
||||
const server = getMcpServerByTool(toolResponse.tool)
|
||||
const isAutoApproveEnabled = isToolAutoApproved(toolResponse.tool, server)
|
||||
let confirmationPromise: Promise<boolean>
|
||||
if (isAutoApproveEnabled) {
|
||||
confirmationPromise = Promise.resolve(true)
|
||||
} else {
|
||||
setToolIdToNameMapping(toolResponse.id, toolResponse.tool.name)
|
||||
|
||||
confirmationPromise = requestToolConfirmation(toolResponse.id, abortSignal).then((confirmed) => {
|
||||
if (confirmed && server) {
|
||||
// 自动确认其他同名的待确认工具
|
||||
confirmSameNameTools(toolResponse.tool.name)
|
||||
}
|
||||
return confirmed
|
||||
})
|
||||
}
|
||||
|
||||
const processingPromise = confirmationPromise
|
||||
.then(async (confirmed) => {
|
||||
if (confirmed) {
|
||||
// 立即更新为invoking状态
|
||||
upsertMCPToolResponse(
|
||||
allToolResponses,
|
||||
{
|
||||
...toolResponse,
|
||||
status: 'invoking'
|
||||
},
|
||||
onChunk!
|
||||
)
|
||||
|
||||
// 执行工具调用
|
||||
try {
|
||||
const images: string[] = []
|
||||
// 根据工具类型选择不同的调用方式
|
||||
const toolCallResponse = toolResponse.tool.isBuiltIn
|
||||
? await callBuiltInTool(toolResponse)
|
||||
: await callMCPTool(toolResponse, topicId, model.name)
|
||||
|
||||
// 立即更新为done状态
|
||||
upsertMCPToolResponse(
|
||||
allToolResponses,
|
||||
{
|
||||
...toolResponse,
|
||||
status: 'done',
|
||||
response: toolCallResponse
|
||||
},
|
||||
onChunk!
|
||||
)
|
||||
|
||||
if (!toolCallResponse) {
|
||||
return
|
||||
}
|
||||
|
||||
// 处理图片
|
||||
for (const content of toolCallResponse.content) {
|
||||
if (content.type === 'image' && content.data) {
|
||||
images.push(`data:${content.mimeType};base64,${content.data}`)
|
||||
}
|
||||
}
|
||||
|
||||
if (images.length) {
|
||||
onChunk?.({
|
||||
type: ChunkType.IMAGE_CREATED
|
||||
})
|
||||
onChunk?.({
|
||||
type: ChunkType.IMAGE_COMPLETE,
|
||||
image: {
|
||||
type: 'base64',
|
||||
images: images
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// 转换消息并添加到结果
|
||||
const convertedMessage = convertToMessage(toolResponse, toolCallResponse, model)
|
||||
if (convertedMessage) {
|
||||
confirmedTools.push(toolResponse)
|
||||
toolResults.push(convertedMessage)
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error(`Error executing tool ${toolResponse.id}:`, error as Error)
|
||||
// 更新为错误状态
|
||||
upsertMCPToolResponse(
|
||||
allToolResponses,
|
||||
{
|
||||
...toolResponse,
|
||||
status: 'done',
|
||||
response: {
|
||||
isError: true,
|
||||
content: [
|
||||
{
|
||||
type: 'text',
|
||||
text: `Error executing tool: ${error instanceof Error ? error.message : 'Unknown error'}`
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
onChunk!
|
||||
)
|
||||
}
|
||||
} else {
|
||||
// 立即更新为cancelled状态
|
||||
upsertMCPToolResponse(
|
||||
allToolResponses,
|
||||
{
|
||||
...toolResponse,
|
||||
status: 'cancelled',
|
||||
response: {
|
||||
isError: false,
|
||||
content: [
|
||||
{
|
||||
type: 'text',
|
||||
text: 'Tool call cancelled by user.'
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
onChunk!
|
||||
)
|
||||
}
|
||||
})
|
||||
.catch((error) => {
|
||||
logger.error(`Error waiting for tool confirmation ${toolResponse.id}:`, error as Error)
|
||||
// 立即更新为cancelled状态
|
||||
upsertMCPToolResponse(
|
||||
allToolResponses,
|
||||
{
|
||||
...toolResponse,
|
||||
status: 'cancelled',
|
||||
response: {
|
||||
isError: true,
|
||||
content: [
|
||||
{
|
||||
type: 'text',
|
||||
text: `Error in confirmation process: ${error instanceof Error ? error.message : 'Unknown error'}`
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
onChunk!
|
||||
)
|
||||
})
|
||||
|
||||
pendingPromises.push(processingPromise)
|
||||
})
|
||||
|
||||
// 等待所有工具处理完成(但每个工具的状态已经实时更新)
|
||||
await Promise.all(pendingPromises)
|
||||
|
||||
return { toolResults, confirmedToolResponses: confirmedTools }
|
||||
}
|
||||
@ -1,45 +0,0 @@
|
||||
import { AnthropicAPIClient } from '@renderer/aiCore/legacy/clients/anthropic/AnthropicAPIClient'
|
||||
import type { AnthropicSdkRawChunk, AnthropicSdkRawOutput } from '@renderer/types/sdk'
|
||||
|
||||
import type { AnthropicStreamListener } from '../../clients/types'
|
||||
import type { CompletionsParams, CompletionsResult } from '../schemas'
|
||||
import type { CompletionsContext, CompletionsMiddleware } from '../types'
|
||||
|
||||
export const MIDDLEWARE_NAME = 'RawStreamListenerMiddleware'
|
||||
|
||||
export const RawStreamListenerMiddleware: CompletionsMiddleware =
|
||||
() =>
|
||||
(next) =>
|
||||
async (ctx: CompletionsContext, params: CompletionsParams): Promise<CompletionsResult> => {
|
||||
const result = await next(ctx, params)
|
||||
|
||||
// 在这里可以监听到从SDK返回的最原始流
|
||||
if (result.rawOutput) {
|
||||
// TODO: 后面下放到AnthropicAPIClient
|
||||
if (ctx.apiClientInstance instanceof AnthropicAPIClient) {
|
||||
const anthropicListener: AnthropicStreamListener<AnthropicSdkRawChunk> = {
|
||||
onMessage: (message) => {
|
||||
if (ctx._internal?.toolProcessingState) {
|
||||
ctx._internal.toolProcessingState.output = message
|
||||
}
|
||||
}
|
||||
// onContentBlock: (contentBlock) => {
|
||||
// console.log(`[${MIDDLEWARE_NAME}] 📝 Anthropic content block:`, contentBlock.type)
|
||||
// }
|
||||
}
|
||||
|
||||
const specificApiClient = ctx.apiClientInstance as AnthropicAPIClient
|
||||
|
||||
const monitoredOutput = specificApiClient.attachRawStreamListener(
|
||||
result.rawOutput as AnthropicSdkRawOutput,
|
||||
anthropicListener
|
||||
)
|
||||
return {
|
||||
...result,
|
||||
rawOutput: monitoredOutput
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
@ -1,88 +0,0 @@
|
||||
import { loggerService } from '@logger'
|
||||
import type { SdkRawChunk } from '@renderer/types/sdk'
|
||||
|
||||
import type { ResponseChunkTransformerContext } from '../../clients/types'
|
||||
import type { CompletionsParams, CompletionsResult, GenericChunk } from '../schemas'
|
||||
import type { CompletionsContext, CompletionsMiddleware } from '../types'
|
||||
|
||||
export const MIDDLEWARE_NAME = 'ResponseTransformMiddleware'
|
||||
|
||||
const logger = loggerService.withContext('ResponseTransformMiddleware')
|
||||
|
||||
/**
|
||||
* 响应转换中间件
|
||||
*
|
||||
* 职责:
|
||||
* 1. 检测ReadableStream类型的响应流
|
||||
* 2. 使用ApiClient的getResponseChunkTransformer()将原始SDK响应块转换为通用格式
|
||||
* 3. 将转换后的ReadableStream保存到ctx._internal.apiCall.genericChunkStream,供下游中间件使用
|
||||
*
|
||||
* 注意:此中间件应该在StreamAdapterMiddleware之后执行
|
||||
*/
|
||||
export const ResponseTransformMiddleware: CompletionsMiddleware =
|
||||
() =>
|
||||
(next) =>
|
||||
async (ctx: CompletionsContext, params: CompletionsParams): Promise<CompletionsResult> => {
|
||||
// 调用下游中间件
|
||||
const result = await next(ctx, params)
|
||||
|
||||
// 响应后处理:转换原始SDK响应块
|
||||
if (result.stream) {
|
||||
const adaptedStream = result.stream
|
||||
|
||||
// 处理ReadableStream类型的流
|
||||
if (adaptedStream instanceof ReadableStream) {
|
||||
const apiClient = ctx.apiClientInstance
|
||||
if (!apiClient) {
|
||||
logger.error(`ApiClient instance not found in context`)
|
||||
throw new Error('ApiClient instance not found in context')
|
||||
}
|
||||
|
||||
// 获取响应转换器
|
||||
const responseChunkTransformer = apiClient.getResponseChunkTransformer(ctx)
|
||||
if (!responseChunkTransformer) {
|
||||
logger.warn(`No ResponseChunkTransformer available, skipping transformation`)
|
||||
return result
|
||||
}
|
||||
|
||||
const assistant = params.assistant
|
||||
const model = assistant?.model
|
||||
|
||||
if (!assistant || !model) {
|
||||
logger.error(`Assistant or Model not found for transformation`)
|
||||
throw new Error('Assistant or Model not found for transformation')
|
||||
}
|
||||
|
||||
const transformerContext: ResponseChunkTransformerContext = {
|
||||
isStreaming: params.streamOutput || false,
|
||||
isEnabledToolCalling: (params.mcpTools && params.mcpTools.length > 0) || false,
|
||||
isEnabledWebSearch: params.enableWebSearch || false,
|
||||
isEnabledUrlContext: params.enableUrlContext || false,
|
||||
isEnabledReasoning: params.enableReasoning || false,
|
||||
mcpTools: params.mcpTools || [],
|
||||
provider: ctx.apiClientInstance?.provider
|
||||
}
|
||||
|
||||
logger.debug(`Transforming raw SDK chunks with context:`, transformerContext)
|
||||
|
||||
try {
|
||||
// 创建转换后的流
|
||||
const genericChunkTransformStream = (adaptedStream as ReadableStream<SdkRawChunk>).pipeThrough<GenericChunk>(
|
||||
new TransformStream<SdkRawChunk, GenericChunk>(responseChunkTransformer(transformerContext))
|
||||
)
|
||||
|
||||
// 将转换后的ReadableStream保存到result,供下游中间件使用
|
||||
return {
|
||||
...result,
|
||||
stream: genericChunkTransformStream
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error('Error during chunk transformation:', error as Error)
|
||||
throw error
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 如果没有流或不是ReadableStream,返回原始结果
|
||||
return result
|
||||
}
|
||||
@ -1,56 +0,0 @@
|
||||
import type { SdkRawChunk } from '@renderer/types/sdk'
|
||||
import { asyncGeneratorToReadableStream, createSingleChunkReadableStream } from '@renderer/utils/stream'
|
||||
|
||||
import type { CompletionsParams, CompletionsResult } from '../schemas'
|
||||
import type { CompletionsContext, CompletionsMiddleware } from '../types'
|
||||
import { isAsyncIterable } from '../utils'
|
||||
|
||||
export const MIDDLEWARE_NAME = 'StreamAdapterMiddleware'
|
||||
|
||||
/**
|
||||
* 流适配器中间件
|
||||
*
|
||||
* 职责:
|
||||
* 1. 检测ctx._internal.apiCall.rawSdkOutput(优先)或原始AsyncIterable流
|
||||
* 2. 将AsyncIterable转换为WHATWG ReadableStream
|
||||
* 3. 更新响应结果中的stream
|
||||
*
|
||||
* 注意:如果ResponseTransformMiddleware已处理过,会优先使用transformedStream
|
||||
*/
|
||||
export const StreamAdapterMiddleware: CompletionsMiddleware =
|
||||
() =>
|
||||
(next) =>
|
||||
async (ctx: CompletionsContext, params: CompletionsParams): Promise<CompletionsResult> => {
|
||||
// TODO:调用开始,因为这个是最靠近接口请求的地方,next执行代表着开始接口请求了
|
||||
// 但是这个中间件的职责是流适配,是否在这调用优待商榷
|
||||
// 调用下游中间件
|
||||
const result = await next(ctx, params)
|
||||
if (
|
||||
result.rawOutput &&
|
||||
!(result.rawOutput instanceof ReadableStream) &&
|
||||
isAsyncIterable<SdkRawChunk>(result.rawOutput)
|
||||
) {
|
||||
const whatwgReadableStream: ReadableStream<SdkRawChunk> = asyncGeneratorToReadableStream<SdkRawChunk>(
|
||||
result.rawOutput
|
||||
)
|
||||
return {
|
||||
...result,
|
||||
stream: whatwgReadableStream
|
||||
}
|
||||
} else if (result.rawOutput && result.rawOutput instanceof ReadableStream) {
|
||||
return {
|
||||
...result,
|
||||
stream: result.rawOutput
|
||||
}
|
||||
} else if (result.rawOutput) {
|
||||
// 非流式输出,强行变为可读流
|
||||
const whatwgReadableStream: ReadableStream<SdkRawChunk> = createSingleChunkReadableStream<SdkRawChunk>(
|
||||
result.rawOutput as SdkRawChunk
|
||||
)
|
||||
return {
|
||||
...result,
|
||||
stream: whatwgReadableStream
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
@ -1,106 +0,0 @@
|
||||
import { loggerService } from '@logger'
|
||||
import { ChunkType } from '@renderer/types/chunk'
|
||||
|
||||
import type { CompletionsParams, CompletionsResult, GenericChunk } from '../schemas'
|
||||
import type { CompletionsContext, CompletionsMiddleware } from '../types'
|
||||
|
||||
export const MIDDLEWARE_NAME = 'TextChunkMiddleware'
|
||||
|
||||
const logger = loggerService.withContext('TextChunkMiddleware')
|
||||
|
||||
/**
|
||||
* 文本块处理中间件
|
||||
*
|
||||
* 职责:
|
||||
* 1. 累积文本内容(TEXT_DELTA)
|
||||
* 2. 对文本内容进行智能链接转换
|
||||
* 3. 生成TEXT_COMPLETE事件
|
||||
* 4. 暂存Web搜索结果,用于最终链接完善
|
||||
* 5. 处理 onResponse 回调,实时发送文本更新和最终完整文本
|
||||
*/
|
||||
export const TextChunkMiddleware: CompletionsMiddleware =
|
||||
() =>
|
||||
(next) =>
|
||||
async (ctx: CompletionsContext, params: CompletionsParams): Promise<CompletionsResult> => {
|
||||
// 调用下游中间件
|
||||
const result = await next(ctx, params)
|
||||
|
||||
// 响应后处理:转换流式响应中的文本内容
|
||||
if (result.stream) {
|
||||
const resultFromUpstream = result.stream as ReadableStream<GenericChunk>
|
||||
|
||||
if (resultFromUpstream && resultFromUpstream instanceof ReadableStream) {
|
||||
const assistant = params.assistant
|
||||
const model = params.assistant?.model
|
||||
|
||||
if (!assistant || !model) {
|
||||
logger.warn(`Missing assistant or model information, skipping text processing`)
|
||||
return result
|
||||
}
|
||||
|
||||
// 用于跨chunk的状态管理
|
||||
let accumulatedTextContent = ''
|
||||
const enhancedTextStream = resultFromUpstream.pipeThrough(
|
||||
new TransformStream<GenericChunk, GenericChunk>({
|
||||
transform(chunk: GenericChunk, controller) {
|
||||
logger.silly('chunk', chunk)
|
||||
if (chunk.type === ChunkType.TEXT_DELTA) {
|
||||
if (model.supported_text_delta === false) {
|
||||
accumulatedTextContent = chunk.text
|
||||
} else {
|
||||
accumulatedTextContent += chunk.text
|
||||
}
|
||||
// 处理 onResponse 回调 - 发送增量文本更新
|
||||
if (params.onResponse) {
|
||||
params.onResponse(accumulatedTextContent, false)
|
||||
}
|
||||
|
||||
controller.enqueue({
|
||||
...chunk,
|
||||
text: accumulatedTextContent // 增量更新
|
||||
})
|
||||
} else if (accumulatedTextContent && chunk.type !== ChunkType.TEXT_START) {
|
||||
ctx._internal.customState!.accumulatedText = accumulatedTextContent
|
||||
if (ctx._internal.toolProcessingState && !ctx._internal.toolProcessingState?.output) {
|
||||
ctx._internal.toolProcessingState.output = accumulatedTextContent
|
||||
}
|
||||
|
||||
if (chunk.type === ChunkType.LLM_RESPONSE_COMPLETE) {
|
||||
// 处理 onResponse 回调 - 发送最终完整文本
|
||||
if (params.onResponse) {
|
||||
params.onResponse(accumulatedTextContent, true)
|
||||
}
|
||||
|
||||
controller.enqueue({
|
||||
type: ChunkType.TEXT_COMPLETE,
|
||||
text: accumulatedTextContent
|
||||
})
|
||||
controller.enqueue(chunk)
|
||||
} else {
|
||||
controller.enqueue({
|
||||
type: ChunkType.TEXT_COMPLETE,
|
||||
text: accumulatedTextContent
|
||||
})
|
||||
controller.enqueue(chunk)
|
||||
}
|
||||
accumulatedTextContent = ''
|
||||
} else {
|
||||
// 其他类型的chunk直接传递
|
||||
controller.enqueue(chunk)
|
||||
}
|
||||
}
|
||||
})
|
||||
)
|
||||
|
||||
// 更新响应结果
|
||||
return {
|
||||
...result,
|
||||
stream: enhancedTextStream
|
||||
}
|
||||
} else {
|
||||
logger.warn(`No stream to process or not a ReadableStream. Returning original result.`)
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
@ -1,99 +0,0 @@
|
||||
import { loggerService } from '@logger'
|
||||
import type { ThinkingCompleteChunk, ThinkingDeltaChunk } from '@renderer/types/chunk'
|
||||
import { ChunkType } from '@renderer/types/chunk'
|
||||
|
||||
import type { CompletionsParams, CompletionsResult, GenericChunk } from '../schemas'
|
||||
import type { CompletionsContext, CompletionsMiddleware } from '../types'
|
||||
|
||||
export const MIDDLEWARE_NAME = 'ThinkChunkMiddleware'
|
||||
|
||||
const logger = loggerService.withContext('ThinkChunkMiddleware')
|
||||
|
||||
/**
|
||||
* 处理思考内容的中间件
|
||||
*
|
||||
* 注意:从 v2 版本开始,流结束语义的判断已移至 ApiClient 层处理
|
||||
* 此中间件现在主要负责:
|
||||
* 1. 处理原始SDK chunk中的reasoning字段
|
||||
* 2. 计算准确的思考时间
|
||||
* 3. 在思考内容结束时生成THINKING_COMPLETE事件
|
||||
*
|
||||
* 职责:
|
||||
* 1. 累积思考内容(THINKING_DELTA)
|
||||
* 2. 监听流结束信号,生成THINKING_COMPLETE事件
|
||||
* 3. 计算准确的思考时间
|
||||
*
|
||||
*/
|
||||
export const ThinkChunkMiddleware: CompletionsMiddleware =
|
||||
() =>
|
||||
(next) =>
|
||||
async (ctx: CompletionsContext, params: CompletionsParams): Promise<CompletionsResult> => {
|
||||
// 调用下游中间件
|
||||
const result = await next(ctx, params)
|
||||
|
||||
// 响应后处理:处理思考内容
|
||||
if (result.stream) {
|
||||
const resultFromUpstream = result.stream as ReadableStream<GenericChunk>
|
||||
|
||||
// 检查是否有流需要处理
|
||||
if (resultFromUpstream && resultFromUpstream instanceof ReadableStream) {
|
||||
// thinking 处理状态
|
||||
let accumulatedThinkingContent = ''
|
||||
let hasThinkingContent = false
|
||||
let thinkingStartTime = 0
|
||||
|
||||
const processedStream = resultFromUpstream.pipeThrough(
|
||||
new TransformStream<GenericChunk, GenericChunk>({
|
||||
transform(chunk: GenericChunk, controller) {
|
||||
if (chunk.type === ChunkType.THINKING_DELTA) {
|
||||
const thinkingChunk = chunk as ThinkingDeltaChunk
|
||||
|
||||
// 第一次接收到思考内容时记录开始时间
|
||||
if (!hasThinkingContent) {
|
||||
hasThinkingContent = true
|
||||
thinkingStartTime = Date.now()
|
||||
}
|
||||
|
||||
accumulatedThinkingContent += thinkingChunk.text
|
||||
|
||||
// 更新思考时间并传递
|
||||
const enhancedChunk: ThinkingDeltaChunk = {
|
||||
...thinkingChunk,
|
||||
text: accumulatedThinkingContent,
|
||||
thinking_millsec: thinkingStartTime > 0 ? Date.now() - thinkingStartTime : 0
|
||||
}
|
||||
controller.enqueue(enhancedChunk)
|
||||
} else if (hasThinkingContent && thinkingStartTime > 0 && chunk.type !== ChunkType.THINKING_START) {
|
||||
// 收到任何非THINKING_DELTA的chunk时,如果有累积的思考内容,生成THINKING_COMPLETE
|
||||
const thinkingCompleteChunk: ThinkingCompleteChunk = {
|
||||
type: ChunkType.THINKING_COMPLETE,
|
||||
text: accumulatedThinkingContent,
|
||||
thinking_millsec: thinkingStartTime > 0 ? Date.now() - thinkingStartTime : 0
|
||||
}
|
||||
controller.enqueue(thinkingCompleteChunk)
|
||||
hasThinkingContent = false
|
||||
accumulatedThinkingContent = ''
|
||||
thinkingStartTime = 0
|
||||
|
||||
// 继续传递当前chunk
|
||||
controller.enqueue(chunk)
|
||||
} else {
|
||||
// 其他情况直接传递
|
||||
controller.enqueue(chunk)
|
||||
}
|
||||
}
|
||||
})
|
||||
)
|
||||
|
||||
// 更新响应结果
|
||||
return {
|
||||
...result,
|
||||
stream: processedStream
|
||||
}
|
||||
} else {
|
||||
logger.warn(`No generic chunk stream to process or not a ReadableStream.`)
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
@ -1,81 +0,0 @@
|
||||
import { loggerService } from '@logger'
|
||||
import { ChunkType } from '@renderer/types/chunk'
|
||||
|
||||
import type { CompletionsParams, CompletionsResult } from '../schemas'
|
||||
import type { CompletionsContext, CompletionsMiddleware } from '../types'
|
||||
|
||||
export const MIDDLEWARE_NAME = 'TransformCoreToSdkParamsMiddleware'
|
||||
|
||||
const logger = loggerService.withContext('TransformCoreToSdkParamsMiddleware')
|
||||
|
||||
/**
|
||||
* 中间件:将CoreCompletionsRequest转换为SDK特定的参数
|
||||
* 使用上下文中ApiClient实例的requestTransformer进行转换
|
||||
*/
|
||||
export const TransformCoreToSdkParamsMiddleware: CompletionsMiddleware =
|
||||
() =>
|
||||
(next) =>
|
||||
async (ctx: CompletionsContext, params: CompletionsParams): Promise<CompletionsResult> => {
|
||||
const internal = ctx._internal
|
||||
|
||||
// 🔧 检测递归调用:检查 params 中是否携带了预处理的 SDK 消息
|
||||
const isRecursiveCall = internal?.toolProcessingState?.isRecursiveCall || false
|
||||
const newSdkMessages = params._internal?.newReqMessages
|
||||
|
||||
const apiClient = ctx.apiClientInstance
|
||||
|
||||
if (!apiClient) {
|
||||
logger.error(`ApiClient instance not found in context.`)
|
||||
throw new Error('ApiClient instance not found in context')
|
||||
}
|
||||
|
||||
// 检查是否有requestTransformer方法
|
||||
const requestTransformer = apiClient.getRequestTransformer()
|
||||
if (!requestTransformer) {
|
||||
logger.warn(`ApiClient does not have getRequestTransformer method, skipping transformation`)
|
||||
const result = await next(ctx, params)
|
||||
return result
|
||||
}
|
||||
|
||||
// 确保assistant和model可用,它们是transformer所需的
|
||||
const assistant = params.assistant
|
||||
const model = params.assistant.model
|
||||
|
||||
if (!assistant || !model) {
|
||||
logger.error(`Assistant or Model not found for transformation.`)
|
||||
throw new Error('Assistant or Model not found for transformation')
|
||||
}
|
||||
|
||||
try {
|
||||
const transformResult = await requestTransformer.transform(
|
||||
params,
|
||||
assistant,
|
||||
model,
|
||||
isRecursiveCall,
|
||||
newSdkMessages
|
||||
)
|
||||
|
||||
const { payload: sdkPayload, metadata } = transformResult
|
||||
|
||||
// 将SDK特定的payload和metadata存储在状态中,供下游中间件使用
|
||||
ctx._internal.sdkPayload = sdkPayload
|
||||
|
||||
if (metadata) {
|
||||
ctx._internal.customState = {
|
||||
...ctx._internal.customState,
|
||||
sdkMetadata: metadata
|
||||
}
|
||||
}
|
||||
|
||||
if (params.enableGenerateImage) {
|
||||
params.onChunk?.({
|
||||
type: ChunkType.IMAGE_CREATED
|
||||
})
|
||||
}
|
||||
return next(ctx, params)
|
||||
} catch (error) {
|
||||
logger.error('Error during request transformation:', error as Error)
|
||||
// 让错误向上传播,或者可以在这里进行特定的错误处理
|
||||
throw error
|
||||
}
|
||||
}
|
||||
@ -1,102 +0,0 @@
|
||||
import { loggerService } from '@logger'
|
||||
import { ChunkType } from '@renderer/types/chunk'
|
||||
import { convertLinks, flushLinkConverterBuffer } from '@renderer/utils/linkConverter'
|
||||
|
||||
import type { CompletionsParams, CompletionsResult, GenericChunk } from '../schemas'
|
||||
import type { CompletionsContext, CompletionsMiddleware } from '../types'
|
||||
|
||||
const logger = loggerService.withContext('WebSearchMiddleware')
|
||||
|
||||
export const MIDDLEWARE_NAME = 'WebSearchMiddleware'
|
||||
|
||||
/**
|
||||
* Web搜索处理中间件 - 基于GenericChunk流处理
|
||||
*
|
||||
* 职责:
|
||||
* 1. 监听和记录Web搜索事件
|
||||
* 2. 可以在此处添加Web搜索结果的后处理逻辑
|
||||
* 3. 维护Web搜索相关的状态
|
||||
*
|
||||
* 注意:Web搜索结果的识别和生成已在ApiClient的响应转换器中处理
|
||||
*/
|
||||
export const WebSearchMiddleware: CompletionsMiddleware =
|
||||
() =>
|
||||
(next) =>
|
||||
async (ctx: CompletionsContext, params: CompletionsParams): Promise<CompletionsResult> => {
|
||||
ctx._internal.webSearchState = {
|
||||
results: undefined
|
||||
}
|
||||
// 调用下游中间件
|
||||
const result = await next(ctx, params)
|
||||
let isFirstChunk = true
|
||||
|
||||
// 响应后处理:记录Web搜索事件
|
||||
if (result.stream) {
|
||||
const resultFromUpstream = result.stream
|
||||
|
||||
if (resultFromUpstream && resultFromUpstream instanceof ReadableStream) {
|
||||
// Web搜索状态跟踪
|
||||
const enhancedStream = (resultFromUpstream as ReadableStream<GenericChunk>).pipeThrough(
|
||||
new TransformStream<GenericChunk, GenericChunk>({
|
||||
transform(chunk: GenericChunk, controller) {
|
||||
if (chunk.type === ChunkType.TEXT_DELTA) {
|
||||
// 使用当前可用的Web搜索结果进行链接转换
|
||||
const text = chunk.text
|
||||
const result = convertLinks(text, isFirstChunk)
|
||||
if (isFirstChunk) {
|
||||
isFirstChunk = false
|
||||
}
|
||||
|
||||
// - 如果有内容被缓冲,说明convertLinks正在等待后续chunk,不使用原文本避免重复
|
||||
// - 如果没有内容被缓冲且结果为空,可能是其他处理问题,使用原文本作为安全回退
|
||||
let finalText: string
|
||||
if (result.hasBufferedContent) {
|
||||
// 有内容被缓冲,使用处理后的结果(可能为空,等待后续chunk)
|
||||
finalText = result.text
|
||||
} else {
|
||||
// 没有内容被缓冲,可以安全使用回退逻辑
|
||||
finalText = result.text || text
|
||||
}
|
||||
|
||||
// 只有当finalText不为空时才发送chunk
|
||||
if (finalText) {
|
||||
controller.enqueue({
|
||||
...chunk,
|
||||
text: finalText
|
||||
})
|
||||
}
|
||||
} else if (chunk.type === ChunkType.LLM_WEB_SEARCH_COMPLETE) {
|
||||
// 暂存Web搜索结果用于链接完善
|
||||
ctx._internal.webSearchState!.results = chunk.llm_web_search
|
||||
|
||||
// 将Web搜索完成事件继续传递下去
|
||||
controller.enqueue(chunk)
|
||||
} else if (chunk.type === ChunkType.LLM_RESPONSE_COMPLETE) {
|
||||
// 流结束时,清空链接转换器的buffer并处理剩余内容
|
||||
const remainingText = flushLinkConverterBuffer()
|
||||
if (remainingText) {
|
||||
controller.enqueue({
|
||||
type: ChunkType.TEXT_DELTA,
|
||||
text: remainingText
|
||||
})
|
||||
}
|
||||
// 继续传递LLM_RESPONSE_COMPLETE事件
|
||||
controller.enqueue(chunk)
|
||||
} else {
|
||||
controller.enqueue(chunk)
|
||||
}
|
||||
}
|
||||
})
|
||||
)
|
||||
|
||||
return {
|
||||
...result,
|
||||
stream: enhancedStream
|
||||
}
|
||||
} else {
|
||||
logger.debug(`No stream to process or not a ReadableStream.`)
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
@ -1,148 +0,0 @@
|
||||
import type OpenAI from '@cherrystudio/openai'
|
||||
import { toFile } from '@cherrystudio/openai/uploads'
|
||||
import { isDedicatedImageGenerationModel } from '@renderer/config/models'
|
||||
import FileManager from '@renderer/services/FileManager'
|
||||
import { ChunkType } from '@renderer/types/chunk'
|
||||
import { findImageBlocks, getMainTextContent } from '@renderer/utils/messageUtils/find'
|
||||
import { defaultTimeout } from '@shared/config/constant'
|
||||
|
||||
import type { BaseApiClient } from '../../clients/BaseApiClient'
|
||||
import type { CompletionsParams, CompletionsResult, GenericChunk } from '../schemas'
|
||||
import type { CompletionsContext, CompletionsMiddleware } from '../types'
|
||||
|
||||
export const MIDDLEWARE_NAME = 'ImageGenerationMiddleware'
|
||||
|
||||
export const ImageGenerationMiddleware: CompletionsMiddleware =
|
||||
() =>
|
||||
(next) =>
|
||||
async (context: CompletionsContext, params: CompletionsParams): Promise<CompletionsResult> => {
|
||||
const { assistant, messages } = params
|
||||
const client = context.apiClientInstance as BaseApiClient<OpenAI>
|
||||
const signal = context._internal?.flowControl?.abortSignal
|
||||
if (!assistant.model || !isDedicatedImageGenerationModel(assistant.model) || typeof messages === 'string') {
|
||||
return next(context, params)
|
||||
}
|
||||
|
||||
const stream = new ReadableStream<GenericChunk>({
|
||||
async start(controller) {
|
||||
const enqueue = (chunk: GenericChunk) => controller.enqueue(chunk)
|
||||
|
||||
try {
|
||||
if (!assistant.model) {
|
||||
throw new Error('Assistant model is not defined.')
|
||||
}
|
||||
|
||||
const sdk = await client.getSdkInstance()
|
||||
const lastUserMessage = messages.findLast((m) => m.role === 'user')
|
||||
const lastAssistantMessage = messages.findLast((m) => m.role === 'assistant')
|
||||
|
||||
if (!lastUserMessage) {
|
||||
throw new Error('No user message found for image generation.')
|
||||
}
|
||||
|
||||
const prompt = getMainTextContent(lastUserMessage)
|
||||
let imageFiles: Blob[] = []
|
||||
|
||||
// Collect images from user message
|
||||
const userImageBlocks = findImageBlocks(lastUserMessage)
|
||||
const userImages = await Promise.all(
|
||||
userImageBlocks.map(async (block) => {
|
||||
if (!block.file) return null
|
||||
const binaryData: Uint8Array = await FileManager.readBinaryImage(block.file)
|
||||
const mimeType = `${block.file.type}/${block.file.ext.slice(1)}`
|
||||
return await toFile(new Blob([binaryData]), block.file.origin_name || 'image.png', { type: mimeType })
|
||||
})
|
||||
)
|
||||
imageFiles = imageFiles.concat(userImages.filter(Boolean) as Blob[])
|
||||
|
||||
// Collect images from last assistant message
|
||||
if (lastAssistantMessage) {
|
||||
const assistantImageBlocks = findImageBlocks(lastAssistantMessage)
|
||||
const assistantImages = await Promise.all(
|
||||
assistantImageBlocks.map(async (block) => {
|
||||
const b64 = block.url?.replace(/^data:image\/\w+;base64,/, '')
|
||||
if (!b64) return null
|
||||
const binary = atob(b64)
|
||||
const bytes = new Uint8Array(binary.length)
|
||||
for (let i = 0; i < binary.length; i++) bytes[i] = binary.charCodeAt(i)
|
||||
return await toFile(new Blob([bytes]), 'assistant_image.png', { type: 'image/png' })
|
||||
})
|
||||
)
|
||||
imageFiles = imageFiles.concat(assistantImages.filter(Boolean) as Blob[])
|
||||
}
|
||||
|
||||
enqueue({ type: ChunkType.IMAGE_CREATED })
|
||||
|
||||
const startTime = Date.now()
|
||||
let response: OpenAI.Images.ImagesResponse
|
||||
const options = { signal, timeout: defaultTimeout }
|
||||
|
||||
if (imageFiles.length > 0) {
|
||||
const model = assistant.model
|
||||
const provider = context.apiClientInstance.provider
|
||||
// https://learn.microsoft.com/en-us/azure/ai-foundry/openai/how-to/dall-e?tabs=gpt-image-1#call-the-image-edit-api
|
||||
if (model.id.toLowerCase().includes('gpt-image-1-mini') && provider.type === 'azure-openai') {
|
||||
throw new Error('Azure OpenAI GPT-Image-1-Mini model does not support image editing.')
|
||||
}
|
||||
response = await sdk.images.edit(
|
||||
{
|
||||
model: assistant.model.id,
|
||||
image: imageFiles,
|
||||
prompt: prompt || ''
|
||||
},
|
||||
options
|
||||
)
|
||||
} else {
|
||||
response = await sdk.images.generate(
|
||||
{
|
||||
model: assistant.model.id,
|
||||
prompt: prompt || '',
|
||||
response_format: assistant.model.id.includes('gpt-image-1') ? undefined : 'b64_json'
|
||||
},
|
||||
options
|
||||
)
|
||||
}
|
||||
|
||||
let imageType: 'url' | 'base64' = 'base64'
|
||||
const imageList =
|
||||
response.data?.reduce((acc: string[], image) => {
|
||||
if (image.url) {
|
||||
acc.push(image.url)
|
||||
imageType = 'url'
|
||||
} else if (image.b64_json) {
|
||||
acc.push(`data:image/png;base64,${image.b64_json}`)
|
||||
}
|
||||
return acc
|
||||
}, []) || []
|
||||
|
||||
enqueue({
|
||||
type: ChunkType.IMAGE_COMPLETE,
|
||||
image: { type: imageType, images: imageList }
|
||||
})
|
||||
|
||||
const usage = (response as any).usage || { prompt_tokens: 0, completion_tokens: 0, total_tokens: 0 }
|
||||
|
||||
enqueue({
|
||||
type: ChunkType.LLM_RESPONSE_COMPLETE,
|
||||
response: {
|
||||
usage,
|
||||
metrics: {
|
||||
completion_tokens: usage.completion_tokens,
|
||||
time_first_token_millsec: 0,
|
||||
time_completion_millsec: Date.now() - startTime
|
||||
}
|
||||
}
|
||||
})
|
||||
} catch (error: any) {
|
||||
enqueue({ type: ChunkType.ERROR, error })
|
||||
} finally {
|
||||
controller.close()
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
return {
|
||||
stream,
|
||||
getText: () => ''
|
||||
}
|
||||
}
|
||||
@ -1,198 +0,0 @@
|
||||
import { loggerService } from '@logger'
|
||||
import type { Model } from '@renderer/types'
|
||||
import type {
|
||||
TextDeltaChunk,
|
||||
ThinkingCompleteChunk,
|
||||
ThinkingDeltaChunk,
|
||||
ThinkingStartChunk
|
||||
} from '@renderer/types/chunk'
|
||||
import { ChunkType } from '@renderer/types/chunk'
|
||||
import { getLowerBaseModelName } from '@renderer/utils'
|
||||
import type { TagConfig } from '@renderer/utils/tagExtraction'
|
||||
import { TagExtractor } from '@renderer/utils/tagExtraction'
|
||||
|
||||
import type { CompletionsParams, CompletionsResult, GenericChunk } from '../schemas'
|
||||
import type { CompletionsContext, CompletionsMiddleware } from '../types'
|
||||
|
||||
const logger = loggerService.withContext('ThinkingTagExtractionMiddleware')
|
||||
|
||||
export const MIDDLEWARE_NAME = 'ThinkingTagExtractionMiddleware'
|
||||
|
||||
// 不同模型的思考标签配置
|
||||
const reasoningTags: TagConfig[] = [
|
||||
{ openingTag: '<think>', closingTag: '</think>', separator: '\n' },
|
||||
{ openingTag: '<thought>', closingTag: '</thought>', separator: '\n' },
|
||||
{ openingTag: '###Thinking', closingTag: '###Response', separator: '\n' },
|
||||
{ openingTag: '◁think▷', closingTag: '◁/think▷', separator: '\n' },
|
||||
{ openingTag: '<thinking>', closingTag: '</thinking>', separator: '\n' },
|
||||
{ openingTag: '<seed:think>', closingTag: '</seed:think>', separator: '\n' }
|
||||
]
|
||||
|
||||
const getAppropriateTag = (model?: Model): TagConfig => {
|
||||
const modelId = model?.id ? getLowerBaseModelName(model.id) : undefined
|
||||
if (modelId?.includes('qwen3')) return reasoningTags[0]
|
||||
if (modelId?.includes('gemini-2.5')) return reasoningTags[1]
|
||||
if (modelId?.includes('kimi-vl-a3b-thinking')) return reasoningTags[3]
|
||||
if (modelId?.includes('seed-oss-36b')) return reasoningTags[5]
|
||||
// 可以在这里添加更多模型特定的标签配置
|
||||
return reasoningTags[0] // 默认使用 <think> 标签
|
||||
}
|
||||
|
||||
/**
|
||||
* 处理文本流中思考标签提取的中间件
|
||||
*
|
||||
* 该中间件专门处理文本流中的思考标签内容(如 <think>...</think>)
|
||||
* 主要用于 OpenAI 等支持思考标签的 provider
|
||||
*
|
||||
* 职责:
|
||||
* 1. 从文本流中提取思考标签内容
|
||||
* 2. 将标签内的内容转换为 THINKING_DELTA chunk
|
||||
* 3. 将标签外的内容作为正常文本输出
|
||||
* 4. 处理不同模型的思考标签格式
|
||||
* 5. 在思考内容结束时生成 THINKING_COMPLETE 事件
|
||||
*/
|
||||
export const ThinkingTagExtractionMiddleware: CompletionsMiddleware =
|
||||
() =>
|
||||
(next) =>
|
||||
async (context: CompletionsContext, params: CompletionsParams): Promise<CompletionsResult> => {
|
||||
// 调用下游中间件
|
||||
const result = await next(context, params)
|
||||
|
||||
// 响应后处理:处理思考标签提取
|
||||
if (result.stream) {
|
||||
const resultFromUpstream = result.stream as ReadableStream<GenericChunk>
|
||||
|
||||
// 检查是否有流需要处理
|
||||
if (resultFromUpstream && resultFromUpstream instanceof ReadableStream) {
|
||||
// 获取当前模型的思考标签配置
|
||||
const model = params.assistant?.model
|
||||
const reasoningTag = getAppropriateTag(model)
|
||||
|
||||
// 创建标签提取器
|
||||
const tagExtractor = new TagExtractor(reasoningTag)
|
||||
|
||||
// thinking 处理状态
|
||||
let hasThinkingContent = false
|
||||
let thinkingStartTime = 0
|
||||
|
||||
let accumulatingText = false
|
||||
let accumulatedThinkingContent = ''
|
||||
const processedStream = resultFromUpstream.pipeThrough(
|
||||
new TransformStream<GenericChunk, GenericChunk>({
|
||||
transform(chunk: GenericChunk, controller) {
|
||||
logger.silly('chunk', chunk)
|
||||
|
||||
if (chunk.type === ChunkType.TEXT_DELTA) {
|
||||
const textChunk = chunk as TextDeltaChunk
|
||||
|
||||
// 使用 TagExtractor 处理文本
|
||||
const extractionResults = tagExtractor.processText(textChunk.text)
|
||||
|
||||
for (const extractionResult of extractionResults) {
|
||||
if (extractionResult.complete && extractionResult.tagContentExtracted?.trim()) {
|
||||
// 完成思考
|
||||
// logger.silly(
|
||||
// 'since extractionResult.complete and extractionResult.tagContentExtracted is not empty, THINKING_COMPLETE chunk is generated'
|
||||
// )
|
||||
// 如果完成思考,更新状态
|
||||
accumulatingText = false
|
||||
|
||||
// 生成 THINKING_COMPLETE 事件
|
||||
const thinkingCompleteChunk: ThinkingCompleteChunk = {
|
||||
type: ChunkType.THINKING_COMPLETE,
|
||||
text: extractionResult.tagContentExtracted.trim(),
|
||||
thinking_millsec: thinkingStartTime > 0 ? Date.now() - thinkingStartTime : 0
|
||||
}
|
||||
controller.enqueue(thinkingCompleteChunk)
|
||||
|
||||
// 重置思考状态
|
||||
hasThinkingContent = false
|
||||
thinkingStartTime = 0
|
||||
} else if (extractionResult.content.length > 0) {
|
||||
// logger.silly(
|
||||
// 'since extractionResult.content is not empty, try to generate THINKING_START/THINKING_DELTA chunk'
|
||||
// )
|
||||
if (extractionResult.isTagContent) {
|
||||
// 如果提取到思考内容,更新状态
|
||||
accumulatingText = false
|
||||
|
||||
// 第一次接收到思考内容时记录开始时间
|
||||
if (!hasThinkingContent) {
|
||||
hasThinkingContent = true
|
||||
thinkingStartTime = Date.now()
|
||||
controller.enqueue({
|
||||
type: ChunkType.THINKING_START
|
||||
} as ThinkingStartChunk)
|
||||
}
|
||||
|
||||
if (extractionResult.content?.trim()) {
|
||||
accumulatedThinkingContent += extractionResult.content.trim()
|
||||
const thinkingDeltaChunk: ThinkingDeltaChunk = {
|
||||
type: ChunkType.THINKING_DELTA,
|
||||
text: accumulatedThinkingContent,
|
||||
thinking_millsec: thinkingStartTime > 0 ? Date.now() - thinkingStartTime : 0
|
||||
}
|
||||
controller.enqueue(thinkingDeltaChunk)
|
||||
}
|
||||
} else {
|
||||
// 如果没有思考内容,直接输出文本
|
||||
// logger.silly(
|
||||
// 'since extractionResult.isTagContent is falsy, try to generate TEXT_START/TEXT_DELTA chunk'
|
||||
// )
|
||||
// 在非组成文本状态下接收到非思考内容时,生成 TEXT_START chunk 并更新状态
|
||||
if (!accumulatingText) {
|
||||
// logger.silly('since accumulatingText is false, TEXT_START chunk is generated')
|
||||
controller.enqueue({
|
||||
type: ChunkType.TEXT_START
|
||||
})
|
||||
accumulatingText = true
|
||||
}
|
||||
// 发送清理后的文本内容
|
||||
const cleanTextChunk: TextDeltaChunk = {
|
||||
...textChunk,
|
||||
text: extractionResult.content
|
||||
}
|
||||
controller.enqueue(cleanTextChunk)
|
||||
}
|
||||
} else {
|
||||
// logger.silly('since both condition is false, skip')
|
||||
}
|
||||
}
|
||||
} else if (chunk.type !== ChunkType.TEXT_START) {
|
||||
// logger.silly('since chunk.type is not TEXT_START and not TEXT_DELTA, pass through')
|
||||
|
||||
// logger.silly('since chunk.type is not TEXT_START and not TEXT_DELTA, accumulatingText is set to false')
|
||||
accumulatingText = false
|
||||
// 其他类型的chunk直接传递(包括 THINKING_DELTA, THINKING_COMPLETE 等)
|
||||
controller.enqueue(chunk)
|
||||
} else {
|
||||
// 接收到的 TEXT_START chunk 直接丢弃
|
||||
// logger.silly('since chunk.type is TEXT_START, passed')
|
||||
}
|
||||
},
|
||||
flush(controller) {
|
||||
// 处理可能剩余的思考内容
|
||||
const finalResult = tagExtractor.finalize()
|
||||
if (finalResult?.tagContentExtracted) {
|
||||
const thinkingCompleteChunk: ThinkingCompleteChunk = {
|
||||
type: ChunkType.THINKING_COMPLETE,
|
||||
text: finalResult.tagContentExtracted,
|
||||
thinking_millsec: thinkingStartTime > 0 ? Date.now() - thinkingStartTime : 0
|
||||
}
|
||||
controller.enqueue(thinkingCompleteChunk)
|
||||
}
|
||||
}
|
||||
})
|
||||
)
|
||||
|
||||
// 更新响应结果
|
||||
return {
|
||||
...result,
|
||||
stream: processedStream
|
||||
}
|
||||
} else {
|
||||
logger.warn(`[${MIDDLEWARE_NAME}] No generic chunk stream to process or not a ReadableStream.`)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
@ -1,138 +0,0 @@
|
||||
import { loggerService } from '@logger'
|
||||
import type { MCPTool } from '@renderer/types'
|
||||
import type { MCPToolCreatedChunk, TextDeltaChunk } from '@renderer/types/chunk'
|
||||
import { ChunkType } from '@renderer/types/chunk'
|
||||
import { parseToolUse } from '@renderer/utils/mcp-tools'
|
||||
import type { TagConfig } from '@renderer/utils/tagExtraction'
|
||||
import { TagExtractor } from '@renderer/utils/tagExtraction'
|
||||
|
||||
import type { CompletionsParams, CompletionsResult, GenericChunk } from '../schemas'
|
||||
import type { CompletionsContext, CompletionsMiddleware } from '../types'
|
||||
|
||||
export const MIDDLEWARE_NAME = 'ToolUseExtractionMiddleware'
|
||||
|
||||
const logger = loggerService.withContext('ToolUseExtractionMiddleware')
|
||||
|
||||
// 工具使用标签配置
|
||||
const TOOL_USE_TAG_CONFIG: TagConfig = {
|
||||
openingTag: '<tool_use>',
|
||||
closingTag: '</tool_use>',
|
||||
separator: '\n'
|
||||
}
|
||||
|
||||
/**
|
||||
* 工具使用提取中间件
|
||||
*
|
||||
* 职责:
|
||||
* 1. 从文本流中检测并提取 <tool_use></tool_use> 标签
|
||||
* 2. 解析工具调用信息并转换为 ToolUseResponse 格式
|
||||
* 3. 生成 MCP_TOOL_CREATED chunk 供 McpToolChunkMiddleware 处理
|
||||
* 4. 丢弃 tool_use 之后的所有内容(助手幻觉)
|
||||
* 5. 清理文本流,移除工具使用标签但保留正常文本
|
||||
*
|
||||
* 注意:此中间件只负责提取和转换,实际工具调用由 McpToolChunkMiddleware 处理
|
||||
*/
|
||||
export const ToolUseExtractionMiddleware: CompletionsMiddleware =
|
||||
() =>
|
||||
(next) =>
|
||||
async (ctx: CompletionsContext, params: CompletionsParams): Promise<CompletionsResult> => {
|
||||
const mcpTools = params.mcpTools || []
|
||||
|
||||
if (!mcpTools || mcpTools.length === 0) return next(ctx, params)
|
||||
|
||||
const result = await next(ctx, params)
|
||||
|
||||
if (result.stream) {
|
||||
const resultFromUpstream = result.stream as ReadableStream<GenericChunk>
|
||||
|
||||
const processedStream = resultFromUpstream.pipeThrough(createToolUseExtractionTransform(ctx, mcpTools))
|
||||
|
||||
return {
|
||||
...result,
|
||||
stream: processedStream
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
/**
|
||||
* 创建工具使用提取的 TransformStream
|
||||
*/
|
||||
function createToolUseExtractionTransform(
|
||||
_ctx: CompletionsContext,
|
||||
mcpTools: MCPTool[]
|
||||
): TransformStream<GenericChunk, GenericChunk> {
|
||||
const toolUseExtractor = new TagExtractor(TOOL_USE_TAG_CONFIG)
|
||||
let hasAnyToolUse = false
|
||||
let toolCounter = 0
|
||||
|
||||
return new TransformStream({
|
||||
async transform(chunk: GenericChunk, controller) {
|
||||
try {
|
||||
// 处理文本内容,检测工具使用标签
|
||||
logger.silly('chunk', chunk)
|
||||
if (chunk.type === ChunkType.TEXT_DELTA) {
|
||||
const textChunk = chunk as TextDeltaChunk
|
||||
|
||||
// 处理 tool_use 标签
|
||||
const toolUseResults = toolUseExtractor.processText(textChunk.text)
|
||||
|
||||
for (const result of toolUseResults) {
|
||||
if (result.complete && result.tagContentExtracted) {
|
||||
// 提取到完整的工具使用内容,解析并转换为 SDK ToolCall 格式
|
||||
const toolUseResponses = parseToolUse(result.tagContentExtracted, mcpTools, toolCounter)
|
||||
toolCounter += toolUseResponses.length
|
||||
|
||||
if (toolUseResponses.length > 0) {
|
||||
// 生成 MCP_TOOL_CREATED chunk
|
||||
const mcpToolCreatedChunk: MCPToolCreatedChunk = {
|
||||
type: ChunkType.MCP_TOOL_CREATED,
|
||||
tool_use_responses: toolUseResponses
|
||||
}
|
||||
controller.enqueue(mcpToolCreatedChunk)
|
||||
|
||||
// 标记已有工具调用
|
||||
hasAnyToolUse = true
|
||||
}
|
||||
} else if (!result.isTagContent && result.content) {
|
||||
if (!hasAnyToolUse) {
|
||||
const cleanTextChunk: TextDeltaChunk = {
|
||||
...textChunk,
|
||||
text: result.content
|
||||
}
|
||||
controller.enqueue(cleanTextChunk)
|
||||
}
|
||||
}
|
||||
// tool_use 标签内的内容不转发,避免重复显示
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// 转发其他所有chunk
|
||||
controller.enqueue(chunk)
|
||||
} catch (error) {
|
||||
logger.error('Error processing chunk:', error as Error)
|
||||
controller.error(error)
|
||||
}
|
||||
},
|
||||
|
||||
async flush(controller) {
|
||||
// 检查是否有未完成的 tool_use 标签内容
|
||||
const finalToolUseResult = toolUseExtractor.finalize()
|
||||
if (finalToolUseResult && finalToolUseResult.tagContentExtracted) {
|
||||
const toolUseResponses = parseToolUse(finalToolUseResult.tagContentExtracted, mcpTools, toolCounter)
|
||||
if (toolUseResponses.length > 0) {
|
||||
const mcpToolCreatedChunk: MCPToolCreatedChunk = {
|
||||
type: ChunkType.MCP_TOOL_CREATED,
|
||||
tool_use_responses: toolUseResponses
|
||||
}
|
||||
controller.enqueue(mcpToolCreatedChunk)
|
||||
hasAnyToolUse = true
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
export default ToolUseExtractionMiddleware
|
||||
@ -1,88 +0,0 @@
|
||||
import type { CompletionsMiddleware, MethodMiddleware } from './types'
|
||||
|
||||
// /**
|
||||
// * Wraps a provider instance with middlewares.
|
||||
// */
|
||||
// export function wrapProviderWithMiddleware(
|
||||
// apiClientInstance: BaseApiClient,
|
||||
// middlewareConfig: MiddlewareConfig
|
||||
// ): BaseApiClient {
|
||||
// console.log(`[wrapProviderWithMiddleware] Wrapping provider: ${apiClientInstance.provider?.id}`)
|
||||
// console.log(`[wrapProviderWithMiddleware] Middleware config:`, {
|
||||
// completions: middlewareConfig.completions?.length || 0,
|
||||
// methods: Object.keys(middlewareConfig.methods || {}).length
|
||||
// })
|
||||
|
||||
// // Cache for already wrapped methods to avoid re-wrapping on every access.
|
||||
// const wrappedMethodsCache = new Map<string, (...args: any[]) => Promise<any>>()
|
||||
|
||||
// const proxy = new Proxy(apiClientInstance, {
|
||||
// get(target, propKey, receiver) {
|
||||
// const methodName = typeof propKey === 'string' ? propKey : undefined
|
||||
|
||||
// if (!methodName) {
|
||||
// return Reflect.get(target, propKey, receiver)
|
||||
// }
|
||||
|
||||
// if (wrappedMethodsCache.has(methodName)) {
|
||||
// console.log(`[wrapProviderWithMiddleware] Using cached wrapped method: ${methodName}`)
|
||||
// return wrappedMethodsCache.get(methodName)
|
||||
// }
|
||||
|
||||
// const originalMethod = Reflect.get(target, propKey, receiver)
|
||||
|
||||
// // If the property is not a function, return it directly.
|
||||
// if (typeof originalMethod !== 'function') {
|
||||
// return originalMethod
|
||||
// }
|
||||
|
||||
// let wrappedMethod: ((...args: any[]) => Promise<any>) | undefined
|
||||
|
||||
// // Handle completions method
|
||||
// if (methodName === 'completions' && middlewareConfig.completions?.length) {
|
||||
// console.log(
|
||||
// `[wrapProviderWithMiddleware] Wrapping completions method with ${middlewareConfig.completions.length} middlewares`
|
||||
// )
|
||||
// const completionsOriginalMethod = originalMethod as (params: CompletionsParams) => Promise<any>
|
||||
// wrappedMethod = applyCompletionsMiddlewares(target, completionsOriginalMethod, middlewareConfig.completions)
|
||||
// }
|
||||
// // Handle other methods
|
||||
// else {
|
||||
// const methodMiddlewares = middlewareConfig.methods?.[methodName]
|
||||
// if (methodMiddlewares?.length) {
|
||||
// console.log(
|
||||
// `[wrapProviderWithMiddleware] Wrapping method ${methodName} with ${methodMiddlewares.length} middlewares`
|
||||
// )
|
||||
// const genericOriginalMethod = originalMethod as (...args: any[]) => Promise<any>
|
||||
// wrappedMethod = applyMethodMiddlewares(target, methodName, genericOriginalMethod, methodMiddlewares)
|
||||
// }
|
||||
// }
|
||||
|
||||
// if (wrappedMethod) {
|
||||
// console.log(`[wrapProviderWithMiddleware] Successfully wrapped method: ${methodName}`)
|
||||
// wrappedMethodsCache.set(methodName, wrappedMethod)
|
||||
// return wrappedMethod
|
||||
// }
|
||||
|
||||
// // If no middlewares are configured for this method, return the original method bound to the target. /
|
||||
// // 如果没有为此方法配置中间件,则返回绑定到目标的原始方法。
|
||||
// console.log(`[wrapProviderWithMiddleware] No middlewares for method ${methodName}, returning original`)
|
||||
// return originalMethod.bind(target)
|
||||
// }
|
||||
// })
|
||||
// return proxy as BaseApiClient
|
||||
// }
|
||||
|
||||
// Export types for external use
|
||||
export type { CompletionsMiddleware, MethodMiddleware }
|
||||
|
||||
// Export MiddlewareBuilder related types and classes
|
||||
export {
|
||||
CompletionsMiddlewareBuilder,
|
||||
createCompletionsBuilder,
|
||||
createMethodBuilder,
|
||||
MethodMiddlewareBuilder,
|
||||
MiddlewareBuilder,
|
||||
type MiddlewareExecutor,
|
||||
type NamedMiddleware
|
||||
} from './builder'
|
||||
@ -1,149 +0,0 @@
|
||||
import * as AbortHandlerModule from './common/AbortHandlerMiddleware'
|
||||
import * as ErrorHandlerModule from './common/ErrorHandlerMiddleware'
|
||||
import * as FinalChunkConsumerModule from './common/FinalChunkConsumerMiddleware'
|
||||
import * as LoggingModule from './common/LoggingMiddleware'
|
||||
import * as McpToolChunkModule from './core/McpToolChunkMiddleware'
|
||||
import * as RawStreamListenerModule from './core/RawStreamListenerMiddleware'
|
||||
import * as ResponseTransformModule from './core/ResponseTransformMiddleware'
|
||||
// import * as SdkCallModule from './core/SdkCallMiddleware'
|
||||
import * as StreamAdapterModule from './core/StreamAdapterMiddleware'
|
||||
import * as TextChunkModule from './core/TextChunkMiddleware'
|
||||
import * as ThinkChunkModule from './core/ThinkChunkMiddleware'
|
||||
import * as TransformCoreToSdkParamsModule from './core/TransformCoreToSdkParamsMiddleware'
|
||||
import * as WebSearchModule from './core/WebSearchMiddleware'
|
||||
import * as ImageGenerationModule from './feat/ImageGenerationMiddleware'
|
||||
import * as ThinkingTagExtractionModule from './feat/ThinkingTagExtractionMiddleware'
|
||||
import * as ToolUseExtractionMiddleware from './feat/ToolUseExtractionMiddleware'
|
||||
|
||||
/**
|
||||
* 中间件注册表 - 提供所有可用中间件的集中访问
|
||||
* 注意:目前中间件文件还未导出 MIDDLEWARE_NAME,会有 linter 错误,这是正常的
|
||||
*/
|
||||
export const MiddlewareRegistry = {
|
||||
[ErrorHandlerModule.MIDDLEWARE_NAME]: {
|
||||
name: ErrorHandlerModule.MIDDLEWARE_NAME,
|
||||
middleware: ErrorHandlerModule.ErrorHandlerMiddleware
|
||||
},
|
||||
// 通用中间件
|
||||
[AbortHandlerModule.MIDDLEWARE_NAME]: {
|
||||
name: AbortHandlerModule.MIDDLEWARE_NAME,
|
||||
middleware: AbortHandlerModule.AbortHandlerMiddleware
|
||||
},
|
||||
[FinalChunkConsumerModule.MIDDLEWARE_NAME]: {
|
||||
name: FinalChunkConsumerModule.MIDDLEWARE_NAME,
|
||||
middleware: FinalChunkConsumerModule.default
|
||||
},
|
||||
|
||||
// 核心流程中间件
|
||||
[TransformCoreToSdkParamsModule.MIDDLEWARE_NAME]: {
|
||||
name: TransformCoreToSdkParamsModule.MIDDLEWARE_NAME,
|
||||
middleware: TransformCoreToSdkParamsModule.TransformCoreToSdkParamsMiddleware
|
||||
},
|
||||
// [SdkCallModule.MIDDLEWARE_NAME]: {
|
||||
// name: SdkCallModule.MIDDLEWARE_NAME,
|
||||
// middleware: SdkCallModule.SdkCallMiddleware
|
||||
// },
|
||||
[StreamAdapterModule.MIDDLEWARE_NAME]: {
|
||||
name: StreamAdapterModule.MIDDLEWARE_NAME,
|
||||
middleware: StreamAdapterModule.StreamAdapterMiddleware
|
||||
},
|
||||
[RawStreamListenerModule.MIDDLEWARE_NAME]: {
|
||||
name: RawStreamListenerModule.MIDDLEWARE_NAME,
|
||||
middleware: RawStreamListenerModule.RawStreamListenerMiddleware
|
||||
},
|
||||
[ResponseTransformModule.MIDDLEWARE_NAME]: {
|
||||
name: ResponseTransformModule.MIDDLEWARE_NAME,
|
||||
middleware: ResponseTransformModule.ResponseTransformMiddleware
|
||||
},
|
||||
|
||||
// 特性处理中间件
|
||||
[ThinkingTagExtractionModule.MIDDLEWARE_NAME]: {
|
||||
name: ThinkingTagExtractionModule.MIDDLEWARE_NAME,
|
||||
middleware: ThinkingTagExtractionModule.ThinkingTagExtractionMiddleware
|
||||
},
|
||||
[ToolUseExtractionMiddleware.MIDDLEWARE_NAME]: {
|
||||
name: ToolUseExtractionMiddleware.MIDDLEWARE_NAME,
|
||||
middleware: ToolUseExtractionMiddleware.ToolUseExtractionMiddleware
|
||||
},
|
||||
[ThinkChunkModule.MIDDLEWARE_NAME]: {
|
||||
name: ThinkChunkModule.MIDDLEWARE_NAME,
|
||||
middleware: ThinkChunkModule.ThinkChunkMiddleware
|
||||
},
|
||||
[McpToolChunkModule.MIDDLEWARE_NAME]: {
|
||||
name: McpToolChunkModule.MIDDLEWARE_NAME,
|
||||
middleware: McpToolChunkModule.McpToolChunkMiddleware
|
||||
},
|
||||
[WebSearchModule.MIDDLEWARE_NAME]: {
|
||||
name: WebSearchModule.MIDDLEWARE_NAME,
|
||||
middleware: WebSearchModule.WebSearchMiddleware
|
||||
},
|
||||
[TextChunkModule.MIDDLEWARE_NAME]: {
|
||||
name: TextChunkModule.MIDDLEWARE_NAME,
|
||||
middleware: TextChunkModule.TextChunkMiddleware
|
||||
},
|
||||
[ImageGenerationModule.MIDDLEWARE_NAME]: {
|
||||
name: ImageGenerationModule.MIDDLEWARE_NAME,
|
||||
middleware: ImageGenerationModule.ImageGenerationMiddleware
|
||||
}
|
||||
} as const
|
||||
|
||||
/**
|
||||
* 根据名称获取中间件
|
||||
* @param name - 中间件名称
|
||||
* @returns 对应的中间件信息
|
||||
*/
|
||||
export function getMiddleware(name: string) {
|
||||
return MiddlewareRegistry[name]
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取所有注册的中间件名称
|
||||
* @returns 中间件名称列表
|
||||
*/
|
||||
export function getRegisteredMiddlewareNames(): string[] {
|
||||
return Object.keys(MiddlewareRegistry)
|
||||
}
|
||||
|
||||
/**
|
||||
* 默认的 Completions 中间件配置 - NamedMiddleware 格式,用于 MiddlewareBuilder
|
||||
*/
|
||||
export const DefaultCompletionsNamedMiddlewares = [
|
||||
MiddlewareRegistry[FinalChunkConsumerModule.MIDDLEWARE_NAME], // 最终消费者
|
||||
MiddlewareRegistry[ErrorHandlerModule.MIDDLEWARE_NAME], // 错误处理
|
||||
MiddlewareRegistry[TransformCoreToSdkParamsModule.MIDDLEWARE_NAME], // 参数转换
|
||||
MiddlewareRegistry[AbortHandlerModule.MIDDLEWARE_NAME], // 中止处理
|
||||
MiddlewareRegistry[McpToolChunkModule.MIDDLEWARE_NAME], // 工具处理
|
||||
MiddlewareRegistry[TextChunkModule.MIDDLEWARE_NAME], // 文本处理
|
||||
MiddlewareRegistry[WebSearchModule.MIDDLEWARE_NAME], // Web搜索处理
|
||||
MiddlewareRegistry[ToolUseExtractionMiddleware.MIDDLEWARE_NAME], // 工具使用提取处理
|
||||
MiddlewareRegistry[ThinkingTagExtractionModule.MIDDLEWARE_NAME], // 思考标签提取处理(特定provider)
|
||||
MiddlewareRegistry[ThinkChunkModule.MIDDLEWARE_NAME], // 思考处理(通用SDK)
|
||||
MiddlewareRegistry[ResponseTransformModule.MIDDLEWARE_NAME], // 响应转换
|
||||
MiddlewareRegistry[StreamAdapterModule.MIDDLEWARE_NAME], // 流适配器
|
||||
MiddlewareRegistry[RawStreamListenerModule.MIDDLEWARE_NAME] // 原始流监听器
|
||||
]
|
||||
|
||||
/**
|
||||
* 默认的通用方法中间件 - 例如翻译、摘要等
|
||||
*/
|
||||
export const DefaultMethodMiddlewares = {
|
||||
translate: [LoggingModule.createGenericLoggingMiddleware()],
|
||||
summaries: [LoggingModule.createGenericLoggingMiddleware()]
|
||||
}
|
||||
|
||||
/**
|
||||
* 导出所有中间件模块,方便外部使用
|
||||
*/
|
||||
export {
|
||||
AbortHandlerModule,
|
||||
FinalChunkConsumerModule,
|
||||
LoggingModule,
|
||||
McpToolChunkModule,
|
||||
ResponseTransformModule,
|
||||
StreamAdapterModule,
|
||||
TextChunkModule,
|
||||
ThinkChunkModule,
|
||||
ThinkingTagExtractionModule,
|
||||
TransformCoreToSdkParamsModule,
|
||||
WebSearchModule
|
||||
}
|
||||
@ -1,84 +0,0 @@
|
||||
import type { Assistant, MCPTool } from '@renderer/types'
|
||||
import type { Chunk } from '@renderer/types/chunk'
|
||||
import type { Message } from '@renderer/types/newMessage'
|
||||
import type { SdkRawChunk, SdkRawOutput } from '@renderer/types/sdk'
|
||||
|
||||
import type { ProcessingState } from './types'
|
||||
|
||||
// ============================================================================
|
||||
// Core Request Types - 核心请求结构
|
||||
// ============================================================================
|
||||
|
||||
/**
|
||||
* 标准化的内部核心请求结构,用于所有AI Provider的统一处理
|
||||
* 这是应用层参数转换后的标准格式,不包含回调函数和控制逻辑
|
||||
*/
|
||||
export interface CompletionsParams {
|
||||
/**
|
||||
* 调用的业务场景类型,用于中间件判断是否执行
|
||||
* 'chat': 主要对话流程
|
||||
* 'translate': 翻译
|
||||
* 'summary': 摘要
|
||||
* 'search': 搜索摘要
|
||||
* 'generate': 生成
|
||||
* 'check': API检查
|
||||
* 'test': 测试调用
|
||||
* 'translate-lang-detect': 翻译语言检测
|
||||
*/
|
||||
callType?: 'chat' | 'translate' | 'summary' | 'search' | 'generate' | 'check' | 'test' | 'translate-lang-detect'
|
||||
|
||||
// 基础对话数据
|
||||
messages: Message[] | string // 联合类型方便判断是否为空
|
||||
|
||||
assistant: Assistant // 助手为基本单位
|
||||
// model: Model
|
||||
|
||||
onChunk?: (chunk: Chunk) => void
|
||||
onResponse?: (text: string, isComplete: boolean) => void
|
||||
|
||||
// 错误相关
|
||||
onError?: (error: Error) => void
|
||||
shouldThrow?: boolean
|
||||
|
||||
// 工具相关
|
||||
mcpTools?: MCPTool[]
|
||||
|
||||
// 生成参数
|
||||
temperature?: number
|
||||
topP?: number
|
||||
maxTokens?: number
|
||||
|
||||
// 功能开关
|
||||
streamOutput: boolean
|
||||
enableWebSearch?: boolean
|
||||
enableUrlContext?: boolean
|
||||
enableReasoning?: boolean
|
||||
enableGenerateImage?: boolean
|
||||
|
||||
// 上下文控制
|
||||
contextCount?: number
|
||||
topicId?: string // 主题ID,用于关联上下文
|
||||
|
||||
// abort 控制
|
||||
abortKey?: string
|
||||
|
||||
_internal?: ProcessingState
|
||||
}
|
||||
|
||||
export interface CompletionsResult {
|
||||
rawOutput?: SdkRawOutput
|
||||
stream?: ReadableStream<SdkRawChunk> | ReadableStream<Chunk> | AsyncIterable<Chunk>
|
||||
controller?: AbortController
|
||||
|
||||
getText: () => string
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Generic Chunk Types - 通用数据块结构
|
||||
// ============================================================================
|
||||
|
||||
/**
|
||||
* 通用数据块类型
|
||||
* 复用现有的 Chunk 类型,这是所有AI Provider都应该输出的标准化数据块格式
|
||||
*/
|
||||
export type GenericChunk = Chunk
|
||||
@ -1,166 +0,0 @@
|
||||
import type { MCPToolResponse, Metrics, Usage, WebSearchResponse } from '@renderer/types'
|
||||
import type { Chunk, ErrorChunk } from '@renderer/types/chunk'
|
||||
import type {
|
||||
SdkInstance,
|
||||
SdkMessageParam,
|
||||
SdkParams,
|
||||
SdkRawChunk,
|
||||
SdkRawOutput,
|
||||
SdkTool,
|
||||
SdkToolCall
|
||||
} from '@renderer/types/sdk'
|
||||
|
||||
import type { BaseApiClient } from '../clients'
|
||||
import type { CompletionsParams, CompletionsResult } from './schemas'
|
||||
|
||||
/**
|
||||
* Symbol to uniquely identify middleware context objects.
|
||||
*/
|
||||
export const MIDDLEWARE_CONTEXT_SYMBOL = Symbol.for('AiProviderMiddlewareContext')
|
||||
|
||||
/**
|
||||
* Defines the structure for the onChunk callback function.
|
||||
*/
|
||||
export type OnChunkFunction = (chunk: Chunk | ErrorChunk) => void
|
||||
|
||||
/**
|
||||
* Base context that carries information about the current method call.
|
||||
*/
|
||||
export interface BaseContext {
|
||||
[MIDDLEWARE_CONTEXT_SYMBOL]: true
|
||||
methodName: string
|
||||
originalArgs: Readonly<any[]>
|
||||
}
|
||||
|
||||
/**
|
||||
* Processing state shared between middlewares.
|
||||
*/
|
||||
export interface ProcessingState<
|
||||
TParams extends SdkParams = SdkParams,
|
||||
TMessageParam extends SdkMessageParam = SdkMessageParam,
|
||||
TToolCall extends SdkToolCall = SdkToolCall
|
||||
> {
|
||||
sdkPayload?: TParams
|
||||
newReqMessages?: TMessageParam[]
|
||||
observer?: {
|
||||
usage?: Usage
|
||||
metrics?: Metrics
|
||||
}
|
||||
toolProcessingState?: {
|
||||
pendingToolCalls?: Array<TToolCall>
|
||||
executingToolCalls?: Array<{
|
||||
sdkToolCall: TToolCall
|
||||
mcpToolResponse: MCPToolResponse
|
||||
}>
|
||||
output?: SdkRawOutput | string
|
||||
isRecursiveCall?: boolean
|
||||
recursionDepth?: number
|
||||
}
|
||||
webSearchState?: {
|
||||
results?: WebSearchResponse
|
||||
}
|
||||
flowControl?: {
|
||||
abortController?: AbortController
|
||||
abortSignal?: AbortSignal
|
||||
cleanup?: () => void
|
||||
}
|
||||
enhancedDispatch?: (context: CompletionsContext, params: CompletionsParams) => Promise<CompletionsResult>
|
||||
customState?: Record<string, any>
|
||||
}
|
||||
|
||||
/**
|
||||
* Extended context for completions method.
|
||||
*/
|
||||
export interface CompletionsContext<
|
||||
TSdkParams extends SdkParams = SdkParams,
|
||||
TSdkMessageParam extends SdkMessageParam = SdkMessageParam,
|
||||
TSdkToolCall extends SdkToolCall = SdkToolCall,
|
||||
TSdkInstance extends SdkInstance = SdkInstance,
|
||||
TRawOutput extends SdkRawOutput = SdkRawOutput,
|
||||
TRawChunk extends SdkRawChunk = SdkRawChunk,
|
||||
TSdkSpecificTool extends SdkTool = SdkTool
|
||||
> extends BaseContext {
|
||||
readonly methodName: 'completions' // 强制方法名为 'completions'
|
||||
|
||||
apiClientInstance: BaseApiClient<
|
||||
TSdkInstance,
|
||||
TSdkParams,
|
||||
TRawOutput,
|
||||
TRawChunk,
|
||||
TSdkMessageParam,
|
||||
TSdkToolCall,
|
||||
TSdkSpecificTool
|
||||
>
|
||||
|
||||
// --- Mutable internal state for the duration of the middleware chain ---
|
||||
_internal: ProcessingState<TSdkParams, TSdkMessageParam, TSdkToolCall> // 包含所有可变的处理状态
|
||||
}
|
||||
|
||||
export interface MiddlewareAPI<Ctx extends BaseContext = BaseContext, Args extends any[] = any[]> {
|
||||
getContext: () => Ctx // Function to get the current context / 获取当前上下文的函数
|
||||
getOriginalArgs: () => Args // Function to get the original arguments of the method call / 获取方法调用原始参数的函数
|
||||
}
|
||||
|
||||
/**
|
||||
* Base middleware type.
|
||||
*/
|
||||
export type Middleware<TContext extends BaseContext> = (
|
||||
api: MiddlewareAPI<TContext>
|
||||
) => (
|
||||
next: (context: TContext, args: any[]) => Promise<unknown>
|
||||
) => (context: TContext, args: any[]) => Promise<unknown>
|
||||
|
||||
export type MethodMiddleware = Middleware<BaseContext>
|
||||
|
||||
/**
|
||||
* Completions middleware type.
|
||||
*/
|
||||
export type CompletionsMiddleware<
|
||||
TSdkParams extends SdkParams = SdkParams,
|
||||
TSdkMessageParam extends SdkMessageParam = SdkMessageParam,
|
||||
TSdkToolCall extends SdkToolCall = SdkToolCall,
|
||||
TSdkInstance extends SdkInstance = SdkInstance,
|
||||
TRawOutput extends SdkRawOutput = SdkRawOutput,
|
||||
TRawChunk extends SdkRawChunk = SdkRawChunk,
|
||||
TSdkSpecificTool extends SdkTool = SdkTool
|
||||
> = (
|
||||
api: MiddlewareAPI<
|
||||
CompletionsContext<
|
||||
TSdkParams,
|
||||
TSdkMessageParam,
|
||||
TSdkToolCall,
|
||||
TSdkInstance,
|
||||
TRawOutput,
|
||||
TRawChunk,
|
||||
TSdkSpecificTool
|
||||
>,
|
||||
[CompletionsParams]
|
||||
>
|
||||
) => (
|
||||
next: (
|
||||
context: CompletionsContext<
|
||||
TSdkParams,
|
||||
TSdkMessageParam,
|
||||
TSdkToolCall,
|
||||
TSdkInstance,
|
||||
TRawOutput,
|
||||
TRawChunk,
|
||||
TSdkSpecificTool
|
||||
>,
|
||||
params: CompletionsParams
|
||||
) => Promise<CompletionsResult>
|
||||
) => (
|
||||
context: CompletionsContext<
|
||||
TSdkParams,
|
||||
TSdkMessageParam,
|
||||
TSdkToolCall,
|
||||
TSdkInstance,
|
||||
TRawOutput,
|
||||
TRawChunk,
|
||||
TSdkSpecificTool
|
||||
>,
|
||||
params: CompletionsParams
|
||||
) => Promise<CompletionsResult>
|
||||
|
||||
// Re-export for convenience
|
||||
export type { Chunk as OnChunkArg } from '@renderer/types/chunk'
|
||||
@ -1,58 +0,0 @@
|
||||
import type { ErrorChunk } from '@renderer/types/chunk'
|
||||
import { ChunkType } from '@renderer/types/chunk'
|
||||
|
||||
/**
|
||||
* Creates an ErrorChunk object with a standardized structure.
|
||||
* @param error The error object or message.
|
||||
* @param chunkType The type of chunk, defaults to ChunkType.ERROR.
|
||||
* @returns An ErrorChunk object.
|
||||
*/
|
||||
export function createErrorChunk(error: any, chunkType: ChunkType = ChunkType.ERROR): ErrorChunk {
|
||||
let errorDetails: Record<string, any> = {}
|
||||
|
||||
if (error instanceof Error) {
|
||||
errorDetails = {
|
||||
message: error.message,
|
||||
name: error.name,
|
||||
stack: error.stack
|
||||
}
|
||||
} else if (typeof error === 'string') {
|
||||
errorDetails = { message: error }
|
||||
} else if (typeof error === 'object' && error !== null) {
|
||||
errorDetails = Object.getOwnPropertyNames(error).reduce(
|
||||
(acc, key) => {
|
||||
acc[key] = error[key]
|
||||
return acc
|
||||
},
|
||||
{} as Record<string, any>
|
||||
)
|
||||
if (!errorDetails.message && error.toString && typeof error.toString === 'function') {
|
||||
const errMsg = error.toString()
|
||||
if (errMsg !== '[object Object]') {
|
||||
errorDetails.message = errMsg
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
type: chunkType,
|
||||
error: errorDetails
|
||||
} as ErrorChunk
|
||||
}
|
||||
|
||||
// Helper to capitalize method names for hook construction
|
||||
export function capitalize(str: string): string {
|
||||
if (!str) return ''
|
||||
return str.charAt(0).toUpperCase() + str.slice(1)
|
||||
}
|
||||
|
||||
/**
|
||||
* 检查对象是否实现了AsyncIterable接口
|
||||
*/
|
||||
export function isAsyncIterable<T = unknown>(obj: unknown): obj is AsyncIterable<T> {
|
||||
return (
|
||||
obj !== null &&
|
||||
typeof obj === 'object' &&
|
||||
typeof (obj as Record<symbol, unknown>)[Symbol.asyncIterator] === 'function'
|
||||
)
|
||||
}
|
||||
@ -29,3 +29,11 @@ export type ProviderConfig<T extends StringKeys<AppProviderSettingsMap> = String
|
||||
|
||||
export type { AppProviderId, AppProviderSettingsMap, AppRuntimeConfig } from './merged'
|
||||
export { appProviderIds, getAllProviderIds, isRegisteredProviderId } from './merged'
|
||||
|
||||
/**
|
||||
* Result of completions operation
|
||||
* Simple interface with getText method to retrieve the generated text
|
||||
*/
|
||||
export type CompletionsResult = {
|
||||
getText: () => string
|
||||
}
|
||||
|
||||
@ -19,7 +19,7 @@ import { isToolUseModeFunction } from '@renderer/utils/assistant'
|
||||
import { isAbortError } from '@renderer/utils/error'
|
||||
import { purifyMarkdownImages } from '@renderer/utils/markdown'
|
||||
import { isPromptToolUse, isSupportedToolUse } from '@renderer/utils/mcp-tools'
|
||||
import { findFileBlocks, getMainTextContent } from '@renderer/utils/messageUtils/find'
|
||||
import { findFileBlocks, findImageBlocks, getMainTextContent } from '@renderer/utils/messageUtils/find'
|
||||
import { containsSupportedVariables, replacePromptVariables } from '@renderer/utils/prompt'
|
||||
import { NOT_SUPPORT_API_KEY_PROVIDER_TYPES, NOT_SUPPORT_API_KEY_PROVIDERS } from '@renderer/utils/provider'
|
||||
import { isEmpty, takeRight } from 'lodash'
|
||||
@ -35,6 +35,7 @@ import {
|
||||
getQuickModel
|
||||
} from './AssistantService'
|
||||
import { ConversationService } from './ConversationService'
|
||||
import FileManager from './FileManager'
|
||||
import { injectUserMessageWithKnowledgeSearchPrompt } from './KnowledgeService'
|
||||
import type { BlockManager } from './messageStreaming'
|
||||
import type { StreamProcessorCallbacks } from './StreamProcessingService'
|
||||
@ -167,6 +168,20 @@ export async function fetchChatCompletion({
|
||||
const AI = new AiProviderNew(assistant.model || getDefaultModel(), providerWithRotatedKey)
|
||||
const provider = AI.getActualProvider()
|
||||
|
||||
// 专用图像生成模型走 generateImage 路径
|
||||
if (isDedicatedImageGenerationModel(assistant.model || getDefaultModel())) {
|
||||
if (!uiMessages || uiMessages.length === 0) {
|
||||
throw new Error('uiMessages is required for dedicated image generation models')
|
||||
}
|
||||
await fetchImageGeneration({
|
||||
messages: uiMessages,
|
||||
assistant,
|
||||
onChunkReceived,
|
||||
aiProvider: AI
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
const mcpTools: MCPTool[] = []
|
||||
onChunkReceived({ type: ChunkType.LLM_RESPONSE_CREATED })
|
||||
|
||||
@ -225,6 +240,114 @@ export async function fetchChatCompletion({
|
||||
})
|
||||
}
|
||||
|
||||
/**
|
||||
* 从消息中收集图像(用于图像编辑)
|
||||
* 收集用户消息中上传的图像和助手消息中生成的图像
|
||||
*/
|
||||
async function collectImagesFromMessages(userMessage: Message, assistantMessage?: Message): Promise<string[]> {
|
||||
const images: string[] = []
|
||||
|
||||
// 收集用户消息中的图像
|
||||
const userImageBlocks = findImageBlocks(userMessage)
|
||||
for (const block of userImageBlocks) {
|
||||
if (block.file) {
|
||||
const base64 = await FileManager.readBase64File(block.file)
|
||||
const mimeType = block.file.type || 'image/png'
|
||||
images.push(`data:${mimeType};base64,${base64}`)
|
||||
}
|
||||
}
|
||||
|
||||
// 收集助手消息中的图像(用于继续编辑生成的图像)
|
||||
if (assistantMessage) {
|
||||
const assistantImageBlocks = findImageBlocks(assistantMessage)
|
||||
for (const block of assistantImageBlocks) {
|
||||
if (block.url) {
|
||||
images.push(block.url)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return images
|
||||
}
|
||||
|
||||
/**
|
||||
* 独立的图像生成函数
|
||||
* 专用于 DALL-E、GPT-Image-1 等专用图像生成模型
|
||||
*/
|
||||
async function fetchImageGeneration({
|
||||
messages,
|
||||
assistant,
|
||||
onChunkReceived,
|
||||
aiProvider
|
||||
}: {
|
||||
messages: Message[]
|
||||
assistant: Assistant
|
||||
onChunkReceived: (chunk: Chunk) => void
|
||||
aiProvider: AiProviderNew
|
||||
}) {
|
||||
onChunkReceived({ type: ChunkType.LLM_RESPONSE_CREATED })
|
||||
onChunkReceived({ type: ChunkType.IMAGE_CREATED })
|
||||
|
||||
const startTime = Date.now()
|
||||
|
||||
try {
|
||||
// 提取 prompt 和图像
|
||||
const lastUserMessage = messages.findLast((m) => m.role === 'user')
|
||||
const lastAssistantMessage = messages.findLast((m) => m.role === 'assistant')
|
||||
|
||||
if (!lastUserMessage) {
|
||||
throw new Error('No user message found for image generation.')
|
||||
}
|
||||
|
||||
const prompt = getMainTextContent(lastUserMessage)
|
||||
const inputImages = await collectImagesFromMessages(lastUserMessage, lastAssistantMessage)
|
||||
|
||||
// 调用 generateImage 或 editImage
|
||||
// 使用默认图像生成配置
|
||||
const imageSize = '1024x1024'
|
||||
const batchSize = 1
|
||||
|
||||
let images: string[]
|
||||
if (inputImages.length > 0) {
|
||||
images = await aiProvider.editImage({
|
||||
model: assistant.model!.id,
|
||||
prompt: prompt || '',
|
||||
inputImages,
|
||||
imageSize
|
||||
})
|
||||
} else {
|
||||
images = await aiProvider.generateImage({
|
||||
model: assistant.model!.id,
|
||||
prompt: prompt || '',
|
||||
imageSize,
|
||||
batchSize
|
||||
})
|
||||
}
|
||||
|
||||
// 发送结果 chunks
|
||||
const imageType = images[0]?.startsWith('data:') ? 'base64' : 'url'
|
||||
onChunkReceived({
|
||||
type: ChunkType.IMAGE_COMPLETE,
|
||||
image: { type: imageType, images }
|
||||
})
|
||||
|
||||
onChunkReceived({
|
||||
type: ChunkType.LLM_RESPONSE_COMPLETE,
|
||||
response: {
|
||||
usage: { prompt_tokens: 0, completion_tokens: 0, total_tokens: 0 },
|
||||
metrics: {
|
||||
completion_tokens: 0,
|
||||
time_first_token_millsec: 0,
|
||||
time_completion_millsec: Date.now() - startTime
|
||||
}
|
||||
}
|
||||
})
|
||||
} catch (error) {
|
||||
onChunkReceived({ type: ChunkType.ERROR, error: error as Error })
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
export async function fetchMessagesSummary({ messages, assistant }: { messages: Message[]; assistant: Assistant }) {
|
||||
let prompt = (getStoreSetting('topicNamingPrompt') as string) || i18n.t('prompts.title')
|
||||
const model = getQuickModel() || assistant?.model || getDefaultModel()
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
import { loggerService } from '@logger'
|
||||
import type { Span } from '@opentelemetry/api'
|
||||
import { ModernAiProvider } from '@renderer/aiCore'
|
||||
import AiProvider from '@renderer/aiCore/legacy'
|
||||
import ModernAiProvider from '@renderer/aiCore'
|
||||
import { getMessageContent } from '@renderer/aiCore/plugins/searchOrchestrationPlugin'
|
||||
import { DEFAULT_KNOWLEDGE_DOCUMENT_COUNT, DEFAULT_KNOWLEDGE_THRESHOLD } from '@renderer/config/constant'
|
||||
import { getEmbeddingMaxContext } from '@renderer/config/embedings'
|
||||
@ -36,7 +35,7 @@ const logger = loggerService.withContext('RendererKnowledgeService')
|
||||
export const getKnowledgeBaseParams = (base: KnowledgeBase): KnowledgeBaseParams => {
|
||||
const rerankProvider = getProviderByModel(base.rerankModel)
|
||||
const aiProvider = new ModernAiProvider(base.model)
|
||||
const rerankAiProvider = new AiProvider(rerankProvider)
|
||||
const rerankAiProvider = new ModernAiProvider(rerankProvider)
|
||||
|
||||
// get preprocess provider from store instead of base.preprocessProvider
|
||||
const preprocessProvider = store
|
||||
|
||||
@ -5,7 +5,6 @@ import type { SpanEntity, TokenUsage } from '@mcp-trace/trace-core'
|
||||
import { cleanContext, endContext, getContext, startContext } from '@mcp-trace/trace-web'
|
||||
import type { Context, Span } from '@opentelemetry/api'
|
||||
import { context, SpanStatusCode, trace } from '@opentelemetry/api'
|
||||
import { isAsyncIterable } from '@renderer/aiCore/legacy/middleware/utils'
|
||||
import { db } from '@renderer/databases'
|
||||
import { getEnableDeveloperMode } from '@renderer/hooks/useSettings'
|
||||
import { EVENT_NAMES, EventEmitter } from '@renderer/services/EventService'
|
||||
@ -22,6 +21,11 @@ import type { SdkRawChunk } from '@renderer/types/sdk'
|
||||
|
||||
const logger = loggerService.withContext('SpanManagerService')
|
||||
|
||||
// Type guard for AsyncIterable
|
||||
function isAsyncIterable<T>(obj: any): obj is AsyncIterable<T> {
|
||||
return obj != null && typeof obj === 'object' && typeof obj[Symbol.asyncIterator] === 'function'
|
||||
}
|
||||
|
||||
class SpanManagerService {
|
||||
private spanMap: Map<string, ModelSpanEntity[]> = new Map()
|
||||
|
||||
|
||||
@ -10,7 +10,7 @@ import type OpenAI from '@cherrystudio/openai'
|
||||
import type { ChatCompletionChunk } from '@cherrystudio/openai/resources'
|
||||
import type { FunctionCall } from '@google/genai'
|
||||
import { FinishReason, MediaModality } from '@google/genai'
|
||||
import AiProvider from '@renderer/aiCore'
|
||||
import AiProvider from '@renderer/aiCore/legacy'
|
||||
import type { BaseApiClient, OpenAIAPIClient, ResponseChunkTransformerContext } from '@renderer/aiCore/legacy/clients'
|
||||
import type { AnthropicAPIClient } from '@renderer/aiCore/legacy/clients/anthropic/AnthropicAPIClient'
|
||||
import { ApiClientFactory } from '@renderer/aiCore/legacy/clients/ApiClientFactory'
|
||||
@ -1,6 +1,6 @@
|
||||
import type { TokenUsage } from '@mcp-trace/trace-core'
|
||||
import type { Span } from '@opentelemetry/api'
|
||||
import type { CompletionsResult } from '@renderer/aiCore/legacy/middleware/schemas'
|
||||
import type { CompletionsResult } from '@renderer/aiCore/types'
|
||||
import { endSpan } from '@renderer/services/SpanManagerService'
|
||||
|
||||
export class CompletionsResultHandler {
|
||||
|
||||
Loading…
Reference in New Issue
Block a user