mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
* Use updateDecoderBuffers in python decoder. Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com> * Fix synchronize in trtllm decoder. Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com> * Enable by default. Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com> * Use guided_decoder to setup seqslots and free them. Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com> * Use always decode_async and update_requests. Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com> * Update decoder buffers. Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com> * Fix speculative decoding tests. Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com> * Send new_tensors_host instead of assuming dict. Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com> * Make default False in enable_trtllm_decoder. Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com> * Partially fix mtp, partially fix py_executor. Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com> * Update request states before sending disagg ctx cache. Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com> * Fix disagg test for torch decoder. Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com> * Make isend_tensor_list and recv_tensor_list for sending the tensors_host. Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com> * Formatting. Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com> * Fix rebase. Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com> * Add disagg serving case to guided decoder. Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com> * Get overlap scheduling to work. Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com> * Update cutlass to main. Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com> * Update after rebasing. Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com> * Formatting. Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com> * Update to use decode async and update requests. Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com> * Properly pass information to update_requests Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com> * Formatting. Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com> * Make disaggregated serving a step closer to working. Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com> * Fix rebase. Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com> * Fix rebase and format. Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com> * Copy new device tokens more pythonic. Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com> * Restore MTP add dummy reqs. Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com> * Add ordereddict import to py_executor. Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com> * Formatting. Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com> * Added seq slot manager. Add test. Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com> * Use transmission for single tensor except when list of tensors is received. Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com> * Add TRTLLMDecoder allocation to estimate max kv cache tokens. Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com> * Add stream synchronization Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com> * Formatting. Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com> * Make memory calculation of decoder adapt to the chosen decoder. Recognize decoder option passed in executorconfig. Make overlap scheduler test run on TinyLlama. Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com> * Format Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com> * Add decoder creation to estimate max kv. Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com> * Formatting. Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com> * Update submodule UCXX inline with main. Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com> --------- Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com>
91 lines
4.2 KiB
C++
91 lines
4.2 KiB
C++
/*
|
|
* SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
* SPDX-License-Identifier: Apache-2.0
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
#include "tensorrt_llm/batch_manager/updateDecoderBuffers.h"
|
|
#include "tensorrt_llm/batch_manager/decoderBuffers.h"
|
|
#include "tensorrt_llm/common/nvtxUtils.h"
|
|
#include "tensorrt_llm/runtime/iTensor.h"
|
|
|
|
namespace tensorrt_llm::batch_manager
|
|
{
|
|
|
|
using BufferManager = tensorrt_llm::runtime::BufferManager;
|
|
using TensorPtr = runtime::ITensor::SharedPtr;
|
|
using ITensor = runtime::ITensor;
|
|
using SizeType32 = tensorrt_llm::runtime::SizeType32;
|
|
|
|
runtime::CudaEvent UpdateDecoderBuffers::operator()(runtime::ModelConfig const& modelConfig,
|
|
DecoderBuffers& decoderBuffers, DecoderOutputBuffers& decoderOutputBuffers,
|
|
runtime::BufferManager const& copyBufferManager, runtime::GptDecoderBatched const& decoder, bool returnLogProbs,
|
|
runtime::CudaEvent const& decoderFinishEvent) const
|
|
{
|
|
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
|
NVTX3_SCOPED_RANGE(updateDecoderBuffers);
|
|
|
|
// Chain copy after decoder event, using a different stream
|
|
copyBufferManager.getStream().wait(decoderFinishEvent);
|
|
|
|
copyBufferManager.copy(*decoder.getDecoderState().getAllNewTokens(), *decoderOutputBuffers.newOutputTokensHost);
|
|
copyBufferManager.copy(
|
|
*decoder.getDecoderState().getJointDecodingOutput().lengths, *decoderOutputBuffers.sequenceLengthsHost);
|
|
|
|
auto const finishedSumDevice = decoder.getDecoderState().getFinishedSum();
|
|
copyBufferManager.copy(*finishedSumDevice, *decoderOutputBuffers.finishedSumHost);
|
|
auto const finishReasonsDevice = decoder.getDecoderState().getFinishReasons();
|
|
copyBufferManager.copy(*finishReasonsDevice, *decoderOutputBuffers.finishReasonsHost);
|
|
|
|
if (returnLogProbs)
|
|
{
|
|
copyBufferManager.copy(*decoder.getDecoderState().getCumLogProbs(), *decoderOutputBuffers.cumLogProbsHost);
|
|
copyBufferManager.copy(*decoder.getDecoderState().getLogProbs(), *decoderOutputBuffers.logProbsHost);
|
|
}
|
|
|
|
if (modelConfig.getSpeculativeDecodingMode().predictsDraftTokens())
|
|
{
|
|
// TODO(rkobus): keep data on device for next iteration
|
|
decoderBuffers.draftBuffers.nextDraftTokensDevice = decoder.getDecoderState().getNextDraftTokens();
|
|
copyBufferManager.copy(
|
|
*decoderBuffers.draftBuffers.nextDraftTokensDevice, *decoderBuffers.draftBuffers.nextDraftTokensHost);
|
|
|
|
if (modelConfig.getSpeculativeDecodingMode().variableDraftLength())
|
|
{
|
|
decoderBuffers.draftBuffers.nextDraftTokensLengthsDevice
|
|
= decoder.getDecoderState().getNextDraftTokensLengths();
|
|
decoderBuffers.draftBuffers.prevDraftTokensLengthsDevice
|
|
= decoder.getDecoderState().getPrevDraftTokensLengths();
|
|
copyBufferManager.copy(*decoderBuffers.draftBuffers.nextDraftTokensLengthsDevice,
|
|
*decoderBuffers.draftBuffers.nextDraftTokensLengthsHost);
|
|
copyBufferManager.copy(*decoderBuffers.draftBuffers.prevDraftTokensLengthsDevice,
|
|
*decoderBuffers.draftBuffers.prevDraftTokensLengthsHost);
|
|
}
|
|
}
|
|
|
|
if (modelConfig.getSpeculativeDecodingMode().needsKVCacheRewind())
|
|
{
|
|
decoderBuffers.draftBuffers.acceptedLengthsCumSumDevice = decoder.getDecoderState().getAcceptedLengthsCumSum();
|
|
decoderBuffers.draftBuffers.acceptedPackedPathsDevice = decoder.getDecoderState().getAcceptedPackedPaths();
|
|
}
|
|
|
|
runtime::CudaEvent copyEvent{};
|
|
copyBufferManager.getStream().record(copyEvent);
|
|
// Store the event for later sync. Sync stream before calling next decoder. Sync host before updating requests.
|
|
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
|
return copyEvent;
|
|
}
|
|
|
|
} // namespace tensorrt_llm::batch_manager
|