From 6157f30b0619df17b06c858b2872ac7df8edd3d7 Mon Sep 17 00:00:00 2001 From: Bala Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com> Date: Tue, 17 Feb 2026 13:24:18 -0500 Subject: [PATCH] [#11318][infra] AutoDeploy: Add fused rope kernel - triton_rope_on_interleaved_qk_inputs (#11327) Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com> --- .../_torch/auto_deploy/custom_ops/README.md | 129 +++++++++++- .../custom_ops/rope/triton_rope.py | 109 ++++++++++- .../custom_ops/rope/triton_rope_kernel.py | 185 ++++++++++++++++++ .../auto_deploy/transform/library/rope.py | 120 +++++++++++- .../custom_ops/rope/test_triton_rope.py | 117 ++++++++++- .../library/test_rope_transformation.py | 61 ++++++ 6 files changed, 717 insertions(+), 4 deletions(-) diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/README.md b/tensorrt_llm/_torch/auto_deploy/custom_ops/README.md index addc6cc222..5263bc79df 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/README.md +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/README.md @@ -4,4 +4,131 @@ All AutoDeploy custom operators follow the following naming convention: `torch.ops.auto_deploy.__` -The table below lists the operators ordered by their backend. +The table below lists the operators grouped by category. + +### Available Custom Operators + +#### Attention + +| Operator Name | Description | +|--------------|-------------| +| `torch.ops.auto_deploy.torch_attention` | Grouped SDPA implementation with `bsnd` and `bnsd` layout supported | +| `torch.ops.auto_deploy.torch_attention_sdpa` | Standard scaled dot-product attention (SDPA) implementation | +| `torch.ops.auto_deploy.torch_attention_repeat_kv` | KV repetition for grouped-query attention | +| `torch.ops.auto_deploy.torch_cached_attention_with_cache` | PyTorch backend attention with KV cache management | +| `torch.ops.auto_deploy.flashinfer_attention_mha_with_cache` | FlashInfer multi-head attention with KV cache support | +| `torch.ops.auto_deploy.flashinfer_attention_prepare_metadata` | FlashInfer attention metadata preparation | +| `torch.ops.auto_deploy.triton_attention_flattened_mha_with_cache` | Triton flattened MHA with cache | +| `torch.ops.auto_deploy.torch_onnx_attention_plugin` | Fused attention with RoPE placeholder for ONNX export | +| `torch.ops.auto_deploy.torch_onnx_gather_nd` | N-dimensional gather operation for ONNX export | + +#### MLA (Multi-head Latent Attention) + +| Operator Name | Description | +|--------------|-------------| +| `torch.ops.auto_deploy.torch_mla` | Multi-head Latent Attention (MLA) implementation | +| `torch.ops.auto_deploy.torch_cached_mla_with_cache` | PyTorch backend cached MLA with KV cache | +| `torch.ops.auto_deploy.flashinfer_mla_with_cache` | FlashInfer MLA with cache | +| `torch.ops.auto_deploy.flashinfer_mla_prepare_metadata` | FlashInfer MLA metadata preparation | + +#### RoPE (Rotary Position Embedding) + +| Operator Name | Description | +|--------------|-------------| +| `torch.ops.auto_deploy.torch_rope_with_explicit_cos_sin` | RoPE with explicit cosine/sine | +| `torch.ops.auto_deploy.torch_rope_with_complex_freqs` | RoPE with complex frequencies | +| `torch.ops.auto_deploy.torch_rope_with_qk_interleaving` | RoPE with QK interleaving | +| `torch.ops.auto_deploy.triton_rope_with_input_pos` | Triton RoPE with input positions | +| `torch.ops.auto_deploy.triton_rope_on_flattened_inputs` | Triton RoPE on flattened inputs | +| `torch.ops.auto_deploy.triton_rope_on_interleaved_qk_inputs` | Triton fused RoPE on interleaved QK inputs (position lookup + de-interleave + RoPE) | +| `torch.ops.auto_deploy.flashinfer_rope` | FlashInfer RoPE implementation | + +#### Linear + +| Operator Name | Description | +|--------------|-------------| +| `torch.ops.auto_deploy.torch_linear_simple` | Simple linear layer wrapper (avoids view ops in export graph) | +| `torch.ops.auto_deploy.torch_moe_router` | MoE router: linear projection + top-k + softmax + scatter | + +#### MoE (Mixture of Experts) + +| Operator Name | Description | +|--------------|-------------| +| `torch.ops.auto_deploy.torch_moe` | Mixture of Experts implementation (PyTorch backend) | +| `torch.ops.auto_deploy.torch_moe_fused` | Fused Mixture of Experts implementation (PyTorch backend) | +| `torch.ops.auto_deploy.torch_moe_dense_mlp` | Dense MLP implementation for MoE (PyTorch backend) | +| `torch.ops.auto_deploy.torch_quant_fp8_moe` | FP8 quantized MoE (PyTorch backend) | +| `torch.ops.auto_deploy.torch_quant_nvfp4_moe` | NVFP4 quantized MoE (PyTorch backend) | +| `torch.ops.auto_deploy.triton_moe_fused` | Fused MoE (Triton backend) | +| `torch.ops.auto_deploy.triton_quant_fp8_moe` | FP8 quantized MoE (Triton backend) | +| `torch.ops.auto_deploy.triton_mxfp4_moe` | MXFP4 MoE with triton-kernels matmul_ogs | +| `torch.ops.auto_deploy.triton_mxfp4_moe_ep` | MXFP4 MoE with Expert Parallelism (triton-kernels) | +| `torch.ops.auto_deploy.trtllm_moe_fused` | Fused MoE (TRT-LLM backend) | +| `torch.ops.auto_deploy.trtllm_quant_fp8_moe_fused` | FP8 quantized fused MoE (TRT-LLM backend) | +| `torch.ops.auto_deploy.trtllm_quant_nvfp4_moe_fused` | NVFP4 quantized fused MoE (TRT-LLM backend) | + +#### Quantization + +| Operator Name | Description | +|--------------|-------------| +| `torch.ops.auto_deploy.torch_quant_fn` | Generic quantization function that scales, rounds, and clamps input values | +| `torch.ops.auto_deploy.torch_quant_fp8_linear` | FP8 quantized linear layer (PyTorch backend) | +| `torch.ops.auto_deploy.torch_quant_nvfp4_linear` | NVFP4 quantized linear layer (PyTorch backend) | +| `torch.ops.auto_deploy.torch_quant_fp8_bmm` | FP8 quantized batch matrix multiply (PyTorch backend) | +| `torch.ops.auto_deploy.trtllm_quant_fp8_linear` | FP8 quantized linear layer (TRT-LLM backend) | +| `torch.ops.auto_deploy.torch_fake_quant_fp8_linear` | Fake FP8 quantized linear (for calibration/simulation) | +| `torch.ops.auto_deploy.torch_fake_quant_nvfp4_linear` | Fake NVFP4 quantized linear (for calibration/simulation) | +| `torch.ops.auto_deploy.torch_fake_quant_int4_linear` | Fake INT4 quantized linear (for calibration/simulation) | +| `torch.ops.auto_deploy.torch_fake_quant_int4_gptq_linear` | Fake INT4 GPTQ quantized linear (for calibration/simulation) | + +#### Normalization + +| Operator Name | Description | +|--------------|-------------| +| `torch.ops.auto_deploy.torch_rmsnorm` | RMSNorm (PyTorch backend) | +| `torch.ops.auto_deploy.torch_rmsnorm_gated` | Gated RMSNorm with optional SiLU gating (PyTorch backend) | +| `torch.ops.auto_deploy.triton_rms_norm` | RMSNorm (Triton backend) | +| `torch.ops.auto_deploy.triton_rmsnorm_gated` | Gated RMSNorm with optional SiLU gating (Triton backend) | +| `torch.ops.auto_deploy.flashinfer_rms_norm` | RMSNorm (FlashInfer backend) | +| `torch.ops.auto_deploy.flashinfer_fused_add_rms_norm_inplace` | Fused residual add + RMSNorm in-place (FlashInfer backend) | +| `torch.ops.auto_deploy.sharded_rmsnorm` | RMSNorm for tensor-parallel sharded activations (uses all-reduce) | +| `torch.ops.auto_deploy.torch_l2norm` | L2 normalization (PyTorch backend) | +| `torch.ops.auto_deploy.fla_l2norm` | L2 normalization (FLA Triton kernel backend) | + +#### Mamba (SSM + Causal Conv) + +| Operator Name | Description | +|--------------|-------------| +| `torch.ops.auto_deploy.torch_ssm` | State Space Model (SSM) computation (PyTorch backend) | +| `torch.ops.auto_deploy.torch_cached_ssm` | Cached SSM with state management (PyTorch backend) | +| `torch.ops.auto_deploy.triton_cached_ssm` | Cached SSM with state management (Triton backend) | +| `torch.ops.auto_deploy.flashinfer_cached_ssm` | Cached SSM with state management (FlashInfer backend) | +| `torch.ops.auto_deploy.mamba_ssm_prepare_metadata` | Mamba SSM metadata preparation (chunk indices, offsets, seq_idx) | +| `torch.ops.auto_deploy.torch_causal_conv1d` | Causal 1D convolution (PyTorch backend) | +| `torch.ops.auto_deploy.torch_cached_causal_conv1d` | Cached causal 1D convolution (PyTorch backend) | +| `torch.ops.auto_deploy.triton_cached_causal_conv1d` | Cached causal 1D convolution (Triton backend) | +| `torch.ops.auto_deploy.cuda_cached_causal_conv1d` | Cached causal 1D convolution (CUDA backend) | + +#### FLA (Flash Linear Attention) + +| Operator Name | Description | +|--------------|-------------| +| `torch.ops.auto_deploy.fla_delta_rule` | FLA chunked delta rule computation | +| `torch.ops.auto_deploy.fla_cached_delta_rule` | FLA cached delta rule with state management | + +#### Distributed + +| Operator Name | Description | +|--------------|-------------| +| `torch.ops.auto_deploy.torch_dist_all_gather` | All-gather (PyTorch backend, demollm mode) | +| `torch.ops.auto_deploy.torch_dist_all_reduce` | All-reduce (PyTorch backend, demollm mode) | +| `torch.ops.auto_deploy.trtllm_dist_all_gather` | All-gather (TRT-LLM backend, MPI mode) | +| `torch.ops.auto_deploy.trtllm_dist_all_reduce` | All-reduce (TRT-LLM backend, MPI mode) | +| `torch.ops.dist.trtllm_fused_allreduce_residual_rmsnorm` | Fused all-reduce + residual add + RMSNorm (TRT-LLM backend, MPI mode) | + +#### Utilities + +| Operator Name | Description | +|--------------|-------------| +| `torch.ops.auto_deploy.triton_utils_fused_gather_scatter` | Triton fused gather + scatter for overlap scheduling input_ids reordering | +| `torch.ops.auto_deploy.gather_logits_before_lm_head` | Gather hidden states using logits indices before LM head | diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/rope/triton_rope.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/rope/triton_rope.py index 3c5d79c0f0..59dd7f899e 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/rope/triton_rope.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/rope/triton_rope.py @@ -12,11 +12,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Tuple import torch import triton -from .triton_rope_kernel import rope_fwd_flattened_kernel, rope_fwd_kernel +from .triton_rope_kernel import ( + rope_fwd_flattened_kernel, + rope_fwd_interleaved_kernel, + rope_fwd_kernel, +) @torch.library.custom_op("auto_deploy::triton_rope_with_input_pos", mutates_args=()) @@ -147,3 +152,105 @@ def apply_rope_on_flattened_inputs( @apply_rope_on_flattened_inputs.register_fake def apply_rope_on_flattened_inputs_fake(x, freqs_cis, input_pos, seq_lens, seq_start_indices): return torch.empty_like(x) + + +@torch.library.custom_op("auto_deploy::triton_rope_on_interleaved_qk_inputs", mutates_args=()) +def apply_rope_on_interleaved_qk_inputs( + q: torch.Tensor, + k: torch.Tensor, + cos_cache: torch.Tensor, + sin_cache: torch.Tensor, + position_ids: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Fused RoPE for DeepSeek-style interleaved Q/K inputs. + + This kernel fuses: + 1. Position ID lookup for cos/sin from cache + 2. De-interleaving of Q/K + 3. RoPE application + + Args: + q: Query tensor with interleaved layout [B, S, H_Q, D] + Input: [q0_r, q0_i, q1_r, q1_i, ...] + k: Key tensor with interleaved layout [B, S, H_K, D] + Typically H_K=1 for MQA + cos_cache: Cosine cache [max_seq_len, D] + sin_cache: Sine cache [max_seq_len, D] + position_ids: Position indices [B, S] + + Returns: + Tuple of (q_rotated, k_rotated) with layout [B, S, H, D] + Output: [y0, y1, ..., y_{D/2-1}, y_{D/2}, ..., y_{D-1}] + where y_first = a*cos - b*sin, y_second = b*cos + a*sin + """ + assert q.dim() == 4, f"Q must be 4D [B, S, H, D], got {q.dim()}D" + assert k.dim() == 4, f"K must be 4D [B, S, H, D], got {k.dim()}D" + assert q.shape[-1] % 2 == 0, "Head dimension must be even" + assert q.shape[-1] == k.shape[-1], "Q and K must have same head dimension" + assert cos_cache.shape == sin_cache.shape, "cos and sin cache must have same shape" + + B, S, H_Q, D = q.shape + _, _, H_K, _ = k.shape + assert k.shape[0] == B and k.shape[1] == S, "Q and K must have same batch and seq dims" + assert position_ids.shape == (B, S), f"position_ids must be [B, S], got {position_ids.shape}" + assert H_Q >= H_K, f"H_Q ({H_Q}) must be >= H_K ({H_K}) for grid sizing" + + # Allocate contiguous outputs + # The kernel computes contiguous strides internally for output writes + q_out = torch.empty_like(q) + k_out = torch.empty_like(k) + + # Block sizes + BLOCK_SIZE_H = 32 + BLOCK_SIZE_S = min(triton.next_power_of_2(S), 32) + + # Grid: (B, cdiv(H_Q, BLOCK_SIZE_H), cdiv(S, BLOCK_SIZE_S)) + # H_Q >= H_K is enforced above; K heads are masked within each block + grid = ( + B, + triton.cdiv(H_Q, BLOCK_SIZE_H), + triton.cdiv(S, BLOCK_SIZE_S), + ) + + rope_fwd_interleaved_kernel[grid]( + q, + k, + cos_cache, + sin_cache, + position_ids, + q_out, + k_out, + B, + S, + H_Q, + H_K, + D, + q.stride(0), + q.stride(1), + q.stride(2), + q.stride(3), + k.stride(0), + k.stride(1), + k.stride(2), + k.stride(3), + position_ids.stride(0), + position_ids.stride(1), + cos_cache.stride(0), + cos_cache.stride(1), + BLOCK_SIZE_H=BLOCK_SIZE_H, + BLOCK_SIZE_S=BLOCK_SIZE_S, + ) + + return q_out, k_out + + +@apply_rope_on_interleaved_qk_inputs.register_fake +def apply_rope_on_interleaved_qk_inputs_fake( + q: torch.Tensor, + k: torch.Tensor, + cos_cache: torch.Tensor, + sin_cache: torch.Tensor, + position_ids: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + return torch.empty_like(q), torch.empty_like(k) diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/rope/triton_rope_kernel.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/rope/triton_rope_kernel.py index 4139d3a9d9..29526a12e9 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/rope/triton_rope_kernel.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/rope/triton_rope_kernel.py @@ -17,6 +17,191 @@ import triton import triton.language as tl +@triton.jit +def rope_fwd_interleaved_kernel( + q_ptr, # [B, S, H_Q, D] input Q (interleaved layout) + k_ptr, # [B, S, H_K, D] input K (interleaved layout) + cos_cache_ptr, # [max_seq_len, D] cos cache + sin_cache_ptr, # [max_seq_len, D] sin cache + position_ids_ptr, # [B, S] position IDs + q_out_ptr, # [B, S, H_Q, D] output Q + k_out_ptr, # [B, S, H_K, D] output K + B, # batch size + S, # sequence length + H_Q, # number of Q heads + H_K, # number of K heads (typically 1 for MQA) + D: tl.constexpr, # head dimension + stride_qb, # Q batch stride + stride_qs, # Q seq stride + stride_qh, # Q head stride + stride_qd, # Q dim stride + stride_kb, # K batch stride + stride_ks, # K seq stride + stride_kh, # K head stride + stride_kd, # K dim stride + stride_pos_b, # position_ids batch stride + stride_pos_s, # position_ids seq stride + stride_cache_s, # cache seq stride + stride_cache_d, # cache dim stride + BLOCK_SIZE_H: tl.constexpr, + BLOCK_SIZE_S: tl.constexpr, +): + """ + Fused RoPE kernel for DeepSeek-style interleaved Q/K inputs. + + Input layout (interleaved): [q0_r, q0_i, q1_r, q1_i, ...] + After de-interleave: [q0_r, q1_r, ..., q0_i, q1_i, ...] + + This kernel does: + 1. Position ID lookup for cos/sin + 2. Reads even and odd indices directly from the interleaved input using strided loads + 3. RoPE application directly on the separated a/b values: y_first = a*cos - b*sin, y_second = b*cos + a*sin + 4. Writes the output in contiguous "half-split" layout — first half and second half stored separately + + Grid: (B, cdiv(H_Q, BLOCK_SIZE_H), cdiv(S, BLOCK_SIZE_S)) + """ + D2: tl.constexpr = D // 2 + D2_PADDED: tl.constexpr = triton.next_power_of_2(D2) + + # Program IDs + batch_id = tl.program_id(0) + head_block_id = tl.program_id(1) + seq_block_id = tl.program_id(2) + + # Head offsets and mask + head_offsets = head_block_id * BLOCK_SIZE_H + tl.arange(0, BLOCK_SIZE_H) + head_mask = head_offsets < H_Q + + # Sequence offsets and mask + seq_offsets = seq_block_id * BLOCK_SIZE_S + tl.arange(0, BLOCK_SIZE_S) + seq_mask = seq_offsets < S + + # Dimension offsets for half the head dim (we read pairs) + dim_offsets = tl.arange(0, D2_PADDED) + dim_mask = dim_offsets < D2 + + # ========== LOAD POSITION IDS ========== + # position_ids: [B, S] + pos_ptr = position_ids_ptr + batch_id * stride_pos_b + seq_offsets * stride_pos_s + pos_ids = tl.load(pos_ptr, mask=seq_mask, other=0) # [BLOCK_SIZE_S] + + # ========== LOAD COS/SIN FROM CACHE ========== + # cos_cache, sin_cache: [max_seq_len, D] + # For each position, load the corresponding cos/sin values + # We need cos/sin for both halves of head_dim + cache_offsets = ( + pos_ids[:, None] * stride_cache_s + dim_offsets[None, :] * stride_cache_d + ) # [BLOCK_SIZE_S, D2_PADDED] + + cache_mask = seq_mask[:, None] & dim_mask[None, :] # [BLOCK_SIZE_S, D2_PADDED] + + cos_first = tl.load(cos_cache_ptr + cache_offsets, mask=cache_mask) # [BLOCK_SIZE_S, D2_PADDED] + sin_first = tl.load(sin_cache_ptr + cache_offsets, mask=cache_mask) # [BLOCK_SIZE_S, D2_PADDED] + + # Second half of cos/sin (offset by D2) + cache_offsets_second = ( + pos_ids[:, None] * stride_cache_s + (dim_offsets[None, :] + D2) * stride_cache_d + ) + cos_second = tl.load(cos_cache_ptr + cache_offsets_second, mask=cache_mask) + sin_second = tl.load(sin_cache_ptr + cache_offsets_second, mask=cache_mask) + + # ========== PROCESS Q ========== + # Q layout: [B, S, H, D] with interleaved D + # Read even indices (a) and odd indices (b) for de-interleaving + # Input: [q0_r, q0_i, q1_r, q1_i, ...] -> a=[q0_r, q1_r, ...], b=[q0_i, q1_i, ...] + + q_base = batch_id * stride_qb + + # Compute offsets for reading interleaved data + # even_offsets: positions 0, 2, 4, ... (stride 2) + # odd_offsets: positions 1, 3, 5, ... (stride 2) + q_offsets_base = ( + seq_offsets[:, None, None] * stride_qs + + head_offsets[None, :, None] * stride_qh + + dim_offsets[None, None, :] * 2 * stride_qd # stride 2 for interleaved + ) # [BLOCK_SIZE_S, BLOCK_SIZE_H, D2_PADDED] + + even_offsets = q_base + q_offsets_base + odd_offsets = q_base + q_offsets_base + stride_qd + + # Combined mask + load_mask = seq_mask[:, None, None] & head_mask[None, :, None] & dim_mask[None, None, :] + + # Load Q values (even = a, odd = b) + q_a = tl.load(q_ptr + even_offsets, mask=load_mask) # [BLOCK_SIZE_S, BLOCK_SIZE_H, D2_PADDED] + q_b = tl.load(q_ptr + odd_offsets, mask=load_mask) # [BLOCK_SIZE_S, BLOCK_SIZE_H, D2_PADDED] + + # Broadcast cos/sin for heads: [BLOCK_SIZE_S, D2_PADDED] -> [BLOCK_SIZE_S, 1, D2_PADDED] + cos_first_bc = cos_first[:, None, :] + sin_first_bc = sin_first[:, None, :] + cos_second_bc = cos_second[:, None, :] + sin_second_bc = sin_second[:, None, :] + + # Apply RoPE formula + # y_first_half = a * cos - b * sin + # y_second_half = b * cos + a * sin + q_y1 = q_a * cos_first_bc - q_b * sin_first_bc + q_y2 = q_b * cos_second_bc + q_a * sin_second_bc + + # Store Q output (CONTIGUOUS layout) + # Output layout: [B, S, H_Q, D] with first half = y1, second half = y2 + # Compute contiguous strides: stride_b=S*H_Q*D, stride_s=H_Q*D, stride_h=D, stride_d=1 + q_out_stride_b = S * H_Q * D + q_out_stride_s = H_Q * D + q_out_stride_h = D + q_out_offsets_first = ( + batch_id * q_out_stride_b + + seq_offsets[:, None, None] * q_out_stride_s + + head_offsets[None, :, None] * q_out_stride_h + + dim_offsets[None, None, :] # stride_d = 1 + ) + q_out_offsets_second = q_out_offsets_first + D2 # D2 * 1 + + tl.store(q_out_ptr + q_out_offsets_first, q_y1, mask=load_mask) + tl.store(q_out_ptr + q_out_offsets_second, q_y2, mask=load_mask) + + # ========== PROCESS K ========== + # K typically has H_K=1 for MQA, but we handle general case + # Use head_offsets < H_K for K's mask + head_mask_k = head_offsets < H_K + + k_base = batch_id * stride_kb + + k_offsets_base = ( + seq_offsets[:, None, None] * stride_ks + + head_offsets[None, :, None] * stride_kh + + dim_offsets[None, None, :] * 2 * stride_kd + ) + + k_even_offsets = k_base + k_offsets_base + k_odd_offsets = k_base + k_offsets_base + stride_kd + + load_mask_k = seq_mask[:, None, None] & head_mask_k[None, :, None] & dim_mask[None, None, :] + + k_a = tl.load(k_ptr + k_even_offsets, mask=load_mask_k) + k_b = tl.load(k_ptr + k_odd_offsets, mask=load_mask_k) + + k_y1 = k_a * cos_first_bc - k_b * sin_first_bc + k_y2 = k_b * cos_second_bc + k_a * sin_second_bc + + # Store K output (CONTIGUOUS layout) + # Output layout: [B, S, H_K, D] with first half = y1, second half = y2 + # Compute contiguous strides: stride_b=S*H_K*D, stride_s=H_K*D, stride_h=D, stride_d=1 + k_out_stride_b = S * H_K * D + k_out_stride_s = H_K * D + k_out_stride_h = D + k_out_offsets_first = ( + batch_id * k_out_stride_b + + seq_offsets[:, None, None] * k_out_stride_s + + head_offsets[None, :, None] * k_out_stride_h + + dim_offsets[None, None, :] # stride_d = 1 + ) + k_out_offsets_second = k_out_offsets_first + D2 # D2 * 1 + + tl.store(k_out_ptr + k_out_offsets_first, k_y1, mask=load_mask_k) + tl.store(k_out_ptr + k_out_offsets_second, k_y2, mask=load_mask_k) + + @triton.jit def rope_fwd_kernel( x_ptr, diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/rope.py b/tensorrt_llm/_torch/auto_deploy/transform/library/rope.py index 811dd34af5..4e9318f0c1 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/rope.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/rope.py @@ -363,7 +363,11 @@ class MatchRopeLayout(BaseTransform): class OptimizeRope(BaseTransform): """ Scan the FX graph and replace calls to the torch-reference RoPE ops with - the optimized `rope::flashinfer` kernel. + optimized kernels: + - ``torch_rope_with_explicit_cos_sin`` → ``flashinfer_rope`` + - ``torch_rope_with_complex_freqs`` → ``flashinfer_rope`` + - ``torch_rope_with_qk_interleaving`` → ``triton_rope_on_interleaved_qk_inputs`` + Precomputes positional IDs and the fused cosine-sine cache as explicit nodes, and reuses those nodes when possible. """ @@ -385,6 +389,9 @@ class OptimizeRope(BaseTransform): _optimize_explicit(graph, node, rope_flash_cache, rope_position_ids_cache) elif is_op(node, torch.ops.auto_deploy.torch_rope_with_complex_freqs): _optimize_complex(graph, node, rope_flash_cache, rope_position_ids_cache) + elif is_op(node, torch.ops.auto_deploy.torch_rope_with_qk_interleaving): + if not _optimize_interleaved(graph, node): + continue else: continue num_rope_optimizations += 1 @@ -546,6 +553,117 @@ def _optimize_complex( graph.erase_node(node) +def _trace_back_index(node: Node) -> Optional[Tuple[Node, Node]]: + """Trace back from a node to find an aten.index.Tensor producer. + + If ``node`` was produced by ``aten.index.Tensor(cache, [position_ids])``, + return ``(cache_node, position_ids_node)``. Otherwise return ``None``. + """ + if not is_op(node, torch.ops.aten.index.Tensor): + return None + cache_node = node.args[0] + indices = node.args[1] + if not isinstance(indices, (list, tuple)) or len(indices) != 1: + return None + position_ids_node = indices[0] + if not isinstance(cache_node, Node) or not isinstance(position_ids_node, Node): + return None + return cache_node, position_ids_node + + +def _validate_interleaved_rope_inputs(q_node: Node, k_node: Node) -> bool: + """Validate q/k inputs for the interleaved triton RoPE kernel. + + Relaxed compared to ``_validate_rope_inputs``: requires even head_dim + (not head_dim % 64 == 0) since the triton kernel pads to next-power-of-2. + """ + for node in (q_node, k_node): + fake_val = node.meta.get("val", None) + if fake_val is None: + return False + + # dtype must be half-precision + if fake_val.dtype not in (torch.float16, torch.bfloat16): + return False + + # Must be at least 4-D + if len(fake_val.shape) < 4: + return False + + # head_dim must be even + head_dim = fake_val.shape[-1] + if isinstance(head_dim, int) and head_dim % 2 != 0: + return False + + # BSND layout: dim 1 (S) should be symbolic, dim 2 (N) should be static + if not isinstance(fake_val.shape[1], torch.SymInt): + return False + if not isinstance(fake_val.shape[2], int): + return False + + return True + + +def _optimize_interleaved(graph: torch.fx.Graph, node: Node) -> bool: + """Replace ``torch_rope_with_qk_interleaving`` with ``triton_rope_on_interleaved_qk_inputs``. + + Traces back from cos/sin nodes to find the original ``aten.index.Tensor`` + producer and extracts the cached cos/sin tensors and position_ids, passing + them directly to the triton kernel (which fuses the position lookup). + + Returns True if the replacement was made, False if skipped. + """ + # --- extract arguments --------------------------------------------------- + q_node, k_node, cos_node, sin_node, *rest = node.args + q_rope_old, k_rope_old = extract_output_tuple(node, 2) + if q_rope_old is None or k_rope_old is None: + return False + + # --- validate inputs ----------------------------------------------------- + if not _validate_interleaved_rope_inputs(q_node, k_node): + return False + + # --- trace back cos/sin to find cache + position_ids --------------------- + cos_traced = _trace_back_index(cos_node) + if cos_traced is None: + return False + cos_cache_node, cos_position_ids_node = cos_traced + + sin_traced = _trace_back_index(sin_node) + if sin_traced is None: + return False + sin_cache_node, sin_position_ids_node = sin_traced + + # Both cos and sin must use the same position_ids + if cos_position_ids_node is not sin_position_ids_node: + return False + + position_ids_node = cos_position_ids_node + + # --- insert the triton op ------------------------------------------------ + with graph.inserting_before(node): + triton_node = graph.call_function( + torch.ops.auto_deploy.triton_rope_on_interleaved_qk_inputs, + args=(q_node, k_node, cos_cache_node, sin_cache_node, position_ids_node), + ) + + with graph.inserting_after(triton_node): + q_rope_new = graph.call_function(operator.getitem, args=(triton_node, 0)) + k_rope_new = graph.call_function(operator.getitem, args=(triton_node, 1)) + + # --- rewire outputs ------------------------------------------------------ + q_rope_new.meta["val"] = q_rope_old.meta.get("val", None) + k_rope_new.meta["val"] = k_rope_old.meta.get("val", None) + + q_rope_old.replace_all_uses_with(q_rope_new) + k_rope_old.replace_all_uses_with(k_rope_new) + + graph.erase_node(q_rope_old) + graph.erase_node(k_rope_old) + + return True + + def _match_input_interleave_pattern(node: Node) -> Optional[Dict[str, Node]]: """ Detect DeepSeek-style interleave on Q/K: diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/rope/test_triton_rope.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/rope/test_triton_rope.py index d9879562e8..748394648b 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/rope/test_triton_rope.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/rope/test_triton_rope.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Optional, Tuple import pytest import torch @@ -8,6 +8,19 @@ from _custom_op_utils import torch_rope_reference from tensorrt_llm._torch.auto_deploy.custom_ops.rope import triton_rope # noqa: F401 +def _precompute_cos_sin_cache( + max_seq_len: int, head_dim: int, rope_theta: float = 10000.0, dtype: torch.dtype = torch.float16 +) -> Tuple[torch.Tensor, torch.Tensor]: + """Precompute cos and sin cache for RoPE (DeepSeek-style).""" + inv_freq = 1.0 / (rope_theta ** (torch.arange(0, head_dim, 2, dtype=torch.float32) / head_dim)) + positions = torch.arange(max_seq_len, dtype=torch.float32) + freqs = positions.unsqueeze(1) * inv_freq.unsqueeze(0) # [max_seq_len, head_dim//2] + emb = torch.cat((freqs, freqs), dim=-1) # [max_seq_len, head_dim] + cos_cache = emb.cos().to(dtype) + sin_cache = emb.sin().to(dtype) + return cos_cache, sin_cache + + def _precompute_freqs_cis( seq_len: int, head_dim: int, rope_theta: Optional[float] = None ) -> torch.Tensor: @@ -73,3 +86,105 @@ def test_rope_flattened(d_head): y_reshaped = y.unflatten(-1, (2, N_ELEM // 2)).transpose(-2, -1).flatten(-2).contiguous() assert torch.allclose(y_ref.cpu(), y_reshaped.cpu(), atol=1e-2, rtol=1e-2) + + +@pytest.mark.parametrize( + "batch_size,seq_len,num_q_heads,num_k_heads,head_dim", + [ + (1, 4, 8, 8, 64), # Standard case + (2, 16, 20, 1, 128), # GLM-4 style: non-power-of-2 heads, MQA + (1, 1, 8, 8, 64), # Single token (decode) + (4, 32, 16, 2, 96), # GQA with non-standard head_dim + ], +) +def test_triton_rope_on_interleaved_qk_inputs( + batch_size: int, seq_len: int, num_q_heads: int, num_k_heads: int, head_dim: int +): + """ + Test that triton_rope_on_interleaved_qk_inputs produces the same output as + the PyTorch reference (index + torch_rope_with_qk_interleaving). + """ + device = "cuda" + dtype = torch.bfloat16 + max_seq_len = 1024 + + # Create random inputs with interleaved layout [B, S, H, D] + q = torch.randn(batch_size, seq_len, num_q_heads, head_dim, device=device, dtype=dtype) + k = torch.randn(batch_size, seq_len, num_k_heads, head_dim, device=device, dtype=dtype) + + # Precompute cos/sin cache + cos_cache, sin_cache = _precompute_cos_sin_cache(max_seq_len, head_dim, dtype=dtype) + cos_cache = cos_cache.to(device) + sin_cache = sin_cache.to(device) + + # Random position IDs (not necessarily sequential) + position_ids = torch.randint(0, max_seq_len - seq_len, (batch_size,), device=device) + position_ids = position_ids.unsqueeze(1) + torch.arange(seq_len, device=device).unsqueeze(0) + # position_ids: [B, S] + + # ========== PyTorch Reference ========== + # Step 1: Index cos/sin with position_ids + cos_indexed = cos_cache[position_ids] # [B, S, D] + sin_indexed = sin_cache[position_ids] # [B, S, D] + + # Step 2: Apply PyTorch rope with qk interleaving + # unsqueeze_dim=2 for [B, S, H, D] layout + q_ref, k_ref = torch.ops.auto_deploy.torch_rope_with_qk_interleaving( + q, k, cos_indexed, sin_indexed, unsqueeze_dim=2 + ) + + # ========== Triton Implementation ========== + q_triton, k_triton = torch.ops.auto_deploy.triton_rope_on_interleaved_qk_inputs( + q, k, cos_cache, sin_cache, position_ids + ) + + # ========== Compare Outputs ========== + # Use relative tolerance for bf16 + atol = 1e-2 + rtol = 1e-2 + + assert torch.allclose(q_ref, q_triton, atol=atol, rtol=rtol), ( + f"Q mismatch: max diff = {(q_ref - q_triton).abs().max().item()}" + ) + assert torch.allclose(k_ref, k_triton, atol=atol, rtol=rtol), ( + f"K mismatch: max diff = {(k_ref - k_triton).abs().max().item()}" + ) + + +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +def test_triton_rope_interleaved_dtype_consistency(dtype): + """Test that the Triton kernel works correctly with different dtypes.""" + device = "cuda" + batch_size, seq_len, num_heads, head_dim = 2, 8, 8, 64 + max_seq_len = 1024 + + q = torch.randn(batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype) + k = torch.randn(batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype) + + cos_cache, sin_cache = _precompute_cos_sin_cache(max_seq_len, head_dim, dtype=dtype) + cos_cache = cos_cache.to(device) + sin_cache = sin_cache.to(device) + + position_ids = torch.arange(seq_len, device=device).unsqueeze(0).expand(batch_size, -1) + + # PyTorch reference + cos_indexed = cos_cache[position_ids] + sin_indexed = sin_cache[position_ids] + q_ref, k_ref = torch.ops.auto_deploy.torch_rope_with_qk_interleaving( + q, k, cos_indexed, sin_indexed, unsqueeze_dim=2 + ) + + # Triton + q_triton, k_triton = torch.ops.auto_deploy.triton_rope_on_interleaved_qk_inputs( + q, k, cos_cache, sin_cache, position_ids + ) + + # Verify outputs match + atol = 1e-2 + rtol = 1e-2 + assert torch.allclose(q_ref, q_triton, atol=atol, rtol=rtol) + assert torch.allclose(k_ref, k_triton, atol=atol, rtol=rtol) + + # Verify output dtype is preserved + assert q_triton.dtype == dtype + assert k_triton.dtype == dtype diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_rope_transformation.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_rope_transformation.py index 291cd377bd..e636803fe3 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_rope_transformation.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_rope_transformation.py @@ -490,3 +490,64 @@ def test_match_and_layout_deepseek(layout, num_heads, num_kv_heads, mode, target None, # check_num_matches False, # skip_output_assert ) + + +@pytest.mark.parametrize( + "num_heads,num_kv_heads", + [ + (8, 8), # Standard MHA + (8, 1), # MQA (DeepSeek-style) + ], +) +@torch.inference_mode() +def test_optimize_interleaved_rope(num_heads, num_kv_heads): + """Test that optimize_rope replaces torch_rope_with_qk_interleaving + with triton_rope_on_interleaved_qk_inputs by tracing back through + aten.index.Tensor to find the cached cos/sin and position_ids.""" + batch, seq, hid = 4, 12, 512 + model = DSModel(hid, 16, num_heads, num_kv_heads, layout="BSND", mode="optimize") + model = model.to("cuda", torch.float16) + + x = torch.randn(batch, seq, hid, device="cuda", dtype=torch.float16) + dynamic_shapes = model.get_dynamic_shapes() + gm = torch_export_to_gm(model, args=(x,), dynamic_shapes=(dynamic_shapes,), clone=True) + + # Verify the graph contains torch_rope_with_qk_interleaving before optimization + assert any( + is_op(n, torch.ops.auto_deploy.torch_rope_with_qk_interleaving) for n in gm.graph.nodes + ), "Expected torch_rope_with_qk_interleaving in graph before optimization" + + gm_transformed = InferenceOptimizer( + None, + { + "optimize_rope": { + "stage": "pattern_matcher", + }, + }, + )(None, gm) + gm_transformed.to("cuda") + + def checker(gm): + has_triton = any( + is_op(n, torch.ops.auto_deploy.triton_rope_on_interleaved_qk_inputs) + for n in gm.graph.nodes + ) + no_torch_rope = not any( + is_op(n, torch.ops.auto_deploy.torch_rope_with_qk_interleaving) for n in gm.graph.nodes + ) + return has_triton and no_torch_rope + + run_test_transformed_gm( + model, + x, + gm_transformed, + checker, + lambda num_p: num_p, + 1e-2, # atol + 1e-2, # rtol + True, # test_load_hook + True, # strict_loading + dynamic_shapes, # dynamic_shapes + None, # check_num_matches + False, # skip_output_assert + )