[TRTLLM-8507][fix] Fix ray resource cleanup and error handling in LoRA test (#8175)

Signed-off-by: shuyix <219646547+shuyixiong@users.noreply.github.com>
This commit is contained in:
shuyixiong 2025-10-14 23:46:30 +08:00 committed by GitHub
parent 0d20a8fd61
commit 6776caaad1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 106 additions and 53 deletions

View File

@ -0,0 +1,28 @@
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import contextmanager
try:
import ray
except ImportError:
import tensorrt_llm.ray_stub as ray
@contextmanager
def unwrap_ray_errors():
try:
yield
except ray.exceptions.RayTaskError as e:
raise e.as_instanceof_cause() from e

View File

@ -12,6 +12,7 @@ from ray.util.placement_group import (PlacementGroup,
get_current_placement_group,
placement_group)
from tensorrt_llm._ray_utils import unwrap_ray_errors
from tensorrt_llm._utils import get_free_port
from tensorrt_llm.logger import logger
@ -57,6 +58,7 @@ class RayExecutor(GenerationExecutor):
"runtime_env": runtime_env
}
try:
if os.environ.get("TLLM_RAY_FORCE_LOCAL_CLUSTER", "0") != "1":
try:
ray.init(address="auto", **ray_init_args)
@ -99,6 +101,11 @@ class RayExecutor(GenerationExecutor):
kv_connector_config=kv_connector_config)
self.create_workers(RayGPUWorker, worker_kwargs)
except Exception as e:
# Clean up the Ray resources early during exception
self.shutdown()
logger.error(f"Failed to initialize RayExecutor: {e}")
raise e
@staticmethod
def create_actor_weak_ref(actor_handle: ray.actor.ActorHandle):
@ -137,12 +144,19 @@ class RayExecutor(GenerationExecutor):
for rank in range(self.world_size)
]
try:
ray.get([worker.__ray_ready__.remote() for worker in self.workers])
except ray.exceptions.ActorDiedError as e:
if "The actor died because of an error raised in its creation task" in str(
e):
raise RuntimeError(
"RayGPUWorker died during initialization") from e
raise
@unwrap_ray_errors()
def call_all_ray_workers(self, func: str, leader_only: bool,
async_call: bool, *args, **kwargs):
workers = (self.workers[0], ) if leader_only else self.workers
if async_call:
return [
getattr(worker, func).remote(*args, **kwargs)
@ -154,6 +168,7 @@ class RayExecutor(GenerationExecutor):
for worker in workers
])
@unwrap_ray_errors()
def collective_rpc(self,
method: str,
args: tuple = (),
@ -174,7 +189,6 @@ class RayExecutor(GenerationExecutor):
# Ray actor doesn't work with __getattr__ delegation.
refs.append(w.call_worker_method.remote(method, *args,
**kwargs))
return refs if non_block else ray.get(refs)
def submit(self, request: GenerationRequest) -> GenerationResult:
@ -224,11 +238,14 @@ class RayExecutor(GenerationExecutor):
self.workers = None
if hasattr(self,
"placement_group") and self.placement_group is not None:
# Only remove placement group if Ray is still initialized
# to avoid triggering auto_init_ray() during program exit
if ray.is_initialized():
ray.util.remove_placement_group(self.placement_group)
self.placement_group = None
self.bundle_indices = None
if self.has_start_local_cluser:
if self.has_start_local_cluser and ray.is_initialized():
logger.debug("Shutting down Ray cluster")
ray.shutdown()

View File

@ -16,6 +16,7 @@ try:
except ModuleNotFoundError:
from tensorrt_llm import ray_stub as ray
from .._ray_utils import unwrap_ray_errors
from .._utils import mpi_disabled, nvtx_range_debug
from ..bindings import executor as tllm
from ..disaggregated_params import DisaggregatedParams
@ -274,7 +275,7 @@ class GenerationResultBase:
else:
self.queue = ray_queue
self.aqueue = None
with unwrap_ray_errors():
ray.get(self.queue.register.remote(id))
else:
if has_event_loop():
@ -735,6 +736,7 @@ class GenerationResult(GenerationResultBase):
def _result_step(self, timeout: Optional[float] = None):
if mpi_disabled():
with unwrap_ray_errors():
response = ray.get(self.queue.get.remote(self.request_id))
response = self._handle_ray_response(response)
else:

View File

@ -12,11 +12,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
from functools import wraps as _wraps
from tensorrt_llm._utils import mpi_disabled
from tensorrt_llm._utils import mpi_disabled as _mpi_disabled
if mpi_disabled():
if _mpi_disabled():
raise RuntimeError(
"Ray requested (TLLM_DISABLE_MPI=1), but not installed. Please install Ray."
)
@ -27,10 +27,11 @@ def remote(*args, **kwargs):
def decorator(func):
# Returns a function that always raises.
# Decorated class depends on ray, but ray is not installed.
@functools.wraps(func)
@_wraps(func)
def stub_checker(*_, **__):
raise RuntimeError(
"Ray not installed, cannot use Ray based feature.")
f'Ray not installed, so the remote function / actor "{func.__name__}" is not available.'
)
return stub_checker
@ -38,3 +39,9 @@ def remote(*args, **kwargs):
return decorator(args[0])
return decorator
def __getattr__(name):
raise RuntimeError(
f'Ray not installed, so "ray.{name}" is unavailable. Please install Ray.'
)

View File

@ -414,7 +414,6 @@ def test_llama_7b_multi_lora_evict_and_reload_evicted_adapters_in_cpu_and_gpu_ca
repeats_per_call=1)
@skip_ray
@skip_gpu_memory_less_than_40gb
def test_llama_7b_peft_cache_config_affects_peft_cache_size():
"""Tests that LLM arg of peft_cache_config affects the peft cache sizes.

View File

@ -19,7 +19,8 @@ except ImportError:
from parameterized import parameterized
import tensorrt_llm
from tensorrt_llm._utils import torch_dtype_to_trt, trt_dtype_to_torch
from tensorrt_llm._utils import (mpi_disabled, torch_dtype_to_trt,
trt_dtype_to_torch)
from tensorrt_llm.llmapi.utils import get_total_gpu_memory
from tensorrt_llm.plugin.plugin import ContextFMHAType
from tensorrt_llm.quantization import QuantMode
@ -449,5 +450,4 @@ def check_accuracy(a, b, atol, rtol, percent):
skip_ray = pytest.mark.skipif(
os.environ.get("TLLM_DISABLE_MPI") == "1",
reason="This test is skipped for Ray orchestrator.")
mpi_disabled(), reason="This test is skipped for Ray orchestrator.")