refactor(sandbox-manager): implement sharded locking for sandbox management

- Enhanced the SandboxManager to use a sharded locking mechanism for improved concurrency and performance.
- Replaced the global lock with shard-specific locks, allowing for lock-free reads and reducing contention.
- Updated methods for registering, retrieving, unregistering, and counting sandboxes to work with the new sharded structure.
- Improved documentation within the class to clarify the purpose and functionality of the sharding approach.
This commit is contained in:
Harry 2026-01-09 12:13:41 +08:00
parent 0da4d64d38
commit 3b454fa95a

View File

@ -8,56 +8,93 @@ logger = logging.getLogger(__name__)
class SandboxManager:
_lock: Final[threading.Lock] = threading.Lock()
_sandboxes: dict[str, VirtualEnvironment] = {}
"""Process-local registry for workflow sandboxes.
Stores `VirtualEnvironment` references keyed by `workflow_execution_id`.
Concurrency: the registry is split into hash shards and each shard is updated with
copy-on-write under a shard lock. Reads are lock-free (snapshot dict) to reduce
contention in hot paths like `get()`.
"""
# FIXME: Prefer a workflow-level context on GraphRuntimeState to store workflow-scoped shared objects.
_NUM_SHARDS: Final[int] = 1024
_SHARD_MASK: Final[int] = _NUM_SHARDS - 1
_shard_locks: Final[tuple[threading.Lock, ...]] = tuple(threading.Lock() for _ in range(_NUM_SHARDS))
_shards: list[dict[str, VirtualEnvironment]] = [{} for _ in range(_NUM_SHARDS)]
@classmethod
def _shard_index(cls, workflow_execution_id: str) -> int:
return hash(workflow_execution_id) & cls._SHARD_MASK
@classmethod
def register(cls, workflow_execution_id: str, sandbox: VirtualEnvironment) -> None:
if not workflow_execution_id:
raise ValueError("workflow_execution_id cannot be empty")
with cls._lock:
if workflow_execution_id in cls._sandboxes:
shard_index = cls._shard_index(workflow_execution_id)
with cls._shard_locks[shard_index]:
shard = cls._shards[shard_index]
if workflow_execution_id in shard:
raise RuntimeError(
f"Sandbox already registered for workflow_execution_id={workflow_execution_id}. "
"Call unregister() first if you need to replace it."
)
cls._sandboxes[workflow_execution_id] = sandbox
logger.debug(
"Registered sandbox for workflow_execution_id=%s, sandbox_id=%s",
workflow_execution_id,
sandbox.metadata.id,
)
new_shard = dict(shard)
new_shard[workflow_execution_id] = sandbox
cls._shards[shard_index] = new_shard
logger.debug(
"Registered sandbox for workflow_execution_id=%s, sandbox_id=%s",
workflow_execution_id,
sandbox.metadata.id,
)
@classmethod
def get(cls, workflow_execution_id: str) -> VirtualEnvironment | None:
with cls._lock:
return cls._sandboxes.get(workflow_execution_id)
shard_index = cls._shard_index(workflow_execution_id)
return cls._shards[shard_index].get(workflow_execution_id)
@classmethod
def unregister(cls, workflow_execution_id: str) -> VirtualEnvironment | None:
with cls._lock:
sandbox = cls._sandboxes.pop(workflow_execution_id, None)
if sandbox:
logger.debug(
"Unregistered sandbox for workflow_execution_id=%s, sandbox_id=%s",
workflow_execution_id,
sandbox.metadata.id,
)
return sandbox
shard_index = cls._shard_index(workflow_execution_id)
with cls._shard_locks[shard_index]:
shard = cls._shards[shard_index]
sandbox = shard.get(workflow_execution_id)
if sandbox is None:
return None
new_shard = dict(shard)
new_shard.pop(workflow_execution_id, None)
cls._shards[shard_index] = new_shard
logger.debug(
"Unregistered sandbox for workflow_execution_id=%s, sandbox_id=%s",
workflow_execution_id,
sandbox.metadata.id,
)
return sandbox
@classmethod
def has(cls, workflow_execution_id: str) -> bool:
with cls._lock:
return workflow_execution_id in cls._sandboxes
shard_index = cls._shard_index(workflow_execution_id)
return workflow_execution_id in cls._shards[shard_index]
@classmethod
def clear(cls) -> None:
with cls._lock:
cls._sandboxes.clear()
for lock in cls._shard_locks:
lock.acquire()
try:
for i in range(cls._NUM_SHARDS):
cls._shards[i] = {}
logger.debug("Cleared all registered sandboxes")
finally:
for lock in reversed(cls._shard_locks):
lock.release()
@classmethod
def count(cls) -> int:
with cls._lock:
return len(cls._sandboxes)
return sum(len(shard) for shard in cls._shards)