[Feat][RL] IPC weight sync optimizations: multigpu support and chunked packed tensors (#37476)

Signed-off-by: ahao-anyscale <ahao@anyscale.com>
Signed-off-by: hao-aaron <ahao@anyscale.com>
This commit is contained in:
Aaron Hao
2026-05-15 07:53:06 -07:00
committed by GitHub
parent 0fe7550254
commit e0a45f1455
10 changed files with 1591 additions and 455 deletions
+40 -9
View File
@@ -1,21 +1,37 @@
# IPC Engine
The IPC weight transfer engine uses **CUDA IPC** (Inter-Process Communication) handles to share GPU memory directly between the trainer and inference workers on the **same node and same GPU**. This avoids any data copying, making it a efficient option when colocating training and inference.
The IPC weight transfer engine uses **CUDA IPC** (Inter-Process Communication) handles to share GPU memory directly between the trainer and inference workers on the **same GPU**. This avoids any data copying, making it the most efficient option when colocating training and inference. Multi-GPU setups are supported — weights are all gathered by each GPU and are extracted by the correct colocated process.
## When to Use IPC
- Training and inference on the **same GPU** (colocated)
- You want to minimize memory overhead by sharing tensors in-place
- Training and inference share the **same GPU(s)** (colocated)
## How It Works
1. The trainer creates CUDA tensors for each weight and generates IPC handles using `torch.multiprocessing.reductions.reduce_tensor`.
2. IPC handles are sent to the inference engine via **Ray.remote()** or **HTTP POST**.
3. The inference worker reconstructs the tensors from the handles, reading directly from the trainer's GPU memory.
1. The trainer creates CUDA tensors for each weight and generates IPC handles using `torch.multiprocessing.reductions.reduce_tensor`. In multi-GPU setups (e.g. FSDP), each trainer rank must all-gather the full tensor for each layer onto its own GPU before generating the IPC handle.
2. IPC handles for each gpu are sent to the inference engine via **Ray**, **HTTP**, or a **custom callable**. Each rank only reads the handle corresponding to its own GPU.
3. The inference worker reconstructs the tensors from the handles using `rebuild_cuda_tensor`, reading directly from the trainer's GPU memory.
!!! warning
IPC handles involve sending serialized Python objects. When using HTTP transport, you must set `VLLM_ALLOW_INSECURE_SERIALIZATION=1` on both the server and client. This is because IPC handles are pickled and base64-encoded for HTTP transmission.
## Packed (Chunked) Transfer
By default, all weights are sent in a single API call. For large models, this requires the full model to reside in GPU memory on both sides simultaneously. Setting `packed=True` enables **chunked transfer** with bounded GPU memory:
- Weights are concatenated into fixed-size packed buffers (controlled by `packed_buffer_size_bytes`).
- Each chunk is sent as a separate `update_weights` call within a single `start_weight_update` / `finish_weight_update` bracket, so the layerwise reload pass is initialized once at the start and finalized once at the end regardless of chunk count.
- After each chunk is consumed, the GPU memory for that chunk can be reclaimed.
```python
trainer_args = IPCTrainerSendWeightsArgs(
send_mode="ray",
llm_handle=llm_actor_handle,
packed=True,
packed_buffer_size_bytes=256 * 1024 * 1024, # 256 MB chunks
)
```
## Initialization
The IPC backend requires no initialization on either side. The `init_transfer_engine` call is a no-op for IPC.
@@ -35,7 +51,7 @@ from vllm.distributed.weight_transfer.ipc_engine import (
)
trainer_args = IPCTrainerSendWeightsArgs(
mode="ray",
send_mode="ray",
llm_handle=llm_actor_handle,
)
# start
@@ -57,7 +73,7 @@ Used when vLLM is running as an HTTP server:
```python
trainer_args = IPCTrainerSendWeightsArgs(
mode="http",
send_mode="http",
url="http://localhost:8000",
)
@@ -77,7 +93,22 @@ response = requests.post(url, json={}, timeout=60)
response.raise_for_status()
```
In HTTP mode, IPC handles are pickled, base64-encoded, and sent as JSON to the `/update_weights` endpoint. As with Ray mode, you must call `start_weight_update` before and `finish_weight_update` after.
In HTTP mode, IPC handles are pickled, base64-encoded, and sent as JSON to the `/update_weights` endpoint. Because the worker deserializes the payload via `pickle.loads`, the vLLM server must be started with `VLLM_ALLOW_INSECURE_SERIALIZATION=1`.
```python
def my_custom_sender(update_info: IPCWeightTransferUpdateInfo):
# Custom logic to deliver update_info to vLLM
...
trainer_args = IPCTrainerSendWeightsArgs(
send_mode=my_custom_sender,
)
IPCWeightTransferEngine.trainer_send_weights(
iterator=model.named_parameters(),
trainer_args=trainer_args,
)
```
See [`IPCTrainerSendWeightsArgs`](https://github.com/vllm-project/vllm/blob/main/vllm/distributed/weight_transfer/ipc_engine.py) for the full list of configurable fields.
+1 -1
View File
@@ -173,7 +173,7 @@ def main():
start_weight_update(BASE_URL, is_checkpoint_format=False)
print("Broadcasting weights via CUDA IPC (HTTP)...")
trainer_args = IPCTrainerSendWeightsArgs(mode="http", url=BASE_URL)
trainer_args = IPCTrainerSendWeightsArgs(send_mode="http", url=BASE_URL)
IPCWeightTransferEngine.trainer_send_weights(
iterator=train_model.named_parameters(),
trainer_args=trainer_args,
+9 -5
View File
@@ -70,10 +70,14 @@ class TrainModel:
self.llm_handle.init_weight_transfer_engine.remote(dict(init_info=dict()))
)
def broadcast_weights(self, llm_handle: ray.actor.ActorHandle):
def broadcast_weights(
self, llm_handle: ray.actor.ActorHandle, packed: bool = False
):
"""Broadcast weights to the inference engine using IPC."""
self.llm_handle = llm_handle
trainer_args = IPCTrainerSendWeightsArgs(mode="ray", llm_handle=llm_handle)
trainer_args = IPCTrainerSendWeightsArgs(
send_mode="ray", llm_handle=llm_handle, packed=packed
)
IPCWeightTransferEngine.trainer_send_weights(
iterator=self.train_model.named_parameters(),
trainer_args=trainer_args,
@@ -141,10 +145,10 @@ ray.get(llm.finish_weight_update.remote())
ray.get(llm.wake_up.remote(tags=["scheduling"]))
# Generate text with the updated model.
outputs_updated = ray.get(llm.generate.remote(prompts, sampling_params))
outputs_packed = ray.get(llm.generate.remote(prompts, sampling_params))
print("-" * 50)
for output in outputs_updated:
print("Results after packed/chunked IPC weight sync:")
for output in outputs_packed:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
+425
View File
@@ -0,0 +1,425 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
RLHF with FSDP2 training and vLLM expert-parallel inference using **CUDA IPC**
weight transfer and **packed** tensors.
Layout (4 GPUs, TP=1, DP=4, EP):
* One Ray placement group per GPU.
* Each PG holds one FSDP training worker and one vLLM ``LLM`` instance
(sync API) using fractional GPUs so both fit on the same device.
* The 4 ``LLM`` instances form a DP group via env-var-based SPMD
coordination (``VLLM_DP_RANK``, ``VLLM_DP_SIZE``, etc.), the same
mechanism used by ``examples/offline_inference/data_parallel.py``.
* A ``DataParallelInferenceEngine`` actor spawns all 4 LLM actors,
waits for initialization, and orchestrates generation / weight-sync.
Uses the built-in ``ray`` send_mode: each FSDP worker calls
``trainer_send_weights`` targeting its colocated LLM actor.
This example was run on 4xH100.
"""
from __future__ import annotations
import os
from dataclasses import asdict
import ray
import torch
import torch.distributed as dist
from huggingface_hub import snapshot_download
from ray.util.placement_group import placement_group
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
from torch.distributed._tensor import DTensor
from torch.distributed.fsdp import fully_shard
from transformers import AutoModelForCausalLM
from vllm import LLM, SamplingParams
from vllm.config import WeightTransferConfig
from vllm.distributed.weight_transfer.ipc_engine import (
IPCTrainerSendWeightsArgs,
IPCWeightTransferEngine,
IPCWeightTransferInitInfo,
)
from vllm.utils.network_utils import get_ip, get_open_port
TRAIN_GPU_FRACTION = float(os.environ.get("RLHF_IPC_TRAIN_GPU_FRACTION", "0.42"))
VLLM_GPU_FRACTION = float(os.environ.get("RLHF_IPC_VLLM_GPU_FRACTION", "0.42"))
MODEL_NAME = "Qwen/Qwen3-30B-A3B"
FSDP_WORLD_SIZE = 4
INFERENCE_TP_SIZE = 1
INFERENCE_DP_SIZE = 4
class MyLLM(LLM):
"""LLM subclass that configures DP env vars for SPMD coordination."""
def __init__(
self,
*args,
dp_rank: int = 0,
dp_size: int = 1,
dp_master_ip: str = "127.0.0.1",
dp_master_port: int = 0,
**kwargs,
):
os.environ.pop("CUDA_VISIBLE_DEVICES", None)
os.environ["VLLM_RAY_PER_WORKER_GPUS"] = str(VLLM_GPU_FRACTION)
os.environ["VLLM_RAY_BUNDLE_INDICES"] = "0"
os.environ["VLLM_ALLOW_INSECURE_SERIALIZATION"] = "1"
os.environ["VLLM_DP_RANK"] = str(dp_rank)
os.environ["VLLM_DP_RANK_LOCAL"] = str(dp_rank)
os.environ["VLLM_DP_SIZE"] = str(dp_size)
os.environ["VLLM_DP_MASTER_IP"] = dp_master_ip
os.environ["VLLM_DP_MASTER_PORT"] = str(dp_master_port)
super().__init__(*args, **kwargs)
def ready(self):
return True
@ray.remote(num_cpus=0, num_gpus=TRAIN_GPU_FRACTION)
class FSDPTrainWorker:
"""One FSDP2 worker per GPU; colocated with vLLM DP rank via placement group."""
def __init__(
self,
model_name: str,
rank: int,
fsdp_world_size: int,
fsdp_master_addr: str,
fsdp_master_port: int,
):
self.rank = rank
os.environ["MASTER_ADDR"] = fsdp_master_addr
os.environ["MASTER_PORT"] = str(fsdp_master_port)
dist.init_process_group(backend="nccl", rank=rank, world_size=fsdp_world_size)
torch.accelerator.set_device_index(0)
model = AutoModelForCausalLM.from_pretrained(
model_name, torch_dtype=torch.bfloat16
)
self.weight_names = [n for n, _ in model.named_parameters()]
self.weight_dtype_names = [
str(p.dtype).split(".")[-1] for _, p in model.named_parameters()
]
self.weight_shapes = [list(p.shape) for _, p in model.named_parameters()]
for layer in model.model.layers:
fully_shard(layer)
fully_shard(model)
self.model = model
def get_rank(self):
return self.rank
def get_weight_metadata(self):
return self.weight_names, self.weight_dtype_names, self.weight_shapes
def gather_and_broadcast_weights_ipc(self, llm_handle, packed: bool = True):
"""All-gather full params; all ranks create IPC handles, rank 0 sends.
All ranks must call trainer_send_weights so they participate in the
all_gather_object collective inside _all_gather_and_merge_handles.
Only rank 0 actually sends the payload to vLLM (gated by _is_rank_zero).
"""
def _full_param_iter():
# HF's Qwen3MoeExperts (and other recent HF MoE impls) packs
# all experts into two fused 3-D tensors per layer:
# experts.gate_up_proj shape (E, 2*I, H)
# experts.down_proj shape (E, H, I)
# vLLM's Qwen3MoE load_weights still expects the older
# per-expert HF layout (experts.<i>.gate_proj.weight,
# experts.<i>.up_proj.weight, experts.<i>.down_proj.weight),
# so we un-fuse on the fly. Split order matches HF's forward:
# gate, up = linear(x, gate_up_proj[i]).chunk(2, dim=-1)
# → rows [:I] of gate_up_proj[i] are gate, rows [I:] are up.
params = self.model.state_dict()
for name in list(params.keys()):
param = params.pop(name)
if isinstance(param, DTensor):
tensor = param.full_tensor().detach().contiguous()
else:
tensor = param.detach().contiguous()
del param
if name.endswith(".experts.gate_up_proj") and tensor.dim() == 3:
prefix = name[: -len(".gate_up_proj")]
num_experts, two_inter, _ = tensor.shape
inter = two_inter // 2
for i in range(num_experts):
expert = tensor[i]
yield (
f"{prefix}.{i}.gate_proj.weight",
expert[:inter].contiguous(),
)
yield (
f"{prefix}.{i}.up_proj.weight",
expert[inter:].contiguous(),
)
del tensor
elif name.endswith(".experts.down_proj") and tensor.dim() == 3:
prefix = name[: -len(".down_proj")]
num_experts = tensor.shape[0]
for i in range(num_experts):
yield (
f"{prefix}.{i}.down_proj.weight",
tensor[i].contiguous(),
)
del tensor
else:
yield name, tensor
trainer_args = IPCTrainerSendWeightsArgs(
send_mode="ray",
llm_handle=llm_handle,
packed=packed,
packed_buffer_size_bytes=1024 * 1024 * 1024, # 1 GB
)
IPCWeightTransferEngine.trainer_send_weights(
iterator=_full_param_iter(),
trainer_args=trainer_args,
)
@ray.remote(num_cpus=1)
class DataParallelInferenceEngine:
"""Manages a pool of DP-sharded vLLM LLM actors.
Spawns one MyLLM actor per placement group, waits for all engines to
finish initializing, and exposes generation / weight-sync helpers.
"""
def __init__(
self,
model: str,
pgs: list,
dp_master_ip: str,
dp_master_port: int,
):
dp_size = len(pgs)
self.llm_actors = []
for r in range(dp_size):
sched = PlacementGroupSchedulingStrategy(
placement_group=pgs[r],
placement_group_capture_child_tasks=True,
)
actor = (
ray.remote(num_cpus=0, num_gpus=0)(MyLLM)
.options(scheduling_strategy=sched)
.remote(
model=model,
enforce_eager=True,
tensor_parallel_size=INFERENCE_TP_SIZE,
distributed_executor_backend="ray",
enable_expert_parallel=True,
gpu_memory_utilization=0.35,
weight_transfer_config=WeightTransferConfig(backend="ipc"),
enable_sleep_mode=True,
load_format="dummy",
dp_rank=r,
dp_size=dp_size,
dp_master_ip=dp_master_ip,
dp_master_port=dp_master_port,
)
)
self.llm_actors.append(actor)
ray.get([actor.ready.remote() for actor in self.llm_actors])
def get_llm_actors(self):
return self.llm_actors
def generate(self, prompts: list[str], sampling_params):
"""Distribute prompts round-robin across DP ranks and collect results."""
dp_size = len(self.llm_actors)
per_rank: list[list[str]] = [[] for _ in range(dp_size)]
indices: list[list[int]] = [[] for _ in range(dp_size)]
for i, prompt in enumerate(prompts):
rank = i % dp_size
per_rank[rank].append(prompt)
indices[rank].append(i)
refs = [
actor.generate.remote(per_rank[r], sampling_params)
for r, actor in enumerate(self.llm_actors)
if per_rank[r]
]
all_outputs = ray.get(refs)
ordered = [None] * len(prompts)
rank_idx = 0
for r in range(dp_size):
if per_rank[r]:
for local_i, orig_i in enumerate(indices[r]):
ordered[orig_i] = all_outputs[rank_idx][local_i]
rank_idx += 1
return ordered
def init_weight_transfer(self):
ray.get(
[
actor.init_weight_transfer_engine.remote(
dict(init_info=asdict(IPCWeightTransferInitInfo()))
)
for actor in self.llm_actors
]
)
def start_weight_update(self, is_checkpoint_format: bool = True):
ray.get(
[
actor.start_weight_update.remote(
is_checkpoint_format=is_checkpoint_format
)
for actor in self.llm_actors
]
)
def finish_weight_update(self):
ray.get([actor.finish_weight_update.remote() for actor in self.llm_actors])
def sleep(self, level: int = 0):
ray.get([actor.sleep.remote(level=level) for actor in self.llm_actors])
def wake_up(self, tags: list[str] | None = None):
ray.get([actor.wake_up.remote(tags=tags) for actor in self.llm_actors])
def main():
ray.init(
runtime_env={
"env_vars": {
"VLLM_ALLOW_INSECURE_SERIALIZATION": "1",
}
}
)
assert TRAIN_GPU_FRACTION + VLLM_GPU_FRACTION <= 1.0, (
"Train + vLLM GPU fractions must sum to at most 1.0 per bundle."
)
local_model_path = snapshot_download(MODEL_NAME)
print(f"[init] Model downloaded to {local_model_path}")
fsdp_master_addr = get_ip()
fsdp_master_port = get_open_port()
dp_master_port = get_open_port()
dp_master_ip = get_ip()
# Create one placement group per DP rank (one GPU each).
pgs = []
for _ in range(INFERENCE_DP_SIZE):
pg = placement_group([{"GPU": 1, "CPU": 1}])
pgs.append(pg)
ray.get([pg.ready() for pg in pgs])
print(f"[init] {len(pgs)} placement groups ready.")
# Launch FSDP training workers, one per PG.
scheduling = [
PlacementGroupSchedulingStrategy(
placement_group=pgs[r],
placement_group_capture_child_tasks=True,
)
for r in range(FSDP_WORLD_SIZE)
]
fsdp_workers = [
FSDPTrainWorker.options(scheduling_strategy=scheduling[r]).remote(
local_model_path,
r,
FSDP_WORLD_SIZE,
fsdp_master_addr,
fsdp_master_port,
)
for r in range(FSDP_WORLD_SIZE)
]
ray.get([w.get_rank.remote() for w in fsdp_workers])
print(f"[init] {FSDP_WORLD_SIZE} FSDP workers ready.")
# Launch DP inference engine (spawns and initializes all LLM actors).
inference_engine = DataParallelInferenceEngine.remote(
model=local_model_path,
pgs=pgs,
dp_master_ip=dp_master_ip,
dp_master_port=dp_master_port,
)
llm_actors = ray.get(inference_engine.get_llm_actors.remote())
print(f"[init] {INFERENCE_DP_SIZE} LLM actors ready.")
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
sampling_params = SamplingParams(temperature=0)
print("[generate] Generating with dummy weights...")
outputs = ray.get(inference_engine.generate.remote(prompts, sampling_params))
print("-" * 60)
print("BEFORE weight sync (dummy weights):")
print("-" * 60)
for output in outputs:
print(f"Prompt: {output.prompt!r}")
print(f"Generated: {output.outputs[0].text!r}")
print("-" * 60)
# --- Weight transfer ---
print("[transfer] Initializing IPC weight transfer...")
ray.get(inference_engine.init_weight_transfer.remote())
# Two-phase sleep/wake pattern:
# 1. sleep(level=1) — offload weights to CPU, discard KV cache
# 2. wake_up(tags=["weights"]) — bring weights back to GPU (KV cache still free)
# 3. IPC weight transfer — overwrite weights, plenty of room without KV cache
# 4. wake_up(tags=["kv_cache"]) — re-allocate KV cache for inference
print("[sync] Sleeping engines (offload weights + free KV cache)...")
ray.get(inference_engine.sleep.remote(level=1))
print("[sync] Waking weights (KV cache stays free)...")
ray.get(inference_engine.wake_up.remote(tags=["weights"]))
print("[sync] Starting weight update...")
ray.get(inference_engine.start_weight_update.remote(is_checkpoint_format=True))
print("[sync] Packed IPC transfer FSDP → vLLM...")
ray.get(
[
w.gather_and_broadcast_weights_ipc.remote(llm_actors, packed=True)
for w in fsdp_workers
]
)
ray.get(inference_engine.finish_weight_update.remote())
print("[sync] Weight transfer complete.")
print("[sync] Waking KV cache + scheduling...")
ray.get(inference_engine.wake_up.remote(tags=["kv_cache", "scheduling"]))
print("[generate] Generating with synced weights...")
outputs_updated = ray.get(
inference_engine.generate.remote(prompts, sampling_params)
)
print("-" * 60)
print("AFTER weight sync (real weights):")
print("-" * 60)
for output in outputs_updated:
print(f"Prompt: {output.prompt!r}")
print(f"Generated: {output.outputs[0].text!r}")
print("-" * 60)
if __name__ == "__main__":
main()
+478 -146
View File
@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests for packed tensor broadcasting functionality.
Unit tests for packed_broadcast_producer and packed_broadcast_consumer.
Unit tests for packed_nccl_broadcast_producer and packed_nccl_broadcast_consumer.
These utilities enable efficient batched tensor transfer over NCCL.
"""
@@ -11,8 +11,12 @@ import torch
from vllm.distributed.weight_transfer.nccl_engine import NCCLWeightTransferUpdateInfo
from vllm.distributed.weight_transfer.packed_tensor import (
packed_broadcast_consumer,
packed_broadcast_producer,
pack_tensors,
packed_ipc_consumer,
packed_ipc_producer,
packed_nccl_broadcast_consumer,
packed_nccl_broadcast_producer,
unpack_tensor,
)
@@ -90,91 +94,18 @@ class TestNCCLWeightTransferUpdateInfoPacked:
assert info.packed is True
# --- Unit Tests: packed_broadcast_producer ---
# --- Unit Tests: packed_nccl_broadcast_producer ---
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
class TestPackedBroadcastProducer:
"""Test packed_broadcast_producer function."""
def test_producer_broadcasts_tensors(self):
"""Test that producer broadcasts all tensors."""
params = create_mock_model_params()
params_cuda = [(name, tensor.cuda()) for name, tensor in params]
mock_group = MockCommunicationGroup()
# Use a small target size to force multiple batches
packed_broadcast_producer(
iterator=iter(params_cuda),
group=mock_group,
src=0,
post_iter_func=lambda x: x[1],
buffer_size_bytes=500,
)
# Should have broadcasted some tensors
assert mock_group.broadcast_count > 0
assert len(mock_group.broadcasted_tensors) > 0
def test_producer_single_large_tensor(self):
"""Test with a single tensor larger than target size."""
# Create a large tensor
large_tensor = torch.randn(1000, 1000, dtype=torch.float32).cuda()
params = [("large_weight", large_tensor)]
mock_group = MockCommunicationGroup()
# Small target size to force the tensor to exceed it
packed_broadcast_producer(
iterator=iter(params),
group=mock_group,
src=0,
post_iter_func=lambda x: x[1],
buffer_size_bytes=100,
)
# Should still broadcast the tensor (at least 1 broadcast)
assert mock_group.broadcast_count >= 1
assert len(mock_group.broadcasted_tensors) >= 1
# Verify the total broadcasted size matches the tensor
expected_size = large_tensor.numel() * large_tensor.element_size()
actual_size = sum(t.numel() for t in mock_group.broadcasted_tensors)
assert actual_size == expected_size
def test_producer_multiple_batches(self):
"""Test that tensors are properly batched when exceeding target size."""
# Create many small tensors
params = [
(f"weight_{i}", torch.randn(10, 10, dtype=torch.float32).cuda())
for i in range(20)
]
mock_group = MockCommunicationGroup()
# Small target size to force multiple batches
packed_broadcast_producer(
iterator=iter(params),
group=mock_group,
src=0,
post_iter_func=lambda x: x[1],
buffer_size_bytes=2000,
)
# Should have multiple broadcasts
assert mock_group.broadcast_count > 1
# Total size should match sum of all tensors
expected_total = sum(t.numel() * t.element_size() for _, t in params)
actual_total = sum(t.numel() for t in mock_group.broadcasted_tensors)
assert actual_total == expected_total
"""Test packed_nccl_broadcast_producer function."""
def test_producer_empty_iterator(self):
"""Test producer handles empty iterator gracefully."""
mock_group = MockCommunicationGroup()
packed_broadcast_producer(
packed_nccl_broadcast_producer(
iterator=iter([]),
group=mock_group,
src=0,
@@ -186,64 +117,6 @@ class TestPackedBroadcastProducer:
assert mock_group.broadcast_count == 0
# --- Unit Tests: packed_broadcast_consumer ---
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
class TestPackedBroadcastConsumer:
"""Test packed_broadcast_consumer function."""
def test_consumer_receives_tensors(self):
"""Test that consumer receives and unpacks tensors."""
params = create_mock_model_params()
params_cuda = [(name, tensor.cuda()) for name, tensor in params]
buffer_size = 2000
# First, run producer to get the broadcasted tensors
producer_group = MockCommunicationGroup()
packed_broadcast_producer(
iterator=iter(params_cuda),
group=producer_group,
src=0,
post_iter_func=lambda x: x[1],
buffer_size_bytes=buffer_size,
)
# Now run consumer with the broadcasted tensors
consumer_group = MockConsumerCommunicationGroup(
producer_group.broadcasted_tensors
)
state_dict_info = create_state_dict_info(params_cuda)
unpacked_tensors = {}
def post_unpack_func(tensor_list):
for name, tensor in tensor_list:
unpacked_tensors[name] = tensor.clone()
packed_broadcast_consumer(
iterator=iter(state_dict_info.items()),
group=consumer_group,
src=0,
post_unpack_func=post_unpack_func,
buffer_size_bytes=buffer_size,
)
# Verify all parameters were unpacked
assert len(unpacked_tensors) == len(params)
# Verify each tensor matches the original
for name, original_tensor in params_cuda:
assert name in unpacked_tensors
unpacked = unpacked_tensors[name]
assert unpacked.shape == original_tensor.shape
assert unpacked.dtype == original_tensor.dtype
assert torch.allclose(unpacked, original_tensor, rtol=1e-5, atol=1e-7)
# --- Integration Tests: Producer-Consumer Roundtrip ---
@@ -260,7 +133,7 @@ class TestPackedBroadcastRoundtrip:
buffer_size = 1000
producer_group = MockCommunicationGroup()
packed_broadcast_producer(
packed_nccl_broadcast_producer(
iterator=iter(params_cuda),
group=producer_group,
src=0,
@@ -279,7 +152,7 @@ class TestPackedBroadcastRoundtrip:
for name, tensor in tensor_list:
unpacked_tensors[name] = tensor.clone()
packed_broadcast_consumer(
packed_nccl_broadcast_consumer(
iterator=iter(state_dict_info.items()),
group=consumer_group,
src=0,
@@ -306,7 +179,7 @@ class TestPackedBroadcastRoundtrip:
buffer_size = 500
producer_group = MockCommunicationGroup()
packed_broadcast_producer(
packed_nccl_broadcast_producer(
iterator=iter(params),
group=producer_group,
src=0,
@@ -325,7 +198,7 @@ class TestPackedBroadcastRoundtrip:
for name, tensor in tensor_list:
unpacked_tensors[name] = tensor.clone()
packed_broadcast_consumer(
packed_nccl_broadcast_consumer(
iterator=iter(state_dict_info.items()),
group=consumer_group,
src=0,
@@ -341,7 +214,7 @@ class TestPackedBroadcastRoundtrip:
assert unpacked.dtype == original_tensor.dtype
assert torch.allclose(unpacked, original_tensor, rtol=1e-4, atol=1e-6)
@pytest.mark.parametrize("target_size", [100, 1000, 10000, 100000])
@pytest.mark.parametrize("target_size", [100, 100000])
def test_roundtrip_different_batch_sizes(self, target_size):
"""Test roundtrip with different target batch sizes."""
params = create_mock_model_params(num_layers=5)
@@ -349,7 +222,7 @@ class TestPackedBroadcastRoundtrip:
producer_group = MockCommunicationGroup()
packed_broadcast_producer(
packed_nccl_broadcast_producer(
iterator=iter(params_cuda),
group=producer_group,
src=0,
@@ -368,7 +241,7 @@ class TestPackedBroadcastRoundtrip:
for name, tensor in tensor_list:
unpacked_tensors[name] = tensor.clone()
packed_broadcast_consumer(
packed_nccl_broadcast_consumer(
iterator=iter(state_dict_info.items()),
group=consumer_group,
src=0,
@@ -407,7 +280,7 @@ class TestPackedBroadcastRoundtrip:
buffer_size = 500
producer_group = MockCommunicationGroup()
packed_broadcast_producer(
packed_nccl_broadcast_producer(
iterator=iter(params),
group=producer_group,
src=0,
@@ -426,7 +299,7 @@ class TestPackedBroadcastRoundtrip:
for name, tensor in tensor_list:
unpacked_tensors[name] = tensor.clone()
packed_broadcast_consumer(
packed_nccl_broadcast_consumer(
iterator=iter(state_dict_info.items()),
group=consumer_group,
src=0,
@@ -441,3 +314,462 @@ class TestPackedBroadcastRoundtrip:
assert unpacked.shape == original_tensor.shape
assert unpacked.dtype == original_tensor.dtype
assert torch.allclose(unpacked, original_tensor, rtol=1e-4, atol=1e-6)
# --- Unit Tests: unpack_tensor ---
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
class TestUnpackTensor:
"""Test the shared unpack_tensor function."""
def test_unpack_produces_independent_copies(self):
"""Verify unpacked tensors don't share memory with packed buffer."""
original = torch.randn(10, dtype=torch.float32).cuda()
packed = original.contiguous().view(torch.uint8).view(-1)
result = unpack_tensor(
packed,
names=["w"],
shapes=[[10]],
dtypes=[torch.float32],
tensor_sizes=[packed.numel()],
)
# Mutate the packed buffer
packed.zero_()
# Unpacked tensor should be unaffected
assert torch.allclose(result[0][1], original)
# --- Unit Tests: pack_tensors ---
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
class TestPackTensors:
"""Test the shared pack_tensors function."""
def test_pack_basic(self):
"""Test packing a few tensors into one buffer."""
params = [
("w1", torch.randn(10, 20, dtype=torch.float32).cuda()),
("w2", torch.randn(5, dtype=torch.float16).cuda()),
]
chunk = pack_tensors(
iterator=iter(params),
post_iter_func=lambda x: x[1],
buffer_size_bytes=10_000_000,
)
assert chunk is not None
assert len(chunk.names) == 2
assert chunk.names == ["w1", "w2"]
assert chunk.shapes == [[10, 20], [5]]
assert chunk.dtypes == [torch.float32, torch.float16]
assert chunk.packed_tensor.dtype == torch.uint8
def test_pack_respects_buffer_limit(self):
"""Test that packing stops when buffer_size_bytes is exceeded."""
params = [
(f"w{i}", torch.randn(100, 100, dtype=torch.float32).cuda())
for i in range(10)
]
chunk = pack_tensors(
iterator=iter(params),
post_iter_func=lambda x: x[1],
buffer_size_bytes=50_000,
)
assert chunk is not None
assert len(chunk.names) < 10
def test_pack_empty_iterator(self):
"""Test that an empty iterator returns None."""
chunk = pack_tensors(
iterator=iter([]),
post_iter_func=lambda x: x[1],
buffer_size_bytes=1000,
)
assert chunk is None
def test_pack_single_tensor_larger_than_buffer_warns(self):
"""Test that a tensor exceeding buffer_size_bytes emits a warning."""
big = torch.randn(1000, 1000, dtype=torch.float32).cuda()
params = [("big", big)]
import warnings
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
chunk = pack_tensors(
iterator=iter(params),
post_iter_func=lambda x: x[1],
buffer_size_bytes=100,
)
assert chunk is not None
assert len(chunk.names) == 1
assert any("exceeds buffer_size_bytes" in str(wi.message) for wi in w)
def test_pack_unpack_roundtrip(self):
"""Test pack then unpack produces identical tensors."""
params = [
("a", torch.randn(8, 16, dtype=torch.float32).cuda()),
("b", torch.randn(4, dtype=torch.float16).cuda()),
("c", torch.randn(3, 5, 7, dtype=torch.bfloat16).cuda()),
]
chunk = pack_tensors(
iterator=iter(params),
post_iter_func=lambda x: x[1],
buffer_size_bytes=10_000_000,
)
assert chunk is not None
result = unpack_tensor(
chunk.packed_tensor,
chunk.names,
chunk.shapes,
chunk.dtypes,
chunk.tensor_sizes,
)
assert len(result) == len(params)
for (orig_name, orig_tensor), (res_name, res_tensor) in zip(params, result):
assert orig_name == res_name
assert res_tensor.shape == orig_tensor.shape
assert res_tensor.dtype == orig_tensor.dtype
assert torch.allclose(res_tensor, orig_tensor, rtol=1e-4, atol=1e-6)
def test_pack_multiple_chunks(self):
"""Test consuming an iterator across multiple pack_tensors calls."""
params = [
(f"w{i}", torch.randn(50, 50, dtype=torch.float32).cuda()) for i in range(6)
]
it = iter(params)
all_names = []
chunks = []
while True:
chunk = pack_tensors(it, lambda x: x[1], buffer_size_bytes=12_000)
if chunk is None:
break
chunks.append(chunk)
all_names.extend(chunk.names)
assert len(chunks) > 1
assert all_names == [f"w{i}" for i in range(6)]
# --- Unit Tests: packed_ipc_producer ---
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
class TestPackedIpcProducer:
"""Test the packed_ipc_producer generator."""
def test_producer_yields_chunks(self):
"""Test that the producer yields PackedIpcChunk objects."""
params = [
(f"w{i}", torch.randn(50, 50, dtype=torch.float32).cuda()) for i in range(6)
]
chunks = list(
packed_ipc_producer(
iterator=iter(params),
gpu_uuid="test-uuid",
post_iter_func=lambda x: x[1],
buffer_size_bytes=12_000,
)
)
assert len(chunks) > 1
def test_producer_ipc_handle_has_uuid(self):
"""Test that each chunk's ipc_handle is keyed by the given UUID."""
params = [("w", torch.randn(10, dtype=torch.float32).cuda())]
chunks = list(
packed_ipc_producer(
iterator=iter(params),
gpu_uuid="my-gpu-uuid",
post_iter_func=lambda x: x[1],
buffer_size_bytes=10_000_000,
)
)
assert "my-gpu-uuid" in chunks[0].ipc_handle
def test_producer_dtype_names_are_strings(self):
"""Test that dtype_names are string representations."""
params = [
("a", torch.randn(10, dtype=torch.float32).cuda()),
("b", torch.randn(10, dtype=torch.float16).cuda()),
]
chunks = list(
packed_ipc_producer(
iterator=iter(params),
gpu_uuid="uuid",
post_iter_func=lambda x: x[1],
buffer_size_bytes=10_000_000,
)
)
assert chunks[0].dtype_names == ["float32", "float16"]
def test_producer_empty_iterator(self):
"""Test producer with empty iterator yields nothing."""
chunks = list(
packed_ipc_producer(
iterator=iter([]),
gpu_uuid="uuid",
post_iter_func=lambda x: x[1],
buffer_size_bytes=1000,
)
)
assert len(chunks) == 0
# --- Integration Tests: IPC Producer-Consumer Roundtrip ---
def _ipc_consumer_worker(cmd_q, ack_q, result_q, done_event, device_index):
"""Worker that consumes chunks streamed one at a time from the parent.
CUDA IPC requires the consumer to be in a separate process from the
producer. The producer reuses a single IPC buffer between chunks, so
the parent must wait for our ack (sent after we copy the chunk to
CPU) before advancing the producer.
"""
try:
torch.accelerator.set_device_index(device_index)
all_results = []
while True:
cd = cmd_q.get()
if cd is None:
break
result = packed_ipc_consumer(
ipc_handle=cd["ipc_handle"],
names=cd["names"],
shapes=cd["shapes"],
dtype_names=cd["dtype_names"],
tensor_sizes=cd["tensor_sizes"],
device_index=device_index,
)
# .cpu() forces a GPU→CPU copy off the shared IPC buffer, so
# the producer is free to overwrite it once we ack.
all_results.extend([(name, tensor.cpu()) for name, tensor in result])
del result
ack_q.put("ack")
result_q.put(("ok", all_results))
except Exception as e:
result_q.put(("error", str(e)))
# Keep the process alive until the parent has finished reading from
# the result queue — torch serializes CPU tensors via fd sharing,
# which requires this process's resource-sharer server to be running.
done_event.wait(timeout=60)
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
class TestPackedIpcRoundtrip:
"""Test IPC producer-consumer roundtrip using real CUDA IPC.
These tests spawn a child process for the consumer because
rebuild_cuda_tensor requires a separate process from the one that
called reduce_tensor.
"""
def _get_gpu_uuid(self) -> str:
device_index = torch.cuda.current_device()
props = torch.cuda.get_device_properties(device_index)
return str(props.uuid)
def _run_roundtrip(self, chunk_iter, device_index, timeout=30):
"""Stream chunks through a child consumer one at a time.
``packed_ipc_producer`` reuses a single IPC buffer for every
chunk, so the producer must not be advanced until the consumer
has finished reading the current chunk. We enforce that with an
ack queue: the consumer puts ``"ack"`` after it has copied the
chunk to CPU, and only then do we pull the next chunk from the
generator.
Returns ``(num_chunks, results)``.
"""
import multiprocessing as mp
ctx = mp.get_context("spawn")
cmd_q = ctx.Queue()
ack_q = ctx.Queue()
result_q = ctx.Queue()
done_event = ctx.Event()
proc = ctx.Process(
target=_ipc_consumer_worker,
args=(cmd_q, ack_q, result_q, done_event, device_index),
)
proc.start()
num_chunks = 0
try:
for chunk in chunk_iter:
cmd_q.put(
{
"ipc_handle": chunk.ipc_handle,
"names": chunk.names,
"shapes": chunk.shapes,
"dtype_names": chunk.dtype_names,
"tensor_sizes": chunk.tensor_sizes,
}
)
if ack_q.get(timeout=timeout) != "ack":
raise RuntimeError("Consumer did not ack chunk")
num_chunks += 1
cmd_q.put(None)
status, payload = result_q.get(timeout=timeout)
finally:
done_event.set()
proc.join(timeout=10)
if proc.is_alive():
proc.kill()
if status == "error":
raise RuntimeError(f"Consumer process failed: {payload}")
# Reclaim IPC-shared memory now that the child has released it
torch.cuda.ipc_collect()
return num_chunks, payload
def test_roundtrip_basic(self):
"""Test basic IPC producer -> consumer roundtrip."""
params = [
("w1", torch.randn(10, 20, dtype=torch.float32).cuda()),
("w2", torch.randn(5, dtype=torch.float16).cuda()),
]
gpu_uuid = self._get_gpu_uuid()
device_index = torch.cuda.current_device()
num_chunks, result = self._run_roundtrip(
packed_ipc_producer(
iterator=iter(params),
gpu_uuid=gpu_uuid,
post_iter_func=lambda x: x[1],
buffer_size_bytes=10_000_000,
),
device_index,
)
assert num_chunks == 1
assert len(result) == 2
for (orig_name, orig_tensor), (res_name, res_tensor) in zip(params, result):
assert orig_name == res_name
assert res_tensor.shape == orig_tensor.shape
assert res_tensor.dtype == orig_tensor.dtype
assert torch.allclose(res_tensor, orig_tensor.cpu(), rtol=1e-4, atol=1e-6)
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
def test_roundtrip_dtypes(self, dtype):
"""Test IPC roundtrip with different dtypes."""
params = create_mock_model_params(num_layers=2, dtype=dtype)
params_cuda = [(n, t.cuda()) for n, t in params]
gpu_uuid = self._get_gpu_uuid()
device_index = torch.cuda.current_device()
_, result = self._run_roundtrip(
packed_ipc_producer(
iterator=iter(params_cuda),
gpu_uuid=gpu_uuid,
post_iter_func=lambda x: x[1],
buffer_size_bytes=10_000_000,
),
device_index,
)
assert len(result) == len(params_cuda)
for (orig_name, orig_tensor), (res_name, res_tensor) in zip(
params_cuda, result
):
assert orig_name == res_name
assert res_tensor.dtype == dtype
assert torch.allclose(res_tensor, orig_tensor.cpu(), rtol=1e-4, atol=1e-6)
def test_roundtrip_multiple_chunks(self):
"""Test IPC roundtrip across multiple chunks."""
params = [
(f"layer{i}.weight", torch.randn(100, 100, dtype=torch.float32).cuda())
for i in range(8)
]
gpu_uuid = self._get_gpu_uuid()
device_index = torch.cuda.current_device()
num_chunks, result = self._run_roundtrip(
packed_ipc_producer(
iterator=iter(params),
gpu_uuid=gpu_uuid,
post_iter_func=lambda x: x[1],
buffer_size_bytes=50_000,
),
device_index,
)
assert num_chunks > 1
assert len(result) == len(params)
for (orig_name, orig_tensor), (res_name, res_tensor) in zip(params, result):
assert orig_name == res_name
assert torch.allclose(res_tensor, orig_tensor.cpu(), rtol=1e-5, atol=1e-7)
def test_roundtrip_non_contiguous(self):
"""Test IPC roundtrip with non-contiguous tensors."""
params = [
("transposed", torch.randn(20, 10, dtype=torch.float32).cuda().T),
("sliced", torch.randn(40, 30, dtype=torch.float16).cuda()[::2, ::2]),
]
gpu_uuid = self._get_gpu_uuid()
device_index = torch.cuda.current_device()
for _, t in params:
assert not t.is_contiguous()
_, result = self._run_roundtrip(
packed_ipc_producer(
iterator=iter(params),
gpu_uuid=gpu_uuid,
post_iter_func=lambda x: x[1],
buffer_size_bytes=10_000_000,
),
device_index,
)
for (orig_name, orig_tensor), (res_name, res_tensor) in zip(params, result):
assert orig_name == res_name
assert res_tensor.shape == orig_tensor.shape
assert res_tensor.dtype == orig_tensor.dtype
assert torch.allclose(res_tensor, orig_tensor.cpu(), rtol=1e-4, atol=1e-6)
def test_consumer_wrong_uuid_raises(self):
"""Test that consumer raises ValueError for unknown GPU UUID."""
params = [("w", torch.randn(10, dtype=torch.float32).cuda())]
gpu_uuid = self._get_gpu_uuid()
chunks = list(
packed_ipc_producer(
iterator=iter(params),
gpu_uuid=gpu_uuid,
post_iter_func=lambda x: x[1],
buffer_size_bytes=10_000_000,
)
)
c = chunks[0]
fake_handle = {"fake-uuid-12345": c.ipc_handle[gpu_uuid]}
with pytest.raises(ValueError, match="IPC handle not found"):
packed_ipc_consumer(
ipc_handle=fake_handle,
names=c.names,
shapes=c.shapes,
dtype_names=c.dtype_names,
tensor_sizes=c.tensor_sizes,
device_index=torch.cuda.current_device(),
)
+54 -75
View File
@@ -389,7 +389,7 @@ class TestIPCWeightTransferUpdateInfoValidation:
# Create a dummy tensor and IPC handle
dummy_tensor = torch.ones(10, 10, device="cuda:0")
ipc_handle = reduce_tensor(dummy_tensor)
_, ipc_handle = reduce_tensor(dummy_tensor)
gpu_uuid = str(torch.cuda.get_device_properties(0).uuid)
ipc_handles = [{gpu_uuid: ipc_handle}]
@@ -410,7 +410,7 @@ class TestIPCWeightTransferUpdateInfoValidation:
pytest.skip("Need at least 1 GPU for this test")
dummy_tensor = torch.ones(10, 10, device="cuda:0")
ipc_handle = reduce_tensor(dummy_tensor)
_, ipc_handle = reduce_tensor(dummy_tensor)
gpu_uuid = str(torch.cuda.get_device_properties(0).uuid)
ipc_handles = [{gpu_uuid: ipc_handle}, {gpu_uuid: ipc_handle}]
@@ -428,7 +428,7 @@ class TestIPCWeightTransferUpdateInfoValidation:
pytest.skip("Need at least 1 GPU for this test")
dummy_tensor = torch.ones(10, 10, device="cuda:0")
ipc_handle = reduce_tensor(dummy_tensor)
_, ipc_handle = reduce_tensor(dummy_tensor)
gpu_uuid = str(torch.cuda.get_device_properties(0).uuid)
ipc_handles = [{gpu_uuid: ipc_handle}, {gpu_uuid: ipc_handle}]
@@ -446,7 +446,7 @@ class TestIPCWeightTransferUpdateInfoValidation:
pytest.skip("Need at least 1 GPU for this test")
dummy_tensor = torch.ones(10, 10, device="cuda:0")
ipc_handle = reduce_tensor(dummy_tensor)
_, ipc_handle = reduce_tensor(dummy_tensor)
gpu_uuid = str(torch.cuda.get_device_properties(0).uuid)
ipc_handles = [{gpu_uuid: ipc_handle}] # Only one handle
@@ -458,65 +458,9 @@ class TestIPCWeightTransferUpdateInfoValidation:
ipc_handles=ipc_handles,
)
def test_valid_update_info_from_pickled(self, monkeypatch):
"""Test creating IPCWeightTransferUpdateInfo from pickled handles."""
if torch.accelerator.device_count() < 1:
pytest.skip("Need at least 1 GPU for this test")
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
dummy_tensor = torch.ones(10, 10, device="cuda:0")
ipc_handle = reduce_tensor(dummy_tensor)
gpu_uuid = str(torch.cuda.get_device_properties(0).uuid)
ipc_handles = [{gpu_uuid: ipc_handle}]
pickled = base64.b64encode(pickle.dumps(ipc_handles)).decode("utf-8")
info = IPCWeightTransferUpdateInfo(
names=["layer.weight"],
dtype_names=["float32"],
shapes=[[10, 10]],
ipc_handles_pickled=pickled,
)
assert info.ipc_handles == ipc_handles
assert info.ipc_handles_pickled is None
def test_pickled_requires_insecure_serialization_flag(self, monkeypatch):
"""Test that pickled handles are rejected unless env flag is enabled."""
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "0")
with pytest.raises(ValueError, match="VLLM_ALLOW_INSECURE_SERIALIZATION=1"):
IPCWeightTransferUpdateInfo(
names=[],
dtype_names=[],
shapes=[],
ipc_handles_pickled=base64.b64encode(pickle.dumps([])).decode("utf-8"),
)
def test_both_handles_and_pickled_raises(self):
"""Test that providing both ipc_handles and ipc_handles_pickled raises."""
if torch.accelerator.device_count() < 1:
pytest.skip("Need at least 1 GPU for this test")
dummy_tensor = torch.ones(10, 10, device="cuda:0")
ipc_handle = reduce_tensor(dummy_tensor)
gpu_uuid = str(torch.cuda.get_device_properties(0).uuid)
ipc_handles = [{gpu_uuid: ipc_handle}]
pickled = base64.b64encode(pickle.dumps(ipc_handles)).decode("utf-8")
with pytest.raises(ValueError, match="Cannot specify both"):
IPCWeightTransferUpdateInfo(
names=["layer.weight"],
dtype_names=["float32"],
shapes=[[10, 10]],
ipc_handles=ipc_handles,
ipc_handles_pickled=pickled,
)
def test_neither_handles_nor_pickled_raises(self):
"""Test that providing neither ipc_handles nor ipc_handles_pickled raises."""
with pytest.raises(ValueError, match="must be provided"):
def test_missing_ipc_handles_raises(self):
"""Test that omitting ipc_handles raises TypeError."""
with pytest.raises(TypeError):
IPCWeightTransferUpdateInfo(
names=["layer.weight"],
dtype_names=["float32"],
@@ -552,10 +496,10 @@ class TestIPCEngineParsing:
# Create dummy IPC handles
dummy_tensor1 = torch.ones(100, 100, device="cuda:0")
dummy_tensor2 = torch.ones(50, device="cuda:0")
ipc_handle1 = reduce_tensor(dummy_tensor1)
ipc_handle2 = reduce_tensor(dummy_tensor2)
_, ipc_args1 = reduce_tensor(dummy_tensor1)
_, ipc_args2 = reduce_tensor(dummy_tensor2)
gpu_uuid = str(torch.cuda.get_device_properties(0).uuid)
ipc_handles = [{gpu_uuid: ipc_handle1}, {gpu_uuid: ipc_handle2}]
ipc_handles = [{gpu_uuid: ipc_args1}, {gpu_uuid: ipc_args2}]
update_info = engine.parse_update_info(
{
@@ -585,10 +529,10 @@ class TestIPCEngineParsing:
dummy_tensor1 = torch.ones(100, 100, device="cuda:0")
dummy_tensor2 = torch.ones(50, device="cuda:0")
ipc_handle1 = reduce_tensor(dummy_tensor1)
ipc_handle2 = reduce_tensor(dummy_tensor2)
_, ipc_args1 = reduce_tensor(dummy_tensor1)
_, ipc_args2 = reduce_tensor(dummy_tensor2)
gpu_uuid = str(torch.cuda.get_device_properties(0).uuid)
ipc_handles = [{gpu_uuid: ipc_handle1}, {gpu_uuid: ipc_handle2}]
ipc_handles = [{gpu_uuid: ipc_args1}, {gpu_uuid: ipc_args2}]
pickled = base64.b64encode(pickle.dumps(ipc_handles)).decode("utf-8")
@@ -604,10 +548,36 @@ class TestIPCEngineParsing:
assert isinstance(update_info, IPCWeightTransferUpdateInfo)
assert update_info.names == ["w1", "w2"]
assert len(update_info.ipc_handles) == 2
assert update_info.ipc_handles_pickled is None
assert gpu_uuid in update_info.ipc_handles[0]
assert gpu_uuid in update_info.ipc_handles[1]
def test_parse_update_info_both_handles_and_pickled_raises(self):
"""Test that providing both ipc_handles and ipc_handles_pickled raises."""
if torch.accelerator.device_count() < 1:
pytest.skip("Need at least 1 GPU for this test")
config = WeightTransferConfig(backend="ipc")
parallel_config = create_mock_parallel_config()
engine = IPCWeightTransferEngine(config, parallel_config)
dummy_tensor = torch.ones(10, 10, device="cuda:0")
_, ipc_handle = reduce_tensor(dummy_tensor)
gpu_uuid = str(torch.cuda.get_device_properties(0).uuid)
ipc_handles = [{gpu_uuid: ipc_handle}]
pickled = base64.b64encode(pickle.dumps(ipc_handles)).decode("utf-8")
with pytest.raises(ValueError, match="Cannot specify both"):
engine.parse_update_info(
{
"names": ["layer.weight"],
"dtype_names": ["float32"],
"shapes": [[10, 10]],
"ipc_handles": ipc_handles,
"ipc_handles_pickled": pickled,
}
)
# --- Integration Test: IPC Weight Transfer Between Ray Tasks ---
@@ -629,13 +599,15 @@ class TrainerActor:
self.tensor.fill_(42.0) # Fill with 42 to verify correct transfer
# Create IPC handle (tensor must stay alive for IPC to work)
ipc_handle = reduce_tensor(self.tensor)
# reduce_tensor returns (rebuild_func, args); we only send args
# since the receiver imports rebuild_cuda_tensor directly.
_, ipc_args = reduce_tensor(self.tensor)
gpu_uuid = get_physical_gpu_id(0)
torch.accelerator.synchronize()
self.ipc_handle_dict = {
"ipc_handle": ipc_handle,
"ipc_handle": ipc_args,
"gpu_uuid": gpu_uuid,
"shape": tensor_shape,
"dtype": tensor_dtype,
@@ -652,6 +624,12 @@ def inference_receive_ipc_tensor(
mode: str = "ray",
) -> dict:
"""Inference task that receives tensor via IPCWeightTransferEngine."""
import os
# Worker-side: ipc_handles_pickled is deserialized via pickle.
if mode == "http":
os.environ["VLLM_ALLOW_INSECURE_SERIALIZATION"] = "1"
from unittest.mock import MagicMock
import torch
@@ -684,7 +662,6 @@ def inference_receive_ipc_tensor(
# Clone tensor to keep it after engine cleans up
received_tensors.append((name, tensor.clone()))
# Build update dict and go through parse_update_info (exercises __post_init__)
ipc_handles = [{ipc_handle_dict["gpu_uuid"]: ipc_handle_dict["ipc_handle"]}]
if mode == "ray":
@@ -695,6 +672,7 @@ def inference_receive_ipc_tensor(
"ipc_handles": ipc_handles,
}
elif mode == "http":
# Simulate HTTP transport: pickle + base64 encode handles
pickled = base64.b64encode(pickle.dumps(ipc_handles)).decode("utf-8")
update_dict = {
"names": ["test.weight"],
@@ -743,7 +721,8 @@ def test_ipc_weight_transfer_between_processes(mode: str):
Parametrized over transport modes:
- 'ray': ipc_handles passed directly.
- 'http': ipc_handles pickled + base64-encoded, unpickled via __post_init__.
- 'http': ipc_handles pickled + base64-encoded, deserialized in
parse_update_info before constructing the dataclass.
IPC requires same-GPU access, so we use a placement group to co-locate
the trainer actor and inference task on the same GPU.
@@ -801,7 +780,7 @@ def test_ipc_receive_weights_missing_gpu_uuid_raises():
# Create IPC handle with wrong GPU UUID
dummy_tensor = torch.ones(10, 10, device="cuda:0")
ipc_handle = reduce_tensor(dummy_tensor)
_, ipc_handle = reduce_tensor(dummy_tensor)
wrong_uuid = "wrong-uuid-12345"
ipc_handles = [{wrong_uuid: ipc_handle}]
+6 -1
View File
@@ -15,7 +15,12 @@ _TORCH_CUDA_PATTERNS = [
r"\bcuda_device_count_stateless\(\)\b",
]
ALLOWED_FILES = {"vllm/platforms/", "vllm/device_allocator/"}
ALLOWED_FILES = {
"vllm/platforms/",
"vllm/device_allocator/",
"vllm/distributed/weight_transfer/ipc_engine.py",
"tests/distributed/test_packed_tensor.py",
}
def scan_file(path: str) -> int:
+291 -142
View File
@@ -8,9 +8,10 @@ from dataclasses import asdict, dataclass
from typing import Any
import pybase64 as base64
import ray
import requests
import torch
from torch.multiprocessing.reductions import reduce_tensor
from torch.multiprocessing.reductions import rebuild_cuda_tensor, reduce_tensor
from vllm import envs
from vllm.config.parallel import ParallelConfig
@@ -20,27 +21,43 @@ from vllm.distributed.weight_transfer.base import (
WeightTransferInitInfo,
WeightTransferUpdateInfo,
)
from vllm.distributed.weight_transfer.packed_tensor import (
DEFAULT_PACKED_BUFFER_SIZE_BYTES,
packed_ipc_consumer,
packed_ipc_producer,
)
@dataclass
class IPCTrainerSendWeightsArgs:
"""Arguments for IPC trainer_send_weights method."""
mode: str
"""Transport mode: 'http' or 'ray'."""
send_mode: str | Callable[["IPCWeightTransferUpdateInfo"], None]
"""How to send updates to vLLM. Either a string ('ray' or 'http') for
built-in transports, or a callable that receives an
IPCWeightTransferUpdateInfo and performs the send."""
llm_handle: Any = None
"""Ray ObjectRef to LLM handle (required for 'ray' mode)."""
"""Ray actor handle or list of handles (required for 'ray' send_mode)."""
url: str | None = None
"""Base URL for HTTP endpoint (required for 'http' mode)."""
"""Base URL for HTTP endpoint (required for 'http' send_mode)."""
packed: bool = False
"""Whether to use packed tensor transfer for bounded-memory chunking."""
packed_buffer_size_bytes: int = DEFAULT_PACKED_BUFFER_SIZE_BYTES
"""Size in bytes for each packed tensor buffer when packed=True."""
def __post_init__(self):
"""Validate that required arguments are provided for the selected mode."""
if self.mode == "ray" and self.llm_handle is None:
raise ValueError("llm_handle is required for 'ray' mode")
if self.mode == "http" and self.url is None:
raise ValueError("url is required for 'http' mode")
if self.mode not in ("ray", "http"):
raise ValueError(f"mode must be 'ray' or 'http', got {self.mode}")
if callable(self.send_mode):
return
if self.send_mode == "ray" and self.llm_handle is None:
raise ValueError("llm_handle is required for 'ray' send_mode")
if self.send_mode == "http" and self.url is None:
raise ValueError("url is required for 'http' send_mode")
if self.send_mode not in ("ray", "http"):
raise ValueError(
f"send_mode must be 'ray', 'http', or a callable, "
f"got {self.send_mode!r}"
)
@dataclass
@@ -52,44 +69,22 @@ class IPCWeightTransferInitInfo(WeightTransferInitInfo):
@dataclass
class IPCWeightTransferUpdateInfo(WeightTransferUpdateInfo):
"""Update info for IPC weight transfer backend.
Accepts IPC handles either directly via ``ipc_handles`` (Ray transport)
or as a base64-encoded pickle via ``ipc_handles_pickled`` (HTTP transport).
Exactly one of the two must be provided; if ``ipc_handles_pickled`` is set
it is unpickled into ``ipc_handles`` during ``__post_init__``.
"""
"""Update info for IPC weight transfer backend."""
names: list[str]
dtype_names: list[str]
shapes: list[list[int]]
ipc_handles: list[dict[str, tuple[Callable, tuple]]] | None = None
"""IPC handles mapping physical GPU UUID to (func, args) tuple.
Each handle is a dictionary mapping GPU UUID strings to IPC handle tuples."""
ipc_handles_pickled: str | None = None
"""Base64-encoded pickled IPC handles, used for HTTP transport."""
ipc_handles: list[dict[str, tuple]] | dict[str, tuple]
"""IPC handles mapping physical GPU UUID to rebuild_cuda_tensor args.
For non-packed mode: list of per-parameter handle dicts.
For packed mode: single handle dict for the packed buffer."""
tensor_sizes: list[int] | None = None
"""Per-parameter sizes in bytes within the packed buffer.
Required when packed=True, unused otherwise."""
packed: bool = False
"""Whether this update uses packed tensor format."""
def __post_init__(self):
if self.ipc_handles_pickled is not None:
if self.ipc_handles is not None:
raise ValueError(
"Cannot specify both `ipc_handles` and `ipc_handles_pickled`"
)
if not envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
raise ValueError(
"Refusing to deserialize `ipc_handles_pickled` without "
"VLLM_ALLOW_INSECURE_SERIALIZATION=1"
)
self.ipc_handles = pickle.loads(base64.b64decode(self.ipc_handles_pickled))
self.ipc_handles_pickled = None
if self.ipc_handles is None:
raise ValueError(
"Either `ipc_handles` or `ipc_handles_pickled` must be provided"
)
num_params = len(self.names)
if len(self.dtype_names) != num_params:
raise ValueError(
@@ -101,11 +96,17 @@ class IPCWeightTransferUpdateInfo(WeightTransferUpdateInfo):
f"`shapes` should be of the same size as `names`: "
f"got {len(self.shapes)} and {len(self.names)}"
)
if len(self.ipc_handles) != num_params:
if (
not self.packed
and isinstance(self.ipc_handles, list)
and len(self.ipc_handles) != num_params
):
raise ValueError(
f"`ipc_handles` should be of the same size as `names`: "
f"got {len(self.ipc_handles)} and {len(self.names)}"
)
if self.packed and self.tensor_sizes is None:
raise ValueError("`tensor_sizes` is required when packed=True")
class IPCWeightTransferEngine(
@@ -135,6 +136,36 @@ class IPCWeightTransferEngine(
"""
super().__init__(config, parallel_config)
def parse_update_info(
self, update_dict: dict[str, Any]
) -> IPCWeightTransferUpdateInfo:
"""Parse update dict, deserializing pickled IPC handles if present.
HTTP transport sends IPC handles as a base64-encoded pickle under the
key ``ipc_handles_pickled``. This method deserializes them back into
``ipc_handles`` before constructing the typed dataclass, keeping
serialization concerns out of the dataclass itself.
Requires ``VLLM_ALLOW_INSECURE_SERIALIZATION=1`` because the
payload is deserialized via ``pickle.loads``.
"""
if "ipc_handles_pickled" in update_dict:
if "ipc_handles" in update_dict:
raise ValueError(
"Cannot specify both `ipc_handles` and `ipc_handles_pickled`"
)
if not envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
raise ValueError(
"Refusing to deserialize `ipc_handles_pickled` without "
"VLLM_ALLOW_INSECURE_SERIALIZATION=1"
)
pickled = update_dict.pop("ipc_handles_pickled")
update_dict["ipc_handles"] = pickle.loads(base64.b64decode(pickled))
return super().parse_update_info(update_dict)
def init_transfer_engine(self, init_info: IPCWeightTransferInitInfo) -> None:
"""
Initialize the weight transfer mechanism.
@@ -157,46 +188,52 @@ class IPCWeightTransferEngine(
Args:
update_info: IPC update info containing parameter names, dtypes, shapes,
and IPC handles. Each IPC handle is a mapping between physical
GPU UUID and the IPC handle tuple (func, args).
GPU UUID and the rebuild_cuda_tensor args tuple.
load_weights: Callable that loads weights into the model. Called
incrementally for each weight to avoid OOM.
"""
assert update_info.ipc_handles is not None
weights = []
for name, _dtype_name, _shape, ipc_handle in zip(
update_info.names,
update_info.dtype_names,
update_info.shapes,
update_info.ipc_handles,
):
device_index = torch.accelerator.current_device_index()
props = torch.cuda.get_device_properties(device_index)
physical_gpu_id = str(props.uuid)
device_index = torch.accelerator.current_device_index()
if physical_gpu_id not in ipc_handle:
raise ValueError(
f"IPC handle not found for GPU UUID {physical_gpu_id}. "
f"Available UUIDs: {list(ipc_handle.keys())}"
)
if update_info.packed:
assert update_info.tensor_sizes is not None
assert isinstance(update_info.ipc_handles, dict)
weights = packed_ipc_consumer(
ipc_handle=update_info.ipc_handles,
names=update_info.names,
shapes=update_info.shapes,
dtype_names=update_info.dtype_names,
tensor_sizes=update_info.tensor_sizes,
device_index=device_index,
)
load_weights(weights)
else:
assert isinstance(update_info.ipc_handles, list)
weights = []
for name, ipc_handle in zip(
update_info.names,
update_info.ipc_handles,
):
props = torch.cuda.get_device_properties(device_index)
physical_gpu_id = str(props.uuid)
handle = ipc_handle[physical_gpu_id]
if physical_gpu_id not in ipc_handle:
raise ValueError(
f"IPC handle not found for GPU UUID "
f"{physical_gpu_id}. "
f"Available UUIDs: {list(ipc_handle.keys())}"
)
func, args = handle
list_args = list(args) # type: ignore
# Index 6 is the device_index parameter in torch's
# IPC handle tuple (rebuild_cuda_tensor). Update it
# to the current device since the logical index can
# differ between sender and receiver.
list_args[6] = device_index
weight = func(*list_args) # type: ignore
weights.append((name, weight))
args = ipc_handle[physical_gpu_id]
list_args = list(args)
# Index 6 of the args from reduce_tensor is the device_index.
# We need to overwrite it with the receiver's device index.
list_args[6] = device_index
weight = rebuild_cuda_tensor(*list_args)
weights.append((name, weight))
load_weights(weights)
load_weights(weights)
def shutdown(self) -> None:
"""
Shutdown the weight transfer engine.
"""
pass
@staticmethod
@@ -204,12 +241,17 @@ class IPCWeightTransferEngine(
iterator: Iterator[tuple[str, torch.Tensor]],
trainer_args: dict[str, Any] | IPCTrainerSendWeightsArgs,
) -> None:
"""
Send weights from trainer to inference workers via CUDA IPC.
"""Send weights from trainer to inference workers via CUDA IPC.
Supports two modes:
- 'ray': Sends weights via Ray RPC to a Ray-based LLM handle
- 'http': Sends weights via HTTP POST to a vLLM HTTP server
Supports two transport modes ('ray' and 'http') and two transfer
strategies:
- Non-packed (default): all weights in a single API call.
- Packed (packed=True): chunked transfer with bounded GPU memory.
For multi-GPU training, all ranks must call this method in
parallel. IPC handles are all-gathered across ranks and merged
so that each vLLM worker can find its own GPU UUID. Only rank 0
sends the payload to vLLM.
.. note::
This method calls ``update_weights`` internally. The caller must
@@ -217,88 +259,195 @@ class IPCWeightTransferEngine(
after this method.
Args:
iterator: Iterator of model parameters. Returns (name, tensor) tuples.
Tensors should be on the same GPU as the inference workers.
trainer_args: Dictionary containing IPC-specific arguments.
Should contain keys from IPCTrainerSendWeightsArgs:
- mode: 'ray' or 'http'
- llm_handle: Ray ObjectRef (for 'ray' mode)
- url: Base URL string (for 'http' mode)
Example (Ray mode):
>>> from vllm.distributed.weight_transfer.ipc_engine import (
... IPCWeightTransferEngine,
... IPCTrainerSendWeightsArgs,
... )
>>> param_iter = ((n, p) for n, p in model.named_parameters())
>>> args = IPCTrainerSendWeightsArgs(mode="ray", llm_handle=llm_handle)
>>> IPCWeightTransferEngine.trainer_send_weights(param_iter, asdict(args))
Example (HTTP mode):
>>> args = IPCTrainerSendWeightsArgs(
... mode="http", url="http://localhost:8000"
... )
>>> IPCWeightTransferEngine.trainer_send_weights(param_iter, asdict(args))
iterator: Iterator of (name, tensor) pairs. For multi-GPU,
each rank should yield the full tensor on its own GPU
(e.g. via FSDP full_tensor()).
trainer_args: IPCTrainerSendWeightsArgs or equivalent dict.
"""
# Parse trainer args - accept either dict or dataclass instance
if isinstance(trainer_args, dict):
args = IPCTrainerSendWeightsArgs(**trainer_args)
else:
args = trainer_args
# Get physical GPU UUID
args = (
IPCTrainerSendWeightsArgs(**trainer_args)
if isinstance(trainer_args, dict)
else trainer_args
)
device_index = torch.accelerator.current_device_index()
props = torch.cuda.get_device_properties(device_index)
gpu_uuid = str(props.uuid)
gpu_uuid = str(torch.cuda.get_device_properties(device_index).uuid)
if args.packed:
IPCWeightTransferEngine._send_packed(iterator, args, gpu_uuid)
else:
IPCWeightTransferEngine._send_unpacked(iterator, args, gpu_uuid)
# Collect weight metadata and create IPC handles
names = []
dtype_names = []
shapes = []
ipc_handles = []
@staticmethod
def _is_rank_zero() -> bool:
"""Return True if this is rank 0 or no distributed group exists."""
if not torch.distributed.is_initialized():
return True
return torch.distributed.get_rank() == 0
@staticmethod
def _all_gather_and_merge_handles(
handles: list[dict[str, tuple]],
) -> list[dict[str, tuple]]:
"""All-gather and merge IPC handle dicts across ranks in one call.
Each rank contributes a list of {gpu_uuid: ipc_args} dicts (one
per parameter or one per chunk). A single all_gather_object
collects every rank's full list, then rank 0 merges per-index so
each dict maps every GPU UUID to its args.
Non-rank-0 returns a list of empty dicts.
No-op (returns handles unchanged) when no distributed group exists.
"""
if (
not torch.distributed.is_initialized()
or torch.distributed.get_world_size() == 1
):
return handles
world_size = torch.distributed.get_world_size()
gathered: list[list[dict[str, tuple]] | None] = [None] * world_size
torch.distributed.all_gather_object(gathered, handles)
torch.distributed.barrier()
torch.cuda.synchronize()
if torch.distributed.get_rank() == 0:
merged: list[dict[str, tuple]] = []
for param_idx in range(len(handles)):
m: dict[str, tuple] = {}
for rank_handles in gathered:
if rank_handles is not None:
m.update(rank_handles[param_idx])
merged.append(m)
return merged
return [{} for _ in handles]
@staticmethod
def _post_send_sync() -> None:
"""Barrier + ipc_collect after a send; no-op if single-GPU."""
if (
torch.distributed.is_initialized()
and torch.distributed.get_world_size() > 1
):
torch.distributed.barrier()
torch.cuda.ipc_collect()
@staticmethod
def _send_unpacked(
iterator: Iterator[tuple[str, torch.Tensor]],
args: IPCTrainerSendWeightsArgs,
gpu_uuid: str,
) -> None:
"""Send all weights in a single API call (non-packed mode)."""
names: list[str] = []
dtype_names: list[str] = []
shapes: list[list[int]] = []
ipc_handles: list[dict[str, tuple]] = []
# Hold strong refs to every contiguous copy until the send + post-send
# sync completes. reduce_tensor's returned args do NOT keep storage
# alive, and non-contiguous inputs allocate fresh storage in
# .contiguous() that would otherwise be GC'd before the consumer opens
# the IPC handle.
weight_refs: list[torch.Tensor] = []
for name, tensor in iterator:
names.append(name)
dtype_names.append(str(tensor.dtype).split(".")[-1])
shapes.append(list(tensor.shape))
# Create IPC handle for this weight tensor
# The tensor must remain in memory for IPC to work
weight = tensor.detach().contiguous()
ipc_handle = reduce_tensor(weight)
ipc_handles.append({gpu_uuid: ipc_handle})
weight_refs.append(weight)
_, ipc_args = reduce_tensor(weight)
ipc_handles.append({gpu_uuid: ipc_args})
# Send weights based on mode
if args.mode == "ray":
# Ray mode: send via Ray RPC
import ray
ipc_handles = IPCWeightTransferEngine._all_gather_and_merge_handles(ipc_handles)
update_info = asdict(
IPCWeightTransferUpdateInfo(
names=names,
dtype_names=dtype_names,
shapes=shapes,
ipc_handles=ipc_handles,
if IPCWeightTransferEngine._is_rank_zero():
IPCWeightTransferEngine._do_send(
args=args,
names=names,
dtype_names=dtype_names,
shapes=shapes,
ipc_handles=ipc_handles,
)
IPCWeightTransferEngine._post_send_sync()
@staticmethod
def _send_packed(
iterator: Iterator[tuple[str, torch.Tensor]],
args: IPCTrainerSendWeightsArgs,
gpu_uuid: str,
) -> None:
"""Send weights in bounded-memory chunks (packed mode)."""
post_iter_func: Callable = lambda item: item[1]
for chunk in packed_ipc_producer(
iterator=iterator,
gpu_uuid=gpu_uuid,
post_iter_func=post_iter_func,
buffer_size_bytes=args.packed_buffer_size_bytes,
):
ipc_handle = IPCWeightTransferEngine._all_gather_and_merge_handles(
[chunk.ipc_handle]
)[0]
if IPCWeightTransferEngine._is_rank_zero():
IPCWeightTransferEngine._do_send(
args=args,
names=chunk.names,
dtype_names=chunk.dtype_names,
shapes=chunk.shapes,
ipc_handles=ipc_handle,
tensor_sizes=chunk.tensor_sizes,
packed=True,
)
IPCWeightTransferEngine._post_send_sync()
@staticmethod
def _do_send(
args: IPCTrainerSendWeightsArgs,
names: list[str],
dtype_names: list[str],
shapes: list[list[int]],
ipc_handles: list[dict[str, tuple]] | dict[str, tuple],
tensor_sizes: list[int] | None = None,
packed: bool = False,
) -> None:
"""Send a single update payload via the configured transport."""
update_fields: dict[str, Any] = {
"names": names,
"dtype_names": dtype_names,
"shapes": shapes,
"packed": packed,
}
if tensor_sizes is not None:
update_fields["tensor_sizes"] = tensor_sizes
update_fields["ipc_handles"] = ipc_handles
update_info = IPCWeightTransferUpdateInfo(**update_fields)
if callable(args.send_mode):
args.send_mode(update_info)
elif args.send_mode == "ray":
handles = (
args.llm_handle
if isinstance(args.llm_handle, list)
else [args.llm_handle]
)
ray.get(
args.llm_handle.update_weights.remote(dict(update_info=update_info))
[
h.update_weights.remote(dict(update_info=asdict(update_info)))
for h in handles
]
)
elif args.mode == "http":
# HTTP mode: send via HTTP POST with pickled handles
# Pickle and base64 encode IPC handles for HTTP transmission
elif args.send_mode == "http":
pickled_handles = base64.b64encode(pickle.dumps(ipc_handles)).decode(
"utf-8"
)
http_fields = {k: v for k, v in update_fields.items() if k != "ipc_handles"}
http_fields["ipc_handles_pickled"] = pickled_handles
url = f"{args.url}/update_weights"
payload = {
"update_info": {
"names": names,
"dtype_names": dtype_names,
"shapes": shapes,
"ipc_handles_pickled": pickled_handles,
}
}
payload = {"update_info": http_fields}
response = requests.post(url, json=payload, timeout=300)
response.raise_for_status()
@@ -21,7 +21,7 @@ from vllm.distributed.weight_transfer.base import (
from vllm.distributed.weight_transfer.packed_tensor import (
DEFAULT_PACKED_BUFFER_SIZE_BYTES,
DEFAULT_PACKED_NUM_BUFFERS,
packed_broadcast_consumer,
packed_nccl_broadcast_consumer,
)
@@ -184,7 +184,7 @@ class NCCLWeightTransferEngine(
dtype = getattr(torch, dtype_name)
yield (name, (shape, dtype))
packed_broadcast_consumer(
packed_nccl_broadcast_consumer(
iterator=state_dict_info_iterator(),
group=self.model_update_group,
src=0,
@@ -247,10 +247,10 @@ class NCCLWeightTransferEngine(
if args.packed:
# Use packed tensor broadcasting for efficiency
from vllm.distributed.weight_transfer.packed_tensor import (
packed_broadcast_producer,
packed_nccl_broadcast_producer,
)
packed_broadcast_producer(
packed_nccl_broadcast_producer(
iterator=iterator,
group=args.group,
src=args.src,
+283 -72
View File
@@ -4,9 +4,11 @@
import math
from collections.abc import Callable, Iterator
from dataclasses import dataclass
from typing import Any
import torch
from torch.multiprocessing.reductions import reduce_tensor
# Default values for packed tensor configuration.
# These are imported by NCCLWeightTransferUpdateInfo and trainer_send_weights.
@@ -14,7 +16,124 @@ DEFAULT_PACKED_BUFFER_SIZE_BYTES = 1024 * 1024 * 1024 # 1GB
DEFAULT_PACKED_NUM_BUFFERS = 2
def packed_broadcast_producer(
def unpack_tensor(
packed_tensor: torch.Tensor,
names: list[str],
shapes: list[list[int]],
dtypes: list[torch.dtype],
tensor_sizes: list[int],
) -> list[tuple[str, torch.Tensor]]:
"""Unpack a packed uint8 tensor into a list of named tensors.
The returned tensors are **views** of ``packed_tensor`` (the
``.contiguous()`` call is a no-op on already-contiguous row-slices).
If ``packed_tensor`` lives in storage that may be reused — e.g. a
reused CUDA IPC buffer — callers must clone the results before the
underlying storage is overwritten.
Args:
packed_tensor: The packed torch.uint8 tensor to unpack
names: List of tensor names
shapes: List of tensor shapes
dtypes: List of tensor dtypes
tensor_sizes: List of tensor sizes in bytes
"""
unpacked_tensors = packed_tensor.split(tensor_sizes)
return [
(name, tensor.contiguous().view(dtype).view(*shape))
for name, shape, dtype, tensor in zip(names, shapes, dtypes, unpacked_tensors)
]
@dataclass
class PackedChunk:
"""Result of packing tensors into a single contiguous uint8 buffer."""
packed_tensor: torch.Tensor
names: list[str]
shapes: list[list[int]]
dtypes: list[torch.dtype]
tensor_sizes: list[int]
def pack_tensors(
iterator: Iterator[tuple[str, torch.Tensor]],
post_iter_func: Callable[[tuple[str, torch.Tensor]], torch.Tensor],
buffer_size_bytes: int,
tensor_list: list[torch.Tensor] | None = None,
current_size: int = 0,
) -> PackedChunk | None:
"""Pack tensors from an iterator into a single contiguous uint8 buffer.
Consumes from the iterator until the accumulated size exceeds
buffer_size_bytes or the iterator is exhausted, then returns a
PackedChunk. Returns None if no tensors were consumed.
Args:
iterator: Iterator of (name, tensor) pairs
post_iter_func: Applied to each item before linearizing to uint8
buffer_size_bytes: Max bytes before flushing
tensor_list: Pre-existing tensor list to append to (for NCCL
multi-buffer reuse). If None, a fresh list is created.
current_size: Byte count already accumulated in tensor_list
"""
if tensor_list is None:
tensor_list = []
names: list[str] = []
shapes: list[list[int]] = []
dtypes: list[torch.dtype] = []
tensor_sizes: list[int] = []
total_bytes = current_size
while True:
try:
item = next(iterator)
except StopIteration:
break
name, orig_tensor = item
# Apply post processing and convert to linearized uint8 tensor
tensor = post_iter_func(item).contiguous().view(torch.uint8).view(-1)
if tensor.numel() > buffer_size_bytes:
import warnings
warnings.warn(
f"Tensor '{name}' has size {tensor.numel()} bytes, which "
f"exceeds buffer_size_bytes={buffer_size_bytes}.",
stacklevel=2,
)
tensor_list.append(tensor)
names.append(name)
shapes.append(list(orig_tensor.shape))
dtypes.append(orig_tensor.dtype)
tensor_sizes.append(tensor.numel())
total_bytes += tensor.numel()
if total_bytes > buffer_size_bytes:
break
if not tensor_list:
return None
packed = torch.cat(tensor_list, dim=0)
del tensor_list
return PackedChunk(
packed_tensor=packed,
names=names,
shapes=shapes,
dtypes=dtypes,
tensor_sizes=tensor_sizes,
)
# ── NCCL packed broadcast ──────────────────────────────────────────────
def packed_nccl_broadcast_producer(
iterator: Iterator[tuple[str, torch.Tensor]],
group: Any,
src: int,
@@ -36,57 +155,31 @@ def packed_broadcast_producer(
Both producer and consumer must use the same value.
"""
target_packed_tensor_size = buffer_size_bytes
streams = [torch.cuda.Stream() for _ in range(num_buffers)]
# Keep references to in-flight chunks so their packed_tensors
# aren't freed while an async broadcast is still reading them.
in_flight: list[PackedChunk | None] = [None] * num_buffers
buffer_idx = 0
packing_tensor_list: list[list[torch.Tensor]] = [[] for _ in range(num_buffers)]
packing_tensor_sizes: list[int] = [0 for _ in range(num_buffers)]
packed_tensors: list[torch.Tensor] = [
torch.empty(0, dtype=torch.uint8, device="cuda") for _ in range(num_buffers)
]
while True:
# Synchronize the current stream
streams[buffer_idx].synchronize()
# Previous chunk on this buffer slot is now safe to free
in_flight[buffer_idx] = None
# Start tasks for the new buffer in a new stream
with torch.cuda.stream(streams[buffer_idx]):
try:
# Initialize the packing tensor list and sizes
packing_tensor_list[buffer_idx] = []
packing_tensor_sizes[buffer_idx] = 0
# Pack the tensors
while True:
# Apply post processing and convert to linearized uint8 tensor
tensor = (
post_iter_func(next(iterator))
.contiguous()
.view(torch.uint8)
.view(-1)
)
packing_tensor_list[buffer_idx].append(tensor)
packing_tensor_sizes[buffer_idx] += tensor.numel()
if packing_tensor_sizes[buffer_idx] > target_packed_tensor_size:
break
# Pack the tensors and call broadcast collective
packed_tensors[buffer_idx] = torch.cat(
packing_tensor_list[buffer_idx], dim=0
)
group.broadcast(packed_tensors[buffer_idx], src=src)
# Move to the next buffer
buffer_idx = (buffer_idx + 1) % num_buffers
except StopIteration:
# Do the last broadcast if there are remaining tensors
if len(packing_tensor_list[buffer_idx]) > 0:
packed_tensors[buffer_idx] = torch.cat(
packing_tensor_list[buffer_idx], dim=0
)
group.broadcast(packed_tensors[buffer_idx], src=src)
chunk = pack_tensors(iterator, post_iter_func, buffer_size_bytes)
if chunk is None:
break
# Pack the tensors and call broadcast collective
group.broadcast(chunk.packed_tensor, src=src)
# Hold reference until this stream is synchronized
in_flight[buffer_idx] = chunk
# Move to the next buffer
buffer_idx = (buffer_idx + 1) % num_buffers
def packed_broadcast_consumer(
def packed_nccl_broadcast_consumer(
iterator: Iterator[tuple[str, tuple[list[int], torch.dtype]]],
group: Any,
src: int,
@@ -108,37 +201,6 @@ def packed_broadcast_consumer(
Both producer and consumer must use the same value.
"""
def unpack_tensor(
packed_tensor: torch.Tensor,
names: list[str],
shapes: list[list[int]],
dtypes: list[torch.dtype],
tensor_sizes: list[int],
) -> list[tuple[str, torch.Tensor]]:
"""Unpack a single tensor into a list of tensors.
Args:
packed_tensor: The packed torch.uint8 tensor to unpack
names: List of tensor names
shapes: List of tensor shapes
dtypes: List of tensor dtypes
tensor_sizes: List of tensor sizes in bytes
Returns:
unpacked List[(name, tensor)]
"""
unpacked_tensors = packed_tensor.split(tensor_sizes)
unpacked_list = [
(name, tensor.contiguous().view(dtype).view(*shape))
for name, shape, dtype, tensor in zip(
names, shapes, dtypes, unpacked_tensors
)
]
return unpacked_list
target_packed_tensor_size = buffer_size_bytes
streams = [torch.cuda.Stream() for _ in range(num_buffers)]
@@ -214,3 +276,152 @@ def packed_broadcast_consumer(
)
)
break
# ── IPC packed transfer ────────────────────────────────────────────────
@dataclass
class PackedIpcChunk:
"""Metadata and IPC handle for a single packed chunk."""
names: list[str]
shapes: list[list[int]]
dtype_names: list[str]
tensor_sizes: list[int]
ipc_handle: dict[str, tuple]
def packed_ipc_producer(
iterator: Iterator[tuple[str, torch.Tensor]],
gpu_uuid: str,
post_iter_func: Callable[[tuple[str, torch.Tensor]], torch.Tensor],
buffer_size_bytes: int = DEFAULT_PACKED_BUFFER_SIZE_BYTES,
) -> Iterator[PackedIpcChunk]:
"""Pack tensors into a reusable IPC buffer and yield handles.
Allocates a single GPU buffer of ``buffer_size_bytes`` and registers
it for IPC once via ``reduce_tensor``. Each chunk's packed data is
copied into this buffer before yielding, so only one IPC-shared
allocation exists for the lifetime of the transfer.
Callers **must** ensure the consumer has finished reading the buffer
(e.g. ``ray.get`` returned) before resuming the generator for the
next chunk.
Args:
iterator: Iterator of (name, tensor) pairs.
gpu_uuid: Physical GPU UUID string for this rank.
post_iter_func: Applied to each (name, tensor) before packing.
buffer_size_bytes: Exact capacity of the reusable IPC buffer.
Every chunk is guaranteed to fit within this size. A
``ValueError`` is raised if any single tensor exceeds it.
"""
ipc_buffer = torch.empty(buffer_size_bytes, dtype=torch.uint8, device="cuda")
_, ipc_args = reduce_tensor(ipc_buffer)
names: list[str] = []
shapes: list[list[int]] = []
dtypes: list[torch.dtype] = []
tensor_sizes: list[int] = []
total_bytes = 0
for name, orig_tensor in iterator:
flat = (
post_iter_func((name, orig_tensor)).contiguous().view(torch.uint8).view(-1)
)
if flat.numel() > buffer_size_bytes:
raise ValueError(
f"Tensor '{name}' has size {flat.numel()} bytes, "
f"which exceeds buffer_size_bytes={buffer_size_bytes}. "
f"Increase buffer_size_bytes to at least {flat.numel()}."
)
if total_bytes and total_bytes + flat.numel() > buffer_size_bytes:
# Drain queued copies so the consumer sees a fully-written buffer.
torch.cuda.current_stream().synchronize()
yield PackedIpcChunk(
names=names,
shapes=shapes,
dtype_names=[str(d).split(".")[-1] for d in dtypes],
tensor_sizes=tensor_sizes,
ipc_handle={gpu_uuid: ipc_args},
)
# Rebind to fresh lists so the yielded chunk's metadata is
# not mutated while the consumer is still reading.
names, shapes, dtypes, tensor_sizes = [], [], [], []
total_bytes = 0
ipc_buffer[total_bytes : total_bytes + flat.numel()].copy_(flat)
names.append(name)
shapes.append(list(orig_tensor.shape))
dtypes.append(orig_tensor.dtype)
tensor_sizes.append(flat.numel())
total_bytes += flat.numel()
if total_bytes:
torch.cuda.current_stream().synchronize()
yield PackedIpcChunk(
names=names,
shapes=shapes,
dtype_names=[str(d).split(".")[-1] for d in dtypes],
tensor_sizes=tensor_sizes,
ipc_handle={gpu_uuid: ipc_args},
)
def packed_ipc_consumer(
ipc_handle: dict[str, tuple],
names: list[str],
shapes: list[list[int]],
dtype_names: list[str],
tensor_sizes: list[int],
device_index: int,
) -> list[tuple[str, torch.Tensor]]:
"""Unpack a single packed IPC chunk into named tensors.
Reconstructs the packed buffer via rebuild_cuda_tensor, unpacks
into individual tensors, and clones each into independent storage
before returning.
The clone is intentional: the producer reuses one IPC buffer across
chunks, so any tensor view that aliases the buffer would observe the
*next* chunk's bytes as soon as the producer's generator is resumed.
Callers that retain references past their own update_weights call
(notably vLLM's layerwise reload, which buffers ``bound_args`` for
replay in ``_layerwise_process``) would otherwise replay against
stale data and silently corrupt multi-chunk weight transfers.
Args:
ipc_handle: Mapping of GPU UUID to rebuild_cuda_tensor args tuple
names: Parameter names in the packed buffer
shapes: Parameter shapes
dtype_names: Parameter dtype name strings (e.g. "float16")
tensor_sizes: Size in bytes of each parameter in the packed buffer
device_index: Local CUDA device index
"""
from torch.multiprocessing.reductions import rebuild_cuda_tensor
props = torch.cuda.get_device_properties(device_index)
physical_gpu_id = str(props.uuid)
if physical_gpu_id not in ipc_handle:
raise ValueError(
f"IPC handle not found for GPU UUID {physical_gpu_id}. "
f"Available UUIDs: {list(ipc_handle.keys())}"
)
args = ipc_handle[physical_gpu_id]
list_args = list(args)
list_args[6] = device_index
packed = rebuild_cuda_tensor(*list_args)
content_size = sum(tensor_sizes)
packed = packed[:content_size]
dtypes = [getattr(torch, dn) for dn in dtype_names]
return [
(name, t.clone())
for name, t in unpack_tensor(packed, names, shapes, dtypes, tensor_sizes)
]