diff --git a/packages/aiCore/src/__tests__/fixtures/mock-providers.ts b/packages/aiCore/src/__tests__/fixtures/mock-providers.ts index 74824e4719..9e59e5e16d 100644 --- a/packages/aiCore/src/__tests__/fixtures/mock-providers.ts +++ b/packages/aiCore/src/__tests__/fixtures/mock-providers.ts @@ -11,18 +11,20 @@ import { vi } from 'vitest' */ export function createMockLanguageModel(overrides?: Partial): LanguageModelV3 { return { - specificationVersion: 'v1', + specificationVersion: 'v3', provider: 'mock-provider', modelId: 'mock-model', - defaultObjectGenerationMode: 'tool', + supportedUrls: {}, doGenerate: vi.fn().mockResolvedValue({ text: 'Mock response text', finishReason: 'stop', usage: { - promptTokens: 10, - completionTokens: 20, - totalTokens: 30 + inputTokens: 10, + outputTokens: 20, + totalTokens: 30, + inputTokenDetails: {}, + outputTokenDetails: {} }, rawCall: { rawPrompt: null, rawSettings: {} }, rawResponse: { headers: {} }, @@ -47,9 +49,11 @@ export function createMockLanguageModel(overrides?: Partial): L type: 'finish', finishReason: 'stop', usage: { - promptTokens: 10, - completionTokens: 15, - totalTokens: 25 + inputTokens: 10, + outputTokens: 15, + totalTokens: 25, + inputTokenDetails: {}, + outputTokenDetails: {} } } })(), @@ -64,12 +68,14 @@ export function createMockLanguageModel(overrides?: Partial): L /** * Creates a mock image model with customizable behavior + * Compliant with AI SDK v3 specification */ export function createMockImageModel(overrides?: Partial): ImageModelV3 { return { - specificationVersion: 'V3', + specificationVersion: 'v3', provider: 'mock-provider', modelId: 'mock-image-model', + maxImagesPerCall: undefined, doGenerate: vi.fn().mockResolvedValue({ images: [ diff --git a/packages/aiCore/src/__tests__/helpers/model-test-utils.ts b/packages/aiCore/src/__tests__/helpers/model-test-utils.ts new file mode 100644 index 0000000000..067d93e8c6 --- /dev/null +++ b/packages/aiCore/src/__tests__/helpers/model-test-utils.ts @@ -0,0 +1,300 @@ +/** + * Model Test Utilities + * Provides comprehensive mock creators for AI SDK v3 models and related test utilities + */ + +import type { + EmbeddingModelV3, + ImageModelV3, + LanguageModelV3, + LanguageModelV3Middleware, + ProviderV3 +} from '@ai-sdk/provider' +import type { Tool, ToolSet } from 'ai' +import { tool } from 'ai' +import { MockLanguageModelV3 } from 'ai/test' +import { vi } from 'vitest' +import * as z from 'zod' + +import { StreamTextParams, StreamTextResult } from '../../core/plugins' +import type { ProviderId } from '../../core/providers/types' +import { AiRequestContext } from '../../types' + +/** + * Type for partial overrides that allows omitting the model field + * The model will be automatically added by createMockContext + */ +type ContextOverrides = Partial, 'originalParams'>> & { + originalParams?: Partial> & { model?: StreamTextParams['model'] } +} + +/** + * Creates a mock AiRequestContext with type safety + * The model field is automatically added to originalParams if not provided + * + * @example + * ```ts + * const context = createMockContext({ + * providerId: 'openai', + * metadata: { requestId: 'test-123' } + * }) + * ``` + */ +export function createMockContext(overrides?: ContextOverrides): AiRequestContext { + const mockModel = new MockLanguageModelV3({ + provider: 'test-provider', + modelId: 'test-model' + }) + + const base: AiRequestContext = { + providerId: 'openai' as ProviderId, + model: mockModel, + originalParams: { + model: mockModel, + messages: [{ role: 'user', content: 'Test message' }] + } as StreamTextParams, + metadata: {}, + startTime: Date.now(), + requestId: 'test-request-id', + recursiveCall: vi.fn(), + isRecursiveCall: false, + recursiveDepth: 0, + maxRecursiveDepth: 10, + extensions: new Map() + } + + if (overrides) { + // Ensure model is always present in originalParams + const mergedOriginalParams = { + ...base.originalParams, + ...overrides.originalParams, + model: overrides.originalParams?.model ?? mockModel + } + + return { + ...base, + ...overrides, + originalParams: mergedOriginalParams as StreamTextParams + } + } + + return base +} + +/** + * Creates a mock embedding model with customizable behavior + * Compliant with AI SDK v3 specification + * + * @example + * ```ts + * const embeddingModel = createMockEmbeddingModel({ + * provider: 'openai', + * modelId: 'text-embedding-3-small', + * maxEmbeddingsPerCall: 2048 + * }) + * ``` + */ +export function createMockEmbeddingModel(overrides?: Partial): EmbeddingModelV3 { + return { + specificationVersion: 'v3', + provider: 'mock-provider', + modelId: 'mock-embedding-model', + maxEmbeddingsPerCall: 100, + supportsParallelCalls: true, + + doEmbed: vi.fn().mockResolvedValue({ + embeddings: [ + [0.1, 0.2, 0.3, 0.4, 0.5], + [0.6, 0.7, 0.8, 0.9, 1.0] + ], + usage: { + inputTokens: 10, + totalTokens: 10 + }, + rawResponse: { headers: {} } + }), + + ...overrides + } as EmbeddingModelV3 +} + +/** + * Creates a complete mock ProviderV3 with all model types + * Useful for testing provider registration and management + * + * @example + * ```ts + * const provider = createMockProviderV3({ + * provider: 'openai', + * languageModel: customLanguageModel, + * imageModel: customImageModel + * }) + * ``` + */ +export function createMockProviderV3(overrides?: { + provider?: string + languageModel?: (modelId: string) => LanguageModelV3 + imageModel?: (modelId: string) => ImageModelV3 + embeddingModel?: (modelId: string) => EmbeddingModelV3 +}): ProviderV3 { + return { + specificationVersion: 'v3', + provider: overrides?.provider ?? 'mock-provider', + + languageModel: overrides?.languageModel + ? overrides.languageModel + : (modelId: string) => + ({ + specificationVersion: 'v3', + provider: overrides?.provider ?? 'mock-provider', + modelId, + defaultObjectGenerationMode: 'tool', + supportedUrls: {}, + doGenerate: vi.fn(), + doStream: vi.fn() + }) as LanguageModelV3, + + imageModel: overrides?.imageModel + ? overrides.imageModel + : (modelId: string) => + ({ + specificationVersion: 'v3', + provider: overrides?.provider ?? 'mock-provider', + modelId, + maxImagesPerCall: undefined, + doGenerate: vi.fn() + }) as ImageModelV3, + + embeddingModel: overrides?.embeddingModel + ? overrides.embeddingModel + : (modelId: string) => + ({ + specificationVersion: 'v3', + provider: overrides?.provider ?? 'mock-provider', + modelId, + maxEmbeddingsPerCall: 100, + supportsParallelCalls: true, + doEmbed: vi.fn() + }) as EmbeddingModelV3 + } as ProviderV3 +} + +/** + * Creates a mock middleware for testing middleware chains + * Supports both generate and stream wrapping + * + * @example + * ```ts + * const middleware = createMockMiddleware({ + * name: 'test-middleware' + * }) + * ``` + */ +export function createMockMiddleware(_options?: { name?: string }): LanguageModelV3Middleware { + return { + specificationVersion: 'v3', + wrapGenerate: vi.fn((doGenerate) => doGenerate), + wrapStream: vi.fn((doStream) => doStream) + } +} + +/** + * Creates a type-safe function tool for testing using AI SDK's tool() function + * + * @example + * ```ts + * const weatherTool = createMockTool('getWeather', 'Get current weather') + * ``` + */ +export function createMockTool(name: string, description?: string): Tool<{ value?: string }, string> { + return tool({ + description: description || `Mock tool: ${name}`, + inputSchema: z.object({ + value: z.string().optional() + }), + execute: vi.fn(async () => 'mock result') + }) +} + +/** + * Creates a provider-defined tool for testing + */ +export function createMockProviderTool(name: string, description?: string): { type: 'provider'; description: string } { + return { + type: 'provider' as const, + description: description || `Mock provider tool: ${name}` + } +} + +/** + * Creates a ToolSet with multiple tools + * + * @example + * ```ts + * const tools = createMockToolSet({ + * getWeather: 'function', + * searchDatabase: 'function', + * nativeSearch: 'provider' + * }) + * ``` + */ +export function createMockToolSet(tools: Record): ToolSet { + const toolSet: ToolSet = {} + + for (const [name, type] of Object.entries(tools)) { + if (type === 'function') { + toolSet[name] = createMockTool(name) + } else { + toolSet[name] = createMockProviderTool(name) as Tool + } + } + + return toolSet +} + +/** + * Creates mock stream params for testing + * + * @example + * ```ts + * const params = createMockStreamParams({ + * messages: [{ role: 'user', content: 'Custom message' }], + * temperature: 0.7 + * }) + * ``` + */ +export function createMockStreamParams(overrides?: Partial): StreamTextParams { + return { + messages: [{ role: 'user', content: 'Test message' }], + ...overrides + } as StreamTextParams +} + +/** + * Common mock model instances for quick testing + */ +export const mockModels = { + /** Standard language model for general testing */ + language: new MockLanguageModelV3({ + provider: 'test-provider', + modelId: 'test-model' + }), + + /** Mock OpenAI GPT-4 model */ + gpt4: new MockLanguageModelV3({ + provider: 'openai', + modelId: 'gpt-4' + }), + + /** Mock Anthropic Claude model */ + claude: new MockLanguageModelV3({ + provider: 'anthropic', + modelId: 'claude-3-5-sonnet-20241022' + }), + + /** Mock Google Gemini model */ + gemini: new MockLanguageModelV3({ + provider: 'google', + modelId: 'gemini-2.0-flash-exp' + }) +} as const diff --git a/packages/aiCore/src/__tests__/index.ts b/packages/aiCore/src/__tests__/index.ts index 23ecd167a4..afc498cad1 100644 --- a/packages/aiCore/src/__tests__/index.ts +++ b/packages/aiCore/src/__tests__/index.ts @@ -8,5 +8,6 @@ export * from './fixtures/mock-providers' export * from './fixtures/mock-responses' // Helpers +export * from './helpers/model-test-utils' export * from './helpers/provider-test-utils' export * from './helpers/test-utils' diff --git a/packages/aiCore/src/core/models/__tests__/ModelResolver.test.ts b/packages/aiCore/src/core/models/__tests__/ModelResolver.test.ts new file mode 100644 index 0000000000..a2ee74fd00 --- /dev/null +++ b/packages/aiCore/src/core/models/__tests__/ModelResolver.test.ts @@ -0,0 +1,454 @@ +/** + * ModelResolver Comprehensive Tests + * Tests model resolution logic for language, embedding, and image models + * Covers both traditional and namespaced format resolution + */ + +import type { EmbeddingModelV3, ImageModelV3, LanguageModelV3 } from '@ai-sdk/provider' +import { beforeEach, describe, expect, it, vi } from 'vitest' + +import { + createMockEmbeddingModel, + createMockImageModel, + createMockLanguageModel, + createMockMiddleware +} from '../../../__tests__' +import { DEFAULT_SEPARATOR, globalRegistryManagement } from '../../providers/RegistryManagement' +import { ModelResolver } from '../ModelResolver' + +// Mock the dependencies +vi.mock('../../providers/RegistryManagement', () => ({ + globalRegistryManagement: { + languageModel: vi.fn(), + embeddingModel: vi.fn(), + imageModel: vi.fn() + }, + DEFAULT_SEPARATOR: '|' +})) + +vi.mock('../../middleware/wrapper', () => ({ + wrapModelWithMiddlewares: vi.fn((model: LanguageModelV3) => { + // Return a wrapped model with a marker + return { + ...model, + _wrapped: true + } as LanguageModelV3 + }) +})) + +describe('ModelResolver', () => { + let resolver: ModelResolver + let mockLanguageModel: LanguageModelV3 + let mockEmbeddingModel: EmbeddingModelV3 + let mockImageModel: ImageModelV3 + + beforeEach(() => { + vi.clearAllMocks() + resolver = new ModelResolver() + + // Create properly typed mock models using global utilities + mockLanguageModel = createMockLanguageModel({ + provider: 'test-provider', + modelId: 'test-model' + }) + + mockEmbeddingModel = createMockEmbeddingModel({ + provider: 'test-provider', + modelId: 'test-embedding' + }) + + mockImageModel = createMockImageModel({ + provider: 'test-provider', + modelId: 'test-image' + }) + + // Setup default mock implementations + vi.mocked(globalRegistryManagement.languageModel).mockReturnValue(mockLanguageModel) + vi.mocked(globalRegistryManagement.embeddingModel).mockReturnValue(mockEmbeddingModel) + vi.mocked(globalRegistryManagement.imageModel).mockReturnValue(mockImageModel) + }) + + describe('resolveLanguageModel', () => { + describe('Traditional Format Resolution', () => { + it('should resolve traditional format modelId without separator', async () => { + const result = await resolver.resolveLanguageModel('gpt-4', 'openai') + + expect(globalRegistryManagement.languageModel).toHaveBeenCalledWith(`openai${DEFAULT_SEPARATOR}gpt-4`) + expect(result).toBe(mockLanguageModel) + }) + + it('should resolve with different provider and modelId combinations', async () => { + const testCases: Array<{ modelId: string; providerId: string; expected: string }> = [ + { modelId: 'claude-3-5-sonnet', providerId: 'anthropic', expected: 'anthropic|claude-3-5-sonnet' }, + { modelId: 'gemini-2.0-flash', providerId: 'google', expected: 'google|gemini-2.0-flash' }, + { modelId: 'grok-2-latest', providerId: 'xai', expected: 'xai|grok-2-latest' }, + { modelId: 'deepseek-chat', providerId: 'deepseek', expected: 'deepseek|deepseek-chat' } + ] + + for (const testCase of testCases) { + vi.clearAllMocks() + await resolver.resolveLanguageModel(testCase.modelId, testCase.providerId) + + expect(globalRegistryManagement.languageModel).toHaveBeenCalledWith(testCase.expected) + } + }) + + it('should handle modelIds with special characters', async () => { + const modelIds = ['model-v1.0', 'model_v2', 'model.2024', 'model:free'] + + for (const modelId of modelIds) { + vi.clearAllMocks() + await resolver.resolveLanguageModel(modelId, 'provider') + + expect(globalRegistryManagement.languageModel).toHaveBeenCalledWith(`provider${DEFAULT_SEPARATOR}${modelId}`) + } + }) + }) + + describe('Namespaced Format Resolution', () => { + it('should resolve namespaced format with hub', async () => { + const namespacedId = `aihubmix${DEFAULT_SEPARATOR}anthropic${DEFAULT_SEPARATOR}claude-3-5-sonnet` + + const result = await resolver.resolveLanguageModel(namespacedId, 'openai') + + expect(globalRegistryManagement.languageModel).toHaveBeenCalledWith(namespacedId) + expect(result).toBe(mockLanguageModel) + }) + + it('should resolve simple namespaced format', async () => { + const namespacedId = `provider${DEFAULT_SEPARATOR}model-id` + + await resolver.resolveLanguageModel(namespacedId, 'fallback-provider') + + expect(globalRegistryManagement.languageModel).toHaveBeenCalledWith(namespacedId) + }) + + it('should handle complex namespaced IDs', async () => { + const complexIds = [ + `hub${DEFAULT_SEPARATOR}provider${DEFAULT_SEPARATOR}model`, + `hub${DEFAULT_SEPARATOR}provider${DEFAULT_SEPARATOR}model-v1.0`, + `custom${DEFAULT_SEPARATOR}openai${DEFAULT_SEPARATOR}gpt-4-turbo` + ] + + for (const id of complexIds) { + vi.clearAllMocks() + await resolver.resolveLanguageModel(id, 'fallback') + + expect(globalRegistryManagement.languageModel).toHaveBeenCalledWith(id) + } + }) + }) + + describe('OpenAI Mode Selection', () => { + it('should append "-chat" suffix for OpenAI provider with chat mode', async () => { + await resolver.resolveLanguageModel('gpt-4', 'openai', { mode: 'chat' }) + + expect(globalRegistryManagement.languageModel).toHaveBeenCalledWith('openai-chat|gpt-4') + }) + + it('should append "-chat" suffix for Azure provider with chat mode', async () => { + await resolver.resolveLanguageModel('gpt-4', 'azure', { mode: 'chat' }) + + expect(globalRegistryManagement.languageModel).toHaveBeenCalledWith('azure-chat|gpt-4') + }) + + it('should not append suffix for OpenAI with responses mode', async () => { + await resolver.resolveLanguageModel('gpt-4', 'openai', { mode: 'responses' }) + + expect(globalRegistryManagement.languageModel).toHaveBeenCalledWith('openai|gpt-4') + }) + + it('should not append suffix for OpenAI without mode', async () => { + await resolver.resolveLanguageModel('gpt-4', 'openai') + + expect(globalRegistryManagement.languageModel).toHaveBeenCalledWith('openai|gpt-4') + }) + + it('should not append suffix for other providers with chat mode', async () => { + await resolver.resolveLanguageModel('claude-3', 'anthropic', { mode: 'chat' }) + + expect(globalRegistryManagement.languageModel).toHaveBeenCalledWith('anthropic|claude-3') + }) + + it('should handle namespaced IDs with OpenAI chat mode', async () => { + const namespacedId = `hub${DEFAULT_SEPARATOR}openai${DEFAULT_SEPARATOR}gpt-4` + + await resolver.resolveLanguageModel(namespacedId, 'openai', { mode: 'chat' }) + + // Should use the namespaced ID directly, not apply mode logic + expect(globalRegistryManagement.languageModel).toHaveBeenCalledWith(namespacedId) + }) + }) + + describe('Middleware Application', () => { + it('should apply middlewares to resolved model', async () => { + const mockMiddleware = createMockMiddleware({ name: 'test-middleware' }) + + const result = await resolver.resolveLanguageModel('gpt-4', 'openai', undefined, [mockMiddleware]) + + expect(result).toHaveProperty('_wrapped', true) + }) + + it('should apply multiple middlewares in order', async () => { + const middleware1 = createMockMiddleware({ name: 'middleware-1' }) + const middleware2 = createMockMiddleware({ name: 'middleware-2' }) + + const result = await resolver.resolveLanguageModel('gpt-4', 'openai', undefined, [middleware1, middleware2]) + + expect(result).toHaveProperty('_wrapped', true) + }) + + it('should not apply middlewares when none provided', async () => { + const result = await resolver.resolveLanguageModel('gpt-4', 'openai') + + expect(result).not.toHaveProperty('_wrapped') + expect(result).toBe(mockLanguageModel) + }) + + it('should not apply middlewares when empty array provided', async () => { + const result = await resolver.resolveLanguageModel('gpt-4', 'openai', undefined, []) + + expect(result).not.toHaveProperty('_wrapped') + }) + }) + + describe('Provider Options Handling', () => { + it('should pass provider options correctly', async () => { + const options = { baseURL: 'https://api.example.com', apiKey: 'test-key' } + + await resolver.resolveLanguageModel('gpt-4', 'openai', options) + + // Provider options are used for mode selection logic + expect(globalRegistryManagement.languageModel).toHaveBeenCalled() + }) + + it('should handle empty provider options', async () => { + await resolver.resolveLanguageModel('gpt-4', 'openai', {}) + + expect(globalRegistryManagement.languageModel).toHaveBeenCalledWith('openai|gpt-4') + }) + + it('should handle undefined provider options', async () => { + await resolver.resolveLanguageModel('gpt-4', 'openai', undefined) + + expect(globalRegistryManagement.languageModel).toHaveBeenCalledWith('openai|gpt-4') + }) + }) + }) + + describe('resolveTextEmbeddingModel', () => { + describe('Traditional Format', () => { + it('should resolve traditional embedding model ID', async () => { + const result = await resolver.resolveTextEmbeddingModel('text-embedding-ada-002', 'openai') + + expect(globalRegistryManagement.embeddingModel).toHaveBeenCalledWith('openai|text-embedding-ada-002') + expect(result).toBe(mockEmbeddingModel) + }) + + it('should resolve different embedding models', async () => { + const testCases = [ + { modelId: 'text-embedding-3-small', providerId: 'openai' }, + { modelId: 'text-embedding-3-large', providerId: 'openai' }, + { modelId: 'embed-english-v3.0', providerId: 'cohere' }, + { modelId: 'voyage-2', providerId: 'voyage' } + ] + + for (const { modelId, providerId } of testCases) { + vi.clearAllMocks() + await resolver.resolveTextEmbeddingModel(modelId, providerId) + + expect(globalRegistryManagement.embeddingModel).toHaveBeenCalledWith(`${providerId}|${modelId}`) + } + }) + }) + + describe('Namespaced Format', () => { + it('should resolve namespaced embedding model ID', async () => { + const namespacedId = `aihubmix${DEFAULT_SEPARATOR}openai${DEFAULT_SEPARATOR}text-embedding-3-small` + + const result = await resolver.resolveTextEmbeddingModel(namespacedId, 'openai') + + expect(globalRegistryManagement.embeddingModel).toHaveBeenCalledWith(namespacedId) + expect(result).toBe(mockEmbeddingModel) + }) + + it('should handle complex namespaced embedding IDs', async () => { + const complexIds = [ + `hub${DEFAULT_SEPARATOR}cohere${DEFAULT_SEPARATOR}embed-multilingual`, + `custom${DEFAULT_SEPARATOR}provider${DEFAULT_SEPARATOR}embedding-model` + ] + + for (const id of complexIds) { + vi.clearAllMocks() + await resolver.resolveTextEmbeddingModel(id, 'fallback') + + expect(globalRegistryManagement.embeddingModel).toHaveBeenCalledWith(id) + } + }) + }) + }) + + describe('resolveImageModel', () => { + describe('Traditional Format', () => { + it('should resolve traditional image model ID', async () => { + const result = await resolver.resolveImageModel('dall-e-3', 'openai') + + expect(globalRegistryManagement.imageModel).toHaveBeenCalledWith('openai|dall-e-3') + expect(result).toBe(mockImageModel) + }) + + it('should resolve different image models', async () => { + const testCases = [ + { modelId: 'dall-e-2', providerId: 'openai' }, + { modelId: 'stable-diffusion-xl', providerId: 'stability' }, + { modelId: 'imagen-2', providerId: 'google' }, + { modelId: 'midjourney-v6', providerId: 'midjourney' } + ] + + for (const { modelId, providerId } of testCases) { + vi.clearAllMocks() + await resolver.resolveImageModel(modelId, providerId) + + expect(globalRegistryManagement.imageModel).toHaveBeenCalledWith(`${providerId}|${modelId}`) + } + }) + }) + + describe('Namespaced Format', () => { + it('should resolve namespaced image model ID', async () => { + const namespacedId = `aihubmix${DEFAULT_SEPARATOR}openai${DEFAULT_SEPARATOR}dall-e-3` + + const result = await resolver.resolveImageModel(namespacedId, 'openai') + + expect(globalRegistryManagement.imageModel).toHaveBeenCalledWith(namespacedId) + expect(result).toBe(mockImageModel) + }) + + it('should handle complex namespaced image IDs', async () => { + const complexIds = [ + `hub${DEFAULT_SEPARATOR}stability${DEFAULT_SEPARATOR}sdxl-turbo`, + `custom${DEFAULT_SEPARATOR}provider${DEFAULT_SEPARATOR}image-gen-v2` + ] + + for (const id of complexIds) { + vi.clearAllMocks() + await resolver.resolveImageModel(id, 'fallback') + + expect(globalRegistryManagement.imageModel).toHaveBeenCalledWith(id) + } + }) + }) + }) + + describe('Edge Cases and Error Scenarios', () => { + it('should handle empty model IDs', async () => { + await resolver.resolveLanguageModel('', 'openai') + + expect(globalRegistryManagement.languageModel).toHaveBeenCalledWith('openai|') + }) + + it('should handle model IDs with multiple separators', async () => { + const multiSeparatorId = `hub${DEFAULT_SEPARATOR}sub${DEFAULT_SEPARATOR}provider${DEFAULT_SEPARATOR}model` + + await resolver.resolveLanguageModel(multiSeparatorId, 'fallback') + + expect(globalRegistryManagement.languageModel).toHaveBeenCalledWith(multiSeparatorId) + }) + + it('should handle model IDs with only separator', async () => { + const onlySeparator = DEFAULT_SEPARATOR + + await resolver.resolveLanguageModel(onlySeparator, 'provider') + + expect(globalRegistryManagement.languageModel).toHaveBeenCalledWith(onlySeparator) + }) + + it('should throw if globalRegistryManagement throws', async () => { + const error = new Error('Model not found in registry') + vi.mocked(globalRegistryManagement.languageModel).mockImplementation(() => { + throw error + }) + + await expect(resolver.resolveLanguageModel('invalid-model', 'openai')).rejects.toThrow( + 'Model not found in registry' + ) + }) + + it('should handle concurrent resolution requests', async () => { + const promises = [ + resolver.resolveLanguageModel('gpt-4', 'openai'), + resolver.resolveLanguageModel('claude-3', 'anthropic'), + resolver.resolveLanguageModel('gemini-2.0', 'google') + ] + + const results = await Promise.all(promises) + + expect(results).toHaveLength(3) + expect(globalRegistryManagement.languageModel).toHaveBeenCalledTimes(3) + }) + }) + + describe('Type Safety', () => { + it('should return properly typed LanguageModelV3', async () => { + const result = await resolver.resolveLanguageModel('gpt-4', 'openai') + + // Type assertions + expect(result.specificationVersion).toBe('v3') + expect(result).toHaveProperty('doGenerate') + expect(result).toHaveProperty('doStream') + }) + + it('should return properly typed EmbeddingModelV3', async () => { + const result = await resolver.resolveTextEmbeddingModel('text-embedding-ada-002', 'openai') + + expect(result.specificationVersion).toBe('v3') + expect(result).toHaveProperty('doEmbed') + }) + + it('should return properly typed ImageModelV3', async () => { + const result = await resolver.resolveImageModel('dall-e-3', 'openai') + + expect(result.specificationVersion).toBe('v3') + expect(result).toHaveProperty('doGenerate') + }) + }) + + describe('Global ModelResolver Instance', () => { + it('should have a global instance available', async () => { + const { globalModelResolver } = await import('../ModelResolver') + + expect(globalModelResolver).toBeInstanceOf(ModelResolver) + }) + }) + + describe('Integration with Different Provider Types', () => { + it('should work with OpenAI compatible providers', async () => { + await resolver.resolveLanguageModel('custom-model', 'openai-compatible') + + expect(globalRegistryManagement.languageModel).toHaveBeenCalledWith('openai-compatible|custom-model') + }) + + it('should work with hub providers', async () => { + const hubId = `aihubmix${DEFAULT_SEPARATOR}custom${DEFAULT_SEPARATOR}model-v1` + + await resolver.resolveLanguageModel(hubId, 'aihubmix') + + expect(globalRegistryManagement.languageModel).toHaveBeenCalledWith(hubId) + }) + + it('should handle all model types for same provider', async () => { + const providerId = 'openai' + const languageModel = 'gpt-4' + const embeddingModel = 'text-embedding-3-small' + const imageModel = 'dall-e-3' + + await resolver.resolveLanguageModel(languageModel, providerId) + await resolver.resolveTextEmbeddingModel(embeddingModel, providerId) + await resolver.resolveImageModel(imageModel, providerId) + + expect(globalRegistryManagement.languageModel).toHaveBeenCalledWith(`${providerId}|${languageModel}`) + expect(globalRegistryManagement.embeddingModel).toHaveBeenCalledWith(`${providerId}|${embeddingModel}`) + expect(globalRegistryManagement.imageModel).toHaveBeenCalledWith(`${providerId}|${imageModel}`) + }) + }) +}) diff --git a/packages/aiCore/src/core/models/__tests__/utils.test.ts b/packages/aiCore/src/core/models/__tests__/utils.test.ts new file mode 100644 index 0000000000..90d9d01854 --- /dev/null +++ b/packages/aiCore/src/core/models/__tests__/utils.test.ts @@ -0,0 +1,171 @@ +import type { LanguageModelV2, LanguageModelV3 } from '@ai-sdk/provider' +import { describe, expect, it } from 'vitest' + +import type { AiSdkModel } from '../../providers' +import { hasModelId, isV2Model, isV3Model } from '../utils' + +describe('Model Type Guards', () => { + describe('isV2Model', () => { + it('should return true for V2 models', () => { + const v2Model: AiSdkModel = { + specificationVersion: 'v2', + modelId: 'test-model', + provider: 'test-provider' + } as LanguageModelV2 + + expect(isV2Model(v2Model)).toBe(true) + }) + + it('should return false for V3 models', () => { + const v3Model: AiSdkModel = { + specificationVersion: 'v3', + modelId: 'test-model', + provider: 'test-provider' + } as LanguageModelV3 + + expect(isV2Model(v3Model)).toBe(false) + }) + + it('should return false for non-object values', () => { + expect(isV2Model('model-id' as any)).toBe(false) + expect(isV2Model(null as any)).toBe(false) + expect(isV2Model(undefined as any)).toBe(false) + expect(isV2Model(123 as any)).toBe(false) + }) + + it('should return false for objects without specificationVersion', () => { + const invalidModel = { + modelId: 'test-model', + provider: 'test-provider' + } as any + + expect(isV2Model(invalidModel)).toBe(false) + }) + }) + + describe('isV3Model', () => { + it('should return true for V3 models', () => { + const v3Model: AiSdkModel = { + specificationVersion: 'v3', + modelId: 'test-model', + provider: 'test-provider' + } as LanguageModelV3 + + expect(isV3Model(v3Model)).toBe(true) + }) + + it('should return false for V2 models', () => { + const v2Model: AiSdkModel = { + specificationVersion: 'v2', + modelId: 'test-model', + provider: 'test-provider' + } as LanguageModelV2 + + expect(isV3Model(v2Model)).toBe(false) + }) + + it('should return false for non-object values', () => { + expect(isV3Model('model-id' as any)).toBe(false) + expect(isV3Model(null as any)).toBe(false) + expect(isV3Model(undefined as any)).toBe(false) + }) + + it('should return false for objects without specificationVersion', () => { + const invalidModel = { + modelId: 'test-model', + provider: 'test-provider' + } as any + + expect(isV3Model(invalidModel)).toBe(false) + }) + }) + + describe('Type Guard Correctness', () => { + it('should correctly distinguish between V2 and V3 models', () => { + const v2Model: AiSdkModel = { + specificationVersion: 'v2', + modelId: 'v2-model' + } as LanguageModelV2 + + const v3Model: AiSdkModel = { + specificationVersion: 'v3', + modelId: 'v3-model' + } as LanguageModelV3 + + // V2 model should only match isV2Model + expect(isV2Model(v2Model)).toBe(true) + expect(isV3Model(v2Model)).toBe(false) + + // V3 model should only match isV3Model + expect(isV2Model(v3Model)).toBe(false) + expect(isV3Model(v3Model)).toBe(true) + }) + + it('should narrow type correctly for V2 models', () => { + const model: AiSdkModel = { + specificationVersion: 'v2', + modelId: 'test' + } as LanguageModelV2 + + if (isV2Model(model)) { + expect(model.specificationVersion).toBe('v2') + } + }) + + it('should narrow type correctly for V3 models', () => { + const model: AiSdkModel = { + specificationVersion: 'v3', + modelId: 'test' + } as LanguageModelV3 + + if (isV3Model(model)) { + expect(model.specificationVersion).toBe('v3') + } + }) + }) + + describe('hasModelId', () => { + it('should return true for objects with modelId string property', () => { + const modelWithId = { + modelId: 'test-model-id', + other: 'property' + } + + expect(hasModelId(modelWithId)).toBe(true) + }) + + it('should return false for objects without modelId property', () => { + const modelWithoutId = { + other: 'property' + } + + expect(hasModelId(modelWithoutId)).toBe(false) + }) + + it('should return false for objects with non-string modelId', () => { + const modelWithNumericId = { + modelId: 123 + } + + expect(hasModelId(modelWithNumericId)).toBe(false) + }) + + it('should return false for non-object values', () => { + expect(hasModelId(null)).toBe(false) + expect(hasModelId(undefined)).toBe(false) + expect(hasModelId('string')).toBe(false) + expect(hasModelId(123)).toBe(false) + expect(hasModelId(true)).toBe(false) + }) + + it('should narrow type correctly', () => { + const unknownValue: unknown = { modelId: 'test-id' } + + if (hasModelId(unknownValue)) { + // TypeScript should allow accessing modelId as string + expect(typeof unknownValue.modelId).toBe('string') + expect(unknownValue.modelId).toBe('test-id') + } + }) + }) +}) diff --git a/packages/aiCore/src/core/plugins/built-in/toolUsePlugin/StreamEventManager.ts b/packages/aiCore/src/core/plugins/built-in/toolUsePlugin/StreamEventManager.ts index 4a3025b39e..a27f168d52 100644 --- a/packages/aiCore/src/core/plugins/built-in/toolUsePlugin/StreamEventManager.ts +++ b/packages/aiCore/src/core/plugins/built-in/toolUsePlugin/StreamEventManager.ts @@ -4,6 +4,7 @@ * 负责处理 AI SDK 流事件的发送和管理 * 从 promptToolUsePlugin.ts 中提取出来以降低复杂度 */ +import type { SharedV3ProviderMetadata } from '@ai-sdk/provider' import type { EmbeddingModelUsage, ImageModelUsage, LanguageModelUsage, ModelMessage } from 'ai' import type { AiSdkUsage } from '../../../providers/types' @@ -79,7 +80,11 @@ export class StreamEventManager { */ sendStepFinishEvent( controller: StreamController, - chunk: any, + chunk: { + usage?: Partial + response?: { id: string; [key: string]: unknown } + providerMetadata?: SharedV3ProviderMetadata + }, context: AiRequestContext, finishReason: string = 'stop' ): void { @@ -154,7 +159,7 @@ export class StreamEventManager { context: AiRequestContext, textBuffer: string, toolResultsText: string, - tools: any + tools: Record ): Partial { const params = context.originalParams diff --git a/packages/aiCore/src/core/plugins/built-in/toolUsePlugin/ToolExecutor.ts b/packages/aiCore/src/core/plugins/built-in/toolUsePlugin/ToolExecutor.ts index 29d644554e..d788bcc68d 100644 --- a/packages/aiCore/src/core/plugins/built-in/toolUsePlugin/ToolExecutor.ts +++ b/packages/aiCore/src/core/plugins/built-in/toolUsePlugin/ToolExecutor.ts @@ -14,15 +14,16 @@ import type { ToolUseResult } from './type' export interface ExecutedResult { toolCallId: string toolName: string - result: any + result: unknown isError?: boolean } /** * 流控制器类型(从 AI SDK 提取) + * Generic type parameter allows for type-safe chunk enqueuing */ -export interface StreamController { - enqueue(chunk: any): void +export interface StreamController { + enqueue(chunk: TChunk): void } /** diff --git a/packages/aiCore/src/core/plugins/built-in/toolUsePlugin/__tests__/StreamEventManager.test.ts b/packages/aiCore/src/core/plugins/built-in/toolUsePlugin/__tests__/StreamEventManager.test.ts new file mode 100644 index 0000000000..cfb2c3df85 --- /dev/null +++ b/packages/aiCore/src/core/plugins/built-in/toolUsePlugin/__tests__/StreamEventManager.test.ts @@ -0,0 +1,566 @@ +import type { SharedV3ProviderMetadata } from '@ai-sdk/provider' +import type { + EmbeddingModelUsage, + ImageModelUsage, + LanguageModelUsage as AiSdkUsage, + LanguageModelUsage, + TextStreamPart, + ToolSet +} from 'ai' +import { simulateReadableStream } from 'ai' +import { beforeEach, describe, expect, it, vi } from 'vitest' + +import { createMockContext, createMockTool } from '../../../../../__tests__' +import { StreamEventManager } from '../StreamEventManager' +import type { StreamController } from '../ToolExecutor' + +/** + * Type alias for empty toolset (no tools) + * Using Record ensures type safety for tests without tools + */ +type EmptyToolSet = Record + +/** + * Mock StreamController for testing + * Provides type-safe enqueue function that accepts TextStreamPart chunks + */ +interface MockStreamController extends StreamController { + enqueue: ReturnType) => void>> +} + +/** + * Create a type-safe mock stream controller + */ +function createMockStreamController(): MockStreamController { + return { + enqueue: vi.fn() + } +} + +/** + * Type for chunk data in finish-step events + */ +interface FinishStepChunk { + usage?: Partial + response?: { id: string; [key: string]: unknown } + providerMetadata?: SharedV3ProviderMetadata +} + +describe('StreamEventManager', () => { + let manager: StreamEventManager + + beforeEach(() => { + manager = new StreamEventManager() + }) + + describe('accumulateUsage', () => { + describe('LanguageModelUsage', () => { + it('should accumulate language model usage correctly', () => { + const target: Partial = { + inputTokens: 10, + outputTokens: 20, + totalTokens: 30 + } + const source: Partial = { + inputTokens: 5, + outputTokens: 10, + totalTokens: 15 + } + + manager.accumulateUsage(target, source) + + expect(target.inputTokens).toBe(15) + expect(target.outputTokens).toBe(30) + expect(target.totalTokens).toBe(45) + }) + + it('should handle undefined values in target', () => { + const target: Partial = { inputTokens: 10 } + const source: Partial = { + inputTokens: 5, + outputTokens: 10, + totalTokens: 15 + } + + manager.accumulateUsage(target, source) + + expect(target.inputTokens).toBe(15) + expect(target.outputTokens).toBe(10) + expect(target.totalTokens).toBe(15) + }) + + it('should handle undefined values in source', () => { + const target: Partial = { + inputTokens: 10, + outputTokens: 20, + totalTokens: 30 + } + const source: Partial = { inputTokens: 5 } + + manager.accumulateUsage(target, source) + + expect(target.inputTokens).toBe(15) + expect(target.outputTokens).toBe(20) + expect(target.totalTokens).toBe(30) + }) + + it('should handle zero values correctly', () => { + const target: Partial = { + inputTokens: 0, + outputTokens: 0, + totalTokens: 0 + } + const source: Partial = { + inputTokens: 5, + outputTokens: 10, + totalTokens: 15 + } + + manager.accumulateUsage(target, source) + + expect(target.inputTokens).toBe(5) + expect(target.outputTokens).toBe(10) + expect(target.totalTokens).toBe(15) + }) + }) + + describe('ImageModelUsage', () => { + it('should accumulate image model usage correctly', () => { + const target: Partial = { + inputTokens: 100, + outputTokens: 50, + totalTokens: 150 + } + const source: Partial = { + inputTokens: 50, + outputTokens: 25, + totalTokens: 75 + } + + manager.accumulateUsage(target, source) + + expect(target.inputTokens).toBe(150) + expect(target.outputTokens).toBe(75) + expect(target.totalTokens).toBe(225) + }) + + it('should handle undefined values', () => { + const target: Partial = { inputTokens: 100 } + const source: Partial = { + outputTokens: 50, + totalTokens: 50 + } + + manager.accumulateUsage(target, source) + + expect(target.inputTokens).toBe(100) + expect(target.outputTokens).toBe(50) + expect(target.totalTokens).toBe(50) + }) + }) + + describe('EmbeddingModelUsage', () => { + it('should accumulate embedding model usage correctly', () => { + const target: Partial = { tokens: 100 } + const source: Partial = { tokens: 50 } + + manager.accumulateUsage(target, source) + + expect(target.tokens).toBe(150) + }) + + it('should handle zero to non-zero accumulation', () => { + const target: Partial = { tokens: 0 } + const source: Partial = { tokens: 50 } + + manager.accumulateUsage(target, source) + + expect(target.tokens).toBe(50) + }) + + it('should handle zero values', () => { + const target: Partial = { tokens: 0 } + const source: Partial = { tokens: 100 } + + manager.accumulateUsage(target, source) + + expect(target.tokens).toBe(100) + }) + }) + + describe('Type Guard Validation', () => { + it('should warn on type mismatch between LanguageModelUsage and EmbeddingModelUsage', () => { + const warnSpy = vi.spyOn(console, 'warn') + const target: Partial = { inputTokens: 10 } + const source: Partial = { tokens: 5 } + + manager.accumulateUsage(target, source) + + expect(warnSpy).toHaveBeenCalledWith( + expect.stringContaining('Unable to accumulate usage'), + expect.objectContaining({ + target, + source + }) + ) + + warnSpy.mockRestore() + }) + + it('should warn on type mismatch between ImageModelUsage and EmbeddingModelUsage', () => { + const warnSpy = vi.spyOn(console, 'warn') + const target: Partial = { inputTokens: 100 } + const source: Partial = { tokens: 50 } + + manager.accumulateUsage(target, source) + + expect(warnSpy).toHaveBeenCalledWith(expect.stringContaining('Unable to accumulate usage'), expect.any(Object)) + + warnSpy.mockRestore() + }) + }) + }) + + describe('buildRecursiveParams', () => { + it('should include textBuffer in assistant message when not empty', () => { + const context = createMockContext() + const textBuffer = 'test response' + const toolResultsText = '...' + const tools = { + test_tool: createMockTool('test_tool') + } + + const params = manager.buildRecursiveParams(context, textBuffer, toolResultsText, tools) + + expect(params.messages).toHaveLength(3) + expect(params.messages?.[0]).toEqual({ role: 'user', content: 'Test message' }) + expect(params.messages?.[1]).toEqual({ role: 'assistant', content: textBuffer }) + expect(params.messages?.[2]).toEqual({ + role: 'user', + content: toolResultsText + }) + expect(params.tools).toBe(tools) + }) + + it('should skip empty textBuffer in messages', () => { + const context = createMockContext() + const textBuffer = '' + const toolResultsText = '...' + const tools = {} + + const params = manager.buildRecursiveParams(context, textBuffer, toolResultsText, tools) + + // Should only have original user message and new user message with tool results + expect(params.messages).toHaveLength(2) + expect(params.messages?.[0]).toEqual({ role: 'user', content: 'Test message' }) + expect(params.messages?.[1]).toEqual({ + role: 'user', + content: toolResultsText + }) + + const assistantMessages = params.messages?.filter((m) => m.role === 'assistant') + expect(assistantMessages).toHaveLength(0) + }) + + it('should preserve all original messages', () => { + const context = createMockContext({ + originalParams: { + messages: [ + { role: 'user', content: 'First message' }, + { role: 'assistant', content: 'First response' }, + { role: 'user', content: 'Second message' } + ] + } + }) + + const params = manager.buildRecursiveParams(context, 'New response', 'Tool results', {}) + + expect(params.messages).toHaveLength(5) + expect(params.messages?.[0]).toEqual({ role: 'user', content: 'First message' }) + expect(params.messages?.[1]).toEqual({ + role: 'assistant', + content: 'First response' + }) + expect(params.messages?.[2]).toEqual({ role: 'user', content: 'Second message' }) + expect(params.messages?.[3]).toEqual({ role: 'assistant', content: 'New response' }) + expect(params.messages?.[4]).toEqual({ role: 'user', content: 'Tool results' }) + }) + + it('should pass through tools parameter', () => { + const context = createMockContext() + const tools = { + tool1: createMockTool('tool1'), + tool2: createMockTool('tool2') + } + + const params = manager.buildRecursiveParams(context, 'response', 'results', tools) + + expect(params.tools).toBe(tools) + expect(Object.keys(params.tools!)).toHaveLength(2) + }) + }) + + describe('sendStepStartEvent', () => { + it('should enqueue start-step event with correct structure', () => { + const controller = createMockStreamController() + + manager.sendStepStartEvent(controller) + + expect(controller.enqueue).toHaveBeenCalledWith({ + type: 'start-step', + request: {}, + warnings: [] + }) + }) + }) + + describe('sendStepFinishEvent', () => { + it('should enqueue finish-step event with provided finishReason', () => { + const controller = createMockStreamController() + + const chunk: FinishStepChunk = { + usage: { + inputTokens: 10, + outputTokens: 20, + totalTokens: 30 + }, + response: { id: 'test-response' }, + providerMetadata: { 'test-provider': {} } + } + + const context = createMockContext({ + accumulatedUsage: { + inputTokens: 0, + outputTokens: 0, + totalTokens: 0 + } + }) + + manager.sendStepFinishEvent(controller, chunk, context, 'tool-calls') + + expect(controller.enqueue).toHaveBeenCalledWith({ + type: 'finish-step', + finishReason: 'tool-calls', + response: chunk.response, + usage: chunk.usage, + providerMetadata: chunk.providerMetadata + }) + }) + + it('should accumulate usage when provided', () => { + const controller = createMockStreamController() + + const chunk: FinishStepChunk = { + usage: { + inputTokens: 10, + outputTokens: 20, + totalTokens: 30 + } + } + + const context = createMockContext({ + accumulatedUsage: { + inputTokens: 5, + outputTokens: 10, + totalTokens: 15 + } + }) + + manager.sendStepFinishEvent(controller, chunk, context) + + // Verify accumulation happened + expect(context.accumulatedUsage.inputTokens).toBe(15) + expect(context.accumulatedUsage.outputTokens).toBe(30) + expect(context.accumulatedUsage.totalTokens).toBe(45) + }) + + it('should handle missing usage gracefully', () => { + const controller = createMockStreamController() + + const chunk: FinishStepChunk = {} + const context = createMockContext({ + accumulatedUsage: { + inputTokens: 5, + outputTokens: 10, + totalTokens: 15 + } + }) + + expect(() => manager.sendStepFinishEvent(controller, chunk, context)).not.toThrow() + + // Verify accumulation did not change + expect(context.accumulatedUsage.inputTokens).toBe(5) + expect(context.accumulatedUsage.outputTokens).toBe(10) + expect(context.accumulatedUsage.totalTokens).toBe(15) + }) + + it('should use default finishReason of "stop" when not provided', () => { + const controller = createMockStreamController() + + const chunk: FinishStepChunk = {} + const context = createMockContext() + + manager.sendStepFinishEvent(controller, chunk, context) + + expect(controller.enqueue).toHaveBeenCalledWith( + expect.objectContaining({ + finishReason: 'stop' + }) + ) + }) + }) + + describe('handleRecursiveCall', () => { + it('should reset hasExecutedToolsInCurrentStep flag', async () => { + const controller = createMockStreamController() + + const mockStream = simulateReadableStream>({ + chunks: [ + { + type: 'text-delta', + id: 'test-id', + text: 'test' + } as TextStreamPart + ], + initialDelayInMs: 0, + chunkDelayInMs: 0 + }) + + const context = createMockContext({ + hasExecutedToolsInCurrentStep: true, + recursiveCall: vi.fn().mockResolvedValue({ + fullStream: mockStream + }) + }) + + const params = { messages: [] } + + await manager.handleRecursiveCall(controller, params, context) + + expect(context.hasExecutedToolsInCurrentStep).toBe(false) + expect(context.recursiveCall).toHaveBeenCalledWith(params) + }) + + it('should pipe recursive stream to controller', async () => { + const enqueuedChunks: TextStreamPart[] = [] + const controller = createMockStreamController() + controller.enqueue.mockImplementation((chunk: TextStreamPart) => { + enqueuedChunks.push(chunk) + }) + + const mockChunks: TextStreamPart[] = [ + { type: 'start' as const }, + { type: 'start-step' as const, request: {}, warnings: [] }, + { type: 'text-delta' as const, id: 'chunk-1', text: 'recursive' }, + { type: 'text-delta' as const, id: 'chunk-2', text: ' response' }, + { + type: 'finish-step' as const, + finishReason: 'stop', + rawFinishReason: 'stop', + response: { + id: 'test-response-id', + timestamp: new Date(), + modelId: 'test-model' + }, + usage: { + totalTokens: 0, + inputTokens: 0, + outputTokens: 0, + inputTokenDetails: { + noCacheTokens: 0, + cacheReadTokens: 0, + cacheWriteTokens: 0 + }, + outputTokenDetails: { + textTokens: 0, + reasoningTokens: 0 + } + }, + providerMetadata: undefined + }, + { + type: 'finish' as const, + finishReason: 'stop', + rawFinishReason: 'stop', + totalUsage: { + totalTokens: 0, + inputTokens: 0, + outputTokens: 0, + inputTokenDetails: { + noCacheTokens: 0, + cacheReadTokens: 0, + cacheWriteTokens: 0 + }, + outputTokenDetails: { + textTokens: 0, + reasoningTokens: 0 + } + } + } + ] + + const mockStream = simulateReadableStream>({ + chunks: mockChunks, + initialDelayInMs: 0, + chunkDelayInMs: 0 + }) + + const context = createMockContext({ + hasExecutedToolsInCurrentStep: true, + recursiveCall: vi.fn().mockResolvedValue({ + fullStream: mockStream + }) + }) + + await manager.handleRecursiveCall(controller, {}, context) + + // Should skip 'start' type and stop at 'finish' type + expect(enqueuedChunks).toHaveLength(4) + expect(enqueuedChunks[0]).toEqual({ type: 'start-step', request: {}, warnings: [] }) + expect(enqueuedChunks[1]).toEqual({ type: 'text-delta', id: 'chunk-1', text: 'recursive' }) + expect(enqueuedChunks[2]).toEqual({ type: 'text-delta', id: 'chunk-2', text: ' response' }) + expect(enqueuedChunks[3]).toMatchObject({ + type: 'finish-step', + finishReason: 'stop', + rawFinishReason: 'stop', + providerMetadata: undefined, + usage: { + totalTokens: 0, + inputTokens: 0, + outputTokens: 0, + inputTokenDetails: { + noCacheTokens: 0, + cacheReadTokens: 0, + cacheWriteTokens: 0 + }, + outputTokenDetails: { + textTokens: 0, + reasoningTokens: 0 + } + } + }) + }) + + it('should warn when no fullStream is found', async () => { + const warnSpy = vi.spyOn(console, 'warn') + const controller = createMockStreamController() + + const context = createMockContext({ + hasExecutedToolsInCurrentStep: true, + recursiveCall: vi.fn().mockResolvedValue({ + // No fullStream property + someOtherProperty: 'value' + }) + }) + + await manager.handleRecursiveCall(controller, {}, context) + + expect(warnSpy).toHaveBeenCalledWith( + expect.stringContaining('[MCP Prompt] No fullstream found'), + expect.any(Object) + ) + + warnSpy.mockRestore() + }) + }) +}) diff --git a/packages/aiCore/src/core/plugins/built-in/toolUsePlugin/__tests__/promptToolUsePlugin.test.ts b/packages/aiCore/src/core/plugins/built-in/toolUsePlugin/__tests__/promptToolUsePlugin.test.ts new file mode 100644 index 0000000000..1bfdd1250a --- /dev/null +++ b/packages/aiCore/src/core/plugins/built-in/toolUsePlugin/__tests__/promptToolUsePlugin.test.ts @@ -0,0 +1,547 @@ +import type { TextStreamPart, ToolSet } from 'ai' +import { simulateReadableStream } from 'ai' +import { convertReadableStreamToArray } from 'ai/test' +import { describe, expect, it, vi } from 'vitest' + +import { createMockContext, createMockStreamParams, createMockTool, createMockToolSet } from '../../../../../__tests__' +import { createPromptToolUsePlugin, DEFAULT_SYSTEM_PROMPT } from '../promptToolUsePlugin' + +describe('promptToolUsePlugin', () => { + describe('Factory Function', () => { + it('should return AiPlugin with correct name', () => { + const plugin = createPromptToolUsePlugin() + + expect(plugin.name).toBe('built-in:prompt-tool-use') + expect(plugin.transformParams).toBeDefined() + expect(plugin.transformStream).toBeDefined() + }) + + it('should accept empty configuration', () => { + const plugin = createPromptToolUsePlugin({}) + + expect(plugin).toBeDefined() + expect(plugin.name).toBe('built-in:prompt-tool-use') + }) + + it('should accept custom buildSystemPrompt', () => { + const customBuildSystemPrompt = vi.fn((userSystemPrompt: string) => userSystemPrompt) + + const plugin = createPromptToolUsePlugin({ + buildSystemPrompt: customBuildSystemPrompt + }) + + expect(plugin).toBeDefined() + }) + + it('should accept custom parseToolUse', () => { + const customParseToolUse = vi.fn(() => ({ results: [], content: '' })) + + const plugin = createPromptToolUsePlugin({ + parseToolUse: customParseToolUse + }) + + expect(plugin).toBeDefined() + }) + + it('should accept enabled flag', () => { + const pluginDisabled = createPromptToolUsePlugin({ enabled: false }) + const pluginEnabled = createPromptToolUsePlugin({ enabled: true }) + + expect(pluginDisabled).toBeDefined() + expect(pluginEnabled).toBeDefined() + }) + }) + + describe('transformParams', () => { + it('should separate provider and prompt tools', async () => { + const plugin = createPromptToolUsePlugin() + const context = createMockContext() + const params = createMockStreamParams({ + tools: createMockToolSet({ + provider_tool: 'provider', + prompt_tool: 'function' + }) + }) + + const result = await Promise.resolve(plugin.transformParams!(params, context)) + + // Provider tools should remain in tools + expect(result.tools).toBeDefined() + expect(result.tools).toHaveProperty('provider_tool') + expect(result.tools).not.toHaveProperty('prompt_tool') + + // Prompt tools should be moved to context.mcpTools + expect(context.mcpTools).toBeDefined() + expect(context.mcpTools).toHaveProperty('prompt_tool') + expect(context.mcpTools).not.toHaveProperty('provider_tool') + }) + + it('should handle only provider tools', async () => { + const plugin = createPromptToolUsePlugin() + const context = createMockContext() + const params = createMockStreamParams({ + tools: createMockToolSet({ + provider_tool1: 'provider', + provider_tool2: 'provider' + }) + }) + + const result = await Promise.resolve(plugin.transformParams!(params, context)) + + expect(result.tools).toEqual(params.tools) + expect(context.mcpTools).toBeUndefined() + }) + + it('should handle only prompt tools', async () => { + const plugin = createPromptToolUsePlugin() + const context = createMockContext() + const params = createMockStreamParams({ + tools: createMockToolSet({ + prompt_tool1: 'function', + prompt_tool2: 'function' + }) + }) + + const result = await Promise.resolve(plugin.transformParams!(params, context)) + + expect(result.tools).toBeUndefined() + expect(context.mcpTools).toEqual(params.tools) + }) + + it('should build system prompt for prompt tools', async () => { + const plugin = createPromptToolUsePlugin() + const context = createMockContext() + const params = createMockStreamParams({ + system: 'Original system prompt', + tools: { + test_tool: createMockTool('test_tool', 'Test tool description') + } + }) + + const result = await Promise.resolve(plugin.transformParams!(params, context)) + + expect(result.system).toBeDefined() + expect(typeof result.system).toBe('string') + expect(result.system).toContain('In this environment you have access to a set of tools') + expect(result.system).toContain('test_tool') + expect(result.system).toContain('Test tool description') + expect(result.system).toContain('Original system prompt') + }) + + it('should handle empty user system prompt', async () => { + const plugin = createPromptToolUsePlugin() + const context = createMockContext() + const params = createMockStreamParams({ + tools: { + test_tool: createMockTool('test_tool') + } + }) + + const result = await Promise.resolve(plugin.transformParams!(params, context)) + + expect(result.system).toBeDefined() + expect(result.system).toContain('In this environment you have access to a set of tools') + }) + + it('should skip system prompt when disabled', async () => { + const plugin = createPromptToolUsePlugin({ enabled: false }) + const context = createMockContext() + const params = createMockStreamParams({ + system: 'Original', + tools: { + test_tool: createMockTool('test_tool') + } + }) + + const result = await Promise.resolve(plugin.transformParams!(params, context)) + + expect(result).toEqual(params) + expect(context.mcpTools).toBeUndefined() + }) + + it('should skip when no tools provided', async () => { + const plugin = createPromptToolUsePlugin() + const context = createMockContext() + const params = createMockStreamParams({ + system: 'Original' + }) + + const result = await Promise.resolve(plugin.transformParams!(params, context)) + + expect(result).toEqual(params) + }) + + it('should skip when tools is not an object', async () => { + const plugin = createPromptToolUsePlugin() + const context = createMockContext() + const params = createMockStreamParams({ + system: 'Original', + tools: 'invalid' as any + }) + + const result = await Promise.resolve(plugin.transformParams!(params, context)) + + expect(result).toEqual(params) + }) + + it('should use custom buildSystemPrompt when provided', async () => { + const customBuildSystemPrompt = vi.fn((userSystemPrompt: string, tools: ToolSet) => { + return `Custom prompt with ${Object.keys(tools).length} tools and user prompt: ${userSystemPrompt}` + }) + + const plugin = createPromptToolUsePlugin({ + buildSystemPrompt: customBuildSystemPrompt + }) + + const context = createMockContext() + const params = createMockStreamParams({ + system: 'User prompt', + tools: { + tool1: createMockTool('tool1') + } + }) + + const result = await Promise.resolve(plugin.transformParams!(params, context)) + + expect(customBuildSystemPrompt).toHaveBeenCalled() + expect(result.system).toBe('Custom prompt with 1 tools and user prompt: User prompt') + }) + + it('should use custom createSystemMessage when provided', async () => { + const customCreateSystemMessage = vi.fn(() => { + return `Modified system message` + }) + + const plugin = createPromptToolUsePlugin({ + createSystemMessage: customCreateSystemMessage + }) + + const context = createMockContext() + const params = createMockStreamParams({ + system: 'Original', + tools: { + test: createMockTool('test') + } + }) + + const result = await Promise.resolve(plugin.transformParams!(params, context)) + + expect(customCreateSystemMessage).toHaveBeenCalled() + expect(result.system).toContain('Modified') + }) + + it('should save originalParams to context', async () => { + const plugin = createPromptToolUsePlugin() + const context = createMockContext() + const params = createMockStreamParams({ + system: 'Original', + tools: { + test: createMockTool('test') + } + }) + + await Promise.resolve(plugin.transformParams!(params, context)) + + expect(context.originalParams).toBeDefined() + expect(context.originalParams.system).toBeDefined() + }) + }) + + describe('transformStream', () => { + it('should return identity transform when disabled', async () => { + const plugin = createPromptToolUsePlugin({ enabled: false }) + const context = createMockContext() + + const inputChunks: Array<{ type: 'text-delta'; text: string }> = [ + { type: 'text-delta', text: 'Hello' }, + { type: 'text-delta', text: ' World' } + ] + + const inputStream = simulateReadableStream>({ + chunks: inputChunks as TextStreamPart[], + initialDelayInMs: 0, + chunkDelayInMs: 0 + }) + + const transform = plugin.transformStream!(createMockStreamParams(), context)() + const result = await convertReadableStreamToArray(inputStream.pipeThrough(transform)) + + expect(result).toEqual(inputChunks) + }) + + it('should return identity transform when no mcpTools in context', async () => { + const plugin = createPromptToolUsePlugin() + const context = createMockContext() + // Don't set context.mcpTools + + const inputChunks: Array<{ type: 'text-delta'; text: string }> = [ + { type: 'text-delta', text: 'Hello' }, + { type: 'text-delta', text: ' World' } + ] + + const inputStream = simulateReadableStream>({ + chunks: inputChunks as TextStreamPart[], + initialDelayInMs: 0, + chunkDelayInMs: 0 + }) + + const transform = plugin.transformStream!(createMockStreamParams(), context)() + const result = await convertReadableStreamToArray(inputStream.pipeThrough(transform)) + + expect(result).toEqual(inputChunks) + }) + + it('should initialize accumulatedUsage in context', () => { + const plugin = createPromptToolUsePlugin() + const context = createMockContext() + context.mcpTools = { + test: createMockTool('test') + } + + plugin.transformStream!(createMockStreamParams(), context)() + + expect(context.accumulatedUsage).toBeDefined() + expect(context.accumulatedUsage).toEqual({ + inputTokens: 0, + outputTokens: 0, + totalTokens: 0, + reasoningTokens: 0, + cachedInputTokens: 0 + }) + }) + + it('should filter tool tags from text-delta chunks', async () => { + const plugin = createPromptToolUsePlugin() + const context = createMockContext() + context.mcpTools = { + test: createMockTool('test') + } + + const inputChunks = [ + { type: 'text-start' as const }, + { type: 'text-delta' as const, text: 'Before ' }, + { type: 'text-delta' as const, text: '' }, + { type: 'text-delta' as const, text: 'test' }, + { type: 'text-delta' as const, text: '{}' }, + { type: 'text-delta' as const, text: '' }, + { type: 'text-delta' as const, text: ' After' }, + { type: 'text-end' as const } + ] + + const inputStream = simulateReadableStream>({ + chunks: inputChunks as TextStreamPart[], + initialDelayInMs: 0, + chunkDelayInMs: 0 + }) + + const transform = plugin.transformStream!(createMockStreamParams(), context)() + const result = await convertReadableStreamToArray(inputStream.pipeThrough(transform)) + + // Extract text from text-delta chunks + const textChunks = result.filter((chunk) => chunk.type === 'text-delta') + const fullText = textChunks.map((chunk) => 'text' in chunk && chunk.text).join('') + + // Tool tags should be filtered out + expect(fullText).not.toContain('') + expect(fullText).not.toContain('') + expect(fullText).toContain('Before') + expect(fullText).toContain('After') + }) + + it('should hold text-start until non-tag content appears', async () => { + const plugin = createPromptToolUsePlugin() + const context = createMockContext() + context.mcpTools = { + test: createMockTool('test') + } + + // Only tool tags, no actual content + const inputChunks = [ + { type: 'text-start' as const }, + { type: 'text-delta' as const, text: '' }, + { type: 'text-delta' as const, text: 'test' }, + { type: 'text-delta' as const, text: '{}' }, + { type: 'text-delta' as const, text: '' }, + { type: 'text-end' as const } + ] + + const inputStream = simulateReadableStream>({ + chunks: inputChunks as TextStreamPart[], + initialDelayInMs: 0, + chunkDelayInMs: 0 + }) + + const transform = plugin.transformStream!(createMockStreamParams(), context)() + const result = await convertReadableStreamToArray(inputStream.pipeThrough(transform)) + + // Should not have text-start or text-end since all content was tool tags + expect(result.some((chunk) => chunk.type === 'text-start')).toBe(false) + expect(result.some((chunk) => chunk.type === 'text-end')).toBe(false) + }) + + it('should send text-start when non-tag content appears', async () => { + const plugin = createPromptToolUsePlugin() + const context = createMockContext() + context.mcpTools = { + test: createMockTool('test') + } + + const inputChunks = [ + { type: 'text-start' as const }, + { type: 'text-delta' as const, text: 'Actual content' }, + { type: 'text-end' as const } + ] + + const inputStream = simulateReadableStream>({ + chunks: inputChunks as TextStreamPart[], + initialDelayInMs: 0, + chunkDelayInMs: 0 + }) + + const transform = plugin.transformStream!(createMockStreamParams(), context)() + const result = await convertReadableStreamToArray(inputStream.pipeThrough(transform)) + + // Should have text-start, text-delta, and text-end + expect(result.some((chunk) => chunk.type === 'text-start')).toBe(true) + expect(result.some((chunk) => chunk.type === 'text-delta')).toBe(true) + expect(result.some((chunk) => chunk.type === 'text-end')).toBe(true) + }) + + it('should pass through non-text events', async () => { + const plugin = createPromptToolUsePlugin() + const context = createMockContext() + context.mcpTools = { + test: createMockTool('test') + } + + const stepStartEvent = { type: 'start-step' as const, request: {}, warnings: [] } + + const inputChunks = [stepStartEvent] + + const inputStream = simulateReadableStream>({ + chunks: inputChunks as TextStreamPart[], + initialDelayInMs: 0, + chunkDelayInMs: 0 + }) + + const transform = plugin.transformStream!(createMockStreamParams(), context)() + const result = await convertReadableStreamToArray(inputStream.pipeThrough(transform)) + + expect(result[0]).toEqual(stepStartEvent) + }) + + it('should accumulate usage from finish-step events', async () => { + const plugin = createPromptToolUsePlugin() + const context = createMockContext() + context.mcpTools = { + test: createMockTool('test') + } + + const inputChunks = [ + { + type: 'finish-step' as const, + finishReason: 'stop' as const, + usage: { + inputTokens: 10, + outputTokens: 20, + totalTokens: 30 + } + } + ] + + const inputStream = simulateReadableStream>({ + chunks: inputChunks as TextStreamPart[], + initialDelayInMs: 0, + chunkDelayInMs: 0 + }) + + const transform = plugin.transformStream!(createMockStreamParams(), context)() + await convertReadableStreamToArray(inputStream.pipeThrough(transform)) + + // Verify usage was accumulated + expect(context.accumulatedUsage).toBeDefined() + expect(context.accumulatedUsage!.inputTokens).toBe(10) + expect(context.accumulatedUsage!.outputTokens).toBe(20) + expect(context.accumulatedUsage!.totalTokens).toBe(30) + }) + + it('should include accumulated usage in finish event', async () => { + const plugin = createPromptToolUsePlugin() + const context = createMockContext() + context.mcpTools = { + test: createMockTool('test') + } + + // Pre-populate accumulated usage + context.accumulatedUsage = { + inputTokens: 5, + outputTokens: 10, + totalTokens: 15, + reasoningTokens: 0, + cachedInputTokens: 0 + } + + const inputChunks = [ + { + type: 'finish' as const, + finishReason: 'stop' as const, + usage: { + inputTokens: 100, + outputTokens: 200, + totalTokens: 300 + } + } + ] + + const inputStream = simulateReadableStream>({ + chunks: inputChunks as unknown as TextStreamPart[], + initialDelayInMs: 0, + chunkDelayInMs: 0 + }) + + const transform = plugin.transformStream!(createMockStreamParams(), context)() + const result = await convertReadableStreamToArray(inputStream.pipeThrough(transform)) + + const finishEvent = result.find((chunk) => chunk.type === 'finish') + expect(finishEvent).toBeDefined() + if (finishEvent && 'totalUsage' in finishEvent) { + expect(finishEvent.totalUsage).toEqual(context.accumulatedUsage) + } + }) + }) + + describe('Type Safety', () => { + it('should have correct generic parameters for StreamTextParams and StreamTextResult', () => { + const plugin = createPromptToolUsePlugin() + + // Type assertion to verify the plugin has the correct type + type PluginType = typeof plugin + const typeTest: PluginType = plugin + + expect(typeTest.name).toBe('built-in:prompt-tool-use') + }) + }) + + describe('DEFAULT_SYSTEM_PROMPT', () => { + it('should contain required sections', () => { + expect(DEFAULT_SYSTEM_PROMPT).toContain('Tool Use Formatting') + expect(DEFAULT_SYSTEM_PROMPT).toContain('Tool Use Examples') + expect(DEFAULT_SYSTEM_PROMPT).toContain('Tool Use Available Tools') + expect(DEFAULT_SYSTEM_PROMPT).toContain('Tool Use Rules') + expect(DEFAULT_SYSTEM_PROMPT).toContain('Response rules') + }) + + it('should have placeholders for dynamic content', () => { + expect(DEFAULT_SYSTEM_PROMPT).toContain('{{ TOOL_USE_EXAMPLES }}') + expect(DEFAULT_SYSTEM_PROMPT).toContain('{{ AVAILABLE_TOOLS }}') + expect(DEFAULT_SYSTEM_PROMPT).toContain('{{ USER_SYSTEM_PROMPT }}') + }) + + it('should contain XML tag examples', () => { + expect(DEFAULT_SYSTEM_PROMPT).toContain('') + expect(DEFAULT_SYSTEM_PROMPT).toContain('') + expect(DEFAULT_SYSTEM_PROMPT).toContain('') + expect(DEFAULT_SYSTEM_PROMPT).toContain('') + }) + }) +}) diff --git a/packages/aiCore/src/core/plugins/index.ts b/packages/aiCore/src/core/plugins/index.ts index bc8b1c3088..1e280abbe8 100644 --- a/packages/aiCore/src/core/plugins/index.ts +++ b/packages/aiCore/src/core/plugins/index.ts @@ -33,8 +33,8 @@ export function createContext { * 执行 transformParams 钩子 - 链式参数转换 * 每个插件返回 Partial,逐步合并到原始参数 */ - async executeTransformParams( - initialValue: TParams, - context: AiRequestContext - ): Promise { + async executeTransformParams(initialValue: TParams, context: AiRequestContext): Promise { let result = initialValue for (const plugin of this.plugins) { @@ -93,10 +90,7 @@ export class PluginManager { * 执行 transformResult 钩子 - 链式结果转换 * 每个插件接收并返回完整的 TResult */ - async executeTransformResult( - initialValue: TResult, - context: AiRequestContext - ): Promise { + async executeTransformResult(initialValue: TResult, context: AiRequestContext): Promise { let result = initialValue for (const plugin of this.plugins) { diff --git a/packages/aiCore/src/core/providers/__tests__/HubProvider.test.ts b/packages/aiCore/src/core/providers/__tests__/HubProvider.test.ts new file mode 100644 index 0000000000..1f608791ee --- /dev/null +++ b/packages/aiCore/src/core/providers/__tests__/HubProvider.test.ts @@ -0,0 +1,525 @@ +/** + * HubProvider Comprehensive Tests + * Tests hub provider routing, model resolution, and error handling + * Covers multi-provider routing with namespaced model IDs + */ + +import type { EmbeddingModelV3, ImageModelV3, LanguageModelV3, ProviderV3 } from '@ai-sdk/provider' +import { customProvider, wrapProvider } from 'ai' +import { beforeEach, describe, expect, it, vi } from 'vitest' + +import { createMockEmbeddingModel, createMockImageModel, createMockLanguageModel } from '../../../__tests__' +import { createHubProvider, type HubProviderConfig, HubProviderError } from '../HubProvider' +import { DEFAULT_SEPARATOR, globalRegistryManagement } from '../RegistryManagement' + +// Mock dependencies +vi.mock('../RegistryManagement', () => ({ + globalRegistryManagement: { + getProvider: vi.fn() + }, + DEFAULT_SEPARATOR: '|' +})) + +vi.mock('ai', () => ({ + customProvider: vi.fn((config) => config.fallbackProvider), + wrapProvider: vi.fn((config) => config.provider) +})) + +describe('HubProvider', () => { + let mockOpenAIProvider: ProviderV3 + let mockAnthropicProvider: ProviderV3 + let mockLanguageModel: LanguageModelV3 + let mockEmbeddingModel: EmbeddingModelV3 + let mockImageModel: ImageModelV3 + + beforeEach(() => { + vi.clearAllMocks() + + // Create mock models using global utilities + mockLanguageModel = createMockLanguageModel({ + provider: 'test', + modelId: 'test-model' + }) + + mockEmbeddingModel = createMockEmbeddingModel({ + provider: 'test', + modelId: 'test-embedding' + }) + + mockImageModel = createMockImageModel({ + provider: 'test', + modelId: 'test-image' + }) + + // Create mock providers + mockOpenAIProvider = { + specificationVersion: 'v3', + languageModel: vi.fn().mockReturnValue(mockLanguageModel), + embeddingModel: vi.fn().mockReturnValue(mockEmbeddingModel), + imageModel: vi.fn().mockReturnValue(mockImageModel) + } as ProviderV3 + + mockAnthropicProvider = { + specificationVersion: 'v3', + languageModel: vi.fn().mockReturnValue(mockLanguageModel), + embeddingModel: vi.fn().mockReturnValue(mockEmbeddingModel), + imageModel: vi.fn().mockReturnValue(mockImageModel) + } as ProviderV3 + + // Setup default mock implementation + vi.mocked(globalRegistryManagement.getProvider).mockImplementation((id) => { + if (id === 'openai') return mockOpenAIProvider + if (id === 'anthropic') return mockAnthropicProvider + return undefined + }) + }) + + describe('Provider Creation', () => { + it('should create hub provider with basic config', () => { + const config: HubProviderConfig = { + hubId: 'test-hub' + } + + const provider = createHubProvider(config) + + expect(provider).toBeDefined() + expect(customProvider).toHaveBeenCalled() + }) + + it('should create provider with debug flag', () => { + const config: HubProviderConfig = { + hubId: 'test-hub', + debug: true + } + + const provider = createHubProvider(config) + + expect(provider).toBeDefined() + }) + + it('should return ProviderV3 specification', () => { + const config: HubProviderConfig = { + hubId: 'aihubmix' + } + + const provider = createHubProvider(config) + + expect(provider).toHaveProperty('specificationVersion', 'v3') + expect(provider).toHaveProperty('languageModel') + expect(provider).toHaveProperty('embeddingModel') + expect(provider).toHaveProperty('imageModel') + }) + }) + + describe('Model ID Parsing', () => { + it('should parse valid hub model ID format', () => { + const config: HubProviderConfig = { hubId: 'test-hub' } + const provider = createHubProvider(config) as ProviderV3 + + const modelId = `openai${DEFAULT_SEPARATOR}gpt-4` + + const result = provider.languageModel(modelId) + + expect(globalRegistryManagement.getProvider).toHaveBeenCalledWith('openai') + expect(mockOpenAIProvider.languageModel).toHaveBeenCalledWith('gpt-4') + expect(result).toBe(mockLanguageModel) + }) + + it('should throw error for invalid model ID format', () => { + const config: HubProviderConfig = { hubId: 'test-hub' } + const provider = createHubProvider(config) as ProviderV3 + + const invalidId = 'invalid-id-without-separator' + + expect(() => provider.languageModel(invalidId)).toThrow(HubProviderError) + }) + + it('should throw error for model ID with multiple separators', () => { + const config: HubProviderConfig = { hubId: 'test-hub' } + const provider = createHubProvider(config) as ProviderV3 + + const multiSeparatorId = `provider${DEFAULT_SEPARATOR}extra${DEFAULT_SEPARATOR}model` + + expect(() => provider.languageModel(multiSeparatorId)).toThrow(HubProviderError) + }) + + it('should throw error for empty model ID', () => { + const config: HubProviderConfig = { hubId: 'test-hub' } + const provider = createHubProvider(config) as ProviderV3 + + expect(() => provider.languageModel('')).toThrow(HubProviderError) + }) + + it('should throw error for model ID with only separator', () => { + const config: HubProviderConfig = { hubId: 'test-hub' } + const provider = createHubProvider(config) as ProviderV3 + + expect(() => provider.languageModel(DEFAULT_SEPARATOR)).toThrow(HubProviderError) + }) + }) + + describe('Language Model Resolution', () => { + it('should route to correct provider for language model', () => { + const config: HubProviderConfig = { hubId: 'aihubmix' } + const provider = createHubProvider(config) as ProviderV3 + + const result = provider.languageModel(`openai${DEFAULT_SEPARATOR}gpt-4`) + + expect(globalRegistryManagement.getProvider).toHaveBeenCalledWith('openai') + expect(mockOpenAIProvider.languageModel).toHaveBeenCalledWith('gpt-4') + expect(result).toBe(mockLanguageModel) + }) + + it('should route different providers correctly', () => { + const config: HubProviderConfig = { hubId: 'aihubmix' } + const provider = createHubProvider(config) as ProviderV3 + + provider.languageModel(`openai${DEFAULT_SEPARATOR}gpt-4`) + provider.languageModel(`anthropic${DEFAULT_SEPARATOR}claude-3`) + + expect(globalRegistryManagement.getProvider).toHaveBeenCalledWith('openai') + expect(globalRegistryManagement.getProvider).toHaveBeenCalledWith('anthropic') + expect(mockOpenAIProvider.languageModel).toHaveBeenCalledWith('gpt-4') + expect(mockAnthropicProvider.languageModel).toHaveBeenCalledWith('claude-3') + }) + + it('should wrap provider with wrapProvider', () => { + const config: HubProviderConfig = { hubId: 'test-hub' } + const provider = createHubProvider(config) as ProviderV3 + + provider.languageModel(`openai${DEFAULT_SEPARATOR}gpt-4`) + + expect(wrapProvider).toHaveBeenCalledWith({ + provider: mockOpenAIProvider, + languageModelMiddleware: [] + }) + }) + + it('should throw HubProviderError if provider not initialized', () => { + vi.mocked(globalRegistryManagement.getProvider).mockReturnValue(undefined) + + const config: HubProviderConfig = { hubId: 'test-hub' } + const provider = createHubProvider(config) as ProviderV3 + + expect(() => provider.languageModel(`uninitialized${DEFAULT_SEPARATOR}model`)).toThrow(HubProviderError) + expect(() => provider.languageModel(`uninitialized${DEFAULT_SEPARATOR}model`)).toThrow(/not initialized/) + }) + + it('should include provider ID in error message', () => { + vi.mocked(globalRegistryManagement.getProvider).mockReturnValue(undefined) + + const config: HubProviderConfig = { hubId: 'test-hub' } + const provider = createHubProvider(config) as ProviderV3 + + try { + provider.languageModel(`missing${DEFAULT_SEPARATOR}model`) + expect.fail('Should have thrown HubProviderError') + } catch (error) { + expect(error).toBeInstanceOf(HubProviderError) + const hubError = error as HubProviderError + expect(hubError.providerId).toBe('missing') + expect(hubError.hubId).toBe('test-hub') + } + }) + }) + + describe('Embedding Model Resolution', () => { + it('should route to correct provider for embedding model', () => { + const config: HubProviderConfig = { hubId: 'aihubmix' } + const provider = createHubProvider(config) as ProviderV3 + + const result = provider.embeddingModel(`openai${DEFAULT_SEPARATOR}text-embedding-3-small`) + + expect(globalRegistryManagement.getProvider).toHaveBeenCalledWith('openai') + expect(mockOpenAIProvider.embeddingModel).toHaveBeenCalledWith('text-embedding-3-small') + expect(result).toBe(mockEmbeddingModel) + }) + + it('should handle different embedding providers', () => { + const config: HubProviderConfig = { hubId: 'aihubmix' } + const provider = createHubProvider(config) as ProviderV3 + + provider.embeddingModel(`openai${DEFAULT_SEPARATOR}ada-002`) + provider.embeddingModel(`anthropic${DEFAULT_SEPARATOR}embed-v1`) + + expect(mockOpenAIProvider.embeddingModel).toHaveBeenCalledWith('ada-002') + expect(mockAnthropicProvider.embeddingModel).toHaveBeenCalledWith('embed-v1') + }) + + it('should throw error for uninitialized embedding provider', () => { + vi.mocked(globalRegistryManagement.getProvider).mockReturnValue(undefined) + + const config: HubProviderConfig = { hubId: 'test-hub' } + const provider = createHubProvider(config) as ProviderV3 + + expect(() => provider.embeddingModel(`missing${DEFAULT_SEPARATOR}embed`)).toThrow(HubProviderError) + }) + }) + + describe('Image Model Resolution', () => { + it('should route to correct provider for image model', () => { + const config: HubProviderConfig = { hubId: 'aihubmix' } + const provider = createHubProvider(config) as ProviderV3 + + const result = provider.imageModel(`openai${DEFAULT_SEPARATOR}dall-e-3`) + + expect(globalRegistryManagement.getProvider).toHaveBeenCalledWith('openai') + expect(mockOpenAIProvider.imageModel).toHaveBeenCalledWith('dall-e-3') + expect(result).toBe(mockImageModel) + }) + + it('should handle different image providers', () => { + const config: HubProviderConfig = { hubId: 'aihubmix' } + const provider = createHubProvider(config) as ProviderV3 + + provider.imageModel(`openai${DEFAULT_SEPARATOR}dall-e-3`) + provider.imageModel(`anthropic${DEFAULT_SEPARATOR}image-gen`) + + expect(mockOpenAIProvider.imageModel).toHaveBeenCalledWith('dall-e-3') + expect(mockAnthropicProvider.imageModel).toHaveBeenCalledWith('image-gen') + }) + }) + + describe('Special Model Types', () => { + it('should support transcription models', () => { + const mockTranscriptionModel = { + specificationVersion: 'v3', + doTranscribe: vi.fn() + } + + const providerWithTranscription = { + ...mockOpenAIProvider, + transcriptionModel: vi.fn().mockReturnValue(mockTranscriptionModel) + } as ProviderV3 + + vi.mocked(globalRegistryManagement.getProvider).mockReturnValue(providerWithTranscription) + + const config: HubProviderConfig = { hubId: 'test-hub' } + const provider = createHubProvider(config) as ProviderV3 + + const result = provider.transcriptionModel!(`openai${DEFAULT_SEPARATOR}whisper-1`) + + expect(providerWithTranscription.transcriptionModel).toHaveBeenCalledWith('whisper-1') + expect(result).toBe(mockTranscriptionModel) + }) + + it('should throw error if provider does not support transcription', () => { + const config: HubProviderConfig = { hubId: 'test-hub' } + const provider = createHubProvider(config) as ProviderV3 + + expect(() => provider.transcriptionModel!(`openai${DEFAULT_SEPARATOR}whisper`)).toThrow(HubProviderError) + expect(() => provider.transcriptionModel!(`openai${DEFAULT_SEPARATOR}whisper`)).toThrow( + /does not support transcription/ + ) + }) + + it('should support speech models', () => { + const mockSpeechModel = { + specificationVersion: 'v3', + doGenerate: vi.fn() + } + + const providerWithSpeech = { + ...mockOpenAIProvider, + speechModel: vi.fn().mockReturnValue(mockSpeechModel) + } as ProviderV3 + + vi.mocked(globalRegistryManagement.getProvider).mockReturnValue(providerWithSpeech) + + const config: HubProviderConfig = { hubId: 'test-hub' } + const provider = createHubProvider(config) as ProviderV3 + + const result = provider.speechModel!(`openai${DEFAULT_SEPARATOR}tts-1`) + + expect(providerWithSpeech.speechModel).toHaveBeenCalledWith('tts-1') + expect(result).toBe(mockSpeechModel) + }) + + it('should throw error if provider does not support speech', () => { + const config: HubProviderConfig = { hubId: 'test-hub' } + const provider = createHubProvider(config) as ProviderV3 + + expect(() => provider.speechModel!(`openai${DEFAULT_SEPARATOR}tts-1`)).toThrow(HubProviderError) + expect(() => provider.speechModel!(`openai${DEFAULT_SEPARATOR}tts-1`)).toThrow(/does not support speech/) + }) + + it('should support reranking models', () => { + const mockRerankingModel = { + specificationVersion: 'v3', + doRerank: vi.fn() + } + + const providerWithReranking = { + ...mockOpenAIProvider, + rerankingModel: vi.fn().mockReturnValue(mockRerankingModel) + } as ProviderV3 + + vi.mocked(globalRegistryManagement.getProvider).mockReturnValue(providerWithReranking) + + const config: HubProviderConfig = { hubId: 'test-hub' } + const provider = createHubProvider(config) as ProviderV3 + + const result = provider.rerankingModel!(`openai${DEFAULT_SEPARATOR}rerank-v1`) + + expect(providerWithReranking.rerankingModel).toHaveBeenCalledWith('rerank-v1') + expect(result).toBe(mockRerankingModel) + }) + + it('should throw error if provider does not support reranking', () => { + const config: HubProviderConfig = { hubId: 'test-hub' } + const provider = createHubProvider(config) as ProviderV3 + + expect(() => provider.rerankingModel!(`openai${DEFAULT_SEPARATOR}rerank`)).toThrow(HubProviderError) + expect(() => provider.rerankingModel!(`openai${DEFAULT_SEPARATOR}rerank`)).toThrow(/does not support reranking/) + }) + }) + + describe('Error Handling', () => { + it('should create HubProviderError with all properties', () => { + const originalError = new Error('Original error') + const error = new HubProviderError('Test message', 'test-hub', 'test-provider', originalError) + + expect(error.message).toBe('Test message') + expect(error.hubId).toBe('test-hub') + expect(error.providerId).toBe('test-provider') + expect(error.originalError).toBe(originalError) + expect(error.name).toBe('HubProviderError') + }) + + it('should create HubProviderError without optional parameters', () => { + const error = new HubProviderError('Test message', 'test-hub') + + expect(error.message).toBe('Test message') + expect(error.hubId).toBe('test-hub') + expect(error.providerId).toBeUndefined() + expect(error.originalError).toBeUndefined() + }) + + it('should wrap provider errors in HubProviderError', () => { + const providerError = new Error('Provider failed') + vi.mocked(globalRegistryManagement.getProvider).mockImplementation(() => { + throw providerError + }) + + const config: HubProviderConfig = { hubId: 'test-hub' } + const provider = createHubProvider(config) as ProviderV3 + + try { + provider.languageModel(`failing${DEFAULT_SEPARATOR}model`) + expect.fail('Should have thrown HubProviderError') + } catch (error) { + expect(error).toBeInstanceOf(HubProviderError) + const hubError = error as HubProviderError + expect(hubError.originalError).toBe(providerError) + expect(hubError.message).toContain('Failed to get provider') + } + }) + + it('should handle null provider from registry', () => { + vi.mocked(globalRegistryManagement.getProvider).mockReturnValue(null as any) + + const config: HubProviderConfig = { hubId: 'test-hub' } + const provider = createHubProvider(config) as ProviderV3 + + expect(() => provider.languageModel(`null-provider${DEFAULT_SEPARATOR}model`)).toThrow(HubProviderError) + }) + }) + + describe('Multi-Provider Scenarios', () => { + it('should handle sequential calls to different providers', () => { + const config: HubProviderConfig = { hubId: 'aihubmix' } + const provider = createHubProvider(config) as ProviderV3 + + provider.languageModel(`openai${DEFAULT_SEPARATOR}gpt-4`) + provider.languageModel(`anthropic${DEFAULT_SEPARATOR}claude-3`) + provider.languageModel(`openai${DEFAULT_SEPARATOR}gpt-3.5`) + + expect(globalRegistryManagement.getProvider).toHaveBeenCalledTimes(3) + expect(mockOpenAIProvider.languageModel).toHaveBeenCalledTimes(2) + expect(mockAnthropicProvider.languageModel).toHaveBeenCalledTimes(1) + }) + + it('should handle mixed model types from same provider', () => { + const config: HubProviderConfig = { hubId: 'aihubmix' } + const provider = createHubProvider(config) as ProviderV3 + + provider.languageModel(`openai${DEFAULT_SEPARATOR}gpt-4`) + provider.embeddingModel(`openai${DEFAULT_SEPARATOR}ada-002`) + provider.imageModel(`openai${DEFAULT_SEPARATOR}dall-e-3`) + + expect(globalRegistryManagement.getProvider).toHaveBeenCalledTimes(3) + expect(mockOpenAIProvider.languageModel).toHaveBeenCalledWith('gpt-4') + expect(mockOpenAIProvider.embeddingModel).toHaveBeenCalledWith('ada-002') + expect(mockOpenAIProvider.imageModel).toHaveBeenCalledWith('dall-e-3') + }) + + it('should cache provider lookups', () => { + const config: HubProviderConfig = { hubId: 'aihubmix' } + const provider = createHubProvider(config) as ProviderV3 + + provider.languageModel(`openai${DEFAULT_SEPARATOR}gpt-4`) + provider.languageModel(`openai${DEFAULT_SEPARATOR}gpt-3.5`) + + // Should call getProvider twice (once per model call) + expect(globalRegistryManagement.getProvider).toHaveBeenCalledTimes(2) + }) + }) + + describe('Provider Wrapping', () => { + it('should wrap all providers with empty middleware', () => { + const config: HubProviderConfig = { hubId: 'test-hub' } + const provider = createHubProvider(config) as ProviderV3 + + provider.languageModel(`openai${DEFAULT_SEPARATOR}gpt-4`) + + expect(wrapProvider).toHaveBeenCalledWith({ + provider: mockOpenAIProvider, + languageModelMiddleware: [] + }) + }) + + it('should wrap providers for all model types', () => { + const config: HubProviderConfig = { hubId: 'test-hub' } + const provider = createHubProvider(config) as ProviderV3 + + provider.languageModel(`openai${DEFAULT_SEPARATOR}gpt-4`) + provider.embeddingModel(`openai${DEFAULT_SEPARATOR}ada`) + provider.imageModel(`openai${DEFAULT_SEPARATOR}dalle`) + + expect(wrapProvider).toHaveBeenCalledTimes(3) + }) + }) + + describe('Type Safety', () => { + it('should return properly typed LanguageModelV3', () => { + const config: HubProviderConfig = { hubId: 'test-hub' } + const provider = createHubProvider(config) as ProviderV3 + + const result = provider.languageModel(`openai${DEFAULT_SEPARATOR}gpt-4`) + + expect(result.specificationVersion).toBe('v3') + expect(result).toHaveProperty('doGenerate') + expect(result).toHaveProperty('doStream') + }) + + it('should return properly typed EmbeddingModelV3', () => { + const config: HubProviderConfig = { hubId: 'test-hub' } + const provider = createHubProvider(config) as ProviderV3 + + const result = provider.embeddingModel(`openai${DEFAULT_SEPARATOR}ada`) + + expect(result.specificationVersion).toBe('v3') + expect(result).toHaveProperty('doEmbed') + }) + + it('should return properly typed ImageModelV3', () => { + const config: HubProviderConfig = { hubId: 'test-hub' } + const provider = createHubProvider(config) as ProviderV3 + + const result = provider.imageModel(`openai${DEFAULT_SEPARATOR}dalle`) + + expect(result.specificationVersion).toBe('v3') + expect(result).toHaveProperty('doGenerate') + }) + }) +}) diff --git a/packages/aiCore/src/core/providers/__tests__/RegistryManagement.test.ts b/packages/aiCore/src/core/providers/__tests__/RegistryManagement.test.ts new file mode 100644 index 0000000000..56c67a7c40 --- /dev/null +++ b/packages/aiCore/src/core/providers/__tests__/RegistryManagement.test.ts @@ -0,0 +1,561 @@ +/** + * RegistryManagement Comprehensive Tests + * Tests provider registry management, model resolution, and alias handling + * Covers registration, retrieval, and cleanup operations + */ + +import type { EmbeddingModelV3, ImageModelV3, LanguageModelV3, ProviderV3 } from '@ai-sdk/provider' +import { createProviderRegistry } from 'ai' +import { beforeEach, describe, expect, it, vi } from 'vitest' + +import { createMockEmbeddingModel, createMockImageModel, createMockLanguageModel } from '../../../__tests__' +import { DEFAULT_SEPARATOR, RegistryManagement } from '../RegistryManagement' + +// Mock AI SDK +vi.mock('ai', () => ({ + createProviderRegistry: vi.fn() +})) + +describe('RegistryManagement', () => { + let registry: RegistryManagement + let mockProvider: ProviderV3 + let mockLanguageModel: LanguageModelV3 + let mockEmbeddingModel: EmbeddingModelV3 + let mockImageModel: ImageModelV3 + + beforeEach(() => { + vi.clearAllMocks() + + // Create mock models using global utilities + mockLanguageModel = createMockLanguageModel({ + provider: 'test', + modelId: 'test-model' + }) + + mockEmbeddingModel = createMockEmbeddingModel({ + provider: 'test', + modelId: 'test-embedding' + }) + + mockImageModel = createMockImageModel({ + provider: 'test', + modelId: 'test-image' + }) + + // Create mock provider + mockProvider = { + specificationVersion: 'v3', + languageModel: vi.fn().mockReturnValue(mockLanguageModel), + embeddingModel: vi.fn().mockReturnValue(mockEmbeddingModel), + imageModel: vi.fn().mockReturnValue(mockImageModel), + transcriptionModel: vi.fn(), + speechModel: vi.fn() + } as ProviderV3 + + // Setup mock registry + const mockRegistry = { + languageModel: vi.fn().mockReturnValue(mockLanguageModel), + embeddingModel: vi.fn().mockReturnValue(mockEmbeddingModel), + imageModel: vi.fn().mockReturnValue(mockImageModel), + transcriptionModel: vi.fn(), + speechModel: vi.fn() + } + + vi.mocked(createProviderRegistry).mockReturnValue(mockRegistry as any) + + registry = new RegistryManagement() + }) + + describe('Constructor and Initialization', () => { + it('should create registry with default separator', () => { + const reg = new RegistryManagement() + + expect(reg).toBeInstanceOf(RegistryManagement) + expect(reg.hasProviders()).toBe(false) + }) + + it('should create registry with custom separator', () => { + const customSeparator = ':' + const reg = new RegistryManagement({ separator: customSeparator }) + + expect(reg).toBeInstanceOf(RegistryManagement) + }) + + it('should start with empty provider list', () => { + expect(registry.getRegisteredProviders()).toEqual([]) + }) + }) + + describe('Provider Registration', () => { + it('should register a provider', () => { + registry.registerProvider('openai', mockProvider) + + expect(registry.getProvider('openai')).toBe(mockProvider) + expect(registry.hasProviders()).toBe(true) + }) + + it('should register multiple providers', () => { + const provider2 = { ...mockProvider } + + registry.registerProvider('openai', mockProvider) + registry.registerProvider('anthropic', provider2) + + expect(registry.getProvider('openai')).toBe(mockProvider) + expect(registry.getProvider('anthropic')).toBe(provider2) + }) + + it('should return this for chaining', () => { + const result = registry.registerProvider('openai', mockProvider) + + expect(result).toBe(registry) + }) + + it('should rebuild registry after registration', () => { + registry.registerProvider('openai', mockProvider) + + expect(createProviderRegistry).toHaveBeenCalledWith( + expect.objectContaining({ + openai: mockProvider + }), + { separator: DEFAULT_SEPARATOR } + ) + }) + + it('should register provider with aliases', () => { + registry.registerProvider('openai', mockProvider, ['gpt', 'chatgpt']) + + expect(registry.getProvider('openai')).toBe(mockProvider) + expect(registry.getProvider('gpt')).toBe(mockProvider) + expect(registry.getProvider('chatgpt')).toBe(mockProvider) + }) + + it('should track aliases separately', () => { + registry.registerProvider('openai', mockProvider, ['gpt']) + + expect(registry.isAlias('gpt')).toBe(true) + expect(registry.isAlias('openai')).toBe(false) + }) + + it('should handle multiple aliases for same provider', () => { + const aliases = ['alias1', 'alias2', 'alias3'] + registry.registerProvider('provider', mockProvider, aliases) + + aliases.forEach((alias) => { + expect(registry.getProvider(alias)).toBe(mockProvider) + expect(registry.isAlias(alias)).toBe(true) + }) + }) + }) + + describe('Bulk Registration', () => { + it('should register multiple providers at once', () => { + const providers = { + openai: mockProvider, + anthropic: { ...mockProvider }, + google: { ...mockProvider } + } + + registry.registerProviders(providers) + + expect(registry.getProvider('openai')).toBe(providers.openai) + expect(registry.getProvider('anthropic')).toBe(providers.anthropic) + expect(registry.getProvider('google')).toBe(providers.google) + }) + + it('should return this for chaining', () => { + const result = registry.registerProviders({ openai: mockProvider }) + + expect(result).toBe(registry) + }) + }) + + describe('Provider Retrieval', () => { + beforeEach(() => { + registry.registerProvider('openai', mockProvider) + }) + + it('should retrieve registered provider', () => { + const provider = registry.getProvider('openai') + + expect(provider).toBe(mockProvider) + }) + + it('should return undefined for unregistered provider', () => { + const provider = registry.getProvider('nonexistent') + + expect(provider).toBeUndefined() + }) + + it('should retrieve provider by alias', () => { + registry.registerProvider('anthropic', mockProvider, ['claude']) + + const provider = registry.getProvider('claude') + + expect(provider).toBe(mockProvider) + }) + + it('should get list of all registered providers', () => { + registry.registerProvider('anthropic', mockProvider) + registry.registerProvider('google', mockProvider, ['gemini']) + + const providers = registry.getRegisteredProviders() + + expect(providers).toContain('openai') + expect(providers).toContain('anthropic') + expect(providers).toContain('google') + expect(providers).toContain('gemini') // Aliases included + }) + }) + + describe('Provider Unregistration', () => { + it('should unregister provider', () => { + registry.registerProvider('openai', mockProvider) + + registry.unregisterProvider('openai') + + expect(registry.getProvider('openai')).toBeUndefined() + }) + + it('should unregister provider with all its aliases', () => { + registry.registerProvider('openai', mockProvider, ['gpt', 'chatgpt']) + + registry.unregisterProvider('openai') + + expect(registry.getProvider('openai')).toBeUndefined() + expect(registry.getProvider('gpt')).toBeUndefined() + expect(registry.getProvider('chatgpt')).toBeUndefined() + }) + + it('should unregister only alias when alias is removed', () => { + registry.registerProvider('openai', mockProvider, ['gpt', 'chatgpt']) + + registry.unregisterProvider('gpt') + + expect(registry.getProvider('openai')).toBe(mockProvider) + expect(registry.getProvider('gpt')).toBeUndefined() + expect(registry.getProvider('chatgpt')).toBe(mockProvider) + }) + + it('should handle unregistering non-existent provider', () => { + expect(() => registry.unregisterProvider('nonexistent')).not.toThrow() + }) + + it('should return this for chaining', () => { + registry.registerProvider('openai', mockProvider) + + const result = registry.unregisterProvider('openai') + + expect(result).toBe(registry) + }) + + it('should rebuild registry after unregistration', () => { + registry.registerProvider('openai', mockProvider) + vi.clearAllMocks() + + registry.unregisterProvider('openai') + + // Should rebuild with empty providers + expect(createProviderRegistry).not.toHaveBeenCalled() // No rebuild when empty + }) + }) + + describe('Model Resolution', () => { + beforeEach(() => { + registry.registerProvider('openai', mockProvider) + }) + + it('should resolve language model', () => { + const modelId = `openai${DEFAULT_SEPARATOR}gpt-4` as any + + const result = registry.languageModel(modelId) + + expect(result).toBe(mockLanguageModel) + }) + + it('should resolve embedding model', () => { + const modelId = `openai${DEFAULT_SEPARATOR}text-embedding-3-small` as any + + const result = registry.embeddingModel(modelId) + + expect(result).toBe(mockEmbeddingModel) + }) + + it('should resolve image model', () => { + const modelId = `openai${DEFAULT_SEPARATOR}dall-e-3` as any + + const result = registry.imageModel(modelId) + + expect(result).toBe(mockImageModel) + }) + + it('should resolve transcription model', () => { + const modelId = `openai${DEFAULT_SEPARATOR}whisper-1` as any + + registry.transcriptionModel(modelId) + + // Verify it calls through to the mock registry + expect(createProviderRegistry).toHaveBeenCalled() + }) + + it('should resolve speech model', () => { + const modelId = `openai${DEFAULT_SEPARATOR}tts-1` as any + + registry.speechModel(modelId) + + expect(createProviderRegistry).toHaveBeenCalled() + }) + + it('should throw error when no providers registered', () => { + const emptyRegistry = new RegistryManagement() + + expect(() => emptyRegistry.languageModel('openai|gpt-4' as any)).toThrow('No providers registered') + }) + }) + + describe('Alias Management', () => { + it('should resolve provider ID from alias', () => { + registry.registerProvider('openai', mockProvider, ['gpt']) + + const realId = registry.resolveProviderId('gpt') + + expect(realId).toBe('openai') + }) + + it('should return same ID if not an alias', () => { + registry.registerProvider('openai', mockProvider) + + const realId = registry.resolveProviderId('openai') + + expect(realId).toBe('openai') + }) + + it('should check if ID is alias', () => { + registry.registerProvider('openai', mockProvider, ['gpt']) + + expect(registry.isAlias('gpt')).toBe(true) + expect(registry.isAlias('openai')).toBe(false) + }) + + it('should get all alias mappings', () => { + const provider2 = { ...mockProvider } + + registry.registerProvider('openai', mockProvider, ['gpt', 'chatgpt']) + registry.registerProvider('anthropic', provider2, ['claude']) + + const aliases = registry.getAllAliases() + + // Check that all aliases are present + expect(aliases['gpt']).toBe('openai') + expect(aliases['chatgpt']).toBe('openai') + expect(aliases['claude']).toBe('anthropic') + }) + + it('should return empty object when no aliases', () => { + registry.registerProvider('openai', mockProvider) + + const aliases = registry.getAllAliases() + + expect(aliases).toEqual({}) + }) + }) + + describe('Registry State', () => { + it('should check if has providers', () => { + expect(registry.hasProviders()).toBe(false) + + registry.registerProvider('openai', mockProvider) + + expect(registry.hasProviders()).toBe(true) + }) + + it('should clear all providers', () => { + registry.registerProvider('openai', mockProvider, ['gpt']) + registry.registerProvider('anthropic', mockProvider) + + registry.clear() + + expect(registry.hasProviders()).toBe(false) + expect(registry.getRegisteredProviders()).toEqual([]) + expect(registry.getAllAliases()).toEqual({}) + }) + + it('should return this after clear for chaining', () => { + const result = registry.clear() + + expect(result).toBe(registry) + }) + }) + + describe('Registry Rebuilding', () => { + it('should rebuild registry when provider added', () => { + registry.registerProvider('openai', mockProvider) + + expect(createProviderRegistry).toHaveBeenCalledTimes(1) + }) + + it('should rebuild registry when provider removed', () => { + registry.registerProvider('openai', mockProvider) + registry.registerProvider('anthropic', mockProvider) + + vi.clearAllMocks() + + registry.unregisterProvider('openai') + + expect(createProviderRegistry).toHaveBeenCalledTimes(1) + }) + + it('should set registry to null when all providers removed', () => { + registry.registerProvider('openai', mockProvider) + registry.unregisterProvider('openai') + + expect(() => registry.languageModel('any|model' as any)).toThrow('No providers registered') + }) + + it('should rebuild with correct separator', () => { + const customRegistry = new RegistryManagement({ separator: ':' }) + customRegistry.registerProvider('openai', mockProvider) + + expect(createProviderRegistry).toHaveBeenCalledWith(expect.any(Object), { separator: ':' }) + }) + }) + + describe('Global Registry Instance', () => { + it('should have a global instance with default separator', async () => { + const module = await import('../RegistryManagement') + + expect(module.globalRegistryManagement).toBeInstanceOf(RegistryManagement) + }) + + it('should have DEFAULT_SEPARATOR exported', () => { + expect(DEFAULT_SEPARATOR).toBe('|') + }) + }) + + describe('Edge Cases', () => { + it('should handle registering same provider twice', () => { + registry.registerProvider('openai', mockProvider) + + const provider2 = { ...mockProvider } + registry.registerProvider('openai', provider2) + + expect(registry.getProvider('openai')).toBe(provider2) + }) + + it('should handle alias conflicts (first wins)', () => { + registry.registerProvider('provider1', mockProvider, ['shared-alias']) + registry.registerProvider('provider2', mockProvider, ['shared-alias']) + + // First registered alias wins (the implementation doesn't override) + expect(registry.resolveProviderId('shared-alias')).toBe('provider1') + }) + + it('should handle empty alias array', () => { + registry.registerProvider('openai', mockProvider, []) + + expect(registry.getAllAliases()).toEqual({}) + }) + + it('should handle null registry operations gracefully', () => { + const emptyRegistry = new RegistryManagement() + + expect(() => emptyRegistry.languageModel('test|model' as any)).toThrow('No providers registered') + expect(() => emptyRegistry.embeddingModel('test|embed' as any)).toThrow('No providers registered') + expect(() => emptyRegistry.imageModel('test|image' as any)).toThrow('No providers registered') + }) + + it('should handle special characters in provider IDs', () => { + const specialIds = ['provider-1', 'provider_2', 'provider.3', 'provider:4'] + + specialIds.forEach((id) => { + registry.registerProvider(id, mockProvider) + expect(registry.getProvider(id)).toBe(mockProvider) + }) + }) + }) + + describe('Concurrent Operations', () => { + it('should handle concurrent registrations', () => { + const promises = [ + Promise.resolve(registry.registerProvider('provider1', mockProvider)), + Promise.resolve(registry.registerProvider('provider2', mockProvider)), + Promise.resolve(registry.registerProvider('provider3', mockProvider)) + ] + + return Promise.all(promises).then(() => { + expect(registry.getRegisteredProviders()).toHaveLength(3) + }) + }) + + it('should handle mixed operations', () => { + registry.registerProvider('openai', mockProvider) + registry.registerProvider('anthropic', mockProvider) + + const provider1 = registry.getProvider('openai') + registry.unregisterProvider('anthropic') + const provider2 = registry.getProvider('openai') + + expect(provider1).toBe(provider2) + }) + }) + + describe('Type Safety', () => { + it('should enforce model ID format with template literal types', () => { + registry.registerProvider('openai', mockProvider) + + // These should be type-safe + const validId = 'openai|gpt-4' as `${string}${typeof DEFAULT_SEPARATOR}${string}` + + expect(() => registry.languageModel(validId)).not.toThrow() + }) + + it('should return properly typed LanguageModelV3', () => { + registry.registerProvider('openai', mockProvider) + + const model = registry.languageModel('openai|gpt-4' as any) + + expect(model.specificationVersion).toBe('v3') + expect(model).toHaveProperty('doGenerate') + expect(model).toHaveProperty('doStream') + }) + + it('should return properly typed EmbeddingModelV3', () => { + registry.registerProvider('openai', mockProvider) + + const model = registry.embeddingModel('openai|ada-002' as any) + + expect(model.specificationVersion).toBe('v3') + expect(model).toHaveProperty('doEmbed') + }) + + it('should return properly typed ImageModelV3', () => { + registry.registerProvider('openai', mockProvider) + + const model = registry.imageModel('openai|dall-e-3' as any) + + expect(model.specificationVersion).toBe('v3') + expect(model).toHaveProperty('doGenerate') + }) + }) + + describe('Memory Management', () => { + it('should properly clean up on clear', () => { + registry.registerProvider('p1', mockProvider, ['a1']) + registry.registerProvider('p2', mockProvider, ['a2']) + + registry.clear() + + expect(registry.getRegisteredProviders()).toHaveLength(0) + expect(Object.keys(registry.getAllAliases())).toHaveLength(0) + }) + + it('should properly clean up on unregister', () => { + registry.registerProvider('openai', mockProvider, ['gpt', 'chatgpt']) + + registry.unregisterProvider('openai') + + expect(registry.getProvider('openai')).toBeUndefined() + expect(registry.isAlias('gpt')).toBe(false) + expect(registry.isAlias('chatgpt')).toBe(false) + }) + }) +}) diff --git a/packages/aiCore/src/core/runtime/__tests__/executor-resolveModel.test.ts b/packages/aiCore/src/core/runtime/__tests__/executor-resolveModel.test.ts new file mode 100644 index 0000000000..948d81063f --- /dev/null +++ b/packages/aiCore/src/core/runtime/__tests__/executor-resolveModel.test.ts @@ -0,0 +1,650 @@ +/** + * RuntimeExecutor.resolveModel Comprehensive Tests + * Tests the private resolveModel and resolveImageModel methods through public APIs + * Covers model resolution, middleware application, and type validation + */ + +import type { ImageModelV3, LanguageModelV3 } from '@ai-sdk/provider' +import { generateImage, generateText, streamText } from 'ai' +import { beforeEach, describe, expect, it, vi } from 'vitest' + +import { + createMockImageModel, + createMockLanguageModel, + createMockMiddleware, + mockProviderConfigs +} from '../../../__tests__' +import { globalModelResolver } from '../../models' +import { ImageModelResolutionError } from '../errors' +import { RuntimeExecutor } from '../executor' + +// Mock AI SDK +vi.mock('ai', async (importOriginal) => { + const actual = (await importOriginal()) as Record + return { + ...actual, + generateText: vi.fn(), + streamText: vi.fn(), + generateImage: vi.fn(), + wrapLanguageModel: vi.fn((config: any) => ({ + ...config.model, + _middlewareApplied: true, + middleware: config.middleware + })) + } +}) + +vi.mock('../../providers/RegistryManagement', () => ({ + globalRegistryManagement: { + languageModel: vi.fn(), + imageModel: vi.fn() + }, + DEFAULT_SEPARATOR: '|' +})) + +vi.mock('../../models', () => ({ + globalModelResolver: { + resolveLanguageModel: vi.fn(), + resolveImageModel: vi.fn() + } +})) + +describe('RuntimeExecutor - Model Resolution', () => { + let executor: RuntimeExecutor<'openai'> + let mockLanguageModel: LanguageModelV3 + let mockImageModel: ImageModelV3 + + beforeEach(() => { + vi.clearAllMocks() + + executor = RuntimeExecutor.create('openai', mockProviderConfigs.openai) + + mockLanguageModel = createMockLanguageModel({ + specificationVersion: 'v3', + provider: 'openai', + modelId: 'gpt-4' + }) + + mockImageModel = createMockImageModel({ + specificationVersion: 'v3', + provider: 'openai', + modelId: 'dall-e-3' + }) + + vi.mocked(globalModelResolver.resolveLanguageModel).mockResolvedValue(mockLanguageModel) + vi.mocked(globalModelResolver.resolveImageModel).mockResolvedValue(mockImageModel) + vi.mocked(generateText).mockResolvedValue({ + text: 'Test response', + finishReason: 'stop', + usage: { inputTokens: 10, outputTokens: 5, totalTokens: 15 } + } as any) + vi.mocked(streamText).mockResolvedValue({ + textStream: (async function* () { + yield 'test' + })() + } as any) + vi.mocked(generateImage).mockResolvedValue({ + image: { + base64: 'test-image', + uint8Array: new Uint8Array([1, 2, 3]), + mimeType: 'image/png' + }, + warnings: [] + } as any) + }) + + describe('Language Model Resolution (String modelId)', () => { + it('should resolve string modelId using globalModelResolver', async () => { + await executor.generateText({ + model: 'gpt-4', + messages: [{ role: 'user', content: 'Hello' }] + }) + + expect(globalModelResolver.resolveLanguageModel).toHaveBeenCalledWith( + 'gpt-4', + 'openai', + mockProviderConfigs.openai, + undefined + ) + }) + + it('should pass provider settings to model resolver', async () => { + const customExecutor = RuntimeExecutor.create('anthropic', { + apiKey: 'sk-test', + baseURL: 'https://api.anthropic.com' + }) + + vi.mocked(globalModelResolver.resolveLanguageModel).mockResolvedValue(mockLanguageModel) + + await customExecutor.generateText({ + model: 'claude-3-5-sonnet', + messages: [{ role: 'user', content: 'Test' }] + }) + + expect(globalModelResolver.resolveLanguageModel).toHaveBeenCalledWith( + 'claude-3-5-sonnet', + 'anthropic', + { + apiKey: 'sk-test', + baseURL: 'https://api.anthropic.com' + }, + undefined + ) + }) + + it('should resolve traditional format modelId', async () => { + await executor.generateText({ + model: 'gpt-4-turbo', + messages: [{ role: 'user', content: 'Test' }] + }) + + expect(globalModelResolver.resolveLanguageModel).toHaveBeenCalledWith( + 'gpt-4-turbo', + 'openai', + expect.any(Object), + undefined + ) + }) + + it('should resolve namespaced format modelId', async () => { + await executor.generateText({ + model: 'aihubmix|anthropic|claude-3', + messages: [{ role: 'user', content: 'Test' }] + }) + + expect(globalModelResolver.resolveLanguageModel).toHaveBeenCalledWith( + 'aihubmix|anthropic|claude-3', + 'openai', + expect.any(Object), + undefined + ) + }) + + it('should use resolved model for generation', async () => { + await executor.generateText({ + model: 'gpt-4', + messages: [{ role: 'user', content: 'Hello' }] + }) + + expect(generateText).toHaveBeenCalledWith( + expect.objectContaining({ + model: mockLanguageModel + }) + ) + }) + + it('should work with streamText', async () => { + await executor.streamText({ + model: 'gpt-4', + messages: [{ role: 'user', content: 'Stream test' }] + }) + + expect(globalModelResolver.resolveLanguageModel).toHaveBeenCalled() + expect(streamText).toHaveBeenCalledWith( + expect.objectContaining({ + model: mockLanguageModel + }) + ) + }) + }) + + describe('Language Model Resolution (Direct Model Object)', () => { + it('should accept pre-resolved V3 model object', async () => { + const directModel: LanguageModelV3 = createMockLanguageModel({ + specificationVersion: 'v3', + provider: 'openai', + modelId: 'gpt-4' + }) + + await executor.generateText({ + model: directModel, + messages: [{ role: 'user', content: 'Test' }] + }) + + // Should NOT call resolver for direct model + expect(globalModelResolver.resolveLanguageModel).not.toHaveBeenCalled() + + // Should use the model directly + expect(generateText).toHaveBeenCalledWith( + expect.objectContaining({ + model: directModel + }) + ) + }) + + it('should accept V2 model object without validation (plugin engine handles it)', async () => { + const v2Model = { + specificationVersion: 'v2', + provider: 'openai', + modelId: 'gpt-4', + doGenerate: vi.fn() + } as any + + // The plugin engine accepts any model object directly without validation + // V3 validation only happens when resolving string modelIds + await expect( + executor.generateText({ + model: v2Model, + messages: [{ role: 'user', content: 'Test' }] + }) + ).resolves.toBeDefined() + }) + + it('should accept any model object without checking specification version', async () => { + const v2Model = { + specificationVersion: 'v2', + provider: 'custom-provider', + modelId: 'custom-model', + doGenerate: vi.fn() + } as any + + // Direct model objects bypass validation + // The executor trusts that plugins/users provide valid models + await expect( + executor.generateText({ + model: v2Model, + messages: [{ role: 'user', content: 'Test' }] + }) + ).resolves.toBeDefined() + }) + + it('should accept model object with streamText', async () => { + const directModel = createMockLanguageModel({ + specificationVersion: 'v3' + }) + + await executor.streamText({ + model: directModel, + messages: [{ role: 'user', content: 'Stream' }] + }) + + expect(globalModelResolver.resolveLanguageModel).not.toHaveBeenCalled() + expect(streamText).toHaveBeenCalledWith( + expect.objectContaining({ + model: directModel + }) + ) + }) + }) + + describe('Middleware Application', () => { + it('should apply middlewares to string modelId', async () => { + const testMiddleware = createMockMiddleware({ name: 'test-middleware' }) + + await executor.generateText( + { + model: 'gpt-4', + messages: [{ role: 'user', content: 'Test' }] + }, + { middlewares: [testMiddleware] } + ) + + expect(globalModelResolver.resolveLanguageModel).toHaveBeenCalledWith('gpt-4', 'openai', expect.any(Object), [ + testMiddleware + ]) + }) + + it('should apply multiple middlewares in order', async () => { + const middleware1 = createMockMiddleware({ name: 'middleware-1' }) + const middleware2 = createMockMiddleware({ name: 'middleware-2' }) + + await executor.generateText( + { + model: 'gpt-4', + messages: [{ role: 'user', content: 'Test' }] + }, + { middlewares: [middleware1, middleware2] } + ) + + expect(globalModelResolver.resolveLanguageModel).toHaveBeenCalledWith('gpt-4', 'openai', expect.any(Object), [ + middleware1, + middleware2 + ]) + }) + + it('should pass middlewares to model resolver for string modelIds', async () => { + const testMiddleware = createMockMiddleware({ name: 'test-middleware' }) + + await executor.generateText( + { + model: 'gpt-4', // String model ID + messages: [{ role: 'user', content: 'Test' }] + }, + { middlewares: [testMiddleware] } + ) + + // Middlewares are passed to the resolver for string modelIds + expect(globalModelResolver.resolveLanguageModel).toHaveBeenCalledWith('gpt-4', 'openai', expect.any(Object), [ + testMiddleware + ]) + }) + + it('should not apply middlewares when none provided', async () => { + await executor.generateText({ + model: 'gpt-4', + messages: [{ role: 'user', content: 'Test' }] + }) + + expect(globalModelResolver.resolveLanguageModel).toHaveBeenCalledWith( + 'gpt-4', + 'openai', + expect.any(Object), + undefined + ) + }) + + it('should handle empty middleware array', async () => { + await executor.generateText( + { + model: 'gpt-4', + messages: [{ role: 'user', content: 'Test' }] + }, + { middlewares: [] } + ) + + expect(globalModelResolver.resolveLanguageModel).toHaveBeenCalledWith('gpt-4', 'openai', expect.any(Object), []) + }) + + it('should work with middlewares in streamText', async () => { + const middleware = createMockMiddleware({ name: 'stream-middleware' }) + + await executor.streamText( + { + model: 'gpt-4', + messages: [{ role: 'user', content: 'Stream' }] + }, + { middlewares: [middleware] } + ) + + expect(globalModelResolver.resolveLanguageModel).toHaveBeenCalledWith('gpt-4', 'openai', expect.any(Object), [ + middleware + ]) + }) + }) + + describe('Image Model Resolution', () => { + it('should resolve string image modelId using globalModelResolver', async () => { + await executor.generateImage({ + model: 'dall-e-3', + prompt: 'A beautiful sunset' + }) + + expect(globalModelResolver.resolveImageModel).toHaveBeenCalledWith('dall-e-3', 'openai') + }) + + it('should accept direct ImageModelV3 object', async () => { + const directImageModel: ImageModelV3 = createMockImageModel({ + specificationVersion: 'v3', + provider: 'openai', + modelId: 'dall-e-3' + }) + + await executor.generateImage({ + model: directImageModel, + prompt: 'Test image' + }) + + expect(globalModelResolver.resolveImageModel).not.toHaveBeenCalled() + expect(generateImage).toHaveBeenCalledWith( + expect.objectContaining({ + model: directImageModel + }) + ) + }) + + it('should resolve namespaced image model ID', async () => { + await executor.generateImage({ + model: 'aihubmix|openai|dall-e-3', + prompt: 'Namespaced image' + }) + + expect(globalModelResolver.resolveImageModel).toHaveBeenCalledWith('aihubmix|openai|dall-e-3', 'openai') + }) + + it('should throw ImageModelResolutionError on resolution failure', async () => { + const resolutionError = new Error('Model not found') + vi.mocked(globalModelResolver.resolveImageModel).mockRejectedValue(resolutionError) + + await expect( + executor.generateImage({ + model: 'invalid-model', + prompt: 'Test' + }) + ).rejects.toThrow(ImageModelResolutionError) + }) + + it('should include modelId and providerId in ImageModelResolutionError', async () => { + vi.mocked(globalModelResolver.resolveImageModel).mockRejectedValue(new Error('Not found')) + + try { + await executor.generateImage({ + model: 'invalid-model', + prompt: 'Test' + }) + expect.fail('Should have thrown ImageModelResolutionError') + } catch (error) { + expect(error).toBeInstanceOf(ImageModelResolutionError) + const imgError = error as ImageModelResolutionError + expect(imgError.message).toContain('invalid-model') + expect(imgError.providerId).toBe('openai') + } + }) + + it('should extract modelId from direct model object in error', async () => { + const directModel = createMockImageModel({ + modelId: 'direct-model', + doGenerate: vi.fn().mockRejectedValue(new Error('Generation failed')) + }) + + vi.mocked(generateImage).mockRejectedValue(new Error('Generation failed')) + + await expect( + executor.generateImage({ + model: directModel, + prompt: 'Test' + }) + ).rejects.toThrow() + }) + }) + + describe('Provider-Specific Model Resolution', () => { + it('should resolve models for OpenAI provider', async () => { + const openaiExecutor = RuntimeExecutor.create('openai', mockProviderConfigs.openai) + + await openaiExecutor.generateText({ + model: 'gpt-4', + messages: [{ role: 'user', content: 'Test' }] + }) + + expect(globalModelResolver.resolveLanguageModel).toHaveBeenCalledWith( + 'gpt-4', + 'openai', + expect.any(Object), + undefined + ) + }) + + it('should resolve models for Anthropic provider', async () => { + const anthropicExecutor = RuntimeExecutor.create('anthropic', mockProviderConfigs.anthropic) + + await anthropicExecutor.generateText({ + model: 'claude-3-5-sonnet', + messages: [{ role: 'user', content: 'Test' }] + }) + + expect(globalModelResolver.resolveLanguageModel).toHaveBeenCalledWith( + 'claude-3-5-sonnet', + 'anthropic', + expect.any(Object), + undefined + ) + }) + + it('should resolve models for Google provider', async () => { + const googleExecutor = RuntimeExecutor.create('google', mockProviderConfigs.google) + + await googleExecutor.generateText({ + model: 'gemini-2.0-flash', + messages: [{ role: 'user', content: 'Test' }] + }) + + expect(globalModelResolver.resolveLanguageModel).toHaveBeenCalledWith( + 'gemini-2.0-flash', + 'google', + expect.any(Object), + undefined + ) + }) + + it('should resolve models for OpenAI-compatible provider', async () => { + const compatibleExecutor = RuntimeExecutor.createOpenAICompatible(mockProviderConfigs['openai-compatible']) + + await compatibleExecutor.generateText({ + model: 'custom-model', + messages: [{ role: 'user', content: 'Test' }] + }) + + expect(globalModelResolver.resolveLanguageModel).toHaveBeenCalledWith( + 'custom-model', + 'openai-compatible', + expect.any(Object), + undefined + ) + }) + }) + + describe('OpenAI Mode Handling', () => { + it('should pass mode setting to model resolver', async () => { + const executorWithMode = RuntimeExecutor.create('openai', { + ...mockProviderConfigs.openai, + mode: 'chat' + }) + + await executorWithMode.generateText({ + model: 'gpt-4', + messages: [{ role: 'user', content: 'Test' }] + }) + + expect(globalModelResolver.resolveLanguageModel).toHaveBeenCalledWith( + 'gpt-4', + 'openai', + expect.objectContaining({ + mode: 'chat' + }), + undefined + ) + }) + + it('should handle responses mode', async () => { + const executorWithMode = RuntimeExecutor.create('openai', { + ...mockProviderConfigs.openai, + mode: 'responses' + }) + + await executorWithMode.generateText({ + model: 'gpt-4', + messages: [{ role: 'user', content: 'Test' }] + }) + + expect(globalModelResolver.resolveLanguageModel).toHaveBeenCalledWith( + 'gpt-4', + 'openai', + expect.objectContaining({ + mode: 'responses' + }), + undefined + ) + }) + }) + + describe('Edge Cases', () => { + it('should handle empty string modelId', async () => { + await executor.generateText({ + model: '', + messages: [{ role: 'user', content: 'Test' }] + }) + + expect(globalModelResolver.resolveLanguageModel).toHaveBeenCalledWith('', 'openai', expect.any(Object), undefined) + }) + + it('should handle model resolution errors gracefully', async () => { + vi.mocked(globalModelResolver.resolveLanguageModel).mockRejectedValue(new Error('Model not found')) + + await expect( + executor.generateText({ + model: 'nonexistent-model', + messages: [{ role: 'user', content: 'Test' }] + }) + ).rejects.toThrow('Model not found') + }) + + it('should handle concurrent model resolutions', async () => { + const promises = [ + executor.generateText({ model: 'gpt-4', messages: [{ role: 'user', content: 'Test 1' }] }), + executor.generateText({ model: 'gpt-4-turbo', messages: [{ role: 'user', content: 'Test 2' }] }), + executor.generateText({ model: 'gpt-3.5-turbo', messages: [{ role: 'user', content: 'Test 3' }] }) + ] + + await Promise.all(promises) + + expect(globalModelResolver.resolveLanguageModel).toHaveBeenCalledTimes(3) + }) + + it('should accept model object even without specificationVersion', async () => { + const invalidModel = { + provider: 'test', + modelId: 'test-model' + // Missing specificationVersion + } as any + + // Plugin engine doesn't validate direct model objects + // It's the user's responsibility to provide valid models + await expect( + executor.generateText({ + model: invalidModel, + messages: [{ role: 'user', content: 'Test' }] + }) + ).resolves.toBeDefined() + }) + }) + + describe('Type Safety Validation', () => { + it('should ensure resolved model is LanguageModelV3', async () => { + const v3Model = createMockLanguageModel({ + specificationVersion: 'v3' + }) + + vi.mocked(globalModelResolver.resolveLanguageModel).mockResolvedValue(v3Model) + + await executor.generateText({ + model: 'gpt-4', + messages: [{ role: 'user', content: 'Test' }] + }) + + expect(generateText).toHaveBeenCalledWith( + expect.objectContaining({ + model: expect.objectContaining({ + specificationVersion: 'v3' + }) + }) + ) + }) + + it('should not enforce specification version for direct models', async () => { + const v1Model = { + specificationVersion: 'v1', + provider: 'test', + modelId: 'test' + } as any + + // Direct models bypass validation in the plugin engine + // Only resolved models (from string IDs) are validated + await expect( + executor.generateText({ + model: v1Model, + messages: [{ role: 'user', content: 'Test' }] + }) + ).resolves.toBeDefined() + }) + }) +}) diff --git a/packages/aiCore/src/core/runtime/__tests__/pluginEngine.test.ts b/packages/aiCore/src/core/runtime/__tests__/pluginEngine.test.ts new file mode 100644 index 0000000000..e0dedf1521 --- /dev/null +++ b/packages/aiCore/src/core/runtime/__tests__/pluginEngine.test.ts @@ -0,0 +1,867 @@ +/** + * PluginEngine Comprehensive Tests + * Tests plugin lifecycle, execution order, and coordination + * Covers both streaming and non-streaming execution paths + */ + +import type { ImageModelV3, LanguageModelV3 } from '@ai-sdk/provider' +import { beforeEach, describe, expect, it, vi } from 'vitest' + +import { createMockImageModel, createMockLanguageModel } from '../../../__tests__' +import { ModelResolutionError, RecursiveDepthError } from '../../errors' +import type { AiPlugin, GenerateTextParams, GenerateTextResult } from '../../plugins' +import { PluginEngine } from '../pluginEngine' + +describe('PluginEngine', () => { + let engine: PluginEngine<'openai'> + let mockLanguageModel: LanguageModelV3 + let mockImageModel: ImageModelV3 + + beforeEach(() => { + vi.clearAllMocks() + + mockLanguageModel = createMockLanguageModel({ + provider: 'openai', + modelId: 'gpt-4' + }) + + mockImageModel = createMockImageModel({ + provider: 'openai', + modelId: 'dall-e-3' + }) + }) + + describe('Plugin Registration and Management', () => { + it('should create engine with empty plugins', () => { + engine = new PluginEngine('openai', []) + + expect(engine.getPlugins()).toEqual([]) + }) + + it('should create engine with initial plugins', () => { + const plugin1: AiPlugin = { name: 'plugin-1' } + const plugin2: AiPlugin = { name: 'plugin-2' } + + engine = new PluginEngine('openai', [plugin1, plugin2]) + + expect(engine.getPlugins()).toHaveLength(2) + expect(engine.getPlugins()).toEqual([plugin1, plugin2]) + }) + + it('should add plugin with use()', () => { + engine = new PluginEngine('openai', []) + const plugin: AiPlugin = { name: 'test-plugin' } + + const result = engine.use(plugin) + + expect(result).toBe(engine) // Chainable + expect(engine.getPlugins()).toContain(plugin) + }) + + it('should add multiple plugins with usePlugins()', () => { + engine = new PluginEngine('openai', []) + const plugins: AiPlugin[] = [{ name: 'plugin-1' }, { name: 'plugin-2' }, { name: 'plugin-3' }] + + const result = engine.usePlugins(plugins) + + expect(result).toBe(engine) // Chainable + expect(engine.getPlugins()).toHaveLength(3) + }) + + it('should support method chaining', () => { + engine = new PluginEngine('openai', []) + + engine + .use({ name: 'plugin-1' }) + .use({ name: 'plugin-2' }) + .usePlugins([{ name: 'plugin-3' }]) + + expect(engine.getPlugins()).toHaveLength(3) + }) + + it('should remove plugin by name', () => { + const plugin1: AiPlugin = { name: 'plugin-1' } + const plugin2: AiPlugin = { name: 'plugin-2' } + + engine = new PluginEngine('openai', [plugin1, plugin2]) + + engine.removePlugin('plugin-1') + + expect(engine.getPlugins()).toHaveLength(1) + expect(engine.getPlugins()[0]).toBe(plugin2) + }) + + it('should not error when removing non-existent plugin', () => { + engine = new PluginEngine('openai', [{ name: 'existing' }]) + + engine.removePlugin('non-existent') + + expect(engine.getPlugins()).toHaveLength(1) + }) + + it('should get plugin statistics', () => { + const plugins: AiPlugin[] = [ + { name: 'plugin-1', enforce: 'pre' }, + { name: 'plugin-2' }, + { name: 'plugin-3', enforce: 'post' } + ] + + engine = new PluginEngine('openai', plugins) + + const stats = engine.getPluginStats() + + expect(stats).toHaveProperty('total') + expect(stats.total).toBe(3) + }) + + it('should preserve plugin order', () => { + const plugins: AiPlugin[] = [{ name: 'first' }, { name: 'second' }, { name: 'third' }] + + engine = new PluginEngine('openai', plugins) + + const retrieved = engine.getPlugins() + expect(retrieved[0].name).toBe('first') + expect(retrieved[1].name).toBe('second') + expect(retrieved[2].name).toBe('third') + }) + }) + + describe('Plugin Lifecycle - Non-Streaming', () => { + it('should execute all plugin hooks in correct order', async () => { + const executionOrder: string[] = [] + + const plugin: AiPlugin = { + name: 'lifecycle-test', + configureContext: vi.fn(async () => { + executionOrder.push('configureContext') + }), + onRequestStart: vi.fn(async () => { + executionOrder.push('onRequestStart') + }), + resolveModel: vi.fn(async () => { + executionOrder.push('resolveModel') + return mockLanguageModel + }), + transformParams: vi.fn(async (params) => { + executionOrder.push('transformParams') + return params + }), + transformResult: vi.fn(async (result) => { + executionOrder.push('transformResult') + return result + }), + onRequestEnd: vi.fn(async () => { + executionOrder.push('onRequestEnd') + }) + } + + engine = new PluginEngine('openai', [plugin]) + + const mockExecutor = vi.fn().mockResolvedValue({ text: 'test', finishReason: 'stop' }) + + await engine.executeWithPlugins( + 'generateText', + { model: 'gpt-4', messages: [] }, + mockExecutor + ) + + expect(executionOrder).toEqual([ + 'configureContext', + 'onRequestStart', + 'resolveModel', + 'transformParams', + 'transformResult', + 'onRequestEnd' + ]) + }) + + it('should call configureContext before other hooks', async () => { + const configureContextSpy = vi.fn() + const onRequestStartSpy = vi.fn() + + const plugin: AiPlugin = { + name: 'test', + configureContext: configureContextSpy, + onRequestStart: onRequestStartSpy, + resolveModel: vi.fn().mockResolvedValue(mockLanguageModel) + } + + engine = new PluginEngine('openai', [plugin]) + + await engine.executeWithPlugins( + 'generateText', + { model: 'gpt-4', messages: [] }, + vi.fn().mockResolvedValue({ text: 'test' }) + ) + + expect(configureContextSpy).toHaveBeenCalled() + expect(onRequestStartSpy).toHaveBeenCalled() + + // configureContext should be called before onRequestStart + expect(configureContextSpy.mock.invocationCallOrder[0]).toBeLessThan( + onRequestStartSpy.mock.invocationCallOrder[0] + ) + }) + + it('should execute onRequestEnd after successful execution', async () => { + const onRequestEndSpy = vi.fn() + + const plugin: AiPlugin = { + name: 'test', + resolveModel: vi.fn().mockResolvedValue(mockLanguageModel), + onRequestEnd: onRequestEndSpy + } + + engine = new PluginEngine('openai', [plugin]) + + await engine.executeWithPlugins( + 'generateText', + { model: 'gpt-4', messages: [] }, + vi.fn().mockResolvedValue({ text: 'result' }) + ) + + expect(onRequestEndSpy).toHaveBeenCalledWith( + expect.any(Object), // context + expect.objectContaining({ text: 'result' }) + ) + }) + + it('should execute onError on failure', async () => { + const onErrorSpy = vi.fn() + const testError = new Error('Test error') + + const plugin: AiPlugin = { + name: 'error-handler', + resolveModel: vi.fn().mockResolvedValue(mockLanguageModel), + onError: onErrorSpy + } + + engine = new PluginEngine('openai', [plugin]) + + await expect( + engine.executeWithPlugins( + 'generateText', + { model: 'gpt-4', messages: [] }, + vi.fn().mockRejectedValue(testError) + ) + ).rejects.toThrow('Test error') + + expect(onErrorSpy).toHaveBeenCalledWith( + testError, + expect.any(Object) // context + ) + }) + + it('should not call onRequestEnd when error occurs', async () => { + const onRequestEndSpy = vi.fn() + + const plugin: AiPlugin = { + name: 'test', + resolveModel: vi.fn().mockResolvedValue(mockLanguageModel), + onRequestEnd: onRequestEndSpy + } + + engine = new PluginEngine('openai', [plugin]) + + await expect( + engine.executeWithPlugins( + 'generateText', + { model: 'gpt-4', messages: [] }, + vi.fn().mockRejectedValue(new Error('Execution error')) + ) + ).rejects.toThrow() + + expect(onRequestEndSpy).not.toHaveBeenCalled() + }) + }) + + describe('Model Resolution', () => { + it('should resolve string model through plugin', async () => { + const resolveModelSpy = vi.fn().mockResolvedValue(mockLanguageModel) + + const plugin: AiPlugin = { + name: 'resolver', + resolveModel: resolveModelSpy + } + + engine = new PluginEngine('openai', [plugin]) + + await engine.executeWithPlugins( + 'generateText', + { model: 'gpt-4', messages: [] }, + vi.fn().mockResolvedValue({ text: 'test' }) + ) + + expect(resolveModelSpy).toHaveBeenCalledWith('gpt-4', expect.any(Object)) + }) + + it('should use first plugin that resolves model', async () => { + const resolver1 = vi.fn().mockResolvedValue(mockLanguageModel) + const resolver2 = vi.fn().mockResolvedValue(mockLanguageModel) + + const plugin1: AiPlugin = { name: 'resolver-1', resolveModel: resolver1 } + const plugin2: AiPlugin = { name: 'resolver-2', resolveModel: resolver2 } + + engine = new PluginEngine('openai', [plugin1, plugin2]) + + await engine.executeWithPlugins( + 'generateText', + { model: 'gpt-4', messages: [] }, + vi.fn().mockResolvedValue({ text: 'test' }) + ) + + expect(resolver1).toHaveBeenCalled() + expect(resolver2).not.toHaveBeenCalled() // Should stop after first resolver + }) + + it('should throw ModelResolutionError if no plugin resolves model', async () => { + const plugin: AiPlugin = { + name: 'no-resolver' + // No resolveModel hook + } + + engine = new PluginEngine('openai', [plugin]) + + await expect( + engine.executeWithPlugins( + 'generateText', + { model: 'unknown-model', messages: [] }, + vi.fn().mockResolvedValue({ text: 'test' }) + ) + ).rejects.toThrow(ModelResolutionError) + }) + + it('should skip resolution for direct model objects', async () => { + const resolveModelSpy = vi.fn() + + const plugin: AiPlugin = { + name: 'resolver', + resolveModel: resolveModelSpy + } + + engine = new PluginEngine('openai', [plugin]) + + await engine.executeWithPlugins( + 'generateText', + { model: mockLanguageModel, messages: [] }, + vi.fn().mockResolvedValue({ text: 'test' }) + ) + + expect(resolveModelSpy).not.toHaveBeenCalled() + }) + + it('should throw if resolved model is null/undefined', async () => { + const plugin: AiPlugin = { + name: 'bad-resolver', + resolveModel: vi.fn().mockResolvedValue(null) + } + + engine = new PluginEngine('openai', [plugin]) + + await expect( + engine.executeWithPlugins( + 'generateText', + { model: 'gpt-4', messages: [] }, + vi.fn().mockResolvedValue({ text: 'test' }) + ) + ).rejects.toThrow(ModelResolutionError) + }) + }) + + describe('Parameter Transformation', () => { + it('should transform parameters through plugin', async () => { + const transformParamsSpy = vi.fn().mockImplementation(async (params) => ({ + ...params, + temperature: 0.8 + })) + + const plugin: AiPlugin = { + name: 'transformer', + resolveModel: vi.fn().mockResolvedValue(mockLanguageModel), + transformParams: transformParamsSpy + } + + engine = new PluginEngine('openai', [plugin]) + + const mockExecutor = vi.fn().mockResolvedValue({ text: 'test' }) + + await engine.executeWithPlugins('generateText', { model: 'gpt-4', messages: [] }, mockExecutor) + + expect(mockExecutor).toHaveBeenCalledWith( + expect.any(Object), + expect.objectContaining({ + temperature: 0.8 + }) + ) + }) + + it('should chain parameter transformations across plugins', async () => { + const plugin1: AiPlugin = { + name: 'transform-1', + transformParams: vi.fn().mockImplementation(async (params) => ({ + ...params, + temperature: 0.5 + })) + } + + const plugin2: AiPlugin = { + name: 'transform-2', + transformParams: vi.fn().mockImplementation(async (params) => ({ + ...params, + maxTokens: 100 + })) + } + + engine = new PluginEngine('openai', [plugin1, plugin2]) + engine.usePlugins([{ name: 'resolver', resolveModel: vi.fn().mockResolvedValue(mockLanguageModel) }]) + + const mockExecutor = vi.fn().mockResolvedValue({ text: 'test' }) + + await engine.executeWithPlugins('generateText', { model: 'gpt-4', messages: [] }, mockExecutor) + + expect(mockExecutor).toHaveBeenCalledWith( + expect.any(Object), + expect.objectContaining({ + temperature: 0.5, + maxTokens: 100 + }) + ) + }) + }) + + describe('Result Transformation', () => { + it('should transform result through plugin', async () => { + const transformResultSpy = vi.fn().mockImplementation(async (result) => ({ + ...result, + text: `${result.text} [modified]` + })) + + const plugin: AiPlugin = { + name: 'result-transformer', + resolveModel: vi.fn().mockResolvedValue(mockLanguageModel), + transformResult: transformResultSpy + } + + engine = new PluginEngine('openai', [plugin]) + + const result = await engine.executeWithPlugins( + 'generateText', + { model: 'gpt-4', messages: [] }, + vi.fn().mockResolvedValue({ text: 'original' }) + ) + + expect(result.text).toBe('original [modified]') + }) + + it('should chain result transformations across plugins', async () => { + const plugin1: AiPlugin = { + name: 'transform-1', + transformResult: vi.fn().mockImplementation(async (result) => ({ + ...result, + text: `${result.text} + plugin1` + })) + } + + const plugin2: AiPlugin = { + name: 'transform-2', + transformResult: vi.fn().mockImplementation(async (result) => ({ + ...result, + text: `${result.text} + plugin2` + })) + } + + engine = new PluginEngine('openai', [plugin1, plugin2]) + engine.usePlugins([{ name: 'resolver', resolveModel: vi.fn().mockResolvedValue(mockLanguageModel) }]) + + const result = await engine.executeWithPlugins( + 'generateText', + { model: 'gpt-4', messages: [] }, + vi.fn().mockResolvedValue({ text: 'base' }) + ) + + expect(result.text).toBe('base + plugin1 + plugin2') + }) + }) + + describe('Recursive Calls', () => { + it('should support recursive calls through context', async () => { + let recursionCount = 0 + + const plugin: AiPlugin = { + name: 'recursive', + resolveModel: vi.fn().mockResolvedValue(mockLanguageModel), + transformParams: vi.fn().mockImplementation(async (params, context) => { + if (recursionCount < 2 && context.recursiveCall) { + recursionCount++ + await context.recursiveCall({ messages: [{ role: 'user', content: 'recursive' }] }) + } + return params + }) + } + + engine = new PluginEngine('openai', [plugin]) + + await engine.executeWithPlugins( + 'generateText', + { model: 'gpt-4', messages: [] }, + vi.fn().mockResolvedValue({ text: 'test' }) + ) + + expect(recursionCount).toBe(2) + }) + + it('should track recursion depth', async () => { + const depths: number[] = [] + + const plugin: AiPlugin = { + name: 'depth-tracker', + resolveModel: vi.fn().mockResolvedValue(mockLanguageModel), + transformParams: vi.fn().mockImplementation(async (params, context) => { + depths.push(context.recursiveDepth) + + if (context.recursiveDepth < 3 && context.recursiveCall) { + await context.recursiveCall({ messages: [] }) + } + + return params + }) + } + + engine = new PluginEngine('openai', [plugin]) + + await engine.executeWithPlugins( + 'generateText', + { model: 'gpt-4', messages: [] }, + vi.fn().mockResolvedValue({ text: 'test' }) + ) + + expect(depths).toEqual([0, 1, 2, 3]) + }) + + it('should throw RecursiveDepthError when max depth exceeded', async () => { + const plugin: AiPlugin = { + name: 'infinite', + resolveModel: vi.fn().mockResolvedValue(mockLanguageModel), + transformParams: vi.fn().mockImplementation(async (params, context) => { + if (context.recursiveCall) { + await context.recursiveCall({ messages: [] }) + } + return params + }) + } + + engine = new PluginEngine('openai', [plugin]) + + await expect( + engine.executeWithPlugins( + 'generateText', + { model: 'gpt-4', messages: [] }, + vi.fn().mockResolvedValue({ text: 'test' }) + ) + ).rejects.toThrow(RecursiveDepthError) + }) + + it('should restore recursion state after recursive call', async () => { + const states: Array<{ depth: number; isRecursive: boolean }> = [] + + const plugin: AiPlugin = { + name: 'state-tracker', + resolveModel: vi.fn().mockResolvedValue(mockLanguageModel), + transformParams: vi.fn().mockImplementation(async (params, context) => { + states.push({ depth: context.recursiveDepth, isRecursive: context.isRecursiveCall }) + + if (context.recursiveDepth === 0 && context.recursiveCall) { + await context.recursiveCall({ messages: [] }) + states.push({ depth: context.recursiveDepth, isRecursive: context.isRecursiveCall }) + } + + return params + }) + } + + engine = new PluginEngine('openai', [plugin]) + + await engine.executeWithPlugins( + 'generateText', + { model: 'gpt-4', messages: [] }, + vi.fn().mockResolvedValue({ text: 'test' }) + ) + + expect(states[0]).toEqual({ depth: 0, isRecursive: false }) + expect(states[1]).toEqual({ depth: 1, isRecursive: true }) + expect(states[2]).toEqual({ depth: 0, isRecursive: false }) + }) + }) + + describe('Image Model Execution', () => { + it('should execute image generation with plugins', async () => { + const plugin: AiPlugin = { + name: 'image-plugin', + resolveModel: vi.fn().mockResolvedValue(mockImageModel), + transformParams: vi.fn().mockImplementation(async (params) => params) + } + + engine = new PluginEngine('openai', [plugin]) + + const mockExecutor = vi.fn().mockResolvedValue({ + image: { base64: 'test', uint8Array: new Uint8Array(), mimeType: 'image/png' } + }) + + await engine.executeImageWithPlugins('generateImage', { model: 'dall-e-3', prompt: 'test' }, mockExecutor) + + expect(plugin.resolveModel).toHaveBeenCalledWith('dall-e-3', expect.any(Object)) + expect(mockExecutor).toHaveBeenCalled() + }) + + it('should skip resolution for direct image model objects', async () => { + const resolveModelSpy = vi.fn() + + const plugin: AiPlugin = { + name: 'image-resolver', + resolveModel: resolveModelSpy + } + + engine = new PluginEngine('openai', [plugin]) + + await engine.executeImageWithPlugins( + 'generateImage', + { model: mockImageModel, prompt: 'test' }, + vi.fn().mockResolvedValue({ image: {} }) + ) + + expect(resolveModelSpy).not.toHaveBeenCalled() + }) + }) + + describe('Streaming Execution', () => { + it('should execute streaming with plugins', async () => { + const plugin: AiPlugin = { + name: 'stream-plugin', + resolveModel: vi.fn().mockResolvedValue(mockLanguageModel), + transformParams: vi.fn().mockImplementation(async (params) => params) + } + + engine = new PluginEngine('openai', [plugin]) + + const mockExecutor = vi.fn().mockResolvedValue({ + textStream: (async function* () { + yield 'test' + })() + }) + + await engine.executeStreamWithPlugins('streamText', { model: 'gpt-4', messages: [] }, mockExecutor) + + expect(plugin.resolveModel).toHaveBeenCalled() + expect(mockExecutor).toHaveBeenCalled() + }) + + it('should collect stream transforms from plugins', async () => { + const mockTransform = vi.fn() + + const plugin: AiPlugin = { + name: 'stream-transformer', + resolveModel: vi.fn().mockResolvedValue(mockLanguageModel), + transformStream: mockTransform + } + + engine = new PluginEngine('openai', [plugin]) + + const mockExecutor = vi.fn().mockResolvedValue({ textStream: (async function* () {})() }) + + await engine.executeStreamWithPlugins('streamText', { model: 'gpt-4', messages: [] }, mockExecutor) + + // Executor should receive stream transforms + expect(mockExecutor).toHaveBeenCalledWith(expect.any(Object), expect.any(Object), expect.arrayContaining([])) + }) + }) + + describe('Context Management', () => { + it('should create context with correct provider and model', async () => { + const configureContextSpy = vi.fn() + + const plugin: AiPlugin = { + name: 'context-checker', + configureContext: configureContextSpy, + resolveModel: vi.fn().mockResolvedValue(mockLanguageModel) + } + + engine = new PluginEngine('openai', [plugin]) + + await engine.executeWithPlugins( + 'generateText', + { model: 'gpt-4', messages: [] }, + vi.fn().mockResolvedValue({ text: 'test' }) + ) + + expect(configureContextSpy).toHaveBeenCalledWith( + expect.objectContaining({ + providerId: 'openai', + model: 'gpt-4' + }) + ) + }) + + it('should pass context to all hooks', async () => { + const contextRefs: any[] = [] + + const plugin: AiPlugin = { + name: 'context-tracker', + configureContext: vi.fn().mockImplementation(async (context) => { + contextRefs.push(context) + }), + onRequestStart: vi.fn().mockImplementation(async (context) => { + contextRefs.push(context) + }), + resolveModel: vi.fn().mockImplementation(async (_, context) => { + contextRefs.push(context) + return mockLanguageModel + }), + transformParams: vi.fn().mockImplementation(async (params, context) => { + contextRefs.push(context) + return params + }) + } + + engine = new PluginEngine('openai', [plugin]) + + await engine.executeWithPlugins( + 'generateText', + { model: 'gpt-4', messages: [] }, + vi.fn().mockResolvedValue({ text: 'test' }) + ) + + // All context refs should point to the same object + expect(contextRefs.length).toBeGreaterThan(0) + const firstContext = contextRefs[0] + contextRefs.forEach((context) => { + expect(context).toBe(firstContext) + }) + }) + }) + + describe('Error Handling', () => { + it('should propagate errors from executor', async () => { + const plugin: AiPlugin = { + name: 'test', + resolveModel: vi.fn().mockResolvedValue(mockLanguageModel) + } + + engine = new PluginEngine('openai', [plugin]) + + const error = new Error('Executor failed') + + await expect( + engine.executeWithPlugins('generateText', { model: 'gpt-4', messages: [] }, vi.fn().mockRejectedValue(error)) + ).rejects.toThrow('Executor failed') + }) + + it('should trigger onError for all plugins on failure', async () => { + const onError1 = vi.fn() + const onError2 = vi.fn() + + const plugin1: AiPlugin = { name: 'error-1', onError: onError1 } + const plugin2: AiPlugin = { name: 'error-2', onError: onError2 } + + engine = new PluginEngine('openai', [plugin1, plugin2]) + engine.usePlugins([{ name: 'resolver', resolveModel: vi.fn().mockResolvedValue(mockLanguageModel) }]) + + const error = new Error('Test failure') + + await expect( + engine.executeWithPlugins('generateText', { model: 'gpt-4', messages: [] }, vi.fn().mockRejectedValue(error)) + ).rejects.toThrow() + + expect(onError1).toHaveBeenCalledWith(error, expect.any(Object)) + expect(onError2).toHaveBeenCalledWith(error, expect.any(Object)) + }) + + it('should handle errors in plugin hooks gracefully', async () => { + const plugin: AiPlugin = { + name: 'failing-plugin', + resolveModel: vi.fn().mockResolvedValue(mockLanguageModel), + transformParams: vi.fn().mockRejectedValue(new Error('Transform failed')) + } + + engine = new PluginEngine('openai', [plugin]) + + await expect( + engine.executeWithPlugins('generateText', { model: 'gpt-4', messages: [] }, vi.fn()) + ).rejects.toThrow('Transform failed') + }) + }) + + describe('Plugin Enforcement', () => { + it('should respect plugin enforce ordering (pre, normal, post)', async () => { + const executionOrder: string[] = [] + + const prePlugin: AiPlugin = { + name: 'pre-plugin', + enforce: 'pre', + onRequestStart: vi.fn(async () => { + executionOrder.push('pre') + }) + } + + const normalPlugin: AiPlugin = { + name: 'normal-plugin', + onRequestStart: vi.fn(async () => { + executionOrder.push('normal') + }) + } + + const postPlugin: AiPlugin = { + name: 'post-plugin', + enforce: 'post', + onRequestStart: vi.fn(async () => { + executionOrder.push('post') + }) + } + + // Add in reverse order to test sorting + engine = new PluginEngine('openai', [postPlugin, normalPlugin, prePlugin]) + engine.usePlugins([{ name: 'resolver', resolveModel: vi.fn().mockResolvedValue(mockLanguageModel) }]) + + await engine.executeWithPlugins( + 'generateText', + { model: 'gpt-4', messages: [] }, + vi.fn().mockResolvedValue({ text: 'test' }) + ) + + // Should execute in order: pre -> normal -> post + expect(executionOrder).toEqual(['pre', 'normal', 'post']) + }) + }) + + describe('Type Safety', () => { + it('should properly type plugin parameters', async () => { + const transformParamsSpy = vi.fn().mockImplementation(async (params) => { + // Type assertions for safety + expect(params).toHaveProperty('messages') + return params + }) + + const transformResultSpy = vi.fn().mockImplementation(async (result) => { + expect(result).toHaveProperty('text') + return result + }) + + const typedPlugin: AiPlugin = { + name: 'typed-plugin', + resolveModel: vi.fn().mockResolvedValue(mockLanguageModel), + transformParams: transformParamsSpy, + transformResult: transformResultSpy + } + + engine = new PluginEngine('openai', [typedPlugin]) + + await engine.executeWithPlugins( + 'generateText', + { model: 'gpt-4', messages: [] }, + vi.fn().mockResolvedValue({ text: 'test', finishReason: 'stop' }) + ) + + expect(transformParamsSpy).toHaveBeenCalled() + expect(transformResultSpy).toHaveBeenCalled() + }) + }) +}) diff --git a/packages/aiCore/src/core/runtime/pluginEngine.ts b/packages/aiCore/src/core/runtime/pluginEngine.ts index 3a81fa4d7b..7f992a3b03 100644 --- a/packages/aiCore/src/core/runtime/pluginEngine.ts +++ b/packages/aiCore/src/core/runtime/pluginEngine.ts @@ -75,10 +75,7 @@ export class PluginEngine { * 执行带插件的操作(非流式) * 提供给AiExecutor使用 */ - async executeWithPlugins< - TParams extends GenerateTextParams, - TResult extends GenerateTextResult - >( + async executeWithPlugins( methodName: string, params: TParams, executor: (model: LanguageModel, transformedParams: TParams) => TResult, @@ -101,9 +98,7 @@ export class PluginEngine { const context = _context ?? createContext(this.providerId, model, params) // ✅ 创建类型化的 manager(逆变安全) - const manager = new PluginManager( - this.basePlugins as AiPlugin[] - ) + const manager = new PluginManager(this.basePlugins as AiPlugin[]) // ✅ 递归调用泛型化,增加深度限制 context.recursiveCall = async (newParams: Partial): Promise => { @@ -118,12 +113,12 @@ export class PluginEngine { context.recursiveDepth = previousDepth + 1 context.isRecursiveCall = true - return await this.executeWithPlugins( + return (await this.executeWithPlugins( methodName, { ...params, ...newParams } as TParams, executor, context - ) as unknown as R + )) as unknown as R } finally { // ✅ finally 确保状态恢复 context.recursiveDepth = previousDepth @@ -201,9 +196,7 @@ export class PluginEngine { const context = _context ?? createContext(this.providerId, model, params) // ✅ 创建类型化的 manager(逆变安全) - const manager = new PluginManager( - this.basePlugins as AiPlugin[] - ) + const manager = new PluginManager(this.basePlugins as AiPlugin[]) // ✅ 递归调用泛型化,增加深度限制 context.recursiveCall = async (newParams: Partial): Promise => { @@ -218,12 +211,12 @@ export class PluginEngine { context.recursiveDepth = previousDepth + 1 context.isRecursiveCall = true - return await this.executeImageWithPlugins( + return (await this.executeImageWithPlugins( methodName, { ...params, ...newParams } as TParams, executor, context - ) as unknown as R + )) as unknown as R } finally { // ✅ finally 确保状态恢复 context.recursiveDepth = previousDepth @@ -275,10 +268,7 @@ export class PluginEngine { * 执行流式调用的通用逻辑(支持流转换器) * 提供给AiExecutor使用 */ - async executeStreamWithPlugins< - TParams extends StreamTextParams, - TResult extends StreamTextResult - >( + async executeStreamWithPlugins( methodName: string, params: TParams, executor: (model: LanguageModel, transformedParams: TParams, streamTransforms: any[]) => TResult, @@ -301,9 +291,7 @@ export class PluginEngine { const context = _context ?? createContext(this.providerId, model, params) // ✅ 创建类型化的 manager(逆变安全) - const manager = new PluginManager( - this.basePlugins as AiPlugin[] - ) + const manager = new PluginManager(this.basePlugins as AiPlugin[]) // ✅ 递归调用泛型化,增加深度限制 context.recursiveCall = async (newParams: Partial): Promise => { @@ -318,12 +306,12 @@ export class PluginEngine { context.recursiveDepth = previousDepth + 1 context.isRecursiveCall = true - return await this.executeStreamWithPlugins( + return (await this.executeStreamWithPlugins( methodName, { ...params, ...newParams } as TParams, executor, context - ) as unknown as R + )) as unknown as R } finally { // ✅ finally 确保状态恢复 context.recursiveDepth = previousDepth