diff --git a/docs/zh/guides/ai-core-architecture.md b/docs/zh/guides/ai-core-architecture.md new file mode 100644 index 0000000000..0b377dfa67 --- /dev/null +++ b/docs/zh/guides/ai-core-architecture.md @@ -0,0 +1,2215 @@ +# Cherry Studio AI Core 架构文档 + +> **版本**: v2.0 (基于 @cherrystudio/ai-core 重构后) +> **更新日期**: 2025-01-02 +> **适用范围**: Cherry Studio v1.7.7+ + +本文档详细描述了 Cherry Studio 从用户交互到 AI SDK 调用的完整数据流和架构设计,是理解应用核心功能的关键文档。 + +--- + +## 📖 目录 + +1. [整体架构概览](#1-整体架构概览) +2. [完整调用流程](#2-完整调用流程) +3. [核心组件详解](#3-核心组件详解) +4. [Provider 系统架构](#4-provider-系统架构) +5. [插件与中间件系统](#5-插件与中间件系统) +6. [消息处理流程](#6-消息处理流程) +7. [类型安全机制](#7-类型安全机制) +8. [Trace 和可观测性](#8-trace-和可观测性) +9. [错误处理机制](#9-错误处理机制) +10. [性能优化](#10-性能优化) + +--- + +## 1. 整体架构概览 + +### 1.1 架构分层 + +Cherry Studio 的 AI 调用采用清晰的分层架构: + +``` +┌─────────────────────────────────────────────────────────────┐ +│ UI Layer │ +│ (React Components, Redux Store, User Interactions) │ +└────────────────────────┬────────────────────────────────────┘ + │ + ▼ +┌─────────────────────────────────────────────────────────────┐ +│ Service Layer │ +│ src/renderer/src/services/ │ +│ ┌────────────────────────────────────────────────────┐ │ +│ │ ApiService.ts │ │ +│ │ - transformMessagesAndFetch() │ │ +│ │ - fetchChatCompletion() │ │ +│ │ - fetchMessagesSummary() │ │ +│ └────────────────────────────────────────────────────┘ │ +└────────────────────────┬────────────────────────────────────┘ + │ + ▼ +┌─────────────────────────────────────────────────────────────┐ +│ AI Provider Layer │ +│ src/renderer/src/aiCore/ │ +│ ┌────────────────────────────────────────────────────┐ │ +│ │ ModernAiProvider (index_new.ts) │ │ +│ │ - completions() │ │ +│ │ - modernCompletions() │ │ +│ │ - _completionsForTrace() │ │ +│ └────────────────────────────────────────────────────┘ │ +│ ┌────────────────────────────────────────────────────┐ │ +│ │ Provider Config & Adaptation │ │ +│ │ - providerConfig.ts │ │ +│ │ - providerToAiSdkConfig() │ │ +│ └────────────────────────────────────────────────────┘ │ +└────────────────────────┬────────────────────────────────────┘ + │ + ▼ +┌─────────────────────────────────────────────────────────────┐ +│ Core Package Layer │ +│ packages/aiCore/ (@cherrystudio/ai-core) │ +│ ┌────────────────────────────────────────────────────┐ │ +│ │ RuntimeExecutor │ │ +│ │ - streamText() │ │ +│ │ - generateText() │ │ +│ │ - generateImage() │ │ +│ └────────────────────────────────────────────────────┘ │ +│ ┌────────────────────────────────────────────────────┐ │ +│ │ Provider Extension System │ │ +│ │ - ProviderExtension (LRU Cache) │ │ +│ │ - ExtensionRegistry │ │ +│ │ - OpenAI/Anthropic/Google Extensions │ │ +│ └────────────────────────────────────────────────────┘ │ +│ ┌────────────────────────────────────────────────────┐ │ +│ │ Plugin Engine │ │ +│ │ - PluginManager │ │ +│ │ - AiPlugin Lifecycle Hooks │ │ +│ └────────────────────────────────────────────────────┘ │ +└────────────────────────┬────────────────────────────────────┘ + │ + ▼ +┌─────────────────────────────────────────────────────────────┐ +│ AI SDK Layer │ +│ Vercel AI SDK v6.x (@ai-sdk/*) │ +│ ┌─────────────────────────────────────────────────────┐ │ +│ │ Provider Implementations │ │ +│ │ - @ai-sdk/openai │ │ +│ │ - @ai-sdk/anthropic │ │ +│ │ - @ai-sdk/google-generative-ai │ │ +│ │ - @ai-sdk/mistral │ │ +│ └─────────────────────────────────────────────────────┘ │ +│ ┌─────────────────────────────────────────────────────┐ │ +│ │ Core Functions │ │ +│ │ - streamText() │ │ +│ │ - generateText() │ │ +│ └─────────────────────────────────────────────────────┘ │ +└────────────────────────┬────────────────────────────────────┘ + │ + ▼ +┌─────────────────────────────────────────────────────────────┐ +│ LLM Provider API +│ (OpenAI, Anthropic, Google, etc.) │ +└─────────────────────────────────────────────────────────────┘ +``` + +### 1.2 核心设计理念 + +#### 1.2.1 关注点分离 (Separation of Concerns) + +- **Service Layer**: 业务逻辑、消息准备、工具调用 +- **AI Provider Layer**: Provider 适配、参数转换、插件构建 +- **Core Package**: 统一 API、Provider 管理、插件执行 +- **AI SDK Layer**: 实际的 LLM API 调用 + +#### 1.2.2 类型安全优先 + +- 端到端 TypeScript 类型推断 +- Provider Settings 自动关联 +- 编译时参数验证 + +#### 1.2.3 可扩展性 + +- 插件化架构 (AiPlugin) +- Provider Extension 系统 +- 中间件机制 + +--- + +## 2. 完整调用流程 + +### 2.1 从用户输入到 LLM 响应的完整流程 + +#### 流程图 + +``` +User Input (UI) + │ + ▼ +┌─────────────────────────────────────────────────────────────┐ +│ 1. UI Event Handler │ +│ - ChatView/MessageInput Component │ +│ - Redux dispatch action │ +└─────────────────────────┬───────────────────────────────────┘ + │ + ▼ +┌─────────────────────────────────────────────────────────────┐ +│ 2. ApiService.transformMessagesAndFetch() │ +│ Location: src/renderer/src/services/ApiService.ts:92 │ +│ │ +│ Step 2.1: ConversationService.prepareMessagesForModel() │ +│ ├─ 消息格式转换 (UI Message → Model Message) │ +│ ├─ 处理图片/文件附件 │ +│ └─ 应用消息过滤规则 │ +│ │ +│ Step 2.2: replacePromptVariables() │ +│ └─ 替换 system prompt 中的变量 │ +│ │ +│ Step 2.3: injectUserMessageWithKnowledgeSearchPrompt() │ +│ └─ 注入知识库搜索提示(如果启用) │ +│ │ +│ Step 2.4: fetchChatCompletion() ────────────────────────► │ +└─────────────────────────────────────────────────────────────┘ + │ + ▼ +┌─────────────────────────────────────────────────────────────┐ +│ 3. ApiService.fetchChatCompletion() │ +│ Location: src/renderer/src/services/ApiService.ts:139 │ +│ │ +│ Step 3.1: getProviderByModel() + API Key Rotation │ +│ ├─ 获取 provider 配置 │ +│ ├─ 应用 API Key 轮换(多 key 负载均衡) │ +│ └─ 创建 providerWithRotatedKey │ +│ │ +│ Step 3.2: new ModernAiProvider(model, provider) │ +│ └─ 初始化 AI Provider 实例 │ +│ │ +│ Step 3.3: buildStreamTextParams() │ +│ ├─ 构建 AI SDK 参数 │ +│ ├─ 处理 MCP 工具 │ +│ ├─ 处理 Web Search 配置 │ +│ └─ 返回 aiSdkParams + capabilities │ +│ │ +│ Step 3.4: buildPlugins(middlewareConfig) │ +│ └─ 根据 capabilities 构建插件数组 │ +│ │ +│ Step 3.5: AI.completions(modelId, params, config) ──────► │ +└─────────────────────────────────────────────────────────────┘ + │ + ▼ +┌─────────────────────────────────────────────────────────────┐ +│ 4. ModernAiProvider.completions() │ +│ Location: src/renderer/src/aiCore/index_new.ts:116 │ +│ │ +│ Step 4.1: providerToAiSdkConfig() │ +│ ├─ 转换 Cherry Provider → AI SDK Config │ +│ ├─ 设置 providerId ('openai', 'anthropic', etc.) │ +│ └─ 设置 providerSettings (apiKey, baseURL, etc.) │ +│ │ +│ Step 4.2: Claude Code OAuth 特殊处理 │ +│ └─ 注入 Claude Code system message(如果是 OAuth) │ +│ │ +│ Step 4.3: 路由选择 │ +│ ├─ 如果启用 trace → _completionsForTrace() │ +│ └─ 否则 → _completionsOrImageGeneration() │ +└─────────────────────────┬───────────────────────────────────┘ + │ + ▼ +┌─────────────────────────────────────────────────────────────┐ +│ 5. ModernAiProvider._completionsOrImageGeneration() │ +│ Location: src/renderer/src/aiCore/index_new.ts:167 │ +│ │ +│ 判断: │ +│ ├─ 图像生成端点 → legacyProvider.completions() │ +│ └─ 文本生成 → modernCompletions() ──────────────────────► │ +└─────────────────────────────────────────────────────────────┘ + │ + ▼ +┌─────────────────────────────────────────────────────────────┐ +│ 6. ModernAiProvider.modernCompletions() │ +│ Location: src/renderer/src/aiCore/index_new.ts:284 │ +│ │ +│ Step 6.1: buildPlugins(config) │ +│ └─ 构建插件数组(Reasoning, ToolUse, WebSearch, etc.) │ +│ │ +│ Step 6.2: createExecutor() ─────────────────────────────► │ +│ └─ 创建 RuntimeExecutor 实例 │ +└─────────────────────────────────────────────────────────────┘ + │ + ▼ +┌─────────────────────────────────────────────────────────────┐ +│ 7. packages/aiCore: createExecutor() │ +│ Location: packages/aiCore/src/core/runtime/index.ts:25 │ +│ │ +│ Step 7.1: extensionRegistry.createProvider() │ +│ ├─ 解析 providerId (支持别名和变体) │ +│ ├─ 获取 ProviderExtension 实例 │ +│ ├─ 计算 settings hash │ +│ ├─ LRU 缓存查找 │ +│ │ ├─ Cache hit → 返回缓存实例 │ +│ │ └─ Cache miss → 创建新实例 │ +│ └─ 返回 ProviderV3 实例 │ +│ │ +│ Step 7.2: RuntimeExecutor.create() │ +│ ├─ 创建 RuntimeExecutor 实例 │ +│ ├─ 注入 provider 引用 │ +│ ├─ 初始化 ModelResolver │ +│ └─ 初始化 PluginEngine │ +│ │ +│ 返回: RuntimeExecutor 实例 ───────────────────────────► │ +└─────────────────────────────────────────────────────────────┘ + │ + ▼ +┌─────────────────────────────────────────────────────────────┐ +│ 8. RuntimeExecutor.streamText() │ +│ Location: packages/aiCore/src/core/runtime/executor.ts │ +│ │ +│ Step 8.1: 插件生命周期 - onRequestStart │ +│ └─ 执行所有插件的 onRequestStart 钩子 │ +│ │ +│ Step 8.2: 插件转换 - transformParams │ +│ └─ 链式执行所有插件的参数转换 │ +│ │ +│ Step 8.3: modelResolver.resolveModel() │ +│ └─ 解析 model string → LanguageModel 实例 │ +│ │ +│ Step 8.4: 调用 AI SDK streamText() ──────────────────────►│ +│ └─ 传入解析后的 model 和转换后的 params │ +└─────────────────────────────────────────────────────────────┘ + │ + ▼ +┌─────────────────────────────────────────────────────────────┐ +│ 9. AI SDK: streamText() │ +│ Location: node_modules/ai/core/generate-text/stream-text │ +│ │ +│ Step 9.1: 参数验证 │ +│ Step 9.2: 调用 provider.doStream() │ +│ Step 9.3: 返回 StreamTextResult │ +│ └─ textStream, fullStream, usage, etc. │ +└─────────────────────────┬───────────────────────────────────┘ + │ + ▼ +┌─────────────────────────────────────────────────────────────┐ +│ 10. 流式数据处理 │ +│ Location: src/renderer/src/aiCore/chunk/ │ +│ │ +│ Step 10.1: AiSdkToChunkAdapter.processStream() │ +│ ├─ 监听 AI SDK 的 textStream │ +│ ├─ 转换为 Cherry Chunk 格式 │ +│ ├─ 处理 tool calls │ +│ ├─ 处理 reasoning blocks │ +│ └─ 发送 chunk 到 onChunkReceived callback │ +│ │ +│ Step 10.2: StreamProcessingService │ +│ └─ 处理不同类型的 chunk 并更新 UI │ +└─────────────────────────┬───────────────────────────────────┘ + │ + ▼ +┌─────────────────────────────────────────────────────────────┐ +│ 11. 插件生命周期 - 完成阶段 │ +│ │ +│ Step 11.1: transformResult │ +│ └─ 插件可以修改最终结果 │ +│ │ +│ Step 11.2: onRequestEnd │ +│ └─ 执行所有插件的完成钩子 │ +└─────────────────────────┬───────────────────────────────────┘ + │ + ▼ +┌─────────────────────────────────────────────────────────────┐ +│ 12. UI Update │ +│ - Redux state 更新 │ +│ - React 组件重渲染 │ +│ - 显示完整响应 │ +└─────────────────────────────────────────────────────────────┘ +``` + +### 2.2 关键时序说明 + +#### 2.2.1 Provider 实例创建(LRU 缓存机制) + +```typescript +// 场景 1: 首次请求 OpenAI (Cache Miss) +const executor1 = await createExecutor('openai', { apiKey: 'sk-xxx' }) +// → extensionRegistry.createProvider('openai', { apiKey: 'sk-xxx' }) +// → 计算 hash: "abc123" +// → LRU cache miss +// → OpenAIExtension.factory() 创建新 provider +// → 存入 LRU: cache.set("abc123", provider) + +// 场景 2: 相同配置的第二次请求 (Cache Hit) +const executor2 = await createExecutor('openai', { apiKey: 'sk-xxx' }) +// → 计算 hash: "abc123" (相同) +// → LRU cache hit! +// → 直接返回缓存的 provider +// → executor1 和 executor2 共享同一个 provider 实例 + +// 场景 3: 不同配置 (Cache Miss + 新实例) +const executor3 = await createExecutor('openai', { + apiKey: 'sk-yyy', // 不同的 key + baseURL: 'https://custom.com/v1' +}) +// → 计算 hash: "def456" (不同) +// → LRU cache miss +// → 创建新的独立 provider 实例 +// → 存入 LRU: cache.set("def456", provider2) +``` + +#### 2.2.2 插件执行顺序 + +```typescript +// 示例:启用 Reasoning + ToolUse + WebSearch +plugins = [ReasoningPlugin, ToolUsePlugin, WebSearchPlugin] + +// 执行顺序: +1. onRequestStart: Reasoning → ToolUse → WebSearch +2. transformParams: Reasoning → ToolUse → WebSearch (链式) +3. [AI SDK 调用] +4. transformResult: WebSearch → ToolUse → Reasoning (反向) +5. onRequestEnd: WebSearch → ToolUse → Reasoning (反向) +``` + +--- + +## 3. 核心组件详解 + +### 3.1 ApiService Layer + +#### 文件位置 +`src/renderer/src/services/ApiService.ts` + +#### 核心职责 + +1. **消息准备和转换** +2. **MCP 工具集成** +3. **知识库搜索注入** +4. **API Key 轮换** +5. **调用 ModernAiProvider** + +#### 关键函数详解 + +##### 3.1.1 `transformMessagesAndFetch()` + +**签名**: +```typescript +async function transformMessagesAndFetch( + request: { + messages: Message[] + assistant: Assistant + blockManager: BlockManager + assistantMsgId: string + callbacks: StreamProcessorCallbacks + topicId?: string + options: { + signal?: AbortSignal + timeout?: number + headers?: Record + } + }, + onChunkReceived: (chunk: Chunk) => void +): Promise +``` + +**执行流程**: + +```typescript +// Step 1: 消息准备 +const { modelMessages, uiMessages } = + await ConversationService.prepareMessagesForModel(messages, assistant) + +// modelMessages: 转换为 LLM 理解的格式 +// uiMessages: 保留原始 UI 消息(用于某些特殊场景) + +// Step 2: 替换 prompt 变量 +assistant.prompt = await replacePromptVariables( + assistant.prompt, + assistant.model?.name +) +// 例如: "{model_name}" → "GPT-4" + +// Step 3: 注入知识库搜索 +await injectUserMessageWithKnowledgeSearchPrompt({ + modelMessages, + assistant, + assistantMsgId, + topicId, + blockManager, + setCitationBlockId +}) + +// Step 4: 发起实际请求 +await fetchChatCompletion({ + messages: modelMessages, + assistant, + topicId, + requestOptions, + uiMessages, + onChunkReceived +}) +``` + +##### 3.1.2 `fetchChatCompletion()` + +**关键代码分析**: + +```typescript +export async function fetchChatCompletion({ + messages, + assistant, + requestOptions, + onChunkReceived, + topicId, + uiMessages +}: FetchChatCompletionParams) { + + // 1. Provider 准备 + API Key 轮换 + const baseProvider = getProviderByModel(assistant.model || getDefaultModel()) + const providerWithRotatedKey = { + ...baseProvider, + apiKey: getRotatedApiKey(baseProvider) // ✅ 多 key 负载均衡 + } + + // 2. 创建 AI Provider 实例 + const AI = new ModernAiProvider( + assistant.model || getDefaultModel(), + providerWithRotatedKey + ) + + // 3. 获取 MCP 工具 + const mcpTools: MCPTool[] = [] + if (isPromptToolUse(assistant) || isSupportedToolUse(assistant)) { + mcpTools.push(...(await fetchMcpTools(assistant))) + } + + // 4. 构建 AI SDK 参数 + const { + params: aiSdkParams, + modelId, + capabilities, + webSearchPluginConfig + } = await buildStreamTextParams(messages, assistant, provider, { + mcpTools, + webSearchProviderId: assistant.webSearchProviderId, + requestOptions + }) + + // 5. 构建中间件配置 + const middlewareConfig: AiSdkMiddlewareConfig = { + streamOutput: assistant.settings?.streamOutput ?? true, + onChunk: onChunkReceived, + model: assistant.model, + enableReasoning: capabilities.enableReasoning, + isPromptToolUse: usePromptToolUse, + isSupportedToolUse: isSupportedToolUse(assistant), + isImageGenerationEndpoint: isDedicatedImageGenerationModel(assistant.model), + webSearchPluginConfig, + enableWebSearch: capabilities.enableWebSearch, + enableGenerateImage: capabilities.enableGenerateImage, + enableUrlContext: capabilities.enableUrlContext, + mcpTools, + uiMessages, + knowledgeRecognition: assistant.knowledgeRecognition + } + + // 6. 调用 AI.completions() + await AI.completions(modelId, aiSdkParams, { + ...middlewareConfig, + assistant, + topicId, + callType: 'chat', + uiMessages + }) +} +``` + +**API Key 轮换机制**: + +```typescript +function getRotatedApiKey(provider: Provider): string { + const keys = provider.apiKey.split(',').map(k => k.trim()).filter(Boolean) + + if (keys.length === 1) return keys[0] + + const keyName = `provider:${provider.id}:last_used_key` + const lastUsedKey = window.keyv.get(keyName) + + const currentIndex = keys.indexOf(lastUsedKey) + const nextIndex = (currentIndex + 1) % keys.length + const nextKey = keys[nextIndex] + + window.keyv.set(keyName, nextKey) + return nextKey +} + +// 使用场景: +// provider.apiKey = "sk-key1,sk-key2,sk-key3" +// 请求 1 → 使用 sk-key1 +// 请求 2 → 使用 sk-key2 +// 请求 3 → 使用 sk-key3 +// 请求 4 → 使用 sk-key1 (轮回) +``` + +### 3.2 ModernAiProvider Layer + +#### 文件位置 +`src/renderer/src/aiCore/index_new.ts` + +#### 核心职责 + +1. **Provider 配置转换** (Cherry Provider → AI SDK Config) +2. **插件构建** (根据 capabilities) +3. **Trace 集成** (OpenTelemetry) +4. **调用 RuntimeExecutor** +5. **流式数据适配** (AI SDK Stream → Cherry Chunk) + +#### 构造函数详解 + +```typescript +constructor(modelOrProvider: Model | Provider, provider?: Provider) { + if (this.isModel(modelOrProvider)) { + // 情况 1: new ModernAiProvider(model, provider) + this.model = modelOrProvider + this.actualProvider = provider + ? adaptProvider({ provider, model: modelOrProvider }) + : getActualProvider(modelOrProvider) + + // 同步或异步创建 config + const configOrPromise = providerToAiSdkConfig( + this.actualProvider, + modelOrProvider + ) + this.config = configOrPromise instanceof Promise + ? undefined + : configOrPromise + } else { + // 情况 2: new ModernAiProvider(provider) + this.actualProvider = adaptProvider({ provider: modelOrProvider }) + } + + this.legacyProvider = new LegacyAiProvider(this.actualProvider) +} +``` + +#### completions() 方法详解 + +```typescript +public async completions( + modelId: string, + params: StreamTextParams, + providerConfig: ModernAiProviderConfig +) { + // 1. 确保 config 已准备 + if (!this.config) { + this.config = await Promise.resolve( + providerToAiSdkConfig(this.actualProvider, this.model!) + ) + } + + // 2. Claude Code OAuth 特殊处理 + if (this.actualProvider.id === 'anthropic' && + this.actualProvider.authType === 'oauth') { + const claudeCodeSystemMessage = buildClaudeCodeSystemModelMessage( + params.system + ) + params.system = undefined + params.messages = [...claudeCodeSystemMessage, ...(params.messages || [])] + } + + // 3. 路由选择 + if (providerConfig.topicId && getEnableDeveloperMode()) { + return await this._completionsForTrace(modelId, params, { + ...providerConfig, + topicId: providerConfig.topicId + }) + } else { + return await this._completionsOrImageGeneration(modelId, params, providerConfig) + } +} +``` + +#### modernCompletions() 核心实现 + +```typescript +private async modernCompletions( + modelId: string, + params: StreamTextParams, + config: ModernAiProviderConfig +): Promise { + + // 1. 构建插件 + const plugins = buildPlugins(config) + + // 2. 创建 RuntimeExecutor + const executor = await createExecutor( + this.config!.providerId, + this.config!.providerSettings, + plugins + ) + + // 3. 流式调用 + if (config.onChunk) { + const accumulate = this.model!.supported_text_delta !== false + const adapter = new AiSdkToChunkAdapter( + config.onChunk, + config.mcpTools, + accumulate, + config.enableWebSearch + ) + + const streamResult = await executor.streamText({ + ...params, + model: modelId, + experimental_context: { onChunk: config.onChunk } + }) + + const finalText = await adapter.processStream(streamResult) + + return { getText: () => finalText } + } else { + // 非流式调用 + const streamResult = await executor.streamText({ + ...params, + model: modelId + }) + + await streamResult?.consumeStream() + const finalText = await streamResult.text + + return { getText: () => finalText } + } +} +``` + +#### Trace 集成详解 + +```typescript +private async _completionsForTrace( + modelId: string, + params: StreamTextParams, + config: ModernAiProviderConfig & { topicId: string } +): Promise { + + const traceName = `${this.actualProvider.name}.${modelId}.${config.callType}` + + // 1. 创建 OpenTelemetry Span + const span = addSpan({ + name: traceName, + tag: 'LLM', + topicId: config.topicId, + modelName: config.assistant.model?.name, + inputs: params + }) + + if (!span) { + return await this._completionsOrImageGeneration(modelId, params, config) + } + + try { + // 2. 在 span 上下文中执行 + const result = await this._completionsOrImageGeneration(modelId, params, config) + + // 3. 标记 span 成功 + endSpan({ + topicId: config.topicId, + outputs: result, + span, + modelName: modelId + }) + + return result + } catch (error) { + // 4. 标记 span 失败 + endSpan({ + topicId: config.topicId, + error: error as Error, + span, + modelName: modelId + }) + throw error + } +} +``` + +--- + +## 4. Provider 系统架构 + +### 4.1 Provider 配置转换 + +#### providerToAiSdkConfig() 详解 + +**文件**: `src/renderer/src/aiCore/provider/providerConfig.ts` + +```typescript +export function providerToAiSdkConfig( + provider: Provider, + model?: Model +): ProviderConfig | Promise { + + // 1. 根据 provider.id 路由到具体实现 + switch (provider.id) { + case 'openai': + return { + providerId: 'openai', + providerSettings: { + apiKey: provider.apiKey, + baseURL: provider.apiHost, + organization: provider.apiOrganization, + headers: provider.apiHeaders + } + } + + case 'anthropic': + return { + providerId: 'anthropic', + providerSettings: { + apiKey: provider.apiKey, + baseURL: provider.apiHost + } + } + + case 'openai-compatible': + return { + providerId: 'openai-compatible', + providerSettings: { + baseURL: provider.apiHost, + apiKey: provider.apiKey, + name: provider.name + } + } + + case 'gateway': + // 特殊处理:gateway 需要异步创建 + return createGatewayConfig(provider, model) + + // ... 其他 providers + } +} +``` + +#### Gateway Provider 特殊处理 + +```typescript +async function createGatewayConfig( + provider: Provider, + model?: Model +): Promise { + + // 1. 从 gateway 获取模型列表 + const gatewayModels = await fetchGatewayModels(provider) + + // 2. 标准化模型格式 + const normalizedModels = normalizeGatewayModels(gatewayModels) + + // 3. 使用 AI SDK 的 gateway() 函数 + const gatewayProvider = gateway({ + provider: { + languageModel: (modelId) => { + const targetModel = normalizedModels.find(m => m.id === modelId) + if (!targetModel) { + throw new Error(`Model ${modelId} not found in gateway`) + } + // 动态创建对应的 provider + return createLanguageModel(targetModel) + } + } + }) + + return { + providerId: 'gateway', + provider: gatewayProvider + } +} +``` + +### 4.2 Provider Extension 系统 + +**文件**: `packages/aiCore/src/core/providers/core/ProviderExtension.ts` + +#### 核心设计 + +```typescript +export class ProviderExtension< + TSettings = any, + TStorage extends ExtensionStorage = ExtensionStorage, + TProvider extends ProviderV3 = ProviderV3, + TConfig extends ProviderExtensionConfig = + ProviderExtensionConfig +> { + + // 1. LRU 缓存(settings hash → provider 实例) + private instances: LRUCache + + constructor(public readonly config: TConfig) { + this.instances = new LRUCache({ + max: 10, // 最多缓存 10 个实例 + updateAgeOnGet: true // LRU 行为 + }) + } + + // 2. 创建 provider(带缓存) + async createProvider( + settings?: TSettings, + variantSuffix?: string + ): Promise { + + // 2.1 合并默认配置 + const mergedSettings = this.mergeSettings(settings) + + // 2.2 计算 hash(包含 variantSuffix) + const hash = this.computeHash(mergedSettings, variantSuffix) + + // 2.3 LRU 缓存查找 + const cachedInstance = this.instances.get(hash) + if (cachedInstance) { + return cachedInstance + } + + // 2.4 缓存未命中,创建新实例 + const provider = await this.factory(mergedSettings, variantSuffix) + + // 2.5 执行生命周期钩子 + await this.lifecycle.onCreate?.(provider, mergedSettings) + + // 2.6 存入 LRU 缓存 + this.instances.set(hash, provider) + + return provider + } + + // 3. Hash 计算(保证相同配置得到相同 hash) + private computeHash(settings?: TSettings, variantSuffix?: string): string { + const baseHash = (() => { + if (settings === undefined || settings === null) { + return 'default' + } + + // 稳定序列化(对象键排序) + const stableStringify = (obj: any): string => { + if (obj === null || obj === undefined) return 'null' + if (typeof obj !== 'object') return JSON.stringify(obj) + if (Array.isArray(obj)) return `[${obj.map(stableStringify).join(',')}]` + + const keys = Object.keys(obj).sort() + const pairs = keys.map(key => + `${JSON.stringify(key)}:${stableStringify(obj[key])}` + ) + return `{${pairs.join(',')}}` + } + + const serialized = stableStringify(settings) + + // 简单哈希函数 + let hash = 0 + for (let i = 0; i < serialized.length; i++) { + const char = serialized.charCodeAt(i) + hash = (hash << 5) - hash + char + hash = hash & hash + } + + return `${Math.abs(hash).toString(36)}` + })() + + // 附加 variantSuffix + return variantSuffix ? `${baseHash}:${variantSuffix}` : baseHash + } +} +``` + +#### OpenAI Extension 示例 + +```typescript +// packages/aiCore/src/core/providers/extensions/openai.ts + +export const OpenAIExtension = new ProviderExtension({ + name: 'openai', + aliases: ['oai'], + variants: [ + { + suffix: 'chat', + name: 'OpenAI Chat', + transform: (baseProvider, settings) => { + return customProvider({ + fallbackProvider: { + ...baseProvider, + languageModel: (modelId) => baseProvider.chat(modelId) + } + }) + } + } + ], + + // Factory 函数 + create: async (settings: OpenAIProviderSettings) => { + return createOpenAI({ + apiKey: settings.apiKey, + baseURL: settings.baseURL, + organization: settings.organization, + headers: settings.headers + }) + }, + + // 默认配置 + defaultSettings: { + baseURL: 'https://api.openai.com/v1' + }, + + // 生命周期钩子 + lifecycle: { + onCreate: async (provider, settings) => { + console.log(`OpenAI provider created with baseURL: ${settings.baseURL}`) + } + } +}) +``` + +### 4.3 Extension Registry + +**文件**: `packages/aiCore/src/core/providers/core/ExtensionRegistry.ts` + +```typescript +export class ExtensionRegistry { + private extensions: Map> = new Map() + private aliasMap: Map = new Map() + + // 1. 注册 extension + register(extension: ProviderExtension): this { + const { name, aliases, variants } = extension.config + + // 注册主 ID + this.extensions.set(name, extension) + + // 注册别名 + if (aliases) { + for (const alias of aliases) { + this.aliasMap.set(alias, name) + } + } + + // 注册变体 ID + if (variants) { + for (const variant of variants) { + const variantId = `${name}-${variant.suffix}` + this.aliasMap.set(variantId, name) + } + } + + return this + } + + // 2. 创建 provider(类型安全) + async createProvider( + id: T, + settings: CoreProviderSettingsMap[T] + ): Promise + + async createProvider(id: string, settings?: any): Promise + + async createProvider(id: string, settings?: any): Promise { + // 2.1 解析 ID(支持别名和变体) + const parsed = this.parseProviderId(id) + if (!parsed) { + throw new Error(`Provider extension "${id}" not found`) + } + + const { baseId, mode: variantSuffix } = parsed + + // 2.2 获取 extension + const extension = this.get(baseId) + if (!extension) { + throw new Error(`Provider extension "${baseId}" not found`) + } + + // 2.3 委托给 extension 创建 + try { + return await extension.createProvider(settings, variantSuffix) + } catch (error) { + throw new ProviderCreationError( + `Failed to create provider "${id}"`, + id, + error instanceof Error ? error : new Error(String(error)) + ) + } + } + + // 3. 解析 providerId + parseProviderId(providerId: string): { + baseId: RegisteredProviderId + mode?: string + isVariant: boolean + } | null { + + // 3.1 检查是否是基础 ID 或别名 + const extension = this.get(providerId) + if (extension) { + return { + baseId: extension.config.name as RegisteredProviderId, + isVariant: false + } + } + + // 3.2 查找变体 + for (const ext of this.extensions.values()) { + if (!ext.config.variants) continue + + for (const variant of ext.config.variants) { + const variantId = `${ext.config.name}-${variant.suffix}` + if (variantId === providerId) { + return { + baseId: ext.config.name as RegisteredProviderId, + mode: variant.suffix, + isVariant: true + } + } + } + } + + return null + } +} + +// 全局单例 +export const extensionRegistry = new ExtensionRegistry() +``` + +--- + +## 5. 插件与中间件系统 + +### 5.1 插件架构 + +#### AiPlugin 接口定义 + +**文件**: `packages/aiCore/src/core/plugins/types.ts` + +```typescript +export interface AiPlugin { + /** 插件名称 */ + name: string + + /** 请求开始前 */ + onRequestStart?: (context: PluginContext) => void | Promise + + /** 转换参数(链式调用) */ + transformParams?: ( + params: any, + context: PluginContext + ) => any | Promise + + /** 转换结果 */ + transformResult?: ( + result: any, + context: PluginContext + ) => any | Promise + + /** 请求结束后 */ + onRequestEnd?: (context: PluginContext) => void | Promise + + /** 错误处理 */ + onError?: ( + error: Error, + context: PluginContext + ) => void | Promise +} + +export interface PluginContext { + providerId: string + model?: string + messages?: any[] + tools?: any + // experimental_context 中的自定义数据 + [key: string]: any +} +``` + +#### PluginEngine 实现 + +**文件**: `packages/aiCore/src/core/plugins/PluginEngine.ts` + +```typescript +export class PluginEngine { + constructor( + private providerId: string, + private plugins: AiPlugin[] + ) {} + + // 1. 执行 onRequestStart + async executeOnRequestStart(params: any): Promise { + const context = this.createContext(params) + + for (const plugin of this.plugins) { + if (plugin.onRequestStart) { + await plugin.onRequestStart(context) + } + } + } + + // 2. 链式执行 transformParams + async executeTransformParams(params: any): Promise { + let transformedParams = params + const context = this.createContext(params) + + for (const plugin of this.plugins) { + if (plugin.transformParams) { + transformedParams = await plugin.transformParams( + transformedParams, + context + ) + } + } + + return transformedParams + } + + // 3. 执行 transformResult + async executeTransformResult(result: any, params: any): Promise { + let transformedResult = result + const context = this.createContext(params) + + // 反向执行 + for (let i = this.plugins.length - 1; i >= 0; i--) { + const plugin = this.plugins[i] + if (plugin.transformResult) { + transformedResult = await plugin.transformResult( + transformedResult, + context + ) + } + } + + return transformedResult + } + + // 4. 执行 onRequestEnd + async executeOnRequestEnd(params: any): Promise { + const context = this.createContext(params) + + // 反向执行 + for (let i = this.plugins.length - 1; i >= 0; i--) { + const plugin = this.plugins[i] + if (plugin.onRequestEnd) { + await plugin.onRequestEnd(context) + } + } + } + + // 5. 执行 onError + async executeOnError(error: Error, params: any): Promise { + const context = this.createContext(params) + + for (const plugin of this.plugins) { + if (plugin.onError) { + try { + await plugin.onError(error, context) + } catch (pluginError) { + console.error(`Error in plugin ${plugin.name}:`, pluginError) + } + } + } + } + + private createContext(params: any): PluginContext { + return { + providerId: this.providerId, + model: params.model, + messages: params.messages, + tools: params.tools, + ...params.experimental_context + } + } +} +``` + +### 5.2 内置插件 + +#### 5.2.1 ReasoningPlugin + +**文件**: `src/renderer/src/aiCore/plugins/ReasoningPlugin.ts` + +```typescript +export const ReasoningPlugin: AiPlugin = { + name: 'ReasoningPlugin', + + transformParams: async (params, context) => { + if (!context.enableReasoning) { + return params + } + + // 根据模型类型添加 reasoning 配置 + if (context.model?.includes('o1') || context.model?.includes('o3')) { + // OpenAI o1/o3 系列 + return { + ...params, + reasoning_effort: context.reasoningEffort || 'medium' + } + } else if (context.model?.includes('claude')) { + // Anthropic Claude 系列 + return { + ...params, + thinking: { + type: 'enabled', + budget_tokens: context.thinkingBudget || 2000 + } + } + } else if (context.model?.includes('qwen')) { + // Qwen 系列 + return { + ...params, + experimental_providerMetadata: { + qwen: { think_mode: true } + } + } + } + + return params + } +} +``` + +#### 5.2.2 ToolUsePlugin + +**文件**: `src/renderer/src/aiCore/plugins/ToolUsePlugin.ts` + +```typescript +export const ToolUsePlugin: AiPlugin = { + name: 'ToolUsePlugin', + + transformParams: async (params, context) => { + if (!context.isSupportedToolUse && !context.isPromptToolUse) { + return params + } + + // 1. 收集所有工具 + const tools: Record = {} + + // 1.1 MCP 工具 + if (context.mcpTools && context.mcpTools.length > 0) { + for (const mcpTool of context.mcpTools) { + tools[mcpTool.name] = convertMcpToolToCoreTool(mcpTool) + } + } + + // 1.2 内置工具(WebSearch, GenerateImage, etc.) + if (context.enableWebSearch) { + tools['web_search'] = webSearchTool + } + + if (context.enableGenerateImage) { + tools['generate_image'] = generateImageTool + } + + // 2. Prompt Tool Use 模式特殊处理 + if (context.isPromptToolUse) { + return { + ...params, + messages: injectToolsIntoPrompt(params.messages, tools) + } + } + + // 3. 标准 Function Calling 模式 + return { + ...params, + tools, + toolChoice: 'auto' + } + } +} +``` + +#### 5.2.3 WebSearchPlugin + +**文件**: `src/renderer/src/aiCore/plugins/WebSearchPlugin.ts` + +```typescript +export const WebSearchPlugin: AiPlugin = { + name: 'WebSearchPlugin', + + transformParams: async (params, context) => { + if (!context.enableWebSearch) { + return params + } + + // 添加 web search 工具 + const webSearchTool = { + type: 'function' as const, + function: { + name: 'web_search', + description: 'Search the web for current information', + parameters: { + type: 'object', + properties: { + query: { + type: 'string', + description: 'Search query' + } + }, + required: ['query'] + } + }, + execute: async ({ query }: { query: string }) => { + return await executeWebSearch(query, context.webSearchProviderId) + } + } + + return { + ...params, + tools: { + ...params.tools, + web_search: webSearchTool + } + } + } +} +``` + +### 5.3 插件构建器 + +**文件**: `src/renderer/src/aiCore/plugins/PluginBuilder.ts` + +```typescript +export function buildPlugins(config: AiSdkMiddlewareConfig): AiPlugin[] { + const plugins: AiPlugin[] = [] + + // 1. Reasoning Plugin + if (config.enableReasoning) { + plugins.push(ReasoningPlugin) + } + + // 2. Tool Use Plugin + if (config.isSupportedToolUse || config.isPromptToolUse) { + plugins.push(ToolUsePlugin) + } + + // 3. Web Search Plugin + if (config.enableWebSearch) { + plugins.push(WebSearchPlugin) + } + + // 4. Image Generation Plugin + if (config.enableGenerateImage) { + plugins.push(ImageGenerationPlugin) + } + + // 5. URL Context Plugin + if (config.enableUrlContext) { + plugins.push(UrlContextPlugin) + } + + return plugins +} +``` + +--- + +## 6. 消息处理流程 + +### 6.1 消息转换 + +**文件**: `src/renderer/src/services/ConversationService.ts` + +```typescript +export class ConversationService { + + /** + * 准备消息用于 LLM 调用 + * + * @returns { + * modelMessages: AI SDK 格式的消息 + * uiMessages: 原始 UI 消息(用于特殊场景) + * } + */ + static async prepareMessagesForModel( + messages: Message[], + assistant: Assistant + ): Promise<{ + modelMessages: CoreMessage[] + uiMessages: Message[] + }> { + + // 1. 过滤消息 + let filteredMessages = messages + .filter(m => !m.isDeleted) + .filter(m => m.role !== 'system') + + // 2. 应用上下文窗口限制 + const contextLimit = assistant.settings?.contextLimit || 10 + if (contextLimit > 0) { + filteredMessages = takeRight(filteredMessages, contextLimit) + } + + // 3. 转换为 AI SDK 格式 + const modelMessages: CoreMessage[] = [] + + for (const msg of filteredMessages) { + const converted = await this.convertMessageToAiSdk(msg, assistant) + if (converted) { + modelMessages.push(converted) + } + } + + // 4. 添加 system message + if (assistant.prompt) { + modelMessages.unshift({ + role: 'system', + content: assistant.prompt + }) + } + + return { + modelMessages, + uiMessages: filteredMessages + } + } + + /** + * 转换单条消息 + */ + static async convertMessageToAiSdk( + message: Message, + assistant: Assistant + ): Promise { + + switch (message.role) { + case 'user': + return await this.convertUserMessage(message) + + case 'assistant': + return await this.convertAssistantMessage(message) + + case 'tool': + return { + role: 'tool', + content: message.content, + toolCallId: message.toolCallId + } + + default: + return null + } + } + + /** + * 转换用户消息(处理多模态内容) + */ + static async convertUserMessage(message: Message): Promise { + const parts: Array = [] + + // 1. 处理文本内容 + const textContent = getMainTextContent(message) + if (textContent) { + parts.push({ + type: 'text', + text: textContent + }) + } + + // 2. 处理图片 + const imageBlocks = findImageBlocks(message) + for (const block of imageBlocks) { + const imageData = await this.loadImageData(block.image.url) + parts.push({ + type: 'image', + image: imageData + }) + } + + // 3. 处理文件 + const fileBlocks = findFileBlocks(message) + for (const block of fileBlocks) { + const fileData = await this.loadFileData(block.file) + parts.push({ + type: 'file', + data: fileData, + mimeType: block.file.mime_type + }) + } + + return { + role: 'user', + content: parts + } + } + + /** + * 转换助手消息(处理工具调用) + */ + static async convertAssistantMessage( + message: Message + ): Promise { + + const parts: Array = [] + + // 1. 处理文本内容 + const textContent = getMainTextContent(message) + if (textContent) { + parts.push({ + type: 'text', + text: textContent + }) + } + + // 2. 处理工具调用 + const toolCallBlocks = findToolCallBlocks(message) + for (const block of toolCallBlocks) { + parts.push({ + type: 'tool-call', + toolCallId: block.toolCallId, + toolName: block.toolName, + args: block.args + }) + } + + return { + role: 'assistant', + content: parts + } + } +} +``` + +### 6.2 流式数据适配 + +**文件**: `src/renderer/src/aiCore/chunk/AiSdkToChunkAdapter.ts` + +```typescript +export default class AiSdkToChunkAdapter { + + constructor( + private onChunk: (chunk: Chunk) => void, + private mcpTools?: MCPTool[], + private accumulate: boolean = true, + private enableWebSearch: boolean = false + ) {} + + /** + * 处理 AI SDK 流式结果 + */ + async processStream(streamResult: StreamTextResult): Promise { + const startTime = Date.now() + let fullText = '' + let firstTokenTime = 0 + + try { + // 1. 监听 textStream + for await (const textDelta of streamResult.textStream) { + if (!firstTokenTime) { + firstTokenTime = Date.now() + } + + if (this.accumulate) { + fullText += textDelta + + // 发送文本增量 chunk + this.onChunk({ + type: ChunkType.TEXT_DELTA, + text: textDelta + }) + } else { + // 不累积,直接发送完整文本 + this.onChunk({ + type: ChunkType.TEXT, + text: textDelta + }) + } + } + + // 2. 处理工具调用 + const toolCalls = streamResult.toolCalls + if (toolCalls && toolCalls.length > 0) { + for (const toolCall of toolCalls) { + await this.handleToolCall(toolCall) + } + } + + // 3. 处理 reasoning/thinking + const reasoning = streamResult.experimental_providerMetadata?.reasoning + if (reasoning) { + this.onChunk({ + type: ChunkType.REASONING, + content: reasoning + }) + } + + // 4. 发送完成 chunk + const usage = await streamResult.usage + const finishReason = await streamResult.finishReason + + this.onChunk({ + type: ChunkType.BLOCK_COMPLETE, + response: { + usage: { + prompt_tokens: usage.promptTokens, + completion_tokens: usage.completionTokens, + total_tokens: usage.totalTokens + }, + metrics: { + completion_tokens: usage.completionTokens, + time_first_token_millsec: firstTokenTime - startTime, + time_completion_millsec: Date.now() - startTime + }, + finish_reason: finishReason + } + }) + + this.onChunk({ + type: ChunkType.LLM_RESPONSE_COMPLETE, + response: { + usage: { + prompt_tokens: usage.promptTokens, + completion_tokens: usage.completionTokens, + total_tokens: usage.totalTokens + } + } + }) + + return fullText + + } catch (error) { + this.onChunk({ + type: ChunkType.ERROR, + error: error as Error + }) + throw error + } + } + + /** + * 处理工具调用 + */ + private async handleToolCall(toolCall: ToolCall): Promise { + // 1. 发送工具调用开始 chunk + this.onChunk({ + type: ChunkType.TOOL_CALL, + toolCall: { + id: toolCall.toolCallId, + name: toolCall.toolName, + arguments: toolCall.args + } + }) + + // 2. 查找工具定义 + const mcpTool = this.mcpTools?.find(t => t.name === toolCall.toolName) + + // 3. 执行工具 + try { + let result: any + + if (mcpTool) { + // MCP 工具 + result = await window.api.mcp.callTool( + mcpTool.serverName, + toolCall.toolName, + toolCall.args + ) + } else if (toolCall.toolName === 'web_search' && this.enableWebSearch) { + // Web Search 工具 + result = await executeWebSearch(toolCall.args.query) + } else { + result = { error: `Unknown tool: ${toolCall.toolName}` } + } + + // 4. 发送工具结果 chunk + this.onChunk({ + type: ChunkType.TOOL_RESULT, + toolResult: { + toolCallId: toolCall.toolCallId, + toolName: toolCall.toolName, + result + } + }) + + } catch (error) { + // 5. 发送工具错误 chunk + this.onChunk({ + type: ChunkType.TOOL_ERROR, + toolError: { + toolCallId: toolCall.toolCallId, + toolName: toolCall.toolName, + error: error as Error + } + }) + } + } +} +``` + +--- + +## 7. 类型安全机制 + +### 7.1 Provider Settings 类型映射 + +**文件**: `packages/aiCore/src/core/providers/types/index.ts` + +```typescript +/** + * Core Provider Settings Map + * 自动从 Extension 提取类型 + */ +export type CoreProviderSettingsMap = UnionToIntersection< + ExtensionToSettingsMap<(typeof coreExtensions)[number]> +> + +/** + * 结果类型(示例): + * { + * openai: OpenAIProviderSettings + * 'openai-chat': OpenAIProviderSettings + * anthropic: AnthropicProviderSettings + * google: GoogleProviderSettings + * ... + * } + */ +``` + +### 7.2 类型安全的 createExecutor + +```typescript +// 1. 已知 provider(类型安全) +const executor = await createExecutor('openai', { + apiKey: 'sk-xxx', // ✅ 类型推断为 string + baseURL: 'https://...' // ✅ 类型推断为 string | undefined + // wrongField: 123 // ❌ 编译错误:不存在的字段 +}) + +// 2. 动态 provider(any) +const executor = await createExecutor('custom-provider', { + anyField: 'value' // ✅ any 类型 +}) +``` + +### 7.3 Extension Registry 类型安全 + +```typescript +export class ExtensionRegistry { + + // 类型安全的函数重载 + async createProvider< + T extends RegisteredProviderId & keyof CoreProviderSettingsMap + >( + id: T, + settings: CoreProviderSettingsMap[T] + ): Promise + + async createProvider( + id: string, + settings?: any + ): Promise + + async createProvider(id: string, settings?: any): Promise { + // 实现 + } +} + +// 使用: +const provider = await extensionRegistry.createProvider('openai', { + apiKey: 'sk-xxx', // ✅ 类型检查 + baseURL: 'https://...' +}) +``` + +--- + +## 8. Trace 和可观测性 + +### 8.1 OpenTelemetry 集成 + +#### Span 创建 + +**文件**: `src/renderer/src/services/SpanManagerService.ts` + +```typescript +export function addSpan(params: StartSpanParams): Span | null { + const { name, tag, topicId, modelName, inputs } = params + + // 1. 获取或创建 tracer + const tracer = getTracer(topicId) + if (!tracer) return null + + // 2. 创建 span + const span = tracer.startSpan(name, { + kind: SpanKind.CLIENT, + attributes: { + 'llm.tag': tag, + 'llm.model': modelName, + 'llm.topic_id': topicId, + 'llm.input_messages': JSON.stringify(inputs.messages), + 'llm.temperature': inputs.temperature, + 'llm.max_tokens': inputs.maxTokens + } + }) + + // 3. 设置 span context 为 active + context.with(trace.setSpan(context.active(), span), () => { + // 后续的 AI SDK 调用会自动继承这个 span + }) + + return span +} +``` + +#### Span 结束 + +```typescript +export function endSpan(params: EndSpanParams): void { + const { topicId, span, outputs, error, modelName } = params + + if (outputs) { + // 成功情况 + span.setAttributes({ + 'llm.output_text': outputs.getText(), + 'llm.finish_reason': outputs.finishReason, + 'llm.usage.prompt_tokens': outputs.usage.promptTokens, + 'llm.usage.completion_tokens': outputs.usage.completionTokens + }) + span.setStatus({ code: SpanStatusCode.OK }) + } else if (error) { + // 错误情况 + span.recordException(error) + span.setStatus({ + code: SpanStatusCode.ERROR, + message: error.message + }) + } + + span.end() +} +``` + +### 8.2 Trace 层级结构 + +``` +Parent Span: fetchChatCompletion +│ +├─ Child Span: prepareMessagesForModel +│ └─ attributes: message_count, filters_applied +│ +├─ Child Span: buildStreamTextParams +│ └─ attributes: tools_count, web_search_enabled +│ +├─ Child Span: AI.completions (创建于 _completionsForTrace) +│ │ +│ ├─ Child Span: buildPlugins +│ │ └─ attributes: plugin_names +│ │ +│ ├─ Child Span: createExecutor +│ │ └─ attributes: provider_id, cache_hit +│ │ +│ └─ Child Span: executor.streamText +│ │ +│ ├─ Child Span: AI SDK doStream (自动创建) +│ │ └─ attributes: model, temperature, tokens +│ │ +│ └─ Child Span: Tool Execution (如果有工具调用) +│ ├─ attributes: tool_name, args +│ └─ attributes: result, latency +│ +└─ attributes: total_duration, final_token_count +``` + +### 8.3 Trace 导出 + +```typescript +// 配置 OTLP Exporter +const exporter = new OTLPTraceExporter({ + url: 'http://localhost:4318/v1/traces', + headers: { + 'Authorization': 'Bearer xxx' + } +}) + +// 配置 Trace Provider +const provider = new WebTracerProvider({ + resource: new Resource({ + 'service.name': 'cherry-studio', + 'service.version': app.getVersion() + }) +}) + +provider.addSpanProcessor( + new BatchSpanProcessor(exporter, { + maxQueueSize: 100, + maxExportBatchSize: 10, + scheduledDelayMillis: 500 + }) +) + +provider.register() +``` + +--- + +## 9. 错误处理机制 + +### 9.1 错误类型层级 + +```typescript +// 1. Base Error +export class ProviderError extends Error { + constructor( + message: string, + public providerId: string, + public code?: string, + public cause?: Error + ) { + super(message) + this.name = 'ProviderError' + } +} + +// 2. Provider Creation Error +export class ProviderCreationError extends ProviderError { + constructor(message: string, providerId: string, cause: Error) { + super(message, providerId, 'PROVIDER_CREATION_FAILED', cause) + this.name = 'ProviderCreationError' + } +} + +// 3. Model Resolution Error +export class ModelResolutionError extends ProviderError { + constructor( + message: string, + public modelId: string, + providerId: string + ) { + super(message, providerId, 'MODEL_RESOLUTION_FAILED') + this.name = 'ModelResolutionError' + } +} + +// 4. API Error +export class ApiError extends ProviderError { + constructor( + message: string, + providerId: string, + public statusCode?: number, + public response?: any + ) { + super(message, providerId, 'API_REQUEST_FAILED') + this.name = 'ApiError' + } +} +``` + +### 9.2 错误传播 + +``` +RuntimeExecutor.streamText() + │ + ├─ try { + │ await pluginEngine.executeOnRequestStart() + │ } catch (error) { + │ await pluginEngine.executeOnError(error) + │ throw error + │ } + │ + ├─ try { + │ params = await pluginEngine.executeTransformParams(params) + │ } catch (error) { + │ await pluginEngine.executeOnError(error) + │ throw error + │ } + │ + └─ try { + const result = await aiSdk.streamText(...) + return result + } catch (error) { + await pluginEngine.executeOnError(error) + + // 转换 AI SDK 错误为统一格式 + if (isAiSdkError(error)) { + throw new ApiError( + error.message, + this.config.providerId, + error.statusCode, + error.response + ) + } + + throw error + } +``` + +### 9.3 用户友好的错误处理 + +**文件**: `src/renderer/src/services/ApiService.ts` + +```typescript +try { + await fetchChatCompletion({...}) +} catch (error: any) { + + // 1. API Key 错误 + if (error.statusCode === 401) { + onChunkReceived({ + type: ChunkType.ERROR, + error: { + message: i18n.t('error.invalid_api_key'), + code: 'INVALID_API_KEY' + } + }) + return + } + + // 2. Rate Limit + if (error.statusCode === 429) { + onChunkReceived({ + type: ChunkType.ERROR, + error: { + message: i18n.t('error.rate_limit'), + code: 'RATE_LIMIT', + retryAfter: error.response?.headers['retry-after'] + } + }) + return + } + + // 3. Abort + if (isAbortError(error)) { + onChunkReceived({ + type: ChunkType.ERROR, + error: { + message: i18n.t('error.request_aborted'), + code: 'ABORTED' + } + }) + return + } + + // 4. 通用错误 + onChunkReceived({ + type: ChunkType.ERROR, + error: { + message: error.message || i18n.t('error.unknown'), + code: error.code || 'UNKNOWN_ERROR', + details: getEnableDeveloperMode() ? error.stack : undefined + } + }) +} +``` + +--- + +## 10. 性能优化 + +### 10.1 Provider 实例缓存(LRU) + +**优势**: +- ✅ 避免重复创建相同配置的 provider +- ✅ 自动清理最久未使用的实例 +- ✅ 内存可控(max: 10 per extension) + +**性能指标**: +``` +Cache Hit: <1ms (直接从 Map 获取) +Cache Miss: ~50ms (创建新 AI SDK provider) +``` + +### 10.2 并行请求优化 + +```typescript +// ❌ 串行执行(慢) +const mcpTools = await fetchMcpTools(assistant) +const params = await buildStreamTextParams(...) +const plugins = buildPlugins(config) + +// ✅ 并行执行(快) +const [mcpTools, params, plugins] = await Promise.all([ + fetchMcpTools(assistant), + buildStreamTextParams(...), + Promise.resolve(buildPlugins(config)) +]) +``` + +### 10.3 流式响应优化 + +```typescript +// 1. 使用 textStream 而非 fullStream +for await (const textDelta of streamResult.textStream) { + onChunk({ type: ChunkType.TEXT_DELTA, text: textDelta }) +} + +// 2. 批量发送 chunks(减少 IPC 开销) +const chunkBuffer: Chunk[] = [] +for await (const textDelta of streamResult.textStream) { + chunkBuffer.push({ type: ChunkType.TEXT_DELTA, text: textDelta }) + + if (chunkBuffer.length >= 10) { + onChunk({ type: ChunkType.BATCH, chunks: chunkBuffer }) + chunkBuffer.length = 0 + } +} +``` + +### 10.4 内存优化 + +```typescript +// 1. 及时清理大对象 +async processStream(streamResult: StreamTextResult) { + try { + for await (const delta of streamResult.textStream) { + // 处理 delta + } + } finally { + // 确保流被消费完毕 + await streamResult.consumeStream() + } +} + +// 2. LRU 缓存自动淘汰 +// 当缓存达到 max: 10 时,最久未使用的实例会被自动移除 +``` + +--- + +## 附录 A: 关键文件索引 + +### Service Layer +- `src/renderer/src/services/ApiService.ts` - 主要 API 服务 +- `src/renderer/src/services/ConversationService.ts` - 消息准备 +- `src/renderer/src/services/SpanManagerService.ts` - Trace 管理 + +### AI Provider Layer +- `src/renderer/src/aiCore/index_new.ts` - ModernAiProvider +- `src/renderer/src/aiCore/provider/providerConfig.ts` - Provider 配置 +- `src/renderer/src/aiCore/chunk/AiSdkToChunkAdapter.ts` - 流式适配 +- `src/renderer/src/aiCore/plugins/PluginBuilder.ts` - 插件构建 + +### Core Package +- `packages/aiCore/src/core/runtime/executor.ts` - RuntimeExecutor +- `packages/aiCore/src/core/runtime/index.ts` - createExecutor +- `packages/aiCore/src/core/providers/core/ProviderExtension.ts` - Extension 基类 +- `packages/aiCore/src/core/providers/core/ExtensionRegistry.ts` - 注册表 +- `packages/aiCore/src/core/models/ModelResolver.ts` - 模型解析 +- `packages/aiCore/src/core/plugins/PluginEngine.ts` - 插件引擎 + +### Extensions +- `packages/aiCore/src/core/providers/extensions/openai.ts` - OpenAI Extension +- `packages/aiCore/src/core/providers/extensions/anthropic.ts` - Anthropic Extension +- `packages/aiCore/src/core/providers/extensions/google.ts` - Google Extension + +--- + +## 附录 B: 常见问题 + +### Q1: 为什么要用 LRU 缓存? +**A**: 避免为相同配置重复创建 provider,同时自动控制内存(最多 10 个实例/extension)。 + +### Q2: Plugin 和 Middleware 有什么区别? +**A**: +- **Plugin**: Cherry Studio 层面的功能扩展(Reasoning, ToolUse, WebSearch) +- **Middleware**: AI SDK 层面的请求/响应拦截器 + +### Q3: 什么时候用 Legacy Provider? +**A**: 仅在图像生成端点且非 gateway 时使用,因为需要图片编辑等高级功能。 + +### Q4: 如何添加新的 Provider? +**A**: +1. 在 `packages/aiCore/src/core/providers/extensions/` 创建 Extension +2. 注册到 `coreExtensions` 数组 +3. 在 `providerConfig.ts` 添加配置转换逻辑 + +--- + +**文档版本**: v1.0 +**最后更新**: 2025-01-02 +**维护者**: Cherry Studio Team diff --git a/packages/aiCore/package.json b/packages/aiCore/package.json index 100a404693..ecac566b68 100644 --- a/packages/aiCore/package.json +++ b/packages/aiCore/package.json @@ -47,6 +47,7 @@ "@ai-sdk/provider": "^3.0.0", "@ai-sdk/provider-utils": "^4.0.0", "@ai-sdk/xai": "^3.0.0", + "lru-cache": "^11.2.4", "zod": "^4.1.5" }, "devDependencies": { diff --git a/packages/aiCore/src/__tests__/index.ts b/packages/aiCore/src/__tests__/index.ts deleted file mode 100644 index afc498cad1..0000000000 --- a/packages/aiCore/src/__tests__/index.ts +++ /dev/null @@ -1,13 +0,0 @@ -/** - * Test Infrastructure Exports - * Central export point for all test utilities, fixtures, and helpers - */ - -// Fixtures -export * from './fixtures/mock-providers' -export * from './fixtures/mock-responses' - -// Helpers -export * from './helpers/model-test-utils' -export * from './helpers/provider-test-utils' -export * from './helpers/test-utils' diff --git a/packages/aiCore/src/core/README.MD b/packages/aiCore/src/core/README.MD deleted file mode 100644 index fc33fe18d5..0000000000 --- a/packages/aiCore/src/core/README.MD +++ /dev/null @@ -1,3 +0,0 @@ -# @cherryStudio-aiCore - -Core diff --git a/packages/aiCore/src/core/index.ts b/packages/aiCore/src/core/index.ts index 2346ea8cd2..beb7f92538 100644 --- a/packages/aiCore/src/core/index.ts +++ b/packages/aiCore/src/core/index.ts @@ -8,7 +8,7 @@ export type { NamedMiddleware } from './middleware' export { createMiddlewares, wrapModelWithMiddlewares } from './middleware' // 创建管理 -export { globalModelResolver, ModelResolver } from './models' +export { ModelResolver } from './models' export type { ModelConfig as ModelConfigType } from './models/types' // 执行管理 diff --git a/packages/aiCore/src/core/models/ModelResolver.ts b/packages/aiCore/src/core/models/ModelResolver.ts index 23534c4798..da9307d935 100644 --- a/packages/aiCore/src/core/models/ModelResolver.ts +++ b/packages/aiCore/src/core/models/ModelResolver.ts @@ -1,77 +1,56 @@ /** * 模型解析器 - models模块的核心 * 负责将modelId解析为AI SDK的LanguageModel实例 - * 支持传统格式和命名空间格式 - * 集成了来自 ModelCreator 的特殊处理逻辑 + * + * 支持两种格式: + * 1. 传统格式: 'gpt-4' (直接使用当前provider) + * 2. 命名空间格式: 'hub|provider|model' (HubProvider内部路由) */ -import type { EmbeddingModelV3, ImageModelV3, LanguageModelV3, LanguageModelV3Middleware } from '@ai-sdk/provider' +import type { + EmbeddingModelV3, + ImageModelV3, + LanguageModelV3, + LanguageModelV3Middleware, + ProviderV3 +} from '@ai-sdk/provider' import { wrapModelWithMiddlewares } from '../middleware/wrapper' -import { globalProviderStorage } from '../providers/core/ProviderExtension' -import { DEFAULT_SEPARATOR } from '../providers/features/HubProvider' export class ModelResolver { + private provider: ProviderV3 + /** - * 从 globalProviderStorage 获取 provider - * @param providerId - Provider explicit ID - * @throws Error if provider not found + * 构造函数接受provider实例 + * Provider可以是普通provider或HubProvider */ - private getProvider(providerId: string) { - const provider = globalProviderStorage.get(providerId) - if (!provider) { - throw new Error( - `Provider "${providerId}" not found. Please ensure it has been initialized with extension.createProvider(settings, "${providerId}")` - ) - } - return provider + constructor(provider: ProviderV3) { + this.provider = provider } /** - * 解析完整的模型ID (providerId:modelId 格式) - * @returns { providerId, modelId } - */ - private parseFullModelId(fullModelId: string): { providerId: string; modelId: string } { - const parts = fullModelId.split(DEFAULT_SEPARATOR) - if (parts.length < 2) { - throw new Error(`Invalid model ID format: "${fullModelId}". Expected "providerId${DEFAULT_SEPARATOR}modelId"`) - } - // 支持多个分隔符的情况(如 hub:provider:model) - const providerId = parts[0] - const modelId = parts.slice(1).join(DEFAULT_SEPARATOR) - return { providerId, modelId } - } - - /** - * 核心方法:解析任意格式的modelId为语言模型 + * 解析语言模型 * - * @param modelId 模型ID,支持 'gpt-4' 和 'anthropic>claude-3' 两种格式 - * @param fallbackProviderId 当modelId为传统格式时使用的providerId - * @param providerOptions provider配置选项(用于OpenAI模式选择等) - * @param middlewares 中间件数组,会应用到最终模型上 + * @param modelId 模型ID,支持传统格式('gpt-4')或命名空间格式('hub|provider|model') + * @param middlewares 可选的中间件数组,会应用到最终模型上 + * @returns 解析后的语言模型实例 + * + * @example + * ```typescript + * // 传统格式 + * const model = await resolver.resolveLanguageModel('gpt-4') + * + * // 命名空间格式 (需要HubProvider) + * const model = await resolver.resolveLanguageModel('hub|openai|gpt-4') + * ``` */ - async resolveLanguageModel( - modelId: string, - fallbackProviderId: string, - providerOptions?: any, - middlewares?: LanguageModelV3Middleware[] - ): Promise { - let finalProviderId = fallbackProviderId - let model: LanguageModelV3 - // 🎯 处理 OpenAI 模式选择逻辑 (从 ModelCreator 迁移) - if ((fallbackProviderId === 'openai' || fallbackProviderId === 'azure') && providerOptions?.mode === 'chat') { - finalProviderId = `${fallbackProviderId}-chat` - } + async resolveLanguageModel(modelId: string, middlewares?: LanguageModelV3Middleware[]): Promise { + // 直接将完整的modelId传给provider + // - 如果是普通provider,会直接使用modelId + // - 如果是HubProvider,会解析命名空间并路由到正确的provider + let model = this.provider.languageModel(modelId) - // 检查是否是命名空间格式 - if (modelId.includes(DEFAULT_SEPARATOR)) { - model = this.resolveNamespacedModel(modelId) - } else { - // 传统格式:使用处理后的 providerId + modelId - model = this.resolveTraditionalModel(finalProviderId, modelId) - } - - // 🎯 应用中间件(如果有) + // 应用中间件 if (middlewares && middlewares.length > 0) { model = wrapModelWithMiddlewares(model, middlewares) } @@ -81,81 +60,21 @@ export class ModelResolver { /** * 解析文本嵌入模型 + * + * @param modelId 模型ID + * @returns 解析后的嵌入模型实例 */ - async resolveTextEmbeddingModel(modelId: string, fallbackProviderId: string): Promise { - if (modelId.includes(DEFAULT_SEPARATOR)) { - return this.resolveNamespacedEmbeddingModel(modelId) - } - - return this.resolveTraditionalEmbeddingModel(fallbackProviderId, modelId) + async resolveEmbeddingModel(modelId: string): Promise { + return this.provider.embeddingModel(modelId) } /** - * 解析图像模型 + * 解析图像生成模型 + * + * @param modelId 模型ID + * @returns 解析后的图像模型实例 */ - async resolveImageModel(modelId: string, fallbackProviderId: string): Promise { - if (modelId.includes(DEFAULT_SEPARATOR)) { - return this.resolveNamespacedImageModel(modelId) - } - - return this.resolveTraditionalImageModel(fallbackProviderId, modelId) - } - - /** - * 解析命名空间格式的语言模型 - * aihubmix:anthropic:claude-3 -> 从 globalProviderStorage 获取 'aihubmix' provider,调用 languageModel('anthropic:claude-3') - */ - private resolveNamespacedModel(fullModelId: string): LanguageModelV3 { - const { providerId, modelId } = this.parseFullModelId(fullModelId) - const provider = this.getProvider(providerId) - return provider.languageModel(modelId) - } - - /** - * 解析传统格式的语言模型 - * providerId: 'openai', modelId: 'gpt-4' -> 从 globalProviderStorage 获取 'openai' provider,调用 languageModel('gpt-4') - */ - private resolveTraditionalModel(providerId: string, modelId: string): LanguageModelV3 { - const provider = this.getProvider(providerId) - return provider.languageModel(modelId) - } - - /** - * 解析命名空间格式的嵌入模型 - */ - private resolveNamespacedEmbeddingModel(fullModelId: string): EmbeddingModelV3 { - const { providerId, modelId } = this.parseFullModelId(fullModelId) - const provider = this.getProvider(providerId) - return provider.embeddingModel(modelId) - } - - /** - * 解析传统格式的嵌入模型 - */ - private resolveTraditionalEmbeddingModel(providerId: string, modelId: string): EmbeddingModelV3 { - const provider = this.getProvider(providerId) - return provider.embeddingModel(modelId) - } - - /** - * 解析命名空间格式的图像模型 - */ - private resolveNamespacedImageModel(fullModelId: string): ImageModelV3 { - const { providerId, modelId } = this.parseFullModelId(fullModelId) - const provider = this.getProvider(providerId) - return provider.imageModel(modelId) - } - - /** - * 解析传统格式的图像模型 - */ - private resolveTraditionalImageModel(providerId: string, modelId: string): ImageModelV3 { - const provider = this.getProvider(providerId) - return provider.imageModel(modelId) + async resolveImageModel(modelId: string): Promise { + return this.provider.imageModel(modelId) } } - -/** - * 全局模型解析器实例 - */ -export const globalModelResolver = new ModelResolver() diff --git a/packages/aiCore/src/core/models/__tests__/ModelResolver.test.ts b/packages/aiCore/src/core/models/__tests__/ModelResolver.test.ts index 0b7ee30c53..f72a3febfe 100644 --- a/packages/aiCore/src/core/models/__tests__/ModelResolver.test.ts +++ b/packages/aiCore/src/core/models/__tests__/ModelResolver.test.ts @@ -1,34 +1,23 @@ /** - * ModelResolver Comprehensive Tests + * ModelResolver Tests * Tests model resolution logic for language, embedding, and image models - * Covers both traditional and namespaced format resolution + * The resolver passes modelId directly to provider - all routing is handled by the provider */ import type { EmbeddingModelV3, ImageModelV3, LanguageModelV3 } from '@ai-sdk/provider' -import { beforeEach, describe, expect, it, vi } from 'vitest' - import { createMockEmbeddingModel, createMockImageModel, createMockLanguageModel, - createMockMiddleware -} from '../../../__tests__' -import { DEFAULT_SEPARATOR, globalProviderInstanceRegistry } from '../../providers/core/ProviderInstanceRegistry' -import { ModelResolver } from '../ModelResolver' + createMockMiddleware, + createMockProviderV3 +} from '@test-utils' +import { beforeEach, describe, expect, it, vi } from 'vitest' -// Mock the dependencies -vi.mock('../../providers/core/ProviderInstanceRegistry', () => ({ - globalProviderInstanceRegistry: { - languageModel: vi.fn(), - embeddingModel: vi.fn(), - imageModel: vi.fn() - }, - DEFAULT_SEPARATOR: '|' -})) +import { ModelResolver } from '../ModelResolver' vi.mock('../../middleware/wrapper', () => ({ wrapModelWithMiddlewares: vi.fn((model: LanguageModelV3) => { - // Return a wrapped model with a marker return { ...model, _wrapped: true @@ -41,12 +30,12 @@ describe('ModelResolver', () => { let mockLanguageModel: LanguageModelV3 let mockEmbeddingModel: EmbeddingModelV3 let mockImageModel: ImageModelV3 + let mockProvider: any beforeEach(() => { vi.clearAllMocks() - resolver = new ModelResolver() - // Create properly typed mock models using global utilities + // Create properly typed mock models mockLanguageModel = createMockLanguageModel({ provider: 'test-provider', modelId: 'test-model' @@ -62,395 +51,204 @@ describe('ModelResolver', () => { modelId: 'test-image' }) - // Setup default mock implementations - vi.mocked(globalProviderInstanceRegistry.languageModel).mockReturnValue(mockLanguageModel) - vi.mocked(globalProviderInstanceRegistry.embeddingModel).mockReturnValue(mockEmbeddingModel) - vi.mocked(globalProviderInstanceRegistry.imageModel).mockReturnValue(mockImageModel) + // Create mock provider with model methods as spies + mockProvider = createMockProviderV3({ + provider: 'test-provider', + languageModel: vi.fn(() => mockLanguageModel), + embeddingModel: vi.fn(() => mockEmbeddingModel), + imageModel: vi.fn(() => mockImageModel) + }) + + // Create resolver with mock provider + resolver = new ModelResolver(mockProvider) }) describe('resolveLanguageModel', () => { - describe('Traditional Format Resolution', () => { - it('should resolve traditional format modelId without separator', async () => { - const result = await resolver.resolveLanguageModel('gpt-4', 'openai') + it('should resolve modelId by passing it to provider', async () => { + const result = await resolver.resolveLanguageModel('gpt-4') - expect(globalProviderInstanceRegistry.languageModel).toHaveBeenCalledWith(`openai${DEFAULT_SEPARATOR}gpt-4`) - expect(result).toBe(mockLanguageModel) - }) - - it('should resolve with different provider and modelId combinations', async () => { - const testCases: Array<{ modelId: string; providerId: string; expected: string }> = [ - { modelId: 'claude-3-5-sonnet', providerId: 'anthropic', expected: 'anthropic|claude-3-5-sonnet' }, - { modelId: 'gemini-2.0-flash', providerId: 'google', expected: 'google|gemini-2.0-flash' }, - { modelId: 'grok-2-latest', providerId: 'xai', expected: 'xai|grok-2-latest' }, - { modelId: 'deepseek-chat', providerId: 'deepseek', expected: 'deepseek|deepseek-chat' } - ] - - for (const testCase of testCases) { - vi.clearAllMocks() - await resolver.resolveLanguageModel(testCase.modelId, testCase.providerId) - - expect(globalProviderInstanceRegistry.languageModel).toHaveBeenCalledWith(testCase.expected) - } - }) - - it('should handle modelIds with special characters', async () => { - const modelIds = ['model-v1.0', 'model_v2', 'model.2024', 'model:free'] - - for (const modelId of modelIds) { - vi.clearAllMocks() - await resolver.resolveLanguageModel(modelId, 'provider') - - expect(globalProviderInstanceRegistry.languageModel).toHaveBeenCalledWith( - `provider${DEFAULT_SEPARATOR}${modelId}` - ) - } - }) + expect(mockProvider.languageModel).toHaveBeenCalledWith('gpt-4') + expect(result).toBe(mockLanguageModel) }) - describe('Namespaced Format Resolution', () => { - it('should resolve namespaced format with hub', async () => { - const namespacedId = `aihubmix${DEFAULT_SEPARATOR}anthropic${DEFAULT_SEPARATOR}claude-3-5-sonnet` + it('should pass various modelIds directly to provider', async () => { + const modelIds = [ + 'claude-3-5-sonnet', + 'gemini-2.0-flash', + 'grok-2-latest', + 'deepseek-chat', + 'model-v1.0', + 'model_v2', + 'model.2024' + ] - const result = await resolver.resolveLanguageModel(namespacedId, 'openai') + for (const modelId of modelIds) { + vi.clearAllMocks() + await resolver.resolveLanguageModel(modelId) - expect(globalProviderInstanceRegistry.languageModel).toHaveBeenCalledWith(namespacedId) - expect(result).toBe(mockLanguageModel) - }) - - it('should resolve simple namespaced format', async () => { - const namespacedId = `provider${DEFAULT_SEPARATOR}model-id` - - await resolver.resolveLanguageModel(namespacedId, 'fallback-provider') - - expect(globalProviderInstanceRegistry.languageModel).toHaveBeenCalledWith(namespacedId) - }) - - it('should handle complex namespaced IDs', async () => { - const complexIds = [ - `hub${DEFAULT_SEPARATOR}provider${DEFAULT_SEPARATOR}model`, - `hub${DEFAULT_SEPARATOR}provider${DEFAULT_SEPARATOR}model-v1.0`, - `custom${DEFAULT_SEPARATOR}openai${DEFAULT_SEPARATOR}gpt-4-turbo` - ] - - for (const id of complexIds) { - vi.clearAllMocks() - await resolver.resolveLanguageModel(id, 'fallback') - - expect(globalProviderInstanceRegistry.languageModel).toHaveBeenCalledWith(id) - } - }) + expect(mockProvider.languageModel).toHaveBeenCalledWith(modelId) + } }) - describe('OpenAI Mode Selection', () => { - it('should append "-chat" suffix for OpenAI provider with chat mode', async () => { - await resolver.resolveLanguageModel('gpt-4', 'openai', { mode: 'chat' }) + it('should pass namespaced modelIds directly to provider (provider handles routing)', async () => { + // HubProvider handles routing internally - ModelResolver just passes through + const namespacedId = 'openai|gpt-4' - expect(globalProviderInstanceRegistry.languageModel).toHaveBeenCalledWith('openai-chat|gpt-4') - }) + await resolver.resolveLanguageModel(namespacedId) - it('should append "-chat" suffix for Azure provider with chat mode', async () => { - await resolver.resolveLanguageModel('gpt-4', 'azure', { mode: 'chat' }) + expect(mockProvider.languageModel).toHaveBeenCalledWith(namespacedId) + }) - expect(globalProviderInstanceRegistry.languageModel).toHaveBeenCalledWith('azure-chat|gpt-4') - }) + it('should handle empty model IDs', async () => { + await resolver.resolveLanguageModel('') - it('should not append suffix for OpenAI with responses mode', async () => { - await resolver.resolveLanguageModel('gpt-4', 'openai', { mode: 'responses' }) - - expect(globalProviderInstanceRegistry.languageModel).toHaveBeenCalledWith('openai|gpt-4') - }) - - it('should not append suffix for OpenAI without mode', async () => { - await resolver.resolveLanguageModel('gpt-4', 'openai') - - expect(globalProviderInstanceRegistry.languageModel).toHaveBeenCalledWith('openai|gpt-4') - }) - - it('should not append suffix for other providers with chat mode', async () => { - await resolver.resolveLanguageModel('claude-3', 'anthropic', { mode: 'chat' }) - - expect(globalProviderInstanceRegistry.languageModel).toHaveBeenCalledWith('anthropic|claude-3') - }) - - it('should handle namespaced IDs with OpenAI chat mode', async () => { - const namespacedId = `hub${DEFAULT_SEPARATOR}openai${DEFAULT_SEPARATOR}gpt-4` - - await resolver.resolveLanguageModel(namespacedId, 'openai', { mode: 'chat' }) - - // Should use the namespaced ID directly, not apply mode logic - expect(globalProviderInstanceRegistry.languageModel).toHaveBeenCalledWith(namespacedId) - }) + expect(mockProvider.languageModel).toHaveBeenCalledWith('') }) describe('Middleware Application', () => { it('should apply middlewares to resolved model', async () => { const mockMiddleware = createMockMiddleware() - const result = await resolver.resolveLanguageModel('gpt-4', 'openai', undefined, [mockMiddleware]) + const result = await resolver.resolveLanguageModel('gpt-4', [mockMiddleware]) expect(result).toHaveProperty('_wrapped', true) }) - it('should apply multiple middlewares in order', async () => { + it('should apply multiple middlewares', async () => { const middleware1 = createMockMiddleware() const middleware2 = createMockMiddleware() - const result = await resolver.resolveLanguageModel('gpt-4', 'openai', undefined, [middleware1, middleware2]) + const result = await resolver.resolveLanguageModel('gpt-4', [middleware1, middleware2]) expect(result).toHaveProperty('_wrapped', true) }) it('should not apply middlewares when none provided', async () => { - const result = await resolver.resolveLanguageModel('gpt-4', 'openai') + const result = await resolver.resolveLanguageModel('gpt-4') expect(result).not.toHaveProperty('_wrapped') expect(result).toBe(mockLanguageModel) }) it('should not apply middlewares when empty array provided', async () => { - const result = await resolver.resolveLanguageModel('gpt-4', 'openai', undefined, []) + const result = await resolver.resolveLanguageModel('gpt-4', []) expect(result).not.toHaveProperty('_wrapped') }) }) - describe('Provider Options Handling', () => { - it('should pass provider options correctly', async () => { - const options = { baseURL: 'https://api.example.com', apiKey: 'test-key' } - - await resolver.resolveLanguageModel('gpt-4', 'openai', options) - - // Provider options are used for mode selection logic - expect(globalProviderInstanceRegistry.languageModel).toHaveBeenCalled() - }) - - it('should handle empty provider options', async () => { - await resolver.resolveLanguageModel('gpt-4', 'openai', {}) - - expect(globalProviderInstanceRegistry.languageModel).toHaveBeenCalledWith('openai|gpt-4') - }) - - it('should handle undefined provider options', async () => { - await resolver.resolveLanguageModel('gpt-4', 'openai', undefined) - - expect(globalProviderInstanceRegistry.languageModel).toHaveBeenCalledWith('openai|gpt-4') - }) - }) - }) - - describe('resolveTextEmbeddingModel', () => { - describe('Traditional Format', () => { - it('should resolve traditional embedding model ID', async () => { - const result = await resolver.resolveTextEmbeddingModel('text-embedding-ada-002', 'openai') - - expect(globalProviderInstanceRegistry.embeddingModel).toHaveBeenCalledWith('openai|text-embedding-ada-002') - expect(result).toBe(mockEmbeddingModel) - }) - - it('should resolve different embedding models', async () => { - const testCases = [ - { modelId: 'text-embedding-3-small', providerId: 'openai' }, - { modelId: 'text-embedding-3-large', providerId: 'openai' }, - { modelId: 'embed-english-v3.0', providerId: 'cohere' }, - { modelId: 'voyage-2', providerId: 'voyage' } - ] - - for (const { modelId, providerId } of testCases) { - vi.clearAllMocks() - await resolver.resolveTextEmbeddingModel(modelId, providerId) - - expect(globalProviderInstanceRegistry.embeddingModel).toHaveBeenCalledWith(`${providerId}|${modelId}`) - } - }) - }) - - describe('Namespaced Format', () => { - it('should resolve namespaced embedding model ID', async () => { - const namespacedId = `aihubmix${DEFAULT_SEPARATOR}openai${DEFAULT_SEPARATOR}text-embedding-3-small` - - const result = await resolver.resolveTextEmbeddingModel(namespacedId, 'openai') - - expect(globalProviderInstanceRegistry.embeddingModel).toHaveBeenCalledWith(namespacedId) - expect(result).toBe(mockEmbeddingModel) - }) - - it('should handle complex namespaced embedding IDs', async () => { - const complexIds = [ - `hub${DEFAULT_SEPARATOR}cohere${DEFAULT_SEPARATOR}embed-multilingual`, - `custom${DEFAULT_SEPARATOR}provider${DEFAULT_SEPARATOR}embedding-model` - ] - - for (const id of complexIds) { - vi.clearAllMocks() - await resolver.resolveTextEmbeddingModel(id, 'fallback') - - expect(globalProviderInstanceRegistry.embeddingModel).toHaveBeenCalledWith(id) - } - }) - }) - }) - - describe('resolveImageModel', () => { - describe('Traditional Format', () => { - it('should resolve traditional image model ID', async () => { - const result = await resolver.resolveImageModel('dall-e-3', 'openai') - - expect(globalProviderInstanceRegistry.imageModel).toHaveBeenCalledWith('openai|dall-e-3') - expect(result).toBe(mockImageModel) - }) - - it('should resolve different image models', async () => { - const testCases = [ - { modelId: 'dall-e-2', providerId: 'openai' }, - { modelId: 'stable-diffusion-xl', providerId: 'stability' }, - { modelId: 'imagen-2', providerId: 'google' }, - { modelId: 'midjourney-v6', providerId: 'midjourney' } - ] - - for (const { modelId, providerId } of testCases) { - vi.clearAllMocks() - await resolver.resolveImageModel(modelId, providerId) - - expect(globalProviderInstanceRegistry.imageModel).toHaveBeenCalledWith(`${providerId}|${modelId}`) - } - }) - }) - - describe('Namespaced Format', () => { - it('should resolve namespaced image model ID', async () => { - const namespacedId = `aihubmix${DEFAULT_SEPARATOR}openai${DEFAULT_SEPARATOR}dall-e-3` - - const result = await resolver.resolveImageModel(namespacedId, 'openai') - - expect(globalProviderInstanceRegistry.imageModel).toHaveBeenCalledWith(namespacedId) - expect(result).toBe(mockImageModel) - }) - - it('should handle complex namespaced image IDs', async () => { - const complexIds = [ - `hub${DEFAULT_SEPARATOR}stability${DEFAULT_SEPARATOR}sdxl-turbo`, - `custom${DEFAULT_SEPARATOR}provider${DEFAULT_SEPARATOR}image-gen-v2` - ] - - for (const id of complexIds) { - vi.clearAllMocks() - await resolver.resolveImageModel(id, 'fallback') - - expect(globalProviderInstanceRegistry.imageModel).toHaveBeenCalledWith(id) - } - }) - }) - }) - - describe('Edge Cases and Error Scenarios', () => { - it('should handle empty model IDs', async () => { - await resolver.resolveLanguageModel('', 'openai') - - expect(globalProviderInstanceRegistry.languageModel).toHaveBeenCalledWith('openai|') - }) - - it('should handle model IDs with multiple separators', async () => { - const multiSeparatorId = `hub${DEFAULT_SEPARATOR}sub${DEFAULT_SEPARATOR}provider${DEFAULT_SEPARATOR}model` - - await resolver.resolveLanguageModel(multiSeparatorId, 'fallback') - - expect(globalProviderInstanceRegistry.languageModel).toHaveBeenCalledWith(multiSeparatorId) - }) - - it('should handle model IDs with only separator', async () => { - const onlySeparator = DEFAULT_SEPARATOR - - await resolver.resolveLanguageModel(onlySeparator, 'provider') - - expect(globalProviderInstanceRegistry.languageModel).toHaveBeenCalledWith(onlySeparator) - }) - - it('should throw if globalProviderInstanceRegistry throws', async () => { - const error = new Error('Model not found in registry') - vi.mocked(globalProviderInstanceRegistry.languageModel).mockImplementation(() => { + it('should throw if provider throws', async () => { + const error = new Error('Model not found') + vi.mocked(mockProvider.languageModel).mockImplementation(() => { throw error }) - await expect(resolver.resolveLanguageModel('invalid-model', 'openai')).rejects.toThrow( - 'Model not found in registry' - ) + await expect(resolver.resolveLanguageModel('invalid-model')).rejects.toThrow('Model not found') }) it('should handle concurrent resolution requests', async () => { const promises = [ - resolver.resolveLanguageModel('gpt-4', 'openai'), - resolver.resolveLanguageModel('claude-3', 'anthropic'), - resolver.resolveLanguageModel('gemini-2.0', 'google') + resolver.resolveLanguageModel('gpt-4'), + resolver.resolveLanguageModel('claude-3'), + resolver.resolveLanguageModel('gemini-2.0') ] const results = await Promise.all(promises) expect(results).toHaveLength(3) - expect(globalProviderInstanceRegistry.languageModel).toHaveBeenCalledTimes(3) + expect(mockProvider.languageModel).toHaveBeenCalledTimes(3) + }) + }) + + describe('resolveEmbeddingModel', () => { + it('should resolve embedding model ID', async () => { + const result = await resolver.resolveEmbeddingModel('text-embedding-ada-002') + + expect(mockProvider.embeddingModel).toHaveBeenCalledWith('text-embedding-ada-002') + expect(result).toBe(mockEmbeddingModel) + }) + + it('should resolve different embedding models', async () => { + const modelIds = ['text-embedding-3-small', 'text-embedding-3-large', 'embed-english-v3.0', 'voyage-2'] + + for (const modelId of modelIds) { + vi.clearAllMocks() + await resolver.resolveEmbeddingModel(modelId) + + expect(mockProvider.embeddingModel).toHaveBeenCalledWith(modelId) + } + }) + + it('should pass namespaced embedding modelIds directly to provider', async () => { + const namespacedId = 'openai|text-embedding-3-small' + + await resolver.resolveEmbeddingModel(namespacedId) + + expect(mockProvider.embeddingModel).toHaveBeenCalledWith(namespacedId) + }) + }) + + describe('resolveImageModel', () => { + it('should resolve image model ID', async () => { + const result = await resolver.resolveImageModel('dall-e-3') + + expect(mockProvider.imageModel).toHaveBeenCalledWith('dall-e-3') + expect(result).toBe(mockImageModel) + }) + + it('should resolve different image models', async () => { + const modelIds = ['dall-e-2', 'stable-diffusion-xl', 'imagen-2', 'grok-2-image'] + + for (const modelId of modelIds) { + vi.clearAllMocks() + await resolver.resolveImageModel(modelId) + + expect(mockProvider.imageModel).toHaveBeenCalledWith(modelId) + } + }) + + it('should pass namespaced image modelIds directly to provider', async () => { + const namespacedId = 'openai|dall-e-3' + + await resolver.resolveImageModel(namespacedId) + + expect(mockProvider.imageModel).toHaveBeenCalledWith(namespacedId) }) }) describe('Type Safety', () => { it('should return properly typed LanguageModelV3', async () => { - const result = await resolver.resolveLanguageModel('gpt-4', 'openai') + const result = await resolver.resolveLanguageModel('gpt-4') - // Type assertions expect(result.specificationVersion).toBe('v3') expect(result).toHaveProperty('doGenerate') expect(result).toHaveProperty('doStream') }) it('should return properly typed EmbeddingModelV3', async () => { - const result = await resolver.resolveTextEmbeddingModel('text-embedding-ada-002', 'openai') + const result = await resolver.resolveEmbeddingModel('text-embedding-ada-002') expect(result.specificationVersion).toBe('v3') expect(result).toHaveProperty('doEmbed') }) it('should return properly typed ImageModelV3', async () => { - const result = await resolver.resolveImageModel('dall-e-3', 'openai') + const result = await resolver.resolveImageModel('dall-e-3') expect(result.specificationVersion).toBe('v3') expect(result).toHaveProperty('doGenerate') }) }) - describe('Global ModelResolver Instance', () => { - it('should have a global instance available', async () => { - const { globalModelResolver } = await import('../ModelResolver') + describe('All model types for same provider', () => { + it('should handle all model types correctly', async () => { + await resolver.resolveLanguageModel('gpt-4') + await resolver.resolveEmbeddingModel('text-embedding-3-small') + await resolver.resolveImageModel('dall-e-3') - expect(globalModelResolver).toBeInstanceOf(ModelResolver) - }) - }) - - describe('Integration with Different Provider Types', () => { - it('should work with OpenAI compatible providers', async () => { - await resolver.resolveLanguageModel('custom-model', 'openai-compatible') - - expect(globalProviderInstanceRegistry.languageModel).toHaveBeenCalledWith('openai-compatible|custom-model') - }) - - it('should work with hub providers', async () => { - const hubId = `aihubmix${DEFAULT_SEPARATOR}custom${DEFAULT_SEPARATOR}model-v1` - - await resolver.resolveLanguageModel(hubId, 'aihubmix') - - expect(globalProviderInstanceRegistry.languageModel).toHaveBeenCalledWith(hubId) - }) - - it('should handle all model types for same provider', async () => { - const providerId = 'openai' - const languageModel = 'gpt-4' - const embeddingModel = 'text-embedding-3-small' - const imageModel = 'dall-e-3' - - await resolver.resolveLanguageModel(languageModel, providerId) - await resolver.resolveTextEmbeddingModel(embeddingModel, providerId) - await resolver.resolveImageModel(imageModel, providerId) - - expect(globalProviderInstanceRegistry.languageModel).toHaveBeenCalledWith(`${providerId}|${languageModel}`) - expect(globalProviderInstanceRegistry.embeddingModel).toHaveBeenCalledWith(`${providerId}|${embeddingModel}`) - expect(globalProviderInstanceRegistry.imageModel).toHaveBeenCalledWith(`${providerId}|${imageModel}`) + expect(mockProvider.languageModel).toHaveBeenCalledWith('gpt-4') + expect(mockProvider.embeddingModel).toHaveBeenCalledWith('text-embedding-3-small') + expect(mockProvider.imageModel).toHaveBeenCalledWith('dall-e-3') }) }) }) diff --git a/packages/aiCore/src/core/models/index.ts b/packages/aiCore/src/core/models/index.ts index 1e6d33bf2a..fa4803a94e 100644 --- a/packages/aiCore/src/core/models/index.ts +++ b/packages/aiCore/src/core/models/index.ts @@ -3,7 +3,7 @@ */ // 核心模型解析器 -export { globalModelResolver, ModelResolver } from './ModelResolver' +export { ModelResolver } from './ModelResolver' // 保留的类型定义(可能被其他地方使用) export type { ModelConfig as ModelConfigType } from './types' diff --git a/packages/aiCore/src/core/models/types.ts b/packages/aiCore/src/core/models/types.ts index ac8eea290e..bfc35d9b43 100644 --- a/packages/aiCore/src/core/models/types.ts +++ b/packages/aiCore/src/core/models/types.ts @@ -17,7 +17,7 @@ export interface ModelConfig< > { providerId: T modelId: string - providerSettings: T extends keyof TSettingsMap ? TSettingsMap[T] : never + providerSettings: TSettingsMap[T & keyof TSettingsMap] middlewares?: LanguageModelV3Middleware[] extraModelConfig?: JSONObject } diff --git a/packages/aiCore/src/core/plugins/built-in/toolUsePlugin/__tests__/StreamEventManager.test.ts b/packages/aiCore/src/core/plugins/built-in/toolUsePlugin/__tests__/StreamEventManager.test.ts index cfb2c3df85..b4c92546e9 100644 --- a/packages/aiCore/src/core/plugins/built-in/toolUsePlugin/__tests__/StreamEventManager.test.ts +++ b/packages/aiCore/src/core/plugins/built-in/toolUsePlugin/__tests__/StreamEventManager.test.ts @@ -10,7 +10,7 @@ import type { import { simulateReadableStream } from 'ai' import { beforeEach, describe, expect, it, vi } from 'vitest' -import { createMockContext, createMockTool } from '../../../../../__tests__' +import { createMockContext, createMockTool } from '@test-utils' import { StreamEventManager } from '../StreamEventManager' import type { StreamController } from '../ToolExecutor' diff --git a/packages/aiCore/src/core/plugins/built-in/toolUsePlugin/__tests__/promptToolUsePlugin.test.ts b/packages/aiCore/src/core/plugins/built-in/toolUsePlugin/__tests__/promptToolUsePlugin.test.ts index 04a8400d9d..f3a7eb9574 100644 --- a/packages/aiCore/src/core/plugins/built-in/toolUsePlugin/__tests__/promptToolUsePlugin.test.ts +++ b/packages/aiCore/src/core/plugins/built-in/toolUsePlugin/__tests__/promptToolUsePlugin.test.ts @@ -3,7 +3,7 @@ import { simulateReadableStream } from 'ai' import { convertReadableStreamToArray } from 'ai/test' import { describe, expect, it, vi } from 'vitest' -import { createMockContext, createMockStreamParams, createMockTool, createMockToolSet } from '../../../../../__tests__' +import { createMockContext, createMockStreamParams, createMockTool, createMockToolSet } from '@test-utils' import { createPromptToolUsePlugin, DEFAULT_SYSTEM_PROMPT } from '../promptToolUsePlugin' describe('promptToolUsePlugin', () => { diff --git a/packages/aiCore/src/core/providers/__tests__/ExtensionRegistry.test.ts b/packages/aiCore/src/core/providers/__tests__/ExtensionRegistry.test.ts index 488518a974..4e93ef9efa 100644 --- a/packages/aiCore/src/core/providers/__tests__/ExtensionRegistry.test.ts +++ b/packages/aiCore/src/core/providers/__tests__/ExtensionRegistry.test.ts @@ -2,9 +2,9 @@ * ExtensionRegistry 单元测试 */ +import { createMockProviderV3 } from '@test-utils' import { beforeEach, describe, expect, it, vi } from 'vitest' -import { createMockProviderV3 } from '../../../__tests__' import { ExtensionRegistry } from '../core/ExtensionRegistry' import { ProviderExtension } from '../core/ProviderExtension' import { ProviderCreationError } from '../core/utils' @@ -297,23 +297,6 @@ describe('ExtensionRegistry', () => { }) }) - it.skip('should validate settings before creating', async () => { - const extension = new ProviderExtension({ - name: 'test-provider', - create: createMockProviderV3 as any - }) - - registry.register(extension) - - try { - await registry.createProvider('test-provider', {}) - expect.fail('Should have thrown') - } catch (error) { - expect(error).toBeInstanceOf(ProviderCreationError) - expect((error as ProviderCreationError).cause.message).toContain('API key required') - } - }) - it('should create provider using dynamic import', async () => { const mockProvider = createMockProviderV3() @@ -503,46 +486,6 @@ describe('ExtensionRegistry', () => { await expect(registry.createProvider('test-provider', { apiKey: 'key' })).rejects.toThrow(ProviderCreationError) }) - - it.skip('should still execute validate hook for backward compatibility', async () => { - const validateSpy = vi.fn(() => ({ success: true })) - - registry.register( - new ProviderExtension({ - name: 'test-provider', - create: createMockProviderV3, - validate: validateSpy - }) - ) - - await registry.createProvider('test-provider', { apiKey: 'key' }) - - expect(validateSpy).toHaveBeenCalledWith({ apiKey: 'key' }) - }) - - it.skip('should execute both onBeforeCreate and validate', async () => { - const executionOrder: string[] = [] - - registry.register( - new ProviderExtension({ - name: 'test-provider', - create: createMockProviderV3, - hooks: { - onBeforeCreate: () => { - executionOrder.push('hook') - } - }, - validate: () => { - executionOrder.push('validate') - return { success: true } - } - }) - ) - - await registry.createProvider('test-provider', { apiKey: 'key' }) - - expect(executionOrder).toEqual(['hook', 'validate']) - }) }) describe('ProviderCreationError', () => { diff --git a/packages/aiCore/src/core/providers/__tests__/HubProvider.integration.test.ts b/packages/aiCore/src/core/providers/__tests__/HubProvider.integration.test.ts new file mode 100644 index 0000000000..7bd03cb8cc --- /dev/null +++ b/packages/aiCore/src/core/providers/__tests__/HubProvider.integration.test.ts @@ -0,0 +1,442 @@ +/** + * HubProvider Integration Tests + * Tests end-to-end integration between HubProvider, RuntimeExecutor, and ProviderExtension + */ + +import type { LanguageModelV3 } from '@ai-sdk/provider' +import { createMockLanguageModel, createMockProviderV3 } from '@test-utils' +import { beforeEach, describe, expect, it, vi } from 'vitest' + +import { RuntimeExecutor } from '../../runtime/executor' +import { ExtensionRegistry } from '../core/ExtensionRegistry' +import { ProviderExtension } from '../core/ProviderExtension' +import { createHubProviderAsync } from '../features/HubProvider' + +describe('HubProvider Integration Tests', () => { + let registry: ExtensionRegistry + let openaiExtension: ProviderExtension + let anthropicExtension: ProviderExtension + let googleExtension: ProviderExtension + + beforeEach(() => { + vi.clearAllMocks() + + // Create fresh registry + registry = new ExtensionRegistry() + + // Create provider extensions using test utils directly + openaiExtension = ProviderExtension.create({ + name: 'openai', + create: () => createMockProviderV3({ provider: 'openai' }) + } as const) + + anthropicExtension = ProviderExtension.create({ + name: 'anthropic', + create: () => createMockProviderV3({ provider: 'anthropic' }) + } as const) + + googleExtension = ProviderExtension.create({ + name: 'google', + create: () => createMockProviderV3({ provider: 'google' }) + } as const) + + // Register extensions + registry.register(openaiExtension) + registry.register(anthropicExtension) + registry.register(googleExtension) + }) + + describe('End-to-End with RuntimeExecutor', () => { + it('should resolve models through HubProvider using namespace format', async () => { + // Create HubProvider + const hubProvider = await createHubProviderAsync({ + hubId: 'aihubmix', + registry, + providerSettingsMap: new Map([ + ['openai', { apiKey: 'test-openai-key' }], + ['anthropic', { apiKey: 'test-anthropic-key' }] + ]) + }) + + // Test that models are resolved correctly + const openaiModel = hubProvider.languageModel('openai|gpt-4') + const anthropicModel = hubProvider.languageModel('anthropic|claude-3-5-sonnet') + + expect(openaiModel).toBeDefined() + expect(openaiModel.provider).toBe('openai') + expect(openaiModel.modelId).toBe('gpt-4') + + expect(anthropicModel).toBeDefined() + expect(anthropicModel.provider).toBe('anthropic') + expect(anthropicModel.modelId).toBe('claude-3-5-sonnet') + }) + + it('should resolve language model correctly through executor', async () => { + const hubProvider = await createHubProviderAsync({ + hubId: 'test-hub', + registry, + providerSettingsMap: new Map([['openai', { apiKey: 'test-key' }]]) + }) + + const executor = RuntimeExecutor.create('test-hub', hubProvider, {} as never, []) + + // Access the private resolveModel method through streamText + const result = await executor.streamText({ + model: 'openai|gpt-4-turbo', + messages: [{ role: 'user', content: 'Test' }] + }) + + // Verify the model was created and result is valid + expect(result).toBeDefined() + expect(result.textStream).toBeDefined() + }) + + it('should handle multiple providers in the same hub', async () => { + const hubProvider = await createHubProviderAsync({ + hubId: 'multi-hub', + registry, + providerSettingsMap: new Map([ + ['openai', { apiKey: 'openai-key' }], + ['anthropic', { apiKey: 'anthropic-key' }], + ['google', { apiKey: 'google-key' }] + ]) + }) + + // Test all three providers can be resolved + const openaiModel = hubProvider.languageModel('openai|gpt-4') + const anthropicModel = hubProvider.languageModel('anthropic|claude-3-5-sonnet') + const googleModel = hubProvider.languageModel('google|gemini-2.0-flash') + + expect(openaiModel.provider).toBe('openai') + expect(openaiModel.modelId).toBe('gpt-4') + + expect(anthropicModel.provider).toBe('anthropic') + expect(anthropicModel.modelId).toBe('claude-3-5-sonnet') + + expect(googleModel.provider).toBe('google') + expect(googleModel.modelId).toBe('gemini-2.0-flash') + }) + + it('should work with direct model objects instead of strings', async () => { + const hubProvider = await createHubProviderAsync({ + hubId: 'test-hub', + registry, + providerSettingsMap: new Map([['openai', { apiKey: 'test-key' }]]) + }) + + const executor = RuntimeExecutor.create('test-hub', hubProvider, {} as never, []) + + // Create a model instance directly + const model = createMockLanguageModel({ + provider: 'openai', + modelId: 'gpt-4' + }) + + // Use the model object directly + const result = await executor.streamText({ + model: model as LanguageModelV3, + messages: [{ role: 'user', content: 'Test with model object' }] + }) + + expect(result).toBeDefined() + }) + }) + + describe('ProviderExtension LRU Cache Integration', () => { + it('should leverage ProviderExtension LRU cache when creating multiple HubProviders', async () => { + const settings = new Map([ + ['openai', { apiKey: 'same-key-1' }], + ['anthropic', { apiKey: 'same-key-2' }] + ]) + + // Create first HubProvider + const hub1 = await createHubProviderAsync({ + hubId: 'hub1', + registry, + providerSettingsMap: settings + }) + + // Create second HubProvider with SAME settings + const hub2 = await createHubProviderAsync({ + hubId: 'hub2', + registry, + providerSettingsMap: settings + }) + + // Extensions should have cached the provider instances + // Create a test model to verify caching + const model1 = hub1.languageModel('openai|gpt-4') + const model2 = hub2.languageModel('openai|gpt-4') + + expect(model1).toBeDefined() + expect(model2).toBeDefined() + + // Both should have the same provider name + expect(model1.provider).toBe('openai') + expect(model2.provider).toBe('openai') + }) + + it('should create new providers when settings differ', async () => { + const settings1 = new Map([['openai', { apiKey: 'key-1' }]]) + const settings2 = new Map([['openai', { apiKey: 'key-2' }]]) + + // Create two HubProviders with DIFFERENT settings + const hub1 = await createHubProviderAsync({ + hubId: 'hub1', + registry, + providerSettingsMap: settings1 + }) + + const hub2 = await createHubProviderAsync({ + hubId: 'hub2', + registry, + providerSettingsMap: settings2 + }) + + const model1 = hub1.languageModel('openai|gpt-4') + const model2 = hub2.languageModel('openai|gpt-4') + + expect(model1).toBeDefined() + expect(model2).toBeDefined() + }) + + it('should handle cache across multiple provider types', async () => { + const settings = new Map([ + ['openai', { apiKey: 'openai-key' }], + ['anthropic', { apiKey: 'anthropic-key' }] + ]) + + const hub = await createHubProviderAsync({ + hubId: 'test-hub', + registry, + providerSettingsMap: settings + }) + + // Create models from different providers + const openaiModel = hub.languageModel('openai|gpt-4') + const anthropicModel = hub.languageModel('anthropic|claude-3-5-sonnet') + const openaiEmbedding = hub.embeddingModel('openai|text-embedding-3-small') + + expect(openaiModel.provider).toBe('openai') + expect(anthropicModel.provider).toBe('anthropic') + expect(openaiEmbedding.provider).toBe('openai') + }) + }) + + describe('Error Handling Integration', () => { + it('should throw error when using provider not in providerSettingsMap', async () => { + const hub = await createHubProviderAsync({ + hubId: 'test-hub', + registry, + providerSettingsMap: new Map([['openai', { apiKey: 'test-key' }]]) + // Note: anthropic NOT included + }) + + // Try to use anthropic (not initialized) + expect(() => { + hub.languageModel('anthropic|claude-3-5-sonnet') + }).toThrow(/Provider "anthropic" not initialized/) + }) + + it('should throw error when extension not registered', async () => { + const emptyRegistry = new ExtensionRegistry() + + await expect( + createHubProviderAsync({ + hubId: 'test-hub', + registry: emptyRegistry, + providerSettingsMap: new Map([['openai', { apiKey: 'test-key' }]]) + }) + ).rejects.toThrow(/Provider extension "openai" not found in registry/) + }) + + it('should throw error on invalid model ID format', async () => { + const hub = await createHubProviderAsync({ + hubId: 'test-hub', + registry, + providerSettingsMap: new Map([['openai', { apiKey: 'test-key' }]]) + }) + + // Invalid format: no separator + expect(() => { + hub.languageModel('invalid-no-separator') + }).toThrow(/Invalid hub model ID format/) + + // Invalid format: empty provider + expect(() => { + hub.languageModel('|model-id') + }).toThrow(/Invalid hub model ID format/) + + // Invalid format: empty modelId + expect(() => { + hub.languageModel('openai|') + }).toThrow(/Invalid hub model ID format/) + }) + + it('should propagate errors from extension.createProvider', async () => { + // Create an extension that throws on creation + const failingExtension = ProviderExtension.create({ + name: 'failing', + create: () => { + throw new Error('Provider creation failed!') + } + } as const) + + const failRegistry = new ExtensionRegistry() + failRegistry.register(failingExtension) + + await expect( + createHubProviderAsync({ + hubId: 'test-hub', + registry: failRegistry, + providerSettingsMap: new Map([['failing', { apiKey: 'test' }]]) + }) + ).rejects.toThrow(/Failed to create provider "failing"/) + }) + }) + + describe('Advanced Scenarios', () => { + it('should support image generation through hub', async () => { + const hub = await createHubProviderAsync({ + hubId: 'test-hub', + registry, + providerSettingsMap: new Map([['openai', { apiKey: 'test-key' }]]) + }) + + const executor = RuntimeExecutor.create('test-hub', hub, {} as never, []) + + const result = await executor.generateImage({ + model: 'openai|dall-e-3', + prompt: 'A beautiful sunset' + }) + + expect(result).toBeDefined() + }) + + it('should support embedding models through hub', async () => { + const hub = await createHubProviderAsync({ + hubId: 'test-hub', + registry, + providerSettingsMap: new Map([['openai', { apiKey: 'test-key' }]]) + }) + + const embeddingModel = hub.embeddingModel('openai|text-embedding-3-small') + + expect(embeddingModel).toBeDefined() + expect(embeddingModel.provider).toBe('openai') + expect(embeddingModel.modelId).toBe('text-embedding-3-small') + }) + + it('should handle concurrent model resolutions', async () => { + const hub = await createHubProviderAsync({ + hubId: 'test-hub', + registry, + providerSettingsMap: new Map([ + ['openai', { apiKey: 'openai-key' }], + ['anthropic', { apiKey: 'anthropic-key' }] + ]) + }) + + // Concurrent model resolutions + const models = await Promise.all([ + Promise.resolve(hub.languageModel('openai|gpt-4')), + Promise.resolve(hub.languageModel('anthropic|claude-3-5-sonnet')), + Promise.resolve(hub.languageModel('openai|gpt-3.5-turbo')) + ]) + + expect(models).toHaveLength(3) + expect(models[0].provider).toBe('openai') + expect(models[0].modelId).toBe('gpt-4') + expect(models[1].provider).toBe('anthropic') + expect(models[1].modelId).toBe('claude-3-5-sonnet') + expect(models[2].provider).toBe('openai') + expect(models[2].modelId).toBe('gpt-3.5-turbo') + }) + + it('should work with middlewares', async () => { + const hub = await createHubProviderAsync({ + hubId: 'test-hub', + registry, + providerSettingsMap: new Map([['openai', { apiKey: 'test-key' }]]) + }) + + const executor = RuntimeExecutor.create('test-hub', hub, {} as never, []) + + // Create a mock middleware + const mockMiddleware = { + specificationVersion: 'v3' as const, + wrapGenerate: vi.fn((doGenerate) => doGenerate), + wrapStream: vi.fn((doStream) => doStream) + } + + const result = await executor.streamText( + { + model: 'openai|gpt-4', + messages: [{ role: 'user', content: 'Test with middleware' }] + }, + { middlewares: [mockMiddleware] } + ) + + expect(result).toBeDefined() + }) + }) + + describe('Multiple HubProvider Instances', () => { + it('should support multiple independent hub providers', async () => { + // Create first hub for OpenAI only + const openaiHub = await createHubProviderAsync({ + hubId: 'openai-hub', + registry, + providerSettingsMap: new Map([['openai', { apiKey: 'openai-key' }]]) + }) + + // Create second hub for Anthropic only + const anthropicHub = await createHubProviderAsync({ + hubId: 'anthropic-hub', + registry, + providerSettingsMap: new Map([['anthropic', { apiKey: 'anthropic-key' }]]) + }) + + // Both hubs should work independently + const openaiModel = openaiHub.languageModel('openai|gpt-4') + const anthropicModel = anthropicHub.languageModel('anthropic|claude-3-5-sonnet') + + expect(openaiModel.provider).toBe('openai') + expect(anthropicModel.provider).toBe('anthropic') + + // OpenAI hub should not have anthropic + expect(() => { + openaiHub.languageModel('anthropic|claude-3-5-sonnet') + }).toThrow(/Provider "anthropic" not initialized/) + + // Anthropic hub should not have openai + expect(() => { + anthropicHub.languageModel('openai|gpt-4') + }).toThrow(/Provider "openai" not initialized/) + }) + + it('should support creating multiple executors from same hub', async () => { + const hub = await createHubProviderAsync({ + hubId: 'shared-hub', + registry, + providerSettingsMap: new Map([ + ['openai', { apiKey: 'key-1' }], + ['anthropic', { apiKey: 'key-2' }] + ]) + }) + + // Create multiple executors from the same hub + const executor1 = RuntimeExecutor.create('shared-hub', hub, {} as never, []) + const executor2 = RuntimeExecutor.create('shared-hub', hub, {} as never, []) + + // Both executors should share the same hub and be able to resolve models + const model1 = hub.languageModel('openai|gpt-4') + const model2 = hub.languageModel('anthropic|claude-3-5-sonnet') + + expect(executor1).toBeDefined() + expect(executor2).toBeDefined() + expect(model1.provider).toBe('openai') + expect(model2.provider).toBe('anthropic') + }) + }) +}) diff --git a/packages/aiCore/src/core/providers/__tests__/HubProvider.test.ts b/packages/aiCore/src/core/providers/__tests__/HubProvider.test.ts index 4c97dcece6..9c053ed746 100644 --- a/packages/aiCore/src/core/providers/__tests__/HubProvider.test.ts +++ b/packages/aiCore/src/core/providers/__tests__/HubProvider.test.ts @@ -1,32 +1,30 @@ /** * HubProvider Comprehensive Tests * Tests hub provider routing, model resolution, and error handling - * Covers multi-provider routing with namespaced model IDs + * Updated for ExtensionRegistry architecture with createHubProviderAsync */ import type { EmbeddingModelV3, ImageModelV3, LanguageModelV3, ProviderV3 } from '@ai-sdk/provider' -import { customProvider, wrapProvider } from 'ai' +import { customProvider } from 'ai' import { beforeEach, describe, expect, it, vi } from 'vitest' -import { createMockEmbeddingModel, createMockImageModel, createMockLanguageModel } from '../../../__tests__' -import { DEFAULT_SEPARATOR, globalProviderInstanceRegistry } from '../core/ProviderInstanceRegistry' -import { createHubProvider, type HubProviderConfig, HubProviderError } from '../features/HubProvider' - -// Mock dependencies -vi.mock('../core/ProviderInstanceRegistry', () => ({ - globalProviderInstanceRegistry: { - getProvider: vi.fn() - }, - DEFAULT_SEPARATOR: '|' -})) +import { createMockEmbeddingModel, createMockImageModel, createMockLanguageModel } from '@test-utils' +import { ExtensionRegistry } from '../core/ExtensionRegistry' +import { ProviderExtension } from '../core/ProviderExtension' +import { + createHubProviderAsync, + DEFAULT_SEPARATOR, + type HubProviderConfig, + HubProviderError +} from '../features/HubProvider' vi.mock('ai', () => ({ customProvider: vi.fn((config) => config.fallbackProvider), - wrapProvider: vi.fn((config) => config.provider), jsonSchema: vi.fn((schema) => schema) })) describe('HubProvider', () => { + let registry: ExtensionRegistry let mockOpenAIProvider: ProviderV3 let mockAnthropicProvider: ProviderV3 let mockLanguageModel: LanguageModelV3 @@ -36,7 +34,7 @@ describe('HubProvider', () => { beforeEach(() => { vi.clearAllMocks() - // Create mock models using global utilities + // Create mock models mockLanguageModel = createMockLanguageModel({ provider: 'test', modelId: 'test-model' @@ -67,150 +65,185 @@ describe('HubProvider', () => { imageModel: vi.fn().mockReturnValue(mockImageModel) } as ProviderV3 - // Setup default mock implementation - vi.mocked(globalProviderInstanceRegistry.getProvider).mockImplementation((id) => { - if (id === 'openai') return mockOpenAIProvider - if (id === 'anthropic') return mockAnthropicProvider - return undefined - }) + // Create registry and register extensions + registry = new ExtensionRegistry() + + const openaiExtension = ProviderExtension.create({ + name: 'openai', + create: () => mockOpenAIProvider + } as const) + + const anthropicExtension = ProviderExtension.create({ + name: 'anthropic', + create: () => mockAnthropicProvider + } as const) + + registry.register(openaiExtension) + registry.register(anthropicExtension) }) describe('Provider Creation', () => { - it('should create hub provider with basic config', () => { + it('should create hub provider with basic config', async () => { const config: HubProviderConfig = { - hubId: 'test-hub' + hubId: 'test-hub', + registry, + providerSettingsMap: new Map([['openai', { apiKey: 'test-key' }]]) } - const provider = createHubProvider(config) + const provider = await createHubProviderAsync(config) expect(provider).toBeDefined() expect(customProvider).toHaveBeenCalled() }) - it('should create provider with debug flag', () => { + it('should create provider with debug flag', async () => { const config: HubProviderConfig = { hubId: 'test-hub', - debug: true + debug: true, + registry, + providerSettingsMap: new Map([['openai', {}]]) } - const provider = createHubProvider(config) + const provider = await createHubProviderAsync(config) expect(provider).toBeDefined() }) - it('should return ProviderV3 specification', () => { - const config: HubProviderConfig = { - hubId: 'aihubmix' - } - - const provider = createHubProvider(config) + it('should return ProviderV3 specification', async () => { + const provider = await createHubProviderAsync({ + hubId: 'aihubmix', + registry, + providerSettingsMap: new Map([ + ['openai', {}], + ['anthropic', {}] + ]) + }) expect(provider).toHaveProperty('specificationVersion', 'v3') expect(provider).toHaveProperty('languageModel') expect(provider).toHaveProperty('embeddingModel') expect(provider).toHaveProperty('imageModel') }) + + it('should throw error if extension not found in registry', async () => { + await expect( + createHubProviderAsync({ + hubId: 'test-hub', + registry, + providerSettingsMap: new Map([['unknown-provider', {}]]) + }) + ).rejects.toThrow(HubProviderError) + }) + + it('should pre-create all providers during initialization', async () => { + await createHubProviderAsync({ + hubId: 'test-hub', + registry, + providerSettingsMap: new Map([ + ['openai', { apiKey: 'key1' }], + ['anthropic', { apiKey: 'key2' }] + ]) + }) + + // Both providers created successfully + expect(true).toBe(true) + }) }) describe('Model ID Parsing', () => { - it('should parse valid hub model ID format', () => { - const config: HubProviderConfig = { hubId: 'test-hub' } - const provider = createHubProvider(config) as ProviderV3 + it('should parse valid hub model ID format', async () => { + const provider = (await createHubProviderAsync({ + hubId: 'test-hub', + registry, + providerSettingsMap: new Map([['openai', {}]]) + })) as ProviderV3 - const modelId = `openai${DEFAULT_SEPARATOR}gpt-4` + const result = provider.languageModel(`openai${DEFAULT_SEPARATOR}gpt-4`) - const result = provider.languageModel(modelId) - - expect(globalProviderInstanceRegistry.getProvider).toHaveBeenCalledWith('openai') expect(mockOpenAIProvider.languageModel).toHaveBeenCalledWith('gpt-4') expect(result).toBe(mockLanguageModel) }) - it('should throw error for invalid model ID format', () => { - const config: HubProviderConfig = { hubId: 'test-hub' } - const provider = createHubProvider(config) as ProviderV3 + it('should throw error for invalid model ID format', async () => { + const provider = (await createHubProviderAsync({ + hubId: 'test-hub', + registry, + providerSettingsMap: new Map([['openai', {}]]) + })) as ProviderV3 - const invalidId = 'invalid-id-without-separator' - - expect(() => provider.languageModel(invalidId)).toThrow(HubProviderError) + expect(() => provider.languageModel('invalid-id-without-separator')).toThrow(HubProviderError) }) - it('should throw error for model ID with multiple separators', () => { - const config: HubProviderConfig = { hubId: 'test-hub' } - const provider = createHubProvider(config) as ProviderV3 + it('should throw error for model ID with multiple separators', async () => { + const provider = (await createHubProviderAsync({ + hubId: 'test-hub', + registry, + providerSettingsMap: new Map([['openai', {}]]) + })) as ProviderV3 - const multiSeparatorId = `provider${DEFAULT_SEPARATOR}extra${DEFAULT_SEPARATOR}model` - - expect(() => provider.languageModel(multiSeparatorId)).toThrow(HubProviderError) + expect(() => provider.languageModel(`provider${DEFAULT_SEPARATOR}extra${DEFAULT_SEPARATOR}model`)).toThrow( + HubProviderError + ) }) - it('should throw error for empty model ID', () => { - const config: HubProviderConfig = { hubId: 'test-hub' } - const provider = createHubProvider(config) as ProviderV3 + it('should throw error for empty model ID', async () => { + const provider = (await createHubProviderAsync({ + hubId: 'test-hub', + registry, + providerSettingsMap: new Map([['openai', {}]]) + })) as ProviderV3 expect(() => provider.languageModel('')).toThrow(HubProviderError) }) - - it('should throw error for model ID with only separator', () => { - const config: HubProviderConfig = { hubId: 'test-hub' } - const provider = createHubProvider(config) as ProviderV3 - - expect(() => provider.languageModel(DEFAULT_SEPARATOR)).toThrow(HubProviderError) - }) }) describe('Language Model Resolution', () => { - it('should route to correct provider for language model', () => { - const config: HubProviderConfig = { hubId: 'aihubmix' } - const provider = createHubProvider(config) as ProviderV3 + it('should route to correct provider for language model', async () => { + const provider = (await createHubProviderAsync({ + hubId: 'aihubmix', + registry, + providerSettingsMap: new Map([['openai', {}]]) + })) as ProviderV3 const result = provider.languageModel(`openai${DEFAULT_SEPARATOR}gpt-4`) - expect(globalProviderInstanceRegistry.getProvider).toHaveBeenCalledWith('openai') expect(mockOpenAIProvider.languageModel).toHaveBeenCalledWith('gpt-4') expect(result).toBe(mockLanguageModel) }) - it('should route different providers correctly', () => { - const config: HubProviderConfig = { hubId: 'aihubmix' } - const provider = createHubProvider(config) as ProviderV3 + it('should route different providers correctly', async () => { + const provider = (await createHubProviderAsync({ + hubId: 'aihubmix', + registry, + providerSettingsMap: new Map([ + ['openai', {}], + ['anthropic', {}] + ]) + })) as ProviderV3 provider.languageModel(`openai${DEFAULT_SEPARATOR}gpt-4`) provider.languageModel(`anthropic${DEFAULT_SEPARATOR}claude-3`) - expect(globalProviderInstanceRegistry.getProvider).toHaveBeenCalledWith('openai') - expect(globalProviderInstanceRegistry.getProvider).toHaveBeenCalledWith('anthropic') expect(mockOpenAIProvider.languageModel).toHaveBeenCalledWith('gpt-4') expect(mockAnthropicProvider.languageModel).toHaveBeenCalledWith('claude-3') }) - it('should wrap provider with wrapProvider', () => { - const config: HubProviderConfig = { hubId: 'test-hub' } - const provider = createHubProvider(config) as ProviderV3 + it('should throw HubProviderError if provider not initialized', async () => { + const provider = (await createHubProviderAsync({ + hubId: 'test-hub', + registry, + providerSettingsMap: new Map([['openai', {}]]) // Only openai initialized + })) as ProviderV3 - provider.languageModel(`openai${DEFAULT_SEPARATOR}gpt-4`) - - expect(wrapProvider).toHaveBeenCalledWith({ - provider: mockOpenAIProvider, - languageModelMiddleware: [] - }) + expect(() => provider.languageModel(`anthropic${DEFAULT_SEPARATOR}claude-3`)).toThrow(HubProviderError) }) - it('should throw HubProviderError if provider not initialized', () => { - vi.mocked(globalProviderInstanceRegistry.getProvider).mockReturnValue(undefined) - - const config: HubProviderConfig = { hubId: 'test-hub' } - const provider = createHubProvider(config) as ProviderV3 - - expect(() => provider.languageModel(`uninitialized${DEFAULT_SEPARATOR}model`)).toThrow(HubProviderError) - expect(() => provider.languageModel(`uninitialized${DEFAULT_SEPARATOR}model`)).toThrow(/not initialized/) - }) - - it('should include provider ID in error message', () => { - vi.mocked(globalProviderInstanceRegistry.getProvider).mockReturnValue(undefined) - - const config: HubProviderConfig = { hubId: 'test-hub' } - const provider = createHubProvider(config) as ProviderV3 + it('should include provider ID in error message', async () => { + const provider = (await createHubProviderAsync({ + hubId: 'test-hub', + registry, + providerSettingsMap: new Map([['openai', {}]]) + })) as ProviderV3 try { provider.languageModel(`missing${DEFAULT_SEPARATOR}model`) @@ -225,20 +258,28 @@ describe('HubProvider', () => { }) describe('Embedding Model Resolution', () => { - it('should route to correct provider for embedding model', () => { - const config: HubProviderConfig = { hubId: 'aihubmix' } - const provider = createHubProvider(config) as ProviderV3 + it('should route to correct provider for embedding model', async () => { + const provider = (await createHubProviderAsync({ + hubId: 'aihubmix', + registry, + providerSettingsMap: new Map([['openai', {}]]) + })) as ProviderV3 const result = provider.embeddingModel(`openai${DEFAULT_SEPARATOR}text-embedding-3-small`) - expect(globalProviderInstanceRegistry.getProvider).toHaveBeenCalledWith('openai') expect(mockOpenAIProvider.embeddingModel).toHaveBeenCalledWith('text-embedding-3-small') expect(result).toBe(mockEmbeddingModel) }) - it('should handle different embedding providers', () => { - const config: HubProviderConfig = { hubId: 'aihubmix' } - const provider = createHubProvider(config) as ProviderV3 + it('should handle different embedding providers', async () => { + const provider = (await createHubProviderAsync({ + hubId: 'aihubmix', + registry, + providerSettingsMap: new Map([ + ['openai', {}], + ['anthropic', {}] + ]) + })) as ProviderV3 provider.embeddingModel(`openai${DEFAULT_SEPARATOR}ada-002`) provider.embeddingModel(`anthropic${DEFAULT_SEPARATOR}embed-v1`) @@ -246,32 +287,31 @@ describe('HubProvider', () => { expect(mockOpenAIProvider.embeddingModel).toHaveBeenCalledWith('ada-002') expect(mockAnthropicProvider.embeddingModel).toHaveBeenCalledWith('embed-v1') }) - - it('should throw error for uninitialized embedding provider', () => { - vi.mocked(globalProviderInstanceRegistry.getProvider).mockReturnValue(undefined) - - const config: HubProviderConfig = { hubId: 'test-hub' } - const provider = createHubProvider(config) as ProviderV3 - - expect(() => provider.embeddingModel(`missing${DEFAULT_SEPARATOR}embed`)).toThrow(HubProviderError) - }) }) describe('Image Model Resolution', () => { - it('should route to correct provider for image model', () => { - const config: HubProviderConfig = { hubId: 'aihubmix' } - const provider = createHubProvider(config) as ProviderV3 + it('should route to correct provider for image model', async () => { + const provider = (await createHubProviderAsync({ + hubId: 'aihubmix', + registry, + providerSettingsMap: new Map([['openai', {}]]) + })) as ProviderV3 const result = provider.imageModel(`openai${DEFAULT_SEPARATOR}dall-e-3`) - expect(globalProviderInstanceRegistry.getProvider).toHaveBeenCalledWith('openai') expect(mockOpenAIProvider.imageModel).toHaveBeenCalledWith('dall-e-3') expect(result).toBe(mockImageModel) }) - it('should handle different image providers', () => { - const config: HubProviderConfig = { hubId: 'aihubmix' } - const provider = createHubProvider(config) as ProviderV3 + it('should handle different image providers', async () => { + const provider = (await createHubProviderAsync({ + hubId: 'aihubmix', + registry, + providerSettingsMap: new Map([ + ['openai', {}], + ['anthropic', {}] + ]) + })) as ProviderV3 provider.imageModel(`openai${DEFAULT_SEPARATOR}dall-e-3`) provider.imageModel(`anthropic${DEFAULT_SEPARATOR}image-gen`) @@ -282,9 +322,9 @@ describe('HubProvider', () => { }) describe('Special Model Types', () => { - it('should support transcription models', () => { + it('should support transcription models if provider has them', async () => { const mockTranscriptionModel = { - specificationVersion: 'v3', + specificationVersion: 'v3' as const, doTranscribe: vi.fn() } @@ -293,86 +333,38 @@ describe('HubProvider', () => { transcriptionModel: vi.fn().mockReturnValue(mockTranscriptionModel) } as ProviderV3 - vi.mocked(globalProviderInstanceRegistry.getProvider).mockReturnValue(providerWithTranscription) + // Replace the provider that will be created + const transcriptionExtension = ProviderExtension.create({ + name: 'transcription-provider', + create: () => providerWithTranscription + } as const) - const config: HubProviderConfig = { hubId: 'test-hub' } - const provider = createHubProvider(config) as ProviderV3 + registry.register(transcriptionExtension) - const result = provider.transcriptionModel!(`openai${DEFAULT_SEPARATOR}whisper-1`) + const provider = (await createHubProviderAsync({ + hubId: 'test-hub', + registry, + providerSettingsMap: new Map([['transcription-provider', {}]]) + })) as ProviderV3 + + const result = provider.transcriptionModel!(`transcription-provider${DEFAULT_SEPARATOR}whisper-1`) expect(providerWithTranscription.transcriptionModel).toHaveBeenCalledWith('whisper-1') expect(result).toBe(mockTranscriptionModel) }) - it('should throw error if provider does not support transcription', () => { - const config: HubProviderConfig = { hubId: 'test-hub' } - const provider = createHubProvider(config) as ProviderV3 + it('should throw error if provider does not support transcription', async () => { + const provider = (await createHubProviderAsync({ + hubId: 'test-hub', + registry, + providerSettingsMap: new Map([['openai', {}]]) + })) as ProviderV3 expect(() => provider.transcriptionModel!(`openai${DEFAULT_SEPARATOR}whisper`)).toThrow(HubProviderError) expect(() => provider.transcriptionModel!(`openai${DEFAULT_SEPARATOR}whisper`)).toThrow( /does not support transcription/ ) }) - - it('should support speech models', () => { - const mockSpeechModel = { - specificationVersion: 'v3', - doGenerate: vi.fn() - } - - const providerWithSpeech = { - ...mockOpenAIProvider, - speechModel: vi.fn().mockReturnValue(mockSpeechModel) - } as ProviderV3 - - vi.mocked(globalProviderInstanceRegistry.getProvider).mockReturnValue(providerWithSpeech) - - const config: HubProviderConfig = { hubId: 'test-hub' } - const provider = createHubProvider(config) as ProviderV3 - - const result = provider.speechModel!(`openai${DEFAULT_SEPARATOR}tts-1`) - - expect(providerWithSpeech.speechModel).toHaveBeenCalledWith('tts-1') - expect(result).toBe(mockSpeechModel) - }) - - it('should throw error if provider does not support speech', () => { - const config: HubProviderConfig = { hubId: 'test-hub' } - const provider = createHubProvider(config) as ProviderV3 - - expect(() => provider.speechModel!(`openai${DEFAULT_SEPARATOR}tts-1`)).toThrow(HubProviderError) - expect(() => provider.speechModel!(`openai${DEFAULT_SEPARATOR}tts-1`)).toThrow(/does not support speech/) - }) - - it('should support reranking models', () => { - const mockRerankingModel = { - specificationVersion: 'v3', - doRerank: vi.fn() - } - - const providerWithReranking = { - ...mockOpenAIProvider, - rerankingModel: vi.fn().mockReturnValue(mockRerankingModel) - } as ProviderV3 - - vi.mocked(globalProviderInstanceRegistry.getProvider).mockReturnValue(providerWithReranking) - - const config: HubProviderConfig = { hubId: 'test-hub' } - const provider = createHubProvider(config) as ProviderV3 - - const result = provider.rerankingModel!(`openai${DEFAULT_SEPARATOR}rerank-v1`) - - expect(providerWithReranking.rerankingModel).toHaveBeenCalledWith('rerank-v1') - expect(result).toBe(mockRerankingModel) - }) - - it('should throw error if provider does not support reranking', () => { - const config: HubProviderConfig = { hubId: 'test-hub' } - const provider = createHubProvider(config) as ProviderV3 - - expect(() => provider.rerankingModel!(`openai${DEFAULT_SEPARATOR}rerank`)).toThrow(HubProviderError) - expect(() => provider.rerankingModel!(`openai${DEFAULT_SEPARATOR}rerank`)).toThrow(/does not support reranking/) - }) }) describe('Error Handling', () => { @@ -395,106 +387,51 @@ describe('HubProvider', () => { expect(error.providerId).toBeUndefined() expect(error.originalError).toBeUndefined() }) - - it('should wrap provider errors in HubProviderError', () => { - const providerError = new Error('Provider failed') - vi.mocked(globalProviderInstanceRegistry.getProvider).mockImplementation(() => { - throw providerError - }) - - const config: HubProviderConfig = { hubId: 'test-hub' } - const provider = createHubProvider(config) as ProviderV3 - - try { - provider.languageModel(`failing${DEFAULT_SEPARATOR}model`) - expect.fail('Should have thrown HubProviderError') - } catch (error) { - expect(error).toBeInstanceOf(HubProviderError) - const hubError = error as HubProviderError - expect(hubError.originalError).toBe(providerError) - expect(hubError.message).toContain('Failed to get provider') - } - }) - - it('should handle null provider from registry', () => { - vi.mocked(globalProviderInstanceRegistry.getProvider).mockReturnValue(null as any) - - const config: HubProviderConfig = { hubId: 'test-hub' } - const provider = createHubProvider(config) as ProviderV3 - - expect(() => provider.languageModel(`null-provider${DEFAULT_SEPARATOR}model`)).toThrow(HubProviderError) - }) }) describe('Multi-Provider Scenarios', () => { - it('should handle sequential calls to different providers', () => { - const config: HubProviderConfig = { hubId: 'aihubmix' } - const provider = createHubProvider(config) as ProviderV3 + it('should handle sequential calls to different providers', async () => { + const provider = (await createHubProviderAsync({ + hubId: 'aihubmix', + registry, + providerSettingsMap: new Map([ + ['openai', {}], + ['anthropic', {}] + ]) + })) as ProviderV3 provider.languageModel(`openai${DEFAULT_SEPARATOR}gpt-4`) provider.languageModel(`anthropic${DEFAULT_SEPARATOR}claude-3`) provider.languageModel(`openai${DEFAULT_SEPARATOR}gpt-3.5`) - expect(globalProviderInstanceRegistry.getProvider).toHaveBeenCalledTimes(3) expect(mockOpenAIProvider.languageModel).toHaveBeenCalledTimes(2) expect(mockAnthropicProvider.languageModel).toHaveBeenCalledTimes(1) }) - it('should handle mixed model types from same provider', () => { - const config: HubProviderConfig = { hubId: 'aihubmix' } - const provider = createHubProvider(config) as ProviderV3 + it('should handle mixed model types from same provider', async () => { + const provider = (await createHubProviderAsync({ + hubId: 'aihubmix', + registry, + providerSettingsMap: new Map([['openai', {}]]) + })) as ProviderV3 provider.languageModel(`openai${DEFAULT_SEPARATOR}gpt-4`) provider.embeddingModel(`openai${DEFAULT_SEPARATOR}ada-002`) provider.imageModel(`openai${DEFAULT_SEPARATOR}dall-e-3`) - expect(globalProviderInstanceRegistry.getProvider).toHaveBeenCalledTimes(3) expect(mockOpenAIProvider.languageModel).toHaveBeenCalledWith('gpt-4') expect(mockOpenAIProvider.embeddingModel).toHaveBeenCalledWith('ada-002') expect(mockOpenAIProvider.imageModel).toHaveBeenCalledWith('dall-e-3') }) - - it('should cache provider lookups', () => { - const config: HubProviderConfig = { hubId: 'aihubmix' } - const provider = createHubProvider(config) as ProviderV3 - - provider.languageModel(`openai${DEFAULT_SEPARATOR}gpt-4`) - provider.languageModel(`openai${DEFAULT_SEPARATOR}gpt-3.5`) - - // Should call getProvider twice (once per model call) - expect(globalProviderInstanceRegistry.getProvider).toHaveBeenCalledTimes(2) - }) - }) - - describe('Provider Wrapping', () => { - it('should wrap all providers with empty middleware', () => { - const config: HubProviderConfig = { hubId: 'test-hub' } - const provider = createHubProvider(config) as ProviderV3 - - provider.languageModel(`openai${DEFAULT_SEPARATOR}gpt-4`) - - expect(wrapProvider).toHaveBeenCalledWith({ - provider: mockOpenAIProvider, - languageModelMiddleware: [] - }) - }) - - it('should wrap providers for all model types', () => { - const config: HubProviderConfig = { hubId: 'test-hub' } - const provider = createHubProvider(config) as ProviderV3 - - provider.languageModel(`openai${DEFAULT_SEPARATOR}gpt-4`) - provider.embeddingModel(`openai${DEFAULT_SEPARATOR}ada`) - provider.imageModel(`openai${DEFAULT_SEPARATOR}dalle`) - - expect(wrapProvider).toHaveBeenCalledTimes(3) - }) }) describe('Type Safety', () => { - it('should return properly typed LanguageModelV3', () => { - const config: HubProviderConfig = { hubId: 'test-hub' } - const provider = createHubProvider(config) as ProviderV3 + it('should return properly typed LanguageModelV3', async () => { + const provider = (await createHubProviderAsync({ + hubId: 'test-hub', + registry, + providerSettingsMap: new Map([['openai', {}]]) + })) as ProviderV3 const result = provider.languageModel(`openai${DEFAULT_SEPARATOR}gpt-4`) @@ -503,9 +440,12 @@ describe('HubProvider', () => { expect(result).toHaveProperty('doStream') }) - it('should return properly typed EmbeddingModelV3', () => { - const config: HubProviderConfig = { hubId: 'test-hub' } - const provider = createHubProvider(config) as ProviderV3 + it('should return properly typed EmbeddingModelV3', async () => { + const provider = (await createHubProviderAsync({ + hubId: 'test-hub', + registry, + providerSettingsMap: new Map([['openai', {}]]) + })) as ProviderV3 const result = provider.embeddingModel(`openai${DEFAULT_SEPARATOR}ada`) @@ -513,9 +453,12 @@ describe('HubProvider', () => { expect(result).toHaveProperty('doEmbed') }) - it('should return properly typed ImageModelV3', () => { - const config: HubProviderConfig = { hubId: 'test-hub' } - const provider = createHubProvider(config) as ProviderV3 + it('should return properly typed ImageModelV3', async () => { + const provider = (await createHubProviderAsync({ + hubId: 'test-hub', + registry, + providerSettingsMap: new Map([['openai', {}]]) + })) as ProviderV3 const result = provider.imageModel(`openai${DEFAULT_SEPARATOR}dalle`) @@ -523,119 +466,4 @@ describe('HubProvider', () => { expect(result).toHaveProperty('doGenerate') }) }) - - describe('Dependency Injection', () => { - it('should use global registry by default', () => { - vi.mocked(globalProviderInstanceRegistry.getProvider).mockReturnValue(mockOpenAIProvider) - - const hubProvider = createHubProvider({ hubId: 'test-hub' }) - const provider = hubProvider as ProviderV3 - - provider.languageModel(`openai${DEFAULT_SEPARATOR}gpt-4`) - - // Should call global registry - expect(globalProviderInstanceRegistry.getProvider).toHaveBeenCalledWith('openai') - }) - - it('should use custom registry when provided', () => { - const customRegistry = { - getProvider: vi.fn().mockReturnValue(mockOpenAIProvider) - } - - const hubProvider = createHubProvider({ - hubId: 'test-hub', - providerRegistry: customRegistry as any - }) - const provider = hubProvider as ProviderV3 - - provider.languageModel(`openai${DEFAULT_SEPARATOR}gpt-4`) - - // Should call custom registry, not global - expect(customRegistry.getProvider).toHaveBeenCalledWith('openai') - expect(globalProviderInstanceRegistry.getProvider).not.toHaveBeenCalled() - }) - - it('should allow testing with mock registry', () => { - const mockRegistry = { - getProvider: vi.fn((id: string) => { - if (id === 'test-provider') { - return mockOpenAIProvider - } - return undefined - }) - } - - const hubProvider = createHubProvider({ - hubId: 'test-hub', - providerRegistry: mockRegistry as any - }) - const provider = hubProvider as ProviderV3 - - // Should work with mock registry - const model = provider.languageModel(`test-provider${DEFAULT_SEPARATOR}model`) - expect(mockRegistry.getProvider).toHaveBeenCalledWith('test-provider') - expect(model).toBeDefined() - }) - - it('should throw error when provider not found in custom registry', () => { - const emptyRegistry = { - getProvider: vi.fn().mockReturnValue(undefined) - } - - const hubProvider = createHubProvider({ - hubId: 'test-hub', - providerRegistry: emptyRegistry as any - }) - const provider = hubProvider as ProviderV3 - - expect(() => { - provider.languageModel(`unknown${DEFAULT_SEPARATOR}model`) - }).toThrow(HubProviderError) - - expect(emptyRegistry.getProvider).toHaveBeenCalledWith('unknown') - }) - - it('should support multiple hub instances with different registries', () => { - const registry1 = { - getProvider: vi.fn().mockReturnValue(mockOpenAIProvider) - } - - const registry2 = { - getProvider: vi.fn().mockReturnValue(mockAnthropicProvider) - } - - const hub1 = createHubProvider({ - hubId: 'hub-1', - providerRegistry: registry1 as any - }) as ProviderV3 - - const hub2 = createHubProvider({ - hubId: 'hub-2', - providerRegistry: registry2 as any - }) as ProviderV3 - - // Each hub should use its own registry - hub1.languageModel(`openai${DEFAULT_SEPARATOR}gpt-4`) - hub2.languageModel(`anthropic${DEFAULT_SEPARATOR}claude`) - - expect(registry1.getProvider).toHaveBeenCalledWith('openai') - expect(registry2.getProvider).toHaveBeenCalledWith('anthropic') - - // Registries should be independent - expect(registry1.getProvider).not.toHaveBeenCalledWith('anthropic') - expect(registry2.getProvider).not.toHaveBeenCalledWith('openai') - }) - - it('should make hubId optional and default to "hub"', () => { - vi.mocked(globalProviderInstanceRegistry.getProvider).mockReturnValue(undefined) - - const hubProvider = createHubProvider() // No config - const provider = hubProvider as ProviderV3 - - // Should use default hubId 'hub' in error messages - expect(() => { - provider.languageModel(`unknown${DEFAULT_SEPARATOR}model`) - }).toThrow(HubProviderError) - }) - }) }) diff --git a/packages/aiCore/src/core/providers/__tests__/ProviderExtension.test.ts b/packages/aiCore/src/core/providers/__tests__/ProviderExtension.test.ts index ca4a9927a3..a939cdf90b 100644 --- a/packages/aiCore/src/core/providers/__tests__/ProviderExtension.test.ts +++ b/packages/aiCore/src/core/providers/__tests__/ProviderExtension.test.ts @@ -5,7 +5,7 @@ import type { ProviderV3 } from '@ai-sdk/provider' import { describe, expect, it, vi } from 'vitest' -import { createMockProviderV3 } from '../../../__tests__' +import { createMockProviderV3 } from '@test-utils' import { createProviderExtension, ProviderExtension, @@ -85,7 +85,7 @@ describe('ProviderExtension', () => { expect(extension.config.defaultOptions).toEqual({ apiKey: 'initial-key' }) }) - it('should validate config from function same as from object', () => { + it('should validate config from function same as from object', async () => { expect(() => { ProviderExtension.create(() => ({ name: '', // Invalid @@ -93,15 +93,16 @@ describe('ProviderExtension', () => { })) }).toThrow('name is required') - expect(() => { - ProviderExtension.create( - () => - ({ - name: 'test-provider' - // Missing create - }) as any - ) - }).toThrow('either create or import must be provided') + // Note: create/import validation happens at runtime in createProvider(), not in constructor + // Extension can be created without create/import, but createProvider() will throw + const extension = ProviderExtension.create( + () => + ({ + name: 'test-provider' + // Missing create + }) as any + ) + await expect(extension.createProvider()).rejects.toThrow('cannot create provider') }) }) @@ -115,21 +116,23 @@ describe('ProviderExtension', () => { }).toThrow('name is required') }) - it('should throw error if neither create nor import is provided', () => { - expect(() => { - new ProviderExtension({ - name: 'test-provider' - } as any) - }).toThrow('either create or import must be provided') + it('should throw error at runtime if neither create nor import is provided', async () => { + // Constructor doesn't validate create/import - validation happens at runtime + const extension = new ProviderExtension({ + name: 'test-provider' + } as any) + + await expect(extension.createProvider()).rejects.toThrow('cannot create provider') }) - it('should throw error if import is provided without creatorFunctionName', () => { - expect(() => { - new ProviderExtension({ - name: 'test-provider', - import: async () => ({}) - } as any) - }).toThrow('creatorFunctionName is required when using import') + it('should throw error at runtime if import is provided without creatorFunctionName', async () => { + // Constructor doesn't validate creatorFunctionName - validation happens at runtime + const extension = new ProviderExtension({ + name: 'test-provider', + import: async () => ({}) + } as any) + + await expect(extension.createProvider()).rejects.toThrow('cannot create provider') }) it('should create extension with valid config', () => { @@ -808,16 +811,26 @@ describe('ProviderExtension', () => { expect(onAfterCreate).toHaveBeenCalledTimes(1) }) - it('should support explicit ID parameter', async () => { + it('should support variant suffix parameter', async () => { const extension = new ProviderExtension({ name: 'test-provider', - create: createMockProviderV3 as any + create: createMockProviderV3 as any, + variants: [ + { + suffix: 'chat', + name: 'Test Chat', + transform: (provider) => provider + } + ] }) const settings = { apiKey: 'test-key' } - // Should not throw when providing explicit ID - await expect(extension.createProvider(settings, 'custom-id')).resolves.toBeDefined() + // Should work when providing a valid variant suffix + await expect(extension.createProvider(settings, 'chat')).resolves.toBeDefined() + + // Should throw for unknown variant suffix + await expect(extension.createProvider(settings, 'unknown')).rejects.toThrow('variant "unknown" not found') }) it('should support dynamic import providers', async () => { diff --git a/packages/aiCore/src/core/providers/__tests__/extensions.integration.test.ts b/packages/aiCore/src/core/providers/__tests__/extensions.integration.test.ts deleted file mode 100644 index 64d16239f3..0000000000 --- a/packages/aiCore/src/core/providers/__tests__/extensions.integration.test.ts +++ /dev/null @@ -1,445 +0,0 @@ -/** - * Provider Extensions Integration Tests - * 测试真实 extensions 的完整功能 - */ - -import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' - -import { extensionRegistry } from '../core/ExtensionRegistry' -import { AnthropicExtension } from '../extensions/anthropic' -import { AzureExtension } from '../extensions/azure' -import { OpenAIExtension } from '../extensions/openai' - -// Mock fetch for health checks -global.fetch = vi.fn() - -describe('Provider Extensions Integration', () => { - beforeEach(() => { - // Clear registry before each test - extensionRegistry.clear() - extensionRegistry.clearCache() - vi.clearAllMocks() - }) - - afterEach(() => { - extensionRegistry.clear() - extensionRegistry.clearCache() - }) - - describe('OpenAI Extension', () => { - it('should register and create provider successfully', async () => { - // Register extension - extensionRegistry.register(OpenAIExtension) - - // Verify registration - expect(extensionRegistry.has('openai')).toBe(true) - expect(extensionRegistry.has('oai')).toBe(true) // alias - - // Create provider - const provider = await extensionRegistry.createProvider('openai', { - apiKey: 'sk-test-key-123', - baseURL: 'https://api.openai.com/v1' - }) - - expect(provider).toBeDefined() - }) - - it('should execute onBeforeCreate hook for validation', async () => { - extensionRegistry.register(OpenAIExtension) - - // Invalid API key (doesn't start with "sk-") - await expect( - extensionRegistry.createProvider('openai', { - apiKey: 'invalid-key' - }) - ).rejects.toThrow('Invalid OpenAI API key format') - - // Missing API key - await expect(extensionRegistry.createProvider('openai', {})).rejects.toThrow('OpenAI API key is required') - }) - - it('should execute onAfterCreate hook for caching', async () => { - extensionRegistry.register(OpenAIExtension) - - const settings = { - apiKey: 'sk-test-key-123', - baseURL: 'https://api.openai.com/v1' - } - - // Create provider - const provider = await extensionRegistry.createProvider('openai', settings) - - // Check extension's internal storage (custom cache) - const ext = extensionRegistry.get('openai') - const cache = ext?.storage.get('providerCache') - expect(cache).toBeDefined() - expect(cache?.has('sk-test-key-123')).toBe(true) - expect(cache?.get('sk-test-key-123')).toBe(provider) - }) - - it('should cache providers based on settings', async () => { - extensionRegistry.register(OpenAIExtension) - - const settings = { - apiKey: 'sk-test-key-123', - baseURL: 'https://api.openai.com/v1' - } - - // First call - creates provider - const provider1 = await extensionRegistry.createProvider('openai', settings) - - // Second call with same settings - returns cached - const provider2 = await extensionRegistry.createProvider('openai', settings) - - expect(provider1).toBe(provider2) // Same instance - - // Different settings - creates new provider - const provider3 = await extensionRegistry.createProvider('openai', { - apiKey: 'sk-different-key-456', - baseURL: 'https://api.openai.com/v1' - }) - - expect(provider3).not.toBe(provider1) // Different instance - }) - - it('should support openai-chat variant', async () => { - extensionRegistry.register(OpenAIExtension) - - // Verify variant ID exists - const providerIds = OpenAIExtension.getProviderIds() - expect(providerIds).toContain('openai') - expect(providerIds).toContain('openai-chat') - - // Create variant provider - await extensionRegistry.createAndRegisterProvider('openai', { - apiKey: 'sk-test-key-123' - }) - - // Both base and variant should be available - const stats = extensionRegistry.getStats() - expect(stats.totalExtensions).toBe(1) - expect(stats.extensionsWithVariants).toBe(1) - }) - - it('should skip cache when requested', async () => { - extensionRegistry.register(OpenAIExtension) - - const settings = { - apiKey: 'sk-test-key-123' - } - - // First creation - const provider1 = await extensionRegistry.createProvider('openai', settings) - - // Skip cache - creates new instance - const provider2 = await extensionRegistry.createProvider('openai', settings, { - skipCache: true - }) - - expect(provider2).not.toBe(provider1) // Different instances - }) - - it('should track health status in storage', async () => { - extensionRegistry.register(OpenAIExtension) - - await extensionRegistry.createProvider('openai', { - apiKey: 'sk-test-key-123' - }) - - const ext = extensionRegistry.get('openai') - const health = ext?.storage.get('healthStatus') - - expect(health).toBeDefined() - expect(health?.isHealthy).toBe(true) - expect(health?.consecutiveFailures).toBe(0) - expect(health?.lastCheckTime).toBeGreaterThan(0) - }) - }) - - describe('Anthropic Extension', () => { - it('should validate Anthropic API key format', async () => { - extensionRegistry.register(AnthropicExtension) - - // Invalid format (doesn't start with "sk-ant-") - await expect( - extensionRegistry.createProvider('anthropic', { - apiKey: 'sk-test-key' - }) - ).rejects.toThrow('Invalid Anthropic API key format') - - // Missing API key - await expect(extensionRegistry.createProvider('anthropic', {})).rejects.toThrow('Anthropic API key is required') - - // Valid format - const provider = await extensionRegistry.createProvider('anthropic', { - apiKey: 'sk-ant-test-key-123' - }) - - expect(provider).toBeDefined() - }) - - it('should validate baseURL format', async () => { - extensionRegistry.register(AnthropicExtension) - - // Invalid baseURL (no http/https) - await expect( - extensionRegistry.createProvider('anthropic', { - apiKey: 'sk-ant-test-key', - baseURL: 'api.anthropic.com' // Missing protocol - }) - ).rejects.toThrow('Invalid baseURL format') - - // Valid baseURL - const provider = await extensionRegistry.createProvider('anthropic', { - apiKey: 'sk-ant-test-key', - baseURL: 'https://api.anthropic.com' - }) - - expect(provider).toBeDefined() - }) - - it('should track creation statistics', async () => { - extensionRegistry.register(AnthropicExtension) - - // First successful creation - await extensionRegistry.createProvider('anthropic', { - apiKey: 'sk-ant-test-key-1' - }) - - const ext = extensionRegistry.get('anthropic') - let stats = ext?.storage.get('stats') - expect(stats?.totalCreations).toBe(1) - expect(stats?.failedCreations).toBe(0) - - // Failed creation - try { - await extensionRegistry.createProvider('anthropic', { - apiKey: 'invalid-key' - }) - } catch { - // Expected error - } - - stats = ext?.storage.get('stats') - expect(stats?.totalCreations).toBe(2) - expect(stats?.failedCreations).toBe(1) - - // Second successful creation - await extensionRegistry.createProvider('anthropic', { - apiKey: 'sk-ant-test-key-2' - }) - - stats = ext?.storage.get('stats') - expect(stats?.totalCreations).toBe(3) - expect(stats?.failedCreations).toBe(1) - }) - - it('should record lastSuccessfulCreation timestamp', async () => { - extensionRegistry.register(AnthropicExtension) - - const before = Date.now() - - await extensionRegistry.createProvider('anthropic', { - apiKey: 'sk-ant-test-key' - }) - - const after = Date.now() - - const ext = extensionRegistry.get('anthropic') - const timestamp = ext?.storage.get('lastSuccessfulCreation') - - expect(timestamp).toBeDefined() - expect(timestamp).toBeGreaterThanOrEqual(before) - expect(timestamp).toBeLessThanOrEqual(after) - }) - - it('should support claude alias', async () => { - extensionRegistry.register(AnthropicExtension) - - // Access via alias - expect(extensionRegistry.has('claude')).toBe(true) - - const provider = await extensionRegistry.createProvider('claude', { - apiKey: 'sk-ant-test-key' - }) - - expect(provider).toBeDefined() - }) - }) - - describe('Azure Extension', () => { - it('should validate Azure configuration', async () => { - extensionRegistry.register(AzureExtension) - - // Missing both resourceName and baseURL - await expect( - extensionRegistry.createProvider('azure', { - apiKey: 'test-key' - }) - ).rejects.toThrow('Azure OpenAI requires either resourceName or baseURL') - - // Missing API key - await expect( - extensionRegistry.createProvider('azure', { - resourceName: 'my-resource' - }) - ).rejects.toThrow('Azure OpenAI API key is required') - }) - - it('should validate resourceName format', async () => { - extensionRegistry.register(AzureExtension) - - // Invalid format (uppercase) - await expect( - extensionRegistry.createProvider('azure', { - resourceName: 'MyResource', - apiKey: 'test-key' - }) - ).rejects.toThrow('Invalid Azure resource name format') - - // Invalid format (special chars) - await expect( - extensionRegistry.createProvider('azure', { - resourceName: 'my_resource', - apiKey: 'test-key' - }) - ).rejects.toThrow('Invalid Azure resource name format') - - // Valid format - const provider = await extensionRegistry.createProvider('azure', { - resourceName: 'my-resource-123', - apiKey: 'test-key' - }) - - expect(provider).toBeDefined() - }) - - it('should cache resource endpoints', async () => { - extensionRegistry.register(AzureExtension) - - await extensionRegistry.createProvider('azure', { - resourceName: 'my-resource', - apiKey: 'test-key' - }) - - const ext = extensionRegistry.get('azure') - const endpoints = ext?.storage.get('resourceEndpoints') - - expect(endpoints).toBeDefined() - expect(endpoints?.has('my-resource')).toBe(true) - expect(endpoints?.get('my-resource')).toBe('https://my-resource.openai.azure.com') - }) - - it('should track validated deployments', async () => { - extensionRegistry.register(AzureExtension) - - // First deployment - await extensionRegistry.createProvider('azure', { - resourceName: 'resource-1', - apiKey: 'test-key-1' - }) - - const ext = extensionRegistry.get('azure') - let deployments = ext?.storage.get('validatedDeployments') - expect(deployments?.size).toBe(1) - expect(deployments?.has('resource-1')).toBe(true) - - // Second deployment - await extensionRegistry.createProvider('azure', { - resourceName: 'resource-2', - apiKey: 'test-key-2' - }) - - deployments = ext?.storage.get('validatedDeployments') - expect(deployments?.size).toBe(2) - expect(deployments?.has('resource-2')).toBe(true) - }) - - it('should support azure-responses variant', async () => { - extensionRegistry.register(AzureExtension) - - const providerIds = AzureExtension.getProviderIds() - expect(providerIds).toContain('azure') - expect(providerIds).toContain('azure-responses') - }) - - it('should support azure-openai alias', async () => { - extensionRegistry.register(AzureExtension) - - expect(extensionRegistry.has('azure-openai')).toBe(true) - - const provider = await extensionRegistry.createProvider('azure-openai', { - resourceName: 'my-resource', - apiKey: 'test-key' - }) - - expect(provider).toBeDefined() - }) - }) - - describe('Multiple Extensions', () => { - it('should register multiple extensions simultaneously', () => { - extensionRegistry.registerAll([OpenAIExtension, AnthropicExtension, AzureExtension]) - - const stats = extensionRegistry.getStats() - expect(stats.totalExtensions).toBe(3) - expect(stats.extensionsWithVariants).toBe(2) // OpenAI and Azure - }) - - it('should maintain separate storage for each extension', async () => { - extensionRegistry.registerAll([OpenAIExtension, AnthropicExtension]) - - // Create providers - await extensionRegistry.createProvider('openai', { - apiKey: 'sk-test-key' - }) - - await extensionRegistry.createProvider('anthropic', { - apiKey: 'sk-ant-test-key' - }) - - // Check OpenAI storage - const openaiExt = extensionRegistry.get('openai') - const openaiCache = openaiExt?.storage.get('providerCache') - expect(openaiCache?.size).toBe(1) - - // Check Anthropic storage - const anthropicExt = extensionRegistry.get('anthropic') - const anthropicStats = anthropicExt?.storage.get('stats') - expect(anthropicStats?.totalCreations).toBe(1) - - // Storages are independent - expect(openaiExt?.storage.get('stats')).toBeUndefined() - expect(anthropicExt?.storage.get('providerCache')).toBeUndefined() - }) - - it('should clear cache per extension', async () => { - extensionRegistry.registerAll([OpenAIExtension, AnthropicExtension]) - - // Create providers - await extensionRegistry.createProvider('openai', { - apiKey: 'sk-test-key' - }) - - await extensionRegistry.createProvider('anthropic', { - apiKey: 'sk-ant-test-key' - }) - - // Verify both are cached - const stats1 = extensionRegistry.getStats() - expect(stats1.cachedProviders).toBe(2) - - // Clear only OpenAI cache - extensionRegistry.clearCache('openai') - - const stats2 = extensionRegistry.getStats() - expect(stats2.cachedProviders).toBe(1) // Only Anthropic remains - - // Clear all caches - extensionRegistry.clearCache() - - const stats3 = extensionRegistry.getStats() - expect(stats3.cachedProviders).toBe(0) - }) - }) -}) diff --git a/packages/aiCore/src/core/providers/__tests__/initialization.test.ts b/packages/aiCore/src/core/providers/__tests__/initialization.test.ts deleted file mode 100644 index f27172aca1..0000000000 --- a/packages/aiCore/src/core/providers/__tests__/initialization.test.ts +++ /dev/null @@ -1,165 +0,0 @@ -import type { ProviderV3 } from '@ai-sdk/provider' -import { afterEach, beforeEach, describe, expect, it } from 'vitest' - -import { ExtensionRegistry } from '../core/ExtensionRegistry' -import { isRegisteredProvider } from '../core/initialization' -import { ProviderExtension } from '../core/ProviderExtension' -import { ProviderInstanceRegistry } from '../core/ProviderInstanceRegistry' - -// Mock provider for testing -const createMockProviderV3 = (): ProviderV3 => ({ - specificationVersion: 'v3' as const, - languageModel: () => ({}) as any, - embeddingModel: () => ({}) as any, - imageModel: () => ({}) as any -}) - -describe('initialization utilities', () => { - let testExtensionRegistry: ExtensionRegistry - let testInstanceRegistry: ProviderInstanceRegistry - - beforeEach(() => { - testExtensionRegistry = new ExtensionRegistry() - testInstanceRegistry = new ProviderInstanceRegistry() - }) - - afterEach(() => { - // Clean up registries - testExtensionRegistry = null as any - testInstanceRegistry = null as any - }) - - describe('isRegisteredProvider()', () => { - it('should return true for providers registered in Extension Registry', () => { - testExtensionRegistry.register( - new ProviderExtension({ - name: 'test-provider', - create: createMockProviderV3 - }) - ) - - // Note: isRegisteredProvider uses global registries, so this tests the concept - // In practice, we'd need to modify the function to accept registries as parameters - // For now, this documents the expected behavior - expect(typeof isRegisteredProvider).toBe('function') - }) - - it('should return true for providers registered in Provider Instance Registry', () => { - const mockProvider = createMockProviderV3() - testInstanceRegistry.registerProvider('test-provider', mockProvider) - - // Note: This tests the concept - actual implementation uses global registries - expect(testInstanceRegistry.getProvider('test-provider')).toBeDefined() - }) - - it('should return false for unregistered providers', () => { - // Both registries are empty - const result = isRegisteredProvider('unknown-provider') - - // Note: This will check global registries - expect(typeof result).toBe('boolean') - }) - - it('should work with provider aliases', () => { - testExtensionRegistry.register( - new ProviderExtension({ - name: 'openai', - aliases: ['oai'], - create: createMockProviderV3 - }) - ) - - // Should be able to check both main ID and alias - expect(testExtensionRegistry.has('openai')).toBe(true) - expect(testExtensionRegistry.has('oai')).toBe(true) - }) - - it('should work with variant IDs', () => { - testExtensionRegistry.register( - new ProviderExtension({ - name: 'openai', - create: createMockProviderV3, - variants: [ - { - suffix: 'chat', - name: 'OpenAI Chat', - transform: (provider) => provider - } - ] - }) - ) - - // Base provider should be registered - expect(testExtensionRegistry.has('openai')).toBe(true) - - // Variant ID can be checked with isVariant method - expect(testExtensionRegistry.isVariant('openai-chat')).toBe(true) - - // Base provider ID should be resolvable from variant - expect(testExtensionRegistry.getBaseProviderId('openai-chat')).toBe('openai') - }) - - it('should return true if provider is in either registry', () => { - // Register in extension registry only - testExtensionRegistry.register( - new ProviderExtension({ - name: 'ext-only', - create: createMockProviderV3 - }) - ) - - // Register in instance registry only - const mockProvider = createMockProviderV3() - testInstanceRegistry.registerProvider('instance-only', mockProvider) - - // Both should be considered registered - expect(testExtensionRegistry.has('ext-only')).toBe(true) - expect(testInstanceRegistry.getProvider('instance-only')).toBeDefined() - }) - - it('should handle empty string gracefully', () => { - const result = isRegisteredProvider('') - expect(typeof result).toBe('boolean') - }) - - it('should be case-sensitive', () => { - testExtensionRegistry.register( - new ProviderExtension({ - name: 'openai', - create: createMockProviderV3 - }) - ) - - expect(testExtensionRegistry.has('openai')).toBe(true) - expect(testExtensionRegistry.has('OpenAI')).toBe(false) - expect(testExtensionRegistry.has('OPENAI')).toBe(false) - }) - }) - - describe('Integration: isRegisteredProvider with actual registries', () => { - it('should correctly identify providers across both registries', () => { - // This test documents the expected behavior when both registries are involved - // isRegisteredProvider checks: extensionRegistry.has(id) || instanceRegistry.getProvider(id) !== undefined - - testExtensionRegistry.register( - new ProviderExtension({ - name: 'registered-ext', - create: createMockProviderV3 - }) - ) - - const mockProvider = createMockProviderV3() - testInstanceRegistry.registerProvider('registered-instance', mockProvider) - - // Extension registry check - expect(testExtensionRegistry.has('registered-ext')).toBe(true) - - // Instance registry check - expect(testInstanceRegistry.getProvider('registered-instance')).toBeDefined() - - // Unregistered provider - expect(testExtensionRegistry.has('unregistered')).toBe(false) - expect(testInstanceRegistry.getProvider('unregistered')).toBeUndefined() - }) - }) -}) diff --git a/packages/aiCore/src/core/providers/core/ExtensionRegistry.ts b/packages/aiCore/src/core/providers/core/ExtensionRegistry.ts index 18ba1e7846..743e74e83f 100644 --- a/packages/aiCore/src/core/providers/core/ExtensionRegistry.ts +++ b/packages/aiCore/src/core/providers/core/ExtensionRegistry.ts @@ -5,7 +5,7 @@ import type { ProviderV3 } from '@ai-sdk/provider' -import type { RegisteredProviderId } from '../index' +import type { CoreProviderSettingsMap, RegisteredProviderId } from '../index' import { type ProviderExtension } from './ProviderExtension' import { ProviderCreationError } from './utils' @@ -52,15 +52,12 @@ export class ExtensionRegistry { register(extension: ProviderExtension): this { const { name, aliases, variants } = extension.config - // 检查主 ID 冲突 if (this.extensions.has(name)) { throw new Error(`Provider extension "${name}" is already registered`) } - // 注册主 Extension this.extensions.set(name, extension) - // 注册别名 if (aliases) { for (const alias of aliases) { if (this.aliasMap.has(alias)) { @@ -70,7 +67,6 @@ export class ExtensionRegistry { } } - // 注册变体 ID if (variants) { for (const variant of variants) { const variantId = `${name}-${variant.suffix}` @@ -106,10 +102,8 @@ export class ExtensionRegistry { return false } - // 删除主 Extension this.extensions.delete(name) - // 删除别名 if (extension.config.aliases) { for (const alias of extension.config.aliases) { this.aliasMap.delete(alias) @@ -123,12 +117,10 @@ export class ExtensionRegistry { * 获取 Extension(支持别名) */ get(id: string): ProviderExtension | undefined { - // 直接查找 if (this.extensions.has(id)) { return this.extensions.get(id) } - // 通过别名查找 const realName = this.aliasMap.get(id) if (realName) { return this.extensions.get(realName) @@ -250,17 +242,7 @@ export class ExtensionRegistry { * ``` */ parseProviderId(providerId: string): { baseId: RegisteredProviderId; mode?: string; isVariant: boolean } | null { - // 先检查是否是已注册的 extension(直接或通过别名) - const extension = this.get(providerId) - if (extension) { - // 是基础 ID 或别名,不是变体 - return { - baseId: extension.config.name as RegisteredProviderId, - isVariant: false - } - } - - // 遍历所有 extensions,查找匹配的变体 + // 先遍历所有 extensions,查找匹配的变体(优先于别名检查) for (const ext of this.extensions.values()) { if (!ext.config.variants) { continue @@ -279,6 +261,16 @@ export class ExtensionRegistry { } } + // 再检查是否是已注册的 extension(直接或通过别名) + const extension = this.get(providerId) + if (extension) { + // 是基础 ID 或别名,不是变体 + return { + baseId: extension.config.name as RegisteredProviderId, + isVariant: false + } + } + // 无法解析 return null } @@ -379,15 +371,21 @@ export class ExtensionRegistry { /** * 创建 provider 实例 - * 委托给 ProviderExtension 处理(包括缓存、生命周期钩子等) * - * @param id - Provider ID(支持别名和变体) - * @param settings - Provider 配置选项 - * @param explicitId - 可选的显式ID,用于AI SDK注册 + * 支持两种调用方式: + * 1. 类型安全版本 - 使用已注册的 provider ID,获得完整的类型推导 + * 2. 动态版本 - 使用任意字符串 ID,用于测试或动态注册的 provider + * + * @param id - Provider ID + * @param settings - Provider 配置 * @returns Provider 实例 */ - async createProvider(id: string, settings?: any, explicitId?: string): Promise { - // 解析 provider ID,提取基础 ID 和变体后缀 + async createProvider( + id: T, + settings: CoreProviderSettingsMap[T] + ): Promise + async createProvider(id: string, settings?: unknown): Promise + async createProvider(id: string, settings?: unknown): Promise { const parsed = this.parseProviderId(id) if (!parsed) { throw new Error(`Provider extension "${id}" not found. Did you forget to register it?`) @@ -395,16 +393,13 @@ export class ExtensionRegistry { const { baseId, mode: variantSuffix } = parsed - // 获取基础 extension const extension = this.get(baseId) if (!extension) { throw new Error(`Provider extension "${baseId}" not found. Did you forget to register it?`) } try { - // 委托给 Extension 的 createProvider 方法 - // Extension 负责缓存、生命周期钩子、AI SDK 注册、变体转换等 - return await extension.createProvider(settings, explicitId, variantSuffix) + return await extension.createProvider(settings, variantSuffix) } catch (error) { throw new ProviderCreationError( `Failed to create provider "${id}"`, diff --git a/packages/aiCore/src/core/providers/core/ProviderExtension.ts b/packages/aiCore/src/core/providers/core/ProviderExtension.ts index 40f8c03f69..e471a5ad73 100644 --- a/packages/aiCore/src/core/providers/core/ProviderExtension.ts +++ b/packages/aiCore/src/core/providers/core/ProviderExtension.ts @@ -1,19 +1,9 @@ import type { ProviderV3 } from '@ai-sdk/provider' +import { LRUCache } from 'lru-cache' import { deepMergeObjects } from '../../utils' import type { ExtensionContext, ExtensionStorage, LifecycleHooks, ProviderVariant, StorageAccessor } from '../types' -/** - * 全局 Provider 存储 - * Extension 创建的 provider 实例注册到这里,供 HubProvider 等使用 - * Key: explicit ID (用户指定的唯一标识) - * Value: Provider 实例 - */ -export const globalProviderStorage = new Map() - -/** - * Provider 创建函数类型 - */ export type ProviderCreatorFunction = (settings?: TSettings) => ProviderV3 | Promise /** @@ -80,19 +70,10 @@ interface ProviderExtensionConfigWithCreate< TProvider extends ProviderV3 = ProviderV3, TName extends string = string > extends ProviderExtensionConfigBase { - /** - * 创建 provider 实例的函数 - */ create: ProviderCreatorFunction - /** - * 禁止使用 import(与 create 互斥) - */ import?: never - /** - * 禁止使用 creatorFunctionName(与 create 互斥) - */ creatorFunctionName?: never } @@ -107,21 +88,10 @@ interface ProviderExtensionConfigWithImport< TProvider extends ProviderV3 = ProviderV3, TName extends string = string > extends ProviderExtensionConfigBase { - /** - * 禁止使用 create(与 import 互斥) - */ create?: never - /** - * 动态导入模块的函数 - * 用于延迟加载第三方 provider - */ import: () => Promise> - /** - * 导入模块后的 creator 函数名 - * 必须与 import 一起使用 - */ creatorFunctionName: string } @@ -196,20 +166,23 @@ export class ProviderExtension< > { private _storage: Map - /** Provider 实例缓存 - 按 settings hash 存储 */ - private instances: Map = new Map() + /** Provider 实例缓存 - 按 settings hash 存储,LRU 自动清理 */ + private instances: LRUCache /** Settings hash 映射表 - 用于验证缓存是否仍然有效 */ private settingsHashes: Map = new Map() constructor(public readonly config: TConfig) { - // 验证配置 if (!config.name) { throw new Error('ProviderExtension: name is required') } - // 初始化 storage this._storage = new Map(Object.entries(config.initialStorage || {})) + + this.instances = new LRUCache({ + max: 10, + updateAgeOnGet: true + }) } /** @@ -370,27 +343,14 @@ export class ProviderExtension< } /** - * 注册 Provider 到全局注册表 - * Extension 拥有的 provider 实例会被注册到全局 Map,供 HubProvider 等使用 - * @private - */ - private registerToAiSdk(provider: TProvider, explicitId: string): void { - // 注册到全局 provider storage - // 使用 explicit ID 作为 key - globalProviderStorage.set(explicitId, provider as any) - } - - /** - * 创建 Provider 实例(带缓存) + * 创建 Provider 实例 * 相同 settings 会复用实例,不同 settings 会创建新实例 * * @param settings - Provider 配置 - * @param explicitId - 可选的显式 ID,用于 AI SDK 注册 * @param variantSuffix - 可选的变体后缀,用于应用变体转换 * @returns Provider 实例 */ - async createProvider(settings?: TSettings, explicitId?: string, variantSuffix?: string): Promise { - // 验证变体后缀(如果提供) + async createProvider(settings?: TSettings, variantSuffix?: string): Promise { if (variantSuffix) { const variant = this.getVariant(variantSuffix) if (!variant) { @@ -402,31 +362,22 @@ export class ProviderExtension< } // 合并 default options - const mergedSettings = deepMergeObjects( - (this.config.defaultOptions || {}) as any, - (settings || {}) as any - ) as TSettings + const mergedSettings = deepMergeObjects(this.config.defaultOptions || {}, settings || {}) as TSettings - // 计算 hash(包含变体后缀) const hash = this.computeHash(mergedSettings, variantSuffix) - // 检查缓存 const cachedInstance = this.instances.get(hash) if (cachedInstance) { return cachedInstance } - // 执行 onBeforeCreate 钩子 await this.executeHook('onBeforeCreate', mergedSettings) - // 创建基础 provider 实例 let baseProvider: ProviderV3 if (this.config.create) { - // 使用直接创建函数 baseProvider = await Promise.resolve(this.config.create(mergedSettings)) } else if (this.config.import && this.config.creatorFunctionName) { - // 动态导入 const module = await this.config.import() const creatorFn = module[this.config.creatorFunctionName] @@ -441,39 +392,19 @@ export class ProviderExtension< throw new Error(`ProviderExtension "${this.config.name}": cannot create provider, invalid configuration`) } - // 应用变体转换(如果提供了变体后缀) let finalProvider: TProvider if (variantSuffix) { const variant = this.getVariant(variantSuffix)! - // 应用变体的 transform 函数 finalProvider = (await Promise.resolve(variant.transform(baseProvider as TProvider, mergedSettings))) as TProvider } else { finalProvider = baseProvider as TProvider } - // 执行 onAfterCreate 钩子 await this.executeHook('onAfterCreate', mergedSettings, finalProvider) - // 缓存实例 this.instances.set(hash, finalProvider) this.settingsHashes.set(hash, mergedSettings) - // 确定注册 ID - const registrationId = (() => { - if (explicitId) { - return explicitId - } - // 如果是变体,使用 name-suffix:hash 格式 - if (variantSuffix) { - return `${this.config.name}-${variantSuffix}:${hash}` - } - // 否则使用 name:hash - return `${this.config.name}:${hash}` - })() - - // 注册到 AI SDK - this.registerToAiSdk(finalProvider, registrationId) - return finalProvider } diff --git a/packages/aiCore/src/core/providers/core/initialization.ts b/packages/aiCore/src/core/providers/core/initialization.ts index fe180d78a9..c24493ca70 100644 --- a/packages/aiCore/src/core/providers/core/initialization.ts +++ b/packages/aiCore/src/core/providers/core/initialization.ts @@ -33,7 +33,7 @@ import type { } from '../types' import { extensionRegistry } from './ExtensionRegistry' import type { ProviderExtensionConfig } from './ProviderExtension' -import { globalProviderStorage, ProviderExtension } from './ProviderExtension' +import { ProviderExtension } from './ProviderExtension' // ==================== Core Extensions ==================== @@ -268,14 +268,6 @@ class ProviderInitializationError extends Error { } } -// ==================== 全局 Provider Storage 导出 ==================== - -/** - * 全局 Provider Storage - * Extension 创建的 provider 实例会注册到这里 - */ -export { globalProviderStorage } - // ==================== 工具函数 ==================== /** @@ -292,57 +284,6 @@ export function getSupportedProviders(): Array<{ })) } -/** - * 获取所有已初始化的 providers (explicit IDs) - */ -export function getInitializedProviders(): string[] { - return Array.from(globalProviderStorage.keys()) -} - -/** - * 检查是否有任何已初始化的 providers - */ -export function hasInitializedProviders(): boolean { - return globalProviderStorage.size > 0 -} - -/** - * 检查指定的 provider ID 是否已注册 - * 检查 Extension Registry (template) 或 Global Provider Storage (initialized instance) - * - * @param id - Provider ID to check (extension name or explicit ID) - * @returns true if the provider is registered (either as extension or initialized instance) - * - * @example - * ```typescript - * if (isRegisteredProvider('openai')) { - * // Provider extension exists - * } - * if (isRegisteredProvider('my-openai-instance')) { - * // Initialized provider instance exists - * } - * ``` - */ -export function isRegisteredProvider(id: string): boolean { - return extensionRegistry.has(id) || globalProviderStorage.has(id) -} - -/** - * 创建 Provider - 使用 Extension Registry - * - * @param providerId - Provider ID (extension name) - * @param options - Provider settings - * @param explicitId - 可选的显式 ID,用于注册到 globalProviderStorage。如果不提供,Extension 会使用 `name:hash` 作为默认 ID - * @returns Provider 实例 - */ -export async function createProvider(providerId: string, options: any, explicitId?: string): Promise { - if (!extensionRegistry.has(providerId)) { - throw new Error(`Provider "${providerId}" not found in Extension Registry`) - } - - return await extensionRegistry.createProvider(providerId, options, explicitId) -} - /** * 检查是否有对应的 Provider Extension */ @@ -350,13 +291,6 @@ export function hasProviderConfig(providerId: string): boolean { return extensionRegistry.has(providerId) } -/** - * 清除所有已注册的 provider 实例 - */ -export function clearAllProviders(): void { - globalProviderStorage.clear() -} - // ==================== 导出错误类型 ==================== export { ProviderInitializationError } diff --git a/packages/aiCore/src/core/providers/features/HubProvider.ts b/packages/aiCore/src/core/providers/features/HubProvider.ts index 825fd4d815..24f4021642 100644 --- a/packages/aiCore/src/core/providers/features/HubProvider.ts +++ b/packages/aiCore/src/core/providers/features/HubProvider.ts @@ -1,8 +1,8 @@ /** * Hub Provider - 支持路由到多个底层provider * - * 支持格式: hubId:providerId:modelId - * 例如: aihubmix:anthropic:claude-3.5-sonnet + * 支持格式: hubId|providerId|modelId + * @example aihubmix|anthropic|claude-3.5-sonnet */ import type { @@ -14,10 +14,10 @@ import type { SpeechModelV3, TranscriptionModelV3 } from '@ai-sdk/provider' -import { customProvider, wrapProvider } from 'ai' +import { customProvider } from 'ai' -import { globalProviderStorage } from '../core/ProviderExtension' -import type { AiSdkProvider } from '../types' +import type { ExtensionRegistry } from '../core/ExtensionRegistry' +import type { CoreProviderSettingsMap } from '../types' /** Model ID 分隔符 */ export const DEFAULT_SEPARATOR = '|' @@ -27,6 +27,10 @@ export interface HubProviderConfig { hubId?: string /** 是否启用调试日志 */ debug?: boolean + /** ExtensionRegistry实例(用于获取provider extensions) */ + registry: ExtensionRegistry + /** Provider配置映射 */ + providerSettingsMap: Map } export class HubProviderError extends Error { @@ -46,8 +50,11 @@ export class HubProviderError extends Error { */ function parseHubModelId(modelId: string): { provider: string; actualModelId: string } { const parts = modelId.split(DEFAULT_SEPARATOR) - if (parts.length !== 2) { - throw new HubProviderError(`Invalid hub model ID format. Expected "provider:modelId", got: ${modelId}`, 'unknown') + if (parts.length !== 2 || !parts[0] || !parts[1]) { + throw new HubProviderError( + `Invalid hub model ID format. Expected "provider${DEFAULT_SEPARATOR}modelId", got: ${modelId}`, + 'unknown' + ) } return { provider: parts[0], @@ -56,37 +63,72 @@ function parseHubModelId(modelId: string): { provider: string; actualModelId: st } /** - * 创建Hub Provider + * 异步创建Hub Provider + * + * 预创建所有provider实例以满足AI SDK的同步要求 + * 通过ExtensionRegistry复用ProviderExtension的LRU缓存 */ -export function createHubProvider(config?: HubProviderConfig): AiSdkProvider { - const hubId = config?.hubId ?? 'hub' +export async function createHubProviderAsync(config: HubProviderConfig): Promise { + const { registry, providerSettingsMap, debug, hubId = 'hub' } = config + + // 预创建所有 provider 实例 + const providers = new Map() + + for (const [providerId, settings] of providerSettingsMap.entries()) { + const extension = registry.get(providerId) + if (!extension) { + const availableExtensions = registry + .getAll() + .map((ext) => ext.config.name) + .join(', ') + throw new HubProviderError( + `Provider extension "${providerId}" not found in registry. Available: ${availableExtensions}`, + hubId, + providerId + ) + } - function getTargetProvider(providerId: string): ProviderV3 { - // 从全局 provider storage 获取已注册的provider实例 try { - const provider = globalProviderStorage.get(providerId) - if (!provider) { - throw new HubProviderError( - `Provider "${providerId}" is not registered. Please call extension.createProvider(settings, "${providerId}") first.`, - hubId, - providerId - ) - } - // 使用 wrapProvider 确保返回的是 V3 provider - // 这样可以自动处理 V2 provider 到 V3 的转换 - return wrapProvider({ provider, languageModelMiddleware: [] }) + // 通过 extension 创建 provider(复用 LRU 缓存) + const provider = await extension.createProvider(settings) + providers.set(providerId, provider) } catch (error) { throw new HubProviderError( - `Failed to get provider "${providerId}": ${error instanceof Error ? error.message : 'Unknown error'}`, + `Failed to create provider "${providerId}": ${error instanceof Error ? error.message : String(error)}`, hubId, providerId, error instanceof Error ? error : undefined ) } } + return createHubProviderWithProviders(hubId, providers, debug) +} - // 创建符合 ProviderV3 规范的 fallback provider - const hubFallbackProvider = { +/** + * 内部函数:使用预创建的providers创建HubProvider + */ +function createHubProviderWithProviders( + hubId: string, + providers: Map, + debug?: boolean +): ProviderV3 { + function getTargetProvider(providerId: string): ProviderV3 { + const provider = providers.get(providerId) + if (!provider) { + const availableProviders = Array.from(providers.keys()).join(', ') + throw new HubProviderError( + `Provider "${providerId}" not initialized. Available: ${availableProviders}`, + hubId, + providerId + ) + } + if (debug) { + console.log(`[HubProvider:${hubId}] Routing to provider: ${providerId}`) + } + return provider + } + + const hubFallbackProvider: ProviderV3 = { specificationVersion: 'v3' as const, languageModel: (modelId: string): LanguageModelV3 => { @@ -128,6 +170,7 @@ export function createHubProvider(config?: HubProviderConfig): AiSdkProvider { return targetProvider.speechModel(actualModelId) }, + rerankingModel: (modelId: string): RerankingModelV3 => { const { provider, actualModelId } = parseHubModelId(modelId) const targetProvider = getTargetProvider(provider) diff --git a/packages/aiCore/src/core/providers/index.ts b/packages/aiCore/src/core/providers/index.ts index fb4ccc065e..4b720cebcf 100644 --- a/packages/aiCore/src/core/providers/index.ts +++ b/packages/aiCore/src/core/providers/index.ts @@ -3,18 +3,12 @@ */ // ==================== 核心管理器 ==================== -export { globalProviderStorage } from './core/ProviderExtension' // Provider 核心功能 export { - clearAllProviders, coreExtensions, - createProvider, - getInitializedProviders, getSupportedProviders, - hasInitializedProviders, hasProviderConfig, - isRegisteredProvider, ProviderInitializationError, registeredProviderIds } from './core/initialization' @@ -24,7 +18,7 @@ export { // 类型定义 export type { AiSdkModel, ProviderError } from './types' -// 类型提取工具(用于应用层 Merge Point 模式) +// 类型提取工具 export type { CoreProviderSettingsMap, ExtensionConfigToIdResolutionMap, @@ -43,7 +37,11 @@ export { formatPrivateKey, ProviderCreationError } from './core/utils' // ==================== 扩展功能 ==================== // Hub Provider 功能 -export { createHubProvider, type HubProviderConfig, HubProviderError } from './features/HubProvider' +export { + createHubProviderAsync, + type HubProviderConfig, + HubProviderError +} from './features/HubProvider' // ==================== Provider Extension 系统 ==================== diff --git a/packages/aiCore/src/core/runtime/__tests__/executor-resolveModel.test.ts b/packages/aiCore/src/core/runtime/__tests__/executor-resolveModel.test.ts deleted file mode 100644 index 74a80f01b3..0000000000 --- a/packages/aiCore/src/core/runtime/__tests__/executor-resolveModel.test.ts +++ /dev/null @@ -1,650 +0,0 @@ -/** - * RuntimeExecutor.resolveModel Comprehensive Tests - * Tests the private resolveModel and resolveImageModel methods through public APIs - * Covers model resolution, middleware application, and type validation - */ - -import type { ImageModelV3, LanguageModelV3 } from '@ai-sdk/provider' -import { generateImage, generateText, streamText } from 'ai' -import { beforeEach, describe, expect, it, vi } from 'vitest' - -import { - createMockImageModel, - createMockLanguageModel, - createMockMiddleware, - mockProviderConfigs -} from '../../../__tests__' -import { globalModelResolver } from '../../models' -import { ImageModelResolutionError } from '../errors' -import { RuntimeExecutor } from '../executor' - -// Mock AI SDK -vi.mock('ai', async (importOriginal) => { - const actual = (await importOriginal()) as Record - return { - ...actual, - generateText: vi.fn(), - streamText: vi.fn(), - generateImage: vi.fn(), - wrapLanguageModel: vi.fn((config: any) => ({ - ...config.model, - _middlewareApplied: true, - middleware: config.middleware - })) - } -}) - -vi.mock('../../providers/core/ProviderInstanceRegistry', () => ({ - globalRegistryManagement: { - languageModel: vi.fn(), - imageModel: vi.fn() - }, - DEFAULT_SEPARATOR: '|' -})) - -vi.mock('../../models', () => ({ - globalModelResolver: { - resolveLanguageModel: vi.fn(), - resolveImageModel: vi.fn() - } -})) - -describe('RuntimeExecutor - Model Resolution', () => { - let executor: RuntimeExecutor<'openai'> - let mockLanguageModel: LanguageModelV3 - let mockImageModel: ImageModelV3 - - beforeEach(() => { - vi.clearAllMocks() - - executor = RuntimeExecutor.create('openai', mockProviderConfigs.openai) - - mockLanguageModel = createMockLanguageModel({ - specificationVersion: 'v3', - provider: 'openai', - modelId: 'gpt-4' - }) - - mockImageModel = createMockImageModel({ - specificationVersion: 'v3', - provider: 'openai', - modelId: 'dall-e-3' - }) - - vi.mocked(globalModelResolver.resolveLanguageModel).mockResolvedValue(mockLanguageModel) - vi.mocked(globalModelResolver.resolveImageModel).mockResolvedValue(mockImageModel) - vi.mocked(generateText).mockResolvedValue({ - text: 'Test response', - finishReason: 'stop', - usage: { inputTokens: 10, outputTokens: 5, totalTokens: 15 } - } as any) - vi.mocked(streamText).mockResolvedValue({ - textStream: (async function* () { - yield 'test' - })() - } as any) - vi.mocked(generateImage).mockResolvedValue({ - image: { - base64: 'test-image', - uint8Array: new Uint8Array([1, 2, 3]), - mimeType: 'image/png' - }, - warnings: [] - } as any) - }) - - describe('Language Model Resolution (String modelId)', () => { - it('should resolve string modelId using globalModelResolver', async () => { - await executor.generateText({ - model: 'gpt-4', - messages: [{ role: 'user', content: 'Hello' }] - }) - - expect(globalModelResolver.resolveLanguageModel).toHaveBeenCalledWith( - 'gpt-4', - 'openai', - mockProviderConfigs.openai, - undefined - ) - }) - - it('should pass provider settings to model resolver', async () => { - const customExecutor = RuntimeExecutor.create('anthropic', { - apiKey: 'sk-test', - baseURL: 'https://api.anthropic.com' - }) - - vi.mocked(globalModelResolver.resolveLanguageModel).mockResolvedValue(mockLanguageModel) - - await customExecutor.generateText({ - model: 'claude-3-5-sonnet', - messages: [{ role: 'user', content: 'Test' }] - }) - - expect(globalModelResolver.resolveLanguageModel).toHaveBeenCalledWith( - 'claude-3-5-sonnet', - 'anthropic', - { - apiKey: 'sk-test', - baseURL: 'https://api.anthropic.com' - }, - undefined - ) - }) - - it('should resolve traditional format modelId', async () => { - await executor.generateText({ - model: 'gpt-4-turbo', - messages: [{ role: 'user', content: 'Test' }] - }) - - expect(globalModelResolver.resolveLanguageModel).toHaveBeenCalledWith( - 'gpt-4-turbo', - 'openai', - expect.any(Object), - undefined - ) - }) - - it('should resolve namespaced format modelId', async () => { - await executor.generateText({ - model: 'aihubmix|anthropic|claude-3', - messages: [{ role: 'user', content: 'Test' }] - }) - - expect(globalModelResolver.resolveLanguageModel).toHaveBeenCalledWith( - 'aihubmix|anthropic|claude-3', - 'openai', - expect.any(Object), - undefined - ) - }) - - it('should use resolved model for generation', async () => { - await executor.generateText({ - model: 'gpt-4', - messages: [{ role: 'user', content: 'Hello' }] - }) - - expect(generateText).toHaveBeenCalledWith( - expect.objectContaining({ - model: mockLanguageModel - }) - ) - }) - - it('should work with streamText', async () => { - await executor.streamText({ - model: 'gpt-4', - messages: [{ role: 'user', content: 'Stream test' }] - }) - - expect(globalModelResolver.resolveLanguageModel).toHaveBeenCalled() - expect(streamText).toHaveBeenCalledWith( - expect.objectContaining({ - model: mockLanguageModel - }) - ) - }) - }) - - describe('Language Model Resolution (Direct Model Object)', () => { - it('should accept pre-resolved V3 model object', async () => { - const directModel: LanguageModelV3 = createMockLanguageModel({ - specificationVersion: 'v3', - provider: 'openai', - modelId: 'gpt-4' - }) - - await executor.generateText({ - model: directModel, - messages: [{ role: 'user', content: 'Test' }] - }) - - // Should NOT call resolver for direct model - expect(globalModelResolver.resolveLanguageModel).not.toHaveBeenCalled() - - // Should use the model directly - expect(generateText).toHaveBeenCalledWith( - expect.objectContaining({ - model: directModel - }) - ) - }) - - it('should accept V2 model object without validation (plugin engine handles it)', async () => { - const v2Model = { - specificationVersion: 'v2', - provider: 'openai', - modelId: 'gpt-4', - doGenerate: vi.fn() - } as any - - // The plugin engine accepts any model object directly without validation - // V3 validation only happens when resolving string modelIds - await expect( - executor.generateText({ - model: v2Model, - messages: [{ role: 'user', content: 'Test' }] - }) - ).resolves.toBeDefined() - }) - - it('should accept any model object without checking specification version', async () => { - const v2Model = { - specificationVersion: 'v2', - provider: 'custom-provider', - modelId: 'custom-model', - doGenerate: vi.fn() - } as any - - // Direct model objects bypass validation - // The executor trusts that plugins/users provide valid models - await expect( - executor.generateText({ - model: v2Model, - messages: [{ role: 'user', content: 'Test' }] - }) - ).resolves.toBeDefined() - }) - - it('should accept model object with streamText', async () => { - const directModel = createMockLanguageModel({ - specificationVersion: 'v3' - }) - - await executor.streamText({ - model: directModel, - messages: [{ role: 'user', content: 'Stream' }] - }) - - expect(globalModelResolver.resolveLanguageModel).not.toHaveBeenCalled() - expect(streamText).toHaveBeenCalledWith( - expect.objectContaining({ - model: directModel - }) - ) - }) - }) - - describe('Middleware Application', () => { - it('should apply middlewares to string modelId', async () => { - const testMiddleware = createMockMiddleware() - - await executor.generateText( - { - model: 'gpt-4', - messages: [{ role: 'user', content: 'Test' }] - }, - { middlewares: [testMiddleware] } - ) - - expect(globalModelResolver.resolveLanguageModel).toHaveBeenCalledWith('gpt-4', 'openai', expect.any(Object), [ - testMiddleware - ]) - }) - - it('should apply multiple middlewares in order', async () => { - const middleware1 = createMockMiddleware() - const middleware2 = createMockMiddleware() - - await executor.generateText( - { - model: 'gpt-4', - messages: [{ role: 'user', content: 'Test' }] - }, - { middlewares: [middleware1, middleware2] } - ) - - expect(globalModelResolver.resolveLanguageModel).toHaveBeenCalledWith('gpt-4', 'openai', expect.any(Object), [ - middleware1, - middleware2 - ]) - }) - - it('should pass middlewares to model resolver for string modelIds', async () => { - const testMiddleware = createMockMiddleware() - - await executor.generateText( - { - model: 'gpt-4', // String model ID - messages: [{ role: 'user', content: 'Test' }] - }, - { middlewares: [testMiddleware] } - ) - - // Middlewares are passed to the resolver for string modelIds - expect(globalModelResolver.resolveLanguageModel).toHaveBeenCalledWith('gpt-4', 'openai', expect.any(Object), [ - testMiddleware - ]) - }) - - it('should not apply middlewares when none provided', async () => { - await executor.generateText({ - model: 'gpt-4', - messages: [{ role: 'user', content: 'Test' }] - }) - - expect(globalModelResolver.resolveLanguageModel).toHaveBeenCalledWith( - 'gpt-4', - 'openai', - expect.any(Object), - undefined - ) - }) - - it('should handle empty middleware array', async () => { - await executor.generateText( - { - model: 'gpt-4', - messages: [{ role: 'user', content: 'Test' }] - }, - { middlewares: [] } - ) - - expect(globalModelResolver.resolveLanguageModel).toHaveBeenCalledWith('gpt-4', 'openai', expect.any(Object), []) - }) - - it('should work with middlewares in streamText', async () => { - const middleware = createMockMiddleware() - - await executor.streamText( - { - model: 'gpt-4', - messages: [{ role: 'user', content: 'Stream' }] - }, - { middlewares: [middleware] } - ) - - expect(globalModelResolver.resolveLanguageModel).toHaveBeenCalledWith('gpt-4', 'openai', expect.any(Object), [ - middleware - ]) - }) - }) - - describe('Image Model Resolution', () => { - it('should resolve string image modelId using globalModelResolver', async () => { - await executor.generateImage({ - model: 'dall-e-3', - prompt: 'A beautiful sunset' - }) - - expect(globalModelResolver.resolveImageModel).toHaveBeenCalledWith('dall-e-3', 'openai') - }) - - it('should accept direct ImageModelV3 object', async () => { - const directImageModel: ImageModelV3 = createMockImageModel({ - specificationVersion: 'v3', - provider: 'openai', - modelId: 'dall-e-3' - }) - - await executor.generateImage({ - model: directImageModel, - prompt: 'Test image' - }) - - expect(globalModelResolver.resolveImageModel).not.toHaveBeenCalled() - expect(generateImage).toHaveBeenCalledWith( - expect.objectContaining({ - model: directImageModel - }) - ) - }) - - it('should resolve namespaced image model ID', async () => { - await executor.generateImage({ - model: 'aihubmix|openai|dall-e-3', - prompt: 'Namespaced image' - }) - - expect(globalModelResolver.resolveImageModel).toHaveBeenCalledWith('aihubmix|openai|dall-e-3', 'openai') - }) - - it('should throw ImageModelResolutionError on resolution failure', async () => { - const resolutionError = new Error('Model not found') - vi.mocked(globalModelResolver.resolveImageModel).mockRejectedValue(resolutionError) - - await expect( - executor.generateImage({ - model: 'invalid-model', - prompt: 'Test' - }) - ).rejects.toThrow(ImageModelResolutionError) - }) - - it('should include modelId and providerId in ImageModelResolutionError', async () => { - vi.mocked(globalModelResolver.resolveImageModel).mockRejectedValue(new Error('Not found')) - - try { - await executor.generateImage({ - model: 'invalid-model', - prompt: 'Test' - }) - expect.fail('Should have thrown ImageModelResolutionError') - } catch (error) { - expect(error).toBeInstanceOf(ImageModelResolutionError) - const imgError = error as ImageModelResolutionError - expect(imgError.message).toContain('invalid-model') - expect(imgError.providerId).toBe('openai') - } - }) - - it('should extract modelId from direct model object in error', async () => { - const directModel = createMockImageModel({ - modelId: 'direct-model', - doGenerate: vi.fn().mockRejectedValue(new Error('Generation failed')) - }) - - vi.mocked(generateImage).mockRejectedValue(new Error('Generation failed')) - - await expect( - executor.generateImage({ - model: directModel, - prompt: 'Test' - }) - ).rejects.toThrow() - }) - }) - - describe('Provider-Specific Model Resolution', () => { - it('should resolve models for OpenAI provider', async () => { - const openaiExecutor = RuntimeExecutor.create('openai', mockProviderConfigs.openai) - - await openaiExecutor.generateText({ - model: 'gpt-4', - messages: [{ role: 'user', content: 'Test' }] - }) - - expect(globalModelResolver.resolveLanguageModel).toHaveBeenCalledWith( - 'gpt-4', - 'openai', - expect.any(Object), - undefined - ) - }) - - it('should resolve models for Anthropic provider', async () => { - const anthropicExecutor = RuntimeExecutor.create('anthropic', mockProviderConfigs.anthropic) - - await anthropicExecutor.generateText({ - model: 'claude-3-5-sonnet', - messages: [{ role: 'user', content: 'Test' }] - }) - - expect(globalModelResolver.resolveLanguageModel).toHaveBeenCalledWith( - 'claude-3-5-sonnet', - 'anthropic', - expect.any(Object), - undefined - ) - }) - - it('should resolve models for Google provider', async () => { - const googleExecutor = RuntimeExecutor.create('google', mockProviderConfigs.google) - - await googleExecutor.generateText({ - model: 'gemini-2.0-flash', - messages: [{ role: 'user', content: 'Test' }] - }) - - expect(globalModelResolver.resolveLanguageModel).toHaveBeenCalledWith( - 'gemini-2.0-flash', - 'google', - expect.any(Object), - undefined - ) - }) - - it('should resolve models for OpenAI-compatible provider', async () => { - const compatibleExecutor = RuntimeExecutor.createOpenAICompatible(mockProviderConfigs['openai-compatible']) - - await compatibleExecutor.generateText({ - model: 'custom-model', - messages: [{ role: 'user', content: 'Test' }] - }) - - expect(globalModelResolver.resolveLanguageModel).toHaveBeenCalledWith( - 'custom-model', - 'openai-compatible', - expect.any(Object), - undefined - ) - }) - }) - - describe('OpenAI Mode Handling', () => { - it('should pass mode setting to model resolver', async () => { - const executorWithMode = RuntimeExecutor.create('openai', { - ...mockProviderConfigs.openai, - mode: 'chat' - }) - - await executorWithMode.generateText({ - model: 'gpt-4', - messages: [{ role: 'user', content: 'Test' }] - }) - - expect(globalModelResolver.resolveLanguageModel).toHaveBeenCalledWith( - 'gpt-4', - 'openai', - expect.objectContaining({ - mode: 'chat' - }), - undefined - ) - }) - - it('should handle responses mode', async () => { - const executorWithMode = RuntimeExecutor.create('openai', { - ...mockProviderConfigs.openai, - mode: 'responses' - }) - - await executorWithMode.generateText({ - model: 'gpt-4', - messages: [{ role: 'user', content: 'Test' }] - }) - - expect(globalModelResolver.resolveLanguageModel).toHaveBeenCalledWith( - 'gpt-4', - 'openai', - expect.objectContaining({ - mode: 'responses' - }), - undefined - ) - }) - }) - - describe('Edge Cases', () => { - it('should handle empty string modelId', async () => { - await executor.generateText({ - model: '', - messages: [{ role: 'user', content: 'Test' }] - }) - - expect(globalModelResolver.resolveLanguageModel).toHaveBeenCalledWith('', 'openai', expect.any(Object), undefined) - }) - - it('should handle model resolution errors gracefully', async () => { - vi.mocked(globalModelResolver.resolveLanguageModel).mockRejectedValue(new Error('Model not found')) - - await expect( - executor.generateText({ - model: 'nonexistent-model', - messages: [{ role: 'user', content: 'Test' }] - }) - ).rejects.toThrow('Model not found') - }) - - it('should handle concurrent model resolutions', async () => { - const promises = [ - executor.generateText({ model: 'gpt-4', messages: [{ role: 'user', content: 'Test 1' }] }), - executor.generateText({ model: 'gpt-4-turbo', messages: [{ role: 'user', content: 'Test 2' }] }), - executor.generateText({ model: 'gpt-3.5-turbo', messages: [{ role: 'user', content: 'Test 3' }] }) - ] - - await Promise.all(promises) - - expect(globalModelResolver.resolveLanguageModel).toHaveBeenCalledTimes(3) - }) - - it('should accept model object even without specificationVersion', async () => { - const invalidModel = { - provider: 'test', - modelId: 'test-model' - // Missing specificationVersion - } as any - - // Plugin engine doesn't validate direct model objects - // It's the user's responsibility to provide valid models - await expect( - executor.generateText({ - model: invalidModel, - messages: [{ role: 'user', content: 'Test' }] - }) - ).resolves.toBeDefined() - }) - }) - - describe('Type Safety Validation', () => { - it('should ensure resolved model is LanguageModelV3', async () => { - const v3Model = createMockLanguageModel({ - specificationVersion: 'v3' - }) - - vi.mocked(globalModelResolver.resolveLanguageModel).mockResolvedValue(v3Model) - - await executor.generateText({ - model: 'gpt-4', - messages: [{ role: 'user', content: 'Test' }] - }) - - expect(generateText).toHaveBeenCalledWith( - expect.objectContaining({ - model: expect.objectContaining({ - specificationVersion: 'v3' - }) - }) - ) - }) - - it('should not enforce specification version for direct models', async () => { - const v1Model = { - specificationVersion: 'v1', - provider: 'test', - modelId: 'test' - } as any - - // Direct models bypass validation in the plugin engine - // Only resolved models (from string IDs) are validated - await expect( - executor.generateText({ - model: v1Model, - messages: [{ role: 'user', content: 'Test' }] - }) - ).resolves.toBeDefined() - }) - }) -}) diff --git a/packages/aiCore/src/core/runtime/__tests__/generateImage.test.ts b/packages/aiCore/src/core/runtime/__tests__/generateImage.test.ts index 3f1b5b4231..3dfe7afae2 100644 --- a/packages/aiCore/src/core/runtime/__tests__/generateImage.test.ts +++ b/packages/aiCore/src/core/runtime/__tests__/generateImage.test.ts @@ -1,9 +1,9 @@ import type { ImageModelV3 } from '@ai-sdk/provider' +import { createMockImageModel, createMockProviderV3 } from '@test-utils' import { generateImage as aiGenerateImage, NoImageGeneratedError } from 'ai' import { beforeEach, describe, expect, it, vi } from 'vitest' import { type AiPlugin } from '../../plugins' -import { globalProviderInstanceRegistry } from '../../providers/core/ProviderInstanceRegistry' import { ImageGenerationError, ImageModelResolutionError } from '../errors' import { RuntimeExecutor } from '../executor' @@ -21,32 +21,32 @@ vi.mock('ai', () => ({ } })) -vi.mock('../../providers/core/ProviderInstanceRegistry', () => ({ - globalProviderInstanceRegistry: { - imageModel: vi.fn() - }, - DEFAULT_SEPARATOR: '|' -})) - describe('RuntimeExecutor.generateImage', () => { let executor: RuntimeExecutor<'openai'> let mockImageModel: ImageModelV3 + let mockProvider: any let mockGenerateImageResult: any beforeEach(() => { // Reset all mocks vi.clearAllMocks() - // Create executor instance - executor = RuntimeExecutor.create('openai', { - apiKey: 'test-key' - }) - // Mock image model - mockImageModel = { + mockImageModel = createMockImageModel({ modelId: 'dall-e-3', provider: 'openai' - } as ImageModelV3 + }) + + // Create mock provider with imageModel as a spy + mockProvider = createMockProviderV3({ + provider: 'openai', + imageModel: vi.fn(() => mockImageModel) + }) + + // Create executor instance + executor = RuntimeExecutor.create('openai', mockProvider, { + apiKey: 'test-key' + }) // Mock generateImage result mockGenerateImageResult = { @@ -71,8 +71,6 @@ describe('RuntimeExecutor.generateImage', () => { responses: [] } - // Setup mocks to avoid "No providers registered" error - vi.mocked(globalProviderInstanceRegistry.imageModel).mockReturnValue(mockImageModel) vi.mocked(aiGenerateImage).mockResolvedValue(mockGenerateImageResult) }) @@ -80,7 +78,7 @@ describe('RuntimeExecutor.generateImage', () => { it('should generate a single image with minimal parameters', async () => { const result = await executor.generateImage({ model: 'dall-e-3', prompt: 'A futuristic cityscape at sunset' }) - expect(globalProviderInstanceRegistry.imageModel).toHaveBeenCalledWith('openai|dall-e-3') + expect(mockProvider.imageModel).toHaveBeenCalledWith('dall-e-3') expect(aiGenerateImage).toHaveBeenCalledWith({ model: mockImageModel, @@ -96,7 +94,8 @@ describe('RuntimeExecutor.generateImage', () => { prompt: 'A beautiful landscape' }) - // Note: globalProviderInstanceRegistry.imageModel may still be called due to resolveImageModel logic + // Pre-created model is used directly, provider.imageModel is not called + expect(mockProvider.imageModel).not.toHaveBeenCalled() expect(aiGenerateImage).toHaveBeenCalledWith({ model: mockImageModel, prompt: 'A beautiful landscape' @@ -224,6 +223,7 @@ describe('RuntimeExecutor.generateImage', () => { const executorWithPlugin = RuntimeExecutor.create( 'openai', + mockProvider, { apiKey: 'test-key' }, @@ -269,6 +269,7 @@ describe('RuntimeExecutor.generateImage', () => { const executorWithPlugin = RuntimeExecutor.create( 'openai', + mockProvider, { apiKey: 'test-key' }, @@ -309,6 +310,7 @@ describe('RuntimeExecutor.generateImage', () => { const executorWithPlugin = RuntimeExecutor.create( 'openai', + mockProvider, { apiKey: 'test-key' }, @@ -325,7 +327,8 @@ describe('RuntimeExecutor.generateImage', () => { describe('Error handling', () => { it('should handle model creation errors', async () => { const modelError = new Error('Failed to get image model') - vi.mocked(globalProviderInstanceRegistry.imageModel).mockImplementation(() => { + // Since mockProvider.imageModel is already a vi.fn() spy, we can mock it directly + mockProvider.imageModel.mockImplementation(() => { throw modelError }) @@ -336,7 +339,7 @@ describe('RuntimeExecutor.generateImage', () => { it('should handle ImageModelResolutionError correctly', async () => { const resolutionError = new ImageModelResolutionError('invalid-model', 'openai', new Error('Model not found')) - vi.mocked(globalProviderInstanceRegistry.imageModel).mockImplementation(() => { + mockProvider.imageModel.mockImplementation(() => { throw resolutionError }) @@ -353,7 +356,7 @@ describe('RuntimeExecutor.generateImage', () => { it('should handle ImageModelResolutionError without provider', async () => { const resolutionError = new ImageModelResolutionError('unknown-model') - vi.mocked(globalProviderInstanceRegistry.imageModel).mockImplementation(() => { + mockProvider.imageModel.mockImplementation(() => { throw resolutionError }) @@ -398,6 +401,7 @@ describe('RuntimeExecutor.generateImage', () => { const executorWithPlugin = RuntimeExecutor.create( 'openai', + mockProvider, { apiKey: 'test-key' }, @@ -436,23 +440,43 @@ describe('RuntimeExecutor.generateImage', () => { describe('Multiple providers support', () => { it('should work with different providers', async () => { - const googleExecutor = RuntimeExecutor.create('google', { + const googleImageModel = createMockImageModel({ + provider: 'google', + modelId: 'imagen-3.0-generate-002' + }) + + const googleProvider = createMockProviderV3({ + provider: 'google', + imageModel: vi.fn(() => googleImageModel) + }) + + const googleExecutor = RuntimeExecutor.create('google', googleProvider, { apiKey: 'google-key' }) await googleExecutor.generateImage({ model: 'imagen-3.0-generate-002', prompt: 'A landscape' }) - expect(globalProviderInstanceRegistry.imageModel).toHaveBeenCalledWith('google|imagen-3.0-generate-002') + expect(googleProvider.imageModel).toHaveBeenCalledWith('imagen-3.0-generate-002') }) it('should support xAI Grok image models', async () => { - const xaiExecutor = RuntimeExecutor.create('xai', { + const xaiImageModel = createMockImageModel({ + provider: 'xai', + modelId: 'grok-2-image' + }) + + const xaiProvider = createMockProviderV3({ + provider: 'xai', + imageModel: vi.fn(() => xaiImageModel) + }) + + const xaiExecutor = RuntimeExecutor.create('xai', xaiProvider, { apiKey: 'xai-key' }) await xaiExecutor.generateImage({ model: 'grok-2-image', prompt: 'A futuristic robot' }) - expect(globalProviderInstanceRegistry.imageModel).toHaveBeenCalledWith('xai|grok-2-image') + expect(xaiProvider.imageModel).toHaveBeenCalledWith('grok-2-image') }) }) diff --git a/packages/aiCore/src/core/runtime/__tests__/generateText.test.ts b/packages/aiCore/src/core/runtime/__tests__/generateText.test.ts index 85bcab5bd6..e69cb2c4e2 100644 --- a/packages/aiCore/src/core/runtime/__tests__/generateText.test.ts +++ b/packages/aiCore/src/core/runtime/__tests__/generateText.test.ts @@ -3,18 +3,18 @@ * Tests non-streaming text generation across all providers with various parameters */ -import { generateText } from 'ai' -import { beforeEach, describe, expect, it, vi } from 'vitest' - import { createMockLanguageModel, + createMockProviderV3, mockCompleteResponses, mockProviderConfigs, testMessages, testTools -} from '../../../__tests__' +} from '@test-utils' +import { generateText } from 'ai' +import { beforeEach, describe, expect, it, vi } from 'vitest' + import type { AiPlugin } from '../../plugins' -import { globalProviderInstanceRegistry } from '../../providers/core/ProviderInstanceRegistry' import { RuntimeExecutor } from '../executor' // Mock AI SDK - use importOriginal to keep jsonSchema and other non-mocked exports @@ -26,28 +26,28 @@ vi.mock('ai', async (importOriginal) => { } }) -vi.mock('../../providers/core/ProviderInstanceRegistry', () => ({ - globalProviderInstanceRegistry: { - languageModel: vi.fn() - }, - DEFAULT_SEPARATOR: '|' -})) - describe('RuntimeExecutor.generateText', () => { let executor: RuntimeExecutor<'openai'> let mockLanguageModel: any + let mockProvider: any beforeEach(() => { vi.clearAllMocks() - executor = RuntimeExecutor.create('openai', mockProviderConfigs.openai) - mockLanguageModel = createMockLanguageModel({ provider: 'openai', modelId: 'gpt-4' }) - vi.mocked(globalProviderInstanceRegistry.languageModel).mockReturnValue(mockLanguageModel) + // ✅ Create mock provider with languageModel as a spy + mockProvider = createMockProviderV3({ + provider: 'openai', + languageModel: vi.fn(() => mockLanguageModel) + }) + + // ✅ Pass provider instance to RuntimeExecutor.create() + executor = RuntimeExecutor.create('openai', mockProvider, mockProviderConfigs.openai) + vi.mocked(generateText).mockResolvedValue(mockCompleteResponses.simple as any) }) @@ -231,75 +231,87 @@ describe('RuntimeExecutor.generateText', () => { describe('Multiple Providers', () => { it('should work with Anthropic provider', async () => { - const anthropicExecutor = RuntimeExecutor.create('anthropic', mockProviderConfigs.anthropic) - const anthropicModel = createMockLanguageModel({ provider: 'anthropic', modelId: 'claude-3-5-sonnet-20241022' }) - vi.mocked(globalProviderInstanceRegistry.languageModel).mockReturnValue(anthropicModel) + const anthropicProvider = createMockProviderV3({ + provider: 'anthropic', + languageModel: vi.fn(() => anthropicModel) + }) + + const anthropicExecutor = RuntimeExecutor.create('anthropic', anthropicProvider, mockProviderConfigs.anthropic) await anthropicExecutor.generateText({ model: 'claude-3-5-sonnet-20241022', messages: testMessages.simple }) - expect(globalProviderInstanceRegistry.languageModel).toHaveBeenCalledWith('anthropic|claude-3-5-sonnet-20241022') + expect(anthropicProvider.languageModel).toHaveBeenCalledWith('claude-3-5-sonnet-20241022') }) it('should work with Google provider', async () => { - const googleExecutor = RuntimeExecutor.create('google', mockProviderConfigs.google) - const googleModel = createMockLanguageModel({ provider: 'google', modelId: 'gemini-2.0-flash-exp' }) - vi.mocked(globalProviderInstanceRegistry.languageModel).mockReturnValue(googleModel) + const googleProvider = createMockProviderV3({ + provider: 'google', + languageModel: vi.fn(() => googleModel) + }) + + const googleExecutor = RuntimeExecutor.create('google', googleProvider, mockProviderConfigs.google) await googleExecutor.generateText({ model: 'gemini-2.0-flash-exp', messages: testMessages.simple }) - expect(globalProviderInstanceRegistry.languageModel).toHaveBeenCalledWith('google|gemini-2.0-flash-exp') + expect(googleProvider.languageModel).toHaveBeenCalledWith('gemini-2.0-flash-exp') }) it('should work with xAI provider', async () => { - const xaiExecutor = RuntimeExecutor.create('xai', mockProviderConfigs.xai) - const xaiModel = createMockLanguageModel({ provider: 'xai', modelId: 'grok-2-latest' }) - vi.mocked(globalProviderInstanceRegistry.languageModel).mockReturnValue(xaiModel) + const xaiProvider = createMockProviderV3({ + provider: 'xai', + languageModel: vi.fn(() => xaiModel) + }) + + const xaiExecutor = RuntimeExecutor.create('xai', xaiProvider, mockProviderConfigs.xai) await xaiExecutor.generateText({ model: 'grok-2-latest', messages: testMessages.simple }) - expect(globalProviderInstanceRegistry.languageModel).toHaveBeenCalledWith('xai|grok-2-latest') + expect(xaiProvider.languageModel).toHaveBeenCalledWith('grok-2-latest') }) it('should work with DeepSeek provider', async () => { - const deepseekExecutor = RuntimeExecutor.create('deepseek', mockProviderConfigs.deepseek) - const deepseekModel = createMockLanguageModel({ provider: 'deepseek', modelId: 'deepseek-chat' }) - vi.mocked(globalProviderInstanceRegistry.languageModel).mockReturnValue(deepseekModel) + const deepseekProvider = createMockProviderV3({ + provider: 'deepseek', + languageModel: vi.fn(() => deepseekModel) + }) + + const deepseekExecutor = RuntimeExecutor.create('deepseek', deepseekProvider, mockProviderConfigs.deepseek) await deepseekExecutor.generateText({ model: 'deepseek-chat', messages: testMessages.simple }) - expect(globalProviderInstanceRegistry.languageModel).toHaveBeenCalledWith('deepseek|deepseek-chat') + expect(deepseekProvider.languageModel).toHaveBeenCalledWith('deepseek-chat') }) }) @@ -325,7 +337,9 @@ describe('RuntimeExecutor.generateText', () => { }) } - const executorWithPlugin = RuntimeExecutor.create('openai', mockProviderConfigs.openai, [testPlugin]) + const executorWithPlugin = RuntimeExecutor.create('openai', mockProvider, mockProviderConfigs.openai, [ + testPlugin + ]) const result = await executorWithPlugin.generateText({ model: 'gpt-4', @@ -364,7 +378,10 @@ describe('RuntimeExecutor.generateText', () => { }) } - const executorWithPlugins = RuntimeExecutor.create('openai', mockProviderConfigs.openai, [plugin1, plugin2]) + const executorWithPlugins = RuntimeExecutor.create('openai', mockProvider, mockProviderConfigs.openai, [ + plugin1, + plugin2 + ]) await executorWithPlugins.generateText({ model: 'gpt-4', @@ -404,7 +421,9 @@ describe('RuntimeExecutor.generateText', () => { onError: vi.fn() } - const executorWithPlugin = RuntimeExecutor.create('openai', mockProviderConfigs.openai, [errorPlugin]) + const executorWithPlugin = RuntimeExecutor.create('openai', mockProvider, mockProviderConfigs.openai, [ + errorPlugin + ]) await expect( executorWithPlugin.generateText({ @@ -425,7 +444,7 @@ describe('RuntimeExecutor.generateText', () => { it('should handle model not found error', async () => { const error = new Error('Model not found: invalid-model') - vi.mocked(globalProviderInstanceRegistry.languageModel).mockImplementation(() => { + mockProvider.languageModel.mockImplementationOnce(() => { throw error }) diff --git a/packages/aiCore/src/core/runtime/__tests__/pluginEngine.test.ts b/packages/aiCore/src/core/runtime/__tests__/pluginEngine.test.ts index e0dedf1521..c853ad4b6c 100644 --- a/packages/aiCore/src/core/runtime/__tests__/pluginEngine.test.ts +++ b/packages/aiCore/src/core/runtime/__tests__/pluginEngine.test.ts @@ -5,9 +5,9 @@ */ import type { ImageModelV3, LanguageModelV3 } from '@ai-sdk/provider' +import { createMockImageModel, createMockLanguageModel } from '@test-utils' import { beforeEach, describe, expect, it, vi } from 'vitest' -import { createMockImageModel, createMockLanguageModel } from '../../../__tests__' import { ModelResolutionError, RecursiveDepthError } from '../../errors' import type { AiPlugin, GenerateTextParams, GenerateTextResult } from '../../plugins' import { PluginEngine } from '../pluginEngine' diff --git a/packages/aiCore/src/core/runtime/__tests__/streamText.test.ts b/packages/aiCore/src/core/runtime/__tests__/streamText.test.ts index 0cb08e4322..f49282dece 100644 --- a/packages/aiCore/src/core/runtime/__tests__/streamText.test.ts +++ b/packages/aiCore/src/core/runtime/__tests__/streamText.test.ts @@ -3,12 +3,17 @@ * Tests streaming text generation across all providers with various parameters */ +import { + collectStreamChunks, + createMockLanguageModel, + createMockProviderV3, + mockProviderConfigs, + testMessages +} from '@test-utils' import { streamText } from 'ai' import { beforeEach, describe, expect, it, vi } from 'vitest' -import { collectStreamChunks, createMockLanguageModel, mockProviderConfigs, testMessages } from '../../../__tests__' import type { AiPlugin } from '../../plugins' -import { globalProviderInstanceRegistry } from '../../providers/core/ProviderInstanceRegistry' import { RuntimeExecutor } from '../executor' // Mock AI SDK - use importOriginal to keep jsonSchema and other non-mocked exports @@ -20,28 +25,25 @@ vi.mock('ai', async (importOriginal) => { } }) -vi.mock('../../providers/core/ProviderInstanceRegistry', () => ({ - globalProviderInstanceRegistry: { - languageModel: vi.fn() - }, - DEFAULT_SEPARATOR: '|' -})) - describe('RuntimeExecutor.streamText', () => { let executor: RuntimeExecutor<'openai'> let mockLanguageModel: any + let mockProvider: any beforeEach(() => { vi.clearAllMocks() - executor = RuntimeExecutor.create('openai', mockProviderConfigs.openai) - mockLanguageModel = createMockLanguageModel({ provider: 'openai', modelId: 'gpt-4' }) - vi.mocked(globalProviderInstanceRegistry.languageModel).mockReturnValue(mockLanguageModel) + mockProvider = createMockProviderV3({ + provider: 'openai', + languageModel: () => mockLanguageModel + }) + + executor = RuntimeExecutor.create('openai', mockProvider, mockProviderConfigs.openai) }) describe('Basic Functionality', () => { @@ -416,7 +418,9 @@ describe('RuntimeExecutor.streamText', () => { }) } - const executorWithPlugin = RuntimeExecutor.create('openai', mockProviderConfigs.openai, [testPlugin]) + const executorWithPlugin = RuntimeExecutor.create('openai', mockProvider, mockProviderConfigs.openai, [ + testPlugin + ]) const mockStream = { textStream: (async function* () { @@ -509,7 +513,9 @@ describe('RuntimeExecutor.streamText', () => { onError: vi.fn() } - const executorWithPlugin = RuntimeExecutor.create('openai', mockProviderConfigs.openai, [errorPlugin]) + const executorWithPlugin = RuntimeExecutor.create('openai', mockProvider, mockProviderConfigs.openai, [ + errorPlugin + ]) await expect( executorWithPlugin.streamText({ diff --git a/packages/aiCore/src/core/runtime/executor.ts b/packages/aiCore/src/core/runtime/executor.ts index 2fb4aa7feb..0bdb5d6f12 100644 --- a/packages/aiCore/src/core/runtime/executor.ts +++ b/packages/aiCore/src/core/runtime/executor.ts @@ -2,7 +2,7 @@ * 运行时执行器 * 专注于插件化的AI调用处理 */ -import type { ImageModelV3, LanguageModelV3, LanguageModelV3Middleware } from '@ai-sdk/provider' +import type { ImageModelV3, LanguageModelV3, LanguageModelV3Middleware, ProviderV3 } from '@ai-sdk/provider' import type { LanguageModel } from 'ai' import { generateImage as _generateImage, @@ -11,7 +11,7 @@ import { wrapLanguageModel } from 'ai' -import { globalModelResolver } from '../models' +import { ModelResolver } from '../models' import { type ModelConfig } from '../models/types' import { isV3Model } from '../models/utils' import { type AiPlugin, type AiRequestContext, definePlugin } from '../plugins' @@ -26,11 +26,13 @@ export class RuntimeExecutor< > { public pluginEngine: PluginEngine private config: RuntimeConfig + private modelResolver: ModelResolver constructor(config: RuntimeConfig) { this.config = config // 创建插件客户端 this.pluginEngine = new PluginEngine(config.providerId, config.plugins || []) + this.modelResolver = new ModelResolver(config.provider) } private createResolveModelPlugin(middlewares?: LanguageModelV3Middleware[]) { @@ -175,13 +177,9 @@ export class RuntimeExecutor< middlewares?: LanguageModelV3Middleware[] ): Promise { if (typeof modelOrId === 'string') { - // 🎯 字符串modelId,使用新的ModelResolver解析,传递完整参数 - return await globalModelResolver.resolveLanguageModel( - modelOrId, // 支持 'gpt-4' 和 'aihubmix:anthropic:claude-3.5-sonnet' - this.config.providerId, // fallback provider - this.config.providerSettings, // provider options - middlewares // 中间件数组 - ) + // 字符串modelId,使用 ModelResolver 解析 + // Provider会处理命名空间格式路由(如果是HubProvider) + return await this.modelResolver.resolveLanguageModel(modelOrId, middlewares) } else { // 已经是模型对象 // 所有 provider 都应该返回 V3 模型(通过 wrapProvider 确保) @@ -206,11 +204,9 @@ export class RuntimeExecutor< private async resolveImageModel(modelOrId: ImageModelV3 | string): Promise { try { if (typeof modelOrId === 'string') { - // 字符串modelId,使用新的ModelResolver解析 - return await globalModelResolver.resolveImageModel( - modelOrId, // 支持 'dall-e-3' 和 'aihubmix:openai:dall-e-3' - this.config.providerId // fallback provider - ) + // 字符串modelId,使用 ModelResolver 解析 + // Provider会处理命名空间格式路由(如果是HubProvider) + return await this.modelResolver.resolveImageModel(modelOrId) } else { // 已经是模型,直接返回 return modelOrId @@ -234,11 +230,13 @@ export class RuntimeExecutor< TSettingsMap extends Record = CoreProviderSettingsMap >( providerId: T, + provider: ProviderV3, // ✅ Accept provider instance options: ModelConfig['providerSettings'], plugins?: AiPlugin[] ): RuntimeExecutor { return new RuntimeExecutor({ providerId, + provider, // ✅ Pass provider to config providerSettings: options, plugins }) @@ -246,13 +244,16 @@ export class RuntimeExecutor< /** * 创建OpenAI Compatible执行器 + * ✅ Now accepts provider instance directly */ static createOpenAICompatible( + provider: ProviderV3, // ✅ Accept provider instance options: ModelConfig<'openai-compatible'>['providerSettings'], plugins: AiPlugin[] = [] ): RuntimeExecutor<'openai-compatible'> { return new RuntimeExecutor({ providerId: 'openai-compatible', + provider, // ✅ Pass provider to config providerSettings: options, plugins }) diff --git a/packages/aiCore/src/core/runtime/index.ts b/packages/aiCore/src/core/runtime/index.ts index 4d0c3c1a50..25865d65d1 100644 --- a/packages/aiCore/src/core/runtime/index.ts +++ b/packages/aiCore/src/core/runtime/index.ts @@ -14,7 +14,7 @@ export type { RuntimeConfig } from './types' import type { LanguageModelV3Middleware } from '@ai-sdk/provider' import { type AiPlugin } from '../plugins' -import { extensionRegistry, globalProviderStorage } from '../providers' +import { extensionRegistry } from '../providers' import { type CoreProviderSettingsMap, type RegisteredProviderId } from '../providers/types' import { RuntimeExecutor } from './executor' @@ -26,32 +26,15 @@ export async function createExecutor> -export async function createExecutor( - providerId: T, - options: any, - plugins?: AiPlugin[] -): Promise> -export async function createExecutor( - providerId: string, - options: any, - plugins?: AiPlugin[] -): Promise> { - // 确保 provider 已初始化 - if (!globalProviderStorage.has(providerId) && extensionRegistry.has(providerId)) { - try { - await extensionRegistry.createProvider(providerId, options || {}, providerId) - } catch (error) { - // 创建失败会在 ModelResolver 抛出更详细的错误 - console.warn(`Failed to auto-initialize provider "${providerId}":`, error) - } +): Promise> { + if (!extensionRegistry.has(providerId)) { + throw new Error(`Provider extension "${providerId}" not registered`) } - return RuntimeExecutor.create(providerId as RegisteredProviderId, options, plugins) + const provider = await extensionRegistry.createProvider(providerId, options || {}) + return RuntimeExecutor.create(providerId, provider, options, plugins) } -// === 直接调用API(无需创建executor实例)=== - /** * 直接流式文本生成 - 支持middlewares */ @@ -96,11 +79,13 @@ export async function generateImage { - return RuntimeExecutor.createOpenAICompatible(options, plugins) +): Promise> { + const provider = await extensionRegistry.createProvider('openai-compatible', options) + + return RuntimeExecutor.createOpenAICompatible(provider, options, plugins) } // === Agent 功能预留 === diff --git a/packages/aiCore/src/core/runtime/types.ts b/packages/aiCore/src/core/runtime/types.ts index 37ea44d60a..1b21a68946 100644 --- a/packages/aiCore/src/core/runtime/types.ts +++ b/packages/aiCore/src/core/runtime/types.ts @@ -1,7 +1,7 @@ /** * Runtime 层类型定义 */ -import type { ImageModelV3 } from '@ai-sdk/provider' +import type { ImageModelV3, ProviderV3 } from '@ai-sdk/provider' import type { generateImage, generateText, streamText } from 'ai' import { type ModelConfig } from '../models/types' @@ -19,6 +19,7 @@ export interface RuntimeConfig< TSettingsMap extends Record = CoreProviderSettingsMap > { providerId: T + provider: ProviderV3 providerSettings: ModelConfig['providerSettings'] plugins?: AiPlugin[] } diff --git a/packages/aiCore/src/core/types/index.ts b/packages/aiCore/src/core/types/index.ts index a9e4d7f6d1..26086d9c69 100644 --- a/packages/aiCore/src/core/types/index.ts +++ b/packages/aiCore/src/core/types/index.ts @@ -1 +1,8 @@ export type PlainObject = Record + +/** + * Provider settings map for HubProvider + * Key: provider ID (string) + * Value: provider settings object + */ +export type ProviderSettingsMap = Map> diff --git a/packages/aiCore/src/index.ts b/packages/aiCore/src/index.ts index f54483ebd8..9cff910c20 100644 --- a/packages/aiCore/src/index.ts +++ b/packages/aiCore/src/index.ts @@ -15,7 +15,7 @@ export { } from './core/runtime' // ==================== 高级API ==================== -export { isV2Model, isV3Model, globalModelResolver as modelResolver } from './core/models' +export { isV2Model, isV3Model } from './core/models' // ==================== 插件系统 ==================== export type { diff --git a/packages/aiCore/src/__tests__/helpers/test-utils.ts b/packages/aiCore/test_utils/helpers/common.ts similarity index 97% rename from packages/aiCore/src/__tests__/helpers/test-utils.ts rename to packages/aiCore/test_utils/helpers/common.ts index 8231075785..2498dc6edf 100644 --- a/packages/aiCore/src/__tests__/helpers/test-utils.ts +++ b/packages/aiCore/test_utils/helpers/common.ts @@ -1,12 +1,12 @@ /** - * Test Utilities - * Helper functions for testing AI Core functionality + * Common Test Utilities + * General-purpose helper functions for testing */ import { expect, vi } from 'vitest' -import type { ProviderId } from '../fixtures/mock-providers' -import { createMockImageModel, createMockLanguageModel, mockProviderConfigs } from '../fixtures/mock-providers' +import type { ProviderId } from '../mocks/providers' +import { createMockImageModel, createMockLanguageModel, mockProviderConfigs } from '../mocks/providers' /** * Creates a test provider with streaming support diff --git a/packages/aiCore/src/__tests__/helpers/model-test-utils.ts b/packages/aiCore/test_utils/helpers/model.ts similarity index 69% rename from packages/aiCore/src/__tests__/helpers/model-test-utils.ts rename to packages/aiCore/test_utils/helpers/model.ts index 5a5e73942b..cdc59c99d8 100644 --- a/packages/aiCore/src/__tests__/helpers/model-test-utils.ts +++ b/packages/aiCore/test_utils/helpers/model.ts @@ -16,9 +16,9 @@ import { MockLanguageModelV3 } from 'ai/test' import { vi } from 'vitest' import * as z from 'zod' -import type { StreamTextParams, StreamTextResult } from '../../core/plugins' -import type { RegisteredProviderId } from '../../core/providers/types' -import type { AiRequestContext } from '../../types' +import type { StreamTextParams, StreamTextResult } from '../../src/core/plugins' +import type { RegisteredProviderId } from '../../src/core/providers/types' +import type { AiRequestContext } from '../../src/types' /** * Type for partial overrides that allows omitting the model field @@ -137,45 +137,95 @@ export function createMockProviderV3(overrides?: { imageModel?: (modelId: string) => ImageModelV3 embeddingModel?: (modelId: string) => EmbeddingModelV3 }): ProviderV3 { + const defaultLanguageModel = (modelId: string) => + ({ + specificationVersion: 'v3', + provider: overrides?.provider ?? 'mock-provider', + modelId, + defaultObjectGenerationMode: 'tool', + supportedUrls: {}, + doGenerate: vi.fn().mockResolvedValue({ + text: 'Mock response text', + finishReason: 'stop', + usage: { + inputTokens: 10, + outputTokens: 20, + totalTokens: 30, + inputTokenDetails: {}, + outputTokenDetails: {} + }, + rawCall: { rawPrompt: null, rawSettings: {} }, + rawResponse: { headers: {} }, + warnings: [] + }), + doStream: vi.fn().mockReturnValue({ + stream: (async function* () { + yield { type: 'text-delta', textDelta: 'Mock ' } + yield { type: 'text-delta', textDelta: 'streaming ' } + yield { type: 'text-delta', textDelta: 'response' } + yield { + type: 'finish', + finishReason: 'stop', + usage: { + inputTokens: 10, + outputTokens: 15, + totalTokens: 25, + inputTokenDetails: {}, + outputTokenDetails: {} + } + } + })(), + rawCall: { rawPrompt: null, rawSettings: {} }, + rawResponse: { headers: {} }, + warnings: [] + }) + }) as LanguageModelV3 + + const defaultImageModel = (modelId: string) => + ({ + specificationVersion: 'v3', + provider: overrides?.provider ?? 'mock-provider', + modelId, + maxImagesPerCall: undefined, + doGenerate: vi.fn().mockResolvedValue({ + images: [ + { + base64: 'mock-base64-image-data', + uint8Array: new Uint8Array([1, 2, 3, 4, 5]), + mimeType: 'image/png' + } + ], + warnings: [] + }) + }) as ImageModelV3 + + const defaultEmbeddingModel = (modelId: string) => + ({ + specificationVersion: 'v3', + provider: overrides?.provider ?? 'mock-provider', + modelId, + maxEmbeddingsPerCall: 100, + supportsParallelCalls: true, + doEmbed: vi.fn().mockResolvedValue({ + embeddings: [ + [0.1, 0.2, 0.3, 0.4, 0.5], + [0.6, 0.7, 0.8, 0.9, 1.0] + ], + usage: { + inputTokens: 10, + totalTokens: 10 + }, + rawResponse: { headers: {} } + }) + }) as EmbeddingModelV3 + return { specificationVersion: 'v3', provider: overrides?.provider ?? 'mock-provider', - languageModel: overrides?.languageModel - ? overrides.languageModel - : (modelId: string) => - ({ - specificationVersion: 'v3', - provider: overrides?.provider ?? 'mock-provider', - modelId, - defaultObjectGenerationMode: 'tool', - supportedUrls: {}, - doGenerate: vi.fn(), - doStream: vi.fn() - }) as LanguageModelV3, - - imageModel: overrides?.imageModel - ? overrides.imageModel - : (modelId: string) => - ({ - specificationVersion: 'v3', - provider: overrides?.provider ?? 'mock-provider', - modelId, - maxImagesPerCall: undefined, - doGenerate: vi.fn() - }) as ImageModelV3, - - embeddingModel: overrides?.embeddingModel - ? overrides.embeddingModel - : (modelId: string) => - ({ - specificationVersion: 'v3', - provider: overrides?.provider ?? 'mock-provider', - modelId, - maxEmbeddingsPerCall: 100, - supportsParallelCalls: true, - doEmbed: vi.fn() - }) as EmbeddingModelV3 + languageModel: vi.fn(overrides?.languageModel ?? defaultLanguageModel), + imageModel: vi.fn(overrides?.imageModel ?? defaultImageModel), + embeddingModel: vi.fn(overrides?.embeddingModel ?? defaultEmbeddingModel) } as ProviderV3 } diff --git a/packages/aiCore/src/__tests__/helpers/provider-test-utils.ts b/packages/aiCore/test_utils/helpers/provider.ts similarity index 100% rename from packages/aiCore/src/__tests__/helpers/provider-test-utils.ts rename to packages/aiCore/test_utils/helpers/provider.ts diff --git a/packages/aiCore/test_utils/index.ts b/packages/aiCore/test_utils/index.ts new file mode 100644 index 0000000000..098c538726 --- /dev/null +++ b/packages/aiCore/test_utils/index.ts @@ -0,0 +1,13 @@ +/** + * Test Infrastructure Exports + * Central export point for all test utilities, fixtures, and helpers + */ + +// Mocks +export * from './mocks/providers' +export * from './mocks/responses' + +// Helpers +export * from './helpers/common' +export * from './helpers/model' +export * from './helpers/provider' diff --git a/packages/aiCore/src/__tests__/mocks/ai-sdk-provider.ts b/packages/aiCore/test_utils/mocks/ai-sdk-provider.ts similarity index 100% rename from packages/aiCore/src/__tests__/mocks/ai-sdk-provider.ts rename to packages/aiCore/test_utils/mocks/ai-sdk-provider.ts diff --git a/packages/aiCore/src/__tests__/fixtures/mock-providers.ts b/packages/aiCore/test_utils/mocks/providers.ts similarity index 100% rename from packages/aiCore/src/__tests__/fixtures/mock-providers.ts rename to packages/aiCore/test_utils/mocks/providers.ts diff --git a/packages/aiCore/src/__tests__/fixtures/mock-responses.ts b/packages/aiCore/test_utils/mocks/responses.ts similarity index 100% rename from packages/aiCore/src/__tests__/fixtures/mock-responses.ts rename to packages/aiCore/test_utils/mocks/responses.ts diff --git a/packages/aiCore/src/__tests__/setup.ts b/packages/aiCore/test_utils/setup.ts similarity index 100% rename from packages/aiCore/src/__tests__/setup.ts rename to packages/aiCore/test_utils/setup.ts diff --git a/packages/aiCore/tsconfig.json b/packages/aiCore/tsconfig.json index 110b2106e0..be852753cb 100644 --- a/packages/aiCore/tsconfig.json +++ b/packages/aiCore/tsconfig.json @@ -11,11 +11,16 @@ "noEmitOnError": false, "outDir": "./dist", "resolveJsonModule": true, - "rootDir": "./src", + "rootDir": ".", "skipLibCheck": true, "strict": true, - "target": "ES2020" + "target": "ES2020", + "baseUrl": ".", + "paths": { + "@test-utils": ["./test_utils"], + "@test-utils/*": ["./test_utils/*"] + } }, "exclude": ["node_modules", "dist"], - "include": ["src/**/*"] + "include": ["src/**/*", "test_utils/**/*"] } diff --git a/packages/aiCore/vitest.config.ts b/packages/aiCore/vitest.config.ts index 2f520ea967..801e2ededf 100644 --- a/packages/aiCore/vitest.config.ts +++ b/packages/aiCore/vitest.config.ts @@ -8,13 +8,14 @@ const __dirname = path.dirname(fileURLToPath(import.meta.url)) export default defineConfig({ test: { globals: true, - setupFiles: [path.resolve(__dirname, './src/__tests__/setup.ts')] + setupFiles: [path.resolve(__dirname, './test_utils/setup.ts')] }, resolve: { alias: { '@': path.resolve(__dirname, './src'), + '@test-utils': path.resolve(__dirname, './test_utils'), // Mock external packages that may not be available in test environment - '@cherrystudio/ai-sdk-provider': path.resolve(__dirname, './src/__tests__/mocks/ai-sdk-provider.ts') + '@cherrystudio/ai-sdk-provider': path.resolve(__dirname, './test_utils/mocks/ai-sdk-provider.ts') } }, esbuild: { diff --git a/yarn.lock b/yarn.lock index e6439da366..9e93036f05 100644 --- a/yarn.lock +++ b/yarn.lock @@ -1936,6 +1936,7 @@ __metadata: "@ai-sdk/provider": "npm:^3.0.0" "@ai-sdk/provider-utils": "npm:^4.0.0" "@ai-sdk/xai": "npm:^3.0.0" + lru-cache: "npm:^11.2.4" tsdown: "npm:^0.12.9" typescript: "npm:^5.0.0" vitest: "npm:^3.2.4" @@ -18183,6 +18184,13 @@ __metadata: languageName: node linkType: hard +"lru-cache@npm:^11.2.4": + version: 11.2.4 + resolution: "lru-cache@npm:11.2.4" + checksum: 10c0/4a24f9b17537619f9144d7b8e42cd5a225efdfd7076ebe7b5e7dc02b860a818455201e67fbf000765233fe7e339d3c8229fc815e9b58ee6ede511e07608c19b2 + languageName: node + linkType: hard + "lru-cache@npm:^5.1.1": version: 5.1.1 resolution: "lru-cache@npm:5.1.1"