/* * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "tensorrt_llm/kernels/parallelDecoding/kvCacheUpdateKernels.h" #include "tensorrt_llm/thop/thUtils.h" #include namespace th = torch; namespace torch_ext { void updateKVCacheDraftTokenLocation(torch::Tensor seqAcceptedDraftTokenOffsetsTensor, torch::Tensor packedAcceptedDraftTokensIndicesTensor, torch::Tensor pastKeyValueLengthsTensor, bool usePagedKVCache, int64_t layerCount, int64_t numKVHeads, int64_t headSizeInBytes, int64_t rewindDraftTokenCount, int64_t maxKVCacheLen, th::optional rewindDraftTokenTensor, th::optional> pastKeyValueListOpt = th::nullopt, th::optional pointerArrayOpt = th::nullopt, th::optional offsetArrayOpt = th::nullopt, th::optional maxBlocksPerSeqOpt = th::nullopt, th::optional tokensPerBlockOpt = th::nullopt, th::optional stream_ptr = th::nullopt) { TLLM_CHECK_WITH_INFO( at::cuda::is_available(), "update_kv_cache_draft_token_location should be called with cuda enabled."); cudaStream_t stream; if (stream_ptr.has_value()) { stream = reinterpret_cast(stream_ptr.value()); } else { stream = at::cuda::getCurrentCUDAStream(); } TLLM_CHECK_WITH_INFO(seqAcceptedDraftTokenOffsetsTensor.dim() == 1 && seqAcceptedDraftTokenOffsetsTensor.scalar_type() == torch::kInt, "accepted_draft_token_offsets tensor should be 1D int tensor."); int seqCount = seqAcceptedDraftTokenOffsetsTensor.size(0) - 1; TLLM_CHECK_WITH_INFO(seqCount > 0, "seqCount should be larger than 0"); TLLM_CHECK_WITH_INFO(packedAcceptedDraftTokensIndicesTensor.dim() == 1 && packedAcceptedDraftTokensIndicesTensor.scalar_type() == torch::kInt, "packed_accepted_draft_tokens_indices tensor should be 1D int tensor."); TLLM_CHECK_WITH_INFO(pastKeyValueLengthsTensor.dim() == 1 && pastKeyValueLengthsTensor.size(0) == seqCount && pastKeyValueLengthsTensor.scalar_type() == torch::kInt, "past_key_value_lengths tensor should be 1D int tensor with same length as seqCount"); int* rewindDraftTokenTensorPtr = nullptr; if (rewindDraftTokenTensor.has_value()) { TLLM_CHECK_WITH_INFO(rewindDraftTokenTensor.value().dim() == 1 && rewindDraftTokenTensor.value().size(0) == seqCount && rewindDraftTokenTensor.value().scalar_type() == torch::kInt, "rewindDraftTokenTensor should be 1D int tensor same length as seqCount"); rewindDraftTokenTensorPtr = rewindDraftTokenTensor.value().data_ptr(); } if (usePagedKVCache) { TLLM_CHECK_WITH_INFO( pointerArrayOpt.has_value(), "pool_pointer_array should be set when using paged KV cache."); TLLM_CHECK_WITH_INFO(offsetArrayOpt.has_value(), "block_offset_array should be set when using paged KV cache."); TLLM_CHECK_WITH_INFO( maxBlocksPerSeqOpt.has_value(), "max_blocks_per_seq should be set when using paged KV cache."); TLLM_CHECK_WITH_INFO( tokensPerBlockOpt.has_value(), "tokens_per_block should be set when using paged KV cache."); auto const& pointerArray = pointerArrayOpt.value(); auto const& offsetArray = offsetArrayOpt.value(); tensorrt_llm::kernels::parallel_decoding::updateKVBlockArrayDraftTokenLocation( seqAcceptedDraftTokenOffsetsTensor.data_ptr(), packedAcceptedDraftTokensIndicesTensor.data_ptr(), pastKeyValueLengthsTensor.data_ptr(), reinterpret_cast(pointerArray.data_ptr()), offsetArray.data_ptr(), layerCount, seqCount, numKVHeads, headSizeInBytes, rewindDraftTokenCount, rewindDraftTokenTensorPtr, nullptr, maxKVCacheLen, maxBlocksPerSeqOpt.value(), tokensPerBlockOpt.value(), stream); } else { TLLM_CHECK_WITH_INFO( pastKeyValueListOpt.has_value(), "block_pointer_array should be set when using linear KV cache."); std::vector pastKeyValueList; pastKeyValueList.reserve(layerCount); for (auto& pastKeyValueTensor : pastKeyValueListOpt.value()) { pastKeyValueList.push_back(static_cast(pastKeyValueTensor.data_ptr())); } tensorrt_llm::kernels::parallel_decoding::updateLinearKVCacheDraftTokenLocation( seqAcceptedDraftTokenOffsetsTensor.data_ptr(), packedAcceptedDraftTokensIndicesTensor.data_ptr(), pastKeyValueLengthsTensor.data_ptr(), pastKeyValueList.data(), layerCount, seqCount, numKVHeads, headSizeInBytes, rewindDraftTokenCount, rewindDraftTokenTensorPtr, nullptr, maxKVCacheLen, stream); } } } // namespace torch_ext static auto update_kv_cache_draft_token_location = torch::RegisterOperators( "tensorrt_llm::update_kv_cache_draft_token_location", &torch_ext::updateKVCacheDraftTokenLocation);