/* * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ // Helper functions for lamport-based synchronization #ifndef TRTLLM_CUDA_LAMPORT_UTILS_CUH #define TRTLLM_CUDA_LAMPORT_UTILS_CUH #include #include #include #include #include #include #include "tensorrt_llm/common/cudaTypeUtils.cuh" namespace tensorrt_llm::common { constexpr uint16_t kNEGZERO_FP16 = 0x8000U; template union Fp16BitCast { T mFp; uint16_t mInt; constexpr Fp16BitCast() : mInt(0) { } constexpr Fp16BitCast(T val) : mFp(val) { } constexpr Fp16BitCast(uint16_t val) : mInt(val) { } }; template static constexpr __device__ __host__ T negZero() { if constexpr (std::is_same_v) { return -0.0F; } else if constexpr (std::is_same_v || std::is_same_v) { return Fp16BitCast(kNEGZERO_FP16).mFp; } else { static_assert(sizeof(T) == 0, "negativeZero not specialized for this type"); } return T{}; // Never reached, but needed for compilation } template static inline __device__ bool isNegZero(T val) { if constexpr (std::is_same_v) { return val == 0.F && signbit(val); } else if constexpr (std::is_same_v || std::is_same_v) { return Fp16BitCast(val).mInt == kNEGZERO_FP16; } else { static_assert(sizeof(T) == 0, "isNegZero not specialized for this type"); } return false; // Never reached, but needed for compilation } template constexpr __device__ __host__ PackedType getPackedLamportInit() { static_assert(sizeof(PackedType) % sizeof(T) == 0, "PackedType size must be divisible by T size"); constexpr int kNumElements = sizeof(PackedType) / sizeof(T); union PackedT { PackedType mPacked; std::array mElements; constexpr PackedT() : mElements{} { for (int i = 0; i < kNumElements; i++) { mElements[i] = negZero(); } } }; PackedT initValue{}; return initValue.mPacked; } // A helper class to get the correct base pointer for a given layout struct LamportBufferLayout { uint32_t numStages = 1; uint32_t bytesPerBuffer = 0; static constexpr uint32_t sNumLamportBuffers = 3; // Implicitly inlined [[nodiscard]] __device__ __host__ size_t getTotalBytes() const { return numStages * static_cast(bytesPerBuffer / numStages) * sNumLamportBuffers; } // Implicitly inlined [[nodiscard]] __device__ __host__ void* getStagePtr( void* bufferBasePtr, uint32_t lamportIndex, uint32_t stageIndex) const { // Typecast to avoid warnings return reinterpret_cast(reinterpret_cast(bufferBasePtr) + static_cast( (lamportIndex * numStages + stageIndex) * static_cast(bytesPerBuffer / numStages))); } }; // Current Index // Dirty Index // bytes_per_buffer // Dirty num_stages // Dirty bytes_to_clear = {stage0, stage1, stage2, stage3} # We fix this to 4 stages // offset_access_ptr namespace cg = cooperative_groups; // PackedType is the one used in kernel for Lamport buffer (LDG.128 or LDG.64) template __device__ struct __attribute__((aligned(32))) LamportFlags { public: __device__ explicit LamportFlags(uint32_t* bufferFlags, uint32_t numStages = 1) : mBufferFlagsPtr(bufferFlags) , mFlagAccessPtr(&bufferFlags[8]) { mCurBufferLayout.numStages = numStages; uint4 flag = reinterpret_cast(bufferFlags)[0]; mCurrentIndex = flag.x; mDirtyIndex = flag.y; // Buffer size is unchanged as the flag should be coupled to each buffer mCurBufferLayout.bytesPerBuffer = flag.z; mDirtyBufferLayout.bytesPerBuffer = flag.z; mDirtyBufferLayout.numStages = flag.w; *reinterpret_cast(&mBytesToClear) = reinterpret_cast(bufferFlags)[1]; } // Return the base pointer of the lamport buffer indexed by mCurrentIndex and the stageIdx [[nodiscard]] __device__ void* getCurLamportBuf(void* bufferBasePtr, int stageIdx = 0) const { return mCurBufferLayout.getStagePtr(bufferBasePtr, mCurrentIndex, stageIdx); } // Fill the dirty lamport buffer with the init value; Use stageIdx to select the stage to clear, -1 to clear all // FIXME: Current kernel may use less stages than the dirty numStages; How to guarantee the correctness? // CAUTION: This function requires all threads in the grid to participate and ASSUME 1D thread block layout! __device__ void clearDirtyLamportBuf(void* bufferBasePtr, int stageIdx = -1) { // Rasterize the threads to 1D for flexible clearing uint32_t globalCtaIdx = blockIdx.x * gridDim.y + blockIdx.y; uint32_t globalTid = globalCtaIdx * blockDim.x + threadIdx.x; uint32_t numThreads = gridDim.x * gridDim.y * blockDim.x; if (stageIdx == -1) { // Clear all stages for (uint32_t i = 0; i < mDirtyBufferLayout.numStages; i++) { clearPackedBuf(bufferBasePtr, globalTid, numThreads, mBytesToClear[i], mDirtyIndex, i); } } else if (stageIdx < mDirtyBufferLayout.numStages) { clearPackedBuf(bufferBasePtr, globalTid, numThreads, mBytesToClear[stageIdx], mDirtyIndex, stageIdx); } } __device__ void ctaArrive() { int tid{0}; #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) cg::cluster_group cluster = cg::this_cluster(); // We update the atomic counter per cluster tid = cluster.thread_rank(); cluster.sync(); #else tid = threadIdx.x; __syncthreads(); #endif if (tid == 0) { #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)) asm volatile("red.async.release.global.gpu.add.u32 [%0], %1;" ::"l"(mFlagAccessPtr), "r"(1) : "memory"); #elif (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 700)) asm volatile("red.release.global.gpu.add.u32 [%0], %1;" ::"l"(mFlagAccessPtr), "r"(1) : "memory"); #else atomicAdd(mFlagAccessPtr, 1); #endif } } __device__ void waitAndUpdate(uint4 bytesToClearPerStage) { bool isLastCtaT0{false}; int targetCount{0}; #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) cg::grid_group grid = cg::this_grid(); // Use the first thread instead of the last thread as the last thread may exit early isLastCtaT0 = grid.thread_rank() == 0; targetCount = grid.num_clusters(); #else isLastCtaT0 = threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0; targetCount = gridDim.x * gridDim.y; #endif if (isLastCtaT0) { uint4* flagPtr = reinterpret_cast(mBufferFlagsPtr); while (*reinterpret_cast(mFlagAccessPtr) < targetCount) { } // 'Current' becomes 'Dirty' flagPtr[0] = {(mCurrentIndex + 1) % 3, // Current index mCurrentIndex, // Dirty index mCurBufferLayout.bytesPerBuffer, // Buffer size mCurBufferLayout.numStages}; // Dirty - Number of stages flagPtr[1] = bytesToClearPerStage; *mFlagAccessPtr = 0; } } private: uint32_t* mBufferFlagsPtr; uint32_t* mFlagAccessPtr; uint32_t mCurrentIndex, mDirtyIndex; // So that we can access it with uint4 alignas(16) std::array mBytesToClear; LamportBufferLayout mCurBufferLayout, mDirtyBufferLayout; inline __device__ void clearPackedBuf(void* bufferBasePtr, uint32_t globalTid, uint32_t numThreads, uint32_t bytesToClear, uint8_t dirtyIndex, uint8_t stageIdx) { // Round up to the float4 boundary // For the same reason that the divUp is shadowed, we have to define it again here. uint32_t clearBoundary = (bytesToClear + sizeof(PackedType) - 1) / sizeof(PackedType); for (uint32_t packedIdx = globalTid; packedIdx < clearBoundary; packedIdx += numThreads) { reinterpret_cast( mDirtyBufferLayout.getStagePtr(bufferBasePtr, dirtyIndex, stageIdx))[packedIdx] = getPackedLamportInit(); } } }; } // namespace tensorrt_llm::common #endif // TRTLLM_CUDA_LAMPORT_UTILS_CUH