mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
change context params and disagg params
Signed-off-by: Chuang Zhu <111838961+chuangz0@users.noreply.github.com>
This commit is contained in:
parent
7e88212d24
commit
ead4fc3336
@ -442,11 +442,16 @@ class ContextPhaseParams
|
||||
public:
|
||||
using RequestIdType = std::uint64_t;
|
||||
|
||||
ContextPhaseParams(VecTokens firstGenTokens, RequestIdType reqId, std::optional<VecTokens> draftTokens);
|
||||
ContextPhaseParams(
|
||||
VecTokens firstGenTokens, RequestIdType reqId, void* state, std::optional<VecTokens> draftTokens);
|
||||
ContextPhaseParams(VecTokens firstGenTokens, RequestIdType reqId, std::optional<VecTokens> draftTokens,
|
||||
std::optional<std::int64_t> disaggId = std::nullopt, std::optional<SizeType32> ctxDpRank = std::nullopt,
|
||||
std::optional<std::string> disaggInfoEndpoint = std::nullopt);
|
||||
ContextPhaseParams(VecTokens firstGenTokens, RequestIdType reqId, void* state, std::optional<VecTokens> draftTokens,
|
||||
std::optional<std::int64_t> disaggId = std::nullopt, std::optional<SizeType32> ctxDpRank = std::nullopt,
|
||||
std::optional<std::string> disaggInfoEndpoint = std::nullopt);
|
||||
ContextPhaseParams(VecTokens firstGenTokens, RequestIdType reqId, std::vector<char> const& serializedState,
|
||||
std::optional<VecTokens> draftTokens);
|
||||
std::optional<VecTokens> draftTokens, std::optional<std::int64_t> disaggId = std::nullopt,
|
||||
std::optional<SizeType32> ctxDpRank = std::nullopt,
|
||||
std::optional<std::string> disaggInfoEndpoint = std::nullopt);
|
||||
|
||||
ContextPhaseParams(ContextPhaseParams const&);
|
||||
ContextPhaseParams(ContextPhaseParams&&) noexcept;
|
||||
@ -457,15 +462,24 @@ public:
|
||||
[[nodiscard]] bool operator==(ContextPhaseParams const&) const noexcept;
|
||||
|
||||
[[nodiscard]] VecTokens const& getFirstGenTokens() const& noexcept;
|
||||
void setFirstGenTokens(VecTokens firstGenTokens) noexcept;
|
||||
[[nodiscard]] std::optional<VecTokens> const& getDraftTokens() const& noexcept;
|
||||
void setDraftTokens(std::optional<VecTokens> draftTokens) noexcept;
|
||||
[[nodiscard]] VecTokens popFirstGenTokens() && noexcept;
|
||||
[[nodiscard]] RequestIdType getReqId() const noexcept;
|
||||
|
||||
void setReqId(RequestIdType reqId) noexcept;
|
||||
[[nodiscard]] void const* getState() const noexcept;
|
||||
[[nodiscard]] void* getState() noexcept;
|
||||
[[nodiscard]] void* releaseState() noexcept;
|
||||
[[nodiscard]] std::vector<char> getSerializedState() const noexcept;
|
||||
|
||||
[[nodiscard]] std::optional<std::int64_t> getDisaggId() const noexcept;
|
||||
void setDisaggId(std::optional<std::int64_t> disaggId) noexcept;
|
||||
[[nodiscard]] std::optional<SizeType32> getCtxDpRank() const noexcept;
|
||||
void setCtxDpRank(std::optional<SizeType32> ctxDpRank) noexcept;
|
||||
[[nodiscard]] std::optional<std::string> const& getDisaggInfoEndpoint() const noexcept;
|
||||
void setDisaggInfoEndpoint(std::optional<std::string> disaggInfoEndpoint) noexcept;
|
||||
|
||||
private:
|
||||
friend class Serialization;
|
||||
static void deleter(void const* data);
|
||||
@ -482,6 +496,15 @@ private:
|
||||
|
||||
/// @brief The draft tokens generated by context executor
|
||||
std::optional<VecTokens> mDraftTokens;
|
||||
|
||||
/// @brief The disaggregated id
|
||||
std::optional<std::int64_t> mDisaggId;
|
||||
|
||||
/// @brief The context phase data parallel rank
|
||||
std::optional<SizeType32> mCtxDpRank;
|
||||
|
||||
/// @brief The disaggregated info endpoint
|
||||
std::optional<std::string> mDisaggInfoEndpoint;
|
||||
};
|
||||
|
||||
/// @brief Configuration for speculative decoding (both draft and target models)
|
||||
|
||||
@ -99,13 +99,16 @@ std::optional<executor::Result> LlmRequest::createResult(bool useFastLogits, int
|
||||
}
|
||||
if (!hasDraftTokens())
|
||||
{
|
||||
result.contextPhaseParams = executor::ContextPhaseParams{
|
||||
std::move(firstGenTokens), mRequestId, mContextPhaseParams.value().releaseState(), std::nullopt};
|
||||
result.contextPhaseParams = executor::ContextPhaseParams{std::move(firstGenTokens), mRequestId,
|
||||
mContextPhaseParams.value().releaseState(), std::nullopt, mContextPhaseParams.value().getDisaggId(),
|
||||
mContextPhaseParams.value().getCtxDpRank(), mContextPhaseParams.value().getDisaggInfoEndpoint()};
|
||||
}
|
||||
else
|
||||
{
|
||||
result.contextPhaseParams = executor::ContextPhaseParams{
|
||||
std::move(firstGenTokens), mRequestId, mContextPhaseParams.value().releaseState(), *getDraftTokens()};
|
||||
result.contextPhaseParams = executor::ContextPhaseParams{std::move(firstGenTokens), mRequestId,
|
||||
mContextPhaseParams.value().releaseState(), *getDraftTokens(),
|
||||
mContextPhaseParams.value().getDisaggId(), mContextPhaseParams.value().getCtxDpRank(),
|
||||
mContextPhaseParams.value().getDisaggInfoEndpoint()};
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -27,28 +27,41 @@ namespace su = tensorrt_llm::executor::serialize_utils;
|
||||
namespace tensorrt_llm::executor
|
||||
{
|
||||
|
||||
ContextPhaseParams::ContextPhaseParams(
|
||||
VecTokens firstGenTokens, RequestIdType reqId, void* state, std::optional<VecTokens> draftTokens)
|
||||
ContextPhaseParams::ContextPhaseParams(VecTokens firstGenTokens, RequestIdType reqId, void* state,
|
||||
std::optional<VecTokens> draftTokens, std::optional<std::int64_t> disaggId, std::optional<SizeType32> ctxDpRank,
|
||||
std::optional<std::string> disaggInfoEndpoint)
|
||||
: mReqId{reqId}
|
||||
, mFirstGenTokens{std::move(firstGenTokens)}
|
||||
, mState{StatePtr{state, deleter}}
|
||||
, mDraftTokens{std::move(draftTokens)}
|
||||
{
|
||||
}
|
||||
|
||||
ContextPhaseParams::ContextPhaseParams(
|
||||
VecTokens firstGenTokens, RequestIdType reqId, std::optional<VecTokens> draftTokens)
|
||||
: mReqId{reqId}
|
||||
, mFirstGenTokens{std::move(firstGenTokens)}
|
||||
, mDraftTokens{std::move(draftTokens)}
|
||||
, mDisaggId{std::move(disaggId)}
|
||||
, mCtxDpRank{ctxDpRank}
|
||||
, mDisaggInfoEndpoint{std::move(disaggInfoEndpoint)}
|
||||
{
|
||||
}
|
||||
|
||||
ContextPhaseParams::ContextPhaseParams(VecTokens firstGenTokens, RequestIdType reqId,
|
||||
std::vector<char> const& serializedState, std::optional<VecTokens> draftTokens)
|
||||
std::optional<VecTokens> draftTokens, std::optional<std::int64_t> disaggId, std::optional<SizeType32> ctxDpRank,
|
||||
std::optional<std::string> disaggInfoEndpoint)
|
||||
: mReqId{reqId}
|
||||
, mFirstGenTokens{std::move(firstGenTokens)}
|
||||
, mDraftTokens{std::move(draftTokens)}
|
||||
, mDisaggId{std::move(disaggId)}
|
||||
, mCtxDpRank{ctxDpRank}
|
||||
, mDisaggInfoEndpoint{std::move(disaggInfoEndpoint)}
|
||||
{
|
||||
}
|
||||
|
||||
ContextPhaseParams::ContextPhaseParams(VecTokens firstGenTokens, RequestIdType reqId,
|
||||
std::vector<char> const& serializedState, std::optional<VecTokens> draftTokens,
|
||||
std::optional<std::int64_t> disaggId, std::optional<SizeType32> ctxDpRank,
|
||||
std::optional<std::string> disaggInfoEndpoint)
|
||||
: mReqId{reqId}
|
||||
, mFirstGenTokens{std::move(firstGenTokens)}
|
||||
, mDraftTokens{std::move(draftTokens)}
|
||||
, mDisaggId{std::move(disaggId)}
|
||||
, mCtxDpRank{ctxDpRank}
|
||||
, mDisaggInfoEndpoint{std::move(disaggInfoEndpoint)}
|
||||
{
|
||||
|
||||
su::VectorWrapBuf<char> strbuf(const_cast<std::vector<char>&>(serializedState));
|
||||
@ -60,12 +73,15 @@ ContextPhaseParams::ContextPhaseParams(VecTokens firstGenTokens, RequestIdType r
|
||||
}
|
||||
|
||||
ContextPhaseParams::ContextPhaseParams(ContextPhaseParams const& other)
|
||||
: mReqId{other.mReqId}
|
||||
, mFirstGenTokens{other.mFirstGenTokens}
|
||||
, mDraftTokens{other.mDraftTokens}
|
||||
, mDisaggId{other.mDisaggId}
|
||||
, mCtxDpRank{other.mCtxDpRank}
|
||||
, mDisaggInfoEndpoint{other.mDisaggInfoEndpoint}
|
||||
{
|
||||
// Since the internal header files implement the destructor while using the declaration of this
|
||||
// type, a `unique_ptr` with a custom destructor member is used here.
|
||||
mReqId = other.mReqId;
|
||||
mFirstGenTokens = other.mFirstGenTokens;
|
||||
mDraftTokens = other.mDraftTokens;
|
||||
if (other.mState)
|
||||
{
|
||||
auto* otherState = static_cast<DataTransceiverState*>(other.mState.get());
|
||||
@ -90,11 +106,21 @@ VecTokens const& ContextPhaseParams::getFirstGenTokens() const& noexcept
|
||||
return mFirstGenTokens;
|
||||
}
|
||||
|
||||
void ContextPhaseParams::setFirstGenTokens(VecTokens firstGenTokens) noexcept
|
||||
{
|
||||
mFirstGenTokens = std::move(firstGenTokens);
|
||||
}
|
||||
|
||||
std::optional<VecTokens> const& ContextPhaseParams::getDraftTokens() const& noexcept
|
||||
{
|
||||
return mDraftTokens;
|
||||
}
|
||||
|
||||
void ContextPhaseParams::setDraftTokens(std::optional<VecTokens> draftTokens) noexcept
|
||||
{
|
||||
mDraftTokens = std::move(draftTokens);
|
||||
}
|
||||
|
||||
VecTokens ContextPhaseParams::popFirstGenTokens() && noexcept
|
||||
{
|
||||
return std::move(mFirstGenTokens);
|
||||
@ -105,6 +131,11 @@ ContextPhaseParams::RequestIdType ContextPhaseParams::getReqId() const noexcept
|
||||
return mReqId;
|
||||
}
|
||||
|
||||
void ContextPhaseParams::setReqId(RequestIdType reqId) noexcept
|
||||
{
|
||||
mReqId = reqId;
|
||||
}
|
||||
|
||||
void const* ContextPhaseParams::getState() const noexcept
|
||||
{
|
||||
return mState.get();
|
||||
@ -125,6 +156,36 @@ void* ContextPhaseParams::releaseState() noexcept
|
||||
return mState.release();
|
||||
}
|
||||
|
||||
std::optional<std::int64_t> ContextPhaseParams::getDisaggId() const noexcept
|
||||
{
|
||||
return mDisaggId;
|
||||
}
|
||||
|
||||
void ContextPhaseParams::setDisaggId(std::optional<std::int64_t> disaggId) noexcept
|
||||
{
|
||||
mDisaggId = disaggId;
|
||||
}
|
||||
|
||||
std::optional<SizeType32> ContextPhaseParams::getCtxDpRank() const noexcept
|
||||
{
|
||||
return mCtxDpRank;
|
||||
}
|
||||
|
||||
void ContextPhaseParams::setCtxDpRank(std::optional<SizeType32> ctxDpRank) noexcept
|
||||
{
|
||||
mCtxDpRank = ctxDpRank;
|
||||
}
|
||||
|
||||
std::optional<std::string> const& ContextPhaseParams::getDisaggInfoEndpoint() const noexcept
|
||||
{
|
||||
return mDisaggInfoEndpoint;
|
||||
}
|
||||
|
||||
void ContextPhaseParams::setDisaggInfoEndpoint(std::optional<std::string> disaggInfoEndpoint) noexcept
|
||||
{
|
||||
mDisaggInfoEndpoint = std::move(disaggInfoEndpoint);
|
||||
}
|
||||
|
||||
void ContextPhaseParams::deleter(void const* data)
|
||||
{
|
||||
using StateT = DataTransceiverState const;
|
||||
@ -134,7 +195,8 @@ void ContextPhaseParams::deleter(void const* data)
|
||||
bool ContextPhaseParams::operator==(ContextPhaseParams const& other) const noexcept
|
||||
{
|
||||
if (mFirstGenTokens != other.mFirstGenTokens || mReqId != other.mReqId || mDraftTokens != other.mDraftTokens
|
||||
|| static_cast<bool>(mState) != static_cast<bool>(other.mState))
|
||||
|| mDisaggId != other.mDisaggId || mDisaggInfoEndpoint != other.mDisaggInfoEndpoint
|
||||
|| mCtxDpRank != other.mCtxDpRank || static_cast<bool>(mState) != static_cast<bool>(other.mState))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
@ -652,14 +652,19 @@ ContextPhaseParams Serialization::deserializeContextPhaseParams(std::istream& is
|
||||
auto reqId = su::deserialize<decltype(ContextPhaseParams::mReqId)>(is);
|
||||
auto firstGenTokens = su::deserialize<decltype(ContextPhaseParams::mFirstGenTokens)>(is);
|
||||
auto draftTokens = su::deserialize<decltype(ContextPhaseParams::mDraftTokens)>(is);
|
||||
auto disaggId = su::deserialize<decltype(ContextPhaseParams::mDisaggId)>(is);
|
||||
auto ctxDpRank = su::deserialize<decltype(ContextPhaseParams::mCtxDpRank)>(is);
|
||||
auto disaggInfoEndpoint = su::deserialize<decltype(ContextPhaseParams::mDisaggInfoEndpoint)>(is);
|
||||
auto hasState = su::deserialize<bool>(is);
|
||||
if (hasState)
|
||||
{
|
||||
auto state = std::make_unique<DataTransceiverState>();
|
||||
*state = deserializeDataTransceiverState(is);
|
||||
return ContextPhaseParams{std::move(firstGenTokens), reqId, state.release(), std::move(draftTokens)};
|
||||
return ContextPhaseParams{std::move(firstGenTokens), reqId, state.release(), std::move(draftTokens),
|
||||
std::move(disaggId), ctxDpRank, std::move(disaggInfoEndpoint)};
|
||||
}
|
||||
return ContextPhaseParams{std::move(firstGenTokens), reqId, nullptr, std::move(draftTokens)};
|
||||
return ContextPhaseParams{std::move(firstGenTokens), reqId, std::move(draftTokens), std::move(disaggId), ctxDpRank,
|
||||
std::move(disaggInfoEndpoint)};
|
||||
}
|
||||
|
||||
void Serialization::serialize(ContextPhaseParams const& contextPhaseParams, std::ostream& os)
|
||||
@ -667,6 +672,9 @@ void Serialization::serialize(ContextPhaseParams const& contextPhaseParams, std:
|
||||
su::serialize(contextPhaseParams.mReqId, os);
|
||||
su::serialize(contextPhaseParams.mFirstGenTokens, os);
|
||||
su::serialize(contextPhaseParams.mDraftTokens, os);
|
||||
su::serialize(contextPhaseParams.mDisaggId, os);
|
||||
su::serialize(contextPhaseParams.mCtxDpRank, os);
|
||||
su::serialize(contextPhaseParams.mDisaggInfoEndpoint, os);
|
||||
su::serialize(static_cast<bool>(contextPhaseParams.mState), os);
|
||||
if (contextPhaseParams.mState)
|
||||
{
|
||||
@ -680,6 +688,9 @@ size_t Serialization::serializedSize(ContextPhaseParams const& contextPhaseParam
|
||||
totalSize += su::serializedSize(contextPhaseParams.mReqId);
|
||||
totalSize += su::serializedSize(contextPhaseParams.mFirstGenTokens);
|
||||
totalSize += su::serializedSize(contextPhaseParams.mDraftTokens);
|
||||
totalSize += su::serializedSize(contextPhaseParams.mDisaggId);
|
||||
totalSize += su::serializedSize(contextPhaseParams.mCtxDpRank);
|
||||
totalSize += su::serializedSize(contextPhaseParams.mDisaggInfoEndpoint);
|
||||
totalSize += su::serializedSize(bool{});
|
||||
if (contextPhaseParams.mState)
|
||||
{
|
||||
|
||||
@ -165,7 +165,7 @@ void initBindings(nb::module_& m)
|
||||
"context_current_position", &GenLlmReq::getContextCurrentPosition, &GenLlmReq::setContextCurrentPosition)
|
||||
.def_prop_ro("prepopulated_prompt_len", &GenLlmReq::getPrepopulatedPromptLen)
|
||||
.def_prop_rw("guided_decoding_params", &GenLlmReq::getGuidedDecodingParams, &GenLlmReq::setGuidedDecodingParams)
|
||||
.def_prop_ro("context_phase_params", &GenLlmReq::getContextPhaseParams)
|
||||
.def_prop_rw("context_phase_params", &GenLlmReq::getContextPhaseParams, &GenLlmReq::setContextPhaseParams)
|
||||
.def_prop_ro("is_context_only_request", &GenLlmReq::isContextOnlyRequest)
|
||||
.def_prop_ro("is_generation_only_request", &GenLlmReq::isGenerationOnlyRequest)
|
||||
.def_prop_ro("is_generation_complete_state", &GenLlmReq::isGenerationCompleteState)
|
||||
|
||||
@ -442,14 +442,16 @@ void initRequestBindings(nb::module_& m)
|
||||
{
|
||||
auto serializedState = self.getSerializedState();
|
||||
return nb::make_tuple(self.getFirstGenTokens(), self.getReqId(),
|
||||
nb::bytes(serializedState.data(), serializedState.size()), self.getDraftTokens());
|
||||
nb::bytes(serializedState.data(), serializedState.size()), self.getDraftTokens(), self.getDisaggId(),
|
||||
self.getCtxDpRank(), self.getDisaggInfoEndpoint());
|
||||
}
|
||||
return nb::make_tuple(self.getFirstGenTokens(), self.getReqId(), nb::none(), self.getDraftTokens());
|
||||
return nb::make_tuple(self.getFirstGenTokens(), self.getReqId(), nb::none(), self.getDraftTokens(),
|
||||
self.getDisaggId(), self.getCtxDpRank(), self.getDisaggInfoEndpoint());
|
||||
};
|
||||
|
||||
auto ContextPhaseParamsSetState = [](tle::ContextPhaseParams& contextPhaseParams, nb::tuple const& state)
|
||||
{
|
||||
if (state.size() != 4)
|
||||
if (state.size() != 7)
|
||||
{
|
||||
throw std::runtime_error("Invalid ContextPhaseParams state!");
|
||||
}
|
||||
@ -460,13 +462,15 @@ void initRequestBindings(nb::module_& m)
|
||||
new (&contextPhaseParams) tle::ContextPhaseParams(nb::cast<VecTokens>(state[0]),
|
||||
nb::cast<tle::ContextPhaseParams::RequestIdType>(state[1]),
|
||||
std::vector<char>(opaque_state_str_view.begin(), opaque_state_str_view.end()),
|
||||
nb::cast<std::optional<VecTokens>>(state[3]));
|
||||
nb::cast<std::optional<VecTokens>>(state[3]), nb::cast<std::optional<std::int64_t>>(state[4]),
|
||||
nb::cast<std::optional<SizeType32>>(state[5]), nb::cast<std::optional<std::string>>(state[6]));
|
||||
}
|
||||
else
|
||||
{
|
||||
new (&contextPhaseParams) tle::ContextPhaseParams(nb::cast<VecTokens>(state[0]),
|
||||
nb::cast<tle::ContextPhaseParams::RequestIdType>(state[1]),
|
||||
nb::cast<std::optional<VecTokens>>(state[3]));
|
||||
nb::cast<std::optional<VecTokens>>(state[3]), nb::cast<std::optional<std::int64_t>>(state[4]),
|
||||
nb::cast<std::optional<SizeType32>>(state[5]), nb::cast<std::optional<std::string>>(state[6]));
|
||||
}
|
||||
};
|
||||
|
||||
@ -475,25 +479,37 @@ void initRequestBindings(nb::module_& m)
|
||||
"__init__",
|
||||
[](tle::ContextPhaseParams& self, VecTokens const& first_gen_tokens,
|
||||
tle::ContextPhaseParams::RequestIdType req_id, std::optional<nb::bytes> const& opaque_state,
|
||||
std::optional<VecTokens> const& draft_tokens)
|
||||
std::optional<VecTokens> const& draft_tokens, std::optional<std::int64_t> const& disagg_id,
|
||||
std::optional<SizeType32> const& ctx_dp_rank, std::optional<std::string> const& disagg_info_endpoint)
|
||||
{
|
||||
if (opaque_state)
|
||||
{
|
||||
auto opaque_state_str_view
|
||||
= std::string_view(opaque_state.value().c_str(), opaque_state.value().size());
|
||||
new (&self) tle::ContextPhaseParams(first_gen_tokens, req_id,
|
||||
std::vector<char>(opaque_state_str_view.begin(), opaque_state_str_view.end()), draft_tokens);
|
||||
std::vector<char>(opaque_state_str_view.begin(), opaque_state_str_view.end()), draft_tokens,
|
||||
disagg_id, ctx_dp_rank, disagg_info_endpoint);
|
||||
}
|
||||
else
|
||||
{
|
||||
new (&self) tle::ContextPhaseParams(first_gen_tokens, req_id, draft_tokens);
|
||||
new (&self) tle::ContextPhaseParams(
|
||||
first_gen_tokens, req_id, draft_tokens, disagg_id, ctx_dp_rank, disagg_info_endpoint);
|
||||
}
|
||||
},
|
||||
nb::arg("first_gen_tokens"), nb::arg("req_id"), nb::arg("opaque_state").none(),
|
||||
nb::arg("draft_tokens").none())
|
||||
.def_prop_ro("first_gen_tokens", [](tle::ContextPhaseParams const& self) { return self.getFirstGenTokens(); })
|
||||
.def_prop_ro("draft_tokens", [](tle::ContextPhaseParams const& self) { return self.getDraftTokens(); })
|
||||
.def_prop_ro("req_id", &tle::ContextPhaseParams::getReqId)
|
||||
nb::arg("draft_tokens").none(), nb::arg("disagg_id").none(), nb::arg("ctx_dp_rank").none(),
|
||||
nb::arg("disagg_info_endpoint").none())
|
||||
.def_prop_rw(
|
||||
"first_gen_tokens", [](tle::ContextPhaseParams const& self) { return self.getFirstGenTokens(); },
|
||||
[](tle::ContextPhaseParams& self, VecTokens const& tokens) { self.setFirstGenTokens(tokens); })
|
||||
.def_prop_rw(
|
||||
"draft_tokens", [](tle::ContextPhaseParams const& self) { return self.getDraftTokens(); },
|
||||
[](tle::ContextPhaseParams& self, std::optional<VecTokens> const& tokens) { self.setDraftTokens(tokens); })
|
||||
.def_prop_rw("req_id", &tle::ContextPhaseParams::getReqId, &tle::ContextPhaseParams::setReqId)
|
||||
.def_prop_rw("disagg_id", &tle::ContextPhaseParams::getDisaggId, &tle::ContextPhaseParams::setDisaggId)
|
||||
.def_prop_rw("ctx_dp_rank", &tle::ContextPhaseParams::getCtxDpRank, &tle::ContextPhaseParams::setCtxDpRank)
|
||||
.def_prop_rw("disagg_info_endpoint", &tle::ContextPhaseParams::getDisaggInfoEndpoint,
|
||||
&tle::ContextPhaseParams::setDisaggInfoEndpoint)
|
||||
.def_prop_ro("opaque_state",
|
||||
[](tle::ContextPhaseParams const& self)
|
||||
{
|
||||
|
||||
@ -170,7 +170,7 @@ void initBindings(pybind11::module_& m)
|
||||
.def_property_readonly("prepopulated_prompt_len", &GenLlmReq::getPrepopulatedPromptLen)
|
||||
.def_property(
|
||||
"guided_decoding_params", &GenLlmReq::getGuidedDecodingParams, &GenLlmReq::setGuidedDecodingParams)
|
||||
.def_property_readonly("context_phase_params", &GenLlmReq::getContextPhaseParams)
|
||||
.def_property("context_phase_params", &GenLlmReq::getContextPhaseParams, &GenLlmReq::setContextPhaseParams)
|
||||
.def_property_readonly("is_context_only_request", &GenLlmReq::isContextOnlyRequest)
|
||||
.def_property_readonly("is_generation_only_request", &GenLlmReq::isGenerationOnlyRequest)
|
||||
.def_property_readonly("is_generation_complete_state", &GenLlmReq::isGenerationCompleteState)
|
||||
|
||||
@ -411,14 +411,16 @@ void initRequestBindings(pybind11::module_& m)
|
||||
{
|
||||
auto serializedState = self.getSerializedState();
|
||||
return py::make_tuple(self.getFirstGenTokens(), self.getReqId(),
|
||||
py::bytes(serializedState.data(), serializedState.size()), self.getDraftTokens());
|
||||
py::bytes(serializedState.data(), serializedState.size()), self.getDraftTokens(), self.getDisaggId(),
|
||||
self.getCtxDpRank(), self.getDisaggInfoEndpoint());
|
||||
}
|
||||
return py::make_tuple(self.getFirstGenTokens(), self.getReqId(), py::none(), self.getDraftTokens());
|
||||
return py::make_tuple(self.getFirstGenTokens(), self.getReqId(), py::none(), self.getDraftTokens(),
|
||||
self.getDisaggId(), self.getCtxDpRank(), self.getDisaggInfoEndpoint());
|
||||
};
|
||||
|
||||
auto ContextPhaseParamsSetState = [](py::tuple const& state)
|
||||
{
|
||||
if (state.size() != 4)
|
||||
if (state.size() != 7)
|
||||
{
|
||||
throw std::runtime_error("Invalid ContextPhaseParams state!");
|
||||
}
|
||||
@ -429,28 +431,44 @@ void initRequestBindings(pybind11::module_& m)
|
||||
return std::make_unique<tle::ContextPhaseParams>(state[0].cast<VecTokens>(),
|
||||
state[1].cast<tle::ContextPhaseParams::RequestIdType>(),
|
||||
std::vector<char>(opaque_state_str_view.begin(), opaque_state_str_view.end()),
|
||||
state[3].cast<std::optional<VecTokens>>());
|
||||
state[3].cast<std::optional<VecTokens>>(), state[4].cast<std::optional<std::int64_t>>(),
|
||||
state[5].cast<std::optional<SizeType32>>(), state[6].cast<std::optional<std::string>>());
|
||||
}
|
||||
return std::make_unique<tle::ContextPhaseParams>(state[0].cast<VecTokens>(),
|
||||
state[1].cast<tle::ContextPhaseParams::RequestIdType>(), state[3].cast<std::optional<VecTokens>>());
|
||||
state[1].cast<tle::ContextPhaseParams::RequestIdType>(), state[3].cast<std::optional<VecTokens>>(),
|
||||
state[4].cast<std::optional<std::int64_t>>(), state[5].cast<std::optional<SizeType32>>(),
|
||||
state[6].cast<std::optional<std::string>>());
|
||||
};
|
||||
|
||||
py::class_<tle::ContextPhaseParams>(m, "ContextPhaseParams")
|
||||
.def(py::init(
|
||||
[](VecTokens const& first_gen_tokens, tle::ContextPhaseParams::RequestIdType req_id,
|
||||
std::optional<py::bytes> const& opaque_state, std::optional<VecTokens> const& draft_tokens)
|
||||
{
|
||||
if (opaque_state)
|
||||
{
|
||||
auto opaque_state_str_view = std::string_view(opaque_state.value().cast<std::string_view>());
|
||||
return std::make_unique<tle::ContextPhaseParams>(first_gen_tokens, req_id,
|
||||
std::vector<char>(opaque_state_str_view.begin(), opaque_state_str_view.end()), draft_tokens);
|
||||
}
|
||||
return std::make_unique<tle::ContextPhaseParams>(first_gen_tokens, req_id, draft_tokens);
|
||||
}))
|
||||
.def_property_readonly("first_gen_tokens", &tle::ContextPhaseParams::getFirstGenTokens)
|
||||
.def_property_readonly("draft_tokens", &tle::ContextPhaseParams::getDraftTokens)
|
||||
.def_property_readonly("req_id", &tle::ContextPhaseParams::getReqId)
|
||||
[](VecTokens const& first_gen_tokens, tle::ContextPhaseParams::RequestIdType req_id,
|
||||
std::optional<py::bytes> const& opaque_state, std::optional<VecTokens> const& draft_tokens,
|
||||
std::optional<std::int64_t> const& disagg_id, std::optional<SizeType32> const& ctx_dp_rank,
|
||||
std::optional<std::string> const& disagg_info_endpoint)
|
||||
{
|
||||
if (opaque_state)
|
||||
{
|
||||
auto opaque_state_str_view = std::string_view(opaque_state.value().cast<std::string_view>());
|
||||
return std::make_unique<tle::ContextPhaseParams>(first_gen_tokens, req_id,
|
||||
std::vector<char>(opaque_state_str_view.begin(), opaque_state_str_view.end()),
|
||||
draft_tokens, disagg_id, ctx_dp_rank, disagg_info_endpoint);
|
||||
}
|
||||
return std::make_unique<tle::ContextPhaseParams>(
|
||||
first_gen_tokens, req_id, draft_tokens, disagg_id, ctx_dp_rank, disagg_info_endpoint);
|
||||
}),
|
||||
py::arg("first_gen_tokens"), py::arg("req_id"), py::arg("opaque_state") = py::none(),
|
||||
py::arg("draft_tokens") = py::none(), py::arg("disagg_id") = py::none(),
|
||||
py::arg("ctx_dp_rank") = py::none(), py::arg("disagg_info_endpoint") = py::none())
|
||||
.def_property("first_gen_tokens", &tle::ContextPhaseParams::getFirstGenTokens,
|
||||
&tle::ContextPhaseParams::setFirstGenTokens)
|
||||
.def_property(
|
||||
"draft_tokens", &tle::ContextPhaseParams::getDraftTokens, &tle::ContextPhaseParams::setDraftTokens)
|
||||
.def_property("req_id", &tle::ContextPhaseParams::getReqId, &tle::ContextPhaseParams::setReqId)
|
||||
.def_property("disagg_id", &tle::ContextPhaseParams::getDisaggId, &tle::ContextPhaseParams::setDisaggId)
|
||||
.def_property("ctx_dp_rank", &tle::ContextPhaseParams::getCtxDpRank, &tle::ContextPhaseParams::setCtxDpRank)
|
||||
.def_property("disagg_info_endpoint", &tle::ContextPhaseParams::getDisaggInfoEndpoint,
|
||||
&tle::ContextPhaseParams::setDisaggInfoEndpoint)
|
||||
.def_property_readonly("opaque_state",
|
||||
[](tle::ContextPhaseParams const& self)
|
||||
{
|
||||
|
||||
@ -465,10 +465,13 @@ class ExecutorRequestQueue:
|
||||
new_requests, "py_scheduling_params")
|
||||
py_num_logprobs = self._collect_py_objects_from_requests(
|
||||
new_requests, "py_num_logprobs")
|
||||
py_disaggregated_params = self._collect_py_objects_from_requests(
|
||||
new_requests, "py_disaggregated_params")
|
||||
py_request_objects = tuple(
|
||||
filter(None, [
|
||||
py_logits_post_processors, py_multimodal_data,
|
||||
py_scheduling_params, py_num_logprobs
|
||||
py_scheduling_params, py_num_logprobs,
|
||||
py_disaggregated_params
|
||||
]))
|
||||
else:
|
||||
py_request_objects = None
|
||||
|
||||
@ -121,7 +121,11 @@ class BindKvCacheTransceiver(KvCacheTransceiver):
|
||||
cache_transceiver_config._to_pybind())
|
||||
|
||||
def respond_and_send_async(self, req: LlmRequest):
|
||||
return self.impl.respond_and_send_async(req)
|
||||
self.impl.respond_and_send_async(req)
|
||||
if (req.py_disaggregated_params is not None
|
||||
and req.py_disaggregated_params.disagg_id is not None):
|
||||
req.context_phase_params.disagg_id = req.py_disaggregated_params.disagg_id
|
||||
return
|
||||
|
||||
def request_and_receive_sync(self, req: LlmRequest):
|
||||
return self.impl.request_and_receive_sync(req)
|
||||
|
||||
@ -566,6 +566,8 @@ class LlmRequest(tensorrt_llm.bindings.internal.batch_manager.LlmRequest):
|
||||
# currently, keep py_stop_words_list as python list, rather than tensor.
|
||||
self.py_stop_words_list = stop_words_list
|
||||
|
||||
self.py_disaggregated_params = None
|
||||
|
||||
self.py_result = PyResult(
|
||||
prompt_len=self.py_prompt_len,
|
||||
max_new_tokens=self.py_max_new_tokens,
|
||||
@ -826,6 +828,10 @@ def executor_request_to_llm_request(
|
||||
py_multimodal_data=getattr(executor_request, "py_multimodal_data",
|
||||
None),
|
||||
kv_cache_retention_config=executor_request.kv_cache_retention_config)
|
||||
|
||||
llm_request.py_disaggregated_params = getattr(executor_request,
|
||||
"py_disaggregated_params",
|
||||
None)
|
||||
if child_req_ids:
|
||||
for child_id in child_req_ids:
|
||||
llm_request.create_child_request(child_id)
|
||||
|
||||
@ -32,7 +32,9 @@ class DisaggregatedParams:
|
||||
ctx_request_id: Optional[int] = None
|
||||
opaque_state: Optional[bytes] = None
|
||||
draft_tokens: Optional[List[int]] = None
|
||||
|
||||
disagg_id: Optional[int] = None
|
||||
ctx_dp_rank: Optional[int] = None
|
||||
ctx_info_endpoint: Optional[List[str]] = None
|
||||
# E-P Disaggregated Params
|
||||
multimodal_embedding_handles: Optional[List[Dict[str, Any]]] = (
|
||||
None # multimodal embedding handles should be a list of cudaIPC handles for each mm_embedding
|
||||
@ -45,7 +47,13 @@ class DisaggregatedParams:
|
||||
|
||||
def get_context_phase_params(self) -> tllme.ContextPhaseParams:
|
||||
return tllme.ContextPhaseParams(
|
||||
self.first_gen_tokens, self.ctx_request_id, self.opaque_state, self.draft_tokens
|
||||
self.first_gen_tokens,
|
||||
self.ctx_request_id,
|
||||
self.opaque_state,
|
||||
self.draft_tokens,
|
||||
self.disagg_id,
|
||||
self.ctx_dp_rank,
|
||||
self.ctx_info_endpoint,
|
||||
)
|
||||
|
||||
def get_request_type(self) -> tllme.RequestType:
|
||||
|
||||
@ -563,6 +563,9 @@ class BaseWorker(GenerationExecutor):
|
||||
executor_request.py_num_logprobs = request.sampling_params.logprobs
|
||||
executor_request.py_lora_path = py_lora_path
|
||||
|
||||
# here we add executor_request.py_disaggregated_params= request.disaggregated_params for python cache transceiver
|
||||
if self._is_pytorch_backend and request.disaggregated_params is not None:
|
||||
executor_request.py_disaggregated_params = request.disaggregated_params
|
||||
if self._is_pytorch_backend and request.multimodal_params is not None:
|
||||
if request.multimodal_params.multimodal_data is not None:
|
||||
# NOTE: Deserialize SharedTensor handle to actual tensor
|
||||
|
||||
@ -427,6 +427,9 @@ class GenerationResultBase:
|
||||
ctx_request_id=context_phase_params.req_id,
|
||||
opaque_state=context_phase_params.opaque_state,
|
||||
draft_tokens=context_phase_params.draft_tokens,
|
||||
disagg_id=context_phase_params.disagg_id,
|
||||
ctx_dp_rank=context_phase_params.ctx_dp_rank,
|
||||
ctx_info_endpoint=context_phase_params.disagg_info_endpoint,
|
||||
multimodal_embedding_handles=None,
|
||||
)
|
||||
|
||||
|
||||
@ -15,6 +15,7 @@
|
||||
import asyncio
|
||||
import copy
|
||||
import os
|
||||
import uuid
|
||||
from typing import Any, Callable, Dict, Optional
|
||||
|
||||
from tensorrt_llm.llmapi.disagg_utils import (
|
||||
@ -142,7 +143,13 @@ class OpenAIDisaggregatedService(OpenAIService):
|
||||
|
||||
def _get_ctx_request(self, request: UCompletionRequest) -> UCompletionRequest:
|
||||
ctx_request = copy.deepcopy(request)
|
||||
ctx_request.disaggregated_params = DisaggregatedParams(request_type="context_only")
|
||||
unique_disagg_id = (
|
||||
uuid.uuid4().int & 0x7FFFFFFFFFFFFFFF
|
||||
) # Generate positive int64 from uuid
|
||||
ctx_request.disaggregated_params = DisaggregatedParams(
|
||||
request_type="context_only",
|
||||
disagg_id=unique_disagg_id,
|
||||
)
|
||||
ctx_request.stream = False
|
||||
ctx_request.stream_options = None
|
||||
return ctx_request
|
||||
@ -304,4 +311,8 @@ class OpenAIDisaggregatedService(OpenAIService):
|
||||
raise ValueError("Context server did not return disaggregated params")
|
||||
if ctx_response.choices[0].disaggregated_params.ctx_request_id is None:
|
||||
raise ValueError("Invalid disaggregated params in context phase response.")
|
||||
if ctx_response.choices[0].disaggregated_params.disagg_id is None:
|
||||
raise ValueError(
|
||||
"Invalid disaggregated params in context phase response. disagg_id is None"
|
||||
)
|
||||
return ctx_response
|
||||
|
||||
@ -117,6 +117,9 @@ class DisaggregatedParams(OpenAIBaseModel):
|
||||
ctx_request_id: Optional[int] = None
|
||||
encoded_opaque_state: Optional[str] = None
|
||||
draft_tokens: Optional[List[int]] = None
|
||||
disagg_id: Optional[int] = None
|
||||
ctx_dp_rank: Optional[int] = None
|
||||
ctx_info_endpoint: Optional[str] = None
|
||||
|
||||
|
||||
class ErrorResponse(OpenAIBaseModel):
|
||||
@ -1000,7 +1003,10 @@ def to_disaggregated_params(
|
||||
ctx_request_id=tllm_disagg_params.ctx_request_id,
|
||||
encoded_opaque_state=encode_opaque_state(
|
||||
tllm_disagg_params.opaque_state),
|
||||
draft_tokens=tllm_disagg_params.draft_tokens)
|
||||
draft_tokens=tllm_disagg_params.draft_tokens,
|
||||
disagg_id=tllm_disagg_params.disagg_id,
|
||||
ctx_dp_rank=tllm_disagg_params.ctx_dp_rank,
|
||||
ctx_info_endpoint=tllm_disagg_params.ctx_info_endpoint)
|
||||
|
||||
|
||||
def to_llm_disaggregated_params(
|
||||
@ -1013,7 +1019,10 @@ def to_llm_disaggregated_params(
|
||||
ctx_request_id=disaggregated_params.ctx_request_id,
|
||||
opaque_state=decode_opaque_state(
|
||||
disaggregated_params.encoded_opaque_state),
|
||||
draft_tokens=disaggregated_params.draft_tokens)
|
||||
draft_tokens=disaggregated_params.draft_tokens,
|
||||
disagg_id=disaggregated_params.disagg_id,
|
||||
ctx_dp_rank=disaggregated_params.ctx_dp_rank,
|
||||
ctx_info_endpoint=disaggregated_params.ctx_info_endpoint)
|
||||
|
||||
|
||||
UCompletionRequest = Union[CompletionRequest, ChatCompletionRequest]
|
||||
|
||||
@ -1198,9 +1198,9 @@ def test_result_pickle():
|
||||
result.sequence_index = 1
|
||||
result.is_sequence_final = True
|
||||
result.decoding_iter = 1
|
||||
result.context_phase_params = trtllm.ContextPhaseParams([1, 2], 123,
|
||||
bytes([0, 1]),
|
||||
[10, 20, 30])
|
||||
result.context_phase_params = trtllm.ContextPhaseParams(
|
||||
[1, 2], 123, bytes([0, 1]), [10, 20, 30], 13579, 1,
|
||||
"disagg_info_endpoint_24680")
|
||||
result.request_perf_metrics = trtllm.RequestPerfMetrics()
|
||||
result.request_perf_metrics.last_iter = 33
|
||||
result_str = pickle.dumps(result)
|
||||
@ -1220,6 +1220,9 @@ def test_result_pickle():
|
||||
assert result.context_phase_params.first_gen_tokens == result_copy.context_phase_params.first_gen_tokens
|
||||
assert result.context_phase_params.draft_tokens == result_copy.context_phase_params.draft_tokens
|
||||
assert result.context_phase_params.opaque_state == result_copy.context_phase_params.opaque_state
|
||||
assert result.context_phase_params.disagg_id == result_copy.context_phase_params.disagg_id
|
||||
assert result.context_phase_params.ctx_dp_rank == result_copy.context_phase_params.ctx_dp_rank
|
||||
assert result.context_phase_params.disagg_info_endpoint == result_copy.context_phase_params.disagg_info_endpoint
|
||||
assert result.request_perf_metrics.last_iter == result_copy.request_perf_metrics.last_iter
|
||||
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user