#include "tensorrt_llm/kernels/preQuantScaleKernel.h" namespace tensorrt_llm { namespace kernels { namespace { template struct Vec2Type; template <> struct Vec2Type { using type = half2; }; #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && defined(ENABLE_BF16)) template <> struct Vec2Type<__nv_bfloat16> { using type = __nv_bfloat162; }; #endif }; // namespace template __global__ void apply_per_channel_scale(T* smoothed_act, const T* act, const T* per_channel_scale, int rows, int cols) { static constexpr int kElems = sizeof(AccessType) / sizeof(T); T scale[kElems], act_vec[kElems]; int col_offset = blockIdx.x * blockDim.x + threadIdx.x; int row_offset = blockIdx.y; if (col_offset * kElems >= cols || row_offset * kProcessRows >= rows) return; act += row_offset * kProcessRows * cols; smoothed_act += row_offset * kProcessRows * cols; *reinterpret_cast(scale) = reinterpret_cast(per_channel_scale)[col_offset]; #pragma unroll for (int i = 0; i < kProcessRows; ++i) { *reinterpret_cast(act_vec) = reinterpret_cast(act + i * cols)[col_offset]; if constexpr ((std::is_same_v #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && defined(ENABLE_BF16)) || std::is_same_v #endif ) &&(kElems % 2 == 0)) { using Vec2 = typename Vec2Type::type; #pragma unroll for (int j = 0; j < kElems; j += 2) { *reinterpret_cast(act_vec + j) = __hmul2(*reinterpret_cast(act_vec + j), *reinterpret_cast(scale + j)); } } else { #pragma unroll for (int j = 0; j < kElems; ++j) { act_vec[j] = static_cast(static_cast(act_vec[j]) * static_cast(scale[j])); } } reinterpret_cast(smoothed_act + i * cols)[col_offset] = *reinterpret_cast(act_vec); } } template void apply_per_channel_scale_kernel_launcher_( T* smoothed_act, const T* act, const T* per_channel_scale, int rows, int cols, cudaStream_t stream = 0) { static constexpr int kElems = sizeof(AccessType) / sizeof(T); dim3 block(128); dim3 grid((cols / kElems + block.x - 1) / block.x, (rows + kProcessRows - 1) / kProcessRows); apply_per_channel_scale <<>>(smoothed_act, act, per_channel_scale, rows, cols); } template void apply_per_channel_scale_kernel_launcher( T* smoothed_act, const T* act, const T* per_channel_scale, int rows, int cols, cudaStream_t stream) { int elems = rows * cols; if (elems < 2048 * 2048) { apply_per_channel_scale_kernel_launcher_( smoothed_act, act, per_channel_scale, rows, cols, stream); } else if (elems < 4096 * 4096) { apply_per_channel_scale_kernel_launcher_( smoothed_act, act, per_channel_scale, rows, cols, stream); } else if (elems < 8192 * 8192) { apply_per_channel_scale_kernel_launcher_( smoothed_act, act, per_channel_scale, rows, cols, stream); } else { apply_per_channel_scale_kernel_launcher_( smoothed_act, act, per_channel_scale, rows, cols, stream); } } #define INSTANTIATE_PREQUANT_SCALE(T) \ template void apply_per_channel_scale_kernel_launcher( \ T * smoothed_act, const T* act, const T* per_channel_scale, int rows, int cols, cudaStream_t stream) INSTANTIATE_PREQUANT_SCALE(half); #if defined(ENABLE_BF16) INSTANTIATE_PREQUANT_SCALE(__nv_bfloat16); #endif } // namespace kernels } // namespace tensorrt_llm