Add comprehensive tests for PluginEngine functionality

- Implement tests for plugin registration, management, and lifecycle
- Validate plugin execution order and context management
- Test model resolution, parameter transformation, and result transformation
- Ensure error handling and recursive call support
- Cover streaming execution and image model handling
- Verify type safety for plugin parameters
This commit is contained in:
suyao 2025-12-29 17:11:11 +08:00
parent fca41ed966
commit 3c23c32232
No known key found for this signature in database
16 changed files with 4683 additions and 47 deletions

View File

@ -11,18 +11,20 @@ import { vi } from 'vitest'
*/
export function createMockLanguageModel(overrides?: Partial<LanguageModelV3>): LanguageModelV3 {
return {
specificationVersion: 'v1',
specificationVersion: 'v3',
provider: 'mock-provider',
modelId: 'mock-model',
defaultObjectGenerationMode: 'tool',
supportedUrls: {},
doGenerate: vi.fn().mockResolvedValue({
text: 'Mock response text',
finishReason: 'stop',
usage: {
promptTokens: 10,
completionTokens: 20,
totalTokens: 30
inputTokens: 10,
outputTokens: 20,
totalTokens: 30,
inputTokenDetails: {},
outputTokenDetails: {}
},
rawCall: { rawPrompt: null, rawSettings: {} },
rawResponse: { headers: {} },
@ -47,9 +49,11 @@ export function createMockLanguageModel(overrides?: Partial<LanguageModelV3>): L
type: 'finish',
finishReason: 'stop',
usage: {
promptTokens: 10,
completionTokens: 15,
totalTokens: 25
inputTokens: 10,
outputTokens: 15,
totalTokens: 25,
inputTokenDetails: {},
outputTokenDetails: {}
}
}
})(),
@ -64,12 +68,14 @@ export function createMockLanguageModel(overrides?: Partial<LanguageModelV3>): L
/**
* Creates a mock image model with customizable behavior
* Compliant with AI SDK v3 specification
*/
export function createMockImageModel(overrides?: Partial<ImageModelV3>): ImageModelV3 {
return {
specificationVersion: 'V3',
specificationVersion: 'v3',
provider: 'mock-provider',
modelId: 'mock-image-model',
maxImagesPerCall: undefined,
doGenerate: vi.fn().mockResolvedValue({
images: [

View File

@ -0,0 +1,300 @@
/**
* Model Test Utilities
* Provides comprehensive mock creators for AI SDK v3 models and related test utilities
*/
import type {
EmbeddingModelV3,
ImageModelV3,
LanguageModelV3,
LanguageModelV3Middleware,
ProviderV3
} from '@ai-sdk/provider'
import type { Tool, ToolSet } from 'ai'
import { tool } from 'ai'
import { MockLanguageModelV3 } from 'ai/test'
import { vi } from 'vitest'
import * as z from 'zod'
import { StreamTextParams, StreamTextResult } from '../../core/plugins'
import type { ProviderId } from '../../core/providers/types'
import { AiRequestContext } from '../../types'
/**
* Type for partial overrides that allows omitting the model field
* The model will be automatically added by createMockContext
*/
type ContextOverrides = Partial<Omit<AiRequestContext<StreamTextParams, StreamTextResult>, 'originalParams'>> & {
originalParams?: Partial<Omit<StreamTextParams, 'model'>> & { model?: StreamTextParams['model'] }
}
/**
* Creates a mock AiRequestContext with type safety
* The model field is automatically added to originalParams if not provided
*
* @example
* ```ts
* const context = createMockContext({
* providerId: 'openai',
* metadata: { requestId: 'test-123' }
* })
* ```
*/
export function createMockContext(overrides?: ContextOverrides): AiRequestContext<StreamTextParams, StreamTextResult> {
const mockModel = new MockLanguageModelV3({
provider: 'test-provider',
modelId: 'test-model'
})
const base: AiRequestContext<StreamTextParams, StreamTextResult> = {
providerId: 'openai' as ProviderId,
model: mockModel,
originalParams: {
model: mockModel,
messages: [{ role: 'user', content: 'Test message' }]
} as StreamTextParams,
metadata: {},
startTime: Date.now(),
requestId: 'test-request-id',
recursiveCall: vi.fn(),
isRecursiveCall: false,
recursiveDepth: 0,
maxRecursiveDepth: 10,
extensions: new Map()
}
if (overrides) {
// Ensure model is always present in originalParams
const mergedOriginalParams = {
...base.originalParams,
...overrides.originalParams,
model: overrides.originalParams?.model ?? mockModel
}
return {
...base,
...overrides,
originalParams: mergedOriginalParams as StreamTextParams
}
}
return base
}
/**
* Creates a mock embedding model with customizable behavior
* Compliant with AI SDK v3 specification
*
* @example
* ```ts
* const embeddingModel = createMockEmbeddingModel({
* provider: 'openai',
* modelId: 'text-embedding-3-small',
* maxEmbeddingsPerCall: 2048
* })
* ```
*/
export function createMockEmbeddingModel(overrides?: Partial<EmbeddingModelV3>): EmbeddingModelV3 {
return {
specificationVersion: 'v3',
provider: 'mock-provider',
modelId: 'mock-embedding-model',
maxEmbeddingsPerCall: 100,
supportsParallelCalls: true,
doEmbed: vi.fn().mockResolvedValue({
embeddings: [
[0.1, 0.2, 0.3, 0.4, 0.5],
[0.6, 0.7, 0.8, 0.9, 1.0]
],
usage: {
inputTokens: 10,
totalTokens: 10
},
rawResponse: { headers: {} }
}),
...overrides
} as EmbeddingModelV3
}
/**
* Creates a complete mock ProviderV3 with all model types
* Useful for testing provider registration and management
*
* @example
* ```ts
* const provider = createMockProviderV3({
* provider: 'openai',
* languageModel: customLanguageModel,
* imageModel: customImageModel
* })
* ```
*/
export function createMockProviderV3(overrides?: {
provider?: string
languageModel?: (modelId: string) => LanguageModelV3
imageModel?: (modelId: string) => ImageModelV3
embeddingModel?: (modelId: string) => EmbeddingModelV3
}): ProviderV3 {
return {
specificationVersion: 'v3',
provider: overrides?.provider ?? 'mock-provider',
languageModel: overrides?.languageModel
? overrides.languageModel
: (modelId: string) =>
({
specificationVersion: 'v3',
provider: overrides?.provider ?? 'mock-provider',
modelId,
defaultObjectGenerationMode: 'tool',
supportedUrls: {},
doGenerate: vi.fn(),
doStream: vi.fn()
}) as LanguageModelV3,
imageModel: overrides?.imageModel
? overrides.imageModel
: (modelId: string) =>
({
specificationVersion: 'v3',
provider: overrides?.provider ?? 'mock-provider',
modelId,
maxImagesPerCall: undefined,
doGenerate: vi.fn()
}) as ImageModelV3,
embeddingModel: overrides?.embeddingModel
? overrides.embeddingModel
: (modelId: string) =>
({
specificationVersion: 'v3',
provider: overrides?.provider ?? 'mock-provider',
modelId,
maxEmbeddingsPerCall: 100,
supportsParallelCalls: true,
doEmbed: vi.fn()
}) as EmbeddingModelV3
} as ProviderV3
}
/**
* Creates a mock middleware for testing middleware chains
* Supports both generate and stream wrapping
*
* @example
* ```ts
* const middleware = createMockMiddleware({
* name: 'test-middleware'
* })
* ```
*/
export function createMockMiddleware(_options?: { name?: string }): LanguageModelV3Middleware {
return {
specificationVersion: 'v3',
wrapGenerate: vi.fn((doGenerate) => doGenerate),
wrapStream: vi.fn((doStream) => doStream)
}
}
/**
* Creates a type-safe function tool for testing using AI SDK's tool() function
*
* @example
* ```ts
* const weatherTool = createMockTool('getWeather', 'Get current weather')
* ```
*/
export function createMockTool(name: string, description?: string): Tool<{ value?: string }, string> {
return tool({
description: description || `Mock tool: ${name}`,
inputSchema: z.object({
value: z.string().optional()
}),
execute: vi.fn(async () => 'mock result')
})
}
/**
* Creates a provider-defined tool for testing
*/
export function createMockProviderTool(name: string, description?: string): { type: 'provider'; description: string } {
return {
type: 'provider' as const,
description: description || `Mock provider tool: ${name}`
}
}
/**
* Creates a ToolSet with multiple tools
*
* @example
* ```ts
* const tools = createMockToolSet({
* getWeather: 'function',
* searchDatabase: 'function',
* nativeSearch: 'provider'
* })
* ```
*/
export function createMockToolSet(tools: Record<string, 'function' | 'provider'>): ToolSet {
const toolSet: ToolSet = {}
for (const [name, type] of Object.entries(tools)) {
if (type === 'function') {
toolSet[name] = createMockTool(name)
} else {
toolSet[name] = createMockProviderTool(name) as Tool
}
}
return toolSet
}
/**
* Creates mock stream params for testing
*
* @example
* ```ts
* const params = createMockStreamParams({
* messages: [{ role: 'user', content: 'Custom message' }],
* temperature: 0.7
* })
* ```
*/
export function createMockStreamParams(overrides?: Partial<StreamTextParams>): StreamTextParams {
return {
messages: [{ role: 'user', content: 'Test message' }],
...overrides
} as StreamTextParams
}
/**
* Common mock model instances for quick testing
*/
export const mockModels = {
/** Standard language model for general testing */
language: new MockLanguageModelV3({
provider: 'test-provider',
modelId: 'test-model'
}),
/** Mock OpenAI GPT-4 model */
gpt4: new MockLanguageModelV3({
provider: 'openai',
modelId: 'gpt-4'
}),
/** Mock Anthropic Claude model */
claude: new MockLanguageModelV3({
provider: 'anthropic',
modelId: 'claude-3-5-sonnet-20241022'
}),
/** Mock Google Gemini model */
gemini: new MockLanguageModelV3({
provider: 'google',
modelId: 'gemini-2.0-flash-exp'
})
} as const

View File

@ -8,5 +8,6 @@ export * from './fixtures/mock-providers'
export * from './fixtures/mock-responses'
// Helpers
export * from './helpers/model-test-utils'
export * from './helpers/provider-test-utils'
export * from './helpers/test-utils'

View File

@ -0,0 +1,454 @@
/**
* ModelResolver Comprehensive Tests
* Tests model resolution logic for language, embedding, and image models
* Covers both traditional and namespaced format resolution
*/
import type { EmbeddingModelV3, ImageModelV3, LanguageModelV3 } from '@ai-sdk/provider'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import {
createMockEmbeddingModel,
createMockImageModel,
createMockLanguageModel,
createMockMiddleware
} from '../../../__tests__'
import { DEFAULT_SEPARATOR, globalRegistryManagement } from '../../providers/RegistryManagement'
import { ModelResolver } from '../ModelResolver'
// Mock the dependencies
vi.mock('../../providers/RegistryManagement', () => ({
globalRegistryManagement: {
languageModel: vi.fn(),
embeddingModel: vi.fn(),
imageModel: vi.fn()
},
DEFAULT_SEPARATOR: '|'
}))
vi.mock('../../middleware/wrapper', () => ({
wrapModelWithMiddlewares: vi.fn((model: LanguageModelV3) => {
// Return a wrapped model with a marker
return {
...model,
_wrapped: true
} as LanguageModelV3
})
}))
describe('ModelResolver', () => {
let resolver: ModelResolver
let mockLanguageModel: LanguageModelV3
let mockEmbeddingModel: EmbeddingModelV3
let mockImageModel: ImageModelV3
beforeEach(() => {
vi.clearAllMocks()
resolver = new ModelResolver()
// Create properly typed mock models using global utilities
mockLanguageModel = createMockLanguageModel({
provider: 'test-provider',
modelId: 'test-model'
})
mockEmbeddingModel = createMockEmbeddingModel({
provider: 'test-provider',
modelId: 'test-embedding'
})
mockImageModel = createMockImageModel({
provider: 'test-provider',
modelId: 'test-image'
})
// Setup default mock implementations
vi.mocked(globalRegistryManagement.languageModel).mockReturnValue(mockLanguageModel)
vi.mocked(globalRegistryManagement.embeddingModel).mockReturnValue(mockEmbeddingModel)
vi.mocked(globalRegistryManagement.imageModel).mockReturnValue(mockImageModel)
})
describe('resolveLanguageModel', () => {
describe('Traditional Format Resolution', () => {
it('should resolve traditional format modelId without separator', async () => {
const result = await resolver.resolveLanguageModel('gpt-4', 'openai')
expect(globalRegistryManagement.languageModel).toHaveBeenCalledWith(`openai${DEFAULT_SEPARATOR}gpt-4`)
expect(result).toBe(mockLanguageModel)
})
it('should resolve with different provider and modelId combinations', async () => {
const testCases: Array<{ modelId: string; providerId: string; expected: string }> = [
{ modelId: 'claude-3-5-sonnet', providerId: 'anthropic', expected: 'anthropic|claude-3-5-sonnet' },
{ modelId: 'gemini-2.0-flash', providerId: 'google', expected: 'google|gemini-2.0-flash' },
{ modelId: 'grok-2-latest', providerId: 'xai', expected: 'xai|grok-2-latest' },
{ modelId: 'deepseek-chat', providerId: 'deepseek', expected: 'deepseek|deepseek-chat' }
]
for (const testCase of testCases) {
vi.clearAllMocks()
await resolver.resolveLanguageModel(testCase.modelId, testCase.providerId)
expect(globalRegistryManagement.languageModel).toHaveBeenCalledWith(testCase.expected)
}
})
it('should handle modelIds with special characters', async () => {
const modelIds = ['model-v1.0', 'model_v2', 'model.2024', 'model:free']
for (const modelId of modelIds) {
vi.clearAllMocks()
await resolver.resolveLanguageModel(modelId, 'provider')
expect(globalRegistryManagement.languageModel).toHaveBeenCalledWith(`provider${DEFAULT_SEPARATOR}${modelId}`)
}
})
})
describe('Namespaced Format Resolution', () => {
it('should resolve namespaced format with hub', async () => {
const namespacedId = `aihubmix${DEFAULT_SEPARATOR}anthropic${DEFAULT_SEPARATOR}claude-3-5-sonnet`
const result = await resolver.resolveLanguageModel(namespacedId, 'openai')
expect(globalRegistryManagement.languageModel).toHaveBeenCalledWith(namespacedId)
expect(result).toBe(mockLanguageModel)
})
it('should resolve simple namespaced format', async () => {
const namespacedId = `provider${DEFAULT_SEPARATOR}model-id`
await resolver.resolveLanguageModel(namespacedId, 'fallback-provider')
expect(globalRegistryManagement.languageModel).toHaveBeenCalledWith(namespacedId)
})
it('should handle complex namespaced IDs', async () => {
const complexIds = [
`hub${DEFAULT_SEPARATOR}provider${DEFAULT_SEPARATOR}model`,
`hub${DEFAULT_SEPARATOR}provider${DEFAULT_SEPARATOR}model-v1.0`,
`custom${DEFAULT_SEPARATOR}openai${DEFAULT_SEPARATOR}gpt-4-turbo`
]
for (const id of complexIds) {
vi.clearAllMocks()
await resolver.resolveLanguageModel(id, 'fallback')
expect(globalRegistryManagement.languageModel).toHaveBeenCalledWith(id)
}
})
})
describe('OpenAI Mode Selection', () => {
it('should append "-chat" suffix for OpenAI provider with chat mode', async () => {
await resolver.resolveLanguageModel('gpt-4', 'openai', { mode: 'chat' })
expect(globalRegistryManagement.languageModel).toHaveBeenCalledWith('openai-chat|gpt-4')
})
it('should append "-chat" suffix for Azure provider with chat mode', async () => {
await resolver.resolveLanguageModel('gpt-4', 'azure', { mode: 'chat' })
expect(globalRegistryManagement.languageModel).toHaveBeenCalledWith('azure-chat|gpt-4')
})
it('should not append suffix for OpenAI with responses mode', async () => {
await resolver.resolveLanguageModel('gpt-4', 'openai', { mode: 'responses' })
expect(globalRegistryManagement.languageModel).toHaveBeenCalledWith('openai|gpt-4')
})
it('should not append suffix for OpenAI without mode', async () => {
await resolver.resolveLanguageModel('gpt-4', 'openai')
expect(globalRegistryManagement.languageModel).toHaveBeenCalledWith('openai|gpt-4')
})
it('should not append suffix for other providers with chat mode', async () => {
await resolver.resolveLanguageModel('claude-3', 'anthropic', { mode: 'chat' })
expect(globalRegistryManagement.languageModel).toHaveBeenCalledWith('anthropic|claude-3')
})
it('should handle namespaced IDs with OpenAI chat mode', async () => {
const namespacedId = `hub${DEFAULT_SEPARATOR}openai${DEFAULT_SEPARATOR}gpt-4`
await resolver.resolveLanguageModel(namespacedId, 'openai', { mode: 'chat' })
// Should use the namespaced ID directly, not apply mode logic
expect(globalRegistryManagement.languageModel).toHaveBeenCalledWith(namespacedId)
})
})
describe('Middleware Application', () => {
it('should apply middlewares to resolved model', async () => {
const mockMiddleware = createMockMiddleware({ name: 'test-middleware' })
const result = await resolver.resolveLanguageModel('gpt-4', 'openai', undefined, [mockMiddleware])
expect(result).toHaveProperty('_wrapped', true)
})
it('should apply multiple middlewares in order', async () => {
const middleware1 = createMockMiddleware({ name: 'middleware-1' })
const middleware2 = createMockMiddleware({ name: 'middleware-2' })
const result = await resolver.resolveLanguageModel('gpt-4', 'openai', undefined, [middleware1, middleware2])
expect(result).toHaveProperty('_wrapped', true)
})
it('should not apply middlewares when none provided', async () => {
const result = await resolver.resolveLanguageModel('gpt-4', 'openai')
expect(result).not.toHaveProperty('_wrapped')
expect(result).toBe(mockLanguageModel)
})
it('should not apply middlewares when empty array provided', async () => {
const result = await resolver.resolveLanguageModel('gpt-4', 'openai', undefined, [])
expect(result).not.toHaveProperty('_wrapped')
})
})
describe('Provider Options Handling', () => {
it('should pass provider options correctly', async () => {
const options = { baseURL: 'https://api.example.com', apiKey: 'test-key' }
await resolver.resolveLanguageModel('gpt-4', 'openai', options)
// Provider options are used for mode selection logic
expect(globalRegistryManagement.languageModel).toHaveBeenCalled()
})
it('should handle empty provider options', async () => {
await resolver.resolveLanguageModel('gpt-4', 'openai', {})
expect(globalRegistryManagement.languageModel).toHaveBeenCalledWith('openai|gpt-4')
})
it('should handle undefined provider options', async () => {
await resolver.resolveLanguageModel('gpt-4', 'openai', undefined)
expect(globalRegistryManagement.languageModel).toHaveBeenCalledWith('openai|gpt-4')
})
})
})
describe('resolveTextEmbeddingModel', () => {
describe('Traditional Format', () => {
it('should resolve traditional embedding model ID', async () => {
const result = await resolver.resolveTextEmbeddingModel('text-embedding-ada-002', 'openai')
expect(globalRegistryManagement.embeddingModel).toHaveBeenCalledWith('openai|text-embedding-ada-002')
expect(result).toBe(mockEmbeddingModel)
})
it('should resolve different embedding models', async () => {
const testCases = [
{ modelId: 'text-embedding-3-small', providerId: 'openai' },
{ modelId: 'text-embedding-3-large', providerId: 'openai' },
{ modelId: 'embed-english-v3.0', providerId: 'cohere' },
{ modelId: 'voyage-2', providerId: 'voyage' }
]
for (const { modelId, providerId } of testCases) {
vi.clearAllMocks()
await resolver.resolveTextEmbeddingModel(modelId, providerId)
expect(globalRegistryManagement.embeddingModel).toHaveBeenCalledWith(`${providerId}|${modelId}`)
}
})
})
describe('Namespaced Format', () => {
it('should resolve namespaced embedding model ID', async () => {
const namespacedId = `aihubmix${DEFAULT_SEPARATOR}openai${DEFAULT_SEPARATOR}text-embedding-3-small`
const result = await resolver.resolveTextEmbeddingModel(namespacedId, 'openai')
expect(globalRegistryManagement.embeddingModel).toHaveBeenCalledWith(namespacedId)
expect(result).toBe(mockEmbeddingModel)
})
it('should handle complex namespaced embedding IDs', async () => {
const complexIds = [
`hub${DEFAULT_SEPARATOR}cohere${DEFAULT_SEPARATOR}embed-multilingual`,
`custom${DEFAULT_SEPARATOR}provider${DEFAULT_SEPARATOR}embedding-model`
]
for (const id of complexIds) {
vi.clearAllMocks()
await resolver.resolveTextEmbeddingModel(id, 'fallback')
expect(globalRegistryManagement.embeddingModel).toHaveBeenCalledWith(id)
}
})
})
})
describe('resolveImageModel', () => {
describe('Traditional Format', () => {
it('should resolve traditional image model ID', async () => {
const result = await resolver.resolveImageModel('dall-e-3', 'openai')
expect(globalRegistryManagement.imageModel).toHaveBeenCalledWith('openai|dall-e-3')
expect(result).toBe(mockImageModel)
})
it('should resolve different image models', async () => {
const testCases = [
{ modelId: 'dall-e-2', providerId: 'openai' },
{ modelId: 'stable-diffusion-xl', providerId: 'stability' },
{ modelId: 'imagen-2', providerId: 'google' },
{ modelId: 'midjourney-v6', providerId: 'midjourney' }
]
for (const { modelId, providerId } of testCases) {
vi.clearAllMocks()
await resolver.resolveImageModel(modelId, providerId)
expect(globalRegistryManagement.imageModel).toHaveBeenCalledWith(`${providerId}|${modelId}`)
}
})
})
describe('Namespaced Format', () => {
it('should resolve namespaced image model ID', async () => {
const namespacedId = `aihubmix${DEFAULT_SEPARATOR}openai${DEFAULT_SEPARATOR}dall-e-3`
const result = await resolver.resolveImageModel(namespacedId, 'openai')
expect(globalRegistryManagement.imageModel).toHaveBeenCalledWith(namespacedId)
expect(result).toBe(mockImageModel)
})
it('should handle complex namespaced image IDs', async () => {
const complexIds = [
`hub${DEFAULT_SEPARATOR}stability${DEFAULT_SEPARATOR}sdxl-turbo`,
`custom${DEFAULT_SEPARATOR}provider${DEFAULT_SEPARATOR}image-gen-v2`
]
for (const id of complexIds) {
vi.clearAllMocks()
await resolver.resolveImageModel(id, 'fallback')
expect(globalRegistryManagement.imageModel).toHaveBeenCalledWith(id)
}
})
})
})
describe('Edge Cases and Error Scenarios', () => {
it('should handle empty model IDs', async () => {
await resolver.resolveLanguageModel('', 'openai')
expect(globalRegistryManagement.languageModel).toHaveBeenCalledWith('openai|')
})
it('should handle model IDs with multiple separators', async () => {
const multiSeparatorId = `hub${DEFAULT_SEPARATOR}sub${DEFAULT_SEPARATOR}provider${DEFAULT_SEPARATOR}model`
await resolver.resolveLanguageModel(multiSeparatorId, 'fallback')
expect(globalRegistryManagement.languageModel).toHaveBeenCalledWith(multiSeparatorId)
})
it('should handle model IDs with only separator', async () => {
const onlySeparator = DEFAULT_SEPARATOR
await resolver.resolveLanguageModel(onlySeparator, 'provider')
expect(globalRegistryManagement.languageModel).toHaveBeenCalledWith(onlySeparator)
})
it('should throw if globalRegistryManagement throws', async () => {
const error = new Error('Model not found in registry')
vi.mocked(globalRegistryManagement.languageModel).mockImplementation(() => {
throw error
})
await expect(resolver.resolveLanguageModel('invalid-model', 'openai')).rejects.toThrow(
'Model not found in registry'
)
})
it('should handle concurrent resolution requests', async () => {
const promises = [
resolver.resolveLanguageModel('gpt-4', 'openai'),
resolver.resolveLanguageModel('claude-3', 'anthropic'),
resolver.resolveLanguageModel('gemini-2.0', 'google')
]
const results = await Promise.all(promises)
expect(results).toHaveLength(3)
expect(globalRegistryManagement.languageModel).toHaveBeenCalledTimes(3)
})
})
describe('Type Safety', () => {
it('should return properly typed LanguageModelV3', async () => {
const result = await resolver.resolveLanguageModel('gpt-4', 'openai')
// Type assertions
expect(result.specificationVersion).toBe('v3')
expect(result).toHaveProperty('doGenerate')
expect(result).toHaveProperty('doStream')
})
it('should return properly typed EmbeddingModelV3', async () => {
const result = await resolver.resolveTextEmbeddingModel('text-embedding-ada-002', 'openai')
expect(result.specificationVersion).toBe('v3')
expect(result).toHaveProperty('doEmbed')
})
it('should return properly typed ImageModelV3', async () => {
const result = await resolver.resolveImageModel('dall-e-3', 'openai')
expect(result.specificationVersion).toBe('v3')
expect(result).toHaveProperty('doGenerate')
})
})
describe('Global ModelResolver Instance', () => {
it('should have a global instance available', async () => {
const { globalModelResolver } = await import('../ModelResolver')
expect(globalModelResolver).toBeInstanceOf(ModelResolver)
})
})
describe('Integration with Different Provider Types', () => {
it('should work with OpenAI compatible providers', async () => {
await resolver.resolveLanguageModel('custom-model', 'openai-compatible')
expect(globalRegistryManagement.languageModel).toHaveBeenCalledWith('openai-compatible|custom-model')
})
it('should work with hub providers', async () => {
const hubId = `aihubmix${DEFAULT_SEPARATOR}custom${DEFAULT_SEPARATOR}model-v1`
await resolver.resolveLanguageModel(hubId, 'aihubmix')
expect(globalRegistryManagement.languageModel).toHaveBeenCalledWith(hubId)
})
it('should handle all model types for same provider', async () => {
const providerId = 'openai'
const languageModel = 'gpt-4'
const embeddingModel = 'text-embedding-3-small'
const imageModel = 'dall-e-3'
await resolver.resolveLanguageModel(languageModel, providerId)
await resolver.resolveTextEmbeddingModel(embeddingModel, providerId)
await resolver.resolveImageModel(imageModel, providerId)
expect(globalRegistryManagement.languageModel).toHaveBeenCalledWith(`${providerId}|${languageModel}`)
expect(globalRegistryManagement.embeddingModel).toHaveBeenCalledWith(`${providerId}|${embeddingModel}`)
expect(globalRegistryManagement.imageModel).toHaveBeenCalledWith(`${providerId}|${imageModel}`)
})
})
})

View File

@ -0,0 +1,171 @@
import type { LanguageModelV2, LanguageModelV3 } from '@ai-sdk/provider'
import { describe, expect, it } from 'vitest'
import type { AiSdkModel } from '../../providers'
import { hasModelId, isV2Model, isV3Model } from '../utils'
describe('Model Type Guards', () => {
describe('isV2Model', () => {
it('should return true for V2 models', () => {
const v2Model: AiSdkModel = {
specificationVersion: 'v2',
modelId: 'test-model',
provider: 'test-provider'
} as LanguageModelV2
expect(isV2Model(v2Model)).toBe(true)
})
it('should return false for V3 models', () => {
const v3Model: AiSdkModel = {
specificationVersion: 'v3',
modelId: 'test-model',
provider: 'test-provider'
} as LanguageModelV3
expect(isV2Model(v3Model)).toBe(false)
})
it('should return false for non-object values', () => {
expect(isV2Model('model-id' as any)).toBe(false)
expect(isV2Model(null as any)).toBe(false)
expect(isV2Model(undefined as any)).toBe(false)
expect(isV2Model(123 as any)).toBe(false)
})
it('should return false for objects without specificationVersion', () => {
const invalidModel = {
modelId: 'test-model',
provider: 'test-provider'
} as any
expect(isV2Model(invalidModel)).toBe(false)
})
})
describe('isV3Model', () => {
it('should return true for V3 models', () => {
const v3Model: AiSdkModel = {
specificationVersion: 'v3',
modelId: 'test-model',
provider: 'test-provider'
} as LanguageModelV3
expect(isV3Model(v3Model)).toBe(true)
})
it('should return false for V2 models', () => {
const v2Model: AiSdkModel = {
specificationVersion: 'v2',
modelId: 'test-model',
provider: 'test-provider'
} as LanguageModelV2
expect(isV3Model(v2Model)).toBe(false)
})
it('should return false for non-object values', () => {
expect(isV3Model('model-id' as any)).toBe(false)
expect(isV3Model(null as any)).toBe(false)
expect(isV3Model(undefined as any)).toBe(false)
})
it('should return false for objects without specificationVersion', () => {
const invalidModel = {
modelId: 'test-model',
provider: 'test-provider'
} as any
expect(isV3Model(invalidModel)).toBe(false)
})
})
describe('Type Guard Correctness', () => {
it('should correctly distinguish between V2 and V3 models', () => {
const v2Model: AiSdkModel = {
specificationVersion: 'v2',
modelId: 'v2-model'
} as LanguageModelV2
const v3Model: AiSdkModel = {
specificationVersion: 'v3',
modelId: 'v3-model'
} as LanguageModelV3
// V2 model should only match isV2Model
expect(isV2Model(v2Model)).toBe(true)
expect(isV3Model(v2Model)).toBe(false)
// V3 model should only match isV3Model
expect(isV2Model(v3Model)).toBe(false)
expect(isV3Model(v3Model)).toBe(true)
})
it('should narrow type correctly for V2 models', () => {
const model: AiSdkModel = {
specificationVersion: 'v2',
modelId: 'test'
} as LanguageModelV2
if (isV2Model(model)) {
expect(model.specificationVersion).toBe('v2')
}
})
it('should narrow type correctly for V3 models', () => {
const model: AiSdkModel = {
specificationVersion: 'v3',
modelId: 'test'
} as LanguageModelV3
if (isV3Model(model)) {
expect(model.specificationVersion).toBe('v3')
}
})
})
describe('hasModelId', () => {
it('should return true for objects with modelId string property', () => {
const modelWithId = {
modelId: 'test-model-id',
other: 'property'
}
expect(hasModelId(modelWithId)).toBe(true)
})
it('should return false for objects without modelId property', () => {
const modelWithoutId = {
other: 'property'
}
expect(hasModelId(modelWithoutId)).toBe(false)
})
it('should return false for objects with non-string modelId', () => {
const modelWithNumericId = {
modelId: 123
}
expect(hasModelId(modelWithNumericId)).toBe(false)
})
it('should return false for non-object values', () => {
expect(hasModelId(null)).toBe(false)
expect(hasModelId(undefined)).toBe(false)
expect(hasModelId('string')).toBe(false)
expect(hasModelId(123)).toBe(false)
expect(hasModelId(true)).toBe(false)
})
it('should narrow type correctly', () => {
const unknownValue: unknown = { modelId: 'test-id' }
if (hasModelId(unknownValue)) {
// TypeScript should allow accessing modelId as string
expect(typeof unknownValue.modelId).toBe('string')
expect(unknownValue.modelId).toBe('test-id')
}
})
})
})

View File

@ -4,6 +4,7 @@
* AI SDK
* promptToolUsePlugin.ts
*/
import type { SharedV3ProviderMetadata } from '@ai-sdk/provider'
import type { EmbeddingModelUsage, ImageModelUsage, LanguageModelUsage, ModelMessage } from 'ai'
import type { AiSdkUsage } from '../../../providers/types'
@ -79,7 +80,11 @@ export class StreamEventManager {
*/
sendStepFinishEvent(
controller: StreamController,
chunk: any,
chunk: {
usage?: Partial<AiSdkUsage>
response?: { id: string; [key: string]: unknown }
providerMetadata?: SharedV3ProviderMetadata
},
context: AiRequestContext,
finishReason: string = 'stop'
): void {
@ -154,7 +159,7 @@ export class StreamEventManager {
context: AiRequestContext<TParams, StreamTextResult>,
textBuffer: string,
toolResultsText: string,
tools: any
tools: Record<string, unknown>
): Partial<TParams> {
const params = context.originalParams

View File

@ -14,15 +14,16 @@ import type { ToolUseResult } from './type'
export interface ExecutedResult {
toolCallId: string
toolName: string
result: any
result: unknown
isError?: boolean
}
/**
* AI SDK
* Generic type parameter allows for type-safe chunk enqueuing
*/
export interface StreamController {
enqueue(chunk: any): void
export interface StreamController<TChunk = unknown> {
enqueue(chunk: TChunk): void
}
/**

View File

@ -0,0 +1,566 @@
import type { SharedV3ProviderMetadata } from '@ai-sdk/provider'
import type {
EmbeddingModelUsage,
ImageModelUsage,
LanguageModelUsage as AiSdkUsage,
LanguageModelUsage,
TextStreamPart,
ToolSet
} from 'ai'
import { simulateReadableStream } from 'ai'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import { createMockContext, createMockTool } from '../../../../../__tests__'
import { StreamEventManager } from '../StreamEventManager'
import type { StreamController } from '../ToolExecutor'
/**
* Type alias for empty toolset (no tools)
* Using Record<string, never> ensures type safety for tests without tools
*/
type EmptyToolSet = Record<string, never>
/**
* Mock StreamController for testing
* Provides type-safe enqueue function that accepts TextStreamPart chunks
*/
interface MockStreamController<TOOLS extends ToolSet = EmptyToolSet> extends StreamController {
enqueue: ReturnType<typeof vi.fn<(chunk: TextStreamPart<TOOLS>) => void>>
}
/**
* Create a type-safe mock stream controller
*/
function createMockStreamController<TOOLS extends ToolSet = EmptyToolSet>(): MockStreamController<TOOLS> {
return {
enqueue: vi.fn()
}
}
/**
* Type for chunk data in finish-step events
*/
interface FinishStepChunk {
usage?: Partial<AiSdkUsage>
response?: { id: string; [key: string]: unknown }
providerMetadata?: SharedV3ProviderMetadata
}
describe('StreamEventManager', () => {
let manager: StreamEventManager
beforeEach(() => {
manager = new StreamEventManager()
})
describe('accumulateUsage', () => {
describe('LanguageModelUsage', () => {
it('should accumulate language model usage correctly', () => {
const target: Partial<LanguageModelUsage> = {
inputTokens: 10,
outputTokens: 20,
totalTokens: 30
}
const source: Partial<LanguageModelUsage> = {
inputTokens: 5,
outputTokens: 10,
totalTokens: 15
}
manager.accumulateUsage(target, source)
expect(target.inputTokens).toBe(15)
expect(target.outputTokens).toBe(30)
expect(target.totalTokens).toBe(45)
})
it('should handle undefined values in target', () => {
const target: Partial<LanguageModelUsage> = { inputTokens: 10 }
const source: Partial<LanguageModelUsage> = {
inputTokens: 5,
outputTokens: 10,
totalTokens: 15
}
manager.accumulateUsage(target, source)
expect(target.inputTokens).toBe(15)
expect(target.outputTokens).toBe(10)
expect(target.totalTokens).toBe(15)
})
it('should handle undefined values in source', () => {
const target: Partial<LanguageModelUsage> = {
inputTokens: 10,
outputTokens: 20,
totalTokens: 30
}
const source: Partial<LanguageModelUsage> = { inputTokens: 5 }
manager.accumulateUsage(target, source)
expect(target.inputTokens).toBe(15)
expect(target.outputTokens).toBe(20)
expect(target.totalTokens).toBe(30)
})
it('should handle zero values correctly', () => {
const target: Partial<LanguageModelUsage> = {
inputTokens: 0,
outputTokens: 0,
totalTokens: 0
}
const source: Partial<LanguageModelUsage> = {
inputTokens: 5,
outputTokens: 10,
totalTokens: 15
}
manager.accumulateUsage(target, source)
expect(target.inputTokens).toBe(5)
expect(target.outputTokens).toBe(10)
expect(target.totalTokens).toBe(15)
})
})
describe('ImageModelUsage', () => {
it('should accumulate image model usage correctly', () => {
const target: Partial<ImageModelUsage> = {
inputTokens: 100,
outputTokens: 50,
totalTokens: 150
}
const source: Partial<ImageModelUsage> = {
inputTokens: 50,
outputTokens: 25,
totalTokens: 75
}
manager.accumulateUsage(target, source)
expect(target.inputTokens).toBe(150)
expect(target.outputTokens).toBe(75)
expect(target.totalTokens).toBe(225)
})
it('should handle undefined values', () => {
const target: Partial<ImageModelUsage> = { inputTokens: 100 }
const source: Partial<ImageModelUsage> = {
outputTokens: 50,
totalTokens: 50
}
manager.accumulateUsage(target, source)
expect(target.inputTokens).toBe(100)
expect(target.outputTokens).toBe(50)
expect(target.totalTokens).toBe(50)
})
})
describe('EmbeddingModelUsage', () => {
it('should accumulate embedding model usage correctly', () => {
const target: Partial<EmbeddingModelUsage> = { tokens: 100 }
const source: Partial<EmbeddingModelUsage> = { tokens: 50 }
manager.accumulateUsage(target, source)
expect(target.tokens).toBe(150)
})
it('should handle zero to non-zero accumulation', () => {
const target: Partial<EmbeddingModelUsage> = { tokens: 0 }
const source: Partial<EmbeddingModelUsage> = { tokens: 50 }
manager.accumulateUsage(target, source)
expect(target.tokens).toBe(50)
})
it('should handle zero values', () => {
const target: Partial<EmbeddingModelUsage> = { tokens: 0 }
const source: Partial<EmbeddingModelUsage> = { tokens: 100 }
manager.accumulateUsage(target, source)
expect(target.tokens).toBe(100)
})
})
describe('Type Guard Validation', () => {
it('should warn on type mismatch between LanguageModelUsage and EmbeddingModelUsage', () => {
const warnSpy = vi.spyOn(console, 'warn')
const target: Partial<LanguageModelUsage> = { inputTokens: 10 }
const source: Partial<EmbeddingModelUsage> = { tokens: 5 }
manager.accumulateUsage(target, source)
expect(warnSpy).toHaveBeenCalledWith(
expect.stringContaining('Unable to accumulate usage'),
expect.objectContaining({
target,
source
})
)
warnSpy.mockRestore()
})
it('should warn on type mismatch between ImageModelUsage and EmbeddingModelUsage', () => {
const warnSpy = vi.spyOn(console, 'warn')
const target: Partial<ImageModelUsage> = { inputTokens: 100 }
const source: Partial<EmbeddingModelUsage> = { tokens: 50 }
manager.accumulateUsage(target, source)
expect(warnSpy).toHaveBeenCalledWith(expect.stringContaining('Unable to accumulate usage'), expect.any(Object))
warnSpy.mockRestore()
})
})
})
describe('buildRecursiveParams', () => {
it('should include textBuffer in assistant message when not empty', () => {
const context = createMockContext()
const textBuffer = 'test response'
const toolResultsText = '<tool_use_result>...</tool_use_result>'
const tools = {
test_tool: createMockTool('test_tool')
}
const params = manager.buildRecursiveParams(context, textBuffer, toolResultsText, tools)
expect(params.messages).toHaveLength(3)
expect(params.messages?.[0]).toEqual({ role: 'user', content: 'Test message' })
expect(params.messages?.[1]).toEqual({ role: 'assistant', content: textBuffer })
expect(params.messages?.[2]).toEqual({
role: 'user',
content: toolResultsText
})
expect(params.tools).toBe(tools)
})
it('should skip empty textBuffer in messages', () => {
const context = createMockContext()
const textBuffer = ''
const toolResultsText = '<tool_use_result>...</tool_use_result>'
const tools = {}
const params = manager.buildRecursiveParams(context, textBuffer, toolResultsText, tools)
// Should only have original user message and new user message with tool results
expect(params.messages).toHaveLength(2)
expect(params.messages?.[0]).toEqual({ role: 'user', content: 'Test message' })
expect(params.messages?.[1]).toEqual({
role: 'user',
content: toolResultsText
})
const assistantMessages = params.messages?.filter((m) => m.role === 'assistant')
expect(assistantMessages).toHaveLength(0)
})
it('should preserve all original messages', () => {
const context = createMockContext({
originalParams: {
messages: [
{ role: 'user', content: 'First message' },
{ role: 'assistant', content: 'First response' },
{ role: 'user', content: 'Second message' }
]
}
})
const params = manager.buildRecursiveParams(context, 'New response', 'Tool results', {})
expect(params.messages).toHaveLength(5)
expect(params.messages?.[0]).toEqual({ role: 'user', content: 'First message' })
expect(params.messages?.[1]).toEqual({
role: 'assistant',
content: 'First response'
})
expect(params.messages?.[2]).toEqual({ role: 'user', content: 'Second message' })
expect(params.messages?.[3]).toEqual({ role: 'assistant', content: 'New response' })
expect(params.messages?.[4]).toEqual({ role: 'user', content: 'Tool results' })
})
it('should pass through tools parameter', () => {
const context = createMockContext()
const tools = {
tool1: createMockTool('tool1'),
tool2: createMockTool('tool2')
}
const params = manager.buildRecursiveParams(context, 'response', 'results', tools)
expect(params.tools).toBe(tools)
expect(Object.keys(params.tools!)).toHaveLength(2)
})
})
describe('sendStepStartEvent', () => {
it('should enqueue start-step event with correct structure', () => {
const controller = createMockStreamController()
manager.sendStepStartEvent(controller)
expect(controller.enqueue).toHaveBeenCalledWith({
type: 'start-step',
request: {},
warnings: []
})
})
})
describe('sendStepFinishEvent', () => {
it('should enqueue finish-step event with provided finishReason', () => {
const controller = createMockStreamController()
const chunk: FinishStepChunk = {
usage: {
inputTokens: 10,
outputTokens: 20,
totalTokens: 30
},
response: { id: 'test-response' },
providerMetadata: { 'test-provider': {} }
}
const context = createMockContext({
accumulatedUsage: {
inputTokens: 0,
outputTokens: 0,
totalTokens: 0
}
})
manager.sendStepFinishEvent(controller, chunk, context, 'tool-calls')
expect(controller.enqueue).toHaveBeenCalledWith({
type: 'finish-step',
finishReason: 'tool-calls',
response: chunk.response,
usage: chunk.usage,
providerMetadata: chunk.providerMetadata
})
})
it('should accumulate usage when provided', () => {
const controller = createMockStreamController()
const chunk: FinishStepChunk = {
usage: {
inputTokens: 10,
outputTokens: 20,
totalTokens: 30
}
}
const context = createMockContext({
accumulatedUsage: {
inputTokens: 5,
outputTokens: 10,
totalTokens: 15
}
})
manager.sendStepFinishEvent(controller, chunk, context)
// Verify accumulation happened
expect(context.accumulatedUsage.inputTokens).toBe(15)
expect(context.accumulatedUsage.outputTokens).toBe(30)
expect(context.accumulatedUsage.totalTokens).toBe(45)
})
it('should handle missing usage gracefully', () => {
const controller = createMockStreamController()
const chunk: FinishStepChunk = {}
const context = createMockContext({
accumulatedUsage: {
inputTokens: 5,
outputTokens: 10,
totalTokens: 15
}
})
expect(() => manager.sendStepFinishEvent(controller, chunk, context)).not.toThrow()
// Verify accumulation did not change
expect(context.accumulatedUsage.inputTokens).toBe(5)
expect(context.accumulatedUsage.outputTokens).toBe(10)
expect(context.accumulatedUsage.totalTokens).toBe(15)
})
it('should use default finishReason of "stop" when not provided', () => {
const controller = createMockStreamController()
const chunk: FinishStepChunk = {}
const context = createMockContext()
manager.sendStepFinishEvent(controller, chunk, context)
expect(controller.enqueue).toHaveBeenCalledWith(
expect.objectContaining({
finishReason: 'stop'
})
)
})
})
describe('handleRecursiveCall', () => {
it('should reset hasExecutedToolsInCurrentStep flag', async () => {
const controller = createMockStreamController()
const mockStream = simulateReadableStream<TextStreamPart<EmptyToolSet>>({
chunks: [
{
type: 'text-delta',
id: 'test-id',
text: 'test'
} as TextStreamPart<EmptyToolSet>
],
initialDelayInMs: 0,
chunkDelayInMs: 0
})
const context = createMockContext({
hasExecutedToolsInCurrentStep: true,
recursiveCall: vi.fn().mockResolvedValue({
fullStream: mockStream
})
})
const params = { messages: [] }
await manager.handleRecursiveCall(controller, params, context)
expect(context.hasExecutedToolsInCurrentStep).toBe(false)
expect(context.recursiveCall).toHaveBeenCalledWith(params)
})
it('should pipe recursive stream to controller', async () => {
const enqueuedChunks: TextStreamPart<EmptyToolSet>[] = []
const controller = createMockStreamController()
controller.enqueue.mockImplementation((chunk: TextStreamPart<EmptyToolSet>) => {
enqueuedChunks.push(chunk)
})
const mockChunks: TextStreamPart<EmptyToolSet>[] = [
{ type: 'start' as const },
{ type: 'start-step' as const, request: {}, warnings: [] },
{ type: 'text-delta' as const, id: 'chunk-1', text: 'recursive' },
{ type: 'text-delta' as const, id: 'chunk-2', text: ' response' },
{
type: 'finish-step' as const,
finishReason: 'stop',
rawFinishReason: 'stop',
response: {
id: 'test-response-id',
timestamp: new Date(),
modelId: 'test-model'
},
usage: {
totalTokens: 0,
inputTokens: 0,
outputTokens: 0,
inputTokenDetails: {
noCacheTokens: 0,
cacheReadTokens: 0,
cacheWriteTokens: 0
},
outputTokenDetails: {
textTokens: 0,
reasoningTokens: 0
}
},
providerMetadata: undefined
},
{
type: 'finish' as const,
finishReason: 'stop',
rawFinishReason: 'stop',
totalUsage: {
totalTokens: 0,
inputTokens: 0,
outputTokens: 0,
inputTokenDetails: {
noCacheTokens: 0,
cacheReadTokens: 0,
cacheWriteTokens: 0
},
outputTokenDetails: {
textTokens: 0,
reasoningTokens: 0
}
}
}
]
const mockStream = simulateReadableStream<TextStreamPart<EmptyToolSet>>({
chunks: mockChunks,
initialDelayInMs: 0,
chunkDelayInMs: 0
})
const context = createMockContext({
hasExecutedToolsInCurrentStep: true,
recursiveCall: vi.fn().mockResolvedValue({
fullStream: mockStream
})
})
await manager.handleRecursiveCall(controller, {}, context)
// Should skip 'start' type and stop at 'finish' type
expect(enqueuedChunks).toHaveLength(4)
expect(enqueuedChunks[0]).toEqual({ type: 'start-step', request: {}, warnings: [] })
expect(enqueuedChunks[1]).toEqual({ type: 'text-delta', id: 'chunk-1', text: 'recursive' })
expect(enqueuedChunks[2]).toEqual({ type: 'text-delta', id: 'chunk-2', text: ' response' })
expect(enqueuedChunks[3]).toMatchObject({
type: 'finish-step',
finishReason: 'stop',
rawFinishReason: 'stop',
providerMetadata: undefined,
usage: {
totalTokens: 0,
inputTokens: 0,
outputTokens: 0,
inputTokenDetails: {
noCacheTokens: 0,
cacheReadTokens: 0,
cacheWriteTokens: 0
},
outputTokenDetails: {
textTokens: 0,
reasoningTokens: 0
}
}
})
})
it('should warn when no fullStream is found', async () => {
const warnSpy = vi.spyOn(console, 'warn')
const controller = createMockStreamController()
const context = createMockContext({
hasExecutedToolsInCurrentStep: true,
recursiveCall: vi.fn().mockResolvedValue({
// No fullStream property
someOtherProperty: 'value'
})
})
await manager.handleRecursiveCall(controller, {}, context)
expect(warnSpy).toHaveBeenCalledWith(
expect.stringContaining('[MCP Prompt] No fullstream found'),
expect.any(Object)
)
warnSpy.mockRestore()
})
})
})

View File

@ -0,0 +1,547 @@
import type { TextStreamPart, ToolSet } from 'ai'
import { simulateReadableStream } from 'ai'
import { convertReadableStreamToArray } from 'ai/test'
import { describe, expect, it, vi } from 'vitest'
import { createMockContext, createMockStreamParams, createMockTool, createMockToolSet } from '../../../../../__tests__'
import { createPromptToolUsePlugin, DEFAULT_SYSTEM_PROMPT } from '../promptToolUsePlugin'
describe('promptToolUsePlugin', () => {
describe('Factory Function', () => {
it('should return AiPlugin with correct name', () => {
const plugin = createPromptToolUsePlugin()
expect(plugin.name).toBe('built-in:prompt-tool-use')
expect(plugin.transformParams).toBeDefined()
expect(plugin.transformStream).toBeDefined()
})
it('should accept empty configuration', () => {
const plugin = createPromptToolUsePlugin({})
expect(plugin).toBeDefined()
expect(plugin.name).toBe('built-in:prompt-tool-use')
})
it('should accept custom buildSystemPrompt', () => {
const customBuildSystemPrompt = vi.fn((userSystemPrompt: string) => userSystemPrompt)
const plugin = createPromptToolUsePlugin({
buildSystemPrompt: customBuildSystemPrompt
})
expect(plugin).toBeDefined()
})
it('should accept custom parseToolUse', () => {
const customParseToolUse = vi.fn(() => ({ results: [], content: '' }))
const plugin = createPromptToolUsePlugin({
parseToolUse: customParseToolUse
})
expect(plugin).toBeDefined()
})
it('should accept enabled flag', () => {
const pluginDisabled = createPromptToolUsePlugin({ enabled: false })
const pluginEnabled = createPromptToolUsePlugin({ enabled: true })
expect(pluginDisabled).toBeDefined()
expect(pluginEnabled).toBeDefined()
})
})
describe('transformParams', () => {
it('should separate provider and prompt tools', async () => {
const plugin = createPromptToolUsePlugin()
const context = createMockContext()
const params = createMockStreamParams({
tools: createMockToolSet({
provider_tool: 'provider',
prompt_tool: 'function'
})
})
const result = await Promise.resolve(plugin.transformParams!(params, context))
// Provider tools should remain in tools
expect(result.tools).toBeDefined()
expect(result.tools).toHaveProperty('provider_tool')
expect(result.tools).not.toHaveProperty('prompt_tool')
// Prompt tools should be moved to context.mcpTools
expect(context.mcpTools).toBeDefined()
expect(context.mcpTools).toHaveProperty('prompt_tool')
expect(context.mcpTools).not.toHaveProperty('provider_tool')
})
it('should handle only provider tools', async () => {
const plugin = createPromptToolUsePlugin()
const context = createMockContext()
const params = createMockStreamParams({
tools: createMockToolSet({
provider_tool1: 'provider',
provider_tool2: 'provider'
})
})
const result = await Promise.resolve(plugin.transformParams!(params, context))
expect(result.tools).toEqual(params.tools)
expect(context.mcpTools).toBeUndefined()
})
it('should handle only prompt tools', async () => {
const plugin = createPromptToolUsePlugin()
const context = createMockContext()
const params = createMockStreamParams({
tools: createMockToolSet({
prompt_tool1: 'function',
prompt_tool2: 'function'
})
})
const result = await Promise.resolve(plugin.transformParams!(params, context))
expect(result.tools).toBeUndefined()
expect(context.mcpTools).toEqual(params.tools)
})
it('should build system prompt for prompt tools', async () => {
const plugin = createPromptToolUsePlugin()
const context = createMockContext()
const params = createMockStreamParams({
system: 'Original system prompt',
tools: {
test_tool: createMockTool('test_tool', 'Test tool description')
}
})
const result = await Promise.resolve(plugin.transformParams!(params, context))
expect(result.system).toBeDefined()
expect(typeof result.system).toBe('string')
expect(result.system).toContain('In this environment you have access to a set of tools')
expect(result.system).toContain('test_tool')
expect(result.system).toContain('Test tool description')
expect(result.system).toContain('Original system prompt')
})
it('should handle empty user system prompt', async () => {
const plugin = createPromptToolUsePlugin()
const context = createMockContext()
const params = createMockStreamParams({
tools: {
test_tool: createMockTool('test_tool')
}
})
const result = await Promise.resolve(plugin.transformParams!(params, context))
expect(result.system).toBeDefined()
expect(result.system).toContain('In this environment you have access to a set of tools')
})
it('should skip system prompt when disabled', async () => {
const plugin = createPromptToolUsePlugin({ enabled: false })
const context = createMockContext()
const params = createMockStreamParams({
system: 'Original',
tools: {
test_tool: createMockTool('test_tool')
}
})
const result = await Promise.resolve(plugin.transformParams!(params, context))
expect(result).toEqual(params)
expect(context.mcpTools).toBeUndefined()
})
it('should skip when no tools provided', async () => {
const plugin = createPromptToolUsePlugin()
const context = createMockContext()
const params = createMockStreamParams({
system: 'Original'
})
const result = await Promise.resolve(plugin.transformParams!(params, context))
expect(result).toEqual(params)
})
it('should skip when tools is not an object', async () => {
const plugin = createPromptToolUsePlugin()
const context = createMockContext()
const params = createMockStreamParams({
system: 'Original',
tools: 'invalid' as any
})
const result = await Promise.resolve(plugin.transformParams!(params, context))
expect(result).toEqual(params)
})
it('should use custom buildSystemPrompt when provided', async () => {
const customBuildSystemPrompt = vi.fn((userSystemPrompt: string, tools: ToolSet) => {
return `Custom prompt with ${Object.keys(tools).length} tools and user prompt: ${userSystemPrompt}`
})
const plugin = createPromptToolUsePlugin({
buildSystemPrompt: customBuildSystemPrompt
})
const context = createMockContext()
const params = createMockStreamParams({
system: 'User prompt',
tools: {
tool1: createMockTool('tool1')
}
})
const result = await Promise.resolve(plugin.transformParams!(params, context))
expect(customBuildSystemPrompt).toHaveBeenCalled()
expect(result.system).toBe('Custom prompt with 1 tools and user prompt: User prompt')
})
it('should use custom createSystemMessage when provided', async () => {
const customCreateSystemMessage = vi.fn(() => {
return `Modified system message`
})
const plugin = createPromptToolUsePlugin({
createSystemMessage: customCreateSystemMessage
})
const context = createMockContext()
const params = createMockStreamParams({
system: 'Original',
tools: {
test: createMockTool('test')
}
})
const result = await Promise.resolve(plugin.transformParams!(params, context))
expect(customCreateSystemMessage).toHaveBeenCalled()
expect(result.system).toContain('Modified')
})
it('should save originalParams to context', async () => {
const plugin = createPromptToolUsePlugin()
const context = createMockContext()
const params = createMockStreamParams({
system: 'Original',
tools: {
test: createMockTool('test')
}
})
await Promise.resolve(plugin.transformParams!(params, context))
expect(context.originalParams).toBeDefined()
expect(context.originalParams.system).toBeDefined()
})
})
describe('transformStream', () => {
it('should return identity transform when disabled', async () => {
const plugin = createPromptToolUsePlugin({ enabled: false })
const context = createMockContext()
const inputChunks: Array<{ type: 'text-delta'; text: string }> = [
{ type: 'text-delta', text: 'Hello' },
{ type: 'text-delta', text: ' World' }
]
const inputStream = simulateReadableStream<TextStreamPart<ToolSet>>({
chunks: inputChunks as TextStreamPart<ToolSet>[],
initialDelayInMs: 0,
chunkDelayInMs: 0
})
const transform = plugin.transformStream!(createMockStreamParams(), context)()
const result = await convertReadableStreamToArray(inputStream.pipeThrough(transform))
expect(result).toEqual(inputChunks)
})
it('should return identity transform when no mcpTools in context', async () => {
const plugin = createPromptToolUsePlugin()
const context = createMockContext()
// Don't set context.mcpTools
const inputChunks: Array<{ type: 'text-delta'; text: string }> = [
{ type: 'text-delta', text: 'Hello' },
{ type: 'text-delta', text: ' World' }
]
const inputStream = simulateReadableStream<TextStreamPart<ToolSet>>({
chunks: inputChunks as TextStreamPart<ToolSet>[],
initialDelayInMs: 0,
chunkDelayInMs: 0
})
const transform = plugin.transformStream!(createMockStreamParams(), context)()
const result = await convertReadableStreamToArray(inputStream.pipeThrough(transform))
expect(result).toEqual(inputChunks)
})
it('should initialize accumulatedUsage in context', () => {
const plugin = createPromptToolUsePlugin()
const context = createMockContext()
context.mcpTools = {
test: createMockTool('test')
}
plugin.transformStream!(createMockStreamParams(), context)()
expect(context.accumulatedUsage).toBeDefined()
expect(context.accumulatedUsage).toEqual({
inputTokens: 0,
outputTokens: 0,
totalTokens: 0,
reasoningTokens: 0,
cachedInputTokens: 0
})
})
it('should filter tool tags from text-delta chunks', async () => {
const plugin = createPromptToolUsePlugin()
const context = createMockContext()
context.mcpTools = {
test: createMockTool('test')
}
const inputChunks = [
{ type: 'text-start' as const },
{ type: 'text-delta' as const, text: 'Before ' },
{ type: 'text-delta' as const, text: '<tool_use>' },
{ type: 'text-delta' as const, text: '<name>test</name>' },
{ type: 'text-delta' as const, text: '<arguments>{}</arguments>' },
{ type: 'text-delta' as const, text: '</tool_use>' },
{ type: 'text-delta' as const, text: ' After' },
{ type: 'text-end' as const }
]
const inputStream = simulateReadableStream<TextStreamPart<ToolSet>>({
chunks: inputChunks as TextStreamPart<ToolSet>[],
initialDelayInMs: 0,
chunkDelayInMs: 0
})
const transform = plugin.transformStream!(createMockStreamParams(), context)()
const result = await convertReadableStreamToArray(inputStream.pipeThrough(transform))
// Extract text from text-delta chunks
const textChunks = result.filter((chunk) => chunk.type === 'text-delta')
const fullText = textChunks.map((chunk) => 'text' in chunk && chunk.text).join('')
// Tool tags should be filtered out
expect(fullText).not.toContain('<tool_use>')
expect(fullText).not.toContain('</tool_use>')
expect(fullText).toContain('Before')
expect(fullText).toContain('After')
})
it('should hold text-start until non-tag content appears', async () => {
const plugin = createPromptToolUsePlugin()
const context = createMockContext()
context.mcpTools = {
test: createMockTool('test')
}
// Only tool tags, no actual content
const inputChunks = [
{ type: 'text-start' as const },
{ type: 'text-delta' as const, text: '<tool_use>' },
{ type: 'text-delta' as const, text: '<name>test</name>' },
{ type: 'text-delta' as const, text: '<arguments>{}</arguments>' },
{ type: 'text-delta' as const, text: '</tool_use>' },
{ type: 'text-end' as const }
]
const inputStream = simulateReadableStream<TextStreamPart<ToolSet>>({
chunks: inputChunks as TextStreamPart<ToolSet>[],
initialDelayInMs: 0,
chunkDelayInMs: 0
})
const transform = plugin.transformStream!(createMockStreamParams(), context)()
const result = await convertReadableStreamToArray(inputStream.pipeThrough(transform))
// Should not have text-start or text-end since all content was tool tags
expect(result.some((chunk) => chunk.type === 'text-start')).toBe(false)
expect(result.some((chunk) => chunk.type === 'text-end')).toBe(false)
})
it('should send text-start when non-tag content appears', async () => {
const plugin = createPromptToolUsePlugin()
const context = createMockContext()
context.mcpTools = {
test: createMockTool('test')
}
const inputChunks = [
{ type: 'text-start' as const },
{ type: 'text-delta' as const, text: 'Actual content' },
{ type: 'text-end' as const }
]
const inputStream = simulateReadableStream<TextStreamPart<ToolSet>>({
chunks: inputChunks as TextStreamPart<ToolSet>[],
initialDelayInMs: 0,
chunkDelayInMs: 0
})
const transform = plugin.transformStream!(createMockStreamParams(), context)()
const result = await convertReadableStreamToArray(inputStream.pipeThrough(transform))
// Should have text-start, text-delta, and text-end
expect(result.some((chunk) => chunk.type === 'text-start')).toBe(true)
expect(result.some((chunk) => chunk.type === 'text-delta')).toBe(true)
expect(result.some((chunk) => chunk.type === 'text-end')).toBe(true)
})
it('should pass through non-text events', async () => {
const plugin = createPromptToolUsePlugin()
const context = createMockContext()
context.mcpTools = {
test: createMockTool('test')
}
const stepStartEvent = { type: 'start-step' as const, request: {}, warnings: [] }
const inputChunks = [stepStartEvent]
const inputStream = simulateReadableStream<TextStreamPart<ToolSet>>({
chunks: inputChunks as TextStreamPart<ToolSet>[],
initialDelayInMs: 0,
chunkDelayInMs: 0
})
const transform = plugin.transformStream!(createMockStreamParams(), context)()
const result = await convertReadableStreamToArray(inputStream.pipeThrough(transform))
expect(result[0]).toEqual(stepStartEvent)
})
it('should accumulate usage from finish-step events', async () => {
const plugin = createPromptToolUsePlugin()
const context = createMockContext()
context.mcpTools = {
test: createMockTool('test')
}
const inputChunks = [
{
type: 'finish-step' as const,
finishReason: 'stop' as const,
usage: {
inputTokens: 10,
outputTokens: 20,
totalTokens: 30
}
}
]
const inputStream = simulateReadableStream<TextStreamPart<ToolSet>>({
chunks: inputChunks as TextStreamPart<ToolSet>[],
initialDelayInMs: 0,
chunkDelayInMs: 0
})
const transform = plugin.transformStream!(createMockStreamParams(), context)()
await convertReadableStreamToArray(inputStream.pipeThrough(transform))
// Verify usage was accumulated
expect(context.accumulatedUsage).toBeDefined()
expect(context.accumulatedUsage!.inputTokens).toBe(10)
expect(context.accumulatedUsage!.outputTokens).toBe(20)
expect(context.accumulatedUsage!.totalTokens).toBe(30)
})
it('should include accumulated usage in finish event', async () => {
const plugin = createPromptToolUsePlugin()
const context = createMockContext()
context.mcpTools = {
test: createMockTool('test')
}
// Pre-populate accumulated usage
context.accumulatedUsage = {
inputTokens: 5,
outputTokens: 10,
totalTokens: 15,
reasoningTokens: 0,
cachedInputTokens: 0
}
const inputChunks = [
{
type: 'finish' as const,
finishReason: 'stop' as const,
usage: {
inputTokens: 100,
outputTokens: 200,
totalTokens: 300
}
}
]
const inputStream = simulateReadableStream<TextStreamPart<ToolSet>>({
chunks: inputChunks as unknown as TextStreamPart<ToolSet>[],
initialDelayInMs: 0,
chunkDelayInMs: 0
})
const transform = plugin.transformStream!(createMockStreamParams(), context)()
const result = await convertReadableStreamToArray(inputStream.pipeThrough(transform))
const finishEvent = result.find((chunk) => chunk.type === 'finish')
expect(finishEvent).toBeDefined()
if (finishEvent && 'totalUsage' in finishEvent) {
expect(finishEvent.totalUsage).toEqual(context.accumulatedUsage)
}
})
})
describe('Type Safety', () => {
it('should have correct generic parameters for StreamTextParams and StreamTextResult', () => {
const plugin = createPromptToolUsePlugin()
// Type assertion to verify the plugin has the correct type
type PluginType = typeof plugin
const typeTest: PluginType = plugin
expect(typeTest.name).toBe('built-in:prompt-tool-use')
})
})
describe('DEFAULT_SYSTEM_PROMPT', () => {
it('should contain required sections', () => {
expect(DEFAULT_SYSTEM_PROMPT).toContain('Tool Use Formatting')
expect(DEFAULT_SYSTEM_PROMPT).toContain('Tool Use Examples')
expect(DEFAULT_SYSTEM_PROMPT).toContain('Tool Use Available Tools')
expect(DEFAULT_SYSTEM_PROMPT).toContain('Tool Use Rules')
expect(DEFAULT_SYSTEM_PROMPT).toContain('Response rules')
})
it('should have placeholders for dynamic content', () => {
expect(DEFAULT_SYSTEM_PROMPT).toContain('{{ TOOL_USE_EXAMPLES }}')
expect(DEFAULT_SYSTEM_PROMPT).toContain('{{ AVAILABLE_TOOLS }}')
expect(DEFAULT_SYSTEM_PROMPT).toContain('{{ USER_SYSTEM_PROMPT }}')
})
it('should contain XML tag examples', () => {
expect(DEFAULT_SYSTEM_PROMPT).toContain('<tool_use>')
expect(DEFAULT_SYSTEM_PROMPT).toContain('</tool_use>')
expect(DEFAULT_SYSTEM_PROMPT).toContain('<name>')
expect(DEFAULT_SYSTEM_PROMPT).toContain('<arguments>')
})
})
})

View File

@ -33,8 +33,8 @@ export function createContext<T extends ProviderId, TParams = unknown, TResult =
startTime: Date.now(),
requestId: `${providerId}-${typeof model === 'string' ? model : model?.modelId}-${Date.now()}-${Math.random().toString(36).slice(2)}`,
isRecursiveCall: false,
recursiveDepth: 0, // 初始化递归深度为 0
maxRecursiveDepth: 10, // 默认最大递归深度为 10
recursiveDepth: 0, // 初始化递归深度为 0
maxRecursiveDepth: 10, // 默认最大递归深度为 10
extensions: new Map(),
middlewares: [],
// 占位递归调用函数,实际使用时会被 PluginEngine 替换

View File

@ -72,10 +72,7 @@ export class PluginManager<TParams = unknown, TResult = unknown> {
* transformParams -
* Partial<TParams>
*/
async executeTransformParams(
initialValue: TParams,
context: AiRequestContext<TParams, TResult>
): Promise<TParams> {
async executeTransformParams(initialValue: TParams, context: AiRequestContext<TParams, TResult>): Promise<TParams> {
let result = initialValue
for (const plugin of this.plugins) {
@ -93,10 +90,7 @@ export class PluginManager<TParams = unknown, TResult = unknown> {
* transformResult -
* TResult
*/
async executeTransformResult(
initialValue: TResult,
context: AiRequestContext<TParams, TResult>
): Promise<TResult> {
async executeTransformResult(initialValue: TResult, context: AiRequestContext<TParams, TResult>): Promise<TResult> {
let result = initialValue
for (const plugin of this.plugins) {

View File

@ -0,0 +1,525 @@
/**
* HubProvider Comprehensive Tests
* Tests hub provider routing, model resolution, and error handling
* Covers multi-provider routing with namespaced model IDs
*/
import type { EmbeddingModelV3, ImageModelV3, LanguageModelV3, ProviderV3 } from '@ai-sdk/provider'
import { customProvider, wrapProvider } from 'ai'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import { createMockEmbeddingModel, createMockImageModel, createMockLanguageModel } from '../../../__tests__'
import { createHubProvider, type HubProviderConfig, HubProviderError } from '../HubProvider'
import { DEFAULT_SEPARATOR, globalRegistryManagement } from '../RegistryManagement'
// Mock dependencies
vi.mock('../RegistryManagement', () => ({
globalRegistryManagement: {
getProvider: vi.fn()
},
DEFAULT_SEPARATOR: '|'
}))
vi.mock('ai', () => ({
customProvider: vi.fn((config) => config.fallbackProvider),
wrapProvider: vi.fn((config) => config.provider)
}))
describe('HubProvider', () => {
let mockOpenAIProvider: ProviderV3
let mockAnthropicProvider: ProviderV3
let mockLanguageModel: LanguageModelV3
let mockEmbeddingModel: EmbeddingModelV3
let mockImageModel: ImageModelV3
beforeEach(() => {
vi.clearAllMocks()
// Create mock models using global utilities
mockLanguageModel = createMockLanguageModel({
provider: 'test',
modelId: 'test-model'
})
mockEmbeddingModel = createMockEmbeddingModel({
provider: 'test',
modelId: 'test-embedding'
})
mockImageModel = createMockImageModel({
provider: 'test',
modelId: 'test-image'
})
// Create mock providers
mockOpenAIProvider = {
specificationVersion: 'v3',
languageModel: vi.fn().mockReturnValue(mockLanguageModel),
embeddingModel: vi.fn().mockReturnValue(mockEmbeddingModel),
imageModel: vi.fn().mockReturnValue(mockImageModel)
} as ProviderV3
mockAnthropicProvider = {
specificationVersion: 'v3',
languageModel: vi.fn().mockReturnValue(mockLanguageModel),
embeddingModel: vi.fn().mockReturnValue(mockEmbeddingModel),
imageModel: vi.fn().mockReturnValue(mockImageModel)
} as ProviderV3
// Setup default mock implementation
vi.mocked(globalRegistryManagement.getProvider).mockImplementation((id) => {
if (id === 'openai') return mockOpenAIProvider
if (id === 'anthropic') return mockAnthropicProvider
return undefined
})
})
describe('Provider Creation', () => {
it('should create hub provider with basic config', () => {
const config: HubProviderConfig = {
hubId: 'test-hub'
}
const provider = createHubProvider(config)
expect(provider).toBeDefined()
expect(customProvider).toHaveBeenCalled()
})
it('should create provider with debug flag', () => {
const config: HubProviderConfig = {
hubId: 'test-hub',
debug: true
}
const provider = createHubProvider(config)
expect(provider).toBeDefined()
})
it('should return ProviderV3 specification', () => {
const config: HubProviderConfig = {
hubId: 'aihubmix'
}
const provider = createHubProvider(config)
expect(provider).toHaveProperty('specificationVersion', 'v3')
expect(provider).toHaveProperty('languageModel')
expect(provider).toHaveProperty('embeddingModel')
expect(provider).toHaveProperty('imageModel')
})
})
describe('Model ID Parsing', () => {
it('should parse valid hub model ID format', () => {
const config: HubProviderConfig = { hubId: 'test-hub' }
const provider = createHubProvider(config) as ProviderV3
const modelId = `openai${DEFAULT_SEPARATOR}gpt-4`
const result = provider.languageModel(modelId)
expect(globalRegistryManagement.getProvider).toHaveBeenCalledWith('openai')
expect(mockOpenAIProvider.languageModel).toHaveBeenCalledWith('gpt-4')
expect(result).toBe(mockLanguageModel)
})
it('should throw error for invalid model ID format', () => {
const config: HubProviderConfig = { hubId: 'test-hub' }
const provider = createHubProvider(config) as ProviderV3
const invalidId = 'invalid-id-without-separator'
expect(() => provider.languageModel(invalidId)).toThrow(HubProviderError)
})
it('should throw error for model ID with multiple separators', () => {
const config: HubProviderConfig = { hubId: 'test-hub' }
const provider = createHubProvider(config) as ProviderV3
const multiSeparatorId = `provider${DEFAULT_SEPARATOR}extra${DEFAULT_SEPARATOR}model`
expect(() => provider.languageModel(multiSeparatorId)).toThrow(HubProviderError)
})
it('should throw error for empty model ID', () => {
const config: HubProviderConfig = { hubId: 'test-hub' }
const provider = createHubProvider(config) as ProviderV3
expect(() => provider.languageModel('')).toThrow(HubProviderError)
})
it('should throw error for model ID with only separator', () => {
const config: HubProviderConfig = { hubId: 'test-hub' }
const provider = createHubProvider(config) as ProviderV3
expect(() => provider.languageModel(DEFAULT_SEPARATOR)).toThrow(HubProviderError)
})
})
describe('Language Model Resolution', () => {
it('should route to correct provider for language model', () => {
const config: HubProviderConfig = { hubId: 'aihubmix' }
const provider = createHubProvider(config) as ProviderV3
const result = provider.languageModel(`openai${DEFAULT_SEPARATOR}gpt-4`)
expect(globalRegistryManagement.getProvider).toHaveBeenCalledWith('openai')
expect(mockOpenAIProvider.languageModel).toHaveBeenCalledWith('gpt-4')
expect(result).toBe(mockLanguageModel)
})
it('should route different providers correctly', () => {
const config: HubProviderConfig = { hubId: 'aihubmix' }
const provider = createHubProvider(config) as ProviderV3
provider.languageModel(`openai${DEFAULT_SEPARATOR}gpt-4`)
provider.languageModel(`anthropic${DEFAULT_SEPARATOR}claude-3`)
expect(globalRegistryManagement.getProvider).toHaveBeenCalledWith('openai')
expect(globalRegistryManagement.getProvider).toHaveBeenCalledWith('anthropic')
expect(mockOpenAIProvider.languageModel).toHaveBeenCalledWith('gpt-4')
expect(mockAnthropicProvider.languageModel).toHaveBeenCalledWith('claude-3')
})
it('should wrap provider with wrapProvider', () => {
const config: HubProviderConfig = { hubId: 'test-hub' }
const provider = createHubProvider(config) as ProviderV3
provider.languageModel(`openai${DEFAULT_SEPARATOR}gpt-4`)
expect(wrapProvider).toHaveBeenCalledWith({
provider: mockOpenAIProvider,
languageModelMiddleware: []
})
})
it('should throw HubProviderError if provider not initialized', () => {
vi.mocked(globalRegistryManagement.getProvider).mockReturnValue(undefined)
const config: HubProviderConfig = { hubId: 'test-hub' }
const provider = createHubProvider(config) as ProviderV3
expect(() => provider.languageModel(`uninitialized${DEFAULT_SEPARATOR}model`)).toThrow(HubProviderError)
expect(() => provider.languageModel(`uninitialized${DEFAULT_SEPARATOR}model`)).toThrow(/not initialized/)
})
it('should include provider ID in error message', () => {
vi.mocked(globalRegistryManagement.getProvider).mockReturnValue(undefined)
const config: HubProviderConfig = { hubId: 'test-hub' }
const provider = createHubProvider(config) as ProviderV3
try {
provider.languageModel(`missing${DEFAULT_SEPARATOR}model`)
expect.fail('Should have thrown HubProviderError')
} catch (error) {
expect(error).toBeInstanceOf(HubProviderError)
const hubError = error as HubProviderError
expect(hubError.providerId).toBe('missing')
expect(hubError.hubId).toBe('test-hub')
}
})
})
describe('Embedding Model Resolution', () => {
it('should route to correct provider for embedding model', () => {
const config: HubProviderConfig = { hubId: 'aihubmix' }
const provider = createHubProvider(config) as ProviderV3
const result = provider.embeddingModel(`openai${DEFAULT_SEPARATOR}text-embedding-3-small`)
expect(globalRegistryManagement.getProvider).toHaveBeenCalledWith('openai')
expect(mockOpenAIProvider.embeddingModel).toHaveBeenCalledWith('text-embedding-3-small')
expect(result).toBe(mockEmbeddingModel)
})
it('should handle different embedding providers', () => {
const config: HubProviderConfig = { hubId: 'aihubmix' }
const provider = createHubProvider(config) as ProviderV3
provider.embeddingModel(`openai${DEFAULT_SEPARATOR}ada-002`)
provider.embeddingModel(`anthropic${DEFAULT_SEPARATOR}embed-v1`)
expect(mockOpenAIProvider.embeddingModel).toHaveBeenCalledWith('ada-002')
expect(mockAnthropicProvider.embeddingModel).toHaveBeenCalledWith('embed-v1')
})
it('should throw error for uninitialized embedding provider', () => {
vi.mocked(globalRegistryManagement.getProvider).mockReturnValue(undefined)
const config: HubProviderConfig = { hubId: 'test-hub' }
const provider = createHubProvider(config) as ProviderV3
expect(() => provider.embeddingModel(`missing${DEFAULT_SEPARATOR}embed`)).toThrow(HubProviderError)
})
})
describe('Image Model Resolution', () => {
it('should route to correct provider for image model', () => {
const config: HubProviderConfig = { hubId: 'aihubmix' }
const provider = createHubProvider(config) as ProviderV3
const result = provider.imageModel(`openai${DEFAULT_SEPARATOR}dall-e-3`)
expect(globalRegistryManagement.getProvider).toHaveBeenCalledWith('openai')
expect(mockOpenAIProvider.imageModel).toHaveBeenCalledWith('dall-e-3')
expect(result).toBe(mockImageModel)
})
it('should handle different image providers', () => {
const config: HubProviderConfig = { hubId: 'aihubmix' }
const provider = createHubProvider(config) as ProviderV3
provider.imageModel(`openai${DEFAULT_SEPARATOR}dall-e-3`)
provider.imageModel(`anthropic${DEFAULT_SEPARATOR}image-gen`)
expect(mockOpenAIProvider.imageModel).toHaveBeenCalledWith('dall-e-3')
expect(mockAnthropicProvider.imageModel).toHaveBeenCalledWith('image-gen')
})
})
describe('Special Model Types', () => {
it('should support transcription models', () => {
const mockTranscriptionModel = {
specificationVersion: 'v3',
doTranscribe: vi.fn()
}
const providerWithTranscription = {
...mockOpenAIProvider,
transcriptionModel: vi.fn().mockReturnValue(mockTranscriptionModel)
} as ProviderV3
vi.mocked(globalRegistryManagement.getProvider).mockReturnValue(providerWithTranscription)
const config: HubProviderConfig = { hubId: 'test-hub' }
const provider = createHubProvider(config) as ProviderV3
const result = provider.transcriptionModel!(`openai${DEFAULT_SEPARATOR}whisper-1`)
expect(providerWithTranscription.transcriptionModel).toHaveBeenCalledWith('whisper-1')
expect(result).toBe(mockTranscriptionModel)
})
it('should throw error if provider does not support transcription', () => {
const config: HubProviderConfig = { hubId: 'test-hub' }
const provider = createHubProvider(config) as ProviderV3
expect(() => provider.transcriptionModel!(`openai${DEFAULT_SEPARATOR}whisper`)).toThrow(HubProviderError)
expect(() => provider.transcriptionModel!(`openai${DEFAULT_SEPARATOR}whisper`)).toThrow(
/does not support transcription/
)
})
it('should support speech models', () => {
const mockSpeechModel = {
specificationVersion: 'v3',
doGenerate: vi.fn()
}
const providerWithSpeech = {
...mockOpenAIProvider,
speechModel: vi.fn().mockReturnValue(mockSpeechModel)
} as ProviderV3
vi.mocked(globalRegistryManagement.getProvider).mockReturnValue(providerWithSpeech)
const config: HubProviderConfig = { hubId: 'test-hub' }
const provider = createHubProvider(config) as ProviderV3
const result = provider.speechModel!(`openai${DEFAULT_SEPARATOR}tts-1`)
expect(providerWithSpeech.speechModel).toHaveBeenCalledWith('tts-1')
expect(result).toBe(mockSpeechModel)
})
it('should throw error if provider does not support speech', () => {
const config: HubProviderConfig = { hubId: 'test-hub' }
const provider = createHubProvider(config) as ProviderV3
expect(() => provider.speechModel!(`openai${DEFAULT_SEPARATOR}tts-1`)).toThrow(HubProviderError)
expect(() => provider.speechModel!(`openai${DEFAULT_SEPARATOR}tts-1`)).toThrow(/does not support speech/)
})
it('should support reranking models', () => {
const mockRerankingModel = {
specificationVersion: 'v3',
doRerank: vi.fn()
}
const providerWithReranking = {
...mockOpenAIProvider,
rerankingModel: vi.fn().mockReturnValue(mockRerankingModel)
} as ProviderV3
vi.mocked(globalRegistryManagement.getProvider).mockReturnValue(providerWithReranking)
const config: HubProviderConfig = { hubId: 'test-hub' }
const provider = createHubProvider(config) as ProviderV3
const result = provider.rerankingModel!(`openai${DEFAULT_SEPARATOR}rerank-v1`)
expect(providerWithReranking.rerankingModel).toHaveBeenCalledWith('rerank-v1')
expect(result).toBe(mockRerankingModel)
})
it('should throw error if provider does not support reranking', () => {
const config: HubProviderConfig = { hubId: 'test-hub' }
const provider = createHubProvider(config) as ProviderV3
expect(() => provider.rerankingModel!(`openai${DEFAULT_SEPARATOR}rerank`)).toThrow(HubProviderError)
expect(() => provider.rerankingModel!(`openai${DEFAULT_SEPARATOR}rerank`)).toThrow(/does not support reranking/)
})
})
describe('Error Handling', () => {
it('should create HubProviderError with all properties', () => {
const originalError = new Error('Original error')
const error = new HubProviderError('Test message', 'test-hub', 'test-provider', originalError)
expect(error.message).toBe('Test message')
expect(error.hubId).toBe('test-hub')
expect(error.providerId).toBe('test-provider')
expect(error.originalError).toBe(originalError)
expect(error.name).toBe('HubProviderError')
})
it('should create HubProviderError without optional parameters', () => {
const error = new HubProviderError('Test message', 'test-hub')
expect(error.message).toBe('Test message')
expect(error.hubId).toBe('test-hub')
expect(error.providerId).toBeUndefined()
expect(error.originalError).toBeUndefined()
})
it('should wrap provider errors in HubProviderError', () => {
const providerError = new Error('Provider failed')
vi.mocked(globalRegistryManagement.getProvider).mockImplementation(() => {
throw providerError
})
const config: HubProviderConfig = { hubId: 'test-hub' }
const provider = createHubProvider(config) as ProviderV3
try {
provider.languageModel(`failing${DEFAULT_SEPARATOR}model`)
expect.fail('Should have thrown HubProviderError')
} catch (error) {
expect(error).toBeInstanceOf(HubProviderError)
const hubError = error as HubProviderError
expect(hubError.originalError).toBe(providerError)
expect(hubError.message).toContain('Failed to get provider')
}
})
it('should handle null provider from registry', () => {
vi.mocked(globalRegistryManagement.getProvider).mockReturnValue(null as any)
const config: HubProviderConfig = { hubId: 'test-hub' }
const provider = createHubProvider(config) as ProviderV3
expect(() => provider.languageModel(`null-provider${DEFAULT_SEPARATOR}model`)).toThrow(HubProviderError)
})
})
describe('Multi-Provider Scenarios', () => {
it('should handle sequential calls to different providers', () => {
const config: HubProviderConfig = { hubId: 'aihubmix' }
const provider = createHubProvider(config) as ProviderV3
provider.languageModel(`openai${DEFAULT_SEPARATOR}gpt-4`)
provider.languageModel(`anthropic${DEFAULT_SEPARATOR}claude-3`)
provider.languageModel(`openai${DEFAULT_SEPARATOR}gpt-3.5`)
expect(globalRegistryManagement.getProvider).toHaveBeenCalledTimes(3)
expect(mockOpenAIProvider.languageModel).toHaveBeenCalledTimes(2)
expect(mockAnthropicProvider.languageModel).toHaveBeenCalledTimes(1)
})
it('should handle mixed model types from same provider', () => {
const config: HubProviderConfig = { hubId: 'aihubmix' }
const provider = createHubProvider(config) as ProviderV3
provider.languageModel(`openai${DEFAULT_SEPARATOR}gpt-4`)
provider.embeddingModel(`openai${DEFAULT_SEPARATOR}ada-002`)
provider.imageModel(`openai${DEFAULT_SEPARATOR}dall-e-3`)
expect(globalRegistryManagement.getProvider).toHaveBeenCalledTimes(3)
expect(mockOpenAIProvider.languageModel).toHaveBeenCalledWith('gpt-4')
expect(mockOpenAIProvider.embeddingModel).toHaveBeenCalledWith('ada-002')
expect(mockOpenAIProvider.imageModel).toHaveBeenCalledWith('dall-e-3')
})
it('should cache provider lookups', () => {
const config: HubProviderConfig = { hubId: 'aihubmix' }
const provider = createHubProvider(config) as ProviderV3
provider.languageModel(`openai${DEFAULT_SEPARATOR}gpt-4`)
provider.languageModel(`openai${DEFAULT_SEPARATOR}gpt-3.5`)
// Should call getProvider twice (once per model call)
expect(globalRegistryManagement.getProvider).toHaveBeenCalledTimes(2)
})
})
describe('Provider Wrapping', () => {
it('should wrap all providers with empty middleware', () => {
const config: HubProviderConfig = { hubId: 'test-hub' }
const provider = createHubProvider(config) as ProviderV3
provider.languageModel(`openai${DEFAULT_SEPARATOR}gpt-4`)
expect(wrapProvider).toHaveBeenCalledWith({
provider: mockOpenAIProvider,
languageModelMiddleware: []
})
})
it('should wrap providers for all model types', () => {
const config: HubProviderConfig = { hubId: 'test-hub' }
const provider = createHubProvider(config) as ProviderV3
provider.languageModel(`openai${DEFAULT_SEPARATOR}gpt-4`)
provider.embeddingModel(`openai${DEFAULT_SEPARATOR}ada`)
provider.imageModel(`openai${DEFAULT_SEPARATOR}dalle`)
expect(wrapProvider).toHaveBeenCalledTimes(3)
})
})
describe('Type Safety', () => {
it('should return properly typed LanguageModelV3', () => {
const config: HubProviderConfig = { hubId: 'test-hub' }
const provider = createHubProvider(config) as ProviderV3
const result = provider.languageModel(`openai${DEFAULT_SEPARATOR}gpt-4`)
expect(result.specificationVersion).toBe('v3')
expect(result).toHaveProperty('doGenerate')
expect(result).toHaveProperty('doStream')
})
it('should return properly typed EmbeddingModelV3', () => {
const config: HubProviderConfig = { hubId: 'test-hub' }
const provider = createHubProvider(config) as ProviderV3
const result = provider.embeddingModel(`openai${DEFAULT_SEPARATOR}ada`)
expect(result.specificationVersion).toBe('v3')
expect(result).toHaveProperty('doEmbed')
})
it('should return properly typed ImageModelV3', () => {
const config: HubProviderConfig = { hubId: 'test-hub' }
const provider = createHubProvider(config) as ProviderV3
const result = provider.imageModel(`openai${DEFAULT_SEPARATOR}dalle`)
expect(result.specificationVersion).toBe('v3')
expect(result).toHaveProperty('doGenerate')
})
})
})

View File

@ -0,0 +1,561 @@
/**
* RegistryManagement Comprehensive Tests
* Tests provider registry management, model resolution, and alias handling
* Covers registration, retrieval, and cleanup operations
*/
import type { EmbeddingModelV3, ImageModelV3, LanguageModelV3, ProviderV3 } from '@ai-sdk/provider'
import { createProviderRegistry } from 'ai'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import { createMockEmbeddingModel, createMockImageModel, createMockLanguageModel } from '../../../__tests__'
import { DEFAULT_SEPARATOR, RegistryManagement } from '../RegistryManagement'
// Mock AI SDK
vi.mock('ai', () => ({
createProviderRegistry: vi.fn()
}))
describe('RegistryManagement', () => {
let registry: RegistryManagement
let mockProvider: ProviderV3
let mockLanguageModel: LanguageModelV3
let mockEmbeddingModel: EmbeddingModelV3
let mockImageModel: ImageModelV3
beforeEach(() => {
vi.clearAllMocks()
// Create mock models using global utilities
mockLanguageModel = createMockLanguageModel({
provider: 'test',
modelId: 'test-model'
})
mockEmbeddingModel = createMockEmbeddingModel({
provider: 'test',
modelId: 'test-embedding'
})
mockImageModel = createMockImageModel({
provider: 'test',
modelId: 'test-image'
})
// Create mock provider
mockProvider = {
specificationVersion: 'v3',
languageModel: vi.fn().mockReturnValue(mockLanguageModel),
embeddingModel: vi.fn().mockReturnValue(mockEmbeddingModel),
imageModel: vi.fn().mockReturnValue(mockImageModel),
transcriptionModel: vi.fn(),
speechModel: vi.fn()
} as ProviderV3
// Setup mock registry
const mockRegistry = {
languageModel: vi.fn().mockReturnValue(mockLanguageModel),
embeddingModel: vi.fn().mockReturnValue(mockEmbeddingModel),
imageModel: vi.fn().mockReturnValue(mockImageModel),
transcriptionModel: vi.fn(),
speechModel: vi.fn()
}
vi.mocked(createProviderRegistry).mockReturnValue(mockRegistry as any)
registry = new RegistryManagement()
})
describe('Constructor and Initialization', () => {
it('should create registry with default separator', () => {
const reg = new RegistryManagement()
expect(reg).toBeInstanceOf(RegistryManagement)
expect(reg.hasProviders()).toBe(false)
})
it('should create registry with custom separator', () => {
const customSeparator = ':'
const reg = new RegistryManagement({ separator: customSeparator })
expect(reg).toBeInstanceOf(RegistryManagement)
})
it('should start with empty provider list', () => {
expect(registry.getRegisteredProviders()).toEqual([])
})
})
describe('Provider Registration', () => {
it('should register a provider', () => {
registry.registerProvider('openai', mockProvider)
expect(registry.getProvider('openai')).toBe(mockProvider)
expect(registry.hasProviders()).toBe(true)
})
it('should register multiple providers', () => {
const provider2 = { ...mockProvider }
registry.registerProvider('openai', mockProvider)
registry.registerProvider('anthropic', provider2)
expect(registry.getProvider('openai')).toBe(mockProvider)
expect(registry.getProvider('anthropic')).toBe(provider2)
})
it('should return this for chaining', () => {
const result = registry.registerProvider('openai', mockProvider)
expect(result).toBe(registry)
})
it('should rebuild registry after registration', () => {
registry.registerProvider('openai', mockProvider)
expect(createProviderRegistry).toHaveBeenCalledWith(
expect.objectContaining({
openai: mockProvider
}),
{ separator: DEFAULT_SEPARATOR }
)
})
it('should register provider with aliases', () => {
registry.registerProvider('openai', mockProvider, ['gpt', 'chatgpt'])
expect(registry.getProvider('openai')).toBe(mockProvider)
expect(registry.getProvider('gpt')).toBe(mockProvider)
expect(registry.getProvider('chatgpt')).toBe(mockProvider)
})
it('should track aliases separately', () => {
registry.registerProvider('openai', mockProvider, ['gpt'])
expect(registry.isAlias('gpt')).toBe(true)
expect(registry.isAlias('openai')).toBe(false)
})
it('should handle multiple aliases for same provider', () => {
const aliases = ['alias1', 'alias2', 'alias3']
registry.registerProvider('provider', mockProvider, aliases)
aliases.forEach((alias) => {
expect(registry.getProvider(alias)).toBe(mockProvider)
expect(registry.isAlias(alias)).toBe(true)
})
})
})
describe('Bulk Registration', () => {
it('should register multiple providers at once', () => {
const providers = {
openai: mockProvider,
anthropic: { ...mockProvider },
google: { ...mockProvider }
}
registry.registerProviders(providers)
expect(registry.getProvider('openai')).toBe(providers.openai)
expect(registry.getProvider('anthropic')).toBe(providers.anthropic)
expect(registry.getProvider('google')).toBe(providers.google)
})
it('should return this for chaining', () => {
const result = registry.registerProviders({ openai: mockProvider })
expect(result).toBe(registry)
})
})
describe('Provider Retrieval', () => {
beforeEach(() => {
registry.registerProvider('openai', mockProvider)
})
it('should retrieve registered provider', () => {
const provider = registry.getProvider('openai')
expect(provider).toBe(mockProvider)
})
it('should return undefined for unregistered provider', () => {
const provider = registry.getProvider('nonexistent')
expect(provider).toBeUndefined()
})
it('should retrieve provider by alias', () => {
registry.registerProvider('anthropic', mockProvider, ['claude'])
const provider = registry.getProvider('claude')
expect(provider).toBe(mockProvider)
})
it('should get list of all registered providers', () => {
registry.registerProvider('anthropic', mockProvider)
registry.registerProvider('google', mockProvider, ['gemini'])
const providers = registry.getRegisteredProviders()
expect(providers).toContain('openai')
expect(providers).toContain('anthropic')
expect(providers).toContain('google')
expect(providers).toContain('gemini') // Aliases included
})
})
describe('Provider Unregistration', () => {
it('should unregister provider', () => {
registry.registerProvider('openai', mockProvider)
registry.unregisterProvider('openai')
expect(registry.getProvider('openai')).toBeUndefined()
})
it('should unregister provider with all its aliases', () => {
registry.registerProvider('openai', mockProvider, ['gpt', 'chatgpt'])
registry.unregisterProvider('openai')
expect(registry.getProvider('openai')).toBeUndefined()
expect(registry.getProvider('gpt')).toBeUndefined()
expect(registry.getProvider('chatgpt')).toBeUndefined()
})
it('should unregister only alias when alias is removed', () => {
registry.registerProvider('openai', mockProvider, ['gpt', 'chatgpt'])
registry.unregisterProvider('gpt')
expect(registry.getProvider('openai')).toBe(mockProvider)
expect(registry.getProvider('gpt')).toBeUndefined()
expect(registry.getProvider('chatgpt')).toBe(mockProvider)
})
it('should handle unregistering non-existent provider', () => {
expect(() => registry.unregisterProvider('nonexistent')).not.toThrow()
})
it('should return this for chaining', () => {
registry.registerProvider('openai', mockProvider)
const result = registry.unregisterProvider('openai')
expect(result).toBe(registry)
})
it('should rebuild registry after unregistration', () => {
registry.registerProvider('openai', mockProvider)
vi.clearAllMocks()
registry.unregisterProvider('openai')
// Should rebuild with empty providers
expect(createProviderRegistry).not.toHaveBeenCalled() // No rebuild when empty
})
})
describe('Model Resolution', () => {
beforeEach(() => {
registry.registerProvider('openai', mockProvider)
})
it('should resolve language model', () => {
const modelId = `openai${DEFAULT_SEPARATOR}gpt-4` as any
const result = registry.languageModel(modelId)
expect(result).toBe(mockLanguageModel)
})
it('should resolve embedding model', () => {
const modelId = `openai${DEFAULT_SEPARATOR}text-embedding-3-small` as any
const result = registry.embeddingModel(modelId)
expect(result).toBe(mockEmbeddingModel)
})
it('should resolve image model', () => {
const modelId = `openai${DEFAULT_SEPARATOR}dall-e-3` as any
const result = registry.imageModel(modelId)
expect(result).toBe(mockImageModel)
})
it('should resolve transcription model', () => {
const modelId = `openai${DEFAULT_SEPARATOR}whisper-1` as any
registry.transcriptionModel(modelId)
// Verify it calls through to the mock registry
expect(createProviderRegistry).toHaveBeenCalled()
})
it('should resolve speech model', () => {
const modelId = `openai${DEFAULT_SEPARATOR}tts-1` as any
registry.speechModel(modelId)
expect(createProviderRegistry).toHaveBeenCalled()
})
it('should throw error when no providers registered', () => {
const emptyRegistry = new RegistryManagement()
expect(() => emptyRegistry.languageModel('openai|gpt-4' as any)).toThrow('No providers registered')
})
})
describe('Alias Management', () => {
it('should resolve provider ID from alias', () => {
registry.registerProvider('openai', mockProvider, ['gpt'])
const realId = registry.resolveProviderId('gpt')
expect(realId).toBe('openai')
})
it('should return same ID if not an alias', () => {
registry.registerProvider('openai', mockProvider)
const realId = registry.resolveProviderId('openai')
expect(realId).toBe('openai')
})
it('should check if ID is alias', () => {
registry.registerProvider('openai', mockProvider, ['gpt'])
expect(registry.isAlias('gpt')).toBe(true)
expect(registry.isAlias('openai')).toBe(false)
})
it('should get all alias mappings', () => {
const provider2 = { ...mockProvider }
registry.registerProvider('openai', mockProvider, ['gpt', 'chatgpt'])
registry.registerProvider('anthropic', provider2, ['claude'])
const aliases = registry.getAllAliases()
// Check that all aliases are present
expect(aliases['gpt']).toBe('openai')
expect(aliases['chatgpt']).toBe('openai')
expect(aliases['claude']).toBe('anthropic')
})
it('should return empty object when no aliases', () => {
registry.registerProvider('openai', mockProvider)
const aliases = registry.getAllAliases()
expect(aliases).toEqual({})
})
})
describe('Registry State', () => {
it('should check if has providers', () => {
expect(registry.hasProviders()).toBe(false)
registry.registerProvider('openai', mockProvider)
expect(registry.hasProviders()).toBe(true)
})
it('should clear all providers', () => {
registry.registerProvider('openai', mockProvider, ['gpt'])
registry.registerProvider('anthropic', mockProvider)
registry.clear()
expect(registry.hasProviders()).toBe(false)
expect(registry.getRegisteredProviders()).toEqual([])
expect(registry.getAllAliases()).toEqual({})
})
it('should return this after clear for chaining', () => {
const result = registry.clear()
expect(result).toBe(registry)
})
})
describe('Registry Rebuilding', () => {
it('should rebuild registry when provider added', () => {
registry.registerProvider('openai', mockProvider)
expect(createProviderRegistry).toHaveBeenCalledTimes(1)
})
it('should rebuild registry when provider removed', () => {
registry.registerProvider('openai', mockProvider)
registry.registerProvider('anthropic', mockProvider)
vi.clearAllMocks()
registry.unregisterProvider('openai')
expect(createProviderRegistry).toHaveBeenCalledTimes(1)
})
it('should set registry to null when all providers removed', () => {
registry.registerProvider('openai', mockProvider)
registry.unregisterProvider('openai')
expect(() => registry.languageModel('any|model' as any)).toThrow('No providers registered')
})
it('should rebuild with correct separator', () => {
const customRegistry = new RegistryManagement({ separator: ':' })
customRegistry.registerProvider('openai', mockProvider)
expect(createProviderRegistry).toHaveBeenCalledWith(expect.any(Object), { separator: ':' })
})
})
describe('Global Registry Instance', () => {
it('should have a global instance with default separator', async () => {
const module = await import('../RegistryManagement')
expect(module.globalRegistryManagement).toBeInstanceOf(RegistryManagement)
})
it('should have DEFAULT_SEPARATOR exported', () => {
expect(DEFAULT_SEPARATOR).toBe('|')
})
})
describe('Edge Cases', () => {
it('should handle registering same provider twice', () => {
registry.registerProvider('openai', mockProvider)
const provider2 = { ...mockProvider }
registry.registerProvider('openai', provider2)
expect(registry.getProvider('openai')).toBe(provider2)
})
it('should handle alias conflicts (first wins)', () => {
registry.registerProvider('provider1', mockProvider, ['shared-alias'])
registry.registerProvider('provider2', mockProvider, ['shared-alias'])
// First registered alias wins (the implementation doesn't override)
expect(registry.resolveProviderId('shared-alias')).toBe('provider1')
})
it('should handle empty alias array', () => {
registry.registerProvider('openai', mockProvider, [])
expect(registry.getAllAliases()).toEqual({})
})
it('should handle null registry operations gracefully', () => {
const emptyRegistry = new RegistryManagement()
expect(() => emptyRegistry.languageModel('test|model' as any)).toThrow('No providers registered')
expect(() => emptyRegistry.embeddingModel('test|embed' as any)).toThrow('No providers registered')
expect(() => emptyRegistry.imageModel('test|image' as any)).toThrow('No providers registered')
})
it('should handle special characters in provider IDs', () => {
const specialIds = ['provider-1', 'provider_2', 'provider.3', 'provider:4']
specialIds.forEach((id) => {
registry.registerProvider(id, mockProvider)
expect(registry.getProvider(id)).toBe(mockProvider)
})
})
})
describe('Concurrent Operations', () => {
it('should handle concurrent registrations', () => {
const promises = [
Promise.resolve(registry.registerProvider('provider1', mockProvider)),
Promise.resolve(registry.registerProvider('provider2', mockProvider)),
Promise.resolve(registry.registerProvider('provider3', mockProvider))
]
return Promise.all(promises).then(() => {
expect(registry.getRegisteredProviders()).toHaveLength(3)
})
})
it('should handle mixed operations', () => {
registry.registerProvider('openai', mockProvider)
registry.registerProvider('anthropic', mockProvider)
const provider1 = registry.getProvider('openai')
registry.unregisterProvider('anthropic')
const provider2 = registry.getProvider('openai')
expect(provider1).toBe(provider2)
})
})
describe('Type Safety', () => {
it('should enforce model ID format with template literal types', () => {
registry.registerProvider('openai', mockProvider)
// These should be type-safe
const validId = 'openai|gpt-4' as `${string}${typeof DEFAULT_SEPARATOR}${string}`
expect(() => registry.languageModel(validId)).not.toThrow()
})
it('should return properly typed LanguageModelV3', () => {
registry.registerProvider('openai', mockProvider)
const model = registry.languageModel('openai|gpt-4' as any)
expect(model.specificationVersion).toBe('v3')
expect(model).toHaveProperty('doGenerate')
expect(model).toHaveProperty('doStream')
})
it('should return properly typed EmbeddingModelV3', () => {
registry.registerProvider('openai', mockProvider)
const model = registry.embeddingModel('openai|ada-002' as any)
expect(model.specificationVersion).toBe('v3')
expect(model).toHaveProperty('doEmbed')
})
it('should return properly typed ImageModelV3', () => {
registry.registerProvider('openai', mockProvider)
const model = registry.imageModel('openai|dall-e-3' as any)
expect(model.specificationVersion).toBe('v3')
expect(model).toHaveProperty('doGenerate')
})
})
describe('Memory Management', () => {
it('should properly clean up on clear', () => {
registry.registerProvider('p1', mockProvider, ['a1'])
registry.registerProvider('p2', mockProvider, ['a2'])
registry.clear()
expect(registry.getRegisteredProviders()).toHaveLength(0)
expect(Object.keys(registry.getAllAliases())).toHaveLength(0)
})
it('should properly clean up on unregister', () => {
registry.registerProvider('openai', mockProvider, ['gpt', 'chatgpt'])
registry.unregisterProvider('openai')
expect(registry.getProvider('openai')).toBeUndefined()
expect(registry.isAlias('gpt')).toBe(false)
expect(registry.isAlias('chatgpt')).toBe(false)
})
})
})

View File

@ -0,0 +1,650 @@
/**
* RuntimeExecutor.resolveModel Comprehensive Tests
* Tests the private resolveModel and resolveImageModel methods through public APIs
* Covers model resolution, middleware application, and type validation
*/
import type { ImageModelV3, LanguageModelV3 } from '@ai-sdk/provider'
import { generateImage, generateText, streamText } from 'ai'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import {
createMockImageModel,
createMockLanguageModel,
createMockMiddleware,
mockProviderConfigs
} from '../../../__tests__'
import { globalModelResolver } from '../../models'
import { ImageModelResolutionError } from '../errors'
import { RuntimeExecutor } from '../executor'
// Mock AI SDK
vi.mock('ai', async (importOriginal) => {
const actual = (await importOriginal()) as Record<string, unknown>
return {
...actual,
generateText: vi.fn(),
streamText: vi.fn(),
generateImage: vi.fn(),
wrapLanguageModel: vi.fn((config: any) => ({
...config.model,
_middlewareApplied: true,
middleware: config.middleware
}))
}
})
vi.mock('../../providers/RegistryManagement', () => ({
globalRegistryManagement: {
languageModel: vi.fn(),
imageModel: vi.fn()
},
DEFAULT_SEPARATOR: '|'
}))
vi.mock('../../models', () => ({
globalModelResolver: {
resolveLanguageModel: vi.fn(),
resolveImageModel: vi.fn()
}
}))
describe('RuntimeExecutor - Model Resolution', () => {
let executor: RuntimeExecutor<'openai'>
let mockLanguageModel: LanguageModelV3
let mockImageModel: ImageModelV3
beforeEach(() => {
vi.clearAllMocks()
executor = RuntimeExecutor.create('openai', mockProviderConfigs.openai)
mockLanguageModel = createMockLanguageModel({
specificationVersion: 'v3',
provider: 'openai',
modelId: 'gpt-4'
})
mockImageModel = createMockImageModel({
specificationVersion: 'v3',
provider: 'openai',
modelId: 'dall-e-3'
})
vi.mocked(globalModelResolver.resolveLanguageModel).mockResolvedValue(mockLanguageModel)
vi.mocked(globalModelResolver.resolveImageModel).mockResolvedValue(mockImageModel)
vi.mocked(generateText).mockResolvedValue({
text: 'Test response',
finishReason: 'stop',
usage: { inputTokens: 10, outputTokens: 5, totalTokens: 15 }
} as any)
vi.mocked(streamText).mockResolvedValue({
textStream: (async function* () {
yield 'test'
})()
} as any)
vi.mocked(generateImage).mockResolvedValue({
image: {
base64: 'test-image',
uint8Array: new Uint8Array([1, 2, 3]),
mimeType: 'image/png'
},
warnings: []
} as any)
})
describe('Language Model Resolution (String modelId)', () => {
it('should resolve string modelId using globalModelResolver', async () => {
await executor.generateText({
model: 'gpt-4',
messages: [{ role: 'user', content: 'Hello' }]
})
expect(globalModelResolver.resolveLanguageModel).toHaveBeenCalledWith(
'gpt-4',
'openai',
mockProviderConfigs.openai,
undefined
)
})
it('should pass provider settings to model resolver', async () => {
const customExecutor = RuntimeExecutor.create('anthropic', {
apiKey: 'sk-test',
baseURL: 'https://api.anthropic.com'
})
vi.mocked(globalModelResolver.resolveLanguageModel).mockResolvedValue(mockLanguageModel)
await customExecutor.generateText({
model: 'claude-3-5-sonnet',
messages: [{ role: 'user', content: 'Test' }]
})
expect(globalModelResolver.resolveLanguageModel).toHaveBeenCalledWith(
'claude-3-5-sonnet',
'anthropic',
{
apiKey: 'sk-test',
baseURL: 'https://api.anthropic.com'
},
undefined
)
})
it('should resolve traditional format modelId', async () => {
await executor.generateText({
model: 'gpt-4-turbo',
messages: [{ role: 'user', content: 'Test' }]
})
expect(globalModelResolver.resolveLanguageModel).toHaveBeenCalledWith(
'gpt-4-turbo',
'openai',
expect.any(Object),
undefined
)
})
it('should resolve namespaced format modelId', async () => {
await executor.generateText({
model: 'aihubmix|anthropic|claude-3',
messages: [{ role: 'user', content: 'Test' }]
})
expect(globalModelResolver.resolveLanguageModel).toHaveBeenCalledWith(
'aihubmix|anthropic|claude-3',
'openai',
expect.any(Object),
undefined
)
})
it('should use resolved model for generation', async () => {
await executor.generateText({
model: 'gpt-4',
messages: [{ role: 'user', content: 'Hello' }]
})
expect(generateText).toHaveBeenCalledWith(
expect.objectContaining({
model: mockLanguageModel
})
)
})
it('should work with streamText', async () => {
await executor.streamText({
model: 'gpt-4',
messages: [{ role: 'user', content: 'Stream test' }]
})
expect(globalModelResolver.resolveLanguageModel).toHaveBeenCalled()
expect(streamText).toHaveBeenCalledWith(
expect.objectContaining({
model: mockLanguageModel
})
)
})
})
describe('Language Model Resolution (Direct Model Object)', () => {
it('should accept pre-resolved V3 model object', async () => {
const directModel: LanguageModelV3 = createMockLanguageModel({
specificationVersion: 'v3',
provider: 'openai',
modelId: 'gpt-4'
})
await executor.generateText({
model: directModel,
messages: [{ role: 'user', content: 'Test' }]
})
// Should NOT call resolver for direct model
expect(globalModelResolver.resolveLanguageModel).not.toHaveBeenCalled()
// Should use the model directly
expect(generateText).toHaveBeenCalledWith(
expect.objectContaining({
model: directModel
})
)
})
it('should accept V2 model object without validation (plugin engine handles it)', async () => {
const v2Model = {
specificationVersion: 'v2',
provider: 'openai',
modelId: 'gpt-4',
doGenerate: vi.fn()
} as any
// The plugin engine accepts any model object directly without validation
// V3 validation only happens when resolving string modelIds
await expect(
executor.generateText({
model: v2Model,
messages: [{ role: 'user', content: 'Test' }]
})
).resolves.toBeDefined()
})
it('should accept any model object without checking specification version', async () => {
const v2Model = {
specificationVersion: 'v2',
provider: 'custom-provider',
modelId: 'custom-model',
doGenerate: vi.fn()
} as any
// Direct model objects bypass validation
// The executor trusts that plugins/users provide valid models
await expect(
executor.generateText({
model: v2Model,
messages: [{ role: 'user', content: 'Test' }]
})
).resolves.toBeDefined()
})
it('should accept model object with streamText', async () => {
const directModel = createMockLanguageModel({
specificationVersion: 'v3'
})
await executor.streamText({
model: directModel,
messages: [{ role: 'user', content: 'Stream' }]
})
expect(globalModelResolver.resolveLanguageModel).not.toHaveBeenCalled()
expect(streamText).toHaveBeenCalledWith(
expect.objectContaining({
model: directModel
})
)
})
})
describe('Middleware Application', () => {
it('should apply middlewares to string modelId', async () => {
const testMiddleware = createMockMiddleware({ name: 'test-middleware' })
await executor.generateText(
{
model: 'gpt-4',
messages: [{ role: 'user', content: 'Test' }]
},
{ middlewares: [testMiddleware] }
)
expect(globalModelResolver.resolveLanguageModel).toHaveBeenCalledWith('gpt-4', 'openai', expect.any(Object), [
testMiddleware
])
})
it('should apply multiple middlewares in order', async () => {
const middleware1 = createMockMiddleware({ name: 'middleware-1' })
const middleware2 = createMockMiddleware({ name: 'middleware-2' })
await executor.generateText(
{
model: 'gpt-4',
messages: [{ role: 'user', content: 'Test' }]
},
{ middlewares: [middleware1, middleware2] }
)
expect(globalModelResolver.resolveLanguageModel).toHaveBeenCalledWith('gpt-4', 'openai', expect.any(Object), [
middleware1,
middleware2
])
})
it('should pass middlewares to model resolver for string modelIds', async () => {
const testMiddleware = createMockMiddleware({ name: 'test-middleware' })
await executor.generateText(
{
model: 'gpt-4', // String model ID
messages: [{ role: 'user', content: 'Test' }]
},
{ middlewares: [testMiddleware] }
)
// Middlewares are passed to the resolver for string modelIds
expect(globalModelResolver.resolveLanguageModel).toHaveBeenCalledWith('gpt-4', 'openai', expect.any(Object), [
testMiddleware
])
})
it('should not apply middlewares when none provided', async () => {
await executor.generateText({
model: 'gpt-4',
messages: [{ role: 'user', content: 'Test' }]
})
expect(globalModelResolver.resolveLanguageModel).toHaveBeenCalledWith(
'gpt-4',
'openai',
expect.any(Object),
undefined
)
})
it('should handle empty middleware array', async () => {
await executor.generateText(
{
model: 'gpt-4',
messages: [{ role: 'user', content: 'Test' }]
},
{ middlewares: [] }
)
expect(globalModelResolver.resolveLanguageModel).toHaveBeenCalledWith('gpt-4', 'openai', expect.any(Object), [])
})
it('should work with middlewares in streamText', async () => {
const middleware = createMockMiddleware({ name: 'stream-middleware' })
await executor.streamText(
{
model: 'gpt-4',
messages: [{ role: 'user', content: 'Stream' }]
},
{ middlewares: [middleware] }
)
expect(globalModelResolver.resolveLanguageModel).toHaveBeenCalledWith('gpt-4', 'openai', expect.any(Object), [
middleware
])
})
})
describe('Image Model Resolution', () => {
it('should resolve string image modelId using globalModelResolver', async () => {
await executor.generateImage({
model: 'dall-e-3',
prompt: 'A beautiful sunset'
})
expect(globalModelResolver.resolveImageModel).toHaveBeenCalledWith('dall-e-3', 'openai')
})
it('should accept direct ImageModelV3 object', async () => {
const directImageModel: ImageModelV3 = createMockImageModel({
specificationVersion: 'v3',
provider: 'openai',
modelId: 'dall-e-3'
})
await executor.generateImage({
model: directImageModel,
prompt: 'Test image'
})
expect(globalModelResolver.resolveImageModel).not.toHaveBeenCalled()
expect(generateImage).toHaveBeenCalledWith(
expect.objectContaining({
model: directImageModel
})
)
})
it('should resolve namespaced image model ID', async () => {
await executor.generateImage({
model: 'aihubmix|openai|dall-e-3',
prompt: 'Namespaced image'
})
expect(globalModelResolver.resolveImageModel).toHaveBeenCalledWith('aihubmix|openai|dall-e-3', 'openai')
})
it('should throw ImageModelResolutionError on resolution failure', async () => {
const resolutionError = new Error('Model not found')
vi.mocked(globalModelResolver.resolveImageModel).mockRejectedValue(resolutionError)
await expect(
executor.generateImage({
model: 'invalid-model',
prompt: 'Test'
})
).rejects.toThrow(ImageModelResolutionError)
})
it('should include modelId and providerId in ImageModelResolutionError', async () => {
vi.mocked(globalModelResolver.resolveImageModel).mockRejectedValue(new Error('Not found'))
try {
await executor.generateImage({
model: 'invalid-model',
prompt: 'Test'
})
expect.fail('Should have thrown ImageModelResolutionError')
} catch (error) {
expect(error).toBeInstanceOf(ImageModelResolutionError)
const imgError = error as ImageModelResolutionError
expect(imgError.message).toContain('invalid-model')
expect(imgError.providerId).toBe('openai')
}
})
it('should extract modelId from direct model object in error', async () => {
const directModel = createMockImageModel({
modelId: 'direct-model',
doGenerate: vi.fn().mockRejectedValue(new Error('Generation failed'))
})
vi.mocked(generateImage).mockRejectedValue(new Error('Generation failed'))
await expect(
executor.generateImage({
model: directModel,
prompt: 'Test'
})
).rejects.toThrow()
})
})
describe('Provider-Specific Model Resolution', () => {
it('should resolve models for OpenAI provider', async () => {
const openaiExecutor = RuntimeExecutor.create('openai', mockProviderConfigs.openai)
await openaiExecutor.generateText({
model: 'gpt-4',
messages: [{ role: 'user', content: 'Test' }]
})
expect(globalModelResolver.resolveLanguageModel).toHaveBeenCalledWith(
'gpt-4',
'openai',
expect.any(Object),
undefined
)
})
it('should resolve models for Anthropic provider', async () => {
const anthropicExecutor = RuntimeExecutor.create('anthropic', mockProviderConfigs.anthropic)
await anthropicExecutor.generateText({
model: 'claude-3-5-sonnet',
messages: [{ role: 'user', content: 'Test' }]
})
expect(globalModelResolver.resolveLanguageModel).toHaveBeenCalledWith(
'claude-3-5-sonnet',
'anthropic',
expect.any(Object),
undefined
)
})
it('should resolve models for Google provider', async () => {
const googleExecutor = RuntimeExecutor.create('google', mockProviderConfigs.google)
await googleExecutor.generateText({
model: 'gemini-2.0-flash',
messages: [{ role: 'user', content: 'Test' }]
})
expect(globalModelResolver.resolveLanguageModel).toHaveBeenCalledWith(
'gemini-2.0-flash',
'google',
expect.any(Object),
undefined
)
})
it('should resolve models for OpenAI-compatible provider', async () => {
const compatibleExecutor = RuntimeExecutor.createOpenAICompatible(mockProviderConfigs['openai-compatible'])
await compatibleExecutor.generateText({
model: 'custom-model',
messages: [{ role: 'user', content: 'Test' }]
})
expect(globalModelResolver.resolveLanguageModel).toHaveBeenCalledWith(
'custom-model',
'openai-compatible',
expect.any(Object),
undefined
)
})
})
describe('OpenAI Mode Handling', () => {
it('should pass mode setting to model resolver', async () => {
const executorWithMode = RuntimeExecutor.create('openai', {
...mockProviderConfigs.openai,
mode: 'chat'
})
await executorWithMode.generateText({
model: 'gpt-4',
messages: [{ role: 'user', content: 'Test' }]
})
expect(globalModelResolver.resolveLanguageModel).toHaveBeenCalledWith(
'gpt-4',
'openai',
expect.objectContaining({
mode: 'chat'
}),
undefined
)
})
it('should handle responses mode', async () => {
const executorWithMode = RuntimeExecutor.create('openai', {
...mockProviderConfigs.openai,
mode: 'responses'
})
await executorWithMode.generateText({
model: 'gpt-4',
messages: [{ role: 'user', content: 'Test' }]
})
expect(globalModelResolver.resolveLanguageModel).toHaveBeenCalledWith(
'gpt-4',
'openai',
expect.objectContaining({
mode: 'responses'
}),
undefined
)
})
})
describe('Edge Cases', () => {
it('should handle empty string modelId', async () => {
await executor.generateText({
model: '',
messages: [{ role: 'user', content: 'Test' }]
})
expect(globalModelResolver.resolveLanguageModel).toHaveBeenCalledWith('', 'openai', expect.any(Object), undefined)
})
it('should handle model resolution errors gracefully', async () => {
vi.mocked(globalModelResolver.resolveLanguageModel).mockRejectedValue(new Error('Model not found'))
await expect(
executor.generateText({
model: 'nonexistent-model',
messages: [{ role: 'user', content: 'Test' }]
})
).rejects.toThrow('Model not found')
})
it('should handle concurrent model resolutions', async () => {
const promises = [
executor.generateText({ model: 'gpt-4', messages: [{ role: 'user', content: 'Test 1' }] }),
executor.generateText({ model: 'gpt-4-turbo', messages: [{ role: 'user', content: 'Test 2' }] }),
executor.generateText({ model: 'gpt-3.5-turbo', messages: [{ role: 'user', content: 'Test 3' }] })
]
await Promise.all(promises)
expect(globalModelResolver.resolveLanguageModel).toHaveBeenCalledTimes(3)
})
it('should accept model object even without specificationVersion', async () => {
const invalidModel = {
provider: 'test',
modelId: 'test-model'
// Missing specificationVersion
} as any
// Plugin engine doesn't validate direct model objects
// It's the user's responsibility to provide valid models
await expect(
executor.generateText({
model: invalidModel,
messages: [{ role: 'user', content: 'Test' }]
})
).resolves.toBeDefined()
})
})
describe('Type Safety Validation', () => {
it('should ensure resolved model is LanguageModelV3', async () => {
const v3Model = createMockLanguageModel({
specificationVersion: 'v3'
})
vi.mocked(globalModelResolver.resolveLanguageModel).mockResolvedValue(v3Model)
await executor.generateText({
model: 'gpt-4',
messages: [{ role: 'user', content: 'Test' }]
})
expect(generateText).toHaveBeenCalledWith(
expect.objectContaining({
model: expect.objectContaining({
specificationVersion: 'v3'
})
})
)
})
it('should not enforce specification version for direct models', async () => {
const v1Model = {
specificationVersion: 'v1',
provider: 'test',
modelId: 'test'
} as any
// Direct models bypass validation in the plugin engine
// Only resolved models (from string IDs) are validated
await expect(
executor.generateText({
model: v1Model,
messages: [{ role: 'user', content: 'Test' }]
})
).resolves.toBeDefined()
})
})
})

View File

@ -0,0 +1,867 @@
/**
* PluginEngine Comprehensive Tests
* Tests plugin lifecycle, execution order, and coordination
* Covers both streaming and non-streaming execution paths
*/
import type { ImageModelV3, LanguageModelV3 } from '@ai-sdk/provider'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import { createMockImageModel, createMockLanguageModel } from '../../../__tests__'
import { ModelResolutionError, RecursiveDepthError } from '../../errors'
import type { AiPlugin, GenerateTextParams, GenerateTextResult } from '../../plugins'
import { PluginEngine } from '../pluginEngine'
describe('PluginEngine', () => {
let engine: PluginEngine<'openai'>
let mockLanguageModel: LanguageModelV3
let mockImageModel: ImageModelV3
beforeEach(() => {
vi.clearAllMocks()
mockLanguageModel = createMockLanguageModel({
provider: 'openai',
modelId: 'gpt-4'
})
mockImageModel = createMockImageModel({
provider: 'openai',
modelId: 'dall-e-3'
})
})
describe('Plugin Registration and Management', () => {
it('should create engine with empty plugins', () => {
engine = new PluginEngine('openai', [])
expect(engine.getPlugins()).toEqual([])
})
it('should create engine with initial plugins', () => {
const plugin1: AiPlugin = { name: 'plugin-1' }
const plugin2: AiPlugin = { name: 'plugin-2' }
engine = new PluginEngine('openai', [plugin1, plugin2])
expect(engine.getPlugins()).toHaveLength(2)
expect(engine.getPlugins()).toEqual([plugin1, plugin2])
})
it('should add plugin with use()', () => {
engine = new PluginEngine('openai', [])
const plugin: AiPlugin = { name: 'test-plugin' }
const result = engine.use(plugin)
expect(result).toBe(engine) // Chainable
expect(engine.getPlugins()).toContain(plugin)
})
it('should add multiple plugins with usePlugins()', () => {
engine = new PluginEngine('openai', [])
const plugins: AiPlugin[] = [{ name: 'plugin-1' }, { name: 'plugin-2' }, { name: 'plugin-3' }]
const result = engine.usePlugins(plugins)
expect(result).toBe(engine) // Chainable
expect(engine.getPlugins()).toHaveLength(3)
})
it('should support method chaining', () => {
engine = new PluginEngine('openai', [])
engine
.use({ name: 'plugin-1' })
.use({ name: 'plugin-2' })
.usePlugins([{ name: 'plugin-3' }])
expect(engine.getPlugins()).toHaveLength(3)
})
it('should remove plugin by name', () => {
const plugin1: AiPlugin = { name: 'plugin-1' }
const plugin2: AiPlugin = { name: 'plugin-2' }
engine = new PluginEngine('openai', [plugin1, plugin2])
engine.removePlugin('plugin-1')
expect(engine.getPlugins()).toHaveLength(1)
expect(engine.getPlugins()[0]).toBe(plugin2)
})
it('should not error when removing non-existent plugin', () => {
engine = new PluginEngine('openai', [{ name: 'existing' }])
engine.removePlugin('non-existent')
expect(engine.getPlugins()).toHaveLength(1)
})
it('should get plugin statistics', () => {
const plugins: AiPlugin[] = [
{ name: 'plugin-1', enforce: 'pre' },
{ name: 'plugin-2' },
{ name: 'plugin-3', enforce: 'post' }
]
engine = new PluginEngine('openai', plugins)
const stats = engine.getPluginStats()
expect(stats).toHaveProperty('total')
expect(stats.total).toBe(3)
})
it('should preserve plugin order', () => {
const plugins: AiPlugin[] = [{ name: 'first' }, { name: 'second' }, { name: 'third' }]
engine = new PluginEngine('openai', plugins)
const retrieved = engine.getPlugins()
expect(retrieved[0].name).toBe('first')
expect(retrieved[1].name).toBe('second')
expect(retrieved[2].name).toBe('third')
})
})
describe('Plugin Lifecycle - Non-Streaming', () => {
it('should execute all plugin hooks in correct order', async () => {
const executionOrder: string[] = []
const plugin: AiPlugin = {
name: 'lifecycle-test',
configureContext: vi.fn(async () => {
executionOrder.push('configureContext')
}),
onRequestStart: vi.fn(async () => {
executionOrder.push('onRequestStart')
}),
resolveModel: vi.fn(async () => {
executionOrder.push('resolveModel')
return mockLanguageModel
}),
transformParams: vi.fn(async (params) => {
executionOrder.push('transformParams')
return params
}),
transformResult: vi.fn(async (result) => {
executionOrder.push('transformResult')
return result
}),
onRequestEnd: vi.fn(async () => {
executionOrder.push('onRequestEnd')
})
}
engine = new PluginEngine('openai', [plugin])
const mockExecutor = vi.fn().mockResolvedValue({ text: 'test', finishReason: 'stop' })
await engine.executeWithPlugins<GenerateTextParams, GenerateTextResult>(
'generateText',
{ model: 'gpt-4', messages: [] },
mockExecutor
)
expect(executionOrder).toEqual([
'configureContext',
'onRequestStart',
'resolveModel',
'transformParams',
'transformResult',
'onRequestEnd'
])
})
it('should call configureContext before other hooks', async () => {
const configureContextSpy = vi.fn()
const onRequestStartSpy = vi.fn()
const plugin: AiPlugin = {
name: 'test',
configureContext: configureContextSpy,
onRequestStart: onRequestStartSpy,
resolveModel: vi.fn().mockResolvedValue(mockLanguageModel)
}
engine = new PluginEngine('openai', [plugin])
await engine.executeWithPlugins(
'generateText',
{ model: 'gpt-4', messages: [] },
vi.fn().mockResolvedValue({ text: 'test' })
)
expect(configureContextSpy).toHaveBeenCalled()
expect(onRequestStartSpy).toHaveBeenCalled()
// configureContext should be called before onRequestStart
expect(configureContextSpy.mock.invocationCallOrder[0]).toBeLessThan(
onRequestStartSpy.mock.invocationCallOrder[0]
)
})
it('should execute onRequestEnd after successful execution', async () => {
const onRequestEndSpy = vi.fn()
const plugin: AiPlugin = {
name: 'test',
resolveModel: vi.fn().mockResolvedValue(mockLanguageModel),
onRequestEnd: onRequestEndSpy
}
engine = new PluginEngine('openai', [plugin])
await engine.executeWithPlugins(
'generateText',
{ model: 'gpt-4', messages: [] },
vi.fn().mockResolvedValue({ text: 'result' })
)
expect(onRequestEndSpy).toHaveBeenCalledWith(
expect.any(Object), // context
expect.objectContaining({ text: 'result' })
)
})
it('should execute onError on failure', async () => {
const onErrorSpy = vi.fn()
const testError = new Error('Test error')
const plugin: AiPlugin = {
name: 'error-handler',
resolveModel: vi.fn().mockResolvedValue(mockLanguageModel),
onError: onErrorSpy
}
engine = new PluginEngine('openai', [plugin])
await expect(
engine.executeWithPlugins(
'generateText',
{ model: 'gpt-4', messages: [] },
vi.fn().mockRejectedValue(testError)
)
).rejects.toThrow('Test error')
expect(onErrorSpy).toHaveBeenCalledWith(
testError,
expect.any(Object) // context
)
})
it('should not call onRequestEnd when error occurs', async () => {
const onRequestEndSpy = vi.fn()
const plugin: AiPlugin = {
name: 'test',
resolveModel: vi.fn().mockResolvedValue(mockLanguageModel),
onRequestEnd: onRequestEndSpy
}
engine = new PluginEngine('openai', [plugin])
await expect(
engine.executeWithPlugins(
'generateText',
{ model: 'gpt-4', messages: [] },
vi.fn().mockRejectedValue(new Error('Execution error'))
)
).rejects.toThrow()
expect(onRequestEndSpy).not.toHaveBeenCalled()
})
})
describe('Model Resolution', () => {
it('should resolve string model through plugin', async () => {
const resolveModelSpy = vi.fn().mockResolvedValue(mockLanguageModel)
const plugin: AiPlugin = {
name: 'resolver',
resolveModel: resolveModelSpy
}
engine = new PluginEngine('openai', [plugin])
await engine.executeWithPlugins(
'generateText',
{ model: 'gpt-4', messages: [] },
vi.fn().mockResolvedValue({ text: 'test' })
)
expect(resolveModelSpy).toHaveBeenCalledWith('gpt-4', expect.any(Object))
})
it('should use first plugin that resolves model', async () => {
const resolver1 = vi.fn().mockResolvedValue(mockLanguageModel)
const resolver2 = vi.fn().mockResolvedValue(mockLanguageModel)
const plugin1: AiPlugin = { name: 'resolver-1', resolveModel: resolver1 }
const plugin2: AiPlugin = { name: 'resolver-2', resolveModel: resolver2 }
engine = new PluginEngine('openai', [plugin1, plugin2])
await engine.executeWithPlugins(
'generateText',
{ model: 'gpt-4', messages: [] },
vi.fn().mockResolvedValue({ text: 'test' })
)
expect(resolver1).toHaveBeenCalled()
expect(resolver2).not.toHaveBeenCalled() // Should stop after first resolver
})
it('should throw ModelResolutionError if no plugin resolves model', async () => {
const plugin: AiPlugin = {
name: 'no-resolver'
// No resolveModel hook
}
engine = new PluginEngine('openai', [plugin])
await expect(
engine.executeWithPlugins(
'generateText',
{ model: 'unknown-model', messages: [] },
vi.fn().mockResolvedValue({ text: 'test' })
)
).rejects.toThrow(ModelResolutionError)
})
it('should skip resolution for direct model objects', async () => {
const resolveModelSpy = vi.fn()
const plugin: AiPlugin = {
name: 'resolver',
resolveModel: resolveModelSpy
}
engine = new PluginEngine('openai', [plugin])
await engine.executeWithPlugins(
'generateText',
{ model: mockLanguageModel, messages: [] },
vi.fn().mockResolvedValue({ text: 'test' })
)
expect(resolveModelSpy).not.toHaveBeenCalled()
})
it('should throw if resolved model is null/undefined', async () => {
const plugin: AiPlugin = {
name: 'bad-resolver',
resolveModel: vi.fn().mockResolvedValue(null)
}
engine = new PluginEngine('openai', [plugin])
await expect(
engine.executeWithPlugins(
'generateText',
{ model: 'gpt-4', messages: [] },
vi.fn().mockResolvedValue({ text: 'test' })
)
).rejects.toThrow(ModelResolutionError)
})
})
describe('Parameter Transformation', () => {
it('should transform parameters through plugin', async () => {
const transformParamsSpy = vi.fn().mockImplementation(async (params) => ({
...params,
temperature: 0.8
}))
const plugin: AiPlugin = {
name: 'transformer',
resolveModel: vi.fn().mockResolvedValue(mockLanguageModel),
transformParams: transformParamsSpy
}
engine = new PluginEngine('openai', [plugin])
const mockExecutor = vi.fn().mockResolvedValue({ text: 'test' })
await engine.executeWithPlugins('generateText', { model: 'gpt-4', messages: [] }, mockExecutor)
expect(mockExecutor).toHaveBeenCalledWith(
expect.any(Object),
expect.objectContaining({
temperature: 0.8
})
)
})
it('should chain parameter transformations across plugins', async () => {
const plugin1: AiPlugin = {
name: 'transform-1',
transformParams: vi.fn().mockImplementation(async (params) => ({
...params,
temperature: 0.5
}))
}
const plugin2: AiPlugin = {
name: 'transform-2',
transformParams: vi.fn().mockImplementation(async (params) => ({
...params,
maxTokens: 100
}))
}
engine = new PluginEngine('openai', [plugin1, plugin2])
engine.usePlugins([{ name: 'resolver', resolveModel: vi.fn().mockResolvedValue(mockLanguageModel) }])
const mockExecutor = vi.fn().mockResolvedValue({ text: 'test' })
await engine.executeWithPlugins('generateText', { model: 'gpt-4', messages: [] }, mockExecutor)
expect(mockExecutor).toHaveBeenCalledWith(
expect.any(Object),
expect.objectContaining({
temperature: 0.5,
maxTokens: 100
})
)
})
})
describe('Result Transformation', () => {
it('should transform result through plugin', async () => {
const transformResultSpy = vi.fn().mockImplementation(async (result) => ({
...result,
text: `${result.text} [modified]`
}))
const plugin: AiPlugin = {
name: 'result-transformer',
resolveModel: vi.fn().mockResolvedValue(mockLanguageModel),
transformResult: transformResultSpy
}
engine = new PluginEngine('openai', [plugin])
const result = await engine.executeWithPlugins(
'generateText',
{ model: 'gpt-4', messages: [] },
vi.fn().mockResolvedValue({ text: 'original' })
)
expect(result.text).toBe('original [modified]')
})
it('should chain result transformations across plugins', async () => {
const plugin1: AiPlugin = {
name: 'transform-1',
transformResult: vi.fn().mockImplementation(async (result) => ({
...result,
text: `${result.text} + plugin1`
}))
}
const plugin2: AiPlugin = {
name: 'transform-2',
transformResult: vi.fn().mockImplementation(async (result) => ({
...result,
text: `${result.text} + plugin2`
}))
}
engine = new PluginEngine('openai', [plugin1, plugin2])
engine.usePlugins([{ name: 'resolver', resolveModel: vi.fn().mockResolvedValue(mockLanguageModel) }])
const result = await engine.executeWithPlugins(
'generateText',
{ model: 'gpt-4', messages: [] },
vi.fn().mockResolvedValue({ text: 'base' })
)
expect(result.text).toBe('base + plugin1 + plugin2')
})
})
describe('Recursive Calls', () => {
it('should support recursive calls through context', async () => {
let recursionCount = 0
const plugin: AiPlugin = {
name: 'recursive',
resolveModel: vi.fn().mockResolvedValue(mockLanguageModel),
transformParams: vi.fn().mockImplementation(async (params, context) => {
if (recursionCount < 2 && context.recursiveCall) {
recursionCount++
await context.recursiveCall({ messages: [{ role: 'user', content: 'recursive' }] })
}
return params
})
}
engine = new PluginEngine('openai', [plugin])
await engine.executeWithPlugins(
'generateText',
{ model: 'gpt-4', messages: [] },
vi.fn().mockResolvedValue({ text: 'test' })
)
expect(recursionCount).toBe(2)
})
it('should track recursion depth', async () => {
const depths: number[] = []
const plugin: AiPlugin = {
name: 'depth-tracker',
resolveModel: vi.fn().mockResolvedValue(mockLanguageModel),
transformParams: vi.fn().mockImplementation(async (params, context) => {
depths.push(context.recursiveDepth)
if (context.recursiveDepth < 3 && context.recursiveCall) {
await context.recursiveCall({ messages: [] })
}
return params
})
}
engine = new PluginEngine('openai', [plugin])
await engine.executeWithPlugins(
'generateText',
{ model: 'gpt-4', messages: [] },
vi.fn().mockResolvedValue({ text: 'test' })
)
expect(depths).toEqual([0, 1, 2, 3])
})
it('should throw RecursiveDepthError when max depth exceeded', async () => {
const plugin: AiPlugin = {
name: 'infinite',
resolveModel: vi.fn().mockResolvedValue(mockLanguageModel),
transformParams: vi.fn().mockImplementation(async (params, context) => {
if (context.recursiveCall) {
await context.recursiveCall({ messages: [] })
}
return params
})
}
engine = new PluginEngine('openai', [plugin])
await expect(
engine.executeWithPlugins(
'generateText',
{ model: 'gpt-4', messages: [] },
vi.fn().mockResolvedValue({ text: 'test' })
)
).rejects.toThrow(RecursiveDepthError)
})
it('should restore recursion state after recursive call', async () => {
const states: Array<{ depth: number; isRecursive: boolean }> = []
const plugin: AiPlugin = {
name: 'state-tracker',
resolveModel: vi.fn().mockResolvedValue(mockLanguageModel),
transformParams: vi.fn().mockImplementation(async (params, context) => {
states.push({ depth: context.recursiveDepth, isRecursive: context.isRecursiveCall })
if (context.recursiveDepth === 0 && context.recursiveCall) {
await context.recursiveCall({ messages: [] })
states.push({ depth: context.recursiveDepth, isRecursive: context.isRecursiveCall })
}
return params
})
}
engine = new PluginEngine('openai', [plugin])
await engine.executeWithPlugins(
'generateText',
{ model: 'gpt-4', messages: [] },
vi.fn().mockResolvedValue({ text: 'test' })
)
expect(states[0]).toEqual({ depth: 0, isRecursive: false })
expect(states[1]).toEqual({ depth: 1, isRecursive: true })
expect(states[2]).toEqual({ depth: 0, isRecursive: false })
})
})
describe('Image Model Execution', () => {
it('should execute image generation with plugins', async () => {
const plugin: AiPlugin = {
name: 'image-plugin',
resolveModel: vi.fn().mockResolvedValue(mockImageModel),
transformParams: vi.fn().mockImplementation(async (params) => params)
}
engine = new PluginEngine('openai', [plugin])
const mockExecutor = vi.fn().mockResolvedValue({
image: { base64: 'test', uint8Array: new Uint8Array(), mimeType: 'image/png' }
})
await engine.executeImageWithPlugins('generateImage', { model: 'dall-e-3', prompt: 'test' }, mockExecutor)
expect(plugin.resolveModel).toHaveBeenCalledWith('dall-e-3', expect.any(Object))
expect(mockExecutor).toHaveBeenCalled()
})
it('should skip resolution for direct image model objects', async () => {
const resolveModelSpy = vi.fn()
const plugin: AiPlugin = {
name: 'image-resolver',
resolveModel: resolveModelSpy
}
engine = new PluginEngine('openai', [plugin])
await engine.executeImageWithPlugins(
'generateImage',
{ model: mockImageModel, prompt: 'test' },
vi.fn().mockResolvedValue({ image: {} })
)
expect(resolveModelSpy).not.toHaveBeenCalled()
})
})
describe('Streaming Execution', () => {
it('should execute streaming with plugins', async () => {
const plugin: AiPlugin = {
name: 'stream-plugin',
resolveModel: vi.fn().mockResolvedValue(mockLanguageModel),
transformParams: vi.fn().mockImplementation(async (params) => params)
}
engine = new PluginEngine('openai', [plugin])
const mockExecutor = vi.fn().mockResolvedValue({
textStream: (async function* () {
yield 'test'
})()
})
await engine.executeStreamWithPlugins('streamText', { model: 'gpt-4', messages: [] }, mockExecutor)
expect(plugin.resolveModel).toHaveBeenCalled()
expect(mockExecutor).toHaveBeenCalled()
})
it('should collect stream transforms from plugins', async () => {
const mockTransform = vi.fn()
const plugin: AiPlugin = {
name: 'stream-transformer',
resolveModel: vi.fn().mockResolvedValue(mockLanguageModel),
transformStream: mockTransform
}
engine = new PluginEngine('openai', [plugin])
const mockExecutor = vi.fn().mockResolvedValue({ textStream: (async function* () {})() })
await engine.executeStreamWithPlugins('streamText', { model: 'gpt-4', messages: [] }, mockExecutor)
// Executor should receive stream transforms
expect(mockExecutor).toHaveBeenCalledWith(expect.any(Object), expect.any(Object), expect.arrayContaining([]))
})
})
describe('Context Management', () => {
it('should create context with correct provider and model', async () => {
const configureContextSpy = vi.fn()
const plugin: AiPlugin = {
name: 'context-checker',
configureContext: configureContextSpy,
resolveModel: vi.fn().mockResolvedValue(mockLanguageModel)
}
engine = new PluginEngine('openai', [plugin])
await engine.executeWithPlugins(
'generateText',
{ model: 'gpt-4', messages: [] },
vi.fn().mockResolvedValue({ text: 'test' })
)
expect(configureContextSpy).toHaveBeenCalledWith(
expect.objectContaining({
providerId: 'openai',
model: 'gpt-4'
})
)
})
it('should pass context to all hooks', async () => {
const contextRefs: any[] = []
const plugin: AiPlugin = {
name: 'context-tracker',
configureContext: vi.fn().mockImplementation(async (context) => {
contextRefs.push(context)
}),
onRequestStart: vi.fn().mockImplementation(async (context) => {
contextRefs.push(context)
}),
resolveModel: vi.fn().mockImplementation(async (_, context) => {
contextRefs.push(context)
return mockLanguageModel
}),
transformParams: vi.fn().mockImplementation(async (params, context) => {
contextRefs.push(context)
return params
})
}
engine = new PluginEngine('openai', [plugin])
await engine.executeWithPlugins(
'generateText',
{ model: 'gpt-4', messages: [] },
vi.fn().mockResolvedValue({ text: 'test' })
)
// All context refs should point to the same object
expect(contextRefs.length).toBeGreaterThan(0)
const firstContext = contextRefs[0]
contextRefs.forEach((context) => {
expect(context).toBe(firstContext)
})
})
})
describe('Error Handling', () => {
it('should propagate errors from executor', async () => {
const plugin: AiPlugin = {
name: 'test',
resolveModel: vi.fn().mockResolvedValue(mockLanguageModel)
}
engine = new PluginEngine('openai', [plugin])
const error = new Error('Executor failed')
await expect(
engine.executeWithPlugins('generateText', { model: 'gpt-4', messages: [] }, vi.fn().mockRejectedValue(error))
).rejects.toThrow('Executor failed')
})
it('should trigger onError for all plugins on failure', async () => {
const onError1 = vi.fn()
const onError2 = vi.fn()
const plugin1: AiPlugin = { name: 'error-1', onError: onError1 }
const plugin2: AiPlugin = { name: 'error-2', onError: onError2 }
engine = new PluginEngine('openai', [plugin1, plugin2])
engine.usePlugins([{ name: 'resolver', resolveModel: vi.fn().mockResolvedValue(mockLanguageModel) }])
const error = new Error('Test failure')
await expect(
engine.executeWithPlugins('generateText', { model: 'gpt-4', messages: [] }, vi.fn().mockRejectedValue(error))
).rejects.toThrow()
expect(onError1).toHaveBeenCalledWith(error, expect.any(Object))
expect(onError2).toHaveBeenCalledWith(error, expect.any(Object))
})
it('should handle errors in plugin hooks gracefully', async () => {
const plugin: AiPlugin = {
name: 'failing-plugin',
resolveModel: vi.fn().mockResolvedValue(mockLanguageModel),
transformParams: vi.fn().mockRejectedValue(new Error('Transform failed'))
}
engine = new PluginEngine('openai', [plugin])
await expect(
engine.executeWithPlugins('generateText', { model: 'gpt-4', messages: [] }, vi.fn())
).rejects.toThrow('Transform failed')
})
})
describe('Plugin Enforcement', () => {
it('should respect plugin enforce ordering (pre, normal, post)', async () => {
const executionOrder: string[] = []
const prePlugin: AiPlugin = {
name: 'pre-plugin',
enforce: 'pre',
onRequestStart: vi.fn(async () => {
executionOrder.push('pre')
})
}
const normalPlugin: AiPlugin = {
name: 'normal-plugin',
onRequestStart: vi.fn(async () => {
executionOrder.push('normal')
})
}
const postPlugin: AiPlugin = {
name: 'post-plugin',
enforce: 'post',
onRequestStart: vi.fn(async () => {
executionOrder.push('post')
})
}
// Add in reverse order to test sorting
engine = new PluginEngine('openai', [postPlugin, normalPlugin, prePlugin])
engine.usePlugins([{ name: 'resolver', resolveModel: vi.fn().mockResolvedValue(mockLanguageModel) }])
await engine.executeWithPlugins(
'generateText',
{ model: 'gpt-4', messages: [] },
vi.fn().mockResolvedValue({ text: 'test' })
)
// Should execute in order: pre -> normal -> post
expect(executionOrder).toEqual(['pre', 'normal', 'post'])
})
})
describe('Type Safety', () => {
it('should properly type plugin parameters', async () => {
const transformParamsSpy = vi.fn().mockImplementation(async (params) => {
// Type assertions for safety
expect(params).toHaveProperty('messages')
return params
})
const transformResultSpy = vi.fn().mockImplementation(async (result) => {
expect(result).toHaveProperty('text')
return result
})
const typedPlugin: AiPlugin = {
name: 'typed-plugin',
resolveModel: vi.fn().mockResolvedValue(mockLanguageModel),
transformParams: transformParamsSpy,
transformResult: transformResultSpy
}
engine = new PluginEngine('openai', [typedPlugin])
await engine.executeWithPlugins(
'generateText',
{ model: 'gpt-4', messages: [] },
vi.fn().mockResolvedValue({ text: 'test', finishReason: 'stop' })
)
expect(transformParamsSpy).toHaveBeenCalled()
expect(transformResultSpy).toHaveBeenCalled()
})
})
})

View File

@ -75,10 +75,7 @@ export class PluginEngine<T extends ProviderId = ProviderId> {
*
* AiExecutor使用
*/
async executeWithPlugins<
TParams extends GenerateTextParams,
TResult extends GenerateTextResult
>(
async executeWithPlugins<TParams extends GenerateTextParams, TResult extends GenerateTextResult>(
methodName: string,
params: TParams,
executor: (model: LanguageModel, transformedParams: TParams) => TResult,
@ -101,9 +98,7 @@ export class PluginEngine<T extends ProviderId = ProviderId> {
const context = _context ?? createContext(this.providerId, model, params)
// ✅ 创建类型化的 manager逆变安全
const manager = new PluginManager<TParams, TResult>(
this.basePlugins as AiPlugin<TParams, TResult>[]
)
const manager = new PluginManager<TParams, TResult>(this.basePlugins as AiPlugin<TParams, TResult>[])
// ✅ 递归调用泛型化,增加深度限制
context.recursiveCall = async <R = TResult>(newParams: Partial<TParams>): Promise<R> => {
@ -118,12 +113,12 @@ export class PluginEngine<T extends ProviderId = ProviderId> {
context.recursiveDepth = previousDepth + 1
context.isRecursiveCall = true
return await this.executeWithPlugins(
return (await this.executeWithPlugins(
methodName,
{ ...params, ...newParams } as TParams,
executor,
context
) as unknown as R
)) as unknown as R
} finally {
// ✅ finally 确保状态恢复
context.recursiveDepth = previousDepth
@ -201,9 +196,7 @@ export class PluginEngine<T extends ProviderId = ProviderId> {
const context = _context ?? createContext(this.providerId, model, params)
// ✅ 创建类型化的 manager逆变安全
const manager = new PluginManager<TParams, TResult>(
this.basePlugins as AiPlugin<TParams, TResult>[]
)
const manager = new PluginManager<TParams, TResult>(this.basePlugins as AiPlugin<TParams, TResult>[])
// ✅ 递归调用泛型化,增加深度限制
context.recursiveCall = async <R = TResult>(newParams: Partial<TParams>): Promise<R> => {
@ -218,12 +211,12 @@ export class PluginEngine<T extends ProviderId = ProviderId> {
context.recursiveDepth = previousDepth + 1
context.isRecursiveCall = true
return await this.executeImageWithPlugins(
return (await this.executeImageWithPlugins(
methodName,
{ ...params, ...newParams } as TParams,
executor,
context
) as unknown as R
)) as unknown as R
} finally {
// ✅ finally 确保状态恢复
context.recursiveDepth = previousDepth
@ -275,10 +268,7 @@ export class PluginEngine<T extends ProviderId = ProviderId> {
*
* AiExecutor使用
*/
async executeStreamWithPlugins<
TParams extends StreamTextParams,
TResult extends StreamTextResult
>(
async executeStreamWithPlugins<TParams extends StreamTextParams, TResult extends StreamTextResult>(
methodName: string,
params: TParams,
executor: (model: LanguageModel, transformedParams: TParams, streamTransforms: any[]) => TResult,
@ -301,9 +291,7 @@ export class PluginEngine<T extends ProviderId = ProviderId> {
const context = _context ?? createContext(this.providerId, model, params)
// ✅ 创建类型化的 manager逆变安全
const manager = new PluginManager<TParams, TResult>(
this.basePlugins as AiPlugin<TParams, TResult>[]
)
const manager = new PluginManager<TParams, TResult>(this.basePlugins as AiPlugin<TParams, TResult>[])
// ✅ 递归调用泛型化,增加深度限制
context.recursiveCall = async <R = TResult>(newParams: Partial<TParams>): Promise<R> => {
@ -318,12 +306,12 @@ export class PluginEngine<T extends ProviderId = ProviderId> {
context.recursiveDepth = previousDepth + 1
context.isRecursiveCall = true
return await this.executeStreamWithPlugins(
return (await this.executeStreamWithPlugins(
methodName,
{ ...params, ...newParams } as TParams,
executor,
context
) as unknown as R
)) as unknown as R
} finally {
// ✅ finally 确保状态恢复
context.recursiveDepth = previousDepth