diff --git a/src/renderer/src/services/ApiService.ts b/src/renderer/src/services/ApiService.ts index 37925c704a..a824f96c20 100644 --- a/src/renderer/src/services/ApiService.ts +++ b/src/renderer/src/services/ApiService.ts @@ -37,7 +37,13 @@ import { import { type Chunk, ChunkType } from '@renderer/types/chunk' import { Message } from '@renderer/types/newMessage' import { SdkModel } from '@renderer/types/sdk' -import { removeSpecialCharactersForTopicName } from '@renderer/utils' +import { removeSpecialCharactersForTopicName, uuid } from '@renderer/utils' +import { + abortCompletion, + addAbortController, + createAbortPromise, + removeAbortController +} from '@renderer/utils/abortController' import { isAbortError } from '@renderer/utils/error' import { extractInfoFromXML, ExtractResults } from '@renderer/utils/extract' import { findFileBlocks, getMainTextContent } from '@renderer/utils/messageUtils/find' @@ -846,30 +852,66 @@ export function checkApiProvider(provider: Provider): void { export async function checkApi(provider: Provider, model: Model): Promise { checkApiProvider(provider) + const timeout = 15000 + const controller = new AbortController() + const abortFn = () => controller.abort() + const taskId = uuid() + addAbortController(taskId, abortFn) + const ai = new AiProvider(provider) const assistant = getDefaultAssistant() assistant.model = model try { if (isEmbeddingModel(model)) { - await ai.getEmbeddingDimensions(model) + // race 超时 15s + logger.silly("it's a embedding model") + const timerPromise = new Promise((_, reject) => setTimeout(() => reject('Timeout'), timeout)) + await Promise.race([ai.getEmbeddingDimensions(model), timerPromise]) } else { + // 通过该状态判断abort原因 + let streamError: Error | undefined = undefined + + // 15s超时 + const timer = setTimeout(() => { + abortCompletion(taskId) + streamError = new Error('Timeout') + }, timeout) + const params: CompletionsParams = { callType: 'check', messages: 'hi', assistant, streamOutput: true, enableReasoning: false, - shouldThrow: true + onChunk: () => { + // 接收到任意chunk都直接abort + abortCompletion(taskId) + }, + onError: (e) => { + // 捕获stream error + streamError = e + abortCompletion(taskId) + } } // Try streaming check first - const result = await ai.completions(params) - if (!result.getText()) { - throw new Error('No response received') + try { + await createAbortPromise(controller.signal, ai.completions(params)) + } catch (e: any) { + if (isAbortError(e)) { + if (streamError) { + throw streamError + } + } else { + throw e + } + } finally { + clearTimeout(timer) } } } catch (error: any) { + // FIXME: 这种判断方法无法严格保证错误是流式引起的 if (error.message.includes('stream')) { const params: CompletionsParams = { callType: 'check', @@ -878,12 +920,13 @@ export async function checkApi(provider: Provider, model: Model): Promise streamOutput: false, shouldThrow: true } - const result = await ai.completions(params) - if (!result.getText()) { - throw new Error('No response received') - } + // 超时判断 + const timeoutPromise = new Promise((_, reject) => setTimeout(() => reject('Timeout'), timeout)) + await Promise.race([ai.completions(params), timeoutPromise]) } else { throw error } + } finally { + removeAbortController(taskId, abortFn) } } diff --git a/src/renderer/src/utils/abortController.ts b/src/renderer/src/utils/abortController.ts index 21ac75c494..6593b1c162 100644 --- a/src/renderer/src/utils/abortController.ts +++ b/src/renderer/src/utils/abortController.ts @@ -30,8 +30,8 @@ export const abortCompletion = (id: string) => { } } -export function createAbortPromise(signal: AbortSignal, finallyPromise: Promise) { - return new Promise((_resolve, reject) => { +export function createAbortPromise(signal: AbortSignal, finallyPromise: Promise) { + return new Promise((_resolve, reject) => { if (signal.aborted) { reject(new DOMException('Operation aborted', 'AbortError')) return