mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
refactor: simplify forward methods in GptDecoderBatched (#3076)
* refactor: Remove ForwardType enum from GptDecoderBatched - Remove ForwardType enum from GptDecoderBatched - Simplify forwardDispatch and forwardDecoder methods Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com> * refactor: Remove forwardDecoder method from GptDecoderBatched - Eliminate the forwardDecoder method to streamline the decoding process. - Update forwardDispatch to directly call forwardAsync when input batch size is greater than zero. Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com> * refactor: Move event handling from forwardDispatch to forwardAsync Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com> --------- Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>
This commit is contained in:
parent
94dd456bd0
commit
3c3629c52a
@ -46,12 +46,6 @@ public:
|
||||
using TensorPtr = ITensor::SharedPtr;
|
||||
using SharedConstPtr = ITensor::SharedConstPtr;
|
||||
|
||||
enum class ForwardType
|
||||
{
|
||||
kASYNC,
|
||||
kSYNC
|
||||
};
|
||||
|
||||
GptDecoderBatched(
|
||||
CudaStreamPtr stream, SpeculativeDecodingMode const& speculativeDecodingMode, nvinfer1::DataType dtype);
|
||||
|
||||
@ -99,14 +93,11 @@ private:
|
||||
void setEagleInputs(decoder_batch::Input const& input);
|
||||
|
||||
//! @brief Calls decoders for tokens per engine step
|
||||
void forwardDispatch(decoder_batch::Output& output, decoder_batch::Input const& input, ForwardType forwardType);
|
||||
void forwardDispatch(decoder_batch::Output& output, decoder_batch::Input const& input);
|
||||
|
||||
//! @brief Prepare Input and Output for decoder step
|
||||
void prepareForward(SizeType32 step, decoder_batch::Output& output, decoder_batch::Input const& input);
|
||||
|
||||
//! @brief Calls decoder for whole batch
|
||||
void forwardDecoder(DecodingOutput& output, DecodingInput const& input, ForwardType forwardType);
|
||||
|
||||
private:
|
||||
CudaStreamPtr mRuntimeStream;
|
||||
CudaStreamPtr mDecoderStream;
|
||||
|
||||
@ -180,18 +180,9 @@ T maxOfActiveSlots(std::vector<T> const& values, std::vector<bool> const& active
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void GptDecoderBatched::forwardDispatch(
|
||||
decoder_batch::Output& output, decoder_batch::Input const& input, ForwardType forwardType)
|
||||
void GptDecoderBatched::forwardDispatch(decoder_batch::Output& output, decoder_batch::Input const& input)
|
||||
{
|
||||
auto eventStart = CudaEvent{};
|
||||
mRuntimeStream->record(eventStart);
|
||||
|
||||
bool const async = forwardType == ForwardType::kASYNC;
|
||||
|
||||
if (async)
|
||||
{
|
||||
mDecoderStream->wait(eventStart.get());
|
||||
}
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
|
||||
auto const maxDecodingEngineTokens
|
||||
= maxOfActiveSlots(mDecoderState->getJointDecodingInput().numDecodingEngineTokens, input.active);
|
||||
@ -199,15 +190,14 @@ void GptDecoderBatched::forwardDispatch(
|
||||
for (SizeType32 si = 0; si < maxDecodingEngineTokens; si += mDecoderState->getMaxDecodingDecoderTokens())
|
||||
{
|
||||
prepareForward(si, output, input);
|
||||
forwardDecoder(mDecoderState->getJointDecodingOutput(), mDecoderState->getJointDecodingInput(), forwardType);
|
||||
|
||||
if (mDecoderState->getJointDecodingInput().batchSize > 0)
|
||||
{
|
||||
mDecoder->forwardAsync(mDecoderState->getJointDecodingOutput(), mDecoderState->getJointDecodingInput());
|
||||
}
|
||||
}
|
||||
|
||||
if (async)
|
||||
{
|
||||
CudaEvent event{};
|
||||
mDecoderStream->record(event);
|
||||
mRuntimeStream->wait(event);
|
||||
}
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
GptDecoderBatched::DecoderFinishedEventPtr GptDecoderBatched::forwardAsync(
|
||||
@ -215,7 +205,15 @@ GptDecoderBatched::DecoderFinishedEventPtr GptDecoderBatched::forwardAsync(
|
||||
{
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
|
||||
forwardDispatch(output, input, ForwardType::kASYNC);
|
||||
auto eventStart = CudaEvent{};
|
||||
mRuntimeStream->record(eventStart);
|
||||
mDecoderStream->wait(eventStart.get());
|
||||
|
||||
forwardDispatch(output, input);
|
||||
|
||||
CudaEvent event{};
|
||||
mDecoderStream->record(event);
|
||||
mRuntimeStream->wait(event);
|
||||
|
||||
CudaEvent eventStop{};
|
||||
mRuntimeStream->record(eventStop);
|
||||
@ -332,29 +330,6 @@ void GptDecoderBatched::prepareForward(
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
void GptDecoderBatched::forwardDecoder(DecodingOutput& output, DecodingInput const& input, ForwardType forwardType)
|
||||
{
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
|
||||
if (input.batchSize > 0)
|
||||
{
|
||||
if (forwardType == ForwardType::kASYNC)
|
||||
{
|
||||
mDecoder->forwardAsync(output, input);
|
||||
}
|
||||
else if (forwardType == ForwardType::kSYNC)
|
||||
{
|
||||
mDecoder->forwardSync(output, input);
|
||||
}
|
||||
else
|
||||
{
|
||||
TLLM_THROW("Unknown ForwardType");
|
||||
}
|
||||
}
|
||||
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
void GptDecoderBatched::forward(decoder_batch::Output& output, decoder_batch::Input const& input)
|
||||
{
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
|
||||
Loading…
Reference in New Issue
Block a user