From ead4fc3336df2f734cbdf4a409694ca112470cf1 Mon Sep 17 00:00:00 2001 From: Chuang Zhu <111838961+chuangz0@users.noreply.github.com> Date: Wed, 7 Jan 2026 09:27:47 +0000 Subject: [PATCH] change context params and disagg params Signed-off-by: Chuang Zhu <111838961+chuangz0@users.noreply.github.com> --- cpp/include/tensorrt_llm/executor/executor.h | 33 ++++++- cpp/tensorrt_llm/batch_manager/llmRequest.cpp | 11 ++- .../executor/contextPhaseParams.cpp | 92 ++++++++++++++++--- cpp/tensorrt_llm/executor/serialization.cpp | 15 ++- .../nanobind/batch_manager/bindings.cpp | 2 +- .../nanobind/executor/request.cpp | 40 +++++--- .../pybind/batch_manager/bindings.cpp | 2 +- cpp/tensorrt_llm/pybind/executor/request.cpp | 56 +++++++---- .../pyexecutor/executor_request_queue.py | 5 +- .../_torch/pyexecutor/kv_cache_transceiver.py | 6 +- tensorrt_llm/_torch/pyexecutor/llm_request.py | 6 ++ tensorrt_llm/disaggregated_params.py | 12 ++- tensorrt_llm/executor/base_worker.py | 3 + tensorrt_llm/executor/result.py | 3 + tensorrt_llm/serve/openai_disagg_service.py | 13 ++- tensorrt_llm/serve/openai_protocol.py | 13 ++- .../bindings/test_executor_bindings.py | 9 +- 17 files changed, 252 insertions(+), 69 deletions(-) diff --git a/cpp/include/tensorrt_llm/executor/executor.h b/cpp/include/tensorrt_llm/executor/executor.h index 787fa0bb7e..3015ead0f7 100644 --- a/cpp/include/tensorrt_llm/executor/executor.h +++ b/cpp/include/tensorrt_llm/executor/executor.h @@ -442,11 +442,16 @@ class ContextPhaseParams public: using RequestIdType = std::uint64_t; - ContextPhaseParams(VecTokens firstGenTokens, RequestIdType reqId, std::optional draftTokens); - ContextPhaseParams( - VecTokens firstGenTokens, RequestIdType reqId, void* state, std::optional draftTokens); + ContextPhaseParams(VecTokens firstGenTokens, RequestIdType reqId, std::optional draftTokens, + std::optional disaggId = std::nullopt, std::optional ctxDpRank = std::nullopt, + std::optional disaggInfoEndpoint = std::nullopt); + ContextPhaseParams(VecTokens firstGenTokens, RequestIdType reqId, void* state, std::optional draftTokens, + std::optional disaggId = std::nullopt, std::optional ctxDpRank = std::nullopt, + std::optional disaggInfoEndpoint = std::nullopt); ContextPhaseParams(VecTokens firstGenTokens, RequestIdType reqId, std::vector const& serializedState, - std::optional draftTokens); + std::optional draftTokens, std::optional disaggId = std::nullopt, + std::optional ctxDpRank = std::nullopt, + std::optional disaggInfoEndpoint = std::nullopt); ContextPhaseParams(ContextPhaseParams const&); ContextPhaseParams(ContextPhaseParams&&) noexcept; @@ -457,15 +462,24 @@ public: [[nodiscard]] bool operator==(ContextPhaseParams const&) const noexcept; [[nodiscard]] VecTokens const& getFirstGenTokens() const& noexcept; + void setFirstGenTokens(VecTokens firstGenTokens) noexcept; [[nodiscard]] std::optional const& getDraftTokens() const& noexcept; + void setDraftTokens(std::optional draftTokens) noexcept; [[nodiscard]] VecTokens popFirstGenTokens() && noexcept; [[nodiscard]] RequestIdType getReqId() const noexcept; - + void setReqId(RequestIdType reqId) noexcept; [[nodiscard]] void const* getState() const noexcept; [[nodiscard]] void* getState() noexcept; [[nodiscard]] void* releaseState() noexcept; [[nodiscard]] std::vector getSerializedState() const noexcept; + [[nodiscard]] std::optional getDisaggId() const noexcept; + void setDisaggId(std::optional disaggId) noexcept; + [[nodiscard]] std::optional getCtxDpRank() const noexcept; + void setCtxDpRank(std::optional ctxDpRank) noexcept; + [[nodiscard]] std::optional const& getDisaggInfoEndpoint() const noexcept; + void setDisaggInfoEndpoint(std::optional disaggInfoEndpoint) noexcept; + private: friend class Serialization; static void deleter(void const* data); @@ -482,6 +496,15 @@ private: /// @brief The draft tokens generated by context executor std::optional mDraftTokens; + + /// @brief The disaggregated id + std::optional mDisaggId; + + /// @brief The context phase data parallel rank + std::optional mCtxDpRank; + + /// @brief The disaggregated info endpoint + std::optional mDisaggInfoEndpoint; }; /// @brief Configuration for speculative decoding (both draft and target models) diff --git a/cpp/tensorrt_llm/batch_manager/llmRequest.cpp b/cpp/tensorrt_llm/batch_manager/llmRequest.cpp index e664021db0..233242400f 100644 --- a/cpp/tensorrt_llm/batch_manager/llmRequest.cpp +++ b/cpp/tensorrt_llm/batch_manager/llmRequest.cpp @@ -99,13 +99,16 @@ std::optional LlmRequest::createResult(bool useFastLogits, int } if (!hasDraftTokens()) { - result.contextPhaseParams = executor::ContextPhaseParams{ - std::move(firstGenTokens), mRequestId, mContextPhaseParams.value().releaseState(), std::nullopt}; + result.contextPhaseParams = executor::ContextPhaseParams{std::move(firstGenTokens), mRequestId, + mContextPhaseParams.value().releaseState(), std::nullopt, mContextPhaseParams.value().getDisaggId(), + mContextPhaseParams.value().getCtxDpRank(), mContextPhaseParams.value().getDisaggInfoEndpoint()}; } else { - result.contextPhaseParams = executor::ContextPhaseParams{ - std::move(firstGenTokens), mRequestId, mContextPhaseParams.value().releaseState(), *getDraftTokens()}; + result.contextPhaseParams = executor::ContextPhaseParams{std::move(firstGenTokens), mRequestId, + mContextPhaseParams.value().releaseState(), *getDraftTokens(), + mContextPhaseParams.value().getDisaggId(), mContextPhaseParams.value().getCtxDpRank(), + mContextPhaseParams.value().getDisaggInfoEndpoint()}; } } diff --git a/cpp/tensorrt_llm/executor/contextPhaseParams.cpp b/cpp/tensorrt_llm/executor/contextPhaseParams.cpp index 4e66d567bb..cdfd5e63e4 100644 --- a/cpp/tensorrt_llm/executor/contextPhaseParams.cpp +++ b/cpp/tensorrt_llm/executor/contextPhaseParams.cpp @@ -27,28 +27,41 @@ namespace su = tensorrt_llm::executor::serialize_utils; namespace tensorrt_llm::executor { -ContextPhaseParams::ContextPhaseParams( - VecTokens firstGenTokens, RequestIdType reqId, void* state, std::optional draftTokens) +ContextPhaseParams::ContextPhaseParams(VecTokens firstGenTokens, RequestIdType reqId, void* state, + std::optional draftTokens, std::optional disaggId, std::optional ctxDpRank, + std::optional disaggInfoEndpoint) : mReqId{reqId} , mFirstGenTokens{std::move(firstGenTokens)} , mState{StatePtr{state, deleter}} , mDraftTokens{std::move(draftTokens)} -{ -} - -ContextPhaseParams::ContextPhaseParams( - VecTokens firstGenTokens, RequestIdType reqId, std::optional draftTokens) - : mReqId{reqId} - , mFirstGenTokens{std::move(firstGenTokens)} - , mDraftTokens{std::move(draftTokens)} + , mDisaggId{std::move(disaggId)} + , mCtxDpRank{ctxDpRank} + , mDisaggInfoEndpoint{std::move(disaggInfoEndpoint)} { } ContextPhaseParams::ContextPhaseParams(VecTokens firstGenTokens, RequestIdType reqId, - std::vector const& serializedState, std::optional draftTokens) + std::optional draftTokens, std::optional disaggId, std::optional ctxDpRank, + std::optional disaggInfoEndpoint) : mReqId{reqId} , mFirstGenTokens{std::move(firstGenTokens)} , mDraftTokens{std::move(draftTokens)} + , mDisaggId{std::move(disaggId)} + , mCtxDpRank{ctxDpRank} + , mDisaggInfoEndpoint{std::move(disaggInfoEndpoint)} +{ +} + +ContextPhaseParams::ContextPhaseParams(VecTokens firstGenTokens, RequestIdType reqId, + std::vector const& serializedState, std::optional draftTokens, + std::optional disaggId, std::optional ctxDpRank, + std::optional disaggInfoEndpoint) + : mReqId{reqId} + , mFirstGenTokens{std::move(firstGenTokens)} + , mDraftTokens{std::move(draftTokens)} + , mDisaggId{std::move(disaggId)} + , mCtxDpRank{ctxDpRank} + , mDisaggInfoEndpoint{std::move(disaggInfoEndpoint)} { su::VectorWrapBuf strbuf(const_cast&>(serializedState)); @@ -60,12 +73,15 @@ ContextPhaseParams::ContextPhaseParams(VecTokens firstGenTokens, RequestIdType r } ContextPhaseParams::ContextPhaseParams(ContextPhaseParams const& other) + : mReqId{other.mReqId} + , mFirstGenTokens{other.mFirstGenTokens} + , mDraftTokens{other.mDraftTokens} + , mDisaggId{other.mDisaggId} + , mCtxDpRank{other.mCtxDpRank} + , mDisaggInfoEndpoint{other.mDisaggInfoEndpoint} { // Since the internal header files implement the destructor while using the declaration of this // type, a `unique_ptr` with a custom destructor member is used here. - mReqId = other.mReqId; - mFirstGenTokens = other.mFirstGenTokens; - mDraftTokens = other.mDraftTokens; if (other.mState) { auto* otherState = static_cast(other.mState.get()); @@ -90,11 +106,21 @@ VecTokens const& ContextPhaseParams::getFirstGenTokens() const& noexcept return mFirstGenTokens; } +void ContextPhaseParams::setFirstGenTokens(VecTokens firstGenTokens) noexcept +{ + mFirstGenTokens = std::move(firstGenTokens); +} + std::optional const& ContextPhaseParams::getDraftTokens() const& noexcept { return mDraftTokens; } +void ContextPhaseParams::setDraftTokens(std::optional draftTokens) noexcept +{ + mDraftTokens = std::move(draftTokens); +} + VecTokens ContextPhaseParams::popFirstGenTokens() && noexcept { return std::move(mFirstGenTokens); @@ -105,6 +131,11 @@ ContextPhaseParams::RequestIdType ContextPhaseParams::getReqId() const noexcept return mReqId; } +void ContextPhaseParams::setReqId(RequestIdType reqId) noexcept +{ + mReqId = reqId; +} + void const* ContextPhaseParams::getState() const noexcept { return mState.get(); @@ -125,6 +156,36 @@ void* ContextPhaseParams::releaseState() noexcept return mState.release(); } +std::optional ContextPhaseParams::getDisaggId() const noexcept +{ + return mDisaggId; +} + +void ContextPhaseParams::setDisaggId(std::optional disaggId) noexcept +{ + mDisaggId = disaggId; +} + +std::optional ContextPhaseParams::getCtxDpRank() const noexcept +{ + return mCtxDpRank; +} + +void ContextPhaseParams::setCtxDpRank(std::optional ctxDpRank) noexcept +{ + mCtxDpRank = ctxDpRank; +} + +std::optional const& ContextPhaseParams::getDisaggInfoEndpoint() const noexcept +{ + return mDisaggInfoEndpoint; +} + +void ContextPhaseParams::setDisaggInfoEndpoint(std::optional disaggInfoEndpoint) noexcept +{ + mDisaggInfoEndpoint = std::move(disaggInfoEndpoint); +} + void ContextPhaseParams::deleter(void const* data) { using StateT = DataTransceiverState const; @@ -134,7 +195,8 @@ void ContextPhaseParams::deleter(void const* data) bool ContextPhaseParams::operator==(ContextPhaseParams const& other) const noexcept { if (mFirstGenTokens != other.mFirstGenTokens || mReqId != other.mReqId || mDraftTokens != other.mDraftTokens - || static_cast(mState) != static_cast(other.mState)) + || mDisaggId != other.mDisaggId || mDisaggInfoEndpoint != other.mDisaggInfoEndpoint + || mCtxDpRank != other.mCtxDpRank || static_cast(mState) != static_cast(other.mState)) { return false; } diff --git a/cpp/tensorrt_llm/executor/serialization.cpp b/cpp/tensorrt_llm/executor/serialization.cpp index 8e79563b7d..f74a4b0ff4 100644 --- a/cpp/tensorrt_llm/executor/serialization.cpp +++ b/cpp/tensorrt_llm/executor/serialization.cpp @@ -652,14 +652,19 @@ ContextPhaseParams Serialization::deserializeContextPhaseParams(std::istream& is auto reqId = su::deserialize(is); auto firstGenTokens = su::deserialize(is); auto draftTokens = su::deserialize(is); + auto disaggId = su::deserialize(is); + auto ctxDpRank = su::deserialize(is); + auto disaggInfoEndpoint = su::deserialize(is); auto hasState = su::deserialize(is); if (hasState) { auto state = std::make_unique(); *state = deserializeDataTransceiverState(is); - return ContextPhaseParams{std::move(firstGenTokens), reqId, state.release(), std::move(draftTokens)}; + return ContextPhaseParams{std::move(firstGenTokens), reqId, state.release(), std::move(draftTokens), + std::move(disaggId), ctxDpRank, std::move(disaggInfoEndpoint)}; } - return ContextPhaseParams{std::move(firstGenTokens), reqId, nullptr, std::move(draftTokens)}; + return ContextPhaseParams{std::move(firstGenTokens), reqId, std::move(draftTokens), std::move(disaggId), ctxDpRank, + std::move(disaggInfoEndpoint)}; } void Serialization::serialize(ContextPhaseParams const& contextPhaseParams, std::ostream& os) @@ -667,6 +672,9 @@ void Serialization::serialize(ContextPhaseParams const& contextPhaseParams, std: su::serialize(contextPhaseParams.mReqId, os); su::serialize(contextPhaseParams.mFirstGenTokens, os); su::serialize(contextPhaseParams.mDraftTokens, os); + su::serialize(contextPhaseParams.mDisaggId, os); + su::serialize(contextPhaseParams.mCtxDpRank, os); + su::serialize(contextPhaseParams.mDisaggInfoEndpoint, os); su::serialize(static_cast(contextPhaseParams.mState), os); if (contextPhaseParams.mState) { @@ -680,6 +688,9 @@ size_t Serialization::serializedSize(ContextPhaseParams const& contextPhaseParam totalSize += su::serializedSize(contextPhaseParams.mReqId); totalSize += su::serializedSize(contextPhaseParams.mFirstGenTokens); totalSize += su::serializedSize(contextPhaseParams.mDraftTokens); + totalSize += su::serializedSize(contextPhaseParams.mDisaggId); + totalSize += su::serializedSize(contextPhaseParams.mCtxDpRank); + totalSize += su::serializedSize(contextPhaseParams.mDisaggInfoEndpoint); totalSize += su::serializedSize(bool{}); if (contextPhaseParams.mState) { diff --git a/cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp b/cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp index 17c27f43be..4d42b4bb36 100644 --- a/cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp +++ b/cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp @@ -165,7 +165,7 @@ void initBindings(nb::module_& m) "context_current_position", &GenLlmReq::getContextCurrentPosition, &GenLlmReq::setContextCurrentPosition) .def_prop_ro("prepopulated_prompt_len", &GenLlmReq::getPrepopulatedPromptLen) .def_prop_rw("guided_decoding_params", &GenLlmReq::getGuidedDecodingParams, &GenLlmReq::setGuidedDecodingParams) - .def_prop_ro("context_phase_params", &GenLlmReq::getContextPhaseParams) + .def_prop_rw("context_phase_params", &GenLlmReq::getContextPhaseParams, &GenLlmReq::setContextPhaseParams) .def_prop_ro("is_context_only_request", &GenLlmReq::isContextOnlyRequest) .def_prop_ro("is_generation_only_request", &GenLlmReq::isGenerationOnlyRequest) .def_prop_ro("is_generation_complete_state", &GenLlmReq::isGenerationCompleteState) diff --git a/cpp/tensorrt_llm/nanobind/executor/request.cpp b/cpp/tensorrt_llm/nanobind/executor/request.cpp index db05409d86..cdcf94a790 100644 --- a/cpp/tensorrt_llm/nanobind/executor/request.cpp +++ b/cpp/tensorrt_llm/nanobind/executor/request.cpp @@ -442,14 +442,16 @@ void initRequestBindings(nb::module_& m) { auto serializedState = self.getSerializedState(); return nb::make_tuple(self.getFirstGenTokens(), self.getReqId(), - nb::bytes(serializedState.data(), serializedState.size()), self.getDraftTokens()); + nb::bytes(serializedState.data(), serializedState.size()), self.getDraftTokens(), self.getDisaggId(), + self.getCtxDpRank(), self.getDisaggInfoEndpoint()); } - return nb::make_tuple(self.getFirstGenTokens(), self.getReqId(), nb::none(), self.getDraftTokens()); + return nb::make_tuple(self.getFirstGenTokens(), self.getReqId(), nb::none(), self.getDraftTokens(), + self.getDisaggId(), self.getCtxDpRank(), self.getDisaggInfoEndpoint()); }; auto ContextPhaseParamsSetState = [](tle::ContextPhaseParams& contextPhaseParams, nb::tuple const& state) { - if (state.size() != 4) + if (state.size() != 7) { throw std::runtime_error("Invalid ContextPhaseParams state!"); } @@ -460,13 +462,15 @@ void initRequestBindings(nb::module_& m) new (&contextPhaseParams) tle::ContextPhaseParams(nb::cast(state[0]), nb::cast(state[1]), std::vector(opaque_state_str_view.begin(), opaque_state_str_view.end()), - nb::cast>(state[3])); + nb::cast>(state[3]), nb::cast>(state[4]), + nb::cast>(state[5]), nb::cast>(state[6])); } else { new (&contextPhaseParams) tle::ContextPhaseParams(nb::cast(state[0]), nb::cast(state[1]), - nb::cast>(state[3])); + nb::cast>(state[3]), nb::cast>(state[4]), + nb::cast>(state[5]), nb::cast>(state[6])); } }; @@ -475,25 +479,37 @@ void initRequestBindings(nb::module_& m) "__init__", [](tle::ContextPhaseParams& self, VecTokens const& first_gen_tokens, tle::ContextPhaseParams::RequestIdType req_id, std::optional const& opaque_state, - std::optional const& draft_tokens) + std::optional const& draft_tokens, std::optional const& disagg_id, + std::optional const& ctx_dp_rank, std::optional const& disagg_info_endpoint) { if (opaque_state) { auto opaque_state_str_view = std::string_view(opaque_state.value().c_str(), opaque_state.value().size()); new (&self) tle::ContextPhaseParams(first_gen_tokens, req_id, - std::vector(opaque_state_str_view.begin(), opaque_state_str_view.end()), draft_tokens); + std::vector(opaque_state_str_view.begin(), opaque_state_str_view.end()), draft_tokens, + disagg_id, ctx_dp_rank, disagg_info_endpoint); } else { - new (&self) tle::ContextPhaseParams(first_gen_tokens, req_id, draft_tokens); + new (&self) tle::ContextPhaseParams( + first_gen_tokens, req_id, draft_tokens, disagg_id, ctx_dp_rank, disagg_info_endpoint); } }, nb::arg("first_gen_tokens"), nb::arg("req_id"), nb::arg("opaque_state").none(), - nb::arg("draft_tokens").none()) - .def_prop_ro("first_gen_tokens", [](tle::ContextPhaseParams const& self) { return self.getFirstGenTokens(); }) - .def_prop_ro("draft_tokens", [](tle::ContextPhaseParams const& self) { return self.getDraftTokens(); }) - .def_prop_ro("req_id", &tle::ContextPhaseParams::getReqId) + nb::arg("draft_tokens").none(), nb::arg("disagg_id").none(), nb::arg("ctx_dp_rank").none(), + nb::arg("disagg_info_endpoint").none()) + .def_prop_rw( + "first_gen_tokens", [](tle::ContextPhaseParams const& self) { return self.getFirstGenTokens(); }, + [](tle::ContextPhaseParams& self, VecTokens const& tokens) { self.setFirstGenTokens(tokens); }) + .def_prop_rw( + "draft_tokens", [](tle::ContextPhaseParams const& self) { return self.getDraftTokens(); }, + [](tle::ContextPhaseParams& self, std::optional const& tokens) { self.setDraftTokens(tokens); }) + .def_prop_rw("req_id", &tle::ContextPhaseParams::getReqId, &tle::ContextPhaseParams::setReqId) + .def_prop_rw("disagg_id", &tle::ContextPhaseParams::getDisaggId, &tle::ContextPhaseParams::setDisaggId) + .def_prop_rw("ctx_dp_rank", &tle::ContextPhaseParams::getCtxDpRank, &tle::ContextPhaseParams::setCtxDpRank) + .def_prop_rw("disagg_info_endpoint", &tle::ContextPhaseParams::getDisaggInfoEndpoint, + &tle::ContextPhaseParams::setDisaggInfoEndpoint) .def_prop_ro("opaque_state", [](tle::ContextPhaseParams const& self) { diff --git a/cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp b/cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp index 1d98b0c623..43edadd642 100644 --- a/cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp +++ b/cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp @@ -170,7 +170,7 @@ void initBindings(pybind11::module_& m) .def_property_readonly("prepopulated_prompt_len", &GenLlmReq::getPrepopulatedPromptLen) .def_property( "guided_decoding_params", &GenLlmReq::getGuidedDecodingParams, &GenLlmReq::setGuidedDecodingParams) - .def_property_readonly("context_phase_params", &GenLlmReq::getContextPhaseParams) + .def_property("context_phase_params", &GenLlmReq::getContextPhaseParams, &GenLlmReq::setContextPhaseParams) .def_property_readonly("is_context_only_request", &GenLlmReq::isContextOnlyRequest) .def_property_readonly("is_generation_only_request", &GenLlmReq::isGenerationOnlyRequest) .def_property_readonly("is_generation_complete_state", &GenLlmReq::isGenerationCompleteState) diff --git a/cpp/tensorrt_llm/pybind/executor/request.cpp b/cpp/tensorrt_llm/pybind/executor/request.cpp index 2e9dae860e..3879d0fd70 100644 --- a/cpp/tensorrt_llm/pybind/executor/request.cpp +++ b/cpp/tensorrt_llm/pybind/executor/request.cpp @@ -411,14 +411,16 @@ void initRequestBindings(pybind11::module_& m) { auto serializedState = self.getSerializedState(); return py::make_tuple(self.getFirstGenTokens(), self.getReqId(), - py::bytes(serializedState.data(), serializedState.size()), self.getDraftTokens()); + py::bytes(serializedState.data(), serializedState.size()), self.getDraftTokens(), self.getDisaggId(), + self.getCtxDpRank(), self.getDisaggInfoEndpoint()); } - return py::make_tuple(self.getFirstGenTokens(), self.getReqId(), py::none(), self.getDraftTokens()); + return py::make_tuple(self.getFirstGenTokens(), self.getReqId(), py::none(), self.getDraftTokens(), + self.getDisaggId(), self.getCtxDpRank(), self.getDisaggInfoEndpoint()); }; auto ContextPhaseParamsSetState = [](py::tuple const& state) { - if (state.size() != 4) + if (state.size() != 7) { throw std::runtime_error("Invalid ContextPhaseParams state!"); } @@ -429,28 +431,44 @@ void initRequestBindings(pybind11::module_& m) return std::make_unique(state[0].cast(), state[1].cast(), std::vector(opaque_state_str_view.begin(), opaque_state_str_view.end()), - state[3].cast>()); + state[3].cast>(), state[4].cast>(), + state[5].cast>(), state[6].cast>()); } return std::make_unique(state[0].cast(), - state[1].cast(), state[3].cast>()); + state[1].cast(), state[3].cast>(), + state[4].cast>(), state[5].cast>(), + state[6].cast>()); }; py::class_(m, "ContextPhaseParams") .def(py::init( - [](VecTokens const& first_gen_tokens, tle::ContextPhaseParams::RequestIdType req_id, - std::optional const& opaque_state, std::optional const& draft_tokens) - { - if (opaque_state) - { - auto opaque_state_str_view = std::string_view(opaque_state.value().cast()); - return std::make_unique(first_gen_tokens, req_id, - std::vector(opaque_state_str_view.begin(), opaque_state_str_view.end()), draft_tokens); - } - return std::make_unique(first_gen_tokens, req_id, draft_tokens); - })) - .def_property_readonly("first_gen_tokens", &tle::ContextPhaseParams::getFirstGenTokens) - .def_property_readonly("draft_tokens", &tle::ContextPhaseParams::getDraftTokens) - .def_property_readonly("req_id", &tle::ContextPhaseParams::getReqId) + [](VecTokens const& first_gen_tokens, tle::ContextPhaseParams::RequestIdType req_id, + std::optional const& opaque_state, std::optional const& draft_tokens, + std::optional const& disagg_id, std::optional const& ctx_dp_rank, + std::optional const& disagg_info_endpoint) + { + if (opaque_state) + { + auto opaque_state_str_view = std::string_view(opaque_state.value().cast()); + return std::make_unique(first_gen_tokens, req_id, + std::vector(opaque_state_str_view.begin(), opaque_state_str_view.end()), + draft_tokens, disagg_id, ctx_dp_rank, disagg_info_endpoint); + } + return std::make_unique( + first_gen_tokens, req_id, draft_tokens, disagg_id, ctx_dp_rank, disagg_info_endpoint); + }), + py::arg("first_gen_tokens"), py::arg("req_id"), py::arg("opaque_state") = py::none(), + py::arg("draft_tokens") = py::none(), py::arg("disagg_id") = py::none(), + py::arg("ctx_dp_rank") = py::none(), py::arg("disagg_info_endpoint") = py::none()) + .def_property("first_gen_tokens", &tle::ContextPhaseParams::getFirstGenTokens, + &tle::ContextPhaseParams::setFirstGenTokens) + .def_property( + "draft_tokens", &tle::ContextPhaseParams::getDraftTokens, &tle::ContextPhaseParams::setDraftTokens) + .def_property("req_id", &tle::ContextPhaseParams::getReqId, &tle::ContextPhaseParams::setReqId) + .def_property("disagg_id", &tle::ContextPhaseParams::getDisaggId, &tle::ContextPhaseParams::setDisaggId) + .def_property("ctx_dp_rank", &tle::ContextPhaseParams::getCtxDpRank, &tle::ContextPhaseParams::setCtxDpRank) + .def_property("disagg_info_endpoint", &tle::ContextPhaseParams::getDisaggInfoEndpoint, + &tle::ContextPhaseParams::setDisaggInfoEndpoint) .def_property_readonly("opaque_state", [](tle::ContextPhaseParams const& self) { diff --git a/tensorrt_llm/_torch/pyexecutor/executor_request_queue.py b/tensorrt_llm/_torch/pyexecutor/executor_request_queue.py index 161282e4c4..13bdb6d270 100644 --- a/tensorrt_llm/_torch/pyexecutor/executor_request_queue.py +++ b/tensorrt_llm/_torch/pyexecutor/executor_request_queue.py @@ -465,10 +465,13 @@ class ExecutorRequestQueue: new_requests, "py_scheduling_params") py_num_logprobs = self._collect_py_objects_from_requests( new_requests, "py_num_logprobs") + py_disaggregated_params = self._collect_py_objects_from_requests( + new_requests, "py_disaggregated_params") py_request_objects = tuple( filter(None, [ py_logits_post_processors, py_multimodal_data, - py_scheduling_params, py_num_logprobs + py_scheduling_params, py_num_logprobs, + py_disaggregated_params ])) else: py_request_objects = None diff --git a/tensorrt_llm/_torch/pyexecutor/kv_cache_transceiver.py b/tensorrt_llm/_torch/pyexecutor/kv_cache_transceiver.py index 5616be7708..d090f81906 100644 --- a/tensorrt_llm/_torch/pyexecutor/kv_cache_transceiver.py +++ b/tensorrt_llm/_torch/pyexecutor/kv_cache_transceiver.py @@ -121,7 +121,11 @@ class BindKvCacheTransceiver(KvCacheTransceiver): cache_transceiver_config._to_pybind()) def respond_and_send_async(self, req: LlmRequest): - return self.impl.respond_and_send_async(req) + self.impl.respond_and_send_async(req) + if (req.py_disaggregated_params is not None + and req.py_disaggregated_params.disagg_id is not None): + req.context_phase_params.disagg_id = req.py_disaggregated_params.disagg_id + return def request_and_receive_sync(self, req: LlmRequest): return self.impl.request_and_receive_sync(req) diff --git a/tensorrt_llm/_torch/pyexecutor/llm_request.py b/tensorrt_llm/_torch/pyexecutor/llm_request.py index 871fe4b9bf..914c1c17af 100644 --- a/tensorrt_llm/_torch/pyexecutor/llm_request.py +++ b/tensorrt_llm/_torch/pyexecutor/llm_request.py @@ -566,6 +566,8 @@ class LlmRequest(tensorrt_llm.bindings.internal.batch_manager.LlmRequest): # currently, keep py_stop_words_list as python list, rather than tensor. self.py_stop_words_list = stop_words_list + self.py_disaggregated_params = None + self.py_result = PyResult( prompt_len=self.py_prompt_len, max_new_tokens=self.py_max_new_tokens, @@ -826,6 +828,10 @@ def executor_request_to_llm_request( py_multimodal_data=getattr(executor_request, "py_multimodal_data", None), kv_cache_retention_config=executor_request.kv_cache_retention_config) + + llm_request.py_disaggregated_params = getattr(executor_request, + "py_disaggregated_params", + None) if child_req_ids: for child_id in child_req_ids: llm_request.create_child_request(child_id) diff --git a/tensorrt_llm/disaggregated_params.py b/tensorrt_llm/disaggregated_params.py index 4c0680bc94..a80fe262f9 100644 --- a/tensorrt_llm/disaggregated_params.py +++ b/tensorrt_llm/disaggregated_params.py @@ -32,7 +32,9 @@ class DisaggregatedParams: ctx_request_id: Optional[int] = None opaque_state: Optional[bytes] = None draft_tokens: Optional[List[int]] = None - + disagg_id: Optional[int] = None + ctx_dp_rank: Optional[int] = None + ctx_info_endpoint: Optional[List[str]] = None # E-P Disaggregated Params multimodal_embedding_handles: Optional[List[Dict[str, Any]]] = ( None # multimodal embedding handles should be a list of cudaIPC handles for each mm_embedding @@ -45,7 +47,13 @@ class DisaggregatedParams: def get_context_phase_params(self) -> tllme.ContextPhaseParams: return tllme.ContextPhaseParams( - self.first_gen_tokens, self.ctx_request_id, self.opaque_state, self.draft_tokens + self.first_gen_tokens, + self.ctx_request_id, + self.opaque_state, + self.draft_tokens, + self.disagg_id, + self.ctx_dp_rank, + self.ctx_info_endpoint, ) def get_request_type(self) -> tllme.RequestType: diff --git a/tensorrt_llm/executor/base_worker.py b/tensorrt_llm/executor/base_worker.py index ce050179a2..c7dba8e9f8 100644 --- a/tensorrt_llm/executor/base_worker.py +++ b/tensorrt_llm/executor/base_worker.py @@ -563,6 +563,9 @@ class BaseWorker(GenerationExecutor): executor_request.py_num_logprobs = request.sampling_params.logprobs executor_request.py_lora_path = py_lora_path + # here we add executor_request.py_disaggregated_params= request.disaggregated_params for python cache transceiver + if self._is_pytorch_backend and request.disaggregated_params is not None: + executor_request.py_disaggregated_params = request.disaggregated_params if self._is_pytorch_backend and request.multimodal_params is not None: if request.multimodal_params.multimodal_data is not None: # NOTE: Deserialize SharedTensor handle to actual tensor diff --git a/tensorrt_llm/executor/result.py b/tensorrt_llm/executor/result.py index 8d33d94a7f..36c8cb12d4 100644 --- a/tensorrt_llm/executor/result.py +++ b/tensorrt_llm/executor/result.py @@ -427,6 +427,9 @@ class GenerationResultBase: ctx_request_id=context_phase_params.req_id, opaque_state=context_phase_params.opaque_state, draft_tokens=context_phase_params.draft_tokens, + disagg_id=context_phase_params.disagg_id, + ctx_dp_rank=context_phase_params.ctx_dp_rank, + ctx_info_endpoint=context_phase_params.disagg_info_endpoint, multimodal_embedding_handles=None, ) diff --git a/tensorrt_llm/serve/openai_disagg_service.py b/tensorrt_llm/serve/openai_disagg_service.py index a0012bd6d3..6d11aa8997 100644 --- a/tensorrt_llm/serve/openai_disagg_service.py +++ b/tensorrt_llm/serve/openai_disagg_service.py @@ -15,6 +15,7 @@ import asyncio import copy import os +import uuid from typing import Any, Callable, Dict, Optional from tensorrt_llm.llmapi.disagg_utils import ( @@ -142,7 +143,13 @@ class OpenAIDisaggregatedService(OpenAIService): def _get_ctx_request(self, request: UCompletionRequest) -> UCompletionRequest: ctx_request = copy.deepcopy(request) - ctx_request.disaggregated_params = DisaggregatedParams(request_type="context_only") + unique_disagg_id = ( + uuid.uuid4().int & 0x7FFFFFFFFFFFFFFF + ) # Generate positive int64 from uuid + ctx_request.disaggregated_params = DisaggregatedParams( + request_type="context_only", + disagg_id=unique_disagg_id, + ) ctx_request.stream = False ctx_request.stream_options = None return ctx_request @@ -304,4 +311,8 @@ class OpenAIDisaggregatedService(OpenAIService): raise ValueError("Context server did not return disaggregated params") if ctx_response.choices[0].disaggregated_params.ctx_request_id is None: raise ValueError("Invalid disaggregated params in context phase response.") + if ctx_response.choices[0].disaggregated_params.disagg_id is None: + raise ValueError( + "Invalid disaggregated params in context phase response. disagg_id is None" + ) return ctx_response diff --git a/tensorrt_llm/serve/openai_protocol.py b/tensorrt_llm/serve/openai_protocol.py index 8ddda27cd7..cb4d5efb01 100644 --- a/tensorrt_llm/serve/openai_protocol.py +++ b/tensorrt_llm/serve/openai_protocol.py @@ -117,6 +117,9 @@ class DisaggregatedParams(OpenAIBaseModel): ctx_request_id: Optional[int] = None encoded_opaque_state: Optional[str] = None draft_tokens: Optional[List[int]] = None + disagg_id: Optional[int] = None + ctx_dp_rank: Optional[int] = None + ctx_info_endpoint: Optional[str] = None class ErrorResponse(OpenAIBaseModel): @@ -1000,7 +1003,10 @@ def to_disaggregated_params( ctx_request_id=tllm_disagg_params.ctx_request_id, encoded_opaque_state=encode_opaque_state( tllm_disagg_params.opaque_state), - draft_tokens=tllm_disagg_params.draft_tokens) + draft_tokens=tllm_disagg_params.draft_tokens, + disagg_id=tllm_disagg_params.disagg_id, + ctx_dp_rank=tllm_disagg_params.ctx_dp_rank, + ctx_info_endpoint=tllm_disagg_params.ctx_info_endpoint) def to_llm_disaggregated_params( @@ -1013,7 +1019,10 @@ def to_llm_disaggregated_params( ctx_request_id=disaggregated_params.ctx_request_id, opaque_state=decode_opaque_state( disaggregated_params.encoded_opaque_state), - draft_tokens=disaggregated_params.draft_tokens) + draft_tokens=disaggregated_params.draft_tokens, + disagg_id=disaggregated_params.disagg_id, + ctx_dp_rank=disaggregated_params.ctx_dp_rank, + ctx_info_endpoint=disaggregated_params.ctx_info_endpoint) UCompletionRequest = Union[CompletionRequest, ChatCompletionRequest] diff --git a/tests/unittest/bindings/test_executor_bindings.py b/tests/unittest/bindings/test_executor_bindings.py index c913531298..c0f52f3f5d 100644 --- a/tests/unittest/bindings/test_executor_bindings.py +++ b/tests/unittest/bindings/test_executor_bindings.py @@ -1198,9 +1198,9 @@ def test_result_pickle(): result.sequence_index = 1 result.is_sequence_final = True result.decoding_iter = 1 - result.context_phase_params = trtllm.ContextPhaseParams([1, 2], 123, - bytes([0, 1]), - [10, 20, 30]) + result.context_phase_params = trtllm.ContextPhaseParams( + [1, 2], 123, bytes([0, 1]), [10, 20, 30], 13579, 1, + "disagg_info_endpoint_24680") result.request_perf_metrics = trtllm.RequestPerfMetrics() result.request_perf_metrics.last_iter = 33 result_str = pickle.dumps(result) @@ -1220,6 +1220,9 @@ def test_result_pickle(): assert result.context_phase_params.first_gen_tokens == result_copy.context_phase_params.first_gen_tokens assert result.context_phase_params.draft_tokens == result_copy.context_phase_params.draft_tokens assert result.context_phase_params.opaque_state == result_copy.context_phase_params.opaque_state + assert result.context_phase_params.disagg_id == result_copy.context_phase_params.disagg_id + assert result.context_phase_params.ctx_dp_rank == result_copy.context_phase_params.ctx_dp_rank + assert result.context_phase_params.disagg_info_endpoint == result_copy.context_phase_params.disagg_info_endpoint assert result.request_perf_metrics.last_iter == result_copy.request_perf_metrics.last_iter