mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
* Add Julien's origina kernel. Signed-off-by: Bo Li <22713281+bobboli@users.noreply.github.com> * Get rid of UpdateKVCache functionality. Signed-off-by: Bo Li <22713281+bobboli@users.noreply.github.com> * Add kernels. Signed-off-by: Bo Li <22713281+bobboli@users.noreply.github.com> * Add torch OP. Signed-off-by: Bo Li <22713281+bobboli@users.noreply.github.com> * Update cmake. Signed-off-by: Bo Li <22713281+bobboli@users.noreply.github.com> * Torch OP must use double as argument dtype. Signed-off-by: Bo Li <22713281+bobboli@users.noreply.github.com> * Add unittest. Signed-off-by: Bo Li <22713281+bobboli@users.noreply.github.com> * Add unittest. Signed-off-by: Bo Li <22713281+bobboli@users.noreply.github.com> * Fix misaligned access when head_dim=64. In this case, numElemsPerThread=2, numVecPerThread=0. But the store code incorrectly perform vectorized store, some threads (e.g., lane1) issue store to address that is not aligned to 64 bit. Signed-off-by: Bo Li <22713281+bobboli@users.noreply.github.com> * Remove unroll (compiler can do that). Cleanup code. Signed-off-by: Bo Li <22713281+bobboli@users.noreply.github.com> * Add switch for interleave. Signed-off-by: Bo Li <22713281+bobboli@users.noreply.github.com> * Refactor vectorized load/store. Signed-off-by: Bo Li <22713281+bobboli@users.noreply.github.com> * Implement is_neox. Result not correct yet. Signed-off-by: Bo Li <22713281+bobboli@users.noreply.github.com> * Fix is_neox=True. Signed-off-by: Bo Li <22713281+bobboli@users.noreply.github.com> * Add q_weight and k_weight. Signed-off-by: Bo Li <22713281+bobboli@users.noreply.github.com> --------- Signed-off-by: Bo Li <22713281+bobboli@users.noreply.github.com>
45 lines
1.8 KiB
C++
45 lines
1.8 KiB
C++
/*
|
|
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
#pragma once
|
|
|
|
#include <cuda_runtime.h>
|
|
|
|
namespace tensorrt_llm
|
|
{
|
|
namespace kernels
|
|
{
|
|
|
|
// Perform fused QK Normalization and RoPE in a single CUDA kernel
|
|
// This function efficiently applies RMS normalization and RoPE embeddings to query and key tensors
|
|
void launchFusedQKNormRope(
|
|
void* qkv, // Combined QKV tensor [num_tokens, (num_heads_q+num_heads_k+num_heads_v)*head_dim]
|
|
int const num_tokens, // Number of tokens
|
|
int const num_heads_q, // Number of query heads
|
|
int const num_heads_k, // Number of key heads
|
|
int const num_heads_v, // Number of value heads
|
|
int const head_dim, // Dimension per head
|
|
float const eps, // Epsilon for RMS normalization
|
|
void const* q_weight, // RMSNorm weights for query [head_dim]
|
|
void const* k_weight, // RMSNorm weights for key [head_dim]
|
|
float const base, // Base for RoPE computation
|
|
bool const interleave, // Whether RoPE is applied in interleave mode (non-Neox style)
|
|
int const* position_ids, // Position IDs for RoPE [num_tokens]
|
|
cudaStream_t stream); // CUDA stream
|
|
|
|
} // namespace kernels
|
|
} // namespace tensorrt_llm
|