[TRTLLM-8511][feat] Add update_weights and sleep_wakeup support for rl integration (#8302)

Signed-off-by: shuyix <219646547+shuyixiong@users.noreply.github.com>
Co-authored-by: Liwei Ma <liweim@nvidia.com>
Co-authored-by: Jonas Yang CN <joyang@nvidia.com>
This commit is contained in:
shuyixiong 2025-11-05 02:19:24 +08:00 committed by GitHub
parent e2b2675120
commit 70e4d72ffa
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
23 changed files with 852 additions and 185 deletions

View File

@ -871,6 +871,13 @@ public:
return mIsValidStoreForReuseSequence.at(requestId);
}
void resetReuseState()
{
std::lock_guard<std::mutex> lock(mCachedBlocksRootMutex);
mCachedBlocksRoot
= std::make_shared<KVCacheBlock>(KVCacheBlock::kCachedBlocksRootId, tensorrt_llm::kernels::KVCacheIndex{0});
}
private:
//! \brief Add single block to beam of sequence and mAllocatedBlocksPerSeq.
void addBlockToBeam(BlockPtr& block, GenerationRequest& sequence, SizeType32 beamIdx);
@ -1347,6 +1354,14 @@ public:
return mWindowBlockManagers.at(windowSize).isSequenceValidForStoreForReuse(requestId);
}
void resetReuseState()
{
for (auto& [windowSize, manager] : mWindowBlockManagers)
{
manager.resetReuseState();
}
}
private:
[[nodiscard]] WindowBlockManager const& windowManagerByLayer(SizeType32 layerIdx) const
{
@ -1533,6 +1548,7 @@ public:
virtual void refreshBlocks() = 0;
virtual void flushIterationEvents() = 0;
virtual void resetReuseState() = 0;
[[nodiscard]] static SizeType32 getSinkBubbleLength(SizeType32 sinkTokenLen, SizeType32 tokensPerBlock);
@ -1913,6 +1929,11 @@ public:
return mBlockManager.findBlocksInReuseTreeByBlockKey(blockKey, windowSize);
}
void resetReuseState() override
{
mBlockManager.resetReuseState();
}
/// @brief Finds the maximum attention window that can be used on a sequence, given some kv-cache block capacity.
///
/// @param inputLength The number of input tokens in the sequence.

View File

@ -482,7 +482,8 @@ void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(nb::module_& m)
.def("flush_iteration_events", &BaseKVCacheManager::flushIterationEvents,
nb::call_guard<nb::gil_scoped_release>())
.def("get_last_block_id", &BaseKVCacheManager::getLastBlockId, nb::call_guard<nb::gil_scoped_release>())
.def("unpin_blocks_by_id", &BaseKVCacheManager::unpinBlocksById, nb::call_guard<nb::gil_scoped_release>());
.def("unpin_blocks_by_id", &BaseKVCacheManager::unpinBlocksById, nb::call_guard<nb::gil_scoped_release>())
.def("reset_reuse_state", &BaseKVCacheManager::resetReuseState, nb::call_guard<nb::gil_scoped_release>());
nb::bind_vector<CacheBlockIds>(m, "CacheBlockIds")
.def("__getstate__", [](CacheBlockIds const& v) { return nb::make_tuple(v); })

View File

@ -486,7 +486,8 @@ void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(py::module_& m)
.def("flush_iteration_events", &BaseKVCacheManager::flushIterationEvents,
py::call_guard<py::gil_scoped_release>())
.def("get_last_block_id", &BaseKVCacheManager::getLastBlockId, py::call_guard<py::gil_scoped_release>())
.def("unpin_blocks_by_id", &BaseKVCacheManager::unpinBlocksById, py::call_guard<py::gil_scoped_release>());
.def("unpin_blocks_by_id", &BaseKVCacheManager::unpinBlocksById, py::call_guard<py::gil_scoped_release>())
.def("reset_reuse_state", &BaseKVCacheManager::resetReuseState, py::call_guard<py::gil_scoped_release>());
py::enum_<tbk::CacheType>(m, "CacheType")
.value("SELF", tbk::CacheType::kSELF)

View File

@ -0,0 +1,77 @@
import torch
from tensorrt_llm._ray_utils import control_action_decorator
from tensorrt_llm._torch.utils import get_device_uuid
from tensorrt_llm.logger import logger
class WorkerExtension:
"""Worker extension class for extending TensorRT-LLM Ray workers with custom functionality.
This class can be injected into tensorrt_llm.LLM() by specifying it via the
ray_worker_extension_cls parameter in LLMArgs when using orchestrator_type='ray'.
The extension methods will be available on each Ray worker and can be called via
the LLM's collective RPC mechanism.
Examples:
Creating an LLM with worker extension:
>>> llm = LLM(
... model=model_dir,
... orchestrator_type="ray",
... ray_worker_extension_cls="rlhf_utils.WorkerExtension",
... )
Calling extension methods via collective RPC:
>>> llm._collective_rpc("update_weights", args=(ipc_handles,))
"""
@control_action_decorator
def update_weights(self, ipc_handles: dict):
"""Update model weights from IPC (Inter-Process Communication) handles.
This method receives shared memory handles from another process (typically FSDP training),
reconstructs tensors from these handles, and loads them into the TensorRT-LLM model.
Uses the control_action_decorator to ensure all active requests are finished before
updating weights.
Args:
ipc_handles: Dictionary mapping device UUIDs to lists of (param_name, tensor_handle) tuples.
Each tensor_handle is a tuple of (func, args) for reconstructing the tensor.
Raises:
ValueError: If the current device's UUID is not found in ipc_handles.
Exception: Re-raises any exception encountered during weight update.
"""
try:
logger.info("Update weights from IPC handles")
device_uuid = get_device_uuid(self.device_id)
if device_uuid not in ipc_handles:
raise ValueError(f"Device UUID {device_uuid} not found in ipc_handles")
weights = {}
all_handles = ipc_handles[device_uuid]
for param_name, tensor_handle in all_handles:
func, args = tensor_handle
list_args = list(args)
list_args[6] = self.device_id # Set target device
tensor = func(*list_args)
weights[param_name] = tensor
self.engine.model_engine.model.load_weights(weights)
torch.cuda.synchronize()
self.engine.reset_prefix_cache()
except Exception as e:
logger.error("Encountered an error in update_weights")
raise e
def check_weights_updated(self):
"""Check if the weights are updated to 0."""
weights_updated = True
for name, p in self.engine.model_engine.model.named_parameters():
weights_updated = weights_updated and torch.allclose(p, torch.zeros_like(p))
return weights_updated

View File

@ -12,7 +12,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
from contextlib import contextmanager
from typing import Callable
try:
import ray
@ -26,3 +28,16 @@ def unwrap_ray_errors():
yield
except ray.exceptions.RayTaskError as e:
raise e.as_instanceof_cause() from e
def control_action_decorator(func: Callable) -> Callable:
"""
Decorator that wraps a method to use control_action context manager.
"""
@functools.wraps(func)
def wrapper(self, *args, **kwargs):
with self.engine.control_action():
return func(self, *args, **kwargs)
return wrapper

View File

@ -871,6 +871,8 @@ def _load_weights_impl(model: Union[nn.Module, DecoderModelForCausalLM],
for new_name in params_map[names[-1]]:
fw = filter_weights('.'.join(names[:-1] + [new_name]),
weights)
if not fw:
continue
if new_name in ['k_proj', 'v_proj']:
num_kv_heads_list = [num_kv_heads
] * len(fw) if isinstance(
@ -887,23 +889,29 @@ def _load_weights_impl(model: Union[nn.Module, DecoderModelForCausalLM],
}
module_weights.append(fw)
module.load_weights(weights=module_weights)
# Note: module_weights may be empty after filtering (e.g., in streaming weight updates)
if module_weights:
module.load_weights(weights=module_weights)
else:
module_weights = filter_weights(name, weights)
if hasattr(module, 'load_weights'):
module.load_weights(weights=[module_weights])
else:
for n, p in module._parameters.items():
if p is not None:
# Note: module_weights may be empty after filtering (e.g., in streaming weight updates)
if module_weights:
if hasattr(module, 'load_weights'):
module.load_weights(weights=[module_weights])
else:
for n, p in module.named_parameters(recurse=False):
p.data.copy_(module_weights[n][:])
if os.environ.get("TRT_LLM_DISABLE_LOAD_WEIGHTS_IN_PARALLEL",
"True") in ["True", "true", "1", "yes", "y"]:
for name, module in tqdm(list(model.named_modules()),
for name, module in tqdm(list(
model.named_modules(remove_duplicate=False)),
desc="Loading weights"):
load_single_module(name, module)
else:
all_modules = dict(model.named_modules())
# remove_duplicate=False ensures original modules sharing weights with next_layer_layernorm are not skipped
all_modules = dict(model.named_modules(remove_duplicate=False))
serial_load_modules = []
if preload_weight_modules is not None:
for module in preload_weight_modules:
@ -919,10 +927,13 @@ def _load_weights_impl(model: Union[nn.Module, DecoderModelForCausalLM],
del all_modules[module]
pbar.close()
pbar = tqdm(list(model.named_modules()),
pbar = tqdm(list(model.named_modules(remove_duplicate=False)),
desc="Loading weights concurrently")
args_list = [(name, module) for name, module in model.named_modules()
if name not in serial_load_modules]
args_list = [
(name, module)
for name, module in model.named_modules(remove_duplicate=False)
if name not in serial_load_modules
]
run_concurrently(load_single_module, args_list, pbar=pbar)
@ -950,31 +961,36 @@ def _load_weights_impl_v2(model: Union[nn.Module, DecoderModelForCausalLM],
if weight_mapper.does_require_special_handling(module_name):
module_weights = weight_mapper.apply_callbacks(
module, module_name, module_names_breakdown, weights)
module.load_weights(weights=module_weights)
# Note: module_weights may be empty after filtering (e.g., in streaming weight updates)
if module_weights:
module.load_weights(weights=module_weights)
else:
module_weights = weight_mapper.filter_weights(name, weights)
if weight_mapper.is_special_instance_module(module):
weight_mapper.handle_special_instance_module(
module, module_name, module_weights)
elif hasattr(module, 'load_weights'):
if "linear_attn.conv1d" in name:
module_weights['weight'] = module_weights[
'weight'].squeeze(dim=1)
module.load_weights(weights=[module_weights])
else:
for n, p in module._parameters.items():
if p is not None:
# Note: module_weights may be empty after filtering (e.g., in streaming weight updates)
if module_weights:
if weight_mapper.is_special_instance_module(module):
weight_mapper.handle_special_instance_module(
module, module_name, module_weights)
elif hasattr(module, 'load_weights'):
if module_weights:
if "linear_attn.conv1d" in name:
module_weights['weight'] = module_weights[
'weight'].squeeze(dim=1)
module.load_weights(weights=[module_weights])
else:
for n, p in module.named_parameters(recurse=False):
weight_mapper.handle_manual_copy(
module_name, module_weights, n, p)
if os.environ.get("TRT_LLM_DISABLE_LOAD_WEIGHTS_IN_PARALLEL",
"True") in ["True", "true", "1", "yes", "y"]:
for name, module in tqdm(list(model.named_modules()),
for name, module in tqdm(list(
model.named_modules(remove_duplicate=False)),
desc="Loading weights"):
load_single_module(name, module)
else:
all_modules = dict(model.named_modules())
# remove_duplicate=False ensures original modules sharing weights with next_layer_layernorm are not skipped
all_modules = dict(model.named_modules(remove_duplicate=False))
serial_load_modules = []
if preload_weight_modules is not None:
for module in preload_weight_modules:
@ -990,8 +1006,11 @@ def _load_weights_impl_v2(model: Union[nn.Module, DecoderModelForCausalLM],
del all_modules[module]
pbar.close()
pbar = tqdm(list(model.named_modules()),
pbar = tqdm(list(model.named_modules(remove_duplicate=False)),
desc="Loading weights concurrently")
args_list = [(name, module) for name, module in model.named_modules()
if name not in serial_load_modules]
args_list = [
(name, module)
for name, module in model.named_modules(remove_duplicate=False)
if name not in serial_load_modules
]
run_concurrently(load_single_module, args_list, pbar=pbar)

View File

@ -670,6 +670,7 @@ def create_py_executor_instance(
peft_cache_config: Optional[PeftCacheConfig] = None,
scheduler_config: Optional[SchedulerConfig] = None,
cache_transceiver_config: Optional[CacheTransceiverConfig] = None,
virtual_memory_pools: Optional[dict] = None,
) -> PyExecutor:
kv_cache_manager = resources.get(ResourceManagerType.KV_CACHE_MANAGER, None)
@ -818,7 +819,8 @@ def create_py_executor_instance(
garbage_collection_gen0_threshold=garbage_collection_gen0_threshold,
kv_connector_manager=kv_connector_manager,
max_seq_len=max_seq_len,
peft_cache_config=peft_cache_config)
peft_cache_config=peft_cache_config,
virtual_memory_pools=virtual_memory_pools)
def create_torch_sampler_args(mapping: Mapping, *, max_seq_len: int,

View File

@ -108,6 +108,9 @@ class PyTorchConfig:
# If true, ONLY the vision encoder part of the full model is loaded/executed.
mm_encoder_only: bool = False
# Enable extra setup to support sleep feature.
enable_sleep: bool = False
# If true, adjust PyTorch CUDA memory fraction to correspond to the
# total GPU memory minus the statically allocated engine memory.
# If false, set the PyTorch CUDA memory fraction to 1.0.

View File

@ -18,6 +18,7 @@ from .llm_request import (ExecutorRequest, LlmRequest,
executor_request_to_llm_request)
SHUTDOWN_REQUEST_ID = -1
CONTROL_REQUEST_ID = -2
@dataclasses.dataclass
@ -35,7 +36,12 @@ class RequestQueueItem:
@property
def is_normal_request(self):
return not (self.is_shutdown_request or self.is_canceled_request)
return not (self.is_shutdown_request or self.is_canceled_request
or self.is_control_request)
@property
def is_control_request(self):
return self.id == CONTROL_REQUEST_ID
class ExecutorRequestQueue:
@ -68,6 +74,8 @@ class ExecutorRequestQueue:
self.new_active_requests_queue_latency_ms = 0
self.is_shutdown = False
self.should_exclude_last_generation_logits = False
self.control_requests: List[RequestQueueItem] = []
self.request_accumulated: List[RequestQueueItem] = []
self._disable_mpi = mpi_disabled()
@ -251,6 +259,10 @@ class ExecutorRequestQueue:
self.request_queue.put(
RequestQueueItem(req_id, is_canceled_request=True))
def enqueue_control_request(self):
with self.enqueue_lock:
self.request_queue.put(RequestQueueItem(id=CONTROL_REQUEST_ID))
def enqueue_shutdown_request(self):
with self.enqueue_lock:
self.request_queue.put(RequestQueueItem(SHUTDOWN_REQUEST_ID))
@ -268,6 +280,10 @@ class ExecutorRequestQueue:
all_ranks_num_active_requests: Optional[List[int]] = None
) -> List[RequestQueueItem]:
"""Common logic for fetching and processing requests from the queue."""
# Block new request processing while control requests are pending.
# Control requests must be handled exclusively to ensure proper synchronization.
if len(self.control_requests) != 0:
return []
# Calculate timeout
idle = (total_num_active_requests == 0) and len(self.waiting_queue) == 0
if idle:
@ -281,7 +297,13 @@ class ExecutorRequestQueue:
# Fetch requests from rank 0
new_requests = []
if self.dist.rank == 0:
new_requests = self._get_from_request_queue(timeout)
# Process accumulated requests that were queued during control request handling.
if len(self.request_accumulated) != 0:
new_requests.extend(self.request_accumulated)
self.request_accumulated.clear()
# Reset timeout to 0 to avoid hanging when no new requests are available
timeout = datetime.timedelta(0)
new_requests.extend(self._get_from_request_queue(timeout))
# Broadcast requests and handle Python objects
new_requests, py_request_objects = self._handle_request_broadcasting(
@ -465,12 +487,17 @@ class ExecutorRequestQueue:
new_requests: List[RequestQueueItem]) -> List[RequestQueueItem]:
"""Validate and filter requests, handling shutdown signals."""
valid_new_requests = []
for req_item in new_requests:
for idx, req_item in enumerate(new_requests):
if req_item.is_shutdown_request:
self.is_shutdown = True
break
elif req_item.is_canceled_request:
self.canceled_req_ids.append(req_item.id)
elif req_item.is_control_request:
self.control_requests.append(req_item)
if self.dist.rank == 0:
self.request_accumulated.extend(new_requests[idx + 1:])
break
else:
valid_new_requests.append(req_item)

View File

@ -128,7 +128,8 @@ class PyExecutor:
start_worker: bool = True,
kv_connector_manager: Optional[KvCacheConnectorManager] = None,
max_seq_len: Optional[int] = None,
peft_cache_config: Optional[PeftCacheConfig] = None):
peft_cache_config: Optional[PeftCacheConfig] = None,
virtual_memory_pools: Optional[dict] = None):
super(PyExecutor, self).__init__()
self.device_id = torch.cuda.current_device()
self.global_rank = dist.rank
@ -151,6 +152,7 @@ class PyExecutor:
self.guided_decoder = guided_decoder
self.dist = dist
self.disable_overlap_scheduler = disable_overlap_scheduler
self.virtual_memory_pools = virtual_memory_pools
# enqueue and _fetch_new_requests used data
self.active = True
@ -256,6 +258,8 @@ class PyExecutor:
)
self.executor_request_queue.set_exclude_last_generation_logits(
self.disable_overlap_scheduler, self.dist.pp_size)
self.control_request_barrier = threading.Event()
self.control_action_done = threading.Event()
self.stats_lock = threading.Lock()
self.stats = []
@ -441,6 +445,10 @@ class PyExecutor:
del self.model_engine
if self.draft_model_engine is not None:
del self.draft_model_engine
if self.virtual_memory_pools is not None:
keys = list(self.virtual_memory_pools.keys())
for key in keys:
del self.virtual_memory_pools[key]
def can_enqueue_requests(self) -> bool:
"""
@ -788,6 +796,8 @@ class PyExecutor:
if self.should_stop_processing:
break
self._handle_control_request()
if self.kv_cache_transceiver:
self._check_disagg_gen_transfer_status()
@ -1124,6 +1134,8 @@ class PyExecutor:
iter_start_time = time.time()
scheduled_batch, iter_stats = self._prepare_and_schedule_batch()
self._handle_control_request()
if scheduled_batch is None:
break
@ -1257,6 +1269,47 @@ class PyExecutor:
logger.error(f"Encountered an error in decode: {error_msg}")
self._handle_errors(error_msg)
def _handle_control_request(self):
if len(self.active_requests) == 0 and \
self.executor_request_queue.get_waiting_queue_size() == 0 and \
len(self.executor_request_queue.control_requests) > 0:
assert len(self.executor_request_queue.control_requests) == 1, (
f"Expected exactly one control request to be processed at a time, "
f"but found {len(self.executor_request_queue.control_requests)} control requests. "
f"This may indicate a race condition or improper control request handling."
)
self.executor_request_queue.control_requests.pop(0)
self.control_request_barrier.set()
self.control_action_done.wait()
self.control_action_done.clear()
@contextmanager
def control_action(self):
"""
Context manager for synchronized control actions.
Usage:
with control_action():
# Eventloop thread has finished all previous requests and paused
do some actions here
# Eventloop thread resumes automatically after exiting
"""
if self.dist.rank == 0:
self.executor_request_queue.enqueue_control_request()
# Wait for worker to finish all previous requests
self.control_request_barrier.wait()
try:
# Yield control to the with block
# Worker is now paused, safe to execute actions
yield self
finally:
# Cleanup: signal worker to resume
self.control_action_done.set()
self.control_request_barrier.clear()
def _executor_loop_overlap(self):
torch.cuda.set_device(self.device_id)
# ensure the context is created, otherwise, some MPI calls will fail.
@ -1271,6 +1324,8 @@ class PyExecutor:
iter_start_time = time.time()
scheduled_batch, iter_stats = self._prepare_and_schedule_batch()
self._handle_control_request()
if scheduled_batch is None:
break
# In gen-only benchmarking mode, wait until the number of scheduled generation
@ -2376,6 +2431,9 @@ class PyExecutor:
if hasattr(self.drafter, "guided_decoder"):
self.guided_decoder.rollback_draft_tokens()
def reset_prefix_cache(self):
self.kv_cache_manager.reset_reuse_state()
class DisaggPPTerminationHandler:
"""Handles termination synchronization across pipeline parallel ranks under disaggregated serving.

View File

@ -1,5 +1,4 @@
import copy
import enum
import importlib
import os
from concurrent.futures import ThreadPoolExecutor
@ -29,6 +28,8 @@ from ..attention_backend.interface import AttentionRuntimeFeatures
from ..distributed import MPIDist, TorchDist
from ..speculative import (get_num_extra_kv_tokens, get_spec_drafter,
get_spec_resource_manager)
from ..virtual_memory import ExecutorMemoryType, RestoreMode
from ..virtual_memory import scope as virtual_memory_scope
from ._util import (KvCacheCreator, _adjust_torch_mem_fraction,
create_py_executor_instance, instantiate_sampler, is_mla,
validate_feature_combination)
@ -40,25 +41,12 @@ from .model_engine import PyTorchModelEngine
from .py_executor import PyExecutor
class _ExecutorCreationStage(enum.Enum):
SAMPLER = "Sampler"
DRAFTER = "Drafter"
GUIDED_DECODER = "Guided decoder"
INIT_KV_CACHE = "Initial KV cache (temporary for KV cache size estimation)"
INIT_EXTRA_RESOURCES = "Additional executor resources (temporary for KV cache size estimation)"
MODEL_EXTRA = "Model resources created during usage"
EXTRA_RESOURCES = "Additional executor resources"
KV_CACHE = "KV cache"
MODEL_ENGINE_MAIN = "Model"
MODEL_ENGINE_DRAFT = "Draft model for speculative decoding"
class _ExecutorMemoryMonitor():
class _ExecutorMemoryMonitor:
"""Currently this focuses on tracking memory usage and related errors."""
@dataclass(frozen=True)
class _GpuMemoryUsageSample:
creation_stage: _ExecutorCreationStage
creation_stage: ExecutorMemoryType
free_gpu_memory_bytes_pre: int
free_gpu_memory_bytes_post: int
@ -70,47 +58,74 @@ class _ExecutorMemoryMonitor():
def _bytes_to_gib(bytes: int) -> float:
return bytes / (1024)**3
memory_type_friendly_names = {
ExecutorMemoryType.SAMPLER:
"Sampler",
ExecutorMemoryType.DRAFTER:
"Drafter",
ExecutorMemoryType.GUIDED_DECODER:
"Guided Decoder",
ExecutorMemoryType.SPEC_RESOURCES:
"Speculative decoding resources",
ExecutorMemoryType.INIT_KV_CACHE:
"Initial KV Cache (temporary for KV cache size estimation)",
ExecutorMemoryType.INIT_EXTRA_RESOURCES:
"Additional executor resources (temporary for KV cache size estimation)",
ExecutorMemoryType.MODEL_EXTRA:
"Model resources created during usage",
ExecutorMemoryType.EXTRA_RESOURCES:
"Additional executor resources",
ExecutorMemoryType.KV_CACHE:
"KV cache",
ExecutorMemoryType.MODEL_ENGINE_MAIN:
"Model",
ExecutorMemoryType.MODEL_ENGINE_DRAFT:
"Draft model for speculative decoding",
}
# Suggestion to reduce component memory usage
memory_type_tuning_suggestion = {
ExecutorMemoryType.SAMPLER:
"reduce max_seq_len and/or max_attention_window_size",
ExecutorMemoryType.DRAFTER:
"reduce max_seq_len and/or max_draft_len",
ExecutorMemoryType.SPEC_RESOURCES:
"reduce max_seq_len and/or max_batch_size",
ExecutorMemoryType.KV_CACHE:
"reduce free_gpu_memory_fraction",
ExecutorMemoryType.INIT_KV_CACHE:
"reduce max_num_tokens",
ExecutorMemoryType.MODEL_ENGINE_MAIN:
("reduce max_num_tokens and/or shard the model weights across GPUs by enabling "
"pipeline and/or tensor parallelism"),
ExecutorMemoryType.MODEL_ENGINE_DRAFT:
("reduce max_num_tokens and/or shard the model weights across GPUs by enabling "
"pipeline and/or tensor parallelism"),
ExecutorMemoryType.INIT_EXTRA_RESOURCES:
"reduce max_num_tokens",
ExecutorMemoryType.EXTRA_RESOURCES:
"reduce max_num_tokens",
ExecutorMemoryType.MODEL_EXTRA:
"reduce max_num_tokens",
}
def _maybe_explain_if_oom(self, e: Exception, *,
current_stage: _ExecutorCreationStage,
current_stage: ExecutorMemoryType,
free_gpu_memory_bytes_pre: int) -> Optional[str]:
if isinstance(e, torch.OutOfMemoryError) or "out of memory" in str(e):
msg = "Executor creation failed due to insufficient GPU memory."
elif (isinstance(e, RuntimeError) and "Failed, NCCL error" in str(e)
and "unhandled cuda error (run with NCCL_DEBUG=INFO for details)"
in str(e)):
msg = (f"Executor creation failed with NCCL error: {str(e)}")
msg = f"Executor creation failed with NCCL error: {str(e)}"
return msg
else:
return None
# how to reduce component memory usage
tuning_knobs = {
_ExecutorCreationStage.SAMPLER:
"reduce max_seq_len and/or max_attention_window_size",
_ExecutorCreationStage.DRAFTER:
"reduce max_seq_len and/or max_draft_len",
_ExecutorCreationStage.KV_CACHE:
"reduce free_gpu_memory_fraction",
_ExecutorCreationStage.INIT_KV_CACHE:
"reduce max_num_tokens",
_ExecutorCreationStage.MODEL_ENGINE_MAIN:
("reduce max_num_tokens and/or shard the model weights across GPUs by enabling "
"pipeline and/or tensor parallelism"),
_ExecutorCreationStage.MODEL_ENGINE_DRAFT:
("reduce max_num_tokens and/or shard the model weights across GPUs by enabling "
"pipeline and/or tensor parallelism"),
_ExecutorCreationStage.INIT_EXTRA_RESOURCES:
"reduce max_num_tokens",
_ExecutorCreationStage.EXTRA_RESOURCES:
"reduce max_num_tokens",
_ExecutorCreationStage.MODEL_EXTRA:
"reduce max_num_tokens",
}
msg = "\n".join([
msg,
"",
f"The following component could not be created: {current_stage.value}",
f"The following component could not be created: {self.memory_type_friendly_names[current_stage]}",
f"Total GPU memory (GiB): {self._bytes_to_gib(self._total_gpu_memory_bytes):.2f}",
f"Free GPU memory before component creation attempt (GiB): {self._bytes_to_gib(free_gpu_memory_bytes_pre):.2f}",
"",
@ -122,15 +137,15 @@ class _ExecutorMemoryMonitor():
("Please refer to the TensorRT LLM documentation for information on how "
"to control the memory usage through TensorRT LLM configuration options. "
"Possible options include:"),
*(f" {stage.value}: {tuning_knobs[stage]}"
*(f" {stage.value}: {self.memory_type_tuning_suggestion[stage]}"
for stage in chain((sample.creation_stage
for sample in self._samples), [current_stage])
if stage in tuning_knobs),
if stage in self.memory_type_tuning_suggestion),
])
return msg
@contextmanager
def observe_creation_stage(self, current_stage: _ExecutorCreationStage):
def observe_creation_stage(self, current_stage: ExecutorMemoryType):
"""Catches OOM and prints instructive message."""
free_gpu_memory_bytes_pre = torch.cuda.mem_get_info()[0]
@ -295,6 +310,9 @@ def create_py_executor(
else:
dist = MPIDist(mapping=mapping)
vm_pools = {}
enable_sleep = pytorch_backend_config.enable_sleep
cache_transceiver_config = llm_args.cache_transceiver_config
has_draft_model_engine = False
@ -313,8 +331,23 @@ def create_py_executor(
logger.info("ATTENTION RUNTIME FEATURES: ", attn_runtime_features)
mem_monitor = _ExecutorMemoryMonitor()
with mem_monitor.observe_creation_stage(
_ExecutorCreationStage.MODEL_ENGINE_MAIN):
@contextmanager
def allocation_scope(current_stage: ExecutorMemoryType,
restore_mode: RestoreMode):
with mem_monitor.observe_creation_stage(current_stage):
stage = current_stage.value
if not enable_sleep or stage.startswith("_no_capture"):
yield
else:
with virtual_memory_scope(stage, restore_mode) as memory_pool:
if stage in vm_pools:
del vm_pools[stage]
vm_pools[stage] = memory_pool
yield
with allocation_scope(ExecutorMemoryType.MODEL_ENGINE_MAIN,
RestoreMode.PINNED):
model_engine = PyTorchModelEngine(
model_path=checkpoint_dir,
llm_args=llm_args,
@ -328,8 +361,8 @@ def create_py_executor(
pytorch_backend_config.sampler_type)
if has_draft_model_engine:
with mem_monitor.observe_creation_stage(
_ExecutorCreationStage.MODEL_ENGINE_DRAFT):
with allocation_scope(ExecutorMemoryType.MODEL_ENGINE_DRAFT,
RestoreMode.PINNED):
draft_spec_config = copy.copy(spec_config)
use_chain_drafter = (
@ -447,10 +480,10 @@ def create_py_executor(
else:
ctx_chunk_config = None
with mem_monitor.observe_creation_stage(
_ExecutorCreationStage.GUIDED_DECODER):
guided_decoder: Optional[GuidedDecoder] = None
if guided_decoding_config is not None:
guided_decoder: Optional[GuidedDecoder] = None
if guided_decoding_config is not None:
with allocation_scope(ExecutorMemoryType.GUIDED_DECODER,
RestoreMode.PINNED):
if mapping.is_last_pp_rank():
kwargs = {
"guided_decoding_config": guided_decoding_config,
@ -479,7 +512,7 @@ def create_py_executor(
f"Guided decoding is not supported for speculative decoding mode: {spec_config.spec_dec_mode.name}."
)
with mem_monitor.observe_creation_stage(_ExecutorCreationStage.SAMPLER):
with allocation_scope(ExecutorMemoryType.SAMPLER, RestoreMode.PINNED):
sampler = instantiate_sampler(model_engine,
pytorch_backend_config,
mapping,
@ -565,9 +598,9 @@ def create_py_executor(
sparse_attention_config=sparse_attention_config,
)
estimating_kv_cache = kv_cache_creator.try_prepare_estimation()
with mem_monitor.observe_creation_stage(
_ExecutorCreationStage.INIT_KV_CACHE
if estimating_kv_cache else _ExecutorCreationStage.KV_CACHE):
with allocation_scope(
ExecutorMemoryType.INIT_KV_CACHE if estimating_kv_cache else
ExecutorMemoryType.KV_CACHE, RestoreMode.NONE):
kv_cache_creator.build_managers(resources, estimating_kv_cache)
# Originally, max_seq_len might be mutated inside build_managers as field of executor config.
# Since now, we are changing kv_cache_creator._max_seq_len instead. Restore max_seq_len here.
@ -577,23 +610,26 @@ def create_py_executor(
# Resource managers for speculative decoding
# For user-specified drafters, use extra_resource_managers in PyTorchBackend config
# to provide a resource manager if required.
spec_resource_manager = get_spec_resource_manager(model_engine,
draft_model_engine)
with allocation_scope(ExecutorMemoryType.SPEC_RESOURCES,
RestoreMode.PINNED):
spec_resource_manager = get_spec_resource_manager(
model_engine, draft_model_engine)
if spec_resource_manager is not None:
resources[
ResourceManagerType.SPEC_RESOURCE_MANAGER] = spec_resource_manager
# Drafter for speculative decoding
with mem_monitor.observe_creation_stage(_ExecutorCreationStage.DRAFTER):
with allocation_scope(ExecutorMemoryType.DRAFTER, RestoreMode.PINNED):
drafter = get_spec_drafter(model_engine,
draft_model_engine,
sampler,
spec_resource_manager=spec_resource_manager,
guided_decoder=guided_decoder)
with mem_monitor.observe_creation_stage(
_ExecutorCreationStage.INIT_EXTRA_RESOURCES
if estimating_kv_cache else _ExecutorCreationStage.EXTRA_RESOURCES):
with allocation_scope(
ExecutorMemoryType.INIT_EXTRA_RESOURCES if estimating_kv_cache else
ExecutorMemoryType.EXTRA_RESOURCES, RestoreMode.PINNED):
py_executor = create_py_executor_instance(
dist=dist,
resources=resources,
@ -616,6 +652,7 @@ def create_py_executor(
peft_cache_config=peft_cache_config,
scheduler_config=scheduler_config,
cache_transceiver_config=cache_transceiver_config,
virtual_memory_pools=vm_pools if not estimating_kv_cache else None,
)
# Originally, peft_cache_config might be mutated inside
# create_py_executor_instance. Restore it here.
@ -623,14 +660,13 @@ def create_py_executor(
if estimating_kv_cache:
assert kv_cache_creator is not None
with mem_monitor.observe_creation_stage(
_ExecutorCreationStage.MODEL_EXTRA):
with allocation_scope(ExecutorMemoryType.MODEL_EXTRA,
RestoreMode.PINNED):
kv_cache_creator.configure_kv_cache_capacity(py_executor)
kv_cache_creator.teardown_managers(resources)
del py_executor # free before constructing new
with mem_monitor.observe_creation_stage(
_ExecutorCreationStage.KV_CACHE):
with allocation_scope(ExecutorMemoryType.KV_CACHE, RestoreMode.NONE):
# Before estimating KV cache size, a minimal KV cache has been allocated using
# create_kv_cache_manager above, which caps kv_cache_creator.max_seq_len. Restoring
# the original value before creating the final KV cache.
@ -649,8 +685,8 @@ def create_py_executor(
eng._release_cuda_graphs()
eng.attn_metadata = None
with mem_monitor.observe_creation_stage(
_ExecutorCreationStage.EXTRA_RESOURCES):
with allocation_scope(ExecutorMemoryType.EXTRA_RESOURCES,
RestoreMode.PINNED):
py_executor = create_py_executor_instance(
dist=dist,
resources=resources,
@ -673,6 +709,7 @@ def create_py_executor(
peft_cache_config=peft_cache_config,
scheduler_config=scheduler_config,
cache_transceiver_config=cache_transceiver_config,
virtual_memory_pools=vm_pools,
)
_adjust_torch_mem_fraction(pytorch_backend_config)

View File

@ -1164,6 +1164,10 @@ class KVCacheManager(BaseResourceManager):
else:
return None
def reset_reuse_state(self):
"""Reset the reuse state of the KV cache manager."""
self.impl.reset_reuse_state()
class SlotManager:

View File

@ -1,8 +1,9 @@
import functools
from contextlib import contextmanager
from typing import Generator
from typing import Generator, List
import torch
from strenum import StrEnum
from tensorrt_llm.bindings.internal.runtime import \
CudaVirtualMemoryAllocatorRestoreMode as RestoreMode
@ -66,6 +67,32 @@ def maybe_scope(
yield
class ExecutorMemoryType(StrEnum):
SAMPLER = "sampler"
DRAFTER = "drafter"
GUIDED_DECODER = "guided_decoder"
SPEC_RESOURCES = "spec_resource_manager"
INIT_KV_CACHE = "_no_capture_init_kv_cache"
INIT_EXTRA_RESOURCES = "_no_capture_init_extra_resources"
MODEL_EXTRA = "_no_capture_model_extra" # TODO: remove _no_capture after torch fix crash on torch.cuda.empty_cache()
EXTRA_RESOURCES = "executor_extra"
KV_CACHE = "kv_cache"
MODEL_ENGINE_MAIN = "model"
MODEL_ENGINE_DRAFT = "draft_model"
def verify_sleep_wakeup_tags(tags_strs: List[str]) -> List[ExecutorMemoryType]:
tags = []
for tag_str in tags_strs:
try:
tags.append(ExecutorMemoryType(tag_str))
except ValueError:
raise ValueError(
f"Unknown memory tag '{tag_str}'."
f"Valid tags are: {[t.value for t in ExecutorMemoryType]}")
return tags
def release_with_tag(*tags: str) -> int:
"""Release virtual memory allocated with given tags

View File

@ -2,11 +2,16 @@ import importlib
import os
from pathlib import Path
from queue import Queue
from typing import Any, Optional, Type, Union
from typing import Any, List, Optional, Type, Union
import ray
import torch
from tensorrt_llm._ray_utils import control_action_decorator
from tensorrt_llm._torch.virtual_memory import (materialize_with_tag,
release_with_tag,
verify_sleep_wakeup_tags)
from ..bindings import executor as tllm
from ..builder import Engine
from ..llmapi.llm_args import BaseLlmArgs
@ -230,6 +235,38 @@ class RayGPUWorker(BaseWorker):
logger.debug(f"Worker {self.rank} shutdown done.")
@control_action_decorator
def sleep(self, sleep_tags: List[str]):
if not self.llm_args.enable_sleep:
raise ValueError(
"Sleep feature is not enabled, please set enable_sleep=True in the LLM arguments."
)
try:
tags = verify_sleep_wakeup_tags(sleep_tags)
logger.info(f"Sleep: {tags}")
torch.cuda.synchronize()
release_with_tag(*tags)
torch.cuda.synchronize()
except Exception as e:
logger.error(f"Encountered an error in sleep: {e}")
raise e
@control_action_decorator
def wakeup(self, wakeup_tags: List[str]):
if not self.llm_args.enable_sleep:
raise ValueError(
"Sleep feature is not enabled, please set enable_sleep=True in the LLM arguments."
)
try:
tags = verify_sleep_wakeup_tags(wakeup_tags)
logger.info(f"Wakeup: {tags}")
torch.cuda.synchronize()
materialize_with_tag(*tags)
torch.cuda.synchronize()
except Exception as e:
logger.error(f"Encountered an error in wakeup")
raise e
def __enter__(self):
return self

View File

@ -2587,6 +2587,13 @@ class TorchLlmArgs(BaseLlmArgs):
"Allows users to extend the functions of the RayGPUWorker class.",
status="prototype")
enable_sleep: bool = Field(
default=False,
description=
"Enable LLM sleep feature. Sleep feature requires extra setup that may slowdown model loading."
"Only enable it if you intend to use this feature.",
status="prototype")
# PrivateVars
_quant_config: Optional[QuantConfig] = PrivateAttr(default=None)
@ -2901,6 +2908,7 @@ class TorchLlmArgs(BaseLlmArgs):
batch_wait_timeout_ms=self.batch_wait_timeout_ms,
batch_wait_timeout_iters=self.batch_wait_timeout_iters,
batch_wait_max_tokens_ratio=self.batch_wait_max_tokens_ratio,
enable_sleep=self.enable_sleep,
)

View File

@ -22,6 +22,7 @@ from tensorrt_llm.bindings.internal.batch_manager import \
from tensorrt_llm.llmapi.llm_args import KvCacheConfig, PeftCacheConfig
from tensorrt_llm.lora_helper import LoraConfig
from tensorrt_llm.mapping import Mapping
from tensorrt_llm.sampling_params import SamplingParams
DataType = tensorrt_llm.bindings.DataType
LoraModule = tensorrt_llm.bindings.LoraModule
@ -689,6 +690,80 @@ class TestResourceManager(unittest.TestCase):
finally:
manager.shutdown()
@staticmethod
def create_llm_request(id, input_tokens, new_tokens=1):
sampling_params = SamplingParams()
req = LlmRequest(request_id=id,
max_new_tokens=new_tokens,
input_tokens=input_tokens,
sampling_config=tensorrt_llm.bindings.SamplingConfig(
sampling_params._get_sampling_config()),
is_streaming=False)
return req
def test_kv_cache_reset_reuse_state(self):
global_kvcache_config = KvCacheConfig(free_gpu_memory_fraction=0.4,
event_buffer_max_size=1024,
enable_block_reuse=True,
onboard_blocks=True,
max_tokens=256)
kv_cache_manager = KVCacheManager(
kv_cache_config=global_kvcache_config,
kv_cache_type=tensorrt_llm.bindings.internal.batch_manager.
CacheType.SELF,
num_layers=2,
num_kv_heads=2,
head_dim=128,
tokens_per_block=64,
max_seq_len=1024,
max_batch_size=1,
mapping=Mapping(),
)
# First request: Add sequence and store blocks for reuse
req1 = self.create_llm_request(0, [1, 2, 3, 4, 5])
kv_cache_manager.impl.add_sequence(req1.py_request_id, req1.prompt_len,
1, req1)
stats_initial = kv_cache_manager.get_kv_cache_stats()
initial_reused_blocks = stats_initial.reused_blocks
kv_cache_manager.free_resources(req1)
# Second request with same tokens - should reuse blocks from the reuse tree
req2 = self.create_llm_request(1, [1, 2, 3, 4, 5])
kv_cache_manager.impl.add_sequence(req2.py_request_id, req2.prompt_len,
1, req2)
stats_after_reuse = kv_cache_manager.get_kv_cache_stats()
self.assertGreater(
stats_after_reuse.reused_blocks, initial_reused_blocks,
f"Second request should reuse blocks. "
f"reused_blocks before: {initial_reused_blocks}, after: {stats_after_reuse.reused_blocks}"
)
kv_cache_manager.free_resources(req2)
# Reset reuse state
kv_cache_manager.reset_reuse_state()
stats_after_reset = kv_cache_manager.get_kv_cache_stats()
reused_blocks_after_reset = stats_after_reset.reused_blocks
# Third request with same tokens - should NOT reuse blocks after reset
req3 = self.create_llm_request(2, [1, 2, 3, 4, 5])
kv_cache_manager.impl.add_sequence(req3.py_request_id, req3.prompt_len,
1, req3)
stats_after_third = kv_cache_manager.get_kv_cache_stats()
self.assertEqual(
stats_after_third.reused_blocks, reused_blocks_after_reset,
f"Third request should NOT reuse blocks after reset. "
f"reused_blocks after reset: {reused_blocks_after_reset}, after third request: {stats_after_third.reused_blocks}"
)
kv_cache_manager.free_resources(req3)
if __name__ == "__main__":
unittest.main()

View File

@ -1,10 +1,8 @@
import gc
import os
import warnings
import pynvml
import pytest
import torch
from utils.util import get_current_process_gpu_memory
import tensorrt_llm
from tensorrt_llm._torch import virtual_memory
@ -25,56 +23,6 @@ def cuda_sync_fixture():
torch.cuda.synchronize()
@pytest.fixture(scope="module")
def memory_info_available():
"""
Checks if NVML can get per-process memory information.
"""
# Allocate a small tensor to test memory tracking
tensor = torch.zeros(4096, dtype=torch.int32, device='cuda')
torch.cuda.synchronize()
# Try to get memory usage
usage = get_current_process_memory_info()
# Clean up
del tensor
torch.cuda.synchronize()
torch.cuda.empty_cache()
if usage == 0:
warnings.warn("Per process memory information unavailable.")
return False
return True
@pytest.fixture(scope="module", autouse=True)
def nvml_init():
pynvml.nvmlInit()
def get_current_process_memory_info() -> int:
"""
Returns GPU memory usage for current process in bytes.
"""
# Get current process ID
current_pid = os.getpid()
# Get device handle for GPU 0
device_handle = pynvml.nvmlDeviceGetHandleByIndex(0)
# Get running processes
processes = pynvml.nvmlDeviceGetComputeRunningProcesses(device_handle)
# Find current process
for process in processes:
if process.pid == current_pid:
return process.usedGpuMemory
return 0
@pytest.fixture(scope="function", autouse=True)
def clean_cache():
gc.collect()
@ -84,16 +32,16 @@ def clean_cache():
torch.cuda.empty_cache()
def test_basic(memory_info_available):
memory_usage_begin = get_current_process_memory_info()
def test_basic(process_gpu_memory_info_available):
memory_usage_begin = get_current_process_gpu_memory()
alloc_size = 256 * 1024 * 1024
tag = "test_tag"
with virtual_memory.scope(tag) as pool:
tensor = torch.full([alloc_size], 42, dtype=torch.int8, device='cuda')
memory_usage_materialized = get_current_process_memory_info()
if memory_info_available:
memory_usage_materialized = get_current_process_gpu_memory()
if process_gpu_memory_info_available:
assert memory_usage_begin + alloc_size == memory_usage_materialized
assert tensor[0].item() == 42
@ -101,15 +49,15 @@ def test_basic(memory_info_available):
torch.cuda.synchronize()
virtual_memory.release_with_tag(tag)
memory_usage_released = get_current_process_memory_info()
if memory_info_available:
memory_usage_released = get_current_process_gpu_memory()
if process_gpu_memory_info_available:
assert memory_usage_begin == memory_usage_released
torch.cuda.synchronize()
virtual_memory.materialize_with_tag(tag)
memory_usage_rematerialized = get_current_process_memory_info()
if memory_info_available:
memory_usage_rematerialized = get_current_process_gpu_memory()
if process_gpu_memory_info_available:
assert memory_usage_begin + alloc_size == memory_usage_rematerialized
torch.fill_(tensor, 24)
@ -118,8 +66,8 @@ def test_basic(memory_info_available):
del tensor
del pool
memory_usage_end = get_current_process_memory_info()
if memory_info_available:
memory_usage_end = get_current_process_gpu_memory()
if process_gpu_memory_info_available:
assert memory_usage_begin == memory_usage_end
@ -146,7 +94,7 @@ def test_restore():
del pool
def test_kv_cache_manager(memory_info_available):
def test_kv_cache_manager(process_gpu_memory_info_available):
kv_cache_params = {
"kv_cache_config": KvCacheConfig(max_tokens=1024),
"kv_cache_type": CacheType.SELF,
@ -164,7 +112,7 @@ def test_kv_cache_manager(memory_info_available):
mgr.shutdown()
del mgr
memory_usage_begin = get_current_process_memory_info()
memory_usage_begin = get_current_process_gpu_memory()
tag = "test_tag"
cache_size = torch.empty(
@ -182,35 +130,35 @@ def test_kv_cache_manager(memory_info_available):
with virtual_memory.scope(tag) as pool:
mgr = KVCacheManager(**kv_cache_params)
memory_usage_materialized = get_current_process_memory_info()
if memory_info_available:
memory_usage_materialized = get_current_process_gpu_memory()
if process_gpu_memory_info_available:
assert memory_usage_begin + alloc_size == memory_usage_materialized
torch.cuda.synchronize()
virtual_memory.release_with_tag(tag)
memory_usage_released = get_current_process_memory_info()
if memory_info_available:
memory_usage_released = get_current_process_gpu_memory()
if process_gpu_memory_info_available:
assert memory_usage_begin == memory_usage_released
torch.cuda.synchronize()
virtual_memory.materialize_with_tag(tag)
memory_usage_rematerialized = get_current_process_memory_info()
if memory_info_available:
memory_usage_rematerialized = get_current_process_gpu_memory()
if process_gpu_memory_info_available:
assert memory_usage_begin + alloc_size == memory_usage_rematerialized
mgr.shutdown()
del mgr
del pool
memory_usage_end = get_current_process_memory_info()
if memory_info_available:
memory_usage_end = get_current_process_gpu_memory()
if process_gpu_memory_info_available:
assert memory_usage_begin == memory_usage_end
@pytest.mark.skip("https://nvbugspro.nvidia.com/bug/5458911")
def test_cuda_graph(memory_info_available):
def test_cuda_graph(process_gpu_memory_info_available):
def work(input: torch.Tensor) -> torch.Tensor:
intermediate = input + input
@ -233,13 +181,13 @@ def test_cuda_graph(memory_info_available):
torch.cuda.synchronize()
assert static_output[0].item() == 3.0
memory_usage_before = get_current_process_memory_info()
memory_usage_before = get_current_process_gpu_memory()
torch.cuda.synchronize()
virtual_memory.release_with_tag(tag)
memory_usage_released = get_current_process_memory_info()
if memory_info_available:
memory_usage_released = get_current_process_gpu_memory()
if process_gpu_memory_info_available:
assert memory_usage_released < memory_usage_before
torch.cuda.synchronize()

View File

@ -1,7 +1,9 @@
import os
import sys
from pathlib import Path
import pytest
from utils.cpp_paths import llm_root # noqa: F401
from tensorrt_llm._utils import mpi_disabled
@ -21,3 +23,13 @@ if not mpi_disabled():
pytest.skip(
"Ray tests are only tested in Ray CI stage or with --run-ray flag",
allow_module_level=True)
@pytest.fixture(scope="function")
def add_worker_extension_path(llm_root: Path):
worker_extension_path = str(llm_root / "examples" / "llm-api" / "rlhf")
original_python_path = os.environ.get('PYTHONPATH', '')
os.environ['PYTHONPATH'] = os.pathsep.join(
filter(None, [worker_extension_path, original_python_path]))
yield
os.environ['PYTHONPATH'] = original_python_path

View File

@ -0,0 +1,51 @@
from utils.llm_data import llm_models_root
from utils.util import get_current_process_gpu_memory
from tensorrt_llm import LLM
from tensorrt_llm._torch.virtual_memory import ExecutorMemoryType
from tensorrt_llm.llmapi import KvCacheConfig, SamplingParams
def test_llm_sleep(process_gpu_memory_info_available):
llama_model_path = str(llm_models_root() / "llama-models-v2/TinyLlama-1.1B-Chat-v1.0")
kv_cache_config = KvCacheConfig(enable_block_reuse=False, max_tokens=4096)
llm = LLM(
model=llama_model_path,
enable_sleep=True,
cuda_graph_config=None, # CUDA Graph unsupported yet
kv_cache_config=kv_cache_config,
)
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
sampling_params = SamplingParams(temperature=0)
with llm:
outputs = llm.generate(prompts, sampling_params)
generated_before_sleep = [output.outputs[0].text for output in outputs]
memory_usage_active = get_current_process_gpu_memory(True)
llm._collective_rpc("sleep", ([ExecutorMemoryType.MODEL_ENGINE_MAIN],))
memory_usage_sleep = get_current_process_gpu_memory(True)
if process_gpu_memory_info_available:
assert memory_usage_sleep < memory_usage_active
llm._collective_rpc("wakeup", ([ExecutorMemoryType.MODEL_ENGINE_MAIN],))
memory_usage_wakeup = get_current_process_gpu_memory(True)
if process_gpu_memory_info_available:
assert memory_usage_wakeup > memory_usage_sleep
outputs = llm.generate(prompts, sampling_params)
generated_after_sleep = [output.outputs[0].text for output in outputs]
for before, after in zip(generated_before_sleep, generated_after_sleep, strict=True):
assert before == after, "Generated result mismatch before and after sleep"

View File

@ -0,0 +1,182 @@
from typing import List
import torch
from torch.multiprocessing.reductions import reduce_tensor
from transformers import AutoModelForCausalLM, AutoTokenizer
from utils.llm_data import llm_models_root
from tensorrt_llm import LLM
from tensorrt_llm.llmapi import KvCacheConfig, SamplingParams
class HFModel:
def __init__(self, model_name: str):
self.model = AutoModelForCausalLM.from_pretrained(
model_name, torch_dtype=torch.bfloat16
).to("cuda")
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.cuda_device = torch.cuda.current_device()
self.all_weights = {}
self.device_uuid = [HFModel.get_device_uuid(i) for i in range(torch.cuda.device_count())]
@staticmethod
def get_device_uuid(cuda_device: int):
from tensorrt_llm._torch.utils import get_device_uuid
return get_device_uuid(cuda_device)
def flip_weights(self):
for _, p in self.model.named_parameters():
p.data = -p.data
self._replicate_weights()
def _replicate_weights(self):
model_weights = []
for n, p in self.model.named_parameters():
model_weights.append((n, p.detach().clone()))
self.all_weights[self.cuda_device] = model_weights
for i in range(torch.cuda.device_count()):
if i != self.cuda_device:
cur_weights = []
for n, p in self.all_weights[self.cuda_device]:
cur_weights.append((n, p.to("cuda:" + str(i))))
self.all_weights[i] = cur_weights
def get_weight_ipc_handles(self, cuda_device: int = None):
ret = {}
device_list = (
list(range(torch.cuda.device_count())) if cuda_device is None else [cuda_device]
)
for device in device_list:
all_handles = []
for item in self.all_weights[device]:
name, p = item
handle = reduce_tensor(p)
all_handles.append((name, handle))
ret[self.device_uuid[device]] = all_handles
return ret
def generate_batch_incremental(
self, original_prompts: List[str], generated_token_ids_list: List[List[int]]
):
"""
Generate tokens incrementally for each prompt in the batch: [prompt, prompt+token0, prompt+token0+token1, ...]
"""
logits_list = []
for i in range(len(original_prompts)):
base_token_ids = self.tokenizer.encode(original_prompts[i], return_tensors="pt")[0].to(
"cuda"
)
cur_logits = []
for j in range(len(generated_token_ids_list[i])):
if j > 0:
cur_gen_tokens = torch.tensor(generated_token_ids_list[i][:j]).to("cuda")
cur_token_ids = torch.cat([base_token_ids, cur_gen_tokens], dim=-1)
else:
cur_token_ids = base_token_ids
ret = self.model.generate(
input_ids=cur_token_ids.unsqueeze(0).cuda(),
max_new_tokens=1,
return_dict_in_generate=True,
output_scores=True,
)
cur_logits.append(ret["scores"][0])
cur_logits = torch.stack(cur_logits, dim=0)
logits_list.append(cur_logits.squeeze(1))
return logits_list
def extract_tokens_from_outputs(outputs):
"""Extract individual tokens from LLM outputs using token IDs directly"""
tokens_list = []
for output in outputs:
# Get token IDs directly from the output
token_ids = output.outputs[0].token_ids
tokens_list.append(token_ids)
return tokens_list
def compare_logits(
logits_list: List[torch.Tensor],
ref_logits_list: List[torch.Tensor],
topk: int = 20,
threshold: float = 0.85,
):
assert len(logits_list) == len(ref_logits_list)
for i in range(len(logits_list)):
assert logits_list[i].shape == ref_logits_list[i].shape
lhs_idx = torch.topk(logits_list[i], topk, dim=-1).indices
rhs_idx = torch.topk(ref_logits_list[i], topk, dim=-1).indices
# Token wise comparison
ratios = []
for j in range(lhs_idx.shape[0]):
lhs_idx_j = lhs_idx[j].tolist()
rhs_idx_j = rhs_idx[j].tolist()
overlap = set(lhs_idx_j) & set(rhs_idx_j)
ratios.append(len(overlap) / len(lhs_idx_j))
mean_ratio = sum(ratios) / len(ratios)
assert mean_ratio > threshold, (
f"Prompt {i}: overlap ratio: {mean_ratio:.2%} is less than {threshold:.2%}"
)
def run_generate(llm, hf_model, prompts, sampling_params):
outputs = llm.generate(prompts, sampling_params)
llm_logits = []
for output in outputs:
llm_logits.append(output.outputs[0].generation_logits)
generated_token_ids_list = extract_tokens_from_outputs(outputs)
ref_logits = hf_model.generate_batch_incremental(prompts, generated_token_ids_list)
return llm_logits, ref_logits
def test_llm_update_weights(add_worker_extension_path):
llama_model_path = str(llm_models_root() / "llama-models-v2/TinyLlama-1.1B-Chat-v1.0")
kv_cache_config = KvCacheConfig(enable_block_reuse=True, free_gpu_memory_fraction=0.1)
hf_model = HFModel(llama_model_path)
llm = LLM(
model=llama_model_path,
ray_worker_extension_cls="rlhf_utils.WorkerExtension",
tensor_parallel_size=1,
pipeline_parallel_size=1,
kv_cache_config=kv_cache_config,
)
# Generate texts from the prompts.
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
sampling_params = SamplingParams(temperature=0, return_generation_logits=True)
results = []
# Stage 1: Generate with original model
results.append(run_generate(llm, hf_model, prompts, sampling_params))
llm_logits, ref_logits = results[0]
compare_logits(llm_logits, ref_logits)
# Stage 2: Test update with flipped weights
hf_model.flip_weights()
ipc_handles = hf_model.get_weight_ipc_handles()
llm._collective_rpc("update_weights", (ipc_handles,))
results.append(run_generate(llm, hf_model, prompts, sampling_params))
llm_logits, ref_logits = results[1]
# Compare the logits for this phase since output should be random
compare_logits(llm_logits, ref_logits)

View File

@ -195,6 +195,10 @@ methods:
annotation: Optional[str]
default: null
status: prototype
enable_sleep:
annotation: bool
default: False
status: prototype
return_annotation: None
generate:
parameters:

View File

@ -16,6 +16,7 @@
import os
import sys
import traceback
import warnings
from functools import partial
from typing import Any
@ -24,6 +25,7 @@ import pytest
import torch
import tqdm
from mpi4py.futures import MPIPoolExecutor
from utils.util import get_current_process_gpu_memory
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from integration.defs import test_list_parser
@ -316,3 +318,28 @@ def _maybe_force_ray(request, monkeypatch, ray_mode):
raising=False)
except Exception:
pass
@pytest.fixture(scope="module")
def process_gpu_memory_info_available():
"""
Checks if NVML can get per-process memory information.
"""
# Allocate a small tensor to test memory tracking
tensor = torch.zeros(4096, dtype=torch.int32, device='cuda')
torch.cuda.synchronize()
# Try to get memory usage
usage = get_current_process_gpu_memory()
# Clean up
del tensor
torch.cuda.synchronize()
torch.cuda.empty_cache()
if usage == 0:
warnings.warn("Per process memory information unavailable.")
return False
return True

View File

@ -22,6 +22,7 @@ from difflib import SequenceMatcher
from pathlib import Path
from typing import Any, Generator
import psutil
import pynvml
import pytest
import tensorrt as trt
@ -515,3 +516,33 @@ def assert_no_cuda_sync(
sleep_ctl.cancel()
scope_finished_event.synchronize()
_pynvmlInited = False
def get_current_process_gpu_memory(include_subprocess: bool = False) -> int:
"""
Returns GPU memory usage for current process in bytes.
"""
global _pynvmlInited
if not _pynvmlInited:
pynvml.nvmlInit()
_pynvmlInited = True
# Get current process ID
targets = [os.getpid()]
if include_subprocess:
targets.extend(
p.pid for p in psutil.Process(targets[0]).children(recursive=True))
targets = frozenset(targets)
# Get device handle for GPU 0
device_handle = pynvml.nvmlDeviceGetHandleByIndex(0)
# Get running processes
processes = pynvml.nvmlDeviceGetComputeRunningProcesses(device_handle)
# Find current process
return sum(process.usedGpuMemory for process in processes
if process.pid in targets)