mirror of
https://github.com/langgenius/dify.git
synced 2026-01-14 06:07:33 +08:00
Merge remote-tracking branch 'origin/feat/pull-a-variable' into feat/pull-a-variable
This commit is contained in:
commit
dbed937fc6
@ -82,7 +82,7 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||
response_chunk.update(data)
|
||||
else:
|
||||
response_chunk.update(sub_stream_response.model_dump(mode="json"))
|
||||
response_chunk.update(sub_stream_response.model_dump(mode="json", exclude_none=True))
|
||||
yield response_chunk
|
||||
|
||||
@classmethod
|
||||
@ -110,7 +110,7 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
}
|
||||
|
||||
if isinstance(sub_stream_response, MessageEndStreamResponse):
|
||||
sub_stream_response_dict = sub_stream_response.model_dump(mode="json")
|
||||
sub_stream_response_dict = sub_stream_response.model_dump(mode="json", exclude_none=True)
|
||||
metadata = sub_stream_response_dict.get("metadata", {})
|
||||
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
|
||||
response_chunk.update(sub_stream_response_dict)
|
||||
@ -120,6 +120,6 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse):
|
||||
response_chunk.update(sub_stream_response.to_ignore_detail_dict())
|
||||
else:
|
||||
response_chunk.update(sub_stream_response.model_dump(mode="json"))
|
||||
response_chunk.update(sub_stream_response.model_dump(mode="json", exclude_none=True))
|
||||
|
||||
yield response_chunk
|
||||
|
||||
@ -81,7 +81,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||
response_chunk.update(data)
|
||||
else:
|
||||
response_chunk.update(sub_stream_response.model_dump(mode="json"))
|
||||
response_chunk.update(sub_stream_response.model_dump(mode="json", exclude_none=True))
|
||||
yield response_chunk
|
||||
|
||||
@classmethod
|
||||
@ -109,7 +109,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
}
|
||||
|
||||
if isinstance(sub_stream_response, MessageEndStreamResponse):
|
||||
sub_stream_response_dict = sub_stream_response.model_dump(mode="json")
|
||||
sub_stream_response_dict = sub_stream_response.model_dump(mode="json", exclude_none=True)
|
||||
metadata = sub_stream_response_dict.get("metadata", {})
|
||||
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
|
||||
response_chunk.update(sub_stream_response_dict)
|
||||
@ -117,6 +117,6 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||
response_chunk.update(data)
|
||||
else:
|
||||
response_chunk.update(sub_stream_response.model_dump(mode="json"))
|
||||
response_chunk.update(sub_stream_response.model_dump(mode="json", exclude_none=True))
|
||||
|
||||
yield response_chunk
|
||||
|
||||
@ -81,7 +81,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||
response_chunk.update(data)
|
||||
else:
|
||||
response_chunk.update(sub_stream_response.model_dump(mode="json"))
|
||||
response_chunk.update(sub_stream_response.model_dump(mode="json", exclude_none=True))
|
||||
yield response_chunk
|
||||
|
||||
@classmethod
|
||||
@ -109,7 +109,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
}
|
||||
|
||||
if isinstance(sub_stream_response, MessageEndStreamResponse):
|
||||
sub_stream_response_dict = sub_stream_response.model_dump(mode="json")
|
||||
sub_stream_response_dict = sub_stream_response.model_dump(mode="json", exclude_none=True)
|
||||
metadata = sub_stream_response_dict.get("metadata", {})
|
||||
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
|
||||
response_chunk.update(sub_stream_response_dict)
|
||||
@ -117,6 +117,6 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||
response_chunk.update(data)
|
||||
else:
|
||||
response_chunk.update(sub_stream_response.model_dump(mode="json"))
|
||||
response_chunk.update(sub_stream_response.model_dump(mode="json", exclude_none=True))
|
||||
|
||||
yield response_chunk
|
||||
|
||||
@ -70,6 +70,8 @@ class _NodeSnapshot:
|
||||
"""Empty string means the node is not executing inside an iteration."""
|
||||
loop_id: str = ""
|
||||
"""Empty string means the node is not executing inside a loop."""
|
||||
mention_parent_id: str = ""
|
||||
"""Empty string means the node is not an extractor node."""
|
||||
|
||||
|
||||
class WorkflowResponseConverter:
|
||||
@ -131,6 +133,7 @@ class WorkflowResponseConverter:
|
||||
start_at=event.start_at,
|
||||
iteration_id=event.in_iteration_id or "",
|
||||
loop_id=event.in_loop_id or "",
|
||||
mention_parent_id=event.in_mention_parent_id or "",
|
||||
)
|
||||
node_execution_id = NodeExecutionId(event.node_execution_id)
|
||||
self._node_snapshots[node_execution_id] = snapshot
|
||||
@ -287,6 +290,7 @@ class WorkflowResponseConverter:
|
||||
created_at=int(snapshot.start_at.timestamp()),
|
||||
iteration_id=event.in_iteration_id,
|
||||
loop_id=event.in_loop_id,
|
||||
mention_parent_id=event.in_mention_parent_id,
|
||||
agent_strategy=event.agent_strategy,
|
||||
),
|
||||
)
|
||||
@ -373,6 +377,7 @@ class WorkflowResponseConverter:
|
||||
files=self.fetch_files_from_node_outputs(event.outputs or {}),
|
||||
iteration_id=event.in_iteration_id,
|
||||
loop_id=event.in_loop_id,
|
||||
mention_parent_id=event.in_mention_parent_id,
|
||||
),
|
||||
)
|
||||
|
||||
@ -422,6 +427,7 @@ class WorkflowResponseConverter:
|
||||
files=self.fetch_files_from_node_outputs(event.outputs or {}),
|
||||
iteration_id=event.in_iteration_id,
|
||||
loop_id=event.in_loop_id,
|
||||
mention_parent_id=event.in_mention_parent_id,
|
||||
retry_index=event.retry_index,
|
||||
),
|
||||
)
|
||||
|
||||
@ -79,7 +79,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||
response_chunk.update(data)
|
||||
else:
|
||||
response_chunk.update(sub_stream_response.model_dump(mode="json"))
|
||||
response_chunk.update(sub_stream_response.model_dump(mode="json", exclude_none=True))
|
||||
yield response_chunk
|
||||
|
||||
@classmethod
|
||||
@ -106,7 +106,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
}
|
||||
|
||||
if isinstance(sub_stream_response, MessageEndStreamResponse):
|
||||
sub_stream_response_dict = sub_stream_response.model_dump(mode="json")
|
||||
sub_stream_response_dict = sub_stream_response.model_dump(mode="json", exclude_none=True)
|
||||
metadata = sub_stream_response_dict.get("metadata", {})
|
||||
if not isinstance(metadata, dict):
|
||||
metadata = {}
|
||||
@ -116,6 +116,6 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||
response_chunk.update(data)
|
||||
else:
|
||||
response_chunk.update(sub_stream_response.model_dump(mode="json"))
|
||||
response_chunk.update(sub_stream_response.model_dump(mode="json", exclude_none=True))
|
||||
|
||||
yield response_chunk
|
||||
|
||||
@ -60,7 +60,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||
response_chunk.update(cast(dict, data))
|
||||
else:
|
||||
response_chunk.update(sub_stream_response.model_dump())
|
||||
response_chunk.update(sub_stream_response.model_dump(exclude_none=True))
|
||||
yield response_chunk
|
||||
|
||||
@classmethod
|
||||
@ -91,5 +91,5 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse):
|
||||
response_chunk.update(cast(dict, sub_stream_response.to_ignore_detail_dict()))
|
||||
else:
|
||||
response_chunk.update(sub_stream_response.model_dump())
|
||||
response_chunk.update(sub_stream_response.model_dump(exclude_none=True))
|
||||
yield response_chunk
|
||||
|
||||
@ -60,7 +60,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||
response_chunk.update(data)
|
||||
else:
|
||||
response_chunk.update(sub_stream_response.model_dump(mode="json"))
|
||||
response_chunk.update(sub_stream_response.model_dump(mode="json", exclude_none=True))
|
||||
yield response_chunk
|
||||
|
||||
@classmethod
|
||||
@ -91,5 +91,5 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse):
|
||||
response_chunk.update(sub_stream_response.to_ignore_detail_dict())
|
||||
else:
|
||||
response_chunk.update(sub_stream_response.model_dump(mode="json"))
|
||||
response_chunk.update(sub_stream_response.model_dump(mode="json", exclude_none=True))
|
||||
yield response_chunk
|
||||
|
||||
@ -385,6 +385,7 @@ class WorkflowBasedAppRunner:
|
||||
start_at=event.start_at,
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
in_loop_id=event.in_loop_id,
|
||||
in_mention_parent_id=event.in_mention_parent_id,
|
||||
inputs=inputs,
|
||||
process_data=process_data,
|
||||
outputs=outputs,
|
||||
@ -405,6 +406,7 @@ class WorkflowBasedAppRunner:
|
||||
start_at=event.start_at,
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
in_loop_id=event.in_loop_id,
|
||||
in_mention_parent_id=event.in_mention_parent_id,
|
||||
agent_strategy=event.agent_strategy,
|
||||
provider_type=event.provider_type,
|
||||
provider_id=event.provider_id,
|
||||
@ -428,6 +430,7 @@ class WorkflowBasedAppRunner:
|
||||
execution_metadata=execution_metadata,
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
in_loop_id=event.in_loop_id,
|
||||
in_mention_parent_id=event.in_mention_parent_id,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, NodeRunFailedEvent):
|
||||
@ -444,6 +447,7 @@ class WorkflowBasedAppRunner:
|
||||
execution_metadata=event.node_run_result.metadata,
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
in_loop_id=event.in_loop_id,
|
||||
in_mention_parent_id=event.in_mention_parent_id,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, NodeRunExceptionEvent):
|
||||
@ -460,6 +464,7 @@ class WorkflowBasedAppRunner:
|
||||
execution_metadata=event.node_run_result.metadata,
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
in_loop_id=event.in_loop_id,
|
||||
in_mention_parent_id=event.in_mention_parent_id,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, NodeRunStreamChunkEvent):
|
||||
@ -469,6 +474,7 @@ class WorkflowBasedAppRunner:
|
||||
from_variable_selector=list(event.selector),
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
in_loop_id=event.in_loop_id,
|
||||
in_mention_parent_id=event.in_mention_parent_id,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, NodeRunRetrieverResourceEvent):
|
||||
@ -477,6 +483,7 @@ class WorkflowBasedAppRunner:
|
||||
retriever_resources=event.retriever_resources,
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
in_loop_id=event.in_loop_id,
|
||||
in_mention_parent_id=event.in_mention_parent_id,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, NodeRunAgentLogEvent):
|
||||
|
||||
@ -190,6 +190,8 @@ class QueueTextChunkEvent(AppQueueEvent):
|
||||
"""iteration id if node is in iteration"""
|
||||
in_loop_id: str | None = None
|
||||
"""loop id if node is in loop"""
|
||||
in_mention_parent_id: str | None = None
|
||||
"""parent node id if this is an extractor node event"""
|
||||
|
||||
|
||||
class QueueAgentMessageEvent(AppQueueEvent):
|
||||
@ -229,6 +231,8 @@ class QueueRetrieverResourcesEvent(AppQueueEvent):
|
||||
"""iteration id if node is in iteration"""
|
||||
in_loop_id: str | None = None
|
||||
"""loop id if node is in loop"""
|
||||
in_mention_parent_id: str | None = None
|
||||
"""parent node id if this is an extractor node event"""
|
||||
|
||||
|
||||
class QueueAnnotationReplyEvent(AppQueueEvent):
|
||||
@ -306,6 +310,8 @@ class QueueNodeStartedEvent(AppQueueEvent):
|
||||
node_run_index: int = 1 # FIXME(-LAN-): may not used
|
||||
in_iteration_id: str | None = None
|
||||
in_loop_id: str | None = None
|
||||
in_mention_parent_id: str | None = None
|
||||
"""parent node id if this is an extractor node event"""
|
||||
start_at: datetime
|
||||
agent_strategy: AgentNodeStrategyInit | None = None
|
||||
|
||||
@ -328,6 +334,8 @@ class QueueNodeSucceededEvent(AppQueueEvent):
|
||||
"""iteration id if node is in iteration"""
|
||||
in_loop_id: str | None = None
|
||||
"""loop id if node is in loop"""
|
||||
in_mention_parent_id: str | None = None
|
||||
"""parent node id if this is an extractor node event"""
|
||||
start_at: datetime
|
||||
|
||||
inputs: Mapping[str, object] = Field(default_factory=dict)
|
||||
@ -383,6 +391,8 @@ class QueueNodeExceptionEvent(AppQueueEvent):
|
||||
"""iteration id if node is in iteration"""
|
||||
in_loop_id: str | None = None
|
||||
"""loop id if node is in loop"""
|
||||
in_mention_parent_id: str | None = None
|
||||
"""parent node id if this is an extractor node event"""
|
||||
start_at: datetime
|
||||
|
||||
inputs: Mapping[str, object] = Field(default_factory=dict)
|
||||
@ -407,6 +417,8 @@ class QueueNodeFailedEvent(AppQueueEvent):
|
||||
"""iteration id if node is in iteration"""
|
||||
in_loop_id: str | None = None
|
||||
"""loop id if node is in loop"""
|
||||
in_mention_parent_id: str | None = None
|
||||
"""parent node id if this is an extractor node event"""
|
||||
start_at: datetime
|
||||
|
||||
inputs: Mapping[str, object] = Field(default_factory=dict)
|
||||
|
||||
@ -262,6 +262,7 @@ class NodeStartStreamResponse(StreamResponse):
|
||||
extras: dict[str, object] = Field(default_factory=dict)
|
||||
iteration_id: str | None = None
|
||||
loop_id: str | None = None
|
||||
mention_parent_id: str | None = None
|
||||
agent_strategy: AgentNodeStrategyInit | None = None
|
||||
|
||||
event: StreamEvent = StreamEvent.NODE_STARTED
|
||||
@ -285,6 +286,7 @@ class NodeStartStreamResponse(StreamResponse):
|
||||
"extras": {},
|
||||
"iteration_id": self.data.iteration_id,
|
||||
"loop_id": self.data.loop_id,
|
||||
"mention_parent_id": self.data.mention_parent_id,
|
||||
},
|
||||
}
|
||||
|
||||
@ -320,6 +322,7 @@ class NodeFinishStreamResponse(StreamResponse):
|
||||
files: Sequence[Mapping[str, Any]] | None = []
|
||||
iteration_id: str | None = None
|
||||
loop_id: str | None = None
|
||||
mention_parent_id: str | None = None
|
||||
|
||||
event: StreamEvent = StreamEvent.NODE_FINISHED
|
||||
workflow_run_id: str
|
||||
@ -349,6 +352,7 @@ class NodeFinishStreamResponse(StreamResponse):
|
||||
"files": [],
|
||||
"iteration_id": self.data.iteration_id,
|
||||
"loop_id": self.data.loop_id,
|
||||
"mention_parent_id": self.data.mention_parent_id,
|
||||
},
|
||||
}
|
||||
|
||||
@ -384,6 +388,7 @@ class NodeRetryStreamResponse(StreamResponse):
|
||||
files: Sequence[Mapping[str, Any]] | None = []
|
||||
iteration_id: str | None = None
|
||||
loop_id: str | None = None
|
||||
mention_parent_id: str | None = None
|
||||
retry_index: int = 0
|
||||
|
||||
event: StreamEvent = StreamEvent.NODE_RETRY
|
||||
@ -414,6 +419,7 @@ class NodeRetryStreamResponse(StreamResponse):
|
||||
"files": [],
|
||||
"iteration_id": self.data.iteration_id,
|
||||
"loop_id": self.data.loop_id,
|
||||
"mention_parent_id": self.data.mention_parent_id,
|
||||
"retry_index": self.data.retry_index,
|
||||
},
|
||||
}
|
||||
|
||||
@ -5,7 +5,7 @@ from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEnti
|
||||
from core.file import file_manager
|
||||
from core.file.models import File
|
||||
from core.helper.code_executor.jinja2.jinja2_formatter import Jinja2Formatter
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.memory.base import BaseMemory
|
||||
from core.model_runtime.entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessage,
|
||||
@ -43,7 +43,7 @@ class AdvancedPromptTransform(PromptTransform):
|
||||
files: Sequence[File],
|
||||
context: str | None,
|
||||
memory_config: MemoryConfig | None,
|
||||
memory: TokenBufferMemory | None,
|
||||
memory: BaseMemory | None,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
|
||||
) -> list[PromptMessage]:
|
||||
@ -84,7 +84,7 @@ class AdvancedPromptTransform(PromptTransform):
|
||||
files: Sequence[File],
|
||||
context: str | None,
|
||||
memory_config: MemoryConfig | None,
|
||||
memory: TokenBufferMemory | None,
|
||||
memory: BaseMemory | None,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
|
||||
) -> list[PromptMessage]:
|
||||
@ -145,7 +145,7 @@ class AdvancedPromptTransform(PromptTransform):
|
||||
files: Sequence[File],
|
||||
context: str | None,
|
||||
memory_config: MemoryConfig | None,
|
||||
memory: TokenBufferMemory | None,
|
||||
memory: BaseMemory | None,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
|
||||
) -> list[PromptMessage]:
|
||||
@ -270,7 +270,7 @@ class AdvancedPromptTransform(PromptTransform):
|
||||
|
||||
def _set_histories_variable(
|
||||
self,
|
||||
memory: TokenBufferMemory,
|
||||
memory: BaseMemory,
|
||||
memory_config: MemoryConfig,
|
||||
raw_prompt: str,
|
||||
role_prefix: MemoryConfig.RolePrefix,
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
from typing import Any
|
||||
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.memory.base import BaseMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.message_entities import PromptMessage
|
||||
from core.model_runtime.entities.model_entities import ModelPropertyKey
|
||||
@ -11,7 +11,7 @@ from core.prompt.entities.advanced_prompt_entities import MemoryConfig
|
||||
class PromptTransform:
|
||||
def _append_chat_histories(
|
||||
self,
|
||||
memory: TokenBufferMemory,
|
||||
memory: BaseMemory,
|
||||
memory_config: MemoryConfig,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
@ -52,7 +52,7 @@ class PromptTransform:
|
||||
|
||||
def _get_history_messages_from_memory(
|
||||
self,
|
||||
memory: TokenBufferMemory,
|
||||
memory: BaseMemory,
|
||||
memory_config: MemoryConfig,
|
||||
max_token_limit: int,
|
||||
human_prefix: str | None = None,
|
||||
@ -73,7 +73,7 @@ class PromptTransform:
|
||||
return memory.get_history_prompt_text(**kwargs)
|
||||
|
||||
def _get_history_messages_list_from_memory(
|
||||
self, memory: TokenBufferMemory, memory_config: MemoryConfig, max_token_limit: int
|
||||
self, memory: BaseMemory, memory_config: MemoryConfig, max_token_limit: int
|
||||
) -> list[PromptMessage]:
|
||||
"""Get memory messages."""
|
||||
return list(
|
||||
|
||||
@ -253,6 +253,7 @@ class WorkflowNodeExecutionMetadataKey(StrEnum):
|
||||
LOOP_VARIABLE_MAP = "loop_variable_map" # single loop variable output
|
||||
DATASOURCE_INFO = "datasource_info"
|
||||
COMPLETED_REASON = "completed_reason" # completed reason for loop node
|
||||
MENTION_PARENT_ID = "mention_parent_id" # parent node id for extractor nodes
|
||||
|
||||
|
||||
class WorkflowNodeExecutionStatus(StrEnum):
|
||||
|
||||
@ -93,8 +93,8 @@ class EventHandler:
|
||||
Args:
|
||||
event: The event to handle
|
||||
"""
|
||||
# Events in loops or iterations are always collected
|
||||
if event.in_loop_id or event.in_iteration_id:
|
||||
# Events in loops, iterations, or extractor groups are always collected
|
||||
if event.in_loop_id or event.in_iteration_id or event.in_mention_parent_id:
|
||||
self._event_collector.collect(event)
|
||||
return
|
||||
return self._dispatch(event)
|
||||
|
||||
@ -68,6 +68,7 @@ class _NodeRuntimeSnapshot:
|
||||
predecessor_node_id: str | None
|
||||
iteration_id: str | None
|
||||
loop_id: str | None
|
||||
mention_parent_id: str | None
|
||||
created_at: datetime
|
||||
|
||||
|
||||
@ -230,6 +231,7 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
|
||||
metadata = {
|
||||
WorkflowNodeExecutionMetadataKey.ITERATION_ID: event.in_iteration_id,
|
||||
WorkflowNodeExecutionMetadataKey.LOOP_ID: event.in_loop_id,
|
||||
WorkflowNodeExecutionMetadataKey.MENTION_PARENT_ID: event.in_mention_parent_id,
|
||||
}
|
||||
|
||||
domain_execution = WorkflowNodeExecution(
|
||||
@ -256,6 +258,7 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
|
||||
predecessor_node_id=event.predecessor_node_id,
|
||||
iteration_id=event.in_iteration_id,
|
||||
loop_id=event.in_loop_id,
|
||||
mention_parent_id=event.in_mention_parent_id,
|
||||
created_at=event.start_at,
|
||||
)
|
||||
self._node_snapshots[event.id] = snapshot
|
||||
|
||||
@ -21,6 +21,12 @@ class GraphNodeEventBase(GraphEngineEvent):
|
||||
"""iteration id if node is in iteration"""
|
||||
in_loop_id: str | None = None
|
||||
"""loop id if node is in loop"""
|
||||
in_mention_parent_id: str | None = None
|
||||
"""Parent node id if this is an extractor node event.
|
||||
|
||||
When set, indicates this event belongs to an extractor node that
|
||||
is extracting values for the specified parent node.
|
||||
"""
|
||||
|
||||
# The version of the node, or "1" if not specified.
|
||||
node_version: str = "1"
|
||||
|
||||
@ -12,11 +12,14 @@ from sqlalchemy.orm import Session
|
||||
from core.agent.entities import AgentToolEntity
|
||||
from core.agent.plugin_entities import AgentStrategyParameter
|
||||
from core.file import File, FileTransferMethod
|
||||
from core.memory.base import BaseMemory
|
||||
from core.memory.node_token_buffer_memory import NodeTokenBufferMemory
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.prompt.entities.advanced_prompt_entities import MemoryMode
|
||||
from core.provider_manager import ProviderManager
|
||||
from core.tools.entities.tool_entities import (
|
||||
ToolIdentity,
|
||||
@ -136,6 +139,9 @@ class AgentNode(Node[AgentNodeData]):
|
||||
)
|
||||
return
|
||||
|
||||
# Fetch memory for node memory saving
|
||||
memory = self._fetch_memory_for_save()
|
||||
|
||||
try:
|
||||
yield from self._transform_message(
|
||||
messages=message_stream,
|
||||
@ -149,6 +155,7 @@ class AgentNode(Node[AgentNodeData]):
|
||||
node_type=self.node_type,
|
||||
node_id=self._node_id,
|
||||
node_execution_id=self.id,
|
||||
memory=memory,
|
||||
)
|
||||
except PluginDaemonClientSideError as e:
|
||||
transform_error = AgentMessageTransformError(
|
||||
@ -395,8 +402,20 @@ class AgentNode(Node[AgentNodeData]):
|
||||
icon = None
|
||||
return icon
|
||||
|
||||
def _fetch_memory(self, model_instance: ModelInstance) -> TokenBufferMemory | None:
|
||||
# get conversation id
|
||||
def _fetch_memory(self, model_instance: ModelInstance) -> BaseMemory | None:
|
||||
"""
|
||||
Fetch memory based on configuration mode.
|
||||
|
||||
Returns TokenBufferMemory for conversation mode (default),
|
||||
or NodeTokenBufferMemory for node mode (Chatflow only).
|
||||
"""
|
||||
node_data = self.node_data
|
||||
memory_config = node_data.memory
|
||||
|
||||
if not memory_config:
|
||||
return None
|
||||
|
||||
# get conversation id (required for both modes in Chatflow)
|
||||
conversation_id_variable = self.graph_runtime_state.variable_pool.get(
|
||||
["sys", SystemVariableKey.CONVERSATION_ID]
|
||||
)
|
||||
@ -404,16 +423,26 @@ class AgentNode(Node[AgentNodeData]):
|
||||
return None
|
||||
conversation_id = conversation_id_variable.value
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = select(Conversation).where(Conversation.app_id == self.app_id, Conversation.id == conversation_id)
|
||||
conversation = session.scalar(stmt)
|
||||
|
||||
if not conversation:
|
||||
return None
|
||||
|
||||
memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance)
|
||||
|
||||
return memory
|
||||
# Return appropriate memory type based on mode
|
||||
if memory_config.mode == MemoryMode.NODE:
|
||||
# Node-level memory (Chatflow only)
|
||||
return NodeTokenBufferMemory(
|
||||
app_id=self.app_id,
|
||||
conversation_id=conversation_id,
|
||||
node_id=self._node_id,
|
||||
tenant_id=self.tenant_id,
|
||||
model_instance=model_instance,
|
||||
)
|
||||
else:
|
||||
# Conversation-level memory (default)
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = select(Conversation).where(
|
||||
Conversation.app_id == self.app_id, Conversation.id == conversation_id
|
||||
)
|
||||
conversation = session.scalar(stmt)
|
||||
if not conversation:
|
||||
return None
|
||||
return TokenBufferMemory(conversation=conversation, model_instance=model_instance)
|
||||
|
||||
def _fetch_model(self, value: dict[str, Any]) -> tuple[ModelInstance, AIModelEntity | None]:
|
||||
provider_manager = ProviderManager()
|
||||
@ -457,6 +486,47 @@ class AgentNode(Node[AgentNodeData]):
|
||||
else:
|
||||
return [tool for tool in tools if tool.get("type") != ToolProviderType.MCP]
|
||||
|
||||
def _fetch_memory_for_save(self) -> BaseMemory | None:
|
||||
"""
|
||||
Fetch memory instance for saving node memory.
|
||||
This is a simplified version that doesn't require model_instance.
|
||||
"""
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
|
||||
node_data = self.node_data
|
||||
if not node_data.memory:
|
||||
return None
|
||||
|
||||
# Get conversation_id
|
||||
conversation_id_var = self.graph_runtime_state.variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID])
|
||||
if not isinstance(conversation_id_var, StringSegment):
|
||||
return None
|
||||
conversation_id = conversation_id_var.value
|
||||
|
||||
# Return appropriate memory type based on mode
|
||||
if node_data.memory.mode == MemoryMode.NODE:
|
||||
# For node memory, we need a model_instance for token counting
|
||||
# Use a simple default model for this purpose
|
||||
try:
|
||||
model_instance = ModelManager().get_default_model_instance(
|
||||
tenant_id=self.tenant_id,
|
||||
model_type=ModelType.LLM,
|
||||
)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
return NodeTokenBufferMemory(
|
||||
app_id=self.app_id,
|
||||
conversation_id=conversation_id,
|
||||
node_id=self._node_id,
|
||||
tenant_id=self.tenant_id,
|
||||
model_instance=model_instance,
|
||||
)
|
||||
else:
|
||||
# Conversation-level memory doesn't need saving here
|
||||
return None
|
||||
|
||||
def _transform_message(
|
||||
self,
|
||||
messages: Generator[ToolInvokeMessage, None, None],
|
||||
@ -467,6 +537,7 @@ class AgentNode(Node[AgentNodeData]):
|
||||
node_type: NodeType,
|
||||
node_id: str,
|
||||
node_execution_id: str,
|
||||
memory: BaseMemory | None = None,
|
||||
) -> Generator[NodeEventBase, None, None]:
|
||||
"""
|
||||
Convert ToolInvokeMessages into tuple[plain_text, files]
|
||||
@ -711,6 +782,21 @@ class AgentNode(Node[AgentNodeData]):
|
||||
is_final=True,
|
||||
)
|
||||
|
||||
# Save to node memory if in node memory mode
|
||||
from core.workflow.nodes.llm import llm_utils
|
||||
|
||||
# Get user query from sys.query
|
||||
user_query_var = self.graph_runtime_state.variable_pool.get(["sys", SystemVariableKey.QUERY])
|
||||
user_query = user_query_var.text if user_query_var else ""
|
||||
|
||||
llm_utils.save_node_memory(
|
||||
memory=memory,
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
user_query=user_query,
|
||||
assistant_response=text,
|
||||
assistant_files=files,
|
||||
)
|
||||
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
|
||||
@ -332,12 +332,17 @@ class Node(Generic[NodeDataT]):
|
||||
|
||||
# Execute and process extractor node events
|
||||
for event in extractor_node.run():
|
||||
# Tag event with parent node id for stream ordering and history tracking
|
||||
if isinstance(event, GraphNodeEventBase):
|
||||
event.in_mention_parent_id = self._node_id
|
||||
|
||||
if isinstance(event, NodeRunSucceededEvent):
|
||||
# Store extractor node outputs in variable pool
|
||||
outputs = event.node_run_result.outputs
|
||||
outputs: Mapping[str, Any] = event.node_run_result.outputs
|
||||
for variable_name, variable_value in outputs.items():
|
||||
self.graph_runtime_state.variable_pool.add((node_id, variable_name), variable_value)
|
||||
yield event
|
||||
if not isinstance(event, NodeRunStreamChunkEvent):
|
||||
yield event
|
||||
|
||||
def run(self) -> Generator[GraphNodeEventBase, None, None]:
|
||||
execution_id = self.ensure_execution_id()
|
||||
|
||||
@ -139,6 +139,50 @@ def fetch_memory(
|
||||
return TokenBufferMemory(conversation=conversation, model_instance=model_instance)
|
||||
|
||||
|
||||
def save_node_memory(
|
||||
memory: BaseMemory | None,
|
||||
variable_pool: VariablePool,
|
||||
user_query: str,
|
||||
assistant_response: str,
|
||||
user_files: Sequence["File"] | None = None,
|
||||
assistant_files: Sequence["File"] | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Save dialogue turn to node memory if applicable.
|
||||
|
||||
This function handles the storage logic for NodeTokenBufferMemory.
|
||||
For TokenBufferMemory (conversation-level), no action is taken as it uses
|
||||
the Message table which is managed elsewhere.
|
||||
|
||||
:param memory: Memory instance (NodeTokenBufferMemory or TokenBufferMemory)
|
||||
:param variable_pool: Variable pool containing system variables
|
||||
:param user_query: User's input text
|
||||
:param assistant_response: Assistant's response text
|
||||
:param user_files: Files attached by user (optional)
|
||||
:param assistant_files: Files generated by assistant (optional)
|
||||
"""
|
||||
if not isinstance(memory, NodeTokenBufferMemory):
|
||||
return
|
||||
|
||||
# Get workflow_run_id as the key for this execution
|
||||
workflow_run_id_var = variable_pool.get(["sys", SystemVariableKey.WORKFLOW_EXECUTION_ID])
|
||||
if not isinstance(workflow_run_id_var, StringSegment):
|
||||
return
|
||||
|
||||
workflow_run_id = workflow_run_id_var.value
|
||||
if not workflow_run_id:
|
||||
return
|
||||
|
||||
memory.add_messages(
|
||||
workflow_run_id=workflow_run_id,
|
||||
user_content=user_query,
|
||||
user_files=list(user_files) if user_files else None,
|
||||
assistant_content=assistant_response,
|
||||
assistant_files=list(assistant_files) if assistant_files else None,
|
||||
)
|
||||
memory.flush()
|
||||
|
||||
|
||||
def deduct_llm_quota(tenant_id: str, model_instance: ModelInstance, usage: LLMUsage):
|
||||
provider_model_bundle = model_instance.provider_model_bundle
|
||||
provider_configuration = provider_model_bundle.configuration
|
||||
|
||||
@ -17,7 +17,6 @@ from core.helper.code_executor import CodeExecutor, CodeLanguage
|
||||
from core.llm_generator.output_parser.errors import OutputParserError
|
||||
from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output
|
||||
from core.memory.base import BaseMemory
|
||||
from core.memory.node_token_buffer_memory import NodeTokenBufferMemory
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
from core.model_runtime.entities import (
|
||||
ImagePromptMessageContent,
|
||||
@ -334,32 +333,16 @@ class LLMNode(Node[LLMNodeData]):
|
||||
outputs["files"] = ArrayFileSegment(value=self._file_outputs)
|
||||
|
||||
# Write to Node Memory if in node memory mode
|
||||
if isinstance(memory, NodeTokenBufferMemory):
|
||||
# Get workflow_run_id as the key for this execution
|
||||
workflow_run_id_var = variable_pool.get(["sys", SystemVariableKey.WORKFLOW_EXECUTION_ID])
|
||||
workflow_run_id = workflow_run_id_var.value if isinstance(workflow_run_id_var, StringSegment) else ""
|
||||
|
||||
if workflow_run_id:
|
||||
# Resolve the query template to get actual user content
|
||||
# query may be a template like "{{#sys.query#}}" or "{{#node_id.output#}}"
|
||||
actual_query = variable_pool.convert_template(query or "").text
|
||||
|
||||
# Get user files from sys.files
|
||||
user_files_var = variable_pool.get(["sys", SystemVariableKey.FILES])
|
||||
user_files: list[File] = []
|
||||
if isinstance(user_files_var, ArrayFileSegment):
|
||||
user_files = list(user_files_var.value)
|
||||
elif isinstance(user_files_var, FileSegment):
|
||||
user_files = [user_files_var.value]
|
||||
|
||||
memory.add_messages(
|
||||
workflow_run_id=workflow_run_id,
|
||||
user_content=actual_query,
|
||||
user_files=user_files,
|
||||
assistant_content=clean_text,
|
||||
assistant_files=self._file_outputs,
|
||||
)
|
||||
memory.flush()
|
||||
# Resolve the query template to get actual user content
|
||||
actual_query = variable_pool.convert_template(query or "").text
|
||||
llm_utils.save_node_memory(
|
||||
memory=memory,
|
||||
variable_pool=variable_pool,
|
||||
user_query=actual_query,
|
||||
assistant_response=clean_text,
|
||||
user_files=files,
|
||||
assistant_files=self._file_outputs,
|
||||
)
|
||||
|
||||
# Send final chunk event to indicate streaming is complete
|
||||
yield StreamChunkEvent(
|
||||
|
||||
@ -7,7 +7,7 @@ from typing import Any, cast
|
||||
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.file import File
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.memory.base import BaseMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities import ImagePromptMessageContent
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
@ -145,8 +145,10 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
memory = llm_utils.fetch_memory(
|
||||
variable_pool=variable_pool,
|
||||
app_id=self.app_id,
|
||||
tenant_id=self.tenant_id,
|
||||
node_data_memory=node_data.memory,
|
||||
model_instance=model_instance,
|
||||
node_id=self._node_id,
|
||||
)
|
||||
|
||||
if (
|
||||
@ -244,6 +246,14 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
# transform result into standard format
|
||||
result = self._transform_result(data=node_data, result=result or {})
|
||||
|
||||
# Save to node memory if in node memory mode
|
||||
llm_utils.save_node_memory(
|
||||
memory=memory,
|
||||
variable_pool=variable_pool,
|
||||
user_query=query,
|
||||
assistant_response=json.dumps(result, ensure_ascii=False),
|
||||
)
|
||||
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=inputs,
|
||||
@ -299,7 +309,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
query: str,
|
||||
variable_pool: VariablePool,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
memory: TokenBufferMemory | None,
|
||||
memory: BaseMemory | None,
|
||||
files: Sequence[File],
|
||||
vision_detail: ImagePromptMessageContent.DETAIL | None = None,
|
||||
) -> tuple[list[PromptMessage], list[PromptMessageTool]]:
|
||||
@ -381,7 +391,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
query: str,
|
||||
variable_pool: VariablePool,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
memory: TokenBufferMemory | None,
|
||||
memory: BaseMemory | None,
|
||||
files: Sequence[File],
|
||||
vision_detail: ImagePromptMessageContent.DETAIL | None = None,
|
||||
) -> list[PromptMessage]:
|
||||
@ -419,7 +429,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
query: str,
|
||||
variable_pool: VariablePool,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
memory: TokenBufferMemory | None,
|
||||
memory: BaseMemory | None,
|
||||
files: Sequence[File],
|
||||
vision_detail: ImagePromptMessageContent.DETAIL | None = None,
|
||||
) -> list[PromptMessage]:
|
||||
@ -453,7 +463,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
query: str,
|
||||
variable_pool: VariablePool,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
memory: TokenBufferMemory | None,
|
||||
memory: BaseMemory | None,
|
||||
files: Sequence[File],
|
||||
vision_detail: ImagePromptMessageContent.DETAIL | None = None,
|
||||
) -> list[PromptMessage]:
|
||||
@ -681,7 +691,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
node_data: ParameterExtractorNodeData,
|
||||
query: str,
|
||||
variable_pool: VariablePool,
|
||||
memory: TokenBufferMemory | None,
|
||||
memory: BaseMemory | None,
|
||||
max_token_limit: int = 2000,
|
||||
) -> list[ChatModelMessage]:
|
||||
model_mode = ModelMode(node_data.model.mode)
|
||||
@ -708,7 +718,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
node_data: ParameterExtractorNodeData,
|
||||
query: str,
|
||||
variable_pool: VariablePool,
|
||||
memory: TokenBufferMemory | None,
|
||||
memory: BaseMemory | None,
|
||||
max_token_limit: int = 2000,
|
||||
):
|
||||
model_mode = ModelMode(node_data.model.mode)
|
||||
|
||||
@ -4,7 +4,7 @@ from collections.abc import Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.memory.base import BaseMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities import LLMUsage, ModelPropertyKey, PromptMessageRole
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
@ -96,8 +96,10 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
|
||||
memory = llm_utils.fetch_memory(
|
||||
variable_pool=variable_pool,
|
||||
app_id=self.app_id,
|
||||
tenant_id=self.tenant_id,
|
||||
node_data_memory=node_data.memory,
|
||||
model_instance=model_instance,
|
||||
node_id=self._node_id,
|
||||
)
|
||||
# fetch instruction
|
||||
node_data.instruction = node_data.instruction or ""
|
||||
@ -203,6 +205,14 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
|
||||
"usage": jsonable_encoder(usage),
|
||||
}
|
||||
|
||||
# Save to node memory if in node memory mode
|
||||
llm_utils.save_node_memory(
|
||||
memory=memory,
|
||||
variable_pool=variable_pool,
|
||||
user_query=query or "",
|
||||
assistant_response=f"class_name: {category_name}, class_id: {category_id}",
|
||||
)
|
||||
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=variables,
|
||||
@ -312,7 +322,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
|
||||
self,
|
||||
node_data: QuestionClassifierNodeData,
|
||||
query: str,
|
||||
memory: TokenBufferMemory | None,
|
||||
memory: BaseMemory | None,
|
||||
max_token_limit: int = 2000,
|
||||
):
|
||||
model_mode = ModelMode(node_data.model.mode)
|
||||
|
||||
@ -1,22 +1,26 @@
|
||||
import re
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Literal, Union
|
||||
from typing import Any, Literal, Self, Union
|
||||
|
||||
from pydantic import BaseModel, field_validator
|
||||
from pydantic import BaseModel, field_validator, model_validator
|
||||
from pydantic_core.core_schema import ValidationInfo
|
||||
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
from core.workflow.nodes.base.entities import BaseNodeData
|
||||
|
||||
# Pattern to match a single variable reference like {{#llm.context#}}
|
||||
SINGLE_VARIABLE_PATTERN = re.compile(r"^\s*\{\{#[a-zA-Z0-9_]+(?:\.[a-zA-Z_][a-zA-Z0-9_]*)+#\}\}\s*$")
|
||||
|
||||
class MentionValue(BaseModel):
|
||||
"""Value structure for mention type parameters.
|
||||
|
||||
Used when a tool parameter needs to be extracted from conversation context
|
||||
using an extractor LLM node.
|
||||
class MentionConfig(BaseModel):
|
||||
"""Configuration for extracting value from context variable.
|
||||
|
||||
Used when a tool parameter needs to be extracted from list[PromptMessage]
|
||||
context using an extractor LLM node.
|
||||
"""
|
||||
|
||||
# Variable selector for list[PromptMessage] input to extractor
|
||||
variable_selector: Sequence[str]
|
||||
# Instruction for the extractor LLM to extract the value
|
||||
instruction: str
|
||||
|
||||
# ID of the extractor LLM node
|
||||
extractor_node_id: str
|
||||
@ -60,8 +64,10 @@ class ToolEntity(BaseModel):
|
||||
class ToolNodeData(BaseNodeData, ToolEntity):
|
||||
class ToolInput(BaseModel):
|
||||
# TODO: check this type
|
||||
value: Union[Any, list[str], MentionValue]
|
||||
value: Union[Any, list[str]]
|
||||
type: Literal["mixed", "variable", "constant", "mention"]
|
||||
# Required config for mention type, extracting value from context variable
|
||||
mention_config: MentionConfig | None = None
|
||||
|
||||
@field_validator("type", mode="before")
|
||||
@classmethod
|
||||
@ -74,6 +80,9 @@ class ToolNodeData(BaseNodeData, ToolEntity):
|
||||
|
||||
if typ == "mixed" and not isinstance(value, str):
|
||||
raise ValueError("value must be a string")
|
||||
elif typ == "mention":
|
||||
# Skip here, will be validated in model_validator
|
||||
pass
|
||||
elif typ == "variable":
|
||||
if not isinstance(value, list):
|
||||
raise ValueError("value must be a list")
|
||||
@ -82,19 +91,31 @@ class ToolNodeData(BaseNodeData, ToolEntity):
|
||||
raise ValueError("value must be a list of strings")
|
||||
elif typ == "constant" and not isinstance(value, str | int | float | bool | dict):
|
||||
raise ValueError("value must be a string, int, float, bool or dict")
|
||||
elif typ == "mention":
|
||||
# Mention type: value should be a MentionValue or dict with required fields
|
||||
if isinstance(value, MentionValue):
|
||||
pass # Already validated by Pydantic
|
||||
elif isinstance(value, dict):
|
||||
if "extractor_node_id" not in value:
|
||||
raise ValueError("value must contain extractor_node_id for mention type")
|
||||
if "output_selector" not in value:
|
||||
raise ValueError("value must contain output_selector for mention type")
|
||||
else:
|
||||
raise ValueError("value must be a MentionValue or dict for mention type")
|
||||
return typ
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_mention_type(self) -> Self:
|
||||
"""Validate mention type with mention_config."""
|
||||
if self.type != "mention":
|
||||
return self
|
||||
|
||||
value = self.value
|
||||
if value is None:
|
||||
return self
|
||||
|
||||
if not isinstance(value, str):
|
||||
raise ValueError("value must be a string for mention type")
|
||||
# For mention type, value must be a single variable reference
|
||||
if not SINGLE_VARIABLE_PATTERN.match(value):
|
||||
raise ValueError(
|
||||
"For mention type, value must be a single variable reference "
|
||||
"like {{#node.variable#}}, cannot contain other content"
|
||||
)
|
||||
# mention_config is required for mention type
|
||||
if self.mention_config is None:
|
||||
raise ValueError("mention_config is required for mention type")
|
||||
return self
|
||||
|
||||
tool_parameters: dict[str, ToolInput]
|
||||
# The version of the tool parameter.
|
||||
# If this value is None, it indicates this is a previous version
|
||||
|
||||
@ -214,16 +214,11 @@ class ToolNode(Node[ToolNodeData]):
|
||||
parameter_value = variable.value
|
||||
elif tool_input.type == "mention":
|
||||
# Mention type: get value from extractor node's output
|
||||
from .entities import MentionValue
|
||||
|
||||
mention_value = tool_input.value
|
||||
if isinstance(mention_value, MentionValue):
|
||||
mention_config = mention_value.model_dump()
|
||||
elif isinstance(mention_value, dict):
|
||||
mention_config = mention_value
|
||||
else:
|
||||
raise ToolParameterError(f"Invalid mention value for parameter '{parameter_name}'")
|
||||
|
||||
if tool_input.mention_config is None:
|
||||
raise ToolParameterError(
|
||||
f"mention_config is required for mention type parameter '{parameter_name}'"
|
||||
)
|
||||
mention_config = tool_input.mention_config.model_dump()
|
||||
try:
|
||||
parameter_value, found = variable_pool.resolve_mention(
|
||||
mention_config, parameter_name=parameter_name
|
||||
@ -524,7 +519,7 @@ class ToolNode(Node[ToolNodeData]):
|
||||
selector_key = ".".join(input.value)
|
||||
result[f"#{selector_key}#"] = input.value
|
||||
elif input.type == "mention":
|
||||
# Mention type handled by extractor node, no direct variable reference
|
||||
# Mention type: value is handled by extractor node, no direct variable reference
|
||||
pass
|
||||
elif input.type == "constant":
|
||||
pass
|
||||
|
||||
94
api/tests/fixtures/pav-test-extraction.yml
vendored
94
api/tests/fixtures/pav-test-extraction.yml
vendored
@ -11,6 +11,11 @@ dependencies:
|
||||
value:
|
||||
marketplace_plugin_unique_identifier: langgenius/google:0.0.8@3efcf55ffeef9d0f77715e0afb23534952ae0cb385c051d0637e86d71199d1a6
|
||||
version: null
|
||||
- current_identifier: null
|
||||
type: marketplace
|
||||
value:
|
||||
marketplace_plugin_unique_identifier: langgenius/openai:0.2.3@5a7f82fa86e28332ad51941d0b491c1e8a38ead539656442f7bf4c6129cd15fa
|
||||
version: null
|
||||
- current_identifier: null
|
||||
type: marketplace
|
||||
value:
|
||||
@ -115,7 +120,8 @@ workflow:
|
||||
enabled: false
|
||||
variable_selector: []
|
||||
memory:
|
||||
query_prompt_template: ''
|
||||
mode: node
|
||||
query_prompt_template: '{{#sys.query#}}'
|
||||
role_prefix:
|
||||
assistant: ''
|
||||
user: ''
|
||||
@ -201,29 +207,17 @@ workflow:
|
||||
tool_node_version: '2'
|
||||
tool_parameters:
|
||||
query:
|
||||
type: variable
|
||||
value:
|
||||
- ext_1
|
||||
- text
|
||||
mention_config:
|
||||
default_value: ''
|
||||
extractor_node_id: 1767773709491_ext_query
|
||||
instruction: 请从对话历史中提取用户想要搜索的关键词,只返回关键词本身
|
||||
null_strategy: use_default
|
||||
output_selector:
|
||||
- structured_output
|
||||
- query
|
||||
type: mention
|
||||
value: '{{#llm.context#}}'
|
||||
type: tool
|
||||
virtual_nodes:
|
||||
- data:
|
||||
model:
|
||||
completion_params:
|
||||
temperature: 0.7
|
||||
mode: chat
|
||||
name: qwen-max
|
||||
provider: langgenius/tongyi/tongyi
|
||||
context:
|
||||
enabled: false
|
||||
prompt_template:
|
||||
- role: user
|
||||
text: '{{#llm.context#}}'
|
||||
- role: user
|
||||
text: 请从对话历史中提取用户想要搜索的关键词,只返回关键词本身,不要返回其他内容
|
||||
title: 提取搜索关键词
|
||||
id: ext_1
|
||||
type: llm
|
||||
height: 52
|
||||
id: '1767773709491'
|
||||
position:
|
||||
@ -237,6 +231,54 @@ workflow:
|
||||
targetPosition: left
|
||||
type: custom
|
||||
width: 242
|
||||
- data:
|
||||
context:
|
||||
enabled: false
|
||||
variable_selector: []
|
||||
model:
|
||||
completion_params:
|
||||
temperature: 0.7
|
||||
mode: chat
|
||||
name: gpt-4o-mini
|
||||
provider: langgenius/openai/openai
|
||||
parent_node_id: '1767773709491'
|
||||
prompt_template:
|
||||
- $context:
|
||||
- llm
|
||||
- context
|
||||
id: 75d58e22-dc59-40c8-ba6f-aeb28f4f305a
|
||||
- id: 18ba6710-77f5-47f4-b144-9191833bb547
|
||||
role: user
|
||||
text: 请从对话历史中提取用户想要搜索的关键词,只返回关键词本身,不要返回其他内容
|
||||
selected: false
|
||||
structured_output:
|
||||
schema:
|
||||
additionalProperties: false
|
||||
properties:
|
||||
query:
|
||||
description: 搜索的关键词
|
||||
type: string
|
||||
required:
|
||||
- query
|
||||
type: object
|
||||
structured_output_enabled: true
|
||||
title: 提取搜索关键词
|
||||
type: llm
|
||||
vision:
|
||||
enabled: false
|
||||
height: 88
|
||||
id: 1767773709491_ext_query
|
||||
position:
|
||||
x: 531
|
||||
y: 382
|
||||
positionAbsolute:
|
||||
x: 531
|
||||
y: 382
|
||||
selected: true
|
||||
sourcePosition: right
|
||||
targetPosition: left
|
||||
type: custom
|
||||
width: 242
|
||||
- data:
|
||||
answer: '搜索结果:
|
||||
|
||||
@ -254,13 +296,13 @@ workflow:
|
||||
positionAbsolute:
|
||||
x: 984
|
||||
y: 282
|
||||
selected: true
|
||||
selected: false
|
||||
sourcePosition: right
|
||||
targetPosition: left
|
||||
type: custom
|
||||
width: 242
|
||||
viewport:
|
||||
x: 151
|
||||
y: 141.5
|
||||
x: -151
|
||||
y: 123
|
||||
zoom: 1
|
||||
rag_pipeline_variables: []
|
||||
|
||||
Loading…
Reference in New Issue
Block a user