From 6776caaad126cf916bc191ed8d784c96b915402f Mon Sep 17 00:00:00 2001 From: shuyixiong <219646547+shuyixiong@users.noreply.github.com> Date: Tue, 14 Oct 2025 23:46:30 +0800 Subject: [PATCH] [TRTLLM-8507][fix] Fix ray resource cleanup and error handling in LoRA test (#8175) Signed-off-by: shuyix <219646547+shuyixiong@users.noreply.github.com> --- tensorrt_llm/_ray_utils.py | 28 +++++++ tensorrt_llm/executor/ray_executor.py | 99 +++++++++++++---------- tensorrt_llm/executor/result.py | 8 +- tensorrt_llm/ray_stub.py | 17 ++-- tests/unittest/llmapi/test_llm_pytorch.py | 1 - tests/unittest/utils/util.py | 6 +- 6 files changed, 106 insertions(+), 53 deletions(-) create mode 100644 tensorrt_llm/_ray_utils.py diff --git a/tensorrt_llm/_ray_utils.py b/tensorrt_llm/_ray_utils.py new file mode 100644 index 0000000000..8a3d41ee6d --- /dev/null +++ b/tensorrt_llm/_ray_utils.py @@ -0,0 +1,28 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from contextlib import contextmanager + +try: + import ray +except ImportError: + import tensorrt_llm.ray_stub as ray + + +@contextmanager +def unwrap_ray_errors(): + try: + yield + except ray.exceptions.RayTaskError as e: + raise e.as_instanceof_cause() from e diff --git a/tensorrt_llm/executor/ray_executor.py b/tensorrt_llm/executor/ray_executor.py index bbf4832ddd..ee232451e6 100644 --- a/tensorrt_llm/executor/ray_executor.py +++ b/tensorrt_llm/executor/ray_executor.py @@ -12,6 +12,7 @@ from ray.util.placement_group import (PlacementGroup, get_current_placement_group, placement_group) +from tensorrt_llm._ray_utils import unwrap_ray_errors from tensorrt_llm._utils import get_free_port from tensorrt_llm.logger import logger @@ -57,48 +58,54 @@ class RayExecutor(GenerationExecutor): "runtime_env": runtime_env } - if os.environ.get("TLLM_RAY_FORCE_LOCAL_CLUSTER", "0") != "1": - try: - ray.init(address="auto", **ray_init_args) - logger.info(f"Attached to an existing Ray cluster.") - except ConnectionError: - logger.info(f"Ray cluster not found, starting a new one.") + try: + if os.environ.get("TLLM_RAY_FORCE_LOCAL_CLUSTER", "0") != "1": + try: + ray.init(address="auto", **ray_init_args) + logger.info(f"Attached to an existing Ray cluster.") + except ConnectionError: + logger.info(f"Ray cluster not found, starting a new one.") - if not ray.is_initialized(): - ray.init(**ray_init_args) + if not ray.is_initialized(): + ray.init(**ray_init_args) + self.has_start_local_cluser = True + else: + ray.init(address="local", **ray_init_args) self.has_start_local_cluser = True - else: - ray.init(address="local", **ray_init_args) - self.has_start_local_cluser = True - self.world_size = model_world_size - self.tp_size = tp_size - self.master_address = ray.util.get_node_ip_address() - self.master_port = get_free_port() + self.world_size = model_world_size + self.tp_size = tp_size + self.master_address = ray.util.get_node_ip_address() + self.master_port = get_free_port() - self.response_queue = RayAsyncQueue.options(runtime_env={ - "env_vars": { - "TLLM_DISABLE_MPI": "1" - } - }).remote() - self.response_sync_queue = RaySyncQueue.options(runtime_env={ - "env_vars": { - "TLLM_DISABLE_MPI": "1" - } - }).remote() - self.async_response_queue_weakref = self.create_actor_weak_ref( - self.response_queue) - self.sync_response_queue_weakref = self.create_actor_weak_ref( - self.response_sync_queue) - self.response_queue.warmup.remote() - self.response_sync_queue.warmup.remote() + self.response_queue = RayAsyncQueue.options(runtime_env={ + "env_vars": { + "TLLM_DISABLE_MPI": "1" + } + }).remote() + self.response_sync_queue = RaySyncQueue.options(runtime_env={ + "env_vars": { + "TLLM_DISABLE_MPI": "1" + } + }).remote() + self.async_response_queue_weakref = self.create_actor_weak_ref( + self.response_queue) + self.sync_response_queue_weakref = self.create_actor_weak_ref( + self.response_sync_queue) + self.response_queue.warmup.remote() + self.response_sync_queue.warmup.remote() - worker_kwargs = dict(**worker_kwargs, - postproc_worker_config=postproc_worker_config, - is_llm_executor=is_llm_executor, - kv_connector_config=kv_connector_config) + worker_kwargs = dict(**worker_kwargs, + postproc_worker_config=postproc_worker_config, + is_llm_executor=is_llm_executor, + kv_connector_config=kv_connector_config) - self.create_workers(RayGPUWorker, worker_kwargs) + self.create_workers(RayGPUWorker, worker_kwargs) + except Exception as e: + # Clean up the Ray resources early during exception + self.shutdown() + logger.error(f"Failed to initialize RayExecutor: {e}") + raise e @staticmethod def create_actor_weak_ref(actor_handle: ray.actor.ActorHandle): @@ -137,12 +144,19 @@ class RayExecutor(GenerationExecutor): for rank in range(self.world_size) ] - ray.get([worker.__ray_ready__.remote() for worker in self.workers]) + try: + ray.get([worker.__ray_ready__.remote() for worker in self.workers]) + except ray.exceptions.ActorDiedError as e: + if "The actor died because of an error raised in its creation task" in str( + e): + raise RuntimeError( + "RayGPUWorker died during initialization") from e + raise + @unwrap_ray_errors() def call_all_ray_workers(self, func: str, leader_only: bool, async_call: bool, *args, **kwargs): workers = (self.workers[0], ) if leader_only else self.workers - if async_call: return [ getattr(worker, func).remote(*args, **kwargs) @@ -154,6 +168,7 @@ class RayExecutor(GenerationExecutor): for worker in workers ]) + @unwrap_ray_errors() def collective_rpc(self, method: str, args: tuple = (), @@ -174,7 +189,6 @@ class RayExecutor(GenerationExecutor): # Ray actor doesn't work with __getattr__ delegation. refs.append(w.call_worker_method.remote(method, *args, **kwargs)) - return refs if non_block else ray.get(refs) def submit(self, request: GenerationRequest) -> GenerationResult: @@ -224,11 +238,14 @@ class RayExecutor(GenerationExecutor): self.workers = None if hasattr(self, "placement_group") and self.placement_group is not None: - ray.util.remove_placement_group(self.placement_group) + # Only remove placement group if Ray is still initialized + # to avoid triggering auto_init_ray() during program exit + if ray.is_initialized(): + ray.util.remove_placement_group(self.placement_group) self.placement_group = None self.bundle_indices = None - if self.has_start_local_cluser: + if self.has_start_local_cluser and ray.is_initialized(): logger.debug("Shutting down Ray cluster") ray.shutdown() diff --git a/tensorrt_llm/executor/result.py b/tensorrt_llm/executor/result.py index 92dc3bece7..9ec2870140 100644 --- a/tensorrt_llm/executor/result.py +++ b/tensorrt_llm/executor/result.py @@ -16,6 +16,7 @@ try: except ModuleNotFoundError: from tensorrt_llm import ray_stub as ray +from .._ray_utils import unwrap_ray_errors from .._utils import mpi_disabled, nvtx_range_debug from ..bindings import executor as tllm from ..disaggregated_params import DisaggregatedParams @@ -274,8 +275,8 @@ class GenerationResultBase: else: self.queue = ray_queue self.aqueue = None - - ray.get(self.queue.register.remote(id)) + with unwrap_ray_errors(): + ray.get(self.queue.register.remote(id)) else: if has_event_loop(): self.aqueue = AsyncQueue() @@ -735,7 +736,8 @@ class GenerationResult(GenerationResultBase): def _result_step(self, timeout: Optional[float] = None): if mpi_disabled(): - response = ray.get(self.queue.get.remote(self.request_id)) + with unwrap_ray_errors(): + response = ray.get(self.queue.get.remote(self.request_id)) response = self._handle_ray_response(response) else: response = self.queue.get() diff --git a/tensorrt_llm/ray_stub.py b/tensorrt_llm/ray_stub.py index fbff5ed306..9bd699d929 100644 --- a/tensorrt_llm/ray_stub.py +++ b/tensorrt_llm/ray_stub.py @@ -12,11 +12,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import functools +from functools import wraps as _wraps -from tensorrt_llm._utils import mpi_disabled +from tensorrt_llm._utils import mpi_disabled as _mpi_disabled -if mpi_disabled(): +if _mpi_disabled(): raise RuntimeError( "Ray requested (TLLM_DISABLE_MPI=1), but not installed. Please install Ray." ) @@ -27,10 +27,11 @@ def remote(*args, **kwargs): def decorator(func): # Returns a function that always raises. # Decorated class depends on ray, but ray is not installed. - @functools.wraps(func) + @_wraps(func) def stub_checker(*_, **__): raise RuntimeError( - "Ray not installed, cannot use Ray based feature.") + f'Ray not installed, so the remote function / actor "{func.__name__}" is not available.' + ) return stub_checker @@ -38,3 +39,9 @@ def remote(*args, **kwargs): return decorator(args[0]) return decorator + + +def __getattr__(name): + raise RuntimeError( + f'Ray not installed, so "ray.{name}" is unavailable. Please install Ray.' + ) diff --git a/tests/unittest/llmapi/test_llm_pytorch.py b/tests/unittest/llmapi/test_llm_pytorch.py index b8e9b0b15b..9e72698c08 100644 --- a/tests/unittest/llmapi/test_llm_pytorch.py +++ b/tests/unittest/llmapi/test_llm_pytorch.py @@ -414,7 +414,6 @@ def test_llama_7b_multi_lora_evict_and_reload_evicted_adapters_in_cpu_and_gpu_ca repeats_per_call=1) -@skip_ray @skip_gpu_memory_less_than_40gb def test_llama_7b_peft_cache_config_affects_peft_cache_size(): """Tests that LLM arg of peft_cache_config affects the peft cache sizes. diff --git a/tests/unittest/utils/util.py b/tests/unittest/utils/util.py index f4b1c68e82..babed7f101 100644 --- a/tests/unittest/utils/util.py +++ b/tests/unittest/utils/util.py @@ -19,7 +19,8 @@ except ImportError: from parameterized import parameterized import tensorrt_llm -from tensorrt_llm._utils import torch_dtype_to_trt, trt_dtype_to_torch +from tensorrt_llm._utils import (mpi_disabled, torch_dtype_to_trt, + trt_dtype_to_torch) from tensorrt_llm.llmapi.utils import get_total_gpu_memory from tensorrt_llm.plugin.plugin import ContextFMHAType from tensorrt_llm.quantization import QuantMode @@ -449,5 +450,4 @@ def check_accuracy(a, b, atol, rtol, percent): skip_ray = pytest.mark.skipif( - os.environ.get("TLLM_DISABLE_MPI") == "1", - reason="This test is skipped for Ray orchestrator.") + mpi_disabled(), reason="This test is skipped for Ray orchestrator.")