/* * SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "tensorrt_llm/common/cudaUtils.h" #include "tensorrt_llm/common/opUtils.h" #include "tensorrt_llm/kernels/speculativeDecoding/draftTokenTreeKernels.h" #include "tensorrt_llm/kernels/speculativeDecoding/mtpKernels.h" #include "tensorrt_llm/runtime/torchUtils.h" namespace th = torch; namespace tl = tensorrt_llm; namespace tk = tensorrt_llm::kernels; namespace torch_ext { //////////////////////////////////////////////////////////////////////////////////////////////////////////// std::tuple mtp_prepare_drafter_inputs_op(th::Tensor& inputIds, th::Tensor& seqLens, th::Tensor& mtpPastHiddenStatesPtrs, th::Tensor& mtpPastTokensPtrs, th::Tensor& hiddenStates, th::Tensor& acceptedTokens, th::Tensor& numAcceptedTokens, th::Tensor& returnInputIds, th::Tensor& returnHiddenStates, int64_t numMTPModules, int64_t batchSize, int64_t numContextRequest, int64_t hiddenSize) { auto dataType = hiddenStates.scalar_type(); // Check auto inputIdsSizes = inputIds.sizes(); auto hiddenStatesSizes = hiddenStates.sizes(); TLLM_CHECK(inputIdsSizes[0] == hiddenStatesSizes[0]); auto seqLensSizes = seqLens.sizes(); TLLM_CHECK(seqLensSizes[0] == batchSize); auto stream = at::cuda::getCurrentCUDAStream(hiddenStates.get_device()); // Fill params tk::MTPPrepareDrafterInputsParam params; params.numMTPModules = numMTPModules; params.batchSize = batchSize; params.numContextRequest = numContextRequest; params.hiddenSize = hiddenSize; params.inputIds = reinterpret_cast(inputIds.data_ptr()); params.seqLens = reinterpret_cast(seqLens.data_ptr()); params.mtpPastHiddenStatesPtrs = reinterpret_cast(mtpPastHiddenStatesPtrs.data_ptr()); params.mtpPastTokensPtrs = reinterpret_cast(mtpPastTokensPtrs.data_ptr()); params.hiddenStates = reinterpret_cast(hiddenStates.data_ptr()); params.acceptedTokens = reinterpret_cast(acceptedTokens.data_ptr()); params.numAcceptedTokens = reinterpret_cast(numAcceptedTokens.data_ptr()); params.returnInputIds = reinterpret_cast(returnInputIds.data_ptr()); params.returnHiddenStates = reinterpret_cast(returnHiddenStates.data_ptr()); switch (dataType) { case torch::kFloat16: // Handle Float16 tk::invokeMTPPrepareDrafterInputs(params, stream); break; case torch::kFloat32: // Handle Float32 tk::invokeMTPPrepareDrafterInputs(params, stream); break; case torch::kBFloat16: // Handle BFloat16 tk::invokeMTPPrepareDrafterInputs<__nv_bfloat16>(params, stream); break; default: // Handle other data types throw std::invalid_argument("Invalid dtype, only supports float16, float32, and bfloat16"); break; } return std::make_tuple(returnInputIds, returnHiddenStates); } //////////////////////////////////////////////////////////////////////////////////////////////////////////// std::tuple mtp_sampling_and_accepted_draft_tokens_op(th::Tensor& logits, th::Tensor& draftTokens, th::Tensor& targetTokens, int64_t numMTPModules, int64_t batchSize, int64_t numContextRequest, int64_t vocabSize) { int const numGenerationRequest = batchSize - numContextRequest; auto dataType = logits.scalar_type(); // Check auto logitsSizes = logits.sizes(); TORCH_CHECK(logitsSizes.size() == 2, "logits must be a 2D Tensor"); TLLM_CHECK(logitsSizes[0] == (numContextRequest + numGenerationRequest * (numMTPModules + 1))); auto draftTokensSizes = draftTokens.sizes(); TORCH_CHECK(draftTokensSizes.size() == 1); TLLM_CHECK(draftTokensSizes[0] == (numGenerationRequest * numMTPModules)); auto stream = at::cuda::getCurrentCUDAStream(logits.get_device()); auto acceptedTokens = torch::ones({batchSize, numMTPModules + 1}, at::TensorOptions().dtype(torch::kInt32).device(logits.device())); auto numAcceptedTokens = torch::ones({batchSize}, at::TensorOptions().dtype(torch::kInt32).device(logits.device())); // Fill params tk::MTPSampleAndAcceptDraftTokensParam params; params.numMTPModules = numMTPModules; params.batchSize = batchSize; params.numContextRequest = numContextRequest; params.vocabSize = vocabSize; params.draftTokens = reinterpret_cast(draftTokens.data_ptr()); params.targetTokens = reinterpret_cast(targetTokens.data_ptr()); params.acceptedTokens = reinterpret_cast(acceptedTokens.data_ptr()); params.numAcceptedTokens = reinterpret_cast(numAcceptedTokens.data_ptr()); params.logits = logits.data_ptr(); switch (dataType) { case torch::kFloat16: // Handle Float16 tk::invokeMTPSampleAndAcceptDraftTokens(params, stream); break; case torch::kFloat32: // Handle Float32 tk::invokeMTPSampleAndAcceptDraftTokens(params, stream); break; case torch::kBFloat16: // Handle BFloat16 tk::invokeMTPSampleAndAcceptDraftTokens<__nv_bfloat16>(params, stream); break; default: // Handle other data types throw std::invalid_argument("Invalid dtype, only supports float16, float32, and bfloat16"); break; } return std::make_tuple(acceptedTokens, numAcceptedTokens); } //////////////////////////////////////////////////////////////////////////////////////////////////////////// std::tuple mtp_update_hidden_states_op(th::Tensor& inputIds, th::Tensor& seqLens, th::Tensor& targetModelHiddenStates, th::Tensor& mtpPastHiddenStatesPtrs, th::Tensor& mtpPastTokensPtrs, th::Tensor& numAcceptedTokens, int64_t numMTPModules, int64_t batchSize, int64_t numContextRequest, int64_t hiddenSize) { auto dataType = targetModelHiddenStates.scalar_type(); // Check auto inputIdsSizes = inputIds.sizes(); auto targetModelHiddenStatesSize = targetModelHiddenStates.sizes(); TLLM_CHECK(inputIdsSizes[0] == targetModelHiddenStatesSize[0]); auto numAcceptedTokensSize = numAcceptedTokens.sizes(); TLLM_CHECK(numAcceptedTokensSize[0] == batchSize); auto stream = at::cuda::getCurrentCUDAStream(targetModelHiddenStates.get_device()); // Fill params tk::MTPUpdateHiddenStatesParam params; params.numMTPModules = numMTPModules; params.batchSize = batchSize; params.numContextRequest = numContextRequest; params.hiddenSize = hiddenSize; params.inputIds = reinterpret_cast(inputIds.data_ptr()); params.seqLens = reinterpret_cast(seqLens.data_ptr()); params.targetModelHiddenStates = targetModelHiddenStates.data_ptr(); params.mtpPastHiddenStatesPtrs = reinterpret_cast(mtpPastHiddenStatesPtrs.data_ptr()); params.mtpPastTokensPtrs = reinterpret_cast(mtpPastTokensPtrs.data_ptr()); params.numAcceptedTokens = reinterpret_cast(numAcceptedTokens.data_ptr()); switch (dataType) { case torch::kFloat16: // Handle Float16 tk::invokeMTPUpdateHiddenStates(params, stream); break; case torch::kFloat32: // Handle Float32 tk::invokeMTPUpdateHiddenStates(params, stream); break; case torch::kBFloat16: // Handle BFloat16 tk::invokeMTPUpdateHiddenStates<__nv_bfloat16>(params, stream); break; default: // Handle other data types throw std::invalid_argument("Invalid dtype, only supports float16, float32, and bfloat16"); break; } return std::make_tuple(mtpPastHiddenStatesPtrs, mtpPastTokensPtrs); } //////////////////////////////////////////////////////////////////////////////////////////////////////////// std::tuple mtp_relaxed_acceptance_op(th::Tensor& reqSlotIds, th::Tensor& topKValue, th::Tensor& topKIndices, th::Tensor& draftTokens, th::Tensor& mtpRelaxedDelta, th::Tensor& numAcceptedTokens, th::Tensor& acceptedTokens, int64_t const numMTPModules, int64_t const batchSize, int64_t const numContextRequest, int64_t const relaxedTopK, double const relaxedDelta, int64_t const beginThinkingTokens, int64_t const endThinkingTokens) { auto dataType = topKValue.scalar_type(); // Check auto numGenerationRequest = batchSize - numContextRequest; auto topKValueSizes = topKValue.sizes(); TLLM_CHECK(topKValueSizes[0] == numGenerationRequest); TLLM_CHECK(topKValueSizes[1] == numMTPModules + 1); TLLM_CHECK(topKValueSizes[2] == relaxedTopK); auto draftTokensSizes = draftTokens.sizes(); TLLM_CHECK(draftTokensSizes[0] == numGenerationRequest); auto numAcceptedTokensSize = numAcceptedTokens.sizes(); TLLM_CHECK(numAcceptedTokensSize[0] == batchSize); auto stream = at::cuda::getCurrentCUDAStream(numAcceptedTokens.get_device()); // Fill params tk::MTPRelaxedAcceptanceParam params; params.numMTPModules = numMTPModules; params.batchSize = batchSize; params.numContextRequest = numContextRequest; params.relaxedTopK = relaxedTopK; params.relaxedDelta = (float) relaxedDelta; params.beginThinkingTokens = beginThinkingTokens; params.endThinkingTokens = endThinkingTokens; params.reqSlotIds = reinterpret_cast(reqSlotIds.data_ptr()); params.topKValue = reinterpret_cast(topKValue.data_ptr()); params.topKIndices = reinterpret_cast(topKIndices.data_ptr()); params.draftTokens = reinterpret_cast(draftTokens.data_ptr()); params.mtpRelaxedDelta = reinterpret_cast(mtpRelaxedDelta.data_ptr()); params.numAcceptedTokens = reinterpret_cast(numAcceptedTokens.data_ptr()); params.acceptedTokens = reinterpret_cast(acceptedTokens.data_ptr()); switch (dataType) { case torch::kFloat16: // Handle Float16 tk::invokeMTPRelaxedAcceptance(params, stream); break; case torch::kFloat32: // Handle Float32 tk::invokeMTPRelaxedAcceptance(params, stream); break; case torch::kBFloat16: // Handle BFloat16 tk::invokeMTPRelaxedAcceptance<__nv_bfloat16>(params, stream); break; default: // Handle other data types throw std::invalid_argument("Invalid dtype, only supports float16, float32, and bfloat16"); break; } return std::make_tuple(acceptedTokens, numAcceptedTokens); } //////////////////////////////////////////////////////////////////////////////////////////////////////////// void extract_real_draft_tokens_op(th::Tensor newDraftTokens, th::Tensor draftTokensBuffer, th::Tensor tokensGatherIdxForDrafterModel, th::Tensor topKList, th::Tensor draftTokensIndicesCumsum, int64_t curDraftIdx, int64_t batchSize, int64_t maxDraftLen, int64_t maxTotalDraftTokens, int64_t maxTopK) { // Args: // curDraftIdx: int // batchSize: int // maxTotalDraftTokens: int // maxTopK: int // tokensGatherIdxForDrafterModel: Tensor, int32, indices of the draft tokens that need to be expand this layer // shape: [numTokensExpandThisLayer] // topKList: Tensor, int32, top k value for each expandable token // shape: [numTokensExpandThisLayer] // draftTokensIndicesCumsum: Tensor, int32, the cumulative sum of the write back indices for each draft layer // shape: [maxDraftLen + 1] // newDraftTokens: Tensor, int64, the new draft tokens. We only need to extract this layer's tokens and write back // to the draftTokensBuffer // shape: [batchSize, maxTotalDraftTokens + 1 if curDraftIdx > 0 else 1, maxTopK] // draftTokensBuffer: Tensor, int64, the buffer to store the real draft tokens // shape: [maxBatchSize, maxTotalDraftTokens + 1] // Check the data types TLLM_CHECK(tokensGatherIdxForDrafterModel.scalar_type() == torch::kInt32); TLLM_CHECK(topKList.scalar_type() == torch::kInt32); TLLM_CHECK(draftTokensIndicesCumsum.scalar_type() == torch::kInt32); TLLM_CHECK(newDraftTokens.scalar_type() == torch::kInt64); TLLM_CHECK(draftTokensBuffer.scalar_type() == torch::kInt64); // Check the shape of 'tokensGatherIdxForDrafterModel' and 'topKList' auto numTokensExpandThisLayer = tokensGatherIdxForDrafterModel.size(0); TLLM_CHECK(numTokensExpandThisLayer > 0); TLLM_CHECK(topKList.size(0) == numTokensExpandThisLayer); // Check the shape of 'draftTokensIndicesCumsum' TLLM_CHECK(draftTokensIndicesCumsum.size(0) == maxDraftLen + 1); // Check the shape of 'newDraftTokens' TLLM_CHECK(newDraftTokens.size(0) == batchSize); if (curDraftIdx == 0) { TLLM_CHECK(newDraftTokens.size(1) == 1); TLLM_CHECK(newDraftTokens.size(2) == maxTopK); } else { TLLM_CHECK(newDraftTokens.size(1) == maxTotalDraftTokens + 1); TLLM_CHECK(newDraftTokens.size(2) == maxTopK); } // Check the shape of 'draftTokensBuffer' TLLM_CHECK(draftTokensBuffer.size(1) == maxTotalDraftTokens + 1); auto stream = at::cuda::getCurrentCUDAStream(newDraftTokens.get_device()); // Fill params tk::ExtractRealDraftTokensParam params; params.curDraftIdx = curDraftIdx; params.batchSize = batchSize; params.maxDraftLen = maxDraftLen; params.maxTotalDraftTokens = maxTotalDraftTokens; params.maxTopK = maxTopK; params.numTokensExpandThisLayer = numTokensExpandThisLayer; params.tokensGatherIdxForDrafterModel = reinterpret_cast(tokensGatherIdxForDrafterModel.data_ptr()); params.topKList = reinterpret_cast(topKList.data_ptr()); params.draftTokensIndicesCumsum = reinterpret_cast(draftTokensIndicesCumsum.data_ptr()); params.newDraftTokens = reinterpret_cast(newDraftTokens.data_ptr()); params.draftTokensBuffer = reinterpret_cast(draftTokensBuffer.data_ptr()); tk::invokeExtractRealDraftTokens(params, stream); } } // end namespace torch_ext TORCH_LIBRARY_FRAGMENT(trtllm, m) { m.def( "mtp_prepare_drafter_inputs_op(Tensor inputIds, Tensor seqLens, Tensor " "mtpPastHiddenStatesPtrs, Tensor mtpPastTokensPtrs, Tensor hiddenStates, " "Tensor acceptedTokens, Tensor numAcceptedTokens, Tensor returnInputIds, Tensor returnHiddenStates, " "int numMTPModules, int batchSize, int numContextRequest," "int hiddenSize) -> (Tensor, Tensor)"); } TORCH_LIBRARY_IMPL(trtllm, CUDA, m) { m.impl("mtp_prepare_drafter_inputs_op", &torch_ext::mtp_prepare_drafter_inputs_op); } //////////////////////////////////////////////////////////////////////////////////////////////////////////// TORCH_LIBRARY_FRAGMENT(trtllm, m) { m.def( "mtp_sampling_and_accepted_draft_tokens_op(Tensor logits, Tensor draftTokens, Tensor " "targetTokens, int numMTPModules, " "int batchSize, int numContextRequest, int vocabSize) -> (Tensor, Tensor)"); } TORCH_LIBRARY_IMPL(trtllm, CUDA, m) { m.impl("mtp_sampling_and_accepted_draft_tokens_op", &torch_ext::mtp_sampling_and_accepted_draft_tokens_op); } //////////////////////////////////////////////////////////////////////////////////////////////////////////// TORCH_LIBRARY_FRAGMENT(trtllm, m) { m.def( "mtp_update_hidden_states_op(Tensor inputIds, Tensor seqLens, Tensor targetModelHiddenStates, " "Tensor mtpPastHiddenStatesPtrs, Tensor mtpPastTokensPtrs, Tensor numAcceptedTokens, " "int numMTPModules, int batchSize, int numContextRequest, int hiddenSize) -> (Tensor, Tensor)"); } TORCH_LIBRARY_IMPL(trtllm, CUDA, m) { m.impl("mtp_update_hidden_states_op", &torch_ext::mtp_update_hidden_states_op); } //////////////////////////////////////////////////////////////////////////////////////////////////////////// TORCH_LIBRARY_FRAGMENT(trtllm, m) { m.def( "mtp_relaxed_acceptance_op(Tensor reqSlotIds, Tensor topKValue, Tensor topKIndices, Tensor draftTokens, " "Tensor mtpRelaxedDelta, Tensor numAcceptedTokens, Tensor acceptedTokens, " "int numMTPModules, int batchSize, int numContextRequest, int relaxedTopK, " "float relaxedDelta, int beginThinkingTokens, int endThinkingTokens) -> (Tensor, Tensor)"); } TORCH_LIBRARY_IMPL(trtllm, CUDA, m) { m.impl("mtp_relaxed_acceptance_op", &torch_ext::mtp_relaxed_acceptance_op); } //////////////////////////////////////////////////////////////////////////////////////////////////////////// TORCH_LIBRARY_FRAGMENT(trtllm, m) { m.def( "extract_real_draft_tokens_op(Tensor newDraftTokens, Tensor draftTokensBuffer, " "Tensor tokensGatherIdxForDrafterModel, Tensor topKList, Tensor draftTokensIndicesCumsum, " "int curDraftIdx, int batchSize, int maxDraftLen, int maxTotalDraftTokens, int maxTopK) -> ()"); } TORCH_LIBRARY_IMPL(trtllm, CUDA, m) { m.impl("extract_real_draft_tokens_op", &torch_ext::extract_real_draft_tokens_op); }