diff --git a/src/main/apiServer/routes/messages.ts b/src/main/apiServer/routes/messages.ts index bcba993915..0b26db3960 100644 --- a/src/main/apiServer/routes/messages.ts +++ b/src/main/apiServer/routes/messages.ts @@ -46,20 +46,11 @@ const providerRouter = express.Router({ mergeParams: true }) /** * Estimate token count from messages - * Uses tokenx library for accurate token estimation and supports images + * Uses tokenx library for accurate token estimation and supports images, tools */ interface CountTokensInput { - messages: Array<{ - role: string - content: - | string - | Array<{ - type: string - text?: string - source?: { type: string; media_type?: string; data?: string } - }> - }> - system?: string | Array<{ type: string; text?: string }> + messages: MessageCreateParams['messages'] + system?: MessageCreateParams['system'] } function estimateTokenCount(input: CountTokensInput): number { @@ -89,14 +80,40 @@ function estimateTokenCount(input: CountTokensInput): number { totalTokens += approximateTokenSize(block.text) } else if (block.type === 'image') { // Image token estimation (consistent with TokenService) - // Base64 images: estimate from data length - if (block.source?.data) { + if (block.source.type === 'base64') { + // Base64 images: estimate from data length const dataSize = block.source.data.length * 0.75 // base64 to bytes totalTokens += Math.floor(dataSize / 100) } else { - // Default image token estimate + // URL images: use default estimate totalTokens += 1000 } + } else if (block.type === 'tool_use') { + // Tool use token estimation: name + input JSON + if (block.name) { + totalTokens += approximateTokenSize(block.name) + } + if (block.input) { + const inputJson = JSON.stringify(block.input) + totalTokens += approximateTokenSize(inputJson) + } + // Add overhead for tool use structure + totalTokens += 10 + } else if (block.type === 'tool_result') { + // Tool result token estimation + if (typeof block.content === 'string') { + totalTokens += approximateTokenSize(block.content) + } else if (Array.isArray(block.content)) { + for (const item of block.content) { + if (typeof item === 'string') { + totalTokens += approximateTokenSize(item) + } else if (item.type === 'text' && item.text) { + totalTokens += approximateTokenSize(item.text) + } + } + } + // Add overhead for tool result structure + totalTokens += 10 } } } @@ -127,6 +144,70 @@ async function validateRequestBody(req: Request): Promise<{ valid: boolean; erro return { valid: true } } +/** + * Shared handler for count_tokens endpoint + * Validates request and returns token count estimation + */ +async function handleCountTokens( + req: Request, + res: Response, + options: { + requireModel?: boolean + logContext?: Record + } = {} +): Promise { + try { + const { model, messages, system } = req.body + const { requireModel = false, logContext = {} } = options + + // Validate model parameter if required + if (requireModel && !model) { + return res.status(400).json({ + type: 'error', + error: { + type: 'invalid_request_error', + message: 'model parameter is required' + } + }) + } + + // Validate messages parameter + if (!messages || !Array.isArray(messages)) { + return res.status(400).json({ + type: 'error', + error: { + type: 'invalid_request_error', + message: 'messages parameter is required' + } + }) + } + + // Estimate token count + const estimatedTokens = estimateTokenCount({ messages, system }) + + // Log with context + logger.debug('Token count estimated', { + model, + messageCount: messages.length, + estimatedTokens, + ...logContext + }) + + return res.json({ + input_tokens: estimatedTokens + }) + } catch (error: any) { + logger.error('Token counting error', { error }) + return res.status(500).json({ + type: 'error', + error: { + type: 'api_error', + message: error.message || 'Internal server error' + } + }) + } +} + interface HandleMessageProcessingOptions { res: Response provider: Provider @@ -650,91 +731,17 @@ providerRouter.post('/', async (req: Request, res: Response) => { * description: Bad request */ router.post('/count_tokens', async (req: Request, res: Response) => { - try { - const { model, messages, system } = req.body - - if (!model) { - return res.status(400).json({ - type: 'error', - error: { - type: 'invalid_request_error', - message: 'model parameter is required' - } - }) - } - - if (!messages || !Array.isArray(messages)) { - return res.status(400).json({ - type: 'error', - error: { - type: 'invalid_request_error', - message: 'messages parameter is required' - } - }) - } - - const estimatedTokens = estimateTokenCount({ messages, system }) - - logger.debug('Token count estimated', { - model, - messageCount: messages.length, - estimatedTokens - }) - - return res.json({ - input_tokens: estimatedTokens - }) - } catch (error: any) { - logger.error('Token counting error', { error }) - return res.status(500).json({ - type: 'error', - error: { - type: 'api_error', - message: error.message || 'Internal server error' - } - }) - } + return handleCountTokens(req, res, { requireModel: true }) }) /** * Provider-specific count_tokens endpoint */ providerRouter.post('/count_tokens', async (req: Request, res: Response) => { - try { - const { model, messages, system } = req.body - - if (!messages || !Array.isArray(messages)) { - return res.status(400).json({ - type: 'error', - error: { - type: 'invalid_request_error', - message: 'messages parameter is required' - } - }) - } - - const estimatedTokens = estimateTokenCount({ messages, system }) - - logger.debug('Token count estimated (provider route)', { - providerId: req.params.provider, - model, - messageCount: messages.length, - estimatedTokens - }) - - return res.json({ - input_tokens: estimatedTokens - }) - } catch (error: any) { - logger.error('Token counting error', { error }) - return res.status(500).json({ - type: 'error', - error: { - type: 'api_error', - message: error.message || 'Internal server error' - } - }) - } + return handleCountTokens(req, res, { + requireModel: false, + logContext: { providerId: req.params.provider } + }) }) export { providerRouter as messagesProviderRoutes, router as messagesRoutes }