/* * SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION & * AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "tensorrt_llm/common/assert.h" #include "tensorrt_llm/common/cudaUtils.h" #include "tensorrt_llm/common/memoryUtils.h" #include "tensorrt_llm/common/reduceKernelUtils.cuh" #include "tensorrt_llm/kernels/speculativeDecoding/kvCacheUpdateKernels.h" #include "tensorrt_llm/runtime/runtimeKernels.h" #include #include #include #include using namespace tensorrt_llm::runtime; namespace tc = tensorrt_llm::common; namespace tensorrt_llm::runtime::kernels { namespace { template __global__ void fill(T* data, std::size_t size, T const value) { auto const tidx = (static_cast(blockIdx.x) * blockDim.x) + threadIdx.x; auto const stride = static_cast(blockDim.x) * gridDim.x; for (auto idx = tidx; idx < size; idx += stride) { data[idx] = value; } } //! @param data expected shape [indicesRange, size] //! @param indices expected shape [gridDim.y] //! @param size //! @param values expected shape [gridDim.y] template __global__ void fillBatch(T* data, std::int32_t const* indices, std::size_t size, T const* values) { auto const batchIdx = indices[blockIdx.y]; T const value = values[blockIdx.y]; auto const tidx = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; auto const stride = static_cast(blockDim.x) * gridDim.x; auto const startIdx = batchIdx * size; auto const endIdx = startIdx + size; for (auto idx = startIdx + tidx; idx < endIdx; idx += stride) { data[idx] = value; } } template void invokeFillBatch(IBuffer& buffer, IBuffer const& slotIndices, std::size_t slotStride, IBuffer const& values, CudaStream const& stream) { auto data = bufferCast(buffer); auto const* const indices = bufferCast(slotIndices); auto fillValues = bufferCast(values); auto numSlots = slotIndices.getSize(); auto const size = slotStride; dim3 const blockSize{256}; std::size_t const gridx{tc::ceilDiv(size, blockSize.x)}; std::size_t const gridMax{std::numeric_limits::max()}; dim3 const gridSize{static_cast(std::min(gridx, gridMax)), static_cast(numSlots)}; fillBatch<<>>(data, indices, size, fillValues); } //! @param data expected shape [gridDim.y, size] //! @param indices expected shape [gridDim.y] //! @param size //! @param values expected shape [indicesRange, size] template __global__ void gatherBatch(T* data, T const* values, std::int32_t const* indices, std::size_t size) { auto const tidx = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; auto const stride = static_cast(blockDim.x) * gridDim.x; for (auto idx = tidx; idx < size; idx += stride) { auto const batchIdx = blockIdx.y; auto const slotIdx = indices[blockIdx.y]; data[batchIdx + idx] = values[slotIdx + idx]; } } template void invokeGatherBatch(IBuffer& buffer, IBuffer const& values, IBuffer const& slotIndices, std::size_t slotStride, CudaStream const& stream) { auto data = bufferCast(buffer); auto const* const indices = bufferCast(slotIndices); auto sparseValues = bufferCast(values); auto numSlots = slotIndices.getSize(); auto const size = slotStride; dim3 const blockSize{256}; std::size_t const gridx{tc::ceilDiv(size, blockSize.x)}; std::size_t const gridMax{std::numeric_limits::max()}; dim3 const gridSize{static_cast(std::min(gridx, gridMax)), static_cast(numSlots)}; gatherBatch<<>>(data, sparseValues, indices, size); } template __global__ void copyBatch(uint8_t const* srcData, uint8_t* dstData, SizeType64 const* srcOffsets, SizeType64 const* dstOffsets, SizeType64 const* sizes, SizeType64 const dataTypeSize) { constexpr auto VEC_ELTS = static_cast(sizeof(VecT)); SizeType64 const srcStartIdx = srcOffsets[blockIdx.y] * dataTypeSize; SizeType64 const dstStartIdx = dstOffsets[blockIdx.y] * dataTypeSize; SizeType64 const size = sizes[blockIdx.y] * dataTypeSize; SizeType64 const tidx = (static_cast(blockIdx.x) * blockDim.x + threadIdx.x) * VEC_ELTS; SizeType64 const stride = static_cast(blockDim.x) * gridDim.x * VEC_ELTS; SizeType64 const srcEndIdx = srcStartIdx + size; SizeType64 srcIdx = srcStartIdx + tidx; SizeType64 dstIdx = dstStartIdx + tidx; for (; srcIdx < srcEndIdx; srcIdx += stride, dstIdx += stride) { *reinterpret_cast(&dstData[dstIdx]) = *reinterpret_cast(&srcData[srcIdx]); } } template __global__ void add(T* data, std::size_t size, T const value) { auto const tidx = (static_cast(blockIdx.x) * blockDim.x) + threadIdx.x; auto const stride = static_cast(blockDim.x) * gridDim.x; for (auto idx = tidx; idx < size; idx += stride) { data[idx] += value; } } template __global__ void reduceSum(T* output, T const* input, std::size_t size) { T threadSum = 0; for (auto index = threadIdx.x; index < size; index += blockDim.x) { threadSum += input[index]; } T blockSum = 0; if (blockDim.x <= 32) { blockSum = tc::warpReduceSum(threadSum); } else { blockSum = tc::blockReduceSum(threadSum); } __syncthreads(); if (threadIdx.x == 0) { *output = blockSum; } } template void invokeReduce(IBuffer& output, IBuffer const& input, CudaStream const& stream) { TLLM_CHECK_WITH_INFO(input.getDataType() == output.getDataType(), "Input and output have different data types"); TLLM_CHECK_WITH_INFO(output.getSize() == 1, common::fmtstr("Output size (%ld) has to be 1", output.getSize())); auto outputPtr = bufferCast(output); auto inputPtr = bufferCast(input); auto const size = input.getSize(); dim3 const blockSize{std::min(512U, static_cast(size))}; dim3 const gridSize{1}; reduceSum<<>>(outputPtr, inputPtr, size); } __global__ void transposeWithOutputOffset(SizeType32* output, SizeType32 const* input, SizeType32 const nbInputRows, SizeType32 const inputRowSize, SizeType32 const outputRowSize, SizeType32 const outputOffset) { SizeType32 const tidx = (blockIdx.x * blockDim.x) + threadIdx.x; SizeType32 const tidy = (blockIdx.y * blockDim.y) + threadIdx.y; for (SizeType32 batchIdx = tidy; batchIdx < nbInputRows; batchIdx += blockDim.y * gridDim.y) { for (SizeType32 tokenIdx = tidx; tokenIdx < inputRowSize; tokenIdx += blockDim.x * gridDim.x) { auto const inputIdx = (batchIdx * inputRowSize) + tokenIdx; auto const outputIdx = (tokenIdx * outputRowSize) + outputOffset + batchIdx; output[outputIdx] = input[inputIdx]; } } } __global__ void buildAttentionMask(SizeType32* attentionMask, SizeType32 const size, SizeType32 const padId) { SizeType32 const tid = (blockIdx.x * blockDim.x) + threadIdx.x; for (SizeType32 i = tid; i < size; i += blockDim.x * gridDim.x) { auto const x = attentionMask[i]; attentionMask[i] = (x != padId); } } __global__ void extendAttentionMask( SizeType32* newMask, SizeType32 const* oldMask, SizeType32 const batchSize, SizeType32 const seqLength) { SizeType32 const tidx = (blockIdx.x * blockDim.x) + threadIdx.x; SizeType32 const tidy = (blockIdx.y * blockDim.y) + threadIdx.y; for (SizeType32 batchIdx = tidy; batchIdx < batchSize; batchIdx += blockDim.y * gridDim.y) { for (SizeType32 tokenIdx = tidx; tokenIdx < seqLength + 1; tokenIdx += blockDim.x * gridDim.x) { SizeType32 const oldIndex = (batchIdx * seqLength) + tokenIdx; SizeType32 const newIndex = (batchIdx * (seqLength + 1)) + tokenIdx; newMask[newIndex] = (tokenIdx < seqLength) ? oldMask[oldIndex] : 1; } } } __global__ void copyInputToOutput(TokenIdType* outputIds, TokenIdType const* inputIds, SizeType32 const* inputLengths, TokenIdType const padId, SizeType32 const batchSize, SizeType32 const beamWidth, SizeType32 const maxInputLength, SizeType32 const maxSeqLength) { SizeType32 const tidx = (blockIdx.x * blockDim.x) + threadIdx.x; SizeType32 const tidy = (blockIdx.y * blockDim.y) + threadIdx.y; for (SizeType32 batchIdx = tidy; batchIdx < batchSize; batchIdx += blockDim.y * gridDim.y) { auto const inputLength = inputLengths[batchIdx]; for (SizeType32 tokenIdx = tidx; tokenIdx < maxInputLength; tokenIdx += blockDim.x * gridDim.x) { auto const value = (tokenIdx < inputLength) ? inputIds[(batchIdx * maxInputLength) + tokenIdx] : padId; for (SizeType32 beamIdx = 0; beamIdx < beamWidth; ++beamIdx) { auto const outputIdx = tc::flat_index3(batchIdx, beamIdx, tokenIdx, beamWidth, maxSeqLength); outputIds[outputIdx] = value; } } } } // In the following kernel, we launch a grid with batchSize blocks of threads. Each thread block // copies the logits from the "logits" tensor to the "lastTokenLogits" tensor for the last token // of each sequence. // // TODO: Enable vector copies for higher BW utilization. template __global__ void gatherLastTokenLogitsKernel( T* lastTokenLogits, T const* logits, int const* lastTokenIds, int beamWidth, int vocabSizePadded) { // This sequence. int const seqIdx = blockIdx.x; // Find the index of the last token in that sequence. // Since lastTokenIds is the accumulated length instead of real ids, so we need to minus 1. // For length [11, 23], we hope to get the results of id 10 and 22, in fact. int const lastTokenIdx = lastTokenIds[seqIdx] - 1; // The output pointer. T* lastTokenLogitsPtr = &lastTokenLogits[seqIdx * beamWidth * vocabSizePadded]; // The input pointer. T const* logitsPtr = &logits[lastTokenIdx * vocabSizePadded]; // The threads in the block collaborate to copy the logits. for (int idx = threadIdx.x; idx < vocabSizePadded; idx += blockDim.x) { T value = logitsPtr[idx]; for (int beamIdx = 0; beamIdx < beamWidth; ++beamIdx) { lastTokenLogitsPtr[(beamIdx * vocabSizePadded) + idx] = value; } } } template void invokeGatherLastTokenLogits( ITensor& output, ITensor const& input, ITensor const& lastTokenIds, CudaStream const& stream) { auto const& outputShape = output.getShape(); auto const batchSize = static_cast(outputShape.d[0]); auto const beamWidth = static_cast(outputShape.d[1]); auto const vocabSizePadded = static_cast(outputShape.d[2]); auto const& inputShape = input.getShape(); TLLM_CHECK_WITH_INFO(inputShape.d[0] == batchSize, "Invalid input shape: dim[0]"); TLLM_CHECK_WITH_INFO(inputShape.d[2] == vocabSizePadded, "Invalid input shape: dim[2]"); dim3 const blockSize{256, 1}; dim3 const gridSize{static_cast(batchSize), 1}; gatherLastTokenLogitsKernel<<>>(bufferCast(output), bufferCast(input), bufferCast(lastTokenIds), static_cast(beamWidth), vocabSizePadded); } __global__ void copyPackedInputToOutput(TokenIdType* outputIds, TokenIdType const* inputIds, SizeType32 const* inputOffsets, TokenIdType const padId, SizeType32 const batchSize, SizeType32 const beamWidth, SizeType32 const maxInputLength, SizeType32 const maxSeqLength) { SizeType32 const tidx = (blockIdx.x * blockDim.x) + threadIdx.x; SizeType32 const tidy = (blockIdx.y * blockDim.y) + threadIdx.y; for (SizeType32 batchIdx = tidy; batchIdx < batchSize; batchIdx += blockDim.y * gridDim.y) { auto const tokenBegin = inputOffsets[batchIdx]; auto const tokenEnd = inputOffsets[batchIdx + 1]; auto const inputLength = tokenEnd - tokenBegin; for (SizeType32 tokenIdx = tidx; tokenIdx < maxInputLength; tokenIdx += blockDim.x * gridDim.x) { auto const value = (tokenIdx < inputLength) ? inputIds[tokenBegin + tokenIdx] : padId; for (SizeType32 beamIdx = 0; beamIdx < beamWidth; ++beamIdx) { auto const outputIdx = tc::flat_index3(batchIdx, beamIdx, tokenIdx, beamWidth, maxSeqLength); outputIds[outputIdx] = value; } } } } template __global__ void scatterTensor(T* output, T const* input, std::uint32_t const batchSize, std::uint32_t const inputRowSize, std::size_t const outputRowSize, std::uint32_t const beamWidth) { auto const tidx = (static_cast(blockIdx.x) * blockDim.x) + threadIdx.x; auto const tidy = (static_cast(blockIdx.y) * blockDim.y) + threadIdx.y; auto const stridex = static_cast(blockDim.x) * gridDim.x; auto const stridey = static_cast(blockDim.y) * gridDim.y; for (auto batchIdx = tidy; batchIdx < batchSize; batchIdx += stridey) { for (auto columnIdx = tidx; columnIdx < inputRowSize; columnIdx += stridex) { auto const inputIdx = (batchIdx * inputRowSize) + columnIdx; auto const value = input[inputIdx]; std::size_t constexpr beamIdx{0}; auto const outputIdx = ((batchIdx * beamWidth + beamIdx) * outputRowSize) + columnIdx; output[outputIdx] = value; } } } template __global__ void tileTensor(T* output, T const* input, std::uint32_t const batchSize, std::size_t const inputRowSize, std::size_t const outputRowSize, std::uint32_t const beamWidth) { auto const tidx = (static_cast(blockIdx.x) * blockDim.x) + threadIdx.x; auto const tidy = (static_cast(blockIdx.y) * blockDim.y) + threadIdx.y; auto const stridex = static_cast(blockDim.x) * gridDim.x; auto const stridey = static_cast(blockDim.y) * gridDim.y; for (auto batchIdx = tidy; batchIdx < batchSize; batchIdx += stridey) { for (auto columnIdx = tidx; columnIdx < inputRowSize; columnIdx += stridex) { auto const inputIdx = (batchIdx * inputRowSize) + columnIdx; auto const value = input[inputIdx]; for (std::size_t beamIdx = 0; beamIdx < beamWidth; ++beamIdx) { auto const outputIdx = ((batchIdx * beamWidth + beamIdx) * outputRowSize) + columnIdx; output[outputIdx] = value; } } } } template void invokeScatterTensor(ITensor& output, ITensor const& input, SizeType32 beamWidth, CudaStream const& stream) { auto const& inputShape = input.getShape(); auto const nbInputRows = static_cast(inputShape.d[0]); auto const inputRowSize = input.getSize() / static_cast(nbInputRows); auto const& outputShape = output.getShape(); auto const nbOutputRows = static_cast(outputShape.d[0]); auto const outputRowSize = output.getSize() / static_cast(nbOutputRows); TLLM_CHECK_WITH_INFO(nbOutputRows == beamWidth * nbInputRows, common::fmtstr( "nbOutputRows (%d) must be beamWidth (%d) times nbInputRows (%d)", nbOutputRows, beamWidth, nbInputRows)); TLLM_CHECK_WITH_INFO(outputRowSize >= inputRowSize, common::fmtstr("output row size (%ld) must be at least input row size (%ld)", outputRowSize, inputRowSize)); dim3 const blockSize{256, 1}; std::size_t const gridx{tc::ceilDiv(inputRowSize, blockSize.x)}; std::size_t const gridMax{std::numeric_limits::max()}; dim3 const gridSize{static_cast(std::min(gridx, gridMax)), nbInputRows}; scatterTensor<<>>(bufferCast(output), bufferCast(input), nbInputRows, inputRowSize, outputRowSize, static_cast(beamWidth)); } template void invokeTileTensor(ITensor& output, ITensor const& input, SizeType32 const beamWidth, CudaStream const& stream) { auto const& inputShape = input.getShape(); auto const nbInputRows = static_cast(inputShape.d[0]); auto const inputRowSize = input.getSize() / static_cast(nbInputRows); auto const& outputShape = output.getShape(); auto const nbOutputRows = static_cast(outputShape.d[0]); auto const outputRowSize = output.getSize() / static_cast(nbOutputRows); TLLM_CHECK_WITH_INFO(nbOutputRows == beamWidth * nbInputRows, common::fmtstr( "nbOutputRows (%d) must be beamWidth (%d) times nbInputRows (%d)", nbOutputRows, beamWidth, nbInputRows)); TLLM_CHECK_WITH_INFO(outputRowSize >= inputRowSize, common::fmtstr("output row size (%ld) must be at least input row size (%ld)", outputRowSize, inputRowSize)); dim3 const blockSize{256, 1}; std::size_t const gridx{tc::ceilDiv(inputRowSize, blockSize.x)}; std::size_t const gridMax{std::numeric_limits::max()}; dim3 const gridSize{static_cast(std::min(gridx, gridMax)), nbInputRows}; tileTensor<<>>(bufferCast(output), bufferCast(input), nbInputRows, inputRowSize, outputRowSize, static_cast(beamWidth)); } // In the following kernel, we launch a grid with (microBatchSize * beamWidth, outputLen) blocks of threads. Each thread // block copies a `vocabSizePadded` length logits tensor from the "inputLogits (microBatchSize, beamWidth, // vocabSizePadded)" to the "outputGenerationLogits (batchSize, beamWidth, outputLen, vocabSizePadded)" template __global__ void mergeLogitsFragmentsKernel(T* output, T** fragmentsVector, int const outputLen, int firstBatchSlotIdx, int beamWidth, int vocabSizePadded, int stepOffset) { // output: shape: [batchSize, beamWidth, outputLen, vocabSize] // inputVecor.at(i): shape: [microBatchSize, beamWidth, vocabSize] // Current step int const curStep = blockIdx.y; // The relatively batch slot index that this thread block in microBatchSize. int const relativeBatchSlotIdx = blockIdx.x / beamWidth; // The Absolute batch slot index in batchSize. int const absoluteBatchSlotIdx = firstBatchSlotIdx + relativeBatchSlotIdx; // The beam index that this thread block process int const mbeamIdx = blockIdx.x % beamWidth; // The output pointer unsigned int const outputOffset = (absoluteBatchSlotIdx * beamWidth * outputLen + mbeamIdx * outputLen + curStep + stepOffset) * vocabSizePadded; T* outputPtr = &output[outputOffset]; unsigned int const inputOffset = (relativeBatchSlotIdx * beamWidth + mbeamIdx) * vocabSizePadded; // The input pointer. T const* inputPtr = &fragmentsVector[curStep][inputOffset]; // The threads in the block collaborate to copy the logits. for (int idx = threadIdx.x; idx < vocabSizePadded; idx += blockDim.x) { outputPtr[idx] = inputPtr[idx]; } } template void invokeMergeLogitsFragments(BufferManager const& bufferManager, ITensor& output, std::vector const& fragmentsVector, ITensor& cachePointerDevice, ITensor& cachePointerHost, SizeType32 firstBatchSlotIdx, SizeType32 microBatchSize, SizeType32 beamWidth, CudaStream const& stream, int stepOffset) { size_t const fragmentsVectorSize = fragmentsVector.size(); auto cachePointerHostPtr = bufferCast(cachePointerHost); for (int i = 0; i < fragmentsVectorSize; i++) { cachePointerHostPtr[i] = bufferCast(*fragmentsVector.at(i)); } bufferManager.copy(cachePointerHost, cachePointerDevice); dim3 const blockSize(256); dim3 const gridSize{(unsigned int) (microBatchSize * beamWidth), (unsigned int) (fragmentsVectorSize)}; auto const& outputShape = output.getShape(); auto const vocabSizePadded = static_cast(outputShape.d[outputShape.nbDims - 1]); auto const outputLen = static_cast(outputShape.d[outputShape.nbDims - 2]); TLLM_CHECK_WITH_INFO(outputLen >= fragmentsVectorSize, "Fragments size does not match outputLen size"); mergeLogitsFragmentsKernel<<>>(bufferCast(output), bufferCast(cachePointerDevice), outputLen, firstBatchSlotIdx, beamWidth, vocabSizePadded, stepOffset); } void invokeCopyPackedInputToOutput(ITensor& outputIds, ITensor const& inputIds, ITensor const& inputOffsets, SizeType32 const maxInputLength, TokenIdType const padId, CudaStream const& stream) { TLLM_CHECK_WITH_INFO( inputIds.getDataType() == outputIds.getDataType(), "Input and output have different data types"); auto const& outputShape = outputIds.getShape(); TLLM_CHECK_WITH_INFO( outputShape.nbDims == 3, common::fmtstr("Output shape must have 3 dimensions, but has %d", outputShape.nbDims)); auto const batchSize = static_cast(inputOffsets.getSize()) - 1; SizeType32 const beamWidth = outputShape.d[1]; SizeType32 const maxSeqLength = outputShape.d[2]; TLLM_CHECK_WITH_INFO(batchSize == outputShape.d[0], common::fmtstr("Output ids batch size (" FMT_DIM ") does not match inputOffsets batch size (%d)", outputShape.d[0], batchSize)); TLLM_CHECK_WITH_INFO(maxInputLength < maxSeqLength, common::fmtstr( "Output sequence length (%d) has to be larger than max input length (%d)", maxSeqLength, maxInputLength)); dim3 const blockSize(256, 1); dim3 const gridSize((maxInputLength + blockSize.x - 1) / blockSize.x, batchSize); copyPackedInputToOutput<<>>(bufferCast(outputIds), bufferCast(inputIds), bufferCast(inputOffsets), padId, batchSize, beamWidth, maxInputLength, maxSeqLength); } } // namespace template void invokeFill(IBuffer& buffer, T const value, CudaStream const& stream) { auto data = bufferCast(buffer); auto const size = buffer.getSize(); dim3 const blockSize{256}; std::size_t const gridx{tc::ceilDiv(size, blockSize.x)}; std::size_t const gridMax{std::numeric_limits::max()}; dim3 const gridSize{static_cast(std::min(gridx, gridMax))}; fill<<>>(data, size, value); } // template instantiation template void invokeFill(IBuffer&, SizeType64, CudaStream const&); template void invokeFill(IBuffer&, std::int32_t, CudaStream const&); template void invokeFill(IBuffer&, std::int8_t, CudaStream const&); template void invokeFill(IBuffer&, std::uint8_t, CudaStream const&); template void invokeFill(IBuffer&, bool, CudaStream const&); template void invokeFill(IBuffer&, half, CudaStream const&); template void invokeFill(IBuffer&, float, CudaStream const&); #ifdef ENABLE_BF16 template void invokeFill(IBuffer&, __nv_bfloat16, CudaStream const&); #endif // ENABLE_BF16 void invokeFillBatch(IBuffer& buffer, IBuffer const& slotIndices, std::size_t slotStride, IBuffer const& values, CudaStream const& stream) { switch (buffer.getDataType()) { case nvinfer1::DataType::kINT32: invokeFillBatch(buffer, slotIndices, slotStride, values, stream); break; case nvinfer1::DataType::kINT8: invokeFillBatch(buffer, slotIndices, slotStride, values, stream); break; case nvinfer1::DataType::kFLOAT: invokeFillBatch(buffer, slotIndices, slotStride, values, stream); break; default: TLLM_THROW("data type not supported"); } } void invokeGatherBatch(IBuffer& buffer, IBuffer const& values, IBuffer const& slotIndices, std::size_t slotStride, CudaStream const& stream) { switch (buffer.getDataType()) { case nvinfer1::DataType::kINT32: invokeGatherBatch(buffer, values, slotIndices, slotStride, stream); break; case nvinfer1::DataType::kINT8: invokeGatherBatch(buffer, values, slotIndices, slotStride, stream); break; case nvinfer1::DataType::kFLOAT: invokeGatherBatch(buffer, values, slotIndices, slotStride, stream); break; default: TLLM_THROW("data type not supported"); } } void invokeCopyBatch(IBuffer const& srcBuffer, IBuffer& dstBuffer, IBuffer const& srcOffsets, IBuffer const& dstOffsets, IBuffer const& sizes, std::size_t maxStride, CudaStream const& stream) { auto const* srcDataPtr = reinterpret_cast(srcBuffer.data()); auto* dstDataPtr = reinterpret_cast(dstBuffer.data()); auto const* srcOffsetsPtr = bufferCast(srcOffsets); auto const* dstOffsetsPtr = bufferCast(dstOffsets); auto const* sizesPtr = bufferCast(sizes); auto numSlots = srcOffsets.getSize(); auto const size = maxStride; auto const dataTypeSize = BufferDataType(srcBuffer.getDataType()).getSize(); auto const copyRowSizeInBytes = size * dataTypeSize; auto copyBatchInvocation = copyBatch; auto vectorSize = 1; if (dataTypeSize % 16 == 0) { vectorSize = 16; copyBatchInvocation = copyBatch; } else if (dataTypeSize % 8 == 0) { vectorSize = 8; copyBatchInvocation = copyBatch; } else if (dataTypeSize % 4 == 0) { vectorSize = 4; copyBatchInvocation = copyBatch; } else if (dataTypeSize % 2 == 0) { vectorSize = 2; copyBatchInvocation = copyBatch; } dim3 const blockSize{256}; std::size_t const gridx{tc::ceilDiv(copyRowSizeInBytes / vectorSize, blockSize.x)}; std::size_t const gridMax{std::numeric_limits::max()}; dim3 const gridSize{static_cast(std::min(gridx, gridMax)), static_cast(numSlots)}; copyBatchInvocation<<>>( srcDataPtr, dstDataPtr, srcOffsetsPtr, dstOffsetsPtr, sizesPtr, static_cast(dataTypeSize)); } template void invokeAdd(IBuffer& buffer, T const value, CudaStream const& stream) { auto data = bufferCast(buffer); auto const size = buffer.getSize(); dim3 const blockSize{256}; std::size_t const gridx{tc::ceilDiv(size, blockSize.x)}; std::size_t const gridMax{std::numeric_limits::max()}; dim3 const gridSize{static_cast(std::min(gridx, gridMax))}; add<<>>(data, size, value); } template void invokeAdd(IBuffer&, std::int32_t, CudaStream const&); template void invokeAdd(IBuffer&, std::int8_t, CudaStream const&); template void invokeAdd(IBuffer&, float, CudaStream const&); void reduce(IBuffer& output, IBuffer const& input, CudaStream const& stream) { switch (input.getDataType()) { case nvinfer1::DataType::kINT32: invokeReduce(output, input, stream); break; case nvinfer1::DataType::kFLOAT: invokeReduce(output, input, stream); break; case nvinfer1::DataType::kHALF: invokeReduce(output, input, stream); break; case nvinfer1::DataType::kINT8: invokeReduce(output, input, stream); break; default: TLLM_THROW("data type not supported"); } } void invokeTransposeWithOutputOffset( ITensor& output, ITensor const& input, SizeType32 const outputOffset, CudaStream const& stream) { TLLM_CHECK_WITH_INFO(input.getDataType() == output.getDataType(), "Input and output have different data types"); auto const& inputShape = input.getShape(); TLLM_CHECK_WITH_INFO( inputShape.nbDims == 2, common::fmtstr("Input shape must have 2 dimensions, but has %d", inputShape.nbDims)); SizeType32 const nbInputRows = inputShape.d[0]; SizeType32 const inputRowSize = inputShape.d[1]; auto const& outputShape = output.getShape(); TLLM_CHECK_WITH_INFO( outputShape.nbDims == 2, common::fmtstr("Output shape must have 2 dimensions, but has %d", outputShape.nbDims)); SizeType32 const nbOutputRows = outputShape.d[0]; SizeType32 const outputRowSize = outputShape.d[1]; TLLM_CHECK_WITH_INFO(inputRowSize == nbOutputRows, common::fmtstr("Input dim 1 (%d) and output dim 0 (%d) differ", inputRowSize, nbOutputRows)); TLLM_CHECK_WITH_INFO(outputOffset + nbInputRows <= outputRowSize, common::fmtstr("Input (%d rows) does not fit into output (%d columns, offset %d)", nbInputRows, inputRowSize, outputOffset)); dim3 const blockSize(256, 1); dim3 const gridSize((inputRowSize + blockSize.x - 1) / blockSize.x, nbInputRows); transposeWithOutputOffset<<>>(bufferCast(output), bufferCast(input), nbInputRows, inputRowSize, outputRowSize, outputOffset); } void invokeInclusiveSum(IBuffer& output, IBuffer const& input, BufferManager const& manager, CudaStream const& stream) { auto const size = input.getSize(); auto const* inputData = bufferCast(input); auto* outputData = bufferCast(output); std::size_t tempStorageBytes{0}; cub::DeviceScan::InclusiveSum(nullptr, tempStorageBytes, inputData, outputData, size, stream.get()); auto tempStorage = manager.gpu(tempStorageBytes, nvinfer1::DataType::kUINT8); auto* tempStorageData = bufferCast(*tempStorage); cub::DeviceScan::InclusiveSum(tempStorageData, tempStorageBytes, inputData, outputData, size, stream.get()); } void invokeBuildAttentionMask(ITensor& attentionMask, SizeType32 const padId, CudaStream const& stream) { TLLM_CHECK_WITH_INFO( TRTDataType::value == attentionMask.getDataType(), "attentionMask has wrong data type"); auto const size = attentionMask.getSize(); dim3 const blockSize(256); dim3 const gridSize((size + blockSize.x - 1) / blockSize.x); buildAttentionMask<<>>(bufferCast(attentionMask), size, padId); } void invokeExtendAttentionMask(ITensor& newMask, ITensor const& oldMask, CudaStream const& stream) { TLLM_CHECK_WITH_INFO(TRTDataType::value == newMask.getDataType(), "attentionMask has wrong data type"); TLLM_CHECK_WITH_INFO(TRTDataType::value == oldMask.getDataType(), "attentionMask has wrong data type"); auto const& shape = oldMask.getShape(); SizeType32 const batchSize = shape.d[0]; SizeType32 const seqLength = shape.d[1]; dim3 const blockSize(256, 1); dim3 const gridSize((seqLength + blockSize.x - 1) / blockSize.x, batchSize); extendAttentionMask<<>>( bufferCast(newMask), bufferCast(oldMask), batchSize, seqLength); } void invokeCopyInputToOutput(ITensor& outputIds, ITensor const& inputIds, ITensor const& inputLengths, TokenIdType const padId, CudaStream const& stream) { TLLM_CHECK_WITH_INFO( inputIds.getDataType() == outputIds.getDataType(), "Input and output have different data types"); auto const& inputShape = inputIds.getShape(); auto const& outputShape = outputIds.getShape(); TLLM_CHECK_WITH_INFO( outputShape.nbDims == 3, common::fmtstr("Output shape must have 3 dimensions, but has %d", outputShape.nbDims)); auto const batchSize = static_cast(inputLengths.getSize()); SizeType32 const maxInputLength = inputShape.d[inputShape.nbDims - 1]; SizeType32 const beamWidth = outputShape.d[1]; SizeType32 const maxSeqLength = outputShape.d[2]; auto const inputBatchSize = inputIds.getSize() / maxInputLength; TLLM_CHECK_WITH_INFO(std::size_t(batchSize) == inputBatchSize, common::fmtstr("Input ids batch size (%ld) does not match inputLengths size (%ld)", inputBatchSize, std::size_t(batchSize))); TLLM_CHECK_WITH_INFO(batchSize == outputShape.d[0], common::fmtstr( "Output ids batch size (" FMT_DIM ") does not match inputLengths size (%d)", outputShape.d[0], batchSize)); TLLM_CHECK_WITH_INFO(maxInputLength < maxSeqLength, common::fmtstr( "Output sequence length (%d) has to be larger than max input length (%d)", maxSeqLength, maxInputLength)); dim3 const blockSize(256, 1); dim3 const gridSize((maxInputLength + blockSize.x - 1) / blockSize.x, batchSize); copyInputToOutput<<>>(bufferCast(outputIds), bufferCast(inputIds), bufferCast(inputLengths), padId, batchSize, beamWidth, maxInputLength, maxSeqLength); } void initOutputIds(ITensor& outputIds, ITensor const& inputIds, ITensor const& inputLengths, ITensor const& inputOffsets, TokenIdType const padId, TokenIdType const endId, SizeType32 const maxInputLength, bool const inputPacked, CudaStream const& stream) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); kernels::invokeFill(outputIds, endId, stream); if (inputPacked) { invokeCopyPackedInputToOutput(outputIds, inputIds, inputOffsets, maxInputLength, padId, stream); } else { kernels::invokeCopyInputToOutput(outputIds, inputIds, inputLengths, padId, stream); } TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } void scatterTensor(ITensor& output, ITensor const& input, SizeType32 beamWidth, CudaStream const& stream) { switch (input.getDataType()) { case nvinfer1::DataType::kINT32: invokeScatterTensor(output, input, beamWidth, stream); break; case nvinfer1::DataType::kFLOAT: invokeScatterTensor(output, input, beamWidth, stream); break; case nvinfer1::DataType::kHALF: invokeScatterTensor(output, input, beamWidth, stream); break; case nvinfer1::DataType::kINT8: invokeScatterTensor(output, input, beamWidth, stream); break; #ifdef ENABLE_FP8 case nvinfer1::DataType::kFP8: invokeScatterTensor<__nv_fp8_e4m3>(output, input, beamWidth, stream); break; #endif // ENABLE_FP8 default: TLLM_THROW("data type not supported"); } } void tileTensor(ITensor& output, ITensor const& input, SizeType32 beamWidth, CudaStream const& stream) { switch (input.getDataType()) { case nvinfer1::DataType::kINT32: invokeTileTensor(output, input, beamWidth, stream); break; case nvinfer1::DataType::kFLOAT: invokeTileTensor(output, input, beamWidth, stream); break; case nvinfer1::DataType::kHALF: invokeTileTensor(output, input, beamWidth, stream); break; #ifdef ENABLE_BF16 case nvinfer1::DataType::kBF16: invokeTileTensor<__nv_bfloat16>(output, input, beamWidth, stream); break; #endif // ENABLE_BF16 case nvinfer1::DataType::kINT8: invokeTileTensor(output, input, beamWidth, stream); break; #ifdef ENABLE_FP8 case nvinfer1::DataType::kFP8: invokeTileTensor<__nv_fp8_e4m3>(output, input, beamWidth, stream); break; #endif // ENABLE_FP8 default: TLLM_THROW("data type not supported"); } } void gatherLastTokenLogits(ITensor& output, ITensor const& input, ITensor const& lastTokenIds, CudaStream const& stream) { switch (input.getDataType()) { case nvinfer1::DataType::kFLOAT: invokeGatherLastTokenLogits(output, input, lastTokenIds, stream); break; case nvinfer1::DataType::kHALF: invokeGatherLastTokenLogits(output, input, lastTokenIds, stream); break; #ifdef ENABLE_BF16 case nvinfer1::DataType::kBF16: invokeGatherLastTokenLogits<__nv_bfloat16>(output, input, lastTokenIds, stream); break; #endif // ENABLE_BF16 #ifdef ENABLE_FP8 case nvinfer1::DataType::kFP8: invokeGatherLastTokenLogits<__nv_fp8_e4m3>(output, input, lastTokenIds, stream); break; #endif // ENABLE_FP8 default: TLLM_THROW("data type not supported"); } } void mergeLogitsFragments(BufferManager const& bufferManager, ITensor& output, std::vector const& fragmentsVector, ITensor& cachePointerDevice, ITensor& cachePointerHost, SizeType32 firstBatchSlotIdx, SizeType32 const microBatchSize, SizeType32 const beamWidth, CudaStream const& stream, int stepOffset) { switch (output.getDataType()) { case nvinfer1::DataType::kFLOAT: invokeMergeLogitsFragments(bufferManager, output, fragmentsVector, cachePointerDevice, cachePointerHost, firstBatchSlotIdx, microBatchSize, beamWidth, stream, stepOffset); break; case nvinfer1::DataType::kHALF: invokeMergeLogitsFragments(bufferManager, output, fragmentsVector, cachePointerDevice, cachePointerHost, firstBatchSlotIdx, microBatchSize, beamWidth, stream, stepOffset); break; #ifdef ENABLE_BF16 case nvinfer1::DataType::kBF16: invokeMergeLogitsFragments<__nv_bfloat16>(bufferManager, output, fragmentsVector, cachePointerDevice, cachePointerHost, firstBatchSlotIdx, microBatchSize, beamWidth, stream, stepOffset); break; #endif // ENABLE_BF16 #ifdef ENABLE_FP8 case nvinfer1::DataType::kFP8: invokeMergeLogitsFragments<__nv_fp8_e4m3>(bufferManager, output, fragmentsVector, cachePointerDevice, cachePointerHost, firstBatchSlotIdx, microBatchSize, beamWidth, stream, stepOffset); break; #endif // ENABLE_FP8 default: TLLM_THROW("data type not supported"); } } void invokeUpdateKVBlockArrayDraftTokenLocation(ITensor const& seqAcceptedDraftTokenOffsets, ITensor const& packedAcceptedDraftTokensIndices, ITensor const& pastKeyValueLengths, void* const* pointerArray, ::tensorrt_llm::kernels::KVCacheIndex const* offsetArray, SizeType32 layerCount, SizeType32 seqCount, SizeType32 numKVHeads, SizeType32 sizeInBytesPerKVHead, SizeType32 rewindDraftTokenCommonCount, SizeType32 const* rewindDraftTokenSeparateAdjustments, ITensor const& seqSlotRemapping, ITensor const& batchSlots, SizeType32 maxKVCacheLen, SizeType32 maxBlocksPerSeq, SizeType32 tokensPerBlock, bool canUseOneMoreBlock, cudaStream_t stream) { tensorrt_llm::kernels::speculative_decoding::updateKVBlockArrayDraftTokenLocation( bufferCast(seqAcceptedDraftTokenOffsets), bufferCast(packedAcceptedDraftTokensIndices), bufferCast(pastKeyValueLengths), pointerArray, offsetArray, layerCount, seqCount, numKVHeads, sizeInBytesPerKVHead, rewindDraftTokenCommonCount, rewindDraftTokenSeparateAdjustments, bufferCast(seqSlotRemapping), bufferCast(batchSlots), maxKVCacheLen, maxBlocksPerSeq, tokensPerBlock, canUseOneMoreBlock, stream); } } // namespace tensorrt_llm::runtime::kernels