Use CU_MEMCPY_SRC_ACCESS_ORDER_ANY for batch KV cache swaps (#39306)

Signed-off-by: Itay Etelis <itay.etelis@ibm.com>
Signed-off-by: Itay Etelis <etelis2019@gmail.com>
Signed-off-by: Itay Etelis <92247226+Etelis@users.noreply.github.com>
Co-authored-by: Itay Etelis <itay.etelis@ibm.com>
Co-authored-by: Or Ozeri <oro@il.ibm.com>
Co-authored-by: Itay Etelis <etelis2019@gmail.com>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
This commit is contained in:
Itay Etelis
2026-05-10 05:57:09 +03:00
committed by GitHub
parent 0d382ecde8
commit 00b0618a03
5 changed files with 35 additions and 6 deletions
+2 -1
View File
@@ -12,7 +12,8 @@ void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
void swap_blocks_batch(const torch::Tensor& src_ptrs,
const torch::Tensor& dst_ptrs,
const torch::Tensor& sizes);
const torch::Tensor& sizes,
bool is_src_access_order_any);
void reshape_and_cache(torch::Tensor& key, torch::Tensor& value,
torch::Tensor& key_cache, torch::Tensor& value_cache,
+8 -2
View File
@@ -77,7 +77,8 @@ void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
void swap_blocks_batch(const torch::Tensor& src_ptrs,
const torch::Tensor& dst_ptrs,
const torch::Tensor& sizes) {
const torch::Tensor& sizes,
bool is_src_access_order_any) {
TORCH_CHECK(src_ptrs.device().is_cpu(), "src_ptrs must be on CPU");
TORCH_CHECK(dst_ptrs.device().is_cpu(), "dst_ptrs must be on CPU");
TORCH_CHECK(sizes.device().is_cpu(), "sizes must be on CPU");
@@ -124,7 +125,12 @@ void swap_blocks_batch(const torch::Tensor& src_ptrs,
if (batch_fn != nullptr) {
CUmemcpyAttributes attr = {};
attr.srcAccessOrder = CU_MEMCPY_SRC_ACCESS_ORDER_STREAM;
// ANY lets the DMA engine prefetch source bytes out of stream order,
// which is only safe when no GPU stream is concurrently writing the
// source.
attr.srcAccessOrder = is_src_access_order_any
? CU_MEMCPY_SRC_ACCESS_ORDER_ANY
: CU_MEMCPY_SRC_ACCESS_ORDER_STREAM;
size_t attrs_idx = 0;
size_t fail_idx = 0;
CUresult result = batch_fn(reinterpret_cast<CUdeviceptr*>(dst_data),
+2 -1
View File
@@ -553,7 +553,8 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
// Batch swap: submit all block copies in a single driver call.
cache_ops.def(
"swap_blocks_batch(Tensor src_ptrs, Tensor dst_ptrs,"
" Tensor sizes) -> ()");
" Tensor sizes,"
" bool is_src_access_order_any=False) -> ()");
cache_ops.impl("swap_blocks_batch", torch::kCPU, &swap_blocks_batch);
// Reshape the key and value tensors and cache them.
+10 -1
View File
@@ -2805,6 +2805,7 @@ def swap_blocks_batch(
src_ptrs: torch.Tensor,
dst_ptrs: torch.Tensor,
sizes: torch.Tensor,
is_src_access_order_any: bool = False,
) -> None:
"""
Batch version of swap_blocks: submit all copies in a single driver call.
@@ -2813,8 +2814,16 @@ def swap_blocks_batch(
of sizes[i] bytes. All three tensors must be int64 CPU tensors.
On CUDA 12.8+ this uses cuMemcpyBatchAsync for minimal submission
overhead; on older CUDA it falls back to a loop of cudaMemcpyAsync.
is_src_access_order_any: if True, pass CU_MEMCPY_SRC_ACCESS_ORDER_ANY to
cuMemcpyBatchAsync, letting the DMA engine prefetch source bytes
out of stream order. Only safe when no GPU stream is concurrently
writing to the source. Defaults to False (STREAM ordering), which
is always safe.
"""
torch.ops._C_cache_ops.swap_blocks_batch(src_ptrs, dst_ptrs, sizes)
torch.ops._C_cache_ops.swap_blocks_batch(
src_ptrs, dst_ptrs, sizes, is_src_access_order_any
)
def convert_fp8(
+13 -1
View File
@@ -313,10 +313,22 @@ class SingleDirectionOffloadingHandler(OffloadingHandler):
last_event = last_transfer.end_event
# assure job will start only after the previous one completes
stream.wait_event(last_event)
# CPU->GPU reads from host pinned memory, which is never written
# by a concurrent GPU stream, so CU_MEMCPY_SRC_ACCESS_ORDER_ANY is
# safe and lets the driver pipeline source reads. GPU->CPU reads
# from the live GPU KV cache, which the compute stream keeps
# writing; we must keep STREAM ordering so source reads are gated
# by the transfer stream's wait_stream(compute) barrier.
is_src_access_order_any = not self.gpu_to_cpu
with torch.cuda.stream(stream):
start_event.record(stream)
if num_copy_ops > 0:
ops.swap_blocks_batch(batch_src, batch_dst, batch_sizes)
ops.swap_blocks_batch(
batch_src,
batch_dst,
batch_sizes,
is_src_access_order_any=is_src_access_order_any,
)
end_event.record(stream)
self._transfer_events[job_id] = end_event