diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_attention.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_attention.py index eb6ed39fef..25439616a9 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_attention.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_attention.py @@ -113,7 +113,7 @@ class _FlashInferPlanner: self.prefill_wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper( self.workspace_buffer, self.kv_layout, - backend="fa2", + backend="fa2" if torch.cuda.get_device_capability(0) == (9, 0) else "auto", ) self.decode_wrapper = self._init_decode_wrapper() diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/flashinfer_backend_mamba.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/flashinfer_backend_mamba.py new file mode 100644 index 0000000000..5779ad36d2 --- /dev/null +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/flashinfer_backend_mamba.py @@ -0,0 +1,246 @@ +from typing import List + +import torch +from torch.fx import Node + +from .....llmapi.llm_args import KvCacheConfig +from ..attention_interface import ( + AttentionRegistry, + MHACallable, + ResourceHandler, + ResourceHandlerDict, + SequenceInfo, +) +from .mamba_backend_common import ( + BaseBackendSSM, + _flatten_ssm_inputs, + _prepare_ssm_decode_inputs, + _run_ssm_prefill, +) + + +@torch.library.custom_op("auto_deploy::flashinfer_cached_ssm", mutates_args={}) +def _flashinfer_cached_ssm( + # INPUTS (dense but may be flattened across sequences) + hidden_states: torch.Tensor, # [b, s, num_heads, head_dim] + A: torch.Tensor, # [num_heads] + B: torch.Tensor, # [b, s, n_groups, ssm_state_size] + C: torch.Tensor, # [b, s, n_groups, ssm_state_size] + D: torch.Tensor, # [num_heads] + dt: torch.Tensor, # [b, s, num_heads] + dt_bias: torch.Tensor, # [num_heads] + # STANDARD METADATA + batch_info_host: torch.Tensor, + cu_seqlen: torch.Tensor, + slot_idx: torch.Tensor, + use_initial_states: torch.Tensor, + # EXTRA METADATA + chunk_indices: torch.Tensor, # [num_logical_chunks] + chunk_offsets: torch.Tensor, # [num_logical_chunks] + seq_idx_prefill: torch.Tensor, # [1, num_prefill_tokens] + # CACHES + ssm_state_cache: torch.Tensor, # [max_batch_size, num_heads, head_dim, ssm_state_size] + # CONSTANTS + time_step_limit: List[float], + chunk_size: int, +) -> torch.Tensor: + b, s, num_heads, head_dim, bs, hs_flat, B_flat, C_flat, dt_flat = _flatten_ssm_inputs( + hidden_states, B, C, dt + ) + ssm_state_size = B.shape[3] + num_prefill, num_prefill_tokens, num_decode = batch_info_host.tolist() + num_seq = num_prefill + num_decode + num_total_tokens = num_prefill_tokens + num_decode + # Preallocate output tensor to avoid memcpy cost for merging prefill + # and decode outputs + preallocated_ssm_out = torch.empty( + [bs, num_heads, head_dim], + dtype=hidden_states.dtype, + device=hidden_states.device, + ) + preallocated_ssm_out_p = preallocated_ssm_out[:num_prefill_tokens] + + num_prefill, num_prefill_tokens, num_total_tokens, num_seq = _run_ssm_prefill( + hs_flat, + B_flat, + C_flat, + dt_flat, + A, + D, + dt_bias, + batch_info_host, + cu_seqlen, + slot_idx, + use_initial_states, + chunk_indices, + chunk_offsets, + seq_idx_prefill, + ssm_state_cache, + time_step_limit, + chunk_size, + preallocated_ssm_out_p.unsqueeze(0), + ) + + num_decode = num_total_tokens - num_prefill_tokens + decode_inputs = _prepare_ssm_decode_inputs( + hs_flat, + B_flat, + C_flat, + dt_flat, + A, + D, + dt_bias, + slot_idx, + num_prefill, + num_prefill_tokens, + num_seq, + num_total_tokens, + num_heads, + head_dim, + ssm_state_size, + ) + + y_decode = None + if decode_inputs is not None: + ( + slot_idx_decode, + x_decode, + B_decode, + C_decode, + dt_hp, + dt_bias_hp, + A_full, + D_full, + ) = decode_inputs + + import flashinfer + + slot_idx_decode_i32 = slot_idx_decode.to(torch.int32) + y_decode = flashinfer.mamba.selective_state_update( + ssm_state_cache, + x_decode, + dt_hp, + A_full, + B_decode, + C_decode, + D=D_full, + z=None, + dt_bias=dt_bias_hp, + dt_softplus=True, + state_batch_indices=slot_idx_decode_i32, + ) + preallocated_ssm_out[num_prefill_tokens:num_total_tokens].copy_(y_decode) + if num_total_tokens > 0: + return ( + preallocated_ssm_out[:num_total_tokens] + .view(b, s, num_heads, head_dim) + .to(hidden_states.dtype) + ) + else: + return torch.empty_like(hidden_states) + + +@_flashinfer_cached_ssm.register_fake +def _flashinfer_cached_ssm_fake( + # INPUTS (dense but may be flattened across sequences) + hidden_states: torch.Tensor, # [b, s, num_heads, head_dim] + A: torch.Tensor, # [num_heads] + B: torch.Tensor, # [b, s, n_groups, ssm_state_size] + C: torch.Tensor, # [b, s, n_groups, ssm_state_size] + D: torch.Tensor, # [num_heads] + dt: torch.Tensor, # [b, s, num_heads] + dt_bias: torch.Tensor, # [num_heads] + # STANDARD METADATA + batch_info_host: torch.Tensor, + cu_seqlen: torch.Tensor, + slot_idx: torch.Tensor, + use_initial_states: torch.Tensor, + # EXTRA METADATA + chunk_indices: torch.Tensor, # [num_logical_chunks] + chunk_offsets: torch.Tensor, # [num_logical_chunks] + seq_idx_prefill: torch.Tensor, # [1, num_prefill_tokens] + # CACHES + ssm_state_cache: torch.Tensor, # [max_batch_size, num_heads, head_dim, ssm_state_size] + # CONSTANTS + time_step_limit: List[float], + chunk_size: int, +): + # Return a correctly-shaped tensor for tracing with fake tensors + return torch.empty_like( + hidden_states, + memory_format=torch.contiguous_format, + dtype=hidden_states.dtype, + ) + + +# Flashinfer's selective_state_update kernel only supports these head dimensions +FLASHINFER_SUPPORTED_HEAD_DIMS = [64, 128] + + +class FlashInferStateResourceHandler(ResourceHandler): + """Handler for flashinfer SSM state resources. + + Unlike the default StateResourceHandler which uses byte-level pooling (resulting + in non-contiguous strided views), this handler allocates a separate contiguous + buffer. This is required because flashinfer's selective_state_update kernel + requires the entire state tensor to be contiguous. + """ + + def __init__(self, *state_shape: int, dtype: torch.dtype) -> None: + self.state_shape = state_shape + self.dtype = dtype + + def allocate(self, sequence_info: SequenceInfo) -> torch.Tensor: + """Allocate a contiguous state buffer for flashinfer.""" + return torch.empty( + sequence_info.max_num_state_slots, + *self.state_shape, + device=sequence_info.device, + dtype=self.dtype, + ) + + +@AttentionRegistry.register("flashinfer_ssm") +class FlashinferBackendSSM(BaseBackendSSM): + @classmethod + def get_cached_attention_op(cls) -> MHACallable: + return torch.ops.auto_deploy.flashinfer_cached_ssm.default + + @classmethod + def get_cache_initializers( + cls, source_attn_node: Node, cache_config: KvCacheConfig + ) -> ResourceHandlerDict: + """Get cache initializers using FlashInferStateResourceHandler. + + We use a custom handler that allocates contiguous buffers directly, + instead of the default StateResourceHandler which creates non-contiguous + views from a shared byte buffer. This is required because flashinfer's + selective_state_update kernel requires contiguous state tensors. + """ + # Shapes from fake tensors + hs_fake: torch.Tensor = source_attn_node.args[0].meta["val"] + B_fake: torch.Tensor = source_attn_node.args[2].meta["val"] + + num_heads = hs_fake.shape[-2] + head_dim = hs_fake.shape[-1] + + # Validate head_dim is supported by flashinfer + if head_dim not in FLASHINFER_SUPPORTED_HEAD_DIMS: + raise ValueError( + f"Flashinfer SSM backend only supports head_dim in {FLASHINFER_SUPPORTED_HEAD_DIMS}, " + f"but got head_dim={head_dim}. Consider using 'triton_ssm' backend instead." + ) + + if B_fake.ndim >= 4: + ssm_state_size = B_fake.shape[-1] + else: + ssm_state_size = max(1, B_fake.shape[-1]) + + # Extract ssm_state_dtype from cache_config or hs_fake + ssm_state_dtype = cls.resolve_cache_dtype(cache_config.mamba_ssm_cache_dtype, hs_fake.dtype) + + return { + "ssm_state_cache": FlashInferStateResourceHandler( + num_heads, head_dim, ssm_state_size, dtype=ssm_state_dtype + ) + } diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/mamba_backend_common.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/mamba_backend_common.py new file mode 100644 index 0000000000..1ab0849d71 --- /dev/null +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/mamba_backend_common.py @@ -0,0 +1,291 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional, Tuple + +import torch +from torch._ops import OpOverloadPacket +from torch.fx import Node + +from tensorrt_llm._torch.modules.mamba.mamba2_metadata import cu_seqlens_to_chunk_indices_offsets +from tensorrt_llm._torch.modules.mamba.ssd_combined import mamba_chunk_scan_combined + +from .....llmapi.llm_args import KvCacheConfig +from ...utils.node_utils import extract_op_args +from ..attention_interface import ( + AttentionDescriptor, + AttentionLayout, + Constant, + PrepareMetadataCallable, + ResourceHandlerDict, + StateResourceHandler, +) + + +@torch.library.custom_op("auto_deploy::mamba_ssm_prepare_metadata", mutates_args=()) +def _mamba_ssm_prepare_metadata( + # INPUTS + position_ids: torch.Tensor, + batch_info_host: torch.Tensor, + seq_len: torch.Tensor, + cu_seqlen: torch.Tensor, + # EXTRA METADATA PROVIDED BY THE DESCRIPTOR + chunk_size: int, +) -> List[torch.Tensor]: + """Prepare metadata for cached SSM transform. + + Returns a tuple of (chunk_indices, chunk_offsets, seq_idx_prefill). + """ + device = cu_seqlen.device + num_prefill, num_prefill_tokens, num_decode = batch_info_host.tolist() + + if num_prefill > 0: + chunk_indices, chunk_offsets = cu_seqlens_to_chunk_indices_offsets( + cu_seqlen[: num_prefill + 1], chunk_size + ) + seq_idx_prefill = torch.repeat_interleave( + torch.arange(num_prefill, device=device, dtype=torch.int32), seq_len[:num_prefill] + ).view(1, -1) + else: + chunk_indices = torch.empty(0, dtype=torch.int32, device=device) + chunk_offsets = torch.empty(0, dtype=torch.int32, device=device) + seq_idx_prefill = torch.empty(1, 0, dtype=torch.int32, device=device) + + return (chunk_indices, chunk_offsets, seq_idx_prefill) + + +@_mamba_ssm_prepare_metadata.register_fake +def _mamba_ssm_prepare_metadata_fake( + # INPUTS + position_ids: torch.Tensor, + batch_info_host: torch.Tensor, + seq_len: torch.Tensor, + cu_seqlen: torch.Tensor, + # EXTRA METADATA PROVIDED BY THE DESCRIPTOR + chunk_size: int, +): + b, s = position_ids.shape[:2] + num_tokens = b * s + device = cu_seqlen.device + dtype = torch.int32 + if s > 1: + # NOTE: this is only an upper bound for the shape in this case... + return ( + torch.empty(num_tokens, dtype=dtype, device=device), # chunk_indices + torch.empty(num_tokens, dtype=dtype, device=device), # chunk_offsets + torch.empty(1, num_tokens, dtype=dtype, device=device), # seq_idx_prefill + ) + else: + return ( + torch.empty(0, dtype=dtype, device=device), # chunk_indices + torch.empty(0, dtype=dtype, device=device), # chunk_offsets + torch.empty(1, 0, dtype=dtype, device=device), # seq_idx_prefill + ) + + +def _flatten_ssm_inputs( + hidden_states: torch.Tensor, + B: torch.Tensor, + C: torch.Tensor, + dt: torch.Tensor, +) -> Tuple[int, int, int, int, int, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + b, s, num_heads, head_dim = hidden_states.shape + bs = b * s + hs_flat = hidden_states.reshape(bs, *hidden_states.shape[2:]) # [bs, H, D] + B_flat = B.reshape(bs, *B.shape[2:]) # [bs, G, N] + C_flat = C.reshape(bs, *C.shape[2:]) # [bs, G, N] + dt_flat = dt.reshape(bs, dt.shape[2]) # [bs, H] + return b, s, num_heads, head_dim, bs, hs_flat, B_flat, C_flat, dt_flat + + +def _run_ssm_prefill( + hs_flat: torch.Tensor, + B_flat: torch.Tensor, + C_flat: torch.Tensor, + dt_flat: torch.Tensor, + A: torch.Tensor, + D: torch.Tensor, + dt_bias: torch.Tensor, + batch_info_host: torch.Tensor, + cu_seqlen: torch.Tensor, + slot_idx: torch.Tensor, + use_initial_states: torch.Tensor, + chunk_indices: torch.Tensor, + chunk_offsets: torch.Tensor, + seq_idx_prefill: torch.Tensor, + ssm_state_cache: torch.Tensor, + time_step_limit: List[float], + chunk_size: int, + out: Optional[torch.Tensor] = None, +) -> Tuple[Optional[torch.Tensor], int, int, int, int]: + num_prefill, num_prefill_tokens, num_decode = batch_info_host.tolist() + num_seq = num_prefill + num_decode + num_total_tokens = num_prefill_tokens + num_decode + + if num_prefill <= 0: + return num_prefill, num_prefill_tokens, num_total_tokens, num_seq + + hs_prefill = hs_flat[:num_prefill_tokens].unsqueeze(0) # [1, S_p, H, D] + B_prefill = B_flat[:num_prefill_tokens].unsqueeze(0) # [1, S_p, G, N] + C_prefill = C_flat[:num_prefill_tokens].unsqueeze(0) # [1, S_p, G, N] + dt_prefill = dt_flat[:num_prefill_tokens].unsqueeze(0) # [1, S_p, H] + + initial_states = None + if torch.any(use_initial_states[:num_prefill]): + initial_states = torch.where( + use_initial_states[:num_prefill, None, None, None], + ssm_state_cache[slot_idx[:num_prefill]], + 0, + ) + else: + chunk_indices = None + chunk_offsets = None + + varlen_states = mamba_chunk_scan_combined( + hs_prefill, + dt_prefill, + A, + B_prefill, + C_prefill, + chunk_size=chunk_size, + D=D, + z=None, + dt_bias=dt_bias, + initial_states=initial_states, + seq_idx=seq_idx_prefill, + chunk_indices=chunk_indices, + chunk_offsets=chunk_offsets, + cu_seqlens=cu_seqlen[: num_prefill + 1], + dt_softplus=True, + dt_limit=(time_step_limit[0], time_step_limit[1]), + return_final_states=False, + return_varlen_states=True, + state_dtype=ssm_state_cache.dtype, + out=out, + ) + + ssm_state_cache.index_copy_( + 0, slot_idx[:num_prefill].long(), varlen_states.to(ssm_state_cache.dtype) + ) + return num_prefill, num_prefill_tokens, num_total_tokens, num_seq + + +def _prepare_ssm_decode_inputs( + hs_flat: torch.Tensor, + B_flat: torch.Tensor, + C_flat: torch.Tensor, + dt_flat: torch.Tensor, + A: torch.Tensor, + D: torch.Tensor, + dt_bias: torch.Tensor, + slot_idx: torch.Tensor, + num_prefill: int, + num_prefill_tokens: int, + num_seq: int, + num_total_tokens: int, + num_heads: int, + head_dim: int, + ssm_state_size: int, +) -> Optional[ + Tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + ] +]: + num_decode = num_total_tokens - num_prefill_tokens + if num_decode <= 0: + return None + + slot_idx_decode = slot_idx[num_prefill:num_seq] + x_decode = hs_flat[num_prefill_tokens:num_total_tokens] # [nd, H, D] + B_decode = B_flat[num_prefill_tokens:num_total_tokens] # [nd, G, N] + C_decode = C_flat[num_prefill_tokens:num_total_tokens] # [nd, G, N] + dt_decode = dt_flat[num_prefill_tokens:num_total_tokens] # [nd, H] + + dt_hp = dt_decode[:, :, None].expand(-1, num_heads, head_dim) + dt_bias_hp = dt_bias[..., None].expand(num_heads, head_dim) + A_full = A[..., None, None].expand(num_heads, head_dim, ssm_state_size) + D_full = D[..., None].expand(num_heads, head_dim) + + return slot_idx_decode, x_decode, B_decode, C_decode, dt_hp, dt_bias_hp, A_full, D_full + + +class BaseBackendSSM(AttentionDescriptor): + @classmethod + def get_attention_layout(cls) -> AttentionLayout: + # Hidden states follow [b, s, n, d] + return "bsnd" + + @classmethod + def get_num_qkv_args(cls) -> int: + # torch_ssm_transform signature has 7 node/state arguments + return 7 + + @classmethod + def get_source_attention_op(cls) -> OpOverloadPacket: + # Keep source op unchanged (used for uncached pre-export) + return torch.ops.auto_deploy.torch_ssm + + @classmethod + def get_standard_metadata_args(cls) -> List[str]: + return ["batch_info_host", "cu_seqlen", "slot_idx", "use_initial_states"] + + @classmethod + def get_prepare_extra_metadata_info( + cls, any_source_attn_node: Node + ) -> Tuple[PrepareMetadataCallable, int, List[Constant]]: + return ( + torch.ops.auto_deploy.mamba_ssm_prepare_metadata.default, + 3, # chunk_indices, chunk_offsets, seq_idx_prefill + extract_op_args(any_source_attn_node, "chunk_size"), + ) + + @classmethod + def get_cache_initializers( + cls, source_attn_node: Node, cache_config: KvCacheConfig + ) -> ResourceHandlerDict: + # Shapes from fake tensors + hs_fake: torch.Tensor = source_attn_node.args[0].meta["val"] + B_fake: torch.Tensor = source_attn_node.args[2].meta["val"] + + num_heads = hs_fake.shape[-2] + head_dim = hs_fake.shape[-1] + + if B_fake.ndim >= 4: + ssm_state_size = B_fake.shape[-1] + else: + ssm_state_size = max(1, B_fake.shape[-1]) + + # extract ssm_state_dtype from cache_config or hs_fake + ssm_state_dtype = cls.resolve_cache_dtype(cache_config.mamba_ssm_cache_dtype, hs_fake.dtype) + + return { + "ssm_state_cache": StateResourceHandler( + num_heads, head_dim, ssm_state_size, dtype=ssm_state_dtype + ) + } + + @classmethod + def get_constants(cls, source_attn_node: Node) -> List[Constant]: + time_step_limit, chunk_size = extract_op_args( + source_attn_node, "time_step_limit", "chunk_size" + ) + return [time_step_limit, chunk_size] diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.py index 5f6ecdee9a..35937d50cf 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.py @@ -13,92 +13,21 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Tuple +from typing import List import torch -from torch._ops import OpOverloadPacket -from torch.fx import Node -# Triton kernels -from tensorrt_llm._torch.modules.mamba.mamba2_metadata import cu_seqlens_to_chunk_indices_offsets from tensorrt_llm._torch.modules.mamba.selective_state_update import selective_state_update -from tensorrt_llm._torch.modules.mamba.ssd_combined import mamba_chunk_scan_combined -from .....llmapi.llm_args import KvCacheConfig -from ...utils.node_utils import extract_op_args -from ..attention_interface import ( - AttentionDescriptor, - AttentionLayout, - AttentionRegistry, - Constant, - MHACallable, - PrepareMetadataCallable, - ResourceHandlerDict, - StateResourceHandler, +from ..attention_interface import AttentionRegistry, MHACallable +from .mamba_backend_common import ( + BaseBackendSSM, + _flatten_ssm_inputs, + _prepare_ssm_decode_inputs, + _run_ssm_prefill, ) -@torch.library.custom_op("auto_deploy::triton_ssm_prepare_metadata", mutates_args=()) -def _triton_ssm_prepare_metadata( - # INPUTS - position_ids: torch.Tensor, - batch_info_host: torch.Tensor, - seq_len: torch.Tensor, - cu_seqlen: torch.Tensor, - # EXTRA METADATA PROVIDED BY THE DESCRIPTOR - chunk_size: int, -) -> List[torch.Tensor]: - """Prepare metadata for cached SSM transform. - - Returns a tuple of (seq_len_sanitized, seq_start, slot_idx_sanitized). - """ - device = cu_seqlen.device - num_prefill, num_prefill_tokens, num_decode = batch_info_host.tolist() - - if num_prefill > 0: - chunk_indices, chunk_offsets = cu_seqlens_to_chunk_indices_offsets( - cu_seqlen[: num_prefill + 1], chunk_size - ) - seq_idx_prefill = torch.repeat_interleave( - torch.arange(num_prefill, device=device, dtype=torch.int32), seq_len[:num_prefill] - ).view(1, -1) - else: - chunk_indices = torch.empty(0, dtype=torch.int32, device=device) - chunk_offsets = torch.empty(0, dtype=torch.int32, device=device) - seq_idx_prefill = torch.empty(1, 0, dtype=torch.int32, device=device) - - return (chunk_indices, chunk_offsets, seq_idx_prefill) - - -@_triton_ssm_prepare_metadata.register_fake -def _triton_ssm_prepare_metadata_fake( - # INPUTS - position_ids: torch.Tensor, - batch_info_host: torch.Tensor, - seq_len: torch.Tensor, - cu_seqlen: torch.Tensor, - # EXTRA METADATA PROVIDED BY THE DESCRIPTOR - chunk_size: int, -): - b, s = position_ids.shape[:2] - num_tokens = b * s - device = cu_seqlen.device - dtype = torch.int32 - if s > 1: - # NOTE: this is only an upper bound for the shape in this case... - return ( - torch.empty(num_tokens, dtype=dtype, device=device), # chunk_indices - torch.empty(num_tokens, dtype=dtype, device=device), # chunk_offsets - torch.empty(1, num_tokens, dtype=dtype, device=device), # seq_idx_prefill - ) - else: - return ( - torch.empty(0, dtype=dtype, device=device), # chunk_indices - torch.empty(0, dtype=dtype, device=device), # chunk_offsets - torch.empty(1, 0, dtype=dtype, device=device), # seq_idx_prefill - ) - - @torch.library.custom_op("auto_deploy::triton_cached_ssm", mutates_args={}) def _triton_cached_ssm( # INPUTS (dense but may be flattened across sequences) @@ -124,26 +53,10 @@ def _triton_cached_ssm( time_step_limit: List[float], chunk_size: int, ) -> torch.Tensor: - """Flattened cached SSM transform op that respects slot-indexed state caches. - - Split mixed batches into prefill (seq_len>1) and decode (seq_len==1): - - Prefill: run one varlen combined scan over concatenated prefill tokens and update final states per slot. - - Decode: batch single-token updates with selective_state_update and update states per slot. - """ - b, s, num_heads, head_dim = hidden_states.shape - # Flatten tokens for indexing/scatter - bs = b * s - hs_flat = hidden_states.reshape(bs, *hidden_states.shape[2:]) # [bs, H, D] - B_flat = B.reshape(bs, *B.shape[2:]) # [bs, G, N] - C_flat = C.reshape(bs, *C.shape[2:]) # [bs, G, N] - dt_flat = dt.reshape(bs, dt.shape[2]) # [bs, H] - + b, s, num_heads, head_dim, bs, hs_flat, B_flat, C_flat, dt_flat = _flatten_ssm_inputs( + hidden_states, B, C, dt + ) ssm_state_size = B.shape[3] - - num_prefill, num_prefill_tokens, num_decode = batch_info_host.tolist() - num_seq = num_prefill + num_decode - num_total_tokens = num_prefill_tokens + num_decode - # Preallocate output tensor to avoid memcpy cost for merging prefill # and decode outputs preallocated_ssm_out = torch.empty( @@ -151,68 +64,63 @@ def _triton_cached_ssm( dtype=hidden_states.dtype, device=hidden_states.device, ) + num_prefill, num_prefill_tokens, num_decode = batch_info_host.tolist() + num_seq = num_prefill + num_decode + num_total_tokens = num_prefill_tokens + num_decode preallocated_ssm_out_p = preallocated_ssm_out[:num_prefill_tokens] preallocated_ssm_out_d = preallocated_ssm_out[num_prefill_tokens:num_total_tokens] - # Prefill: concatenate tokens at the front and run combined scan - if num_prefill > 0: - hs_prefill = hs_flat[:num_prefill_tokens].unsqueeze(0) # [1, S_p, H, D] - B_prefill = B_flat[:num_prefill_tokens].unsqueeze(0) # [1, S_p, G, N] - C_prefill = C_flat[:num_prefill_tokens].unsqueeze(0) # [1, S_p, G, N] - dt_prefill = dt_flat[:num_prefill_tokens].unsqueeze(0) # [1, S_p, H] + num_prefill, num_prefill_tokens, num_total_tokens, num_seq = _run_ssm_prefill( + hs_flat, + B_flat, + C_flat, + dt_flat, + A, + D, + dt_bias, + batch_info_host, + cu_seqlen, + slot_idx, + use_initial_states, + chunk_indices, + chunk_offsets, + seq_idx_prefill, + ssm_state_cache, + time_step_limit, + chunk_size, + preallocated_ssm_out_p.unsqueeze(0), + ) - initial_states = None - if torch.any(use_initial_states[:num_prefill]): - initial_states = torch.where( - use_initial_states[:num_prefill, None, None, None], - ssm_state_cache[slot_idx[:num_prefill]], - 0, - ) - else: - chunk_indices = None - chunk_offsets = None - - varlen_states = mamba_chunk_scan_combined( - hs_prefill, - dt_prefill, - A, - B_prefill, - C_prefill, - chunk_size=chunk_size, - D=D, - z=None, - dt_bias=dt_bias, - initial_states=initial_states, - seq_idx=seq_idx_prefill, - chunk_indices=chunk_indices, - chunk_offsets=chunk_offsets, - cu_seqlens=cu_seqlen[: num_prefill + 1], - dt_softplus=True, - dt_limit=(time_step_limit[0], time_step_limit[1]), - return_final_states=False, - return_varlen_states=True, - out=preallocated_ssm_out_p.unsqueeze(0), - state_dtype=ssm_state_cache.dtype, - ) - - ssm_state_cache.index_copy_( - 0, slot_idx[:num_prefill].long(), varlen_states.to(ssm_state_cache.dtype) - ) - - # Decode: batch single-token updates via selective_state_update - if num_decode > 0: - slot_idx_decode = slot_idx[num_prefill:num_seq] - - x_decode = hs_flat[num_prefill_tokens:num_total_tokens] # [nd, H, D] - B_decode = B_flat[num_prefill_tokens:num_total_tokens] # [nd, G, N] - C_decode = C_flat[num_prefill_tokens:num_total_tokens] # [nd, G, N] - dt_decode = dt_flat[num_prefill_tokens:num_total_tokens] # [nd, H] - - dt_hp = dt_decode[:, :, None].expand(-1, num_heads, head_dim) - dt_bias_hp = dt_bias[..., None].expand(num_heads, head_dim) - A_full = A[..., None, None].expand(num_heads, head_dim, ssm_state_size) - D_full = D[..., None].expand(num_heads, head_dim) + num_decode = num_total_tokens - num_prefill_tokens + decode_inputs = _prepare_ssm_decode_inputs( + hs_flat, + B_flat, + C_flat, + dt_flat, + A, + D, + dt_bias, + slot_idx, + num_prefill, + num_prefill_tokens, + num_seq, + num_total_tokens, + num_heads, + head_dim, + ssm_state_size, + ) + if decode_inputs is not None: + ( + slot_idx_decode, + x_decode, + B_decode, + C_decode, + dt_hp, + dt_bias_hp, + A_full, + D_full, + ) = decode_inputs selective_state_update( ssm_state_cache, x_decode, @@ -228,7 +136,6 @@ def _triton_cached_ssm( out=preallocated_ssm_out_d, ) - # Return the preallocated output reshaped to original dimensions if num_total_tokens > 0: return ( preallocated_ssm_out[:num_total_tokens] @@ -273,68 +180,7 @@ def _triton_cached_ssm_fake( @AttentionRegistry.register("triton_ssm") -class TritonBackendSSM(AttentionDescriptor): - @classmethod - def get_attention_layout(cls) -> AttentionLayout: - # Hidden states follow [b, s, n, d] - return "bsnd" - - @classmethod - def get_num_qkv_args(cls) -> int: - # torch_ssm_transform signature has 7 node/state arguments - return 7 - - @classmethod - def get_source_attention_op(cls) -> OpOverloadPacket: - # Keep source op unchanged (used for uncached pre-export) - return torch.ops.auto_deploy.torch_ssm - +class TritonBackendSSM(BaseBackendSSM): @classmethod def get_cached_attention_op(cls) -> MHACallable: return torch.ops.auto_deploy.triton_cached_ssm.default - - @classmethod - def get_standard_metadata_args(cls) -> List[str]: - return ["batch_info_host", "cu_seqlen", "slot_idx", "use_initial_states"] - - @classmethod - def get_prepare_extra_metadata_info( - cls, any_source_attn_node: Node - ) -> Tuple[PrepareMetadataCallable, int, List[Constant]]: - return ( - torch.ops.auto_deploy.triton_ssm_prepare_metadata.default, - 3, # chunk_indices, chunk_offsets, seq_idx_prefill - extract_op_args(any_source_attn_node, "chunk_size"), - ) - - @classmethod - def get_cache_initializers( - cls, source_attn_node: Node, cache_config: KvCacheConfig - ) -> ResourceHandlerDict: - # Shapes from fake tensors - hs_fake: torch.Tensor = source_attn_node.args[0].meta["val"] - B_fake: torch.Tensor = source_attn_node.args[2].meta["val"] - - num_heads = hs_fake.shape[-2] - head_dim = hs_fake.shape[-1] - - if B_fake.ndim >= 4: - ssm_state_size = B_fake.shape[-1] - else: - ssm_state_size = max(1, B_fake.shape[-1]) - - # extract ssm_state_dtype from cache_config or hs_fake - ssm_state_dtype = cls.resolve_cache_dtype(cache_config.mamba_ssm_cache_dtype, hs_fake.dtype) - - return { - "ssm_state_cache": StateResourceHandler( - num_heads, head_dim, ssm_state_size, dtype=ssm_state_dtype - ) - } - - @classmethod - def get_constants(cls, source_attn_node: Node) -> List[Constant]: - time_step_limit, chunk_size = extract_op_args( - source_attn_node, "time_step_limit", "chunk_size" - ) - return [time_step_limit, chunk_size] diff --git a/tests/integration/defs/accuracy/test_llm_api_autodeploy.py b/tests/integration/defs/accuracy/test_llm_api_autodeploy.py index d46820ce90..1ae49622a0 100644 --- a/tests/integration/defs/accuracy/test_llm_api_autodeploy.py +++ b/tests/integration/defs/accuracy/test_llm_api_autodeploy.py @@ -143,8 +143,14 @@ class TestNemotronH(LlmapiAccuracyTestHarness): @pytest.mark.skip_less_device_memory(32000) @pytest.mark.parametrize("enable_chunked_prefill", [False, True]) - def test_auto_dtype(self, enable_chunked_prefill): + @pytest.mark.parametrize("ssm_backend", ["triton_ssm", "flashinfer_ssm"]) + def test_auto_dtype(self, enable_chunked_prefill, ssm_backend): kwargs = self.get_default_kwargs(enable_chunked_prefill) + kwargs.setdefault("transforms", {}) + insert_ssm_cfg = {"backend": ssm_backend} + if ssm_backend == "flashinfer_ssm": + insert_ssm_cfg["cache_config"] = {"mamba_dtype": "bfloat16"} + kwargs["transforms"]["insert_cached_ssm_attention"] = insert_ssm_cfg sampling_params = self.get_default_sampling_params() with AutoDeployLLM(model=self.MODEL_PATH, tokenizer=self.MODEL_PATH, diff --git a/tests/integration/test_lists/test-db/l0_h100.yml b/tests/integration/test_lists/test-db/l0_h100.yml index bb12a3302e..bec6a89f90 100644 --- a/tests/integration/test_lists/test-db/l0_h100.yml +++ b/tests/integration/test_lists/test-db/l0_h100.yml @@ -434,8 +434,9 @@ l0_h100: - unittest/_torch/auto_deploy/unit/singlegpu - accuracy/test_llm_api_autodeploy.py::TestLlama3_1_8B::test_auto_dtype[False-1] - accuracy/test_llm_api_autodeploy.py::TestLlama3_1_8B::test_auto_dtype[True-1] - - accuracy/test_llm_api_autodeploy.py::TestNemotronH::test_auto_dtype[False] - - accuracy/test_llm_api_autodeploy.py::TestNemotronH::test_auto_dtype[True] + - accuracy/test_llm_api_autodeploy.py::TestNemotronH::test_auto_dtype[triton_ssm-False] + - accuracy/test_llm_api_autodeploy.py::TestNemotronH::test_auto_dtype[flashinfer_ssm-False] + - accuracy/test_llm_api_autodeploy.py::TestNemotronH::test_auto_dtype[triton_ssm-True] - accuracy/test_llm_api_autodeploy.py::TestNemotronMOE::test_fp8[1] - accuracy/test_llm_api_autodeploy.py::TestNemotronMOE::test_bf16[1] - examples/test_ad_speculative_decoding.py::test_autodeploy_spec_dec_output[draft_target] diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_flashinfer_mamba_cached_op.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_flashinfer_mamba_cached_op.py new file mode 100644 index 0000000000..1ebe3a948a --- /dev/null +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_flashinfer_mamba_cached_op.py @@ -0,0 +1,103 @@ +import pytest +import torch + +import tensorrt_llm._torch.auto_deploy # noqa: F401 +from tests.unittest._torch.auto_deploy.unit.singlegpu.custom_ops.test_triton_mamba_cached_op import ( + _random_params, +) + + +@pytest.fixture +def mamba_env(): + device = "cuda" + dtype = torch.bfloat16 + atol = 1e-3 + rtol = 1e-3 + torch.manual_seed(42) + torch.cuda.empty_cache() + return {"device": device, "dtype": dtype, "atol": atol, "rtol": rtol} + + +def test_flashinfer_decode_matches_triton(mamba_env): + device = mamba_env["device"] + dtype = mamba_env["dtype"] + atol = mamba_env["atol"] + rtol = mamba_env["rtol"] + + batch, seq = 2, 1 + num_heads, head_dim = 2, 64 + n_groups, ssm_state_size = 2, 64 + (hidden_states, A, B, C, D, dt, dt_bias, time_step_limit, chunk_size) = _random_params( + device, dtype, batch, seq, num_heads, head_dim, n_groups, ssm_state_size + ) + + max_batch_size = 4 + slot_idx = torch.tensor([0, 2], device=device, dtype=torch.int32) + ssm_state_cache_triton = torch.randn( + max_batch_size, num_heads, head_dim, ssm_state_size, device=device, dtype=dtype + ) + ssm_state_cache_flashinfer = ssm_state_cache_triton.clone() + + # batch_info_host: [num_prefill, num_prefill_tokens, num_decode] + batch_info_host = torch.tensor([0, 0, batch], device=device, dtype=torch.int32) + cu_seqlen = torch.zeros(batch + 1, device=device, dtype=torch.int32) + use_initial_states = torch.zeros(batch, device=device, dtype=torch.bool) + + y_triton = torch.ops.auto_deploy.triton_cached_ssm( + hidden_states, + A, + B, + C, + D, + dt, + dt_bias, + # STANDARD METADATA + batch_info_host, + cu_seqlen, + slot_idx, + use_initial_states, + # EXTRA METADATA + None, # chunk indices + None, # chunk offsets + None, # seq_idx_prefill + # CACHES + ssm_state_cache_triton, + # CONSTANTS + time_step_limit, + chunk_size, + ) + + y_flashinfer = torch.ops.auto_deploy.flashinfer_cached_ssm( + hidden_states, + A, + B, + C, + D, + dt, + dt_bias, + # STANDARD METADATA + batch_info_host, + cu_seqlen, + slot_idx, + use_initial_states, + # EXTRA METADATA + None, # chunk indices + None, # chunk offsets + None, # seq_idx_prefill + # CACHES + ssm_state_cache_flashinfer, + # CONSTANTS + time_step_limit, + chunk_size, + ) + + assert y_triton.shape == hidden_states.shape + assert y_flashinfer.shape == hidden_states.shape + assert torch.isfinite(y_flashinfer).all() + assert torch.allclose(y_flashinfer, y_triton.to(y_flashinfer.dtype), atol=atol, rtol=rtol) + + after_triton = ssm_state_cache_triton.index_select(0, slot_idx) + after_flashinfer = ssm_state_cache_flashinfer.index_select(0, slot_idx) + assert torch.allclose( + after_flashinfer.to(after_triton.dtype), after_triton, atol=atol, rtol=rtol + ) diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_trtllm_moe.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_trtllm_moe.py index fa26211b95..c41bf1a601 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_trtllm_moe.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_trtllm_moe.py @@ -991,14 +991,5 @@ def test_stack_nvfp4_moe_weights_transform_relu2(hidden_size, intermediate_size) # Run the transformed graph transformed_output = gm(x, selected_experts, routing_weights) - # Get the registered parameters after transform - fc1_act_scale = getattr(gm, "nvfp4_moe_w3_w1_input_scale_stacked_0", None) - fc1_alpha = getattr(gm, "nvfp4_moe_w1_alpha_stacked_0", None) - if fc1_act_scale is not None: - print(f"fc1_act_scale (after transform): {fc1_act_scale}, shape: {fc1_act_scale.shape}") - if fc1_alpha is not None: - print(f"fc1_alpha (after transform): {fc1_alpha}, shape: {fc1_alpha.shape}") - - # Should be close for FP4 quantization (gated MLP may have slightly larger diff due to alpha handling) tol = 1e-3 torch.testing.assert_close(ref_output, transformed_output, rtol=tol, atol=tol) diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_build_small_single.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_build_small_single.py index 2ef8033349..dd56511748 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_build_small_single.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_build_small_single.py @@ -184,6 +184,7 @@ def _check_ad_config(experiment_config: ExperimentConfig, llm_args: LlmArgs): "transforms": { "insert_cached_attention": {"backend": "flashinfer"}, "compile_model": {"backend": "torch-simple"}, + "insert_cached_ssm_attention": {"backend": "triton_ssm"}, }, }, ), @@ -192,6 +193,7 @@ def _check_ad_config(experiment_config: ExperimentConfig, llm_args: LlmArgs): { "transforms": { "multi_stream_moe": {"stage": "compile", "enabled": True}, + "insert_cached_ssm_attention": {"backend": "triton_ssm"}, # TODO: https://github.com/NVIDIA/TensorRT-LLM/issues/9878 "compile_model": {"backend": "torch-cudagraph"}, },