mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
* Update TensorRT-LLM --------- Co-authored-by: Shixiaowei02 <39303645+Shixiaowei02@users.noreply.github.com>
285 lines
9.7 KiB
C++
285 lines
9.7 KiB
C++
/*
|
|
* Adapted from https://github.com/state-spaces/mamba/blob/main/csrc/selective_scan/selective_scan_common.h
|
|
* Copyright (c) 2023, Tri Dao.
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*
|
|
* Not a contribution
|
|
* Changes made by NVIDIA CORPORATION & AFFILIATES or otherwise documented as
|
|
* NVIDIA-proprietary are not a contribution and subject to the following terms and conditions:
|
|
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
* SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
|
*
|
|
* NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
|
* property and proprietary rights in and to this material, related
|
|
* documentation and any modifications thereto. Any use, reproduction,
|
|
* disclosure or distribution of this material and related documentation
|
|
* without an express license agreement from NVIDIA CORPORATION or
|
|
* its affiliates is strictly prohibited.
|
|
*/
|
|
|
|
#pragma once
|
|
|
|
#include <cuda_bf16.h>
|
|
#include <cuda_fp16.h>
|
|
|
|
namespace tensorrt_llm
|
|
{
|
|
namespace kernels
|
|
{
|
|
|
|
#define MAX_DSTATE 256
|
|
|
|
inline __device__ float2 operator+(const float2& a, const float2& b)
|
|
{
|
|
return {a.x + b.x, a.y + b.y};
|
|
}
|
|
|
|
inline __device__ float3 operator+(const float3& a, const float3& b)
|
|
{
|
|
return {a.x + b.x, a.y + b.y, a.z + b.z};
|
|
}
|
|
|
|
inline __device__ float4 operator+(const float4& a, const float4& b)
|
|
{
|
|
return {a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w};
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// Inspired by https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h
|
|
// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h
|
|
|
|
/// @param COND - a boolean expression to switch by
|
|
/// @param CONST_NAME - a name given for the constexpr bool variable.
|
|
/// @param ... - code to execute for true and false
|
|
///
|
|
/// Usage:
|
|
/// ```
|
|
/// BOOL_SWITCH(flag, BoolConst, [&] {
|
|
/// some_function<BoolConst>(...);
|
|
/// });
|
|
/// ```
|
|
#define BOOL_SWITCH(COND, CONST_NAME, ...) \
|
|
[&] \
|
|
{ \
|
|
if (COND) \
|
|
{ \
|
|
static constexpr bool CONST_NAME = true; \
|
|
return __VA_ARGS__(); \
|
|
} \
|
|
else \
|
|
{ \
|
|
static constexpr bool CONST_NAME = false; \
|
|
return __VA_ARGS__(); \
|
|
} \
|
|
}()
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <int BYTES>
|
|
struct BytesToType
|
|
{
|
|
};
|
|
|
|
template <>
|
|
struct BytesToType<16>
|
|
{
|
|
using Type = uint4;
|
|
static_assert(sizeof(Type) == 16);
|
|
};
|
|
|
|
template <>
|
|
struct BytesToType<8>
|
|
{
|
|
using Type = uint64_t;
|
|
static_assert(sizeof(Type) == 8);
|
|
};
|
|
|
|
template <>
|
|
struct BytesToType<4>
|
|
{
|
|
using Type = uint32_t;
|
|
static_assert(sizeof(Type) == 4);
|
|
};
|
|
|
|
template <>
|
|
struct BytesToType<2>
|
|
{
|
|
using Type = uint16_t;
|
|
static_assert(sizeof(Type) == 2);
|
|
};
|
|
|
|
template <>
|
|
struct BytesToType<1>
|
|
{
|
|
using Type = uint8_t;
|
|
static_assert(sizeof(Type) == 1);
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <typename scalar_t, int N>
|
|
struct Converter
|
|
{
|
|
static inline __device__ void to_float(const scalar_t (&src)[N], float (&dst)[N])
|
|
{
|
|
#pragma unroll
|
|
for (int i = 0; i < N; ++i)
|
|
{
|
|
dst[i] = src[i];
|
|
}
|
|
}
|
|
};
|
|
|
|
template <int N>
|
|
struct Converter<half, N>
|
|
{
|
|
static inline __device__ void to_float(const half (&src)[N], float (&dst)[N])
|
|
{
|
|
static_assert(N % 2 == 0);
|
|
auto& src2 = reinterpret_cast<const half2(&)[N / 2]>(src);
|
|
auto& dst2 = reinterpret_cast<float2(&)[N / 2]>(dst);
|
|
#pragma unroll
|
|
for (int i = 0; i < N / 2; ++i)
|
|
{
|
|
dst2[i] = __half22float2(src2[i]);
|
|
}
|
|
}
|
|
};
|
|
|
|
#if __CUDA_ARCH__ >= 800
|
|
template <int N>
|
|
struct Converter<__nv_bfloat16, N>
|
|
{
|
|
static inline __device__ void to_float(const __nv_bfloat16 (&src)[N], float (&dst)[N])
|
|
{
|
|
static_assert(N % 2 == 0);
|
|
auto& src2 = reinterpret_cast<const nv_bfloat162(&)[N / 2]>(src);
|
|
auto& dst2 = reinterpret_cast<float2(&)[N / 2]>(dst);
|
|
#pragma unroll
|
|
for (int i = 0; i < N / 2; ++i)
|
|
{
|
|
dst2[i] = __bfloat1622float2(src2[i]);
|
|
}
|
|
}
|
|
};
|
|
#endif
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <typename scalar_t>
|
|
struct SSMScanOp;
|
|
|
|
template <>
|
|
struct SSMScanOp<float>
|
|
{
|
|
__device__ __forceinline__ float2 operator()(const float2& ab0, const float2& ab1) const
|
|
{
|
|
return make_float2(ab1.x * ab0.x, ab1.x * ab0.y + ab1.y);
|
|
}
|
|
};
|
|
|
|
// A stateful callback functor that maintains a running prefix to be applied
|
|
// during consecutive scan operations.
|
|
template <typename scalar_t>
|
|
struct SSMScanPrefixCallbackOp
|
|
{
|
|
using scan_t = std::conditional_t<std::is_same_v<scalar_t, float>, float2, float4>;
|
|
scan_t running_prefix;
|
|
|
|
// Constructor
|
|
__device__ SSMScanPrefixCallbackOp(scan_t running_prefix_)
|
|
: running_prefix(running_prefix_)
|
|
{
|
|
}
|
|
|
|
// Callback operator to be entered by the first warp of threads in the block.
|
|
// Thread-0 is responsible for returning a value for seeding the block-wide scan.
|
|
__device__ scan_t operator()(scan_t block_aggregate)
|
|
{
|
|
scan_t old_prefix = running_prefix;
|
|
running_prefix = SSMScanOp<scalar_t>()(running_prefix, block_aggregate);
|
|
return old_prefix;
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <typename Ktraits>
|
|
inline __device__ void load_input(typename Ktraits::input_t* u, typename Ktraits::input_t (&u_vals)[Ktraits::kNItems],
|
|
typename Ktraits::BlockLoadT::TempStorage& smem_load, int seqlen)
|
|
{
|
|
if constexpr (Ktraits::kIsEvenLen)
|
|
{
|
|
auto& smem_load_vec = reinterpret_cast<typename Ktraits::BlockLoadVecT::TempStorage&>(smem_load);
|
|
using vec_t = typename Ktraits::vec_t;
|
|
Ktraits::BlockLoadVecT(smem_load_vec)
|
|
.Load(reinterpret_cast<vec_t*>(u), reinterpret_cast<vec_t(&)[Ktraits::kNLoads]>(u_vals));
|
|
}
|
|
else
|
|
{
|
|
Ktraits::BlockLoadT(smem_load).Load(u, u_vals, seqlen, 0.f);
|
|
}
|
|
}
|
|
|
|
template <typename Ktraits>
|
|
inline __device__ void load_weight(typename Ktraits::input_t* Bvar,
|
|
typename Ktraits::weight_t (&B_vals)[Ktraits::kNItems],
|
|
typename Ktraits::BlockLoadWeightT::TempStorage& smem_load_weight, int seqlen)
|
|
{
|
|
constexpr int kNItems = Ktraits::kNItems;
|
|
typename Ktraits::input_t B_vals_load[kNItems];
|
|
if constexpr (Ktraits::kIsEvenLen)
|
|
{
|
|
auto& smem_load_weight_vec
|
|
= reinterpret_cast<typename Ktraits::BlockLoadWeightVecT::TempStorage&>(smem_load_weight);
|
|
using vec_t = typename Ktraits::vec_t;
|
|
Ktraits::BlockLoadWeightVecT(smem_load_weight_vec)
|
|
.Load(reinterpret_cast<vec_t*>(Bvar), reinterpret_cast<vec_t(&)[Ktraits::kNLoads]>(B_vals_load));
|
|
}
|
|
else
|
|
{
|
|
Ktraits::BlockLoadWeightT(smem_load_weight).Load(Bvar, B_vals_load, seqlen, 0.f);
|
|
}
|
|
// #pragma unroll
|
|
// for (int i = 0; i < kNItems; ++i) { B_vals[i] = B_vals_load[i]; }
|
|
Converter<typename Ktraits::input_t, kNItems>::to_float(B_vals_load, B_vals);
|
|
}
|
|
|
|
template <typename Ktraits>
|
|
inline __device__ void store_output(typename Ktraits::input_t* out, const float (&out_vals)[Ktraits::kNItems],
|
|
typename Ktraits::BlockStoreT::TempStorage& smem_store, int seqlen)
|
|
{
|
|
typename Ktraits::input_t write_vals[Ktraits::kNItems];
|
|
#pragma unroll
|
|
for (int i = 0; i < Ktraits::kNItems; ++i)
|
|
{
|
|
write_vals[i] = out_vals[i];
|
|
}
|
|
if constexpr (Ktraits::kIsEvenLen)
|
|
{
|
|
auto& smem_store_vec = reinterpret_cast<typename Ktraits::BlockStoreVecT::TempStorage&>(smem_store);
|
|
using vec_t = typename Ktraits::vec_t;
|
|
Ktraits::BlockStoreVecT(smem_store_vec)
|
|
.Store(reinterpret_cast<vec_t*>(out), reinterpret_cast<vec_t(&)[Ktraits::kNLoads]>(write_vals));
|
|
}
|
|
else
|
|
{
|
|
Ktraits::BlockStoreT(smem_store).Store(out, write_vals, seqlen);
|
|
}
|
|
}
|
|
|
|
} // namespace kernels
|
|
} // namespace tensorrt_llm
|