mirror of
https://github.com/langgenius/dify.git
synced 2026-01-14 06:07:33 +08:00
refactor(sandbox-manager): implement sharded locking for sandbox management
- Enhanced the SandboxManager to use a sharded locking mechanism for improved concurrency and performance. - Replaced the global lock with shard-specific locks, allowing for lock-free reads and reducing contention. - Updated methods for registering, retrieving, unregistering, and counting sandboxes to work with the new sharded structure. - Improved documentation within the class to clarify the purpose and functionality of the sharding approach.
This commit is contained in:
parent
0da4d64d38
commit
3b454fa95a
@ -8,56 +8,93 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SandboxManager:
|
||||
_lock: Final[threading.Lock] = threading.Lock()
|
||||
_sandboxes: dict[str, VirtualEnvironment] = {}
|
||||
"""Process-local registry for workflow sandboxes.
|
||||
|
||||
Stores `VirtualEnvironment` references keyed by `workflow_execution_id`.
|
||||
|
||||
Concurrency: the registry is split into hash shards and each shard is updated with
|
||||
copy-on-write under a shard lock. Reads are lock-free (snapshot dict) to reduce
|
||||
contention in hot paths like `get()`.
|
||||
"""
|
||||
|
||||
# FIXME: Prefer a workflow-level context on GraphRuntimeState to store workflow-scoped shared objects.
|
||||
|
||||
_NUM_SHARDS: Final[int] = 1024
|
||||
_SHARD_MASK: Final[int] = _NUM_SHARDS - 1
|
||||
|
||||
_shard_locks: Final[tuple[threading.Lock, ...]] = tuple(threading.Lock() for _ in range(_NUM_SHARDS))
|
||||
_shards: list[dict[str, VirtualEnvironment]] = [{} for _ in range(_NUM_SHARDS)]
|
||||
|
||||
@classmethod
|
||||
def _shard_index(cls, workflow_execution_id: str) -> int:
|
||||
return hash(workflow_execution_id) & cls._SHARD_MASK
|
||||
|
||||
@classmethod
|
||||
def register(cls, workflow_execution_id: str, sandbox: VirtualEnvironment) -> None:
|
||||
if not workflow_execution_id:
|
||||
raise ValueError("workflow_execution_id cannot be empty")
|
||||
|
||||
with cls._lock:
|
||||
if workflow_execution_id in cls._sandboxes:
|
||||
shard_index = cls._shard_index(workflow_execution_id)
|
||||
with cls._shard_locks[shard_index]:
|
||||
shard = cls._shards[shard_index]
|
||||
if workflow_execution_id in shard:
|
||||
raise RuntimeError(
|
||||
f"Sandbox already registered for workflow_execution_id={workflow_execution_id}. "
|
||||
"Call unregister() first if you need to replace it."
|
||||
)
|
||||
cls._sandboxes[workflow_execution_id] = sandbox
|
||||
logger.debug(
|
||||
"Registered sandbox for workflow_execution_id=%s, sandbox_id=%s",
|
||||
workflow_execution_id,
|
||||
sandbox.metadata.id,
|
||||
)
|
||||
|
||||
new_shard = dict(shard)
|
||||
new_shard[workflow_execution_id] = sandbox
|
||||
cls._shards[shard_index] = new_shard
|
||||
|
||||
logger.debug(
|
||||
"Registered sandbox for workflow_execution_id=%s, sandbox_id=%s",
|
||||
workflow_execution_id,
|
||||
sandbox.metadata.id,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get(cls, workflow_execution_id: str) -> VirtualEnvironment | None:
|
||||
with cls._lock:
|
||||
return cls._sandboxes.get(workflow_execution_id)
|
||||
shard_index = cls._shard_index(workflow_execution_id)
|
||||
return cls._shards[shard_index].get(workflow_execution_id)
|
||||
|
||||
@classmethod
|
||||
def unregister(cls, workflow_execution_id: str) -> VirtualEnvironment | None:
|
||||
with cls._lock:
|
||||
sandbox = cls._sandboxes.pop(workflow_execution_id, None)
|
||||
if sandbox:
|
||||
logger.debug(
|
||||
"Unregistered sandbox for workflow_execution_id=%s, sandbox_id=%s",
|
||||
workflow_execution_id,
|
||||
sandbox.metadata.id,
|
||||
)
|
||||
return sandbox
|
||||
shard_index = cls._shard_index(workflow_execution_id)
|
||||
with cls._shard_locks[shard_index]:
|
||||
shard = cls._shards[shard_index]
|
||||
sandbox = shard.get(workflow_execution_id)
|
||||
if sandbox is None:
|
||||
return None
|
||||
|
||||
new_shard = dict(shard)
|
||||
new_shard.pop(workflow_execution_id, None)
|
||||
cls._shards[shard_index] = new_shard
|
||||
|
||||
logger.debug(
|
||||
"Unregistered sandbox for workflow_execution_id=%s, sandbox_id=%s",
|
||||
workflow_execution_id,
|
||||
sandbox.metadata.id,
|
||||
)
|
||||
return sandbox
|
||||
|
||||
@classmethod
|
||||
def has(cls, workflow_execution_id: str) -> bool:
|
||||
with cls._lock:
|
||||
return workflow_execution_id in cls._sandboxes
|
||||
shard_index = cls._shard_index(workflow_execution_id)
|
||||
return workflow_execution_id in cls._shards[shard_index]
|
||||
|
||||
@classmethod
|
||||
def clear(cls) -> None:
|
||||
with cls._lock:
|
||||
cls._sandboxes.clear()
|
||||
for lock in cls._shard_locks:
|
||||
lock.acquire()
|
||||
try:
|
||||
for i in range(cls._NUM_SHARDS):
|
||||
cls._shards[i] = {}
|
||||
logger.debug("Cleared all registered sandboxes")
|
||||
finally:
|
||||
for lock in reversed(cls._shard_locks):
|
||||
lock.release()
|
||||
|
||||
@classmethod
|
||||
def count(cls) -> int:
|
||||
with cls._lock:
|
||||
return len(cls._sandboxes)
|
||||
return sum(len(shard) for shard in cls._shards)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user