[TRTLLM-9942][feat] new request states and kvcache transceiver APIs in generation-first disagg (#10406)

Signed-off-by: Lizhi Zhou <1432185+reasonsolo@users.noreply.github.com>
This commit is contained in:
Lizhi Zhou 2026-01-15 19:18:21 +08:00 committed by GitHub
parent 3bc17e1aa3
commit 93db0d5e18
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 39 additions and 5 deletions

View File

@ -49,6 +49,8 @@ enum class LlmRequestState : int32_t
kUNKNOWN = 0, ///< Unknown state
kENCODER_INIT = 1, ///< Encoder phase starts (for encoder-decoder models)
kDISAGG_CONTEXT_WAIT_SCHEDULER = 7, ///< Waiting for scheduler to schedule the context-only request
/// e.g. in gen-first mode when generation request is not scheduled yet
kDISAGG_GENERATION_INIT = 8, ///< New Generation request arrived at generation model
kDISAGG_GENERATION_TRANS_IN_PROGRESS = 9, ///< Transmitting the kv cache
@ -65,6 +67,7 @@ enum class LlmRequestState : int32_t
kDISAGG_CONTEXT_TRANS_IN_PROGRESS = 21, ///< Waiting context-only request transmitting the kv cache,
/// after computation finished
kDISAGG_CONTEXT_COMPLETE = 22, ///< Context-only request finished kv cache transmission.
kDISAGG_GENERATION_WAIT_TOKENS = 23, ///< Generation-only request waiting for ctx/draft tokens to be received
// error states
kDISAGG_TRANS_ERROR = -1, ///< Error occurred during kv cache transmission
@ -1511,15 +1514,17 @@ public:
{
switch (mState)
{
case batch_manager::LlmRequestState::kENCODER_INIT: return executor::RequestStage::kENCODER_IN_PROGRESS; break;
case batch_manager::LlmRequestState::kCONTEXT_INIT: return executor::RequestStage::kCONTEXT_IN_PROGRESS; break;
case batch_manager::LlmRequestState::kENCODER_INIT: return executor::RequestStage::kENCODER_IN_PROGRESS;
case batch_manager::LlmRequestState::kCONTEXT_INIT:
case batch_manager::LlmRequestState::kDISAGG_CONTEXT_WAIT_SCHEDULER:
return executor::RequestStage::kCONTEXT_IN_PROGRESS;
case batch_manager::LlmRequestState::kGENERATION_IN_PROGRESS:
case batch_manager::LlmRequestState::kGENERATION_TO_COMPLETE:
case batch_manager::LlmRequestState::kDISAGG_GENERATION_TRANS_COMPLETE:
case batch_manager::LlmRequestState::kDISAGG_GENERATION_INIT:
case batch_manager::LlmRequestState::kDISAGG_GENERATION_TRANS_IN_PROGRESS:
case batch_manager::LlmRequestState::kDISAGG_GENERATION_WAIT_TOKENS:
return executor::RequestStage::kGENERATION_IN_PROGRESS;
break;
default: TLLM_LOG_ERROR("Unexpected request state."); return executor::RequestStage::kGENERATION_COMPLETE;
}
}

View File

@ -481,7 +481,9 @@ NB_MODULE(TRTLLM_NB_MODULE, m)
.value("DISAGG_GENERATION_TRANS_IN_PROGRESS", tb::LlmRequestState::kDISAGG_GENERATION_TRANS_IN_PROGRESS)
.value("DISAGG_GENERATION_TRANS_COMPLETE", tb::LlmRequestState::kDISAGG_GENERATION_TRANS_COMPLETE)
.value("DISAGG_CONTEXT_INIT_AND_TRANS", tb::LlmRequestState::kDISAGG_CONTEXT_INIT_AND_TRANS)
.value("DISAGG_TRANS_ERROR", tb::LlmRequestState::kDISAGG_TRANS_ERROR);
.value("DISAGG_TRANS_ERROR", tb::LlmRequestState::kDISAGG_TRANS_ERROR)
.value("DISAGG_CONTEXT_WAIT_SCHEDULER", tb::LlmRequestState::kDISAGG_CONTEXT_WAIT_SCHEDULER)
.value("DISAGG_GENERATION_WAIT_TOKENS", tb::LlmRequestState::kDISAGG_GENERATION_WAIT_TOKENS);
nb::class_<tr::MemoryCounters>(m, "MemoryCounters")
.def_static("instance", &tr::MemoryCounters::getInstance, nb::rv_policy::reference)

View File

@ -470,7 +470,9 @@ PYBIND11_MODULE(TRTLLM_PYBIND_MODULE, m)
.value("DISAGG_GENERATION_TRANS_IN_PROGRESS", tb::LlmRequestState::kDISAGG_GENERATION_TRANS_IN_PROGRESS)
.value("DISAGG_TRANS_ERROR", tb::LlmRequestState::kDISAGG_TRANS_ERROR)
.value("DISAGG_GENERATION_TRANS_COMPLETE", tb::LlmRequestState::kDISAGG_GENERATION_TRANS_COMPLETE)
.value("DISAGG_CONTEXT_INIT_AND_TRANS", tb::LlmRequestState::kDISAGG_CONTEXT_INIT_AND_TRANS);
.value("DISAGG_CONTEXT_INIT_AND_TRANS", tb::LlmRequestState::kDISAGG_CONTEXT_INIT_AND_TRANS)
.value("DISAGG_CONTEXT_WAIT_SCHEDULER", tb::LlmRequestState::kDISAGG_CONTEXT_WAIT_SCHEDULER)
.value("DISAGG_GENERATION_WAIT_TOKENS", tb::LlmRequestState::kDISAGG_GENERATION_WAIT_TOKENS);
py::class_<tr::MemoryCounters>(m, "MemoryCounters")
.def_static("instance", &tr::MemoryCounters::getInstance, py::return_value_policy::reference)

View File

@ -1,5 +1,6 @@
from abc import ABC, abstractmethod
from os import getenv
from typing import List
import tensorrt_llm
from tensorrt_llm import logger
@ -95,6 +96,24 @@ class KvCacheTransceiver(ABC):
def cancel_request(self, req: LlmRequest):
raise NotImplementedError
@abstractmethod
def prepare_context_request(self, requests: List[LlmRequest]):
"""
Prepare the context request for the cache transceiver in generation-first mode.
This method should set the context request state to DISAGG_CONTEXT_WAIT_SCHEDULER
so that it won't be scheduled if the responding generation kvcache request is not
yet received otherwise set it to CONTEXT_INIT.
"""
...
@abstractmethod
def get_context_state(self):
"""
Return the opaque context request state, which will be attached to the generation request.
The generation server will use this state to get kvcache in generation-first mode.
"""
...
class BindKvCacheTransceiver(KvCacheTransceiver):
@ -141,6 +160,12 @@ class BindKvCacheTransceiver(KvCacheTransceiver):
def cancel_request(self, req: LlmRequest):
return self.impl.cancel_request(req)
def prepare_context_request(self, requests: List[LlmRequest]):
raise NotImplementedError
def get_context_state(self):
raise NotImplementedError
class CacheTransBufferManager: