[None][feat] AutoDeploy: Flashinfer kernels bringup (#10867)

Signed-off-by: nvchenghaoz <211069071+nvchenghaoz@users.noreply.github.com>
This commit is contained in:
Chenghao Zhang 2026-01-29 14:59:29 -08:00 committed by GitHub
parent 0ad87895f5
commit e033929221
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 716 additions and 230 deletions

View File

@ -113,7 +113,7 @@ class _FlashInferPlanner:
self.prefill_wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(
self.workspace_buffer,
self.kv_layout,
backend="fa2",
backend="fa2" if torch.cuda.get_device_capability(0) == (9, 0) else "auto",
)
self.decode_wrapper = self._init_decode_wrapper()

View File

@ -0,0 +1,246 @@
from typing import List
import torch
from torch.fx import Node
from .....llmapi.llm_args import KvCacheConfig
from ..attention_interface import (
AttentionRegistry,
MHACallable,
ResourceHandler,
ResourceHandlerDict,
SequenceInfo,
)
from .mamba_backend_common import (
BaseBackendSSM,
_flatten_ssm_inputs,
_prepare_ssm_decode_inputs,
_run_ssm_prefill,
)
@torch.library.custom_op("auto_deploy::flashinfer_cached_ssm", mutates_args={})
def _flashinfer_cached_ssm(
# INPUTS (dense but may be flattened across sequences)
hidden_states: torch.Tensor, # [b, s, num_heads, head_dim]
A: torch.Tensor, # [num_heads]
B: torch.Tensor, # [b, s, n_groups, ssm_state_size]
C: torch.Tensor, # [b, s, n_groups, ssm_state_size]
D: torch.Tensor, # [num_heads]
dt: torch.Tensor, # [b, s, num_heads]
dt_bias: torch.Tensor, # [num_heads]
# STANDARD METADATA
batch_info_host: torch.Tensor,
cu_seqlen: torch.Tensor,
slot_idx: torch.Tensor,
use_initial_states: torch.Tensor,
# EXTRA METADATA
chunk_indices: torch.Tensor, # [num_logical_chunks]
chunk_offsets: torch.Tensor, # [num_logical_chunks]
seq_idx_prefill: torch.Tensor, # [1, num_prefill_tokens]
# CACHES
ssm_state_cache: torch.Tensor, # [max_batch_size, num_heads, head_dim, ssm_state_size]
# CONSTANTS
time_step_limit: List[float],
chunk_size: int,
) -> torch.Tensor:
b, s, num_heads, head_dim, bs, hs_flat, B_flat, C_flat, dt_flat = _flatten_ssm_inputs(
hidden_states, B, C, dt
)
ssm_state_size = B.shape[3]
num_prefill, num_prefill_tokens, num_decode = batch_info_host.tolist()
num_seq = num_prefill + num_decode
num_total_tokens = num_prefill_tokens + num_decode
# Preallocate output tensor to avoid memcpy cost for merging prefill
# and decode outputs
preallocated_ssm_out = torch.empty(
[bs, num_heads, head_dim],
dtype=hidden_states.dtype,
device=hidden_states.device,
)
preallocated_ssm_out_p = preallocated_ssm_out[:num_prefill_tokens]
num_prefill, num_prefill_tokens, num_total_tokens, num_seq = _run_ssm_prefill(
hs_flat,
B_flat,
C_flat,
dt_flat,
A,
D,
dt_bias,
batch_info_host,
cu_seqlen,
slot_idx,
use_initial_states,
chunk_indices,
chunk_offsets,
seq_idx_prefill,
ssm_state_cache,
time_step_limit,
chunk_size,
preallocated_ssm_out_p.unsqueeze(0),
)
num_decode = num_total_tokens - num_prefill_tokens
decode_inputs = _prepare_ssm_decode_inputs(
hs_flat,
B_flat,
C_flat,
dt_flat,
A,
D,
dt_bias,
slot_idx,
num_prefill,
num_prefill_tokens,
num_seq,
num_total_tokens,
num_heads,
head_dim,
ssm_state_size,
)
y_decode = None
if decode_inputs is not None:
(
slot_idx_decode,
x_decode,
B_decode,
C_decode,
dt_hp,
dt_bias_hp,
A_full,
D_full,
) = decode_inputs
import flashinfer
slot_idx_decode_i32 = slot_idx_decode.to(torch.int32)
y_decode = flashinfer.mamba.selective_state_update(
ssm_state_cache,
x_decode,
dt_hp,
A_full,
B_decode,
C_decode,
D=D_full,
z=None,
dt_bias=dt_bias_hp,
dt_softplus=True,
state_batch_indices=slot_idx_decode_i32,
)
preallocated_ssm_out[num_prefill_tokens:num_total_tokens].copy_(y_decode)
if num_total_tokens > 0:
return (
preallocated_ssm_out[:num_total_tokens]
.view(b, s, num_heads, head_dim)
.to(hidden_states.dtype)
)
else:
return torch.empty_like(hidden_states)
@_flashinfer_cached_ssm.register_fake
def _flashinfer_cached_ssm_fake(
# INPUTS (dense but may be flattened across sequences)
hidden_states: torch.Tensor, # [b, s, num_heads, head_dim]
A: torch.Tensor, # [num_heads]
B: torch.Tensor, # [b, s, n_groups, ssm_state_size]
C: torch.Tensor, # [b, s, n_groups, ssm_state_size]
D: torch.Tensor, # [num_heads]
dt: torch.Tensor, # [b, s, num_heads]
dt_bias: torch.Tensor, # [num_heads]
# STANDARD METADATA
batch_info_host: torch.Tensor,
cu_seqlen: torch.Tensor,
slot_idx: torch.Tensor,
use_initial_states: torch.Tensor,
# EXTRA METADATA
chunk_indices: torch.Tensor, # [num_logical_chunks]
chunk_offsets: torch.Tensor, # [num_logical_chunks]
seq_idx_prefill: torch.Tensor, # [1, num_prefill_tokens]
# CACHES
ssm_state_cache: torch.Tensor, # [max_batch_size, num_heads, head_dim, ssm_state_size]
# CONSTANTS
time_step_limit: List[float],
chunk_size: int,
):
# Return a correctly-shaped tensor for tracing with fake tensors
return torch.empty_like(
hidden_states,
memory_format=torch.contiguous_format,
dtype=hidden_states.dtype,
)
# Flashinfer's selective_state_update kernel only supports these head dimensions
FLASHINFER_SUPPORTED_HEAD_DIMS = [64, 128]
class FlashInferStateResourceHandler(ResourceHandler):
"""Handler for flashinfer SSM state resources.
Unlike the default StateResourceHandler which uses byte-level pooling (resulting
in non-contiguous strided views), this handler allocates a separate contiguous
buffer. This is required because flashinfer's selective_state_update kernel
requires the entire state tensor to be contiguous.
"""
def __init__(self, *state_shape: int, dtype: torch.dtype) -> None:
self.state_shape = state_shape
self.dtype = dtype
def allocate(self, sequence_info: SequenceInfo) -> torch.Tensor:
"""Allocate a contiguous state buffer for flashinfer."""
return torch.empty(
sequence_info.max_num_state_slots,
*self.state_shape,
device=sequence_info.device,
dtype=self.dtype,
)
@AttentionRegistry.register("flashinfer_ssm")
class FlashinferBackendSSM(BaseBackendSSM):
@classmethod
def get_cached_attention_op(cls) -> MHACallable:
return torch.ops.auto_deploy.flashinfer_cached_ssm.default
@classmethod
def get_cache_initializers(
cls, source_attn_node: Node, cache_config: KvCacheConfig
) -> ResourceHandlerDict:
"""Get cache initializers using FlashInferStateResourceHandler.
We use a custom handler that allocates contiguous buffers directly,
instead of the default StateResourceHandler which creates non-contiguous
views from a shared byte buffer. This is required because flashinfer's
selective_state_update kernel requires contiguous state tensors.
"""
# Shapes from fake tensors
hs_fake: torch.Tensor = source_attn_node.args[0].meta["val"]
B_fake: torch.Tensor = source_attn_node.args[2].meta["val"]
num_heads = hs_fake.shape[-2]
head_dim = hs_fake.shape[-1]
# Validate head_dim is supported by flashinfer
if head_dim not in FLASHINFER_SUPPORTED_HEAD_DIMS:
raise ValueError(
f"Flashinfer SSM backend only supports head_dim in {FLASHINFER_SUPPORTED_HEAD_DIMS}, "
f"but got head_dim={head_dim}. Consider using 'triton_ssm' backend instead."
)
if B_fake.ndim >= 4:
ssm_state_size = B_fake.shape[-1]
else:
ssm_state_size = max(1, B_fake.shape[-1])
# Extract ssm_state_dtype from cache_config or hs_fake
ssm_state_dtype = cls.resolve_cache_dtype(cache_config.mamba_ssm_cache_dtype, hs_fake.dtype)
return {
"ssm_state_cache": FlashInferStateResourceHandler(
num_heads, head_dim, ssm_state_size, dtype=ssm_state_dtype
)
}

View File

@ -0,0 +1,291 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Tuple
import torch
from torch._ops import OpOverloadPacket
from torch.fx import Node
from tensorrt_llm._torch.modules.mamba.mamba2_metadata import cu_seqlens_to_chunk_indices_offsets
from tensorrt_llm._torch.modules.mamba.ssd_combined import mamba_chunk_scan_combined
from .....llmapi.llm_args import KvCacheConfig
from ...utils.node_utils import extract_op_args
from ..attention_interface import (
AttentionDescriptor,
AttentionLayout,
Constant,
PrepareMetadataCallable,
ResourceHandlerDict,
StateResourceHandler,
)
@torch.library.custom_op("auto_deploy::mamba_ssm_prepare_metadata", mutates_args=())
def _mamba_ssm_prepare_metadata(
# INPUTS
position_ids: torch.Tensor,
batch_info_host: torch.Tensor,
seq_len: torch.Tensor,
cu_seqlen: torch.Tensor,
# EXTRA METADATA PROVIDED BY THE DESCRIPTOR
chunk_size: int,
) -> List[torch.Tensor]:
"""Prepare metadata for cached SSM transform.
Returns a tuple of (chunk_indices, chunk_offsets, seq_idx_prefill).
"""
device = cu_seqlen.device
num_prefill, num_prefill_tokens, num_decode = batch_info_host.tolist()
if num_prefill > 0:
chunk_indices, chunk_offsets = cu_seqlens_to_chunk_indices_offsets(
cu_seqlen[: num_prefill + 1], chunk_size
)
seq_idx_prefill = torch.repeat_interleave(
torch.arange(num_prefill, device=device, dtype=torch.int32), seq_len[:num_prefill]
).view(1, -1)
else:
chunk_indices = torch.empty(0, dtype=torch.int32, device=device)
chunk_offsets = torch.empty(0, dtype=torch.int32, device=device)
seq_idx_prefill = torch.empty(1, 0, dtype=torch.int32, device=device)
return (chunk_indices, chunk_offsets, seq_idx_prefill)
@_mamba_ssm_prepare_metadata.register_fake
def _mamba_ssm_prepare_metadata_fake(
# INPUTS
position_ids: torch.Tensor,
batch_info_host: torch.Tensor,
seq_len: torch.Tensor,
cu_seqlen: torch.Tensor,
# EXTRA METADATA PROVIDED BY THE DESCRIPTOR
chunk_size: int,
):
b, s = position_ids.shape[:2]
num_tokens = b * s
device = cu_seqlen.device
dtype = torch.int32
if s > 1:
# NOTE: this is only an upper bound for the shape in this case...
return (
torch.empty(num_tokens, dtype=dtype, device=device), # chunk_indices
torch.empty(num_tokens, dtype=dtype, device=device), # chunk_offsets
torch.empty(1, num_tokens, dtype=dtype, device=device), # seq_idx_prefill
)
else:
return (
torch.empty(0, dtype=dtype, device=device), # chunk_indices
torch.empty(0, dtype=dtype, device=device), # chunk_offsets
torch.empty(1, 0, dtype=dtype, device=device), # seq_idx_prefill
)
def _flatten_ssm_inputs(
hidden_states: torch.Tensor,
B: torch.Tensor,
C: torch.Tensor,
dt: torch.Tensor,
) -> Tuple[int, int, int, int, int, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
b, s, num_heads, head_dim = hidden_states.shape
bs = b * s
hs_flat = hidden_states.reshape(bs, *hidden_states.shape[2:]) # [bs, H, D]
B_flat = B.reshape(bs, *B.shape[2:]) # [bs, G, N]
C_flat = C.reshape(bs, *C.shape[2:]) # [bs, G, N]
dt_flat = dt.reshape(bs, dt.shape[2]) # [bs, H]
return b, s, num_heads, head_dim, bs, hs_flat, B_flat, C_flat, dt_flat
def _run_ssm_prefill(
hs_flat: torch.Tensor,
B_flat: torch.Tensor,
C_flat: torch.Tensor,
dt_flat: torch.Tensor,
A: torch.Tensor,
D: torch.Tensor,
dt_bias: torch.Tensor,
batch_info_host: torch.Tensor,
cu_seqlen: torch.Tensor,
slot_idx: torch.Tensor,
use_initial_states: torch.Tensor,
chunk_indices: torch.Tensor,
chunk_offsets: torch.Tensor,
seq_idx_prefill: torch.Tensor,
ssm_state_cache: torch.Tensor,
time_step_limit: List[float],
chunk_size: int,
out: Optional[torch.Tensor] = None,
) -> Tuple[Optional[torch.Tensor], int, int, int, int]:
num_prefill, num_prefill_tokens, num_decode = batch_info_host.tolist()
num_seq = num_prefill + num_decode
num_total_tokens = num_prefill_tokens + num_decode
if num_prefill <= 0:
return num_prefill, num_prefill_tokens, num_total_tokens, num_seq
hs_prefill = hs_flat[:num_prefill_tokens].unsqueeze(0) # [1, S_p, H, D]
B_prefill = B_flat[:num_prefill_tokens].unsqueeze(0) # [1, S_p, G, N]
C_prefill = C_flat[:num_prefill_tokens].unsqueeze(0) # [1, S_p, G, N]
dt_prefill = dt_flat[:num_prefill_tokens].unsqueeze(0) # [1, S_p, H]
initial_states = None
if torch.any(use_initial_states[:num_prefill]):
initial_states = torch.where(
use_initial_states[:num_prefill, None, None, None],
ssm_state_cache[slot_idx[:num_prefill]],
0,
)
else:
chunk_indices = None
chunk_offsets = None
varlen_states = mamba_chunk_scan_combined(
hs_prefill,
dt_prefill,
A,
B_prefill,
C_prefill,
chunk_size=chunk_size,
D=D,
z=None,
dt_bias=dt_bias,
initial_states=initial_states,
seq_idx=seq_idx_prefill,
chunk_indices=chunk_indices,
chunk_offsets=chunk_offsets,
cu_seqlens=cu_seqlen[: num_prefill + 1],
dt_softplus=True,
dt_limit=(time_step_limit[0], time_step_limit[1]),
return_final_states=False,
return_varlen_states=True,
state_dtype=ssm_state_cache.dtype,
out=out,
)
ssm_state_cache.index_copy_(
0, slot_idx[:num_prefill].long(), varlen_states.to(ssm_state_cache.dtype)
)
return num_prefill, num_prefill_tokens, num_total_tokens, num_seq
def _prepare_ssm_decode_inputs(
hs_flat: torch.Tensor,
B_flat: torch.Tensor,
C_flat: torch.Tensor,
dt_flat: torch.Tensor,
A: torch.Tensor,
D: torch.Tensor,
dt_bias: torch.Tensor,
slot_idx: torch.Tensor,
num_prefill: int,
num_prefill_tokens: int,
num_seq: int,
num_total_tokens: int,
num_heads: int,
head_dim: int,
ssm_state_size: int,
) -> Optional[
Tuple[
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
]
]:
num_decode = num_total_tokens - num_prefill_tokens
if num_decode <= 0:
return None
slot_idx_decode = slot_idx[num_prefill:num_seq]
x_decode = hs_flat[num_prefill_tokens:num_total_tokens] # [nd, H, D]
B_decode = B_flat[num_prefill_tokens:num_total_tokens] # [nd, G, N]
C_decode = C_flat[num_prefill_tokens:num_total_tokens] # [nd, G, N]
dt_decode = dt_flat[num_prefill_tokens:num_total_tokens] # [nd, H]
dt_hp = dt_decode[:, :, None].expand(-1, num_heads, head_dim)
dt_bias_hp = dt_bias[..., None].expand(num_heads, head_dim)
A_full = A[..., None, None].expand(num_heads, head_dim, ssm_state_size)
D_full = D[..., None].expand(num_heads, head_dim)
return slot_idx_decode, x_decode, B_decode, C_decode, dt_hp, dt_bias_hp, A_full, D_full
class BaseBackendSSM(AttentionDescriptor):
@classmethod
def get_attention_layout(cls) -> AttentionLayout:
# Hidden states follow [b, s, n, d]
return "bsnd"
@classmethod
def get_num_qkv_args(cls) -> int:
# torch_ssm_transform signature has 7 node/state arguments
return 7
@classmethod
def get_source_attention_op(cls) -> OpOverloadPacket:
# Keep source op unchanged (used for uncached pre-export)
return torch.ops.auto_deploy.torch_ssm
@classmethod
def get_standard_metadata_args(cls) -> List[str]:
return ["batch_info_host", "cu_seqlen", "slot_idx", "use_initial_states"]
@classmethod
def get_prepare_extra_metadata_info(
cls, any_source_attn_node: Node
) -> Tuple[PrepareMetadataCallable, int, List[Constant]]:
return (
torch.ops.auto_deploy.mamba_ssm_prepare_metadata.default,
3, # chunk_indices, chunk_offsets, seq_idx_prefill
extract_op_args(any_source_attn_node, "chunk_size"),
)
@classmethod
def get_cache_initializers(
cls, source_attn_node: Node, cache_config: KvCacheConfig
) -> ResourceHandlerDict:
# Shapes from fake tensors
hs_fake: torch.Tensor = source_attn_node.args[0].meta["val"]
B_fake: torch.Tensor = source_attn_node.args[2].meta["val"]
num_heads = hs_fake.shape[-2]
head_dim = hs_fake.shape[-1]
if B_fake.ndim >= 4:
ssm_state_size = B_fake.shape[-1]
else:
ssm_state_size = max(1, B_fake.shape[-1])
# extract ssm_state_dtype from cache_config or hs_fake
ssm_state_dtype = cls.resolve_cache_dtype(cache_config.mamba_ssm_cache_dtype, hs_fake.dtype)
return {
"ssm_state_cache": StateResourceHandler(
num_heads, head_dim, ssm_state_size, dtype=ssm_state_dtype
)
}
@classmethod
def get_constants(cls, source_attn_node: Node) -> List[Constant]:
time_step_limit, chunk_size = extract_op_args(
source_attn_node, "time_step_limit", "chunk_size"
)
return [time_step_limit, chunk_size]

View File

@ -13,92 +13,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Tuple
from typing import List
import torch
from torch._ops import OpOverloadPacket
from torch.fx import Node
# Triton kernels
from tensorrt_llm._torch.modules.mamba.mamba2_metadata import cu_seqlens_to_chunk_indices_offsets
from tensorrt_llm._torch.modules.mamba.selective_state_update import selective_state_update
from tensorrt_llm._torch.modules.mamba.ssd_combined import mamba_chunk_scan_combined
from .....llmapi.llm_args import KvCacheConfig
from ...utils.node_utils import extract_op_args
from ..attention_interface import (
AttentionDescriptor,
AttentionLayout,
AttentionRegistry,
Constant,
MHACallable,
PrepareMetadataCallable,
ResourceHandlerDict,
StateResourceHandler,
from ..attention_interface import AttentionRegistry, MHACallable
from .mamba_backend_common import (
BaseBackendSSM,
_flatten_ssm_inputs,
_prepare_ssm_decode_inputs,
_run_ssm_prefill,
)
@torch.library.custom_op("auto_deploy::triton_ssm_prepare_metadata", mutates_args=())
def _triton_ssm_prepare_metadata(
# INPUTS
position_ids: torch.Tensor,
batch_info_host: torch.Tensor,
seq_len: torch.Tensor,
cu_seqlen: torch.Tensor,
# EXTRA METADATA PROVIDED BY THE DESCRIPTOR
chunk_size: int,
) -> List[torch.Tensor]:
"""Prepare metadata for cached SSM transform.
Returns a tuple of (seq_len_sanitized, seq_start, slot_idx_sanitized).
"""
device = cu_seqlen.device
num_prefill, num_prefill_tokens, num_decode = batch_info_host.tolist()
if num_prefill > 0:
chunk_indices, chunk_offsets = cu_seqlens_to_chunk_indices_offsets(
cu_seqlen[: num_prefill + 1], chunk_size
)
seq_idx_prefill = torch.repeat_interleave(
torch.arange(num_prefill, device=device, dtype=torch.int32), seq_len[:num_prefill]
).view(1, -1)
else:
chunk_indices = torch.empty(0, dtype=torch.int32, device=device)
chunk_offsets = torch.empty(0, dtype=torch.int32, device=device)
seq_idx_prefill = torch.empty(1, 0, dtype=torch.int32, device=device)
return (chunk_indices, chunk_offsets, seq_idx_prefill)
@_triton_ssm_prepare_metadata.register_fake
def _triton_ssm_prepare_metadata_fake(
# INPUTS
position_ids: torch.Tensor,
batch_info_host: torch.Tensor,
seq_len: torch.Tensor,
cu_seqlen: torch.Tensor,
# EXTRA METADATA PROVIDED BY THE DESCRIPTOR
chunk_size: int,
):
b, s = position_ids.shape[:2]
num_tokens = b * s
device = cu_seqlen.device
dtype = torch.int32
if s > 1:
# NOTE: this is only an upper bound for the shape in this case...
return (
torch.empty(num_tokens, dtype=dtype, device=device), # chunk_indices
torch.empty(num_tokens, dtype=dtype, device=device), # chunk_offsets
torch.empty(1, num_tokens, dtype=dtype, device=device), # seq_idx_prefill
)
else:
return (
torch.empty(0, dtype=dtype, device=device), # chunk_indices
torch.empty(0, dtype=dtype, device=device), # chunk_offsets
torch.empty(1, 0, dtype=dtype, device=device), # seq_idx_prefill
)
@torch.library.custom_op("auto_deploy::triton_cached_ssm", mutates_args={})
def _triton_cached_ssm(
# INPUTS (dense but may be flattened across sequences)
@ -124,26 +53,10 @@ def _triton_cached_ssm(
time_step_limit: List[float],
chunk_size: int,
) -> torch.Tensor:
"""Flattened cached SSM transform op that respects slot-indexed state caches.
Split mixed batches into prefill (seq_len>1) and decode (seq_len==1):
- Prefill: run one varlen combined scan over concatenated prefill tokens and update final states per slot.
- Decode: batch single-token updates with selective_state_update and update states per slot.
"""
b, s, num_heads, head_dim = hidden_states.shape
# Flatten tokens for indexing/scatter
bs = b * s
hs_flat = hidden_states.reshape(bs, *hidden_states.shape[2:]) # [bs, H, D]
B_flat = B.reshape(bs, *B.shape[2:]) # [bs, G, N]
C_flat = C.reshape(bs, *C.shape[2:]) # [bs, G, N]
dt_flat = dt.reshape(bs, dt.shape[2]) # [bs, H]
b, s, num_heads, head_dim, bs, hs_flat, B_flat, C_flat, dt_flat = _flatten_ssm_inputs(
hidden_states, B, C, dt
)
ssm_state_size = B.shape[3]
num_prefill, num_prefill_tokens, num_decode = batch_info_host.tolist()
num_seq = num_prefill + num_decode
num_total_tokens = num_prefill_tokens + num_decode
# Preallocate output tensor to avoid memcpy cost for merging prefill
# and decode outputs
preallocated_ssm_out = torch.empty(
@ -151,68 +64,63 @@ def _triton_cached_ssm(
dtype=hidden_states.dtype,
device=hidden_states.device,
)
num_prefill, num_prefill_tokens, num_decode = batch_info_host.tolist()
num_seq = num_prefill + num_decode
num_total_tokens = num_prefill_tokens + num_decode
preallocated_ssm_out_p = preallocated_ssm_out[:num_prefill_tokens]
preallocated_ssm_out_d = preallocated_ssm_out[num_prefill_tokens:num_total_tokens]
# Prefill: concatenate tokens at the front and run combined scan
if num_prefill > 0:
hs_prefill = hs_flat[:num_prefill_tokens].unsqueeze(0) # [1, S_p, H, D]
B_prefill = B_flat[:num_prefill_tokens].unsqueeze(0) # [1, S_p, G, N]
C_prefill = C_flat[:num_prefill_tokens].unsqueeze(0) # [1, S_p, G, N]
dt_prefill = dt_flat[:num_prefill_tokens].unsqueeze(0) # [1, S_p, H]
num_prefill, num_prefill_tokens, num_total_tokens, num_seq = _run_ssm_prefill(
hs_flat,
B_flat,
C_flat,
dt_flat,
A,
D,
dt_bias,
batch_info_host,
cu_seqlen,
slot_idx,
use_initial_states,
chunk_indices,
chunk_offsets,
seq_idx_prefill,
ssm_state_cache,
time_step_limit,
chunk_size,
preallocated_ssm_out_p.unsqueeze(0),
)
initial_states = None
if torch.any(use_initial_states[:num_prefill]):
initial_states = torch.where(
use_initial_states[:num_prefill, None, None, None],
ssm_state_cache[slot_idx[:num_prefill]],
0,
)
else:
chunk_indices = None
chunk_offsets = None
varlen_states = mamba_chunk_scan_combined(
hs_prefill,
dt_prefill,
A,
B_prefill,
C_prefill,
chunk_size=chunk_size,
D=D,
z=None,
dt_bias=dt_bias,
initial_states=initial_states,
seq_idx=seq_idx_prefill,
chunk_indices=chunk_indices,
chunk_offsets=chunk_offsets,
cu_seqlens=cu_seqlen[: num_prefill + 1],
dt_softplus=True,
dt_limit=(time_step_limit[0], time_step_limit[1]),
return_final_states=False,
return_varlen_states=True,
out=preallocated_ssm_out_p.unsqueeze(0),
state_dtype=ssm_state_cache.dtype,
)
ssm_state_cache.index_copy_(
0, slot_idx[:num_prefill].long(), varlen_states.to(ssm_state_cache.dtype)
)
# Decode: batch single-token updates via selective_state_update
if num_decode > 0:
slot_idx_decode = slot_idx[num_prefill:num_seq]
x_decode = hs_flat[num_prefill_tokens:num_total_tokens] # [nd, H, D]
B_decode = B_flat[num_prefill_tokens:num_total_tokens] # [nd, G, N]
C_decode = C_flat[num_prefill_tokens:num_total_tokens] # [nd, G, N]
dt_decode = dt_flat[num_prefill_tokens:num_total_tokens] # [nd, H]
dt_hp = dt_decode[:, :, None].expand(-1, num_heads, head_dim)
dt_bias_hp = dt_bias[..., None].expand(num_heads, head_dim)
A_full = A[..., None, None].expand(num_heads, head_dim, ssm_state_size)
D_full = D[..., None].expand(num_heads, head_dim)
num_decode = num_total_tokens - num_prefill_tokens
decode_inputs = _prepare_ssm_decode_inputs(
hs_flat,
B_flat,
C_flat,
dt_flat,
A,
D,
dt_bias,
slot_idx,
num_prefill,
num_prefill_tokens,
num_seq,
num_total_tokens,
num_heads,
head_dim,
ssm_state_size,
)
if decode_inputs is not None:
(
slot_idx_decode,
x_decode,
B_decode,
C_decode,
dt_hp,
dt_bias_hp,
A_full,
D_full,
) = decode_inputs
selective_state_update(
ssm_state_cache,
x_decode,
@ -228,7 +136,6 @@ def _triton_cached_ssm(
out=preallocated_ssm_out_d,
)
# Return the preallocated output reshaped to original dimensions
if num_total_tokens > 0:
return (
preallocated_ssm_out[:num_total_tokens]
@ -273,68 +180,7 @@ def _triton_cached_ssm_fake(
@AttentionRegistry.register("triton_ssm")
class TritonBackendSSM(AttentionDescriptor):
@classmethod
def get_attention_layout(cls) -> AttentionLayout:
# Hidden states follow [b, s, n, d]
return "bsnd"
@classmethod
def get_num_qkv_args(cls) -> int:
# torch_ssm_transform signature has 7 node/state arguments
return 7
@classmethod
def get_source_attention_op(cls) -> OpOverloadPacket:
# Keep source op unchanged (used for uncached pre-export)
return torch.ops.auto_deploy.torch_ssm
class TritonBackendSSM(BaseBackendSSM):
@classmethod
def get_cached_attention_op(cls) -> MHACallable:
return torch.ops.auto_deploy.triton_cached_ssm.default
@classmethod
def get_standard_metadata_args(cls) -> List[str]:
return ["batch_info_host", "cu_seqlen", "slot_idx", "use_initial_states"]
@classmethod
def get_prepare_extra_metadata_info(
cls, any_source_attn_node: Node
) -> Tuple[PrepareMetadataCallable, int, List[Constant]]:
return (
torch.ops.auto_deploy.triton_ssm_prepare_metadata.default,
3, # chunk_indices, chunk_offsets, seq_idx_prefill
extract_op_args(any_source_attn_node, "chunk_size"),
)
@classmethod
def get_cache_initializers(
cls, source_attn_node: Node, cache_config: KvCacheConfig
) -> ResourceHandlerDict:
# Shapes from fake tensors
hs_fake: torch.Tensor = source_attn_node.args[0].meta["val"]
B_fake: torch.Tensor = source_attn_node.args[2].meta["val"]
num_heads = hs_fake.shape[-2]
head_dim = hs_fake.shape[-1]
if B_fake.ndim >= 4:
ssm_state_size = B_fake.shape[-1]
else:
ssm_state_size = max(1, B_fake.shape[-1])
# extract ssm_state_dtype from cache_config or hs_fake
ssm_state_dtype = cls.resolve_cache_dtype(cache_config.mamba_ssm_cache_dtype, hs_fake.dtype)
return {
"ssm_state_cache": StateResourceHandler(
num_heads, head_dim, ssm_state_size, dtype=ssm_state_dtype
)
}
@classmethod
def get_constants(cls, source_attn_node: Node) -> List[Constant]:
time_step_limit, chunk_size = extract_op_args(
source_attn_node, "time_step_limit", "chunk_size"
)
return [time_step_limit, chunk_size]

View File

@ -143,8 +143,14 @@ class TestNemotronH(LlmapiAccuracyTestHarness):
@pytest.mark.skip_less_device_memory(32000)
@pytest.mark.parametrize("enable_chunked_prefill", [False, True])
def test_auto_dtype(self, enable_chunked_prefill):
@pytest.mark.parametrize("ssm_backend", ["triton_ssm", "flashinfer_ssm"])
def test_auto_dtype(self, enable_chunked_prefill, ssm_backend):
kwargs = self.get_default_kwargs(enable_chunked_prefill)
kwargs.setdefault("transforms", {})
insert_ssm_cfg = {"backend": ssm_backend}
if ssm_backend == "flashinfer_ssm":
insert_ssm_cfg["cache_config"] = {"mamba_dtype": "bfloat16"}
kwargs["transforms"]["insert_cached_ssm_attention"] = insert_ssm_cfg
sampling_params = self.get_default_sampling_params()
with AutoDeployLLM(model=self.MODEL_PATH,
tokenizer=self.MODEL_PATH,

View File

@ -434,8 +434,9 @@ l0_h100:
- unittest/_torch/auto_deploy/unit/singlegpu
- accuracy/test_llm_api_autodeploy.py::TestLlama3_1_8B::test_auto_dtype[False-1]
- accuracy/test_llm_api_autodeploy.py::TestLlama3_1_8B::test_auto_dtype[True-1]
- accuracy/test_llm_api_autodeploy.py::TestNemotronH::test_auto_dtype[False]
- accuracy/test_llm_api_autodeploy.py::TestNemotronH::test_auto_dtype[True]
- accuracy/test_llm_api_autodeploy.py::TestNemotronH::test_auto_dtype[triton_ssm-False]
- accuracy/test_llm_api_autodeploy.py::TestNemotronH::test_auto_dtype[flashinfer_ssm-False]
- accuracy/test_llm_api_autodeploy.py::TestNemotronH::test_auto_dtype[triton_ssm-True]
- accuracy/test_llm_api_autodeploy.py::TestNemotronMOE::test_fp8[1]
- accuracy/test_llm_api_autodeploy.py::TestNemotronMOE::test_bf16[1]
- examples/test_ad_speculative_decoding.py::test_autodeploy_spec_dec_output[draft_target]

View File

@ -0,0 +1,103 @@
import pytest
import torch
import tensorrt_llm._torch.auto_deploy # noqa: F401
from tests.unittest._torch.auto_deploy.unit.singlegpu.custom_ops.test_triton_mamba_cached_op import (
_random_params,
)
@pytest.fixture
def mamba_env():
device = "cuda"
dtype = torch.bfloat16
atol = 1e-3
rtol = 1e-3
torch.manual_seed(42)
torch.cuda.empty_cache()
return {"device": device, "dtype": dtype, "atol": atol, "rtol": rtol}
def test_flashinfer_decode_matches_triton(mamba_env):
device = mamba_env["device"]
dtype = mamba_env["dtype"]
atol = mamba_env["atol"]
rtol = mamba_env["rtol"]
batch, seq = 2, 1
num_heads, head_dim = 2, 64
n_groups, ssm_state_size = 2, 64
(hidden_states, A, B, C, D, dt, dt_bias, time_step_limit, chunk_size) = _random_params(
device, dtype, batch, seq, num_heads, head_dim, n_groups, ssm_state_size
)
max_batch_size = 4
slot_idx = torch.tensor([0, 2], device=device, dtype=torch.int32)
ssm_state_cache_triton = torch.randn(
max_batch_size, num_heads, head_dim, ssm_state_size, device=device, dtype=dtype
)
ssm_state_cache_flashinfer = ssm_state_cache_triton.clone()
# batch_info_host: [num_prefill, num_prefill_tokens, num_decode]
batch_info_host = torch.tensor([0, 0, batch], device=device, dtype=torch.int32)
cu_seqlen = torch.zeros(batch + 1, device=device, dtype=torch.int32)
use_initial_states = torch.zeros(batch, device=device, dtype=torch.bool)
y_triton = torch.ops.auto_deploy.triton_cached_ssm(
hidden_states,
A,
B,
C,
D,
dt,
dt_bias,
# STANDARD METADATA
batch_info_host,
cu_seqlen,
slot_idx,
use_initial_states,
# EXTRA METADATA
None, # chunk indices
None, # chunk offsets
None, # seq_idx_prefill
# CACHES
ssm_state_cache_triton,
# CONSTANTS
time_step_limit,
chunk_size,
)
y_flashinfer = torch.ops.auto_deploy.flashinfer_cached_ssm(
hidden_states,
A,
B,
C,
D,
dt,
dt_bias,
# STANDARD METADATA
batch_info_host,
cu_seqlen,
slot_idx,
use_initial_states,
# EXTRA METADATA
None, # chunk indices
None, # chunk offsets
None, # seq_idx_prefill
# CACHES
ssm_state_cache_flashinfer,
# CONSTANTS
time_step_limit,
chunk_size,
)
assert y_triton.shape == hidden_states.shape
assert y_flashinfer.shape == hidden_states.shape
assert torch.isfinite(y_flashinfer).all()
assert torch.allclose(y_flashinfer, y_triton.to(y_flashinfer.dtype), atol=atol, rtol=rtol)
after_triton = ssm_state_cache_triton.index_select(0, slot_idx)
after_flashinfer = ssm_state_cache_flashinfer.index_select(0, slot_idx)
assert torch.allclose(
after_flashinfer.to(after_triton.dtype), after_triton, atol=atol, rtol=rtol
)

View File

@ -991,14 +991,5 @@ def test_stack_nvfp4_moe_weights_transform_relu2(hidden_size, intermediate_size)
# Run the transformed graph
transformed_output = gm(x, selected_experts, routing_weights)
# Get the registered parameters after transform
fc1_act_scale = getattr(gm, "nvfp4_moe_w3_w1_input_scale_stacked_0", None)
fc1_alpha = getattr(gm, "nvfp4_moe_w1_alpha_stacked_0", None)
if fc1_act_scale is not None:
print(f"fc1_act_scale (after transform): {fc1_act_scale}, shape: {fc1_act_scale.shape}")
if fc1_alpha is not None:
print(f"fc1_alpha (after transform): {fc1_alpha}, shape: {fc1_alpha.shape}")
# Should be close for FP4 quantization (gated MLP may have slightly larger diff due to alpha handling)
tol = 1e-3
torch.testing.assert_close(ref_output, transformed_output, rtol=tol, atol=tol)

View File

@ -184,6 +184,7 @@ def _check_ad_config(experiment_config: ExperimentConfig, llm_args: LlmArgs):
"transforms": {
"insert_cached_attention": {"backend": "flashinfer"},
"compile_model": {"backend": "torch-simple"},
"insert_cached_ssm_attention": {"backend": "triton_ssm"},
},
},
),
@ -192,6 +193,7 @@ def _check_ad_config(experiment_config: ExperimentConfig, llm_args: LlmArgs):
{
"transforms": {
"multi_stream_moe": {"stage": "compile", "enabled": True},
"insert_cached_ssm_attention": {"backend": "triton_ssm"},
# TODO: https://github.com/NVIDIA/TensorRT-LLM/issues/9878
"compile_model": {"backend": "torch-cudagraph"},
},