feat(sandbox-layer): refactor sandbox management and integrate with SandboxManager

- Simplified the SandboxLayer initialization by removing unused parameters and consolidating sandbox creation logic.
- Integrated SandboxManager for better lifecycle management of sandboxes during workflow execution.
- Updated error handling to ensure proper initialization and cleanup of sandboxes.
- Enhanced CommandNode to retrieve sandboxes from SandboxManager, improving sandbox availability checks.
- Added unit tests to validate the new sandbox management approach and ensure robust error handling.
This commit is contained in:
Harry 2026-01-09 11:08:55 +08:00
parent b09a831d15
commit 0da4d64d38
7 changed files with 481 additions and 270 deletions

View File

@ -1,185 +1,95 @@
"""
Sandbox Layer for managing VirtualEnvironment lifecycle during workflow execution.
"""
import contextlib
import logging
from collections.abc import Mapping
from typing import Any
from core.virtual_environment.__base.virtual_environment import VirtualEnvironment
from core.virtual_environment.factory import SandboxFactory, SandboxType
from core.workflow.enums import NodeType
from core.virtual_environment.sandbox_manager import SandboxManager
from core.workflow.graph_engine.layers.base import GraphEngineLayer
from core.workflow.graph_events.base import GraphEngineEvent
from core.workflow.nodes.base.node import Node
logger = logging.getLogger(__name__)
class SandboxInitializationError(Exception):
"""Raised when sandbox initialization fails."""
pass
class SandboxLayer(GraphEngineLayer):
"""
Manages VirtualEnvironment (sandbox) lifecycle during workflow execution.
Responsibilities:
- on_graph_start: Initialize the sandbox environment
- on_graph_end: Release the sandbox environment (cleanup)
Example:
# Using tenant-specific configuration (recommended):
layer = SandboxLayer(tenant_id="tenant-uuid")
# Using explicit configuration (for testing/override):
layer = SandboxLayer(
sandbox_type=SandboxType.DOCKER,
options={"docker_image": "python:3.11-slim"},
)
graph_engine.layer(layer)
# During workflow execution, access sandbox via:
# layer.sandbox.execute_command(...)
"""
def __init__(
self,
tenant_id: str | None = None,
sandbox_type: SandboxType | None = None,
tenant_id: str,
options: Mapping[str, Any] | None = None,
environments: Mapping[str, str] | None = None,
) -> None:
"""
Initialize the SandboxLayer.
Args:
tenant_id: Tenant ID to load sandbox configuration from database.
If provided, sandbox_type and options are ignored and
loaded from the tenant's active sandbox provider.
sandbox_type: Type of sandbox to create (default: DOCKER).
Only used if tenant_id is not provided.
options: Sandbox-specific configuration options.
Only used if tenant_id is not provided.
environments: Environment variables to set in the sandbox.
"""
super().__init__()
self._tenant_id = tenant_id
self._sandbox_type = sandbox_type
self._options: Mapping[str, Any] = options or {}
self._environments: Mapping[str, str] = environments or {}
self._sandbox: VirtualEnvironment | None = None
self._workflow_execution_id: str | None = None
def _get_workflow_execution_id(self) -> str:
workflow_execution_id = self.graph_runtime_state.system_variable.workflow_execution_id
if not workflow_execution_id:
raise RuntimeError("workflow_execution_id is not set in system variables")
return workflow_execution_id
@property
def sandbox(self) -> VirtualEnvironment:
"""
Get the current sandbox instance.
Returns:
The initialized VirtualEnvironment instance
Raises:
RuntimeError: If sandbox has not been initialized
"""
if self._sandbox is None:
if self._workflow_execution_id is None:
raise RuntimeError("Sandbox not initialized. Ensure on_graph_start() has been called.")
return self._sandbox
sandbox = SandboxManager.get(self._workflow_execution_id)
if sandbox is None:
raise RuntimeError(f"Sandbox not found for workflow_execution_id={self._workflow_execution_id}")
return sandbox
def on_graph_start(self) -> None:
"""
Initialize the sandbox when workflow execution starts.
self._workflow_execution_id = self._get_workflow_execution_id()
If tenant_id was provided, uses SandboxProviderService to create
the sandbox with the tenant's active provider configuration.
Otherwise, falls back to explicit sandbox_type/options.
Raises:
SandboxInitializationError: If sandbox cannot be created
"""
try:
if self._tenant_id:
# Use SandboxProviderService to create sandbox based on tenant config
from services.sandbox.sandbox_provider_service import SandboxProviderService
sandbox: VirtualEnvironment
from services.sandbox.sandbox_provider_service import SandboxProviderService
logger.info("Initializing sandbox for tenant_id=%s", self._tenant_id)
self._sandbox = SandboxProviderService.create_sandbox(
tenant_id=self._tenant_id,
environments=self._environments,
)
else:
# Fallback to explicit configuration (backward compatibility)
sandbox_type = self._sandbox_type or SandboxType.DOCKER
logger.info("Initializing sandbox, sandbox_type=%s", sandbox_type)
# Use a placeholder tenant_id for backward compatibility when tenant_id is not provided
effective_tenant_id = self._tenant_id or "default"
self._sandbox = SandboxFactory.create(
tenant_id=effective_tenant_id,
sandbox_type=sandbox_type,
options=self._options,
environments=self._environments,
)
logger.info("Initializing sandbox for tenant_id=%s", self._tenant_id)
sandbox = SandboxProviderService.create_sandbox(
tenant_id=self._tenant_id,
environments=self._environments,
)
SandboxManager.register(self._workflow_execution_id, sandbox)
logger.info(
"Sandbox initialized, sandbox_id=%s, sandbox_arch=%s",
self._sandbox.metadata.id,
self._sandbox.metadata.arch,
"Sandbox initialized, workflow_execution_id=%s, sandbox_id=%s, sandbox_arch=%s",
self._workflow_execution_id,
sandbox.metadata.id,
sandbox.metadata.arch,
)
except Exception as e:
logger.exception("Failed to initialize sandbox")
raise SandboxInitializationError(f"Failed to initialize sandbox: {e}") from e
def on_event(self, event: GraphEngineEvent) -> None:
"""
Handle graph engine events.
Currently a no-op, but can be extended for sandbox monitoring/health checks.
"""
pass
def on_node_run_start(self, node: Node[Any]) -> None:
"""Attach sandbox handle to CommandNode instances."""
if node.node_type is not NodeType.COMMAND:
return
try:
# FIXME: type: ignore[attr-defined]
node.sandbox = self.sandbox # type: ignore[attr-defined]
except Exception:
logger.exception("Failed to attach sandbox to node")
def on_node_run_end(self, node: Node[Any], error: Exception | None) -> None:
_ = error
if node.node_type is not NodeType.COMMAND:
return
with contextlib.suppress(Exception):
# FIXME: type: ignore[attr-defined]
node.sandbox = None # type: ignore[attr-defined]
def on_graph_end(self, error: Exception | None) -> None:
"""
Release the sandbox when workflow execution ends.
This method is idempotent and will not raise exceptions on cleanup failure.
Args:
error: The exception that caused execution to fail, or None if successful
"""
if self._sandbox is None:
logger.debug("No sandbox to release")
if self._workflow_execution_id is None:
logger.debug("No workflow_execution_id set, nothing to release")
return
sandbox_id = self._sandbox.metadata.id
logger.info("Releasing sandbox, sandbox_id=%s", sandbox_id)
sandbox = SandboxManager.unregister(self._workflow_execution_id)
if sandbox is None:
logger.debug("No sandbox to release for workflow_execution_id=%s", self._workflow_execution_id)
return
sandbox_id = sandbox.metadata.id
logger.info(
"Releasing sandbox, workflow_execution_id=%s, sandbox_id=%s",
self._workflow_execution_id,
sandbox_id,
)
try:
self._sandbox.release_environment()
sandbox.release_environment()
logger.info("Sandbox released, sandbox_id=%s", sandbox_id)
except Exception:
# Log but don't raise - cleanup failures should not break workflow completion
logger.exception("Failed to release sandbox, sandbox_id=%s", sandbox_id)
finally:
self._sandbox = None
self._workflow_execution_id = None

View File

@ -0,0 +1,63 @@
import logging
import threading
from typing import Final
from core.virtual_environment.__base.virtual_environment import VirtualEnvironment
logger = logging.getLogger(__name__)
class SandboxManager:
_lock: Final[threading.Lock] = threading.Lock()
_sandboxes: dict[str, VirtualEnvironment] = {}
@classmethod
def register(cls, workflow_execution_id: str, sandbox: VirtualEnvironment) -> None:
if not workflow_execution_id:
raise ValueError("workflow_execution_id cannot be empty")
with cls._lock:
if workflow_execution_id in cls._sandboxes:
raise RuntimeError(
f"Sandbox already registered for workflow_execution_id={workflow_execution_id}. "
"Call unregister() first if you need to replace it."
)
cls._sandboxes[workflow_execution_id] = sandbox
logger.debug(
"Registered sandbox for workflow_execution_id=%s, sandbox_id=%s",
workflow_execution_id,
sandbox.metadata.id,
)
@classmethod
def get(cls, workflow_execution_id: str) -> VirtualEnvironment | None:
with cls._lock:
return cls._sandboxes.get(workflow_execution_id)
@classmethod
def unregister(cls, workflow_execution_id: str) -> VirtualEnvironment | None:
with cls._lock:
sandbox = cls._sandboxes.pop(workflow_execution_id, None)
if sandbox:
logger.debug(
"Unregistered sandbox for workflow_execution_id=%s, sandbox_id=%s",
workflow_execution_id,
sandbox.metadata.id,
)
return sandbox
@classmethod
def has(cls, workflow_execution_id: str) -> bool:
with cls._lock:
return workflow_execution_id in cls._sandboxes
@classmethod
def clear(cls) -> None:
with cls._lock:
cls._sandboxes.clear()
logger.debug("Cleared all registered sandboxes")
@classmethod
def count(cls) -> int:
with cls._lock:
return len(cls._sandboxes)

View File

@ -6,6 +6,7 @@ from typing import Any
from core.virtual_environment.__base.command_future import CommandCancelledError, CommandTimeoutError
from core.virtual_environment.__base.virtual_environment import VirtualEnvironment
from core.virtual_environment.sandbox_manager import SandboxManager
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base import variable_template_parser
@ -21,14 +22,14 @@ COMMAND_NODE_TIMEOUT_SECONDS = 60
class CommandNode(Node[CommandNodeData]):
# FIXME: This is a temporary solution for sandbox injection from SandboxLayer.
# The sandbox is dynamically attached by SandboxLayer.on_node_run_start() before
# node execution and cleared by on_node_run_end(). A cleaner approach would be
# to pass sandbox through GraphRuntimeState or use a proper dependency injection pattern.
sandbox: VirtualEnvironment | None = None
node_type = NodeType.COMMAND
def _get_sandbox(self) -> VirtualEnvironment | None:
workflow_execution_id = self.graph_runtime_state.variable_pool.system_variables.workflow_execution_id
if not workflow_execution_id:
return None
return SandboxManager.get(workflow_execution_id)
def _render_template(self, template: str) -> str:
parser = VariableTemplateParser(template=template)
selectors = parser.extract_variable_selectors()
@ -57,7 +58,8 @@ class CommandNode(Node[CommandNodeData]):
return "1"
def _run(self) -> NodeRunResult:
if not isinstance(self.sandbox, VirtualEnvironment):
sandbox = self._get_sandbox()
if sandbox is None:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error="Sandbox not available for CommandNode.",
@ -77,15 +79,15 @@ class CommandNode(Node[CommandNodeData]):
)
timeout = COMMAND_NODE_TIMEOUT_SECONDS if COMMAND_NODE_TIMEOUT_SECONDS > 0 else None
connection_handle = self.sandbox.establish_connection()
connection_handle = sandbox.establish_connection()
try:
# FIXME: VirtualEnvironment.run_command lacks native cwd support.
# TODO: VirtualEnvironment.run_command lacks native cwd support.
# Once the interface adds a `cwd` parameter, remove this shell hack
# and pass working_directory directly to run_command.
if working_directory:
check_cmd = ["test", "-d", working_directory]
check_future = self.sandbox.run_command(connection_handle, check_cmd)
check_future = sandbox.run_command(connection_handle, check_cmd)
check_result = check_future.result(timeout=timeout)
if check_result.exit_code != 0:
@ -99,7 +101,7 @@ class CommandNode(Node[CommandNodeData]):
else:
command = shlex.split(raw_command)
future = self.sandbox.run_command(connection_handle, command)
future = sandbox.run_command(connection_handle, command)
result = future.result(timeout=timeout)
outputs: dict[str, Any] = {
@ -149,7 +151,7 @@ class CommandNode(Node[CommandNodeData]):
)
finally:
with contextlib.suppress(Exception):
self.sandbox.release_connection(connection_handle)
sandbox.release_connection(connection_handle)
@classmethod
def _extract_variable_selector_to_variable_mapping(

View File

@ -697,11 +697,17 @@ class WorkflowService:
else:
enclosing_node_id = None
# FIXME: Consolidate runtime config checking into a unified location.
# TODO: Consolidate runtime config checking into a unified location.
runtime = draft_workflow.features_dict.get("runtime")
sandbox = None
single_step_execution_id: str | None = None
if isinstance(runtime, dict) and runtime.get("enabled"):
sandbox = SandboxProviderService.create_sandbox(tenant_id=draft_workflow.tenant_id)
single_step_execution_id = f"single-step-{uuid.uuid4()}"
from core.virtual_environment.sandbox_manager import SandboxManager
SandboxManager.register(single_step_execution_id, sandbox)
variable_pool.system_variables.workflow_execution_id = single_step_execution_id
try:
node, generator = WorkflowEntry.single_step_run(
@ -713,10 +719,6 @@ class WorkflowService:
variable_loader=variable_loader,
)
# FIXME: Use a proper dependency injection pattern for sandbox.
if sandbox:
node.sandbox = sandbox # type: ignore[attr-defined]
# Run draft workflow node
start_at = time.perf_counter()
node_execution = self._handle_single_step_result(
@ -725,12 +727,15 @@ class WorkflowService:
node_id=node_id,
)
finally:
# Release sandbox after node execution
if sandbox:
try:
sandbox.release_environment()
except Exception:
logger.exception("Failed to release sandbox")
if single_step_execution_id:
from core.virtual_environment.sandbox_manager import SandboxManager
sandbox = SandboxManager.unregister(single_step_execution_id)
if sandbox:
try:
sandbox.release_environment()
except Exception:
logger.exception("Failed to release sandbox")
# Set workflow_id on the NodeExecution
node_execution.workflow_id = draft_workflow.id

View File

@ -1,11 +1,3 @@
"""
Unit tests for the SandboxLayer.
This module tests the SandboxLayer lifecycle management including initialization,
event handling, and cleanup of VirtualEnvironment instances.
"""
from pathlib import Path
from unittest.mock import MagicMock, patch
import pytest
@ -13,7 +5,7 @@ import pytest
from core.app.layers.sandbox_layer import SandboxInitializationError, SandboxLayer
from core.virtual_environment.__base.entities import Arch
from core.virtual_environment.__base.virtual_environment import VirtualEnvironment
from core.virtual_environment.factory import SandboxFactory, SandboxType
from core.virtual_environment.sandbox_manager import SandboxManager
from core.workflow.graph_engine.layers.base import GraphEngineLayerNotInitializedError
from core.workflow.graph_events.graph import (
GraphRunFailedEvent,
@ -23,16 +15,12 @@ from core.workflow.graph_events.graph import (
class MockMetadata:
"""Mock metadata for testing."""
def __init__(self, sandbox_id: str = "test-sandbox-id", arch: Arch = Arch.AMD64):
self.id = sandbox_id
self.arch = arch
class MockVirtualEnvironment:
"""Mock VirtualEnvironment for testing."""
def __init__(self, sandbox_id: str = "test-sandbox-id"):
self.metadata = MockMetadata(sandbox_id=sandbox_id)
self._released = False
@ -41,33 +29,46 @@ class MockVirtualEnvironment:
self._released = True
class MockSystemVariableView:
def __init__(self, workflow_execution_id: str | None = "test-workflow-exec-id"):
self._workflow_execution_id = workflow_execution_id
@property
def workflow_execution_id(self) -> str | None:
return self._workflow_execution_id
class MockReadOnlyGraphRuntimeStateWrapper:
def __init__(self, workflow_execution_id: str | None = "test-workflow-exec-id"):
self._system_variable = MockSystemVariableView(workflow_execution_id)
@property
def system_variable(self) -> MockSystemVariableView:
return self._system_variable
@pytest.fixture(autouse=True)
def clean_sandbox_manager():
SandboxManager.clear()
yield
SandboxManager.clear()
class TestSandboxLayer:
"""Unit tests for SandboxLayer."""
def test_init_with_default_parameters(self):
"""Test SandboxLayer initialization with default parameters."""
layer = SandboxLayer()
assert layer._sandbox_type is None # pyright: ignore[reportPrivateUsage]
assert layer._options == {} # pyright: ignore[reportPrivateUsage]
assert layer._environments == {} # pyright: ignore[reportPrivateUsage]
assert layer._sandbox is None # pyright: ignore[reportPrivateUsage]
def test_init_with_custom_parameters(self):
"""Test SandboxLayer initialization with custom parameters."""
def test_init_with_parameters(self):
layer = SandboxLayer(
sandbox_type=SandboxType.LOCAL,
tenant_id="test-tenant",
options={"base_working_path": "/tmp/sandbox"},
environments={"PYTHONUNBUFFERED": "1"},
)
assert layer._sandbox_type == SandboxType.LOCAL # pyright: ignore[reportPrivateUsage]
assert layer._tenant_id == "test-tenant" # pyright: ignore[reportPrivateUsage]
assert layer._options == {"base_working_path": "/tmp/sandbox"} # pyright: ignore[reportPrivateUsage]
assert layer._environments == {"PYTHONUNBUFFERED": "1"} # pyright: ignore[reportPrivateUsage]
assert layer._workflow_execution_id is None # pyright: ignore[reportPrivateUsage]
def test_sandbox_property_raises_when_not_initialized(self):
"""Test that accessing sandbox property raises error before initialization."""
layer = SandboxLayer()
layer = SandboxLayer(tenant_id="test-tenant")
with pytest.raises(RuntimeError) as exc_info:
_ = layer.sandbox
@ -75,170 +76,213 @@ class TestSandboxLayer:
assert "Sandbox not initialized" in str(exc_info.value)
def test_sandbox_property_returns_sandbox_after_initialization(self):
"""Test that sandbox property returns the sandbox after on_graph_start."""
layer = SandboxLayer()
layer = SandboxLayer(tenant_id="test-tenant")
mock_sandbox = MockVirtualEnvironment()
mock_runtime_state = MockReadOnlyGraphRuntimeStateWrapper("test-exec-id")
layer._graph_runtime_state = mock_runtime_state # type: ignore[assignment]
with patch.object(SandboxFactory, "create", return_value=mock_sandbox):
with patch(
"services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox",
return_value=mock_sandbox,
):
layer.on_graph_start()
assert layer.sandbox is mock_sandbox
def test_on_graph_start_creates_sandbox(self):
"""Test that on_graph_start creates a sandbox via factory."""
def test_on_graph_start_creates_sandbox_and_registers_with_manager(self):
layer = SandboxLayer(
sandbox_type=SandboxType.DOCKER,
options={"docker_image": "python:3.11"},
tenant_id="test-tenant-123",
environments={"PATH": "/usr/bin"},
)
mock_sandbox = MockVirtualEnvironment()
mock_runtime_state = MockReadOnlyGraphRuntimeStateWrapper("test-exec-123")
layer._graph_runtime_state = mock_runtime_state # type: ignore[assignment]
with patch.object(SandboxFactory, "create", return_value=mock_sandbox) as mock_create:
with patch(
"services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox",
return_value=mock_sandbox,
) as mock_create:
layer.on_graph_start()
mock_create.assert_called_once_with(
tenant_id="default",
sandbox_type=SandboxType.DOCKER,
options={"docker_image": "python:3.11"},
tenant_id="test-tenant-123",
environments={"PATH": "/usr/bin"},
)
def test_on_graph_start_raises_sandbox_initialization_error_on_failure(self):
"""Test that on_graph_start raises SandboxInitializationError on factory failure."""
layer = SandboxLayer(sandbox_type=SandboxType.DOCKER)
assert SandboxManager.get("test-exec-123") is mock_sandbox
with patch.object(SandboxFactory, "create", side_effect=Exception("Docker not available")):
def test_on_graph_start_raises_sandbox_initialization_error_on_failure(self):
layer = SandboxLayer(tenant_id="test-tenant")
mock_runtime_state = MockReadOnlyGraphRuntimeStateWrapper("test-exec-id")
layer._graph_runtime_state = mock_runtime_state # type: ignore[assignment]
with patch(
"services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox",
side_effect=Exception("Sandbox provider not available"),
):
with pytest.raises(SandboxInitializationError) as exc_info:
layer.on_graph_start()
assert "Failed to initialize sandbox" in str(exc_info.value)
assert "Docker not available" in str(exc_info.value)
assert "Sandbox provider not available" in str(exc_info.value)
def test_on_graph_start_raises_when_workflow_execution_id_not_set(self):
layer = SandboxLayer(tenant_id="test-tenant")
mock_runtime_state = MockReadOnlyGraphRuntimeStateWrapper(workflow_execution_id=None)
layer._graph_runtime_state = mock_runtime_state # type: ignore[assignment]
with pytest.raises(RuntimeError) as exc_info:
layer.on_graph_start()
assert "workflow_execution_id is not set" in str(exc_info.value)
def test_on_event_is_noop(self):
"""Test that on_event does nothing (no-op)."""
layer = SandboxLayer()
layer = SandboxLayer(tenant_id="test-tenant")
# These should not raise any exceptions
layer.on_event(GraphRunStartedEvent())
layer.on_event(GraphRunSucceededEvent(outputs={}))
layer.on_event(GraphRunFailedEvent(error="test error", exceptions_count=1))
def test_on_graph_end_releases_sandbox(self):
"""Test that on_graph_end releases the sandbox."""
layer = SandboxLayer()
def test_on_graph_end_releases_sandbox_and_unregisters_from_manager(self):
layer = SandboxLayer(tenant_id="test-tenant")
mock_sandbox = MagicMock(spec=VirtualEnvironment)
mock_sandbox.metadata = MockMetadata()
workflow_execution_id = "test-exec-456"
mock_runtime_state = MockReadOnlyGraphRuntimeStateWrapper(workflow_execution_id)
layer._graph_runtime_state = mock_runtime_state # type: ignore[assignment]
with patch.object(SandboxFactory, "create", return_value=mock_sandbox):
with patch(
"services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox",
return_value=mock_sandbox,
):
layer.on_graph_start()
assert SandboxManager.has(workflow_execution_id)
layer.on_graph_end(error=None)
mock_sandbox.release_environment.assert_called_once()
assert layer._sandbox is None # pyright: ignore[reportPrivateUsage]
assert layer._workflow_execution_id is None # pyright: ignore[reportPrivateUsage]
assert not SandboxManager.has(workflow_execution_id)
def test_on_graph_end_releases_sandbox_even_on_error(self):
"""Test that on_graph_end releases sandbox even when workflow had an error."""
layer = SandboxLayer()
layer = SandboxLayer(tenant_id="test-tenant")
mock_sandbox = MagicMock(spec=VirtualEnvironment)
mock_sandbox.metadata = MockMetadata()
workflow_execution_id = "test-exec-789"
mock_runtime_state = MockReadOnlyGraphRuntimeStateWrapper(workflow_execution_id)
layer._graph_runtime_state = mock_runtime_state # type: ignore[assignment]
with patch.object(SandboxFactory, "create", return_value=mock_sandbox):
with patch(
"services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox",
return_value=mock_sandbox,
):
layer.on_graph_start()
layer.on_graph_end(error=Exception("Workflow failed"))
mock_sandbox.release_environment.assert_called_once()
assert layer._sandbox is None # pyright: ignore[reportPrivateUsage]
assert layer._workflow_execution_id is None # pyright: ignore[reportPrivateUsage]
assert not SandboxManager.has(workflow_execution_id)
def test_on_graph_end_handles_release_failure_gracefully(self):
"""Test that on_graph_end handles release failures without raising."""
layer = SandboxLayer()
layer = SandboxLayer(tenant_id="test-tenant")
mock_sandbox = MagicMock(spec=VirtualEnvironment)
mock_sandbox.metadata = MockMetadata()
mock_sandbox.release_environment.side_effect = Exception("Container already removed")
workflow_execution_id = "test-exec-fail"
mock_runtime_state = MockReadOnlyGraphRuntimeStateWrapper(workflow_execution_id)
layer._graph_runtime_state = mock_runtime_state # type: ignore[assignment]
with patch.object(SandboxFactory, "create", return_value=mock_sandbox):
with patch(
"services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox",
return_value=mock_sandbox,
):
layer.on_graph_start()
# Should not raise exception
layer.on_graph_end(error=None)
mock_sandbox.release_environment.assert_called_once()
assert layer._sandbox is None # pyright: ignore[reportPrivateUsage]
assert layer._workflow_execution_id is None # pyright: ignore[reportPrivateUsage]
def test_on_graph_end_noop_when_sandbox_not_initialized(self):
"""Test that on_graph_end is a no-op when sandbox was never initialized."""
layer = SandboxLayer()
layer = SandboxLayer(tenant_id="test-tenant")
# Should not raise exception
layer.on_graph_end(error=None)
assert layer._sandbox is None # pyright: ignore[reportPrivateUsage]
assert layer._workflow_execution_id is None # pyright: ignore[reportPrivateUsage]
def test_on_graph_end_is_idempotent(self):
"""Test that calling on_graph_end multiple times is safe."""
layer = SandboxLayer()
layer = SandboxLayer(tenant_id="test-tenant")
mock_sandbox = MagicMock(spec=VirtualEnvironment)
mock_sandbox.metadata = MockMetadata()
workflow_execution_id = "test-exec-idempotent"
mock_runtime_state = MockReadOnlyGraphRuntimeStateWrapper(workflow_execution_id)
layer._graph_runtime_state = mock_runtime_state # type: ignore[assignment]
with patch.object(SandboxFactory, "create", return_value=mock_sandbox):
with patch(
"services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox",
return_value=mock_sandbox,
):
layer.on_graph_start()
layer.on_graph_end(error=None)
layer.on_graph_end(error=None) # Second call should be no-op
layer.on_graph_end(error=None)
mock_sandbox.release_environment.assert_called_once()
def test_layer_inherits_from_graph_engine_layer(self):
"""Test that SandboxLayer properly inherits from GraphEngineLayer."""
layer = SandboxLayer()
layer = SandboxLayer(tenant_id="test-tenant")
# Should have the graph_runtime_state property from base class
with pytest.raises(GraphEngineLayerNotInitializedError):
_ = layer.graph_runtime_state
# Should have command_channel from base class
assert layer.command_channel is None
class TestSandboxLayerIntegration:
"""Integration tests for SandboxLayer with real LocalVirtualEnvironment."""
def test_full_lifecycle_with_mocked_provider(self):
layer = SandboxLayer(tenant_id="integration-tenant")
workflow_execution_id = "integration-test-exec"
mock_runtime_state = MockReadOnlyGraphRuntimeStateWrapper(workflow_execution_id)
layer._graph_runtime_state = mock_runtime_state # type: ignore[assignment]
mock_sandbox = MagicMock(spec=VirtualEnvironment)
mock_sandbox.metadata = MockMetadata(sandbox_id="integration-sandbox")
def test_full_lifecycle_with_local_sandbox(self, tmp_path: Path):
"""Test complete lifecycle: init -> start -> end with local sandbox."""
layer = SandboxLayer(
sandbox_type=SandboxType.LOCAL,
options={"base_working_path": str(tmp_path)},
)
with patch(
"services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox",
return_value=mock_sandbox,
):
layer.on_graph_start()
# Start
layer.on_graph_start()
assert layer._workflow_execution_id == workflow_execution_id # pyright: ignore[reportPrivateUsage]
assert layer.sandbox is mock_sandbox
assert SandboxManager.get(workflow_execution_id) is mock_sandbox
# Verify sandbox is created
assert layer._sandbox is not None # pyright: ignore[reportPrivateUsage]
sandbox_id = layer.sandbox.metadata.id
assert sandbox_id is not None
# End
layer.on_graph_end(error=None)
# Verify sandbox is released
assert layer._sandbox is None # pyright: ignore[reportPrivateUsage]
assert layer._workflow_execution_id is None # pyright: ignore[reportPrivateUsage]
assert not SandboxManager.has(workflow_execution_id)
mock_sandbox.release_environment.assert_called_once()
def test_lifecycle_with_workflow_error(self, tmp_path: Path):
"""Test lifecycle when workflow encounters an error."""
layer = SandboxLayer(
sandbox_type=SandboxType.LOCAL,
options={"base_working_path": str(tmp_path)},
)
def test_lifecycle_with_workflow_error(self):
layer = SandboxLayer(tenant_id="error-tenant")
workflow_execution_id = "integration-error-test"
mock_runtime_state = MockReadOnlyGraphRuntimeStateWrapper(workflow_execution_id)
layer._graph_runtime_state = mock_runtime_state # type: ignore[assignment]
mock_sandbox = MagicMock(spec=VirtualEnvironment)
mock_sandbox.metadata = MockMetadata()
with patch(
"services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox",
return_value=mock_sandbox,
):
layer.on_graph_start()
layer.on_graph_start()
assert layer.sandbox.metadata.id is not None
# Simulate workflow error
layer.on_graph_end(error=Exception("Workflow execution failed"))
# Sandbox should still be cleaned up
# pyright: ignore[reportPrivateUsage]
assert layer._sandbox is None # pyright: ignore[reportPrivateUsage]
assert layer._workflow_execution_id is None # pyright: ignore[reportPrivateUsage]
assert not SandboxManager.has(workflow_execution_id)
mock_sandbox.release_environment.assert_called_once()

View File

@ -0,0 +1,153 @@
import threading
from collections.abc import Mapping
from io import BytesIO
from typing import Any
import pytest
from core.virtual_environment.__base.entities import Arch, CommandStatus, ConnectionHandle, FileState, Metadata
from core.virtual_environment.__base.virtual_environment import VirtualEnvironment
from core.virtual_environment.sandbox_manager import SandboxManager
class FakeVirtualEnvironment(VirtualEnvironment):
def __init__(self, sandbox_id: str = "fake-id"):
self._sandbox_id = sandbox_id
super().__init__(tenant_id="test-tenant", options={}, environments={})
def _construct_environment(self, options: Mapping[str, Any], environments: Mapping[str, str]) -> Metadata:
return Metadata(id=self._sandbox_id, arch=Arch.AMD64)
def upload_file(self, path: str, content: BytesIO) -> None:
raise NotImplementedError
def download_file(self, path: str) -> BytesIO:
raise NotImplementedError
def list_files(self, directory_path: str, limit: int) -> list[FileState]:
return []
def establish_connection(self) -> ConnectionHandle:
return ConnectionHandle(id="conn")
def release_connection(self, connection_handle: ConnectionHandle) -> None:
pass
def release_environment(self) -> None:
pass
def execute_command(
self, connection_handle: ConnectionHandle, command: list[str], environments: Mapping[str, str] | None = None
) -> tuple[str, Any, Any, Any]:
raise NotImplementedError
def get_command_status(self, connection_handle: ConnectionHandle, pid: str) -> CommandStatus:
return CommandStatus(status=CommandStatus.Status.COMPLETED, exit_code=0)
@classmethod
def validate(cls, options: Mapping[str, Any]) -> None:
pass
@pytest.fixture(autouse=True)
def clean_sandbox_manager():
SandboxManager.clear()
yield
SandboxManager.clear()
class TestSandboxManager:
def test_register_and_get(self):
sandbox = FakeVirtualEnvironment("sandbox-1")
SandboxManager.register("exec-1", sandbox)
result = SandboxManager.get("exec-1")
assert result is sandbox
def test_get_returns_none_for_unknown_id(self):
result = SandboxManager.get("unknown-id")
assert result is None
def test_register_raises_on_empty_workflow_execution_id(self):
sandbox = FakeVirtualEnvironment()
with pytest.raises(ValueError, match="workflow_execution_id cannot be empty"):
SandboxManager.register("", sandbox)
def test_register_raises_on_duplicate(self):
sandbox1 = FakeVirtualEnvironment("sandbox-1")
sandbox2 = FakeVirtualEnvironment("sandbox-2")
SandboxManager.register("exec-dup", sandbox1)
with pytest.raises(RuntimeError, match="already registered"):
SandboxManager.register("exec-dup", sandbox2)
def test_unregister_returns_sandbox(self):
sandbox = FakeVirtualEnvironment("sandbox-to-remove")
SandboxManager.register("exec-remove", sandbox)
result = SandboxManager.unregister("exec-remove")
assert result is sandbox
assert SandboxManager.get("exec-remove") is None
def test_unregister_returns_none_for_unknown(self):
result = SandboxManager.unregister("nonexistent")
assert result is None
def test_has_returns_true_when_registered(self):
sandbox = FakeVirtualEnvironment()
SandboxManager.register("exec-has", sandbox)
assert SandboxManager.has("exec-has") is True
def test_has_returns_false_when_not_registered(self):
assert SandboxManager.has("exec-no") is False
def test_clear_removes_all_sandboxes(self):
sandbox1 = FakeVirtualEnvironment("s1")
sandbox2 = FakeVirtualEnvironment("s2")
SandboxManager.register("exec-1", sandbox1)
SandboxManager.register("exec-2", sandbox2)
SandboxManager.clear()
assert SandboxManager.count() == 0
assert SandboxManager.get("exec-1") is None
assert SandboxManager.get("exec-2") is None
def test_count_returns_number_of_sandboxes(self):
assert SandboxManager.count() == 0
SandboxManager.register("e1", FakeVirtualEnvironment("s1"))
assert SandboxManager.count() == 1
SandboxManager.register("e2", FakeVirtualEnvironment("s2"))
assert SandboxManager.count() == 2
SandboxManager.unregister("e1")
assert SandboxManager.count() == 1
def test_thread_safety(self):
results: list[bool] = []
errors: list[Exception] = []
def register_sandbox(exec_id: str):
try:
sandbox = FakeVirtualEnvironment(f"sandbox-{exec_id}")
SandboxManager.register(exec_id, sandbox)
results.append(True)
except Exception as e:
errors.append(e)
threads = [threading.Thread(target=register_sandbox, args=(f"exec-{i}",)) for i in range(10)]
for t in threads:
t.start()
for t in threads:
t.join()
assert len(errors) == 0
assert len(results) == 10
assert SandboxManager.count() == 10

View File

@ -1,11 +1,15 @@
import time
from collections.abc import Mapping
from io import BytesIO
from typing import Any
import pytest
from core.virtual_environment.__base.entities import Arch, CommandStatus, ConnectionHandle, FileState, Metadata
from core.virtual_environment.__base.virtual_environment import VirtualEnvironment
from core.virtual_environment.channel.queue_transport import QueueTransportReadCloser
from core.virtual_environment.channel.transport import NopTransportWriteCloser
from core.virtual_environment.sandbox_manager import SandboxManager
from core.workflow.entities import GraphInitParams
from core.workflow.enums import WorkflowNodeExecutionStatus
from core.workflow.nodes.command.node import CommandNode
@ -30,7 +34,7 @@ class FakeSandbox(VirtualEnvironment):
self.released_connections: list[str] = []
super().__init__(tenant_id="test-tenant", options={}, environments={})
def _construct_environment(self, options, environments): # type: ignore[override]
def _construct_environment(self, options: Mapping[str, Any], environments: Mapping[str, str]) -> Metadata:
return Metadata(id="fake", arch=Arch.ARM64)
def upload_file(self, path: str, content: BytesIO) -> None:
@ -51,7 +55,9 @@ class FakeSandbox(VirtualEnvironment):
def release_environment(self) -> None:
return
def execute_command(self, connection_handle: ConnectionHandle, command: list[str], environments=None): # type: ignore[override]
def execute_command(
self, connection_handle: ConnectionHandle, command: list[str], environments: Mapping[str, str] | None = None
) -> tuple[str, NopTransportWriteCloser, QueueTransportReadCloser, QueueTransportReadCloser]:
_ = connection_handle
_ = environments
self.last_execute_command = command
@ -76,12 +82,22 @@ class FakeSandbox(VirtualEnvironment):
return CommandStatus(status=CommandStatus.Status.COMPLETED, exit_code=0)
@classmethod
def validate(cls, options: Any) -> None:
def validate(cls, options: Mapping[str, Any]) -> None:
pass
def _make_node(*, command: str, working_directory: str = "") -> CommandNode:
variable_pool = VariablePool(system_variables=SystemVariable.empty(), user_inputs={})
@pytest.fixture(autouse=True)
def clean_sandbox_manager():
SandboxManager.clear()
yield
SandboxManager.clear()
def _make_node(
*, command: str, working_directory: str = "", workflow_execution_id: str = "test-workflow-exec-id"
) -> CommandNode:
system_variables = SystemVariable(workflow_execution_id=workflow_execution_id)
variable_pool = VariablePool(system_variables=system_variables, user_inputs={})
runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
init_params = GraphInitParams(
tenant_id="t",
@ -110,11 +126,16 @@ def _make_node(*, command: str, working_directory: str = "") -> CommandNode:
def test_command_node_success_executes_in_sandbox():
node = _make_node(command="echo {{#pre_node_id.number#}}", working_directory="dir-{{#pre_node_id.number#}}")
workflow_execution_id = "test-exec-success"
node = _make_node(
command="echo {{#pre_node_id.number#}}",
working_directory="dir-{{#pre_node_id.number#}}",
workflow_execution_id=workflow_execution_id,
)
node.graph_runtime_state.variable_pool.add(("pre_node_id", "number"), 42)
sandbox = FakeSandbox(stdout=b"ok\n", stderr=b"")
node.sandbox = sandbox
SandboxManager.register(workflow_execution_id, sandbox)
result = node._run() # pyright: ignore[reportPrivateUsage]
@ -124,18 +145,19 @@ def test_command_node_success_executes_in_sandbox():
assert result.outputs["exit_code"] == 0
assert sandbox.last_execute_command is not None
assert sandbox.last_execute_command[:2] == ["sh", "-lc"]
assert sandbox.last_execute_command[:2] == ["sh", "-c"]
assert "cd dir-42 && echo 42" in sandbox.last_execute_command[2]
def test_command_node_nonzero_exit_code_returns_failed_result():
node = _make_node(command="false")
workflow_execution_id = "test-exec-nonzero"
node = _make_node(command="false", workflow_execution_id=workflow_execution_id)
sandbox = FakeSandbox(
stdout=b"out",
stderr=b"err",
statuses=[CommandStatus(status=CommandStatus.Status.COMPLETED, exit_code=2)],
)
node.sandbox = sandbox
SandboxManager.register(workflow_execution_id, sandbox)
result = node._run() # pyright: ignore[reportPrivateUsage]
@ -149,17 +171,29 @@ def test_command_node_timeout_returns_failed_result_and_closes_transports(monkey
monkeypatch.setattr(command_node_module, "COMMAND_NODE_TIMEOUT_SECONDS", 1)
node = _make_node(command="sleep 10")
workflow_execution_id = "test-exec-timeout"
node = _make_node(command="sleep 10", workflow_execution_id=workflow_execution_id)
sandbox = FakeSandbox(
stdout=b"",
stderr=b"",
statuses=[CommandStatus(status=CommandStatus.Status.RUNNING, exit_code=None)] * 1000,
close_streams=False,
)
node.sandbox = sandbox
SandboxManager.register(workflow_execution_id, sandbox)
result = node._run() # pyright: ignore[reportPrivateUsage]
assert result.status == WorkflowNodeExecutionStatus.FAILED
assert result.error_type == "CommandTimeoutError"
assert "timed out" in result.error
def test_command_node_no_sandbox_returns_failed():
workflow_execution_id = "test-exec-no-sandbox"
node = _make_node(command="echo hello", workflow_execution_id=workflow_execution_id)
result = node._run() # pyright: ignore[reportPrivateUsage]
assert result.status == WorkflowNodeExecutionStatus.FAILED
assert result.error_type == "SandboxNotInitializedError"
assert "Sandbox not available" in result.error