From 953b6d30427cc4e5afc40287c94e79ca0bdaf0f9 Mon Sep 17 00:00:00 2001 From: suyao Date: Fri, 2 Jan 2026 09:38:34 +0800 Subject: [PATCH] refactor: use ai sdk getFromApi --- .../src/aiCore/services/ModelListService.ts | 83 +++- .../src/aiCore/services/get-from-api.ts | 393 ------------------ src/renderer/src/aiCore/services/index.ts | 1 - 3 files changed, 65 insertions(+), 412 deletions(-) delete mode 100644 src/renderer/src/aiCore/services/get-from-api.ts diff --git a/src/renderer/src/aiCore/services/ModelListService.ts b/src/renderer/src/aiCore/services/ModelListService.ts index ebc105bf16..bf7f06c78f 100644 --- a/src/renderer/src/aiCore/services/ModelListService.ts +++ b/src/renderer/src/aiCore/services/ModelListService.ts @@ -1,15 +1,22 @@ /** * ModelListService - Unified model listing service - * Fetches model lists from various providers using getFromApi + * Fetches model lists from various providers using AI SDK's getFromApi */ + +import { + createJsonErrorResponseHandler, + createJsonResponseHandler, + getFromApi as aiSdkGetFromApi, + zodSchema +} from '@ai-sdk/provider-utils' import { loggerService } from '@logger' -import type { Model, Provider } from '@renderer/types' +import type { EndpointType, Model, Provider } from '@renderer/types' import { SystemProviderIds } from '@renderer/types' import { formatApiHost, withoutTrailingSlash } from '@renderer/utils' import { isAIGatewayProvider, isGeminiProvider, isOllamaProvider } from '@renderer/utils/provider' import { defaultAppHeaders } from '@shared/utils' +import * as z from 'zod' -import { APICallError, getJsonFromApi } from './get-from-api' import { type GeminiModelsResponse, GeminiModelsResponseSchema, @@ -29,6 +36,47 @@ import { const logger = loggerService.withContext('ModelListService') +// Error schema for API error responses +const ApiErrorSchema = z.object({ + error: z + .object({ + message: z.string().optional(), + code: z.string().optional() + }) + .optional(), + message: z.string().optional() +}) + +type ApiError = z.infer + +/** + * Type-safe fetch wrapper using AI SDK's getFromApi with Zod schema validation + */ +async function getFromApi({ + url, + headers, + responseSchema, + abortSignal +}: { + url: string + headers?: Record + responseSchema: z.ZodType + abortSignal?: AbortSignal +}): Promise { + const { value } = await aiSdkGetFromApi({ + url, + headers, + successfulResponseHandler: createJsonResponseHandler(zodSchema(responseSchema)), + failedResponseHandler: createJsonErrorResponseHandler({ + errorSchema: zodSchema(ApiErrorSchema), + errorToMessage: (error: ApiError) => error.error?.message || error.message || 'Unknown error' + }), + abortSignal + }) + + return value +} + // === Helper Functions === function getApiKey(provider: Provider): string { @@ -237,7 +285,8 @@ function convertNewApiModelsToModels(provider: Provider, response: NewApiModelsR provider: provider.id, group: getDefaultGroupName(id, provider.id), owned_by: model.owned_by, - supported_endpoint_types: model.supported_endpoint_types as any + // The Zod schema type is a subset of EndpointType, safe to cast + supported_endpoint_types: model.supported_endpoint_types as EndpointType[] | undefined }) } @@ -320,7 +369,7 @@ export class ModelListService { const baseUrl = formatApiHost(provider.apiHost) const url = `${baseUrl}/models` - const response = await getJsonFromApi({ + const response = await getFromApi({ url, headers: getDefaultHeaders(provider), responseSchema: OpenAIModelsResponseSchema, @@ -336,7 +385,7 @@ export class ModelListService { .replace(/\/api$/, '') const url = `${baseUrl}/api/tags` - const response = await getJsonFromApi({ + const response = await getFromApi({ url, headers: getDefaultHeaders(provider), responseSchema: OllamaTagsResponseSchema, @@ -354,7 +403,7 @@ export class ModelListService { const apiVersion = provider.apiVersion || 'v1beta' const url = `${baseUrl}/${apiVersion}/models?key=${getApiKey(provider)}` - const response = await getJsonFromApi({ + const response = await getFromApi({ url, headers: { ...defaultAppHeaders(), @@ -370,7 +419,7 @@ export class ModelListService { private static async fetchGitHubModels(provider: Provider, abortSignal?: AbortSignal): Promise { const url = 'https://models.github.ai/catalog/' - const response = await getJsonFromApi({ + const response = await getFromApi({ url, headers: getDefaultHeaders(provider), responseSchema: GitHubModelsResponseSchema, @@ -384,7 +433,7 @@ export class ModelListService { const baseUrl = formatApiHost(withoutTrailingSlash(provider.apiHost).replace(/\/v1$/, ''), true, 'v1') const url = `${baseUrl}/config` - const response = await getJsonFromApi({ + const response = await getFromApi({ url, headers: getDefaultHeaders(provider), responseSchema: OVMSConfigResponseSchema, @@ -398,7 +447,7 @@ export class ModelListService { const baseUrl = formatApiHost(provider.apiHost) const url = `${baseUrl}/models` - const response = await getJsonFromApi({ + const response = await getFromApi({ url, headers: getDefaultHeaders(provider), responseSchema: TogetherModelsResponseSchema, @@ -412,7 +461,7 @@ export class ModelListService { const baseUrl = formatApiHost(provider.apiHost) const url = `${baseUrl}/models` - const response = await getJsonFromApi({ + const response = await getFromApi({ url, headers: getDefaultHeaders(provider), responseSchema: NewApiModelsResponseSchema, @@ -428,13 +477,13 @@ export class ModelListService { const embedBaseUrl = 'https://openrouter.ai/api/v1/embeddings' const [modelsResponse, embedModelsResponse] = await Promise.all([ - getJsonFromApi({ + getFromApi({ url: `${baseUrl}/models`, headers: getDefaultHeaders(provider), responseSchema: OpenAIModelsResponseSchema, abortSignal }), - getJsonFromApi({ + getFromApi({ url: `${embedBaseUrl}/models`, headers: getDefaultHeaders(provider), responseSchema: OpenAIModelsResponseSchema, @@ -454,19 +503,19 @@ export class ModelListService { // PPIO requires three separate requests to get all model types const [chatModelsResponse, embeddingModelsResponse, rerankerModelsResponse] = await Promise.all([ - getJsonFromApi({ + getFromApi({ url: `${baseUrl}/models`, headers: getDefaultHeaders(provider), responseSchema: OpenAIModelsResponseSchema, abortSignal }), - getJsonFromApi({ + getFromApi({ url: `${baseUrl}/models?model_type=embedding`, headers: getDefaultHeaders(provider), responseSchema: OpenAIModelsResponseSchema, abortSignal }).catch(() => ({ data: [] })), - getJsonFromApi({ + getFromApi({ url: `${baseUrl}/models?model_type=reranker`, headers: getDefaultHeaders(provider), responseSchema: OpenAIModelsResponseSchema, @@ -480,5 +529,3 @@ export class ModelListService { return convertOpenAIModelsToModels(provider, { data: allModels }) } } - -export { APICallError } diff --git a/src/renderer/src/aiCore/services/get-from-api.ts b/src/renderer/src/aiCore/services/get-from-api.ts deleted file mode 100644 index ba00658cb9..0000000000 --- a/src/renderer/src/aiCore/services/get-from-api.ts +++ /dev/null @@ -1,393 +0,0 @@ -/** - * Unified HTTP GET utility for API calls - * Inspired by AI SDK's postToApi pattern - */ -import type * as z from 'zod' - -// === Types === - -export type FetchFunction = typeof globalThis.fetch - -export type ResponseHandler = (options: { url: string; response: Response }) => PromiseLike<{ - value: T - responseHeaders?: Record -}> - -export interface APICallErrorOptions { - message: string - url: string - statusCode?: number - responseHeaders?: Record - responseBody?: string - cause?: unknown - isRetryable?: boolean -} - -// === Error Classes === - -export class APICallError extends Error { - readonly url: string - readonly statusCode?: number - readonly responseHeaders?: Record - readonly responseBody?: string - readonly cause?: unknown - readonly isRetryable: boolean - - constructor(options: APICallErrorOptions) { - super(options.message) - this.name = 'APICallError' - this.url = options.url - this.statusCode = options.statusCode - this.responseHeaders = options.responseHeaders - this.responseBody = options.responseBody - this.cause = options.cause - this.isRetryable = options.isRetryable ?? false - } - - static isInstance(error: unknown): error is APICallError { - return error instanceof APICallError - } -} - -export class JSONParseError extends Error { - readonly text: string - readonly cause?: unknown - - constructor(options: { text: string; cause?: unknown }) { - super('Failed to parse JSON') - this.name = 'JSONParseError' - this.text = options.text - this.cause = options.cause - } - - static isInstance(error: unknown): error is JSONParseError { - return error instanceof JSONParseError - } -} - -export class TypeValidationError extends Error { - readonly value: unknown - readonly cause?: unknown - - constructor(options: { value: unknown; cause?: unknown }) { - super('Type validation failed') - this.name = 'TypeValidationError' - this.value = options.value - this.cause = options.cause - } - - static isInstance(error: unknown): error is TypeValidationError { - return error instanceof TypeValidationError - } -} - -// === Utility Functions === - -function extractResponseHeaders(response: Response): Record { - return Object.fromEntries([...response.headers]) -} - -function isAbortError(error: unknown): error is Error { - return ( - (error instanceof Error || error instanceof DOMException) && - (error.name === 'AbortError' || error.name === 'ResponseAborted' || error.name === 'TimeoutError') - ) -} - -const FETCH_FAILED_ERROR_MESSAGES = ['fetch failed', 'failed to fetch'] - -function handleFetchError({ error, url }: { error: unknown; url: string }) { - if (isAbortError(error)) { - return error - } - - // Unwrap original error when fetch failed (for easier debugging) - if (error instanceof TypeError && FETCH_FAILED_ERROR_MESSAGES.includes(error.message.toLowerCase())) { - const cause = (error as any).cause - - if (cause != null) { - return new APICallError({ - message: `Cannot connect to API: ${cause.message}`, - cause, - url, - isRetryable: true - }) - } - } - - return error -} - -// === JSON Parsing === - -export type ParseResult = - | { success: true; value: T; rawValue: unknown } - | { success: false; error: JSONParseError | TypeValidationError; rawValue?: unknown } - -export async function safeParseJSON(options: { text: string; schema: z.ZodType }): Promise> -export async function safeParseJSON(options: { text: string; schema?: undefined }): Promise> -export async function safeParseJSON({ - text, - schema -}: { - text: string - schema?: z.ZodType -}): Promise> { - try { - const value = JSON.parse(text) - - if (schema == null) { - return { success: true, value: value as T, rawValue: value } - } - - const result = schema.safeParse(value) - if (result.success) { - return { success: true, value: result.data, rawValue: value } - } else { - return { - success: false, - error: new TypeValidationError({ value, cause: result.error }), - rawValue: value - } - } - } catch (error) { - return { - success: false, - error: JSONParseError.isInstance(error) ? error : new JSONParseError({ text, cause: error }), - rawValue: undefined - } - } -} - -// === Response Handlers === - -export const createJsonResponseHandler = - (responseSchema: z.ZodType): ResponseHandler => - async ({ response, url }) => { - const responseBody = await response.text() - const parsedResult = await safeParseJSON({ text: responseBody, schema: responseSchema }) - const responseHeaders = extractResponseHeaders(response) - - if (!parsedResult.success) { - throw new APICallError({ - message: 'Invalid JSON response', - cause: parsedResult.error, - statusCode: response.status, - responseHeaders, - responseBody, - url - }) - } - - return { - responseHeaders, - value: parsedResult.value - } - } - -export const createJsonErrorResponseHandler = - ({ - errorSchema, - errorToMessage, - isRetryable - }: { - errorSchema: z.ZodType - errorToMessage: (error: T) => string - isRetryable?: (response: Response, error?: T) => boolean - }): ResponseHandler => - async ({ response, url }) => { - const responseBody = await response.text() - const responseHeaders = extractResponseHeaders(response) - - // Some providers return an empty response body for some errors - if (responseBody.trim() === '') { - return { - responseHeaders, - value: new APICallError({ - message: response.statusText, - url, - statusCode: response.status, - responseHeaders, - responseBody, - isRetryable: isRetryable?.(response) - }) - } - } - - // Resilient parsing in case the response is not JSON or does not match the schema - try { - const parsedResult = await safeParseJSON({ text: responseBody, schema: errorSchema }) - - if (parsedResult.success) { - return { - responseHeaders, - value: new APICallError({ - message: errorToMessage(parsedResult.value), - url, - statusCode: response.status, - responseHeaders, - responseBody, - isRetryable: isRetryable?.(response, parsedResult.value) - }) - } - } - } catch { - // Fall through to default error - } - - return { - responseHeaders, - value: new APICallError({ - message: response.statusText, - url, - statusCode: response.status, - responseHeaders, - responseBody, - isRetryable: isRetryable?.(response) - }) - } - } - -export const createStatusCodeErrorResponseHandler = - (): ResponseHandler => - async ({ response, url }) => { - const responseHeaders = extractResponseHeaders(response) - const responseBody = await response.text() - - return { - responseHeaders, - value: new APICallError({ - message: response.statusText, - url, - statusCode: response.status, - responseHeaders, - responseBody - }) - } - } - -// === Main GET Function === - -export interface GetFromApiOptions { - url: string - headers?: Record - successfulResponseHandler: ResponseHandler - failedResponseHandler: ResponseHandler - abortSignal?: AbortSignal - fetch?: FetchFunction -} - -export const getFromApi = async ({ - url, - headers = {}, - successfulResponseHandler, - failedResponseHandler, - abortSignal, - fetch = globalThis.fetch -}: GetFromApiOptions): Promise<{ value: T; responseHeaders?: Record }> => { - try { - // Filter out undefined headers - const cleanHeaders: Record = {} - for (const [key, value] of Object.entries(headers)) { - if (value !== undefined) { - cleanHeaders[key] = value - } - } - - const response = await fetch(url, { - method: 'GET', - headers: cleanHeaders, - signal: abortSignal - }) - - const responseHeaders = extractResponseHeaders(response) - - if (!response.ok) { - let errorInformation: { - value: Error - responseHeaders?: Record | undefined - } - - try { - errorInformation = await failedResponseHandler({ - response, - url - }) - } catch (error) { - if (isAbortError(error) || APICallError.isInstance(error)) { - throw error - } - - throw new APICallError({ - message: 'Failed to process error response', - cause: error, - statusCode: response.status, - url, - responseHeaders - }) - } - - throw errorInformation.value - } - - try { - return await successfulResponseHandler({ - response, - url - }) - } catch (error) { - if (error instanceof Error) { - if (isAbortError(error) || APICallError.isInstance(error)) { - throw error - } - } - - throw new APICallError({ - message: 'Failed to process successful response', - cause: error, - statusCode: response.status, - url, - responseHeaders - }) - } - } catch (error) { - throw handleFetchError({ error, url }) - } -} - -// === Convenience Functions === - -/** - * Fetch JSON from an API endpoint with schema validation - */ -export async function getJsonFromApi({ - url, - headers, - responseSchema, - errorSchema, - errorToMessage, - abortSignal, - fetch -}: { - url: string - headers?: Record - responseSchema: z.ZodType - errorSchema?: z.ZodType - errorToMessage?: (error: any) => string - abortSignal?: AbortSignal - fetch?: FetchFunction -}): Promise { - const result = await getFromApi({ - url, - headers, - successfulResponseHandler: createJsonResponseHandler(responseSchema), - failedResponseHandler: - errorSchema && errorToMessage - ? createJsonErrorResponseHandler({ errorSchema, errorToMessage }) - : createStatusCodeErrorResponseHandler(), - abortSignal, - fetch - }) - - return result.value -} diff --git a/src/renderer/src/aiCore/services/index.ts b/src/renderer/src/aiCore/services/index.ts index fcc3583c7b..5cd96877ad 100644 --- a/src/renderer/src/aiCore/services/index.ts +++ b/src/renderer/src/aiCore/services/index.ts @@ -3,6 +3,5 @@ * Unified services for AI operations */ -export * from './get-from-api' export * from './ModelListService' export * from './schemas'