diff --git a/api/core/virtual_environment/channel/exec.py b/api/core/virtual_environment/channel/exec.py new file mode 100644 index 0000000000..6a03e2f766 --- /dev/null +++ b/api/core/virtual_environment/channel/exec.py @@ -0,0 +1,4 @@ +class TransportEOFError(Exception): + """Exception raised when attempting to read from a closed transport.""" + + pass diff --git a/api/core/virtual_environment/channel/pipe_transport.py b/api/core/virtual_environment/channel/pipe_transport.py index 27bc4a3bf5..aecddeb6fc 100644 --- a/api/core/virtual_environment/channel/pipe_transport.py +++ b/api/core/virtual_environment/channel/pipe_transport.py @@ -1,5 +1,6 @@ import os +from core.virtual_environment.channel.exec import TransportEOFError from core.virtual_environment.channel.transport import Transport, TransportReadCloser, TransportWriteCloser @@ -18,10 +19,16 @@ class PipeTransport(Transport): self.w_fd = w_fd def write(self, data: bytes) -> None: - os.write(self.w_fd, data) + try: + os.write(self.w_fd, data) + except OSError: + raise TransportEOFError("Pipe write error, maybe the read end is closed") def read(self, n: int) -> bytes: - return os.read(self.r_fd, n) + data = os.read(self.r_fd, n) + if data == b"": + raise TransportEOFError("End of Pipe reached") + return data def close(self) -> None: os.close(self.r_fd) @@ -37,7 +44,11 @@ class PipeReadCloser(TransportReadCloser): self.r_fd = r_fd def read(self, n: int) -> bytes: - return os.read(self.r_fd, n) + data = os.read(self.r_fd, n) + if data == b"": + raise TransportEOFError("End of Pipe reached") + + return data def close(self) -> None: os.close(self.r_fd) @@ -52,7 +63,10 @@ class PipeWriteCloser(TransportWriteCloser): self.w_fd = w_fd def write(self, data: bytes) -> None: - os.write(self.w_fd, data) + try: + os.write(self.w_fd, data) + except OSError: + raise TransportEOFError("Pipe write error, maybe the read end is closed") def close(self) -> None: os.close(self.w_fd) diff --git a/api/core/virtual_environment/channel/queue_transport.py b/api/core/virtual_environment/channel/queue_transport.py index d87badaf4a..7fd9cbcc35 100644 --- a/api/core/virtual_environment/channel/queue_transport.py +++ b/api/core/virtual_environment/channel/queue_transport.py @@ -1,5 +1,6 @@ from queue import Queue +from core.virtual_environment.channel.exec import TransportEOFError from core.virtual_environment.channel.transport import TransportReadCloser @@ -39,6 +40,9 @@ class QueueTransportReadCloser(TransportReadCloser): Initialize the QueueTransportReadCloser with write function. """ self.q = Queue[bytes | None]() + self._read_buffer = bytearray() + self._closed = False + self._write_channel_closed = False def get_write_handler(self) -> WriteHandler: """ @@ -50,17 +54,47 @@ class QueueTransportReadCloser(TransportReadCloser): """ Close the transport by putting a sentinel value in the queue. """ + if self._write_channel_closed: + raise TransportEOFError("Write channel already closed") + + self._write_channel_closed = True self.q.put(None) def read(self, n: int) -> bytes: """ Read up to n bytes from the queue. + + NEVER USE IT IN A MULTI-THREADED CONTEXT WITHOUT PROPER SYNCHRONIZATION. """ - data = bytearray() - while len(data) < n: + if n <= 0: + return b"" + + if self._closed: + raise TransportEOFError("Transport is closed") + + to_return = self._drain_buffer(n) + while len(to_return) < n and not self._closed: chunk = self.q.get() if chunk is None: - break - data.extend(chunk) + self._closed = True + raise TransportEOFError("Transport is closed") - return bytes(data) + self._read_buffer.extend(chunk) + + if n - len(to_return) > 0: + # Drain the buffer if we still need more data + to_return += self._drain_buffer(n - len(to_return)) + else: + # No more data needed, break + break + + if self.q.qsize() == 0: + # If no more data is available, break to return what we have + break + + return to_return + + def _drain_buffer(self, n: int) -> bytes: + data = bytes(self._read_buffer[:n]) + del self._read_buffer[:n] + return data diff --git a/api/core/virtual_environment/channel/socket_transport.py b/api/core/virtual_environment/channel/socket_transport.py index 87d5cebf6a..904e42df37 100644 --- a/api/core/virtual_environment/channel/socket_transport.py +++ b/api/core/virtual_environment/channel/socket_transport.py @@ -1,5 +1,6 @@ import socket +from core.virtual_environment.channel.exec import TransportEOFError from core.virtual_environment.channel.transport import Transport, TransportReadCloser, TransportWriteCloser @@ -12,10 +13,19 @@ class SocketTransport(Transport): self.sock = sock def write(self, data: bytes) -> None: - self.sock.write(data) + try: + self.sock.write(data) + except (ConnectionResetError, BrokenPipeError): + raise TransportEOFError("Socket write error, maybe the read end is closed") def read(self, n: int) -> bytes: - return self.sock.read(n) + try: + data = self.sock.read(n) + if data == b"": + raise TransportEOFError("End of Socket reached") + except (ConnectionResetError, BrokenPipeError): + raise TransportEOFError("Socket connection reset") + return data def close(self) -> None: self.sock.close() @@ -30,7 +40,13 @@ class SocketReadCloser(TransportReadCloser): self.sock = sock def read(self, n: int) -> bytes: - return self.sock.read(n) + try: + data = self.sock.read(n) + if data == b"": + raise TransportEOFError("End of Socket reached") + return data + except (ConnectionResetError, BrokenPipeError): + raise TransportEOFError("Socket connection reset") def close(self) -> None: self.sock.close() @@ -45,7 +61,10 @@ class SocketWriteCloser(TransportWriteCloser): self.sock = sock def write(self, data: bytes) -> None: - self.sock.write(data) + try: + self.sock.write(data) + except (ConnectionResetError, BrokenPipeError): + raise TransportEOFError("Socket write error, maybe the read end is closed") def close(self) -> None: self.sock.close() diff --git a/api/core/virtual_environment/channel/transport.py b/api/core/virtual_environment/channel/transport.py index 67d24d8797..130538ab63 100644 --- a/api/core/virtual_environment/channel/transport.py +++ b/api/core/virtual_environment/channel/transport.py @@ -23,6 +23,8 @@ class TransportWriter(Protocol): def write(self, data: bytes) -> None: """ Write data to the transport. + + Raises TransportEOFError if the transport is closed. """ @@ -35,6 +37,8 @@ class TransportReader(Protocol): def read(self, n: int) -> bytes: """ Read up to n bytes from the transport. + + Raises TransportEOFError if the end of the transport is reached. """ diff --git a/api/core/virtual_environment/providers/docker_daemon_sandbox.py b/api/core/virtual_environment/providers/docker_daemon_sandbox.py index 9d57aa6148..bb9b626756 100644 --- a/api/core/virtual_environment/providers/docker_daemon_sandbox.py +++ b/api/core/virtual_environment/providers/docker_daemon_sandbox.py @@ -50,8 +50,16 @@ pid, transport_stdout, transport_stderr, transport_stdin = environment.execute_c print(f"Executed command with PID: {pid}") # consume stdout -output = transport_stdout.read(1024) -print(f"Command output: {output.decode().strip()}") +# consume stdout +while True: + try: + output = transport_stdout.read(1024) + except TransportEOFError: + logger.info("End of stdout reached") + break + + logger.info("Command output: %s", output.decode().strip()) + environment.release_connection(connection_handle) environment.release_environment() diff --git a/api/core/virtual_environment/providers/e2b_sandbox.py b/api/core/virtual_environment/providers/e2b_sandbox.py index 1f84d79fbf..86fa58fb68 100644 --- a/api/core/virtual_environment/providers/e2b_sandbox.py +++ b/api/core/virtual_environment/providers/e2b_sandbox.py @@ -1,4 +1,5 @@ import os +import shlex import threading from collections.abc import Mapping, Sequence from enum import StrEnum @@ -50,8 +51,16 @@ pid, transport_stdin, transport_stdout, transport_stderr = environment.execute_c logger.info("Executed command with PID: %s", pid) # consume stdout -output = transport_stdout.read(1024) -logger.info("Command output: %s", output.decode().strip()) +# consume stdout +while True: + try: + output = transport_stdout.read(1024) + except TransportEOFError: + logger.info("End of stdout reached") + break + + logger.info("Command output: %s", output.decode().strip()) + environment.release_connection(connection_handle) environment.release_environment() @@ -204,17 +213,19 @@ class E2BEnvironment(VirtualEnvironment): """ """ stdout_stream_write_handler = stdout_stream.get_write_handler() stderr_stream_write_handler = stderr_stream.get_write_handler() - sandbox.commands.run( - cmd=" ".join(command), - envs=dict(environments or {}), - # stdin=True, - on_stdout=lambda data: stdout_stream_write_handler.write(data.encode()), - on_stderr=lambda data: stderr_stream_write_handler.write(data.encode()), - ) - # Close the write handlers to signal EOF - stdout_stream.close() - stderr_stream.close() + try: + sandbox.commands.run( + cmd=shlex.join(command), + envs=dict(environments or {}), + # stdin=True, + on_stdout=lambda data: stdout_stream_write_handler.write(data.encode()), + on_stderr=lambda data: stderr_stream_write_handler.write(data.encode()), + ) + finally: + # Close the write handlers to signal EOF + stdout_stream.close() + stderr_stream.close() @cached_property def api_key(self) -> str: diff --git a/api/core/virtual_environment/providers/local_without_isolation.py b/api/core/virtual_environment/providers/local_without_isolation.py index 68ffc6978c..c8641e64c7 100644 --- a/api/core/virtual_environment/providers/local_without_isolation.py +++ b/api/core/virtual_environment/providers/local_without_isolation.py @@ -14,6 +14,48 @@ from core.virtual_environment.__base.virtual_environment import VirtualEnvironme from core.virtual_environment.channel.pipe_transport import PipeReadCloser, PipeWriteCloser from core.virtual_environment.channel.transport import TransportReadCloser, TransportWriteCloser +""" +USAGE: + +import logging +from collections.abc import Mapping +from typing import Any + +from core.virtual_environment.channel.exec import TransportEOFError +from core.virtual_environment.providers.local_without_isolation import LocalVirtualEnvironment + +options: Mapping[str, Any] = {} + +logger = logging.getLogger(__name__) +logging.basicConfig(level=logging.DEBUG) + +environment = LocalVirtualEnvironment(options=options) + +connection_handle = environment.establish_connection() + +pid, transport_stdin, transport_stdout, transport_stderr = environment.execute_command( + connection_handle, + ["sh", "-lc", "for i in 1 2 3 4 5; do date '+%F %T'; sleep 1; done"], +) + +logger.info("Executed command with PID: %s", pid) + +# consume stdout +while True: + try: + output = transport_stdout.read(1024) + except TransportEOFError: + logger.info("End of stdout reached") + break + + logger.info("Command output: %s", output.decode().strip()) + + +environment.release_connection(connection_handle) +environment.release_environment() + +""" + class LocalVirtualEnvironment(VirtualEnvironment): """