TensorRT-LLMs/cpp/tensorrt_llm/batch_manager/updateDecoderBuffers.cpp
Daniel Cámpora 1299f27c74
fix: Fix C++ decoder synchronization in PyTorch (#3106)
* Use updateDecoderBuffers in python decoder.

Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com>

* Fix synchronize in trtllm decoder.

Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com>

* Enable by default.

Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com>

* Use guided_decoder to setup seqslots and free them.

Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com>

* Use always decode_async and update_requests.

Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com>

* Update decoder buffers.

Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com>

* Fix speculative decoding tests.

Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com>

* Send new_tensors_host instead of assuming dict.

Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com>

* Make default False in enable_trtllm_decoder.

Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com>

* Partially fix mtp, partially fix py_executor.

Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com>

* Update request states before sending disagg ctx cache.

Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com>

* Fix disagg test for torch decoder.

Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com>

* Make isend_tensor_list and recv_tensor_list for sending the tensors_host.

Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com>

* Formatting.

Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com>

* Fix rebase.

Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com>

* Add disagg serving case to guided decoder.

Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com>

* Get overlap scheduling to work.

Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com>

* Update cutlass to main.

Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com>

* Update after rebasing.

Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com>

* Formatting.

Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com>

* Update to use decode async and update requests.

Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com>

* Properly pass information to update_requests

Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com>

* Formatting.

Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com>

* Make disaggregated serving a step closer to working.

Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com>

* Fix rebase.

Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com>

* Fix rebase and format.

Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com>

* Copy new device tokens more pythonic.

Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com>

* Restore MTP add dummy reqs.

Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com>

* Add ordereddict import to py_executor.

Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com>

* Formatting.

Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com>

* Added seq slot manager. Add test.

Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com>

* Use transmission for single tensor except when list of tensors is received.

Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com>

* Add TRTLLMDecoder allocation to estimate max kv cache tokens.

Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com>

* Add stream synchronization

Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com>

* Formatting.

Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com>

* Make memory calculation of decoder adapt to the chosen decoder. Recognize decoder option passed in executorconfig. Make overlap scheduler test run on TinyLlama.

Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com>

* Format

Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com>

* Add decoder creation to estimate max kv.

Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com>

* Formatting.

Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com>

* Update submodule UCXX inline with main.

Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com>

---------

Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com>
2025-04-23 23:55:27 +08:00

91 lines
4.2 KiB
C++

/*
* SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/batch_manager/updateDecoderBuffers.h"
#include "tensorrt_llm/batch_manager/decoderBuffers.h"
#include "tensorrt_llm/common/nvtxUtils.h"
#include "tensorrt_llm/runtime/iTensor.h"
namespace tensorrt_llm::batch_manager
{
using BufferManager = tensorrt_llm::runtime::BufferManager;
using TensorPtr = runtime::ITensor::SharedPtr;
using ITensor = runtime::ITensor;
using SizeType32 = tensorrt_llm::runtime::SizeType32;
runtime::CudaEvent UpdateDecoderBuffers::operator()(runtime::ModelConfig const& modelConfig,
DecoderBuffers& decoderBuffers, DecoderOutputBuffers& decoderOutputBuffers,
runtime::BufferManager const& copyBufferManager, runtime::GptDecoderBatched const& decoder, bool returnLogProbs,
runtime::CudaEvent const& decoderFinishEvent) const
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
NVTX3_SCOPED_RANGE(updateDecoderBuffers);
// Chain copy after decoder event, using a different stream
copyBufferManager.getStream().wait(decoderFinishEvent);
copyBufferManager.copy(*decoder.getDecoderState().getAllNewTokens(), *decoderOutputBuffers.newOutputTokensHost);
copyBufferManager.copy(
*decoder.getDecoderState().getJointDecodingOutput().lengths, *decoderOutputBuffers.sequenceLengthsHost);
auto const finishedSumDevice = decoder.getDecoderState().getFinishedSum();
copyBufferManager.copy(*finishedSumDevice, *decoderOutputBuffers.finishedSumHost);
auto const finishReasonsDevice = decoder.getDecoderState().getFinishReasons();
copyBufferManager.copy(*finishReasonsDevice, *decoderOutputBuffers.finishReasonsHost);
if (returnLogProbs)
{
copyBufferManager.copy(*decoder.getDecoderState().getCumLogProbs(), *decoderOutputBuffers.cumLogProbsHost);
copyBufferManager.copy(*decoder.getDecoderState().getLogProbs(), *decoderOutputBuffers.logProbsHost);
}
if (modelConfig.getSpeculativeDecodingMode().predictsDraftTokens())
{
// TODO(rkobus): keep data on device for next iteration
decoderBuffers.draftBuffers.nextDraftTokensDevice = decoder.getDecoderState().getNextDraftTokens();
copyBufferManager.copy(
*decoderBuffers.draftBuffers.nextDraftTokensDevice, *decoderBuffers.draftBuffers.nextDraftTokensHost);
if (modelConfig.getSpeculativeDecodingMode().variableDraftLength())
{
decoderBuffers.draftBuffers.nextDraftTokensLengthsDevice
= decoder.getDecoderState().getNextDraftTokensLengths();
decoderBuffers.draftBuffers.prevDraftTokensLengthsDevice
= decoder.getDecoderState().getPrevDraftTokensLengths();
copyBufferManager.copy(*decoderBuffers.draftBuffers.nextDraftTokensLengthsDevice,
*decoderBuffers.draftBuffers.nextDraftTokensLengthsHost);
copyBufferManager.copy(*decoderBuffers.draftBuffers.prevDraftTokensLengthsDevice,
*decoderBuffers.draftBuffers.prevDraftTokensLengthsHost);
}
}
if (modelConfig.getSpeculativeDecodingMode().needsKVCacheRewind())
{
decoderBuffers.draftBuffers.acceptedLengthsCumSumDevice = decoder.getDecoderState().getAcceptedLengthsCumSum();
decoderBuffers.draftBuffers.acceptedPackedPathsDevice = decoder.getDecoderState().getAcceptedPackedPaths();
}
runtime::CudaEvent copyEvent{};
copyBufferManager.getStream().record(copyEvent);
// Store the event for later sync. Sync stream before calling next decoder. Sync host before updating requests.
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
return copyEvent;
}
} // namespace tensorrt_llm::batch_manager