mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[TRTLLM-8507][fix] Fix ray resource cleanup and error handling in LoRA test (#8175)
Signed-off-by: shuyix <219646547+shuyixiong@users.noreply.github.com>
This commit is contained in:
parent
0d20a8fd61
commit
6776caaad1
28
tensorrt_llm/_ray_utils.py
Normal file
28
tensorrt_llm/_ray_utils.py
Normal file
@ -0,0 +1,28 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from contextlib import contextmanager
|
||||
|
||||
try:
|
||||
import ray
|
||||
except ImportError:
|
||||
import tensorrt_llm.ray_stub as ray
|
||||
|
||||
|
||||
@contextmanager
|
||||
def unwrap_ray_errors():
|
||||
try:
|
||||
yield
|
||||
except ray.exceptions.RayTaskError as e:
|
||||
raise e.as_instanceof_cause() from e
|
||||
@ -12,6 +12,7 @@ from ray.util.placement_group import (PlacementGroup,
|
||||
get_current_placement_group,
|
||||
placement_group)
|
||||
|
||||
from tensorrt_llm._ray_utils import unwrap_ray_errors
|
||||
from tensorrt_llm._utils import get_free_port
|
||||
from tensorrt_llm.logger import logger
|
||||
|
||||
@ -57,6 +58,7 @@ class RayExecutor(GenerationExecutor):
|
||||
"runtime_env": runtime_env
|
||||
}
|
||||
|
||||
try:
|
||||
if os.environ.get("TLLM_RAY_FORCE_LOCAL_CLUSTER", "0") != "1":
|
||||
try:
|
||||
ray.init(address="auto", **ray_init_args)
|
||||
@ -99,6 +101,11 @@ class RayExecutor(GenerationExecutor):
|
||||
kv_connector_config=kv_connector_config)
|
||||
|
||||
self.create_workers(RayGPUWorker, worker_kwargs)
|
||||
except Exception as e:
|
||||
# Clean up the Ray resources early during exception
|
||||
self.shutdown()
|
||||
logger.error(f"Failed to initialize RayExecutor: {e}")
|
||||
raise e
|
||||
|
||||
@staticmethod
|
||||
def create_actor_weak_ref(actor_handle: ray.actor.ActorHandle):
|
||||
@ -137,12 +144,19 @@ class RayExecutor(GenerationExecutor):
|
||||
for rank in range(self.world_size)
|
||||
]
|
||||
|
||||
try:
|
||||
ray.get([worker.__ray_ready__.remote() for worker in self.workers])
|
||||
except ray.exceptions.ActorDiedError as e:
|
||||
if "The actor died because of an error raised in its creation task" in str(
|
||||
e):
|
||||
raise RuntimeError(
|
||||
"RayGPUWorker died during initialization") from e
|
||||
raise
|
||||
|
||||
@unwrap_ray_errors()
|
||||
def call_all_ray_workers(self, func: str, leader_only: bool,
|
||||
async_call: bool, *args, **kwargs):
|
||||
workers = (self.workers[0], ) if leader_only else self.workers
|
||||
|
||||
if async_call:
|
||||
return [
|
||||
getattr(worker, func).remote(*args, **kwargs)
|
||||
@ -154,6 +168,7 @@ class RayExecutor(GenerationExecutor):
|
||||
for worker in workers
|
||||
])
|
||||
|
||||
@unwrap_ray_errors()
|
||||
def collective_rpc(self,
|
||||
method: str,
|
||||
args: tuple = (),
|
||||
@ -174,7 +189,6 @@ class RayExecutor(GenerationExecutor):
|
||||
# Ray actor doesn't work with __getattr__ delegation.
|
||||
refs.append(w.call_worker_method.remote(method, *args,
|
||||
**kwargs))
|
||||
|
||||
return refs if non_block else ray.get(refs)
|
||||
|
||||
def submit(self, request: GenerationRequest) -> GenerationResult:
|
||||
@ -224,11 +238,14 @@ class RayExecutor(GenerationExecutor):
|
||||
self.workers = None
|
||||
if hasattr(self,
|
||||
"placement_group") and self.placement_group is not None:
|
||||
# Only remove placement group if Ray is still initialized
|
||||
# to avoid triggering auto_init_ray() during program exit
|
||||
if ray.is_initialized():
|
||||
ray.util.remove_placement_group(self.placement_group)
|
||||
self.placement_group = None
|
||||
self.bundle_indices = None
|
||||
|
||||
if self.has_start_local_cluser:
|
||||
if self.has_start_local_cluser and ray.is_initialized():
|
||||
logger.debug("Shutting down Ray cluster")
|
||||
ray.shutdown()
|
||||
|
||||
|
||||
@ -16,6 +16,7 @@ try:
|
||||
except ModuleNotFoundError:
|
||||
from tensorrt_llm import ray_stub as ray
|
||||
|
||||
from .._ray_utils import unwrap_ray_errors
|
||||
from .._utils import mpi_disabled, nvtx_range_debug
|
||||
from ..bindings import executor as tllm
|
||||
from ..disaggregated_params import DisaggregatedParams
|
||||
@ -274,7 +275,7 @@ class GenerationResultBase:
|
||||
else:
|
||||
self.queue = ray_queue
|
||||
self.aqueue = None
|
||||
|
||||
with unwrap_ray_errors():
|
||||
ray.get(self.queue.register.remote(id))
|
||||
else:
|
||||
if has_event_loop():
|
||||
@ -735,6 +736,7 @@ class GenerationResult(GenerationResultBase):
|
||||
|
||||
def _result_step(self, timeout: Optional[float] = None):
|
||||
if mpi_disabled():
|
||||
with unwrap_ray_errors():
|
||||
response = ray.get(self.queue.get.remote(self.request_id))
|
||||
response = self._handle_ray_response(response)
|
||||
else:
|
||||
|
||||
@ -12,11 +12,11 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import functools
|
||||
from functools import wraps as _wraps
|
||||
|
||||
from tensorrt_llm._utils import mpi_disabled
|
||||
from tensorrt_llm._utils import mpi_disabled as _mpi_disabled
|
||||
|
||||
if mpi_disabled():
|
||||
if _mpi_disabled():
|
||||
raise RuntimeError(
|
||||
"Ray requested (TLLM_DISABLE_MPI=1), but not installed. Please install Ray."
|
||||
)
|
||||
@ -27,10 +27,11 @@ def remote(*args, **kwargs):
|
||||
def decorator(func):
|
||||
# Returns a function that always raises.
|
||||
# Decorated class depends on ray, but ray is not installed.
|
||||
@functools.wraps(func)
|
||||
@_wraps(func)
|
||||
def stub_checker(*_, **__):
|
||||
raise RuntimeError(
|
||||
"Ray not installed, cannot use Ray based feature.")
|
||||
f'Ray not installed, so the remote function / actor "{func.__name__}" is not available.'
|
||||
)
|
||||
|
||||
return stub_checker
|
||||
|
||||
@ -38,3 +39,9 @@ def remote(*args, **kwargs):
|
||||
return decorator(args[0])
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def __getattr__(name):
|
||||
raise RuntimeError(
|
||||
f'Ray not installed, so "ray.{name}" is unavailable. Please install Ray.'
|
||||
)
|
||||
|
||||
@ -414,7 +414,6 @@ def test_llama_7b_multi_lora_evict_and_reload_evicted_adapters_in_cpu_and_gpu_ca
|
||||
repeats_per_call=1)
|
||||
|
||||
|
||||
@skip_ray
|
||||
@skip_gpu_memory_less_than_40gb
|
||||
def test_llama_7b_peft_cache_config_affects_peft_cache_size():
|
||||
"""Tests that LLM arg of peft_cache_config affects the peft cache sizes.
|
||||
|
||||
@ -19,7 +19,8 @@ except ImportError:
|
||||
from parameterized import parameterized
|
||||
|
||||
import tensorrt_llm
|
||||
from tensorrt_llm._utils import torch_dtype_to_trt, trt_dtype_to_torch
|
||||
from tensorrt_llm._utils import (mpi_disabled, torch_dtype_to_trt,
|
||||
trt_dtype_to_torch)
|
||||
from tensorrt_llm.llmapi.utils import get_total_gpu_memory
|
||||
from tensorrt_llm.plugin.plugin import ContextFMHAType
|
||||
from tensorrt_llm.quantization import QuantMode
|
||||
@ -449,5 +450,4 @@ def check_accuracy(a, b, atol, rtol, percent):
|
||||
|
||||
|
||||
skip_ray = pytest.mark.skipif(
|
||||
os.environ.get("TLLM_DISABLE_MPI") == "1",
|
||||
reason="This test is skipped for Ray orchestrator.")
|
||||
mpi_disabled(), reason="This test is skipped for Ray orchestrator.")
|
||||
|
||||
Loading…
Reference in New Issue
Block a user