[#11318][infra] AutoDeploy: Add fused rope kernel - triton_rope_on_interleaved_qk_inputs (#11327)

Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com>
This commit is contained in:
Bala Marimuthu 2026-02-17 13:24:18 -05:00 committed by GitHub
parent ab941afa2e
commit 6157f30b06
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 717 additions and 4 deletions

View File

@ -4,4 +4,131 @@ All AutoDeploy custom operators follow the following naming convention:
`torch.ops.auto_deploy.<kernel_backend>_<op_category>_<op_name>`
The table below lists the operators ordered by their backend.
The table below lists the operators grouped by category.
### Available Custom Operators
#### Attention
| Operator Name | Description |
|--------------|-------------|
| `torch.ops.auto_deploy.torch_attention` | Grouped SDPA implementation with `bsnd` and `bnsd` layout supported |
| `torch.ops.auto_deploy.torch_attention_sdpa` | Standard scaled dot-product attention (SDPA) implementation |
| `torch.ops.auto_deploy.torch_attention_repeat_kv` | KV repetition for grouped-query attention |
| `torch.ops.auto_deploy.torch_cached_attention_with_cache` | PyTorch backend attention with KV cache management |
| `torch.ops.auto_deploy.flashinfer_attention_mha_with_cache` | FlashInfer multi-head attention with KV cache support |
| `torch.ops.auto_deploy.flashinfer_attention_prepare_metadata` | FlashInfer attention metadata preparation |
| `torch.ops.auto_deploy.triton_attention_flattened_mha_with_cache` | Triton flattened MHA with cache |
| `torch.ops.auto_deploy.torch_onnx_attention_plugin` | Fused attention with RoPE placeholder for ONNX export |
| `torch.ops.auto_deploy.torch_onnx_gather_nd` | N-dimensional gather operation for ONNX export |
#### MLA (Multi-head Latent Attention)
| Operator Name | Description |
|--------------|-------------|
| `torch.ops.auto_deploy.torch_mla` | Multi-head Latent Attention (MLA) implementation |
| `torch.ops.auto_deploy.torch_cached_mla_with_cache` | PyTorch backend cached MLA with KV cache |
| `torch.ops.auto_deploy.flashinfer_mla_with_cache` | FlashInfer MLA with cache |
| `torch.ops.auto_deploy.flashinfer_mla_prepare_metadata` | FlashInfer MLA metadata preparation |
#### RoPE (Rotary Position Embedding)
| Operator Name | Description |
|--------------|-------------|
| `torch.ops.auto_deploy.torch_rope_with_explicit_cos_sin` | RoPE with explicit cosine/sine |
| `torch.ops.auto_deploy.torch_rope_with_complex_freqs` | RoPE with complex frequencies |
| `torch.ops.auto_deploy.torch_rope_with_qk_interleaving` | RoPE with QK interleaving |
| `torch.ops.auto_deploy.triton_rope_with_input_pos` | Triton RoPE with input positions |
| `torch.ops.auto_deploy.triton_rope_on_flattened_inputs` | Triton RoPE on flattened inputs |
| `torch.ops.auto_deploy.triton_rope_on_interleaved_qk_inputs` | Triton fused RoPE on interleaved QK inputs (position lookup + de-interleave + RoPE) |
| `torch.ops.auto_deploy.flashinfer_rope` | FlashInfer RoPE implementation |
#### Linear
| Operator Name | Description |
|--------------|-------------|
| `torch.ops.auto_deploy.torch_linear_simple` | Simple linear layer wrapper (avoids view ops in export graph) |
| `torch.ops.auto_deploy.torch_moe_router` | MoE router: linear projection + top-k + softmax + scatter |
#### MoE (Mixture of Experts)
| Operator Name | Description |
|--------------|-------------|
| `torch.ops.auto_deploy.torch_moe` | Mixture of Experts implementation (PyTorch backend) |
| `torch.ops.auto_deploy.torch_moe_fused` | Fused Mixture of Experts implementation (PyTorch backend) |
| `torch.ops.auto_deploy.torch_moe_dense_mlp` | Dense MLP implementation for MoE (PyTorch backend) |
| `torch.ops.auto_deploy.torch_quant_fp8_moe` | FP8 quantized MoE (PyTorch backend) |
| `torch.ops.auto_deploy.torch_quant_nvfp4_moe` | NVFP4 quantized MoE (PyTorch backend) |
| `torch.ops.auto_deploy.triton_moe_fused` | Fused MoE (Triton backend) |
| `torch.ops.auto_deploy.triton_quant_fp8_moe` | FP8 quantized MoE (Triton backend) |
| `torch.ops.auto_deploy.triton_mxfp4_moe` | MXFP4 MoE with triton-kernels matmul_ogs |
| `torch.ops.auto_deploy.triton_mxfp4_moe_ep` | MXFP4 MoE with Expert Parallelism (triton-kernels) |
| `torch.ops.auto_deploy.trtllm_moe_fused` | Fused MoE (TRT-LLM backend) |
| `torch.ops.auto_deploy.trtllm_quant_fp8_moe_fused` | FP8 quantized fused MoE (TRT-LLM backend) |
| `torch.ops.auto_deploy.trtllm_quant_nvfp4_moe_fused` | NVFP4 quantized fused MoE (TRT-LLM backend) |
#### Quantization
| Operator Name | Description |
|--------------|-------------|
| `torch.ops.auto_deploy.torch_quant_fn` | Generic quantization function that scales, rounds, and clamps input values |
| `torch.ops.auto_deploy.torch_quant_fp8_linear` | FP8 quantized linear layer (PyTorch backend) |
| `torch.ops.auto_deploy.torch_quant_nvfp4_linear` | NVFP4 quantized linear layer (PyTorch backend) |
| `torch.ops.auto_deploy.torch_quant_fp8_bmm` | FP8 quantized batch matrix multiply (PyTorch backend) |
| `torch.ops.auto_deploy.trtllm_quant_fp8_linear` | FP8 quantized linear layer (TRT-LLM backend) |
| `torch.ops.auto_deploy.torch_fake_quant_fp8_linear` | Fake FP8 quantized linear (for calibration/simulation) |
| `torch.ops.auto_deploy.torch_fake_quant_nvfp4_linear` | Fake NVFP4 quantized linear (for calibration/simulation) |
| `torch.ops.auto_deploy.torch_fake_quant_int4_linear` | Fake INT4 quantized linear (for calibration/simulation) |
| `torch.ops.auto_deploy.torch_fake_quant_int4_gptq_linear` | Fake INT4 GPTQ quantized linear (for calibration/simulation) |
#### Normalization
| Operator Name | Description |
|--------------|-------------|
| `torch.ops.auto_deploy.torch_rmsnorm` | RMSNorm (PyTorch backend) |
| `torch.ops.auto_deploy.torch_rmsnorm_gated` | Gated RMSNorm with optional SiLU gating (PyTorch backend) |
| `torch.ops.auto_deploy.triton_rms_norm` | RMSNorm (Triton backend) |
| `torch.ops.auto_deploy.triton_rmsnorm_gated` | Gated RMSNorm with optional SiLU gating (Triton backend) |
| `torch.ops.auto_deploy.flashinfer_rms_norm` | RMSNorm (FlashInfer backend) |
| `torch.ops.auto_deploy.flashinfer_fused_add_rms_norm_inplace` | Fused residual add + RMSNorm in-place (FlashInfer backend) |
| `torch.ops.auto_deploy.sharded_rmsnorm` | RMSNorm for tensor-parallel sharded activations (uses all-reduce) |
| `torch.ops.auto_deploy.torch_l2norm` | L2 normalization (PyTorch backend) |
| `torch.ops.auto_deploy.fla_l2norm` | L2 normalization (FLA Triton kernel backend) |
#### Mamba (SSM + Causal Conv)
| Operator Name | Description |
|--------------|-------------|
| `torch.ops.auto_deploy.torch_ssm` | State Space Model (SSM) computation (PyTorch backend) |
| `torch.ops.auto_deploy.torch_cached_ssm` | Cached SSM with state management (PyTorch backend) |
| `torch.ops.auto_deploy.triton_cached_ssm` | Cached SSM with state management (Triton backend) |
| `torch.ops.auto_deploy.flashinfer_cached_ssm` | Cached SSM with state management (FlashInfer backend) |
| `torch.ops.auto_deploy.mamba_ssm_prepare_metadata` | Mamba SSM metadata preparation (chunk indices, offsets, seq_idx) |
| `torch.ops.auto_deploy.torch_causal_conv1d` | Causal 1D convolution (PyTorch backend) |
| `torch.ops.auto_deploy.torch_cached_causal_conv1d` | Cached causal 1D convolution (PyTorch backend) |
| `torch.ops.auto_deploy.triton_cached_causal_conv1d` | Cached causal 1D convolution (Triton backend) |
| `torch.ops.auto_deploy.cuda_cached_causal_conv1d` | Cached causal 1D convolution (CUDA backend) |
#### FLA (Flash Linear Attention)
| Operator Name | Description |
|--------------|-------------|
| `torch.ops.auto_deploy.fla_delta_rule` | FLA chunked delta rule computation |
| `torch.ops.auto_deploy.fla_cached_delta_rule` | FLA cached delta rule with state management |
#### Distributed
| Operator Name | Description |
|--------------|-------------|
| `torch.ops.auto_deploy.torch_dist_all_gather` | All-gather (PyTorch backend, demollm mode) |
| `torch.ops.auto_deploy.torch_dist_all_reduce` | All-reduce (PyTorch backend, demollm mode) |
| `torch.ops.auto_deploy.trtllm_dist_all_gather` | All-gather (TRT-LLM backend, MPI mode) |
| `torch.ops.auto_deploy.trtllm_dist_all_reduce` | All-reduce (TRT-LLM backend, MPI mode) |
| `torch.ops.dist.trtllm_fused_allreduce_residual_rmsnorm` | Fused all-reduce + residual add + RMSNorm (TRT-LLM backend, MPI mode) |
#### Utilities
| Operator Name | Description |
|--------------|-------------|
| `torch.ops.auto_deploy.triton_utils_fused_gather_scatter` | Triton fused gather + scatter for overlap scheduling input_ids reordering |
| `torch.ops.auto_deploy.gather_logits_before_lm_head` | Gather hidden states using logits indices before LM head |

View File

@ -12,11 +12,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Tuple
import torch
import triton
from .triton_rope_kernel import rope_fwd_flattened_kernel, rope_fwd_kernel
from .triton_rope_kernel import (
rope_fwd_flattened_kernel,
rope_fwd_interleaved_kernel,
rope_fwd_kernel,
)
@torch.library.custom_op("auto_deploy::triton_rope_with_input_pos", mutates_args=())
@ -147,3 +152,105 @@ def apply_rope_on_flattened_inputs(
@apply_rope_on_flattened_inputs.register_fake
def apply_rope_on_flattened_inputs_fake(x, freqs_cis, input_pos, seq_lens, seq_start_indices):
return torch.empty_like(x)
@torch.library.custom_op("auto_deploy::triton_rope_on_interleaved_qk_inputs", mutates_args=())
def apply_rope_on_interleaved_qk_inputs(
q: torch.Tensor,
k: torch.Tensor,
cos_cache: torch.Tensor,
sin_cache: torch.Tensor,
position_ids: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Fused RoPE for DeepSeek-style interleaved Q/K inputs.
This kernel fuses:
1. Position ID lookup for cos/sin from cache
2. De-interleaving of Q/K
3. RoPE application
Args:
q: Query tensor with interleaved layout [B, S, H_Q, D]
Input: [q0_r, q0_i, q1_r, q1_i, ...]
k: Key tensor with interleaved layout [B, S, H_K, D]
Typically H_K=1 for MQA
cos_cache: Cosine cache [max_seq_len, D]
sin_cache: Sine cache [max_seq_len, D]
position_ids: Position indices [B, S]
Returns:
Tuple of (q_rotated, k_rotated) with layout [B, S, H, D]
Output: [y0, y1, ..., y_{D/2-1}, y_{D/2}, ..., y_{D-1}]
where y_first = a*cos - b*sin, y_second = b*cos + a*sin
"""
assert q.dim() == 4, f"Q must be 4D [B, S, H, D], got {q.dim()}D"
assert k.dim() == 4, f"K must be 4D [B, S, H, D], got {k.dim()}D"
assert q.shape[-1] % 2 == 0, "Head dimension must be even"
assert q.shape[-1] == k.shape[-1], "Q and K must have same head dimension"
assert cos_cache.shape == sin_cache.shape, "cos and sin cache must have same shape"
B, S, H_Q, D = q.shape
_, _, H_K, _ = k.shape
assert k.shape[0] == B and k.shape[1] == S, "Q and K must have same batch and seq dims"
assert position_ids.shape == (B, S), f"position_ids must be [B, S], got {position_ids.shape}"
assert H_Q >= H_K, f"H_Q ({H_Q}) must be >= H_K ({H_K}) for grid sizing"
# Allocate contiguous outputs
# The kernel computes contiguous strides internally for output writes
q_out = torch.empty_like(q)
k_out = torch.empty_like(k)
# Block sizes
BLOCK_SIZE_H = 32
BLOCK_SIZE_S = min(triton.next_power_of_2(S), 32)
# Grid: (B, cdiv(H_Q, BLOCK_SIZE_H), cdiv(S, BLOCK_SIZE_S))
# H_Q >= H_K is enforced above; K heads are masked within each block
grid = (
B,
triton.cdiv(H_Q, BLOCK_SIZE_H),
triton.cdiv(S, BLOCK_SIZE_S),
)
rope_fwd_interleaved_kernel[grid](
q,
k,
cos_cache,
sin_cache,
position_ids,
q_out,
k_out,
B,
S,
H_Q,
H_K,
D,
q.stride(0),
q.stride(1),
q.stride(2),
q.stride(3),
k.stride(0),
k.stride(1),
k.stride(2),
k.stride(3),
position_ids.stride(0),
position_ids.stride(1),
cos_cache.stride(0),
cos_cache.stride(1),
BLOCK_SIZE_H=BLOCK_SIZE_H,
BLOCK_SIZE_S=BLOCK_SIZE_S,
)
return q_out, k_out
@apply_rope_on_interleaved_qk_inputs.register_fake
def apply_rope_on_interleaved_qk_inputs_fake(
q: torch.Tensor,
k: torch.Tensor,
cos_cache: torch.Tensor,
sin_cache: torch.Tensor,
position_ids: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
return torch.empty_like(q), torch.empty_like(k)

View File

@ -17,6 +17,191 @@ import triton
import triton.language as tl
@triton.jit
def rope_fwd_interleaved_kernel(
q_ptr, # [B, S, H_Q, D] input Q (interleaved layout)
k_ptr, # [B, S, H_K, D] input K (interleaved layout)
cos_cache_ptr, # [max_seq_len, D] cos cache
sin_cache_ptr, # [max_seq_len, D] sin cache
position_ids_ptr, # [B, S] position IDs
q_out_ptr, # [B, S, H_Q, D] output Q
k_out_ptr, # [B, S, H_K, D] output K
B, # batch size
S, # sequence length
H_Q, # number of Q heads
H_K, # number of K heads (typically 1 for MQA)
D: tl.constexpr, # head dimension
stride_qb, # Q batch stride
stride_qs, # Q seq stride
stride_qh, # Q head stride
stride_qd, # Q dim stride
stride_kb, # K batch stride
stride_ks, # K seq stride
stride_kh, # K head stride
stride_kd, # K dim stride
stride_pos_b, # position_ids batch stride
stride_pos_s, # position_ids seq stride
stride_cache_s, # cache seq stride
stride_cache_d, # cache dim stride
BLOCK_SIZE_H: tl.constexpr,
BLOCK_SIZE_S: tl.constexpr,
):
"""
Fused RoPE kernel for DeepSeek-style interleaved Q/K inputs.
Input layout (interleaved): [q0_r, q0_i, q1_r, q1_i, ...]
After de-interleave: [q0_r, q1_r, ..., q0_i, q1_i, ...]
This kernel does:
1. Position ID lookup for cos/sin
2. Reads even and odd indices directly from the interleaved input using strided loads
3. RoPE application directly on the separated a/b values: y_first = a*cos - b*sin, y_second = b*cos + a*sin
4. Writes the output in contiguous "half-split" layout first half and second half stored separately
Grid: (B, cdiv(H_Q, BLOCK_SIZE_H), cdiv(S, BLOCK_SIZE_S))
"""
D2: tl.constexpr = D // 2
D2_PADDED: tl.constexpr = triton.next_power_of_2(D2)
# Program IDs
batch_id = tl.program_id(0)
head_block_id = tl.program_id(1)
seq_block_id = tl.program_id(2)
# Head offsets and mask
head_offsets = head_block_id * BLOCK_SIZE_H + tl.arange(0, BLOCK_SIZE_H)
head_mask = head_offsets < H_Q
# Sequence offsets and mask
seq_offsets = seq_block_id * BLOCK_SIZE_S + tl.arange(0, BLOCK_SIZE_S)
seq_mask = seq_offsets < S
# Dimension offsets for half the head dim (we read pairs)
dim_offsets = tl.arange(0, D2_PADDED)
dim_mask = dim_offsets < D2
# ========== LOAD POSITION IDS ==========
# position_ids: [B, S]
pos_ptr = position_ids_ptr + batch_id * stride_pos_b + seq_offsets * stride_pos_s
pos_ids = tl.load(pos_ptr, mask=seq_mask, other=0) # [BLOCK_SIZE_S]
# ========== LOAD COS/SIN FROM CACHE ==========
# cos_cache, sin_cache: [max_seq_len, D]
# For each position, load the corresponding cos/sin values
# We need cos/sin for both halves of head_dim
cache_offsets = (
pos_ids[:, None] * stride_cache_s + dim_offsets[None, :] * stride_cache_d
) # [BLOCK_SIZE_S, D2_PADDED]
cache_mask = seq_mask[:, None] & dim_mask[None, :] # [BLOCK_SIZE_S, D2_PADDED]
cos_first = tl.load(cos_cache_ptr + cache_offsets, mask=cache_mask) # [BLOCK_SIZE_S, D2_PADDED]
sin_first = tl.load(sin_cache_ptr + cache_offsets, mask=cache_mask) # [BLOCK_SIZE_S, D2_PADDED]
# Second half of cos/sin (offset by D2)
cache_offsets_second = (
pos_ids[:, None] * stride_cache_s + (dim_offsets[None, :] + D2) * stride_cache_d
)
cos_second = tl.load(cos_cache_ptr + cache_offsets_second, mask=cache_mask)
sin_second = tl.load(sin_cache_ptr + cache_offsets_second, mask=cache_mask)
# ========== PROCESS Q ==========
# Q layout: [B, S, H, D] with interleaved D
# Read even indices (a) and odd indices (b) for de-interleaving
# Input: [q0_r, q0_i, q1_r, q1_i, ...] -> a=[q0_r, q1_r, ...], b=[q0_i, q1_i, ...]
q_base = batch_id * stride_qb
# Compute offsets for reading interleaved data
# even_offsets: positions 0, 2, 4, ... (stride 2)
# odd_offsets: positions 1, 3, 5, ... (stride 2)
q_offsets_base = (
seq_offsets[:, None, None] * stride_qs
+ head_offsets[None, :, None] * stride_qh
+ dim_offsets[None, None, :] * 2 * stride_qd # stride 2 for interleaved
) # [BLOCK_SIZE_S, BLOCK_SIZE_H, D2_PADDED]
even_offsets = q_base + q_offsets_base
odd_offsets = q_base + q_offsets_base + stride_qd
# Combined mask
load_mask = seq_mask[:, None, None] & head_mask[None, :, None] & dim_mask[None, None, :]
# Load Q values (even = a, odd = b)
q_a = tl.load(q_ptr + even_offsets, mask=load_mask) # [BLOCK_SIZE_S, BLOCK_SIZE_H, D2_PADDED]
q_b = tl.load(q_ptr + odd_offsets, mask=load_mask) # [BLOCK_SIZE_S, BLOCK_SIZE_H, D2_PADDED]
# Broadcast cos/sin for heads: [BLOCK_SIZE_S, D2_PADDED] -> [BLOCK_SIZE_S, 1, D2_PADDED]
cos_first_bc = cos_first[:, None, :]
sin_first_bc = sin_first[:, None, :]
cos_second_bc = cos_second[:, None, :]
sin_second_bc = sin_second[:, None, :]
# Apply RoPE formula
# y_first_half = a * cos - b * sin
# y_second_half = b * cos + a * sin
q_y1 = q_a * cos_first_bc - q_b * sin_first_bc
q_y2 = q_b * cos_second_bc + q_a * sin_second_bc
# Store Q output (CONTIGUOUS layout)
# Output layout: [B, S, H_Q, D] with first half = y1, second half = y2
# Compute contiguous strides: stride_b=S*H_Q*D, stride_s=H_Q*D, stride_h=D, stride_d=1
q_out_stride_b = S * H_Q * D
q_out_stride_s = H_Q * D
q_out_stride_h = D
q_out_offsets_first = (
batch_id * q_out_stride_b
+ seq_offsets[:, None, None] * q_out_stride_s
+ head_offsets[None, :, None] * q_out_stride_h
+ dim_offsets[None, None, :] # stride_d = 1
)
q_out_offsets_second = q_out_offsets_first + D2 # D2 * 1
tl.store(q_out_ptr + q_out_offsets_first, q_y1, mask=load_mask)
tl.store(q_out_ptr + q_out_offsets_second, q_y2, mask=load_mask)
# ========== PROCESS K ==========
# K typically has H_K=1 for MQA, but we handle general case
# Use head_offsets < H_K for K's mask
head_mask_k = head_offsets < H_K
k_base = batch_id * stride_kb
k_offsets_base = (
seq_offsets[:, None, None] * stride_ks
+ head_offsets[None, :, None] * stride_kh
+ dim_offsets[None, None, :] * 2 * stride_kd
)
k_even_offsets = k_base + k_offsets_base
k_odd_offsets = k_base + k_offsets_base + stride_kd
load_mask_k = seq_mask[:, None, None] & head_mask_k[None, :, None] & dim_mask[None, None, :]
k_a = tl.load(k_ptr + k_even_offsets, mask=load_mask_k)
k_b = tl.load(k_ptr + k_odd_offsets, mask=load_mask_k)
k_y1 = k_a * cos_first_bc - k_b * sin_first_bc
k_y2 = k_b * cos_second_bc + k_a * sin_second_bc
# Store K output (CONTIGUOUS layout)
# Output layout: [B, S, H_K, D] with first half = y1, second half = y2
# Compute contiguous strides: stride_b=S*H_K*D, stride_s=H_K*D, stride_h=D, stride_d=1
k_out_stride_b = S * H_K * D
k_out_stride_s = H_K * D
k_out_stride_h = D
k_out_offsets_first = (
batch_id * k_out_stride_b
+ seq_offsets[:, None, None] * k_out_stride_s
+ head_offsets[None, :, None] * k_out_stride_h
+ dim_offsets[None, None, :] # stride_d = 1
)
k_out_offsets_second = k_out_offsets_first + D2 # D2 * 1
tl.store(k_out_ptr + k_out_offsets_first, k_y1, mask=load_mask_k)
tl.store(k_out_ptr + k_out_offsets_second, k_y2, mask=load_mask_k)
@triton.jit
def rope_fwd_kernel(
x_ptr,

View File

@ -363,7 +363,11 @@ class MatchRopeLayout(BaseTransform):
class OptimizeRope(BaseTransform):
"""
Scan the FX graph and replace calls to the torch-reference RoPE ops with
the optimized `rope::flashinfer` kernel.
optimized kernels:
- ``torch_rope_with_explicit_cos_sin`` ``flashinfer_rope``
- ``torch_rope_with_complex_freqs`` ``flashinfer_rope``
- ``torch_rope_with_qk_interleaving`` ``triton_rope_on_interleaved_qk_inputs``
Precomputes positional IDs and the fused cosine-sine cache as explicit nodes,
and reuses those nodes when possible.
"""
@ -385,6 +389,9 @@ class OptimizeRope(BaseTransform):
_optimize_explicit(graph, node, rope_flash_cache, rope_position_ids_cache)
elif is_op(node, torch.ops.auto_deploy.torch_rope_with_complex_freqs):
_optimize_complex(graph, node, rope_flash_cache, rope_position_ids_cache)
elif is_op(node, torch.ops.auto_deploy.torch_rope_with_qk_interleaving):
if not _optimize_interleaved(graph, node):
continue
else:
continue
num_rope_optimizations += 1
@ -546,6 +553,117 @@ def _optimize_complex(
graph.erase_node(node)
def _trace_back_index(node: Node) -> Optional[Tuple[Node, Node]]:
"""Trace back from a node to find an aten.index.Tensor producer.
If ``node`` was produced by ``aten.index.Tensor(cache, [position_ids])``,
return ``(cache_node, position_ids_node)``. Otherwise return ``None``.
"""
if not is_op(node, torch.ops.aten.index.Tensor):
return None
cache_node = node.args[0]
indices = node.args[1]
if not isinstance(indices, (list, tuple)) or len(indices) != 1:
return None
position_ids_node = indices[0]
if not isinstance(cache_node, Node) or not isinstance(position_ids_node, Node):
return None
return cache_node, position_ids_node
def _validate_interleaved_rope_inputs(q_node: Node, k_node: Node) -> bool:
"""Validate q/k inputs for the interleaved triton RoPE kernel.
Relaxed compared to ``_validate_rope_inputs``: requires even head_dim
(not head_dim % 64 == 0) since the triton kernel pads to next-power-of-2.
"""
for node in (q_node, k_node):
fake_val = node.meta.get("val", None)
if fake_val is None:
return False
# dtype must be half-precision
if fake_val.dtype not in (torch.float16, torch.bfloat16):
return False
# Must be at least 4-D
if len(fake_val.shape) < 4:
return False
# head_dim must be even
head_dim = fake_val.shape[-1]
if isinstance(head_dim, int) and head_dim % 2 != 0:
return False
# BSND layout: dim 1 (S) should be symbolic, dim 2 (N) should be static
if not isinstance(fake_val.shape[1], torch.SymInt):
return False
if not isinstance(fake_val.shape[2], int):
return False
return True
def _optimize_interleaved(graph: torch.fx.Graph, node: Node) -> bool:
"""Replace ``torch_rope_with_qk_interleaving`` with ``triton_rope_on_interleaved_qk_inputs``.
Traces back from cos/sin nodes to find the original ``aten.index.Tensor``
producer and extracts the cached cos/sin tensors and position_ids, passing
them directly to the triton kernel (which fuses the position lookup).
Returns True if the replacement was made, False if skipped.
"""
# --- extract arguments ---------------------------------------------------
q_node, k_node, cos_node, sin_node, *rest = node.args
q_rope_old, k_rope_old = extract_output_tuple(node, 2)
if q_rope_old is None or k_rope_old is None:
return False
# --- validate inputs -----------------------------------------------------
if not _validate_interleaved_rope_inputs(q_node, k_node):
return False
# --- trace back cos/sin to find cache + position_ids ---------------------
cos_traced = _trace_back_index(cos_node)
if cos_traced is None:
return False
cos_cache_node, cos_position_ids_node = cos_traced
sin_traced = _trace_back_index(sin_node)
if sin_traced is None:
return False
sin_cache_node, sin_position_ids_node = sin_traced
# Both cos and sin must use the same position_ids
if cos_position_ids_node is not sin_position_ids_node:
return False
position_ids_node = cos_position_ids_node
# --- insert the triton op ------------------------------------------------
with graph.inserting_before(node):
triton_node = graph.call_function(
torch.ops.auto_deploy.triton_rope_on_interleaved_qk_inputs,
args=(q_node, k_node, cos_cache_node, sin_cache_node, position_ids_node),
)
with graph.inserting_after(triton_node):
q_rope_new = graph.call_function(operator.getitem, args=(triton_node, 0))
k_rope_new = graph.call_function(operator.getitem, args=(triton_node, 1))
# --- rewire outputs ------------------------------------------------------
q_rope_new.meta["val"] = q_rope_old.meta.get("val", None)
k_rope_new.meta["val"] = k_rope_old.meta.get("val", None)
q_rope_old.replace_all_uses_with(q_rope_new)
k_rope_old.replace_all_uses_with(k_rope_new)
graph.erase_node(q_rope_old)
graph.erase_node(k_rope_old)
return True
def _match_input_interleave_pattern(node: Node) -> Optional[Dict[str, Node]]:
"""
Detect DeepSeek-style interleave on Q/K:

View File

@ -1,4 +1,4 @@
from typing import Optional
from typing import Optional, Tuple
import pytest
import torch
@ -8,6 +8,19 @@ from _custom_op_utils import torch_rope_reference
from tensorrt_llm._torch.auto_deploy.custom_ops.rope import triton_rope # noqa: F401
def _precompute_cos_sin_cache(
max_seq_len: int, head_dim: int, rope_theta: float = 10000.0, dtype: torch.dtype = torch.float16
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Precompute cos and sin cache for RoPE (DeepSeek-style)."""
inv_freq = 1.0 / (rope_theta ** (torch.arange(0, head_dim, 2, dtype=torch.float32) / head_dim))
positions = torch.arange(max_seq_len, dtype=torch.float32)
freqs = positions.unsqueeze(1) * inv_freq.unsqueeze(0) # [max_seq_len, head_dim//2]
emb = torch.cat((freqs, freqs), dim=-1) # [max_seq_len, head_dim]
cos_cache = emb.cos().to(dtype)
sin_cache = emb.sin().to(dtype)
return cos_cache, sin_cache
def _precompute_freqs_cis(
seq_len: int, head_dim: int, rope_theta: Optional[float] = None
) -> torch.Tensor:
@ -73,3 +86,105 @@ def test_rope_flattened(d_head):
y_reshaped = y.unflatten(-1, (2, N_ELEM // 2)).transpose(-2, -1).flatten(-2).contiguous()
assert torch.allclose(y_ref.cpu(), y_reshaped.cpu(), atol=1e-2, rtol=1e-2)
@pytest.mark.parametrize(
"batch_size,seq_len,num_q_heads,num_k_heads,head_dim",
[
(1, 4, 8, 8, 64), # Standard case
(2, 16, 20, 1, 128), # GLM-4 style: non-power-of-2 heads, MQA
(1, 1, 8, 8, 64), # Single token (decode)
(4, 32, 16, 2, 96), # GQA with non-standard head_dim
],
)
def test_triton_rope_on_interleaved_qk_inputs(
batch_size: int, seq_len: int, num_q_heads: int, num_k_heads: int, head_dim: int
):
"""
Test that triton_rope_on_interleaved_qk_inputs produces the same output as
the PyTorch reference (index + torch_rope_with_qk_interleaving).
"""
device = "cuda"
dtype = torch.bfloat16
max_seq_len = 1024
# Create random inputs with interleaved layout [B, S, H, D]
q = torch.randn(batch_size, seq_len, num_q_heads, head_dim, device=device, dtype=dtype)
k = torch.randn(batch_size, seq_len, num_k_heads, head_dim, device=device, dtype=dtype)
# Precompute cos/sin cache
cos_cache, sin_cache = _precompute_cos_sin_cache(max_seq_len, head_dim, dtype=dtype)
cos_cache = cos_cache.to(device)
sin_cache = sin_cache.to(device)
# Random position IDs (not necessarily sequential)
position_ids = torch.randint(0, max_seq_len - seq_len, (batch_size,), device=device)
position_ids = position_ids.unsqueeze(1) + torch.arange(seq_len, device=device).unsqueeze(0)
# position_ids: [B, S]
# ========== PyTorch Reference ==========
# Step 1: Index cos/sin with position_ids
cos_indexed = cos_cache[position_ids] # [B, S, D]
sin_indexed = sin_cache[position_ids] # [B, S, D]
# Step 2: Apply PyTorch rope with qk interleaving
# unsqueeze_dim=2 for [B, S, H, D] layout
q_ref, k_ref = torch.ops.auto_deploy.torch_rope_with_qk_interleaving(
q, k, cos_indexed, sin_indexed, unsqueeze_dim=2
)
# ========== Triton Implementation ==========
q_triton, k_triton = torch.ops.auto_deploy.triton_rope_on_interleaved_qk_inputs(
q, k, cos_cache, sin_cache, position_ids
)
# ========== Compare Outputs ==========
# Use relative tolerance for bf16
atol = 1e-2
rtol = 1e-2
assert torch.allclose(q_ref, q_triton, atol=atol, rtol=rtol), (
f"Q mismatch: max diff = {(q_ref - q_triton).abs().max().item()}"
)
assert torch.allclose(k_ref, k_triton, atol=atol, rtol=rtol), (
f"K mismatch: max diff = {(k_ref - k_triton).abs().max().item()}"
)
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
def test_triton_rope_interleaved_dtype_consistency(dtype):
"""Test that the Triton kernel works correctly with different dtypes."""
device = "cuda"
batch_size, seq_len, num_heads, head_dim = 2, 8, 8, 64
max_seq_len = 1024
q = torch.randn(batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype)
k = torch.randn(batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype)
cos_cache, sin_cache = _precompute_cos_sin_cache(max_seq_len, head_dim, dtype=dtype)
cos_cache = cos_cache.to(device)
sin_cache = sin_cache.to(device)
position_ids = torch.arange(seq_len, device=device).unsqueeze(0).expand(batch_size, -1)
# PyTorch reference
cos_indexed = cos_cache[position_ids]
sin_indexed = sin_cache[position_ids]
q_ref, k_ref = torch.ops.auto_deploy.torch_rope_with_qk_interleaving(
q, k, cos_indexed, sin_indexed, unsqueeze_dim=2
)
# Triton
q_triton, k_triton = torch.ops.auto_deploy.triton_rope_on_interleaved_qk_inputs(
q, k, cos_cache, sin_cache, position_ids
)
# Verify outputs match
atol = 1e-2
rtol = 1e-2
assert torch.allclose(q_ref, q_triton, atol=atol, rtol=rtol)
assert torch.allclose(k_ref, k_triton, atol=atol, rtol=rtol)
# Verify output dtype is preserved
assert q_triton.dtype == dtype
assert k_triton.dtype == dtype

View File

@ -490,3 +490,64 @@ def test_match_and_layout_deepseek(layout, num_heads, num_kv_heads, mode, target
None, # check_num_matches
False, # skip_output_assert
)
@pytest.mark.parametrize(
"num_heads,num_kv_heads",
[
(8, 8), # Standard MHA
(8, 1), # MQA (DeepSeek-style)
],
)
@torch.inference_mode()
def test_optimize_interleaved_rope(num_heads, num_kv_heads):
"""Test that optimize_rope replaces torch_rope_with_qk_interleaving
with triton_rope_on_interleaved_qk_inputs by tracing back through
aten.index.Tensor to find the cached cos/sin and position_ids."""
batch, seq, hid = 4, 12, 512
model = DSModel(hid, 16, num_heads, num_kv_heads, layout="BSND", mode="optimize")
model = model.to("cuda", torch.float16)
x = torch.randn(batch, seq, hid, device="cuda", dtype=torch.float16)
dynamic_shapes = model.get_dynamic_shapes()
gm = torch_export_to_gm(model, args=(x,), dynamic_shapes=(dynamic_shapes,), clone=True)
# Verify the graph contains torch_rope_with_qk_interleaving before optimization
assert any(
is_op(n, torch.ops.auto_deploy.torch_rope_with_qk_interleaving) for n in gm.graph.nodes
), "Expected torch_rope_with_qk_interleaving in graph before optimization"
gm_transformed = InferenceOptimizer(
None,
{
"optimize_rope": {
"stage": "pattern_matcher",
},
},
)(None, gm)
gm_transformed.to("cuda")
def checker(gm):
has_triton = any(
is_op(n, torch.ops.auto_deploy.triton_rope_on_interleaved_qk_inputs)
for n in gm.graph.nodes
)
no_torch_rope = not any(
is_op(n, torch.ops.auto_deploy.torch_rope_with_qk_interleaving) for n in gm.graph.nodes
)
return has_triton and no_torch_rope
run_test_transformed_gm(
model,
x,
gm_transformed,
checker,
lambda num_p: num_p,
1e-2, # atol
1e-2, # rtol
True, # test_load_hook
True, # strict_loading
dynamic_shapes, # dynamic_shapes
None, # check_num_matches
False, # skip_output_assert
)