Merge branch 'feat/proxy-api-server' into refactor/api-gateway

This commit is contained in:
suyao 2026-01-04 18:57:47 +08:00
commit 1220adad9b
No known key found for this signature in database
4 changed files with 89 additions and 190 deletions

View File

@ -3,7 +3,7 @@ import express from 'express'
import { loggerService } from '../../services/LoggerService' import { loggerService } from '../../services/LoggerService'
import type { ExtendedChatCompletionCreateParams } from '../adapters' import type { ExtendedChatCompletionCreateParams } from '../adapters'
import { generateMessage, streamToResponse } from '../services/ProxyStreamService' import { processMessage } from '../services/ProxyStreamService'
import { validateModelId } from '../utils' import { validateModelId } from '../utils'
const logger = loggerService.withContext('ApiServerChatRoutes') const logger = loggerService.withContext('ApiServerChatRoutes')
@ -205,38 +205,15 @@ router.post('/completions', async (req: Request, res: Response) => {
const provider = modelValidation.provider! const provider = modelValidation.provider!
const modelId = modelValidation.modelId! const modelId = modelValidation.modelId!
const isStreaming = !!request.stream
if (isStreaming) { return processMessage({
try { response: res,
await streamToResponse({
response: res,
provider,
modelId,
params: request,
inputFormat: 'openai',
outputFormat: 'openai'
})
} catch (streamError) {
logger.error('Stream error', { error: streamError })
// If headers weren't sent yet, return JSON error
if (!res.headersSent) {
const { status, body } = mapChatCompletionError(streamError)
return res.status(status).json(body)
}
// Otherwise the error is already handled by streamToResponse
}
return
}
const response = await generateMessage({
provider, provider,
modelId, modelId,
params: request, params: request,
inputFormat: 'openai', inputFormat: 'openai',
outputFormat: 'openai' outputFormat: 'openai'
}) })
return res.json(response)
} catch (error: unknown) { } catch (error: unknown) {
const { status, body } = mapChatCompletionError(error) const { status, body } = mapChatCompletionError(error)
return res.status(status).json(body) return res.status(status).json(body)

View File

@ -8,7 +8,7 @@ import express from 'express'
import { approximateTokenSize } from 'tokenx' import { approximateTokenSize } from 'tokenx'
import { messagesService } from '../services/messages' import { messagesService } from '../services/messages'
import { generateMessage, streamToResponse } from '../services/ProxyStreamService' import { processMessage } from '../services/ProxyStreamService'
import { getProviderById, isModelAnthropicCompatible, validateModelId } from '../utils' import { getProviderById, isModelAnthropicCompatible, validateModelId } from '../utils'
/** /**
@ -321,29 +321,19 @@ async function handleUnifiedProcessing({
providerId: provider.id providerId: provider.id
}) })
if (request.stream) { await processMessage({
await streamToResponse({ response: res,
response: res, provider,
provider, modelId: actualModelId,
modelId: actualModelId, params: request,
params: request, middlewares,
middlewares, onError: (error) => {
onError: (error) => { logger.error('Message error', error as Error)
logger.error('Stream error', error as Error) },
}, onComplete: () => {
onComplete: () => { logger.debug('Message completed')
logger.debug('Stream completed') }
} })
})
} else {
const response = await generateMessage({
provider,
modelId: actualModelId,
params: request,
middlewares
})
res.json(response)
}
} catch (error: any) { } catch (error: any) {
const { statusCode, errorResponse } = messagesService.transformError(error) const { statusCode, errorResponse } = messagesService.transformError(error)
res.status(statusCode).json(errorResponse) res.status(statusCode).json(errorResponse)

View File

@ -33,6 +33,8 @@ import type { Response } from 'express'
import type { InputFormat, InputParamsMap, IStreamAdapter } from '../adapters' import type { InputFormat, InputParamsMap, IStreamAdapter } from '../adapters'
import { MessageConverterFactory, type OutputFormat, StreamAdapterFactory } from '../adapters' import { MessageConverterFactory, type OutputFormat, StreamAdapterFactory } from '../adapters'
import { LONG_POLL_TIMEOUT_MS } from '../config/timeouts'
import { createStreamAbortController } from '../utils/createStreamAbortController'
import { googleReasoningCache, openRouterReasoningCache } from './reasoning-cache' import { googleReasoningCache, openRouterReasoningCache } from './reasoning-cache'
const logger = loggerService.withContext('ProxyStreamService') const logger = loggerService.withContext('ProxyStreamService')
@ -42,10 +44,6 @@ initializeSharedProviders({
error: (message, error) => logger.error(message, error) error: (message, error) => logger.error(message, error)
}) })
// ============================================================================
// Configuration Interfaces
// ============================================================================
/** /**
* Middleware type alias * Middleware type alias
*/ */
@ -57,9 +55,9 @@ type LanguageModelMiddleware = LanguageModelV2Middleware
type InputParams = InputParamsMap[InputFormat] type InputParams = InputParamsMap[InputFormat]
/** /**
* Configuration for streaming message requests * Configuration for message requests (both streaming and non-streaming)
*/ */
export interface StreamConfig { export interface MessageConfig {
response: Response response: Response
provider: Provider provider: Provider
modelId: string modelId: string
@ -72,19 +70,6 @@ export interface StreamConfig {
plugins?: AiPlugin[] plugins?: AiPlugin[]
} }
/**
* Configuration for non-streaming message generation
*/
export interface GenerateConfig {
provider: Provider
modelId: string
params: InputParams
inputFormat?: InputFormat
outputFormat?: OutputFormat
middlewares?: LanguageModelMiddleware[]
plugins?: AiPlugin[]
}
/** /**
* Internal configuration for stream execution * Internal configuration for stream execution
*/ */
@ -96,6 +81,7 @@ interface ExecuteStreamConfig {
outputFormat: OutputFormat outputFormat: OutputFormat
middlewares?: LanguageModelMiddleware[] middlewares?: LanguageModelMiddleware[]
plugins?: AiPlugin[] plugins?: AiPlugin[]
abortSignal?: AbortSignal
} }
// ============================================================================ // ============================================================================
@ -248,7 +234,7 @@ async function executeStream(config: ExecuteStreamConfig): Promise<{
adapter: IStreamAdapter adapter: IStreamAdapter
outputStream: ReadableStream<unknown> outputStream: ReadableStream<unknown>
}> { }> {
const { provider, modelId, params, inputFormat, outputFormat, middlewares = [], plugins = [] } = config const { provider, modelId, params, inputFormat, outputFormat, middlewares = [], plugins = [], abortSignal } = config
// Convert provider config to AI SDK config // Convert provider config to AI SDK config
let sdkConfig = providerToAiSdkConfig(provider, modelId) let sdkConfig = providerToAiSdkConfig(provider, modelId)
@ -291,7 +277,8 @@ async function executeStream(config: ExecuteStreamConfig): Promise<{
stopWhen: stepCountIs(100), stopWhen: stepCountIs(100),
headers: defaultAppHeaders(), headers: defaultAppHeaders(),
tools, tools,
providerOptions providerOptions,
abortSignal
}) })
// Transform stream using adapter // Transform stream using adapter
@ -300,27 +287,14 @@ async function executeStream(config: ExecuteStreamConfig): Promise<{
return { adapter, outputStream } return { adapter, outputStream }
} }
// ============================================================================
// Public API
// ============================================================================
/** /**
* Stream a message request and write to HTTP response * Process a message request - handles both streaming and non-streaming
* *
* Uses TransformStream-based adapters for efficient streaming. * Automatically detects streaming mode from params.stream:
* * - stream=true: SSE streaming response
* @example * - stream=false: JSON response
* ```typescript
* await streamToResponse({
* response: res,
* provider,
* modelId: 'claude-3-opus',
* params: messageCreateParams,
* outputFormat: 'anthropic'
* })
* ```
*/ */
export async function streamToResponse(config: StreamConfig): Promise<void> { export async function processMessage(config: MessageConfig): Promise<void> {
const { const {
response, response,
provider, provider,
@ -334,7 +308,9 @@ export async function streamToResponse(config: StreamConfig): Promise<void> {
plugins = [] plugins = []
} = config } = config
logger.info('Starting proxy stream', { const isStreaming = 'stream' in params && params.stream === true
logger.info(`Starting ${isStreaming ? 'streaming' : 'non-streaming'} message`, {
providerId: provider.id, providerId: provider.id,
providerType: provider.type, providerType: provider.type,
modelId, modelId,
@ -344,90 +320,21 @@ export async function streamToResponse(config: StreamConfig): Promise<void> {
pluginCount: plugins.length pluginCount: plugins.length
}) })
try { // Create abort controller with timeout
// Set SSE headers const streamController = createStreamAbortController({ timeoutMs: LONG_POLL_TIMEOUT_MS })
response.setHeader('Content-Type', 'text/event-stream') const { abortController, dispose } = streamController
response.setHeader('Cache-Control', 'no-cache')
response.setHeader('Connection', 'keep-alive')
response.setHeader('X-Accel-Buffering', 'no')
const { outputStream } = await executeStream({ const handleDisconnect = () => {
provider, if (abortController.signal.aborted) return
modelId, logger.info('Client disconnected, aborting', { providerId: provider.id, modelId })
params, abortController.abort('Client disconnected')
inputFormat,
outputFormat,
middlewares,
plugins
})
// Get formatter for the output format
const formatter = StreamAdapterFactory.getFormatter(outputFormat)
// Stream events to response
const reader = outputStream.getReader()
try {
while (true) {
const { done, value } = await reader.read()
if (done) break
response.write(formatter.formatEvent(value))
}
} finally {
reader.releaseLock()
}
// Send done marker and end response
response.write(formatter.formatDone())
response.end()
logger.info('Proxy stream completed', { providerId: provider.id, modelId })
onComplete?.()
} catch (error) {
logger.error('Error in proxy stream', error as Error, { providerId: provider.id, modelId })
onError?.(error)
throw error
} }
}
/** response.on('close', handleDisconnect)
* Generate a non-streaming message response
*
* Uses simulateStreamingMiddleware to reuse the same streaming logic.
*
* @example
* ```typescript
* const message = await generateMessage({
* provider,
* modelId: 'claude-3-opus',
* params: messageCreateParams,
* outputFormat: 'anthropic'
* })
* ```
*/
export async function generateMessage(config: GenerateConfig): Promise<unknown> {
const {
provider,
modelId,
params,
inputFormat = 'anthropic',
outputFormat = 'anthropic',
middlewares = [],
plugins = []
} = config
logger.info('Starting message generation', {
providerId: provider.id,
providerType: provider.type,
modelId,
inputFormat,
outputFormat,
middlewareCount: middlewares.length,
pluginCount: plugins.length
})
try { try {
// Add simulateStreamingMiddleware to reuse streaming logic // For non-streaming, add simulateStreamingMiddleware
const allMiddlewares = [simulateStreamingMiddleware(), ...middlewares] const allMiddlewares = isStreaming ? middlewares : [simulateStreamingMiddleware(), ...middlewares]
const { adapter, outputStream } = await executeStream({ const { adapter, outputStream } = await executeStream({
provider, provider,
@ -436,30 +343,60 @@ export async function generateMessage(config: GenerateConfig): Promise<unknown>
inputFormat, inputFormat,
outputFormat, outputFormat,
middlewares: allMiddlewares, middlewares: allMiddlewares,
plugins plugins,
abortSignal: abortController.signal
}) })
// Consume the stream to populate adapter state if (isStreaming) {
const reader = outputStream.getReader() // Streaming: Set SSE headers and stream events
while (true) { response.setHeader('Content-Type', 'text/event-stream')
const { done } = await reader.read() response.setHeader('Cache-Control', 'no-cache')
if (done) break response.setHeader('Connection', 'keep-alive')
response.setHeader('X-Accel-Buffering', 'no')
const formatter = StreamAdapterFactory.getFormatter(outputFormat)
const reader = outputStream.getReader()
try {
while (true) {
const { done, value } = await reader.read()
if (done) break
if (response.writableEnded) break
response.write(formatter.formatEvent(value))
}
} finally {
reader.releaseLock()
}
if (!response.writableEnded) {
response.write(formatter.formatDone())
response.end()
}
} else {
// Non-streaming: Consume stream and return JSON
const reader = outputStream.getReader()
while (true) {
const { done } = await reader.read()
if (done) break
}
reader.releaseLock()
const finalResponse = adapter.buildNonStreamingResponse()
response.json(finalResponse)
} }
reader.releaseLock()
// Build final response from adapter logger.info('Message completed', { providerId: provider.id, modelId, streaming: isStreaming })
const finalResponse = adapter.buildNonStreamingResponse() onComplete?.()
logger.info('Message generation completed', { providerId: provider.id, modelId })
return finalResponse
} catch (error) { } catch (error) {
logger.error('Error in message generation', error as Error, { providerId: provider.id, modelId }) logger.error('Error in message processing', error as Error, { providerId: provider.id, modelId })
onError?.(error)
throw error throw error
} finally {
response.off('close', handleDisconnect)
dispose()
} }
} }
export default { export default {
streamToResponse, processMessage
generateMessage
} }

View File

@ -15,11 +15,6 @@ export class ModelsService {
const providers = await getAvailableProviders() const providers = await getAvailableProviders()
// Note: When providerType === 'anthropic', we now return ALL available models
// because the API Server's unified adapter (AiSdkToAnthropicSSE) can convert
// any provider's response to Anthropic SSE format. This enables Claude Code Agent
// to work with OpenAI, Gemini, and other providers transparently.
const models = await listAllAvailableModels(providers) const models = await listAllAvailableModels(providers)
// Use Map to deduplicate models by their full ID (provider:model_id) // Use Map to deduplicate models by their full ID (provider:model_id)
const uniqueModels = new Map<string, ApiModel>() const uniqueModels = new Map<string, ApiModel>()