mirror of
https://github.com/vllm-project/vllm.git
synced 2026-06-06 00:16:14 +00:00
[feat] Add explicit /start_weight_update and /finish_weight_update APIs for weight transfer (#39212)
This commit is contained in:
@@ -4,10 +4,12 @@ vLLM provides a pluggable weight transfer system for synchronizing model weights
|
||||
|
||||
## Architecture
|
||||
|
||||
The weight transfer system follows a **two-phase protocol** with a pluggable backend design:
|
||||
The weight transfer system follows a **four-phase protocol** with a pluggable backend design:
|
||||
|
||||
1. **Initialization** (`init_weight_transfer_engine`): Establishes the communication channel between the trainer and inference workers. Called once before the training loop begins.
|
||||
2. **Weight Update** (`update_weights`): Transfers updated weights from the trainer to the inference engine. Called after each training step (or batch of steps).
|
||||
2. **Start** (`start_weight_update`): Prepares the inference engine for a weight update.
|
||||
3. **Weight Update** (`update_weights`): Transfers updated weights from the trainer to the inference engine. May be called one or more times (e.g., for chunked transfers).
|
||||
4. **Finish** (`finish_weight_update`): Finalizes the weight update (e.g., runs post-processing for checkpoint-format weights). Called once after all weights have been transferred.
|
||||
|
||||
## Available Backends
|
||||
|
||||
@@ -48,7 +50,9 @@ When running vLLM as an HTTP server, the following endpoints are available for w
|
||||
| Endpoint | Method | Description |
|
||||
| -------- | ------ | ----------- |
|
||||
| `/init_weight_transfer_engine` | POST | Initialize the weight transfer engine with backend-specific info |
|
||||
| `/update_weights` | POST | Trigger a weight update with backend-specific metadata |
|
||||
| `/start_weight_update` | POST | Start a weight update |
|
||||
| `/update_weights` | POST | Transfer a batch of weights with backend-specific metadata |
|
||||
| `/finish_weight_update` | POST | Finish the weight update and run post-processing |
|
||||
| `/pause` | POST | Pause generation before weight sync to handle inflight requests |
|
||||
| `/resume` | POST | Resume generation after weight sync |
|
||||
| `/get_world_size` | GET | Get the number of inference workers (useful for NCCL world size calculation) |
|
||||
@@ -64,11 +68,17 @@ Both backends provide static methods that the trainer calls to send weights. The
|
||||
# 1. Initialize the transfer engine (backend-specific)
|
||||
EngineClass.trainer_init(init_info)
|
||||
|
||||
# 2. Send weights to inference workers
|
||||
# 2. Start weight update on inference side
|
||||
llm.start_weight_update(is_checkpoint_format=True)
|
||||
|
||||
# 3. Send weights to inference workers
|
||||
EngineClass.trainer_send_weights(
|
||||
iterator=model.named_parameters(),
|
||||
trainer_args=backend_specific_args,
|
||||
)
|
||||
|
||||
# 4. Finish weight update on inference side
|
||||
llm.finish_weight_update()
|
||||
```
|
||||
|
||||
See the [NCCL](nccl.md) and [IPC](ipc.md) pages for backend-specific trainer APIs and full examples.
|
||||
|
||||
@@ -43,16 +43,14 @@ update_request = WeightTransferUpdateRequest(
|
||||
|
||||
### WeightTransferUpdateInfo
|
||||
|
||||
The base `WeightTransferUpdateInfo` includes an `is_checkpoint_format` flag:
|
||||
The base `WeightTransferUpdateInfo` is a marker class for backend-specific update info:
|
||||
|
||||
```python
|
||||
@dataclass
|
||||
class WeightTransferUpdateInfo(ABC):
|
||||
is_checkpoint_format: bool = True
|
||||
pass
|
||||
```
|
||||
|
||||
When `is_checkpoint_format=True` (the default), vLLM applies layerwise weight processing (repacking, renaming, etc.) on the received weights before loading them. Set to `False` if the trainer has already converted weights to the kernel format expected by the model.
|
||||
|
||||
## Implementing a Custom Engine
|
||||
|
||||
To create a custom weight transfer backend:
|
||||
|
||||
@@ -38,11 +38,15 @@ trainer_args = IPCTrainerSendWeightsArgs(
|
||||
mode="ray",
|
||||
llm_handle=llm_actor_handle,
|
||||
)
|
||||
|
||||
# start
|
||||
ray.get(llm_actor_handle.start_weight_update.remote(is_checkpoint_format=True))
|
||||
# send weights
|
||||
IPCWeightTransferEngine.trainer_send_weights(
|
||||
iterator=model.named_parameters(),
|
||||
trainer_args=trainer_args,
|
||||
)
|
||||
# finish
|
||||
ray.get(llm_actor_handle.finish_weight_update.remote())
|
||||
```
|
||||
|
||||
In Ray mode, the engine calls `llm_handle.update_weights.remote(...)` directly, passing the IPC handles via Ray's serialization.
|
||||
@@ -57,13 +61,23 @@ trainer_args = IPCTrainerSendWeightsArgs(
|
||||
url="http://localhost:8000",
|
||||
)
|
||||
|
||||
# start
|
||||
base_url = "http://localhost:8000"
|
||||
url = f"{base_url}/start_weight_update"
|
||||
response = requests.post(url, json={"is_checkpoint_format": True}, timeout=60)
|
||||
response.raise_for_status()
|
||||
# send weights
|
||||
IPCWeightTransferEngine.trainer_send_weights(
|
||||
iterator=model.named_parameters(),
|
||||
trainer_args=trainer_args,
|
||||
)
|
||||
# finish
|
||||
url = f"{base_url}/finish_weight_update"
|
||||
response = requests.post(url, json={}, timeout=60)
|
||||
response.raise_for_status()
|
||||
```
|
||||
|
||||
In HTTP mode, IPC handles are pickled, base64-encoded, and sent as JSON to the `/update_weights` endpoint.
|
||||
In HTTP mode, IPC handles are pickled, base64-encoded, and sent as JSON to the `/update_weights` endpoint. As with Ray mode, you must call `start_weight_update` before and `finish_weight_update` after.
|
||||
|
||||
See [`IPCTrainerSendWeightsArgs`](https://github.com/vllm-project/vllm/blob/main/vllm/distributed/weight_transfer/ipc_engine.py) for the full list of configurable fields.
|
||||
|
||||
|
||||
@@ -84,11 +84,15 @@ Both the trainer (`NCCLTrainerSendWeightsArgs`) and inference side (`NCCLWeightT
|
||||
|
||||
## Receiving Weights (Inference Side)
|
||||
|
||||
The inference side triggers weight reception by calling `update_weights`:
|
||||
The inference side triggers weight reception using the four-phase protocol — `init_weight_transfer_engine`, `start_weight_update`, `update_weights`, `finish_weight_update`. The init phase is shown [above](#initialization). The remaining three steps are:
|
||||
|
||||
```python
|
||||
from vllm.distributed.weight_transfer.base import WeightTransferUpdateRequest
|
||||
|
||||
# 1. Start the weight update
|
||||
llm.start_weight_update(is_checkpoint_format=True)
|
||||
|
||||
# 2. Receive weights (can be called multiple times for chunked transfers)
|
||||
llm.update_weights(
|
||||
WeightTransferUpdateRequest(
|
||||
update_info=dict(
|
||||
@@ -99,10 +103,15 @@ llm.update_weights(
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
# 3. Finish the weight update
|
||||
llm.finish_weight_update()
|
||||
```
|
||||
|
||||
The `names`, `dtype_names`, and `shapes` lists describe each parameter. These must match the order in which the trainer iterates over its parameters.
|
||||
|
||||
`start_weight_update` must be called before `update_weights`, and `finish_weight_update` must be called after all weight chunks have been transferred. The `is_checkpoint_format` flag controls whether layerwise reload processing is applied (`True` for checkpoint-format weights, `False` for pre-processed kernel-format weights).
|
||||
|
||||
## Examples
|
||||
|
||||
- [RLHF with NCCL weight syncing (offline, Ray)](../../examples/rl/rlhf_nccl.md) - Trainer on one GPU, 2x tensor-parallel vLLM engine on two others, with packed NCCL weight broadcast
|
||||
|
||||
@@ -310,6 +310,8 @@ gen_futures = [
|
||||
|
||||
ray.get(llm.pause_after_n_tokens.remote())
|
||||
|
||||
ray.get(llm.start_weight_update.remote(is_checkpoint_format=True))
|
||||
|
||||
inference_handle = llm.update_weights.remote(
|
||||
WeightTransferUpdateRequest(
|
||||
update_info=asdict(
|
||||
@@ -325,6 +327,8 @@ inference_handle = llm.update_weights.remote(
|
||||
train_handle = train_model.broadcast_weights.remote(packed=True)
|
||||
ray.get([train_handle, inference_handle])
|
||||
|
||||
ray.get(llm.finish_weight_update.remote())
|
||||
|
||||
ray.get(llm.resume_generation.remote())
|
||||
results = ray.get(gen_futures)
|
||||
|
||||
|
||||
@@ -80,6 +80,24 @@ def init_weight_transfer_engine(base_url: str) -> None:
|
||||
response.raise_for_status()
|
||||
|
||||
|
||||
def start_weight_update(
|
||||
base_url: str,
|
||||
is_checkpoint_format: bool = True,
|
||||
) -> None:
|
||||
"""Start a weight update via HTTP endpoint."""
|
||||
url = f"{base_url}/start_weight_update"
|
||||
payload = {"is_checkpoint_format": is_checkpoint_format}
|
||||
response = requests.post(url, json=payload, timeout=60)
|
||||
response.raise_for_status()
|
||||
|
||||
|
||||
def finish_weight_update(base_url: str) -> None:
|
||||
"""Finish a weight update via HTTP endpoint."""
|
||||
url = f"{base_url}/finish_weight_update"
|
||||
response = requests.post(url, json={}, timeout=60)
|
||||
response.raise_for_status()
|
||||
|
||||
|
||||
def pause_generation(base_url: str) -> None:
|
||||
"""Pause generation via HTTP endpoint."""
|
||||
url = f"{base_url}/pause"
|
||||
@@ -151,7 +169,9 @@ def main():
|
||||
# Pause generation before weight sync
|
||||
pause_generation(BASE_URL)
|
||||
|
||||
# Broadcast weights via IPC handles using HTTP mode
|
||||
# Start weight update, broadcast via IPC, then finish
|
||||
start_weight_update(BASE_URL, is_checkpoint_format=False)
|
||||
|
||||
print("Broadcasting weights via CUDA IPC (HTTP)...")
|
||||
trainer_args = IPCTrainerSendWeightsArgs(mode="http", url=BASE_URL)
|
||||
IPCWeightTransferEngine.trainer_send_weights(
|
||||
@@ -159,6 +179,8 @@ def main():
|
||||
trainer_args=trainer_args,
|
||||
)
|
||||
|
||||
finish_weight_update(BASE_URL)
|
||||
|
||||
# Resume generation after weight sync
|
||||
resume_generation(BASE_URL)
|
||||
|
||||
|
||||
@@ -83,6 +83,17 @@ def init_weight_transfer_engine(
|
||||
response.raise_for_status()
|
||||
|
||||
|
||||
def start_weight_update(
|
||||
base_url: str,
|
||||
is_checkpoint_format: bool = True,
|
||||
) -> None:
|
||||
"""Start a weight update via HTTP endpoint."""
|
||||
url = f"{base_url}/start_weight_update"
|
||||
payload = {"is_checkpoint_format": is_checkpoint_format}
|
||||
response = requests.post(url, json=payload, timeout=60)
|
||||
response.raise_for_status()
|
||||
|
||||
|
||||
def update_weights(
|
||||
base_url: str,
|
||||
names: list[str],
|
||||
@@ -104,6 +115,13 @@ def update_weights(
|
||||
response.raise_for_status()
|
||||
|
||||
|
||||
def finish_weight_update(base_url: str) -> None:
|
||||
"""Finish a weight update via HTTP endpoint."""
|
||||
url = f"{base_url}/finish_weight_update"
|
||||
response = requests.post(url, json={}, timeout=60)
|
||||
response.raise_for_status()
|
||||
|
||||
|
||||
def pause_generation(base_url: str) -> None:
|
||||
"""Pause generation via HTTP endpoint."""
|
||||
url = f"{base_url}/pause"
|
||||
@@ -204,6 +222,9 @@ def main():
|
||||
dtype_names.append(str(p.dtype).split(".")[-1])
|
||||
shapes.append(list(p.shape))
|
||||
|
||||
# Start weight update
|
||||
start_weight_update(BASE_URL, is_checkpoint_format=True)
|
||||
|
||||
# Start the update_weights call in a separate thread since it will block
|
||||
# waiting for NCCL broadcasts
|
||||
# packed=True enables efficient batched tensor broadcasting
|
||||
@@ -227,6 +248,9 @@ def main():
|
||||
# Wait for update_weights to complete
|
||||
update_thread.join()
|
||||
|
||||
# Finish weight update
|
||||
finish_weight_update(BASE_URL)
|
||||
|
||||
# Resume generation after weight sync
|
||||
resume_generation(BASE_URL)
|
||||
|
||||
|
||||
@@ -134,8 +134,10 @@ for output in outputs:
|
||||
ray.get(llm.sleep.remote(level=0))
|
||||
|
||||
ray.get(train_model.init_weight_transfer.remote())
|
||||
# Synchronize the updated weights to the inference engine using batched API.
|
||||
# Start weight update, sync weights, then finish
|
||||
ray.get(llm.start_weight_update.remote(is_checkpoint_format=True))
|
||||
ray.get(train_model.broadcast_weights.remote(llm))
|
||||
ray.get(llm.finish_weight_update.remote())
|
||||
|
||||
ray.get(llm.wake_up.remote(tags=["scheduling"]))
|
||||
|
||||
|
||||
@@ -186,6 +186,9 @@ ray.get([train_handle, inference_handle])
|
||||
# Collect all weight metadata from the training actor
|
||||
names, dtype_names, shapes = ray.get(train_model.get_weight_metadata.remote())
|
||||
|
||||
# Start weight update
|
||||
ray.get(llm.start_weight_update.remote(is_checkpoint_format=True))
|
||||
|
||||
# Issue update_weights call with NCCL-specific update info
|
||||
# packed=True enables efficient batched tensor broadcasting
|
||||
inference_handle = llm.update_weights.remote(
|
||||
@@ -203,6 +206,9 @@ inference_handle = llm.update_weights.remote(
|
||||
train_handle = train_model.broadcast_weights.remote(packed=True)
|
||||
ray.get([train_handle, inference_handle])
|
||||
|
||||
# Finish weight update
|
||||
ray.get(llm.finish_weight_update.remote())
|
||||
|
||||
ray.get(llm.wake_up.remote(tags=["scheduling"]))
|
||||
|
||||
# Generate text with the updated model. The output is expected to be normal
|
||||
|
||||
@@ -298,6 +298,9 @@ async def main():
|
||||
names, dtype_names, shapes = ray.get(fsdp_workers[0].get_weight_metadata.remote())
|
||||
print(f"[sync] Got metadata for {len(names)} parameters.")
|
||||
|
||||
print("[sync] Starting weight update...")
|
||||
await engine.start_weight_update(is_checkpoint_format=True)
|
||||
|
||||
print("[sync] Broadcasting weights from FSDP → vLLM...")
|
||||
broadcast_handles = [
|
||||
w.gather_and_broadcast_weights.remote(packed=True) for w in fsdp_workers
|
||||
@@ -315,6 +318,8 @@ async def main():
|
||||
)
|
||||
)
|
||||
ray.get(broadcast_handles)
|
||||
|
||||
await engine.finish_weight_update()
|
||||
print("[sync] Weight broadcast complete.")
|
||||
|
||||
print("[sync] Resuming generation...")
|
||||
|
||||
@@ -133,7 +133,9 @@ def test_openapi_stateless(case: Case):
|
||||
# (weight_transfer_config) and are meant to be stateful.
|
||||
if case.operation.path in (
|
||||
"/init_weight_transfer_engine",
|
||||
"/start_weight_update",
|
||||
"/update_weights",
|
||||
"/finish_weight_update",
|
||||
):
|
||||
return
|
||||
|
||||
|
||||
@@ -199,6 +199,9 @@ def test_update_weights_calls_engine():
|
||||
WeightTransferInitRequest(init_info={"test_param": "init"})
|
||||
)
|
||||
|
||||
# Start weight update (required before update_weights)
|
||||
llm.start_weight_update(is_checkpoint_format=True)
|
||||
|
||||
# Call update_weights
|
||||
test_names = ["layer.weight", "layer.bias"]
|
||||
test_dtypes = ["float32", "float32"]
|
||||
@@ -229,10 +232,14 @@ def test_update_weights_calls_engine():
|
||||
assert dtypes == test_dtypes
|
||||
assert shapes == test_shapes
|
||||
|
||||
# Finish weight update
|
||||
llm.finish_weight_update()
|
||||
|
||||
|
||||
@create_new_process_for_each_test()
|
||||
def test_full_weight_transfer_flow():
|
||||
"""Test the complete weight transfer flow: init -> update."""
|
||||
"""Test the complete weight transfer flow:
|
||||
init -> start -> update -> finish."""
|
||||
if torch.accelerator.device_count() < 1:
|
||||
pytest.skip("Need at least 1 GPU for this test")
|
||||
|
||||
@@ -253,12 +260,15 @@ def test_full_weight_transfer_flow():
|
||||
weight_transfer_config=WeightTransferConfig(backend="nccl"),
|
||||
)
|
||||
|
||||
# Step 1: Initialize
|
||||
# Step 1: Initialize weight transfer engine
|
||||
llm.init_weight_transfer_engine(
|
||||
WeightTransferInitRequest(init_info={"test_param": "flow_test"})
|
||||
)
|
||||
|
||||
# Step 2: Update weights
|
||||
# Step 2: Start weight update
|
||||
llm.start_weight_update(is_checkpoint_format=True)
|
||||
|
||||
# Step 3: Update weights
|
||||
llm.update_weights(
|
||||
WeightTransferUpdateRequest(
|
||||
update_info={
|
||||
@@ -269,6 +279,9 @@ def test_full_weight_transfer_flow():
|
||||
)
|
||||
)
|
||||
|
||||
# Step 4: Finish weight update
|
||||
llm.finish_weight_update()
|
||||
|
||||
# Verify the full flow completed
|
||||
def check_flow(self):
|
||||
engine = self.weight_transfer_engine
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable, Iterator
|
||||
from dataclasses import KW_ONLY, dataclass, field
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Generic, TypeVar
|
||||
|
||||
import torch
|
||||
@@ -28,11 +28,7 @@ class WeightTransferInitInfo(ABC): # noqa: B024
|
||||
class WeightTransferUpdateInfo(ABC): # noqa: B024
|
||||
"""Base class for backend-specific weight update info."""
|
||||
|
||||
_: KW_ONLY
|
||||
is_checkpoint_format: bool = True
|
||||
"""Set to True if weights are in checkpoint/original model format and need
|
||||
layerwise processing. Set to False if weights have already been processed
|
||||
into kernel format (repacking, renaming, etc.)."""
|
||||
pass
|
||||
|
||||
|
||||
# API-level request classes (accept dicts for backend-agnostic serialization)
|
||||
|
||||
@@ -211,6 +211,11 @@ class IPCWeightTransferEngine(
|
||||
- 'ray': Sends weights via Ray RPC to a Ray-based LLM handle
|
||||
- 'http': Sends weights via HTTP POST to a vLLM HTTP server
|
||||
|
||||
.. note::
|
||||
This method calls ``update_weights`` internally. The caller must
|
||||
call ``start_weight_update`` before and ``finish_weight_update``
|
||||
after this method.
|
||||
|
||||
Args:
|
||||
iterator: Iterator of model parameters. Returns (name, tensor) tuples.
|
||||
Tensors should be on the same GPU as the inference workers.
|
||||
|
||||
@@ -230,6 +230,14 @@ class EngineClient(ABC):
|
||||
"""Initialize weight transfer for RL training."""
|
||||
raise NotImplementedError
|
||||
|
||||
async def start_weight_update(self, is_checkpoint_format: bool = True) -> None:
|
||||
"""Start a new weight update."""
|
||||
raise NotImplementedError
|
||||
|
||||
async def update_weights(self, request: WeightTransferUpdateRequest) -> None:
|
||||
"""Batched weight update for RL training."""
|
||||
raise NotImplementedError
|
||||
|
||||
async def finish_weight_update(self) -> None:
|
||||
"""Finish the current weight update."""
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -1908,6 +1908,20 @@ class LLM:
|
||||
"init_weight_transfer_engine", kwargs={"init_info": init_info_dict}
|
||||
)
|
||||
|
||||
def start_weight_update(self, is_checkpoint_format: bool = True) -> None:
|
||||
"""
|
||||
Start a new weight update.
|
||||
|
||||
Args:
|
||||
is_checkpoint_format: Whether incoming weights are in checkpoint
|
||||
format (need layerwise processing) or kernel format (direct
|
||||
copy).
|
||||
"""
|
||||
self.llm_engine.collective_rpc(
|
||||
"start_weight_update",
|
||||
kwargs={"is_checkpoint_format": is_checkpoint_format},
|
||||
)
|
||||
|
||||
def update_weights(self, request: WeightTransferUpdateRequest | dict) -> None:
|
||||
"""
|
||||
Update the weights of the model.
|
||||
@@ -1923,6 +1937,12 @@ class LLM:
|
||||
"update_weights", kwargs={"update_info": update_info_dict}
|
||||
)
|
||||
|
||||
def finish_weight_update(self) -> None:
|
||||
"""
|
||||
Finish the current weight update.
|
||||
"""
|
||||
self.llm_engine.collective_rpc("finish_weight_update")
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""Return a transformers-style hierarchical view of the model."""
|
||||
# Cache the result to avoid repeated collective_rpc calls
|
||||
|
||||
@@ -128,6 +128,19 @@ async def init_weight_transfer_engine(raw_request: Request):
|
||||
return JSONResponse(content={"message": "Weight transfer initialized"})
|
||||
|
||||
|
||||
@router.post("/start_weight_update")
|
||||
async def start_weight_update(raw_request: Request):
|
||||
try:
|
||||
body = await raw_request.json()
|
||||
except json.JSONDecodeError as e:
|
||||
raise HTTPException(status_code=400, detail="Invalid JSON format") from e # noqa: B904
|
||||
is_checkpoint_format = body.get("is_checkpoint_format", True)
|
||||
await engine_client(raw_request).start_weight_update(
|
||||
is_checkpoint_format=is_checkpoint_format
|
||||
)
|
||||
return JSONResponse(content={"message": "Weight update started"})
|
||||
|
||||
|
||||
@router.post("/update_weights")
|
||||
async def update_weights(raw_request: Request):
|
||||
try:
|
||||
@@ -146,6 +159,12 @@ async def update_weights(raw_request: Request):
|
||||
return JSONResponse(content={"message": "Weights updated"})
|
||||
|
||||
|
||||
@router.post("/finish_weight_update")
|
||||
async def finish_weight_update(raw_request: Request):
|
||||
await engine_client(raw_request).finish_weight_update()
|
||||
return JSONResponse(content={"message": "Weight update finished"})
|
||||
|
||||
|
||||
@router.get("/get_world_size")
|
||||
async def get_world_size(
|
||||
raw_request: Request,
|
||||
|
||||
@@ -1049,6 +1049,13 @@ class AsyncLLM(EngineClient):
|
||||
"init_weight_transfer_engine", kwargs={"init_info": init_info_dict}
|
||||
)
|
||||
|
||||
async def start_weight_update(self, is_checkpoint_format: bool = True) -> None:
|
||||
"""Start a new weight update."""
|
||||
await self.collective_rpc(
|
||||
"start_weight_update",
|
||||
kwargs={"is_checkpoint_format": is_checkpoint_format},
|
||||
)
|
||||
|
||||
async def update_weights(self, request: WeightTransferUpdateRequest) -> None:
|
||||
"""
|
||||
Batched weight update for RL training.
|
||||
@@ -1067,3 +1074,7 @@ class AsyncLLM(EngineClient):
|
||||
await self.collective_rpc(
|
||||
"update_weights", kwargs={"update_info": update_info_dict}
|
||||
)
|
||||
|
||||
async def finish_weight_update(self) -> None:
|
||||
"""Finish the current weight update."""
|
||||
await self.collective_rpc("finish_weight_update")
|
||||
|
||||
@@ -140,6 +140,8 @@ class Worker(WorkerBase):
|
||||
if self.vllm_config.weight_transfer_config is not None
|
||||
else None
|
||||
)
|
||||
self._weight_update_active = False
|
||||
self._is_checkpoint_format = True
|
||||
|
||||
# Torch/CUDA profiler. Enabled and configured through profiler_config.
|
||||
# Profiler wrapper is created lazily in profile() when start is called,
|
||||
@@ -966,6 +968,13 @@ class Worker(WorkerBase):
|
||||
model_config=self.model_config,
|
||||
)
|
||||
|
||||
def _check_weight_transfer_engine(self) -> None:
|
||||
if self.weight_transfer_engine is None:
|
||||
raise RuntimeError(
|
||||
"Weight transfer not configured. "
|
||||
"Please set weight_transfer_config to enable weight transfer."
|
||||
)
|
||||
|
||||
def init_weight_transfer_engine(self, init_info: dict) -> None:
|
||||
"""
|
||||
Initialize weight transfer mechanism.
|
||||
@@ -974,26 +983,62 @@ class Worker(WorkerBase):
|
||||
Args:
|
||||
init_info: Dictionary containing backend-specific initialization info
|
||||
"""
|
||||
if self.weight_transfer_engine is None:
|
||||
raise RuntimeError(
|
||||
"Weight transfer not configured. "
|
||||
"Please set weight_transfer_config to enable weight transfer."
|
||||
)
|
||||
self._check_weight_transfer_engine()
|
||||
assert self.weight_transfer_engine is not None
|
||||
# Parse dict into backend-specific typed dataclass
|
||||
typed_init_info = self.weight_transfer_engine.parse_init_info(init_info)
|
||||
self.weight_transfer_engine.init_transfer_engine(typed_init_info)
|
||||
|
||||
def start_weight_update(self, is_checkpoint_format: bool = True) -> None:
|
||||
"""
|
||||
Start a new weight update.
|
||||
|
||||
Prepares the model for receiving weights. For checkpoint format,
|
||||
this initializes state for layerwise processing. For kernel format, this is
|
||||
a no-op but must still be called for consistency.
|
||||
|
||||
Args:
|
||||
is_checkpoint_format: Whether incoming weights are in checkpoint
|
||||
format (need layerwise processing) or kernel format (direct
|
||||
copy). Stored as state for finish_weight_update.
|
||||
"""
|
||||
self._check_weight_transfer_engine()
|
||||
|
||||
if self._weight_update_active:
|
||||
raise RuntimeError(
|
||||
"start_weight_update called while a weight update is "
|
||||
"already active. Call finish_weight_update first."
|
||||
)
|
||||
|
||||
if is_checkpoint_format:
|
||||
from vllm.model_executor.model_loader.reload import (
|
||||
initialize_layerwise_reload,
|
||||
)
|
||||
|
||||
model = self.model_runner.model
|
||||
with torch.device(self.device):
|
||||
initialize_layerwise_reload(model)
|
||||
|
||||
# Store state so update_weights/finish_weight_update can check
|
||||
self._is_checkpoint_format = is_checkpoint_format
|
||||
self._weight_update_active = True
|
||||
|
||||
def update_weights(self, update_info: dict) -> None:
|
||||
"""
|
||||
Batched weight update from the trainer.
|
||||
Receive weights from the trainer (one or more chunks).
|
||||
|
||||
start_weight_update must be called before update_weights and
|
||||
finish_weight_update must be called after.
|
||||
|
||||
Args:
|
||||
update_info: Dictionary containing backend-specific update info
|
||||
"""
|
||||
if self.weight_transfer_engine is None:
|
||||
self._check_weight_transfer_engine()
|
||||
assert self.weight_transfer_engine is not None
|
||||
|
||||
if not self._weight_update_active:
|
||||
raise RuntimeError(
|
||||
"Weight transfer not configured. "
|
||||
"Please set weight_transfer_config to enable weight transfer."
|
||||
"start_weight_update must be called before update_weights."
|
||||
)
|
||||
|
||||
# Parse dict into backend-specific typed dataclass
|
||||
@@ -1001,38 +1046,59 @@ class Worker(WorkerBase):
|
||||
|
||||
model = self.model_runner.model
|
||||
|
||||
if typed_update_info.is_checkpoint_format:
|
||||
from vllm.model_executor.model_loader.reload import (
|
||||
finalize_layerwise_reload,
|
||||
initialize_layerwise_reload,
|
||||
)
|
||||
|
||||
# Use layerwise reload pattern for checkpoint format weights
|
||||
with torch.device(self.device):
|
||||
initialize_layerwise_reload(model)
|
||||
with torch.device(self.device):
|
||||
if self._is_checkpoint_format:
|
||||
self.weight_transfer_engine.receive_weights(
|
||||
typed_update_info,
|
||||
load_weights=model.load_weights,
|
||||
)
|
||||
finalize_layerwise_reload(model, self.model_config)
|
||||
else:
|
||||
# Weights are already in kernel format, copy directly
|
||||
def load_weights_direct(
|
||||
weights: list[tuple[str, torch.Tensor]],
|
||||
) -> None:
|
||||
for name, weight in weights:
|
||||
param = model.get_parameter(name)
|
||||
param.copy_(weight)
|
||||
else:
|
||||
# Weights are already in kernel format, copy directly
|
||||
def load_weights_direct(
|
||||
weights: list[tuple[str, torch.Tensor]],
|
||||
) -> None:
|
||||
for name, weight in weights:
|
||||
param = model.get_parameter(name)
|
||||
param.copy_(weight)
|
||||
|
||||
self.weight_transfer_engine.receive_weights(
|
||||
typed_update_info,
|
||||
load_weights=load_weights_direct,
|
||||
)
|
||||
self.weight_transfer_engine.receive_weights(
|
||||
typed_update_info,
|
||||
load_weights=load_weights_direct,
|
||||
)
|
||||
|
||||
# NCCL broadcast/packed path are asynchronous.
|
||||
# Sync here so the next step uses the new weights.
|
||||
torch.accelerator.synchronize()
|
||||
|
||||
def finish_weight_update(self) -> None:
|
||||
"""
|
||||
Finish the current weight update.
|
||||
|
||||
For checkpoint format, this runs layerwise postprocessing.
|
||||
Uses the is_checkpoint_format state stored by start_weight_update.
|
||||
"""
|
||||
self._check_weight_transfer_engine()
|
||||
|
||||
if not self._weight_update_active:
|
||||
raise RuntimeError(
|
||||
"start_weight_update must be called before finish_weight_update."
|
||||
)
|
||||
|
||||
is_checkpoint_format = self._is_checkpoint_format
|
||||
|
||||
if is_checkpoint_format:
|
||||
from vllm.model_executor.model_loader.reload import (
|
||||
finalize_layerwise_reload,
|
||||
)
|
||||
|
||||
model = self.model_runner.model
|
||||
with torch.device(self.device):
|
||||
finalize_layerwise_reload(model, self.model_config)
|
||||
|
||||
# Reset state
|
||||
self._weight_update_active = False
|
||||
self._is_checkpoint_format = True
|
||||
|
||||
def shutdown(self) -> None:
|
||||
# has_kv_transfer_group can be None during interpreter shutdown.
|
||||
if ensure_kv_transfer_shutdown is not None:
|
||||
|
||||
Reference in New Issue
Block a user