mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2026-02-13 14:23:44 +08:00
feat: implement BlockManager and associated callbacks for message str… (#8167)
* feat: implement BlockManager and associated callbacks for message streaming - Introduced BlockManager to manage message blocks with smart update strategies. - Added various callback handlers for different message types including text, image, citation, and tool responses. - Enhanced state management for active blocks and transitions between different message types. - Created utility functions for handling block updates and transitions, improving overall message processing flow. - Refactored message thunk to utilize BlockManager for better organization and maintainability. This implementation lays the groundwork for more efficient message streaming and processing in the application. * refactor: clean up BlockManager and callback implementations - Removed redundant assignments of lastBlockType in various callback files. - Updated error handling logic to ensure correct message status updates. - Added console logs for debugging purposes in BlockManager and citation callbacks. - Enhanced smartBlockUpdate method call in citation callbacks for better state management. * refactor: streamline BlockManager and callback logic - Removed unnecessary accumulated content variables in text and thinking callbacks. - Updated content handling in callbacks to directly use incoming text instead of accumulating. - Enhanced smartBlockUpdate calls for better state management in message streaming. - Cleaned up console log statements for improved readability and debugging.
This commit is contained in:
parent
aa254a3772
commit
7e471bfea4
136
src/renderer/src/services/messageStreaming/BlockManager.ts
Normal file
136
src/renderer/src/services/messageStreaming/BlockManager.ts
Normal file
@ -0,0 +1,136 @@
|
||||
import type { AppDispatch, RootState } from '@renderer/store'
|
||||
import { updateOneBlock, upsertOneBlock } from '@renderer/store/messageBlock'
|
||||
import { newMessagesActions } from '@renderer/store/newMessage'
|
||||
import { MessageBlock, MessageBlockType } from '@renderer/types/newMessage'
|
||||
|
||||
interface ActiveBlockInfo {
|
||||
id: string
|
||||
type: MessageBlockType
|
||||
}
|
||||
|
||||
interface BlockManagerDependencies {
|
||||
dispatch: AppDispatch
|
||||
getState: () => RootState
|
||||
saveUpdatedBlockToDB: (
|
||||
blockId: string | null,
|
||||
messageId: string,
|
||||
topicId: string,
|
||||
getState: () => RootState
|
||||
) => Promise<void>
|
||||
saveUpdatesToDB: (
|
||||
messageId: string,
|
||||
topicId: string,
|
||||
messageUpdates: Partial<any>,
|
||||
blocksToUpdate: MessageBlock[]
|
||||
) => Promise<void>
|
||||
assistantMsgId: string
|
||||
topicId: string
|
||||
// 节流器管理从外部传入
|
||||
throttledBlockUpdate: (id: string, blockUpdate: any) => void
|
||||
cancelThrottledBlockUpdate: (id: string) => void
|
||||
}
|
||||
|
||||
export class BlockManager {
|
||||
private deps: BlockManagerDependencies
|
||||
|
||||
// 简化后的状态管理
|
||||
private _activeBlockInfo: ActiveBlockInfo | null = null
|
||||
private _lastBlockType: MessageBlockType | null = null // 保留用于错误处理
|
||||
|
||||
constructor(dependencies: BlockManagerDependencies) {
|
||||
this.deps = dependencies
|
||||
}
|
||||
|
||||
// Getters
|
||||
get activeBlockInfo() {
|
||||
return this._activeBlockInfo
|
||||
}
|
||||
|
||||
get lastBlockType() {
|
||||
return this._lastBlockType
|
||||
}
|
||||
|
||||
get hasInitialPlaceholder() {
|
||||
return this._activeBlockInfo?.type === MessageBlockType.UNKNOWN
|
||||
}
|
||||
|
||||
get initialPlaceholderBlockId() {
|
||||
return this.hasInitialPlaceholder ? this._activeBlockInfo?.id || null : null
|
||||
}
|
||||
|
||||
// Setters
|
||||
set lastBlockType(value: MessageBlockType | null) {
|
||||
this._lastBlockType = value
|
||||
}
|
||||
|
||||
set activeBlockInfo(value: ActiveBlockInfo | null) {
|
||||
this._activeBlockInfo = value
|
||||
}
|
||||
|
||||
/**
|
||||
* 智能更新策略:根据块类型连续性自动判断使用节流还是立即更新
|
||||
*/
|
||||
smartBlockUpdate(
|
||||
blockId: string,
|
||||
changes: Partial<MessageBlock>,
|
||||
blockType: MessageBlockType,
|
||||
isComplete: boolean = false
|
||||
) {
|
||||
const isBlockTypeChanged = this._lastBlockType !== null && this._lastBlockType !== blockType
|
||||
if (isBlockTypeChanged || isComplete) {
|
||||
// 如果块类型改变,则取消上一个块的节流更新
|
||||
if (isBlockTypeChanged && this._activeBlockInfo) {
|
||||
this.deps.cancelThrottledBlockUpdate(this._activeBlockInfo.id)
|
||||
}
|
||||
// 如果当前块完成,则取消当前块的节流更新
|
||||
if (isComplete) {
|
||||
this.deps.cancelThrottledBlockUpdate(blockId)
|
||||
this._activeBlockInfo = null // 块完成时清空activeBlockInfo
|
||||
} else {
|
||||
this._activeBlockInfo = { id: blockId, type: blockType } // 更新活跃块信息
|
||||
}
|
||||
this.deps.dispatch(updateOneBlock({ id: blockId, changes }))
|
||||
this.deps.saveUpdatedBlockToDB(blockId, this.deps.assistantMsgId, this.deps.topicId, this.deps.getState)
|
||||
this._lastBlockType = blockType
|
||||
} else {
|
||||
this._activeBlockInfo = { id: blockId, type: blockType } // 更新活跃块信息
|
||||
this.deps.throttledBlockUpdate(blockId, changes)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 处理块转换
|
||||
*/
|
||||
async handleBlockTransition(newBlock: MessageBlock, newBlockType: MessageBlockType) {
|
||||
this._lastBlockType = newBlockType
|
||||
this._activeBlockInfo = { id: newBlock.id, type: newBlockType } // 设置新的活跃块信息
|
||||
|
||||
this.deps.dispatch(
|
||||
newMessagesActions.updateMessage({
|
||||
topicId: this.deps.topicId,
|
||||
messageId: this.deps.assistantMsgId,
|
||||
updates: { blockInstruction: { id: newBlock.id } }
|
||||
})
|
||||
)
|
||||
this.deps.dispatch(upsertOneBlock(newBlock))
|
||||
this.deps.dispatch(
|
||||
newMessagesActions.upsertBlockReference({
|
||||
messageId: this.deps.assistantMsgId,
|
||||
blockId: newBlock.id,
|
||||
status: newBlock.status
|
||||
})
|
||||
)
|
||||
|
||||
const currentState = this.deps.getState()
|
||||
const updatedMessage = currentState.messages.entities[this.deps.assistantMsgId]
|
||||
if (updatedMessage) {
|
||||
await this.deps.saveUpdatesToDB(this.deps.assistantMsgId, this.deps.topicId, { blocks: updatedMessage.blocks }, [
|
||||
newBlock
|
||||
])
|
||||
} else {
|
||||
console.error(
|
||||
`[handleBlockTransition] Failed to get updated message ${this.deps.assistantMsgId} from state for DB save.`
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,214 @@
|
||||
import { autoRenameTopic } from '@renderer/hooks/useTopic'
|
||||
import i18n from '@renderer/i18n'
|
||||
import { EVENT_NAMES, EventEmitter } from '@renderer/services/EventService'
|
||||
import { NotificationService } from '@renderer/services/NotificationService'
|
||||
import { estimateMessagesUsage } from '@renderer/services/TokenService'
|
||||
import { selectMessagesForTopic } from '@renderer/store/newMessage'
|
||||
import { newMessagesActions } from '@renderer/store/newMessage'
|
||||
import type { Assistant } from '@renderer/types'
|
||||
import type { Response } from '@renderer/types/newMessage'
|
||||
import {
|
||||
AssistantMessageStatus,
|
||||
MessageBlockStatus,
|
||||
MessageBlockType,
|
||||
PlaceholderMessageBlock
|
||||
} from '@renderer/types/newMessage'
|
||||
import { uuid } from '@renderer/utils'
|
||||
import { formatErrorMessage, isAbortError } from '@renderer/utils/error'
|
||||
import { createBaseMessageBlock, createErrorBlock } from '@renderer/utils/messageUtils/create'
|
||||
import { findAllBlocks, getMainTextContent } from '@renderer/utils/messageUtils/find'
|
||||
import { isFocused, isOnHomePage } from '@renderer/utils/window'
|
||||
|
||||
import { BlockManager } from '../BlockManager'
|
||||
|
||||
interface BaseCallbacksDependencies {
|
||||
blockManager: BlockManager
|
||||
dispatch: any
|
||||
getState: any
|
||||
topicId: string
|
||||
assistantMsgId: string
|
||||
saveUpdatesToDB: any
|
||||
assistant: Assistant
|
||||
}
|
||||
|
||||
export const createBaseCallbacks = (deps: BaseCallbacksDependencies) => {
|
||||
const { blockManager, dispatch, getState, topicId, assistantMsgId, saveUpdatesToDB, assistant } = deps
|
||||
|
||||
const startTime = Date.now()
|
||||
const notificationService = NotificationService.getInstance()
|
||||
|
||||
// 通用的 block 查找函数
|
||||
const findBlockIdForCompletion = (message?: any) => {
|
||||
// 优先使用 BlockManager 中的 activeBlockInfo
|
||||
const activeBlockInfo = blockManager.activeBlockInfo
|
||||
|
||||
if (activeBlockInfo) {
|
||||
return activeBlockInfo.id
|
||||
}
|
||||
|
||||
// 如果没有活跃的block,从message中查找最新的block作为备选
|
||||
const targetMessage = message || getState().messages.entities[assistantMsgId]
|
||||
if (targetMessage) {
|
||||
const allBlocks = findAllBlocks(targetMessage)
|
||||
if (allBlocks.length > 0) {
|
||||
return allBlocks[allBlocks.length - 1].id // 返回最新的block
|
||||
}
|
||||
}
|
||||
|
||||
// 最后的备选方案:从 blockManager 获取占位符块ID
|
||||
return blockManager.initialPlaceholderBlockId
|
||||
}
|
||||
|
||||
return {
|
||||
onLLMResponseCreated: async () => {
|
||||
const baseBlock = createBaseMessageBlock(assistantMsgId, MessageBlockType.UNKNOWN, {
|
||||
status: MessageBlockStatus.PROCESSING
|
||||
})
|
||||
await blockManager.handleBlockTransition(baseBlock as PlaceholderMessageBlock, MessageBlockType.UNKNOWN)
|
||||
},
|
||||
|
||||
onError: async (error: any) => {
|
||||
console.dir(error, { depth: null })
|
||||
const isErrorTypeAbort = isAbortError(error)
|
||||
let pauseErrorLanguagePlaceholder = ''
|
||||
if (isErrorTypeAbort) {
|
||||
pauseErrorLanguagePlaceholder = 'pause_placeholder'
|
||||
}
|
||||
|
||||
const serializableError = {
|
||||
name: error.name,
|
||||
message: pauseErrorLanguagePlaceholder || error.message || formatErrorMessage(error),
|
||||
originalMessage: error.message,
|
||||
stack: error.stack,
|
||||
status: error.status || error.code,
|
||||
requestId: error.request_id
|
||||
}
|
||||
|
||||
const duration = Date.now() - startTime
|
||||
// 发送错误通知(除了中止错误)
|
||||
if (!isErrorTypeAbort) {
|
||||
const timeOut = duration > 30 * 1000
|
||||
if ((!isOnHomePage() && timeOut) || (!isFocused() && timeOut)) {
|
||||
await notificationService.send({
|
||||
id: uuid(),
|
||||
type: 'error',
|
||||
title: i18n.t('notification.assistant'),
|
||||
message: serializableError.message,
|
||||
silent: false,
|
||||
timestamp: Date.now(),
|
||||
source: 'assistant'
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
const possibleBlockId = findBlockIdForCompletion()
|
||||
|
||||
if (possibleBlockId) {
|
||||
// 更改上一个block的状态为ERROR
|
||||
const changes = {
|
||||
status: isErrorTypeAbort ? MessageBlockStatus.PAUSED : MessageBlockStatus.ERROR
|
||||
}
|
||||
blockManager.smartBlockUpdate(possibleBlockId, changes, blockManager.lastBlockType!, true)
|
||||
}
|
||||
|
||||
const errorBlock = createErrorBlock(assistantMsgId, serializableError, { status: MessageBlockStatus.SUCCESS })
|
||||
await blockManager.handleBlockTransition(errorBlock, MessageBlockType.ERROR)
|
||||
const messageErrorUpdate = {
|
||||
status: isErrorTypeAbort ? AssistantMessageStatus.SUCCESS : AssistantMessageStatus.ERROR
|
||||
}
|
||||
dispatch(
|
||||
newMessagesActions.updateMessage({
|
||||
topicId,
|
||||
messageId: assistantMsgId,
|
||||
updates: messageErrorUpdate
|
||||
})
|
||||
)
|
||||
await saveUpdatesToDB(assistantMsgId, topicId, messageErrorUpdate, [])
|
||||
|
||||
EventEmitter.emit(EVENT_NAMES.MESSAGE_COMPLETE, {
|
||||
id: assistantMsgId,
|
||||
topicId,
|
||||
status: isErrorTypeAbort ? 'pause' : 'error',
|
||||
error: error.message
|
||||
})
|
||||
},
|
||||
|
||||
onComplete: async (status: AssistantMessageStatus, response?: Response) => {
|
||||
const finalStateOnComplete = getState()
|
||||
const finalAssistantMsg = finalStateOnComplete.messages.entities[assistantMsgId]
|
||||
|
||||
if (status === 'success' && finalAssistantMsg) {
|
||||
const userMsgId = finalAssistantMsg.askId
|
||||
const orderedMsgs = selectMessagesForTopic(finalStateOnComplete, topicId)
|
||||
const userMsgIndex = orderedMsgs.findIndex((m) => m.id === userMsgId)
|
||||
const contextForUsage = userMsgIndex !== -1 ? orderedMsgs.slice(0, userMsgIndex + 1) : []
|
||||
const finalContextWithAssistant = [...contextForUsage, finalAssistantMsg]
|
||||
|
||||
const possibleBlockId = findBlockIdForCompletion(finalAssistantMsg)
|
||||
|
||||
if (possibleBlockId) {
|
||||
const changes = {
|
||||
status: MessageBlockStatus.SUCCESS
|
||||
}
|
||||
blockManager.smartBlockUpdate(possibleBlockId, changes, blockManager.lastBlockType!, true)
|
||||
}
|
||||
|
||||
const duration = Date.now() - startTime
|
||||
const content = getMainTextContent(finalAssistantMsg)
|
||||
|
||||
const timeOut = duration > 30 * 1000
|
||||
// 发送长时间运行消息的成功通知
|
||||
if ((!isOnHomePage() && timeOut) || (!isFocused() && timeOut)) {
|
||||
await notificationService.send({
|
||||
id: uuid(),
|
||||
type: 'success',
|
||||
title: i18n.t('notification.assistant'),
|
||||
message: content.length > 50 ? content.slice(0, 47) + '...' : content,
|
||||
silent: false,
|
||||
timestamp: Date.now(),
|
||||
source: 'assistant',
|
||||
channel: 'system'
|
||||
})
|
||||
}
|
||||
|
||||
// 更新topic的name
|
||||
autoRenameTopic(assistant, topicId)
|
||||
|
||||
// 处理usage估算
|
||||
if (
|
||||
response &&
|
||||
(response.usage?.total_tokens === 0 ||
|
||||
response?.usage?.prompt_tokens === 0 ||
|
||||
response?.usage?.completion_tokens === 0)
|
||||
) {
|
||||
const usage = await estimateMessagesUsage({ assistant, messages: finalContextWithAssistant })
|
||||
response.usage = usage
|
||||
}
|
||||
}
|
||||
|
||||
if (response && response.metrics) {
|
||||
if (response.metrics.completion_tokens === 0 && response.usage?.completion_tokens) {
|
||||
response = {
|
||||
...response,
|
||||
metrics: {
|
||||
...response.metrics,
|
||||
completion_tokens: response.usage.completion_tokens
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const messageUpdates = { status, metrics: response?.metrics, usage: response?.usage }
|
||||
dispatch(
|
||||
newMessagesActions.updateMessage({
|
||||
topicId,
|
||||
messageId: assistantMsgId,
|
||||
updates: messageUpdates
|
||||
})
|
||||
)
|
||||
await saveUpdatesToDB(assistantMsgId, topicId, messageUpdates, [])
|
||||
|
||||
EventEmitter.emit(EVENT_NAMES.MESSAGE_COMPLETE, { id: assistantMsgId, topicId, status })
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,112 @@
|
||||
import type { ExternalToolResult } from '@renderer/types'
|
||||
import { CitationMessageBlock, MessageBlockStatus, MessageBlockType } from '@renderer/types/newMessage'
|
||||
import { createCitationBlock } from '@renderer/utils/messageUtils/create'
|
||||
import { findMainTextBlocks } from '@renderer/utils/messageUtils/find'
|
||||
|
||||
import { BlockManager } from '../BlockManager'
|
||||
|
||||
interface CitationCallbacksDependencies {
|
||||
blockManager: BlockManager
|
||||
assistantMsgId: string
|
||||
getState: any
|
||||
}
|
||||
|
||||
export const createCitationCallbacks = (deps: CitationCallbacksDependencies) => {
|
||||
const { blockManager, assistantMsgId, getState } = deps
|
||||
|
||||
// 内部维护的状态
|
||||
let citationBlockId: string | null = null
|
||||
|
||||
return {
|
||||
onExternalToolInProgress: async () => {
|
||||
const citationBlock = createCitationBlock(assistantMsgId, {}, { status: MessageBlockStatus.PROCESSING })
|
||||
citationBlockId = citationBlock.id
|
||||
await blockManager.handleBlockTransition(citationBlock, MessageBlockType.CITATION)
|
||||
},
|
||||
|
||||
onExternalToolComplete: (externalToolResult: ExternalToolResult) => {
|
||||
if (citationBlockId) {
|
||||
const changes: Partial<CitationMessageBlock> = {
|
||||
response: externalToolResult.webSearch,
|
||||
knowledge: externalToolResult.knowledge,
|
||||
status: MessageBlockStatus.SUCCESS
|
||||
}
|
||||
blockManager.smartBlockUpdate(citationBlockId, changes, MessageBlockType.CITATION, true)
|
||||
} else {
|
||||
console.error('[onExternalToolComplete] citationBlockId is null. Cannot update.')
|
||||
}
|
||||
},
|
||||
|
||||
onLLMWebSearchInProgress: async () => {
|
||||
if (blockManager.hasInitialPlaceholder) {
|
||||
// blockManager.lastBlockType = MessageBlockType.CITATION
|
||||
console.log('blockManager.initialPlaceholderBlockId', blockManager.initialPlaceholderBlockId)
|
||||
citationBlockId = blockManager.initialPlaceholderBlockId!
|
||||
console.log('citationBlockId', citationBlockId)
|
||||
|
||||
const changes = {
|
||||
type: MessageBlockType.CITATION,
|
||||
status: MessageBlockStatus.PROCESSING
|
||||
}
|
||||
blockManager.smartBlockUpdate(citationBlockId, changes, MessageBlockType.CITATION)
|
||||
} else {
|
||||
const citationBlock = createCitationBlock(assistantMsgId, {}, { status: MessageBlockStatus.PROCESSING })
|
||||
citationBlockId = citationBlock.id
|
||||
await blockManager.handleBlockTransition(citationBlock, MessageBlockType.CITATION)
|
||||
}
|
||||
},
|
||||
|
||||
onLLMWebSearchComplete: async (llmWebSearchResult: any) => {
|
||||
const blockId = citationBlockId || blockManager.initialPlaceholderBlockId
|
||||
if (blockId) {
|
||||
const changes: Partial<CitationMessageBlock> = {
|
||||
type: MessageBlockType.CITATION,
|
||||
response: llmWebSearchResult,
|
||||
status: MessageBlockStatus.SUCCESS
|
||||
}
|
||||
blockManager.smartBlockUpdate(blockId, changes, MessageBlockType.CITATION, true)
|
||||
|
||||
const state = getState()
|
||||
const existingMainTextBlocks = findMainTextBlocks(state.messages.entities[assistantMsgId])
|
||||
if (existingMainTextBlocks.length > 0) {
|
||||
const existingMainTextBlock = existingMainTextBlocks[0]
|
||||
const currentRefs = existingMainTextBlock.citationReferences || []
|
||||
const mainTextChanges = {
|
||||
citationReferences: [...currentRefs, { blockId, citationBlockSource: llmWebSearchResult.source }]
|
||||
}
|
||||
blockManager.smartBlockUpdate(existingMainTextBlock.id, mainTextChanges, MessageBlockType.MAIN_TEXT, true)
|
||||
}
|
||||
|
||||
if (blockManager.hasInitialPlaceholder) {
|
||||
citationBlockId = blockManager.initialPlaceholderBlockId
|
||||
}
|
||||
} else {
|
||||
const citationBlock = createCitationBlock(
|
||||
assistantMsgId,
|
||||
{
|
||||
response: llmWebSearchResult
|
||||
},
|
||||
{
|
||||
status: MessageBlockStatus.SUCCESS
|
||||
}
|
||||
)
|
||||
citationBlockId = citationBlock.id
|
||||
|
||||
const state = getState()
|
||||
const existingMainTextBlocks = findMainTextBlocks(state.messages.entities[assistantMsgId])
|
||||
if (existingMainTextBlocks.length > 0) {
|
||||
const existingMainTextBlock = existingMainTextBlocks[0]
|
||||
const currentRefs = existingMainTextBlock.citationReferences || []
|
||||
const mainTextChanges = {
|
||||
citationReferences: [...currentRefs, { citationBlockId, citationBlockSource: llmWebSearchResult.source }]
|
||||
}
|
||||
blockManager.smartBlockUpdate(existingMainTextBlock.id, mainTextChanges, MessageBlockType.MAIN_TEXT, true)
|
||||
}
|
||||
await blockManager.handleBlockTransition(citationBlock, MessageBlockType.CITATION)
|
||||
}
|
||||
},
|
||||
|
||||
// 暴露给外部的方法,用于textCallbacks中获取citationBlockId
|
||||
getCitationBlockId: () => citationBlockId
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,69 @@
|
||||
import { ImageMessageBlock, MessageBlockStatus, MessageBlockType } from '@renderer/types/newMessage'
|
||||
import { createImageBlock } from '@renderer/utils/messageUtils/create'
|
||||
|
||||
import { BlockManager } from '../BlockManager'
|
||||
|
||||
interface ImageCallbacksDependencies {
|
||||
blockManager: BlockManager
|
||||
assistantMsgId: string
|
||||
}
|
||||
|
||||
export const createImageCallbacks = (deps: ImageCallbacksDependencies) => {
|
||||
const { blockManager, assistantMsgId } = deps
|
||||
|
||||
// 内部维护的状态
|
||||
let imageBlockId: string | null = null
|
||||
|
||||
return {
|
||||
onImageCreated: async () => {
|
||||
if (blockManager.hasInitialPlaceholder) {
|
||||
const initialChanges = {
|
||||
type: MessageBlockType.IMAGE,
|
||||
status: MessageBlockStatus.PENDING
|
||||
}
|
||||
imageBlockId = blockManager.initialPlaceholderBlockId!
|
||||
blockManager.smartBlockUpdate(imageBlockId, initialChanges, MessageBlockType.IMAGE)
|
||||
} else if (!imageBlockId) {
|
||||
const imageBlock = createImageBlock(assistantMsgId, {
|
||||
status: MessageBlockStatus.PENDING
|
||||
})
|
||||
imageBlockId = imageBlock.id
|
||||
await blockManager.handleBlockTransition(imageBlock, MessageBlockType.IMAGE)
|
||||
}
|
||||
},
|
||||
|
||||
onImageDelta: (imageData: any) => {
|
||||
const imageUrl = imageData.images?.[0] || 'placeholder_image_url'
|
||||
if (imageBlockId) {
|
||||
const changes: Partial<ImageMessageBlock> = {
|
||||
url: imageUrl,
|
||||
metadata: { generateImageResponse: imageData },
|
||||
status: MessageBlockStatus.STREAMING
|
||||
}
|
||||
blockManager.smartBlockUpdate(imageBlockId, changes, MessageBlockType.IMAGE, true)
|
||||
}
|
||||
},
|
||||
|
||||
onImageGenerated: (imageData: any) => {
|
||||
if (imageBlockId) {
|
||||
if (!imageData) {
|
||||
const changes: Partial<ImageMessageBlock> = {
|
||||
status: MessageBlockStatus.SUCCESS
|
||||
}
|
||||
blockManager.smartBlockUpdate(imageBlockId, changes, MessageBlockType.IMAGE)
|
||||
} else {
|
||||
const imageUrl = imageData.images?.[0] || 'placeholder_image_url'
|
||||
const changes: Partial<ImageMessageBlock> = {
|
||||
url: imageUrl,
|
||||
metadata: { generateImageResponse: imageData },
|
||||
status: MessageBlockStatus.SUCCESS
|
||||
}
|
||||
blockManager.smartBlockUpdate(imageBlockId, changes, MessageBlockType.IMAGE, true)
|
||||
}
|
||||
imageBlockId = null
|
||||
} else {
|
||||
console.error('[onImageGenerated] Last block was not an Image block or ID is missing.')
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,79 @@
|
||||
import type { Assistant } from '@renderer/types'
|
||||
|
||||
import { BlockManager } from '../BlockManager'
|
||||
import { createBaseCallbacks } from './baseCallbacks'
|
||||
import { createCitationCallbacks } from './citationCallbacks'
|
||||
import { createImageCallbacks } from './imageCallbacks'
|
||||
import { createTextCallbacks } from './textCallbacks'
|
||||
import { createThinkingCallbacks } from './thinkingCallbacks'
|
||||
import { createToolCallbacks } from './toolCallbacks'
|
||||
|
||||
interface CallbacksDependencies {
|
||||
blockManager: BlockManager
|
||||
dispatch: any
|
||||
getState: any
|
||||
topicId: string
|
||||
assistantMsgId: string
|
||||
saveUpdatesToDB: any
|
||||
assistant: Assistant
|
||||
}
|
||||
|
||||
export const createCallbacks = (deps: CallbacksDependencies) => {
|
||||
const { blockManager, dispatch, getState, topicId, assistantMsgId, saveUpdatesToDB, assistant } = deps
|
||||
|
||||
// 创建基础回调
|
||||
const baseCallbacks = createBaseCallbacks({
|
||||
blockManager,
|
||||
dispatch,
|
||||
getState,
|
||||
topicId,
|
||||
assistantMsgId,
|
||||
saveUpdatesToDB,
|
||||
assistant
|
||||
})
|
||||
|
||||
// 创建各类回调
|
||||
const thinkingCallbacks = createThinkingCallbacks({
|
||||
blockManager,
|
||||
assistantMsgId
|
||||
})
|
||||
|
||||
const toolCallbacks = createToolCallbacks({
|
||||
blockManager,
|
||||
assistantMsgId
|
||||
})
|
||||
|
||||
const imageCallbacks = createImageCallbacks({
|
||||
blockManager,
|
||||
assistantMsgId
|
||||
})
|
||||
|
||||
const citationCallbacks = createCitationCallbacks({
|
||||
blockManager,
|
||||
assistantMsgId,
|
||||
getState
|
||||
})
|
||||
|
||||
// 创建textCallbacks时传入citationCallbacks的getCitationBlockId方法
|
||||
const textCallbacks = createTextCallbacks({
|
||||
blockManager,
|
||||
getState,
|
||||
assistantMsgId,
|
||||
getCitationBlockId: citationCallbacks.getCitationBlockId
|
||||
})
|
||||
|
||||
// 组合所有回调
|
||||
return {
|
||||
...baseCallbacks,
|
||||
...textCallbacks,
|
||||
...thinkingCallbacks,
|
||||
...toolCallbacks,
|
||||
...imageCallbacks,
|
||||
...citationCallbacks,
|
||||
// 清理资源的方法
|
||||
cleanup: () => {
|
||||
// 清理由 messageThunk 中的节流函数管理,这里不需要特别处理
|
||||
// 如果需要,可以调用 blockManager 的相关清理方法
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,69 @@
|
||||
import { WebSearchSource } from '@renderer/types'
|
||||
import { CitationMessageBlock, MessageBlock, MessageBlockStatus, MessageBlockType } from '@renderer/types/newMessage'
|
||||
import { createMainTextBlock } from '@renderer/utils/messageUtils/create'
|
||||
|
||||
import { BlockManager } from '../BlockManager'
|
||||
|
||||
interface TextCallbacksDependencies {
|
||||
blockManager: BlockManager
|
||||
getState: any
|
||||
assistantMsgId: string
|
||||
getCitationBlockId: () => string | null
|
||||
}
|
||||
|
||||
export const createTextCallbacks = (deps: TextCallbacksDependencies) => {
|
||||
const { blockManager, getState, assistantMsgId, getCitationBlockId } = deps
|
||||
|
||||
// 内部维护的状态
|
||||
let mainTextBlockId: string | null = null
|
||||
|
||||
return {
|
||||
onTextStart: async () => {
|
||||
if (blockManager.hasInitialPlaceholder) {
|
||||
const changes = {
|
||||
type: MessageBlockType.MAIN_TEXT,
|
||||
content: '',
|
||||
status: MessageBlockStatus.STREAMING
|
||||
}
|
||||
mainTextBlockId = blockManager.initialPlaceholderBlockId!
|
||||
blockManager.smartBlockUpdate(mainTextBlockId, changes, MessageBlockType.MAIN_TEXT, true)
|
||||
} else if (!mainTextBlockId) {
|
||||
const newBlock = createMainTextBlock(assistantMsgId, '', {
|
||||
status: MessageBlockStatus.STREAMING
|
||||
})
|
||||
mainTextBlockId = newBlock.id
|
||||
await blockManager.handleBlockTransition(newBlock, MessageBlockType.MAIN_TEXT)
|
||||
}
|
||||
},
|
||||
|
||||
onTextChunk: async (text: string) => {
|
||||
const citationBlockId = getCitationBlockId()
|
||||
const citationBlockSource = citationBlockId
|
||||
? (getState().messageBlocks.entities[citationBlockId] as CitationMessageBlock).response?.source
|
||||
: WebSearchSource.WEBSEARCH
|
||||
if (text) {
|
||||
const blockChanges: Partial<MessageBlock> = {
|
||||
content: text,
|
||||
status: MessageBlockStatus.STREAMING,
|
||||
citationReferences: citationBlockId ? [{ citationBlockId, citationBlockSource }] : []
|
||||
}
|
||||
blockManager.smartBlockUpdate(mainTextBlockId!, blockChanges, MessageBlockType.MAIN_TEXT)
|
||||
}
|
||||
},
|
||||
|
||||
onTextComplete: async (finalText: string) => {
|
||||
if (mainTextBlockId) {
|
||||
const changes = {
|
||||
content: finalText,
|
||||
status: MessageBlockStatus.SUCCESS
|
||||
}
|
||||
blockManager.smartBlockUpdate(mainTextBlockId, changes, MessageBlockType.MAIN_TEXT, true)
|
||||
mainTextBlockId = null
|
||||
} else {
|
||||
console.warn(
|
||||
`[onTextComplete] Received text.complete but last block was not MAIN_TEXT (was ${blockManager.lastBlockType}) or lastBlockId is null.`
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,66 @@
|
||||
import { MessageBlock, MessageBlockStatus, MessageBlockType } from '@renderer/types/newMessage'
|
||||
import { createThinkingBlock } from '@renderer/utils/messageUtils/create'
|
||||
|
||||
import { BlockManager } from '../BlockManager'
|
||||
|
||||
interface ThinkingCallbacksDependencies {
|
||||
blockManager: BlockManager
|
||||
assistantMsgId: string
|
||||
}
|
||||
|
||||
export const createThinkingCallbacks = (deps: ThinkingCallbacksDependencies) => {
|
||||
const { blockManager, assistantMsgId } = deps
|
||||
|
||||
// 内部维护的状态
|
||||
let thinkingBlockId: string | null = null
|
||||
|
||||
return {
|
||||
onThinkingStart: async () => {
|
||||
if (blockManager.hasInitialPlaceholder) {
|
||||
const changes = {
|
||||
type: MessageBlockType.THINKING,
|
||||
content: '',
|
||||
status: MessageBlockStatus.STREAMING,
|
||||
thinking_millsec: 0
|
||||
}
|
||||
thinkingBlockId = blockManager.initialPlaceholderBlockId!
|
||||
blockManager.smartBlockUpdate(thinkingBlockId, changes, MessageBlockType.THINKING, true)
|
||||
} else if (!thinkingBlockId) {
|
||||
const newBlock = createThinkingBlock(assistantMsgId, '', {
|
||||
status: MessageBlockStatus.STREAMING,
|
||||
thinking_millsec: 0
|
||||
})
|
||||
thinkingBlockId = newBlock.id
|
||||
await blockManager.handleBlockTransition(newBlock, MessageBlockType.THINKING)
|
||||
}
|
||||
},
|
||||
|
||||
onThinkingChunk: async (text: string, thinking_millsec?: number) => {
|
||||
if (thinkingBlockId) {
|
||||
const blockChanges: Partial<MessageBlock> = {
|
||||
content: text,
|
||||
status: MessageBlockStatus.STREAMING,
|
||||
thinking_millsec: thinking_millsec || 0
|
||||
}
|
||||
blockManager.smartBlockUpdate(thinkingBlockId, blockChanges, MessageBlockType.THINKING)
|
||||
}
|
||||
},
|
||||
|
||||
onThinkingComplete: (finalText: string, final_thinking_millsec?: number) => {
|
||||
if (thinkingBlockId) {
|
||||
const changes = {
|
||||
type: MessageBlockType.THINKING,
|
||||
content: finalText,
|
||||
status: MessageBlockStatus.SUCCESS,
|
||||
thinking_millsec: final_thinking_millsec || 0
|
||||
}
|
||||
blockManager.smartBlockUpdate(thinkingBlockId, changes, MessageBlockType.THINKING, true)
|
||||
thinkingBlockId = null
|
||||
} else {
|
||||
console.warn(
|
||||
`[onThinkingComplete] Received thinking.complete but last block was not THINKING (was ${blockManager.lastBlockType}) or lastBlockId is null.`
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,106 @@
|
||||
import type { MCPToolResponse } from '@renderer/types'
|
||||
import { MessageBlockStatus, MessageBlockType, ToolMessageBlock } from '@renderer/types/newMessage'
|
||||
import { createToolBlock } from '@renderer/utils/messageUtils/create'
|
||||
|
||||
import { BlockManager } from '../BlockManager'
|
||||
|
||||
interface ToolCallbacksDependencies {
|
||||
blockManager: BlockManager
|
||||
assistantMsgId: string
|
||||
}
|
||||
|
||||
export const createToolCallbacks = (deps: ToolCallbacksDependencies) => {
|
||||
const { blockManager, assistantMsgId } = deps
|
||||
|
||||
// 内部维护的状态
|
||||
const toolCallIdToBlockIdMap = new Map<string, string>()
|
||||
let toolBlockId: string | null = null
|
||||
|
||||
return {
|
||||
onToolCallPending: (toolResponse: MCPToolResponse) => {
|
||||
if (blockManager.hasInitialPlaceholder) {
|
||||
const changes = {
|
||||
type: MessageBlockType.TOOL,
|
||||
status: MessageBlockStatus.PENDING,
|
||||
toolName: toolResponse.tool.name,
|
||||
metadata: { rawMcpToolResponse: toolResponse }
|
||||
}
|
||||
toolBlockId = blockManager.initialPlaceholderBlockId!
|
||||
blockManager.smartBlockUpdate(toolBlockId, changes, MessageBlockType.TOOL)
|
||||
toolCallIdToBlockIdMap.set(toolResponse.id, toolBlockId)
|
||||
} else if (toolResponse.status === 'pending') {
|
||||
const toolBlock = createToolBlock(assistantMsgId, toolResponse.id, {
|
||||
toolName: toolResponse.tool.name,
|
||||
status: MessageBlockStatus.PENDING,
|
||||
metadata: { rawMcpToolResponse: toolResponse }
|
||||
})
|
||||
toolBlockId = toolBlock.id
|
||||
blockManager.handleBlockTransition(toolBlock, MessageBlockType.TOOL)
|
||||
toolCallIdToBlockIdMap.set(toolResponse.id, toolBlock.id)
|
||||
} else {
|
||||
console.warn(
|
||||
`[onToolCallPending] Received unhandled tool status: ${toolResponse.status} for ID: ${toolResponse.id}`
|
||||
)
|
||||
}
|
||||
},
|
||||
|
||||
onToolCallInProgress: (toolResponse: MCPToolResponse) => {
|
||||
// 根据 toolResponse.id 查找对应的块ID
|
||||
const targetBlockId = toolCallIdToBlockIdMap.get(toolResponse.id)
|
||||
|
||||
if (targetBlockId && toolResponse.status === 'invoking') {
|
||||
const changes = {
|
||||
status: MessageBlockStatus.PROCESSING,
|
||||
metadata: { rawMcpToolResponse: toolResponse }
|
||||
}
|
||||
blockManager.smartBlockUpdate(targetBlockId, changes, MessageBlockType.TOOL)
|
||||
} else if (!targetBlockId) {
|
||||
console.warn(
|
||||
`[onToolCallInProgress] No block ID found for tool ID: ${toolResponse.id}. Available mappings:`,
|
||||
Array.from(toolCallIdToBlockIdMap.entries())
|
||||
)
|
||||
} else {
|
||||
console.warn(
|
||||
`[onToolCallInProgress] Received unhandled tool status: ${toolResponse.status} for ID: ${toolResponse.id}`
|
||||
)
|
||||
}
|
||||
},
|
||||
|
||||
onToolCallComplete: (toolResponse: MCPToolResponse) => {
|
||||
const existingBlockId = toolCallIdToBlockIdMap.get(toolResponse.id)
|
||||
toolCallIdToBlockIdMap.delete(toolResponse.id)
|
||||
|
||||
if (toolResponse.status === 'done' || toolResponse.status === 'error' || toolResponse.status === 'cancelled') {
|
||||
if (!existingBlockId) {
|
||||
console.error(
|
||||
`[onToolCallComplete] No existing block found for completed/error tool call ID: ${toolResponse.id}. Cannot update.`
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
const finalStatus =
|
||||
toolResponse.status === 'done' || toolResponse.status === 'cancelled'
|
||||
? MessageBlockStatus.SUCCESS
|
||||
: MessageBlockStatus.ERROR
|
||||
|
||||
const changes: Partial<ToolMessageBlock> = {
|
||||
content: toolResponse.response,
|
||||
status: finalStatus,
|
||||
metadata: { rawMcpToolResponse: toolResponse }
|
||||
}
|
||||
|
||||
if (finalStatus === MessageBlockStatus.ERROR) {
|
||||
changes.error = { message: `Tool execution failed/error`, details: toolResponse.response }
|
||||
}
|
||||
|
||||
blockManager.smartBlockUpdate(existingBlockId, changes, MessageBlockType.TOOL, true)
|
||||
} else {
|
||||
console.warn(
|
||||
`[onToolCallComplete] Received unhandled tool status: ${toolResponse.status} for ID: ${toolResponse.id}`
|
||||
)
|
||||
}
|
||||
|
||||
toolBlockId = null
|
||||
}
|
||||
}
|
||||
}
|
||||
3
src/renderer/src/services/messageStreaming/index.ts
Normal file
3
src/renderer/src/services/messageStreaming/index.ts
Normal file
@ -0,0 +1,3 @@
|
||||
export { BlockManager } from './BlockManager'
|
||||
export type { createCallbacks as CreateCallbacksFunction } from './callbacks'
|
||||
export { createCallbacks } from './callbacks'
|
||||
@ -1,9 +1,10 @@
|
||||
import { combineReducers, configureStore } from '@reduxjs/toolkit'
|
||||
import { BlockManager } from '@renderer/services/messageStreaming/BlockManager'
|
||||
import { createCallbacks } from '@renderer/services/messageStreaming/callbacks'
|
||||
import { createStreamProcessor } from '@renderer/services/StreamProcessingService'
|
||||
import type { AppDispatch } from '@renderer/store'
|
||||
import { messageBlocksSlice } from '@renderer/store/messageBlock'
|
||||
import { messagesSlice } from '@renderer/store/newMessage'
|
||||
import { streamCallback } from '@renderer/store/thunk/messageThunk'
|
||||
import type { Assistant, ExternalToolResult, MCPTool, Model } from '@renderer/types'
|
||||
import { WebSearchSource } from '@renderer/types'
|
||||
import type { Chunk } from '@renderer/types/chunk'
|
||||
@ -13,6 +14,32 @@ import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import type { RootState } from '../../index'
|
||||
|
||||
const createMockCallbacks = (
|
||||
mockAssistantMsgId: string,
|
||||
mockTopicId: string,
|
||||
mockAssistant: Assistant,
|
||||
dispatch: AppDispatch,
|
||||
getState: () => ReturnType<typeof reducer> & RootState
|
||||
) =>
|
||||
createCallbacks({
|
||||
blockManager: new BlockManager({
|
||||
dispatch,
|
||||
getState,
|
||||
saveUpdatedBlockToDB: vi.fn(),
|
||||
saveUpdatesToDB: vi.fn(),
|
||||
assistantMsgId: mockAssistantMsgId,
|
||||
topicId: mockTopicId,
|
||||
throttledBlockUpdate: vi.fn(),
|
||||
cancelThrottledBlockUpdate: vi.fn()
|
||||
}),
|
||||
dispatch,
|
||||
getState,
|
||||
topicId: mockTopicId,
|
||||
assistantMsgId: mockAssistantMsgId,
|
||||
saveUpdatesToDB: vi.fn(),
|
||||
assistant: mockAssistant
|
||||
})
|
||||
|
||||
// Mock external dependencies
|
||||
vi.mock('@renderer/config/models', () => ({
|
||||
SYSTEM_MODELS: {
|
||||
@ -186,7 +213,8 @@ vi.mock('@renderer/utils/queue', () => ({
|
||||
vi.mock('@renderer/utils/messageUtils/find', () => ({
|
||||
default: {},
|
||||
findMainTextBlocks: vi.fn(() => []),
|
||||
getMainTextContent: vi.fn(() => 'Test content')
|
||||
getMainTextContent: vi.fn(() => 'Test content'),
|
||||
findAllBlocks: vi.fn(() => [])
|
||||
}))
|
||||
|
||||
vi.mock('i18next', () => {
|
||||
@ -239,7 +267,7 @@ const createMockStore = () => {
|
||||
}
|
||||
|
||||
// Helper function to simulate processing chunks through the stream processor
|
||||
const processChunks = async (chunks: Chunk[], callbacks: ReturnType<typeof streamCallback>) => {
|
||||
const processChunks = async (chunks: Chunk[], callbacks: ReturnType<typeof createCallbacks>) => {
|
||||
const streamProcessor = createStreamProcessor(callbacks)
|
||||
|
||||
const stream = new ReadableStream<Chunk>({
|
||||
@ -326,7 +354,7 @@ describe('streamCallback Integration Tests', () => {
|
||||
})
|
||||
|
||||
it('should handle complete text streaming flow', async () => {
|
||||
const callbacks = streamCallback(dispatch, getState, mockTopicId, mockAssistant, mockAssistantMsgId)
|
||||
const callbacks = createMockCallbacks(mockAssistantMsgId, mockTopicId, mockAssistant, dispatch, getState)
|
||||
|
||||
const chunks: Chunk[] = [
|
||||
{ type: ChunkType.LLM_RESPONSE_CREATED },
|
||||
@ -369,7 +397,7 @@ describe('streamCallback Integration Tests', () => {
|
||||
})
|
||||
|
||||
it('should handle thinking flow', async () => {
|
||||
const callbacks = streamCallback(dispatch, getState, mockTopicId, mockAssistant, mockAssistantMsgId)
|
||||
const callbacks = createMockCallbacks(mockAssistantMsgId, mockTopicId, mockAssistant, dispatch, getState)
|
||||
|
||||
const chunks: Chunk[] = [
|
||||
{ type: ChunkType.LLM_RESPONSE_CREATED },
|
||||
@ -394,7 +422,7 @@ describe('streamCallback Integration Tests', () => {
|
||||
})
|
||||
|
||||
it('should handle tool call flow', async () => {
|
||||
const callbacks = streamCallback(dispatch, getState, mockTopicId, mockAssistant, mockAssistantMsgId)
|
||||
const callbacks = createMockCallbacks(mockAssistantMsgId, mockTopicId, mockAssistant, dispatch, getState)
|
||||
|
||||
const mockTool: MCPTool = {
|
||||
id: 'tool-1',
|
||||
@ -464,7 +492,7 @@ describe('streamCallback Integration Tests', () => {
|
||||
})
|
||||
|
||||
it('should handle image generation flow', async () => {
|
||||
const callbacks = streamCallback(dispatch, getState, mockTopicId, mockAssistant, mockAssistantMsgId)
|
||||
const callbacks = createMockCallbacks(mockAssistantMsgId, mockTopicId, mockAssistant, dispatch, getState)
|
||||
|
||||
const chunks: Chunk[] = [
|
||||
{ type: ChunkType.LLM_RESPONSE_CREATED },
|
||||
@ -504,7 +532,7 @@ describe('streamCallback Integration Tests', () => {
|
||||
})
|
||||
|
||||
it('should handle web search flow', async () => {
|
||||
const callbacks = streamCallback(dispatch, getState, mockTopicId, mockAssistant, mockAssistantMsgId)
|
||||
const callbacks = createMockCallbacks(mockAssistantMsgId, mockTopicId, mockAssistant, dispatch, getState)
|
||||
|
||||
const mockWebSearchResult = {
|
||||
source: WebSearchSource.WEBSEARCH,
|
||||
@ -523,7 +551,6 @@ describe('streamCallback Integration Tests', () => {
|
||||
// 验证 Redux 状态
|
||||
const state = getState()
|
||||
const blocks = Object.values(state.messageBlocks.entities)
|
||||
|
||||
const citationBlock = blocks.find((block) => block.type === MessageBlockType.CITATION)
|
||||
expect(citationBlock).toBeDefined()
|
||||
expect(citationBlock?.response?.source).toEqual(mockWebSearchResult.source)
|
||||
@ -531,7 +558,7 @@ describe('streamCallback Integration Tests', () => {
|
||||
})
|
||||
|
||||
it('should handle mixed content flow (thinking + tool + text)', async () => {
|
||||
const callbacks = streamCallback(dispatch, getState, mockTopicId, mockAssistant, mockAssistantMsgId)
|
||||
const callbacks = createMockCallbacks(mockAssistantMsgId, mockTopicId, mockAssistant, dispatch, getState)
|
||||
|
||||
const mockCalculatorTool: MCPTool = {
|
||||
id: 'tool-1',
|
||||
@ -632,7 +659,7 @@ describe('streamCallback Integration Tests', () => {
|
||||
})
|
||||
|
||||
it('should handle error flow', async () => {
|
||||
const callbacks = streamCallback(dispatch, getState, mockTopicId, mockAssistant, mockAssistantMsgId)
|
||||
const callbacks = createMockCallbacks(mockAssistantMsgId, mockTopicId, mockAssistant, dispatch, getState)
|
||||
|
||||
const mockError = new Error('Test error')
|
||||
|
||||
@ -662,7 +689,7 @@ describe('streamCallback Integration Tests', () => {
|
||||
})
|
||||
|
||||
it('should handle external tool flow', async () => {
|
||||
const callbacks = streamCallback(dispatch, getState, mockTopicId, mockAssistant, mockAssistantMsgId)
|
||||
const callbacks = createMockCallbacks(mockAssistantMsgId, mockTopicId, mockAssistant, dispatch, getState)
|
||||
|
||||
const mockExternalToolResult: ExternalToolResult = {
|
||||
webSearch: {
|
||||
@ -700,7 +727,7 @@ describe('streamCallback Integration Tests', () => {
|
||||
})
|
||||
|
||||
it('should handle abort error correctly', async () => {
|
||||
const callbacks = streamCallback(dispatch, getState, mockTopicId, mockAssistant, mockAssistantMsgId)
|
||||
const callbacks = createMockCallbacks(mockAssistantMsgId, mockTopicId, mockAssistant, dispatch, getState)
|
||||
|
||||
// 创建一个模拟的 abort 错误
|
||||
const abortError = new Error('Request aborted')
|
||||
@ -731,7 +758,7 @@ describe('streamCallback Integration Tests', () => {
|
||||
})
|
||||
|
||||
it('should maintain block reference integrity during streaming', async () => {
|
||||
const callbacks = streamCallback(dispatch, getState, mockTopicId, mockAssistant, mockAssistantMsgId)
|
||||
const callbacks = createMockCallbacks(mockAssistantMsgId, mockTopicId, mockAssistant, dispatch, getState)
|
||||
|
||||
const chunks: Chunk[] = [
|
||||
{ type: ChunkType.LLM_RESPONSE_CREATED },
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user