diff --git a/api/core/app/layers/sandbox_layer.py b/api/core/app/layers/sandbox_layer.py index 8385cc6b24..61e1a06e8c 100644 --- a/api/core/app/layers/sandbox_layer.py +++ b/api/core/app/layers/sandbox_layer.py @@ -1,185 +1,95 @@ -""" -Sandbox Layer for managing VirtualEnvironment lifecycle during workflow execution. -""" - -import contextlib import logging from collections.abc import Mapping from typing import Any from core.virtual_environment.__base.virtual_environment import VirtualEnvironment -from core.virtual_environment.factory import SandboxFactory, SandboxType -from core.workflow.enums import NodeType +from core.virtual_environment.sandbox_manager import SandboxManager from core.workflow.graph_engine.layers.base import GraphEngineLayer from core.workflow.graph_events.base import GraphEngineEvent -from core.workflow.nodes.base.node import Node logger = logging.getLogger(__name__) class SandboxInitializationError(Exception): - """Raised when sandbox initialization fails.""" - pass class SandboxLayer(GraphEngineLayer): - """ - Manages VirtualEnvironment (sandbox) lifecycle during workflow execution. - - Responsibilities: - - on_graph_start: Initialize the sandbox environment - - on_graph_end: Release the sandbox environment (cleanup) - - Example: - # Using tenant-specific configuration (recommended): - layer = SandboxLayer(tenant_id="tenant-uuid") - - # Using explicit configuration (for testing/override): - layer = SandboxLayer( - sandbox_type=SandboxType.DOCKER, - options={"docker_image": "python:3.11-slim"}, - ) - graph_engine.layer(layer) - - # During workflow execution, access sandbox via: - # layer.sandbox.execute_command(...) - """ - def __init__( self, - tenant_id: str | None = None, - sandbox_type: SandboxType | None = None, + tenant_id: str, options: Mapping[str, Any] | None = None, environments: Mapping[str, str] | None = None, ) -> None: - """ - Initialize the SandboxLayer. - - Args: - tenant_id: Tenant ID to load sandbox configuration from database. - If provided, sandbox_type and options are ignored and - loaded from the tenant's active sandbox provider. - sandbox_type: Type of sandbox to create (default: DOCKER). - Only used if tenant_id is not provided. - options: Sandbox-specific configuration options. - Only used if tenant_id is not provided. - environments: Environment variables to set in the sandbox. - """ super().__init__() self._tenant_id = tenant_id - self._sandbox_type = sandbox_type self._options: Mapping[str, Any] = options or {} self._environments: Mapping[str, str] = environments or {} - self._sandbox: VirtualEnvironment | None = None + self._workflow_execution_id: str | None = None + + def _get_workflow_execution_id(self) -> str: + workflow_execution_id = self.graph_runtime_state.system_variable.workflow_execution_id + if not workflow_execution_id: + raise RuntimeError("workflow_execution_id is not set in system variables") + return workflow_execution_id @property def sandbox(self) -> VirtualEnvironment: - """ - Get the current sandbox instance. - - Returns: - The initialized VirtualEnvironment instance - - Raises: - RuntimeError: If sandbox has not been initialized - """ - if self._sandbox is None: + if self._workflow_execution_id is None: raise RuntimeError("Sandbox not initialized. Ensure on_graph_start() has been called.") - return self._sandbox + sandbox = SandboxManager.get(self._workflow_execution_id) + if sandbox is None: + raise RuntimeError(f"Sandbox not found for workflow_execution_id={self._workflow_execution_id}") + return sandbox def on_graph_start(self) -> None: - """ - Initialize the sandbox when workflow execution starts. + self._workflow_execution_id = self._get_workflow_execution_id() - If tenant_id was provided, uses SandboxProviderService to create - the sandbox with the tenant's active provider configuration. - Otherwise, falls back to explicit sandbox_type/options. - - Raises: - SandboxInitializationError: If sandbox cannot be created - """ try: - if self._tenant_id: - # Use SandboxProviderService to create sandbox based on tenant config - from services.sandbox.sandbox_provider_service import SandboxProviderService + sandbox: VirtualEnvironment + from services.sandbox.sandbox_provider_service import SandboxProviderService - logger.info("Initializing sandbox for tenant_id=%s", self._tenant_id) - self._sandbox = SandboxProviderService.create_sandbox( - tenant_id=self._tenant_id, - environments=self._environments, - ) - else: - # Fallback to explicit configuration (backward compatibility) - sandbox_type = self._sandbox_type or SandboxType.DOCKER - logger.info("Initializing sandbox, sandbox_type=%s", sandbox_type) - # Use a placeholder tenant_id for backward compatibility when tenant_id is not provided - effective_tenant_id = self._tenant_id or "default" - self._sandbox = SandboxFactory.create( - tenant_id=effective_tenant_id, - sandbox_type=sandbox_type, - options=self._options, - environments=self._environments, - ) + logger.info("Initializing sandbox for tenant_id=%s", self._tenant_id) + sandbox = SandboxProviderService.create_sandbox( + tenant_id=self._tenant_id, + environments=self._environments, + ) + SandboxManager.register(self._workflow_execution_id, sandbox) logger.info( - "Sandbox initialized, sandbox_id=%s, sandbox_arch=%s", - self._sandbox.metadata.id, - self._sandbox.metadata.arch, + "Sandbox initialized, workflow_execution_id=%s, sandbox_id=%s, sandbox_arch=%s", + self._workflow_execution_id, + sandbox.metadata.id, + sandbox.metadata.arch, ) except Exception as e: logger.exception("Failed to initialize sandbox") raise SandboxInitializationError(f"Failed to initialize sandbox: {e}") from e def on_event(self, event: GraphEngineEvent) -> None: - """ - Handle graph engine events. - - Currently a no-op, but can be extended for sandbox monitoring/health checks. - """ pass - def on_node_run_start(self, node: Node[Any]) -> None: - """Attach sandbox handle to CommandNode instances.""" - if node.node_type is not NodeType.COMMAND: - return - - try: - # FIXME: type: ignore[attr-defined] - node.sandbox = self.sandbox # type: ignore[attr-defined] - except Exception: - logger.exception("Failed to attach sandbox to node") - - def on_node_run_end(self, node: Node[Any], error: Exception | None) -> None: - _ = error - if node.node_type is not NodeType.COMMAND: - return - - with contextlib.suppress(Exception): - # FIXME: type: ignore[attr-defined] - node.sandbox = None # type: ignore[attr-defined] - def on_graph_end(self, error: Exception | None) -> None: - """ - Release the sandbox when workflow execution ends. - - This method is idempotent and will not raise exceptions on cleanup failure. - - Args: - error: The exception that caused execution to fail, or None if successful - """ - if self._sandbox is None: - logger.debug("No sandbox to release") + if self._workflow_execution_id is None: + logger.debug("No workflow_execution_id set, nothing to release") return - sandbox_id = self._sandbox.metadata.id - logger.info("Releasing sandbox, sandbox_id=%s", sandbox_id) + sandbox = SandboxManager.unregister(self._workflow_execution_id) + if sandbox is None: + logger.debug("No sandbox to release for workflow_execution_id=%s", self._workflow_execution_id) + return + + sandbox_id = sandbox.metadata.id + logger.info( + "Releasing sandbox, workflow_execution_id=%s, sandbox_id=%s", + self._workflow_execution_id, + sandbox_id, + ) try: - self._sandbox.release_environment() + sandbox.release_environment() logger.info("Sandbox released, sandbox_id=%s", sandbox_id) except Exception: - # Log but don't raise - cleanup failures should not break workflow completion logger.exception("Failed to release sandbox, sandbox_id=%s", sandbox_id) finally: - self._sandbox = None + self._workflow_execution_id = None diff --git a/api/core/virtual_environment/sandbox_manager.py b/api/core/virtual_environment/sandbox_manager.py new file mode 100644 index 0000000000..b29f4e25c9 --- /dev/null +++ b/api/core/virtual_environment/sandbox_manager.py @@ -0,0 +1,63 @@ +import logging +import threading +from typing import Final + +from core.virtual_environment.__base.virtual_environment import VirtualEnvironment + +logger = logging.getLogger(__name__) + + +class SandboxManager: + _lock: Final[threading.Lock] = threading.Lock() + _sandboxes: dict[str, VirtualEnvironment] = {} + + @classmethod + def register(cls, workflow_execution_id: str, sandbox: VirtualEnvironment) -> None: + if not workflow_execution_id: + raise ValueError("workflow_execution_id cannot be empty") + + with cls._lock: + if workflow_execution_id in cls._sandboxes: + raise RuntimeError( + f"Sandbox already registered for workflow_execution_id={workflow_execution_id}. " + "Call unregister() first if you need to replace it." + ) + cls._sandboxes[workflow_execution_id] = sandbox + logger.debug( + "Registered sandbox for workflow_execution_id=%s, sandbox_id=%s", + workflow_execution_id, + sandbox.metadata.id, + ) + + @classmethod + def get(cls, workflow_execution_id: str) -> VirtualEnvironment | None: + with cls._lock: + return cls._sandboxes.get(workflow_execution_id) + + @classmethod + def unregister(cls, workflow_execution_id: str) -> VirtualEnvironment | None: + with cls._lock: + sandbox = cls._sandboxes.pop(workflow_execution_id, None) + if sandbox: + logger.debug( + "Unregistered sandbox for workflow_execution_id=%s, sandbox_id=%s", + workflow_execution_id, + sandbox.metadata.id, + ) + return sandbox + + @classmethod + def has(cls, workflow_execution_id: str) -> bool: + with cls._lock: + return workflow_execution_id in cls._sandboxes + + @classmethod + def clear(cls) -> None: + with cls._lock: + cls._sandboxes.clear() + logger.debug("Cleared all registered sandboxes") + + @classmethod + def count(cls) -> int: + with cls._lock: + return len(cls._sandboxes) diff --git a/api/core/workflow/nodes/command/node.py b/api/core/workflow/nodes/command/node.py index cecd1549aa..be603cc6bd 100644 --- a/api/core/workflow/nodes/command/node.py +++ b/api/core/workflow/nodes/command/node.py @@ -6,6 +6,7 @@ from typing import Any from core.virtual_environment.__base.command_future import CommandCancelledError, CommandTimeoutError from core.virtual_environment.__base.virtual_environment import VirtualEnvironment +from core.virtual_environment.sandbox_manager import SandboxManager from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus from core.workflow.node_events import NodeRunResult from core.workflow.nodes.base import variable_template_parser @@ -21,14 +22,14 @@ COMMAND_NODE_TIMEOUT_SECONDS = 60 class CommandNode(Node[CommandNodeData]): - # FIXME: This is a temporary solution for sandbox injection from SandboxLayer. - # The sandbox is dynamically attached by SandboxLayer.on_node_run_start() before - # node execution and cleared by on_node_run_end(). A cleaner approach would be - # to pass sandbox through GraphRuntimeState or use a proper dependency injection pattern. - sandbox: VirtualEnvironment | None = None - node_type = NodeType.COMMAND + def _get_sandbox(self) -> VirtualEnvironment | None: + workflow_execution_id = self.graph_runtime_state.variable_pool.system_variables.workflow_execution_id + if not workflow_execution_id: + return None + return SandboxManager.get(workflow_execution_id) + def _render_template(self, template: str) -> str: parser = VariableTemplateParser(template=template) selectors = parser.extract_variable_selectors() @@ -57,7 +58,8 @@ class CommandNode(Node[CommandNodeData]): return "1" def _run(self) -> NodeRunResult: - if not isinstance(self.sandbox, VirtualEnvironment): + sandbox = self._get_sandbox() + if sandbox is None: return NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, error="Sandbox not available for CommandNode.", @@ -77,15 +79,15 @@ class CommandNode(Node[CommandNodeData]): ) timeout = COMMAND_NODE_TIMEOUT_SECONDS if COMMAND_NODE_TIMEOUT_SECONDS > 0 else None - connection_handle = self.sandbox.establish_connection() + connection_handle = sandbox.establish_connection() try: - # FIXME: VirtualEnvironment.run_command lacks native cwd support. + # TODO: VirtualEnvironment.run_command lacks native cwd support. # Once the interface adds a `cwd` parameter, remove this shell hack # and pass working_directory directly to run_command. if working_directory: check_cmd = ["test", "-d", working_directory] - check_future = self.sandbox.run_command(connection_handle, check_cmd) + check_future = sandbox.run_command(connection_handle, check_cmd) check_result = check_future.result(timeout=timeout) if check_result.exit_code != 0: @@ -99,7 +101,7 @@ class CommandNode(Node[CommandNodeData]): else: command = shlex.split(raw_command) - future = self.sandbox.run_command(connection_handle, command) + future = sandbox.run_command(connection_handle, command) result = future.result(timeout=timeout) outputs: dict[str, Any] = { @@ -149,7 +151,7 @@ class CommandNode(Node[CommandNodeData]): ) finally: with contextlib.suppress(Exception): - self.sandbox.release_connection(connection_handle) + sandbox.release_connection(connection_handle) @classmethod def _extract_variable_selector_to_variable_mapping( diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 776866999c..a8a54ec3f6 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -697,11 +697,17 @@ class WorkflowService: else: enclosing_node_id = None - # FIXME: Consolidate runtime config checking into a unified location. + # TODO: Consolidate runtime config checking into a unified location. runtime = draft_workflow.features_dict.get("runtime") sandbox = None + single_step_execution_id: str | None = None if isinstance(runtime, dict) and runtime.get("enabled"): sandbox = SandboxProviderService.create_sandbox(tenant_id=draft_workflow.tenant_id) + single_step_execution_id = f"single-step-{uuid.uuid4()}" + from core.virtual_environment.sandbox_manager import SandboxManager + + SandboxManager.register(single_step_execution_id, sandbox) + variable_pool.system_variables.workflow_execution_id = single_step_execution_id try: node, generator = WorkflowEntry.single_step_run( @@ -713,10 +719,6 @@ class WorkflowService: variable_loader=variable_loader, ) - # FIXME: Use a proper dependency injection pattern for sandbox. - if sandbox: - node.sandbox = sandbox # type: ignore[attr-defined] - # Run draft workflow node start_at = time.perf_counter() node_execution = self._handle_single_step_result( @@ -725,12 +727,15 @@ class WorkflowService: node_id=node_id, ) finally: - # Release sandbox after node execution - if sandbox: - try: - sandbox.release_environment() - except Exception: - logger.exception("Failed to release sandbox") + if single_step_execution_id: + from core.virtual_environment.sandbox_manager import SandboxManager + + sandbox = SandboxManager.unregister(single_step_execution_id) + if sandbox: + try: + sandbox.release_environment() + except Exception: + logger.exception("Failed to release sandbox") # Set workflow_id on the NodeExecution node_execution.workflow_id = draft_workflow.id diff --git a/api/tests/unit_tests/core/app/layers/test_sandbox_layer.py b/api/tests/unit_tests/core/app/layers/test_sandbox_layer.py index a6518d45a9..4358a8d4d6 100644 --- a/api/tests/unit_tests/core/app/layers/test_sandbox_layer.py +++ b/api/tests/unit_tests/core/app/layers/test_sandbox_layer.py @@ -1,11 +1,3 @@ -""" -Unit tests for the SandboxLayer. - -This module tests the SandboxLayer lifecycle management including initialization, -event handling, and cleanup of VirtualEnvironment instances. -""" - -from pathlib import Path from unittest.mock import MagicMock, patch import pytest @@ -13,7 +5,7 @@ import pytest from core.app.layers.sandbox_layer import SandboxInitializationError, SandboxLayer from core.virtual_environment.__base.entities import Arch from core.virtual_environment.__base.virtual_environment import VirtualEnvironment -from core.virtual_environment.factory import SandboxFactory, SandboxType +from core.virtual_environment.sandbox_manager import SandboxManager from core.workflow.graph_engine.layers.base import GraphEngineLayerNotInitializedError from core.workflow.graph_events.graph import ( GraphRunFailedEvent, @@ -23,16 +15,12 @@ from core.workflow.graph_events.graph import ( class MockMetadata: - """Mock metadata for testing.""" - def __init__(self, sandbox_id: str = "test-sandbox-id", arch: Arch = Arch.AMD64): self.id = sandbox_id self.arch = arch class MockVirtualEnvironment: - """Mock VirtualEnvironment for testing.""" - def __init__(self, sandbox_id: str = "test-sandbox-id"): self.metadata = MockMetadata(sandbox_id=sandbox_id) self._released = False @@ -41,33 +29,46 @@ class MockVirtualEnvironment: self._released = True +class MockSystemVariableView: + def __init__(self, workflow_execution_id: str | None = "test-workflow-exec-id"): + self._workflow_execution_id = workflow_execution_id + + @property + def workflow_execution_id(self) -> str | None: + return self._workflow_execution_id + + +class MockReadOnlyGraphRuntimeStateWrapper: + def __init__(self, workflow_execution_id: str | None = "test-workflow-exec-id"): + self._system_variable = MockSystemVariableView(workflow_execution_id) + + @property + def system_variable(self) -> MockSystemVariableView: + return self._system_variable + + +@pytest.fixture(autouse=True) +def clean_sandbox_manager(): + SandboxManager.clear() + yield + SandboxManager.clear() + + class TestSandboxLayer: - """Unit tests for SandboxLayer.""" - - def test_init_with_default_parameters(self): - """Test SandboxLayer initialization with default parameters.""" - layer = SandboxLayer() - - assert layer._sandbox_type is None # pyright: ignore[reportPrivateUsage] - assert layer._options == {} # pyright: ignore[reportPrivateUsage] - assert layer._environments == {} # pyright: ignore[reportPrivateUsage] - assert layer._sandbox is None # pyright: ignore[reportPrivateUsage] - - def test_init_with_custom_parameters(self): - """Test SandboxLayer initialization with custom parameters.""" + def test_init_with_parameters(self): layer = SandboxLayer( - sandbox_type=SandboxType.LOCAL, + tenant_id="test-tenant", options={"base_working_path": "/tmp/sandbox"}, environments={"PYTHONUNBUFFERED": "1"}, ) - assert layer._sandbox_type == SandboxType.LOCAL # pyright: ignore[reportPrivateUsage] + assert layer._tenant_id == "test-tenant" # pyright: ignore[reportPrivateUsage] assert layer._options == {"base_working_path": "/tmp/sandbox"} # pyright: ignore[reportPrivateUsage] assert layer._environments == {"PYTHONUNBUFFERED": "1"} # pyright: ignore[reportPrivateUsage] + assert layer._workflow_execution_id is None # pyright: ignore[reportPrivateUsage] def test_sandbox_property_raises_when_not_initialized(self): - """Test that accessing sandbox property raises error before initialization.""" - layer = SandboxLayer() + layer = SandboxLayer(tenant_id="test-tenant") with pytest.raises(RuntimeError) as exc_info: _ = layer.sandbox @@ -75,170 +76,213 @@ class TestSandboxLayer: assert "Sandbox not initialized" in str(exc_info.value) def test_sandbox_property_returns_sandbox_after_initialization(self): - """Test that sandbox property returns the sandbox after on_graph_start.""" - layer = SandboxLayer() + layer = SandboxLayer(tenant_id="test-tenant") mock_sandbox = MockVirtualEnvironment() + mock_runtime_state = MockReadOnlyGraphRuntimeStateWrapper("test-exec-id") + layer._graph_runtime_state = mock_runtime_state # type: ignore[assignment] - with patch.object(SandboxFactory, "create", return_value=mock_sandbox): + with patch( + "services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox", + return_value=mock_sandbox, + ): layer.on_graph_start() assert layer.sandbox is mock_sandbox - def test_on_graph_start_creates_sandbox(self): - """Test that on_graph_start creates a sandbox via factory.""" + def test_on_graph_start_creates_sandbox_and_registers_with_manager(self): layer = SandboxLayer( - sandbox_type=SandboxType.DOCKER, - options={"docker_image": "python:3.11"}, + tenant_id="test-tenant-123", environments={"PATH": "/usr/bin"}, ) mock_sandbox = MockVirtualEnvironment() + mock_runtime_state = MockReadOnlyGraphRuntimeStateWrapper("test-exec-123") + layer._graph_runtime_state = mock_runtime_state # type: ignore[assignment] - with patch.object(SandboxFactory, "create", return_value=mock_sandbox) as mock_create: + with patch( + "services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox", + return_value=mock_sandbox, + ) as mock_create: layer.on_graph_start() mock_create.assert_called_once_with( - tenant_id="default", - sandbox_type=SandboxType.DOCKER, - options={"docker_image": "python:3.11"}, + tenant_id="test-tenant-123", environments={"PATH": "/usr/bin"}, ) - def test_on_graph_start_raises_sandbox_initialization_error_on_failure(self): - """Test that on_graph_start raises SandboxInitializationError on factory failure.""" - layer = SandboxLayer(sandbox_type=SandboxType.DOCKER) + assert SandboxManager.get("test-exec-123") is mock_sandbox - with patch.object(SandboxFactory, "create", side_effect=Exception("Docker not available")): + def test_on_graph_start_raises_sandbox_initialization_error_on_failure(self): + layer = SandboxLayer(tenant_id="test-tenant") + mock_runtime_state = MockReadOnlyGraphRuntimeStateWrapper("test-exec-id") + layer._graph_runtime_state = mock_runtime_state # type: ignore[assignment] + + with patch( + "services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox", + side_effect=Exception("Sandbox provider not available"), + ): with pytest.raises(SandboxInitializationError) as exc_info: layer.on_graph_start() assert "Failed to initialize sandbox" in str(exc_info.value) - assert "Docker not available" in str(exc_info.value) + assert "Sandbox provider not available" in str(exc_info.value) + + def test_on_graph_start_raises_when_workflow_execution_id_not_set(self): + layer = SandboxLayer(tenant_id="test-tenant") + mock_runtime_state = MockReadOnlyGraphRuntimeStateWrapper(workflow_execution_id=None) + layer._graph_runtime_state = mock_runtime_state # type: ignore[assignment] + + with pytest.raises(RuntimeError) as exc_info: + layer.on_graph_start() + + assert "workflow_execution_id is not set" in str(exc_info.value) def test_on_event_is_noop(self): - """Test that on_event does nothing (no-op).""" - layer = SandboxLayer() + layer = SandboxLayer(tenant_id="test-tenant") - # These should not raise any exceptions layer.on_event(GraphRunStartedEvent()) layer.on_event(GraphRunSucceededEvent(outputs={})) layer.on_event(GraphRunFailedEvent(error="test error", exceptions_count=1)) - def test_on_graph_end_releases_sandbox(self): - """Test that on_graph_end releases the sandbox.""" - layer = SandboxLayer() + def test_on_graph_end_releases_sandbox_and_unregisters_from_manager(self): + layer = SandboxLayer(tenant_id="test-tenant") mock_sandbox = MagicMock(spec=VirtualEnvironment) mock_sandbox.metadata = MockMetadata() + workflow_execution_id = "test-exec-456" + mock_runtime_state = MockReadOnlyGraphRuntimeStateWrapper(workflow_execution_id) + layer._graph_runtime_state = mock_runtime_state # type: ignore[assignment] - with patch.object(SandboxFactory, "create", return_value=mock_sandbox): + with patch( + "services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox", + return_value=mock_sandbox, + ): layer.on_graph_start() + assert SandboxManager.has(workflow_execution_id) + layer.on_graph_end(error=None) mock_sandbox.release_environment.assert_called_once() - assert layer._sandbox is None # pyright: ignore[reportPrivateUsage] + assert layer._workflow_execution_id is None # pyright: ignore[reportPrivateUsage] + assert not SandboxManager.has(workflow_execution_id) def test_on_graph_end_releases_sandbox_even_on_error(self): - """Test that on_graph_end releases sandbox even when workflow had an error.""" - layer = SandboxLayer() + layer = SandboxLayer(tenant_id="test-tenant") mock_sandbox = MagicMock(spec=VirtualEnvironment) mock_sandbox.metadata = MockMetadata() + workflow_execution_id = "test-exec-789" + mock_runtime_state = MockReadOnlyGraphRuntimeStateWrapper(workflow_execution_id) + layer._graph_runtime_state = mock_runtime_state # type: ignore[assignment] - with patch.object(SandboxFactory, "create", return_value=mock_sandbox): + with patch( + "services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox", + return_value=mock_sandbox, + ): layer.on_graph_start() layer.on_graph_end(error=Exception("Workflow failed")) mock_sandbox.release_environment.assert_called_once() - assert layer._sandbox is None # pyright: ignore[reportPrivateUsage] + assert layer._workflow_execution_id is None # pyright: ignore[reportPrivateUsage] + assert not SandboxManager.has(workflow_execution_id) def test_on_graph_end_handles_release_failure_gracefully(self): - """Test that on_graph_end handles release failures without raising.""" - layer = SandboxLayer() + layer = SandboxLayer(tenant_id="test-tenant") mock_sandbox = MagicMock(spec=VirtualEnvironment) mock_sandbox.metadata = MockMetadata() mock_sandbox.release_environment.side_effect = Exception("Container already removed") + workflow_execution_id = "test-exec-fail" + mock_runtime_state = MockReadOnlyGraphRuntimeStateWrapper(workflow_execution_id) + layer._graph_runtime_state = mock_runtime_state # type: ignore[assignment] - with patch.object(SandboxFactory, "create", return_value=mock_sandbox): + with patch( + "services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox", + return_value=mock_sandbox, + ): layer.on_graph_start() - # Should not raise exception layer.on_graph_end(error=None) mock_sandbox.release_environment.assert_called_once() - assert layer._sandbox is None # pyright: ignore[reportPrivateUsage] + assert layer._workflow_execution_id is None # pyright: ignore[reportPrivateUsage] def test_on_graph_end_noop_when_sandbox_not_initialized(self): - """Test that on_graph_end is a no-op when sandbox was never initialized.""" - layer = SandboxLayer() + layer = SandboxLayer(tenant_id="test-tenant") - # Should not raise exception layer.on_graph_end(error=None) - assert layer._sandbox is None # pyright: ignore[reportPrivateUsage] + assert layer._workflow_execution_id is None # pyright: ignore[reportPrivateUsage] def test_on_graph_end_is_idempotent(self): - """Test that calling on_graph_end multiple times is safe.""" - layer = SandboxLayer() + layer = SandboxLayer(tenant_id="test-tenant") mock_sandbox = MagicMock(spec=VirtualEnvironment) mock_sandbox.metadata = MockMetadata() + workflow_execution_id = "test-exec-idempotent" + mock_runtime_state = MockReadOnlyGraphRuntimeStateWrapper(workflow_execution_id) + layer._graph_runtime_state = mock_runtime_state # type: ignore[assignment] - with patch.object(SandboxFactory, "create", return_value=mock_sandbox): + with patch( + "services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox", + return_value=mock_sandbox, + ): layer.on_graph_start() layer.on_graph_end(error=None) - layer.on_graph_end(error=None) # Second call should be no-op + layer.on_graph_end(error=None) mock_sandbox.release_environment.assert_called_once() def test_layer_inherits_from_graph_engine_layer(self): - """Test that SandboxLayer properly inherits from GraphEngineLayer.""" - layer = SandboxLayer() + layer = SandboxLayer(tenant_id="test-tenant") - # Should have the graph_runtime_state property from base class with pytest.raises(GraphEngineLayerNotInitializedError): _ = layer.graph_runtime_state - # Should have command_channel from base class assert layer.command_channel is None class TestSandboxLayerIntegration: - """Integration tests for SandboxLayer with real LocalVirtualEnvironment.""" + def test_full_lifecycle_with_mocked_provider(self): + layer = SandboxLayer(tenant_id="integration-tenant") + workflow_execution_id = "integration-test-exec" + mock_runtime_state = MockReadOnlyGraphRuntimeStateWrapper(workflow_execution_id) + layer._graph_runtime_state = mock_runtime_state # type: ignore[assignment] + mock_sandbox = MagicMock(spec=VirtualEnvironment) + mock_sandbox.metadata = MockMetadata(sandbox_id="integration-sandbox") - def test_full_lifecycle_with_local_sandbox(self, tmp_path: Path): - """Test complete lifecycle: init -> start -> end with local sandbox.""" - layer = SandboxLayer( - sandbox_type=SandboxType.LOCAL, - options={"base_working_path": str(tmp_path)}, - ) + with patch( + "services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox", + return_value=mock_sandbox, + ): + layer.on_graph_start() - # Start - layer.on_graph_start() + assert layer._workflow_execution_id == workflow_execution_id # pyright: ignore[reportPrivateUsage] + assert layer.sandbox is mock_sandbox + assert SandboxManager.get(workflow_execution_id) is mock_sandbox - # Verify sandbox is created - assert layer._sandbox is not None # pyright: ignore[reportPrivateUsage] - sandbox_id = layer.sandbox.metadata.id - assert sandbox_id is not None - - # End layer.on_graph_end(error=None) - # Verify sandbox is released - assert layer._sandbox is None # pyright: ignore[reportPrivateUsage] + assert layer._workflow_execution_id is None # pyright: ignore[reportPrivateUsage] + assert not SandboxManager.has(workflow_execution_id) + mock_sandbox.release_environment.assert_called_once() - def test_lifecycle_with_workflow_error(self, tmp_path: Path): - """Test lifecycle when workflow encounters an error.""" - layer = SandboxLayer( - sandbox_type=SandboxType.LOCAL, - options={"base_working_path": str(tmp_path)}, - ) + def test_lifecycle_with_workflow_error(self): + layer = SandboxLayer(tenant_id="error-tenant") + workflow_execution_id = "integration-error-test" + mock_runtime_state = MockReadOnlyGraphRuntimeStateWrapper(workflow_execution_id) + layer._graph_runtime_state = mock_runtime_state # type: ignore[assignment] + mock_sandbox = MagicMock(spec=VirtualEnvironment) + mock_sandbox.metadata = MockMetadata() + + with patch( + "services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox", + return_value=mock_sandbox, + ): + layer.on_graph_start() - layer.on_graph_start() assert layer.sandbox.metadata.id is not None - # Simulate workflow error layer.on_graph_end(error=Exception("Workflow execution failed")) - # Sandbox should still be cleaned up - # pyright: ignore[reportPrivateUsage] - assert layer._sandbox is None # pyright: ignore[reportPrivateUsage] + assert layer._workflow_execution_id is None # pyright: ignore[reportPrivateUsage] + assert not SandboxManager.has(workflow_execution_id) + mock_sandbox.release_environment.assert_called_once() diff --git a/api/tests/unit_tests/core/virtual_environment/test_sandbox_manager.py b/api/tests/unit_tests/core/virtual_environment/test_sandbox_manager.py new file mode 100644 index 0000000000..512365fd80 --- /dev/null +++ b/api/tests/unit_tests/core/virtual_environment/test_sandbox_manager.py @@ -0,0 +1,153 @@ +import threading +from collections.abc import Mapping +from io import BytesIO +from typing import Any + +import pytest + +from core.virtual_environment.__base.entities import Arch, CommandStatus, ConnectionHandle, FileState, Metadata +from core.virtual_environment.__base.virtual_environment import VirtualEnvironment +from core.virtual_environment.sandbox_manager import SandboxManager + + +class FakeVirtualEnvironment(VirtualEnvironment): + def __init__(self, sandbox_id: str = "fake-id"): + self._sandbox_id = sandbox_id + super().__init__(tenant_id="test-tenant", options={}, environments={}) + + def _construct_environment(self, options: Mapping[str, Any], environments: Mapping[str, str]) -> Metadata: + return Metadata(id=self._sandbox_id, arch=Arch.AMD64) + + def upload_file(self, path: str, content: BytesIO) -> None: + raise NotImplementedError + + def download_file(self, path: str) -> BytesIO: + raise NotImplementedError + + def list_files(self, directory_path: str, limit: int) -> list[FileState]: + return [] + + def establish_connection(self) -> ConnectionHandle: + return ConnectionHandle(id="conn") + + def release_connection(self, connection_handle: ConnectionHandle) -> None: + pass + + def release_environment(self) -> None: + pass + + def execute_command( + self, connection_handle: ConnectionHandle, command: list[str], environments: Mapping[str, str] | None = None + ) -> tuple[str, Any, Any, Any]: + raise NotImplementedError + + def get_command_status(self, connection_handle: ConnectionHandle, pid: str) -> CommandStatus: + return CommandStatus(status=CommandStatus.Status.COMPLETED, exit_code=0) + + @classmethod + def validate(cls, options: Mapping[str, Any]) -> None: + pass + + +@pytest.fixture(autouse=True) +def clean_sandbox_manager(): + SandboxManager.clear() + yield + SandboxManager.clear() + + +class TestSandboxManager: + def test_register_and_get(self): + sandbox = FakeVirtualEnvironment("sandbox-1") + + SandboxManager.register("exec-1", sandbox) + result = SandboxManager.get("exec-1") + + assert result is sandbox + + def test_get_returns_none_for_unknown_id(self): + result = SandboxManager.get("unknown-id") + assert result is None + + def test_register_raises_on_empty_workflow_execution_id(self): + sandbox = FakeVirtualEnvironment() + + with pytest.raises(ValueError, match="workflow_execution_id cannot be empty"): + SandboxManager.register("", sandbox) + + def test_register_raises_on_duplicate(self): + sandbox1 = FakeVirtualEnvironment("sandbox-1") + sandbox2 = FakeVirtualEnvironment("sandbox-2") + + SandboxManager.register("exec-dup", sandbox1) + + with pytest.raises(RuntimeError, match="already registered"): + SandboxManager.register("exec-dup", sandbox2) + + def test_unregister_returns_sandbox(self): + sandbox = FakeVirtualEnvironment("sandbox-to-remove") + SandboxManager.register("exec-remove", sandbox) + + result = SandboxManager.unregister("exec-remove") + + assert result is sandbox + assert SandboxManager.get("exec-remove") is None + + def test_unregister_returns_none_for_unknown(self): + result = SandboxManager.unregister("nonexistent") + assert result is None + + def test_has_returns_true_when_registered(self): + sandbox = FakeVirtualEnvironment() + SandboxManager.register("exec-has", sandbox) + + assert SandboxManager.has("exec-has") is True + + def test_has_returns_false_when_not_registered(self): + assert SandboxManager.has("exec-no") is False + + def test_clear_removes_all_sandboxes(self): + sandbox1 = FakeVirtualEnvironment("s1") + sandbox2 = FakeVirtualEnvironment("s2") + SandboxManager.register("exec-1", sandbox1) + SandboxManager.register("exec-2", sandbox2) + + SandboxManager.clear() + + assert SandboxManager.count() == 0 + assert SandboxManager.get("exec-1") is None + assert SandboxManager.get("exec-2") is None + + def test_count_returns_number_of_sandboxes(self): + assert SandboxManager.count() == 0 + + SandboxManager.register("e1", FakeVirtualEnvironment("s1")) + assert SandboxManager.count() == 1 + + SandboxManager.register("e2", FakeVirtualEnvironment("s2")) + assert SandboxManager.count() == 2 + + SandboxManager.unregister("e1") + assert SandboxManager.count() == 1 + + def test_thread_safety(self): + results: list[bool] = [] + errors: list[Exception] = [] + + def register_sandbox(exec_id: str): + try: + sandbox = FakeVirtualEnvironment(f"sandbox-{exec_id}") + SandboxManager.register(exec_id, sandbox) + results.append(True) + except Exception as e: + errors.append(e) + + threads = [threading.Thread(target=register_sandbox, args=(f"exec-{i}",)) for i in range(10)] + for t in threads: + t.start() + for t in threads: + t.join() + + assert len(errors) == 0 + assert len(results) == 10 + assert SandboxManager.count() == 10 diff --git a/api/tests/unit_tests/core/workflow/nodes/command/test_command_node.py b/api/tests/unit_tests/core/workflow/nodes/command/test_command_node.py index 02de7f8c81..6e0d2350c7 100644 --- a/api/tests/unit_tests/core/workflow/nodes/command/test_command_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/command/test_command_node.py @@ -1,11 +1,15 @@ import time +from collections.abc import Mapping from io import BytesIO from typing import Any +import pytest + from core.virtual_environment.__base.entities import Arch, CommandStatus, ConnectionHandle, FileState, Metadata from core.virtual_environment.__base.virtual_environment import VirtualEnvironment from core.virtual_environment.channel.queue_transport import QueueTransportReadCloser from core.virtual_environment.channel.transport import NopTransportWriteCloser +from core.virtual_environment.sandbox_manager import SandboxManager from core.workflow.entities import GraphInitParams from core.workflow.enums import WorkflowNodeExecutionStatus from core.workflow.nodes.command.node import CommandNode @@ -30,7 +34,7 @@ class FakeSandbox(VirtualEnvironment): self.released_connections: list[str] = [] super().__init__(tenant_id="test-tenant", options={}, environments={}) - def _construct_environment(self, options, environments): # type: ignore[override] + def _construct_environment(self, options: Mapping[str, Any], environments: Mapping[str, str]) -> Metadata: return Metadata(id="fake", arch=Arch.ARM64) def upload_file(self, path: str, content: BytesIO) -> None: @@ -51,7 +55,9 @@ class FakeSandbox(VirtualEnvironment): def release_environment(self) -> None: return - def execute_command(self, connection_handle: ConnectionHandle, command: list[str], environments=None): # type: ignore[override] + def execute_command( + self, connection_handle: ConnectionHandle, command: list[str], environments: Mapping[str, str] | None = None + ) -> tuple[str, NopTransportWriteCloser, QueueTransportReadCloser, QueueTransportReadCloser]: _ = connection_handle _ = environments self.last_execute_command = command @@ -76,12 +82,22 @@ class FakeSandbox(VirtualEnvironment): return CommandStatus(status=CommandStatus.Status.COMPLETED, exit_code=0) @classmethod - def validate(cls, options: Any) -> None: + def validate(cls, options: Mapping[str, Any]) -> None: pass -def _make_node(*, command: str, working_directory: str = "") -> CommandNode: - variable_pool = VariablePool(system_variables=SystemVariable.empty(), user_inputs={}) +@pytest.fixture(autouse=True) +def clean_sandbox_manager(): + SandboxManager.clear() + yield + SandboxManager.clear() + + +def _make_node( + *, command: str, working_directory: str = "", workflow_execution_id: str = "test-workflow-exec-id" +) -> CommandNode: + system_variables = SystemVariable(workflow_execution_id=workflow_execution_id) + variable_pool = VariablePool(system_variables=system_variables, user_inputs={}) runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) init_params = GraphInitParams( tenant_id="t", @@ -110,11 +126,16 @@ def _make_node(*, command: str, working_directory: str = "") -> CommandNode: def test_command_node_success_executes_in_sandbox(): - node = _make_node(command="echo {{#pre_node_id.number#}}", working_directory="dir-{{#pre_node_id.number#}}") + workflow_execution_id = "test-exec-success" + node = _make_node( + command="echo {{#pre_node_id.number#}}", + working_directory="dir-{{#pre_node_id.number#}}", + workflow_execution_id=workflow_execution_id, + ) node.graph_runtime_state.variable_pool.add(("pre_node_id", "number"), 42) sandbox = FakeSandbox(stdout=b"ok\n", stderr=b"") - node.sandbox = sandbox + SandboxManager.register(workflow_execution_id, sandbox) result = node._run() # pyright: ignore[reportPrivateUsage] @@ -124,18 +145,19 @@ def test_command_node_success_executes_in_sandbox(): assert result.outputs["exit_code"] == 0 assert sandbox.last_execute_command is not None - assert sandbox.last_execute_command[:2] == ["sh", "-lc"] + assert sandbox.last_execute_command[:2] == ["sh", "-c"] assert "cd dir-42 && echo 42" in sandbox.last_execute_command[2] def test_command_node_nonzero_exit_code_returns_failed_result(): - node = _make_node(command="false") + workflow_execution_id = "test-exec-nonzero" + node = _make_node(command="false", workflow_execution_id=workflow_execution_id) sandbox = FakeSandbox( stdout=b"out", stderr=b"err", statuses=[CommandStatus(status=CommandStatus.Status.COMPLETED, exit_code=2)], ) - node.sandbox = sandbox + SandboxManager.register(workflow_execution_id, sandbox) result = node._run() # pyright: ignore[reportPrivateUsage] @@ -149,17 +171,29 @@ def test_command_node_timeout_returns_failed_result_and_closes_transports(monkey monkeypatch.setattr(command_node_module, "COMMAND_NODE_TIMEOUT_SECONDS", 1) - node = _make_node(command="sleep 10") + workflow_execution_id = "test-exec-timeout" + node = _make_node(command="sleep 10", workflow_execution_id=workflow_execution_id) sandbox = FakeSandbox( stdout=b"", stderr=b"", statuses=[CommandStatus(status=CommandStatus.Status.RUNNING, exit_code=None)] * 1000, close_streams=False, ) - node.sandbox = sandbox + SandboxManager.register(workflow_execution_id, sandbox) result = node._run() # pyright: ignore[reportPrivateUsage] assert result.status == WorkflowNodeExecutionStatus.FAILED assert result.error_type == "CommandTimeoutError" assert "timed out" in result.error + + +def test_command_node_no_sandbox_returns_failed(): + workflow_execution_id = "test-exec-no-sandbox" + node = _make_node(command="echo hello", workflow_execution_id=workflow_execution_id) + + result = node._run() # pyright: ignore[reportPrivateUsage] + + assert result.status == WorkflowNodeExecutionStatus.FAILED + assert result.error_type == "SandboxNotInitializedError" + assert "Sandbox not available" in result.error