mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-04 18:21:52 +08:00
[TRTLLM-9527][feat] change context params and disagg params (step3) (#10495)
Signed-off-by: Chuang Zhu <111838961+chuangz0@users.noreply.github.com>
This commit is contained in:
parent
fae4985797
commit
d6f76d2fae
@ -442,11 +442,15 @@ class ContextPhaseParams
|
||||
public:
|
||||
using RequestIdType = std::uint64_t;
|
||||
|
||||
ContextPhaseParams(VecTokens firstGenTokens, RequestIdType reqId, std::optional<VecTokens> draftTokens);
|
||||
ContextPhaseParams(
|
||||
VecTokens firstGenTokens, RequestIdType reqId, void* state, std::optional<VecTokens> draftTokens);
|
||||
ContextPhaseParams(VecTokens firstGenTokens, RequestIdType reqId, std::optional<VecTokens> draftTokens,
|
||||
std::optional<SizeType32> ctxDpRank = std::nullopt,
|
||||
std::optional<std::string> disaggInfoEndpoint = std::nullopt);
|
||||
ContextPhaseParams(VecTokens firstGenTokens, RequestIdType reqId, void* state, std::optional<VecTokens> draftTokens,
|
||||
std::optional<SizeType32> ctxDpRank = std::nullopt,
|
||||
std::optional<std::string> disaggInfoEndpoint = std::nullopt);
|
||||
ContextPhaseParams(VecTokens firstGenTokens, RequestIdType reqId, std::vector<char> const& serializedState,
|
||||
std::optional<VecTokens> draftTokens);
|
||||
std::optional<VecTokens> draftTokens, std::optional<SizeType32> ctxDpRank = std::nullopt,
|
||||
std::optional<std::string> disaggInfoEndpoint = std::nullopt);
|
||||
|
||||
ContextPhaseParams(ContextPhaseParams const&);
|
||||
ContextPhaseParams(ContextPhaseParams&&) noexcept;
|
||||
@ -457,15 +461,22 @@ public:
|
||||
[[nodiscard]] bool operator==(ContextPhaseParams const&) const noexcept;
|
||||
|
||||
[[nodiscard]] VecTokens const& getFirstGenTokens() const& noexcept;
|
||||
void setFirstGenTokens(VecTokens const& firstGenTokens) noexcept;
|
||||
[[nodiscard]] std::optional<VecTokens> const& getDraftTokens() const& noexcept;
|
||||
void setDraftTokens(std::optional<VecTokens> const& draftTokens) noexcept;
|
||||
[[nodiscard]] VecTokens popFirstGenTokens() && noexcept;
|
||||
[[nodiscard]] RequestIdType getReqId() const noexcept;
|
||||
|
||||
void setReqId(RequestIdType const& reqId) noexcept;
|
||||
[[nodiscard]] void const* getState() const noexcept;
|
||||
[[nodiscard]] void* getState() noexcept;
|
||||
[[nodiscard]] void* releaseState() noexcept;
|
||||
[[nodiscard]] std::vector<char> getSerializedState() const noexcept;
|
||||
|
||||
[[nodiscard]] std::optional<SizeType32> getCtxDpRank() const noexcept;
|
||||
void setCtxDpRank(std::optional<SizeType32> const& ctxDpRank) noexcept;
|
||||
[[nodiscard]] std::optional<std::string> const& getDisaggInfoEndpoint() const noexcept;
|
||||
void setDisaggInfoEndpoint(std::optional<std::string> const& disaggInfoEndpoint) noexcept;
|
||||
|
||||
private:
|
||||
friend class Serialization;
|
||||
static void deleter(void const* data);
|
||||
@ -482,6 +493,12 @@ private:
|
||||
|
||||
/// @brief The draft tokens generated by context executor
|
||||
std::optional<VecTokens> mDraftTokens;
|
||||
|
||||
/// @brief The context phase data parallel rank
|
||||
std::optional<SizeType32> mCtxDpRank;
|
||||
|
||||
/// @brief The disaggregated info endpoint
|
||||
std::optional<std::string> mDisaggInfoEndpoint;
|
||||
};
|
||||
|
||||
/// @brief Configuration for speculative decoding (both draft and target models)
|
||||
|
||||
@ -99,13 +99,15 @@ std::optional<executor::Result> LlmRequest::createResult(bool useFastLogits, int
|
||||
}
|
||||
if (!hasDraftTokens())
|
||||
{
|
||||
result.contextPhaseParams = executor::ContextPhaseParams{
|
||||
std::move(firstGenTokens), mRequestId, mContextPhaseParams.value().releaseState(), std::nullopt};
|
||||
result.contextPhaseParams = executor::ContextPhaseParams{std::move(firstGenTokens), mRequestId,
|
||||
mContextPhaseParams.value().releaseState(), std::nullopt, mContextPhaseParams.value().getCtxDpRank(),
|
||||
mContextPhaseParams.value().getDisaggInfoEndpoint()};
|
||||
}
|
||||
else
|
||||
{
|
||||
result.contextPhaseParams = executor::ContextPhaseParams{
|
||||
std::move(firstGenTokens), mRequestId, mContextPhaseParams.value().releaseState(), *getDraftTokens()};
|
||||
result.contextPhaseParams = executor::ContextPhaseParams{std::move(firstGenTokens), mRequestId,
|
||||
mContextPhaseParams.value().releaseState(), *getDraftTokens(),
|
||||
mContextPhaseParams.value().getCtxDpRank(), mContextPhaseParams.value().getDisaggInfoEndpoint()};
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -27,28 +27,37 @@ namespace su = tensorrt_llm::executor::serialize_utils;
|
||||
namespace tensorrt_llm::executor
|
||||
{
|
||||
|
||||
ContextPhaseParams::ContextPhaseParams(
|
||||
VecTokens firstGenTokens, RequestIdType reqId, void* state, std::optional<VecTokens> draftTokens)
|
||||
ContextPhaseParams::ContextPhaseParams(VecTokens firstGenTokens, RequestIdType reqId, void* state,
|
||||
std::optional<VecTokens> draftTokens, std::optional<SizeType32> ctxDpRank,
|
||||
std::optional<std::string> disaggInfoEndpoint)
|
||||
: mReqId{reqId}
|
||||
, mFirstGenTokens{std::move(firstGenTokens)}
|
||||
, mState{StatePtr{state, deleter}}
|
||||
, mDraftTokens{std::move(draftTokens)}
|
||||
{
|
||||
}
|
||||
|
||||
ContextPhaseParams::ContextPhaseParams(
|
||||
VecTokens firstGenTokens, RequestIdType reqId, std::optional<VecTokens> draftTokens)
|
||||
: mReqId{reqId}
|
||||
, mFirstGenTokens{std::move(firstGenTokens)}
|
||||
, mDraftTokens{std::move(draftTokens)}
|
||||
, mCtxDpRank{ctxDpRank}
|
||||
, mDisaggInfoEndpoint{std::move(disaggInfoEndpoint)}
|
||||
{
|
||||
}
|
||||
|
||||
ContextPhaseParams::ContextPhaseParams(VecTokens firstGenTokens, RequestIdType reqId,
|
||||
std::vector<char> const& serializedState, std::optional<VecTokens> draftTokens)
|
||||
std::optional<VecTokens> draftTokens, std::optional<SizeType32> ctxDpRank,
|
||||
std::optional<std::string> disaggInfoEndpoint)
|
||||
: mReqId{reqId}
|
||||
, mFirstGenTokens{std::move(firstGenTokens)}
|
||||
, mDraftTokens{std::move(draftTokens)}
|
||||
, mCtxDpRank{ctxDpRank}
|
||||
, mDisaggInfoEndpoint{std::move(disaggInfoEndpoint)}
|
||||
{
|
||||
}
|
||||
|
||||
ContextPhaseParams::ContextPhaseParams(VecTokens firstGenTokens, RequestIdType reqId,
|
||||
std::vector<char> const& serializedState, std::optional<VecTokens> draftTokens, std::optional<SizeType32> ctxDpRank,
|
||||
std::optional<std::string> disaggInfoEndpoint)
|
||||
: mReqId{reqId}
|
||||
, mFirstGenTokens{std::move(firstGenTokens)}
|
||||
, mDraftTokens{std::move(draftTokens)}
|
||||
, mCtxDpRank{ctxDpRank}
|
||||
, mDisaggInfoEndpoint{std::move(disaggInfoEndpoint)}
|
||||
{
|
||||
|
||||
su::VectorWrapBuf<char> strbuf(const_cast<std::vector<char>&>(serializedState));
|
||||
@ -60,12 +69,14 @@ ContextPhaseParams::ContextPhaseParams(VecTokens firstGenTokens, RequestIdType r
|
||||
}
|
||||
|
||||
ContextPhaseParams::ContextPhaseParams(ContextPhaseParams const& other)
|
||||
: mReqId{other.mReqId}
|
||||
, mFirstGenTokens{other.mFirstGenTokens}
|
||||
, mDraftTokens{other.mDraftTokens}
|
||||
, mCtxDpRank{other.mCtxDpRank}
|
||||
, mDisaggInfoEndpoint{other.mDisaggInfoEndpoint}
|
||||
{
|
||||
// Since the internal header files implement the destructor while using the declaration of this
|
||||
// type, a `unique_ptr` with a custom destructor member is used here.
|
||||
mReqId = other.mReqId;
|
||||
mFirstGenTokens = other.mFirstGenTokens;
|
||||
mDraftTokens = other.mDraftTokens;
|
||||
if (other.mState)
|
||||
{
|
||||
auto* otherState = static_cast<DataTransceiverState*>(other.mState.get());
|
||||
@ -90,11 +101,21 @@ VecTokens const& ContextPhaseParams::getFirstGenTokens() const& noexcept
|
||||
return mFirstGenTokens;
|
||||
}
|
||||
|
||||
void ContextPhaseParams::setFirstGenTokens(VecTokens const& firstGenTokens) noexcept
|
||||
{
|
||||
mFirstGenTokens = firstGenTokens;
|
||||
}
|
||||
|
||||
std::optional<VecTokens> const& ContextPhaseParams::getDraftTokens() const& noexcept
|
||||
{
|
||||
return mDraftTokens;
|
||||
}
|
||||
|
||||
void ContextPhaseParams::setDraftTokens(std::optional<VecTokens> const& draftTokens) noexcept
|
||||
{
|
||||
mDraftTokens = draftTokens;
|
||||
}
|
||||
|
||||
VecTokens ContextPhaseParams::popFirstGenTokens() && noexcept
|
||||
{
|
||||
return std::move(mFirstGenTokens);
|
||||
@ -105,6 +126,11 @@ ContextPhaseParams::RequestIdType ContextPhaseParams::getReqId() const noexcept
|
||||
return mReqId;
|
||||
}
|
||||
|
||||
void ContextPhaseParams::setReqId(RequestIdType const& reqId) noexcept
|
||||
{
|
||||
mReqId = reqId;
|
||||
}
|
||||
|
||||
void const* ContextPhaseParams::getState() const noexcept
|
||||
{
|
||||
return mState.get();
|
||||
@ -125,6 +151,26 @@ void* ContextPhaseParams::releaseState() noexcept
|
||||
return mState.release();
|
||||
}
|
||||
|
||||
std::optional<SizeType32> ContextPhaseParams::getCtxDpRank() const noexcept
|
||||
{
|
||||
return mCtxDpRank;
|
||||
}
|
||||
|
||||
void ContextPhaseParams::setCtxDpRank(std::optional<SizeType32> const& ctxDpRank) noexcept
|
||||
{
|
||||
mCtxDpRank = ctxDpRank;
|
||||
}
|
||||
|
||||
std::optional<std::string> const& ContextPhaseParams::getDisaggInfoEndpoint() const noexcept
|
||||
{
|
||||
return mDisaggInfoEndpoint;
|
||||
}
|
||||
|
||||
void ContextPhaseParams::setDisaggInfoEndpoint(std::optional<std::string> const& disaggInfoEndpoint) noexcept
|
||||
{
|
||||
mDisaggInfoEndpoint = disaggInfoEndpoint;
|
||||
}
|
||||
|
||||
void ContextPhaseParams::deleter(void const* data)
|
||||
{
|
||||
using StateT = DataTransceiverState const;
|
||||
@ -134,6 +180,7 @@ void ContextPhaseParams::deleter(void const* data)
|
||||
bool ContextPhaseParams::operator==(ContextPhaseParams const& other) const noexcept
|
||||
{
|
||||
if (mFirstGenTokens != other.mFirstGenTokens || mReqId != other.mReqId || mDraftTokens != other.mDraftTokens
|
||||
|| mDisaggInfoEndpoint != other.mDisaggInfoEndpoint || mCtxDpRank != other.mCtxDpRank
|
||||
|| static_cast<bool>(mState) != static_cast<bool>(other.mState))
|
||||
{
|
||||
return false;
|
||||
|
||||
@ -652,14 +652,18 @@ ContextPhaseParams Serialization::deserializeContextPhaseParams(std::istream& is
|
||||
auto reqId = su::deserialize<decltype(ContextPhaseParams::mReqId)>(is);
|
||||
auto firstGenTokens = su::deserialize<decltype(ContextPhaseParams::mFirstGenTokens)>(is);
|
||||
auto draftTokens = su::deserialize<decltype(ContextPhaseParams::mDraftTokens)>(is);
|
||||
auto ctxDpRank = su::deserialize<decltype(ContextPhaseParams::mCtxDpRank)>(is);
|
||||
auto disaggInfoEndpoint = su::deserialize<decltype(ContextPhaseParams::mDisaggInfoEndpoint)>(is);
|
||||
auto hasState = su::deserialize<bool>(is);
|
||||
if (hasState)
|
||||
{
|
||||
auto state = std::make_unique<DataTransceiverState>();
|
||||
*state = deserializeDataTransceiverState(is);
|
||||
return ContextPhaseParams{std::move(firstGenTokens), reqId, state.release(), std::move(draftTokens)};
|
||||
return ContextPhaseParams{std::move(firstGenTokens), reqId, state.release(), std::move(draftTokens), ctxDpRank,
|
||||
std::move(disaggInfoEndpoint)};
|
||||
}
|
||||
return ContextPhaseParams{std::move(firstGenTokens), reqId, nullptr, std::move(draftTokens)};
|
||||
return ContextPhaseParams{
|
||||
std::move(firstGenTokens), reqId, std::move(draftTokens), ctxDpRank, std::move(disaggInfoEndpoint)};
|
||||
}
|
||||
|
||||
void Serialization::serialize(ContextPhaseParams const& contextPhaseParams, std::ostream& os)
|
||||
@ -667,6 +671,8 @@ void Serialization::serialize(ContextPhaseParams const& contextPhaseParams, std:
|
||||
su::serialize(contextPhaseParams.mReqId, os);
|
||||
su::serialize(contextPhaseParams.mFirstGenTokens, os);
|
||||
su::serialize(contextPhaseParams.mDraftTokens, os);
|
||||
su::serialize(contextPhaseParams.mCtxDpRank, os);
|
||||
su::serialize(contextPhaseParams.mDisaggInfoEndpoint, os);
|
||||
su::serialize(static_cast<bool>(contextPhaseParams.mState), os);
|
||||
if (contextPhaseParams.mState)
|
||||
{
|
||||
@ -680,6 +686,8 @@ size_t Serialization::serializedSize(ContextPhaseParams const& contextPhaseParam
|
||||
totalSize += su::serializedSize(contextPhaseParams.mReqId);
|
||||
totalSize += su::serializedSize(contextPhaseParams.mFirstGenTokens);
|
||||
totalSize += su::serializedSize(contextPhaseParams.mDraftTokens);
|
||||
totalSize += su::serializedSize(contextPhaseParams.mCtxDpRank);
|
||||
totalSize += su::serializedSize(contextPhaseParams.mDisaggInfoEndpoint);
|
||||
totalSize += su::serializedSize(bool{});
|
||||
if (contextPhaseParams.mState)
|
||||
{
|
||||
|
||||
@ -167,7 +167,7 @@ void initBindings(nb::module_& m)
|
||||
"context_current_position", &GenLlmReq::getContextCurrentPosition, &GenLlmReq::setContextCurrentPosition)
|
||||
.def_prop_ro("prepopulated_prompt_len", &GenLlmReq::getPrepopulatedPromptLen)
|
||||
.def_prop_rw("guided_decoding_params", &GenLlmReq::getGuidedDecodingParams, &GenLlmReq::setGuidedDecodingParams)
|
||||
.def_prop_ro("context_phase_params", &GenLlmReq::getContextPhaseParams)
|
||||
.def_prop_rw("context_phase_params", &GenLlmReq::getContextPhaseParams, &GenLlmReq::setContextPhaseParams)
|
||||
.def_prop_ro("is_context_only_request", &GenLlmReq::isContextOnlyRequest)
|
||||
.def_prop_ro("is_generation_only_request", &GenLlmReq::isGenerationOnlyRequest)
|
||||
.def_prop_ro("is_generation_to_complete_state", &GenLlmReq::isGenerationToCompleteState)
|
||||
|
||||
@ -442,14 +442,16 @@ void initRequestBindings(nb::module_& m)
|
||||
{
|
||||
auto serializedState = self.getSerializedState();
|
||||
return nb::make_tuple(self.getFirstGenTokens(), self.getReqId(),
|
||||
nb::bytes(serializedState.data(), serializedState.size()), self.getDraftTokens());
|
||||
nb::bytes(serializedState.data(), serializedState.size()), self.getDraftTokens(), self.getCtxDpRank(),
|
||||
self.getDisaggInfoEndpoint());
|
||||
}
|
||||
return nb::make_tuple(self.getFirstGenTokens(), self.getReqId(), nb::none(), self.getDraftTokens());
|
||||
return nb::make_tuple(self.getFirstGenTokens(), self.getReqId(), nb::none(), self.getDraftTokens(),
|
||||
self.getCtxDpRank(), self.getDisaggInfoEndpoint());
|
||||
};
|
||||
|
||||
auto ContextPhaseParamsSetState = [](tle::ContextPhaseParams& contextPhaseParams, nb::tuple const& state)
|
||||
{
|
||||
if (state.size() != 4)
|
||||
if (state.size() != 6)
|
||||
{
|
||||
throw std::runtime_error("Invalid ContextPhaseParams state!");
|
||||
}
|
||||
@ -460,13 +462,15 @@ void initRequestBindings(nb::module_& m)
|
||||
new (&contextPhaseParams) tle::ContextPhaseParams(nb::cast<VecTokens>(state[0]),
|
||||
nb::cast<tle::ContextPhaseParams::RequestIdType>(state[1]),
|
||||
std::vector<char>(opaque_state_str_view.begin(), opaque_state_str_view.end()),
|
||||
nb::cast<std::optional<VecTokens>>(state[3]));
|
||||
nb::cast<std::optional<VecTokens>>(state[3]), nb::cast<std::optional<SizeType32>>(state[4]),
|
||||
nb::cast<std::optional<std::string>>(state[5]));
|
||||
}
|
||||
else
|
||||
{
|
||||
new (&contextPhaseParams) tle::ContextPhaseParams(nb::cast<VecTokens>(state[0]),
|
||||
nb::cast<tle::ContextPhaseParams::RequestIdType>(state[1]),
|
||||
nb::cast<std::optional<VecTokens>>(state[3]));
|
||||
nb::cast<std::optional<VecTokens>>(state[3]), nb::cast<std::optional<SizeType32>>(state[4]),
|
||||
nb::cast<std::optional<std::string>>(state[5]));
|
||||
}
|
||||
};
|
||||
|
||||
@ -475,25 +479,35 @@ void initRequestBindings(nb::module_& m)
|
||||
"__init__",
|
||||
[](tle::ContextPhaseParams& self, VecTokens const& first_gen_tokens,
|
||||
tle::ContextPhaseParams::RequestIdType req_id, std::optional<nb::bytes> const& opaque_state,
|
||||
std::optional<VecTokens> const& draft_tokens)
|
||||
std::optional<VecTokens> const& draft_tokens, std::optional<SizeType32> const& ctx_dp_rank,
|
||||
std::optional<std::string> const& disagg_info_endpoint)
|
||||
{
|
||||
if (opaque_state)
|
||||
{
|
||||
auto opaque_state_str_view
|
||||
= std::string_view(opaque_state.value().c_str(), opaque_state.value().size());
|
||||
new (&self) tle::ContextPhaseParams(first_gen_tokens, req_id,
|
||||
std::vector<char>(opaque_state_str_view.begin(), opaque_state_str_view.end()), draft_tokens);
|
||||
std::vector<char>(opaque_state_str_view.begin(), opaque_state_str_view.end()), draft_tokens,
|
||||
ctx_dp_rank, disagg_info_endpoint);
|
||||
}
|
||||
else
|
||||
{
|
||||
new (&self) tle::ContextPhaseParams(first_gen_tokens, req_id, draft_tokens);
|
||||
new (&self) tle::ContextPhaseParams(
|
||||
first_gen_tokens, req_id, draft_tokens, ctx_dp_rank, disagg_info_endpoint);
|
||||
}
|
||||
},
|
||||
nb::arg("first_gen_tokens"), nb::arg("req_id"), nb::arg("opaque_state").none(),
|
||||
nb::arg("draft_tokens").none())
|
||||
.def_prop_ro("first_gen_tokens", [](tle::ContextPhaseParams const& self) { return self.getFirstGenTokens(); })
|
||||
.def_prop_ro("draft_tokens", [](tle::ContextPhaseParams const& self) { return self.getDraftTokens(); })
|
||||
.def_prop_ro("req_id", &tle::ContextPhaseParams::getReqId)
|
||||
nb::arg("draft_tokens").none(), nb::arg("ctx_dp_rank").none(), nb::arg("disagg_info_endpoint").none())
|
||||
.def_prop_rw(
|
||||
"first_gen_tokens", [](tle::ContextPhaseParams const& self) { return self.getFirstGenTokens(); },
|
||||
[](tle::ContextPhaseParams& self, VecTokens const& tokens) { self.setFirstGenTokens(tokens); })
|
||||
.def_prop_rw(
|
||||
"draft_tokens", [](tle::ContextPhaseParams const& self) { return self.getDraftTokens(); },
|
||||
[](tle::ContextPhaseParams& self, std::optional<VecTokens> const& tokens) { self.setDraftTokens(tokens); })
|
||||
.def_prop_rw("req_id", &tle::ContextPhaseParams::getReqId, &tle::ContextPhaseParams::setReqId)
|
||||
.def_prop_rw("ctx_dp_rank", &tle::ContextPhaseParams::getCtxDpRank, &tle::ContextPhaseParams::setCtxDpRank)
|
||||
.def_prop_rw("disagg_info_endpoint", &tle::ContextPhaseParams::getDisaggInfoEndpoint,
|
||||
&tle::ContextPhaseParams::setDisaggInfoEndpoint)
|
||||
.def_prop_ro("opaque_state",
|
||||
[](tle::ContextPhaseParams const& self)
|
||||
{
|
||||
|
||||
@ -172,7 +172,7 @@ void initBindings(pybind11::module_& m)
|
||||
.def_property_readonly("prepopulated_prompt_len", &GenLlmReq::getPrepopulatedPromptLen)
|
||||
.def_property(
|
||||
"guided_decoding_params", &GenLlmReq::getGuidedDecodingParams, &GenLlmReq::setGuidedDecodingParams)
|
||||
.def_property_readonly("context_phase_params", &GenLlmReq::getContextPhaseParams)
|
||||
.def_property("context_phase_params", &GenLlmReq::getContextPhaseParams, &GenLlmReq::setContextPhaseParams)
|
||||
.def_property_readonly("is_context_only_request", &GenLlmReq::isContextOnlyRequest)
|
||||
.def_property_readonly("is_generation_only_request", &GenLlmReq::isGenerationOnlyRequest)
|
||||
.def_property_readonly("is_generation_to_complete_state", &GenLlmReq::isGenerationToCompleteState)
|
||||
|
||||
@ -411,14 +411,16 @@ void initRequestBindings(pybind11::module_& m)
|
||||
{
|
||||
auto serializedState = self.getSerializedState();
|
||||
return py::make_tuple(self.getFirstGenTokens(), self.getReqId(),
|
||||
py::bytes(serializedState.data(), serializedState.size()), self.getDraftTokens());
|
||||
py::bytes(serializedState.data(), serializedState.size()), self.getDraftTokens(), self.getCtxDpRank(),
|
||||
self.getDisaggInfoEndpoint());
|
||||
}
|
||||
return py::make_tuple(self.getFirstGenTokens(), self.getReqId(), py::none(), self.getDraftTokens());
|
||||
return py::make_tuple(self.getFirstGenTokens(), self.getReqId(), py::none(), self.getDraftTokens(),
|
||||
self.getCtxDpRank(), self.getDisaggInfoEndpoint());
|
||||
};
|
||||
|
||||
auto ContextPhaseParamsSetState = [](py::tuple const& state)
|
||||
{
|
||||
if (state.size() != 4)
|
||||
if (state.size() != 6)
|
||||
{
|
||||
throw std::runtime_error("Invalid ContextPhaseParams state!");
|
||||
}
|
||||
@ -429,28 +431,42 @@ void initRequestBindings(pybind11::module_& m)
|
||||
return std::make_unique<tle::ContextPhaseParams>(state[0].cast<VecTokens>(),
|
||||
state[1].cast<tle::ContextPhaseParams::RequestIdType>(),
|
||||
std::vector<char>(opaque_state_str_view.begin(), opaque_state_str_view.end()),
|
||||
state[3].cast<std::optional<VecTokens>>());
|
||||
state[3].cast<std::optional<VecTokens>>(), state[4].cast<std::optional<SizeType32>>(),
|
||||
state[5].cast<std::optional<std::string>>());
|
||||
}
|
||||
return std::make_unique<tle::ContextPhaseParams>(state[0].cast<VecTokens>(),
|
||||
state[1].cast<tle::ContextPhaseParams::RequestIdType>(), state[3].cast<std::optional<VecTokens>>());
|
||||
state[1].cast<tle::ContextPhaseParams::RequestIdType>(), state[3].cast<std::optional<VecTokens>>(),
|
||||
state[4].cast<std::optional<SizeType32>>(), state[5].cast<std::optional<std::string>>());
|
||||
};
|
||||
|
||||
py::class_<tle::ContextPhaseParams>(m, "ContextPhaseParams")
|
||||
.def(py::init(
|
||||
[](VecTokens const& first_gen_tokens, tle::ContextPhaseParams::RequestIdType req_id,
|
||||
std::optional<py::bytes> const& opaque_state, std::optional<VecTokens> const& draft_tokens)
|
||||
{
|
||||
if (opaque_state)
|
||||
{
|
||||
auto opaque_state_str_view = std::string_view(opaque_state.value().cast<std::string_view>());
|
||||
return std::make_unique<tle::ContextPhaseParams>(first_gen_tokens, req_id,
|
||||
std::vector<char>(opaque_state_str_view.begin(), opaque_state_str_view.end()), draft_tokens);
|
||||
}
|
||||
return std::make_unique<tle::ContextPhaseParams>(first_gen_tokens, req_id, draft_tokens);
|
||||
}))
|
||||
.def_property_readonly("first_gen_tokens", &tle::ContextPhaseParams::getFirstGenTokens)
|
||||
.def_property_readonly("draft_tokens", &tle::ContextPhaseParams::getDraftTokens)
|
||||
.def_property_readonly("req_id", &tle::ContextPhaseParams::getReqId)
|
||||
[](VecTokens const& first_gen_tokens, tle::ContextPhaseParams::RequestIdType req_id,
|
||||
std::optional<py::bytes> const& opaque_state, std::optional<VecTokens> const& draft_tokens,
|
||||
std::optional<SizeType32> const& ctx_dp_rank,
|
||||
std::optional<std::string> const& disagg_info_endpoint)
|
||||
{
|
||||
if (opaque_state)
|
||||
{
|
||||
auto opaque_state_str_view = std::string_view(opaque_state.value().cast<std::string_view>());
|
||||
return std::make_unique<tle::ContextPhaseParams>(first_gen_tokens, req_id,
|
||||
std::vector<char>(opaque_state_str_view.begin(), opaque_state_str_view.end()),
|
||||
draft_tokens, ctx_dp_rank, disagg_info_endpoint);
|
||||
}
|
||||
return std::make_unique<tle::ContextPhaseParams>(
|
||||
first_gen_tokens, req_id, draft_tokens, ctx_dp_rank, disagg_info_endpoint);
|
||||
}),
|
||||
py::arg("first_gen_tokens"), py::arg("req_id"), py::arg("opaque_state") = py::none(),
|
||||
py::arg("draft_tokens") = py::none(), py::arg("ctx_dp_rank") = py::none(),
|
||||
py::arg("disagg_info_endpoint") = py::none())
|
||||
.def_property("first_gen_tokens", &tle::ContextPhaseParams::getFirstGenTokens,
|
||||
&tle::ContextPhaseParams::setFirstGenTokens)
|
||||
.def_property(
|
||||
"draft_tokens", &tle::ContextPhaseParams::getDraftTokens, &tle::ContextPhaseParams::setDraftTokens)
|
||||
.def_property("req_id", &tle::ContextPhaseParams::getReqId, &tle::ContextPhaseParams::setReqId)
|
||||
.def_property("ctx_dp_rank", &tle::ContextPhaseParams::getCtxDpRank, &tle::ContextPhaseParams::setCtxDpRank)
|
||||
.def_property("disagg_info_endpoint", &tle::ContextPhaseParams::getDisaggInfoEndpoint,
|
||||
&tle::ContextPhaseParams::setDisaggInfoEndpoint)
|
||||
.def_property_readonly("opaque_state",
|
||||
[](tle::ContextPhaseParams const& self)
|
||||
{
|
||||
|
||||
@ -751,6 +751,34 @@ TEST(SerializeUtilsTest, ContextPhaseParams)
|
||||
|
||||
EXPECT_EQ(state2, stateCopy);
|
||||
}
|
||||
|
||||
// Test with ctxDpRank and disaggInfoEndpoint
|
||||
{
|
||||
auto state = std::make_unique<texec::DataTransceiverState>();
|
||||
state->setCommState(texec::kv_cache::CommState{{10, 20}});
|
||||
auto stats
|
||||
= texec::ContextPhaseParams({10, 20, 30}, 2, state.release(), VecTokens{5, 6}, 3, "http://127.0.0.1:8080");
|
||||
auto stats2 = serializeDeserialize(stats);
|
||||
EXPECT_EQ(stats, stats2);
|
||||
EXPECT_EQ(stats.getCtxDpRank(), 3);
|
||||
EXPECT_EQ(stats.getDisaggInfoEndpoint(), "http://127.0.0.1:8080");
|
||||
}
|
||||
|
||||
{
|
||||
auto stats = texec::ContextPhaseParams({1, 2}, 1, std::nullopt, 5, std::nullopt);
|
||||
auto stats2 = serializeDeserialize(stats);
|
||||
EXPECT_EQ(stats, stats2);
|
||||
EXPECT_EQ(stats.getCtxDpRank(), 5);
|
||||
EXPECT_EQ(stats.getDisaggInfoEndpoint(), std::nullopt);
|
||||
}
|
||||
|
||||
{
|
||||
auto stats = texec::ContextPhaseParams({1, 2}, 1, std::nullopt, std::nullopt, "endpoint://test");
|
||||
auto stats2 = serializeDeserialize(stats);
|
||||
EXPECT_EQ(stats, stats2);
|
||||
EXPECT_EQ(stats.getCtxDpRank(), std::nullopt);
|
||||
EXPECT_EQ(stats.getDisaggInfoEndpoint(), "endpoint://test");
|
||||
}
|
||||
}
|
||||
|
||||
TEST(SerializeUtilsTest, SpeculativeDecodingFastLogitsInfo)
|
||||
|
||||
@ -482,10 +482,13 @@ class ExecutorRequestQueue:
|
||||
new_requests, "py_scheduling_params")
|
||||
py_num_logprobs = self._collect_py_objects_from_requests(
|
||||
new_requests, "py_num_logprobs")
|
||||
py_disaggregated_params = self._collect_py_objects_from_requests(
|
||||
new_requests, "py_disaggregated_params")
|
||||
py_request_objects = tuple(
|
||||
filter(None, [
|
||||
py_logits_post_processors, py_multimodal_data,
|
||||
py_scheduling_params, py_num_logprobs
|
||||
py_scheduling_params, py_num_logprobs,
|
||||
py_disaggregated_params
|
||||
]))
|
||||
else:
|
||||
py_request_objects = None
|
||||
|
||||
@ -567,6 +567,7 @@ class LlmRequest(tensorrt_llm.bindings.internal.batch_manager.LlmRequest):
|
||||
|
||||
self.py_logprobs_mode = LogprobMode(
|
||||
logprobs_mode) # handle passed a raw string
|
||||
self.py_disaggregated_params = None
|
||||
|
||||
self.py_result = PyResult(
|
||||
prompt_len=self.py_prompt_len,
|
||||
@ -831,6 +832,10 @@ def executor_request_to_llm_request(
|
||||
logprobs_mode=getattr(executor_request, "py_logprobs_mode",
|
||||
LogprobMode.RAW),
|
||||
)
|
||||
|
||||
llm_request.py_disaggregated_params = getattr(executor_request,
|
||||
"py_disaggregated_params",
|
||||
None)
|
||||
if child_req_ids:
|
||||
for child_id in child_req_ids:
|
||||
llm_request.create_child_request(child_id)
|
||||
|
||||
@ -36,7 +36,8 @@ class DisaggregatedParams:
|
||||
draft_tokens: Optional[List[int]] = None
|
||||
# If disagg_request_id is set, both context and generation requests will use it as underlying request id.
|
||||
disagg_request_id: Optional[int] = None
|
||||
|
||||
ctx_dp_rank: Optional[int] = None
|
||||
ctx_info_endpoint: Optional[List[str]] = None
|
||||
# E-P Disaggregated Params
|
||||
multimodal_embedding_handles: Optional[List[Dict[str, Any]]] = (
|
||||
None # multimodal embedding handles should be a list of cudaIPC handles for each mm_embedding
|
||||
@ -53,7 +54,12 @@ class DisaggregatedParams:
|
||||
self.disagg_request_id if self.disagg_request_id is not None else self.ctx_request_id
|
||||
)
|
||||
return tllme.ContextPhaseParams(
|
||||
self.first_gen_tokens, request_id, self.opaque_state, self.draft_tokens
|
||||
self.first_gen_tokens,
|
||||
request_id,
|
||||
self.opaque_state,
|
||||
self.draft_tokens,
|
||||
self.ctx_dp_rank,
|
||||
self.ctx_info_endpoint,
|
||||
)
|
||||
|
||||
def get_request_type(self) -> tllme.RequestType:
|
||||
|
||||
@ -566,6 +566,9 @@ class BaseWorker(GenerationExecutor):
|
||||
executor_request.py_lora_path = py_lora_path
|
||||
executor_request.py_logprobs_mode = request.sampling_params.logprobs_mode
|
||||
|
||||
# here we add executor_request.py_disaggregated_params= request.disaggregated_params for python cache transceiver
|
||||
if self._is_pytorch_backend and request.disaggregated_params is not None:
|
||||
executor_request.py_disaggregated_params = request.disaggregated_params
|
||||
if self._is_pytorch_backend and request.multimodal_params is not None:
|
||||
if request.multimodal_params.multimodal_data is not None:
|
||||
# NOTE: Deserialize SharedTensor handle to actual tensor
|
||||
|
||||
@ -427,6 +427,8 @@ class GenerationResultBase:
|
||||
ctx_request_id=context_phase_params.req_id,
|
||||
opaque_state=context_phase_params.opaque_state,
|
||||
draft_tokens=context_phase_params.draft_tokens,
|
||||
ctx_dp_rank=context_phase_params.ctx_dp_rank,
|
||||
ctx_info_endpoint=context_phase_params.disagg_info_endpoint,
|
||||
multimodal_embedding_handles=None,
|
||||
)
|
||||
|
||||
|
||||
@ -313,4 +313,8 @@ class OpenAIDisaggregatedService(OpenAIService):
|
||||
raise ValueError("Context server did not return disaggregated params")
|
||||
if ctx_response.choices[0].disaggregated_params.ctx_request_id is None:
|
||||
raise ValueError("Invalid disaggregated params in context phase response.")
|
||||
if ctx_response.choices[0].disaggregated_params.disagg_request_id is None:
|
||||
raise ValueError(
|
||||
"Invalid disaggregated params in context phase response. disagg_request_id is None"
|
||||
)
|
||||
return ctx_response
|
||||
|
||||
@ -119,6 +119,8 @@ class DisaggregatedParams(OpenAIBaseModel):
|
||||
encoded_opaque_state: Optional[str] = None
|
||||
draft_tokens: Optional[List[int]] = None
|
||||
disagg_request_id: Optional[int] = None
|
||||
ctx_dp_rank: Optional[int] = None
|
||||
ctx_info_endpoint: Optional[str] = None
|
||||
|
||||
|
||||
class ErrorResponse(OpenAIBaseModel):
|
||||
@ -1091,7 +1093,9 @@ def to_disaggregated_params(
|
||||
encoded_opaque_state=encode_opaque_state(
|
||||
tllm_disagg_params.opaque_state),
|
||||
draft_tokens=tllm_disagg_params.draft_tokens,
|
||||
disagg_request_id=tllm_disagg_params.disagg_request_id)
|
||||
disagg_request_id=tllm_disagg_params.disagg_request_id,
|
||||
ctx_dp_rank=tllm_disagg_params.ctx_dp_rank,
|
||||
ctx_info_endpoint=tllm_disagg_params.ctx_info_endpoint)
|
||||
|
||||
|
||||
def to_llm_disaggregated_params(
|
||||
@ -1105,7 +1109,9 @@ def to_llm_disaggregated_params(
|
||||
opaque_state=decode_opaque_state(
|
||||
disaggregated_params.encoded_opaque_state),
|
||||
draft_tokens=disaggregated_params.draft_tokens,
|
||||
disagg_request_id=disaggregated_params.disagg_request_id)
|
||||
disagg_request_id=disaggregated_params.disagg_request_id,
|
||||
ctx_dp_rank=disaggregated_params.ctx_dp_rank,
|
||||
ctx_info_endpoint=disaggregated_params.ctx_info_endpoint)
|
||||
|
||||
|
||||
UCompletionRequest = Union[CompletionRequest, ChatCompletionRequest]
|
||||
|
||||
@ -1198,9 +1198,9 @@ def test_result_pickle():
|
||||
result.sequence_index = 1
|
||||
result.is_sequence_final = True
|
||||
result.decoding_iter = 1
|
||||
result.context_phase_params = trtllm.ContextPhaseParams([1, 2], 123,
|
||||
bytes([0, 1]),
|
||||
[10, 20, 30])
|
||||
result.context_phase_params = trtllm.ContextPhaseParams(
|
||||
[1, 2], 123, bytes([0,
|
||||
1]), [10, 20, 30], 1, "disagg_info_endpoint_24680")
|
||||
result.request_perf_metrics = trtllm.RequestPerfMetrics()
|
||||
result.request_perf_metrics.last_iter = 33
|
||||
result_str = pickle.dumps(result)
|
||||
@ -1220,6 +1220,8 @@ def test_result_pickle():
|
||||
assert result.context_phase_params.first_gen_tokens == result_copy.context_phase_params.first_gen_tokens
|
||||
assert result.context_phase_params.draft_tokens == result_copy.context_phase_params.draft_tokens
|
||||
assert result.context_phase_params.opaque_state == result_copy.context_phase_params.opaque_state
|
||||
assert result.context_phase_params.ctx_dp_rank == result_copy.context_phase_params.ctx_dp_rank
|
||||
assert result.context_phase_params.disagg_info_endpoint == result_copy.context_phase_params.disagg_info_endpoint
|
||||
assert result.request_perf_metrics.last_iter == result_copy.request_perf_metrics.last_iter
|
||||
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user