mirror of
https://github.com/vllm-project/vllm.git
synced 2026-06-06 00:16:14 +00:00
Use CU_MEMCPY_SRC_ACCESS_ORDER_ANY for batch KV cache swaps (#39306)
Signed-off-by: Itay Etelis <itay.etelis@ibm.com> Signed-off-by: Itay Etelis <etelis2019@gmail.com> Signed-off-by: Itay Etelis <92247226+Etelis@users.noreply.github.com> Co-authored-by: Itay Etelis <itay.etelis@ibm.com> Co-authored-by: Or Ozeri <oro@il.ibm.com> Co-authored-by: Itay Etelis <etelis2019@gmail.com> Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
This commit is contained in:
+2
-1
@@ -12,7 +12,8 @@ void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
|
||||
|
||||
void swap_blocks_batch(const torch::Tensor& src_ptrs,
|
||||
const torch::Tensor& dst_ptrs,
|
||||
const torch::Tensor& sizes);
|
||||
const torch::Tensor& sizes,
|
||||
bool is_src_access_order_any);
|
||||
|
||||
void reshape_and_cache(torch::Tensor& key, torch::Tensor& value,
|
||||
torch::Tensor& key_cache, torch::Tensor& value_cache,
|
||||
|
||||
@@ -77,7 +77,8 @@ void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
|
||||
|
||||
void swap_blocks_batch(const torch::Tensor& src_ptrs,
|
||||
const torch::Tensor& dst_ptrs,
|
||||
const torch::Tensor& sizes) {
|
||||
const torch::Tensor& sizes,
|
||||
bool is_src_access_order_any) {
|
||||
TORCH_CHECK(src_ptrs.device().is_cpu(), "src_ptrs must be on CPU");
|
||||
TORCH_CHECK(dst_ptrs.device().is_cpu(), "dst_ptrs must be on CPU");
|
||||
TORCH_CHECK(sizes.device().is_cpu(), "sizes must be on CPU");
|
||||
@@ -124,7 +125,12 @@ void swap_blocks_batch(const torch::Tensor& src_ptrs,
|
||||
|
||||
if (batch_fn != nullptr) {
|
||||
CUmemcpyAttributes attr = {};
|
||||
attr.srcAccessOrder = CU_MEMCPY_SRC_ACCESS_ORDER_STREAM;
|
||||
// ANY lets the DMA engine prefetch source bytes out of stream order,
|
||||
// which is only safe when no GPU stream is concurrently writing the
|
||||
// source.
|
||||
attr.srcAccessOrder = is_src_access_order_any
|
||||
? CU_MEMCPY_SRC_ACCESS_ORDER_ANY
|
||||
: CU_MEMCPY_SRC_ACCESS_ORDER_STREAM;
|
||||
size_t attrs_idx = 0;
|
||||
size_t fail_idx = 0;
|
||||
CUresult result = batch_fn(reinterpret_cast<CUdeviceptr*>(dst_data),
|
||||
|
||||
@@ -553,7 +553,8 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
|
||||
// Batch swap: submit all block copies in a single driver call.
|
||||
cache_ops.def(
|
||||
"swap_blocks_batch(Tensor src_ptrs, Tensor dst_ptrs,"
|
||||
" Tensor sizes) -> ()");
|
||||
" Tensor sizes,"
|
||||
" bool is_src_access_order_any=False) -> ()");
|
||||
cache_ops.impl("swap_blocks_batch", torch::kCPU, &swap_blocks_batch);
|
||||
|
||||
// Reshape the key and value tensors and cache them.
|
||||
|
||||
+10
-1
@@ -2805,6 +2805,7 @@ def swap_blocks_batch(
|
||||
src_ptrs: torch.Tensor,
|
||||
dst_ptrs: torch.Tensor,
|
||||
sizes: torch.Tensor,
|
||||
is_src_access_order_any: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Batch version of swap_blocks: submit all copies in a single driver call.
|
||||
@@ -2813,8 +2814,16 @@ def swap_blocks_batch(
|
||||
of sizes[i] bytes. All three tensors must be int64 CPU tensors.
|
||||
On CUDA 12.8+ this uses cuMemcpyBatchAsync for minimal submission
|
||||
overhead; on older CUDA it falls back to a loop of cudaMemcpyAsync.
|
||||
|
||||
is_src_access_order_any: if True, pass CU_MEMCPY_SRC_ACCESS_ORDER_ANY to
|
||||
cuMemcpyBatchAsync, letting the DMA engine prefetch source bytes
|
||||
out of stream order. Only safe when no GPU stream is concurrently
|
||||
writing to the source. Defaults to False (STREAM ordering), which
|
||||
is always safe.
|
||||
"""
|
||||
torch.ops._C_cache_ops.swap_blocks_batch(src_ptrs, dst_ptrs, sizes)
|
||||
torch.ops._C_cache_ops.swap_blocks_batch(
|
||||
src_ptrs, dst_ptrs, sizes, is_src_access_order_any
|
||||
)
|
||||
|
||||
|
||||
def convert_fp8(
|
||||
|
||||
@@ -313,10 +313,22 @@ class SingleDirectionOffloadingHandler(OffloadingHandler):
|
||||
last_event = last_transfer.end_event
|
||||
# assure job will start only after the previous one completes
|
||||
stream.wait_event(last_event)
|
||||
# CPU->GPU reads from host pinned memory, which is never written
|
||||
# by a concurrent GPU stream, so CU_MEMCPY_SRC_ACCESS_ORDER_ANY is
|
||||
# safe and lets the driver pipeline source reads. GPU->CPU reads
|
||||
# from the live GPU KV cache, which the compute stream keeps
|
||||
# writing; we must keep STREAM ordering so source reads are gated
|
||||
# by the transfer stream's wait_stream(compute) barrier.
|
||||
is_src_access_order_any = not self.gpu_to_cpu
|
||||
with torch.cuda.stream(stream):
|
||||
start_event.record(stream)
|
||||
if num_copy_ops > 0:
|
||||
ops.swap_blocks_batch(batch_src, batch_dst, batch_sizes)
|
||||
ops.swap_blocks_batch(
|
||||
batch_src,
|
||||
batch_dst,
|
||||
batch_sizes,
|
||||
is_src_access_order_any=is_src_access_order_any,
|
||||
)
|
||||
end_event.record(stream)
|
||||
|
||||
self._transfer_events[job_id] = end_event
|
||||
|
||||
Reference in New Issue
Block a user