/* * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #pragma once #include "tensorrt_llm/common/config.h" #include "tensorrt_llm/common/cudaBf16Wrapper.h" #include "tensorrt_llm/common/cudaDriverWrapper.h" #include "tensorrt_llm/common/cudaFp8Utils.h" #if ENABLE_FP4 #include #endif #include "tensorrt_llm/common/logger.h" #include "tensorrt_llm/common/tllmException.h" #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #ifndef _WIN32 // Linux #include #endif // not WIN32 #include #ifdef _WIN32 // Windows #include #undef ERROR // A Windows header file defines ERROR as 0, but it's used in our logger.h enum. Logging breaks without // this undef. #endif // WIN32 TRTLLM_NAMESPACE_BEGIN namespace common { // workspace for cublas gemm : 32MB #define CUBLAS_WORKSPACE_SIZE 33554432 typedef struct __align__(4) { half x, y, z, w; } half4; /* **************************** type definition ***************************** */ enum CublasDataType { FLOAT_DATATYPE = 0, HALF_DATATYPE = 1, BFLOAT16_DATATYPE = 2, INT8_DATATYPE = 3, FP8_DATATYPE = 4 }; enum TRTLLMCudaDataType { FP32 = 0, FP16 = 1, BF16 = 2, INT8 = 3, FP8 = 4 }; enum class OperationType { FP32, FP16, BF16, INT8, FP8 }; /* **************************** debug tools ********************************* */ static char const* _cudaGetErrorEnum(cudaError_t error) { return cudaGetErrorString(error); } static char const* _cudaGetErrorEnum(cublasStatus_t error) { switch (error) { case CUBLAS_STATUS_SUCCESS: return "CUBLAS_STATUS_SUCCESS"; case CUBLAS_STATUS_NOT_INITIALIZED: return "CUBLAS_STATUS_NOT_INITIALIZED"; case CUBLAS_STATUS_ALLOC_FAILED: return "CUBLAS_STATUS_ALLOC_FAILED"; case CUBLAS_STATUS_INVALID_VALUE: return "CUBLAS_STATUS_INVALID_VALUE"; case CUBLAS_STATUS_ARCH_MISMATCH: return "CUBLAS_STATUS_ARCH_MISMATCH"; case CUBLAS_STATUS_MAPPING_ERROR: return "CUBLAS_STATUS_MAPPING_ERROR"; case CUBLAS_STATUS_EXECUTION_FAILED: return "CUBLAS_STATUS_EXECUTION_FAILED"; case CUBLAS_STATUS_INTERNAL_ERROR: return "CUBLAS_STATUS_INTERNAL_ERROR"; case CUBLAS_STATUS_NOT_SUPPORTED: return "CUBLAS_STATUS_NOT_SUPPORTED"; case CUBLAS_STATUS_LICENSE_ERROR: return "CUBLAS_STATUS_LICENSE_ERROR"; } return ""; } template void check(T ptr, char const* const func, char const* const file, int const line) { if (ptr) { throw TllmException(file, line, fmtstr("[TensorRT-LLM][ERROR] CUDA runtime error in %s: %s", func, _cudaGetErrorEnum(ptr)).c_str()); } } template void checkEx( T ptr, std::initializer_list const& validReturns, char const* const func, char const* const file, int const line) { if (std::all_of(std::begin(validReturns), std::end(validReturns), [&ptr](T const& t) { return t != ptr; })) { throw TllmException(file, line, fmtstr("[TensorRT-LLM][ERROR] CUDA runtime error in %s: %s", func, _cudaGetErrorEnum(ptr)).c_str()); } } #define check_cuda_error(val) check((val), #val, __FILE__, __LINE__) #define check_cuda_error_2(val, file, line) check((val), #val, file, line) inline std::optional isCudaLaunchBlocking() { thread_local bool firstCall = true; thread_local std::optional result = std::nullopt; if (!firstCall) { char const* env = std::getenv("CUDA_LAUNCH_BLOCKING"); if (env != nullptr && std::string(env) == "1") { result = true; } else { result = false; } firstCall = false; } return result; } inline bool isCapturing(cudaStream_t stream) { cudaStreamCaptureStatus status; check_cuda_error(cudaStreamIsCapturing(stream, &status)); return status == cudaStreamCaptureStatus::cudaStreamCaptureStatusActive; } inline bool doCheckError(cudaStream_t stream) { auto const cudaLaunchBlocking = isCudaLaunchBlocking(); if (cudaLaunchBlocking.has_value() && cudaLaunchBlocking.value()) { return !isCapturing(stream); } #ifndef NDEBUG // Debug builds will sync when we're not capturing unless explicitly // disabled. bool const checkError = cudaLaunchBlocking.value_or(!isCapturing(stream)); #else bool const checkError = cudaLaunchBlocking.value_or(false); #endif return checkError; } inline void syncAndCheck(cudaStream_t stream, char const* const file, int const line) { if (doCheckError(stream)) { cudaStreamSynchronize(stream); check(cudaGetLastError(), "cudaGetLastError", file, line); } } #define sync_check_cuda_error(stream) tensorrt_llm::common::syncAndCheck(stream, __FILE__, __LINE__) #define PRINT_FUNC_NAME_() \ do \ { \ std::cout << "[TensorRT-LLM][CALL] " << __FUNCTION__ << " " << std::endl; \ } while (0) // clang-format off template struct packed_type; template <> struct packed_type { using type = float; }; // we don't need to pack float by default template <> struct packed_type { using type = half2; }; #ifdef ENABLE_BF16 template<> struct packed_type<__nv_bfloat16> { using type = __nv_bfloat162; }; #endif #ifdef ENABLE_FP8 template<> struct packed_type<__nv_fp8_e4m3> { using type = __nv_fp8x2_e4m3; }; #endif template struct num_elems; template <> struct num_elems { static constexpr int value = 1; }; template <> struct num_elems { static constexpr int value = 2; }; template <> struct num_elems { static constexpr int value = 4; }; template <> struct num_elems { static constexpr int value = 1; }; template <> struct num_elems { static constexpr int value = 2; }; #ifdef ENABLE_BF16 template <> struct num_elems<__nv_bfloat16> { static constexpr int value = 1; }; template <> struct num_elems<__nv_bfloat162> { static constexpr int value = 2; }; #endif #ifdef ENABLE_FP8 template <> struct num_elems<__nv_fp8_e4m3> { static constexpr int value = 1; }; template <> struct num_elems<__nv_fp8x2_e4m3> { static constexpr int value = 2; }; #endif template struct packed_as; template struct packed_as { using type = T; }; template<> struct packed_as { using type = half2; }; template<> struct packed_as { using type = float2; }; template<> struct packed_as { using type = int16_t; }; template<> struct packed_as { using type = int2; }; template<> struct packed_as { using type = half; }; template<> struct packed_as { using type = float; }; #ifdef ENABLE_BF16 template<> struct packed_as<__nv_bfloat16, 2> { using type = __nv_bfloat162; }; template<> struct packed_as<__nv_bfloat162, 1> { using type = __nv_bfloat16; }; #endif #ifdef ENABLE_FP8 template<> struct packed_as<__nv_fp8_e4m3, 2> { using type = __nv_fp8x2_e4m3; }; template<> struct packed_as<__nv_fp8x2_e4m3, 1> { using type = __nv_fp8_e4m3; }; template<> struct packed_as<__nv_fp8_e5m2, 2> { using type = __nv_fp8x2_e5m2; }; template<> struct packed_as<__nv_fp8x2_e5m2, 1> { using type = __nv_fp8_e5m2; }; #endif inline __device__ float2 operator*(float2 a, float2 b) { return make_float2(a.x * b.x, a.y * b.y); } inline __device__ float2 operator+(float2 a, float2 b) { return make_float2(a.x + b.x, a.y + b.y); } inline __device__ float2 operator-(float2 a, float2 b) { return make_float2(a.x - b.x, a.y - b.y); } inline __device__ float2 operator*(float2 a, float b) { return make_float2(a.x * b, a.y * b); } inline __device__ float2 operator+(float2 a, float b) { return make_float2(a.x + b, a.y + b); } inline __device__ float2 operator-(float2 a, float b) { return make_float2(a.x - b, a.y - b); } // clang-format on template struct CudaDataType { }; template <> struct CudaDataType { static constexpr cudaDataType_t value = cudaDataType::CUDA_R_32F; }; template <> struct CudaDataType { static constexpr cudaDataType_t value = cudaDataType::CUDA_R_16F; }; #ifdef ENABLE_BF16 template <> struct CudaDataType<__nv_bfloat16> { static constexpr cudaDataType_t value = cudaDataType::CUDA_R_16BF; }; #endif /// @brief Get the SM version of the current device. /// @param queryRealSmArch Whether to query the real SM architecture. example usage: use real sm arch when do LUT tuning /// and use fake sm arch when reuse sm120 code on sm121 devices. /// @return The SM version of the current device. inline int getSMVersion(bool queryRealSmArch = false) { int device{-1}; check_cuda_error(cudaGetDevice(&device)); int sm_major = 0; int sm_minor = 0; check_cuda_error(cudaDeviceGetAttribute(&sm_major, cudaDevAttrComputeCapabilityMajor, device)); check_cuda_error(cudaDeviceGetAttribute(&sm_minor, cudaDevAttrComputeCapabilityMinor, device)); int sm = sm_major * 10 + sm_minor; if (sm == 121 && !queryRealSmArch) { return 120; } return sm; } inline bool isSM100Family() { int const sm = getSMVersion(); return sm == 100 || sm == 103; // To be continued... } inline int getDevice() { int deviceID{0}; check_cuda_error(cudaGetDevice(&deviceID)); return deviceID; } inline int getDeviceCount() { int count{0}; check_cuda_error(cudaGetDeviceCount(&count)); return count; } /// @brief Identifies the memory type of the given pointer. template cudaMemoryType getPtrCudaMemoryType(T* ptr) { cudaPointerAttributes attributes{}; check_cuda_error(cudaPointerGetAttributes(&attributes, ptr)); return attributes.type; } /// Get the memory info /// \return The free and total amount of memory in bytes inline std::tuple getDeviceMemoryInfo(bool const useUvm) { if (useUvm) { size_t freeSysMem = 0; size_t totalSysMem = 0; #ifndef _WIN32 // Linux struct sysinfo info { }; sysinfo(&info); totalSysMem = info.totalram * info.mem_unit; freeSysMem = info.freeram * info.mem_unit; #else // Windows MEMORYSTATUSEX memInfo; memInfo.dwLength = sizeof(memInfo); GlobalMemoryStatusEx(&memInfo); totalSysMem = memInfo.ullTotalPhys; freeSysMem = memInfo.ullAvailPhys; #endif // WIN32 TLLM_LOG_INFO("Using UVM based system memory for KV cache, total memory %0.2f GB, available memory %0.2f GB", ((double) totalSysMem / 1e9), ((double) freeSysMem / 1e9)); return {freeSysMem, totalSysMem}; } size_t free = 0; size_t total = 0; check_cuda_error(cudaMemGetInfo(&free, &total)); TLLM_LOG_DEBUG("Using GPU memory for KV cache, total memory %0.2f GB, available memory %0.2f GB", ((double) total / 1e9), ((double) free / 1e9)); return {free, total}; } /// @brief Gets the memory allocation granularity for the current device. /// /// @return size_t The size of the smallest difference in memory size supported by the current device. inline size_t getAllocationGranularity() { auto const currentDevice = getDevice(); ::CUmemAllocationProp prop = {}; prop.type = ::CU_MEM_ALLOCATION_TYPE_PINNED; prop.location.type = ::CU_MEM_LOCATION_TYPE_DEVICE; prop.location.id = currentDevice; prop.requestedHandleTypes = ::CU_MEM_HANDLE_TYPE_NONE; // Get the minimum granularity supported for allocation with cuMemCreate() size_t granularity = 0; TLLM_CU_CHECK(cuMemGetAllocationGranularity(&granularity, &prop, CU_MEM_ALLOC_GRANULARITY_MINIMUM)); return granularity; } inline int getMultiProcessorCount() { int nSM{0}; int deviceID{0}; check_cuda_error(cudaGetDevice(&deviceID)); check_cuda_error(cudaDeviceGetAttribute(&nSM, cudaDevAttrMultiProcessorCount, deviceID)); return nSM; } inline int getMaxSharedMemoryPerSM() { int nByteMaxSharedMemoryPerSM{0}; int deviceID{0}; check_cuda_error(cudaGetDevice(&deviceID)); check_cuda_error( cudaDeviceGetAttribute(&nByteMaxSharedMemoryPerSM, cudaDevAttrMaxSharedMemoryPerMultiprocessor, deviceID)); return nByteMaxSharedMemoryPerSM; } inline int getMaxSharedMemoryPerBlockOptin() { int nByteMaxSharedMemoryPerBlockOptin{0}; int deviceID{0}; check_cuda_error(cudaGetDevice(&deviceID)); check_cuda_error( cudaDeviceGetAttribute(&nByteMaxSharedMemoryPerBlockOptin, cudaDevAttrMaxSharedMemoryPerBlockOptin, deviceID)); return nByteMaxSharedMemoryPerBlockOptin; } template inline int getMaxActiveBlocksPerSM(T kernel, int blockSize, size_t dynamicSMemSize) { static std::unordered_map cache; auto it = cache.find(kernel); if (it != cache.end()) { return it->second; } int numBlocks; check_cuda_error(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&numBlocks, kernel, blockSize, dynamicSMemSize)); cache[kernel] = numBlocks; return numBlocks; } template inline size_t divUp(T1 const& a, T2 const& b) { auto const tmp_a = static_cast(a); auto const tmp_b = static_cast(b); return (tmp_a + tmp_b - 1) / tmp_b; } inline int roundUp(int a, int b) { return divUp(a, b) * b; } template ::value>, typename = std::enable_if_t::value>> auto constexpr ceilDiv(T numerator, U denominator) { return (numerator + denominator - 1) / denominator; } template void printArrayInfo(T const* ptr, uint64_t nElement = 1, std::string name = "", bool const bPrintElement = false) { if (ptr == nullptr) { TLLM_LOG_WARNING("%s is an nullptr, skip!", name.c_str()); return; } cudaDeviceSynchronize(); check_cuda_error(cudaGetLastError()); bool const isDevicePtr = (getPtrCudaMemoryType(ptr) == cudaMemoryTypeDevice); size_t sizeInByte = sizeof(T) * nElement; TLLM_LOG_TRACE("addr=%p, location=%s, sizeof(T)=%lu, nElement=%d, sizeInByte=%lu\n", ptr, (isDevicePtr ? "Device" : "Host"), sizeof(T), nElement, sizeInByte); T* tmp = const_cast(ptr); std::vector tmpVec; // For device pointer if (isDevicePtr) { tmpVec.resize(nElement); tmp = tmpVec.data(); // Note `data()` is not supported for vector check_cuda_error(cudaMemcpy(tmp, ptr, sizeInByte, cudaMemcpyDeviceToHost)); cudaDeviceSynchronize(); } size_t nInf = 0; size_t nNaN = 0; size_t nZero = 0; double sum = 0.0; double sqrSum = 0.0; double absSum = 0.0; float allMax = -1.0e6f; float allMin = 1.0e6f; float allSad = 0.0f; // Sum Abs of Difference, to distinguish A and its transpose float old = 0.0f; for (uint64_t i = 0; i < nElement; i++) { float val = (float) tmp[i]; if (std::isinf(val)) { nInf++; continue; } if (std::isnan(val)) { nNaN++; continue; } nZero += (val == 0.0f); sum += val; sqrSum += val * val; absSum += expf(val); allMax = std::max(allMax, val); allMin = std::min(allMin, val); allSad += abs(val - old); old = val; } float avg = sum / nElement; float std = sqrtf(sqrSum / nElement - avg * avg); TLLM_LOG_INFO("%s", name.c_str()); TLLM_LOG_INFO("size=%u, nInf=%zu, nNaN=%zu, nZero=%zu", nElement, nInf, nNaN, nZero); TLLM_LOG_INFO("avg=%f, absSum: %f, std=%f, max=%f, min=%f, sad=%f", avg, absSum, std, allMax, allMin, allSad); if (bPrintElement) { uint64_t constexpr nHead = 5; std::stringstream ss; ss << std::setw(10) << std::fixed << std::setprecision(3); for (uint64_t i = 0; i < std::min(nElement, nHead); ++i) { ss << (float) tmp[i] << ", "; } if (nElement > nHead) { ss << " ... "; for (uint64_t i = nElement - nHead; i < nElement; ++i) { ss << (float) tmp[i] << ", "; } } TLLM_LOG_INFO("%s", ss.str().c_str()); } cudaDeviceSynchronize(); check_cuda_error(cudaGetLastError()); } template void printArrayInfo(float const* ptr, uint64_t nElement, std::string name, bool const bPrintElement); template void printArrayInfo(half const* ptr, uint64_t nElement, std::string name, bool const bPrintElement); #ifdef ENABLE_BF16 template void printArrayInfo(__nv_bfloat16 const* ptr, uint64_t nElement, std::string name, bool const bPrintElement); #endif #ifdef ENABLE_FP8 template void printArrayInfo(__nv_fp8_e4m3 const* ptr, uint64_t nElement, std::string name, bool const bPrintElement); #endif #ifdef ENABLE_FP4 template void printArrayInfo(__nv_fp4_e2m1 const* ptr, uint64_t nElement, std::string name, bool const bPrintElement); #endif template void printArrayInfo(uint32_t const* ptr, uint64_t nElement, std::string name, bool const bPrintElement); template void printArrayInfo(uint64_t const* ptr, uint64_t nElement, std::string name, bool const bPrintElement); template void printArrayInfo(int const* ptr, uint64_t nElement, std::string name, bool const bPrintElement); template void printArrayInfo(uint8_t const* ptr, uint64_t nElement, std::string name, bool const bPrintElement); template void printToStream(T const* ptr, int const nElement, FILE* strm) { bool const split_rows = (strm == stdout); if (ptr == nullptr) { TLLM_LOG_WARNING("Nullptr, skip!\n"); return; } std::vector tmp(nElement, 0); check_cuda_error(cudaMemcpy(tmp.data(), ptr, sizeof(T) * nElement, cudaMemcpyDeviceToHost)); for (int i = 0; i < nElement; ++i) { fprintf(strm, "%f, ", static_cast(tmp[i])); if (split_rows && ((i + 1) % 10) == 0) fprintf(strm, "\n"); } if (!split_rows || (nElement % 10) != 0) { fprintf(strm, "\n"); } } template void printToScreen(T const* ptr, int const nElement) { printToStream(ptr, nElement, stdout); } template void print2dToStream(T const* ptr, int const nRow, int const nCol, int const nStride, FILE* strm) { if (ptr == nullptr) { TLLM_LOG_WARNING("Nullptr, skip!\n"); return; } for (int ri = 0; ri < nRow; ++ri) { T const* tmp = ptr + ri * nStride; printToStream(tmp, nCol, strm); } fprintf(strm, "\n"); } template void print2dToScreen(T const* ptr, int const nRow, int const nCol, int const nStride) { print2dToStream(ptr, nRow, nCol, nStride, stdout); } template void print2dToFile(std::string fname, T const* ptr, int const nRow, int const nCol, int const nStride) { FILE* fp = fopen(fname.c_str(), "wt"); if (fp != nullptr) { print2dToStream(ptr, nRow, nCol, nStride, fp); fclose(fp); } } __host__ __device__ inline void print_float_(float x) { printf("%7.3f ", x); } __host__ __device__ inline void print_element_(float x) { print_float_(x); } __host__ __device__ inline void print_element_(half x) { print_float_((float) x); } #ifdef ENABLE_BF16 __host__ __device__ inline void print_element_(__nv_bfloat16 x) { print_float_((float) x); } #endif #ifdef ENABLE_FP8 __host__ __device__ inline void print_element_(__nv_fp8_e4m3 x) { print_float_((float) x); } #endif __host__ __device__ inline void print_element_(bool ui) { printf("%7" PRIu32 " ", (unsigned int) ui); } __host__ __device__ inline void print_element_(uint8_t ui) { printf("%7" PRIu32 " ", (unsigned int) ui); } __host__ __device__ inline void print_element_(uint32_t ul) { printf("%7" PRIu32 " ", ul); } __host__ __device__ inline void print_element_(uint64_t ull) { printf("%7" PRIu64 " ", ull); } __host__ __device__ inline void print_element_(int32_t il) { printf("%7" PRId32 " ", il); } __host__ __device__ inline void print_element_(int64_t ill) { printf("%7" PRId64 " ", ill); } template __host__ __device__ inline void print_elements(T const* ptr, int nRow, int nCol, int nStride) { for (int iRow = -1; iRow < nRow; ++iRow) { if (iRow >= 0) { printf("%07d|", iRow); } else { printf(" |"); // heading row } for (int iCol = 0; iCol < nCol; iCol += 1) { if (iRow >= 0) { print_element_(ptr[iRow * nStride + iCol]); } else { printf("%7d|", iCol); // heading colume } } printf("\n"); } printf("\n"); } template inline void printMatrix(T const* ptr, int nRow, int nCol, int nStride) { // `nRow` is length of row dimension // `nStride` is length of column dimension // `nCol` (<= nStride) is length for print per row if (ptr == nullptr) { TLLM_LOG_WARNING("Nullptr, skip!\n"); return; } cudaDeviceSynchronize(); check_cuda_error(cudaGetLastError()); bool const isDevicePtr = (getPtrCudaMemoryType(ptr) == cudaMemoryTypeDevice); size_t sizeInByte = sizeof(T) * nRow * nStride; TLLM_LOG_TRACE("addr=%p, location=%s, sizeof(T)=%lu, nRow=%d, nStride=%d, sizeInByte=%lu\n", ptr, (isDevicePtr ? "Device" : "Host"), sizeof(T), nRow, nStride, sizeInByte); if (isDevicePtr) { std::vector tmpVec; tmpVec.resize(nRow * nStride); T* tmp = tmpVec.data(); // Note `data()` is not supported for vector check_cuda_error(cudaMemcpy(tmp, ptr, sizeInByte, cudaMemcpyDeviceToHost)); cudaDeviceSynchronize(); check_cuda_error(cudaGetLastError()); print_elements(tmp, nRow, nCol, nStride); } else { print_elements(ptr, nRow, nCol, nStride); } } template void printMatrix(float const* ptr, int nRow, int nCol, int nStride); template void printMatrix(half const* ptr, int nRow, int nCol, int nStride); #ifdef ENABLE_BF16 template void printMatrix(__nv_bfloat16 const* ptr, int nRow, int nCol, int nStride); #endif #ifdef ENABLE_FP8 template void printMatrix(__nv_fp8_e4m3 const* ptr, int nRow, int nCol, int nStride); #endif template void printMatrix(uint32_t const* ptr, int nRow, int nCol, int nStride); template void printMatrix(uint64_t const* ptr, int nRow, int nCol, int nStride); template void printMatrix(int const* ptr, int nRow, int nCol, int nStride); template void printMatrix(uint8_t const* ptr, int nRow, int nCol, int nStride); template __device__ inline void printMatrixDevice(T const* ptr, int nRow, int nCol, int nStride) { // `nRow` is length of row dimension // `nStride` is length of column dimension // `nCol` (<= nStride) is length for print per row // Can be called inside kernels by one single thread if (ptr == nullptr) { printf("Nullptr, skip!\n"); return; } size_t sizeInByte = sizeof(T) * nRow * nStride; printf("addr=%p, sizeof(T)=%lu, nRow=%d, nStride=%d, sizeInByte=%lu\n", ptr, sizeof(T), nRow, nStride, sizeInByte); print_elements(ptr, nRow, nCol, nStride); } template __device__ void printMatrixDevice(float const* ptr, int nRow, int nCol, int nStride); template __device__ void printMatrixDevice(half const* ptr, int nRow, int nCol, int nStride); #ifdef ENABLE_BF16 template __device__ void printMatrixDevice(__nv_bfloat16 const* ptr, int nRow, int nCol, int nStride); #endif #ifdef ENABLE_FP8 template __device__ void printMatrixDevice(__nv_fp8_e4m3 const* ptr, int nRow, int nCol, int nStride); #endif template __device__ void printMatrixDevice(uint32_t const* ptr, int nRow, int nCol, int nStride); template __device__ void printMatrixDevice(uint64_t const* ptr, int nRow, int nCol, int nStride); template __device__ void printMatrixDevice(int const* ptr, int nRow, int nCol, int nStride); template __device__ void printMatrixDevice(uint8_t const* ptr, int nRow, int nCol, int nStride); #ifndef CUDA_CALL #define CUDA_CALL(answer) \ { \ gpuAssert((answer), __FILE__, __LINE__); \ } inline void gpuAssert(cudaError_t code, char const* file, int line, bool abort = true) { if (code != cudaSuccess) { fprintf(stderr, "CUDA error: %s @ %s:%d\n", cudaGetErrorString(code), file, line); if (abort) exit(code); } } inline void gpuAssert(CUresult code, char const* file, int line, bool abort = true) { if (code != CUresult::CUDA_SUCCESS) { char const* buf = "Unknown error"; assert(cuGetErrorString(code, &buf) == CUresult::CUDA_SUCCESS); fprintf(stderr, "Driver API error: %s @ %s:%d\n", buf, file, line); if (abort) exit(code); } } #endif template struct UpperType; template <> struct UpperType { using Type = int; }; template <> struct UpperType { using Type = uint32_t; }; template <> struct UpperType { using Type = int; }; template <> struct UpperType<__nv_bfloat16> { using Type = double; }; template <> struct UpperType { using Type = double; }; template <> struct UpperType { using Type = double; }; extern "C" { __device__ uint32_t __nvvm_get_smem_pointer(void* ptr); } __forceinline__ __device__ void issue_stas(uint32_t dist_barrier_ptr, uint32_t dist_buffer_ptr, uint32_t d0) { #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && (__CUDACC_VER_MAJOR__ >= 12)) asm volatile("st.async.weak.shared::cluster.mbarrier::complete_tx::bytes.b32 [%0], %2, [%1];\n\t" : : "r"(dist_buffer_ptr), "r"(dist_barrier_ptr), "r"(d0)); #endif } __forceinline__ __device__ void issue_stas(uint32_t dist_barrier_ptr, uint32_t dist_buffer_ptr, uint64_t d0) { #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && (__CUDACC_VER_MAJOR__ >= 12)) asm volatile("st.async.weak.shared::cluster.mbarrier::complete_tx::bytes.b64 [%0], %2, [%1];\n\t" : : "r"(dist_buffer_ptr), "r"(dist_barrier_ptr), "l"(d0)); #endif } __forceinline__ __device__ void issue_stas( uint32_t dist_barrier_ptr, uint32_t dist_buffer_ptr, uint32_t d0, uint32_t d1) { #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && (__CUDACC_VER_MAJOR__ >= 12)) asm volatile("st.async.weak.shared::cluster.mbarrier::complete_tx::bytes.v2.b32 [%0], {%2, %3}, [%1];\n\t" : : "r"(dist_buffer_ptr), "r"(dist_barrier_ptr), "r"(d0), "r"(d1)); #endif } __forceinline__ __device__ void issue_stas( uint32_t dist_barrier_ptr, uint32_t dist_buffer_ptr, uint32_t d0, uint32_t d1, uint32_t d2, uint32_t d3) { #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && (__CUDACC_VER_MAJOR__ >= 12)) asm volatile("st.async.shared::cluster.mbarrier::complete_tx::bytes.v4.b32 [%0], {%2, %3, %4, %5}, [%1];\n\t" : : "r"(dist_buffer_ptr), "r"(dist_barrier_ptr), "r"(d0), "r"(d1), "r"(d2), "r"(d3)); #endif } inline __device__ uint32_t elect_one_sync() { uint32_t pred = 0; #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && (__CUDACC_VER_MAJOR__ >= 12)) #if (defined(__CUDA_ARCH_FEAT_SM90_ALL)) uint32_t laneid = 0; asm volatile( "\n\ {\n\ .reg .b32 %rx;\n\ .reg .pred %px;\n\ elect.sync %rx|%px, %2;\n\ @%px mov.s32 %1, 1;\n\ mov.s32 %0, %rx;\n\ }\n\ " : "+r"(laneid), "+r"(pred) : "r"(0xFFFFFFFF)); #endif #endif return pred; } __forceinline__ __device__ uint32_t get_smem_pointer(void const* ptr) { return __nvvm_get_smem_pointer(const_cast(ptr)); } __forceinline__ __device__ void bar_create(void* bar_ptr, int init_count) { #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) unsigned smem_ptr = get_smem_pointer(bar_ptr); asm volatile( "{\n\t" "mbarrier.init.shared.b64 [%1], %0; \n\t" "}" : : "r"(init_count), "r"(smem_ptr)); #endif } struct Arrive_wait { public: __forceinline__ __device__ Arrive_wait() { bar_base_ = NULL; } __forceinline__ __device__ Arrive_wait(uint64_t* bar_base, int id = 0) { bar_base_ = bar_base; id_ = id; } __forceinline__ __device__ int bar_peek(int id, unsigned int bar_phase) { uint32_t result32{}; #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) auto* bar_ptr = bar_base_ + id; unsigned smem_ptr = get_smem_pointer(bar_ptr); asm volatile( "{\n\t" ".reg .pred P1; \n\t" "mbarrier.try_wait.parity.shared.b64 P1, [%1], %2; \n\t" "selp.b32 %0, 1, 0, P1; \n\t" "}" : "=r"(result32) : "r"(smem_ptr), "r"(bar_phase)); #endif return result32; } __forceinline__ __device__ int bar_peek(int id, unsigned int bar_phase, int pred) { uint32_t result32{}; #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) auto* bar_ptr = bar_base_ + id; unsigned smem_ptr = get_smem_pointer(bar_ptr); asm volatile( "{\n\t" ".reg .pred P1; \n\t" ".reg .pred P2;\n\t" "setp.eq.u32 P2, %3, 1;\n\t" "@P2 mbarrier.try_wait.parity.shared.b64 P1, [%1], %2; \n\t" "selp.b32 %0, 1, 0, P1; \n\t" "}" : "=r"(result32) : "r"(smem_ptr), "r"(bar_phase), "r"(pred)); #endif return result32; } __forceinline__ __device__ void bar_wait(int id, unsigned int bar_phase) { #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) auto* bar_ptr = bar_base_ + id; unsigned smem_ptr = get_smem_pointer(bar_ptr); asm volatile( "{\n\t" ".reg .pred P1; \n\t" "LAB_WAIT: \n\t" "mbarrier.try_wait.parity.acquire.cta.shared::cta.b64 P1, [%0], %1; \n\t" "@P1 bra.uni DONE; \n\t" "bra.uni LAB_WAIT; \n\t" "DONE: \n\t" "}" : : "r"(smem_ptr), "r"(bar_phase)); #endif } __forceinline__ __device__ void bar_arrive_dsmem(int const& id) { #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) auto* bar_ptr = bar_base_ + id; asm volatile( "{\n\t" "mbarrier.arrive.b64 _, [%0];\n\t" "}" : : "l"(bar_ptr)); #endif } __forceinline__ __device__ void bar_arrive_dsmem(int const& id, uint32_t const& pred) { #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) asm volatile( "{\n\t" " .reg .pred p;\n\t" " .reg .s64 addr;\n\t" " .reg .b64 tmp;\n\t" " setp.eq.u32 p, %2, 1;\n\t" " mul.wide.s32 tmp, %0, 8;\n\t" " add.s64 addr, tmp, %1;\n\t" "@p mbarrier.arrive.b64 _, [addr];\n\t" "}" : : "r"(id), "l"(bar_base_), "r"(pred)); #endif } // Sets up the base address for arrival with the correct ctaid in cga __forceinline__ __device__ void set_bar_base_dsmem(uint32_t const& cta_id) { bar_base_ = reinterpret_cast( (reinterpret_cast(bar_base_) & 0xFFFFFFFFF0FFFFFFULL) + (cta_id << 24)); } __forceinline__ __device__ void bar_arrive_normal(int id, bool flag = true) { #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) if (flag == true) { uint64_t* bar_ptr = reinterpret_cast(bar_base_ + id); unsigned smem_ptr = get_smem_pointer(bar_ptr); asm volatile( "{\n\t" ".reg .b64 state; \n\t" "mbarrier.arrive.shared.b64 state, [%0];\n\t" "}" : : "r"(smem_ptr)); } #endif } __forceinline__ __device__ void bar_arrive_set_transactioncnt(int id, int expected_copy_bytes) { #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) auto* bar_ptr = bar_base_ + id; unsigned smem_ptr = get_smem_pointer(bar_ptr); asm volatile( "{\n\t" "mbarrier.arrive.expect_tx.shared.b64 _, [%0], %1; \n\t" "}" : : "r"(smem_ptr), "r"(expected_copy_bytes)); #endif } __forceinline__ __device__ void bar_arrive_set_transactioncnt(int id, int expected_copy_bytes, uint32_t pred) { #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) auto* bar_ptr = bar_base_ + id; unsigned smem_ptr = get_smem_pointer(bar_ptr); asm volatile( "{\n\t" ".reg .pred p;\n\t" "setp.eq.u32 p, %2, 1;\n\t" "@p mbarrier.arrive.expect_tx.shared.b64 _, [%0], %1; \n\t" "}" : : "r"(smem_ptr), "r"(expected_copy_bytes), "r"(pred)); #endif } __forceinline__ __device__ uint64_t* bar_base() { return bar_base_; } __forceinline__ __device__ uint64_t* get_bar_addr(int id) { return bar_base_ + id; } private: // smem barrier base pointer uint64_t* bar_base_; // barrier id int id_; }; __forceinline__ __device__ void cga_sync() { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 asm volatile("barrier.cluster.sync;\n" : :); #endif } __forceinline__ __device__ void cga_arrive() { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 asm volatile("barrier.cluster.arrive.aligned;\n" : :); #endif } __forceinline__ __device__ void cga_wait() { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 asm volatile("barrier.cluster.wait.aligned;\n" : :); #endif } inline __device__ void fence_view_async_shared() { // only compiles on sm90+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 asm volatile("fence.proxy.async.shared::cta;\n" : :); #endif } template __forceinline__ __device__ T* get_DSMEM_ptr(T* localAddress, uint32_t destCtaId) { T* dsmemAddress = reinterpret_cast(((unsigned long long int) localAddress & 0xFFFFFFFFF0FFFFFFULL) + (destCtaId << 24)); return dsmemAddress; } template __forceinline__ __device__ void write_DSMEM_Address(T* localAddress, uint32_t destCtaId, T val) { T* dsmemAddress = get_DSMEM_ptr(localAddress, destCtaId); *dsmemAddress = val; } __forceinline__ __device__ void arrive_barrier(uint64_t* p_barrier, uint32_t arrive_cnt = 1) { #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) asm volatile("{mbarrier.arrive.shared.b64 _, [%0],%1;\n\t}" : : "l"(p_barrier), "r"(arrive_cnt)); #endif } __forceinline__ __device__ void arrive_DSMEM_barrier(uint64_t* p_barrier, uint32_t ctaid) { #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) uint64_t* p_barrier_remote = get_DSMEM_ptr(p_barrier, ctaid); asm volatile("{mbarrier.arrive.b64 _, [%0];\n\t}" : : "l"(p_barrier_remote)); #endif } __forceinline__ __device__ void arrive_DSMEM_barrier_and_set_tx_cnt( uint64_t* p_barrier, uint32_t ctaid, uint32_t expected_copy_bytes) { #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) uint32_t p_bar = static_cast(reinterpret_cast(get_DSMEM_ptr(p_barrier, ctaid))); asm volatile("{mbarrier.arrive.expect_tx.b64 _, [%0], %1; \n\t}" ::"r"(p_bar), "r"(expected_copy_bytes)); #endif } template __forceinline__ __device__ void stas(uint32_t* p_data, uint64_t* p_barrier, uint32_t ctaid, uint32_t const& wrdat) { #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) if (barSetTxCnt) arrive_DSMEM_barrier_and_set_tx_cnt(p_barrier, ctaid, sizeof(uint32_t)); uint32_t buffer_ptr = static_cast(reinterpret_cast(p_data)); uint32_t barrier_ptr = static_cast(reinterpret_cast(p_barrier)); uint32_t buffer_ptr_, barrier_ptr_; asm volatile( "{\n\t" "setctarank.shared.u32 %0, %2, %4;\n\t" "setctarank.shared.u32 %1, %3, %4;\n\t" "}" : "=r"(buffer_ptr_), "=r"(barrier_ptr_) : "r"(buffer_ptr), "r"(barrier_ptr), "r"(ctaid)); issue_stas(buffer_ptr_, barrier_ptr_, wrdat); #endif } template __forceinline__ __device__ void stas(uint64_t* p_data, uint64_t* p_barrier, uint32_t ctaid, uint64_t const& wrdat) { #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) if (barSetTxCnt) arrive_DSMEM_barrier_and_set_tx_cnt(p_barrier, ctaid, sizeof(uint64_t)); uint32_t buffer_ptr = static_cast(reinterpret_cast(p_data)); uint32_t barrier_ptr = static_cast(reinterpret_cast(p_barrier)); uint32_t buffer_ptr_, barrier_ptr_; asm volatile( "{\n\t" "setctarank.shared.u32 %0, %2, %4;\n\t" "setctarank.shared.u32 %1, %3, %4;\n\t" "}" : "=r"(buffer_ptr_), "=r"(barrier_ptr_) : "r"(buffer_ptr), "r"(barrier_ptr), "r"(ctaid)); issue_stas(buffer_ptr_, barrier_ptr_, wrdat); #endif } template __forceinline__ __device__ void stas(uint64_t* p_data, uint64_t* p_barrier, uint32_t ctaid, uint32_t const wrdat0, uint32_t const wrdat1, uint32_t const wrdat2, uint32_t const wrdat3) { #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) if (barSetTxCnt) arrive_DSMEM_barrier_and_set_tx_cnt(p_barrier, ctaid, 4 * sizeof(uint32_t)); uint32_t buffer_ptr = static_cast(reinterpret_cast(p_data)); uint32_t barrier_ptr = static_cast(reinterpret_cast(p_barrier)); uint32_t buffer_ptr_, barrier_ptr_; asm volatile( "{\n\t" "setctarank.shared.u32 %0, %2, %4;\n\t" "setctarank.shared.u32 %1, %3, %4;\n\t" "}" : "=r"(buffer_ptr_), "=r"(barrier_ptr_) : "r"(buffer_ptr), "r"(barrier_ptr), "r"(ctaid)); issue_stas(buffer_ptr_, barrier_ptr_, wrdat0, wrdat1, wrdat2, wrdat3); #endif } template __forceinline__ __device__ void stas(T* p_data, uint64_t* p_barrier, uint32_t ctaid, T const& wrdat) { #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) static_assert(sizeof(T) % 4 == 0); if (barSetTxCnt) arrive_DSMEM_barrier_and_set_tx_cnt(p_barrier, ctaid, sizeof(T)); uint32_t buffer_ptr = static_cast(reinterpret_cast(p_data)); uint32_t barrier_ptr = static_cast(reinterpret_cast(p_barrier)); uint32_t buffer_ptr_, barrier_ptr_; asm volatile( "{\n\t" "setctarank.shared.u32 %0, %2, %4;\n\t" "setctarank.shared.u32 %1, %3, %4;\n\t" "}" : "=r"(buffer_ptr_), "=r"(barrier_ptr_) : "r"(buffer_ptr), "r"(barrier_ptr), "r"(ctaid)); uint32_t const* p_wrdat_b32 = reinterpret_cast(&wrdat); for (uint32_t offset = 0; offset < sizeof(T);) { if constexpr (assumeAligned) { if (offset + 16 <= sizeof(T)) { // Use write_async_v4_b32 issue_stas(buffer_ptr_ + offset, barrier_ptr_, p_wrdat_b32[offset / 4], p_wrdat_b32[offset / 4 + 1], p_wrdat_b32[offset / 4 + 2], p_wrdat_b32[offset / 4 + 3]); offset += 16; } else if (offset + 8 <= sizeof(T) && (buffer_ptr + offset) % 8 == 0) { // Use write_async_v2_b32 issue_stas(buffer_ptr + offset, barrier_ptr_, p_wrdat_b32[offset / 4], p_wrdat_b32[offset / 4 + 1]); offset += 8; } else { issue_stas(buffer_ptr + offset, barrier_ptr_, p_wrdat_b32[offset / 4]); offset += 4; } } else { issue_stas(buffer_ptr + offset, barrier_ptr_, p_wrdat_b32[offset / 4]); offset += 4; } } #endif } struct OrderedMutex { uint64_t barriers[2]; __device__ void init(int tid0, int threads0, int threads1) { if (tid0) { bar_create(&barriers[0], threads0); bar_create(&barriers[1], threads1); } } OrderedMutex() = default; OrderedMutex(OrderedMutex const& other) = delete; }; class OrderedMutexAccessor { public: struct State { int phase = 0; }; private: int _phase; int _id; Arrive_wait _barriers; public: __device__ OrderedMutexAccessor(OrderedMutex& m, int id, State state) : _phase(state.phase) , _id(id) , _barriers(m.barriers) { } __device__ void arrive() { _barriers.bar_arrive_normal(_id); } __device__ void wait() { _barriers.bar_wait(_id ^ 1, _phase); _phase ^= 1; } __device__ State exportState() { return {.phase = _phase}; } }; template struct ConstExprWrapper { static constexpr T value = VALUE; }; template using ConstInt = ConstExprWrapper; template using ConstBool = ConstExprWrapper; template struct TmaDescType; template <> struct TmaDescType<__nv_bfloat16> { static constexpr auto value = CUtensorMapDataType_enum::CU_TENSOR_MAP_DATA_TYPE_BFLOAT16; }; template <> struct TmaDescType { static constexpr auto value = CUtensorMapDataType_enum::CU_TENSOR_MAP_DATA_TYPE_FLOAT32; }; #define DEFINE_MEMBER_CHECKER(member) \ template \ struct has_##member : std::false_type \ { \ }; \ template \ struct has_##member().member), void>::value, bool>::type> \ : std::true_type \ { \ }; #define HAS_MEMBER(C, member) has_##member::value DEFINE_MEMBER_CHECKER(output) DEFINE_MEMBER_CHECKER(residual) DEFINE_MEMBER_CHECKER(bias) DEFINE_MEMBER_CHECKER(deq) DEFINE_MEMBER_CHECKER(qua) DEFINE_MEMBER_CHECKER(high_preciecion_normed_output) } // namespace common TRTLLM_NAMESPACE_END /* * Macros compliant with TensorRT coding conventions */ #define TLLM_CUDA_CHECK(stat) \ do \ { \ tensorrt_llm::common::check((stat), #stat, __FILE__, __LINE__); \ } while (0) // We use singleton memory pool and the order of destructors depends on the compiler implementation. We find that the // cudaFree/cudaFreeHost is called after cudaruntime destruction on Windows. There will be an cudaErrorCudartUnloading // error. However, it is safe to ignore this error because the cuda runtime is already exited, we are no more worried // about the memory leaks. #define TLLM_CUDA_CHECK_FREE_RESOURCE(stat) \ do \ { \ tensorrt_llm::common::checkEx((stat), {cudaSuccess, cudaErrorCudartUnloading}, #stat, __FILE__, __LINE__); \ } while (0)