fix: use thread local isolation the context (#31410)

This commit is contained in:
wangxiaolei 2026-01-22 18:02:54 +08:00 committed by GitHub
parent 510a02286f
commit a112caf5ec
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 73 additions and 26 deletions

View File

@ -3,6 +3,7 @@ Flask App Context - Flask implementation of AppContext interface.
"""
import contextvars
import threading
from collections.abc import Generator
from contextlib import contextmanager
from typing import Any, final
@ -118,6 +119,7 @@ class FlaskExecutionContext:
self._context_vars = context_vars
self._user = user
self._flask_app = flask_app
self._local = threading.local()
@property
def app_context(self) -> FlaskAppContext:
@ -136,47 +138,39 @@ class FlaskExecutionContext:
def __enter__(self) -> "FlaskExecutionContext":
"""Enter the Flask execution context."""
# Restore context variables
# Restore non-Flask context variables to avoid leaking Flask tokens across threads
for var, val in self._context_vars.items():
var.set(val)
# Save current user from g if available
saved_user = None
if hasattr(g, "_login_user"):
saved_user = g._login_user
# Enter Flask app context
self._cm = self._app_context.enter()
self._cm.__enter__()
cm = self._app_context.enter()
self._local.cm = cm
cm.__enter__()
# Restore user in new app context
if saved_user is not None:
g._login_user = saved_user
if self._user is not None:
g._login_user = self._user
return self
def __exit__(self, *args: Any) -> None:
"""Exit the Flask execution context."""
if hasattr(self, "_cm"):
self._cm.__exit__(*args)
cm = getattr(self._local, "cm", None)
if cm is not None:
cm.__exit__(*args)
@contextmanager
def enter(self) -> Generator[None, None, None]:
"""Enter Flask execution context as context manager."""
# Restore context variables
# Restore non-Flask context variables to avoid leaking Flask tokens across threads
for var, val in self._context_vars.items():
var.set(val)
# Save current user from g if available
saved_user = None
if hasattr(g, "_login_user"):
saved_user = g._login_user
# Enter Flask app context
with self._flask_app.app_context():
# Restore user in new app context
if saved_user is not None:
g._login_user = saved_user
if self._user is not None:
g._login_user = self._user
yield

View File

@ -3,6 +3,7 @@ Execution Context - Abstracted context management for workflow execution.
"""
import contextvars
import threading
from abc import ABC, abstractmethod
from collections.abc import Callable, Generator
from contextlib import AbstractContextManager, contextmanager
@ -88,6 +89,7 @@ class ExecutionContext:
self._app_context = app_context
self._context_vars = context_vars
self._user = user
self._local = threading.local()
@property
def app_context(self) -> AppContext | None:
@ -125,14 +127,16 @@ class ExecutionContext:
def __enter__(self) -> "ExecutionContext":
"""Enter the execution context."""
self._cm = self.enter()
self._cm.__enter__()
cm = self.enter()
self._local.cm = cm
cm.__enter__()
return self
def __exit__(self, *args: Any) -> None:
"""Exit the execution context."""
if hasattr(self, "_cm"):
self._cm.__exit__(*args)
cm = getattr(self._local, "cm", None)
if cm is not None:
cm.__exit__(*args)
class NullAppContext(AppContext):

View File

@ -11,7 +11,6 @@ import time
from collections.abc import Sequence
from datetime import datetime
from typing import TYPE_CHECKING, final
from uuid import uuid4
from typing_extensions import override
@ -113,7 +112,7 @@ class Worker(threading.Thread):
self._ready_queue.task_done()
except Exception as e:
error_event = NodeRunFailedEvent(
id=str(uuid4()),
id=node.execution_id,
node_id=node.id,
node_type=node.node_type,
in_iteration_id=None,

View File

@ -1,6 +1,8 @@
"""Tests for execution context module."""
import contextvars
import threading
from contextlib import contextmanager
from typing import Any
from unittest.mock import MagicMock
@ -149,6 +151,54 @@ class TestExecutionContext:
assert ctx.user == user
def test_thread_safe_context_manager(self):
"""Test shared ExecutionContext works across threads without token mismatch."""
test_var = contextvars.ContextVar("thread_safe_test_var")
class TrackingAppContext(AppContext):
def get_config(self, key: str, default: Any = None) -> Any:
return default
def get_extension(self, name: str) -> Any:
return None
@contextmanager
def enter(self):
token = test_var.set(threading.get_ident())
try:
yield
finally:
test_var.reset(token)
ctx = ExecutionContext(app_context=TrackingAppContext())
errors: list[Exception] = []
barrier = threading.Barrier(2)
def worker():
try:
for _ in range(20):
with ctx:
try:
barrier.wait()
barrier.wait()
except threading.BrokenBarrierError:
return
except Exception as exc:
errors.append(exc)
try:
barrier.abort()
except Exception:
pass
t1 = threading.Thread(target=worker)
t2 = threading.Thread(target=worker)
t1.start()
t2.start()
t1.join(timeout=5)
t2.join(timeout=5)
assert not errors
class TestIExecutionContextProtocol:
"""Test IExecutionContext protocol."""