diff --git a/api/core/virtual_environment/providers/docker_daemon_sandbox.py b/api/core/virtual_environment/providers/docker_daemon_sandbox.py index 3e52c9feee..f9b70721a0 100644 --- a/api/core/virtual_environment/providers/docker_daemon_sandbox.py +++ b/api/core/virtual_environment/providers/docker_daemon_sandbox.py @@ -1,7 +1,7 @@ import socket import tarfile from collections.abc import Mapping, Sequence -from enum import StrEnum +from enum import IntEnum, StrEnum from functools import lru_cache from io import BytesIO from pathlib import PurePosixPath @@ -15,9 +15,134 @@ import docker from core.virtual_environment.__base.entities import Arch, CommandStatus, ConnectionHandle, FileState, Metadata from core.virtual_environment.__base.exec import VirtualEnvironmentLaunchFailedError from core.virtual_environment.__base.virtual_environment import VirtualEnvironment -from core.virtual_environment.channel.socket_transport import SocketReadCloser, SocketWriteCloser +from core.virtual_environment.channel.exec import TransportEOFError +from core.virtual_environment.channel.socket_transport import SocketWriteCloser from core.virtual_environment.channel.transport import TransportReadCloser, TransportWriteCloser + +class DockerStreamType(IntEnum): + """ + Docker multiplexed stream types. + + When Docker exec runs with tty=False, it multiplexes stdout and stderr over a single + socket connection. Each frame is prefixed with an 8-byte header: + + [stream_type (1 byte)][0][0][0][payload_size (4 bytes, big-endian)] + + This allows the client to distinguish between stdout (type=1) and stderr (type=2). + See: https://docs.docker.com/engine/api/v1.41/#operation/ContainerAttach + """ + + STDIN = 0 + STDOUT = 1 + STDERR = 2 + + +class DockerDemuxer: + """ + Demultiplexes Docker's combined stdout/stderr stream. + + Docker exec with tty=False sends stdout and stderr over a single socket, + each frame prefixed with an 8-byte header: + - Byte 0: stream type (1=stdout, 2=stderr) + - Bytes 1-3: reserved (zeros) + - Bytes 4-7: payload size (big-endian uint32) + + This class reads frames and routes them to separate stdout/stderr buffers. + Without demuxing, output contains binary garbage like: + b'\\x01\\x00\\x00\\x00\\x00\\x00\\x00\\x0eHello World\\n' + """ + + _HEADER_SIZE = 8 + + def __init__(self, sock: socket.SocketIO): + self._sock = sock + self._stdout_buf = bytearray() + self._stderr_buf = bytearray() + self._eof = False + self._closed = False + + def read_stdout(self, n: int) -> bytes: + return self._read_from_buffer(self._stdout_buf, DockerStreamType.STDOUT, n) + + def read_stderr(self, n: int) -> bytes: + return self._read_from_buffer(self._stderr_buf, DockerStreamType.STDERR, n) + + def _read_from_buffer(self, buffer: bytearray, target_type: DockerStreamType, n: int) -> bytes: + while len(buffer) < n and not self._eof: + self._read_next_frame() + + if not buffer: + raise TransportEOFError("End of demuxed stream") + + result = bytes(buffer[:n]) + del buffer[:n] + return result + + def _read_next_frame(self) -> None: + header = self._read_exact(self._HEADER_SIZE) + if not header or len(header) < self._HEADER_SIZE: + self._eof = True + return + + frame_type = header[0] + payload_size = int.from_bytes(header[4:8], "big") + + if payload_size == 0: + return + + payload = self._read_exact(payload_size) + if not payload: + self._eof = True + return + + if frame_type == DockerStreamType.STDOUT: + self._stdout_buf.extend(payload) + elif frame_type == DockerStreamType.STDERR: + self._stderr_buf.extend(payload) + + def _read_exact(self, size: int) -> bytes: + data = bytearray() + remaining = size + while remaining > 0: + try: + chunk = self._sock.read(remaining) + if not chunk: + return bytes(data) if data else b"" + data.extend(chunk) + remaining -= len(chunk) + except (ConnectionResetError, BrokenPipeError): + return bytes(data) if data else b"" + return bytes(data) + + def close(self) -> None: + if not self._closed: + self._closed = True + self._sock.close() + + +class DemuxedStdoutReader(TransportReadCloser): + def __init__(self, demuxer: DockerDemuxer): + self._demuxer = demuxer + + def read(self, n: int) -> bytes: + return self._demuxer.read_stdout(n) + + def close(self) -> None: + self._demuxer.close() + + +class DemuxedStderrReader(TransportReadCloser): + def __init__(self, demuxer: DockerDemuxer): + self._demuxer = demuxer + + def read(self, n: int) -> bytes: + return self._demuxer.read_stderr(n) + + def close(self) -> None: + self._demuxer.close() + + """ EXAMPLE: @@ -288,9 +413,11 @@ class DockerDaemonEnvironment(VirtualEnvironment): raw_sock: socket.SocketIO = cast(socket.SocketIO, api_client.exec_start(exec_id, socket=True, tty=False)) # pyright: ignore[reportUnknownMemberType] # stdin_transport = SocketWriteCloser(raw_sock) - stdout_transport = SocketReadCloser(raw_sock) + demuxer = DockerDemuxer(raw_sock) + stdout_transport = DemuxedStdoutReader(demuxer) + stderr_transport = DemuxedStderrReader(demuxer) - return exec_id, stdin_transport, stdout_transport, stdout_transport + return exec_id, stdin_transport, stdout_transport, stderr_transport def get_command_status(self, connection_handle: ConnectionHandle, pid: str) -> CommandStatus: api_client = self.get_docker_api_client(self.get_docker_sock()) diff --git a/api/core/virtual_environment/providers/local_without_isolation.py b/api/core/virtual_environment/providers/local_without_isolation.py index 29ec8c0569..ff7d26e986 100644 --- a/api/core/virtual_environment/providers/local_without_isolation.py +++ b/api/core/virtual_environment/providers/local_without_isolation.py @@ -212,24 +212,20 @@ class LocalVirtualEnvironment(VirtualEnvironment): return str(process.pid), stdin_transport, stdout_transport, stderr_transport def get_command_status(self, connection_handle: ConnectionHandle, pid: str) -> CommandStatus: - """ - Docstring for get_command_status - - :param self: Description - :param connection_handle: Description - :type connection_handle: ConnectionHandle - :param pid: Description - :type pid: int - :return: Description - :rtype: CommandStatus - """ pid_int = int(pid) try: - retcode = os.waitpid(pid_int, os.WNOHANG)[1] - if retcode == 0: + waited_pid, wait_status = os.waitpid(pid_int, os.WNOHANG) + if waited_pid == 0: return CommandStatus(status=CommandStatus.Status.RUNNING, exit_code=None) + + if os.WIFEXITED(wait_status): + exit_code = os.WEXITSTATUS(wait_status) + elif os.WIFSIGNALED(wait_status): + exit_code = -os.WTERMSIG(wait_status) else: - return CommandStatus(status=CommandStatus.Status.COMPLETED, exit_code=retcode) + exit_code = None + + return CommandStatus(status=CommandStatus.Status.COMPLETED, exit_code=exit_code) except ChildProcessError: return CommandStatus(status=CommandStatus.Status.COMPLETED, exit_code=None) diff --git a/api/tests/unit_tests/core/virtual_environment/test_local_without_isolation.py b/api/tests/unit_tests/core/virtual_environment/test_local_without_isolation.py index 1c7696c656..23ef39bc4c 100644 --- a/api/tests/unit_tests/core/virtual_environment/test_local_without_isolation.py +++ b/api/tests/unit_tests/core/virtual_environment/test_local_without_isolation.py @@ -1,31 +1,27 @@ -import os from io import BytesIO from pathlib import Path import pytest +from core.virtual_environment.channel.exec import TransportEOFError +from core.virtual_environment.channel.transport import TransportReadCloser from core.virtual_environment.providers import local_without_isolation from core.virtual_environment.providers.local_without_isolation import LocalVirtualEnvironment -def _read_all(fd: int) -> bytes: +def _drain_transport(transport: TransportReadCloser) -> bytes: chunks: list[bytes] = [] - while True: - data = os.read(fd, 4096) - if not data: - break - chunks.append(data) + try: + while True: + data = transport.read(4096) + if not data: + break + chunks.append(data) + except TransportEOFError: + pass return b"".join(chunks) -def _close_fds(*fds: int) -> None: - for fd in fds: - try: - os.close(fd) - except OSError: - pass - - @pytest.fixture def local_env(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> LocalVirtualEnvironment: monkeypatch.setattr(local_without_isolation, "machine", lambda: "x86_64") @@ -35,7 +31,7 @@ def local_env(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> LocalVirtualEn def test_construct_environment_creates_working_path(local_env: LocalVirtualEnvironment): working_path = local_env.get_working_path() assert local_env.metadata.id - assert os.path.isdir(working_path) + assert Path(working_path).is_dir() def test_upload_download_roundtrip(local_env: LocalVirtualEnvironment): @@ -54,7 +50,7 @@ def test_list_files_respects_limit(local_env: LocalVirtualEnvironment): all_files = local_env.list_files("", limit=10) all_paths = {state.path for state in all_files} - assert os.path.join("dir", "file_a.txt") in all_paths + assert "dir/file_a.txt" in all_paths or "dir\\file_a.txt" in all_paths assert "file_b.txt" in all_paths limited_files = local_env.list_files("", limit=1) @@ -66,16 +62,15 @@ def test_execute_command_uses_working_directory(local_env: LocalVirtualEnvironme connection = local_env.establish_connection() command = ["/bin/sh", "-c", "cat message.txt"] - pid, stdin_fd, stdout_fd, stderr_fd = local_env.execute_command(connection, command) + _, stdin_transport, stdout_transport, stderr_transport = local_env.execute_command(connection, command) try: - os.close(stdin_fd) - if hasattr(os, "waitpid"): - os.waitpid(pid, 0) - stdout = _read_all(stdout_fd) - stderr = _read_all(stderr_fd) + stdin_transport.close() + stdout = _drain_transport(stdout_transport) + stderr = _drain_transport(stderr_transport) finally: - _close_fds(stdin_fd, stdout_fd, stderr_fd) + stdout_transport.close() + stderr_transport.close() assert stdout == b"hello" assert stderr == b"" @@ -85,17 +80,37 @@ def test_execute_command_pipes_stdio(local_env: LocalVirtualEnvironment): connection = local_env.establish_connection() command = ["/bin/sh", "-c", "tr a-z A-Z < /dev/stdin; printf ERR >&2"] - pid, stdin_fd, stdout_fd, stderr_fd = local_env.execute_command(connection, command) + _, stdin_transport, stdout_transport, stderr_transport = local_env.execute_command(connection, command) try: - os.write(stdin_fd, b"abc") - os.close(stdin_fd) - if hasattr(os, "waitpid"): - os.waitpid(pid, 0) - stdout = _read_all(stdout_fd) - stderr = _read_all(stderr_fd) + stdin_transport.write(b"abc") + stdin_transport.close() + stdout = _drain_transport(stdout_transport) + stderr = _drain_transport(stderr_transport) finally: - _close_fds(stdin_fd, stdout_fd, stderr_fd) + stdout_transport.close() + stderr_transport.close() assert stdout == b"ABC" assert stderr == b"ERR" + + +def test_run_command_returns_output(local_env: LocalVirtualEnvironment): + local_env.upload_file("message.txt", BytesIO(b"hello")) + connection = local_env.establish_connection() + + result = local_env.run_command(connection, ["/bin/sh", "-c", "cat message.txt"]).result(timeout=10) + + assert result.stdout == b"hello" + assert result.stderr == b"" + assert result.exit_code == 0 + + +def test_run_command_captures_stderr(local_env: LocalVirtualEnvironment): + connection = local_env.establish_connection() + + result = local_env.run_command(connection, ["/bin/sh", "-c", "echo OUT; echo ERR >&2"]).result(timeout=10) + + assert b"OUT" in result.stdout + assert b"ERR" in result.stderr + assert result.exit_code == 0