[AutoDeploy][perf] Further optimize flashinfer backend in AutoDeploy (#4024)

* reuse batch_indices, positions across layers

Signed-off-by: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com>

* fix flashinfer unit tests

Signed-off-by: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com>

* simplify call to get_batch_indices_positions

Signed-off-by: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com>

* fix call to get_batch_indices_positions

Signed-off-by: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com>

---------

Signed-off-by: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com>
This commit is contained in:
Suyog Gupta 2025-05-05 19:46:36 -07:00 committed by GitHub
parent 5c0f554b9e
commit ac2ab9ba36
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 103 additions and 13 deletions

View File

@ -186,8 +186,23 @@ def prepare_flashinfer_metadata(
paged_kv_last_page_len = ((offsets + seq_len - 1) % page_size) + 1
# Compute batch_indices and positions so that they can be reused for kv cache appends
# for all the layers
batch_indices, positions = flashinfer.get_batch_indices_positions(
qo_indptr,
flashinfer.get_seq_lens(paged_kv_indptr, paged_kv_last_page_len, page_size),
position_ids.numel(),
)
# return metadata
return (qo_indptr, paged_kv_indptr, paged_kv_indices, paged_kv_last_page_len)
return (
qo_indptr,
paged_kv_indptr,
paged_kv_indices,
paged_kv_last_page_len,
batch_indices,
positions,
)
@prepare_flashinfer_metadata.register_fake
@ -195,11 +210,15 @@ def prepare_flashinfer_metadata_fake(
input_ids, position_ids, seq_len, input_pos, cache_loc, pages_per_seq, page_size
):
qo_indptr = torch.empty(len(seq_len) + 1, dtype=seq_len.dtype, device=seq_len.device)
batch_indices = torch.empty_like(cache_loc)
positions = torch.empty_like(cache_loc)
return (
qo_indptr, # qo_indptr
torch.empty_like(qo_indptr), # paged_kv_indptr
torch.empty_like(cache_loc), # paged_kv_indices
torch.empty_like(seq_len), # paged_kv_last_page_len
batch_indices, # batch_indices
positions, # positions
)
@ -214,6 +233,8 @@ def flashinfer_mha_with_cache(
paged_kv_indptr: torch.Tensor,
paged_kv_indices: torch.Tensor,
paged_kv_last_page_len: torch.Tensor,
batch_indices: torch.Tensor,
positions: torch.Tensor,
# CACHES
k_cache: torch.Tensor,
v_cache: torch.Tensor,
@ -254,13 +275,6 @@ def flashinfer_mha_with_cache(
k = (k / k_scale).to(torch.float8_e4m3fn)
v = (v / v_scale).to(torch.float8_e4m3fn)
# Append to kv cache
batch_indices, positions = flashinfer.get_batch_indices_positions(
qo_indptr,
flashinfer.get_seq_lens(paged_kv_indptr, paged_kv_last_page_len, pp.page_size),
q.shape[0],
)
flashinfer.page.append_paged_kv_cache(
k,
v,
@ -296,6 +310,8 @@ def flashinfer_mha_with_cache_fake(
paged_kv_indptr: torch.Tensor,
paged_kv_indices: torch.Tensor,
paged_kv_last_page_len: torch.Tensor,
batch_indices: torch.Tensor,
positions: torch.Tensor,
# CACHES
k_cache: torch.Tensor,
v_cache: torch.Tensor,
@ -341,7 +357,7 @@ class FlashInferAttention(AttentionDescriptor):
@classmethod
def get_prepare_metadata_op(cls) -> Tuple[PrepareMetadataCallable, int]:
return torch.ops.attention.prepare_flashinfer_metadata, 4
return torch.ops.attention.prepare_flashinfer_metadata, 6
@classmethod
def get_cache_initializers(

View File

@ -4,6 +4,8 @@ from typing import Dict, List, Optional, Tuple
import torch
from torch._prims_common import DeviceLikeType
from tensorrt_llm._utils import nvtx_range
from ...._utils import mpi_rank, mpi_world_size
from ....bindings.executor import ExecutorConfig
from ....bindings.internal.batch_manager import CacheType
@ -136,6 +138,7 @@ class ADEngine(ModelEngine):
# start fresh with fixed seed
torch.manual_seed(1234)
@nvtx_range("ad_prepare_inputs")
def _prepare_inputs(
self, scheduled_requests: ScheduledRequests, resource_manager: ResourceManager
) -> bool:

View File

@ -1,3 +1,4 @@
import flashinfer
import pytest
import torch
@ -80,6 +81,13 @@ def test_flashinfer_attention_op_context(seq_length, n_heads, batch_size, dtype,
workspace = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device=device)
_GlobalFlashInferPlanner.init_workspace(workspace)
batch_indices, positions = flashinfer.get_batch_indices_positions(
qo_indptr,
flashinfer.get_seq_lens(
paged_kv_indptr, paged_kv_last_page_len, page_size=k_cache.shape[1]
),
BATCH_SIZE * SEQ_LEN,
)
flashinfer_output = torch.ops.attention.flashinfer_mha_with_cache(
# Q, K, V
q,
@ -90,6 +98,8 @@ def test_flashinfer_attention_op_context(seq_length, n_heads, batch_size, dtype,
paged_kv_indptr,
paged_kv_indices,
paged_kv_last_page_len,
batch_indices,
positions,
# CACHES
k_cache,
v_cache,
@ -196,6 +206,13 @@ def test_flashinfer_attention_op_decode(
workspace = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device=device)
_GlobalFlashInferPlanner.init_workspace(workspace)
batch_indices, positions = flashinfer.get_batch_indices_positions(
qo_indptr,
flashinfer.get_seq_lens(
paged_kv_indptr, paged_kv_last_page_len, page_size=k_cache.shape[1]
),
BATCH_SIZE * SEQ_LEN,
)
flashinfer_output = torch.ops.attention.flashinfer_mha_with_cache(
# Q, K, V
q,
@ -206,6 +223,8 @@ def test_flashinfer_attention_op_decode(
paged_kv_indptr,
paged_kv_indices,
paged_kv_last_page_len,
batch_indices,
positions,
# CACHES
k_cache,
v_cache,
@ -303,7 +322,13 @@ def test_flashinfer_attention_context_and_generate(
workspace = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device=device)
_GlobalFlashInferPlanner.init_workspace(workspace)
# Generate output
batch_indices, positions = flashinfer.get_batch_indices_positions(
qo_indptr,
flashinfer.get_seq_lens(
paged_kv_indptr, paged_kv_last_page_len, page_size=k_cache.shape[1]
),
BATCH_SIZE * PREFILL_SEQ_LEN,
)
flashinfer_output_1 = torch.ops.attention.flashinfer_mha_with_cache(
# Q, K, V
q_1,
@ -314,6 +339,8 @@ def test_flashinfer_attention_context_and_generate(
paged_kv_indptr,
paged_kv_indices,
paged_kv_last_page_len,
batch_indices,
positions,
# CACHES
k_cache,
v_cache,
@ -370,6 +397,13 @@ def test_flashinfer_attention_context_and_generate(
# Create FlashInferAttention class before calling the custom op
_GlobalFlashInferPlanner.reset()
batch_indices, positions = flashinfer.get_batch_indices_positions(
qo_indptr,
flashinfer.get_seq_lens(
paged_kv_indptr, paged_kv_last_page_len, page_size=k_cache.shape[1]
),
BATCH_SIZE * 1,
)
flashinfer_output_3 = torch.ops.attention.flashinfer_mha_with_cache(
# Q, K, V
q_3,
@ -380,6 +414,8 @@ def test_flashinfer_attention_context_and_generate(
paged_kv_indptr,
paged_kv_indices,
paged_kv_last_page_len,
batch_indices,
positions,
# CACHES
k_cache,
v_cache,
@ -415,7 +451,6 @@ def test_flashinfer_attention_context_and_generate(
)
@pytest.mark.skip(reason="https://nvbugspro.nvidia.com/bug/5095416")
@pytest.mark.parametrize(
"seq",
[
@ -471,6 +506,13 @@ def test_flashinfer_attention_op_context_input_pos(seq, batch_size, n_heads, dty
workspace = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device=device)
_GlobalFlashInferPlanner.init_workspace(workspace)
batch_indices, positions = flashinfer.get_batch_indices_positions(
qo_indptr,
flashinfer.get_seq_lens(
paged_kv_indptr, paged_kv_last_page_len, page_size=k_cache.shape[1]
),
BATCH_SIZE * SEQ_LEN,
)
flashinfer_output = torch.ops.attention.flashinfer_mha_with_cache(
# Q, K, V
q,
@ -481,6 +523,8 @@ def test_flashinfer_attention_op_context_input_pos(seq, batch_size, n_heads, dty
paged_kv_indptr,
paged_kv_indices,
paged_kv_last_page_len,
batch_indices,
positions,
# CACHES
k_cache,
v_cache,
@ -609,7 +653,14 @@ def test_flashinfer_attention_with_fp8_cache(
workspace = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device=device)
_GlobalFlashInferPlanner.init_workspace(workspace)
y = torch.ops.attention.flashinfer_mha_with_cache(
batch_indices, positions = flashinfer.get_batch_indices_positions(
qo_indptr,
flashinfer.get_seq_lens(
paged_kv_indptr, paged_kv_last_page_len, page_size=k_cache.shape[1]
),
BATCH_SIZE * SEQ_LEN,
)
flashinfer_output = torch.ops.attention.flashinfer_mha_with_cache(
# Q, K, V
q,
k,
@ -619,6 +670,8 @@ def test_flashinfer_attention_with_fp8_cache(
paged_kv_indptr,
paged_kv_indices,
paged_kv_last_page_len,
batch_indices,
positions,
# CACHES
k_cache,
v_cache,
@ -630,7 +683,7 @@ def test_flashinfer_attention_with_fp8_cache(
V_SCALE,
)
y = y.view(BATCH_SIZE, SEQ_LEN, N_HEADS, D_HEAD)
y = flashinfer_output.view(BATCH_SIZE, SEQ_LEN, N_HEADS, D_HEAD)
q = q.view(BATCH_SIZE, SEQ_LEN, N_HEADS, D_HEAD)
ref = _attention_with_fp8_kv_cache(
@ -697,6 +750,13 @@ def test_flashinfer_attention_with_paged_kvcache(seq_lengths, n_heads, dtype, de
workspace = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device=device)
_GlobalFlashInferPlanner.init_workspace(workspace)
batch_indices, positions = flashinfer.get_batch_indices_positions(
qo_indptr,
flashinfer.get_seq_lens(
paged_kv_indptr, paged_kv_last_page_len, page_size=k_cache.shape[1]
),
BATCH_SIZE * SEQ_LEN,
)
flashinfer_output = torch.ops.attention.flashinfer_mha_with_cache(
# Q, K, V
q,
@ -707,6 +767,8 @@ def test_flashinfer_attention_with_paged_kvcache(seq_lengths, n_heads, dtype, de
paged_kv_indptr,
paged_kv_indices,
paged_kv_last_page_len,
batch_indices,
positions,
# CACHES
k_cache,
v_cache,
@ -771,6 +833,13 @@ def test_flashinfer_attention_with_paged_kvcache(seq_lengths, n_heads, dtype, de
# Create FlashInferAttention class before calling the custom op
_GlobalFlashInferPlanner.reset()
batch_indices, positions = flashinfer.get_batch_indices_positions(
qo_indptr2,
flashinfer.get_seq_lens(
paged_kv_indptr2, paged_kv_last_page_len2, page_size=k_cache.shape[1]
),
BATCH_SIZE * 1,
)
flashinfer_output_gen = torch.ops.attention.flashinfer_mha_with_cache(
# Q, K, V
q_gen,
@ -781,6 +850,8 @@ def test_flashinfer_attention_with_paged_kvcache(seq_lengths, n_heads, dtype, de
paged_kv_indptr2,
paged_kv_indices2,
paged_kv_last_page_len2,
batch_indices,
positions,
# CACHES
k_cache,
v_cache,