mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2026-02-12 05:43:15 +08:00
Add comprehensive tests for PluginEngine functionality
- Implement tests for plugin registration, management, and lifecycle - Validate plugin execution order and context management - Test model resolution, parameter transformation, and result transformation - Ensure error handling and recursive call support - Cover streaming execution and image model handling - Verify type safety for plugin parameters
This commit is contained in:
parent
fca41ed966
commit
3c23c32232
@ -11,18 +11,20 @@ import { vi } from 'vitest'
|
||||
*/
|
||||
export function createMockLanguageModel(overrides?: Partial<LanguageModelV3>): LanguageModelV3 {
|
||||
return {
|
||||
specificationVersion: 'v1',
|
||||
specificationVersion: 'v3',
|
||||
provider: 'mock-provider',
|
||||
modelId: 'mock-model',
|
||||
defaultObjectGenerationMode: 'tool',
|
||||
supportedUrls: {},
|
||||
|
||||
doGenerate: vi.fn().mockResolvedValue({
|
||||
text: 'Mock response text',
|
||||
finishReason: 'stop',
|
||||
usage: {
|
||||
promptTokens: 10,
|
||||
completionTokens: 20,
|
||||
totalTokens: 30
|
||||
inputTokens: 10,
|
||||
outputTokens: 20,
|
||||
totalTokens: 30,
|
||||
inputTokenDetails: {},
|
||||
outputTokenDetails: {}
|
||||
},
|
||||
rawCall: { rawPrompt: null, rawSettings: {} },
|
||||
rawResponse: { headers: {} },
|
||||
@ -47,9 +49,11 @@ export function createMockLanguageModel(overrides?: Partial<LanguageModelV3>): L
|
||||
type: 'finish',
|
||||
finishReason: 'stop',
|
||||
usage: {
|
||||
promptTokens: 10,
|
||||
completionTokens: 15,
|
||||
totalTokens: 25
|
||||
inputTokens: 10,
|
||||
outputTokens: 15,
|
||||
totalTokens: 25,
|
||||
inputTokenDetails: {},
|
||||
outputTokenDetails: {}
|
||||
}
|
||||
}
|
||||
})(),
|
||||
@ -64,12 +68,14 @@ export function createMockLanguageModel(overrides?: Partial<LanguageModelV3>): L
|
||||
|
||||
/**
|
||||
* Creates a mock image model with customizable behavior
|
||||
* Compliant with AI SDK v3 specification
|
||||
*/
|
||||
export function createMockImageModel(overrides?: Partial<ImageModelV3>): ImageModelV3 {
|
||||
return {
|
||||
specificationVersion: 'V3',
|
||||
specificationVersion: 'v3',
|
||||
provider: 'mock-provider',
|
||||
modelId: 'mock-image-model',
|
||||
maxImagesPerCall: undefined,
|
||||
|
||||
doGenerate: vi.fn().mockResolvedValue({
|
||||
images: [
|
||||
|
||||
300
packages/aiCore/src/__tests__/helpers/model-test-utils.ts
Normal file
300
packages/aiCore/src/__tests__/helpers/model-test-utils.ts
Normal file
@ -0,0 +1,300 @@
|
||||
/**
|
||||
* Model Test Utilities
|
||||
* Provides comprehensive mock creators for AI SDK v3 models and related test utilities
|
||||
*/
|
||||
|
||||
import type {
|
||||
EmbeddingModelV3,
|
||||
ImageModelV3,
|
||||
LanguageModelV3,
|
||||
LanguageModelV3Middleware,
|
||||
ProviderV3
|
||||
} from '@ai-sdk/provider'
|
||||
import type { Tool, ToolSet } from 'ai'
|
||||
import { tool } from 'ai'
|
||||
import { MockLanguageModelV3 } from 'ai/test'
|
||||
import { vi } from 'vitest'
|
||||
import * as z from 'zod'
|
||||
|
||||
import { StreamTextParams, StreamTextResult } from '../../core/plugins'
|
||||
import type { ProviderId } from '../../core/providers/types'
|
||||
import { AiRequestContext } from '../../types'
|
||||
|
||||
/**
|
||||
* Type for partial overrides that allows omitting the model field
|
||||
* The model will be automatically added by createMockContext
|
||||
*/
|
||||
type ContextOverrides = Partial<Omit<AiRequestContext<StreamTextParams, StreamTextResult>, 'originalParams'>> & {
|
||||
originalParams?: Partial<Omit<StreamTextParams, 'model'>> & { model?: StreamTextParams['model'] }
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a mock AiRequestContext with type safety
|
||||
* The model field is automatically added to originalParams if not provided
|
||||
*
|
||||
* @example
|
||||
* ```ts
|
||||
* const context = createMockContext({
|
||||
* providerId: 'openai',
|
||||
* metadata: { requestId: 'test-123' }
|
||||
* })
|
||||
* ```
|
||||
*/
|
||||
export function createMockContext(overrides?: ContextOverrides): AiRequestContext<StreamTextParams, StreamTextResult> {
|
||||
const mockModel = new MockLanguageModelV3({
|
||||
provider: 'test-provider',
|
||||
modelId: 'test-model'
|
||||
})
|
||||
|
||||
const base: AiRequestContext<StreamTextParams, StreamTextResult> = {
|
||||
providerId: 'openai' as ProviderId,
|
||||
model: mockModel,
|
||||
originalParams: {
|
||||
model: mockModel,
|
||||
messages: [{ role: 'user', content: 'Test message' }]
|
||||
} as StreamTextParams,
|
||||
metadata: {},
|
||||
startTime: Date.now(),
|
||||
requestId: 'test-request-id',
|
||||
recursiveCall: vi.fn(),
|
||||
isRecursiveCall: false,
|
||||
recursiveDepth: 0,
|
||||
maxRecursiveDepth: 10,
|
||||
extensions: new Map()
|
||||
}
|
||||
|
||||
if (overrides) {
|
||||
// Ensure model is always present in originalParams
|
||||
const mergedOriginalParams = {
|
||||
...base.originalParams,
|
||||
...overrides.originalParams,
|
||||
model: overrides.originalParams?.model ?? mockModel
|
||||
}
|
||||
|
||||
return {
|
||||
...base,
|
||||
...overrides,
|
||||
originalParams: mergedOriginalParams as StreamTextParams
|
||||
}
|
||||
}
|
||||
|
||||
return base
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a mock embedding model with customizable behavior
|
||||
* Compliant with AI SDK v3 specification
|
||||
*
|
||||
* @example
|
||||
* ```ts
|
||||
* const embeddingModel = createMockEmbeddingModel({
|
||||
* provider: 'openai',
|
||||
* modelId: 'text-embedding-3-small',
|
||||
* maxEmbeddingsPerCall: 2048
|
||||
* })
|
||||
* ```
|
||||
*/
|
||||
export function createMockEmbeddingModel(overrides?: Partial<EmbeddingModelV3>): EmbeddingModelV3 {
|
||||
return {
|
||||
specificationVersion: 'v3',
|
||||
provider: 'mock-provider',
|
||||
modelId: 'mock-embedding-model',
|
||||
maxEmbeddingsPerCall: 100,
|
||||
supportsParallelCalls: true,
|
||||
|
||||
doEmbed: vi.fn().mockResolvedValue({
|
||||
embeddings: [
|
||||
[0.1, 0.2, 0.3, 0.4, 0.5],
|
||||
[0.6, 0.7, 0.8, 0.9, 1.0]
|
||||
],
|
||||
usage: {
|
||||
inputTokens: 10,
|
||||
totalTokens: 10
|
||||
},
|
||||
rawResponse: { headers: {} }
|
||||
}),
|
||||
|
||||
...overrides
|
||||
} as EmbeddingModelV3
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a complete mock ProviderV3 with all model types
|
||||
* Useful for testing provider registration and management
|
||||
*
|
||||
* @example
|
||||
* ```ts
|
||||
* const provider = createMockProviderV3({
|
||||
* provider: 'openai',
|
||||
* languageModel: customLanguageModel,
|
||||
* imageModel: customImageModel
|
||||
* })
|
||||
* ```
|
||||
*/
|
||||
export function createMockProviderV3(overrides?: {
|
||||
provider?: string
|
||||
languageModel?: (modelId: string) => LanguageModelV3
|
||||
imageModel?: (modelId: string) => ImageModelV3
|
||||
embeddingModel?: (modelId: string) => EmbeddingModelV3
|
||||
}): ProviderV3 {
|
||||
return {
|
||||
specificationVersion: 'v3',
|
||||
provider: overrides?.provider ?? 'mock-provider',
|
||||
|
||||
languageModel: overrides?.languageModel
|
||||
? overrides.languageModel
|
||||
: (modelId: string) =>
|
||||
({
|
||||
specificationVersion: 'v3',
|
||||
provider: overrides?.provider ?? 'mock-provider',
|
||||
modelId,
|
||||
defaultObjectGenerationMode: 'tool',
|
||||
supportedUrls: {},
|
||||
doGenerate: vi.fn(),
|
||||
doStream: vi.fn()
|
||||
}) as LanguageModelV3,
|
||||
|
||||
imageModel: overrides?.imageModel
|
||||
? overrides.imageModel
|
||||
: (modelId: string) =>
|
||||
({
|
||||
specificationVersion: 'v3',
|
||||
provider: overrides?.provider ?? 'mock-provider',
|
||||
modelId,
|
||||
maxImagesPerCall: undefined,
|
||||
doGenerate: vi.fn()
|
||||
}) as ImageModelV3,
|
||||
|
||||
embeddingModel: overrides?.embeddingModel
|
||||
? overrides.embeddingModel
|
||||
: (modelId: string) =>
|
||||
({
|
||||
specificationVersion: 'v3',
|
||||
provider: overrides?.provider ?? 'mock-provider',
|
||||
modelId,
|
||||
maxEmbeddingsPerCall: 100,
|
||||
supportsParallelCalls: true,
|
||||
doEmbed: vi.fn()
|
||||
}) as EmbeddingModelV3
|
||||
} as ProviderV3
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a mock middleware for testing middleware chains
|
||||
* Supports both generate and stream wrapping
|
||||
*
|
||||
* @example
|
||||
* ```ts
|
||||
* const middleware = createMockMiddleware({
|
||||
* name: 'test-middleware'
|
||||
* })
|
||||
* ```
|
||||
*/
|
||||
export function createMockMiddleware(_options?: { name?: string }): LanguageModelV3Middleware {
|
||||
return {
|
||||
specificationVersion: 'v3',
|
||||
wrapGenerate: vi.fn((doGenerate) => doGenerate),
|
||||
wrapStream: vi.fn((doStream) => doStream)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a type-safe function tool for testing using AI SDK's tool() function
|
||||
*
|
||||
* @example
|
||||
* ```ts
|
||||
* const weatherTool = createMockTool('getWeather', 'Get current weather')
|
||||
* ```
|
||||
*/
|
||||
export function createMockTool(name: string, description?: string): Tool<{ value?: string }, string> {
|
||||
return tool({
|
||||
description: description || `Mock tool: ${name}`,
|
||||
inputSchema: z.object({
|
||||
value: z.string().optional()
|
||||
}),
|
||||
execute: vi.fn(async () => 'mock result')
|
||||
})
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a provider-defined tool for testing
|
||||
*/
|
||||
export function createMockProviderTool(name: string, description?: string): { type: 'provider'; description: string } {
|
||||
return {
|
||||
type: 'provider' as const,
|
||||
description: description || `Mock provider tool: ${name}`
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a ToolSet with multiple tools
|
||||
*
|
||||
* @example
|
||||
* ```ts
|
||||
* const tools = createMockToolSet({
|
||||
* getWeather: 'function',
|
||||
* searchDatabase: 'function',
|
||||
* nativeSearch: 'provider'
|
||||
* })
|
||||
* ```
|
||||
*/
|
||||
export function createMockToolSet(tools: Record<string, 'function' | 'provider'>): ToolSet {
|
||||
const toolSet: ToolSet = {}
|
||||
|
||||
for (const [name, type] of Object.entries(tools)) {
|
||||
if (type === 'function') {
|
||||
toolSet[name] = createMockTool(name)
|
||||
} else {
|
||||
toolSet[name] = createMockProviderTool(name) as Tool
|
||||
}
|
||||
}
|
||||
|
||||
return toolSet
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates mock stream params for testing
|
||||
*
|
||||
* @example
|
||||
* ```ts
|
||||
* const params = createMockStreamParams({
|
||||
* messages: [{ role: 'user', content: 'Custom message' }],
|
||||
* temperature: 0.7
|
||||
* })
|
||||
* ```
|
||||
*/
|
||||
export function createMockStreamParams(overrides?: Partial<StreamTextParams>): StreamTextParams {
|
||||
return {
|
||||
messages: [{ role: 'user', content: 'Test message' }],
|
||||
...overrides
|
||||
} as StreamTextParams
|
||||
}
|
||||
|
||||
/**
|
||||
* Common mock model instances for quick testing
|
||||
*/
|
||||
export const mockModels = {
|
||||
/** Standard language model for general testing */
|
||||
language: new MockLanguageModelV3({
|
||||
provider: 'test-provider',
|
||||
modelId: 'test-model'
|
||||
}),
|
||||
|
||||
/** Mock OpenAI GPT-4 model */
|
||||
gpt4: new MockLanguageModelV3({
|
||||
provider: 'openai',
|
||||
modelId: 'gpt-4'
|
||||
}),
|
||||
|
||||
/** Mock Anthropic Claude model */
|
||||
claude: new MockLanguageModelV3({
|
||||
provider: 'anthropic',
|
||||
modelId: 'claude-3-5-sonnet-20241022'
|
||||
}),
|
||||
|
||||
/** Mock Google Gemini model */
|
||||
gemini: new MockLanguageModelV3({
|
||||
provider: 'google',
|
||||
modelId: 'gemini-2.0-flash-exp'
|
||||
})
|
||||
} as const
|
||||
@ -8,5 +8,6 @@ export * from './fixtures/mock-providers'
|
||||
export * from './fixtures/mock-responses'
|
||||
|
||||
// Helpers
|
||||
export * from './helpers/model-test-utils'
|
||||
export * from './helpers/provider-test-utils'
|
||||
export * from './helpers/test-utils'
|
||||
|
||||
454
packages/aiCore/src/core/models/__tests__/ModelResolver.test.ts
Normal file
454
packages/aiCore/src/core/models/__tests__/ModelResolver.test.ts
Normal file
@ -0,0 +1,454 @@
|
||||
/**
|
||||
* ModelResolver Comprehensive Tests
|
||||
* Tests model resolution logic for language, embedding, and image models
|
||||
* Covers both traditional and namespaced format resolution
|
||||
*/
|
||||
|
||||
import type { EmbeddingModelV3, ImageModelV3, LanguageModelV3 } from '@ai-sdk/provider'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import {
|
||||
createMockEmbeddingModel,
|
||||
createMockImageModel,
|
||||
createMockLanguageModel,
|
||||
createMockMiddleware
|
||||
} from '../../../__tests__'
|
||||
import { DEFAULT_SEPARATOR, globalRegistryManagement } from '../../providers/RegistryManagement'
|
||||
import { ModelResolver } from '../ModelResolver'
|
||||
|
||||
// Mock the dependencies
|
||||
vi.mock('../../providers/RegistryManagement', () => ({
|
||||
globalRegistryManagement: {
|
||||
languageModel: vi.fn(),
|
||||
embeddingModel: vi.fn(),
|
||||
imageModel: vi.fn()
|
||||
},
|
||||
DEFAULT_SEPARATOR: '|'
|
||||
}))
|
||||
|
||||
vi.mock('../../middleware/wrapper', () => ({
|
||||
wrapModelWithMiddlewares: vi.fn((model: LanguageModelV3) => {
|
||||
// Return a wrapped model with a marker
|
||||
return {
|
||||
...model,
|
||||
_wrapped: true
|
||||
} as LanguageModelV3
|
||||
})
|
||||
}))
|
||||
|
||||
describe('ModelResolver', () => {
|
||||
let resolver: ModelResolver
|
||||
let mockLanguageModel: LanguageModelV3
|
||||
let mockEmbeddingModel: EmbeddingModelV3
|
||||
let mockImageModel: ImageModelV3
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
resolver = new ModelResolver()
|
||||
|
||||
// Create properly typed mock models using global utilities
|
||||
mockLanguageModel = createMockLanguageModel({
|
||||
provider: 'test-provider',
|
||||
modelId: 'test-model'
|
||||
})
|
||||
|
||||
mockEmbeddingModel = createMockEmbeddingModel({
|
||||
provider: 'test-provider',
|
||||
modelId: 'test-embedding'
|
||||
})
|
||||
|
||||
mockImageModel = createMockImageModel({
|
||||
provider: 'test-provider',
|
||||
modelId: 'test-image'
|
||||
})
|
||||
|
||||
// Setup default mock implementations
|
||||
vi.mocked(globalRegistryManagement.languageModel).mockReturnValue(mockLanguageModel)
|
||||
vi.mocked(globalRegistryManagement.embeddingModel).mockReturnValue(mockEmbeddingModel)
|
||||
vi.mocked(globalRegistryManagement.imageModel).mockReturnValue(mockImageModel)
|
||||
})
|
||||
|
||||
describe('resolveLanguageModel', () => {
|
||||
describe('Traditional Format Resolution', () => {
|
||||
it('should resolve traditional format modelId without separator', async () => {
|
||||
const result = await resolver.resolveLanguageModel('gpt-4', 'openai')
|
||||
|
||||
expect(globalRegistryManagement.languageModel).toHaveBeenCalledWith(`openai${DEFAULT_SEPARATOR}gpt-4`)
|
||||
expect(result).toBe(mockLanguageModel)
|
||||
})
|
||||
|
||||
it('should resolve with different provider and modelId combinations', async () => {
|
||||
const testCases: Array<{ modelId: string; providerId: string; expected: string }> = [
|
||||
{ modelId: 'claude-3-5-sonnet', providerId: 'anthropic', expected: 'anthropic|claude-3-5-sonnet' },
|
||||
{ modelId: 'gemini-2.0-flash', providerId: 'google', expected: 'google|gemini-2.0-flash' },
|
||||
{ modelId: 'grok-2-latest', providerId: 'xai', expected: 'xai|grok-2-latest' },
|
||||
{ modelId: 'deepseek-chat', providerId: 'deepseek', expected: 'deepseek|deepseek-chat' }
|
||||
]
|
||||
|
||||
for (const testCase of testCases) {
|
||||
vi.clearAllMocks()
|
||||
await resolver.resolveLanguageModel(testCase.modelId, testCase.providerId)
|
||||
|
||||
expect(globalRegistryManagement.languageModel).toHaveBeenCalledWith(testCase.expected)
|
||||
}
|
||||
})
|
||||
|
||||
it('should handle modelIds with special characters', async () => {
|
||||
const modelIds = ['model-v1.0', 'model_v2', 'model.2024', 'model:free']
|
||||
|
||||
for (const modelId of modelIds) {
|
||||
vi.clearAllMocks()
|
||||
await resolver.resolveLanguageModel(modelId, 'provider')
|
||||
|
||||
expect(globalRegistryManagement.languageModel).toHaveBeenCalledWith(`provider${DEFAULT_SEPARATOR}${modelId}`)
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
describe('Namespaced Format Resolution', () => {
|
||||
it('should resolve namespaced format with hub', async () => {
|
||||
const namespacedId = `aihubmix${DEFAULT_SEPARATOR}anthropic${DEFAULT_SEPARATOR}claude-3-5-sonnet`
|
||||
|
||||
const result = await resolver.resolveLanguageModel(namespacedId, 'openai')
|
||||
|
||||
expect(globalRegistryManagement.languageModel).toHaveBeenCalledWith(namespacedId)
|
||||
expect(result).toBe(mockLanguageModel)
|
||||
})
|
||||
|
||||
it('should resolve simple namespaced format', async () => {
|
||||
const namespacedId = `provider${DEFAULT_SEPARATOR}model-id`
|
||||
|
||||
await resolver.resolveLanguageModel(namespacedId, 'fallback-provider')
|
||||
|
||||
expect(globalRegistryManagement.languageModel).toHaveBeenCalledWith(namespacedId)
|
||||
})
|
||||
|
||||
it('should handle complex namespaced IDs', async () => {
|
||||
const complexIds = [
|
||||
`hub${DEFAULT_SEPARATOR}provider${DEFAULT_SEPARATOR}model`,
|
||||
`hub${DEFAULT_SEPARATOR}provider${DEFAULT_SEPARATOR}model-v1.0`,
|
||||
`custom${DEFAULT_SEPARATOR}openai${DEFAULT_SEPARATOR}gpt-4-turbo`
|
||||
]
|
||||
|
||||
for (const id of complexIds) {
|
||||
vi.clearAllMocks()
|
||||
await resolver.resolveLanguageModel(id, 'fallback')
|
||||
|
||||
expect(globalRegistryManagement.languageModel).toHaveBeenCalledWith(id)
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
describe('OpenAI Mode Selection', () => {
|
||||
it('should append "-chat" suffix for OpenAI provider with chat mode', async () => {
|
||||
await resolver.resolveLanguageModel('gpt-4', 'openai', { mode: 'chat' })
|
||||
|
||||
expect(globalRegistryManagement.languageModel).toHaveBeenCalledWith('openai-chat|gpt-4')
|
||||
})
|
||||
|
||||
it('should append "-chat" suffix for Azure provider with chat mode', async () => {
|
||||
await resolver.resolveLanguageModel('gpt-4', 'azure', { mode: 'chat' })
|
||||
|
||||
expect(globalRegistryManagement.languageModel).toHaveBeenCalledWith('azure-chat|gpt-4')
|
||||
})
|
||||
|
||||
it('should not append suffix for OpenAI with responses mode', async () => {
|
||||
await resolver.resolveLanguageModel('gpt-4', 'openai', { mode: 'responses' })
|
||||
|
||||
expect(globalRegistryManagement.languageModel).toHaveBeenCalledWith('openai|gpt-4')
|
||||
})
|
||||
|
||||
it('should not append suffix for OpenAI without mode', async () => {
|
||||
await resolver.resolveLanguageModel('gpt-4', 'openai')
|
||||
|
||||
expect(globalRegistryManagement.languageModel).toHaveBeenCalledWith('openai|gpt-4')
|
||||
})
|
||||
|
||||
it('should not append suffix for other providers with chat mode', async () => {
|
||||
await resolver.resolveLanguageModel('claude-3', 'anthropic', { mode: 'chat' })
|
||||
|
||||
expect(globalRegistryManagement.languageModel).toHaveBeenCalledWith('anthropic|claude-3')
|
||||
})
|
||||
|
||||
it('should handle namespaced IDs with OpenAI chat mode', async () => {
|
||||
const namespacedId = `hub${DEFAULT_SEPARATOR}openai${DEFAULT_SEPARATOR}gpt-4`
|
||||
|
||||
await resolver.resolveLanguageModel(namespacedId, 'openai', { mode: 'chat' })
|
||||
|
||||
// Should use the namespaced ID directly, not apply mode logic
|
||||
expect(globalRegistryManagement.languageModel).toHaveBeenCalledWith(namespacedId)
|
||||
})
|
||||
})
|
||||
|
||||
describe('Middleware Application', () => {
|
||||
it('should apply middlewares to resolved model', async () => {
|
||||
const mockMiddleware = createMockMiddleware({ name: 'test-middleware' })
|
||||
|
||||
const result = await resolver.resolveLanguageModel('gpt-4', 'openai', undefined, [mockMiddleware])
|
||||
|
||||
expect(result).toHaveProperty('_wrapped', true)
|
||||
})
|
||||
|
||||
it('should apply multiple middlewares in order', async () => {
|
||||
const middleware1 = createMockMiddleware({ name: 'middleware-1' })
|
||||
const middleware2 = createMockMiddleware({ name: 'middleware-2' })
|
||||
|
||||
const result = await resolver.resolveLanguageModel('gpt-4', 'openai', undefined, [middleware1, middleware2])
|
||||
|
||||
expect(result).toHaveProperty('_wrapped', true)
|
||||
})
|
||||
|
||||
it('should not apply middlewares when none provided', async () => {
|
||||
const result = await resolver.resolveLanguageModel('gpt-4', 'openai')
|
||||
|
||||
expect(result).not.toHaveProperty('_wrapped')
|
||||
expect(result).toBe(mockLanguageModel)
|
||||
})
|
||||
|
||||
it('should not apply middlewares when empty array provided', async () => {
|
||||
const result = await resolver.resolveLanguageModel('gpt-4', 'openai', undefined, [])
|
||||
|
||||
expect(result).not.toHaveProperty('_wrapped')
|
||||
})
|
||||
})
|
||||
|
||||
describe('Provider Options Handling', () => {
|
||||
it('should pass provider options correctly', async () => {
|
||||
const options = { baseURL: 'https://api.example.com', apiKey: 'test-key' }
|
||||
|
||||
await resolver.resolveLanguageModel('gpt-4', 'openai', options)
|
||||
|
||||
// Provider options are used for mode selection logic
|
||||
expect(globalRegistryManagement.languageModel).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should handle empty provider options', async () => {
|
||||
await resolver.resolveLanguageModel('gpt-4', 'openai', {})
|
||||
|
||||
expect(globalRegistryManagement.languageModel).toHaveBeenCalledWith('openai|gpt-4')
|
||||
})
|
||||
|
||||
it('should handle undefined provider options', async () => {
|
||||
await resolver.resolveLanguageModel('gpt-4', 'openai', undefined)
|
||||
|
||||
expect(globalRegistryManagement.languageModel).toHaveBeenCalledWith('openai|gpt-4')
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('resolveTextEmbeddingModel', () => {
|
||||
describe('Traditional Format', () => {
|
||||
it('should resolve traditional embedding model ID', async () => {
|
||||
const result = await resolver.resolveTextEmbeddingModel('text-embedding-ada-002', 'openai')
|
||||
|
||||
expect(globalRegistryManagement.embeddingModel).toHaveBeenCalledWith('openai|text-embedding-ada-002')
|
||||
expect(result).toBe(mockEmbeddingModel)
|
||||
})
|
||||
|
||||
it('should resolve different embedding models', async () => {
|
||||
const testCases = [
|
||||
{ modelId: 'text-embedding-3-small', providerId: 'openai' },
|
||||
{ modelId: 'text-embedding-3-large', providerId: 'openai' },
|
||||
{ modelId: 'embed-english-v3.0', providerId: 'cohere' },
|
||||
{ modelId: 'voyage-2', providerId: 'voyage' }
|
||||
]
|
||||
|
||||
for (const { modelId, providerId } of testCases) {
|
||||
vi.clearAllMocks()
|
||||
await resolver.resolveTextEmbeddingModel(modelId, providerId)
|
||||
|
||||
expect(globalRegistryManagement.embeddingModel).toHaveBeenCalledWith(`${providerId}|${modelId}`)
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
describe('Namespaced Format', () => {
|
||||
it('should resolve namespaced embedding model ID', async () => {
|
||||
const namespacedId = `aihubmix${DEFAULT_SEPARATOR}openai${DEFAULT_SEPARATOR}text-embedding-3-small`
|
||||
|
||||
const result = await resolver.resolveTextEmbeddingModel(namespacedId, 'openai')
|
||||
|
||||
expect(globalRegistryManagement.embeddingModel).toHaveBeenCalledWith(namespacedId)
|
||||
expect(result).toBe(mockEmbeddingModel)
|
||||
})
|
||||
|
||||
it('should handle complex namespaced embedding IDs', async () => {
|
||||
const complexIds = [
|
||||
`hub${DEFAULT_SEPARATOR}cohere${DEFAULT_SEPARATOR}embed-multilingual`,
|
||||
`custom${DEFAULT_SEPARATOR}provider${DEFAULT_SEPARATOR}embedding-model`
|
||||
]
|
||||
|
||||
for (const id of complexIds) {
|
||||
vi.clearAllMocks()
|
||||
await resolver.resolveTextEmbeddingModel(id, 'fallback')
|
||||
|
||||
expect(globalRegistryManagement.embeddingModel).toHaveBeenCalledWith(id)
|
||||
}
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('resolveImageModel', () => {
|
||||
describe('Traditional Format', () => {
|
||||
it('should resolve traditional image model ID', async () => {
|
||||
const result = await resolver.resolveImageModel('dall-e-3', 'openai')
|
||||
|
||||
expect(globalRegistryManagement.imageModel).toHaveBeenCalledWith('openai|dall-e-3')
|
||||
expect(result).toBe(mockImageModel)
|
||||
})
|
||||
|
||||
it('should resolve different image models', async () => {
|
||||
const testCases = [
|
||||
{ modelId: 'dall-e-2', providerId: 'openai' },
|
||||
{ modelId: 'stable-diffusion-xl', providerId: 'stability' },
|
||||
{ modelId: 'imagen-2', providerId: 'google' },
|
||||
{ modelId: 'midjourney-v6', providerId: 'midjourney' }
|
||||
]
|
||||
|
||||
for (const { modelId, providerId } of testCases) {
|
||||
vi.clearAllMocks()
|
||||
await resolver.resolveImageModel(modelId, providerId)
|
||||
|
||||
expect(globalRegistryManagement.imageModel).toHaveBeenCalledWith(`${providerId}|${modelId}`)
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
describe('Namespaced Format', () => {
|
||||
it('should resolve namespaced image model ID', async () => {
|
||||
const namespacedId = `aihubmix${DEFAULT_SEPARATOR}openai${DEFAULT_SEPARATOR}dall-e-3`
|
||||
|
||||
const result = await resolver.resolveImageModel(namespacedId, 'openai')
|
||||
|
||||
expect(globalRegistryManagement.imageModel).toHaveBeenCalledWith(namespacedId)
|
||||
expect(result).toBe(mockImageModel)
|
||||
})
|
||||
|
||||
it('should handle complex namespaced image IDs', async () => {
|
||||
const complexIds = [
|
||||
`hub${DEFAULT_SEPARATOR}stability${DEFAULT_SEPARATOR}sdxl-turbo`,
|
||||
`custom${DEFAULT_SEPARATOR}provider${DEFAULT_SEPARATOR}image-gen-v2`
|
||||
]
|
||||
|
||||
for (const id of complexIds) {
|
||||
vi.clearAllMocks()
|
||||
await resolver.resolveImageModel(id, 'fallback')
|
||||
|
||||
expect(globalRegistryManagement.imageModel).toHaveBeenCalledWith(id)
|
||||
}
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('Edge Cases and Error Scenarios', () => {
|
||||
it('should handle empty model IDs', async () => {
|
||||
await resolver.resolveLanguageModel('', 'openai')
|
||||
|
||||
expect(globalRegistryManagement.languageModel).toHaveBeenCalledWith('openai|')
|
||||
})
|
||||
|
||||
it('should handle model IDs with multiple separators', async () => {
|
||||
const multiSeparatorId = `hub${DEFAULT_SEPARATOR}sub${DEFAULT_SEPARATOR}provider${DEFAULT_SEPARATOR}model`
|
||||
|
||||
await resolver.resolveLanguageModel(multiSeparatorId, 'fallback')
|
||||
|
||||
expect(globalRegistryManagement.languageModel).toHaveBeenCalledWith(multiSeparatorId)
|
||||
})
|
||||
|
||||
it('should handle model IDs with only separator', async () => {
|
||||
const onlySeparator = DEFAULT_SEPARATOR
|
||||
|
||||
await resolver.resolveLanguageModel(onlySeparator, 'provider')
|
||||
|
||||
expect(globalRegistryManagement.languageModel).toHaveBeenCalledWith(onlySeparator)
|
||||
})
|
||||
|
||||
it('should throw if globalRegistryManagement throws', async () => {
|
||||
const error = new Error('Model not found in registry')
|
||||
vi.mocked(globalRegistryManagement.languageModel).mockImplementation(() => {
|
||||
throw error
|
||||
})
|
||||
|
||||
await expect(resolver.resolveLanguageModel('invalid-model', 'openai')).rejects.toThrow(
|
||||
'Model not found in registry'
|
||||
)
|
||||
})
|
||||
|
||||
it('should handle concurrent resolution requests', async () => {
|
||||
const promises = [
|
||||
resolver.resolveLanguageModel('gpt-4', 'openai'),
|
||||
resolver.resolveLanguageModel('claude-3', 'anthropic'),
|
||||
resolver.resolveLanguageModel('gemini-2.0', 'google')
|
||||
]
|
||||
|
||||
const results = await Promise.all(promises)
|
||||
|
||||
expect(results).toHaveLength(3)
|
||||
expect(globalRegistryManagement.languageModel).toHaveBeenCalledTimes(3)
|
||||
})
|
||||
})
|
||||
|
||||
describe('Type Safety', () => {
|
||||
it('should return properly typed LanguageModelV3', async () => {
|
||||
const result = await resolver.resolveLanguageModel('gpt-4', 'openai')
|
||||
|
||||
// Type assertions
|
||||
expect(result.specificationVersion).toBe('v3')
|
||||
expect(result).toHaveProperty('doGenerate')
|
||||
expect(result).toHaveProperty('doStream')
|
||||
})
|
||||
|
||||
it('should return properly typed EmbeddingModelV3', async () => {
|
||||
const result = await resolver.resolveTextEmbeddingModel('text-embedding-ada-002', 'openai')
|
||||
|
||||
expect(result.specificationVersion).toBe('v3')
|
||||
expect(result).toHaveProperty('doEmbed')
|
||||
})
|
||||
|
||||
it('should return properly typed ImageModelV3', async () => {
|
||||
const result = await resolver.resolveImageModel('dall-e-3', 'openai')
|
||||
|
||||
expect(result.specificationVersion).toBe('v3')
|
||||
expect(result).toHaveProperty('doGenerate')
|
||||
})
|
||||
})
|
||||
|
||||
describe('Global ModelResolver Instance', () => {
|
||||
it('should have a global instance available', async () => {
|
||||
const { globalModelResolver } = await import('../ModelResolver')
|
||||
|
||||
expect(globalModelResolver).toBeInstanceOf(ModelResolver)
|
||||
})
|
||||
})
|
||||
|
||||
describe('Integration with Different Provider Types', () => {
|
||||
it('should work with OpenAI compatible providers', async () => {
|
||||
await resolver.resolveLanguageModel('custom-model', 'openai-compatible')
|
||||
|
||||
expect(globalRegistryManagement.languageModel).toHaveBeenCalledWith('openai-compatible|custom-model')
|
||||
})
|
||||
|
||||
it('should work with hub providers', async () => {
|
||||
const hubId = `aihubmix${DEFAULT_SEPARATOR}custom${DEFAULT_SEPARATOR}model-v1`
|
||||
|
||||
await resolver.resolveLanguageModel(hubId, 'aihubmix')
|
||||
|
||||
expect(globalRegistryManagement.languageModel).toHaveBeenCalledWith(hubId)
|
||||
})
|
||||
|
||||
it('should handle all model types for same provider', async () => {
|
||||
const providerId = 'openai'
|
||||
const languageModel = 'gpt-4'
|
||||
const embeddingModel = 'text-embedding-3-small'
|
||||
const imageModel = 'dall-e-3'
|
||||
|
||||
await resolver.resolveLanguageModel(languageModel, providerId)
|
||||
await resolver.resolveTextEmbeddingModel(embeddingModel, providerId)
|
||||
await resolver.resolveImageModel(imageModel, providerId)
|
||||
|
||||
expect(globalRegistryManagement.languageModel).toHaveBeenCalledWith(`${providerId}|${languageModel}`)
|
||||
expect(globalRegistryManagement.embeddingModel).toHaveBeenCalledWith(`${providerId}|${embeddingModel}`)
|
||||
expect(globalRegistryManagement.imageModel).toHaveBeenCalledWith(`${providerId}|${imageModel}`)
|
||||
})
|
||||
})
|
||||
})
|
||||
171
packages/aiCore/src/core/models/__tests__/utils.test.ts
Normal file
171
packages/aiCore/src/core/models/__tests__/utils.test.ts
Normal file
@ -0,0 +1,171 @@
|
||||
import type { LanguageModelV2, LanguageModelV3 } from '@ai-sdk/provider'
|
||||
import { describe, expect, it } from 'vitest'
|
||||
|
||||
import type { AiSdkModel } from '../../providers'
|
||||
import { hasModelId, isV2Model, isV3Model } from '../utils'
|
||||
|
||||
describe('Model Type Guards', () => {
|
||||
describe('isV2Model', () => {
|
||||
it('should return true for V2 models', () => {
|
||||
const v2Model: AiSdkModel = {
|
||||
specificationVersion: 'v2',
|
||||
modelId: 'test-model',
|
||||
provider: 'test-provider'
|
||||
} as LanguageModelV2
|
||||
|
||||
expect(isV2Model(v2Model)).toBe(true)
|
||||
})
|
||||
|
||||
it('should return false for V3 models', () => {
|
||||
const v3Model: AiSdkModel = {
|
||||
specificationVersion: 'v3',
|
||||
modelId: 'test-model',
|
||||
provider: 'test-provider'
|
||||
} as LanguageModelV3
|
||||
|
||||
expect(isV2Model(v3Model)).toBe(false)
|
||||
})
|
||||
|
||||
it('should return false for non-object values', () => {
|
||||
expect(isV2Model('model-id' as any)).toBe(false)
|
||||
expect(isV2Model(null as any)).toBe(false)
|
||||
expect(isV2Model(undefined as any)).toBe(false)
|
||||
expect(isV2Model(123 as any)).toBe(false)
|
||||
})
|
||||
|
||||
it('should return false for objects without specificationVersion', () => {
|
||||
const invalidModel = {
|
||||
modelId: 'test-model',
|
||||
provider: 'test-provider'
|
||||
} as any
|
||||
|
||||
expect(isV2Model(invalidModel)).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
describe('isV3Model', () => {
|
||||
it('should return true for V3 models', () => {
|
||||
const v3Model: AiSdkModel = {
|
||||
specificationVersion: 'v3',
|
||||
modelId: 'test-model',
|
||||
provider: 'test-provider'
|
||||
} as LanguageModelV3
|
||||
|
||||
expect(isV3Model(v3Model)).toBe(true)
|
||||
})
|
||||
|
||||
it('should return false for V2 models', () => {
|
||||
const v2Model: AiSdkModel = {
|
||||
specificationVersion: 'v2',
|
||||
modelId: 'test-model',
|
||||
provider: 'test-provider'
|
||||
} as LanguageModelV2
|
||||
|
||||
expect(isV3Model(v2Model)).toBe(false)
|
||||
})
|
||||
|
||||
it('should return false for non-object values', () => {
|
||||
expect(isV3Model('model-id' as any)).toBe(false)
|
||||
expect(isV3Model(null as any)).toBe(false)
|
||||
expect(isV3Model(undefined as any)).toBe(false)
|
||||
})
|
||||
|
||||
it('should return false for objects without specificationVersion', () => {
|
||||
const invalidModel = {
|
||||
modelId: 'test-model',
|
||||
provider: 'test-provider'
|
||||
} as any
|
||||
|
||||
expect(isV3Model(invalidModel)).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
describe('Type Guard Correctness', () => {
|
||||
it('should correctly distinguish between V2 and V3 models', () => {
|
||||
const v2Model: AiSdkModel = {
|
||||
specificationVersion: 'v2',
|
||||
modelId: 'v2-model'
|
||||
} as LanguageModelV2
|
||||
|
||||
const v3Model: AiSdkModel = {
|
||||
specificationVersion: 'v3',
|
||||
modelId: 'v3-model'
|
||||
} as LanguageModelV3
|
||||
|
||||
// V2 model should only match isV2Model
|
||||
expect(isV2Model(v2Model)).toBe(true)
|
||||
expect(isV3Model(v2Model)).toBe(false)
|
||||
|
||||
// V3 model should only match isV3Model
|
||||
expect(isV2Model(v3Model)).toBe(false)
|
||||
expect(isV3Model(v3Model)).toBe(true)
|
||||
})
|
||||
|
||||
it('should narrow type correctly for V2 models', () => {
|
||||
const model: AiSdkModel = {
|
||||
specificationVersion: 'v2',
|
||||
modelId: 'test'
|
||||
} as LanguageModelV2
|
||||
|
||||
if (isV2Model(model)) {
|
||||
expect(model.specificationVersion).toBe('v2')
|
||||
}
|
||||
})
|
||||
|
||||
it('should narrow type correctly for V3 models', () => {
|
||||
const model: AiSdkModel = {
|
||||
specificationVersion: 'v3',
|
||||
modelId: 'test'
|
||||
} as LanguageModelV3
|
||||
|
||||
if (isV3Model(model)) {
|
||||
expect(model.specificationVersion).toBe('v3')
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
describe('hasModelId', () => {
|
||||
it('should return true for objects with modelId string property', () => {
|
||||
const modelWithId = {
|
||||
modelId: 'test-model-id',
|
||||
other: 'property'
|
||||
}
|
||||
|
||||
expect(hasModelId(modelWithId)).toBe(true)
|
||||
})
|
||||
|
||||
it('should return false for objects without modelId property', () => {
|
||||
const modelWithoutId = {
|
||||
other: 'property'
|
||||
}
|
||||
|
||||
expect(hasModelId(modelWithoutId)).toBe(false)
|
||||
})
|
||||
|
||||
it('should return false for objects with non-string modelId', () => {
|
||||
const modelWithNumericId = {
|
||||
modelId: 123
|
||||
}
|
||||
|
||||
expect(hasModelId(modelWithNumericId)).toBe(false)
|
||||
})
|
||||
|
||||
it('should return false for non-object values', () => {
|
||||
expect(hasModelId(null)).toBe(false)
|
||||
expect(hasModelId(undefined)).toBe(false)
|
||||
expect(hasModelId('string')).toBe(false)
|
||||
expect(hasModelId(123)).toBe(false)
|
||||
expect(hasModelId(true)).toBe(false)
|
||||
})
|
||||
|
||||
it('should narrow type correctly', () => {
|
||||
const unknownValue: unknown = { modelId: 'test-id' }
|
||||
|
||||
if (hasModelId(unknownValue)) {
|
||||
// TypeScript should allow accessing modelId as string
|
||||
expect(typeof unknownValue.modelId).toBe('string')
|
||||
expect(unknownValue.modelId).toBe('test-id')
|
||||
}
|
||||
})
|
||||
})
|
||||
})
|
||||
@ -4,6 +4,7 @@
|
||||
* 负责处理 AI SDK 流事件的发送和管理
|
||||
* 从 promptToolUsePlugin.ts 中提取出来以降低复杂度
|
||||
*/
|
||||
import type { SharedV3ProviderMetadata } from '@ai-sdk/provider'
|
||||
import type { EmbeddingModelUsage, ImageModelUsage, LanguageModelUsage, ModelMessage } from 'ai'
|
||||
|
||||
import type { AiSdkUsage } from '../../../providers/types'
|
||||
@ -79,7 +80,11 @@ export class StreamEventManager {
|
||||
*/
|
||||
sendStepFinishEvent(
|
||||
controller: StreamController,
|
||||
chunk: any,
|
||||
chunk: {
|
||||
usage?: Partial<AiSdkUsage>
|
||||
response?: { id: string; [key: string]: unknown }
|
||||
providerMetadata?: SharedV3ProviderMetadata
|
||||
},
|
||||
context: AiRequestContext,
|
||||
finishReason: string = 'stop'
|
||||
): void {
|
||||
@ -154,7 +159,7 @@ export class StreamEventManager {
|
||||
context: AiRequestContext<TParams, StreamTextResult>,
|
||||
textBuffer: string,
|
||||
toolResultsText: string,
|
||||
tools: any
|
||||
tools: Record<string, unknown>
|
||||
): Partial<TParams> {
|
||||
const params = context.originalParams
|
||||
|
||||
|
||||
@ -14,15 +14,16 @@ import type { ToolUseResult } from './type'
|
||||
export interface ExecutedResult {
|
||||
toolCallId: string
|
||||
toolName: string
|
||||
result: any
|
||||
result: unknown
|
||||
isError?: boolean
|
||||
}
|
||||
|
||||
/**
|
||||
* 流控制器类型(从 AI SDK 提取)
|
||||
* Generic type parameter allows for type-safe chunk enqueuing
|
||||
*/
|
||||
export interface StreamController {
|
||||
enqueue(chunk: any): void
|
||||
export interface StreamController<TChunk = unknown> {
|
||||
enqueue(chunk: TChunk): void
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@ -0,0 +1,566 @@
|
||||
import type { SharedV3ProviderMetadata } from '@ai-sdk/provider'
|
||||
import type {
|
||||
EmbeddingModelUsage,
|
||||
ImageModelUsage,
|
||||
LanguageModelUsage as AiSdkUsage,
|
||||
LanguageModelUsage,
|
||||
TextStreamPart,
|
||||
ToolSet
|
||||
} from 'ai'
|
||||
import { simulateReadableStream } from 'ai'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import { createMockContext, createMockTool } from '../../../../../__tests__'
|
||||
import { StreamEventManager } from '../StreamEventManager'
|
||||
import type { StreamController } from '../ToolExecutor'
|
||||
|
||||
/**
|
||||
* Type alias for empty toolset (no tools)
|
||||
* Using Record<string, never> ensures type safety for tests without tools
|
||||
*/
|
||||
type EmptyToolSet = Record<string, never>
|
||||
|
||||
/**
|
||||
* Mock StreamController for testing
|
||||
* Provides type-safe enqueue function that accepts TextStreamPart chunks
|
||||
*/
|
||||
interface MockStreamController<TOOLS extends ToolSet = EmptyToolSet> extends StreamController {
|
||||
enqueue: ReturnType<typeof vi.fn<(chunk: TextStreamPart<TOOLS>) => void>>
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a type-safe mock stream controller
|
||||
*/
|
||||
function createMockStreamController<TOOLS extends ToolSet = EmptyToolSet>(): MockStreamController<TOOLS> {
|
||||
return {
|
||||
enqueue: vi.fn()
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Type for chunk data in finish-step events
|
||||
*/
|
||||
interface FinishStepChunk {
|
||||
usage?: Partial<AiSdkUsage>
|
||||
response?: { id: string; [key: string]: unknown }
|
||||
providerMetadata?: SharedV3ProviderMetadata
|
||||
}
|
||||
|
||||
describe('StreamEventManager', () => {
|
||||
let manager: StreamEventManager
|
||||
|
||||
beforeEach(() => {
|
||||
manager = new StreamEventManager()
|
||||
})
|
||||
|
||||
describe('accumulateUsage', () => {
|
||||
describe('LanguageModelUsage', () => {
|
||||
it('should accumulate language model usage correctly', () => {
|
||||
const target: Partial<LanguageModelUsage> = {
|
||||
inputTokens: 10,
|
||||
outputTokens: 20,
|
||||
totalTokens: 30
|
||||
}
|
||||
const source: Partial<LanguageModelUsage> = {
|
||||
inputTokens: 5,
|
||||
outputTokens: 10,
|
||||
totalTokens: 15
|
||||
}
|
||||
|
||||
manager.accumulateUsage(target, source)
|
||||
|
||||
expect(target.inputTokens).toBe(15)
|
||||
expect(target.outputTokens).toBe(30)
|
||||
expect(target.totalTokens).toBe(45)
|
||||
})
|
||||
|
||||
it('should handle undefined values in target', () => {
|
||||
const target: Partial<LanguageModelUsage> = { inputTokens: 10 }
|
||||
const source: Partial<LanguageModelUsage> = {
|
||||
inputTokens: 5,
|
||||
outputTokens: 10,
|
||||
totalTokens: 15
|
||||
}
|
||||
|
||||
manager.accumulateUsage(target, source)
|
||||
|
||||
expect(target.inputTokens).toBe(15)
|
||||
expect(target.outputTokens).toBe(10)
|
||||
expect(target.totalTokens).toBe(15)
|
||||
})
|
||||
|
||||
it('should handle undefined values in source', () => {
|
||||
const target: Partial<LanguageModelUsage> = {
|
||||
inputTokens: 10,
|
||||
outputTokens: 20,
|
||||
totalTokens: 30
|
||||
}
|
||||
const source: Partial<LanguageModelUsage> = { inputTokens: 5 }
|
||||
|
||||
manager.accumulateUsage(target, source)
|
||||
|
||||
expect(target.inputTokens).toBe(15)
|
||||
expect(target.outputTokens).toBe(20)
|
||||
expect(target.totalTokens).toBe(30)
|
||||
})
|
||||
|
||||
it('should handle zero values correctly', () => {
|
||||
const target: Partial<LanguageModelUsage> = {
|
||||
inputTokens: 0,
|
||||
outputTokens: 0,
|
||||
totalTokens: 0
|
||||
}
|
||||
const source: Partial<LanguageModelUsage> = {
|
||||
inputTokens: 5,
|
||||
outputTokens: 10,
|
||||
totalTokens: 15
|
||||
}
|
||||
|
||||
manager.accumulateUsage(target, source)
|
||||
|
||||
expect(target.inputTokens).toBe(5)
|
||||
expect(target.outputTokens).toBe(10)
|
||||
expect(target.totalTokens).toBe(15)
|
||||
})
|
||||
})
|
||||
|
||||
describe('ImageModelUsage', () => {
|
||||
it('should accumulate image model usage correctly', () => {
|
||||
const target: Partial<ImageModelUsage> = {
|
||||
inputTokens: 100,
|
||||
outputTokens: 50,
|
||||
totalTokens: 150
|
||||
}
|
||||
const source: Partial<ImageModelUsage> = {
|
||||
inputTokens: 50,
|
||||
outputTokens: 25,
|
||||
totalTokens: 75
|
||||
}
|
||||
|
||||
manager.accumulateUsage(target, source)
|
||||
|
||||
expect(target.inputTokens).toBe(150)
|
||||
expect(target.outputTokens).toBe(75)
|
||||
expect(target.totalTokens).toBe(225)
|
||||
})
|
||||
|
||||
it('should handle undefined values', () => {
|
||||
const target: Partial<ImageModelUsage> = { inputTokens: 100 }
|
||||
const source: Partial<ImageModelUsage> = {
|
||||
outputTokens: 50,
|
||||
totalTokens: 50
|
||||
}
|
||||
|
||||
manager.accumulateUsage(target, source)
|
||||
|
||||
expect(target.inputTokens).toBe(100)
|
||||
expect(target.outputTokens).toBe(50)
|
||||
expect(target.totalTokens).toBe(50)
|
||||
})
|
||||
})
|
||||
|
||||
describe('EmbeddingModelUsage', () => {
|
||||
it('should accumulate embedding model usage correctly', () => {
|
||||
const target: Partial<EmbeddingModelUsage> = { tokens: 100 }
|
||||
const source: Partial<EmbeddingModelUsage> = { tokens: 50 }
|
||||
|
||||
manager.accumulateUsage(target, source)
|
||||
|
||||
expect(target.tokens).toBe(150)
|
||||
})
|
||||
|
||||
it('should handle zero to non-zero accumulation', () => {
|
||||
const target: Partial<EmbeddingModelUsage> = { tokens: 0 }
|
||||
const source: Partial<EmbeddingModelUsage> = { tokens: 50 }
|
||||
|
||||
manager.accumulateUsage(target, source)
|
||||
|
||||
expect(target.tokens).toBe(50)
|
||||
})
|
||||
|
||||
it('should handle zero values', () => {
|
||||
const target: Partial<EmbeddingModelUsage> = { tokens: 0 }
|
||||
const source: Partial<EmbeddingModelUsage> = { tokens: 100 }
|
||||
|
||||
manager.accumulateUsage(target, source)
|
||||
|
||||
expect(target.tokens).toBe(100)
|
||||
})
|
||||
})
|
||||
|
||||
describe('Type Guard Validation', () => {
|
||||
it('should warn on type mismatch between LanguageModelUsage and EmbeddingModelUsage', () => {
|
||||
const warnSpy = vi.spyOn(console, 'warn')
|
||||
const target: Partial<LanguageModelUsage> = { inputTokens: 10 }
|
||||
const source: Partial<EmbeddingModelUsage> = { tokens: 5 }
|
||||
|
||||
manager.accumulateUsage(target, source)
|
||||
|
||||
expect(warnSpy).toHaveBeenCalledWith(
|
||||
expect.stringContaining('Unable to accumulate usage'),
|
||||
expect.objectContaining({
|
||||
target,
|
||||
source
|
||||
})
|
||||
)
|
||||
|
||||
warnSpy.mockRestore()
|
||||
})
|
||||
|
||||
it('should warn on type mismatch between ImageModelUsage and EmbeddingModelUsage', () => {
|
||||
const warnSpy = vi.spyOn(console, 'warn')
|
||||
const target: Partial<ImageModelUsage> = { inputTokens: 100 }
|
||||
const source: Partial<EmbeddingModelUsage> = { tokens: 50 }
|
||||
|
||||
manager.accumulateUsage(target, source)
|
||||
|
||||
expect(warnSpy).toHaveBeenCalledWith(expect.stringContaining('Unable to accumulate usage'), expect.any(Object))
|
||||
|
||||
warnSpy.mockRestore()
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('buildRecursiveParams', () => {
|
||||
it('should include textBuffer in assistant message when not empty', () => {
|
||||
const context = createMockContext()
|
||||
const textBuffer = 'test response'
|
||||
const toolResultsText = '<tool_use_result>...</tool_use_result>'
|
||||
const tools = {
|
||||
test_tool: createMockTool('test_tool')
|
||||
}
|
||||
|
||||
const params = manager.buildRecursiveParams(context, textBuffer, toolResultsText, tools)
|
||||
|
||||
expect(params.messages).toHaveLength(3)
|
||||
expect(params.messages?.[0]).toEqual({ role: 'user', content: 'Test message' })
|
||||
expect(params.messages?.[1]).toEqual({ role: 'assistant', content: textBuffer })
|
||||
expect(params.messages?.[2]).toEqual({
|
||||
role: 'user',
|
||||
content: toolResultsText
|
||||
})
|
||||
expect(params.tools).toBe(tools)
|
||||
})
|
||||
|
||||
it('should skip empty textBuffer in messages', () => {
|
||||
const context = createMockContext()
|
||||
const textBuffer = ''
|
||||
const toolResultsText = '<tool_use_result>...</tool_use_result>'
|
||||
const tools = {}
|
||||
|
||||
const params = manager.buildRecursiveParams(context, textBuffer, toolResultsText, tools)
|
||||
|
||||
// Should only have original user message and new user message with tool results
|
||||
expect(params.messages).toHaveLength(2)
|
||||
expect(params.messages?.[0]).toEqual({ role: 'user', content: 'Test message' })
|
||||
expect(params.messages?.[1]).toEqual({
|
||||
role: 'user',
|
||||
content: toolResultsText
|
||||
})
|
||||
|
||||
const assistantMessages = params.messages?.filter((m) => m.role === 'assistant')
|
||||
expect(assistantMessages).toHaveLength(0)
|
||||
})
|
||||
|
||||
it('should preserve all original messages', () => {
|
||||
const context = createMockContext({
|
||||
originalParams: {
|
||||
messages: [
|
||||
{ role: 'user', content: 'First message' },
|
||||
{ role: 'assistant', content: 'First response' },
|
||||
{ role: 'user', content: 'Second message' }
|
||||
]
|
||||
}
|
||||
})
|
||||
|
||||
const params = manager.buildRecursiveParams(context, 'New response', 'Tool results', {})
|
||||
|
||||
expect(params.messages).toHaveLength(5)
|
||||
expect(params.messages?.[0]).toEqual({ role: 'user', content: 'First message' })
|
||||
expect(params.messages?.[1]).toEqual({
|
||||
role: 'assistant',
|
||||
content: 'First response'
|
||||
})
|
||||
expect(params.messages?.[2]).toEqual({ role: 'user', content: 'Second message' })
|
||||
expect(params.messages?.[3]).toEqual({ role: 'assistant', content: 'New response' })
|
||||
expect(params.messages?.[4]).toEqual({ role: 'user', content: 'Tool results' })
|
||||
})
|
||||
|
||||
it('should pass through tools parameter', () => {
|
||||
const context = createMockContext()
|
||||
const tools = {
|
||||
tool1: createMockTool('tool1'),
|
||||
tool2: createMockTool('tool2')
|
||||
}
|
||||
|
||||
const params = manager.buildRecursiveParams(context, 'response', 'results', tools)
|
||||
|
||||
expect(params.tools).toBe(tools)
|
||||
expect(Object.keys(params.tools!)).toHaveLength(2)
|
||||
})
|
||||
})
|
||||
|
||||
describe('sendStepStartEvent', () => {
|
||||
it('should enqueue start-step event with correct structure', () => {
|
||||
const controller = createMockStreamController()
|
||||
|
||||
manager.sendStepStartEvent(controller)
|
||||
|
||||
expect(controller.enqueue).toHaveBeenCalledWith({
|
||||
type: 'start-step',
|
||||
request: {},
|
||||
warnings: []
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('sendStepFinishEvent', () => {
|
||||
it('should enqueue finish-step event with provided finishReason', () => {
|
||||
const controller = createMockStreamController()
|
||||
|
||||
const chunk: FinishStepChunk = {
|
||||
usage: {
|
||||
inputTokens: 10,
|
||||
outputTokens: 20,
|
||||
totalTokens: 30
|
||||
},
|
||||
response: { id: 'test-response' },
|
||||
providerMetadata: { 'test-provider': {} }
|
||||
}
|
||||
|
||||
const context = createMockContext({
|
||||
accumulatedUsage: {
|
||||
inputTokens: 0,
|
||||
outputTokens: 0,
|
||||
totalTokens: 0
|
||||
}
|
||||
})
|
||||
|
||||
manager.sendStepFinishEvent(controller, chunk, context, 'tool-calls')
|
||||
|
||||
expect(controller.enqueue).toHaveBeenCalledWith({
|
||||
type: 'finish-step',
|
||||
finishReason: 'tool-calls',
|
||||
response: chunk.response,
|
||||
usage: chunk.usage,
|
||||
providerMetadata: chunk.providerMetadata
|
||||
})
|
||||
})
|
||||
|
||||
it('should accumulate usage when provided', () => {
|
||||
const controller = createMockStreamController()
|
||||
|
||||
const chunk: FinishStepChunk = {
|
||||
usage: {
|
||||
inputTokens: 10,
|
||||
outputTokens: 20,
|
||||
totalTokens: 30
|
||||
}
|
||||
}
|
||||
|
||||
const context = createMockContext({
|
||||
accumulatedUsage: {
|
||||
inputTokens: 5,
|
||||
outputTokens: 10,
|
||||
totalTokens: 15
|
||||
}
|
||||
})
|
||||
|
||||
manager.sendStepFinishEvent(controller, chunk, context)
|
||||
|
||||
// Verify accumulation happened
|
||||
expect(context.accumulatedUsage.inputTokens).toBe(15)
|
||||
expect(context.accumulatedUsage.outputTokens).toBe(30)
|
||||
expect(context.accumulatedUsage.totalTokens).toBe(45)
|
||||
})
|
||||
|
||||
it('should handle missing usage gracefully', () => {
|
||||
const controller = createMockStreamController()
|
||||
|
||||
const chunk: FinishStepChunk = {}
|
||||
const context = createMockContext({
|
||||
accumulatedUsage: {
|
||||
inputTokens: 5,
|
||||
outputTokens: 10,
|
||||
totalTokens: 15
|
||||
}
|
||||
})
|
||||
|
||||
expect(() => manager.sendStepFinishEvent(controller, chunk, context)).not.toThrow()
|
||||
|
||||
// Verify accumulation did not change
|
||||
expect(context.accumulatedUsage.inputTokens).toBe(5)
|
||||
expect(context.accumulatedUsage.outputTokens).toBe(10)
|
||||
expect(context.accumulatedUsage.totalTokens).toBe(15)
|
||||
})
|
||||
|
||||
it('should use default finishReason of "stop" when not provided', () => {
|
||||
const controller = createMockStreamController()
|
||||
|
||||
const chunk: FinishStepChunk = {}
|
||||
const context = createMockContext()
|
||||
|
||||
manager.sendStepFinishEvent(controller, chunk, context)
|
||||
|
||||
expect(controller.enqueue).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
finishReason: 'stop'
|
||||
})
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
describe('handleRecursiveCall', () => {
|
||||
it('should reset hasExecutedToolsInCurrentStep flag', async () => {
|
||||
const controller = createMockStreamController()
|
||||
|
||||
const mockStream = simulateReadableStream<TextStreamPart<EmptyToolSet>>({
|
||||
chunks: [
|
||||
{
|
||||
type: 'text-delta',
|
||||
id: 'test-id',
|
||||
text: 'test'
|
||||
} as TextStreamPart<EmptyToolSet>
|
||||
],
|
||||
initialDelayInMs: 0,
|
||||
chunkDelayInMs: 0
|
||||
})
|
||||
|
||||
const context = createMockContext({
|
||||
hasExecutedToolsInCurrentStep: true,
|
||||
recursiveCall: vi.fn().mockResolvedValue({
|
||||
fullStream: mockStream
|
||||
})
|
||||
})
|
||||
|
||||
const params = { messages: [] }
|
||||
|
||||
await manager.handleRecursiveCall(controller, params, context)
|
||||
|
||||
expect(context.hasExecutedToolsInCurrentStep).toBe(false)
|
||||
expect(context.recursiveCall).toHaveBeenCalledWith(params)
|
||||
})
|
||||
|
||||
it('should pipe recursive stream to controller', async () => {
|
||||
const enqueuedChunks: TextStreamPart<EmptyToolSet>[] = []
|
||||
const controller = createMockStreamController()
|
||||
controller.enqueue.mockImplementation((chunk: TextStreamPart<EmptyToolSet>) => {
|
||||
enqueuedChunks.push(chunk)
|
||||
})
|
||||
|
||||
const mockChunks: TextStreamPart<EmptyToolSet>[] = [
|
||||
{ type: 'start' as const },
|
||||
{ type: 'start-step' as const, request: {}, warnings: [] },
|
||||
{ type: 'text-delta' as const, id: 'chunk-1', text: 'recursive' },
|
||||
{ type: 'text-delta' as const, id: 'chunk-2', text: ' response' },
|
||||
{
|
||||
type: 'finish-step' as const,
|
||||
finishReason: 'stop',
|
||||
rawFinishReason: 'stop',
|
||||
response: {
|
||||
id: 'test-response-id',
|
||||
timestamp: new Date(),
|
||||
modelId: 'test-model'
|
||||
},
|
||||
usage: {
|
||||
totalTokens: 0,
|
||||
inputTokens: 0,
|
||||
outputTokens: 0,
|
||||
inputTokenDetails: {
|
||||
noCacheTokens: 0,
|
||||
cacheReadTokens: 0,
|
||||
cacheWriteTokens: 0
|
||||
},
|
||||
outputTokenDetails: {
|
||||
textTokens: 0,
|
||||
reasoningTokens: 0
|
||||
}
|
||||
},
|
||||
providerMetadata: undefined
|
||||
},
|
||||
{
|
||||
type: 'finish' as const,
|
||||
finishReason: 'stop',
|
||||
rawFinishReason: 'stop',
|
||||
totalUsage: {
|
||||
totalTokens: 0,
|
||||
inputTokens: 0,
|
||||
outputTokens: 0,
|
||||
inputTokenDetails: {
|
||||
noCacheTokens: 0,
|
||||
cacheReadTokens: 0,
|
||||
cacheWriteTokens: 0
|
||||
},
|
||||
outputTokenDetails: {
|
||||
textTokens: 0,
|
||||
reasoningTokens: 0
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
const mockStream = simulateReadableStream<TextStreamPart<EmptyToolSet>>({
|
||||
chunks: mockChunks,
|
||||
initialDelayInMs: 0,
|
||||
chunkDelayInMs: 0
|
||||
})
|
||||
|
||||
const context = createMockContext({
|
||||
hasExecutedToolsInCurrentStep: true,
|
||||
recursiveCall: vi.fn().mockResolvedValue({
|
||||
fullStream: mockStream
|
||||
})
|
||||
})
|
||||
|
||||
await manager.handleRecursiveCall(controller, {}, context)
|
||||
|
||||
// Should skip 'start' type and stop at 'finish' type
|
||||
expect(enqueuedChunks).toHaveLength(4)
|
||||
expect(enqueuedChunks[0]).toEqual({ type: 'start-step', request: {}, warnings: [] })
|
||||
expect(enqueuedChunks[1]).toEqual({ type: 'text-delta', id: 'chunk-1', text: 'recursive' })
|
||||
expect(enqueuedChunks[2]).toEqual({ type: 'text-delta', id: 'chunk-2', text: ' response' })
|
||||
expect(enqueuedChunks[3]).toMatchObject({
|
||||
type: 'finish-step',
|
||||
finishReason: 'stop',
|
||||
rawFinishReason: 'stop',
|
||||
providerMetadata: undefined,
|
||||
usage: {
|
||||
totalTokens: 0,
|
||||
inputTokens: 0,
|
||||
outputTokens: 0,
|
||||
inputTokenDetails: {
|
||||
noCacheTokens: 0,
|
||||
cacheReadTokens: 0,
|
||||
cacheWriteTokens: 0
|
||||
},
|
||||
outputTokenDetails: {
|
||||
textTokens: 0,
|
||||
reasoningTokens: 0
|
||||
}
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
it('should warn when no fullStream is found', async () => {
|
||||
const warnSpy = vi.spyOn(console, 'warn')
|
||||
const controller = createMockStreamController()
|
||||
|
||||
const context = createMockContext({
|
||||
hasExecutedToolsInCurrentStep: true,
|
||||
recursiveCall: vi.fn().mockResolvedValue({
|
||||
// No fullStream property
|
||||
someOtherProperty: 'value'
|
||||
})
|
||||
})
|
||||
|
||||
await manager.handleRecursiveCall(controller, {}, context)
|
||||
|
||||
expect(warnSpy).toHaveBeenCalledWith(
|
||||
expect.stringContaining('[MCP Prompt] No fullstream found'),
|
||||
expect.any(Object)
|
||||
)
|
||||
|
||||
warnSpy.mockRestore()
|
||||
})
|
||||
})
|
||||
})
|
||||
@ -0,0 +1,547 @@
|
||||
import type { TextStreamPart, ToolSet } from 'ai'
|
||||
import { simulateReadableStream } from 'ai'
|
||||
import { convertReadableStreamToArray } from 'ai/test'
|
||||
import { describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import { createMockContext, createMockStreamParams, createMockTool, createMockToolSet } from '../../../../../__tests__'
|
||||
import { createPromptToolUsePlugin, DEFAULT_SYSTEM_PROMPT } from '../promptToolUsePlugin'
|
||||
|
||||
describe('promptToolUsePlugin', () => {
|
||||
describe('Factory Function', () => {
|
||||
it('should return AiPlugin with correct name', () => {
|
||||
const plugin = createPromptToolUsePlugin()
|
||||
|
||||
expect(plugin.name).toBe('built-in:prompt-tool-use')
|
||||
expect(plugin.transformParams).toBeDefined()
|
||||
expect(plugin.transformStream).toBeDefined()
|
||||
})
|
||||
|
||||
it('should accept empty configuration', () => {
|
||||
const plugin = createPromptToolUsePlugin({})
|
||||
|
||||
expect(plugin).toBeDefined()
|
||||
expect(plugin.name).toBe('built-in:prompt-tool-use')
|
||||
})
|
||||
|
||||
it('should accept custom buildSystemPrompt', () => {
|
||||
const customBuildSystemPrompt = vi.fn((userSystemPrompt: string) => userSystemPrompt)
|
||||
|
||||
const plugin = createPromptToolUsePlugin({
|
||||
buildSystemPrompt: customBuildSystemPrompt
|
||||
})
|
||||
|
||||
expect(plugin).toBeDefined()
|
||||
})
|
||||
|
||||
it('should accept custom parseToolUse', () => {
|
||||
const customParseToolUse = vi.fn(() => ({ results: [], content: '' }))
|
||||
|
||||
const plugin = createPromptToolUsePlugin({
|
||||
parseToolUse: customParseToolUse
|
||||
})
|
||||
|
||||
expect(plugin).toBeDefined()
|
||||
})
|
||||
|
||||
it('should accept enabled flag', () => {
|
||||
const pluginDisabled = createPromptToolUsePlugin({ enabled: false })
|
||||
const pluginEnabled = createPromptToolUsePlugin({ enabled: true })
|
||||
|
||||
expect(pluginDisabled).toBeDefined()
|
||||
expect(pluginEnabled).toBeDefined()
|
||||
})
|
||||
})
|
||||
|
||||
describe('transformParams', () => {
|
||||
it('should separate provider and prompt tools', async () => {
|
||||
const plugin = createPromptToolUsePlugin()
|
||||
const context = createMockContext()
|
||||
const params = createMockStreamParams({
|
||||
tools: createMockToolSet({
|
||||
provider_tool: 'provider',
|
||||
prompt_tool: 'function'
|
||||
})
|
||||
})
|
||||
|
||||
const result = await Promise.resolve(plugin.transformParams!(params, context))
|
||||
|
||||
// Provider tools should remain in tools
|
||||
expect(result.tools).toBeDefined()
|
||||
expect(result.tools).toHaveProperty('provider_tool')
|
||||
expect(result.tools).not.toHaveProperty('prompt_tool')
|
||||
|
||||
// Prompt tools should be moved to context.mcpTools
|
||||
expect(context.mcpTools).toBeDefined()
|
||||
expect(context.mcpTools).toHaveProperty('prompt_tool')
|
||||
expect(context.mcpTools).not.toHaveProperty('provider_tool')
|
||||
})
|
||||
|
||||
it('should handle only provider tools', async () => {
|
||||
const plugin = createPromptToolUsePlugin()
|
||||
const context = createMockContext()
|
||||
const params = createMockStreamParams({
|
||||
tools: createMockToolSet({
|
||||
provider_tool1: 'provider',
|
||||
provider_tool2: 'provider'
|
||||
})
|
||||
})
|
||||
|
||||
const result = await Promise.resolve(plugin.transformParams!(params, context))
|
||||
|
||||
expect(result.tools).toEqual(params.tools)
|
||||
expect(context.mcpTools).toBeUndefined()
|
||||
})
|
||||
|
||||
it('should handle only prompt tools', async () => {
|
||||
const plugin = createPromptToolUsePlugin()
|
||||
const context = createMockContext()
|
||||
const params = createMockStreamParams({
|
||||
tools: createMockToolSet({
|
||||
prompt_tool1: 'function',
|
||||
prompt_tool2: 'function'
|
||||
})
|
||||
})
|
||||
|
||||
const result = await Promise.resolve(plugin.transformParams!(params, context))
|
||||
|
||||
expect(result.tools).toBeUndefined()
|
||||
expect(context.mcpTools).toEqual(params.tools)
|
||||
})
|
||||
|
||||
it('should build system prompt for prompt tools', async () => {
|
||||
const plugin = createPromptToolUsePlugin()
|
||||
const context = createMockContext()
|
||||
const params = createMockStreamParams({
|
||||
system: 'Original system prompt',
|
||||
tools: {
|
||||
test_tool: createMockTool('test_tool', 'Test tool description')
|
||||
}
|
||||
})
|
||||
|
||||
const result = await Promise.resolve(plugin.transformParams!(params, context))
|
||||
|
||||
expect(result.system).toBeDefined()
|
||||
expect(typeof result.system).toBe('string')
|
||||
expect(result.system).toContain('In this environment you have access to a set of tools')
|
||||
expect(result.system).toContain('test_tool')
|
||||
expect(result.system).toContain('Test tool description')
|
||||
expect(result.system).toContain('Original system prompt')
|
||||
})
|
||||
|
||||
it('should handle empty user system prompt', async () => {
|
||||
const plugin = createPromptToolUsePlugin()
|
||||
const context = createMockContext()
|
||||
const params = createMockStreamParams({
|
||||
tools: {
|
||||
test_tool: createMockTool('test_tool')
|
||||
}
|
||||
})
|
||||
|
||||
const result = await Promise.resolve(plugin.transformParams!(params, context))
|
||||
|
||||
expect(result.system).toBeDefined()
|
||||
expect(result.system).toContain('In this environment you have access to a set of tools')
|
||||
})
|
||||
|
||||
it('should skip system prompt when disabled', async () => {
|
||||
const plugin = createPromptToolUsePlugin({ enabled: false })
|
||||
const context = createMockContext()
|
||||
const params = createMockStreamParams({
|
||||
system: 'Original',
|
||||
tools: {
|
||||
test_tool: createMockTool('test_tool')
|
||||
}
|
||||
})
|
||||
|
||||
const result = await Promise.resolve(plugin.transformParams!(params, context))
|
||||
|
||||
expect(result).toEqual(params)
|
||||
expect(context.mcpTools).toBeUndefined()
|
||||
})
|
||||
|
||||
it('should skip when no tools provided', async () => {
|
||||
const plugin = createPromptToolUsePlugin()
|
||||
const context = createMockContext()
|
||||
const params = createMockStreamParams({
|
||||
system: 'Original'
|
||||
})
|
||||
|
||||
const result = await Promise.resolve(plugin.transformParams!(params, context))
|
||||
|
||||
expect(result).toEqual(params)
|
||||
})
|
||||
|
||||
it('should skip when tools is not an object', async () => {
|
||||
const plugin = createPromptToolUsePlugin()
|
||||
const context = createMockContext()
|
||||
const params = createMockStreamParams({
|
||||
system: 'Original',
|
||||
tools: 'invalid' as any
|
||||
})
|
||||
|
||||
const result = await Promise.resolve(plugin.transformParams!(params, context))
|
||||
|
||||
expect(result).toEqual(params)
|
||||
})
|
||||
|
||||
it('should use custom buildSystemPrompt when provided', async () => {
|
||||
const customBuildSystemPrompt = vi.fn((userSystemPrompt: string, tools: ToolSet) => {
|
||||
return `Custom prompt with ${Object.keys(tools).length} tools and user prompt: ${userSystemPrompt}`
|
||||
})
|
||||
|
||||
const plugin = createPromptToolUsePlugin({
|
||||
buildSystemPrompt: customBuildSystemPrompt
|
||||
})
|
||||
|
||||
const context = createMockContext()
|
||||
const params = createMockStreamParams({
|
||||
system: 'User prompt',
|
||||
tools: {
|
||||
tool1: createMockTool('tool1')
|
||||
}
|
||||
})
|
||||
|
||||
const result = await Promise.resolve(plugin.transformParams!(params, context))
|
||||
|
||||
expect(customBuildSystemPrompt).toHaveBeenCalled()
|
||||
expect(result.system).toBe('Custom prompt with 1 tools and user prompt: User prompt')
|
||||
})
|
||||
|
||||
it('should use custom createSystemMessage when provided', async () => {
|
||||
const customCreateSystemMessage = vi.fn(() => {
|
||||
return `Modified system message`
|
||||
})
|
||||
|
||||
const plugin = createPromptToolUsePlugin({
|
||||
createSystemMessage: customCreateSystemMessage
|
||||
})
|
||||
|
||||
const context = createMockContext()
|
||||
const params = createMockStreamParams({
|
||||
system: 'Original',
|
||||
tools: {
|
||||
test: createMockTool('test')
|
||||
}
|
||||
})
|
||||
|
||||
const result = await Promise.resolve(plugin.transformParams!(params, context))
|
||||
|
||||
expect(customCreateSystemMessage).toHaveBeenCalled()
|
||||
expect(result.system).toContain('Modified')
|
||||
})
|
||||
|
||||
it('should save originalParams to context', async () => {
|
||||
const plugin = createPromptToolUsePlugin()
|
||||
const context = createMockContext()
|
||||
const params = createMockStreamParams({
|
||||
system: 'Original',
|
||||
tools: {
|
||||
test: createMockTool('test')
|
||||
}
|
||||
})
|
||||
|
||||
await Promise.resolve(plugin.transformParams!(params, context))
|
||||
|
||||
expect(context.originalParams).toBeDefined()
|
||||
expect(context.originalParams.system).toBeDefined()
|
||||
})
|
||||
})
|
||||
|
||||
describe('transformStream', () => {
|
||||
it('should return identity transform when disabled', async () => {
|
||||
const plugin = createPromptToolUsePlugin({ enabled: false })
|
||||
const context = createMockContext()
|
||||
|
||||
const inputChunks: Array<{ type: 'text-delta'; text: string }> = [
|
||||
{ type: 'text-delta', text: 'Hello' },
|
||||
{ type: 'text-delta', text: ' World' }
|
||||
]
|
||||
|
||||
const inputStream = simulateReadableStream<TextStreamPart<ToolSet>>({
|
||||
chunks: inputChunks as TextStreamPart<ToolSet>[],
|
||||
initialDelayInMs: 0,
|
||||
chunkDelayInMs: 0
|
||||
})
|
||||
|
||||
const transform = plugin.transformStream!(createMockStreamParams(), context)()
|
||||
const result = await convertReadableStreamToArray(inputStream.pipeThrough(transform))
|
||||
|
||||
expect(result).toEqual(inputChunks)
|
||||
})
|
||||
|
||||
it('should return identity transform when no mcpTools in context', async () => {
|
||||
const plugin = createPromptToolUsePlugin()
|
||||
const context = createMockContext()
|
||||
// Don't set context.mcpTools
|
||||
|
||||
const inputChunks: Array<{ type: 'text-delta'; text: string }> = [
|
||||
{ type: 'text-delta', text: 'Hello' },
|
||||
{ type: 'text-delta', text: ' World' }
|
||||
]
|
||||
|
||||
const inputStream = simulateReadableStream<TextStreamPart<ToolSet>>({
|
||||
chunks: inputChunks as TextStreamPart<ToolSet>[],
|
||||
initialDelayInMs: 0,
|
||||
chunkDelayInMs: 0
|
||||
})
|
||||
|
||||
const transform = plugin.transformStream!(createMockStreamParams(), context)()
|
||||
const result = await convertReadableStreamToArray(inputStream.pipeThrough(transform))
|
||||
|
||||
expect(result).toEqual(inputChunks)
|
||||
})
|
||||
|
||||
it('should initialize accumulatedUsage in context', () => {
|
||||
const plugin = createPromptToolUsePlugin()
|
||||
const context = createMockContext()
|
||||
context.mcpTools = {
|
||||
test: createMockTool('test')
|
||||
}
|
||||
|
||||
plugin.transformStream!(createMockStreamParams(), context)()
|
||||
|
||||
expect(context.accumulatedUsage).toBeDefined()
|
||||
expect(context.accumulatedUsage).toEqual({
|
||||
inputTokens: 0,
|
||||
outputTokens: 0,
|
||||
totalTokens: 0,
|
||||
reasoningTokens: 0,
|
||||
cachedInputTokens: 0
|
||||
})
|
||||
})
|
||||
|
||||
it('should filter tool tags from text-delta chunks', async () => {
|
||||
const plugin = createPromptToolUsePlugin()
|
||||
const context = createMockContext()
|
||||
context.mcpTools = {
|
||||
test: createMockTool('test')
|
||||
}
|
||||
|
||||
const inputChunks = [
|
||||
{ type: 'text-start' as const },
|
||||
{ type: 'text-delta' as const, text: 'Before ' },
|
||||
{ type: 'text-delta' as const, text: '<tool_use>' },
|
||||
{ type: 'text-delta' as const, text: '<name>test</name>' },
|
||||
{ type: 'text-delta' as const, text: '<arguments>{}</arguments>' },
|
||||
{ type: 'text-delta' as const, text: '</tool_use>' },
|
||||
{ type: 'text-delta' as const, text: ' After' },
|
||||
{ type: 'text-end' as const }
|
||||
]
|
||||
|
||||
const inputStream = simulateReadableStream<TextStreamPart<ToolSet>>({
|
||||
chunks: inputChunks as TextStreamPart<ToolSet>[],
|
||||
initialDelayInMs: 0,
|
||||
chunkDelayInMs: 0
|
||||
})
|
||||
|
||||
const transform = plugin.transformStream!(createMockStreamParams(), context)()
|
||||
const result = await convertReadableStreamToArray(inputStream.pipeThrough(transform))
|
||||
|
||||
// Extract text from text-delta chunks
|
||||
const textChunks = result.filter((chunk) => chunk.type === 'text-delta')
|
||||
const fullText = textChunks.map((chunk) => 'text' in chunk && chunk.text).join('')
|
||||
|
||||
// Tool tags should be filtered out
|
||||
expect(fullText).not.toContain('<tool_use>')
|
||||
expect(fullText).not.toContain('</tool_use>')
|
||||
expect(fullText).toContain('Before')
|
||||
expect(fullText).toContain('After')
|
||||
})
|
||||
|
||||
it('should hold text-start until non-tag content appears', async () => {
|
||||
const plugin = createPromptToolUsePlugin()
|
||||
const context = createMockContext()
|
||||
context.mcpTools = {
|
||||
test: createMockTool('test')
|
||||
}
|
||||
|
||||
// Only tool tags, no actual content
|
||||
const inputChunks = [
|
||||
{ type: 'text-start' as const },
|
||||
{ type: 'text-delta' as const, text: '<tool_use>' },
|
||||
{ type: 'text-delta' as const, text: '<name>test</name>' },
|
||||
{ type: 'text-delta' as const, text: '<arguments>{}</arguments>' },
|
||||
{ type: 'text-delta' as const, text: '</tool_use>' },
|
||||
{ type: 'text-end' as const }
|
||||
]
|
||||
|
||||
const inputStream = simulateReadableStream<TextStreamPart<ToolSet>>({
|
||||
chunks: inputChunks as TextStreamPart<ToolSet>[],
|
||||
initialDelayInMs: 0,
|
||||
chunkDelayInMs: 0
|
||||
})
|
||||
|
||||
const transform = plugin.transformStream!(createMockStreamParams(), context)()
|
||||
const result = await convertReadableStreamToArray(inputStream.pipeThrough(transform))
|
||||
|
||||
// Should not have text-start or text-end since all content was tool tags
|
||||
expect(result.some((chunk) => chunk.type === 'text-start')).toBe(false)
|
||||
expect(result.some((chunk) => chunk.type === 'text-end')).toBe(false)
|
||||
})
|
||||
|
||||
it('should send text-start when non-tag content appears', async () => {
|
||||
const plugin = createPromptToolUsePlugin()
|
||||
const context = createMockContext()
|
||||
context.mcpTools = {
|
||||
test: createMockTool('test')
|
||||
}
|
||||
|
||||
const inputChunks = [
|
||||
{ type: 'text-start' as const },
|
||||
{ type: 'text-delta' as const, text: 'Actual content' },
|
||||
{ type: 'text-end' as const }
|
||||
]
|
||||
|
||||
const inputStream = simulateReadableStream<TextStreamPart<ToolSet>>({
|
||||
chunks: inputChunks as TextStreamPart<ToolSet>[],
|
||||
initialDelayInMs: 0,
|
||||
chunkDelayInMs: 0
|
||||
})
|
||||
|
||||
const transform = plugin.transformStream!(createMockStreamParams(), context)()
|
||||
const result = await convertReadableStreamToArray(inputStream.pipeThrough(transform))
|
||||
|
||||
// Should have text-start, text-delta, and text-end
|
||||
expect(result.some((chunk) => chunk.type === 'text-start')).toBe(true)
|
||||
expect(result.some((chunk) => chunk.type === 'text-delta')).toBe(true)
|
||||
expect(result.some((chunk) => chunk.type === 'text-end')).toBe(true)
|
||||
})
|
||||
|
||||
it('should pass through non-text events', async () => {
|
||||
const plugin = createPromptToolUsePlugin()
|
||||
const context = createMockContext()
|
||||
context.mcpTools = {
|
||||
test: createMockTool('test')
|
||||
}
|
||||
|
||||
const stepStartEvent = { type: 'start-step' as const, request: {}, warnings: [] }
|
||||
|
||||
const inputChunks = [stepStartEvent]
|
||||
|
||||
const inputStream = simulateReadableStream<TextStreamPart<ToolSet>>({
|
||||
chunks: inputChunks as TextStreamPart<ToolSet>[],
|
||||
initialDelayInMs: 0,
|
||||
chunkDelayInMs: 0
|
||||
})
|
||||
|
||||
const transform = plugin.transformStream!(createMockStreamParams(), context)()
|
||||
const result = await convertReadableStreamToArray(inputStream.pipeThrough(transform))
|
||||
|
||||
expect(result[0]).toEqual(stepStartEvent)
|
||||
})
|
||||
|
||||
it('should accumulate usage from finish-step events', async () => {
|
||||
const plugin = createPromptToolUsePlugin()
|
||||
const context = createMockContext()
|
||||
context.mcpTools = {
|
||||
test: createMockTool('test')
|
||||
}
|
||||
|
||||
const inputChunks = [
|
||||
{
|
||||
type: 'finish-step' as const,
|
||||
finishReason: 'stop' as const,
|
||||
usage: {
|
||||
inputTokens: 10,
|
||||
outputTokens: 20,
|
||||
totalTokens: 30
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
const inputStream = simulateReadableStream<TextStreamPart<ToolSet>>({
|
||||
chunks: inputChunks as TextStreamPart<ToolSet>[],
|
||||
initialDelayInMs: 0,
|
||||
chunkDelayInMs: 0
|
||||
})
|
||||
|
||||
const transform = plugin.transformStream!(createMockStreamParams(), context)()
|
||||
await convertReadableStreamToArray(inputStream.pipeThrough(transform))
|
||||
|
||||
// Verify usage was accumulated
|
||||
expect(context.accumulatedUsage).toBeDefined()
|
||||
expect(context.accumulatedUsage!.inputTokens).toBe(10)
|
||||
expect(context.accumulatedUsage!.outputTokens).toBe(20)
|
||||
expect(context.accumulatedUsage!.totalTokens).toBe(30)
|
||||
})
|
||||
|
||||
it('should include accumulated usage in finish event', async () => {
|
||||
const plugin = createPromptToolUsePlugin()
|
||||
const context = createMockContext()
|
||||
context.mcpTools = {
|
||||
test: createMockTool('test')
|
||||
}
|
||||
|
||||
// Pre-populate accumulated usage
|
||||
context.accumulatedUsage = {
|
||||
inputTokens: 5,
|
||||
outputTokens: 10,
|
||||
totalTokens: 15,
|
||||
reasoningTokens: 0,
|
||||
cachedInputTokens: 0
|
||||
}
|
||||
|
||||
const inputChunks = [
|
||||
{
|
||||
type: 'finish' as const,
|
||||
finishReason: 'stop' as const,
|
||||
usage: {
|
||||
inputTokens: 100,
|
||||
outputTokens: 200,
|
||||
totalTokens: 300
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
const inputStream = simulateReadableStream<TextStreamPart<ToolSet>>({
|
||||
chunks: inputChunks as unknown as TextStreamPart<ToolSet>[],
|
||||
initialDelayInMs: 0,
|
||||
chunkDelayInMs: 0
|
||||
})
|
||||
|
||||
const transform = plugin.transformStream!(createMockStreamParams(), context)()
|
||||
const result = await convertReadableStreamToArray(inputStream.pipeThrough(transform))
|
||||
|
||||
const finishEvent = result.find((chunk) => chunk.type === 'finish')
|
||||
expect(finishEvent).toBeDefined()
|
||||
if (finishEvent && 'totalUsage' in finishEvent) {
|
||||
expect(finishEvent.totalUsage).toEqual(context.accumulatedUsage)
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
describe('Type Safety', () => {
|
||||
it('should have correct generic parameters for StreamTextParams and StreamTextResult', () => {
|
||||
const plugin = createPromptToolUsePlugin()
|
||||
|
||||
// Type assertion to verify the plugin has the correct type
|
||||
type PluginType = typeof plugin
|
||||
const typeTest: PluginType = plugin
|
||||
|
||||
expect(typeTest.name).toBe('built-in:prompt-tool-use')
|
||||
})
|
||||
})
|
||||
|
||||
describe('DEFAULT_SYSTEM_PROMPT', () => {
|
||||
it('should contain required sections', () => {
|
||||
expect(DEFAULT_SYSTEM_PROMPT).toContain('Tool Use Formatting')
|
||||
expect(DEFAULT_SYSTEM_PROMPT).toContain('Tool Use Examples')
|
||||
expect(DEFAULT_SYSTEM_PROMPT).toContain('Tool Use Available Tools')
|
||||
expect(DEFAULT_SYSTEM_PROMPT).toContain('Tool Use Rules')
|
||||
expect(DEFAULT_SYSTEM_PROMPT).toContain('Response rules')
|
||||
})
|
||||
|
||||
it('should have placeholders for dynamic content', () => {
|
||||
expect(DEFAULT_SYSTEM_PROMPT).toContain('{{ TOOL_USE_EXAMPLES }}')
|
||||
expect(DEFAULT_SYSTEM_PROMPT).toContain('{{ AVAILABLE_TOOLS }}')
|
||||
expect(DEFAULT_SYSTEM_PROMPT).toContain('{{ USER_SYSTEM_PROMPT }}')
|
||||
})
|
||||
|
||||
it('should contain XML tag examples', () => {
|
||||
expect(DEFAULT_SYSTEM_PROMPT).toContain('<tool_use>')
|
||||
expect(DEFAULT_SYSTEM_PROMPT).toContain('</tool_use>')
|
||||
expect(DEFAULT_SYSTEM_PROMPT).toContain('<name>')
|
||||
expect(DEFAULT_SYSTEM_PROMPT).toContain('<arguments>')
|
||||
})
|
||||
})
|
||||
})
|
||||
@ -33,8 +33,8 @@ export function createContext<T extends ProviderId, TParams = unknown, TResult =
|
||||
startTime: Date.now(),
|
||||
requestId: `${providerId}-${typeof model === 'string' ? model : model?.modelId}-${Date.now()}-${Math.random().toString(36).slice(2)}`,
|
||||
isRecursiveCall: false,
|
||||
recursiveDepth: 0, // 初始化递归深度为 0
|
||||
maxRecursiveDepth: 10, // 默认最大递归深度为 10
|
||||
recursiveDepth: 0, // 初始化递归深度为 0
|
||||
maxRecursiveDepth: 10, // 默认最大递归深度为 10
|
||||
extensions: new Map(),
|
||||
middlewares: [],
|
||||
// 占位递归调用函数,实际使用时会被 PluginEngine 替换
|
||||
|
||||
@ -72,10 +72,7 @@ export class PluginManager<TParams = unknown, TResult = unknown> {
|
||||
* 执行 transformParams 钩子 - 链式参数转换
|
||||
* 每个插件返回 Partial<TParams>,逐步合并到原始参数
|
||||
*/
|
||||
async executeTransformParams(
|
||||
initialValue: TParams,
|
||||
context: AiRequestContext<TParams, TResult>
|
||||
): Promise<TParams> {
|
||||
async executeTransformParams(initialValue: TParams, context: AiRequestContext<TParams, TResult>): Promise<TParams> {
|
||||
let result = initialValue
|
||||
|
||||
for (const plugin of this.plugins) {
|
||||
@ -93,10 +90,7 @@ export class PluginManager<TParams = unknown, TResult = unknown> {
|
||||
* 执行 transformResult 钩子 - 链式结果转换
|
||||
* 每个插件接收并返回完整的 TResult
|
||||
*/
|
||||
async executeTransformResult(
|
||||
initialValue: TResult,
|
||||
context: AiRequestContext<TParams, TResult>
|
||||
): Promise<TResult> {
|
||||
async executeTransformResult(initialValue: TResult, context: AiRequestContext<TParams, TResult>): Promise<TResult> {
|
||||
let result = initialValue
|
||||
|
||||
for (const plugin of this.plugins) {
|
||||
|
||||
525
packages/aiCore/src/core/providers/__tests__/HubProvider.test.ts
Normal file
525
packages/aiCore/src/core/providers/__tests__/HubProvider.test.ts
Normal file
@ -0,0 +1,525 @@
|
||||
/**
|
||||
* HubProvider Comprehensive Tests
|
||||
* Tests hub provider routing, model resolution, and error handling
|
||||
* Covers multi-provider routing with namespaced model IDs
|
||||
*/
|
||||
|
||||
import type { EmbeddingModelV3, ImageModelV3, LanguageModelV3, ProviderV3 } from '@ai-sdk/provider'
|
||||
import { customProvider, wrapProvider } from 'ai'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import { createMockEmbeddingModel, createMockImageModel, createMockLanguageModel } from '../../../__tests__'
|
||||
import { createHubProvider, type HubProviderConfig, HubProviderError } from '../HubProvider'
|
||||
import { DEFAULT_SEPARATOR, globalRegistryManagement } from '../RegistryManagement'
|
||||
|
||||
// Mock dependencies
|
||||
vi.mock('../RegistryManagement', () => ({
|
||||
globalRegistryManagement: {
|
||||
getProvider: vi.fn()
|
||||
},
|
||||
DEFAULT_SEPARATOR: '|'
|
||||
}))
|
||||
|
||||
vi.mock('ai', () => ({
|
||||
customProvider: vi.fn((config) => config.fallbackProvider),
|
||||
wrapProvider: vi.fn((config) => config.provider)
|
||||
}))
|
||||
|
||||
describe('HubProvider', () => {
|
||||
let mockOpenAIProvider: ProviderV3
|
||||
let mockAnthropicProvider: ProviderV3
|
||||
let mockLanguageModel: LanguageModelV3
|
||||
let mockEmbeddingModel: EmbeddingModelV3
|
||||
let mockImageModel: ImageModelV3
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
|
||||
// Create mock models using global utilities
|
||||
mockLanguageModel = createMockLanguageModel({
|
||||
provider: 'test',
|
||||
modelId: 'test-model'
|
||||
})
|
||||
|
||||
mockEmbeddingModel = createMockEmbeddingModel({
|
||||
provider: 'test',
|
||||
modelId: 'test-embedding'
|
||||
})
|
||||
|
||||
mockImageModel = createMockImageModel({
|
||||
provider: 'test',
|
||||
modelId: 'test-image'
|
||||
})
|
||||
|
||||
// Create mock providers
|
||||
mockOpenAIProvider = {
|
||||
specificationVersion: 'v3',
|
||||
languageModel: vi.fn().mockReturnValue(mockLanguageModel),
|
||||
embeddingModel: vi.fn().mockReturnValue(mockEmbeddingModel),
|
||||
imageModel: vi.fn().mockReturnValue(mockImageModel)
|
||||
} as ProviderV3
|
||||
|
||||
mockAnthropicProvider = {
|
||||
specificationVersion: 'v3',
|
||||
languageModel: vi.fn().mockReturnValue(mockLanguageModel),
|
||||
embeddingModel: vi.fn().mockReturnValue(mockEmbeddingModel),
|
||||
imageModel: vi.fn().mockReturnValue(mockImageModel)
|
||||
} as ProviderV3
|
||||
|
||||
// Setup default mock implementation
|
||||
vi.mocked(globalRegistryManagement.getProvider).mockImplementation((id) => {
|
||||
if (id === 'openai') return mockOpenAIProvider
|
||||
if (id === 'anthropic') return mockAnthropicProvider
|
||||
return undefined
|
||||
})
|
||||
})
|
||||
|
||||
describe('Provider Creation', () => {
|
||||
it('should create hub provider with basic config', () => {
|
||||
const config: HubProviderConfig = {
|
||||
hubId: 'test-hub'
|
||||
}
|
||||
|
||||
const provider = createHubProvider(config)
|
||||
|
||||
expect(provider).toBeDefined()
|
||||
expect(customProvider).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should create provider with debug flag', () => {
|
||||
const config: HubProviderConfig = {
|
||||
hubId: 'test-hub',
|
||||
debug: true
|
||||
}
|
||||
|
||||
const provider = createHubProvider(config)
|
||||
|
||||
expect(provider).toBeDefined()
|
||||
})
|
||||
|
||||
it('should return ProviderV3 specification', () => {
|
||||
const config: HubProviderConfig = {
|
||||
hubId: 'aihubmix'
|
||||
}
|
||||
|
||||
const provider = createHubProvider(config)
|
||||
|
||||
expect(provider).toHaveProperty('specificationVersion', 'v3')
|
||||
expect(provider).toHaveProperty('languageModel')
|
||||
expect(provider).toHaveProperty('embeddingModel')
|
||||
expect(provider).toHaveProperty('imageModel')
|
||||
})
|
||||
})
|
||||
|
||||
describe('Model ID Parsing', () => {
|
||||
it('should parse valid hub model ID format', () => {
|
||||
const config: HubProviderConfig = { hubId: 'test-hub' }
|
||||
const provider = createHubProvider(config) as ProviderV3
|
||||
|
||||
const modelId = `openai${DEFAULT_SEPARATOR}gpt-4`
|
||||
|
||||
const result = provider.languageModel(modelId)
|
||||
|
||||
expect(globalRegistryManagement.getProvider).toHaveBeenCalledWith('openai')
|
||||
expect(mockOpenAIProvider.languageModel).toHaveBeenCalledWith('gpt-4')
|
||||
expect(result).toBe(mockLanguageModel)
|
||||
})
|
||||
|
||||
it('should throw error for invalid model ID format', () => {
|
||||
const config: HubProviderConfig = { hubId: 'test-hub' }
|
||||
const provider = createHubProvider(config) as ProviderV3
|
||||
|
||||
const invalidId = 'invalid-id-without-separator'
|
||||
|
||||
expect(() => provider.languageModel(invalidId)).toThrow(HubProviderError)
|
||||
})
|
||||
|
||||
it('should throw error for model ID with multiple separators', () => {
|
||||
const config: HubProviderConfig = { hubId: 'test-hub' }
|
||||
const provider = createHubProvider(config) as ProviderV3
|
||||
|
||||
const multiSeparatorId = `provider${DEFAULT_SEPARATOR}extra${DEFAULT_SEPARATOR}model`
|
||||
|
||||
expect(() => provider.languageModel(multiSeparatorId)).toThrow(HubProviderError)
|
||||
})
|
||||
|
||||
it('should throw error for empty model ID', () => {
|
||||
const config: HubProviderConfig = { hubId: 'test-hub' }
|
||||
const provider = createHubProvider(config) as ProviderV3
|
||||
|
||||
expect(() => provider.languageModel('')).toThrow(HubProviderError)
|
||||
})
|
||||
|
||||
it('should throw error for model ID with only separator', () => {
|
||||
const config: HubProviderConfig = { hubId: 'test-hub' }
|
||||
const provider = createHubProvider(config) as ProviderV3
|
||||
|
||||
expect(() => provider.languageModel(DEFAULT_SEPARATOR)).toThrow(HubProviderError)
|
||||
})
|
||||
})
|
||||
|
||||
describe('Language Model Resolution', () => {
|
||||
it('should route to correct provider for language model', () => {
|
||||
const config: HubProviderConfig = { hubId: 'aihubmix' }
|
||||
const provider = createHubProvider(config) as ProviderV3
|
||||
|
||||
const result = provider.languageModel(`openai${DEFAULT_SEPARATOR}gpt-4`)
|
||||
|
||||
expect(globalRegistryManagement.getProvider).toHaveBeenCalledWith('openai')
|
||||
expect(mockOpenAIProvider.languageModel).toHaveBeenCalledWith('gpt-4')
|
||||
expect(result).toBe(mockLanguageModel)
|
||||
})
|
||||
|
||||
it('should route different providers correctly', () => {
|
||||
const config: HubProviderConfig = { hubId: 'aihubmix' }
|
||||
const provider = createHubProvider(config) as ProviderV3
|
||||
|
||||
provider.languageModel(`openai${DEFAULT_SEPARATOR}gpt-4`)
|
||||
provider.languageModel(`anthropic${DEFAULT_SEPARATOR}claude-3`)
|
||||
|
||||
expect(globalRegistryManagement.getProvider).toHaveBeenCalledWith('openai')
|
||||
expect(globalRegistryManagement.getProvider).toHaveBeenCalledWith('anthropic')
|
||||
expect(mockOpenAIProvider.languageModel).toHaveBeenCalledWith('gpt-4')
|
||||
expect(mockAnthropicProvider.languageModel).toHaveBeenCalledWith('claude-3')
|
||||
})
|
||||
|
||||
it('should wrap provider with wrapProvider', () => {
|
||||
const config: HubProviderConfig = { hubId: 'test-hub' }
|
||||
const provider = createHubProvider(config) as ProviderV3
|
||||
|
||||
provider.languageModel(`openai${DEFAULT_SEPARATOR}gpt-4`)
|
||||
|
||||
expect(wrapProvider).toHaveBeenCalledWith({
|
||||
provider: mockOpenAIProvider,
|
||||
languageModelMiddleware: []
|
||||
})
|
||||
})
|
||||
|
||||
it('should throw HubProviderError if provider not initialized', () => {
|
||||
vi.mocked(globalRegistryManagement.getProvider).mockReturnValue(undefined)
|
||||
|
||||
const config: HubProviderConfig = { hubId: 'test-hub' }
|
||||
const provider = createHubProvider(config) as ProviderV3
|
||||
|
||||
expect(() => provider.languageModel(`uninitialized${DEFAULT_SEPARATOR}model`)).toThrow(HubProviderError)
|
||||
expect(() => provider.languageModel(`uninitialized${DEFAULT_SEPARATOR}model`)).toThrow(/not initialized/)
|
||||
})
|
||||
|
||||
it('should include provider ID in error message', () => {
|
||||
vi.mocked(globalRegistryManagement.getProvider).mockReturnValue(undefined)
|
||||
|
||||
const config: HubProviderConfig = { hubId: 'test-hub' }
|
||||
const provider = createHubProvider(config) as ProviderV3
|
||||
|
||||
try {
|
||||
provider.languageModel(`missing${DEFAULT_SEPARATOR}model`)
|
||||
expect.fail('Should have thrown HubProviderError')
|
||||
} catch (error) {
|
||||
expect(error).toBeInstanceOf(HubProviderError)
|
||||
const hubError = error as HubProviderError
|
||||
expect(hubError.providerId).toBe('missing')
|
||||
expect(hubError.hubId).toBe('test-hub')
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
describe('Embedding Model Resolution', () => {
|
||||
it('should route to correct provider for embedding model', () => {
|
||||
const config: HubProviderConfig = { hubId: 'aihubmix' }
|
||||
const provider = createHubProvider(config) as ProviderV3
|
||||
|
||||
const result = provider.embeddingModel(`openai${DEFAULT_SEPARATOR}text-embedding-3-small`)
|
||||
|
||||
expect(globalRegistryManagement.getProvider).toHaveBeenCalledWith('openai')
|
||||
expect(mockOpenAIProvider.embeddingModel).toHaveBeenCalledWith('text-embedding-3-small')
|
||||
expect(result).toBe(mockEmbeddingModel)
|
||||
})
|
||||
|
||||
it('should handle different embedding providers', () => {
|
||||
const config: HubProviderConfig = { hubId: 'aihubmix' }
|
||||
const provider = createHubProvider(config) as ProviderV3
|
||||
|
||||
provider.embeddingModel(`openai${DEFAULT_SEPARATOR}ada-002`)
|
||||
provider.embeddingModel(`anthropic${DEFAULT_SEPARATOR}embed-v1`)
|
||||
|
||||
expect(mockOpenAIProvider.embeddingModel).toHaveBeenCalledWith('ada-002')
|
||||
expect(mockAnthropicProvider.embeddingModel).toHaveBeenCalledWith('embed-v1')
|
||||
})
|
||||
|
||||
it('should throw error for uninitialized embedding provider', () => {
|
||||
vi.mocked(globalRegistryManagement.getProvider).mockReturnValue(undefined)
|
||||
|
||||
const config: HubProviderConfig = { hubId: 'test-hub' }
|
||||
const provider = createHubProvider(config) as ProviderV3
|
||||
|
||||
expect(() => provider.embeddingModel(`missing${DEFAULT_SEPARATOR}embed`)).toThrow(HubProviderError)
|
||||
})
|
||||
})
|
||||
|
||||
describe('Image Model Resolution', () => {
|
||||
it('should route to correct provider for image model', () => {
|
||||
const config: HubProviderConfig = { hubId: 'aihubmix' }
|
||||
const provider = createHubProvider(config) as ProviderV3
|
||||
|
||||
const result = provider.imageModel(`openai${DEFAULT_SEPARATOR}dall-e-3`)
|
||||
|
||||
expect(globalRegistryManagement.getProvider).toHaveBeenCalledWith('openai')
|
||||
expect(mockOpenAIProvider.imageModel).toHaveBeenCalledWith('dall-e-3')
|
||||
expect(result).toBe(mockImageModel)
|
||||
})
|
||||
|
||||
it('should handle different image providers', () => {
|
||||
const config: HubProviderConfig = { hubId: 'aihubmix' }
|
||||
const provider = createHubProvider(config) as ProviderV3
|
||||
|
||||
provider.imageModel(`openai${DEFAULT_SEPARATOR}dall-e-3`)
|
||||
provider.imageModel(`anthropic${DEFAULT_SEPARATOR}image-gen`)
|
||||
|
||||
expect(mockOpenAIProvider.imageModel).toHaveBeenCalledWith('dall-e-3')
|
||||
expect(mockAnthropicProvider.imageModel).toHaveBeenCalledWith('image-gen')
|
||||
})
|
||||
})
|
||||
|
||||
describe('Special Model Types', () => {
|
||||
it('should support transcription models', () => {
|
||||
const mockTranscriptionModel = {
|
||||
specificationVersion: 'v3',
|
||||
doTranscribe: vi.fn()
|
||||
}
|
||||
|
||||
const providerWithTranscription = {
|
||||
...mockOpenAIProvider,
|
||||
transcriptionModel: vi.fn().mockReturnValue(mockTranscriptionModel)
|
||||
} as ProviderV3
|
||||
|
||||
vi.mocked(globalRegistryManagement.getProvider).mockReturnValue(providerWithTranscription)
|
||||
|
||||
const config: HubProviderConfig = { hubId: 'test-hub' }
|
||||
const provider = createHubProvider(config) as ProviderV3
|
||||
|
||||
const result = provider.transcriptionModel!(`openai${DEFAULT_SEPARATOR}whisper-1`)
|
||||
|
||||
expect(providerWithTranscription.transcriptionModel).toHaveBeenCalledWith('whisper-1')
|
||||
expect(result).toBe(mockTranscriptionModel)
|
||||
})
|
||||
|
||||
it('should throw error if provider does not support transcription', () => {
|
||||
const config: HubProviderConfig = { hubId: 'test-hub' }
|
||||
const provider = createHubProvider(config) as ProviderV3
|
||||
|
||||
expect(() => provider.transcriptionModel!(`openai${DEFAULT_SEPARATOR}whisper`)).toThrow(HubProviderError)
|
||||
expect(() => provider.transcriptionModel!(`openai${DEFAULT_SEPARATOR}whisper`)).toThrow(
|
||||
/does not support transcription/
|
||||
)
|
||||
})
|
||||
|
||||
it('should support speech models', () => {
|
||||
const mockSpeechModel = {
|
||||
specificationVersion: 'v3',
|
||||
doGenerate: vi.fn()
|
||||
}
|
||||
|
||||
const providerWithSpeech = {
|
||||
...mockOpenAIProvider,
|
||||
speechModel: vi.fn().mockReturnValue(mockSpeechModel)
|
||||
} as ProviderV3
|
||||
|
||||
vi.mocked(globalRegistryManagement.getProvider).mockReturnValue(providerWithSpeech)
|
||||
|
||||
const config: HubProviderConfig = { hubId: 'test-hub' }
|
||||
const provider = createHubProvider(config) as ProviderV3
|
||||
|
||||
const result = provider.speechModel!(`openai${DEFAULT_SEPARATOR}tts-1`)
|
||||
|
||||
expect(providerWithSpeech.speechModel).toHaveBeenCalledWith('tts-1')
|
||||
expect(result).toBe(mockSpeechModel)
|
||||
})
|
||||
|
||||
it('should throw error if provider does not support speech', () => {
|
||||
const config: HubProviderConfig = { hubId: 'test-hub' }
|
||||
const provider = createHubProvider(config) as ProviderV3
|
||||
|
||||
expect(() => provider.speechModel!(`openai${DEFAULT_SEPARATOR}tts-1`)).toThrow(HubProviderError)
|
||||
expect(() => provider.speechModel!(`openai${DEFAULT_SEPARATOR}tts-1`)).toThrow(/does not support speech/)
|
||||
})
|
||||
|
||||
it('should support reranking models', () => {
|
||||
const mockRerankingModel = {
|
||||
specificationVersion: 'v3',
|
||||
doRerank: vi.fn()
|
||||
}
|
||||
|
||||
const providerWithReranking = {
|
||||
...mockOpenAIProvider,
|
||||
rerankingModel: vi.fn().mockReturnValue(mockRerankingModel)
|
||||
} as ProviderV3
|
||||
|
||||
vi.mocked(globalRegistryManagement.getProvider).mockReturnValue(providerWithReranking)
|
||||
|
||||
const config: HubProviderConfig = { hubId: 'test-hub' }
|
||||
const provider = createHubProvider(config) as ProviderV3
|
||||
|
||||
const result = provider.rerankingModel!(`openai${DEFAULT_SEPARATOR}rerank-v1`)
|
||||
|
||||
expect(providerWithReranking.rerankingModel).toHaveBeenCalledWith('rerank-v1')
|
||||
expect(result).toBe(mockRerankingModel)
|
||||
})
|
||||
|
||||
it('should throw error if provider does not support reranking', () => {
|
||||
const config: HubProviderConfig = { hubId: 'test-hub' }
|
||||
const provider = createHubProvider(config) as ProviderV3
|
||||
|
||||
expect(() => provider.rerankingModel!(`openai${DEFAULT_SEPARATOR}rerank`)).toThrow(HubProviderError)
|
||||
expect(() => provider.rerankingModel!(`openai${DEFAULT_SEPARATOR}rerank`)).toThrow(/does not support reranking/)
|
||||
})
|
||||
})
|
||||
|
||||
describe('Error Handling', () => {
|
||||
it('should create HubProviderError with all properties', () => {
|
||||
const originalError = new Error('Original error')
|
||||
const error = new HubProviderError('Test message', 'test-hub', 'test-provider', originalError)
|
||||
|
||||
expect(error.message).toBe('Test message')
|
||||
expect(error.hubId).toBe('test-hub')
|
||||
expect(error.providerId).toBe('test-provider')
|
||||
expect(error.originalError).toBe(originalError)
|
||||
expect(error.name).toBe('HubProviderError')
|
||||
})
|
||||
|
||||
it('should create HubProviderError without optional parameters', () => {
|
||||
const error = new HubProviderError('Test message', 'test-hub')
|
||||
|
||||
expect(error.message).toBe('Test message')
|
||||
expect(error.hubId).toBe('test-hub')
|
||||
expect(error.providerId).toBeUndefined()
|
||||
expect(error.originalError).toBeUndefined()
|
||||
})
|
||||
|
||||
it('should wrap provider errors in HubProviderError', () => {
|
||||
const providerError = new Error('Provider failed')
|
||||
vi.mocked(globalRegistryManagement.getProvider).mockImplementation(() => {
|
||||
throw providerError
|
||||
})
|
||||
|
||||
const config: HubProviderConfig = { hubId: 'test-hub' }
|
||||
const provider = createHubProvider(config) as ProviderV3
|
||||
|
||||
try {
|
||||
provider.languageModel(`failing${DEFAULT_SEPARATOR}model`)
|
||||
expect.fail('Should have thrown HubProviderError')
|
||||
} catch (error) {
|
||||
expect(error).toBeInstanceOf(HubProviderError)
|
||||
const hubError = error as HubProviderError
|
||||
expect(hubError.originalError).toBe(providerError)
|
||||
expect(hubError.message).toContain('Failed to get provider')
|
||||
}
|
||||
})
|
||||
|
||||
it('should handle null provider from registry', () => {
|
||||
vi.mocked(globalRegistryManagement.getProvider).mockReturnValue(null as any)
|
||||
|
||||
const config: HubProviderConfig = { hubId: 'test-hub' }
|
||||
const provider = createHubProvider(config) as ProviderV3
|
||||
|
||||
expect(() => provider.languageModel(`null-provider${DEFAULT_SEPARATOR}model`)).toThrow(HubProviderError)
|
||||
})
|
||||
})
|
||||
|
||||
describe('Multi-Provider Scenarios', () => {
|
||||
it('should handle sequential calls to different providers', () => {
|
||||
const config: HubProviderConfig = { hubId: 'aihubmix' }
|
||||
const provider = createHubProvider(config) as ProviderV3
|
||||
|
||||
provider.languageModel(`openai${DEFAULT_SEPARATOR}gpt-4`)
|
||||
provider.languageModel(`anthropic${DEFAULT_SEPARATOR}claude-3`)
|
||||
provider.languageModel(`openai${DEFAULT_SEPARATOR}gpt-3.5`)
|
||||
|
||||
expect(globalRegistryManagement.getProvider).toHaveBeenCalledTimes(3)
|
||||
expect(mockOpenAIProvider.languageModel).toHaveBeenCalledTimes(2)
|
||||
expect(mockAnthropicProvider.languageModel).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
it('should handle mixed model types from same provider', () => {
|
||||
const config: HubProviderConfig = { hubId: 'aihubmix' }
|
||||
const provider = createHubProvider(config) as ProviderV3
|
||||
|
||||
provider.languageModel(`openai${DEFAULT_SEPARATOR}gpt-4`)
|
||||
provider.embeddingModel(`openai${DEFAULT_SEPARATOR}ada-002`)
|
||||
provider.imageModel(`openai${DEFAULT_SEPARATOR}dall-e-3`)
|
||||
|
||||
expect(globalRegistryManagement.getProvider).toHaveBeenCalledTimes(3)
|
||||
expect(mockOpenAIProvider.languageModel).toHaveBeenCalledWith('gpt-4')
|
||||
expect(mockOpenAIProvider.embeddingModel).toHaveBeenCalledWith('ada-002')
|
||||
expect(mockOpenAIProvider.imageModel).toHaveBeenCalledWith('dall-e-3')
|
||||
})
|
||||
|
||||
it('should cache provider lookups', () => {
|
||||
const config: HubProviderConfig = { hubId: 'aihubmix' }
|
||||
const provider = createHubProvider(config) as ProviderV3
|
||||
|
||||
provider.languageModel(`openai${DEFAULT_SEPARATOR}gpt-4`)
|
||||
provider.languageModel(`openai${DEFAULT_SEPARATOR}gpt-3.5`)
|
||||
|
||||
// Should call getProvider twice (once per model call)
|
||||
expect(globalRegistryManagement.getProvider).toHaveBeenCalledTimes(2)
|
||||
})
|
||||
})
|
||||
|
||||
describe('Provider Wrapping', () => {
|
||||
it('should wrap all providers with empty middleware', () => {
|
||||
const config: HubProviderConfig = { hubId: 'test-hub' }
|
||||
const provider = createHubProvider(config) as ProviderV3
|
||||
|
||||
provider.languageModel(`openai${DEFAULT_SEPARATOR}gpt-4`)
|
||||
|
||||
expect(wrapProvider).toHaveBeenCalledWith({
|
||||
provider: mockOpenAIProvider,
|
||||
languageModelMiddleware: []
|
||||
})
|
||||
})
|
||||
|
||||
it('should wrap providers for all model types', () => {
|
||||
const config: HubProviderConfig = { hubId: 'test-hub' }
|
||||
const provider = createHubProvider(config) as ProviderV3
|
||||
|
||||
provider.languageModel(`openai${DEFAULT_SEPARATOR}gpt-4`)
|
||||
provider.embeddingModel(`openai${DEFAULT_SEPARATOR}ada`)
|
||||
provider.imageModel(`openai${DEFAULT_SEPARATOR}dalle`)
|
||||
|
||||
expect(wrapProvider).toHaveBeenCalledTimes(3)
|
||||
})
|
||||
})
|
||||
|
||||
describe('Type Safety', () => {
|
||||
it('should return properly typed LanguageModelV3', () => {
|
||||
const config: HubProviderConfig = { hubId: 'test-hub' }
|
||||
const provider = createHubProvider(config) as ProviderV3
|
||||
|
||||
const result = provider.languageModel(`openai${DEFAULT_SEPARATOR}gpt-4`)
|
||||
|
||||
expect(result.specificationVersion).toBe('v3')
|
||||
expect(result).toHaveProperty('doGenerate')
|
||||
expect(result).toHaveProperty('doStream')
|
||||
})
|
||||
|
||||
it('should return properly typed EmbeddingModelV3', () => {
|
||||
const config: HubProviderConfig = { hubId: 'test-hub' }
|
||||
const provider = createHubProvider(config) as ProviderV3
|
||||
|
||||
const result = provider.embeddingModel(`openai${DEFAULT_SEPARATOR}ada`)
|
||||
|
||||
expect(result.specificationVersion).toBe('v3')
|
||||
expect(result).toHaveProperty('doEmbed')
|
||||
})
|
||||
|
||||
it('should return properly typed ImageModelV3', () => {
|
||||
const config: HubProviderConfig = { hubId: 'test-hub' }
|
||||
const provider = createHubProvider(config) as ProviderV3
|
||||
|
||||
const result = provider.imageModel(`openai${DEFAULT_SEPARATOR}dalle`)
|
||||
|
||||
expect(result.specificationVersion).toBe('v3')
|
||||
expect(result).toHaveProperty('doGenerate')
|
||||
})
|
||||
})
|
||||
})
|
||||
@ -0,0 +1,561 @@
|
||||
/**
|
||||
* RegistryManagement Comprehensive Tests
|
||||
* Tests provider registry management, model resolution, and alias handling
|
||||
* Covers registration, retrieval, and cleanup operations
|
||||
*/
|
||||
|
||||
import type { EmbeddingModelV3, ImageModelV3, LanguageModelV3, ProviderV3 } from '@ai-sdk/provider'
|
||||
import { createProviderRegistry } from 'ai'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import { createMockEmbeddingModel, createMockImageModel, createMockLanguageModel } from '../../../__tests__'
|
||||
import { DEFAULT_SEPARATOR, RegistryManagement } from '../RegistryManagement'
|
||||
|
||||
// Mock AI SDK
|
||||
vi.mock('ai', () => ({
|
||||
createProviderRegistry: vi.fn()
|
||||
}))
|
||||
|
||||
describe('RegistryManagement', () => {
|
||||
let registry: RegistryManagement
|
||||
let mockProvider: ProviderV3
|
||||
let mockLanguageModel: LanguageModelV3
|
||||
let mockEmbeddingModel: EmbeddingModelV3
|
||||
let mockImageModel: ImageModelV3
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
|
||||
// Create mock models using global utilities
|
||||
mockLanguageModel = createMockLanguageModel({
|
||||
provider: 'test',
|
||||
modelId: 'test-model'
|
||||
})
|
||||
|
||||
mockEmbeddingModel = createMockEmbeddingModel({
|
||||
provider: 'test',
|
||||
modelId: 'test-embedding'
|
||||
})
|
||||
|
||||
mockImageModel = createMockImageModel({
|
||||
provider: 'test',
|
||||
modelId: 'test-image'
|
||||
})
|
||||
|
||||
// Create mock provider
|
||||
mockProvider = {
|
||||
specificationVersion: 'v3',
|
||||
languageModel: vi.fn().mockReturnValue(mockLanguageModel),
|
||||
embeddingModel: vi.fn().mockReturnValue(mockEmbeddingModel),
|
||||
imageModel: vi.fn().mockReturnValue(mockImageModel),
|
||||
transcriptionModel: vi.fn(),
|
||||
speechModel: vi.fn()
|
||||
} as ProviderV3
|
||||
|
||||
// Setup mock registry
|
||||
const mockRegistry = {
|
||||
languageModel: vi.fn().mockReturnValue(mockLanguageModel),
|
||||
embeddingModel: vi.fn().mockReturnValue(mockEmbeddingModel),
|
||||
imageModel: vi.fn().mockReturnValue(mockImageModel),
|
||||
transcriptionModel: vi.fn(),
|
||||
speechModel: vi.fn()
|
||||
}
|
||||
|
||||
vi.mocked(createProviderRegistry).mockReturnValue(mockRegistry as any)
|
||||
|
||||
registry = new RegistryManagement()
|
||||
})
|
||||
|
||||
describe('Constructor and Initialization', () => {
|
||||
it('should create registry with default separator', () => {
|
||||
const reg = new RegistryManagement()
|
||||
|
||||
expect(reg).toBeInstanceOf(RegistryManagement)
|
||||
expect(reg.hasProviders()).toBe(false)
|
||||
})
|
||||
|
||||
it('should create registry with custom separator', () => {
|
||||
const customSeparator = ':'
|
||||
const reg = new RegistryManagement({ separator: customSeparator })
|
||||
|
||||
expect(reg).toBeInstanceOf(RegistryManagement)
|
||||
})
|
||||
|
||||
it('should start with empty provider list', () => {
|
||||
expect(registry.getRegisteredProviders()).toEqual([])
|
||||
})
|
||||
})
|
||||
|
||||
describe('Provider Registration', () => {
|
||||
it('should register a provider', () => {
|
||||
registry.registerProvider('openai', mockProvider)
|
||||
|
||||
expect(registry.getProvider('openai')).toBe(mockProvider)
|
||||
expect(registry.hasProviders()).toBe(true)
|
||||
})
|
||||
|
||||
it('should register multiple providers', () => {
|
||||
const provider2 = { ...mockProvider }
|
||||
|
||||
registry.registerProvider('openai', mockProvider)
|
||||
registry.registerProvider('anthropic', provider2)
|
||||
|
||||
expect(registry.getProvider('openai')).toBe(mockProvider)
|
||||
expect(registry.getProvider('anthropic')).toBe(provider2)
|
||||
})
|
||||
|
||||
it('should return this for chaining', () => {
|
||||
const result = registry.registerProvider('openai', mockProvider)
|
||||
|
||||
expect(result).toBe(registry)
|
||||
})
|
||||
|
||||
it('should rebuild registry after registration', () => {
|
||||
registry.registerProvider('openai', mockProvider)
|
||||
|
||||
expect(createProviderRegistry).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
openai: mockProvider
|
||||
}),
|
||||
{ separator: DEFAULT_SEPARATOR }
|
||||
)
|
||||
})
|
||||
|
||||
it('should register provider with aliases', () => {
|
||||
registry.registerProvider('openai', mockProvider, ['gpt', 'chatgpt'])
|
||||
|
||||
expect(registry.getProvider('openai')).toBe(mockProvider)
|
||||
expect(registry.getProvider('gpt')).toBe(mockProvider)
|
||||
expect(registry.getProvider('chatgpt')).toBe(mockProvider)
|
||||
})
|
||||
|
||||
it('should track aliases separately', () => {
|
||||
registry.registerProvider('openai', mockProvider, ['gpt'])
|
||||
|
||||
expect(registry.isAlias('gpt')).toBe(true)
|
||||
expect(registry.isAlias('openai')).toBe(false)
|
||||
})
|
||||
|
||||
it('should handle multiple aliases for same provider', () => {
|
||||
const aliases = ['alias1', 'alias2', 'alias3']
|
||||
registry.registerProvider('provider', mockProvider, aliases)
|
||||
|
||||
aliases.forEach((alias) => {
|
||||
expect(registry.getProvider(alias)).toBe(mockProvider)
|
||||
expect(registry.isAlias(alias)).toBe(true)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('Bulk Registration', () => {
|
||||
it('should register multiple providers at once', () => {
|
||||
const providers = {
|
||||
openai: mockProvider,
|
||||
anthropic: { ...mockProvider },
|
||||
google: { ...mockProvider }
|
||||
}
|
||||
|
||||
registry.registerProviders(providers)
|
||||
|
||||
expect(registry.getProvider('openai')).toBe(providers.openai)
|
||||
expect(registry.getProvider('anthropic')).toBe(providers.anthropic)
|
||||
expect(registry.getProvider('google')).toBe(providers.google)
|
||||
})
|
||||
|
||||
it('should return this for chaining', () => {
|
||||
const result = registry.registerProviders({ openai: mockProvider })
|
||||
|
||||
expect(result).toBe(registry)
|
||||
})
|
||||
})
|
||||
|
||||
describe('Provider Retrieval', () => {
|
||||
beforeEach(() => {
|
||||
registry.registerProvider('openai', mockProvider)
|
||||
})
|
||||
|
||||
it('should retrieve registered provider', () => {
|
||||
const provider = registry.getProvider('openai')
|
||||
|
||||
expect(provider).toBe(mockProvider)
|
||||
})
|
||||
|
||||
it('should return undefined for unregistered provider', () => {
|
||||
const provider = registry.getProvider('nonexistent')
|
||||
|
||||
expect(provider).toBeUndefined()
|
||||
})
|
||||
|
||||
it('should retrieve provider by alias', () => {
|
||||
registry.registerProvider('anthropic', mockProvider, ['claude'])
|
||||
|
||||
const provider = registry.getProvider('claude')
|
||||
|
||||
expect(provider).toBe(mockProvider)
|
||||
})
|
||||
|
||||
it('should get list of all registered providers', () => {
|
||||
registry.registerProvider('anthropic', mockProvider)
|
||||
registry.registerProvider('google', mockProvider, ['gemini'])
|
||||
|
||||
const providers = registry.getRegisteredProviders()
|
||||
|
||||
expect(providers).toContain('openai')
|
||||
expect(providers).toContain('anthropic')
|
||||
expect(providers).toContain('google')
|
||||
expect(providers).toContain('gemini') // Aliases included
|
||||
})
|
||||
})
|
||||
|
||||
describe('Provider Unregistration', () => {
|
||||
it('should unregister provider', () => {
|
||||
registry.registerProvider('openai', mockProvider)
|
||||
|
||||
registry.unregisterProvider('openai')
|
||||
|
||||
expect(registry.getProvider('openai')).toBeUndefined()
|
||||
})
|
||||
|
||||
it('should unregister provider with all its aliases', () => {
|
||||
registry.registerProvider('openai', mockProvider, ['gpt', 'chatgpt'])
|
||||
|
||||
registry.unregisterProvider('openai')
|
||||
|
||||
expect(registry.getProvider('openai')).toBeUndefined()
|
||||
expect(registry.getProvider('gpt')).toBeUndefined()
|
||||
expect(registry.getProvider('chatgpt')).toBeUndefined()
|
||||
})
|
||||
|
||||
it('should unregister only alias when alias is removed', () => {
|
||||
registry.registerProvider('openai', mockProvider, ['gpt', 'chatgpt'])
|
||||
|
||||
registry.unregisterProvider('gpt')
|
||||
|
||||
expect(registry.getProvider('openai')).toBe(mockProvider)
|
||||
expect(registry.getProvider('gpt')).toBeUndefined()
|
||||
expect(registry.getProvider('chatgpt')).toBe(mockProvider)
|
||||
})
|
||||
|
||||
it('should handle unregistering non-existent provider', () => {
|
||||
expect(() => registry.unregisterProvider('nonexistent')).not.toThrow()
|
||||
})
|
||||
|
||||
it('should return this for chaining', () => {
|
||||
registry.registerProvider('openai', mockProvider)
|
||||
|
||||
const result = registry.unregisterProvider('openai')
|
||||
|
||||
expect(result).toBe(registry)
|
||||
})
|
||||
|
||||
it('should rebuild registry after unregistration', () => {
|
||||
registry.registerProvider('openai', mockProvider)
|
||||
vi.clearAllMocks()
|
||||
|
||||
registry.unregisterProvider('openai')
|
||||
|
||||
// Should rebuild with empty providers
|
||||
expect(createProviderRegistry).not.toHaveBeenCalled() // No rebuild when empty
|
||||
})
|
||||
})
|
||||
|
||||
describe('Model Resolution', () => {
|
||||
beforeEach(() => {
|
||||
registry.registerProvider('openai', mockProvider)
|
||||
})
|
||||
|
||||
it('should resolve language model', () => {
|
||||
const modelId = `openai${DEFAULT_SEPARATOR}gpt-4` as any
|
||||
|
||||
const result = registry.languageModel(modelId)
|
||||
|
||||
expect(result).toBe(mockLanguageModel)
|
||||
})
|
||||
|
||||
it('should resolve embedding model', () => {
|
||||
const modelId = `openai${DEFAULT_SEPARATOR}text-embedding-3-small` as any
|
||||
|
||||
const result = registry.embeddingModel(modelId)
|
||||
|
||||
expect(result).toBe(mockEmbeddingModel)
|
||||
})
|
||||
|
||||
it('should resolve image model', () => {
|
||||
const modelId = `openai${DEFAULT_SEPARATOR}dall-e-3` as any
|
||||
|
||||
const result = registry.imageModel(modelId)
|
||||
|
||||
expect(result).toBe(mockImageModel)
|
||||
})
|
||||
|
||||
it('should resolve transcription model', () => {
|
||||
const modelId = `openai${DEFAULT_SEPARATOR}whisper-1` as any
|
||||
|
||||
registry.transcriptionModel(modelId)
|
||||
|
||||
// Verify it calls through to the mock registry
|
||||
expect(createProviderRegistry).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should resolve speech model', () => {
|
||||
const modelId = `openai${DEFAULT_SEPARATOR}tts-1` as any
|
||||
|
||||
registry.speechModel(modelId)
|
||||
|
||||
expect(createProviderRegistry).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should throw error when no providers registered', () => {
|
||||
const emptyRegistry = new RegistryManagement()
|
||||
|
||||
expect(() => emptyRegistry.languageModel('openai|gpt-4' as any)).toThrow('No providers registered')
|
||||
})
|
||||
})
|
||||
|
||||
describe('Alias Management', () => {
|
||||
it('should resolve provider ID from alias', () => {
|
||||
registry.registerProvider('openai', mockProvider, ['gpt'])
|
||||
|
||||
const realId = registry.resolveProviderId('gpt')
|
||||
|
||||
expect(realId).toBe('openai')
|
||||
})
|
||||
|
||||
it('should return same ID if not an alias', () => {
|
||||
registry.registerProvider('openai', mockProvider)
|
||||
|
||||
const realId = registry.resolveProviderId('openai')
|
||||
|
||||
expect(realId).toBe('openai')
|
||||
})
|
||||
|
||||
it('should check if ID is alias', () => {
|
||||
registry.registerProvider('openai', mockProvider, ['gpt'])
|
||||
|
||||
expect(registry.isAlias('gpt')).toBe(true)
|
||||
expect(registry.isAlias('openai')).toBe(false)
|
||||
})
|
||||
|
||||
it('should get all alias mappings', () => {
|
||||
const provider2 = { ...mockProvider }
|
||||
|
||||
registry.registerProvider('openai', mockProvider, ['gpt', 'chatgpt'])
|
||||
registry.registerProvider('anthropic', provider2, ['claude'])
|
||||
|
||||
const aliases = registry.getAllAliases()
|
||||
|
||||
// Check that all aliases are present
|
||||
expect(aliases['gpt']).toBe('openai')
|
||||
expect(aliases['chatgpt']).toBe('openai')
|
||||
expect(aliases['claude']).toBe('anthropic')
|
||||
})
|
||||
|
||||
it('should return empty object when no aliases', () => {
|
||||
registry.registerProvider('openai', mockProvider)
|
||||
|
||||
const aliases = registry.getAllAliases()
|
||||
|
||||
expect(aliases).toEqual({})
|
||||
})
|
||||
})
|
||||
|
||||
describe('Registry State', () => {
|
||||
it('should check if has providers', () => {
|
||||
expect(registry.hasProviders()).toBe(false)
|
||||
|
||||
registry.registerProvider('openai', mockProvider)
|
||||
|
||||
expect(registry.hasProviders()).toBe(true)
|
||||
})
|
||||
|
||||
it('should clear all providers', () => {
|
||||
registry.registerProvider('openai', mockProvider, ['gpt'])
|
||||
registry.registerProvider('anthropic', mockProvider)
|
||||
|
||||
registry.clear()
|
||||
|
||||
expect(registry.hasProviders()).toBe(false)
|
||||
expect(registry.getRegisteredProviders()).toEqual([])
|
||||
expect(registry.getAllAliases()).toEqual({})
|
||||
})
|
||||
|
||||
it('should return this after clear for chaining', () => {
|
||||
const result = registry.clear()
|
||||
|
||||
expect(result).toBe(registry)
|
||||
})
|
||||
})
|
||||
|
||||
describe('Registry Rebuilding', () => {
|
||||
it('should rebuild registry when provider added', () => {
|
||||
registry.registerProvider('openai', mockProvider)
|
||||
|
||||
expect(createProviderRegistry).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
it('should rebuild registry when provider removed', () => {
|
||||
registry.registerProvider('openai', mockProvider)
|
||||
registry.registerProvider('anthropic', mockProvider)
|
||||
|
||||
vi.clearAllMocks()
|
||||
|
||||
registry.unregisterProvider('openai')
|
||||
|
||||
expect(createProviderRegistry).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
it('should set registry to null when all providers removed', () => {
|
||||
registry.registerProvider('openai', mockProvider)
|
||||
registry.unregisterProvider('openai')
|
||||
|
||||
expect(() => registry.languageModel('any|model' as any)).toThrow('No providers registered')
|
||||
})
|
||||
|
||||
it('should rebuild with correct separator', () => {
|
||||
const customRegistry = new RegistryManagement({ separator: ':' })
|
||||
customRegistry.registerProvider('openai', mockProvider)
|
||||
|
||||
expect(createProviderRegistry).toHaveBeenCalledWith(expect.any(Object), { separator: ':' })
|
||||
})
|
||||
})
|
||||
|
||||
describe('Global Registry Instance', () => {
|
||||
it('should have a global instance with default separator', async () => {
|
||||
const module = await import('../RegistryManagement')
|
||||
|
||||
expect(module.globalRegistryManagement).toBeInstanceOf(RegistryManagement)
|
||||
})
|
||||
|
||||
it('should have DEFAULT_SEPARATOR exported', () => {
|
||||
expect(DEFAULT_SEPARATOR).toBe('|')
|
||||
})
|
||||
})
|
||||
|
||||
describe('Edge Cases', () => {
|
||||
it('should handle registering same provider twice', () => {
|
||||
registry.registerProvider('openai', mockProvider)
|
||||
|
||||
const provider2 = { ...mockProvider }
|
||||
registry.registerProvider('openai', provider2)
|
||||
|
||||
expect(registry.getProvider('openai')).toBe(provider2)
|
||||
})
|
||||
|
||||
it('should handle alias conflicts (first wins)', () => {
|
||||
registry.registerProvider('provider1', mockProvider, ['shared-alias'])
|
||||
registry.registerProvider('provider2', mockProvider, ['shared-alias'])
|
||||
|
||||
// First registered alias wins (the implementation doesn't override)
|
||||
expect(registry.resolveProviderId('shared-alias')).toBe('provider1')
|
||||
})
|
||||
|
||||
it('should handle empty alias array', () => {
|
||||
registry.registerProvider('openai', mockProvider, [])
|
||||
|
||||
expect(registry.getAllAliases()).toEqual({})
|
||||
})
|
||||
|
||||
it('should handle null registry operations gracefully', () => {
|
||||
const emptyRegistry = new RegistryManagement()
|
||||
|
||||
expect(() => emptyRegistry.languageModel('test|model' as any)).toThrow('No providers registered')
|
||||
expect(() => emptyRegistry.embeddingModel('test|embed' as any)).toThrow('No providers registered')
|
||||
expect(() => emptyRegistry.imageModel('test|image' as any)).toThrow('No providers registered')
|
||||
})
|
||||
|
||||
it('should handle special characters in provider IDs', () => {
|
||||
const specialIds = ['provider-1', 'provider_2', 'provider.3', 'provider:4']
|
||||
|
||||
specialIds.forEach((id) => {
|
||||
registry.registerProvider(id, mockProvider)
|
||||
expect(registry.getProvider(id)).toBe(mockProvider)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('Concurrent Operations', () => {
|
||||
it('should handle concurrent registrations', () => {
|
||||
const promises = [
|
||||
Promise.resolve(registry.registerProvider('provider1', mockProvider)),
|
||||
Promise.resolve(registry.registerProvider('provider2', mockProvider)),
|
||||
Promise.resolve(registry.registerProvider('provider3', mockProvider))
|
||||
]
|
||||
|
||||
return Promise.all(promises).then(() => {
|
||||
expect(registry.getRegisteredProviders()).toHaveLength(3)
|
||||
})
|
||||
})
|
||||
|
||||
it('should handle mixed operations', () => {
|
||||
registry.registerProvider('openai', mockProvider)
|
||||
registry.registerProvider('anthropic', mockProvider)
|
||||
|
||||
const provider1 = registry.getProvider('openai')
|
||||
registry.unregisterProvider('anthropic')
|
||||
const provider2 = registry.getProvider('openai')
|
||||
|
||||
expect(provider1).toBe(provider2)
|
||||
})
|
||||
})
|
||||
|
||||
describe('Type Safety', () => {
|
||||
it('should enforce model ID format with template literal types', () => {
|
||||
registry.registerProvider('openai', mockProvider)
|
||||
|
||||
// These should be type-safe
|
||||
const validId = 'openai|gpt-4' as `${string}${typeof DEFAULT_SEPARATOR}${string}`
|
||||
|
||||
expect(() => registry.languageModel(validId)).not.toThrow()
|
||||
})
|
||||
|
||||
it('should return properly typed LanguageModelV3', () => {
|
||||
registry.registerProvider('openai', mockProvider)
|
||||
|
||||
const model = registry.languageModel('openai|gpt-4' as any)
|
||||
|
||||
expect(model.specificationVersion).toBe('v3')
|
||||
expect(model).toHaveProperty('doGenerate')
|
||||
expect(model).toHaveProperty('doStream')
|
||||
})
|
||||
|
||||
it('should return properly typed EmbeddingModelV3', () => {
|
||||
registry.registerProvider('openai', mockProvider)
|
||||
|
||||
const model = registry.embeddingModel('openai|ada-002' as any)
|
||||
|
||||
expect(model.specificationVersion).toBe('v3')
|
||||
expect(model).toHaveProperty('doEmbed')
|
||||
})
|
||||
|
||||
it('should return properly typed ImageModelV3', () => {
|
||||
registry.registerProvider('openai', mockProvider)
|
||||
|
||||
const model = registry.imageModel('openai|dall-e-3' as any)
|
||||
|
||||
expect(model.specificationVersion).toBe('v3')
|
||||
expect(model).toHaveProperty('doGenerate')
|
||||
})
|
||||
})
|
||||
|
||||
describe('Memory Management', () => {
|
||||
it('should properly clean up on clear', () => {
|
||||
registry.registerProvider('p1', mockProvider, ['a1'])
|
||||
registry.registerProvider('p2', mockProvider, ['a2'])
|
||||
|
||||
registry.clear()
|
||||
|
||||
expect(registry.getRegisteredProviders()).toHaveLength(0)
|
||||
expect(Object.keys(registry.getAllAliases())).toHaveLength(0)
|
||||
})
|
||||
|
||||
it('should properly clean up on unregister', () => {
|
||||
registry.registerProvider('openai', mockProvider, ['gpt', 'chatgpt'])
|
||||
|
||||
registry.unregisterProvider('openai')
|
||||
|
||||
expect(registry.getProvider('openai')).toBeUndefined()
|
||||
expect(registry.isAlias('gpt')).toBe(false)
|
||||
expect(registry.isAlias('chatgpt')).toBe(false)
|
||||
})
|
||||
})
|
||||
})
|
||||
@ -0,0 +1,650 @@
|
||||
/**
|
||||
* RuntimeExecutor.resolveModel Comprehensive Tests
|
||||
* Tests the private resolveModel and resolveImageModel methods through public APIs
|
||||
* Covers model resolution, middleware application, and type validation
|
||||
*/
|
||||
|
||||
import type { ImageModelV3, LanguageModelV3 } from '@ai-sdk/provider'
|
||||
import { generateImage, generateText, streamText } from 'ai'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import {
|
||||
createMockImageModel,
|
||||
createMockLanguageModel,
|
||||
createMockMiddleware,
|
||||
mockProviderConfigs
|
||||
} from '../../../__tests__'
|
||||
import { globalModelResolver } from '../../models'
|
||||
import { ImageModelResolutionError } from '../errors'
|
||||
import { RuntimeExecutor } from '../executor'
|
||||
|
||||
// Mock AI SDK
|
||||
vi.mock('ai', async (importOriginal) => {
|
||||
const actual = (await importOriginal()) as Record<string, unknown>
|
||||
return {
|
||||
...actual,
|
||||
generateText: vi.fn(),
|
||||
streamText: vi.fn(),
|
||||
generateImage: vi.fn(),
|
||||
wrapLanguageModel: vi.fn((config: any) => ({
|
||||
...config.model,
|
||||
_middlewareApplied: true,
|
||||
middleware: config.middleware
|
||||
}))
|
||||
}
|
||||
})
|
||||
|
||||
vi.mock('../../providers/RegistryManagement', () => ({
|
||||
globalRegistryManagement: {
|
||||
languageModel: vi.fn(),
|
||||
imageModel: vi.fn()
|
||||
},
|
||||
DEFAULT_SEPARATOR: '|'
|
||||
}))
|
||||
|
||||
vi.mock('../../models', () => ({
|
||||
globalModelResolver: {
|
||||
resolveLanguageModel: vi.fn(),
|
||||
resolveImageModel: vi.fn()
|
||||
}
|
||||
}))
|
||||
|
||||
describe('RuntimeExecutor - Model Resolution', () => {
|
||||
let executor: RuntimeExecutor<'openai'>
|
||||
let mockLanguageModel: LanguageModelV3
|
||||
let mockImageModel: ImageModelV3
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
|
||||
executor = RuntimeExecutor.create('openai', mockProviderConfigs.openai)
|
||||
|
||||
mockLanguageModel = createMockLanguageModel({
|
||||
specificationVersion: 'v3',
|
||||
provider: 'openai',
|
||||
modelId: 'gpt-4'
|
||||
})
|
||||
|
||||
mockImageModel = createMockImageModel({
|
||||
specificationVersion: 'v3',
|
||||
provider: 'openai',
|
||||
modelId: 'dall-e-3'
|
||||
})
|
||||
|
||||
vi.mocked(globalModelResolver.resolveLanguageModel).mockResolvedValue(mockLanguageModel)
|
||||
vi.mocked(globalModelResolver.resolveImageModel).mockResolvedValue(mockImageModel)
|
||||
vi.mocked(generateText).mockResolvedValue({
|
||||
text: 'Test response',
|
||||
finishReason: 'stop',
|
||||
usage: { inputTokens: 10, outputTokens: 5, totalTokens: 15 }
|
||||
} as any)
|
||||
vi.mocked(streamText).mockResolvedValue({
|
||||
textStream: (async function* () {
|
||||
yield 'test'
|
||||
})()
|
||||
} as any)
|
||||
vi.mocked(generateImage).mockResolvedValue({
|
||||
image: {
|
||||
base64: 'test-image',
|
||||
uint8Array: new Uint8Array([1, 2, 3]),
|
||||
mimeType: 'image/png'
|
||||
},
|
||||
warnings: []
|
||||
} as any)
|
||||
})
|
||||
|
||||
describe('Language Model Resolution (String modelId)', () => {
|
||||
it('should resolve string modelId using globalModelResolver', async () => {
|
||||
await executor.generateText({
|
||||
model: 'gpt-4',
|
||||
messages: [{ role: 'user', content: 'Hello' }]
|
||||
})
|
||||
|
||||
expect(globalModelResolver.resolveLanguageModel).toHaveBeenCalledWith(
|
||||
'gpt-4',
|
||||
'openai',
|
||||
mockProviderConfigs.openai,
|
||||
undefined
|
||||
)
|
||||
})
|
||||
|
||||
it('should pass provider settings to model resolver', async () => {
|
||||
const customExecutor = RuntimeExecutor.create('anthropic', {
|
||||
apiKey: 'sk-test',
|
||||
baseURL: 'https://api.anthropic.com'
|
||||
})
|
||||
|
||||
vi.mocked(globalModelResolver.resolveLanguageModel).mockResolvedValue(mockLanguageModel)
|
||||
|
||||
await customExecutor.generateText({
|
||||
model: 'claude-3-5-sonnet',
|
||||
messages: [{ role: 'user', content: 'Test' }]
|
||||
})
|
||||
|
||||
expect(globalModelResolver.resolveLanguageModel).toHaveBeenCalledWith(
|
||||
'claude-3-5-sonnet',
|
||||
'anthropic',
|
||||
{
|
||||
apiKey: 'sk-test',
|
||||
baseURL: 'https://api.anthropic.com'
|
||||
},
|
||||
undefined
|
||||
)
|
||||
})
|
||||
|
||||
it('should resolve traditional format modelId', async () => {
|
||||
await executor.generateText({
|
||||
model: 'gpt-4-turbo',
|
||||
messages: [{ role: 'user', content: 'Test' }]
|
||||
})
|
||||
|
||||
expect(globalModelResolver.resolveLanguageModel).toHaveBeenCalledWith(
|
||||
'gpt-4-turbo',
|
||||
'openai',
|
||||
expect.any(Object),
|
||||
undefined
|
||||
)
|
||||
})
|
||||
|
||||
it('should resolve namespaced format modelId', async () => {
|
||||
await executor.generateText({
|
||||
model: 'aihubmix|anthropic|claude-3',
|
||||
messages: [{ role: 'user', content: 'Test' }]
|
||||
})
|
||||
|
||||
expect(globalModelResolver.resolveLanguageModel).toHaveBeenCalledWith(
|
||||
'aihubmix|anthropic|claude-3',
|
||||
'openai',
|
||||
expect.any(Object),
|
||||
undefined
|
||||
)
|
||||
})
|
||||
|
||||
it('should use resolved model for generation', async () => {
|
||||
await executor.generateText({
|
||||
model: 'gpt-4',
|
||||
messages: [{ role: 'user', content: 'Hello' }]
|
||||
})
|
||||
|
||||
expect(generateText).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
model: mockLanguageModel
|
||||
})
|
||||
)
|
||||
})
|
||||
|
||||
it('should work with streamText', async () => {
|
||||
await executor.streamText({
|
||||
model: 'gpt-4',
|
||||
messages: [{ role: 'user', content: 'Stream test' }]
|
||||
})
|
||||
|
||||
expect(globalModelResolver.resolveLanguageModel).toHaveBeenCalled()
|
||||
expect(streamText).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
model: mockLanguageModel
|
||||
})
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
describe('Language Model Resolution (Direct Model Object)', () => {
|
||||
it('should accept pre-resolved V3 model object', async () => {
|
||||
const directModel: LanguageModelV3 = createMockLanguageModel({
|
||||
specificationVersion: 'v3',
|
||||
provider: 'openai',
|
||||
modelId: 'gpt-4'
|
||||
})
|
||||
|
||||
await executor.generateText({
|
||||
model: directModel,
|
||||
messages: [{ role: 'user', content: 'Test' }]
|
||||
})
|
||||
|
||||
// Should NOT call resolver for direct model
|
||||
expect(globalModelResolver.resolveLanguageModel).not.toHaveBeenCalled()
|
||||
|
||||
// Should use the model directly
|
||||
expect(generateText).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
model: directModel
|
||||
})
|
||||
)
|
||||
})
|
||||
|
||||
it('should accept V2 model object without validation (plugin engine handles it)', async () => {
|
||||
const v2Model = {
|
||||
specificationVersion: 'v2',
|
||||
provider: 'openai',
|
||||
modelId: 'gpt-4',
|
||||
doGenerate: vi.fn()
|
||||
} as any
|
||||
|
||||
// The plugin engine accepts any model object directly without validation
|
||||
// V3 validation only happens when resolving string modelIds
|
||||
await expect(
|
||||
executor.generateText({
|
||||
model: v2Model,
|
||||
messages: [{ role: 'user', content: 'Test' }]
|
||||
})
|
||||
).resolves.toBeDefined()
|
||||
})
|
||||
|
||||
it('should accept any model object without checking specification version', async () => {
|
||||
const v2Model = {
|
||||
specificationVersion: 'v2',
|
||||
provider: 'custom-provider',
|
||||
modelId: 'custom-model',
|
||||
doGenerate: vi.fn()
|
||||
} as any
|
||||
|
||||
// Direct model objects bypass validation
|
||||
// The executor trusts that plugins/users provide valid models
|
||||
await expect(
|
||||
executor.generateText({
|
||||
model: v2Model,
|
||||
messages: [{ role: 'user', content: 'Test' }]
|
||||
})
|
||||
).resolves.toBeDefined()
|
||||
})
|
||||
|
||||
it('should accept model object with streamText', async () => {
|
||||
const directModel = createMockLanguageModel({
|
||||
specificationVersion: 'v3'
|
||||
})
|
||||
|
||||
await executor.streamText({
|
||||
model: directModel,
|
||||
messages: [{ role: 'user', content: 'Stream' }]
|
||||
})
|
||||
|
||||
expect(globalModelResolver.resolveLanguageModel).not.toHaveBeenCalled()
|
||||
expect(streamText).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
model: directModel
|
||||
})
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
describe('Middleware Application', () => {
|
||||
it('should apply middlewares to string modelId', async () => {
|
||||
const testMiddleware = createMockMiddleware({ name: 'test-middleware' })
|
||||
|
||||
await executor.generateText(
|
||||
{
|
||||
model: 'gpt-4',
|
||||
messages: [{ role: 'user', content: 'Test' }]
|
||||
},
|
||||
{ middlewares: [testMiddleware] }
|
||||
)
|
||||
|
||||
expect(globalModelResolver.resolveLanguageModel).toHaveBeenCalledWith('gpt-4', 'openai', expect.any(Object), [
|
||||
testMiddleware
|
||||
])
|
||||
})
|
||||
|
||||
it('should apply multiple middlewares in order', async () => {
|
||||
const middleware1 = createMockMiddleware({ name: 'middleware-1' })
|
||||
const middleware2 = createMockMiddleware({ name: 'middleware-2' })
|
||||
|
||||
await executor.generateText(
|
||||
{
|
||||
model: 'gpt-4',
|
||||
messages: [{ role: 'user', content: 'Test' }]
|
||||
},
|
||||
{ middlewares: [middleware1, middleware2] }
|
||||
)
|
||||
|
||||
expect(globalModelResolver.resolveLanguageModel).toHaveBeenCalledWith('gpt-4', 'openai', expect.any(Object), [
|
||||
middleware1,
|
||||
middleware2
|
||||
])
|
||||
})
|
||||
|
||||
it('should pass middlewares to model resolver for string modelIds', async () => {
|
||||
const testMiddleware = createMockMiddleware({ name: 'test-middleware' })
|
||||
|
||||
await executor.generateText(
|
||||
{
|
||||
model: 'gpt-4', // String model ID
|
||||
messages: [{ role: 'user', content: 'Test' }]
|
||||
},
|
||||
{ middlewares: [testMiddleware] }
|
||||
)
|
||||
|
||||
// Middlewares are passed to the resolver for string modelIds
|
||||
expect(globalModelResolver.resolveLanguageModel).toHaveBeenCalledWith('gpt-4', 'openai', expect.any(Object), [
|
||||
testMiddleware
|
||||
])
|
||||
})
|
||||
|
||||
it('should not apply middlewares when none provided', async () => {
|
||||
await executor.generateText({
|
||||
model: 'gpt-4',
|
||||
messages: [{ role: 'user', content: 'Test' }]
|
||||
})
|
||||
|
||||
expect(globalModelResolver.resolveLanguageModel).toHaveBeenCalledWith(
|
||||
'gpt-4',
|
||||
'openai',
|
||||
expect.any(Object),
|
||||
undefined
|
||||
)
|
||||
})
|
||||
|
||||
it('should handle empty middleware array', async () => {
|
||||
await executor.generateText(
|
||||
{
|
||||
model: 'gpt-4',
|
||||
messages: [{ role: 'user', content: 'Test' }]
|
||||
},
|
||||
{ middlewares: [] }
|
||||
)
|
||||
|
||||
expect(globalModelResolver.resolveLanguageModel).toHaveBeenCalledWith('gpt-4', 'openai', expect.any(Object), [])
|
||||
})
|
||||
|
||||
it('should work with middlewares in streamText', async () => {
|
||||
const middleware = createMockMiddleware({ name: 'stream-middleware' })
|
||||
|
||||
await executor.streamText(
|
||||
{
|
||||
model: 'gpt-4',
|
||||
messages: [{ role: 'user', content: 'Stream' }]
|
||||
},
|
||||
{ middlewares: [middleware] }
|
||||
)
|
||||
|
||||
expect(globalModelResolver.resolveLanguageModel).toHaveBeenCalledWith('gpt-4', 'openai', expect.any(Object), [
|
||||
middleware
|
||||
])
|
||||
})
|
||||
})
|
||||
|
||||
describe('Image Model Resolution', () => {
|
||||
it('should resolve string image modelId using globalModelResolver', async () => {
|
||||
await executor.generateImage({
|
||||
model: 'dall-e-3',
|
||||
prompt: 'A beautiful sunset'
|
||||
})
|
||||
|
||||
expect(globalModelResolver.resolveImageModel).toHaveBeenCalledWith('dall-e-3', 'openai')
|
||||
})
|
||||
|
||||
it('should accept direct ImageModelV3 object', async () => {
|
||||
const directImageModel: ImageModelV3 = createMockImageModel({
|
||||
specificationVersion: 'v3',
|
||||
provider: 'openai',
|
||||
modelId: 'dall-e-3'
|
||||
})
|
||||
|
||||
await executor.generateImage({
|
||||
model: directImageModel,
|
||||
prompt: 'Test image'
|
||||
})
|
||||
|
||||
expect(globalModelResolver.resolveImageModel).not.toHaveBeenCalled()
|
||||
expect(generateImage).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
model: directImageModel
|
||||
})
|
||||
)
|
||||
})
|
||||
|
||||
it('should resolve namespaced image model ID', async () => {
|
||||
await executor.generateImage({
|
||||
model: 'aihubmix|openai|dall-e-3',
|
||||
prompt: 'Namespaced image'
|
||||
})
|
||||
|
||||
expect(globalModelResolver.resolveImageModel).toHaveBeenCalledWith('aihubmix|openai|dall-e-3', 'openai')
|
||||
})
|
||||
|
||||
it('should throw ImageModelResolutionError on resolution failure', async () => {
|
||||
const resolutionError = new Error('Model not found')
|
||||
vi.mocked(globalModelResolver.resolveImageModel).mockRejectedValue(resolutionError)
|
||||
|
||||
await expect(
|
||||
executor.generateImage({
|
||||
model: 'invalid-model',
|
||||
prompt: 'Test'
|
||||
})
|
||||
).rejects.toThrow(ImageModelResolutionError)
|
||||
})
|
||||
|
||||
it('should include modelId and providerId in ImageModelResolutionError', async () => {
|
||||
vi.mocked(globalModelResolver.resolveImageModel).mockRejectedValue(new Error('Not found'))
|
||||
|
||||
try {
|
||||
await executor.generateImage({
|
||||
model: 'invalid-model',
|
||||
prompt: 'Test'
|
||||
})
|
||||
expect.fail('Should have thrown ImageModelResolutionError')
|
||||
} catch (error) {
|
||||
expect(error).toBeInstanceOf(ImageModelResolutionError)
|
||||
const imgError = error as ImageModelResolutionError
|
||||
expect(imgError.message).toContain('invalid-model')
|
||||
expect(imgError.providerId).toBe('openai')
|
||||
}
|
||||
})
|
||||
|
||||
it('should extract modelId from direct model object in error', async () => {
|
||||
const directModel = createMockImageModel({
|
||||
modelId: 'direct-model',
|
||||
doGenerate: vi.fn().mockRejectedValue(new Error('Generation failed'))
|
||||
})
|
||||
|
||||
vi.mocked(generateImage).mockRejectedValue(new Error('Generation failed'))
|
||||
|
||||
await expect(
|
||||
executor.generateImage({
|
||||
model: directModel,
|
||||
prompt: 'Test'
|
||||
})
|
||||
).rejects.toThrow()
|
||||
})
|
||||
})
|
||||
|
||||
describe('Provider-Specific Model Resolution', () => {
|
||||
it('should resolve models for OpenAI provider', async () => {
|
||||
const openaiExecutor = RuntimeExecutor.create('openai', mockProviderConfigs.openai)
|
||||
|
||||
await openaiExecutor.generateText({
|
||||
model: 'gpt-4',
|
||||
messages: [{ role: 'user', content: 'Test' }]
|
||||
})
|
||||
|
||||
expect(globalModelResolver.resolveLanguageModel).toHaveBeenCalledWith(
|
||||
'gpt-4',
|
||||
'openai',
|
||||
expect.any(Object),
|
||||
undefined
|
||||
)
|
||||
})
|
||||
|
||||
it('should resolve models for Anthropic provider', async () => {
|
||||
const anthropicExecutor = RuntimeExecutor.create('anthropic', mockProviderConfigs.anthropic)
|
||||
|
||||
await anthropicExecutor.generateText({
|
||||
model: 'claude-3-5-sonnet',
|
||||
messages: [{ role: 'user', content: 'Test' }]
|
||||
})
|
||||
|
||||
expect(globalModelResolver.resolveLanguageModel).toHaveBeenCalledWith(
|
||||
'claude-3-5-sonnet',
|
||||
'anthropic',
|
||||
expect.any(Object),
|
||||
undefined
|
||||
)
|
||||
})
|
||||
|
||||
it('should resolve models for Google provider', async () => {
|
||||
const googleExecutor = RuntimeExecutor.create('google', mockProviderConfigs.google)
|
||||
|
||||
await googleExecutor.generateText({
|
||||
model: 'gemini-2.0-flash',
|
||||
messages: [{ role: 'user', content: 'Test' }]
|
||||
})
|
||||
|
||||
expect(globalModelResolver.resolveLanguageModel).toHaveBeenCalledWith(
|
||||
'gemini-2.0-flash',
|
||||
'google',
|
||||
expect.any(Object),
|
||||
undefined
|
||||
)
|
||||
})
|
||||
|
||||
it('should resolve models for OpenAI-compatible provider', async () => {
|
||||
const compatibleExecutor = RuntimeExecutor.createOpenAICompatible(mockProviderConfigs['openai-compatible'])
|
||||
|
||||
await compatibleExecutor.generateText({
|
||||
model: 'custom-model',
|
||||
messages: [{ role: 'user', content: 'Test' }]
|
||||
})
|
||||
|
||||
expect(globalModelResolver.resolveLanguageModel).toHaveBeenCalledWith(
|
||||
'custom-model',
|
||||
'openai-compatible',
|
||||
expect.any(Object),
|
||||
undefined
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
describe('OpenAI Mode Handling', () => {
|
||||
it('should pass mode setting to model resolver', async () => {
|
||||
const executorWithMode = RuntimeExecutor.create('openai', {
|
||||
...mockProviderConfigs.openai,
|
||||
mode: 'chat'
|
||||
})
|
||||
|
||||
await executorWithMode.generateText({
|
||||
model: 'gpt-4',
|
||||
messages: [{ role: 'user', content: 'Test' }]
|
||||
})
|
||||
|
||||
expect(globalModelResolver.resolveLanguageModel).toHaveBeenCalledWith(
|
||||
'gpt-4',
|
||||
'openai',
|
||||
expect.objectContaining({
|
||||
mode: 'chat'
|
||||
}),
|
||||
undefined
|
||||
)
|
||||
})
|
||||
|
||||
it('should handle responses mode', async () => {
|
||||
const executorWithMode = RuntimeExecutor.create('openai', {
|
||||
...mockProviderConfigs.openai,
|
||||
mode: 'responses'
|
||||
})
|
||||
|
||||
await executorWithMode.generateText({
|
||||
model: 'gpt-4',
|
||||
messages: [{ role: 'user', content: 'Test' }]
|
||||
})
|
||||
|
||||
expect(globalModelResolver.resolveLanguageModel).toHaveBeenCalledWith(
|
||||
'gpt-4',
|
||||
'openai',
|
||||
expect.objectContaining({
|
||||
mode: 'responses'
|
||||
}),
|
||||
undefined
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
describe('Edge Cases', () => {
|
||||
it('should handle empty string modelId', async () => {
|
||||
await executor.generateText({
|
||||
model: '',
|
||||
messages: [{ role: 'user', content: 'Test' }]
|
||||
})
|
||||
|
||||
expect(globalModelResolver.resolveLanguageModel).toHaveBeenCalledWith('', 'openai', expect.any(Object), undefined)
|
||||
})
|
||||
|
||||
it('should handle model resolution errors gracefully', async () => {
|
||||
vi.mocked(globalModelResolver.resolveLanguageModel).mockRejectedValue(new Error('Model not found'))
|
||||
|
||||
await expect(
|
||||
executor.generateText({
|
||||
model: 'nonexistent-model',
|
||||
messages: [{ role: 'user', content: 'Test' }]
|
||||
})
|
||||
).rejects.toThrow('Model not found')
|
||||
})
|
||||
|
||||
it('should handle concurrent model resolutions', async () => {
|
||||
const promises = [
|
||||
executor.generateText({ model: 'gpt-4', messages: [{ role: 'user', content: 'Test 1' }] }),
|
||||
executor.generateText({ model: 'gpt-4-turbo', messages: [{ role: 'user', content: 'Test 2' }] }),
|
||||
executor.generateText({ model: 'gpt-3.5-turbo', messages: [{ role: 'user', content: 'Test 3' }] })
|
||||
]
|
||||
|
||||
await Promise.all(promises)
|
||||
|
||||
expect(globalModelResolver.resolveLanguageModel).toHaveBeenCalledTimes(3)
|
||||
})
|
||||
|
||||
it('should accept model object even without specificationVersion', async () => {
|
||||
const invalidModel = {
|
||||
provider: 'test',
|
||||
modelId: 'test-model'
|
||||
// Missing specificationVersion
|
||||
} as any
|
||||
|
||||
// Plugin engine doesn't validate direct model objects
|
||||
// It's the user's responsibility to provide valid models
|
||||
await expect(
|
||||
executor.generateText({
|
||||
model: invalidModel,
|
||||
messages: [{ role: 'user', content: 'Test' }]
|
||||
})
|
||||
).resolves.toBeDefined()
|
||||
})
|
||||
})
|
||||
|
||||
describe('Type Safety Validation', () => {
|
||||
it('should ensure resolved model is LanguageModelV3', async () => {
|
||||
const v3Model = createMockLanguageModel({
|
||||
specificationVersion: 'v3'
|
||||
})
|
||||
|
||||
vi.mocked(globalModelResolver.resolveLanguageModel).mockResolvedValue(v3Model)
|
||||
|
||||
await executor.generateText({
|
||||
model: 'gpt-4',
|
||||
messages: [{ role: 'user', content: 'Test' }]
|
||||
})
|
||||
|
||||
expect(generateText).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
model: expect.objectContaining({
|
||||
specificationVersion: 'v3'
|
||||
})
|
||||
})
|
||||
)
|
||||
})
|
||||
|
||||
it('should not enforce specification version for direct models', async () => {
|
||||
const v1Model = {
|
||||
specificationVersion: 'v1',
|
||||
provider: 'test',
|
||||
modelId: 'test'
|
||||
} as any
|
||||
|
||||
// Direct models bypass validation in the plugin engine
|
||||
// Only resolved models (from string IDs) are validated
|
||||
await expect(
|
||||
executor.generateText({
|
||||
model: v1Model,
|
||||
messages: [{ role: 'user', content: 'Test' }]
|
||||
})
|
||||
).resolves.toBeDefined()
|
||||
})
|
||||
})
|
||||
})
|
||||
867
packages/aiCore/src/core/runtime/__tests__/pluginEngine.test.ts
Normal file
867
packages/aiCore/src/core/runtime/__tests__/pluginEngine.test.ts
Normal file
@ -0,0 +1,867 @@
|
||||
/**
|
||||
* PluginEngine Comprehensive Tests
|
||||
* Tests plugin lifecycle, execution order, and coordination
|
||||
* Covers both streaming and non-streaming execution paths
|
||||
*/
|
||||
|
||||
import type { ImageModelV3, LanguageModelV3 } from '@ai-sdk/provider'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import { createMockImageModel, createMockLanguageModel } from '../../../__tests__'
|
||||
import { ModelResolutionError, RecursiveDepthError } from '../../errors'
|
||||
import type { AiPlugin, GenerateTextParams, GenerateTextResult } from '../../plugins'
|
||||
import { PluginEngine } from '../pluginEngine'
|
||||
|
||||
describe('PluginEngine', () => {
|
||||
let engine: PluginEngine<'openai'>
|
||||
let mockLanguageModel: LanguageModelV3
|
||||
let mockImageModel: ImageModelV3
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
|
||||
mockLanguageModel = createMockLanguageModel({
|
||||
provider: 'openai',
|
||||
modelId: 'gpt-4'
|
||||
})
|
||||
|
||||
mockImageModel = createMockImageModel({
|
||||
provider: 'openai',
|
||||
modelId: 'dall-e-3'
|
||||
})
|
||||
})
|
||||
|
||||
describe('Plugin Registration and Management', () => {
|
||||
it('should create engine with empty plugins', () => {
|
||||
engine = new PluginEngine('openai', [])
|
||||
|
||||
expect(engine.getPlugins()).toEqual([])
|
||||
})
|
||||
|
||||
it('should create engine with initial plugins', () => {
|
||||
const plugin1: AiPlugin = { name: 'plugin-1' }
|
||||
const plugin2: AiPlugin = { name: 'plugin-2' }
|
||||
|
||||
engine = new PluginEngine('openai', [plugin1, plugin2])
|
||||
|
||||
expect(engine.getPlugins()).toHaveLength(2)
|
||||
expect(engine.getPlugins()).toEqual([plugin1, plugin2])
|
||||
})
|
||||
|
||||
it('should add plugin with use()', () => {
|
||||
engine = new PluginEngine('openai', [])
|
||||
const plugin: AiPlugin = { name: 'test-plugin' }
|
||||
|
||||
const result = engine.use(plugin)
|
||||
|
||||
expect(result).toBe(engine) // Chainable
|
||||
expect(engine.getPlugins()).toContain(plugin)
|
||||
})
|
||||
|
||||
it('should add multiple plugins with usePlugins()', () => {
|
||||
engine = new PluginEngine('openai', [])
|
||||
const plugins: AiPlugin[] = [{ name: 'plugin-1' }, { name: 'plugin-2' }, { name: 'plugin-3' }]
|
||||
|
||||
const result = engine.usePlugins(plugins)
|
||||
|
||||
expect(result).toBe(engine) // Chainable
|
||||
expect(engine.getPlugins()).toHaveLength(3)
|
||||
})
|
||||
|
||||
it('should support method chaining', () => {
|
||||
engine = new PluginEngine('openai', [])
|
||||
|
||||
engine
|
||||
.use({ name: 'plugin-1' })
|
||||
.use({ name: 'plugin-2' })
|
||||
.usePlugins([{ name: 'plugin-3' }])
|
||||
|
||||
expect(engine.getPlugins()).toHaveLength(3)
|
||||
})
|
||||
|
||||
it('should remove plugin by name', () => {
|
||||
const plugin1: AiPlugin = { name: 'plugin-1' }
|
||||
const plugin2: AiPlugin = { name: 'plugin-2' }
|
||||
|
||||
engine = new PluginEngine('openai', [plugin1, plugin2])
|
||||
|
||||
engine.removePlugin('plugin-1')
|
||||
|
||||
expect(engine.getPlugins()).toHaveLength(1)
|
||||
expect(engine.getPlugins()[0]).toBe(plugin2)
|
||||
})
|
||||
|
||||
it('should not error when removing non-existent plugin', () => {
|
||||
engine = new PluginEngine('openai', [{ name: 'existing' }])
|
||||
|
||||
engine.removePlugin('non-existent')
|
||||
|
||||
expect(engine.getPlugins()).toHaveLength(1)
|
||||
})
|
||||
|
||||
it('should get plugin statistics', () => {
|
||||
const plugins: AiPlugin[] = [
|
||||
{ name: 'plugin-1', enforce: 'pre' },
|
||||
{ name: 'plugin-2' },
|
||||
{ name: 'plugin-3', enforce: 'post' }
|
||||
]
|
||||
|
||||
engine = new PluginEngine('openai', plugins)
|
||||
|
||||
const stats = engine.getPluginStats()
|
||||
|
||||
expect(stats).toHaveProperty('total')
|
||||
expect(stats.total).toBe(3)
|
||||
})
|
||||
|
||||
it('should preserve plugin order', () => {
|
||||
const plugins: AiPlugin[] = [{ name: 'first' }, { name: 'second' }, { name: 'third' }]
|
||||
|
||||
engine = new PluginEngine('openai', plugins)
|
||||
|
||||
const retrieved = engine.getPlugins()
|
||||
expect(retrieved[0].name).toBe('first')
|
||||
expect(retrieved[1].name).toBe('second')
|
||||
expect(retrieved[2].name).toBe('third')
|
||||
})
|
||||
})
|
||||
|
||||
describe('Plugin Lifecycle - Non-Streaming', () => {
|
||||
it('should execute all plugin hooks in correct order', async () => {
|
||||
const executionOrder: string[] = []
|
||||
|
||||
const plugin: AiPlugin = {
|
||||
name: 'lifecycle-test',
|
||||
configureContext: vi.fn(async () => {
|
||||
executionOrder.push('configureContext')
|
||||
}),
|
||||
onRequestStart: vi.fn(async () => {
|
||||
executionOrder.push('onRequestStart')
|
||||
}),
|
||||
resolveModel: vi.fn(async () => {
|
||||
executionOrder.push('resolveModel')
|
||||
return mockLanguageModel
|
||||
}),
|
||||
transformParams: vi.fn(async (params) => {
|
||||
executionOrder.push('transformParams')
|
||||
return params
|
||||
}),
|
||||
transformResult: vi.fn(async (result) => {
|
||||
executionOrder.push('transformResult')
|
||||
return result
|
||||
}),
|
||||
onRequestEnd: vi.fn(async () => {
|
||||
executionOrder.push('onRequestEnd')
|
||||
})
|
||||
}
|
||||
|
||||
engine = new PluginEngine('openai', [plugin])
|
||||
|
||||
const mockExecutor = vi.fn().mockResolvedValue({ text: 'test', finishReason: 'stop' })
|
||||
|
||||
await engine.executeWithPlugins<GenerateTextParams, GenerateTextResult>(
|
||||
'generateText',
|
||||
{ model: 'gpt-4', messages: [] },
|
||||
mockExecutor
|
||||
)
|
||||
|
||||
expect(executionOrder).toEqual([
|
||||
'configureContext',
|
||||
'onRequestStart',
|
||||
'resolveModel',
|
||||
'transformParams',
|
||||
'transformResult',
|
||||
'onRequestEnd'
|
||||
])
|
||||
})
|
||||
|
||||
it('should call configureContext before other hooks', async () => {
|
||||
const configureContextSpy = vi.fn()
|
||||
const onRequestStartSpy = vi.fn()
|
||||
|
||||
const plugin: AiPlugin = {
|
||||
name: 'test',
|
||||
configureContext: configureContextSpy,
|
||||
onRequestStart: onRequestStartSpy,
|
||||
resolveModel: vi.fn().mockResolvedValue(mockLanguageModel)
|
||||
}
|
||||
|
||||
engine = new PluginEngine('openai', [plugin])
|
||||
|
||||
await engine.executeWithPlugins(
|
||||
'generateText',
|
||||
{ model: 'gpt-4', messages: [] },
|
||||
vi.fn().mockResolvedValue({ text: 'test' })
|
||||
)
|
||||
|
||||
expect(configureContextSpy).toHaveBeenCalled()
|
||||
expect(onRequestStartSpy).toHaveBeenCalled()
|
||||
|
||||
// configureContext should be called before onRequestStart
|
||||
expect(configureContextSpy.mock.invocationCallOrder[0]).toBeLessThan(
|
||||
onRequestStartSpy.mock.invocationCallOrder[0]
|
||||
)
|
||||
})
|
||||
|
||||
it('should execute onRequestEnd after successful execution', async () => {
|
||||
const onRequestEndSpy = vi.fn()
|
||||
|
||||
const plugin: AiPlugin = {
|
||||
name: 'test',
|
||||
resolveModel: vi.fn().mockResolvedValue(mockLanguageModel),
|
||||
onRequestEnd: onRequestEndSpy
|
||||
}
|
||||
|
||||
engine = new PluginEngine('openai', [plugin])
|
||||
|
||||
await engine.executeWithPlugins(
|
||||
'generateText',
|
||||
{ model: 'gpt-4', messages: [] },
|
||||
vi.fn().mockResolvedValue({ text: 'result' })
|
||||
)
|
||||
|
||||
expect(onRequestEndSpy).toHaveBeenCalledWith(
|
||||
expect.any(Object), // context
|
||||
expect.objectContaining({ text: 'result' })
|
||||
)
|
||||
})
|
||||
|
||||
it('should execute onError on failure', async () => {
|
||||
const onErrorSpy = vi.fn()
|
||||
const testError = new Error('Test error')
|
||||
|
||||
const plugin: AiPlugin = {
|
||||
name: 'error-handler',
|
||||
resolveModel: vi.fn().mockResolvedValue(mockLanguageModel),
|
||||
onError: onErrorSpy
|
||||
}
|
||||
|
||||
engine = new PluginEngine('openai', [plugin])
|
||||
|
||||
await expect(
|
||||
engine.executeWithPlugins(
|
||||
'generateText',
|
||||
{ model: 'gpt-4', messages: [] },
|
||||
vi.fn().mockRejectedValue(testError)
|
||||
)
|
||||
).rejects.toThrow('Test error')
|
||||
|
||||
expect(onErrorSpy).toHaveBeenCalledWith(
|
||||
testError,
|
||||
expect.any(Object) // context
|
||||
)
|
||||
})
|
||||
|
||||
it('should not call onRequestEnd when error occurs', async () => {
|
||||
const onRequestEndSpy = vi.fn()
|
||||
|
||||
const plugin: AiPlugin = {
|
||||
name: 'test',
|
||||
resolveModel: vi.fn().mockResolvedValue(mockLanguageModel),
|
||||
onRequestEnd: onRequestEndSpy
|
||||
}
|
||||
|
||||
engine = new PluginEngine('openai', [plugin])
|
||||
|
||||
await expect(
|
||||
engine.executeWithPlugins(
|
||||
'generateText',
|
||||
{ model: 'gpt-4', messages: [] },
|
||||
vi.fn().mockRejectedValue(new Error('Execution error'))
|
||||
)
|
||||
).rejects.toThrow()
|
||||
|
||||
expect(onRequestEndSpy).not.toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
describe('Model Resolution', () => {
|
||||
it('should resolve string model through plugin', async () => {
|
||||
const resolveModelSpy = vi.fn().mockResolvedValue(mockLanguageModel)
|
||||
|
||||
const plugin: AiPlugin = {
|
||||
name: 'resolver',
|
||||
resolveModel: resolveModelSpy
|
||||
}
|
||||
|
||||
engine = new PluginEngine('openai', [plugin])
|
||||
|
||||
await engine.executeWithPlugins(
|
||||
'generateText',
|
||||
{ model: 'gpt-4', messages: [] },
|
||||
vi.fn().mockResolvedValue({ text: 'test' })
|
||||
)
|
||||
|
||||
expect(resolveModelSpy).toHaveBeenCalledWith('gpt-4', expect.any(Object))
|
||||
})
|
||||
|
||||
it('should use first plugin that resolves model', async () => {
|
||||
const resolver1 = vi.fn().mockResolvedValue(mockLanguageModel)
|
||||
const resolver2 = vi.fn().mockResolvedValue(mockLanguageModel)
|
||||
|
||||
const plugin1: AiPlugin = { name: 'resolver-1', resolveModel: resolver1 }
|
||||
const plugin2: AiPlugin = { name: 'resolver-2', resolveModel: resolver2 }
|
||||
|
||||
engine = new PluginEngine('openai', [plugin1, plugin2])
|
||||
|
||||
await engine.executeWithPlugins(
|
||||
'generateText',
|
||||
{ model: 'gpt-4', messages: [] },
|
||||
vi.fn().mockResolvedValue({ text: 'test' })
|
||||
)
|
||||
|
||||
expect(resolver1).toHaveBeenCalled()
|
||||
expect(resolver2).not.toHaveBeenCalled() // Should stop after first resolver
|
||||
})
|
||||
|
||||
it('should throw ModelResolutionError if no plugin resolves model', async () => {
|
||||
const plugin: AiPlugin = {
|
||||
name: 'no-resolver'
|
||||
// No resolveModel hook
|
||||
}
|
||||
|
||||
engine = new PluginEngine('openai', [plugin])
|
||||
|
||||
await expect(
|
||||
engine.executeWithPlugins(
|
||||
'generateText',
|
||||
{ model: 'unknown-model', messages: [] },
|
||||
vi.fn().mockResolvedValue({ text: 'test' })
|
||||
)
|
||||
).rejects.toThrow(ModelResolutionError)
|
||||
})
|
||||
|
||||
it('should skip resolution for direct model objects', async () => {
|
||||
const resolveModelSpy = vi.fn()
|
||||
|
||||
const plugin: AiPlugin = {
|
||||
name: 'resolver',
|
||||
resolveModel: resolveModelSpy
|
||||
}
|
||||
|
||||
engine = new PluginEngine('openai', [plugin])
|
||||
|
||||
await engine.executeWithPlugins(
|
||||
'generateText',
|
||||
{ model: mockLanguageModel, messages: [] },
|
||||
vi.fn().mockResolvedValue({ text: 'test' })
|
||||
)
|
||||
|
||||
expect(resolveModelSpy).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should throw if resolved model is null/undefined', async () => {
|
||||
const plugin: AiPlugin = {
|
||||
name: 'bad-resolver',
|
||||
resolveModel: vi.fn().mockResolvedValue(null)
|
||||
}
|
||||
|
||||
engine = new PluginEngine('openai', [plugin])
|
||||
|
||||
await expect(
|
||||
engine.executeWithPlugins(
|
||||
'generateText',
|
||||
{ model: 'gpt-4', messages: [] },
|
||||
vi.fn().mockResolvedValue({ text: 'test' })
|
||||
)
|
||||
).rejects.toThrow(ModelResolutionError)
|
||||
})
|
||||
})
|
||||
|
||||
describe('Parameter Transformation', () => {
|
||||
it('should transform parameters through plugin', async () => {
|
||||
const transformParamsSpy = vi.fn().mockImplementation(async (params) => ({
|
||||
...params,
|
||||
temperature: 0.8
|
||||
}))
|
||||
|
||||
const plugin: AiPlugin = {
|
||||
name: 'transformer',
|
||||
resolveModel: vi.fn().mockResolvedValue(mockLanguageModel),
|
||||
transformParams: transformParamsSpy
|
||||
}
|
||||
|
||||
engine = new PluginEngine('openai', [plugin])
|
||||
|
||||
const mockExecutor = vi.fn().mockResolvedValue({ text: 'test' })
|
||||
|
||||
await engine.executeWithPlugins('generateText', { model: 'gpt-4', messages: [] }, mockExecutor)
|
||||
|
||||
expect(mockExecutor).toHaveBeenCalledWith(
|
||||
expect.any(Object),
|
||||
expect.objectContaining({
|
||||
temperature: 0.8
|
||||
})
|
||||
)
|
||||
})
|
||||
|
||||
it('should chain parameter transformations across plugins', async () => {
|
||||
const plugin1: AiPlugin = {
|
||||
name: 'transform-1',
|
||||
transformParams: vi.fn().mockImplementation(async (params) => ({
|
||||
...params,
|
||||
temperature: 0.5
|
||||
}))
|
||||
}
|
||||
|
||||
const plugin2: AiPlugin = {
|
||||
name: 'transform-2',
|
||||
transformParams: vi.fn().mockImplementation(async (params) => ({
|
||||
...params,
|
||||
maxTokens: 100
|
||||
}))
|
||||
}
|
||||
|
||||
engine = new PluginEngine('openai', [plugin1, plugin2])
|
||||
engine.usePlugins([{ name: 'resolver', resolveModel: vi.fn().mockResolvedValue(mockLanguageModel) }])
|
||||
|
||||
const mockExecutor = vi.fn().mockResolvedValue({ text: 'test' })
|
||||
|
||||
await engine.executeWithPlugins('generateText', { model: 'gpt-4', messages: [] }, mockExecutor)
|
||||
|
||||
expect(mockExecutor).toHaveBeenCalledWith(
|
||||
expect.any(Object),
|
||||
expect.objectContaining({
|
||||
temperature: 0.5,
|
||||
maxTokens: 100
|
||||
})
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
describe('Result Transformation', () => {
|
||||
it('should transform result through plugin', async () => {
|
||||
const transformResultSpy = vi.fn().mockImplementation(async (result) => ({
|
||||
...result,
|
||||
text: `${result.text} [modified]`
|
||||
}))
|
||||
|
||||
const plugin: AiPlugin = {
|
||||
name: 'result-transformer',
|
||||
resolveModel: vi.fn().mockResolvedValue(mockLanguageModel),
|
||||
transformResult: transformResultSpy
|
||||
}
|
||||
|
||||
engine = new PluginEngine('openai', [plugin])
|
||||
|
||||
const result = await engine.executeWithPlugins(
|
||||
'generateText',
|
||||
{ model: 'gpt-4', messages: [] },
|
||||
vi.fn().mockResolvedValue({ text: 'original' })
|
||||
)
|
||||
|
||||
expect(result.text).toBe('original [modified]')
|
||||
})
|
||||
|
||||
it('should chain result transformations across plugins', async () => {
|
||||
const plugin1: AiPlugin = {
|
||||
name: 'transform-1',
|
||||
transformResult: vi.fn().mockImplementation(async (result) => ({
|
||||
...result,
|
||||
text: `${result.text} + plugin1`
|
||||
}))
|
||||
}
|
||||
|
||||
const plugin2: AiPlugin = {
|
||||
name: 'transform-2',
|
||||
transformResult: vi.fn().mockImplementation(async (result) => ({
|
||||
...result,
|
||||
text: `${result.text} + plugin2`
|
||||
}))
|
||||
}
|
||||
|
||||
engine = new PluginEngine('openai', [plugin1, plugin2])
|
||||
engine.usePlugins([{ name: 'resolver', resolveModel: vi.fn().mockResolvedValue(mockLanguageModel) }])
|
||||
|
||||
const result = await engine.executeWithPlugins(
|
||||
'generateText',
|
||||
{ model: 'gpt-4', messages: [] },
|
||||
vi.fn().mockResolvedValue({ text: 'base' })
|
||||
)
|
||||
|
||||
expect(result.text).toBe('base + plugin1 + plugin2')
|
||||
})
|
||||
})
|
||||
|
||||
describe('Recursive Calls', () => {
|
||||
it('should support recursive calls through context', async () => {
|
||||
let recursionCount = 0
|
||||
|
||||
const plugin: AiPlugin = {
|
||||
name: 'recursive',
|
||||
resolveModel: vi.fn().mockResolvedValue(mockLanguageModel),
|
||||
transformParams: vi.fn().mockImplementation(async (params, context) => {
|
||||
if (recursionCount < 2 && context.recursiveCall) {
|
||||
recursionCount++
|
||||
await context.recursiveCall({ messages: [{ role: 'user', content: 'recursive' }] })
|
||||
}
|
||||
return params
|
||||
})
|
||||
}
|
||||
|
||||
engine = new PluginEngine('openai', [plugin])
|
||||
|
||||
await engine.executeWithPlugins(
|
||||
'generateText',
|
||||
{ model: 'gpt-4', messages: [] },
|
||||
vi.fn().mockResolvedValue({ text: 'test' })
|
||||
)
|
||||
|
||||
expect(recursionCount).toBe(2)
|
||||
})
|
||||
|
||||
it('should track recursion depth', async () => {
|
||||
const depths: number[] = []
|
||||
|
||||
const plugin: AiPlugin = {
|
||||
name: 'depth-tracker',
|
||||
resolveModel: vi.fn().mockResolvedValue(mockLanguageModel),
|
||||
transformParams: vi.fn().mockImplementation(async (params, context) => {
|
||||
depths.push(context.recursiveDepth)
|
||||
|
||||
if (context.recursiveDepth < 3 && context.recursiveCall) {
|
||||
await context.recursiveCall({ messages: [] })
|
||||
}
|
||||
|
||||
return params
|
||||
})
|
||||
}
|
||||
|
||||
engine = new PluginEngine('openai', [plugin])
|
||||
|
||||
await engine.executeWithPlugins(
|
||||
'generateText',
|
||||
{ model: 'gpt-4', messages: [] },
|
||||
vi.fn().mockResolvedValue({ text: 'test' })
|
||||
)
|
||||
|
||||
expect(depths).toEqual([0, 1, 2, 3])
|
||||
})
|
||||
|
||||
it('should throw RecursiveDepthError when max depth exceeded', async () => {
|
||||
const plugin: AiPlugin = {
|
||||
name: 'infinite',
|
||||
resolveModel: vi.fn().mockResolvedValue(mockLanguageModel),
|
||||
transformParams: vi.fn().mockImplementation(async (params, context) => {
|
||||
if (context.recursiveCall) {
|
||||
await context.recursiveCall({ messages: [] })
|
||||
}
|
||||
return params
|
||||
})
|
||||
}
|
||||
|
||||
engine = new PluginEngine('openai', [plugin])
|
||||
|
||||
await expect(
|
||||
engine.executeWithPlugins(
|
||||
'generateText',
|
||||
{ model: 'gpt-4', messages: [] },
|
||||
vi.fn().mockResolvedValue({ text: 'test' })
|
||||
)
|
||||
).rejects.toThrow(RecursiveDepthError)
|
||||
})
|
||||
|
||||
it('should restore recursion state after recursive call', async () => {
|
||||
const states: Array<{ depth: number; isRecursive: boolean }> = []
|
||||
|
||||
const plugin: AiPlugin = {
|
||||
name: 'state-tracker',
|
||||
resolveModel: vi.fn().mockResolvedValue(mockLanguageModel),
|
||||
transformParams: vi.fn().mockImplementation(async (params, context) => {
|
||||
states.push({ depth: context.recursiveDepth, isRecursive: context.isRecursiveCall })
|
||||
|
||||
if (context.recursiveDepth === 0 && context.recursiveCall) {
|
||||
await context.recursiveCall({ messages: [] })
|
||||
states.push({ depth: context.recursiveDepth, isRecursive: context.isRecursiveCall })
|
||||
}
|
||||
|
||||
return params
|
||||
})
|
||||
}
|
||||
|
||||
engine = new PluginEngine('openai', [plugin])
|
||||
|
||||
await engine.executeWithPlugins(
|
||||
'generateText',
|
||||
{ model: 'gpt-4', messages: [] },
|
||||
vi.fn().mockResolvedValue({ text: 'test' })
|
||||
)
|
||||
|
||||
expect(states[0]).toEqual({ depth: 0, isRecursive: false })
|
||||
expect(states[1]).toEqual({ depth: 1, isRecursive: true })
|
||||
expect(states[2]).toEqual({ depth: 0, isRecursive: false })
|
||||
})
|
||||
})
|
||||
|
||||
describe('Image Model Execution', () => {
|
||||
it('should execute image generation with plugins', async () => {
|
||||
const plugin: AiPlugin = {
|
||||
name: 'image-plugin',
|
||||
resolveModel: vi.fn().mockResolvedValue(mockImageModel),
|
||||
transformParams: vi.fn().mockImplementation(async (params) => params)
|
||||
}
|
||||
|
||||
engine = new PluginEngine('openai', [plugin])
|
||||
|
||||
const mockExecutor = vi.fn().mockResolvedValue({
|
||||
image: { base64: 'test', uint8Array: new Uint8Array(), mimeType: 'image/png' }
|
||||
})
|
||||
|
||||
await engine.executeImageWithPlugins('generateImage', { model: 'dall-e-3', prompt: 'test' }, mockExecutor)
|
||||
|
||||
expect(plugin.resolveModel).toHaveBeenCalledWith('dall-e-3', expect.any(Object))
|
||||
expect(mockExecutor).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should skip resolution for direct image model objects', async () => {
|
||||
const resolveModelSpy = vi.fn()
|
||||
|
||||
const plugin: AiPlugin = {
|
||||
name: 'image-resolver',
|
||||
resolveModel: resolveModelSpy
|
||||
}
|
||||
|
||||
engine = new PluginEngine('openai', [plugin])
|
||||
|
||||
await engine.executeImageWithPlugins(
|
||||
'generateImage',
|
||||
{ model: mockImageModel, prompt: 'test' },
|
||||
vi.fn().mockResolvedValue({ image: {} })
|
||||
)
|
||||
|
||||
expect(resolveModelSpy).not.toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
describe('Streaming Execution', () => {
|
||||
it('should execute streaming with plugins', async () => {
|
||||
const plugin: AiPlugin = {
|
||||
name: 'stream-plugin',
|
||||
resolveModel: vi.fn().mockResolvedValue(mockLanguageModel),
|
||||
transformParams: vi.fn().mockImplementation(async (params) => params)
|
||||
}
|
||||
|
||||
engine = new PluginEngine('openai', [plugin])
|
||||
|
||||
const mockExecutor = vi.fn().mockResolvedValue({
|
||||
textStream: (async function* () {
|
||||
yield 'test'
|
||||
})()
|
||||
})
|
||||
|
||||
await engine.executeStreamWithPlugins('streamText', { model: 'gpt-4', messages: [] }, mockExecutor)
|
||||
|
||||
expect(plugin.resolveModel).toHaveBeenCalled()
|
||||
expect(mockExecutor).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should collect stream transforms from plugins', async () => {
|
||||
const mockTransform = vi.fn()
|
||||
|
||||
const plugin: AiPlugin = {
|
||||
name: 'stream-transformer',
|
||||
resolveModel: vi.fn().mockResolvedValue(mockLanguageModel),
|
||||
transformStream: mockTransform
|
||||
}
|
||||
|
||||
engine = new PluginEngine('openai', [plugin])
|
||||
|
||||
const mockExecutor = vi.fn().mockResolvedValue({ textStream: (async function* () {})() })
|
||||
|
||||
await engine.executeStreamWithPlugins('streamText', { model: 'gpt-4', messages: [] }, mockExecutor)
|
||||
|
||||
// Executor should receive stream transforms
|
||||
expect(mockExecutor).toHaveBeenCalledWith(expect.any(Object), expect.any(Object), expect.arrayContaining([]))
|
||||
})
|
||||
})
|
||||
|
||||
describe('Context Management', () => {
|
||||
it('should create context with correct provider and model', async () => {
|
||||
const configureContextSpy = vi.fn()
|
||||
|
||||
const plugin: AiPlugin = {
|
||||
name: 'context-checker',
|
||||
configureContext: configureContextSpy,
|
||||
resolveModel: vi.fn().mockResolvedValue(mockLanguageModel)
|
||||
}
|
||||
|
||||
engine = new PluginEngine('openai', [plugin])
|
||||
|
||||
await engine.executeWithPlugins(
|
||||
'generateText',
|
||||
{ model: 'gpt-4', messages: [] },
|
||||
vi.fn().mockResolvedValue({ text: 'test' })
|
||||
)
|
||||
|
||||
expect(configureContextSpy).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
providerId: 'openai',
|
||||
model: 'gpt-4'
|
||||
})
|
||||
)
|
||||
})
|
||||
|
||||
it('should pass context to all hooks', async () => {
|
||||
const contextRefs: any[] = []
|
||||
|
||||
const plugin: AiPlugin = {
|
||||
name: 'context-tracker',
|
||||
configureContext: vi.fn().mockImplementation(async (context) => {
|
||||
contextRefs.push(context)
|
||||
}),
|
||||
onRequestStart: vi.fn().mockImplementation(async (context) => {
|
||||
contextRefs.push(context)
|
||||
}),
|
||||
resolveModel: vi.fn().mockImplementation(async (_, context) => {
|
||||
contextRefs.push(context)
|
||||
return mockLanguageModel
|
||||
}),
|
||||
transformParams: vi.fn().mockImplementation(async (params, context) => {
|
||||
contextRefs.push(context)
|
||||
return params
|
||||
})
|
||||
}
|
||||
|
||||
engine = new PluginEngine('openai', [plugin])
|
||||
|
||||
await engine.executeWithPlugins(
|
||||
'generateText',
|
||||
{ model: 'gpt-4', messages: [] },
|
||||
vi.fn().mockResolvedValue({ text: 'test' })
|
||||
)
|
||||
|
||||
// All context refs should point to the same object
|
||||
expect(contextRefs.length).toBeGreaterThan(0)
|
||||
const firstContext = contextRefs[0]
|
||||
contextRefs.forEach((context) => {
|
||||
expect(context).toBe(firstContext)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('Error Handling', () => {
|
||||
it('should propagate errors from executor', async () => {
|
||||
const plugin: AiPlugin = {
|
||||
name: 'test',
|
||||
resolveModel: vi.fn().mockResolvedValue(mockLanguageModel)
|
||||
}
|
||||
|
||||
engine = new PluginEngine('openai', [plugin])
|
||||
|
||||
const error = new Error('Executor failed')
|
||||
|
||||
await expect(
|
||||
engine.executeWithPlugins('generateText', { model: 'gpt-4', messages: [] }, vi.fn().mockRejectedValue(error))
|
||||
).rejects.toThrow('Executor failed')
|
||||
})
|
||||
|
||||
it('should trigger onError for all plugins on failure', async () => {
|
||||
const onError1 = vi.fn()
|
||||
const onError2 = vi.fn()
|
||||
|
||||
const plugin1: AiPlugin = { name: 'error-1', onError: onError1 }
|
||||
const plugin2: AiPlugin = { name: 'error-2', onError: onError2 }
|
||||
|
||||
engine = new PluginEngine('openai', [plugin1, plugin2])
|
||||
engine.usePlugins([{ name: 'resolver', resolveModel: vi.fn().mockResolvedValue(mockLanguageModel) }])
|
||||
|
||||
const error = new Error('Test failure')
|
||||
|
||||
await expect(
|
||||
engine.executeWithPlugins('generateText', { model: 'gpt-4', messages: [] }, vi.fn().mockRejectedValue(error))
|
||||
).rejects.toThrow()
|
||||
|
||||
expect(onError1).toHaveBeenCalledWith(error, expect.any(Object))
|
||||
expect(onError2).toHaveBeenCalledWith(error, expect.any(Object))
|
||||
})
|
||||
|
||||
it('should handle errors in plugin hooks gracefully', async () => {
|
||||
const plugin: AiPlugin = {
|
||||
name: 'failing-plugin',
|
||||
resolveModel: vi.fn().mockResolvedValue(mockLanguageModel),
|
||||
transformParams: vi.fn().mockRejectedValue(new Error('Transform failed'))
|
||||
}
|
||||
|
||||
engine = new PluginEngine('openai', [plugin])
|
||||
|
||||
await expect(
|
||||
engine.executeWithPlugins('generateText', { model: 'gpt-4', messages: [] }, vi.fn())
|
||||
).rejects.toThrow('Transform failed')
|
||||
})
|
||||
})
|
||||
|
||||
describe('Plugin Enforcement', () => {
|
||||
it('should respect plugin enforce ordering (pre, normal, post)', async () => {
|
||||
const executionOrder: string[] = []
|
||||
|
||||
const prePlugin: AiPlugin = {
|
||||
name: 'pre-plugin',
|
||||
enforce: 'pre',
|
||||
onRequestStart: vi.fn(async () => {
|
||||
executionOrder.push('pre')
|
||||
})
|
||||
}
|
||||
|
||||
const normalPlugin: AiPlugin = {
|
||||
name: 'normal-plugin',
|
||||
onRequestStart: vi.fn(async () => {
|
||||
executionOrder.push('normal')
|
||||
})
|
||||
}
|
||||
|
||||
const postPlugin: AiPlugin = {
|
||||
name: 'post-plugin',
|
||||
enforce: 'post',
|
||||
onRequestStart: vi.fn(async () => {
|
||||
executionOrder.push('post')
|
||||
})
|
||||
}
|
||||
|
||||
// Add in reverse order to test sorting
|
||||
engine = new PluginEngine('openai', [postPlugin, normalPlugin, prePlugin])
|
||||
engine.usePlugins([{ name: 'resolver', resolveModel: vi.fn().mockResolvedValue(mockLanguageModel) }])
|
||||
|
||||
await engine.executeWithPlugins(
|
||||
'generateText',
|
||||
{ model: 'gpt-4', messages: [] },
|
||||
vi.fn().mockResolvedValue({ text: 'test' })
|
||||
)
|
||||
|
||||
// Should execute in order: pre -> normal -> post
|
||||
expect(executionOrder).toEqual(['pre', 'normal', 'post'])
|
||||
})
|
||||
})
|
||||
|
||||
describe('Type Safety', () => {
|
||||
it('should properly type plugin parameters', async () => {
|
||||
const transformParamsSpy = vi.fn().mockImplementation(async (params) => {
|
||||
// Type assertions for safety
|
||||
expect(params).toHaveProperty('messages')
|
||||
return params
|
||||
})
|
||||
|
||||
const transformResultSpy = vi.fn().mockImplementation(async (result) => {
|
||||
expect(result).toHaveProperty('text')
|
||||
return result
|
||||
})
|
||||
|
||||
const typedPlugin: AiPlugin = {
|
||||
name: 'typed-plugin',
|
||||
resolveModel: vi.fn().mockResolvedValue(mockLanguageModel),
|
||||
transformParams: transformParamsSpy,
|
||||
transformResult: transformResultSpy
|
||||
}
|
||||
|
||||
engine = new PluginEngine('openai', [typedPlugin])
|
||||
|
||||
await engine.executeWithPlugins(
|
||||
'generateText',
|
||||
{ model: 'gpt-4', messages: [] },
|
||||
vi.fn().mockResolvedValue({ text: 'test', finishReason: 'stop' })
|
||||
)
|
||||
|
||||
expect(transformParamsSpy).toHaveBeenCalled()
|
||||
expect(transformResultSpy).toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
})
|
||||
@ -75,10 +75,7 @@ export class PluginEngine<T extends ProviderId = ProviderId> {
|
||||
* 执行带插件的操作(非流式)
|
||||
* 提供给AiExecutor使用
|
||||
*/
|
||||
async executeWithPlugins<
|
||||
TParams extends GenerateTextParams,
|
||||
TResult extends GenerateTextResult
|
||||
>(
|
||||
async executeWithPlugins<TParams extends GenerateTextParams, TResult extends GenerateTextResult>(
|
||||
methodName: string,
|
||||
params: TParams,
|
||||
executor: (model: LanguageModel, transformedParams: TParams) => TResult,
|
||||
@ -101,9 +98,7 @@ export class PluginEngine<T extends ProviderId = ProviderId> {
|
||||
const context = _context ?? createContext(this.providerId, model, params)
|
||||
|
||||
// ✅ 创建类型化的 manager(逆变安全)
|
||||
const manager = new PluginManager<TParams, TResult>(
|
||||
this.basePlugins as AiPlugin<TParams, TResult>[]
|
||||
)
|
||||
const manager = new PluginManager<TParams, TResult>(this.basePlugins as AiPlugin<TParams, TResult>[])
|
||||
|
||||
// ✅ 递归调用泛型化,增加深度限制
|
||||
context.recursiveCall = async <R = TResult>(newParams: Partial<TParams>): Promise<R> => {
|
||||
@ -118,12 +113,12 @@ export class PluginEngine<T extends ProviderId = ProviderId> {
|
||||
context.recursiveDepth = previousDepth + 1
|
||||
context.isRecursiveCall = true
|
||||
|
||||
return await this.executeWithPlugins(
|
||||
return (await this.executeWithPlugins(
|
||||
methodName,
|
||||
{ ...params, ...newParams } as TParams,
|
||||
executor,
|
||||
context
|
||||
) as unknown as R
|
||||
)) as unknown as R
|
||||
} finally {
|
||||
// ✅ finally 确保状态恢复
|
||||
context.recursiveDepth = previousDepth
|
||||
@ -201,9 +196,7 @@ export class PluginEngine<T extends ProviderId = ProviderId> {
|
||||
const context = _context ?? createContext(this.providerId, model, params)
|
||||
|
||||
// ✅ 创建类型化的 manager(逆变安全)
|
||||
const manager = new PluginManager<TParams, TResult>(
|
||||
this.basePlugins as AiPlugin<TParams, TResult>[]
|
||||
)
|
||||
const manager = new PluginManager<TParams, TResult>(this.basePlugins as AiPlugin<TParams, TResult>[])
|
||||
|
||||
// ✅ 递归调用泛型化,增加深度限制
|
||||
context.recursiveCall = async <R = TResult>(newParams: Partial<TParams>): Promise<R> => {
|
||||
@ -218,12 +211,12 @@ export class PluginEngine<T extends ProviderId = ProviderId> {
|
||||
context.recursiveDepth = previousDepth + 1
|
||||
context.isRecursiveCall = true
|
||||
|
||||
return await this.executeImageWithPlugins(
|
||||
return (await this.executeImageWithPlugins(
|
||||
methodName,
|
||||
{ ...params, ...newParams } as TParams,
|
||||
executor,
|
||||
context
|
||||
) as unknown as R
|
||||
)) as unknown as R
|
||||
} finally {
|
||||
// ✅ finally 确保状态恢复
|
||||
context.recursiveDepth = previousDepth
|
||||
@ -275,10 +268,7 @@ export class PluginEngine<T extends ProviderId = ProviderId> {
|
||||
* 执行流式调用的通用逻辑(支持流转换器)
|
||||
* 提供给AiExecutor使用
|
||||
*/
|
||||
async executeStreamWithPlugins<
|
||||
TParams extends StreamTextParams,
|
||||
TResult extends StreamTextResult
|
||||
>(
|
||||
async executeStreamWithPlugins<TParams extends StreamTextParams, TResult extends StreamTextResult>(
|
||||
methodName: string,
|
||||
params: TParams,
|
||||
executor: (model: LanguageModel, transformedParams: TParams, streamTransforms: any[]) => TResult,
|
||||
@ -301,9 +291,7 @@ export class PluginEngine<T extends ProviderId = ProviderId> {
|
||||
const context = _context ?? createContext(this.providerId, model, params)
|
||||
|
||||
// ✅ 创建类型化的 manager(逆变安全)
|
||||
const manager = new PluginManager<TParams, TResult>(
|
||||
this.basePlugins as AiPlugin<TParams, TResult>[]
|
||||
)
|
||||
const manager = new PluginManager<TParams, TResult>(this.basePlugins as AiPlugin<TParams, TResult>[])
|
||||
|
||||
// ✅ 递归调用泛型化,增加深度限制
|
||||
context.recursiveCall = async <R = TResult>(newParams: Partial<TParams>): Promise<R> => {
|
||||
@ -318,12 +306,12 @@ export class PluginEngine<T extends ProviderId = ProviderId> {
|
||||
context.recursiveDepth = previousDepth + 1
|
||||
context.isRecursiveCall = true
|
||||
|
||||
return await this.executeStreamWithPlugins(
|
||||
return (await this.executeStreamWithPlugins(
|
||||
methodName,
|
||||
{ ...params, ...newParams } as TParams,
|
||||
executor,
|
||||
context
|
||||
) as unknown as R
|
||||
)) as unknown as R
|
||||
} finally {
|
||||
// ✅ finally 确保状态恢复
|
||||
context.recursiveDepth = previousDepth
|
||||
|
||||
Loading…
Reference in New Issue
Block a user