Merge remote-tracking branch 'origin/feat/pull-a-variable' into feat/pull-a-variable

This commit is contained in:
zhsama 2026-01-13 15:17:24 +08:00
commit dbed937fc6
25 changed files with 371 additions and 134 deletions

View File

@ -82,7 +82,7 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
data = cls._error_to_stream_response(sub_stream_response.err) data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(data) response_chunk.update(data)
else: else:
response_chunk.update(sub_stream_response.model_dump(mode="json")) response_chunk.update(sub_stream_response.model_dump(mode="json", exclude_none=True))
yield response_chunk yield response_chunk
@classmethod @classmethod
@ -110,7 +110,7 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
} }
if isinstance(sub_stream_response, MessageEndStreamResponse): if isinstance(sub_stream_response, MessageEndStreamResponse):
sub_stream_response_dict = sub_stream_response.model_dump(mode="json") sub_stream_response_dict = sub_stream_response.model_dump(mode="json", exclude_none=True)
metadata = sub_stream_response_dict.get("metadata", {}) metadata = sub_stream_response_dict.get("metadata", {})
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata) sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
response_chunk.update(sub_stream_response_dict) response_chunk.update(sub_stream_response_dict)
@ -120,6 +120,6 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse): elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse):
response_chunk.update(sub_stream_response.to_ignore_detail_dict()) response_chunk.update(sub_stream_response.to_ignore_detail_dict())
else: else:
response_chunk.update(sub_stream_response.model_dump(mode="json")) response_chunk.update(sub_stream_response.model_dump(mode="json", exclude_none=True))
yield response_chunk yield response_chunk

View File

@ -81,7 +81,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
data = cls._error_to_stream_response(sub_stream_response.err) data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(data) response_chunk.update(data)
else: else:
response_chunk.update(sub_stream_response.model_dump(mode="json")) response_chunk.update(sub_stream_response.model_dump(mode="json", exclude_none=True))
yield response_chunk yield response_chunk
@classmethod @classmethod
@ -109,7 +109,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
} }
if isinstance(sub_stream_response, MessageEndStreamResponse): if isinstance(sub_stream_response, MessageEndStreamResponse):
sub_stream_response_dict = sub_stream_response.model_dump(mode="json") sub_stream_response_dict = sub_stream_response.model_dump(mode="json", exclude_none=True)
metadata = sub_stream_response_dict.get("metadata", {}) metadata = sub_stream_response_dict.get("metadata", {})
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata) sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
response_chunk.update(sub_stream_response_dict) response_chunk.update(sub_stream_response_dict)
@ -117,6 +117,6 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
data = cls._error_to_stream_response(sub_stream_response.err) data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(data) response_chunk.update(data)
else: else:
response_chunk.update(sub_stream_response.model_dump(mode="json")) response_chunk.update(sub_stream_response.model_dump(mode="json", exclude_none=True))
yield response_chunk yield response_chunk

View File

@ -81,7 +81,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
data = cls._error_to_stream_response(sub_stream_response.err) data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(data) response_chunk.update(data)
else: else:
response_chunk.update(sub_stream_response.model_dump(mode="json")) response_chunk.update(sub_stream_response.model_dump(mode="json", exclude_none=True))
yield response_chunk yield response_chunk
@classmethod @classmethod
@ -109,7 +109,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
} }
if isinstance(sub_stream_response, MessageEndStreamResponse): if isinstance(sub_stream_response, MessageEndStreamResponse):
sub_stream_response_dict = sub_stream_response.model_dump(mode="json") sub_stream_response_dict = sub_stream_response.model_dump(mode="json", exclude_none=True)
metadata = sub_stream_response_dict.get("metadata", {}) metadata = sub_stream_response_dict.get("metadata", {})
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata) sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
response_chunk.update(sub_stream_response_dict) response_chunk.update(sub_stream_response_dict)
@ -117,6 +117,6 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
data = cls._error_to_stream_response(sub_stream_response.err) data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(data) response_chunk.update(data)
else: else:
response_chunk.update(sub_stream_response.model_dump(mode="json")) response_chunk.update(sub_stream_response.model_dump(mode="json", exclude_none=True))
yield response_chunk yield response_chunk

View File

@ -70,6 +70,8 @@ class _NodeSnapshot:
"""Empty string means the node is not executing inside an iteration.""" """Empty string means the node is not executing inside an iteration."""
loop_id: str = "" loop_id: str = ""
"""Empty string means the node is not executing inside a loop.""" """Empty string means the node is not executing inside a loop."""
mention_parent_id: str = ""
"""Empty string means the node is not an extractor node."""
class WorkflowResponseConverter: class WorkflowResponseConverter:
@ -131,6 +133,7 @@ class WorkflowResponseConverter:
start_at=event.start_at, start_at=event.start_at,
iteration_id=event.in_iteration_id or "", iteration_id=event.in_iteration_id or "",
loop_id=event.in_loop_id or "", loop_id=event.in_loop_id or "",
mention_parent_id=event.in_mention_parent_id or "",
) )
node_execution_id = NodeExecutionId(event.node_execution_id) node_execution_id = NodeExecutionId(event.node_execution_id)
self._node_snapshots[node_execution_id] = snapshot self._node_snapshots[node_execution_id] = snapshot
@ -287,6 +290,7 @@ class WorkflowResponseConverter:
created_at=int(snapshot.start_at.timestamp()), created_at=int(snapshot.start_at.timestamp()),
iteration_id=event.in_iteration_id, iteration_id=event.in_iteration_id,
loop_id=event.in_loop_id, loop_id=event.in_loop_id,
mention_parent_id=event.in_mention_parent_id,
agent_strategy=event.agent_strategy, agent_strategy=event.agent_strategy,
), ),
) )
@ -373,6 +377,7 @@ class WorkflowResponseConverter:
files=self.fetch_files_from_node_outputs(event.outputs or {}), files=self.fetch_files_from_node_outputs(event.outputs or {}),
iteration_id=event.in_iteration_id, iteration_id=event.in_iteration_id,
loop_id=event.in_loop_id, loop_id=event.in_loop_id,
mention_parent_id=event.in_mention_parent_id,
), ),
) )
@ -422,6 +427,7 @@ class WorkflowResponseConverter:
files=self.fetch_files_from_node_outputs(event.outputs or {}), files=self.fetch_files_from_node_outputs(event.outputs or {}),
iteration_id=event.in_iteration_id, iteration_id=event.in_iteration_id,
loop_id=event.in_loop_id, loop_id=event.in_loop_id,
mention_parent_id=event.in_mention_parent_id,
retry_index=event.retry_index, retry_index=event.retry_index,
), ),
) )

View File

@ -79,7 +79,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
data = cls._error_to_stream_response(sub_stream_response.err) data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(data) response_chunk.update(data)
else: else:
response_chunk.update(sub_stream_response.model_dump(mode="json")) response_chunk.update(sub_stream_response.model_dump(mode="json", exclude_none=True))
yield response_chunk yield response_chunk
@classmethod @classmethod
@ -106,7 +106,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
} }
if isinstance(sub_stream_response, MessageEndStreamResponse): if isinstance(sub_stream_response, MessageEndStreamResponse):
sub_stream_response_dict = sub_stream_response.model_dump(mode="json") sub_stream_response_dict = sub_stream_response.model_dump(mode="json", exclude_none=True)
metadata = sub_stream_response_dict.get("metadata", {}) metadata = sub_stream_response_dict.get("metadata", {})
if not isinstance(metadata, dict): if not isinstance(metadata, dict):
metadata = {} metadata = {}
@ -116,6 +116,6 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
data = cls._error_to_stream_response(sub_stream_response.err) data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(data) response_chunk.update(data)
else: else:
response_chunk.update(sub_stream_response.model_dump(mode="json")) response_chunk.update(sub_stream_response.model_dump(mode="json", exclude_none=True))
yield response_chunk yield response_chunk

View File

@ -60,7 +60,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
data = cls._error_to_stream_response(sub_stream_response.err) data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(cast(dict, data)) response_chunk.update(cast(dict, data))
else: else:
response_chunk.update(sub_stream_response.model_dump()) response_chunk.update(sub_stream_response.model_dump(exclude_none=True))
yield response_chunk yield response_chunk
@classmethod @classmethod
@ -91,5 +91,5 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse): elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse):
response_chunk.update(cast(dict, sub_stream_response.to_ignore_detail_dict())) response_chunk.update(cast(dict, sub_stream_response.to_ignore_detail_dict()))
else: else:
response_chunk.update(sub_stream_response.model_dump()) response_chunk.update(sub_stream_response.model_dump(exclude_none=True))
yield response_chunk yield response_chunk

View File

@ -60,7 +60,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
data = cls._error_to_stream_response(sub_stream_response.err) data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(data) response_chunk.update(data)
else: else:
response_chunk.update(sub_stream_response.model_dump(mode="json")) response_chunk.update(sub_stream_response.model_dump(mode="json", exclude_none=True))
yield response_chunk yield response_chunk
@classmethod @classmethod
@ -91,5 +91,5 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse): elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse):
response_chunk.update(sub_stream_response.to_ignore_detail_dict()) response_chunk.update(sub_stream_response.to_ignore_detail_dict())
else: else:
response_chunk.update(sub_stream_response.model_dump(mode="json")) response_chunk.update(sub_stream_response.model_dump(mode="json", exclude_none=True))
yield response_chunk yield response_chunk

View File

@ -385,6 +385,7 @@ class WorkflowBasedAppRunner:
start_at=event.start_at, start_at=event.start_at,
in_iteration_id=event.in_iteration_id, in_iteration_id=event.in_iteration_id,
in_loop_id=event.in_loop_id, in_loop_id=event.in_loop_id,
in_mention_parent_id=event.in_mention_parent_id,
inputs=inputs, inputs=inputs,
process_data=process_data, process_data=process_data,
outputs=outputs, outputs=outputs,
@ -405,6 +406,7 @@ class WorkflowBasedAppRunner:
start_at=event.start_at, start_at=event.start_at,
in_iteration_id=event.in_iteration_id, in_iteration_id=event.in_iteration_id,
in_loop_id=event.in_loop_id, in_loop_id=event.in_loop_id,
in_mention_parent_id=event.in_mention_parent_id,
agent_strategy=event.agent_strategy, agent_strategy=event.agent_strategy,
provider_type=event.provider_type, provider_type=event.provider_type,
provider_id=event.provider_id, provider_id=event.provider_id,
@ -428,6 +430,7 @@ class WorkflowBasedAppRunner:
execution_metadata=execution_metadata, execution_metadata=execution_metadata,
in_iteration_id=event.in_iteration_id, in_iteration_id=event.in_iteration_id,
in_loop_id=event.in_loop_id, in_loop_id=event.in_loop_id,
in_mention_parent_id=event.in_mention_parent_id,
) )
) )
elif isinstance(event, NodeRunFailedEvent): elif isinstance(event, NodeRunFailedEvent):
@ -444,6 +447,7 @@ class WorkflowBasedAppRunner:
execution_metadata=event.node_run_result.metadata, execution_metadata=event.node_run_result.metadata,
in_iteration_id=event.in_iteration_id, in_iteration_id=event.in_iteration_id,
in_loop_id=event.in_loop_id, in_loop_id=event.in_loop_id,
in_mention_parent_id=event.in_mention_parent_id,
) )
) )
elif isinstance(event, NodeRunExceptionEvent): elif isinstance(event, NodeRunExceptionEvent):
@ -460,6 +464,7 @@ class WorkflowBasedAppRunner:
execution_metadata=event.node_run_result.metadata, execution_metadata=event.node_run_result.metadata,
in_iteration_id=event.in_iteration_id, in_iteration_id=event.in_iteration_id,
in_loop_id=event.in_loop_id, in_loop_id=event.in_loop_id,
in_mention_parent_id=event.in_mention_parent_id,
) )
) )
elif isinstance(event, NodeRunStreamChunkEvent): elif isinstance(event, NodeRunStreamChunkEvent):
@ -469,6 +474,7 @@ class WorkflowBasedAppRunner:
from_variable_selector=list(event.selector), from_variable_selector=list(event.selector),
in_iteration_id=event.in_iteration_id, in_iteration_id=event.in_iteration_id,
in_loop_id=event.in_loop_id, in_loop_id=event.in_loop_id,
in_mention_parent_id=event.in_mention_parent_id,
) )
) )
elif isinstance(event, NodeRunRetrieverResourceEvent): elif isinstance(event, NodeRunRetrieverResourceEvent):
@ -477,6 +483,7 @@ class WorkflowBasedAppRunner:
retriever_resources=event.retriever_resources, retriever_resources=event.retriever_resources,
in_iteration_id=event.in_iteration_id, in_iteration_id=event.in_iteration_id,
in_loop_id=event.in_loop_id, in_loop_id=event.in_loop_id,
in_mention_parent_id=event.in_mention_parent_id,
) )
) )
elif isinstance(event, NodeRunAgentLogEvent): elif isinstance(event, NodeRunAgentLogEvent):

View File

@ -190,6 +190,8 @@ class QueueTextChunkEvent(AppQueueEvent):
"""iteration id if node is in iteration""" """iteration id if node is in iteration"""
in_loop_id: str | None = None in_loop_id: str | None = None
"""loop id if node is in loop""" """loop id if node is in loop"""
in_mention_parent_id: str | None = None
"""parent node id if this is an extractor node event"""
class QueueAgentMessageEvent(AppQueueEvent): class QueueAgentMessageEvent(AppQueueEvent):
@ -229,6 +231,8 @@ class QueueRetrieverResourcesEvent(AppQueueEvent):
"""iteration id if node is in iteration""" """iteration id if node is in iteration"""
in_loop_id: str | None = None in_loop_id: str | None = None
"""loop id if node is in loop""" """loop id if node is in loop"""
in_mention_parent_id: str | None = None
"""parent node id if this is an extractor node event"""
class QueueAnnotationReplyEvent(AppQueueEvent): class QueueAnnotationReplyEvent(AppQueueEvent):
@ -306,6 +310,8 @@ class QueueNodeStartedEvent(AppQueueEvent):
node_run_index: int = 1 # FIXME(-LAN-): may not used node_run_index: int = 1 # FIXME(-LAN-): may not used
in_iteration_id: str | None = None in_iteration_id: str | None = None
in_loop_id: str | None = None in_loop_id: str | None = None
in_mention_parent_id: str | None = None
"""parent node id if this is an extractor node event"""
start_at: datetime start_at: datetime
agent_strategy: AgentNodeStrategyInit | None = None agent_strategy: AgentNodeStrategyInit | None = None
@ -328,6 +334,8 @@ class QueueNodeSucceededEvent(AppQueueEvent):
"""iteration id if node is in iteration""" """iteration id if node is in iteration"""
in_loop_id: str | None = None in_loop_id: str | None = None
"""loop id if node is in loop""" """loop id if node is in loop"""
in_mention_parent_id: str | None = None
"""parent node id if this is an extractor node event"""
start_at: datetime start_at: datetime
inputs: Mapping[str, object] = Field(default_factory=dict) inputs: Mapping[str, object] = Field(default_factory=dict)
@ -383,6 +391,8 @@ class QueueNodeExceptionEvent(AppQueueEvent):
"""iteration id if node is in iteration""" """iteration id if node is in iteration"""
in_loop_id: str | None = None in_loop_id: str | None = None
"""loop id if node is in loop""" """loop id if node is in loop"""
in_mention_parent_id: str | None = None
"""parent node id if this is an extractor node event"""
start_at: datetime start_at: datetime
inputs: Mapping[str, object] = Field(default_factory=dict) inputs: Mapping[str, object] = Field(default_factory=dict)
@ -407,6 +417,8 @@ class QueueNodeFailedEvent(AppQueueEvent):
"""iteration id if node is in iteration""" """iteration id if node is in iteration"""
in_loop_id: str | None = None in_loop_id: str | None = None
"""loop id if node is in loop""" """loop id if node is in loop"""
in_mention_parent_id: str | None = None
"""parent node id if this is an extractor node event"""
start_at: datetime start_at: datetime
inputs: Mapping[str, object] = Field(default_factory=dict) inputs: Mapping[str, object] = Field(default_factory=dict)

View File

@ -262,6 +262,7 @@ class NodeStartStreamResponse(StreamResponse):
extras: dict[str, object] = Field(default_factory=dict) extras: dict[str, object] = Field(default_factory=dict)
iteration_id: str | None = None iteration_id: str | None = None
loop_id: str | None = None loop_id: str | None = None
mention_parent_id: str | None = None
agent_strategy: AgentNodeStrategyInit | None = None agent_strategy: AgentNodeStrategyInit | None = None
event: StreamEvent = StreamEvent.NODE_STARTED event: StreamEvent = StreamEvent.NODE_STARTED
@ -285,6 +286,7 @@ class NodeStartStreamResponse(StreamResponse):
"extras": {}, "extras": {},
"iteration_id": self.data.iteration_id, "iteration_id": self.data.iteration_id,
"loop_id": self.data.loop_id, "loop_id": self.data.loop_id,
"mention_parent_id": self.data.mention_parent_id,
}, },
} }
@ -320,6 +322,7 @@ class NodeFinishStreamResponse(StreamResponse):
files: Sequence[Mapping[str, Any]] | None = [] files: Sequence[Mapping[str, Any]] | None = []
iteration_id: str | None = None iteration_id: str | None = None
loop_id: str | None = None loop_id: str | None = None
mention_parent_id: str | None = None
event: StreamEvent = StreamEvent.NODE_FINISHED event: StreamEvent = StreamEvent.NODE_FINISHED
workflow_run_id: str workflow_run_id: str
@ -349,6 +352,7 @@ class NodeFinishStreamResponse(StreamResponse):
"files": [], "files": [],
"iteration_id": self.data.iteration_id, "iteration_id": self.data.iteration_id,
"loop_id": self.data.loop_id, "loop_id": self.data.loop_id,
"mention_parent_id": self.data.mention_parent_id,
}, },
} }
@ -384,6 +388,7 @@ class NodeRetryStreamResponse(StreamResponse):
files: Sequence[Mapping[str, Any]] | None = [] files: Sequence[Mapping[str, Any]] | None = []
iteration_id: str | None = None iteration_id: str | None = None
loop_id: str | None = None loop_id: str | None = None
mention_parent_id: str | None = None
retry_index: int = 0 retry_index: int = 0
event: StreamEvent = StreamEvent.NODE_RETRY event: StreamEvent = StreamEvent.NODE_RETRY
@ -414,6 +419,7 @@ class NodeRetryStreamResponse(StreamResponse):
"files": [], "files": [],
"iteration_id": self.data.iteration_id, "iteration_id": self.data.iteration_id,
"loop_id": self.data.loop_id, "loop_id": self.data.loop_id,
"mention_parent_id": self.data.mention_parent_id,
"retry_index": self.data.retry_index, "retry_index": self.data.retry_index,
}, },
} }

View File

@ -5,7 +5,7 @@ from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEnti
from core.file import file_manager from core.file import file_manager
from core.file.models import File from core.file.models import File
from core.helper.code_executor.jinja2.jinja2_formatter import Jinja2Formatter from core.helper.code_executor.jinja2.jinja2_formatter import Jinja2Formatter
from core.memory.token_buffer_memory import TokenBufferMemory from core.memory.base import BaseMemory
from core.model_runtime.entities import ( from core.model_runtime.entities import (
AssistantPromptMessage, AssistantPromptMessage,
PromptMessage, PromptMessage,
@ -43,7 +43,7 @@ class AdvancedPromptTransform(PromptTransform):
files: Sequence[File], files: Sequence[File],
context: str | None, context: str | None,
memory_config: MemoryConfig | None, memory_config: MemoryConfig | None,
memory: TokenBufferMemory | None, memory: BaseMemory | None,
model_config: ModelConfigWithCredentialsEntity, model_config: ModelConfigWithCredentialsEntity,
image_detail_config: ImagePromptMessageContent.DETAIL | None = None, image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
) -> list[PromptMessage]: ) -> list[PromptMessage]:
@ -84,7 +84,7 @@ class AdvancedPromptTransform(PromptTransform):
files: Sequence[File], files: Sequence[File],
context: str | None, context: str | None,
memory_config: MemoryConfig | None, memory_config: MemoryConfig | None,
memory: TokenBufferMemory | None, memory: BaseMemory | None,
model_config: ModelConfigWithCredentialsEntity, model_config: ModelConfigWithCredentialsEntity,
image_detail_config: ImagePromptMessageContent.DETAIL | None = None, image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
) -> list[PromptMessage]: ) -> list[PromptMessage]:
@ -145,7 +145,7 @@ class AdvancedPromptTransform(PromptTransform):
files: Sequence[File], files: Sequence[File],
context: str | None, context: str | None,
memory_config: MemoryConfig | None, memory_config: MemoryConfig | None,
memory: TokenBufferMemory | None, memory: BaseMemory | None,
model_config: ModelConfigWithCredentialsEntity, model_config: ModelConfigWithCredentialsEntity,
image_detail_config: ImagePromptMessageContent.DETAIL | None = None, image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
) -> list[PromptMessage]: ) -> list[PromptMessage]:
@ -270,7 +270,7 @@ class AdvancedPromptTransform(PromptTransform):
def _set_histories_variable( def _set_histories_variable(
self, self,
memory: TokenBufferMemory, memory: BaseMemory,
memory_config: MemoryConfig, memory_config: MemoryConfig,
raw_prompt: str, raw_prompt: str,
role_prefix: MemoryConfig.RolePrefix, role_prefix: MemoryConfig.RolePrefix,

View File

@ -1,7 +1,7 @@
from typing import Any from typing import Any
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.memory.token_buffer_memory import TokenBufferMemory from core.memory.base import BaseMemory
from core.model_manager import ModelInstance from core.model_manager import ModelInstance
from core.model_runtime.entities.message_entities import PromptMessage from core.model_runtime.entities.message_entities import PromptMessage
from core.model_runtime.entities.model_entities import ModelPropertyKey from core.model_runtime.entities.model_entities import ModelPropertyKey
@ -11,7 +11,7 @@ from core.prompt.entities.advanced_prompt_entities import MemoryConfig
class PromptTransform: class PromptTransform:
def _append_chat_histories( def _append_chat_histories(
self, self,
memory: TokenBufferMemory, memory: BaseMemory,
memory_config: MemoryConfig, memory_config: MemoryConfig,
prompt_messages: list[PromptMessage], prompt_messages: list[PromptMessage],
model_config: ModelConfigWithCredentialsEntity, model_config: ModelConfigWithCredentialsEntity,
@ -52,7 +52,7 @@ class PromptTransform:
def _get_history_messages_from_memory( def _get_history_messages_from_memory(
self, self,
memory: TokenBufferMemory, memory: BaseMemory,
memory_config: MemoryConfig, memory_config: MemoryConfig,
max_token_limit: int, max_token_limit: int,
human_prefix: str | None = None, human_prefix: str | None = None,
@ -73,7 +73,7 @@ class PromptTransform:
return memory.get_history_prompt_text(**kwargs) return memory.get_history_prompt_text(**kwargs)
def _get_history_messages_list_from_memory( def _get_history_messages_list_from_memory(
self, memory: TokenBufferMemory, memory_config: MemoryConfig, max_token_limit: int self, memory: BaseMemory, memory_config: MemoryConfig, max_token_limit: int
) -> list[PromptMessage]: ) -> list[PromptMessage]:
"""Get memory messages.""" """Get memory messages."""
return list( return list(

View File

@ -253,6 +253,7 @@ class WorkflowNodeExecutionMetadataKey(StrEnum):
LOOP_VARIABLE_MAP = "loop_variable_map" # single loop variable output LOOP_VARIABLE_MAP = "loop_variable_map" # single loop variable output
DATASOURCE_INFO = "datasource_info" DATASOURCE_INFO = "datasource_info"
COMPLETED_REASON = "completed_reason" # completed reason for loop node COMPLETED_REASON = "completed_reason" # completed reason for loop node
MENTION_PARENT_ID = "mention_parent_id" # parent node id for extractor nodes
class WorkflowNodeExecutionStatus(StrEnum): class WorkflowNodeExecutionStatus(StrEnum):

View File

@ -93,8 +93,8 @@ class EventHandler:
Args: Args:
event: The event to handle event: The event to handle
""" """
# Events in loops or iterations are always collected # Events in loops, iterations, or extractor groups are always collected
if event.in_loop_id or event.in_iteration_id: if event.in_loop_id or event.in_iteration_id or event.in_mention_parent_id:
self._event_collector.collect(event) self._event_collector.collect(event)
return return
return self._dispatch(event) return self._dispatch(event)

View File

@ -68,6 +68,7 @@ class _NodeRuntimeSnapshot:
predecessor_node_id: str | None predecessor_node_id: str | None
iteration_id: str | None iteration_id: str | None
loop_id: str | None loop_id: str | None
mention_parent_id: str | None
created_at: datetime created_at: datetime
@ -230,6 +231,7 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
metadata = { metadata = {
WorkflowNodeExecutionMetadataKey.ITERATION_ID: event.in_iteration_id, WorkflowNodeExecutionMetadataKey.ITERATION_ID: event.in_iteration_id,
WorkflowNodeExecutionMetadataKey.LOOP_ID: event.in_loop_id, WorkflowNodeExecutionMetadataKey.LOOP_ID: event.in_loop_id,
WorkflowNodeExecutionMetadataKey.MENTION_PARENT_ID: event.in_mention_parent_id,
} }
domain_execution = WorkflowNodeExecution( domain_execution = WorkflowNodeExecution(
@ -256,6 +258,7 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
predecessor_node_id=event.predecessor_node_id, predecessor_node_id=event.predecessor_node_id,
iteration_id=event.in_iteration_id, iteration_id=event.in_iteration_id,
loop_id=event.in_loop_id, loop_id=event.in_loop_id,
mention_parent_id=event.in_mention_parent_id,
created_at=event.start_at, created_at=event.start_at,
) )
self._node_snapshots[event.id] = snapshot self._node_snapshots[event.id] = snapshot

View File

@ -21,6 +21,12 @@ class GraphNodeEventBase(GraphEngineEvent):
"""iteration id if node is in iteration""" """iteration id if node is in iteration"""
in_loop_id: str | None = None in_loop_id: str | None = None
"""loop id if node is in loop""" """loop id if node is in loop"""
in_mention_parent_id: str | None = None
"""Parent node id if this is an extractor node event.
When set, indicates this event belongs to an extractor node that
is extracting values for the specified parent node.
"""
# The version of the node, or "1" if not specified. # The version of the node, or "1" if not specified.
node_version: str = "1" node_version: str = "1"

View File

@ -12,11 +12,14 @@ from sqlalchemy.orm import Session
from core.agent.entities import AgentToolEntity from core.agent.entities import AgentToolEntity
from core.agent.plugin_entities import AgentStrategyParameter from core.agent.plugin_entities import AgentStrategyParameter
from core.file import File, FileTransferMethod from core.file import File, FileTransferMethod
from core.memory.base import BaseMemory
from core.memory.node_token_buffer_memory import NodeTokenBufferMemory
from core.memory.token_buffer_memory import TokenBufferMemory from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance, ModelManager from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata from core.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
from core.model_runtime.utils.encoders import jsonable_encoder from core.model_runtime.utils.encoders import jsonable_encoder
from core.prompt.entities.advanced_prompt_entities import MemoryMode
from core.provider_manager import ProviderManager from core.provider_manager import ProviderManager
from core.tools.entities.tool_entities import ( from core.tools.entities.tool_entities import (
ToolIdentity, ToolIdentity,
@ -136,6 +139,9 @@ class AgentNode(Node[AgentNodeData]):
) )
return return
# Fetch memory for node memory saving
memory = self._fetch_memory_for_save()
try: try:
yield from self._transform_message( yield from self._transform_message(
messages=message_stream, messages=message_stream,
@ -149,6 +155,7 @@ class AgentNode(Node[AgentNodeData]):
node_type=self.node_type, node_type=self.node_type,
node_id=self._node_id, node_id=self._node_id,
node_execution_id=self.id, node_execution_id=self.id,
memory=memory,
) )
except PluginDaemonClientSideError as e: except PluginDaemonClientSideError as e:
transform_error = AgentMessageTransformError( transform_error = AgentMessageTransformError(
@ -395,8 +402,20 @@ class AgentNode(Node[AgentNodeData]):
icon = None icon = None
return icon return icon
def _fetch_memory(self, model_instance: ModelInstance) -> TokenBufferMemory | None: def _fetch_memory(self, model_instance: ModelInstance) -> BaseMemory | None:
# get conversation id """
Fetch memory based on configuration mode.
Returns TokenBufferMemory for conversation mode (default),
or NodeTokenBufferMemory for node mode (Chatflow only).
"""
node_data = self.node_data
memory_config = node_data.memory
if not memory_config:
return None
# get conversation id (required for both modes in Chatflow)
conversation_id_variable = self.graph_runtime_state.variable_pool.get( conversation_id_variable = self.graph_runtime_state.variable_pool.get(
["sys", SystemVariableKey.CONVERSATION_ID] ["sys", SystemVariableKey.CONVERSATION_ID]
) )
@ -404,16 +423,26 @@ class AgentNode(Node[AgentNodeData]):
return None return None
conversation_id = conversation_id_variable.value conversation_id = conversation_id_variable.value
with Session(db.engine, expire_on_commit=False) as session: # Return appropriate memory type based on mode
stmt = select(Conversation).where(Conversation.app_id == self.app_id, Conversation.id == conversation_id) if memory_config.mode == MemoryMode.NODE:
conversation = session.scalar(stmt) # Node-level memory (Chatflow only)
return NodeTokenBufferMemory(
if not conversation: app_id=self.app_id,
return None conversation_id=conversation_id,
node_id=self._node_id,
memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance) tenant_id=self.tenant_id,
model_instance=model_instance,
return memory )
else:
# Conversation-level memory (default)
with Session(db.engine, expire_on_commit=False) as session:
stmt = select(Conversation).where(
Conversation.app_id == self.app_id, Conversation.id == conversation_id
)
conversation = session.scalar(stmt)
if not conversation:
return None
return TokenBufferMemory(conversation=conversation, model_instance=model_instance)
def _fetch_model(self, value: dict[str, Any]) -> tuple[ModelInstance, AIModelEntity | None]: def _fetch_model(self, value: dict[str, Any]) -> tuple[ModelInstance, AIModelEntity | None]:
provider_manager = ProviderManager() provider_manager = ProviderManager()
@ -457,6 +486,47 @@ class AgentNode(Node[AgentNodeData]):
else: else:
return [tool for tool in tools if tool.get("type") != ToolProviderType.MCP] return [tool for tool in tools if tool.get("type") != ToolProviderType.MCP]
def _fetch_memory_for_save(self) -> BaseMemory | None:
"""
Fetch memory instance for saving node memory.
This is a simplified version that doesn't require model_instance.
"""
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
node_data = self.node_data
if not node_data.memory:
return None
# Get conversation_id
conversation_id_var = self.graph_runtime_state.variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID])
if not isinstance(conversation_id_var, StringSegment):
return None
conversation_id = conversation_id_var.value
# Return appropriate memory type based on mode
if node_data.memory.mode == MemoryMode.NODE:
# For node memory, we need a model_instance for token counting
# Use a simple default model for this purpose
try:
model_instance = ModelManager().get_default_model_instance(
tenant_id=self.tenant_id,
model_type=ModelType.LLM,
)
except Exception:
return None
return NodeTokenBufferMemory(
app_id=self.app_id,
conversation_id=conversation_id,
node_id=self._node_id,
tenant_id=self.tenant_id,
model_instance=model_instance,
)
else:
# Conversation-level memory doesn't need saving here
return None
def _transform_message( def _transform_message(
self, self,
messages: Generator[ToolInvokeMessage, None, None], messages: Generator[ToolInvokeMessage, None, None],
@ -467,6 +537,7 @@ class AgentNode(Node[AgentNodeData]):
node_type: NodeType, node_type: NodeType,
node_id: str, node_id: str,
node_execution_id: str, node_execution_id: str,
memory: BaseMemory | None = None,
) -> Generator[NodeEventBase, None, None]: ) -> Generator[NodeEventBase, None, None]:
""" """
Convert ToolInvokeMessages into tuple[plain_text, files] Convert ToolInvokeMessages into tuple[plain_text, files]
@ -711,6 +782,21 @@ class AgentNode(Node[AgentNodeData]):
is_final=True, is_final=True,
) )
# Save to node memory if in node memory mode
from core.workflow.nodes.llm import llm_utils
# Get user query from sys.query
user_query_var = self.graph_runtime_state.variable_pool.get(["sys", SystemVariableKey.QUERY])
user_query = user_query_var.text if user_query_var else ""
llm_utils.save_node_memory(
memory=memory,
variable_pool=self.graph_runtime_state.variable_pool,
user_query=user_query,
assistant_response=text,
assistant_files=files,
)
yield StreamCompletedEvent( yield StreamCompletedEvent(
node_run_result=NodeRunResult( node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED, status=WorkflowNodeExecutionStatus.SUCCEEDED,

View File

@ -332,12 +332,17 @@ class Node(Generic[NodeDataT]):
# Execute and process extractor node events # Execute and process extractor node events
for event in extractor_node.run(): for event in extractor_node.run():
# Tag event with parent node id for stream ordering and history tracking
if isinstance(event, GraphNodeEventBase):
event.in_mention_parent_id = self._node_id
if isinstance(event, NodeRunSucceededEvent): if isinstance(event, NodeRunSucceededEvent):
# Store extractor node outputs in variable pool # Store extractor node outputs in variable pool
outputs = event.node_run_result.outputs outputs: Mapping[str, Any] = event.node_run_result.outputs
for variable_name, variable_value in outputs.items(): for variable_name, variable_value in outputs.items():
self.graph_runtime_state.variable_pool.add((node_id, variable_name), variable_value) self.graph_runtime_state.variable_pool.add((node_id, variable_name), variable_value)
yield event if not isinstance(event, NodeRunStreamChunkEvent):
yield event
def run(self) -> Generator[GraphNodeEventBase, None, None]: def run(self) -> Generator[GraphNodeEventBase, None, None]:
execution_id = self.ensure_execution_id() execution_id = self.ensure_execution_id()

View File

@ -139,6 +139,50 @@ def fetch_memory(
return TokenBufferMemory(conversation=conversation, model_instance=model_instance) return TokenBufferMemory(conversation=conversation, model_instance=model_instance)
def save_node_memory(
memory: BaseMemory | None,
variable_pool: VariablePool,
user_query: str,
assistant_response: str,
user_files: Sequence["File"] | None = None,
assistant_files: Sequence["File"] | None = None,
) -> None:
"""
Save dialogue turn to node memory if applicable.
This function handles the storage logic for NodeTokenBufferMemory.
For TokenBufferMemory (conversation-level), no action is taken as it uses
the Message table which is managed elsewhere.
:param memory: Memory instance (NodeTokenBufferMemory or TokenBufferMemory)
:param variable_pool: Variable pool containing system variables
:param user_query: User's input text
:param assistant_response: Assistant's response text
:param user_files: Files attached by user (optional)
:param assistant_files: Files generated by assistant (optional)
"""
if not isinstance(memory, NodeTokenBufferMemory):
return
# Get workflow_run_id as the key for this execution
workflow_run_id_var = variable_pool.get(["sys", SystemVariableKey.WORKFLOW_EXECUTION_ID])
if not isinstance(workflow_run_id_var, StringSegment):
return
workflow_run_id = workflow_run_id_var.value
if not workflow_run_id:
return
memory.add_messages(
workflow_run_id=workflow_run_id,
user_content=user_query,
user_files=list(user_files) if user_files else None,
assistant_content=assistant_response,
assistant_files=list(assistant_files) if assistant_files else None,
)
memory.flush()
def deduct_llm_quota(tenant_id: str, model_instance: ModelInstance, usage: LLMUsage): def deduct_llm_quota(tenant_id: str, model_instance: ModelInstance, usage: LLMUsage):
provider_model_bundle = model_instance.provider_model_bundle provider_model_bundle = model_instance.provider_model_bundle
provider_configuration = provider_model_bundle.configuration provider_configuration = provider_model_bundle.configuration

View File

@ -17,7 +17,6 @@ from core.helper.code_executor import CodeExecutor, CodeLanguage
from core.llm_generator.output_parser.errors import OutputParserError from core.llm_generator.output_parser.errors import OutputParserError
from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output
from core.memory.base import BaseMemory from core.memory.base import BaseMemory
from core.memory.node_token_buffer_memory import NodeTokenBufferMemory
from core.model_manager import ModelInstance, ModelManager from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities import ( from core.model_runtime.entities import (
ImagePromptMessageContent, ImagePromptMessageContent,
@ -334,32 +333,16 @@ class LLMNode(Node[LLMNodeData]):
outputs["files"] = ArrayFileSegment(value=self._file_outputs) outputs["files"] = ArrayFileSegment(value=self._file_outputs)
# Write to Node Memory if in node memory mode # Write to Node Memory if in node memory mode
if isinstance(memory, NodeTokenBufferMemory): # Resolve the query template to get actual user content
# Get workflow_run_id as the key for this execution actual_query = variable_pool.convert_template(query or "").text
workflow_run_id_var = variable_pool.get(["sys", SystemVariableKey.WORKFLOW_EXECUTION_ID]) llm_utils.save_node_memory(
workflow_run_id = workflow_run_id_var.value if isinstance(workflow_run_id_var, StringSegment) else "" memory=memory,
variable_pool=variable_pool,
if workflow_run_id: user_query=actual_query,
# Resolve the query template to get actual user content assistant_response=clean_text,
# query may be a template like "{{#sys.query#}}" or "{{#node_id.output#}}" user_files=files,
actual_query = variable_pool.convert_template(query or "").text assistant_files=self._file_outputs,
)
# Get user files from sys.files
user_files_var = variable_pool.get(["sys", SystemVariableKey.FILES])
user_files: list[File] = []
if isinstance(user_files_var, ArrayFileSegment):
user_files = list(user_files_var.value)
elif isinstance(user_files_var, FileSegment):
user_files = [user_files_var.value]
memory.add_messages(
workflow_run_id=workflow_run_id,
user_content=actual_query,
user_files=user_files,
assistant_content=clean_text,
assistant_files=self._file_outputs,
)
memory.flush()
# Send final chunk event to indicate streaming is complete # Send final chunk event to indicate streaming is complete
yield StreamChunkEvent( yield StreamChunkEvent(

View File

@ -7,7 +7,7 @@ from typing import Any, cast
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.file import File from core.file import File
from core.memory.token_buffer_memory import TokenBufferMemory from core.memory.base import BaseMemory
from core.model_manager import ModelInstance from core.model_manager import ModelInstance
from core.model_runtime.entities import ImagePromptMessageContent from core.model_runtime.entities import ImagePromptMessageContent
from core.model_runtime.entities.llm_entities import LLMUsage from core.model_runtime.entities.llm_entities import LLMUsage
@ -145,8 +145,10 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
memory = llm_utils.fetch_memory( memory = llm_utils.fetch_memory(
variable_pool=variable_pool, variable_pool=variable_pool,
app_id=self.app_id, app_id=self.app_id,
tenant_id=self.tenant_id,
node_data_memory=node_data.memory, node_data_memory=node_data.memory,
model_instance=model_instance, model_instance=model_instance,
node_id=self._node_id,
) )
if ( if (
@ -244,6 +246,14 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
# transform result into standard format # transform result into standard format
result = self._transform_result(data=node_data, result=result or {}) result = self._transform_result(data=node_data, result=result or {})
# Save to node memory if in node memory mode
llm_utils.save_node_memory(
memory=memory,
variable_pool=variable_pool,
user_query=query,
assistant_response=json.dumps(result, ensure_ascii=False),
)
return NodeRunResult( return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED, status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=inputs, inputs=inputs,
@ -299,7 +309,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
query: str, query: str,
variable_pool: VariablePool, variable_pool: VariablePool,
model_config: ModelConfigWithCredentialsEntity, model_config: ModelConfigWithCredentialsEntity,
memory: TokenBufferMemory | None, memory: BaseMemory | None,
files: Sequence[File], files: Sequence[File],
vision_detail: ImagePromptMessageContent.DETAIL | None = None, vision_detail: ImagePromptMessageContent.DETAIL | None = None,
) -> tuple[list[PromptMessage], list[PromptMessageTool]]: ) -> tuple[list[PromptMessage], list[PromptMessageTool]]:
@ -381,7 +391,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
query: str, query: str,
variable_pool: VariablePool, variable_pool: VariablePool,
model_config: ModelConfigWithCredentialsEntity, model_config: ModelConfigWithCredentialsEntity,
memory: TokenBufferMemory | None, memory: BaseMemory | None,
files: Sequence[File], files: Sequence[File],
vision_detail: ImagePromptMessageContent.DETAIL | None = None, vision_detail: ImagePromptMessageContent.DETAIL | None = None,
) -> list[PromptMessage]: ) -> list[PromptMessage]:
@ -419,7 +429,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
query: str, query: str,
variable_pool: VariablePool, variable_pool: VariablePool,
model_config: ModelConfigWithCredentialsEntity, model_config: ModelConfigWithCredentialsEntity,
memory: TokenBufferMemory | None, memory: BaseMemory | None,
files: Sequence[File], files: Sequence[File],
vision_detail: ImagePromptMessageContent.DETAIL | None = None, vision_detail: ImagePromptMessageContent.DETAIL | None = None,
) -> list[PromptMessage]: ) -> list[PromptMessage]:
@ -453,7 +463,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
query: str, query: str,
variable_pool: VariablePool, variable_pool: VariablePool,
model_config: ModelConfigWithCredentialsEntity, model_config: ModelConfigWithCredentialsEntity,
memory: TokenBufferMemory | None, memory: BaseMemory | None,
files: Sequence[File], files: Sequence[File],
vision_detail: ImagePromptMessageContent.DETAIL | None = None, vision_detail: ImagePromptMessageContent.DETAIL | None = None,
) -> list[PromptMessage]: ) -> list[PromptMessage]:
@ -681,7 +691,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
node_data: ParameterExtractorNodeData, node_data: ParameterExtractorNodeData,
query: str, query: str,
variable_pool: VariablePool, variable_pool: VariablePool,
memory: TokenBufferMemory | None, memory: BaseMemory | None,
max_token_limit: int = 2000, max_token_limit: int = 2000,
) -> list[ChatModelMessage]: ) -> list[ChatModelMessage]:
model_mode = ModelMode(node_data.model.mode) model_mode = ModelMode(node_data.model.mode)
@ -708,7 +718,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
node_data: ParameterExtractorNodeData, node_data: ParameterExtractorNodeData,
query: str, query: str,
variable_pool: VariablePool, variable_pool: VariablePool,
memory: TokenBufferMemory | None, memory: BaseMemory | None,
max_token_limit: int = 2000, max_token_limit: int = 2000,
): ):
model_mode = ModelMode(node_data.model.mode) model_mode = ModelMode(node_data.model.mode)

View File

@ -4,7 +4,7 @@ from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.memory.token_buffer_memory import TokenBufferMemory from core.memory.base import BaseMemory
from core.model_manager import ModelInstance from core.model_manager import ModelInstance
from core.model_runtime.entities import LLMUsage, ModelPropertyKey, PromptMessageRole from core.model_runtime.entities import LLMUsage, ModelPropertyKey, PromptMessageRole
from core.model_runtime.utils.encoders import jsonable_encoder from core.model_runtime.utils.encoders import jsonable_encoder
@ -96,8 +96,10 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
memory = llm_utils.fetch_memory( memory = llm_utils.fetch_memory(
variable_pool=variable_pool, variable_pool=variable_pool,
app_id=self.app_id, app_id=self.app_id,
tenant_id=self.tenant_id,
node_data_memory=node_data.memory, node_data_memory=node_data.memory,
model_instance=model_instance, model_instance=model_instance,
node_id=self._node_id,
) )
# fetch instruction # fetch instruction
node_data.instruction = node_data.instruction or "" node_data.instruction = node_data.instruction or ""
@ -203,6 +205,14 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
"usage": jsonable_encoder(usage), "usage": jsonable_encoder(usage),
} }
# Save to node memory if in node memory mode
llm_utils.save_node_memory(
memory=memory,
variable_pool=variable_pool,
user_query=query or "",
assistant_response=f"class_name: {category_name}, class_id: {category_id}",
)
return NodeRunResult( return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED, status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=variables, inputs=variables,
@ -312,7 +322,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
self, self,
node_data: QuestionClassifierNodeData, node_data: QuestionClassifierNodeData,
query: str, query: str,
memory: TokenBufferMemory | None, memory: BaseMemory | None,
max_token_limit: int = 2000, max_token_limit: int = 2000,
): ):
model_mode = ModelMode(node_data.model.mode) model_mode = ModelMode(node_data.model.mode)

View File

@ -1,22 +1,26 @@
import re
from collections.abc import Sequence from collections.abc import Sequence
from typing import Any, Literal, Union from typing import Any, Literal, Self, Union
from pydantic import BaseModel, field_validator from pydantic import BaseModel, field_validator, model_validator
from pydantic_core.core_schema import ValidationInfo from pydantic_core.core_schema import ValidationInfo
from core.tools.entities.tool_entities import ToolProviderType from core.tools.entities.tool_entities import ToolProviderType
from core.workflow.nodes.base.entities import BaseNodeData from core.workflow.nodes.base.entities import BaseNodeData
# Pattern to match a single variable reference like {{#llm.context#}}
SINGLE_VARIABLE_PATTERN = re.compile(r"^\s*\{\{#[a-zA-Z0-9_]+(?:\.[a-zA-Z_][a-zA-Z0-9_]*)+#\}\}\s*$")
class MentionValue(BaseModel):
"""Value structure for mention type parameters.
Used when a tool parameter needs to be extracted from conversation context class MentionConfig(BaseModel):
using an extractor LLM node. """Configuration for extracting value from context variable.
Used when a tool parameter needs to be extracted from list[PromptMessage]
context using an extractor LLM node.
""" """
# Variable selector for list[PromptMessage] input to extractor # Instruction for the extractor LLM to extract the value
variable_selector: Sequence[str] instruction: str
# ID of the extractor LLM node # ID of the extractor LLM node
extractor_node_id: str extractor_node_id: str
@ -60,8 +64,10 @@ class ToolEntity(BaseModel):
class ToolNodeData(BaseNodeData, ToolEntity): class ToolNodeData(BaseNodeData, ToolEntity):
class ToolInput(BaseModel): class ToolInput(BaseModel):
# TODO: check this type # TODO: check this type
value: Union[Any, list[str], MentionValue] value: Union[Any, list[str]]
type: Literal["mixed", "variable", "constant", "mention"] type: Literal["mixed", "variable", "constant", "mention"]
# Required config for mention type, extracting value from context variable
mention_config: MentionConfig | None = None
@field_validator("type", mode="before") @field_validator("type", mode="before")
@classmethod @classmethod
@ -74,6 +80,9 @@ class ToolNodeData(BaseNodeData, ToolEntity):
if typ == "mixed" and not isinstance(value, str): if typ == "mixed" and not isinstance(value, str):
raise ValueError("value must be a string") raise ValueError("value must be a string")
elif typ == "mention":
# Skip here, will be validated in model_validator
pass
elif typ == "variable": elif typ == "variable":
if not isinstance(value, list): if not isinstance(value, list):
raise ValueError("value must be a list") raise ValueError("value must be a list")
@ -82,19 +91,31 @@ class ToolNodeData(BaseNodeData, ToolEntity):
raise ValueError("value must be a list of strings") raise ValueError("value must be a list of strings")
elif typ == "constant" and not isinstance(value, str | int | float | bool | dict): elif typ == "constant" and not isinstance(value, str | int | float | bool | dict):
raise ValueError("value must be a string, int, float, bool or dict") raise ValueError("value must be a string, int, float, bool or dict")
elif typ == "mention":
# Mention type: value should be a MentionValue or dict with required fields
if isinstance(value, MentionValue):
pass # Already validated by Pydantic
elif isinstance(value, dict):
if "extractor_node_id" not in value:
raise ValueError("value must contain extractor_node_id for mention type")
if "output_selector" not in value:
raise ValueError("value must contain output_selector for mention type")
else:
raise ValueError("value must be a MentionValue or dict for mention type")
return typ return typ
@model_validator(mode="after")
def check_mention_type(self) -> Self:
"""Validate mention type with mention_config."""
if self.type != "mention":
return self
value = self.value
if value is None:
return self
if not isinstance(value, str):
raise ValueError("value must be a string for mention type")
# For mention type, value must be a single variable reference
if not SINGLE_VARIABLE_PATTERN.match(value):
raise ValueError(
"For mention type, value must be a single variable reference "
"like {{#node.variable#}}, cannot contain other content"
)
# mention_config is required for mention type
if self.mention_config is None:
raise ValueError("mention_config is required for mention type")
return self
tool_parameters: dict[str, ToolInput] tool_parameters: dict[str, ToolInput]
# The version of the tool parameter. # The version of the tool parameter.
# If this value is None, it indicates this is a previous version # If this value is None, it indicates this is a previous version

View File

@ -214,16 +214,11 @@ class ToolNode(Node[ToolNodeData]):
parameter_value = variable.value parameter_value = variable.value
elif tool_input.type == "mention": elif tool_input.type == "mention":
# Mention type: get value from extractor node's output # Mention type: get value from extractor node's output
from .entities import MentionValue if tool_input.mention_config is None:
raise ToolParameterError(
mention_value = tool_input.value f"mention_config is required for mention type parameter '{parameter_name}'"
if isinstance(mention_value, MentionValue): )
mention_config = mention_value.model_dump() mention_config = tool_input.mention_config.model_dump()
elif isinstance(mention_value, dict):
mention_config = mention_value
else:
raise ToolParameterError(f"Invalid mention value for parameter '{parameter_name}'")
try: try:
parameter_value, found = variable_pool.resolve_mention( parameter_value, found = variable_pool.resolve_mention(
mention_config, parameter_name=parameter_name mention_config, parameter_name=parameter_name
@ -524,7 +519,7 @@ class ToolNode(Node[ToolNodeData]):
selector_key = ".".join(input.value) selector_key = ".".join(input.value)
result[f"#{selector_key}#"] = input.value result[f"#{selector_key}#"] = input.value
elif input.type == "mention": elif input.type == "mention":
# Mention type handled by extractor node, no direct variable reference # Mention type: value is handled by extractor node, no direct variable reference
pass pass
elif input.type == "constant": elif input.type == "constant":
pass pass

View File

@ -11,6 +11,11 @@ dependencies:
value: value:
marketplace_plugin_unique_identifier: langgenius/google:0.0.8@3efcf55ffeef9d0f77715e0afb23534952ae0cb385c051d0637e86d71199d1a6 marketplace_plugin_unique_identifier: langgenius/google:0.0.8@3efcf55ffeef9d0f77715e0afb23534952ae0cb385c051d0637e86d71199d1a6
version: null version: null
- current_identifier: null
type: marketplace
value:
marketplace_plugin_unique_identifier: langgenius/openai:0.2.3@5a7f82fa86e28332ad51941d0b491c1e8a38ead539656442f7bf4c6129cd15fa
version: null
- current_identifier: null - current_identifier: null
type: marketplace type: marketplace
value: value:
@ -115,7 +120,8 @@ workflow:
enabled: false enabled: false
variable_selector: [] variable_selector: []
memory: memory:
query_prompt_template: '' mode: node
query_prompt_template: '{{#sys.query#}}'
role_prefix: role_prefix:
assistant: '' assistant: ''
user: '' user: ''
@ -201,29 +207,17 @@ workflow:
tool_node_version: '2' tool_node_version: '2'
tool_parameters: tool_parameters:
query: query:
type: variable mention_config:
value: default_value: ''
- ext_1 extractor_node_id: 1767773709491_ext_query
- text instruction: 请从对话历史中提取用户想要搜索的关键词,只返回关键词本身
null_strategy: use_default
output_selector:
- structured_output
- query
type: mention
value: '{{#llm.context#}}'
type: tool type: tool
virtual_nodes:
- data:
model:
completion_params:
temperature: 0.7
mode: chat
name: qwen-max
provider: langgenius/tongyi/tongyi
context:
enabled: false
prompt_template:
- role: user
text: '{{#llm.context#}}'
- role: user
text: 请从对话历史中提取用户想要搜索的关键词,只返回关键词本身,不要返回其他内容
title: 提取搜索关键词
id: ext_1
type: llm
height: 52 height: 52
id: '1767773709491' id: '1767773709491'
position: position:
@ -237,6 +231,54 @@ workflow:
targetPosition: left targetPosition: left
type: custom type: custom
width: 242 width: 242
- data:
context:
enabled: false
variable_selector: []
model:
completion_params:
temperature: 0.7
mode: chat
name: gpt-4o-mini
provider: langgenius/openai/openai
parent_node_id: '1767773709491'
prompt_template:
- $context:
- llm
- context
id: 75d58e22-dc59-40c8-ba6f-aeb28f4f305a
- id: 18ba6710-77f5-47f4-b144-9191833bb547
role: user
text: 请从对话历史中提取用户想要搜索的关键词,只返回关键词本身,不要返回其他内容
selected: false
structured_output:
schema:
additionalProperties: false
properties:
query:
description: 搜索的关键词
type: string
required:
- query
type: object
structured_output_enabled: true
title: 提取搜索关键词
type: llm
vision:
enabled: false
height: 88
id: 1767773709491_ext_query
position:
x: 531
y: 382
positionAbsolute:
x: 531
y: 382
selected: true
sourcePosition: right
targetPosition: left
type: custom
width: 242
- data: - data:
answer: '搜索结果: answer: '搜索结果:
@ -254,13 +296,13 @@ workflow:
positionAbsolute: positionAbsolute:
x: 984 x: 984
y: 282 y: 282
selected: true selected: false
sourcePosition: right sourcePosition: right
targetPosition: left targetPosition: left
type: custom type: custom
width: 242 width: 242
viewport: viewport:
x: 151 x: -151
y: 141.5 y: 123
zoom: 1 zoom: 1
rag_pipeline_variables: [] rag_pipeline_variables: []