mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
Signed-off-by: Superjomn <328693+Superjomn@users.noreply.github.com> Signed-off-by: Erin Ho <14718778+hchings@users.noreply.github.com> Signed-off-by: Yan Chunwei <328693+Superjomn@users.noreply.github.com> Co-authored-by: Erin Ho <14718778+hchings@users.noreply.github.com>
591 lines
24 KiB
Python
591 lines
24 KiB
Python
import atexit
|
|
import faulthandler
|
|
import multiprocessing
|
|
import platform
|
|
import signal
|
|
import traceback
|
|
from abc import ABC, abstractmethod
|
|
from collections.abc import Mapping
|
|
from pathlib import Path
|
|
from queue import Queue
|
|
from typing import (TYPE_CHECKING, AsyncIterable, Dict, Generator, List,
|
|
Optional, Union)
|
|
|
|
import numpy as np
|
|
import torch
|
|
|
|
from tensorrt_llm.inputs.multimodal import MultimodalParams
|
|
from tensorrt_llm.logger import logger, set_level
|
|
|
|
from .._utils import mpi_world_size
|
|
from ..bindings import executor as tllm
|
|
from ..builder import Engine
|
|
from ..disaggregated_params import DisaggregatedParams
|
|
from ..llmapi.llm_args import BaseLlmArgs, TorchLlmArgs
|
|
from ..llmapi.llm_utils import KvCacheRetentionConfig
|
|
from ..llmapi.mpi_session import (MpiSession, external_mpi_comm_available,
|
|
need_spawn_mpi_workers)
|
|
from ..llmapi.tokenizer import TokenizerBase
|
|
from ..llmapi.utils import (AsyncQueue, enable_llm_debug,
|
|
enable_worker_single_process_for_tp1, logger_debug,
|
|
print_colored)
|
|
from ..sampling_params import (BatchedLogitsProcessor, LogprobParams,
|
|
SamplingParams)
|
|
from ..scheduling_params import SchedulingParams
|
|
from .ipc import FusedIpcQueue
|
|
from .postproc_worker import PostprocParams, PostprocWorkerConfig
|
|
from .request import GenerationRequest, LoRARequest, PromptAdapterRequest
|
|
from .result import GenerationResult, IterationResult
|
|
from .utils import IntraProcessQueue, ProcessPoolExecutorSession, RequestError
|
|
|
|
if TYPE_CHECKING:
|
|
from .proxy import GenerationExecutorProxy
|
|
from .worker import GenerationExecutorWorker
|
|
|
|
__all__ = [
|
|
"GenerationExecutor",
|
|
"CppExecutorError",
|
|
]
|
|
|
|
if enable_llm_debug():
|
|
# Mainly enable more detailed logging from cpp runtime.
|
|
set_level("info")
|
|
|
|
|
|
async def empty_async_iterable() -> AsyncIterable:
|
|
if False: # ensures the function remains an async generator
|
|
yield
|
|
|
|
|
|
class CppExecutorError(RuntimeError):
|
|
|
|
def __init__(self, message: Optional[str] = None):
|
|
self.message = message
|
|
self.stack_trace = traceback.format_exc()
|
|
super().__init__(message)
|
|
|
|
def __str__(self):
|
|
return f"{self.message}\nStack trace:\n{self.stack_trace}"
|
|
|
|
|
|
class IterationResultQueue:
|
|
is_initialized: bool = False
|
|
# FusedIpcQueue or IntraProcessQueue is used to communicate results from workers to proxy
|
|
queue: Optional[Union[Queue, FusedIpcQueue, IntraProcessQueue]] = None
|
|
aqueue: Optional[AsyncQueue] = None
|
|
|
|
|
|
class GenerationExecutor(ABC):
|
|
|
|
def __init__(self,
|
|
num_postprocess_workers: int = 0,
|
|
postprocess_tokenizer_dir: Optional[str] = None,
|
|
is_llm_executor: Optional[bool] = None):
|
|
self.postproc_config = PostprocWorkerConfig(
|
|
num_postprocess_workers=num_postprocess_workers,
|
|
postprocess_tokenizer_dir=postprocess_tokenizer_dir)
|
|
|
|
self.kv_events_queues = IterationResultQueue()
|
|
self.stats_queues = IterationResultQueue()
|
|
|
|
atexit.register(self.shutdown)
|
|
|
|
# This is used to capture the exceptions from the threads.
|
|
self._error_queue = Queue()
|
|
|
|
# A flag to avoid calling shutdown() recursively. This happens when the background threads raise errors.
|
|
self.doing_shutdown = False
|
|
|
|
self._last_client_id: int = 1
|
|
|
|
# whether it's the executor instance of LLM API
|
|
self._is_llm_executor = is_llm_executor
|
|
self._iter_kv_events_result: IterationResult | None = None
|
|
self._iter_stats_result: IterationResult | None = None
|
|
|
|
@abstractmethod
|
|
def submit(self, request: GenerationRequest) -> GenerationResult:
|
|
pass
|
|
|
|
@abstractmethod
|
|
def abort_request(self, request_id: int) -> None:
|
|
pass
|
|
|
|
def generate_async(
|
|
self,
|
|
prompt_token_ids: List[int],
|
|
sampling_params: SamplingParams,
|
|
query_token_ids: Optional[Union[torch.Tensor, np.ndarray, list]] = None,
|
|
lora_request: Optional[LoRARequest] = None,
|
|
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
|
streaming: bool = False,
|
|
kv_cache_retention_config: Optional[KvCacheRetentionConfig] = None,
|
|
disaggregated_params: Optional[DisaggregatedParams] = None,
|
|
trace_headers: Optional[Mapping[str, str]] = None,
|
|
postproc_params: Optional[PostprocParams] = None,
|
|
multimodal_params: Optional[MultimodalParams] = None,
|
|
scheduling_params: Optional[SchedulingParams] = None,
|
|
cache_salt_id: Optional[int] = None,
|
|
arrival_time: Optional[float] = None,
|
|
) -> GenerationResult:
|
|
"""Generate output for the given prompt token ids in the asynchronous mode.
|
|
Asynchronous generation accepts single prompt only.
|
|
"""
|
|
assert isinstance(prompt_token_ids[0], int)
|
|
assert isinstance(sampling_params, SamplingParams)
|
|
|
|
self._maybe_initialize_iteration_results()
|
|
|
|
if postproc_params:
|
|
postproc_params.postproc_args.num_prompt_tokens = len(
|
|
prompt_token_ids)
|
|
request = GenerationRequest(
|
|
prompt_token_ids,
|
|
sampling_params=sampling_params,
|
|
postproc_params=postproc_params,
|
|
query_token_ids=query_token_ids,
|
|
lora_request=lora_request,
|
|
prompt_adapter_request=prompt_adapter_request,
|
|
streaming=streaming,
|
|
kv_cache_retention_config=kv_cache_retention_config,
|
|
disaggregated_params=disaggregated_params,
|
|
trace_headers=trace_headers,
|
|
multimodal_params=multimodal_params,
|
|
scheduling_params=scheduling_params,
|
|
cache_salt_id=cache_salt_id,
|
|
arrival_time=arrival_time)
|
|
result = self.submit(request)
|
|
# release memory in time
|
|
if hasattr(request, "multimodal_params"):
|
|
del request.multimodal_params
|
|
return result
|
|
|
|
def generate(
|
|
self,
|
|
prompt_token_ids: Union[List[int], List[List[int]]],
|
|
sampling_params: Union[SamplingParams, List[SamplingParams]],
|
|
query_token_ids: Optional[Union[torch.Tensor, np.ndarray, list]] = None,
|
|
lora_request: Optional[Union[LoRARequest, List[LoRARequest]]] = None,
|
|
prompt_adapter_request: Optional[Union[
|
|
PromptAdapterRequest, List[PromptAdapterRequest]]] = None,
|
|
disaggregated_params: Optional[DisaggregatedParams] = None,
|
|
) -> Union[GenerationResult, List[GenerationResult]]:
|
|
"""Generate output for the given prompt token ids in the synchronous mode.
|
|
Synchronous generation accepts either single prompt or batched prompts.
|
|
"""
|
|
unbatched = isinstance(prompt_token_ids[0], int)
|
|
|
|
if unbatched:
|
|
prompt_token_ids = [prompt_token_ids]
|
|
if query_token_ids:
|
|
query_token_ids = [query_token_ids]
|
|
|
|
futures = []
|
|
for i, p in enumerate(prompt_token_ids):
|
|
if isinstance(sampling_params, list):
|
|
sp = sampling_params[i]
|
|
else:
|
|
sp = sampling_params
|
|
if isinstance(lora_request, list):
|
|
lora_req = lora_request[i]
|
|
else:
|
|
lora_req = lora_request
|
|
if isinstance(prompt_adapter_request, list):
|
|
pa_req = prompt_adapter_request[i]
|
|
else:
|
|
pa_req = prompt_adapter_request
|
|
future = self.generate_async(
|
|
p,
|
|
sampling_params=sp,
|
|
query_token_ids=query_token_ids,
|
|
lora_request=lora_req,
|
|
prompt_adapter_request=pa_req,
|
|
streaming=False,
|
|
disaggregated_params=disaggregated_params)
|
|
futures.append(future)
|
|
|
|
for future in futures:
|
|
future.result()
|
|
|
|
if unbatched:
|
|
futures = futures[0]
|
|
|
|
return futures
|
|
|
|
def _get_next_client_id(self):
|
|
# (self._last_client_id + 1) % UINT64_MAX
|
|
self._last_client_id = (self._last_client_id + 1) & ((1 << 64) - 1)
|
|
return self._last_client_id
|
|
|
|
def _get_logprob_params(
|
|
self, request: GenerationRequest) -> Optional[LogprobParams]:
|
|
"""Store logprobs-related fields from request for the later logprob calculation."""
|
|
logprob_params = None
|
|
if request.sampling_params.logprobs or request.sampling_params.prompt_logprobs:
|
|
logprob_params = LogprobParams(
|
|
logprobs=request.sampling_params.logprobs,
|
|
prompt_logprobs=request.sampling_params.prompt_logprobs,
|
|
# drop logits if users didn't explicitly ask for it, or if it's using PostProcess flow
|
|
drop_context_logits=(
|
|
not request.sampling_params._need_return_context_logits)
|
|
or self.postproc_config.num_postprocess_workers > 0,
|
|
drop_generation_logits=(
|
|
not request.sampling_params._need_return_generation_logits)
|
|
or self.postproc_config.num_postprocess_workers > 0)
|
|
|
|
return logprob_params
|
|
|
|
def _maybe_initialize_iteration_results(self):
|
|
if self._is_llm_executor:
|
|
if self._iter_stats_result is None:
|
|
# singleton to store cpp runtime stats
|
|
self._iter_stats_result = IterationResult()
|
|
else:
|
|
# expect more engine stats whenever new prompts are submitted
|
|
self._iter_stats_result.mark_undone()
|
|
|
|
if self._iter_kv_events_result is None:
|
|
self._iter_kv_events_result = IterationResult()
|
|
else:
|
|
self._iter_kv_events_result.mark_undone()
|
|
|
|
def _handle_background_error(self, error: Optional[Exception | str] = None):
|
|
""" Process the errors from the threads or processes.
|
|
NOTE: This should be called in the main thread.
|
|
"""
|
|
if error is not None:
|
|
# For details please refer to the comment of `GenerationResult.error`
|
|
|
|
if isinstance(error, RequestError):
|
|
# A per-request error, can be captured and ignored
|
|
if enable_llm_debug():
|
|
print_colored(f"Got per-request error: {repr(error)}\n",
|
|
"red")
|
|
elif isinstance(error, str):
|
|
# A per-request error, can be captured and ignored
|
|
if enable_llm_debug():
|
|
print_colored(f"Got per-request error: {repr(error)}\n",
|
|
"red")
|
|
print_colored(str(traceback.extract_stack()) + "\n", "red")
|
|
error = RequestError(error)
|
|
else:
|
|
# Serious error from background thread or process
|
|
if not isinstance(error, BaseException):
|
|
error = RuntimeError(repr(error))
|
|
if enable_llm_debug():
|
|
print_colored(
|
|
f"Got background error: {repr(error)}, will shutdown the LLM instance\n",
|
|
"red")
|
|
self.shutdown()
|
|
raise error
|
|
|
|
# Here we raise the first error in the queue. This method will be called repeatedly and user can choose to catch
|
|
# more than one error.
|
|
if not self._error_queue.empty():
|
|
e = self._error_queue.get()
|
|
self._error_queue.task_done()
|
|
self.shutdown()
|
|
# We can catch some exceptions here.
|
|
raise e
|
|
|
|
def is_shutdown(self) -> bool:
|
|
return self.doing_shutdown
|
|
|
|
@abstractmethod
|
|
def shutdown(self):
|
|
pass
|
|
|
|
@property
|
|
def enable_postprocess_parallel(self) -> bool:
|
|
return self.postproc_config.enabled
|
|
|
|
def get_stats(self, timeout: float) -> List[dict]:
|
|
"""
|
|
Get iteration statistics from the runtime.
|
|
Args:
|
|
timeout (float): Max wait time in seconds when retrieving stats from queue.
|
|
Returns:
|
|
List[dict]: A list of runtime stats as dict.
|
|
"""
|
|
if self._iter_stats_result is None:
|
|
print_colored(
|
|
"Iteration statistics are not available yet. To collect runtime statistics, please call get_stats() AFTER prompts have been submitted.\n",
|
|
"yellow")
|
|
return []
|
|
|
|
self._iter_stats_result.set_timeout(timeout)
|
|
return self._iter_stats_result.get_results()
|
|
|
|
def aget_stats(self, timeout: float) -> IterationResult:
|
|
"""
|
|
Get iteration statistics from the runtime.
|
|
Returns:
|
|
IterationResult: An async iterable object containing runtime stats.
|
|
"""
|
|
if self._iter_stats_result is None:
|
|
print_colored(
|
|
"Iteration statistics are not available yet. To collect runtime statistics, please call get_stats_async() in async coroutine or the /metrics endpoint (if you're using trtllm-serve) AFTER prompts have been submitted.\n",
|
|
"yellow")
|
|
return empty_async_iterable()
|
|
|
|
self._iter_stats_result.set_timeout(timeout)
|
|
return self._iter_stats_result
|
|
|
|
def get_kv_events(self, timeout: float) -> List[dict]:
|
|
"""
|
|
Get iteration kv events from the runtime.
|
|
Args:
|
|
timeout (float): Max wait time in seconds when retrieving stats from queue.
|
|
Returns:
|
|
List[dict]: A list of runtime events as dict.
|
|
"""
|
|
assert self._iter_kv_events_result is not None, "KV Event IterationResult is not properly instantiated."
|
|
|
|
self._iter_kv_events_result.set_timeout(timeout)
|
|
return self._iter_kv_events_result.get_results()
|
|
|
|
def aget_kv_events(self, timeout=None) -> IterationResult:
|
|
"""
|
|
Get iteration kv events from the runtime.
|
|
Args:
|
|
timeout (float): Max wait time in seconds when retrieving stats from queue.
|
|
Returns:
|
|
IterationResult: An async iterable object containing runtime events.
|
|
"""
|
|
assert self._iter_kv_events_result is not None, "KV Event IterationResult is not properly instantiated."
|
|
|
|
self._iter_kv_events_result.set_timeout(timeout)
|
|
return self._iter_kv_events_result
|
|
|
|
@staticmethod
|
|
def _create_ray_executor(
|
|
worker_kwargs: Dict,
|
|
model_world_size: int,
|
|
postproc_worker_config: PostprocWorkerConfig,
|
|
is_llm_executor: bool,
|
|
tp_size: int,
|
|
):
|
|
logger.warning(f"Orchestrator is creating Ray executor")
|
|
from .ray_executor import RayExecutor
|
|
|
|
return RayExecutor(worker_kwargs,
|
|
model_world_size=model_world_size,
|
|
postproc_worker_config=postproc_worker_config,
|
|
is_llm_executor=is_llm_executor,
|
|
tp_size=tp_size)
|
|
|
|
@staticmethod
|
|
def _create_rpc_executor(
|
|
worker_kwargs: Dict,
|
|
model_world_size: int,
|
|
mpi_session: Optional[MpiSession],
|
|
postproc_worker_config: PostprocWorkerConfig,
|
|
is_llm_executor: bool,
|
|
):
|
|
"""Create RPC-based executor (GenerationExecutorRpcProxy)."""
|
|
from .rpc_proxy import GenerationExecutorRpcProxy
|
|
logger.warning(f"Orchestrator is creating RPC executor")
|
|
return GenerationExecutorRpcProxy(
|
|
worker_kwargs,
|
|
model_world_size=model_world_size,
|
|
mpi_session=mpi_session,
|
|
postproc_worker_config=postproc_worker_config,
|
|
is_llm_executor=is_llm_executor)
|
|
|
|
@staticmethod
|
|
def _create_ipc_executor(
|
|
worker_kwargs: Dict,
|
|
model_world_size: int,
|
|
mpi_session: Optional[MpiSession],
|
|
postproc_worker_config: PostprocWorkerConfig,
|
|
is_llm_executor: bool,
|
|
use_worker: bool = False,
|
|
):
|
|
"""Create IPC-based executor (GenerationExecutorProxy or GenerationExecutorWorker).
|
|
|
|
Args:
|
|
use_worker: If True, creates GenerationExecutorWorker (single process).
|
|
If False, creates GenerationExecutorProxy (multi-process with IPC).
|
|
"""
|
|
logger.warning(f"Orchestrator is creating IPC executor")
|
|
if use_worker:
|
|
from .worker import GenerationExecutorWorker
|
|
return GenerationExecutorWorker(**worker_kwargs,
|
|
is_llm_executor=is_llm_executor)
|
|
else:
|
|
from .proxy import GenerationExecutorProxy
|
|
return GenerationExecutorProxy(
|
|
worker_kwargs,
|
|
model_world_size=model_world_size,
|
|
mpi_session=mpi_session,
|
|
postproc_worker_config=postproc_worker_config,
|
|
is_llm_executor=is_llm_executor)
|
|
|
|
@staticmethod
|
|
def create(
|
|
engine: Union[Path, Engine],
|
|
executor_config: Optional[tllm.ExecutorConfig] = None,
|
|
batched_logits_processor: Optional[BatchedLogitsProcessor] = None,
|
|
model_world_size: int = 1,
|
|
world_size: int = 0,
|
|
mpi_session: Optional[MpiSession] = None,
|
|
reuse_mpi_comm: bool = False,
|
|
return_logits: bool = False,
|
|
postproc_worker_config: Optional[PostprocWorkerConfig] = None,
|
|
is_llm_executor: Optional[bool] = None,
|
|
hf_model_dir: Optional[Path] = None,
|
|
tokenizer: Optional[TokenizerBase] = None,
|
|
llm_args: Optional[BaseLlmArgs] = None,
|
|
**args,
|
|
) -> Union["GenerationExecutorProxy", "GenerationExecutorWorker"]:
|
|
if world_size == 0:
|
|
world_size = mpi_world_size()
|
|
|
|
if world_size > 1 and world_size < model_world_size:
|
|
raise RuntimeError(
|
|
"Cannot instantiate Generator for engine built "
|
|
f"for {model_world_size} ranks, while currently running "
|
|
f"on {world_size} ranks.")
|
|
|
|
postproc_worker_config = postproc_worker_config or PostprocWorkerConfig(
|
|
)
|
|
|
|
if postproc_worker_config.enabled:
|
|
logger_debug(
|
|
f"Using {postproc_worker_config.num_postprocess_workers} postprocess parallel processes.\n",
|
|
"green")
|
|
|
|
worker_kwargs = {
|
|
"engine": engine,
|
|
"executor_config": executor_config,
|
|
"batched_logits_processor": batched_logits_processor,
|
|
"hf_model_dir": hf_model_dir,
|
|
"tokenizer": tokenizer,
|
|
"llm_args": llm_args,
|
|
}
|
|
|
|
orchestrator_type = None if not isinstance(
|
|
llm_args, TorchLlmArgs) else llm_args.orchestrator_type
|
|
if orchestrator_type == "ray":
|
|
if llm_args and hasattr(llm_args, 'ray_worker_extension_cls'):
|
|
worker_kwargs[
|
|
"ray_worker_extension_cls"] = llm_args.ray_worker_extension_cls
|
|
return GenerationExecutor._create_ray_executor(
|
|
worker_kwargs,
|
|
model_world_size,
|
|
postproc_worker_config,
|
|
is_llm_executor=is_llm_executor,
|
|
tp_size=args.get("tp_size", 1))
|
|
elif orchestrator_type is not None and orchestrator_type != "rpc":
|
|
raise ValueError(
|
|
f"Unsupported orchestrator_type: {orchestrator_type}")
|
|
|
|
# The case where the Python main process is launched by mpirun
|
|
mpirun_launch = external_mpi_comm_available(model_world_size)
|
|
# The case where the Python main process utilizes mpi4py to spawn MPI workers
|
|
spawn_workers = need_spawn_mpi_workers(model_world_size)
|
|
orchestrator_is_rpc = llm_args and llm_args.orchestrator_type == "rpc"
|
|
|
|
if spawn_workers or (mpirun_launch and reuse_mpi_comm):
|
|
if reuse_mpi_comm:
|
|
assert mpi_session is not None, "reuse_mpi_comm requires an external MPI session"
|
|
|
|
if orchestrator_is_rpc:
|
|
return GenerationExecutor._create_rpc_executor(
|
|
worker_kwargs,
|
|
model_world_size=model_world_size,
|
|
mpi_session=mpi_session,
|
|
postproc_worker_config=postproc_worker_config,
|
|
is_llm_executor=is_llm_executor)
|
|
|
|
return GenerationExecutor._create_ipc_executor(
|
|
worker_kwargs,
|
|
model_world_size=model_world_size,
|
|
mpi_session=mpi_session,
|
|
postproc_worker_config=postproc_worker_config,
|
|
is_llm_executor=is_llm_executor,
|
|
use_worker=False)
|
|
|
|
# WAR: For the performance of gathering logits, we use single process worker
|
|
# for TP1 to avoid the large overhead of IPC.
|
|
# WAR: Developers can enable this manually, this will be easier for TP1
|
|
# debugging. We will introduce a better solution in the future.
|
|
if return_logits or enable_worker_single_process_for_tp1():
|
|
logger.warning(
|
|
"Using single process worker for TP1, this may hurt streaming generation performance."
|
|
)
|
|
if orchestrator_is_rpc:
|
|
return GenerationExecutor._create_rpc_executor(
|
|
worker_kwargs,
|
|
model_world_size=model_world_size,
|
|
mpi_session=mpi_session,
|
|
postproc_worker_config=postproc_worker_config,
|
|
is_llm_executor=is_llm_executor)
|
|
|
|
return GenerationExecutor._create_ipc_executor(
|
|
worker_kwargs,
|
|
model_world_size=model_world_size,
|
|
mpi_session=mpi_session,
|
|
postproc_worker_config=postproc_worker_config,
|
|
is_llm_executor=is_llm_executor,
|
|
use_worker=True)
|
|
|
|
# For single-gpu case:
|
|
# Partition the workload to multiple process for streaming performance.
|
|
# While this requires uses to protect their entrypoint to
|
|
# `if __name__ == "__main__":`.
|
|
if not platform.system() == 'Windows':
|
|
if orchestrator_is_rpc:
|
|
return GenerationExecutor._create_rpc_executor(
|
|
worker_kwargs,
|
|
model_world_size=model_world_size,
|
|
mpi_session=mpi_session,
|
|
postproc_worker_config=postproc_worker_config,
|
|
is_llm_executor=is_llm_executor)
|
|
|
|
return GenerationExecutor._create_ipc_executor(
|
|
worker_kwargs,
|
|
model_world_size=model_world_size,
|
|
mpi_session=None, # use mpi4py
|
|
postproc_worker_config=postproc_worker_config,
|
|
is_llm_executor=is_llm_executor,
|
|
use_worker=False)
|
|
else:
|
|
ctx = multiprocessing.get_context("spawn")
|
|
# The ProcessPoolExecutorSession is used to support Windows, as mpi4py cannot.
|
|
mpi_session = ProcessPoolExecutorSession(n_workers=1,
|
|
mp_context=ctx)
|
|
# TODO: add rpc worker here
|
|
return GenerationExecutor._create_ipc_executor(
|
|
worker_kwargs,
|
|
model_world_size=model_world_size,
|
|
mpi_session=mpi_session,
|
|
postproc_worker_config=postproc_worker_config,
|
|
is_llm_executor=is_llm_executor,
|
|
use_worker=False)
|
|
|
|
def wait_first_completed(
|
|
self, futures: List[GenerationResult]
|
|
) -> Generator[GenerationResult, None, None]:
|
|
wait_set = set(futures)
|
|
|
|
# clear already-finished requests
|
|
for f in futures:
|
|
if f._done:
|
|
wait_set.pop(f)
|
|
yield f
|
|
|
|
# wait remaining active requests
|
|
while len(wait_set) > 0:
|
|
fut = wait_set.pop()
|
|
if fut.request_id not in self._results:
|
|
yield fut
|
|
else:
|
|
wait_set.add(fut)
|
|
|
|
|
|
if enable_llm_debug():
|
|
print_colored("LLM debug mode enabled.\n", "yellow")
|
|
# This will dump all the alive threads when the process is interrupted by SIGINT.
|
|
faulthandler.register(signal.SIGINT, all_threads=True)
|