mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
Merge ce44fabe8e into 6df2c8a074
This commit is contained in:
commit
f8cd142413
@ -684,6 +684,7 @@ public:
|
||||
/// finish reason. The request may exceed this time slightly, but at most by 1 forward pass (in pipeline parallelism
|
||||
/// that may involve multiple micro-batches). A request can be timed-out before ever being scheduled.
|
||||
/// @param cacheSaltID Salt ID for KV cache blocks to limit the kv cache reuse to the requests with the same string.
|
||||
/// @param disaggRequestId Disaggregated request ID.
|
||||
Request(VecTokens inputTokenIds, SizeType32 maxTokens, bool streaming = false,
|
||||
SamplingConfig const& samplingConfig = SamplingConfig(), OutputConfig const& outputConfig = OutputConfig(),
|
||||
std::optional<SizeType32> const& endId = std::nullopt, std::optional<SizeType32> const& padId = std::nullopt,
|
||||
@ -711,7 +712,8 @@ public:
|
||||
std::optional<GuidedDecodingParams> guidedDecodingParams = std::nullopt,
|
||||
std::optional<SizeType32> languageAdapterUid = std::nullopt,
|
||||
std::optional<MillisecondsType> allottedTimeMs = std::nullopt,
|
||||
std::optional<CacheSaltIDType> cacheSaltID = std::nullopt);
|
||||
std::optional<CacheSaltIDType> cacheSaltID = std::nullopt,
|
||||
std::optional<IdType> disaggRequestId = std::nullopt);
|
||||
|
||||
/// @brief This logits postprocessor name will dispatch to the batched logits postprocessor
|
||||
static auto constexpr kBatchedPostProcessorName = "batched";
|
||||
@ -761,6 +763,7 @@ public:
|
||||
[[nodiscard]] std::optional<MillisecondsType> getAllottedTimeMs() const;
|
||||
[[nodiscard]] std::optional<CacheSaltIDType> getCacheSaltID() const;
|
||||
[[nodiscard]] std::optional<std::vector<std::string>> getAdditionalOutputNames() const;
|
||||
[[nodiscard]] std::optional<IdType> getDisaggRequestId() const;
|
||||
|
||||
void setStreaming(bool streaming);
|
||||
void setSamplingConfig(SamplingConfig const& config);
|
||||
@ -796,6 +799,7 @@ public:
|
||||
void setLanguageAdapterUid(SizeType32 languageAdapterUid);
|
||||
void setAllottedTimeMs(MillisecondsType allottedTimeMs);
|
||||
void setCacheSaltID(CacheSaltIDType cacheSaltID);
|
||||
void setDisaggRequestId(IdType disaggRequestId);
|
||||
|
||||
private:
|
||||
friend class Serialization;
|
||||
|
||||
@ -40,7 +40,8 @@ Request::Request(VecTokens inputTokenIds, SizeType32 maxTokens, bool streaming,
|
||||
std::optional<SizeType32> encoderOutputLength, std::optional<Tensor> crossAttentionMask,
|
||||
SizeType32 numReturnSequences, std::optional<EagleConfig> eagleConfig, std::optional<Tensor> skipCrossAttnBlocks,
|
||||
std::optional<GuidedDecodingParams> guidedDecodingParams, std::optional<SizeType32> languageAdapterUid,
|
||||
std::optional<MillisecondsType> allottedTimeMs, std::optional<CacheSaltIDType> cacheSaltID)
|
||||
std::optional<MillisecondsType> allottedTimeMs, std::optional<CacheSaltIDType> cacheSaltID,
|
||||
std::optional<IdType> disaggRequestId)
|
||||
: mImpl(std::make_unique<Impl>(std::move(inputTokenIds), maxTokens, streaming, samplingConfig, outputConfig, endId,
|
||||
padId, std::move(positionIds), std::move(badWords), std::move(stopWords), std::move(embeddingBias),
|
||||
std::move(externalDraftTokensConfig), std::move(pTuningConfig), std::move(multimodalInput),
|
||||
@ -49,7 +50,7 @@ Request::Request(VecTokens inputTokenIds, SizeType32 maxTokens, bool streaming,
|
||||
std::move(encoderInputTokenIds), clientId, returnAllGeneratedTokens, priority, type,
|
||||
std::move(contextPhaseParams), std::move(encoderInputFeatures), encoderOutputLength, crossAttentionMask,
|
||||
numReturnSequences, eagleConfig, skipCrossAttnBlocks, std::move(guidedDecodingParams), languageAdapterUid,
|
||||
allottedTimeMs, cacheSaltID))
|
||||
allottedTimeMs, cacheSaltID, disaggRequestId))
|
||||
{
|
||||
}
|
||||
|
||||
@ -253,6 +254,11 @@ std::optional<CacheSaltIDType> Request::getCacheSaltID() const
|
||||
return mImpl->getCacheSaltID();
|
||||
}
|
||||
|
||||
std::optional<IdType> Request::getDisaggRequestId() const
|
||||
{
|
||||
return mImpl->getDisaggRequestId();
|
||||
}
|
||||
|
||||
void Request::setStreaming(bool streaming)
|
||||
{
|
||||
mImpl->setStreaming(streaming);
|
||||
@ -310,12 +316,12 @@ void Request::setPromptTuningConfig(PromptTuningConfig const& pTuningConfig)
|
||||
|
||||
void Request::setMultimodalEmbedding(Tensor const& multimodalEmbedding)
|
||||
{
|
||||
return mImpl->setMultimodalEmbedding(multimodalEmbedding);
|
||||
mImpl->setMultimodalEmbedding(multimodalEmbedding);
|
||||
}
|
||||
|
||||
void Request::setMultimodalInput(MultimodalInput const& multimodalInput)
|
||||
{
|
||||
return mImpl->setMultimodalInput(multimodalInput);
|
||||
mImpl->setMultimodalInput(multimodalInput);
|
||||
}
|
||||
|
||||
void Request::setMropeConfig(MropeConfig const& mRopeConfig)
|
||||
@ -400,7 +406,7 @@ void Request::setEagleConfig(std::optional<EagleConfig> const& eagleConfig)
|
||||
|
||||
void Request::setSkipCrossAttnBlocks(Tensor skipCrossAttnBlocks)
|
||||
{
|
||||
return mImpl->setSkipCrossAttnBlocks(skipCrossAttnBlocks);
|
||||
mImpl->setSkipCrossAttnBlocks(skipCrossAttnBlocks);
|
||||
}
|
||||
|
||||
void Request::setGuidedDecodingParams(GuidedDecodingParams const& guidedDecodingParams)
|
||||
@ -410,16 +416,21 @@ void Request::setGuidedDecodingParams(GuidedDecodingParams const& guidedDecoding
|
||||
|
||||
void Request::setAllottedTimeMs(MillisecondsType allottedTimeMs)
|
||||
{
|
||||
return mImpl->setAllottedTimeMs(allottedTimeMs);
|
||||
mImpl->setAllottedTimeMs(allottedTimeMs);
|
||||
}
|
||||
|
||||
void Request::setLanguageAdapterUid(SizeType32 languageAdapterUid)
|
||||
{
|
||||
return mImpl->setLanguageAdapterUid(languageAdapterUid);
|
||||
mImpl->setLanguageAdapterUid(languageAdapterUid);
|
||||
}
|
||||
|
||||
void Request::setCacheSaltID(CacheSaltIDType cacheSaltID)
|
||||
{
|
||||
return mImpl->setCacheSaltID(cacheSaltID);
|
||||
mImpl->setCacheSaltID(cacheSaltID);
|
||||
}
|
||||
|
||||
void Request::setDisaggRequestId(IdType disaggRequestId)
|
||||
{
|
||||
mImpl->setDisaggRequestId(disaggRequestId);
|
||||
}
|
||||
} // namespace tensorrt_llm::executor
|
||||
|
||||
@ -48,7 +48,7 @@ public:
|
||||
std::optional<Tensor> crossAttentionMask, SizeType32 numReturnSequences, std::optional<EagleConfig> eagleConfig,
|
||||
std::optional<Tensor> skipCrossAttnBlocks, std::optional<GuidedDecodingParams> guidedDecodingParams,
|
||||
std::optional<SizeType32> languageAdapterUid, std::optional<MillisecondsType> allottedTimeMs,
|
||||
std::optional<CacheSaltIDType> cacheSaltID)
|
||||
std::optional<CacheSaltIDType> cacheSaltID, std::optional<IdType> disaggRequestId)
|
||||
: mInputTokenIds(std::move(inputTokenIds))
|
||||
, mMaxNewTokens(maxNewTokens)
|
||||
, mStreaming(streaming)
|
||||
@ -86,6 +86,7 @@ public:
|
||||
, mLanguageAdapterUid(languageAdapterUid)
|
||||
, mAllottedTimeMs(allottedTimeMs)
|
||||
, mCacheSaltID(cacheSaltID)
|
||||
, mDisaggRequestId(disaggRequestId)
|
||||
{
|
||||
validate();
|
||||
}
|
||||
@ -302,6 +303,11 @@ public:
|
||||
return mCacheSaltID;
|
||||
}
|
||||
|
||||
[[nodiscard]] std::optional<IdType> getDisaggRequestId() const
|
||||
{
|
||||
return mDisaggRequestId;
|
||||
}
|
||||
|
||||
void setStreaming(bool streaming)
|
||||
{
|
||||
mStreaming = streaming;
|
||||
@ -481,6 +487,11 @@ public:
|
||||
mCacheSaltID = cacheSaltID;
|
||||
}
|
||||
|
||||
void setDisaggRequestId(IdType disaggRequestId)
|
||||
{
|
||||
mDisaggRequestId = disaggRequestId;
|
||||
}
|
||||
|
||||
private:
|
||||
void validate()
|
||||
{
|
||||
@ -555,6 +566,7 @@ private:
|
||||
lambda(mLanguageAdapterUid);
|
||||
lambda(mAllottedTimeMs ? std::make_optional(mAllottedTimeMs->count()) : std::nullopt);
|
||||
lambda(mCacheSaltID);
|
||||
lambda(mDisaggRequestId);
|
||||
}
|
||||
|
||||
VecTokens mInputTokenIds;
|
||||
@ -594,6 +606,7 @@ private:
|
||||
std::optional<SizeType32> mLanguageAdapterUid;
|
||||
std::optional<MillisecondsType> mAllottedTimeMs;
|
||||
std::optional<CacheSaltIDType> mCacheSaltID;
|
||||
std::optional<IdType> mDisaggRequestId;
|
||||
};
|
||||
|
||||
} // namespace tensorrt_llm::executor
|
||||
|
||||
@ -578,11 +578,11 @@ void initRequestBindings(nb::module_& m)
|
||||
self.getClientId(), self.getReturnAllGeneratedTokens(), self.getPriority(), self.getRequestType(),
|
||||
self.getContextPhaseParams(), self.getEncoderInputFeatures(), self.getEncoderOutputLength(),
|
||||
self.getCrossAttentionMask(), self.getEagleConfig(), self.getSkipCrossAttnBlocks(),
|
||||
self.getGuidedDecodingParams(), self.getCacheSaltID());
|
||||
self.getGuidedDecodingParams(), self.getCacheSaltID(), self.getDisaggRequestId());
|
||||
};
|
||||
auto requestSetstate = [](tle::Request& self, nb::tuple const& state)
|
||||
{
|
||||
if (state.size() != 34)
|
||||
if (state.size() != 35)
|
||||
{
|
||||
throw std::runtime_error("Invalid Request state!");
|
||||
}
|
||||
@ -606,8 +606,8 @@ void initRequestBindings(nb::module_& m)
|
||||
nb::cast<std::optional<tle::Tensor>>(state[27]), nb::cast<std::optional<SizeType32>>(state[28]),
|
||||
nb::cast<std::optional<tle::Tensor>>(state[29]), 1, nb::cast<std::optional<tle::EagleConfig>>(state[30]),
|
||||
nb::cast<std::optional<tle::Tensor>>(state[31]),
|
||||
nb::cast<std::optional<tle::GuidedDecodingParams>>(state[32]),
|
||||
nb::cast<std::optional<tle::CacheSaltIDType>>(state[33]));
|
||||
nb::cast<std::optional<tle::GuidedDecodingParams>>(state[32]), std::nullopt, std::nullopt,
|
||||
nb::cast<std::optional<tle::CacheSaltIDType>>(state[33]), nb::cast<std::optional<tle::IdType>>(state[34]));
|
||||
};
|
||||
|
||||
nb::class_<tle::Request> request(m, "Request", nb::dynamic_attr());
|
||||
@ -648,7 +648,8 @@ void initRequestBindings(nb::module_& m)
|
||||
std::optional<tle::GuidedDecodingParams>, // guidedDecodingParams
|
||||
std::optional<tle::SizeType32>, // languageAdapterUid
|
||||
std::optional<tle::MillisecondsType>, // allottedTimeMs
|
||||
std::optional<tle::CacheSaltIDType> // cacheSaltID
|
||||
std::optional<tle::CacheSaltIDType>, // cacheSaltID
|
||||
std::optional<tle::IdType> // disaggRequestId
|
||||
>(),
|
||||
// clang-format off
|
||||
nb::arg("input_token_ids"),
|
||||
@ -688,8 +689,9 @@ void initRequestBindings(nb::module_& m)
|
||||
nb::arg("guided_decoding_params") = nb::none(),
|
||||
nb::arg("language_adapter_uid") = nb::none(),
|
||||
nb::arg("allotted_time_ms") = nb::none(),
|
||||
nb::arg("cache_salt_id") = nb::none()
|
||||
) // clang-format on
|
||||
nb::arg("cache_salt_id") = nb::none(),
|
||||
nb::arg("disagg_request_id") = nb::none()
|
||||
) // clang-format on
|
||||
.def_prop_ro("input_token_ids", &tle::Request::getInputTokenIds)
|
||||
.def_prop_ro("max_tokens", &tle::Request::getMaxTokens)
|
||||
.def_prop_rw("streaming", &tle::Request::getStreaming, &tle::Request::setStreaming)
|
||||
@ -733,6 +735,7 @@ void initRequestBindings(nb::module_& m)
|
||||
.def_prop_rw("allotted_time_ms", &tle::Request::getAllottedTimeMs, &tle::Request::setAllottedTimeMs)
|
||||
.def_prop_rw("cache_salt_id", &tle::Request::getCacheSaltID, &tle::Request::setCacheSaltID)
|
||||
.def_prop_rw("context_phase_params", &tle::Request::getContextPhaseParams, &tle::Request::setContextPhaseParams)
|
||||
.def_prop_rw("disagg_request_id", &tle::Request::getDisaggRequestId, &tle::Request::setDisaggRequestId)
|
||||
.def("__getstate__", requestGetstate)
|
||||
.def("__setstate__", requestSetstate);
|
||||
request.attr("BATCHED_POST_PROCESSOR_NAME") = tle::Request::kBatchedPostProcessorName;
|
||||
@ -929,8 +932,8 @@ void initRequestBindings(nb::module_& m)
|
||||
{
|
||||
throw std::runtime_error("Invalid Request state!");
|
||||
}
|
||||
new (&response) tle::Response(
|
||||
nb::cast<SizeType32>(state[0]), nb::cast<tle::Result>(state[1]), nb::cast<SizeType32>(state[2]));
|
||||
new (&response)
|
||||
tle::Response(nb::cast<IdType>(state[0]), nb::cast<tle::Result>(state[1]), nb::cast<IdType>(state[2]));
|
||||
};
|
||||
|
||||
nb::class_<tle::Response>(m, "Response")
|
||||
|
||||
@ -85,8 +85,7 @@ public:
|
||||
std::optional<SizeType32> languageAdapterUid = std::nullopt,
|
||||
std::optional<MillisecondsType> allottedTimeMs = std::nullopt,
|
||||
std::optional<executor::ContextPhaseParams> const& contextPhaseParams = std::nullopt,
|
||||
std::optional<CacheSaltIDType> cacheSaltID = std::nullopt,
|
||||
std::optional<TimePoint> arrivalTime = std::nullopt)
|
||||
std::optional<CacheSaltIDType> cacheSaltID = std::nullopt, std::optional<TimePoint> arrivalTime = std::nullopt)
|
||||
: Base(requestId, //
|
||||
maxNewTokens, //
|
||||
std::make_shared<std::vector<TokenIdType>>(std::move(inputTokens)), //
|
||||
|
||||
@ -531,11 +531,11 @@ void initRequestBindings(pybind11::module_& m)
|
||||
self.getClientId(), self.getReturnAllGeneratedTokens(), self.getPriority(), self.getRequestType(),
|
||||
self.getContextPhaseParams(), self.getEncoderInputFeatures(), self.getEncoderOutputLength(),
|
||||
self.getCrossAttentionMask(), self.getEagleConfig(), self.getSkipCrossAttnBlocks(),
|
||||
self.getGuidedDecodingParams(), self.getCacheSaltID());
|
||||
self.getGuidedDecodingParams(), self.getCacheSaltID(), self.getDisaggRequestId());
|
||||
};
|
||||
auto requestSetstate = [](py::tuple const& state)
|
||||
{
|
||||
if (state.size() != 34)
|
||||
if (state.size() != 35)
|
||||
{
|
||||
throw std::runtime_error("Invalid Request state!");
|
||||
}
|
||||
@ -556,7 +556,8 @@ void initRequestBindings(pybind11::module_& m)
|
||||
state[27].cast<std::optional<tle::Tensor>>(), state[28].cast<std::optional<SizeType32>>(),
|
||||
state[29].cast<std::optional<tle::Tensor>>(), 1, state[30].cast<std::optional<tle::EagleConfig>>(),
|
||||
state[31].cast<std::optional<tle::Tensor>>(), state[32].cast<std::optional<tle::GuidedDecodingParams>>(),
|
||||
state[33].cast<std::optional<tle::CacheSaltIDType>>());
|
||||
std::nullopt, std::nullopt, state[33].cast<std::optional<tle::CacheSaltIDType>>(),
|
||||
state[34].cast<std::optional<tle::IdType>>());
|
||||
};
|
||||
|
||||
py::class_<tle::Request> request(m, "Request", pybind11::dynamic_attr());
|
||||
@ -597,7 +598,8 @@ void initRequestBindings(pybind11::module_& m)
|
||||
std::optional<tle::GuidedDecodingParams>, // guidedDecodingParams
|
||||
std::optional<tle::SizeType32>, // languageAdapterUid
|
||||
std::optional<tle::MillisecondsType>, // allottedTimeMs
|
||||
std::optional<tle::CacheSaltIDType> // cacheSaltID
|
||||
std::optional<tle::CacheSaltIDType>, // cacheSaltID
|
||||
std::optional<tle::IdType> // disaggRequestId
|
||||
>(),
|
||||
// clang-format off
|
||||
py::arg("input_token_ids"),
|
||||
@ -638,8 +640,9 @@ void initRequestBindings(pybind11::module_& m)
|
||||
py::arg("guided_decoding_params") = py::none(),
|
||||
py::arg("language_adapter_uid") = py::none(),
|
||||
py::arg("allotted_time_ms") = py::none(),
|
||||
py::arg("cache_salt_id") = py::none()
|
||||
) // clang-format on
|
||||
py::arg("cache_salt_id") = py::none(),
|
||||
py::arg("disagg_request_id") = py::none()
|
||||
) // clang-format on
|
||||
.def_property_readonly("input_token_ids", &tle::Request::getInputTokenIds)
|
||||
.def_property_readonly("max_tokens", &tle::Request::getMaxTokens)
|
||||
.def_property("streaming", &tle::Request::getStreaming, &tle::Request::setStreaming)
|
||||
@ -686,6 +689,7 @@ void initRequestBindings(pybind11::module_& m)
|
||||
.def_property("cache_salt_id", &tle::Request::getCacheSaltID, &tle::Request::setCacheSaltID)
|
||||
.def_property(
|
||||
"context_phase_params", &tle::Request::getContextPhaseParams, &tle::Request::setContextPhaseParams)
|
||||
.def_property("disagg_request_id", &tle::Request::getDisaggRequestId, &tle::Request::setDisaggRequestId)
|
||||
.def(py::pickle(requestGetstate, requestSetstate));
|
||||
request.attr("BATCHED_POST_PROCESSOR_NAME") = tle::Request::kBatchedPostProcessorName;
|
||||
|
||||
@ -870,7 +874,7 @@ void initRequestBindings(pybind11::module_& m)
|
||||
throw std::runtime_error("Invalid Request state!");
|
||||
}
|
||||
return std::make_unique<tle::Response>(
|
||||
state[0].cast<SizeType32>(), state[1].cast<tle::Result>(), state[2].cast<SizeType32>());
|
||||
state[0].cast<IdType>(), state[1].cast<tle::Result>(), state[2].cast<IdType>());
|
||||
};
|
||||
|
||||
py::class_<tle::Response>(m, "Response")
|
||||
|
||||
@ -11,6 +11,7 @@ from typing import Dict, Iterable, List, Optional, Tuple
|
||||
import torch
|
||||
|
||||
from tensorrt_llm._utils import mpi_disabled, nvtx_range
|
||||
from tensorrt_llm.llmapi.disagg_utils import get_local_request_id
|
||||
from tensorrt_llm.mapping import CpType
|
||||
|
||||
from ..distributed import Distributed
|
||||
@ -206,10 +207,15 @@ class ExecutorRequestQueue:
|
||||
|
||||
return False
|
||||
|
||||
def _get_request_id(self):
|
||||
# (next_request_id + 1) % UINT64_MAX
|
||||
def _get_request_id(self, request: Optional[ExecutorRequest] = None):
|
||||
# if request has a disagg_request_id, use it as request id so that
|
||||
# corresponding context and generation requests have the same request id
|
||||
if request and request.disagg_request_id and isinstance(
|
||||
request.disagg_request_id, int):
|
||||
return request.disagg_request_id
|
||||
|
||||
current_id = self.next_request_id
|
||||
self.next_request_id = (self.next_request_id + 1) & ((1 << 64) - 1)
|
||||
self.next_request_id = get_local_request_id(current_id)
|
||||
return current_id
|
||||
|
||||
def _generate_child_request_ids(
|
||||
@ -236,7 +242,7 @@ class ExecutorRequestQueue:
|
||||
assert self.active, "PyExecutor has already been shutdown."
|
||||
start_time = time.time()
|
||||
for request, query in requests_and_queries:
|
||||
req_id = self._get_request_id()
|
||||
req_id = self._get_request_id(request)
|
||||
if self.enable_iter_perf_stats:
|
||||
self.start_times[req_id] = start_time
|
||||
child_req_ids = self._generate_child_request_ids(request)
|
||||
|
||||
@ -21,6 +21,8 @@ class DisaggregatedParams:
|
||||
ctx_request_id (int): The context request id
|
||||
opaque_state(bytes): Any additional state needing to be exchanged between context and gen instances
|
||||
draft_tokens (List[int]): The draft tokens of the generation request
|
||||
disagg_request_id (int): The disaggregated request id, if set, both context and generation requests will use it
|
||||
as underlying request id.
|
||||
|
||||
multimodal_embedding_handles (List[Dict[str, Any]]): The resulting multimodal embedding handles from ViT.
|
||||
multimodal_hashes (List[List[int]]): The multimodal hashes of each multimodal item in the request.
|
||||
@ -32,6 +34,8 @@ class DisaggregatedParams:
|
||||
ctx_request_id: Optional[int] = None
|
||||
opaque_state: Optional[bytes] = None
|
||||
draft_tokens: Optional[List[int]] = None
|
||||
# If disagg_request_id is set, both context and generation requests will use it as underlying request id.
|
||||
disagg_request_id: Optional[int] = None
|
||||
|
||||
# E-P Disaggregated Params
|
||||
multimodal_embedding_handles: Optional[List[Dict[str, Any]]] = (
|
||||
@ -44,8 +48,12 @@ class DisaggregatedParams:
|
||||
mrope_position_deltas_handle: Optional[Dict[str, Any]] = None
|
||||
|
||||
def get_context_phase_params(self) -> tllme.ContextPhaseParams:
|
||||
# Prefer disagg_request_id over ctx_request_id
|
||||
request_id = (
|
||||
self.disagg_request_id if self.disagg_request_id is not None else self.ctx_request_id
|
||||
)
|
||||
return tllme.ContextPhaseParams(
|
||||
self.first_gen_tokens, self.ctx_request_id, self.opaque_state, self.draft_tokens
|
||||
self.first_gen_tokens, request_id, self.opaque_state, self.draft_tokens
|
||||
)
|
||||
|
||||
def get_request_type(self) -> tllme.RequestType:
|
||||
|
||||
@ -433,6 +433,8 @@ class BaseWorker(GenerationExecutor):
|
||||
|
||||
context_phase_params = None
|
||||
request_type = tllm.RequestType.REQUEST_TYPE_CONTEXT_AND_GENERATION
|
||||
disagg_request_id = 0
|
||||
|
||||
if request.disaggregated_params is not None:
|
||||
assert (
|
||||
not self._is_pytorch_backend
|
||||
@ -441,6 +443,7 @@ class BaseWorker(GenerationExecutor):
|
||||
== "context_and_generation"
|
||||
), "kv_cache_transceiver is disabled, please set 'cache_transceiver_config: backend:<backend_type>` in config file for disaggregated serving"
|
||||
request_type = request.disaggregated_params.get_request_type()
|
||||
disagg_request_id = request.disaggregated_params.disagg_request_id
|
||||
if request_type == tllm.RequestType.REQUEST_TYPE_GENERATION_ONLY:
|
||||
context_phase_params = request.disaggregated_params.get_context_phase_params(
|
||||
)
|
||||
@ -559,7 +562,8 @@ class BaseWorker(GenerationExecutor):
|
||||
kv_cache_retention_config=request.kv_cache_retention_config,
|
||||
context_phase_params=context_phase_params,
|
||||
type=request_type,
|
||||
cache_salt_id=request.cache_salt_id)
|
||||
cache_salt_id=request.cache_salt_id,
|
||||
disagg_request_id=disagg_request_id)
|
||||
executor_request.py_num_logprobs = request.sampling_params.logprobs
|
||||
executor_request.py_lora_path = py_lora_path
|
||||
|
||||
|
||||
@ -212,7 +212,7 @@ class GenerationExecutor(ABC):
|
||||
|
||||
return futures
|
||||
|
||||
def _get_next_client_id(self):
|
||||
def _get_next_client_id(self) -> int:
|
||||
# (self._last_client_id + 1) % UINT64_MAX
|
||||
self._last_client_id = (self._last_client_id + 1) & ((1 << 64) - 1)
|
||||
return self._last_client_id
|
||||
|
||||
@ -233,7 +233,8 @@ class RayExecutor(RpcExecutorMixin, GenerationExecutor):
|
||||
Low-level API to the executor. Return a "future" GenerationResult
|
||||
which can be waited. Forwards the request to the workers through RPC.
|
||||
"""
|
||||
request.set_id(self._get_next_client_id())
|
||||
if request.id is None:
|
||||
request.set_id(self._get_next_client_id())
|
||||
logprob_params = self._get_logprob_params(request)
|
||||
|
||||
with nvtx_range_debug("rpc_submit"):
|
||||
|
||||
@ -1,4 +1,7 @@
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from enum import IntEnum
|
||||
from typing import Any, Dict, List, Literal, Optional, Tuple
|
||||
@ -76,6 +79,9 @@ class DisaggServerConfig():
|
||||
max_retries: int = 1
|
||||
perf_metrics_max_requests: int = 0
|
||||
disagg_cluster_config: Optional[DisaggClusterConfig] = None
|
||||
node_id: int = uuid.getnode(
|
||||
) % 1021 # Assuming only one disagg-server is running on a machine, moding mac by the largest 10-bit prime
|
||||
# If this causes collisions, users can set node_id manually within range [0, 1023] in config
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -331,3 +337,48 @@ def parse_metadata_server_config_file(
|
||||
with open(metadata_server_config_file, 'r') as file:
|
||||
config = yaml.safe_load(file)
|
||||
return MetadataServerConfig(**config)
|
||||
|
||||
|
||||
MIN_GLOBAL_ID = 1 << 42
|
||||
|
||||
# Consider GIL being removed in the future, use a lock to protect the counter
|
||||
_global_disagg_request_id_lock = threading.Lock()
|
||||
_global_disagg_request_id_counter = 0
|
||||
|
||||
|
||||
def get_global_disagg_request_id(machine_id: int) -> int:
|
||||
"""
|
||||
a snowflake global disagg request id that doesn't guarantee monotonicity
|
||||
0: positive integer
|
||||
1-41 41 bits: timestamp_ms
|
||||
42-51 10 bits: machine_id
|
||||
52-63 12 bits: counter
|
||||
"""
|
||||
global _global_disagg_request_id_lock
|
||||
global _global_disagg_request_id_counter
|
||||
|
||||
COUNTER_BITS = 12
|
||||
MACHINE_ID_BITS = 10
|
||||
COUNTER_MASK = (1 << COUNTER_BITS) - 1
|
||||
MAX_INT64 = (1 << 63) - 1
|
||||
|
||||
if machine_id not in range(0, (1 << MACHINE_ID_BITS) - 1):
|
||||
raise ValueError(
|
||||
f"machine_id must be in range [0, {(1 << MACHINE_ID_BITS) - 1})")
|
||||
|
||||
timestamp_ms = int(time.monotonic() * 1000)
|
||||
with _global_disagg_request_id_lock:
|
||||
counter = _global_disagg_request_id_counter & COUNTER_MASK
|
||||
_global_disagg_request_id_counter += 1
|
||||
|
||||
# Rotate in [MIN_GLOBAL_ID, MAX_INT64)
|
||||
# [0, MIN_GLOBAL_ID) is reserved for local ids
|
||||
global_id = (timestamp_ms << (MACHINE_ID_BITS + COUNTER_BITS)) | (
|
||||
machine_id << COUNTER_BITS) | counter
|
||||
global_id_int64 = global_id % (MAX_INT64 - MIN_GLOBAL_ID) + MIN_GLOBAL_ID
|
||||
return global_id_int64
|
||||
|
||||
|
||||
def get_local_request_id(last_id: int) -> int:
|
||||
""" increment the last_id by 1 and mod by MIN_GLOBAL_ID """
|
||||
return (last_id + 1) & (MIN_GLOBAL_ID - 1)
|
||||
|
||||
@ -23,6 +23,7 @@ from tensorrt_llm.llmapi.disagg_utils import (
|
||||
DisaggServerConfig,
|
||||
MetadataServerConfig,
|
||||
ServerRole,
|
||||
get_global_disagg_request_id,
|
||||
)
|
||||
from tensorrt_llm.logger import logger
|
||||
from tensorrt_llm.serve.cluster_storage import ClusterStorage, WatchEventType
|
||||
@ -115,14 +116,15 @@ class OpenAIDisaggregatedService(OpenAIService):
|
||||
need_ctx = need_ctx and not await self._check_gen_only_disagg(request)
|
||||
ctx_response = None
|
||||
gen_req = request
|
||||
disagg_request_id = get_global_disagg_request_id(self._config.node_id)
|
||||
if need_ctx:
|
||||
ctx_req = self._get_ctx_request(request)
|
||||
ctx_req = self._get_ctx_request(request, disagg_request_id)
|
||||
# ctx generator is empty
|
||||
ctx_response = await self._ctx_client.send_request(
|
||||
ctx_req, server=reserved_ctx_server, hooks=hooks
|
||||
)
|
||||
await self._verify_ctx_response(ctx_response)
|
||||
gen_req = self._get_gen_request(request, ctx_response)
|
||||
gen_req = self._get_gen_request(request, ctx_response, disagg_request_id)
|
||||
if ctx_response is None or self._need_gen(ctx_response):
|
||||
return await self._gen_client.send_request(
|
||||
gen_req, server=reserved_gen_server, hooks=hooks
|
||||
@ -140,9 +142,13 @@ class OpenAIDisaggregatedService(OpenAIService):
|
||||
return False
|
||||
return True
|
||||
|
||||
def _get_ctx_request(self, request: UCompletionRequest) -> UCompletionRequest:
|
||||
def _get_ctx_request(
|
||||
self, request: UCompletionRequest, disagg_request_id: Optional[int]
|
||||
) -> UCompletionRequest:
|
||||
ctx_request = copy.deepcopy(request)
|
||||
ctx_request.disaggregated_params = DisaggregatedParams(request_type="context_only")
|
||||
ctx_request.disaggregated_params = DisaggregatedParams(
|
||||
request_type="context_only", disagg_request_id=disagg_request_id
|
||||
)
|
||||
ctx_request.stream = False
|
||||
ctx_request.stream_options = None
|
||||
return ctx_request
|
||||
@ -151,6 +157,7 @@ class OpenAIDisaggregatedService(OpenAIService):
|
||||
self,
|
||||
request: UCompletionRequest,
|
||||
ctx_response: UCompletionResponse,
|
||||
disagg_request_id: Optional[int],
|
||||
) -> UCompletionRequest:
|
||||
request.disaggregated_params = ctx_response.choices[0].disaggregated_params
|
||||
request.disaggregated_params.request_type = "generation_only"
|
||||
|
||||
@ -117,6 +117,7 @@ class DisaggregatedParams(OpenAIBaseModel):
|
||||
ctx_request_id: Optional[int] = None
|
||||
encoded_opaque_state: Optional[str] = None
|
||||
draft_tokens: Optional[List[int]] = None
|
||||
disagg_request_id: Optional[int] = None
|
||||
|
||||
|
||||
class ErrorResponse(OpenAIBaseModel):
|
||||
@ -1000,7 +1001,8 @@ def to_disaggregated_params(
|
||||
ctx_request_id=tllm_disagg_params.ctx_request_id,
|
||||
encoded_opaque_state=encode_opaque_state(
|
||||
tllm_disagg_params.opaque_state),
|
||||
draft_tokens=tllm_disagg_params.draft_tokens)
|
||||
draft_tokens=tllm_disagg_params.draft_tokens,
|
||||
disagg_request_id=tllm_disagg_params.disagg_request_id)
|
||||
|
||||
|
||||
def to_llm_disaggregated_params(
|
||||
@ -1013,7 +1015,8 @@ def to_llm_disaggregated_params(
|
||||
ctx_request_id=disaggregated_params.ctx_request_id,
|
||||
opaque_state=decode_opaque_state(
|
||||
disaggregated_params.encoded_opaque_state),
|
||||
draft_tokens=disaggregated_params.draft_tokens)
|
||||
draft_tokens=disaggregated_params.draft_tokens,
|
||||
disagg_request_id=disaggregated_params.disagg_request_id)
|
||||
|
||||
|
||||
UCompletionRequest = Union[CompletionRequest, ChatCompletionRequest]
|
||||
|
||||
@ -1,11 +1,13 @@
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
|
||||
# isort: off
|
||||
from tensorrt_llm.llmapi.disagg_utils import (
|
||||
CtxGenServerConfig, DisaggServerConfig, extract_ctx_gen_cfgs,
|
||||
extract_router_config, extract_disagg_cfg, get_server_configs_dict,
|
||||
parse_disagg_config_file)
|
||||
MIN_GLOBAL_ID, CtxGenServerConfig, DisaggServerConfig, extract_ctx_gen_cfgs,
|
||||
extract_router_config, extract_disagg_cfg, get_global_disagg_request_id,
|
||||
get_local_request_id, get_server_configs_dict, parse_disagg_config_file)
|
||||
# isort: on
|
||||
|
||||
|
||||
@ -155,3 +157,45 @@ def test_get_server_configs_dict():
|
||||
assert len(server_dict) == 2
|
||||
assert ("host1", 8001) in server_dict
|
||||
assert ("host2", 8002) in server_dict
|
||||
|
||||
|
||||
# test get_global_disagg_request_id
|
||||
@pytest.mark.parametrize("multithread", [True, False],
|
||||
ids=["multithread", "singlethread"])
|
||||
def test_get_global_disagg_request_id(multithread):
|
||||
iter = 10000
|
||||
|
||||
def get_ids(node_ids):
|
||||
all_node_ids = [[] for _ in range(len(node_ids))]
|
||||
for i in range(iter):
|
||||
for i, node_id in enumerate(node_ids):
|
||||
all_node_ids[i].append(get_global_disagg_request_id(node_id))
|
||||
return all_node_ids
|
||||
|
||||
node_ids = list(range(10))
|
||||
if multithread:
|
||||
with ThreadPoolExecutor(max_workers=len(node_ids)) as executor:
|
||||
all_node_ids = [
|
||||
ids[0] for ids in executor.map(get_ids, [[i] for i in node_ids])
|
||||
]
|
||||
else:
|
||||
all_node_ids = get_ids(node_ids)
|
||||
|
||||
all_ids = set(i for ids in all_node_ids for i in ids)
|
||||
assert len(all_ids) == iter * len(node_ids)
|
||||
assert all(id >= MIN_GLOBAL_ID and id < ((1 << 63) - 1) for id in all_ids)
|
||||
|
||||
|
||||
def test_get_local_request_id():
|
||||
last_id = MIN_GLOBAL_ID - 100
|
||||
ids = set()
|
||||
for i in range(1000):
|
||||
last_id = get_local_request_id(last_id)
|
||||
assert last_id >= 0
|
||||
assert last_id < MIN_GLOBAL_ID
|
||||
ids.add(last_id)
|
||||
assert len(ids) == 1000
|
||||
assert min(ids) == 0
|
||||
assert max(ids) == MIN_GLOBAL_ID - 1
|
||||
assert max(ids) - min(ids) > (
|
||||
1 << 40) # ensure there is enough space for local ids
|
||||
|
||||
Loading…
Reference in New Issue
Block a user