mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
325 lines
14 KiB
C++
325 lines
14 KiB
C++
/*
|
|
* Adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/csrc/causal_conv1d.cpp
|
|
* Copyright (c) 2024, Tri Dao.
|
|
*
|
|
* Copyright (c) 2022-2025, NVIDIA CORPORATION. All rights reserved.
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
#include "tensorrt_llm/common/cudaUtils.h"
|
|
#include "tensorrt_llm/kernels/causalConv1d/causalConv1d.h"
|
|
|
|
#include <ATen/cuda/CUDAContext.h>
|
|
#include <c10/cuda/CUDAGuard.h>
|
|
#include <torch/all.h>
|
|
|
|
TRTLLM_NAMESPACE_BEGIN
|
|
|
|
namespace torch_ext
|
|
{
|
|
|
|
#define CHECK_SHAPE(x, ...) \
|
|
TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
|
|
|
|
#define DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, NAME, ...) \
|
|
if (ITYPE == at::ScalarType::Half) \
|
|
{ \
|
|
using input_t = half; \
|
|
using weight_t = half; \
|
|
__VA_ARGS__(); \
|
|
} \
|
|
else if (ITYPE == at::ScalarType::BFloat16) \
|
|
{ \
|
|
using input_t = nv_bfloat16; \
|
|
using weight_t = nv_bfloat16; \
|
|
__VA_ARGS__(); \
|
|
} \
|
|
else if (ITYPE == at::ScalarType::Float) \
|
|
{ \
|
|
using input_t = float; \
|
|
using weight_t = float; \
|
|
__VA_ARGS__(); \
|
|
} \
|
|
else \
|
|
{ \
|
|
AT_ERROR(#NAME, " not implemented for input type '", toString(ITYPE), "'"); \
|
|
}
|
|
|
|
void set_conv_params_fwd(tensorrt_llm::kernels::causal_conv1d::ConvParamsBase& params,
|
|
// sizes
|
|
const size_t batch, const size_t dim, const size_t seqlen, const size_t width,
|
|
// device pointers
|
|
const at::Tensor x, const at::Tensor weight, const at::Tensor out, std::optional<at::Tensor> const& bias,
|
|
bool silu_activation, int64_t pad_slot_id, std::optional<at::Tensor> const& query_start_loc = std::nullopt,
|
|
std::optional<at::Tensor> const& cache_indices = std::nullopt,
|
|
std::optional<at::Tensor> const& has_initial_state = std::nullopt)
|
|
{
|
|
|
|
// Reset the parameters
|
|
memset(¶ms, 0, sizeof(params));
|
|
|
|
params.batch = batch;
|
|
params.dim = dim;
|
|
params.seqlen = seqlen;
|
|
params.width = width;
|
|
params.pad_slot_id = pad_slot_id;
|
|
|
|
params.silu_activation = silu_activation;
|
|
|
|
// Set the pointers and strides.
|
|
params.x_ptr = x.data_ptr();
|
|
params.weight_ptr = weight.data_ptr();
|
|
params.bias_ptr = bias.has_value() ? bias.value().data_ptr() : nullptr;
|
|
params.out_ptr = out.data_ptr();
|
|
// All stride are in elements, not bytes.
|
|
params.query_start_loc_ptr = query_start_loc.has_value() ? query_start_loc.value().data_ptr() : nullptr;
|
|
params.cache_indices_ptr = cache_indices.has_value() ? cache_indices.value().data_ptr() : nullptr;
|
|
params.has_initial_state_ptr = has_initial_state.has_value() ? has_initial_state.value().data_ptr() : nullptr;
|
|
bool const varlen = params.query_start_loc_ptr != nullptr;
|
|
params.x_batch_stride = x.stride(varlen ? 1 : 0);
|
|
params.x_c_stride = x.stride(varlen ? 0 : 1);
|
|
params.x_l_stride = x.stride(varlen ? 1 : -1);
|
|
params.weight_c_stride = weight.stride(0);
|
|
params.weight_width_stride = weight.stride(1);
|
|
params.out_batch_stride = out.stride(varlen ? 1 : 0);
|
|
params.out_c_stride = out.stride(varlen ? 0 : 1);
|
|
params.out_l_stride = out.stride(varlen ? 1 : -1);
|
|
}
|
|
|
|
void causalConv1dFwd(at::Tensor const& x, at::Tensor const& weight, std::optional<at::Tensor> const& bias_,
|
|
std::optional<at::Tensor> const& conv_states, std::optional<at::Tensor> const& query_start_loc,
|
|
std::optional<at::Tensor> const& cache_indices, std::optional<at::Tensor> const& has_initial_state,
|
|
bool silu_activation,
|
|
// used to identify padding entries if cache_indices provided
|
|
// in case of padding, the kernel will return early
|
|
int64_t pad_slot_id)
|
|
{
|
|
auto input_type = x.scalar_type();
|
|
auto weight_type = weight.scalar_type();
|
|
TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half
|
|
|| input_type == at::ScalarType::BFloat16);
|
|
TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::Half
|
|
|| weight_type == at::ScalarType::BFloat16);
|
|
|
|
TORCH_CHECK(x.is_cuda());
|
|
TORCH_CHECK(weight.is_cuda());
|
|
|
|
bool const varlen = query_start_loc.has_value() ? true : false;
|
|
auto const sizes = x.sizes();
|
|
int const batch_size = varlen ? query_start_loc.value().sizes()[0] - 1 : sizes[0];
|
|
int const dim = varlen ? sizes[0] : sizes[1];
|
|
int const seqlen = varlen ? sizes[1] : sizes[2];
|
|
int const width = weight.size(-1);
|
|
if (varlen)
|
|
{
|
|
CHECK_SHAPE(x, dim, seqlen);
|
|
}
|
|
else
|
|
{
|
|
CHECK_SHAPE(x, batch_size, dim, seqlen);
|
|
}
|
|
CHECK_SHAPE(weight, dim, width);
|
|
|
|
if (bias_.has_value())
|
|
{
|
|
auto bias = bias_.value();
|
|
TORCH_CHECK(bias.scalar_type() == weight_type);
|
|
TORCH_CHECK(bias.is_cuda());
|
|
TORCH_CHECK(bias.stride(-1) == 1);
|
|
CHECK_SHAPE(bias, dim);
|
|
}
|
|
|
|
if (has_initial_state.has_value())
|
|
{
|
|
auto has_initial_state_ = has_initial_state.value();
|
|
TORCH_CHECK(has_initial_state_.scalar_type() == at::ScalarType::Bool);
|
|
TORCH_CHECK(has_initial_state_.is_cuda());
|
|
CHECK_SHAPE(has_initial_state_, batch_size);
|
|
}
|
|
|
|
if (query_start_loc.has_value())
|
|
{
|
|
auto query_start_loc_ = query_start_loc.value();
|
|
TORCH_CHECK(query_start_loc_.scalar_type() == at::ScalarType::Int);
|
|
TORCH_CHECK(query_start_loc_.is_cuda());
|
|
}
|
|
|
|
if (cache_indices.has_value())
|
|
{
|
|
auto cache_indices_ = cache_indices.value();
|
|
TORCH_CHECK(cache_indices_.scalar_type() == at::ScalarType::Int);
|
|
TORCH_CHECK(cache_indices_.is_cuda());
|
|
CHECK_SHAPE(cache_indices_, batch_size);
|
|
}
|
|
|
|
at::Tensor out = x;
|
|
|
|
tensorrt_llm::kernels::causal_conv1d::ConvParamsBase params;
|
|
set_conv_params_fwd(params, batch_size, dim, seqlen, width, x, weight, out, bias_, silu_activation, pad_slot_id,
|
|
query_start_loc, cache_indices, has_initial_state);
|
|
|
|
if (conv_states.has_value())
|
|
{
|
|
auto conv_states_ = conv_states.value();
|
|
TORCH_CHECK(conv_states_.scalar_type() == input_type);
|
|
TORCH_CHECK(conv_states_.is_cuda());
|
|
params.conv_states_ptr = conv_states_.data_ptr();
|
|
params.conv_states_batch_stride = conv_states_.stride(0);
|
|
params.conv_states_c_stride = conv_states_.stride(1);
|
|
params.conv_states_l_stride = conv_states_.stride(2);
|
|
}
|
|
else
|
|
{
|
|
params.conv_states_ptr = nullptr;
|
|
}
|
|
|
|
// Otherwise the kernel will be launched from cuda:0 device
|
|
// Static cast to signed char (AKA c10::DeviceIndex - the input to CUDAGuard) to avoid compiler warning about
|
|
// narrowing
|
|
at::cuda::CUDAGuard device_guard{static_cast<signed char>(x.get_device())};
|
|
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
|
DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_fwd",
|
|
[&] { tensorrt_llm::kernels::causal_conv1d::causal_conv1d_fwd_cuda<input_t, weight_t>(params, stream); });
|
|
}
|
|
|
|
void causalConv1dUpdate(at::Tensor const& x, at::Tensor const& conv_state, at::Tensor const& weight,
|
|
std::optional<at::Tensor> const& bias_, bool silu_activation, std::optional<at::Tensor> const& cache_seqlens_,
|
|
std::optional<at::Tensor> const& conv_state_indices_,
|
|
// used to identify padding entries if cache_indices provided
|
|
// in case of padding, the kernel will return early
|
|
int64_t pad_slot_id)
|
|
{
|
|
auto input_type = x.scalar_type();
|
|
auto weight_type = weight.scalar_type();
|
|
TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half
|
|
|| input_type == at::ScalarType::BFloat16);
|
|
TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::Half
|
|
|| weight_type == at::ScalarType::BFloat16);
|
|
TORCH_CHECK(weight_type == input_type,
|
|
"weight type must equal to input type, other variations are disabled due to binary size limitations");
|
|
TORCH_CHECK(conv_state.scalar_type() == input_type);
|
|
|
|
TORCH_CHECK(x.is_cuda());
|
|
TORCH_CHECK(conv_state.is_cuda());
|
|
TORCH_CHECK(weight.is_cuda());
|
|
|
|
auto const sizes = x.sizes();
|
|
int const batch_size = sizes[0];
|
|
int const dim = sizes[1];
|
|
int const seqlen = sizes[2];
|
|
int const width = weight.size(-1);
|
|
int const conv_state_len = conv_state.size(2);
|
|
TORCH_CHECK(conv_state_len >= width - 1);
|
|
|
|
CHECK_SHAPE(x, batch_size, dim, seqlen);
|
|
CHECK_SHAPE(weight, dim, width);
|
|
|
|
TORCH_CHECK(width >= 2 && width <= 4, "causal_conv1d only supports width between 2 and 4");
|
|
|
|
if (bias_.has_value())
|
|
{
|
|
auto bias = bias_.value();
|
|
TORCH_CHECK(bias.scalar_type() == weight_type);
|
|
TORCH_CHECK(bias.is_cuda());
|
|
TORCH_CHECK(bias.stride(-1) == 1);
|
|
CHECK_SHAPE(bias, dim);
|
|
}
|
|
|
|
at::Tensor out = x;
|
|
|
|
tensorrt_llm::kernels::causal_conv1d::ConvParamsBase params;
|
|
set_conv_params_fwd(params, batch_size, dim, seqlen, width, x, weight, out, bias_, silu_activation, pad_slot_id);
|
|
params.conv_state_ptr = conv_state.data_ptr();
|
|
params.conv_state_len = conv_state_len;
|
|
// All stride are in elements, not bytes.
|
|
params.conv_state_batch_stride = conv_state.stride(0);
|
|
params.conv_state_c_stride = conv_state.stride(1);
|
|
params.conv_state_l_stride = conv_state.stride(2);
|
|
|
|
if (cache_seqlens_.has_value())
|
|
{
|
|
auto cache_seqlens = cache_seqlens_.value();
|
|
TORCH_CHECK(cache_seqlens.scalar_type() == torch::kInt32);
|
|
TORCH_CHECK(cache_seqlens.is_cuda());
|
|
TORCH_CHECK(cache_seqlens.stride(-1) == 1);
|
|
CHECK_SHAPE(cache_seqlens, batch_size);
|
|
params.cache_seqlens = cache_seqlens.data_ptr<int32_t>();
|
|
}
|
|
else
|
|
{
|
|
params.cache_seqlens = nullptr;
|
|
}
|
|
|
|
if (conv_state_indices_.has_value())
|
|
{
|
|
auto conv_state_indices = conv_state_indices_.value();
|
|
TORCH_CHECK(conv_state_indices.scalar_type() == torch::kInt32)
|
|
TORCH_CHECK(conv_state_indices.is_cuda());
|
|
TORCH_CHECK(conv_state_indices.stride(0) == 1)
|
|
CHECK_SHAPE(conv_state_indices, batch_size);
|
|
|
|
int conv_state_entries = conv_state.size(0);
|
|
CHECK_SHAPE(conv_state, conv_state_entries, dim, conv_state_len);
|
|
|
|
params.conv_state_indices_ptr = conv_state_indices.data_ptr<int32_t>();
|
|
}
|
|
else
|
|
{
|
|
CHECK_SHAPE(conv_state, batch_size, dim, conv_state_len);
|
|
params.conv_state_indices_ptr = nullptr;
|
|
}
|
|
|
|
// Otherwise the kernel will be launched from cuda:0 device
|
|
// Static cast to signed char (AKA c10::DeviceIndex - the input to CUDAGuard) to avoid compiler warning about
|
|
// narrowing
|
|
at::cuda::CUDAGuard device_guard{static_cast<signed char>(x.get_device())};
|
|
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
|
DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_update",
|
|
[&] { tensorrt_llm::kernels::causal_conv1d::causal_conv1d_update_cuda<input_t, weight_t>(params, stream); });
|
|
}
|
|
|
|
} // namespace torch_ext
|
|
|
|
TRTLLM_NAMESPACE_END
|
|
|
|
TORCH_LIBRARY_FRAGMENT(trtllm, m)
|
|
{
|
|
m.def(
|
|
"causal_conv1d_fwd(Tensor! x,"
|
|
"Tensor! weight,"
|
|
"Tensor? bias_,"
|
|
"Tensor!? conv_states,"
|
|
"Tensor? query_start_loc,"
|
|
"Tensor? cache_indices,"
|
|
"Tensor? has_initial_state,"
|
|
"bool silu_activation,"
|
|
"int pad_slot_id) -> ()");
|
|
|
|
m.def(
|
|
"causal_conv1d_update(Tensor! x,"
|
|
"Tensor! conv_state,"
|
|
"Tensor! weight,"
|
|
"Tensor? bias_,"
|
|
"bool silu_activation,"
|
|
"Tensor? cache_seqlens_,"
|
|
"Tensor? conv_state_indices,"
|
|
"int pad_slot_id) -> ()");
|
|
}
|
|
|
|
TORCH_LIBRARY_IMPL(trtllm, CUDA, m)
|
|
{
|
|
m.impl("causal_conv1d_fwd", &tensorrt_llm::torch_ext::causalConv1dFwd);
|
|
m.impl("causal_conv1d_update", &tensorrt_llm::torch_ext::causalConv1dUpdate);
|
|
}
|