refactor: simplify forward methods in GptDecoderBatched (#3076)

* refactor: Remove ForwardType enum from GptDecoderBatched

 - Remove ForwardType enum from GptDecoderBatched
 - Simplify forwardDispatch and forwardDecoder methods

Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>

* refactor: Remove forwardDecoder method from GptDecoderBatched

- Eliminate the forwardDecoder method to streamline the decoding process.
- Update forwardDispatch to directly call forwardAsync when input batch size is greater than zero.

Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>

* refactor: Move event handling from forwardDispatch to forwardAsync

Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>

---------

Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>
This commit is contained in:
Robin Kobus 2025-03-26 13:45:04 +01:00 committed by GitHub
parent 94dd456bd0
commit 3c3629c52a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 18 additions and 52 deletions

View File

@ -46,12 +46,6 @@ public:
using TensorPtr = ITensor::SharedPtr;
using SharedConstPtr = ITensor::SharedConstPtr;
enum class ForwardType
{
kASYNC,
kSYNC
};
GptDecoderBatched(
CudaStreamPtr stream, SpeculativeDecodingMode const& speculativeDecodingMode, nvinfer1::DataType dtype);
@ -99,14 +93,11 @@ private:
void setEagleInputs(decoder_batch::Input const& input);
//! @brief Calls decoders for tokens per engine step
void forwardDispatch(decoder_batch::Output& output, decoder_batch::Input const& input, ForwardType forwardType);
void forwardDispatch(decoder_batch::Output& output, decoder_batch::Input const& input);
//! @brief Prepare Input and Output for decoder step
void prepareForward(SizeType32 step, decoder_batch::Output& output, decoder_batch::Input const& input);
//! @brief Calls decoder for whole batch
void forwardDecoder(DecodingOutput& output, DecodingInput const& input, ForwardType forwardType);
private:
CudaStreamPtr mRuntimeStream;
CudaStreamPtr mDecoderStream;

View File

@ -180,18 +180,9 @@ T maxOfActiveSlots(std::vector<T> const& values, std::vector<bool> const& active
}
} // namespace
void GptDecoderBatched::forwardDispatch(
decoder_batch::Output& output, decoder_batch::Input const& input, ForwardType forwardType)
void GptDecoderBatched::forwardDispatch(decoder_batch::Output& output, decoder_batch::Input const& input)
{
auto eventStart = CudaEvent{};
mRuntimeStream->record(eventStart);
bool const async = forwardType == ForwardType::kASYNC;
if (async)
{
mDecoderStream->wait(eventStart.get());
}
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
auto const maxDecodingEngineTokens
= maxOfActiveSlots(mDecoderState->getJointDecodingInput().numDecodingEngineTokens, input.active);
@ -199,15 +190,14 @@ void GptDecoderBatched::forwardDispatch(
for (SizeType32 si = 0; si < maxDecodingEngineTokens; si += mDecoderState->getMaxDecodingDecoderTokens())
{
prepareForward(si, output, input);
forwardDecoder(mDecoderState->getJointDecodingOutput(), mDecoderState->getJointDecodingInput(), forwardType);
if (mDecoderState->getJointDecodingInput().batchSize > 0)
{
mDecoder->forwardAsync(mDecoderState->getJointDecodingOutput(), mDecoderState->getJointDecodingInput());
}
}
if (async)
{
CudaEvent event{};
mDecoderStream->record(event);
mRuntimeStream->wait(event);
}
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
GptDecoderBatched::DecoderFinishedEventPtr GptDecoderBatched::forwardAsync(
@ -215,7 +205,15 @@ GptDecoderBatched::DecoderFinishedEventPtr GptDecoderBatched::forwardAsync(
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
forwardDispatch(output, input, ForwardType::kASYNC);
auto eventStart = CudaEvent{};
mRuntimeStream->record(eventStart);
mDecoderStream->wait(eventStart.get());
forwardDispatch(output, input);
CudaEvent event{};
mDecoderStream->record(event);
mRuntimeStream->wait(event);
CudaEvent eventStop{};
mRuntimeStream->record(eventStop);
@ -332,29 +330,6 @@ void GptDecoderBatched::prepareForward(
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
void GptDecoderBatched::forwardDecoder(DecodingOutput& output, DecodingInput const& input, ForwardType forwardType)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
if (input.batchSize > 0)
{
if (forwardType == ForwardType::kASYNC)
{
mDecoder->forwardAsync(output, input);
}
else if (forwardType == ForwardType::kSYNC)
{
mDecoder->forwardSync(output, input);
}
else
{
TLLM_THROW("Unknown ForwardType");
}
}
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
void GptDecoderBatched::forward(decoder_batch::Output& output, decoder_batch::Input const& input)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);