mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2026-01-27 14:11:59 +08:00
refactor: remove obsolete middleware and README, add new plugins for reasoning and tool selection
- Deleted README.md for AiSdkMiddlewareBuilder as it was outdated. - Removed toolChoiceMiddleware.ts as it is no longer needed. - Introduced new plugins: - noThinkPlugin: Appends '/no_think' to user messages to prevent unnecessary thinking. - openrouterGenerateImagePlugin: Configures OpenRouter for image and text modalities. - openrouterReasoningPlugin: Redacts reasoning blocks in OpenRouter responses. - qwenThinkingPlugin: Controls thinking mode for Qwen models based on provider support. - reasoningExtractionPlugin: Extracts reasoning tags from OpenAI/Azure responses. - simulateStreamingPlugin: Converts non-streaming responses to streaming format. - skipGeminiThoughtSignaturePlugin: Skips Gemini3 thought signatures for multi-model requests. - Updated parameterBuilder.ts to correct type definitions. - Added middlewareConfig.ts for better middleware configuration management. - Enhanced reasoning utility functions for better tag name retrieval. - Updated ApiService.ts and aiCoreTypes.ts for consistency with new changes.
This commit is contained in:
parent
9b1165e073
commit
fca41ed966
@ -9,6 +9,7 @@
|
||||
"scripts": {
|
||||
"build": "tsdown",
|
||||
"dev": "tsc -w",
|
||||
"typecheck": "tsc --noEmit",
|
||||
"clean": "rm -rf dist",
|
||||
"test": "vitest run",
|
||||
"test:watch": "vitest"
|
||||
|
||||
124
packages/aiCore/src/core/errors/index.ts
Normal file
124
packages/aiCore/src/core/errors/index.ts
Normal file
@ -0,0 +1,124 @@
|
||||
/**
|
||||
* AI Core Error System
|
||||
* Unified error handling for the AI Core package
|
||||
*/
|
||||
|
||||
/**
|
||||
* Base error class for all AI Core errors
|
||||
* Provides structured error information with error codes, context, and cause tracking
|
||||
*/
|
||||
export class AiCoreError extends Error {
|
||||
constructor(
|
||||
public readonly code: string,
|
||||
message: string,
|
||||
public readonly context?: Record<string, unknown>,
|
||||
public readonly cause?: Error
|
||||
) {
|
||||
super(message)
|
||||
this.name = 'AiCoreError'
|
||||
|
||||
if (cause) {
|
||||
this.stack = `${this.stack}\nCaused by: ${cause.stack}`
|
||||
}
|
||||
}
|
||||
|
||||
toJSON() {
|
||||
return {
|
||||
name: this.name,
|
||||
code: this.code,
|
||||
message: this.message,
|
||||
context: this.context,
|
||||
cause: this.cause
|
||||
? {
|
||||
name: this.cause.name,
|
||||
message: this.cause.message
|
||||
}
|
||||
: undefined
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Recursive depth limit exceeded error
|
||||
* Thrown when recursive calls exceed the maximum allowed depth
|
||||
*/
|
||||
export class RecursiveDepthError extends AiCoreError {
|
||||
constructor(requestId: string, currentDepth: number, maxDepth: number) {
|
||||
super('RECURSIVE_DEPTH_EXCEEDED', `Maximum recursive depth (${maxDepth}) exceeded at depth ${currentDepth}`, {
|
||||
requestId,
|
||||
currentDepth,
|
||||
maxDepth
|
||||
})
|
||||
this.name = 'RecursiveDepthError'
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Model resolution failure error
|
||||
* Thrown when a model ID cannot be resolved to a model instance
|
||||
*/
|
||||
export class ModelResolutionError extends AiCoreError {
|
||||
constructor(modelId: string, providerId: string, cause?: Error) {
|
||||
super('MODEL_RESOLUTION_FAILED', `Failed to resolve model: ${modelId}`, { modelId, providerId }, cause)
|
||||
this.name = 'ModelResolutionError'
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Parameter validation error
|
||||
* Thrown when request parameters fail validation
|
||||
*/
|
||||
export class ParameterValidationError extends AiCoreError {
|
||||
constructor(paramName: string, reason: string, value?: unknown) {
|
||||
super('PARAMETER_VALIDATION_FAILED', `Invalid parameter '${paramName}': ${reason}`, {
|
||||
paramName,
|
||||
reason,
|
||||
value
|
||||
})
|
||||
this.name = 'ParameterValidationError'
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Plugin execution error
|
||||
* Thrown when a plugin fails during execution
|
||||
*/
|
||||
export class PluginExecutionError extends AiCoreError {
|
||||
constructor(pluginName: string, hookName: string, cause: Error) {
|
||||
super(
|
||||
'PLUGIN_EXECUTION_FAILED',
|
||||
`Plugin '${pluginName}' failed in hook '${hookName}'`,
|
||||
{
|
||||
pluginName,
|
||||
hookName
|
||||
},
|
||||
cause
|
||||
)
|
||||
this.name = 'PluginExecutionError'
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Provider configuration error
|
||||
* Thrown when provider settings are invalid or missing
|
||||
*/
|
||||
export class ProviderConfigError extends AiCoreError {
|
||||
constructor(providerId: string, reason: string) {
|
||||
super('PROVIDER_CONFIG_ERROR', `Provider '${providerId}' configuration error: ${reason}`, {
|
||||
providerId,
|
||||
reason
|
||||
})
|
||||
this.name = 'ProviderConfigError'
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Template loading error
|
||||
* Thrown when a template cannot be loaded
|
||||
*/
|
||||
export class TemplateLoadError extends AiCoreError {
|
||||
constructor(templateName: string, cause?: Error) {
|
||||
super('TEMPLATE_LOAD_FAILED', `Failed to load template: ${templateName}`, { templateName }, cause)
|
||||
this.name = 'TemplateLoadError'
|
||||
}
|
||||
}
|
||||
@ -7,3 +7,6 @@ export { globalModelResolver, ModelResolver } from './ModelResolver'
|
||||
|
||||
// 保留的类型定义(可能被其他地方使用)
|
||||
export type { ModelConfig as ModelConfigType } from './types'
|
||||
|
||||
// 模型工具函数
|
||||
export { hasModelId, isV2Model, isV3Model } from './utils'
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
/**
|
||||
* Creation 模块类型定义
|
||||
*/
|
||||
import type { LanguageModelV3Middleware } from '@ai-sdk/provider'
|
||||
import type { JSONObject, LanguageModelV3Middleware } from '@ai-sdk/provider'
|
||||
|
||||
import type { ProviderId, ProviderSettingsMap } from '../providers/types'
|
||||
|
||||
@ -10,6 +10,5 @@ export interface ModelConfig<T extends ProviderId = ProviderId> {
|
||||
modelId: string
|
||||
providerSettings: ProviderSettingsMap[T] & { mode?: 'chat' | 'responses' }
|
||||
middlewares?: LanguageModelV3Middleware[]
|
||||
// 额外模型参数
|
||||
extraModelConfig?: Record<string, any>
|
||||
extraModelConfig?: JSONObject
|
||||
}
|
||||
|
||||
27
packages/aiCore/src/core/models/utils.ts
Normal file
27
packages/aiCore/src/core/models/utils.ts
Normal file
@ -0,0 +1,27 @@
|
||||
import type { LanguageModelV2, LanguageModelV3 } from '@ai-sdk/provider'
|
||||
|
||||
import type { AiSdkModel } from '../providers'
|
||||
|
||||
export const isV2Model = (model: AiSdkModel): model is LanguageModelV2 => {
|
||||
return typeof model === 'object' && model !== null && model.specificationVersion === 'v2'
|
||||
}
|
||||
|
||||
export const isV3Model = (model: AiSdkModel): model is LanguageModelV3 => {
|
||||
return typeof model === 'object' && model !== null && model.specificationVersion === 'v3'
|
||||
}
|
||||
|
||||
/**
|
||||
* Type guard to check if a model has a modelId property
|
||||
*/
|
||||
export const hasModelId = (model: unknown): model is { modelId: string } => {
|
||||
if (typeof model !== 'object' || model === null) {
|
||||
return false
|
||||
}
|
||||
|
||||
if (!('modelId' in model)) {
|
||||
return false
|
||||
}
|
||||
|
||||
const obj = model as Record<string, unknown>
|
||||
return typeof obj.modelId === 'string'
|
||||
}
|
||||
@ -1,87 +0,0 @@
|
||||
import { streamText } from 'ai'
|
||||
|
||||
import {
|
||||
createAnthropicOptions,
|
||||
createGenericProviderOptions,
|
||||
createGoogleOptions,
|
||||
createOpenAIOptions,
|
||||
mergeProviderOptions
|
||||
} from './factory'
|
||||
|
||||
// 示例1: 使用已知供应商的严格类型约束
|
||||
export function exampleOpenAIWithOptions() {
|
||||
const openaiOptions = createOpenAIOptions({
|
||||
reasoningEffort: 'medium'
|
||||
})
|
||||
|
||||
// 这里会有类型检查,确保选项符合OpenAI的设置
|
||||
return streamText({
|
||||
model: {} as any, // 实际使用时替换为真实模型
|
||||
prompt: 'Hello',
|
||||
providerOptions: openaiOptions
|
||||
})
|
||||
}
|
||||
|
||||
// 示例2: 使用Anthropic供应商选项
|
||||
export function exampleAnthropicWithOptions() {
|
||||
const anthropicOptions = createAnthropicOptions({
|
||||
thinking: {
|
||||
type: 'enabled',
|
||||
budgetTokens: 1000
|
||||
}
|
||||
})
|
||||
|
||||
return streamText({
|
||||
model: {} as any,
|
||||
prompt: 'Hello',
|
||||
providerOptions: anthropicOptions
|
||||
})
|
||||
}
|
||||
|
||||
// 示例3: 使用Google供应商选项
|
||||
export function exampleGoogleWithOptions() {
|
||||
const googleOptions = createGoogleOptions({
|
||||
thinkingConfig: {
|
||||
includeThoughts: true,
|
||||
thinkingBudget: 1000
|
||||
}
|
||||
})
|
||||
|
||||
return streamText({
|
||||
model: {} as any,
|
||||
prompt: 'Hello',
|
||||
providerOptions: googleOptions
|
||||
})
|
||||
}
|
||||
|
||||
// 示例4: 使用未知供应商(通用类型)
|
||||
export function exampleUnknownProviderWithOptions() {
|
||||
const customProviderOptions = createGenericProviderOptions('custom-provider', {
|
||||
temperature: 0.7,
|
||||
customSetting: 'value',
|
||||
anotherOption: true
|
||||
})
|
||||
|
||||
return streamText({
|
||||
model: {} as any,
|
||||
prompt: 'Hello',
|
||||
providerOptions: customProviderOptions
|
||||
})
|
||||
}
|
||||
|
||||
// 示例5: 合并多个供应商选项
|
||||
export function exampleMergedOptions() {
|
||||
const openaiOptions = createOpenAIOptions({})
|
||||
|
||||
const customOptions = createGenericProviderOptions('custom', {
|
||||
customParam: 'value'
|
||||
})
|
||||
|
||||
const mergedOptions = mergeProviderOptions(openaiOptions, customOptions)
|
||||
|
||||
return streamText({
|
||||
model: {} as any,
|
||||
prompt: 'Hello',
|
||||
providerOptions: mergedOptions
|
||||
})
|
||||
}
|
||||
@ -1,7 +1,6 @@
|
||||
import { google } from '@ai-sdk/google'
|
||||
|
||||
import { definePlugin } from '../../'
|
||||
import type { AiRequestContext } from '../../types'
|
||||
import { type AiPlugin, definePlugin, type StreamTextParams, type StreamTextResult } from '../../'
|
||||
|
||||
const toolNameMap = {
|
||||
googleSearch: 'google_search',
|
||||
@ -12,27 +11,40 @@ const toolNameMap = {
|
||||
type ToolConfigKey = keyof typeof toolNameMap
|
||||
type ToolConfig = { googleSearch?: boolean; urlContext?: boolean; codeExecution?: boolean }
|
||||
|
||||
export const googleToolsPlugin = (config?: ToolConfig) =>
|
||||
definePlugin({
|
||||
export const googleToolsPlugin = (config?: ToolConfig): AiPlugin<StreamTextParams, StreamTextResult> =>
|
||||
definePlugin<StreamTextParams, StreamTextResult>({
|
||||
name: 'googleToolsPlugin',
|
||||
transformParams: <T>(params: T, context: AiRequestContext): T => {
|
||||
transformParams: (params, context) => {
|
||||
const { providerId } = context
|
||||
if (providerId === 'google' && config) {
|
||||
if (typeof params === 'object' && params !== null) {
|
||||
const typedParams = params as T & { tools?: Record<string, unknown> }
|
||||
|
||||
if (!typedParams.tools) {
|
||||
typedParams.tools = {}
|
||||
}
|
||||
// 使用类型安全的方式遍历配置
|
||||
;(Object.keys(config) as ToolConfigKey[]).forEach((key) => {
|
||||
if (config[key] && key in toolNameMap && key in google.tools) {
|
||||
const toolName = toolNameMap[key]
|
||||
typedParams.tools![toolName] = google.tools[key]({})
|
||||
}
|
||||
})
|
||||
}
|
||||
// 只在 Google provider 且有配置时才修改参数
|
||||
if (providerId !== 'google' || !config) {
|
||||
return {} // 返回空 Partial,表示不修改
|
||||
}
|
||||
return params
|
||||
|
||||
if (typeof params !== 'object' || params === null) {
|
||||
return {}
|
||||
}
|
||||
|
||||
// 构建 tools 对象,确保类型兼容
|
||||
const hasTools = (Object.keys(config) as ToolConfigKey[]).some(
|
||||
(key) => config[key] && key in toolNameMap && key in google.tools
|
||||
)
|
||||
|
||||
if (!hasTools) {
|
||||
return {} // 返回空 Partial,表示不修改
|
||||
}
|
||||
|
||||
// 构建符合 AI SDK 的 tools 对象
|
||||
const tools: Record<string, ReturnType<(typeof google.tools)[keyof typeof google.tools]>> = {}
|
||||
|
||||
;(Object.keys(config) as ToolConfigKey[]).forEach((key) => {
|
||||
if (config[key] && key in toolNameMap && key in google.tools) {
|
||||
const toolName = toolNameMap[key]
|
||||
tools[toolName] = google.tools[key]({})
|
||||
}
|
||||
})
|
||||
|
||||
return { tools: tools }
|
||||
}
|
||||
})
|
||||
|
||||
@ -4,11 +4,61 @@
|
||||
* 负责处理 AI SDK 流事件的发送和管理
|
||||
* 从 promptToolUsePlugin.ts 中提取出来以降低复杂度
|
||||
*/
|
||||
import type { ModelMessage } from 'ai'
|
||||
import type { EmbeddingModelUsage, ImageModelUsage, LanguageModelUsage, ModelMessage } from 'ai'
|
||||
|
||||
import type { AiRequestContext } from '../../types'
|
||||
import type { AiSdkUsage } from '../../../providers/types'
|
||||
import type { AiRequestContext, StreamTextParams, StreamTextResult } from '../../types'
|
||||
import type { StreamController } from './ToolExecutor'
|
||||
|
||||
/**
|
||||
* 类型守卫:检查对象是否是有效的流结果(包含 ReadableStream 类型的 fullStream)
|
||||
*/
|
||||
function hasFullStream(obj: unknown): obj is StreamTextResult & { fullStream: ReadableStream } {
|
||||
return typeof obj === 'object' && obj !== null && 'fullStream' in obj && obj.fullStream instanceof ReadableStream
|
||||
}
|
||||
|
||||
/**
|
||||
* 类型守卫:检查 usage 是否是 LanguageModelUsage
|
||||
* LanguageModelUsage 包含 totalTokens, inputTokens, outputTokens 等字段
|
||||
*/
|
||||
function isLanguageModelUsage(usage: unknown): usage is LanguageModelUsage {
|
||||
return (
|
||||
typeof usage === 'object' &&
|
||||
usage !== null &&
|
||||
('totalTokens' in usage || 'inputTokens' in usage || 'outputTokens' in usage)
|
||||
)
|
||||
}
|
||||
|
||||
/**
|
||||
* 类型守卫:检查 usage 是否是 ImageModelUsage
|
||||
* ImageModelUsage 包含 inputTokens, outputTokens, totalTokens 字段
|
||||
*/
|
||||
function isImageModelUsage(usage: unknown): usage is ImageModelUsage {
|
||||
return (
|
||||
typeof usage === 'object' &&
|
||||
usage !== null &&
|
||||
'inputTokens' in usage &&
|
||||
'outputTokens' in usage &&
|
||||
// 确保不是 LanguageModelUsage(LanguageModelUsage 可能有 reasoningTokens 等额外字段)
|
||||
!('reasoningTokens' in usage)
|
||||
)
|
||||
}
|
||||
|
||||
/**
|
||||
* 类型守卫:检查 usage 是否是 EmbeddingModelUsage
|
||||
* EmbeddingModelUsage 只包含 tokens 字段
|
||||
*/
|
||||
function isEmbeddingModelUsage(usage: unknown): usage is EmbeddingModelUsage {
|
||||
return (
|
||||
typeof usage === 'object' &&
|
||||
usage !== null &&
|
||||
'tokens' in usage &&
|
||||
// 确保只有 tokens 字段(没有 inputTokens, outputTokens 等)
|
||||
!('inputTokens' in usage) &&
|
||||
!('outputTokens' in usage)
|
||||
)
|
||||
}
|
||||
|
||||
/**
|
||||
* 流事件管理器类
|
||||
*/
|
||||
@ -50,10 +100,10 @@ export class StreamEventManager {
|
||||
/**
|
||||
* 处理递归调用并将结果流接入当前流
|
||||
*/
|
||||
async handleRecursiveCall(
|
||||
async handleRecursiveCall<TParams extends StreamTextParams>(
|
||||
controller: StreamController,
|
||||
recursiveParams: any,
|
||||
context: AiRequestContext
|
||||
recursiveParams: Partial<TParams>,
|
||||
context: AiRequestContext<TParams, StreamTextResult>
|
||||
): Promise<void> {
|
||||
// try {
|
||||
// 重置工具执行状态,准备处理新的步骤
|
||||
@ -61,7 +111,7 @@ export class StreamEventManager {
|
||||
|
||||
const recursiveResult = await context.recursiveCall(recursiveParams)
|
||||
|
||||
if (recursiveResult && recursiveResult.fullStream) {
|
||||
if (hasFullStream(recursiveResult)) {
|
||||
await this.pipeRecursiveStream(controller, recursiveResult.fullStream)
|
||||
} else {
|
||||
console.warn('[MCP Prompt] No fullstream found in recursive result:', recursiveResult)
|
||||
@ -97,36 +147,20 @@ export class StreamEventManager {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 处理递归调用错误
|
||||
*/
|
||||
// private handleRecursiveCallError(controller: StreamController, error: unknown): void {
|
||||
// console.error('[MCP Prompt] Recursive call failed:', error)
|
||||
|
||||
// // 使用 AI SDK 标准错误格式,但不中断流
|
||||
// controller.enqueue({
|
||||
// type: 'error',
|
||||
// error: {
|
||||
// message: error instanceof Error ? error.message : String(error),
|
||||
// name: error instanceof Error ? error.name : 'RecursiveCallError'
|
||||
// }
|
||||
// })
|
||||
|
||||
// // // 继续发送文本增量,保持流的连续性
|
||||
// // controller.enqueue({
|
||||
// // type: 'text-delta',
|
||||
// // id: stepId,
|
||||
// // text: '\n\n[工具执行后递归调用失败,继续对话...]'
|
||||
// // })
|
||||
// }
|
||||
|
||||
/**
|
||||
* 构建递归调用的参数
|
||||
*/
|
||||
buildRecursiveParams(context: AiRequestContext, textBuffer: string, toolResultsText: string, tools: any): any {
|
||||
buildRecursiveParams<TParams extends StreamTextParams>(
|
||||
context: AiRequestContext<TParams, StreamTextResult>,
|
||||
textBuffer: string,
|
||||
toolResultsText: string,
|
||||
tools: any
|
||||
): Partial<TParams> {
|
||||
const params = context.originalParams
|
||||
|
||||
// 构建新的对话消息
|
||||
const newMessages: ModelMessage[] = [
|
||||
...(context.originalParams.messages || []),
|
||||
...(params.messages || []),
|
||||
// 只有当 textBuffer 有内容时才添加 assistant 消息,避免空消息导致 API 错误
|
||||
...(textBuffer ? [{ role: 'assistant' as const, content: textBuffer }] : []),
|
||||
{
|
||||
@ -137,28 +171,47 @@ export class StreamEventManager {
|
||||
|
||||
// 递归调用,继续对话,重新传递 tools
|
||||
const recursiveParams = {
|
||||
...context.originalParams,
|
||||
...params,
|
||||
messages: newMessages,
|
||||
tools: tools
|
||||
}
|
||||
|
||||
// 更新上下文中的消息
|
||||
context.originalParams.messages = newMessages
|
||||
} as Partial<TParams>
|
||||
|
||||
return recursiveParams
|
||||
}
|
||||
|
||||
/**
|
||||
* 累加 usage 数据
|
||||
*
|
||||
* 使用类型守卫来处理不同类型的 usage(LanguageModelUsage, ImageModelUsage, EmbeddingModelUsage)
|
||||
* - LanguageModelUsage: inputTokens, outputTokens, totalTokens
|
||||
* - ImageModelUsage: inputTokens, outputTokens, totalTokens
|
||||
* - EmbeddingModelUsage: tokens
|
||||
*/
|
||||
accumulateUsage(target: any, source: any): void {
|
||||
accumulateUsage(target: Partial<AiSdkUsage>, source: Partial<AiSdkUsage>): void {
|
||||
if (!target || !source) return
|
||||
|
||||
// 累加各种 token 类型
|
||||
target.inputTokens = (target.inputTokens || 0) + (source.inputTokens || 0)
|
||||
target.outputTokens = (target.outputTokens || 0) + (source.outputTokens || 0)
|
||||
target.totalTokens = (target.totalTokens || 0) + (source.totalTokens || 0)
|
||||
target.reasoningTokens = (target.reasoningTokens || 0) + (source.reasoningTokens || 0)
|
||||
target.cachedInputTokens = (target.cachedInputTokens || 0) + (source.cachedInputTokens || 0)
|
||||
if (isLanguageModelUsage(target) && isLanguageModelUsage(source)) {
|
||||
target.totalTokens = (target.totalTokens || 0) + (source.totalTokens || 0)
|
||||
target.inputTokens = (target.inputTokens || 0) + (source.inputTokens || 0)
|
||||
target.outputTokens = (target.outputTokens || 0) + (source.outputTokens || 0)
|
||||
return
|
||||
}
|
||||
if (isImageModelUsage(target) && isImageModelUsage(source)) {
|
||||
target.totalTokens = (target.totalTokens || 0) + (source.totalTokens || 0)
|
||||
target.inputTokens = (target.inputTokens || 0) + (source.inputTokens || 0)
|
||||
target.outputTokens = (target.outputTokens || 0) + (source.outputTokens || 0)
|
||||
return
|
||||
}
|
||||
|
||||
if (isEmbeddingModelUsage(target) && isEmbeddingModelUsage(source)) {
|
||||
target.tokens = (target.tokens || 0) + (source.tokens || 0)
|
||||
return
|
||||
}
|
||||
|
||||
// ⚠️ 未知类型或类型不匹配,不进行累加
|
||||
console.warn('[StreamEventManager] Unable to accumulate usage - type mismatch or unknown type', {
|
||||
target,
|
||||
source
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@ -6,7 +6,7 @@
|
||||
import type { TextStreamPart, ToolSet } from 'ai'
|
||||
|
||||
import { definePlugin } from '../../index'
|
||||
import type { AiRequestContext } from '../../types'
|
||||
import type { AiPlugin, StreamTextParams, StreamTextResult } from '../../types'
|
||||
import { StreamEventManager } from './StreamEventManager'
|
||||
import { type TagConfig, TagExtractor } from './tagExtraction'
|
||||
import { ToolExecutor } from './ToolExecutor'
|
||||
@ -254,23 +254,25 @@ function defaultParseToolUse(content: string, tools: ToolSet): { results: ToolUs
|
||||
return { results, content: contentToProcess }
|
||||
}
|
||||
|
||||
export const createPromptToolUsePlugin = (config: PromptToolUseConfig = {}) => {
|
||||
export const createPromptToolUsePlugin = (
|
||||
config: PromptToolUseConfig = {}
|
||||
): AiPlugin<StreamTextParams, StreamTextResult> => {
|
||||
const { enabled = true, buildSystemPrompt = defaultBuildSystemPrompt, parseToolUse = defaultParseToolUse } = config
|
||||
|
||||
return definePlugin({
|
||||
return definePlugin<StreamTextParams, StreamTextResult>({
|
||||
name: 'built-in:prompt-tool-use',
|
||||
transformParams: (params: any, context: AiRequestContext) => {
|
||||
transformParams: (params, context) => {
|
||||
if (!enabled || !params.tools || typeof params.tools !== 'object') {
|
||||
return params
|
||||
}
|
||||
|
||||
// 分离 provider-defined 和其他类型的工具
|
||||
// 分离 provider 和其他类型的工具
|
||||
const providerDefinedTools: ToolSet = {}
|
||||
const promptTools: ToolSet = {}
|
||||
|
||||
for (const [toolName, tool] of Object.entries(params.tools as ToolSet)) {
|
||||
if (tool.type === 'provider-defined') {
|
||||
// provider-defined 类型的工具保留在 tools 参数中
|
||||
if (tool.type === 'provider') {
|
||||
// provider 类型的工具保留在 tools 参数中
|
||||
providerDefinedTools[toolName] = tool
|
||||
} else {
|
||||
// 其他工具转换为 prompt 模式
|
||||
@ -278,12 +280,12 @@ export const createPromptToolUsePlugin = (config: PromptToolUseConfig = {}) => {
|
||||
}
|
||||
}
|
||||
|
||||
// 只有当有非 provider-defined 工具时才保存到 context
|
||||
// 只有当有非 provider 工具时才保存到 context
|
||||
if (Object.keys(promptTools).length > 0) {
|
||||
context.mcpTools = promptTools
|
||||
}
|
||||
|
||||
// 构建系统提示符(只包含非 provider-defined 工具)
|
||||
// 构建系统提示符(只包含非 provider 工具)
|
||||
const userSystemPrompt = typeof params.system === 'string' ? params.system : ''
|
||||
const systemPrompt = buildSystemPrompt(userSystemPrompt, promptTools)
|
||||
let systemMessage: string | null = systemPrompt
|
||||
@ -292,7 +294,7 @@ export const createPromptToolUsePlugin = (config: PromptToolUseConfig = {}) => {
|
||||
systemMessage = config.createSystemMessage(systemPrompt, params, context)
|
||||
}
|
||||
|
||||
// 保留 provider-defined tools,移除其他 tools
|
||||
// 保留 provide tools,移除其他 tools
|
||||
const transformedParams = {
|
||||
...params,
|
||||
...(systemMessage ? { system: systemMessage } : {}),
|
||||
@ -301,7 +303,7 @@ export const createPromptToolUsePlugin = (config: PromptToolUseConfig = {}) => {
|
||||
context.originalParams = transformedParams
|
||||
return transformedParams
|
||||
},
|
||||
transformStream: (_: any, context: AiRequestContext) => () => {
|
||||
transformStream: (_, context) => () => {
|
||||
let textBuffer = ''
|
||||
// let stepId = ''
|
||||
|
||||
|
||||
@ -1,7 +1,17 @@
|
||||
// 核心类型和接口
|
||||
export type { AiPlugin, AiRequestContext, HookResult, PluginManagerConfig } from './types'
|
||||
import type { ImageModelV3 } from '@ai-sdk/provider'
|
||||
import type { LanguageModel } from 'ai'
|
||||
export type {
|
||||
AiPlugin,
|
||||
AiRequestContext,
|
||||
AiRequestMetadata,
|
||||
GenerateTextParams,
|
||||
GenerateTextResult,
|
||||
HookResult,
|
||||
PluginManagerConfig,
|
||||
RecursiveCallFn,
|
||||
StreamTextParams,
|
||||
StreamTextResult
|
||||
} from './types'
|
||||
import type { ImageModel, LanguageModel } from 'ai'
|
||||
|
||||
import type { ProviderId } from '../providers'
|
||||
import type { AiPlugin, AiRequestContext } from './types'
|
||||
@ -10,11 +20,11 @@ import type { AiPlugin, AiRequestContext } from './types'
|
||||
export { PluginManager } from './manager'
|
||||
|
||||
// 工具函数
|
||||
export function createContext<T extends ProviderId>(
|
||||
export function createContext<T extends ProviderId, TParams = unknown, TResult = unknown>(
|
||||
providerId: T,
|
||||
model: LanguageModel | ImageModelV3,
|
||||
originalParams: any
|
||||
): AiRequestContext {
|
||||
model: LanguageModel | ImageModel,
|
||||
originalParams: TParams
|
||||
): AiRequestContext<TParams, TResult> {
|
||||
return {
|
||||
providerId,
|
||||
model,
|
||||
@ -22,14 +32,28 @@ export function createContext<T extends ProviderId>(
|
||||
metadata: {},
|
||||
startTime: Date.now(),
|
||||
requestId: `${providerId}-${typeof model === 'string' ? model : model?.modelId}-${Date.now()}-${Math.random().toString(36).slice(2)}`,
|
||||
// 占位
|
||||
recursiveCall: () => Promise.resolve(null)
|
||||
isRecursiveCall: false,
|
||||
recursiveDepth: 0, // 初始化递归深度为 0
|
||||
maxRecursiveDepth: 10, // 默认最大递归深度为 10
|
||||
extensions: new Map(),
|
||||
middlewares: [],
|
||||
// 占位递归调用函数,实际使用时会被 PluginEngine 替换
|
||||
recursiveCall: () => Promise.resolve(null as any)
|
||||
}
|
||||
}
|
||||
|
||||
// 插件构建器 - 便于创建插件
|
||||
|
||||
// 重载 1: 泛型插件(显式指定类型参数)
|
||||
export function definePlugin<TParams, TResult>(plugin: AiPlugin<TParams, TResult>): AiPlugin<TParams, TResult>
|
||||
|
||||
// 重载 2: 非泛型插件(默认 unknown)
|
||||
export function definePlugin(plugin: AiPlugin): AiPlugin
|
||||
|
||||
// 重载 3: 插件工厂函数
|
||||
export function definePlugin<T extends (...args: any[]) => AiPlugin>(pluginFactory: T): T
|
||||
|
||||
// 实现
|
||||
export function definePlugin(plugin: AiPlugin | ((...args: any[]) => AiPlugin)) {
|
||||
return plugin
|
||||
}
|
||||
|
||||
@ -1,19 +1,20 @@
|
||||
import type { AiPlugin, AiRequestContext } from './types'
|
||||
|
||||
/**
|
||||
* 插件管理器
|
||||
* 插件管理器(泛型化)
|
||||
* 支持类型安全的插件管理,同时通过逆变保持灵活性
|
||||
*/
|
||||
export class PluginManager {
|
||||
private plugins: AiPlugin[] = []
|
||||
export class PluginManager<TParams = unknown, TResult = unknown> {
|
||||
private plugins: AiPlugin<TParams, TResult>[] = []
|
||||
|
||||
constructor(plugins: AiPlugin[] = []) {
|
||||
constructor(plugins: AiPlugin<TParams, TResult>[] = []) {
|
||||
this.plugins = this.sortPlugins(plugins)
|
||||
}
|
||||
|
||||
/**
|
||||
* 添加插件
|
||||
* 添加插件(支持逆变:AiPlugin<unknown> 可赋值给 AiPlugin<TParams>)
|
||||
*/
|
||||
use(plugin: AiPlugin): this {
|
||||
use(plugin: AiPlugin<TParams, TResult>): this {
|
||||
this.plugins = this.sortPlugins([...this.plugins, plugin])
|
||||
return this
|
||||
}
|
||||
@ -29,10 +30,10 @@ export class PluginManager {
|
||||
/**
|
||||
* 插件排序:pre -> normal -> post
|
||||
*/
|
||||
private sortPlugins(plugins: AiPlugin[]): AiPlugin[] {
|
||||
const pre: AiPlugin[] = []
|
||||
const normal: AiPlugin[] = []
|
||||
const post: AiPlugin[] = []
|
||||
private sortPlugins(plugins: AiPlugin<TParams, TResult>[]): AiPlugin<TParams, TResult>[] {
|
||||
const pre: AiPlugin<TParams, TResult>[] = []
|
||||
const normal: AiPlugin<TParams, TResult>[] = []
|
||||
const post: AiPlugin<TParams, TResult>[] = []
|
||||
|
||||
plugins.forEach((plugin) => {
|
||||
if (plugin.enforce === 'pre') {
|
||||
@ -53,7 +54,7 @@ export class PluginManager {
|
||||
async executeFirst<T>(
|
||||
hookName: 'resolveModel' | 'loadTemplate',
|
||||
arg: any,
|
||||
context: AiRequestContext
|
||||
context: AiRequestContext<TParams, TResult>
|
||||
): Promise<T | null> {
|
||||
for (const plugin of this.plugins) {
|
||||
const hook = plugin[hookName]
|
||||
@ -68,19 +69,42 @@ export class PluginManager {
|
||||
}
|
||||
|
||||
/**
|
||||
* 执行 Sequential 钩子 - 链式数据转换
|
||||
* 执行 transformParams 钩子 - 链式参数转换
|
||||
* 每个插件返回 Partial<TParams>,逐步合并到原始参数
|
||||
*/
|
||||
async executeSequential<T>(
|
||||
hookName: 'transformParams' | 'transformResult',
|
||||
initialValue: T,
|
||||
context: AiRequestContext
|
||||
): Promise<T> {
|
||||
async executeTransformParams(
|
||||
initialValue: TParams,
|
||||
context: AiRequestContext<TParams, TResult>
|
||||
): Promise<TParams> {
|
||||
let result = initialValue
|
||||
|
||||
for (const plugin of this.plugins) {
|
||||
const hook = plugin[hookName]
|
||||
if (hook) {
|
||||
result = await hook<T>(result, context)
|
||||
if (plugin.transformParams) {
|
||||
const partial = await plugin.transformParams(result, context)
|
||||
// 合并 Partial 到现有参数
|
||||
result = { ...result, ...partial }
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
/**
|
||||
* 执行 transformResult 钩子 - 链式结果转换
|
||||
* 每个插件接收并返回完整的 TResult
|
||||
*/
|
||||
async executeTransformResult(
|
||||
initialValue: TResult,
|
||||
context: AiRequestContext<TParams, TResult>
|
||||
): Promise<TResult> {
|
||||
let result = initialValue
|
||||
|
||||
for (const plugin of this.plugins) {
|
||||
if (plugin.transformResult) {
|
||||
// SAFETY: transformResult 的契约保证返回 TResult
|
||||
// 由于插件接口定义,这个类型断言是安全的
|
||||
const transformed = await plugin.transformResult(result, context)
|
||||
result = transformed as TResult
|
||||
}
|
||||
}
|
||||
|
||||
@ -90,7 +114,7 @@ export class PluginManager {
|
||||
/**
|
||||
* 执行 ConfigureContext 钩子 - 串行配置上下文
|
||||
*/
|
||||
async executeConfigureContext(context: AiRequestContext): Promise<void> {
|
||||
async executeConfigureContext(context: AiRequestContext<TParams, TResult>): Promise<void> {
|
||||
for (const plugin of this.plugins) {
|
||||
const hook = plugin.configureContext
|
||||
if (hook) {
|
||||
@ -104,8 +128,8 @@ export class PluginManager {
|
||||
*/
|
||||
async executeParallel(
|
||||
hookName: 'onRequestStart' | 'onRequestEnd' | 'onError',
|
||||
context: AiRequestContext,
|
||||
result?: any,
|
||||
context: AiRequestContext<TParams, TResult>,
|
||||
result?: TResult,
|
||||
error?: Error
|
||||
): Promise<void> {
|
||||
const promises = this.plugins
|
||||
@ -131,7 +155,7 @@ export class PluginManager {
|
||||
/**
|
||||
* 收集所有流转换器(返回数组,AI SDK 原生支持)
|
||||
*/
|
||||
collectStreamTransforms(params: any, context: AiRequestContext) {
|
||||
collectStreamTransforms(params: TParams, context: AiRequestContext<TParams, TResult>) {
|
||||
return this.plugins
|
||||
.filter((plugin) => plugin.transformStream)
|
||||
.map((plugin) => plugin.transformStream?.(params, context))
|
||||
@ -140,7 +164,7 @@ export class PluginManager {
|
||||
/**
|
||||
* 获取所有插件信息
|
||||
*/
|
||||
getPlugins(): AiPlugin[] {
|
||||
getPlugins(): AiPlugin<TParams, TResult>[] {
|
||||
return [...this.plugins]
|
||||
}
|
||||
|
||||
|
||||
@ -1,79 +1,126 @@
|
||||
import type { ImageModelV3 } from '@ai-sdk/provider'
|
||||
import type { LanguageModel, TextStreamPart, ToolSet } from 'ai'
|
||||
import type { JSONObject, JSONValue } from '@ai-sdk/provider'
|
||||
import type { generateText, LanguageModelMiddleware, streamText, TextStreamPart, ToolSet } from 'ai'
|
||||
|
||||
import { type ProviderId } from '../providers/types'
|
||||
import type { AiSdkModel, ProviderId } from '../providers/types'
|
||||
|
||||
/**
|
||||
* 常用的 AI SDK 参数类型(完整版,用于插件泛型)
|
||||
*/
|
||||
export type StreamTextParams = Parameters<typeof streamText>[0]
|
||||
export type StreamTextResult = ReturnType<typeof streamText>
|
||||
export type GenerateTextParams = Parameters<typeof generateText>[0]
|
||||
export type GenerateTextResult = ReturnType<typeof generateText>
|
||||
|
||||
/**
|
||||
* AI 请求元数据
|
||||
* 定义结构化的元数据字段,避免使用 Record<string, any>
|
||||
*/
|
||||
export interface AiRequestMetadata {
|
||||
topicId?: string
|
||||
callType?: string
|
||||
enableReasoning?: boolean
|
||||
enableWebSearch?: boolean
|
||||
enableGenerateImage?: boolean
|
||||
isPromptToolUse?: boolean
|
||||
isSupportedToolUse?: boolean
|
||||
isImageGenerationEndpoint?: boolean
|
||||
// 自定义元数据,使用 JSONValue 确保类型安全
|
||||
custom?: JSONObject
|
||||
}
|
||||
|
||||
/**
|
||||
* 递归调用函数类型
|
||||
* 使用 any 是因为递归调用时参数和返回类型可能完全不同
|
||||
* 泛型化以保持类型推导
|
||||
*/
|
||||
export type RecursiveCallFn = (newParams: any) => Promise<any>
|
||||
export type RecursiveCallFn<TParams = unknown, TResult = unknown> = (newParams: Partial<TParams>) => Promise<TResult>
|
||||
|
||||
/**
|
||||
* AI 请求上下文
|
||||
* 使用泛型参数以支持不同类型的请求
|
||||
*/
|
||||
export interface AiRequestContext {
|
||||
export interface AiRequestContext<TParams = unknown, TResult = unknown> {
|
||||
providerId: ProviderId
|
||||
model: LanguageModel | ImageModelV3
|
||||
originalParams: any
|
||||
metadata: Record<string, any>
|
||||
model: AiSdkModel
|
||||
originalParams: TParams
|
||||
metadata: AiRequestMetadata
|
||||
startTime: number
|
||||
requestId: string
|
||||
recursiveCall: RecursiveCallFn
|
||||
isRecursiveCall?: boolean
|
||||
recursiveCall: RecursiveCallFn<TParams, TResult>
|
||||
isRecursiveCall: boolean
|
||||
|
||||
// 递归深度控制(防止栈溢出)
|
||||
recursiveDepth: number // 当前递归深度
|
||||
maxRecursiveDepth: number // 最大递归深度限制,默认 10
|
||||
|
||||
mcpTools?: ToolSet
|
||||
|
||||
extensions: Map<string, JSONValue>
|
||||
|
||||
middlewares?: LanguageModelMiddleware[]
|
||||
|
||||
// 向后兼容:允许插件动态添加属性(临时保留)
|
||||
[key: string]: any
|
||||
}
|
||||
|
||||
/**
|
||||
* 钩子分类
|
||||
* 使用泛型参数以支持不同类型的请求和响应
|
||||
*/
|
||||
export interface AiPlugin {
|
||||
export interface AiPlugin<TParams = unknown, TResult = unknown> {
|
||||
name: string
|
||||
enforce?: 'pre' | 'post'
|
||||
|
||||
// 【First】首个钩子 - 只执行第一个返回值的插件
|
||||
resolveModel?: (
|
||||
modelId: string,
|
||||
context: AiRequestContext
|
||||
) => Promise<LanguageModel | ImageModelV3 | null> | LanguageModel | ImageModelV3 | null
|
||||
loadTemplate?: (templateName: string, context: AiRequestContext) => any | null | Promise<any | null>
|
||||
context: AiRequestContext<TParams, TResult>
|
||||
) => Promise<AiSdkModel | null> | AiSdkModel | null
|
||||
|
||||
loadTemplate?: (
|
||||
templateName: string,
|
||||
context: AiRequestContext<TParams, TResult>
|
||||
) => JSONValue | null | Promise<JSONValue | null>
|
||||
|
||||
// 【Sequential】串行钩子 - 链式执行,支持数据转换
|
||||
configureContext?: (context: AiRequestContext) => void | Promise<void>
|
||||
transformParams?: <T>(params: T, context: AiRequestContext) => T | Promise<T>
|
||||
transformResult?: <T>(result: T, context: AiRequestContext) => T | Promise<T>
|
||||
configureContext?: (context: AiRequestContext<TParams, TResult>) => void | Promise<void>
|
||||
|
||||
transformParams?: (
|
||||
params: TParams,
|
||||
context: AiRequestContext<TParams, TResult>
|
||||
) => Partial<TParams> | Promise<Partial<TParams>>
|
||||
|
||||
transformResult?: (result: TResult, context: AiRequestContext<TParams, TResult>) => TResult | Promise<TResult>
|
||||
|
||||
// 【Parallel】并行钩子 - 不依赖顺序,用于副作用
|
||||
onRequestStart?: (context: AiRequestContext) => void | Promise<void>
|
||||
onRequestEnd?: (context: AiRequestContext, result: any) => void | Promise<void>
|
||||
onError?: (error: Error, context: AiRequestContext) => void | Promise<void>
|
||||
onRequestStart?: (context: AiRequestContext<TParams, TResult>) => void | Promise<void>
|
||||
|
||||
onRequestEnd?: (context: AiRequestContext<TParams, TResult>, result: TResult) => void | Promise<void>
|
||||
|
||||
onError?: (error: Error, context: AiRequestContext<TParams, TResult>) => void | Promise<void>
|
||||
|
||||
// 【Stream】流处理 - 直接使用 AI SDK
|
||||
transformStream?: (
|
||||
params: any,
|
||||
context: AiRequestContext
|
||||
params: TParams,
|
||||
context: AiRequestContext<TParams, TResult>
|
||||
) => <TOOLS extends ToolSet>(options?: {
|
||||
tools: TOOLS
|
||||
stopStream: () => void
|
||||
}) => TransformStream<TextStreamPart<TOOLS>, TextStreamPart<TOOLS>>
|
||||
|
||||
// AI SDK 原生中间件
|
||||
// aiSdkMiddlewares?: LanguageModelV1Middleware[]
|
||||
}
|
||||
|
||||
/**
|
||||
* 插件管理器配置
|
||||
*/
|
||||
export interface PluginManagerConfig {
|
||||
plugins: AiPlugin[]
|
||||
context: Partial<AiRequestContext>
|
||||
export interface PluginManagerConfig<TParams = unknown, TResult = unknown> {
|
||||
plugins: AiPlugin<TParams, TResult>[]
|
||||
context: Partial<AiRequestContext<TParams, TResult>>
|
||||
}
|
||||
|
||||
/**
|
||||
* 钩子执行结果
|
||||
* 泛型参数指定返回值类型
|
||||
*/
|
||||
export interface HookResult<T = any> {
|
||||
export interface HookResult<T = unknown> {
|
||||
value: T
|
||||
stop?: boolean
|
||||
}
|
||||
|
||||
@ -5,11 +5,19 @@
|
||||
* 例如: aihubmix:anthropic:claude-3.5-sonnet
|
||||
*/
|
||||
|
||||
import type { ProviderV3 } from '@ai-sdk/provider'
|
||||
import { customProvider } from 'ai'
|
||||
import type {
|
||||
EmbeddingModelV3,
|
||||
ImageModelV3,
|
||||
LanguageModelV3,
|
||||
ProviderV3,
|
||||
RerankingModelV3,
|
||||
SpeechModelV3,
|
||||
TranscriptionModelV3
|
||||
} from '@ai-sdk/provider'
|
||||
import { customProvider, wrapProvider } from 'ai'
|
||||
|
||||
import { globalRegistryManagement } from './RegistryManagement'
|
||||
import type { AiSdkMethodName, AiSdkModelReturn, AiSdkModelType } from './types'
|
||||
import { DEFAULT_SEPARATOR, globalRegistryManagement } from './RegistryManagement'
|
||||
import type { AiSdkProvider } from './types'
|
||||
|
||||
export interface HubProviderConfig {
|
||||
/** Hub的唯一标识符 */
|
||||
@ -34,7 +42,7 @@ export class HubProviderError extends Error {
|
||||
* 解析Hub模型ID
|
||||
*/
|
||||
function parseHubModelId(modelId: string): { provider: string; actualModelId: string } {
|
||||
const parts = modelId.split(':')
|
||||
const parts = modelId.split(DEFAULT_SEPARATOR)
|
||||
if (parts.length !== 2) {
|
||||
throw new HubProviderError(`Invalid hub model ID format. Expected "provider:modelId", got: ${modelId}`, 'unknown')
|
||||
}
|
||||
@ -47,7 +55,7 @@ function parseHubModelId(modelId: string): { provider: string; actualModelId: st
|
||||
/**
|
||||
* 创建Hub Provider
|
||||
*/
|
||||
export function createHubProvider(config: HubProviderConfig): ProviderV3 {
|
||||
export function createHubProvider(config: HubProviderConfig): AiSdkProvider {
|
||||
const { hubId } = config
|
||||
|
||||
function getTargetProvider(providerId: string): ProviderV3 {
|
||||
@ -61,7 +69,9 @@ export function createHubProvider(config: HubProviderConfig): ProviderV3 {
|
||||
providerId
|
||||
)
|
||||
}
|
||||
return provider
|
||||
// 使用 wrapProvider 确保返回的是 V3 provider
|
||||
// 这样可以自动处理 V2 provider 到 V3 的转换
|
||||
return wrapProvider({ provider, languageModelMiddleware: [] })
|
||||
} catch (error) {
|
||||
throw new HubProviderError(
|
||||
`Failed to get provider "${providerId}": ${error instanceof Error ? error.message : 'Unknown error'}`,
|
||||
@ -72,30 +82,62 @@ export function createHubProvider(config: HubProviderConfig): ProviderV3 {
|
||||
}
|
||||
}
|
||||
|
||||
function resolveModel<T extends AiSdkModelType>(
|
||||
modelId: string,
|
||||
modelType: T,
|
||||
methodName: AiSdkMethodName<T>
|
||||
): AiSdkModelReturn<T> {
|
||||
const { provider, actualModelId } = parseHubModelId(modelId)
|
||||
const targetProvider = getTargetProvider(provider)
|
||||
// 创建符合 ProviderV3 规范的 fallback provider
|
||||
const hubFallbackProvider = {
|
||||
specificationVersion: 'v3' as const,
|
||||
|
||||
const fn = targetProvider[methodName] as (id: string) => AiSdkModelReturn<T>
|
||||
languageModel: (modelId: string): LanguageModelV3 => {
|
||||
const { provider, actualModelId } = parseHubModelId(modelId)
|
||||
const targetProvider = getTargetProvider(provider)
|
||||
return targetProvider.languageModel(actualModelId)
|
||||
},
|
||||
|
||||
if (!fn) {
|
||||
throw new HubProviderError(`Provider "${provider}" does not support ${modelType}`, hubId, provider)
|
||||
embeddingModel: (modelId: string): EmbeddingModelV3 => {
|
||||
const { provider, actualModelId } = parseHubModelId(modelId)
|
||||
const targetProvider = getTargetProvider(provider)
|
||||
return targetProvider.embeddingModel(actualModelId)
|
||||
},
|
||||
|
||||
imageModel: (modelId: string): ImageModelV3 => {
|
||||
const { provider, actualModelId } = parseHubModelId(modelId)
|
||||
const targetProvider = getTargetProvider(provider)
|
||||
return targetProvider.imageModel(actualModelId)
|
||||
},
|
||||
|
||||
transcriptionModel: (modelId: string): TranscriptionModelV3 => {
|
||||
const { provider, actualModelId } = parseHubModelId(modelId)
|
||||
const targetProvider = getTargetProvider(provider)
|
||||
|
||||
if (!targetProvider.transcriptionModel) {
|
||||
throw new HubProviderError(`Provider "${provider}" does not support transcription models`, hubId, provider)
|
||||
}
|
||||
|
||||
return targetProvider.transcriptionModel(actualModelId)
|
||||
},
|
||||
|
||||
speechModel: (modelId: string): SpeechModelV3 => {
|
||||
const { provider, actualModelId } = parseHubModelId(modelId)
|
||||
const targetProvider = getTargetProvider(provider)
|
||||
|
||||
if (!targetProvider.speechModel) {
|
||||
throw new HubProviderError(`Provider "${provider}" does not support speech models`, hubId, provider)
|
||||
}
|
||||
|
||||
return targetProvider.speechModel(actualModelId)
|
||||
},
|
||||
rerankingModel: (modelId: string): RerankingModelV3 => {
|
||||
const { provider, actualModelId } = parseHubModelId(modelId)
|
||||
const targetProvider = getTargetProvider(provider)
|
||||
|
||||
if (!targetProvider.rerankingModel) {
|
||||
throw new HubProviderError(`Provider "${provider}" does not support reranking models`, hubId, provider)
|
||||
}
|
||||
|
||||
return targetProvider.rerankingModel(actualModelId)
|
||||
}
|
||||
|
||||
return fn(actualModelId)
|
||||
}
|
||||
|
||||
return customProvider({
|
||||
fallbackProvider: {
|
||||
languageModel: (modelId: string) => resolveModel(modelId, 'text', 'languageModel'),
|
||||
textEmbeddingModel: (modelId: string) => resolveModel(modelId, 'embedding', 'textEmbeddingModel'),
|
||||
imageModel: (modelId: string) => resolveModel(modelId, 'image', 'imageModel'),
|
||||
transcriptionModel: (modelId: string) => resolveModel(modelId, 'transcription', 'transcriptionModel'),
|
||||
speechModel: (modelId: string) => resolveModel(modelId, 'speech', 'speechModel')
|
||||
}
|
||||
fallbackProvider: hubFallbackProvider
|
||||
})
|
||||
}
|
||||
|
||||
@ -11,8 +11,6 @@ type PROVIDERS = Record<string, ProviderV3>
|
||||
|
||||
export const DEFAULT_SEPARATOR = '|'
|
||||
|
||||
// export type MODEL_ID = `${string}${typeof DEFAULT_SEPARATOR}${string}`
|
||||
|
||||
export class RegistryManagement<SEPARATOR extends string = typeof DEFAULT_SEPARATOR> {
|
||||
private providers: PROVIDERS = {}
|
||||
private aliases: Set<string> = new Set() // 记录哪些key是别名
|
||||
|
||||
@ -56,6 +56,7 @@ export type {
|
||||
} from './schemas' // 从 schemas 导出的类型
|
||||
export { baseProviderIdSchema, customProviderIdSchema, providerConfigSchema, providerIdSchema } from './schemas' // Schema 导出
|
||||
export type {
|
||||
AiSdkModel,
|
||||
DynamicProviderRegistry,
|
||||
ExtensibleProviderSettingsMap,
|
||||
ProviderError,
|
||||
|
||||
@ -12,8 +12,8 @@ import { createOpenAICompatible } from '@ai-sdk/openai-compatible'
|
||||
import type { ProviderV3 } from '@ai-sdk/provider'
|
||||
import { createXai } from '@ai-sdk/xai'
|
||||
import { type CherryInProviderSettings, createCherryIn } from '@cherrystudio/ai-sdk-provider'
|
||||
import { createOpenRouter } from '@openrouter/ai-sdk-provider'
|
||||
import { customProvider } from 'ai'
|
||||
import { createOpenRouter, type OpenRouterProviderSettings } from '@openrouter/ai-sdk-provider'
|
||||
import { customProvider, wrapProvider } from 'ai'
|
||||
import * as z from 'zod'
|
||||
|
||||
/**
|
||||
@ -133,7 +133,10 @@ export const baseProviders = [
|
||||
{
|
||||
id: 'openrouter',
|
||||
name: 'OpenRouter',
|
||||
creator: createOpenRouter,
|
||||
creator: (options?: OpenRouterProviderSettings) => {
|
||||
const provider = createOpenRouter(options)
|
||||
return wrapProvider({ provider, languageModelMiddleware: [] })
|
||||
},
|
||||
supportsImageGeneration: true
|
||||
},
|
||||
{
|
||||
|
||||
@ -4,15 +4,18 @@ import { type DeepSeekProviderSettings } from '@ai-sdk/deepseek'
|
||||
import { type GoogleGenerativeAIProviderSettings } from '@ai-sdk/google'
|
||||
import { type OpenAIProviderSettings } from '@ai-sdk/openai'
|
||||
import { type OpenAICompatibleProviderSettings } from '@ai-sdk/openai-compatible'
|
||||
import type {
|
||||
EmbeddingModelV3 as EmbeddingModel,
|
||||
ImageModelV3 as ImageModel,
|
||||
LanguageModelV3 as LanguageModel,
|
||||
ProviderV3,
|
||||
SpeechModelV3 as SpeechModel,
|
||||
TranscriptionModelV3 as TranscriptionModel
|
||||
} from '@ai-sdk/provider'
|
||||
import type { ProviderV2, ProviderV3 } from '@ai-sdk/provider'
|
||||
import { type XaiProviderSettings } from '@ai-sdk/xai'
|
||||
import type {
|
||||
EmbeddingModel,
|
||||
EmbeddingModelUsage,
|
||||
ImageModel,
|
||||
ImageModelUsage,
|
||||
LanguageModel,
|
||||
LanguageModelUsage,
|
||||
SpeechModel,
|
||||
TranscriptionModel
|
||||
} from 'ai'
|
||||
|
||||
// 导入基于 Zod 的 ProviderId 类型
|
||||
import { type ProviderId as ZodProviderId } from './schemas'
|
||||
@ -70,6 +73,8 @@ export type {
|
||||
}
|
||||
|
||||
export type AiSdkModel = LanguageModel | ImageModel | EmbeddingModel | TranscriptionModel | SpeechModel
|
||||
export type AiSdkProvider = ProviderV2 | ProviderV3
|
||||
export type AiSdkUsage = LanguageModelUsage | ImageModelUsage | EmbeddingModelUsage
|
||||
|
||||
export type AiSdkModelType = 'text' | 'image' | 'embedding' | 'transcription' | 'speech'
|
||||
|
||||
|
||||
@ -4,10 +4,16 @@
|
||||
*/
|
||||
import type { ImageModelV3, LanguageModelV3, LanguageModelV3Middleware } from '@ai-sdk/provider'
|
||||
import type { LanguageModel } from 'ai'
|
||||
import { generateImage as _generateImage, generateText as _generateText, streamText as _streamText } from 'ai'
|
||||
import {
|
||||
generateImage as _generateImage,
|
||||
generateText as _generateText,
|
||||
streamText as _streamText,
|
||||
wrapLanguageModel
|
||||
} from 'ai'
|
||||
|
||||
import { globalModelResolver } from '../models'
|
||||
import { type ModelConfig } from '../models/types'
|
||||
import { isV3Model } from '../models/utils'
|
||||
import { type AiPlugin, type AiRequestContext, definePlugin } from '../plugins'
|
||||
import { type ProviderId } from '../providers'
|
||||
import { ImageGenerationError, ImageModelResolutionError } from './errors'
|
||||
@ -181,8 +187,20 @@ export class RuntimeExecutor<T extends ProviderId = ProviderId> {
|
||||
middlewares // 中间件数组
|
||||
)
|
||||
} else {
|
||||
// 已经是模型,直接返回
|
||||
return modelOrId
|
||||
// 已经是模型对象
|
||||
// 所有 provider 都应该返回 V3 模型(通过 wrapProvider 确保)
|
||||
if (!isV3Model(modelOrId)) {
|
||||
throw new Error(
|
||||
`Model must be V3. Provider "${this.config.providerId}" returned a V2 model. ` +
|
||||
'All providers should be wrapped with wrapProvider to return V3 models.'
|
||||
)
|
||||
}
|
||||
|
||||
// V3 模型,使用 wrapLanguageModel 应用中间件
|
||||
return wrapLanguageModel({
|
||||
model: modelOrId,
|
||||
middleware: middlewares || []
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -1,8 +1,18 @@
|
||||
/* eslint-disable @eslint-react/naming-convention/context-name */
|
||||
import type { ImageModelV3 } from '@ai-sdk/provider'
|
||||
import type { generateImage, generateText, LanguageModel, streamText } from 'ai'
|
||||
import type { generateImage, LanguageModel } from 'ai'
|
||||
|
||||
import { type AiPlugin, createContext, PluginManager } from '../plugins'
|
||||
import { ModelResolutionError, RecursiveDepthError } from '../errors'
|
||||
import {
|
||||
type AiPlugin,
|
||||
type AiRequestContext,
|
||||
createContext,
|
||||
type GenerateTextParams,
|
||||
type GenerateTextResult,
|
||||
PluginManager,
|
||||
type StreamTextParams,
|
||||
type StreamTextResult
|
||||
} from '../plugins'
|
||||
import { type ProviderId } from '../providers/types'
|
||||
|
||||
/**
|
||||
@ -10,21 +20,22 @@ import { type ProviderId } from '../providers/types'
|
||||
* 专注于插件处理,不暴露用户API
|
||||
*/
|
||||
export class PluginEngine<T extends ProviderId = ProviderId> {
|
||||
private pluginManager: PluginManager
|
||||
// ✅ 存储为非泛型数组(允许混合不同类型的插件)
|
||||
private basePlugins: AiPlugin[] = []
|
||||
|
||||
constructor(
|
||||
private readonly providerId: T,
|
||||
// private readonly options: ProviderSettingsMap[T],
|
||||
plugins: AiPlugin[] = []
|
||||
) {
|
||||
this.pluginManager = new PluginManager(plugins)
|
||||
this.basePlugins = plugins
|
||||
}
|
||||
|
||||
/**
|
||||
* 添加插件
|
||||
*/
|
||||
use(plugin: AiPlugin): this {
|
||||
this.pluginManager.use(plugin)
|
||||
this.basePlugins.push(plugin)
|
||||
return this
|
||||
}
|
||||
|
||||
@ -32,7 +43,7 @@ export class PluginEngine<T extends ProviderId = ProviderId> {
|
||||
* 批量添加插件
|
||||
*/
|
||||
usePlugins(plugins: AiPlugin[]): this {
|
||||
plugins.forEach((plugin) => this.use(plugin))
|
||||
this.basePlugins.push(...plugins)
|
||||
return this
|
||||
}
|
||||
|
||||
@ -40,7 +51,7 @@ export class PluginEngine<T extends ProviderId = ProviderId> {
|
||||
* 移除插件
|
||||
*/
|
||||
removePlugin(pluginName: string): this {
|
||||
this.pluginManager.remove(pluginName)
|
||||
this.basePlugins = this.basePlugins.filter((p) => p.name !== pluginName)
|
||||
return this
|
||||
}
|
||||
|
||||
@ -48,14 +59,16 @@ export class PluginEngine<T extends ProviderId = ProviderId> {
|
||||
* 获取插件统计
|
||||
*/
|
||||
getPluginStats() {
|
||||
return this.pluginManager.getStats()
|
||||
// 创建临时 manager 来获取统计信息
|
||||
const tempManager = new PluginManager(this.basePlugins)
|
||||
return tempManager.getStats()
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取所有插件
|
||||
*/
|
||||
getPlugins() {
|
||||
return this.pluginManager.getPlugins()
|
||||
return [...this.basePlugins]
|
||||
}
|
||||
|
||||
/**
|
||||
@ -63,13 +76,13 @@ export class PluginEngine<T extends ProviderId = ProviderId> {
|
||||
* 提供给AiExecutor使用
|
||||
*/
|
||||
async executeWithPlugins<
|
||||
TParams extends Parameters<typeof generateText>[0],
|
||||
TResult extends ReturnType<typeof generateText>
|
||||
TParams extends GenerateTextParams,
|
||||
TResult extends GenerateTextResult
|
||||
>(
|
||||
methodName: string,
|
||||
params: TParams,
|
||||
executor: (model: LanguageModel, transformedParams: TParams) => TResult,
|
||||
_context?: ReturnType<typeof createContext>
|
||||
_context?: AiRequestContext<TParams, TResult>
|
||||
): Promise<TResult> {
|
||||
// 统一处理模型解析
|
||||
let resolvedModel: LanguageModel | undefined
|
||||
@ -84,54 +97,76 @@ export class PluginEngine<T extends ProviderId = ProviderId> {
|
||||
modelId = model.modelId
|
||||
}
|
||||
|
||||
// 使用正确的createContext创建请求上下文
|
||||
const context = _context ? _context : createContext(this.providerId, model, params)
|
||||
// 创建类型安全的 context
|
||||
const context = _context ?? createContext(this.providerId, model, params)
|
||||
|
||||
// 🔥 为上下文添加递归调用能力
|
||||
context.recursiveCall = async (newParams: any): Promise<TResult> => {
|
||||
// 递归调用自身,重新走完整的插件流程
|
||||
context.isRecursiveCall = true
|
||||
const result = await this.executeWithPlugins(methodName, newParams, executor, context)
|
||||
context.isRecursiveCall = false
|
||||
return result
|
||||
// ✅ 创建类型化的 manager(逆变安全)
|
||||
const manager = new PluginManager<TParams, TResult>(
|
||||
this.basePlugins as AiPlugin<TParams, TResult>[]
|
||||
)
|
||||
|
||||
// ✅ 递归调用泛型化,增加深度限制
|
||||
context.recursiveCall = async <R = TResult>(newParams: Partial<TParams>): Promise<R> => {
|
||||
if (context.recursiveDepth >= context.maxRecursiveDepth) {
|
||||
throw new RecursiveDepthError(context.requestId, context.recursiveDepth, context.maxRecursiveDepth)
|
||||
}
|
||||
|
||||
const previousDepth = context.recursiveDepth
|
||||
const wasRecursive = context.isRecursiveCall
|
||||
|
||||
try {
|
||||
context.recursiveDepth = previousDepth + 1
|
||||
context.isRecursiveCall = true
|
||||
|
||||
return await this.executeWithPlugins(
|
||||
methodName,
|
||||
{ ...params, ...newParams } as TParams,
|
||||
executor,
|
||||
context
|
||||
) as unknown as R
|
||||
} finally {
|
||||
// ✅ finally 确保状态恢复
|
||||
context.recursiveDepth = previousDepth
|
||||
context.isRecursiveCall = wasRecursive
|
||||
}
|
||||
}
|
||||
|
||||
try {
|
||||
// 0. 配置上下文
|
||||
await this.pluginManager.executeConfigureContext(context)
|
||||
await manager.executeConfigureContext(context)
|
||||
|
||||
// 1. 触发请求开始事件
|
||||
await this.pluginManager.executeParallel('onRequestStart', context)
|
||||
await manager.executeParallel('onRequestStart', context)
|
||||
|
||||
// 2. 解析模型(如果是字符串)
|
||||
if (typeof model === 'string') {
|
||||
const resolved = await this.pluginManager.executeFirst<LanguageModel>('resolveModel', modelId, context)
|
||||
const resolved = await manager.executeFirst<LanguageModel>('resolveModel', modelId, context)
|
||||
if (!resolved) {
|
||||
throw new Error(`Failed to resolve model: ${modelId}`)
|
||||
throw new ModelResolutionError(modelId, this.providerId)
|
||||
}
|
||||
resolvedModel = resolved
|
||||
}
|
||||
|
||||
if (!resolvedModel) {
|
||||
throw new Error(`Model resolution failed: no model available`)
|
||||
throw new ModelResolutionError(modelId, this.providerId)
|
||||
}
|
||||
|
||||
// 3. 转换请求参数
|
||||
const transformedParams = await this.pluginManager.executeSequential('transformParams', params, context)
|
||||
const transformedParams = await manager.executeTransformParams(params, context)
|
||||
|
||||
// 4. 执行具体的 API 调用
|
||||
const result = await executor(resolvedModel, transformedParams)
|
||||
|
||||
// 5. 转换结果(对于非流式调用)
|
||||
const transformedResult = await this.pluginManager.executeSequential('transformResult', result, context)
|
||||
const transformedResult = await manager.executeTransformResult(result, context)
|
||||
|
||||
// 6. 触发完成事件
|
||||
await this.pluginManager.executeParallel('onRequestEnd', context, transformedResult)
|
||||
await manager.executeParallel('onRequestEnd', context, transformedResult)
|
||||
|
||||
return transformedResult
|
||||
} catch (error) {
|
||||
// 7. 触发错误事件
|
||||
await this.pluginManager.executeParallel('onError', context, undefined, error as Error)
|
||||
await manager.executeParallel('onError', context, undefined, error as Error)
|
||||
throw error
|
||||
}
|
||||
}
|
||||
@ -147,7 +182,7 @@ export class PluginEngine<T extends ProviderId = ProviderId> {
|
||||
methodName: string,
|
||||
params: TParams,
|
||||
executor: (model: ImageModelV3, transformedParams: TParams) => TResult,
|
||||
_context?: ReturnType<typeof createContext>
|
||||
_context?: AiRequestContext<TParams, TResult>
|
||||
): Promise<TResult> {
|
||||
// 统一处理模型解析
|
||||
let resolvedModel: ImageModelV3 | undefined
|
||||
@ -162,54 +197,76 @@ export class PluginEngine<T extends ProviderId = ProviderId> {
|
||||
modelId = model.modelId
|
||||
}
|
||||
|
||||
// 使用正确的createContext创建请求上下文
|
||||
const context = _context ? _context : createContext(this.providerId, model, params)
|
||||
// 创建类型安全的 context
|
||||
const context = _context ?? createContext(this.providerId, model, params)
|
||||
|
||||
// 🔥 为上下文添加递归调用能力
|
||||
context.recursiveCall = async (newParams: any): Promise<TResult> => {
|
||||
// 递归调用自身,重新走完整的插件流程
|
||||
context.isRecursiveCall = true
|
||||
const result = await this.executeImageWithPlugins(methodName, newParams, executor, context)
|
||||
context.isRecursiveCall = false
|
||||
return result
|
||||
// ✅ 创建类型化的 manager(逆变安全)
|
||||
const manager = new PluginManager<TParams, TResult>(
|
||||
this.basePlugins as AiPlugin<TParams, TResult>[]
|
||||
)
|
||||
|
||||
// ✅ 递归调用泛型化,增加深度限制
|
||||
context.recursiveCall = async <R = TResult>(newParams: Partial<TParams>): Promise<R> => {
|
||||
if (context.recursiveDepth >= context.maxRecursiveDepth) {
|
||||
throw new RecursiveDepthError(context.requestId, context.recursiveDepth, context.maxRecursiveDepth)
|
||||
}
|
||||
|
||||
const previousDepth = context.recursiveDepth
|
||||
const wasRecursive = context.isRecursiveCall
|
||||
|
||||
try {
|
||||
context.recursiveDepth = previousDepth + 1
|
||||
context.isRecursiveCall = true
|
||||
|
||||
return await this.executeImageWithPlugins(
|
||||
methodName,
|
||||
{ ...params, ...newParams } as TParams,
|
||||
executor,
|
||||
context
|
||||
) as unknown as R
|
||||
} finally {
|
||||
// ✅ finally 确保状态恢复
|
||||
context.recursiveDepth = previousDepth
|
||||
context.isRecursiveCall = wasRecursive
|
||||
}
|
||||
}
|
||||
|
||||
try {
|
||||
// 0. 配置上下文
|
||||
await this.pluginManager.executeConfigureContext(context)
|
||||
await manager.executeConfigureContext(context)
|
||||
|
||||
// 1. 触发请求开始事件
|
||||
await this.pluginManager.executeParallel('onRequestStart', context)
|
||||
await manager.executeParallel('onRequestStart', context)
|
||||
|
||||
// 2. 解析模型(如果是字符串)
|
||||
if (typeof model === 'string') {
|
||||
const resolved = await this.pluginManager.executeFirst<ImageModelV3>('resolveModel', modelId, context)
|
||||
const resolved = await manager.executeFirst<ImageModelV3>('resolveModel', modelId, context)
|
||||
if (!resolved) {
|
||||
throw new Error(`Failed to resolve image model: ${modelId}`)
|
||||
throw new ModelResolutionError(modelId, this.providerId)
|
||||
}
|
||||
resolvedModel = resolved
|
||||
}
|
||||
|
||||
if (!resolvedModel) {
|
||||
throw new Error(`Image model resolution failed: no model available`)
|
||||
throw new ModelResolutionError(modelId, this.providerId)
|
||||
}
|
||||
|
||||
// 3. 转换请求参数
|
||||
const transformedParams = await this.pluginManager.executeSequential('transformParams', params, context)
|
||||
const transformedParams = await manager.executeTransformParams(params, context)
|
||||
|
||||
// 4. 执行具体的 API 调用
|
||||
const result = await executor(resolvedModel, transformedParams)
|
||||
|
||||
// 5. 转换结果
|
||||
const transformedResult = await this.pluginManager.executeSequential('transformResult', result, context)
|
||||
const transformedResult = await manager.executeTransformResult(result, context)
|
||||
|
||||
// 6. 触发完成事件
|
||||
await this.pluginManager.executeParallel('onRequestEnd', context, transformedResult)
|
||||
await manager.executeParallel('onRequestEnd', context, transformedResult)
|
||||
|
||||
return transformedResult
|
||||
} catch (error) {
|
||||
// 7. 触发错误事件
|
||||
await this.pluginManager.executeParallel('onError', context, undefined, error as Error)
|
||||
await manager.executeParallel('onError', context, undefined, error as Error)
|
||||
throw error
|
||||
}
|
||||
}
|
||||
@ -219,13 +276,13 @@ export class PluginEngine<T extends ProviderId = ProviderId> {
|
||||
* 提供给AiExecutor使用
|
||||
*/
|
||||
async executeStreamWithPlugins<
|
||||
TParams extends Parameters<typeof streamText>[0],
|
||||
TResult extends ReturnType<typeof streamText>
|
||||
TParams extends StreamTextParams,
|
||||
TResult extends StreamTextResult
|
||||
>(
|
||||
methodName: string,
|
||||
params: TParams,
|
||||
executor: (model: LanguageModel, transformedParams: TParams, streamTransforms: any[]) => TResult,
|
||||
_context?: ReturnType<typeof createContext>
|
||||
_context?: AiRequestContext<TParams, TResult>
|
||||
): Promise<TResult> {
|
||||
// 统一处理模型解析
|
||||
let resolvedModel: LanguageModel | undefined
|
||||
@ -240,56 +297,78 @@ export class PluginEngine<T extends ProviderId = ProviderId> {
|
||||
modelId = model.modelId
|
||||
}
|
||||
|
||||
// 创建请求上下文
|
||||
const context = _context ? _context : createContext(this.providerId, model, params)
|
||||
// 创建类型安全的 context
|
||||
const context = _context ?? createContext(this.providerId, model, params)
|
||||
|
||||
// 🔥 为上下文添加递归调用能力
|
||||
context.recursiveCall = async (newParams: any): Promise<TResult> => {
|
||||
// 递归调用自身,重新走完整的插件流程
|
||||
context.isRecursiveCall = true
|
||||
const result = await this.executeStreamWithPlugins(methodName, newParams, executor, context)
|
||||
context.isRecursiveCall = false
|
||||
return result
|
||||
// ✅ 创建类型化的 manager(逆变安全)
|
||||
const manager = new PluginManager<TParams, TResult>(
|
||||
this.basePlugins as AiPlugin<TParams, TResult>[]
|
||||
)
|
||||
|
||||
// ✅ 递归调用泛型化,增加深度限制
|
||||
context.recursiveCall = async <R = TResult>(newParams: Partial<TParams>): Promise<R> => {
|
||||
if (context.recursiveDepth >= context.maxRecursiveDepth) {
|
||||
throw new RecursiveDepthError(context.requestId, context.recursiveDepth, context.maxRecursiveDepth)
|
||||
}
|
||||
|
||||
const previousDepth = context.recursiveDepth
|
||||
const wasRecursive = context.isRecursiveCall
|
||||
|
||||
try {
|
||||
context.recursiveDepth = previousDepth + 1
|
||||
context.isRecursiveCall = true
|
||||
|
||||
return await this.executeStreamWithPlugins(
|
||||
methodName,
|
||||
{ ...params, ...newParams } as TParams,
|
||||
executor,
|
||||
context
|
||||
) as unknown as R
|
||||
} finally {
|
||||
// ✅ finally 确保状态恢复
|
||||
context.recursiveDepth = previousDepth
|
||||
context.isRecursiveCall = wasRecursive
|
||||
}
|
||||
}
|
||||
|
||||
try {
|
||||
// 0. 配置上下文
|
||||
await this.pluginManager.executeConfigureContext(context)
|
||||
await manager.executeConfigureContext(context)
|
||||
|
||||
// 1. 触发请求开始事件
|
||||
await this.pluginManager.executeParallel('onRequestStart', context)
|
||||
await manager.executeParallel('onRequestStart', context)
|
||||
|
||||
// 2. 解析模型(如果是字符串)
|
||||
if (typeof model === 'string') {
|
||||
const resolved = await this.pluginManager.executeFirst<LanguageModel>('resolveModel', modelId, context)
|
||||
const resolved = await manager.executeFirst<LanguageModel>('resolveModel', modelId, context)
|
||||
if (!resolved) {
|
||||
throw new Error(`Failed to resolve model: ${modelId}`)
|
||||
throw new ModelResolutionError(modelId, this.providerId)
|
||||
}
|
||||
resolvedModel = resolved
|
||||
}
|
||||
|
||||
if (!resolvedModel) {
|
||||
throw new Error(`Model resolution failed: no model available`)
|
||||
throw new ModelResolutionError(modelId, this.providerId)
|
||||
}
|
||||
|
||||
// 3. 转换请求参数
|
||||
const transformedParams = await this.pluginManager.executeSequential('transformParams', params, context)
|
||||
const transformedParams = await manager.executeTransformParams(params, context)
|
||||
|
||||
// 4. 收集流转换器
|
||||
const streamTransforms = this.pluginManager.collectStreamTransforms(transformedParams, context)
|
||||
const streamTransforms = manager.collectStreamTransforms(transformedParams, context)
|
||||
|
||||
// 5. 执行流式 API 调用
|
||||
const result = await executor(resolvedModel, transformedParams, streamTransforms)
|
||||
|
||||
const transformedResult = await this.pluginManager.executeSequential('transformResult', result, context)
|
||||
const transformedResult = await manager.executeTransformResult(result, context)
|
||||
|
||||
// 6. 触发完成事件(注意:对于流式调用,这里触发的是开始流式响应的事件)
|
||||
await this.pluginManager.executeParallel('onRequestEnd', context, transformedResult)
|
||||
await manager.executeParallel('onRequestEnd', context, transformedResult)
|
||||
|
||||
return transformedResult
|
||||
} catch (error) {
|
||||
// 7. 触发错误事件
|
||||
await this.pluginManager.executeParallel('onError', context, undefined, error as Error)
|
||||
await manager.executeParallel('onError', context, undefined, error as Error)
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
80
packages/aiCore/src/core/types/branded.ts
Normal file
80
packages/aiCore/src/core/types/branded.ts
Normal file
@ -0,0 +1,80 @@
|
||||
/**
|
||||
* Branded Types for type-safe IDs
|
||||
*
|
||||
* Branded types prevent accidental misuse of primitive types (like string)
|
||||
* by adding compile-time type safety without runtime overhead.
|
||||
*
|
||||
* @example
|
||||
* ```typescript
|
||||
* const modelId = ModelId('gpt-4') // ModelId type
|
||||
* const requestId = RequestId('req-123') // RequestId type
|
||||
*
|
||||
* function processModel(id: ModelId) { ... }
|
||||
* processModel(requestId) // ❌ Compile error - type mismatch
|
||||
* ```
|
||||
*/
|
||||
|
||||
/**
|
||||
* Brand helper type
|
||||
*/
|
||||
type Brand<K, T> = K & { readonly __brand: T }
|
||||
|
||||
/**
|
||||
* Model ID branded type
|
||||
* Represents a unique model identifier
|
||||
*/
|
||||
export type ModelId = Brand<string, 'ModelId'>
|
||||
|
||||
/**
|
||||
* Request ID branded type
|
||||
* Represents a unique request identifier for tracing
|
||||
*/
|
||||
export type RequestId = Brand<string, 'RequestId'>
|
||||
|
||||
/**
|
||||
* Provider ID branded type
|
||||
* Represents a provider identifier (e.g., 'openai', 'anthropic')
|
||||
*/
|
||||
export type ProviderId = Brand<string, 'ProviderId'>
|
||||
|
||||
/**
|
||||
* Create a ModelId from a string
|
||||
* @param id - The model identifier string
|
||||
* @returns Branded ModelId
|
||||
*/
|
||||
export const ModelId = (id: string): ModelId => id as ModelId
|
||||
|
||||
/**
|
||||
* Create a RequestId from a string
|
||||
* @param id - The request identifier string
|
||||
* @returns Branded RequestId
|
||||
*/
|
||||
export const RequestId = (id: string): RequestId => id as RequestId
|
||||
|
||||
/**
|
||||
* Create a ProviderId from a string
|
||||
* @param id - The provider identifier string
|
||||
* @returns Branded ProviderId
|
||||
*/
|
||||
export const ProviderId = (id: string): ProviderId => id as ProviderId
|
||||
|
||||
/**
|
||||
* Type guard to check if a string is a valid ModelId
|
||||
*/
|
||||
export const isModelId = (value: unknown): value is ModelId => {
|
||||
return typeof value === 'string' && value.length > 0
|
||||
}
|
||||
|
||||
/**
|
||||
* Type guard to check if a string is a valid RequestId
|
||||
*/
|
||||
export const isRequestId = (value: unknown): value is RequestId => {
|
||||
return typeof value === 'string' && value.length > 0
|
||||
}
|
||||
|
||||
/**
|
||||
* Type guard to check if a string is a valid ProviderId
|
||||
*/
|
||||
export const isProviderId = (value: unknown): value is ProviderId => {
|
||||
return typeof value === 'string' && value.length > 0
|
||||
}
|
||||
@ -15,19 +15,34 @@ export {
|
||||
} from './core/runtime'
|
||||
|
||||
// ==================== 高级API ====================
|
||||
export { globalModelResolver as modelResolver } from './core/models'
|
||||
export { isV2Model, isV3Model, globalModelResolver as modelResolver } from './core/models'
|
||||
|
||||
// ==================== 插件系统 ====================
|
||||
export type { AiPlugin, AiRequestContext, HookResult, PluginManagerConfig } from './core/plugins'
|
||||
export type {
|
||||
AiPlugin,
|
||||
AiRequestContext,
|
||||
AiRequestMetadata,
|
||||
GenerateTextParams,
|
||||
GenerateTextResult,
|
||||
HookResult,
|
||||
PluginManagerConfig,
|
||||
RecursiveCallFn,
|
||||
StreamTextParams,
|
||||
StreamTextResult
|
||||
} from './core/plugins'
|
||||
export { createContext, definePlugin, PluginManager } from './core/plugins'
|
||||
// export { createPromptToolUsePlugin, webSearchPlugin } from './core/plugins/built-in'
|
||||
export { PluginEngine } from './core/runtime/pluginEngine'
|
||||
|
||||
// ==================== AI SDK 常用类型导出 ====================
|
||||
// 直接导出 AI SDK 的常用类型,方便使用
|
||||
export type { LanguageModelV3Middleware, LanguageModelV3StreamPart } from '@ai-sdk/provider'
|
||||
export type { ToolCall } from '@ai-sdk/provider-utils'
|
||||
export type { ReasoningPart } from '@ai-sdk/provider-utils'
|
||||
// ==================== 类型工具 ====================
|
||||
export type { ModelId, ProviderId, RequestId } from './core/types/branded'
|
||||
export { isModelId, isProviderId, isRequestId } from './core/types/branded'
|
||||
// Branded type constructors (values, not types)
|
||||
export type { AiSdkModel } from './core/providers'
|
||||
export {
|
||||
ModelId as createModelId,
|
||||
ProviderId as createProviderId,
|
||||
RequestId as createRequestId
|
||||
} from './core/types/branded'
|
||||
|
||||
// ==================== 选项 ====================
|
||||
export {
|
||||
@ -40,6 +55,17 @@ export {
|
||||
type TypedProviderOptions
|
||||
} from './core/options'
|
||||
|
||||
// ==================== 错误处理 ====================
|
||||
export {
|
||||
AiCoreError,
|
||||
ModelResolutionError,
|
||||
ParameterValidationError,
|
||||
PluginExecutionError,
|
||||
ProviderConfigError,
|
||||
RecursiveDepthError,
|
||||
TemplateLoadError
|
||||
} from './core/errors'
|
||||
|
||||
// ==================== 包信息 ====================
|
||||
export const AI_CORE_VERSION = '1.0.0'
|
||||
export const AI_CORE_NAME = '@cherrystudio/ai-core'
|
||||
|
||||
@ -7,6 +7,7 @@
|
||||
* 2. 暂时保持接口兼容性
|
||||
*/
|
||||
|
||||
import type { AiSdkModel } from '@cherrystudio/ai-core'
|
||||
import { createExecutor } from '@cherrystudio/ai-core'
|
||||
import { loggerService } from '@logger'
|
||||
import { getEnableDeveloperMode } from '@renderer/hooks/useSettings'
|
||||
@ -14,16 +15,14 @@ import { normalizeGatewayModels, normalizeSdkModels } from '@renderer/services/m
|
||||
import { addSpan, endSpan } from '@renderer/services/SpanManagerService'
|
||||
import type { StartSpanParams } from '@renderer/trace/types/ModelSpanEntity'
|
||||
import { type Assistant, type GenerateImageParams, type Model, type Provider, SystemProviderIds } from '@renderer/types'
|
||||
import type { AiSdkModel, StreamTextParams } from '@renderer/types/aiCoreTypes'
|
||||
import type { StreamTextParams } from '@renderer/types/aiCoreTypes'
|
||||
import { SUPPORTED_IMAGE_ENDPOINT_LIST } from '@renderer/utils'
|
||||
import { buildClaudeCodeSystemModelMessage } from '@shared/anthropic'
|
||||
import { gateway, type ImageModel, type LanguageModel, type Provider as AiSdkProvider, wrapLanguageModel } from 'ai'
|
||||
import { gateway, type LanguageModel, type Provider as AiSdkProvider } from 'ai'
|
||||
|
||||
import AiSdkToChunkAdapter from './chunk/AiSdkToChunkAdapter'
|
||||
import LegacyAiProvider from './legacy/index'
|
||||
import type { CompletionsParams, CompletionsResult } from './legacy/middleware/schemas'
|
||||
import type { AiSdkMiddlewareConfig } from './middleware/AiSdkMiddlewareBuilder'
|
||||
import { buildAiSdkMiddlewares } from './middleware/AiSdkMiddlewareBuilder'
|
||||
import { buildPlugins } from './plugins/PluginBuilder'
|
||||
import { createAiSdkProvider } from './provider/factory'
|
||||
import {
|
||||
@ -34,6 +33,7 @@ import {
|
||||
providerToAiSdkConfig
|
||||
} from './provider/providerConfig'
|
||||
import type { AiSdkConfig } from './types'
|
||||
import type { AiSdkMiddlewareConfig } from './types/middlewareConfig'
|
||||
|
||||
const logger = loggerService.withContext('ModernAiProvider')
|
||||
|
||||
@ -144,16 +144,6 @@ export default class ModernAiProvider {
|
||||
this.localProvider = await createAiSdkProvider(this.config)
|
||||
}
|
||||
|
||||
// 提前构建中间件
|
||||
const middlewares = buildAiSdkMiddlewares({
|
||||
...providerConfig,
|
||||
provider: this.actualProvider,
|
||||
assistant: providerConfig.assistant
|
||||
})
|
||||
logger.debug('Built middlewares in completions', {
|
||||
middlewareCount: middlewares.length,
|
||||
isImageGeneration: providerConfig.isImageGenerationEndpoint
|
||||
})
|
||||
if (!this.localProvider) {
|
||||
throw new Error('Local provider not created')
|
||||
}
|
||||
@ -164,14 +154,20 @@ export default class ModernAiProvider {
|
||||
model = this.localProvider.imageModel(modelId)
|
||||
} else {
|
||||
model = this.localProvider.languageModel(modelId)
|
||||
// 如果有中间件,应用到语言模型上
|
||||
if (middlewares.length > 0 && typeof model === 'object') {
|
||||
model = wrapLanguageModel({ model, middleware: middlewares })
|
||||
}
|
||||
}
|
||||
|
||||
if (this.actualProvider.id === 'anthropic' && this.actualProvider.authType === 'oauth') {
|
||||
const claudeCodeSystemMessage = buildClaudeCodeSystemModelMessage(params.system)
|
||||
// 类型守卫:确保 system 是 string、Array 或 undefined
|
||||
const system = params.system
|
||||
let systemParam: string | Array<any> | undefined
|
||||
if (typeof system === 'string' || Array.isArray(system) || system === undefined) {
|
||||
systemParam = system
|
||||
} else {
|
||||
// SystemModelMessage 类型,转换为 string
|
||||
systemParam = undefined
|
||||
}
|
||||
|
||||
const claudeCodeSystemMessage = buildClaudeCodeSystemModelMessage(systemParam)
|
||||
params.system = undefined // 清除原有system,避免重复
|
||||
params.messages = [...claudeCodeSystemMessage, ...(params.messages || [])]
|
||||
}
|
||||
@ -543,7 +539,7 @@ export default class ModernAiProvider {
|
||||
|
||||
const executor = createExecutor(this.config!.providerId, this.config!.options, [])
|
||||
const result = await executor.generateImage({
|
||||
model: this.localProvider?.imageModel(model) as ImageModel,
|
||||
model: model, // 直接使用 model ID 字符串,由 executor 内部解析
|
||||
...aiSdkParams
|
||||
})
|
||||
|
||||
|
||||
@ -1,286 +0,0 @@
|
||||
import type { WebSearchPluginConfig } from '@cherrystudio/ai-core/built-in/plugins'
|
||||
import { loggerService } from '@logger'
|
||||
import { isGemini3Model, isSupportedThinkingTokenQwenModel } from '@renderer/config/models'
|
||||
import type { MCPTool } from '@renderer/types'
|
||||
import { type Assistant, type Message, type Model, type Provider, SystemProviderIds } from '@renderer/types'
|
||||
import type { Chunk } from '@renderer/types/chunk'
|
||||
import { isOllamaProvider, isSupportEnableThinkingProvider } from '@renderer/utils/provider'
|
||||
import type { LanguageModelMiddleware } from 'ai'
|
||||
import { extractReasoningMiddleware, simulateStreamingMiddleware } from 'ai'
|
||||
|
||||
import { getAiSdkProviderId } from '../provider/factory'
|
||||
import { isOpenRouterGeminiGenerateImageModel } from '../utils/image'
|
||||
import { noThinkMiddleware } from './noThinkMiddleware'
|
||||
import { openrouterGenerateImageMiddleware } from './openrouterGenerateImageMiddleware'
|
||||
import { openrouterReasoningMiddleware } from './openrouterReasoningMiddleware'
|
||||
import { qwenThinkingMiddleware } from './qwenThinkingMiddleware'
|
||||
import { skipGeminiThoughtSignatureMiddleware } from './skipGeminiThoughtSignatureMiddleware'
|
||||
|
||||
const logger = loggerService.withContext('AiSdkMiddlewareBuilder')
|
||||
|
||||
/**
|
||||
* AI SDK 中间件配置项
|
||||
*/
|
||||
export interface AiSdkMiddlewareConfig {
|
||||
streamOutput: boolean
|
||||
onChunk?: (chunk: Chunk) => void
|
||||
model?: Model
|
||||
provider?: Provider
|
||||
assistant?: Assistant
|
||||
enableReasoning: boolean
|
||||
// 是否开启提示词工具调用
|
||||
isPromptToolUse: boolean
|
||||
// 是否支持工具调用
|
||||
isSupportedToolUse: boolean
|
||||
// image generation endpoint
|
||||
isImageGenerationEndpoint: boolean
|
||||
// 是否开启内置搜索
|
||||
enableWebSearch: boolean
|
||||
enableGenerateImage: boolean
|
||||
enableUrlContext: boolean
|
||||
mcpTools?: MCPTool[]
|
||||
uiMessages?: Message[]
|
||||
// 内置搜索配置
|
||||
webSearchPluginConfig?: WebSearchPluginConfig
|
||||
// 知识库识别开关,默认开启
|
||||
knowledgeRecognition?: 'off' | 'on'
|
||||
}
|
||||
|
||||
/**
|
||||
* 具名的 AI SDK 中间件
|
||||
*/
|
||||
export interface NamedAiSdkMiddleware {
|
||||
name: string
|
||||
middleware: LanguageModelMiddleware
|
||||
}
|
||||
|
||||
/**
|
||||
* AI SDK 中间件建造者
|
||||
* 用于根据不同条件动态构建中间件数组
|
||||
*/
|
||||
export class AiSdkMiddlewareBuilder {
|
||||
private middlewares: NamedAiSdkMiddleware[] = []
|
||||
|
||||
/**
|
||||
* 添加具名中间件
|
||||
*/
|
||||
public add(namedMiddleware: NamedAiSdkMiddleware): this {
|
||||
this.middlewares.push(namedMiddleware)
|
||||
return this
|
||||
}
|
||||
|
||||
/**
|
||||
* 在指定位置插入中间件
|
||||
*/
|
||||
public insertAfter(targetName: string, middleware: NamedAiSdkMiddleware): this {
|
||||
const index = this.middlewares.findIndex((m) => m.name === targetName)
|
||||
if (index !== -1) {
|
||||
this.middlewares.splice(index + 1, 0, middleware)
|
||||
} else {
|
||||
logger.warn(`AiSdkMiddlewareBuilder: Middleware named '${targetName}' not found, cannot insert`)
|
||||
}
|
||||
return this
|
||||
}
|
||||
|
||||
/**
|
||||
* 检查是否包含指定名称的中间件
|
||||
*/
|
||||
public has(name: string): boolean {
|
||||
return this.middlewares.some((m) => m.name === name)
|
||||
}
|
||||
|
||||
/**
|
||||
* 移除指定名称的中间件
|
||||
*/
|
||||
public remove(name: string): this {
|
||||
this.middlewares = this.middlewares.filter((m) => m.name !== name)
|
||||
return this
|
||||
}
|
||||
|
||||
/**
|
||||
* 构建最终的中间件数组
|
||||
*/
|
||||
public build(): LanguageModelMiddleware[] {
|
||||
return this.middlewares.map((m) => m.middleware)
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取具名中间件数组(用于调试)
|
||||
*/
|
||||
public buildNamed(): NamedAiSdkMiddleware[] {
|
||||
return [...this.middlewares]
|
||||
}
|
||||
|
||||
/**
|
||||
* 清空所有中间件
|
||||
*/
|
||||
public clear(): this {
|
||||
this.middlewares = []
|
||||
return this
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取中间件总数
|
||||
*/
|
||||
public get length(): number {
|
||||
return this.middlewares.length
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 根据配置构建AI SDK中间件的工厂函数
|
||||
* 这里要注意构建顺序,因为有些中间件需要依赖其他中间件的结果
|
||||
*/
|
||||
export function buildAiSdkMiddlewares(config: AiSdkMiddlewareConfig): LanguageModelMiddleware[] {
|
||||
const builder = new AiSdkMiddlewareBuilder()
|
||||
|
||||
// 1. 根据provider添加特定中间件
|
||||
if (config.provider) {
|
||||
addProviderSpecificMiddlewares(builder, config)
|
||||
}
|
||||
|
||||
// 2. 根据模型类型添加特定中间件
|
||||
if (config.model) {
|
||||
addModelSpecificMiddlewares(builder, config)
|
||||
}
|
||||
|
||||
// 3. 非流式输出时添加模拟流中间件
|
||||
if (config.streamOutput === false) {
|
||||
builder.add({
|
||||
name: 'simulate-streaming',
|
||||
middleware: simulateStreamingMiddleware()
|
||||
})
|
||||
}
|
||||
|
||||
return builder.build()
|
||||
}
|
||||
|
||||
const tagName = {
|
||||
reasoning: 'reasoning',
|
||||
think: 'think',
|
||||
thought: 'thought',
|
||||
seedThink: 'seed:think'
|
||||
}
|
||||
|
||||
function getReasoningTagName(modelId: string | undefined): string {
|
||||
if (modelId?.includes('gpt-oss')) return tagName.reasoning
|
||||
if (modelId?.includes('gemini')) return tagName.thought
|
||||
if (modelId?.includes('seed-oss-36b')) return tagName.seedThink
|
||||
return tagName.think
|
||||
}
|
||||
|
||||
/**
|
||||
* 添加provider特定的中间件
|
||||
*/
|
||||
function addProviderSpecificMiddlewares(builder: AiSdkMiddlewareBuilder, config: AiSdkMiddlewareConfig): void {
|
||||
if (!config.provider) return
|
||||
|
||||
// 根据不同provider添加特定中间件
|
||||
switch (config.provider.type) {
|
||||
case 'anthropic':
|
||||
// Anthropic特定中间件
|
||||
break
|
||||
case 'openai':
|
||||
case 'azure-openai': {
|
||||
if (config.enableReasoning) {
|
||||
const tagName = getReasoningTagName(config.model?.id.toLowerCase())
|
||||
builder.add({
|
||||
name: 'thinking-tag-extraction',
|
||||
middleware: extractReasoningMiddleware({ tagName })
|
||||
})
|
||||
}
|
||||
break
|
||||
}
|
||||
case 'gemini':
|
||||
// Gemini特定中间件
|
||||
break
|
||||
case 'aws-bedrock': {
|
||||
break
|
||||
}
|
||||
default:
|
||||
// 其他provider的通用处理
|
||||
break
|
||||
}
|
||||
|
||||
// OVMS+MCP's specific middleware
|
||||
if (config.provider.id === 'ovms' && config.mcpTools && config.mcpTools.length > 0) {
|
||||
builder.add({
|
||||
name: 'no-think',
|
||||
middleware: noThinkMiddleware()
|
||||
})
|
||||
}
|
||||
|
||||
if (config.provider.id === SystemProviderIds.openrouter && config.enableReasoning) {
|
||||
builder.add({
|
||||
name: 'openrouter-reasoning-redaction',
|
||||
middleware: openrouterReasoningMiddleware()
|
||||
})
|
||||
logger.debug('Added OpenRouter reasoning redaction middleware')
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 添加模型特定的中间件
|
||||
*/
|
||||
function addModelSpecificMiddlewares(builder: AiSdkMiddlewareBuilder, config: AiSdkMiddlewareConfig): void {
|
||||
if (!config.model || !config.provider) return
|
||||
|
||||
// Qwen models on providers that don't support enable_thinking parameter (like Ollama, LM Studio, NVIDIA)
|
||||
// Use /think or /no_think suffix to control thinking mode
|
||||
if (
|
||||
config.provider &&
|
||||
!isOllamaProvider(config.provider) &&
|
||||
isSupportedThinkingTokenQwenModel(config.model) &&
|
||||
!isSupportEnableThinkingProvider(config.provider)
|
||||
) {
|
||||
const enableThinking = config.assistant?.settings?.reasoning_effort !== undefined
|
||||
builder.add({
|
||||
name: 'qwen-thinking-control',
|
||||
middleware: qwenThinkingMiddleware(enableThinking)
|
||||
})
|
||||
logger.debug(`Added Qwen thinking middleware with thinking ${enableThinking ? 'enabled' : 'disabled'}`)
|
||||
}
|
||||
|
||||
// 可以根据模型ID或特性添加特定中间件
|
||||
// 例如:图像生成模型、多模态模型等
|
||||
if (isOpenRouterGeminiGenerateImageModel(config.model, config.provider)) {
|
||||
builder.add({
|
||||
name: 'openrouter-gemini-image-generation',
|
||||
middleware: openrouterGenerateImageMiddleware()
|
||||
})
|
||||
}
|
||||
|
||||
if (isGemini3Model(config.model)) {
|
||||
const aiSdkId = getAiSdkProviderId(config.provider)
|
||||
builder.add({
|
||||
name: 'skip-gemini3-thought-signature',
|
||||
middleware: skipGeminiThoughtSignatureMiddleware(aiSdkId)
|
||||
})
|
||||
logger.debug('Added skip Gemini3 thought signature middleware')
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 创建一个预配置的中间件建造者
|
||||
*/
|
||||
export function createAiSdkMiddlewareBuilder(): AiSdkMiddlewareBuilder {
|
||||
return new AiSdkMiddlewareBuilder()
|
||||
}
|
||||
|
||||
/**
|
||||
* 创建一个带有默认中间件的建造者
|
||||
*/
|
||||
export function createDefaultAiSdkMiddlewareBuilder(config: AiSdkMiddlewareConfig): AiSdkMiddlewareBuilder {
|
||||
const builder = new AiSdkMiddlewareBuilder()
|
||||
const defaultMiddlewares = buildAiSdkMiddlewares(config)
|
||||
|
||||
// 将普通中间件数组转换为具名中间件并添加
|
||||
defaultMiddlewares.forEach((middleware, index) => {
|
||||
builder.add({
|
||||
name: `default-middleware-${index}`,
|
||||
middleware
|
||||
})
|
||||
})
|
||||
|
||||
return builder
|
||||
}
|
||||
@ -1,140 +0,0 @@
|
||||
# AI SDK 中间件建造者
|
||||
|
||||
## 概述
|
||||
|
||||
`AiSdkMiddlewareBuilder` 是一个用于动态构建 AI SDK 中间件数组的建造者模式实现。它可以根据不同的条件(如流式输出、思考模型、provider类型等)自动构建合适的中间件组合。
|
||||
|
||||
## 使用方式
|
||||
|
||||
### 基本用法
|
||||
|
||||
```typescript
|
||||
import { buildAiSdkMiddlewares, type AiSdkMiddlewareConfig } from './AiSdkMiddlewareBuilder'
|
||||
|
||||
// 配置中间件参数
|
||||
const config: AiSdkMiddlewareConfig = {
|
||||
streamOutput: false, // 非流式输出
|
||||
onChunk: chunkHandler, // chunk回调函数
|
||||
model: currentModel, // 当前模型
|
||||
provider: currentProvider, // 当前provider
|
||||
enableReasoning: true, // 启用推理
|
||||
enableTool: false, // 禁用工具
|
||||
enableWebSearch: false // 禁用网页搜索
|
||||
}
|
||||
|
||||
// 构建中间件数组
|
||||
const middlewares = buildAiSdkMiddlewares(config)
|
||||
|
||||
// 创建带有中间件的客户端
|
||||
const client = createClient(providerId, options, middlewares)
|
||||
```
|
||||
|
||||
### 手动构建
|
||||
|
||||
```typescript
|
||||
import { AiSdkMiddlewareBuilder, createAiSdkMiddlewareBuilder } from './AiSdkMiddlewareBuilder'
|
||||
|
||||
const builder = createAiSdkMiddlewareBuilder()
|
||||
|
||||
// 添加特定中间件
|
||||
builder.add({
|
||||
name: 'custom-middleware',
|
||||
aiSdkMiddlewares: [customMiddleware()]
|
||||
})
|
||||
|
||||
// 检查是否包含某个中间件
|
||||
if (builder.has('thinking-time')) {
|
||||
console.log('已包含思考时间中间件')
|
||||
}
|
||||
|
||||
// 移除不需要的中间件
|
||||
builder.remove('simulate-streaming')
|
||||
|
||||
// 构建最终数组
|
||||
const middlewares = builder.build()
|
||||
```
|
||||
|
||||
## 支持的条件
|
||||
|
||||
### 1. 流式输出控制
|
||||
|
||||
- **streamOutput = false**: 自动添加 `simulateStreamingMiddleware`
|
||||
- **streamOutput = true**: 使用原生流式处理
|
||||
|
||||
### 2. 思考模型处理
|
||||
|
||||
- **条件**: `onChunk` 存在 && `isReasoningModel(model)` 为 true
|
||||
- **效果**: 自动添加 `thinkingTimeMiddleware`
|
||||
|
||||
### 3. Provider 特定中间件
|
||||
|
||||
根据不同的 provider 类型添加特定中间件:
|
||||
|
||||
- **anthropic**: Anthropic 特定处理
|
||||
- **openai**: OpenAI 特定处理
|
||||
- **gemini**: Gemini 特定处理
|
||||
|
||||
### 4. 模型特定中间件
|
||||
|
||||
根据模型特性添加中间件:
|
||||
|
||||
- **图像生成模型**: 添加图像处理相关中间件
|
||||
- **多模态模型**: 添加多模态处理中间件
|
||||
|
||||
## 扩展指南
|
||||
|
||||
### 添加新的条件判断
|
||||
|
||||
在 `buildAiSdkMiddlewares` 函数中添加新的条件:
|
||||
|
||||
```typescript
|
||||
// 例如:添加缓存中间件
|
||||
if (config.enableCache) {
|
||||
builder.add({
|
||||
name: 'cache',
|
||||
aiSdkMiddlewares: [cacheMiddleware(config.cacheOptions)]
|
||||
})
|
||||
}
|
||||
```
|
||||
|
||||
### 添加 Provider 特定处理
|
||||
|
||||
在 `addProviderSpecificMiddlewares` 函数中添加:
|
||||
|
||||
```typescript
|
||||
case 'custom-provider':
|
||||
builder.add({
|
||||
name: 'custom-provider-middleware',
|
||||
aiSdkMiddlewares: [customProviderMiddleware()]
|
||||
})
|
||||
break
|
||||
```
|
||||
|
||||
### 添加模型特定处理
|
||||
|
||||
在 `addModelSpecificMiddlewares` 函数中添加:
|
||||
|
||||
```typescript
|
||||
if (config.model.id.includes('custom-model')) {
|
||||
builder.add({
|
||||
name: 'custom-model-middleware',
|
||||
aiSdkMiddlewares: [customModelMiddleware()]
|
||||
})
|
||||
}
|
||||
```
|
||||
|
||||
## 中间件执行顺序
|
||||
|
||||
中间件按照添加顺序执行:
|
||||
|
||||
1. **simulate-streaming** (如果 streamOutput = false)
|
||||
2. **thinking-time** (如果是思考模型且有 onChunk)
|
||||
3. **provider-specific** (根据 provider 类型)
|
||||
4. **model-specific** (根据模型类型)
|
||||
|
||||
## 注意事项
|
||||
|
||||
1. 中间件的执行顺序很重要,确保按正确顺序添加
|
||||
2. 避免添加冲突的中间件
|
||||
3. 某些中间件可能有依赖关系,需要确保依赖的中间件先添加
|
||||
4. 建议在开发环境下启用日志,以便调试中间件构建过程
|
||||
@ -1,45 +0,0 @@
|
||||
import { loggerService } from '@logger'
|
||||
import type { LanguageModelMiddleware } from 'ai'
|
||||
|
||||
const logger = loggerService.withContext('toolChoiceMiddleware')
|
||||
|
||||
/**
|
||||
* Tool Choice Middleware
|
||||
* Controls tool selection strategy across multiple rounds of tool calls:
|
||||
* - First round: Forces the model to call a specific tool (e.g., knowledge base search)
|
||||
* - Subsequent rounds: Allows the model to automatically choose any available tool
|
||||
*
|
||||
* This ensures knowledge base is consulted first while still enabling MCP tools
|
||||
* and other capabilities in follow-up interactions.
|
||||
*
|
||||
* @param forceFirstToolName - The tool name to force on the first round
|
||||
* @returns LanguageModelMiddleware
|
||||
*/
|
||||
export function toolChoiceMiddleware(forceFirstToolName: string): LanguageModelMiddleware {
|
||||
let toolCallRound = 0
|
||||
|
||||
return {
|
||||
middlewareVersion: 'v2',
|
||||
|
||||
transformParams: async ({ params }) => {
|
||||
toolCallRound++
|
||||
|
||||
const transformedParams = { ...params }
|
||||
|
||||
if (toolCallRound === 1) {
|
||||
// First round: force the specified tool
|
||||
logger.debug(`Round ${toolCallRound}: Forcing tool choice to '${forceFirstToolName}'`)
|
||||
transformedParams.toolChoice = {
|
||||
type: 'tool',
|
||||
toolName: forceFirstToolName
|
||||
}
|
||||
} else {
|
||||
// Subsequent rounds: allow automatic tool selection
|
||||
logger.debug(`Round ${toolCallRound}: Using automatic tool choice`)
|
||||
transformedParams.toolChoice = { type: 'auto' }
|
||||
}
|
||||
|
||||
return transformedParams
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1,11 +1,24 @@
|
||||
import type { AiPlugin } from '@cherrystudio/ai-core'
|
||||
import { createPromptToolUsePlugin, webSearchPlugin } from '@cherrystudio/ai-core/built-in/plugins'
|
||||
import { loggerService } from '@logger'
|
||||
import { isGemini3Model, isSupportedThinkingTokenQwenModel } from '@renderer/config/models'
|
||||
import { getEnableDeveloperMode } from '@renderer/hooks/useSettings'
|
||||
import type { Assistant } from '@renderer/types'
|
||||
import { SystemProviderIds } from '@renderer/types'
|
||||
import { isOllamaProvider, isSupportEnableThinkingProvider } from '@renderer/utils/provider'
|
||||
|
||||
import type { AiSdkMiddlewareConfig } from '../middleware/AiSdkMiddlewareBuilder'
|
||||
import { getAiSdkProviderId } from '../provider/factory'
|
||||
import type { AiSdkMiddlewareConfig } from '../types/middlewareConfig'
|
||||
import { isOpenRouterGeminiGenerateImageModel } from '../utils/image'
|
||||
import { getReasoningTagName } from '../utils/reasoning'
|
||||
import { createNoThinkPlugin } from './noThinkPlugin'
|
||||
import { createOpenrouterGenerateImagePlugin } from './openrouterGenerateImagePlugin'
|
||||
import { createOpenrouterReasoningPlugin } from './openrouterReasoningPlugin'
|
||||
import { createQwenThinkingPlugin } from './qwenThinkingPlugin'
|
||||
import { createReasoningExtractionPlugin } from './reasoningExtractionPlugin'
|
||||
import { searchOrchestrationPlugin } from './searchOrchestrationPlugin'
|
||||
import { createSimulateStreamingPlugin } from './simulateStreamingPlugin'
|
||||
import { createSkipGeminiThoughtSignaturePlugin } from './skipGeminiThoughtSignaturePlugin'
|
||||
import { createTelemetryPlugin } from './telemetryPlugin'
|
||||
|
||||
const logger = loggerService.withContext('PluginBuilder')
|
||||
@ -28,6 +41,59 @@ export function buildPlugins(
|
||||
)
|
||||
}
|
||||
|
||||
// === AI SDK Middleware Plugins ===
|
||||
|
||||
// 0.1 Simulate streaming for non-streaming requests
|
||||
if (!middlewareConfig.streamOutput) {
|
||||
plugins.push(createSimulateStreamingPlugin())
|
||||
}
|
||||
|
||||
// 0.2 Reasoning extraction for OpenAI/Azure providers
|
||||
if (middlewareConfig.enableReasoning && middlewareConfig.provider) {
|
||||
const providerType = middlewareConfig.provider.type
|
||||
if (providerType === 'openai' || providerType === 'azure-openai') {
|
||||
const tagName = getReasoningTagName(middlewareConfig.model?.id.toLowerCase())
|
||||
plugins.push(createReasoningExtractionPlugin({ tagName }))
|
||||
}
|
||||
}
|
||||
|
||||
// 0.3 OpenRouter reasoning redaction
|
||||
if (middlewareConfig.provider?.id === SystemProviderIds.openrouter && middlewareConfig.enableReasoning) {
|
||||
plugins.push(createOpenrouterReasoningPlugin())
|
||||
}
|
||||
|
||||
// 0.4 OVMS no-think for MCP tools
|
||||
if (middlewareConfig.provider?.id === 'ovms' && middlewareConfig.mcpTools && middlewareConfig.mcpTools.length > 0) {
|
||||
plugins.push(createNoThinkPlugin())
|
||||
}
|
||||
|
||||
// 0.5 Qwen thinking control for providers without enable_thinking support
|
||||
if (
|
||||
middlewareConfig.provider &&
|
||||
middlewareConfig.model &&
|
||||
!isOllamaProvider(middlewareConfig.provider) &&
|
||||
isSupportedThinkingTokenQwenModel(middlewareConfig.model) &&
|
||||
!isSupportEnableThinkingProvider(middlewareConfig.provider)
|
||||
) {
|
||||
const enableThinking = middlewareConfig.assistant?.settings?.reasoning_effort !== undefined
|
||||
plugins.push(createQwenThinkingPlugin(enableThinking))
|
||||
}
|
||||
|
||||
// 0.6 OpenRouter Gemini image generation
|
||||
if (
|
||||
middlewareConfig.model &&
|
||||
middlewareConfig.provider &&
|
||||
isOpenRouterGeminiGenerateImageModel(middlewareConfig.model, middlewareConfig.provider)
|
||||
) {
|
||||
plugins.push(createOpenrouterGenerateImagePlugin())
|
||||
}
|
||||
|
||||
// 0.7 Skip Gemini3 thought signature
|
||||
if (middlewareConfig.model && middlewareConfig.provider && isGemini3Model(middlewareConfig.model)) {
|
||||
const aiSdkId = getAiSdkProviderId(middlewareConfig.provider)
|
||||
plugins.push(createSkipGeminiThoughtSignaturePlugin(aiSdkId))
|
||||
}
|
||||
|
||||
// 1. 模型内置搜索
|
||||
if (middlewareConfig.enableWebSearch && middlewareConfig.webSearchPluginConfig) {
|
||||
plugins.push(webSearchPlugin(middlewareConfig.webSearchPluginConfig))
|
||||
|
||||
@ -1,7 +1,8 @@
|
||||
import { definePlugin } from '@cherrystudio/ai-core'
|
||||
import { loggerService } from '@logger'
|
||||
import type { LanguageModelMiddleware } from 'ai'
|
||||
|
||||
const logger = loggerService.withContext('noThinkMiddleware')
|
||||
const logger = loggerService.withContext('noThinkPlugin')
|
||||
|
||||
/**
|
||||
* No Think Middleware
|
||||
@ -9,9 +10,9 @@ const logger = loggerService.withContext('noThinkMiddleware')
|
||||
* This prevents the model from generating unnecessary thinking process and returns results directly
|
||||
* @returns LanguageModelMiddleware
|
||||
*/
|
||||
export function noThinkMiddleware(): LanguageModelMiddleware {
|
||||
function createNoThinkMiddleware(): LanguageModelMiddleware {
|
||||
return {
|
||||
middlewareVersion: 'v2',
|
||||
specificationVersion: 'v3',
|
||||
|
||||
transformParams: async ({ params }) => {
|
||||
const transformedParams = { ...params }
|
||||
@ -50,3 +51,14 @@ export function noThinkMiddleware(): LanguageModelMiddleware {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export const createNoThinkPlugin = () =>
|
||||
definePlugin({
|
||||
name: 'noThink',
|
||||
enforce: 'pre',
|
||||
|
||||
configureContext: (context) => {
|
||||
context.middlewares = context.middlewares || []
|
||||
context.middlewares.push(createNoThinkMiddleware())
|
||||
}
|
||||
})
|
||||
@ -1,3 +1,4 @@
|
||||
import { definePlugin } from '@cherrystudio/ai-core'
|
||||
import type { LanguageModelMiddleware } from 'ai'
|
||||
|
||||
/**
|
||||
@ -6,7 +7,7 @@ import type { LanguageModelMiddleware } from 'ai'
|
||||
* https://openrouter.ai/docs/features/multimodal/image-generation
|
||||
*
|
||||
* Remarks:
|
||||
* - The middleware declares middlewareVersion as 'v2'.
|
||||
* - The middleware declares specificationVersion as 'v3'.
|
||||
* - transformParams asynchronously clones the incoming params and sets
|
||||
* providerOptions.openrouter.modalities = ['image', 'text'], preserving other providerOptions and
|
||||
* openrouter fields when present.
|
||||
@ -15,9 +16,9 @@ import type { LanguageModelMiddleware } from 'ai'
|
||||
*
|
||||
* @returns LanguageModelMiddleware - a middleware that augments providerOptions for OpenRouter to include image and text modalities.
|
||||
*/
|
||||
export function openrouterGenerateImageMiddleware(): LanguageModelMiddleware {
|
||||
function createOpenrouterGenerateImageMiddleware(): LanguageModelMiddleware {
|
||||
return {
|
||||
middlewareVersion: 'v2',
|
||||
specificationVersion: 'v3',
|
||||
|
||||
transformParams: async ({ params }) => {
|
||||
const transformedParams = { ...params }
|
||||
@ -25,9 +26,19 @@ export function openrouterGenerateImageMiddleware(): LanguageModelMiddleware {
|
||||
...transformedParams.providerOptions,
|
||||
openrouter: { ...transformedParams.providerOptions?.openrouter, modalities: ['image', 'text'] }
|
||||
}
|
||||
transformedParams
|
||||
|
||||
return transformedParams
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export const createOpenrouterGenerateImagePlugin = () =>
|
||||
definePlugin({
|
||||
name: 'openrouterGenerateImage',
|
||||
enforce: 'pre',
|
||||
|
||||
configureContext: (context) => {
|
||||
context.middlewares = context.middlewares || []
|
||||
context.middlewares.push(createOpenrouterGenerateImageMiddleware())
|
||||
}
|
||||
})
|
||||
@ -1,4 +1,5 @@
|
||||
import type { LanguageModelV2StreamPart } from '@ai-sdk/provider'
|
||||
import type { LanguageModelV3StreamPart } from '@ai-sdk/provider'
|
||||
import { definePlugin } from '@cherrystudio/ai-core'
|
||||
import type { LanguageModelMiddleware } from 'ai'
|
||||
|
||||
/**
|
||||
@ -6,10 +7,10 @@ import type { LanguageModelMiddleware } from 'ai'
|
||||
*
|
||||
* @returns LanguageModelMiddleware - a middleware filter redacted block
|
||||
*/
|
||||
export function openrouterReasoningMiddleware(): LanguageModelMiddleware {
|
||||
function createOpenrouterReasoningMiddleware(): LanguageModelMiddleware {
|
||||
const REDACTED_BLOCK = '[REDACTED]'
|
||||
return {
|
||||
middlewareVersion: 'v2',
|
||||
specificationVersion: 'v3',
|
||||
wrapGenerate: async ({ doGenerate }) => {
|
||||
const { content, ...rest } = await doGenerate()
|
||||
const modifiedContent = content.map((part) => {
|
||||
@ -27,10 +28,10 @@ export function openrouterReasoningMiddleware(): LanguageModelMiddleware {
|
||||
const { stream, ...rest } = await doStream()
|
||||
return {
|
||||
stream: stream.pipeThrough(
|
||||
new TransformStream<LanguageModelV2StreamPart, LanguageModelV2StreamPart>({
|
||||
new TransformStream<LanguageModelV3StreamPart, LanguageModelV3StreamPart>({
|
||||
transform(
|
||||
chunk: LanguageModelV2StreamPart,
|
||||
controller: TransformStreamDefaultController<LanguageModelV2StreamPart>
|
||||
chunk: LanguageModelV3StreamPart,
|
||||
controller: TransformStreamDefaultController<LanguageModelV3StreamPart>
|
||||
) {
|
||||
if (chunk.type === 'reasoning-delta' && chunk.delta.includes(REDACTED_BLOCK)) {
|
||||
controller.enqueue({
|
||||
@ -48,3 +49,14 @@ export function openrouterReasoningMiddleware(): LanguageModelMiddleware {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export const createOpenrouterReasoningPlugin = () =>
|
||||
definePlugin({
|
||||
name: 'openrouterReasoning',
|
||||
enforce: 'pre',
|
||||
|
||||
configureContext: (context) => {
|
||||
context.middlewares = context.middlewares || []
|
||||
context.middlewares.push(createOpenrouterReasoningMiddleware())
|
||||
}
|
||||
})
|
||||
@ -1,3 +1,4 @@
|
||||
import { definePlugin } from '@cherrystudio/ai-core'
|
||||
import type { LanguageModelMiddleware } from 'ai'
|
||||
|
||||
/**
|
||||
@ -7,11 +8,11 @@ import type { LanguageModelMiddleware } from 'ai'
|
||||
* @param enableThinking - Whether thinking mode is enabled (based on reasoning_effort !== undefined)
|
||||
* @returns LanguageModelMiddleware
|
||||
*/
|
||||
export function qwenThinkingMiddleware(enableThinking: boolean): LanguageModelMiddleware {
|
||||
function createQwenThinkingMiddleware(enableThinking: boolean): LanguageModelMiddleware {
|
||||
const suffix = enableThinking ? ' /think' : ' /no_think'
|
||||
|
||||
return {
|
||||
middlewareVersion: 'v2',
|
||||
specificationVersion: 'v3',
|
||||
|
||||
transformParams: async ({ params }) => {
|
||||
const transformedParams = { ...params }
|
||||
@ -37,3 +38,14 @@ export function qwenThinkingMiddleware(enableThinking: boolean): LanguageModelMi
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export const createQwenThinkingPlugin = (enableThinking: boolean) =>
|
||||
definePlugin({
|
||||
name: 'qwenThinking',
|
||||
enforce: 'pre',
|
||||
|
||||
configureContext: (context) => {
|
||||
context.middlewares = context.middlewares || []
|
||||
context.middlewares.push(createQwenThinkingMiddleware(enableThinking))
|
||||
}
|
||||
})
|
||||
22
src/renderer/src/aiCore/plugins/reasoningExtractionPlugin.ts
Normal file
22
src/renderer/src/aiCore/plugins/reasoningExtractionPlugin.ts
Normal file
@ -0,0 +1,22 @@
|
||||
import { definePlugin } from '@cherrystudio/ai-core'
|
||||
import { extractReasoningMiddleware } from 'ai'
|
||||
|
||||
/**
|
||||
* Reasoning Extraction Plugin
|
||||
* Extracts reasoning/thinking tags from OpenAI/Azure responses
|
||||
* Uses AI SDK's built-in extractReasoningMiddleware
|
||||
*/
|
||||
export const createReasoningExtractionPlugin = (options: { tagName?: string } = {}) =>
|
||||
definePlugin({
|
||||
name: 'reasoningExtraction',
|
||||
enforce: 'pre',
|
||||
|
||||
configureContext: (context) => {
|
||||
context.middlewares = context.middlewares || []
|
||||
context.middlewares.push(
|
||||
extractReasoningMiddleware({
|
||||
tagName: options.tagName || 'thinking'
|
||||
})
|
||||
)
|
||||
}
|
||||
})
|
||||
@ -6,7 +6,13 @@
|
||||
* 2. transformParams: 根据意图分析结果动态添加对应的工具
|
||||
* 3. onRequestEnd: 自动记忆存储
|
||||
*/
|
||||
import { type AiRequestContext, definePlugin } from '@cherrystudio/ai-core'
|
||||
import {
|
||||
type AiPlugin,
|
||||
type AiRequestContext,
|
||||
definePlugin,
|
||||
type StreamTextParams,
|
||||
type StreamTextResult
|
||||
} from '@cherrystudio/ai-core'
|
||||
import { loggerService } from '@logger'
|
||||
// import { generateObject } from '@cherrystudio/ai-core'
|
||||
import {
|
||||
@ -236,18 +242,21 @@ async function storeConversationMemory(
|
||||
/**
|
||||
* 🎯 搜索编排插件
|
||||
*/
|
||||
export const searchOrchestrationPlugin = (assistant: Assistant, topicId: string) => {
|
||||
export const searchOrchestrationPlugin = (
|
||||
assistant: Assistant,
|
||||
topicId: string
|
||||
): AiPlugin<StreamTextParams, StreamTextResult> => {
|
||||
// 存储意图分析结果
|
||||
const intentAnalysisResults: { [requestId: string]: ExtractResults } = {}
|
||||
const userMessages: { [requestId: string]: ModelMessage } = {}
|
||||
|
||||
return definePlugin({
|
||||
return definePlugin<StreamTextParams, StreamTextResult>({
|
||||
name: 'search-orchestration',
|
||||
enforce: 'pre', // 确保在其他插件之前执行
|
||||
/**
|
||||
* 🔍 Step 1: 意图识别阶段
|
||||
*/
|
||||
onRequestStart: async (context: AiRequestContext) => {
|
||||
onRequestStart: async (context) => {
|
||||
// 没开启任何搜索则不进行意图分析
|
||||
if (!(assistant.webSearchProviderId || assistant.knowledge_bases?.length || assistant.enableMemory)) return
|
||||
|
||||
@ -297,7 +306,7 @@ export const searchOrchestrationPlugin = (assistant: Assistant, topicId: string)
|
||||
/**
|
||||
* 🔧 Step 2: 工具配置阶段
|
||||
*/
|
||||
transformParams: async (params: any, context: AiRequestContext) => {
|
||||
transformParams: async (params, context) => {
|
||||
// logger.info('🔧 Configuring tools based on intent...', context.requestId)
|
||||
|
||||
try {
|
||||
@ -309,7 +318,7 @@ export const searchOrchestrationPlugin = (assistant: Assistant, topicId: string)
|
||||
|
||||
// 确保 tools 对象存在
|
||||
if (!params.tools) {
|
||||
params.tools = {}
|
||||
return { tools: {} }
|
||||
}
|
||||
|
||||
// 🌐 网络搜索工具配置
|
||||
@ -371,11 +380,12 @@ export const searchOrchestrationPlugin = (assistant: Assistant, topicId: string)
|
||||
* 💾 Step 3: 记忆存储阶段
|
||||
*/
|
||||
|
||||
onRequestEnd: async (context: AiRequestContext) => {
|
||||
onRequestEnd: async (context) => {
|
||||
// context.isAnalyzing = false
|
||||
// logger.info('context.isAnalyzing', context, result)
|
||||
// logger.info('💾 Starting memory storage...', context.requestId)
|
||||
try {
|
||||
// ✅ 类型安全访问:context.originalParams 已通过泛型正确类型化
|
||||
const messages = context.originalParams.messages
|
||||
|
||||
if (messages && assistant) {
|
||||
|
||||
18
src/renderer/src/aiCore/plugins/simulateStreamingPlugin.ts
Normal file
18
src/renderer/src/aiCore/plugins/simulateStreamingPlugin.ts
Normal file
@ -0,0 +1,18 @@
|
||||
import { definePlugin } from '@cherrystudio/ai-core'
|
||||
import { simulateStreamingMiddleware } from 'ai'
|
||||
|
||||
/**
|
||||
* Simulate Streaming Plugin
|
||||
* Converts non-streaming responses to streaming format
|
||||
* Uses AI SDK's built-in simulateStreamingMiddleware
|
||||
*/
|
||||
export const createSimulateStreamingPlugin = () =>
|
||||
definePlugin({
|
||||
name: 'simulateStreaming',
|
||||
enforce: 'pre',
|
||||
|
||||
configureContext: (context) => {
|
||||
context.middlewares = context.middlewares || []
|
||||
context.middlewares.push(simulateStreamingMiddleware())
|
||||
}
|
||||
})
|
||||
@ -1,3 +1,4 @@
|
||||
import { definePlugin } from '@cherrystudio/ai-core'
|
||||
import type { LanguageModelMiddleware } from 'ai'
|
||||
|
||||
/**
|
||||
@ -8,10 +9,10 @@ import type { LanguageModelMiddleware } from 'ai'
|
||||
* @param aiSdkId AI SDK Provider ID
|
||||
* @returns LanguageModelMiddleware
|
||||
*/
|
||||
export function skipGeminiThoughtSignatureMiddleware(aiSdkId: string): LanguageModelMiddleware {
|
||||
function createSkipGeminiThoughtSignatureMiddleware(aiSdkId: string): LanguageModelMiddleware {
|
||||
const MAGIC_STRING = 'skip_thought_signature_validator'
|
||||
return {
|
||||
middlewareVersion: 'v2',
|
||||
specificationVersion: 'v3',
|
||||
|
||||
transformParams: async ({ params }) => {
|
||||
const transformedParams = { ...params }
|
||||
@ -34,3 +35,14 @@ export function skipGeminiThoughtSignatureMiddleware(aiSdkId: string): LanguageM
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export const createSkipGeminiThoughtSignaturePlugin = (aiSdkId: string) =>
|
||||
definePlugin({
|
||||
name: 'skipGeminiThoughtSignature',
|
||||
enforce: 'pre',
|
||||
|
||||
configureContext: (context) => {
|
||||
context.middlewares = context.middlewares || []
|
||||
context.middlewares.push(createSkipGeminiThoughtSignatureMiddleware(aiSdkId))
|
||||
}
|
||||
})
|
||||
@ -47,7 +47,7 @@ import { getMaxTokens, getTemperature, getTopP } from './modelParameters'
|
||||
|
||||
const logger = loggerService.withContext('parameterBuilder')
|
||||
|
||||
type ProviderDefinedTool = Extract<Tool<any, any>, { type: 'provider-defined' }>
|
||||
type ProviderDefinedTool = Extract<Tool<any, any>, { type: 'provider' }>
|
||||
|
||||
function mapVertexAIGatewayModelToProviderId(model: Model): BaseProviderId | undefined {
|
||||
if (isAnthropicModel(model)) {
|
||||
@ -62,9 +62,7 @@ function mapVertexAIGatewayModelToProviderId(model: Model): BaseProviderId | und
|
||||
if (isOpenAIModel(model)) {
|
||||
return 'openai'
|
||||
}
|
||||
logger.warn(
|
||||
`[mapVertexAIGatewayModelToProviderId] Unknown model type for AI Gateway: ${model.id}. Web search will not be enabled.`
|
||||
)
|
||||
logger.warn(`Unknown model type for AI Gateway: ${model.id}. Web search will not be enabled.`)
|
||||
return undefined
|
||||
}
|
||||
|
||||
|
||||
26
src/renderer/src/aiCore/types/middlewareConfig.ts
Normal file
26
src/renderer/src/aiCore/types/middlewareConfig.ts
Normal file
@ -0,0 +1,26 @@
|
||||
import type { WebSearchPluginConfig } from '@cherrystudio/ai-core/built-in/plugins'
|
||||
import type { MCPTool } from '@renderer/types'
|
||||
import type { Assistant, Message, Model, Provider } from '@renderer/types'
|
||||
import type { Chunk } from '@renderer/types/chunk'
|
||||
|
||||
/**
|
||||
* AI SDK 中间件配置项(用于插件构建)
|
||||
*/
|
||||
export interface AiSdkMiddlewareConfig {
|
||||
streamOutput: boolean
|
||||
onChunk?: (chunk: Chunk) => void
|
||||
model?: Model
|
||||
provider?: Provider
|
||||
assistant?: Assistant
|
||||
enableReasoning: boolean
|
||||
isPromptToolUse: boolean
|
||||
isSupportedToolUse: boolean
|
||||
isImageGenerationEndpoint: boolean
|
||||
enableWebSearch: boolean
|
||||
enableGenerateImage: boolean
|
||||
enableUrlContext: boolean
|
||||
mcpTools?: MCPTool[]
|
||||
uiMessages?: Message[]
|
||||
webSearchPluginConfig?: WebSearchPluginConfig
|
||||
knowledgeRecognition?: 'off' | 'on'
|
||||
}
|
||||
@ -726,3 +726,21 @@ export function getCustomParameters(assistant: Assistant): Record<string, any> {
|
||||
}, {}) || {}
|
||||
)
|
||||
}
|
||||
|
||||
/**
|
||||
* Get reasoning tag name based on model ID
|
||||
* Used for extractReasoningMiddleware configuration
|
||||
*/
|
||||
export function getReasoningTagName(modelId: string | undefined): string {
|
||||
const tagName = {
|
||||
reasoning: 'reasoning',
|
||||
think: 'think',
|
||||
thought: 'thought',
|
||||
seedThink: 'seed:think'
|
||||
}
|
||||
|
||||
if (modelId?.includes('gpt-oss')) return tagName.reasoning
|
||||
if (modelId?.includes('gemini')) return tagName.thought
|
||||
if (modelId?.includes('seed-oss-36b')) return tagName.seedThink
|
||||
return tagName.think
|
||||
}
|
||||
|
||||
@ -2,8 +2,8 @@
|
||||
* 职责:提供原子化的、无状态的API调用函数
|
||||
*/
|
||||
import { loggerService } from '@logger'
|
||||
import type { AiSdkMiddlewareConfig } from '@renderer/aiCore/middleware/AiSdkMiddlewareBuilder'
|
||||
import { buildStreamTextParams } from '@renderer/aiCore/prepareParams'
|
||||
import type { AiSdkMiddlewareConfig } from '@renderer/aiCore/types/middlewareConfig'
|
||||
import { isDedicatedImageGenerationModel, isEmbeddingModel, isFunctionCallingModel } from '@renderer/config/models'
|
||||
import { getStoreSetting } from '@renderer/hooks/useSettings'
|
||||
import i18n from '@renderer/i18n'
|
||||
|
||||
@ -1,9 +1,14 @@
|
||||
import type OpenAI from '@cherrystudio/openai'
|
||||
import type { NotUndefined } from '@types'
|
||||
import type { ImageModel, LanguageModel } from 'ai'
|
||||
import type { generateObject, generateText, ModelMessage, streamObject, streamText } from 'ai'
|
||||
import type { generateText, ModelMessage, streamText } from 'ai'
|
||||
import * as z from 'zod'
|
||||
|
||||
/**
|
||||
* 渲染器侧参数类型(不包含 model 和 messages,因为它们会单独处理)
|
||||
* 注意:这与 @cherrystudio/ai-core 导出的完整参数类型不同
|
||||
* - @cherrystudio/ai-core 的 StreamTextParams: 完整的 AI SDK 参数(用于插件系统)
|
||||
* - 此处的 StreamTextParams: 去除 model/messages 的参数(用于渲染器参数构建)
|
||||
*/
|
||||
export type StreamTextParams = Omit<Parameters<typeof streamText>[0], 'model' | 'messages'> &
|
||||
(
|
||||
| {
|
||||
@ -15,6 +20,11 @@ export type StreamTextParams = Omit<Parameters<typeof streamText>[0], 'model' |
|
||||
prompt?: never
|
||||
}
|
||||
)
|
||||
|
||||
/**
|
||||
* 渲染器侧参数类型(不包含 model 和 messages)
|
||||
* 注意:这与 @cherrystudio/ai-core 导出的完整参数类型不同
|
||||
*/
|
||||
export type GenerateTextParams = Omit<Parameters<typeof generateText>[0], 'model' | 'messages'> &
|
||||
(
|
||||
| {
|
||||
@ -26,10 +36,6 @@ export type GenerateTextParams = Omit<Parameters<typeof generateText>[0], 'model
|
||||
prompt?: never
|
||||
}
|
||||
)
|
||||
export type StreamObjectParams = Omit<Parameters<typeof streamObject>[0], 'model'>
|
||||
export type GenerateObjectParams = Omit<Parameters<typeof generateObject>[0], 'model'>
|
||||
|
||||
export type AiSdkModel = LanguageModel | ImageModel
|
||||
|
||||
/**
|
||||
* Constrains the verbosity of the model's response. Lower values will result in more concise responses, while higher values will result in more verbose responses.
|
||||
|
||||
43
yarn.lock
43
yarn.lock
@ -205,6 +205,18 @@ __metadata:
|
||||
languageName: node
|
||||
linkType: hard
|
||||
|
||||
"@ai-sdk/openai-compatible@npm:1.0.28":
|
||||
version: 1.0.28
|
||||
resolution: "@ai-sdk/openai-compatible@npm:1.0.28"
|
||||
dependencies:
|
||||
"@ai-sdk/provider": "npm:2.0.0"
|
||||
"@ai-sdk/provider-utils": "npm:3.0.18"
|
||||
peerDependencies:
|
||||
zod: ^3.25.76 || ^4.1.8
|
||||
checksum: 10c0/f484774e0094a12674f392d925038a296191723b4c76bd833eabf1b334cf3c84fe77a2e2c5fbac974ec5e18340e113c6a81c86d957c9529a7a60e87cd390ada8
|
||||
languageName: node
|
||||
linkType: hard
|
||||
|
||||
"@ai-sdk/openai-compatible@npm:2.0.0, @ai-sdk/openai-compatible@npm:^2.0.0":
|
||||
version: 2.0.0
|
||||
resolution: "@ai-sdk/openai-compatible@npm:2.0.0"
|
||||
@ -217,15 +229,15 @@ __metadata:
|
||||
languageName: node
|
||||
linkType: hard
|
||||
|
||||
"@ai-sdk/openai-compatible@npm:^1.0.19":
|
||||
version: 1.0.27
|
||||
resolution: "@ai-sdk/openai-compatible@npm:1.0.27"
|
||||
"@ai-sdk/openai-compatible@patch:@ai-sdk/openai-compatible@npm%3A1.0.28#~/.yarn/patches/@ai-sdk-openai-compatible-npm-1.0.28-5705188855.patch":
|
||||
version: 1.0.28
|
||||
resolution: "@ai-sdk/openai-compatible@patch:@ai-sdk/openai-compatible@npm%3A1.0.28#~/.yarn/patches/@ai-sdk-openai-compatible-npm-1.0.28-5705188855.patch::version=1.0.28&hash=f2cb20"
|
||||
dependencies:
|
||||
"@ai-sdk/provider": "npm:2.0.0"
|
||||
"@ai-sdk/provider-utils": "npm:3.0.17"
|
||||
"@ai-sdk/provider-utils": "npm:3.0.18"
|
||||
peerDependencies:
|
||||
zod: ^3.25.76 || ^4.1.8
|
||||
checksum: 10c0/9f656e4f2ea4d714dc05be588baafd962b2e0360e9195fef373e745efeb20172698ea87e1033c0c5e1f1aa6e0db76a32629427bc8433eb42bd1a0ee00e04af0c
|
||||
checksum: 10c0/0b1d99fe8ce506e5c0a3703ae0511ac2017781584074d41faa2df82923c64eb1229ffe9f036de150d0248923613c761a463fe89d5923493983e0463a1101e792
|
||||
languageName: node
|
||||
linkType: hard
|
||||
|
||||
@ -277,16 +289,16 @@ __metadata:
|
||||
languageName: node
|
||||
linkType: hard
|
||||
|
||||
"@ai-sdk/provider-utils@npm:3.0.17, @ai-sdk/provider-utils@npm:^3.0.10, @ai-sdk/provider-utils@npm:^3.0.17":
|
||||
version: 3.0.17
|
||||
resolution: "@ai-sdk/provider-utils@npm:3.0.17"
|
||||
"@ai-sdk/provider-utils@npm:3.0.18":
|
||||
version: 3.0.18
|
||||
resolution: "@ai-sdk/provider-utils@npm:3.0.18"
|
||||
dependencies:
|
||||
"@ai-sdk/provider": "npm:2.0.0"
|
||||
"@standard-schema/spec": "npm:^1.0.0"
|
||||
eventsource-parser: "npm:^3.0.6"
|
||||
peerDependencies:
|
||||
zod: ^3.25.76 || ^4.1.8
|
||||
checksum: 10c0/1bae6dc4cacd0305b6aa152f9589bbd61c29f150155482c285a77f83d7ed416d52bc2aa7fdaba2e5764530392d9e8f799baea34a63dce6c72ecd3de364dc62d1
|
||||
checksum: 10c0/209c15b0dceef0ba95a7d3de544be0a417ad4a0bd5143496b3966a35fedf144156d93a42ff8c3d7db56781b9836bafc8c132c98978c49240e55bc1a36e18a67f
|
||||
languageName: node
|
||||
linkType: hard
|
||||
|
||||
@ -316,6 +328,19 @@ __metadata:
|
||||
languageName: node
|
||||
linkType: hard
|
||||
|
||||
"@ai-sdk/provider-utils@npm:^3.0.10, @ai-sdk/provider-utils@npm:^3.0.17":
|
||||
version: 3.0.17
|
||||
resolution: "@ai-sdk/provider-utils@npm:3.0.17"
|
||||
dependencies:
|
||||
"@ai-sdk/provider": "npm:2.0.0"
|
||||
"@standard-schema/spec": "npm:^1.0.0"
|
||||
eventsource-parser: "npm:^3.0.6"
|
||||
peerDependencies:
|
||||
zod: ^3.25.76 || ^4.1.8
|
||||
checksum: 10c0/1bae6dc4cacd0305b6aa152f9589bbd61c29f150155482c285a77f83d7ed416d52bc2aa7fdaba2e5764530392d9e8f799baea34a63dce6c72ecd3de364dc62d1
|
||||
languageName: node
|
||||
linkType: hard
|
||||
|
||||
"@ai-sdk/provider@npm:2.0.0, @ai-sdk/provider@npm:^2.0.0":
|
||||
version: 2.0.0
|
||||
resolution: "@ai-sdk/provider@npm:2.0.0"
|
||||
|
||||
Loading…
Reference in New Issue
Block a user