/* * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include #include "cumsumLastDim.h" #include namespace tensorrt_llm { namespace kernels { /////////////// template size_t invokeComputeCumsumLastDimWorkspaceSize(SizeType32 inputLength) { T* iodata = nullptr; size_t tempStorageBytes; cub::DeviceScan::InclusiveSum(nullptr, tempStorageBytes, iodata, iodata, inputLength); return tempStorageBytes; } #define INSTANTIATE_COMPUTE_CUMSUM_LastDim_WORKSPACE_SIZE_DATA_TYPE(T) \ template size_t invokeComputeCumsumLastDimWorkspaceSize(int inputLength) INSTANTIATE_COMPUTE_CUMSUM_LastDim_WORKSPACE_SIZE_DATA_TYPE(int); INSTANTIATE_COMPUTE_CUMSUM_LastDim_WORKSPACE_SIZE_DATA_TYPE(float); INSTANTIATE_COMPUTE_CUMSUM_LastDim_WORKSPACE_SIZE_DATA_TYPE(half); #ifdef ENABLE_BF16 INSTANTIATE_COMPUTE_CUMSUM_LastDim_WORKSPACE_SIZE_DATA_TYPE(__nv_bfloat16); #endif #undef INSTANTIATE_COMPUTE_CUMSUM_LastDim_WORKSPACE_SIZE_DATA_TYPE /////////////// template __global__ void cumsum_last_dim(T const* d_in, T* d_out, int length) { typedef cub::BlockLoad BlockLoadT; typedef cub::BlockStore BlockStoreT; typedef cub::BlockScan BlockScanT; int const row_idx = blockIdx.x; T const* local_d_in = d_in + row_idx * length; T* local_d_out = d_out + row_idx * length; // Shared memory __shared__ union TempStorage { typename BlockLoadT::TempStorage load; typename BlockStoreT::TempStorage store; typename BlockScanT::TempStorage scan; } temp_storage; int tile_size = THREADS_PER_BLOCK * ITEMS_PER_THREAD; T aggregate = static_cast(0); T const* cur_d_in = local_d_in; T* cur_d_out = local_d_out; for (int tile_start = 0; tile_start < length; tile_start += tile_size, cur_d_in += tile_size, cur_d_out += tile_size) { int cur_tile_size = (tile_start + tile_size) <= length ? tile_size : (length - tile_start); T data[ITEMS_PER_THREAD]; // Per-thread tile data // Load items into a blocked arrangement BlockLoadT(temp_storage.load).Load(cur_d_in, data, cur_tile_size, static_cast(0)); if (threadIdx.x == 0) { data[0] += aggregate; } __syncthreads(); BlockScanT(temp_storage.scan).InclusiveSum(data, data, aggregate); __syncthreads(); // Store items from a blocked arrangement BlockStoreT(temp_storage.store).Store(cur_d_out, data, cur_tile_size); __syncthreads(); } } /////////////// template void invokeDeviceScan(SizeType32 batchSize, SizeType32 inputLength, void const* __restrict__ input, void* __restrict__ output, void* d_temp_storage, size_t tempStorageBytes, cudaStream_t stream) { for (SizeType32 i = 0; i < batchSize; i++) { T const* inputPtr = reinterpret_cast(input) + i * inputLength; T* outputPtr = reinterpret_cast(output) + i * inputLength; cub::DeviceScan::InclusiveSum(d_temp_storage, tempStorageBytes, inputPtr, outputPtr, inputLength, stream); } } /////////////// template void invokeCumsumLastDim(SizeType32 batchSize, SizeType32 inputLength, void const* __restrict__ input, void* __restrict__ output, void* deviceTempStorage, size_t tempStorageBytes, cudaStream_t stream) { // For empty tensor support if (batchSize == 0) { return; } if (deviceTempStorage != nullptr) // we need to use DeviceScan { invokeDeviceScan(batchSize, inputLength, input, output, deviceTempStorage, tempStorageBytes, stream); return; } T const* inputPtr = reinterpret_cast(input); T* outputPtr = reinterpret_cast(output); // Launch the kernel if (inputLength <= 64) { int const ITP = 1; int const TPB = 32; const size_t SHMEM = sizeof(T) * TPB * ITP; const cub::BlockScanAlgorithm ALG = cub::BLOCK_SCAN_WARP_SCANS; cumsum_last_dim<<>>(inputPtr, outputPtr, inputLength); } else if (inputLength < 512) { int const ITP = 2; int const TPB = 64; const size_t SHMEM = sizeof(T) * TPB * ITP; const cub::BlockScanAlgorithm ALG = cub::BLOCK_SCAN_WARP_SCANS; cumsum_last_dim<<>>(inputPtr, outputPtr, inputLength); } else // if () { int const ITP = 8; int const TPB = 256; const size_t SHMEM = sizeof(T) * TPB * ITP; const cub::BlockScanAlgorithm ALG = cub::BLOCK_SCAN_WARP_SCANS; cumsum_last_dim<<>>(inputPtr, outputPtr, inputLength); } } #define INSTANTIATE_CUMSUM_LastDim_DATA_TYPE(T) \ template void invokeCumsumLastDim(SizeType32 batchSize, SizeType32 inputLength, const void* __restrict__ input, \ void* __restrict__ output, void* workspace, size_t tempStorageBytes, cudaStream_t stream) INSTANTIATE_CUMSUM_LastDim_DATA_TYPE(int); INSTANTIATE_CUMSUM_LastDim_DATA_TYPE(float); INSTANTIATE_CUMSUM_LastDim_DATA_TYPE(half); #ifdef ENABLE_BF16 INSTANTIATE_CUMSUM_LastDim_DATA_TYPE(__nv_bfloat16); #endif #undef INSTANTIATE_CUMSUM_LastDim_DATA_TYPE } // namespace kernels } // namespace tensorrt_llm