From 9644f024bd47313c95b179e565ac07441c42c725 Mon Sep 17 00:00:00 2001 From: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com> Date: Thu, 5 Feb 2026 21:33:00 -0800 Subject: [PATCH] [None][feat] AutoDeploy: add triton backend for causal conv (#11124) Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com> --- .../custom_ops/mamba/causal_conv_common.py | 104 +++++ .../mamba/cuda_backend_causal_conv.py | 68 +-- .../mamba/triton_backend_causal_conv.py | 180 ++++++++ .../test_triton_causal_conv_cached_op.py | 386 ++++++++++++++++++ 4 files changed, 678 insertions(+), 60 deletions(-) create mode 100644 tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/causal_conv_common.py create mode 100644 tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_causal_conv.py create mode 100644 tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_triton_causal_conv_cached_op.py diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/causal_conv_common.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/causal_conv_common.py new file mode 100644 index 0000000000..22619c992b --- /dev/null +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/causal_conv_common.py @@ -0,0 +1,104 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import abstractmethod +from typing import List + +import torch +from torch._ops import OpOverloadPacket +from torch.fx import Node + +from .....llmapi.llm_args import KvCacheConfig +from ...utils.node_utils import extract_op_args +from ..attention_interface import ( + AttentionDescriptor, + AttentionLayout, + CausalConvResourceHandler, + Constant, + MHACallable, + ResourceHandlerDict, +) + + +class BaseCausalConvDescriptor(AttentionDescriptor): + """Base class for causal conv1d backends. + + Provides shared implementations for: + - get_attention_layout + - get_num_qkv_args + - get_source_attention_op + - get_standard_metadata_args + - get_cache_initializers + - get_constants + + Subclasses must implement: + - get_cached_attention_op + """ + + @classmethod + def get_attention_layout(cls) -> AttentionLayout: + # Hidden states follow [b, s, c] + return "bsnd" + + @classmethod + def get_num_qkv_args(cls) -> int: + # torch_causal_conv1d signature has 3 relevant tensor arguments + # TODO: bias can be optional!! How to handle None bias here? + return 3 + + @classmethod + def get_source_attention_op(cls) -> OpOverloadPacket: + return torch.ops.auto_deploy.torch_causal_conv1d.default + + @classmethod + @abstractmethod + def get_cached_attention_op(cls) -> MHACallable: + """Return the cached attention op for this backend. + + Must be implemented by subclasses. + """ + raise NotImplementedError + + @classmethod + def get_standard_metadata_args(cls) -> List[str]: + return ["batch_info_host", "cu_seqlen", "slot_idx", "use_initial_states"] + + @classmethod + def get_cache_initializers( + cls, source_attn_node: Node, cache_config: KvCacheConfig + ) -> ResourceHandlerDict: + inp_fake: torch.Tensor = source_attn_node.args[0].meta["val"] + w_fake: torch.Tensor = source_attn_node.args[1].meta["val"] + + in_channels = inp_fake.shape[-1] + kernel_size = w_fake.shape[-1] + + # NOTE: cuda backend stores kernel_size - 1 elements in state. + # CausalConvResourceHandler.state_shape = (conv_dim, d_conv - 1), so d_conv = kernel_size. + # Ensure d_conv >= 1 (state_shape[-1] >= 0). + conv_state_handler = CausalConvResourceHandler( + conv_dim=in_channels, + d_conv=max(1, kernel_size), # state_shape[-1] = d_conv - 1 = kernel_size - 1 + dtype=cls.resolve_cache_dtype("auto", inp_fake.dtype), + ) + return {"conv_state_cache": conv_state_handler} + + @classmethod + def get_constants(cls, source_attn_node: Node) -> List[Constant]: + stride, padding, dilation, groups, padding_mode = extract_op_args( + source_attn_node, "stride", "padding", "dilation", "groups", "padding_mode" + ) + # None is for activation parameter, which may not exist in the source node (added by fusion later) + return [stride, padding, dilation, groups, padding_mode, None] diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/cuda_backend_causal_conv.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/cuda_backend_causal_conv.py index f1d66b63f1..ebaefbf963 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/cuda_backend_causal_conv.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/cuda_backend_causal_conv.py @@ -24,26 +24,15 @@ The flattened cached op integrates with the auto_deploy attention interface and updates a slot-indexed convolution state cache internally. """ -from typing import List, Optional +from typing import Optional import torch -from torch._ops import OpOverloadPacket -from torch.fx import Node from tensorrt_llm._torch.modules.mamba import PAD_SLOT_ID from tensorrt_llm._torch.modules.mamba.causal_conv1d import causal_conv1d_fn, causal_conv1d_update -from .....llmapi.llm_args import KvCacheConfig -from ...utils.node_utils import extract_op_args -from ..attention_interface import ( - AttentionDescriptor, - AttentionLayout, - AttentionRegistry, - CausalConvResourceHandler, - Constant, - MHACallable, - ResourceHandlerDict, -) +from ..attention_interface import AttentionRegistry, MHACallable +from .causal_conv_common import BaseCausalConvDescriptor @torch.library.custom_op("auto_deploy::cuda_cached_causal_conv1d", mutates_args={"input"}) @@ -163,54 +152,13 @@ def cuda_cached_causal_conv1d_wrapper(input: torch.Tensor, *args, **kwargs) -> t @AttentionRegistry.register("cuda_causal_conv") -class CudaBackendCausalConv(AttentionDescriptor): - @classmethod - def get_attention_layout(cls) -> AttentionLayout: - # Hidden states follow [b, s, c] - return "bsnd" +class CudaBackendCausalConv(BaseCausalConvDescriptor): + """CUDA-backed causal conv1d attention descriptor. - @classmethod - def get_num_qkv_args(cls) -> int: - # torch_causal_conv1d signature has 3 relevant tensor arguments - # TODO: bias can be optional!! How to handle None bias here? - return 3 - - @classmethod - def get_source_attention_op(cls) -> OpOverloadPacket: - return torch.ops.auto_deploy.torch_causal_conv1d.default + Inherits shared methods from BaseCausalConvDescriptor. + Only overrides get_cached_attention_op to return the CUDA wrapper. + """ @classmethod def get_cached_attention_op(cls) -> MHACallable: return cuda_cached_causal_conv1d_wrapper - - @classmethod - def get_standard_metadata_args(cls) -> List[str]: - return ["batch_info_host", "cu_seqlen", "slot_idx", "use_initial_states"] - - @classmethod - def get_cache_initializers( - cls, source_attn_node: Node, cache_config: KvCacheConfig - ) -> ResourceHandlerDict: - inp_fake: torch.Tensor = source_attn_node.args[0].meta["val"] - w_fake: torch.Tensor = source_attn_node.args[1].meta["val"] - - in_channels = inp_fake.shape[-1] - kernel_size = w_fake.shape[-1] - - # NOTE: cuda backend stores kernel_size - 1 elements in state. - # CausalConvResourceHandler.state_shape = (conv_dim, d_conv - 1), so d_conv = kernel_size. - # Ensure d_conv >= 1 (state_shape[-1] >= 0). - conv_state_handler = CausalConvResourceHandler( - conv_dim=in_channels, - d_conv=max(1, kernel_size), # state_shape[-1] = d_conv - 1 = kernel_size - 1 - dtype=cls.resolve_cache_dtype("auto", inp_fake.dtype), - ) - return {"conv_state_cache": conv_state_handler} - - @classmethod - def get_constants(cls, source_attn_node: Node) -> List[Constant]: - stride, padding, dilation, groups, padding_mode = extract_op_args( - source_attn_node, "stride", "padding", "dilation", "groups", "padding_mode" - ) - # None is for activation parameter, which may not exist in the source node (added by fusion later) - return [stride, padding, dilation, groups, padding_mode, None] diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_causal_conv.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_causal_conv.py new file mode 100644 index 0000000000..993d061248 --- /dev/null +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_causal_conv.py @@ -0,0 +1,180 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Triton-backed cached causal conv1d custom ops and attention descriptor. + +This mirrors `cuda_backend_causal_conv.py` but uses Triton kernels instead of CUDA: +- Prefill uses Triton `causal_conv1d_fn` +- Decode uses Triton `causal_conv1d_update` + +The flattened cached op integrates with the auto_deploy attention interface +and updates a slot-indexed convolution state cache internally. +""" + +from typing import List, Optional + +import torch + +from tensorrt_llm._torch.modules.mamba import PAD_SLOT_ID +from tensorrt_llm._torch.modules.mamba.causal_conv1d_triton import ( + causal_conv1d_fn, + causal_conv1d_update, +) + +from ..attention_interface import AttentionRegistry, MHACallable +from .causal_conv_common import BaseCausalConvDescriptor + + +@torch.library.custom_op("auto_deploy::triton_cached_causal_conv1d", mutates_args={"input"}) +def _triton_cached_causal_conv1d( + # INPUTS (dense but may be flattened across sequences) + input: torch.Tensor, # [b, s, c_in] + weight: torch.Tensor, # [c_out, c_in/groups, k] but we expect depthwise use: [c_in, k] + bias: Optional[torch.Tensor], + # STANDARD METADATA + batch_info_host: torch.Tensor, + seq_len: torch.Tensor, + cu_seqlen: torch.Tensor, + slot_idx: torch.Tensor, + use_initial_states: torch.Tensor, + # EXTRA METADATA + # + # CACHES + conv_state_cache: torch.Tensor, # [max_batch_size, c_in, k-1] + # CONSTANTS + stride: int, + padding: int, + dilation: int, + groups: int, + padding_mode: str, + activation: Optional[str], +) -> None: + """Flattened cached causal conv that respects slot-indexed state caches (Triton backend). + + Supports two layouts from the attention interface: + - Generate-only: input is [b, 1, c_in]. We'll gather caches using slot_idx[:b]. + - Flattened context/mixed: input is [1, total_s, c_in] and seq_len/seq_start + describe per-sequence segments. We'll process each segment and scatter final states to caches. + + NOTE: This op modifies `input` in-place. + """ + b, s = input.shape[:2] + + num_prefill, num_prefill_tokens, num_decode = batch_info_host.tolist() + num_seq = num_prefill + num_decode + num_total_tokens = num_prefill_tokens + num_decode + + # Flatten tokens + bs = b * s + inp_flat = input.reshape(bs, *input.shape[2:]) # [total_s, C_in] + + # Prepare weight as [dim, width] (depthwise) + if weight.ndim == 3: + assert weight.shape[-2] == 1 + w2d = weight.squeeze(-2) + else: + w2d = weight + + # PREFILL: concatenate all prefill tokens and run one varlen forward + if num_prefill > 0: + # x_varlen: (dim, cu_seq_len) + x_varlen = inp_flat[:num_prefill_tokens].transpose(0, 1).contiguous() + + prefill_cu_seqlen = cu_seqlen[: num_prefill + 1] + seq_lens_cpu = seq_len[:num_prefill].tolist() + + # Run varlen conv; updates conv_state_cache in-place per cache_indices + # Note: Triton kernel returns a new tensor (not in-place like CUDA) + y_varlen = causal_conv1d_fn( + x_varlen, + w2d, + bias, + conv_state_cache, + prefill_cu_seqlen, + seq_lens_cpu, + cache_indices=slot_idx[:num_prefill].to(torch.int32), + has_initial_state=use_initial_states[:num_prefill], + activation=activation, + pad_slot_id=PAD_SLOT_ID, + ) # (dim, total_prefill_tokens) + # Scatter outputs back to input buffer + inp_flat[:num_prefill_tokens] = y_varlen.transpose(0, 1) + + # DECODE: batch update for single-token sequences + if num_decode > 0: + x_decode = inp_flat[num_prefill_tokens:num_total_tokens] # [num_decode, C_in] + + # Note: Triton causal_conv1d_update returns a new tensor (not in-place like CUDA version) + # so we need to capture the output and write it back + y_decode = causal_conv1d_update( + x_decode, # [batch, dim] + conv_state_cache, + w2d, + bias, + activation=activation, + cache_seqlens=None, + conv_state_indices=slot_idx[num_prefill:num_seq].to(torch.int32), + pad_slot_id=PAD_SLOT_ID, + ) + inp_flat[num_prefill_tokens:num_total_tokens] = y_decode + + +@_triton_cached_causal_conv1d.register_fake +def _triton_cached_causal_conv1d_fake( + # INPUTS (dense but may be flattened across sequences) + input: torch.Tensor, # [b, s, c_in] + weight: torch.Tensor, # [c_out, c_in/groups, k] but we expect depthwise use: [c_in, k] + bias: Optional[torch.Tensor], + # STANDARD METADATA + batch_info_host: torch.Tensor, + seq_len: torch.Tensor, + cu_seqlen: torch.Tensor, + slot_idx: torch.Tensor, + use_initial_states: torch.Tensor, + # EXTRA METADATA + # + # CACHES + conv_state_cache: torch.Tensor, # [max_batch_size, c_in, k-1] + # CONSTANTS + stride: int, + padding: int, + dilation: int, + groups: int, + padding_mode: str, + activation: Optional[str], +) -> None: + pass + + +def triton_cached_causal_conv1d_wrapper(input: torch.Tensor, *args, **kwargs) -> torch.Tensor: + torch.ops.auto_deploy.triton_cached_causal_conv1d(input, *args, **kwargs) + return input + + +@AttentionRegistry.register("triton_causal_conv") +class TritonBackendCausalConv(BaseCausalConvDescriptor): + """Triton-backed causal conv1d attention descriptor. + + Inherits shared methods from BaseCausalConvDescriptor. + Overrides get_standard_metadata_args to include seq_len (used directly by Triton kernel). + """ + + @classmethod + def get_standard_metadata_args(cls) -> List[str]: + return ["batch_info_host", "seq_len", "cu_seqlen", "slot_idx", "use_initial_states"] + + @classmethod + def get_cached_attention_op(cls) -> MHACallable: + return triton_cached_causal_conv1d_wrapper diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_triton_causal_conv_cached_op.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_triton_causal_conv_cached_op.py new file mode 100644 index 0000000000..8980240be3 --- /dev/null +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_triton_causal_conv_cached_op.py @@ -0,0 +1,386 @@ +"""Unit tests for Triton-backed cached causal conv1d custom ops. + +Covers: +- Generate-only path comparing Triton vs CUDA backend +- Context (flattened) path comparing Triton vs CUDA backend +- Ensures numerical consistency between backends +""" + +import pytest +import torch + +import tensorrt_llm._torch.auto_deploy # noqa: F401 + + +def _random_params_depthwise(device, dtype, batch, seq, channels, k): + x = torch.randn(batch, seq, channels, device=device, dtype=dtype) + # Depthwise: out_channels == in_channels, groups == channels, weight [C, 1, K] + weight = torch.randn(channels, 1, k, device=device, dtype=dtype) + bias = torch.randn(channels, device=device, dtype=dtype) + stride = 1 + padding = k - 1 + dilation = 1 + groups = channels + padding_mode = "zeros" + return x, weight, bias, stride, padding, dilation, groups, padding_mode + + +@pytest.fixture +def conv_env(): + device = "cuda" + dtype = torch.float16 + atol = 5e-2 + rtol = 5e-2 + torch.manual_seed(123) + torch.cuda.empty_cache() + return {"device": device, "dtype": dtype, "atol": atol, "rtol": rtol} + + +def test_generate_only_triton_vs_cuda(conv_env): + """Test that Triton backend matches CUDA backend for generate-only (decode) path.""" + device = conv_env["device"] + dtype = conv_env["dtype"] + + batch, seq = 1, 1 + c, k = 2, 3 + x, w, b, s, p, d, g, pm = _random_params_depthwise(device, dtype, batch, seq, c, k) + + # Slot mapping with arbitrary order within max_batch_size + max_batch_size = 2 + slot_idx = torch.tensor([0], device=device, dtype=torch.int32) + # Cache holds K-1 entries (TRT update kernel contract) + conv_state_cache_cuda = torch.randn( + max_batch_size, + c, + k - 1, + device=device, + dtype=dtype, + ) + # Clone for Triton backend + conv_state_cache_triton = conv_state_cache_cuda.clone() + + # Metadata + cu_seqlen = torch.zeros(batch, device=device, dtype=torch.int32) + seq_len = torch.zeros(batch, device=device, dtype=torch.int32) + use_initial_states = torch.zeros(batch, device=device, dtype=torch.bool) + # batch_info_host: [num_prefill, num_prefill_tokens, num_decode] + # For generate-only: num_decode = batch, num_prefill = 0 + batch_info_host = torch.tensor([0, 0, batch], device=device, dtype=torch.int32) + + # Clone inputs for each backend + x_cuda = x.clone() + x_triton = x.clone() + + # Run CUDA backend + torch.ops.auto_deploy.cuda_cached_causal_conv1d( + x_cuda, + w, + b, + batch_info_host, + cu_seqlen, + slot_idx, + use_initial_states, + conv_state_cache_cuda, + s, + p, + d, + g, + pm, + None, + ) + + # Run Triton backend (includes seq_len parameter) + torch.ops.auto_deploy.triton_cached_causal_conv1d( + x_triton, + w, + b, + batch_info_host, + seq_len, + cu_seqlen, + slot_idx, + use_initial_states, + conv_state_cache_triton, + s, + p, + d, + g, + pm, + None, + ) + + # Compare outputs + assert x_cuda.shape == x_triton.shape + assert torch.allclose(x_cuda, x_triton, atol=conv_env["atol"], rtol=conv_env["rtol"]), ( + f"Output mismatch: max diff = {(x_cuda - x_triton).abs().max()}" + ) + + # Compare cache states + assert torch.allclose( + conv_state_cache_cuda, conv_state_cache_triton, atol=conv_env["atol"], rtol=conv_env["rtol"] + ), ( + f"Cache state mismatch: max diff = {(conv_state_cache_cuda - conv_state_cache_triton).abs().max()}" + ) + + +def test_context_flattened_triton_vs_cuda(conv_env): + """Test that Triton backend matches CUDA backend for context (prefill) path.""" + device = conv_env["device"] + dtype = conv_env["dtype"] + + # Two short sequences with lengths 2 and 1, flattened to [1,3] + lens = [2, 1] + total = sum(lens) + batch, seq = 1, total + c, k = 2, 3 + x, w, b, s, p, d, g, pm = _random_params_depthwise(device, dtype, batch, seq, c, k) + + max_batch_size = 2 + slot_idx = torch.tensor([1, 0], device=device, dtype=torch.int32) + conv_state_cache_cuda = torch.randn( + max_batch_size, + c, + k - 1, + device=device, + dtype=dtype, + ) + # Clone for Triton backend + conv_state_cache_triton = conv_state_cache_cuda.clone() + + # batch_info_host: [num_prefill, num_prefill_tokens, num_decode] + num_prefill = len(lens) + batch_info_host = torch.tensor([num_prefill, total, 0], device=device, dtype=torch.int32) + cu_seqlen = torch.tensor([0, lens[0], total], device=device, dtype=torch.int32) + seq_len = torch.tensor(lens, device=device, dtype=torch.int32) + use_initial_states = torch.zeros(num_prefill, device=device, dtype=torch.bool) + + # Clone inputs for each backend + x_cuda = x.clone() + x_triton = x.clone() + + # Run CUDA backend + torch.ops.auto_deploy.cuda_cached_causal_conv1d( + x_cuda, + w, + b, + batch_info_host, + cu_seqlen, + slot_idx, + use_initial_states, + conv_state_cache_cuda, + s, + p, + d, + g, + pm, + None, + ) + + # Run Triton backend (includes seq_len parameter) + torch.ops.auto_deploy.triton_cached_causal_conv1d( + x_triton, + w, + b, + batch_info_host, + seq_len, + cu_seqlen, + slot_idx, + use_initial_states, + conv_state_cache_triton, + s, + p, + d, + g, + pm, + None, + ) + + # Compare outputs + assert x_cuda.shape == x_triton.shape + assert torch.allclose(x_cuda, x_triton, atol=conv_env["atol"], rtol=conv_env["rtol"]), ( + f"Output mismatch: max diff = {(x_cuda - x_triton).abs().max()}" + ) + + # Compare cache states + assert torch.allclose( + conv_state_cache_cuda, conv_state_cache_triton, atol=conv_env["atol"], rtol=conv_env["rtol"] + ), ( + f"Cache state mismatch: max diff = {(conv_state_cache_cuda - conv_state_cache_triton).abs().max()}" + ) + + +def test_mixed_prefill_decode_triton_vs_cuda(conv_env): + """Test that Triton backend matches CUDA backend for mixed prefill + decode batch.""" + device = conv_env["device"] + dtype = conv_env["dtype"] + + # Mixed batch: 1 prefill sequence (len=3) + 2 decode tokens + prefill_lens = [3] + num_prefill = len(prefill_lens) + num_prefill_tokens = sum(prefill_lens) + num_decode = 2 + + total_tokens = num_prefill_tokens + num_decode + batch, seq = 1, total_tokens + c, k = 4, 3 + x, w, b, s, p, d, g, pm = _random_params_depthwise(device, dtype, batch, seq, c, k) + + max_batch_size = 4 + # Slot indices: first for prefill, then for decode + slot_idx = torch.tensor([0, 1, 2], device=device, dtype=torch.int32) + conv_state_cache_cuda = torch.randn( + max_batch_size, + c, + k - 1, + device=device, + dtype=dtype, + ) + conv_state_cache_triton = conv_state_cache_cuda.clone() + + # batch_info_host: [num_prefill, num_prefill_tokens, num_decode] + batch_info_host = torch.tensor( + [num_prefill, num_prefill_tokens, num_decode], device=device, dtype=torch.int32 + ) + cu_seqlen = torch.tensor([0, prefill_lens[0]], device=device, dtype=torch.int32) + seq_len = torch.tensor(prefill_lens, device=device, dtype=torch.int32) + use_initial_states = torch.zeros(num_prefill, device=device, dtype=torch.bool) + + # Clone inputs + x_cuda = x.clone() + x_triton = x.clone() + + # Run CUDA backend + torch.ops.auto_deploy.cuda_cached_causal_conv1d( + x_cuda, + w, + b, + batch_info_host, + cu_seqlen, + slot_idx, + use_initial_states, + conv_state_cache_cuda, + s, + p, + d, + g, + pm, + None, + ) + + # Run Triton backend + torch.ops.auto_deploy.triton_cached_causal_conv1d( + x_triton, + w, + b, + batch_info_host, + seq_len, + cu_seqlen, + slot_idx, + use_initial_states, + conv_state_cache_triton, + s, + p, + d, + g, + pm, + None, + ) + + # Compare outputs + assert x_cuda.shape == x_triton.shape + assert torch.allclose(x_cuda, x_triton, atol=conv_env["atol"], rtol=conv_env["rtol"]), ( + f"Output mismatch: max diff = {(x_cuda - x_triton).abs().max()}" + ) + + # Compare cache states + assert torch.allclose( + conv_state_cache_cuda, conv_state_cache_triton, atol=conv_env["atol"], rtol=conv_env["rtol"] + ), ( + f"Cache state mismatch: max diff = {(conv_state_cache_cuda - conv_state_cache_triton).abs().max()}" + ) + + +def test_larger_batch_triton_vs_cuda(conv_env): + """Test with larger batch and longer sequences.""" + device = conv_env["device"] + dtype = conv_env["dtype"] + + # Multiple prefill sequences + lens = [5, 8, 3, 10] + total = sum(lens) + batch, seq = 1, total + c, k = 16, 4 + x, w, b, s, p, d, g, pm = _random_params_depthwise(device, dtype, batch, seq, c, k) + + max_batch_size = 8 + slot_idx = torch.tensor([3, 1, 5, 0], device=device, dtype=torch.int32) + conv_state_cache_cuda = torch.randn( + max_batch_size, + c, + k - 1, + device=device, + dtype=dtype, + ) + conv_state_cache_triton = conv_state_cache_cuda.clone() + + num_prefill = len(lens) + batch_info_host = torch.tensor([num_prefill, total, 0], device=device, dtype=torch.int32) + + # Build cumulative sequence lengths + cu_seqlen_list = [0] + for ln in lens: + cu_seqlen_list.append(cu_seqlen_list[-1] + ln) + cu_seqlen = torch.tensor(cu_seqlen_list, device=device, dtype=torch.int32) + seq_len = torch.tensor(lens, device=device, dtype=torch.int32) + use_initial_states = torch.zeros(num_prefill, device=device, dtype=torch.bool) + + x_cuda = x.clone() + x_triton = x.clone() + + # Run CUDA backend + torch.ops.auto_deploy.cuda_cached_causal_conv1d( + x_cuda, + w, + b, + batch_info_host, + cu_seqlen, + slot_idx, + use_initial_states, + conv_state_cache_cuda, + s, + p, + d, + g, + pm, + None, + ) + + # Run Triton backend + torch.ops.auto_deploy.triton_cached_causal_conv1d( + x_triton, + w, + b, + batch_info_host, + seq_len, + cu_seqlen, + slot_idx, + use_initial_states, + conv_state_cache_triton, + s, + p, + d, + g, + pm, + None, + ) + + # Compare outputs + assert torch.allclose(x_cuda, x_triton, atol=conv_env["atol"], rtol=conv_env["rtol"]), ( + f"Output mismatch: max diff = {(x_cuda - x_triton).abs().max()}" + ) + + # Compare cache states + assert torch.allclose( + conv_state_cache_cuda, conv_state_cache_triton, atol=conv_env["atol"], rtol=conv_env["rtol"] + ), ( + f"Cache state mismatch: max diff = {(conv_state_cache_cuda - conv_state_cache_triton).abs().max()}" + )