mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-16 15:55:08 +08:00
[None][feat] AutoDeploy: add triton backend for causal conv (#11124)
Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
This commit is contained in:
parent
d160439ef9
commit
9644f024bd
@ -0,0 +1,104 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from abc import abstractmethod
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
from torch._ops import OpOverloadPacket
|
||||
from torch.fx import Node
|
||||
|
||||
from .....llmapi.llm_args import KvCacheConfig
|
||||
from ...utils.node_utils import extract_op_args
|
||||
from ..attention_interface import (
|
||||
AttentionDescriptor,
|
||||
AttentionLayout,
|
||||
CausalConvResourceHandler,
|
||||
Constant,
|
||||
MHACallable,
|
||||
ResourceHandlerDict,
|
||||
)
|
||||
|
||||
|
||||
class BaseCausalConvDescriptor(AttentionDescriptor):
|
||||
"""Base class for causal conv1d backends.
|
||||
|
||||
Provides shared implementations for:
|
||||
- get_attention_layout
|
||||
- get_num_qkv_args
|
||||
- get_source_attention_op
|
||||
- get_standard_metadata_args
|
||||
- get_cache_initializers
|
||||
- get_constants
|
||||
|
||||
Subclasses must implement:
|
||||
- get_cached_attention_op
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def get_attention_layout(cls) -> AttentionLayout:
|
||||
# Hidden states follow [b, s, c]
|
||||
return "bsnd"
|
||||
|
||||
@classmethod
|
||||
def get_num_qkv_args(cls) -> int:
|
||||
# torch_causal_conv1d signature has 3 relevant tensor arguments
|
||||
# TODO: bias can be optional!! How to handle None bias here?
|
||||
return 3
|
||||
|
||||
@classmethod
|
||||
def get_source_attention_op(cls) -> OpOverloadPacket:
|
||||
return torch.ops.auto_deploy.torch_causal_conv1d.default
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def get_cached_attention_op(cls) -> MHACallable:
|
||||
"""Return the cached attention op for this backend.
|
||||
|
||||
Must be implemented by subclasses.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def get_standard_metadata_args(cls) -> List[str]:
|
||||
return ["batch_info_host", "cu_seqlen", "slot_idx", "use_initial_states"]
|
||||
|
||||
@classmethod
|
||||
def get_cache_initializers(
|
||||
cls, source_attn_node: Node, cache_config: KvCacheConfig
|
||||
) -> ResourceHandlerDict:
|
||||
inp_fake: torch.Tensor = source_attn_node.args[0].meta["val"]
|
||||
w_fake: torch.Tensor = source_attn_node.args[1].meta["val"]
|
||||
|
||||
in_channels = inp_fake.shape[-1]
|
||||
kernel_size = w_fake.shape[-1]
|
||||
|
||||
# NOTE: cuda backend stores kernel_size - 1 elements in state.
|
||||
# CausalConvResourceHandler.state_shape = (conv_dim, d_conv - 1), so d_conv = kernel_size.
|
||||
# Ensure d_conv >= 1 (state_shape[-1] >= 0).
|
||||
conv_state_handler = CausalConvResourceHandler(
|
||||
conv_dim=in_channels,
|
||||
d_conv=max(1, kernel_size), # state_shape[-1] = d_conv - 1 = kernel_size - 1
|
||||
dtype=cls.resolve_cache_dtype("auto", inp_fake.dtype),
|
||||
)
|
||||
return {"conv_state_cache": conv_state_handler}
|
||||
|
||||
@classmethod
|
||||
def get_constants(cls, source_attn_node: Node) -> List[Constant]:
|
||||
stride, padding, dilation, groups, padding_mode = extract_op_args(
|
||||
source_attn_node, "stride", "padding", "dilation", "groups", "padding_mode"
|
||||
)
|
||||
# None is for activation parameter, which may not exist in the source node (added by fusion later)
|
||||
return [stride, padding, dilation, groups, padding_mode, None]
|
||||
@ -24,26 +24,15 @@ The flattened cached op integrates with the auto_deploy attention interface
|
||||
and updates a slot-indexed convolution state cache internally.
|
||||
"""
|
||||
|
||||
from typing import List, Optional
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch._ops import OpOverloadPacket
|
||||
from torch.fx import Node
|
||||
|
||||
from tensorrt_llm._torch.modules.mamba import PAD_SLOT_ID
|
||||
from tensorrt_llm._torch.modules.mamba.causal_conv1d import causal_conv1d_fn, causal_conv1d_update
|
||||
|
||||
from .....llmapi.llm_args import KvCacheConfig
|
||||
from ...utils.node_utils import extract_op_args
|
||||
from ..attention_interface import (
|
||||
AttentionDescriptor,
|
||||
AttentionLayout,
|
||||
AttentionRegistry,
|
||||
CausalConvResourceHandler,
|
||||
Constant,
|
||||
MHACallable,
|
||||
ResourceHandlerDict,
|
||||
)
|
||||
from ..attention_interface import AttentionRegistry, MHACallable
|
||||
from .causal_conv_common import BaseCausalConvDescriptor
|
||||
|
||||
|
||||
@torch.library.custom_op("auto_deploy::cuda_cached_causal_conv1d", mutates_args={"input"})
|
||||
@ -163,54 +152,13 @@ def cuda_cached_causal_conv1d_wrapper(input: torch.Tensor, *args, **kwargs) -> t
|
||||
|
||||
|
||||
@AttentionRegistry.register("cuda_causal_conv")
|
||||
class CudaBackendCausalConv(AttentionDescriptor):
|
||||
@classmethod
|
||||
def get_attention_layout(cls) -> AttentionLayout:
|
||||
# Hidden states follow [b, s, c]
|
||||
return "bsnd"
|
||||
class CudaBackendCausalConv(BaseCausalConvDescriptor):
|
||||
"""CUDA-backed causal conv1d attention descriptor.
|
||||
|
||||
@classmethod
|
||||
def get_num_qkv_args(cls) -> int:
|
||||
# torch_causal_conv1d signature has 3 relevant tensor arguments
|
||||
# TODO: bias can be optional!! How to handle None bias here?
|
||||
return 3
|
||||
|
||||
@classmethod
|
||||
def get_source_attention_op(cls) -> OpOverloadPacket:
|
||||
return torch.ops.auto_deploy.torch_causal_conv1d.default
|
||||
Inherits shared methods from BaseCausalConvDescriptor.
|
||||
Only overrides get_cached_attention_op to return the CUDA wrapper.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def get_cached_attention_op(cls) -> MHACallable:
|
||||
return cuda_cached_causal_conv1d_wrapper
|
||||
|
||||
@classmethod
|
||||
def get_standard_metadata_args(cls) -> List[str]:
|
||||
return ["batch_info_host", "cu_seqlen", "slot_idx", "use_initial_states"]
|
||||
|
||||
@classmethod
|
||||
def get_cache_initializers(
|
||||
cls, source_attn_node: Node, cache_config: KvCacheConfig
|
||||
) -> ResourceHandlerDict:
|
||||
inp_fake: torch.Tensor = source_attn_node.args[0].meta["val"]
|
||||
w_fake: torch.Tensor = source_attn_node.args[1].meta["val"]
|
||||
|
||||
in_channels = inp_fake.shape[-1]
|
||||
kernel_size = w_fake.shape[-1]
|
||||
|
||||
# NOTE: cuda backend stores kernel_size - 1 elements in state.
|
||||
# CausalConvResourceHandler.state_shape = (conv_dim, d_conv - 1), so d_conv = kernel_size.
|
||||
# Ensure d_conv >= 1 (state_shape[-1] >= 0).
|
||||
conv_state_handler = CausalConvResourceHandler(
|
||||
conv_dim=in_channels,
|
||||
d_conv=max(1, kernel_size), # state_shape[-1] = d_conv - 1 = kernel_size - 1
|
||||
dtype=cls.resolve_cache_dtype("auto", inp_fake.dtype),
|
||||
)
|
||||
return {"conv_state_cache": conv_state_handler}
|
||||
|
||||
@classmethod
|
||||
def get_constants(cls, source_attn_node: Node) -> List[Constant]:
|
||||
stride, padding, dilation, groups, padding_mode = extract_op_args(
|
||||
source_attn_node, "stride", "padding", "dilation", "groups", "padding_mode"
|
||||
)
|
||||
# None is for activation parameter, which may not exist in the source node (added by fusion later)
|
||||
return [stride, padding, dilation, groups, padding_mode, None]
|
||||
|
||||
@ -0,0 +1,180 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Triton-backed cached causal conv1d custom ops and attention descriptor.
|
||||
|
||||
This mirrors `cuda_backend_causal_conv.py` but uses Triton kernels instead of CUDA:
|
||||
- Prefill uses Triton `causal_conv1d_fn`
|
||||
- Decode uses Triton `causal_conv1d_update`
|
||||
|
||||
The flattened cached op integrates with the auto_deploy attention interface
|
||||
and updates a slot-indexed convolution state cache internally.
|
||||
"""
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from tensorrt_llm._torch.modules.mamba import PAD_SLOT_ID
|
||||
from tensorrt_llm._torch.modules.mamba.causal_conv1d_triton import (
|
||||
causal_conv1d_fn,
|
||||
causal_conv1d_update,
|
||||
)
|
||||
|
||||
from ..attention_interface import AttentionRegistry, MHACallable
|
||||
from .causal_conv_common import BaseCausalConvDescriptor
|
||||
|
||||
|
||||
@torch.library.custom_op("auto_deploy::triton_cached_causal_conv1d", mutates_args={"input"})
|
||||
def _triton_cached_causal_conv1d(
|
||||
# INPUTS (dense but may be flattened across sequences)
|
||||
input: torch.Tensor, # [b, s, c_in]
|
||||
weight: torch.Tensor, # [c_out, c_in/groups, k] but we expect depthwise use: [c_in, k]
|
||||
bias: Optional[torch.Tensor],
|
||||
# STANDARD METADATA
|
||||
batch_info_host: torch.Tensor,
|
||||
seq_len: torch.Tensor,
|
||||
cu_seqlen: torch.Tensor,
|
||||
slot_idx: torch.Tensor,
|
||||
use_initial_states: torch.Tensor,
|
||||
# EXTRA METADATA
|
||||
#
|
||||
# CACHES
|
||||
conv_state_cache: torch.Tensor, # [max_batch_size, c_in, k-1]
|
||||
# CONSTANTS
|
||||
stride: int,
|
||||
padding: int,
|
||||
dilation: int,
|
||||
groups: int,
|
||||
padding_mode: str,
|
||||
activation: Optional[str],
|
||||
) -> None:
|
||||
"""Flattened cached causal conv that respects slot-indexed state caches (Triton backend).
|
||||
|
||||
Supports two layouts from the attention interface:
|
||||
- Generate-only: input is [b, 1, c_in]. We'll gather caches using slot_idx[:b].
|
||||
- Flattened context/mixed: input is [1, total_s, c_in] and seq_len/seq_start
|
||||
describe per-sequence segments. We'll process each segment and scatter final states to caches.
|
||||
|
||||
NOTE: This op modifies `input` in-place.
|
||||
"""
|
||||
b, s = input.shape[:2]
|
||||
|
||||
num_prefill, num_prefill_tokens, num_decode = batch_info_host.tolist()
|
||||
num_seq = num_prefill + num_decode
|
||||
num_total_tokens = num_prefill_tokens + num_decode
|
||||
|
||||
# Flatten tokens
|
||||
bs = b * s
|
||||
inp_flat = input.reshape(bs, *input.shape[2:]) # [total_s, C_in]
|
||||
|
||||
# Prepare weight as [dim, width] (depthwise)
|
||||
if weight.ndim == 3:
|
||||
assert weight.shape[-2] == 1
|
||||
w2d = weight.squeeze(-2)
|
||||
else:
|
||||
w2d = weight
|
||||
|
||||
# PREFILL: concatenate all prefill tokens and run one varlen forward
|
||||
if num_prefill > 0:
|
||||
# x_varlen: (dim, cu_seq_len)
|
||||
x_varlen = inp_flat[:num_prefill_tokens].transpose(0, 1).contiguous()
|
||||
|
||||
prefill_cu_seqlen = cu_seqlen[: num_prefill + 1]
|
||||
seq_lens_cpu = seq_len[:num_prefill].tolist()
|
||||
|
||||
# Run varlen conv; updates conv_state_cache in-place per cache_indices
|
||||
# Note: Triton kernel returns a new tensor (not in-place like CUDA)
|
||||
y_varlen = causal_conv1d_fn(
|
||||
x_varlen,
|
||||
w2d,
|
||||
bias,
|
||||
conv_state_cache,
|
||||
prefill_cu_seqlen,
|
||||
seq_lens_cpu,
|
||||
cache_indices=slot_idx[:num_prefill].to(torch.int32),
|
||||
has_initial_state=use_initial_states[:num_prefill],
|
||||
activation=activation,
|
||||
pad_slot_id=PAD_SLOT_ID,
|
||||
) # (dim, total_prefill_tokens)
|
||||
# Scatter outputs back to input buffer
|
||||
inp_flat[:num_prefill_tokens] = y_varlen.transpose(0, 1)
|
||||
|
||||
# DECODE: batch update for single-token sequences
|
||||
if num_decode > 0:
|
||||
x_decode = inp_flat[num_prefill_tokens:num_total_tokens] # [num_decode, C_in]
|
||||
|
||||
# Note: Triton causal_conv1d_update returns a new tensor (not in-place like CUDA version)
|
||||
# so we need to capture the output and write it back
|
||||
y_decode = causal_conv1d_update(
|
||||
x_decode, # [batch, dim]
|
||||
conv_state_cache,
|
||||
w2d,
|
||||
bias,
|
||||
activation=activation,
|
||||
cache_seqlens=None,
|
||||
conv_state_indices=slot_idx[num_prefill:num_seq].to(torch.int32),
|
||||
pad_slot_id=PAD_SLOT_ID,
|
||||
)
|
||||
inp_flat[num_prefill_tokens:num_total_tokens] = y_decode
|
||||
|
||||
|
||||
@_triton_cached_causal_conv1d.register_fake
|
||||
def _triton_cached_causal_conv1d_fake(
|
||||
# INPUTS (dense but may be flattened across sequences)
|
||||
input: torch.Tensor, # [b, s, c_in]
|
||||
weight: torch.Tensor, # [c_out, c_in/groups, k] but we expect depthwise use: [c_in, k]
|
||||
bias: Optional[torch.Tensor],
|
||||
# STANDARD METADATA
|
||||
batch_info_host: torch.Tensor,
|
||||
seq_len: torch.Tensor,
|
||||
cu_seqlen: torch.Tensor,
|
||||
slot_idx: torch.Tensor,
|
||||
use_initial_states: torch.Tensor,
|
||||
# EXTRA METADATA
|
||||
#
|
||||
# CACHES
|
||||
conv_state_cache: torch.Tensor, # [max_batch_size, c_in, k-1]
|
||||
# CONSTANTS
|
||||
stride: int,
|
||||
padding: int,
|
||||
dilation: int,
|
||||
groups: int,
|
||||
padding_mode: str,
|
||||
activation: Optional[str],
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
|
||||
def triton_cached_causal_conv1d_wrapper(input: torch.Tensor, *args, **kwargs) -> torch.Tensor:
|
||||
torch.ops.auto_deploy.triton_cached_causal_conv1d(input, *args, **kwargs)
|
||||
return input
|
||||
|
||||
|
||||
@AttentionRegistry.register("triton_causal_conv")
|
||||
class TritonBackendCausalConv(BaseCausalConvDescriptor):
|
||||
"""Triton-backed causal conv1d attention descriptor.
|
||||
|
||||
Inherits shared methods from BaseCausalConvDescriptor.
|
||||
Overrides get_standard_metadata_args to include seq_len (used directly by Triton kernel).
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def get_standard_metadata_args(cls) -> List[str]:
|
||||
return ["batch_info_host", "seq_len", "cu_seqlen", "slot_idx", "use_initial_states"]
|
||||
|
||||
@classmethod
|
||||
def get_cached_attention_op(cls) -> MHACallable:
|
||||
return triton_cached_causal_conv1d_wrapper
|
||||
@ -0,0 +1,386 @@
|
||||
"""Unit tests for Triton-backed cached causal conv1d custom ops.
|
||||
|
||||
Covers:
|
||||
- Generate-only path comparing Triton vs CUDA backend
|
||||
- Context (flattened) path comparing Triton vs CUDA backend
|
||||
- Ensures numerical consistency between backends
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import tensorrt_llm._torch.auto_deploy # noqa: F401
|
||||
|
||||
|
||||
def _random_params_depthwise(device, dtype, batch, seq, channels, k):
|
||||
x = torch.randn(batch, seq, channels, device=device, dtype=dtype)
|
||||
# Depthwise: out_channels == in_channels, groups == channels, weight [C, 1, K]
|
||||
weight = torch.randn(channels, 1, k, device=device, dtype=dtype)
|
||||
bias = torch.randn(channels, device=device, dtype=dtype)
|
||||
stride = 1
|
||||
padding = k - 1
|
||||
dilation = 1
|
||||
groups = channels
|
||||
padding_mode = "zeros"
|
||||
return x, weight, bias, stride, padding, dilation, groups, padding_mode
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def conv_env():
|
||||
device = "cuda"
|
||||
dtype = torch.float16
|
||||
atol = 5e-2
|
||||
rtol = 5e-2
|
||||
torch.manual_seed(123)
|
||||
torch.cuda.empty_cache()
|
||||
return {"device": device, "dtype": dtype, "atol": atol, "rtol": rtol}
|
||||
|
||||
|
||||
def test_generate_only_triton_vs_cuda(conv_env):
|
||||
"""Test that Triton backend matches CUDA backend for generate-only (decode) path."""
|
||||
device = conv_env["device"]
|
||||
dtype = conv_env["dtype"]
|
||||
|
||||
batch, seq = 1, 1
|
||||
c, k = 2, 3
|
||||
x, w, b, s, p, d, g, pm = _random_params_depthwise(device, dtype, batch, seq, c, k)
|
||||
|
||||
# Slot mapping with arbitrary order within max_batch_size
|
||||
max_batch_size = 2
|
||||
slot_idx = torch.tensor([0], device=device, dtype=torch.int32)
|
||||
# Cache holds K-1 entries (TRT update kernel contract)
|
||||
conv_state_cache_cuda = torch.randn(
|
||||
max_batch_size,
|
||||
c,
|
||||
k - 1,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
# Clone for Triton backend
|
||||
conv_state_cache_triton = conv_state_cache_cuda.clone()
|
||||
|
||||
# Metadata
|
||||
cu_seqlen = torch.zeros(batch, device=device, dtype=torch.int32)
|
||||
seq_len = torch.zeros(batch, device=device, dtype=torch.int32)
|
||||
use_initial_states = torch.zeros(batch, device=device, dtype=torch.bool)
|
||||
# batch_info_host: [num_prefill, num_prefill_tokens, num_decode]
|
||||
# For generate-only: num_decode = batch, num_prefill = 0
|
||||
batch_info_host = torch.tensor([0, 0, batch], device=device, dtype=torch.int32)
|
||||
|
||||
# Clone inputs for each backend
|
||||
x_cuda = x.clone()
|
||||
x_triton = x.clone()
|
||||
|
||||
# Run CUDA backend
|
||||
torch.ops.auto_deploy.cuda_cached_causal_conv1d(
|
||||
x_cuda,
|
||||
w,
|
||||
b,
|
||||
batch_info_host,
|
||||
cu_seqlen,
|
||||
slot_idx,
|
||||
use_initial_states,
|
||||
conv_state_cache_cuda,
|
||||
s,
|
||||
p,
|
||||
d,
|
||||
g,
|
||||
pm,
|
||||
None,
|
||||
)
|
||||
|
||||
# Run Triton backend (includes seq_len parameter)
|
||||
torch.ops.auto_deploy.triton_cached_causal_conv1d(
|
||||
x_triton,
|
||||
w,
|
||||
b,
|
||||
batch_info_host,
|
||||
seq_len,
|
||||
cu_seqlen,
|
||||
slot_idx,
|
||||
use_initial_states,
|
||||
conv_state_cache_triton,
|
||||
s,
|
||||
p,
|
||||
d,
|
||||
g,
|
||||
pm,
|
||||
None,
|
||||
)
|
||||
|
||||
# Compare outputs
|
||||
assert x_cuda.shape == x_triton.shape
|
||||
assert torch.allclose(x_cuda, x_triton, atol=conv_env["atol"], rtol=conv_env["rtol"]), (
|
||||
f"Output mismatch: max diff = {(x_cuda - x_triton).abs().max()}"
|
||||
)
|
||||
|
||||
# Compare cache states
|
||||
assert torch.allclose(
|
||||
conv_state_cache_cuda, conv_state_cache_triton, atol=conv_env["atol"], rtol=conv_env["rtol"]
|
||||
), (
|
||||
f"Cache state mismatch: max diff = {(conv_state_cache_cuda - conv_state_cache_triton).abs().max()}"
|
||||
)
|
||||
|
||||
|
||||
def test_context_flattened_triton_vs_cuda(conv_env):
|
||||
"""Test that Triton backend matches CUDA backend for context (prefill) path."""
|
||||
device = conv_env["device"]
|
||||
dtype = conv_env["dtype"]
|
||||
|
||||
# Two short sequences with lengths 2 and 1, flattened to [1,3]
|
||||
lens = [2, 1]
|
||||
total = sum(lens)
|
||||
batch, seq = 1, total
|
||||
c, k = 2, 3
|
||||
x, w, b, s, p, d, g, pm = _random_params_depthwise(device, dtype, batch, seq, c, k)
|
||||
|
||||
max_batch_size = 2
|
||||
slot_idx = torch.tensor([1, 0], device=device, dtype=torch.int32)
|
||||
conv_state_cache_cuda = torch.randn(
|
||||
max_batch_size,
|
||||
c,
|
||||
k - 1,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
# Clone for Triton backend
|
||||
conv_state_cache_triton = conv_state_cache_cuda.clone()
|
||||
|
||||
# batch_info_host: [num_prefill, num_prefill_tokens, num_decode]
|
||||
num_prefill = len(lens)
|
||||
batch_info_host = torch.tensor([num_prefill, total, 0], device=device, dtype=torch.int32)
|
||||
cu_seqlen = torch.tensor([0, lens[0], total], device=device, dtype=torch.int32)
|
||||
seq_len = torch.tensor(lens, device=device, dtype=torch.int32)
|
||||
use_initial_states = torch.zeros(num_prefill, device=device, dtype=torch.bool)
|
||||
|
||||
# Clone inputs for each backend
|
||||
x_cuda = x.clone()
|
||||
x_triton = x.clone()
|
||||
|
||||
# Run CUDA backend
|
||||
torch.ops.auto_deploy.cuda_cached_causal_conv1d(
|
||||
x_cuda,
|
||||
w,
|
||||
b,
|
||||
batch_info_host,
|
||||
cu_seqlen,
|
||||
slot_idx,
|
||||
use_initial_states,
|
||||
conv_state_cache_cuda,
|
||||
s,
|
||||
p,
|
||||
d,
|
||||
g,
|
||||
pm,
|
||||
None,
|
||||
)
|
||||
|
||||
# Run Triton backend (includes seq_len parameter)
|
||||
torch.ops.auto_deploy.triton_cached_causal_conv1d(
|
||||
x_triton,
|
||||
w,
|
||||
b,
|
||||
batch_info_host,
|
||||
seq_len,
|
||||
cu_seqlen,
|
||||
slot_idx,
|
||||
use_initial_states,
|
||||
conv_state_cache_triton,
|
||||
s,
|
||||
p,
|
||||
d,
|
||||
g,
|
||||
pm,
|
||||
None,
|
||||
)
|
||||
|
||||
# Compare outputs
|
||||
assert x_cuda.shape == x_triton.shape
|
||||
assert torch.allclose(x_cuda, x_triton, atol=conv_env["atol"], rtol=conv_env["rtol"]), (
|
||||
f"Output mismatch: max diff = {(x_cuda - x_triton).abs().max()}"
|
||||
)
|
||||
|
||||
# Compare cache states
|
||||
assert torch.allclose(
|
||||
conv_state_cache_cuda, conv_state_cache_triton, atol=conv_env["atol"], rtol=conv_env["rtol"]
|
||||
), (
|
||||
f"Cache state mismatch: max diff = {(conv_state_cache_cuda - conv_state_cache_triton).abs().max()}"
|
||||
)
|
||||
|
||||
|
||||
def test_mixed_prefill_decode_triton_vs_cuda(conv_env):
|
||||
"""Test that Triton backend matches CUDA backend for mixed prefill + decode batch."""
|
||||
device = conv_env["device"]
|
||||
dtype = conv_env["dtype"]
|
||||
|
||||
# Mixed batch: 1 prefill sequence (len=3) + 2 decode tokens
|
||||
prefill_lens = [3]
|
||||
num_prefill = len(prefill_lens)
|
||||
num_prefill_tokens = sum(prefill_lens)
|
||||
num_decode = 2
|
||||
|
||||
total_tokens = num_prefill_tokens + num_decode
|
||||
batch, seq = 1, total_tokens
|
||||
c, k = 4, 3
|
||||
x, w, b, s, p, d, g, pm = _random_params_depthwise(device, dtype, batch, seq, c, k)
|
||||
|
||||
max_batch_size = 4
|
||||
# Slot indices: first for prefill, then for decode
|
||||
slot_idx = torch.tensor([0, 1, 2], device=device, dtype=torch.int32)
|
||||
conv_state_cache_cuda = torch.randn(
|
||||
max_batch_size,
|
||||
c,
|
||||
k - 1,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
conv_state_cache_triton = conv_state_cache_cuda.clone()
|
||||
|
||||
# batch_info_host: [num_prefill, num_prefill_tokens, num_decode]
|
||||
batch_info_host = torch.tensor(
|
||||
[num_prefill, num_prefill_tokens, num_decode], device=device, dtype=torch.int32
|
||||
)
|
||||
cu_seqlen = torch.tensor([0, prefill_lens[0]], device=device, dtype=torch.int32)
|
||||
seq_len = torch.tensor(prefill_lens, device=device, dtype=torch.int32)
|
||||
use_initial_states = torch.zeros(num_prefill, device=device, dtype=torch.bool)
|
||||
|
||||
# Clone inputs
|
||||
x_cuda = x.clone()
|
||||
x_triton = x.clone()
|
||||
|
||||
# Run CUDA backend
|
||||
torch.ops.auto_deploy.cuda_cached_causal_conv1d(
|
||||
x_cuda,
|
||||
w,
|
||||
b,
|
||||
batch_info_host,
|
||||
cu_seqlen,
|
||||
slot_idx,
|
||||
use_initial_states,
|
||||
conv_state_cache_cuda,
|
||||
s,
|
||||
p,
|
||||
d,
|
||||
g,
|
||||
pm,
|
||||
None,
|
||||
)
|
||||
|
||||
# Run Triton backend
|
||||
torch.ops.auto_deploy.triton_cached_causal_conv1d(
|
||||
x_triton,
|
||||
w,
|
||||
b,
|
||||
batch_info_host,
|
||||
seq_len,
|
||||
cu_seqlen,
|
||||
slot_idx,
|
||||
use_initial_states,
|
||||
conv_state_cache_triton,
|
||||
s,
|
||||
p,
|
||||
d,
|
||||
g,
|
||||
pm,
|
||||
None,
|
||||
)
|
||||
|
||||
# Compare outputs
|
||||
assert x_cuda.shape == x_triton.shape
|
||||
assert torch.allclose(x_cuda, x_triton, atol=conv_env["atol"], rtol=conv_env["rtol"]), (
|
||||
f"Output mismatch: max diff = {(x_cuda - x_triton).abs().max()}"
|
||||
)
|
||||
|
||||
# Compare cache states
|
||||
assert torch.allclose(
|
||||
conv_state_cache_cuda, conv_state_cache_triton, atol=conv_env["atol"], rtol=conv_env["rtol"]
|
||||
), (
|
||||
f"Cache state mismatch: max diff = {(conv_state_cache_cuda - conv_state_cache_triton).abs().max()}"
|
||||
)
|
||||
|
||||
|
||||
def test_larger_batch_triton_vs_cuda(conv_env):
|
||||
"""Test with larger batch and longer sequences."""
|
||||
device = conv_env["device"]
|
||||
dtype = conv_env["dtype"]
|
||||
|
||||
# Multiple prefill sequences
|
||||
lens = [5, 8, 3, 10]
|
||||
total = sum(lens)
|
||||
batch, seq = 1, total
|
||||
c, k = 16, 4
|
||||
x, w, b, s, p, d, g, pm = _random_params_depthwise(device, dtype, batch, seq, c, k)
|
||||
|
||||
max_batch_size = 8
|
||||
slot_idx = torch.tensor([3, 1, 5, 0], device=device, dtype=torch.int32)
|
||||
conv_state_cache_cuda = torch.randn(
|
||||
max_batch_size,
|
||||
c,
|
||||
k - 1,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
conv_state_cache_triton = conv_state_cache_cuda.clone()
|
||||
|
||||
num_prefill = len(lens)
|
||||
batch_info_host = torch.tensor([num_prefill, total, 0], device=device, dtype=torch.int32)
|
||||
|
||||
# Build cumulative sequence lengths
|
||||
cu_seqlen_list = [0]
|
||||
for ln in lens:
|
||||
cu_seqlen_list.append(cu_seqlen_list[-1] + ln)
|
||||
cu_seqlen = torch.tensor(cu_seqlen_list, device=device, dtype=torch.int32)
|
||||
seq_len = torch.tensor(lens, device=device, dtype=torch.int32)
|
||||
use_initial_states = torch.zeros(num_prefill, device=device, dtype=torch.bool)
|
||||
|
||||
x_cuda = x.clone()
|
||||
x_triton = x.clone()
|
||||
|
||||
# Run CUDA backend
|
||||
torch.ops.auto_deploy.cuda_cached_causal_conv1d(
|
||||
x_cuda,
|
||||
w,
|
||||
b,
|
||||
batch_info_host,
|
||||
cu_seqlen,
|
||||
slot_idx,
|
||||
use_initial_states,
|
||||
conv_state_cache_cuda,
|
||||
s,
|
||||
p,
|
||||
d,
|
||||
g,
|
||||
pm,
|
||||
None,
|
||||
)
|
||||
|
||||
# Run Triton backend
|
||||
torch.ops.auto_deploy.triton_cached_causal_conv1d(
|
||||
x_triton,
|
||||
w,
|
||||
b,
|
||||
batch_info_host,
|
||||
seq_len,
|
||||
cu_seqlen,
|
||||
slot_idx,
|
||||
use_initial_states,
|
||||
conv_state_cache_triton,
|
||||
s,
|
||||
p,
|
||||
d,
|
||||
g,
|
||||
pm,
|
||||
None,
|
||||
)
|
||||
|
||||
# Compare outputs
|
||||
assert torch.allclose(x_cuda, x_triton, atol=conv_env["atol"], rtol=conv_env["rtol"]), (
|
||||
f"Output mismatch: max diff = {(x_cuda - x_triton).abs().max()}"
|
||||
)
|
||||
|
||||
# Compare cache states
|
||||
assert torch.allclose(
|
||||
conv_state_cache_cuda, conv_state_cache_triton, atol=conv_env["atol"], rtol=conv_env["rtol"]
|
||||
), (
|
||||
f"Cache state mismatch: max diff = {(conv_state_cache_cuda - conv_state_cache_triton).abs().max()}"
|
||||
)
|
||||
Loading…
Reference in New Issue
Block a user