change context params and disagg params

Signed-off-by: Chuang Zhu <111838961+chuangz0@users.noreply.github.com>
This commit is contained in:
Chuang Zhu 2026-01-07 09:27:47 +00:00
parent 7e88212d24
commit ead4fc3336
No known key found for this signature in database
17 changed files with 252 additions and 69 deletions

View File

@ -442,11 +442,16 @@ class ContextPhaseParams
public:
using RequestIdType = std::uint64_t;
ContextPhaseParams(VecTokens firstGenTokens, RequestIdType reqId, std::optional<VecTokens> draftTokens);
ContextPhaseParams(
VecTokens firstGenTokens, RequestIdType reqId, void* state, std::optional<VecTokens> draftTokens);
ContextPhaseParams(VecTokens firstGenTokens, RequestIdType reqId, std::optional<VecTokens> draftTokens,
std::optional<std::int64_t> disaggId = std::nullopt, std::optional<SizeType32> ctxDpRank = std::nullopt,
std::optional<std::string> disaggInfoEndpoint = std::nullopt);
ContextPhaseParams(VecTokens firstGenTokens, RequestIdType reqId, void* state, std::optional<VecTokens> draftTokens,
std::optional<std::int64_t> disaggId = std::nullopt, std::optional<SizeType32> ctxDpRank = std::nullopt,
std::optional<std::string> disaggInfoEndpoint = std::nullopt);
ContextPhaseParams(VecTokens firstGenTokens, RequestIdType reqId, std::vector<char> const& serializedState,
std::optional<VecTokens> draftTokens);
std::optional<VecTokens> draftTokens, std::optional<std::int64_t> disaggId = std::nullopt,
std::optional<SizeType32> ctxDpRank = std::nullopt,
std::optional<std::string> disaggInfoEndpoint = std::nullopt);
ContextPhaseParams(ContextPhaseParams const&);
ContextPhaseParams(ContextPhaseParams&&) noexcept;
@ -457,15 +462,24 @@ public:
[[nodiscard]] bool operator==(ContextPhaseParams const&) const noexcept;
[[nodiscard]] VecTokens const& getFirstGenTokens() const& noexcept;
void setFirstGenTokens(VecTokens firstGenTokens) noexcept;
[[nodiscard]] std::optional<VecTokens> const& getDraftTokens() const& noexcept;
void setDraftTokens(std::optional<VecTokens> draftTokens) noexcept;
[[nodiscard]] VecTokens popFirstGenTokens() && noexcept;
[[nodiscard]] RequestIdType getReqId() const noexcept;
void setReqId(RequestIdType reqId) noexcept;
[[nodiscard]] void const* getState() const noexcept;
[[nodiscard]] void* getState() noexcept;
[[nodiscard]] void* releaseState() noexcept;
[[nodiscard]] std::vector<char> getSerializedState() const noexcept;
[[nodiscard]] std::optional<std::int64_t> getDisaggId() const noexcept;
void setDisaggId(std::optional<std::int64_t> disaggId) noexcept;
[[nodiscard]] std::optional<SizeType32> getCtxDpRank() const noexcept;
void setCtxDpRank(std::optional<SizeType32> ctxDpRank) noexcept;
[[nodiscard]] std::optional<std::string> const& getDisaggInfoEndpoint() const noexcept;
void setDisaggInfoEndpoint(std::optional<std::string> disaggInfoEndpoint) noexcept;
private:
friend class Serialization;
static void deleter(void const* data);
@ -482,6 +496,15 @@ private:
/// @brief The draft tokens generated by context executor
std::optional<VecTokens> mDraftTokens;
/// @brief The disaggregated id
std::optional<std::int64_t> mDisaggId;
/// @brief The context phase data parallel rank
std::optional<SizeType32> mCtxDpRank;
/// @brief The disaggregated info endpoint
std::optional<std::string> mDisaggInfoEndpoint;
};
/// @brief Configuration for speculative decoding (both draft and target models)

View File

@ -99,13 +99,16 @@ std::optional<executor::Result> LlmRequest::createResult(bool useFastLogits, int
}
if (!hasDraftTokens())
{
result.contextPhaseParams = executor::ContextPhaseParams{
std::move(firstGenTokens), mRequestId, mContextPhaseParams.value().releaseState(), std::nullopt};
result.contextPhaseParams = executor::ContextPhaseParams{std::move(firstGenTokens), mRequestId,
mContextPhaseParams.value().releaseState(), std::nullopt, mContextPhaseParams.value().getDisaggId(),
mContextPhaseParams.value().getCtxDpRank(), mContextPhaseParams.value().getDisaggInfoEndpoint()};
}
else
{
result.contextPhaseParams = executor::ContextPhaseParams{
std::move(firstGenTokens), mRequestId, mContextPhaseParams.value().releaseState(), *getDraftTokens()};
result.contextPhaseParams = executor::ContextPhaseParams{std::move(firstGenTokens), mRequestId,
mContextPhaseParams.value().releaseState(), *getDraftTokens(),
mContextPhaseParams.value().getDisaggId(), mContextPhaseParams.value().getCtxDpRank(),
mContextPhaseParams.value().getDisaggInfoEndpoint()};
}
}

View File

@ -27,28 +27,41 @@ namespace su = tensorrt_llm::executor::serialize_utils;
namespace tensorrt_llm::executor
{
ContextPhaseParams::ContextPhaseParams(
VecTokens firstGenTokens, RequestIdType reqId, void* state, std::optional<VecTokens> draftTokens)
ContextPhaseParams::ContextPhaseParams(VecTokens firstGenTokens, RequestIdType reqId, void* state,
std::optional<VecTokens> draftTokens, std::optional<std::int64_t> disaggId, std::optional<SizeType32> ctxDpRank,
std::optional<std::string> disaggInfoEndpoint)
: mReqId{reqId}
, mFirstGenTokens{std::move(firstGenTokens)}
, mState{StatePtr{state, deleter}}
, mDraftTokens{std::move(draftTokens)}
{
}
ContextPhaseParams::ContextPhaseParams(
VecTokens firstGenTokens, RequestIdType reqId, std::optional<VecTokens> draftTokens)
: mReqId{reqId}
, mFirstGenTokens{std::move(firstGenTokens)}
, mDraftTokens{std::move(draftTokens)}
, mDisaggId{std::move(disaggId)}
, mCtxDpRank{ctxDpRank}
, mDisaggInfoEndpoint{std::move(disaggInfoEndpoint)}
{
}
ContextPhaseParams::ContextPhaseParams(VecTokens firstGenTokens, RequestIdType reqId,
std::vector<char> const& serializedState, std::optional<VecTokens> draftTokens)
std::optional<VecTokens> draftTokens, std::optional<std::int64_t> disaggId, std::optional<SizeType32> ctxDpRank,
std::optional<std::string> disaggInfoEndpoint)
: mReqId{reqId}
, mFirstGenTokens{std::move(firstGenTokens)}
, mDraftTokens{std::move(draftTokens)}
, mDisaggId{std::move(disaggId)}
, mCtxDpRank{ctxDpRank}
, mDisaggInfoEndpoint{std::move(disaggInfoEndpoint)}
{
}
ContextPhaseParams::ContextPhaseParams(VecTokens firstGenTokens, RequestIdType reqId,
std::vector<char> const& serializedState, std::optional<VecTokens> draftTokens,
std::optional<std::int64_t> disaggId, std::optional<SizeType32> ctxDpRank,
std::optional<std::string> disaggInfoEndpoint)
: mReqId{reqId}
, mFirstGenTokens{std::move(firstGenTokens)}
, mDraftTokens{std::move(draftTokens)}
, mDisaggId{std::move(disaggId)}
, mCtxDpRank{ctxDpRank}
, mDisaggInfoEndpoint{std::move(disaggInfoEndpoint)}
{
su::VectorWrapBuf<char> strbuf(const_cast<std::vector<char>&>(serializedState));
@ -60,12 +73,15 @@ ContextPhaseParams::ContextPhaseParams(VecTokens firstGenTokens, RequestIdType r
}
ContextPhaseParams::ContextPhaseParams(ContextPhaseParams const& other)
: mReqId{other.mReqId}
, mFirstGenTokens{other.mFirstGenTokens}
, mDraftTokens{other.mDraftTokens}
, mDisaggId{other.mDisaggId}
, mCtxDpRank{other.mCtxDpRank}
, mDisaggInfoEndpoint{other.mDisaggInfoEndpoint}
{
// Since the internal header files implement the destructor while using the declaration of this
// type, a `unique_ptr` with a custom destructor member is used here.
mReqId = other.mReqId;
mFirstGenTokens = other.mFirstGenTokens;
mDraftTokens = other.mDraftTokens;
if (other.mState)
{
auto* otherState = static_cast<DataTransceiverState*>(other.mState.get());
@ -90,11 +106,21 @@ VecTokens const& ContextPhaseParams::getFirstGenTokens() const& noexcept
return mFirstGenTokens;
}
void ContextPhaseParams::setFirstGenTokens(VecTokens firstGenTokens) noexcept
{
mFirstGenTokens = std::move(firstGenTokens);
}
std::optional<VecTokens> const& ContextPhaseParams::getDraftTokens() const& noexcept
{
return mDraftTokens;
}
void ContextPhaseParams::setDraftTokens(std::optional<VecTokens> draftTokens) noexcept
{
mDraftTokens = std::move(draftTokens);
}
VecTokens ContextPhaseParams::popFirstGenTokens() && noexcept
{
return std::move(mFirstGenTokens);
@ -105,6 +131,11 @@ ContextPhaseParams::RequestIdType ContextPhaseParams::getReqId() const noexcept
return mReqId;
}
void ContextPhaseParams::setReqId(RequestIdType reqId) noexcept
{
mReqId = reqId;
}
void const* ContextPhaseParams::getState() const noexcept
{
return mState.get();
@ -125,6 +156,36 @@ void* ContextPhaseParams::releaseState() noexcept
return mState.release();
}
std::optional<std::int64_t> ContextPhaseParams::getDisaggId() const noexcept
{
return mDisaggId;
}
void ContextPhaseParams::setDisaggId(std::optional<std::int64_t> disaggId) noexcept
{
mDisaggId = disaggId;
}
std::optional<SizeType32> ContextPhaseParams::getCtxDpRank() const noexcept
{
return mCtxDpRank;
}
void ContextPhaseParams::setCtxDpRank(std::optional<SizeType32> ctxDpRank) noexcept
{
mCtxDpRank = ctxDpRank;
}
std::optional<std::string> const& ContextPhaseParams::getDisaggInfoEndpoint() const noexcept
{
return mDisaggInfoEndpoint;
}
void ContextPhaseParams::setDisaggInfoEndpoint(std::optional<std::string> disaggInfoEndpoint) noexcept
{
mDisaggInfoEndpoint = std::move(disaggInfoEndpoint);
}
void ContextPhaseParams::deleter(void const* data)
{
using StateT = DataTransceiverState const;
@ -134,7 +195,8 @@ void ContextPhaseParams::deleter(void const* data)
bool ContextPhaseParams::operator==(ContextPhaseParams const& other) const noexcept
{
if (mFirstGenTokens != other.mFirstGenTokens || mReqId != other.mReqId || mDraftTokens != other.mDraftTokens
|| static_cast<bool>(mState) != static_cast<bool>(other.mState))
|| mDisaggId != other.mDisaggId || mDisaggInfoEndpoint != other.mDisaggInfoEndpoint
|| mCtxDpRank != other.mCtxDpRank || static_cast<bool>(mState) != static_cast<bool>(other.mState))
{
return false;
}

View File

@ -652,14 +652,19 @@ ContextPhaseParams Serialization::deserializeContextPhaseParams(std::istream& is
auto reqId = su::deserialize<decltype(ContextPhaseParams::mReqId)>(is);
auto firstGenTokens = su::deserialize<decltype(ContextPhaseParams::mFirstGenTokens)>(is);
auto draftTokens = su::deserialize<decltype(ContextPhaseParams::mDraftTokens)>(is);
auto disaggId = su::deserialize<decltype(ContextPhaseParams::mDisaggId)>(is);
auto ctxDpRank = su::deserialize<decltype(ContextPhaseParams::mCtxDpRank)>(is);
auto disaggInfoEndpoint = su::deserialize<decltype(ContextPhaseParams::mDisaggInfoEndpoint)>(is);
auto hasState = su::deserialize<bool>(is);
if (hasState)
{
auto state = std::make_unique<DataTransceiverState>();
*state = deserializeDataTransceiverState(is);
return ContextPhaseParams{std::move(firstGenTokens), reqId, state.release(), std::move(draftTokens)};
return ContextPhaseParams{std::move(firstGenTokens), reqId, state.release(), std::move(draftTokens),
std::move(disaggId), ctxDpRank, std::move(disaggInfoEndpoint)};
}
return ContextPhaseParams{std::move(firstGenTokens), reqId, nullptr, std::move(draftTokens)};
return ContextPhaseParams{std::move(firstGenTokens), reqId, std::move(draftTokens), std::move(disaggId), ctxDpRank,
std::move(disaggInfoEndpoint)};
}
void Serialization::serialize(ContextPhaseParams const& contextPhaseParams, std::ostream& os)
@ -667,6 +672,9 @@ void Serialization::serialize(ContextPhaseParams const& contextPhaseParams, std:
su::serialize(contextPhaseParams.mReqId, os);
su::serialize(contextPhaseParams.mFirstGenTokens, os);
su::serialize(contextPhaseParams.mDraftTokens, os);
su::serialize(contextPhaseParams.mDisaggId, os);
su::serialize(contextPhaseParams.mCtxDpRank, os);
su::serialize(contextPhaseParams.mDisaggInfoEndpoint, os);
su::serialize(static_cast<bool>(contextPhaseParams.mState), os);
if (contextPhaseParams.mState)
{
@ -680,6 +688,9 @@ size_t Serialization::serializedSize(ContextPhaseParams const& contextPhaseParam
totalSize += su::serializedSize(contextPhaseParams.mReqId);
totalSize += su::serializedSize(contextPhaseParams.mFirstGenTokens);
totalSize += su::serializedSize(contextPhaseParams.mDraftTokens);
totalSize += su::serializedSize(contextPhaseParams.mDisaggId);
totalSize += su::serializedSize(contextPhaseParams.mCtxDpRank);
totalSize += su::serializedSize(contextPhaseParams.mDisaggInfoEndpoint);
totalSize += su::serializedSize(bool{});
if (contextPhaseParams.mState)
{

View File

@ -165,7 +165,7 @@ void initBindings(nb::module_& m)
"context_current_position", &GenLlmReq::getContextCurrentPosition, &GenLlmReq::setContextCurrentPosition)
.def_prop_ro("prepopulated_prompt_len", &GenLlmReq::getPrepopulatedPromptLen)
.def_prop_rw("guided_decoding_params", &GenLlmReq::getGuidedDecodingParams, &GenLlmReq::setGuidedDecodingParams)
.def_prop_ro("context_phase_params", &GenLlmReq::getContextPhaseParams)
.def_prop_rw("context_phase_params", &GenLlmReq::getContextPhaseParams, &GenLlmReq::setContextPhaseParams)
.def_prop_ro("is_context_only_request", &GenLlmReq::isContextOnlyRequest)
.def_prop_ro("is_generation_only_request", &GenLlmReq::isGenerationOnlyRequest)
.def_prop_ro("is_generation_complete_state", &GenLlmReq::isGenerationCompleteState)

View File

@ -442,14 +442,16 @@ void initRequestBindings(nb::module_& m)
{
auto serializedState = self.getSerializedState();
return nb::make_tuple(self.getFirstGenTokens(), self.getReqId(),
nb::bytes(serializedState.data(), serializedState.size()), self.getDraftTokens());
nb::bytes(serializedState.data(), serializedState.size()), self.getDraftTokens(), self.getDisaggId(),
self.getCtxDpRank(), self.getDisaggInfoEndpoint());
}
return nb::make_tuple(self.getFirstGenTokens(), self.getReqId(), nb::none(), self.getDraftTokens());
return nb::make_tuple(self.getFirstGenTokens(), self.getReqId(), nb::none(), self.getDraftTokens(),
self.getDisaggId(), self.getCtxDpRank(), self.getDisaggInfoEndpoint());
};
auto ContextPhaseParamsSetState = [](tle::ContextPhaseParams& contextPhaseParams, nb::tuple const& state)
{
if (state.size() != 4)
if (state.size() != 7)
{
throw std::runtime_error("Invalid ContextPhaseParams state!");
}
@ -460,13 +462,15 @@ void initRequestBindings(nb::module_& m)
new (&contextPhaseParams) tle::ContextPhaseParams(nb::cast<VecTokens>(state[0]),
nb::cast<tle::ContextPhaseParams::RequestIdType>(state[1]),
std::vector<char>(opaque_state_str_view.begin(), opaque_state_str_view.end()),
nb::cast<std::optional<VecTokens>>(state[3]));
nb::cast<std::optional<VecTokens>>(state[3]), nb::cast<std::optional<std::int64_t>>(state[4]),
nb::cast<std::optional<SizeType32>>(state[5]), nb::cast<std::optional<std::string>>(state[6]));
}
else
{
new (&contextPhaseParams) tle::ContextPhaseParams(nb::cast<VecTokens>(state[0]),
nb::cast<tle::ContextPhaseParams::RequestIdType>(state[1]),
nb::cast<std::optional<VecTokens>>(state[3]));
nb::cast<std::optional<VecTokens>>(state[3]), nb::cast<std::optional<std::int64_t>>(state[4]),
nb::cast<std::optional<SizeType32>>(state[5]), nb::cast<std::optional<std::string>>(state[6]));
}
};
@ -475,25 +479,37 @@ void initRequestBindings(nb::module_& m)
"__init__",
[](tle::ContextPhaseParams& self, VecTokens const& first_gen_tokens,
tle::ContextPhaseParams::RequestIdType req_id, std::optional<nb::bytes> const& opaque_state,
std::optional<VecTokens> const& draft_tokens)
std::optional<VecTokens> const& draft_tokens, std::optional<std::int64_t> const& disagg_id,
std::optional<SizeType32> const& ctx_dp_rank, std::optional<std::string> const& disagg_info_endpoint)
{
if (opaque_state)
{
auto opaque_state_str_view
= std::string_view(opaque_state.value().c_str(), opaque_state.value().size());
new (&self) tle::ContextPhaseParams(first_gen_tokens, req_id,
std::vector<char>(opaque_state_str_view.begin(), opaque_state_str_view.end()), draft_tokens);
std::vector<char>(opaque_state_str_view.begin(), opaque_state_str_view.end()), draft_tokens,
disagg_id, ctx_dp_rank, disagg_info_endpoint);
}
else
{
new (&self) tle::ContextPhaseParams(first_gen_tokens, req_id, draft_tokens);
new (&self) tle::ContextPhaseParams(
first_gen_tokens, req_id, draft_tokens, disagg_id, ctx_dp_rank, disagg_info_endpoint);
}
},
nb::arg("first_gen_tokens"), nb::arg("req_id"), nb::arg("opaque_state").none(),
nb::arg("draft_tokens").none())
.def_prop_ro("first_gen_tokens", [](tle::ContextPhaseParams const& self) { return self.getFirstGenTokens(); })
.def_prop_ro("draft_tokens", [](tle::ContextPhaseParams const& self) { return self.getDraftTokens(); })
.def_prop_ro("req_id", &tle::ContextPhaseParams::getReqId)
nb::arg("draft_tokens").none(), nb::arg("disagg_id").none(), nb::arg("ctx_dp_rank").none(),
nb::arg("disagg_info_endpoint").none())
.def_prop_rw(
"first_gen_tokens", [](tle::ContextPhaseParams const& self) { return self.getFirstGenTokens(); },
[](tle::ContextPhaseParams& self, VecTokens const& tokens) { self.setFirstGenTokens(tokens); })
.def_prop_rw(
"draft_tokens", [](tle::ContextPhaseParams const& self) { return self.getDraftTokens(); },
[](tle::ContextPhaseParams& self, std::optional<VecTokens> const& tokens) { self.setDraftTokens(tokens); })
.def_prop_rw("req_id", &tle::ContextPhaseParams::getReqId, &tle::ContextPhaseParams::setReqId)
.def_prop_rw("disagg_id", &tle::ContextPhaseParams::getDisaggId, &tle::ContextPhaseParams::setDisaggId)
.def_prop_rw("ctx_dp_rank", &tle::ContextPhaseParams::getCtxDpRank, &tle::ContextPhaseParams::setCtxDpRank)
.def_prop_rw("disagg_info_endpoint", &tle::ContextPhaseParams::getDisaggInfoEndpoint,
&tle::ContextPhaseParams::setDisaggInfoEndpoint)
.def_prop_ro("opaque_state",
[](tle::ContextPhaseParams const& self)
{

View File

@ -170,7 +170,7 @@ void initBindings(pybind11::module_& m)
.def_property_readonly("prepopulated_prompt_len", &GenLlmReq::getPrepopulatedPromptLen)
.def_property(
"guided_decoding_params", &GenLlmReq::getGuidedDecodingParams, &GenLlmReq::setGuidedDecodingParams)
.def_property_readonly("context_phase_params", &GenLlmReq::getContextPhaseParams)
.def_property("context_phase_params", &GenLlmReq::getContextPhaseParams, &GenLlmReq::setContextPhaseParams)
.def_property_readonly("is_context_only_request", &GenLlmReq::isContextOnlyRequest)
.def_property_readonly("is_generation_only_request", &GenLlmReq::isGenerationOnlyRequest)
.def_property_readonly("is_generation_complete_state", &GenLlmReq::isGenerationCompleteState)

View File

@ -411,14 +411,16 @@ void initRequestBindings(pybind11::module_& m)
{
auto serializedState = self.getSerializedState();
return py::make_tuple(self.getFirstGenTokens(), self.getReqId(),
py::bytes(serializedState.data(), serializedState.size()), self.getDraftTokens());
py::bytes(serializedState.data(), serializedState.size()), self.getDraftTokens(), self.getDisaggId(),
self.getCtxDpRank(), self.getDisaggInfoEndpoint());
}
return py::make_tuple(self.getFirstGenTokens(), self.getReqId(), py::none(), self.getDraftTokens());
return py::make_tuple(self.getFirstGenTokens(), self.getReqId(), py::none(), self.getDraftTokens(),
self.getDisaggId(), self.getCtxDpRank(), self.getDisaggInfoEndpoint());
};
auto ContextPhaseParamsSetState = [](py::tuple const& state)
{
if (state.size() != 4)
if (state.size() != 7)
{
throw std::runtime_error("Invalid ContextPhaseParams state!");
}
@ -429,28 +431,44 @@ void initRequestBindings(pybind11::module_& m)
return std::make_unique<tle::ContextPhaseParams>(state[0].cast<VecTokens>(),
state[1].cast<tle::ContextPhaseParams::RequestIdType>(),
std::vector<char>(opaque_state_str_view.begin(), opaque_state_str_view.end()),
state[3].cast<std::optional<VecTokens>>());
state[3].cast<std::optional<VecTokens>>(), state[4].cast<std::optional<std::int64_t>>(),
state[5].cast<std::optional<SizeType32>>(), state[6].cast<std::optional<std::string>>());
}
return std::make_unique<tle::ContextPhaseParams>(state[0].cast<VecTokens>(),
state[1].cast<tle::ContextPhaseParams::RequestIdType>(), state[3].cast<std::optional<VecTokens>>());
state[1].cast<tle::ContextPhaseParams::RequestIdType>(), state[3].cast<std::optional<VecTokens>>(),
state[4].cast<std::optional<std::int64_t>>(), state[5].cast<std::optional<SizeType32>>(),
state[6].cast<std::optional<std::string>>());
};
py::class_<tle::ContextPhaseParams>(m, "ContextPhaseParams")
.def(py::init(
[](VecTokens const& first_gen_tokens, tle::ContextPhaseParams::RequestIdType req_id,
std::optional<py::bytes> const& opaque_state, std::optional<VecTokens> const& draft_tokens)
{
if (opaque_state)
{
auto opaque_state_str_view = std::string_view(opaque_state.value().cast<std::string_view>());
return std::make_unique<tle::ContextPhaseParams>(first_gen_tokens, req_id,
std::vector<char>(opaque_state_str_view.begin(), opaque_state_str_view.end()), draft_tokens);
}
return std::make_unique<tle::ContextPhaseParams>(first_gen_tokens, req_id, draft_tokens);
}))
.def_property_readonly("first_gen_tokens", &tle::ContextPhaseParams::getFirstGenTokens)
.def_property_readonly("draft_tokens", &tle::ContextPhaseParams::getDraftTokens)
.def_property_readonly("req_id", &tle::ContextPhaseParams::getReqId)
[](VecTokens const& first_gen_tokens, tle::ContextPhaseParams::RequestIdType req_id,
std::optional<py::bytes> const& opaque_state, std::optional<VecTokens> const& draft_tokens,
std::optional<std::int64_t> const& disagg_id, std::optional<SizeType32> const& ctx_dp_rank,
std::optional<std::string> const& disagg_info_endpoint)
{
if (opaque_state)
{
auto opaque_state_str_view = std::string_view(opaque_state.value().cast<std::string_view>());
return std::make_unique<tle::ContextPhaseParams>(first_gen_tokens, req_id,
std::vector<char>(opaque_state_str_view.begin(), opaque_state_str_view.end()),
draft_tokens, disagg_id, ctx_dp_rank, disagg_info_endpoint);
}
return std::make_unique<tle::ContextPhaseParams>(
first_gen_tokens, req_id, draft_tokens, disagg_id, ctx_dp_rank, disagg_info_endpoint);
}),
py::arg("first_gen_tokens"), py::arg("req_id"), py::arg("opaque_state") = py::none(),
py::arg("draft_tokens") = py::none(), py::arg("disagg_id") = py::none(),
py::arg("ctx_dp_rank") = py::none(), py::arg("disagg_info_endpoint") = py::none())
.def_property("first_gen_tokens", &tle::ContextPhaseParams::getFirstGenTokens,
&tle::ContextPhaseParams::setFirstGenTokens)
.def_property(
"draft_tokens", &tle::ContextPhaseParams::getDraftTokens, &tle::ContextPhaseParams::setDraftTokens)
.def_property("req_id", &tle::ContextPhaseParams::getReqId, &tle::ContextPhaseParams::setReqId)
.def_property("disagg_id", &tle::ContextPhaseParams::getDisaggId, &tle::ContextPhaseParams::setDisaggId)
.def_property("ctx_dp_rank", &tle::ContextPhaseParams::getCtxDpRank, &tle::ContextPhaseParams::setCtxDpRank)
.def_property("disagg_info_endpoint", &tle::ContextPhaseParams::getDisaggInfoEndpoint,
&tle::ContextPhaseParams::setDisaggInfoEndpoint)
.def_property_readonly("opaque_state",
[](tle::ContextPhaseParams const& self)
{

View File

@ -465,10 +465,13 @@ class ExecutorRequestQueue:
new_requests, "py_scheduling_params")
py_num_logprobs = self._collect_py_objects_from_requests(
new_requests, "py_num_logprobs")
py_disaggregated_params = self._collect_py_objects_from_requests(
new_requests, "py_disaggregated_params")
py_request_objects = tuple(
filter(None, [
py_logits_post_processors, py_multimodal_data,
py_scheduling_params, py_num_logprobs
py_scheduling_params, py_num_logprobs,
py_disaggregated_params
]))
else:
py_request_objects = None

View File

@ -121,7 +121,11 @@ class BindKvCacheTransceiver(KvCacheTransceiver):
cache_transceiver_config._to_pybind())
def respond_and_send_async(self, req: LlmRequest):
return self.impl.respond_and_send_async(req)
self.impl.respond_and_send_async(req)
if (req.py_disaggregated_params is not None
and req.py_disaggregated_params.disagg_id is not None):
req.context_phase_params.disagg_id = req.py_disaggregated_params.disagg_id
return
def request_and_receive_sync(self, req: LlmRequest):
return self.impl.request_and_receive_sync(req)

View File

@ -566,6 +566,8 @@ class LlmRequest(tensorrt_llm.bindings.internal.batch_manager.LlmRequest):
# currently, keep py_stop_words_list as python list, rather than tensor.
self.py_stop_words_list = stop_words_list
self.py_disaggregated_params = None
self.py_result = PyResult(
prompt_len=self.py_prompt_len,
max_new_tokens=self.py_max_new_tokens,
@ -826,6 +828,10 @@ def executor_request_to_llm_request(
py_multimodal_data=getattr(executor_request, "py_multimodal_data",
None),
kv_cache_retention_config=executor_request.kv_cache_retention_config)
llm_request.py_disaggregated_params = getattr(executor_request,
"py_disaggregated_params",
None)
if child_req_ids:
for child_id in child_req_ids:
llm_request.create_child_request(child_id)

View File

@ -32,7 +32,9 @@ class DisaggregatedParams:
ctx_request_id: Optional[int] = None
opaque_state: Optional[bytes] = None
draft_tokens: Optional[List[int]] = None
disagg_id: Optional[int] = None
ctx_dp_rank: Optional[int] = None
ctx_info_endpoint: Optional[List[str]] = None
# E-P Disaggregated Params
multimodal_embedding_handles: Optional[List[Dict[str, Any]]] = (
None # multimodal embedding handles should be a list of cudaIPC handles for each mm_embedding
@ -45,7 +47,13 @@ class DisaggregatedParams:
def get_context_phase_params(self) -> tllme.ContextPhaseParams:
return tllme.ContextPhaseParams(
self.first_gen_tokens, self.ctx_request_id, self.opaque_state, self.draft_tokens
self.first_gen_tokens,
self.ctx_request_id,
self.opaque_state,
self.draft_tokens,
self.disagg_id,
self.ctx_dp_rank,
self.ctx_info_endpoint,
)
def get_request_type(self) -> tllme.RequestType:

View File

@ -563,6 +563,9 @@ class BaseWorker(GenerationExecutor):
executor_request.py_num_logprobs = request.sampling_params.logprobs
executor_request.py_lora_path = py_lora_path
# here we add executor_request.py_disaggregated_params= request.disaggregated_params for python cache transceiver
if self._is_pytorch_backend and request.disaggregated_params is not None:
executor_request.py_disaggregated_params = request.disaggregated_params
if self._is_pytorch_backend and request.multimodal_params is not None:
if request.multimodal_params.multimodal_data is not None:
# NOTE: Deserialize SharedTensor handle to actual tensor

View File

@ -427,6 +427,9 @@ class GenerationResultBase:
ctx_request_id=context_phase_params.req_id,
opaque_state=context_phase_params.opaque_state,
draft_tokens=context_phase_params.draft_tokens,
disagg_id=context_phase_params.disagg_id,
ctx_dp_rank=context_phase_params.ctx_dp_rank,
ctx_info_endpoint=context_phase_params.disagg_info_endpoint,
multimodal_embedding_handles=None,
)

View File

@ -15,6 +15,7 @@
import asyncio
import copy
import os
import uuid
from typing import Any, Callable, Dict, Optional
from tensorrt_llm.llmapi.disagg_utils import (
@ -142,7 +143,13 @@ class OpenAIDisaggregatedService(OpenAIService):
def _get_ctx_request(self, request: UCompletionRequest) -> UCompletionRequest:
ctx_request = copy.deepcopy(request)
ctx_request.disaggregated_params = DisaggregatedParams(request_type="context_only")
unique_disagg_id = (
uuid.uuid4().int & 0x7FFFFFFFFFFFFFFF
) # Generate positive int64 from uuid
ctx_request.disaggregated_params = DisaggregatedParams(
request_type="context_only",
disagg_id=unique_disagg_id,
)
ctx_request.stream = False
ctx_request.stream_options = None
return ctx_request
@ -304,4 +311,8 @@ class OpenAIDisaggregatedService(OpenAIService):
raise ValueError("Context server did not return disaggregated params")
if ctx_response.choices[0].disaggregated_params.ctx_request_id is None:
raise ValueError("Invalid disaggregated params in context phase response.")
if ctx_response.choices[0].disaggregated_params.disagg_id is None:
raise ValueError(
"Invalid disaggregated params in context phase response. disagg_id is None"
)
return ctx_response

View File

@ -117,6 +117,9 @@ class DisaggregatedParams(OpenAIBaseModel):
ctx_request_id: Optional[int] = None
encoded_opaque_state: Optional[str] = None
draft_tokens: Optional[List[int]] = None
disagg_id: Optional[int] = None
ctx_dp_rank: Optional[int] = None
ctx_info_endpoint: Optional[str] = None
class ErrorResponse(OpenAIBaseModel):
@ -1000,7 +1003,10 @@ def to_disaggregated_params(
ctx_request_id=tllm_disagg_params.ctx_request_id,
encoded_opaque_state=encode_opaque_state(
tllm_disagg_params.opaque_state),
draft_tokens=tllm_disagg_params.draft_tokens)
draft_tokens=tllm_disagg_params.draft_tokens,
disagg_id=tllm_disagg_params.disagg_id,
ctx_dp_rank=tllm_disagg_params.ctx_dp_rank,
ctx_info_endpoint=tllm_disagg_params.ctx_info_endpoint)
def to_llm_disaggregated_params(
@ -1013,7 +1019,10 @@ def to_llm_disaggregated_params(
ctx_request_id=disaggregated_params.ctx_request_id,
opaque_state=decode_opaque_state(
disaggregated_params.encoded_opaque_state),
draft_tokens=disaggregated_params.draft_tokens)
draft_tokens=disaggregated_params.draft_tokens,
disagg_id=disaggregated_params.disagg_id,
ctx_dp_rank=disaggregated_params.ctx_dp_rank,
ctx_info_endpoint=disaggregated_params.ctx_info_endpoint)
UCompletionRequest = Union[CompletionRequest, ChatCompletionRequest]

View File

@ -1198,9 +1198,9 @@ def test_result_pickle():
result.sequence_index = 1
result.is_sequence_final = True
result.decoding_iter = 1
result.context_phase_params = trtllm.ContextPhaseParams([1, 2], 123,
bytes([0, 1]),
[10, 20, 30])
result.context_phase_params = trtllm.ContextPhaseParams(
[1, 2], 123, bytes([0, 1]), [10, 20, 30], 13579, 1,
"disagg_info_endpoint_24680")
result.request_perf_metrics = trtllm.RequestPerfMetrics()
result.request_perf_metrics.last_iter = 33
result_str = pickle.dumps(result)
@ -1220,6 +1220,9 @@ def test_result_pickle():
assert result.context_phase_params.first_gen_tokens == result_copy.context_phase_params.first_gen_tokens
assert result.context_phase_params.draft_tokens == result_copy.context_phase_params.draft_tokens
assert result.context_phase_params.opaque_state == result_copy.context_phase_params.opaque_state
assert result.context_phase_params.disagg_id == result_copy.context_phase_params.disagg_id
assert result.context_phase_params.ctx_dp_rank == result_copy.context_phase_params.ctx_dp_rank
assert result.context_phase_params.disagg_info_endpoint == result_copy.context_phase_params.disagg_info_endpoint
assert result.request_perf_metrics.last_iter == result_copy.request_perf_metrics.last_iter