mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2026-01-24 12:41:56 +08:00
- Move test utilities from src/__tests__/ to test_utils/ - Fix ModelResolver tests for simplified API (2 params instead of 4) - Fix generateImage/generateText tests with proper vi.fn() mocks - Fix ExtensionRegistry.parseProviderId to check variants before aliases - Add createProvider method overload for dynamic provider IDs - Update ProviderExtension tests for runtime validation behavior - Delete outdated tests: initialization.test.ts, extensions.integration.test.ts, executor-resolveModel.test.ts - Remove 3 skipped tests for removed validate hook - Add HubProvider.integration.test.ts - All 359 tests passing, 0 skipped 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
351 lines
9.2 KiB
TypeScript
351 lines
9.2 KiB
TypeScript
/**
|
|
* Model Test Utilities
|
|
* Provides comprehensive mock creators for AI SDK v3 models and related test utilities
|
|
*/
|
|
|
|
import type {
|
|
EmbeddingModelV3,
|
|
ImageModelV3,
|
|
LanguageModelV3,
|
|
LanguageModelV3Middleware,
|
|
ProviderV3
|
|
} from '@ai-sdk/provider'
|
|
import type { Tool, ToolSet } from 'ai'
|
|
import { tool } from 'ai'
|
|
import { MockLanguageModelV3 } from 'ai/test'
|
|
import { vi } from 'vitest'
|
|
import * as z from 'zod'
|
|
|
|
import type { StreamTextParams, StreamTextResult } from '../../src/core/plugins'
|
|
import type { RegisteredProviderId } from '../../src/core/providers/types'
|
|
import type { AiRequestContext } from '../../src/types'
|
|
|
|
/**
|
|
* Type for partial overrides that allows omitting the model field
|
|
* The model will be automatically added by createMockContext
|
|
*/
|
|
type ContextOverrides = Partial<Omit<AiRequestContext<StreamTextParams, StreamTextResult>, 'originalParams'>> & {
|
|
originalParams?: Partial<Omit<StreamTextParams, 'model'>> & { model?: StreamTextParams['model'] }
|
|
}
|
|
|
|
/**
|
|
* Creates a mock AiRequestContext with type safety
|
|
* The model field is automatically added to originalParams if not provided
|
|
*
|
|
* @example
|
|
* ```ts
|
|
* const context = createMockContext({
|
|
* providerId: 'openai',
|
|
* metadata: { requestId: 'test-123' }
|
|
* })
|
|
* ```
|
|
*/
|
|
export function createMockContext(overrides?: ContextOverrides): AiRequestContext<StreamTextParams, StreamTextResult> {
|
|
const mockModel = new MockLanguageModelV3({
|
|
provider: 'test-provider',
|
|
modelId: 'test-model'
|
|
})
|
|
|
|
const base: AiRequestContext<StreamTextParams, StreamTextResult> = {
|
|
providerId: 'openai' as RegisteredProviderId,
|
|
model: mockModel,
|
|
originalParams: {
|
|
model: mockModel,
|
|
messages: [{ role: 'user', content: 'Test message' }]
|
|
} as StreamTextParams,
|
|
metadata: {},
|
|
startTime: Date.now(),
|
|
requestId: 'test-request-id',
|
|
recursiveCall: vi.fn(),
|
|
isRecursiveCall: false,
|
|
recursiveDepth: 0,
|
|
maxRecursiveDepth: 10,
|
|
extensions: new Map()
|
|
}
|
|
|
|
if (overrides) {
|
|
// Ensure model is always present in originalParams
|
|
const mergedOriginalParams = {
|
|
...base.originalParams,
|
|
...overrides.originalParams,
|
|
model: overrides.originalParams?.model ?? mockModel
|
|
}
|
|
|
|
return {
|
|
...base,
|
|
...overrides,
|
|
originalParams: mergedOriginalParams as StreamTextParams
|
|
}
|
|
}
|
|
|
|
return base
|
|
}
|
|
|
|
/**
|
|
* Creates a mock embedding model with customizable behavior
|
|
* Compliant with AI SDK v3 specification
|
|
*
|
|
* @example
|
|
* ```ts
|
|
* const embeddingModel = createMockEmbeddingModel({
|
|
* provider: 'openai',
|
|
* modelId: 'text-embedding-3-small',
|
|
* maxEmbeddingsPerCall: 2048
|
|
* })
|
|
* ```
|
|
*/
|
|
export function createMockEmbeddingModel(overrides?: Partial<EmbeddingModelV3>): EmbeddingModelV3 {
|
|
return {
|
|
specificationVersion: 'v3',
|
|
provider: 'mock-provider',
|
|
modelId: 'mock-embedding-model',
|
|
maxEmbeddingsPerCall: 100,
|
|
supportsParallelCalls: true,
|
|
|
|
doEmbed: vi.fn().mockResolvedValue({
|
|
embeddings: [
|
|
[0.1, 0.2, 0.3, 0.4, 0.5],
|
|
[0.6, 0.7, 0.8, 0.9, 1.0]
|
|
],
|
|
usage: {
|
|
inputTokens: 10,
|
|
totalTokens: 10
|
|
},
|
|
rawResponse: { headers: {} }
|
|
}),
|
|
|
|
...overrides
|
|
} as EmbeddingModelV3
|
|
}
|
|
|
|
/**
|
|
* Creates a complete mock ProviderV3 with all model types
|
|
* Useful for testing provider registration and management
|
|
*
|
|
* @example
|
|
* ```ts
|
|
* const provider = createMockProviderV3({
|
|
* provider: 'openai',
|
|
* languageModel: customLanguageModel,
|
|
* imageModel: customImageModel
|
|
* })
|
|
* ```
|
|
*/
|
|
export function createMockProviderV3(overrides?: {
|
|
provider?: string
|
|
languageModel?: (modelId: string) => LanguageModelV3
|
|
imageModel?: (modelId: string) => ImageModelV3
|
|
embeddingModel?: (modelId: string) => EmbeddingModelV3
|
|
}): ProviderV3 {
|
|
const defaultLanguageModel = (modelId: string) =>
|
|
({
|
|
specificationVersion: 'v3',
|
|
provider: overrides?.provider ?? 'mock-provider',
|
|
modelId,
|
|
defaultObjectGenerationMode: 'tool',
|
|
supportedUrls: {},
|
|
doGenerate: vi.fn().mockResolvedValue({
|
|
text: 'Mock response text',
|
|
finishReason: 'stop',
|
|
usage: {
|
|
inputTokens: 10,
|
|
outputTokens: 20,
|
|
totalTokens: 30,
|
|
inputTokenDetails: {},
|
|
outputTokenDetails: {}
|
|
},
|
|
rawCall: { rawPrompt: null, rawSettings: {} },
|
|
rawResponse: { headers: {} },
|
|
warnings: []
|
|
}),
|
|
doStream: vi.fn().mockReturnValue({
|
|
stream: (async function* () {
|
|
yield { type: 'text-delta', textDelta: 'Mock ' }
|
|
yield { type: 'text-delta', textDelta: 'streaming ' }
|
|
yield { type: 'text-delta', textDelta: 'response' }
|
|
yield {
|
|
type: 'finish',
|
|
finishReason: 'stop',
|
|
usage: {
|
|
inputTokens: 10,
|
|
outputTokens: 15,
|
|
totalTokens: 25,
|
|
inputTokenDetails: {},
|
|
outputTokenDetails: {}
|
|
}
|
|
}
|
|
})(),
|
|
rawCall: { rawPrompt: null, rawSettings: {} },
|
|
rawResponse: { headers: {} },
|
|
warnings: []
|
|
})
|
|
}) as LanguageModelV3
|
|
|
|
const defaultImageModel = (modelId: string) =>
|
|
({
|
|
specificationVersion: 'v3',
|
|
provider: overrides?.provider ?? 'mock-provider',
|
|
modelId,
|
|
maxImagesPerCall: undefined,
|
|
doGenerate: vi.fn().mockResolvedValue({
|
|
images: [
|
|
{
|
|
base64: 'mock-base64-image-data',
|
|
uint8Array: new Uint8Array([1, 2, 3, 4, 5]),
|
|
mimeType: 'image/png'
|
|
}
|
|
],
|
|
warnings: []
|
|
})
|
|
}) as ImageModelV3
|
|
|
|
const defaultEmbeddingModel = (modelId: string) =>
|
|
({
|
|
specificationVersion: 'v3',
|
|
provider: overrides?.provider ?? 'mock-provider',
|
|
modelId,
|
|
maxEmbeddingsPerCall: 100,
|
|
supportsParallelCalls: true,
|
|
doEmbed: vi.fn().mockResolvedValue({
|
|
embeddings: [
|
|
[0.1, 0.2, 0.3, 0.4, 0.5],
|
|
[0.6, 0.7, 0.8, 0.9, 1.0]
|
|
],
|
|
usage: {
|
|
inputTokens: 10,
|
|
totalTokens: 10
|
|
},
|
|
rawResponse: { headers: {} }
|
|
})
|
|
}) as EmbeddingModelV3
|
|
|
|
return {
|
|
specificationVersion: 'v3',
|
|
provider: overrides?.provider ?? 'mock-provider',
|
|
|
|
languageModel: vi.fn(overrides?.languageModel ?? defaultLanguageModel),
|
|
imageModel: vi.fn(overrides?.imageModel ?? defaultImageModel),
|
|
embeddingModel: vi.fn(overrides?.embeddingModel ?? defaultEmbeddingModel)
|
|
} as ProviderV3
|
|
}
|
|
|
|
/**
|
|
* Creates a mock middleware for testing middleware chains
|
|
* Supports both generate and stream wrapping
|
|
*
|
|
* @example
|
|
* ```ts
|
|
* const middleware = createMockMiddleware({
|
|
* name: 'test-middleware'
|
|
* })
|
|
* ```
|
|
*/
|
|
export function createMockMiddleware(): LanguageModelV3Middleware {
|
|
return {
|
|
specificationVersion: 'v3',
|
|
wrapGenerate: vi.fn((doGenerate) => doGenerate),
|
|
wrapStream: vi.fn((doStream) => doStream)
|
|
}
|
|
}
|
|
|
|
/**
|
|
* Creates a type-safe function tool for testing using AI SDK's tool() function
|
|
*
|
|
* @example
|
|
* ```ts
|
|
* const weatherTool = createMockTool('getWeather', 'Get current weather')
|
|
* ```
|
|
*/
|
|
export function createMockTool(name: string, description?: string): Tool<{ value?: string }, string> {
|
|
return tool({
|
|
description: description || `Mock tool: ${name}`,
|
|
inputSchema: z.object({
|
|
value: z.string().optional()
|
|
}),
|
|
execute: vi.fn(async () => 'mock result')
|
|
})
|
|
}
|
|
|
|
/**
|
|
* Creates a provider-defined tool for testing
|
|
*/
|
|
export function createMockProviderTool(name: string, description?: string): { type: 'provider'; description: string } {
|
|
return {
|
|
type: 'provider' as const,
|
|
description: description || `Mock provider tool: ${name}`
|
|
}
|
|
}
|
|
|
|
/**
|
|
* Creates a ToolSet with multiple tools
|
|
*
|
|
* @example
|
|
* ```ts
|
|
* const tools = createMockToolSet({
|
|
* getWeather: 'function',
|
|
* searchDatabase: 'function',
|
|
* nativeSearch: 'provider'
|
|
* })
|
|
* ```
|
|
*/
|
|
export function createMockToolSet(tools: Record<string, 'function' | 'provider'>): ToolSet {
|
|
const toolSet: ToolSet = {}
|
|
|
|
for (const [name, type] of Object.entries(tools)) {
|
|
if (type === 'function') {
|
|
toolSet[name] = createMockTool(name)
|
|
} else {
|
|
toolSet[name] = createMockProviderTool(name) as Tool
|
|
}
|
|
}
|
|
|
|
return toolSet
|
|
}
|
|
|
|
/**
|
|
* Creates mock stream params for testing
|
|
*
|
|
* @example
|
|
* ```ts
|
|
* const params = createMockStreamParams({
|
|
* messages: [{ role: 'user', content: 'Custom message' }],
|
|
* temperature: 0.7
|
|
* })
|
|
* ```
|
|
*/
|
|
export function createMockStreamParams(overrides?: Partial<StreamTextParams>): StreamTextParams {
|
|
return {
|
|
messages: [{ role: 'user', content: 'Test message' }],
|
|
...overrides
|
|
} as StreamTextParams
|
|
}
|
|
|
|
/**
|
|
* Common mock model instances for quick testing
|
|
*/
|
|
export const mockModels = {
|
|
/** Standard language model for general testing */
|
|
language: new MockLanguageModelV3({
|
|
provider: 'test-provider',
|
|
modelId: 'test-model'
|
|
}),
|
|
|
|
/** Mock OpenAI GPT-4 model */
|
|
gpt4: new MockLanguageModelV3({
|
|
provider: 'openai',
|
|
modelId: 'gpt-4'
|
|
}),
|
|
|
|
/** Mock Anthropic Claude model */
|
|
claude: new MockLanguageModelV3({
|
|
provider: 'anthropic',
|
|
modelId: 'claude-3-5-sonnet-20241022'
|
|
}),
|
|
|
|
/** Mock Google Gemini model */
|
|
gemini: new MockLanguageModelV3({
|
|
provider: 'google',
|
|
modelId: 'gemini-2.0-flash-exp'
|
|
})
|
|
} as const
|