mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2026-02-14 14:53:20 +08:00
refactor: remove ProviderId
This commit is contained in:
parent
7268d8eef2
commit
0f276a3f1d
@ -17,7 +17,7 @@ import { vi } from 'vitest'
|
||||
import * as z from 'zod'
|
||||
|
||||
import type { StreamTextParams, StreamTextResult } from '../../core/plugins'
|
||||
import type { ProviderId } from '../../core/providers/types'
|
||||
import type { RegisteredProviderId } from '../../core/providers/types'
|
||||
import type { AiRequestContext } from '../../types'
|
||||
|
||||
/**
|
||||
@ -47,7 +47,7 @@ export function createMockContext(overrides?: ContextOverrides): AiRequestContex
|
||||
})
|
||||
|
||||
const base: AiRequestContext<StreamTextParams, StreamTextResult> = {
|
||||
providerId: 'openai' as ProviderId,
|
||||
providerId: 'openai' as RegisteredProviderId,
|
||||
model: mockModel,
|
||||
originalParams: {
|
||||
model: mockModel,
|
||||
|
||||
@ -13,14 +13,14 @@ export type {
|
||||
} from './types'
|
||||
import type { ImageModel, LanguageModel } from 'ai'
|
||||
|
||||
import type { ProviderId } from '../providers'
|
||||
import type { RegisteredProviderId } from '../providers'
|
||||
import type { AiPlugin, AiRequestContext } from './types'
|
||||
|
||||
// 插件管理器
|
||||
export { PluginManager } from './manager'
|
||||
|
||||
// 工具函数
|
||||
export function createContext<T extends ProviderId, TParams = unknown, TResult = unknown>(
|
||||
export function createContext<T extends RegisteredProviderId | (string & {}), TParams = unknown, TResult = unknown>(
|
||||
providerId: T,
|
||||
model: LanguageModel | ImageModel,
|
||||
originalParams: TParams
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import type { JSONObject, JSONValue } from '@ai-sdk/provider'
|
||||
import type { generateText, LanguageModelMiddleware, streamText, TextStreamPart, ToolSet } from 'ai'
|
||||
|
||||
import type { AiSdkModel, ProviderId } from '../providers/types'
|
||||
import type { AiSdkModel, RegisteredProviderId } from '../providers/types'
|
||||
|
||||
/**
|
||||
* 常用的 AI SDK 参数类型(完整版,用于插件泛型)
|
||||
@ -39,7 +39,7 @@ export type RecursiveCallFn<TParams = unknown, TResult = unknown> = (newParams:
|
||||
* 使用泛型参数以支持不同类型的请求
|
||||
*/
|
||||
export interface AiRequestContext<TParams = unknown, TResult = unknown> {
|
||||
providerId: ProviderId
|
||||
providerId: RegisteredProviderId | (string & {})
|
||||
model: AiSdkModel
|
||||
originalParams: TParams
|
||||
metadata: AiRequestMetadata
|
||||
|
||||
@ -6,7 +6,6 @@
|
||||
import type { ProviderV3 } from '@ai-sdk/provider'
|
||||
|
||||
import type { RegisteredProviderId } from '../index'
|
||||
import type { ProviderId } from '../types'
|
||||
import { type ProviderExtension } from './ProviderExtension'
|
||||
import { ProviderCreationError } from './utils'
|
||||
|
||||
@ -250,13 +249,13 @@ export class ExtensionRegistry {
|
||||
* parseProviderId('unknown') // → null
|
||||
* ```
|
||||
*/
|
||||
parseProviderId(providerId: string): { baseId: ProviderId; mode?: string; isVariant: boolean } | null {
|
||||
parseProviderId(providerId: string): { baseId: RegisteredProviderId; mode?: string; isVariant: boolean } | null {
|
||||
// 先检查是否是已注册的 extension(直接或通过别名)
|
||||
const extension = this.get(providerId)
|
||||
if (extension) {
|
||||
// 是基础 ID 或别名,不是变体
|
||||
return {
|
||||
baseId: extension.config.name as ProviderId,
|
||||
baseId: extension.config.name as RegisteredProviderId,
|
||||
isVariant: false
|
||||
}
|
||||
}
|
||||
@ -272,7 +271,7 @@ export class ExtensionRegistry {
|
||||
const variantId = `${ext.config.name}-${variant.suffix}`
|
||||
if (variantId === providerId) {
|
||||
return {
|
||||
baseId: ext.config.name as ProviderId,
|
||||
baseId: ext.config.name as RegisteredProviderId,
|
||||
mode: variant.suffix,
|
||||
isVariant: true
|
||||
}
|
||||
@ -322,7 +321,7 @@ export class ExtensionRegistry {
|
||||
* getBaseProviderId('unknown') // → null
|
||||
* ```
|
||||
*/
|
||||
getBaseProviderId(id: string): ProviderId | null {
|
||||
getBaseProviderId(id: string): RegisteredProviderId | null {
|
||||
const parsed = this.parseProviderId(id)
|
||||
return parsed?.baseId ?? null
|
||||
}
|
||||
|
||||
@ -21,10 +21,8 @@ export {
|
||||
|
||||
// ==================== 基础数据和类型 ====================
|
||||
|
||||
// 类型定义和Schema
|
||||
export type { AiSdkModel, ProviderError, ProviderTypeRegistrar } from './types'
|
||||
export type { ProviderId } from './types/schemas'
|
||||
export { providerIdSchema } from './types/schemas'
|
||||
// 类型定义
|
||||
export type { AiSdkModel, ProviderError } from './types'
|
||||
|
||||
// 类型提取工具(用于应用层 Merge Point 模式)
|
||||
export type {
|
||||
|
||||
@ -12,13 +12,14 @@ import type {
|
||||
|
||||
import type { coreExtensions, CoreProviderId } from '../core/initialization'
|
||||
import type { ProviderExtension } from '../core/ProviderExtension'
|
||||
// 导入基于 Zod 的 ProviderId 类型
|
||||
import { type ProviderId as ZodProviderId } from './schemas'
|
||||
|
||||
/**
|
||||
* 核心 Provider ID
|
||||
* 已注册的 Provider ID
|
||||
* 从 coreExtensions 数组自动提取所有 Provider IDs
|
||||
* 类型安全的 literal union
|
||||
*
|
||||
* 如果需要支持动态/未注册的 provider,使用:
|
||||
* RegisteredProviderId | (string & {})
|
||||
*/
|
||||
export type RegisteredProviderId = CoreProviderId
|
||||
|
||||
@ -35,14 +36,6 @@ export class ProviderError extends Error {
|
||||
}
|
||||
}
|
||||
|
||||
// 动态ProviderId类型 - 基于 Zod Schema,支持运行时扩展和验证
|
||||
export type ProviderId = ZodProviderId
|
||||
|
||||
export interface ProviderTypeRegistrar {
|
||||
registerProviderType<T extends string, S>(providerId: T, settingsType: S): void
|
||||
getProviderSettings<T extends string>(providerId: T): any
|
||||
}
|
||||
|
||||
export type AiSdkModel = LanguageModel | ImageModel | EmbeddingModel | TranscriptionModel | SpeechModel
|
||||
export type AiSdkProvider = ProviderV2 | ProviderV3
|
||||
export type AiSdkUsage = LanguageModelUsage | ImageModelUsage | EmbeddingModelUsage
|
||||
|
||||
@ -1,16 +0,0 @@
|
||||
/**
|
||||
* Provider ID Schema
|
||||
*/
|
||||
|
||||
import * as z from 'zod'
|
||||
|
||||
/**
|
||||
* Provider ID Schema
|
||||
* 通过 module augmentation 扩展的类型安全 ID
|
||||
*/
|
||||
export const providerIdSchema = z.string().min(1)
|
||||
|
||||
/**
|
||||
* Provider ID 类型 - 基于 zod schema 推导
|
||||
*/
|
||||
export type ProviderId = z.infer<typeof providerIdSchema>
|
||||
@ -13,13 +13,13 @@ import {
|
||||
type StreamTextParams,
|
||||
type StreamTextResult
|
||||
} from '../plugins'
|
||||
import { type ProviderId } from '../providers/types'
|
||||
import { type RegisteredProviderId } from '../providers/types'
|
||||
|
||||
/**
|
||||
* 插件增强的 AI 客户端
|
||||
* 专注于插件处理,不暴露用户API
|
||||
*/
|
||||
export class PluginEngine<T extends ProviderId = ProviderId> {
|
||||
export class PluginEngine<T extends RegisteredProviderId | (string & {}) = RegisteredProviderId> {
|
||||
/**
|
||||
* Plugin storage with explicit any/any generics
|
||||
*
|
||||
|
||||
@ -34,15 +34,7 @@ export { createContext, definePlugin, PluginManager } from './core/plugins'
|
||||
export { PluginEngine } from './core/runtime/pluginEngine'
|
||||
|
||||
// ==================== 类型工具 ====================
|
||||
export type { ModelId, ProviderId, RequestId } from './core/types/branded'
|
||||
export { isModelId, isProviderId, isRequestId } from './core/types/branded'
|
||||
// Branded type constructors (values, not types)
|
||||
export type { AiSdkModel } from './core/providers'
|
||||
export {
|
||||
ModelId as createModelId,
|
||||
ProviderId as createProviderId,
|
||||
RequestId as createRequestId
|
||||
} from './core/types/branded'
|
||||
|
||||
// ==================== 选项 ====================
|
||||
export {
|
||||
|
||||
Loading…
Reference in New Issue
Block a user