mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
196 lines
6.3 KiB
C++
196 lines
6.3 KiB
C++
/*
|
|
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
#include "tensorrt_llm/common/cudaUtils.h"
|
|
#include "tensorrt_llm/kernels/mambaConv1dKernels.h"
|
|
#include "tensorrt_llm/thop/thUtils.h"
|
|
|
|
namespace th = torch;
|
|
namespace tk = tensorrt_llm::kernels;
|
|
|
|
TRTLLM_NAMESPACE_BEGIN
|
|
|
|
namespace torch_ext
|
|
{
|
|
|
|
std::tuple<th::Tensor, th::Tensor> mamba_conv1d(th::Tensor const& input, th::Tensor const& conv_weight,
|
|
th::Tensor const& conv_bias, th::Tensor const& conv_state, th::Tensor const& host_request_types,
|
|
th::Tensor const& last_token_ids, th::optional<th::Tensor> host_context_lengths,
|
|
th::optional<th::Tensor> slot_mapping, int64_t const dim, int64_t const dconv, int64_t const pre_stride,
|
|
int64_t const post_stride, bool const remove_padding, bool const apply_silu, bool const is_paged_state)
|
|
{
|
|
// tensors info: [shapes] x [dtype]
|
|
// input: [batch_size, seq_len, dim] or [num_tokens, dim] for remove_padding x [float16, float32, bfloat16]
|
|
// conv_weight: [1, dconv, dim] x [float16, float32, bfloat16]
|
|
// conv_bias: [dim] x [float16, float32, bfloat16]
|
|
// conv_state: [batch_size, dconv-1, dim] x [float16, float32, bfloat16]
|
|
// host_request_types: [batch_size] x [int32]
|
|
// last_token_ids: [batch_size] x [int32]
|
|
// host_context_lengths: [batch_size] x [int32] for remove_padding
|
|
|
|
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
|
|
|
tk::MambaConv1dParamsBase params;
|
|
|
|
auto host_request_sizes = host_request_types.sizes();
|
|
auto input_sizes = input.sizes();
|
|
|
|
int batch_size = host_request_sizes[0];
|
|
int max_seqlen;
|
|
|
|
if (remove_padding && host_context_lengths.has_value())
|
|
{
|
|
max_seqlen = host_context_lengths.value().max().item<int>();
|
|
}
|
|
else
|
|
{
|
|
max_seqlen = input_sizes[1];
|
|
}
|
|
|
|
// req_type=0 -> context (prefill)
|
|
// req_type=1 -> generation (decode)
|
|
auto req_type = host_request_types[0].item<int>();
|
|
|
|
int idx = (remove_padding) ? 1 : 2;
|
|
int64_t out_dim = input_sizes[idx] - pre_stride - post_stride;
|
|
|
|
std::vector<int64_t> out_shape;
|
|
if (remove_padding)
|
|
{
|
|
out_shape = {input_sizes[0], out_dim};
|
|
}
|
|
else
|
|
{
|
|
out_shape = {input_sizes[0], input_sizes[1], out_dim};
|
|
}
|
|
|
|
auto out = torch::empty(out_shape, input.options());
|
|
auto state_out = torch::empty_like(conv_state);
|
|
|
|
params.batch = batch_size;
|
|
params.dim = dim;
|
|
params.max_seqlen = max_seqlen;
|
|
params.dconv = dconv;
|
|
params.pre_stride = pre_stride;
|
|
params.post_stride = post_stride;
|
|
params.remove_padding = remove_padding;
|
|
params.apply_silu = apply_silu;
|
|
|
|
// Set the pointers and strides.
|
|
params.in_ptr = input.data_ptr();
|
|
params.weight_ptr = conv_weight.data_ptr();
|
|
params.bias_ptr = conv_bias.data_ptr();
|
|
params.out_ptr = out.data_ptr();
|
|
params.last_token_ids_ptr = static_cast<int const*>(last_token_ids.const_data_ptr());
|
|
|
|
if (is_paged_state)
|
|
{
|
|
if (!slot_mapping.has_value())
|
|
{
|
|
throw std::invalid_argument("slot_mapping must be provided when paged state is enabled");
|
|
}
|
|
|
|
params.state_in_ptr = *reinterpret_cast<void**>(const_cast<void*>(conv_state.data_ptr()));
|
|
params.state_out_ptr = *reinterpret_cast<void**>(const_cast<void*>(conv_state.data_ptr()));
|
|
params.state_slot_mapping_ptr = static_cast<int const*>(slot_mapping.value().const_data_ptr());
|
|
}
|
|
else
|
|
{
|
|
params.state_in_ptr = conv_state.data_ptr();
|
|
params.state_out_ptr = state_out.data_ptr();
|
|
params.state_slot_mapping_ptr = nullptr;
|
|
}
|
|
|
|
c10::ScalarType dtype = input.scalar_type();
|
|
|
|
if (req_type == 0)
|
|
{
|
|
switch (dtype)
|
|
{
|
|
case torch::kFloat16:
|
|
// Handle Float16
|
|
tk::invokeMambaConv1dContext<half>(params, stream);
|
|
break;
|
|
case torch::kFloat32:
|
|
// Handle Float32
|
|
tk::invokeMambaConv1dContext<float>(params, stream);
|
|
break;
|
|
case torch::kBFloat16:
|
|
// Handle BFloat16
|
|
tk::invokeMambaConv1dContext<__nv_bfloat16>(params, stream);
|
|
break;
|
|
default:
|
|
// Handle other data types
|
|
throw std::invalid_argument("Invalid dtype, only supports float16, float32, and bfloat16");
|
|
break;
|
|
}
|
|
}
|
|
else
|
|
{
|
|
switch (dtype)
|
|
{
|
|
case torch::kFloat16:
|
|
// Handle Float16
|
|
tk::invokeMambaConv1dGeneration<half>(params, stream);
|
|
break;
|
|
case torch::kFloat32:
|
|
// Handle Float32
|
|
tk::invokeMambaConv1dGeneration<float>(params, stream);
|
|
break;
|
|
case torch::kBFloat16:
|
|
// Handle BFloat16
|
|
tk::invokeMambaConv1dGeneration<__nv_bfloat16>(params, stream);
|
|
break;
|
|
default:
|
|
// Handle other data types
|
|
throw std::invalid_argument("Invalid dtype, only supports float16, float32, and bfloat16");
|
|
break;
|
|
}
|
|
}
|
|
|
|
sync_check_cuda_error(stream);
|
|
|
|
if (is_paged_state)
|
|
{
|
|
return std::make_tuple(out, conv_state);
|
|
}
|
|
else
|
|
{
|
|
return std::make_tuple(out, state_out);
|
|
}
|
|
}
|
|
|
|
} // namespace torch_ext
|
|
|
|
TRTLLM_NAMESPACE_END
|
|
|
|
TORCH_LIBRARY_FRAGMENT(trtllm, m)
|
|
{
|
|
m.def(
|
|
"mamba_conv1d(Tensor input, Tensor conv_weight, "
|
|
"Tensor conv_bias, Tensor conv_state, "
|
|
"Tensor host_request_types, Tensor last_token_ids, "
|
|
"Tensor? host_context_lengths, Tensor? slot_mapping, "
|
|
"int dim, int dconv, int pre_stride, int post_stride, "
|
|
"bool remove_padding, bool apply_silu, "
|
|
"bool is_paged_state) -> (Tensor, Tensor)");
|
|
}
|
|
|
|
TORCH_LIBRARY_IMPL(trtllm, CUDA, m)
|
|
{
|
|
m.impl("mamba_conv1d", &tensorrt_llm::torch_ext::mamba_conv1d);
|
|
}
|