From 201a18d6ba38a419bf4843279ff4ef9003c7b567 Mon Sep 17 00:00:00 2001 From: Harry Date: Mon, 12 Jan 2026 14:20:03 +0800 Subject: [PATCH] refactor(virtual_environment): add cwd parameter to execute_command method across all providers for improved command execution context --- .../__base/virtual_environment.py | 12 ++++++++-- .../providers/daytona_sandbox.py | 12 ++++++++-- .../providers/docker_daemon_sandbox.py | 11 +++++++-- .../providers/e2b_sandbox.py | 14 +++++++---- .../providers/local_without_isolation.py | 15 +++++------- api/core/workflow/nodes/command/node.py | 22 ++---------------- .../test_sandbox_manager.py | 17 +++++++++++--- .../nodes/command/test_command_node.py | 23 +++++++++++++++---- 8 files changed, 79 insertions(+), 47 deletions(-) diff --git a/api/core/virtual_environment/__base/virtual_environment.py b/api/core/virtual_environment/__base/virtual_environment.py index 4cd1e7178a..b74230f8c5 100644 --- a/api/core/virtual_environment/__base/virtual_environment.py +++ b/api/core/virtual_environment/__base/virtual_environment.py @@ -134,7 +134,11 @@ class VirtualEnvironment(ABC): @abstractmethod def execute_command( - self, connection_handle: ConnectionHandle, command: list[str], environments: Mapping[str, str] | None = None + self, + connection_handle: ConnectionHandle, + command: list[str], + environments: Mapping[str, str] | None = None, + cwd: str | None = None, ) -> tuple[str, TransportWriteCloser, TransportReadCloser, TransportReadCloser]: """ Execute a command in the virtual environment. @@ -142,6 +146,8 @@ class VirtualEnvironment(ABC): Args: connection_handle (ConnectionHandle): The handle for managing the connection. command (list[str]): The command to execute as a list of strings. + environments (Mapping[str, str] | None): Environment variables for the command. + cwd (str | None): Working directory for the command. If None, uses the provider's default. Returns: tuple[int, TransportWriteCloser, TransportReadCloser, TransportReadCloser] @@ -176,6 +182,7 @@ class VirtualEnvironment(ABC): connection_handle: ConnectionHandle, command: list[str], environments: Mapping[str, str] | None = None, + cwd: str | None = None, ) -> CommandFuture: """ Execute a command and return a Future for the result. @@ -187,6 +194,7 @@ class VirtualEnvironment(ABC): connection_handle: The connection handle. command: Command as list of strings. environments: Environment variables. + cwd: Working directory for the command. If None, uses the provider's default. Returns: CommandFuture that can be used to get result with timeout or cancel. @@ -195,7 +203,7 @@ class VirtualEnvironment(ABC): result = env.run_command(handle, ["ls", "-la"]).result(timeout=30) """ pid, stdin_transport, stdout_transport, stderr_transport = self.execute_command( - connection_handle, command, environments + connection_handle, command, environments, cwd ) return CommandFuture( diff --git a/api/core/virtual_environment/providers/daytona_sandbox.py b/api/core/virtual_environment/providers/daytona_sandbox.py index 26993baa15..c19b37a8cb 100644 --- a/api/core/virtual_environment/providers/daytona_sandbox.py +++ b/api/core/virtual_environment/providers/daytona_sandbox.py @@ -185,7 +185,11 @@ class DaytonaEnvironment(VirtualEnvironment): return files def execute_command( - self, connection_handle: ConnectionHandle, command: list[str], environments: Mapping[str, str] | None = None + self, + connection_handle: ConnectionHandle, + command: list[str], + environments: Mapping[str, str] | None = None, + cwd: str | None = None, ) -> tuple[str, TransportWriteCloser, TransportReadCloser, TransportReadCloser]: sandbox: Sandbox = self.metadata.store[self.StoreKey.SANDBOX] @@ -193,9 +197,11 @@ class DaytonaEnvironment(VirtualEnvironment): stderr_stream = QueueTransportReadCloser() pid = uuid4().hex + working_dir = cwd or self._working_dir + thread = threading.Thread( target=self._exec_thread, - args=(pid, sandbox, command, environments or {}, stdout_stream, stderr_stream), + args=(pid, sandbox, command, environments or {}, working_dir, stdout_stream, stderr_stream), daemon=True, ) @@ -236,6 +242,7 @@ class DaytonaEnvironment(VirtualEnvironment): sandbox: Sandbox, command: list[str], environments: Mapping[str, str], + cwd: str, stdout_stream: QueueTransportReadCloser, stderr_stream: QueueTransportReadCloser, ) -> None: @@ -249,6 +256,7 @@ class DaytonaEnvironment(VirtualEnvironment): response = sandbox.process.exec( command=shlex.join(command), env=dict(environments), + cwd=cwd, ) exit_code = response.exit_code output = response.artifacts.stdout if response.artifacts and response.artifacts.stdout else response.result diff --git a/api/core/virtual_environment/providers/docker_daemon_sandbox.py b/api/core/virtual_environment/providers/docker_daemon_sandbox.py index bf71ac27cb..78f90646e4 100644 --- a/api/core/virtual_environment/providers/docker_daemon_sandbox.py +++ b/api/core/virtual_environment/providers/docker_daemon_sandbox.py @@ -449,13 +449,20 @@ class DockerDaemonEnvironment(VirtualEnvironment): return def execute_command( - self, connection_handle: ConnectionHandle, command: list[str], environments: Mapping[str, str] | None = None + self, + connection_handle: ConnectionHandle, + command: list[str], + environments: Mapping[str, str] | None = None, + cwd: str | None = None, ) -> tuple[str, TransportWriteCloser, TransportReadCloser, TransportReadCloser]: container = self._get_container() container_id = container.id if not isinstance(container_id, str) or not container_id: raise RuntimeError("Docker container ID is not available for exec.") api_client = self.get_docker_api_client(self.get_docker_sock()) + + working_dir = cwd or self._working_dir + exec_info: dict[str, object] = cast( dict[str, object], api_client.exec_create( # pyright: ignore[reportUnknownMemberType] # @@ -465,7 +472,7 @@ class DockerDaemonEnvironment(VirtualEnvironment): stdout=True, stderr=True, tty=False, - workdir=self._working_dir, + workdir=working_dir, environment=environments, ), ) diff --git a/api/core/virtual_environment/providers/e2b_sandbox.py b/api/core/virtual_environment/providers/e2b_sandbox.py index 98fec805df..0cef0c3b44 100644 --- a/api/core/virtual_environment/providers/e2b_sandbox.py +++ b/api/core/virtual_environment/providers/e2b_sandbox.py @@ -200,7 +200,11 @@ class E2BEnvironment(VirtualEnvironment): ] def execute_command( - self, connection_handle: ConnectionHandle, command: list[str], environments: Mapping[str, str] | None = None + self, + connection_handle: ConnectionHandle, + command: list[str], + environments: Mapping[str, str] | None = None, + cwd: str | None = None, ) -> tuple[str, TransportWriteCloser, TransportReadCloser, TransportReadCloser]: """ Execute a command in the E2B virtual environment. @@ -212,9 +216,11 @@ class E2BEnvironment(VirtualEnvironment): stdout_stream = QueueTransportReadCloser() stderr_stream = QueueTransportReadCloser() + working_dir = cwd or self._WORKDIR + threading.Thread( target=self._cmd_thread, - args=(sandbox, command, environments, stdout_stream, stderr_stream), + args=(sandbox, command, environments, working_dir, stdout_stream, stderr_stream), ).start() return ( @@ -235,10 +241,10 @@ class E2BEnvironment(VirtualEnvironment): sandbox: Sandbox, command: list[str], environments: Mapping[str, str] | None, + cwd: str, stdout_stream: QueueTransportReadCloser, stderr_stream: QueueTransportReadCloser, ) -> None: - """ """ stdout_stream_write_handler = stdout_stream.get_write_handler() stderr_stream_write_handler = stderr_stream.get_write_handler() @@ -246,7 +252,7 @@ class E2BEnvironment(VirtualEnvironment): sandbox.commands.run( cmd=shlex.join(command), envs=dict(environments or {}), - # stdin=True, + cwd=cwd, on_stdout=lambda data: stdout_stream_write_handler.write(data.encode()), on_stderr=lambda data: stderr_stream_write_handler.write(data.encode()), ) diff --git a/api/core/virtual_environment/providers/local_without_isolation.py b/api/core/virtual_environment/providers/local_without_isolation.py index 6b4d47203d..d1beb57b8e 100644 --- a/api/core/virtual_environment/providers/local_without_isolation.py +++ b/api/core/virtual_environment/providers/local_without_isolation.py @@ -171,16 +171,13 @@ class LocalVirtualEnvironment(VirtualEnvironment): pass def execute_command( - self, connection_handle: ConnectionHandle, command: list[str], environments: Mapping[str, str] | None = None + self, + connection_handle: ConnectionHandle, + command: list[str], + environments: Mapping[str, str] | None = None, + cwd: str | None = None, ) -> tuple[str, TransportWriteCloser, TransportReadCloser, TransportReadCloser]: - """ - Execute a command in the local virtual environment. - - Args: - connection_handle (ConnectionHandle): The connection handle. - command (list[str]): The command to execute. - """ - working_path = self.get_working_path() + working_path = cwd or self.get_working_path() stdin_read_fd, stdin_write_fd = os.pipe() stdout_read_fd, stdout_write_fd = os.pipe() stderr_read_fd, stderr_write_fd = os.pipe() diff --git a/api/core/workflow/nodes/command/node.py b/api/core/workflow/nodes/command/node.py index 0a76ac0399..e753f8dc0f 100644 --- a/api/core/workflow/nodes/command/node.py +++ b/api/core/workflow/nodes/command/node.py @@ -82,26 +82,8 @@ class CommandNode(Node[CommandNodeData]): connection_handle = sandbox.establish_connection() try: - # TODO: VirtualEnvironment.run_command lacks native cwd support. - # Once the interface adds a `cwd` parameter, remove this shell hack - # and pass working_directory directly to run_command. - if working_directory: - check_cmd = ["test", "-d", working_directory] - check_future = sandbox.run_command(connection_handle, check_cmd) - check_result = check_future.result(timeout=timeout) - - if check_result.exit_code != 0: - return NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - error=f"Working directory does not exist: {working_directory}", - error_type="WorkingDirectoryNotFoundError", - ) - - command = ["sh", "-c", f"cd {shlex.quote(working_directory)} && {raw_command}"] - else: - command = shlex.split(raw_command) - - future = sandbox.run_command(connection_handle, command) + command = shlex.split(raw_command) + future = sandbox.run_command(connection_handle, command, cwd=working_directory) result = future.result(timeout=timeout) outputs: dict[str, Any] = { diff --git a/api/tests/unit_tests/core/virtual_environment/test_sandbox_manager.py b/api/tests/unit_tests/core/virtual_environment/test_sandbox_manager.py index c7ac09b8c0..f00049baf8 100644 --- a/api/tests/unit_tests/core/virtual_environment/test_sandbox_manager.py +++ b/api/tests/unit_tests/core/virtual_environment/test_sandbox_manager.py @@ -6,7 +6,14 @@ from typing import Any import pytest from core.sandbox.manager import SandboxManager -from core.virtual_environment.__base.entities import Arch, CommandStatus, ConnectionHandle, FileState, Metadata +from core.virtual_environment.__base.entities import ( + Arch, + CommandStatus, + ConnectionHandle, + FileState, + Metadata, + OperatingSystem, +) from core.virtual_environment.__base.virtual_environment import VirtualEnvironment @@ -16,7 +23,7 @@ class FakeVirtualEnvironment(VirtualEnvironment): super().__init__(tenant_id="test-tenant", options={}, environments={}) def _construct_environment(self, options: Mapping[str, Any], environments: Mapping[str, str]) -> Metadata: - return Metadata(id=self._sandbox_id, arch=Arch.AMD64) + return Metadata(id=self._sandbox_id, arch=Arch.AMD64, os=OperatingSystem.LINUX) def upload_file(self, path: str, content: BytesIO) -> None: raise NotImplementedError @@ -37,7 +44,11 @@ class FakeVirtualEnvironment(VirtualEnvironment): pass def execute_command( - self, connection_handle: ConnectionHandle, command: list[str], environments: Mapping[str, str] | None = None + self, + connection_handle: ConnectionHandle, + command: list[str], + environments: Mapping[str, str] | None = None, + cwd: str | None = None, ) -> tuple[str, Any, Any, Any]: raise NotImplementedError diff --git a/api/tests/unit_tests/core/workflow/nodes/command/test_command_node.py b/api/tests/unit_tests/core/workflow/nodes/command/test_command_node.py index 6dc035076b..7253275366 100644 --- a/api/tests/unit_tests/core/workflow/nodes/command/test_command_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/command/test_command_node.py @@ -6,7 +6,14 @@ from typing import Any import pytest from core.sandbox.manager import SandboxManager -from core.virtual_environment.__base.entities import Arch, CommandStatus, ConnectionHandle, FileState, Metadata +from core.virtual_environment.__base.entities import ( + Arch, + CommandStatus, + ConnectionHandle, + FileState, + Metadata, + OperatingSystem, +) from core.virtual_environment.__base.virtual_environment import VirtualEnvironment from core.virtual_environment.channel.queue_transport import QueueTransportReadCloser from core.virtual_environment.channel.transport import NopTransportWriteCloser @@ -31,11 +38,12 @@ class FakeSandbox(VirtualEnvironment): self._statuses = list(statuses or []) self._close_streams = close_streams self.last_execute_command: list[str] | None = None + self.last_execute_cwd: str | None = None self.released_connections: list[str] = [] super().__init__(tenant_id="test-tenant", options={}, environments={}) def _construct_environment(self, options: Mapping[str, Any], environments: Mapping[str, str]) -> Metadata: - return Metadata(id="fake", arch=Arch.ARM64) + return Metadata(id="fake", arch=Arch.ARM64, os=OperatingSystem.LINUX) def upload_file(self, path: str, content: BytesIO) -> None: raise NotImplementedError @@ -56,11 +64,16 @@ class FakeSandbox(VirtualEnvironment): return def execute_command( - self, connection_handle: ConnectionHandle, command: list[str], environments: Mapping[str, str] | None = None + self, + connection_handle: ConnectionHandle, + command: list[str], + environments: Mapping[str, str] | None = None, + cwd: str | None = None, ) -> tuple[str, NopTransportWriteCloser, QueueTransportReadCloser, QueueTransportReadCloser]: _ = connection_handle _ = environments self.last_execute_command = command + self.last_execute_cwd = cwd stdout = QueueTransportReadCloser() stderr = QueueTransportReadCloser() @@ -145,8 +158,8 @@ def test_command_node_success_executes_in_sandbox(): assert result.outputs["exit_code"] == 0 assert sandbox.last_execute_command is not None - assert sandbox.last_execute_command[:2] == ["sh", "-c"] - assert "cd dir-42 && echo 42" in sandbox.last_execute_command[2] + assert sandbox.last_execute_command == ["echo", "42"] + assert sandbox.last_execute_cwd == "dir-42" def test_command_node_nonzero_exit_code_returns_failed_result():