refactor(aiCore): restructure test utilities and fix failing tests

- Move test utilities from src/__tests__/ to test_utils/
- Fix ModelResolver tests for simplified API (2 params instead of 4)
- Fix generateImage/generateText tests with proper vi.fn() mocks
- Fix ExtensionRegistry.parseProviderId to check variants before aliases
- Add createProvider method overload for dynamic provider IDs
- Update ProviderExtension tests for runtime validation behavior
- Delete outdated tests: initialization.test.ts, extensions.integration.test.ts, executor-resolveModel.test.ts
- Remove 3 skipped tests for removed validate hook
- Add HubProvider.integration.test.ts
- All 359 tests passing, 0 skipped

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
suyao 2026-01-02 03:15:50 +08:00
parent 2e97d07c10
commit f805ddc285
No known key found for this signature in database
43 changed files with 3488 additions and 2584 deletions

File diff suppressed because it is too large Load Diff

View File

@ -47,6 +47,7 @@
"@ai-sdk/provider": "^3.0.0",
"@ai-sdk/provider-utils": "^4.0.0",
"@ai-sdk/xai": "^3.0.0",
"lru-cache": "^11.2.4",
"zod": "^4.1.5"
},
"devDependencies": {

View File

@ -1,13 +0,0 @@
/**
* Test Infrastructure Exports
* Central export point for all test utilities, fixtures, and helpers
*/
// Fixtures
export * from './fixtures/mock-providers'
export * from './fixtures/mock-responses'
// Helpers
export * from './helpers/model-test-utils'
export * from './helpers/provider-test-utils'
export * from './helpers/test-utils'

View File

@ -1,3 +0,0 @@
# @cherryStudio-aiCore
Core

View File

@ -8,7 +8,7 @@ export type { NamedMiddleware } from './middleware'
export { createMiddlewares, wrapModelWithMiddlewares } from './middleware'
// 创建管理
export { globalModelResolver, ModelResolver } from './models'
export { ModelResolver } from './models'
export type { ModelConfig as ModelConfigType } from './models/types'
// 执行管理

View File

@ -1,77 +1,56 @@
/**
* - models模块的核心
* modelId解析为AI SDK的LanguageModel实例
*
* ModelCreator
*
* :
* 1. : 'gpt-4' (使provider)
* 2. : 'hub|provider|model' (HubProvider内部路由)
*/
import type { EmbeddingModelV3, ImageModelV3, LanguageModelV3, LanguageModelV3Middleware } from '@ai-sdk/provider'
import type {
EmbeddingModelV3,
ImageModelV3,
LanguageModelV3,
LanguageModelV3Middleware,
ProviderV3
} from '@ai-sdk/provider'
import { wrapModelWithMiddlewares } from '../middleware/wrapper'
import { globalProviderStorage } from '../providers/core/ProviderExtension'
import { DEFAULT_SEPARATOR } from '../providers/features/HubProvider'
export class ModelResolver {
private provider: ProviderV3
/**
* globalProviderStorage provider
* @param providerId - Provider explicit ID
* @throws Error if provider not found
* provider实例
* Provider可以是普通provider或HubProvider
*/
private getProvider(providerId: string) {
const provider = globalProviderStorage.get(providerId)
if (!provider) {
throw new Error(
`Provider "${providerId}" not found. Please ensure it has been initialized with extension.createProvider(settings, "${providerId}")`
)
}
return provider
constructor(provider: ProviderV3) {
this.provider = provider
}
/**
* ID (providerId:modelId )
* @returns { providerId, modelId }
*/
private parseFullModelId(fullModelId: string): { providerId: string; modelId: string } {
const parts = fullModelId.split(DEFAULT_SEPARATOR)
if (parts.length < 2) {
throw new Error(`Invalid model ID format: "${fullModelId}". Expected "providerId${DEFAULT_SEPARATOR}modelId"`)
}
// 支持多个分隔符的情况(如 hub:provider:model
const providerId = parts[0]
const modelId = parts.slice(1).join(DEFAULT_SEPARATOR)
return { providerId, modelId }
}
/**
* modelId为语言模型
*
*
* @param modelId ID 'gpt-4' 'anthropic>claude-3'
* @param fallbackProviderId modelId为传统格式时使用的providerId
* @param providerOptions provider配置选项OpenAI模式选择等
* @param middlewares
* @param modelId ID('gpt-4')('hub|provider|model')
* @param middlewares
* @returns
*
* @example
* ```typescript
* // 传统格式
* const model = await resolver.resolveLanguageModel('gpt-4')
*
* // 命名空间格式 (需要HubProvider)
* const model = await resolver.resolveLanguageModel('hub|openai|gpt-4')
* ```
*/
async resolveLanguageModel(
modelId: string,
fallbackProviderId: string,
providerOptions?: any,
middlewares?: LanguageModelV3Middleware[]
): Promise<LanguageModelV3> {
let finalProviderId = fallbackProviderId
let model: LanguageModelV3
// 🎯 处理 OpenAI 模式选择逻辑 (从 ModelCreator 迁移)
if ((fallbackProviderId === 'openai' || fallbackProviderId === 'azure') && providerOptions?.mode === 'chat') {
finalProviderId = `${fallbackProviderId}-chat`
}
async resolveLanguageModel(modelId: string, middlewares?: LanguageModelV3Middleware[]): Promise<LanguageModelV3> {
// 直接将完整的modelId传给provider
// - 如果是普通provider会直接使用modelId
// - 如果是HubProvider会解析命名空间并路由到正确的provider
let model = this.provider.languageModel(modelId)
// 检查是否是命名空间格式
if (modelId.includes(DEFAULT_SEPARATOR)) {
model = this.resolveNamespacedModel(modelId)
} else {
// 传统格式:使用处理后的 providerId + modelId
model = this.resolveTraditionalModel(finalProviderId, modelId)
}
// 🎯 应用中间件(如果有)
// 应用中间件
if (middlewares && middlewares.length > 0) {
model = wrapModelWithMiddlewares(model, middlewares)
}
@ -81,81 +60,21 @@ export class ModelResolver {
/**
*
*
* @param modelId ID
* @returns
*/
async resolveTextEmbeddingModel(modelId: string, fallbackProviderId: string): Promise<EmbeddingModelV3> {
if (modelId.includes(DEFAULT_SEPARATOR)) {
return this.resolveNamespacedEmbeddingModel(modelId)
}
return this.resolveTraditionalEmbeddingModel(fallbackProviderId, modelId)
async resolveEmbeddingModel(modelId: string): Promise<EmbeddingModelV3> {
return this.provider.embeddingModel(modelId)
}
/**
*
*
*
* @param modelId ID
* @returns
*/
async resolveImageModel(modelId: string, fallbackProviderId: string): Promise<ImageModelV3> {
if (modelId.includes(DEFAULT_SEPARATOR)) {
return this.resolveNamespacedImageModel(modelId)
}
return this.resolveTraditionalImageModel(fallbackProviderId, modelId)
}
/**
*
* aihubmix:anthropic:claude-3 -> globalProviderStorage 'aihubmix' provider languageModel('anthropic:claude-3')
*/
private resolveNamespacedModel(fullModelId: string): LanguageModelV3 {
const { providerId, modelId } = this.parseFullModelId(fullModelId)
const provider = this.getProvider(providerId)
return provider.languageModel(modelId)
}
/**
*
* providerId: 'openai', modelId: 'gpt-4' -> globalProviderStorage 'openai' provider languageModel('gpt-4')
*/
private resolveTraditionalModel(providerId: string, modelId: string): LanguageModelV3 {
const provider = this.getProvider(providerId)
return provider.languageModel(modelId)
}
/**
*
*/
private resolveNamespacedEmbeddingModel(fullModelId: string): EmbeddingModelV3 {
const { providerId, modelId } = this.parseFullModelId(fullModelId)
const provider = this.getProvider(providerId)
return provider.embeddingModel(modelId)
}
/**
*
*/
private resolveTraditionalEmbeddingModel(providerId: string, modelId: string): EmbeddingModelV3 {
const provider = this.getProvider(providerId)
return provider.embeddingModel(modelId)
}
/**
*
*/
private resolveNamespacedImageModel(fullModelId: string): ImageModelV3 {
const { providerId, modelId } = this.parseFullModelId(fullModelId)
const provider = this.getProvider(providerId)
return provider.imageModel(modelId)
}
/**
*
*/
private resolveTraditionalImageModel(providerId: string, modelId: string): ImageModelV3 {
const provider = this.getProvider(providerId)
return provider.imageModel(modelId)
async resolveImageModel(modelId: string): Promise<ImageModelV3> {
return this.provider.imageModel(modelId)
}
}
/**
*
*/
export const globalModelResolver = new ModelResolver()

View File

@ -1,34 +1,23 @@
/**
* ModelResolver Comprehensive Tests
* ModelResolver Tests
* Tests model resolution logic for language, embedding, and image models
* Covers both traditional and namespaced format resolution
* The resolver passes modelId directly to provider - all routing is handled by the provider
*/
import type { EmbeddingModelV3, ImageModelV3, LanguageModelV3 } from '@ai-sdk/provider'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import {
createMockEmbeddingModel,
createMockImageModel,
createMockLanguageModel,
createMockMiddleware
} from '../../../__tests__'
import { DEFAULT_SEPARATOR, globalProviderInstanceRegistry } from '../../providers/core/ProviderInstanceRegistry'
import { ModelResolver } from '../ModelResolver'
createMockMiddleware,
createMockProviderV3
} from '@test-utils'
import { beforeEach, describe, expect, it, vi } from 'vitest'
// Mock the dependencies
vi.mock('../../providers/core/ProviderInstanceRegistry', () => ({
globalProviderInstanceRegistry: {
languageModel: vi.fn(),
embeddingModel: vi.fn(),
imageModel: vi.fn()
},
DEFAULT_SEPARATOR: '|'
}))
import { ModelResolver } from '../ModelResolver'
vi.mock('../../middleware/wrapper', () => ({
wrapModelWithMiddlewares: vi.fn((model: LanguageModelV3) => {
// Return a wrapped model with a marker
return {
...model,
_wrapped: true
@ -41,12 +30,12 @@ describe('ModelResolver', () => {
let mockLanguageModel: LanguageModelV3
let mockEmbeddingModel: EmbeddingModelV3
let mockImageModel: ImageModelV3
let mockProvider: any
beforeEach(() => {
vi.clearAllMocks()
resolver = new ModelResolver()
// Create properly typed mock models using global utilities
// Create properly typed mock models
mockLanguageModel = createMockLanguageModel({
provider: 'test-provider',
modelId: 'test-model'
@ -62,395 +51,204 @@ describe('ModelResolver', () => {
modelId: 'test-image'
})
// Setup default mock implementations
vi.mocked(globalProviderInstanceRegistry.languageModel).mockReturnValue(mockLanguageModel)
vi.mocked(globalProviderInstanceRegistry.embeddingModel).mockReturnValue(mockEmbeddingModel)
vi.mocked(globalProviderInstanceRegistry.imageModel).mockReturnValue(mockImageModel)
// Create mock provider with model methods as spies
mockProvider = createMockProviderV3({
provider: 'test-provider',
languageModel: vi.fn(() => mockLanguageModel),
embeddingModel: vi.fn(() => mockEmbeddingModel),
imageModel: vi.fn(() => mockImageModel)
})
// Create resolver with mock provider
resolver = new ModelResolver(mockProvider)
})
describe('resolveLanguageModel', () => {
describe('Traditional Format Resolution', () => {
it('should resolve traditional format modelId without separator', async () => {
const result = await resolver.resolveLanguageModel('gpt-4', 'openai')
it('should resolve modelId by passing it to provider', async () => {
const result = await resolver.resolveLanguageModel('gpt-4')
expect(globalProviderInstanceRegistry.languageModel).toHaveBeenCalledWith(`openai${DEFAULT_SEPARATOR}gpt-4`)
expect(result).toBe(mockLanguageModel)
})
it('should resolve with different provider and modelId combinations', async () => {
const testCases: Array<{ modelId: string; providerId: string; expected: string }> = [
{ modelId: 'claude-3-5-sonnet', providerId: 'anthropic', expected: 'anthropic|claude-3-5-sonnet' },
{ modelId: 'gemini-2.0-flash', providerId: 'google', expected: 'google|gemini-2.0-flash' },
{ modelId: 'grok-2-latest', providerId: 'xai', expected: 'xai|grok-2-latest' },
{ modelId: 'deepseek-chat', providerId: 'deepseek', expected: 'deepseek|deepseek-chat' }
]
for (const testCase of testCases) {
vi.clearAllMocks()
await resolver.resolveLanguageModel(testCase.modelId, testCase.providerId)
expect(globalProviderInstanceRegistry.languageModel).toHaveBeenCalledWith(testCase.expected)
}
})
it('should handle modelIds with special characters', async () => {
const modelIds = ['model-v1.0', 'model_v2', 'model.2024', 'model:free']
for (const modelId of modelIds) {
vi.clearAllMocks()
await resolver.resolveLanguageModel(modelId, 'provider')
expect(globalProviderInstanceRegistry.languageModel).toHaveBeenCalledWith(
`provider${DEFAULT_SEPARATOR}${modelId}`
)
}
})
expect(mockProvider.languageModel).toHaveBeenCalledWith('gpt-4')
expect(result).toBe(mockLanguageModel)
})
describe('Namespaced Format Resolution', () => {
it('should resolve namespaced format with hub', async () => {
const namespacedId = `aihubmix${DEFAULT_SEPARATOR}anthropic${DEFAULT_SEPARATOR}claude-3-5-sonnet`
it('should pass various modelIds directly to provider', async () => {
const modelIds = [
'claude-3-5-sonnet',
'gemini-2.0-flash',
'grok-2-latest',
'deepseek-chat',
'model-v1.0',
'model_v2',
'model.2024'
]
const result = await resolver.resolveLanguageModel(namespacedId, 'openai')
for (const modelId of modelIds) {
vi.clearAllMocks()
await resolver.resolveLanguageModel(modelId)
expect(globalProviderInstanceRegistry.languageModel).toHaveBeenCalledWith(namespacedId)
expect(result).toBe(mockLanguageModel)
})
it('should resolve simple namespaced format', async () => {
const namespacedId = `provider${DEFAULT_SEPARATOR}model-id`
await resolver.resolveLanguageModel(namespacedId, 'fallback-provider')
expect(globalProviderInstanceRegistry.languageModel).toHaveBeenCalledWith(namespacedId)
})
it('should handle complex namespaced IDs', async () => {
const complexIds = [
`hub${DEFAULT_SEPARATOR}provider${DEFAULT_SEPARATOR}model`,
`hub${DEFAULT_SEPARATOR}provider${DEFAULT_SEPARATOR}model-v1.0`,
`custom${DEFAULT_SEPARATOR}openai${DEFAULT_SEPARATOR}gpt-4-turbo`
]
for (const id of complexIds) {
vi.clearAllMocks()
await resolver.resolveLanguageModel(id, 'fallback')
expect(globalProviderInstanceRegistry.languageModel).toHaveBeenCalledWith(id)
}
})
expect(mockProvider.languageModel).toHaveBeenCalledWith(modelId)
}
})
describe('OpenAI Mode Selection', () => {
it('should append "-chat" suffix for OpenAI provider with chat mode', async () => {
await resolver.resolveLanguageModel('gpt-4', 'openai', { mode: 'chat' })
it('should pass namespaced modelIds directly to provider (provider handles routing)', async () => {
// HubProvider handles routing internally - ModelResolver just passes through
const namespacedId = 'openai|gpt-4'
expect(globalProviderInstanceRegistry.languageModel).toHaveBeenCalledWith('openai-chat|gpt-4')
})
await resolver.resolveLanguageModel(namespacedId)
it('should append "-chat" suffix for Azure provider with chat mode', async () => {
await resolver.resolveLanguageModel('gpt-4', 'azure', { mode: 'chat' })
expect(mockProvider.languageModel).toHaveBeenCalledWith(namespacedId)
})
expect(globalProviderInstanceRegistry.languageModel).toHaveBeenCalledWith('azure-chat|gpt-4')
})
it('should handle empty model IDs', async () => {
await resolver.resolveLanguageModel('')
it('should not append suffix for OpenAI with responses mode', async () => {
await resolver.resolveLanguageModel('gpt-4', 'openai', { mode: 'responses' })
expect(globalProviderInstanceRegistry.languageModel).toHaveBeenCalledWith('openai|gpt-4')
})
it('should not append suffix for OpenAI without mode', async () => {
await resolver.resolveLanguageModel('gpt-4', 'openai')
expect(globalProviderInstanceRegistry.languageModel).toHaveBeenCalledWith('openai|gpt-4')
})
it('should not append suffix for other providers with chat mode', async () => {
await resolver.resolveLanguageModel('claude-3', 'anthropic', { mode: 'chat' })
expect(globalProviderInstanceRegistry.languageModel).toHaveBeenCalledWith('anthropic|claude-3')
})
it('should handle namespaced IDs with OpenAI chat mode', async () => {
const namespacedId = `hub${DEFAULT_SEPARATOR}openai${DEFAULT_SEPARATOR}gpt-4`
await resolver.resolveLanguageModel(namespacedId, 'openai', { mode: 'chat' })
// Should use the namespaced ID directly, not apply mode logic
expect(globalProviderInstanceRegistry.languageModel).toHaveBeenCalledWith(namespacedId)
})
expect(mockProvider.languageModel).toHaveBeenCalledWith('')
})
describe('Middleware Application', () => {
it('should apply middlewares to resolved model', async () => {
const mockMiddleware = createMockMiddleware()
const result = await resolver.resolveLanguageModel('gpt-4', 'openai', undefined, [mockMiddleware])
const result = await resolver.resolveLanguageModel('gpt-4', [mockMiddleware])
expect(result).toHaveProperty('_wrapped', true)
})
it('should apply multiple middlewares in order', async () => {
it('should apply multiple middlewares', async () => {
const middleware1 = createMockMiddleware()
const middleware2 = createMockMiddleware()
const result = await resolver.resolveLanguageModel('gpt-4', 'openai', undefined, [middleware1, middleware2])
const result = await resolver.resolveLanguageModel('gpt-4', [middleware1, middleware2])
expect(result).toHaveProperty('_wrapped', true)
})
it('should not apply middlewares when none provided', async () => {
const result = await resolver.resolveLanguageModel('gpt-4', 'openai')
const result = await resolver.resolveLanguageModel('gpt-4')
expect(result).not.toHaveProperty('_wrapped')
expect(result).toBe(mockLanguageModel)
})
it('should not apply middlewares when empty array provided', async () => {
const result = await resolver.resolveLanguageModel('gpt-4', 'openai', undefined, [])
const result = await resolver.resolveLanguageModel('gpt-4', [])
expect(result).not.toHaveProperty('_wrapped')
})
})
describe('Provider Options Handling', () => {
it('should pass provider options correctly', async () => {
const options = { baseURL: 'https://api.example.com', apiKey: 'test-key' }
await resolver.resolveLanguageModel('gpt-4', 'openai', options)
// Provider options are used for mode selection logic
expect(globalProviderInstanceRegistry.languageModel).toHaveBeenCalled()
})
it('should handle empty provider options', async () => {
await resolver.resolveLanguageModel('gpt-4', 'openai', {})
expect(globalProviderInstanceRegistry.languageModel).toHaveBeenCalledWith('openai|gpt-4')
})
it('should handle undefined provider options', async () => {
await resolver.resolveLanguageModel('gpt-4', 'openai', undefined)
expect(globalProviderInstanceRegistry.languageModel).toHaveBeenCalledWith('openai|gpt-4')
})
})
})
describe('resolveTextEmbeddingModel', () => {
describe('Traditional Format', () => {
it('should resolve traditional embedding model ID', async () => {
const result = await resolver.resolveTextEmbeddingModel('text-embedding-ada-002', 'openai')
expect(globalProviderInstanceRegistry.embeddingModel).toHaveBeenCalledWith('openai|text-embedding-ada-002')
expect(result).toBe(mockEmbeddingModel)
})
it('should resolve different embedding models', async () => {
const testCases = [
{ modelId: 'text-embedding-3-small', providerId: 'openai' },
{ modelId: 'text-embedding-3-large', providerId: 'openai' },
{ modelId: 'embed-english-v3.0', providerId: 'cohere' },
{ modelId: 'voyage-2', providerId: 'voyage' }
]
for (const { modelId, providerId } of testCases) {
vi.clearAllMocks()
await resolver.resolveTextEmbeddingModel(modelId, providerId)
expect(globalProviderInstanceRegistry.embeddingModel).toHaveBeenCalledWith(`${providerId}|${modelId}`)
}
})
})
describe('Namespaced Format', () => {
it('should resolve namespaced embedding model ID', async () => {
const namespacedId = `aihubmix${DEFAULT_SEPARATOR}openai${DEFAULT_SEPARATOR}text-embedding-3-small`
const result = await resolver.resolveTextEmbeddingModel(namespacedId, 'openai')
expect(globalProviderInstanceRegistry.embeddingModel).toHaveBeenCalledWith(namespacedId)
expect(result).toBe(mockEmbeddingModel)
})
it('should handle complex namespaced embedding IDs', async () => {
const complexIds = [
`hub${DEFAULT_SEPARATOR}cohere${DEFAULT_SEPARATOR}embed-multilingual`,
`custom${DEFAULT_SEPARATOR}provider${DEFAULT_SEPARATOR}embedding-model`
]
for (const id of complexIds) {
vi.clearAllMocks()
await resolver.resolveTextEmbeddingModel(id, 'fallback')
expect(globalProviderInstanceRegistry.embeddingModel).toHaveBeenCalledWith(id)
}
})
})
})
describe('resolveImageModel', () => {
describe('Traditional Format', () => {
it('should resolve traditional image model ID', async () => {
const result = await resolver.resolveImageModel('dall-e-3', 'openai')
expect(globalProviderInstanceRegistry.imageModel).toHaveBeenCalledWith('openai|dall-e-3')
expect(result).toBe(mockImageModel)
})
it('should resolve different image models', async () => {
const testCases = [
{ modelId: 'dall-e-2', providerId: 'openai' },
{ modelId: 'stable-diffusion-xl', providerId: 'stability' },
{ modelId: 'imagen-2', providerId: 'google' },
{ modelId: 'midjourney-v6', providerId: 'midjourney' }
]
for (const { modelId, providerId } of testCases) {
vi.clearAllMocks()
await resolver.resolveImageModel(modelId, providerId)
expect(globalProviderInstanceRegistry.imageModel).toHaveBeenCalledWith(`${providerId}|${modelId}`)
}
})
})
describe('Namespaced Format', () => {
it('should resolve namespaced image model ID', async () => {
const namespacedId = `aihubmix${DEFAULT_SEPARATOR}openai${DEFAULT_SEPARATOR}dall-e-3`
const result = await resolver.resolveImageModel(namespacedId, 'openai')
expect(globalProviderInstanceRegistry.imageModel).toHaveBeenCalledWith(namespacedId)
expect(result).toBe(mockImageModel)
})
it('should handle complex namespaced image IDs', async () => {
const complexIds = [
`hub${DEFAULT_SEPARATOR}stability${DEFAULT_SEPARATOR}sdxl-turbo`,
`custom${DEFAULT_SEPARATOR}provider${DEFAULT_SEPARATOR}image-gen-v2`
]
for (const id of complexIds) {
vi.clearAllMocks()
await resolver.resolveImageModel(id, 'fallback')
expect(globalProviderInstanceRegistry.imageModel).toHaveBeenCalledWith(id)
}
})
})
})
describe('Edge Cases and Error Scenarios', () => {
it('should handle empty model IDs', async () => {
await resolver.resolveLanguageModel('', 'openai')
expect(globalProviderInstanceRegistry.languageModel).toHaveBeenCalledWith('openai|')
})
it('should handle model IDs with multiple separators', async () => {
const multiSeparatorId = `hub${DEFAULT_SEPARATOR}sub${DEFAULT_SEPARATOR}provider${DEFAULT_SEPARATOR}model`
await resolver.resolveLanguageModel(multiSeparatorId, 'fallback')
expect(globalProviderInstanceRegistry.languageModel).toHaveBeenCalledWith(multiSeparatorId)
})
it('should handle model IDs with only separator', async () => {
const onlySeparator = DEFAULT_SEPARATOR
await resolver.resolveLanguageModel(onlySeparator, 'provider')
expect(globalProviderInstanceRegistry.languageModel).toHaveBeenCalledWith(onlySeparator)
})
it('should throw if globalProviderInstanceRegistry throws', async () => {
const error = new Error('Model not found in registry')
vi.mocked(globalProviderInstanceRegistry.languageModel).mockImplementation(() => {
it('should throw if provider throws', async () => {
const error = new Error('Model not found')
vi.mocked(mockProvider.languageModel).mockImplementation(() => {
throw error
})
await expect(resolver.resolveLanguageModel('invalid-model', 'openai')).rejects.toThrow(
'Model not found in registry'
)
await expect(resolver.resolveLanguageModel('invalid-model')).rejects.toThrow('Model not found')
})
it('should handle concurrent resolution requests', async () => {
const promises = [
resolver.resolveLanguageModel('gpt-4', 'openai'),
resolver.resolveLanguageModel('claude-3', 'anthropic'),
resolver.resolveLanguageModel('gemini-2.0', 'google')
resolver.resolveLanguageModel('gpt-4'),
resolver.resolveLanguageModel('claude-3'),
resolver.resolveLanguageModel('gemini-2.0')
]
const results = await Promise.all(promises)
expect(results).toHaveLength(3)
expect(globalProviderInstanceRegistry.languageModel).toHaveBeenCalledTimes(3)
expect(mockProvider.languageModel).toHaveBeenCalledTimes(3)
})
})
describe('resolveEmbeddingModel', () => {
it('should resolve embedding model ID', async () => {
const result = await resolver.resolveEmbeddingModel('text-embedding-ada-002')
expect(mockProvider.embeddingModel).toHaveBeenCalledWith('text-embedding-ada-002')
expect(result).toBe(mockEmbeddingModel)
})
it('should resolve different embedding models', async () => {
const modelIds = ['text-embedding-3-small', 'text-embedding-3-large', 'embed-english-v3.0', 'voyage-2']
for (const modelId of modelIds) {
vi.clearAllMocks()
await resolver.resolveEmbeddingModel(modelId)
expect(mockProvider.embeddingModel).toHaveBeenCalledWith(modelId)
}
})
it('should pass namespaced embedding modelIds directly to provider', async () => {
const namespacedId = 'openai|text-embedding-3-small'
await resolver.resolveEmbeddingModel(namespacedId)
expect(mockProvider.embeddingModel).toHaveBeenCalledWith(namespacedId)
})
})
describe('resolveImageModel', () => {
it('should resolve image model ID', async () => {
const result = await resolver.resolveImageModel('dall-e-3')
expect(mockProvider.imageModel).toHaveBeenCalledWith('dall-e-3')
expect(result).toBe(mockImageModel)
})
it('should resolve different image models', async () => {
const modelIds = ['dall-e-2', 'stable-diffusion-xl', 'imagen-2', 'grok-2-image']
for (const modelId of modelIds) {
vi.clearAllMocks()
await resolver.resolveImageModel(modelId)
expect(mockProvider.imageModel).toHaveBeenCalledWith(modelId)
}
})
it('should pass namespaced image modelIds directly to provider', async () => {
const namespacedId = 'openai|dall-e-3'
await resolver.resolveImageModel(namespacedId)
expect(mockProvider.imageModel).toHaveBeenCalledWith(namespacedId)
})
})
describe('Type Safety', () => {
it('should return properly typed LanguageModelV3', async () => {
const result = await resolver.resolveLanguageModel('gpt-4', 'openai')
const result = await resolver.resolveLanguageModel('gpt-4')
// Type assertions
expect(result.specificationVersion).toBe('v3')
expect(result).toHaveProperty('doGenerate')
expect(result).toHaveProperty('doStream')
})
it('should return properly typed EmbeddingModelV3', async () => {
const result = await resolver.resolveTextEmbeddingModel('text-embedding-ada-002', 'openai')
const result = await resolver.resolveEmbeddingModel('text-embedding-ada-002')
expect(result.specificationVersion).toBe('v3')
expect(result).toHaveProperty('doEmbed')
})
it('should return properly typed ImageModelV3', async () => {
const result = await resolver.resolveImageModel('dall-e-3', 'openai')
const result = await resolver.resolveImageModel('dall-e-3')
expect(result.specificationVersion).toBe('v3')
expect(result).toHaveProperty('doGenerate')
})
})
describe('Global ModelResolver Instance', () => {
it('should have a global instance available', async () => {
const { globalModelResolver } = await import('../ModelResolver')
describe('All model types for same provider', () => {
it('should handle all model types correctly', async () => {
await resolver.resolveLanguageModel('gpt-4')
await resolver.resolveEmbeddingModel('text-embedding-3-small')
await resolver.resolveImageModel('dall-e-3')
expect(globalModelResolver).toBeInstanceOf(ModelResolver)
})
})
describe('Integration with Different Provider Types', () => {
it('should work with OpenAI compatible providers', async () => {
await resolver.resolveLanguageModel('custom-model', 'openai-compatible')
expect(globalProviderInstanceRegistry.languageModel).toHaveBeenCalledWith('openai-compatible|custom-model')
})
it('should work with hub providers', async () => {
const hubId = `aihubmix${DEFAULT_SEPARATOR}custom${DEFAULT_SEPARATOR}model-v1`
await resolver.resolveLanguageModel(hubId, 'aihubmix')
expect(globalProviderInstanceRegistry.languageModel).toHaveBeenCalledWith(hubId)
})
it('should handle all model types for same provider', async () => {
const providerId = 'openai'
const languageModel = 'gpt-4'
const embeddingModel = 'text-embedding-3-small'
const imageModel = 'dall-e-3'
await resolver.resolveLanguageModel(languageModel, providerId)
await resolver.resolveTextEmbeddingModel(embeddingModel, providerId)
await resolver.resolveImageModel(imageModel, providerId)
expect(globalProviderInstanceRegistry.languageModel).toHaveBeenCalledWith(`${providerId}|${languageModel}`)
expect(globalProviderInstanceRegistry.embeddingModel).toHaveBeenCalledWith(`${providerId}|${embeddingModel}`)
expect(globalProviderInstanceRegistry.imageModel).toHaveBeenCalledWith(`${providerId}|${imageModel}`)
expect(mockProvider.languageModel).toHaveBeenCalledWith('gpt-4')
expect(mockProvider.embeddingModel).toHaveBeenCalledWith('text-embedding-3-small')
expect(mockProvider.imageModel).toHaveBeenCalledWith('dall-e-3')
})
})
})

View File

@ -3,7 +3,7 @@
*/
// 核心模型解析器
export { globalModelResolver, ModelResolver } from './ModelResolver'
export { ModelResolver } from './ModelResolver'
// 保留的类型定义(可能被其他地方使用)
export type { ModelConfig as ModelConfigType } from './types'

View File

@ -17,7 +17,7 @@ export interface ModelConfig<
> {
providerId: T
modelId: string
providerSettings: T extends keyof TSettingsMap ? TSettingsMap[T] : never
providerSettings: TSettingsMap[T & keyof TSettingsMap]
middlewares?: LanguageModelV3Middleware[]
extraModelConfig?: JSONObject
}

View File

@ -10,7 +10,7 @@ import type {
import { simulateReadableStream } from 'ai'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import { createMockContext, createMockTool } from '../../../../../__tests__'
import { createMockContext, createMockTool } from '@test-utils'
import { StreamEventManager } from '../StreamEventManager'
import type { StreamController } from '../ToolExecutor'

View File

@ -3,7 +3,7 @@ import { simulateReadableStream } from 'ai'
import { convertReadableStreamToArray } from 'ai/test'
import { describe, expect, it, vi } from 'vitest'
import { createMockContext, createMockStreamParams, createMockTool, createMockToolSet } from '../../../../../__tests__'
import { createMockContext, createMockStreamParams, createMockTool, createMockToolSet } from '@test-utils'
import { createPromptToolUsePlugin, DEFAULT_SYSTEM_PROMPT } from '../promptToolUsePlugin'
describe('promptToolUsePlugin', () => {

View File

@ -2,9 +2,9 @@
* ExtensionRegistry
*/
import { createMockProviderV3 } from '@test-utils'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import { createMockProviderV3 } from '../../../__tests__'
import { ExtensionRegistry } from '../core/ExtensionRegistry'
import { ProviderExtension } from '../core/ProviderExtension'
import { ProviderCreationError } from '../core/utils'
@ -297,23 +297,6 @@ describe('ExtensionRegistry', () => {
})
})
it.skip('should validate settings before creating', async () => {
const extension = new ProviderExtension<any>({
name: 'test-provider',
create: createMockProviderV3 as any
})
registry.register(extension)
try {
await registry.createProvider('test-provider', {})
expect.fail('Should have thrown')
} catch (error) {
expect(error).toBeInstanceOf(ProviderCreationError)
expect((error as ProviderCreationError).cause.message).toContain('API key required')
}
})
it('should create provider using dynamic import', async () => {
const mockProvider = createMockProviderV3()
@ -503,46 +486,6 @@ describe('ExtensionRegistry', () => {
await expect(registry.createProvider('test-provider', { apiKey: 'key' })).rejects.toThrow(ProviderCreationError)
})
it.skip('should still execute validate hook for backward compatibility', async () => {
const validateSpy = vi.fn(() => ({ success: true }))
registry.register(
new ProviderExtension({
name: 'test-provider',
create: createMockProviderV3,
validate: validateSpy
})
)
await registry.createProvider('test-provider', { apiKey: 'key' })
expect(validateSpy).toHaveBeenCalledWith({ apiKey: 'key' })
})
it.skip('should execute both onBeforeCreate and validate', async () => {
const executionOrder: string[] = []
registry.register(
new ProviderExtension({
name: 'test-provider',
create: createMockProviderV3,
hooks: {
onBeforeCreate: () => {
executionOrder.push('hook')
}
},
validate: () => {
executionOrder.push('validate')
return { success: true }
}
})
)
await registry.createProvider('test-provider', { apiKey: 'key' })
expect(executionOrder).toEqual(['hook', 'validate'])
})
})
describe('ProviderCreationError', () => {

View File

@ -0,0 +1,442 @@
/**
* HubProvider Integration Tests
* Tests end-to-end integration between HubProvider, RuntimeExecutor, and ProviderExtension
*/
import type { LanguageModelV3 } from '@ai-sdk/provider'
import { createMockLanguageModel, createMockProviderV3 } from '@test-utils'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import { RuntimeExecutor } from '../../runtime/executor'
import { ExtensionRegistry } from '../core/ExtensionRegistry'
import { ProviderExtension } from '../core/ProviderExtension'
import { createHubProviderAsync } from '../features/HubProvider'
describe('HubProvider Integration Tests', () => {
let registry: ExtensionRegistry
let openaiExtension: ProviderExtension<any, any, any, any>
let anthropicExtension: ProviderExtension<any, any, any, any>
let googleExtension: ProviderExtension<any, any, any, any>
beforeEach(() => {
vi.clearAllMocks()
// Create fresh registry
registry = new ExtensionRegistry()
// Create provider extensions using test utils directly
openaiExtension = ProviderExtension.create({
name: 'openai',
create: () => createMockProviderV3({ provider: 'openai' })
} as const)
anthropicExtension = ProviderExtension.create({
name: 'anthropic',
create: () => createMockProviderV3({ provider: 'anthropic' })
} as const)
googleExtension = ProviderExtension.create({
name: 'google',
create: () => createMockProviderV3({ provider: 'google' })
} as const)
// Register extensions
registry.register(openaiExtension)
registry.register(anthropicExtension)
registry.register(googleExtension)
})
describe('End-to-End with RuntimeExecutor', () => {
it('should resolve models through HubProvider using namespace format', async () => {
// Create HubProvider
const hubProvider = await createHubProviderAsync({
hubId: 'aihubmix',
registry,
providerSettingsMap: new Map([
['openai', { apiKey: 'test-openai-key' }],
['anthropic', { apiKey: 'test-anthropic-key' }]
])
})
// Test that models are resolved correctly
const openaiModel = hubProvider.languageModel('openai|gpt-4')
const anthropicModel = hubProvider.languageModel('anthropic|claude-3-5-sonnet')
expect(openaiModel).toBeDefined()
expect(openaiModel.provider).toBe('openai')
expect(openaiModel.modelId).toBe('gpt-4')
expect(anthropicModel).toBeDefined()
expect(anthropicModel.provider).toBe('anthropic')
expect(anthropicModel.modelId).toBe('claude-3-5-sonnet')
})
it('should resolve language model correctly through executor', async () => {
const hubProvider = await createHubProviderAsync({
hubId: 'test-hub',
registry,
providerSettingsMap: new Map([['openai', { apiKey: 'test-key' }]])
})
const executor = RuntimeExecutor.create('test-hub', hubProvider, {} as never, [])
// Access the private resolveModel method through streamText
const result = await executor.streamText({
model: 'openai|gpt-4-turbo',
messages: [{ role: 'user', content: 'Test' }]
})
// Verify the model was created and result is valid
expect(result).toBeDefined()
expect(result.textStream).toBeDefined()
})
it('should handle multiple providers in the same hub', async () => {
const hubProvider = await createHubProviderAsync({
hubId: 'multi-hub',
registry,
providerSettingsMap: new Map([
['openai', { apiKey: 'openai-key' }],
['anthropic', { apiKey: 'anthropic-key' }],
['google', { apiKey: 'google-key' }]
])
})
// Test all three providers can be resolved
const openaiModel = hubProvider.languageModel('openai|gpt-4')
const anthropicModel = hubProvider.languageModel('anthropic|claude-3-5-sonnet')
const googleModel = hubProvider.languageModel('google|gemini-2.0-flash')
expect(openaiModel.provider).toBe('openai')
expect(openaiModel.modelId).toBe('gpt-4')
expect(anthropicModel.provider).toBe('anthropic')
expect(anthropicModel.modelId).toBe('claude-3-5-sonnet')
expect(googleModel.provider).toBe('google')
expect(googleModel.modelId).toBe('gemini-2.0-flash')
})
it('should work with direct model objects instead of strings', async () => {
const hubProvider = await createHubProviderAsync({
hubId: 'test-hub',
registry,
providerSettingsMap: new Map([['openai', { apiKey: 'test-key' }]])
})
const executor = RuntimeExecutor.create('test-hub', hubProvider, {} as never, [])
// Create a model instance directly
const model = createMockLanguageModel({
provider: 'openai',
modelId: 'gpt-4'
})
// Use the model object directly
const result = await executor.streamText({
model: model as LanguageModelV3,
messages: [{ role: 'user', content: 'Test with model object' }]
})
expect(result).toBeDefined()
})
})
describe('ProviderExtension LRU Cache Integration', () => {
it('should leverage ProviderExtension LRU cache when creating multiple HubProviders', async () => {
const settings = new Map([
['openai', { apiKey: 'same-key-1' }],
['anthropic', { apiKey: 'same-key-2' }]
])
// Create first HubProvider
const hub1 = await createHubProviderAsync({
hubId: 'hub1',
registry,
providerSettingsMap: settings
})
// Create second HubProvider with SAME settings
const hub2 = await createHubProviderAsync({
hubId: 'hub2',
registry,
providerSettingsMap: settings
})
// Extensions should have cached the provider instances
// Create a test model to verify caching
const model1 = hub1.languageModel('openai|gpt-4')
const model2 = hub2.languageModel('openai|gpt-4')
expect(model1).toBeDefined()
expect(model2).toBeDefined()
// Both should have the same provider name
expect(model1.provider).toBe('openai')
expect(model2.provider).toBe('openai')
})
it('should create new providers when settings differ', async () => {
const settings1 = new Map([['openai', { apiKey: 'key-1' }]])
const settings2 = new Map([['openai', { apiKey: 'key-2' }]])
// Create two HubProviders with DIFFERENT settings
const hub1 = await createHubProviderAsync({
hubId: 'hub1',
registry,
providerSettingsMap: settings1
})
const hub2 = await createHubProviderAsync({
hubId: 'hub2',
registry,
providerSettingsMap: settings2
})
const model1 = hub1.languageModel('openai|gpt-4')
const model2 = hub2.languageModel('openai|gpt-4')
expect(model1).toBeDefined()
expect(model2).toBeDefined()
})
it('should handle cache across multiple provider types', async () => {
const settings = new Map([
['openai', { apiKey: 'openai-key' }],
['anthropic', { apiKey: 'anthropic-key' }]
])
const hub = await createHubProviderAsync({
hubId: 'test-hub',
registry,
providerSettingsMap: settings
})
// Create models from different providers
const openaiModel = hub.languageModel('openai|gpt-4')
const anthropicModel = hub.languageModel('anthropic|claude-3-5-sonnet')
const openaiEmbedding = hub.embeddingModel('openai|text-embedding-3-small')
expect(openaiModel.provider).toBe('openai')
expect(anthropicModel.provider).toBe('anthropic')
expect(openaiEmbedding.provider).toBe('openai')
})
})
describe('Error Handling Integration', () => {
it('should throw error when using provider not in providerSettingsMap', async () => {
const hub = await createHubProviderAsync({
hubId: 'test-hub',
registry,
providerSettingsMap: new Map([['openai', { apiKey: 'test-key' }]])
// Note: anthropic NOT included
})
// Try to use anthropic (not initialized)
expect(() => {
hub.languageModel('anthropic|claude-3-5-sonnet')
}).toThrow(/Provider "anthropic" not initialized/)
})
it('should throw error when extension not registered', async () => {
const emptyRegistry = new ExtensionRegistry()
await expect(
createHubProviderAsync({
hubId: 'test-hub',
registry: emptyRegistry,
providerSettingsMap: new Map([['openai', { apiKey: 'test-key' }]])
})
).rejects.toThrow(/Provider extension "openai" not found in registry/)
})
it('should throw error on invalid model ID format', async () => {
const hub = await createHubProviderAsync({
hubId: 'test-hub',
registry,
providerSettingsMap: new Map([['openai', { apiKey: 'test-key' }]])
})
// Invalid format: no separator
expect(() => {
hub.languageModel('invalid-no-separator')
}).toThrow(/Invalid hub model ID format/)
// Invalid format: empty provider
expect(() => {
hub.languageModel('|model-id')
}).toThrow(/Invalid hub model ID format/)
// Invalid format: empty modelId
expect(() => {
hub.languageModel('openai|')
}).toThrow(/Invalid hub model ID format/)
})
it('should propagate errors from extension.createProvider', async () => {
// Create an extension that throws on creation
const failingExtension = ProviderExtension.create({
name: 'failing',
create: () => {
throw new Error('Provider creation failed!')
}
} as const)
const failRegistry = new ExtensionRegistry()
failRegistry.register(failingExtension)
await expect(
createHubProviderAsync({
hubId: 'test-hub',
registry: failRegistry,
providerSettingsMap: new Map([['failing', { apiKey: 'test' }]])
})
).rejects.toThrow(/Failed to create provider "failing"/)
})
})
describe('Advanced Scenarios', () => {
it('should support image generation through hub', async () => {
const hub = await createHubProviderAsync({
hubId: 'test-hub',
registry,
providerSettingsMap: new Map([['openai', { apiKey: 'test-key' }]])
})
const executor = RuntimeExecutor.create('test-hub', hub, {} as never, [])
const result = await executor.generateImage({
model: 'openai|dall-e-3',
prompt: 'A beautiful sunset'
})
expect(result).toBeDefined()
})
it('should support embedding models through hub', async () => {
const hub = await createHubProviderAsync({
hubId: 'test-hub',
registry,
providerSettingsMap: new Map([['openai', { apiKey: 'test-key' }]])
})
const embeddingModel = hub.embeddingModel('openai|text-embedding-3-small')
expect(embeddingModel).toBeDefined()
expect(embeddingModel.provider).toBe('openai')
expect(embeddingModel.modelId).toBe('text-embedding-3-small')
})
it('should handle concurrent model resolutions', async () => {
const hub = await createHubProviderAsync({
hubId: 'test-hub',
registry,
providerSettingsMap: new Map([
['openai', { apiKey: 'openai-key' }],
['anthropic', { apiKey: 'anthropic-key' }]
])
})
// Concurrent model resolutions
const models = await Promise.all([
Promise.resolve(hub.languageModel('openai|gpt-4')),
Promise.resolve(hub.languageModel('anthropic|claude-3-5-sonnet')),
Promise.resolve(hub.languageModel('openai|gpt-3.5-turbo'))
])
expect(models).toHaveLength(3)
expect(models[0].provider).toBe('openai')
expect(models[0].modelId).toBe('gpt-4')
expect(models[1].provider).toBe('anthropic')
expect(models[1].modelId).toBe('claude-3-5-sonnet')
expect(models[2].provider).toBe('openai')
expect(models[2].modelId).toBe('gpt-3.5-turbo')
})
it('should work with middlewares', async () => {
const hub = await createHubProviderAsync({
hubId: 'test-hub',
registry,
providerSettingsMap: new Map([['openai', { apiKey: 'test-key' }]])
})
const executor = RuntimeExecutor.create('test-hub', hub, {} as never, [])
// Create a mock middleware
const mockMiddleware = {
specificationVersion: 'v3' as const,
wrapGenerate: vi.fn((doGenerate) => doGenerate),
wrapStream: vi.fn((doStream) => doStream)
}
const result = await executor.streamText(
{
model: 'openai|gpt-4',
messages: [{ role: 'user', content: 'Test with middleware' }]
},
{ middlewares: [mockMiddleware] }
)
expect(result).toBeDefined()
})
})
describe('Multiple HubProvider Instances', () => {
it('should support multiple independent hub providers', async () => {
// Create first hub for OpenAI only
const openaiHub = await createHubProviderAsync({
hubId: 'openai-hub',
registry,
providerSettingsMap: new Map([['openai', { apiKey: 'openai-key' }]])
})
// Create second hub for Anthropic only
const anthropicHub = await createHubProviderAsync({
hubId: 'anthropic-hub',
registry,
providerSettingsMap: new Map([['anthropic', { apiKey: 'anthropic-key' }]])
})
// Both hubs should work independently
const openaiModel = openaiHub.languageModel('openai|gpt-4')
const anthropicModel = anthropicHub.languageModel('anthropic|claude-3-5-sonnet')
expect(openaiModel.provider).toBe('openai')
expect(anthropicModel.provider).toBe('anthropic')
// OpenAI hub should not have anthropic
expect(() => {
openaiHub.languageModel('anthropic|claude-3-5-sonnet')
}).toThrow(/Provider "anthropic" not initialized/)
// Anthropic hub should not have openai
expect(() => {
anthropicHub.languageModel('openai|gpt-4')
}).toThrow(/Provider "openai" not initialized/)
})
it('should support creating multiple executors from same hub', async () => {
const hub = await createHubProviderAsync({
hubId: 'shared-hub',
registry,
providerSettingsMap: new Map([
['openai', { apiKey: 'key-1' }],
['anthropic', { apiKey: 'key-2' }]
])
})
// Create multiple executors from the same hub
const executor1 = RuntimeExecutor.create('shared-hub', hub, {} as never, [])
const executor2 = RuntimeExecutor.create('shared-hub', hub, {} as never, [])
// Both executors should share the same hub and be able to resolve models
const model1 = hub.languageModel('openai|gpt-4')
const model2 = hub.languageModel('anthropic|claude-3-5-sonnet')
expect(executor1).toBeDefined()
expect(executor2).toBeDefined()
expect(model1.provider).toBe('openai')
expect(model2.provider).toBe('anthropic')
})
})
})

View File

@ -1,32 +1,30 @@
/**
* HubProvider Comprehensive Tests
* Tests hub provider routing, model resolution, and error handling
* Covers multi-provider routing with namespaced model IDs
* Updated for ExtensionRegistry architecture with createHubProviderAsync
*/
import type { EmbeddingModelV3, ImageModelV3, LanguageModelV3, ProviderV3 } from '@ai-sdk/provider'
import { customProvider, wrapProvider } from 'ai'
import { customProvider } from 'ai'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import { createMockEmbeddingModel, createMockImageModel, createMockLanguageModel } from '../../../__tests__'
import { DEFAULT_SEPARATOR, globalProviderInstanceRegistry } from '../core/ProviderInstanceRegistry'
import { createHubProvider, type HubProviderConfig, HubProviderError } from '../features/HubProvider'
// Mock dependencies
vi.mock('../core/ProviderInstanceRegistry', () => ({
globalProviderInstanceRegistry: {
getProvider: vi.fn()
},
DEFAULT_SEPARATOR: '|'
}))
import { createMockEmbeddingModel, createMockImageModel, createMockLanguageModel } from '@test-utils'
import { ExtensionRegistry } from '../core/ExtensionRegistry'
import { ProviderExtension } from '../core/ProviderExtension'
import {
createHubProviderAsync,
DEFAULT_SEPARATOR,
type HubProviderConfig,
HubProviderError
} from '../features/HubProvider'
vi.mock('ai', () => ({
customProvider: vi.fn((config) => config.fallbackProvider),
wrapProvider: vi.fn((config) => config.provider),
jsonSchema: vi.fn((schema) => schema)
}))
describe('HubProvider', () => {
let registry: ExtensionRegistry
let mockOpenAIProvider: ProviderV3
let mockAnthropicProvider: ProviderV3
let mockLanguageModel: LanguageModelV3
@ -36,7 +34,7 @@ describe('HubProvider', () => {
beforeEach(() => {
vi.clearAllMocks()
// Create mock models using global utilities
// Create mock models
mockLanguageModel = createMockLanguageModel({
provider: 'test',
modelId: 'test-model'
@ -67,150 +65,185 @@ describe('HubProvider', () => {
imageModel: vi.fn().mockReturnValue(mockImageModel)
} as ProviderV3
// Setup default mock implementation
vi.mocked(globalProviderInstanceRegistry.getProvider).mockImplementation((id) => {
if (id === 'openai') return mockOpenAIProvider
if (id === 'anthropic') return mockAnthropicProvider
return undefined
})
// Create registry and register extensions
registry = new ExtensionRegistry()
const openaiExtension = ProviderExtension.create({
name: 'openai',
create: () => mockOpenAIProvider
} as const)
const anthropicExtension = ProviderExtension.create({
name: 'anthropic',
create: () => mockAnthropicProvider
} as const)
registry.register(openaiExtension)
registry.register(anthropicExtension)
})
describe('Provider Creation', () => {
it('should create hub provider with basic config', () => {
it('should create hub provider with basic config', async () => {
const config: HubProviderConfig = {
hubId: 'test-hub'
hubId: 'test-hub',
registry,
providerSettingsMap: new Map([['openai', { apiKey: 'test-key' }]])
}
const provider = createHubProvider(config)
const provider = await createHubProviderAsync(config)
expect(provider).toBeDefined()
expect(customProvider).toHaveBeenCalled()
})
it('should create provider with debug flag', () => {
it('should create provider with debug flag', async () => {
const config: HubProviderConfig = {
hubId: 'test-hub',
debug: true
debug: true,
registry,
providerSettingsMap: new Map([['openai', {}]])
}
const provider = createHubProvider(config)
const provider = await createHubProviderAsync(config)
expect(provider).toBeDefined()
})
it('should return ProviderV3 specification', () => {
const config: HubProviderConfig = {
hubId: 'aihubmix'
}
const provider = createHubProvider(config)
it('should return ProviderV3 specification', async () => {
const provider = await createHubProviderAsync({
hubId: 'aihubmix',
registry,
providerSettingsMap: new Map([
['openai', {}],
['anthropic', {}]
])
})
expect(provider).toHaveProperty('specificationVersion', 'v3')
expect(provider).toHaveProperty('languageModel')
expect(provider).toHaveProperty('embeddingModel')
expect(provider).toHaveProperty('imageModel')
})
it('should throw error if extension not found in registry', async () => {
await expect(
createHubProviderAsync({
hubId: 'test-hub',
registry,
providerSettingsMap: new Map([['unknown-provider', {}]])
})
).rejects.toThrow(HubProviderError)
})
it('should pre-create all providers during initialization', async () => {
await createHubProviderAsync({
hubId: 'test-hub',
registry,
providerSettingsMap: new Map([
['openai', { apiKey: 'key1' }],
['anthropic', { apiKey: 'key2' }]
])
})
// Both providers created successfully
expect(true).toBe(true)
})
})
describe('Model ID Parsing', () => {
it('should parse valid hub model ID format', () => {
const config: HubProviderConfig = { hubId: 'test-hub' }
const provider = createHubProvider(config) as ProviderV3
it('should parse valid hub model ID format', async () => {
const provider = (await createHubProviderAsync({
hubId: 'test-hub',
registry,
providerSettingsMap: new Map([['openai', {}]])
})) as ProviderV3
const modelId = `openai${DEFAULT_SEPARATOR}gpt-4`
const result = provider.languageModel(`openai${DEFAULT_SEPARATOR}gpt-4`)
const result = provider.languageModel(modelId)
expect(globalProviderInstanceRegistry.getProvider).toHaveBeenCalledWith('openai')
expect(mockOpenAIProvider.languageModel).toHaveBeenCalledWith('gpt-4')
expect(result).toBe(mockLanguageModel)
})
it('should throw error for invalid model ID format', () => {
const config: HubProviderConfig = { hubId: 'test-hub' }
const provider = createHubProvider(config) as ProviderV3
it('should throw error for invalid model ID format', async () => {
const provider = (await createHubProviderAsync({
hubId: 'test-hub',
registry,
providerSettingsMap: new Map([['openai', {}]])
})) as ProviderV3
const invalidId = 'invalid-id-without-separator'
expect(() => provider.languageModel(invalidId)).toThrow(HubProviderError)
expect(() => provider.languageModel('invalid-id-without-separator')).toThrow(HubProviderError)
})
it('should throw error for model ID with multiple separators', () => {
const config: HubProviderConfig = { hubId: 'test-hub' }
const provider = createHubProvider(config) as ProviderV3
it('should throw error for model ID with multiple separators', async () => {
const provider = (await createHubProviderAsync({
hubId: 'test-hub',
registry,
providerSettingsMap: new Map([['openai', {}]])
})) as ProviderV3
const multiSeparatorId = `provider${DEFAULT_SEPARATOR}extra${DEFAULT_SEPARATOR}model`
expect(() => provider.languageModel(multiSeparatorId)).toThrow(HubProviderError)
expect(() => provider.languageModel(`provider${DEFAULT_SEPARATOR}extra${DEFAULT_SEPARATOR}model`)).toThrow(
HubProviderError
)
})
it('should throw error for empty model ID', () => {
const config: HubProviderConfig = { hubId: 'test-hub' }
const provider = createHubProvider(config) as ProviderV3
it('should throw error for empty model ID', async () => {
const provider = (await createHubProviderAsync({
hubId: 'test-hub',
registry,
providerSettingsMap: new Map([['openai', {}]])
})) as ProviderV3
expect(() => provider.languageModel('')).toThrow(HubProviderError)
})
it('should throw error for model ID with only separator', () => {
const config: HubProviderConfig = { hubId: 'test-hub' }
const provider = createHubProvider(config) as ProviderV3
expect(() => provider.languageModel(DEFAULT_SEPARATOR)).toThrow(HubProviderError)
})
})
describe('Language Model Resolution', () => {
it('should route to correct provider for language model', () => {
const config: HubProviderConfig = { hubId: 'aihubmix' }
const provider = createHubProvider(config) as ProviderV3
it('should route to correct provider for language model', async () => {
const provider = (await createHubProviderAsync({
hubId: 'aihubmix',
registry,
providerSettingsMap: new Map([['openai', {}]])
})) as ProviderV3
const result = provider.languageModel(`openai${DEFAULT_SEPARATOR}gpt-4`)
expect(globalProviderInstanceRegistry.getProvider).toHaveBeenCalledWith('openai')
expect(mockOpenAIProvider.languageModel).toHaveBeenCalledWith('gpt-4')
expect(result).toBe(mockLanguageModel)
})
it('should route different providers correctly', () => {
const config: HubProviderConfig = { hubId: 'aihubmix' }
const provider = createHubProvider(config) as ProviderV3
it('should route different providers correctly', async () => {
const provider = (await createHubProviderAsync({
hubId: 'aihubmix',
registry,
providerSettingsMap: new Map([
['openai', {}],
['anthropic', {}]
])
})) as ProviderV3
provider.languageModel(`openai${DEFAULT_SEPARATOR}gpt-4`)
provider.languageModel(`anthropic${DEFAULT_SEPARATOR}claude-3`)
expect(globalProviderInstanceRegistry.getProvider).toHaveBeenCalledWith('openai')
expect(globalProviderInstanceRegistry.getProvider).toHaveBeenCalledWith('anthropic')
expect(mockOpenAIProvider.languageModel).toHaveBeenCalledWith('gpt-4')
expect(mockAnthropicProvider.languageModel).toHaveBeenCalledWith('claude-3')
})
it('should wrap provider with wrapProvider', () => {
const config: HubProviderConfig = { hubId: 'test-hub' }
const provider = createHubProvider(config) as ProviderV3
it('should throw HubProviderError if provider not initialized', async () => {
const provider = (await createHubProviderAsync({
hubId: 'test-hub',
registry,
providerSettingsMap: new Map([['openai', {}]]) // Only openai initialized
})) as ProviderV3
provider.languageModel(`openai${DEFAULT_SEPARATOR}gpt-4`)
expect(wrapProvider).toHaveBeenCalledWith({
provider: mockOpenAIProvider,
languageModelMiddleware: []
})
expect(() => provider.languageModel(`anthropic${DEFAULT_SEPARATOR}claude-3`)).toThrow(HubProviderError)
})
it('should throw HubProviderError if provider not initialized', () => {
vi.mocked(globalProviderInstanceRegistry.getProvider).mockReturnValue(undefined)
const config: HubProviderConfig = { hubId: 'test-hub' }
const provider = createHubProvider(config) as ProviderV3
expect(() => provider.languageModel(`uninitialized${DEFAULT_SEPARATOR}model`)).toThrow(HubProviderError)
expect(() => provider.languageModel(`uninitialized${DEFAULT_SEPARATOR}model`)).toThrow(/not initialized/)
})
it('should include provider ID in error message', () => {
vi.mocked(globalProviderInstanceRegistry.getProvider).mockReturnValue(undefined)
const config: HubProviderConfig = { hubId: 'test-hub' }
const provider = createHubProvider(config) as ProviderV3
it('should include provider ID in error message', async () => {
const provider = (await createHubProviderAsync({
hubId: 'test-hub',
registry,
providerSettingsMap: new Map([['openai', {}]])
})) as ProviderV3
try {
provider.languageModel(`missing${DEFAULT_SEPARATOR}model`)
@ -225,20 +258,28 @@ describe('HubProvider', () => {
})
describe('Embedding Model Resolution', () => {
it('should route to correct provider for embedding model', () => {
const config: HubProviderConfig = { hubId: 'aihubmix' }
const provider = createHubProvider(config) as ProviderV3
it('should route to correct provider for embedding model', async () => {
const provider = (await createHubProviderAsync({
hubId: 'aihubmix',
registry,
providerSettingsMap: new Map([['openai', {}]])
})) as ProviderV3
const result = provider.embeddingModel(`openai${DEFAULT_SEPARATOR}text-embedding-3-small`)
expect(globalProviderInstanceRegistry.getProvider).toHaveBeenCalledWith('openai')
expect(mockOpenAIProvider.embeddingModel).toHaveBeenCalledWith('text-embedding-3-small')
expect(result).toBe(mockEmbeddingModel)
})
it('should handle different embedding providers', () => {
const config: HubProviderConfig = { hubId: 'aihubmix' }
const provider = createHubProvider(config) as ProviderV3
it('should handle different embedding providers', async () => {
const provider = (await createHubProviderAsync({
hubId: 'aihubmix',
registry,
providerSettingsMap: new Map([
['openai', {}],
['anthropic', {}]
])
})) as ProviderV3
provider.embeddingModel(`openai${DEFAULT_SEPARATOR}ada-002`)
provider.embeddingModel(`anthropic${DEFAULT_SEPARATOR}embed-v1`)
@ -246,32 +287,31 @@ describe('HubProvider', () => {
expect(mockOpenAIProvider.embeddingModel).toHaveBeenCalledWith('ada-002')
expect(mockAnthropicProvider.embeddingModel).toHaveBeenCalledWith('embed-v1')
})
it('should throw error for uninitialized embedding provider', () => {
vi.mocked(globalProviderInstanceRegistry.getProvider).mockReturnValue(undefined)
const config: HubProviderConfig = { hubId: 'test-hub' }
const provider = createHubProvider(config) as ProviderV3
expect(() => provider.embeddingModel(`missing${DEFAULT_SEPARATOR}embed`)).toThrow(HubProviderError)
})
})
describe('Image Model Resolution', () => {
it('should route to correct provider for image model', () => {
const config: HubProviderConfig = { hubId: 'aihubmix' }
const provider = createHubProvider(config) as ProviderV3
it('should route to correct provider for image model', async () => {
const provider = (await createHubProviderAsync({
hubId: 'aihubmix',
registry,
providerSettingsMap: new Map([['openai', {}]])
})) as ProviderV3
const result = provider.imageModel(`openai${DEFAULT_SEPARATOR}dall-e-3`)
expect(globalProviderInstanceRegistry.getProvider).toHaveBeenCalledWith('openai')
expect(mockOpenAIProvider.imageModel).toHaveBeenCalledWith('dall-e-3')
expect(result).toBe(mockImageModel)
})
it('should handle different image providers', () => {
const config: HubProviderConfig = { hubId: 'aihubmix' }
const provider = createHubProvider(config) as ProviderV3
it('should handle different image providers', async () => {
const provider = (await createHubProviderAsync({
hubId: 'aihubmix',
registry,
providerSettingsMap: new Map([
['openai', {}],
['anthropic', {}]
])
})) as ProviderV3
provider.imageModel(`openai${DEFAULT_SEPARATOR}dall-e-3`)
provider.imageModel(`anthropic${DEFAULT_SEPARATOR}image-gen`)
@ -282,9 +322,9 @@ describe('HubProvider', () => {
})
describe('Special Model Types', () => {
it('should support transcription models', () => {
it('should support transcription models if provider has them', async () => {
const mockTranscriptionModel = {
specificationVersion: 'v3',
specificationVersion: 'v3' as const,
doTranscribe: vi.fn()
}
@ -293,86 +333,38 @@ describe('HubProvider', () => {
transcriptionModel: vi.fn().mockReturnValue(mockTranscriptionModel)
} as ProviderV3
vi.mocked(globalProviderInstanceRegistry.getProvider).mockReturnValue(providerWithTranscription)
// Replace the provider that will be created
const transcriptionExtension = ProviderExtension.create({
name: 'transcription-provider',
create: () => providerWithTranscription
} as const)
const config: HubProviderConfig = { hubId: 'test-hub' }
const provider = createHubProvider(config) as ProviderV3
registry.register(transcriptionExtension)
const result = provider.transcriptionModel!(`openai${DEFAULT_SEPARATOR}whisper-1`)
const provider = (await createHubProviderAsync({
hubId: 'test-hub',
registry,
providerSettingsMap: new Map([['transcription-provider', {}]])
})) as ProviderV3
const result = provider.transcriptionModel!(`transcription-provider${DEFAULT_SEPARATOR}whisper-1`)
expect(providerWithTranscription.transcriptionModel).toHaveBeenCalledWith('whisper-1')
expect(result).toBe(mockTranscriptionModel)
})
it('should throw error if provider does not support transcription', () => {
const config: HubProviderConfig = { hubId: 'test-hub' }
const provider = createHubProvider(config) as ProviderV3
it('should throw error if provider does not support transcription', async () => {
const provider = (await createHubProviderAsync({
hubId: 'test-hub',
registry,
providerSettingsMap: new Map([['openai', {}]])
})) as ProviderV3
expect(() => provider.transcriptionModel!(`openai${DEFAULT_SEPARATOR}whisper`)).toThrow(HubProviderError)
expect(() => provider.transcriptionModel!(`openai${DEFAULT_SEPARATOR}whisper`)).toThrow(
/does not support transcription/
)
})
it('should support speech models', () => {
const mockSpeechModel = {
specificationVersion: 'v3',
doGenerate: vi.fn()
}
const providerWithSpeech = {
...mockOpenAIProvider,
speechModel: vi.fn().mockReturnValue(mockSpeechModel)
} as ProviderV3
vi.mocked(globalProviderInstanceRegistry.getProvider).mockReturnValue(providerWithSpeech)
const config: HubProviderConfig = { hubId: 'test-hub' }
const provider = createHubProvider(config) as ProviderV3
const result = provider.speechModel!(`openai${DEFAULT_SEPARATOR}tts-1`)
expect(providerWithSpeech.speechModel).toHaveBeenCalledWith('tts-1')
expect(result).toBe(mockSpeechModel)
})
it('should throw error if provider does not support speech', () => {
const config: HubProviderConfig = { hubId: 'test-hub' }
const provider = createHubProvider(config) as ProviderV3
expect(() => provider.speechModel!(`openai${DEFAULT_SEPARATOR}tts-1`)).toThrow(HubProviderError)
expect(() => provider.speechModel!(`openai${DEFAULT_SEPARATOR}tts-1`)).toThrow(/does not support speech/)
})
it('should support reranking models', () => {
const mockRerankingModel = {
specificationVersion: 'v3',
doRerank: vi.fn()
}
const providerWithReranking = {
...mockOpenAIProvider,
rerankingModel: vi.fn().mockReturnValue(mockRerankingModel)
} as ProviderV3
vi.mocked(globalProviderInstanceRegistry.getProvider).mockReturnValue(providerWithReranking)
const config: HubProviderConfig = { hubId: 'test-hub' }
const provider = createHubProvider(config) as ProviderV3
const result = provider.rerankingModel!(`openai${DEFAULT_SEPARATOR}rerank-v1`)
expect(providerWithReranking.rerankingModel).toHaveBeenCalledWith('rerank-v1')
expect(result).toBe(mockRerankingModel)
})
it('should throw error if provider does not support reranking', () => {
const config: HubProviderConfig = { hubId: 'test-hub' }
const provider = createHubProvider(config) as ProviderV3
expect(() => provider.rerankingModel!(`openai${DEFAULT_SEPARATOR}rerank`)).toThrow(HubProviderError)
expect(() => provider.rerankingModel!(`openai${DEFAULT_SEPARATOR}rerank`)).toThrow(/does not support reranking/)
})
})
describe('Error Handling', () => {
@ -395,106 +387,51 @@ describe('HubProvider', () => {
expect(error.providerId).toBeUndefined()
expect(error.originalError).toBeUndefined()
})
it('should wrap provider errors in HubProviderError', () => {
const providerError = new Error('Provider failed')
vi.mocked(globalProviderInstanceRegistry.getProvider).mockImplementation(() => {
throw providerError
})
const config: HubProviderConfig = { hubId: 'test-hub' }
const provider = createHubProvider(config) as ProviderV3
try {
provider.languageModel(`failing${DEFAULT_SEPARATOR}model`)
expect.fail('Should have thrown HubProviderError')
} catch (error) {
expect(error).toBeInstanceOf(HubProviderError)
const hubError = error as HubProviderError
expect(hubError.originalError).toBe(providerError)
expect(hubError.message).toContain('Failed to get provider')
}
})
it('should handle null provider from registry', () => {
vi.mocked(globalProviderInstanceRegistry.getProvider).mockReturnValue(null as any)
const config: HubProviderConfig = { hubId: 'test-hub' }
const provider = createHubProvider(config) as ProviderV3
expect(() => provider.languageModel(`null-provider${DEFAULT_SEPARATOR}model`)).toThrow(HubProviderError)
})
})
describe('Multi-Provider Scenarios', () => {
it('should handle sequential calls to different providers', () => {
const config: HubProviderConfig = { hubId: 'aihubmix' }
const provider = createHubProvider(config) as ProviderV3
it('should handle sequential calls to different providers', async () => {
const provider = (await createHubProviderAsync({
hubId: 'aihubmix',
registry,
providerSettingsMap: new Map([
['openai', {}],
['anthropic', {}]
])
})) as ProviderV3
provider.languageModel(`openai${DEFAULT_SEPARATOR}gpt-4`)
provider.languageModel(`anthropic${DEFAULT_SEPARATOR}claude-3`)
provider.languageModel(`openai${DEFAULT_SEPARATOR}gpt-3.5`)
expect(globalProviderInstanceRegistry.getProvider).toHaveBeenCalledTimes(3)
expect(mockOpenAIProvider.languageModel).toHaveBeenCalledTimes(2)
expect(mockAnthropicProvider.languageModel).toHaveBeenCalledTimes(1)
})
it('should handle mixed model types from same provider', () => {
const config: HubProviderConfig = { hubId: 'aihubmix' }
const provider = createHubProvider(config) as ProviderV3
it('should handle mixed model types from same provider', async () => {
const provider = (await createHubProviderAsync({
hubId: 'aihubmix',
registry,
providerSettingsMap: new Map([['openai', {}]])
})) as ProviderV3
provider.languageModel(`openai${DEFAULT_SEPARATOR}gpt-4`)
provider.embeddingModel(`openai${DEFAULT_SEPARATOR}ada-002`)
provider.imageModel(`openai${DEFAULT_SEPARATOR}dall-e-3`)
expect(globalProviderInstanceRegistry.getProvider).toHaveBeenCalledTimes(3)
expect(mockOpenAIProvider.languageModel).toHaveBeenCalledWith('gpt-4')
expect(mockOpenAIProvider.embeddingModel).toHaveBeenCalledWith('ada-002')
expect(mockOpenAIProvider.imageModel).toHaveBeenCalledWith('dall-e-3')
})
it('should cache provider lookups', () => {
const config: HubProviderConfig = { hubId: 'aihubmix' }
const provider = createHubProvider(config) as ProviderV3
provider.languageModel(`openai${DEFAULT_SEPARATOR}gpt-4`)
provider.languageModel(`openai${DEFAULT_SEPARATOR}gpt-3.5`)
// Should call getProvider twice (once per model call)
expect(globalProviderInstanceRegistry.getProvider).toHaveBeenCalledTimes(2)
})
})
describe('Provider Wrapping', () => {
it('should wrap all providers with empty middleware', () => {
const config: HubProviderConfig = { hubId: 'test-hub' }
const provider = createHubProvider(config) as ProviderV3
provider.languageModel(`openai${DEFAULT_SEPARATOR}gpt-4`)
expect(wrapProvider).toHaveBeenCalledWith({
provider: mockOpenAIProvider,
languageModelMiddleware: []
})
})
it('should wrap providers for all model types', () => {
const config: HubProviderConfig = { hubId: 'test-hub' }
const provider = createHubProvider(config) as ProviderV3
provider.languageModel(`openai${DEFAULT_SEPARATOR}gpt-4`)
provider.embeddingModel(`openai${DEFAULT_SEPARATOR}ada`)
provider.imageModel(`openai${DEFAULT_SEPARATOR}dalle`)
expect(wrapProvider).toHaveBeenCalledTimes(3)
})
})
describe('Type Safety', () => {
it('should return properly typed LanguageModelV3', () => {
const config: HubProviderConfig = { hubId: 'test-hub' }
const provider = createHubProvider(config) as ProviderV3
it('should return properly typed LanguageModelV3', async () => {
const provider = (await createHubProviderAsync({
hubId: 'test-hub',
registry,
providerSettingsMap: new Map([['openai', {}]])
})) as ProviderV3
const result = provider.languageModel(`openai${DEFAULT_SEPARATOR}gpt-4`)
@ -503,9 +440,12 @@ describe('HubProvider', () => {
expect(result).toHaveProperty('doStream')
})
it('should return properly typed EmbeddingModelV3', () => {
const config: HubProviderConfig = { hubId: 'test-hub' }
const provider = createHubProvider(config) as ProviderV3
it('should return properly typed EmbeddingModelV3', async () => {
const provider = (await createHubProviderAsync({
hubId: 'test-hub',
registry,
providerSettingsMap: new Map([['openai', {}]])
})) as ProviderV3
const result = provider.embeddingModel(`openai${DEFAULT_SEPARATOR}ada`)
@ -513,9 +453,12 @@ describe('HubProvider', () => {
expect(result).toHaveProperty('doEmbed')
})
it('should return properly typed ImageModelV3', () => {
const config: HubProviderConfig = { hubId: 'test-hub' }
const provider = createHubProvider(config) as ProviderV3
it('should return properly typed ImageModelV3', async () => {
const provider = (await createHubProviderAsync({
hubId: 'test-hub',
registry,
providerSettingsMap: new Map([['openai', {}]])
})) as ProviderV3
const result = provider.imageModel(`openai${DEFAULT_SEPARATOR}dalle`)
@ -523,119 +466,4 @@ describe('HubProvider', () => {
expect(result).toHaveProperty('doGenerate')
})
})
describe('Dependency Injection', () => {
it('should use global registry by default', () => {
vi.mocked(globalProviderInstanceRegistry.getProvider).mockReturnValue(mockOpenAIProvider)
const hubProvider = createHubProvider({ hubId: 'test-hub' })
const provider = hubProvider as ProviderV3
provider.languageModel(`openai${DEFAULT_SEPARATOR}gpt-4`)
// Should call global registry
expect(globalProviderInstanceRegistry.getProvider).toHaveBeenCalledWith('openai')
})
it('should use custom registry when provided', () => {
const customRegistry = {
getProvider: vi.fn().mockReturnValue(mockOpenAIProvider)
}
const hubProvider = createHubProvider({
hubId: 'test-hub',
providerRegistry: customRegistry as any
})
const provider = hubProvider as ProviderV3
provider.languageModel(`openai${DEFAULT_SEPARATOR}gpt-4`)
// Should call custom registry, not global
expect(customRegistry.getProvider).toHaveBeenCalledWith('openai')
expect(globalProviderInstanceRegistry.getProvider).not.toHaveBeenCalled()
})
it('should allow testing with mock registry', () => {
const mockRegistry = {
getProvider: vi.fn((id: string) => {
if (id === 'test-provider') {
return mockOpenAIProvider
}
return undefined
})
}
const hubProvider = createHubProvider({
hubId: 'test-hub',
providerRegistry: mockRegistry as any
})
const provider = hubProvider as ProviderV3
// Should work with mock registry
const model = provider.languageModel(`test-provider${DEFAULT_SEPARATOR}model`)
expect(mockRegistry.getProvider).toHaveBeenCalledWith('test-provider')
expect(model).toBeDefined()
})
it('should throw error when provider not found in custom registry', () => {
const emptyRegistry = {
getProvider: vi.fn().mockReturnValue(undefined)
}
const hubProvider = createHubProvider({
hubId: 'test-hub',
providerRegistry: emptyRegistry as any
})
const provider = hubProvider as ProviderV3
expect(() => {
provider.languageModel(`unknown${DEFAULT_SEPARATOR}model`)
}).toThrow(HubProviderError)
expect(emptyRegistry.getProvider).toHaveBeenCalledWith('unknown')
})
it('should support multiple hub instances with different registries', () => {
const registry1 = {
getProvider: vi.fn().mockReturnValue(mockOpenAIProvider)
}
const registry2 = {
getProvider: vi.fn().mockReturnValue(mockAnthropicProvider)
}
const hub1 = createHubProvider({
hubId: 'hub-1',
providerRegistry: registry1 as any
}) as ProviderV3
const hub2 = createHubProvider({
hubId: 'hub-2',
providerRegistry: registry2 as any
}) as ProviderV3
// Each hub should use its own registry
hub1.languageModel(`openai${DEFAULT_SEPARATOR}gpt-4`)
hub2.languageModel(`anthropic${DEFAULT_SEPARATOR}claude`)
expect(registry1.getProvider).toHaveBeenCalledWith('openai')
expect(registry2.getProvider).toHaveBeenCalledWith('anthropic')
// Registries should be independent
expect(registry1.getProvider).not.toHaveBeenCalledWith('anthropic')
expect(registry2.getProvider).not.toHaveBeenCalledWith('openai')
})
it('should make hubId optional and default to "hub"', () => {
vi.mocked(globalProviderInstanceRegistry.getProvider).mockReturnValue(undefined)
const hubProvider = createHubProvider() // No config
const provider = hubProvider as ProviderV3
// Should use default hubId 'hub' in error messages
expect(() => {
provider.languageModel(`unknown${DEFAULT_SEPARATOR}model`)
}).toThrow(HubProviderError)
})
})
})

View File

@ -5,7 +5,7 @@
import type { ProviderV3 } from '@ai-sdk/provider'
import { describe, expect, it, vi } from 'vitest'
import { createMockProviderV3 } from '../../../__tests__'
import { createMockProviderV3 } from '@test-utils'
import {
createProviderExtension,
ProviderExtension,
@ -85,7 +85,7 @@ describe('ProviderExtension', () => {
expect(extension.config.defaultOptions).toEqual({ apiKey: 'initial-key' })
})
it('should validate config from function same as from object', () => {
it('should validate config from function same as from object', async () => {
expect(() => {
ProviderExtension.create(() => ({
name: '', // Invalid
@ -93,15 +93,16 @@ describe('ProviderExtension', () => {
}))
}).toThrow('name is required')
expect(() => {
ProviderExtension.create(
() =>
({
name: 'test-provider'
// Missing create
}) as any
)
}).toThrow('either create or import must be provided')
// Note: create/import validation happens at runtime in createProvider(), not in constructor
// Extension can be created without create/import, but createProvider() will throw
const extension = ProviderExtension.create(
() =>
({
name: 'test-provider'
// Missing create
}) as any
)
await expect(extension.createProvider()).rejects.toThrow('cannot create provider')
})
})
@ -115,21 +116,23 @@ describe('ProviderExtension', () => {
}).toThrow('name is required')
})
it('should throw error if neither create nor import is provided', () => {
expect(() => {
new ProviderExtension({
name: 'test-provider'
} as any)
}).toThrow('either create or import must be provided')
it('should throw error at runtime if neither create nor import is provided', async () => {
// Constructor doesn't validate create/import - validation happens at runtime
const extension = new ProviderExtension({
name: 'test-provider'
} as any)
await expect(extension.createProvider()).rejects.toThrow('cannot create provider')
})
it('should throw error if import is provided without creatorFunctionName', () => {
expect(() => {
new ProviderExtension({
name: 'test-provider',
import: async () => ({})
} as any)
}).toThrow('creatorFunctionName is required when using import')
it('should throw error at runtime if import is provided without creatorFunctionName', async () => {
// Constructor doesn't validate creatorFunctionName - validation happens at runtime
const extension = new ProviderExtension({
name: 'test-provider',
import: async () => ({})
} as any)
await expect(extension.createProvider()).rejects.toThrow('cannot create provider')
})
it('should create extension with valid config', () => {
@ -808,16 +811,26 @@ describe('ProviderExtension', () => {
expect(onAfterCreate).toHaveBeenCalledTimes(1)
})
it('should support explicit ID parameter', async () => {
it('should support variant suffix parameter', async () => {
const extension = new ProviderExtension<TestSettings>({
name: 'test-provider',
create: createMockProviderV3 as any
create: createMockProviderV3 as any,
variants: [
{
suffix: 'chat',
name: 'Test Chat',
transform: (provider) => provider
}
]
})
const settings = { apiKey: 'test-key' }
// Should not throw when providing explicit ID
await expect(extension.createProvider(settings, 'custom-id')).resolves.toBeDefined()
// Should work when providing a valid variant suffix
await expect(extension.createProvider(settings, 'chat')).resolves.toBeDefined()
// Should throw for unknown variant suffix
await expect(extension.createProvider(settings, 'unknown')).rejects.toThrow('variant "unknown" not found')
})
it('should support dynamic import providers', async () => {

View File

@ -1,445 +0,0 @@
/**
* Provider Extensions Integration Tests
* extensions
*/
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
import { extensionRegistry } from '../core/ExtensionRegistry'
import { AnthropicExtension } from '../extensions/anthropic'
import { AzureExtension } from '../extensions/azure'
import { OpenAIExtension } from '../extensions/openai'
// Mock fetch for health checks
global.fetch = vi.fn()
describe('Provider Extensions Integration', () => {
beforeEach(() => {
// Clear registry before each test
extensionRegistry.clear()
extensionRegistry.clearCache()
vi.clearAllMocks()
})
afterEach(() => {
extensionRegistry.clear()
extensionRegistry.clearCache()
})
describe('OpenAI Extension', () => {
it('should register and create provider successfully', async () => {
// Register extension
extensionRegistry.register(OpenAIExtension)
// Verify registration
expect(extensionRegistry.has('openai')).toBe(true)
expect(extensionRegistry.has('oai')).toBe(true) // alias
// Create provider
const provider = await extensionRegistry.createProvider('openai', {
apiKey: 'sk-test-key-123',
baseURL: 'https://api.openai.com/v1'
})
expect(provider).toBeDefined()
})
it('should execute onBeforeCreate hook for validation', async () => {
extensionRegistry.register(OpenAIExtension)
// Invalid API key (doesn't start with "sk-")
await expect(
extensionRegistry.createProvider('openai', {
apiKey: 'invalid-key'
})
).rejects.toThrow('Invalid OpenAI API key format')
// Missing API key
await expect(extensionRegistry.createProvider('openai', {})).rejects.toThrow('OpenAI API key is required')
})
it('should execute onAfterCreate hook for caching', async () => {
extensionRegistry.register(OpenAIExtension)
const settings = {
apiKey: 'sk-test-key-123',
baseURL: 'https://api.openai.com/v1'
}
// Create provider
const provider = await extensionRegistry.createProvider('openai', settings)
// Check extension's internal storage (custom cache)
const ext = extensionRegistry.get('openai')
const cache = ext?.storage.get('providerCache')
expect(cache).toBeDefined()
expect(cache?.has('sk-test-key-123')).toBe(true)
expect(cache?.get('sk-test-key-123')).toBe(provider)
})
it('should cache providers based on settings', async () => {
extensionRegistry.register(OpenAIExtension)
const settings = {
apiKey: 'sk-test-key-123',
baseURL: 'https://api.openai.com/v1'
}
// First call - creates provider
const provider1 = await extensionRegistry.createProvider('openai', settings)
// Second call with same settings - returns cached
const provider2 = await extensionRegistry.createProvider('openai', settings)
expect(provider1).toBe(provider2) // Same instance
// Different settings - creates new provider
const provider3 = await extensionRegistry.createProvider('openai', {
apiKey: 'sk-different-key-456',
baseURL: 'https://api.openai.com/v1'
})
expect(provider3).not.toBe(provider1) // Different instance
})
it('should support openai-chat variant', async () => {
extensionRegistry.register(OpenAIExtension)
// Verify variant ID exists
const providerIds = OpenAIExtension.getProviderIds()
expect(providerIds).toContain('openai')
expect(providerIds).toContain('openai-chat')
// Create variant provider
await extensionRegistry.createAndRegisterProvider('openai', {
apiKey: 'sk-test-key-123'
})
// Both base and variant should be available
const stats = extensionRegistry.getStats()
expect(stats.totalExtensions).toBe(1)
expect(stats.extensionsWithVariants).toBe(1)
})
it('should skip cache when requested', async () => {
extensionRegistry.register(OpenAIExtension)
const settings = {
apiKey: 'sk-test-key-123'
}
// First creation
const provider1 = await extensionRegistry.createProvider('openai', settings)
// Skip cache - creates new instance
const provider2 = await extensionRegistry.createProvider('openai', settings, {
skipCache: true
})
expect(provider2).not.toBe(provider1) // Different instances
})
it('should track health status in storage', async () => {
extensionRegistry.register(OpenAIExtension)
await extensionRegistry.createProvider('openai', {
apiKey: 'sk-test-key-123'
})
const ext = extensionRegistry.get('openai')
const health = ext?.storage.get('healthStatus')
expect(health).toBeDefined()
expect(health?.isHealthy).toBe(true)
expect(health?.consecutiveFailures).toBe(0)
expect(health?.lastCheckTime).toBeGreaterThan(0)
})
})
describe('Anthropic Extension', () => {
it('should validate Anthropic API key format', async () => {
extensionRegistry.register(AnthropicExtension)
// Invalid format (doesn't start with "sk-ant-")
await expect(
extensionRegistry.createProvider('anthropic', {
apiKey: 'sk-test-key'
})
).rejects.toThrow('Invalid Anthropic API key format')
// Missing API key
await expect(extensionRegistry.createProvider('anthropic', {})).rejects.toThrow('Anthropic API key is required')
// Valid format
const provider = await extensionRegistry.createProvider('anthropic', {
apiKey: 'sk-ant-test-key-123'
})
expect(provider).toBeDefined()
})
it('should validate baseURL format', async () => {
extensionRegistry.register(AnthropicExtension)
// Invalid baseURL (no http/https)
await expect(
extensionRegistry.createProvider('anthropic', {
apiKey: 'sk-ant-test-key',
baseURL: 'api.anthropic.com' // Missing protocol
})
).rejects.toThrow('Invalid baseURL format')
// Valid baseURL
const provider = await extensionRegistry.createProvider('anthropic', {
apiKey: 'sk-ant-test-key',
baseURL: 'https://api.anthropic.com'
})
expect(provider).toBeDefined()
})
it('should track creation statistics', async () => {
extensionRegistry.register(AnthropicExtension)
// First successful creation
await extensionRegistry.createProvider('anthropic', {
apiKey: 'sk-ant-test-key-1'
})
const ext = extensionRegistry.get('anthropic')
let stats = ext?.storage.get('stats')
expect(stats?.totalCreations).toBe(1)
expect(stats?.failedCreations).toBe(0)
// Failed creation
try {
await extensionRegistry.createProvider('anthropic', {
apiKey: 'invalid-key'
})
} catch {
// Expected error
}
stats = ext?.storage.get('stats')
expect(stats?.totalCreations).toBe(2)
expect(stats?.failedCreations).toBe(1)
// Second successful creation
await extensionRegistry.createProvider('anthropic', {
apiKey: 'sk-ant-test-key-2'
})
stats = ext?.storage.get('stats')
expect(stats?.totalCreations).toBe(3)
expect(stats?.failedCreations).toBe(1)
})
it('should record lastSuccessfulCreation timestamp', async () => {
extensionRegistry.register(AnthropicExtension)
const before = Date.now()
await extensionRegistry.createProvider('anthropic', {
apiKey: 'sk-ant-test-key'
})
const after = Date.now()
const ext = extensionRegistry.get('anthropic')
const timestamp = ext?.storage.get('lastSuccessfulCreation')
expect(timestamp).toBeDefined()
expect(timestamp).toBeGreaterThanOrEqual(before)
expect(timestamp).toBeLessThanOrEqual(after)
})
it('should support claude alias', async () => {
extensionRegistry.register(AnthropicExtension)
// Access via alias
expect(extensionRegistry.has('claude')).toBe(true)
const provider = await extensionRegistry.createProvider('claude', {
apiKey: 'sk-ant-test-key'
})
expect(provider).toBeDefined()
})
})
describe('Azure Extension', () => {
it('should validate Azure configuration', async () => {
extensionRegistry.register(AzureExtension)
// Missing both resourceName and baseURL
await expect(
extensionRegistry.createProvider('azure', {
apiKey: 'test-key'
})
).rejects.toThrow('Azure OpenAI requires either resourceName or baseURL')
// Missing API key
await expect(
extensionRegistry.createProvider('azure', {
resourceName: 'my-resource'
})
).rejects.toThrow('Azure OpenAI API key is required')
})
it('should validate resourceName format', async () => {
extensionRegistry.register(AzureExtension)
// Invalid format (uppercase)
await expect(
extensionRegistry.createProvider('azure', {
resourceName: 'MyResource',
apiKey: 'test-key'
})
).rejects.toThrow('Invalid Azure resource name format')
// Invalid format (special chars)
await expect(
extensionRegistry.createProvider('azure', {
resourceName: 'my_resource',
apiKey: 'test-key'
})
).rejects.toThrow('Invalid Azure resource name format')
// Valid format
const provider = await extensionRegistry.createProvider('azure', {
resourceName: 'my-resource-123',
apiKey: 'test-key'
})
expect(provider).toBeDefined()
})
it('should cache resource endpoints', async () => {
extensionRegistry.register(AzureExtension)
await extensionRegistry.createProvider('azure', {
resourceName: 'my-resource',
apiKey: 'test-key'
})
const ext = extensionRegistry.get('azure')
const endpoints = ext?.storage.get('resourceEndpoints')
expect(endpoints).toBeDefined()
expect(endpoints?.has('my-resource')).toBe(true)
expect(endpoints?.get('my-resource')).toBe('https://my-resource.openai.azure.com')
})
it('should track validated deployments', async () => {
extensionRegistry.register(AzureExtension)
// First deployment
await extensionRegistry.createProvider('azure', {
resourceName: 'resource-1',
apiKey: 'test-key-1'
})
const ext = extensionRegistry.get('azure')
let deployments = ext?.storage.get('validatedDeployments')
expect(deployments?.size).toBe(1)
expect(deployments?.has('resource-1')).toBe(true)
// Second deployment
await extensionRegistry.createProvider('azure', {
resourceName: 'resource-2',
apiKey: 'test-key-2'
})
deployments = ext?.storage.get('validatedDeployments')
expect(deployments?.size).toBe(2)
expect(deployments?.has('resource-2')).toBe(true)
})
it('should support azure-responses variant', async () => {
extensionRegistry.register(AzureExtension)
const providerIds = AzureExtension.getProviderIds()
expect(providerIds).toContain('azure')
expect(providerIds).toContain('azure-responses')
})
it('should support azure-openai alias', async () => {
extensionRegistry.register(AzureExtension)
expect(extensionRegistry.has('azure-openai')).toBe(true)
const provider = await extensionRegistry.createProvider('azure-openai', {
resourceName: 'my-resource',
apiKey: 'test-key'
})
expect(provider).toBeDefined()
})
})
describe('Multiple Extensions', () => {
it('should register multiple extensions simultaneously', () => {
extensionRegistry.registerAll([OpenAIExtension, AnthropicExtension, AzureExtension])
const stats = extensionRegistry.getStats()
expect(stats.totalExtensions).toBe(3)
expect(stats.extensionsWithVariants).toBe(2) // OpenAI and Azure
})
it('should maintain separate storage for each extension', async () => {
extensionRegistry.registerAll([OpenAIExtension, AnthropicExtension])
// Create providers
await extensionRegistry.createProvider('openai', {
apiKey: 'sk-test-key'
})
await extensionRegistry.createProvider('anthropic', {
apiKey: 'sk-ant-test-key'
})
// Check OpenAI storage
const openaiExt = extensionRegistry.get('openai')
const openaiCache = openaiExt?.storage.get('providerCache')
expect(openaiCache?.size).toBe(1)
// Check Anthropic storage
const anthropicExt = extensionRegistry.get('anthropic')
const anthropicStats = anthropicExt?.storage.get('stats')
expect(anthropicStats?.totalCreations).toBe(1)
// Storages are independent
expect(openaiExt?.storage.get('stats')).toBeUndefined()
expect(anthropicExt?.storage.get('providerCache')).toBeUndefined()
})
it('should clear cache per extension', async () => {
extensionRegistry.registerAll([OpenAIExtension, AnthropicExtension])
// Create providers
await extensionRegistry.createProvider('openai', {
apiKey: 'sk-test-key'
})
await extensionRegistry.createProvider('anthropic', {
apiKey: 'sk-ant-test-key'
})
// Verify both are cached
const stats1 = extensionRegistry.getStats()
expect(stats1.cachedProviders).toBe(2)
// Clear only OpenAI cache
extensionRegistry.clearCache('openai')
const stats2 = extensionRegistry.getStats()
expect(stats2.cachedProviders).toBe(1) // Only Anthropic remains
// Clear all caches
extensionRegistry.clearCache()
const stats3 = extensionRegistry.getStats()
expect(stats3.cachedProviders).toBe(0)
})
})
})

View File

@ -1,165 +0,0 @@
import type { ProviderV3 } from '@ai-sdk/provider'
import { afterEach, beforeEach, describe, expect, it } from 'vitest'
import { ExtensionRegistry } from '../core/ExtensionRegistry'
import { isRegisteredProvider } from '../core/initialization'
import { ProviderExtension } from '../core/ProviderExtension'
import { ProviderInstanceRegistry } from '../core/ProviderInstanceRegistry'
// Mock provider for testing
const createMockProviderV3 = (): ProviderV3 => ({
specificationVersion: 'v3' as const,
languageModel: () => ({}) as any,
embeddingModel: () => ({}) as any,
imageModel: () => ({}) as any
})
describe('initialization utilities', () => {
let testExtensionRegistry: ExtensionRegistry
let testInstanceRegistry: ProviderInstanceRegistry
beforeEach(() => {
testExtensionRegistry = new ExtensionRegistry()
testInstanceRegistry = new ProviderInstanceRegistry()
})
afterEach(() => {
// Clean up registries
testExtensionRegistry = null as any
testInstanceRegistry = null as any
})
describe('isRegisteredProvider()', () => {
it('should return true for providers registered in Extension Registry', () => {
testExtensionRegistry.register(
new ProviderExtension({
name: 'test-provider',
create: createMockProviderV3
})
)
// Note: isRegisteredProvider uses global registries, so this tests the concept
// In practice, we'd need to modify the function to accept registries as parameters
// For now, this documents the expected behavior
expect(typeof isRegisteredProvider).toBe('function')
})
it('should return true for providers registered in Provider Instance Registry', () => {
const mockProvider = createMockProviderV3()
testInstanceRegistry.registerProvider('test-provider', mockProvider)
// Note: This tests the concept - actual implementation uses global registries
expect(testInstanceRegistry.getProvider('test-provider')).toBeDefined()
})
it('should return false for unregistered providers', () => {
// Both registries are empty
const result = isRegisteredProvider('unknown-provider')
// Note: This will check global registries
expect(typeof result).toBe('boolean')
})
it('should work with provider aliases', () => {
testExtensionRegistry.register(
new ProviderExtension({
name: 'openai',
aliases: ['oai'],
create: createMockProviderV3
})
)
// Should be able to check both main ID and alias
expect(testExtensionRegistry.has('openai')).toBe(true)
expect(testExtensionRegistry.has('oai')).toBe(true)
})
it('should work with variant IDs', () => {
testExtensionRegistry.register(
new ProviderExtension({
name: 'openai',
create: createMockProviderV3,
variants: [
{
suffix: 'chat',
name: 'OpenAI Chat',
transform: (provider) => provider
}
]
})
)
// Base provider should be registered
expect(testExtensionRegistry.has('openai')).toBe(true)
// Variant ID can be checked with isVariant method
expect(testExtensionRegistry.isVariant('openai-chat')).toBe(true)
// Base provider ID should be resolvable from variant
expect(testExtensionRegistry.getBaseProviderId('openai-chat')).toBe('openai')
})
it('should return true if provider is in either registry', () => {
// Register in extension registry only
testExtensionRegistry.register(
new ProviderExtension({
name: 'ext-only',
create: createMockProviderV3
})
)
// Register in instance registry only
const mockProvider = createMockProviderV3()
testInstanceRegistry.registerProvider('instance-only', mockProvider)
// Both should be considered registered
expect(testExtensionRegistry.has('ext-only')).toBe(true)
expect(testInstanceRegistry.getProvider('instance-only')).toBeDefined()
})
it('should handle empty string gracefully', () => {
const result = isRegisteredProvider('')
expect(typeof result).toBe('boolean')
})
it('should be case-sensitive', () => {
testExtensionRegistry.register(
new ProviderExtension({
name: 'openai',
create: createMockProviderV3
})
)
expect(testExtensionRegistry.has('openai')).toBe(true)
expect(testExtensionRegistry.has('OpenAI')).toBe(false)
expect(testExtensionRegistry.has('OPENAI')).toBe(false)
})
})
describe('Integration: isRegisteredProvider with actual registries', () => {
it('should correctly identify providers across both registries', () => {
// This test documents the expected behavior when both registries are involved
// isRegisteredProvider checks: extensionRegistry.has(id) || instanceRegistry.getProvider(id) !== undefined
testExtensionRegistry.register(
new ProviderExtension({
name: 'registered-ext',
create: createMockProviderV3
})
)
const mockProvider = createMockProviderV3()
testInstanceRegistry.registerProvider('registered-instance', mockProvider)
// Extension registry check
expect(testExtensionRegistry.has('registered-ext')).toBe(true)
// Instance registry check
expect(testInstanceRegistry.getProvider('registered-instance')).toBeDefined()
// Unregistered provider
expect(testExtensionRegistry.has('unregistered')).toBe(false)
expect(testInstanceRegistry.getProvider('unregistered')).toBeUndefined()
})
})
})

View File

@ -5,7 +5,7 @@
import type { ProviderV3 } from '@ai-sdk/provider'
import type { RegisteredProviderId } from '../index'
import type { CoreProviderSettingsMap, RegisteredProviderId } from '../index'
import { type ProviderExtension } from './ProviderExtension'
import { ProviderCreationError } from './utils'
@ -52,15 +52,12 @@ export class ExtensionRegistry {
register(extension: ProviderExtension<any, any, any>): this {
const { name, aliases, variants } = extension.config
// 检查主 ID 冲突
if (this.extensions.has(name)) {
throw new Error(`Provider extension "${name}" is already registered`)
}
// 注册主 Extension
this.extensions.set(name, extension)
// 注册别名
if (aliases) {
for (const alias of aliases) {
if (this.aliasMap.has(alias)) {
@ -70,7 +67,6 @@ export class ExtensionRegistry {
}
}
// 注册变体 ID
if (variants) {
for (const variant of variants) {
const variantId = `${name}-${variant.suffix}`
@ -106,10 +102,8 @@ export class ExtensionRegistry {
return false
}
// 删除主 Extension
this.extensions.delete(name)
// 删除别名
if (extension.config.aliases) {
for (const alias of extension.config.aliases) {
this.aliasMap.delete(alias)
@ -123,12 +117,10 @@ export class ExtensionRegistry {
* Extension
*/
get(id: string): ProviderExtension<any, any, any> | undefined {
// 直接查找
if (this.extensions.has(id)) {
return this.extensions.get(id)
}
// 通过别名查找
const realName = this.aliasMap.get(id)
if (realName) {
return this.extensions.get(realName)
@ -250,17 +242,7 @@ export class ExtensionRegistry {
* ```
*/
parseProviderId(providerId: string): { baseId: RegisteredProviderId; mode?: string; isVariant: boolean } | null {
// 先检查是否是已注册的 extension直接或通过别名
const extension = this.get(providerId)
if (extension) {
// 是基础 ID 或别名,不是变体
return {
baseId: extension.config.name as RegisteredProviderId,
isVariant: false
}
}
// 遍历所有 extensions查找匹配的变体
// 先遍历所有 extensions查找匹配的变体优先于别名检查
for (const ext of this.extensions.values()) {
if (!ext.config.variants) {
continue
@ -279,6 +261,16 @@ export class ExtensionRegistry {
}
}
// 再检查是否是已注册的 extension直接或通过别名
const extension = this.get(providerId)
if (extension) {
// 是基础 ID 或别名,不是变体
return {
baseId: extension.config.name as RegisteredProviderId,
isVariant: false
}
}
// 无法解析
return null
}
@ -379,15 +371,21 @@ export class ExtensionRegistry {
/**
* provider
* ProviderExtension
*
* @param id - Provider ID
* @param settings - Provider
* @param explicitId - IDAI SDK注册
* :
* 1. - 使 provider ID
* 2. - 使 ID provider
*
* @param id - Provider ID
* @param settings - Provider
* @returns Provider
*/
async createProvider(id: string, settings?: any, explicitId?: string): Promise<ProviderV3> {
// 解析 provider ID提取基础 ID 和变体后缀
async createProvider<T extends RegisteredProviderId & keyof CoreProviderSettingsMap>(
id: T,
settings: CoreProviderSettingsMap[T]
): Promise<ProviderV3>
async createProvider(id: string, settings?: unknown): Promise<ProviderV3>
async createProvider(id: string, settings?: unknown): Promise<ProviderV3> {
const parsed = this.parseProviderId(id)
if (!parsed) {
throw new Error(`Provider extension "${id}" not found. Did you forget to register it?`)
@ -395,16 +393,13 @@ export class ExtensionRegistry {
const { baseId, mode: variantSuffix } = parsed
// 获取基础 extension
const extension = this.get(baseId)
if (!extension) {
throw new Error(`Provider extension "${baseId}" not found. Did you forget to register it?`)
}
try {
// 委托给 Extension 的 createProvider 方法
// Extension 负责缓存、生命周期钩子、AI SDK 注册、变体转换等
return await extension.createProvider(settings, explicitId, variantSuffix)
return await extension.createProvider(settings, variantSuffix)
} catch (error) {
throw new ProviderCreationError(
`Failed to create provider "${id}"`,

View File

@ -1,19 +1,9 @@
import type { ProviderV3 } from '@ai-sdk/provider'
import { LRUCache } from 'lru-cache'
import { deepMergeObjects } from '../../utils'
import type { ExtensionContext, ExtensionStorage, LifecycleHooks, ProviderVariant, StorageAccessor } from '../types'
/**
* Provider
* Extension provider HubProvider 使
* Key: explicit ID ()
* Value: Provider
*/
export const globalProviderStorage = new Map<string, ProviderV3>()
/**
* Provider
*/
export type ProviderCreatorFunction<TSettings = any> = (settings?: TSettings) => ProviderV3 | Promise<ProviderV3>
/**
@ -80,19 +70,10 @@ interface ProviderExtensionConfigWithCreate<
TProvider extends ProviderV3 = ProviderV3,
TName extends string = string
> extends ProviderExtensionConfigBase<TSettings, TStorage, TProvider, TName> {
/**
* provider
*/
create: ProviderCreatorFunction<TSettings>
/**
* 使 import create
*/
import?: never
/**
* 使 creatorFunctionName create
*/
creatorFunctionName?: never
}
@ -107,21 +88,10 @@ interface ProviderExtensionConfigWithImport<
TProvider extends ProviderV3 = ProviderV3,
TName extends string = string
> extends ProviderExtensionConfigBase<TSettings, TStorage, TProvider, TName> {
/**
* 使 create import
*/
create?: never
/**
*
* provider
*/
import: () => Promise<ProviderModule<TSettings>>
/**
* creator
* import 使
*/
creatorFunctionName: string
}
@ -196,20 +166,23 @@ export class ProviderExtension<
> {
private _storage: Map<string, any>
/** Provider 实例缓存 - 按 settings hash 存储 */
private instances: Map<string, TProvider> = new Map()
/** Provider 实例缓存 - 按 settings hash 存储LRU 自动清理 */
private instances: LRUCache<string, TProvider>
/** Settings hash 映射表 - 用于验证缓存是否仍然有效 */
private settingsHashes: Map<string, TSettings | undefined> = new Map()
constructor(public readonly config: TConfig) {
// 验证配置
if (!config.name) {
throw new Error('ProviderExtension: name is required')
}
// 初始化 storage
this._storage = new Map(Object.entries(config.initialStorage || {}))
this.instances = new LRUCache<string, TProvider>({
max: 10,
updateAgeOnGet: true
})
}
/**
@ -370,27 +343,14 @@ export class ProviderExtension<
}
/**
* Provider
* Extension provider Map HubProvider 使
* @private
*/
private registerToAiSdk(provider: TProvider, explicitId: string): void {
// 注册到全局 provider storage
// 使用 explicit ID 作为 key
globalProviderStorage.set(explicitId, provider as any)
}
/**
* Provider
* Provider
* settings settings
*
* @param settings - Provider
* @param explicitId - ID AI SDK
* @param variantSuffix -
* @returns Provider
*/
async createProvider(settings?: TSettings, explicitId?: string, variantSuffix?: string): Promise<TProvider> {
// 验证变体后缀(如果提供)
async createProvider(settings?: TSettings, variantSuffix?: string): Promise<TProvider> {
if (variantSuffix) {
const variant = this.getVariant(variantSuffix)
if (!variant) {
@ -402,31 +362,22 @@ export class ProviderExtension<
}
// 合并 default options
const mergedSettings = deepMergeObjects(
(this.config.defaultOptions || {}) as any,
(settings || {}) as any
) as TSettings
const mergedSettings = deepMergeObjects(this.config.defaultOptions || {}, settings || {}) as TSettings
// 计算 hash包含变体后缀
const hash = this.computeHash(mergedSettings, variantSuffix)
// 检查缓存
const cachedInstance = this.instances.get(hash)
if (cachedInstance) {
return cachedInstance
}
// 执行 onBeforeCreate 钩子
await this.executeHook('onBeforeCreate', mergedSettings)
// 创建基础 provider 实例
let baseProvider: ProviderV3
if (this.config.create) {
// 使用直接创建函数
baseProvider = await Promise.resolve(this.config.create(mergedSettings))
} else if (this.config.import && this.config.creatorFunctionName) {
// 动态导入
const module = await this.config.import()
const creatorFn = module[this.config.creatorFunctionName]
@ -441,39 +392,19 @@ export class ProviderExtension<
throw new Error(`ProviderExtension "${this.config.name}": cannot create provider, invalid configuration`)
}
// 应用变体转换(如果提供了变体后缀)
let finalProvider: TProvider
if (variantSuffix) {
const variant = this.getVariant(variantSuffix)!
// 应用变体的 transform 函数
finalProvider = (await Promise.resolve(variant.transform(baseProvider as TProvider, mergedSettings))) as TProvider
} else {
finalProvider = baseProvider as TProvider
}
// 执行 onAfterCreate 钩子
await this.executeHook('onAfterCreate', mergedSettings, finalProvider)
// 缓存实例
this.instances.set(hash, finalProvider)
this.settingsHashes.set(hash, mergedSettings)
// 确定注册 ID
const registrationId = (() => {
if (explicitId) {
return explicitId
}
// 如果是变体,使用 name-suffix:hash 格式
if (variantSuffix) {
return `${this.config.name}-${variantSuffix}:${hash}`
}
// 否则使用 name:hash
return `${this.config.name}:${hash}`
})()
// 注册到 AI SDK
this.registerToAiSdk(finalProvider, registrationId)
return finalProvider
}

View File

@ -33,7 +33,7 @@ import type {
} from '../types'
import { extensionRegistry } from './ExtensionRegistry'
import type { ProviderExtensionConfig } from './ProviderExtension'
import { globalProviderStorage, ProviderExtension } from './ProviderExtension'
import { ProviderExtension } from './ProviderExtension'
// ==================== Core Extensions ====================
@ -268,14 +268,6 @@ class ProviderInitializationError extends Error {
}
}
// ==================== 全局 Provider Storage 导出 ====================
/**
* Provider Storage
* Extension provider
*/
export { globalProviderStorage }
// ==================== 工具函数 ====================
/**
@ -292,57 +284,6 @@ export function getSupportedProviders(): Array<{
}))
}
/**
* providers (explicit IDs)
*/
export function getInitializedProviders(): string[] {
return Array.from(globalProviderStorage.keys())
}
/**
* providers
*/
export function hasInitializedProviders(): boolean {
return globalProviderStorage.size > 0
}
/**
* provider ID
* Extension Registry (template) Global Provider Storage (initialized instance)
*
* @param id - Provider ID to check (extension name or explicit ID)
* @returns true if the provider is registered (either as extension or initialized instance)
*
* @example
* ```typescript
* if (isRegisteredProvider('openai')) {
* // Provider extension exists
* }
* if (isRegisteredProvider('my-openai-instance')) {
* // Initialized provider instance exists
* }
* ```
*/
export function isRegisteredProvider(id: string): boolean {
return extensionRegistry.has(id) || globalProviderStorage.has(id)
}
/**
* Provider - 使 Extension Registry
*
* @param providerId - Provider ID (extension name)
* @param options - Provider settings
* @param explicitId - ID globalProviderStorageExtension 使 `name:hash` ID
* @returns Provider
*/
export async function createProvider(providerId: string, options: any, explicitId?: string): Promise<any> {
if (!extensionRegistry.has(providerId)) {
throw new Error(`Provider "${providerId}" not found in Extension Registry`)
}
return await extensionRegistry.createProvider(providerId, options, explicitId)
}
/**
* Provider Extension
*/
@ -350,13 +291,6 @@ export function hasProviderConfig(providerId: string): boolean {
return extensionRegistry.has(providerId)
}
/**
* provider
*/
export function clearAllProviders(): void {
globalProviderStorage.clear()
}
// ==================== 导出错误类型 ====================
export { ProviderInitializationError }

View File

@ -1,8 +1,8 @@
/**
* Hub Provider - provider
*
* 支持格式: hubId:providerId:modelId
* 例如: aihubmix:anthropic:claude-3.5-sonnet
* 支持格式: hubId|providerId|modelId
* @example aihubmix|anthropic|claude-3.5-sonnet
*/
import type {
@ -14,10 +14,10 @@ import type {
SpeechModelV3,
TranscriptionModelV3
} from '@ai-sdk/provider'
import { customProvider, wrapProvider } from 'ai'
import { customProvider } from 'ai'
import { globalProviderStorage } from '../core/ProviderExtension'
import type { AiSdkProvider } from '../types'
import type { ExtensionRegistry } from '../core/ExtensionRegistry'
import type { CoreProviderSettingsMap } from '../types'
/** Model ID 分隔符 */
export const DEFAULT_SEPARATOR = '|'
@ -27,6 +27,10 @@ export interface HubProviderConfig {
hubId?: string
/** 是否启用调试日志 */
debug?: boolean
/** ExtensionRegistry实例用于获取provider extensions */
registry: ExtensionRegistry
/** Provider配置映射 */
providerSettingsMap: Map<string, CoreProviderSettingsMap[keyof CoreProviderSettingsMap]>
}
export class HubProviderError extends Error {
@ -46,8 +50,11 @@ export class HubProviderError extends Error {
*/
function parseHubModelId(modelId: string): { provider: string; actualModelId: string } {
const parts = modelId.split(DEFAULT_SEPARATOR)
if (parts.length !== 2) {
throw new HubProviderError(`Invalid hub model ID format. Expected "provider:modelId", got: ${modelId}`, 'unknown')
if (parts.length !== 2 || !parts[0] || !parts[1]) {
throw new HubProviderError(
`Invalid hub model ID format. Expected "provider${DEFAULT_SEPARATOR}modelId", got: ${modelId}`,
'unknown'
)
}
return {
provider: parts[0],
@ -56,37 +63,72 @@ function parseHubModelId(modelId: string): { provider: string; actualModelId: st
}
/**
* Hub Provider
* Hub Provider
*
* provider实例以满足AI SDK的同步要求
* ExtensionRegistry复用ProviderExtension的LRU缓存
*/
export function createHubProvider(config?: HubProviderConfig): AiSdkProvider {
const hubId = config?.hubId ?? 'hub'
export async function createHubProviderAsync(config: HubProviderConfig): Promise<ProviderV3> {
const { registry, providerSettingsMap, debug, hubId = 'hub' } = config
// 预创建所有 provider 实例
const providers = new Map<string, ProviderV3>()
for (const [providerId, settings] of providerSettingsMap.entries()) {
const extension = registry.get(providerId)
if (!extension) {
const availableExtensions = registry
.getAll()
.map((ext) => ext.config.name)
.join(', ')
throw new HubProviderError(
`Provider extension "${providerId}" not found in registry. Available: ${availableExtensions}`,
hubId,
providerId
)
}
function getTargetProvider(providerId: string): ProviderV3 {
// 从全局 provider storage 获取已注册的provider实例
try {
const provider = globalProviderStorage.get(providerId)
if (!provider) {
throw new HubProviderError(
`Provider "${providerId}" is not registered. Please call extension.createProvider(settings, "${providerId}") first.`,
hubId,
providerId
)
}
// 使用 wrapProvider 确保返回的是 V3 provider
// 这样可以自动处理 V2 provider 到 V3 的转换
return wrapProvider({ provider, languageModelMiddleware: [] })
// 通过 extension 创建 provider复用 LRU 缓存)
const provider = await extension.createProvider(settings)
providers.set(providerId, provider)
} catch (error) {
throw new HubProviderError(
`Failed to get provider "${providerId}": ${error instanceof Error ? error.message : 'Unknown error'}`,
`Failed to create provider "${providerId}": ${error instanceof Error ? error.message : String(error)}`,
hubId,
providerId,
error instanceof Error ? error : undefined
)
}
}
return createHubProviderWithProviders(hubId, providers, debug)
}
// 创建符合 ProviderV3 规范的 fallback provider
const hubFallbackProvider = {
/**
* 使providers创建HubProvider
*/
function createHubProviderWithProviders(
hubId: string,
providers: Map<string, ProviderV3>,
debug?: boolean
): ProviderV3 {
function getTargetProvider(providerId: string): ProviderV3 {
const provider = providers.get(providerId)
if (!provider) {
const availableProviders = Array.from(providers.keys()).join(', ')
throw new HubProviderError(
`Provider "${providerId}" not initialized. Available: ${availableProviders}`,
hubId,
providerId
)
}
if (debug) {
console.log(`[HubProvider:${hubId}] Routing to provider: ${providerId}`)
}
return provider
}
const hubFallbackProvider: ProviderV3 = {
specificationVersion: 'v3' as const,
languageModel: (modelId: string): LanguageModelV3 => {
@ -128,6 +170,7 @@ export function createHubProvider(config?: HubProviderConfig): AiSdkProvider {
return targetProvider.speechModel(actualModelId)
},
rerankingModel: (modelId: string): RerankingModelV3 => {
const { provider, actualModelId } = parseHubModelId(modelId)
const targetProvider = getTargetProvider(provider)

View File

@ -3,18 +3,12 @@
*/
// ==================== 核心管理器 ====================
export { globalProviderStorage } from './core/ProviderExtension'
// Provider 核心功能
export {
clearAllProviders,
coreExtensions,
createProvider,
getInitializedProviders,
getSupportedProviders,
hasInitializedProviders,
hasProviderConfig,
isRegisteredProvider,
ProviderInitializationError,
registeredProviderIds
} from './core/initialization'
@ -24,7 +18,7 @@ export {
// 类型定义
export type { AiSdkModel, ProviderError } from './types'
// 类型提取工具(用于应用层 Merge Point 模式)
// 类型提取工具
export type {
CoreProviderSettingsMap,
ExtensionConfigToIdResolutionMap,
@ -43,7 +37,11 @@ export { formatPrivateKey, ProviderCreationError } from './core/utils'
// ==================== 扩展功能 ====================
// Hub Provider 功能
export { createHubProvider, type HubProviderConfig, HubProviderError } from './features/HubProvider'
export {
createHubProviderAsync,
type HubProviderConfig,
HubProviderError
} from './features/HubProvider'
// ==================== Provider Extension 系统 ====================

View File

@ -1,650 +0,0 @@
/**
* RuntimeExecutor.resolveModel Comprehensive Tests
* Tests the private resolveModel and resolveImageModel methods through public APIs
* Covers model resolution, middleware application, and type validation
*/
import type { ImageModelV3, LanguageModelV3 } from '@ai-sdk/provider'
import { generateImage, generateText, streamText } from 'ai'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import {
createMockImageModel,
createMockLanguageModel,
createMockMiddleware,
mockProviderConfigs
} from '../../../__tests__'
import { globalModelResolver } from '../../models'
import { ImageModelResolutionError } from '../errors'
import { RuntimeExecutor } from '../executor'
// Mock AI SDK
vi.mock('ai', async (importOriginal) => {
const actual = (await importOriginal()) as Record<string, unknown>
return {
...actual,
generateText: vi.fn(),
streamText: vi.fn(),
generateImage: vi.fn(),
wrapLanguageModel: vi.fn((config: any) => ({
...config.model,
_middlewareApplied: true,
middleware: config.middleware
}))
}
})
vi.mock('../../providers/core/ProviderInstanceRegistry', () => ({
globalRegistryManagement: {
languageModel: vi.fn(),
imageModel: vi.fn()
},
DEFAULT_SEPARATOR: '|'
}))
vi.mock('../../models', () => ({
globalModelResolver: {
resolveLanguageModel: vi.fn(),
resolveImageModel: vi.fn()
}
}))
describe('RuntimeExecutor - Model Resolution', () => {
let executor: RuntimeExecutor<'openai'>
let mockLanguageModel: LanguageModelV3
let mockImageModel: ImageModelV3
beforeEach(() => {
vi.clearAllMocks()
executor = RuntimeExecutor.create('openai', mockProviderConfigs.openai)
mockLanguageModel = createMockLanguageModel({
specificationVersion: 'v3',
provider: 'openai',
modelId: 'gpt-4'
})
mockImageModel = createMockImageModel({
specificationVersion: 'v3',
provider: 'openai',
modelId: 'dall-e-3'
})
vi.mocked(globalModelResolver.resolveLanguageModel).mockResolvedValue(mockLanguageModel)
vi.mocked(globalModelResolver.resolveImageModel).mockResolvedValue(mockImageModel)
vi.mocked(generateText).mockResolvedValue({
text: 'Test response',
finishReason: 'stop',
usage: { inputTokens: 10, outputTokens: 5, totalTokens: 15 }
} as any)
vi.mocked(streamText).mockResolvedValue({
textStream: (async function* () {
yield 'test'
})()
} as any)
vi.mocked(generateImage).mockResolvedValue({
image: {
base64: 'test-image',
uint8Array: new Uint8Array([1, 2, 3]),
mimeType: 'image/png'
},
warnings: []
} as any)
})
describe('Language Model Resolution (String modelId)', () => {
it('should resolve string modelId using globalModelResolver', async () => {
await executor.generateText({
model: 'gpt-4',
messages: [{ role: 'user', content: 'Hello' }]
})
expect(globalModelResolver.resolveLanguageModel).toHaveBeenCalledWith(
'gpt-4',
'openai',
mockProviderConfigs.openai,
undefined
)
})
it('should pass provider settings to model resolver', async () => {
const customExecutor = RuntimeExecutor.create('anthropic', {
apiKey: 'sk-test',
baseURL: 'https://api.anthropic.com'
})
vi.mocked(globalModelResolver.resolveLanguageModel).mockResolvedValue(mockLanguageModel)
await customExecutor.generateText({
model: 'claude-3-5-sonnet',
messages: [{ role: 'user', content: 'Test' }]
})
expect(globalModelResolver.resolveLanguageModel).toHaveBeenCalledWith(
'claude-3-5-sonnet',
'anthropic',
{
apiKey: 'sk-test',
baseURL: 'https://api.anthropic.com'
},
undefined
)
})
it('should resolve traditional format modelId', async () => {
await executor.generateText({
model: 'gpt-4-turbo',
messages: [{ role: 'user', content: 'Test' }]
})
expect(globalModelResolver.resolveLanguageModel).toHaveBeenCalledWith(
'gpt-4-turbo',
'openai',
expect.any(Object),
undefined
)
})
it('should resolve namespaced format modelId', async () => {
await executor.generateText({
model: 'aihubmix|anthropic|claude-3',
messages: [{ role: 'user', content: 'Test' }]
})
expect(globalModelResolver.resolveLanguageModel).toHaveBeenCalledWith(
'aihubmix|anthropic|claude-3',
'openai',
expect.any(Object),
undefined
)
})
it('should use resolved model for generation', async () => {
await executor.generateText({
model: 'gpt-4',
messages: [{ role: 'user', content: 'Hello' }]
})
expect(generateText).toHaveBeenCalledWith(
expect.objectContaining({
model: mockLanguageModel
})
)
})
it('should work with streamText', async () => {
await executor.streamText({
model: 'gpt-4',
messages: [{ role: 'user', content: 'Stream test' }]
})
expect(globalModelResolver.resolveLanguageModel).toHaveBeenCalled()
expect(streamText).toHaveBeenCalledWith(
expect.objectContaining({
model: mockLanguageModel
})
)
})
})
describe('Language Model Resolution (Direct Model Object)', () => {
it('should accept pre-resolved V3 model object', async () => {
const directModel: LanguageModelV3 = createMockLanguageModel({
specificationVersion: 'v3',
provider: 'openai',
modelId: 'gpt-4'
})
await executor.generateText({
model: directModel,
messages: [{ role: 'user', content: 'Test' }]
})
// Should NOT call resolver for direct model
expect(globalModelResolver.resolveLanguageModel).not.toHaveBeenCalled()
// Should use the model directly
expect(generateText).toHaveBeenCalledWith(
expect.objectContaining({
model: directModel
})
)
})
it('should accept V2 model object without validation (plugin engine handles it)', async () => {
const v2Model = {
specificationVersion: 'v2',
provider: 'openai',
modelId: 'gpt-4',
doGenerate: vi.fn()
} as any
// The plugin engine accepts any model object directly without validation
// V3 validation only happens when resolving string modelIds
await expect(
executor.generateText({
model: v2Model,
messages: [{ role: 'user', content: 'Test' }]
})
).resolves.toBeDefined()
})
it('should accept any model object without checking specification version', async () => {
const v2Model = {
specificationVersion: 'v2',
provider: 'custom-provider',
modelId: 'custom-model',
doGenerate: vi.fn()
} as any
// Direct model objects bypass validation
// The executor trusts that plugins/users provide valid models
await expect(
executor.generateText({
model: v2Model,
messages: [{ role: 'user', content: 'Test' }]
})
).resolves.toBeDefined()
})
it('should accept model object with streamText', async () => {
const directModel = createMockLanguageModel({
specificationVersion: 'v3'
})
await executor.streamText({
model: directModel,
messages: [{ role: 'user', content: 'Stream' }]
})
expect(globalModelResolver.resolveLanguageModel).not.toHaveBeenCalled()
expect(streamText).toHaveBeenCalledWith(
expect.objectContaining({
model: directModel
})
)
})
})
describe('Middleware Application', () => {
it('should apply middlewares to string modelId', async () => {
const testMiddleware = createMockMiddleware()
await executor.generateText(
{
model: 'gpt-4',
messages: [{ role: 'user', content: 'Test' }]
},
{ middlewares: [testMiddleware] }
)
expect(globalModelResolver.resolveLanguageModel).toHaveBeenCalledWith('gpt-4', 'openai', expect.any(Object), [
testMiddleware
])
})
it('should apply multiple middlewares in order', async () => {
const middleware1 = createMockMiddleware()
const middleware2 = createMockMiddleware()
await executor.generateText(
{
model: 'gpt-4',
messages: [{ role: 'user', content: 'Test' }]
},
{ middlewares: [middleware1, middleware2] }
)
expect(globalModelResolver.resolveLanguageModel).toHaveBeenCalledWith('gpt-4', 'openai', expect.any(Object), [
middleware1,
middleware2
])
})
it('should pass middlewares to model resolver for string modelIds', async () => {
const testMiddleware = createMockMiddleware()
await executor.generateText(
{
model: 'gpt-4', // String model ID
messages: [{ role: 'user', content: 'Test' }]
},
{ middlewares: [testMiddleware] }
)
// Middlewares are passed to the resolver for string modelIds
expect(globalModelResolver.resolveLanguageModel).toHaveBeenCalledWith('gpt-4', 'openai', expect.any(Object), [
testMiddleware
])
})
it('should not apply middlewares when none provided', async () => {
await executor.generateText({
model: 'gpt-4',
messages: [{ role: 'user', content: 'Test' }]
})
expect(globalModelResolver.resolveLanguageModel).toHaveBeenCalledWith(
'gpt-4',
'openai',
expect.any(Object),
undefined
)
})
it('should handle empty middleware array', async () => {
await executor.generateText(
{
model: 'gpt-4',
messages: [{ role: 'user', content: 'Test' }]
},
{ middlewares: [] }
)
expect(globalModelResolver.resolveLanguageModel).toHaveBeenCalledWith('gpt-4', 'openai', expect.any(Object), [])
})
it('should work with middlewares in streamText', async () => {
const middleware = createMockMiddleware()
await executor.streamText(
{
model: 'gpt-4',
messages: [{ role: 'user', content: 'Stream' }]
},
{ middlewares: [middleware] }
)
expect(globalModelResolver.resolveLanguageModel).toHaveBeenCalledWith('gpt-4', 'openai', expect.any(Object), [
middleware
])
})
})
describe('Image Model Resolution', () => {
it('should resolve string image modelId using globalModelResolver', async () => {
await executor.generateImage({
model: 'dall-e-3',
prompt: 'A beautiful sunset'
})
expect(globalModelResolver.resolveImageModel).toHaveBeenCalledWith('dall-e-3', 'openai')
})
it('should accept direct ImageModelV3 object', async () => {
const directImageModel: ImageModelV3 = createMockImageModel({
specificationVersion: 'v3',
provider: 'openai',
modelId: 'dall-e-3'
})
await executor.generateImage({
model: directImageModel,
prompt: 'Test image'
})
expect(globalModelResolver.resolveImageModel).not.toHaveBeenCalled()
expect(generateImage).toHaveBeenCalledWith(
expect.objectContaining({
model: directImageModel
})
)
})
it('should resolve namespaced image model ID', async () => {
await executor.generateImage({
model: 'aihubmix|openai|dall-e-3',
prompt: 'Namespaced image'
})
expect(globalModelResolver.resolveImageModel).toHaveBeenCalledWith('aihubmix|openai|dall-e-3', 'openai')
})
it('should throw ImageModelResolutionError on resolution failure', async () => {
const resolutionError = new Error('Model not found')
vi.mocked(globalModelResolver.resolveImageModel).mockRejectedValue(resolutionError)
await expect(
executor.generateImage({
model: 'invalid-model',
prompt: 'Test'
})
).rejects.toThrow(ImageModelResolutionError)
})
it('should include modelId and providerId in ImageModelResolutionError', async () => {
vi.mocked(globalModelResolver.resolveImageModel).mockRejectedValue(new Error('Not found'))
try {
await executor.generateImage({
model: 'invalid-model',
prompt: 'Test'
})
expect.fail('Should have thrown ImageModelResolutionError')
} catch (error) {
expect(error).toBeInstanceOf(ImageModelResolutionError)
const imgError = error as ImageModelResolutionError
expect(imgError.message).toContain('invalid-model')
expect(imgError.providerId).toBe('openai')
}
})
it('should extract modelId from direct model object in error', async () => {
const directModel = createMockImageModel({
modelId: 'direct-model',
doGenerate: vi.fn().mockRejectedValue(new Error('Generation failed'))
})
vi.mocked(generateImage).mockRejectedValue(new Error('Generation failed'))
await expect(
executor.generateImage({
model: directModel,
prompt: 'Test'
})
).rejects.toThrow()
})
})
describe('Provider-Specific Model Resolution', () => {
it('should resolve models for OpenAI provider', async () => {
const openaiExecutor = RuntimeExecutor.create('openai', mockProviderConfigs.openai)
await openaiExecutor.generateText({
model: 'gpt-4',
messages: [{ role: 'user', content: 'Test' }]
})
expect(globalModelResolver.resolveLanguageModel).toHaveBeenCalledWith(
'gpt-4',
'openai',
expect.any(Object),
undefined
)
})
it('should resolve models for Anthropic provider', async () => {
const anthropicExecutor = RuntimeExecutor.create('anthropic', mockProviderConfigs.anthropic)
await anthropicExecutor.generateText({
model: 'claude-3-5-sonnet',
messages: [{ role: 'user', content: 'Test' }]
})
expect(globalModelResolver.resolveLanguageModel).toHaveBeenCalledWith(
'claude-3-5-sonnet',
'anthropic',
expect.any(Object),
undefined
)
})
it('should resolve models for Google provider', async () => {
const googleExecutor = RuntimeExecutor.create('google', mockProviderConfigs.google)
await googleExecutor.generateText({
model: 'gemini-2.0-flash',
messages: [{ role: 'user', content: 'Test' }]
})
expect(globalModelResolver.resolveLanguageModel).toHaveBeenCalledWith(
'gemini-2.0-flash',
'google',
expect.any(Object),
undefined
)
})
it('should resolve models for OpenAI-compatible provider', async () => {
const compatibleExecutor = RuntimeExecutor.createOpenAICompatible(mockProviderConfigs['openai-compatible'])
await compatibleExecutor.generateText({
model: 'custom-model',
messages: [{ role: 'user', content: 'Test' }]
})
expect(globalModelResolver.resolveLanguageModel).toHaveBeenCalledWith(
'custom-model',
'openai-compatible',
expect.any(Object),
undefined
)
})
})
describe('OpenAI Mode Handling', () => {
it('should pass mode setting to model resolver', async () => {
const executorWithMode = RuntimeExecutor.create('openai', {
...mockProviderConfigs.openai,
mode: 'chat'
})
await executorWithMode.generateText({
model: 'gpt-4',
messages: [{ role: 'user', content: 'Test' }]
})
expect(globalModelResolver.resolveLanguageModel).toHaveBeenCalledWith(
'gpt-4',
'openai',
expect.objectContaining({
mode: 'chat'
}),
undefined
)
})
it('should handle responses mode', async () => {
const executorWithMode = RuntimeExecutor.create('openai', {
...mockProviderConfigs.openai,
mode: 'responses'
})
await executorWithMode.generateText({
model: 'gpt-4',
messages: [{ role: 'user', content: 'Test' }]
})
expect(globalModelResolver.resolveLanguageModel).toHaveBeenCalledWith(
'gpt-4',
'openai',
expect.objectContaining({
mode: 'responses'
}),
undefined
)
})
})
describe('Edge Cases', () => {
it('should handle empty string modelId', async () => {
await executor.generateText({
model: '',
messages: [{ role: 'user', content: 'Test' }]
})
expect(globalModelResolver.resolveLanguageModel).toHaveBeenCalledWith('', 'openai', expect.any(Object), undefined)
})
it('should handle model resolution errors gracefully', async () => {
vi.mocked(globalModelResolver.resolveLanguageModel).mockRejectedValue(new Error('Model not found'))
await expect(
executor.generateText({
model: 'nonexistent-model',
messages: [{ role: 'user', content: 'Test' }]
})
).rejects.toThrow('Model not found')
})
it('should handle concurrent model resolutions', async () => {
const promises = [
executor.generateText({ model: 'gpt-4', messages: [{ role: 'user', content: 'Test 1' }] }),
executor.generateText({ model: 'gpt-4-turbo', messages: [{ role: 'user', content: 'Test 2' }] }),
executor.generateText({ model: 'gpt-3.5-turbo', messages: [{ role: 'user', content: 'Test 3' }] })
]
await Promise.all(promises)
expect(globalModelResolver.resolveLanguageModel).toHaveBeenCalledTimes(3)
})
it('should accept model object even without specificationVersion', async () => {
const invalidModel = {
provider: 'test',
modelId: 'test-model'
// Missing specificationVersion
} as any
// Plugin engine doesn't validate direct model objects
// It's the user's responsibility to provide valid models
await expect(
executor.generateText({
model: invalidModel,
messages: [{ role: 'user', content: 'Test' }]
})
).resolves.toBeDefined()
})
})
describe('Type Safety Validation', () => {
it('should ensure resolved model is LanguageModelV3', async () => {
const v3Model = createMockLanguageModel({
specificationVersion: 'v3'
})
vi.mocked(globalModelResolver.resolveLanguageModel).mockResolvedValue(v3Model)
await executor.generateText({
model: 'gpt-4',
messages: [{ role: 'user', content: 'Test' }]
})
expect(generateText).toHaveBeenCalledWith(
expect.objectContaining({
model: expect.objectContaining({
specificationVersion: 'v3'
})
})
)
})
it('should not enforce specification version for direct models', async () => {
const v1Model = {
specificationVersion: 'v1',
provider: 'test',
modelId: 'test'
} as any
// Direct models bypass validation in the plugin engine
// Only resolved models (from string IDs) are validated
await expect(
executor.generateText({
model: v1Model,
messages: [{ role: 'user', content: 'Test' }]
})
).resolves.toBeDefined()
})
})
})

View File

@ -1,9 +1,9 @@
import type { ImageModelV3 } from '@ai-sdk/provider'
import { createMockImageModel, createMockProviderV3 } from '@test-utils'
import { generateImage as aiGenerateImage, NoImageGeneratedError } from 'ai'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import { type AiPlugin } from '../../plugins'
import { globalProviderInstanceRegistry } from '../../providers/core/ProviderInstanceRegistry'
import { ImageGenerationError, ImageModelResolutionError } from '../errors'
import { RuntimeExecutor } from '../executor'
@ -21,32 +21,32 @@ vi.mock('ai', () => ({
}
}))
vi.mock('../../providers/core/ProviderInstanceRegistry', () => ({
globalProviderInstanceRegistry: {
imageModel: vi.fn()
},
DEFAULT_SEPARATOR: '|'
}))
describe('RuntimeExecutor.generateImage', () => {
let executor: RuntimeExecutor<'openai'>
let mockImageModel: ImageModelV3
let mockProvider: any
let mockGenerateImageResult: any
beforeEach(() => {
// Reset all mocks
vi.clearAllMocks()
// Create executor instance
executor = RuntimeExecutor.create('openai', {
apiKey: 'test-key'
})
// Mock image model
mockImageModel = {
mockImageModel = createMockImageModel({
modelId: 'dall-e-3',
provider: 'openai'
} as ImageModelV3
})
// Create mock provider with imageModel as a spy
mockProvider = createMockProviderV3({
provider: 'openai',
imageModel: vi.fn(() => mockImageModel)
})
// Create executor instance
executor = RuntimeExecutor.create('openai', mockProvider, {
apiKey: 'test-key'
})
// Mock generateImage result
mockGenerateImageResult = {
@ -71,8 +71,6 @@ describe('RuntimeExecutor.generateImage', () => {
responses: []
}
// Setup mocks to avoid "No providers registered" error
vi.mocked(globalProviderInstanceRegistry.imageModel).mockReturnValue(mockImageModel)
vi.mocked(aiGenerateImage).mockResolvedValue(mockGenerateImageResult)
})
@ -80,7 +78,7 @@ describe('RuntimeExecutor.generateImage', () => {
it('should generate a single image with minimal parameters', async () => {
const result = await executor.generateImage({ model: 'dall-e-3', prompt: 'A futuristic cityscape at sunset' })
expect(globalProviderInstanceRegistry.imageModel).toHaveBeenCalledWith('openai|dall-e-3')
expect(mockProvider.imageModel).toHaveBeenCalledWith('dall-e-3')
expect(aiGenerateImage).toHaveBeenCalledWith({
model: mockImageModel,
@ -96,7 +94,8 @@ describe('RuntimeExecutor.generateImage', () => {
prompt: 'A beautiful landscape'
})
// Note: globalProviderInstanceRegistry.imageModel may still be called due to resolveImageModel logic
// Pre-created model is used directly, provider.imageModel is not called
expect(mockProvider.imageModel).not.toHaveBeenCalled()
expect(aiGenerateImage).toHaveBeenCalledWith({
model: mockImageModel,
prompt: 'A beautiful landscape'
@ -224,6 +223,7 @@ describe('RuntimeExecutor.generateImage', () => {
const executorWithPlugin = RuntimeExecutor.create(
'openai',
mockProvider,
{
apiKey: 'test-key'
},
@ -269,6 +269,7 @@ describe('RuntimeExecutor.generateImage', () => {
const executorWithPlugin = RuntimeExecutor.create(
'openai',
mockProvider,
{
apiKey: 'test-key'
},
@ -309,6 +310,7 @@ describe('RuntimeExecutor.generateImage', () => {
const executorWithPlugin = RuntimeExecutor.create(
'openai',
mockProvider,
{
apiKey: 'test-key'
},
@ -325,7 +327,8 @@ describe('RuntimeExecutor.generateImage', () => {
describe('Error handling', () => {
it('should handle model creation errors', async () => {
const modelError = new Error('Failed to get image model')
vi.mocked(globalProviderInstanceRegistry.imageModel).mockImplementation(() => {
// Since mockProvider.imageModel is already a vi.fn() spy, we can mock it directly
mockProvider.imageModel.mockImplementation(() => {
throw modelError
})
@ -336,7 +339,7 @@ describe('RuntimeExecutor.generateImage', () => {
it('should handle ImageModelResolutionError correctly', async () => {
const resolutionError = new ImageModelResolutionError('invalid-model', 'openai', new Error('Model not found'))
vi.mocked(globalProviderInstanceRegistry.imageModel).mockImplementation(() => {
mockProvider.imageModel.mockImplementation(() => {
throw resolutionError
})
@ -353,7 +356,7 @@ describe('RuntimeExecutor.generateImage', () => {
it('should handle ImageModelResolutionError without provider', async () => {
const resolutionError = new ImageModelResolutionError('unknown-model')
vi.mocked(globalProviderInstanceRegistry.imageModel).mockImplementation(() => {
mockProvider.imageModel.mockImplementation(() => {
throw resolutionError
})
@ -398,6 +401,7 @@ describe('RuntimeExecutor.generateImage', () => {
const executorWithPlugin = RuntimeExecutor.create(
'openai',
mockProvider,
{
apiKey: 'test-key'
},
@ -436,23 +440,43 @@ describe('RuntimeExecutor.generateImage', () => {
describe('Multiple providers support', () => {
it('should work with different providers', async () => {
const googleExecutor = RuntimeExecutor.create('google', {
const googleImageModel = createMockImageModel({
provider: 'google',
modelId: 'imagen-3.0-generate-002'
})
const googleProvider = createMockProviderV3({
provider: 'google',
imageModel: vi.fn(() => googleImageModel)
})
const googleExecutor = RuntimeExecutor.create('google', googleProvider, {
apiKey: 'google-key'
})
await googleExecutor.generateImage({ model: 'imagen-3.0-generate-002', prompt: 'A landscape' })
expect(globalProviderInstanceRegistry.imageModel).toHaveBeenCalledWith('google|imagen-3.0-generate-002')
expect(googleProvider.imageModel).toHaveBeenCalledWith('imagen-3.0-generate-002')
})
it('should support xAI Grok image models', async () => {
const xaiExecutor = RuntimeExecutor.create('xai', {
const xaiImageModel = createMockImageModel({
provider: 'xai',
modelId: 'grok-2-image'
})
const xaiProvider = createMockProviderV3({
provider: 'xai',
imageModel: vi.fn(() => xaiImageModel)
})
const xaiExecutor = RuntimeExecutor.create('xai', xaiProvider, {
apiKey: 'xai-key'
})
await xaiExecutor.generateImage({ model: 'grok-2-image', prompt: 'A futuristic robot' })
expect(globalProviderInstanceRegistry.imageModel).toHaveBeenCalledWith('xai|grok-2-image')
expect(xaiProvider.imageModel).toHaveBeenCalledWith('grok-2-image')
})
})

View File

@ -3,18 +3,18 @@
* Tests non-streaming text generation across all providers with various parameters
*/
import { generateText } from 'ai'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import {
createMockLanguageModel,
createMockProviderV3,
mockCompleteResponses,
mockProviderConfigs,
testMessages,
testTools
} from '../../../__tests__'
} from '@test-utils'
import { generateText } from 'ai'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import type { AiPlugin } from '../../plugins'
import { globalProviderInstanceRegistry } from '../../providers/core/ProviderInstanceRegistry'
import { RuntimeExecutor } from '../executor'
// Mock AI SDK - use importOriginal to keep jsonSchema and other non-mocked exports
@ -26,28 +26,28 @@ vi.mock('ai', async (importOriginal) => {
}
})
vi.mock('../../providers/core/ProviderInstanceRegistry', () => ({
globalProviderInstanceRegistry: {
languageModel: vi.fn()
},
DEFAULT_SEPARATOR: '|'
}))
describe('RuntimeExecutor.generateText', () => {
let executor: RuntimeExecutor<'openai'>
let mockLanguageModel: any
let mockProvider: any
beforeEach(() => {
vi.clearAllMocks()
executor = RuntimeExecutor.create('openai', mockProviderConfigs.openai)
mockLanguageModel = createMockLanguageModel({
provider: 'openai',
modelId: 'gpt-4'
})
vi.mocked(globalProviderInstanceRegistry.languageModel).mockReturnValue(mockLanguageModel)
// ✅ Create mock provider with languageModel as a spy
mockProvider = createMockProviderV3({
provider: 'openai',
languageModel: vi.fn(() => mockLanguageModel)
})
// ✅ Pass provider instance to RuntimeExecutor.create()
executor = RuntimeExecutor.create('openai', mockProvider, mockProviderConfigs.openai)
vi.mocked(generateText).mockResolvedValue(mockCompleteResponses.simple as any)
})
@ -231,75 +231,87 @@ describe('RuntimeExecutor.generateText', () => {
describe('Multiple Providers', () => {
it('should work with Anthropic provider', async () => {
const anthropicExecutor = RuntimeExecutor.create('anthropic', mockProviderConfigs.anthropic)
const anthropicModel = createMockLanguageModel({
provider: 'anthropic',
modelId: 'claude-3-5-sonnet-20241022'
})
vi.mocked(globalProviderInstanceRegistry.languageModel).mockReturnValue(anthropicModel)
const anthropicProvider = createMockProviderV3({
provider: 'anthropic',
languageModel: vi.fn(() => anthropicModel)
})
const anthropicExecutor = RuntimeExecutor.create('anthropic', anthropicProvider, mockProviderConfigs.anthropic)
await anthropicExecutor.generateText({
model: 'claude-3-5-sonnet-20241022',
messages: testMessages.simple
})
expect(globalProviderInstanceRegistry.languageModel).toHaveBeenCalledWith('anthropic|claude-3-5-sonnet-20241022')
expect(anthropicProvider.languageModel).toHaveBeenCalledWith('claude-3-5-sonnet-20241022')
})
it('should work with Google provider', async () => {
const googleExecutor = RuntimeExecutor.create('google', mockProviderConfigs.google)
const googleModel = createMockLanguageModel({
provider: 'google',
modelId: 'gemini-2.0-flash-exp'
})
vi.mocked(globalProviderInstanceRegistry.languageModel).mockReturnValue(googleModel)
const googleProvider = createMockProviderV3({
provider: 'google',
languageModel: vi.fn(() => googleModel)
})
const googleExecutor = RuntimeExecutor.create('google', googleProvider, mockProviderConfigs.google)
await googleExecutor.generateText({
model: 'gemini-2.0-flash-exp',
messages: testMessages.simple
})
expect(globalProviderInstanceRegistry.languageModel).toHaveBeenCalledWith('google|gemini-2.0-flash-exp')
expect(googleProvider.languageModel).toHaveBeenCalledWith('gemini-2.0-flash-exp')
})
it('should work with xAI provider', async () => {
const xaiExecutor = RuntimeExecutor.create('xai', mockProviderConfigs.xai)
const xaiModel = createMockLanguageModel({
provider: 'xai',
modelId: 'grok-2-latest'
})
vi.mocked(globalProviderInstanceRegistry.languageModel).mockReturnValue(xaiModel)
const xaiProvider = createMockProviderV3({
provider: 'xai',
languageModel: vi.fn(() => xaiModel)
})
const xaiExecutor = RuntimeExecutor.create('xai', xaiProvider, mockProviderConfigs.xai)
await xaiExecutor.generateText({
model: 'grok-2-latest',
messages: testMessages.simple
})
expect(globalProviderInstanceRegistry.languageModel).toHaveBeenCalledWith('xai|grok-2-latest')
expect(xaiProvider.languageModel).toHaveBeenCalledWith('grok-2-latest')
})
it('should work with DeepSeek provider', async () => {
const deepseekExecutor = RuntimeExecutor.create('deepseek', mockProviderConfigs.deepseek)
const deepseekModel = createMockLanguageModel({
provider: 'deepseek',
modelId: 'deepseek-chat'
})
vi.mocked(globalProviderInstanceRegistry.languageModel).mockReturnValue(deepseekModel)
const deepseekProvider = createMockProviderV3({
provider: 'deepseek',
languageModel: vi.fn(() => deepseekModel)
})
const deepseekExecutor = RuntimeExecutor.create('deepseek', deepseekProvider, mockProviderConfigs.deepseek)
await deepseekExecutor.generateText({
model: 'deepseek-chat',
messages: testMessages.simple
})
expect(globalProviderInstanceRegistry.languageModel).toHaveBeenCalledWith('deepseek|deepseek-chat')
expect(deepseekProvider.languageModel).toHaveBeenCalledWith('deepseek-chat')
})
})
@ -325,7 +337,9 @@ describe('RuntimeExecutor.generateText', () => {
})
}
const executorWithPlugin = RuntimeExecutor.create('openai', mockProviderConfigs.openai, [testPlugin])
const executorWithPlugin = RuntimeExecutor.create('openai', mockProvider, mockProviderConfigs.openai, [
testPlugin
])
const result = await executorWithPlugin.generateText({
model: 'gpt-4',
@ -364,7 +378,10 @@ describe('RuntimeExecutor.generateText', () => {
})
}
const executorWithPlugins = RuntimeExecutor.create('openai', mockProviderConfigs.openai, [plugin1, plugin2])
const executorWithPlugins = RuntimeExecutor.create('openai', mockProvider, mockProviderConfigs.openai, [
plugin1,
plugin2
])
await executorWithPlugins.generateText({
model: 'gpt-4',
@ -404,7 +421,9 @@ describe('RuntimeExecutor.generateText', () => {
onError: vi.fn()
}
const executorWithPlugin = RuntimeExecutor.create('openai', mockProviderConfigs.openai, [errorPlugin])
const executorWithPlugin = RuntimeExecutor.create('openai', mockProvider, mockProviderConfigs.openai, [
errorPlugin
])
await expect(
executorWithPlugin.generateText({
@ -425,7 +444,7 @@ describe('RuntimeExecutor.generateText', () => {
it('should handle model not found error', async () => {
const error = new Error('Model not found: invalid-model')
vi.mocked(globalProviderInstanceRegistry.languageModel).mockImplementation(() => {
mockProvider.languageModel.mockImplementationOnce(() => {
throw error
})

View File

@ -5,9 +5,9 @@
*/
import type { ImageModelV3, LanguageModelV3 } from '@ai-sdk/provider'
import { createMockImageModel, createMockLanguageModel } from '@test-utils'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import { createMockImageModel, createMockLanguageModel } from '../../../__tests__'
import { ModelResolutionError, RecursiveDepthError } from '../../errors'
import type { AiPlugin, GenerateTextParams, GenerateTextResult } from '../../plugins'
import { PluginEngine } from '../pluginEngine'

View File

@ -3,12 +3,17 @@
* Tests streaming text generation across all providers with various parameters
*/
import {
collectStreamChunks,
createMockLanguageModel,
createMockProviderV3,
mockProviderConfigs,
testMessages
} from '@test-utils'
import { streamText } from 'ai'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import { collectStreamChunks, createMockLanguageModel, mockProviderConfigs, testMessages } from '../../../__tests__'
import type { AiPlugin } from '../../plugins'
import { globalProviderInstanceRegistry } from '../../providers/core/ProviderInstanceRegistry'
import { RuntimeExecutor } from '../executor'
// Mock AI SDK - use importOriginal to keep jsonSchema and other non-mocked exports
@ -20,28 +25,25 @@ vi.mock('ai', async (importOriginal) => {
}
})
vi.mock('../../providers/core/ProviderInstanceRegistry', () => ({
globalProviderInstanceRegistry: {
languageModel: vi.fn()
},
DEFAULT_SEPARATOR: '|'
}))
describe('RuntimeExecutor.streamText', () => {
let executor: RuntimeExecutor<'openai'>
let mockLanguageModel: any
let mockProvider: any
beforeEach(() => {
vi.clearAllMocks()
executor = RuntimeExecutor.create('openai', mockProviderConfigs.openai)
mockLanguageModel = createMockLanguageModel({
provider: 'openai',
modelId: 'gpt-4'
})
vi.mocked(globalProviderInstanceRegistry.languageModel).mockReturnValue(mockLanguageModel)
mockProvider = createMockProviderV3({
provider: 'openai',
languageModel: () => mockLanguageModel
})
executor = RuntimeExecutor.create('openai', mockProvider, mockProviderConfigs.openai)
})
describe('Basic Functionality', () => {
@ -416,7 +418,9 @@ describe('RuntimeExecutor.streamText', () => {
})
}
const executorWithPlugin = RuntimeExecutor.create('openai', mockProviderConfigs.openai, [testPlugin])
const executorWithPlugin = RuntimeExecutor.create('openai', mockProvider, mockProviderConfigs.openai, [
testPlugin
])
const mockStream = {
textStream: (async function* () {
@ -509,7 +513,9 @@ describe('RuntimeExecutor.streamText', () => {
onError: vi.fn()
}
const executorWithPlugin = RuntimeExecutor.create('openai', mockProviderConfigs.openai, [errorPlugin])
const executorWithPlugin = RuntimeExecutor.create('openai', mockProvider, mockProviderConfigs.openai, [
errorPlugin
])
await expect(
executorWithPlugin.streamText({

View File

@ -2,7 +2,7 @@
*
* AI调用处理
*/
import type { ImageModelV3, LanguageModelV3, LanguageModelV3Middleware } from '@ai-sdk/provider'
import type { ImageModelV3, LanguageModelV3, LanguageModelV3Middleware, ProviderV3 } from '@ai-sdk/provider'
import type { LanguageModel } from 'ai'
import {
generateImage as _generateImage,
@ -11,7 +11,7 @@ import {
wrapLanguageModel
} from 'ai'
import { globalModelResolver } from '../models'
import { ModelResolver } from '../models'
import { type ModelConfig } from '../models/types'
import { isV3Model } from '../models/utils'
import { type AiPlugin, type AiRequestContext, definePlugin } from '../plugins'
@ -26,11 +26,13 @@ export class RuntimeExecutor<
> {
public pluginEngine: PluginEngine<T>
private config: RuntimeConfig<T, TSettingsMap>
private modelResolver: ModelResolver
constructor(config: RuntimeConfig<T, TSettingsMap>) {
this.config = config
// 创建插件客户端
this.pluginEngine = new PluginEngine(config.providerId, config.plugins || [])
this.modelResolver = new ModelResolver(config.provider)
}
private createResolveModelPlugin(middlewares?: LanguageModelV3Middleware[]) {
@ -175,13 +177,9 @@ export class RuntimeExecutor<
middlewares?: LanguageModelV3Middleware[]
): Promise<LanguageModelV3> {
if (typeof modelOrId === 'string') {
// 🎯 字符串modelId使用新的ModelResolver解析传递完整参数
return await globalModelResolver.resolveLanguageModel(
modelOrId, // 支持 'gpt-4' 和 'aihubmix:anthropic:claude-3.5-sonnet'
this.config.providerId, // fallback provider
this.config.providerSettings, // provider options
middlewares // 中间件数组
)
// 字符串modelId使用 ModelResolver 解析
// Provider会处理命名空间格式路由如果是HubProvider
return await this.modelResolver.resolveLanguageModel(modelOrId, middlewares)
} else {
// 已经是模型对象
// 所有 provider 都应该返回 V3 模型(通过 wrapProvider 确保)
@ -206,11 +204,9 @@ export class RuntimeExecutor<
private async resolveImageModel(modelOrId: ImageModelV3 | string): Promise<ImageModelV3> {
try {
if (typeof modelOrId === 'string') {
// 字符串modelId使用新的ModelResolver解析
return await globalModelResolver.resolveImageModel(
modelOrId, // 支持 'dall-e-3' 和 'aihubmix:openai:dall-e-3'
this.config.providerId // fallback provider
)
// 字符串modelId使用 ModelResolver 解析
// Provider会处理命名空间格式路由如果是HubProvider
return await this.modelResolver.resolveImageModel(modelOrId)
} else {
// 已经是模型,直接返回
return modelOrId
@ -234,11 +230,13 @@ export class RuntimeExecutor<
TSettingsMap extends Record<string, any> = CoreProviderSettingsMap
>(
providerId: T,
provider: ProviderV3, // ✅ Accept provider instance
options: ModelConfig<T, TSettingsMap>['providerSettings'],
plugins?: AiPlugin[]
): RuntimeExecutor<T, TSettingsMap> {
return new RuntimeExecutor({
providerId,
provider, // ✅ Pass provider to config
providerSettings: options,
plugins
})
@ -246,13 +244,16 @@ export class RuntimeExecutor<
/**
* OpenAI Compatible执行器
* Now accepts provider instance directly
*/
static createOpenAICompatible(
provider: ProviderV3, // ✅ Accept provider instance
options: ModelConfig<'openai-compatible'>['providerSettings'],
plugins: AiPlugin[] = []
): RuntimeExecutor<'openai-compatible'> {
return new RuntimeExecutor({
providerId: 'openai-compatible',
provider, // ✅ Pass provider to config
providerSettings: options,
plugins
})

View File

@ -14,7 +14,7 @@ export type { RuntimeConfig } from './types'
import type { LanguageModelV3Middleware } from '@ai-sdk/provider'
import { type AiPlugin } from '../plugins'
import { extensionRegistry, globalProviderStorage } from '../providers'
import { extensionRegistry } from '../providers'
import { type CoreProviderSettingsMap, type RegisteredProviderId } from '../providers/types'
import { RuntimeExecutor } from './executor'
@ -26,32 +26,15 @@ export async function createExecutor<T extends RegisteredProviderId & keyof Core
providerId: T,
options: CoreProviderSettingsMap[T],
plugins?: AiPlugin[]
): Promise<RuntimeExecutor<T>>
export async function createExecutor<T extends string>(
providerId: T,
options: any,
plugins?: AiPlugin[]
): Promise<RuntimeExecutor<T>>
export async function createExecutor(
providerId: string,
options: any,
plugins?: AiPlugin[]
): Promise<RuntimeExecutor<string>> {
// 确保 provider 已初始化
if (!globalProviderStorage.has(providerId) && extensionRegistry.has(providerId)) {
try {
await extensionRegistry.createProvider(providerId, options || {}, providerId)
} catch (error) {
// 创建失败会在 ModelResolver 抛出更详细的错误
console.warn(`Failed to auto-initialize provider "${providerId}":`, error)
}
): Promise<RuntimeExecutor<T>> {
if (!extensionRegistry.has(providerId)) {
throw new Error(`Provider extension "${providerId}" not registered`)
}
return RuntimeExecutor.create(providerId as RegisteredProviderId, options, plugins)
const provider = await extensionRegistry.createProvider<T>(providerId, options || {})
return RuntimeExecutor.create<T, CoreProviderSettingsMap>(providerId, provider, options, plugins)
}
// === 直接调用API无需创建executor实例===
/**
* - middlewares
*/
@ -96,11 +79,13 @@ export async function generateImage<T extends RegisteredProviderId & keyof CoreP
/**
* OpenAI Compatible
*/
export function createOpenAICompatibleExecutor(
export async function createOpenAICompatibleExecutor(
options: CoreProviderSettingsMap['openai-compatible'],
plugins?: AiPlugin[]
): RuntimeExecutor<'openai-compatible'> {
return RuntimeExecutor.createOpenAICompatible(options, plugins)
): Promise<RuntimeExecutor<'openai-compatible'>> {
const provider = await extensionRegistry.createProvider('openai-compatible', options)
return RuntimeExecutor.createOpenAICompatible(provider, options, plugins)
}
// === Agent 功能预留 ===

View File

@ -1,7 +1,7 @@
/**
* Runtime
*/
import type { ImageModelV3 } from '@ai-sdk/provider'
import type { ImageModelV3, ProviderV3 } from '@ai-sdk/provider'
import type { generateImage, generateText, streamText } from 'ai'
import { type ModelConfig } from '../models/types'
@ -19,6 +19,7 @@ export interface RuntimeConfig<
TSettingsMap extends Record<string, any> = CoreProviderSettingsMap
> {
providerId: T
provider: ProviderV3
providerSettings: ModelConfig<T, TSettingsMap>['providerSettings']
plugins?: AiPlugin[]
}

View File

@ -1 +1,8 @@
export type PlainObject = Record<string, any>
/**
* Provider settings map for HubProvider
* Key: provider ID (string)
* Value: provider settings object
*/
export type ProviderSettingsMap = Map<string, Record<string, unknown>>

View File

@ -15,7 +15,7 @@ export {
} from './core/runtime'
// ==================== 高级API ====================
export { isV2Model, isV3Model, globalModelResolver as modelResolver } from './core/models'
export { isV2Model, isV3Model } from './core/models'
// ==================== 插件系统 ====================
export type {

View File

@ -1,12 +1,12 @@
/**
* Test Utilities
* Helper functions for testing AI Core functionality
* Common Test Utilities
* General-purpose helper functions for testing
*/
import { expect, vi } from 'vitest'
import type { ProviderId } from '../fixtures/mock-providers'
import { createMockImageModel, createMockLanguageModel, mockProviderConfigs } from '../fixtures/mock-providers'
import type { ProviderId } from '../mocks/providers'
import { createMockImageModel, createMockLanguageModel, mockProviderConfigs } from '../mocks/providers'
/**
* Creates a test provider with streaming support

View File

@ -16,9 +16,9 @@ import { MockLanguageModelV3 } from 'ai/test'
import { vi } from 'vitest'
import * as z from 'zod'
import type { StreamTextParams, StreamTextResult } from '../../core/plugins'
import type { RegisteredProviderId } from '../../core/providers/types'
import type { AiRequestContext } from '../../types'
import type { StreamTextParams, StreamTextResult } from '../../src/core/plugins'
import type { RegisteredProviderId } from '../../src/core/providers/types'
import type { AiRequestContext } from '../../src/types'
/**
* Type for partial overrides that allows omitting the model field
@ -137,45 +137,95 @@ export function createMockProviderV3(overrides?: {
imageModel?: (modelId: string) => ImageModelV3
embeddingModel?: (modelId: string) => EmbeddingModelV3
}): ProviderV3 {
const defaultLanguageModel = (modelId: string) =>
({
specificationVersion: 'v3',
provider: overrides?.provider ?? 'mock-provider',
modelId,
defaultObjectGenerationMode: 'tool',
supportedUrls: {},
doGenerate: vi.fn().mockResolvedValue({
text: 'Mock response text',
finishReason: 'stop',
usage: {
inputTokens: 10,
outputTokens: 20,
totalTokens: 30,
inputTokenDetails: {},
outputTokenDetails: {}
},
rawCall: { rawPrompt: null, rawSettings: {} },
rawResponse: { headers: {} },
warnings: []
}),
doStream: vi.fn().mockReturnValue({
stream: (async function* () {
yield { type: 'text-delta', textDelta: 'Mock ' }
yield { type: 'text-delta', textDelta: 'streaming ' }
yield { type: 'text-delta', textDelta: 'response' }
yield {
type: 'finish',
finishReason: 'stop',
usage: {
inputTokens: 10,
outputTokens: 15,
totalTokens: 25,
inputTokenDetails: {},
outputTokenDetails: {}
}
}
})(),
rawCall: { rawPrompt: null, rawSettings: {} },
rawResponse: { headers: {} },
warnings: []
})
}) as LanguageModelV3
const defaultImageModel = (modelId: string) =>
({
specificationVersion: 'v3',
provider: overrides?.provider ?? 'mock-provider',
modelId,
maxImagesPerCall: undefined,
doGenerate: vi.fn().mockResolvedValue({
images: [
{
base64: 'mock-base64-image-data',
uint8Array: new Uint8Array([1, 2, 3, 4, 5]),
mimeType: 'image/png'
}
],
warnings: []
})
}) as ImageModelV3
const defaultEmbeddingModel = (modelId: string) =>
({
specificationVersion: 'v3',
provider: overrides?.provider ?? 'mock-provider',
modelId,
maxEmbeddingsPerCall: 100,
supportsParallelCalls: true,
doEmbed: vi.fn().mockResolvedValue({
embeddings: [
[0.1, 0.2, 0.3, 0.4, 0.5],
[0.6, 0.7, 0.8, 0.9, 1.0]
],
usage: {
inputTokens: 10,
totalTokens: 10
},
rawResponse: { headers: {} }
})
}) as EmbeddingModelV3
return {
specificationVersion: 'v3',
provider: overrides?.provider ?? 'mock-provider',
languageModel: overrides?.languageModel
? overrides.languageModel
: (modelId: string) =>
({
specificationVersion: 'v3',
provider: overrides?.provider ?? 'mock-provider',
modelId,
defaultObjectGenerationMode: 'tool',
supportedUrls: {},
doGenerate: vi.fn(),
doStream: vi.fn()
}) as LanguageModelV3,
imageModel: overrides?.imageModel
? overrides.imageModel
: (modelId: string) =>
({
specificationVersion: 'v3',
provider: overrides?.provider ?? 'mock-provider',
modelId,
maxImagesPerCall: undefined,
doGenerate: vi.fn()
}) as ImageModelV3,
embeddingModel: overrides?.embeddingModel
? overrides.embeddingModel
: (modelId: string) =>
({
specificationVersion: 'v3',
provider: overrides?.provider ?? 'mock-provider',
modelId,
maxEmbeddingsPerCall: 100,
supportsParallelCalls: true,
doEmbed: vi.fn()
}) as EmbeddingModelV3
languageModel: vi.fn(overrides?.languageModel ?? defaultLanguageModel),
imageModel: vi.fn(overrides?.imageModel ?? defaultImageModel),
embeddingModel: vi.fn(overrides?.embeddingModel ?? defaultEmbeddingModel)
} as ProviderV3
}

View File

@ -0,0 +1,13 @@
/**
* Test Infrastructure Exports
* Central export point for all test utilities, fixtures, and helpers
*/
// Mocks
export * from './mocks/providers'
export * from './mocks/responses'
// Helpers
export * from './helpers/common'
export * from './helpers/model'
export * from './helpers/provider'

View File

@ -11,11 +11,16 @@
"noEmitOnError": false,
"outDir": "./dist",
"resolveJsonModule": true,
"rootDir": "./src",
"rootDir": ".",
"skipLibCheck": true,
"strict": true,
"target": "ES2020"
"target": "ES2020",
"baseUrl": ".",
"paths": {
"@test-utils": ["./test_utils"],
"@test-utils/*": ["./test_utils/*"]
}
},
"exclude": ["node_modules", "dist"],
"include": ["src/**/*"]
"include": ["src/**/*", "test_utils/**/*"]
}

View File

@ -8,13 +8,14 @@ const __dirname = path.dirname(fileURLToPath(import.meta.url))
export default defineConfig({
test: {
globals: true,
setupFiles: [path.resolve(__dirname, './src/__tests__/setup.ts')]
setupFiles: [path.resolve(__dirname, './test_utils/setup.ts')]
},
resolve: {
alias: {
'@': path.resolve(__dirname, './src'),
'@test-utils': path.resolve(__dirname, './test_utils'),
// Mock external packages that may not be available in test environment
'@cherrystudio/ai-sdk-provider': path.resolve(__dirname, './src/__tests__/mocks/ai-sdk-provider.ts')
'@cherrystudio/ai-sdk-provider': path.resolve(__dirname, './test_utils/mocks/ai-sdk-provider.ts')
}
},
esbuild: {

View File

@ -1936,6 +1936,7 @@ __metadata:
"@ai-sdk/provider": "npm:^3.0.0"
"@ai-sdk/provider-utils": "npm:^4.0.0"
"@ai-sdk/xai": "npm:^3.0.0"
lru-cache: "npm:^11.2.4"
tsdown: "npm:^0.12.9"
typescript: "npm:^5.0.0"
vitest: "npm:^3.2.4"
@ -18183,6 +18184,13 @@ __metadata:
languageName: node
linkType: hard
"lru-cache@npm:^11.2.4":
version: 11.2.4
resolution: "lru-cache@npm:11.2.4"
checksum: 10c0/4a24f9b17537619f9144d7b8e42cd5a225efdfd7076ebe7b5e7dc02b860a818455201e67fbf000765233fe7e339d3c8229fc815e9b58ee6ede511e07608c19b2
languageName: node
linkType: hard
"lru-cache@npm:^5.1.1":
version: 5.1.1
resolution: "lru-cache@npm:5.1.1"