[#10244][feat] AutoDeploy: separate prefill/decode in flashinfer (#10252)

Signed-off-by: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com>
This commit is contained in:
Lucas Liebenwein 2025-12-31 17:01:24 -05:00 committed by GitHub
parent 9085021aa4
commit 1bbe71b3ed
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
25 changed files with 441 additions and 330 deletions

View File

@ -10,19 +10,7 @@ and operates on a purely functional paradigm that is compatible with the torch c
"""
from abc import ABC, abstractmethod
from typing import (
Callable,
Dict,
List,
Literal,
Optional,
Protocol,
Sequence,
Set,
Tuple,
Type,
Union,
)
from typing import Dict, List, Literal, Optional, Protocol, Sequence, Set, Tuple, Type, Union
import torch
from pydantic import BaseModel, ConfigDict, Field, field_validator
@ -36,6 +24,10 @@ from ..utils.logger import ad_logger
Constant = Union[int, float, str, None]
class PrepareMetadataHostCallable(Protocol):
def __call__(self, **sequence_info_args: torch.Tensor) -> None: ...
class InputBuffer:
"""Manages contiguous memory buffers for efficient host-to-device transfers.
@ -388,6 +380,9 @@ class SequenceInfo:
- _mask_scatter_indices: [m_0, m_1, ..., m_{s_total-1}]
Mask scatter indices used by the overlap scheduler to scatter results back.
NOTE: all tensors are also accessible as host tensors with the suffix "_host". For example,
the tensor "batch_info" is accessible as "batch_info_host" on the host.
################################################################################################
Here are a couple of notes to emphasize this notation:
@ -508,6 +503,9 @@ class SequenceInfo:
# Create the InputBuffer that manages contiguous host and device memory
# Starts on default device; use to() to move to target device
self._input_buffer = InputBuffer(tensor_specs)
self._available_args = set(self._input_buffer.tensor_names) | {
f"{name}_host" for name in self._input_buffer.tensor_names
}
# Initialize args_list from tensor specs
self._args_list: Dict[str, List[int]] = {
@ -515,9 +513,7 @@ class SequenceInfo:
}
self._active_args = ("input_ids", "position_ids")
self._shapeable_args = ("input_ids", "position_ids")
# Args that should be returned from host (pinned memory) instead of device in _named_args
self._host_return_args = ("batch_info", "logits_gather_info")
self._shapeable_args = ("input_ids", "position_ids", "input_ids_host", "position_ids_host")
############################################################################################
# EXTRA TENSOR FIELDS ######################################################################
@ -525,7 +521,7 @@ class SequenceInfo:
############################################################################################
# HOST PREPARE FOR ATTENTION FORWARD #######################################################
self._host_prepare_functions: set[Callable[[SequenceInfo], None]] = set()
self._host_prepare_functions: List[Tuple[PrepareMetadataHostCallable, List[str]]] = []
# call reset once to set a consistent initial state
self.reset()
@ -558,14 +554,13 @@ class SequenceInfo:
def _get_arg(self, name: str) -> torch.Tensor:
"""Get the argument from the input buffer either on device or host."""
if name in self._host_return_args:
arg = self._input_buffer.get_host_view(name)
if name.endswith("_host"):
arg = self._input_buffer.get_host_view(name.replace("_host", ""))
else:
arg = self._input_buffer.get_view(name)
return self._shape_for_forward(arg) if name in self._shapeable_args else arg
def _named_args(self, include_extra_args: bool = True) -> Dict[str, torch.Tensor]:
# Build args dict, using host views for _host_return_args, device views otherwise
args = {k: self._get_arg(k) for k in self._active_args}
# check other args to include
@ -577,7 +572,7 @@ class SequenceInfo:
@property
def available_args(self) -> Set[str]:
"""Return a list of available arguments."""
return set(self._input_buffer.tensor_names)
return self._available_args
@property
def named_args(self) -> Dict[str, torch.Tensor]:
@ -697,68 +692,6 @@ class SequenceInfo:
pages_per_seq = [len(p) for p in page_assignments]
return cache_loc_flat, pages_per_seq
# TODO: remove after updating all cached backends
@classmethod
def _get_sanitized_seq_len(
cls, input_or_position_ids: torch.Tensor, seq_len: torch.Tensor
) -> torch.Tensor:
"""Sanitize sequence lengths.
We want to cover the following scenarios with this function:
1. Pre-fill:
input_ids: [1, s_total, ...]
seq_len: [s_0, s_1, ..., s_{b-1}, 0, 0, ..., 0]
---> returns [s_0, s_1, ..., s_{b-1}]
2. Decode:
input_ids: [b, 1, ...]
seq_len: [1, 1, ..., 1, 0, 0, ..., ..., ..., ..., 0]
|---- b ----|--- (max_batch_size - b) ---|
--> returns [1,] * b
3. Decode in Cudagraph:
input_ids: [b_cudagraph, 1, ...]
seq_len: [1, 1, ..., 1, 0, 0, ..., ..., ..., ..., 0]
|---- b ----|--- (max_batch_size - b) ---|
--> returns [1,] * b_cudagraph
Here b <= b_cudagraph. We want to make sure that the seq_len is one-padded to
b_cudagraph.
# TODO: I could see one possible issue with this approach in the future.
# If we have b < b_cudagraph we now one-pad. However, we don't pad the cache location
# information. What could happen is that the for the padded sequences the cache location
# tensors point to allocated pages. This could lead to a situation where we write into
# allocated cache pages polluting the cache of other sequences. Now this is not an issue
# if we write the dummy sequences into unallocated cache pages... One fix could be to
# pad not only the seq len but also pad the cache locations by just repeating the last
# valid cache location in the batch. This would ensure that the dummy sequences just
# repeats valid computation...
"""
_, s = input_or_position_ids.shape[:2]
num_seq = cls._get_sanitized_num_sequences(input_or_position_ids, seq_len)
if s > 1:
return seq_len[:num_seq].clone()
else:
return torch.ones(num_seq, dtype=seq_len.dtype, device=seq_len.device)
@staticmethod
def _get_sanitized_num_sequences(
input_or_position_ids: torch.Tensor, seq_len: torch.Tensor
) -> int:
"""Get number of sequences.
We makes sure that this function is compatible with both torch graph capture and cudagraph.
Both can be a bit temparamental when trying to extract the number of sequences from a tensor
with max_batch_size or max_batch_size*max_seq_len.
"""
b, s = input_or_position_ids.shape[:2]
if s > 1:
num_seq = torch.sum(seq_len > 0)
assert seq_len[num_seq:].sum() == 0, "seq_len should be zero-padded"
else:
num_seq = b
return num_seq
def activate_arg(self, arg_name: str) -> bool:
"""Activate a desired argument.
@ -869,7 +802,7 @@ class SequenceInfo:
self._args_list[name] = tnsr_like.copy()
# Only store to buffer when the argument is active or force_copy is True
if not (name in self._active_args or force_copy):
if not (name in self._active_args or f"{name}_host" in self._active_args or force_copy):
return
# Store to the InputBuffer's pinned host memory
@ -1090,12 +1023,12 @@ class SequenceInfo:
def maybe_gather_and_squeeze_logits(self, logits: torch.Tensor) -> torch.Tensor:
"""Maybe gather the logits if logits have not been gathered yet."""
num_tokens = logits.shape[0] * logits.shape[1]
num_tokens_to_gather, gather_required = self._get_arg("logits_gather_info").tolist()
num_tokens_to_gather, gather_required = self._get_arg("logits_gather_info_host").tolist()
if gather_required and num_tokens_to_gather < num_tokens:
logits = torch.ops.auto_deploy.gather_logits_before_lm_head(
logits,
self._get_arg("logits_gather_indices"),
self._get_arg("logits_gather_info"),
self._get_arg("logits_gather_info_host"),
)
return logits.squeeze(int(self.is_generate))
@ -1105,13 +1038,13 @@ class SequenceInfo:
return list(torch.split(t_squeezed, self.seq_len))
def register_host_prepare_for_attention_forward(
self, host_function: Callable[["SequenceInfo"], None]
self, host_function: PrepareMetadataHostCallable, args: List[str]
):
self._host_prepare_functions.add(host_function)
self._host_prepare_functions.append((host_function, args))
def run_host_prepare_for_attention_forward(self) -> None:
for host_function in self._host_prepare_functions:
host_function(self)
for host_function, args in self._host_prepare_functions:
host_function(**{arg: self._get_arg(arg) for arg in args})
class MHACallable(Protocol):
@ -1123,14 +1056,7 @@ class MHACallable(Protocol):
class PrepareMetadataCallable(Protocol):
def __call__(
self,
position_ids: torch.Tensor,
seq_len: torch.Tensor,
input_pos: torch.Tensor,
cache_loc: torch.Tensor,
pages_per_seq: torch.Tensor,
slot_idx: torch.Tensor,
page_size: int,
self, *sequence_info_args_and_constants: Union[torch.Tensor, Constant]
) -> List[torch.Tensor]: ...
@ -1291,13 +1217,14 @@ class AttentionDescriptor(ABC):
return []
@classmethod
def host_prepare_for_forward(cls, sequence_info: SequenceInfo):
"""Perform host-side preparation for the forward pass for the attention op.
def get_host_prepare_metadata_function(cls) -> Optional[PrepareMetadataHostCallable]:
"""Get function that performs host-side prep for the forward pass for the attention op.
This method is responsible for preparing the attention op for the forward pass.
This function is not expected to be graph capturable or compatible with cuda graphs.
This function is not expected to be graph capturable or compatible with cuda graphs. It can
use any argument from the SequenceInfo interface as input argument to its function.
"""
return
return None
class AttentionRegistry:

View File

@ -35,7 +35,7 @@ def fla_cached_delta_rule(
v: torch.Tensor,
beta: torch.Tensor,
# STANDARD METADATA
batch_info: torch.Tensor,
batch_info_host: torch.Tensor,
cu_seqlen: torch.Tensor,
slot_idx: torch.Tensor,
use_initial_states: torch.Tensor,
@ -58,7 +58,7 @@ def fla_cached_delta_rule(
y = torch.empty_like(v, memory_format=torch.contiguous_format)
y_flat = y.view(b * s, num_heads, -1)
num_prefill, num_prefill_tokens, num_decode = batch_info.tolist()
num_prefill, num_prefill_tokens, num_decode = batch_info_host.tolist()
num_seq = num_prefill + num_decode
# clean up metadata
@ -120,7 +120,7 @@ def fla_cached_delta_rule_fake(
v: torch.Tensor,
beta: torch.Tensor,
# STANDARD METADATA
batch_info: torch.Tensor,
batch_info_host: torch.Tensor,
cu_seqlen: torch.Tensor,
slot_idx: torch.Tensor,
use_initial_states: torch.Tensor,
@ -160,7 +160,7 @@ class FlaDeltaBackend(AttentionDescriptor):
@classmethod
def get_standard_metadata_args(cls) -> List[str]:
return ["batch_info", "cu_seqlen", "slot_idx", "use_initial_states"]
return ["batch_info_host", "cu_seqlen", "slot_idx", "use_initial_states"]
@classmethod
def get_cache_initializers(

View File

@ -21,6 +21,7 @@ from .attention_interface import (
Constant,
MHACallable,
PrepareMetadataCallable,
PrepareMetadataHostCallable,
SequenceInfo,
)
@ -183,7 +184,6 @@ class PlanParams:
n_kv_heads: int
head_dim: int
num_seq: int
is_generate: bool
page_size: int
q_dtype: torch.dtype
kv_dtype: torch.dtype
@ -205,14 +205,16 @@ class _FlashInferPlanner:
cached_cuda_graph_decode_wrappers: Dict[
PlanParams, flashinfer.BatchDecodeWithPagedKVCacheWrapper
]
plan_params: Optional[PlanParams]
plan_params_prefill: Optional[PlanParams]
plan_params_decode: Optional[PlanParams]
def __init__(self):
self.workspace_buffer = None
self.prefill_wrapper = None
self.decode_wrapper = None
self.cached_cuda_graph_decode_wrappers = {}
self.plan_params = None
self.plan_params_prefill = None
self.plan_params_decode = None
def _init_decode_wrapper(
self,
@ -253,7 +255,8 @@ class _FlashInferPlanner:
self.decode_wrapper = self._init_decode_wrapper()
def reset(self) -> None:
self.plan_params = None
self.plan_params_prefill = None
self.plan_params_decode = None
def plan_generate_only(
self,
@ -279,9 +282,46 @@ class _FlashInferPlanner:
sm_scale=plan_params.sm_scale,
)
def plan(
def plan_prefill(
self,
qo_indptr_host: torch.Tensor,
kv_page_indptr_host: torch.Tensor,
kv_page_indices: torch.Tensor,
kv_last_page_len_host: torch.Tensor,
kv_lens_arr_host: torch.Tensor,
plan_params: PlanParams,
) -> None:
# check for re-planning
if plan_params != self.plan_params_prefill:
# plan prefill
# NOTE (lucaslie): we use host versions here. the plan actually needs both (host+device)
# version. Unfortunately, there is no good way to access the plan API and provide both
# although we have both available. I have decided to use the host versions here to
# ensure non-blocking invocation of plan, whereas the other way around would trigger a
# blocking copy to cpu. This way we trigger a non-blocking copy to device (note that
# this is safe since we do have pinned CPU memory for all our host-side arguments).
self.prefill_wrapper.plan(
qo_indptr_host,
kv_page_indptr_host,
kv_page_indices,
kv_last_page_len_host,
plan_params.n_heads, # Q heads
plan_params.n_kv_heads, # KV heads
plan_params.head_dim,
plan_params.page_size,
causal=plan_params.causal,
q_data_type=plan_params.q_dtype,
kv_data_type=plan_params.kv_dtype,
sm_scale=plan_params.sm_scale,
seq_lens=kv_lens_arr_host,
)
self.plan_params_prefill = plan_params
# return prefill wrapper
return self.prefill_wrapper
def plan_decode(
self,
qo_indptr: torch.Tensor,
kv_page_indptr: torch.Tensor,
kv_page_indices: torch.Tensor,
kv_last_page_len: torch.Tensor,
@ -323,34 +363,16 @@ class _FlashInferPlanner:
_plan_decode(self.cached_cuda_graph_decode_wrappers[plan_params])
# check if we are in cuda graph capture and just return the pre-cached decode wrapper
if torch.cuda.is_current_stream_capturing() or cuda_graph_state.in_warm_up():
assert plan_params.is_generate, "Only generate is supported during cuda graph capture."
wrapper = self.cached_cuda_graph_decode_wrappers[plan_params]
return wrapper
# check for re-planning
if plan_params != self.plan_params:
if plan_params.is_generate:
_plan_decode(self.decode_wrapper)
else:
# plan prefill
self.prefill_wrapper.plan(
qo_indptr,
kv_page_indptr,
kv_page_indices,
kv_last_page_len,
plan_params.n_heads, # Q heads
plan_params.n_kv_heads, # KV heads
plan_params.head_dim,
plan_params.page_size,
causal=plan_params.causal,
q_data_type=plan_params.q_dtype,
kv_data_type=plan_params.kv_dtype,
sm_scale=plan_params.sm_scale,
)
self.plan_params = plan_params
if plan_params != self.plan_params_decode:
_plan_decode(self.decode_wrapper)
self.plan_params_decode = plan_params
# return desired wrapper
return self.decode_wrapper if plan_params.is_generate else self.prefill_wrapper
# return decode wrapper
return self.decode_wrapper
_GlobalFlashInferPlanner = _FlashInferPlanner()
@ -359,7 +381,7 @@ _GlobalFlashInferPlanner = _FlashInferPlanner()
@torch.library.custom_op("auto_deploy::flashinfer_attention_prepare_metadata", mutates_args=())
def prepare_flashinfer_metadata(
position_ids: torch.Tensor,
batch_info: torch.Tensor,
batch_info_host: torch.Tensor,
cu_seqlen: torch.Tensor,
seq_len_with_cache: torch.Tensor,
) -> List[torch.Tensor]:
@ -370,7 +392,7 @@ def prepare_flashinfer_metadata(
to understand the convention.
"""
# retrieve host-side metadata
num_prefill, num_prefill_tokens, num_decode = batch_info.tolist()
num_prefill, num_prefill_tokens, num_decode = batch_info_host.tolist()
num_seq = num_prefill + num_decode
num_tokens = num_prefill_tokens + num_decode
@ -393,7 +415,7 @@ def prepare_flashinfer_metadata(
@prepare_flashinfer_metadata.register_fake
def prepare_flashinfer_metadata_fake(
position_ids: torch.Tensor,
batch_info: torch.Tensor,
batch_info_host: torch.Tensor,
cu_seqlen: torch.Tensor,
seq_len_with_cache: torch.Tensor,
):
@ -404,6 +426,23 @@ def prepare_flashinfer_metadata_fake(
)
def prepare_flashinfer_metadata_host(
batch_info_host: torch.Tensor,
cu_num_pages_host: torch.Tensor,
cache_loc_host: torch.Tensor,
last_page_len_host: torch.Tensor,
) -> None:
num_prefill, num_prefill_tokens, num_decode = batch_info_host.tolist()
if num_prefill == 0:
_GlobalFlashInferPlanner.plan_generate_only(
num_decode,
cu_num_pages_host[: num_decode + 1],
cache_loc_host,
last_page_len_host[:num_decode],
)
@torch.library.custom_op("auto_deploy::flashinfer_attention_mha_with_cache", mutates_args=())
def flashinfer_mha_with_cache(
# Q, K, V
@ -411,11 +450,14 @@ def flashinfer_mha_with_cache(
k: torch.Tensor,
v: torch.Tensor,
# STANDARD METADATA
batch_info: torch.Tensor,
cu_seqlen: torch.Tensor,
batch_info_host: torch.Tensor,
cu_seqlen_host: torch.Tensor,
cu_num_pages: torch.Tensor,
cu_num_pages_host: torch.Tensor,
cache_loc: torch.Tensor,
last_page_len: torch.Tensor,
last_page_len_host: torch.Tensor,
seq_len_with_cache_host: torch.Tensor,
# EXTRA METADATA
flashinfer_batch_indices: torch.Tensor,
flashinfer_positions: torch.Tensor,
@ -439,32 +481,13 @@ def flashinfer_mha_with_cache(
v = v.reshape(b * s, -1, head_dim)
# convert to flashinfer-style metadata
num_prefill, num_prefill_tokens, num_decode = batch_info.tolist()
num_prefill, num_prefill_tokens, num_decode = batch_info_host.tolist()
num_seq = num_prefill + num_decode
qo_indptr = cu_seqlen[: num_seq + 1]
paged_kv_indptr = cu_num_pages[: num_seq + 1]
# NOTE: it is okay to have cache_loc here without truncation. paged_kv_indptr will be
# truncated and will point to the correct sub range of cache_loc.
paged_kv_indices = cache_loc
paged_kv_last_page_len = last_page_len[:num_seq]
num_total_tokens = num_prefill_tokens + num_decode
n_heads = q.shape[1]
n_kv_heads = k.shape[1]
pp = PlanParams(
n_heads=n_heads,
n_kv_heads=n_kv_heads,
head_dim=head_dim,
num_seq=len(qo_indptr) - 1,
is_generate=(s == 1),
page_size=k_cache.shape[1],
q_dtype=q.dtype,
kv_dtype=k_cache.dtype,
sm_scale=scale,
)
# Assuming k_scale = v_scale = 1.0
k_scale, v_scale = 1.0, 1.0
# k = (k / k_scale).to(torch.float8_e4m3fn) if k_scale != 1.0, same for v
@ -473,28 +496,91 @@ def flashinfer_mha_with_cache(
v = v.to(torch.float8_e4m3fn)
flashinfer.page.append_paged_kv_cache(
k,
v,
flashinfer_batch_indices,
flashinfer_positions,
(k_cache, v_cache),
paged_kv_indices,
paged_kv_indptr,
paged_kv_last_page_len,
append_key=k,
append_value=v,
batch_indices=flashinfer_batch_indices,
positions=flashinfer_positions,
paged_kv_cache=(k_cache, v_cache),
kv_indices=cache_loc,
kv_indptr=cu_num_pages[: num_seq + 1],
kv_last_page_len=last_page_len[:num_seq],
)
# run the flashinfer planner and obtain the correct wrapper
wrapper = _GlobalFlashInferPlanner.plan(
qo_indptr,
paged_kv_indptr,
paged_kv_indices,
paged_kv_last_page_len,
pp,
)
# check if we need to re-combine outputs
if num_prefill > 0 and num_decode > 0:
y = torch.empty_like(q)
else:
y = None
y = wrapper.run(
q, (k_cache, v_cache), k_scale=k_scale, v_scale=v_scale, enable_pdl=get_env_enable_pdl()
)
# now run split prefill, decode
if num_prefill > 0:
q_prefill = q[:num_prefill_tokens]
pp_prefill = PlanParams(
n_heads=n_heads,
n_kv_heads=n_kv_heads,
head_dim=head_dim,
num_seq=num_prefill,
page_size=k_cache.shape[1],
q_dtype=q_prefill.dtype,
kv_dtype=k_cache.dtype,
sm_scale=scale,
)
wrapper_prefill = _GlobalFlashInferPlanner.plan_prefill(
qo_indptr_host=cu_seqlen_host[: num_prefill + 1],
kv_page_indptr_host=cu_num_pages_host[: num_prefill + 1],
kv_page_indices=cache_loc,
kv_last_page_len_host=last_page_len_host[:num_prefill],
kv_lens_arr_host=seq_len_with_cache_host[:num_prefill],
plan_params=pp_prefill,
)
y_prefill = wrapper_prefill.run(
q_prefill,
(k_cache, v_cache),
k_scale=k_scale,
v_scale=v_scale,
enable_pdl=get_env_enable_pdl(),
)
if y is not None:
y[:num_prefill_tokens] = y_prefill
else:
y = y_prefill
if num_decode > 0:
q_decode = q[num_prefill_tokens:num_total_tokens]
pp_decode = PlanParams(
n_heads=n_heads,
n_kv_heads=n_kv_heads,
head_dim=head_dim,
num_seq=num_decode,
page_size=k_cache.shape[1],
q_dtype=q_decode.dtype,
kv_dtype=k_cache.dtype,
sm_scale=scale,
)
# run the flashinfer planner and obtain the correct wrapper
wrapper_decode = _GlobalFlashInferPlanner.plan_decode(
kv_page_indptr=cu_num_pages[num_prefill : num_seq + 1],
kv_page_indices=cache_loc,
kv_last_page_len=last_page_len[num_prefill:num_seq],
plan_params=pp_decode,
)
y_decode = wrapper_decode.run(
q_decode,
(k_cache, v_cache),
k_scale=k_scale,
v_scale=v_scale,
enable_pdl=get_env_enable_pdl(),
)
if y is not None:
y[num_prefill_tokens:num_total_tokens] = y_decode
else:
y = y_decode
return y.view(q_shape_og) # [b,s,n*h_d] or [b,s, n, h_d]
@ -506,11 +592,14 @@ def flashinfer_mha_with_cache_fake(
k: torch.Tensor,
v: torch.Tensor,
# STANDARD METADATA
batch_info: torch.Tensor,
cu_seqlen: torch.Tensor,
batch_info_host: torch.Tensor,
cu_seqlen_host: torch.Tensor,
cu_num_pages: torch.Tensor,
cu_num_pages_host: torch.Tensor,
cache_loc: torch.Tensor,
last_page_len: torch.Tensor,
last_page_len_host: torch.Tensor,
seq_len_with_cache_host: torch.Tensor,
# EXTRA METADATA
flashinfer_batch_indices: torch.Tensor,
flashinfer_positions: torch.Tensor,
@ -559,7 +648,16 @@ class FlashInferAttention(AttentionDescriptor):
@classmethod
def get_standard_metadata_args(cls) -> List[str]:
return ["batch_info", "cu_seqlen", "cu_num_pages", "cache_loc", "last_page_len"]
return [
"batch_info_host",
"cu_seqlen_host",
"cu_num_pages",
"cu_num_pages_host",
"cache_loc",
"last_page_len",
"last_page_len_host",
"seq_len_with_cache_host",
]
@classmethod
def get_prepare_extra_metadata_info(
@ -600,18 +698,8 @@ class FlashInferAttention(AttentionDescriptor):
return {"workspace_buffer": _init_workspace}
@classmethod
def host_prepare_for_forward(cls, sequence_info: SequenceInfo):
batch_info = sequence_info._input_buffer.get_host_view("batch_info")
num_prefill, num_prefill_tokens, num_decode = batch_info.tolist()
# Call plan for generate-only batches.
if num_prefill == 0:
_GlobalFlashInferPlanner.plan_generate_only(
num_decode,
sequence_info._input_buffer.get_host_view("cu_num_pages")[: num_decode + 1],
sequence_info._input_buffer.get_host_view("cache_loc"),
sequence_info._input_buffer.get_host_view("last_page_len")[:num_decode],
)
return
def get_host_prepare_metadata_function(cls) -> Optional[PrepareMetadataHostCallable]:
return prepare_flashinfer_metadata_host
@classmethod
def get_constants(cls, source_attn_node: Node) -> List[Constant]:

View File

@ -53,7 +53,7 @@ def _cuda_cached_causal_conv1d(
weight: torch.Tensor, # [c_out, c_in/groups, k] but we expect depthwise use: [c_in, k]
bias: Optional[torch.Tensor],
# STANDARD METADATA
batch_info: torch.Tensor,
batch_info_host: torch.Tensor,
cu_seqlen: torch.Tensor,
slot_idx: torch.Tensor,
use_initial_states: torch.Tensor,
@ -80,7 +80,7 @@ def _cuda_cached_causal_conv1d(
"""
b, s = input.shape[:2]
num_prefill, num_prefill_tokens, num_decode = batch_info.tolist()
num_prefill, num_prefill_tokens, num_decode = batch_info_host.tolist()
num_seq = num_prefill + num_decode
num_total_tokens = num_prefill_tokens + num_decode
@ -138,7 +138,7 @@ def _cuda_cached_causal_conv1d_fake(
weight: torch.Tensor, # [c_out, c_in/groups, k] but we expect depthwise use: [c_in, k]
bias: Optional[torch.Tensor],
# STANDARD METADATA
batch_info: torch.Tensor,
batch_info_host: torch.Tensor,
cu_seqlen: torch.Tensor,
slot_idx: torch.Tensor,
use_initial_states: torch.Tensor,
@ -189,7 +189,7 @@ class CudaBackendCausalConv(AttentionDescriptor):
@classmethod
def get_standard_metadata_args(cls) -> List[str]:
return ["batch_info", "cu_seqlen", "slot_idx", "use_initial_states"]
return ["batch_info_host", "cu_seqlen", "slot_idx", "use_initial_states"]
@classmethod
def get_cache_initializers(

View File

@ -147,7 +147,7 @@ def _torch_cached_causal_conv1d(
weight: torch.Tensor, # [c_out, c_in/groups, k]
bias: Optional[torch.Tensor],
# STANDARD METADATA
batch_info: torch.Tensor,
batch_info_host: torch.Tensor,
seq_len: torch.Tensor,
cu_seqlen: torch.Tensor,
slot_idx: torch.Tensor,
@ -174,7 +174,7 @@ def _torch_cached_causal_conv1d(
num_seq = seq_len.shape[0]
# get cleaned up metadata
num_prefill, num_prefill_tokens, num_decode = batch_info.tolist()
num_prefill, num_prefill_tokens, num_decode = batch_info_host.tolist()
num_seq = num_prefill + num_decode
seq_len = seq_len[:num_seq]
seq_start = cu_seqlen[:num_seq]
@ -247,7 +247,7 @@ def _torch_cached_causal_conv1d_fake(
weight: torch.Tensor, # [c_out, c_in/groups, k]
bias: Optional[torch.Tensor],
# STANDARD METADATA
batch_info: torch.Tensor,
batch_info_host: torch.Tensor,
seq_len: torch.Tensor,
cu_seqlen: torch.Tensor,
slot_idx: torch.Tensor,
@ -296,7 +296,7 @@ class TorchBackendCausalConv(AttentionDescriptor):
@classmethod
def get_standard_metadata_args(cls) -> List[str]:
return ["batch_info", "seq_len", "cu_seqlen", "slot_idx", "use_initial_states"]
return ["batch_info_host", "seq_len", "cu_seqlen", "slot_idx", "use_initial_states"]
@classmethod
def get_cache_initializers(

View File

@ -121,7 +121,7 @@ def _torch_cached_ssm(
dt: torch.Tensor, # [b, s, num_heads]
dt_bias: torch.Tensor, # [num_heads]
# STANDARD METADATA
batch_info: torch.Tensor,
batch_info_host: torch.Tensor,
seq_len: torch.Tensor,
cu_seqlen: torch.Tensor,
slot_idx: torch.Tensor,
@ -145,7 +145,7 @@ def _torch_cached_ssm(
num_seq = seq_len.shape[0]
# get cleaned up metadata
num_prefill, num_prefill_tokens, num_decode = batch_info.tolist()
num_prefill, num_prefill_tokens, num_decode = batch_info_host.tolist()
num_seq = num_prefill + num_decode
seq_len = seq_len[:num_seq]
seq_start = cu_seqlen[:num_seq]
@ -246,7 +246,7 @@ def _torch_cached_ssm_fake(
dt: torch.Tensor, # [b, s, num_heads]
dt_bias: torch.Tensor, # [num_heads]
# STANDARD METADATA
batch_info: torch.Tensor,
batch_info_host: torch.Tensor,
seq_len: torch.Tensor,
cu_seqlen: torch.Tensor,
slot_idx: torch.Tensor,
@ -293,7 +293,7 @@ class TorchBackendSSM(AttentionDescriptor):
@classmethod
def get_standard_metadata_args(cls) -> List[str]:
return ["batch_info", "seq_len", "cu_seqlen", "slot_idx", "use_initial_states"]
return ["batch_info_host", "seq_len", "cu_seqlen", "slot_idx", "use_initial_states"]
@classmethod
def get_cache_initializers(

View File

@ -42,7 +42,7 @@ from ..attention_interface import (
def _triton_ssm_prepare_metadata(
# INPUTS
position_ids: torch.Tensor,
batch_info: torch.Tensor,
batch_info_host: torch.Tensor,
seq_len: torch.Tensor,
cu_seqlen: torch.Tensor,
# EXTRA METADATA PROVIDED BY THE DESCRIPTOR
@ -53,7 +53,7 @@ def _triton_ssm_prepare_metadata(
Returns a tuple of (seq_len_sanitized, seq_start, slot_idx_sanitized).
"""
device = cu_seqlen.device
num_prefill, num_prefill_tokens, num_decode = batch_info.tolist()
num_prefill, num_prefill_tokens, num_decode = batch_info_host.tolist()
if num_prefill > 0:
chunk_indices, chunk_offsets = cu_seqlens_to_chunk_indices_offsets(
@ -74,7 +74,7 @@ def _triton_ssm_prepare_metadata(
def _triton_ssm_prepare_metadata_fake(
# INPUTS
position_ids: torch.Tensor,
batch_info: torch.Tensor,
batch_info_host: torch.Tensor,
seq_len: torch.Tensor,
cu_seqlen: torch.Tensor,
# EXTRA METADATA PROVIDED BY THE DESCRIPTOR
@ -110,7 +110,7 @@ def _triton_cached_ssm(
dt: torch.Tensor, # [b, s, num_heads]
dt_bias: torch.Tensor, # [num_heads]
# STANDARD METADATA
batch_info: torch.Tensor,
batch_info_host: torch.Tensor,
cu_seqlen: torch.Tensor,
slot_idx: torch.Tensor,
use_initial_states: torch.Tensor,
@ -140,7 +140,7 @@ def _triton_cached_ssm(
ssm_state_size = B.shape[3]
num_prefill, num_prefill_tokens, num_decode = batch_info.tolist()
num_prefill, num_prefill_tokens, num_decode = batch_info_host.tolist()
num_seq = num_prefill + num_decode
num_total_tokens = num_prefill_tokens + num_decode
@ -245,7 +245,7 @@ def _triton_cached_ssm_fake(
dt: torch.Tensor, # [b, s, num_heads]
dt_bias: torch.Tensor, # [num_heads]
# STANDARD METADATA
batch_info: torch.Tensor,
batch_info_host: torch.Tensor,
cu_seqlen: torch.Tensor,
slot_idx: torch.Tensor,
use_initial_states: torch.Tensor,
@ -294,7 +294,7 @@ class TritonBackendSSM(AttentionDescriptor):
@classmethod
def get_standard_metadata_args(cls) -> List[str]:
return ["batch_info", "cu_seqlen", "slot_idx", "use_initial_states"]
return ["batch_info_host", "cu_seqlen", "slot_idx", "use_initial_states"]
@classmethod
def get_prepare_extra_metadata_info(

View File

@ -31,7 +31,7 @@ def fused_flattened_mla_with_cache(
kv: torch.Tensor,
k_pe: torch.Tensor,
# STANDARD METADATA
batch_info: torch.Tensor,
batch_info_host: torch.Tensor,
seq_len: torch.Tensor,
input_pos: torch.Tensor,
cache_loc: torch.Tensor,
@ -55,7 +55,7 @@ def fused_flattened_mla_with_cache(
# and number of tokens per sequence are encoded in seq_len and seq_start.
# check for sequence info and truncate metadata
num_prefill, num_prefill_tokens, num_decode = batch_info.tolist()
num_prefill, num_prefill_tokens, num_decode = batch_info_host.tolist()
num_seq = num_prefill + num_decode
seq_len = seq_len[:num_seq]
@ -166,7 +166,7 @@ def fused_flattened_mla_with_cache_fake(
kv: torch.Tensor,
k_pe: torch.Tensor,
# STANDARD METADATA
batch_info: torch.Tensor,
batch_info_host: torch.Tensor,
seq_len: torch.Tensor,
input_pos: torch.Tensor,
cache_loc: torch.Tensor,
@ -212,7 +212,7 @@ class MultiHeadLatentAttention(AttentionDescriptor):
@classmethod
def get_standard_metadata_args(cls) -> List[str]:
return ["batch_info", "seq_len", "input_pos", "cache_loc", "cu_seqlen"]
return ["batch_info_host", "seq_len", "input_pos", "cache_loc", "cu_seqlen"]
@classmethod
def get_cache_initializers(

View File

@ -253,7 +253,7 @@ def torch_backend_mha_with_cache(
k: torch.Tensor,
v: torch.Tensor,
# STANDARD METADATA
batch_info: torch.Tensor,
batch_info_host: torch.Tensor,
seq_len: torch.Tensor,
input_pos: torch.Tensor,
cache_loc: torch.Tensor,
@ -278,7 +278,7 @@ def torch_backend_mha_with_cache(
b, s = q.shape[:2]
# get cleaned up metadata
num_prefill, num_prefill_tokens, num_decode = batch_info.tolist()
num_prefill, num_prefill_tokens, num_decode = batch_info_host.tolist()
num_seq = num_prefill + num_decode
seq_len = seq_len[:num_seq]
input_pos = input_pos[:num_seq]
@ -352,7 +352,7 @@ def torch_backend_mha_with_cache_fake(
k: torch.Tensor,
v: torch.Tensor,
# STANDARD METADATA
batch_info: torch.Tensor,
batch_info_host: torch.Tensor,
seq_len: torch.Tensor,
input_pos: torch.Tensor,
cache_loc: torch.Tensor,
@ -400,7 +400,7 @@ class TorchBackendAttention(AttentionDescriptor):
@classmethod
def get_standard_metadata_args(cls) -> List[str]:
return ["batch_info", "seq_len", "input_pos", "cache_loc", "cu_seqlen"]
return ["batch_info_host", "seq_len", "input_pos", "cache_loc", "cu_seqlen"]
@classmethod
def get_cache_initializers(

View File

@ -5,14 +5,14 @@ import torch
def gather_logits_before_lm_head(
hidden_states: torch.Tensor,
logits_gather_indices: torch.Tensor, # long tensor
logits_gather_info: torch.Tensor, # int tensor
logits_gather_info_host: torch.Tensor, # int tensor
) -> torch.Tensor:
"""Gather hidden states using logits_gather_indices before LM head.
Args:
hidden_states: Hidden states tensor [b, 1, hidden] or [1, s_total, hidden]
logits_gather_indices: indices for gathering logits.
logits_gather_info: info for gathering logits.
logits_gather_info_host: info for gathering logits.
Returns:
Gathered and flattened hidden states [num_gathered_tokens, hidden]
"""
@ -21,7 +21,7 @@ def gather_logits_before_lm_head(
hidden_states = hidden_states.squeeze(int(is_decode_only))
# info object
num_tokens_to_gather, gather_required = logits_gather_info.tolist()
num_tokens_to_gather, gather_required = logits_gather_info_host.tolist()
if gather_required:
out = hidden_states.index_select(0, logits_gather_indices[:num_tokens_to_gather])
@ -34,7 +34,7 @@ def gather_logits_before_lm_head(
def gather_logits_before_lm_head_fake(
hidden_states: torch.Tensor,
logits_gather_indices: torch.Tensor,
logits_gather_info: torch.Tensor,
logits_gather_info_host: torch.Tensor,
) -> torch.Tensor:
# NOTE: shape is not correct in fake mode
# see https://github.com/NVIDIA/TensorRT-LLM/issues/9878

View File

@ -188,7 +188,7 @@ def flattened_mha_with_cache(
k: torch.Tensor,
v: torch.Tensor,
# STANDARD METADATA
batch_info: torch.Tensor,
batch_info_host: torch.Tensor,
seq_len: torch.Tensor,
input_pos: torch.Tensor,
cache_loc: torch.Tensor,
@ -210,7 +210,7 @@ def flattened_mha_with_cache(
NOTE: this op can also handle seq_len==0, which might be useful for CUDAGRAPH.
"""
# check for sequence info and truncate metadata
num_prefill, num_prefill_tokens, num_decode = batch_info.tolist()
num_prefill, num_prefill_tokens, num_decode = batch_info_host.tolist()
num_seq = num_prefill + num_decode
seq_len = seq_len[:num_seq]
@ -290,7 +290,7 @@ def flattened_mha_fake(
k: torch.Tensor,
v: torch.Tensor,
# STANDARD METADATA
batch_info: torch.Tensor,
batch_info_host: torch.Tensor,
seq_len: torch.Tensor,
input_pos: torch.Tensor,
cache_loc: torch.Tensor,
@ -337,7 +337,7 @@ class TritonAttention(AttentionDescriptor):
@classmethod
def get_standard_metadata_args(cls) -> List[str]:
return ["batch_info", "seq_len", "input_pos", "cache_loc", "cu_seqlen"]
return ["batch_info_host", "seq_len", "input_pos", "cache_loc", "cu_seqlen"]
@classmethod
def get_cache_initializers(

View File

@ -49,15 +49,15 @@ def _bamba_mixer_torch_forward(
)
slot_idx_t = torch.arange(batch_size, device=input_states.device, dtype=torch.long)
use_initial_states_t = torch.zeros(batch_size, device=input_states.device, dtype=torch.bool)
# batch_info: [num_prefill, num_prefill_tokens, num_decode]
# batch_info_host: [num_prefill, num_prefill_tokens, num_decode]
# For context phase (seq_len > 1): [batch_size, batch_size * seq_len, 0]
# For generate phase (seq_len == 1): [0, 0, batch_size]
if seq_len == 1:
batch_info_t = torch.tensor(
batch_info_host_t = torch.tensor(
[0, 0, batch_size], device=input_states.device, dtype=torch.int32
)
else:
batch_info_t = torch.tensor(
batch_info_host_t = torch.tensor(
[batch_size, batch_size * seq_len, 0], device=input_states.device, dtype=torch.int32
)
if use_caching:
@ -68,7 +68,7 @@ def _bamba_mixer_torch_forward(
self.conv1d.weight,
self.conv1d.bias,
# STANDARD METADATA
batch_info_t,
batch_info_host_t,
seq_len_t,
cu_seqlen_t,
slot_idx_t,
@ -123,7 +123,7 @@ def _bamba_mixer_torch_forward(
dt=dt,
dt_bias=self.dt_bias,
# STANDARD METADATA
batch_info=batch_info_t,
batch_info_host=batch_info_host_t,
seq_len=seq_len_t,
cu_seqlen=cu_seqlen_t,
slot_idx=slot_idx_t,

View File

@ -379,7 +379,7 @@ def maybe_pad_for_cuda_graph(func):
# check if we have a dummy request to use
if self.padding_dummy_request is None:
ad_logger.error("No CUDA graph padding possible due to missing dummy request.")
ad_logger.info("No CUDA graph padding possible due to missing dummy request.")
return _call_func()
# pad the scheduled requests with the dummy request

View File

@ -67,12 +67,14 @@ class GatherLogitsBeforeLmHeadTransform(BaseTransform):
# Add logits_gather_mask as input in the graph and the sequence info interface
logits_gather_indices_node = self._add_or_retrieve_input(gm, cm, "logits_gather_indices")
logits_gather_info_node = self._add_or_retrieve_input(gm, cm, "logits_gather_info")
logits_gather_info_host_node = self._add_or_retrieve_input(
gm, cm, "logits_gather_info_host"
)
with gm.graph.inserting_after(node_to_gather):
gathered_node = gm.graph.call_function(
torch.ops.auto_deploy.gather_logits_before_lm_head.default,
args=(node_to_gather, logits_gather_indices_node, logits_gather_info_node),
args=(node_to_gather, logits_gather_indices_node, logits_gather_info_host_node),
)
node_to_gather.replace_all_uses_with(gathered_node)
gathered_node.replace_input_with(gathered_node, node_to_gather)

View File

@ -1,5 +1,6 @@
"""Graph transformation to automatically add kv cache into fused MHA op."""
import inspect
import operator
from typing import Dict, List, Optional, Tuple, Type
@ -106,6 +107,23 @@ class InsertCachedAttention(BaseTransform):
gm, prep_meta_op, inputs_for_prep_meta, const_args, num_meta_out
)
def _process_metadata_host(self, cm: CachedSequenceInterface):
"""Process the host-side prepare metadata function."""
prep_meta_host_op = self.attn_descriptor.get_host_prepare_metadata_function()
if prep_meta_host_op is None:
return
# analyze the args of the host-side prepare metadata function using inspect
sig = inspect.signature(prep_meta_host_op)
args = sig.parameters.keys()
# check if all args are available in the cached sequence interface
unavailable_args = args - cm.info.available_args
assert not unavailable_args, f"Missing args in SequenceInfo: {unavailable_args=}"
# add the host-side prepare metadata function to the graph
cm.info.register_host_prepare_for_attention_forward(prep_meta_host_op, list(args))
def _process_cache_node(self, gm: GraphModule, cache_name: str) -> Node:
"""Process the cache nodes by inserting a cached attention replacement op."""
return add_graph_input(gm, cache_name)
@ -173,6 +191,9 @@ class InsertCachedAttention(BaseTransform):
# insert metadata computation and extract each argument as a node
meta_nodes_extra = self._process_metadata_extra(gm, cm, source_attn_nodes[0])
# Register host-side prepare_metadata function for attention descriptor.
self._process_metadata_host(cm)
buffer_in_lookup: Dict[str, Node] = {}
# replace fused attention node with attention node that has kv cache
@ -213,11 +234,7 @@ class InsertCachedAttention(BaseTransform):
buffer_in_nodes,
constants,
)
# Attention descriptor should register its host function with SequenceInfo.
# This function will be called before graph invocation.
cm.info.register_host_prepare_for_attention_forward(
attn_descriptor.host_prepare_for_forward
)
num_cached_attn_replacements += 1
info = TransformInfo(

View File

@ -220,7 +220,6 @@ class TestNemotronMOE(LlmapiAccuracyTestHarness):
@pytest.mark.skip_less_device_memory(32000)
def test_fp8(self):
kwargs = self.get_default_kwargs()
sampling_params = self.get_default_sampling_params()
with AutoDeployLLM(model=self.MODEL_PATH_FP8,
tokenizer=self.MODEL_PATH_FP8,
**kwargs) as llm:
@ -228,8 +227,8 @@ class TestNemotronMOE(LlmapiAccuracyTestHarness):
llm.args.quant_config.quant_algo = QuantAlgo.FP8
llm.args.quant_config.kv_cache_quant_algo = QuantAlgo.FP8
task = MMLU(self.MODEL_NAME)
task.evaluate(llm, sampling_params=sampling_params)
# task = MMLU(self.MODEL_NAME)
# task.evaluate(llm, sampling_params=sampling_params)
task = GSM8K(self.MODEL_NAME)
task.evaluate(llm)

View File

@ -40,13 +40,13 @@ class TorchAttentionReference:
0, batch_size * seq_len, seq_len, device=q.device, dtype=torch.int32
)
# Create batch_info: [num_prefill, num_prefill_tokens, num_decode]
# Create batch_info_host: [num_prefill, num_prefill_tokens, num_decode]
# For context phase (seq_len > 1): [batch_size, batch_size * seq_len, 0]
# For generate phase (seq_len == 1): [0, 0, batch_size]
if seq_len == 1:
batch_info = torch.tensor([0, 0, batch_size], device=q.device, dtype=torch.int32)
batch_info_host = torch.tensor([0, 0, batch_size], device=q.device, dtype=torch.int32)
else:
batch_info = torch.tensor(
batch_info_host = torch.tensor(
[batch_size, batch_size * seq_len, 0], device=q.device, dtype=torch.int32
)
@ -60,7 +60,7 @@ class TorchAttentionReference:
q_flat,
k_flat,
v_flat,
batch_info,
batch_info_host,
seq_len_tensor,
input_positions,
cache_loc,
@ -84,7 +84,7 @@ class TorchAttentionReference:
q,
k,
v,
batch_info,
batch_info_host,
seq_len,
input_positions,
cache_loc,
@ -101,7 +101,7 @@ class TorchAttentionReference:
q,
k,
v,
batch_info,
batch_info_host,
seq_len,
input_positions,
cache_loc,
@ -144,15 +144,15 @@ class TorchAttentionReference:
k_flat = k_new.view(1, batch_size, -1)
v_flat = v_new.view(1, batch_size, -1)
# Create batch_info for decode phase: [num_prefill, num_prefill_tokens, num_decode]
batch_info = torch.tensor([0, 0, batch_size], device=q.device, dtype=torch.int32)
# Create batch_info_host for decode phase: [num_prefill, num_prefill_tokens, num_decode]
batch_info_host = torch.tensor([0, 0, batch_size], device=q.device, dtype=torch.int32)
# Call torch backend via custom op registry
output_flat = torch.ops.auto_deploy.torch_cached_attention_with_cache(
q_flat,
k_flat,
v_flat,
batch_info,
batch_info_host,
seq_len,
input_positions,
cache_loc,
@ -170,7 +170,7 @@ class TorchAttentionReference:
q,
k,
v,
batch_info,
batch_info_host,
seq_len,
input_positions,
cache_loc,
@ -189,7 +189,7 @@ class TorchAttentionReference:
q,
k,
v,
batch_info,
batch_info_host,
seq_len,
input_positions,
cache_loc,

View File

@ -125,9 +125,9 @@ def test_flat_gqa_op(
k = torch.randn(1, seq_len.sum(), n_kv_heads * D_HEAD, **dtype_kwargs)
v = torch.randn(1, seq_len.sum(), n_kv_heads * D_HEAD, **dtype_kwargs)
# create batch_info: [num_prefill, num_prefill_tokens, num_decode]
# create batch_info_host: [num_prefill, num_prefill_tokens, num_decode]
num_prefill_tokens = seq_len[:num_context].sum()
batch_info = torch.tensor([num_context, num_prefill_tokens, num_generate], **int_kwargs)
batch_info_host = torch.tensor([num_context, num_prefill_tokens, num_generate], **int_kwargs)
# run op
output = torch.ops.auto_deploy.triton_attention_flattened_mha_with_cache(
@ -136,7 +136,7 @@ def test_flat_gqa_op(
k,
v,
# STANDARD METADATA
batch_info,
batch_info_host,
seq_len,
input_positions,
cache_loc,
@ -150,7 +150,7 @@ def test_flat_gqa_op(
# Use torch backend as clean reference
ref_flat = TorchAttentionReference.flattened_mha_with_cache(
q, k, v, batch_info, seq_len, input_positions, cache_loc, seq_start, k_cache, v_cache
q, k, v, batch_info_host, seq_len, input_positions, cache_loc, seq_start, k_cache, v_cache
)
assert torch.allclose(

View File

@ -59,9 +59,9 @@ def test_generate_only_with_slot_mapping_cuda(conv_env):
# Metadata (not used in generate-only op entry, but required by the interface)
cu_seqlen = torch.zeros(batch, device=device, dtype=torch.int32)
use_initial_states = torch.zeros(batch, device=device, dtype=torch.bool)
# batch_info: [num_prefill, num_prefill_tokens, num_decode]
# batch_info_host: [num_prefill, num_prefill_tokens, num_decode]
# For generate-only: num_decode = batch, num_prefill = 0
batch_info = torch.tensor([0, 0, batch], device=device, dtype=torch.int32)
batch_info_host = torch.tensor([0, 0, batch], device=device, dtype=torch.int32)
# Snapshot caches for reference before running op (op mutates caches)
gathered_before = conv_state_cache.clone().index_select(0, slot_idx)
x_ref = x.clone()
@ -72,7 +72,7 @@ def test_generate_only_with_slot_mapping_cuda(conv_env):
w,
b,
# STANDARD METADATA
batch_info,
batch_info_host,
cu_seqlen,
slot_idx,
use_initial_states,

View File

@ -64,6 +64,12 @@ def test_flashinfer_attention_op_context(seq_length, n_heads, batch_size, dtype,
paged_kv_indices = torch.arange(BATCH_SIZE).int().to(device)
paged_kv_last_page_len = offsets + seq_len_tensor
# Host copies of metadata
qo_indptr_host = qo_indptr.cpu()
paged_kv_indptr_host = paged_kv_indptr.cpu()
paged_kv_last_page_len_host = paged_kv_last_page_len.cpu()
seq_len_with_cache_host = (offsets + seq_len_tensor).cpu()
# Q,K,V are computed using GEMM.
q = torch.randn(BATCH_SIZE, SEQ_LEN, N_HEADS * D_HEAD, dtype=DTYPE).to(device)
k = torch.randn(BATCH_SIZE, SEQ_LEN, N_HEADS * D_HEAD, dtype=DTYPE).to(device)
@ -88,8 +94,8 @@ def test_flashinfer_attention_op_context(seq_length, n_heads, batch_size, dtype,
),
BATCH_SIZE * SEQ_LEN,
)
# Create batch_info: [num_prefill, num_prefill_tokens, num_decode]
batch_info = torch.tensor(
# Create batch_info_host: [num_prefill, num_prefill_tokens, num_decode]
batch_info_host = torch.tensor(
[BATCH_SIZE, BATCH_SIZE * SEQ_LEN, 0], dtype=torch.int32, device=device
)
flashinfer_output = torch.ops.auto_deploy.flashinfer_attention_mha_with_cache(
@ -98,11 +104,14 @@ def test_flashinfer_attention_op_context(seq_length, n_heads, batch_size, dtype,
k,
v,
# STANDARD METADATA
batch_info,
qo_indptr,
batch_info_host,
qo_indptr_host,
paged_kv_indptr,
paged_kv_indptr_host,
paged_kv_indices,
paged_kv_last_page_len,
paged_kv_last_page_len_host,
seq_len_with_cache_host,
# EXTRA METADATA
batch_indices,
positions,
@ -168,6 +177,12 @@ def test_flashinfer_attention_op_decode(
paged_kv_indices = torch.arange(BATCH_SIZE).int().to(device)
paged_kv_last_page_len = offsets + seq_len_tensor
# Host copies of metadata
qo_indptr_host = qo_indptr.cpu()
paged_kv_indptr_host = paged_kv_indptr.cpu()
paged_kv_last_page_len_host = paged_kv_last_page_len.cpu()
seq_len_with_cache_host = (offsets + seq_len_tensor).cpu()
# Q,K,V are computed using GEMM.
q = torch.randn(BATCH_SIZE, SEQ_LEN, N_HEADS * D_HEAD, dtype=DTYPE).to(device)
k = torch.ones(BATCH_SIZE, SEQ_LEN, N_HEADS * D_HEAD, dtype=DTYPE).to(device)
@ -224,20 +239,23 @@ def test_flashinfer_attention_op_decode(
),
BATCH_SIZE * SEQ_LEN,
)
# Create batch_info: [num_prefill, num_prefill_tokens, num_decode]
# Create batch_info_host: [num_prefill, num_prefill_tokens, num_decode]
# For decode phase: num_decode = BATCH_SIZE, num_prefill = 0
batch_info = torch.tensor([0, 0, BATCH_SIZE], dtype=torch.int32, device=device)
batch_info_host = torch.tensor([0, 0, BATCH_SIZE], dtype=torch.int32, device=device)
flashinfer_output = torch.ops.auto_deploy.flashinfer_attention_mha_with_cache(
# Q, K, V
q,
k,
v,
# STANDARD METADATA
batch_info,
qo_indptr,
batch_info_host,
qo_indptr_host,
paged_kv_indptr,
paged_kv_indptr_host,
paged_kv_indices,
paged_kv_last_page_len,
paged_kv_last_page_len_host,
seq_len_with_cache_host,
# EXTRA METADATA
batch_indices,
positions,
@ -323,6 +341,12 @@ def test_flashinfer_attention_context_and_generate(
paged_kv_indices = torch.arange(BATCH_SIZE).int().to(device)
paged_kv_last_page_len = offsets + seq_len_tensor
# Host copies of metadata
qo_indptr_host = qo_indptr.cpu()
paged_kv_indptr_host = paged_kv_indptr.cpu()
paged_kv_last_page_len_host = paged_kv_last_page_len.cpu()
seq_len_with_cache_host = (offsets + seq_len_tensor).cpu()
# Q,K,V for prefill phase
q_1 = torch.randn(BATCH_SIZE, PREFILL_SEQ_LEN, N_HEADS * D_HEAD, dtype=DTYPE).to(device)
k_1 = torch.randn(BATCH_SIZE, PREFILL_SEQ_LEN, N_HEADS * D_HEAD, dtype=DTYPE).to(device)
@ -347,8 +371,8 @@ def test_flashinfer_attention_context_and_generate(
),
BATCH_SIZE * PREFILL_SEQ_LEN,
)
# Create batch_info: [num_prefill, num_prefill_tokens, num_decode]
batch_info = torch.tensor(
# Create batch_info_host: [num_prefill, num_prefill_tokens, num_decode]
batch_info_host = torch.tensor(
[BATCH_SIZE, BATCH_SIZE * PREFILL_SEQ_LEN, 0], dtype=torch.int32, device=device
)
flashinfer_output_1 = torch.ops.auto_deploy.flashinfer_attention_mha_with_cache(
@ -357,11 +381,14 @@ def test_flashinfer_attention_context_and_generate(
k_1,
v_1,
# STANDARD METADATA
batch_info,
qo_indptr,
batch_info_host,
qo_indptr_host,
paged_kv_indptr,
paged_kv_indptr_host,
paged_kv_indices,
paged_kv_last_page_len,
paged_kv_last_page_len_host,
seq_len_with_cache_host,
# EXTRA METADATA
batch_indices,
positions,
@ -415,6 +442,12 @@ def test_flashinfer_attention_context_and_generate(
paged_kv_indices = torch.arange(BATCH_SIZE).int().to(device)
paged_kv_last_page_len = offsets + seq_len_tensor
# Host copies of metadata
qo_indptr_host = qo_indptr.cpu()
paged_kv_indptr_host = paged_kv_indptr.cpu()
paged_kv_last_page_len_host = paged_kv_last_page_len.cpu()
seq_len_with_cache_host = (offsets + seq_len_tensor).cpu()
# Q,K,V are computed using GEMM.
q_3 = torch.randn(BATCH_SIZE, 1, N_HEADS * D_HEAD, dtype=DTYPE).to(device)
k_3 = torch.randn(BATCH_SIZE, 1, N_HEADS * D_HEAD, dtype=DTYPE).to(device)
@ -430,19 +463,22 @@ def test_flashinfer_attention_context_and_generate(
),
BATCH_SIZE * 1,
)
# Create batch_info: [num_prefill, num_prefill_tokens, num_decode]
batch_info = torch.tensor([0, 0, BATCH_SIZE], dtype=torch.int32, device=device)
# Create batch_info_host: [num_prefill, num_prefill_tokens, num_decode]
batch_info_host = torch.tensor([0, 0, BATCH_SIZE], dtype=torch.int32, device=device)
flashinfer_output_3 = torch.ops.auto_deploy.flashinfer_attention_mha_with_cache(
# Q, K, V
q_3,
k_3,
v_3,
# STANDARD METADATA
batch_info,
qo_indptr,
batch_info_host,
qo_indptr_host,
paged_kv_indptr,
paged_kv_indptr_host,
paged_kv_indices,
paged_kv_last_page_len,
paged_kv_last_page_len_host,
seq_len_with_cache_host,
# EXTRA METADATA
batch_indices,
positions,
@ -519,6 +555,12 @@ def test_flashinfer_attention_op_context_input_pos(seq, batch_size, n_heads, dty
paged_kv_indices = torch.arange(BATCH_SIZE).int().to(device)
paged_kv_last_page_len = offsets + seq_len_tensor
# Host copies of metadata
qo_indptr_host = qo_indptr.cpu()
paged_kv_indptr_host = paged_kv_indptr.cpu()
paged_kv_last_page_len_host = paged_kv_last_page_len.cpu()
seq_len_with_cache_host = (offsets + seq_len_tensor).cpu()
# Q,K,V are computed using GEMM.
q = torch.randn(BATCH_SIZE, SEQ_LEN, N_HEADS * D_HEAD, dtype=DTYPE).to(device)
k = torch.randn(BATCH_SIZE, SEQ_LEN, N_HEADS * D_HEAD, dtype=DTYPE).to(device)
@ -543,8 +585,8 @@ def test_flashinfer_attention_op_context_input_pos(seq, batch_size, n_heads, dty
),
BATCH_SIZE * SEQ_LEN,
)
# Create batch_info: [num_prefill, num_prefill_tokens, num_decode]
batch_info = torch.tensor(
# Create batch_info_host: [num_prefill, num_prefill_tokens, num_decode]
batch_info_host = torch.tensor(
[BATCH_SIZE, BATCH_SIZE * SEQ_LEN, 0], dtype=torch.int32, device=device
)
flashinfer_output = torch.ops.auto_deploy.flashinfer_attention_mha_with_cache(
@ -553,11 +595,14 @@ def test_flashinfer_attention_op_context_input_pos(seq, batch_size, n_heads, dty
k,
v,
# STANDARD METADATA
batch_info,
qo_indptr,
batch_info_host,
qo_indptr_host,
paged_kv_indptr,
paged_kv_indptr_host,
paged_kv_indices,
paged_kv_last_page_len,
paged_kv_last_page_len_host,
seq_len_with_cache_host,
# EXTRA METADATA
batch_indices,
positions,
@ -642,6 +687,12 @@ def test_flashinfer_attention_with_fp8_cache(
paged_kv_indices = torch.arange(BATCH_SIZE).int().to(device)
paged_kv_last_page_len = offsets + seq_len_tensor
# Host copies of metadata
qo_indptr_host = qo_indptr.cpu()
paged_kv_indptr_host = paged_kv_indptr.cpu()
paged_kv_last_page_len_host = paged_kv_last_page_len.cpu()
seq_len_with_cache_host = (offsets + seq_len_tensor).cpu()
# Q,K,V are computed using GEMM, in fp16
q = torch.randn(BATCH_SIZE, SEQ_LEN, N_HEADS * D_HEAD, dtype=DTYPE).to(device)
k = torch.randn(BATCH_SIZE, SEQ_LEN, N_HEADS * D_HEAD, dtype=DTYPE).to(device)
@ -696,8 +747,8 @@ def test_flashinfer_attention_with_fp8_cache(
),
BATCH_SIZE * SEQ_LEN,
)
# Create batch_info: [num_prefill, num_prefill_tokens, num_decode]
batch_info = torch.tensor(
# Create batch_info_host: [num_prefill, num_prefill_tokens, num_decode]
batch_info_host = torch.tensor(
[BATCH_SIZE, BATCH_SIZE * SEQ_LEN, 0], dtype=torch.int32, device=device
)
flashinfer_output = torch.ops.auto_deploy.flashinfer_attention_mha_with_cache(
@ -706,11 +757,14 @@ def test_flashinfer_attention_with_fp8_cache(
k,
v,
# STANDARD METADATA
batch_info,
qo_indptr,
batch_info_host,
qo_indptr_host,
paged_kv_indptr,
paged_kv_indptr_host,
paged_kv_indices,
paged_kv_last_page_len,
paged_kv_last_page_len_host,
seq_len_with_cache_host,
# EXTRA METADATA
batch_indices,
positions,
@ -787,6 +841,12 @@ def test_flashinfer_attention_with_paged_kvcache(seq_lengths, n_heads, dtype, de
)
paged_kv_last_page_len = ((offsets + seq_len_tensor - 1) % PAGE_SIZE) + 1
# Host copies of metadata
qo_indptr_host = qo_indptr.cpu()
paged_kv_indptr_host = paged_kv_indptr.cpu()
paged_kv_last_page_len_host = paged_kv_last_page_len.cpu()
seq_len_with_cache_host = (offsets + seq_len_tensor).cpu()
# make sure planner is initialized
workspace = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device=device)
_GlobalFlashInferPlanner.init_workspace(workspace)
@ -798,19 +858,22 @@ def test_flashinfer_attention_with_paged_kvcache(seq_lengths, n_heads, dtype, de
),
SEQ_LEN,
)
# Create batch_info: [num_prefill, num_prefill_tokens, num_decode]
batch_info = torch.tensor([BATCH_SIZE, SEQ_LEN, 0], dtype=torch.int32, device=device)
# Create batch_info_host: [num_prefill, num_prefill_tokens, num_decode]
batch_info_host = torch.tensor([BATCH_SIZE, SEQ_LEN, 0], dtype=torch.int32, device=device)
flashinfer_output = torch.ops.auto_deploy.flashinfer_attention_mha_with_cache(
# Q, K, V
q,
k,
v,
# STANDARD METADATA
batch_info,
qo_indptr,
batch_info_host,
qo_indptr_host,
paged_kv_indptr,
paged_kv_indptr_host,
paged_kv_indices,
paged_kv_last_page_len,
paged_kv_last_page_len_host,
seq_len_with_cache_host,
# EXTRA METADATA
batch_indices,
positions,
@ -875,6 +938,12 @@ def test_flashinfer_attention_with_paged_kvcache(seq_lengths, n_heads, dtype, de
)
paged_kv_last_page_len2 = ((offsets2 + seq_len_tensor2 - 1) % PAGE_SIZE) + 1
# Host copies of metadata
qo_indptr2_host = qo_indptr2.cpu()
paged_kv_indptr2_host = paged_kv_indptr2.cpu()
paged_kv_last_page_len2_host = paged_kv_last_page_len2.cpu()
seq_len_with_cache2_host = (offsets2 + seq_len_tensor2).cpu()
# Create FlashInferAttention class before calling the custom op
_GlobalFlashInferPlanner.reset()
@ -885,19 +954,22 @@ def test_flashinfer_attention_with_paged_kvcache(seq_lengths, n_heads, dtype, de
),
BATCH_SIZE * 1,
)
# Create batch_info: [num_prefill, num_prefill_tokens, num_decode]
batch_info = torch.tensor([0, 0, BATCH_SIZE], dtype=torch.int32, device=device)
# Create batch_info_host: [num_prefill, num_prefill_tokens, num_decode]
batch_info_host = torch.tensor([0, 0, BATCH_SIZE], dtype=torch.int32, device=device)
flashinfer_output_gen = torch.ops.auto_deploy.flashinfer_attention_mha_with_cache(
# Q, K, V
q_gen,
k_gen,
v_gen,
# STANDARD METADATA
batch_info,
qo_indptr2,
batch_info_host,
qo_indptr2_host,
paged_kv_indptr2,
paged_kv_indptr2_host,
paged_kv_indices2,
paged_kv_last_page_len2,
paged_kv_last_page_len2_host,
seq_len_with_cache2_host,
# EXTRA METADATA
batch_indices,
positions,

View File

@ -246,14 +246,16 @@ class TestTorchBackendAttention:
if seq_len == 1:
# Generate phase: [num_prefill, num_prefill_tokens, num_decode]
batch_info = torch.tensor([0, 0, batch_size], device=self.device, dtype=torch.int32)
batch_info_host = torch.tensor(
[0, 0, batch_size], device=self.device, dtype=torch.int32
)
seq_start = torch.arange(batch_size, device=self.device, dtype=torch.int32)
q_flat = q.view(batch_size, seq_len, -1)
k_flat = k.view(batch_size, seq_len, -1)
v_flat = v.view(batch_size, seq_len, -1)
else:
# Context phase: [num_prefill, num_prefill_tokens, num_decode]
batch_info = torch.tensor(
batch_info_host = torch.tensor(
[batch_size, batch_size * seq_len, 0], device=self.device, dtype=torch.int32
)
seq_start = torch.arange(
@ -267,7 +269,7 @@ class TestTorchBackendAttention:
"q": q_flat,
"k": k_flat,
"v": v_flat,
"batch_info": batch_info,
"batch_info_host": batch_info_host,
"seq_len": seq_len_tensor,
"input_pos": input_positions,
"cache_loc": cache_loc,
@ -286,7 +288,7 @@ class TestTorchBackendAttention:
data["k"],
data["v"],
# STANDARD METADATA
data["batch_info"],
data["batch_info_host"],
data["seq_len"],
data["input_pos"],
data["cache_loc"],

View File

@ -59,9 +59,9 @@ def test_generate_only_with_slot_mapping(conv_env):
# Snapshot caches for reference before running op (op mutates caches)
gathered_before = conv_state_cache.clone().index_select(0, slot_idx)
use_initial_states = torch.zeros(batch, device=device, dtype=torch.bool)
# batch_info: [num_prefill, num_prefill_tokens, num_decode]
# batch_info_host: [num_prefill, num_prefill_tokens, num_decode]
# For generate-only: num_decode = batch, num_prefill = 0
batch_info = torch.tensor([0, 0, batch], device=device, dtype=torch.int32)
batch_info_host = torch.tensor([0, 0, batch], device=device, dtype=torch.int32)
# Run cached op
y = torch.ops.auto_deploy.torch_cached_causal_conv1d(
# INPUTS
@ -69,7 +69,7 @@ def test_generate_only_with_slot_mapping(conv_env):
w,
b,
# STANDARD METADATA
batch_info,
batch_info_host,
seq_len,
cu_seqlen,
slot_idx,
@ -124,18 +124,20 @@ def test_context_flattened_and_state_writeback(conv_env):
seq_len = torch.tensor(lens, device=device, dtype=torch.int32)
cu_seqlen = torch.tensor([0, lens[0]], device=device, dtype=torch.int32)
use_initial_states = torch.zeros(batch, device=device, dtype=torch.bool)
# batch_info: [num_prefill, num_prefill_tokens, num_decode]
# batch_info_host: [num_prefill, num_prefill_tokens, num_decode]
# For context/prefill phase: num_prefill = len(lens), num_decode = 0
num_seqs = len(lens)
num_prefill_tokens = sum(lens)
batch_info = torch.tensor([num_seqs, num_prefill_tokens, 0], device=device, dtype=torch.int32)
batch_info_host = torch.tensor(
[num_seqs, num_prefill_tokens, 0], device=device, dtype=torch.int32
)
y = torch.ops.auto_deploy.torch_cached_causal_conv1d(
# INPUTS
x,
w,
b,
# STANDARD METADATA
batch_info,
batch_info_host,
seq_len,
cu_seqlen,
slot_idx,

View File

@ -66,9 +66,9 @@ def test_generate_only_with_slot_mapping(mamba_env):
seq_len = torch.ones(batch, device=device, dtype=torch.int32)
cu_seqlen = torch.zeros(batch, device=device, dtype=torch.int32)
use_initial_states = torch.zeros(batch, device=device, dtype=torch.bool)
# batch_info: [num_prefill, num_prefill_tokens, num_decode]
# batch_info_host: [num_prefill, num_prefill_tokens, num_decode]
# For generate-only: num_decode = batch, num_prefill = 0
batch_info = torch.tensor([0, 0, batch], device=device, dtype=torch.int32)
batch_info_host = torch.tensor([0, 0, batch], device=device, dtype=torch.int32)
# Snapshot caches for reference before running op (op mutates caches)
gathered_before = ssm_state_cache.clone().index_select(0, slot_idx)
@ -83,7 +83,7 @@ def test_generate_only_with_slot_mapping(mamba_env):
dt,
dt_bias,
# STANDARD METADATA
batch_info,
batch_info_host,
seq_len,
cu_seqlen,
slot_idx,
@ -141,11 +141,13 @@ def test_context_flattened_and_state_writeback(mamba_env):
seq_len = torch.tensor(lens, device=device, dtype=torch.int32)
cu_seqlen = torch.tensor([0, lens[0]], device=device, dtype=torch.int32)
use_initial_states = torch.zeros(batch, device=device, dtype=torch.bool)
# batch_info: [num_prefill, num_prefill_tokens, num_decode]
# batch_info_host: [num_prefill, num_prefill_tokens, num_decode]
# For context/prefill phase: num_prefill = len(lens), num_decode = 0
num_seqs = len(lens)
num_prefill_tokens = sum(lens)
batch_info = torch.tensor([num_seqs, num_prefill_tokens, 0], device=device, dtype=torch.int32)
batch_info_host = torch.tensor(
[num_seqs, num_prefill_tokens, 0], device=device, dtype=torch.int32
)
y = torch.ops.auto_deploy.torch_cached_ssm(
# INPUTS
hidden_states,
@ -156,7 +158,7 @@ def test_context_flattened_and_state_writeback(mamba_env):
dt,
dt_bias,
# STANDARD METADATA
batch_info,
batch_info_host,
seq_len,
cu_seqlen,
slot_idx,

View File

@ -134,8 +134,8 @@ def test_triton_context_flattened_and_state_writeback(mamba_env):
torch.arange(len(lens), device=device, dtype=torch.int32),
seq_len,
).view(1, -1)
# batch_info: [num_prefill, num_prefill_tokens, num_decode]
batch_info_tensor = torch.tensor([len(lens), sum(lens), 0], dtype=torch.int32, device=device)
# batch_info_host: [num_prefill, num_prefill_tokens, num_decode]
batch_info_host = torch.tensor([len(lens), sum(lens), 0], dtype=torch.int32, device=device)
# Torch reference
y_torch = torch.ops.auto_deploy.torch_cached_ssm(
hidden_states,
@ -146,7 +146,7 @@ def test_triton_context_flattened_and_state_writeback(mamba_env):
dt,
dt_bias,
# STANDARD METADATA
batch_info_tensor,
batch_info_host,
seq_len,
cu_seqlen,
slot_idx,
@ -168,7 +168,7 @@ def test_triton_context_flattened_and_state_writeback(mamba_env):
dt,
dt_bias,
# STANDARD METADATA
batch_info_tensor,
batch_info_host,
cu_seqlens,
slot_idx,
use_initial_states,

View File

@ -56,10 +56,10 @@ class TestGatherLogitsBeforeLmHeadOp:
# Create gather info: num_tokens_to_gather=batch_size, gather_required=0 (False)
logits_gather_indices = torch.arange(batch_size, dtype=torch.long, device="cuda")
logits_gather_info = torch.tensor([batch_size, 0], dtype=torch.int32, device="cuda")
logits_gather_info_host = torch.tensor([batch_size, 0], dtype=torch.int32, device="cpu")
output = torch.ops.auto_deploy.gather_logits_before_lm_head.default(
hidden_states, logits_gather_indices, logits_gather_info
hidden_states, logits_gather_indices, logits_gather_info_host
)
# Should return [batch, 1, hidden] for generate format (3D shape preserved)
@ -82,10 +82,10 @@ class TestGatherLogitsBeforeLmHeadOp:
gather_indices = torch.arange(0, num_gather, dtype=torch.long, device="cuda")
# Create gather info: num_tokens_to_gather=num_gather, gather_required=1 (True)
logits_gather_info = torch.tensor([num_gather, 1], dtype=torch.int32, device="cuda")
logits_gather_info_host = torch.tensor([num_gather, 1], dtype=torch.int32, device="cpu")
output = torch.ops.auto_deploy.gather_logits_before_lm_head.default(
hidden_states, gather_indices, logits_gather_info
hidden_states, gather_indices, logits_gather_info_host
)
# Should return [1, num_gather, hidden] for packed format (3D shape preserved)
@ -105,15 +105,15 @@ class TestGatherLogitsBeforeLmHeadOp:
# Create gather info
logits_gather_indices = torch.arange(batch_size, dtype=torch.long, device="cuda")
logits_gather_info = torch.tensor([batch_size, 0], dtype=torch.int32, device="cuda")
logits_gather_info_host = torch.tensor([batch_size, 0], dtype=torch.int32, device="cpu")
# Use fake implementation directly
with FakeTensorMode() as mode:
hidden_states_fake = mode.from_tensor(hidden_states)
logits_gather_indices_fake = mode.from_tensor(logits_gather_indices)
logits_gather_info_fake = mode.from_tensor(logits_gather_info)
logits_gather_info_host_fake = mode.from_tensor(logits_gather_info_host)
output = torch.ops.auto_deploy.gather_logits_before_lm_head.default(
hidden_states_fake, logits_gather_indices_fake, logits_gather_info_fake
hidden_states_fake, logits_gather_indices_fake, logits_gather_info_host_fake
)
# Should return [batch, 1, hidden_size] (fake returns empty_like which preserves 3D shape)
@ -132,15 +132,15 @@ class TestGatherLogitsBeforeLmHeadOp:
# Create gather info
logits_gather_indices = torch.arange(num_gather, dtype=torch.long, device="cuda")
logits_gather_info = torch.tensor([num_gather, 1], dtype=torch.int32, device="cuda")
logits_gather_info_host = torch.tensor([num_gather, 1], dtype=torch.int32, device="cpu")
# Use fake implementation directly
with FakeTensorMode() as mode:
hidden_states_fake = mode.from_tensor(hidden_states)
logits_gather_indices_fake = mode.from_tensor(logits_gather_indices)
logits_gather_info_fake = mode.from_tensor(logits_gather_info)
logits_gather_info_host_fake = mode.from_tensor(logits_gather_info_host)
output = torch.ops.auto_deploy.gather_logits_before_lm_head.default(
hidden_states_fake, logits_gather_indices_fake, logits_gather_info_fake
hidden_states_fake, logits_gather_indices_fake, logits_gather_info_host_fake
)
# The fake implementation returns empty_like which preserves input shape [1, total_tokens, hidden]
@ -217,13 +217,13 @@ class TestGatherLogitsBeforeLmHeadTransform:
# Test forward pass
# We must pass the new graph inputs manually since we are running the graph directly
logits_gather_indices = torch.arange(batch_size, dtype=torch.long, device="cuda")
logits_gather_info = torch.tensor([batch_size, 0], dtype=torch.int32, device="cuda")
logits_gather_info_host = torch.tensor([batch_size, 0], dtype=torch.int32, device="cpu")
output = gm_transformed(
hidden_states,
logit_gather_ids,
seq_len,
logits_gather_indices=logits_gather_indices,
logits_gather_info=logits_gather_info,
logits_gather_info_host=logits_gather_info_host,
)
# Output should be [batch_size, 1, vocab_size] since gather now returns 3D
assert output.shape == (batch_size, 1, vocab_size)
@ -278,13 +278,13 @@ class TestGatherLogitsBeforeLmHeadTransform:
# We must pass the new graph inputs manually since we are running the graph directly
num_gather = len(logit_gather_ids)
logits_gather_indices = logit_gather_ids
logits_gather_info = torch.tensor([num_gather, 1], dtype=torch.int32, device="cuda")
logits_gather_info_host = torch.tensor([num_gather, 1], dtype=torch.int32, device="cpu")
output = gm_transformed(
hidden_states,
logit_gather_ids_padded,
seq_len,
logits_gather_indices=logits_gather_indices,
logits_gather_info=logits_gather_info,
logits_gather_info_host=logits_gather_info_host,
)
# Output should be [1, num_gather, vocab_size] since gather now returns 3D
assert output.shape == (1, num_gather, vocab_size)