diff --git a/cpp/include/tensorrt_llm/executor/executor.h b/cpp/include/tensorrt_llm/executor/executor.h index 1e5cc16a05..44806b37b0 100644 --- a/cpp/include/tensorrt_llm/executor/executor.h +++ b/cpp/include/tensorrt_llm/executor/executor.h @@ -442,11 +442,15 @@ class ContextPhaseParams public: using RequestIdType = std::uint64_t; - ContextPhaseParams(VecTokens firstGenTokens, RequestIdType reqId, std::optional draftTokens); - ContextPhaseParams( - VecTokens firstGenTokens, RequestIdType reqId, void* state, std::optional draftTokens); + ContextPhaseParams(VecTokens firstGenTokens, RequestIdType reqId, std::optional draftTokens, + std::optional ctxDpRank = std::nullopt, + std::optional disaggInfoEndpoint = std::nullopt); + ContextPhaseParams(VecTokens firstGenTokens, RequestIdType reqId, void* state, std::optional draftTokens, + std::optional ctxDpRank = std::nullopt, + std::optional disaggInfoEndpoint = std::nullopt); ContextPhaseParams(VecTokens firstGenTokens, RequestIdType reqId, std::vector const& serializedState, - std::optional draftTokens); + std::optional draftTokens, std::optional ctxDpRank = std::nullopt, + std::optional disaggInfoEndpoint = std::nullopt); ContextPhaseParams(ContextPhaseParams const&); ContextPhaseParams(ContextPhaseParams&&) noexcept; @@ -457,15 +461,22 @@ public: [[nodiscard]] bool operator==(ContextPhaseParams const&) const noexcept; [[nodiscard]] VecTokens const& getFirstGenTokens() const& noexcept; + void setFirstGenTokens(VecTokens const& firstGenTokens) noexcept; [[nodiscard]] std::optional const& getDraftTokens() const& noexcept; + void setDraftTokens(std::optional const& draftTokens) noexcept; [[nodiscard]] VecTokens popFirstGenTokens() && noexcept; [[nodiscard]] RequestIdType getReqId() const noexcept; - + void setReqId(RequestIdType const& reqId) noexcept; [[nodiscard]] void const* getState() const noexcept; [[nodiscard]] void* getState() noexcept; [[nodiscard]] void* releaseState() noexcept; [[nodiscard]] std::vector getSerializedState() const noexcept; + [[nodiscard]] std::optional getCtxDpRank() const noexcept; + void setCtxDpRank(std::optional const& ctxDpRank) noexcept; + [[nodiscard]] std::optional const& getDisaggInfoEndpoint() const noexcept; + void setDisaggInfoEndpoint(std::optional const& disaggInfoEndpoint) noexcept; + private: friend class Serialization; static void deleter(void const* data); @@ -482,6 +493,12 @@ private: /// @brief The draft tokens generated by context executor std::optional mDraftTokens; + + /// @brief The context phase data parallel rank + std::optional mCtxDpRank; + + /// @brief The disaggregated info endpoint + std::optional mDisaggInfoEndpoint; }; /// @brief Configuration for speculative decoding (both draft and target models) diff --git a/cpp/tensorrt_llm/batch_manager/llmRequest.cpp b/cpp/tensorrt_llm/batch_manager/llmRequest.cpp index e664021db0..63eba22146 100644 --- a/cpp/tensorrt_llm/batch_manager/llmRequest.cpp +++ b/cpp/tensorrt_llm/batch_manager/llmRequest.cpp @@ -99,13 +99,15 @@ std::optional LlmRequest::createResult(bool useFastLogits, int } if (!hasDraftTokens()) { - result.contextPhaseParams = executor::ContextPhaseParams{ - std::move(firstGenTokens), mRequestId, mContextPhaseParams.value().releaseState(), std::nullopt}; + result.contextPhaseParams = executor::ContextPhaseParams{std::move(firstGenTokens), mRequestId, + mContextPhaseParams.value().releaseState(), std::nullopt, mContextPhaseParams.value().getCtxDpRank(), + mContextPhaseParams.value().getDisaggInfoEndpoint()}; } else { - result.contextPhaseParams = executor::ContextPhaseParams{ - std::move(firstGenTokens), mRequestId, mContextPhaseParams.value().releaseState(), *getDraftTokens()}; + result.contextPhaseParams = executor::ContextPhaseParams{std::move(firstGenTokens), mRequestId, + mContextPhaseParams.value().releaseState(), *getDraftTokens(), + mContextPhaseParams.value().getCtxDpRank(), mContextPhaseParams.value().getDisaggInfoEndpoint()}; } } diff --git a/cpp/tensorrt_llm/executor/contextPhaseParams.cpp b/cpp/tensorrt_llm/executor/contextPhaseParams.cpp index 4e66d567bb..5ebd984b43 100644 --- a/cpp/tensorrt_llm/executor/contextPhaseParams.cpp +++ b/cpp/tensorrt_llm/executor/contextPhaseParams.cpp @@ -27,28 +27,37 @@ namespace su = tensorrt_llm::executor::serialize_utils; namespace tensorrt_llm::executor { -ContextPhaseParams::ContextPhaseParams( - VecTokens firstGenTokens, RequestIdType reqId, void* state, std::optional draftTokens) +ContextPhaseParams::ContextPhaseParams(VecTokens firstGenTokens, RequestIdType reqId, void* state, + std::optional draftTokens, std::optional ctxDpRank, + std::optional disaggInfoEndpoint) : mReqId{reqId} , mFirstGenTokens{std::move(firstGenTokens)} , mState{StatePtr{state, deleter}} , mDraftTokens{std::move(draftTokens)} -{ -} - -ContextPhaseParams::ContextPhaseParams( - VecTokens firstGenTokens, RequestIdType reqId, std::optional draftTokens) - : mReqId{reqId} - , mFirstGenTokens{std::move(firstGenTokens)} - , mDraftTokens{std::move(draftTokens)} + , mCtxDpRank{ctxDpRank} + , mDisaggInfoEndpoint{std::move(disaggInfoEndpoint)} { } ContextPhaseParams::ContextPhaseParams(VecTokens firstGenTokens, RequestIdType reqId, - std::vector const& serializedState, std::optional draftTokens) + std::optional draftTokens, std::optional ctxDpRank, + std::optional disaggInfoEndpoint) : mReqId{reqId} , mFirstGenTokens{std::move(firstGenTokens)} , mDraftTokens{std::move(draftTokens)} + , mCtxDpRank{ctxDpRank} + , mDisaggInfoEndpoint{std::move(disaggInfoEndpoint)} +{ +} + +ContextPhaseParams::ContextPhaseParams(VecTokens firstGenTokens, RequestIdType reqId, + std::vector const& serializedState, std::optional draftTokens, std::optional ctxDpRank, + std::optional disaggInfoEndpoint) + : mReqId{reqId} + , mFirstGenTokens{std::move(firstGenTokens)} + , mDraftTokens{std::move(draftTokens)} + , mCtxDpRank{ctxDpRank} + , mDisaggInfoEndpoint{std::move(disaggInfoEndpoint)} { su::VectorWrapBuf strbuf(const_cast&>(serializedState)); @@ -60,12 +69,14 @@ ContextPhaseParams::ContextPhaseParams(VecTokens firstGenTokens, RequestIdType r } ContextPhaseParams::ContextPhaseParams(ContextPhaseParams const& other) + : mReqId{other.mReqId} + , mFirstGenTokens{other.mFirstGenTokens} + , mDraftTokens{other.mDraftTokens} + , mCtxDpRank{other.mCtxDpRank} + , mDisaggInfoEndpoint{other.mDisaggInfoEndpoint} { // Since the internal header files implement the destructor while using the declaration of this // type, a `unique_ptr` with a custom destructor member is used here. - mReqId = other.mReqId; - mFirstGenTokens = other.mFirstGenTokens; - mDraftTokens = other.mDraftTokens; if (other.mState) { auto* otherState = static_cast(other.mState.get()); @@ -90,11 +101,21 @@ VecTokens const& ContextPhaseParams::getFirstGenTokens() const& noexcept return mFirstGenTokens; } +void ContextPhaseParams::setFirstGenTokens(VecTokens const& firstGenTokens) noexcept +{ + mFirstGenTokens = firstGenTokens; +} + std::optional const& ContextPhaseParams::getDraftTokens() const& noexcept { return mDraftTokens; } +void ContextPhaseParams::setDraftTokens(std::optional const& draftTokens) noexcept +{ + mDraftTokens = draftTokens; +} + VecTokens ContextPhaseParams::popFirstGenTokens() && noexcept { return std::move(mFirstGenTokens); @@ -105,6 +126,11 @@ ContextPhaseParams::RequestIdType ContextPhaseParams::getReqId() const noexcept return mReqId; } +void ContextPhaseParams::setReqId(RequestIdType const& reqId) noexcept +{ + mReqId = reqId; +} + void const* ContextPhaseParams::getState() const noexcept { return mState.get(); @@ -125,6 +151,26 @@ void* ContextPhaseParams::releaseState() noexcept return mState.release(); } +std::optional ContextPhaseParams::getCtxDpRank() const noexcept +{ + return mCtxDpRank; +} + +void ContextPhaseParams::setCtxDpRank(std::optional const& ctxDpRank) noexcept +{ + mCtxDpRank = ctxDpRank; +} + +std::optional const& ContextPhaseParams::getDisaggInfoEndpoint() const noexcept +{ + return mDisaggInfoEndpoint; +} + +void ContextPhaseParams::setDisaggInfoEndpoint(std::optional const& disaggInfoEndpoint) noexcept +{ + mDisaggInfoEndpoint = disaggInfoEndpoint; +} + void ContextPhaseParams::deleter(void const* data) { using StateT = DataTransceiverState const; @@ -134,6 +180,7 @@ void ContextPhaseParams::deleter(void const* data) bool ContextPhaseParams::operator==(ContextPhaseParams const& other) const noexcept { if (mFirstGenTokens != other.mFirstGenTokens || mReqId != other.mReqId || mDraftTokens != other.mDraftTokens + || mDisaggInfoEndpoint != other.mDisaggInfoEndpoint || mCtxDpRank != other.mCtxDpRank || static_cast(mState) != static_cast(other.mState)) { return false; diff --git a/cpp/tensorrt_llm/executor/serialization.cpp b/cpp/tensorrt_llm/executor/serialization.cpp index 2c417fbeea..79d015d585 100644 --- a/cpp/tensorrt_llm/executor/serialization.cpp +++ b/cpp/tensorrt_llm/executor/serialization.cpp @@ -652,14 +652,18 @@ ContextPhaseParams Serialization::deserializeContextPhaseParams(std::istream& is auto reqId = su::deserialize(is); auto firstGenTokens = su::deserialize(is); auto draftTokens = su::deserialize(is); + auto ctxDpRank = su::deserialize(is); + auto disaggInfoEndpoint = su::deserialize(is); auto hasState = su::deserialize(is); if (hasState) { auto state = std::make_unique(); *state = deserializeDataTransceiverState(is); - return ContextPhaseParams{std::move(firstGenTokens), reqId, state.release(), std::move(draftTokens)}; + return ContextPhaseParams{std::move(firstGenTokens), reqId, state.release(), std::move(draftTokens), ctxDpRank, + std::move(disaggInfoEndpoint)}; } - return ContextPhaseParams{std::move(firstGenTokens), reqId, nullptr, std::move(draftTokens)}; + return ContextPhaseParams{ + std::move(firstGenTokens), reqId, std::move(draftTokens), ctxDpRank, std::move(disaggInfoEndpoint)}; } void Serialization::serialize(ContextPhaseParams const& contextPhaseParams, std::ostream& os) @@ -667,6 +671,8 @@ void Serialization::serialize(ContextPhaseParams const& contextPhaseParams, std: su::serialize(contextPhaseParams.mReqId, os); su::serialize(contextPhaseParams.mFirstGenTokens, os); su::serialize(contextPhaseParams.mDraftTokens, os); + su::serialize(contextPhaseParams.mCtxDpRank, os); + su::serialize(contextPhaseParams.mDisaggInfoEndpoint, os); su::serialize(static_cast(contextPhaseParams.mState), os); if (contextPhaseParams.mState) { @@ -680,6 +686,8 @@ size_t Serialization::serializedSize(ContextPhaseParams const& contextPhaseParam totalSize += su::serializedSize(contextPhaseParams.mReqId); totalSize += su::serializedSize(contextPhaseParams.mFirstGenTokens); totalSize += su::serializedSize(contextPhaseParams.mDraftTokens); + totalSize += su::serializedSize(contextPhaseParams.mCtxDpRank); + totalSize += su::serializedSize(contextPhaseParams.mDisaggInfoEndpoint); totalSize += su::serializedSize(bool{}); if (contextPhaseParams.mState) { diff --git a/cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp b/cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp index 72a94944d5..87f8c2c3cc 100644 --- a/cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp +++ b/cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp @@ -167,7 +167,7 @@ void initBindings(nb::module_& m) "context_current_position", &GenLlmReq::getContextCurrentPosition, &GenLlmReq::setContextCurrentPosition) .def_prop_ro("prepopulated_prompt_len", &GenLlmReq::getPrepopulatedPromptLen) .def_prop_rw("guided_decoding_params", &GenLlmReq::getGuidedDecodingParams, &GenLlmReq::setGuidedDecodingParams) - .def_prop_ro("context_phase_params", &GenLlmReq::getContextPhaseParams) + .def_prop_rw("context_phase_params", &GenLlmReq::getContextPhaseParams, &GenLlmReq::setContextPhaseParams) .def_prop_ro("is_context_only_request", &GenLlmReq::isContextOnlyRequest) .def_prop_ro("is_generation_only_request", &GenLlmReq::isGenerationOnlyRequest) .def_prop_ro("is_generation_to_complete_state", &GenLlmReq::isGenerationToCompleteState) diff --git a/cpp/tensorrt_llm/nanobind/executor/request.cpp b/cpp/tensorrt_llm/nanobind/executor/request.cpp index 4a53516bf8..aa75a5fb60 100644 --- a/cpp/tensorrt_llm/nanobind/executor/request.cpp +++ b/cpp/tensorrt_llm/nanobind/executor/request.cpp @@ -442,14 +442,16 @@ void initRequestBindings(nb::module_& m) { auto serializedState = self.getSerializedState(); return nb::make_tuple(self.getFirstGenTokens(), self.getReqId(), - nb::bytes(serializedState.data(), serializedState.size()), self.getDraftTokens()); + nb::bytes(serializedState.data(), serializedState.size()), self.getDraftTokens(), self.getCtxDpRank(), + self.getDisaggInfoEndpoint()); } - return nb::make_tuple(self.getFirstGenTokens(), self.getReqId(), nb::none(), self.getDraftTokens()); + return nb::make_tuple(self.getFirstGenTokens(), self.getReqId(), nb::none(), self.getDraftTokens(), + self.getCtxDpRank(), self.getDisaggInfoEndpoint()); }; auto ContextPhaseParamsSetState = [](tle::ContextPhaseParams& contextPhaseParams, nb::tuple const& state) { - if (state.size() != 4) + if (state.size() != 6) { throw std::runtime_error("Invalid ContextPhaseParams state!"); } @@ -460,13 +462,15 @@ void initRequestBindings(nb::module_& m) new (&contextPhaseParams) tle::ContextPhaseParams(nb::cast(state[0]), nb::cast(state[1]), std::vector(opaque_state_str_view.begin(), opaque_state_str_view.end()), - nb::cast>(state[3])); + nb::cast>(state[3]), nb::cast>(state[4]), + nb::cast>(state[5])); } else { new (&contextPhaseParams) tle::ContextPhaseParams(nb::cast(state[0]), nb::cast(state[1]), - nb::cast>(state[3])); + nb::cast>(state[3]), nb::cast>(state[4]), + nb::cast>(state[5])); } }; @@ -475,25 +479,35 @@ void initRequestBindings(nb::module_& m) "__init__", [](tle::ContextPhaseParams& self, VecTokens const& first_gen_tokens, tle::ContextPhaseParams::RequestIdType req_id, std::optional const& opaque_state, - std::optional const& draft_tokens) + std::optional const& draft_tokens, std::optional const& ctx_dp_rank, + std::optional const& disagg_info_endpoint) { if (opaque_state) { auto opaque_state_str_view = std::string_view(opaque_state.value().c_str(), opaque_state.value().size()); new (&self) tle::ContextPhaseParams(first_gen_tokens, req_id, - std::vector(opaque_state_str_view.begin(), opaque_state_str_view.end()), draft_tokens); + std::vector(opaque_state_str_view.begin(), opaque_state_str_view.end()), draft_tokens, + ctx_dp_rank, disagg_info_endpoint); } else { - new (&self) tle::ContextPhaseParams(first_gen_tokens, req_id, draft_tokens); + new (&self) tle::ContextPhaseParams( + first_gen_tokens, req_id, draft_tokens, ctx_dp_rank, disagg_info_endpoint); } }, nb::arg("first_gen_tokens"), nb::arg("req_id"), nb::arg("opaque_state").none(), - nb::arg("draft_tokens").none()) - .def_prop_ro("first_gen_tokens", [](tle::ContextPhaseParams const& self) { return self.getFirstGenTokens(); }) - .def_prop_ro("draft_tokens", [](tle::ContextPhaseParams const& self) { return self.getDraftTokens(); }) - .def_prop_ro("req_id", &tle::ContextPhaseParams::getReqId) + nb::arg("draft_tokens").none(), nb::arg("ctx_dp_rank").none(), nb::arg("disagg_info_endpoint").none()) + .def_prop_rw( + "first_gen_tokens", [](tle::ContextPhaseParams const& self) { return self.getFirstGenTokens(); }, + [](tle::ContextPhaseParams& self, VecTokens const& tokens) { self.setFirstGenTokens(tokens); }) + .def_prop_rw( + "draft_tokens", [](tle::ContextPhaseParams const& self) { return self.getDraftTokens(); }, + [](tle::ContextPhaseParams& self, std::optional const& tokens) { self.setDraftTokens(tokens); }) + .def_prop_rw("req_id", &tle::ContextPhaseParams::getReqId, &tle::ContextPhaseParams::setReqId) + .def_prop_rw("ctx_dp_rank", &tle::ContextPhaseParams::getCtxDpRank, &tle::ContextPhaseParams::setCtxDpRank) + .def_prop_rw("disagg_info_endpoint", &tle::ContextPhaseParams::getDisaggInfoEndpoint, + &tle::ContextPhaseParams::setDisaggInfoEndpoint) .def_prop_ro("opaque_state", [](tle::ContextPhaseParams const& self) { diff --git a/cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp b/cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp index 657b7c6f36..f3f8daad12 100644 --- a/cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp +++ b/cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp @@ -172,7 +172,7 @@ void initBindings(pybind11::module_& m) .def_property_readonly("prepopulated_prompt_len", &GenLlmReq::getPrepopulatedPromptLen) .def_property( "guided_decoding_params", &GenLlmReq::getGuidedDecodingParams, &GenLlmReq::setGuidedDecodingParams) - .def_property_readonly("context_phase_params", &GenLlmReq::getContextPhaseParams) + .def_property("context_phase_params", &GenLlmReq::getContextPhaseParams, &GenLlmReq::setContextPhaseParams) .def_property_readonly("is_context_only_request", &GenLlmReq::isContextOnlyRequest) .def_property_readonly("is_generation_only_request", &GenLlmReq::isGenerationOnlyRequest) .def_property_readonly("is_generation_to_complete_state", &GenLlmReq::isGenerationToCompleteState) diff --git a/cpp/tensorrt_llm/pybind/executor/request.cpp b/cpp/tensorrt_llm/pybind/executor/request.cpp index 78bb650fbc..390dcb1eb3 100644 --- a/cpp/tensorrt_llm/pybind/executor/request.cpp +++ b/cpp/tensorrt_llm/pybind/executor/request.cpp @@ -411,14 +411,16 @@ void initRequestBindings(pybind11::module_& m) { auto serializedState = self.getSerializedState(); return py::make_tuple(self.getFirstGenTokens(), self.getReqId(), - py::bytes(serializedState.data(), serializedState.size()), self.getDraftTokens()); + py::bytes(serializedState.data(), serializedState.size()), self.getDraftTokens(), self.getCtxDpRank(), + self.getDisaggInfoEndpoint()); } - return py::make_tuple(self.getFirstGenTokens(), self.getReqId(), py::none(), self.getDraftTokens()); + return py::make_tuple(self.getFirstGenTokens(), self.getReqId(), py::none(), self.getDraftTokens(), + self.getCtxDpRank(), self.getDisaggInfoEndpoint()); }; auto ContextPhaseParamsSetState = [](py::tuple const& state) { - if (state.size() != 4) + if (state.size() != 6) { throw std::runtime_error("Invalid ContextPhaseParams state!"); } @@ -429,28 +431,42 @@ void initRequestBindings(pybind11::module_& m) return std::make_unique(state[0].cast(), state[1].cast(), std::vector(opaque_state_str_view.begin(), opaque_state_str_view.end()), - state[3].cast>()); + state[3].cast>(), state[4].cast>(), + state[5].cast>()); } return std::make_unique(state[0].cast(), - state[1].cast(), state[3].cast>()); + state[1].cast(), state[3].cast>(), + state[4].cast>(), state[5].cast>()); }; py::class_(m, "ContextPhaseParams") .def(py::init( - [](VecTokens const& first_gen_tokens, tle::ContextPhaseParams::RequestIdType req_id, - std::optional const& opaque_state, std::optional const& draft_tokens) - { - if (opaque_state) - { - auto opaque_state_str_view = std::string_view(opaque_state.value().cast()); - return std::make_unique(first_gen_tokens, req_id, - std::vector(opaque_state_str_view.begin(), opaque_state_str_view.end()), draft_tokens); - } - return std::make_unique(first_gen_tokens, req_id, draft_tokens); - })) - .def_property_readonly("first_gen_tokens", &tle::ContextPhaseParams::getFirstGenTokens) - .def_property_readonly("draft_tokens", &tle::ContextPhaseParams::getDraftTokens) - .def_property_readonly("req_id", &tle::ContextPhaseParams::getReqId) + [](VecTokens const& first_gen_tokens, tle::ContextPhaseParams::RequestIdType req_id, + std::optional const& opaque_state, std::optional const& draft_tokens, + std::optional const& ctx_dp_rank, + std::optional const& disagg_info_endpoint) + { + if (opaque_state) + { + auto opaque_state_str_view = std::string_view(opaque_state.value().cast()); + return std::make_unique(first_gen_tokens, req_id, + std::vector(opaque_state_str_view.begin(), opaque_state_str_view.end()), + draft_tokens, ctx_dp_rank, disagg_info_endpoint); + } + return std::make_unique( + first_gen_tokens, req_id, draft_tokens, ctx_dp_rank, disagg_info_endpoint); + }), + py::arg("first_gen_tokens"), py::arg("req_id"), py::arg("opaque_state") = py::none(), + py::arg("draft_tokens") = py::none(), py::arg("ctx_dp_rank") = py::none(), + py::arg("disagg_info_endpoint") = py::none()) + .def_property("first_gen_tokens", &tle::ContextPhaseParams::getFirstGenTokens, + &tle::ContextPhaseParams::setFirstGenTokens) + .def_property( + "draft_tokens", &tle::ContextPhaseParams::getDraftTokens, &tle::ContextPhaseParams::setDraftTokens) + .def_property("req_id", &tle::ContextPhaseParams::getReqId, &tle::ContextPhaseParams::setReqId) + .def_property("ctx_dp_rank", &tle::ContextPhaseParams::getCtxDpRank, &tle::ContextPhaseParams::setCtxDpRank) + .def_property("disagg_info_endpoint", &tle::ContextPhaseParams::getDisaggInfoEndpoint, + &tle::ContextPhaseParams::setDisaggInfoEndpoint) .def_property_readonly("opaque_state", [](tle::ContextPhaseParams const& self) { diff --git a/cpp/tests/unit_tests/executor/serializeUtilsTest.cpp b/cpp/tests/unit_tests/executor/serializeUtilsTest.cpp index b5a6353f64..a2aefd90e6 100644 --- a/cpp/tests/unit_tests/executor/serializeUtilsTest.cpp +++ b/cpp/tests/unit_tests/executor/serializeUtilsTest.cpp @@ -751,6 +751,34 @@ TEST(SerializeUtilsTest, ContextPhaseParams) EXPECT_EQ(state2, stateCopy); } + + // Test with ctxDpRank and disaggInfoEndpoint + { + auto state = std::make_unique(); + state->setCommState(texec::kv_cache::CommState{{10, 20}}); + auto stats + = texec::ContextPhaseParams({10, 20, 30}, 2, state.release(), VecTokens{5, 6}, 3, "http://127.0.0.1:8080"); + auto stats2 = serializeDeserialize(stats); + EXPECT_EQ(stats, stats2); + EXPECT_EQ(stats.getCtxDpRank(), 3); + EXPECT_EQ(stats.getDisaggInfoEndpoint(), "http://127.0.0.1:8080"); + } + + { + auto stats = texec::ContextPhaseParams({1, 2}, 1, std::nullopt, 5, std::nullopt); + auto stats2 = serializeDeserialize(stats); + EXPECT_EQ(stats, stats2); + EXPECT_EQ(stats.getCtxDpRank(), 5); + EXPECT_EQ(stats.getDisaggInfoEndpoint(), std::nullopt); + } + + { + auto stats = texec::ContextPhaseParams({1, 2}, 1, std::nullopt, std::nullopt, "endpoint://test"); + auto stats2 = serializeDeserialize(stats); + EXPECT_EQ(stats, stats2); + EXPECT_EQ(stats.getCtxDpRank(), std::nullopt); + EXPECT_EQ(stats.getDisaggInfoEndpoint(), "endpoint://test"); + } } TEST(SerializeUtilsTest, SpeculativeDecodingFastLogitsInfo) diff --git a/tensorrt_llm/_torch/pyexecutor/executor_request_queue.py b/tensorrt_llm/_torch/pyexecutor/executor_request_queue.py index 30e813513b..60b45f4e94 100644 --- a/tensorrt_llm/_torch/pyexecutor/executor_request_queue.py +++ b/tensorrt_llm/_torch/pyexecutor/executor_request_queue.py @@ -482,10 +482,13 @@ class ExecutorRequestQueue: new_requests, "py_scheduling_params") py_num_logprobs = self._collect_py_objects_from_requests( new_requests, "py_num_logprobs") + py_disaggregated_params = self._collect_py_objects_from_requests( + new_requests, "py_disaggregated_params") py_request_objects = tuple( filter(None, [ py_logits_post_processors, py_multimodal_data, - py_scheduling_params, py_num_logprobs + py_scheduling_params, py_num_logprobs, + py_disaggregated_params ])) else: py_request_objects = None diff --git a/tensorrt_llm/_torch/pyexecutor/llm_request.py b/tensorrt_llm/_torch/pyexecutor/llm_request.py index 62a386d585..f48d724658 100644 --- a/tensorrt_llm/_torch/pyexecutor/llm_request.py +++ b/tensorrt_llm/_torch/pyexecutor/llm_request.py @@ -567,6 +567,7 @@ class LlmRequest(tensorrt_llm.bindings.internal.batch_manager.LlmRequest): self.py_logprobs_mode = LogprobMode( logprobs_mode) # handle passed a raw string + self.py_disaggregated_params = None self.py_result = PyResult( prompt_len=self.py_prompt_len, @@ -831,6 +832,10 @@ def executor_request_to_llm_request( logprobs_mode=getattr(executor_request, "py_logprobs_mode", LogprobMode.RAW), ) + + llm_request.py_disaggregated_params = getattr(executor_request, + "py_disaggregated_params", + None) if child_req_ids: for child_id in child_req_ids: llm_request.create_child_request(child_id) diff --git a/tensorrt_llm/disaggregated_params.py b/tensorrt_llm/disaggregated_params.py index b98780bac7..a394002582 100644 --- a/tensorrt_llm/disaggregated_params.py +++ b/tensorrt_llm/disaggregated_params.py @@ -36,7 +36,8 @@ class DisaggregatedParams: draft_tokens: Optional[List[int]] = None # If disagg_request_id is set, both context and generation requests will use it as underlying request id. disagg_request_id: Optional[int] = None - + ctx_dp_rank: Optional[int] = None + ctx_info_endpoint: Optional[List[str]] = None # E-P Disaggregated Params multimodal_embedding_handles: Optional[List[Dict[str, Any]]] = ( None # multimodal embedding handles should be a list of cudaIPC handles for each mm_embedding @@ -53,7 +54,12 @@ class DisaggregatedParams: self.disagg_request_id if self.disagg_request_id is not None else self.ctx_request_id ) return tllme.ContextPhaseParams( - self.first_gen_tokens, request_id, self.opaque_state, self.draft_tokens + self.first_gen_tokens, + request_id, + self.opaque_state, + self.draft_tokens, + self.ctx_dp_rank, + self.ctx_info_endpoint, ) def get_request_type(self) -> tllme.RequestType: diff --git a/tensorrt_llm/executor/base_worker.py b/tensorrt_llm/executor/base_worker.py index 0b6a417793..d121f520c3 100644 --- a/tensorrt_llm/executor/base_worker.py +++ b/tensorrt_llm/executor/base_worker.py @@ -566,6 +566,9 @@ class BaseWorker(GenerationExecutor): executor_request.py_lora_path = py_lora_path executor_request.py_logprobs_mode = request.sampling_params.logprobs_mode + # here we add executor_request.py_disaggregated_params= request.disaggregated_params for python cache transceiver + if self._is_pytorch_backend and request.disaggregated_params is not None: + executor_request.py_disaggregated_params = request.disaggregated_params if self._is_pytorch_backend and request.multimodal_params is not None: if request.multimodal_params.multimodal_data is not None: # NOTE: Deserialize SharedTensor handle to actual tensor diff --git a/tensorrt_llm/executor/result.py b/tensorrt_llm/executor/result.py index 26e0a3c5fa..0642c1aff1 100644 --- a/tensorrt_llm/executor/result.py +++ b/tensorrt_llm/executor/result.py @@ -427,6 +427,8 @@ class GenerationResultBase: ctx_request_id=context_phase_params.req_id, opaque_state=context_phase_params.opaque_state, draft_tokens=context_phase_params.draft_tokens, + ctx_dp_rank=context_phase_params.ctx_dp_rank, + ctx_info_endpoint=context_phase_params.disagg_info_endpoint, multimodal_embedding_handles=None, ) diff --git a/tensorrt_llm/serve/openai_disagg_service.py b/tensorrt_llm/serve/openai_disagg_service.py index 0f43a6147d..b9e88fe5f5 100644 --- a/tensorrt_llm/serve/openai_disagg_service.py +++ b/tensorrt_llm/serve/openai_disagg_service.py @@ -313,4 +313,8 @@ class OpenAIDisaggregatedService(OpenAIService): raise ValueError("Context server did not return disaggregated params") if ctx_response.choices[0].disaggregated_params.ctx_request_id is None: raise ValueError("Invalid disaggregated params in context phase response.") + if ctx_response.choices[0].disaggregated_params.disagg_request_id is None: + raise ValueError( + "Invalid disaggregated params in context phase response. disagg_request_id is None" + ) return ctx_response diff --git a/tensorrt_llm/serve/openai_protocol.py b/tensorrt_llm/serve/openai_protocol.py index a909212024..877688732a 100644 --- a/tensorrt_llm/serve/openai_protocol.py +++ b/tensorrt_llm/serve/openai_protocol.py @@ -119,6 +119,8 @@ class DisaggregatedParams(OpenAIBaseModel): encoded_opaque_state: Optional[str] = None draft_tokens: Optional[List[int]] = None disagg_request_id: Optional[int] = None + ctx_dp_rank: Optional[int] = None + ctx_info_endpoint: Optional[str] = None class ErrorResponse(OpenAIBaseModel): @@ -1091,7 +1093,9 @@ def to_disaggregated_params( encoded_opaque_state=encode_opaque_state( tllm_disagg_params.opaque_state), draft_tokens=tllm_disagg_params.draft_tokens, - disagg_request_id=tllm_disagg_params.disagg_request_id) + disagg_request_id=tllm_disagg_params.disagg_request_id, + ctx_dp_rank=tllm_disagg_params.ctx_dp_rank, + ctx_info_endpoint=tllm_disagg_params.ctx_info_endpoint) def to_llm_disaggregated_params( @@ -1105,7 +1109,9 @@ def to_llm_disaggregated_params( opaque_state=decode_opaque_state( disaggregated_params.encoded_opaque_state), draft_tokens=disaggregated_params.draft_tokens, - disagg_request_id=disaggregated_params.disagg_request_id) + disagg_request_id=disaggregated_params.disagg_request_id, + ctx_dp_rank=disaggregated_params.ctx_dp_rank, + ctx_info_endpoint=disaggregated_params.ctx_info_endpoint) UCompletionRequest = Union[CompletionRequest, ChatCompletionRequest] diff --git a/tests/unittest/bindings/test_executor_bindings.py b/tests/unittest/bindings/test_executor_bindings.py index c913531298..884247f91c 100644 --- a/tests/unittest/bindings/test_executor_bindings.py +++ b/tests/unittest/bindings/test_executor_bindings.py @@ -1198,9 +1198,9 @@ def test_result_pickle(): result.sequence_index = 1 result.is_sequence_final = True result.decoding_iter = 1 - result.context_phase_params = trtllm.ContextPhaseParams([1, 2], 123, - bytes([0, 1]), - [10, 20, 30]) + result.context_phase_params = trtllm.ContextPhaseParams( + [1, 2], 123, bytes([0, + 1]), [10, 20, 30], 1, "disagg_info_endpoint_24680") result.request_perf_metrics = trtllm.RequestPerfMetrics() result.request_perf_metrics.last_iter = 33 result_str = pickle.dumps(result) @@ -1220,6 +1220,8 @@ def test_result_pickle(): assert result.context_phase_params.first_gen_tokens == result_copy.context_phase_params.first_gen_tokens assert result.context_phase_params.draft_tokens == result_copy.context_phase_params.draft_tokens assert result.context_phase_params.opaque_state == result_copy.context_phase_params.opaque_state + assert result.context_phase_params.ctx_dp_rank == result_copy.context_phase_params.ctx_dp_rank + assert result.context_phase_params.disagg_info_endpoint == result_copy.context_phase_params.disagg_info_endpoint assert result.request_perf_metrics.last_iter == result_copy.request_perf_metrics.last_iter