mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[TRTLLM-8511][feat] Add update_weights and sleep_wakeup support for rl integration (#8302)
Signed-off-by: shuyix <219646547+shuyixiong@users.noreply.github.com> Co-authored-by: Liwei Ma <liweim@nvidia.com> Co-authored-by: Jonas Yang CN <joyang@nvidia.com>
This commit is contained in:
parent
e2b2675120
commit
70e4d72ffa
@ -871,6 +871,13 @@ public:
|
||||
return mIsValidStoreForReuseSequence.at(requestId);
|
||||
}
|
||||
|
||||
void resetReuseState()
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(mCachedBlocksRootMutex);
|
||||
mCachedBlocksRoot
|
||||
= std::make_shared<KVCacheBlock>(KVCacheBlock::kCachedBlocksRootId, tensorrt_llm::kernels::KVCacheIndex{0});
|
||||
}
|
||||
|
||||
private:
|
||||
//! \brief Add single block to beam of sequence and mAllocatedBlocksPerSeq.
|
||||
void addBlockToBeam(BlockPtr& block, GenerationRequest& sequence, SizeType32 beamIdx);
|
||||
@ -1347,6 +1354,14 @@ public:
|
||||
return mWindowBlockManagers.at(windowSize).isSequenceValidForStoreForReuse(requestId);
|
||||
}
|
||||
|
||||
void resetReuseState()
|
||||
{
|
||||
for (auto& [windowSize, manager] : mWindowBlockManagers)
|
||||
{
|
||||
manager.resetReuseState();
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
[[nodiscard]] WindowBlockManager const& windowManagerByLayer(SizeType32 layerIdx) const
|
||||
{
|
||||
@ -1533,6 +1548,7 @@ public:
|
||||
|
||||
virtual void refreshBlocks() = 0;
|
||||
virtual void flushIterationEvents() = 0;
|
||||
virtual void resetReuseState() = 0;
|
||||
|
||||
[[nodiscard]] static SizeType32 getSinkBubbleLength(SizeType32 sinkTokenLen, SizeType32 tokensPerBlock);
|
||||
|
||||
@ -1913,6 +1929,11 @@ public:
|
||||
return mBlockManager.findBlocksInReuseTreeByBlockKey(blockKey, windowSize);
|
||||
}
|
||||
|
||||
void resetReuseState() override
|
||||
{
|
||||
mBlockManager.resetReuseState();
|
||||
}
|
||||
|
||||
/// @brief Finds the maximum attention window that can be used on a sequence, given some kv-cache block capacity.
|
||||
///
|
||||
/// @param inputLength The number of input tokens in the sequence.
|
||||
|
||||
@ -482,7 +482,8 @@ void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(nb::module_& m)
|
||||
.def("flush_iteration_events", &BaseKVCacheManager::flushIterationEvents,
|
||||
nb::call_guard<nb::gil_scoped_release>())
|
||||
.def("get_last_block_id", &BaseKVCacheManager::getLastBlockId, nb::call_guard<nb::gil_scoped_release>())
|
||||
.def("unpin_blocks_by_id", &BaseKVCacheManager::unpinBlocksById, nb::call_guard<nb::gil_scoped_release>());
|
||||
.def("unpin_blocks_by_id", &BaseKVCacheManager::unpinBlocksById, nb::call_guard<nb::gil_scoped_release>())
|
||||
.def("reset_reuse_state", &BaseKVCacheManager::resetReuseState, nb::call_guard<nb::gil_scoped_release>());
|
||||
|
||||
nb::bind_vector<CacheBlockIds>(m, "CacheBlockIds")
|
||||
.def("__getstate__", [](CacheBlockIds const& v) { return nb::make_tuple(v); })
|
||||
|
||||
@ -486,7 +486,8 @@ void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(py::module_& m)
|
||||
.def("flush_iteration_events", &BaseKVCacheManager::flushIterationEvents,
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def("get_last_block_id", &BaseKVCacheManager::getLastBlockId, py::call_guard<py::gil_scoped_release>())
|
||||
.def("unpin_blocks_by_id", &BaseKVCacheManager::unpinBlocksById, py::call_guard<py::gil_scoped_release>());
|
||||
.def("unpin_blocks_by_id", &BaseKVCacheManager::unpinBlocksById, py::call_guard<py::gil_scoped_release>())
|
||||
.def("reset_reuse_state", &BaseKVCacheManager::resetReuseState, py::call_guard<py::gil_scoped_release>());
|
||||
|
||||
py::enum_<tbk::CacheType>(m, "CacheType")
|
||||
.value("SELF", tbk::CacheType::kSELF)
|
||||
|
||||
77
examples/llm-api/rlhf/rlhf_utils.py
Normal file
77
examples/llm-api/rlhf/rlhf_utils.py
Normal file
@ -0,0 +1,77 @@
|
||||
import torch
|
||||
|
||||
from tensorrt_llm._ray_utils import control_action_decorator
|
||||
from tensorrt_llm._torch.utils import get_device_uuid
|
||||
from tensorrt_llm.logger import logger
|
||||
|
||||
|
||||
class WorkerExtension:
|
||||
"""Worker extension class for extending TensorRT-LLM Ray workers with custom functionality.
|
||||
|
||||
This class can be injected into tensorrt_llm.LLM() by specifying it via the
|
||||
ray_worker_extension_cls parameter in LLMArgs when using orchestrator_type='ray'.
|
||||
The extension methods will be available on each Ray worker and can be called via
|
||||
the LLM's collective RPC mechanism.
|
||||
|
||||
Examples:
|
||||
Creating an LLM with worker extension:
|
||||
|
||||
>>> llm = LLM(
|
||||
... model=model_dir,
|
||||
... orchestrator_type="ray",
|
||||
... ray_worker_extension_cls="rlhf_utils.WorkerExtension",
|
||||
... )
|
||||
|
||||
Calling extension methods via collective RPC:
|
||||
|
||||
>>> llm._collective_rpc("update_weights", args=(ipc_handles,))
|
||||
"""
|
||||
|
||||
@control_action_decorator
|
||||
def update_weights(self, ipc_handles: dict):
|
||||
"""Update model weights from IPC (Inter-Process Communication) handles.
|
||||
|
||||
This method receives shared memory handles from another process (typically FSDP training),
|
||||
reconstructs tensors from these handles, and loads them into the TensorRT-LLM model.
|
||||
Uses the control_action_decorator to ensure all active requests are finished before
|
||||
updating weights.
|
||||
|
||||
Args:
|
||||
ipc_handles: Dictionary mapping device UUIDs to lists of (param_name, tensor_handle) tuples.
|
||||
Each tensor_handle is a tuple of (func, args) for reconstructing the tensor.
|
||||
|
||||
Raises:
|
||||
ValueError: If the current device's UUID is not found in ipc_handles.
|
||||
Exception: Re-raises any exception encountered during weight update.
|
||||
"""
|
||||
try:
|
||||
logger.info("Update weights from IPC handles")
|
||||
device_uuid = get_device_uuid(self.device_id)
|
||||
|
||||
if device_uuid not in ipc_handles:
|
||||
raise ValueError(f"Device UUID {device_uuid} not found in ipc_handles")
|
||||
|
||||
weights = {}
|
||||
all_handles = ipc_handles[device_uuid]
|
||||
|
||||
for param_name, tensor_handle in all_handles:
|
||||
func, args = tensor_handle
|
||||
list_args = list(args)
|
||||
list_args[6] = self.device_id # Set target device
|
||||
tensor = func(*list_args)
|
||||
weights[param_name] = tensor
|
||||
|
||||
self.engine.model_engine.model.load_weights(weights)
|
||||
torch.cuda.synchronize()
|
||||
self.engine.reset_prefix_cache()
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Encountered an error in update_weights")
|
||||
raise e
|
||||
|
||||
def check_weights_updated(self):
|
||||
"""Check if the weights are updated to 0."""
|
||||
weights_updated = True
|
||||
for name, p in self.engine.model_engine.model.named_parameters():
|
||||
weights_updated = weights_updated and torch.allclose(p, torch.zeros_like(p))
|
||||
return weights_updated
|
||||
@ -12,7 +12,9 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import functools
|
||||
from contextlib import contextmanager
|
||||
from typing import Callable
|
||||
|
||||
try:
|
||||
import ray
|
||||
@ -26,3 +28,16 @@ def unwrap_ray_errors():
|
||||
yield
|
||||
except ray.exceptions.RayTaskError as e:
|
||||
raise e.as_instanceof_cause() from e
|
||||
|
||||
|
||||
def control_action_decorator(func: Callable) -> Callable:
|
||||
"""
|
||||
Decorator that wraps a method to use control_action context manager.
|
||||
"""
|
||||
|
||||
@functools.wraps(func)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
with self.engine.control_action():
|
||||
return func(self, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
@ -871,6 +871,8 @@ def _load_weights_impl(model: Union[nn.Module, DecoderModelForCausalLM],
|
||||
for new_name in params_map[names[-1]]:
|
||||
fw = filter_weights('.'.join(names[:-1] + [new_name]),
|
||||
weights)
|
||||
if not fw:
|
||||
continue
|
||||
if new_name in ['k_proj', 'v_proj']:
|
||||
num_kv_heads_list = [num_kv_heads
|
||||
] * len(fw) if isinstance(
|
||||
@ -887,23 +889,29 @@ def _load_weights_impl(model: Union[nn.Module, DecoderModelForCausalLM],
|
||||
}
|
||||
|
||||
module_weights.append(fw)
|
||||
module.load_weights(weights=module_weights)
|
||||
# Note: module_weights may be empty after filtering (e.g., in streaming weight updates)
|
||||
if module_weights:
|
||||
module.load_weights(weights=module_weights)
|
||||
|
||||
else:
|
||||
module_weights = filter_weights(name, weights)
|
||||
if hasattr(module, 'load_weights'):
|
||||
module.load_weights(weights=[module_weights])
|
||||
else:
|
||||
for n, p in module._parameters.items():
|
||||
if p is not None:
|
||||
# Note: module_weights may be empty after filtering (e.g., in streaming weight updates)
|
||||
if module_weights:
|
||||
if hasattr(module, 'load_weights'):
|
||||
module.load_weights(weights=[module_weights])
|
||||
else:
|
||||
for n, p in module.named_parameters(recurse=False):
|
||||
p.data.copy_(module_weights[n][:])
|
||||
|
||||
if os.environ.get("TRT_LLM_DISABLE_LOAD_WEIGHTS_IN_PARALLEL",
|
||||
"True") in ["True", "true", "1", "yes", "y"]:
|
||||
for name, module in tqdm(list(model.named_modules()),
|
||||
for name, module in tqdm(list(
|
||||
model.named_modules(remove_duplicate=False)),
|
||||
desc="Loading weights"):
|
||||
load_single_module(name, module)
|
||||
else:
|
||||
all_modules = dict(model.named_modules())
|
||||
# remove_duplicate=False ensures original modules sharing weights with next_layer_layernorm are not skipped
|
||||
all_modules = dict(model.named_modules(remove_duplicate=False))
|
||||
serial_load_modules = []
|
||||
if preload_weight_modules is not None:
|
||||
for module in preload_weight_modules:
|
||||
@ -919,10 +927,13 @@ def _load_weights_impl(model: Union[nn.Module, DecoderModelForCausalLM],
|
||||
del all_modules[module]
|
||||
pbar.close()
|
||||
|
||||
pbar = tqdm(list(model.named_modules()),
|
||||
pbar = tqdm(list(model.named_modules(remove_duplicate=False)),
|
||||
desc="Loading weights concurrently")
|
||||
args_list = [(name, module) for name, module in model.named_modules()
|
||||
if name not in serial_load_modules]
|
||||
args_list = [
|
||||
(name, module)
|
||||
for name, module in model.named_modules(remove_duplicate=False)
|
||||
if name not in serial_load_modules
|
||||
]
|
||||
run_concurrently(load_single_module, args_list, pbar=pbar)
|
||||
|
||||
|
||||
@ -950,31 +961,36 @@ def _load_weights_impl_v2(model: Union[nn.Module, DecoderModelForCausalLM],
|
||||
if weight_mapper.does_require_special_handling(module_name):
|
||||
module_weights = weight_mapper.apply_callbacks(
|
||||
module, module_name, module_names_breakdown, weights)
|
||||
module.load_weights(weights=module_weights)
|
||||
# Note: module_weights may be empty after filtering (e.g., in streaming weight updates)
|
||||
if module_weights:
|
||||
module.load_weights(weights=module_weights)
|
||||
else:
|
||||
module_weights = weight_mapper.filter_weights(name, weights)
|
||||
if weight_mapper.is_special_instance_module(module):
|
||||
weight_mapper.handle_special_instance_module(
|
||||
module, module_name, module_weights)
|
||||
|
||||
elif hasattr(module, 'load_weights'):
|
||||
if "linear_attn.conv1d" in name:
|
||||
module_weights['weight'] = module_weights[
|
||||
'weight'].squeeze(dim=1)
|
||||
module.load_weights(weights=[module_weights])
|
||||
else:
|
||||
for n, p in module._parameters.items():
|
||||
if p is not None:
|
||||
# Note: module_weights may be empty after filtering (e.g., in streaming weight updates)
|
||||
if module_weights:
|
||||
if weight_mapper.is_special_instance_module(module):
|
||||
weight_mapper.handle_special_instance_module(
|
||||
module, module_name, module_weights)
|
||||
elif hasattr(module, 'load_weights'):
|
||||
if module_weights:
|
||||
if "linear_attn.conv1d" in name:
|
||||
module_weights['weight'] = module_weights[
|
||||
'weight'].squeeze(dim=1)
|
||||
module.load_weights(weights=[module_weights])
|
||||
else:
|
||||
for n, p in module.named_parameters(recurse=False):
|
||||
weight_mapper.handle_manual_copy(
|
||||
module_name, module_weights, n, p)
|
||||
|
||||
if os.environ.get("TRT_LLM_DISABLE_LOAD_WEIGHTS_IN_PARALLEL",
|
||||
"True") in ["True", "true", "1", "yes", "y"]:
|
||||
for name, module in tqdm(list(model.named_modules()),
|
||||
for name, module in tqdm(list(
|
||||
model.named_modules(remove_duplicate=False)),
|
||||
desc="Loading weights"):
|
||||
load_single_module(name, module)
|
||||
else:
|
||||
all_modules = dict(model.named_modules())
|
||||
# remove_duplicate=False ensures original modules sharing weights with next_layer_layernorm are not skipped
|
||||
all_modules = dict(model.named_modules(remove_duplicate=False))
|
||||
serial_load_modules = []
|
||||
if preload_weight_modules is not None:
|
||||
for module in preload_weight_modules:
|
||||
@ -990,8 +1006,11 @@ def _load_weights_impl_v2(model: Union[nn.Module, DecoderModelForCausalLM],
|
||||
del all_modules[module]
|
||||
pbar.close()
|
||||
|
||||
pbar = tqdm(list(model.named_modules()),
|
||||
pbar = tqdm(list(model.named_modules(remove_duplicate=False)),
|
||||
desc="Loading weights concurrently")
|
||||
args_list = [(name, module) for name, module in model.named_modules()
|
||||
if name not in serial_load_modules]
|
||||
args_list = [
|
||||
(name, module)
|
||||
for name, module in model.named_modules(remove_duplicate=False)
|
||||
if name not in serial_load_modules
|
||||
]
|
||||
run_concurrently(load_single_module, args_list, pbar=pbar)
|
||||
|
||||
@ -670,6 +670,7 @@ def create_py_executor_instance(
|
||||
peft_cache_config: Optional[PeftCacheConfig] = None,
|
||||
scheduler_config: Optional[SchedulerConfig] = None,
|
||||
cache_transceiver_config: Optional[CacheTransceiverConfig] = None,
|
||||
virtual_memory_pools: Optional[dict] = None,
|
||||
) -> PyExecutor:
|
||||
kv_cache_manager = resources.get(ResourceManagerType.KV_CACHE_MANAGER, None)
|
||||
|
||||
@ -818,7 +819,8 @@ def create_py_executor_instance(
|
||||
garbage_collection_gen0_threshold=garbage_collection_gen0_threshold,
|
||||
kv_connector_manager=kv_connector_manager,
|
||||
max_seq_len=max_seq_len,
|
||||
peft_cache_config=peft_cache_config)
|
||||
peft_cache_config=peft_cache_config,
|
||||
virtual_memory_pools=virtual_memory_pools)
|
||||
|
||||
|
||||
def create_torch_sampler_args(mapping: Mapping, *, max_seq_len: int,
|
||||
|
||||
@ -108,6 +108,9 @@ class PyTorchConfig:
|
||||
# If true, ONLY the vision encoder part of the full model is loaded/executed.
|
||||
mm_encoder_only: bool = False
|
||||
|
||||
# Enable extra setup to support sleep feature.
|
||||
enable_sleep: bool = False
|
||||
|
||||
# If true, adjust PyTorch CUDA memory fraction to correspond to the
|
||||
# total GPU memory minus the statically allocated engine memory.
|
||||
# If false, set the PyTorch CUDA memory fraction to 1.0.
|
||||
|
||||
@ -18,6 +18,7 @@ from .llm_request import (ExecutorRequest, LlmRequest,
|
||||
executor_request_to_llm_request)
|
||||
|
||||
SHUTDOWN_REQUEST_ID = -1
|
||||
CONTROL_REQUEST_ID = -2
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
@ -35,7 +36,12 @@ class RequestQueueItem:
|
||||
|
||||
@property
|
||||
def is_normal_request(self):
|
||||
return not (self.is_shutdown_request or self.is_canceled_request)
|
||||
return not (self.is_shutdown_request or self.is_canceled_request
|
||||
or self.is_control_request)
|
||||
|
||||
@property
|
||||
def is_control_request(self):
|
||||
return self.id == CONTROL_REQUEST_ID
|
||||
|
||||
|
||||
class ExecutorRequestQueue:
|
||||
@ -68,6 +74,8 @@ class ExecutorRequestQueue:
|
||||
self.new_active_requests_queue_latency_ms = 0
|
||||
self.is_shutdown = False
|
||||
self.should_exclude_last_generation_logits = False
|
||||
self.control_requests: List[RequestQueueItem] = []
|
||||
self.request_accumulated: List[RequestQueueItem] = []
|
||||
|
||||
self._disable_mpi = mpi_disabled()
|
||||
|
||||
@ -251,6 +259,10 @@ class ExecutorRequestQueue:
|
||||
self.request_queue.put(
|
||||
RequestQueueItem(req_id, is_canceled_request=True))
|
||||
|
||||
def enqueue_control_request(self):
|
||||
with self.enqueue_lock:
|
||||
self.request_queue.put(RequestQueueItem(id=CONTROL_REQUEST_ID))
|
||||
|
||||
def enqueue_shutdown_request(self):
|
||||
with self.enqueue_lock:
|
||||
self.request_queue.put(RequestQueueItem(SHUTDOWN_REQUEST_ID))
|
||||
@ -268,6 +280,10 @@ class ExecutorRequestQueue:
|
||||
all_ranks_num_active_requests: Optional[List[int]] = None
|
||||
) -> List[RequestQueueItem]:
|
||||
"""Common logic for fetching and processing requests from the queue."""
|
||||
# Block new request processing while control requests are pending.
|
||||
# Control requests must be handled exclusively to ensure proper synchronization.
|
||||
if len(self.control_requests) != 0:
|
||||
return []
|
||||
# Calculate timeout
|
||||
idle = (total_num_active_requests == 0) and len(self.waiting_queue) == 0
|
||||
if idle:
|
||||
@ -281,7 +297,13 @@ class ExecutorRequestQueue:
|
||||
# Fetch requests from rank 0
|
||||
new_requests = []
|
||||
if self.dist.rank == 0:
|
||||
new_requests = self._get_from_request_queue(timeout)
|
||||
# Process accumulated requests that were queued during control request handling.
|
||||
if len(self.request_accumulated) != 0:
|
||||
new_requests.extend(self.request_accumulated)
|
||||
self.request_accumulated.clear()
|
||||
# Reset timeout to 0 to avoid hanging when no new requests are available
|
||||
timeout = datetime.timedelta(0)
|
||||
new_requests.extend(self._get_from_request_queue(timeout))
|
||||
|
||||
# Broadcast requests and handle Python objects
|
||||
new_requests, py_request_objects = self._handle_request_broadcasting(
|
||||
@ -465,12 +487,17 @@ class ExecutorRequestQueue:
|
||||
new_requests: List[RequestQueueItem]) -> List[RequestQueueItem]:
|
||||
"""Validate and filter requests, handling shutdown signals."""
|
||||
valid_new_requests = []
|
||||
for req_item in new_requests:
|
||||
for idx, req_item in enumerate(new_requests):
|
||||
if req_item.is_shutdown_request:
|
||||
self.is_shutdown = True
|
||||
break
|
||||
elif req_item.is_canceled_request:
|
||||
self.canceled_req_ids.append(req_item.id)
|
||||
elif req_item.is_control_request:
|
||||
self.control_requests.append(req_item)
|
||||
if self.dist.rank == 0:
|
||||
self.request_accumulated.extend(new_requests[idx + 1:])
|
||||
break
|
||||
else:
|
||||
valid_new_requests.append(req_item)
|
||||
|
||||
|
||||
@ -128,7 +128,8 @@ class PyExecutor:
|
||||
start_worker: bool = True,
|
||||
kv_connector_manager: Optional[KvCacheConnectorManager] = None,
|
||||
max_seq_len: Optional[int] = None,
|
||||
peft_cache_config: Optional[PeftCacheConfig] = None):
|
||||
peft_cache_config: Optional[PeftCacheConfig] = None,
|
||||
virtual_memory_pools: Optional[dict] = None):
|
||||
super(PyExecutor, self).__init__()
|
||||
self.device_id = torch.cuda.current_device()
|
||||
self.global_rank = dist.rank
|
||||
@ -151,6 +152,7 @@ class PyExecutor:
|
||||
self.guided_decoder = guided_decoder
|
||||
self.dist = dist
|
||||
self.disable_overlap_scheduler = disable_overlap_scheduler
|
||||
self.virtual_memory_pools = virtual_memory_pools
|
||||
|
||||
# enqueue and _fetch_new_requests used data
|
||||
self.active = True
|
||||
@ -256,6 +258,8 @@ class PyExecutor:
|
||||
)
|
||||
self.executor_request_queue.set_exclude_last_generation_logits(
|
||||
self.disable_overlap_scheduler, self.dist.pp_size)
|
||||
self.control_request_barrier = threading.Event()
|
||||
self.control_action_done = threading.Event()
|
||||
|
||||
self.stats_lock = threading.Lock()
|
||||
self.stats = []
|
||||
@ -441,6 +445,10 @@ class PyExecutor:
|
||||
del self.model_engine
|
||||
if self.draft_model_engine is not None:
|
||||
del self.draft_model_engine
|
||||
if self.virtual_memory_pools is not None:
|
||||
keys = list(self.virtual_memory_pools.keys())
|
||||
for key in keys:
|
||||
del self.virtual_memory_pools[key]
|
||||
|
||||
def can_enqueue_requests(self) -> bool:
|
||||
"""
|
||||
@ -788,6 +796,8 @@ class PyExecutor:
|
||||
if self.should_stop_processing:
|
||||
break
|
||||
|
||||
self._handle_control_request()
|
||||
|
||||
if self.kv_cache_transceiver:
|
||||
self._check_disagg_gen_transfer_status()
|
||||
|
||||
@ -1124,6 +1134,8 @@ class PyExecutor:
|
||||
iter_start_time = time.time()
|
||||
|
||||
scheduled_batch, iter_stats = self._prepare_and_schedule_batch()
|
||||
self._handle_control_request()
|
||||
|
||||
if scheduled_batch is None:
|
||||
break
|
||||
|
||||
@ -1257,6 +1269,47 @@ class PyExecutor:
|
||||
logger.error(f"Encountered an error in decode: {error_msg}")
|
||||
self._handle_errors(error_msg)
|
||||
|
||||
def _handle_control_request(self):
|
||||
if len(self.active_requests) == 0 and \
|
||||
self.executor_request_queue.get_waiting_queue_size() == 0 and \
|
||||
len(self.executor_request_queue.control_requests) > 0:
|
||||
assert len(self.executor_request_queue.control_requests) == 1, (
|
||||
f"Expected exactly one control request to be processed at a time, "
|
||||
f"but found {len(self.executor_request_queue.control_requests)} control requests. "
|
||||
f"This may indicate a race condition or improper control request handling."
|
||||
)
|
||||
self.executor_request_queue.control_requests.pop(0)
|
||||
self.control_request_barrier.set()
|
||||
self.control_action_done.wait()
|
||||
self.control_action_done.clear()
|
||||
|
||||
@contextmanager
|
||||
def control_action(self):
|
||||
"""
|
||||
Context manager for synchronized control actions.
|
||||
|
||||
Usage:
|
||||
with control_action():
|
||||
# Eventloop thread has finished all previous requests and paused
|
||||
do some actions here
|
||||
# Eventloop thread resumes automatically after exiting
|
||||
"""
|
||||
|
||||
if self.dist.rank == 0:
|
||||
self.executor_request_queue.enqueue_control_request()
|
||||
|
||||
# Wait for worker to finish all previous requests
|
||||
self.control_request_barrier.wait()
|
||||
|
||||
try:
|
||||
# Yield control to the with block
|
||||
# Worker is now paused, safe to execute actions
|
||||
yield self
|
||||
finally:
|
||||
# Cleanup: signal worker to resume
|
||||
self.control_action_done.set()
|
||||
self.control_request_barrier.clear()
|
||||
|
||||
def _executor_loop_overlap(self):
|
||||
torch.cuda.set_device(self.device_id)
|
||||
# ensure the context is created, otherwise, some MPI calls will fail.
|
||||
@ -1271,6 +1324,8 @@ class PyExecutor:
|
||||
iter_start_time = time.time()
|
||||
|
||||
scheduled_batch, iter_stats = self._prepare_and_schedule_batch()
|
||||
self._handle_control_request()
|
||||
|
||||
if scheduled_batch is None:
|
||||
break
|
||||
# In gen-only benchmarking mode, wait until the number of scheduled generation
|
||||
@ -2376,6 +2431,9 @@ class PyExecutor:
|
||||
if hasattr(self.drafter, "guided_decoder"):
|
||||
self.guided_decoder.rollback_draft_tokens()
|
||||
|
||||
def reset_prefix_cache(self):
|
||||
self.kv_cache_manager.reset_reuse_state()
|
||||
|
||||
|
||||
class DisaggPPTerminationHandler:
|
||||
"""Handles termination synchronization across pipeline parallel ranks under disaggregated serving.
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
import copy
|
||||
import enum
|
||||
import importlib
|
||||
import os
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
@ -29,6 +28,8 @@ from ..attention_backend.interface import AttentionRuntimeFeatures
|
||||
from ..distributed import MPIDist, TorchDist
|
||||
from ..speculative import (get_num_extra_kv_tokens, get_spec_drafter,
|
||||
get_spec_resource_manager)
|
||||
from ..virtual_memory import ExecutorMemoryType, RestoreMode
|
||||
from ..virtual_memory import scope as virtual_memory_scope
|
||||
from ._util import (KvCacheCreator, _adjust_torch_mem_fraction,
|
||||
create_py_executor_instance, instantiate_sampler, is_mla,
|
||||
validate_feature_combination)
|
||||
@ -40,25 +41,12 @@ from .model_engine import PyTorchModelEngine
|
||||
from .py_executor import PyExecutor
|
||||
|
||||
|
||||
class _ExecutorCreationStage(enum.Enum):
|
||||
SAMPLER = "Sampler"
|
||||
DRAFTER = "Drafter"
|
||||
GUIDED_DECODER = "Guided decoder"
|
||||
INIT_KV_CACHE = "Initial KV cache (temporary for KV cache size estimation)"
|
||||
INIT_EXTRA_RESOURCES = "Additional executor resources (temporary for KV cache size estimation)"
|
||||
MODEL_EXTRA = "Model resources created during usage"
|
||||
EXTRA_RESOURCES = "Additional executor resources"
|
||||
KV_CACHE = "KV cache"
|
||||
MODEL_ENGINE_MAIN = "Model"
|
||||
MODEL_ENGINE_DRAFT = "Draft model for speculative decoding"
|
||||
|
||||
|
||||
class _ExecutorMemoryMonitor():
|
||||
class _ExecutorMemoryMonitor:
|
||||
"""Currently this focuses on tracking memory usage and related errors."""
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class _GpuMemoryUsageSample:
|
||||
creation_stage: _ExecutorCreationStage
|
||||
creation_stage: ExecutorMemoryType
|
||||
free_gpu_memory_bytes_pre: int
|
||||
free_gpu_memory_bytes_post: int
|
||||
|
||||
@ -70,47 +58,74 @@ class _ExecutorMemoryMonitor():
|
||||
def _bytes_to_gib(bytes: int) -> float:
|
||||
return bytes / (1024)**3
|
||||
|
||||
memory_type_friendly_names = {
|
||||
ExecutorMemoryType.SAMPLER:
|
||||
"Sampler",
|
||||
ExecutorMemoryType.DRAFTER:
|
||||
"Drafter",
|
||||
ExecutorMemoryType.GUIDED_DECODER:
|
||||
"Guided Decoder",
|
||||
ExecutorMemoryType.SPEC_RESOURCES:
|
||||
"Speculative decoding resources",
|
||||
ExecutorMemoryType.INIT_KV_CACHE:
|
||||
"Initial KV Cache (temporary for KV cache size estimation)",
|
||||
ExecutorMemoryType.INIT_EXTRA_RESOURCES:
|
||||
"Additional executor resources (temporary for KV cache size estimation)",
|
||||
ExecutorMemoryType.MODEL_EXTRA:
|
||||
"Model resources created during usage",
|
||||
ExecutorMemoryType.EXTRA_RESOURCES:
|
||||
"Additional executor resources",
|
||||
ExecutorMemoryType.KV_CACHE:
|
||||
"KV cache",
|
||||
ExecutorMemoryType.MODEL_ENGINE_MAIN:
|
||||
"Model",
|
||||
ExecutorMemoryType.MODEL_ENGINE_DRAFT:
|
||||
"Draft model for speculative decoding",
|
||||
}
|
||||
|
||||
# Suggestion to reduce component memory usage
|
||||
memory_type_tuning_suggestion = {
|
||||
ExecutorMemoryType.SAMPLER:
|
||||
"reduce max_seq_len and/or max_attention_window_size",
|
||||
ExecutorMemoryType.DRAFTER:
|
||||
"reduce max_seq_len and/or max_draft_len",
|
||||
ExecutorMemoryType.SPEC_RESOURCES:
|
||||
"reduce max_seq_len and/or max_batch_size",
|
||||
ExecutorMemoryType.KV_CACHE:
|
||||
"reduce free_gpu_memory_fraction",
|
||||
ExecutorMemoryType.INIT_KV_CACHE:
|
||||
"reduce max_num_tokens",
|
||||
ExecutorMemoryType.MODEL_ENGINE_MAIN:
|
||||
("reduce max_num_tokens and/or shard the model weights across GPUs by enabling "
|
||||
"pipeline and/or tensor parallelism"),
|
||||
ExecutorMemoryType.MODEL_ENGINE_DRAFT:
|
||||
("reduce max_num_tokens and/or shard the model weights across GPUs by enabling "
|
||||
"pipeline and/or tensor parallelism"),
|
||||
ExecutorMemoryType.INIT_EXTRA_RESOURCES:
|
||||
"reduce max_num_tokens",
|
||||
ExecutorMemoryType.EXTRA_RESOURCES:
|
||||
"reduce max_num_tokens",
|
||||
ExecutorMemoryType.MODEL_EXTRA:
|
||||
"reduce max_num_tokens",
|
||||
}
|
||||
|
||||
def _maybe_explain_if_oom(self, e: Exception, *,
|
||||
current_stage: _ExecutorCreationStage,
|
||||
current_stage: ExecutorMemoryType,
|
||||
free_gpu_memory_bytes_pre: int) -> Optional[str]:
|
||||
if isinstance(e, torch.OutOfMemoryError) or "out of memory" in str(e):
|
||||
msg = "Executor creation failed due to insufficient GPU memory."
|
||||
elif (isinstance(e, RuntimeError) and "Failed, NCCL error" in str(e)
|
||||
and "unhandled cuda error (run with NCCL_DEBUG=INFO for details)"
|
||||
in str(e)):
|
||||
msg = (f"Executor creation failed with NCCL error: {str(e)}")
|
||||
msg = f"Executor creation failed with NCCL error: {str(e)}"
|
||||
return msg
|
||||
else:
|
||||
return None
|
||||
|
||||
# how to reduce component memory usage
|
||||
tuning_knobs = {
|
||||
_ExecutorCreationStage.SAMPLER:
|
||||
"reduce max_seq_len and/or max_attention_window_size",
|
||||
_ExecutorCreationStage.DRAFTER:
|
||||
"reduce max_seq_len and/or max_draft_len",
|
||||
_ExecutorCreationStage.KV_CACHE:
|
||||
"reduce free_gpu_memory_fraction",
|
||||
_ExecutorCreationStage.INIT_KV_CACHE:
|
||||
"reduce max_num_tokens",
|
||||
_ExecutorCreationStage.MODEL_ENGINE_MAIN:
|
||||
("reduce max_num_tokens and/or shard the model weights across GPUs by enabling "
|
||||
"pipeline and/or tensor parallelism"),
|
||||
_ExecutorCreationStage.MODEL_ENGINE_DRAFT:
|
||||
("reduce max_num_tokens and/or shard the model weights across GPUs by enabling "
|
||||
"pipeline and/or tensor parallelism"),
|
||||
_ExecutorCreationStage.INIT_EXTRA_RESOURCES:
|
||||
"reduce max_num_tokens",
|
||||
_ExecutorCreationStage.EXTRA_RESOURCES:
|
||||
"reduce max_num_tokens",
|
||||
_ExecutorCreationStage.MODEL_EXTRA:
|
||||
"reduce max_num_tokens",
|
||||
}
|
||||
|
||||
msg = "\n".join([
|
||||
msg,
|
||||
"",
|
||||
f"The following component could not be created: {current_stage.value}",
|
||||
f"The following component could not be created: {self.memory_type_friendly_names[current_stage]}",
|
||||
f"Total GPU memory (GiB): {self._bytes_to_gib(self._total_gpu_memory_bytes):.2f}",
|
||||
f"Free GPU memory before component creation attempt (GiB): {self._bytes_to_gib(free_gpu_memory_bytes_pre):.2f}",
|
||||
"",
|
||||
@ -122,15 +137,15 @@ class _ExecutorMemoryMonitor():
|
||||
("Please refer to the TensorRT LLM documentation for information on how "
|
||||
"to control the memory usage through TensorRT LLM configuration options. "
|
||||
"Possible options include:"),
|
||||
*(f" {stage.value}: {tuning_knobs[stage]}"
|
||||
*(f" {stage.value}: {self.memory_type_tuning_suggestion[stage]}"
|
||||
for stage in chain((sample.creation_stage
|
||||
for sample in self._samples), [current_stage])
|
||||
if stage in tuning_knobs),
|
||||
if stage in self.memory_type_tuning_suggestion),
|
||||
])
|
||||
return msg
|
||||
|
||||
@contextmanager
|
||||
def observe_creation_stage(self, current_stage: _ExecutorCreationStage):
|
||||
def observe_creation_stage(self, current_stage: ExecutorMemoryType):
|
||||
"""Catches OOM and prints instructive message."""
|
||||
|
||||
free_gpu_memory_bytes_pre = torch.cuda.mem_get_info()[0]
|
||||
@ -295,6 +310,9 @@ def create_py_executor(
|
||||
else:
|
||||
dist = MPIDist(mapping=mapping)
|
||||
|
||||
vm_pools = {}
|
||||
enable_sleep = pytorch_backend_config.enable_sleep
|
||||
|
||||
cache_transceiver_config = llm_args.cache_transceiver_config
|
||||
|
||||
has_draft_model_engine = False
|
||||
@ -313,8 +331,23 @@ def create_py_executor(
|
||||
logger.info("ATTENTION RUNTIME FEATURES: ", attn_runtime_features)
|
||||
|
||||
mem_monitor = _ExecutorMemoryMonitor()
|
||||
with mem_monitor.observe_creation_stage(
|
||||
_ExecutorCreationStage.MODEL_ENGINE_MAIN):
|
||||
|
||||
@contextmanager
|
||||
def allocation_scope(current_stage: ExecutorMemoryType,
|
||||
restore_mode: RestoreMode):
|
||||
with mem_monitor.observe_creation_stage(current_stage):
|
||||
stage = current_stage.value
|
||||
if not enable_sleep or stage.startswith("_no_capture"):
|
||||
yield
|
||||
else:
|
||||
with virtual_memory_scope(stage, restore_mode) as memory_pool:
|
||||
if stage in vm_pools:
|
||||
del vm_pools[stage]
|
||||
vm_pools[stage] = memory_pool
|
||||
yield
|
||||
|
||||
with allocation_scope(ExecutorMemoryType.MODEL_ENGINE_MAIN,
|
||||
RestoreMode.PINNED):
|
||||
model_engine = PyTorchModelEngine(
|
||||
model_path=checkpoint_dir,
|
||||
llm_args=llm_args,
|
||||
@ -328,8 +361,8 @@ def create_py_executor(
|
||||
pytorch_backend_config.sampler_type)
|
||||
|
||||
if has_draft_model_engine:
|
||||
with mem_monitor.observe_creation_stage(
|
||||
_ExecutorCreationStage.MODEL_ENGINE_DRAFT):
|
||||
with allocation_scope(ExecutorMemoryType.MODEL_ENGINE_DRAFT,
|
||||
RestoreMode.PINNED):
|
||||
draft_spec_config = copy.copy(spec_config)
|
||||
|
||||
use_chain_drafter = (
|
||||
@ -447,10 +480,10 @@ def create_py_executor(
|
||||
else:
|
||||
ctx_chunk_config = None
|
||||
|
||||
with mem_monitor.observe_creation_stage(
|
||||
_ExecutorCreationStage.GUIDED_DECODER):
|
||||
guided_decoder: Optional[GuidedDecoder] = None
|
||||
if guided_decoding_config is not None:
|
||||
guided_decoder: Optional[GuidedDecoder] = None
|
||||
if guided_decoding_config is not None:
|
||||
with allocation_scope(ExecutorMemoryType.GUIDED_DECODER,
|
||||
RestoreMode.PINNED):
|
||||
if mapping.is_last_pp_rank():
|
||||
kwargs = {
|
||||
"guided_decoding_config": guided_decoding_config,
|
||||
@ -479,7 +512,7 @@ def create_py_executor(
|
||||
f"Guided decoding is not supported for speculative decoding mode: {spec_config.spec_dec_mode.name}."
|
||||
)
|
||||
|
||||
with mem_monitor.observe_creation_stage(_ExecutorCreationStage.SAMPLER):
|
||||
with allocation_scope(ExecutorMemoryType.SAMPLER, RestoreMode.PINNED):
|
||||
sampler = instantiate_sampler(model_engine,
|
||||
pytorch_backend_config,
|
||||
mapping,
|
||||
@ -565,9 +598,9 @@ def create_py_executor(
|
||||
sparse_attention_config=sparse_attention_config,
|
||||
)
|
||||
estimating_kv_cache = kv_cache_creator.try_prepare_estimation()
|
||||
with mem_monitor.observe_creation_stage(
|
||||
_ExecutorCreationStage.INIT_KV_CACHE
|
||||
if estimating_kv_cache else _ExecutorCreationStage.KV_CACHE):
|
||||
with allocation_scope(
|
||||
ExecutorMemoryType.INIT_KV_CACHE if estimating_kv_cache else
|
||||
ExecutorMemoryType.KV_CACHE, RestoreMode.NONE):
|
||||
kv_cache_creator.build_managers(resources, estimating_kv_cache)
|
||||
# Originally, max_seq_len might be mutated inside build_managers as field of executor config.
|
||||
# Since now, we are changing kv_cache_creator._max_seq_len instead. Restore max_seq_len here.
|
||||
@ -577,23 +610,26 @@ def create_py_executor(
|
||||
# Resource managers for speculative decoding
|
||||
# For user-specified drafters, use extra_resource_managers in PyTorchBackend config
|
||||
# to provide a resource manager if required.
|
||||
spec_resource_manager = get_spec_resource_manager(model_engine,
|
||||
draft_model_engine)
|
||||
|
||||
with allocation_scope(ExecutorMemoryType.SPEC_RESOURCES,
|
||||
RestoreMode.PINNED):
|
||||
spec_resource_manager = get_spec_resource_manager(
|
||||
model_engine, draft_model_engine)
|
||||
if spec_resource_manager is not None:
|
||||
resources[
|
||||
ResourceManagerType.SPEC_RESOURCE_MANAGER] = spec_resource_manager
|
||||
|
||||
# Drafter for speculative decoding
|
||||
with mem_monitor.observe_creation_stage(_ExecutorCreationStage.DRAFTER):
|
||||
with allocation_scope(ExecutorMemoryType.DRAFTER, RestoreMode.PINNED):
|
||||
drafter = get_spec_drafter(model_engine,
|
||||
draft_model_engine,
|
||||
sampler,
|
||||
spec_resource_manager=spec_resource_manager,
|
||||
guided_decoder=guided_decoder)
|
||||
|
||||
with mem_monitor.observe_creation_stage(
|
||||
_ExecutorCreationStage.INIT_EXTRA_RESOURCES
|
||||
if estimating_kv_cache else _ExecutorCreationStage.EXTRA_RESOURCES):
|
||||
with allocation_scope(
|
||||
ExecutorMemoryType.INIT_EXTRA_RESOURCES if estimating_kv_cache else
|
||||
ExecutorMemoryType.EXTRA_RESOURCES, RestoreMode.PINNED):
|
||||
py_executor = create_py_executor_instance(
|
||||
dist=dist,
|
||||
resources=resources,
|
||||
@ -616,6 +652,7 @@ def create_py_executor(
|
||||
peft_cache_config=peft_cache_config,
|
||||
scheduler_config=scheduler_config,
|
||||
cache_transceiver_config=cache_transceiver_config,
|
||||
virtual_memory_pools=vm_pools if not estimating_kv_cache else None,
|
||||
)
|
||||
# Originally, peft_cache_config might be mutated inside
|
||||
# create_py_executor_instance. Restore it here.
|
||||
@ -623,14 +660,13 @@ def create_py_executor(
|
||||
|
||||
if estimating_kv_cache:
|
||||
assert kv_cache_creator is not None
|
||||
with mem_monitor.observe_creation_stage(
|
||||
_ExecutorCreationStage.MODEL_EXTRA):
|
||||
with allocation_scope(ExecutorMemoryType.MODEL_EXTRA,
|
||||
RestoreMode.PINNED):
|
||||
kv_cache_creator.configure_kv_cache_capacity(py_executor)
|
||||
kv_cache_creator.teardown_managers(resources)
|
||||
del py_executor # free before constructing new
|
||||
|
||||
with mem_monitor.observe_creation_stage(
|
||||
_ExecutorCreationStage.KV_CACHE):
|
||||
with allocation_scope(ExecutorMemoryType.KV_CACHE, RestoreMode.NONE):
|
||||
# Before estimating KV cache size, a minimal KV cache has been allocated using
|
||||
# create_kv_cache_manager above, which caps kv_cache_creator.max_seq_len. Restoring
|
||||
# the original value before creating the final KV cache.
|
||||
@ -649,8 +685,8 @@ def create_py_executor(
|
||||
eng._release_cuda_graphs()
|
||||
eng.attn_metadata = None
|
||||
|
||||
with mem_monitor.observe_creation_stage(
|
||||
_ExecutorCreationStage.EXTRA_RESOURCES):
|
||||
with allocation_scope(ExecutorMemoryType.EXTRA_RESOURCES,
|
||||
RestoreMode.PINNED):
|
||||
py_executor = create_py_executor_instance(
|
||||
dist=dist,
|
||||
resources=resources,
|
||||
@ -673,6 +709,7 @@ def create_py_executor(
|
||||
peft_cache_config=peft_cache_config,
|
||||
scheduler_config=scheduler_config,
|
||||
cache_transceiver_config=cache_transceiver_config,
|
||||
virtual_memory_pools=vm_pools,
|
||||
)
|
||||
|
||||
_adjust_torch_mem_fraction(pytorch_backend_config)
|
||||
|
||||
@ -1164,6 +1164,10 @@ class KVCacheManager(BaseResourceManager):
|
||||
else:
|
||||
return None
|
||||
|
||||
def reset_reuse_state(self):
|
||||
"""Reset the reuse state of the KV cache manager."""
|
||||
self.impl.reset_reuse_state()
|
||||
|
||||
|
||||
class SlotManager:
|
||||
|
||||
|
||||
@ -1,8 +1,9 @@
|
||||
import functools
|
||||
from contextlib import contextmanager
|
||||
from typing import Generator
|
||||
from typing import Generator, List
|
||||
|
||||
import torch
|
||||
from strenum import StrEnum
|
||||
|
||||
from tensorrt_llm.bindings.internal.runtime import \
|
||||
CudaVirtualMemoryAllocatorRestoreMode as RestoreMode
|
||||
@ -66,6 +67,32 @@ def maybe_scope(
|
||||
yield
|
||||
|
||||
|
||||
class ExecutorMemoryType(StrEnum):
|
||||
SAMPLER = "sampler"
|
||||
DRAFTER = "drafter"
|
||||
GUIDED_DECODER = "guided_decoder"
|
||||
SPEC_RESOURCES = "spec_resource_manager"
|
||||
INIT_KV_CACHE = "_no_capture_init_kv_cache"
|
||||
INIT_EXTRA_RESOURCES = "_no_capture_init_extra_resources"
|
||||
MODEL_EXTRA = "_no_capture_model_extra" # TODO: remove _no_capture after torch fix crash on torch.cuda.empty_cache()
|
||||
EXTRA_RESOURCES = "executor_extra"
|
||||
KV_CACHE = "kv_cache"
|
||||
MODEL_ENGINE_MAIN = "model"
|
||||
MODEL_ENGINE_DRAFT = "draft_model"
|
||||
|
||||
|
||||
def verify_sleep_wakeup_tags(tags_strs: List[str]) -> List[ExecutorMemoryType]:
|
||||
tags = []
|
||||
for tag_str in tags_strs:
|
||||
try:
|
||||
tags.append(ExecutorMemoryType(tag_str))
|
||||
except ValueError:
|
||||
raise ValueError(
|
||||
f"Unknown memory tag '{tag_str}'."
|
||||
f"Valid tags are: {[t.value for t in ExecutorMemoryType]}")
|
||||
return tags
|
||||
|
||||
|
||||
def release_with_tag(*tags: str) -> int:
|
||||
"""Release virtual memory allocated with given tags
|
||||
|
||||
|
||||
@ -2,11 +2,16 @@ import importlib
|
||||
import os
|
||||
from pathlib import Path
|
||||
from queue import Queue
|
||||
from typing import Any, Optional, Type, Union
|
||||
from typing import Any, List, Optional, Type, Union
|
||||
|
||||
import ray
|
||||
import torch
|
||||
|
||||
from tensorrt_llm._ray_utils import control_action_decorator
|
||||
from tensorrt_llm._torch.virtual_memory import (materialize_with_tag,
|
||||
release_with_tag,
|
||||
verify_sleep_wakeup_tags)
|
||||
|
||||
from ..bindings import executor as tllm
|
||||
from ..builder import Engine
|
||||
from ..llmapi.llm_args import BaseLlmArgs
|
||||
@ -230,6 +235,38 @@ class RayGPUWorker(BaseWorker):
|
||||
|
||||
logger.debug(f"Worker {self.rank} shutdown done.")
|
||||
|
||||
@control_action_decorator
|
||||
def sleep(self, sleep_tags: List[str]):
|
||||
if not self.llm_args.enable_sleep:
|
||||
raise ValueError(
|
||||
"Sleep feature is not enabled, please set enable_sleep=True in the LLM arguments."
|
||||
)
|
||||
try:
|
||||
tags = verify_sleep_wakeup_tags(sleep_tags)
|
||||
logger.info(f"Sleep: {tags}")
|
||||
torch.cuda.synchronize()
|
||||
release_with_tag(*tags)
|
||||
torch.cuda.synchronize()
|
||||
except Exception as e:
|
||||
logger.error(f"Encountered an error in sleep: {e}")
|
||||
raise e
|
||||
|
||||
@control_action_decorator
|
||||
def wakeup(self, wakeup_tags: List[str]):
|
||||
if not self.llm_args.enable_sleep:
|
||||
raise ValueError(
|
||||
"Sleep feature is not enabled, please set enable_sleep=True in the LLM arguments."
|
||||
)
|
||||
try:
|
||||
tags = verify_sleep_wakeup_tags(wakeup_tags)
|
||||
logger.info(f"Wakeup: {tags}")
|
||||
torch.cuda.synchronize()
|
||||
materialize_with_tag(*tags)
|
||||
torch.cuda.synchronize()
|
||||
except Exception as e:
|
||||
logger.error(f"Encountered an error in wakeup")
|
||||
raise e
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
|
||||
@ -2587,6 +2587,13 @@ class TorchLlmArgs(BaseLlmArgs):
|
||||
"Allows users to extend the functions of the RayGPUWorker class.",
|
||||
status="prototype")
|
||||
|
||||
enable_sleep: bool = Field(
|
||||
default=False,
|
||||
description=
|
||||
"Enable LLM sleep feature. Sleep feature requires extra setup that may slowdown model loading."
|
||||
"Only enable it if you intend to use this feature.",
|
||||
status="prototype")
|
||||
|
||||
# PrivateVars
|
||||
_quant_config: Optional[QuantConfig] = PrivateAttr(default=None)
|
||||
|
||||
@ -2901,6 +2908,7 @@ class TorchLlmArgs(BaseLlmArgs):
|
||||
batch_wait_timeout_ms=self.batch_wait_timeout_ms,
|
||||
batch_wait_timeout_iters=self.batch_wait_timeout_iters,
|
||||
batch_wait_max_tokens_ratio=self.batch_wait_max_tokens_ratio,
|
||||
enable_sleep=self.enable_sleep,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -22,6 +22,7 @@ from tensorrt_llm.bindings.internal.batch_manager import \
|
||||
from tensorrt_llm.llmapi.llm_args import KvCacheConfig, PeftCacheConfig
|
||||
from tensorrt_llm.lora_helper import LoraConfig
|
||||
from tensorrt_llm.mapping import Mapping
|
||||
from tensorrt_llm.sampling_params import SamplingParams
|
||||
|
||||
DataType = tensorrt_llm.bindings.DataType
|
||||
LoraModule = tensorrt_llm.bindings.LoraModule
|
||||
@ -689,6 +690,80 @@ class TestResourceManager(unittest.TestCase):
|
||||
finally:
|
||||
manager.shutdown()
|
||||
|
||||
@staticmethod
|
||||
def create_llm_request(id, input_tokens, new_tokens=1):
|
||||
sampling_params = SamplingParams()
|
||||
req = LlmRequest(request_id=id,
|
||||
max_new_tokens=new_tokens,
|
||||
input_tokens=input_tokens,
|
||||
sampling_config=tensorrt_llm.bindings.SamplingConfig(
|
||||
sampling_params._get_sampling_config()),
|
||||
is_streaming=False)
|
||||
return req
|
||||
|
||||
def test_kv_cache_reset_reuse_state(self):
|
||||
|
||||
global_kvcache_config = KvCacheConfig(free_gpu_memory_fraction=0.4,
|
||||
event_buffer_max_size=1024,
|
||||
enable_block_reuse=True,
|
||||
onboard_blocks=True,
|
||||
max_tokens=256)
|
||||
|
||||
kv_cache_manager = KVCacheManager(
|
||||
kv_cache_config=global_kvcache_config,
|
||||
kv_cache_type=tensorrt_llm.bindings.internal.batch_manager.
|
||||
CacheType.SELF,
|
||||
num_layers=2,
|
||||
num_kv_heads=2,
|
||||
head_dim=128,
|
||||
tokens_per_block=64,
|
||||
max_seq_len=1024,
|
||||
max_batch_size=1,
|
||||
mapping=Mapping(),
|
||||
)
|
||||
|
||||
# First request: Add sequence and store blocks for reuse
|
||||
req1 = self.create_llm_request(0, [1, 2, 3, 4, 5])
|
||||
kv_cache_manager.impl.add_sequence(req1.py_request_id, req1.prompt_len,
|
||||
1, req1)
|
||||
|
||||
stats_initial = kv_cache_manager.get_kv_cache_stats()
|
||||
initial_reused_blocks = stats_initial.reused_blocks
|
||||
|
||||
kv_cache_manager.free_resources(req1)
|
||||
|
||||
# Second request with same tokens - should reuse blocks from the reuse tree
|
||||
req2 = self.create_llm_request(1, [1, 2, 3, 4, 5])
|
||||
kv_cache_manager.impl.add_sequence(req2.py_request_id, req2.prompt_len,
|
||||
1, req2)
|
||||
|
||||
stats_after_reuse = kv_cache_manager.get_kv_cache_stats()
|
||||
self.assertGreater(
|
||||
stats_after_reuse.reused_blocks, initial_reused_blocks,
|
||||
f"Second request should reuse blocks. "
|
||||
f"reused_blocks before: {initial_reused_blocks}, after: {stats_after_reuse.reused_blocks}"
|
||||
)
|
||||
|
||||
kv_cache_manager.free_resources(req2)
|
||||
|
||||
# Reset reuse state
|
||||
kv_cache_manager.reset_reuse_state()
|
||||
stats_after_reset = kv_cache_manager.get_kv_cache_stats()
|
||||
reused_blocks_after_reset = stats_after_reset.reused_blocks
|
||||
|
||||
# Third request with same tokens - should NOT reuse blocks after reset
|
||||
req3 = self.create_llm_request(2, [1, 2, 3, 4, 5])
|
||||
kv_cache_manager.impl.add_sequence(req3.py_request_id, req3.prompt_len,
|
||||
1, req3)
|
||||
|
||||
stats_after_third = kv_cache_manager.get_kv_cache_stats()
|
||||
self.assertEqual(
|
||||
stats_after_third.reused_blocks, reused_blocks_after_reset,
|
||||
f"Third request should NOT reuse blocks after reset. "
|
||||
f"reused_blocks after reset: {reused_blocks_after_reset}, after third request: {stats_after_third.reused_blocks}"
|
||||
)
|
||||
kv_cache_manager.free_resources(req3)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@ -1,10 +1,8 @@
|
||||
import gc
|
||||
import os
|
||||
import warnings
|
||||
|
||||
import pynvml
|
||||
import pytest
|
||||
import torch
|
||||
from utils.util import get_current_process_gpu_memory
|
||||
|
||||
import tensorrt_llm
|
||||
from tensorrt_llm._torch import virtual_memory
|
||||
@ -25,56 +23,6 @@ def cuda_sync_fixture():
|
||||
torch.cuda.synchronize()
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def memory_info_available():
|
||||
"""
|
||||
Checks if NVML can get per-process memory information.
|
||||
"""
|
||||
|
||||
# Allocate a small tensor to test memory tracking
|
||||
tensor = torch.zeros(4096, dtype=torch.int32, device='cuda')
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Try to get memory usage
|
||||
usage = get_current_process_memory_info()
|
||||
|
||||
# Clean up
|
||||
del tensor
|
||||
torch.cuda.synchronize()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
if usage == 0:
|
||||
warnings.warn("Per process memory information unavailable.")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", autouse=True)
|
||||
def nvml_init():
|
||||
pynvml.nvmlInit()
|
||||
|
||||
|
||||
def get_current_process_memory_info() -> int:
|
||||
"""
|
||||
Returns GPU memory usage for current process in bytes.
|
||||
"""
|
||||
# Get current process ID
|
||||
current_pid = os.getpid()
|
||||
# Get device handle for GPU 0
|
||||
device_handle = pynvml.nvmlDeviceGetHandleByIndex(0)
|
||||
|
||||
# Get running processes
|
||||
processes = pynvml.nvmlDeviceGetComputeRunningProcesses(device_handle)
|
||||
|
||||
# Find current process
|
||||
for process in processes:
|
||||
if process.pid == current_pid:
|
||||
return process.usedGpuMemory
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
@pytest.fixture(scope="function", autouse=True)
|
||||
def clean_cache():
|
||||
gc.collect()
|
||||
@ -84,16 +32,16 @@ def clean_cache():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
def test_basic(memory_info_available):
|
||||
memory_usage_begin = get_current_process_memory_info()
|
||||
def test_basic(process_gpu_memory_info_available):
|
||||
memory_usage_begin = get_current_process_gpu_memory()
|
||||
|
||||
alloc_size = 256 * 1024 * 1024
|
||||
tag = "test_tag"
|
||||
|
||||
with virtual_memory.scope(tag) as pool:
|
||||
tensor = torch.full([alloc_size], 42, dtype=torch.int8, device='cuda')
|
||||
memory_usage_materialized = get_current_process_memory_info()
|
||||
if memory_info_available:
|
||||
memory_usage_materialized = get_current_process_gpu_memory()
|
||||
if process_gpu_memory_info_available:
|
||||
assert memory_usage_begin + alloc_size == memory_usage_materialized
|
||||
|
||||
assert tensor[0].item() == 42
|
||||
@ -101,15 +49,15 @@ def test_basic(memory_info_available):
|
||||
torch.cuda.synchronize()
|
||||
virtual_memory.release_with_tag(tag)
|
||||
|
||||
memory_usage_released = get_current_process_memory_info()
|
||||
if memory_info_available:
|
||||
memory_usage_released = get_current_process_gpu_memory()
|
||||
if process_gpu_memory_info_available:
|
||||
assert memory_usage_begin == memory_usage_released
|
||||
|
||||
torch.cuda.synchronize()
|
||||
virtual_memory.materialize_with_tag(tag)
|
||||
|
||||
memory_usage_rematerialized = get_current_process_memory_info()
|
||||
if memory_info_available:
|
||||
memory_usage_rematerialized = get_current_process_gpu_memory()
|
||||
if process_gpu_memory_info_available:
|
||||
assert memory_usage_begin + alloc_size == memory_usage_rematerialized
|
||||
|
||||
torch.fill_(tensor, 24)
|
||||
@ -118,8 +66,8 @@ def test_basic(memory_info_available):
|
||||
del tensor
|
||||
del pool
|
||||
|
||||
memory_usage_end = get_current_process_memory_info()
|
||||
if memory_info_available:
|
||||
memory_usage_end = get_current_process_gpu_memory()
|
||||
if process_gpu_memory_info_available:
|
||||
assert memory_usage_begin == memory_usage_end
|
||||
|
||||
|
||||
@ -146,7 +94,7 @@ def test_restore():
|
||||
del pool
|
||||
|
||||
|
||||
def test_kv_cache_manager(memory_info_available):
|
||||
def test_kv_cache_manager(process_gpu_memory_info_available):
|
||||
kv_cache_params = {
|
||||
"kv_cache_config": KvCacheConfig(max_tokens=1024),
|
||||
"kv_cache_type": CacheType.SELF,
|
||||
@ -164,7 +112,7 @@ def test_kv_cache_manager(memory_info_available):
|
||||
mgr.shutdown()
|
||||
del mgr
|
||||
|
||||
memory_usage_begin = get_current_process_memory_info()
|
||||
memory_usage_begin = get_current_process_gpu_memory()
|
||||
|
||||
tag = "test_tag"
|
||||
cache_size = torch.empty(
|
||||
@ -182,35 +130,35 @@ def test_kv_cache_manager(memory_info_available):
|
||||
|
||||
with virtual_memory.scope(tag) as pool:
|
||||
mgr = KVCacheManager(**kv_cache_params)
|
||||
memory_usage_materialized = get_current_process_memory_info()
|
||||
if memory_info_available:
|
||||
memory_usage_materialized = get_current_process_gpu_memory()
|
||||
if process_gpu_memory_info_available:
|
||||
assert memory_usage_begin + alloc_size == memory_usage_materialized
|
||||
|
||||
torch.cuda.synchronize()
|
||||
virtual_memory.release_with_tag(tag)
|
||||
|
||||
memory_usage_released = get_current_process_memory_info()
|
||||
if memory_info_available:
|
||||
memory_usage_released = get_current_process_gpu_memory()
|
||||
if process_gpu_memory_info_available:
|
||||
assert memory_usage_begin == memory_usage_released
|
||||
|
||||
torch.cuda.synchronize()
|
||||
virtual_memory.materialize_with_tag(tag)
|
||||
|
||||
memory_usage_rematerialized = get_current_process_memory_info()
|
||||
if memory_info_available:
|
||||
memory_usage_rematerialized = get_current_process_gpu_memory()
|
||||
if process_gpu_memory_info_available:
|
||||
assert memory_usage_begin + alloc_size == memory_usage_rematerialized
|
||||
|
||||
mgr.shutdown()
|
||||
del mgr
|
||||
del pool
|
||||
|
||||
memory_usage_end = get_current_process_memory_info()
|
||||
if memory_info_available:
|
||||
memory_usage_end = get_current_process_gpu_memory()
|
||||
if process_gpu_memory_info_available:
|
||||
assert memory_usage_begin == memory_usage_end
|
||||
|
||||
|
||||
@pytest.mark.skip("https://nvbugspro.nvidia.com/bug/5458911")
|
||||
def test_cuda_graph(memory_info_available):
|
||||
def test_cuda_graph(process_gpu_memory_info_available):
|
||||
|
||||
def work(input: torch.Tensor) -> torch.Tensor:
|
||||
intermediate = input + input
|
||||
@ -233,13 +181,13 @@ def test_cuda_graph(memory_info_available):
|
||||
torch.cuda.synchronize()
|
||||
assert static_output[0].item() == 3.0
|
||||
|
||||
memory_usage_before = get_current_process_memory_info()
|
||||
memory_usage_before = get_current_process_gpu_memory()
|
||||
|
||||
torch.cuda.synchronize()
|
||||
virtual_memory.release_with_tag(tag)
|
||||
|
||||
memory_usage_released = get_current_process_memory_info()
|
||||
if memory_info_available:
|
||||
memory_usage_released = get_current_process_gpu_memory()
|
||||
if process_gpu_memory_info_available:
|
||||
assert memory_usage_released < memory_usage_before
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
@ -1,7 +1,9 @@
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from utils.cpp_paths import llm_root # noqa: F401
|
||||
|
||||
from tensorrt_llm._utils import mpi_disabled
|
||||
|
||||
@ -21,3 +23,13 @@ if not mpi_disabled():
|
||||
pytest.skip(
|
||||
"Ray tests are only tested in Ray CI stage or with --run-ray flag",
|
||||
allow_module_level=True)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def add_worker_extension_path(llm_root: Path):
|
||||
worker_extension_path = str(llm_root / "examples" / "llm-api" / "rlhf")
|
||||
original_python_path = os.environ.get('PYTHONPATH', '')
|
||||
os.environ['PYTHONPATH'] = os.pathsep.join(
|
||||
filter(None, [worker_extension_path, original_python_path]))
|
||||
yield
|
||||
os.environ['PYTHONPATH'] = original_python_path
|
||||
|
||||
@ -0,0 +1,51 @@
|
||||
from utils.llm_data import llm_models_root
|
||||
from utils.util import get_current_process_gpu_memory
|
||||
|
||||
from tensorrt_llm import LLM
|
||||
from tensorrt_llm._torch.virtual_memory import ExecutorMemoryType
|
||||
from tensorrt_llm.llmapi import KvCacheConfig, SamplingParams
|
||||
|
||||
|
||||
def test_llm_sleep(process_gpu_memory_info_available):
|
||||
llama_model_path = str(llm_models_root() / "llama-models-v2/TinyLlama-1.1B-Chat-v1.0")
|
||||
kv_cache_config = KvCacheConfig(enable_block_reuse=False, max_tokens=4096)
|
||||
|
||||
llm = LLM(
|
||||
model=llama_model_path,
|
||||
enable_sleep=True,
|
||||
cuda_graph_config=None, # CUDA Graph unsupported yet
|
||||
kv_cache_config=kv_cache_config,
|
||||
)
|
||||
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
|
||||
sampling_params = SamplingParams(temperature=0)
|
||||
|
||||
with llm:
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
generated_before_sleep = [output.outputs[0].text for output in outputs]
|
||||
|
||||
memory_usage_active = get_current_process_gpu_memory(True)
|
||||
|
||||
llm._collective_rpc("sleep", ([ExecutorMemoryType.MODEL_ENGINE_MAIN],))
|
||||
|
||||
memory_usage_sleep = get_current_process_gpu_memory(True)
|
||||
if process_gpu_memory_info_available:
|
||||
assert memory_usage_sleep < memory_usage_active
|
||||
|
||||
llm._collective_rpc("wakeup", ([ExecutorMemoryType.MODEL_ENGINE_MAIN],))
|
||||
|
||||
memory_usage_wakeup = get_current_process_gpu_memory(True)
|
||||
if process_gpu_memory_info_available:
|
||||
assert memory_usage_wakeup > memory_usage_sleep
|
||||
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
generated_after_sleep = [output.outputs[0].text for output in outputs]
|
||||
|
||||
for before, after in zip(generated_before_sleep, generated_after_sleep, strict=True):
|
||||
assert before == after, "Generated result mismatch before and after sleep"
|
||||
@ -0,0 +1,182 @@
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
from torch.multiprocessing.reductions import reduce_tensor
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from utils.llm_data import llm_models_root
|
||||
|
||||
from tensorrt_llm import LLM
|
||||
from tensorrt_llm.llmapi import KvCacheConfig, SamplingParams
|
||||
|
||||
|
||||
class HFModel:
|
||||
def __init__(self, model_name: str):
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name, torch_dtype=torch.bfloat16
|
||||
).to("cuda")
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
self.cuda_device = torch.cuda.current_device()
|
||||
self.all_weights = {}
|
||||
self.device_uuid = [HFModel.get_device_uuid(i) for i in range(torch.cuda.device_count())]
|
||||
|
||||
@staticmethod
|
||||
def get_device_uuid(cuda_device: int):
|
||||
from tensorrt_llm._torch.utils import get_device_uuid
|
||||
|
||||
return get_device_uuid(cuda_device)
|
||||
|
||||
def flip_weights(self):
|
||||
for _, p in self.model.named_parameters():
|
||||
p.data = -p.data
|
||||
|
||||
self._replicate_weights()
|
||||
|
||||
def _replicate_weights(self):
|
||||
model_weights = []
|
||||
for n, p in self.model.named_parameters():
|
||||
model_weights.append((n, p.detach().clone()))
|
||||
|
||||
self.all_weights[self.cuda_device] = model_weights
|
||||
for i in range(torch.cuda.device_count()):
|
||||
if i != self.cuda_device:
|
||||
cur_weights = []
|
||||
for n, p in self.all_weights[self.cuda_device]:
|
||||
cur_weights.append((n, p.to("cuda:" + str(i))))
|
||||
self.all_weights[i] = cur_weights
|
||||
|
||||
def get_weight_ipc_handles(self, cuda_device: int = None):
|
||||
ret = {}
|
||||
device_list = (
|
||||
list(range(torch.cuda.device_count())) if cuda_device is None else [cuda_device]
|
||||
)
|
||||
for device in device_list:
|
||||
all_handles = []
|
||||
for item in self.all_weights[device]:
|
||||
name, p = item
|
||||
handle = reduce_tensor(p)
|
||||
all_handles.append((name, handle))
|
||||
ret[self.device_uuid[device]] = all_handles
|
||||
return ret
|
||||
|
||||
def generate_batch_incremental(
|
||||
self, original_prompts: List[str], generated_token_ids_list: List[List[int]]
|
||||
):
|
||||
"""
|
||||
Generate tokens incrementally for each prompt in the batch: [prompt, prompt+token0, prompt+token0+token1, ...]
|
||||
"""
|
||||
logits_list = []
|
||||
|
||||
for i in range(len(original_prompts)):
|
||||
base_token_ids = self.tokenizer.encode(original_prompts[i], return_tensors="pt")[0].to(
|
||||
"cuda"
|
||||
)
|
||||
cur_logits = []
|
||||
for j in range(len(generated_token_ids_list[i])):
|
||||
if j > 0:
|
||||
cur_gen_tokens = torch.tensor(generated_token_ids_list[i][:j]).to("cuda")
|
||||
cur_token_ids = torch.cat([base_token_ids, cur_gen_tokens], dim=-1)
|
||||
else:
|
||||
cur_token_ids = base_token_ids
|
||||
|
||||
ret = self.model.generate(
|
||||
input_ids=cur_token_ids.unsqueeze(0).cuda(),
|
||||
max_new_tokens=1,
|
||||
return_dict_in_generate=True,
|
||||
output_scores=True,
|
||||
)
|
||||
|
||||
cur_logits.append(ret["scores"][0])
|
||||
cur_logits = torch.stack(cur_logits, dim=0)
|
||||
logits_list.append(cur_logits.squeeze(1))
|
||||
|
||||
return logits_list
|
||||
|
||||
|
||||
def extract_tokens_from_outputs(outputs):
|
||||
"""Extract individual tokens from LLM outputs using token IDs directly"""
|
||||
tokens_list = []
|
||||
for output in outputs:
|
||||
# Get token IDs directly from the output
|
||||
token_ids = output.outputs[0].token_ids
|
||||
tokens_list.append(token_ids)
|
||||
return tokens_list
|
||||
|
||||
|
||||
def compare_logits(
|
||||
logits_list: List[torch.Tensor],
|
||||
ref_logits_list: List[torch.Tensor],
|
||||
topk: int = 20,
|
||||
threshold: float = 0.85,
|
||||
):
|
||||
assert len(logits_list) == len(ref_logits_list)
|
||||
|
||||
for i in range(len(logits_list)):
|
||||
assert logits_list[i].shape == ref_logits_list[i].shape
|
||||
lhs_idx = torch.topk(logits_list[i], topk, dim=-1).indices
|
||||
rhs_idx = torch.topk(ref_logits_list[i], topk, dim=-1).indices
|
||||
# Token wise comparison
|
||||
ratios = []
|
||||
for j in range(lhs_idx.shape[0]):
|
||||
lhs_idx_j = lhs_idx[j].tolist()
|
||||
rhs_idx_j = rhs_idx[j].tolist()
|
||||
overlap = set(lhs_idx_j) & set(rhs_idx_j)
|
||||
ratios.append(len(overlap) / len(lhs_idx_j))
|
||||
|
||||
mean_ratio = sum(ratios) / len(ratios)
|
||||
assert mean_ratio > threshold, (
|
||||
f"Prompt {i}: overlap ratio: {mean_ratio:.2%} is less than {threshold:.2%}"
|
||||
)
|
||||
|
||||
|
||||
def run_generate(llm, hf_model, prompts, sampling_params):
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
llm_logits = []
|
||||
for output in outputs:
|
||||
llm_logits.append(output.outputs[0].generation_logits)
|
||||
|
||||
generated_token_ids_list = extract_tokens_from_outputs(outputs)
|
||||
ref_logits = hf_model.generate_batch_incremental(prompts, generated_token_ids_list)
|
||||
return llm_logits, ref_logits
|
||||
|
||||
|
||||
def test_llm_update_weights(add_worker_extension_path):
|
||||
llama_model_path = str(llm_models_root() / "llama-models-v2/TinyLlama-1.1B-Chat-v1.0")
|
||||
kv_cache_config = KvCacheConfig(enable_block_reuse=True, free_gpu_memory_fraction=0.1)
|
||||
|
||||
hf_model = HFModel(llama_model_path)
|
||||
|
||||
llm = LLM(
|
||||
model=llama_model_path,
|
||||
ray_worker_extension_cls="rlhf_utils.WorkerExtension",
|
||||
tensor_parallel_size=1,
|
||||
pipeline_parallel_size=1,
|
||||
kv_cache_config=kv_cache_config,
|
||||
)
|
||||
|
||||
# Generate texts from the prompts.
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
|
||||
sampling_params = SamplingParams(temperature=0, return_generation_logits=True)
|
||||
|
||||
results = []
|
||||
|
||||
# Stage 1: Generate with original model
|
||||
results.append(run_generate(llm, hf_model, prompts, sampling_params))
|
||||
llm_logits, ref_logits = results[0]
|
||||
compare_logits(llm_logits, ref_logits)
|
||||
|
||||
# Stage 2: Test update with flipped weights
|
||||
hf_model.flip_weights()
|
||||
ipc_handles = hf_model.get_weight_ipc_handles()
|
||||
|
||||
llm._collective_rpc("update_weights", (ipc_handles,))
|
||||
results.append(run_generate(llm, hf_model, prompts, sampling_params))
|
||||
llm_logits, ref_logits = results[1]
|
||||
|
||||
# Compare the logits for this phase since output should be random
|
||||
compare_logits(llm_logits, ref_logits)
|
||||
@ -195,6 +195,10 @@ methods:
|
||||
annotation: Optional[str]
|
||||
default: null
|
||||
status: prototype
|
||||
enable_sleep:
|
||||
annotation: bool
|
||||
default: False
|
||||
status: prototype
|
||||
return_annotation: None
|
||||
generate:
|
||||
parameters:
|
||||
|
||||
@ -16,6 +16,7 @@
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
import warnings
|
||||
from functools import partial
|
||||
from typing import Any
|
||||
|
||||
@ -24,6 +25,7 @@ import pytest
|
||||
import torch
|
||||
import tqdm
|
||||
from mpi4py.futures import MPIPoolExecutor
|
||||
from utils.util import get_current_process_gpu_memory
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
from integration.defs import test_list_parser
|
||||
@ -316,3 +318,28 @@ def _maybe_force_ray(request, monkeypatch, ray_mode):
|
||||
raising=False)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def process_gpu_memory_info_available():
|
||||
"""
|
||||
Checks if NVML can get per-process memory information.
|
||||
"""
|
||||
|
||||
# Allocate a small tensor to test memory tracking
|
||||
tensor = torch.zeros(4096, dtype=torch.int32, device='cuda')
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Try to get memory usage
|
||||
usage = get_current_process_gpu_memory()
|
||||
|
||||
# Clean up
|
||||
del tensor
|
||||
torch.cuda.synchronize()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
if usage == 0:
|
||||
warnings.warn("Per process memory information unavailable.")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
@ -22,6 +22,7 @@ from difflib import SequenceMatcher
|
||||
from pathlib import Path
|
||||
from typing import Any, Generator
|
||||
|
||||
import psutil
|
||||
import pynvml
|
||||
import pytest
|
||||
import tensorrt as trt
|
||||
@ -515,3 +516,33 @@ def assert_no_cuda_sync(
|
||||
|
||||
sleep_ctl.cancel()
|
||||
scope_finished_event.synchronize()
|
||||
|
||||
|
||||
_pynvmlInited = False
|
||||
|
||||
|
||||
def get_current_process_gpu_memory(include_subprocess: bool = False) -> int:
|
||||
"""
|
||||
Returns GPU memory usage for current process in bytes.
|
||||
"""
|
||||
global _pynvmlInited
|
||||
if not _pynvmlInited:
|
||||
pynvml.nvmlInit()
|
||||
_pynvmlInited = True
|
||||
|
||||
# Get current process ID
|
||||
targets = [os.getpid()]
|
||||
if include_subprocess:
|
||||
targets.extend(
|
||||
p.pid for p in psutil.Process(targets[0]).children(recursive=True))
|
||||
targets = frozenset(targets)
|
||||
|
||||
# Get device handle for GPU 0
|
||||
device_handle = pynvml.nvmlDeviceGetHandleByIndex(0)
|
||||
|
||||
# Get running processes
|
||||
processes = pynvml.nvmlDeviceGetComputeRunningProcesses(device_handle)
|
||||
|
||||
# Find current process
|
||||
return sum(process.usedGpuMemory for process in processes
|
||||
if process.pid in targets)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user