diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index dd3f5423fd..24779ef6b0 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -212,6 +212,7 @@ endif() include_directories( SYSTEM ${CUDAToolkit_INCLUDE_DIRS} + ${CUDAToolkit_INCLUDE_DIRS}/cccl ${CUDNN_ROOT_DIR}/include $ ${3RDPARTY_DIR}/cutlass/include diff --git a/cpp/include/tensorrt_llm/deep_gemm/tma_utils.cuh b/cpp/include/tensorrt_llm/deep_gemm/tma_utils.cuh index 411d744760..33ddfd31ec 100644 --- a/cpp/include/tensorrt_llm/deep_gemm/tma_utils.cuh +++ b/cpp/include/tensorrt_llm/deep_gemm/tma_utils.cuh @@ -95,7 +95,7 @@ constexpr CUtensorMapDataType get_CUtensorMapDataType() } } -PFN_cuTensorMapEncodeTiled get_cuTensorMapEncodeTiled() +PFN_cuTensorMapEncodeTiled_v12000 get_cuTensorMapEncodeTiled() { // Get pointer to `cuTensorMapEncodeTiled` cudaDriverEntryPointQueryResult driver_status; @@ -110,12 +110,12 @@ PFN_cuTensorMapEncodeTiled get_cuTensorMapEncodeTiled() if (driver_status != cudaDriverEntryPointSuccess) throw std::runtime_error("driver_status != cudaDriverEntryPointSuccess"); - return reinterpret_cast(cuTensorMapEncodeTiled_ptr); + return reinterpret_cast(cuTensorMapEncodeTiled_ptr); } template CUtensorMap make_2d_tma_copy_desc(T* global_address, uint64_t gmem_dim[2], uint64_t stride_in_bytes, - uint32_t smem_dim[2], CUtensorMapSwizzle swizzle_type, PFN_cuTensorMapEncodeTiled encode_func = nullptr) + uint32_t smem_dim[2], CUtensorMapSwizzle swizzle_type, PFN_cuTensorMapEncodeTiled_v12000 encode_func = nullptr) { CUtensorMap tensor_map{}; constexpr uint32_t rank = 2; diff --git a/cpp/tensorrt_llm/kernels/beamSearchKernels.cu b/cpp/tensorrt_llm/kernels/beamSearchKernels.cu index 97c35478bc..ff5f5347b4 100644 --- a/cpp/tensorrt_llm/kernels/beamSearchKernels.cu +++ b/cpp/tensorrt_llm/kernels/beamSearchKernels.cu @@ -134,15 +134,14 @@ void invokeUpdateCacheIndirection(int* tgtCI, int const* srcCI, BeamHypotheses& sync_check_cuda_error(stream); } -template -__global__ void addCumLogProbs(T* __restrict pStage1LogProbs, float const* __restrict cumLogProbs, +__global__ void addCumLogProbs(float* __restrict pStage1LogProbs, float const* __restrict cumLogProbs, FinishedState const* finished, int const* endIds, float const* diversityRates, runtime::SizeType32 const* batchSlots, size_t const nBS, size_t const nBMIn, size_t const nBMOut, size_t const nBM) { int const bid = blockIdx.x; // Index of request in batch runtime::SizeType32 const slot = batchSlots[bid]; float const diversityRate{diversityRates[slot]}; - T* pLocalLogProbs = pStage1LogProbs + bid * nBMIn * nBMOut * 2; + float* pLocalLogProbs = pStage1LogProbs + bid * nBMIn * nBMOut * 2; for (int i = threadIdx.x; i < nBMIn * nBMOut * 2; i += blockDim.x) { @@ -160,13 +159,30 @@ __global__ void addCumLogProbs(T* __restrict pStage1LogProbs, float const* __res return; } -template __global__ void addCumLogProbs(float* __restrict pStage1LogProbs, float const* __restrict cumLogProbs, +__global__ void addCumLogProbs(half* __restrict pStage1LogProbs, float const* __restrict cumLogProbs, FinishedState const* finished, int const* endIds, float const* diversityRates, - runtime::SizeType32 const* batchSlots, size_t const nBS, size_t const nBMIn, size_t const nBMOut, size_t const nBM); + runtime::SizeType32 const* batchSlots, size_t const nBS, size_t const nBMIn, size_t const nBMOut, size_t const nBM) +{ + int const bid = blockIdx.x; // Index of request in batch + runtime::SizeType32 const slot = batchSlots[bid]; + float const diversityRate{diversityRates[slot]}; + half* pLocalLogProbs = pStage1LogProbs + bid * nBMIn * nBMOut * 2; -template __global__ void addCumLogProbs(half* __restrict pStage1LogProbs, float const* __restrict cumLogProbs, - FinishedState const* finished, int const* endIds, float const* diversityRates, - runtime::SizeType32 const* batchSlots, size_t const nBS, size_t const nBMIn, size_t const nBMOut, size_t const nBM); + for (int i = threadIdx.x; i < nBMIn * nBMOut * 2; i += blockDim.x) + { + int const iBMIn = i / (nBMOut * 2); + if (finished[slot * nBMIn + iBMIn].isFinished()) + { + pLocalLogProbs[i] += (i == endIds[slot]) ? 1.0f : 0.0f; + } + else + { + // nBM is used in VBWS since `cumLogProbs` is initialized with kMaxBeamWidth earlier than BeamSearchLayer + pLocalLogProbs[i] += cumLogProbs[slot * nBM + iBMIn] + diversityRate * iBMIn; + } + } + return; +} __global__ void gatherId(int const* __restrict pStage1Id, int* __restrict pStage2Id, size_t const nBS, size_t const nBMIn, size_t const nBMOut, size_t const nV) diff --git a/cpp/tensorrt_llm/kernels/beamSearchKernels.h b/cpp/tensorrt_llm/kernels/beamSearchKernels.h index 10a285af90..ebf41d7787 100644 --- a/cpp/tensorrt_llm/kernels/beamSearchKernels.h +++ b/cpp/tensorrt_llm/kernels/beamSearchKernels.h @@ -130,8 +130,11 @@ void invokeTopkBeamSearch(T const* logProbs, T const* bias, void* workspace, Bea void invokeUpdateCacheIndirection(int* tgtCI, int const* srcCI, BeamHypotheses& bh, runtime::SizeType32 const maxAttentionWindow, runtime::SizeType32 sinkTokenLength, cudaStream_t stream); -template -__global__ void addCumLogProbs(T* __restrict pStage1Probs, float const* __restrict cumLogProbs, +__global__ void addCumLogProbs(float* __restrict pStage1LogProbs, float const* __restrict cumLogProbs, + FinishedState const* finished, int const* endIds, float const* diversityRates, + runtime::SizeType32 const* batchSlots, size_t const nBS, size_t const nBMIn, size_t const nBMOut, size_t const nBM); + +__global__ void addCumLogProbs(half* __restrict pStage1LogProbs, float const* __restrict cumLogProbs, FinishedState const* finished, int const* endIds, float const* diversityRates, runtime::SizeType32 const* batchSlots, size_t const nBS, size_t const nBMIn, size_t const nBMOut, size_t const nBM); diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/fp8_blockscale_gemm/fp8_blockscale_tma_utils.cuh b/cpp/tensorrt_llm/kernels/cutlass_kernels/fp8_blockscale_gemm/fp8_blockscale_tma_utils.cuh index b105368af0..18911feb7c 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/fp8_blockscale_gemm/fp8_blockscale_tma_utils.cuh +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/fp8_blockscale_gemm/fp8_blockscale_tma_utils.cuh @@ -84,7 +84,7 @@ inline CUtensorMapDataType get_CUtensorMapDataType() } } -PFN_cuTensorMapEncodeTiled get_cuTensorMapEncodeTiled() +PFN_cuTensorMapEncodeTiled_v12000 get_cuTensorMapEncodeTiled() { // Get pointer to cuTensorMapEncodeTiled cudaDriverEntryPointQueryResult driver_status; @@ -101,12 +101,12 @@ PFN_cuTensorMapEncodeTiled get_cuTensorMapEncodeTiled() throw std::runtime_error("driver_status != cudaDriverEntryPointSuccess"); } - return reinterpret_cast(cuTensorMapEncodeTiled_ptr); + return reinterpret_cast(cuTensorMapEncodeTiled_ptr); } template CUtensorMap make_2d_tma_copy_desc(data_type* global_address, uint64_t gmem_dim[2], uint64_t stride_in_bytes, - uint32_t smem_dim[2], CUtensorMapSwizzle swizzle_type, PFN_cuTensorMapEncodeTiled encode_func = nullptr) + uint32_t smem_dim[2], CUtensorMapSwizzle swizzle_type, PFN_cuTensorMapEncodeTiled_v12000 encode_func = nullptr) { CUtensorMap tensor_map{}; constexpr uint32_t rank = 2; diff --git a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderMaskedMultiheadAttentionTemplate.h b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderMaskedMultiheadAttentionTemplate.h index 744029c177..8536b940a7 100644 --- a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderMaskedMultiheadAttentionTemplate.h +++ b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderMaskedMultiheadAttentionTemplate.h @@ -2597,7 +2597,7 @@ __global__ void __launch_bounds__(MAX_THEADS_PER_BLOCK, MIN_BLOCKS_PER_SM) maske __shared__ typename BlockReduce::TempStorage temp_storage; // Obtain a segment of consecutive items that are blocked across threads (final_max from above) // Compute the block-wide max for thread0 - final_max = BlockReduce(temp_storage).Reduce(thread_partial_max, cub::Max(), gridDim.z); + final_max = BlockReduce(temp_storage).Reduce(thread_partial_max, cuda::maximum(), gridDim.z); __shared__ float final_max_smem; if (tidx == 0) diff --git a/cpp/tensorrt_llm/kernels/sageAttentionKernels.cu b/cpp/tensorrt_llm/kernels/sageAttentionKernels.cu index 80a12b41ce..e45a7bb97f 100644 --- a/cpp/tensorrt_llm/kernels/sageAttentionKernels.cu +++ b/cpp/tensorrt_llm/kernels/sageAttentionKernels.cu @@ -250,7 +250,7 @@ __global__ void sage_quant_kernel(void const* q, void const* k, void const* v, i // Compute the block-wide max for thread0 // cuda::maximum<>{} - float aggregate = BlockReduce(temp_storage).Reduce(local_amax, cub::Max{}); + float aggregate = BlockReduce(temp_storage).Reduce(local_amax, cuda::maximum{}); if (row_id == 0 && col_id == 0) s_block_amax = static_cast(aggregate); @@ -429,7 +429,7 @@ __global__ void sage_quant_kernel(void const* q, void const* k, void const* v, i // Compute the block-wide max for thread0 // cuda::maximum<>{} - float aggregate = BlockReduce(temp_storage).Reduce(local_amax, cub::Max{}); + float aggregate = BlockReduce(temp_storage).Reduce(local_amax, cuda::maximum{}); if (row_id == 0 && col_id == 0) s_block_amax = static_cast(aggregate); diff --git a/cpp/tensorrt_llm/kernels/speculativeDecoding/eagleDecodingKernels.cu b/cpp/tensorrt_llm/kernels/speculativeDecoding/eagleDecodingKernels.cu index b3a90bea5f..e963033855 100644 --- a/cpp/tensorrt_llm/kernels/speculativeDecoding/eagleDecodingKernels.cu +++ b/cpp/tensorrt_llm/kernels/speculativeDecoding/eagleDecodingKernels.cu @@ -504,7 +504,7 @@ __global__ void prepareGenEagleNetInputsKernel(SizeType32* nextSequenceLengths, BlockScan(tempStorage.scan).ExclusiveSum(numNextLogits, outputLastIndicesBase); // Sync because tempStorage is reused. __syncthreads(); - auto const maxGenLength = BlockReduce(tempStorage.reduce).Reduce(nextDraftLen, cub::Max()); + auto const maxGenLength = BlockReduce(tempStorage.reduce).Reduce(nextDraftLen, cuda::maximum()); // Thread 0 has the result. if (bid == 0) diff --git a/cpp/tensorrt_llm/kernels/topkLastDim.cu b/cpp/tensorrt_llm/kernels/topkLastDim.cu index 3371ab4a0f..b13cd00b8f 100644 --- a/cpp/tensorrt_llm/kernels/topkLastDim.cu +++ b/cpp/tensorrt_llm/kernels/topkLastDim.cu @@ -25,6 +25,8 @@ #include "topkLastDim.h" #include #include +#include +#include namespace tensorrt_llm { @@ -1221,9 +1223,9 @@ void standalone_stable_radix_topk_(void* buf, size_t& buf_size, T const* in, Idx IdxT* sort_in_idx = nullptr; air_topk_stable::ComputeOffset computeoffset(k); - cub::CountingInputIterator counting_iter(0); - cub::TransformInputIterator, cub::CountingInputIterator> - transform_iter(counting_iter, computeoffset); + thrust::counting_iterator counting_iter(0); + thrust::transform_iterator, thrust::counting_iterator> transform_iter( + counting_iter, computeoffset); cub::DeviceSegmentedSort::SortPairs(NULL, temp_storage_bytes, out_idx, out_idx, out, out, k * batch_size, batch_size, transform_iter, transform_iter + 1, stream); if (sorted) @@ -1348,9 +1350,9 @@ void standalone_stable_radix_topk_one_block_(void* buf, size_t& buf_size, T cons const IdxT buf_len = air_topk_stable::calc_buf_len(len); air_topk_stable::ComputeOffset computeoffset(k); - cub::CountingInputIterator counting_iter(0); - cub::TransformInputIterator, cub::CountingInputIterator> - transform_iter(counting_iter, computeoffset); + thrust::counting_iterator counting_iter(0); + thrust::transform_iterator, thrust::counting_iterator> transform_iter( + counting_iter, computeoffset); cub::DeviceSegmentedSort::SortPairs(NULL, temp_storage_bytes, out_idx, out_idx, out, out, k * batch_size, batch_size, transform_iter, transform_iter + 1, stream); diff --git a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/DevKernel.cu b/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/DevKernel.cu index ad5cd15fdd..ba850c45a2 100644 --- a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/DevKernel.cu +++ b/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/DevKernel.cu @@ -154,7 +154,7 @@ __global__ void activationDeepSeekKernel(KernelParams params) float constexpr E4m3MaxVal{448.f}; // Compute the absolute max - float aMax = BlockReduce(temp_storage).Reduce(fabsf(out), cub::Max()); + float aMax = BlockReduce(temp_storage).Reduce(fabsf(out), cuda::maximum()); if (threadIdx.x == 0) { s_scaleOut = aMax / E4m3MaxVal; @@ -657,7 +657,7 @@ __global__ void finalizeDeepSeekKernel(KernelParams params) float constexpr E4m3MaxVal{448.f}; // Compute the absolute max - float aMax = BlockReduce(temp_storage).Reduce(fabsf(acc), cub::Max()); + float aMax = BlockReduce(temp_storage).Reduce(fabsf(acc), cuda::maximum()); if (threadIdx.x == 0) { diff --git a/cpp/tensorrt_llm/runtime/utils/debugUtils.cu b/cpp/tensorrt_llm/runtime/utils/debugUtils.cu index 7f1c8d8dfc..661dacd9a7 100644 --- a/cpp/tensorrt_llm/runtime/utils/debugUtils.cu +++ b/cpp/tensorrt_llm/runtime/utils/debugUtils.cu @@ -54,7 +54,7 @@ __global__ void checkTensorInvalidKernel(T const* data, std::size_t size, int* f __shared__ typename BlockReduceT::TempStorage tempStorage; // Compute block-wide maximum - int blockFound = BlockReduceT(tempStorage).Reduce(found, cub::Max()); + int blockFound = BlockReduceT(tempStorage).Reduce(found, cuda::maximum()); // Have thread 0 write out block's result if (threadIdx.x == 0)