TensorRT-LLMs/tensorrt_llm/llmapi/rlhf_utils.py
shuyixiong d8acea1db3
[TRTLLM-9293][feat] Enable partial weight loading to support streaming update weights (#9224)
Signed-off-by: shuyix <219646547+shuyixiong@users.noreply.github.com>
2025-11-26 10:59:06 +08:00

97 lines
4.1 KiB
Python

from typing import Optional
import torch
from tensorrt_llm._ray_utils import control_action_decorator
from tensorrt_llm._torch.modules.fused_moe.moe_load_balancer import MoeLoadBalancer
from tensorrt_llm._torch.utils import get_device_uuid
from tensorrt_llm.logger import logger
class WorkerExtension:
"""Worker extension class for extending TensorRT-LLM Ray workers with custom functionality.
This class can be injected into tensorrt_llm.LLM() by specifying it via the
ray_worker_extension_cls parameter in LLMArgs when using orchestrator_type='ray'.
The extension methods will be available on each Ray worker and can be called via
the LLM's collective RPC mechanism.
Examples:
Creating an LLM with worker extension:
>>> llm = LLM(
... model=model_dir,
... orchestrator_type="ray",
... ray_worker_extension_cls="rlhf_utils.WorkerExtension",
... )
Calling extension methods via collective RPC:
>>> llm._collective_rpc("update_weights", args=(ipc_handles,))
"""
@control_action_decorator
def update_weights(self, ipc_handles: Optional[dict] = None):
"""Update model weights from IPC (Inter-Process Communication) handles.
This method receives shared memory handles from another process (typically FSDP training),
reconstructs tensors from these handles, and loads them into the TensorRT-LLM model.
Uses the control_action_decorator to ensure all active requests are finished before
updating weights.
Args:
ipc_handles: Dictionary mapping device UUIDs to lists of (param_name, tensor_handle) tuples.
Each tensor_handle is a tuple of (func, args) for reconstructing the tensor.
Raises:
ValueError: If the current device's UUID is not found in ipc_handles.
Exception: Re-raises any exception encountered during weight update.
"""
try:
if ipc_handles is not None:
logger.info("Update weights from IPC handles")
device_uuid = get_device_uuid(self.device_id)
if device_uuid not in ipc_handles:
raise ValueError(f"Device UUID {device_uuid} not found in ipc_handles")
weights = {}
all_handles = ipc_handles[device_uuid]
for param_name, tensor_handle in all_handles:
func, args = tensor_handle
list_args = list(args)
list_args[6] = self.device_id # Set target device
tensor = func(*list_args)
weights[param_name] = tensor
logger.info(f"weights key size: {len(weights.keys())}")
self.engine.model_engine.model_loader.reload(
self.engine.model_engine.model, weights, allow_partial_loading=True
)
else:
logger.info("Finalize update weights")
for module in self.engine.model_engine.model.modules():
if hasattr(module, "post_load_weights") and not getattr(
module, "_weights_removed", False
):
module.post_load_weights()
moe_load_balancer = getattr(self.engine.model_engine, "moe_load_balancer", None)
if isinstance(moe_load_balancer, MoeLoadBalancer):
moe_load_balancer.register_weight_slots_after_to_cuda()
logger.info("moe_load_balancer finalizing model...")
moe_load_balancer.finalize_model()
logger.info("moe_load_balancer finalize model done")
self.engine.reset_prefix_cache()
except Exception as e:
logger.error("Encountered an error in update_weights")
raise e
def check_weights_updated(self):
"""Check if the weights are updated to 0."""
weights_updated = True
for name, p in self.engine.model_engine.model.named_parameters():
weights_updated = weights_updated and torch.allclose(p, torch.zeros_like(p))
return weights_updated