/* * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. * Copyright (c) 2021, NAVER Corp. Authored by CLOVA. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include #ifndef CUDART_VERSION #error CUDART_VERSION Undefined! #elif (CUDART_VERSION >= 11050) #include #else #include "3rdparty/cub/cub.cuh" #endif #include "tensorrt_llm/common/logger.h" #include "tensorrt_llm/common/reduceKernelUtils.cuh" #include "tensorrt_llm/common/stringUtils.h" #include "tensorrt_llm/kernels/samplingTopKKernels.h" using namespace tensorrt_llm::common; namespace tensorrt_llm { namespace kernels { __global__ void curandInitialize(curandState_t* state, const int size, const unsigned long long randomSeed) { if (threadIdx.x + blockIdx.x * blockDim.x < size) { curand_init(randomSeed, 0, 0, &state[blockIdx.x * blockDim.x + threadIdx.x]); } } void invokeCurandInitialize( curandState_t* state, const size_t batchSize, const unsigned long long randomSeed, cudaStream_t stream) { dim3 block(256); dim3 grid((int) (ceil(batchSize * 1.0 / 256))); curandInitialize<<>>(state, batchSize, randomSeed); } __global__ void curandBatchInitialize(curandState_t* states, const int size, const unsigned long long* randomSeeds) { int idx = threadIdx.x + blockIdx.x * blockDim.x; if (idx < size) { curand_init(randomSeeds[idx], 0, 0, &states[idx]); } } void invokeCurandBatchInitialize( curandState_t* states, const size_t batchSize, const unsigned long long* randomSeeds, cudaStream_t stream) { dim3 block(256); dim3 grid((int) (ceil(batchSize * 1.0 / 256))); curandBatchInitialize<<>>(states, batchSize, randomSeeds); } template __global__ void addBiasEndMask( T* logits, const T* bias, const int* endIds, const bool* finished, const int vocabSize, const int vocabSizePadded) { int bid = blockIdx.x; bool finish = finished != nullptr ? finished[bid] : false; int offset = bid * vocabSizePadded; const bool IS_FP16 = std::is_same::value; const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX; for (int tid = threadIdx.x; tid < vocabSizePadded; tid += blockDim.x) { if (tid >= vocabSize) { logits[offset + tid] = -MAX_T_VAL; } else if (finish) { logits[offset + tid] = (tid == endIds[bid]) ? MAX_T_VAL : -MAX_T_VAL; } else { if (bias != nullptr) { logits[offset + tid] += bias[tid]; } } } } template void invokeAddBiasEndMask(T* logits, const T* bias, const int* endIds, const bool* finished, const int batchSize, const int vocabSize, const int vocabSizePadded, cudaStream_t stream) { dim3 grid(batchSize); dim3 block(min(vocabSizePadded, 1024)); // n is the vocabSize, e.g., 30000, 7000.... vocabSize is usually very big. addBiasEndMask<<>>(logits, bias, endIds, finished, vocabSize, vocabSizePadded); } template void invokeAddBiasEndMask(float* logits, const float* bias, const int* endIds, const bool* finished, const int batchSize, const int vocabSize, const int vocabSizePadded, cudaStream_t stream); template void invokeAddBiasEndMask(half* logits, const half* bias, const int* endIds, const bool* finished, const int batchSize, const int vocabSize, const int vocabSizePadded, cudaStream_t stream); template __global__ void topKStage1(const T* __restrict logProbs, T* tmpLogProbs, int* topKTmpIdBuf, T* topKTmpValBuf, const bool* finished, const int maxTopK, const int* topKs, const int vocabSize, const int* endIds, const bool* skipDecode) { typedef cub::BlockReduce, BLOCK_SIZE_> BlockReduce; __shared__ typename BlockReduce::TempStorage tempStorage; const int tid = threadIdx.x; const int bid = blockIdx.x; const int batchId = bid / BLOCKS_PER_BEAM_; // row id for logProbs if (skipDecode != nullptr && skipDecode[batchId]) { return; } const int blockLane = bid % BLOCKS_PER_BEAM_; // block id for a beam const int k = (topKs != nullptr) ? topKs[batchId] : maxTopK; // batchId = batch index const int tmpLogBufIndex = batchId * vocabSize; const int tmpTopKBufIndex = batchId * BLOCKS_PER_BEAM_ * maxTopK + blockLane * k; TopK_2 partial; const bool IS_FP16 = std::is_same::value; const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX; if (finished != nullptr && finished[batchId] == true) { if (tid < k) { const int index = tmpTopKBufIndex + tid; if (blockLane == 0 && tid == 0) { const int endId = endIds[batchId]; topKTmpIdBuf[index] = tmpLogBufIndex + endId; topKTmpValBuf[index] = logProbs[tmpLogBufIndex + endId]; } else { topKTmpIdBuf[index] = -1; topKTmpValBuf[index] = -MAX_T_VAL; } } return; } for (int elemId = tid + blockLane * BLOCK_SIZE_; elemId < vocabSize; elemId += BLOCK_SIZE_ * BLOCKS_PER_BEAM_) { int index = elemId + tmpLogBufIndex; tmpLogProbs[index] = logProbs[index]; } for (int ite = 0; ite < k; ite++) { partial.init(); #pragma unroll for (int elemId = tid + blockLane * BLOCK_SIZE_; elemId < vocabSize; elemId += BLOCK_SIZE_ * BLOCKS_PER_BEAM_) { int index = elemId + tmpLogBufIndex; partial.insert(tmpLogProbs[index], index); } TopK_2 total = BlockReduce(tempStorage).Reduce(partial, reduce_topk_op_2); if (tid == 0) { const int index = tmpTopKBufIndex + ite; topKTmpIdBuf[index] = total.p; topKTmpValBuf[index] = total.u; if (total.p >= 0) { tmpLogProbs[total.p] = -MAX_T_VAL; } } __syncthreads(); } } template __global__ void topKStage2Sampling(const int* __restrict topKTmpIdBuf, T* topKTmpValBuf, int** ids, int* sequenceLengths, bool* finished, float* cumLogProbs, float* outputLogProbs, const int maxTopK, const int* topKs, const float topP, const float* topPs, curandState_t* curandstate, const int* endIds, const int vocabSize, const bool* skipDecode) { const bool IS_FP16 = std::is_same::value; const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX; const int tid = threadIdx.x; const int batchId = blockIdx.x; if (skipDecode != nullptr && skipDecode[batchId]) { return; } const int k = (topKs != nullptr) ? topKs[batchId] : maxTopK; const float probThreshold = (topPs != nullptr) ? topPs[batchId] : topP; const int size = k * BLOCKS_PER_BEAM_; const int stride = maxTopK * BLOCKS_PER_BEAM_; typedef cub::BlockReduce, BLOCK_SIZE_> BlockReduce; __shared__ typename BlockReduce::TempStorage tempStorage; extern __shared__ char array[]; __shared__ float s_sum; T* s_val = topKTmpValBuf + batchId * stride; int* s_id = reinterpret_cast(array); if (tid == 0) { s_sum = 0.0f; } TopK_2 partial; if (finished != nullptr && finished[batchId] == true) { return; } float* s_val2 = reinterpret_cast(s_id + k); float maxLogit; for (int ite = 0; ite < k; ite++) { partial.init(); #pragma unroll for (int i = tid; i < size; i += BLOCK_SIZE_) { partial.insert((float) s_val[i], i); } TopK_2 total = BlockReduce(tempStorage).Reduce(partial, reduce_topk_op_2); if (tid == 0) { if (ite == 0) { maxLogit = total.u; } s_id[ite] = total.p; s_val[total.p] = -MAX_T_VAL; // when cumLogProbs are computed, topKTmpValBuf (logits_buf_) are // already pre-processed by softmax_kernel if (cumLogProbs == nullptr && outputLogProbs == nullptr) { total.u = __expf(total.u - maxLogit); } s_val2[ite] = total.u; s_sum += total.u; } __syncthreads(); } if (tid == 0) { float randNum = (float) curand_uniform(curandstate + blockIdx.x) * probThreshold * s_sum; for (int i = 0; i < k; i++) { float expLogit = s_val2[i]; randNum = randNum - expLogit; if (randNum <= 0.0f || i == k - 1) { ids[batchId][sequenceLengths[batchId]] = topKTmpIdBuf[batchId * stride + s_id[i]] % vocabSize; if (cumLogProbs != nullptr || outputLogProbs != nullptr) { float logProb = logf(expLogit); if (cumLogProbs != nullptr) { cumLogProbs[batchId] += logProb; } if (outputLogProbs != nullptr) { // 'outputLogProbs' is the probability induced by the top-k // sampling. We normalize the probability 'expLogit' of the // selected token by the probability 's_sum' of a set of top-k // tokens, meaning the logProb is the probability of the selected // token, conditioned on the event that it is selected, i.e., // log_prob = log P(i | i is in top-k) = log(expLogit / s_sum). outputLogProbs[batchId] = logProb - logf(s_sum); } } break; } } if (sequenceLengths != nullptr && finished != nullptr) { const int seqLen = sequenceLengths[batchId]; if (ids[batchId][seqLen] == endIds[batchId]) { finished[batchId] = true; // Do not increase seq len when EOS is generated. Seq len should always contain only tokens to be // outputted } else { finished[batchId] = false; sequenceLengths[batchId] += 1; } } } } #define CASE_K(K_MAX, BLOCK_SIZE_1_, BLOCK_SIZE_2_, BLOCKS_PER_BEAM_) \ topKStage1<<>>( \ logProbs, tempLogProbs, topKTmpIdBuf, topKTmpValBuf, finished, maxTopK, topKs, vocabSize, endIds, skipDecode); \ topKStage2Sampling \ <<>>(topKTmpIdBuf, \ topKTmpValBuf, ids, sequenceLengths, finished, cumLogProbs, outputLogProbs, maxTopK, topKs, topP, topPs, \ curandstate, endIds, vocabSize, skipDecode); \ break; template void invokeBatchTopKSampling(void* workspace, size_t& workspaceSize, const T* logProbs, int** ids, int* sequenceLengths, bool* finished, float* cumLogProbs, float* outputLogProbs, curandState_t* curandstate, const int maxTopK, const int* topKs, const float topP, const float* topPs, const int vocabSizePadded, const int* endIds, cudaStream_t stream, const int batchSize, const bool* skipDecode) { TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__); // Not allow an ambiguous inputs topP and topPs. assert(topP == 1.0f || topPs == nullptr); const int vocabSize = vocabSizePadded; const int maxBlockPerBeam = 8; int tempLogProbsBufSize = batchSize * vocabSize; // type float int topKTmpIdsBufSize = batchSize * maxTopK * maxBlockPerBeam; // type int int topKTmpValBuf_size = batchSize * maxTopK * maxBlockPerBeam; // type float // prevent memory misaligned address tempLogProbsBufSize = (int) (ceil(tempLogProbsBufSize / 4.)) * 4; topKTmpIdsBufSize = (int) (ceil(topKTmpIdsBufSize / 4.)) * 4; topKTmpValBuf_size = (int) (ceil(topKTmpValBuf_size / 4.)) * 4; if (workspace == nullptr) { workspaceSize = sizeof(T) * tempLogProbsBufSize + sizeof(int) * topKTmpIdsBufSize + sizeof(T) * topKTmpValBuf_size; return; } T* tempLogProbs = (T*) workspace; int* topKTmpIdBuf = (int*) (tempLogProbs + tempLogProbsBufSize); T* topKTmpValBuf = (T*) (topKTmpIdBuf + topKTmpIdsBufSize); int logMaxTopK(0); int recursor(maxTopK - 1); while (recursor >>= 1) ++logMaxTopK; switch (logMaxTopK) { case 0: case 1: case 2: case 3: // 0 < maxTopK <= 16 CASE_K(16, 128, 128, 8); case 4: // 16 < maxTopK <= 32 CASE_K(32, 256, 128, 8); case 5: // 32 < maxTopK <= 64 CASE_K(64, 256, 256, 8); case 6: case 7: case 8: case 9: // 64 < maxTopK <= 1024 CASE_K(1024, 256, 256, 8); default: throw std::domain_error(fmtstr("top-k kernel supports 1<=k<=1024 but got k=%d", maxTopK)); } } #undef CASE_K template void invokeBatchTopKSampling(void* workspace, size_t& workspaceSize, const float* logProbs, int** ids, int* sequenceLengths, bool* finished_buf, float* cumLogProbs, float* outputLogProbs, curandState_t* curandstate, const int maxTopK, const int* topKs, const float topP, const float* topPs, const int vocabSizePadded, const int* endIds, cudaStream_t stream, const int batchSize, const bool* skipDecode); template void invokeBatchTopKSampling(void* workspace, size_t& workspaceSize, const half* logProbs, int** ids, int* sequenceLengths, bool* finished_buf, float* cumLogProbs, float* outputLogProbs, curandState_t* curandstate, const int maxTopK, const int* topKs, const float topP, const float* topPs, const int vocabSizePadded, const int* endIds, cudaStream_t stream, const int batchSize, const bool* skipDecode); template void invokeTopKSampling(void* workspace, size_t& workspaceSize, const T* logProbs, int** ids, int* sequenceLengths, bool* finished_buf, float* cumLogProbs, float* outputLogProbs, curandState_t* curandstate, const int topK, const float topP, const int vocabSizePadded, const int* endIds, cudaStream_t stream, const int batchSize, const bool* skipDecode) { invokeBatchTopKSampling(workspace, workspaceSize, logProbs, ids, sequenceLengths, finished_buf, cumLogProbs, outputLogProbs, curandstate, topK, nullptr, topP, nullptr, vocabSizePadded, endIds, stream, batchSize, skipDecode); } template void invokeTopKSampling(void* workspace, size_t& workspaceSize, const float* logProbs, int** ids, int* sequenceLengths, bool* finished_buf, float* cumLogProbs, float* outputLogProbs, curandState_t* curandstate, const int topK, const float topP, const int vocabSizePadded, const int* endIds, cudaStream_t stream, const int batchSize, const bool* skipDecode); template void invokeTopKSampling(void* workspace, size_t& workspaceSize, const half* logProbs, int** ids, int* sequenceLengths, bool* finished_buf, float* cumLogProbs, float* outputLogProbs, curandState_t* curandstate, const int topK, const float topP, const int vocabSizePadded, const int* endIds, cudaStream_t stream, const int batchSize, const bool* skipDecode); } // namespace kernels } // namespace tensorrt_llm