mirror of
https://github.com/langgenius/dify.git
synced 2026-01-14 06:07:33 +08:00
feat: add mention node executor
This commit is contained in:
parent
8b8e521c4e
commit
5bcd3b6fe6
1418
api/core/workflow/docs/variable_extraction_design.md
Normal file
1418
api/core/workflow/docs/variable_extraction_design.md
Normal file
File diff suppressed because it is too large
Load Diff
@ -125,6 +125,11 @@ class EventHandler:
|
|||||||
Args:
|
Args:
|
||||||
event: The node started event
|
event: The node started event
|
||||||
"""
|
"""
|
||||||
|
# Check if this is a virtual node (extraction node)
|
||||||
|
if self._is_virtual_node(event.node_id):
|
||||||
|
self._handle_virtual_node_started(event)
|
||||||
|
return
|
||||||
|
|
||||||
# Track execution in domain model
|
# Track execution in domain model
|
||||||
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
|
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
|
||||||
is_initial_attempt = node_execution.retry_count == 0
|
is_initial_attempt = node_execution.retry_count == 0
|
||||||
@ -164,6 +169,11 @@ class EventHandler:
|
|||||||
Args:
|
Args:
|
||||||
event: The node succeeded event
|
event: The node succeeded event
|
||||||
"""
|
"""
|
||||||
|
# Check if this is a virtual node (extraction node)
|
||||||
|
if self._is_virtual_node(event.node_id):
|
||||||
|
self._handle_virtual_node_success(event)
|
||||||
|
return
|
||||||
|
|
||||||
# Update domain model
|
# Update domain model
|
||||||
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
|
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
|
||||||
node_execution.mark_taken()
|
node_execution.mark_taken()
|
||||||
@ -226,6 +236,11 @@ class EventHandler:
|
|||||||
Args:
|
Args:
|
||||||
event: The node failed event
|
event: The node failed event
|
||||||
"""
|
"""
|
||||||
|
# Check if this is a virtual node (extraction node)
|
||||||
|
if self._is_virtual_node(event.node_id):
|
||||||
|
self._handle_virtual_node_failed(event)
|
||||||
|
return
|
||||||
|
|
||||||
# Update domain model
|
# Update domain model
|
||||||
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
|
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
|
||||||
node_execution.mark_failed(event.error)
|
node_execution.mark_failed(event.error)
|
||||||
@ -345,3 +360,57 @@ class EventHandler:
|
|||||||
self._graph_runtime_state.set_output("answer", value)
|
self._graph_runtime_state.set_output("answer", value)
|
||||||
else:
|
else:
|
||||||
self._graph_runtime_state.set_output(key, value)
|
self._graph_runtime_state.set_output(key, value)
|
||||||
|
|
||||||
|
def _is_virtual_node(self, node_id: str) -> bool:
|
||||||
|
"""
|
||||||
|
Check if node_id represents a virtual sub-node.
|
||||||
|
|
||||||
|
Virtual nodes have IDs in the format: {parent_node_id}.{local_id}
|
||||||
|
We check if the part before '.' exists in graph nodes.
|
||||||
|
"""
|
||||||
|
if "." in node_id:
|
||||||
|
parent_id = node_id.rsplit(".", 1)[0]
|
||||||
|
return parent_id in self._graph.nodes
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _handle_virtual_node_started(self, event: NodeRunStartedEvent) -> None:
|
||||||
|
"""
|
||||||
|
Handle virtual node started event.
|
||||||
|
|
||||||
|
Virtual nodes don't need full execution tracking, just collect the event.
|
||||||
|
"""
|
||||||
|
# Track in response coordinator for stream ordering
|
||||||
|
self._response_coordinator.track_node_execution(event.node_id, event.id)
|
||||||
|
|
||||||
|
# Collect the event
|
||||||
|
self._event_collector.collect(event)
|
||||||
|
|
||||||
|
def _handle_virtual_node_success(self, event: NodeRunSucceededEvent) -> None:
|
||||||
|
"""
|
||||||
|
Handle virtual node success event.
|
||||||
|
|
||||||
|
Virtual nodes (extraction nodes) need special handling:
|
||||||
|
- Store outputs in variable pool (for reference by other nodes)
|
||||||
|
- Accumulate token usage
|
||||||
|
- Collect the event for logging
|
||||||
|
- Do NOT process edges or enqueue next nodes (parent node handles that)
|
||||||
|
"""
|
||||||
|
self._accumulate_node_usage(event.node_run_result.llm_usage)
|
||||||
|
|
||||||
|
# Store outputs in variable pool
|
||||||
|
self._store_node_outputs(event.node_id, event.node_run_result.outputs)
|
||||||
|
|
||||||
|
# Collect the event
|
||||||
|
self._event_collector.collect(event)
|
||||||
|
|
||||||
|
def _handle_virtual_node_failed(self, event: NodeRunFailedEvent) -> None:
|
||||||
|
"""
|
||||||
|
Handle virtual node failed event.
|
||||||
|
|
||||||
|
Virtual nodes (extraction nodes) failures are collected for logging,
|
||||||
|
but the parent node is responsible for handling the error.
|
||||||
|
"""
|
||||||
|
self._accumulate_node_usage(event.node_run_result.llm_usage)
|
||||||
|
|
||||||
|
# Collect the event for logging
|
||||||
|
self._event_collector.collect(event)
|
||||||
|
|||||||
@ -20,6 +20,12 @@ class NodeRunStartedEvent(GraphNodeEventBase):
|
|||||||
provider_type: str = ""
|
provider_type: str = ""
|
||||||
provider_id: str = ""
|
provider_id: str = ""
|
||||||
|
|
||||||
|
# Virtual node fields for extraction
|
||||||
|
is_virtual: bool = False
|
||||||
|
parent_node_id: str | None = None
|
||||||
|
extraction_source: str | None = None # e.g., "llm1.context"
|
||||||
|
extraction_prompt: str | None = None
|
||||||
|
|
||||||
|
|
||||||
class NodeRunStreamChunkEvent(GraphNodeEventBase):
|
class NodeRunStreamChunkEvent(GraphNodeEventBase):
|
||||||
# Spec-compliant fields
|
# Spec-compliant fields
|
||||||
|
|||||||
@ -1,5 +1,13 @@
|
|||||||
from .entities import BaseIterationNodeData, BaseIterationState, BaseLoopNodeData, BaseLoopState, BaseNodeData
|
from .entities import (
|
||||||
|
BaseIterationNodeData,
|
||||||
|
BaseIterationState,
|
||||||
|
BaseLoopNodeData,
|
||||||
|
BaseLoopState,
|
||||||
|
BaseNodeData,
|
||||||
|
VirtualNodeConfig,
|
||||||
|
)
|
||||||
from .usage_tracking_mixin import LLMUsageTrackingMixin
|
from .usage_tracking_mixin import LLMUsageTrackingMixin
|
||||||
|
from .virtual_node_executor import VirtualNodeExecutionError, VirtualNodeExecutor
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"BaseIterationNodeData",
|
"BaseIterationNodeData",
|
||||||
@ -8,4 +16,7 @@ __all__ = [
|
|||||||
"BaseLoopState",
|
"BaseLoopState",
|
||||||
"BaseNodeData",
|
"BaseNodeData",
|
||||||
"LLMUsageTrackingMixin",
|
"LLMUsageTrackingMixin",
|
||||||
|
"VirtualNodeConfig",
|
||||||
|
"VirtualNodeExecutionError",
|
||||||
|
"VirtualNodeExecutor",
|
||||||
]
|
]
|
||||||
|
|||||||
@ -167,6 +167,24 @@ class DefaultValue(BaseModel):
|
|||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
class VirtualNodeConfig(BaseModel):
|
||||||
|
"""Configuration for a virtual sub-node embedded within a parent node."""
|
||||||
|
|
||||||
|
# Local ID within parent node (e.g., "ext_1")
|
||||||
|
# Will be converted to global ID: "{parent_id}.{id}"
|
||||||
|
id: str
|
||||||
|
|
||||||
|
# Node type (e.g., "llm", "code", "tool")
|
||||||
|
type: str
|
||||||
|
|
||||||
|
# Full node data configuration
|
||||||
|
data: dict[str, Any] = {}
|
||||||
|
|
||||||
|
def get_global_id(self, parent_node_id: str) -> str:
|
||||||
|
"""Get the global node ID by combining parent ID and local ID."""
|
||||||
|
return f"{parent_node_id}.{self.id}"
|
||||||
|
|
||||||
|
|
||||||
class BaseNodeData(ABC, BaseModel):
|
class BaseNodeData(ABC, BaseModel):
|
||||||
title: str
|
title: str
|
||||||
desc: str | None = None
|
desc: str | None = None
|
||||||
@ -175,6 +193,9 @@ class BaseNodeData(ABC, BaseModel):
|
|||||||
default_value: list[DefaultValue] | None = None
|
default_value: list[DefaultValue] | None = None
|
||||||
retry_config: RetryConfig = RetryConfig()
|
retry_config: RetryConfig = RetryConfig()
|
||||||
|
|
||||||
|
# Virtual sub-nodes that execute before the main node
|
||||||
|
virtual_nodes: list[VirtualNodeConfig] = []
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def default_value_dict(self) -> dict[str, Any]:
|
def default_value_dict(self) -> dict[str, Any]:
|
||||||
if self.default_value:
|
if self.default_value:
|
||||||
|
|||||||
@ -229,6 +229,7 @@ class Node(Generic[NodeDataT]):
|
|||||||
self._node_id = node_id
|
self._node_id = node_id
|
||||||
self._node_execution_id: str = ""
|
self._node_execution_id: str = ""
|
||||||
self._start_at = naive_utc_now()
|
self._start_at = naive_utc_now()
|
||||||
|
self._virtual_node_outputs: dict[str, Any] = {} # Outputs from virtual sub-nodes
|
||||||
|
|
||||||
raw_node_data = config.get("data") or {}
|
raw_node_data = config.get("data") or {}
|
||||||
if not isinstance(raw_node_data, Mapping):
|
if not isinstance(raw_node_data, Mapping):
|
||||||
@ -270,10 +271,52 @@ class Node(Generic[NodeDataT]):
|
|||||||
"""Check if execution should be stopped."""
|
"""Check if execution should be stopped."""
|
||||||
return self.graph_runtime_state.stop_event.is_set()
|
return self.graph_runtime_state.stop_event.is_set()
|
||||||
|
|
||||||
|
def _execute_virtual_nodes(self) -> Generator[GraphNodeEventBase, None, dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Execute all virtual sub-nodes defined in node configuration.
|
||||||
|
|
||||||
|
Virtual nodes are complete node definitions that execute before the main node.
|
||||||
|
Each virtual node:
|
||||||
|
- Has its own global ID: "{parent_id}.{local_id}"
|
||||||
|
- Generates standard node events
|
||||||
|
- Stores outputs in the variable pool (via event handling)
|
||||||
|
- Supports retry via parent node's retry config
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict mapping local_id -> outputs dict
|
||||||
|
"""
|
||||||
|
from .virtual_node_executor import VirtualNodeExecutor
|
||||||
|
|
||||||
|
virtual_nodes = self.node_data.virtual_nodes
|
||||||
|
if not virtual_nodes:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
executor = VirtualNodeExecutor(
|
||||||
|
graph_init_params=self._graph_init_params,
|
||||||
|
graph_runtime_state=self.graph_runtime_state,
|
||||||
|
parent_node_id=self._node_id,
|
||||||
|
parent_retry_config=self.retry_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
return (yield from executor.execute_virtual_nodes(virtual_nodes))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def virtual_node_outputs(self) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Get the outputs from virtual sub-nodes.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict mapping local_id -> outputs dict
|
||||||
|
"""
|
||||||
|
return self._virtual_node_outputs
|
||||||
|
|
||||||
def run(self) -> Generator[GraphNodeEventBase, None, None]:
|
def run(self) -> Generator[GraphNodeEventBase, None, None]:
|
||||||
execution_id = self.ensure_execution_id()
|
execution_id = self.ensure_execution_id()
|
||||||
self._start_at = naive_utc_now()
|
self._start_at = naive_utc_now()
|
||||||
|
|
||||||
|
# Step 1: Execute virtual sub-nodes before main node execution
|
||||||
|
self._virtual_node_outputs = yield from self._execute_virtual_nodes()
|
||||||
|
|
||||||
# Create and push start event with required fields
|
# Create and push start event with required fields
|
||||||
start_event = NodeRunStartedEvent(
|
start_event = NodeRunStartedEvent(
|
||||||
id=execution_id,
|
id=execution_id,
|
||||||
|
|||||||
213
api/core/workflow/nodes/base/virtual_node_executor.py
Normal file
213
api/core/workflow/nodes/base/virtual_node_executor.py
Normal file
@ -0,0 +1,213 @@
|
|||||||
|
"""
|
||||||
|
Virtual Node Executor for running embedded sub-nodes within a parent node.
|
||||||
|
|
||||||
|
This module handles the execution of virtual nodes defined in a parent node's
|
||||||
|
`virtual_nodes` configuration. Virtual nodes are complete node definitions
|
||||||
|
that execute before the parent node.
|
||||||
|
|
||||||
|
Example configuration:
|
||||||
|
virtual_nodes:
|
||||||
|
- id: ext_1
|
||||||
|
type: llm
|
||||||
|
data:
|
||||||
|
model: {...}
|
||||||
|
prompt_template: [...]
|
||||||
|
"""
|
||||||
|
|
||||||
|
import time
|
||||||
|
from collections.abc import Generator
|
||||||
|
from typing import TYPE_CHECKING, Any
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
from core.workflow.enums import NodeType
|
||||||
|
from core.workflow.graph_events import (
|
||||||
|
GraphNodeEventBase,
|
||||||
|
NodeRunFailedEvent,
|
||||||
|
NodeRunRetryEvent,
|
||||||
|
NodeRunStartedEvent,
|
||||||
|
NodeRunSucceededEvent,
|
||||||
|
)
|
||||||
|
from libs.datetime_utils import naive_utc_now
|
||||||
|
|
||||||
|
from .entities import RetryConfig, VirtualNodeConfig
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from core.workflow.entities import GraphInitParams
|
||||||
|
from core.workflow.runtime import GraphRuntimeState
|
||||||
|
|
||||||
|
|
||||||
|
class VirtualNodeExecutionError(Exception):
|
||||||
|
"""Error during virtual node execution"""
|
||||||
|
|
||||||
|
def __init__(self, node_id: str, original_error: Exception):
|
||||||
|
self.node_id = node_id
|
||||||
|
self.original_error = original_error
|
||||||
|
super().__init__(f"Virtual node {node_id} execution failed: {original_error}")
|
||||||
|
|
||||||
|
|
||||||
|
class VirtualNodeExecutor:
|
||||||
|
"""
|
||||||
|
Executes virtual sub-nodes embedded within a parent node.
|
||||||
|
|
||||||
|
Virtual nodes are complete node definitions that execute before the parent node.
|
||||||
|
Each virtual node:
|
||||||
|
- Has its own global ID: "{parent_id}.{local_id}"
|
||||||
|
- Generates standard node events
|
||||||
|
- Stores outputs in the variable pool
|
||||||
|
- Supports retry via parent node's retry config
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
graph_init_params: "GraphInitParams",
|
||||||
|
graph_runtime_state: "GraphRuntimeState",
|
||||||
|
parent_node_id: str,
|
||||||
|
parent_retry_config: RetryConfig | None = None,
|
||||||
|
):
|
||||||
|
self._graph_init_params = graph_init_params
|
||||||
|
self._graph_runtime_state = graph_runtime_state
|
||||||
|
self._parent_node_id = parent_node_id
|
||||||
|
self._parent_retry_config = parent_retry_config or RetryConfig()
|
||||||
|
|
||||||
|
def execute_virtual_nodes(
|
||||||
|
self,
|
||||||
|
virtual_nodes: list[VirtualNodeConfig],
|
||||||
|
) -> Generator[GraphNodeEventBase, None, dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Execute all virtual nodes in order.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
virtual_nodes: List of virtual node configurations
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
Node events from each virtual node execution
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict mapping local_id -> outputs dict
|
||||||
|
"""
|
||||||
|
results: dict[str, Any] = {}
|
||||||
|
|
||||||
|
for vnode_config in virtual_nodes:
|
||||||
|
global_id = vnode_config.get_global_id(self._parent_node_id)
|
||||||
|
|
||||||
|
# Execute with retry
|
||||||
|
outputs = yield from self._execute_with_retry(vnode_config, global_id)
|
||||||
|
results[vnode_config.id] = outputs
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
def _execute_with_retry(
|
||||||
|
self,
|
||||||
|
vnode_config: VirtualNodeConfig,
|
||||||
|
global_id: str,
|
||||||
|
) -> Generator[GraphNodeEventBase, None, dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Execute virtual node with retry support.
|
||||||
|
"""
|
||||||
|
retry_config = self._parent_retry_config
|
||||||
|
last_error: Exception | None = None
|
||||||
|
|
||||||
|
for attempt in range(retry_config.max_retries + 1):
|
||||||
|
try:
|
||||||
|
return (yield from self._execute_single_node(vnode_config, global_id))
|
||||||
|
except Exception as e:
|
||||||
|
last_error = e
|
||||||
|
|
||||||
|
if attempt < retry_config.max_retries:
|
||||||
|
# Yield retry event
|
||||||
|
yield NodeRunRetryEvent(
|
||||||
|
id=str(uuid4()),
|
||||||
|
node_id=global_id,
|
||||||
|
node_type=self._get_node_type(vnode_config.type),
|
||||||
|
node_title=vnode_config.data.get("title", f"Virtual: {vnode_config.id}"),
|
||||||
|
start_at=naive_utc_now(),
|
||||||
|
error=str(e),
|
||||||
|
retry_index=attempt + 1,
|
||||||
|
)
|
||||||
|
|
||||||
|
time.sleep(retry_config.retry_interval_seconds)
|
||||||
|
continue
|
||||||
|
|
||||||
|
raise VirtualNodeExecutionError(global_id, e) from e
|
||||||
|
|
||||||
|
raise last_error or VirtualNodeExecutionError(global_id, Exception("Unknown error"))
|
||||||
|
|
||||||
|
def _execute_single_node(
|
||||||
|
self,
|
||||||
|
vnode_config: VirtualNodeConfig,
|
||||||
|
global_id: str,
|
||||||
|
) -> Generator[GraphNodeEventBase, None, dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Execute a single virtual node by instantiating and running it.
|
||||||
|
"""
|
||||||
|
from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
|
||||||
|
|
||||||
|
# Build node config
|
||||||
|
node_config: dict[str, Any] = {
|
||||||
|
"id": global_id,
|
||||||
|
"data": {
|
||||||
|
**vnode_config.data,
|
||||||
|
"title": vnode_config.data.get("title", f"Virtual: {vnode_config.id}"),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
# Get the node class for this type
|
||||||
|
node_type = self._get_node_type(vnode_config.type)
|
||||||
|
node_mapping = NODE_TYPE_CLASSES_MAPPING.get(node_type)
|
||||||
|
if not node_mapping:
|
||||||
|
raise ValueError(f"No class mapping found for node type: {node_type}")
|
||||||
|
|
||||||
|
node_version = str(vnode_config.data.get("version", "1"))
|
||||||
|
node_cls = node_mapping.get(node_version) or node_mapping.get(LATEST_VERSION)
|
||||||
|
if not node_cls:
|
||||||
|
raise ValueError(f"No class found for node type: {node_type}")
|
||||||
|
|
||||||
|
# Instantiate the node
|
||||||
|
node = node_cls(
|
||||||
|
id=global_id,
|
||||||
|
config=node_config,
|
||||||
|
graph_init_params=self._graph_init_params,
|
||||||
|
graph_runtime_state=self._graph_runtime_state,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Run and collect events
|
||||||
|
outputs: dict[str, Any] = {}
|
||||||
|
|
||||||
|
for event in node.run():
|
||||||
|
# Mark event as coming from virtual node
|
||||||
|
self._mark_event_as_virtual(event, vnode_config)
|
||||||
|
yield event
|
||||||
|
|
||||||
|
if isinstance(event, NodeRunSucceededEvent):
|
||||||
|
outputs = event.node_run_result.outputs or {}
|
||||||
|
elif isinstance(event, NodeRunFailedEvent):
|
||||||
|
raise Exception(event.error or "Virtual node execution failed")
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
def _mark_event_as_virtual(
|
||||||
|
self,
|
||||||
|
event: GraphNodeEventBase,
|
||||||
|
vnode_config: VirtualNodeConfig,
|
||||||
|
) -> None:
|
||||||
|
"""Mark event as coming from a virtual node."""
|
||||||
|
if isinstance(event, NodeRunStartedEvent):
|
||||||
|
event.is_virtual = True
|
||||||
|
event.parent_node_id = self._parent_node_id
|
||||||
|
|
||||||
|
def _get_node_type(self, type_str: str) -> NodeType:
|
||||||
|
"""Convert type string to NodeType enum."""
|
||||||
|
type_mapping = {
|
||||||
|
"llm": NodeType.LLM,
|
||||||
|
"code": NodeType.CODE,
|
||||||
|
"tool": NodeType.TOOL,
|
||||||
|
"if-else": NodeType.IF_ELSE,
|
||||||
|
"question-classifier": NodeType.QUESTION_CLASSIFIER,
|
||||||
|
"parameter-extractor": NodeType.PARAMETER_EXTRACTOR,
|
||||||
|
"template-transform": NodeType.TEMPLATE_TRANSFORM,
|
||||||
|
"variable-assigner": NodeType.VARIABLE_ASSIGNER,
|
||||||
|
"http-request": NodeType.HTTP_REQUEST,
|
||||||
|
"knowledge-retrieval": NodeType.KNOWLEDGE_RETRIEVAL,
|
||||||
|
}
|
||||||
|
return type_mapping.get(type_str, NodeType.LLM)
|
||||||
@ -89,18 +89,20 @@ class ToolNode(Node[ToolNodeData]):
|
|||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
# get parameters
|
# get parameters (use virtual_node_outputs from base class)
|
||||||
tool_parameters = tool_runtime.get_merged_runtime_parameters() or []
|
tool_parameters = tool_runtime.get_merged_runtime_parameters() or []
|
||||||
parameters = self._generate_parameters(
|
parameters = self._generate_parameters(
|
||||||
tool_parameters=tool_parameters,
|
tool_parameters=tool_parameters,
|
||||||
variable_pool=self.graph_runtime_state.variable_pool,
|
variable_pool=self.graph_runtime_state.variable_pool,
|
||||||
node_data=self.node_data,
|
node_data=self.node_data,
|
||||||
|
virtual_node_outputs=self.virtual_node_outputs,
|
||||||
)
|
)
|
||||||
parameters_for_log = self._generate_parameters(
|
parameters_for_log = self._generate_parameters(
|
||||||
tool_parameters=tool_parameters,
|
tool_parameters=tool_parameters,
|
||||||
variable_pool=self.graph_runtime_state.variable_pool,
|
variable_pool=self.graph_runtime_state.variable_pool,
|
||||||
node_data=self.node_data,
|
node_data=self.node_data,
|
||||||
for_log=True,
|
for_log=True,
|
||||||
|
virtual_node_outputs=self.virtual_node_outputs,
|
||||||
)
|
)
|
||||||
# get conversation id
|
# get conversation id
|
||||||
conversation_id = self.graph_runtime_state.variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID])
|
conversation_id = self.graph_runtime_state.variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID])
|
||||||
@ -176,6 +178,7 @@ class ToolNode(Node[ToolNodeData]):
|
|||||||
variable_pool: "VariablePool",
|
variable_pool: "VariablePool",
|
||||||
node_data: ToolNodeData,
|
node_data: ToolNodeData,
|
||||||
for_log: bool = False,
|
for_log: bool = False,
|
||||||
|
virtual_node_outputs: dict[str, Any] | None = None,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Generate parameters based on the given tool parameters, variable pool, and node data.
|
Generate parameters based on the given tool parameters, variable pool, and node data.
|
||||||
@ -184,12 +187,17 @@ class ToolNode(Node[ToolNodeData]):
|
|||||||
tool_parameters (Sequence[ToolParameter]): The list of tool parameters.
|
tool_parameters (Sequence[ToolParameter]): The list of tool parameters.
|
||||||
variable_pool (VariablePool): The variable pool containing the variables.
|
variable_pool (VariablePool): The variable pool containing the variables.
|
||||||
node_data (ToolNodeData): The data associated with the tool node.
|
node_data (ToolNodeData): The data associated with the tool node.
|
||||||
|
for_log (bool): Whether to generate parameters for logging.
|
||||||
|
virtual_node_outputs (dict[str, Any] | None): Outputs from virtual sub-nodes.
|
||||||
|
Maps local_id -> outputs dict. Virtual node outputs are also in variable_pool
|
||||||
|
with global IDs like "{parent_id}.{local_id}".
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Mapping[str, Any]: A dictionary containing the generated parameters.
|
Mapping[str, Any]: A dictionary containing the generated parameters.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
tool_parameters_dictionary = {parameter.name: parameter for parameter in tool_parameters}
|
tool_parameters_dictionary = {parameter.name: parameter for parameter in tool_parameters}
|
||||||
|
virtual_node_outputs = virtual_node_outputs or {}
|
||||||
|
|
||||||
result: dict[str, Any] = {}
|
result: dict[str, Any] = {}
|
||||||
for parameter_name in node_data.tool_parameters:
|
for parameter_name in node_data.tool_parameters:
|
||||||
@ -199,14 +207,25 @@ class ToolNode(Node[ToolNodeData]):
|
|||||||
continue
|
continue
|
||||||
tool_input = node_data.tool_parameters[parameter_name]
|
tool_input = node_data.tool_parameters[parameter_name]
|
||||||
if tool_input.type == "variable":
|
if tool_input.type == "variable":
|
||||||
variable = variable_pool.get(tool_input.value)
|
# Check if this references a virtual node output (local ID like [ext_1, text])
|
||||||
if variable is None:
|
selector = tool_input.value
|
||||||
if parameter.required:
|
if len(selector) >= 2 and selector[0] in virtual_node_outputs:
|
||||||
raise ToolParameterError(f"Variable {tool_input.value} does not exist")
|
# Reference to virtual node output
|
||||||
continue
|
local_id = selector[0]
|
||||||
parameter_value = variable.value
|
var_name = selector[1]
|
||||||
|
outputs = virtual_node_outputs.get(local_id, {})
|
||||||
|
parameter_value = outputs.get(var_name)
|
||||||
|
else:
|
||||||
|
# Normal variable reference
|
||||||
|
variable = variable_pool.get(selector)
|
||||||
|
if variable is None:
|
||||||
|
if parameter.required:
|
||||||
|
raise ToolParameterError(f"Variable {selector} does not exist")
|
||||||
|
continue
|
||||||
|
parameter_value = variable.value
|
||||||
elif tool_input.type in {"mixed", "constant"}:
|
elif tool_input.type in {"mixed", "constant"}:
|
||||||
segment_group = variable_pool.convert_template(str(tool_input.value))
|
template = str(tool_input.value)
|
||||||
|
segment_group = variable_pool.convert_template(template)
|
||||||
parameter_value = segment_group.log if for_log else segment_group.text
|
parameter_value = segment_group.log if for_log else segment_group.text
|
||||||
else:
|
else:
|
||||||
raise ToolParameterError(f"Unknown tool input type '{tool_input.type}'")
|
raise ToolParameterError(f"Unknown tool input type '{tool_input.type}'")
|
||||||
|
|||||||
266
api/tests/fixtures/pav-test-extraction.yml
vendored
Normal file
266
api/tests/fixtures/pav-test-extraction.yml
vendored
Normal file
@ -0,0 +1,266 @@
|
|||||||
|
app:
|
||||||
|
description: Test for variable extraction feature
|
||||||
|
icon: 🤖
|
||||||
|
icon_background: '#FFEAD5'
|
||||||
|
mode: advanced-chat
|
||||||
|
name: pav-test-extraction
|
||||||
|
use_icon_as_answer_icon: false
|
||||||
|
dependencies:
|
||||||
|
- current_identifier: null
|
||||||
|
type: marketplace
|
||||||
|
value:
|
||||||
|
marketplace_plugin_unique_identifier: langgenius/google:0.0.8@3efcf55ffeef9d0f77715e0afb23534952ae0cb385c051d0637e86d71199d1a6
|
||||||
|
version: null
|
||||||
|
- current_identifier: null
|
||||||
|
type: marketplace
|
||||||
|
value:
|
||||||
|
marketplace_plugin_unique_identifier: langgenius/tongyi:0.1.16@d8bffbe45418f0c117fb3393e5e40e61faee98f9a2183f062e5a280e74b15d21
|
||||||
|
version: null
|
||||||
|
kind: app
|
||||||
|
version: 0.5.0
|
||||||
|
workflow:
|
||||||
|
conversation_variables: []
|
||||||
|
environment_variables: []
|
||||||
|
features:
|
||||||
|
file_upload:
|
||||||
|
allowed_file_extensions:
|
||||||
|
- .JPG
|
||||||
|
- .JPEG
|
||||||
|
- .PNG
|
||||||
|
- .GIF
|
||||||
|
- .WEBP
|
||||||
|
- .SVG
|
||||||
|
allowed_file_types:
|
||||||
|
- image
|
||||||
|
allowed_file_upload_methods:
|
||||||
|
- local_file
|
||||||
|
- remote_url
|
||||||
|
enabled: false
|
||||||
|
image:
|
||||||
|
enabled: false
|
||||||
|
number_limits: 3
|
||||||
|
transfer_methods:
|
||||||
|
- local_file
|
||||||
|
- remote_url
|
||||||
|
number_limits: 3
|
||||||
|
opening_statement: 你好!我是一个搜索助手,请告诉我你想搜索什么内容。
|
||||||
|
retriever_resource:
|
||||||
|
enabled: true
|
||||||
|
sensitive_word_avoidance:
|
||||||
|
enabled: false
|
||||||
|
speech_to_text:
|
||||||
|
enabled: false
|
||||||
|
suggested_questions: []
|
||||||
|
suggested_questions_after_answer:
|
||||||
|
enabled: false
|
||||||
|
text_to_speech:
|
||||||
|
enabled: false
|
||||||
|
language: ''
|
||||||
|
voice: ''
|
||||||
|
graph:
|
||||||
|
edges:
|
||||||
|
- data:
|
||||||
|
sourceType: start
|
||||||
|
targetType: llm
|
||||||
|
id: 1767773675796-llm
|
||||||
|
source: '1767773675796'
|
||||||
|
sourceHandle: source
|
||||||
|
target: llm
|
||||||
|
targetHandle: target
|
||||||
|
type: custom
|
||||||
|
- data:
|
||||||
|
isInIteration: false
|
||||||
|
isInLoop: false
|
||||||
|
sourceType: llm
|
||||||
|
targetType: tool
|
||||||
|
id: llm-source-1767773709491-target
|
||||||
|
source: llm
|
||||||
|
sourceHandle: source
|
||||||
|
target: '1767773709491'
|
||||||
|
targetHandle: target
|
||||||
|
type: custom
|
||||||
|
zIndex: 0
|
||||||
|
- data:
|
||||||
|
isInIteration: false
|
||||||
|
isInLoop: false
|
||||||
|
sourceType: tool
|
||||||
|
targetType: answer
|
||||||
|
id: tool-source-answer-target
|
||||||
|
source: '1767773709491'
|
||||||
|
sourceHandle: source
|
||||||
|
target: answer
|
||||||
|
targetHandle: target
|
||||||
|
type: custom
|
||||||
|
zIndex: 0
|
||||||
|
nodes:
|
||||||
|
- data:
|
||||||
|
selected: false
|
||||||
|
title: User Input
|
||||||
|
type: start
|
||||||
|
variables: []
|
||||||
|
height: 73
|
||||||
|
id: '1767773675796'
|
||||||
|
position:
|
||||||
|
x: 80
|
||||||
|
y: 282
|
||||||
|
positionAbsolute:
|
||||||
|
x: 80
|
||||||
|
y: 282
|
||||||
|
sourcePosition: right
|
||||||
|
targetPosition: left
|
||||||
|
type: custom
|
||||||
|
width: 242
|
||||||
|
- data:
|
||||||
|
context:
|
||||||
|
enabled: false
|
||||||
|
variable_selector: []
|
||||||
|
memory:
|
||||||
|
query_prompt_template: ''
|
||||||
|
role_prefix:
|
||||||
|
assistant: ''
|
||||||
|
user: ''
|
||||||
|
window:
|
||||||
|
enabled: true
|
||||||
|
size: 10
|
||||||
|
model:
|
||||||
|
completion_params:
|
||||||
|
temperature: 0.7
|
||||||
|
mode: chat
|
||||||
|
name: qwen-max
|
||||||
|
provider: langgenius/tongyi/tongyi
|
||||||
|
prompt_template:
|
||||||
|
- id: 11d06d15-914a-4915-a5b1-0e35ab4fba51
|
||||||
|
role: system
|
||||||
|
text: '你是一个智能搜索助手。用户会告诉你他们想搜索的内容。
|
||||||
|
|
||||||
|
请与用户进行对话,了解他们的搜索需求。
|
||||||
|
|
||||||
|
当用户明确表达了想要搜索的内容后,你可以回复"好的,我来帮你搜索"。
|
||||||
|
|
||||||
|
'
|
||||||
|
selected: false
|
||||||
|
title: LLM
|
||||||
|
type: llm
|
||||||
|
vision:
|
||||||
|
enabled: false
|
||||||
|
height: 88
|
||||||
|
id: llm
|
||||||
|
position:
|
||||||
|
x: 380
|
||||||
|
y: 282
|
||||||
|
positionAbsolute:
|
||||||
|
x: 380
|
||||||
|
y: 282
|
||||||
|
selected: false
|
||||||
|
sourcePosition: right
|
||||||
|
targetPosition: left
|
||||||
|
type: custom
|
||||||
|
width: 242
|
||||||
|
- data:
|
||||||
|
is_team_authorization: true
|
||||||
|
paramSchemas:
|
||||||
|
- auto_generate: null
|
||||||
|
default: null
|
||||||
|
form: llm
|
||||||
|
human_description:
|
||||||
|
en_US: used for searching
|
||||||
|
ja_JP: used for searching
|
||||||
|
pt_BR: used for searching
|
||||||
|
zh_Hans: 用于搜索网页内容
|
||||||
|
label:
|
||||||
|
en_US: Query string
|
||||||
|
ja_JP: Query string
|
||||||
|
pt_BR: Query string
|
||||||
|
zh_Hans: 查询语句
|
||||||
|
llm_description: key words for searching
|
||||||
|
max: null
|
||||||
|
min: null
|
||||||
|
name: query
|
||||||
|
options: []
|
||||||
|
placeholder: null
|
||||||
|
precision: null
|
||||||
|
required: true
|
||||||
|
scope: null
|
||||||
|
template: null
|
||||||
|
type: string
|
||||||
|
params:
|
||||||
|
query: ''
|
||||||
|
plugin_id: langgenius/google
|
||||||
|
plugin_unique_identifier: langgenius/google:0.0.8@3efcf55ffeef9d0f77715e0afb23534952ae0cb385c051d0637e86d71199d1a6
|
||||||
|
provider_icon: http://localhost:5001/console/api/workspaces/current/plugin/icon?tenant_id=7217e801-f6f5-49ec-8103-d7de97a4b98f&filename=1c5871163478957bac64c3fe33d72d003f767497d921c74b742aad27a8344a74.svg
|
||||||
|
provider_id: langgenius/google/google
|
||||||
|
provider_name: langgenius/google/google
|
||||||
|
provider_type: builtin
|
||||||
|
selected: false
|
||||||
|
title: GoogleSearch
|
||||||
|
tool_configurations: {}
|
||||||
|
tool_description: A tool for performing a Google SERP search and extracting
|
||||||
|
snippets and webpages.Input should be a search query.
|
||||||
|
tool_label: GoogleSearch
|
||||||
|
tool_name: google_search
|
||||||
|
tool_node_version: '2'
|
||||||
|
tool_parameters:
|
||||||
|
query:
|
||||||
|
type: variable
|
||||||
|
value:
|
||||||
|
- ext_1
|
||||||
|
- text
|
||||||
|
type: tool
|
||||||
|
virtual_nodes:
|
||||||
|
- data:
|
||||||
|
model:
|
||||||
|
completion_params:
|
||||||
|
temperature: 0.7
|
||||||
|
mode: chat
|
||||||
|
name: qwen-max
|
||||||
|
provider: langgenius/tongyi/tongyi
|
||||||
|
context:
|
||||||
|
enabled: false
|
||||||
|
prompt_template:
|
||||||
|
- role: user
|
||||||
|
text: '{{#llm.context#}}'
|
||||||
|
- role: user
|
||||||
|
text: 请从对话历史中提取用户想要搜索的关键词,只返回关键词本身,不要返回其他内容
|
||||||
|
title: 提取搜索关键词
|
||||||
|
id: ext_1
|
||||||
|
type: llm
|
||||||
|
height: 52
|
||||||
|
id: '1767773709491'
|
||||||
|
position:
|
||||||
|
x: 682
|
||||||
|
y: 282
|
||||||
|
positionAbsolute:
|
||||||
|
x: 682
|
||||||
|
y: 282
|
||||||
|
selected: false
|
||||||
|
sourcePosition: right
|
||||||
|
targetPosition: left
|
||||||
|
type: custom
|
||||||
|
width: 242
|
||||||
|
- data:
|
||||||
|
answer: '搜索结果:
|
||||||
|
|
||||||
|
{{#1767773709491.text#}}
|
||||||
|
|
||||||
|
'
|
||||||
|
selected: false
|
||||||
|
title: Answer
|
||||||
|
type: answer
|
||||||
|
height: 103
|
||||||
|
id: answer
|
||||||
|
position:
|
||||||
|
x: 984
|
||||||
|
y: 282
|
||||||
|
positionAbsolute:
|
||||||
|
x: 984
|
||||||
|
y: 282
|
||||||
|
selected: true
|
||||||
|
sourcePosition: right
|
||||||
|
targetPosition: left
|
||||||
|
type: custom
|
||||||
|
width: 242
|
||||||
|
viewport:
|
||||||
|
x: 151
|
||||||
|
y: 141.5
|
||||||
|
zoom: 1
|
||||||
|
rag_pipeline_variables: []
|
||||||
@ -0,0 +1,77 @@
|
|||||||
|
"""
|
||||||
|
Unit tests for virtual node configuration.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from core.workflow.nodes.base.entities import VirtualNodeConfig
|
||||||
|
|
||||||
|
|
||||||
|
class TestVirtualNodeConfig:
|
||||||
|
"""Tests for VirtualNodeConfig entity."""
|
||||||
|
|
||||||
|
def test_create_basic_config(self):
|
||||||
|
"""Test creating a basic virtual node config."""
|
||||||
|
config = VirtualNodeConfig(
|
||||||
|
id="ext_1",
|
||||||
|
type="llm",
|
||||||
|
data={
|
||||||
|
"title": "Extract keywords",
|
||||||
|
"model": {"provider": "openai", "name": "gpt-4o-mini"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert config.id == "ext_1"
|
||||||
|
assert config.type == "llm"
|
||||||
|
assert config.data["title"] == "Extract keywords"
|
||||||
|
|
||||||
|
def test_get_global_id(self):
|
||||||
|
"""Test generating global ID from parent ID."""
|
||||||
|
config = VirtualNodeConfig(
|
||||||
|
id="ext_1",
|
||||||
|
type="llm",
|
||||||
|
data={},
|
||||||
|
)
|
||||||
|
|
||||||
|
global_id = config.get_global_id("tool1")
|
||||||
|
assert global_id == "tool1.ext_1"
|
||||||
|
|
||||||
|
def test_get_global_id_with_different_parents(self):
|
||||||
|
"""Test global ID generation with different parent IDs."""
|
||||||
|
config = VirtualNodeConfig(id="sub_node", type="code", data={})
|
||||||
|
|
||||||
|
assert config.get_global_id("parent1") == "parent1.sub_node"
|
||||||
|
assert config.get_global_id("node_123") == "node_123.sub_node"
|
||||||
|
|
||||||
|
def test_empty_data(self):
|
||||||
|
"""Test virtual node config with empty data."""
|
||||||
|
config = VirtualNodeConfig(
|
||||||
|
id="test",
|
||||||
|
type="tool",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert config.id == "test"
|
||||||
|
assert config.type == "tool"
|
||||||
|
assert config.data == {}
|
||||||
|
|
||||||
|
def test_complex_data(self):
|
||||||
|
"""Test virtual node config with complex data."""
|
||||||
|
config = VirtualNodeConfig(
|
||||||
|
id="llm_1",
|
||||||
|
type="llm",
|
||||||
|
data={
|
||||||
|
"title": "Generate summary",
|
||||||
|
"model": {
|
||||||
|
"provider": "openai",
|
||||||
|
"name": "gpt-4",
|
||||||
|
"mode": "chat",
|
||||||
|
"completion_params": {"temperature": 0.7, "max_tokens": 500},
|
||||||
|
},
|
||||||
|
"prompt_template": [
|
||||||
|
{"role": "user", "text": "{{#llm1.context#}}"},
|
||||||
|
{"role": "user", "text": "Please summarize the conversation"},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert config.data["model"]["provider"] == "openai"
|
||||||
|
assert len(config.data["prompt_template"]) == 2
|
||||||
|
|
||||||
@ -37,7 +37,10 @@ export const useWorkflowNodeFinished = () => {
|
|||||||
}))
|
}))
|
||||||
|
|
||||||
const newNodes = produce(nodes, (draft) => {
|
const newNodes = produce(nodes, (draft) => {
|
||||||
const currentNode = draft.find(node => node.id === data.node_id)!
|
const currentNode = draft.find(node => node.id === data.node_id)
|
||||||
|
// Skip if node not found (e.g., virtual extraction nodes)
|
||||||
|
if (!currentNode)
|
||||||
|
return
|
||||||
currentNode.data._runningStatus = data.status
|
currentNode.data._runningStatus = data.status
|
||||||
if (data.status === NodeRunningStatus.Exception) {
|
if (data.status === NodeRunningStatus.Exception) {
|
||||||
if (data.execution_metadata?.error_strategy === ErrorHandleTypeEnum.failBranch)
|
if (data.execution_metadata?.error_strategy === ErrorHandleTypeEnum.failBranch)
|
||||||
|
|||||||
@ -45,6 +45,11 @@ export const useWorkflowNodeStarted = () => {
|
|||||||
} = reactflow
|
} = reactflow
|
||||||
const currentNodeIndex = nodes.findIndex(node => node.id === data.node_id)
|
const currentNodeIndex = nodes.findIndex(node => node.id === data.node_id)
|
||||||
const currentNode = nodes[currentNodeIndex]
|
const currentNode = nodes[currentNodeIndex]
|
||||||
|
|
||||||
|
// Skip if node not found (e.g., virtual extraction nodes)
|
||||||
|
if (!currentNode)
|
||||||
|
return
|
||||||
|
|
||||||
const position = currentNode.position
|
const position = currentNode.position
|
||||||
const zoom = transform[2]
|
const zoom = transform[2]
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user