mirror of
https://github.com/langgenius/dify.git
synced 2026-01-14 06:07:33 +08:00
feat(sandbox-layer): refactor sandbox management and integrate with SandboxManager
- Simplified the SandboxLayer initialization by removing unused parameters and consolidating sandbox creation logic. - Integrated SandboxManager for better lifecycle management of sandboxes during workflow execution. - Updated error handling to ensure proper initialization and cleanup of sandboxes. - Enhanced CommandNode to retrieve sandboxes from SandboxManager, improving sandbox availability checks. - Added unit tests to validate the new sandbox management approach and ensure robust error handling.
This commit is contained in:
parent
b09a831d15
commit
0da4d64d38
@ -1,185 +1,95 @@
|
||||
"""
|
||||
Sandbox Layer for managing VirtualEnvironment lifecycle during workflow execution.
|
||||
"""
|
||||
|
||||
import contextlib
|
||||
import logging
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from core.virtual_environment.__base.virtual_environment import VirtualEnvironment
|
||||
from core.virtual_environment.factory import SandboxFactory, SandboxType
|
||||
from core.workflow.enums import NodeType
|
||||
from core.virtual_environment.sandbox_manager import SandboxManager
|
||||
from core.workflow.graph_engine.layers.base import GraphEngineLayer
|
||||
from core.workflow.graph_events.base import GraphEngineEvent
|
||||
from core.workflow.nodes.base.node import Node
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SandboxInitializationError(Exception):
|
||||
"""Raised when sandbox initialization fails."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class SandboxLayer(GraphEngineLayer):
|
||||
"""
|
||||
Manages VirtualEnvironment (sandbox) lifecycle during workflow execution.
|
||||
|
||||
Responsibilities:
|
||||
- on_graph_start: Initialize the sandbox environment
|
||||
- on_graph_end: Release the sandbox environment (cleanup)
|
||||
|
||||
Example:
|
||||
# Using tenant-specific configuration (recommended):
|
||||
layer = SandboxLayer(tenant_id="tenant-uuid")
|
||||
|
||||
# Using explicit configuration (for testing/override):
|
||||
layer = SandboxLayer(
|
||||
sandbox_type=SandboxType.DOCKER,
|
||||
options={"docker_image": "python:3.11-slim"},
|
||||
)
|
||||
graph_engine.layer(layer)
|
||||
|
||||
# During workflow execution, access sandbox via:
|
||||
# layer.sandbox.execute_command(...)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tenant_id: str | None = None,
|
||||
sandbox_type: SandboxType | None = None,
|
||||
tenant_id: str,
|
||||
options: Mapping[str, Any] | None = None,
|
||||
environments: Mapping[str, str] | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the SandboxLayer.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant ID to load sandbox configuration from database.
|
||||
If provided, sandbox_type and options are ignored and
|
||||
loaded from the tenant's active sandbox provider.
|
||||
sandbox_type: Type of sandbox to create (default: DOCKER).
|
||||
Only used if tenant_id is not provided.
|
||||
options: Sandbox-specific configuration options.
|
||||
Only used if tenant_id is not provided.
|
||||
environments: Environment variables to set in the sandbox.
|
||||
"""
|
||||
super().__init__()
|
||||
self._tenant_id = tenant_id
|
||||
self._sandbox_type = sandbox_type
|
||||
self._options: Mapping[str, Any] = options or {}
|
||||
self._environments: Mapping[str, str] = environments or {}
|
||||
self._sandbox: VirtualEnvironment | None = None
|
||||
self._workflow_execution_id: str | None = None
|
||||
|
||||
def _get_workflow_execution_id(self) -> str:
|
||||
workflow_execution_id = self.graph_runtime_state.system_variable.workflow_execution_id
|
||||
if not workflow_execution_id:
|
||||
raise RuntimeError("workflow_execution_id is not set in system variables")
|
||||
return workflow_execution_id
|
||||
|
||||
@property
|
||||
def sandbox(self) -> VirtualEnvironment:
|
||||
"""
|
||||
Get the current sandbox instance.
|
||||
|
||||
Returns:
|
||||
The initialized VirtualEnvironment instance
|
||||
|
||||
Raises:
|
||||
RuntimeError: If sandbox has not been initialized
|
||||
"""
|
||||
if self._sandbox is None:
|
||||
if self._workflow_execution_id is None:
|
||||
raise RuntimeError("Sandbox not initialized. Ensure on_graph_start() has been called.")
|
||||
return self._sandbox
|
||||
sandbox = SandboxManager.get(self._workflow_execution_id)
|
||||
if sandbox is None:
|
||||
raise RuntimeError(f"Sandbox not found for workflow_execution_id={self._workflow_execution_id}")
|
||||
return sandbox
|
||||
|
||||
def on_graph_start(self) -> None:
|
||||
"""
|
||||
Initialize the sandbox when workflow execution starts.
|
||||
self._workflow_execution_id = self._get_workflow_execution_id()
|
||||
|
||||
If tenant_id was provided, uses SandboxProviderService to create
|
||||
the sandbox with the tenant's active provider configuration.
|
||||
Otherwise, falls back to explicit sandbox_type/options.
|
||||
|
||||
Raises:
|
||||
SandboxInitializationError: If sandbox cannot be created
|
||||
"""
|
||||
try:
|
||||
if self._tenant_id:
|
||||
# Use SandboxProviderService to create sandbox based on tenant config
|
||||
from services.sandbox.sandbox_provider_service import SandboxProviderService
|
||||
sandbox: VirtualEnvironment
|
||||
from services.sandbox.sandbox_provider_service import SandboxProviderService
|
||||
|
||||
logger.info("Initializing sandbox for tenant_id=%s", self._tenant_id)
|
||||
self._sandbox = SandboxProviderService.create_sandbox(
|
||||
tenant_id=self._tenant_id,
|
||||
environments=self._environments,
|
||||
)
|
||||
else:
|
||||
# Fallback to explicit configuration (backward compatibility)
|
||||
sandbox_type = self._sandbox_type or SandboxType.DOCKER
|
||||
logger.info("Initializing sandbox, sandbox_type=%s", sandbox_type)
|
||||
# Use a placeholder tenant_id for backward compatibility when tenant_id is not provided
|
||||
effective_tenant_id = self._tenant_id or "default"
|
||||
self._sandbox = SandboxFactory.create(
|
||||
tenant_id=effective_tenant_id,
|
||||
sandbox_type=sandbox_type,
|
||||
options=self._options,
|
||||
environments=self._environments,
|
||||
)
|
||||
logger.info("Initializing sandbox for tenant_id=%s", self._tenant_id)
|
||||
sandbox = SandboxProviderService.create_sandbox(
|
||||
tenant_id=self._tenant_id,
|
||||
environments=self._environments,
|
||||
)
|
||||
|
||||
SandboxManager.register(self._workflow_execution_id, sandbox)
|
||||
logger.info(
|
||||
"Sandbox initialized, sandbox_id=%s, sandbox_arch=%s",
|
||||
self._sandbox.metadata.id,
|
||||
self._sandbox.metadata.arch,
|
||||
"Sandbox initialized, workflow_execution_id=%s, sandbox_id=%s, sandbox_arch=%s",
|
||||
self._workflow_execution_id,
|
||||
sandbox.metadata.id,
|
||||
sandbox.metadata.arch,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("Failed to initialize sandbox")
|
||||
raise SandboxInitializationError(f"Failed to initialize sandbox: {e}") from e
|
||||
|
||||
def on_event(self, event: GraphEngineEvent) -> None:
|
||||
"""
|
||||
Handle graph engine events.
|
||||
|
||||
Currently a no-op, but can be extended for sandbox monitoring/health checks.
|
||||
"""
|
||||
pass
|
||||
|
||||
def on_node_run_start(self, node: Node[Any]) -> None:
|
||||
"""Attach sandbox handle to CommandNode instances."""
|
||||
if node.node_type is not NodeType.COMMAND:
|
||||
return
|
||||
|
||||
try:
|
||||
# FIXME: type: ignore[attr-defined]
|
||||
node.sandbox = self.sandbox # type: ignore[attr-defined]
|
||||
except Exception:
|
||||
logger.exception("Failed to attach sandbox to node")
|
||||
|
||||
def on_node_run_end(self, node: Node[Any], error: Exception | None) -> None:
|
||||
_ = error
|
||||
if node.node_type is not NodeType.COMMAND:
|
||||
return
|
||||
|
||||
with contextlib.suppress(Exception):
|
||||
# FIXME: type: ignore[attr-defined]
|
||||
node.sandbox = None # type: ignore[attr-defined]
|
||||
|
||||
def on_graph_end(self, error: Exception | None) -> None:
|
||||
"""
|
||||
Release the sandbox when workflow execution ends.
|
||||
|
||||
This method is idempotent and will not raise exceptions on cleanup failure.
|
||||
|
||||
Args:
|
||||
error: The exception that caused execution to fail, or None if successful
|
||||
"""
|
||||
if self._sandbox is None:
|
||||
logger.debug("No sandbox to release")
|
||||
if self._workflow_execution_id is None:
|
||||
logger.debug("No workflow_execution_id set, nothing to release")
|
||||
return
|
||||
|
||||
sandbox_id = self._sandbox.metadata.id
|
||||
logger.info("Releasing sandbox, sandbox_id=%s", sandbox_id)
|
||||
sandbox = SandboxManager.unregister(self._workflow_execution_id)
|
||||
if sandbox is None:
|
||||
logger.debug("No sandbox to release for workflow_execution_id=%s", self._workflow_execution_id)
|
||||
return
|
||||
|
||||
sandbox_id = sandbox.metadata.id
|
||||
logger.info(
|
||||
"Releasing sandbox, workflow_execution_id=%s, sandbox_id=%s",
|
||||
self._workflow_execution_id,
|
||||
sandbox_id,
|
||||
)
|
||||
|
||||
try:
|
||||
self._sandbox.release_environment()
|
||||
sandbox.release_environment()
|
||||
logger.info("Sandbox released, sandbox_id=%s", sandbox_id)
|
||||
except Exception:
|
||||
# Log but don't raise - cleanup failures should not break workflow completion
|
||||
logger.exception("Failed to release sandbox, sandbox_id=%s", sandbox_id)
|
||||
finally:
|
||||
self._sandbox = None
|
||||
self._workflow_execution_id = None
|
||||
|
||||
63
api/core/virtual_environment/sandbox_manager.py
Normal file
63
api/core/virtual_environment/sandbox_manager.py
Normal file
@ -0,0 +1,63 @@
|
||||
import logging
|
||||
import threading
|
||||
from typing import Final
|
||||
|
||||
from core.virtual_environment.__base.virtual_environment import VirtualEnvironment
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SandboxManager:
|
||||
_lock: Final[threading.Lock] = threading.Lock()
|
||||
_sandboxes: dict[str, VirtualEnvironment] = {}
|
||||
|
||||
@classmethod
|
||||
def register(cls, workflow_execution_id: str, sandbox: VirtualEnvironment) -> None:
|
||||
if not workflow_execution_id:
|
||||
raise ValueError("workflow_execution_id cannot be empty")
|
||||
|
||||
with cls._lock:
|
||||
if workflow_execution_id in cls._sandboxes:
|
||||
raise RuntimeError(
|
||||
f"Sandbox already registered for workflow_execution_id={workflow_execution_id}. "
|
||||
"Call unregister() first if you need to replace it."
|
||||
)
|
||||
cls._sandboxes[workflow_execution_id] = sandbox
|
||||
logger.debug(
|
||||
"Registered sandbox for workflow_execution_id=%s, sandbox_id=%s",
|
||||
workflow_execution_id,
|
||||
sandbox.metadata.id,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get(cls, workflow_execution_id: str) -> VirtualEnvironment | None:
|
||||
with cls._lock:
|
||||
return cls._sandboxes.get(workflow_execution_id)
|
||||
|
||||
@classmethod
|
||||
def unregister(cls, workflow_execution_id: str) -> VirtualEnvironment | None:
|
||||
with cls._lock:
|
||||
sandbox = cls._sandboxes.pop(workflow_execution_id, None)
|
||||
if sandbox:
|
||||
logger.debug(
|
||||
"Unregistered sandbox for workflow_execution_id=%s, sandbox_id=%s",
|
||||
workflow_execution_id,
|
||||
sandbox.metadata.id,
|
||||
)
|
||||
return sandbox
|
||||
|
||||
@classmethod
|
||||
def has(cls, workflow_execution_id: str) -> bool:
|
||||
with cls._lock:
|
||||
return workflow_execution_id in cls._sandboxes
|
||||
|
||||
@classmethod
|
||||
def clear(cls) -> None:
|
||||
with cls._lock:
|
||||
cls._sandboxes.clear()
|
||||
logger.debug("Cleared all registered sandboxes")
|
||||
|
||||
@classmethod
|
||||
def count(cls) -> int:
|
||||
with cls._lock:
|
||||
return len(cls._sandboxes)
|
||||
@ -6,6 +6,7 @@ from typing import Any
|
||||
|
||||
from core.virtual_environment.__base.command_future import CommandCancelledError, CommandTimeoutError
|
||||
from core.virtual_environment.__base.virtual_environment import VirtualEnvironment
|
||||
from core.virtual_environment.sandbox_manager import SandboxManager
|
||||
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.base import variable_template_parser
|
||||
@ -21,14 +22,14 @@ COMMAND_NODE_TIMEOUT_SECONDS = 60
|
||||
|
||||
|
||||
class CommandNode(Node[CommandNodeData]):
|
||||
# FIXME: This is a temporary solution for sandbox injection from SandboxLayer.
|
||||
# The sandbox is dynamically attached by SandboxLayer.on_node_run_start() before
|
||||
# node execution and cleared by on_node_run_end(). A cleaner approach would be
|
||||
# to pass sandbox through GraphRuntimeState or use a proper dependency injection pattern.
|
||||
sandbox: VirtualEnvironment | None = None
|
||||
|
||||
node_type = NodeType.COMMAND
|
||||
|
||||
def _get_sandbox(self) -> VirtualEnvironment | None:
|
||||
workflow_execution_id = self.graph_runtime_state.variable_pool.system_variables.workflow_execution_id
|
||||
if not workflow_execution_id:
|
||||
return None
|
||||
return SandboxManager.get(workflow_execution_id)
|
||||
|
||||
def _render_template(self, template: str) -> str:
|
||||
parser = VariableTemplateParser(template=template)
|
||||
selectors = parser.extract_variable_selectors()
|
||||
@ -57,7 +58,8 @@ class CommandNode(Node[CommandNodeData]):
|
||||
return "1"
|
||||
|
||||
def _run(self) -> NodeRunResult:
|
||||
if not isinstance(self.sandbox, VirtualEnvironment):
|
||||
sandbox = self._get_sandbox()
|
||||
if sandbox is None:
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error="Sandbox not available for CommandNode.",
|
||||
@ -77,15 +79,15 @@ class CommandNode(Node[CommandNodeData]):
|
||||
)
|
||||
|
||||
timeout = COMMAND_NODE_TIMEOUT_SECONDS if COMMAND_NODE_TIMEOUT_SECONDS > 0 else None
|
||||
connection_handle = self.sandbox.establish_connection()
|
||||
connection_handle = sandbox.establish_connection()
|
||||
|
||||
try:
|
||||
# FIXME: VirtualEnvironment.run_command lacks native cwd support.
|
||||
# TODO: VirtualEnvironment.run_command lacks native cwd support.
|
||||
# Once the interface adds a `cwd` parameter, remove this shell hack
|
||||
# and pass working_directory directly to run_command.
|
||||
if working_directory:
|
||||
check_cmd = ["test", "-d", working_directory]
|
||||
check_future = self.sandbox.run_command(connection_handle, check_cmd)
|
||||
check_future = sandbox.run_command(connection_handle, check_cmd)
|
||||
check_result = check_future.result(timeout=timeout)
|
||||
|
||||
if check_result.exit_code != 0:
|
||||
@ -99,7 +101,7 @@ class CommandNode(Node[CommandNodeData]):
|
||||
else:
|
||||
command = shlex.split(raw_command)
|
||||
|
||||
future = self.sandbox.run_command(connection_handle, command)
|
||||
future = sandbox.run_command(connection_handle, command)
|
||||
result = future.result(timeout=timeout)
|
||||
|
||||
outputs: dict[str, Any] = {
|
||||
@ -149,7 +151,7 @@ class CommandNode(Node[CommandNodeData]):
|
||||
)
|
||||
finally:
|
||||
with contextlib.suppress(Exception):
|
||||
self.sandbox.release_connection(connection_handle)
|
||||
sandbox.release_connection(connection_handle)
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
|
||||
@ -697,11 +697,17 @@ class WorkflowService:
|
||||
else:
|
||||
enclosing_node_id = None
|
||||
|
||||
# FIXME: Consolidate runtime config checking into a unified location.
|
||||
# TODO: Consolidate runtime config checking into a unified location.
|
||||
runtime = draft_workflow.features_dict.get("runtime")
|
||||
sandbox = None
|
||||
single_step_execution_id: str | None = None
|
||||
if isinstance(runtime, dict) and runtime.get("enabled"):
|
||||
sandbox = SandboxProviderService.create_sandbox(tenant_id=draft_workflow.tenant_id)
|
||||
single_step_execution_id = f"single-step-{uuid.uuid4()}"
|
||||
from core.virtual_environment.sandbox_manager import SandboxManager
|
||||
|
||||
SandboxManager.register(single_step_execution_id, sandbox)
|
||||
variable_pool.system_variables.workflow_execution_id = single_step_execution_id
|
||||
|
||||
try:
|
||||
node, generator = WorkflowEntry.single_step_run(
|
||||
@ -713,10 +719,6 @@ class WorkflowService:
|
||||
variable_loader=variable_loader,
|
||||
)
|
||||
|
||||
# FIXME: Use a proper dependency injection pattern for sandbox.
|
||||
if sandbox:
|
||||
node.sandbox = sandbox # type: ignore[attr-defined]
|
||||
|
||||
# Run draft workflow node
|
||||
start_at = time.perf_counter()
|
||||
node_execution = self._handle_single_step_result(
|
||||
@ -725,12 +727,15 @@ class WorkflowService:
|
||||
node_id=node_id,
|
||||
)
|
||||
finally:
|
||||
# Release sandbox after node execution
|
||||
if sandbox:
|
||||
try:
|
||||
sandbox.release_environment()
|
||||
except Exception:
|
||||
logger.exception("Failed to release sandbox")
|
||||
if single_step_execution_id:
|
||||
from core.virtual_environment.sandbox_manager import SandboxManager
|
||||
|
||||
sandbox = SandboxManager.unregister(single_step_execution_id)
|
||||
if sandbox:
|
||||
try:
|
||||
sandbox.release_environment()
|
||||
except Exception:
|
||||
logger.exception("Failed to release sandbox")
|
||||
|
||||
# Set workflow_id on the NodeExecution
|
||||
node_execution.workflow_id = draft_workflow.id
|
||||
|
||||
@ -1,11 +1,3 @@
|
||||
"""
|
||||
Unit tests for the SandboxLayer.
|
||||
|
||||
This module tests the SandboxLayer lifecycle management including initialization,
|
||||
event handling, and cleanup of VirtualEnvironment instances.
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
@ -13,7 +5,7 @@ import pytest
|
||||
from core.app.layers.sandbox_layer import SandboxInitializationError, SandboxLayer
|
||||
from core.virtual_environment.__base.entities import Arch
|
||||
from core.virtual_environment.__base.virtual_environment import VirtualEnvironment
|
||||
from core.virtual_environment.factory import SandboxFactory, SandboxType
|
||||
from core.virtual_environment.sandbox_manager import SandboxManager
|
||||
from core.workflow.graph_engine.layers.base import GraphEngineLayerNotInitializedError
|
||||
from core.workflow.graph_events.graph import (
|
||||
GraphRunFailedEvent,
|
||||
@ -23,16 +15,12 @@ from core.workflow.graph_events.graph import (
|
||||
|
||||
|
||||
class MockMetadata:
|
||||
"""Mock metadata for testing."""
|
||||
|
||||
def __init__(self, sandbox_id: str = "test-sandbox-id", arch: Arch = Arch.AMD64):
|
||||
self.id = sandbox_id
|
||||
self.arch = arch
|
||||
|
||||
|
||||
class MockVirtualEnvironment:
|
||||
"""Mock VirtualEnvironment for testing."""
|
||||
|
||||
def __init__(self, sandbox_id: str = "test-sandbox-id"):
|
||||
self.metadata = MockMetadata(sandbox_id=sandbox_id)
|
||||
self._released = False
|
||||
@ -41,33 +29,46 @@ class MockVirtualEnvironment:
|
||||
self._released = True
|
||||
|
||||
|
||||
class MockSystemVariableView:
|
||||
def __init__(self, workflow_execution_id: str | None = "test-workflow-exec-id"):
|
||||
self._workflow_execution_id = workflow_execution_id
|
||||
|
||||
@property
|
||||
def workflow_execution_id(self) -> str | None:
|
||||
return self._workflow_execution_id
|
||||
|
||||
|
||||
class MockReadOnlyGraphRuntimeStateWrapper:
|
||||
def __init__(self, workflow_execution_id: str | None = "test-workflow-exec-id"):
|
||||
self._system_variable = MockSystemVariableView(workflow_execution_id)
|
||||
|
||||
@property
|
||||
def system_variable(self) -> MockSystemVariableView:
|
||||
return self._system_variable
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def clean_sandbox_manager():
|
||||
SandboxManager.clear()
|
||||
yield
|
||||
SandboxManager.clear()
|
||||
|
||||
|
||||
class TestSandboxLayer:
|
||||
"""Unit tests for SandboxLayer."""
|
||||
|
||||
def test_init_with_default_parameters(self):
|
||||
"""Test SandboxLayer initialization with default parameters."""
|
||||
layer = SandboxLayer()
|
||||
|
||||
assert layer._sandbox_type is None # pyright: ignore[reportPrivateUsage]
|
||||
assert layer._options == {} # pyright: ignore[reportPrivateUsage]
|
||||
assert layer._environments == {} # pyright: ignore[reportPrivateUsage]
|
||||
assert layer._sandbox is None # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
def test_init_with_custom_parameters(self):
|
||||
"""Test SandboxLayer initialization with custom parameters."""
|
||||
def test_init_with_parameters(self):
|
||||
layer = SandboxLayer(
|
||||
sandbox_type=SandboxType.LOCAL,
|
||||
tenant_id="test-tenant",
|
||||
options={"base_working_path": "/tmp/sandbox"},
|
||||
environments={"PYTHONUNBUFFERED": "1"},
|
||||
)
|
||||
|
||||
assert layer._sandbox_type == SandboxType.LOCAL # pyright: ignore[reportPrivateUsage]
|
||||
assert layer._tenant_id == "test-tenant" # pyright: ignore[reportPrivateUsage]
|
||||
assert layer._options == {"base_working_path": "/tmp/sandbox"} # pyright: ignore[reportPrivateUsage]
|
||||
assert layer._environments == {"PYTHONUNBUFFERED": "1"} # pyright: ignore[reportPrivateUsage]
|
||||
assert layer._workflow_execution_id is None # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
def test_sandbox_property_raises_when_not_initialized(self):
|
||||
"""Test that accessing sandbox property raises error before initialization."""
|
||||
layer = SandboxLayer()
|
||||
layer = SandboxLayer(tenant_id="test-tenant")
|
||||
|
||||
with pytest.raises(RuntimeError) as exc_info:
|
||||
_ = layer.sandbox
|
||||
@ -75,170 +76,213 @@ class TestSandboxLayer:
|
||||
assert "Sandbox not initialized" in str(exc_info.value)
|
||||
|
||||
def test_sandbox_property_returns_sandbox_after_initialization(self):
|
||||
"""Test that sandbox property returns the sandbox after on_graph_start."""
|
||||
layer = SandboxLayer()
|
||||
layer = SandboxLayer(tenant_id="test-tenant")
|
||||
mock_sandbox = MockVirtualEnvironment()
|
||||
mock_runtime_state = MockReadOnlyGraphRuntimeStateWrapper("test-exec-id")
|
||||
layer._graph_runtime_state = mock_runtime_state # type: ignore[assignment]
|
||||
|
||||
with patch.object(SandboxFactory, "create", return_value=mock_sandbox):
|
||||
with patch(
|
||||
"services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox",
|
||||
return_value=mock_sandbox,
|
||||
):
|
||||
layer.on_graph_start()
|
||||
|
||||
assert layer.sandbox is mock_sandbox
|
||||
|
||||
def test_on_graph_start_creates_sandbox(self):
|
||||
"""Test that on_graph_start creates a sandbox via factory."""
|
||||
def test_on_graph_start_creates_sandbox_and_registers_with_manager(self):
|
||||
layer = SandboxLayer(
|
||||
sandbox_type=SandboxType.DOCKER,
|
||||
options={"docker_image": "python:3.11"},
|
||||
tenant_id="test-tenant-123",
|
||||
environments={"PATH": "/usr/bin"},
|
||||
)
|
||||
mock_sandbox = MockVirtualEnvironment()
|
||||
mock_runtime_state = MockReadOnlyGraphRuntimeStateWrapper("test-exec-123")
|
||||
layer._graph_runtime_state = mock_runtime_state # type: ignore[assignment]
|
||||
|
||||
with patch.object(SandboxFactory, "create", return_value=mock_sandbox) as mock_create:
|
||||
with patch(
|
||||
"services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox",
|
||||
return_value=mock_sandbox,
|
||||
) as mock_create:
|
||||
layer.on_graph_start()
|
||||
|
||||
mock_create.assert_called_once_with(
|
||||
tenant_id="default",
|
||||
sandbox_type=SandboxType.DOCKER,
|
||||
options={"docker_image": "python:3.11"},
|
||||
tenant_id="test-tenant-123",
|
||||
environments={"PATH": "/usr/bin"},
|
||||
)
|
||||
|
||||
def test_on_graph_start_raises_sandbox_initialization_error_on_failure(self):
|
||||
"""Test that on_graph_start raises SandboxInitializationError on factory failure."""
|
||||
layer = SandboxLayer(sandbox_type=SandboxType.DOCKER)
|
||||
assert SandboxManager.get("test-exec-123") is mock_sandbox
|
||||
|
||||
with patch.object(SandboxFactory, "create", side_effect=Exception("Docker not available")):
|
||||
def test_on_graph_start_raises_sandbox_initialization_error_on_failure(self):
|
||||
layer = SandboxLayer(tenant_id="test-tenant")
|
||||
mock_runtime_state = MockReadOnlyGraphRuntimeStateWrapper("test-exec-id")
|
||||
layer._graph_runtime_state = mock_runtime_state # type: ignore[assignment]
|
||||
|
||||
with patch(
|
||||
"services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox",
|
||||
side_effect=Exception("Sandbox provider not available"),
|
||||
):
|
||||
with pytest.raises(SandboxInitializationError) as exc_info:
|
||||
layer.on_graph_start()
|
||||
|
||||
assert "Failed to initialize sandbox" in str(exc_info.value)
|
||||
assert "Docker not available" in str(exc_info.value)
|
||||
assert "Sandbox provider not available" in str(exc_info.value)
|
||||
|
||||
def test_on_graph_start_raises_when_workflow_execution_id_not_set(self):
|
||||
layer = SandboxLayer(tenant_id="test-tenant")
|
||||
mock_runtime_state = MockReadOnlyGraphRuntimeStateWrapper(workflow_execution_id=None)
|
||||
layer._graph_runtime_state = mock_runtime_state # type: ignore[assignment]
|
||||
|
||||
with pytest.raises(RuntimeError) as exc_info:
|
||||
layer.on_graph_start()
|
||||
|
||||
assert "workflow_execution_id is not set" in str(exc_info.value)
|
||||
|
||||
def test_on_event_is_noop(self):
|
||||
"""Test that on_event does nothing (no-op)."""
|
||||
layer = SandboxLayer()
|
||||
layer = SandboxLayer(tenant_id="test-tenant")
|
||||
|
||||
# These should not raise any exceptions
|
||||
layer.on_event(GraphRunStartedEvent())
|
||||
layer.on_event(GraphRunSucceededEvent(outputs={}))
|
||||
layer.on_event(GraphRunFailedEvent(error="test error", exceptions_count=1))
|
||||
|
||||
def test_on_graph_end_releases_sandbox(self):
|
||||
"""Test that on_graph_end releases the sandbox."""
|
||||
layer = SandboxLayer()
|
||||
def test_on_graph_end_releases_sandbox_and_unregisters_from_manager(self):
|
||||
layer = SandboxLayer(tenant_id="test-tenant")
|
||||
mock_sandbox = MagicMock(spec=VirtualEnvironment)
|
||||
mock_sandbox.metadata = MockMetadata()
|
||||
workflow_execution_id = "test-exec-456"
|
||||
mock_runtime_state = MockReadOnlyGraphRuntimeStateWrapper(workflow_execution_id)
|
||||
layer._graph_runtime_state = mock_runtime_state # type: ignore[assignment]
|
||||
|
||||
with patch.object(SandboxFactory, "create", return_value=mock_sandbox):
|
||||
with patch(
|
||||
"services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox",
|
||||
return_value=mock_sandbox,
|
||||
):
|
||||
layer.on_graph_start()
|
||||
|
||||
assert SandboxManager.has(workflow_execution_id)
|
||||
|
||||
layer.on_graph_end(error=None)
|
||||
|
||||
mock_sandbox.release_environment.assert_called_once()
|
||||
assert layer._sandbox is None # pyright: ignore[reportPrivateUsage]
|
||||
assert layer._workflow_execution_id is None # pyright: ignore[reportPrivateUsage]
|
||||
assert not SandboxManager.has(workflow_execution_id)
|
||||
|
||||
def test_on_graph_end_releases_sandbox_even_on_error(self):
|
||||
"""Test that on_graph_end releases sandbox even when workflow had an error."""
|
||||
layer = SandboxLayer()
|
||||
layer = SandboxLayer(tenant_id="test-tenant")
|
||||
mock_sandbox = MagicMock(spec=VirtualEnvironment)
|
||||
mock_sandbox.metadata = MockMetadata()
|
||||
workflow_execution_id = "test-exec-789"
|
||||
mock_runtime_state = MockReadOnlyGraphRuntimeStateWrapper(workflow_execution_id)
|
||||
layer._graph_runtime_state = mock_runtime_state # type: ignore[assignment]
|
||||
|
||||
with patch.object(SandboxFactory, "create", return_value=mock_sandbox):
|
||||
with patch(
|
||||
"services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox",
|
||||
return_value=mock_sandbox,
|
||||
):
|
||||
layer.on_graph_start()
|
||||
|
||||
layer.on_graph_end(error=Exception("Workflow failed"))
|
||||
|
||||
mock_sandbox.release_environment.assert_called_once()
|
||||
assert layer._sandbox is None # pyright: ignore[reportPrivateUsage]
|
||||
assert layer._workflow_execution_id is None # pyright: ignore[reportPrivateUsage]
|
||||
assert not SandboxManager.has(workflow_execution_id)
|
||||
|
||||
def test_on_graph_end_handles_release_failure_gracefully(self):
|
||||
"""Test that on_graph_end handles release failures without raising."""
|
||||
layer = SandboxLayer()
|
||||
layer = SandboxLayer(tenant_id="test-tenant")
|
||||
mock_sandbox = MagicMock(spec=VirtualEnvironment)
|
||||
mock_sandbox.metadata = MockMetadata()
|
||||
mock_sandbox.release_environment.side_effect = Exception("Container already removed")
|
||||
workflow_execution_id = "test-exec-fail"
|
||||
mock_runtime_state = MockReadOnlyGraphRuntimeStateWrapper(workflow_execution_id)
|
||||
layer._graph_runtime_state = mock_runtime_state # type: ignore[assignment]
|
||||
|
||||
with patch.object(SandboxFactory, "create", return_value=mock_sandbox):
|
||||
with patch(
|
||||
"services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox",
|
||||
return_value=mock_sandbox,
|
||||
):
|
||||
layer.on_graph_start()
|
||||
|
||||
# Should not raise exception
|
||||
layer.on_graph_end(error=None)
|
||||
|
||||
mock_sandbox.release_environment.assert_called_once()
|
||||
assert layer._sandbox is None # pyright: ignore[reportPrivateUsage]
|
||||
assert layer._workflow_execution_id is None # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
def test_on_graph_end_noop_when_sandbox_not_initialized(self):
|
||||
"""Test that on_graph_end is a no-op when sandbox was never initialized."""
|
||||
layer = SandboxLayer()
|
||||
layer = SandboxLayer(tenant_id="test-tenant")
|
||||
|
||||
# Should not raise exception
|
||||
layer.on_graph_end(error=None)
|
||||
|
||||
assert layer._sandbox is None # pyright: ignore[reportPrivateUsage]
|
||||
assert layer._workflow_execution_id is None # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
def test_on_graph_end_is_idempotent(self):
|
||||
"""Test that calling on_graph_end multiple times is safe."""
|
||||
layer = SandboxLayer()
|
||||
layer = SandboxLayer(tenant_id="test-tenant")
|
||||
mock_sandbox = MagicMock(spec=VirtualEnvironment)
|
||||
mock_sandbox.metadata = MockMetadata()
|
||||
workflow_execution_id = "test-exec-idempotent"
|
||||
mock_runtime_state = MockReadOnlyGraphRuntimeStateWrapper(workflow_execution_id)
|
||||
layer._graph_runtime_state = mock_runtime_state # type: ignore[assignment]
|
||||
|
||||
with patch.object(SandboxFactory, "create", return_value=mock_sandbox):
|
||||
with patch(
|
||||
"services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox",
|
||||
return_value=mock_sandbox,
|
||||
):
|
||||
layer.on_graph_start()
|
||||
|
||||
layer.on_graph_end(error=None)
|
||||
layer.on_graph_end(error=None) # Second call should be no-op
|
||||
layer.on_graph_end(error=None)
|
||||
|
||||
mock_sandbox.release_environment.assert_called_once()
|
||||
|
||||
def test_layer_inherits_from_graph_engine_layer(self):
|
||||
"""Test that SandboxLayer properly inherits from GraphEngineLayer."""
|
||||
layer = SandboxLayer()
|
||||
layer = SandboxLayer(tenant_id="test-tenant")
|
||||
|
||||
# Should have the graph_runtime_state property from base class
|
||||
with pytest.raises(GraphEngineLayerNotInitializedError):
|
||||
_ = layer.graph_runtime_state
|
||||
|
||||
# Should have command_channel from base class
|
||||
assert layer.command_channel is None
|
||||
|
||||
|
||||
class TestSandboxLayerIntegration:
|
||||
"""Integration tests for SandboxLayer with real LocalVirtualEnvironment."""
|
||||
def test_full_lifecycle_with_mocked_provider(self):
|
||||
layer = SandboxLayer(tenant_id="integration-tenant")
|
||||
workflow_execution_id = "integration-test-exec"
|
||||
mock_runtime_state = MockReadOnlyGraphRuntimeStateWrapper(workflow_execution_id)
|
||||
layer._graph_runtime_state = mock_runtime_state # type: ignore[assignment]
|
||||
mock_sandbox = MagicMock(spec=VirtualEnvironment)
|
||||
mock_sandbox.metadata = MockMetadata(sandbox_id="integration-sandbox")
|
||||
|
||||
def test_full_lifecycle_with_local_sandbox(self, tmp_path: Path):
|
||||
"""Test complete lifecycle: init -> start -> end with local sandbox."""
|
||||
layer = SandboxLayer(
|
||||
sandbox_type=SandboxType.LOCAL,
|
||||
options={"base_working_path": str(tmp_path)},
|
||||
)
|
||||
with patch(
|
||||
"services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox",
|
||||
return_value=mock_sandbox,
|
||||
):
|
||||
layer.on_graph_start()
|
||||
|
||||
# Start
|
||||
layer.on_graph_start()
|
||||
assert layer._workflow_execution_id == workflow_execution_id # pyright: ignore[reportPrivateUsage]
|
||||
assert layer.sandbox is mock_sandbox
|
||||
assert SandboxManager.get(workflow_execution_id) is mock_sandbox
|
||||
|
||||
# Verify sandbox is created
|
||||
assert layer._sandbox is not None # pyright: ignore[reportPrivateUsage]
|
||||
sandbox_id = layer.sandbox.metadata.id
|
||||
assert sandbox_id is not None
|
||||
|
||||
# End
|
||||
layer.on_graph_end(error=None)
|
||||
|
||||
# Verify sandbox is released
|
||||
assert layer._sandbox is None # pyright: ignore[reportPrivateUsage]
|
||||
assert layer._workflow_execution_id is None # pyright: ignore[reportPrivateUsage]
|
||||
assert not SandboxManager.has(workflow_execution_id)
|
||||
mock_sandbox.release_environment.assert_called_once()
|
||||
|
||||
def test_lifecycle_with_workflow_error(self, tmp_path: Path):
|
||||
"""Test lifecycle when workflow encounters an error."""
|
||||
layer = SandboxLayer(
|
||||
sandbox_type=SandboxType.LOCAL,
|
||||
options={"base_working_path": str(tmp_path)},
|
||||
)
|
||||
def test_lifecycle_with_workflow_error(self):
|
||||
layer = SandboxLayer(tenant_id="error-tenant")
|
||||
workflow_execution_id = "integration-error-test"
|
||||
mock_runtime_state = MockReadOnlyGraphRuntimeStateWrapper(workflow_execution_id)
|
||||
layer._graph_runtime_state = mock_runtime_state # type: ignore[assignment]
|
||||
mock_sandbox = MagicMock(spec=VirtualEnvironment)
|
||||
mock_sandbox.metadata = MockMetadata()
|
||||
|
||||
with patch(
|
||||
"services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox",
|
||||
return_value=mock_sandbox,
|
||||
):
|
||||
layer.on_graph_start()
|
||||
|
||||
layer.on_graph_start()
|
||||
assert layer.sandbox.metadata.id is not None
|
||||
|
||||
# Simulate workflow error
|
||||
layer.on_graph_end(error=Exception("Workflow execution failed"))
|
||||
|
||||
# Sandbox should still be cleaned up
|
||||
# pyright: ignore[reportPrivateUsage]
|
||||
assert layer._sandbox is None # pyright: ignore[reportPrivateUsage]
|
||||
assert layer._workflow_execution_id is None # pyright: ignore[reportPrivateUsage]
|
||||
assert not SandboxManager.has(workflow_execution_id)
|
||||
mock_sandbox.release_environment.assert_called_once()
|
||||
|
||||
@ -0,0 +1,153 @@
|
||||
import threading
|
||||
from collections.abc import Mapping
|
||||
from io import BytesIO
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from core.virtual_environment.__base.entities import Arch, CommandStatus, ConnectionHandle, FileState, Metadata
|
||||
from core.virtual_environment.__base.virtual_environment import VirtualEnvironment
|
||||
from core.virtual_environment.sandbox_manager import SandboxManager
|
||||
|
||||
|
||||
class FakeVirtualEnvironment(VirtualEnvironment):
|
||||
def __init__(self, sandbox_id: str = "fake-id"):
|
||||
self._sandbox_id = sandbox_id
|
||||
super().__init__(tenant_id="test-tenant", options={}, environments={})
|
||||
|
||||
def _construct_environment(self, options: Mapping[str, Any], environments: Mapping[str, str]) -> Metadata:
|
||||
return Metadata(id=self._sandbox_id, arch=Arch.AMD64)
|
||||
|
||||
def upload_file(self, path: str, content: BytesIO) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def download_file(self, path: str) -> BytesIO:
|
||||
raise NotImplementedError
|
||||
|
||||
def list_files(self, directory_path: str, limit: int) -> list[FileState]:
|
||||
return []
|
||||
|
||||
def establish_connection(self) -> ConnectionHandle:
|
||||
return ConnectionHandle(id="conn")
|
||||
|
||||
def release_connection(self, connection_handle: ConnectionHandle) -> None:
|
||||
pass
|
||||
|
||||
def release_environment(self) -> None:
|
||||
pass
|
||||
|
||||
def execute_command(
|
||||
self, connection_handle: ConnectionHandle, command: list[str], environments: Mapping[str, str] | None = None
|
||||
) -> tuple[str, Any, Any, Any]:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_command_status(self, connection_handle: ConnectionHandle, pid: str) -> CommandStatus:
|
||||
return CommandStatus(status=CommandStatus.Status.COMPLETED, exit_code=0)
|
||||
|
||||
@classmethod
|
||||
def validate(cls, options: Mapping[str, Any]) -> None:
|
||||
pass
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def clean_sandbox_manager():
|
||||
SandboxManager.clear()
|
||||
yield
|
||||
SandboxManager.clear()
|
||||
|
||||
|
||||
class TestSandboxManager:
|
||||
def test_register_and_get(self):
|
||||
sandbox = FakeVirtualEnvironment("sandbox-1")
|
||||
|
||||
SandboxManager.register("exec-1", sandbox)
|
||||
result = SandboxManager.get("exec-1")
|
||||
|
||||
assert result is sandbox
|
||||
|
||||
def test_get_returns_none_for_unknown_id(self):
|
||||
result = SandboxManager.get("unknown-id")
|
||||
assert result is None
|
||||
|
||||
def test_register_raises_on_empty_workflow_execution_id(self):
|
||||
sandbox = FakeVirtualEnvironment()
|
||||
|
||||
with pytest.raises(ValueError, match="workflow_execution_id cannot be empty"):
|
||||
SandboxManager.register("", sandbox)
|
||||
|
||||
def test_register_raises_on_duplicate(self):
|
||||
sandbox1 = FakeVirtualEnvironment("sandbox-1")
|
||||
sandbox2 = FakeVirtualEnvironment("sandbox-2")
|
||||
|
||||
SandboxManager.register("exec-dup", sandbox1)
|
||||
|
||||
with pytest.raises(RuntimeError, match="already registered"):
|
||||
SandboxManager.register("exec-dup", sandbox2)
|
||||
|
||||
def test_unregister_returns_sandbox(self):
|
||||
sandbox = FakeVirtualEnvironment("sandbox-to-remove")
|
||||
SandboxManager.register("exec-remove", sandbox)
|
||||
|
||||
result = SandboxManager.unregister("exec-remove")
|
||||
|
||||
assert result is sandbox
|
||||
assert SandboxManager.get("exec-remove") is None
|
||||
|
||||
def test_unregister_returns_none_for_unknown(self):
|
||||
result = SandboxManager.unregister("nonexistent")
|
||||
assert result is None
|
||||
|
||||
def test_has_returns_true_when_registered(self):
|
||||
sandbox = FakeVirtualEnvironment()
|
||||
SandboxManager.register("exec-has", sandbox)
|
||||
|
||||
assert SandboxManager.has("exec-has") is True
|
||||
|
||||
def test_has_returns_false_when_not_registered(self):
|
||||
assert SandboxManager.has("exec-no") is False
|
||||
|
||||
def test_clear_removes_all_sandboxes(self):
|
||||
sandbox1 = FakeVirtualEnvironment("s1")
|
||||
sandbox2 = FakeVirtualEnvironment("s2")
|
||||
SandboxManager.register("exec-1", sandbox1)
|
||||
SandboxManager.register("exec-2", sandbox2)
|
||||
|
||||
SandboxManager.clear()
|
||||
|
||||
assert SandboxManager.count() == 0
|
||||
assert SandboxManager.get("exec-1") is None
|
||||
assert SandboxManager.get("exec-2") is None
|
||||
|
||||
def test_count_returns_number_of_sandboxes(self):
|
||||
assert SandboxManager.count() == 0
|
||||
|
||||
SandboxManager.register("e1", FakeVirtualEnvironment("s1"))
|
||||
assert SandboxManager.count() == 1
|
||||
|
||||
SandboxManager.register("e2", FakeVirtualEnvironment("s2"))
|
||||
assert SandboxManager.count() == 2
|
||||
|
||||
SandboxManager.unregister("e1")
|
||||
assert SandboxManager.count() == 1
|
||||
|
||||
def test_thread_safety(self):
|
||||
results: list[bool] = []
|
||||
errors: list[Exception] = []
|
||||
|
||||
def register_sandbox(exec_id: str):
|
||||
try:
|
||||
sandbox = FakeVirtualEnvironment(f"sandbox-{exec_id}")
|
||||
SandboxManager.register(exec_id, sandbox)
|
||||
results.append(True)
|
||||
except Exception as e:
|
||||
errors.append(e)
|
||||
|
||||
threads = [threading.Thread(target=register_sandbox, args=(f"exec-{i}",)) for i in range(10)]
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
assert len(errors) == 0
|
||||
assert len(results) == 10
|
||||
assert SandboxManager.count() == 10
|
||||
@ -1,11 +1,15 @@
|
||||
import time
|
||||
from collections.abc import Mapping
|
||||
from io import BytesIO
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from core.virtual_environment.__base.entities import Arch, CommandStatus, ConnectionHandle, FileState, Metadata
|
||||
from core.virtual_environment.__base.virtual_environment import VirtualEnvironment
|
||||
from core.virtual_environment.channel.queue_transport import QueueTransportReadCloser
|
||||
from core.virtual_environment.channel.transport import NopTransportWriteCloser
|
||||
from core.virtual_environment.sandbox_manager import SandboxManager
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.enums import WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.command.node import CommandNode
|
||||
@ -30,7 +34,7 @@ class FakeSandbox(VirtualEnvironment):
|
||||
self.released_connections: list[str] = []
|
||||
super().__init__(tenant_id="test-tenant", options={}, environments={})
|
||||
|
||||
def _construct_environment(self, options, environments): # type: ignore[override]
|
||||
def _construct_environment(self, options: Mapping[str, Any], environments: Mapping[str, str]) -> Metadata:
|
||||
return Metadata(id="fake", arch=Arch.ARM64)
|
||||
|
||||
def upload_file(self, path: str, content: BytesIO) -> None:
|
||||
@ -51,7 +55,9 @@ class FakeSandbox(VirtualEnvironment):
|
||||
def release_environment(self) -> None:
|
||||
return
|
||||
|
||||
def execute_command(self, connection_handle: ConnectionHandle, command: list[str], environments=None): # type: ignore[override]
|
||||
def execute_command(
|
||||
self, connection_handle: ConnectionHandle, command: list[str], environments: Mapping[str, str] | None = None
|
||||
) -> tuple[str, NopTransportWriteCloser, QueueTransportReadCloser, QueueTransportReadCloser]:
|
||||
_ = connection_handle
|
||||
_ = environments
|
||||
self.last_execute_command = command
|
||||
@ -76,12 +82,22 @@ class FakeSandbox(VirtualEnvironment):
|
||||
return CommandStatus(status=CommandStatus.Status.COMPLETED, exit_code=0)
|
||||
|
||||
@classmethod
|
||||
def validate(cls, options: Any) -> None:
|
||||
def validate(cls, options: Mapping[str, Any]) -> None:
|
||||
pass
|
||||
|
||||
|
||||
def _make_node(*, command: str, working_directory: str = "") -> CommandNode:
|
||||
variable_pool = VariablePool(system_variables=SystemVariable.empty(), user_inputs={})
|
||||
@pytest.fixture(autouse=True)
|
||||
def clean_sandbox_manager():
|
||||
SandboxManager.clear()
|
||||
yield
|
||||
SandboxManager.clear()
|
||||
|
||||
|
||||
def _make_node(
|
||||
*, command: str, working_directory: str = "", workflow_execution_id: str = "test-workflow-exec-id"
|
||||
) -> CommandNode:
|
||||
system_variables = SystemVariable(workflow_execution_id=workflow_execution_id)
|
||||
variable_pool = VariablePool(system_variables=system_variables, user_inputs={})
|
||||
runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
init_params = GraphInitParams(
|
||||
tenant_id="t",
|
||||
@ -110,11 +126,16 @@ def _make_node(*, command: str, working_directory: str = "") -> CommandNode:
|
||||
|
||||
|
||||
def test_command_node_success_executes_in_sandbox():
|
||||
node = _make_node(command="echo {{#pre_node_id.number#}}", working_directory="dir-{{#pre_node_id.number#}}")
|
||||
workflow_execution_id = "test-exec-success"
|
||||
node = _make_node(
|
||||
command="echo {{#pre_node_id.number#}}",
|
||||
working_directory="dir-{{#pre_node_id.number#}}",
|
||||
workflow_execution_id=workflow_execution_id,
|
||||
)
|
||||
node.graph_runtime_state.variable_pool.add(("pre_node_id", "number"), 42)
|
||||
|
||||
sandbox = FakeSandbox(stdout=b"ok\n", stderr=b"")
|
||||
node.sandbox = sandbox
|
||||
SandboxManager.register(workflow_execution_id, sandbox)
|
||||
|
||||
result = node._run() # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
@ -124,18 +145,19 @@ def test_command_node_success_executes_in_sandbox():
|
||||
assert result.outputs["exit_code"] == 0
|
||||
|
||||
assert sandbox.last_execute_command is not None
|
||||
assert sandbox.last_execute_command[:2] == ["sh", "-lc"]
|
||||
assert sandbox.last_execute_command[:2] == ["sh", "-c"]
|
||||
assert "cd dir-42 && echo 42" in sandbox.last_execute_command[2]
|
||||
|
||||
|
||||
def test_command_node_nonzero_exit_code_returns_failed_result():
|
||||
node = _make_node(command="false")
|
||||
workflow_execution_id = "test-exec-nonzero"
|
||||
node = _make_node(command="false", workflow_execution_id=workflow_execution_id)
|
||||
sandbox = FakeSandbox(
|
||||
stdout=b"out",
|
||||
stderr=b"err",
|
||||
statuses=[CommandStatus(status=CommandStatus.Status.COMPLETED, exit_code=2)],
|
||||
)
|
||||
node.sandbox = sandbox
|
||||
SandboxManager.register(workflow_execution_id, sandbox)
|
||||
|
||||
result = node._run() # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
@ -149,17 +171,29 @@ def test_command_node_timeout_returns_failed_result_and_closes_transports(monkey
|
||||
|
||||
monkeypatch.setattr(command_node_module, "COMMAND_NODE_TIMEOUT_SECONDS", 1)
|
||||
|
||||
node = _make_node(command="sleep 10")
|
||||
workflow_execution_id = "test-exec-timeout"
|
||||
node = _make_node(command="sleep 10", workflow_execution_id=workflow_execution_id)
|
||||
sandbox = FakeSandbox(
|
||||
stdout=b"",
|
||||
stderr=b"",
|
||||
statuses=[CommandStatus(status=CommandStatus.Status.RUNNING, exit_code=None)] * 1000,
|
||||
close_streams=False,
|
||||
)
|
||||
node.sandbox = sandbox
|
||||
SandboxManager.register(workflow_execution_id, sandbox)
|
||||
|
||||
result = node._run() # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
assert result.status == WorkflowNodeExecutionStatus.FAILED
|
||||
assert result.error_type == "CommandTimeoutError"
|
||||
assert "timed out" in result.error
|
||||
|
||||
|
||||
def test_command_node_no_sandbox_returns_failed():
|
||||
workflow_execution_id = "test-exec-no-sandbox"
|
||||
node = _make_node(command="echo hello", workflow_execution_id=workflow_execution_id)
|
||||
|
||||
result = node._run() # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
assert result.status == WorkflowNodeExecutionStatus.FAILED
|
||||
assert result.error_type == "SandboxNotInitializedError"
|
||||
assert "Sandbox not available" in result.error
|
||||
|
||||
Loading…
Reference in New Issue
Block a user