mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
[AutoDeploy][perf] Further optimize flashinfer backend in AutoDeploy (#4024)
* reuse batch_indices, positions across layers Signed-off-by: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com> * fix flashinfer unit tests Signed-off-by: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com> * simplify call to get_batch_indices_positions Signed-off-by: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com> * fix call to get_batch_indices_positions Signed-off-by: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com> --------- Signed-off-by: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com>
This commit is contained in:
parent
5c0f554b9e
commit
ac2ab9ba36
@ -186,8 +186,23 @@ def prepare_flashinfer_metadata(
|
||||
|
||||
paged_kv_last_page_len = ((offsets + seq_len - 1) % page_size) + 1
|
||||
|
||||
# Compute batch_indices and positions so that they can be reused for kv cache appends
|
||||
# for all the layers
|
||||
batch_indices, positions = flashinfer.get_batch_indices_positions(
|
||||
qo_indptr,
|
||||
flashinfer.get_seq_lens(paged_kv_indptr, paged_kv_last_page_len, page_size),
|
||||
position_ids.numel(),
|
||||
)
|
||||
|
||||
# return metadata
|
||||
return (qo_indptr, paged_kv_indptr, paged_kv_indices, paged_kv_last_page_len)
|
||||
return (
|
||||
qo_indptr,
|
||||
paged_kv_indptr,
|
||||
paged_kv_indices,
|
||||
paged_kv_last_page_len,
|
||||
batch_indices,
|
||||
positions,
|
||||
)
|
||||
|
||||
|
||||
@prepare_flashinfer_metadata.register_fake
|
||||
@ -195,11 +210,15 @@ def prepare_flashinfer_metadata_fake(
|
||||
input_ids, position_ids, seq_len, input_pos, cache_loc, pages_per_seq, page_size
|
||||
):
|
||||
qo_indptr = torch.empty(len(seq_len) + 1, dtype=seq_len.dtype, device=seq_len.device)
|
||||
batch_indices = torch.empty_like(cache_loc)
|
||||
positions = torch.empty_like(cache_loc)
|
||||
return (
|
||||
qo_indptr, # qo_indptr
|
||||
torch.empty_like(qo_indptr), # paged_kv_indptr
|
||||
torch.empty_like(cache_loc), # paged_kv_indices
|
||||
torch.empty_like(seq_len), # paged_kv_last_page_len
|
||||
batch_indices, # batch_indices
|
||||
positions, # positions
|
||||
)
|
||||
|
||||
|
||||
@ -214,6 +233,8 @@ def flashinfer_mha_with_cache(
|
||||
paged_kv_indptr: torch.Tensor,
|
||||
paged_kv_indices: torch.Tensor,
|
||||
paged_kv_last_page_len: torch.Tensor,
|
||||
batch_indices: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
# CACHES
|
||||
k_cache: torch.Tensor,
|
||||
v_cache: torch.Tensor,
|
||||
@ -254,13 +275,6 @@ def flashinfer_mha_with_cache(
|
||||
k = (k / k_scale).to(torch.float8_e4m3fn)
|
||||
v = (v / v_scale).to(torch.float8_e4m3fn)
|
||||
|
||||
# Append to kv cache
|
||||
batch_indices, positions = flashinfer.get_batch_indices_positions(
|
||||
qo_indptr,
|
||||
flashinfer.get_seq_lens(paged_kv_indptr, paged_kv_last_page_len, pp.page_size),
|
||||
q.shape[0],
|
||||
)
|
||||
|
||||
flashinfer.page.append_paged_kv_cache(
|
||||
k,
|
||||
v,
|
||||
@ -296,6 +310,8 @@ def flashinfer_mha_with_cache_fake(
|
||||
paged_kv_indptr: torch.Tensor,
|
||||
paged_kv_indices: torch.Tensor,
|
||||
paged_kv_last_page_len: torch.Tensor,
|
||||
batch_indices: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
# CACHES
|
||||
k_cache: torch.Tensor,
|
||||
v_cache: torch.Tensor,
|
||||
@ -341,7 +357,7 @@ class FlashInferAttention(AttentionDescriptor):
|
||||
|
||||
@classmethod
|
||||
def get_prepare_metadata_op(cls) -> Tuple[PrepareMetadataCallable, int]:
|
||||
return torch.ops.attention.prepare_flashinfer_metadata, 4
|
||||
return torch.ops.attention.prepare_flashinfer_metadata, 6
|
||||
|
||||
@classmethod
|
||||
def get_cache_initializers(
|
||||
|
||||
@ -4,6 +4,8 @@ from typing import Dict, List, Optional, Tuple
|
||||
import torch
|
||||
from torch._prims_common import DeviceLikeType
|
||||
|
||||
from tensorrt_llm._utils import nvtx_range
|
||||
|
||||
from ...._utils import mpi_rank, mpi_world_size
|
||||
from ....bindings.executor import ExecutorConfig
|
||||
from ....bindings.internal.batch_manager import CacheType
|
||||
@ -136,6 +138,7 @@ class ADEngine(ModelEngine):
|
||||
# start fresh with fixed seed
|
||||
torch.manual_seed(1234)
|
||||
|
||||
@nvtx_range("ad_prepare_inputs")
|
||||
def _prepare_inputs(
|
||||
self, scheduled_requests: ScheduledRequests, resource_manager: ResourceManager
|
||||
) -> bool:
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import flashinfer
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
@ -80,6 +81,13 @@ def test_flashinfer_attention_op_context(seq_length, n_heads, batch_size, dtype,
|
||||
workspace = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device=device)
|
||||
_GlobalFlashInferPlanner.init_workspace(workspace)
|
||||
|
||||
batch_indices, positions = flashinfer.get_batch_indices_positions(
|
||||
qo_indptr,
|
||||
flashinfer.get_seq_lens(
|
||||
paged_kv_indptr, paged_kv_last_page_len, page_size=k_cache.shape[1]
|
||||
),
|
||||
BATCH_SIZE * SEQ_LEN,
|
||||
)
|
||||
flashinfer_output = torch.ops.attention.flashinfer_mha_with_cache(
|
||||
# Q, K, V
|
||||
q,
|
||||
@ -90,6 +98,8 @@ def test_flashinfer_attention_op_context(seq_length, n_heads, batch_size, dtype,
|
||||
paged_kv_indptr,
|
||||
paged_kv_indices,
|
||||
paged_kv_last_page_len,
|
||||
batch_indices,
|
||||
positions,
|
||||
# CACHES
|
||||
k_cache,
|
||||
v_cache,
|
||||
@ -196,6 +206,13 @@ def test_flashinfer_attention_op_decode(
|
||||
workspace = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device=device)
|
||||
_GlobalFlashInferPlanner.init_workspace(workspace)
|
||||
|
||||
batch_indices, positions = flashinfer.get_batch_indices_positions(
|
||||
qo_indptr,
|
||||
flashinfer.get_seq_lens(
|
||||
paged_kv_indptr, paged_kv_last_page_len, page_size=k_cache.shape[1]
|
||||
),
|
||||
BATCH_SIZE * SEQ_LEN,
|
||||
)
|
||||
flashinfer_output = torch.ops.attention.flashinfer_mha_with_cache(
|
||||
# Q, K, V
|
||||
q,
|
||||
@ -206,6 +223,8 @@ def test_flashinfer_attention_op_decode(
|
||||
paged_kv_indptr,
|
||||
paged_kv_indices,
|
||||
paged_kv_last_page_len,
|
||||
batch_indices,
|
||||
positions,
|
||||
# CACHES
|
||||
k_cache,
|
||||
v_cache,
|
||||
@ -303,7 +322,13 @@ def test_flashinfer_attention_context_and_generate(
|
||||
workspace = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device=device)
|
||||
_GlobalFlashInferPlanner.init_workspace(workspace)
|
||||
|
||||
# Generate output
|
||||
batch_indices, positions = flashinfer.get_batch_indices_positions(
|
||||
qo_indptr,
|
||||
flashinfer.get_seq_lens(
|
||||
paged_kv_indptr, paged_kv_last_page_len, page_size=k_cache.shape[1]
|
||||
),
|
||||
BATCH_SIZE * PREFILL_SEQ_LEN,
|
||||
)
|
||||
flashinfer_output_1 = torch.ops.attention.flashinfer_mha_with_cache(
|
||||
# Q, K, V
|
||||
q_1,
|
||||
@ -314,6 +339,8 @@ def test_flashinfer_attention_context_and_generate(
|
||||
paged_kv_indptr,
|
||||
paged_kv_indices,
|
||||
paged_kv_last_page_len,
|
||||
batch_indices,
|
||||
positions,
|
||||
# CACHES
|
||||
k_cache,
|
||||
v_cache,
|
||||
@ -370,6 +397,13 @@ def test_flashinfer_attention_context_and_generate(
|
||||
# Create FlashInferAttention class before calling the custom op
|
||||
_GlobalFlashInferPlanner.reset()
|
||||
|
||||
batch_indices, positions = flashinfer.get_batch_indices_positions(
|
||||
qo_indptr,
|
||||
flashinfer.get_seq_lens(
|
||||
paged_kv_indptr, paged_kv_last_page_len, page_size=k_cache.shape[1]
|
||||
),
|
||||
BATCH_SIZE * 1,
|
||||
)
|
||||
flashinfer_output_3 = torch.ops.attention.flashinfer_mha_with_cache(
|
||||
# Q, K, V
|
||||
q_3,
|
||||
@ -380,6 +414,8 @@ def test_flashinfer_attention_context_and_generate(
|
||||
paged_kv_indptr,
|
||||
paged_kv_indices,
|
||||
paged_kv_last_page_len,
|
||||
batch_indices,
|
||||
positions,
|
||||
# CACHES
|
||||
k_cache,
|
||||
v_cache,
|
||||
@ -415,7 +451,6 @@ def test_flashinfer_attention_context_and_generate(
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="https://nvbugspro.nvidia.com/bug/5095416")
|
||||
@pytest.mark.parametrize(
|
||||
"seq",
|
||||
[
|
||||
@ -471,6 +506,13 @@ def test_flashinfer_attention_op_context_input_pos(seq, batch_size, n_heads, dty
|
||||
workspace = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device=device)
|
||||
_GlobalFlashInferPlanner.init_workspace(workspace)
|
||||
|
||||
batch_indices, positions = flashinfer.get_batch_indices_positions(
|
||||
qo_indptr,
|
||||
flashinfer.get_seq_lens(
|
||||
paged_kv_indptr, paged_kv_last_page_len, page_size=k_cache.shape[1]
|
||||
),
|
||||
BATCH_SIZE * SEQ_LEN,
|
||||
)
|
||||
flashinfer_output = torch.ops.attention.flashinfer_mha_with_cache(
|
||||
# Q, K, V
|
||||
q,
|
||||
@ -481,6 +523,8 @@ def test_flashinfer_attention_op_context_input_pos(seq, batch_size, n_heads, dty
|
||||
paged_kv_indptr,
|
||||
paged_kv_indices,
|
||||
paged_kv_last_page_len,
|
||||
batch_indices,
|
||||
positions,
|
||||
# CACHES
|
||||
k_cache,
|
||||
v_cache,
|
||||
@ -609,7 +653,14 @@ def test_flashinfer_attention_with_fp8_cache(
|
||||
workspace = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device=device)
|
||||
_GlobalFlashInferPlanner.init_workspace(workspace)
|
||||
|
||||
y = torch.ops.attention.flashinfer_mha_with_cache(
|
||||
batch_indices, positions = flashinfer.get_batch_indices_positions(
|
||||
qo_indptr,
|
||||
flashinfer.get_seq_lens(
|
||||
paged_kv_indptr, paged_kv_last_page_len, page_size=k_cache.shape[1]
|
||||
),
|
||||
BATCH_SIZE * SEQ_LEN,
|
||||
)
|
||||
flashinfer_output = torch.ops.attention.flashinfer_mha_with_cache(
|
||||
# Q, K, V
|
||||
q,
|
||||
k,
|
||||
@ -619,6 +670,8 @@ def test_flashinfer_attention_with_fp8_cache(
|
||||
paged_kv_indptr,
|
||||
paged_kv_indices,
|
||||
paged_kv_last_page_len,
|
||||
batch_indices,
|
||||
positions,
|
||||
# CACHES
|
||||
k_cache,
|
||||
v_cache,
|
||||
@ -630,7 +683,7 @@ def test_flashinfer_attention_with_fp8_cache(
|
||||
V_SCALE,
|
||||
)
|
||||
|
||||
y = y.view(BATCH_SIZE, SEQ_LEN, N_HEADS, D_HEAD)
|
||||
y = flashinfer_output.view(BATCH_SIZE, SEQ_LEN, N_HEADS, D_HEAD)
|
||||
q = q.view(BATCH_SIZE, SEQ_LEN, N_HEADS, D_HEAD)
|
||||
|
||||
ref = _attention_with_fp8_kv_cache(
|
||||
@ -697,6 +750,13 @@ def test_flashinfer_attention_with_paged_kvcache(seq_lengths, n_heads, dtype, de
|
||||
workspace = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device=device)
|
||||
_GlobalFlashInferPlanner.init_workspace(workspace)
|
||||
|
||||
batch_indices, positions = flashinfer.get_batch_indices_positions(
|
||||
qo_indptr,
|
||||
flashinfer.get_seq_lens(
|
||||
paged_kv_indptr, paged_kv_last_page_len, page_size=k_cache.shape[1]
|
||||
),
|
||||
BATCH_SIZE * SEQ_LEN,
|
||||
)
|
||||
flashinfer_output = torch.ops.attention.flashinfer_mha_with_cache(
|
||||
# Q, K, V
|
||||
q,
|
||||
@ -707,6 +767,8 @@ def test_flashinfer_attention_with_paged_kvcache(seq_lengths, n_heads, dtype, de
|
||||
paged_kv_indptr,
|
||||
paged_kv_indices,
|
||||
paged_kv_last_page_len,
|
||||
batch_indices,
|
||||
positions,
|
||||
# CACHES
|
||||
k_cache,
|
||||
v_cache,
|
||||
@ -771,6 +833,13 @@ def test_flashinfer_attention_with_paged_kvcache(seq_lengths, n_heads, dtype, de
|
||||
# Create FlashInferAttention class before calling the custom op
|
||||
_GlobalFlashInferPlanner.reset()
|
||||
|
||||
batch_indices, positions = flashinfer.get_batch_indices_positions(
|
||||
qo_indptr2,
|
||||
flashinfer.get_seq_lens(
|
||||
paged_kv_indptr2, paged_kv_last_page_len2, page_size=k_cache.shape[1]
|
||||
),
|
||||
BATCH_SIZE * 1,
|
||||
)
|
||||
flashinfer_output_gen = torch.ops.attention.flashinfer_mha_with_cache(
|
||||
# Q, K, V
|
||||
q_gen,
|
||||
@ -781,6 +850,8 @@ def test_flashinfer_attention_with_paged_kvcache(seq_lengths, n_heads, dtype, de
|
||||
paged_kv_indptr2,
|
||||
paged_kv_indices2,
|
||||
paged_kv_last_page_len2,
|
||||
batch_indices,
|
||||
positions,
|
||||
# CACHES
|
||||
k_cache,
|
||||
v_cache,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user