From 3c3629c52ac2d3b89e9c35475c0b454e47a3784c Mon Sep 17 00:00:00 2001 From: Robin Kobus <19427718+Funatiq@users.noreply.github.com> Date: Wed, 26 Mar 2025 13:45:04 +0100 Subject: [PATCH] refactor: simplify forward methods in GptDecoderBatched (#3076) * refactor: Remove ForwardType enum from GptDecoderBatched - Remove ForwardType enum from GptDecoderBatched - Simplify forwardDispatch and forwardDecoder methods Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com> * refactor: Remove forwardDecoder method from GptDecoderBatched - Eliminate the forwardDecoder method to streamline the decoding process. - Update forwardDispatch to directly call forwardAsync when input batch size is greater than zero. Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com> * refactor: Move event handling from forwardDispatch to forwardAsync Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com> --------- Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com> --- .../tensorrt_llm/runtime/gptDecoderBatched.h | 11 +--- .../runtime/gptDecoderBatched.cpp | 59 ++++++------------- 2 files changed, 18 insertions(+), 52 deletions(-) diff --git a/cpp/include/tensorrt_llm/runtime/gptDecoderBatched.h b/cpp/include/tensorrt_llm/runtime/gptDecoderBatched.h index 298922a749..679a860879 100644 --- a/cpp/include/tensorrt_llm/runtime/gptDecoderBatched.h +++ b/cpp/include/tensorrt_llm/runtime/gptDecoderBatched.h @@ -46,12 +46,6 @@ public: using TensorPtr = ITensor::SharedPtr; using SharedConstPtr = ITensor::SharedConstPtr; - enum class ForwardType - { - kASYNC, - kSYNC - }; - GptDecoderBatched( CudaStreamPtr stream, SpeculativeDecodingMode const& speculativeDecodingMode, nvinfer1::DataType dtype); @@ -99,14 +93,11 @@ private: void setEagleInputs(decoder_batch::Input const& input); //! @brief Calls decoders for tokens per engine step - void forwardDispatch(decoder_batch::Output& output, decoder_batch::Input const& input, ForwardType forwardType); + void forwardDispatch(decoder_batch::Output& output, decoder_batch::Input const& input); //! @brief Prepare Input and Output for decoder step void prepareForward(SizeType32 step, decoder_batch::Output& output, decoder_batch::Input const& input); - //! @brief Calls decoder for whole batch - void forwardDecoder(DecodingOutput& output, DecodingInput const& input, ForwardType forwardType); - private: CudaStreamPtr mRuntimeStream; CudaStreamPtr mDecoderStream; diff --git a/cpp/tensorrt_llm/runtime/gptDecoderBatched.cpp b/cpp/tensorrt_llm/runtime/gptDecoderBatched.cpp index eab75d0964..a91caa7c49 100644 --- a/cpp/tensorrt_llm/runtime/gptDecoderBatched.cpp +++ b/cpp/tensorrt_llm/runtime/gptDecoderBatched.cpp @@ -180,18 +180,9 @@ T maxOfActiveSlots(std::vector const& values, std::vector const& active } } // namespace -void GptDecoderBatched::forwardDispatch( - decoder_batch::Output& output, decoder_batch::Input const& input, ForwardType forwardType) +void GptDecoderBatched::forwardDispatch(decoder_batch::Output& output, decoder_batch::Input const& input) { - auto eventStart = CudaEvent{}; - mRuntimeStream->record(eventStart); - - bool const async = forwardType == ForwardType::kASYNC; - - if (async) - { - mDecoderStream->wait(eventStart.get()); - } + TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); auto const maxDecodingEngineTokens = maxOfActiveSlots(mDecoderState->getJointDecodingInput().numDecodingEngineTokens, input.active); @@ -199,15 +190,14 @@ void GptDecoderBatched::forwardDispatch( for (SizeType32 si = 0; si < maxDecodingEngineTokens; si += mDecoderState->getMaxDecodingDecoderTokens()) { prepareForward(si, output, input); - forwardDecoder(mDecoderState->getJointDecodingOutput(), mDecoderState->getJointDecodingInput(), forwardType); + + if (mDecoderState->getJointDecodingInput().batchSize > 0) + { + mDecoder->forwardAsync(mDecoderState->getJointDecodingOutput(), mDecoderState->getJointDecodingInput()); + } } - if (async) - { - CudaEvent event{}; - mDecoderStream->record(event); - mRuntimeStream->wait(event); - } + TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } GptDecoderBatched::DecoderFinishedEventPtr GptDecoderBatched::forwardAsync( @@ -215,7 +205,15 @@ GptDecoderBatched::DecoderFinishedEventPtr GptDecoderBatched::forwardAsync( { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); - forwardDispatch(output, input, ForwardType::kASYNC); + auto eventStart = CudaEvent{}; + mRuntimeStream->record(eventStart); + mDecoderStream->wait(eventStart.get()); + + forwardDispatch(output, input); + + CudaEvent event{}; + mDecoderStream->record(event); + mRuntimeStream->wait(event); CudaEvent eventStop{}; mRuntimeStream->record(eventStop); @@ -332,29 +330,6 @@ void GptDecoderBatched::prepareForward( TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } -void GptDecoderBatched::forwardDecoder(DecodingOutput& output, DecodingInput const& input, ForwardType forwardType) -{ - TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); - - if (input.batchSize > 0) - { - if (forwardType == ForwardType::kASYNC) - { - mDecoder->forwardAsync(output, input); - } - else if (forwardType == ForwardType::kSYNC) - { - mDecoder->forwardSync(output, input); - } - else - { - TLLM_THROW("Unknown ForwardType"); - } - } - - TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); -} - void GptDecoderBatched::forward(decoder_batch::Output& output, decoder_batch::Input const& input) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);