From 93db0d5e185c24da61e56e42c85287c3a61fa0de Mon Sep 17 00:00:00 2001 From: Lizhi Zhou <1432185+reasonsolo@users.noreply.github.com> Date: Thu, 15 Jan 2026 19:18:21 +0800 Subject: [PATCH] [TRTLLM-9942][feat] new request states and kvcache transceiver APIs in generation-first disagg (#10406) Signed-off-by: Lizhi Zhou <1432185+reasonsolo@users.noreply.github.com> --- .../tensorrt_llm/batch_manager/llmRequest.h | 11 +++++--- cpp/tensorrt_llm/nanobind/bindings.cpp | 4 ++- cpp/tensorrt_llm/pybind/bindings.cpp | 4 ++- .../_torch/pyexecutor/kv_cache_transceiver.py | 25 +++++++++++++++++++ 4 files changed, 39 insertions(+), 5 deletions(-) diff --git a/cpp/include/tensorrt_llm/batch_manager/llmRequest.h b/cpp/include/tensorrt_llm/batch_manager/llmRequest.h index 5757c57362..d3c967b7eb 100644 --- a/cpp/include/tensorrt_llm/batch_manager/llmRequest.h +++ b/cpp/include/tensorrt_llm/batch_manager/llmRequest.h @@ -49,6 +49,8 @@ enum class LlmRequestState : int32_t kUNKNOWN = 0, ///< Unknown state kENCODER_INIT = 1, ///< Encoder phase starts (for encoder-decoder models) + kDISAGG_CONTEXT_WAIT_SCHEDULER = 7, ///< Waiting for scheduler to schedule the context-only request + /// e.g. in gen-first mode when generation request is not scheduled yet kDISAGG_GENERATION_INIT = 8, ///< New Generation request arrived at generation model kDISAGG_GENERATION_TRANS_IN_PROGRESS = 9, ///< Transmitting the kv cache @@ -65,6 +67,7 @@ enum class LlmRequestState : int32_t kDISAGG_CONTEXT_TRANS_IN_PROGRESS = 21, ///< Waiting context-only request transmitting the kv cache, /// after computation finished kDISAGG_CONTEXT_COMPLETE = 22, ///< Context-only request finished kv cache transmission. + kDISAGG_GENERATION_WAIT_TOKENS = 23, ///< Generation-only request waiting for ctx/draft tokens to be received // error states kDISAGG_TRANS_ERROR = -1, ///< Error occurred during kv cache transmission @@ -1511,15 +1514,17 @@ public: { switch (mState) { - case batch_manager::LlmRequestState::kENCODER_INIT: return executor::RequestStage::kENCODER_IN_PROGRESS; break; - case batch_manager::LlmRequestState::kCONTEXT_INIT: return executor::RequestStage::kCONTEXT_IN_PROGRESS; break; + case batch_manager::LlmRequestState::kENCODER_INIT: return executor::RequestStage::kENCODER_IN_PROGRESS; + case batch_manager::LlmRequestState::kCONTEXT_INIT: + case batch_manager::LlmRequestState::kDISAGG_CONTEXT_WAIT_SCHEDULER: + return executor::RequestStage::kCONTEXT_IN_PROGRESS; case batch_manager::LlmRequestState::kGENERATION_IN_PROGRESS: case batch_manager::LlmRequestState::kGENERATION_TO_COMPLETE: case batch_manager::LlmRequestState::kDISAGG_GENERATION_TRANS_COMPLETE: case batch_manager::LlmRequestState::kDISAGG_GENERATION_INIT: case batch_manager::LlmRequestState::kDISAGG_GENERATION_TRANS_IN_PROGRESS: + case batch_manager::LlmRequestState::kDISAGG_GENERATION_WAIT_TOKENS: return executor::RequestStage::kGENERATION_IN_PROGRESS; - break; default: TLLM_LOG_ERROR("Unexpected request state."); return executor::RequestStage::kGENERATION_COMPLETE; } } diff --git a/cpp/tensorrt_llm/nanobind/bindings.cpp b/cpp/tensorrt_llm/nanobind/bindings.cpp index d9c2687c7f..06ea1d7a2a 100644 --- a/cpp/tensorrt_llm/nanobind/bindings.cpp +++ b/cpp/tensorrt_llm/nanobind/bindings.cpp @@ -481,7 +481,9 @@ NB_MODULE(TRTLLM_NB_MODULE, m) .value("DISAGG_GENERATION_TRANS_IN_PROGRESS", tb::LlmRequestState::kDISAGG_GENERATION_TRANS_IN_PROGRESS) .value("DISAGG_GENERATION_TRANS_COMPLETE", tb::LlmRequestState::kDISAGG_GENERATION_TRANS_COMPLETE) .value("DISAGG_CONTEXT_INIT_AND_TRANS", tb::LlmRequestState::kDISAGG_CONTEXT_INIT_AND_TRANS) - .value("DISAGG_TRANS_ERROR", tb::LlmRequestState::kDISAGG_TRANS_ERROR); + .value("DISAGG_TRANS_ERROR", tb::LlmRequestState::kDISAGG_TRANS_ERROR) + .value("DISAGG_CONTEXT_WAIT_SCHEDULER", tb::LlmRequestState::kDISAGG_CONTEXT_WAIT_SCHEDULER) + .value("DISAGG_GENERATION_WAIT_TOKENS", tb::LlmRequestState::kDISAGG_GENERATION_WAIT_TOKENS); nb::class_(m, "MemoryCounters") .def_static("instance", &tr::MemoryCounters::getInstance, nb::rv_policy::reference) diff --git a/cpp/tensorrt_llm/pybind/bindings.cpp b/cpp/tensorrt_llm/pybind/bindings.cpp index 0a04d5ad19..cb4f34b722 100644 --- a/cpp/tensorrt_llm/pybind/bindings.cpp +++ b/cpp/tensorrt_llm/pybind/bindings.cpp @@ -470,7 +470,9 @@ PYBIND11_MODULE(TRTLLM_PYBIND_MODULE, m) .value("DISAGG_GENERATION_TRANS_IN_PROGRESS", tb::LlmRequestState::kDISAGG_GENERATION_TRANS_IN_PROGRESS) .value("DISAGG_TRANS_ERROR", tb::LlmRequestState::kDISAGG_TRANS_ERROR) .value("DISAGG_GENERATION_TRANS_COMPLETE", tb::LlmRequestState::kDISAGG_GENERATION_TRANS_COMPLETE) - .value("DISAGG_CONTEXT_INIT_AND_TRANS", tb::LlmRequestState::kDISAGG_CONTEXT_INIT_AND_TRANS); + .value("DISAGG_CONTEXT_INIT_AND_TRANS", tb::LlmRequestState::kDISAGG_CONTEXT_INIT_AND_TRANS) + .value("DISAGG_CONTEXT_WAIT_SCHEDULER", tb::LlmRequestState::kDISAGG_CONTEXT_WAIT_SCHEDULER) + .value("DISAGG_GENERATION_WAIT_TOKENS", tb::LlmRequestState::kDISAGG_GENERATION_WAIT_TOKENS); py::class_(m, "MemoryCounters") .def_static("instance", &tr::MemoryCounters::getInstance, py::return_value_policy::reference) diff --git a/tensorrt_llm/_torch/pyexecutor/kv_cache_transceiver.py b/tensorrt_llm/_torch/pyexecutor/kv_cache_transceiver.py index 5616be7708..8fe669456e 100644 --- a/tensorrt_llm/_torch/pyexecutor/kv_cache_transceiver.py +++ b/tensorrt_llm/_torch/pyexecutor/kv_cache_transceiver.py @@ -1,5 +1,6 @@ from abc import ABC, abstractmethod from os import getenv +from typing import List import tensorrt_llm from tensorrt_llm import logger @@ -95,6 +96,24 @@ class KvCacheTransceiver(ABC): def cancel_request(self, req: LlmRequest): raise NotImplementedError + @abstractmethod + def prepare_context_request(self, requests: List[LlmRequest]): + """ + Prepare the context request for the cache transceiver in generation-first mode. + This method should set the context request state to DISAGG_CONTEXT_WAIT_SCHEDULER + so that it won't be scheduled if the responding generation kvcache request is not + yet received otherwise set it to CONTEXT_INIT. + """ + ... + + @abstractmethod + def get_context_state(self): + """ + Return the opaque context request state, which will be attached to the generation request. + The generation server will use this state to get kvcache in generation-first mode. + """ + ... + class BindKvCacheTransceiver(KvCacheTransceiver): @@ -141,6 +160,12 @@ class BindKvCacheTransceiver(KvCacheTransceiver): def cancel_request(self, req: LlmRequest): return self.impl.cancel_request(req) + def prepare_context_request(self, requests: List[LlmRequest]): + raise NotImplementedError + + def get_context_state(self): + raise NotImplementedError + class CacheTransBufferManager: