TensorRT-LLMs/cpp/tensorrt_llm/kernels/selectiveScanCommon.h
Kaiyu Xie 0ab9d17a59
Update TensorRT-LLM (#1055)
* Update TensorRT-LLM

---------

Co-authored-by: Shixiaowei02 <39303645+Shixiaowei02@users.noreply.github.com>
2024-02-06 18:38:07 +08:00

285 lines
9.7 KiB
C++

/*
* Adapted from https://github.com/state-spaces/mamba/blob/main/csrc/selective_scan/selective_scan_common.h
* Copyright (c) 2023, Tri Dao.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* Not a contribution
* Changes made by NVIDIA CORPORATION & AFFILIATES or otherwise documented as
* NVIDIA-proprietary are not a contribution and subject to the following terms and conditions:
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: LicenseRef-NvidiaProprietary
*
* NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
* property and proprietary rights in and to this material, related
* documentation and any modifications thereto. Any use, reproduction,
* disclosure or distribution of this material and related documentation
* without an express license agreement from NVIDIA CORPORATION or
* its affiliates is strictly prohibited.
*/
#pragma once
#include <cuda_bf16.h>
#include <cuda_fp16.h>
namespace tensorrt_llm
{
namespace kernels
{
#define MAX_DSTATE 256
inline __device__ float2 operator+(const float2& a, const float2& b)
{
return {a.x + b.x, a.y + b.y};
}
inline __device__ float3 operator+(const float3& a, const float3& b)
{
return {a.x + b.x, a.y + b.y, a.z + b.z};
}
inline __device__ float4 operator+(const float4& a, const float4& b)
{
return {a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w};
}
////////////////////////////////////////////////////////////////////////////////////////////////////
// Inspired by https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h
// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h
/// @param COND - a boolean expression to switch by
/// @param CONST_NAME - a name given for the constexpr bool variable.
/// @param ... - code to execute for true and false
///
/// Usage:
/// ```
/// BOOL_SWITCH(flag, BoolConst, [&] {
/// some_function<BoolConst>(...);
/// });
/// ```
#define BOOL_SWITCH(COND, CONST_NAME, ...) \
[&] \
{ \
if (COND) \
{ \
static constexpr bool CONST_NAME = true; \
return __VA_ARGS__(); \
} \
else \
{ \
static constexpr bool CONST_NAME = false; \
return __VA_ARGS__(); \
} \
}()
////////////////////////////////////////////////////////////////////////////////////////////////////
template <int BYTES>
struct BytesToType
{
};
template <>
struct BytesToType<16>
{
using Type = uint4;
static_assert(sizeof(Type) == 16);
};
template <>
struct BytesToType<8>
{
using Type = uint64_t;
static_assert(sizeof(Type) == 8);
};
template <>
struct BytesToType<4>
{
using Type = uint32_t;
static_assert(sizeof(Type) == 4);
};
template <>
struct BytesToType<2>
{
using Type = uint16_t;
static_assert(sizeof(Type) == 2);
};
template <>
struct BytesToType<1>
{
using Type = uint8_t;
static_assert(sizeof(Type) == 1);
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename scalar_t, int N>
struct Converter
{
static inline __device__ void to_float(const scalar_t (&src)[N], float (&dst)[N])
{
#pragma unroll
for (int i = 0; i < N; ++i)
{
dst[i] = src[i];
}
}
};
template <int N>
struct Converter<half, N>
{
static inline __device__ void to_float(const half (&src)[N], float (&dst)[N])
{
static_assert(N % 2 == 0);
auto& src2 = reinterpret_cast<const half2(&)[N / 2]>(src);
auto& dst2 = reinterpret_cast<float2(&)[N / 2]>(dst);
#pragma unroll
for (int i = 0; i < N / 2; ++i)
{
dst2[i] = __half22float2(src2[i]);
}
}
};
#if __CUDA_ARCH__ >= 800
template <int N>
struct Converter<__nv_bfloat16, N>
{
static inline __device__ void to_float(const __nv_bfloat16 (&src)[N], float (&dst)[N])
{
static_assert(N % 2 == 0);
auto& src2 = reinterpret_cast<const nv_bfloat162(&)[N / 2]>(src);
auto& dst2 = reinterpret_cast<float2(&)[N / 2]>(dst);
#pragma unroll
for (int i = 0; i < N / 2; ++i)
{
dst2[i] = __bfloat1622float2(src2[i]);
}
}
};
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename scalar_t>
struct SSMScanOp;
template <>
struct SSMScanOp<float>
{
__device__ __forceinline__ float2 operator()(const float2& ab0, const float2& ab1) const
{
return make_float2(ab1.x * ab0.x, ab1.x * ab0.y + ab1.y);
}
};
// A stateful callback functor that maintains a running prefix to be applied
// during consecutive scan operations.
template <typename scalar_t>
struct SSMScanPrefixCallbackOp
{
using scan_t = std::conditional_t<std::is_same_v<scalar_t, float>, float2, float4>;
scan_t running_prefix;
// Constructor
__device__ SSMScanPrefixCallbackOp(scan_t running_prefix_)
: running_prefix(running_prefix_)
{
}
// Callback operator to be entered by the first warp of threads in the block.
// Thread-0 is responsible for returning a value for seeding the block-wide scan.
__device__ scan_t operator()(scan_t block_aggregate)
{
scan_t old_prefix = running_prefix;
running_prefix = SSMScanOp<scalar_t>()(running_prefix, block_aggregate);
return old_prefix;
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Ktraits>
inline __device__ void load_input(typename Ktraits::input_t* u, typename Ktraits::input_t (&u_vals)[Ktraits::kNItems],
typename Ktraits::BlockLoadT::TempStorage& smem_load, int seqlen)
{
if constexpr (Ktraits::kIsEvenLen)
{
auto& smem_load_vec = reinterpret_cast<typename Ktraits::BlockLoadVecT::TempStorage&>(smem_load);
using vec_t = typename Ktraits::vec_t;
Ktraits::BlockLoadVecT(smem_load_vec)
.Load(reinterpret_cast<vec_t*>(u), reinterpret_cast<vec_t(&)[Ktraits::kNLoads]>(u_vals));
}
else
{
Ktraits::BlockLoadT(smem_load).Load(u, u_vals, seqlen, 0.f);
}
}
template <typename Ktraits>
inline __device__ void load_weight(typename Ktraits::input_t* Bvar,
typename Ktraits::weight_t (&B_vals)[Ktraits::kNItems],
typename Ktraits::BlockLoadWeightT::TempStorage& smem_load_weight, int seqlen)
{
constexpr int kNItems = Ktraits::kNItems;
typename Ktraits::input_t B_vals_load[kNItems];
if constexpr (Ktraits::kIsEvenLen)
{
auto& smem_load_weight_vec
= reinterpret_cast<typename Ktraits::BlockLoadWeightVecT::TempStorage&>(smem_load_weight);
using vec_t = typename Ktraits::vec_t;
Ktraits::BlockLoadWeightVecT(smem_load_weight_vec)
.Load(reinterpret_cast<vec_t*>(Bvar), reinterpret_cast<vec_t(&)[Ktraits::kNLoads]>(B_vals_load));
}
else
{
Ktraits::BlockLoadWeightT(smem_load_weight).Load(Bvar, B_vals_load, seqlen, 0.f);
}
// #pragma unroll
// for (int i = 0; i < kNItems; ++i) { B_vals[i] = B_vals_load[i]; }
Converter<typename Ktraits::input_t, kNItems>::to_float(B_vals_load, B_vals);
}
template <typename Ktraits>
inline __device__ void store_output(typename Ktraits::input_t* out, const float (&out_vals)[Ktraits::kNItems],
typename Ktraits::BlockStoreT::TempStorage& smem_store, int seqlen)
{
typename Ktraits::input_t write_vals[Ktraits::kNItems];
#pragma unroll
for (int i = 0; i < Ktraits::kNItems; ++i)
{
write_vals[i] = out_vals[i];
}
if constexpr (Ktraits::kIsEvenLen)
{
auto& smem_store_vec = reinterpret_cast<typename Ktraits::BlockStoreVecT::TempStorage&>(smem_store);
using vec_t = typename Ktraits::vec_t;
Ktraits::BlockStoreVecT(smem_store_vec)
.Store(reinterpret_cast<vec_t*>(out), reinterpret_cast<vec_t(&)[Ktraits::kNLoads]>(write_vals));
}
else
{
Ktraits::BlockStoreT(smem_store).Store(out, write_vals, seqlen);
}
}
} // namespace kernels
} // namespace tensorrt_llm