From a112caf5ec9ed976840ef8304bbdddc5467a9e61 Mon Sep 17 00:00:00 2001 From: wangxiaolei Date: Thu, 22 Jan 2026 18:02:54 +0800 Subject: [PATCH] fix: use thread local isolation the context (#31410) --- api/context/flask_app_context.py | 34 ++++++------- .../workflow/context/execution_context.py | 12 +++-- api/core/workflow/graph_engine/worker.py | 3 +- .../context/test_execution_context.py | 50 +++++++++++++++++++ 4 files changed, 73 insertions(+), 26 deletions(-) diff --git a/api/context/flask_app_context.py b/api/context/flask_app_context.py index 360be16beb..2d465c8cf4 100644 --- a/api/context/flask_app_context.py +++ b/api/context/flask_app_context.py @@ -3,6 +3,7 @@ Flask App Context - Flask implementation of AppContext interface. """ import contextvars +import threading from collections.abc import Generator from contextlib import contextmanager from typing import Any, final @@ -118,6 +119,7 @@ class FlaskExecutionContext: self._context_vars = context_vars self._user = user self._flask_app = flask_app + self._local = threading.local() @property def app_context(self) -> FlaskAppContext: @@ -136,47 +138,39 @@ class FlaskExecutionContext: def __enter__(self) -> "FlaskExecutionContext": """Enter the Flask execution context.""" - # Restore context variables + # Restore non-Flask context variables to avoid leaking Flask tokens across threads for var, val in self._context_vars.items(): var.set(val) - # Save current user from g if available - saved_user = None - if hasattr(g, "_login_user"): - saved_user = g._login_user - # Enter Flask app context - self._cm = self._app_context.enter() - self._cm.__enter__() + cm = self._app_context.enter() + self._local.cm = cm + cm.__enter__() # Restore user in new app context - if saved_user is not None: - g._login_user = saved_user + if self._user is not None: + g._login_user = self._user return self def __exit__(self, *args: Any) -> None: """Exit the Flask execution context.""" - if hasattr(self, "_cm"): - self._cm.__exit__(*args) + cm = getattr(self._local, "cm", None) + if cm is not None: + cm.__exit__(*args) @contextmanager def enter(self) -> Generator[None, None, None]: """Enter Flask execution context as context manager.""" - # Restore context variables + # Restore non-Flask context variables to avoid leaking Flask tokens across threads for var, val in self._context_vars.items(): var.set(val) - # Save current user from g if available - saved_user = None - if hasattr(g, "_login_user"): - saved_user = g._login_user - # Enter Flask app context with self._flask_app.app_context(): # Restore user in new app context - if saved_user is not None: - g._login_user = saved_user + if self._user is not None: + g._login_user = self._user yield diff --git a/api/core/workflow/context/execution_context.py b/api/core/workflow/context/execution_context.py index d951c95d68..e3007530f0 100644 --- a/api/core/workflow/context/execution_context.py +++ b/api/core/workflow/context/execution_context.py @@ -3,6 +3,7 @@ Execution Context - Abstracted context management for workflow execution. """ import contextvars +import threading from abc import ABC, abstractmethod from collections.abc import Callable, Generator from contextlib import AbstractContextManager, contextmanager @@ -88,6 +89,7 @@ class ExecutionContext: self._app_context = app_context self._context_vars = context_vars self._user = user + self._local = threading.local() @property def app_context(self) -> AppContext | None: @@ -125,14 +127,16 @@ class ExecutionContext: def __enter__(self) -> "ExecutionContext": """Enter the execution context.""" - self._cm = self.enter() - self._cm.__enter__() + cm = self.enter() + self._local.cm = cm + cm.__enter__() return self def __exit__(self, *args: Any) -> None: """Exit the execution context.""" - if hasattr(self, "_cm"): - self._cm.__exit__(*args) + cm = getattr(self._local, "cm", None) + if cm is not None: + cm.__exit__(*args) class NullAppContext(AppContext): diff --git a/api/core/workflow/graph_engine/worker.py b/api/core/workflow/graph_engine/worker.py index 95db5c5c92..6c69ea5df0 100644 --- a/api/core/workflow/graph_engine/worker.py +++ b/api/core/workflow/graph_engine/worker.py @@ -11,7 +11,6 @@ import time from collections.abc import Sequence from datetime import datetime from typing import TYPE_CHECKING, final -from uuid import uuid4 from typing_extensions import override @@ -113,7 +112,7 @@ class Worker(threading.Thread): self._ready_queue.task_done() except Exception as e: error_event = NodeRunFailedEvent( - id=str(uuid4()), + id=node.execution_id, node_id=node.id, node_type=node.node_type, in_iteration_id=None, diff --git a/api/tests/unit_tests/core/workflow/context/test_execution_context.py b/api/tests/unit_tests/core/workflow/context/test_execution_context.py index 63466cfb5e..8dd669e17f 100644 --- a/api/tests/unit_tests/core/workflow/context/test_execution_context.py +++ b/api/tests/unit_tests/core/workflow/context/test_execution_context.py @@ -1,6 +1,8 @@ """Tests for execution context module.""" import contextvars +import threading +from contextlib import contextmanager from typing import Any from unittest.mock import MagicMock @@ -149,6 +151,54 @@ class TestExecutionContext: assert ctx.user == user + def test_thread_safe_context_manager(self): + """Test shared ExecutionContext works across threads without token mismatch.""" + test_var = contextvars.ContextVar("thread_safe_test_var") + + class TrackingAppContext(AppContext): + def get_config(self, key: str, default: Any = None) -> Any: + return default + + def get_extension(self, name: str) -> Any: + return None + + @contextmanager + def enter(self): + token = test_var.set(threading.get_ident()) + try: + yield + finally: + test_var.reset(token) + + ctx = ExecutionContext(app_context=TrackingAppContext()) + errors: list[Exception] = [] + barrier = threading.Barrier(2) + + def worker(): + try: + for _ in range(20): + with ctx: + try: + barrier.wait() + barrier.wait() + except threading.BrokenBarrierError: + return + except Exception as exc: + errors.append(exc) + try: + barrier.abort() + except Exception: + pass + + t1 = threading.Thread(target=worker) + t2 = threading.Thread(target=worker) + t1.start() + t2.start() + t1.join(timeout=5) + t2.join(timeout=5) + + assert not errors + class TestIExecutionContextProtocol: """Test IExecutionContext protocol."""