fix(virtual-env): fix Docker stdout/stderr demuxing and exit code parsing

- Add _DockerDemuxer to properly separate stdout/stderr from multiplexed stream
- Fix binary header garbage in Docker exec output (tty=False 8-byte header)
- Fix LocalVirtualEnvironment.get_command_status() to use os.WEXITSTATUS()
- Update tests to use Transport API instead of raw file descriptors
This commit is contained in:
Harry 2026-01-07 12:20:07 +08:00
parent 05c3344554
commit 1a203031e0
3 changed files with 188 additions and 50 deletions

View File

@ -1,7 +1,7 @@
import socket
import tarfile
from collections.abc import Mapping, Sequence
from enum import StrEnum
from enum import IntEnum, StrEnum
from functools import lru_cache
from io import BytesIO
from pathlib import PurePosixPath
@ -15,9 +15,134 @@ import docker
from core.virtual_environment.__base.entities import Arch, CommandStatus, ConnectionHandle, FileState, Metadata
from core.virtual_environment.__base.exec import VirtualEnvironmentLaunchFailedError
from core.virtual_environment.__base.virtual_environment import VirtualEnvironment
from core.virtual_environment.channel.socket_transport import SocketReadCloser, SocketWriteCloser
from core.virtual_environment.channel.exec import TransportEOFError
from core.virtual_environment.channel.socket_transport import SocketWriteCloser
from core.virtual_environment.channel.transport import TransportReadCloser, TransportWriteCloser
class DockerStreamType(IntEnum):
"""
Docker multiplexed stream types.
When Docker exec runs with tty=False, it multiplexes stdout and stderr over a single
socket connection. Each frame is prefixed with an 8-byte header:
[stream_type (1 byte)][0][0][0][payload_size (4 bytes, big-endian)]
This allows the client to distinguish between stdout (type=1) and stderr (type=2).
See: https://docs.docker.com/engine/api/v1.41/#operation/ContainerAttach
"""
STDIN = 0
STDOUT = 1
STDERR = 2
class DockerDemuxer:
"""
Demultiplexes Docker's combined stdout/stderr stream.
Docker exec with tty=False sends stdout and stderr over a single socket,
each frame prefixed with an 8-byte header:
- Byte 0: stream type (1=stdout, 2=stderr)
- Bytes 1-3: reserved (zeros)
- Bytes 4-7: payload size (big-endian uint32)
This class reads frames and routes them to separate stdout/stderr buffers.
Without demuxing, output contains binary garbage like:
b'\\x01\\x00\\x00\\x00\\x00\\x00\\x00\\x0eHello World\\n'
"""
_HEADER_SIZE = 8
def __init__(self, sock: socket.SocketIO):
self._sock = sock
self._stdout_buf = bytearray()
self._stderr_buf = bytearray()
self._eof = False
self._closed = False
def read_stdout(self, n: int) -> bytes:
return self._read_from_buffer(self._stdout_buf, DockerStreamType.STDOUT, n)
def read_stderr(self, n: int) -> bytes:
return self._read_from_buffer(self._stderr_buf, DockerStreamType.STDERR, n)
def _read_from_buffer(self, buffer: bytearray, target_type: DockerStreamType, n: int) -> bytes:
while len(buffer) < n and not self._eof:
self._read_next_frame()
if not buffer:
raise TransportEOFError("End of demuxed stream")
result = bytes(buffer[:n])
del buffer[:n]
return result
def _read_next_frame(self) -> None:
header = self._read_exact(self._HEADER_SIZE)
if not header or len(header) < self._HEADER_SIZE:
self._eof = True
return
frame_type = header[0]
payload_size = int.from_bytes(header[4:8], "big")
if payload_size == 0:
return
payload = self._read_exact(payload_size)
if not payload:
self._eof = True
return
if frame_type == DockerStreamType.STDOUT:
self._stdout_buf.extend(payload)
elif frame_type == DockerStreamType.STDERR:
self._stderr_buf.extend(payload)
def _read_exact(self, size: int) -> bytes:
data = bytearray()
remaining = size
while remaining > 0:
try:
chunk = self._sock.read(remaining)
if not chunk:
return bytes(data) if data else b""
data.extend(chunk)
remaining -= len(chunk)
except (ConnectionResetError, BrokenPipeError):
return bytes(data) if data else b""
return bytes(data)
def close(self) -> None:
if not self._closed:
self._closed = True
self._sock.close()
class DemuxedStdoutReader(TransportReadCloser):
def __init__(self, demuxer: DockerDemuxer):
self._demuxer = demuxer
def read(self, n: int) -> bytes:
return self._demuxer.read_stdout(n)
def close(self) -> None:
self._demuxer.close()
class DemuxedStderrReader(TransportReadCloser):
def __init__(self, demuxer: DockerDemuxer):
self._demuxer = demuxer
def read(self, n: int) -> bytes:
return self._demuxer.read_stderr(n)
def close(self) -> None:
self._demuxer.close()
"""
EXAMPLE:
@ -288,9 +413,11 @@ class DockerDaemonEnvironment(VirtualEnvironment):
raw_sock: socket.SocketIO = cast(socket.SocketIO, api_client.exec_start(exec_id, socket=True, tty=False)) # pyright: ignore[reportUnknownMemberType] #
stdin_transport = SocketWriteCloser(raw_sock)
stdout_transport = SocketReadCloser(raw_sock)
demuxer = DockerDemuxer(raw_sock)
stdout_transport = DemuxedStdoutReader(demuxer)
stderr_transport = DemuxedStderrReader(demuxer)
return exec_id, stdin_transport, stdout_transport, stdout_transport
return exec_id, stdin_transport, stdout_transport, stderr_transport
def get_command_status(self, connection_handle: ConnectionHandle, pid: str) -> CommandStatus:
api_client = self.get_docker_api_client(self.get_docker_sock())

View File

@ -212,24 +212,20 @@ class LocalVirtualEnvironment(VirtualEnvironment):
return str(process.pid), stdin_transport, stdout_transport, stderr_transport
def get_command_status(self, connection_handle: ConnectionHandle, pid: str) -> CommandStatus:
"""
Docstring for get_command_status
:param self: Description
:param connection_handle: Description
:type connection_handle: ConnectionHandle
:param pid: Description
:type pid: int
:return: Description
:rtype: CommandStatus
"""
pid_int = int(pid)
try:
retcode = os.waitpid(pid_int, os.WNOHANG)[1]
if retcode == 0:
waited_pid, wait_status = os.waitpid(pid_int, os.WNOHANG)
if waited_pid == 0:
return CommandStatus(status=CommandStatus.Status.RUNNING, exit_code=None)
if os.WIFEXITED(wait_status):
exit_code = os.WEXITSTATUS(wait_status)
elif os.WIFSIGNALED(wait_status):
exit_code = -os.WTERMSIG(wait_status)
else:
return CommandStatus(status=CommandStatus.Status.COMPLETED, exit_code=retcode)
exit_code = None
return CommandStatus(status=CommandStatus.Status.COMPLETED, exit_code=exit_code)
except ChildProcessError:
return CommandStatus(status=CommandStatus.Status.COMPLETED, exit_code=None)

View File

@ -1,31 +1,27 @@
import os
from io import BytesIO
from pathlib import Path
import pytest
from core.virtual_environment.channel.exec import TransportEOFError
from core.virtual_environment.channel.transport import TransportReadCloser
from core.virtual_environment.providers import local_without_isolation
from core.virtual_environment.providers.local_without_isolation import LocalVirtualEnvironment
def _read_all(fd: int) -> bytes:
def _drain_transport(transport: TransportReadCloser) -> bytes:
chunks: list[bytes] = []
while True:
data = os.read(fd, 4096)
if not data:
break
chunks.append(data)
try:
while True:
data = transport.read(4096)
if not data:
break
chunks.append(data)
except TransportEOFError:
pass
return b"".join(chunks)
def _close_fds(*fds: int) -> None:
for fd in fds:
try:
os.close(fd)
except OSError:
pass
@pytest.fixture
def local_env(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> LocalVirtualEnvironment:
monkeypatch.setattr(local_without_isolation, "machine", lambda: "x86_64")
@ -35,7 +31,7 @@ def local_env(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> LocalVirtualEn
def test_construct_environment_creates_working_path(local_env: LocalVirtualEnvironment):
working_path = local_env.get_working_path()
assert local_env.metadata.id
assert os.path.isdir(working_path)
assert Path(working_path).is_dir()
def test_upload_download_roundtrip(local_env: LocalVirtualEnvironment):
@ -54,7 +50,7 @@ def test_list_files_respects_limit(local_env: LocalVirtualEnvironment):
all_files = local_env.list_files("", limit=10)
all_paths = {state.path for state in all_files}
assert os.path.join("dir", "file_a.txt") in all_paths
assert "dir/file_a.txt" in all_paths or "dir\\file_a.txt" in all_paths
assert "file_b.txt" in all_paths
limited_files = local_env.list_files("", limit=1)
@ -66,16 +62,15 @@ def test_execute_command_uses_working_directory(local_env: LocalVirtualEnvironme
connection = local_env.establish_connection()
command = ["/bin/sh", "-c", "cat message.txt"]
pid, stdin_fd, stdout_fd, stderr_fd = local_env.execute_command(connection, command)
_, stdin_transport, stdout_transport, stderr_transport = local_env.execute_command(connection, command)
try:
os.close(stdin_fd)
if hasattr(os, "waitpid"):
os.waitpid(pid, 0)
stdout = _read_all(stdout_fd)
stderr = _read_all(stderr_fd)
stdin_transport.close()
stdout = _drain_transport(stdout_transport)
stderr = _drain_transport(stderr_transport)
finally:
_close_fds(stdin_fd, stdout_fd, stderr_fd)
stdout_transport.close()
stderr_transport.close()
assert stdout == b"hello"
assert stderr == b""
@ -85,17 +80,37 @@ def test_execute_command_pipes_stdio(local_env: LocalVirtualEnvironment):
connection = local_env.establish_connection()
command = ["/bin/sh", "-c", "tr a-z A-Z < /dev/stdin; printf ERR >&2"]
pid, stdin_fd, stdout_fd, stderr_fd = local_env.execute_command(connection, command)
_, stdin_transport, stdout_transport, stderr_transport = local_env.execute_command(connection, command)
try:
os.write(stdin_fd, b"abc")
os.close(stdin_fd)
if hasattr(os, "waitpid"):
os.waitpid(pid, 0)
stdout = _read_all(stdout_fd)
stderr = _read_all(stderr_fd)
stdin_transport.write(b"abc")
stdin_transport.close()
stdout = _drain_transport(stdout_transport)
stderr = _drain_transport(stderr_transport)
finally:
_close_fds(stdin_fd, stdout_fd, stderr_fd)
stdout_transport.close()
stderr_transport.close()
assert stdout == b"ABC"
assert stderr == b"ERR"
def test_run_command_returns_output(local_env: LocalVirtualEnvironment):
local_env.upload_file("message.txt", BytesIO(b"hello"))
connection = local_env.establish_connection()
result = local_env.run_command(connection, ["/bin/sh", "-c", "cat message.txt"]).result(timeout=10)
assert result.stdout == b"hello"
assert result.stderr == b""
assert result.exit_code == 0
def test_run_command_captures_stderr(local_env: LocalVirtualEnvironment):
connection = local_env.establish_connection()
result = local_env.run_command(connection, ["/bin/sh", "-c", "echo OUT; echo ERR >&2"]).result(timeout=10)
assert b"OUT" in result.stdout
assert b"ERR" in result.stderr
assert result.exit_code == 0