TensorRT-LLMs/tensorrt_llm/bench/benchmark/utils/processes.py
Frank 8bb3eea285
perf: Readd iteration logging for trtllm-bench. (#3039)
Signed-off-by: Frank Di Natale <3429989+FrankD412@users.noreply.github.com>
2025-04-01 08:13:09 +08:00

174 lines
7.4 KiB
Python

import os
import tempfile
from contextlib import contextmanager
from multiprocessing import Event, Process
from multiprocessing.synchronize import Event as MpEvent
from pathlib import Path
from typing import Optional, Union
from zmq import PULL, Context
from tensorrt_llm import logger
# The IterationWriter class implements a multi-process logging system that captures and writes
# iteration data to a specified file using ZeroMQ (ZMQ) for inter-process communication.
# It uses a producer-consumer pattern where the main process produces messages and a separate
# logging process consumes and writes them to a file.
class IterationWriter:
"""Manages the logging of iteration data to a specified file using inter-process communication.
This class sets up a separate process for logging data to avoid I/O operations blocking the
main process. It uses ZeroMQ's PULL socket pattern for reliable message passing between processes.
Attributes:
address (str): The network address for ZMQ inter-process communication (e.g., "localhost").
port (int): The network port for ZMQ communication.
log_path (Optional[Path]): The filesystem path where iteration data will be logged.
If None, logging is disabled.
Usage:
writer = IterationWriter(Path("iterations.log"))
with writer.capture():
# Any iteration data sent during this context will be logged
# Send data using ZMQ PUSH socket to writer.full_address
"""
def __init__(self, log_path: Optional[Path] = None) -> None:
"""Initialize the IterationWriter with network communication parameters.
Sets up the basic configuration for the logging system. The actual logging process
is not started until the capture() context manager is used.
Args:
address (str): The network address for ZMQ communication (e.g., "localhost").
port (int): The network port number for ZMQ communication.
log_path (Optional[Path]): Path where iteration data will be logged. If None,
logging is disabled and capture() will be a no-op.
"""
self.log_path = log_path
self._socket_path = Path(
tempfile.mkstemp()[1]) if log_path is not None else None
@property
def full_address(self) -> Union[str, None]:
"""Construct the complete ZMQ IPC address string.
Combines the address and port into a ZMQ-compatible IPC URL format.
This address is used by both the logging process (PULL socket) and
any processes that want to send data to be logged (PUSH socket).
Returns:
Union[str, None]: A ZMQ IPC URL (e.g., "ipc://localhost:5555") if log_path
is provided, otherwise None to indicate logging is disabled.
"""
if self._socket_path is not None:
return f"ipc://{self._socket_path}"
else:
return None
@contextmanager
def capture(self) -> contextmanager:
"""Create a context for capturing and logging iteration data.
This context manager handles the lifecycle of the logging process:
1. If logging is enabled (log_path is set):
- Creates a new process for handling log writes
- Sets up an event for coordinating process shutdown
- Starts the logging process
2. If logging is disabled:
- Acts as a no-op context manager
3. On context exit:
- Signals the logging process to stop
- Waits for the process to finish
Yields:
None: The context manager doesn't provide any values to the caller.
Example:
writer = IterationWriter(log_path=Path("log.txt"))
with writer.capture():
# Send data to writer.full_address using ZMQ PUSH socket
# Data will be logged in a separate process
"""
if self._socket_path is None:
logger.info("No log path provided, skipping logging.")
yield
else:
logger.info(f"Logging iterations to {self.log_path}...")
stop = Event()
process = Process(name="IterationWriter",
target=self.run,
args=(self.full_address, self.log_path, stop))
process.start()
try:
yield
finally:
stop.set()
process.join()
def __del__(self) -> None:
if self._socket_path is not None:
os.remove(f"{self._socket_path}")
@staticmethod
def run(address: str, log_path: Path, stop_event: MpEvent) -> None:
"""Execute the logging process that receives and writes iteration data.
This method runs in a separate process and:
1. Sets up a ZMQ PULL socket to receive messages
2. Opens the log file for writing
3. Continuously receives messages and writes them to the log file
4. Handles graceful shutdown on keyboard interrupt
5. Cleans up ZMQ resources on exit
The process continues running until either:
- The stop_event is set (normal shutdown)
- An "end" message is received
- A KeyboardInterrupt occurs
Args:
address (str): The ZMQ IPC address to bind to for receiving messages.
log_path (Path): The file path where received messages will be written.
stop_event (MpEvent): Multiprocessing event used to signal process shutdown.
"""
context = None
socket = None
try:
# Create a ZeroMQ context and socket for inter-process communication
logger.debug(f"Iteration logging: Binding to {address}...")
context = Context(io_threads=1)
socket = context.socket(PULL)
socket.bind(address)
# Open the log file for writing and start listening for messages
logger.debug(
f"Iteration logging: Listening for messages on {address}...")
with open(log_path, "w") as f:
logger.info(f"Iteration logging: Opened log file {log_path}...")
# Receive the first message from the socket
message = socket.recv_json()
logger.debug(f"Iteration logging: Received initial message")
# Continue receiving messages until the stop event is set or an
# "end" message is received
while not stop_event.is_set() and "end" not in message:
f.write(f"{message}\n")
message = socket.recv_json()
logger.debug(f"Iteration logging: Received end message")
except KeyboardInterrupt:
# Handle keyboard interrupt by continuing to receive
# messages until "None" is received. LlmManager will
# send "None" when it is finished.
logger.info("Keyboard interrupt, exiting iteration logging...")
while message != b"None":
message = socket.recv_json()
finally:
# Finalize the logging process by closing the socket and terminating
# the context
logger.info("Finalizing iteration logging...")
if socket is not None:
socket.close()
if context is not None:
context.term()
logger.debug("Iteration logging exiting.")