mirror of
https://github.com/langgenius/dify.git
synced 2026-02-01 16:41:58 +08:00
fix: use thread local isolation the context (#31410)
This commit is contained in:
parent
510a02286f
commit
a112caf5ec
@ -3,6 +3,7 @@ Flask App Context - Flask implementation of AppContext interface.
|
||||
"""
|
||||
|
||||
import contextvars
|
||||
import threading
|
||||
from collections.abc import Generator
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, final
|
||||
@ -118,6 +119,7 @@ class FlaskExecutionContext:
|
||||
self._context_vars = context_vars
|
||||
self._user = user
|
||||
self._flask_app = flask_app
|
||||
self._local = threading.local()
|
||||
|
||||
@property
|
||||
def app_context(self) -> FlaskAppContext:
|
||||
@ -136,47 +138,39 @@ class FlaskExecutionContext:
|
||||
|
||||
def __enter__(self) -> "FlaskExecutionContext":
|
||||
"""Enter the Flask execution context."""
|
||||
# Restore context variables
|
||||
# Restore non-Flask context variables to avoid leaking Flask tokens across threads
|
||||
for var, val in self._context_vars.items():
|
||||
var.set(val)
|
||||
|
||||
# Save current user from g if available
|
||||
saved_user = None
|
||||
if hasattr(g, "_login_user"):
|
||||
saved_user = g._login_user
|
||||
|
||||
# Enter Flask app context
|
||||
self._cm = self._app_context.enter()
|
||||
self._cm.__enter__()
|
||||
cm = self._app_context.enter()
|
||||
self._local.cm = cm
|
||||
cm.__enter__()
|
||||
|
||||
# Restore user in new app context
|
||||
if saved_user is not None:
|
||||
g._login_user = saved_user
|
||||
if self._user is not None:
|
||||
g._login_user = self._user
|
||||
|
||||
return self
|
||||
|
||||
def __exit__(self, *args: Any) -> None:
|
||||
"""Exit the Flask execution context."""
|
||||
if hasattr(self, "_cm"):
|
||||
self._cm.__exit__(*args)
|
||||
cm = getattr(self._local, "cm", None)
|
||||
if cm is not None:
|
||||
cm.__exit__(*args)
|
||||
|
||||
@contextmanager
|
||||
def enter(self) -> Generator[None, None, None]:
|
||||
"""Enter Flask execution context as context manager."""
|
||||
# Restore context variables
|
||||
# Restore non-Flask context variables to avoid leaking Flask tokens across threads
|
||||
for var, val in self._context_vars.items():
|
||||
var.set(val)
|
||||
|
||||
# Save current user from g if available
|
||||
saved_user = None
|
||||
if hasattr(g, "_login_user"):
|
||||
saved_user = g._login_user
|
||||
|
||||
# Enter Flask app context
|
||||
with self._flask_app.app_context():
|
||||
# Restore user in new app context
|
||||
if saved_user is not None:
|
||||
g._login_user = saved_user
|
||||
if self._user is not None:
|
||||
g._login_user = self._user
|
||||
yield
|
||||
|
||||
|
||||
|
||||
@ -3,6 +3,7 @@ Execution Context - Abstracted context management for workflow execution.
|
||||
"""
|
||||
|
||||
import contextvars
|
||||
import threading
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable, Generator
|
||||
from contextlib import AbstractContextManager, contextmanager
|
||||
@ -88,6 +89,7 @@ class ExecutionContext:
|
||||
self._app_context = app_context
|
||||
self._context_vars = context_vars
|
||||
self._user = user
|
||||
self._local = threading.local()
|
||||
|
||||
@property
|
||||
def app_context(self) -> AppContext | None:
|
||||
@ -125,14 +127,16 @@ class ExecutionContext:
|
||||
|
||||
def __enter__(self) -> "ExecutionContext":
|
||||
"""Enter the execution context."""
|
||||
self._cm = self.enter()
|
||||
self._cm.__enter__()
|
||||
cm = self.enter()
|
||||
self._local.cm = cm
|
||||
cm.__enter__()
|
||||
return self
|
||||
|
||||
def __exit__(self, *args: Any) -> None:
|
||||
"""Exit the execution context."""
|
||||
if hasattr(self, "_cm"):
|
||||
self._cm.__exit__(*args)
|
||||
cm = getattr(self._local, "cm", None)
|
||||
if cm is not None:
|
||||
cm.__exit__(*args)
|
||||
|
||||
|
||||
class NullAppContext(AppContext):
|
||||
|
||||
@ -11,7 +11,6 @@ import time
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, final
|
||||
from uuid import uuid4
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
@ -113,7 +112,7 @@ class Worker(threading.Thread):
|
||||
self._ready_queue.task_done()
|
||||
except Exception as e:
|
||||
error_event = NodeRunFailedEvent(
|
||||
id=str(uuid4()),
|
||||
id=node.execution_id,
|
||||
node_id=node.id,
|
||||
node_type=node.node_type,
|
||||
in_iteration_id=None,
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
"""Tests for execution context module."""
|
||||
|
||||
import contextvars
|
||||
import threading
|
||||
from contextlib import contextmanager
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
@ -149,6 +151,54 @@ class TestExecutionContext:
|
||||
|
||||
assert ctx.user == user
|
||||
|
||||
def test_thread_safe_context_manager(self):
|
||||
"""Test shared ExecutionContext works across threads without token mismatch."""
|
||||
test_var = contextvars.ContextVar("thread_safe_test_var")
|
||||
|
||||
class TrackingAppContext(AppContext):
|
||||
def get_config(self, key: str, default: Any = None) -> Any:
|
||||
return default
|
||||
|
||||
def get_extension(self, name: str) -> Any:
|
||||
return None
|
||||
|
||||
@contextmanager
|
||||
def enter(self):
|
||||
token = test_var.set(threading.get_ident())
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
test_var.reset(token)
|
||||
|
||||
ctx = ExecutionContext(app_context=TrackingAppContext())
|
||||
errors: list[Exception] = []
|
||||
barrier = threading.Barrier(2)
|
||||
|
||||
def worker():
|
||||
try:
|
||||
for _ in range(20):
|
||||
with ctx:
|
||||
try:
|
||||
barrier.wait()
|
||||
barrier.wait()
|
||||
except threading.BrokenBarrierError:
|
||||
return
|
||||
except Exception as exc:
|
||||
errors.append(exc)
|
||||
try:
|
||||
barrier.abort()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
t1 = threading.Thread(target=worker)
|
||||
t2 = threading.Thread(target=worker)
|
||||
t1.start()
|
||||
t2.start()
|
||||
t1.join(timeout=5)
|
||||
t2.join(timeout=5)
|
||||
|
||||
assert not errors
|
||||
|
||||
|
||||
class TestIExecutionContextProtocol:
|
||||
"""Test IExecutionContext protocol."""
|
||||
|
||||
Loading…
Reference in New Issue
Block a user