[feat] Add explicit /start_weight_update and /finish_weight_update APIs for weight transfer (#39212)

This commit is contained in:
Sumanth R Hegde
2026-05-08 18:03:33 -07:00
committed by GitHub
parent 30f519e947
commit e3b65a5ba0
19 changed files with 287 additions and 53 deletions
+4
View File
@@ -310,6 +310,8 @@ gen_futures = [
ray.get(llm.pause_after_n_tokens.remote())
ray.get(llm.start_weight_update.remote(is_checkpoint_format=True))
inference_handle = llm.update_weights.remote(
WeightTransferUpdateRequest(
update_info=asdict(
@@ -325,6 +327,8 @@ inference_handle = llm.update_weights.remote(
train_handle = train_model.broadcast_weights.remote(packed=True)
ray.get([train_handle, inference_handle])
ray.get(llm.finish_weight_update.remote())
ray.get(llm.resume_generation.remote())
results = ray.get(gen_futures)
+23 -1
View File
@@ -80,6 +80,24 @@ def init_weight_transfer_engine(base_url: str) -> None:
response.raise_for_status()
def start_weight_update(
base_url: str,
is_checkpoint_format: bool = True,
) -> None:
"""Start a weight update via HTTP endpoint."""
url = f"{base_url}/start_weight_update"
payload = {"is_checkpoint_format": is_checkpoint_format}
response = requests.post(url, json=payload, timeout=60)
response.raise_for_status()
def finish_weight_update(base_url: str) -> None:
"""Finish a weight update via HTTP endpoint."""
url = f"{base_url}/finish_weight_update"
response = requests.post(url, json={}, timeout=60)
response.raise_for_status()
def pause_generation(base_url: str) -> None:
"""Pause generation via HTTP endpoint."""
url = f"{base_url}/pause"
@@ -151,7 +169,9 @@ def main():
# Pause generation before weight sync
pause_generation(BASE_URL)
# Broadcast weights via IPC handles using HTTP mode
# Start weight update, broadcast via IPC, then finish
start_weight_update(BASE_URL, is_checkpoint_format=False)
print("Broadcasting weights via CUDA IPC (HTTP)...")
trainer_args = IPCTrainerSendWeightsArgs(mode="http", url=BASE_URL)
IPCWeightTransferEngine.trainer_send_weights(
@@ -159,6 +179,8 @@ def main():
trainer_args=trainer_args,
)
finish_weight_update(BASE_URL)
# Resume generation after weight sync
resume_generation(BASE_URL)
+24
View File
@@ -83,6 +83,17 @@ def init_weight_transfer_engine(
response.raise_for_status()
def start_weight_update(
base_url: str,
is_checkpoint_format: bool = True,
) -> None:
"""Start a weight update via HTTP endpoint."""
url = f"{base_url}/start_weight_update"
payload = {"is_checkpoint_format": is_checkpoint_format}
response = requests.post(url, json=payload, timeout=60)
response.raise_for_status()
def update_weights(
base_url: str,
names: list[str],
@@ -104,6 +115,13 @@ def update_weights(
response.raise_for_status()
def finish_weight_update(base_url: str) -> None:
"""Finish a weight update via HTTP endpoint."""
url = f"{base_url}/finish_weight_update"
response = requests.post(url, json={}, timeout=60)
response.raise_for_status()
def pause_generation(base_url: str) -> None:
"""Pause generation via HTTP endpoint."""
url = f"{base_url}/pause"
@@ -204,6 +222,9 @@ def main():
dtype_names.append(str(p.dtype).split(".")[-1])
shapes.append(list(p.shape))
# Start weight update
start_weight_update(BASE_URL, is_checkpoint_format=True)
# Start the update_weights call in a separate thread since it will block
# waiting for NCCL broadcasts
# packed=True enables efficient batched tensor broadcasting
@@ -227,6 +248,9 @@ def main():
# Wait for update_weights to complete
update_thread.join()
# Finish weight update
finish_weight_update(BASE_URL)
# Resume generation after weight sync
resume_generation(BASE_URL)
+3 -1
View File
@@ -134,8 +134,10 @@ for output in outputs:
ray.get(llm.sleep.remote(level=0))
ray.get(train_model.init_weight_transfer.remote())
# Synchronize the updated weights to the inference engine using batched API.
# Start weight update, sync weights, then finish
ray.get(llm.start_weight_update.remote(is_checkpoint_format=True))
ray.get(train_model.broadcast_weights.remote(llm))
ray.get(llm.finish_weight_update.remote())
ray.get(llm.wake_up.remote(tags=["scheduling"]))
+6
View File
@@ -186,6 +186,9 @@ ray.get([train_handle, inference_handle])
# Collect all weight metadata from the training actor
names, dtype_names, shapes = ray.get(train_model.get_weight_metadata.remote())
# Start weight update
ray.get(llm.start_weight_update.remote(is_checkpoint_format=True))
# Issue update_weights call with NCCL-specific update info
# packed=True enables efficient batched tensor broadcasting
inference_handle = llm.update_weights.remote(
@@ -203,6 +206,9 @@ inference_handle = llm.update_weights.remote(
train_handle = train_model.broadcast_weights.remote(packed=True)
ray.get([train_handle, inference_handle])
# Finish weight update
ray.get(llm.finish_weight_update.remote())
ray.get(llm.wake_up.remote(tags=["scheduling"]))
# Generate text with the updated model. The output is expected to be normal
+5
View File
@@ -298,6 +298,9 @@ async def main():
names, dtype_names, shapes = ray.get(fsdp_workers[0].get_weight_metadata.remote())
print(f"[sync] Got metadata for {len(names)} parameters.")
print("[sync] Starting weight update...")
await engine.start_weight_update(is_checkpoint_format=True)
print("[sync] Broadcasting weights from FSDP → vLLM...")
broadcast_handles = [
w.gather_and_broadcast_weights.remote(packed=True) for w in fsdp_workers
@@ -315,6 +318,8 @@ async def main():
)
)
ray.get(broadcast_handles)
await engine.finish_weight_update()
print("[sync] Weight broadcast complete.")
print("[sync] Resuming generation...")