mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2026-01-14 06:07:23 +08:00
Merge branch 'feat/proxy-api-server' into refactor/api-gateway
This commit is contained in:
commit
1220adad9b
@ -3,7 +3,7 @@ import express from 'express'
|
||||
|
||||
import { loggerService } from '../../services/LoggerService'
|
||||
import type { ExtendedChatCompletionCreateParams } from '../adapters'
|
||||
import { generateMessage, streamToResponse } from '../services/ProxyStreamService'
|
||||
import { processMessage } from '../services/ProxyStreamService'
|
||||
import { validateModelId } from '../utils'
|
||||
|
||||
const logger = loggerService.withContext('ApiServerChatRoutes')
|
||||
@ -205,38 +205,15 @@ router.post('/completions', async (req: Request, res: Response) => {
|
||||
|
||||
const provider = modelValidation.provider!
|
||||
const modelId = modelValidation.modelId!
|
||||
const isStreaming = !!request.stream
|
||||
|
||||
if (isStreaming) {
|
||||
try {
|
||||
await streamToResponse({
|
||||
response: res,
|
||||
provider,
|
||||
modelId,
|
||||
params: request,
|
||||
inputFormat: 'openai',
|
||||
outputFormat: 'openai'
|
||||
})
|
||||
} catch (streamError) {
|
||||
logger.error('Stream error', { error: streamError })
|
||||
// If headers weren't sent yet, return JSON error
|
||||
if (!res.headersSent) {
|
||||
const { status, body } = mapChatCompletionError(streamError)
|
||||
return res.status(status).json(body)
|
||||
}
|
||||
// Otherwise the error is already handled by streamToResponse
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
const response = await generateMessage({
|
||||
return processMessage({
|
||||
response: res,
|
||||
provider,
|
||||
modelId,
|
||||
params: request,
|
||||
inputFormat: 'openai',
|
||||
outputFormat: 'openai'
|
||||
})
|
||||
return res.json(response)
|
||||
} catch (error: unknown) {
|
||||
const { status, body } = mapChatCompletionError(error)
|
||||
return res.status(status).json(body)
|
||||
|
||||
@ -8,7 +8,7 @@ import express from 'express'
|
||||
import { approximateTokenSize } from 'tokenx'
|
||||
|
||||
import { messagesService } from '../services/messages'
|
||||
import { generateMessage, streamToResponse } from '../services/ProxyStreamService'
|
||||
import { processMessage } from '../services/ProxyStreamService'
|
||||
import { getProviderById, isModelAnthropicCompatible, validateModelId } from '../utils'
|
||||
|
||||
/**
|
||||
@ -321,29 +321,19 @@ async function handleUnifiedProcessing({
|
||||
providerId: provider.id
|
||||
})
|
||||
|
||||
if (request.stream) {
|
||||
await streamToResponse({
|
||||
response: res,
|
||||
provider,
|
||||
modelId: actualModelId,
|
||||
params: request,
|
||||
middlewares,
|
||||
onError: (error) => {
|
||||
logger.error('Stream error', error as Error)
|
||||
},
|
||||
onComplete: () => {
|
||||
logger.debug('Stream completed')
|
||||
}
|
||||
})
|
||||
} else {
|
||||
const response = await generateMessage({
|
||||
provider,
|
||||
modelId: actualModelId,
|
||||
params: request,
|
||||
middlewares
|
||||
})
|
||||
res.json(response)
|
||||
}
|
||||
await processMessage({
|
||||
response: res,
|
||||
provider,
|
||||
modelId: actualModelId,
|
||||
params: request,
|
||||
middlewares,
|
||||
onError: (error) => {
|
||||
logger.error('Message error', error as Error)
|
||||
},
|
||||
onComplete: () => {
|
||||
logger.debug('Message completed')
|
||||
}
|
||||
})
|
||||
} catch (error: any) {
|
||||
const { statusCode, errorResponse } = messagesService.transformError(error)
|
||||
res.status(statusCode).json(errorResponse)
|
||||
|
||||
@ -33,6 +33,8 @@ import type { Response } from 'express'
|
||||
|
||||
import type { InputFormat, InputParamsMap, IStreamAdapter } from '../adapters'
|
||||
import { MessageConverterFactory, type OutputFormat, StreamAdapterFactory } from '../adapters'
|
||||
import { LONG_POLL_TIMEOUT_MS } from '../config/timeouts'
|
||||
import { createStreamAbortController } from '../utils/createStreamAbortController'
|
||||
import { googleReasoningCache, openRouterReasoningCache } from './reasoning-cache'
|
||||
|
||||
const logger = loggerService.withContext('ProxyStreamService')
|
||||
@ -42,10 +44,6 @@ initializeSharedProviders({
|
||||
error: (message, error) => logger.error(message, error)
|
||||
})
|
||||
|
||||
// ============================================================================
|
||||
// Configuration Interfaces
|
||||
// ============================================================================
|
||||
|
||||
/**
|
||||
* Middleware type alias
|
||||
*/
|
||||
@ -57,9 +55,9 @@ type LanguageModelMiddleware = LanguageModelV2Middleware
|
||||
type InputParams = InputParamsMap[InputFormat]
|
||||
|
||||
/**
|
||||
* Configuration for streaming message requests
|
||||
* Configuration for message requests (both streaming and non-streaming)
|
||||
*/
|
||||
export interface StreamConfig {
|
||||
export interface MessageConfig {
|
||||
response: Response
|
||||
provider: Provider
|
||||
modelId: string
|
||||
@ -72,19 +70,6 @@ export interface StreamConfig {
|
||||
plugins?: AiPlugin[]
|
||||
}
|
||||
|
||||
/**
|
||||
* Configuration for non-streaming message generation
|
||||
*/
|
||||
export interface GenerateConfig {
|
||||
provider: Provider
|
||||
modelId: string
|
||||
params: InputParams
|
||||
inputFormat?: InputFormat
|
||||
outputFormat?: OutputFormat
|
||||
middlewares?: LanguageModelMiddleware[]
|
||||
plugins?: AiPlugin[]
|
||||
}
|
||||
|
||||
/**
|
||||
* Internal configuration for stream execution
|
||||
*/
|
||||
@ -96,6 +81,7 @@ interface ExecuteStreamConfig {
|
||||
outputFormat: OutputFormat
|
||||
middlewares?: LanguageModelMiddleware[]
|
||||
plugins?: AiPlugin[]
|
||||
abortSignal?: AbortSignal
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
@ -248,7 +234,7 @@ async function executeStream(config: ExecuteStreamConfig): Promise<{
|
||||
adapter: IStreamAdapter
|
||||
outputStream: ReadableStream<unknown>
|
||||
}> {
|
||||
const { provider, modelId, params, inputFormat, outputFormat, middlewares = [], plugins = [] } = config
|
||||
const { provider, modelId, params, inputFormat, outputFormat, middlewares = [], plugins = [], abortSignal } = config
|
||||
|
||||
// Convert provider config to AI SDK config
|
||||
let sdkConfig = providerToAiSdkConfig(provider, modelId)
|
||||
@ -291,7 +277,8 @@ async function executeStream(config: ExecuteStreamConfig): Promise<{
|
||||
stopWhen: stepCountIs(100),
|
||||
headers: defaultAppHeaders(),
|
||||
tools,
|
||||
providerOptions
|
||||
providerOptions,
|
||||
abortSignal
|
||||
})
|
||||
|
||||
// Transform stream using adapter
|
||||
@ -300,27 +287,14 @@ async function executeStream(config: ExecuteStreamConfig): Promise<{
|
||||
return { adapter, outputStream }
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Public API
|
||||
// ============================================================================
|
||||
|
||||
/**
|
||||
* Stream a message request and write to HTTP response
|
||||
* Process a message request - handles both streaming and non-streaming
|
||||
*
|
||||
* Uses TransformStream-based adapters for efficient streaming.
|
||||
*
|
||||
* @example
|
||||
* ```typescript
|
||||
* await streamToResponse({
|
||||
* response: res,
|
||||
* provider,
|
||||
* modelId: 'claude-3-opus',
|
||||
* params: messageCreateParams,
|
||||
* outputFormat: 'anthropic'
|
||||
* })
|
||||
* ```
|
||||
* Automatically detects streaming mode from params.stream:
|
||||
* - stream=true: SSE streaming response
|
||||
* - stream=false: JSON response
|
||||
*/
|
||||
export async function streamToResponse(config: StreamConfig): Promise<void> {
|
||||
export async function processMessage(config: MessageConfig): Promise<void> {
|
||||
const {
|
||||
response,
|
||||
provider,
|
||||
@ -334,7 +308,9 @@ export async function streamToResponse(config: StreamConfig): Promise<void> {
|
||||
plugins = []
|
||||
} = config
|
||||
|
||||
logger.info('Starting proxy stream', {
|
||||
const isStreaming = 'stream' in params && params.stream === true
|
||||
|
||||
logger.info(`Starting ${isStreaming ? 'streaming' : 'non-streaming'} message`, {
|
||||
providerId: provider.id,
|
||||
providerType: provider.type,
|
||||
modelId,
|
||||
@ -344,90 +320,21 @@ export async function streamToResponse(config: StreamConfig): Promise<void> {
|
||||
pluginCount: plugins.length
|
||||
})
|
||||
|
||||
try {
|
||||
// Set SSE headers
|
||||
response.setHeader('Content-Type', 'text/event-stream')
|
||||
response.setHeader('Cache-Control', 'no-cache')
|
||||
response.setHeader('Connection', 'keep-alive')
|
||||
response.setHeader('X-Accel-Buffering', 'no')
|
||||
// Create abort controller with timeout
|
||||
const streamController = createStreamAbortController({ timeoutMs: LONG_POLL_TIMEOUT_MS })
|
||||
const { abortController, dispose } = streamController
|
||||
|
||||
const { outputStream } = await executeStream({
|
||||
provider,
|
||||
modelId,
|
||||
params,
|
||||
inputFormat,
|
||||
outputFormat,
|
||||
middlewares,
|
||||
plugins
|
||||
})
|
||||
|
||||
// Get formatter for the output format
|
||||
const formatter = StreamAdapterFactory.getFormatter(outputFormat)
|
||||
|
||||
// Stream events to response
|
||||
const reader = outputStream.getReader()
|
||||
try {
|
||||
while (true) {
|
||||
const { done, value } = await reader.read()
|
||||
if (done) break
|
||||
response.write(formatter.formatEvent(value))
|
||||
}
|
||||
} finally {
|
||||
reader.releaseLock()
|
||||
}
|
||||
|
||||
// Send done marker and end response
|
||||
response.write(formatter.formatDone())
|
||||
response.end()
|
||||
|
||||
logger.info('Proxy stream completed', { providerId: provider.id, modelId })
|
||||
onComplete?.()
|
||||
} catch (error) {
|
||||
logger.error('Error in proxy stream', error as Error, { providerId: provider.id, modelId })
|
||||
onError?.(error)
|
||||
throw error
|
||||
const handleDisconnect = () => {
|
||||
if (abortController.signal.aborted) return
|
||||
logger.info('Client disconnected, aborting', { providerId: provider.id, modelId })
|
||||
abortController.abort('Client disconnected')
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate a non-streaming message response
|
||||
*
|
||||
* Uses simulateStreamingMiddleware to reuse the same streaming logic.
|
||||
*
|
||||
* @example
|
||||
* ```typescript
|
||||
* const message = await generateMessage({
|
||||
* provider,
|
||||
* modelId: 'claude-3-opus',
|
||||
* params: messageCreateParams,
|
||||
* outputFormat: 'anthropic'
|
||||
* })
|
||||
* ```
|
||||
*/
|
||||
export async function generateMessage(config: GenerateConfig): Promise<unknown> {
|
||||
const {
|
||||
provider,
|
||||
modelId,
|
||||
params,
|
||||
inputFormat = 'anthropic',
|
||||
outputFormat = 'anthropic',
|
||||
middlewares = [],
|
||||
plugins = []
|
||||
} = config
|
||||
|
||||
logger.info('Starting message generation', {
|
||||
providerId: provider.id,
|
||||
providerType: provider.type,
|
||||
modelId,
|
||||
inputFormat,
|
||||
outputFormat,
|
||||
middlewareCount: middlewares.length,
|
||||
pluginCount: plugins.length
|
||||
})
|
||||
response.on('close', handleDisconnect)
|
||||
|
||||
try {
|
||||
// Add simulateStreamingMiddleware to reuse streaming logic
|
||||
const allMiddlewares = [simulateStreamingMiddleware(), ...middlewares]
|
||||
// For non-streaming, add simulateStreamingMiddleware
|
||||
const allMiddlewares = isStreaming ? middlewares : [simulateStreamingMiddleware(), ...middlewares]
|
||||
|
||||
const { adapter, outputStream } = await executeStream({
|
||||
provider,
|
||||
@ -436,30 +343,60 @@ export async function generateMessage(config: GenerateConfig): Promise<unknown>
|
||||
inputFormat,
|
||||
outputFormat,
|
||||
middlewares: allMiddlewares,
|
||||
plugins
|
||||
plugins,
|
||||
abortSignal: abortController.signal
|
||||
})
|
||||
|
||||
// Consume the stream to populate adapter state
|
||||
const reader = outputStream.getReader()
|
||||
while (true) {
|
||||
const { done } = await reader.read()
|
||||
if (done) break
|
||||
if (isStreaming) {
|
||||
// Streaming: Set SSE headers and stream events
|
||||
response.setHeader('Content-Type', 'text/event-stream')
|
||||
response.setHeader('Cache-Control', 'no-cache')
|
||||
response.setHeader('Connection', 'keep-alive')
|
||||
response.setHeader('X-Accel-Buffering', 'no')
|
||||
|
||||
const formatter = StreamAdapterFactory.getFormatter(outputFormat)
|
||||
const reader = outputStream.getReader()
|
||||
|
||||
try {
|
||||
while (true) {
|
||||
const { done, value } = await reader.read()
|
||||
if (done) break
|
||||
if (response.writableEnded) break
|
||||
response.write(formatter.formatEvent(value))
|
||||
}
|
||||
} finally {
|
||||
reader.releaseLock()
|
||||
}
|
||||
|
||||
if (!response.writableEnded) {
|
||||
response.write(formatter.formatDone())
|
||||
response.end()
|
||||
}
|
||||
} else {
|
||||
// Non-streaming: Consume stream and return JSON
|
||||
const reader = outputStream.getReader()
|
||||
while (true) {
|
||||
const { done } = await reader.read()
|
||||
if (done) break
|
||||
}
|
||||
reader.releaseLock()
|
||||
|
||||
const finalResponse = adapter.buildNonStreamingResponse()
|
||||
response.json(finalResponse)
|
||||
}
|
||||
reader.releaseLock()
|
||||
|
||||
// Build final response from adapter
|
||||
const finalResponse = adapter.buildNonStreamingResponse()
|
||||
|
||||
logger.info('Message generation completed', { providerId: provider.id, modelId })
|
||||
|
||||
return finalResponse
|
||||
logger.info('Message completed', { providerId: provider.id, modelId, streaming: isStreaming })
|
||||
onComplete?.()
|
||||
} catch (error) {
|
||||
logger.error('Error in message generation', error as Error, { providerId: provider.id, modelId })
|
||||
logger.error('Error in message processing', error as Error, { providerId: provider.id, modelId })
|
||||
onError?.(error)
|
||||
throw error
|
||||
} finally {
|
||||
response.off('close', handleDisconnect)
|
||||
dispose()
|
||||
}
|
||||
}
|
||||
|
||||
export default {
|
||||
streamToResponse,
|
||||
generateMessage
|
||||
processMessage
|
||||
}
|
||||
|
||||
@ -15,11 +15,6 @@ export class ModelsService {
|
||||
|
||||
const providers = await getAvailableProviders()
|
||||
|
||||
// Note: When providerType === 'anthropic', we now return ALL available models
|
||||
// because the API Server's unified adapter (AiSdkToAnthropicSSE) can convert
|
||||
// any provider's response to Anthropic SSE format. This enables Claude Code Agent
|
||||
// to work with OpenAI, Gemini, and other providers transparently.
|
||||
|
||||
const models = await listAllAvailableModels(providers)
|
||||
// Use Map to deduplicate models by their full ID (provider:model_id)
|
||||
const uniqueModels = new Map<string, ApiModel>()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user