/* * Copyright (c) 2024-2025, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "tensorrt_llm/common/cudaUtils.h" #include "tensorrt_llm/kernels/logitsBitmask.h" using namespace tensorrt_llm::common; using namespace tensorrt_llm::runtime; namespace tensorrt_llm { namespace kernels { namespace { int32_t constexpr kBitsPerMaskElement = 32; int32_t constexpr kThreadsPerBlock = 256; template __device__ T negativeInfinity() { return -INFINITY; } template <> __device__ half negativeInfinity() { return -CUDART_INF_FP16; } template <> __device__ __nv_bfloat16 negativeInfinity<__nv_bfloat16>() { return -CUDART_INF_BF16; } template __device__ PackedT packedNegativeInfinity() { int constexpr kAlignment = sizeof(PackedT) / sizeof(T); T packed[kAlignment]; #pragma unroll for (int i = 0; i < kAlignment; i++) { packed[i] = negativeInfinity(); } return *reinterpret_cast(packed); } template __global__ void __launch_bounds__(kThreadsPerBlock) logitsBitmaskKernel( T** __restrict__ logits, uint32_t const** __restrict__ bitmask, int32_t vocabSizePadded, int32_t bitmaskSize) { int constexpr kAlignment = sizeof(PackedT) / sizeof(T); uint32_t constexpr kPackedMask = (1 << kAlignment) - 1; int const batchIdx = blockIdx.y; int const blockOffset = blockIdx.x * kThreadsPerBlock * kBitsPerThread; T* logitsGmemPtr = logits[batchIdx] + blockOffset; uint32_t const* bitmaskGmemPtr = bitmask[batchIdx] + blockOffset / kBitsPerMaskElement; int const bitmaskInnerIdx = threadIdx.x % (kBitsPerMaskElement / kAlignment); T logitsReg[kAlignment]; #pragma unroll for (int offset = threadIdx.x * kAlignment; offset < kThreadsPerBlock * kBitsPerThread; offset += kThreadsPerBlock * kAlignment) { if (blockOffset + offset >= vocabSizePadded) { break; } uint32_t const bitmaskVal = (~bitmaskGmemPtr[offset / kBitsPerMaskElement] >> (bitmaskInnerIdx * kAlignment)) & kPackedMask; if (bitmaskVal == 0) { continue; } if (bitmaskVal == kPackedMask) { *reinterpret_cast(logitsGmemPtr + offset) = packedNegativeInfinity(); continue; } *reinterpret_cast(logitsReg) = *reinterpret_cast(logitsGmemPtr + offset); #pragma unroll for (int i = 0; i < kAlignment; i++) { if (((bitmaskVal >> i) & 1)) { logitsReg[i] = negativeInfinity(); } } *reinterpret_cast(logitsGmemPtr + offset) = *reinterpret_cast(logitsReg); } } template void logitsBitmaskDispatchToBitsPerThread( T** logits, uint32_t const** bitmask, int32_t batchSize, int32_t vocabSizePadded, cudaStream_t stream) { int constexpr kAlignment = sizeof(PackedT) / sizeof(T); int32_t const numBlocksPerRow = ceilDiv(2048 / kThreadsPerBlock * 128, batchSize); int32_t const numBitsPerThread = ceilDiv(vocabSizePadded, kThreadsPerBlock * numBlocksPerRow); int32_t bitmaskSize = ceilDiv(vocabSizePadded, kBitsPerMaskElement); dim3 const block(kThreadsPerBlock); if (numBitsPerThread <= 4 && kAlignment <= 4) { dim3 const grid(ceilDiv(vocabSizePadded, kThreadsPerBlock * 4), batchSize); logitsBitmaskKernel<<>>(logits, bitmask, vocabSizePadded, bitmaskSize); } else if (numBitsPerThread <= 8 && kAlignment <= 8) { dim3 const grid(ceilDiv(vocabSizePadded, kThreadsPerBlock * 8), batchSize); logitsBitmaskKernel<<>>(logits, bitmask, vocabSizePadded, bitmaskSize); } else if (numBitsPerThread <= 16 && kAlignment <= 16) { dim3 const grid(ceilDiv(vocabSizePadded, kThreadsPerBlock * 16), batchSize); logitsBitmaskKernel<<>>(logits, bitmask, vocabSizePadded, bitmaskSize); } else { dim3 const grid(ceilDiv(vocabSizePadded, kThreadsPerBlock * 32), batchSize); logitsBitmaskKernel<<>>(logits, bitmask, vocabSizePadded, bitmaskSize); } } } // namespace template void invokeLogitsBitmask( T** logits, uint32_t const** bitmask, int32_t batchSize, int32_t vocabSizePadded, cudaStream_t stream) { // Dispatch to PackedT if (vocabSizePadded % (sizeof(float4) / sizeof(T)) == 0) { logitsBitmaskDispatchToBitsPerThread(logits, bitmask, batchSize, vocabSizePadded, stream); } else if (vocabSizePadded % (sizeof(float2) / sizeof(T)) == 0) { logitsBitmaskDispatchToBitsPerThread(logits, bitmask, batchSize, vocabSizePadded, stream); } else if (vocabSizePadded % (sizeof(float) / sizeof(T)) == 0) { logitsBitmaskDispatchToBitsPerThread(logits, bitmask, batchSize, vocabSizePadded, stream); } else { logitsBitmaskDispatchToBitsPerThread(logits, bitmask, batchSize, vocabSizePadded, stream); } } template void invokeLogitsBitmask( float** logits, uint32_t const** bitmask, int32_t batchSize, int32_t vocabSizePadded, cudaStream_t stream); template void invokeLogitsBitmask( half** logits, uint32_t const** bitmask, int32_t batchSize, int32_t vocabSizePadded, cudaStream_t stream); #ifdef ENABLE_BF16 template void invokeLogitsBitmask<__nv_bfloat16>( __nv_bfloat16** logits, uint32_t const** bitmask, int32_t batchSize, int32_t vocabSizePadded, cudaStream_t stream); #endif } // namespace kernels } // namespace tensorrt_llm