mirror of
https://github.com/langgenius/dify.git
synced 2026-01-14 06:07:33 +08:00
Refactor code structure for improved readability and maintainability
This commit is contained in:
parent
a513ab9a59
commit
274f9a3f32
@ -1,4 +1,6 @@
|
||||
from collections.abc import Mapping
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@ -19,6 +21,9 @@ class Metadata(BaseModel):
|
||||
|
||||
id: str = Field(description="The unique identifier of the virtual environment.")
|
||||
arch: Arch = Field(description="Which architecture was used to create the virtual environment.")
|
||||
store: Mapping[str, Any] = Field(
|
||||
default_factory=dict, description="The store information of the virtual environment., Additional data."
|
||||
)
|
||||
|
||||
|
||||
class ConnectionHandle(BaseModel):
|
||||
@ -34,5 +39,21 @@ class CommandStatus(BaseModel):
|
||||
Status of a command executed in the virtual environment.
|
||||
"""
|
||||
|
||||
class Status(StrEnum):
|
||||
RUNNING = "running"
|
||||
COMPLETED = "completed"
|
||||
|
||||
pid: int = Field(description="The process ID of the command.")
|
||||
return_code: int = Field(description="The return code of the command execution.")
|
||||
status: Status = Field(description="The status of the command execution.")
|
||||
exit_code: int | None = Field(description="The return code of the command execution.")
|
||||
|
||||
|
||||
class FileState(BaseModel):
|
||||
"""
|
||||
State of a file in the virtual environment.
|
||||
"""
|
||||
|
||||
size: int = Field(description="The size of the file in bytes.")
|
||||
path: str = Field(description="The path of the file in the virtual environment.")
|
||||
created_at: int = Field(description="The creation timestamp of the file.")
|
||||
updated_at: int = Field(description="The last modified timestamp of the file.")
|
||||
|
||||
10
api/core/virtual_environment/__base/exec.py
Normal file
10
api/core/virtual_environment/__base/exec.py
Normal file
@ -0,0 +1,10 @@
|
||||
class ArchNotSupportedError(Exception):
|
||||
"""Exception raised when the architecture is not supported."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class VirtualEnvironmentLaunchFailedError(Exception):
|
||||
"""Exception raised when launching the virtual environment fails."""
|
||||
|
||||
pass
|
||||
@ -1,9 +1,9 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Mapping
|
||||
from collections.abc import Mapping, Sequence
|
||||
from io import BytesIO
|
||||
from typing import Any
|
||||
|
||||
from core.virtual_environment.__base.entities import CommandStatus, ConnectionHandle, Metadata
|
||||
from core.virtual_environment.__base.entities import CommandStatus, ConnectionHandle, FileState, Metadata
|
||||
|
||||
|
||||
class VirtualEnvironment(ABC):
|
||||
@ -11,30 +11,30 @@ class VirtualEnvironment(ABC):
|
||||
Base class for virtual environment implementations.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def request_environment(self, options: Mapping[str, Any]) -> Metadata:
|
||||
def __init__(self, options: Mapping[str, Any]) -> None:
|
||||
"""
|
||||
Initialize the virtual environment with metadata.
|
||||
"""
|
||||
Request a virtual environment with the given options.
|
||||
|
||||
Args:
|
||||
options (Mapping[str, Any]): Options for requesting the virtual environment.
|
||||
Those options are implementation-specific, which can be defined in environment
|
||||
self.options = options
|
||||
self.metadata = self.construct_environment(options)
|
||||
|
||||
@abstractmethod
|
||||
def construct_environment(self, options: Mapping[str, Any]) -> Metadata:
|
||||
"""
|
||||
Construct the unique identifier for the virtual environment.
|
||||
|
||||
Returns:
|
||||
Metadata: Metadata about the requested virtual environment.
|
||||
|
||||
Raises:
|
||||
Exception: If the environment cannot be requested.
|
||||
str: The unique identifier of the virtual environment.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def upload_file(self, environment_id: str, destination_path: str, content: BytesIO) -> None:
|
||||
def upload_file(self, path: str, content: BytesIO) -> None:
|
||||
"""
|
||||
Upload a file to the virtual environment.
|
||||
|
||||
Args:
|
||||
environment_id (str): The unique identifier of the virtual environment.
|
||||
destination_path (str): The destination path in the virtual environment.
|
||||
path (str): The destination path in the virtual environment.
|
||||
content (BytesIO): The content of the file to upload.
|
||||
|
||||
Raises:
|
||||
@ -42,12 +42,49 @@ class VirtualEnvironment(ABC):
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def establish_connection(self, environment_id: str) -> ConnectionHandle:
|
||||
def download_file(self, path: str) -> BytesIO:
|
||||
"""
|
||||
Establish a connection to the virtual environment.
|
||||
Download a file from the virtual environment.
|
||||
|
||||
Args:
|
||||
environment_id (str): The unique identifier of the virtual environment.
|
||||
source_path (str): The source path in the virtual environment.
|
||||
Returns:
|
||||
BytesIO: The content of the downloaded file.
|
||||
Raises:
|
||||
Exception: If the file cannot be downloaded.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def list_files(self, directory_path: str, limit: int) -> Sequence[FileState]:
|
||||
"""
|
||||
List files in a directory of the virtual environment.
|
||||
|
||||
Args:
|
||||
directory_path (str): The directory path in the virtual environment.
|
||||
limit (int): The maximum number of files(including recursive paths) to return.
|
||||
Returns:
|
||||
Sequence[FileState]: A list of file states in the specified directory.
|
||||
Raises:
|
||||
Exception: If the files cannot be listed.
|
||||
|
||||
Example:
|
||||
If the directory structure is like:
|
||||
/dir
|
||||
/subdir1
|
||||
file1.txt
|
||||
/subdir2
|
||||
file2.txt
|
||||
And limit is 2, the returned list may look like:
|
||||
[
|
||||
FileState(path="/dir/subdir1/file1.txt", is_directory=False, size=1234, created_at=..., updated_at=...),
|
||||
FileState(path="/dir/subdir2", is_directory=True, size=0, created_at=..., updated_at=...),
|
||||
]
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def establish_connection(self) -> ConnectionHandle:
|
||||
"""
|
||||
Establish a connection to the virtual environment.
|
||||
|
||||
Returns:
|
||||
ConnectionHandle: Handle for managing the connection to the virtual environment.
|
||||
@ -69,13 +106,10 @@ class VirtualEnvironment(ABC):
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def release_environment(self, environment_id: str) -> None:
|
||||
def release_environment(self) -> None:
|
||||
"""
|
||||
Release the virtual environment.
|
||||
|
||||
Args:
|
||||
environment_id (str): The unique identifier of the virtual environment.
|
||||
|
||||
Raises:
|
||||
Exception: If the environment cannot be released.
|
||||
Multiple calls to `release_environment` with the same `environment_id` is acceptable.
|
||||
@ -92,7 +126,7 @@ class VirtualEnvironment(ABC):
|
||||
|
||||
Returns:
|
||||
tuple[int, int, int, int]: A tuple containing pid and 3 handle to os.pipe(): (stdin, stdout, stderr).
|
||||
After exuection, the 3 handles will be closed by `execute_command` itself.
|
||||
After exuection, the 3 handles will be closed by caller.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
|
||||
@ -0,0 +1,60 @@
|
||||
from collections.abc import Mapping
|
||||
from functools import lru_cache
|
||||
from typing import Any
|
||||
|
||||
from docker.models.containers import Container
|
||||
|
||||
import docker
|
||||
from core.virtual_environment.__base.entities import Arch, Metadata
|
||||
from core.virtual_environment.__base.exec import ArchNotSupportedError, VirtualEnvironmentLaunchFailedError
|
||||
from core.virtual_environment.__base.virtual_environment import VirtualEnvironment
|
||||
|
||||
|
||||
class DockerDaemonEnvironment(VirtualEnvironment):
|
||||
def construct_environment(self, options: Mapping[str, Any]) -> Metadata:
|
||||
"""
|
||||
Construct the Docker daemon virtual environment.
|
||||
"""
|
||||
|
||||
docker_sock = options.get("docker_sock", "unix:///var/run/docker.sock")
|
||||
docker_client = self.get_docker_daemon(docker_sock)
|
||||
|
||||
# TODO: use a better image in practice
|
||||
default_docker_image = options.get("docker_agent_image", "ubuntu:latest")
|
||||
|
||||
container = docker_client.containers.run(image=default_docker_image, detach=True, remove=True)
|
||||
|
||||
# wait for the container to be fully started
|
||||
container.reload()
|
||||
|
||||
if not container.id:
|
||||
raise VirtualEnvironmentLaunchFailedError("Failed to start Docker container for DockerDaemonEnvironment.")
|
||||
|
||||
return Metadata(
|
||||
id=container.id,
|
||||
arch=self._get_container_architecture(container),
|
||||
)
|
||||
|
||||
@lru_cache(maxsize=5)
|
||||
@classmethod
|
||||
def get_docker_daemon(cls, docker_sock: str) -> docker.DockerClient:
|
||||
"""
|
||||
Get the Docker daemon client.
|
||||
|
||||
NOTE: I guess nobody will use more than 5 different docker sockets in practice....
|
||||
"""
|
||||
return docker.DockerClient(base_url=docker_sock)
|
||||
|
||||
def _get_container_architecture(self, container: Container) -> Arch:
|
||||
"""
|
||||
Get the architecture of the Docker container.
|
||||
"""
|
||||
container.reload()
|
||||
arch_str: str = container.attrs["Architecture"]
|
||||
match arch_str.lower():
|
||||
case "x86_64" | "amd64":
|
||||
return Arch.AMD64
|
||||
case "aarch64" | "arm64":
|
||||
return Arch.ARM64
|
||||
case _:
|
||||
raise ArchNotSupportedError(f"Architecture {arch_str} is not supported in DockerDaemonEnvironment.")
|
||||
@ -0,0 +1,221 @@
|
||||
import os
|
||||
import pathlib
|
||||
import subprocess
|
||||
from collections.abc import Mapping, Sequence
|
||||
from functools import cached_property
|
||||
from io import BytesIO
|
||||
from platform import machine
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
from core.virtual_environment.__base.entities import Arch, CommandStatus, ConnectionHandle, FileState, Metadata
|
||||
from core.virtual_environment.__base.exec import ArchNotSupportedError
|
||||
from core.virtual_environment.__base.virtual_environment import VirtualEnvironment
|
||||
|
||||
|
||||
class LocalVirtualEnvironment(VirtualEnvironment):
|
||||
"""
|
||||
Local virtual environment provider without isolation.
|
||||
|
||||
WARNING: This provider does not provide any isolation. It's only suitable for development and testing purposes.
|
||||
NEVER USE IT IN PRODUCTION ENVIRONMENTS.
|
||||
"""
|
||||
|
||||
def construct_environment(self, options: Mapping[str, Any]) -> Metadata:
|
||||
"""
|
||||
Construct the local virtual environment.
|
||||
|
||||
Under local without isolation, this method simply create a path for the environment and return the metadata.
|
||||
"""
|
||||
id = uuid4().hex
|
||||
working_path = os.path.join(self._base_working_path, id)
|
||||
os.makedirs(working_path, exist_ok=True)
|
||||
return Metadata(
|
||||
id=id,
|
||||
arch=self._get_os_architecture(),
|
||||
)
|
||||
|
||||
def release_environment(self) -> None:
|
||||
"""
|
||||
Release the local virtual environment.
|
||||
|
||||
Just simply remove the working directory.
|
||||
"""
|
||||
working_path = self.get_working_path()
|
||||
if os.path.exists(working_path):
|
||||
os.rmdir(working_path)
|
||||
|
||||
def upload_file(self, path: str, content: BytesIO) -> None:
|
||||
"""
|
||||
Upload a file to the local virtual environment.
|
||||
|
||||
Args:
|
||||
path (str): The path to upload the file to.
|
||||
content (BytesIO): The content of the file.
|
||||
"""
|
||||
working_path = self.get_working_path()
|
||||
full_path = os.path.join(working_path, path)
|
||||
os.makedirs(os.path.dirname(full_path), exist_ok=True)
|
||||
pathlib.Path(full_path).write_bytes(content.getbuffer())
|
||||
|
||||
def download_file(self, path: str) -> BytesIO:
|
||||
"""
|
||||
Download a file from the local virtual environment.
|
||||
|
||||
Args:
|
||||
path (str): The path to download the file from.
|
||||
Returns:
|
||||
BytesIO: The content of the file.
|
||||
"""
|
||||
working_path = self.get_working_path()
|
||||
full_path = os.path.join(working_path, path)
|
||||
content = pathlib.Path(full_path).read_bytes()
|
||||
return BytesIO(content)
|
||||
|
||||
def list_files(self, directory_path: str, limit: int) -> Sequence[FileState]:
|
||||
"""
|
||||
List files in a directory of the local virtual environment.
|
||||
"""
|
||||
working_path = self.get_working_path()
|
||||
full_directory_path = os.path.join(working_path, directory_path)
|
||||
files: list[FileState] = []
|
||||
for root, _, filenames in os.walk(full_directory_path):
|
||||
for filename in filenames:
|
||||
if len(files) >= limit:
|
||||
break
|
||||
file_path = os.path.relpath(os.path.join(root, filename), working_path)
|
||||
state = os.stat(os.path.join(root, filename))
|
||||
files.append(
|
||||
FileState(
|
||||
path=file_path,
|
||||
size=state.st_size,
|
||||
created_at=int(state.st_ctime),
|
||||
updated_at=int(state.st_mtime),
|
||||
)
|
||||
)
|
||||
if len(files) >= limit:
|
||||
# break the outer loop as well
|
||||
return files
|
||||
|
||||
return files
|
||||
|
||||
def establish_connection(self) -> ConnectionHandle:
|
||||
"""
|
||||
Establish a connection to the local virtual environment.
|
||||
"""
|
||||
return ConnectionHandle(
|
||||
id=uuid4().hex,
|
||||
)
|
||||
|
||||
def release_connection(self, connection_handle: ConnectionHandle) -> None:
|
||||
"""
|
||||
Release the connection to the local virtual environment.
|
||||
"""
|
||||
# No action needed for local without isolation
|
||||
pass
|
||||
|
||||
def execute_command(self, connection_handle: ConnectionHandle, command: list[str]) -> tuple[int, int, int, int]:
|
||||
"""
|
||||
Execute a command in the local virtual environment.
|
||||
|
||||
Args:
|
||||
connection_handle (ConnectionHandle): The connection handle.
|
||||
command (list[str]): The command to execute.
|
||||
"""
|
||||
working_path = self.get_working_path()
|
||||
stdin_read_fd, stdin_write_fd = os.pipe()
|
||||
stdout_read_fd, stdout_write_fd = os.pipe()
|
||||
stderr_read_fd, stderr_write_fd = os.pipe()
|
||||
try:
|
||||
process = subprocess.Popen(
|
||||
command,
|
||||
stdin=stdin_read_fd,
|
||||
stdout=stdout_write_fd,
|
||||
stderr=stderr_write_fd,
|
||||
cwd=working_path,
|
||||
close_fds=True,
|
||||
)
|
||||
except Exception:
|
||||
# Clean up file descriptors if process creation fails
|
||||
for fd in (
|
||||
stdin_read_fd,
|
||||
stdin_write_fd,
|
||||
stdout_read_fd,
|
||||
stdout_write_fd,
|
||||
stderr_read_fd,
|
||||
stderr_write_fd,
|
||||
):
|
||||
try:
|
||||
os.close(fd)
|
||||
except OSError:
|
||||
pass
|
||||
raise
|
||||
|
||||
# Close unused fds in the parent process
|
||||
os.close(stdin_read_fd)
|
||||
os.close(stdout_write_fd)
|
||||
os.close(stderr_write_fd)
|
||||
|
||||
# Return the process ID and file descriptors for stdin, stdout, and stderr
|
||||
return process.pid, stdin_write_fd, stdout_read_fd, stderr_read_fd
|
||||
|
||||
def get_command_status(self, connection_handle: ConnectionHandle, pid: int) -> CommandStatus:
|
||||
"""
|
||||
Docstring for get_command_status
|
||||
|
||||
:param self: Description
|
||||
:param connection_handle: Description
|
||||
:type connection_handle: ConnectionHandle
|
||||
:param pid: Description
|
||||
:type pid: int
|
||||
:return: Description
|
||||
:rtype: CommandStatus
|
||||
"""
|
||||
try:
|
||||
retcode = os.waitpid(pid, os.WNOHANG)[1]
|
||||
if retcode == 0:
|
||||
return CommandStatus(status=CommandStatus.Status.RUNNING, pid=pid, exit_code=None)
|
||||
else:
|
||||
return CommandStatus(status=CommandStatus.Status.COMPLETED, pid=pid, exit_code=retcode)
|
||||
except ChildProcessError:
|
||||
return CommandStatus(status=CommandStatus.Status.COMPLETED, pid=pid, exit_code=None)
|
||||
|
||||
def _get_os_architecture(self) -> Arch:
|
||||
"""
|
||||
Get the operating system architecture.
|
||||
|
||||
Returns:
|
||||
Arch: The operating system architecture.
|
||||
"""
|
||||
|
||||
arch = machine()
|
||||
match arch:
|
||||
case "x86_64" | "AMD64":
|
||||
return Arch.AMD64
|
||||
case "aarch64" | "ARM64":
|
||||
return Arch.ARM64
|
||||
case _:
|
||||
raise ArchNotSupportedError(f"Unsupported architecture: {arch}")
|
||||
|
||||
@cached_property
|
||||
def _base_working_path(self) -> str:
|
||||
"""
|
||||
Get the base working path for the local virtual environment.
|
||||
|
||||
Args:
|
||||
options (Mapping[str, Any]): Options for requesting the virtual environment.
|
||||
|
||||
Returns:
|
||||
str: The base working path.
|
||||
"""
|
||||
cwd = os.getcwd()
|
||||
return self.options.get("base_working_path", os.path.join(cwd, "local_virtual_environments"))
|
||||
|
||||
def get_working_path(self) -> str:
|
||||
"""
|
||||
Get the working path for the local virtual environment.
|
||||
|
||||
Returns:
|
||||
str: The working path.
|
||||
"""
|
||||
return os.path.join(self._base_working_path, self.metadata.id)
|
||||
@ -93,6 +93,7 @@ dependencies = [
|
||||
"weaviate-client==4.17.0",
|
||||
"apscheduler>=3.11.0",
|
||||
"weave>=0.52.16",
|
||||
"docker>=7.1.0",
|
||||
]
|
||||
# Before adding new dependency, consider place it in
|
||||
# alphabet order (a-z) and suitable group.
|
||||
|
||||
@ -0,0 +1,101 @@
|
||||
import os
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from core.virtual_environment.providers import local_without_isolation
|
||||
from core.virtual_environment.providers.local_without_isolation import LocalVirtualEnvironment
|
||||
|
||||
|
||||
def _read_all(fd: int) -> bytes:
|
||||
chunks: list[bytes] = []
|
||||
while True:
|
||||
data = os.read(fd, 4096)
|
||||
if not data:
|
||||
break
|
||||
chunks.append(data)
|
||||
return b"".join(chunks)
|
||||
|
||||
|
||||
def _close_fds(*fds: int) -> None:
|
||||
for fd in fds:
|
||||
try:
|
||||
os.close(fd)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def local_env(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> LocalVirtualEnvironment:
|
||||
monkeypatch.setattr(local_without_isolation, "machine", lambda: "x86_64")
|
||||
return LocalVirtualEnvironment({"base_working_path": str(tmp_path)})
|
||||
|
||||
|
||||
def test_construct_environment_creates_working_path(local_env: LocalVirtualEnvironment):
|
||||
working_path = local_env.get_working_path()
|
||||
assert local_env.metadata.id
|
||||
assert os.path.isdir(working_path)
|
||||
|
||||
|
||||
def test_upload_download_roundtrip(local_env: LocalVirtualEnvironment):
|
||||
content = BytesIO(b"payload")
|
||||
local_env.upload_file("nested/file.txt", content)
|
||||
|
||||
downloaded = local_env.download_file("nested/file.txt")
|
||||
|
||||
assert downloaded.getvalue() == b"payload"
|
||||
|
||||
|
||||
def test_list_files_respects_limit(local_env: LocalVirtualEnvironment):
|
||||
local_env.upload_file("dir/file_a.txt", BytesIO(b"a"))
|
||||
local_env.upload_file("file_b.txt", BytesIO(b"b"))
|
||||
|
||||
all_files = local_env.list_files("", limit=10)
|
||||
all_paths = {state.path for state in all_files}
|
||||
|
||||
assert os.path.join("dir", "file_a.txt") in all_paths
|
||||
assert "file_b.txt" in all_paths
|
||||
|
||||
limited_files = local_env.list_files("", limit=1)
|
||||
assert len(limited_files) == 1
|
||||
|
||||
|
||||
def test_execute_command_uses_working_directory(local_env: LocalVirtualEnvironment):
|
||||
local_env.upload_file("message.txt", BytesIO(b"hello"))
|
||||
connection = local_env.establish_connection()
|
||||
command = ["/bin/sh", "-c", "cat message.txt"]
|
||||
|
||||
pid, stdin_fd, stdout_fd, stderr_fd = local_env.execute_command(connection, command)
|
||||
|
||||
try:
|
||||
os.close(stdin_fd)
|
||||
if hasattr(os, "waitpid"):
|
||||
os.waitpid(pid, 0)
|
||||
stdout = _read_all(stdout_fd)
|
||||
stderr = _read_all(stderr_fd)
|
||||
finally:
|
||||
_close_fds(stdin_fd, stdout_fd, stderr_fd)
|
||||
|
||||
assert stdout == b"hello"
|
||||
assert stderr == b""
|
||||
|
||||
|
||||
def test_execute_command_pipes_stdio(local_env: LocalVirtualEnvironment):
|
||||
connection = local_env.establish_connection()
|
||||
command = ["/bin/sh", "-c", "tr a-z A-Z < /dev/stdin; printf ERR >&2"]
|
||||
|
||||
pid, stdin_fd, stdout_fd, stderr_fd = local_env.execute_command(connection, command)
|
||||
|
||||
try:
|
||||
os.write(stdin_fd, b"abc")
|
||||
os.close(stdin_fd)
|
||||
if hasattr(os, "waitpid"):
|
||||
os.waitpid(pid, 0)
|
||||
stdout = _read_all(stdout_fd)
|
||||
stderr = _read_all(stderr_fd)
|
||||
finally:
|
||||
_close_fds(stdin_fd, stdout_fd, stderr_fd)
|
||||
|
||||
assert stdout == b"ABC"
|
||||
assert stderr == b"ERR"
|
||||
4648
api/uv.lock
generated
4648
api/uv.lock
generated
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user