/* * Adapted from https://github.com/state-spaces/mamba/blob/main/csrc/selective_scan/selective_scan_common.h * Copyright (c) 2023, Tri Dao. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * Not a contribution * Changes made by NVIDIA CORPORATION & AFFILIATES or otherwise documented as * NVIDIA-proprietary are not a contribution and subject to the following terms and conditions: * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: LicenseRef-NvidiaProprietary * * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual * property and proprietary rights in and to this material, related * documentation and any modifications thereto. Any use, reproduction, * disclosure or distribution of this material and related documentation * without an express license agreement from NVIDIA CORPORATION or * its affiliates is strictly prohibited. */ #pragma once #include #include namespace tensorrt_llm { namespace kernels { #define MAX_DSTATE 256 inline __device__ float2 operator+(const float2& a, const float2& b) { return {a.x + b.x, a.y + b.y}; } inline __device__ float3 operator+(const float3& a, const float3& b) { return {a.x + b.x, a.y + b.y, a.z + b.z}; } inline __device__ float4 operator+(const float4& a, const float4& b) { return {a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w}; } //////////////////////////////////////////////////////////////////////////////////////////////////// // Inspired by https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h // and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h /// @param COND - a boolean expression to switch by /// @param CONST_NAME - a name given for the constexpr bool variable. /// @param ... - code to execute for true and false /// /// Usage: /// ``` /// BOOL_SWITCH(flag, BoolConst, [&] { /// some_function(...); /// }); /// ``` #define BOOL_SWITCH(COND, CONST_NAME, ...) \ [&] \ { \ if (COND) \ { \ static constexpr bool CONST_NAME = true; \ return __VA_ARGS__(); \ } \ else \ { \ static constexpr bool CONST_NAME = false; \ return __VA_ARGS__(); \ } \ }() //////////////////////////////////////////////////////////////////////////////////////////////////// template struct BytesToType { }; template <> struct BytesToType<16> { using Type = uint4; static_assert(sizeof(Type) == 16); }; template <> struct BytesToType<8> { using Type = uint64_t; static_assert(sizeof(Type) == 8); }; template <> struct BytesToType<4> { using Type = uint32_t; static_assert(sizeof(Type) == 4); }; template <> struct BytesToType<2> { using Type = uint16_t; static_assert(sizeof(Type) == 2); }; template <> struct BytesToType<1> { using Type = uint8_t; static_assert(sizeof(Type) == 1); }; //////////////////////////////////////////////////////////////////////////////////////////////////// template struct Converter { static inline __device__ void to_float(const scalar_t (&src)[N], float (&dst)[N]) { #pragma unroll for (int i = 0; i < N; ++i) { dst[i] = src[i]; } } }; template struct Converter { static inline __device__ void to_float(const half (&src)[N], float (&dst)[N]) { static_assert(N % 2 == 0); auto& src2 = reinterpret_cast(src); auto& dst2 = reinterpret_cast(dst); #pragma unroll for (int i = 0; i < N / 2; ++i) { dst2[i] = __half22float2(src2[i]); } } }; #if __CUDA_ARCH__ >= 800 template struct Converter<__nv_bfloat16, N> { static inline __device__ void to_float(const __nv_bfloat16 (&src)[N], float (&dst)[N]) { static_assert(N % 2 == 0); auto& src2 = reinterpret_cast(src); auto& dst2 = reinterpret_cast(dst); #pragma unroll for (int i = 0; i < N / 2; ++i) { dst2[i] = __bfloat1622float2(src2[i]); } } }; #endif //////////////////////////////////////////////////////////////////////////////////////////////////// template struct SSMScanOp; template <> struct SSMScanOp { __device__ __forceinline__ float2 operator()(const float2& ab0, const float2& ab1) const { return make_float2(ab1.x * ab0.x, ab1.x * ab0.y + ab1.y); } }; // A stateful callback functor that maintains a running prefix to be applied // during consecutive scan operations. template struct SSMScanPrefixCallbackOp { using scan_t = std::conditional_t, float2, float4>; scan_t running_prefix; // Constructor __device__ SSMScanPrefixCallbackOp(scan_t running_prefix_) : running_prefix(running_prefix_) { } // Callback operator to be entered by the first warp of threads in the block. // Thread-0 is responsible for returning a value for seeding the block-wide scan. __device__ scan_t operator()(scan_t block_aggregate) { scan_t old_prefix = running_prefix; running_prefix = SSMScanOp()(running_prefix, block_aggregate); return old_prefix; } }; //////////////////////////////////////////////////////////////////////////////////////////////////// template inline __device__ void load_input(typename Ktraits::input_t* u, typename Ktraits::input_t (&u_vals)[Ktraits::kNItems], typename Ktraits::BlockLoadT::TempStorage& smem_load, int seqlen) { if constexpr (Ktraits::kIsEvenLen) { auto& smem_load_vec = reinterpret_cast(smem_load); using vec_t = typename Ktraits::vec_t; Ktraits::BlockLoadVecT(smem_load_vec) .Load(reinterpret_cast(u), reinterpret_cast(u_vals)); } else { Ktraits::BlockLoadT(smem_load).Load(u, u_vals, seqlen, 0.f); } } template inline __device__ void load_weight(typename Ktraits::input_t* Bvar, typename Ktraits::weight_t (&B_vals)[Ktraits::kNItems], typename Ktraits::BlockLoadWeightT::TempStorage& smem_load_weight, int seqlen) { constexpr int kNItems = Ktraits::kNItems; typename Ktraits::input_t B_vals_load[kNItems]; if constexpr (Ktraits::kIsEvenLen) { auto& smem_load_weight_vec = reinterpret_cast(smem_load_weight); using vec_t = typename Ktraits::vec_t; Ktraits::BlockLoadWeightVecT(smem_load_weight_vec) .Load(reinterpret_cast(Bvar), reinterpret_cast(B_vals_load)); } else { Ktraits::BlockLoadWeightT(smem_load_weight).Load(Bvar, B_vals_load, seqlen, 0.f); } // #pragma unroll // for (int i = 0; i < kNItems; ++i) { B_vals[i] = B_vals_load[i]; } Converter::to_float(B_vals_load, B_vals); } template inline __device__ void store_output(typename Ktraits::input_t* out, const float (&out_vals)[Ktraits::kNItems], typename Ktraits::BlockStoreT::TempStorage& smem_store, int seqlen) { typename Ktraits::input_t write_vals[Ktraits::kNItems]; #pragma unroll for (int i = 0; i < Ktraits::kNItems; ++i) { write_vals[i] = out_vals[i]; } if constexpr (Ktraits::kIsEvenLen) { auto& smem_store_vec = reinterpret_cast(smem_store); using vec_t = typename Ktraits::vec_t; Ktraits::BlockStoreVecT(smem_store_vec) .Store(reinterpret_cast(out), reinterpret_cast(write_vals)); } else { Ktraits::BlockStoreT(smem_store).Store(out, write_vals, seqlen); } } } // namespace kernels } // namespace tensorrt_llm