/* * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "tensorrt_llm/executor/dataTransceiverState.h" #include "tensorrt_llm/executor/executor.h" #include "tensorrt_llm/executor/serialization.h" #include "tensorrt_llm/executor/serializeUtils.h" #include namespace su = tensorrt_llm::executor::serialize_utils; namespace tensorrt_llm::executor { ContextPhaseParams::ContextPhaseParams(VecTokens firstGenTokens, RequestIdType reqId, void* state, std::optional draftTokens, std::optional ctxDpRank, std::optional disaggInfoEndpoint) : mReqId{reqId} , mFirstGenTokens{std::move(firstGenTokens)} , mState{StatePtr{state, deleter}} , mDraftTokens{std::move(draftTokens)} , mCtxDpRank{ctxDpRank} , mDisaggInfoEndpoint{std::move(disaggInfoEndpoint)} { } ContextPhaseParams::ContextPhaseParams(VecTokens firstGenTokens, RequestIdType reqId, std::optional draftTokens, std::optional ctxDpRank, std::optional disaggInfoEndpoint) : mReqId{reqId} , mFirstGenTokens{std::move(firstGenTokens)} , mDraftTokens{std::move(draftTokens)} , mCtxDpRank{ctxDpRank} , mDisaggInfoEndpoint{std::move(disaggInfoEndpoint)} { } ContextPhaseParams::ContextPhaseParams(VecTokens firstGenTokens, RequestIdType reqId, std::vector const& serializedState, std::optional draftTokens, std::optional ctxDpRank, std::optional disaggInfoEndpoint) : mReqId{reqId} , mFirstGenTokens{std::move(firstGenTokens)} , mDraftTokens{std::move(draftTokens)} , mCtxDpRank{ctxDpRank} , mDisaggInfoEndpoint{std::move(disaggInfoEndpoint)} { su::VectorWrapBuf strbuf(const_cast&>(serializedState)); std::istream is(&strbuf); auto dataTransceiverState = Serialization::deserializeDataTransceiverState(is); auto dataTransceiverStatePtr = std::make_unique(std::move(dataTransceiverState)); mState = StatePtr{dataTransceiverStatePtr.release(), deleter}; } ContextPhaseParams::ContextPhaseParams(ContextPhaseParams const& other) : mReqId{other.mReqId} , mFirstGenTokens{other.mFirstGenTokens} , mDraftTokens{other.mDraftTokens} , mCtxDpRank{other.mCtxDpRank} , mDisaggInfoEndpoint{other.mDisaggInfoEndpoint} { // Since the internal header files implement the destructor while using the declaration of this // type, a `unique_ptr` with a custom destructor member is used here. if (other.mState) { auto* otherState = static_cast(other.mState.get()); mState = StatePtr{std::make_unique(*otherState).release(), deleter}; } } ContextPhaseParams::ContextPhaseParams(ContextPhaseParams&&) noexcept = default; ContextPhaseParams& ContextPhaseParams::operator=(ContextPhaseParams const& other) { *this = ContextPhaseParams{other}; return *this; } ContextPhaseParams& ContextPhaseParams::operator=(ContextPhaseParams&&) noexcept = default; ContextPhaseParams::~ContextPhaseParams() = default; VecTokens const& ContextPhaseParams::getFirstGenTokens() const& noexcept { return mFirstGenTokens; } void ContextPhaseParams::setFirstGenTokens(VecTokens const& firstGenTokens) noexcept { mFirstGenTokens = firstGenTokens; } std::optional const& ContextPhaseParams::getDraftTokens() const& noexcept { return mDraftTokens; } void ContextPhaseParams::setDraftTokens(std::optional const& draftTokens) noexcept { mDraftTokens = draftTokens; } VecTokens ContextPhaseParams::popFirstGenTokens() && noexcept { return std::move(mFirstGenTokens); } ContextPhaseParams::RequestIdType ContextPhaseParams::getReqId() const noexcept { return mReqId; } void ContextPhaseParams::setReqId(RequestIdType const& reqId) noexcept { mReqId = reqId; } void const* ContextPhaseParams::getState() const noexcept { return mState.get(); } void* ContextPhaseParams::getState() noexcept { return mState.get(); } std::vector ContextPhaseParams::getSerializedState() const noexcept { return Serialization::serialize(*static_cast(mState.get())); } void* ContextPhaseParams::releaseState() noexcept { return mState.release(); } std::optional ContextPhaseParams::getCtxDpRank() const noexcept { return mCtxDpRank; } void ContextPhaseParams::setCtxDpRank(std::optional const& ctxDpRank) noexcept { mCtxDpRank = ctxDpRank; } std::optional const& ContextPhaseParams::getDisaggInfoEndpoint() const noexcept { return mDisaggInfoEndpoint; } void ContextPhaseParams::setDisaggInfoEndpoint(std::optional const& disaggInfoEndpoint) noexcept { mDisaggInfoEndpoint = disaggInfoEndpoint; } void ContextPhaseParams::deleter(void const* data) { using StateT = DataTransceiverState const; std::default_delete()(static_cast(data)); } bool ContextPhaseParams::operator==(ContextPhaseParams const& other) const noexcept { if (mFirstGenTokens != other.mFirstGenTokens || mReqId != other.mReqId || mDraftTokens != other.mDraftTokens || mDisaggInfoEndpoint != other.mDisaggInfoEndpoint || mCtxDpRank != other.mCtxDpRank || static_cast(mState) != static_cast(other.mState)) { return false; } return !mState || *static_cast(mState.get()) == *static_cast(other.mState.get()); } } // namespace tensorrt_llm::executor