mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
[TRTLLM-8507][fix] Fix ray resource cleanup and error handling in LoRA test (#8175)
Signed-off-by: shuyix <219646547+shuyixiong@users.noreply.github.com>
This commit is contained in:
parent
0d20a8fd61
commit
6776caaad1
28
tensorrt_llm/_ray_utils.py
Normal file
28
tensorrt_llm/_ray_utils.py
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from contextlib import contextmanager
|
||||||
|
|
||||||
|
try:
|
||||||
|
import ray
|
||||||
|
except ImportError:
|
||||||
|
import tensorrt_llm.ray_stub as ray
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def unwrap_ray_errors():
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
except ray.exceptions.RayTaskError as e:
|
||||||
|
raise e.as_instanceof_cause() from e
|
||||||
@ -12,6 +12,7 @@ from ray.util.placement_group import (PlacementGroup,
|
|||||||
get_current_placement_group,
|
get_current_placement_group,
|
||||||
placement_group)
|
placement_group)
|
||||||
|
|
||||||
|
from tensorrt_llm._ray_utils import unwrap_ray_errors
|
||||||
from tensorrt_llm._utils import get_free_port
|
from tensorrt_llm._utils import get_free_port
|
||||||
from tensorrt_llm.logger import logger
|
from tensorrt_llm.logger import logger
|
||||||
|
|
||||||
@ -57,48 +58,54 @@ class RayExecutor(GenerationExecutor):
|
|||||||
"runtime_env": runtime_env
|
"runtime_env": runtime_env
|
||||||
}
|
}
|
||||||
|
|
||||||
if os.environ.get("TLLM_RAY_FORCE_LOCAL_CLUSTER", "0") != "1":
|
try:
|
||||||
try:
|
if os.environ.get("TLLM_RAY_FORCE_LOCAL_CLUSTER", "0") != "1":
|
||||||
ray.init(address="auto", **ray_init_args)
|
try:
|
||||||
logger.info(f"Attached to an existing Ray cluster.")
|
ray.init(address="auto", **ray_init_args)
|
||||||
except ConnectionError:
|
logger.info(f"Attached to an existing Ray cluster.")
|
||||||
logger.info(f"Ray cluster not found, starting a new one.")
|
except ConnectionError:
|
||||||
|
logger.info(f"Ray cluster not found, starting a new one.")
|
||||||
|
|
||||||
if not ray.is_initialized():
|
if not ray.is_initialized():
|
||||||
ray.init(**ray_init_args)
|
ray.init(**ray_init_args)
|
||||||
|
self.has_start_local_cluser = True
|
||||||
|
else:
|
||||||
|
ray.init(address="local", **ray_init_args)
|
||||||
self.has_start_local_cluser = True
|
self.has_start_local_cluser = True
|
||||||
else:
|
|
||||||
ray.init(address="local", **ray_init_args)
|
|
||||||
self.has_start_local_cluser = True
|
|
||||||
|
|
||||||
self.world_size = model_world_size
|
self.world_size = model_world_size
|
||||||
self.tp_size = tp_size
|
self.tp_size = tp_size
|
||||||
self.master_address = ray.util.get_node_ip_address()
|
self.master_address = ray.util.get_node_ip_address()
|
||||||
self.master_port = get_free_port()
|
self.master_port = get_free_port()
|
||||||
|
|
||||||
self.response_queue = RayAsyncQueue.options(runtime_env={
|
self.response_queue = RayAsyncQueue.options(runtime_env={
|
||||||
"env_vars": {
|
"env_vars": {
|
||||||
"TLLM_DISABLE_MPI": "1"
|
"TLLM_DISABLE_MPI": "1"
|
||||||
}
|
}
|
||||||
}).remote()
|
}).remote()
|
||||||
self.response_sync_queue = RaySyncQueue.options(runtime_env={
|
self.response_sync_queue = RaySyncQueue.options(runtime_env={
|
||||||
"env_vars": {
|
"env_vars": {
|
||||||
"TLLM_DISABLE_MPI": "1"
|
"TLLM_DISABLE_MPI": "1"
|
||||||
}
|
}
|
||||||
}).remote()
|
}).remote()
|
||||||
self.async_response_queue_weakref = self.create_actor_weak_ref(
|
self.async_response_queue_weakref = self.create_actor_weak_ref(
|
||||||
self.response_queue)
|
self.response_queue)
|
||||||
self.sync_response_queue_weakref = self.create_actor_weak_ref(
|
self.sync_response_queue_weakref = self.create_actor_weak_ref(
|
||||||
self.response_sync_queue)
|
self.response_sync_queue)
|
||||||
self.response_queue.warmup.remote()
|
self.response_queue.warmup.remote()
|
||||||
self.response_sync_queue.warmup.remote()
|
self.response_sync_queue.warmup.remote()
|
||||||
|
|
||||||
worker_kwargs = dict(**worker_kwargs,
|
worker_kwargs = dict(**worker_kwargs,
|
||||||
postproc_worker_config=postproc_worker_config,
|
postproc_worker_config=postproc_worker_config,
|
||||||
is_llm_executor=is_llm_executor,
|
is_llm_executor=is_llm_executor,
|
||||||
kv_connector_config=kv_connector_config)
|
kv_connector_config=kv_connector_config)
|
||||||
|
|
||||||
self.create_workers(RayGPUWorker, worker_kwargs)
|
self.create_workers(RayGPUWorker, worker_kwargs)
|
||||||
|
except Exception as e:
|
||||||
|
# Clean up the Ray resources early during exception
|
||||||
|
self.shutdown()
|
||||||
|
logger.error(f"Failed to initialize RayExecutor: {e}")
|
||||||
|
raise e
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def create_actor_weak_ref(actor_handle: ray.actor.ActorHandle):
|
def create_actor_weak_ref(actor_handle: ray.actor.ActorHandle):
|
||||||
@ -137,12 +144,19 @@ class RayExecutor(GenerationExecutor):
|
|||||||
for rank in range(self.world_size)
|
for rank in range(self.world_size)
|
||||||
]
|
]
|
||||||
|
|
||||||
ray.get([worker.__ray_ready__.remote() for worker in self.workers])
|
try:
|
||||||
|
ray.get([worker.__ray_ready__.remote() for worker in self.workers])
|
||||||
|
except ray.exceptions.ActorDiedError as e:
|
||||||
|
if "The actor died because of an error raised in its creation task" in str(
|
||||||
|
e):
|
||||||
|
raise RuntimeError(
|
||||||
|
"RayGPUWorker died during initialization") from e
|
||||||
|
raise
|
||||||
|
|
||||||
|
@unwrap_ray_errors()
|
||||||
def call_all_ray_workers(self, func: str, leader_only: bool,
|
def call_all_ray_workers(self, func: str, leader_only: bool,
|
||||||
async_call: bool, *args, **kwargs):
|
async_call: bool, *args, **kwargs):
|
||||||
workers = (self.workers[0], ) if leader_only else self.workers
|
workers = (self.workers[0], ) if leader_only else self.workers
|
||||||
|
|
||||||
if async_call:
|
if async_call:
|
||||||
return [
|
return [
|
||||||
getattr(worker, func).remote(*args, **kwargs)
|
getattr(worker, func).remote(*args, **kwargs)
|
||||||
@ -154,6 +168,7 @@ class RayExecutor(GenerationExecutor):
|
|||||||
for worker in workers
|
for worker in workers
|
||||||
])
|
])
|
||||||
|
|
||||||
|
@unwrap_ray_errors()
|
||||||
def collective_rpc(self,
|
def collective_rpc(self,
|
||||||
method: str,
|
method: str,
|
||||||
args: tuple = (),
|
args: tuple = (),
|
||||||
@ -174,7 +189,6 @@ class RayExecutor(GenerationExecutor):
|
|||||||
# Ray actor doesn't work with __getattr__ delegation.
|
# Ray actor doesn't work with __getattr__ delegation.
|
||||||
refs.append(w.call_worker_method.remote(method, *args,
|
refs.append(w.call_worker_method.remote(method, *args,
|
||||||
**kwargs))
|
**kwargs))
|
||||||
|
|
||||||
return refs if non_block else ray.get(refs)
|
return refs if non_block else ray.get(refs)
|
||||||
|
|
||||||
def submit(self, request: GenerationRequest) -> GenerationResult:
|
def submit(self, request: GenerationRequest) -> GenerationResult:
|
||||||
@ -224,11 +238,14 @@ class RayExecutor(GenerationExecutor):
|
|||||||
self.workers = None
|
self.workers = None
|
||||||
if hasattr(self,
|
if hasattr(self,
|
||||||
"placement_group") and self.placement_group is not None:
|
"placement_group") and self.placement_group is not None:
|
||||||
ray.util.remove_placement_group(self.placement_group)
|
# Only remove placement group if Ray is still initialized
|
||||||
|
# to avoid triggering auto_init_ray() during program exit
|
||||||
|
if ray.is_initialized():
|
||||||
|
ray.util.remove_placement_group(self.placement_group)
|
||||||
self.placement_group = None
|
self.placement_group = None
|
||||||
self.bundle_indices = None
|
self.bundle_indices = None
|
||||||
|
|
||||||
if self.has_start_local_cluser:
|
if self.has_start_local_cluser and ray.is_initialized():
|
||||||
logger.debug("Shutting down Ray cluster")
|
logger.debug("Shutting down Ray cluster")
|
||||||
ray.shutdown()
|
ray.shutdown()
|
||||||
|
|
||||||
|
|||||||
@ -16,6 +16,7 @@ try:
|
|||||||
except ModuleNotFoundError:
|
except ModuleNotFoundError:
|
||||||
from tensorrt_llm import ray_stub as ray
|
from tensorrt_llm import ray_stub as ray
|
||||||
|
|
||||||
|
from .._ray_utils import unwrap_ray_errors
|
||||||
from .._utils import mpi_disabled, nvtx_range_debug
|
from .._utils import mpi_disabled, nvtx_range_debug
|
||||||
from ..bindings import executor as tllm
|
from ..bindings import executor as tllm
|
||||||
from ..disaggregated_params import DisaggregatedParams
|
from ..disaggregated_params import DisaggregatedParams
|
||||||
@ -274,8 +275,8 @@ class GenerationResultBase:
|
|||||||
else:
|
else:
|
||||||
self.queue = ray_queue
|
self.queue = ray_queue
|
||||||
self.aqueue = None
|
self.aqueue = None
|
||||||
|
with unwrap_ray_errors():
|
||||||
ray.get(self.queue.register.remote(id))
|
ray.get(self.queue.register.remote(id))
|
||||||
else:
|
else:
|
||||||
if has_event_loop():
|
if has_event_loop():
|
||||||
self.aqueue = AsyncQueue()
|
self.aqueue = AsyncQueue()
|
||||||
@ -735,7 +736,8 @@ class GenerationResult(GenerationResultBase):
|
|||||||
|
|
||||||
def _result_step(self, timeout: Optional[float] = None):
|
def _result_step(self, timeout: Optional[float] = None):
|
||||||
if mpi_disabled():
|
if mpi_disabled():
|
||||||
response = ray.get(self.queue.get.remote(self.request_id))
|
with unwrap_ray_errors():
|
||||||
|
response = ray.get(self.queue.get.remote(self.request_id))
|
||||||
response = self._handle_ray_response(response)
|
response = self._handle_ray_response(response)
|
||||||
else:
|
else:
|
||||||
response = self.queue.get()
|
response = self.queue.get()
|
||||||
|
|||||||
@ -12,11 +12,11 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import functools
|
from functools import wraps as _wraps
|
||||||
|
|
||||||
from tensorrt_llm._utils import mpi_disabled
|
from tensorrt_llm._utils import mpi_disabled as _mpi_disabled
|
||||||
|
|
||||||
if mpi_disabled():
|
if _mpi_disabled():
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"Ray requested (TLLM_DISABLE_MPI=1), but not installed. Please install Ray."
|
"Ray requested (TLLM_DISABLE_MPI=1), but not installed. Please install Ray."
|
||||||
)
|
)
|
||||||
@ -27,10 +27,11 @@ def remote(*args, **kwargs):
|
|||||||
def decorator(func):
|
def decorator(func):
|
||||||
# Returns a function that always raises.
|
# Returns a function that always raises.
|
||||||
# Decorated class depends on ray, but ray is not installed.
|
# Decorated class depends on ray, but ray is not installed.
|
||||||
@functools.wraps(func)
|
@_wraps(func)
|
||||||
def stub_checker(*_, **__):
|
def stub_checker(*_, **__):
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"Ray not installed, cannot use Ray based feature.")
|
f'Ray not installed, so the remote function / actor "{func.__name__}" is not available.'
|
||||||
|
)
|
||||||
|
|
||||||
return stub_checker
|
return stub_checker
|
||||||
|
|
||||||
@ -38,3 +39,9 @@ def remote(*args, **kwargs):
|
|||||||
return decorator(args[0])
|
return decorator(args[0])
|
||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
def __getattr__(name):
|
||||||
|
raise RuntimeError(
|
||||||
|
f'Ray not installed, so "ray.{name}" is unavailable. Please install Ray.'
|
||||||
|
)
|
||||||
|
|||||||
@ -414,7 +414,6 @@ def test_llama_7b_multi_lora_evict_and_reload_evicted_adapters_in_cpu_and_gpu_ca
|
|||||||
repeats_per_call=1)
|
repeats_per_call=1)
|
||||||
|
|
||||||
|
|
||||||
@skip_ray
|
|
||||||
@skip_gpu_memory_less_than_40gb
|
@skip_gpu_memory_less_than_40gb
|
||||||
def test_llama_7b_peft_cache_config_affects_peft_cache_size():
|
def test_llama_7b_peft_cache_config_affects_peft_cache_size():
|
||||||
"""Tests that LLM arg of peft_cache_config affects the peft cache sizes.
|
"""Tests that LLM arg of peft_cache_config affects the peft cache sizes.
|
||||||
|
|||||||
@ -19,7 +19,8 @@ except ImportError:
|
|||||||
from parameterized import parameterized
|
from parameterized import parameterized
|
||||||
|
|
||||||
import tensorrt_llm
|
import tensorrt_llm
|
||||||
from tensorrt_llm._utils import torch_dtype_to_trt, trt_dtype_to_torch
|
from tensorrt_llm._utils import (mpi_disabled, torch_dtype_to_trt,
|
||||||
|
trt_dtype_to_torch)
|
||||||
from tensorrt_llm.llmapi.utils import get_total_gpu_memory
|
from tensorrt_llm.llmapi.utils import get_total_gpu_memory
|
||||||
from tensorrt_llm.plugin.plugin import ContextFMHAType
|
from tensorrt_llm.plugin.plugin import ContextFMHAType
|
||||||
from tensorrt_llm.quantization import QuantMode
|
from tensorrt_llm.quantization import QuantMode
|
||||||
@ -449,5 +450,4 @@ def check_accuracy(a, b, atol, rtol, percent):
|
|||||||
|
|
||||||
|
|
||||||
skip_ray = pytest.mark.skipif(
|
skip_ray = pytest.mark.skipif(
|
||||||
os.environ.get("TLLM_DISABLE_MPI") == "1",
|
mpi_disabled(), reason="This test is skipped for Ray orchestrator.")
|
||||||
reason="This test is skipped for Ray orchestrator.")
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user