fix: loop streaming by clearing stale subgraph variables (#30059)
Some checks are pending
autofix.ci / autofix (push) Waiting to run
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/amd64, build-api-amd64) (push) Waiting to run
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/arm64, build-api-arm64) (push) Waiting to run
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/amd64, build-web-amd64) (push) Waiting to run
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/arm64, build-web-arm64) (push) Waiting to run
Build and Push API & Web / create-manifest (api, DIFY_API_IMAGE_NAME, merge-api-images) (push) Blocked by required conditions
Build and Push API & Web / create-manifest (web, DIFY_WEB_IMAGE_NAME, merge-web-images) (push) Blocked by required conditions
Main CI Pipeline / Check Changed Files (push) Waiting to run
Main CI Pipeline / API Tests (push) Blocked by required conditions
Main CI Pipeline / Web Tests (push) Blocked by required conditions
Main CI Pipeline / Style Check (push) Waiting to run
Main CI Pipeline / VDB Tests (push) Blocked by required conditions
Main CI Pipeline / DB Migration Test (push) Blocked by required conditions

This commit is contained in:
Novice 2025-12-24 11:28:52 +08:00 committed by GitHub
parent a5309bee25
commit f439e081b5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 27 additions and 2 deletions

View File

@ -247,6 +247,7 @@ class WorkflowNodeExecutionMetadataKey(StrEnum):
ERROR_STRATEGY = "error_strategy" # node in continue on error mode return the field ERROR_STRATEGY = "error_strategy" # node in continue on error mode return the field
LOOP_VARIABLE_MAP = "loop_variable_map" # single loop variable output LOOP_VARIABLE_MAP = "loop_variable_map" # single loop variable output
DATASOURCE_INFO = "datasource_info" DATASOURCE_INFO = "datasource_info"
COMPLETED_REASON = "completed_reason" # completed reason for loop node
class WorkflowNodeExecutionStatus(StrEnum): class WorkflowNodeExecutionStatus(StrEnum):

View File

@ -1,3 +1,4 @@
from enum import StrEnum
from typing import Annotated, Any, Literal from typing import Annotated, Any, Literal
from pydantic import AfterValidator, BaseModel, Field, field_validator from pydantic import AfterValidator, BaseModel, Field, field_validator
@ -96,3 +97,8 @@ class LoopState(BaseLoopState):
Get current output. Get current output.
""" """
return self.current_output return self.current_output
class LoopCompletedReason(StrEnum):
LOOP_BREAK = "loop_break"
LOOP_COMPLETED = "loop_completed"

View File

@ -29,7 +29,7 @@ from core.workflow.node_events import (
) )
from core.workflow.nodes.base import LLMUsageTrackingMixin from core.workflow.nodes.base import LLMUsageTrackingMixin
from core.workflow.nodes.base.node import Node from core.workflow.nodes.base.node import Node
from core.workflow.nodes.loop.entities import LoopNodeData, LoopVariableData from core.workflow.nodes.loop.entities import LoopCompletedReason, LoopNodeData, LoopVariableData
from core.workflow.utils.condition.processor import ConditionProcessor from core.workflow.utils.condition.processor import ConditionProcessor
from factories.variable_factory import TypeMismatchError, build_segment_with_type, segment_to_variable from factories.variable_factory import TypeMismatchError, build_segment_with_type, segment_to_variable
from libs.datetime_utils import naive_utc_now from libs.datetime_utils import naive_utc_now
@ -96,6 +96,7 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
loop_duration_map: dict[str, float] = {} loop_duration_map: dict[str, float] = {}
single_loop_variable_map: dict[str, dict[str, Any]] = {} # single loop variable output single_loop_variable_map: dict[str, dict[str, Any]] = {} # single loop variable output
loop_usage = LLMUsage.empty_usage() loop_usage = LLMUsage.empty_usage()
loop_node_ids = self._extract_loop_node_ids_from_config(self.graph_config, self._node_id)
# Start Loop event # Start Loop event
yield LoopStartedEvent( yield LoopStartedEvent(
@ -118,6 +119,8 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
loop_count = 0 loop_count = 0
for i in range(loop_count): for i in range(loop_count):
# Clear stale variables from previous loop iterations to avoid streaming old values
self._clear_loop_subgraph_variables(loop_node_ids)
graph_engine = self._create_graph_engine(start_at=start_at, root_node_id=root_node_id) graph_engine = self._create_graph_engine(start_at=start_at, root_node_id=root_node_id)
loop_start_time = naive_utc_now() loop_start_time = naive_utc_now()
@ -177,7 +180,11 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: loop_usage.total_tokens, WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: loop_usage.total_tokens,
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: loop_usage.total_price, WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: loop_usage.total_price,
WorkflowNodeExecutionMetadataKey.CURRENCY: loop_usage.currency, WorkflowNodeExecutionMetadataKey.CURRENCY: loop_usage.currency,
"completed_reason": "loop_break" if reach_break_condition else "loop_completed", WorkflowNodeExecutionMetadataKey.COMPLETED_REASON: (
LoopCompletedReason.LOOP_BREAK
if reach_break_condition
else LoopCompletedReason.LOOP_COMPLETED.value
),
WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map, WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map, WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
}, },
@ -274,6 +281,17 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
if WorkflowNodeExecutionMetadataKey.LOOP_ID not in current_metadata: if WorkflowNodeExecutionMetadataKey.LOOP_ID not in current_metadata:
event.node_run_result.metadata = {**current_metadata, **loop_metadata} event.node_run_result.metadata = {**current_metadata, **loop_metadata}
def _clear_loop_subgraph_variables(self, loop_node_ids: set[str]) -> None:
"""
Remove variables produced by loop sub-graph nodes from previous iterations.
Keeping stale variables causes a freshly created response coordinator in the
next iteration to fall back to outdated values when no stream chunks exist.
"""
variable_pool = self.graph_runtime_state.variable_pool
for node_id in loop_node_ids:
variable_pool.remove([node_id])
@classmethod @classmethod
def _extract_variable_selector_to_variable_mapping( def _extract_variable_selector_to_variable_mapping(
cls, cls,