From 9ae705af1ba798d053c5cb1a556f6a6c3035a5aa Mon Sep 17 00:00:00 2001 From: Bo Li <22713281+bobboli@users.noreply.github.com> Date: Fri, 23 May 2025 15:31:04 +0800 Subject: [PATCH] perf: Add fused q_norm/k_norm/RoPE for Qwen3. (#4482) * Add Julien's origina kernel. Signed-off-by: Bo Li <22713281+bobboli@users.noreply.github.com> * Get rid of UpdateKVCache functionality. Signed-off-by: Bo Li <22713281+bobboli@users.noreply.github.com> * Add kernels. Signed-off-by: Bo Li <22713281+bobboli@users.noreply.github.com> * Add torch OP. Signed-off-by: Bo Li <22713281+bobboli@users.noreply.github.com> * Update cmake. Signed-off-by: Bo Li <22713281+bobboli@users.noreply.github.com> * Torch OP must use double as argument dtype. Signed-off-by: Bo Li <22713281+bobboli@users.noreply.github.com> * Add unittest. Signed-off-by: Bo Li <22713281+bobboli@users.noreply.github.com> * Add unittest. Signed-off-by: Bo Li <22713281+bobboli@users.noreply.github.com> * Fix misaligned access when head_dim=64. In this case, numElemsPerThread=2, numVecPerThread=0. But the store code incorrectly perform vectorized store, some threads (e.g., lane1) issue store to address that is not aligned to 64 bit. Signed-off-by: Bo Li <22713281+bobboli@users.noreply.github.com> * Remove unroll (compiler can do that). Cleanup code. Signed-off-by: Bo Li <22713281+bobboli@users.noreply.github.com> * Add switch for interleave. Signed-off-by: Bo Li <22713281+bobboli@users.noreply.github.com> * Refactor vectorized load/store. Signed-off-by: Bo Li <22713281+bobboli@users.noreply.github.com> * Implement is_neox. Result not correct yet. Signed-off-by: Bo Li <22713281+bobboli@users.noreply.github.com> * Fix is_neox=True. Signed-off-by: Bo Li <22713281+bobboli@users.noreply.github.com> * Add q_weight and k_weight. Signed-off-by: Bo Li <22713281+bobboli@users.noreply.github.com> --------- Signed-off-by: Bo Li <22713281+bobboli@users.noreply.github.com> --- .../kernels/fusedQKNormRopeKernel.cu | 270 ++++++++++++++++++ .../kernels/fusedQKNormRopeKernel.h | 44 +++ cpp/tensorrt_llm/thop/CMakeLists.txt | 1 + cpp/tensorrt_llm/thop/fusedQKNormRopeOp.cpp | 89 ++++++ .../_torch/thop/test_fused_qk_norm_rope.py | 162 +++++++++++ 5 files changed, 566 insertions(+) create mode 100644 cpp/tensorrt_llm/kernels/fusedQKNormRopeKernel.cu create mode 100644 cpp/tensorrt_llm/kernels/fusedQKNormRopeKernel.h create mode 100644 cpp/tensorrt_llm/thop/fusedQKNormRopeOp.cpp create mode 100644 tests/unittest/_torch/thop/test_fused_qk_norm_rope.py diff --git a/cpp/tensorrt_llm/kernels/fusedQKNormRopeKernel.cu b/cpp/tensorrt_llm/kernels/fusedQKNormRopeKernel.cu new file mode 100644 index 0000000000..9ce057ee7b --- /dev/null +++ b/cpp/tensorrt_llm/kernels/fusedQKNormRopeKernel.cu @@ -0,0 +1,270 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "fusedQKNormRopeKernel.h" +#include "tensorrt_llm/common/cudaUtils.h" +#include "tensorrt_llm/common/mathUtils.h" +#include "tensorrt_llm/common/reduceKernelUtils.cuh" +#include +#include +#include +#include +#include + +namespace tensorrt_llm::common +{ +// Specialization for packed_as used in this kernel. +template <> +struct packed_as +{ + using type = uint; +}; + +template <> +struct packed_as +{ + using type = uint2; +}; + +template <> +struct packed_as +{ + using type = uint4; +}; +} // namespace tensorrt_llm::common + +namespace tensorrt_llm::kernels +{ + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Perform per-head QK Norm and RoPE in a single kernel. +// head_dim: the dimension of each head +// interleave: interleave=!is_neox. +template +__global__ void fusedQKNormRopeKernel( + __nv_bfloat16* qkv, // Combined QKV tensor [num_tokens, (num_heads_q+num_heads_k+num_heads_v)*head_dim] + int const num_heads_q, // Number of query heads + int const num_heads_k, // Number of key heads + int const num_heads_v, // Number of value heads + float const eps, // Epsilon for RMS normalization + __nv_bfloat16 const* q_weight, // RMSNorm weights for query + __nv_bfloat16 const* k_weight, // RMSNorm weights for key + float const base, // Base for RoPE computation + int const* position_ids, // Position IDs for RoPE + int const num_tokens // Number of tokens +) +{ + int const warpsPerBlock = blockDim.x / 32; + int const warpId = threadIdx.x / 32; + int const laneId = threadIdx.x % 32; + + // Calculate global warp index to determine which head/token this warp processes + int const globalWarpIdx = blockIdx.x * warpsPerBlock + warpId; + + // Total number of attention heads (Q and K) + int const total_qk_heads = num_heads_q + num_heads_k; + + // Determine which token and head type (Q or K) this warp processes + int const tokenIdx = globalWarpIdx / total_qk_heads; + int const localHeadIdx = globalWarpIdx % total_qk_heads; + + // Skip if this warp is assigned beyond the number of tokens + if (tokenIdx >= num_tokens) + return; + + bool const isQ = localHeadIdx < num_heads_q; + int const headIdx = isQ ? localHeadIdx : localHeadIdx - num_heads_q; + + int const num_heads = num_heads_q + num_heads_k + num_heads_v; + + static_assert(head_dim % (32 * 2) == 0, + "head_dim must be divisible by 64 (each warp processes one head, and each thread gets even number of " + "elements)"); + constexpr int numElemsPerThread = head_dim / 32; + float elements[numElemsPerThread]; + constexpr int elemSizeBytes = numElemsPerThread * sizeof(__nv_bfloat16); + static_assert(elemSizeBytes % 4 == 0, "numSizeBytes must be a multiple of 4"); + constexpr int vecSize = elemSizeBytes / 4; // Use packed_as to perform loading/saving. + using vec_T = typename tensorrt_llm::common::packed_as::type; + + int offsetWarp; // Offset for the warp + if (isQ) + { + // Q segment: token offset + head offset within Q segment + offsetWarp = tokenIdx * num_heads * head_dim + headIdx * head_dim; + } + else + { + // K segment: token offset + entire Q segment + head offset within K segment + offsetWarp = tokenIdx * num_heads * head_dim + num_heads_q * head_dim + headIdx * head_dim; + } + int offsetThread = offsetWarp + laneId * numElemsPerThread; + + // Sum of squares for RMSNorm + float sumOfSquares = 0.0f; + + // Load. + { + vec_T vec = *reinterpret_cast(&qkv[offsetThread]); + for (int i = 0; i < vecSize; i++) + { + float2 vals = __bfloat1622float2(*reinterpret_cast<__nv_bfloat162*>(reinterpret_cast(&vec) + i)); + sumOfSquares += vals.x * vals.x; + sumOfSquares += vals.y * vals.y; + + elements[2 * i] = vals.x; + elements[2 * i + 1] = vals.y; + } + } + + // Reduce sum across warp using the utility function + sumOfSquares = tensorrt_llm::common::warpReduceSum(sumOfSquares); + + // Compute RMS normalization factor + float rms_rcp = rsqrtf(sumOfSquares / static_cast(head_dim) + eps); + + // Normalize elements + for (int i = 0; i < numElemsPerThread; i++) + { + int dim = laneId * numElemsPerThread + i; + float weight = isQ ? __bfloat162float(q_weight[dim]) : __bfloat162float(k_weight[dim]); + elements[i] *= rms_rcp * weight; + } + + // Apply RoPE to normalized elements + float elements2[numElemsPerThread]; // Additional buffer required for RoPE. + float cos_vals[numElemsPerThread]; + float sin_vals[numElemsPerThread]; + + float pos_id = static_cast(position_ids[tokenIdx]); + + // TODO: cos sin calculation could be halved. + if constexpr (interleave) + { + // Perform interleaving. Fill cos_vals and sin_vals. + for (int i = 0; i < numElemsPerThread; i++) + { + if (i % 2 == 0) + { + elements2[i] = -elements[i + 1]; + } + else + { + elements2[i] = elements[i - 1]; + } + + int dim_idx = laneId * numElemsPerThread + i; + int half_dim = dim_idx / 2; + float freq = powf(base, -2.0f * half_dim / static_cast(head_dim)); + float theta = pos_id * freq; + __sincosf(theta, &sin_vals[i], &cos_vals[i]); + } + } + else + { + // Before data exchange with in warp, we need to sync. + __syncwarp(); + // Get the data from the other half of the warp. Fill cos_vals and sin_vals. + for (int i = 0; i < numElemsPerThread; i++) + { + elements2[i] = __shfl_xor_sync(0xffffffff, elements[i], 16); + if (laneId < 16) + { + elements2[i] = -elements2[i]; + } + + int dim_idx = laneId * numElemsPerThread + i; + dim_idx = (dim_idx * 2) % head_dim; + int half_dim = dim_idx / 2; + float freq = powf(base, -2.0f * half_dim / static_cast(head_dim)); + float theta = pos_id * freq; + __sincosf(theta, &sin_vals[i], &cos_vals[i]); + } + // __shfl_xor_sync does not provide memfence. Need to sync again. + __syncwarp(); + } + + for (int i = 0; i < numElemsPerThread; i++) + { + elements[i] = elements[i] * cos_vals[i] + elements2[i] * sin_vals[i]; + } + + // Store. + { + vec_T vec; + for (int i = 0; i < vecSize; i++) + { + __nv_bfloat162 vals = __float22bfloat162_rn(make_float2(elements[2 * i], elements[2 * i + 1])); + reinterpret_cast<__nv_bfloat162&>(*(reinterpret_cast(&vec) + i)) = vals; + } + vec_T* outputPtr = reinterpret_cast(&qkv[offsetThread]); + *outputPtr = vec; + } +} + +// Borrowed from +// https://github.com/flashinfer-ai/flashinfer/blob/8125d079a43e9a0ba463a4ed1b639cefd084cec9/include/flashinfer/pos_enc.cuh#L568 +#define DISPATCH_INTERLEAVE(interleave, INTERLEAVE, ...) \ + if (interleave) \ + { \ + const bool INTERLEAVE = true; \ + __VA_ARGS__ \ + } \ + else \ + { \ + const bool INTERLEAVE = false; \ + __VA_ARGS__ \ + } + +void launchFusedQKNormRope(void* qkv, int const num_tokens, int const num_heads_q, int const num_heads_k, + int const num_heads_v, int const head_dim, float const eps, void const* q_weight, void const* k_weight, + float const base, bool const interleave, int const* position_ids, cudaStream_t stream) +{ + constexpr int blockSize = 256; + + int const warpsPerBlock = blockSize / 32; + int const totalQKHeads = num_heads_q + num_heads_k; + int const totalWarps = num_tokens * totalQKHeads; + + int const gridSize = common::divUp(totalWarps, warpsPerBlock); + dim3 gridDim(gridSize); + dim3 blockDim(blockSize); + + // Head dimensions should be a multiple of 64 + // Add more cases as needed + switch (head_dim) + { + case 64: + DISPATCH_INTERLEAVE(interleave, INTERLEAVE, { + fusedQKNormRopeKernel<64, INTERLEAVE> + <<>>(reinterpret_cast<__nv_bfloat16*>(qkv), num_heads_q, num_heads_k, + num_heads_v, eps, reinterpret_cast<__nv_bfloat16 const*>(q_weight), + reinterpret_cast<__nv_bfloat16 const*>(k_weight), base, position_ids, num_tokens); + }); + break; + case 128: + DISPATCH_INTERLEAVE(interleave, INTERLEAVE, { + fusedQKNormRopeKernel<128, INTERLEAVE> + <<>>(reinterpret_cast<__nv_bfloat16*>(qkv), num_heads_q, num_heads_k, + num_heads_v, eps, reinterpret_cast<__nv_bfloat16 const*>(q_weight), + reinterpret_cast<__nv_bfloat16 const*>(k_weight), base, position_ids, num_tokens); + }); + break; + default: TLLM_THROW("Unsupported head dimension for fusedQKNormRope: %d", head_dim); + } +} +} // namespace tensorrt_llm::kernels diff --git a/cpp/tensorrt_llm/kernels/fusedQKNormRopeKernel.h b/cpp/tensorrt_llm/kernels/fusedQKNormRopeKernel.h new file mode 100644 index 0000000000..09146a8d03 --- /dev/null +++ b/cpp/tensorrt_llm/kernels/fusedQKNormRopeKernel.h @@ -0,0 +1,44 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +namespace tensorrt_llm +{ +namespace kernels +{ + +// Perform fused QK Normalization and RoPE in a single CUDA kernel +// This function efficiently applies RMS normalization and RoPE embeddings to query and key tensors +void launchFusedQKNormRope( + void* qkv, // Combined QKV tensor [num_tokens, (num_heads_q+num_heads_k+num_heads_v)*head_dim] + int const num_tokens, // Number of tokens + int const num_heads_q, // Number of query heads + int const num_heads_k, // Number of key heads + int const num_heads_v, // Number of value heads + int const head_dim, // Dimension per head + float const eps, // Epsilon for RMS normalization + void const* q_weight, // RMSNorm weights for query [head_dim] + void const* k_weight, // RMSNorm weights for key [head_dim] + float const base, // Base for RoPE computation + bool const interleave, // Whether RoPE is applied in interleave mode (non-Neox style) + int const* position_ids, // Position IDs for RoPE [num_tokens] + cudaStream_t stream); // CUDA stream + +} // namespace kernels +} // namespace tensorrt_llm diff --git a/cpp/tensorrt_llm/thop/CMakeLists.txt b/cpp/tensorrt_llm/thop/CMakeLists.txt index a3a8a134db..62f4fac2de 100644 --- a/cpp/tensorrt_llm/thop/CMakeLists.txt +++ b/cpp/tensorrt_llm/thop/CMakeLists.txt @@ -55,6 +55,7 @@ add_library( fp4BatchedQuantize.cpp fp8BlockScalingGemm.cpp fp8Quantize.cpp + fusedQKNormRopeOp.cpp fusedTopkSoftmax.cpp gatherTreeOp.cpp groupRmsNormOp.cpp diff --git a/cpp/tensorrt_llm/thop/fusedQKNormRopeOp.cpp b/cpp/tensorrt_llm/thop/fusedQKNormRopeOp.cpp new file mode 100644 index 0000000000..0692ee57a7 --- /dev/null +++ b/cpp/tensorrt_llm/thop/fusedQKNormRopeOp.cpp @@ -0,0 +1,89 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/kernels/fusedQKNormRopeKernel.h" +#include "tensorrt_llm/thop/thUtils.h" + +#include +#include + +namespace torch_ext +{ + +// Function for fused QK Norm and RoPE +// This operator applies RMS normalization and RoPE to Q and K tensors in a single CUDA kernel. +// The OP performs operations in-place on the input qkv tensor. +void fused_qk_norm_rope( + torch::Tensor& qkv, // Combined QKV tensor [num_tokens, (num_heads_q+num_heads_k+num_heads_v)*head_dim] + int64_t num_heads_q, // Number of query heads + int64_t num_heads_k, // Number of key heads + int64_t num_heads_v, // Number of value heads + int64_t head_dim, // Dimension per head + double eps, // Epsilon for RMS normalization + torch::Tensor& q_weight, // RMSNorm weights for query [head_dim] + torch::Tensor& k_weight, // RMSNorm weights for key [head_dim] + double base, // Base for RoPE computation + bool is_neox, // Whether RoPE is applied in Neox style + torch::Tensor& position_ids // Position IDs for RoPE [num_tokens] +) +{ + // Input validation + TORCH_CHECK(qkv.dim() == 2, "QKV tensor must be 2D: [num_tokens, (num_heads_q+num_heads_k+num_heads_v)*head_dim]"); + TORCH_CHECK(position_ids.dim() == 1, "Position IDs must be 1D: [num_tokens]"); + TORCH_CHECK(q_weight.dim() == 1, "Query weights must be 1D: [head_dim]"); + TORCH_CHECK(k_weight.dim() == 1, "Key weights must be 1D: [head_dim]"); + TORCH_CHECK(q_weight.size(0) == head_dim, "Query weights size must match head dimension"); + TORCH_CHECK(k_weight.size(0) == head_dim, "Key weights size must match head dimension"); + + CHECK_INPUT(qkv, torch::kBFloat16); + CHECK_INPUT(position_ids, torch::kInt32); + CHECK_INPUT(q_weight, torch::kBFloat16); + CHECK_INPUT(k_weight, torch::kBFloat16); + + int64_t num_tokens = qkv.size(0); + TORCH_CHECK(position_ids.size(0) == num_tokens, "Number of tokens in position_ids must match QKV"); + + int64_t total_heads = num_heads_q + num_heads_k + num_heads_v; + TORCH_CHECK( + qkv.size(1) == total_heads * head_dim, "QKV tensor size must match total number of heads and head dimension"); + + auto stream = at::cuda::getCurrentCUDAStream(qkv.get_device()); + + tensorrt_llm::kernels::launchFusedQKNormRope(reinterpret_cast<__nv_bfloat16*>(qkv.data_ptr()), + static_cast(num_tokens), static_cast(num_heads_q), static_cast(num_heads_k), + static_cast(num_heads_v), static_cast(head_dim), static_cast(eps), + reinterpret_cast<__nv_bfloat16*>(q_weight.data_ptr()), reinterpret_cast<__nv_bfloat16*>(k_weight.data_ptr()), + static_cast(base), + !is_neox, // interleave + reinterpret_cast(position_ids.data_ptr()), stream); +} + +// Register the PyTorch operators +TORCH_LIBRARY_FRAGMENT(trtllm, m) +{ + m.def( + "fused_qk_norm_rope(Tensor qkv, int num_heads_q, int num_heads_k, int num_heads_v, int head_dim, float eps, " + "Tensor q_weight, Tensor k_weight, float base, bool is_neox, Tensor position_ids) -> ()", + &fused_qk_norm_rope); +} + +// Register the CUDA implementation +TORCH_LIBRARY_IMPL(trtllm, CUDA, m) +{ + m.impl("fused_qk_norm_rope", &fused_qk_norm_rope); +} + +} // namespace torch_ext diff --git a/tests/unittest/_torch/thop/test_fused_qk_norm_rope.py b/tests/unittest/_torch/thop/test_fused_qk_norm_rope.py new file mode 100644 index 0000000000..ad76e9705e --- /dev/null +++ b/tests/unittest/_torch/thop/test_fused_qk_norm_rope.py @@ -0,0 +1,162 @@ +import pytest +import torch + +from tensorrt_llm._torch.attention_backend.interface import RopeParams +from tensorrt_llm._torch.modules.rms_norm import RMSNorm +from tensorrt_llm._torch.modules.rotary_embedding import RotaryEmbedding + + +@torch.inference_mode() +def torch_ref_rms_norm_rope(qkv, num_heads_q, num_heads_k, num_heads_v, + head_dim, eps, q_weight, k_weight, base, is_neox, + position_ids): + """ + PyTorch reference implementation of RMSNorm+RoPE for verification. + + Uses TensorRT-LLM's own RMSNorm and RotaryEmbedding modules to ensure consistency + with the expected behavior of the fused kernel. + + Args: + qkv: Combined QKV tensor of shape [num_tokens, hidden_size] + num_heads_q: Number of query heads + num_heads_k: Number of key heads + num_heads_v: Number of value heads (unused for normalization/RoPE but needed for tensor splitting) + head_dim: Dimension of each head + eps: Epsilon value for RMS normalization + q_weight: RMSNorm weights for query [head_dim] + k_weight: RMSNorm weights for key [head_dim] + base: Base value for RoPE calculations + is_neox: Whether to use NeoX style RoPE + position_ids: Position IDs for RoPE of shape [num_tokens] + + Returns: + Combined tensor with Q and K parts normalized and RoPE applied + """ + # Get input shape information + num_tokens = qkv.shape[0] + hidden_size = qkv.shape[1] + + # Calculate dimensions for Q, K, V segments + q_size = num_heads_q * head_dim + k_size = num_heads_k * head_dim + v_size = num_heads_v * head_dim + + # Verify dimensions match + assert hidden_size == q_size + k_size + v_size, f"Hidden size {hidden_size} doesn't match Q+K+V dimensions {q_size + k_size + v_size}" + + # Split the tensor into Q, K, V parts + q = qkv[:, :q_size] + k = qkv[:, q_size:q_size + k_size] + v = qkv[:, q_size + k_size:] + + # Create and apply RMSNorm modules with custom weights + q_norm = RMSNorm(hidden_size=head_dim, eps=eps).to(qkv.device).to(qkv.dtype) + k_norm = RMSNorm(hidden_size=head_dim, eps=eps).to(qkv.device).to(qkv.dtype) + + # Set the weights to the provided weights + q_norm.weight.data.copy_(q_weight) + k_norm.weight.data.copy_(k_weight) + + # Apply RMSNorm to Q and K + q_normalized = q_norm(q.reshape(num_tokens * num_heads_q, + head_dim)).reshape(num_tokens, q_size) + k_normalized = k_norm(k.reshape(num_tokens * num_heads_k, + head_dim)).reshape(num_tokens, k_size) + + # Create and apply RotaryEmbedding module + rope_params = RopeParams( + dim=head_dim, # Set the rotary dimension to match the head dimension + theta=base, # Base value for RoPE calculations + max_positions=8192 # Large enough for any reasonable hidden size + ) + rotary_emb = RotaryEmbedding(rope_params=rope_params, + head_dim=head_dim, + is_neox=is_neox).to(qkv.device) + + # Apply RoPE to the normalized Q and K + [q_rope, k_rope] = rotary_emb(position_ids, [q_normalized, k_normalized]) + + # Combine Q, K, V back together + result = torch.cat([q_rope, k_rope, v], dim=1) + + return result + + +head_dims = [64, 128] +# (Q heads, K heads, V heads) +num_heads_groups = [ + (16, 8, 8), # Qwen3-0.6B, Qwen3-1.7B + (32, 8, 8), # Qwen3-4B, Qwen3-8B, Qwen3-30B-A3B + (40, 8, 8), # Qwen3-14B + (64, 8, 8) # Qwen3-32B, Qwen3-235B-A22B +] +num_tokens_list = [1, 3, 8, 32, 256] +is_neox_list = [False, True] +dtypes = [torch.bfloat16] # TODO: support float16 + + +@pytest.mark.parametrize("head_dim", head_dims) +@pytest.mark.parametrize("num_heads_group", num_heads_groups) +@pytest.mark.parametrize("num_tokens", num_tokens_list) +@pytest.mark.parametrize("is_neox", is_neox_list) +@pytest.mark.parametrize("dtype", dtypes) +def test_fused_qk_norm_rope(head_dim, num_heads_group, num_tokens, is_neox, + dtype): + """ + Test the fused QK RMSNorm + RoPE operation with various configurations. + + This test verifies that the fused kernel correctly applies: + 1. RMSNorm to both query (Q) and key (K) portions of the QKV tensor + 2. Rotary Position Embeddings (RoPE) to the normalized Q and K + 3. Leaves the value (V) portion unchanged + + Args: + head_dim: Dimension of each attention head + num_heads_group: Tuple of (num_heads_q, num_heads_k, num_heads_v) + num_tokens: Number of tokens to process + dtype: Data type (float16 or bfloat16) + """ + device = "cuda" + torch_dtype = dtype + + # Unpack head counts + num_heads_q, num_heads_k, num_heads_v = num_heads_group + + # Calculate total hidden dimension + hidden_size = (num_heads_q + num_heads_k + num_heads_v) * head_dim + + # Generate random inputs directly as 2D [num_tokens, hidden_size] + torch.random.manual_seed(0) + qkv = torch.randn(num_tokens, hidden_size, dtype=torch_dtype, device=device) + qkv_copy = qkv.clone() + + # Generate position IDs with +100 offset to test decoding scenarios + position_ids = torch.arange(num_tokens, dtype=torch.int32, + device=device) + 100 + + # Generate random weights for RMSNorm + q_weight = torch.randn(head_dim, dtype=torch_dtype, device=device) * 5.0 + k_weight = torch.randn(head_dim, dtype=torch_dtype, device=device) * 5.0 + + # Set RMSNorm and RoPE parameters + eps = 1e-5 + base = 10000.0 + + # Run the custom fusedQKNormRope operation + torch.ops.trtllm.fused_qk_norm_rope(qkv, num_heads_q, num_heads_k, + num_heads_v, head_dim, eps, q_weight, + k_weight, base, is_neox, position_ids) + output = qkv # This op is inplace + + # Compute reference output using TensorRT-LLM modules + ref_output = torch_ref_rms_norm_rope(qkv_copy, num_heads_q, num_heads_k, + num_heads_v, head_dim, eps, q_weight, + k_weight, base, is_neox, position_ids) + + # Compare outputs from custom kernel vs reference implementation + torch.testing.assert_close( + output, + ref_output, + rtol=5e-2, + atol=1e-1, + )