diff --git a/cpp/include/tensorrt_llm/executor/executor.h b/cpp/include/tensorrt_llm/executor/executor.h index 787fa0bb7e..1e5cc16a05 100644 --- a/cpp/include/tensorrt_llm/executor/executor.h +++ b/cpp/include/tensorrt_llm/executor/executor.h @@ -684,6 +684,7 @@ public: /// finish reason. The request may exceed this time slightly, but at most by 1 forward pass (in pipeline parallelism /// that may involve multiple micro-batches). A request can be timed-out before ever being scheduled. /// @param cacheSaltID Salt ID for KV cache blocks to limit the kv cache reuse to the requests with the same string. + /// @param disaggRequestId Disaggregated request ID. Request(VecTokens inputTokenIds, SizeType32 maxTokens, bool streaming = false, SamplingConfig const& samplingConfig = SamplingConfig(), OutputConfig const& outputConfig = OutputConfig(), std::optional const& endId = std::nullopt, std::optional const& padId = std::nullopt, @@ -711,7 +712,8 @@ public: std::optional guidedDecodingParams = std::nullopt, std::optional languageAdapterUid = std::nullopt, std::optional allottedTimeMs = std::nullopt, - std::optional cacheSaltID = std::nullopt); + std::optional cacheSaltID = std::nullopt, + std::optional disaggRequestId = std::nullopt); /// @brief This logits postprocessor name will dispatch to the batched logits postprocessor static auto constexpr kBatchedPostProcessorName = "batched"; @@ -761,6 +763,7 @@ public: [[nodiscard]] std::optional getAllottedTimeMs() const; [[nodiscard]] std::optional getCacheSaltID() const; [[nodiscard]] std::optional> getAdditionalOutputNames() const; + [[nodiscard]] std::optional getDisaggRequestId() const; void setStreaming(bool streaming); void setSamplingConfig(SamplingConfig const& config); @@ -796,6 +799,7 @@ public: void setLanguageAdapterUid(SizeType32 languageAdapterUid); void setAllottedTimeMs(MillisecondsType allottedTimeMs); void setCacheSaltID(CacheSaltIDType cacheSaltID); + void setDisaggRequestId(IdType disaggRequestId); private: friend class Serialization; diff --git a/cpp/tensorrt_llm/executor/request.cpp b/cpp/tensorrt_llm/executor/request.cpp index 987eeef894..5ac62d3fcb 100644 --- a/cpp/tensorrt_llm/executor/request.cpp +++ b/cpp/tensorrt_llm/executor/request.cpp @@ -40,7 +40,8 @@ Request::Request(VecTokens inputTokenIds, SizeType32 maxTokens, bool streaming, std::optional encoderOutputLength, std::optional crossAttentionMask, SizeType32 numReturnSequences, std::optional eagleConfig, std::optional skipCrossAttnBlocks, std::optional guidedDecodingParams, std::optional languageAdapterUid, - std::optional allottedTimeMs, std::optional cacheSaltID) + std::optional allottedTimeMs, std::optional cacheSaltID, + std::optional disaggRequestId) : mImpl(std::make_unique(std::move(inputTokenIds), maxTokens, streaming, samplingConfig, outputConfig, endId, padId, std::move(positionIds), std::move(badWords), std::move(stopWords), std::move(embeddingBias), std::move(externalDraftTokensConfig), std::move(pTuningConfig), std::move(multimodalInput), @@ -49,7 +50,7 @@ Request::Request(VecTokens inputTokenIds, SizeType32 maxTokens, bool streaming, std::move(encoderInputTokenIds), clientId, returnAllGeneratedTokens, priority, type, std::move(contextPhaseParams), std::move(encoderInputFeatures), encoderOutputLength, crossAttentionMask, numReturnSequences, eagleConfig, skipCrossAttnBlocks, std::move(guidedDecodingParams), languageAdapterUid, - allottedTimeMs, cacheSaltID)) + allottedTimeMs, cacheSaltID, disaggRequestId)) { } @@ -253,6 +254,11 @@ std::optional Request::getCacheSaltID() const return mImpl->getCacheSaltID(); } +std::optional Request::getDisaggRequestId() const +{ + return mImpl->getDisaggRequestId(); +} + void Request::setStreaming(bool streaming) { mImpl->setStreaming(streaming); @@ -310,12 +316,12 @@ void Request::setPromptTuningConfig(PromptTuningConfig const& pTuningConfig) void Request::setMultimodalEmbedding(Tensor const& multimodalEmbedding) { - return mImpl->setMultimodalEmbedding(multimodalEmbedding); + mImpl->setMultimodalEmbedding(multimodalEmbedding); } void Request::setMultimodalInput(MultimodalInput const& multimodalInput) { - return mImpl->setMultimodalInput(multimodalInput); + mImpl->setMultimodalInput(multimodalInput); } void Request::setMropeConfig(MropeConfig const& mRopeConfig) @@ -400,7 +406,7 @@ void Request::setEagleConfig(std::optional const& eagleConfig) void Request::setSkipCrossAttnBlocks(Tensor skipCrossAttnBlocks) { - return mImpl->setSkipCrossAttnBlocks(skipCrossAttnBlocks); + mImpl->setSkipCrossAttnBlocks(skipCrossAttnBlocks); } void Request::setGuidedDecodingParams(GuidedDecodingParams const& guidedDecodingParams) @@ -410,16 +416,21 @@ void Request::setGuidedDecodingParams(GuidedDecodingParams const& guidedDecoding void Request::setAllottedTimeMs(MillisecondsType allottedTimeMs) { - return mImpl->setAllottedTimeMs(allottedTimeMs); + mImpl->setAllottedTimeMs(allottedTimeMs); } void Request::setLanguageAdapterUid(SizeType32 languageAdapterUid) { - return mImpl->setLanguageAdapterUid(languageAdapterUid); + mImpl->setLanguageAdapterUid(languageAdapterUid); } void Request::setCacheSaltID(CacheSaltIDType cacheSaltID) { - return mImpl->setCacheSaltID(cacheSaltID); + mImpl->setCacheSaltID(cacheSaltID); +} + +void Request::setDisaggRequestId(IdType disaggRequestId) +{ + mImpl->setDisaggRequestId(disaggRequestId); } } // namespace tensorrt_llm::executor diff --git a/cpp/tensorrt_llm/executor/requestImpl.h b/cpp/tensorrt_llm/executor/requestImpl.h index 94de53a781..281f81d462 100644 --- a/cpp/tensorrt_llm/executor/requestImpl.h +++ b/cpp/tensorrt_llm/executor/requestImpl.h @@ -48,7 +48,7 @@ public: std::optional crossAttentionMask, SizeType32 numReturnSequences, std::optional eagleConfig, std::optional skipCrossAttnBlocks, std::optional guidedDecodingParams, std::optional languageAdapterUid, std::optional allottedTimeMs, - std::optional cacheSaltID) + std::optional cacheSaltID, std::optional disaggRequestId) : mInputTokenIds(std::move(inputTokenIds)) , mMaxNewTokens(maxNewTokens) , mStreaming(streaming) @@ -86,6 +86,7 @@ public: , mLanguageAdapterUid(languageAdapterUid) , mAllottedTimeMs(allottedTimeMs) , mCacheSaltID(cacheSaltID) + , mDisaggRequestId(disaggRequestId) { validate(); } @@ -302,6 +303,11 @@ public: return mCacheSaltID; } + [[nodiscard]] std::optional getDisaggRequestId() const + { + return mDisaggRequestId; + } + void setStreaming(bool streaming) { mStreaming = streaming; @@ -481,6 +487,11 @@ public: mCacheSaltID = cacheSaltID; } + void setDisaggRequestId(IdType disaggRequestId) + { + mDisaggRequestId = disaggRequestId; + } + private: void validate() { @@ -555,6 +566,7 @@ private: lambda(mLanguageAdapterUid); lambda(mAllottedTimeMs ? std::make_optional(mAllottedTimeMs->count()) : std::nullopt); lambda(mCacheSaltID); + lambda(mDisaggRequestId); } VecTokens mInputTokenIds; @@ -594,6 +606,7 @@ private: std::optional mLanguageAdapterUid; std::optional mAllottedTimeMs; std::optional mCacheSaltID; + std::optional mDisaggRequestId; }; } // namespace tensorrt_llm::executor diff --git a/cpp/tensorrt_llm/nanobind/executor/request.cpp b/cpp/tensorrt_llm/nanobind/executor/request.cpp index db05409d86..4a53516bf8 100644 --- a/cpp/tensorrt_llm/nanobind/executor/request.cpp +++ b/cpp/tensorrt_llm/nanobind/executor/request.cpp @@ -578,11 +578,11 @@ void initRequestBindings(nb::module_& m) self.getClientId(), self.getReturnAllGeneratedTokens(), self.getPriority(), self.getRequestType(), self.getContextPhaseParams(), self.getEncoderInputFeatures(), self.getEncoderOutputLength(), self.getCrossAttentionMask(), self.getEagleConfig(), self.getSkipCrossAttnBlocks(), - self.getGuidedDecodingParams(), self.getCacheSaltID()); + self.getGuidedDecodingParams(), self.getCacheSaltID(), self.getDisaggRequestId()); }; auto requestSetstate = [](tle::Request& self, nb::tuple const& state) { - if (state.size() != 34) + if (state.size() != 35) { throw std::runtime_error("Invalid Request state!"); } @@ -606,8 +606,8 @@ void initRequestBindings(nb::module_& m) nb::cast>(state[27]), nb::cast>(state[28]), nb::cast>(state[29]), 1, nb::cast>(state[30]), nb::cast>(state[31]), - nb::cast>(state[32]), - nb::cast>(state[33])); + nb::cast>(state[32]), std::nullopt, std::nullopt, + nb::cast>(state[33]), nb::cast>(state[34])); }; nb::class_ request(m, "Request", nb::dynamic_attr()); @@ -648,7 +648,8 @@ void initRequestBindings(nb::module_& m) std::optional, // guidedDecodingParams std::optional, // languageAdapterUid std::optional, // allottedTimeMs - std::optional // cacheSaltID + std::optional, // cacheSaltID + std::optional // disaggRequestId >(), // clang-format off nb::arg("input_token_ids"), @@ -688,8 +689,9 @@ void initRequestBindings(nb::module_& m) nb::arg("guided_decoding_params") = nb::none(), nb::arg("language_adapter_uid") = nb::none(), nb::arg("allotted_time_ms") = nb::none(), - nb::arg("cache_salt_id") = nb::none() - ) // clang-format on + nb::arg("cache_salt_id") = nb::none(), + nb::arg("disagg_request_id") = nb::none() + ) // clang-format on .def_prop_ro("input_token_ids", &tle::Request::getInputTokenIds) .def_prop_ro("max_tokens", &tle::Request::getMaxTokens) .def_prop_rw("streaming", &tle::Request::getStreaming, &tle::Request::setStreaming) @@ -733,6 +735,7 @@ void initRequestBindings(nb::module_& m) .def_prop_rw("allotted_time_ms", &tle::Request::getAllottedTimeMs, &tle::Request::setAllottedTimeMs) .def_prop_rw("cache_salt_id", &tle::Request::getCacheSaltID, &tle::Request::setCacheSaltID) .def_prop_rw("context_phase_params", &tle::Request::getContextPhaseParams, &tle::Request::setContextPhaseParams) + .def_prop_rw("disagg_request_id", &tle::Request::getDisaggRequestId, &tle::Request::setDisaggRequestId) .def("__getstate__", requestGetstate) .def("__setstate__", requestSetstate); request.attr("BATCHED_POST_PROCESSOR_NAME") = tle::Request::kBatchedPostProcessorName; @@ -929,8 +932,8 @@ void initRequestBindings(nb::module_& m) { throw std::runtime_error("Invalid Request state!"); } - new (&response) tle::Response( - nb::cast(state[0]), nb::cast(state[1]), nb::cast(state[2])); + new (&response) + tle::Response(nb::cast(state[0]), nb::cast(state[1]), nb::cast(state[2])); }; nb::class_(m, "Response") diff --git a/cpp/tensorrt_llm/pybind/batch_manager/llmRequest.h b/cpp/tensorrt_llm/pybind/batch_manager/llmRequest.h index b43fb8dd07..19b214b001 100644 --- a/cpp/tensorrt_llm/pybind/batch_manager/llmRequest.h +++ b/cpp/tensorrt_llm/pybind/batch_manager/llmRequest.h @@ -85,8 +85,7 @@ public: std::optional languageAdapterUid = std::nullopt, std::optional allottedTimeMs = std::nullopt, std::optional const& contextPhaseParams = std::nullopt, - std::optional cacheSaltID = std::nullopt, - std::optional arrivalTime = std::nullopt) + std::optional cacheSaltID = std::nullopt, std::optional arrivalTime = std::nullopt) : Base(requestId, // maxNewTokens, // std::make_shared>(std::move(inputTokens)), // diff --git a/cpp/tensorrt_llm/pybind/executor/request.cpp b/cpp/tensorrt_llm/pybind/executor/request.cpp index 2e9dae860e..78bb650fbc 100644 --- a/cpp/tensorrt_llm/pybind/executor/request.cpp +++ b/cpp/tensorrt_llm/pybind/executor/request.cpp @@ -531,11 +531,11 @@ void initRequestBindings(pybind11::module_& m) self.getClientId(), self.getReturnAllGeneratedTokens(), self.getPriority(), self.getRequestType(), self.getContextPhaseParams(), self.getEncoderInputFeatures(), self.getEncoderOutputLength(), self.getCrossAttentionMask(), self.getEagleConfig(), self.getSkipCrossAttnBlocks(), - self.getGuidedDecodingParams(), self.getCacheSaltID()); + self.getGuidedDecodingParams(), self.getCacheSaltID(), self.getDisaggRequestId()); }; auto requestSetstate = [](py::tuple const& state) { - if (state.size() != 34) + if (state.size() != 35) { throw std::runtime_error("Invalid Request state!"); } @@ -556,7 +556,8 @@ void initRequestBindings(pybind11::module_& m) state[27].cast>(), state[28].cast>(), state[29].cast>(), 1, state[30].cast>(), state[31].cast>(), state[32].cast>(), - state[33].cast>()); + std::nullopt, std::nullopt, state[33].cast>(), + state[34].cast>()); }; py::class_ request(m, "Request", pybind11::dynamic_attr()); @@ -597,7 +598,8 @@ void initRequestBindings(pybind11::module_& m) std::optional, // guidedDecodingParams std::optional, // languageAdapterUid std::optional, // allottedTimeMs - std::optional // cacheSaltID + std::optional, // cacheSaltID + std::optional // disaggRequestId >(), // clang-format off py::arg("input_token_ids"), @@ -638,8 +640,9 @@ void initRequestBindings(pybind11::module_& m) py::arg("guided_decoding_params") = py::none(), py::arg("language_adapter_uid") = py::none(), py::arg("allotted_time_ms") = py::none(), - py::arg("cache_salt_id") = py::none() - ) // clang-format on + py::arg("cache_salt_id") = py::none(), + py::arg("disagg_request_id") = py::none() + ) // clang-format on .def_property_readonly("input_token_ids", &tle::Request::getInputTokenIds) .def_property_readonly("max_tokens", &tle::Request::getMaxTokens) .def_property("streaming", &tle::Request::getStreaming, &tle::Request::setStreaming) @@ -686,6 +689,7 @@ void initRequestBindings(pybind11::module_& m) .def_property("cache_salt_id", &tle::Request::getCacheSaltID, &tle::Request::setCacheSaltID) .def_property( "context_phase_params", &tle::Request::getContextPhaseParams, &tle::Request::setContextPhaseParams) + .def_property("disagg_request_id", &tle::Request::getDisaggRequestId, &tle::Request::setDisaggRequestId) .def(py::pickle(requestGetstate, requestSetstate)); request.attr("BATCHED_POST_PROCESSOR_NAME") = tle::Request::kBatchedPostProcessorName; @@ -870,7 +874,7 @@ void initRequestBindings(pybind11::module_& m) throw std::runtime_error("Invalid Request state!"); } return std::make_unique( - state[0].cast(), state[1].cast(), state[2].cast()); + state[0].cast(), state[1].cast(), state[2].cast()); }; py::class_(m, "Response") diff --git a/tensorrt_llm/_torch/pyexecutor/executor_request_queue.py b/tensorrt_llm/_torch/pyexecutor/executor_request_queue.py index cb42186520..9ea2492840 100644 --- a/tensorrt_llm/_torch/pyexecutor/executor_request_queue.py +++ b/tensorrt_llm/_torch/pyexecutor/executor_request_queue.py @@ -11,6 +11,7 @@ from typing import Dict, Iterable, List, Optional, Tuple import torch from tensorrt_llm._utils import mpi_disabled, nvtx_range +from tensorrt_llm.llmapi.disagg_utils import get_local_request_id from tensorrt_llm.mapping import CpType from ..distributed import Distributed @@ -206,10 +207,15 @@ class ExecutorRequestQueue: return False - def _get_request_id(self): - # (next_request_id + 1) % UINT64_MAX + def _get_request_id(self, request: Optional[ExecutorRequest] = None): + # if request has a disagg_request_id, use it as request id so that + # corresponding context and generation requests have the same request id + if request and request.disagg_request_id and isinstance( + request.disagg_request_id, int): + return request.disagg_request_id + current_id = self.next_request_id - self.next_request_id = (self.next_request_id + 1) & ((1 << 64) - 1) + self.next_request_id = get_local_request_id(current_id) return current_id def _generate_child_request_ids( @@ -236,7 +242,7 @@ class ExecutorRequestQueue: assert self.active, "PyExecutor has already been shutdown." start_time = time.time() for request, query in requests_and_queries: - req_id = self._get_request_id() + req_id = self._get_request_id(request) if self.enable_iter_perf_stats: self.start_times[req_id] = start_time child_req_ids = self._generate_child_request_ids(request) diff --git a/tensorrt_llm/disaggregated_params.py b/tensorrt_llm/disaggregated_params.py index 4c0680bc94..b98780bac7 100644 --- a/tensorrt_llm/disaggregated_params.py +++ b/tensorrt_llm/disaggregated_params.py @@ -21,6 +21,8 @@ class DisaggregatedParams: ctx_request_id (int): The context request id opaque_state(bytes): Any additional state needing to be exchanged between context and gen instances draft_tokens (List[int]): The draft tokens of the generation request + disagg_request_id (int): The disaggregated request id, if set, both context and generation requests will use it + as underlying request id. multimodal_embedding_handles (List[Dict[str, Any]]): The resulting multimodal embedding handles from ViT. multimodal_hashes (List[List[int]]): The multimodal hashes of each multimodal item in the request. @@ -32,6 +34,8 @@ class DisaggregatedParams: ctx_request_id: Optional[int] = None opaque_state: Optional[bytes] = None draft_tokens: Optional[List[int]] = None + # If disagg_request_id is set, both context and generation requests will use it as underlying request id. + disagg_request_id: Optional[int] = None # E-P Disaggregated Params multimodal_embedding_handles: Optional[List[Dict[str, Any]]] = ( @@ -44,8 +48,12 @@ class DisaggregatedParams: mrope_position_deltas_handle: Optional[Dict[str, Any]] = None def get_context_phase_params(self) -> tllme.ContextPhaseParams: + # Prefer disagg_request_id over ctx_request_id + request_id = ( + self.disagg_request_id if self.disagg_request_id is not None else self.ctx_request_id + ) return tllme.ContextPhaseParams( - self.first_gen_tokens, self.ctx_request_id, self.opaque_state, self.draft_tokens + self.first_gen_tokens, request_id, self.opaque_state, self.draft_tokens ) def get_request_type(self) -> tllme.RequestType: diff --git a/tensorrt_llm/executor/base_worker.py b/tensorrt_llm/executor/base_worker.py index ce050179a2..548f63e062 100644 --- a/tensorrt_llm/executor/base_worker.py +++ b/tensorrt_llm/executor/base_worker.py @@ -433,6 +433,8 @@ class BaseWorker(GenerationExecutor): context_phase_params = None request_type = tllm.RequestType.REQUEST_TYPE_CONTEXT_AND_GENERATION + disagg_request_id = 0 + if request.disaggregated_params is not None: assert ( not self._is_pytorch_backend @@ -441,6 +443,7 @@ class BaseWorker(GenerationExecutor): == "context_and_generation" ), "kv_cache_transceiver is disabled, please set 'cache_transceiver_config: backend:` in config file for disaggregated serving" request_type = request.disaggregated_params.get_request_type() + disagg_request_id = request.disaggregated_params.disagg_request_id if request_type == tllm.RequestType.REQUEST_TYPE_GENERATION_ONLY: context_phase_params = request.disaggregated_params.get_context_phase_params( ) @@ -559,7 +562,8 @@ class BaseWorker(GenerationExecutor): kv_cache_retention_config=request.kv_cache_retention_config, context_phase_params=context_phase_params, type=request_type, - cache_salt_id=request.cache_salt_id) + cache_salt_id=request.cache_salt_id, + disagg_request_id=disagg_request_id) executor_request.py_num_logprobs = request.sampling_params.logprobs executor_request.py_lora_path = py_lora_path diff --git a/tensorrt_llm/executor/executor.py b/tensorrt_llm/executor/executor.py index 4c15e657c1..12c69f8d74 100644 --- a/tensorrt_llm/executor/executor.py +++ b/tensorrt_llm/executor/executor.py @@ -212,7 +212,7 @@ class GenerationExecutor(ABC): return futures - def _get_next_client_id(self): + def _get_next_client_id(self) -> int: # (self._last_client_id + 1) % UINT64_MAX self._last_client_id = (self._last_client_id + 1) & ((1 << 64) - 1) return self._last_client_id diff --git a/tensorrt_llm/executor/ray_executor.py b/tensorrt_llm/executor/ray_executor.py index 0fc4fa2810..8613eba044 100644 --- a/tensorrt_llm/executor/ray_executor.py +++ b/tensorrt_llm/executor/ray_executor.py @@ -233,7 +233,8 @@ class RayExecutor(RpcExecutorMixin, GenerationExecutor): Low-level API to the executor. Return a "future" GenerationResult which can be waited. Forwards the request to the workers through RPC. """ - request.set_id(self._get_next_client_id()) + if request.id is None: + request.set_id(self._get_next_client_id()) logprob_params = self._get_logprob_params(request) with nvtx_range_debug("rpc_submit"): diff --git a/tensorrt_llm/llmapi/disagg_utils.py b/tensorrt_llm/llmapi/disagg_utils.py index 5a42c731f3..d7f8fffe0d 100644 --- a/tensorrt_llm/llmapi/disagg_utils.py +++ b/tensorrt_llm/llmapi/disagg_utils.py @@ -1,4 +1,7 @@ import logging +import threading +import time +import uuid from dataclasses import dataclass, field from enum import IntEnum from typing import Any, Dict, List, Literal, Optional, Tuple @@ -76,6 +79,9 @@ class DisaggServerConfig(): max_retries: int = 1 perf_metrics_max_requests: int = 0 disagg_cluster_config: Optional[DisaggClusterConfig] = None + node_id: int = uuid.getnode( + ) % 1021 # Assuming only one disagg-server is running on a machine, moding mac by the largest 10-bit prime + # If this causes collisions, users can set node_id manually within range [0, 1023] in config @dataclass @@ -331,3 +337,48 @@ def parse_metadata_server_config_file( with open(metadata_server_config_file, 'r') as file: config = yaml.safe_load(file) return MetadataServerConfig(**config) + + +MIN_GLOBAL_ID = 1 << 42 + +# Consider GIL being removed in the future, use a lock to protect the counter +_global_disagg_request_id_lock = threading.Lock() +_global_disagg_request_id_counter = 0 + + +def get_global_disagg_request_id(machine_id: int) -> int: + """ + a snowflake global disagg request id that doesn't guarantee monotonicity + 0: positive integer + 1-41 41 bits: timestamp_ms + 42-51 10 bits: machine_id + 52-63 12 bits: counter + """ + global _global_disagg_request_id_lock + global _global_disagg_request_id_counter + + COUNTER_BITS = 12 + MACHINE_ID_BITS = 10 + COUNTER_MASK = (1 << COUNTER_BITS) - 1 + MAX_INT64 = (1 << 63) - 1 + + if machine_id not in range(0, (1 << MACHINE_ID_BITS) - 1): + raise ValueError( + f"machine_id must be in range [0, {(1 << MACHINE_ID_BITS) - 1})") + + timestamp_ms = int(time.monotonic() * 1000) + with _global_disagg_request_id_lock: + counter = _global_disagg_request_id_counter & COUNTER_MASK + _global_disagg_request_id_counter += 1 + + # Rotate in [MIN_GLOBAL_ID, MAX_INT64) + # [0, MIN_GLOBAL_ID) is reserved for local ids + global_id = (timestamp_ms << (MACHINE_ID_BITS + COUNTER_BITS)) | ( + machine_id << COUNTER_BITS) | counter + global_id_int64 = global_id % (MAX_INT64 - MIN_GLOBAL_ID) + MIN_GLOBAL_ID + return global_id_int64 + + +def get_local_request_id(last_id: int) -> int: + """ increment the last_id by 1 and mod by MIN_GLOBAL_ID """ + return (last_id + 1) & (MIN_GLOBAL_ID - 1) diff --git a/tensorrt_llm/serve/openai_disagg_service.py b/tensorrt_llm/serve/openai_disagg_service.py index a0012bd6d3..c1fb8f2af5 100644 --- a/tensorrt_llm/serve/openai_disagg_service.py +++ b/tensorrt_llm/serve/openai_disagg_service.py @@ -23,6 +23,7 @@ from tensorrt_llm.llmapi.disagg_utils import ( DisaggServerConfig, MetadataServerConfig, ServerRole, + get_global_disagg_request_id, ) from tensorrt_llm.logger import logger from tensorrt_llm.serve.cluster_storage import ClusterStorage, WatchEventType @@ -115,14 +116,15 @@ class OpenAIDisaggregatedService(OpenAIService): need_ctx = need_ctx and not await self._check_gen_only_disagg(request) ctx_response = None gen_req = request + disagg_request_id = get_global_disagg_request_id(self._config.node_id) if need_ctx: - ctx_req = self._get_ctx_request(request) + ctx_req = self._get_ctx_request(request, disagg_request_id) # ctx generator is empty ctx_response = await self._ctx_client.send_request( ctx_req, server=reserved_ctx_server, hooks=hooks ) await self._verify_ctx_response(ctx_response) - gen_req = self._get_gen_request(request, ctx_response) + gen_req = self._get_gen_request(request, ctx_response, disagg_request_id) if ctx_response is None or self._need_gen(ctx_response): return await self._gen_client.send_request( gen_req, server=reserved_gen_server, hooks=hooks @@ -140,9 +142,13 @@ class OpenAIDisaggregatedService(OpenAIService): return False return True - def _get_ctx_request(self, request: UCompletionRequest) -> UCompletionRequest: + def _get_ctx_request( + self, request: UCompletionRequest, disagg_request_id: Optional[int] + ) -> UCompletionRequest: ctx_request = copy.deepcopy(request) - ctx_request.disaggregated_params = DisaggregatedParams(request_type="context_only") + ctx_request.disaggregated_params = DisaggregatedParams( + request_type="context_only", disagg_request_id=disagg_request_id + ) ctx_request.stream = False ctx_request.stream_options = None return ctx_request @@ -151,6 +157,7 @@ class OpenAIDisaggregatedService(OpenAIService): self, request: UCompletionRequest, ctx_response: UCompletionResponse, + disagg_request_id: Optional[int], ) -> UCompletionRequest: request.disaggregated_params = ctx_response.choices[0].disaggregated_params request.disaggregated_params.request_type = "generation_only" diff --git a/tensorrt_llm/serve/openai_protocol.py b/tensorrt_llm/serve/openai_protocol.py index 8ddda27cd7..ea27bb053d 100644 --- a/tensorrt_llm/serve/openai_protocol.py +++ b/tensorrt_llm/serve/openai_protocol.py @@ -117,6 +117,7 @@ class DisaggregatedParams(OpenAIBaseModel): ctx_request_id: Optional[int] = None encoded_opaque_state: Optional[str] = None draft_tokens: Optional[List[int]] = None + disagg_request_id: Optional[int] = None class ErrorResponse(OpenAIBaseModel): @@ -1000,7 +1001,8 @@ def to_disaggregated_params( ctx_request_id=tllm_disagg_params.ctx_request_id, encoded_opaque_state=encode_opaque_state( tllm_disagg_params.opaque_state), - draft_tokens=tllm_disagg_params.draft_tokens) + draft_tokens=tllm_disagg_params.draft_tokens, + disagg_request_id=tllm_disagg_params.disagg_request_id) def to_llm_disaggregated_params( @@ -1013,7 +1015,8 @@ def to_llm_disaggregated_params( ctx_request_id=disaggregated_params.ctx_request_id, opaque_state=decode_opaque_state( disaggregated_params.encoded_opaque_state), - draft_tokens=disaggregated_params.draft_tokens) + draft_tokens=disaggregated_params.draft_tokens, + disagg_request_id=disaggregated_params.disagg_request_id) UCompletionRequest = Union[CompletionRequest, ChatCompletionRequest] diff --git a/tests/unittest/disaggregated/test_disagg_utils.py b/tests/unittest/disaggregated/test_disagg_utils.py index 0180d7e6d6..8d4be0f213 100644 --- a/tests/unittest/disaggregated/test_disagg_utils.py +++ b/tests/unittest/disaggregated/test_disagg_utils.py @@ -1,11 +1,13 @@ +from concurrent.futures import ThreadPoolExecutor + import pytest import yaml # isort: off from tensorrt_llm.llmapi.disagg_utils import ( - CtxGenServerConfig, DisaggServerConfig, extract_ctx_gen_cfgs, - extract_router_config, extract_disagg_cfg, get_server_configs_dict, - parse_disagg_config_file) + MIN_GLOBAL_ID, CtxGenServerConfig, DisaggServerConfig, extract_ctx_gen_cfgs, + extract_router_config, extract_disagg_cfg, get_global_disagg_request_id, + get_local_request_id, get_server_configs_dict, parse_disagg_config_file) # isort: on @@ -155,3 +157,45 @@ def test_get_server_configs_dict(): assert len(server_dict) == 2 assert ("host1", 8001) in server_dict assert ("host2", 8002) in server_dict + + +# test get_global_disagg_request_id +@pytest.mark.parametrize("multithread", [True, False], + ids=["multithread", "singlethread"]) +def test_get_global_disagg_request_id(multithread): + iter = 10000 + + def get_ids(node_ids): + all_node_ids = [[] for _ in range(len(node_ids))] + for i in range(iter): + for i, node_id in enumerate(node_ids): + all_node_ids[i].append(get_global_disagg_request_id(node_id)) + return all_node_ids + + node_ids = list(range(10)) + if multithread: + with ThreadPoolExecutor(max_workers=len(node_ids)) as executor: + all_node_ids = [ + ids[0] for ids in executor.map(get_ids, [[i] for i in node_ids]) + ] + else: + all_node_ids = get_ids(node_ids) + + all_ids = set(i for ids in all_node_ids for i in ids) + assert len(all_ids) == iter * len(node_ids) + assert all(id >= MIN_GLOBAL_ID and id < ((1 << 63) - 1) for id in all_ids) + + +def test_get_local_request_id(): + last_id = MIN_GLOBAL_ID - 100 + ids = set() + for i in range(1000): + last_id = get_local_request_id(last_id) + assert last_id >= 0 + assert last_id < MIN_GLOBAL_ID + ids.add(last_id) + assert len(ids) == 1000 + assert min(ids) == 0 + assert max(ids) == MIN_GLOBAL_ID - 1 + assert max(ids) - min(ids) > ( + 1 << 40) # ensure there is enough space for local ids