[feat] Add explicit /start_weight_update and /finish_weight_update APIs for weight transfer (#39212)

This commit is contained in:
Sumanth R Hegde
2026-05-08 18:03:33 -07:00
committed by GitHub
parent 30f519e947
commit e3b65a5ba0
19 changed files with 287 additions and 53 deletions
+14 -4
View File
@@ -4,10 +4,12 @@ vLLM provides a pluggable weight transfer system for synchronizing model weights
## Architecture
The weight transfer system follows a **two-phase protocol** with a pluggable backend design:
The weight transfer system follows a **four-phase protocol** with a pluggable backend design:
1. **Initialization** (`init_weight_transfer_engine`): Establishes the communication channel between the trainer and inference workers. Called once before the training loop begins.
2. **Weight Update** (`update_weights`): Transfers updated weights from the trainer to the inference engine. Called after each training step (or batch of steps).
2. **Start** (`start_weight_update`): Prepares the inference engine for a weight update.
3. **Weight Update** (`update_weights`): Transfers updated weights from the trainer to the inference engine. May be called one or more times (e.g., for chunked transfers).
4. **Finish** (`finish_weight_update`): Finalizes the weight update (e.g., runs post-processing for checkpoint-format weights). Called once after all weights have been transferred.
## Available Backends
@@ -48,7 +50,9 @@ When running vLLM as an HTTP server, the following endpoints are available for w
| Endpoint | Method | Description |
| -------- | ------ | ----------- |
| `/init_weight_transfer_engine` | POST | Initialize the weight transfer engine with backend-specific info |
| `/update_weights` | POST | Trigger a weight update with backend-specific metadata |
| `/start_weight_update` | POST | Start a weight update |
| `/update_weights` | POST | Transfer a batch of weights with backend-specific metadata |
| `/finish_weight_update` | POST | Finish the weight update and run post-processing |
| `/pause` | POST | Pause generation before weight sync to handle inflight requests |
| `/resume` | POST | Resume generation after weight sync |
| `/get_world_size` | GET | Get the number of inference workers (useful for NCCL world size calculation) |
@@ -64,11 +68,17 @@ Both backends provide static methods that the trainer calls to send weights. The
# 1. Initialize the transfer engine (backend-specific)
EngineClass.trainer_init(init_info)
# 2. Send weights to inference workers
# 2. Start weight update on inference side
llm.start_weight_update(is_checkpoint_format=True)
# 3. Send weights to inference workers
EngineClass.trainer_send_weights(
iterator=model.named_parameters(),
trainer_args=backend_specific_args,
)
# 4. Finish weight update on inference side
llm.finish_weight_update()
```
See the [NCCL](nccl.md) and [IPC](ipc.md) pages for backend-specific trainer APIs and full examples.
+2 -4
View File
@@ -43,16 +43,14 @@ update_request = WeightTransferUpdateRequest(
### WeightTransferUpdateInfo
The base `WeightTransferUpdateInfo` includes an `is_checkpoint_format` flag:
The base `WeightTransferUpdateInfo` is a marker class for backend-specific update info:
```python
@dataclass
class WeightTransferUpdateInfo(ABC):
is_checkpoint_format: bool = True
pass
```
When `is_checkpoint_format=True` (the default), vLLM applies layerwise weight processing (repacking, renaming, etc.) on the received weights before loading them. Set to `False` if the trainer has already converted weights to the kernel format expected by the model.
## Implementing a Custom Engine
To create a custom weight transfer backend:
+16 -2
View File
@@ -38,11 +38,15 @@ trainer_args = IPCTrainerSendWeightsArgs(
mode="ray",
llm_handle=llm_actor_handle,
)
# start
ray.get(llm_actor_handle.start_weight_update.remote(is_checkpoint_format=True))
# send weights
IPCWeightTransferEngine.trainer_send_weights(
iterator=model.named_parameters(),
trainer_args=trainer_args,
)
# finish
ray.get(llm_actor_handle.finish_weight_update.remote())
```
In Ray mode, the engine calls `llm_handle.update_weights.remote(...)` directly, passing the IPC handles via Ray's serialization.
@@ -57,13 +61,23 @@ trainer_args = IPCTrainerSendWeightsArgs(
url="http://localhost:8000",
)
# start
base_url = "http://localhost:8000"
url = f"{base_url}/start_weight_update"
response = requests.post(url, json={"is_checkpoint_format": True}, timeout=60)
response.raise_for_status()
# send weights
IPCWeightTransferEngine.trainer_send_weights(
iterator=model.named_parameters(),
trainer_args=trainer_args,
)
# finish
url = f"{base_url}/finish_weight_update"
response = requests.post(url, json={}, timeout=60)
response.raise_for_status()
```
In HTTP mode, IPC handles are pickled, base64-encoded, and sent as JSON to the `/update_weights` endpoint.
In HTTP mode, IPC handles are pickled, base64-encoded, and sent as JSON to the `/update_weights` endpoint. As with Ray mode, you must call `start_weight_update` before and `finish_weight_update` after.
See [`IPCTrainerSendWeightsArgs`](https://github.com/vllm-project/vllm/blob/main/vllm/distributed/weight_transfer/ipc_engine.py) for the full list of configurable fields.
+10 -1
View File
@@ -84,11 +84,15 @@ Both the trainer (`NCCLTrainerSendWeightsArgs`) and inference side (`NCCLWeightT
## Receiving Weights (Inference Side)
The inference side triggers weight reception by calling `update_weights`:
The inference side triggers weight reception using the four-phase protocol — `init_weight_transfer_engine`, `start_weight_update`, `update_weights`, `finish_weight_update`. The init phase is shown [above](#initialization). The remaining three steps are:
```python
from vllm.distributed.weight_transfer.base import WeightTransferUpdateRequest
# 1. Start the weight update
llm.start_weight_update(is_checkpoint_format=True)
# 2. Receive weights (can be called multiple times for chunked transfers)
llm.update_weights(
WeightTransferUpdateRequest(
update_info=dict(
@@ -99,10 +103,15 @@ llm.update_weights(
)
)
)
# 3. Finish the weight update
llm.finish_weight_update()
```
The `names`, `dtype_names`, and `shapes` lists describe each parameter. These must match the order in which the trainer iterates over its parameters.
`start_weight_update` must be called before `update_weights`, and `finish_weight_update` must be called after all weight chunks have been transferred. The `is_checkpoint_format` flag controls whether layerwise reload processing is applied (`True` for checkpoint-format weights, `False` for pre-processed kernel-format weights).
## Examples
- [RLHF with NCCL weight syncing (offline, Ray)](../../examples/rl/rlhf_nccl.md) - Trainer on one GPU, 2x tensor-parallel vLLM engine on two others, with packed NCCL weight broadcast
+4
View File
@@ -310,6 +310,8 @@ gen_futures = [
ray.get(llm.pause_after_n_tokens.remote())
ray.get(llm.start_weight_update.remote(is_checkpoint_format=True))
inference_handle = llm.update_weights.remote(
WeightTransferUpdateRequest(
update_info=asdict(
@@ -325,6 +327,8 @@ inference_handle = llm.update_weights.remote(
train_handle = train_model.broadcast_weights.remote(packed=True)
ray.get([train_handle, inference_handle])
ray.get(llm.finish_weight_update.remote())
ray.get(llm.resume_generation.remote())
results = ray.get(gen_futures)
+23 -1
View File
@@ -80,6 +80,24 @@ def init_weight_transfer_engine(base_url: str) -> None:
response.raise_for_status()
def start_weight_update(
base_url: str,
is_checkpoint_format: bool = True,
) -> None:
"""Start a weight update via HTTP endpoint."""
url = f"{base_url}/start_weight_update"
payload = {"is_checkpoint_format": is_checkpoint_format}
response = requests.post(url, json=payload, timeout=60)
response.raise_for_status()
def finish_weight_update(base_url: str) -> None:
"""Finish a weight update via HTTP endpoint."""
url = f"{base_url}/finish_weight_update"
response = requests.post(url, json={}, timeout=60)
response.raise_for_status()
def pause_generation(base_url: str) -> None:
"""Pause generation via HTTP endpoint."""
url = f"{base_url}/pause"
@@ -151,7 +169,9 @@ def main():
# Pause generation before weight sync
pause_generation(BASE_URL)
# Broadcast weights via IPC handles using HTTP mode
# Start weight update, broadcast via IPC, then finish
start_weight_update(BASE_URL, is_checkpoint_format=False)
print("Broadcasting weights via CUDA IPC (HTTP)...")
trainer_args = IPCTrainerSendWeightsArgs(mode="http", url=BASE_URL)
IPCWeightTransferEngine.trainer_send_weights(
@@ -159,6 +179,8 @@ def main():
trainer_args=trainer_args,
)
finish_weight_update(BASE_URL)
# Resume generation after weight sync
resume_generation(BASE_URL)
+24
View File
@@ -83,6 +83,17 @@ def init_weight_transfer_engine(
response.raise_for_status()
def start_weight_update(
base_url: str,
is_checkpoint_format: bool = True,
) -> None:
"""Start a weight update via HTTP endpoint."""
url = f"{base_url}/start_weight_update"
payload = {"is_checkpoint_format": is_checkpoint_format}
response = requests.post(url, json=payload, timeout=60)
response.raise_for_status()
def update_weights(
base_url: str,
names: list[str],
@@ -104,6 +115,13 @@ def update_weights(
response.raise_for_status()
def finish_weight_update(base_url: str) -> None:
"""Finish a weight update via HTTP endpoint."""
url = f"{base_url}/finish_weight_update"
response = requests.post(url, json={}, timeout=60)
response.raise_for_status()
def pause_generation(base_url: str) -> None:
"""Pause generation via HTTP endpoint."""
url = f"{base_url}/pause"
@@ -204,6 +222,9 @@ def main():
dtype_names.append(str(p.dtype).split(".")[-1])
shapes.append(list(p.shape))
# Start weight update
start_weight_update(BASE_URL, is_checkpoint_format=True)
# Start the update_weights call in a separate thread since it will block
# waiting for NCCL broadcasts
# packed=True enables efficient batched tensor broadcasting
@@ -227,6 +248,9 @@ def main():
# Wait for update_weights to complete
update_thread.join()
# Finish weight update
finish_weight_update(BASE_URL)
# Resume generation after weight sync
resume_generation(BASE_URL)
+3 -1
View File
@@ -134,8 +134,10 @@ for output in outputs:
ray.get(llm.sleep.remote(level=0))
ray.get(train_model.init_weight_transfer.remote())
# Synchronize the updated weights to the inference engine using batched API.
# Start weight update, sync weights, then finish
ray.get(llm.start_weight_update.remote(is_checkpoint_format=True))
ray.get(train_model.broadcast_weights.remote(llm))
ray.get(llm.finish_weight_update.remote())
ray.get(llm.wake_up.remote(tags=["scheduling"]))
+6
View File
@@ -186,6 +186,9 @@ ray.get([train_handle, inference_handle])
# Collect all weight metadata from the training actor
names, dtype_names, shapes = ray.get(train_model.get_weight_metadata.remote())
# Start weight update
ray.get(llm.start_weight_update.remote(is_checkpoint_format=True))
# Issue update_weights call with NCCL-specific update info
# packed=True enables efficient batched tensor broadcasting
inference_handle = llm.update_weights.remote(
@@ -203,6 +206,9 @@ inference_handle = llm.update_weights.remote(
train_handle = train_model.broadcast_weights.remote(packed=True)
ray.get([train_handle, inference_handle])
# Finish weight update
ray.get(llm.finish_weight_update.remote())
ray.get(llm.wake_up.remote(tags=["scheduling"]))
# Generate text with the updated model. The output is expected to be normal
+5
View File
@@ -298,6 +298,9 @@ async def main():
names, dtype_names, shapes = ray.get(fsdp_workers[0].get_weight_metadata.remote())
print(f"[sync] Got metadata for {len(names)} parameters.")
print("[sync] Starting weight update...")
await engine.start_weight_update(is_checkpoint_format=True)
print("[sync] Broadcasting weights from FSDP → vLLM...")
broadcast_handles = [
w.gather_and_broadcast_weights.remote(packed=True) for w in fsdp_workers
@@ -315,6 +318,8 @@ async def main():
)
)
ray.get(broadcast_handles)
await engine.finish_weight_update()
print("[sync] Weight broadcast complete.")
print("[sync] Resuming generation...")
@@ -133,7 +133,9 @@ def test_openapi_stateless(case: Case):
# (weight_transfer_config) and are meant to be stateful.
if case.operation.path in (
"/init_weight_transfer_engine",
"/start_weight_update",
"/update_weights",
"/finish_weight_update",
):
return
@@ -199,6 +199,9 @@ def test_update_weights_calls_engine():
WeightTransferInitRequest(init_info={"test_param": "init"})
)
# Start weight update (required before update_weights)
llm.start_weight_update(is_checkpoint_format=True)
# Call update_weights
test_names = ["layer.weight", "layer.bias"]
test_dtypes = ["float32", "float32"]
@@ -229,10 +232,14 @@ def test_update_weights_calls_engine():
assert dtypes == test_dtypes
assert shapes == test_shapes
# Finish weight update
llm.finish_weight_update()
@create_new_process_for_each_test()
def test_full_weight_transfer_flow():
"""Test the complete weight transfer flow: init -> update."""
"""Test the complete weight transfer flow:
init -> start -> update -> finish."""
if torch.accelerator.device_count() < 1:
pytest.skip("Need at least 1 GPU for this test")
@@ -253,12 +260,15 @@ def test_full_weight_transfer_flow():
weight_transfer_config=WeightTransferConfig(backend="nccl"),
)
# Step 1: Initialize
# Step 1: Initialize weight transfer engine
llm.init_weight_transfer_engine(
WeightTransferInitRequest(init_info={"test_param": "flow_test"})
)
# Step 2: Update weights
# Step 2: Start weight update
llm.start_weight_update(is_checkpoint_format=True)
# Step 3: Update weights
llm.update_weights(
WeightTransferUpdateRequest(
update_info={
@@ -269,6 +279,9 @@ def test_full_weight_transfer_flow():
)
)
# Step 4: Finish weight update
llm.finish_weight_update()
# Verify the full flow completed
def check_flow(self):
engine = self.weight_transfer_engine
+2 -6
View File
@@ -4,7 +4,7 @@
from abc import ABC, abstractmethod
from collections.abc import Callable, Iterator
from dataclasses import KW_ONLY, dataclass, field
from dataclasses import dataclass, field
from typing import Any, Generic, TypeVar
import torch
@@ -28,11 +28,7 @@ class WeightTransferInitInfo(ABC): # noqa: B024
class WeightTransferUpdateInfo(ABC): # noqa: B024
"""Base class for backend-specific weight update info."""
_: KW_ONLY
is_checkpoint_format: bool = True
"""Set to True if weights are in checkpoint/original model format and need
layerwise processing. Set to False if weights have already been processed
into kernel format (repacking, renaming, etc.)."""
pass
# API-level request classes (accept dicts for backend-agnostic serialization)
@@ -211,6 +211,11 @@ class IPCWeightTransferEngine(
- 'ray': Sends weights via Ray RPC to a Ray-based LLM handle
- 'http': Sends weights via HTTP POST to a vLLM HTTP server
.. note::
This method calls ``update_weights`` internally. The caller must
call ``start_weight_update`` before and ``finish_weight_update``
after this method.
Args:
iterator: Iterator of model parameters. Returns (name, tensor) tuples.
Tensors should be on the same GPU as the inference workers.
+8
View File
@@ -230,6 +230,14 @@ class EngineClient(ABC):
"""Initialize weight transfer for RL training."""
raise NotImplementedError
async def start_weight_update(self, is_checkpoint_format: bool = True) -> None:
"""Start a new weight update."""
raise NotImplementedError
async def update_weights(self, request: WeightTransferUpdateRequest) -> None:
"""Batched weight update for RL training."""
raise NotImplementedError
async def finish_weight_update(self) -> None:
"""Finish the current weight update."""
raise NotImplementedError
+20
View File
@@ -1908,6 +1908,20 @@ class LLM:
"init_weight_transfer_engine", kwargs={"init_info": init_info_dict}
)
def start_weight_update(self, is_checkpoint_format: bool = True) -> None:
"""
Start a new weight update.
Args:
is_checkpoint_format: Whether incoming weights are in checkpoint
format (need layerwise processing) or kernel format (direct
copy).
"""
self.llm_engine.collective_rpc(
"start_weight_update",
kwargs={"is_checkpoint_format": is_checkpoint_format},
)
def update_weights(self, request: WeightTransferUpdateRequest | dict) -> None:
"""
Update the weights of the model.
@@ -1923,6 +1937,12 @@ class LLM:
"update_weights", kwargs={"update_info": update_info_dict}
)
def finish_weight_update(self) -> None:
"""
Finish the current weight update.
"""
self.llm_engine.collective_rpc("finish_weight_update")
def __repr__(self) -> str:
"""Return a transformers-style hierarchical view of the model."""
# Cache the result to avoid repeated collective_rpc calls
+19
View File
@@ -128,6 +128,19 @@ async def init_weight_transfer_engine(raw_request: Request):
return JSONResponse(content={"message": "Weight transfer initialized"})
@router.post("/start_weight_update")
async def start_weight_update(raw_request: Request):
try:
body = await raw_request.json()
except json.JSONDecodeError as e:
raise HTTPException(status_code=400, detail="Invalid JSON format") from e # noqa: B904
is_checkpoint_format = body.get("is_checkpoint_format", True)
await engine_client(raw_request).start_weight_update(
is_checkpoint_format=is_checkpoint_format
)
return JSONResponse(content={"message": "Weight update started"})
@router.post("/update_weights")
async def update_weights(raw_request: Request):
try:
@@ -146,6 +159,12 @@ async def update_weights(raw_request: Request):
return JSONResponse(content={"message": "Weights updated"})
@router.post("/finish_weight_update")
async def finish_weight_update(raw_request: Request):
await engine_client(raw_request).finish_weight_update()
return JSONResponse(content={"message": "Weight update finished"})
@router.get("/get_world_size")
async def get_world_size(
raw_request: Request,
+11
View File
@@ -1049,6 +1049,13 @@ class AsyncLLM(EngineClient):
"init_weight_transfer_engine", kwargs={"init_info": init_info_dict}
)
async def start_weight_update(self, is_checkpoint_format: bool = True) -> None:
"""Start a new weight update."""
await self.collective_rpc(
"start_weight_update",
kwargs={"is_checkpoint_format": is_checkpoint_format},
)
async def update_weights(self, request: WeightTransferUpdateRequest) -> None:
"""
Batched weight update for RL training.
@@ -1067,3 +1074,7 @@ class AsyncLLM(EngineClient):
await self.collective_rpc(
"update_weights", kwargs={"update_info": update_info_dict}
)
async def finish_weight_update(self) -> None:
"""Finish the current weight update."""
await self.collective_rpc("finish_weight_update")
+97 -31
View File
@@ -140,6 +140,8 @@ class Worker(WorkerBase):
if self.vllm_config.weight_transfer_config is not None
else None
)
self._weight_update_active = False
self._is_checkpoint_format = True
# Torch/CUDA profiler. Enabled and configured through profiler_config.
# Profiler wrapper is created lazily in profile() when start is called,
@@ -966,6 +968,13 @@ class Worker(WorkerBase):
model_config=self.model_config,
)
def _check_weight_transfer_engine(self) -> None:
if self.weight_transfer_engine is None:
raise RuntimeError(
"Weight transfer not configured. "
"Please set weight_transfer_config to enable weight transfer."
)
def init_weight_transfer_engine(self, init_info: dict) -> None:
"""
Initialize weight transfer mechanism.
@@ -974,26 +983,62 @@ class Worker(WorkerBase):
Args:
init_info: Dictionary containing backend-specific initialization info
"""
if self.weight_transfer_engine is None:
raise RuntimeError(
"Weight transfer not configured. "
"Please set weight_transfer_config to enable weight transfer."
)
self._check_weight_transfer_engine()
assert self.weight_transfer_engine is not None
# Parse dict into backend-specific typed dataclass
typed_init_info = self.weight_transfer_engine.parse_init_info(init_info)
self.weight_transfer_engine.init_transfer_engine(typed_init_info)
def start_weight_update(self, is_checkpoint_format: bool = True) -> None:
"""
Start a new weight update.
Prepares the model for receiving weights. For checkpoint format,
this initializes state for layerwise processing. For kernel format, this is
a no-op but must still be called for consistency.
Args:
is_checkpoint_format: Whether incoming weights are in checkpoint
format (need layerwise processing) or kernel format (direct
copy). Stored as state for finish_weight_update.
"""
self._check_weight_transfer_engine()
if self._weight_update_active:
raise RuntimeError(
"start_weight_update called while a weight update is "
"already active. Call finish_weight_update first."
)
if is_checkpoint_format:
from vllm.model_executor.model_loader.reload import (
initialize_layerwise_reload,
)
model = self.model_runner.model
with torch.device(self.device):
initialize_layerwise_reload(model)
# Store state so update_weights/finish_weight_update can check
self._is_checkpoint_format = is_checkpoint_format
self._weight_update_active = True
def update_weights(self, update_info: dict) -> None:
"""
Batched weight update from the trainer.
Receive weights from the trainer (one or more chunks).
start_weight_update must be called before update_weights and
finish_weight_update must be called after.
Args:
update_info: Dictionary containing backend-specific update info
"""
if self.weight_transfer_engine is None:
self._check_weight_transfer_engine()
assert self.weight_transfer_engine is not None
if not self._weight_update_active:
raise RuntimeError(
"Weight transfer not configured. "
"Please set weight_transfer_config to enable weight transfer."
"start_weight_update must be called before update_weights."
)
# Parse dict into backend-specific typed dataclass
@@ -1001,38 +1046,59 @@ class Worker(WorkerBase):
model = self.model_runner.model
if typed_update_info.is_checkpoint_format:
from vllm.model_executor.model_loader.reload import (
finalize_layerwise_reload,
initialize_layerwise_reload,
)
# Use layerwise reload pattern for checkpoint format weights
with torch.device(self.device):
initialize_layerwise_reload(model)
with torch.device(self.device):
if self._is_checkpoint_format:
self.weight_transfer_engine.receive_weights(
typed_update_info,
load_weights=model.load_weights,
)
finalize_layerwise_reload(model, self.model_config)
else:
# Weights are already in kernel format, copy directly
def load_weights_direct(
weights: list[tuple[str, torch.Tensor]],
) -> None:
for name, weight in weights:
param = model.get_parameter(name)
param.copy_(weight)
else:
# Weights are already in kernel format, copy directly
def load_weights_direct(
weights: list[tuple[str, torch.Tensor]],
) -> None:
for name, weight in weights:
param = model.get_parameter(name)
param.copy_(weight)
self.weight_transfer_engine.receive_weights(
typed_update_info,
load_weights=load_weights_direct,
)
self.weight_transfer_engine.receive_weights(
typed_update_info,
load_weights=load_weights_direct,
)
# NCCL broadcast/packed path are asynchronous.
# Sync here so the next step uses the new weights.
torch.accelerator.synchronize()
def finish_weight_update(self) -> None:
"""
Finish the current weight update.
For checkpoint format, this runs layerwise postprocessing.
Uses the is_checkpoint_format state stored by start_weight_update.
"""
self._check_weight_transfer_engine()
if not self._weight_update_active:
raise RuntimeError(
"start_weight_update must be called before finish_weight_update."
)
is_checkpoint_format = self._is_checkpoint_format
if is_checkpoint_format:
from vllm.model_executor.model_loader.reload import (
finalize_layerwise_reload,
)
model = self.model_runner.model
with torch.device(self.device):
finalize_layerwise_reload(model, self.model_config)
# Reset state
self._weight_update_active = False
self._is_checkpoint_format = True
def shutdown(self) -> None:
# has_kv_transfer_group can be None during interpreter shutdown.
if ensure_kv_transfer_shutdown is not None: