mirror of
https://github.com/vllm-project/vllm.git
synced 2026-06-06 00:16:14 +00:00
[Feat][RL] IPC weight sync optimizations: multigpu support and chunked packed tensors (#37476)
Signed-off-by: ahao-anyscale <ahao@anyscale.com> Signed-off-by: hao-aaron <ahao@anyscale.com>
This commit is contained in:
@@ -1,21 +1,37 @@
|
||||
# IPC Engine
|
||||
|
||||
The IPC weight transfer engine uses **CUDA IPC** (Inter-Process Communication) handles to share GPU memory directly between the trainer and inference workers on the **same node and same GPU**. This avoids any data copying, making it a efficient option when colocating training and inference.
|
||||
The IPC weight transfer engine uses **CUDA IPC** (Inter-Process Communication) handles to share GPU memory directly between the trainer and inference workers on the **same GPU**. This avoids any data copying, making it the most efficient option when colocating training and inference. Multi-GPU setups are supported — weights are all gathered by each GPU and are extracted by the correct colocated process.
|
||||
|
||||
## When to Use IPC
|
||||
|
||||
- Training and inference on the **same GPU** (colocated)
|
||||
- You want to minimize memory overhead by sharing tensors in-place
|
||||
- Training and inference share the **same GPU(s)** (colocated)
|
||||
|
||||
## How It Works
|
||||
|
||||
1. The trainer creates CUDA tensors for each weight and generates IPC handles using `torch.multiprocessing.reductions.reduce_tensor`.
|
||||
2. IPC handles are sent to the inference engine via **Ray.remote()** or **HTTP POST**.
|
||||
3. The inference worker reconstructs the tensors from the handles, reading directly from the trainer's GPU memory.
|
||||
1. The trainer creates CUDA tensors for each weight and generates IPC handles using `torch.multiprocessing.reductions.reduce_tensor`. In multi-GPU setups (e.g. FSDP), each trainer rank must all-gather the full tensor for each layer onto its own GPU before generating the IPC handle.
|
||||
2. IPC handles for each gpu are sent to the inference engine via **Ray**, **HTTP**, or a **custom callable**. Each rank only reads the handle corresponding to its own GPU.
|
||||
3. The inference worker reconstructs the tensors from the handles using `rebuild_cuda_tensor`, reading directly from the trainer's GPU memory.
|
||||
|
||||
!!! warning
|
||||
IPC handles involve sending serialized Python objects. When using HTTP transport, you must set `VLLM_ALLOW_INSECURE_SERIALIZATION=1` on both the server and client. This is because IPC handles are pickled and base64-encoded for HTTP transmission.
|
||||
|
||||
## Packed (Chunked) Transfer
|
||||
|
||||
By default, all weights are sent in a single API call. For large models, this requires the full model to reside in GPU memory on both sides simultaneously. Setting `packed=True` enables **chunked transfer** with bounded GPU memory:
|
||||
|
||||
- Weights are concatenated into fixed-size packed buffers (controlled by `packed_buffer_size_bytes`).
|
||||
- Each chunk is sent as a separate `update_weights` call within a single `start_weight_update` / `finish_weight_update` bracket, so the layerwise reload pass is initialized once at the start and finalized once at the end regardless of chunk count.
|
||||
- After each chunk is consumed, the GPU memory for that chunk can be reclaimed.
|
||||
|
||||
```python
|
||||
trainer_args = IPCTrainerSendWeightsArgs(
|
||||
send_mode="ray",
|
||||
llm_handle=llm_actor_handle,
|
||||
packed=True,
|
||||
packed_buffer_size_bytes=256 * 1024 * 1024, # 256 MB chunks
|
||||
)
|
||||
```
|
||||
|
||||
## Initialization
|
||||
|
||||
The IPC backend requires no initialization on either side. The `init_transfer_engine` call is a no-op for IPC.
|
||||
@@ -35,7 +51,7 @@ from vllm.distributed.weight_transfer.ipc_engine import (
|
||||
)
|
||||
|
||||
trainer_args = IPCTrainerSendWeightsArgs(
|
||||
mode="ray",
|
||||
send_mode="ray",
|
||||
llm_handle=llm_actor_handle,
|
||||
)
|
||||
# start
|
||||
@@ -57,7 +73,7 @@ Used when vLLM is running as an HTTP server:
|
||||
|
||||
```python
|
||||
trainer_args = IPCTrainerSendWeightsArgs(
|
||||
mode="http",
|
||||
send_mode="http",
|
||||
url="http://localhost:8000",
|
||||
)
|
||||
|
||||
@@ -77,7 +93,22 @@ response = requests.post(url, json={}, timeout=60)
|
||||
response.raise_for_status()
|
||||
```
|
||||
|
||||
In HTTP mode, IPC handles are pickled, base64-encoded, and sent as JSON to the `/update_weights` endpoint. As with Ray mode, you must call `start_weight_update` before and `finish_weight_update` after.
|
||||
In HTTP mode, IPC handles are pickled, base64-encoded, and sent as JSON to the `/update_weights` endpoint. Because the worker deserializes the payload via `pickle.loads`, the vLLM server must be started with `VLLM_ALLOW_INSECURE_SERIALIZATION=1`.
|
||||
|
||||
```python
|
||||
def my_custom_sender(update_info: IPCWeightTransferUpdateInfo):
|
||||
# Custom logic to deliver update_info to vLLM
|
||||
...
|
||||
|
||||
trainer_args = IPCTrainerSendWeightsArgs(
|
||||
send_mode=my_custom_sender,
|
||||
)
|
||||
|
||||
IPCWeightTransferEngine.trainer_send_weights(
|
||||
iterator=model.named_parameters(),
|
||||
trainer_args=trainer_args,
|
||||
)
|
||||
```
|
||||
|
||||
See [`IPCTrainerSendWeightsArgs`](https://github.com/vllm-project/vllm/blob/main/vllm/distributed/weight_transfer/ipc_engine.py) for the full list of configurable fields.
|
||||
|
||||
|
||||
@@ -173,7 +173,7 @@ def main():
|
||||
start_weight_update(BASE_URL, is_checkpoint_format=False)
|
||||
|
||||
print("Broadcasting weights via CUDA IPC (HTTP)...")
|
||||
trainer_args = IPCTrainerSendWeightsArgs(mode="http", url=BASE_URL)
|
||||
trainer_args = IPCTrainerSendWeightsArgs(send_mode="http", url=BASE_URL)
|
||||
IPCWeightTransferEngine.trainer_send_weights(
|
||||
iterator=train_model.named_parameters(),
|
||||
trainer_args=trainer_args,
|
||||
|
||||
@@ -70,10 +70,14 @@ class TrainModel:
|
||||
self.llm_handle.init_weight_transfer_engine.remote(dict(init_info=dict()))
|
||||
)
|
||||
|
||||
def broadcast_weights(self, llm_handle: ray.actor.ActorHandle):
|
||||
def broadcast_weights(
|
||||
self, llm_handle: ray.actor.ActorHandle, packed: bool = False
|
||||
):
|
||||
"""Broadcast weights to the inference engine using IPC."""
|
||||
self.llm_handle = llm_handle
|
||||
trainer_args = IPCTrainerSendWeightsArgs(mode="ray", llm_handle=llm_handle)
|
||||
trainer_args = IPCTrainerSendWeightsArgs(
|
||||
send_mode="ray", llm_handle=llm_handle, packed=packed
|
||||
)
|
||||
IPCWeightTransferEngine.trainer_send_weights(
|
||||
iterator=self.train_model.named_parameters(),
|
||||
trainer_args=trainer_args,
|
||||
@@ -141,10 +145,10 @@ ray.get(llm.finish_weight_update.remote())
|
||||
|
||||
ray.get(llm.wake_up.remote(tags=["scheduling"]))
|
||||
|
||||
# Generate text with the updated model.
|
||||
outputs_updated = ray.get(llm.generate.remote(prompts, sampling_params))
|
||||
outputs_packed = ray.get(llm.generate.remote(prompts, sampling_params))
|
||||
print("-" * 50)
|
||||
for output in outputs_updated:
|
||||
print("Results after packed/chunked IPC weight sync:")
|
||||
for output in outputs_packed:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
|
||||
|
||||
@@ -0,0 +1,425 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
RLHF with FSDP2 training and vLLM expert-parallel inference using **CUDA IPC**
|
||||
weight transfer and **packed** tensors.
|
||||
|
||||
Layout (4 GPUs, TP=1, DP=4, EP):
|
||||
* One Ray placement group per GPU.
|
||||
* Each PG holds one FSDP training worker and one vLLM ``LLM`` instance
|
||||
(sync API) using fractional GPUs so both fit on the same device.
|
||||
* The 4 ``LLM`` instances form a DP group via env-var-based SPMD
|
||||
coordination (``VLLM_DP_RANK``, ``VLLM_DP_SIZE``, etc.), the same
|
||||
mechanism used by ``examples/offline_inference/data_parallel.py``.
|
||||
* A ``DataParallelInferenceEngine`` actor spawns all 4 LLM actors,
|
||||
waits for initialization, and orchestrates generation / weight-sync.
|
||||
|
||||
Uses the built-in ``ray`` send_mode: each FSDP worker calls
|
||||
``trainer_send_weights`` targeting its colocated LLM actor.
|
||||
|
||||
This example was run on 4xH100.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from dataclasses import asdict
|
||||
|
||||
import ray
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from huggingface_hub import snapshot_download
|
||||
from ray.util.placement_group import placement_group
|
||||
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
|
||||
from torch.distributed._tensor import DTensor
|
||||
from torch.distributed.fsdp import fully_shard
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.config import WeightTransferConfig
|
||||
from vllm.distributed.weight_transfer.ipc_engine import (
|
||||
IPCTrainerSendWeightsArgs,
|
||||
IPCWeightTransferEngine,
|
||||
IPCWeightTransferInitInfo,
|
||||
)
|
||||
from vllm.utils.network_utils import get_ip, get_open_port
|
||||
|
||||
TRAIN_GPU_FRACTION = float(os.environ.get("RLHF_IPC_TRAIN_GPU_FRACTION", "0.42"))
|
||||
VLLM_GPU_FRACTION = float(os.environ.get("RLHF_IPC_VLLM_GPU_FRACTION", "0.42"))
|
||||
|
||||
MODEL_NAME = "Qwen/Qwen3-30B-A3B"
|
||||
|
||||
FSDP_WORLD_SIZE = 4
|
||||
INFERENCE_TP_SIZE = 1
|
||||
INFERENCE_DP_SIZE = 4
|
||||
|
||||
|
||||
class MyLLM(LLM):
|
||||
"""LLM subclass that configures DP env vars for SPMD coordination."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
dp_rank: int = 0,
|
||||
dp_size: int = 1,
|
||||
dp_master_ip: str = "127.0.0.1",
|
||||
dp_master_port: int = 0,
|
||||
**kwargs,
|
||||
):
|
||||
os.environ.pop("CUDA_VISIBLE_DEVICES", None)
|
||||
os.environ["VLLM_RAY_PER_WORKER_GPUS"] = str(VLLM_GPU_FRACTION)
|
||||
os.environ["VLLM_RAY_BUNDLE_INDICES"] = "0"
|
||||
os.environ["VLLM_ALLOW_INSECURE_SERIALIZATION"] = "1"
|
||||
|
||||
os.environ["VLLM_DP_RANK"] = str(dp_rank)
|
||||
os.environ["VLLM_DP_RANK_LOCAL"] = str(dp_rank)
|
||||
os.environ["VLLM_DP_SIZE"] = str(dp_size)
|
||||
os.environ["VLLM_DP_MASTER_IP"] = dp_master_ip
|
||||
os.environ["VLLM_DP_MASTER_PORT"] = str(dp_master_port)
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def ready(self):
|
||||
return True
|
||||
|
||||
|
||||
@ray.remote(num_cpus=0, num_gpus=TRAIN_GPU_FRACTION)
|
||||
class FSDPTrainWorker:
|
||||
"""One FSDP2 worker per GPU; colocated with vLLM DP rank via placement group."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str,
|
||||
rank: int,
|
||||
fsdp_world_size: int,
|
||||
fsdp_master_addr: str,
|
||||
fsdp_master_port: int,
|
||||
):
|
||||
self.rank = rank
|
||||
|
||||
os.environ["MASTER_ADDR"] = fsdp_master_addr
|
||||
os.environ["MASTER_PORT"] = str(fsdp_master_port)
|
||||
|
||||
dist.init_process_group(backend="nccl", rank=rank, world_size=fsdp_world_size)
|
||||
torch.accelerator.set_device_index(0)
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name, torch_dtype=torch.bfloat16
|
||||
)
|
||||
|
||||
self.weight_names = [n for n, _ in model.named_parameters()]
|
||||
self.weight_dtype_names = [
|
||||
str(p.dtype).split(".")[-1] for _, p in model.named_parameters()
|
||||
]
|
||||
self.weight_shapes = [list(p.shape) for _, p in model.named_parameters()]
|
||||
|
||||
for layer in model.model.layers:
|
||||
fully_shard(layer)
|
||||
fully_shard(model)
|
||||
|
||||
self.model = model
|
||||
|
||||
def get_rank(self):
|
||||
return self.rank
|
||||
|
||||
def get_weight_metadata(self):
|
||||
return self.weight_names, self.weight_dtype_names, self.weight_shapes
|
||||
|
||||
def gather_and_broadcast_weights_ipc(self, llm_handle, packed: bool = True):
|
||||
"""All-gather full params; all ranks create IPC handles, rank 0 sends.
|
||||
|
||||
All ranks must call trainer_send_weights so they participate in the
|
||||
all_gather_object collective inside _all_gather_and_merge_handles.
|
||||
Only rank 0 actually sends the payload to vLLM (gated by _is_rank_zero).
|
||||
"""
|
||||
|
||||
def _full_param_iter():
|
||||
# HF's Qwen3MoeExperts (and other recent HF MoE impls) packs
|
||||
# all experts into two fused 3-D tensors per layer:
|
||||
# experts.gate_up_proj shape (E, 2*I, H)
|
||||
# experts.down_proj shape (E, H, I)
|
||||
# vLLM's Qwen3MoE load_weights still expects the older
|
||||
# per-expert HF layout (experts.<i>.gate_proj.weight,
|
||||
# experts.<i>.up_proj.weight, experts.<i>.down_proj.weight),
|
||||
# so we un-fuse on the fly. Split order matches HF's forward:
|
||||
# gate, up = linear(x, gate_up_proj[i]).chunk(2, dim=-1)
|
||||
# → rows [:I] of gate_up_proj[i] are gate, rows [I:] are up.
|
||||
params = self.model.state_dict()
|
||||
for name in list(params.keys()):
|
||||
param = params.pop(name)
|
||||
if isinstance(param, DTensor):
|
||||
tensor = param.full_tensor().detach().contiguous()
|
||||
else:
|
||||
tensor = param.detach().contiguous()
|
||||
del param
|
||||
|
||||
if name.endswith(".experts.gate_up_proj") and tensor.dim() == 3:
|
||||
prefix = name[: -len(".gate_up_proj")]
|
||||
num_experts, two_inter, _ = tensor.shape
|
||||
inter = two_inter // 2
|
||||
for i in range(num_experts):
|
||||
expert = tensor[i]
|
||||
yield (
|
||||
f"{prefix}.{i}.gate_proj.weight",
|
||||
expert[:inter].contiguous(),
|
||||
)
|
||||
yield (
|
||||
f"{prefix}.{i}.up_proj.weight",
|
||||
expert[inter:].contiguous(),
|
||||
)
|
||||
del tensor
|
||||
elif name.endswith(".experts.down_proj") and tensor.dim() == 3:
|
||||
prefix = name[: -len(".down_proj")]
|
||||
num_experts = tensor.shape[0]
|
||||
for i in range(num_experts):
|
||||
yield (
|
||||
f"{prefix}.{i}.down_proj.weight",
|
||||
tensor[i].contiguous(),
|
||||
)
|
||||
del tensor
|
||||
else:
|
||||
yield name, tensor
|
||||
|
||||
trainer_args = IPCTrainerSendWeightsArgs(
|
||||
send_mode="ray",
|
||||
llm_handle=llm_handle,
|
||||
packed=packed,
|
||||
packed_buffer_size_bytes=1024 * 1024 * 1024, # 1 GB
|
||||
)
|
||||
IPCWeightTransferEngine.trainer_send_weights(
|
||||
iterator=_full_param_iter(),
|
||||
trainer_args=trainer_args,
|
||||
)
|
||||
|
||||
|
||||
@ray.remote(num_cpus=1)
|
||||
class DataParallelInferenceEngine:
|
||||
"""Manages a pool of DP-sharded vLLM LLM actors.
|
||||
|
||||
Spawns one MyLLM actor per placement group, waits for all engines to
|
||||
finish initializing, and exposes generation / weight-sync helpers.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
pgs: list,
|
||||
dp_master_ip: str,
|
||||
dp_master_port: int,
|
||||
):
|
||||
dp_size = len(pgs)
|
||||
self.llm_actors = []
|
||||
for r in range(dp_size):
|
||||
sched = PlacementGroupSchedulingStrategy(
|
||||
placement_group=pgs[r],
|
||||
placement_group_capture_child_tasks=True,
|
||||
)
|
||||
actor = (
|
||||
ray.remote(num_cpus=0, num_gpus=0)(MyLLM)
|
||||
.options(scheduling_strategy=sched)
|
||||
.remote(
|
||||
model=model,
|
||||
enforce_eager=True,
|
||||
tensor_parallel_size=INFERENCE_TP_SIZE,
|
||||
distributed_executor_backend="ray",
|
||||
enable_expert_parallel=True,
|
||||
gpu_memory_utilization=0.35,
|
||||
weight_transfer_config=WeightTransferConfig(backend="ipc"),
|
||||
enable_sleep_mode=True,
|
||||
load_format="dummy",
|
||||
dp_rank=r,
|
||||
dp_size=dp_size,
|
||||
dp_master_ip=dp_master_ip,
|
||||
dp_master_port=dp_master_port,
|
||||
)
|
||||
)
|
||||
self.llm_actors.append(actor)
|
||||
|
||||
ray.get([actor.ready.remote() for actor in self.llm_actors])
|
||||
|
||||
def get_llm_actors(self):
|
||||
return self.llm_actors
|
||||
|
||||
def generate(self, prompts: list[str], sampling_params):
|
||||
"""Distribute prompts round-robin across DP ranks and collect results."""
|
||||
dp_size = len(self.llm_actors)
|
||||
per_rank: list[list[str]] = [[] for _ in range(dp_size)]
|
||||
indices: list[list[int]] = [[] for _ in range(dp_size)]
|
||||
|
||||
for i, prompt in enumerate(prompts):
|
||||
rank = i % dp_size
|
||||
per_rank[rank].append(prompt)
|
||||
indices[rank].append(i)
|
||||
|
||||
refs = [
|
||||
actor.generate.remote(per_rank[r], sampling_params)
|
||||
for r, actor in enumerate(self.llm_actors)
|
||||
if per_rank[r]
|
||||
]
|
||||
all_outputs = ray.get(refs)
|
||||
|
||||
ordered = [None] * len(prompts)
|
||||
rank_idx = 0
|
||||
for r in range(dp_size):
|
||||
if per_rank[r]:
|
||||
for local_i, orig_i in enumerate(indices[r]):
|
||||
ordered[orig_i] = all_outputs[rank_idx][local_i]
|
||||
rank_idx += 1
|
||||
return ordered
|
||||
|
||||
def init_weight_transfer(self):
|
||||
ray.get(
|
||||
[
|
||||
actor.init_weight_transfer_engine.remote(
|
||||
dict(init_info=asdict(IPCWeightTransferInitInfo()))
|
||||
)
|
||||
for actor in self.llm_actors
|
||||
]
|
||||
)
|
||||
|
||||
def start_weight_update(self, is_checkpoint_format: bool = True):
|
||||
ray.get(
|
||||
[
|
||||
actor.start_weight_update.remote(
|
||||
is_checkpoint_format=is_checkpoint_format
|
||||
)
|
||||
for actor in self.llm_actors
|
||||
]
|
||||
)
|
||||
|
||||
def finish_weight_update(self):
|
||||
ray.get([actor.finish_weight_update.remote() for actor in self.llm_actors])
|
||||
|
||||
def sleep(self, level: int = 0):
|
||||
ray.get([actor.sleep.remote(level=level) for actor in self.llm_actors])
|
||||
|
||||
def wake_up(self, tags: list[str] | None = None):
|
||||
ray.get([actor.wake_up.remote(tags=tags) for actor in self.llm_actors])
|
||||
|
||||
|
||||
def main():
|
||||
ray.init(
|
||||
runtime_env={
|
||||
"env_vars": {
|
||||
"VLLM_ALLOW_INSECURE_SERIALIZATION": "1",
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
assert TRAIN_GPU_FRACTION + VLLM_GPU_FRACTION <= 1.0, (
|
||||
"Train + vLLM GPU fractions must sum to at most 1.0 per bundle."
|
||||
)
|
||||
|
||||
local_model_path = snapshot_download(MODEL_NAME)
|
||||
print(f"[init] Model downloaded to {local_model_path}")
|
||||
|
||||
fsdp_master_addr = get_ip()
|
||||
fsdp_master_port = get_open_port()
|
||||
dp_master_port = get_open_port()
|
||||
dp_master_ip = get_ip()
|
||||
|
||||
# Create one placement group per DP rank (one GPU each).
|
||||
pgs = []
|
||||
for _ in range(INFERENCE_DP_SIZE):
|
||||
pg = placement_group([{"GPU": 1, "CPU": 1}])
|
||||
pgs.append(pg)
|
||||
ray.get([pg.ready() for pg in pgs])
|
||||
print(f"[init] {len(pgs)} placement groups ready.")
|
||||
|
||||
# Launch FSDP training workers, one per PG.
|
||||
scheduling = [
|
||||
PlacementGroupSchedulingStrategy(
|
||||
placement_group=pgs[r],
|
||||
placement_group_capture_child_tasks=True,
|
||||
)
|
||||
for r in range(FSDP_WORLD_SIZE)
|
||||
]
|
||||
|
||||
fsdp_workers = [
|
||||
FSDPTrainWorker.options(scheduling_strategy=scheduling[r]).remote(
|
||||
local_model_path,
|
||||
r,
|
||||
FSDP_WORLD_SIZE,
|
||||
fsdp_master_addr,
|
||||
fsdp_master_port,
|
||||
)
|
||||
for r in range(FSDP_WORLD_SIZE)
|
||||
]
|
||||
ray.get([w.get_rank.remote() for w in fsdp_workers])
|
||||
print(f"[init] {FSDP_WORLD_SIZE} FSDP workers ready.")
|
||||
|
||||
# Launch DP inference engine (spawns and initializes all LLM actors).
|
||||
inference_engine = DataParallelInferenceEngine.remote(
|
||||
model=local_model_path,
|
||||
pgs=pgs,
|
||||
dp_master_ip=dp_master_ip,
|
||||
dp_master_port=dp_master_port,
|
||||
)
|
||||
llm_actors = ray.get(inference_engine.get_llm_actors.remote())
|
||||
print(f"[init] {INFERENCE_DP_SIZE} LLM actors ready.")
|
||||
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
sampling_params = SamplingParams(temperature=0)
|
||||
|
||||
print("[generate] Generating with dummy weights...")
|
||||
outputs = ray.get(inference_engine.generate.remote(prompts, sampling_params))
|
||||
print("-" * 60)
|
||||
print("BEFORE weight sync (dummy weights):")
|
||||
print("-" * 60)
|
||||
for output in outputs:
|
||||
print(f"Prompt: {output.prompt!r}")
|
||||
print(f"Generated: {output.outputs[0].text!r}")
|
||||
print("-" * 60)
|
||||
|
||||
# --- Weight transfer ---
|
||||
print("[transfer] Initializing IPC weight transfer...")
|
||||
ray.get(inference_engine.init_weight_transfer.remote())
|
||||
|
||||
# Two-phase sleep/wake pattern:
|
||||
# 1. sleep(level=1) — offload weights to CPU, discard KV cache
|
||||
# 2. wake_up(tags=["weights"]) — bring weights back to GPU (KV cache still free)
|
||||
# 3. IPC weight transfer — overwrite weights, plenty of room without KV cache
|
||||
# 4. wake_up(tags=["kv_cache"]) — re-allocate KV cache for inference
|
||||
print("[sync] Sleeping engines (offload weights + free KV cache)...")
|
||||
ray.get(inference_engine.sleep.remote(level=1))
|
||||
|
||||
print("[sync] Waking weights (KV cache stays free)...")
|
||||
ray.get(inference_engine.wake_up.remote(tags=["weights"]))
|
||||
|
||||
print("[sync] Starting weight update...")
|
||||
ray.get(inference_engine.start_weight_update.remote(is_checkpoint_format=True))
|
||||
|
||||
print("[sync] Packed IPC transfer FSDP → vLLM...")
|
||||
ray.get(
|
||||
[
|
||||
w.gather_and_broadcast_weights_ipc.remote(llm_actors, packed=True)
|
||||
for w in fsdp_workers
|
||||
]
|
||||
)
|
||||
|
||||
ray.get(inference_engine.finish_weight_update.remote())
|
||||
print("[sync] Weight transfer complete.")
|
||||
|
||||
print("[sync] Waking KV cache + scheduling...")
|
||||
ray.get(inference_engine.wake_up.remote(tags=["kv_cache", "scheduling"]))
|
||||
|
||||
print("[generate] Generating with synced weights...")
|
||||
outputs_updated = ray.get(
|
||||
inference_engine.generate.remote(prompts, sampling_params)
|
||||
)
|
||||
print("-" * 60)
|
||||
print("AFTER weight sync (real weights):")
|
||||
print("-" * 60)
|
||||
for output in outputs_updated:
|
||||
print(f"Prompt: {output.prompt!r}")
|
||||
print(f"Generated: {output.outputs[0].text!r}")
|
||||
print("-" * 60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -2,7 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Tests for packed tensor broadcasting functionality.
|
||||
|
||||
Unit tests for packed_broadcast_producer and packed_broadcast_consumer.
|
||||
Unit tests for packed_nccl_broadcast_producer and packed_nccl_broadcast_consumer.
|
||||
These utilities enable efficient batched tensor transfer over NCCL.
|
||||
"""
|
||||
|
||||
@@ -11,8 +11,12 @@ import torch
|
||||
|
||||
from vllm.distributed.weight_transfer.nccl_engine import NCCLWeightTransferUpdateInfo
|
||||
from vllm.distributed.weight_transfer.packed_tensor import (
|
||||
packed_broadcast_consumer,
|
||||
packed_broadcast_producer,
|
||||
pack_tensors,
|
||||
packed_ipc_consumer,
|
||||
packed_ipc_producer,
|
||||
packed_nccl_broadcast_consumer,
|
||||
packed_nccl_broadcast_producer,
|
||||
unpack_tensor,
|
||||
)
|
||||
|
||||
|
||||
@@ -90,91 +94,18 @@ class TestNCCLWeightTransferUpdateInfoPacked:
|
||||
assert info.packed is True
|
||||
|
||||
|
||||
# --- Unit Tests: packed_broadcast_producer ---
|
||||
# --- Unit Tests: packed_nccl_broadcast_producer ---
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
class TestPackedBroadcastProducer:
|
||||
"""Test packed_broadcast_producer function."""
|
||||
|
||||
def test_producer_broadcasts_tensors(self):
|
||||
"""Test that producer broadcasts all tensors."""
|
||||
params = create_mock_model_params()
|
||||
params_cuda = [(name, tensor.cuda()) for name, tensor in params]
|
||||
|
||||
mock_group = MockCommunicationGroup()
|
||||
|
||||
# Use a small target size to force multiple batches
|
||||
packed_broadcast_producer(
|
||||
iterator=iter(params_cuda),
|
||||
group=mock_group,
|
||||
src=0,
|
||||
post_iter_func=lambda x: x[1],
|
||||
buffer_size_bytes=500,
|
||||
)
|
||||
|
||||
# Should have broadcasted some tensors
|
||||
assert mock_group.broadcast_count > 0
|
||||
assert len(mock_group.broadcasted_tensors) > 0
|
||||
|
||||
def test_producer_single_large_tensor(self):
|
||||
"""Test with a single tensor larger than target size."""
|
||||
# Create a large tensor
|
||||
large_tensor = torch.randn(1000, 1000, dtype=torch.float32).cuda()
|
||||
params = [("large_weight", large_tensor)]
|
||||
|
||||
mock_group = MockCommunicationGroup()
|
||||
|
||||
# Small target size to force the tensor to exceed it
|
||||
packed_broadcast_producer(
|
||||
iterator=iter(params),
|
||||
group=mock_group,
|
||||
src=0,
|
||||
post_iter_func=lambda x: x[1],
|
||||
buffer_size_bytes=100,
|
||||
)
|
||||
|
||||
# Should still broadcast the tensor (at least 1 broadcast)
|
||||
assert mock_group.broadcast_count >= 1
|
||||
assert len(mock_group.broadcasted_tensors) >= 1
|
||||
|
||||
# Verify the total broadcasted size matches the tensor
|
||||
expected_size = large_tensor.numel() * large_tensor.element_size()
|
||||
actual_size = sum(t.numel() for t in mock_group.broadcasted_tensors)
|
||||
assert actual_size == expected_size
|
||||
|
||||
def test_producer_multiple_batches(self):
|
||||
"""Test that tensors are properly batched when exceeding target size."""
|
||||
# Create many small tensors
|
||||
params = [
|
||||
(f"weight_{i}", torch.randn(10, 10, dtype=torch.float32).cuda())
|
||||
for i in range(20)
|
||||
]
|
||||
|
||||
mock_group = MockCommunicationGroup()
|
||||
|
||||
# Small target size to force multiple batches
|
||||
packed_broadcast_producer(
|
||||
iterator=iter(params),
|
||||
group=mock_group,
|
||||
src=0,
|
||||
post_iter_func=lambda x: x[1],
|
||||
buffer_size_bytes=2000,
|
||||
)
|
||||
|
||||
# Should have multiple broadcasts
|
||||
assert mock_group.broadcast_count > 1
|
||||
|
||||
# Total size should match sum of all tensors
|
||||
expected_total = sum(t.numel() * t.element_size() for _, t in params)
|
||||
actual_total = sum(t.numel() for t in mock_group.broadcasted_tensors)
|
||||
assert actual_total == expected_total
|
||||
"""Test packed_nccl_broadcast_producer function."""
|
||||
|
||||
def test_producer_empty_iterator(self):
|
||||
"""Test producer handles empty iterator gracefully."""
|
||||
mock_group = MockCommunicationGroup()
|
||||
|
||||
packed_broadcast_producer(
|
||||
packed_nccl_broadcast_producer(
|
||||
iterator=iter([]),
|
||||
group=mock_group,
|
||||
src=0,
|
||||
@@ -186,64 +117,6 @@ class TestPackedBroadcastProducer:
|
||||
assert mock_group.broadcast_count == 0
|
||||
|
||||
|
||||
# --- Unit Tests: packed_broadcast_consumer ---
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
class TestPackedBroadcastConsumer:
|
||||
"""Test packed_broadcast_consumer function."""
|
||||
|
||||
def test_consumer_receives_tensors(self):
|
||||
"""Test that consumer receives and unpacks tensors."""
|
||||
params = create_mock_model_params()
|
||||
params_cuda = [(name, tensor.cuda()) for name, tensor in params]
|
||||
|
||||
buffer_size = 2000
|
||||
|
||||
# First, run producer to get the broadcasted tensors
|
||||
producer_group = MockCommunicationGroup()
|
||||
|
||||
packed_broadcast_producer(
|
||||
iterator=iter(params_cuda),
|
||||
group=producer_group,
|
||||
src=0,
|
||||
post_iter_func=lambda x: x[1],
|
||||
buffer_size_bytes=buffer_size,
|
||||
)
|
||||
|
||||
# Now run consumer with the broadcasted tensors
|
||||
consumer_group = MockConsumerCommunicationGroup(
|
||||
producer_group.broadcasted_tensors
|
||||
)
|
||||
|
||||
state_dict_info = create_state_dict_info(params_cuda)
|
||||
|
||||
unpacked_tensors = {}
|
||||
|
||||
def post_unpack_func(tensor_list):
|
||||
for name, tensor in tensor_list:
|
||||
unpacked_tensors[name] = tensor.clone()
|
||||
|
||||
packed_broadcast_consumer(
|
||||
iterator=iter(state_dict_info.items()),
|
||||
group=consumer_group,
|
||||
src=0,
|
||||
post_unpack_func=post_unpack_func,
|
||||
buffer_size_bytes=buffer_size,
|
||||
)
|
||||
|
||||
# Verify all parameters were unpacked
|
||||
assert len(unpacked_tensors) == len(params)
|
||||
|
||||
# Verify each tensor matches the original
|
||||
for name, original_tensor in params_cuda:
|
||||
assert name in unpacked_tensors
|
||||
unpacked = unpacked_tensors[name]
|
||||
assert unpacked.shape == original_tensor.shape
|
||||
assert unpacked.dtype == original_tensor.dtype
|
||||
assert torch.allclose(unpacked, original_tensor, rtol=1e-5, atol=1e-7)
|
||||
|
||||
|
||||
# --- Integration Tests: Producer-Consumer Roundtrip ---
|
||||
|
||||
|
||||
@@ -260,7 +133,7 @@ class TestPackedBroadcastRoundtrip:
|
||||
buffer_size = 1000
|
||||
producer_group = MockCommunicationGroup()
|
||||
|
||||
packed_broadcast_producer(
|
||||
packed_nccl_broadcast_producer(
|
||||
iterator=iter(params_cuda),
|
||||
group=producer_group,
|
||||
src=0,
|
||||
@@ -279,7 +152,7 @@ class TestPackedBroadcastRoundtrip:
|
||||
for name, tensor in tensor_list:
|
||||
unpacked_tensors[name] = tensor.clone()
|
||||
|
||||
packed_broadcast_consumer(
|
||||
packed_nccl_broadcast_consumer(
|
||||
iterator=iter(state_dict_info.items()),
|
||||
group=consumer_group,
|
||||
src=0,
|
||||
@@ -306,7 +179,7 @@ class TestPackedBroadcastRoundtrip:
|
||||
buffer_size = 500
|
||||
producer_group = MockCommunicationGroup()
|
||||
|
||||
packed_broadcast_producer(
|
||||
packed_nccl_broadcast_producer(
|
||||
iterator=iter(params),
|
||||
group=producer_group,
|
||||
src=0,
|
||||
@@ -325,7 +198,7 @@ class TestPackedBroadcastRoundtrip:
|
||||
for name, tensor in tensor_list:
|
||||
unpacked_tensors[name] = tensor.clone()
|
||||
|
||||
packed_broadcast_consumer(
|
||||
packed_nccl_broadcast_consumer(
|
||||
iterator=iter(state_dict_info.items()),
|
||||
group=consumer_group,
|
||||
src=0,
|
||||
@@ -341,7 +214,7 @@ class TestPackedBroadcastRoundtrip:
|
||||
assert unpacked.dtype == original_tensor.dtype
|
||||
assert torch.allclose(unpacked, original_tensor, rtol=1e-4, atol=1e-6)
|
||||
|
||||
@pytest.mark.parametrize("target_size", [100, 1000, 10000, 100000])
|
||||
@pytest.mark.parametrize("target_size", [100, 100000])
|
||||
def test_roundtrip_different_batch_sizes(self, target_size):
|
||||
"""Test roundtrip with different target batch sizes."""
|
||||
params = create_mock_model_params(num_layers=5)
|
||||
@@ -349,7 +222,7 @@ class TestPackedBroadcastRoundtrip:
|
||||
|
||||
producer_group = MockCommunicationGroup()
|
||||
|
||||
packed_broadcast_producer(
|
||||
packed_nccl_broadcast_producer(
|
||||
iterator=iter(params_cuda),
|
||||
group=producer_group,
|
||||
src=0,
|
||||
@@ -368,7 +241,7 @@ class TestPackedBroadcastRoundtrip:
|
||||
for name, tensor in tensor_list:
|
||||
unpacked_tensors[name] = tensor.clone()
|
||||
|
||||
packed_broadcast_consumer(
|
||||
packed_nccl_broadcast_consumer(
|
||||
iterator=iter(state_dict_info.items()),
|
||||
group=consumer_group,
|
||||
src=0,
|
||||
@@ -407,7 +280,7 @@ class TestPackedBroadcastRoundtrip:
|
||||
buffer_size = 500
|
||||
producer_group = MockCommunicationGroup()
|
||||
|
||||
packed_broadcast_producer(
|
||||
packed_nccl_broadcast_producer(
|
||||
iterator=iter(params),
|
||||
group=producer_group,
|
||||
src=0,
|
||||
@@ -426,7 +299,7 @@ class TestPackedBroadcastRoundtrip:
|
||||
for name, tensor in tensor_list:
|
||||
unpacked_tensors[name] = tensor.clone()
|
||||
|
||||
packed_broadcast_consumer(
|
||||
packed_nccl_broadcast_consumer(
|
||||
iterator=iter(state_dict_info.items()),
|
||||
group=consumer_group,
|
||||
src=0,
|
||||
@@ -441,3 +314,462 @@ class TestPackedBroadcastRoundtrip:
|
||||
assert unpacked.shape == original_tensor.shape
|
||||
assert unpacked.dtype == original_tensor.dtype
|
||||
assert torch.allclose(unpacked, original_tensor, rtol=1e-4, atol=1e-6)
|
||||
|
||||
|
||||
# --- Unit Tests: unpack_tensor ---
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
class TestUnpackTensor:
|
||||
"""Test the shared unpack_tensor function."""
|
||||
|
||||
def test_unpack_produces_independent_copies(self):
|
||||
"""Verify unpacked tensors don't share memory with packed buffer."""
|
||||
original = torch.randn(10, dtype=torch.float32).cuda()
|
||||
packed = original.contiguous().view(torch.uint8).view(-1)
|
||||
|
||||
result = unpack_tensor(
|
||||
packed,
|
||||
names=["w"],
|
||||
shapes=[[10]],
|
||||
dtypes=[torch.float32],
|
||||
tensor_sizes=[packed.numel()],
|
||||
)
|
||||
|
||||
# Mutate the packed buffer
|
||||
packed.zero_()
|
||||
|
||||
# Unpacked tensor should be unaffected
|
||||
assert torch.allclose(result[0][1], original)
|
||||
|
||||
|
||||
# --- Unit Tests: pack_tensors ---
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
class TestPackTensors:
|
||||
"""Test the shared pack_tensors function."""
|
||||
|
||||
def test_pack_basic(self):
|
||||
"""Test packing a few tensors into one buffer."""
|
||||
params = [
|
||||
("w1", torch.randn(10, 20, dtype=torch.float32).cuda()),
|
||||
("w2", torch.randn(5, dtype=torch.float16).cuda()),
|
||||
]
|
||||
|
||||
chunk = pack_tensors(
|
||||
iterator=iter(params),
|
||||
post_iter_func=lambda x: x[1],
|
||||
buffer_size_bytes=10_000_000,
|
||||
)
|
||||
|
||||
assert chunk is not None
|
||||
assert len(chunk.names) == 2
|
||||
assert chunk.names == ["w1", "w2"]
|
||||
assert chunk.shapes == [[10, 20], [5]]
|
||||
assert chunk.dtypes == [torch.float32, torch.float16]
|
||||
assert chunk.packed_tensor.dtype == torch.uint8
|
||||
|
||||
def test_pack_respects_buffer_limit(self):
|
||||
"""Test that packing stops when buffer_size_bytes is exceeded."""
|
||||
params = [
|
||||
(f"w{i}", torch.randn(100, 100, dtype=torch.float32).cuda())
|
||||
for i in range(10)
|
||||
]
|
||||
|
||||
chunk = pack_tensors(
|
||||
iterator=iter(params),
|
||||
post_iter_func=lambda x: x[1],
|
||||
buffer_size_bytes=50_000,
|
||||
)
|
||||
|
||||
assert chunk is not None
|
||||
assert len(chunk.names) < 10
|
||||
|
||||
def test_pack_empty_iterator(self):
|
||||
"""Test that an empty iterator returns None."""
|
||||
chunk = pack_tensors(
|
||||
iterator=iter([]),
|
||||
post_iter_func=lambda x: x[1],
|
||||
buffer_size_bytes=1000,
|
||||
)
|
||||
assert chunk is None
|
||||
|
||||
def test_pack_single_tensor_larger_than_buffer_warns(self):
|
||||
"""Test that a tensor exceeding buffer_size_bytes emits a warning."""
|
||||
big = torch.randn(1000, 1000, dtype=torch.float32).cuda()
|
||||
params = [("big", big)]
|
||||
|
||||
import warnings
|
||||
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
warnings.simplefilter("always")
|
||||
chunk = pack_tensors(
|
||||
iterator=iter(params),
|
||||
post_iter_func=lambda x: x[1],
|
||||
buffer_size_bytes=100,
|
||||
)
|
||||
assert chunk is not None
|
||||
assert len(chunk.names) == 1
|
||||
assert any("exceeds buffer_size_bytes" in str(wi.message) for wi in w)
|
||||
|
||||
def test_pack_unpack_roundtrip(self):
|
||||
"""Test pack then unpack produces identical tensors."""
|
||||
params = [
|
||||
("a", torch.randn(8, 16, dtype=torch.float32).cuda()),
|
||||
("b", torch.randn(4, dtype=torch.float16).cuda()),
|
||||
("c", torch.randn(3, 5, 7, dtype=torch.bfloat16).cuda()),
|
||||
]
|
||||
|
||||
chunk = pack_tensors(
|
||||
iterator=iter(params),
|
||||
post_iter_func=lambda x: x[1],
|
||||
buffer_size_bytes=10_000_000,
|
||||
)
|
||||
|
||||
assert chunk is not None
|
||||
result = unpack_tensor(
|
||||
chunk.packed_tensor,
|
||||
chunk.names,
|
||||
chunk.shapes,
|
||||
chunk.dtypes,
|
||||
chunk.tensor_sizes,
|
||||
)
|
||||
|
||||
assert len(result) == len(params)
|
||||
for (orig_name, orig_tensor), (res_name, res_tensor) in zip(params, result):
|
||||
assert orig_name == res_name
|
||||
assert res_tensor.shape == orig_tensor.shape
|
||||
assert res_tensor.dtype == orig_tensor.dtype
|
||||
assert torch.allclose(res_tensor, orig_tensor, rtol=1e-4, atol=1e-6)
|
||||
|
||||
def test_pack_multiple_chunks(self):
|
||||
"""Test consuming an iterator across multiple pack_tensors calls."""
|
||||
params = [
|
||||
(f"w{i}", torch.randn(50, 50, dtype=torch.float32).cuda()) for i in range(6)
|
||||
]
|
||||
it = iter(params)
|
||||
|
||||
all_names = []
|
||||
chunks = []
|
||||
while True:
|
||||
chunk = pack_tensors(it, lambda x: x[1], buffer_size_bytes=12_000)
|
||||
if chunk is None:
|
||||
break
|
||||
chunks.append(chunk)
|
||||
all_names.extend(chunk.names)
|
||||
|
||||
assert len(chunks) > 1
|
||||
assert all_names == [f"w{i}" for i in range(6)]
|
||||
|
||||
|
||||
# --- Unit Tests: packed_ipc_producer ---
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
class TestPackedIpcProducer:
|
||||
"""Test the packed_ipc_producer generator."""
|
||||
|
||||
def test_producer_yields_chunks(self):
|
||||
"""Test that the producer yields PackedIpcChunk objects."""
|
||||
params = [
|
||||
(f"w{i}", torch.randn(50, 50, dtype=torch.float32).cuda()) for i in range(6)
|
||||
]
|
||||
|
||||
chunks = list(
|
||||
packed_ipc_producer(
|
||||
iterator=iter(params),
|
||||
gpu_uuid="test-uuid",
|
||||
post_iter_func=lambda x: x[1],
|
||||
buffer_size_bytes=12_000,
|
||||
)
|
||||
)
|
||||
|
||||
assert len(chunks) > 1
|
||||
|
||||
def test_producer_ipc_handle_has_uuid(self):
|
||||
"""Test that each chunk's ipc_handle is keyed by the given UUID."""
|
||||
params = [("w", torch.randn(10, dtype=torch.float32).cuda())]
|
||||
|
||||
chunks = list(
|
||||
packed_ipc_producer(
|
||||
iterator=iter(params),
|
||||
gpu_uuid="my-gpu-uuid",
|
||||
post_iter_func=lambda x: x[1],
|
||||
buffer_size_bytes=10_000_000,
|
||||
)
|
||||
)
|
||||
|
||||
assert "my-gpu-uuid" in chunks[0].ipc_handle
|
||||
|
||||
def test_producer_dtype_names_are_strings(self):
|
||||
"""Test that dtype_names are string representations."""
|
||||
params = [
|
||||
("a", torch.randn(10, dtype=torch.float32).cuda()),
|
||||
("b", torch.randn(10, dtype=torch.float16).cuda()),
|
||||
]
|
||||
|
||||
chunks = list(
|
||||
packed_ipc_producer(
|
||||
iterator=iter(params),
|
||||
gpu_uuid="uuid",
|
||||
post_iter_func=lambda x: x[1],
|
||||
buffer_size_bytes=10_000_000,
|
||||
)
|
||||
)
|
||||
|
||||
assert chunks[0].dtype_names == ["float32", "float16"]
|
||||
|
||||
def test_producer_empty_iterator(self):
|
||||
"""Test producer with empty iterator yields nothing."""
|
||||
chunks = list(
|
||||
packed_ipc_producer(
|
||||
iterator=iter([]),
|
||||
gpu_uuid="uuid",
|
||||
post_iter_func=lambda x: x[1],
|
||||
buffer_size_bytes=1000,
|
||||
)
|
||||
)
|
||||
assert len(chunks) == 0
|
||||
|
||||
|
||||
# --- Integration Tests: IPC Producer-Consumer Roundtrip ---
|
||||
|
||||
|
||||
def _ipc_consumer_worker(cmd_q, ack_q, result_q, done_event, device_index):
|
||||
"""Worker that consumes chunks streamed one at a time from the parent.
|
||||
|
||||
CUDA IPC requires the consumer to be in a separate process from the
|
||||
producer. The producer reuses a single IPC buffer between chunks, so
|
||||
the parent must wait for our ack (sent after we copy the chunk to
|
||||
CPU) before advancing the producer.
|
||||
"""
|
||||
try:
|
||||
torch.accelerator.set_device_index(device_index)
|
||||
all_results = []
|
||||
while True:
|
||||
cd = cmd_q.get()
|
||||
if cd is None:
|
||||
break
|
||||
result = packed_ipc_consumer(
|
||||
ipc_handle=cd["ipc_handle"],
|
||||
names=cd["names"],
|
||||
shapes=cd["shapes"],
|
||||
dtype_names=cd["dtype_names"],
|
||||
tensor_sizes=cd["tensor_sizes"],
|
||||
device_index=device_index,
|
||||
)
|
||||
# .cpu() forces a GPU→CPU copy off the shared IPC buffer, so
|
||||
# the producer is free to overwrite it once we ack.
|
||||
all_results.extend([(name, tensor.cpu()) for name, tensor in result])
|
||||
del result
|
||||
ack_q.put("ack")
|
||||
result_q.put(("ok", all_results))
|
||||
except Exception as e:
|
||||
result_q.put(("error", str(e)))
|
||||
# Keep the process alive until the parent has finished reading from
|
||||
# the result queue — torch serializes CPU tensors via fd sharing,
|
||||
# which requires this process's resource-sharer server to be running.
|
||||
done_event.wait(timeout=60)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
class TestPackedIpcRoundtrip:
|
||||
"""Test IPC producer-consumer roundtrip using real CUDA IPC.
|
||||
|
||||
These tests spawn a child process for the consumer because
|
||||
rebuild_cuda_tensor requires a separate process from the one that
|
||||
called reduce_tensor.
|
||||
"""
|
||||
|
||||
def _get_gpu_uuid(self) -> str:
|
||||
device_index = torch.cuda.current_device()
|
||||
props = torch.cuda.get_device_properties(device_index)
|
||||
return str(props.uuid)
|
||||
|
||||
def _run_roundtrip(self, chunk_iter, device_index, timeout=30):
|
||||
"""Stream chunks through a child consumer one at a time.
|
||||
|
||||
``packed_ipc_producer`` reuses a single IPC buffer for every
|
||||
chunk, so the producer must not be advanced until the consumer
|
||||
has finished reading the current chunk. We enforce that with an
|
||||
ack queue: the consumer puts ``"ack"`` after it has copied the
|
||||
chunk to CPU, and only then do we pull the next chunk from the
|
||||
generator.
|
||||
|
||||
Returns ``(num_chunks, results)``.
|
||||
"""
|
||||
import multiprocessing as mp
|
||||
|
||||
ctx = mp.get_context("spawn")
|
||||
cmd_q = ctx.Queue()
|
||||
ack_q = ctx.Queue()
|
||||
result_q = ctx.Queue()
|
||||
done_event = ctx.Event()
|
||||
proc = ctx.Process(
|
||||
target=_ipc_consumer_worker,
|
||||
args=(cmd_q, ack_q, result_q, done_event, device_index),
|
||||
)
|
||||
proc.start()
|
||||
|
||||
num_chunks = 0
|
||||
try:
|
||||
for chunk in chunk_iter:
|
||||
cmd_q.put(
|
||||
{
|
||||
"ipc_handle": chunk.ipc_handle,
|
||||
"names": chunk.names,
|
||||
"shapes": chunk.shapes,
|
||||
"dtype_names": chunk.dtype_names,
|
||||
"tensor_sizes": chunk.tensor_sizes,
|
||||
}
|
||||
)
|
||||
if ack_q.get(timeout=timeout) != "ack":
|
||||
raise RuntimeError("Consumer did not ack chunk")
|
||||
num_chunks += 1
|
||||
cmd_q.put(None)
|
||||
status, payload = result_q.get(timeout=timeout)
|
||||
finally:
|
||||
done_event.set()
|
||||
proc.join(timeout=10)
|
||||
if proc.is_alive():
|
||||
proc.kill()
|
||||
|
||||
if status == "error":
|
||||
raise RuntimeError(f"Consumer process failed: {payload}")
|
||||
# Reclaim IPC-shared memory now that the child has released it
|
||||
torch.cuda.ipc_collect()
|
||||
return num_chunks, payload
|
||||
|
||||
def test_roundtrip_basic(self):
|
||||
"""Test basic IPC producer -> consumer roundtrip."""
|
||||
params = [
|
||||
("w1", torch.randn(10, 20, dtype=torch.float32).cuda()),
|
||||
("w2", torch.randn(5, dtype=torch.float16).cuda()),
|
||||
]
|
||||
gpu_uuid = self._get_gpu_uuid()
|
||||
device_index = torch.cuda.current_device()
|
||||
|
||||
num_chunks, result = self._run_roundtrip(
|
||||
packed_ipc_producer(
|
||||
iterator=iter(params),
|
||||
gpu_uuid=gpu_uuid,
|
||||
post_iter_func=lambda x: x[1],
|
||||
buffer_size_bytes=10_000_000,
|
||||
),
|
||||
device_index,
|
||||
)
|
||||
|
||||
assert num_chunks == 1
|
||||
assert len(result) == 2
|
||||
for (orig_name, orig_tensor), (res_name, res_tensor) in zip(params, result):
|
||||
assert orig_name == res_name
|
||||
assert res_tensor.shape == orig_tensor.shape
|
||||
assert res_tensor.dtype == orig_tensor.dtype
|
||||
assert torch.allclose(res_tensor, orig_tensor.cpu(), rtol=1e-4, atol=1e-6)
|
||||
|
||||
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
|
||||
def test_roundtrip_dtypes(self, dtype):
|
||||
"""Test IPC roundtrip with different dtypes."""
|
||||
params = create_mock_model_params(num_layers=2, dtype=dtype)
|
||||
params_cuda = [(n, t.cuda()) for n, t in params]
|
||||
gpu_uuid = self._get_gpu_uuid()
|
||||
device_index = torch.cuda.current_device()
|
||||
|
||||
_, result = self._run_roundtrip(
|
||||
packed_ipc_producer(
|
||||
iterator=iter(params_cuda),
|
||||
gpu_uuid=gpu_uuid,
|
||||
post_iter_func=lambda x: x[1],
|
||||
buffer_size_bytes=10_000_000,
|
||||
),
|
||||
device_index,
|
||||
)
|
||||
|
||||
assert len(result) == len(params_cuda)
|
||||
for (orig_name, orig_tensor), (res_name, res_tensor) in zip(
|
||||
params_cuda, result
|
||||
):
|
||||
assert orig_name == res_name
|
||||
assert res_tensor.dtype == dtype
|
||||
assert torch.allclose(res_tensor, orig_tensor.cpu(), rtol=1e-4, atol=1e-6)
|
||||
|
||||
def test_roundtrip_multiple_chunks(self):
|
||||
"""Test IPC roundtrip across multiple chunks."""
|
||||
params = [
|
||||
(f"layer{i}.weight", torch.randn(100, 100, dtype=torch.float32).cuda())
|
||||
for i in range(8)
|
||||
]
|
||||
gpu_uuid = self._get_gpu_uuid()
|
||||
device_index = torch.cuda.current_device()
|
||||
|
||||
num_chunks, result = self._run_roundtrip(
|
||||
packed_ipc_producer(
|
||||
iterator=iter(params),
|
||||
gpu_uuid=gpu_uuid,
|
||||
post_iter_func=lambda x: x[1],
|
||||
buffer_size_bytes=50_000,
|
||||
),
|
||||
device_index,
|
||||
)
|
||||
|
||||
assert num_chunks > 1
|
||||
assert len(result) == len(params)
|
||||
for (orig_name, orig_tensor), (res_name, res_tensor) in zip(params, result):
|
||||
assert orig_name == res_name
|
||||
assert torch.allclose(res_tensor, orig_tensor.cpu(), rtol=1e-5, atol=1e-7)
|
||||
|
||||
def test_roundtrip_non_contiguous(self):
|
||||
"""Test IPC roundtrip with non-contiguous tensors."""
|
||||
params = [
|
||||
("transposed", torch.randn(20, 10, dtype=torch.float32).cuda().T),
|
||||
("sliced", torch.randn(40, 30, dtype=torch.float16).cuda()[::2, ::2]),
|
||||
]
|
||||
gpu_uuid = self._get_gpu_uuid()
|
||||
device_index = torch.cuda.current_device()
|
||||
|
||||
for _, t in params:
|
||||
assert not t.is_contiguous()
|
||||
|
||||
_, result = self._run_roundtrip(
|
||||
packed_ipc_producer(
|
||||
iterator=iter(params),
|
||||
gpu_uuid=gpu_uuid,
|
||||
post_iter_func=lambda x: x[1],
|
||||
buffer_size_bytes=10_000_000,
|
||||
),
|
||||
device_index,
|
||||
)
|
||||
|
||||
for (orig_name, orig_tensor), (res_name, res_tensor) in zip(params, result):
|
||||
assert orig_name == res_name
|
||||
assert res_tensor.shape == orig_tensor.shape
|
||||
assert res_tensor.dtype == orig_tensor.dtype
|
||||
assert torch.allclose(res_tensor, orig_tensor.cpu(), rtol=1e-4, atol=1e-6)
|
||||
|
||||
def test_consumer_wrong_uuid_raises(self):
|
||||
"""Test that consumer raises ValueError for unknown GPU UUID."""
|
||||
params = [("w", torch.randn(10, dtype=torch.float32).cuda())]
|
||||
gpu_uuid = self._get_gpu_uuid()
|
||||
|
||||
chunks = list(
|
||||
packed_ipc_producer(
|
||||
iterator=iter(params),
|
||||
gpu_uuid=gpu_uuid,
|
||||
post_iter_func=lambda x: x[1],
|
||||
buffer_size_bytes=10_000_000,
|
||||
)
|
||||
)
|
||||
|
||||
c = chunks[0]
|
||||
fake_handle = {"fake-uuid-12345": c.ipc_handle[gpu_uuid]}
|
||||
|
||||
with pytest.raises(ValueError, match="IPC handle not found"):
|
||||
packed_ipc_consumer(
|
||||
ipc_handle=fake_handle,
|
||||
names=c.names,
|
||||
shapes=c.shapes,
|
||||
dtype_names=c.dtype_names,
|
||||
tensor_sizes=c.tensor_sizes,
|
||||
device_index=torch.cuda.current_device(),
|
||||
)
|
||||
|
||||
@@ -389,7 +389,7 @@ class TestIPCWeightTransferUpdateInfoValidation:
|
||||
|
||||
# Create a dummy tensor and IPC handle
|
||||
dummy_tensor = torch.ones(10, 10, device="cuda:0")
|
||||
ipc_handle = reduce_tensor(dummy_tensor)
|
||||
_, ipc_handle = reduce_tensor(dummy_tensor)
|
||||
gpu_uuid = str(torch.cuda.get_device_properties(0).uuid)
|
||||
ipc_handles = [{gpu_uuid: ipc_handle}]
|
||||
|
||||
@@ -410,7 +410,7 @@ class TestIPCWeightTransferUpdateInfoValidation:
|
||||
pytest.skip("Need at least 1 GPU for this test")
|
||||
|
||||
dummy_tensor = torch.ones(10, 10, device="cuda:0")
|
||||
ipc_handle = reduce_tensor(dummy_tensor)
|
||||
_, ipc_handle = reduce_tensor(dummy_tensor)
|
||||
gpu_uuid = str(torch.cuda.get_device_properties(0).uuid)
|
||||
ipc_handles = [{gpu_uuid: ipc_handle}, {gpu_uuid: ipc_handle}]
|
||||
|
||||
@@ -428,7 +428,7 @@ class TestIPCWeightTransferUpdateInfoValidation:
|
||||
pytest.skip("Need at least 1 GPU for this test")
|
||||
|
||||
dummy_tensor = torch.ones(10, 10, device="cuda:0")
|
||||
ipc_handle = reduce_tensor(dummy_tensor)
|
||||
_, ipc_handle = reduce_tensor(dummy_tensor)
|
||||
gpu_uuid = str(torch.cuda.get_device_properties(0).uuid)
|
||||
ipc_handles = [{gpu_uuid: ipc_handle}, {gpu_uuid: ipc_handle}]
|
||||
|
||||
@@ -446,7 +446,7 @@ class TestIPCWeightTransferUpdateInfoValidation:
|
||||
pytest.skip("Need at least 1 GPU for this test")
|
||||
|
||||
dummy_tensor = torch.ones(10, 10, device="cuda:0")
|
||||
ipc_handle = reduce_tensor(dummy_tensor)
|
||||
_, ipc_handle = reduce_tensor(dummy_tensor)
|
||||
gpu_uuid = str(torch.cuda.get_device_properties(0).uuid)
|
||||
ipc_handles = [{gpu_uuid: ipc_handle}] # Only one handle
|
||||
|
||||
@@ -458,65 +458,9 @@ class TestIPCWeightTransferUpdateInfoValidation:
|
||||
ipc_handles=ipc_handles,
|
||||
)
|
||||
|
||||
def test_valid_update_info_from_pickled(self, monkeypatch):
|
||||
"""Test creating IPCWeightTransferUpdateInfo from pickled handles."""
|
||||
if torch.accelerator.device_count() < 1:
|
||||
pytest.skip("Need at least 1 GPU for this test")
|
||||
|
||||
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
|
||||
|
||||
dummy_tensor = torch.ones(10, 10, device="cuda:0")
|
||||
ipc_handle = reduce_tensor(dummy_tensor)
|
||||
gpu_uuid = str(torch.cuda.get_device_properties(0).uuid)
|
||||
ipc_handles = [{gpu_uuid: ipc_handle}]
|
||||
|
||||
pickled = base64.b64encode(pickle.dumps(ipc_handles)).decode("utf-8")
|
||||
|
||||
info = IPCWeightTransferUpdateInfo(
|
||||
names=["layer.weight"],
|
||||
dtype_names=["float32"],
|
||||
shapes=[[10, 10]],
|
||||
ipc_handles_pickled=pickled,
|
||||
)
|
||||
assert info.ipc_handles == ipc_handles
|
||||
assert info.ipc_handles_pickled is None
|
||||
|
||||
def test_pickled_requires_insecure_serialization_flag(self, monkeypatch):
|
||||
"""Test that pickled handles are rejected unless env flag is enabled."""
|
||||
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "0")
|
||||
|
||||
with pytest.raises(ValueError, match="VLLM_ALLOW_INSECURE_SERIALIZATION=1"):
|
||||
IPCWeightTransferUpdateInfo(
|
||||
names=[],
|
||||
dtype_names=[],
|
||||
shapes=[],
|
||||
ipc_handles_pickled=base64.b64encode(pickle.dumps([])).decode("utf-8"),
|
||||
)
|
||||
|
||||
def test_both_handles_and_pickled_raises(self):
|
||||
"""Test that providing both ipc_handles and ipc_handles_pickled raises."""
|
||||
if torch.accelerator.device_count() < 1:
|
||||
pytest.skip("Need at least 1 GPU for this test")
|
||||
|
||||
dummy_tensor = torch.ones(10, 10, device="cuda:0")
|
||||
ipc_handle = reduce_tensor(dummy_tensor)
|
||||
gpu_uuid = str(torch.cuda.get_device_properties(0).uuid)
|
||||
ipc_handles = [{gpu_uuid: ipc_handle}]
|
||||
|
||||
pickled = base64.b64encode(pickle.dumps(ipc_handles)).decode("utf-8")
|
||||
|
||||
with pytest.raises(ValueError, match="Cannot specify both"):
|
||||
IPCWeightTransferUpdateInfo(
|
||||
names=["layer.weight"],
|
||||
dtype_names=["float32"],
|
||||
shapes=[[10, 10]],
|
||||
ipc_handles=ipc_handles,
|
||||
ipc_handles_pickled=pickled,
|
||||
)
|
||||
|
||||
def test_neither_handles_nor_pickled_raises(self):
|
||||
"""Test that providing neither ipc_handles nor ipc_handles_pickled raises."""
|
||||
with pytest.raises(ValueError, match="must be provided"):
|
||||
def test_missing_ipc_handles_raises(self):
|
||||
"""Test that omitting ipc_handles raises TypeError."""
|
||||
with pytest.raises(TypeError):
|
||||
IPCWeightTransferUpdateInfo(
|
||||
names=["layer.weight"],
|
||||
dtype_names=["float32"],
|
||||
@@ -552,10 +496,10 @@ class TestIPCEngineParsing:
|
||||
# Create dummy IPC handles
|
||||
dummy_tensor1 = torch.ones(100, 100, device="cuda:0")
|
||||
dummy_tensor2 = torch.ones(50, device="cuda:0")
|
||||
ipc_handle1 = reduce_tensor(dummy_tensor1)
|
||||
ipc_handle2 = reduce_tensor(dummy_tensor2)
|
||||
_, ipc_args1 = reduce_tensor(dummy_tensor1)
|
||||
_, ipc_args2 = reduce_tensor(dummy_tensor2)
|
||||
gpu_uuid = str(torch.cuda.get_device_properties(0).uuid)
|
||||
ipc_handles = [{gpu_uuid: ipc_handle1}, {gpu_uuid: ipc_handle2}]
|
||||
ipc_handles = [{gpu_uuid: ipc_args1}, {gpu_uuid: ipc_args2}]
|
||||
|
||||
update_info = engine.parse_update_info(
|
||||
{
|
||||
@@ -585,10 +529,10 @@ class TestIPCEngineParsing:
|
||||
|
||||
dummy_tensor1 = torch.ones(100, 100, device="cuda:0")
|
||||
dummy_tensor2 = torch.ones(50, device="cuda:0")
|
||||
ipc_handle1 = reduce_tensor(dummy_tensor1)
|
||||
ipc_handle2 = reduce_tensor(dummy_tensor2)
|
||||
_, ipc_args1 = reduce_tensor(dummy_tensor1)
|
||||
_, ipc_args2 = reduce_tensor(dummy_tensor2)
|
||||
gpu_uuid = str(torch.cuda.get_device_properties(0).uuid)
|
||||
ipc_handles = [{gpu_uuid: ipc_handle1}, {gpu_uuid: ipc_handle2}]
|
||||
ipc_handles = [{gpu_uuid: ipc_args1}, {gpu_uuid: ipc_args2}]
|
||||
|
||||
pickled = base64.b64encode(pickle.dumps(ipc_handles)).decode("utf-8")
|
||||
|
||||
@@ -604,10 +548,36 @@ class TestIPCEngineParsing:
|
||||
assert isinstance(update_info, IPCWeightTransferUpdateInfo)
|
||||
assert update_info.names == ["w1", "w2"]
|
||||
assert len(update_info.ipc_handles) == 2
|
||||
assert update_info.ipc_handles_pickled is None
|
||||
assert gpu_uuid in update_info.ipc_handles[0]
|
||||
assert gpu_uuid in update_info.ipc_handles[1]
|
||||
|
||||
def test_parse_update_info_both_handles_and_pickled_raises(self):
|
||||
"""Test that providing both ipc_handles and ipc_handles_pickled raises."""
|
||||
if torch.accelerator.device_count() < 1:
|
||||
pytest.skip("Need at least 1 GPU for this test")
|
||||
|
||||
config = WeightTransferConfig(backend="ipc")
|
||||
parallel_config = create_mock_parallel_config()
|
||||
engine = IPCWeightTransferEngine(config, parallel_config)
|
||||
|
||||
dummy_tensor = torch.ones(10, 10, device="cuda:0")
|
||||
_, ipc_handle = reduce_tensor(dummy_tensor)
|
||||
gpu_uuid = str(torch.cuda.get_device_properties(0).uuid)
|
||||
ipc_handles = [{gpu_uuid: ipc_handle}]
|
||||
|
||||
pickled = base64.b64encode(pickle.dumps(ipc_handles)).decode("utf-8")
|
||||
|
||||
with pytest.raises(ValueError, match="Cannot specify both"):
|
||||
engine.parse_update_info(
|
||||
{
|
||||
"names": ["layer.weight"],
|
||||
"dtype_names": ["float32"],
|
||||
"shapes": [[10, 10]],
|
||||
"ipc_handles": ipc_handles,
|
||||
"ipc_handles_pickled": pickled,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
# --- Integration Test: IPC Weight Transfer Between Ray Tasks ---
|
||||
|
||||
@@ -629,13 +599,15 @@ class TrainerActor:
|
||||
self.tensor.fill_(42.0) # Fill with 42 to verify correct transfer
|
||||
|
||||
# Create IPC handle (tensor must stay alive for IPC to work)
|
||||
ipc_handle = reduce_tensor(self.tensor)
|
||||
# reduce_tensor returns (rebuild_func, args); we only send args
|
||||
# since the receiver imports rebuild_cuda_tensor directly.
|
||||
_, ipc_args = reduce_tensor(self.tensor)
|
||||
gpu_uuid = get_physical_gpu_id(0)
|
||||
|
||||
torch.accelerator.synchronize()
|
||||
|
||||
self.ipc_handle_dict = {
|
||||
"ipc_handle": ipc_handle,
|
||||
"ipc_handle": ipc_args,
|
||||
"gpu_uuid": gpu_uuid,
|
||||
"shape": tensor_shape,
|
||||
"dtype": tensor_dtype,
|
||||
@@ -652,6 +624,12 @@ def inference_receive_ipc_tensor(
|
||||
mode: str = "ray",
|
||||
) -> dict:
|
||||
"""Inference task that receives tensor via IPCWeightTransferEngine."""
|
||||
import os
|
||||
|
||||
# Worker-side: ipc_handles_pickled is deserialized via pickle.
|
||||
if mode == "http":
|
||||
os.environ["VLLM_ALLOW_INSECURE_SERIALIZATION"] = "1"
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import torch
|
||||
@@ -684,7 +662,6 @@ def inference_receive_ipc_tensor(
|
||||
# Clone tensor to keep it after engine cleans up
|
||||
received_tensors.append((name, tensor.clone()))
|
||||
|
||||
# Build update dict and go through parse_update_info (exercises __post_init__)
|
||||
ipc_handles = [{ipc_handle_dict["gpu_uuid"]: ipc_handle_dict["ipc_handle"]}]
|
||||
|
||||
if mode == "ray":
|
||||
@@ -695,6 +672,7 @@ def inference_receive_ipc_tensor(
|
||||
"ipc_handles": ipc_handles,
|
||||
}
|
||||
elif mode == "http":
|
||||
# Simulate HTTP transport: pickle + base64 encode handles
|
||||
pickled = base64.b64encode(pickle.dumps(ipc_handles)).decode("utf-8")
|
||||
update_dict = {
|
||||
"names": ["test.weight"],
|
||||
@@ -743,7 +721,8 @@ def test_ipc_weight_transfer_between_processes(mode: str):
|
||||
|
||||
Parametrized over transport modes:
|
||||
- 'ray': ipc_handles passed directly.
|
||||
- 'http': ipc_handles pickled + base64-encoded, unpickled via __post_init__.
|
||||
- 'http': ipc_handles pickled + base64-encoded, deserialized in
|
||||
parse_update_info before constructing the dataclass.
|
||||
|
||||
IPC requires same-GPU access, so we use a placement group to co-locate
|
||||
the trainer actor and inference task on the same GPU.
|
||||
@@ -801,7 +780,7 @@ def test_ipc_receive_weights_missing_gpu_uuid_raises():
|
||||
|
||||
# Create IPC handle with wrong GPU UUID
|
||||
dummy_tensor = torch.ones(10, 10, device="cuda:0")
|
||||
ipc_handle = reduce_tensor(dummy_tensor)
|
||||
_, ipc_handle = reduce_tensor(dummy_tensor)
|
||||
wrong_uuid = "wrong-uuid-12345"
|
||||
ipc_handles = [{wrong_uuid: ipc_handle}]
|
||||
|
||||
|
||||
@@ -15,7 +15,12 @@ _TORCH_CUDA_PATTERNS = [
|
||||
r"\bcuda_device_count_stateless\(\)\b",
|
||||
]
|
||||
|
||||
ALLOWED_FILES = {"vllm/platforms/", "vllm/device_allocator/"}
|
||||
ALLOWED_FILES = {
|
||||
"vllm/platforms/",
|
||||
"vllm/device_allocator/",
|
||||
"vllm/distributed/weight_transfer/ipc_engine.py",
|
||||
"tests/distributed/test_packed_tensor.py",
|
||||
}
|
||||
|
||||
|
||||
def scan_file(path: str) -> int:
|
||||
|
||||
@@ -8,9 +8,10 @@ from dataclasses import asdict, dataclass
|
||||
from typing import Any
|
||||
|
||||
import pybase64 as base64
|
||||
import ray
|
||||
import requests
|
||||
import torch
|
||||
from torch.multiprocessing.reductions import reduce_tensor
|
||||
from torch.multiprocessing.reductions import rebuild_cuda_tensor, reduce_tensor
|
||||
|
||||
from vllm import envs
|
||||
from vllm.config.parallel import ParallelConfig
|
||||
@@ -20,27 +21,43 @@ from vllm.distributed.weight_transfer.base import (
|
||||
WeightTransferInitInfo,
|
||||
WeightTransferUpdateInfo,
|
||||
)
|
||||
from vllm.distributed.weight_transfer.packed_tensor import (
|
||||
DEFAULT_PACKED_BUFFER_SIZE_BYTES,
|
||||
packed_ipc_consumer,
|
||||
packed_ipc_producer,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class IPCTrainerSendWeightsArgs:
|
||||
"""Arguments for IPC trainer_send_weights method."""
|
||||
|
||||
mode: str
|
||||
"""Transport mode: 'http' or 'ray'."""
|
||||
send_mode: str | Callable[["IPCWeightTransferUpdateInfo"], None]
|
||||
"""How to send updates to vLLM. Either a string ('ray' or 'http') for
|
||||
built-in transports, or a callable that receives an
|
||||
IPCWeightTransferUpdateInfo and performs the send."""
|
||||
llm_handle: Any = None
|
||||
"""Ray ObjectRef to LLM handle (required for 'ray' mode)."""
|
||||
"""Ray actor handle or list of handles (required for 'ray' send_mode)."""
|
||||
url: str | None = None
|
||||
"""Base URL for HTTP endpoint (required for 'http' mode)."""
|
||||
"""Base URL for HTTP endpoint (required for 'http' send_mode)."""
|
||||
packed: bool = False
|
||||
"""Whether to use packed tensor transfer for bounded-memory chunking."""
|
||||
packed_buffer_size_bytes: int = DEFAULT_PACKED_BUFFER_SIZE_BYTES
|
||||
"""Size in bytes for each packed tensor buffer when packed=True."""
|
||||
|
||||
def __post_init__(self):
|
||||
"""Validate that required arguments are provided for the selected mode."""
|
||||
if self.mode == "ray" and self.llm_handle is None:
|
||||
raise ValueError("llm_handle is required for 'ray' mode")
|
||||
if self.mode == "http" and self.url is None:
|
||||
raise ValueError("url is required for 'http' mode")
|
||||
if self.mode not in ("ray", "http"):
|
||||
raise ValueError(f"mode must be 'ray' or 'http', got {self.mode}")
|
||||
if callable(self.send_mode):
|
||||
return
|
||||
if self.send_mode == "ray" and self.llm_handle is None:
|
||||
raise ValueError("llm_handle is required for 'ray' send_mode")
|
||||
if self.send_mode == "http" and self.url is None:
|
||||
raise ValueError("url is required for 'http' send_mode")
|
||||
if self.send_mode not in ("ray", "http"):
|
||||
raise ValueError(
|
||||
f"send_mode must be 'ray', 'http', or a callable, "
|
||||
f"got {self.send_mode!r}"
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -52,44 +69,22 @@ class IPCWeightTransferInitInfo(WeightTransferInitInfo):
|
||||
|
||||
@dataclass
|
||||
class IPCWeightTransferUpdateInfo(WeightTransferUpdateInfo):
|
||||
"""Update info for IPC weight transfer backend.
|
||||
|
||||
Accepts IPC handles either directly via ``ipc_handles`` (Ray transport)
|
||||
or as a base64-encoded pickle via ``ipc_handles_pickled`` (HTTP transport).
|
||||
Exactly one of the two must be provided; if ``ipc_handles_pickled`` is set
|
||||
it is unpickled into ``ipc_handles`` during ``__post_init__``.
|
||||
"""
|
||||
"""Update info for IPC weight transfer backend."""
|
||||
|
||||
names: list[str]
|
||||
dtype_names: list[str]
|
||||
shapes: list[list[int]]
|
||||
ipc_handles: list[dict[str, tuple[Callable, tuple]]] | None = None
|
||||
"""IPC handles mapping physical GPU UUID to (func, args) tuple.
|
||||
Each handle is a dictionary mapping GPU UUID strings to IPC handle tuples."""
|
||||
ipc_handles_pickled: str | None = None
|
||||
"""Base64-encoded pickled IPC handles, used for HTTP transport."""
|
||||
ipc_handles: list[dict[str, tuple]] | dict[str, tuple]
|
||||
"""IPC handles mapping physical GPU UUID to rebuild_cuda_tensor args.
|
||||
For non-packed mode: list of per-parameter handle dicts.
|
||||
For packed mode: single handle dict for the packed buffer."""
|
||||
tensor_sizes: list[int] | None = None
|
||||
"""Per-parameter sizes in bytes within the packed buffer.
|
||||
Required when packed=True, unused otherwise."""
|
||||
packed: bool = False
|
||||
"""Whether this update uses packed tensor format."""
|
||||
|
||||
def __post_init__(self):
|
||||
if self.ipc_handles_pickled is not None:
|
||||
if self.ipc_handles is not None:
|
||||
raise ValueError(
|
||||
"Cannot specify both `ipc_handles` and `ipc_handles_pickled`"
|
||||
)
|
||||
|
||||
if not envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
|
||||
raise ValueError(
|
||||
"Refusing to deserialize `ipc_handles_pickled` without "
|
||||
"VLLM_ALLOW_INSECURE_SERIALIZATION=1"
|
||||
)
|
||||
|
||||
self.ipc_handles = pickle.loads(base64.b64decode(self.ipc_handles_pickled))
|
||||
self.ipc_handles_pickled = None
|
||||
|
||||
if self.ipc_handles is None:
|
||||
raise ValueError(
|
||||
"Either `ipc_handles` or `ipc_handles_pickled` must be provided"
|
||||
)
|
||||
|
||||
num_params = len(self.names)
|
||||
if len(self.dtype_names) != num_params:
|
||||
raise ValueError(
|
||||
@@ -101,11 +96,17 @@ class IPCWeightTransferUpdateInfo(WeightTransferUpdateInfo):
|
||||
f"`shapes` should be of the same size as `names`: "
|
||||
f"got {len(self.shapes)} and {len(self.names)}"
|
||||
)
|
||||
if len(self.ipc_handles) != num_params:
|
||||
if (
|
||||
not self.packed
|
||||
and isinstance(self.ipc_handles, list)
|
||||
and len(self.ipc_handles) != num_params
|
||||
):
|
||||
raise ValueError(
|
||||
f"`ipc_handles` should be of the same size as `names`: "
|
||||
f"got {len(self.ipc_handles)} and {len(self.names)}"
|
||||
)
|
||||
if self.packed and self.tensor_sizes is None:
|
||||
raise ValueError("`tensor_sizes` is required when packed=True")
|
||||
|
||||
|
||||
class IPCWeightTransferEngine(
|
||||
@@ -135,6 +136,36 @@ class IPCWeightTransferEngine(
|
||||
"""
|
||||
super().__init__(config, parallel_config)
|
||||
|
||||
def parse_update_info(
|
||||
self, update_dict: dict[str, Any]
|
||||
) -> IPCWeightTransferUpdateInfo:
|
||||
"""Parse update dict, deserializing pickled IPC handles if present.
|
||||
|
||||
HTTP transport sends IPC handles as a base64-encoded pickle under the
|
||||
key ``ipc_handles_pickled``. This method deserializes them back into
|
||||
``ipc_handles`` before constructing the typed dataclass, keeping
|
||||
serialization concerns out of the dataclass itself.
|
||||
|
||||
Requires ``VLLM_ALLOW_INSECURE_SERIALIZATION=1`` because the
|
||||
payload is deserialized via ``pickle.loads``.
|
||||
"""
|
||||
if "ipc_handles_pickled" in update_dict:
|
||||
if "ipc_handles" in update_dict:
|
||||
raise ValueError(
|
||||
"Cannot specify both `ipc_handles` and `ipc_handles_pickled`"
|
||||
)
|
||||
|
||||
if not envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
|
||||
raise ValueError(
|
||||
"Refusing to deserialize `ipc_handles_pickled` without "
|
||||
"VLLM_ALLOW_INSECURE_SERIALIZATION=1"
|
||||
)
|
||||
|
||||
pickled = update_dict.pop("ipc_handles_pickled")
|
||||
update_dict["ipc_handles"] = pickle.loads(base64.b64decode(pickled))
|
||||
|
||||
return super().parse_update_info(update_dict)
|
||||
|
||||
def init_transfer_engine(self, init_info: IPCWeightTransferInitInfo) -> None:
|
||||
"""
|
||||
Initialize the weight transfer mechanism.
|
||||
@@ -157,46 +188,52 @@ class IPCWeightTransferEngine(
|
||||
Args:
|
||||
update_info: IPC update info containing parameter names, dtypes, shapes,
|
||||
and IPC handles. Each IPC handle is a mapping between physical
|
||||
GPU UUID and the IPC handle tuple (func, args).
|
||||
GPU UUID and the rebuild_cuda_tensor args tuple.
|
||||
load_weights: Callable that loads weights into the model. Called
|
||||
incrementally for each weight to avoid OOM.
|
||||
"""
|
||||
assert update_info.ipc_handles is not None
|
||||
weights = []
|
||||
for name, _dtype_name, _shape, ipc_handle in zip(
|
||||
update_info.names,
|
||||
update_info.dtype_names,
|
||||
update_info.shapes,
|
||||
update_info.ipc_handles,
|
||||
):
|
||||
device_index = torch.accelerator.current_device_index()
|
||||
props = torch.cuda.get_device_properties(device_index)
|
||||
physical_gpu_id = str(props.uuid)
|
||||
device_index = torch.accelerator.current_device_index()
|
||||
|
||||
if physical_gpu_id not in ipc_handle:
|
||||
raise ValueError(
|
||||
f"IPC handle not found for GPU UUID {physical_gpu_id}. "
|
||||
f"Available UUIDs: {list(ipc_handle.keys())}"
|
||||
)
|
||||
if update_info.packed:
|
||||
assert update_info.tensor_sizes is not None
|
||||
assert isinstance(update_info.ipc_handles, dict)
|
||||
weights = packed_ipc_consumer(
|
||||
ipc_handle=update_info.ipc_handles,
|
||||
names=update_info.names,
|
||||
shapes=update_info.shapes,
|
||||
dtype_names=update_info.dtype_names,
|
||||
tensor_sizes=update_info.tensor_sizes,
|
||||
device_index=device_index,
|
||||
)
|
||||
load_weights(weights)
|
||||
else:
|
||||
assert isinstance(update_info.ipc_handles, list)
|
||||
weights = []
|
||||
for name, ipc_handle in zip(
|
||||
update_info.names,
|
||||
update_info.ipc_handles,
|
||||
):
|
||||
props = torch.cuda.get_device_properties(device_index)
|
||||
physical_gpu_id = str(props.uuid)
|
||||
|
||||
handle = ipc_handle[physical_gpu_id]
|
||||
if physical_gpu_id not in ipc_handle:
|
||||
raise ValueError(
|
||||
f"IPC handle not found for GPU UUID "
|
||||
f"{physical_gpu_id}. "
|
||||
f"Available UUIDs: {list(ipc_handle.keys())}"
|
||||
)
|
||||
|
||||
func, args = handle
|
||||
list_args = list(args) # type: ignore
|
||||
# Index 6 is the device_index parameter in torch's
|
||||
# IPC handle tuple (rebuild_cuda_tensor). Update it
|
||||
# to the current device since the logical index can
|
||||
# differ between sender and receiver.
|
||||
list_args[6] = device_index
|
||||
weight = func(*list_args) # type: ignore
|
||||
weights.append((name, weight))
|
||||
args = ipc_handle[physical_gpu_id]
|
||||
list_args = list(args)
|
||||
# Index 6 of the args from reduce_tensor is the device_index.
|
||||
# We need to overwrite it with the receiver's device index.
|
||||
list_args[6] = device_index
|
||||
weight = rebuild_cuda_tensor(*list_args)
|
||||
weights.append((name, weight))
|
||||
|
||||
load_weights(weights)
|
||||
load_weights(weights)
|
||||
|
||||
def shutdown(self) -> None:
|
||||
"""
|
||||
Shutdown the weight transfer engine.
|
||||
"""
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
@@ -204,12 +241,17 @@ class IPCWeightTransferEngine(
|
||||
iterator: Iterator[tuple[str, torch.Tensor]],
|
||||
trainer_args: dict[str, Any] | IPCTrainerSendWeightsArgs,
|
||||
) -> None:
|
||||
"""
|
||||
Send weights from trainer to inference workers via CUDA IPC.
|
||||
"""Send weights from trainer to inference workers via CUDA IPC.
|
||||
|
||||
Supports two modes:
|
||||
- 'ray': Sends weights via Ray RPC to a Ray-based LLM handle
|
||||
- 'http': Sends weights via HTTP POST to a vLLM HTTP server
|
||||
Supports two transport modes ('ray' and 'http') and two transfer
|
||||
strategies:
|
||||
- Non-packed (default): all weights in a single API call.
|
||||
- Packed (packed=True): chunked transfer with bounded GPU memory.
|
||||
|
||||
For multi-GPU training, all ranks must call this method in
|
||||
parallel. IPC handles are all-gathered across ranks and merged
|
||||
so that each vLLM worker can find its own GPU UUID. Only rank 0
|
||||
sends the payload to vLLM.
|
||||
|
||||
.. note::
|
||||
This method calls ``update_weights`` internally. The caller must
|
||||
@@ -217,88 +259,195 @@ class IPCWeightTransferEngine(
|
||||
after this method.
|
||||
|
||||
Args:
|
||||
iterator: Iterator of model parameters. Returns (name, tensor) tuples.
|
||||
Tensors should be on the same GPU as the inference workers.
|
||||
trainer_args: Dictionary containing IPC-specific arguments.
|
||||
Should contain keys from IPCTrainerSendWeightsArgs:
|
||||
- mode: 'ray' or 'http'
|
||||
- llm_handle: Ray ObjectRef (for 'ray' mode)
|
||||
- url: Base URL string (for 'http' mode)
|
||||
|
||||
Example (Ray mode):
|
||||
>>> from vllm.distributed.weight_transfer.ipc_engine import (
|
||||
... IPCWeightTransferEngine,
|
||||
... IPCTrainerSendWeightsArgs,
|
||||
... )
|
||||
>>> param_iter = ((n, p) for n, p in model.named_parameters())
|
||||
>>> args = IPCTrainerSendWeightsArgs(mode="ray", llm_handle=llm_handle)
|
||||
>>> IPCWeightTransferEngine.trainer_send_weights(param_iter, asdict(args))
|
||||
|
||||
Example (HTTP mode):
|
||||
>>> args = IPCTrainerSendWeightsArgs(
|
||||
... mode="http", url="http://localhost:8000"
|
||||
... )
|
||||
>>> IPCWeightTransferEngine.trainer_send_weights(param_iter, asdict(args))
|
||||
iterator: Iterator of (name, tensor) pairs. For multi-GPU,
|
||||
each rank should yield the full tensor on its own GPU
|
||||
(e.g. via FSDP full_tensor()).
|
||||
trainer_args: IPCTrainerSendWeightsArgs or equivalent dict.
|
||||
"""
|
||||
# Parse trainer args - accept either dict or dataclass instance
|
||||
if isinstance(trainer_args, dict):
|
||||
args = IPCTrainerSendWeightsArgs(**trainer_args)
|
||||
else:
|
||||
args = trainer_args
|
||||
|
||||
# Get physical GPU UUID
|
||||
args = (
|
||||
IPCTrainerSendWeightsArgs(**trainer_args)
|
||||
if isinstance(trainer_args, dict)
|
||||
else trainer_args
|
||||
)
|
||||
device_index = torch.accelerator.current_device_index()
|
||||
props = torch.cuda.get_device_properties(device_index)
|
||||
gpu_uuid = str(props.uuid)
|
||||
gpu_uuid = str(torch.cuda.get_device_properties(device_index).uuid)
|
||||
if args.packed:
|
||||
IPCWeightTransferEngine._send_packed(iterator, args, gpu_uuid)
|
||||
else:
|
||||
IPCWeightTransferEngine._send_unpacked(iterator, args, gpu_uuid)
|
||||
|
||||
# Collect weight metadata and create IPC handles
|
||||
names = []
|
||||
dtype_names = []
|
||||
shapes = []
|
||||
ipc_handles = []
|
||||
@staticmethod
|
||||
def _is_rank_zero() -> bool:
|
||||
"""Return True if this is rank 0 or no distributed group exists."""
|
||||
if not torch.distributed.is_initialized():
|
||||
return True
|
||||
return torch.distributed.get_rank() == 0
|
||||
|
||||
@staticmethod
|
||||
def _all_gather_and_merge_handles(
|
||||
handles: list[dict[str, tuple]],
|
||||
) -> list[dict[str, tuple]]:
|
||||
"""All-gather and merge IPC handle dicts across ranks in one call.
|
||||
|
||||
Each rank contributes a list of {gpu_uuid: ipc_args} dicts (one
|
||||
per parameter or one per chunk). A single all_gather_object
|
||||
collects every rank's full list, then rank 0 merges per-index so
|
||||
each dict maps every GPU UUID to its args.
|
||||
|
||||
Non-rank-0 returns a list of empty dicts.
|
||||
No-op (returns handles unchanged) when no distributed group exists.
|
||||
"""
|
||||
if (
|
||||
not torch.distributed.is_initialized()
|
||||
or torch.distributed.get_world_size() == 1
|
||||
):
|
||||
return handles
|
||||
|
||||
world_size = torch.distributed.get_world_size()
|
||||
gathered: list[list[dict[str, tuple]] | None] = [None] * world_size
|
||||
torch.distributed.all_gather_object(gathered, handles)
|
||||
torch.distributed.barrier()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
if torch.distributed.get_rank() == 0:
|
||||
merged: list[dict[str, tuple]] = []
|
||||
for param_idx in range(len(handles)):
|
||||
m: dict[str, tuple] = {}
|
||||
for rank_handles in gathered:
|
||||
if rank_handles is not None:
|
||||
m.update(rank_handles[param_idx])
|
||||
merged.append(m)
|
||||
return merged
|
||||
return [{} for _ in handles]
|
||||
|
||||
@staticmethod
|
||||
def _post_send_sync() -> None:
|
||||
"""Barrier + ipc_collect after a send; no-op if single-GPU."""
|
||||
if (
|
||||
torch.distributed.is_initialized()
|
||||
and torch.distributed.get_world_size() > 1
|
||||
):
|
||||
torch.distributed.barrier()
|
||||
torch.cuda.ipc_collect()
|
||||
|
||||
@staticmethod
|
||||
def _send_unpacked(
|
||||
iterator: Iterator[tuple[str, torch.Tensor]],
|
||||
args: IPCTrainerSendWeightsArgs,
|
||||
gpu_uuid: str,
|
||||
) -> None:
|
||||
"""Send all weights in a single API call (non-packed mode)."""
|
||||
names: list[str] = []
|
||||
dtype_names: list[str] = []
|
||||
shapes: list[list[int]] = []
|
||||
ipc_handles: list[dict[str, tuple]] = []
|
||||
# Hold strong refs to every contiguous copy until the send + post-send
|
||||
# sync completes. reduce_tensor's returned args do NOT keep storage
|
||||
# alive, and non-contiguous inputs allocate fresh storage in
|
||||
# .contiguous() that would otherwise be GC'd before the consumer opens
|
||||
# the IPC handle.
|
||||
weight_refs: list[torch.Tensor] = []
|
||||
|
||||
for name, tensor in iterator:
|
||||
names.append(name)
|
||||
dtype_names.append(str(tensor.dtype).split(".")[-1])
|
||||
shapes.append(list(tensor.shape))
|
||||
|
||||
# Create IPC handle for this weight tensor
|
||||
# The tensor must remain in memory for IPC to work
|
||||
weight = tensor.detach().contiguous()
|
||||
ipc_handle = reduce_tensor(weight)
|
||||
ipc_handles.append({gpu_uuid: ipc_handle})
|
||||
weight_refs.append(weight)
|
||||
_, ipc_args = reduce_tensor(weight)
|
||||
ipc_handles.append({gpu_uuid: ipc_args})
|
||||
|
||||
# Send weights based on mode
|
||||
if args.mode == "ray":
|
||||
# Ray mode: send via Ray RPC
|
||||
import ray
|
||||
ipc_handles = IPCWeightTransferEngine._all_gather_and_merge_handles(ipc_handles)
|
||||
|
||||
update_info = asdict(
|
||||
IPCWeightTransferUpdateInfo(
|
||||
names=names,
|
||||
dtype_names=dtype_names,
|
||||
shapes=shapes,
|
||||
ipc_handles=ipc_handles,
|
||||
if IPCWeightTransferEngine._is_rank_zero():
|
||||
IPCWeightTransferEngine._do_send(
|
||||
args=args,
|
||||
names=names,
|
||||
dtype_names=dtype_names,
|
||||
shapes=shapes,
|
||||
ipc_handles=ipc_handles,
|
||||
)
|
||||
|
||||
IPCWeightTransferEngine._post_send_sync()
|
||||
|
||||
@staticmethod
|
||||
def _send_packed(
|
||||
iterator: Iterator[tuple[str, torch.Tensor]],
|
||||
args: IPCTrainerSendWeightsArgs,
|
||||
gpu_uuid: str,
|
||||
) -> None:
|
||||
"""Send weights in bounded-memory chunks (packed mode)."""
|
||||
post_iter_func: Callable = lambda item: item[1]
|
||||
|
||||
for chunk in packed_ipc_producer(
|
||||
iterator=iterator,
|
||||
gpu_uuid=gpu_uuid,
|
||||
post_iter_func=post_iter_func,
|
||||
buffer_size_bytes=args.packed_buffer_size_bytes,
|
||||
):
|
||||
ipc_handle = IPCWeightTransferEngine._all_gather_and_merge_handles(
|
||||
[chunk.ipc_handle]
|
||||
)[0]
|
||||
|
||||
if IPCWeightTransferEngine._is_rank_zero():
|
||||
IPCWeightTransferEngine._do_send(
|
||||
args=args,
|
||||
names=chunk.names,
|
||||
dtype_names=chunk.dtype_names,
|
||||
shapes=chunk.shapes,
|
||||
ipc_handles=ipc_handle,
|
||||
tensor_sizes=chunk.tensor_sizes,
|
||||
packed=True,
|
||||
)
|
||||
|
||||
IPCWeightTransferEngine._post_send_sync()
|
||||
|
||||
@staticmethod
|
||||
def _do_send(
|
||||
args: IPCTrainerSendWeightsArgs,
|
||||
names: list[str],
|
||||
dtype_names: list[str],
|
||||
shapes: list[list[int]],
|
||||
ipc_handles: list[dict[str, tuple]] | dict[str, tuple],
|
||||
tensor_sizes: list[int] | None = None,
|
||||
packed: bool = False,
|
||||
) -> None:
|
||||
"""Send a single update payload via the configured transport."""
|
||||
update_fields: dict[str, Any] = {
|
||||
"names": names,
|
||||
"dtype_names": dtype_names,
|
||||
"shapes": shapes,
|
||||
"packed": packed,
|
||||
}
|
||||
if tensor_sizes is not None:
|
||||
update_fields["tensor_sizes"] = tensor_sizes
|
||||
|
||||
update_fields["ipc_handles"] = ipc_handles
|
||||
update_info = IPCWeightTransferUpdateInfo(**update_fields)
|
||||
|
||||
if callable(args.send_mode):
|
||||
args.send_mode(update_info)
|
||||
elif args.send_mode == "ray":
|
||||
handles = (
|
||||
args.llm_handle
|
||||
if isinstance(args.llm_handle, list)
|
||||
else [args.llm_handle]
|
||||
)
|
||||
ray.get(
|
||||
args.llm_handle.update_weights.remote(dict(update_info=update_info))
|
||||
[
|
||||
h.update_weights.remote(dict(update_info=asdict(update_info)))
|
||||
for h in handles
|
||||
]
|
||||
)
|
||||
elif args.mode == "http":
|
||||
# HTTP mode: send via HTTP POST with pickled handles
|
||||
# Pickle and base64 encode IPC handles for HTTP transmission
|
||||
elif args.send_mode == "http":
|
||||
pickled_handles = base64.b64encode(pickle.dumps(ipc_handles)).decode(
|
||||
"utf-8"
|
||||
)
|
||||
http_fields = {k: v for k, v in update_fields.items() if k != "ipc_handles"}
|
||||
http_fields["ipc_handles_pickled"] = pickled_handles
|
||||
|
||||
url = f"{args.url}/update_weights"
|
||||
payload = {
|
||||
"update_info": {
|
||||
"names": names,
|
||||
"dtype_names": dtype_names,
|
||||
"shapes": shapes,
|
||||
"ipc_handles_pickled": pickled_handles,
|
||||
}
|
||||
}
|
||||
payload = {"update_info": http_fields}
|
||||
response = requests.post(url, json=payload, timeout=300)
|
||||
response.raise_for_status()
|
||||
|
||||
@@ -21,7 +21,7 @@ from vllm.distributed.weight_transfer.base import (
|
||||
from vllm.distributed.weight_transfer.packed_tensor import (
|
||||
DEFAULT_PACKED_BUFFER_SIZE_BYTES,
|
||||
DEFAULT_PACKED_NUM_BUFFERS,
|
||||
packed_broadcast_consumer,
|
||||
packed_nccl_broadcast_consumer,
|
||||
)
|
||||
|
||||
|
||||
@@ -184,7 +184,7 @@ class NCCLWeightTransferEngine(
|
||||
dtype = getattr(torch, dtype_name)
|
||||
yield (name, (shape, dtype))
|
||||
|
||||
packed_broadcast_consumer(
|
||||
packed_nccl_broadcast_consumer(
|
||||
iterator=state_dict_info_iterator(),
|
||||
group=self.model_update_group,
|
||||
src=0,
|
||||
@@ -247,10 +247,10 @@ class NCCLWeightTransferEngine(
|
||||
if args.packed:
|
||||
# Use packed tensor broadcasting for efficiency
|
||||
from vllm.distributed.weight_transfer.packed_tensor import (
|
||||
packed_broadcast_producer,
|
||||
packed_nccl_broadcast_producer,
|
||||
)
|
||||
|
||||
packed_broadcast_producer(
|
||||
packed_nccl_broadcast_producer(
|
||||
iterator=iterator,
|
||||
group=args.group,
|
||||
src=args.src,
|
||||
|
||||
@@ -4,9 +4,11 @@
|
||||
|
||||
import math
|
||||
from collections.abc import Callable, Iterator
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from torch.multiprocessing.reductions import reduce_tensor
|
||||
|
||||
# Default values for packed tensor configuration.
|
||||
# These are imported by NCCLWeightTransferUpdateInfo and trainer_send_weights.
|
||||
@@ -14,7 +16,124 @@ DEFAULT_PACKED_BUFFER_SIZE_BYTES = 1024 * 1024 * 1024 # 1GB
|
||||
DEFAULT_PACKED_NUM_BUFFERS = 2
|
||||
|
||||
|
||||
def packed_broadcast_producer(
|
||||
def unpack_tensor(
|
||||
packed_tensor: torch.Tensor,
|
||||
names: list[str],
|
||||
shapes: list[list[int]],
|
||||
dtypes: list[torch.dtype],
|
||||
tensor_sizes: list[int],
|
||||
) -> list[tuple[str, torch.Tensor]]:
|
||||
"""Unpack a packed uint8 tensor into a list of named tensors.
|
||||
|
||||
The returned tensors are **views** of ``packed_tensor`` (the
|
||||
``.contiguous()`` call is a no-op on already-contiguous row-slices).
|
||||
If ``packed_tensor`` lives in storage that may be reused — e.g. a
|
||||
reused CUDA IPC buffer — callers must clone the results before the
|
||||
underlying storage is overwritten.
|
||||
|
||||
Args:
|
||||
packed_tensor: The packed torch.uint8 tensor to unpack
|
||||
names: List of tensor names
|
||||
shapes: List of tensor shapes
|
||||
dtypes: List of tensor dtypes
|
||||
tensor_sizes: List of tensor sizes in bytes
|
||||
"""
|
||||
unpacked_tensors = packed_tensor.split(tensor_sizes)
|
||||
|
||||
return [
|
||||
(name, tensor.contiguous().view(dtype).view(*shape))
|
||||
for name, shape, dtype, tensor in zip(names, shapes, dtypes, unpacked_tensors)
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class PackedChunk:
|
||||
"""Result of packing tensors into a single contiguous uint8 buffer."""
|
||||
|
||||
packed_tensor: torch.Tensor
|
||||
names: list[str]
|
||||
shapes: list[list[int]]
|
||||
dtypes: list[torch.dtype]
|
||||
tensor_sizes: list[int]
|
||||
|
||||
|
||||
def pack_tensors(
|
||||
iterator: Iterator[tuple[str, torch.Tensor]],
|
||||
post_iter_func: Callable[[tuple[str, torch.Tensor]], torch.Tensor],
|
||||
buffer_size_bytes: int,
|
||||
tensor_list: list[torch.Tensor] | None = None,
|
||||
current_size: int = 0,
|
||||
) -> PackedChunk | None:
|
||||
"""Pack tensors from an iterator into a single contiguous uint8 buffer.
|
||||
|
||||
Consumes from the iterator until the accumulated size exceeds
|
||||
buffer_size_bytes or the iterator is exhausted, then returns a
|
||||
PackedChunk. Returns None if no tensors were consumed.
|
||||
|
||||
Args:
|
||||
iterator: Iterator of (name, tensor) pairs
|
||||
post_iter_func: Applied to each item before linearizing to uint8
|
||||
buffer_size_bytes: Max bytes before flushing
|
||||
tensor_list: Pre-existing tensor list to append to (for NCCL
|
||||
multi-buffer reuse). If None, a fresh list is created.
|
||||
current_size: Byte count already accumulated in tensor_list
|
||||
"""
|
||||
if tensor_list is None:
|
||||
tensor_list = []
|
||||
|
||||
names: list[str] = []
|
||||
shapes: list[list[int]] = []
|
||||
dtypes: list[torch.dtype] = []
|
||||
tensor_sizes: list[int] = []
|
||||
total_bytes = current_size
|
||||
|
||||
while True:
|
||||
try:
|
||||
item = next(iterator)
|
||||
except StopIteration:
|
||||
break
|
||||
|
||||
name, orig_tensor = item
|
||||
# Apply post processing and convert to linearized uint8 tensor
|
||||
tensor = post_iter_func(item).contiguous().view(torch.uint8).view(-1)
|
||||
|
||||
if tensor.numel() > buffer_size_bytes:
|
||||
import warnings
|
||||
|
||||
warnings.warn(
|
||||
f"Tensor '{name}' has size {tensor.numel()} bytes, which "
|
||||
f"exceeds buffer_size_bytes={buffer_size_bytes}.",
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
tensor_list.append(tensor)
|
||||
names.append(name)
|
||||
shapes.append(list(orig_tensor.shape))
|
||||
dtypes.append(orig_tensor.dtype)
|
||||
tensor_sizes.append(tensor.numel())
|
||||
total_bytes += tensor.numel()
|
||||
|
||||
if total_bytes > buffer_size_bytes:
|
||||
break
|
||||
|
||||
if not tensor_list:
|
||||
return None
|
||||
|
||||
packed = torch.cat(tensor_list, dim=0)
|
||||
del tensor_list
|
||||
return PackedChunk(
|
||||
packed_tensor=packed,
|
||||
names=names,
|
||||
shapes=shapes,
|
||||
dtypes=dtypes,
|
||||
tensor_sizes=tensor_sizes,
|
||||
)
|
||||
|
||||
|
||||
# ── NCCL packed broadcast ──────────────────────────────────────────────
|
||||
|
||||
|
||||
def packed_nccl_broadcast_producer(
|
||||
iterator: Iterator[tuple[str, torch.Tensor]],
|
||||
group: Any,
|
||||
src: int,
|
||||
@@ -36,57 +155,31 @@ def packed_broadcast_producer(
|
||||
Both producer and consumer must use the same value.
|
||||
|
||||
"""
|
||||
target_packed_tensor_size = buffer_size_bytes
|
||||
|
||||
streams = [torch.cuda.Stream() for _ in range(num_buffers)]
|
||||
# Keep references to in-flight chunks so their packed_tensors
|
||||
# aren't freed while an async broadcast is still reading them.
|
||||
in_flight: list[PackedChunk | None] = [None] * num_buffers
|
||||
buffer_idx = 0
|
||||
|
||||
packing_tensor_list: list[list[torch.Tensor]] = [[] for _ in range(num_buffers)]
|
||||
packing_tensor_sizes: list[int] = [0 for _ in range(num_buffers)]
|
||||
packed_tensors: list[torch.Tensor] = [
|
||||
torch.empty(0, dtype=torch.uint8, device="cuda") for _ in range(num_buffers)
|
||||
]
|
||||
|
||||
while True:
|
||||
# Synchronize the current stream
|
||||
streams[buffer_idx].synchronize()
|
||||
# Previous chunk on this buffer slot is now safe to free
|
||||
in_flight[buffer_idx] = None
|
||||
# Start tasks for the new buffer in a new stream
|
||||
with torch.cuda.stream(streams[buffer_idx]):
|
||||
try:
|
||||
# Initialize the packing tensor list and sizes
|
||||
packing_tensor_list[buffer_idx] = []
|
||||
packing_tensor_sizes[buffer_idx] = 0
|
||||
# Pack the tensors
|
||||
while True:
|
||||
# Apply post processing and convert to linearized uint8 tensor
|
||||
tensor = (
|
||||
post_iter_func(next(iterator))
|
||||
.contiguous()
|
||||
.view(torch.uint8)
|
||||
.view(-1)
|
||||
)
|
||||
packing_tensor_list[buffer_idx].append(tensor)
|
||||
packing_tensor_sizes[buffer_idx] += tensor.numel()
|
||||
if packing_tensor_sizes[buffer_idx] > target_packed_tensor_size:
|
||||
break
|
||||
# Pack the tensors and call broadcast collective
|
||||
packed_tensors[buffer_idx] = torch.cat(
|
||||
packing_tensor_list[buffer_idx], dim=0
|
||||
)
|
||||
group.broadcast(packed_tensors[buffer_idx], src=src)
|
||||
# Move to the next buffer
|
||||
buffer_idx = (buffer_idx + 1) % num_buffers
|
||||
except StopIteration:
|
||||
# Do the last broadcast if there are remaining tensors
|
||||
if len(packing_tensor_list[buffer_idx]) > 0:
|
||||
packed_tensors[buffer_idx] = torch.cat(
|
||||
packing_tensor_list[buffer_idx], dim=0
|
||||
)
|
||||
group.broadcast(packed_tensors[buffer_idx], src=src)
|
||||
chunk = pack_tensors(iterator, post_iter_func, buffer_size_bytes)
|
||||
if chunk is None:
|
||||
break
|
||||
# Pack the tensors and call broadcast collective
|
||||
group.broadcast(chunk.packed_tensor, src=src)
|
||||
# Hold reference until this stream is synchronized
|
||||
in_flight[buffer_idx] = chunk
|
||||
# Move to the next buffer
|
||||
buffer_idx = (buffer_idx + 1) % num_buffers
|
||||
|
||||
|
||||
def packed_broadcast_consumer(
|
||||
def packed_nccl_broadcast_consumer(
|
||||
iterator: Iterator[tuple[str, tuple[list[int], torch.dtype]]],
|
||||
group: Any,
|
||||
src: int,
|
||||
@@ -108,37 +201,6 @@ def packed_broadcast_consumer(
|
||||
Both producer and consumer must use the same value.
|
||||
|
||||
"""
|
||||
|
||||
def unpack_tensor(
|
||||
packed_tensor: torch.Tensor,
|
||||
names: list[str],
|
||||
shapes: list[list[int]],
|
||||
dtypes: list[torch.dtype],
|
||||
tensor_sizes: list[int],
|
||||
) -> list[tuple[str, torch.Tensor]]:
|
||||
"""Unpack a single tensor into a list of tensors.
|
||||
|
||||
Args:
|
||||
packed_tensor: The packed torch.uint8 tensor to unpack
|
||||
names: List of tensor names
|
||||
shapes: List of tensor shapes
|
||||
dtypes: List of tensor dtypes
|
||||
tensor_sizes: List of tensor sizes in bytes
|
||||
|
||||
Returns:
|
||||
unpacked List[(name, tensor)]
|
||||
"""
|
||||
unpacked_tensors = packed_tensor.split(tensor_sizes)
|
||||
|
||||
unpacked_list = [
|
||||
(name, tensor.contiguous().view(dtype).view(*shape))
|
||||
for name, shape, dtype, tensor in zip(
|
||||
names, shapes, dtypes, unpacked_tensors
|
||||
)
|
||||
]
|
||||
|
||||
return unpacked_list
|
||||
|
||||
target_packed_tensor_size = buffer_size_bytes
|
||||
|
||||
streams = [torch.cuda.Stream() for _ in range(num_buffers)]
|
||||
@@ -214,3 +276,152 @@ def packed_broadcast_consumer(
|
||||
)
|
||||
)
|
||||
break
|
||||
|
||||
|
||||
# ── IPC packed transfer ────────────────────────────────────────────────
|
||||
|
||||
|
||||
@dataclass
|
||||
class PackedIpcChunk:
|
||||
"""Metadata and IPC handle for a single packed chunk."""
|
||||
|
||||
names: list[str]
|
||||
shapes: list[list[int]]
|
||||
dtype_names: list[str]
|
||||
tensor_sizes: list[int]
|
||||
ipc_handle: dict[str, tuple]
|
||||
|
||||
|
||||
def packed_ipc_producer(
|
||||
iterator: Iterator[tuple[str, torch.Tensor]],
|
||||
gpu_uuid: str,
|
||||
post_iter_func: Callable[[tuple[str, torch.Tensor]], torch.Tensor],
|
||||
buffer_size_bytes: int = DEFAULT_PACKED_BUFFER_SIZE_BYTES,
|
||||
) -> Iterator[PackedIpcChunk]:
|
||||
"""Pack tensors into a reusable IPC buffer and yield handles.
|
||||
|
||||
Allocates a single GPU buffer of ``buffer_size_bytes`` and registers
|
||||
it for IPC once via ``reduce_tensor``. Each chunk's packed data is
|
||||
copied into this buffer before yielding, so only one IPC-shared
|
||||
allocation exists for the lifetime of the transfer.
|
||||
|
||||
Callers **must** ensure the consumer has finished reading the buffer
|
||||
(e.g. ``ray.get`` returned) before resuming the generator for the
|
||||
next chunk.
|
||||
|
||||
Args:
|
||||
iterator: Iterator of (name, tensor) pairs.
|
||||
gpu_uuid: Physical GPU UUID string for this rank.
|
||||
post_iter_func: Applied to each (name, tensor) before packing.
|
||||
buffer_size_bytes: Exact capacity of the reusable IPC buffer.
|
||||
Every chunk is guaranteed to fit within this size. A
|
||||
``ValueError`` is raised if any single tensor exceeds it.
|
||||
"""
|
||||
ipc_buffer = torch.empty(buffer_size_bytes, dtype=torch.uint8, device="cuda")
|
||||
_, ipc_args = reduce_tensor(ipc_buffer)
|
||||
|
||||
names: list[str] = []
|
||||
shapes: list[list[int]] = []
|
||||
dtypes: list[torch.dtype] = []
|
||||
tensor_sizes: list[int] = []
|
||||
total_bytes = 0
|
||||
|
||||
for name, orig_tensor in iterator:
|
||||
flat = (
|
||||
post_iter_func((name, orig_tensor)).contiguous().view(torch.uint8).view(-1)
|
||||
)
|
||||
|
||||
if flat.numel() > buffer_size_bytes:
|
||||
raise ValueError(
|
||||
f"Tensor '{name}' has size {flat.numel()} bytes, "
|
||||
f"which exceeds buffer_size_bytes={buffer_size_bytes}. "
|
||||
f"Increase buffer_size_bytes to at least {flat.numel()}."
|
||||
)
|
||||
|
||||
if total_bytes and total_bytes + flat.numel() > buffer_size_bytes:
|
||||
# Drain queued copies so the consumer sees a fully-written buffer.
|
||||
torch.cuda.current_stream().synchronize()
|
||||
yield PackedIpcChunk(
|
||||
names=names,
|
||||
shapes=shapes,
|
||||
dtype_names=[str(d).split(".")[-1] for d in dtypes],
|
||||
tensor_sizes=tensor_sizes,
|
||||
ipc_handle={gpu_uuid: ipc_args},
|
||||
)
|
||||
# Rebind to fresh lists so the yielded chunk's metadata is
|
||||
# not mutated while the consumer is still reading.
|
||||
names, shapes, dtypes, tensor_sizes = [], [], [], []
|
||||
total_bytes = 0
|
||||
|
||||
ipc_buffer[total_bytes : total_bytes + flat.numel()].copy_(flat)
|
||||
names.append(name)
|
||||
shapes.append(list(orig_tensor.shape))
|
||||
dtypes.append(orig_tensor.dtype)
|
||||
tensor_sizes.append(flat.numel())
|
||||
total_bytes += flat.numel()
|
||||
|
||||
if total_bytes:
|
||||
torch.cuda.current_stream().synchronize()
|
||||
yield PackedIpcChunk(
|
||||
names=names,
|
||||
shapes=shapes,
|
||||
dtype_names=[str(d).split(".")[-1] for d in dtypes],
|
||||
tensor_sizes=tensor_sizes,
|
||||
ipc_handle={gpu_uuid: ipc_args},
|
||||
)
|
||||
|
||||
|
||||
def packed_ipc_consumer(
|
||||
ipc_handle: dict[str, tuple],
|
||||
names: list[str],
|
||||
shapes: list[list[int]],
|
||||
dtype_names: list[str],
|
||||
tensor_sizes: list[int],
|
||||
device_index: int,
|
||||
) -> list[tuple[str, torch.Tensor]]:
|
||||
"""Unpack a single packed IPC chunk into named tensors.
|
||||
|
||||
Reconstructs the packed buffer via rebuild_cuda_tensor, unpacks
|
||||
into individual tensors, and clones each into independent storage
|
||||
before returning.
|
||||
|
||||
The clone is intentional: the producer reuses one IPC buffer across
|
||||
chunks, so any tensor view that aliases the buffer would observe the
|
||||
*next* chunk's bytes as soon as the producer's generator is resumed.
|
||||
Callers that retain references past their own update_weights call
|
||||
(notably vLLM's layerwise reload, which buffers ``bound_args`` for
|
||||
replay in ``_layerwise_process``) would otherwise replay against
|
||||
stale data and silently corrupt multi-chunk weight transfers.
|
||||
|
||||
Args:
|
||||
ipc_handle: Mapping of GPU UUID to rebuild_cuda_tensor args tuple
|
||||
names: Parameter names in the packed buffer
|
||||
shapes: Parameter shapes
|
||||
dtype_names: Parameter dtype name strings (e.g. "float16")
|
||||
tensor_sizes: Size in bytes of each parameter in the packed buffer
|
||||
device_index: Local CUDA device index
|
||||
"""
|
||||
from torch.multiprocessing.reductions import rebuild_cuda_tensor
|
||||
|
||||
props = torch.cuda.get_device_properties(device_index)
|
||||
physical_gpu_id = str(props.uuid)
|
||||
|
||||
if physical_gpu_id not in ipc_handle:
|
||||
raise ValueError(
|
||||
f"IPC handle not found for GPU UUID {physical_gpu_id}. "
|
||||
f"Available UUIDs: {list(ipc_handle.keys())}"
|
||||
)
|
||||
|
||||
args = ipc_handle[physical_gpu_id]
|
||||
list_args = list(args)
|
||||
list_args[6] = device_index
|
||||
packed = rebuild_cuda_tensor(*list_args)
|
||||
|
||||
content_size = sum(tensor_sizes)
|
||||
packed = packed[:content_size]
|
||||
|
||||
dtypes = [getattr(torch, dn) for dn in dtype_names]
|
||||
return [
|
||||
(name, t.clone())
|
||||
for name, t in unpack_tensor(packed, names, shapes, dtypes, tensor_sizes)
|
||||
]
|
||||
|
||||
Reference in New Issue
Block a user