This commit is contained in:
Lizhi Zhou 2026-01-13 21:25:09 +08:00 committed by GitHub
commit f8cd142413
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 203 additions and 45 deletions

View File

@ -684,6 +684,7 @@ public:
/// finish reason. The request may exceed this time slightly, but at most by 1 forward pass (in pipeline parallelism
/// that may involve multiple micro-batches). A request can be timed-out before ever being scheduled.
/// @param cacheSaltID Salt ID for KV cache blocks to limit the kv cache reuse to the requests with the same string.
/// @param disaggRequestId Disaggregated request ID.
Request(VecTokens inputTokenIds, SizeType32 maxTokens, bool streaming = false,
SamplingConfig const& samplingConfig = SamplingConfig(), OutputConfig const& outputConfig = OutputConfig(),
std::optional<SizeType32> const& endId = std::nullopt, std::optional<SizeType32> const& padId = std::nullopt,
@ -711,7 +712,8 @@ public:
std::optional<GuidedDecodingParams> guidedDecodingParams = std::nullopt,
std::optional<SizeType32> languageAdapterUid = std::nullopt,
std::optional<MillisecondsType> allottedTimeMs = std::nullopt,
std::optional<CacheSaltIDType> cacheSaltID = std::nullopt);
std::optional<CacheSaltIDType> cacheSaltID = std::nullopt,
std::optional<IdType> disaggRequestId = std::nullopt);
/// @brief This logits postprocessor name will dispatch to the batched logits postprocessor
static auto constexpr kBatchedPostProcessorName = "batched";
@ -761,6 +763,7 @@ public:
[[nodiscard]] std::optional<MillisecondsType> getAllottedTimeMs() const;
[[nodiscard]] std::optional<CacheSaltIDType> getCacheSaltID() const;
[[nodiscard]] std::optional<std::vector<std::string>> getAdditionalOutputNames() const;
[[nodiscard]] std::optional<IdType> getDisaggRequestId() const;
void setStreaming(bool streaming);
void setSamplingConfig(SamplingConfig const& config);
@ -796,6 +799,7 @@ public:
void setLanguageAdapterUid(SizeType32 languageAdapterUid);
void setAllottedTimeMs(MillisecondsType allottedTimeMs);
void setCacheSaltID(CacheSaltIDType cacheSaltID);
void setDisaggRequestId(IdType disaggRequestId);
private:
friend class Serialization;

View File

@ -40,7 +40,8 @@ Request::Request(VecTokens inputTokenIds, SizeType32 maxTokens, bool streaming,
std::optional<SizeType32> encoderOutputLength, std::optional<Tensor> crossAttentionMask,
SizeType32 numReturnSequences, std::optional<EagleConfig> eagleConfig, std::optional<Tensor> skipCrossAttnBlocks,
std::optional<GuidedDecodingParams> guidedDecodingParams, std::optional<SizeType32> languageAdapterUid,
std::optional<MillisecondsType> allottedTimeMs, std::optional<CacheSaltIDType> cacheSaltID)
std::optional<MillisecondsType> allottedTimeMs, std::optional<CacheSaltIDType> cacheSaltID,
std::optional<IdType> disaggRequestId)
: mImpl(std::make_unique<Impl>(std::move(inputTokenIds), maxTokens, streaming, samplingConfig, outputConfig, endId,
padId, std::move(positionIds), std::move(badWords), std::move(stopWords), std::move(embeddingBias),
std::move(externalDraftTokensConfig), std::move(pTuningConfig), std::move(multimodalInput),
@ -49,7 +50,7 @@ Request::Request(VecTokens inputTokenIds, SizeType32 maxTokens, bool streaming,
std::move(encoderInputTokenIds), clientId, returnAllGeneratedTokens, priority, type,
std::move(contextPhaseParams), std::move(encoderInputFeatures), encoderOutputLength, crossAttentionMask,
numReturnSequences, eagleConfig, skipCrossAttnBlocks, std::move(guidedDecodingParams), languageAdapterUid,
allottedTimeMs, cacheSaltID))
allottedTimeMs, cacheSaltID, disaggRequestId))
{
}
@ -253,6 +254,11 @@ std::optional<CacheSaltIDType> Request::getCacheSaltID() const
return mImpl->getCacheSaltID();
}
std::optional<IdType> Request::getDisaggRequestId() const
{
return mImpl->getDisaggRequestId();
}
void Request::setStreaming(bool streaming)
{
mImpl->setStreaming(streaming);
@ -310,12 +316,12 @@ void Request::setPromptTuningConfig(PromptTuningConfig const& pTuningConfig)
void Request::setMultimodalEmbedding(Tensor const& multimodalEmbedding)
{
return mImpl->setMultimodalEmbedding(multimodalEmbedding);
mImpl->setMultimodalEmbedding(multimodalEmbedding);
}
void Request::setMultimodalInput(MultimodalInput const& multimodalInput)
{
return mImpl->setMultimodalInput(multimodalInput);
mImpl->setMultimodalInput(multimodalInput);
}
void Request::setMropeConfig(MropeConfig const& mRopeConfig)
@ -400,7 +406,7 @@ void Request::setEagleConfig(std::optional<EagleConfig> const& eagleConfig)
void Request::setSkipCrossAttnBlocks(Tensor skipCrossAttnBlocks)
{
return mImpl->setSkipCrossAttnBlocks(skipCrossAttnBlocks);
mImpl->setSkipCrossAttnBlocks(skipCrossAttnBlocks);
}
void Request::setGuidedDecodingParams(GuidedDecodingParams const& guidedDecodingParams)
@ -410,16 +416,21 @@ void Request::setGuidedDecodingParams(GuidedDecodingParams const& guidedDecoding
void Request::setAllottedTimeMs(MillisecondsType allottedTimeMs)
{
return mImpl->setAllottedTimeMs(allottedTimeMs);
mImpl->setAllottedTimeMs(allottedTimeMs);
}
void Request::setLanguageAdapterUid(SizeType32 languageAdapterUid)
{
return mImpl->setLanguageAdapterUid(languageAdapterUid);
mImpl->setLanguageAdapterUid(languageAdapterUid);
}
void Request::setCacheSaltID(CacheSaltIDType cacheSaltID)
{
return mImpl->setCacheSaltID(cacheSaltID);
mImpl->setCacheSaltID(cacheSaltID);
}
void Request::setDisaggRequestId(IdType disaggRequestId)
{
mImpl->setDisaggRequestId(disaggRequestId);
}
} // namespace tensorrt_llm::executor

View File

@ -48,7 +48,7 @@ public:
std::optional<Tensor> crossAttentionMask, SizeType32 numReturnSequences, std::optional<EagleConfig> eagleConfig,
std::optional<Tensor> skipCrossAttnBlocks, std::optional<GuidedDecodingParams> guidedDecodingParams,
std::optional<SizeType32> languageAdapterUid, std::optional<MillisecondsType> allottedTimeMs,
std::optional<CacheSaltIDType> cacheSaltID)
std::optional<CacheSaltIDType> cacheSaltID, std::optional<IdType> disaggRequestId)
: mInputTokenIds(std::move(inputTokenIds))
, mMaxNewTokens(maxNewTokens)
, mStreaming(streaming)
@ -86,6 +86,7 @@ public:
, mLanguageAdapterUid(languageAdapterUid)
, mAllottedTimeMs(allottedTimeMs)
, mCacheSaltID(cacheSaltID)
, mDisaggRequestId(disaggRequestId)
{
validate();
}
@ -302,6 +303,11 @@ public:
return mCacheSaltID;
}
[[nodiscard]] std::optional<IdType> getDisaggRequestId() const
{
return mDisaggRequestId;
}
void setStreaming(bool streaming)
{
mStreaming = streaming;
@ -481,6 +487,11 @@ public:
mCacheSaltID = cacheSaltID;
}
void setDisaggRequestId(IdType disaggRequestId)
{
mDisaggRequestId = disaggRequestId;
}
private:
void validate()
{
@ -555,6 +566,7 @@ private:
lambda(mLanguageAdapterUid);
lambda(mAllottedTimeMs ? std::make_optional(mAllottedTimeMs->count()) : std::nullopt);
lambda(mCacheSaltID);
lambda(mDisaggRequestId);
}
VecTokens mInputTokenIds;
@ -594,6 +606,7 @@ private:
std::optional<SizeType32> mLanguageAdapterUid;
std::optional<MillisecondsType> mAllottedTimeMs;
std::optional<CacheSaltIDType> mCacheSaltID;
std::optional<IdType> mDisaggRequestId;
};
} // namespace tensorrt_llm::executor

View File

@ -578,11 +578,11 @@ void initRequestBindings(nb::module_& m)
self.getClientId(), self.getReturnAllGeneratedTokens(), self.getPriority(), self.getRequestType(),
self.getContextPhaseParams(), self.getEncoderInputFeatures(), self.getEncoderOutputLength(),
self.getCrossAttentionMask(), self.getEagleConfig(), self.getSkipCrossAttnBlocks(),
self.getGuidedDecodingParams(), self.getCacheSaltID());
self.getGuidedDecodingParams(), self.getCacheSaltID(), self.getDisaggRequestId());
};
auto requestSetstate = [](tle::Request& self, nb::tuple const& state)
{
if (state.size() != 34)
if (state.size() != 35)
{
throw std::runtime_error("Invalid Request state!");
}
@ -606,8 +606,8 @@ void initRequestBindings(nb::module_& m)
nb::cast<std::optional<tle::Tensor>>(state[27]), nb::cast<std::optional<SizeType32>>(state[28]),
nb::cast<std::optional<tle::Tensor>>(state[29]), 1, nb::cast<std::optional<tle::EagleConfig>>(state[30]),
nb::cast<std::optional<tle::Tensor>>(state[31]),
nb::cast<std::optional<tle::GuidedDecodingParams>>(state[32]),
nb::cast<std::optional<tle::CacheSaltIDType>>(state[33]));
nb::cast<std::optional<tle::GuidedDecodingParams>>(state[32]), std::nullopt, std::nullopt,
nb::cast<std::optional<tle::CacheSaltIDType>>(state[33]), nb::cast<std::optional<tle::IdType>>(state[34]));
};
nb::class_<tle::Request> request(m, "Request", nb::dynamic_attr());
@ -648,7 +648,8 @@ void initRequestBindings(nb::module_& m)
std::optional<tle::GuidedDecodingParams>, // guidedDecodingParams
std::optional<tle::SizeType32>, // languageAdapterUid
std::optional<tle::MillisecondsType>, // allottedTimeMs
std::optional<tle::CacheSaltIDType> // cacheSaltID
std::optional<tle::CacheSaltIDType>, // cacheSaltID
std::optional<tle::IdType> // disaggRequestId
>(),
// clang-format off
nb::arg("input_token_ids"),
@ -688,8 +689,9 @@ void initRequestBindings(nb::module_& m)
nb::arg("guided_decoding_params") = nb::none(),
nb::arg("language_adapter_uid") = nb::none(),
nb::arg("allotted_time_ms") = nb::none(),
nb::arg("cache_salt_id") = nb::none()
) // clang-format on
nb::arg("cache_salt_id") = nb::none(),
nb::arg("disagg_request_id") = nb::none()
) // clang-format on
.def_prop_ro("input_token_ids", &tle::Request::getInputTokenIds)
.def_prop_ro("max_tokens", &tle::Request::getMaxTokens)
.def_prop_rw("streaming", &tle::Request::getStreaming, &tle::Request::setStreaming)
@ -733,6 +735,7 @@ void initRequestBindings(nb::module_& m)
.def_prop_rw("allotted_time_ms", &tle::Request::getAllottedTimeMs, &tle::Request::setAllottedTimeMs)
.def_prop_rw("cache_salt_id", &tle::Request::getCacheSaltID, &tle::Request::setCacheSaltID)
.def_prop_rw("context_phase_params", &tle::Request::getContextPhaseParams, &tle::Request::setContextPhaseParams)
.def_prop_rw("disagg_request_id", &tle::Request::getDisaggRequestId, &tle::Request::setDisaggRequestId)
.def("__getstate__", requestGetstate)
.def("__setstate__", requestSetstate);
request.attr("BATCHED_POST_PROCESSOR_NAME") = tle::Request::kBatchedPostProcessorName;
@ -929,8 +932,8 @@ void initRequestBindings(nb::module_& m)
{
throw std::runtime_error("Invalid Request state!");
}
new (&response) tle::Response(
nb::cast<SizeType32>(state[0]), nb::cast<tle::Result>(state[1]), nb::cast<SizeType32>(state[2]));
new (&response)
tle::Response(nb::cast<IdType>(state[0]), nb::cast<tle::Result>(state[1]), nb::cast<IdType>(state[2]));
};
nb::class_<tle::Response>(m, "Response")

View File

@ -85,8 +85,7 @@ public:
std::optional<SizeType32> languageAdapterUid = std::nullopt,
std::optional<MillisecondsType> allottedTimeMs = std::nullopt,
std::optional<executor::ContextPhaseParams> const& contextPhaseParams = std::nullopt,
std::optional<CacheSaltIDType> cacheSaltID = std::nullopt,
std::optional<TimePoint> arrivalTime = std::nullopt)
std::optional<CacheSaltIDType> cacheSaltID = std::nullopt, std::optional<TimePoint> arrivalTime = std::nullopt)
: Base(requestId, //
maxNewTokens, //
std::make_shared<std::vector<TokenIdType>>(std::move(inputTokens)), //

View File

@ -531,11 +531,11 @@ void initRequestBindings(pybind11::module_& m)
self.getClientId(), self.getReturnAllGeneratedTokens(), self.getPriority(), self.getRequestType(),
self.getContextPhaseParams(), self.getEncoderInputFeatures(), self.getEncoderOutputLength(),
self.getCrossAttentionMask(), self.getEagleConfig(), self.getSkipCrossAttnBlocks(),
self.getGuidedDecodingParams(), self.getCacheSaltID());
self.getGuidedDecodingParams(), self.getCacheSaltID(), self.getDisaggRequestId());
};
auto requestSetstate = [](py::tuple const& state)
{
if (state.size() != 34)
if (state.size() != 35)
{
throw std::runtime_error("Invalid Request state!");
}
@ -556,7 +556,8 @@ void initRequestBindings(pybind11::module_& m)
state[27].cast<std::optional<tle::Tensor>>(), state[28].cast<std::optional<SizeType32>>(),
state[29].cast<std::optional<tle::Tensor>>(), 1, state[30].cast<std::optional<tle::EagleConfig>>(),
state[31].cast<std::optional<tle::Tensor>>(), state[32].cast<std::optional<tle::GuidedDecodingParams>>(),
state[33].cast<std::optional<tle::CacheSaltIDType>>());
std::nullopt, std::nullopt, state[33].cast<std::optional<tle::CacheSaltIDType>>(),
state[34].cast<std::optional<tle::IdType>>());
};
py::class_<tle::Request> request(m, "Request", pybind11::dynamic_attr());
@ -597,7 +598,8 @@ void initRequestBindings(pybind11::module_& m)
std::optional<tle::GuidedDecodingParams>, // guidedDecodingParams
std::optional<tle::SizeType32>, // languageAdapterUid
std::optional<tle::MillisecondsType>, // allottedTimeMs
std::optional<tle::CacheSaltIDType> // cacheSaltID
std::optional<tle::CacheSaltIDType>, // cacheSaltID
std::optional<tle::IdType> // disaggRequestId
>(),
// clang-format off
py::arg("input_token_ids"),
@ -638,8 +640,9 @@ void initRequestBindings(pybind11::module_& m)
py::arg("guided_decoding_params") = py::none(),
py::arg("language_adapter_uid") = py::none(),
py::arg("allotted_time_ms") = py::none(),
py::arg("cache_salt_id") = py::none()
) // clang-format on
py::arg("cache_salt_id") = py::none(),
py::arg("disagg_request_id") = py::none()
) // clang-format on
.def_property_readonly("input_token_ids", &tle::Request::getInputTokenIds)
.def_property_readonly("max_tokens", &tle::Request::getMaxTokens)
.def_property("streaming", &tle::Request::getStreaming, &tle::Request::setStreaming)
@ -686,6 +689,7 @@ void initRequestBindings(pybind11::module_& m)
.def_property("cache_salt_id", &tle::Request::getCacheSaltID, &tle::Request::setCacheSaltID)
.def_property(
"context_phase_params", &tle::Request::getContextPhaseParams, &tle::Request::setContextPhaseParams)
.def_property("disagg_request_id", &tle::Request::getDisaggRequestId, &tle::Request::setDisaggRequestId)
.def(py::pickle(requestGetstate, requestSetstate));
request.attr("BATCHED_POST_PROCESSOR_NAME") = tle::Request::kBatchedPostProcessorName;
@ -870,7 +874,7 @@ void initRequestBindings(pybind11::module_& m)
throw std::runtime_error("Invalid Request state!");
}
return std::make_unique<tle::Response>(
state[0].cast<SizeType32>(), state[1].cast<tle::Result>(), state[2].cast<SizeType32>());
state[0].cast<IdType>(), state[1].cast<tle::Result>(), state[2].cast<IdType>());
};
py::class_<tle::Response>(m, "Response")

View File

@ -11,6 +11,7 @@ from typing import Dict, Iterable, List, Optional, Tuple
import torch
from tensorrt_llm._utils import mpi_disabled, nvtx_range
from tensorrt_llm.llmapi.disagg_utils import get_local_request_id
from tensorrt_llm.mapping import CpType
from ..distributed import Distributed
@ -206,10 +207,15 @@ class ExecutorRequestQueue:
return False
def _get_request_id(self):
# (next_request_id + 1) % UINT64_MAX
def _get_request_id(self, request: Optional[ExecutorRequest] = None):
# if request has a disagg_request_id, use it as request id so that
# corresponding context and generation requests have the same request id
if request and request.disagg_request_id and isinstance(
request.disagg_request_id, int):
return request.disagg_request_id
current_id = self.next_request_id
self.next_request_id = (self.next_request_id + 1) & ((1 << 64) - 1)
self.next_request_id = get_local_request_id(current_id)
return current_id
def _generate_child_request_ids(
@ -236,7 +242,7 @@ class ExecutorRequestQueue:
assert self.active, "PyExecutor has already been shutdown."
start_time = time.time()
for request, query in requests_and_queries:
req_id = self._get_request_id()
req_id = self._get_request_id(request)
if self.enable_iter_perf_stats:
self.start_times[req_id] = start_time
child_req_ids = self._generate_child_request_ids(request)

View File

@ -21,6 +21,8 @@ class DisaggregatedParams:
ctx_request_id (int): The context request id
opaque_state(bytes): Any additional state needing to be exchanged between context and gen instances
draft_tokens (List[int]): The draft tokens of the generation request
disagg_request_id (int): The disaggregated request id, if set, both context and generation requests will use it
as underlying request id.
multimodal_embedding_handles (List[Dict[str, Any]]): The resulting multimodal embedding handles from ViT.
multimodal_hashes (List[List[int]]): The multimodal hashes of each multimodal item in the request.
@ -32,6 +34,8 @@ class DisaggregatedParams:
ctx_request_id: Optional[int] = None
opaque_state: Optional[bytes] = None
draft_tokens: Optional[List[int]] = None
# If disagg_request_id is set, both context and generation requests will use it as underlying request id.
disagg_request_id: Optional[int] = None
# E-P Disaggregated Params
multimodal_embedding_handles: Optional[List[Dict[str, Any]]] = (
@ -44,8 +48,12 @@ class DisaggregatedParams:
mrope_position_deltas_handle: Optional[Dict[str, Any]] = None
def get_context_phase_params(self) -> tllme.ContextPhaseParams:
# Prefer disagg_request_id over ctx_request_id
request_id = (
self.disagg_request_id if self.disagg_request_id is not None else self.ctx_request_id
)
return tllme.ContextPhaseParams(
self.first_gen_tokens, self.ctx_request_id, self.opaque_state, self.draft_tokens
self.first_gen_tokens, request_id, self.opaque_state, self.draft_tokens
)
def get_request_type(self) -> tllme.RequestType:

View File

@ -433,6 +433,8 @@ class BaseWorker(GenerationExecutor):
context_phase_params = None
request_type = tllm.RequestType.REQUEST_TYPE_CONTEXT_AND_GENERATION
disagg_request_id = 0
if request.disaggregated_params is not None:
assert (
not self._is_pytorch_backend
@ -441,6 +443,7 @@ class BaseWorker(GenerationExecutor):
== "context_and_generation"
), "kv_cache_transceiver is disabled, please set 'cache_transceiver_config: backend:<backend_type>` in config file for disaggregated serving"
request_type = request.disaggregated_params.get_request_type()
disagg_request_id = request.disaggregated_params.disagg_request_id
if request_type == tllm.RequestType.REQUEST_TYPE_GENERATION_ONLY:
context_phase_params = request.disaggregated_params.get_context_phase_params(
)
@ -559,7 +562,8 @@ class BaseWorker(GenerationExecutor):
kv_cache_retention_config=request.kv_cache_retention_config,
context_phase_params=context_phase_params,
type=request_type,
cache_salt_id=request.cache_salt_id)
cache_salt_id=request.cache_salt_id,
disagg_request_id=disagg_request_id)
executor_request.py_num_logprobs = request.sampling_params.logprobs
executor_request.py_lora_path = py_lora_path

View File

@ -212,7 +212,7 @@ class GenerationExecutor(ABC):
return futures
def _get_next_client_id(self):
def _get_next_client_id(self) -> int:
# (self._last_client_id + 1) % UINT64_MAX
self._last_client_id = (self._last_client_id + 1) & ((1 << 64) - 1)
return self._last_client_id

View File

@ -233,7 +233,8 @@ class RayExecutor(RpcExecutorMixin, GenerationExecutor):
Low-level API to the executor. Return a "future" GenerationResult
which can be waited. Forwards the request to the workers through RPC.
"""
request.set_id(self._get_next_client_id())
if request.id is None:
request.set_id(self._get_next_client_id())
logprob_params = self._get_logprob_params(request)
with nvtx_range_debug("rpc_submit"):

View File

@ -1,4 +1,7 @@
import logging
import threading
import time
import uuid
from dataclasses import dataclass, field
from enum import IntEnum
from typing import Any, Dict, List, Literal, Optional, Tuple
@ -76,6 +79,9 @@ class DisaggServerConfig():
max_retries: int = 1
perf_metrics_max_requests: int = 0
disagg_cluster_config: Optional[DisaggClusterConfig] = None
node_id: int = uuid.getnode(
) % 1021 # Assuming only one disagg-server is running on a machine, moding mac by the largest 10-bit prime
# If this causes collisions, users can set node_id manually within range [0, 1023] in config
@dataclass
@ -331,3 +337,48 @@ def parse_metadata_server_config_file(
with open(metadata_server_config_file, 'r') as file:
config = yaml.safe_load(file)
return MetadataServerConfig(**config)
MIN_GLOBAL_ID = 1 << 42
# Consider GIL being removed in the future, use a lock to protect the counter
_global_disagg_request_id_lock = threading.Lock()
_global_disagg_request_id_counter = 0
def get_global_disagg_request_id(machine_id: int) -> int:
"""
a snowflake global disagg request id that doesn't guarantee monotonicity
0: positive integer
1-41 41 bits: timestamp_ms
42-51 10 bits: machine_id
52-63 12 bits: counter
"""
global _global_disagg_request_id_lock
global _global_disagg_request_id_counter
COUNTER_BITS = 12
MACHINE_ID_BITS = 10
COUNTER_MASK = (1 << COUNTER_BITS) - 1
MAX_INT64 = (1 << 63) - 1
if machine_id not in range(0, (1 << MACHINE_ID_BITS) - 1):
raise ValueError(
f"machine_id must be in range [0, {(1 << MACHINE_ID_BITS) - 1})")
timestamp_ms = int(time.monotonic() * 1000)
with _global_disagg_request_id_lock:
counter = _global_disagg_request_id_counter & COUNTER_MASK
_global_disagg_request_id_counter += 1
# Rotate in [MIN_GLOBAL_ID, MAX_INT64)
# [0, MIN_GLOBAL_ID) is reserved for local ids
global_id = (timestamp_ms << (MACHINE_ID_BITS + COUNTER_BITS)) | (
machine_id << COUNTER_BITS) | counter
global_id_int64 = global_id % (MAX_INT64 - MIN_GLOBAL_ID) + MIN_GLOBAL_ID
return global_id_int64
def get_local_request_id(last_id: int) -> int:
""" increment the last_id by 1 and mod by MIN_GLOBAL_ID """
return (last_id + 1) & (MIN_GLOBAL_ID - 1)

View File

@ -23,6 +23,7 @@ from tensorrt_llm.llmapi.disagg_utils import (
DisaggServerConfig,
MetadataServerConfig,
ServerRole,
get_global_disagg_request_id,
)
from tensorrt_llm.logger import logger
from tensorrt_llm.serve.cluster_storage import ClusterStorage, WatchEventType
@ -115,14 +116,15 @@ class OpenAIDisaggregatedService(OpenAIService):
need_ctx = need_ctx and not await self._check_gen_only_disagg(request)
ctx_response = None
gen_req = request
disagg_request_id = get_global_disagg_request_id(self._config.node_id)
if need_ctx:
ctx_req = self._get_ctx_request(request)
ctx_req = self._get_ctx_request(request, disagg_request_id)
# ctx generator is empty
ctx_response = await self._ctx_client.send_request(
ctx_req, server=reserved_ctx_server, hooks=hooks
)
await self._verify_ctx_response(ctx_response)
gen_req = self._get_gen_request(request, ctx_response)
gen_req = self._get_gen_request(request, ctx_response, disagg_request_id)
if ctx_response is None or self._need_gen(ctx_response):
return await self._gen_client.send_request(
gen_req, server=reserved_gen_server, hooks=hooks
@ -140,9 +142,13 @@ class OpenAIDisaggregatedService(OpenAIService):
return False
return True
def _get_ctx_request(self, request: UCompletionRequest) -> UCompletionRequest:
def _get_ctx_request(
self, request: UCompletionRequest, disagg_request_id: Optional[int]
) -> UCompletionRequest:
ctx_request = copy.deepcopy(request)
ctx_request.disaggregated_params = DisaggregatedParams(request_type="context_only")
ctx_request.disaggregated_params = DisaggregatedParams(
request_type="context_only", disagg_request_id=disagg_request_id
)
ctx_request.stream = False
ctx_request.stream_options = None
return ctx_request
@ -151,6 +157,7 @@ class OpenAIDisaggregatedService(OpenAIService):
self,
request: UCompletionRequest,
ctx_response: UCompletionResponse,
disagg_request_id: Optional[int],
) -> UCompletionRequest:
request.disaggregated_params = ctx_response.choices[0].disaggregated_params
request.disaggregated_params.request_type = "generation_only"

View File

@ -117,6 +117,7 @@ class DisaggregatedParams(OpenAIBaseModel):
ctx_request_id: Optional[int] = None
encoded_opaque_state: Optional[str] = None
draft_tokens: Optional[List[int]] = None
disagg_request_id: Optional[int] = None
class ErrorResponse(OpenAIBaseModel):
@ -1000,7 +1001,8 @@ def to_disaggregated_params(
ctx_request_id=tllm_disagg_params.ctx_request_id,
encoded_opaque_state=encode_opaque_state(
tllm_disagg_params.opaque_state),
draft_tokens=tllm_disagg_params.draft_tokens)
draft_tokens=tllm_disagg_params.draft_tokens,
disagg_request_id=tllm_disagg_params.disagg_request_id)
def to_llm_disaggregated_params(
@ -1013,7 +1015,8 @@ def to_llm_disaggregated_params(
ctx_request_id=disaggregated_params.ctx_request_id,
opaque_state=decode_opaque_state(
disaggregated_params.encoded_opaque_state),
draft_tokens=disaggregated_params.draft_tokens)
draft_tokens=disaggregated_params.draft_tokens,
disagg_request_id=disaggregated_params.disagg_request_id)
UCompletionRequest = Union[CompletionRequest, ChatCompletionRequest]

View File

@ -1,11 +1,13 @@
from concurrent.futures import ThreadPoolExecutor
import pytest
import yaml
# isort: off
from tensorrt_llm.llmapi.disagg_utils import (
CtxGenServerConfig, DisaggServerConfig, extract_ctx_gen_cfgs,
extract_router_config, extract_disagg_cfg, get_server_configs_dict,
parse_disagg_config_file)
MIN_GLOBAL_ID, CtxGenServerConfig, DisaggServerConfig, extract_ctx_gen_cfgs,
extract_router_config, extract_disagg_cfg, get_global_disagg_request_id,
get_local_request_id, get_server_configs_dict, parse_disagg_config_file)
# isort: on
@ -155,3 +157,45 @@ def test_get_server_configs_dict():
assert len(server_dict) == 2
assert ("host1", 8001) in server_dict
assert ("host2", 8002) in server_dict
# test get_global_disagg_request_id
@pytest.mark.parametrize("multithread", [True, False],
ids=["multithread", "singlethread"])
def test_get_global_disagg_request_id(multithread):
iter = 10000
def get_ids(node_ids):
all_node_ids = [[] for _ in range(len(node_ids))]
for i in range(iter):
for i, node_id in enumerate(node_ids):
all_node_ids[i].append(get_global_disagg_request_id(node_id))
return all_node_ids
node_ids = list(range(10))
if multithread:
with ThreadPoolExecutor(max_workers=len(node_ids)) as executor:
all_node_ids = [
ids[0] for ids in executor.map(get_ids, [[i] for i in node_ids])
]
else:
all_node_ids = get_ids(node_ids)
all_ids = set(i for ids in all_node_ids for i in ids)
assert len(all_ids) == iter * len(node_ids)
assert all(id >= MIN_GLOBAL_ID and id < ((1 << 63) - 1) for id in all_ids)
def test_get_local_request_id():
last_id = MIN_GLOBAL_ID - 100
ids = set()
for i in range(1000):
last_id = get_local_request_id(last_id)
assert last_id >= 0
assert last_id < MIN_GLOBAL_ID
ids.add(last_id)
assert len(ids) == 1000
assert min(ids) == 0
assert max(ids) == MIN_GLOBAL_ID - 1
assert max(ids) - min(ids) > (
1 << 40) # ensure there is enough space for local ids