mirror of
https://github.com/vllm-project/vllm.git
synced 2026-06-06 00:16:14 +00:00
[feat] Add explicit /start_weight_update and /finish_weight_update APIs for weight transfer (#39212)
This commit is contained in:
@@ -310,6 +310,8 @@ gen_futures = [
|
||||
|
||||
ray.get(llm.pause_after_n_tokens.remote())
|
||||
|
||||
ray.get(llm.start_weight_update.remote(is_checkpoint_format=True))
|
||||
|
||||
inference_handle = llm.update_weights.remote(
|
||||
WeightTransferUpdateRequest(
|
||||
update_info=asdict(
|
||||
@@ -325,6 +327,8 @@ inference_handle = llm.update_weights.remote(
|
||||
train_handle = train_model.broadcast_weights.remote(packed=True)
|
||||
ray.get([train_handle, inference_handle])
|
||||
|
||||
ray.get(llm.finish_weight_update.remote())
|
||||
|
||||
ray.get(llm.resume_generation.remote())
|
||||
results = ray.get(gen_futures)
|
||||
|
||||
|
||||
@@ -80,6 +80,24 @@ def init_weight_transfer_engine(base_url: str) -> None:
|
||||
response.raise_for_status()
|
||||
|
||||
|
||||
def start_weight_update(
|
||||
base_url: str,
|
||||
is_checkpoint_format: bool = True,
|
||||
) -> None:
|
||||
"""Start a weight update via HTTP endpoint."""
|
||||
url = f"{base_url}/start_weight_update"
|
||||
payload = {"is_checkpoint_format": is_checkpoint_format}
|
||||
response = requests.post(url, json=payload, timeout=60)
|
||||
response.raise_for_status()
|
||||
|
||||
|
||||
def finish_weight_update(base_url: str) -> None:
|
||||
"""Finish a weight update via HTTP endpoint."""
|
||||
url = f"{base_url}/finish_weight_update"
|
||||
response = requests.post(url, json={}, timeout=60)
|
||||
response.raise_for_status()
|
||||
|
||||
|
||||
def pause_generation(base_url: str) -> None:
|
||||
"""Pause generation via HTTP endpoint."""
|
||||
url = f"{base_url}/pause"
|
||||
@@ -151,7 +169,9 @@ def main():
|
||||
# Pause generation before weight sync
|
||||
pause_generation(BASE_URL)
|
||||
|
||||
# Broadcast weights via IPC handles using HTTP mode
|
||||
# Start weight update, broadcast via IPC, then finish
|
||||
start_weight_update(BASE_URL, is_checkpoint_format=False)
|
||||
|
||||
print("Broadcasting weights via CUDA IPC (HTTP)...")
|
||||
trainer_args = IPCTrainerSendWeightsArgs(mode="http", url=BASE_URL)
|
||||
IPCWeightTransferEngine.trainer_send_weights(
|
||||
@@ -159,6 +179,8 @@ def main():
|
||||
trainer_args=trainer_args,
|
||||
)
|
||||
|
||||
finish_weight_update(BASE_URL)
|
||||
|
||||
# Resume generation after weight sync
|
||||
resume_generation(BASE_URL)
|
||||
|
||||
|
||||
@@ -83,6 +83,17 @@ def init_weight_transfer_engine(
|
||||
response.raise_for_status()
|
||||
|
||||
|
||||
def start_weight_update(
|
||||
base_url: str,
|
||||
is_checkpoint_format: bool = True,
|
||||
) -> None:
|
||||
"""Start a weight update via HTTP endpoint."""
|
||||
url = f"{base_url}/start_weight_update"
|
||||
payload = {"is_checkpoint_format": is_checkpoint_format}
|
||||
response = requests.post(url, json=payload, timeout=60)
|
||||
response.raise_for_status()
|
||||
|
||||
|
||||
def update_weights(
|
||||
base_url: str,
|
||||
names: list[str],
|
||||
@@ -104,6 +115,13 @@ def update_weights(
|
||||
response.raise_for_status()
|
||||
|
||||
|
||||
def finish_weight_update(base_url: str) -> None:
|
||||
"""Finish a weight update via HTTP endpoint."""
|
||||
url = f"{base_url}/finish_weight_update"
|
||||
response = requests.post(url, json={}, timeout=60)
|
||||
response.raise_for_status()
|
||||
|
||||
|
||||
def pause_generation(base_url: str) -> None:
|
||||
"""Pause generation via HTTP endpoint."""
|
||||
url = f"{base_url}/pause"
|
||||
@@ -204,6 +222,9 @@ def main():
|
||||
dtype_names.append(str(p.dtype).split(".")[-1])
|
||||
shapes.append(list(p.shape))
|
||||
|
||||
# Start weight update
|
||||
start_weight_update(BASE_URL, is_checkpoint_format=True)
|
||||
|
||||
# Start the update_weights call in a separate thread since it will block
|
||||
# waiting for NCCL broadcasts
|
||||
# packed=True enables efficient batched tensor broadcasting
|
||||
@@ -227,6 +248,9 @@ def main():
|
||||
# Wait for update_weights to complete
|
||||
update_thread.join()
|
||||
|
||||
# Finish weight update
|
||||
finish_weight_update(BASE_URL)
|
||||
|
||||
# Resume generation after weight sync
|
||||
resume_generation(BASE_URL)
|
||||
|
||||
|
||||
@@ -134,8 +134,10 @@ for output in outputs:
|
||||
ray.get(llm.sleep.remote(level=0))
|
||||
|
||||
ray.get(train_model.init_weight_transfer.remote())
|
||||
# Synchronize the updated weights to the inference engine using batched API.
|
||||
# Start weight update, sync weights, then finish
|
||||
ray.get(llm.start_weight_update.remote(is_checkpoint_format=True))
|
||||
ray.get(train_model.broadcast_weights.remote(llm))
|
||||
ray.get(llm.finish_weight_update.remote())
|
||||
|
||||
ray.get(llm.wake_up.remote(tags=["scheduling"]))
|
||||
|
||||
|
||||
@@ -186,6 +186,9 @@ ray.get([train_handle, inference_handle])
|
||||
# Collect all weight metadata from the training actor
|
||||
names, dtype_names, shapes = ray.get(train_model.get_weight_metadata.remote())
|
||||
|
||||
# Start weight update
|
||||
ray.get(llm.start_weight_update.remote(is_checkpoint_format=True))
|
||||
|
||||
# Issue update_weights call with NCCL-specific update info
|
||||
# packed=True enables efficient batched tensor broadcasting
|
||||
inference_handle = llm.update_weights.remote(
|
||||
@@ -203,6 +206,9 @@ inference_handle = llm.update_weights.remote(
|
||||
train_handle = train_model.broadcast_weights.remote(packed=True)
|
||||
ray.get([train_handle, inference_handle])
|
||||
|
||||
# Finish weight update
|
||||
ray.get(llm.finish_weight_update.remote())
|
||||
|
||||
ray.get(llm.wake_up.remote(tags=["scheduling"]))
|
||||
|
||||
# Generate text with the updated model. The output is expected to be normal
|
||||
|
||||
@@ -298,6 +298,9 @@ async def main():
|
||||
names, dtype_names, shapes = ray.get(fsdp_workers[0].get_weight_metadata.remote())
|
||||
print(f"[sync] Got metadata for {len(names)} parameters.")
|
||||
|
||||
print("[sync] Starting weight update...")
|
||||
await engine.start_weight_update(is_checkpoint_format=True)
|
||||
|
||||
print("[sync] Broadcasting weights from FSDP → vLLM...")
|
||||
broadcast_handles = [
|
||||
w.gather_and_broadcast_weights.remote(packed=True) for w in fsdp_workers
|
||||
@@ -315,6 +318,8 @@ async def main():
|
||||
)
|
||||
)
|
||||
ray.get(broadcast_handles)
|
||||
|
||||
await engine.finish_weight_update()
|
||||
print("[sync] Weight broadcast complete.")
|
||||
|
||||
print("[sync] Resuming generation...")
|
||||
|
||||
Reference in New Issue
Block a user