feat: implement BlockManager and associated callbacks for message str… (#8167)

* feat: implement BlockManager and associated callbacks for message streaming

- Introduced BlockManager to manage message blocks with smart update strategies.
- Added various callback handlers for different message types including text, image, citation, and tool responses.
- Enhanced state management for active blocks and transitions between different message types.
- Created utility functions for handling block updates and transitions, improving overall message processing flow.
- Refactored message thunk to utilize BlockManager for better organization and maintainability.

This implementation lays the groundwork for more efficient message streaming and processing in the application.

* refactor: clean up BlockManager and callback implementations

- Removed redundant assignments of lastBlockType in various callback files.
- Updated error handling logic to ensure correct message status updates.
- Added console logs for debugging purposes in BlockManager and citation callbacks.
- Enhanced smartBlockUpdate method call in citation callbacks for better state management.

* refactor: streamline BlockManager and callback logic

- Removed unnecessary accumulated content variables in text and thinking callbacks.
- Updated content handling in callbacks to directly use incoming text instead of accumulating.
- Enhanced smartBlockUpdate calls for better state management in message streaming.
- Cleaned up console log statements for improved readability and debugging.
This commit is contained in:
MyPrototypeWhat 2025-07-17 10:03:14 +08:00 committed by GitHub
parent aa254a3772
commit 7e471bfea4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 1468 additions and 584 deletions

View File

@ -0,0 +1,136 @@
import type { AppDispatch, RootState } from '@renderer/store'
import { updateOneBlock, upsertOneBlock } from '@renderer/store/messageBlock'
import { newMessagesActions } from '@renderer/store/newMessage'
import { MessageBlock, MessageBlockType } from '@renderer/types/newMessage'
interface ActiveBlockInfo {
id: string
type: MessageBlockType
}
interface BlockManagerDependencies {
dispatch: AppDispatch
getState: () => RootState
saveUpdatedBlockToDB: (
blockId: string | null,
messageId: string,
topicId: string,
getState: () => RootState
) => Promise<void>
saveUpdatesToDB: (
messageId: string,
topicId: string,
messageUpdates: Partial<any>,
blocksToUpdate: MessageBlock[]
) => Promise<void>
assistantMsgId: string
topicId: string
// 节流器管理从外部传入
throttledBlockUpdate: (id: string, blockUpdate: any) => void
cancelThrottledBlockUpdate: (id: string) => void
}
export class BlockManager {
private deps: BlockManagerDependencies
// 简化后的状态管理
private _activeBlockInfo: ActiveBlockInfo | null = null
private _lastBlockType: MessageBlockType | null = null // 保留用于错误处理
constructor(dependencies: BlockManagerDependencies) {
this.deps = dependencies
}
// Getters
get activeBlockInfo() {
return this._activeBlockInfo
}
get lastBlockType() {
return this._lastBlockType
}
get hasInitialPlaceholder() {
return this._activeBlockInfo?.type === MessageBlockType.UNKNOWN
}
get initialPlaceholderBlockId() {
return this.hasInitialPlaceholder ? this._activeBlockInfo?.id || null : null
}
// Setters
set lastBlockType(value: MessageBlockType | null) {
this._lastBlockType = value
}
set activeBlockInfo(value: ActiveBlockInfo | null) {
this._activeBlockInfo = value
}
/**
* 使
*/
smartBlockUpdate(
blockId: string,
changes: Partial<MessageBlock>,
blockType: MessageBlockType,
isComplete: boolean = false
) {
const isBlockTypeChanged = this._lastBlockType !== null && this._lastBlockType !== blockType
if (isBlockTypeChanged || isComplete) {
// 如果块类型改变,则取消上一个块的节流更新
if (isBlockTypeChanged && this._activeBlockInfo) {
this.deps.cancelThrottledBlockUpdate(this._activeBlockInfo.id)
}
// 如果当前块完成,则取消当前块的节流更新
if (isComplete) {
this.deps.cancelThrottledBlockUpdate(blockId)
this._activeBlockInfo = null // 块完成时清空activeBlockInfo
} else {
this._activeBlockInfo = { id: blockId, type: blockType } // 更新活跃块信息
}
this.deps.dispatch(updateOneBlock({ id: blockId, changes }))
this.deps.saveUpdatedBlockToDB(blockId, this.deps.assistantMsgId, this.deps.topicId, this.deps.getState)
this._lastBlockType = blockType
} else {
this._activeBlockInfo = { id: blockId, type: blockType } // 更新活跃块信息
this.deps.throttledBlockUpdate(blockId, changes)
}
}
/**
*
*/
async handleBlockTransition(newBlock: MessageBlock, newBlockType: MessageBlockType) {
this._lastBlockType = newBlockType
this._activeBlockInfo = { id: newBlock.id, type: newBlockType } // 设置新的活跃块信息
this.deps.dispatch(
newMessagesActions.updateMessage({
topicId: this.deps.topicId,
messageId: this.deps.assistantMsgId,
updates: { blockInstruction: { id: newBlock.id } }
})
)
this.deps.dispatch(upsertOneBlock(newBlock))
this.deps.dispatch(
newMessagesActions.upsertBlockReference({
messageId: this.deps.assistantMsgId,
blockId: newBlock.id,
status: newBlock.status
})
)
const currentState = this.deps.getState()
const updatedMessage = currentState.messages.entities[this.deps.assistantMsgId]
if (updatedMessage) {
await this.deps.saveUpdatesToDB(this.deps.assistantMsgId, this.deps.topicId, { blocks: updatedMessage.blocks }, [
newBlock
])
} else {
console.error(
`[handleBlockTransition] Failed to get updated message ${this.deps.assistantMsgId} from state for DB save.`
)
}
}
}

View File

@ -0,0 +1,214 @@
import { autoRenameTopic } from '@renderer/hooks/useTopic'
import i18n from '@renderer/i18n'
import { EVENT_NAMES, EventEmitter } from '@renderer/services/EventService'
import { NotificationService } from '@renderer/services/NotificationService'
import { estimateMessagesUsage } from '@renderer/services/TokenService'
import { selectMessagesForTopic } from '@renderer/store/newMessage'
import { newMessagesActions } from '@renderer/store/newMessage'
import type { Assistant } from '@renderer/types'
import type { Response } from '@renderer/types/newMessage'
import {
AssistantMessageStatus,
MessageBlockStatus,
MessageBlockType,
PlaceholderMessageBlock
} from '@renderer/types/newMessage'
import { uuid } from '@renderer/utils'
import { formatErrorMessage, isAbortError } from '@renderer/utils/error'
import { createBaseMessageBlock, createErrorBlock } from '@renderer/utils/messageUtils/create'
import { findAllBlocks, getMainTextContent } from '@renderer/utils/messageUtils/find'
import { isFocused, isOnHomePage } from '@renderer/utils/window'
import { BlockManager } from '../BlockManager'
interface BaseCallbacksDependencies {
blockManager: BlockManager
dispatch: any
getState: any
topicId: string
assistantMsgId: string
saveUpdatesToDB: any
assistant: Assistant
}
export const createBaseCallbacks = (deps: BaseCallbacksDependencies) => {
const { blockManager, dispatch, getState, topicId, assistantMsgId, saveUpdatesToDB, assistant } = deps
const startTime = Date.now()
const notificationService = NotificationService.getInstance()
// 通用的 block 查找函数
const findBlockIdForCompletion = (message?: any) => {
// 优先使用 BlockManager 中的 activeBlockInfo
const activeBlockInfo = blockManager.activeBlockInfo
if (activeBlockInfo) {
return activeBlockInfo.id
}
// 如果没有活跃的block从message中查找最新的block作为备选
const targetMessage = message || getState().messages.entities[assistantMsgId]
if (targetMessage) {
const allBlocks = findAllBlocks(targetMessage)
if (allBlocks.length > 0) {
return allBlocks[allBlocks.length - 1].id // 返回最新的block
}
}
// 最后的备选方案:从 blockManager 获取占位符块ID
return blockManager.initialPlaceholderBlockId
}
return {
onLLMResponseCreated: async () => {
const baseBlock = createBaseMessageBlock(assistantMsgId, MessageBlockType.UNKNOWN, {
status: MessageBlockStatus.PROCESSING
})
await blockManager.handleBlockTransition(baseBlock as PlaceholderMessageBlock, MessageBlockType.UNKNOWN)
},
onError: async (error: any) => {
console.dir(error, { depth: null })
const isErrorTypeAbort = isAbortError(error)
let pauseErrorLanguagePlaceholder = ''
if (isErrorTypeAbort) {
pauseErrorLanguagePlaceholder = 'pause_placeholder'
}
const serializableError = {
name: error.name,
message: pauseErrorLanguagePlaceholder || error.message || formatErrorMessage(error),
originalMessage: error.message,
stack: error.stack,
status: error.status || error.code,
requestId: error.request_id
}
const duration = Date.now() - startTime
// 发送错误通知(除了中止错误)
if (!isErrorTypeAbort) {
const timeOut = duration > 30 * 1000
if ((!isOnHomePage() && timeOut) || (!isFocused() && timeOut)) {
await notificationService.send({
id: uuid(),
type: 'error',
title: i18n.t('notification.assistant'),
message: serializableError.message,
silent: false,
timestamp: Date.now(),
source: 'assistant'
})
}
}
const possibleBlockId = findBlockIdForCompletion()
if (possibleBlockId) {
// 更改上一个block的状态为ERROR
const changes = {
status: isErrorTypeAbort ? MessageBlockStatus.PAUSED : MessageBlockStatus.ERROR
}
blockManager.smartBlockUpdate(possibleBlockId, changes, blockManager.lastBlockType!, true)
}
const errorBlock = createErrorBlock(assistantMsgId, serializableError, { status: MessageBlockStatus.SUCCESS })
await blockManager.handleBlockTransition(errorBlock, MessageBlockType.ERROR)
const messageErrorUpdate = {
status: isErrorTypeAbort ? AssistantMessageStatus.SUCCESS : AssistantMessageStatus.ERROR
}
dispatch(
newMessagesActions.updateMessage({
topicId,
messageId: assistantMsgId,
updates: messageErrorUpdate
})
)
await saveUpdatesToDB(assistantMsgId, topicId, messageErrorUpdate, [])
EventEmitter.emit(EVENT_NAMES.MESSAGE_COMPLETE, {
id: assistantMsgId,
topicId,
status: isErrorTypeAbort ? 'pause' : 'error',
error: error.message
})
},
onComplete: async (status: AssistantMessageStatus, response?: Response) => {
const finalStateOnComplete = getState()
const finalAssistantMsg = finalStateOnComplete.messages.entities[assistantMsgId]
if (status === 'success' && finalAssistantMsg) {
const userMsgId = finalAssistantMsg.askId
const orderedMsgs = selectMessagesForTopic(finalStateOnComplete, topicId)
const userMsgIndex = orderedMsgs.findIndex((m) => m.id === userMsgId)
const contextForUsage = userMsgIndex !== -1 ? orderedMsgs.slice(0, userMsgIndex + 1) : []
const finalContextWithAssistant = [...contextForUsage, finalAssistantMsg]
const possibleBlockId = findBlockIdForCompletion(finalAssistantMsg)
if (possibleBlockId) {
const changes = {
status: MessageBlockStatus.SUCCESS
}
blockManager.smartBlockUpdate(possibleBlockId, changes, blockManager.lastBlockType!, true)
}
const duration = Date.now() - startTime
const content = getMainTextContent(finalAssistantMsg)
const timeOut = duration > 30 * 1000
// 发送长时间运行消息的成功通知
if ((!isOnHomePage() && timeOut) || (!isFocused() && timeOut)) {
await notificationService.send({
id: uuid(),
type: 'success',
title: i18n.t('notification.assistant'),
message: content.length > 50 ? content.slice(0, 47) + '...' : content,
silent: false,
timestamp: Date.now(),
source: 'assistant',
channel: 'system'
})
}
// 更新topic的name
autoRenameTopic(assistant, topicId)
// 处理usage估算
if (
response &&
(response.usage?.total_tokens === 0 ||
response?.usage?.prompt_tokens === 0 ||
response?.usage?.completion_tokens === 0)
) {
const usage = await estimateMessagesUsage({ assistant, messages: finalContextWithAssistant })
response.usage = usage
}
}
if (response && response.metrics) {
if (response.metrics.completion_tokens === 0 && response.usage?.completion_tokens) {
response = {
...response,
metrics: {
...response.metrics,
completion_tokens: response.usage.completion_tokens
}
}
}
}
const messageUpdates = { status, metrics: response?.metrics, usage: response?.usage }
dispatch(
newMessagesActions.updateMessage({
topicId,
messageId: assistantMsgId,
updates: messageUpdates
})
)
await saveUpdatesToDB(assistantMsgId, topicId, messageUpdates, [])
EventEmitter.emit(EVENT_NAMES.MESSAGE_COMPLETE, { id: assistantMsgId, topicId, status })
}
}
}

View File

@ -0,0 +1,112 @@
import type { ExternalToolResult } from '@renderer/types'
import { CitationMessageBlock, MessageBlockStatus, MessageBlockType } from '@renderer/types/newMessage'
import { createCitationBlock } from '@renderer/utils/messageUtils/create'
import { findMainTextBlocks } from '@renderer/utils/messageUtils/find'
import { BlockManager } from '../BlockManager'
interface CitationCallbacksDependencies {
blockManager: BlockManager
assistantMsgId: string
getState: any
}
export const createCitationCallbacks = (deps: CitationCallbacksDependencies) => {
const { blockManager, assistantMsgId, getState } = deps
// 内部维护的状态
let citationBlockId: string | null = null
return {
onExternalToolInProgress: async () => {
const citationBlock = createCitationBlock(assistantMsgId, {}, { status: MessageBlockStatus.PROCESSING })
citationBlockId = citationBlock.id
await blockManager.handleBlockTransition(citationBlock, MessageBlockType.CITATION)
},
onExternalToolComplete: (externalToolResult: ExternalToolResult) => {
if (citationBlockId) {
const changes: Partial<CitationMessageBlock> = {
response: externalToolResult.webSearch,
knowledge: externalToolResult.knowledge,
status: MessageBlockStatus.SUCCESS
}
blockManager.smartBlockUpdate(citationBlockId, changes, MessageBlockType.CITATION, true)
} else {
console.error('[onExternalToolComplete] citationBlockId is null. Cannot update.')
}
},
onLLMWebSearchInProgress: async () => {
if (blockManager.hasInitialPlaceholder) {
// blockManager.lastBlockType = MessageBlockType.CITATION
console.log('blockManager.initialPlaceholderBlockId', blockManager.initialPlaceholderBlockId)
citationBlockId = blockManager.initialPlaceholderBlockId!
console.log('citationBlockId', citationBlockId)
const changes = {
type: MessageBlockType.CITATION,
status: MessageBlockStatus.PROCESSING
}
blockManager.smartBlockUpdate(citationBlockId, changes, MessageBlockType.CITATION)
} else {
const citationBlock = createCitationBlock(assistantMsgId, {}, { status: MessageBlockStatus.PROCESSING })
citationBlockId = citationBlock.id
await blockManager.handleBlockTransition(citationBlock, MessageBlockType.CITATION)
}
},
onLLMWebSearchComplete: async (llmWebSearchResult: any) => {
const blockId = citationBlockId || blockManager.initialPlaceholderBlockId
if (blockId) {
const changes: Partial<CitationMessageBlock> = {
type: MessageBlockType.CITATION,
response: llmWebSearchResult,
status: MessageBlockStatus.SUCCESS
}
blockManager.smartBlockUpdate(blockId, changes, MessageBlockType.CITATION, true)
const state = getState()
const existingMainTextBlocks = findMainTextBlocks(state.messages.entities[assistantMsgId])
if (existingMainTextBlocks.length > 0) {
const existingMainTextBlock = existingMainTextBlocks[0]
const currentRefs = existingMainTextBlock.citationReferences || []
const mainTextChanges = {
citationReferences: [...currentRefs, { blockId, citationBlockSource: llmWebSearchResult.source }]
}
blockManager.smartBlockUpdate(existingMainTextBlock.id, mainTextChanges, MessageBlockType.MAIN_TEXT, true)
}
if (blockManager.hasInitialPlaceholder) {
citationBlockId = blockManager.initialPlaceholderBlockId
}
} else {
const citationBlock = createCitationBlock(
assistantMsgId,
{
response: llmWebSearchResult
},
{
status: MessageBlockStatus.SUCCESS
}
)
citationBlockId = citationBlock.id
const state = getState()
const existingMainTextBlocks = findMainTextBlocks(state.messages.entities[assistantMsgId])
if (existingMainTextBlocks.length > 0) {
const existingMainTextBlock = existingMainTextBlocks[0]
const currentRefs = existingMainTextBlock.citationReferences || []
const mainTextChanges = {
citationReferences: [...currentRefs, { citationBlockId, citationBlockSource: llmWebSearchResult.source }]
}
blockManager.smartBlockUpdate(existingMainTextBlock.id, mainTextChanges, MessageBlockType.MAIN_TEXT, true)
}
await blockManager.handleBlockTransition(citationBlock, MessageBlockType.CITATION)
}
},
// 暴露给外部的方法用于textCallbacks中获取citationBlockId
getCitationBlockId: () => citationBlockId
}
}

View File

@ -0,0 +1,69 @@
import { ImageMessageBlock, MessageBlockStatus, MessageBlockType } from '@renderer/types/newMessage'
import { createImageBlock } from '@renderer/utils/messageUtils/create'
import { BlockManager } from '../BlockManager'
interface ImageCallbacksDependencies {
blockManager: BlockManager
assistantMsgId: string
}
export const createImageCallbacks = (deps: ImageCallbacksDependencies) => {
const { blockManager, assistantMsgId } = deps
// 内部维护的状态
let imageBlockId: string | null = null
return {
onImageCreated: async () => {
if (blockManager.hasInitialPlaceholder) {
const initialChanges = {
type: MessageBlockType.IMAGE,
status: MessageBlockStatus.PENDING
}
imageBlockId = blockManager.initialPlaceholderBlockId!
blockManager.smartBlockUpdate(imageBlockId, initialChanges, MessageBlockType.IMAGE)
} else if (!imageBlockId) {
const imageBlock = createImageBlock(assistantMsgId, {
status: MessageBlockStatus.PENDING
})
imageBlockId = imageBlock.id
await blockManager.handleBlockTransition(imageBlock, MessageBlockType.IMAGE)
}
},
onImageDelta: (imageData: any) => {
const imageUrl = imageData.images?.[0] || 'placeholder_image_url'
if (imageBlockId) {
const changes: Partial<ImageMessageBlock> = {
url: imageUrl,
metadata: { generateImageResponse: imageData },
status: MessageBlockStatus.STREAMING
}
blockManager.smartBlockUpdate(imageBlockId, changes, MessageBlockType.IMAGE, true)
}
},
onImageGenerated: (imageData: any) => {
if (imageBlockId) {
if (!imageData) {
const changes: Partial<ImageMessageBlock> = {
status: MessageBlockStatus.SUCCESS
}
blockManager.smartBlockUpdate(imageBlockId, changes, MessageBlockType.IMAGE)
} else {
const imageUrl = imageData.images?.[0] || 'placeholder_image_url'
const changes: Partial<ImageMessageBlock> = {
url: imageUrl,
metadata: { generateImageResponse: imageData },
status: MessageBlockStatus.SUCCESS
}
blockManager.smartBlockUpdate(imageBlockId, changes, MessageBlockType.IMAGE, true)
}
imageBlockId = null
} else {
console.error('[onImageGenerated] Last block was not an Image block or ID is missing.')
}
}
}
}

View File

@ -0,0 +1,79 @@
import type { Assistant } from '@renderer/types'
import { BlockManager } from '../BlockManager'
import { createBaseCallbacks } from './baseCallbacks'
import { createCitationCallbacks } from './citationCallbacks'
import { createImageCallbacks } from './imageCallbacks'
import { createTextCallbacks } from './textCallbacks'
import { createThinkingCallbacks } from './thinkingCallbacks'
import { createToolCallbacks } from './toolCallbacks'
interface CallbacksDependencies {
blockManager: BlockManager
dispatch: any
getState: any
topicId: string
assistantMsgId: string
saveUpdatesToDB: any
assistant: Assistant
}
export const createCallbacks = (deps: CallbacksDependencies) => {
const { blockManager, dispatch, getState, topicId, assistantMsgId, saveUpdatesToDB, assistant } = deps
// 创建基础回调
const baseCallbacks = createBaseCallbacks({
blockManager,
dispatch,
getState,
topicId,
assistantMsgId,
saveUpdatesToDB,
assistant
})
// 创建各类回调
const thinkingCallbacks = createThinkingCallbacks({
blockManager,
assistantMsgId
})
const toolCallbacks = createToolCallbacks({
blockManager,
assistantMsgId
})
const imageCallbacks = createImageCallbacks({
blockManager,
assistantMsgId
})
const citationCallbacks = createCitationCallbacks({
blockManager,
assistantMsgId,
getState
})
// 创建textCallbacks时传入citationCallbacks的getCitationBlockId方法
const textCallbacks = createTextCallbacks({
blockManager,
getState,
assistantMsgId,
getCitationBlockId: citationCallbacks.getCitationBlockId
})
// 组合所有回调
return {
...baseCallbacks,
...textCallbacks,
...thinkingCallbacks,
...toolCallbacks,
...imageCallbacks,
...citationCallbacks,
// 清理资源的方法
cleanup: () => {
// 清理由 messageThunk 中的节流函数管理,这里不需要特别处理
// 如果需要,可以调用 blockManager 的相关清理方法
}
}
}

View File

@ -0,0 +1,69 @@
import { WebSearchSource } from '@renderer/types'
import { CitationMessageBlock, MessageBlock, MessageBlockStatus, MessageBlockType } from '@renderer/types/newMessage'
import { createMainTextBlock } from '@renderer/utils/messageUtils/create'
import { BlockManager } from '../BlockManager'
interface TextCallbacksDependencies {
blockManager: BlockManager
getState: any
assistantMsgId: string
getCitationBlockId: () => string | null
}
export const createTextCallbacks = (deps: TextCallbacksDependencies) => {
const { blockManager, getState, assistantMsgId, getCitationBlockId } = deps
// 内部维护的状态
let mainTextBlockId: string | null = null
return {
onTextStart: async () => {
if (blockManager.hasInitialPlaceholder) {
const changes = {
type: MessageBlockType.MAIN_TEXT,
content: '',
status: MessageBlockStatus.STREAMING
}
mainTextBlockId = blockManager.initialPlaceholderBlockId!
blockManager.smartBlockUpdate(mainTextBlockId, changes, MessageBlockType.MAIN_TEXT, true)
} else if (!mainTextBlockId) {
const newBlock = createMainTextBlock(assistantMsgId, '', {
status: MessageBlockStatus.STREAMING
})
mainTextBlockId = newBlock.id
await blockManager.handleBlockTransition(newBlock, MessageBlockType.MAIN_TEXT)
}
},
onTextChunk: async (text: string) => {
const citationBlockId = getCitationBlockId()
const citationBlockSource = citationBlockId
? (getState().messageBlocks.entities[citationBlockId] as CitationMessageBlock).response?.source
: WebSearchSource.WEBSEARCH
if (text) {
const blockChanges: Partial<MessageBlock> = {
content: text,
status: MessageBlockStatus.STREAMING,
citationReferences: citationBlockId ? [{ citationBlockId, citationBlockSource }] : []
}
blockManager.smartBlockUpdate(mainTextBlockId!, blockChanges, MessageBlockType.MAIN_TEXT)
}
},
onTextComplete: async (finalText: string) => {
if (mainTextBlockId) {
const changes = {
content: finalText,
status: MessageBlockStatus.SUCCESS
}
blockManager.smartBlockUpdate(mainTextBlockId, changes, MessageBlockType.MAIN_TEXT, true)
mainTextBlockId = null
} else {
console.warn(
`[onTextComplete] Received text.complete but last block was not MAIN_TEXT (was ${blockManager.lastBlockType}) or lastBlockId is null.`
)
}
}
}
}

View File

@ -0,0 +1,66 @@
import { MessageBlock, MessageBlockStatus, MessageBlockType } from '@renderer/types/newMessage'
import { createThinkingBlock } from '@renderer/utils/messageUtils/create'
import { BlockManager } from '../BlockManager'
interface ThinkingCallbacksDependencies {
blockManager: BlockManager
assistantMsgId: string
}
export const createThinkingCallbacks = (deps: ThinkingCallbacksDependencies) => {
const { blockManager, assistantMsgId } = deps
// 内部维护的状态
let thinkingBlockId: string | null = null
return {
onThinkingStart: async () => {
if (blockManager.hasInitialPlaceholder) {
const changes = {
type: MessageBlockType.THINKING,
content: '',
status: MessageBlockStatus.STREAMING,
thinking_millsec: 0
}
thinkingBlockId = blockManager.initialPlaceholderBlockId!
blockManager.smartBlockUpdate(thinkingBlockId, changes, MessageBlockType.THINKING, true)
} else if (!thinkingBlockId) {
const newBlock = createThinkingBlock(assistantMsgId, '', {
status: MessageBlockStatus.STREAMING,
thinking_millsec: 0
})
thinkingBlockId = newBlock.id
await blockManager.handleBlockTransition(newBlock, MessageBlockType.THINKING)
}
},
onThinkingChunk: async (text: string, thinking_millsec?: number) => {
if (thinkingBlockId) {
const blockChanges: Partial<MessageBlock> = {
content: text,
status: MessageBlockStatus.STREAMING,
thinking_millsec: thinking_millsec || 0
}
blockManager.smartBlockUpdate(thinkingBlockId, blockChanges, MessageBlockType.THINKING)
}
},
onThinkingComplete: (finalText: string, final_thinking_millsec?: number) => {
if (thinkingBlockId) {
const changes = {
type: MessageBlockType.THINKING,
content: finalText,
status: MessageBlockStatus.SUCCESS,
thinking_millsec: final_thinking_millsec || 0
}
blockManager.smartBlockUpdate(thinkingBlockId, changes, MessageBlockType.THINKING, true)
thinkingBlockId = null
} else {
console.warn(
`[onThinkingComplete] Received thinking.complete but last block was not THINKING (was ${blockManager.lastBlockType}) or lastBlockId is null.`
)
}
}
}
}

View File

@ -0,0 +1,106 @@
import type { MCPToolResponse } from '@renderer/types'
import { MessageBlockStatus, MessageBlockType, ToolMessageBlock } from '@renderer/types/newMessage'
import { createToolBlock } from '@renderer/utils/messageUtils/create'
import { BlockManager } from '../BlockManager'
interface ToolCallbacksDependencies {
blockManager: BlockManager
assistantMsgId: string
}
export const createToolCallbacks = (deps: ToolCallbacksDependencies) => {
const { blockManager, assistantMsgId } = deps
// 内部维护的状态
const toolCallIdToBlockIdMap = new Map<string, string>()
let toolBlockId: string | null = null
return {
onToolCallPending: (toolResponse: MCPToolResponse) => {
if (blockManager.hasInitialPlaceholder) {
const changes = {
type: MessageBlockType.TOOL,
status: MessageBlockStatus.PENDING,
toolName: toolResponse.tool.name,
metadata: { rawMcpToolResponse: toolResponse }
}
toolBlockId = blockManager.initialPlaceholderBlockId!
blockManager.smartBlockUpdate(toolBlockId, changes, MessageBlockType.TOOL)
toolCallIdToBlockIdMap.set(toolResponse.id, toolBlockId)
} else if (toolResponse.status === 'pending') {
const toolBlock = createToolBlock(assistantMsgId, toolResponse.id, {
toolName: toolResponse.tool.name,
status: MessageBlockStatus.PENDING,
metadata: { rawMcpToolResponse: toolResponse }
})
toolBlockId = toolBlock.id
blockManager.handleBlockTransition(toolBlock, MessageBlockType.TOOL)
toolCallIdToBlockIdMap.set(toolResponse.id, toolBlock.id)
} else {
console.warn(
`[onToolCallPending] Received unhandled tool status: ${toolResponse.status} for ID: ${toolResponse.id}`
)
}
},
onToolCallInProgress: (toolResponse: MCPToolResponse) => {
// 根据 toolResponse.id 查找对应的块ID
const targetBlockId = toolCallIdToBlockIdMap.get(toolResponse.id)
if (targetBlockId && toolResponse.status === 'invoking') {
const changes = {
status: MessageBlockStatus.PROCESSING,
metadata: { rawMcpToolResponse: toolResponse }
}
blockManager.smartBlockUpdate(targetBlockId, changes, MessageBlockType.TOOL)
} else if (!targetBlockId) {
console.warn(
`[onToolCallInProgress] No block ID found for tool ID: ${toolResponse.id}. Available mappings:`,
Array.from(toolCallIdToBlockIdMap.entries())
)
} else {
console.warn(
`[onToolCallInProgress] Received unhandled tool status: ${toolResponse.status} for ID: ${toolResponse.id}`
)
}
},
onToolCallComplete: (toolResponse: MCPToolResponse) => {
const existingBlockId = toolCallIdToBlockIdMap.get(toolResponse.id)
toolCallIdToBlockIdMap.delete(toolResponse.id)
if (toolResponse.status === 'done' || toolResponse.status === 'error' || toolResponse.status === 'cancelled') {
if (!existingBlockId) {
console.error(
`[onToolCallComplete] No existing block found for completed/error tool call ID: ${toolResponse.id}. Cannot update.`
)
return
}
const finalStatus =
toolResponse.status === 'done' || toolResponse.status === 'cancelled'
? MessageBlockStatus.SUCCESS
: MessageBlockStatus.ERROR
const changes: Partial<ToolMessageBlock> = {
content: toolResponse.response,
status: finalStatus,
metadata: { rawMcpToolResponse: toolResponse }
}
if (finalStatus === MessageBlockStatus.ERROR) {
changes.error = { message: `Tool execution failed/error`, details: toolResponse.response }
}
blockManager.smartBlockUpdate(existingBlockId, changes, MessageBlockType.TOOL, true)
} else {
console.warn(
`[onToolCallComplete] Received unhandled tool status: ${toolResponse.status} for ID: ${toolResponse.id}`
)
}
toolBlockId = null
}
}
}

View File

@ -0,0 +1,3 @@
export { BlockManager } from './BlockManager'
export type { createCallbacks as CreateCallbacksFunction } from './callbacks'
export { createCallbacks } from './callbacks'

View File

@ -1,9 +1,10 @@
import { combineReducers, configureStore } from '@reduxjs/toolkit'
import { BlockManager } from '@renderer/services/messageStreaming/BlockManager'
import { createCallbacks } from '@renderer/services/messageStreaming/callbacks'
import { createStreamProcessor } from '@renderer/services/StreamProcessingService'
import type { AppDispatch } from '@renderer/store'
import { messageBlocksSlice } from '@renderer/store/messageBlock'
import { messagesSlice } from '@renderer/store/newMessage'
import { streamCallback } from '@renderer/store/thunk/messageThunk'
import type { Assistant, ExternalToolResult, MCPTool, Model } from '@renderer/types'
import { WebSearchSource } from '@renderer/types'
import type { Chunk } from '@renderer/types/chunk'
@ -13,6 +14,32 @@ import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
import type { RootState } from '../../index'
const createMockCallbacks = (
mockAssistantMsgId: string,
mockTopicId: string,
mockAssistant: Assistant,
dispatch: AppDispatch,
getState: () => ReturnType<typeof reducer> & RootState
) =>
createCallbacks({
blockManager: new BlockManager({
dispatch,
getState,
saveUpdatedBlockToDB: vi.fn(),
saveUpdatesToDB: vi.fn(),
assistantMsgId: mockAssistantMsgId,
topicId: mockTopicId,
throttledBlockUpdate: vi.fn(),
cancelThrottledBlockUpdate: vi.fn()
}),
dispatch,
getState,
topicId: mockTopicId,
assistantMsgId: mockAssistantMsgId,
saveUpdatesToDB: vi.fn(),
assistant: mockAssistant
})
// Mock external dependencies
vi.mock('@renderer/config/models', () => ({
SYSTEM_MODELS: {
@ -186,7 +213,8 @@ vi.mock('@renderer/utils/queue', () => ({
vi.mock('@renderer/utils/messageUtils/find', () => ({
default: {},
findMainTextBlocks: vi.fn(() => []),
getMainTextContent: vi.fn(() => 'Test content')
getMainTextContent: vi.fn(() => 'Test content'),
findAllBlocks: vi.fn(() => [])
}))
vi.mock('i18next', () => {
@ -239,7 +267,7 @@ const createMockStore = () => {
}
// Helper function to simulate processing chunks through the stream processor
const processChunks = async (chunks: Chunk[], callbacks: ReturnType<typeof streamCallback>) => {
const processChunks = async (chunks: Chunk[], callbacks: ReturnType<typeof createCallbacks>) => {
const streamProcessor = createStreamProcessor(callbacks)
const stream = new ReadableStream<Chunk>({
@ -326,7 +354,7 @@ describe('streamCallback Integration Tests', () => {
})
it('should handle complete text streaming flow', async () => {
const callbacks = streamCallback(dispatch, getState, mockTopicId, mockAssistant, mockAssistantMsgId)
const callbacks = createMockCallbacks(mockAssistantMsgId, mockTopicId, mockAssistant, dispatch, getState)
const chunks: Chunk[] = [
{ type: ChunkType.LLM_RESPONSE_CREATED },
@ -369,7 +397,7 @@ describe('streamCallback Integration Tests', () => {
})
it('should handle thinking flow', async () => {
const callbacks = streamCallback(dispatch, getState, mockTopicId, mockAssistant, mockAssistantMsgId)
const callbacks = createMockCallbacks(mockAssistantMsgId, mockTopicId, mockAssistant, dispatch, getState)
const chunks: Chunk[] = [
{ type: ChunkType.LLM_RESPONSE_CREATED },
@ -394,7 +422,7 @@ describe('streamCallback Integration Tests', () => {
})
it('should handle tool call flow', async () => {
const callbacks = streamCallback(dispatch, getState, mockTopicId, mockAssistant, mockAssistantMsgId)
const callbacks = createMockCallbacks(mockAssistantMsgId, mockTopicId, mockAssistant, dispatch, getState)
const mockTool: MCPTool = {
id: 'tool-1',
@ -464,7 +492,7 @@ describe('streamCallback Integration Tests', () => {
})
it('should handle image generation flow', async () => {
const callbacks = streamCallback(dispatch, getState, mockTopicId, mockAssistant, mockAssistantMsgId)
const callbacks = createMockCallbacks(mockAssistantMsgId, mockTopicId, mockAssistant, dispatch, getState)
const chunks: Chunk[] = [
{ type: ChunkType.LLM_RESPONSE_CREATED },
@ -504,7 +532,7 @@ describe('streamCallback Integration Tests', () => {
})
it('should handle web search flow', async () => {
const callbacks = streamCallback(dispatch, getState, mockTopicId, mockAssistant, mockAssistantMsgId)
const callbacks = createMockCallbacks(mockAssistantMsgId, mockTopicId, mockAssistant, dispatch, getState)
const mockWebSearchResult = {
source: WebSearchSource.WEBSEARCH,
@ -523,7 +551,6 @@ describe('streamCallback Integration Tests', () => {
// 验证 Redux 状态
const state = getState()
const blocks = Object.values(state.messageBlocks.entities)
const citationBlock = blocks.find((block) => block.type === MessageBlockType.CITATION)
expect(citationBlock).toBeDefined()
expect(citationBlock?.response?.source).toEqual(mockWebSearchResult.source)
@ -531,7 +558,7 @@ describe('streamCallback Integration Tests', () => {
})
it('should handle mixed content flow (thinking + tool + text)', async () => {
const callbacks = streamCallback(dispatch, getState, mockTopicId, mockAssistant, mockAssistantMsgId)
const callbacks = createMockCallbacks(mockAssistantMsgId, mockTopicId, mockAssistant, dispatch, getState)
const mockCalculatorTool: MCPTool = {
id: 'tool-1',
@ -632,7 +659,7 @@ describe('streamCallback Integration Tests', () => {
})
it('should handle error flow', async () => {
const callbacks = streamCallback(dispatch, getState, mockTopicId, mockAssistant, mockAssistantMsgId)
const callbacks = createMockCallbacks(mockAssistantMsgId, mockTopicId, mockAssistant, dispatch, getState)
const mockError = new Error('Test error')
@ -662,7 +689,7 @@ describe('streamCallback Integration Tests', () => {
})
it('should handle external tool flow', async () => {
const callbacks = streamCallback(dispatch, getState, mockTopicId, mockAssistant, mockAssistantMsgId)
const callbacks = createMockCallbacks(mockAssistantMsgId, mockTopicId, mockAssistant, dispatch, getState)
const mockExternalToolResult: ExternalToolResult = {
webSearch: {
@ -700,7 +727,7 @@ describe('streamCallback Integration Tests', () => {
})
it('should handle abort error correctly', async () => {
const callbacks = streamCallback(dispatch, getState, mockTopicId, mockAssistant, mockAssistantMsgId)
const callbacks = createMockCallbacks(mockAssistantMsgId, mockTopicId, mockAssistant, dispatch, getState)
// 创建一个模拟的 abort 错误
const abortError = new Error('Request aborted')
@ -731,7 +758,7 @@ describe('streamCallback Integration Tests', () => {
})
it('should maintain block reference integrity during streaming', async () => {
const callbacks = streamCallback(dispatch, getState, mockTopicId, mockAssistant, mockAssistantMsgId)
const callbacks = createMockCallbacks(mockAssistantMsgId, mockTopicId, mockAssistant, dispatch, getState)
const chunks: Chunk[] = [
{ type: ChunkType.LLM_RESPONSE_CREATED },

File diff suppressed because it is too large Load Diff