mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
Signed-off-by: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com>
This commit is contained in:
parent
9085021aa4
commit
1bbe71b3ed
@ -10,19 +10,7 @@ and operates on a purely functional paradigm that is compatible with the torch c
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import (
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Protocol,
|
||||
Sequence,
|
||||
Set,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
)
|
||||
from typing import Dict, List, Literal, Optional, Protocol, Sequence, Set, Tuple, Type, Union
|
||||
|
||||
import torch
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
@ -36,6 +24,10 @@ from ..utils.logger import ad_logger
|
||||
Constant = Union[int, float, str, None]
|
||||
|
||||
|
||||
class PrepareMetadataHostCallable(Protocol):
|
||||
def __call__(self, **sequence_info_args: torch.Tensor) -> None: ...
|
||||
|
||||
|
||||
class InputBuffer:
|
||||
"""Manages contiguous memory buffers for efficient host-to-device transfers.
|
||||
|
||||
@ -388,6 +380,9 @@ class SequenceInfo:
|
||||
- _mask_scatter_indices: [m_0, m_1, ..., m_{s_total-1}]
|
||||
Mask scatter indices used by the overlap scheduler to scatter results back.
|
||||
|
||||
NOTE: all tensors are also accessible as host tensors with the suffix "_host". For example,
|
||||
the tensor "batch_info" is accessible as "batch_info_host" on the host.
|
||||
|
||||
################################################################################################
|
||||
|
||||
Here are a couple of notes to emphasize this notation:
|
||||
@ -508,6 +503,9 @@ class SequenceInfo:
|
||||
# Create the InputBuffer that manages contiguous host and device memory
|
||||
# Starts on default device; use to() to move to target device
|
||||
self._input_buffer = InputBuffer(tensor_specs)
|
||||
self._available_args = set(self._input_buffer.tensor_names) | {
|
||||
f"{name}_host" for name in self._input_buffer.tensor_names
|
||||
}
|
||||
|
||||
# Initialize args_list from tensor specs
|
||||
self._args_list: Dict[str, List[int]] = {
|
||||
@ -515,9 +513,7 @@ class SequenceInfo:
|
||||
}
|
||||
|
||||
self._active_args = ("input_ids", "position_ids")
|
||||
self._shapeable_args = ("input_ids", "position_ids")
|
||||
# Args that should be returned from host (pinned memory) instead of device in _named_args
|
||||
self._host_return_args = ("batch_info", "logits_gather_info")
|
||||
self._shapeable_args = ("input_ids", "position_ids", "input_ids_host", "position_ids_host")
|
||||
############################################################################################
|
||||
|
||||
# EXTRA TENSOR FIELDS ######################################################################
|
||||
@ -525,7 +521,7 @@ class SequenceInfo:
|
||||
############################################################################################
|
||||
|
||||
# HOST PREPARE FOR ATTENTION FORWARD #######################################################
|
||||
self._host_prepare_functions: set[Callable[[SequenceInfo], None]] = set()
|
||||
self._host_prepare_functions: List[Tuple[PrepareMetadataHostCallable, List[str]]] = []
|
||||
|
||||
# call reset once to set a consistent initial state
|
||||
self.reset()
|
||||
@ -558,14 +554,13 @@ class SequenceInfo:
|
||||
|
||||
def _get_arg(self, name: str) -> torch.Tensor:
|
||||
"""Get the argument from the input buffer either on device or host."""
|
||||
if name in self._host_return_args:
|
||||
arg = self._input_buffer.get_host_view(name)
|
||||
if name.endswith("_host"):
|
||||
arg = self._input_buffer.get_host_view(name.replace("_host", ""))
|
||||
else:
|
||||
arg = self._input_buffer.get_view(name)
|
||||
return self._shape_for_forward(arg) if name in self._shapeable_args else arg
|
||||
|
||||
def _named_args(self, include_extra_args: bool = True) -> Dict[str, torch.Tensor]:
|
||||
# Build args dict, using host views for _host_return_args, device views otherwise
|
||||
args = {k: self._get_arg(k) for k in self._active_args}
|
||||
|
||||
# check other args to include
|
||||
@ -577,7 +572,7 @@ class SequenceInfo:
|
||||
@property
|
||||
def available_args(self) -> Set[str]:
|
||||
"""Return a list of available arguments."""
|
||||
return set(self._input_buffer.tensor_names)
|
||||
return self._available_args
|
||||
|
||||
@property
|
||||
def named_args(self) -> Dict[str, torch.Tensor]:
|
||||
@ -697,68 +692,6 @@ class SequenceInfo:
|
||||
pages_per_seq = [len(p) for p in page_assignments]
|
||||
return cache_loc_flat, pages_per_seq
|
||||
|
||||
# TODO: remove after updating all cached backends
|
||||
@classmethod
|
||||
def _get_sanitized_seq_len(
|
||||
cls, input_or_position_ids: torch.Tensor, seq_len: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
"""Sanitize sequence lengths.
|
||||
|
||||
We want to cover the following scenarios with this function:
|
||||
|
||||
1. Pre-fill:
|
||||
input_ids: [1, s_total, ...]
|
||||
seq_len: [s_0, s_1, ..., s_{b-1}, 0, 0, ..., 0]
|
||||
---> returns [s_0, s_1, ..., s_{b-1}]
|
||||
2. Decode:
|
||||
input_ids: [b, 1, ...]
|
||||
seq_len: [1, 1, ..., 1, 0, 0, ..., ..., ..., ..., 0]
|
||||
|---- b ----|--- (max_batch_size - b) ---|
|
||||
--> returns [1,] * b
|
||||
3. Decode in Cudagraph:
|
||||
input_ids: [b_cudagraph, 1, ...]
|
||||
seq_len: [1, 1, ..., 1, 0, 0, ..., ..., ..., ..., 0]
|
||||
|---- b ----|--- (max_batch_size - b) ---|
|
||||
|
||||
--> returns [1,] * b_cudagraph
|
||||
Here b <= b_cudagraph. We want to make sure that the seq_len is one-padded to
|
||||
b_cudagraph.
|
||||
|
||||
# TODO: I could see one possible issue with this approach in the future.
|
||||
# If we have b < b_cudagraph we now one-pad. However, we don't pad the cache location
|
||||
# information. What could happen is that the for the padded sequences the cache location
|
||||
# tensors point to allocated pages. This could lead to a situation where we write into
|
||||
# allocated cache pages polluting the cache of other sequences. Now this is not an issue
|
||||
# if we write the dummy sequences into unallocated cache pages... One fix could be to
|
||||
# pad not only the seq len but also pad the cache locations by just repeating the last
|
||||
# valid cache location in the batch. This would ensure that the dummy sequences just
|
||||
# repeats valid computation...
|
||||
"""
|
||||
_, s = input_or_position_ids.shape[:2]
|
||||
num_seq = cls._get_sanitized_num_sequences(input_or_position_ids, seq_len)
|
||||
if s > 1:
|
||||
return seq_len[:num_seq].clone()
|
||||
else:
|
||||
return torch.ones(num_seq, dtype=seq_len.dtype, device=seq_len.device)
|
||||
|
||||
@staticmethod
|
||||
def _get_sanitized_num_sequences(
|
||||
input_or_position_ids: torch.Tensor, seq_len: torch.Tensor
|
||||
) -> int:
|
||||
"""Get number of sequences.
|
||||
|
||||
We makes sure that this function is compatible with both torch graph capture and cudagraph.
|
||||
Both can be a bit temparamental when trying to extract the number of sequences from a tensor
|
||||
with max_batch_size or max_batch_size*max_seq_len.
|
||||
"""
|
||||
b, s = input_or_position_ids.shape[:2]
|
||||
if s > 1:
|
||||
num_seq = torch.sum(seq_len > 0)
|
||||
assert seq_len[num_seq:].sum() == 0, "seq_len should be zero-padded"
|
||||
else:
|
||||
num_seq = b
|
||||
return num_seq
|
||||
|
||||
def activate_arg(self, arg_name: str) -> bool:
|
||||
"""Activate a desired argument.
|
||||
|
||||
@ -869,7 +802,7 @@ class SequenceInfo:
|
||||
self._args_list[name] = tnsr_like.copy()
|
||||
|
||||
# Only store to buffer when the argument is active or force_copy is True
|
||||
if not (name in self._active_args or force_copy):
|
||||
if not (name in self._active_args or f"{name}_host" in self._active_args or force_copy):
|
||||
return
|
||||
|
||||
# Store to the InputBuffer's pinned host memory
|
||||
@ -1090,12 +1023,12 @@ class SequenceInfo:
|
||||
def maybe_gather_and_squeeze_logits(self, logits: torch.Tensor) -> torch.Tensor:
|
||||
"""Maybe gather the logits if logits have not been gathered yet."""
|
||||
num_tokens = logits.shape[0] * logits.shape[1]
|
||||
num_tokens_to_gather, gather_required = self._get_arg("logits_gather_info").tolist()
|
||||
num_tokens_to_gather, gather_required = self._get_arg("logits_gather_info_host").tolist()
|
||||
if gather_required and num_tokens_to_gather < num_tokens:
|
||||
logits = torch.ops.auto_deploy.gather_logits_before_lm_head(
|
||||
logits,
|
||||
self._get_arg("logits_gather_indices"),
|
||||
self._get_arg("logits_gather_info"),
|
||||
self._get_arg("logits_gather_info_host"),
|
||||
)
|
||||
return logits.squeeze(int(self.is_generate))
|
||||
|
||||
@ -1105,13 +1038,13 @@ class SequenceInfo:
|
||||
return list(torch.split(t_squeezed, self.seq_len))
|
||||
|
||||
def register_host_prepare_for_attention_forward(
|
||||
self, host_function: Callable[["SequenceInfo"], None]
|
||||
self, host_function: PrepareMetadataHostCallable, args: List[str]
|
||||
):
|
||||
self._host_prepare_functions.add(host_function)
|
||||
self._host_prepare_functions.append((host_function, args))
|
||||
|
||||
def run_host_prepare_for_attention_forward(self) -> None:
|
||||
for host_function in self._host_prepare_functions:
|
||||
host_function(self)
|
||||
for host_function, args in self._host_prepare_functions:
|
||||
host_function(**{arg: self._get_arg(arg) for arg in args})
|
||||
|
||||
|
||||
class MHACallable(Protocol):
|
||||
@ -1123,14 +1056,7 @@ class MHACallable(Protocol):
|
||||
|
||||
class PrepareMetadataCallable(Protocol):
|
||||
def __call__(
|
||||
self,
|
||||
position_ids: torch.Tensor,
|
||||
seq_len: torch.Tensor,
|
||||
input_pos: torch.Tensor,
|
||||
cache_loc: torch.Tensor,
|
||||
pages_per_seq: torch.Tensor,
|
||||
slot_idx: torch.Tensor,
|
||||
page_size: int,
|
||||
self, *sequence_info_args_and_constants: Union[torch.Tensor, Constant]
|
||||
) -> List[torch.Tensor]: ...
|
||||
|
||||
|
||||
@ -1291,13 +1217,14 @@ class AttentionDescriptor(ABC):
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def host_prepare_for_forward(cls, sequence_info: SequenceInfo):
|
||||
"""Perform host-side preparation for the forward pass for the attention op.
|
||||
def get_host_prepare_metadata_function(cls) -> Optional[PrepareMetadataHostCallable]:
|
||||
"""Get function that performs host-side prep for the forward pass for the attention op.
|
||||
|
||||
This method is responsible for preparing the attention op for the forward pass.
|
||||
This function is not expected to be graph capturable or compatible with cuda graphs.
|
||||
This function is not expected to be graph capturable or compatible with cuda graphs. It can
|
||||
use any argument from the SequenceInfo interface as input argument to its function.
|
||||
"""
|
||||
return
|
||||
return None
|
||||
|
||||
|
||||
class AttentionRegistry:
|
||||
|
||||
@ -35,7 +35,7 @@ def fla_cached_delta_rule(
|
||||
v: torch.Tensor,
|
||||
beta: torch.Tensor,
|
||||
# STANDARD METADATA
|
||||
batch_info: torch.Tensor,
|
||||
batch_info_host: torch.Tensor,
|
||||
cu_seqlen: torch.Tensor,
|
||||
slot_idx: torch.Tensor,
|
||||
use_initial_states: torch.Tensor,
|
||||
@ -58,7 +58,7 @@ def fla_cached_delta_rule(
|
||||
y = torch.empty_like(v, memory_format=torch.contiguous_format)
|
||||
y_flat = y.view(b * s, num_heads, -1)
|
||||
|
||||
num_prefill, num_prefill_tokens, num_decode = batch_info.tolist()
|
||||
num_prefill, num_prefill_tokens, num_decode = batch_info_host.tolist()
|
||||
num_seq = num_prefill + num_decode
|
||||
|
||||
# clean up metadata
|
||||
@ -120,7 +120,7 @@ def fla_cached_delta_rule_fake(
|
||||
v: torch.Tensor,
|
||||
beta: torch.Tensor,
|
||||
# STANDARD METADATA
|
||||
batch_info: torch.Tensor,
|
||||
batch_info_host: torch.Tensor,
|
||||
cu_seqlen: torch.Tensor,
|
||||
slot_idx: torch.Tensor,
|
||||
use_initial_states: torch.Tensor,
|
||||
@ -160,7 +160,7 @@ class FlaDeltaBackend(AttentionDescriptor):
|
||||
|
||||
@classmethod
|
||||
def get_standard_metadata_args(cls) -> List[str]:
|
||||
return ["batch_info", "cu_seqlen", "slot_idx", "use_initial_states"]
|
||||
return ["batch_info_host", "cu_seqlen", "slot_idx", "use_initial_states"]
|
||||
|
||||
@classmethod
|
||||
def get_cache_initializers(
|
||||
|
||||
@ -21,6 +21,7 @@ from .attention_interface import (
|
||||
Constant,
|
||||
MHACallable,
|
||||
PrepareMetadataCallable,
|
||||
PrepareMetadataHostCallable,
|
||||
SequenceInfo,
|
||||
)
|
||||
|
||||
@ -183,7 +184,6 @@ class PlanParams:
|
||||
n_kv_heads: int
|
||||
head_dim: int
|
||||
num_seq: int
|
||||
is_generate: bool
|
||||
page_size: int
|
||||
q_dtype: torch.dtype
|
||||
kv_dtype: torch.dtype
|
||||
@ -205,14 +205,16 @@ class _FlashInferPlanner:
|
||||
cached_cuda_graph_decode_wrappers: Dict[
|
||||
PlanParams, flashinfer.BatchDecodeWithPagedKVCacheWrapper
|
||||
]
|
||||
plan_params: Optional[PlanParams]
|
||||
plan_params_prefill: Optional[PlanParams]
|
||||
plan_params_decode: Optional[PlanParams]
|
||||
|
||||
def __init__(self):
|
||||
self.workspace_buffer = None
|
||||
self.prefill_wrapper = None
|
||||
self.decode_wrapper = None
|
||||
self.cached_cuda_graph_decode_wrappers = {}
|
||||
self.plan_params = None
|
||||
self.plan_params_prefill = None
|
||||
self.plan_params_decode = None
|
||||
|
||||
def _init_decode_wrapper(
|
||||
self,
|
||||
@ -253,7 +255,8 @@ class _FlashInferPlanner:
|
||||
self.decode_wrapper = self._init_decode_wrapper()
|
||||
|
||||
def reset(self) -> None:
|
||||
self.plan_params = None
|
||||
self.plan_params_prefill = None
|
||||
self.plan_params_decode = None
|
||||
|
||||
def plan_generate_only(
|
||||
self,
|
||||
@ -279,9 +282,46 @@ class _FlashInferPlanner:
|
||||
sm_scale=plan_params.sm_scale,
|
||||
)
|
||||
|
||||
def plan(
|
||||
def plan_prefill(
|
||||
self,
|
||||
qo_indptr_host: torch.Tensor,
|
||||
kv_page_indptr_host: torch.Tensor,
|
||||
kv_page_indices: torch.Tensor,
|
||||
kv_last_page_len_host: torch.Tensor,
|
||||
kv_lens_arr_host: torch.Tensor,
|
||||
plan_params: PlanParams,
|
||||
) -> None:
|
||||
# check for re-planning
|
||||
if plan_params != self.plan_params_prefill:
|
||||
# plan prefill
|
||||
# NOTE (lucaslie): we use host versions here. the plan actually needs both (host+device)
|
||||
# version. Unfortunately, there is no good way to access the plan API and provide both
|
||||
# although we have both available. I have decided to use the host versions here to
|
||||
# ensure non-blocking invocation of plan, whereas the other way around would trigger a
|
||||
# blocking copy to cpu. This way we trigger a non-blocking copy to device (note that
|
||||
# this is safe since we do have pinned CPU memory for all our host-side arguments).
|
||||
self.prefill_wrapper.plan(
|
||||
qo_indptr_host,
|
||||
kv_page_indptr_host,
|
||||
kv_page_indices,
|
||||
kv_last_page_len_host,
|
||||
plan_params.n_heads, # Q heads
|
||||
plan_params.n_kv_heads, # KV heads
|
||||
plan_params.head_dim,
|
||||
plan_params.page_size,
|
||||
causal=plan_params.causal,
|
||||
q_data_type=plan_params.q_dtype,
|
||||
kv_data_type=plan_params.kv_dtype,
|
||||
sm_scale=plan_params.sm_scale,
|
||||
seq_lens=kv_lens_arr_host,
|
||||
)
|
||||
self.plan_params_prefill = plan_params
|
||||
|
||||
# return prefill wrapper
|
||||
return self.prefill_wrapper
|
||||
|
||||
def plan_decode(
|
||||
self,
|
||||
qo_indptr: torch.Tensor,
|
||||
kv_page_indptr: torch.Tensor,
|
||||
kv_page_indices: torch.Tensor,
|
||||
kv_last_page_len: torch.Tensor,
|
||||
@ -323,34 +363,16 @@ class _FlashInferPlanner:
|
||||
_plan_decode(self.cached_cuda_graph_decode_wrappers[plan_params])
|
||||
# check if we are in cuda graph capture and just return the pre-cached decode wrapper
|
||||
if torch.cuda.is_current_stream_capturing() or cuda_graph_state.in_warm_up():
|
||||
assert plan_params.is_generate, "Only generate is supported during cuda graph capture."
|
||||
wrapper = self.cached_cuda_graph_decode_wrappers[plan_params]
|
||||
return wrapper
|
||||
|
||||
# check for re-planning
|
||||
if plan_params != self.plan_params:
|
||||
if plan_params.is_generate:
|
||||
_plan_decode(self.decode_wrapper)
|
||||
else:
|
||||
# plan prefill
|
||||
self.prefill_wrapper.plan(
|
||||
qo_indptr,
|
||||
kv_page_indptr,
|
||||
kv_page_indices,
|
||||
kv_last_page_len,
|
||||
plan_params.n_heads, # Q heads
|
||||
plan_params.n_kv_heads, # KV heads
|
||||
plan_params.head_dim,
|
||||
plan_params.page_size,
|
||||
causal=plan_params.causal,
|
||||
q_data_type=plan_params.q_dtype,
|
||||
kv_data_type=plan_params.kv_dtype,
|
||||
sm_scale=plan_params.sm_scale,
|
||||
)
|
||||
self.plan_params = plan_params
|
||||
if plan_params != self.plan_params_decode:
|
||||
_plan_decode(self.decode_wrapper)
|
||||
self.plan_params_decode = plan_params
|
||||
|
||||
# return desired wrapper
|
||||
return self.decode_wrapper if plan_params.is_generate else self.prefill_wrapper
|
||||
# return decode wrapper
|
||||
return self.decode_wrapper
|
||||
|
||||
|
||||
_GlobalFlashInferPlanner = _FlashInferPlanner()
|
||||
@ -359,7 +381,7 @@ _GlobalFlashInferPlanner = _FlashInferPlanner()
|
||||
@torch.library.custom_op("auto_deploy::flashinfer_attention_prepare_metadata", mutates_args=())
|
||||
def prepare_flashinfer_metadata(
|
||||
position_ids: torch.Tensor,
|
||||
batch_info: torch.Tensor,
|
||||
batch_info_host: torch.Tensor,
|
||||
cu_seqlen: torch.Tensor,
|
||||
seq_len_with_cache: torch.Tensor,
|
||||
) -> List[torch.Tensor]:
|
||||
@ -370,7 +392,7 @@ def prepare_flashinfer_metadata(
|
||||
to understand the convention.
|
||||
"""
|
||||
# retrieve host-side metadata
|
||||
num_prefill, num_prefill_tokens, num_decode = batch_info.tolist()
|
||||
num_prefill, num_prefill_tokens, num_decode = batch_info_host.tolist()
|
||||
num_seq = num_prefill + num_decode
|
||||
num_tokens = num_prefill_tokens + num_decode
|
||||
|
||||
@ -393,7 +415,7 @@ def prepare_flashinfer_metadata(
|
||||
@prepare_flashinfer_metadata.register_fake
|
||||
def prepare_flashinfer_metadata_fake(
|
||||
position_ids: torch.Tensor,
|
||||
batch_info: torch.Tensor,
|
||||
batch_info_host: torch.Tensor,
|
||||
cu_seqlen: torch.Tensor,
|
||||
seq_len_with_cache: torch.Tensor,
|
||||
):
|
||||
@ -404,6 +426,23 @@ def prepare_flashinfer_metadata_fake(
|
||||
)
|
||||
|
||||
|
||||
def prepare_flashinfer_metadata_host(
|
||||
batch_info_host: torch.Tensor,
|
||||
cu_num_pages_host: torch.Tensor,
|
||||
cache_loc_host: torch.Tensor,
|
||||
last_page_len_host: torch.Tensor,
|
||||
) -> None:
|
||||
num_prefill, num_prefill_tokens, num_decode = batch_info_host.tolist()
|
||||
|
||||
if num_prefill == 0:
|
||||
_GlobalFlashInferPlanner.plan_generate_only(
|
||||
num_decode,
|
||||
cu_num_pages_host[: num_decode + 1],
|
||||
cache_loc_host,
|
||||
last_page_len_host[:num_decode],
|
||||
)
|
||||
|
||||
|
||||
@torch.library.custom_op("auto_deploy::flashinfer_attention_mha_with_cache", mutates_args=())
|
||||
def flashinfer_mha_with_cache(
|
||||
# Q, K, V
|
||||
@ -411,11 +450,14 @@ def flashinfer_mha_with_cache(
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
# STANDARD METADATA
|
||||
batch_info: torch.Tensor,
|
||||
cu_seqlen: torch.Tensor,
|
||||
batch_info_host: torch.Tensor,
|
||||
cu_seqlen_host: torch.Tensor,
|
||||
cu_num_pages: torch.Tensor,
|
||||
cu_num_pages_host: torch.Tensor,
|
||||
cache_loc: torch.Tensor,
|
||||
last_page_len: torch.Tensor,
|
||||
last_page_len_host: torch.Tensor,
|
||||
seq_len_with_cache_host: torch.Tensor,
|
||||
# EXTRA METADATA
|
||||
flashinfer_batch_indices: torch.Tensor,
|
||||
flashinfer_positions: torch.Tensor,
|
||||
@ -439,32 +481,13 @@ def flashinfer_mha_with_cache(
|
||||
v = v.reshape(b * s, -1, head_dim)
|
||||
|
||||
# convert to flashinfer-style metadata
|
||||
num_prefill, num_prefill_tokens, num_decode = batch_info.tolist()
|
||||
num_prefill, num_prefill_tokens, num_decode = batch_info_host.tolist()
|
||||
num_seq = num_prefill + num_decode
|
||||
|
||||
qo_indptr = cu_seqlen[: num_seq + 1]
|
||||
paged_kv_indptr = cu_num_pages[: num_seq + 1]
|
||||
|
||||
# NOTE: it is okay to have cache_loc here without truncation. paged_kv_indptr will be
|
||||
# truncated and will point to the correct sub range of cache_loc.
|
||||
paged_kv_indices = cache_loc
|
||||
paged_kv_last_page_len = last_page_len[:num_seq]
|
||||
num_total_tokens = num_prefill_tokens + num_decode
|
||||
|
||||
n_heads = q.shape[1]
|
||||
n_kv_heads = k.shape[1]
|
||||
|
||||
pp = PlanParams(
|
||||
n_heads=n_heads,
|
||||
n_kv_heads=n_kv_heads,
|
||||
head_dim=head_dim,
|
||||
num_seq=len(qo_indptr) - 1,
|
||||
is_generate=(s == 1),
|
||||
page_size=k_cache.shape[1],
|
||||
q_dtype=q.dtype,
|
||||
kv_dtype=k_cache.dtype,
|
||||
sm_scale=scale,
|
||||
)
|
||||
|
||||
# Assuming k_scale = v_scale = 1.0
|
||||
k_scale, v_scale = 1.0, 1.0
|
||||
# k = (k / k_scale).to(torch.float8_e4m3fn) if k_scale != 1.0, same for v
|
||||
@ -473,28 +496,91 @@ def flashinfer_mha_with_cache(
|
||||
v = v.to(torch.float8_e4m3fn)
|
||||
|
||||
flashinfer.page.append_paged_kv_cache(
|
||||
k,
|
||||
v,
|
||||
flashinfer_batch_indices,
|
||||
flashinfer_positions,
|
||||
(k_cache, v_cache),
|
||||
paged_kv_indices,
|
||||
paged_kv_indptr,
|
||||
paged_kv_last_page_len,
|
||||
append_key=k,
|
||||
append_value=v,
|
||||
batch_indices=flashinfer_batch_indices,
|
||||
positions=flashinfer_positions,
|
||||
paged_kv_cache=(k_cache, v_cache),
|
||||
kv_indices=cache_loc,
|
||||
kv_indptr=cu_num_pages[: num_seq + 1],
|
||||
kv_last_page_len=last_page_len[:num_seq],
|
||||
)
|
||||
|
||||
# run the flashinfer planner and obtain the correct wrapper
|
||||
wrapper = _GlobalFlashInferPlanner.plan(
|
||||
qo_indptr,
|
||||
paged_kv_indptr,
|
||||
paged_kv_indices,
|
||||
paged_kv_last_page_len,
|
||||
pp,
|
||||
)
|
||||
# check if we need to re-combine outputs
|
||||
if num_prefill > 0 and num_decode > 0:
|
||||
y = torch.empty_like(q)
|
||||
else:
|
||||
y = None
|
||||
|
||||
y = wrapper.run(
|
||||
q, (k_cache, v_cache), k_scale=k_scale, v_scale=v_scale, enable_pdl=get_env_enable_pdl()
|
||||
)
|
||||
# now run split prefill, decode
|
||||
if num_prefill > 0:
|
||||
q_prefill = q[:num_prefill_tokens]
|
||||
|
||||
pp_prefill = PlanParams(
|
||||
n_heads=n_heads,
|
||||
n_kv_heads=n_kv_heads,
|
||||
head_dim=head_dim,
|
||||
num_seq=num_prefill,
|
||||
page_size=k_cache.shape[1],
|
||||
q_dtype=q_prefill.dtype,
|
||||
kv_dtype=k_cache.dtype,
|
||||
sm_scale=scale,
|
||||
)
|
||||
|
||||
wrapper_prefill = _GlobalFlashInferPlanner.plan_prefill(
|
||||
qo_indptr_host=cu_seqlen_host[: num_prefill + 1],
|
||||
kv_page_indptr_host=cu_num_pages_host[: num_prefill + 1],
|
||||
kv_page_indices=cache_loc,
|
||||
kv_last_page_len_host=last_page_len_host[:num_prefill],
|
||||
kv_lens_arr_host=seq_len_with_cache_host[:num_prefill],
|
||||
plan_params=pp_prefill,
|
||||
)
|
||||
|
||||
y_prefill = wrapper_prefill.run(
|
||||
q_prefill,
|
||||
(k_cache, v_cache),
|
||||
k_scale=k_scale,
|
||||
v_scale=v_scale,
|
||||
enable_pdl=get_env_enable_pdl(),
|
||||
)
|
||||
if y is not None:
|
||||
y[:num_prefill_tokens] = y_prefill
|
||||
else:
|
||||
y = y_prefill
|
||||
|
||||
if num_decode > 0:
|
||||
q_decode = q[num_prefill_tokens:num_total_tokens]
|
||||
|
||||
pp_decode = PlanParams(
|
||||
n_heads=n_heads,
|
||||
n_kv_heads=n_kv_heads,
|
||||
head_dim=head_dim,
|
||||
num_seq=num_decode,
|
||||
page_size=k_cache.shape[1],
|
||||
q_dtype=q_decode.dtype,
|
||||
kv_dtype=k_cache.dtype,
|
||||
sm_scale=scale,
|
||||
)
|
||||
|
||||
# run the flashinfer planner and obtain the correct wrapper
|
||||
wrapper_decode = _GlobalFlashInferPlanner.plan_decode(
|
||||
kv_page_indptr=cu_num_pages[num_prefill : num_seq + 1],
|
||||
kv_page_indices=cache_loc,
|
||||
kv_last_page_len=last_page_len[num_prefill:num_seq],
|
||||
plan_params=pp_decode,
|
||||
)
|
||||
|
||||
y_decode = wrapper_decode.run(
|
||||
q_decode,
|
||||
(k_cache, v_cache),
|
||||
k_scale=k_scale,
|
||||
v_scale=v_scale,
|
||||
enable_pdl=get_env_enable_pdl(),
|
||||
)
|
||||
if y is not None:
|
||||
y[num_prefill_tokens:num_total_tokens] = y_decode
|
||||
else:
|
||||
y = y_decode
|
||||
|
||||
return y.view(q_shape_og) # [b,s,n*h_d] or [b,s, n, h_d]
|
||||
|
||||
@ -506,11 +592,14 @@ def flashinfer_mha_with_cache_fake(
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
# STANDARD METADATA
|
||||
batch_info: torch.Tensor,
|
||||
cu_seqlen: torch.Tensor,
|
||||
batch_info_host: torch.Tensor,
|
||||
cu_seqlen_host: torch.Tensor,
|
||||
cu_num_pages: torch.Tensor,
|
||||
cu_num_pages_host: torch.Tensor,
|
||||
cache_loc: torch.Tensor,
|
||||
last_page_len: torch.Tensor,
|
||||
last_page_len_host: torch.Tensor,
|
||||
seq_len_with_cache_host: torch.Tensor,
|
||||
# EXTRA METADATA
|
||||
flashinfer_batch_indices: torch.Tensor,
|
||||
flashinfer_positions: torch.Tensor,
|
||||
@ -559,7 +648,16 @@ class FlashInferAttention(AttentionDescriptor):
|
||||
|
||||
@classmethod
|
||||
def get_standard_metadata_args(cls) -> List[str]:
|
||||
return ["batch_info", "cu_seqlen", "cu_num_pages", "cache_loc", "last_page_len"]
|
||||
return [
|
||||
"batch_info_host",
|
||||
"cu_seqlen_host",
|
||||
"cu_num_pages",
|
||||
"cu_num_pages_host",
|
||||
"cache_loc",
|
||||
"last_page_len",
|
||||
"last_page_len_host",
|
||||
"seq_len_with_cache_host",
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def get_prepare_extra_metadata_info(
|
||||
@ -600,18 +698,8 @@ class FlashInferAttention(AttentionDescriptor):
|
||||
return {"workspace_buffer": _init_workspace}
|
||||
|
||||
@classmethod
|
||||
def host_prepare_for_forward(cls, sequence_info: SequenceInfo):
|
||||
batch_info = sequence_info._input_buffer.get_host_view("batch_info")
|
||||
num_prefill, num_prefill_tokens, num_decode = batch_info.tolist()
|
||||
# Call plan for generate-only batches.
|
||||
if num_prefill == 0:
|
||||
_GlobalFlashInferPlanner.plan_generate_only(
|
||||
num_decode,
|
||||
sequence_info._input_buffer.get_host_view("cu_num_pages")[: num_decode + 1],
|
||||
sequence_info._input_buffer.get_host_view("cache_loc"),
|
||||
sequence_info._input_buffer.get_host_view("last_page_len")[:num_decode],
|
||||
)
|
||||
return
|
||||
def get_host_prepare_metadata_function(cls) -> Optional[PrepareMetadataHostCallable]:
|
||||
return prepare_flashinfer_metadata_host
|
||||
|
||||
@classmethod
|
||||
def get_constants(cls, source_attn_node: Node) -> List[Constant]:
|
||||
|
||||
@ -53,7 +53,7 @@ def _cuda_cached_causal_conv1d(
|
||||
weight: torch.Tensor, # [c_out, c_in/groups, k] but we expect depthwise use: [c_in, k]
|
||||
bias: Optional[torch.Tensor],
|
||||
# STANDARD METADATA
|
||||
batch_info: torch.Tensor,
|
||||
batch_info_host: torch.Tensor,
|
||||
cu_seqlen: torch.Tensor,
|
||||
slot_idx: torch.Tensor,
|
||||
use_initial_states: torch.Tensor,
|
||||
@ -80,7 +80,7 @@ def _cuda_cached_causal_conv1d(
|
||||
"""
|
||||
b, s = input.shape[:2]
|
||||
|
||||
num_prefill, num_prefill_tokens, num_decode = batch_info.tolist()
|
||||
num_prefill, num_prefill_tokens, num_decode = batch_info_host.tolist()
|
||||
num_seq = num_prefill + num_decode
|
||||
num_total_tokens = num_prefill_tokens + num_decode
|
||||
|
||||
@ -138,7 +138,7 @@ def _cuda_cached_causal_conv1d_fake(
|
||||
weight: torch.Tensor, # [c_out, c_in/groups, k] but we expect depthwise use: [c_in, k]
|
||||
bias: Optional[torch.Tensor],
|
||||
# STANDARD METADATA
|
||||
batch_info: torch.Tensor,
|
||||
batch_info_host: torch.Tensor,
|
||||
cu_seqlen: torch.Tensor,
|
||||
slot_idx: torch.Tensor,
|
||||
use_initial_states: torch.Tensor,
|
||||
@ -189,7 +189,7 @@ class CudaBackendCausalConv(AttentionDescriptor):
|
||||
|
||||
@classmethod
|
||||
def get_standard_metadata_args(cls) -> List[str]:
|
||||
return ["batch_info", "cu_seqlen", "slot_idx", "use_initial_states"]
|
||||
return ["batch_info_host", "cu_seqlen", "slot_idx", "use_initial_states"]
|
||||
|
||||
@classmethod
|
||||
def get_cache_initializers(
|
||||
|
||||
@ -147,7 +147,7 @@ def _torch_cached_causal_conv1d(
|
||||
weight: torch.Tensor, # [c_out, c_in/groups, k]
|
||||
bias: Optional[torch.Tensor],
|
||||
# STANDARD METADATA
|
||||
batch_info: torch.Tensor,
|
||||
batch_info_host: torch.Tensor,
|
||||
seq_len: torch.Tensor,
|
||||
cu_seqlen: torch.Tensor,
|
||||
slot_idx: torch.Tensor,
|
||||
@ -174,7 +174,7 @@ def _torch_cached_causal_conv1d(
|
||||
num_seq = seq_len.shape[0]
|
||||
|
||||
# get cleaned up metadata
|
||||
num_prefill, num_prefill_tokens, num_decode = batch_info.tolist()
|
||||
num_prefill, num_prefill_tokens, num_decode = batch_info_host.tolist()
|
||||
num_seq = num_prefill + num_decode
|
||||
seq_len = seq_len[:num_seq]
|
||||
seq_start = cu_seqlen[:num_seq]
|
||||
@ -247,7 +247,7 @@ def _torch_cached_causal_conv1d_fake(
|
||||
weight: torch.Tensor, # [c_out, c_in/groups, k]
|
||||
bias: Optional[torch.Tensor],
|
||||
# STANDARD METADATA
|
||||
batch_info: torch.Tensor,
|
||||
batch_info_host: torch.Tensor,
|
||||
seq_len: torch.Tensor,
|
||||
cu_seqlen: torch.Tensor,
|
||||
slot_idx: torch.Tensor,
|
||||
@ -296,7 +296,7 @@ class TorchBackendCausalConv(AttentionDescriptor):
|
||||
|
||||
@classmethod
|
||||
def get_standard_metadata_args(cls) -> List[str]:
|
||||
return ["batch_info", "seq_len", "cu_seqlen", "slot_idx", "use_initial_states"]
|
||||
return ["batch_info_host", "seq_len", "cu_seqlen", "slot_idx", "use_initial_states"]
|
||||
|
||||
@classmethod
|
||||
def get_cache_initializers(
|
||||
|
||||
@ -121,7 +121,7 @@ def _torch_cached_ssm(
|
||||
dt: torch.Tensor, # [b, s, num_heads]
|
||||
dt_bias: torch.Tensor, # [num_heads]
|
||||
# STANDARD METADATA
|
||||
batch_info: torch.Tensor,
|
||||
batch_info_host: torch.Tensor,
|
||||
seq_len: torch.Tensor,
|
||||
cu_seqlen: torch.Tensor,
|
||||
slot_idx: torch.Tensor,
|
||||
@ -145,7 +145,7 @@ def _torch_cached_ssm(
|
||||
num_seq = seq_len.shape[0]
|
||||
|
||||
# get cleaned up metadata
|
||||
num_prefill, num_prefill_tokens, num_decode = batch_info.tolist()
|
||||
num_prefill, num_prefill_tokens, num_decode = batch_info_host.tolist()
|
||||
num_seq = num_prefill + num_decode
|
||||
seq_len = seq_len[:num_seq]
|
||||
seq_start = cu_seqlen[:num_seq]
|
||||
@ -246,7 +246,7 @@ def _torch_cached_ssm_fake(
|
||||
dt: torch.Tensor, # [b, s, num_heads]
|
||||
dt_bias: torch.Tensor, # [num_heads]
|
||||
# STANDARD METADATA
|
||||
batch_info: torch.Tensor,
|
||||
batch_info_host: torch.Tensor,
|
||||
seq_len: torch.Tensor,
|
||||
cu_seqlen: torch.Tensor,
|
||||
slot_idx: torch.Tensor,
|
||||
@ -293,7 +293,7 @@ class TorchBackendSSM(AttentionDescriptor):
|
||||
|
||||
@classmethod
|
||||
def get_standard_metadata_args(cls) -> List[str]:
|
||||
return ["batch_info", "seq_len", "cu_seqlen", "slot_idx", "use_initial_states"]
|
||||
return ["batch_info_host", "seq_len", "cu_seqlen", "slot_idx", "use_initial_states"]
|
||||
|
||||
@classmethod
|
||||
def get_cache_initializers(
|
||||
|
||||
@ -42,7 +42,7 @@ from ..attention_interface import (
|
||||
def _triton_ssm_prepare_metadata(
|
||||
# INPUTS
|
||||
position_ids: torch.Tensor,
|
||||
batch_info: torch.Tensor,
|
||||
batch_info_host: torch.Tensor,
|
||||
seq_len: torch.Tensor,
|
||||
cu_seqlen: torch.Tensor,
|
||||
# EXTRA METADATA PROVIDED BY THE DESCRIPTOR
|
||||
@ -53,7 +53,7 @@ def _triton_ssm_prepare_metadata(
|
||||
Returns a tuple of (seq_len_sanitized, seq_start, slot_idx_sanitized).
|
||||
"""
|
||||
device = cu_seqlen.device
|
||||
num_prefill, num_prefill_tokens, num_decode = batch_info.tolist()
|
||||
num_prefill, num_prefill_tokens, num_decode = batch_info_host.tolist()
|
||||
|
||||
if num_prefill > 0:
|
||||
chunk_indices, chunk_offsets = cu_seqlens_to_chunk_indices_offsets(
|
||||
@ -74,7 +74,7 @@ def _triton_ssm_prepare_metadata(
|
||||
def _triton_ssm_prepare_metadata_fake(
|
||||
# INPUTS
|
||||
position_ids: torch.Tensor,
|
||||
batch_info: torch.Tensor,
|
||||
batch_info_host: torch.Tensor,
|
||||
seq_len: torch.Tensor,
|
||||
cu_seqlen: torch.Tensor,
|
||||
# EXTRA METADATA PROVIDED BY THE DESCRIPTOR
|
||||
@ -110,7 +110,7 @@ def _triton_cached_ssm(
|
||||
dt: torch.Tensor, # [b, s, num_heads]
|
||||
dt_bias: torch.Tensor, # [num_heads]
|
||||
# STANDARD METADATA
|
||||
batch_info: torch.Tensor,
|
||||
batch_info_host: torch.Tensor,
|
||||
cu_seqlen: torch.Tensor,
|
||||
slot_idx: torch.Tensor,
|
||||
use_initial_states: torch.Tensor,
|
||||
@ -140,7 +140,7 @@ def _triton_cached_ssm(
|
||||
|
||||
ssm_state_size = B.shape[3]
|
||||
|
||||
num_prefill, num_prefill_tokens, num_decode = batch_info.tolist()
|
||||
num_prefill, num_prefill_tokens, num_decode = batch_info_host.tolist()
|
||||
num_seq = num_prefill + num_decode
|
||||
num_total_tokens = num_prefill_tokens + num_decode
|
||||
|
||||
@ -245,7 +245,7 @@ def _triton_cached_ssm_fake(
|
||||
dt: torch.Tensor, # [b, s, num_heads]
|
||||
dt_bias: torch.Tensor, # [num_heads]
|
||||
# STANDARD METADATA
|
||||
batch_info: torch.Tensor,
|
||||
batch_info_host: torch.Tensor,
|
||||
cu_seqlen: torch.Tensor,
|
||||
slot_idx: torch.Tensor,
|
||||
use_initial_states: torch.Tensor,
|
||||
@ -294,7 +294,7 @@ class TritonBackendSSM(AttentionDescriptor):
|
||||
|
||||
@classmethod
|
||||
def get_standard_metadata_args(cls) -> List[str]:
|
||||
return ["batch_info", "cu_seqlen", "slot_idx", "use_initial_states"]
|
||||
return ["batch_info_host", "cu_seqlen", "slot_idx", "use_initial_states"]
|
||||
|
||||
@classmethod
|
||||
def get_prepare_extra_metadata_info(
|
||||
|
||||
@ -31,7 +31,7 @@ def fused_flattened_mla_with_cache(
|
||||
kv: torch.Tensor,
|
||||
k_pe: torch.Tensor,
|
||||
# STANDARD METADATA
|
||||
batch_info: torch.Tensor,
|
||||
batch_info_host: torch.Tensor,
|
||||
seq_len: torch.Tensor,
|
||||
input_pos: torch.Tensor,
|
||||
cache_loc: torch.Tensor,
|
||||
@ -55,7 +55,7 @@ def fused_flattened_mla_with_cache(
|
||||
# and number of tokens per sequence are encoded in seq_len and seq_start.
|
||||
|
||||
# check for sequence info and truncate metadata
|
||||
num_prefill, num_prefill_tokens, num_decode = batch_info.tolist()
|
||||
num_prefill, num_prefill_tokens, num_decode = batch_info_host.tolist()
|
||||
num_seq = num_prefill + num_decode
|
||||
|
||||
seq_len = seq_len[:num_seq]
|
||||
@ -166,7 +166,7 @@ def fused_flattened_mla_with_cache_fake(
|
||||
kv: torch.Tensor,
|
||||
k_pe: torch.Tensor,
|
||||
# STANDARD METADATA
|
||||
batch_info: torch.Tensor,
|
||||
batch_info_host: torch.Tensor,
|
||||
seq_len: torch.Tensor,
|
||||
input_pos: torch.Tensor,
|
||||
cache_loc: torch.Tensor,
|
||||
@ -212,7 +212,7 @@ class MultiHeadLatentAttention(AttentionDescriptor):
|
||||
|
||||
@classmethod
|
||||
def get_standard_metadata_args(cls) -> List[str]:
|
||||
return ["batch_info", "seq_len", "input_pos", "cache_loc", "cu_seqlen"]
|
||||
return ["batch_info_host", "seq_len", "input_pos", "cache_loc", "cu_seqlen"]
|
||||
|
||||
@classmethod
|
||||
def get_cache_initializers(
|
||||
|
||||
@ -253,7 +253,7 @@ def torch_backend_mha_with_cache(
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
# STANDARD METADATA
|
||||
batch_info: torch.Tensor,
|
||||
batch_info_host: torch.Tensor,
|
||||
seq_len: torch.Tensor,
|
||||
input_pos: torch.Tensor,
|
||||
cache_loc: torch.Tensor,
|
||||
@ -278,7 +278,7 @@ def torch_backend_mha_with_cache(
|
||||
b, s = q.shape[:2]
|
||||
|
||||
# get cleaned up metadata
|
||||
num_prefill, num_prefill_tokens, num_decode = batch_info.tolist()
|
||||
num_prefill, num_prefill_tokens, num_decode = batch_info_host.tolist()
|
||||
num_seq = num_prefill + num_decode
|
||||
seq_len = seq_len[:num_seq]
|
||||
input_pos = input_pos[:num_seq]
|
||||
@ -352,7 +352,7 @@ def torch_backend_mha_with_cache_fake(
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
# STANDARD METADATA
|
||||
batch_info: torch.Tensor,
|
||||
batch_info_host: torch.Tensor,
|
||||
seq_len: torch.Tensor,
|
||||
input_pos: torch.Tensor,
|
||||
cache_loc: torch.Tensor,
|
||||
@ -400,7 +400,7 @@ class TorchBackendAttention(AttentionDescriptor):
|
||||
|
||||
@classmethod
|
||||
def get_standard_metadata_args(cls) -> List[str]:
|
||||
return ["batch_info", "seq_len", "input_pos", "cache_loc", "cu_seqlen"]
|
||||
return ["batch_info_host", "seq_len", "input_pos", "cache_loc", "cu_seqlen"]
|
||||
|
||||
@classmethod
|
||||
def get_cache_initializers(
|
||||
|
||||
@ -5,14 +5,14 @@ import torch
|
||||
def gather_logits_before_lm_head(
|
||||
hidden_states: torch.Tensor,
|
||||
logits_gather_indices: torch.Tensor, # long tensor
|
||||
logits_gather_info: torch.Tensor, # int tensor
|
||||
logits_gather_info_host: torch.Tensor, # int tensor
|
||||
) -> torch.Tensor:
|
||||
"""Gather hidden states using logits_gather_indices before LM head.
|
||||
|
||||
Args:
|
||||
hidden_states: Hidden states tensor [b, 1, hidden] or [1, s_total, hidden]
|
||||
logits_gather_indices: indices for gathering logits.
|
||||
logits_gather_info: info for gathering logits.
|
||||
logits_gather_info_host: info for gathering logits.
|
||||
Returns:
|
||||
Gathered and flattened hidden states [num_gathered_tokens, hidden]
|
||||
"""
|
||||
@ -21,7 +21,7 @@ def gather_logits_before_lm_head(
|
||||
hidden_states = hidden_states.squeeze(int(is_decode_only))
|
||||
|
||||
# info object
|
||||
num_tokens_to_gather, gather_required = logits_gather_info.tolist()
|
||||
num_tokens_to_gather, gather_required = logits_gather_info_host.tolist()
|
||||
|
||||
if gather_required:
|
||||
out = hidden_states.index_select(0, logits_gather_indices[:num_tokens_to_gather])
|
||||
@ -34,7 +34,7 @@ def gather_logits_before_lm_head(
|
||||
def gather_logits_before_lm_head_fake(
|
||||
hidden_states: torch.Tensor,
|
||||
logits_gather_indices: torch.Tensor,
|
||||
logits_gather_info: torch.Tensor,
|
||||
logits_gather_info_host: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
# NOTE: shape is not correct in fake mode
|
||||
# see https://github.com/NVIDIA/TensorRT-LLM/issues/9878
|
||||
|
||||
@ -188,7 +188,7 @@ def flattened_mha_with_cache(
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
# STANDARD METADATA
|
||||
batch_info: torch.Tensor,
|
||||
batch_info_host: torch.Tensor,
|
||||
seq_len: torch.Tensor,
|
||||
input_pos: torch.Tensor,
|
||||
cache_loc: torch.Tensor,
|
||||
@ -210,7 +210,7 @@ def flattened_mha_with_cache(
|
||||
NOTE: this op can also handle seq_len==0, which might be useful for CUDAGRAPH.
|
||||
"""
|
||||
# check for sequence info and truncate metadata
|
||||
num_prefill, num_prefill_tokens, num_decode = batch_info.tolist()
|
||||
num_prefill, num_prefill_tokens, num_decode = batch_info_host.tolist()
|
||||
num_seq = num_prefill + num_decode
|
||||
|
||||
seq_len = seq_len[:num_seq]
|
||||
@ -290,7 +290,7 @@ def flattened_mha_fake(
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
# STANDARD METADATA
|
||||
batch_info: torch.Tensor,
|
||||
batch_info_host: torch.Tensor,
|
||||
seq_len: torch.Tensor,
|
||||
input_pos: torch.Tensor,
|
||||
cache_loc: torch.Tensor,
|
||||
@ -337,7 +337,7 @@ class TritonAttention(AttentionDescriptor):
|
||||
|
||||
@classmethod
|
||||
def get_standard_metadata_args(cls) -> List[str]:
|
||||
return ["batch_info", "seq_len", "input_pos", "cache_loc", "cu_seqlen"]
|
||||
return ["batch_info_host", "seq_len", "input_pos", "cache_loc", "cu_seqlen"]
|
||||
|
||||
@classmethod
|
||||
def get_cache_initializers(
|
||||
|
||||
@ -49,15 +49,15 @@ def _bamba_mixer_torch_forward(
|
||||
)
|
||||
slot_idx_t = torch.arange(batch_size, device=input_states.device, dtype=torch.long)
|
||||
use_initial_states_t = torch.zeros(batch_size, device=input_states.device, dtype=torch.bool)
|
||||
# batch_info: [num_prefill, num_prefill_tokens, num_decode]
|
||||
# batch_info_host: [num_prefill, num_prefill_tokens, num_decode]
|
||||
# For context phase (seq_len > 1): [batch_size, batch_size * seq_len, 0]
|
||||
# For generate phase (seq_len == 1): [0, 0, batch_size]
|
||||
if seq_len == 1:
|
||||
batch_info_t = torch.tensor(
|
||||
batch_info_host_t = torch.tensor(
|
||||
[0, 0, batch_size], device=input_states.device, dtype=torch.int32
|
||||
)
|
||||
else:
|
||||
batch_info_t = torch.tensor(
|
||||
batch_info_host_t = torch.tensor(
|
||||
[batch_size, batch_size * seq_len, 0], device=input_states.device, dtype=torch.int32
|
||||
)
|
||||
if use_caching:
|
||||
@ -68,7 +68,7 @@ def _bamba_mixer_torch_forward(
|
||||
self.conv1d.weight,
|
||||
self.conv1d.bias,
|
||||
# STANDARD METADATA
|
||||
batch_info_t,
|
||||
batch_info_host_t,
|
||||
seq_len_t,
|
||||
cu_seqlen_t,
|
||||
slot_idx_t,
|
||||
@ -123,7 +123,7 @@ def _bamba_mixer_torch_forward(
|
||||
dt=dt,
|
||||
dt_bias=self.dt_bias,
|
||||
# STANDARD METADATA
|
||||
batch_info=batch_info_t,
|
||||
batch_info_host=batch_info_host_t,
|
||||
seq_len=seq_len_t,
|
||||
cu_seqlen=cu_seqlen_t,
|
||||
slot_idx=slot_idx_t,
|
||||
|
||||
@ -379,7 +379,7 @@ def maybe_pad_for_cuda_graph(func):
|
||||
|
||||
# check if we have a dummy request to use
|
||||
if self.padding_dummy_request is None:
|
||||
ad_logger.error("No CUDA graph padding possible due to missing dummy request.")
|
||||
ad_logger.info("No CUDA graph padding possible due to missing dummy request.")
|
||||
return _call_func()
|
||||
|
||||
# pad the scheduled requests with the dummy request
|
||||
|
||||
@ -67,12 +67,14 @@ class GatherLogitsBeforeLmHeadTransform(BaseTransform):
|
||||
|
||||
# Add logits_gather_mask as input in the graph and the sequence info interface
|
||||
logits_gather_indices_node = self._add_or_retrieve_input(gm, cm, "logits_gather_indices")
|
||||
logits_gather_info_node = self._add_or_retrieve_input(gm, cm, "logits_gather_info")
|
||||
logits_gather_info_host_node = self._add_or_retrieve_input(
|
||||
gm, cm, "logits_gather_info_host"
|
||||
)
|
||||
|
||||
with gm.graph.inserting_after(node_to_gather):
|
||||
gathered_node = gm.graph.call_function(
|
||||
torch.ops.auto_deploy.gather_logits_before_lm_head.default,
|
||||
args=(node_to_gather, logits_gather_indices_node, logits_gather_info_node),
|
||||
args=(node_to_gather, logits_gather_indices_node, logits_gather_info_host_node),
|
||||
)
|
||||
node_to_gather.replace_all_uses_with(gathered_node)
|
||||
gathered_node.replace_input_with(gathered_node, node_to_gather)
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
"""Graph transformation to automatically add kv cache into fused MHA op."""
|
||||
|
||||
import inspect
|
||||
import operator
|
||||
from typing import Dict, List, Optional, Tuple, Type
|
||||
|
||||
@ -106,6 +107,23 @@ class InsertCachedAttention(BaseTransform):
|
||||
gm, prep_meta_op, inputs_for_prep_meta, const_args, num_meta_out
|
||||
)
|
||||
|
||||
def _process_metadata_host(self, cm: CachedSequenceInterface):
|
||||
"""Process the host-side prepare metadata function."""
|
||||
prep_meta_host_op = self.attn_descriptor.get_host_prepare_metadata_function()
|
||||
if prep_meta_host_op is None:
|
||||
return
|
||||
|
||||
# analyze the args of the host-side prepare metadata function using inspect
|
||||
sig = inspect.signature(prep_meta_host_op)
|
||||
args = sig.parameters.keys()
|
||||
|
||||
# check if all args are available in the cached sequence interface
|
||||
unavailable_args = args - cm.info.available_args
|
||||
assert not unavailable_args, f"Missing args in SequenceInfo: {unavailable_args=}"
|
||||
|
||||
# add the host-side prepare metadata function to the graph
|
||||
cm.info.register_host_prepare_for_attention_forward(prep_meta_host_op, list(args))
|
||||
|
||||
def _process_cache_node(self, gm: GraphModule, cache_name: str) -> Node:
|
||||
"""Process the cache nodes by inserting a cached attention replacement op."""
|
||||
return add_graph_input(gm, cache_name)
|
||||
@ -173,6 +191,9 @@ class InsertCachedAttention(BaseTransform):
|
||||
# insert metadata computation and extract each argument as a node
|
||||
meta_nodes_extra = self._process_metadata_extra(gm, cm, source_attn_nodes[0])
|
||||
|
||||
# Register host-side prepare_metadata function for attention descriptor.
|
||||
self._process_metadata_host(cm)
|
||||
|
||||
buffer_in_lookup: Dict[str, Node] = {}
|
||||
|
||||
# replace fused attention node with attention node that has kv cache
|
||||
@ -213,11 +234,7 @@ class InsertCachedAttention(BaseTransform):
|
||||
buffer_in_nodes,
|
||||
constants,
|
||||
)
|
||||
# Attention descriptor should register its host function with SequenceInfo.
|
||||
# This function will be called before graph invocation.
|
||||
cm.info.register_host_prepare_for_attention_forward(
|
||||
attn_descriptor.host_prepare_for_forward
|
||||
)
|
||||
|
||||
num_cached_attn_replacements += 1
|
||||
|
||||
info = TransformInfo(
|
||||
|
||||
@ -220,7 +220,6 @@ class TestNemotronMOE(LlmapiAccuracyTestHarness):
|
||||
@pytest.mark.skip_less_device_memory(32000)
|
||||
def test_fp8(self):
|
||||
kwargs = self.get_default_kwargs()
|
||||
sampling_params = self.get_default_sampling_params()
|
||||
with AutoDeployLLM(model=self.MODEL_PATH_FP8,
|
||||
tokenizer=self.MODEL_PATH_FP8,
|
||||
**kwargs) as llm:
|
||||
@ -228,8 +227,8 @@ class TestNemotronMOE(LlmapiAccuracyTestHarness):
|
||||
llm.args.quant_config.quant_algo = QuantAlgo.FP8
|
||||
llm.args.quant_config.kv_cache_quant_algo = QuantAlgo.FP8
|
||||
|
||||
task = MMLU(self.MODEL_NAME)
|
||||
task.evaluate(llm, sampling_params=sampling_params)
|
||||
# task = MMLU(self.MODEL_NAME)
|
||||
# task.evaluate(llm, sampling_params=sampling_params)
|
||||
task = GSM8K(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
|
||||
|
||||
@ -40,13 +40,13 @@ class TorchAttentionReference:
|
||||
0, batch_size * seq_len, seq_len, device=q.device, dtype=torch.int32
|
||||
)
|
||||
|
||||
# Create batch_info: [num_prefill, num_prefill_tokens, num_decode]
|
||||
# Create batch_info_host: [num_prefill, num_prefill_tokens, num_decode]
|
||||
# For context phase (seq_len > 1): [batch_size, batch_size * seq_len, 0]
|
||||
# For generate phase (seq_len == 1): [0, 0, batch_size]
|
||||
if seq_len == 1:
|
||||
batch_info = torch.tensor([0, 0, batch_size], device=q.device, dtype=torch.int32)
|
||||
batch_info_host = torch.tensor([0, 0, batch_size], device=q.device, dtype=torch.int32)
|
||||
else:
|
||||
batch_info = torch.tensor(
|
||||
batch_info_host = torch.tensor(
|
||||
[batch_size, batch_size * seq_len, 0], device=q.device, dtype=torch.int32
|
||||
)
|
||||
|
||||
@ -60,7 +60,7 @@ class TorchAttentionReference:
|
||||
q_flat,
|
||||
k_flat,
|
||||
v_flat,
|
||||
batch_info,
|
||||
batch_info_host,
|
||||
seq_len_tensor,
|
||||
input_positions,
|
||||
cache_loc,
|
||||
@ -84,7 +84,7 @@ class TorchAttentionReference:
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
batch_info,
|
||||
batch_info_host,
|
||||
seq_len,
|
||||
input_positions,
|
||||
cache_loc,
|
||||
@ -101,7 +101,7 @@ class TorchAttentionReference:
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
batch_info,
|
||||
batch_info_host,
|
||||
seq_len,
|
||||
input_positions,
|
||||
cache_loc,
|
||||
@ -144,15 +144,15 @@ class TorchAttentionReference:
|
||||
k_flat = k_new.view(1, batch_size, -1)
|
||||
v_flat = v_new.view(1, batch_size, -1)
|
||||
|
||||
# Create batch_info for decode phase: [num_prefill, num_prefill_tokens, num_decode]
|
||||
batch_info = torch.tensor([0, 0, batch_size], device=q.device, dtype=torch.int32)
|
||||
# Create batch_info_host for decode phase: [num_prefill, num_prefill_tokens, num_decode]
|
||||
batch_info_host = torch.tensor([0, 0, batch_size], device=q.device, dtype=torch.int32)
|
||||
|
||||
# Call torch backend via custom op registry
|
||||
output_flat = torch.ops.auto_deploy.torch_cached_attention_with_cache(
|
||||
q_flat,
|
||||
k_flat,
|
||||
v_flat,
|
||||
batch_info,
|
||||
batch_info_host,
|
||||
seq_len,
|
||||
input_positions,
|
||||
cache_loc,
|
||||
@ -170,7 +170,7 @@ class TorchAttentionReference:
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
batch_info,
|
||||
batch_info_host,
|
||||
seq_len,
|
||||
input_positions,
|
||||
cache_loc,
|
||||
@ -189,7 +189,7 @@ class TorchAttentionReference:
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
batch_info,
|
||||
batch_info_host,
|
||||
seq_len,
|
||||
input_positions,
|
||||
cache_loc,
|
||||
|
||||
@ -125,9 +125,9 @@ def test_flat_gqa_op(
|
||||
k = torch.randn(1, seq_len.sum(), n_kv_heads * D_HEAD, **dtype_kwargs)
|
||||
v = torch.randn(1, seq_len.sum(), n_kv_heads * D_HEAD, **dtype_kwargs)
|
||||
|
||||
# create batch_info: [num_prefill, num_prefill_tokens, num_decode]
|
||||
# create batch_info_host: [num_prefill, num_prefill_tokens, num_decode]
|
||||
num_prefill_tokens = seq_len[:num_context].sum()
|
||||
batch_info = torch.tensor([num_context, num_prefill_tokens, num_generate], **int_kwargs)
|
||||
batch_info_host = torch.tensor([num_context, num_prefill_tokens, num_generate], **int_kwargs)
|
||||
|
||||
# run op
|
||||
output = torch.ops.auto_deploy.triton_attention_flattened_mha_with_cache(
|
||||
@ -136,7 +136,7 @@ def test_flat_gqa_op(
|
||||
k,
|
||||
v,
|
||||
# STANDARD METADATA
|
||||
batch_info,
|
||||
batch_info_host,
|
||||
seq_len,
|
||||
input_positions,
|
||||
cache_loc,
|
||||
@ -150,7 +150,7 @@ def test_flat_gqa_op(
|
||||
|
||||
# Use torch backend as clean reference
|
||||
ref_flat = TorchAttentionReference.flattened_mha_with_cache(
|
||||
q, k, v, batch_info, seq_len, input_positions, cache_loc, seq_start, k_cache, v_cache
|
||||
q, k, v, batch_info_host, seq_len, input_positions, cache_loc, seq_start, k_cache, v_cache
|
||||
)
|
||||
|
||||
assert torch.allclose(
|
||||
|
||||
@ -59,9 +59,9 @@ def test_generate_only_with_slot_mapping_cuda(conv_env):
|
||||
# Metadata (not used in generate-only op entry, but required by the interface)
|
||||
cu_seqlen = torch.zeros(batch, device=device, dtype=torch.int32)
|
||||
use_initial_states = torch.zeros(batch, device=device, dtype=torch.bool)
|
||||
# batch_info: [num_prefill, num_prefill_tokens, num_decode]
|
||||
# batch_info_host: [num_prefill, num_prefill_tokens, num_decode]
|
||||
# For generate-only: num_decode = batch, num_prefill = 0
|
||||
batch_info = torch.tensor([0, 0, batch], device=device, dtype=torch.int32)
|
||||
batch_info_host = torch.tensor([0, 0, batch], device=device, dtype=torch.int32)
|
||||
# Snapshot caches for reference before running op (op mutates caches)
|
||||
gathered_before = conv_state_cache.clone().index_select(0, slot_idx)
|
||||
x_ref = x.clone()
|
||||
@ -72,7 +72,7 @@ def test_generate_only_with_slot_mapping_cuda(conv_env):
|
||||
w,
|
||||
b,
|
||||
# STANDARD METADATA
|
||||
batch_info,
|
||||
batch_info_host,
|
||||
cu_seqlen,
|
||||
slot_idx,
|
||||
use_initial_states,
|
||||
|
||||
@ -64,6 +64,12 @@ def test_flashinfer_attention_op_context(seq_length, n_heads, batch_size, dtype,
|
||||
paged_kv_indices = torch.arange(BATCH_SIZE).int().to(device)
|
||||
paged_kv_last_page_len = offsets + seq_len_tensor
|
||||
|
||||
# Host copies of metadata
|
||||
qo_indptr_host = qo_indptr.cpu()
|
||||
paged_kv_indptr_host = paged_kv_indptr.cpu()
|
||||
paged_kv_last_page_len_host = paged_kv_last_page_len.cpu()
|
||||
seq_len_with_cache_host = (offsets + seq_len_tensor).cpu()
|
||||
|
||||
# Q,K,V are computed using GEMM.
|
||||
q = torch.randn(BATCH_SIZE, SEQ_LEN, N_HEADS * D_HEAD, dtype=DTYPE).to(device)
|
||||
k = torch.randn(BATCH_SIZE, SEQ_LEN, N_HEADS * D_HEAD, dtype=DTYPE).to(device)
|
||||
@ -88,8 +94,8 @@ def test_flashinfer_attention_op_context(seq_length, n_heads, batch_size, dtype,
|
||||
),
|
||||
BATCH_SIZE * SEQ_LEN,
|
||||
)
|
||||
# Create batch_info: [num_prefill, num_prefill_tokens, num_decode]
|
||||
batch_info = torch.tensor(
|
||||
# Create batch_info_host: [num_prefill, num_prefill_tokens, num_decode]
|
||||
batch_info_host = torch.tensor(
|
||||
[BATCH_SIZE, BATCH_SIZE * SEQ_LEN, 0], dtype=torch.int32, device=device
|
||||
)
|
||||
flashinfer_output = torch.ops.auto_deploy.flashinfer_attention_mha_with_cache(
|
||||
@ -98,11 +104,14 @@ def test_flashinfer_attention_op_context(seq_length, n_heads, batch_size, dtype,
|
||||
k,
|
||||
v,
|
||||
# STANDARD METADATA
|
||||
batch_info,
|
||||
qo_indptr,
|
||||
batch_info_host,
|
||||
qo_indptr_host,
|
||||
paged_kv_indptr,
|
||||
paged_kv_indptr_host,
|
||||
paged_kv_indices,
|
||||
paged_kv_last_page_len,
|
||||
paged_kv_last_page_len_host,
|
||||
seq_len_with_cache_host,
|
||||
# EXTRA METADATA
|
||||
batch_indices,
|
||||
positions,
|
||||
@ -168,6 +177,12 @@ def test_flashinfer_attention_op_decode(
|
||||
paged_kv_indices = torch.arange(BATCH_SIZE).int().to(device)
|
||||
paged_kv_last_page_len = offsets + seq_len_tensor
|
||||
|
||||
# Host copies of metadata
|
||||
qo_indptr_host = qo_indptr.cpu()
|
||||
paged_kv_indptr_host = paged_kv_indptr.cpu()
|
||||
paged_kv_last_page_len_host = paged_kv_last_page_len.cpu()
|
||||
seq_len_with_cache_host = (offsets + seq_len_tensor).cpu()
|
||||
|
||||
# Q,K,V are computed using GEMM.
|
||||
q = torch.randn(BATCH_SIZE, SEQ_LEN, N_HEADS * D_HEAD, dtype=DTYPE).to(device)
|
||||
k = torch.ones(BATCH_SIZE, SEQ_LEN, N_HEADS * D_HEAD, dtype=DTYPE).to(device)
|
||||
@ -224,20 +239,23 @@ def test_flashinfer_attention_op_decode(
|
||||
),
|
||||
BATCH_SIZE * SEQ_LEN,
|
||||
)
|
||||
# Create batch_info: [num_prefill, num_prefill_tokens, num_decode]
|
||||
# Create batch_info_host: [num_prefill, num_prefill_tokens, num_decode]
|
||||
# For decode phase: num_decode = BATCH_SIZE, num_prefill = 0
|
||||
batch_info = torch.tensor([0, 0, BATCH_SIZE], dtype=torch.int32, device=device)
|
||||
batch_info_host = torch.tensor([0, 0, BATCH_SIZE], dtype=torch.int32, device=device)
|
||||
flashinfer_output = torch.ops.auto_deploy.flashinfer_attention_mha_with_cache(
|
||||
# Q, K, V
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
# STANDARD METADATA
|
||||
batch_info,
|
||||
qo_indptr,
|
||||
batch_info_host,
|
||||
qo_indptr_host,
|
||||
paged_kv_indptr,
|
||||
paged_kv_indptr_host,
|
||||
paged_kv_indices,
|
||||
paged_kv_last_page_len,
|
||||
paged_kv_last_page_len_host,
|
||||
seq_len_with_cache_host,
|
||||
# EXTRA METADATA
|
||||
batch_indices,
|
||||
positions,
|
||||
@ -323,6 +341,12 @@ def test_flashinfer_attention_context_and_generate(
|
||||
paged_kv_indices = torch.arange(BATCH_SIZE).int().to(device)
|
||||
paged_kv_last_page_len = offsets + seq_len_tensor
|
||||
|
||||
# Host copies of metadata
|
||||
qo_indptr_host = qo_indptr.cpu()
|
||||
paged_kv_indptr_host = paged_kv_indptr.cpu()
|
||||
paged_kv_last_page_len_host = paged_kv_last_page_len.cpu()
|
||||
seq_len_with_cache_host = (offsets + seq_len_tensor).cpu()
|
||||
|
||||
# Q,K,V for prefill phase
|
||||
q_1 = torch.randn(BATCH_SIZE, PREFILL_SEQ_LEN, N_HEADS * D_HEAD, dtype=DTYPE).to(device)
|
||||
k_1 = torch.randn(BATCH_SIZE, PREFILL_SEQ_LEN, N_HEADS * D_HEAD, dtype=DTYPE).to(device)
|
||||
@ -347,8 +371,8 @@ def test_flashinfer_attention_context_and_generate(
|
||||
),
|
||||
BATCH_SIZE * PREFILL_SEQ_LEN,
|
||||
)
|
||||
# Create batch_info: [num_prefill, num_prefill_tokens, num_decode]
|
||||
batch_info = torch.tensor(
|
||||
# Create batch_info_host: [num_prefill, num_prefill_tokens, num_decode]
|
||||
batch_info_host = torch.tensor(
|
||||
[BATCH_SIZE, BATCH_SIZE * PREFILL_SEQ_LEN, 0], dtype=torch.int32, device=device
|
||||
)
|
||||
flashinfer_output_1 = torch.ops.auto_deploy.flashinfer_attention_mha_with_cache(
|
||||
@ -357,11 +381,14 @@ def test_flashinfer_attention_context_and_generate(
|
||||
k_1,
|
||||
v_1,
|
||||
# STANDARD METADATA
|
||||
batch_info,
|
||||
qo_indptr,
|
||||
batch_info_host,
|
||||
qo_indptr_host,
|
||||
paged_kv_indptr,
|
||||
paged_kv_indptr_host,
|
||||
paged_kv_indices,
|
||||
paged_kv_last_page_len,
|
||||
paged_kv_last_page_len_host,
|
||||
seq_len_with_cache_host,
|
||||
# EXTRA METADATA
|
||||
batch_indices,
|
||||
positions,
|
||||
@ -415,6 +442,12 @@ def test_flashinfer_attention_context_and_generate(
|
||||
paged_kv_indices = torch.arange(BATCH_SIZE).int().to(device)
|
||||
paged_kv_last_page_len = offsets + seq_len_tensor
|
||||
|
||||
# Host copies of metadata
|
||||
qo_indptr_host = qo_indptr.cpu()
|
||||
paged_kv_indptr_host = paged_kv_indptr.cpu()
|
||||
paged_kv_last_page_len_host = paged_kv_last_page_len.cpu()
|
||||
seq_len_with_cache_host = (offsets + seq_len_tensor).cpu()
|
||||
|
||||
# Q,K,V are computed using GEMM.
|
||||
q_3 = torch.randn(BATCH_SIZE, 1, N_HEADS * D_HEAD, dtype=DTYPE).to(device)
|
||||
k_3 = torch.randn(BATCH_SIZE, 1, N_HEADS * D_HEAD, dtype=DTYPE).to(device)
|
||||
@ -430,19 +463,22 @@ def test_flashinfer_attention_context_and_generate(
|
||||
),
|
||||
BATCH_SIZE * 1,
|
||||
)
|
||||
# Create batch_info: [num_prefill, num_prefill_tokens, num_decode]
|
||||
batch_info = torch.tensor([0, 0, BATCH_SIZE], dtype=torch.int32, device=device)
|
||||
# Create batch_info_host: [num_prefill, num_prefill_tokens, num_decode]
|
||||
batch_info_host = torch.tensor([0, 0, BATCH_SIZE], dtype=torch.int32, device=device)
|
||||
flashinfer_output_3 = torch.ops.auto_deploy.flashinfer_attention_mha_with_cache(
|
||||
# Q, K, V
|
||||
q_3,
|
||||
k_3,
|
||||
v_3,
|
||||
# STANDARD METADATA
|
||||
batch_info,
|
||||
qo_indptr,
|
||||
batch_info_host,
|
||||
qo_indptr_host,
|
||||
paged_kv_indptr,
|
||||
paged_kv_indptr_host,
|
||||
paged_kv_indices,
|
||||
paged_kv_last_page_len,
|
||||
paged_kv_last_page_len_host,
|
||||
seq_len_with_cache_host,
|
||||
# EXTRA METADATA
|
||||
batch_indices,
|
||||
positions,
|
||||
@ -519,6 +555,12 @@ def test_flashinfer_attention_op_context_input_pos(seq, batch_size, n_heads, dty
|
||||
paged_kv_indices = torch.arange(BATCH_SIZE).int().to(device)
|
||||
paged_kv_last_page_len = offsets + seq_len_tensor
|
||||
|
||||
# Host copies of metadata
|
||||
qo_indptr_host = qo_indptr.cpu()
|
||||
paged_kv_indptr_host = paged_kv_indptr.cpu()
|
||||
paged_kv_last_page_len_host = paged_kv_last_page_len.cpu()
|
||||
seq_len_with_cache_host = (offsets + seq_len_tensor).cpu()
|
||||
|
||||
# Q,K,V are computed using GEMM.
|
||||
q = torch.randn(BATCH_SIZE, SEQ_LEN, N_HEADS * D_HEAD, dtype=DTYPE).to(device)
|
||||
k = torch.randn(BATCH_SIZE, SEQ_LEN, N_HEADS * D_HEAD, dtype=DTYPE).to(device)
|
||||
@ -543,8 +585,8 @@ def test_flashinfer_attention_op_context_input_pos(seq, batch_size, n_heads, dty
|
||||
),
|
||||
BATCH_SIZE * SEQ_LEN,
|
||||
)
|
||||
# Create batch_info: [num_prefill, num_prefill_tokens, num_decode]
|
||||
batch_info = torch.tensor(
|
||||
# Create batch_info_host: [num_prefill, num_prefill_tokens, num_decode]
|
||||
batch_info_host = torch.tensor(
|
||||
[BATCH_SIZE, BATCH_SIZE * SEQ_LEN, 0], dtype=torch.int32, device=device
|
||||
)
|
||||
flashinfer_output = torch.ops.auto_deploy.flashinfer_attention_mha_with_cache(
|
||||
@ -553,11 +595,14 @@ def test_flashinfer_attention_op_context_input_pos(seq, batch_size, n_heads, dty
|
||||
k,
|
||||
v,
|
||||
# STANDARD METADATA
|
||||
batch_info,
|
||||
qo_indptr,
|
||||
batch_info_host,
|
||||
qo_indptr_host,
|
||||
paged_kv_indptr,
|
||||
paged_kv_indptr_host,
|
||||
paged_kv_indices,
|
||||
paged_kv_last_page_len,
|
||||
paged_kv_last_page_len_host,
|
||||
seq_len_with_cache_host,
|
||||
# EXTRA METADATA
|
||||
batch_indices,
|
||||
positions,
|
||||
@ -642,6 +687,12 @@ def test_flashinfer_attention_with_fp8_cache(
|
||||
paged_kv_indices = torch.arange(BATCH_SIZE).int().to(device)
|
||||
paged_kv_last_page_len = offsets + seq_len_tensor
|
||||
|
||||
# Host copies of metadata
|
||||
qo_indptr_host = qo_indptr.cpu()
|
||||
paged_kv_indptr_host = paged_kv_indptr.cpu()
|
||||
paged_kv_last_page_len_host = paged_kv_last_page_len.cpu()
|
||||
seq_len_with_cache_host = (offsets + seq_len_tensor).cpu()
|
||||
|
||||
# Q,K,V are computed using GEMM, in fp16
|
||||
q = torch.randn(BATCH_SIZE, SEQ_LEN, N_HEADS * D_HEAD, dtype=DTYPE).to(device)
|
||||
k = torch.randn(BATCH_SIZE, SEQ_LEN, N_HEADS * D_HEAD, dtype=DTYPE).to(device)
|
||||
@ -696,8 +747,8 @@ def test_flashinfer_attention_with_fp8_cache(
|
||||
),
|
||||
BATCH_SIZE * SEQ_LEN,
|
||||
)
|
||||
# Create batch_info: [num_prefill, num_prefill_tokens, num_decode]
|
||||
batch_info = torch.tensor(
|
||||
# Create batch_info_host: [num_prefill, num_prefill_tokens, num_decode]
|
||||
batch_info_host = torch.tensor(
|
||||
[BATCH_SIZE, BATCH_SIZE * SEQ_LEN, 0], dtype=torch.int32, device=device
|
||||
)
|
||||
flashinfer_output = torch.ops.auto_deploy.flashinfer_attention_mha_with_cache(
|
||||
@ -706,11 +757,14 @@ def test_flashinfer_attention_with_fp8_cache(
|
||||
k,
|
||||
v,
|
||||
# STANDARD METADATA
|
||||
batch_info,
|
||||
qo_indptr,
|
||||
batch_info_host,
|
||||
qo_indptr_host,
|
||||
paged_kv_indptr,
|
||||
paged_kv_indptr_host,
|
||||
paged_kv_indices,
|
||||
paged_kv_last_page_len,
|
||||
paged_kv_last_page_len_host,
|
||||
seq_len_with_cache_host,
|
||||
# EXTRA METADATA
|
||||
batch_indices,
|
||||
positions,
|
||||
@ -787,6 +841,12 @@ def test_flashinfer_attention_with_paged_kvcache(seq_lengths, n_heads, dtype, de
|
||||
)
|
||||
paged_kv_last_page_len = ((offsets + seq_len_tensor - 1) % PAGE_SIZE) + 1
|
||||
|
||||
# Host copies of metadata
|
||||
qo_indptr_host = qo_indptr.cpu()
|
||||
paged_kv_indptr_host = paged_kv_indptr.cpu()
|
||||
paged_kv_last_page_len_host = paged_kv_last_page_len.cpu()
|
||||
seq_len_with_cache_host = (offsets + seq_len_tensor).cpu()
|
||||
|
||||
# make sure planner is initialized
|
||||
workspace = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device=device)
|
||||
_GlobalFlashInferPlanner.init_workspace(workspace)
|
||||
@ -798,19 +858,22 @@ def test_flashinfer_attention_with_paged_kvcache(seq_lengths, n_heads, dtype, de
|
||||
),
|
||||
SEQ_LEN,
|
||||
)
|
||||
# Create batch_info: [num_prefill, num_prefill_tokens, num_decode]
|
||||
batch_info = torch.tensor([BATCH_SIZE, SEQ_LEN, 0], dtype=torch.int32, device=device)
|
||||
# Create batch_info_host: [num_prefill, num_prefill_tokens, num_decode]
|
||||
batch_info_host = torch.tensor([BATCH_SIZE, SEQ_LEN, 0], dtype=torch.int32, device=device)
|
||||
flashinfer_output = torch.ops.auto_deploy.flashinfer_attention_mha_with_cache(
|
||||
# Q, K, V
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
# STANDARD METADATA
|
||||
batch_info,
|
||||
qo_indptr,
|
||||
batch_info_host,
|
||||
qo_indptr_host,
|
||||
paged_kv_indptr,
|
||||
paged_kv_indptr_host,
|
||||
paged_kv_indices,
|
||||
paged_kv_last_page_len,
|
||||
paged_kv_last_page_len_host,
|
||||
seq_len_with_cache_host,
|
||||
# EXTRA METADATA
|
||||
batch_indices,
|
||||
positions,
|
||||
@ -875,6 +938,12 @@ def test_flashinfer_attention_with_paged_kvcache(seq_lengths, n_heads, dtype, de
|
||||
)
|
||||
paged_kv_last_page_len2 = ((offsets2 + seq_len_tensor2 - 1) % PAGE_SIZE) + 1
|
||||
|
||||
# Host copies of metadata
|
||||
qo_indptr2_host = qo_indptr2.cpu()
|
||||
paged_kv_indptr2_host = paged_kv_indptr2.cpu()
|
||||
paged_kv_last_page_len2_host = paged_kv_last_page_len2.cpu()
|
||||
seq_len_with_cache2_host = (offsets2 + seq_len_tensor2).cpu()
|
||||
|
||||
# Create FlashInferAttention class before calling the custom op
|
||||
_GlobalFlashInferPlanner.reset()
|
||||
|
||||
@ -885,19 +954,22 @@ def test_flashinfer_attention_with_paged_kvcache(seq_lengths, n_heads, dtype, de
|
||||
),
|
||||
BATCH_SIZE * 1,
|
||||
)
|
||||
# Create batch_info: [num_prefill, num_prefill_tokens, num_decode]
|
||||
batch_info = torch.tensor([0, 0, BATCH_SIZE], dtype=torch.int32, device=device)
|
||||
# Create batch_info_host: [num_prefill, num_prefill_tokens, num_decode]
|
||||
batch_info_host = torch.tensor([0, 0, BATCH_SIZE], dtype=torch.int32, device=device)
|
||||
flashinfer_output_gen = torch.ops.auto_deploy.flashinfer_attention_mha_with_cache(
|
||||
# Q, K, V
|
||||
q_gen,
|
||||
k_gen,
|
||||
v_gen,
|
||||
# STANDARD METADATA
|
||||
batch_info,
|
||||
qo_indptr2,
|
||||
batch_info_host,
|
||||
qo_indptr2_host,
|
||||
paged_kv_indptr2,
|
||||
paged_kv_indptr2_host,
|
||||
paged_kv_indices2,
|
||||
paged_kv_last_page_len2,
|
||||
paged_kv_last_page_len2_host,
|
||||
seq_len_with_cache2_host,
|
||||
# EXTRA METADATA
|
||||
batch_indices,
|
||||
positions,
|
||||
|
||||
@ -246,14 +246,16 @@ class TestTorchBackendAttention:
|
||||
|
||||
if seq_len == 1:
|
||||
# Generate phase: [num_prefill, num_prefill_tokens, num_decode]
|
||||
batch_info = torch.tensor([0, 0, batch_size], device=self.device, dtype=torch.int32)
|
||||
batch_info_host = torch.tensor(
|
||||
[0, 0, batch_size], device=self.device, dtype=torch.int32
|
||||
)
|
||||
seq_start = torch.arange(batch_size, device=self.device, dtype=torch.int32)
|
||||
q_flat = q.view(batch_size, seq_len, -1)
|
||||
k_flat = k.view(batch_size, seq_len, -1)
|
||||
v_flat = v.view(batch_size, seq_len, -1)
|
||||
else:
|
||||
# Context phase: [num_prefill, num_prefill_tokens, num_decode]
|
||||
batch_info = torch.tensor(
|
||||
batch_info_host = torch.tensor(
|
||||
[batch_size, batch_size * seq_len, 0], device=self.device, dtype=torch.int32
|
||||
)
|
||||
seq_start = torch.arange(
|
||||
@ -267,7 +269,7 @@ class TestTorchBackendAttention:
|
||||
"q": q_flat,
|
||||
"k": k_flat,
|
||||
"v": v_flat,
|
||||
"batch_info": batch_info,
|
||||
"batch_info_host": batch_info_host,
|
||||
"seq_len": seq_len_tensor,
|
||||
"input_pos": input_positions,
|
||||
"cache_loc": cache_loc,
|
||||
@ -286,7 +288,7 @@ class TestTorchBackendAttention:
|
||||
data["k"],
|
||||
data["v"],
|
||||
# STANDARD METADATA
|
||||
data["batch_info"],
|
||||
data["batch_info_host"],
|
||||
data["seq_len"],
|
||||
data["input_pos"],
|
||||
data["cache_loc"],
|
||||
|
||||
@ -59,9 +59,9 @@ def test_generate_only_with_slot_mapping(conv_env):
|
||||
# Snapshot caches for reference before running op (op mutates caches)
|
||||
gathered_before = conv_state_cache.clone().index_select(0, slot_idx)
|
||||
use_initial_states = torch.zeros(batch, device=device, dtype=torch.bool)
|
||||
# batch_info: [num_prefill, num_prefill_tokens, num_decode]
|
||||
# batch_info_host: [num_prefill, num_prefill_tokens, num_decode]
|
||||
# For generate-only: num_decode = batch, num_prefill = 0
|
||||
batch_info = torch.tensor([0, 0, batch], device=device, dtype=torch.int32)
|
||||
batch_info_host = torch.tensor([0, 0, batch], device=device, dtype=torch.int32)
|
||||
# Run cached op
|
||||
y = torch.ops.auto_deploy.torch_cached_causal_conv1d(
|
||||
# INPUTS
|
||||
@ -69,7 +69,7 @@ def test_generate_only_with_slot_mapping(conv_env):
|
||||
w,
|
||||
b,
|
||||
# STANDARD METADATA
|
||||
batch_info,
|
||||
batch_info_host,
|
||||
seq_len,
|
||||
cu_seqlen,
|
||||
slot_idx,
|
||||
@ -124,18 +124,20 @@ def test_context_flattened_and_state_writeback(conv_env):
|
||||
seq_len = torch.tensor(lens, device=device, dtype=torch.int32)
|
||||
cu_seqlen = torch.tensor([0, lens[0]], device=device, dtype=torch.int32)
|
||||
use_initial_states = torch.zeros(batch, device=device, dtype=torch.bool)
|
||||
# batch_info: [num_prefill, num_prefill_tokens, num_decode]
|
||||
# batch_info_host: [num_prefill, num_prefill_tokens, num_decode]
|
||||
# For context/prefill phase: num_prefill = len(lens), num_decode = 0
|
||||
num_seqs = len(lens)
|
||||
num_prefill_tokens = sum(lens)
|
||||
batch_info = torch.tensor([num_seqs, num_prefill_tokens, 0], device=device, dtype=torch.int32)
|
||||
batch_info_host = torch.tensor(
|
||||
[num_seqs, num_prefill_tokens, 0], device=device, dtype=torch.int32
|
||||
)
|
||||
y = torch.ops.auto_deploy.torch_cached_causal_conv1d(
|
||||
# INPUTS
|
||||
x,
|
||||
w,
|
||||
b,
|
||||
# STANDARD METADATA
|
||||
batch_info,
|
||||
batch_info_host,
|
||||
seq_len,
|
||||
cu_seqlen,
|
||||
slot_idx,
|
||||
|
||||
@ -66,9 +66,9 @@ def test_generate_only_with_slot_mapping(mamba_env):
|
||||
seq_len = torch.ones(batch, device=device, dtype=torch.int32)
|
||||
cu_seqlen = torch.zeros(batch, device=device, dtype=torch.int32)
|
||||
use_initial_states = torch.zeros(batch, device=device, dtype=torch.bool)
|
||||
# batch_info: [num_prefill, num_prefill_tokens, num_decode]
|
||||
# batch_info_host: [num_prefill, num_prefill_tokens, num_decode]
|
||||
# For generate-only: num_decode = batch, num_prefill = 0
|
||||
batch_info = torch.tensor([0, 0, batch], device=device, dtype=torch.int32)
|
||||
batch_info_host = torch.tensor([0, 0, batch], device=device, dtype=torch.int32)
|
||||
# Snapshot caches for reference before running op (op mutates caches)
|
||||
gathered_before = ssm_state_cache.clone().index_select(0, slot_idx)
|
||||
|
||||
@ -83,7 +83,7 @@ def test_generate_only_with_slot_mapping(mamba_env):
|
||||
dt,
|
||||
dt_bias,
|
||||
# STANDARD METADATA
|
||||
batch_info,
|
||||
batch_info_host,
|
||||
seq_len,
|
||||
cu_seqlen,
|
||||
slot_idx,
|
||||
@ -141,11 +141,13 @@ def test_context_flattened_and_state_writeback(mamba_env):
|
||||
seq_len = torch.tensor(lens, device=device, dtype=torch.int32)
|
||||
cu_seqlen = torch.tensor([0, lens[0]], device=device, dtype=torch.int32)
|
||||
use_initial_states = torch.zeros(batch, device=device, dtype=torch.bool)
|
||||
# batch_info: [num_prefill, num_prefill_tokens, num_decode]
|
||||
# batch_info_host: [num_prefill, num_prefill_tokens, num_decode]
|
||||
# For context/prefill phase: num_prefill = len(lens), num_decode = 0
|
||||
num_seqs = len(lens)
|
||||
num_prefill_tokens = sum(lens)
|
||||
batch_info = torch.tensor([num_seqs, num_prefill_tokens, 0], device=device, dtype=torch.int32)
|
||||
batch_info_host = torch.tensor(
|
||||
[num_seqs, num_prefill_tokens, 0], device=device, dtype=torch.int32
|
||||
)
|
||||
y = torch.ops.auto_deploy.torch_cached_ssm(
|
||||
# INPUTS
|
||||
hidden_states,
|
||||
@ -156,7 +158,7 @@ def test_context_flattened_and_state_writeback(mamba_env):
|
||||
dt,
|
||||
dt_bias,
|
||||
# STANDARD METADATA
|
||||
batch_info,
|
||||
batch_info_host,
|
||||
seq_len,
|
||||
cu_seqlen,
|
||||
slot_idx,
|
||||
|
||||
@ -134,8 +134,8 @@ def test_triton_context_flattened_and_state_writeback(mamba_env):
|
||||
torch.arange(len(lens), device=device, dtype=torch.int32),
|
||||
seq_len,
|
||||
).view(1, -1)
|
||||
# batch_info: [num_prefill, num_prefill_tokens, num_decode]
|
||||
batch_info_tensor = torch.tensor([len(lens), sum(lens), 0], dtype=torch.int32, device=device)
|
||||
# batch_info_host: [num_prefill, num_prefill_tokens, num_decode]
|
||||
batch_info_host = torch.tensor([len(lens), sum(lens), 0], dtype=torch.int32, device=device)
|
||||
# Torch reference
|
||||
y_torch = torch.ops.auto_deploy.torch_cached_ssm(
|
||||
hidden_states,
|
||||
@ -146,7 +146,7 @@ def test_triton_context_flattened_and_state_writeback(mamba_env):
|
||||
dt,
|
||||
dt_bias,
|
||||
# STANDARD METADATA
|
||||
batch_info_tensor,
|
||||
batch_info_host,
|
||||
seq_len,
|
||||
cu_seqlen,
|
||||
slot_idx,
|
||||
@ -168,7 +168,7 @@ def test_triton_context_flattened_and_state_writeback(mamba_env):
|
||||
dt,
|
||||
dt_bias,
|
||||
# STANDARD METADATA
|
||||
batch_info_tensor,
|
||||
batch_info_host,
|
||||
cu_seqlens,
|
||||
slot_idx,
|
||||
use_initial_states,
|
||||
|
||||
@ -56,10 +56,10 @@ class TestGatherLogitsBeforeLmHeadOp:
|
||||
|
||||
# Create gather info: num_tokens_to_gather=batch_size, gather_required=0 (False)
|
||||
logits_gather_indices = torch.arange(batch_size, dtype=torch.long, device="cuda")
|
||||
logits_gather_info = torch.tensor([batch_size, 0], dtype=torch.int32, device="cuda")
|
||||
logits_gather_info_host = torch.tensor([batch_size, 0], dtype=torch.int32, device="cpu")
|
||||
|
||||
output = torch.ops.auto_deploy.gather_logits_before_lm_head.default(
|
||||
hidden_states, logits_gather_indices, logits_gather_info
|
||||
hidden_states, logits_gather_indices, logits_gather_info_host
|
||||
)
|
||||
|
||||
# Should return [batch, 1, hidden] for generate format (3D shape preserved)
|
||||
@ -82,10 +82,10 @@ class TestGatherLogitsBeforeLmHeadOp:
|
||||
gather_indices = torch.arange(0, num_gather, dtype=torch.long, device="cuda")
|
||||
|
||||
# Create gather info: num_tokens_to_gather=num_gather, gather_required=1 (True)
|
||||
logits_gather_info = torch.tensor([num_gather, 1], dtype=torch.int32, device="cuda")
|
||||
logits_gather_info_host = torch.tensor([num_gather, 1], dtype=torch.int32, device="cpu")
|
||||
|
||||
output = torch.ops.auto_deploy.gather_logits_before_lm_head.default(
|
||||
hidden_states, gather_indices, logits_gather_info
|
||||
hidden_states, gather_indices, logits_gather_info_host
|
||||
)
|
||||
|
||||
# Should return [1, num_gather, hidden] for packed format (3D shape preserved)
|
||||
@ -105,15 +105,15 @@ class TestGatherLogitsBeforeLmHeadOp:
|
||||
|
||||
# Create gather info
|
||||
logits_gather_indices = torch.arange(batch_size, dtype=torch.long, device="cuda")
|
||||
logits_gather_info = torch.tensor([batch_size, 0], dtype=torch.int32, device="cuda")
|
||||
logits_gather_info_host = torch.tensor([batch_size, 0], dtype=torch.int32, device="cpu")
|
||||
|
||||
# Use fake implementation directly
|
||||
with FakeTensorMode() as mode:
|
||||
hidden_states_fake = mode.from_tensor(hidden_states)
|
||||
logits_gather_indices_fake = mode.from_tensor(logits_gather_indices)
|
||||
logits_gather_info_fake = mode.from_tensor(logits_gather_info)
|
||||
logits_gather_info_host_fake = mode.from_tensor(logits_gather_info_host)
|
||||
output = torch.ops.auto_deploy.gather_logits_before_lm_head.default(
|
||||
hidden_states_fake, logits_gather_indices_fake, logits_gather_info_fake
|
||||
hidden_states_fake, logits_gather_indices_fake, logits_gather_info_host_fake
|
||||
)
|
||||
|
||||
# Should return [batch, 1, hidden_size] (fake returns empty_like which preserves 3D shape)
|
||||
@ -132,15 +132,15 @@ class TestGatherLogitsBeforeLmHeadOp:
|
||||
|
||||
# Create gather info
|
||||
logits_gather_indices = torch.arange(num_gather, dtype=torch.long, device="cuda")
|
||||
logits_gather_info = torch.tensor([num_gather, 1], dtype=torch.int32, device="cuda")
|
||||
logits_gather_info_host = torch.tensor([num_gather, 1], dtype=torch.int32, device="cpu")
|
||||
|
||||
# Use fake implementation directly
|
||||
with FakeTensorMode() as mode:
|
||||
hidden_states_fake = mode.from_tensor(hidden_states)
|
||||
logits_gather_indices_fake = mode.from_tensor(logits_gather_indices)
|
||||
logits_gather_info_fake = mode.from_tensor(logits_gather_info)
|
||||
logits_gather_info_host_fake = mode.from_tensor(logits_gather_info_host)
|
||||
output = torch.ops.auto_deploy.gather_logits_before_lm_head.default(
|
||||
hidden_states_fake, logits_gather_indices_fake, logits_gather_info_fake
|
||||
hidden_states_fake, logits_gather_indices_fake, logits_gather_info_host_fake
|
||||
)
|
||||
|
||||
# The fake implementation returns empty_like which preserves input shape [1, total_tokens, hidden]
|
||||
@ -217,13 +217,13 @@ class TestGatherLogitsBeforeLmHeadTransform:
|
||||
# Test forward pass
|
||||
# We must pass the new graph inputs manually since we are running the graph directly
|
||||
logits_gather_indices = torch.arange(batch_size, dtype=torch.long, device="cuda")
|
||||
logits_gather_info = torch.tensor([batch_size, 0], dtype=torch.int32, device="cuda")
|
||||
logits_gather_info_host = torch.tensor([batch_size, 0], dtype=torch.int32, device="cpu")
|
||||
output = gm_transformed(
|
||||
hidden_states,
|
||||
logit_gather_ids,
|
||||
seq_len,
|
||||
logits_gather_indices=logits_gather_indices,
|
||||
logits_gather_info=logits_gather_info,
|
||||
logits_gather_info_host=logits_gather_info_host,
|
||||
)
|
||||
# Output should be [batch_size, 1, vocab_size] since gather now returns 3D
|
||||
assert output.shape == (batch_size, 1, vocab_size)
|
||||
@ -278,13 +278,13 @@ class TestGatherLogitsBeforeLmHeadTransform:
|
||||
# We must pass the new graph inputs manually since we are running the graph directly
|
||||
num_gather = len(logit_gather_ids)
|
||||
logits_gather_indices = logit_gather_ids
|
||||
logits_gather_info = torch.tensor([num_gather, 1], dtype=torch.int32, device="cuda")
|
||||
logits_gather_info_host = torch.tensor([num_gather, 1], dtype=torch.int32, device="cpu")
|
||||
output = gm_transformed(
|
||||
hidden_states,
|
||||
logit_gather_ids_padded,
|
||||
seq_len,
|
||||
logits_gather_indices=logits_gather_indices,
|
||||
logits_gather_info=logits_gather_info,
|
||||
logits_gather_info_host=logits_gather_info_host,
|
||||
)
|
||||
# Output should be [1, num_gather, vocab_size] since gather now returns 3D
|
||||
assert output.shape == (1, num_gather, vocab_size)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user