/* * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "tensorrt_llm/common/config.h" #include "tensorrt_llm/common/cudaFp8Utils.h" #include "tensorrt_llm/common/cudaUtils.h" #include "tensorrt_llm/common/envUtils.h" #include "tensorrt_llm/common/reduceKernelUtils.cuh" #include #include #include #include #include TRTLLM_NAMESPACE_BEGIN namespace common { #ifdef ENABLE_FP8 constexpr int CTA_SIZE = 256; template __inline__ __device__ float scale(float a, float b) { return QUANTIZE ? a / b : a * b; } template __global__ void scaleMatrix(T_OUT* output, T_S const* input_scale, T_IN const* input, int64_t numel, int64_t lda) { #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) cudaGridDependencySynchronize(); #endif for (int64_t i = threadIdx.x + blockIdx.x * blockDim.x; i < numel; i += blockDim.x * gridDim.x) { if (QUANTIZE_MODE == QuantizeMode::PER_CHANNEL) { output[i] = T_OUT(scale(static_cast(input[i]), static_cast(input_scale[i % lda]))); } else if (QUANTIZE_MODE == QuantizeMode::PER_TOKEN) { output[i] = T_OUT(scale(static_cast(input[i]), static_cast(input_scale[i / lda]))); } else if (QUANTIZE_MODE == QuantizeMode::PER_TENSOR) { output[i] = T_OUT(scale(static_cast(input[i]), static_cast(input_scale[0]))); } } #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) cudaTriggerProgrammaticLaunchCompletion(); #endif } template void invokeQuantizeMatrix(T_OUT* output, T_S const* input_scale, T_IN const* input, int64_t numel, int64_t lda, QuantizeMode quantize_mode, cudaStream_t stream) { dim3 grid(1024); dim3 block(CTA_SIZE); cudaLaunchConfig_t config; config.gridDim = grid; config.blockDim = block; config.dynamicSmemBytes = 0; config.stream = stream; cudaLaunchAttribute attrs[1]; attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; attrs[0].val.programmaticStreamSerializationAllowed = tensorrt_llm::common::getEnvEnablePDL(); config.numAttrs = 1; config.attrs = attrs; if (quantize_mode == QuantizeMode::PER_CHANNEL) { cudaLaunchKernelEx(&config, scaleMatrix, output, input_scale, input, numel, lda); } else if (quantize_mode == QuantizeMode::PER_TOKEN) { cudaLaunchKernelEx(&config, scaleMatrix, output, input_scale, input, numel, lda); } else if (quantize_mode == QuantizeMode::PER_TENSOR) { cudaLaunchKernelEx(&config, scaleMatrix, output, input_scale, input, numel, lda); } sync_check_cuda_error(stream); } template void invokeDequantizeMatrix(T_OUT* output, T_S const* input_scale, T_IN const* input, int64_t numel, int64_t lda, QuantizeMode quantize_mode, cudaStream_t stream) { dim3 grid(1024); dim3 block(CTA_SIZE); cudaLaunchConfig_t config; config.gridDim = grid; config.blockDim = block; config.dynamicSmemBytes = 0; config.stream = stream; cudaLaunchAttribute attrs[1]; attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; attrs[0].val.programmaticStreamSerializationAllowed = tensorrt_llm::common::getEnvEnablePDL(); config.numAttrs = 1; config.attrs = attrs; if (quantize_mode == QuantizeMode::PER_CHANNEL) { cudaLaunchKernelEx(&config, scaleMatrix, output, input_scale, input, numel, lda); } else if (quantize_mode == QuantizeMode::PER_TOKEN) { cudaLaunchKernelEx(&config, scaleMatrix, output, input_scale, input, numel, lda); } else if (quantize_mode == QuantizeMode::PER_TENSOR) { cudaLaunchKernelEx(&config, scaleMatrix, output, input_scale, input, numel, lda); } sync_check_cuda_error(stream); } template __global__ void fakeQuantize(T_OUT* dst, const T_IN* src, const int64_t numel) { for (int64_t tid = threadIdx.x + blockIdx.x * blockDim.x; tid < numel; tid += blockDim.x * gridDim.x) { T_FAKE tmp = (T_FAKE) (static_cast(src[tid])); dst[tid] = (T_OUT) (static_cast(tmp)); } } template void invokeFakeQuantize(T_OUT* dst, const T_IN* src, const int64_t numel, cudaStream_t stream) { fakeQuantize<<<1024, CTA_SIZE, 0, stream>>>(dst, src, numel); sync_check_cuda_error(stream); } template void invokeFakeQuantize<__nv_fp8_e4m3, float, float>( float* dst, float const* src, const int64_t numel, cudaStream_t stream); template void invokeFakeQuantize( float* dst, __nv_fp8_e4m3 const* src, const int64_t numel, cudaStream_t stream); template void invokeFakeQuantize<__nv_fp8_e4m3, half, half>( half* dst, half const* src, const int64_t numel, cudaStream_t stream); template void invokeFakeQuantize<__nv_fp8_e4m3, __nv_bfloat16, __nv_bfloat16>( __nv_bfloat16* dst, __nv_bfloat16 const* src, const int64_t numel, cudaStream_t stream); template void invokeFakeQuantize( half* dst, float const* src, const int64_t numel, cudaStream_t stream); __device__ float atomicMaxExtd(float* address, float val) { assert(val >= 0); unsigned int* address_as_u = reinterpret_cast(address); unsigned int old = atomicMax(address_as_u, __float_as_uint(val)); return __uint_as_float(old); } template inline __device__ T atomicMaxExtdV2(T* address, T val) { #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) static_assert(std::is_same_v | std::is_same_v, "T needs to be either half or bfloat16"); // The address in 64 bits. uint64_t address_u64 = reinterpret_cast(address); // Pack the input value into 32 bits. union { T v[2]; uint16_t u[2]; } old, tmp = {}; int const loc = (address_u64 & 0x2) >> 1; tmp.v[loc] = val; // 4B aligned pointer. auto aligned_address = reinterpret_cast(address_u64 & ~0x3ull); if constexpr (std::is_same_v) { asm volatile("atom.global.v2.f16.max.noftz {%0, %1}, [%2], {%3, %4};" : "=h"(old.u[0]), "=h"(old.u[1]) : "l"(aligned_address), "h"(tmp.u[0]), "h"(tmp.u[1])); } if constexpr (std::is_same_v) { asm volatile("atom.global.v2.bf16.max.noftz {%0, %1}, [%2], {%3, %4};" : "=h"(old.u[0]), "=h"(old.u[1]) : "l"(aligned_address), "h"(tmp.u[0]), "h"(tmp.u[1])); } // Return the correct half. return old.v[loc]; #endif } __device__ half atomicMaxExtd(half* address, half val) { unsigned short int* address_as_u = reinterpret_cast(address); unsigned short int old = *address_as_u, assumed; while (val > __ushort_as_half(old)) { assumed = old; old = atomicCAS(address_as_u, assumed, __half_as_ushort(val)); } return __ushort_as_half(old); } __device__ __nv_bfloat16 atomicMaxExtd(__nv_bfloat16* address, __nv_bfloat16 val) { #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) unsigned short int* address_as_u = reinterpret_cast(address); unsigned short int old = *address_as_u, assumed; while (val > __ushort_as_bfloat16(old)) { assumed = old; old = atomicCAS(address_as_u, assumed, __bfloat16_as_ushort(val)); } return __ushort_as_bfloat16(old); #else assert(0); asm volatile("brkpt;\n" ::); return __nv_bfloat16(0); #endif } template __global__ void computeFP8QuantizeScale(T_S* quant_ptr, const T_W* weights, const int64_t size, const int64_t n) { constexpr float min_scaling_factor = 1.0f / (FP8_E4M3_MAX * 512.f); if (QUANTIZE_MODE == QuantizeMode::PER_CHANNEL) { for (int64_t col = threadIdx.x; col < n; col += blockDim.x) { float max = 0.f; for (int64_t i = col + n * blockIdx.x; i < size; i += gridDim.x * n) { auto val = fabs(static_cast(weights[i])); max = max > val ? max : val; } auto const scale = (T_S) std::max(max / FP8_E4M3_MAX, min_scaling_factor); #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) if constexpr (std::is_same_v) { atomicMaxExtd(quant_ptr + col, scale); } else { auto const address_u64 = reinterpret_cast(quant_ptr + col); if ((col == 0 && address_u64 % 4 != 0) || (col == n - 1 && address_u64 % 4 == 0)) atomicMaxExtd(quant_ptr + col, scale); else atomicMaxExtdV2(quant_ptr + col, scale); } #else // Vector atomics require __CUDA_ARCH__ >= 900 atomicMaxExtd(quant_ptr + col, scale); #endif } } else if (QUANTIZE_MODE == QuantizeMode::PER_TOKEN) { auto const nrows = size / n; for (int64_t row = blockIdx.x; row < nrows; row += gridDim.x) { float max = 0.f; for (int64_t i = threadIdx.x; i < n; i += blockDim.x) { auto val = fabs(static_cast(weights[row * n + i])); max = max > val ? max : val; } max = blockReduceMax(max); if (threadIdx.x == 0) { auto const scale = (T_S) std::max(max / FP8_E4M3_MAX, min_scaling_factor); quant_ptr[row] = scale; } } } else if (QUANTIZE_MODE == QuantizeMode::PER_TENSOR) { float max = 0.f; for (int64_t i = threadIdx.x + blockIdx.x * blockDim.x; i < size; i += gridDim.x * blockDim.x) { auto val = fabs(static_cast(weights[i])); max = max > val ? max : val; } max = blockReduceMax(max); if (threadIdx.x == 0) { auto const scale = (T_S) std::max(max / FP8_E4M3_MAX, min_scaling_factor); atomicMaxExtd(quant_ptr, scale); } } } template void invokeComputeFP8QuantizeScale(T_S* quant_ptr, const T_W* weights, const int64_t numel, const int64_t lda, QuantizeMode quantize_mode, cudaStream_t stream) { if (quantize_mode == QuantizeMode::PER_TOKEN) { dim3 block(CTA_SIZE); dim3 grid(numel / lda); computeFP8QuantizeScale<<>>(quant_ptr, weights, numel, lda); } else if (quantize_mode == QuantizeMode::PER_CHANNEL) { dim3 block(CTA_SIZE); dim3 grid((lda + CTA_SIZE - 1) / CTA_SIZE); cudaMemsetAsync(quant_ptr, 0, lda * sizeof(T_S), stream); sync_check_cuda_error(stream); computeFP8QuantizeScale<<>>(quant_ptr, weights, numel, lda); } else if (quantize_mode == QuantizeMode::PER_TENSOR) { dim3 block(1024); dim3 grid(1024); cudaMemsetAsync(quant_ptr, 0, sizeof(T_S), stream); sync_check_cuda_error(stream); computeFP8QuantizeScale<<>>(quant_ptr, weights, numel, lda); } sync_check_cuda_error(stream); } #define DEFINE_INVOKE_COMPUTE_FP8_QUANTIZE_SCALE(type_scale, type_in) \ template void invokeComputeFP8QuantizeScale(type_scale * input_scale, type_in const* weights, \ int64_t numel, int64_t lda, QuantizeMode quantize_mode, cudaStream_t stream); DEFINE_INVOKE_COMPUTE_FP8_QUANTIZE_SCALE(half, half); DEFINE_INVOKE_COMPUTE_FP8_QUANTIZE_SCALE(float, half); DEFINE_INVOKE_COMPUTE_FP8_QUANTIZE_SCALE(float, float); #ifdef ENABLE_BF16 DEFINE_INVOKE_COMPUTE_FP8_QUANTIZE_SCALE(__nv_bfloat16, __nv_bfloat16); DEFINE_INVOKE_COMPUTE_FP8_QUANTIZE_SCALE(float, __nv_bfloat16); #endif template __global__ void dynamicQuantizeMatrixPerToken( T_OUT* output, T_S* quant_ptr, T_IN const* input, int64_t numel, int64_t lda) { extern __shared__ __align__(sizeof(float)) char _shmem[]; T_IN* shmem = reinterpret_cast(_shmem); constexpr float min_scaling_factor = 1.0f / (FP8_E4M3_MAX * 512.f); auto const nrows = numel / lda; for (int64_t row = blockIdx.x; row < nrows; row += gridDim.x) { float max = 0.f; for (int64_t i = threadIdx.x; i < lda; i += blockDim.x) { auto const in = input[row * lda + i]; shmem[i] = in; auto val = fabs(static_cast(in)); max = max > val ? max : val; } max = blockAllReduceMax(max); // __syncthreads() called so we can read shmem auto const s = (T_S) std::max(max / FP8_E4M3_MAX, min_scaling_factor); for (int64_t i = threadIdx.x; i < lda; i += blockDim.x) { // true means we are quantizing output[row * lda + i] = (T_OUT) scale(static_cast(shmem[i]), static_cast(s)); } if (threadIdx.x == 0) { quant_ptr[row] = s; } } } template void invokeComputeScalesAndQuantizeMatrix(T_OUT* output, T_S* quant_ptr, const T_IN* input, const int64_t numel, const int64_t lda, QuantizeMode quantize_mode, cudaStream_t stream) { if (quantize_mode == QuantizeMode::PER_TOKEN) { dim3 grid(numel / lda); bool use_shmem = true; auto const shmem_size = lda * sizeof(T_IN); if (shmem_size >= (48 << 10)) { cudaError_t ret = cudaFuncSetAttribute(dynamicQuantizeMatrixPerToken, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size); use_shmem = ret == cudaSuccess; } if (use_shmem) { // ensure the threadblock is as large as possible to increase occupancy dim3 block(std::min((lda + 31) / 32 * 32, static_cast(1024))); dynamicQuantizeMatrixPerToken<<>>(output, quant_ptr, input, numel, lda); } else { dim3 block(CTA_SIZE); computeFP8QuantizeScale<<>>(quant_ptr, input, numel, lda); sync_check_cuda_error(stream); invokeQuantizeMatrix(output, quant_ptr, input, numel, lda, quantize_mode, stream); } } else if (quantize_mode == QuantizeMode::PER_CHANNEL) { dim3 block(CTA_SIZE); dim3 grid((lda + CTA_SIZE - 1) / CTA_SIZE); cudaMemsetAsync(quant_ptr, 0, lda * sizeof(T_S), stream); sync_check_cuda_error(stream); computeFP8QuantizeScale<<>>(quant_ptr, input, numel, lda); sync_check_cuda_error(stream); invokeQuantizeMatrix(output, quant_ptr, input, numel, lda, quantize_mode, stream); } else if (quantize_mode == QuantizeMode::PER_TENSOR) { dim3 block(1024); dim3 grid(1024); cudaMemsetAsync(quant_ptr, 0, sizeof(T_S), stream); sync_check_cuda_error(stream); computeFP8QuantizeScale<<>>(quant_ptr, input, numel, lda); sync_check_cuda_error(stream); invokeQuantizeMatrix(output, quant_ptr, input, numel, lda, quantize_mode, stream); } sync_check_cuda_error(stream); } #define DEFINE_INVOKE_QUANTIZE_MATRIX(type_out, type_scale, type_in) \ template void invokeQuantizeMatrix(type_out * output, \ type_scale const* input_scale, type_in const* input, int64_t numel, int64_t lda, QuantizeMode quantize_mode, \ cudaStream_t stream); \ template void invokeDequantizeMatrix(type_out * output, \ type_scale const* input_scale, type_in const* input, int64_t numel, int64_t lda, QuantizeMode quantize_mode, \ cudaStream_t stream); \ template void invokeComputeScalesAndQuantizeMatrix(type_out * output, \ type_scale * input_scale, type_in const* input, int64_t numel, int64_t lda, QuantizeMode quantize_mode, \ cudaStream_t stream); #ifdef ENABLE_FP8 DEFINE_INVOKE_QUANTIZE_MATRIX(__nv_fp8_e4m3, float, float); DEFINE_INVOKE_QUANTIZE_MATRIX(__nv_fp8_e4m3, float, half); DEFINE_INVOKE_QUANTIZE_MATRIX(__nv_fp8_e4m3, half, half); DEFINE_INVOKE_QUANTIZE_MATRIX(half, half, __nv_fp8_e4m3); DEFINE_INVOKE_QUANTIZE_MATRIX(float, float, __nv_fp8_e4m3); DEFINE_INVOKE_QUANTIZE_MATRIX(half, float, __nv_fp8_e4m3); #ifdef ENABLE_BF16 DEFINE_INVOKE_QUANTIZE_MATRIX(__nv_fp8_e4m3, float, __nv_bfloat16); DEFINE_INVOKE_QUANTIZE_MATRIX(__nv_fp8_e4m3, __nv_bfloat16, __nv_bfloat16); DEFINE_INVOKE_QUANTIZE_MATRIX(__nv_bfloat16, __nv_bfloat16, __nv_fp8_e4m3); DEFINE_INVOKE_QUANTIZE_MATRIX(__nv_bfloat16, float, __nv_fp8_e4m3); #endif #endif #endif // ENABLE_FP8 } // namespace common TRTLLM_NAMESPACE_END