TensorRT-LLMs/cpp/tensorrt_llm/thop/groupRmsNormOp.cpp
Simeng Liu 873c7532fd
feat: Add group_rms_norm kernel to normalize multiple inputs in a single operator. (#3438)
* feat: Add group_rms_norm kernel to normalize multiple inputs in a single operator.

Previously, the RMSNorm implementation only supported a single input tensor. With group_rms_norm, multiple tensors can be normalized together:
```python
input_a, input_b, ... = group_rms_norm([input_a, input_b, ...])
```
All input tensors must share the same batch dimension. The kernel partitions work by dynamically assigning warp groups proportional to the last dimension of each input, improving launch efficiency and reducing overhead.

This MR provides two implementations:
GroupRMSNormKernel: Optimized for small-to-medium batch sizes
GroupRMSNormKernelLargeBatch: Contains additional optimizations for large batch sizes

Both kernels are currently exposed as custom PyTorch ops. A future MR will implement heuristic-based kernel selection and expose a unified interface.

Signed-off-by: Simeng Liu <simengl@nvidia.com>

* Resolve comments and fix typo with IS_FLASHINFER_AVAILABLE

Signed-off-by: Simeng Liu <simengl@nvidia.com>

---------

Signed-off-by: Simeng Liu <simengl@nvidia.com>
2025-05-02 13:25:30 +08:00

199 lines
12 KiB
C++

/*
* Copyright (c) 2022-2025, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/common/dataType.h"
#include "tensorrt_llm/common/opUtils.h"
#include "tensorrt_llm/kernels/groupRmsNormKernels/groupRmsNormKernels.h"
#include "tensorrt_llm/runtime/torchUtils.h"
#include "tensorrt_llm/thop/thUtils.h"
#include <ATen/cuda/CUDAContext.h>
#include <cstdint>
#include <cuda_runtime.h>
#include <stdexcept>
#include <torch/extension.h>
#include <vector>
namespace torch_ext
{
void groupRMSNorm(torch::TensorList const& inputs, torch::TensorList const& outputs, torch::TensorList const& weights,
double eps, double weight_bias)
{
TORCH_CHECK(!inputs.empty(), "Input tensor list cannot be empty.");
TORCH_CHECK(!outputs.empty(), "Output tensor list cannot be empty.");
uint32_t const num_inputs = inputs.size();
TORCH_CHECK(num_inputs <= 2, "Only up to 2 inputs are supported.");
auto const first_input = inputs[0];
TORCH_CHECK(first_input.dim() == 2, "Inputs must be 2D tensors [batch_size, hidden_dim].");
TORCH_CHECK(first_input.sizes()[0] > 0, "Batch size must be greater than 0.");
TORCH_CHECK(first_input.is_cuda(), "Inputs must be CUDA tensors.");
uint32_t const batch_size = first_input.sizes()[0];
auto const dtype = first_input.scalar_type();
auto const device = first_input.device();
if (!weights.empty())
{
TORCH_CHECK(weights.size() == num_inputs, "Weights list size must match inputs list size.");
}
for (size_t i = 0; i < num_inputs; ++i)
{
TORCH_CHECK(inputs[i].sizes()[0] == batch_size, "Inputs must have the same batch size.");
TORCH_CHECK(inputs[i].dim() == 2, "Inputs must be 2D tensors [batch_size, hidden_dim].");
TORCH_CHECK(inputs[i].device() == device, "Inputs must be on the same device.");
TORCH_CHECK(inputs[i].scalar_type() == dtype, "Inputs must be of the same type.");
TORCH_CHECK(outputs[i].dim() == 2, "Outputs must be 2D tensors [batch_size, hidden_dim].");
TORCH_CHECK(outputs[i].device() == device, "Outputs must be on the same device.");
TORCH_CHECK(outputs[i].scalar_type() == dtype, "Outputs must be of the same type.");
TORCH_CHECK(outputs[i].sizes()[0] == batch_size, "Outputs must have the same batch size.");
TORCH_CHECK(
outputs[i].sizes()[1] == inputs[i].sizes()[1], "Outputs and inputs must have the same last dimension.");
TORCH_CHECK(inputs[i].strides()[0] == outputs[i].strides()[0], "Inputs and outputs must have the same stride.");
TORCH_CHECK(inputs[i].strides()[1] == 1, "Inputs must be contiguous along the last dimension.");
TORCH_CHECK(outputs[i].strides()[1] == 1, "Outputs must be contiguous along the last dimension.");
if (!weights.empty())
{
TORCH_CHECK(
inputs[i].sizes()[1] == weights[i].sizes()[0], "Inputs and weights must have the same last dimension.");
}
}
#define DISPATCH_INPUT_SIZES(n) \
{ \
tensorrt_llm::kernels::group_rms_norm::GroupRMSParams<n> params; \
for (size_t i = 0; i < n; ++i) \
{ \
params.inputs[i] = reinterpret_cast<float4*>(inputs[i].data_ptr()); \
params.input_last_dims[i] = inputs[i].sizes()[1]; \
params.input_strides[i] = inputs[i].strides()[0]; \
params.output_strides[i] = outputs[i].strides()[0]; \
if (!weights.empty()) \
{ \
params.weights[i] = reinterpret_cast<float4 const*>(weights[i].data_ptr()); \
} \
params.outputs[i] = reinterpret_cast<float4*>(outputs[i].mutable_data_ptr()); \
} \
/* Set remaining params */ \
params.batch_size = batch_size; \
params.num_inputs = n; \
params.eps = static_cast<float>(eps); \
params.weight_bias = static_cast<float>(weight_bias); \
params.enable_weights = !weights.empty(); \
params.stream = at::cuda::getCurrentCUDAStream(inputs[0].get_device()); \
/* Handle dtype conversion */ \
switch (dtype) \
{ \
case torch::ScalarType::Half: params.dtype = nvinfer1::DataType::kHALF; break; \
case torch::ScalarType::BFloat16: params.dtype = nvinfer1::DataType::kBF16; break; \
case torch::ScalarType::Float: params.dtype = nvinfer1::DataType::kFLOAT; break; \
default: TORCH_CHECK(false, "Unsupported data type"); \
} \
tensorrt_llm::kernels::group_rms_norm::GroupRMSNormKernelLauncher<n>(params); \
break; \
}
switch (num_inputs)
{
case 1: DISPATCH_INPUT_SIZES(1)
case 2: DISPATCH_INPUT_SIZES(2)
default: TORCH_CHECK(false, "Unsupported number of inputs (max 2)");
}
#undef DISPATCH_INPUT_SIZES
}
void groupRMSNormLargeBatch(torch::TensorList const& inputs, torch::TensorList const& outputs,
torch::TensorList const& weights, double eps, double weight_bias)
{
TORCH_CHECK(!inputs.empty(), "Input tensor list cannot be empty.");
TORCH_CHECK(!outputs.empty(), "Output tensor list cannot be empty.");
TORCH_CHECK(inputs.size() == 2, "groupRMSNormLargeBatch requires exactly 2 input tensors.");
auto const first_input = inputs[0];
TORCH_CHECK(first_input.dim() == 2, "Inputs must be 2D tensors [batch_size, hidden_dim].");
TORCH_CHECK(first_input.sizes()[0] > 0, "Batch size must be greater than 0.");
TORCH_CHECK(first_input.is_cuda(), "Inputs must be CUDA tensors.");
uint32_t const batch_size = first_input.sizes()[0];
auto const dtype = first_input.scalar_type();
auto const device = first_input.device();
uint32_t const num_inputs = inputs.size();
for (size_t i = 0; i < num_inputs; ++i)
{
TORCH_CHECK(inputs[i].sizes()[0] == batch_size, "Inputs must have the same batch size.");
TORCH_CHECK(inputs[i].dim() == 2, "Inputs must be 2D tensors [batch_size, hidden_dim].");
TORCH_CHECK(inputs[i].device() == device, "Inputs must be on the same device.");
TORCH_CHECK(inputs[i].scalar_type() == dtype, "Inputs must be of the same type.");
TORCH_CHECK(outputs[i].dim() == 2, "Outputs must be 2D tensors [batch_size, hidden_dim].");
TORCH_CHECK(outputs[i].device() == device, "Outputs must be on the same device.");
TORCH_CHECK(outputs[i].scalar_type() == dtype, "Outputs must be of the same type.");
TORCH_CHECK(outputs[i].sizes()[0] == batch_size, "Outputs must have the same batch size.");
TORCH_CHECK(
outputs[i].sizes()[1] == inputs[i].sizes()[1], "Outputs and inputs must have the same last dimension.");
TORCH_CHECK(inputs[i].strides()[0] == outputs[i].strides()[0], "Inputs and outputs must have the same stride.");
TORCH_CHECK(inputs[i].strides()[1] == 1, "Inputs must be contiguous along the last dimension.");
TORCH_CHECK(outputs[i].strides()[1] == 1, "Outputs must be contiguous along the last dimension.");
if (!weights.empty())
{
TORCH_CHECK(
inputs[i].sizes()[1] == weights[i].sizes()[0], "Inputs and weights must have the same last dimension.");
}
}
tensorrt_llm::kernels::group_rms_norm::GroupRMSParams<2> params;
for (size_t i = 0; i < 2; ++i)
{
params.inputs[i] = reinterpret_cast<float4*>(inputs[i].data_ptr());
params.input_last_dims[i] = inputs[i].sizes()[1];
params.input_strides[i] = inputs[i].strides()[0];
params.output_strides[i] = outputs[i].strides()[0];
params.outputs[i] = reinterpret_cast<float4*>(outputs[i].mutable_data_ptr());
if (!weights.empty())
{
params.weights[i] = reinterpret_cast<float4 const*>(weights[i].data_ptr());
}
}
// Set remaining params
params.batch_size = batch_size;
params.num_inputs = 2;
params.eps = static_cast<float>(eps);
params.weight_bias = static_cast<float>(weight_bias);
params.enable_weights = !weights.empty();
params.stream = at::cuda::getCurrentCUDAStream(inputs[0].get_device());
// Handle dtype conversion
switch (dtype)
{
case torch::ScalarType::Half: params.dtype = nvinfer1::DataType::kHALF; break;
case torch::ScalarType::BFloat16: params.dtype = nvinfer1::DataType::kBF16; break;
case torch::ScalarType::Float: params.dtype = nvinfer1::DataType::kFLOAT; break;
default: TORCH_CHECK(false, "Unsupported data type");
}
tensorrt_llm::kernels::group_rms_norm::GroupRMSNormKernelLargeBatchLauncher<2>(params);
}
} // namespace torch_ext
TORCH_LIBRARY_IMPL(trtllm, CUDA, m)
{
m.impl("group_rms_norm", &torch_ext::groupRMSNorm);
// TODO: Add heuristics for when to use the large batch kernel and merge into one torch Op.
m.impl("group_rms_norm_large_batch", &torch_ext::groupRMSNormLargeBatch);
}