mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-04 02:02:01 +08:00
[TRTLLM-9942][feat] new request states and kvcache transceiver APIs in generation-first disagg (#10406)
Signed-off-by: Lizhi Zhou <1432185+reasonsolo@users.noreply.github.com>
This commit is contained in:
parent
3bc17e1aa3
commit
93db0d5e18
@ -49,6 +49,8 @@ enum class LlmRequestState : int32_t
|
||||
kUNKNOWN = 0, ///< Unknown state
|
||||
kENCODER_INIT = 1, ///< Encoder phase starts (for encoder-decoder models)
|
||||
|
||||
kDISAGG_CONTEXT_WAIT_SCHEDULER = 7, ///< Waiting for scheduler to schedule the context-only request
|
||||
/// e.g. in gen-first mode when generation request is not scheduled yet
|
||||
kDISAGG_GENERATION_INIT = 8, ///< New Generation request arrived at generation model
|
||||
kDISAGG_GENERATION_TRANS_IN_PROGRESS = 9, ///< Transmitting the kv cache
|
||||
|
||||
@ -65,6 +67,7 @@ enum class LlmRequestState : int32_t
|
||||
kDISAGG_CONTEXT_TRANS_IN_PROGRESS = 21, ///< Waiting context-only request transmitting the kv cache,
|
||||
/// after computation finished
|
||||
kDISAGG_CONTEXT_COMPLETE = 22, ///< Context-only request finished kv cache transmission.
|
||||
kDISAGG_GENERATION_WAIT_TOKENS = 23, ///< Generation-only request waiting for ctx/draft tokens to be received
|
||||
|
||||
// error states
|
||||
kDISAGG_TRANS_ERROR = -1, ///< Error occurred during kv cache transmission
|
||||
@ -1511,15 +1514,17 @@ public:
|
||||
{
|
||||
switch (mState)
|
||||
{
|
||||
case batch_manager::LlmRequestState::kENCODER_INIT: return executor::RequestStage::kENCODER_IN_PROGRESS; break;
|
||||
case batch_manager::LlmRequestState::kCONTEXT_INIT: return executor::RequestStage::kCONTEXT_IN_PROGRESS; break;
|
||||
case batch_manager::LlmRequestState::kENCODER_INIT: return executor::RequestStage::kENCODER_IN_PROGRESS;
|
||||
case batch_manager::LlmRequestState::kCONTEXT_INIT:
|
||||
case batch_manager::LlmRequestState::kDISAGG_CONTEXT_WAIT_SCHEDULER:
|
||||
return executor::RequestStage::kCONTEXT_IN_PROGRESS;
|
||||
case batch_manager::LlmRequestState::kGENERATION_IN_PROGRESS:
|
||||
case batch_manager::LlmRequestState::kGENERATION_TO_COMPLETE:
|
||||
case batch_manager::LlmRequestState::kDISAGG_GENERATION_TRANS_COMPLETE:
|
||||
case batch_manager::LlmRequestState::kDISAGG_GENERATION_INIT:
|
||||
case batch_manager::LlmRequestState::kDISAGG_GENERATION_TRANS_IN_PROGRESS:
|
||||
case batch_manager::LlmRequestState::kDISAGG_GENERATION_WAIT_TOKENS:
|
||||
return executor::RequestStage::kGENERATION_IN_PROGRESS;
|
||||
break;
|
||||
default: TLLM_LOG_ERROR("Unexpected request state."); return executor::RequestStage::kGENERATION_COMPLETE;
|
||||
}
|
||||
}
|
||||
|
||||
@ -481,7 +481,9 @@ NB_MODULE(TRTLLM_NB_MODULE, m)
|
||||
.value("DISAGG_GENERATION_TRANS_IN_PROGRESS", tb::LlmRequestState::kDISAGG_GENERATION_TRANS_IN_PROGRESS)
|
||||
.value("DISAGG_GENERATION_TRANS_COMPLETE", tb::LlmRequestState::kDISAGG_GENERATION_TRANS_COMPLETE)
|
||||
.value("DISAGG_CONTEXT_INIT_AND_TRANS", tb::LlmRequestState::kDISAGG_CONTEXT_INIT_AND_TRANS)
|
||||
.value("DISAGG_TRANS_ERROR", tb::LlmRequestState::kDISAGG_TRANS_ERROR);
|
||||
.value("DISAGG_TRANS_ERROR", tb::LlmRequestState::kDISAGG_TRANS_ERROR)
|
||||
.value("DISAGG_CONTEXT_WAIT_SCHEDULER", tb::LlmRequestState::kDISAGG_CONTEXT_WAIT_SCHEDULER)
|
||||
.value("DISAGG_GENERATION_WAIT_TOKENS", tb::LlmRequestState::kDISAGG_GENERATION_WAIT_TOKENS);
|
||||
|
||||
nb::class_<tr::MemoryCounters>(m, "MemoryCounters")
|
||||
.def_static("instance", &tr::MemoryCounters::getInstance, nb::rv_policy::reference)
|
||||
|
||||
@ -470,7 +470,9 @@ PYBIND11_MODULE(TRTLLM_PYBIND_MODULE, m)
|
||||
.value("DISAGG_GENERATION_TRANS_IN_PROGRESS", tb::LlmRequestState::kDISAGG_GENERATION_TRANS_IN_PROGRESS)
|
||||
.value("DISAGG_TRANS_ERROR", tb::LlmRequestState::kDISAGG_TRANS_ERROR)
|
||||
.value("DISAGG_GENERATION_TRANS_COMPLETE", tb::LlmRequestState::kDISAGG_GENERATION_TRANS_COMPLETE)
|
||||
.value("DISAGG_CONTEXT_INIT_AND_TRANS", tb::LlmRequestState::kDISAGG_CONTEXT_INIT_AND_TRANS);
|
||||
.value("DISAGG_CONTEXT_INIT_AND_TRANS", tb::LlmRequestState::kDISAGG_CONTEXT_INIT_AND_TRANS)
|
||||
.value("DISAGG_CONTEXT_WAIT_SCHEDULER", tb::LlmRequestState::kDISAGG_CONTEXT_WAIT_SCHEDULER)
|
||||
.value("DISAGG_GENERATION_WAIT_TOKENS", tb::LlmRequestState::kDISAGG_GENERATION_WAIT_TOKENS);
|
||||
|
||||
py::class_<tr::MemoryCounters>(m, "MemoryCounters")
|
||||
.def_static("instance", &tr::MemoryCounters::getInstance, py::return_value_policy::reference)
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from os import getenv
|
||||
from typing import List
|
||||
|
||||
import tensorrt_llm
|
||||
from tensorrt_llm import logger
|
||||
@ -95,6 +96,24 @@ class KvCacheTransceiver(ABC):
|
||||
def cancel_request(self, req: LlmRequest):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def prepare_context_request(self, requests: List[LlmRequest]):
|
||||
"""
|
||||
Prepare the context request for the cache transceiver in generation-first mode.
|
||||
This method should set the context request state to DISAGG_CONTEXT_WAIT_SCHEDULER
|
||||
so that it won't be scheduled if the responding generation kvcache request is not
|
||||
yet received otherwise set it to CONTEXT_INIT.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def get_context_state(self):
|
||||
"""
|
||||
Return the opaque context request state, which will be attached to the generation request.
|
||||
The generation server will use this state to get kvcache in generation-first mode.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class BindKvCacheTransceiver(KvCacheTransceiver):
|
||||
|
||||
@ -141,6 +160,12 @@ class BindKvCacheTransceiver(KvCacheTransceiver):
|
||||
def cancel_request(self, req: LlmRequest):
|
||||
return self.impl.cancel_request(req)
|
||||
|
||||
def prepare_context_request(self, requests: List[LlmRequest]):
|
||||
raise NotImplementedError
|
||||
|
||||
def get_context_state(self):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class CacheTransBufferManager:
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user