mirror of
https://github.com/langgenius/dify.git
synced 2026-01-29 15:13:53 +08:00
fix(virtual-env): fix Docker stdout/stderr demuxing and exit code parsing
- Add _DockerDemuxer to properly separate stdout/stderr from multiplexed stream - Fix binary header garbage in Docker exec output (tty=False 8-byte header) - Fix LocalVirtualEnvironment.get_command_status() to use os.WEXITSTATUS() - Update tests to use Transport API instead of raw file descriptors
This commit is contained in:
parent
05c3344554
commit
1a203031e0
@ -1,7 +1,7 @@
|
||||
import socket
|
||||
import tarfile
|
||||
from collections.abc import Mapping, Sequence
|
||||
from enum import StrEnum
|
||||
from enum import IntEnum, StrEnum
|
||||
from functools import lru_cache
|
||||
from io import BytesIO
|
||||
from pathlib import PurePosixPath
|
||||
@ -15,9 +15,134 @@ import docker
|
||||
from core.virtual_environment.__base.entities import Arch, CommandStatus, ConnectionHandle, FileState, Metadata
|
||||
from core.virtual_environment.__base.exec import VirtualEnvironmentLaunchFailedError
|
||||
from core.virtual_environment.__base.virtual_environment import VirtualEnvironment
|
||||
from core.virtual_environment.channel.socket_transport import SocketReadCloser, SocketWriteCloser
|
||||
from core.virtual_environment.channel.exec import TransportEOFError
|
||||
from core.virtual_environment.channel.socket_transport import SocketWriteCloser
|
||||
from core.virtual_environment.channel.transport import TransportReadCloser, TransportWriteCloser
|
||||
|
||||
|
||||
class DockerStreamType(IntEnum):
|
||||
"""
|
||||
Docker multiplexed stream types.
|
||||
|
||||
When Docker exec runs with tty=False, it multiplexes stdout and stderr over a single
|
||||
socket connection. Each frame is prefixed with an 8-byte header:
|
||||
|
||||
[stream_type (1 byte)][0][0][0][payload_size (4 bytes, big-endian)]
|
||||
|
||||
This allows the client to distinguish between stdout (type=1) and stderr (type=2).
|
||||
See: https://docs.docker.com/engine/api/v1.41/#operation/ContainerAttach
|
||||
"""
|
||||
|
||||
STDIN = 0
|
||||
STDOUT = 1
|
||||
STDERR = 2
|
||||
|
||||
|
||||
class DockerDemuxer:
|
||||
"""
|
||||
Demultiplexes Docker's combined stdout/stderr stream.
|
||||
|
||||
Docker exec with tty=False sends stdout and stderr over a single socket,
|
||||
each frame prefixed with an 8-byte header:
|
||||
- Byte 0: stream type (1=stdout, 2=stderr)
|
||||
- Bytes 1-3: reserved (zeros)
|
||||
- Bytes 4-7: payload size (big-endian uint32)
|
||||
|
||||
This class reads frames and routes them to separate stdout/stderr buffers.
|
||||
Without demuxing, output contains binary garbage like:
|
||||
b'\\x01\\x00\\x00\\x00\\x00\\x00\\x00\\x0eHello World\\n'
|
||||
"""
|
||||
|
||||
_HEADER_SIZE = 8
|
||||
|
||||
def __init__(self, sock: socket.SocketIO):
|
||||
self._sock = sock
|
||||
self._stdout_buf = bytearray()
|
||||
self._stderr_buf = bytearray()
|
||||
self._eof = False
|
||||
self._closed = False
|
||||
|
||||
def read_stdout(self, n: int) -> bytes:
|
||||
return self._read_from_buffer(self._stdout_buf, DockerStreamType.STDOUT, n)
|
||||
|
||||
def read_stderr(self, n: int) -> bytes:
|
||||
return self._read_from_buffer(self._stderr_buf, DockerStreamType.STDERR, n)
|
||||
|
||||
def _read_from_buffer(self, buffer: bytearray, target_type: DockerStreamType, n: int) -> bytes:
|
||||
while len(buffer) < n and not self._eof:
|
||||
self._read_next_frame()
|
||||
|
||||
if not buffer:
|
||||
raise TransportEOFError("End of demuxed stream")
|
||||
|
||||
result = bytes(buffer[:n])
|
||||
del buffer[:n]
|
||||
return result
|
||||
|
||||
def _read_next_frame(self) -> None:
|
||||
header = self._read_exact(self._HEADER_SIZE)
|
||||
if not header or len(header) < self._HEADER_SIZE:
|
||||
self._eof = True
|
||||
return
|
||||
|
||||
frame_type = header[0]
|
||||
payload_size = int.from_bytes(header[4:8], "big")
|
||||
|
||||
if payload_size == 0:
|
||||
return
|
||||
|
||||
payload = self._read_exact(payload_size)
|
||||
if not payload:
|
||||
self._eof = True
|
||||
return
|
||||
|
||||
if frame_type == DockerStreamType.STDOUT:
|
||||
self._stdout_buf.extend(payload)
|
||||
elif frame_type == DockerStreamType.STDERR:
|
||||
self._stderr_buf.extend(payload)
|
||||
|
||||
def _read_exact(self, size: int) -> bytes:
|
||||
data = bytearray()
|
||||
remaining = size
|
||||
while remaining > 0:
|
||||
try:
|
||||
chunk = self._sock.read(remaining)
|
||||
if not chunk:
|
||||
return bytes(data) if data else b""
|
||||
data.extend(chunk)
|
||||
remaining -= len(chunk)
|
||||
except (ConnectionResetError, BrokenPipeError):
|
||||
return bytes(data) if data else b""
|
||||
return bytes(data)
|
||||
|
||||
def close(self) -> None:
|
||||
if not self._closed:
|
||||
self._closed = True
|
||||
self._sock.close()
|
||||
|
||||
|
||||
class DemuxedStdoutReader(TransportReadCloser):
|
||||
def __init__(self, demuxer: DockerDemuxer):
|
||||
self._demuxer = demuxer
|
||||
|
||||
def read(self, n: int) -> bytes:
|
||||
return self._demuxer.read_stdout(n)
|
||||
|
||||
def close(self) -> None:
|
||||
self._demuxer.close()
|
||||
|
||||
|
||||
class DemuxedStderrReader(TransportReadCloser):
|
||||
def __init__(self, demuxer: DockerDemuxer):
|
||||
self._demuxer = demuxer
|
||||
|
||||
def read(self, n: int) -> bytes:
|
||||
return self._demuxer.read_stderr(n)
|
||||
|
||||
def close(self) -> None:
|
||||
self._demuxer.close()
|
||||
|
||||
|
||||
"""
|
||||
EXAMPLE:
|
||||
|
||||
@ -288,9 +413,11 @@ class DockerDaemonEnvironment(VirtualEnvironment):
|
||||
raw_sock: socket.SocketIO = cast(socket.SocketIO, api_client.exec_start(exec_id, socket=True, tty=False)) # pyright: ignore[reportUnknownMemberType] #
|
||||
|
||||
stdin_transport = SocketWriteCloser(raw_sock)
|
||||
stdout_transport = SocketReadCloser(raw_sock)
|
||||
demuxer = DockerDemuxer(raw_sock)
|
||||
stdout_transport = DemuxedStdoutReader(demuxer)
|
||||
stderr_transport = DemuxedStderrReader(demuxer)
|
||||
|
||||
return exec_id, stdin_transport, stdout_transport, stdout_transport
|
||||
return exec_id, stdin_transport, stdout_transport, stderr_transport
|
||||
|
||||
def get_command_status(self, connection_handle: ConnectionHandle, pid: str) -> CommandStatus:
|
||||
api_client = self.get_docker_api_client(self.get_docker_sock())
|
||||
|
||||
@ -212,24 +212,20 @@ class LocalVirtualEnvironment(VirtualEnvironment):
|
||||
return str(process.pid), stdin_transport, stdout_transport, stderr_transport
|
||||
|
||||
def get_command_status(self, connection_handle: ConnectionHandle, pid: str) -> CommandStatus:
|
||||
"""
|
||||
Docstring for get_command_status
|
||||
|
||||
:param self: Description
|
||||
:param connection_handle: Description
|
||||
:type connection_handle: ConnectionHandle
|
||||
:param pid: Description
|
||||
:type pid: int
|
||||
:return: Description
|
||||
:rtype: CommandStatus
|
||||
"""
|
||||
pid_int = int(pid)
|
||||
try:
|
||||
retcode = os.waitpid(pid_int, os.WNOHANG)[1]
|
||||
if retcode == 0:
|
||||
waited_pid, wait_status = os.waitpid(pid_int, os.WNOHANG)
|
||||
if waited_pid == 0:
|
||||
return CommandStatus(status=CommandStatus.Status.RUNNING, exit_code=None)
|
||||
|
||||
if os.WIFEXITED(wait_status):
|
||||
exit_code = os.WEXITSTATUS(wait_status)
|
||||
elif os.WIFSIGNALED(wait_status):
|
||||
exit_code = -os.WTERMSIG(wait_status)
|
||||
else:
|
||||
return CommandStatus(status=CommandStatus.Status.COMPLETED, exit_code=retcode)
|
||||
exit_code = None
|
||||
|
||||
return CommandStatus(status=CommandStatus.Status.COMPLETED, exit_code=exit_code)
|
||||
except ChildProcessError:
|
||||
return CommandStatus(status=CommandStatus.Status.COMPLETED, exit_code=None)
|
||||
|
||||
|
||||
@ -1,31 +1,27 @@
|
||||
import os
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from core.virtual_environment.channel.exec import TransportEOFError
|
||||
from core.virtual_environment.channel.transport import TransportReadCloser
|
||||
from core.virtual_environment.providers import local_without_isolation
|
||||
from core.virtual_environment.providers.local_without_isolation import LocalVirtualEnvironment
|
||||
|
||||
|
||||
def _read_all(fd: int) -> bytes:
|
||||
def _drain_transport(transport: TransportReadCloser) -> bytes:
|
||||
chunks: list[bytes] = []
|
||||
while True:
|
||||
data = os.read(fd, 4096)
|
||||
if not data:
|
||||
break
|
||||
chunks.append(data)
|
||||
try:
|
||||
while True:
|
||||
data = transport.read(4096)
|
||||
if not data:
|
||||
break
|
||||
chunks.append(data)
|
||||
except TransportEOFError:
|
||||
pass
|
||||
return b"".join(chunks)
|
||||
|
||||
|
||||
def _close_fds(*fds: int) -> None:
|
||||
for fd in fds:
|
||||
try:
|
||||
os.close(fd)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def local_env(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> LocalVirtualEnvironment:
|
||||
monkeypatch.setattr(local_without_isolation, "machine", lambda: "x86_64")
|
||||
@ -35,7 +31,7 @@ def local_env(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> LocalVirtualEn
|
||||
def test_construct_environment_creates_working_path(local_env: LocalVirtualEnvironment):
|
||||
working_path = local_env.get_working_path()
|
||||
assert local_env.metadata.id
|
||||
assert os.path.isdir(working_path)
|
||||
assert Path(working_path).is_dir()
|
||||
|
||||
|
||||
def test_upload_download_roundtrip(local_env: LocalVirtualEnvironment):
|
||||
@ -54,7 +50,7 @@ def test_list_files_respects_limit(local_env: LocalVirtualEnvironment):
|
||||
all_files = local_env.list_files("", limit=10)
|
||||
all_paths = {state.path for state in all_files}
|
||||
|
||||
assert os.path.join("dir", "file_a.txt") in all_paths
|
||||
assert "dir/file_a.txt" in all_paths or "dir\\file_a.txt" in all_paths
|
||||
assert "file_b.txt" in all_paths
|
||||
|
||||
limited_files = local_env.list_files("", limit=1)
|
||||
@ -66,16 +62,15 @@ def test_execute_command_uses_working_directory(local_env: LocalVirtualEnvironme
|
||||
connection = local_env.establish_connection()
|
||||
command = ["/bin/sh", "-c", "cat message.txt"]
|
||||
|
||||
pid, stdin_fd, stdout_fd, stderr_fd = local_env.execute_command(connection, command)
|
||||
_, stdin_transport, stdout_transport, stderr_transport = local_env.execute_command(connection, command)
|
||||
|
||||
try:
|
||||
os.close(stdin_fd)
|
||||
if hasattr(os, "waitpid"):
|
||||
os.waitpid(pid, 0)
|
||||
stdout = _read_all(stdout_fd)
|
||||
stderr = _read_all(stderr_fd)
|
||||
stdin_transport.close()
|
||||
stdout = _drain_transport(stdout_transport)
|
||||
stderr = _drain_transport(stderr_transport)
|
||||
finally:
|
||||
_close_fds(stdin_fd, stdout_fd, stderr_fd)
|
||||
stdout_transport.close()
|
||||
stderr_transport.close()
|
||||
|
||||
assert stdout == b"hello"
|
||||
assert stderr == b""
|
||||
@ -85,17 +80,37 @@ def test_execute_command_pipes_stdio(local_env: LocalVirtualEnvironment):
|
||||
connection = local_env.establish_connection()
|
||||
command = ["/bin/sh", "-c", "tr a-z A-Z < /dev/stdin; printf ERR >&2"]
|
||||
|
||||
pid, stdin_fd, stdout_fd, stderr_fd = local_env.execute_command(connection, command)
|
||||
_, stdin_transport, stdout_transport, stderr_transport = local_env.execute_command(connection, command)
|
||||
|
||||
try:
|
||||
os.write(stdin_fd, b"abc")
|
||||
os.close(stdin_fd)
|
||||
if hasattr(os, "waitpid"):
|
||||
os.waitpid(pid, 0)
|
||||
stdout = _read_all(stdout_fd)
|
||||
stderr = _read_all(stderr_fd)
|
||||
stdin_transport.write(b"abc")
|
||||
stdin_transport.close()
|
||||
stdout = _drain_transport(stdout_transport)
|
||||
stderr = _drain_transport(stderr_transport)
|
||||
finally:
|
||||
_close_fds(stdin_fd, stdout_fd, stderr_fd)
|
||||
stdout_transport.close()
|
||||
stderr_transport.close()
|
||||
|
||||
assert stdout == b"ABC"
|
||||
assert stderr == b"ERR"
|
||||
|
||||
|
||||
def test_run_command_returns_output(local_env: LocalVirtualEnvironment):
|
||||
local_env.upload_file("message.txt", BytesIO(b"hello"))
|
||||
connection = local_env.establish_connection()
|
||||
|
||||
result = local_env.run_command(connection, ["/bin/sh", "-c", "cat message.txt"]).result(timeout=10)
|
||||
|
||||
assert result.stdout == b"hello"
|
||||
assert result.stderr == b""
|
||||
assert result.exit_code == 0
|
||||
|
||||
|
||||
def test_run_command_captures_stderr(local_env: LocalVirtualEnvironment):
|
||||
connection = local_env.establish_connection()
|
||||
|
||||
result = local_env.run_command(connection, ["/bin/sh", "-c", "echo OUT; echo ERR >&2"]).result(timeout=10)
|
||||
|
||||
assert b"OUT" in result.stdout
|
||||
assert b"ERR" in result.stderr
|
||||
assert result.exit_code == 0
|
||||
|
||||
Loading…
Reference in New Issue
Block a user