mirror of
https://github.com/langgenius/dify.git
synced 2026-02-14 23:16:14 +08:00
refactor(virtual_environment): add cwd parameter to execute_command method across all providers for improved command execution context
This commit is contained in:
parent
f990f4a8d4
commit
201a18d6ba
@ -134,7 +134,11 @@ class VirtualEnvironment(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def execute_command(
|
||||
self, connection_handle: ConnectionHandle, command: list[str], environments: Mapping[str, str] | None = None
|
||||
self,
|
||||
connection_handle: ConnectionHandle,
|
||||
command: list[str],
|
||||
environments: Mapping[str, str] | None = None,
|
||||
cwd: str | None = None,
|
||||
) -> tuple[str, TransportWriteCloser, TransportReadCloser, TransportReadCloser]:
|
||||
"""
|
||||
Execute a command in the virtual environment.
|
||||
@ -142,6 +146,8 @@ class VirtualEnvironment(ABC):
|
||||
Args:
|
||||
connection_handle (ConnectionHandle): The handle for managing the connection.
|
||||
command (list[str]): The command to execute as a list of strings.
|
||||
environments (Mapping[str, str] | None): Environment variables for the command.
|
||||
cwd (str | None): Working directory for the command. If None, uses the provider's default.
|
||||
|
||||
Returns:
|
||||
tuple[int, TransportWriteCloser, TransportReadCloser, TransportReadCloser]
|
||||
@ -176,6 +182,7 @@ class VirtualEnvironment(ABC):
|
||||
connection_handle: ConnectionHandle,
|
||||
command: list[str],
|
||||
environments: Mapping[str, str] | None = None,
|
||||
cwd: str | None = None,
|
||||
) -> CommandFuture:
|
||||
"""
|
||||
Execute a command and return a Future for the result.
|
||||
@ -187,6 +194,7 @@ class VirtualEnvironment(ABC):
|
||||
connection_handle: The connection handle.
|
||||
command: Command as list of strings.
|
||||
environments: Environment variables.
|
||||
cwd: Working directory for the command. If None, uses the provider's default.
|
||||
|
||||
Returns:
|
||||
CommandFuture that can be used to get result with timeout or cancel.
|
||||
@ -195,7 +203,7 @@ class VirtualEnvironment(ABC):
|
||||
result = env.run_command(handle, ["ls", "-la"]).result(timeout=30)
|
||||
"""
|
||||
pid, stdin_transport, stdout_transport, stderr_transport = self.execute_command(
|
||||
connection_handle, command, environments
|
||||
connection_handle, command, environments, cwd
|
||||
)
|
||||
|
||||
return CommandFuture(
|
||||
|
||||
@ -185,7 +185,11 @@ class DaytonaEnvironment(VirtualEnvironment):
|
||||
return files
|
||||
|
||||
def execute_command(
|
||||
self, connection_handle: ConnectionHandle, command: list[str], environments: Mapping[str, str] | None = None
|
||||
self,
|
||||
connection_handle: ConnectionHandle,
|
||||
command: list[str],
|
||||
environments: Mapping[str, str] | None = None,
|
||||
cwd: str | None = None,
|
||||
) -> tuple[str, TransportWriteCloser, TransportReadCloser, TransportReadCloser]:
|
||||
sandbox: Sandbox = self.metadata.store[self.StoreKey.SANDBOX]
|
||||
|
||||
@ -193,9 +197,11 @@ class DaytonaEnvironment(VirtualEnvironment):
|
||||
stderr_stream = QueueTransportReadCloser()
|
||||
pid = uuid4().hex
|
||||
|
||||
working_dir = cwd or self._working_dir
|
||||
|
||||
thread = threading.Thread(
|
||||
target=self._exec_thread,
|
||||
args=(pid, sandbox, command, environments or {}, stdout_stream, stderr_stream),
|
||||
args=(pid, sandbox, command, environments or {}, working_dir, stdout_stream, stderr_stream),
|
||||
daemon=True,
|
||||
)
|
||||
|
||||
@ -236,6 +242,7 @@ class DaytonaEnvironment(VirtualEnvironment):
|
||||
sandbox: Sandbox,
|
||||
command: list[str],
|
||||
environments: Mapping[str, str],
|
||||
cwd: str,
|
||||
stdout_stream: QueueTransportReadCloser,
|
||||
stderr_stream: QueueTransportReadCloser,
|
||||
) -> None:
|
||||
@ -249,6 +256,7 @@ class DaytonaEnvironment(VirtualEnvironment):
|
||||
response = sandbox.process.exec(
|
||||
command=shlex.join(command),
|
||||
env=dict(environments),
|
||||
cwd=cwd,
|
||||
)
|
||||
exit_code = response.exit_code
|
||||
output = response.artifacts.stdout if response.artifacts and response.artifacts.stdout else response.result
|
||||
|
||||
@ -449,13 +449,20 @@ class DockerDaemonEnvironment(VirtualEnvironment):
|
||||
return
|
||||
|
||||
def execute_command(
|
||||
self, connection_handle: ConnectionHandle, command: list[str], environments: Mapping[str, str] | None = None
|
||||
self,
|
||||
connection_handle: ConnectionHandle,
|
||||
command: list[str],
|
||||
environments: Mapping[str, str] | None = None,
|
||||
cwd: str | None = None,
|
||||
) -> tuple[str, TransportWriteCloser, TransportReadCloser, TransportReadCloser]:
|
||||
container = self._get_container()
|
||||
container_id = container.id
|
||||
if not isinstance(container_id, str) or not container_id:
|
||||
raise RuntimeError("Docker container ID is not available for exec.")
|
||||
api_client = self.get_docker_api_client(self.get_docker_sock())
|
||||
|
||||
working_dir = cwd or self._working_dir
|
||||
|
||||
exec_info: dict[str, object] = cast(
|
||||
dict[str, object],
|
||||
api_client.exec_create( # pyright: ignore[reportUnknownMemberType] #
|
||||
@ -465,7 +472,7 @@ class DockerDaemonEnvironment(VirtualEnvironment):
|
||||
stdout=True,
|
||||
stderr=True,
|
||||
tty=False,
|
||||
workdir=self._working_dir,
|
||||
workdir=working_dir,
|
||||
environment=environments,
|
||||
),
|
||||
)
|
||||
|
||||
@ -200,7 +200,11 @@ class E2BEnvironment(VirtualEnvironment):
|
||||
]
|
||||
|
||||
def execute_command(
|
||||
self, connection_handle: ConnectionHandle, command: list[str], environments: Mapping[str, str] | None = None
|
||||
self,
|
||||
connection_handle: ConnectionHandle,
|
||||
command: list[str],
|
||||
environments: Mapping[str, str] | None = None,
|
||||
cwd: str | None = None,
|
||||
) -> tuple[str, TransportWriteCloser, TransportReadCloser, TransportReadCloser]:
|
||||
"""
|
||||
Execute a command in the E2B virtual environment.
|
||||
@ -212,9 +216,11 @@ class E2BEnvironment(VirtualEnvironment):
|
||||
stdout_stream = QueueTransportReadCloser()
|
||||
stderr_stream = QueueTransportReadCloser()
|
||||
|
||||
working_dir = cwd or self._WORKDIR
|
||||
|
||||
threading.Thread(
|
||||
target=self._cmd_thread,
|
||||
args=(sandbox, command, environments, stdout_stream, stderr_stream),
|
||||
args=(sandbox, command, environments, working_dir, stdout_stream, stderr_stream),
|
||||
).start()
|
||||
|
||||
return (
|
||||
@ -235,10 +241,10 @@ class E2BEnvironment(VirtualEnvironment):
|
||||
sandbox: Sandbox,
|
||||
command: list[str],
|
||||
environments: Mapping[str, str] | None,
|
||||
cwd: str,
|
||||
stdout_stream: QueueTransportReadCloser,
|
||||
stderr_stream: QueueTransportReadCloser,
|
||||
) -> None:
|
||||
""" """
|
||||
stdout_stream_write_handler = stdout_stream.get_write_handler()
|
||||
stderr_stream_write_handler = stderr_stream.get_write_handler()
|
||||
|
||||
@ -246,7 +252,7 @@ class E2BEnvironment(VirtualEnvironment):
|
||||
sandbox.commands.run(
|
||||
cmd=shlex.join(command),
|
||||
envs=dict(environments or {}),
|
||||
# stdin=True,
|
||||
cwd=cwd,
|
||||
on_stdout=lambda data: stdout_stream_write_handler.write(data.encode()),
|
||||
on_stderr=lambda data: stderr_stream_write_handler.write(data.encode()),
|
||||
)
|
||||
|
||||
@ -171,16 +171,13 @@ class LocalVirtualEnvironment(VirtualEnvironment):
|
||||
pass
|
||||
|
||||
def execute_command(
|
||||
self, connection_handle: ConnectionHandle, command: list[str], environments: Mapping[str, str] | None = None
|
||||
self,
|
||||
connection_handle: ConnectionHandle,
|
||||
command: list[str],
|
||||
environments: Mapping[str, str] | None = None,
|
||||
cwd: str | None = None,
|
||||
) -> tuple[str, TransportWriteCloser, TransportReadCloser, TransportReadCloser]:
|
||||
"""
|
||||
Execute a command in the local virtual environment.
|
||||
|
||||
Args:
|
||||
connection_handle (ConnectionHandle): The connection handle.
|
||||
command (list[str]): The command to execute.
|
||||
"""
|
||||
working_path = self.get_working_path()
|
||||
working_path = cwd or self.get_working_path()
|
||||
stdin_read_fd, stdin_write_fd = os.pipe()
|
||||
stdout_read_fd, stdout_write_fd = os.pipe()
|
||||
stderr_read_fd, stderr_write_fd = os.pipe()
|
||||
|
||||
@ -82,26 +82,8 @@ class CommandNode(Node[CommandNodeData]):
|
||||
connection_handle = sandbox.establish_connection()
|
||||
|
||||
try:
|
||||
# TODO: VirtualEnvironment.run_command lacks native cwd support.
|
||||
# Once the interface adds a `cwd` parameter, remove this shell hack
|
||||
# and pass working_directory directly to run_command.
|
||||
if working_directory:
|
||||
check_cmd = ["test", "-d", working_directory]
|
||||
check_future = sandbox.run_command(connection_handle, check_cmd)
|
||||
check_result = check_future.result(timeout=timeout)
|
||||
|
||||
if check_result.exit_code != 0:
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=f"Working directory does not exist: {working_directory}",
|
||||
error_type="WorkingDirectoryNotFoundError",
|
||||
)
|
||||
|
||||
command = ["sh", "-c", f"cd {shlex.quote(working_directory)} && {raw_command}"]
|
||||
else:
|
||||
command = shlex.split(raw_command)
|
||||
|
||||
future = sandbox.run_command(connection_handle, command)
|
||||
command = shlex.split(raw_command)
|
||||
future = sandbox.run_command(connection_handle, command, cwd=working_directory)
|
||||
result = future.result(timeout=timeout)
|
||||
|
||||
outputs: dict[str, Any] = {
|
||||
|
||||
@ -6,7 +6,14 @@ from typing import Any
|
||||
import pytest
|
||||
|
||||
from core.sandbox.manager import SandboxManager
|
||||
from core.virtual_environment.__base.entities import Arch, CommandStatus, ConnectionHandle, FileState, Metadata
|
||||
from core.virtual_environment.__base.entities import (
|
||||
Arch,
|
||||
CommandStatus,
|
||||
ConnectionHandle,
|
||||
FileState,
|
||||
Metadata,
|
||||
OperatingSystem,
|
||||
)
|
||||
from core.virtual_environment.__base.virtual_environment import VirtualEnvironment
|
||||
|
||||
|
||||
@ -16,7 +23,7 @@ class FakeVirtualEnvironment(VirtualEnvironment):
|
||||
super().__init__(tenant_id="test-tenant", options={}, environments={})
|
||||
|
||||
def _construct_environment(self, options: Mapping[str, Any], environments: Mapping[str, str]) -> Metadata:
|
||||
return Metadata(id=self._sandbox_id, arch=Arch.AMD64)
|
||||
return Metadata(id=self._sandbox_id, arch=Arch.AMD64, os=OperatingSystem.LINUX)
|
||||
|
||||
def upload_file(self, path: str, content: BytesIO) -> None:
|
||||
raise NotImplementedError
|
||||
@ -37,7 +44,11 @@ class FakeVirtualEnvironment(VirtualEnvironment):
|
||||
pass
|
||||
|
||||
def execute_command(
|
||||
self, connection_handle: ConnectionHandle, command: list[str], environments: Mapping[str, str] | None = None
|
||||
self,
|
||||
connection_handle: ConnectionHandle,
|
||||
command: list[str],
|
||||
environments: Mapping[str, str] | None = None,
|
||||
cwd: str | None = None,
|
||||
) -> tuple[str, Any, Any, Any]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@ -6,7 +6,14 @@ from typing import Any
|
||||
import pytest
|
||||
|
||||
from core.sandbox.manager import SandboxManager
|
||||
from core.virtual_environment.__base.entities import Arch, CommandStatus, ConnectionHandle, FileState, Metadata
|
||||
from core.virtual_environment.__base.entities import (
|
||||
Arch,
|
||||
CommandStatus,
|
||||
ConnectionHandle,
|
||||
FileState,
|
||||
Metadata,
|
||||
OperatingSystem,
|
||||
)
|
||||
from core.virtual_environment.__base.virtual_environment import VirtualEnvironment
|
||||
from core.virtual_environment.channel.queue_transport import QueueTransportReadCloser
|
||||
from core.virtual_environment.channel.transport import NopTransportWriteCloser
|
||||
@ -31,11 +38,12 @@ class FakeSandbox(VirtualEnvironment):
|
||||
self._statuses = list(statuses or [])
|
||||
self._close_streams = close_streams
|
||||
self.last_execute_command: list[str] | None = None
|
||||
self.last_execute_cwd: str | None = None
|
||||
self.released_connections: list[str] = []
|
||||
super().__init__(tenant_id="test-tenant", options={}, environments={})
|
||||
|
||||
def _construct_environment(self, options: Mapping[str, Any], environments: Mapping[str, str]) -> Metadata:
|
||||
return Metadata(id="fake", arch=Arch.ARM64)
|
||||
return Metadata(id="fake", arch=Arch.ARM64, os=OperatingSystem.LINUX)
|
||||
|
||||
def upload_file(self, path: str, content: BytesIO) -> None:
|
||||
raise NotImplementedError
|
||||
@ -56,11 +64,16 @@ class FakeSandbox(VirtualEnvironment):
|
||||
return
|
||||
|
||||
def execute_command(
|
||||
self, connection_handle: ConnectionHandle, command: list[str], environments: Mapping[str, str] | None = None
|
||||
self,
|
||||
connection_handle: ConnectionHandle,
|
||||
command: list[str],
|
||||
environments: Mapping[str, str] | None = None,
|
||||
cwd: str | None = None,
|
||||
) -> tuple[str, NopTransportWriteCloser, QueueTransportReadCloser, QueueTransportReadCloser]:
|
||||
_ = connection_handle
|
||||
_ = environments
|
||||
self.last_execute_command = command
|
||||
self.last_execute_cwd = cwd
|
||||
|
||||
stdout = QueueTransportReadCloser()
|
||||
stderr = QueueTransportReadCloser()
|
||||
@ -145,8 +158,8 @@ def test_command_node_success_executes_in_sandbox():
|
||||
assert result.outputs["exit_code"] == 0
|
||||
|
||||
assert sandbox.last_execute_command is not None
|
||||
assert sandbox.last_execute_command[:2] == ["sh", "-c"]
|
||||
assert "cd dir-42 && echo 42" in sandbox.last_execute_command[2]
|
||||
assert sandbox.last_execute_command == ["echo", "42"]
|
||||
assert sandbox.last_execute_cwd == "dir-42"
|
||||
|
||||
|
||||
def test_command_node_nonzero_exit_code_returns_failed_result():
|
||||
|
||||
Loading…
Reference in New Issue
Block a user