mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-16 15:55:08 +08:00
perf: Add fused q_norm/k_norm/RoPE for Qwen3. (#4482)
* Add Julien's origina kernel. Signed-off-by: Bo Li <22713281+bobboli@users.noreply.github.com> * Get rid of UpdateKVCache functionality. Signed-off-by: Bo Li <22713281+bobboli@users.noreply.github.com> * Add kernels. Signed-off-by: Bo Li <22713281+bobboli@users.noreply.github.com> * Add torch OP. Signed-off-by: Bo Li <22713281+bobboli@users.noreply.github.com> * Update cmake. Signed-off-by: Bo Li <22713281+bobboli@users.noreply.github.com> * Torch OP must use double as argument dtype. Signed-off-by: Bo Li <22713281+bobboli@users.noreply.github.com> * Add unittest. Signed-off-by: Bo Li <22713281+bobboli@users.noreply.github.com> * Add unittest. Signed-off-by: Bo Li <22713281+bobboli@users.noreply.github.com> * Fix misaligned access when head_dim=64. In this case, numElemsPerThread=2, numVecPerThread=0. But the store code incorrectly perform vectorized store, some threads (e.g., lane1) issue store to address that is not aligned to 64 bit. Signed-off-by: Bo Li <22713281+bobboli@users.noreply.github.com> * Remove unroll (compiler can do that). Cleanup code. Signed-off-by: Bo Li <22713281+bobboli@users.noreply.github.com> * Add switch for interleave. Signed-off-by: Bo Li <22713281+bobboli@users.noreply.github.com> * Refactor vectorized load/store. Signed-off-by: Bo Li <22713281+bobboli@users.noreply.github.com> * Implement is_neox. Result not correct yet. Signed-off-by: Bo Li <22713281+bobboli@users.noreply.github.com> * Fix is_neox=True. Signed-off-by: Bo Li <22713281+bobboli@users.noreply.github.com> * Add q_weight and k_weight. Signed-off-by: Bo Li <22713281+bobboli@users.noreply.github.com> --------- Signed-off-by: Bo Li <22713281+bobboli@users.noreply.github.com>
This commit is contained in:
parent
6527c055cf
commit
9ae705af1b
270
cpp/tensorrt_llm/kernels/fusedQKNormRopeKernel.cu
Normal file
270
cpp/tensorrt_llm/kernels/fusedQKNormRopeKernel.cu
Normal file
@ -0,0 +1,270 @@
|
||||
/*
|
||||
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "fusedQKNormRopeKernel.h"
|
||||
#include "tensorrt_llm/common/cudaUtils.h"
|
||||
#include "tensorrt_llm/common/mathUtils.h"
|
||||
#include "tensorrt_llm/common/reduceKernelUtils.cuh"
|
||||
#include <cmath>
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_fp8.h>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
namespace tensorrt_llm::common
|
||||
{
|
||||
// Specialization for packed_as used in this kernel.
|
||||
template <>
|
||||
struct packed_as<uint, 1>
|
||||
{
|
||||
using type = uint;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct packed_as<uint, 2>
|
||||
{
|
||||
using type = uint2;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct packed_as<uint, 4>
|
||||
{
|
||||
using type = uint4;
|
||||
};
|
||||
} // namespace tensorrt_llm::common
|
||||
|
||||
namespace tensorrt_llm::kernels
|
||||
{
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Perform per-head QK Norm and RoPE in a single kernel.
|
||||
// head_dim: the dimension of each head
|
||||
// interleave: interleave=!is_neox.
|
||||
template <int head_dim, bool interleave>
|
||||
__global__ void fusedQKNormRopeKernel(
|
||||
__nv_bfloat16* qkv, // Combined QKV tensor [num_tokens, (num_heads_q+num_heads_k+num_heads_v)*head_dim]
|
||||
int const num_heads_q, // Number of query heads
|
||||
int const num_heads_k, // Number of key heads
|
||||
int const num_heads_v, // Number of value heads
|
||||
float const eps, // Epsilon for RMS normalization
|
||||
__nv_bfloat16 const* q_weight, // RMSNorm weights for query
|
||||
__nv_bfloat16 const* k_weight, // RMSNorm weights for key
|
||||
float const base, // Base for RoPE computation
|
||||
int const* position_ids, // Position IDs for RoPE
|
||||
int const num_tokens // Number of tokens
|
||||
)
|
||||
{
|
||||
int const warpsPerBlock = blockDim.x / 32;
|
||||
int const warpId = threadIdx.x / 32;
|
||||
int const laneId = threadIdx.x % 32;
|
||||
|
||||
// Calculate global warp index to determine which head/token this warp processes
|
||||
int const globalWarpIdx = blockIdx.x * warpsPerBlock + warpId;
|
||||
|
||||
// Total number of attention heads (Q and K)
|
||||
int const total_qk_heads = num_heads_q + num_heads_k;
|
||||
|
||||
// Determine which token and head type (Q or K) this warp processes
|
||||
int const tokenIdx = globalWarpIdx / total_qk_heads;
|
||||
int const localHeadIdx = globalWarpIdx % total_qk_heads;
|
||||
|
||||
// Skip if this warp is assigned beyond the number of tokens
|
||||
if (tokenIdx >= num_tokens)
|
||||
return;
|
||||
|
||||
bool const isQ = localHeadIdx < num_heads_q;
|
||||
int const headIdx = isQ ? localHeadIdx : localHeadIdx - num_heads_q;
|
||||
|
||||
int const num_heads = num_heads_q + num_heads_k + num_heads_v;
|
||||
|
||||
static_assert(head_dim % (32 * 2) == 0,
|
||||
"head_dim must be divisible by 64 (each warp processes one head, and each thread gets even number of "
|
||||
"elements)");
|
||||
constexpr int numElemsPerThread = head_dim / 32;
|
||||
float elements[numElemsPerThread];
|
||||
constexpr int elemSizeBytes = numElemsPerThread * sizeof(__nv_bfloat16);
|
||||
static_assert(elemSizeBytes % 4 == 0, "numSizeBytes must be a multiple of 4");
|
||||
constexpr int vecSize = elemSizeBytes / 4; // Use packed_as<uint, vecSize> to perform loading/saving.
|
||||
using vec_T = typename tensorrt_llm::common::packed_as<uint, vecSize>::type;
|
||||
|
||||
int offsetWarp; // Offset for the warp
|
||||
if (isQ)
|
||||
{
|
||||
// Q segment: token offset + head offset within Q segment
|
||||
offsetWarp = tokenIdx * num_heads * head_dim + headIdx * head_dim;
|
||||
}
|
||||
else
|
||||
{
|
||||
// K segment: token offset + entire Q segment + head offset within K segment
|
||||
offsetWarp = tokenIdx * num_heads * head_dim + num_heads_q * head_dim + headIdx * head_dim;
|
||||
}
|
||||
int offsetThread = offsetWarp + laneId * numElemsPerThread;
|
||||
|
||||
// Sum of squares for RMSNorm
|
||||
float sumOfSquares = 0.0f;
|
||||
|
||||
// Load.
|
||||
{
|
||||
vec_T vec = *reinterpret_cast<vec_T const*>(&qkv[offsetThread]);
|
||||
for (int i = 0; i < vecSize; i++)
|
||||
{
|
||||
float2 vals = __bfloat1622float2(*reinterpret_cast<__nv_bfloat162*>(reinterpret_cast<uint*>(&vec) + i));
|
||||
sumOfSquares += vals.x * vals.x;
|
||||
sumOfSquares += vals.y * vals.y;
|
||||
|
||||
elements[2 * i] = vals.x;
|
||||
elements[2 * i + 1] = vals.y;
|
||||
}
|
||||
}
|
||||
|
||||
// Reduce sum across warp using the utility function
|
||||
sumOfSquares = tensorrt_llm::common::warpReduceSum(sumOfSquares);
|
||||
|
||||
// Compute RMS normalization factor
|
||||
float rms_rcp = rsqrtf(sumOfSquares / static_cast<float>(head_dim) + eps);
|
||||
|
||||
// Normalize elements
|
||||
for (int i = 0; i < numElemsPerThread; i++)
|
||||
{
|
||||
int dim = laneId * numElemsPerThread + i;
|
||||
float weight = isQ ? __bfloat162float(q_weight[dim]) : __bfloat162float(k_weight[dim]);
|
||||
elements[i] *= rms_rcp * weight;
|
||||
}
|
||||
|
||||
// Apply RoPE to normalized elements
|
||||
float elements2[numElemsPerThread]; // Additional buffer required for RoPE.
|
||||
float cos_vals[numElemsPerThread];
|
||||
float sin_vals[numElemsPerThread];
|
||||
|
||||
float pos_id = static_cast<float>(position_ids[tokenIdx]);
|
||||
|
||||
// TODO: cos sin calculation could be halved.
|
||||
if constexpr (interleave)
|
||||
{
|
||||
// Perform interleaving. Fill cos_vals and sin_vals.
|
||||
for (int i = 0; i < numElemsPerThread; i++)
|
||||
{
|
||||
if (i % 2 == 0)
|
||||
{
|
||||
elements2[i] = -elements[i + 1];
|
||||
}
|
||||
else
|
||||
{
|
||||
elements2[i] = elements[i - 1];
|
||||
}
|
||||
|
||||
int dim_idx = laneId * numElemsPerThread + i;
|
||||
int half_dim = dim_idx / 2;
|
||||
float freq = powf(base, -2.0f * half_dim / static_cast<float>(head_dim));
|
||||
float theta = pos_id * freq;
|
||||
__sincosf(theta, &sin_vals[i], &cos_vals[i]);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// Before data exchange with in warp, we need to sync.
|
||||
__syncwarp();
|
||||
// Get the data from the other half of the warp. Fill cos_vals and sin_vals.
|
||||
for (int i = 0; i < numElemsPerThread; i++)
|
||||
{
|
||||
elements2[i] = __shfl_xor_sync(0xffffffff, elements[i], 16);
|
||||
if (laneId < 16)
|
||||
{
|
||||
elements2[i] = -elements2[i];
|
||||
}
|
||||
|
||||
int dim_idx = laneId * numElemsPerThread + i;
|
||||
dim_idx = (dim_idx * 2) % head_dim;
|
||||
int half_dim = dim_idx / 2;
|
||||
float freq = powf(base, -2.0f * half_dim / static_cast<float>(head_dim));
|
||||
float theta = pos_id * freq;
|
||||
__sincosf(theta, &sin_vals[i], &cos_vals[i]);
|
||||
}
|
||||
// __shfl_xor_sync does not provide memfence. Need to sync again.
|
||||
__syncwarp();
|
||||
}
|
||||
|
||||
for (int i = 0; i < numElemsPerThread; i++)
|
||||
{
|
||||
elements[i] = elements[i] * cos_vals[i] + elements2[i] * sin_vals[i];
|
||||
}
|
||||
|
||||
// Store.
|
||||
{
|
||||
vec_T vec;
|
||||
for (int i = 0; i < vecSize; i++)
|
||||
{
|
||||
__nv_bfloat162 vals = __float22bfloat162_rn(make_float2(elements[2 * i], elements[2 * i + 1]));
|
||||
reinterpret_cast<__nv_bfloat162&>(*(reinterpret_cast<uint*>(&vec) + i)) = vals;
|
||||
}
|
||||
vec_T* outputPtr = reinterpret_cast<vec_T*>(&qkv[offsetThread]);
|
||||
*outputPtr = vec;
|
||||
}
|
||||
}
|
||||
|
||||
// Borrowed from
|
||||
// https://github.com/flashinfer-ai/flashinfer/blob/8125d079a43e9a0ba463a4ed1b639cefd084cec9/include/flashinfer/pos_enc.cuh#L568
|
||||
#define DISPATCH_INTERLEAVE(interleave, INTERLEAVE, ...) \
|
||||
if (interleave) \
|
||||
{ \
|
||||
const bool INTERLEAVE = true; \
|
||||
__VA_ARGS__ \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
const bool INTERLEAVE = false; \
|
||||
__VA_ARGS__ \
|
||||
}
|
||||
|
||||
void launchFusedQKNormRope(void* qkv, int const num_tokens, int const num_heads_q, int const num_heads_k,
|
||||
int const num_heads_v, int const head_dim, float const eps, void const* q_weight, void const* k_weight,
|
||||
float const base, bool const interleave, int const* position_ids, cudaStream_t stream)
|
||||
{
|
||||
constexpr int blockSize = 256;
|
||||
|
||||
int const warpsPerBlock = blockSize / 32;
|
||||
int const totalQKHeads = num_heads_q + num_heads_k;
|
||||
int const totalWarps = num_tokens * totalQKHeads;
|
||||
|
||||
int const gridSize = common::divUp(totalWarps, warpsPerBlock);
|
||||
dim3 gridDim(gridSize);
|
||||
dim3 blockDim(blockSize);
|
||||
|
||||
// Head dimensions should be a multiple of 64
|
||||
// Add more cases as needed
|
||||
switch (head_dim)
|
||||
{
|
||||
case 64:
|
||||
DISPATCH_INTERLEAVE(interleave, INTERLEAVE, {
|
||||
fusedQKNormRopeKernel<64, INTERLEAVE>
|
||||
<<<gridDim, blockDim, 0, stream>>>(reinterpret_cast<__nv_bfloat16*>(qkv), num_heads_q, num_heads_k,
|
||||
num_heads_v, eps, reinterpret_cast<__nv_bfloat16 const*>(q_weight),
|
||||
reinterpret_cast<__nv_bfloat16 const*>(k_weight), base, position_ids, num_tokens);
|
||||
});
|
||||
break;
|
||||
case 128:
|
||||
DISPATCH_INTERLEAVE(interleave, INTERLEAVE, {
|
||||
fusedQKNormRopeKernel<128, INTERLEAVE>
|
||||
<<<gridDim, blockDim, 0, stream>>>(reinterpret_cast<__nv_bfloat16*>(qkv), num_heads_q, num_heads_k,
|
||||
num_heads_v, eps, reinterpret_cast<__nv_bfloat16 const*>(q_weight),
|
||||
reinterpret_cast<__nv_bfloat16 const*>(k_weight), base, position_ids, num_tokens);
|
||||
});
|
||||
break;
|
||||
default: TLLM_THROW("Unsupported head dimension for fusedQKNormRope: %d", head_dim);
|
||||
}
|
||||
}
|
||||
} // namespace tensorrt_llm::kernels
|
||||
44
cpp/tensorrt_llm/kernels/fusedQKNormRopeKernel.h
Normal file
44
cpp/tensorrt_llm/kernels/fusedQKNormRopeKernel.h
Normal file
@ -0,0 +1,44 @@
|
||||
/*
|
||||
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
namespace kernels
|
||||
{
|
||||
|
||||
// Perform fused QK Normalization and RoPE in a single CUDA kernel
|
||||
// This function efficiently applies RMS normalization and RoPE embeddings to query and key tensors
|
||||
void launchFusedQKNormRope(
|
||||
void* qkv, // Combined QKV tensor [num_tokens, (num_heads_q+num_heads_k+num_heads_v)*head_dim]
|
||||
int const num_tokens, // Number of tokens
|
||||
int const num_heads_q, // Number of query heads
|
||||
int const num_heads_k, // Number of key heads
|
||||
int const num_heads_v, // Number of value heads
|
||||
int const head_dim, // Dimension per head
|
||||
float const eps, // Epsilon for RMS normalization
|
||||
void const* q_weight, // RMSNorm weights for query [head_dim]
|
||||
void const* k_weight, // RMSNorm weights for key [head_dim]
|
||||
float const base, // Base for RoPE computation
|
||||
bool const interleave, // Whether RoPE is applied in interleave mode (non-Neox style)
|
||||
int const* position_ids, // Position IDs for RoPE [num_tokens]
|
||||
cudaStream_t stream); // CUDA stream
|
||||
|
||||
} // namespace kernels
|
||||
} // namespace tensorrt_llm
|
||||
@ -55,6 +55,7 @@ add_library(
|
||||
fp4BatchedQuantize.cpp
|
||||
fp8BlockScalingGemm.cpp
|
||||
fp8Quantize.cpp
|
||||
fusedQKNormRopeOp.cpp
|
||||
fusedTopkSoftmax.cpp
|
||||
gatherTreeOp.cpp
|
||||
groupRmsNormOp.cpp
|
||||
|
||||
89
cpp/tensorrt_llm/thop/fusedQKNormRopeOp.cpp
Normal file
89
cpp/tensorrt_llm/thop/fusedQKNormRopeOp.cpp
Normal file
@ -0,0 +1,89 @@
|
||||
/*
|
||||
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tensorrt_llm/kernels/fusedQKNormRopeKernel.h"
|
||||
#include "tensorrt_llm/thop/thUtils.h"
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <torch/extension.h>
|
||||
|
||||
namespace torch_ext
|
||||
{
|
||||
|
||||
// Function for fused QK Norm and RoPE
|
||||
// This operator applies RMS normalization and RoPE to Q and K tensors in a single CUDA kernel.
|
||||
// The OP performs operations in-place on the input qkv tensor.
|
||||
void fused_qk_norm_rope(
|
||||
torch::Tensor& qkv, // Combined QKV tensor [num_tokens, (num_heads_q+num_heads_k+num_heads_v)*head_dim]
|
||||
int64_t num_heads_q, // Number of query heads
|
||||
int64_t num_heads_k, // Number of key heads
|
||||
int64_t num_heads_v, // Number of value heads
|
||||
int64_t head_dim, // Dimension per head
|
||||
double eps, // Epsilon for RMS normalization
|
||||
torch::Tensor& q_weight, // RMSNorm weights for query [head_dim]
|
||||
torch::Tensor& k_weight, // RMSNorm weights for key [head_dim]
|
||||
double base, // Base for RoPE computation
|
||||
bool is_neox, // Whether RoPE is applied in Neox style
|
||||
torch::Tensor& position_ids // Position IDs for RoPE [num_tokens]
|
||||
)
|
||||
{
|
||||
// Input validation
|
||||
TORCH_CHECK(qkv.dim() == 2, "QKV tensor must be 2D: [num_tokens, (num_heads_q+num_heads_k+num_heads_v)*head_dim]");
|
||||
TORCH_CHECK(position_ids.dim() == 1, "Position IDs must be 1D: [num_tokens]");
|
||||
TORCH_CHECK(q_weight.dim() == 1, "Query weights must be 1D: [head_dim]");
|
||||
TORCH_CHECK(k_weight.dim() == 1, "Key weights must be 1D: [head_dim]");
|
||||
TORCH_CHECK(q_weight.size(0) == head_dim, "Query weights size must match head dimension");
|
||||
TORCH_CHECK(k_weight.size(0) == head_dim, "Key weights size must match head dimension");
|
||||
|
||||
CHECK_INPUT(qkv, torch::kBFloat16);
|
||||
CHECK_INPUT(position_ids, torch::kInt32);
|
||||
CHECK_INPUT(q_weight, torch::kBFloat16);
|
||||
CHECK_INPUT(k_weight, torch::kBFloat16);
|
||||
|
||||
int64_t num_tokens = qkv.size(0);
|
||||
TORCH_CHECK(position_ids.size(0) == num_tokens, "Number of tokens in position_ids must match QKV");
|
||||
|
||||
int64_t total_heads = num_heads_q + num_heads_k + num_heads_v;
|
||||
TORCH_CHECK(
|
||||
qkv.size(1) == total_heads * head_dim, "QKV tensor size must match total number of heads and head dimension");
|
||||
|
||||
auto stream = at::cuda::getCurrentCUDAStream(qkv.get_device());
|
||||
|
||||
tensorrt_llm::kernels::launchFusedQKNormRope(reinterpret_cast<__nv_bfloat16*>(qkv.data_ptr()),
|
||||
static_cast<int>(num_tokens), static_cast<int>(num_heads_q), static_cast<int>(num_heads_k),
|
||||
static_cast<int>(num_heads_v), static_cast<int>(head_dim), static_cast<float>(eps),
|
||||
reinterpret_cast<__nv_bfloat16*>(q_weight.data_ptr()), reinterpret_cast<__nv_bfloat16*>(k_weight.data_ptr()),
|
||||
static_cast<float>(base),
|
||||
!is_neox, // interleave
|
||||
reinterpret_cast<int const*>(position_ids.data_ptr()), stream);
|
||||
}
|
||||
|
||||
// Register the PyTorch operators
|
||||
TORCH_LIBRARY_FRAGMENT(trtllm, m)
|
||||
{
|
||||
m.def(
|
||||
"fused_qk_norm_rope(Tensor qkv, int num_heads_q, int num_heads_k, int num_heads_v, int head_dim, float eps, "
|
||||
"Tensor q_weight, Tensor k_weight, float base, bool is_neox, Tensor position_ids) -> ()",
|
||||
&fused_qk_norm_rope);
|
||||
}
|
||||
|
||||
// Register the CUDA implementation
|
||||
TORCH_LIBRARY_IMPL(trtllm, CUDA, m)
|
||||
{
|
||||
m.impl("fused_qk_norm_rope", &fused_qk_norm_rope);
|
||||
}
|
||||
|
||||
} // namespace torch_ext
|
||||
162
tests/unittest/_torch/thop/test_fused_qk_norm_rope.py
Normal file
162
tests/unittest/_torch/thop/test_fused_qk_norm_rope.py
Normal file
@ -0,0 +1,162 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tensorrt_llm._torch.attention_backend.interface import RopeParams
|
||||
from tensorrt_llm._torch.modules.rms_norm import RMSNorm
|
||||
from tensorrt_llm._torch.modules.rotary_embedding import RotaryEmbedding
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def torch_ref_rms_norm_rope(qkv, num_heads_q, num_heads_k, num_heads_v,
|
||||
head_dim, eps, q_weight, k_weight, base, is_neox,
|
||||
position_ids):
|
||||
"""
|
||||
PyTorch reference implementation of RMSNorm+RoPE for verification.
|
||||
|
||||
Uses TensorRT-LLM's own RMSNorm and RotaryEmbedding modules to ensure consistency
|
||||
with the expected behavior of the fused kernel.
|
||||
|
||||
Args:
|
||||
qkv: Combined QKV tensor of shape [num_tokens, hidden_size]
|
||||
num_heads_q: Number of query heads
|
||||
num_heads_k: Number of key heads
|
||||
num_heads_v: Number of value heads (unused for normalization/RoPE but needed for tensor splitting)
|
||||
head_dim: Dimension of each head
|
||||
eps: Epsilon value for RMS normalization
|
||||
q_weight: RMSNorm weights for query [head_dim]
|
||||
k_weight: RMSNorm weights for key [head_dim]
|
||||
base: Base value for RoPE calculations
|
||||
is_neox: Whether to use NeoX style RoPE
|
||||
position_ids: Position IDs for RoPE of shape [num_tokens]
|
||||
|
||||
Returns:
|
||||
Combined tensor with Q and K parts normalized and RoPE applied
|
||||
"""
|
||||
# Get input shape information
|
||||
num_tokens = qkv.shape[0]
|
||||
hidden_size = qkv.shape[1]
|
||||
|
||||
# Calculate dimensions for Q, K, V segments
|
||||
q_size = num_heads_q * head_dim
|
||||
k_size = num_heads_k * head_dim
|
||||
v_size = num_heads_v * head_dim
|
||||
|
||||
# Verify dimensions match
|
||||
assert hidden_size == q_size + k_size + v_size, f"Hidden size {hidden_size} doesn't match Q+K+V dimensions {q_size + k_size + v_size}"
|
||||
|
||||
# Split the tensor into Q, K, V parts
|
||||
q = qkv[:, :q_size]
|
||||
k = qkv[:, q_size:q_size + k_size]
|
||||
v = qkv[:, q_size + k_size:]
|
||||
|
||||
# Create and apply RMSNorm modules with custom weights
|
||||
q_norm = RMSNorm(hidden_size=head_dim, eps=eps).to(qkv.device).to(qkv.dtype)
|
||||
k_norm = RMSNorm(hidden_size=head_dim, eps=eps).to(qkv.device).to(qkv.dtype)
|
||||
|
||||
# Set the weights to the provided weights
|
||||
q_norm.weight.data.copy_(q_weight)
|
||||
k_norm.weight.data.copy_(k_weight)
|
||||
|
||||
# Apply RMSNorm to Q and K
|
||||
q_normalized = q_norm(q.reshape(num_tokens * num_heads_q,
|
||||
head_dim)).reshape(num_tokens, q_size)
|
||||
k_normalized = k_norm(k.reshape(num_tokens * num_heads_k,
|
||||
head_dim)).reshape(num_tokens, k_size)
|
||||
|
||||
# Create and apply RotaryEmbedding module
|
||||
rope_params = RopeParams(
|
||||
dim=head_dim, # Set the rotary dimension to match the head dimension
|
||||
theta=base, # Base value for RoPE calculations
|
||||
max_positions=8192 # Large enough for any reasonable hidden size
|
||||
)
|
||||
rotary_emb = RotaryEmbedding(rope_params=rope_params,
|
||||
head_dim=head_dim,
|
||||
is_neox=is_neox).to(qkv.device)
|
||||
|
||||
# Apply RoPE to the normalized Q and K
|
||||
[q_rope, k_rope] = rotary_emb(position_ids, [q_normalized, k_normalized])
|
||||
|
||||
# Combine Q, K, V back together
|
||||
result = torch.cat([q_rope, k_rope, v], dim=1)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
head_dims = [64, 128]
|
||||
# (Q heads, K heads, V heads)
|
||||
num_heads_groups = [
|
||||
(16, 8, 8), # Qwen3-0.6B, Qwen3-1.7B
|
||||
(32, 8, 8), # Qwen3-4B, Qwen3-8B, Qwen3-30B-A3B
|
||||
(40, 8, 8), # Qwen3-14B
|
||||
(64, 8, 8) # Qwen3-32B, Qwen3-235B-A22B
|
||||
]
|
||||
num_tokens_list = [1, 3, 8, 32, 256]
|
||||
is_neox_list = [False, True]
|
||||
dtypes = [torch.bfloat16] # TODO: support float16
|
||||
|
||||
|
||||
@pytest.mark.parametrize("head_dim", head_dims)
|
||||
@pytest.mark.parametrize("num_heads_group", num_heads_groups)
|
||||
@pytest.mark.parametrize("num_tokens", num_tokens_list)
|
||||
@pytest.mark.parametrize("is_neox", is_neox_list)
|
||||
@pytest.mark.parametrize("dtype", dtypes)
|
||||
def test_fused_qk_norm_rope(head_dim, num_heads_group, num_tokens, is_neox,
|
||||
dtype):
|
||||
"""
|
||||
Test the fused QK RMSNorm + RoPE operation with various configurations.
|
||||
|
||||
This test verifies that the fused kernel correctly applies:
|
||||
1. RMSNorm to both query (Q) and key (K) portions of the QKV tensor
|
||||
2. Rotary Position Embeddings (RoPE) to the normalized Q and K
|
||||
3. Leaves the value (V) portion unchanged
|
||||
|
||||
Args:
|
||||
head_dim: Dimension of each attention head
|
||||
num_heads_group: Tuple of (num_heads_q, num_heads_k, num_heads_v)
|
||||
num_tokens: Number of tokens to process
|
||||
dtype: Data type (float16 or bfloat16)
|
||||
"""
|
||||
device = "cuda"
|
||||
torch_dtype = dtype
|
||||
|
||||
# Unpack head counts
|
||||
num_heads_q, num_heads_k, num_heads_v = num_heads_group
|
||||
|
||||
# Calculate total hidden dimension
|
||||
hidden_size = (num_heads_q + num_heads_k + num_heads_v) * head_dim
|
||||
|
||||
# Generate random inputs directly as 2D [num_tokens, hidden_size]
|
||||
torch.random.manual_seed(0)
|
||||
qkv = torch.randn(num_tokens, hidden_size, dtype=torch_dtype, device=device)
|
||||
qkv_copy = qkv.clone()
|
||||
|
||||
# Generate position IDs with +100 offset to test decoding scenarios
|
||||
position_ids = torch.arange(num_tokens, dtype=torch.int32,
|
||||
device=device) + 100
|
||||
|
||||
# Generate random weights for RMSNorm
|
||||
q_weight = torch.randn(head_dim, dtype=torch_dtype, device=device) * 5.0
|
||||
k_weight = torch.randn(head_dim, dtype=torch_dtype, device=device) * 5.0
|
||||
|
||||
# Set RMSNorm and RoPE parameters
|
||||
eps = 1e-5
|
||||
base = 10000.0
|
||||
|
||||
# Run the custom fusedQKNormRope operation
|
||||
torch.ops.trtllm.fused_qk_norm_rope(qkv, num_heads_q, num_heads_k,
|
||||
num_heads_v, head_dim, eps, q_weight,
|
||||
k_weight, base, is_neox, position_ids)
|
||||
output = qkv # This op is inplace
|
||||
|
||||
# Compute reference output using TensorRT-LLM modules
|
||||
ref_output = torch_ref_rms_norm_rope(qkv_copy, num_heads_q, num_heads_k,
|
||||
num_heads_v, head_dim, eps, q_weight,
|
||||
k_weight, base, is_neox, position_ids)
|
||||
|
||||
# Compare outputs from custom kernel vs reference implementation
|
||||
torch.testing.assert_close(
|
||||
output,
|
||||
ref_output,
|
||||
rtol=5e-2,
|
||||
atol=1e-1,
|
||||
)
|
||||
Loading…
Reference in New Issue
Block a user