refactor(virtual_environment): add cwd parameter to execute_command method across all providers for improved command execution context

This commit is contained in:
Harry 2026-01-12 14:20:03 +08:00
parent f990f4a8d4
commit 201a18d6ba
8 changed files with 79 additions and 47 deletions

View File

@ -134,7 +134,11 @@ class VirtualEnvironment(ABC):
@abstractmethod
def execute_command(
self, connection_handle: ConnectionHandle, command: list[str], environments: Mapping[str, str] | None = None
self,
connection_handle: ConnectionHandle,
command: list[str],
environments: Mapping[str, str] | None = None,
cwd: str | None = None,
) -> tuple[str, TransportWriteCloser, TransportReadCloser, TransportReadCloser]:
"""
Execute a command in the virtual environment.
@ -142,6 +146,8 @@ class VirtualEnvironment(ABC):
Args:
connection_handle (ConnectionHandle): The handle for managing the connection.
command (list[str]): The command to execute as a list of strings.
environments (Mapping[str, str] | None): Environment variables for the command.
cwd (str | None): Working directory for the command. If None, uses the provider's default.
Returns:
tuple[int, TransportWriteCloser, TransportReadCloser, TransportReadCloser]
@ -176,6 +182,7 @@ class VirtualEnvironment(ABC):
connection_handle: ConnectionHandle,
command: list[str],
environments: Mapping[str, str] | None = None,
cwd: str | None = None,
) -> CommandFuture:
"""
Execute a command and return a Future for the result.
@ -187,6 +194,7 @@ class VirtualEnvironment(ABC):
connection_handle: The connection handle.
command: Command as list of strings.
environments: Environment variables.
cwd: Working directory for the command. If None, uses the provider's default.
Returns:
CommandFuture that can be used to get result with timeout or cancel.
@ -195,7 +203,7 @@ class VirtualEnvironment(ABC):
result = env.run_command(handle, ["ls", "-la"]).result(timeout=30)
"""
pid, stdin_transport, stdout_transport, stderr_transport = self.execute_command(
connection_handle, command, environments
connection_handle, command, environments, cwd
)
return CommandFuture(

View File

@ -185,7 +185,11 @@ class DaytonaEnvironment(VirtualEnvironment):
return files
def execute_command(
self, connection_handle: ConnectionHandle, command: list[str], environments: Mapping[str, str] | None = None
self,
connection_handle: ConnectionHandle,
command: list[str],
environments: Mapping[str, str] | None = None,
cwd: str | None = None,
) -> tuple[str, TransportWriteCloser, TransportReadCloser, TransportReadCloser]:
sandbox: Sandbox = self.metadata.store[self.StoreKey.SANDBOX]
@ -193,9 +197,11 @@ class DaytonaEnvironment(VirtualEnvironment):
stderr_stream = QueueTransportReadCloser()
pid = uuid4().hex
working_dir = cwd or self._working_dir
thread = threading.Thread(
target=self._exec_thread,
args=(pid, sandbox, command, environments or {}, stdout_stream, stderr_stream),
args=(pid, sandbox, command, environments or {}, working_dir, stdout_stream, stderr_stream),
daemon=True,
)
@ -236,6 +242,7 @@ class DaytonaEnvironment(VirtualEnvironment):
sandbox: Sandbox,
command: list[str],
environments: Mapping[str, str],
cwd: str,
stdout_stream: QueueTransportReadCloser,
stderr_stream: QueueTransportReadCloser,
) -> None:
@ -249,6 +256,7 @@ class DaytonaEnvironment(VirtualEnvironment):
response = sandbox.process.exec(
command=shlex.join(command),
env=dict(environments),
cwd=cwd,
)
exit_code = response.exit_code
output = response.artifacts.stdout if response.artifacts and response.artifacts.stdout else response.result

View File

@ -449,13 +449,20 @@ class DockerDaemonEnvironment(VirtualEnvironment):
return
def execute_command(
self, connection_handle: ConnectionHandle, command: list[str], environments: Mapping[str, str] | None = None
self,
connection_handle: ConnectionHandle,
command: list[str],
environments: Mapping[str, str] | None = None,
cwd: str | None = None,
) -> tuple[str, TransportWriteCloser, TransportReadCloser, TransportReadCloser]:
container = self._get_container()
container_id = container.id
if not isinstance(container_id, str) or not container_id:
raise RuntimeError("Docker container ID is not available for exec.")
api_client = self.get_docker_api_client(self.get_docker_sock())
working_dir = cwd or self._working_dir
exec_info: dict[str, object] = cast(
dict[str, object],
api_client.exec_create( # pyright: ignore[reportUnknownMemberType] #
@ -465,7 +472,7 @@ class DockerDaemonEnvironment(VirtualEnvironment):
stdout=True,
stderr=True,
tty=False,
workdir=self._working_dir,
workdir=working_dir,
environment=environments,
),
)

View File

@ -200,7 +200,11 @@ class E2BEnvironment(VirtualEnvironment):
]
def execute_command(
self, connection_handle: ConnectionHandle, command: list[str], environments: Mapping[str, str] | None = None
self,
connection_handle: ConnectionHandle,
command: list[str],
environments: Mapping[str, str] | None = None,
cwd: str | None = None,
) -> tuple[str, TransportWriteCloser, TransportReadCloser, TransportReadCloser]:
"""
Execute a command in the E2B virtual environment.
@ -212,9 +216,11 @@ class E2BEnvironment(VirtualEnvironment):
stdout_stream = QueueTransportReadCloser()
stderr_stream = QueueTransportReadCloser()
working_dir = cwd or self._WORKDIR
threading.Thread(
target=self._cmd_thread,
args=(sandbox, command, environments, stdout_stream, stderr_stream),
args=(sandbox, command, environments, working_dir, stdout_stream, stderr_stream),
).start()
return (
@ -235,10 +241,10 @@ class E2BEnvironment(VirtualEnvironment):
sandbox: Sandbox,
command: list[str],
environments: Mapping[str, str] | None,
cwd: str,
stdout_stream: QueueTransportReadCloser,
stderr_stream: QueueTransportReadCloser,
) -> None:
""" """
stdout_stream_write_handler = stdout_stream.get_write_handler()
stderr_stream_write_handler = stderr_stream.get_write_handler()
@ -246,7 +252,7 @@ class E2BEnvironment(VirtualEnvironment):
sandbox.commands.run(
cmd=shlex.join(command),
envs=dict(environments or {}),
# stdin=True,
cwd=cwd,
on_stdout=lambda data: stdout_stream_write_handler.write(data.encode()),
on_stderr=lambda data: stderr_stream_write_handler.write(data.encode()),
)

View File

@ -171,16 +171,13 @@ class LocalVirtualEnvironment(VirtualEnvironment):
pass
def execute_command(
self, connection_handle: ConnectionHandle, command: list[str], environments: Mapping[str, str] | None = None
self,
connection_handle: ConnectionHandle,
command: list[str],
environments: Mapping[str, str] | None = None,
cwd: str | None = None,
) -> tuple[str, TransportWriteCloser, TransportReadCloser, TransportReadCloser]:
"""
Execute a command in the local virtual environment.
Args:
connection_handle (ConnectionHandle): The connection handle.
command (list[str]): The command to execute.
"""
working_path = self.get_working_path()
working_path = cwd or self.get_working_path()
stdin_read_fd, stdin_write_fd = os.pipe()
stdout_read_fd, stdout_write_fd = os.pipe()
stderr_read_fd, stderr_write_fd = os.pipe()

View File

@ -82,26 +82,8 @@ class CommandNode(Node[CommandNodeData]):
connection_handle = sandbox.establish_connection()
try:
# TODO: VirtualEnvironment.run_command lacks native cwd support.
# Once the interface adds a `cwd` parameter, remove this shell hack
# and pass working_directory directly to run_command.
if working_directory:
check_cmd = ["test", "-d", working_directory]
check_future = sandbox.run_command(connection_handle, check_cmd)
check_result = check_future.result(timeout=timeout)
if check_result.exit_code != 0:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=f"Working directory does not exist: {working_directory}",
error_type="WorkingDirectoryNotFoundError",
)
command = ["sh", "-c", f"cd {shlex.quote(working_directory)} && {raw_command}"]
else:
command = shlex.split(raw_command)
future = sandbox.run_command(connection_handle, command)
command = shlex.split(raw_command)
future = sandbox.run_command(connection_handle, command, cwd=working_directory)
result = future.result(timeout=timeout)
outputs: dict[str, Any] = {

View File

@ -6,7 +6,14 @@ from typing import Any
import pytest
from core.sandbox.manager import SandboxManager
from core.virtual_environment.__base.entities import Arch, CommandStatus, ConnectionHandle, FileState, Metadata
from core.virtual_environment.__base.entities import (
Arch,
CommandStatus,
ConnectionHandle,
FileState,
Metadata,
OperatingSystem,
)
from core.virtual_environment.__base.virtual_environment import VirtualEnvironment
@ -16,7 +23,7 @@ class FakeVirtualEnvironment(VirtualEnvironment):
super().__init__(tenant_id="test-tenant", options={}, environments={})
def _construct_environment(self, options: Mapping[str, Any], environments: Mapping[str, str]) -> Metadata:
return Metadata(id=self._sandbox_id, arch=Arch.AMD64)
return Metadata(id=self._sandbox_id, arch=Arch.AMD64, os=OperatingSystem.LINUX)
def upload_file(self, path: str, content: BytesIO) -> None:
raise NotImplementedError
@ -37,7 +44,11 @@ class FakeVirtualEnvironment(VirtualEnvironment):
pass
def execute_command(
self, connection_handle: ConnectionHandle, command: list[str], environments: Mapping[str, str] | None = None
self,
connection_handle: ConnectionHandle,
command: list[str],
environments: Mapping[str, str] | None = None,
cwd: str | None = None,
) -> tuple[str, Any, Any, Any]:
raise NotImplementedError

View File

@ -6,7 +6,14 @@ from typing import Any
import pytest
from core.sandbox.manager import SandboxManager
from core.virtual_environment.__base.entities import Arch, CommandStatus, ConnectionHandle, FileState, Metadata
from core.virtual_environment.__base.entities import (
Arch,
CommandStatus,
ConnectionHandle,
FileState,
Metadata,
OperatingSystem,
)
from core.virtual_environment.__base.virtual_environment import VirtualEnvironment
from core.virtual_environment.channel.queue_transport import QueueTransportReadCloser
from core.virtual_environment.channel.transport import NopTransportWriteCloser
@ -31,11 +38,12 @@ class FakeSandbox(VirtualEnvironment):
self._statuses = list(statuses or [])
self._close_streams = close_streams
self.last_execute_command: list[str] | None = None
self.last_execute_cwd: str | None = None
self.released_connections: list[str] = []
super().__init__(tenant_id="test-tenant", options={}, environments={})
def _construct_environment(self, options: Mapping[str, Any], environments: Mapping[str, str]) -> Metadata:
return Metadata(id="fake", arch=Arch.ARM64)
return Metadata(id="fake", arch=Arch.ARM64, os=OperatingSystem.LINUX)
def upload_file(self, path: str, content: BytesIO) -> None:
raise NotImplementedError
@ -56,11 +64,16 @@ class FakeSandbox(VirtualEnvironment):
return
def execute_command(
self, connection_handle: ConnectionHandle, command: list[str], environments: Mapping[str, str] | None = None
self,
connection_handle: ConnectionHandle,
command: list[str],
environments: Mapping[str, str] | None = None,
cwd: str | None = None,
) -> tuple[str, NopTransportWriteCloser, QueueTransportReadCloser, QueueTransportReadCloser]:
_ = connection_handle
_ = environments
self.last_execute_command = command
self.last_execute_cwd = cwd
stdout = QueueTransportReadCloser()
stderr = QueueTransportReadCloser()
@ -145,8 +158,8 @@ def test_command_node_success_executes_in_sandbox():
assert result.outputs["exit_code"] == 0
assert sandbox.last_execute_command is not None
assert sandbox.last_execute_command[:2] == ["sh", "-c"]
assert "cd dir-42 && echo 42" in sandbox.last_execute_command[2]
assert sandbox.last_execute_command == ["echo", "42"]
assert sandbox.last_execute_cwd == "dir-42"
def test_command_node_nonzero_exit_code_returns_failed_result():