From ef268e20622fb1dfde93087e3f13611c02b71937 Mon Sep 17 00:00:00 2001 From: Jin Li <59594262+liji-nv@users.noreply.github.com> Date: Fri, 30 Jan 2026 14:49:17 +0800 Subject: [PATCH] [TRTLLM-9904][feat] Changes for future KVCacheV2 MTP support (#11029) Signed-off-by: Jin Li <59594262+liji-nv@users.noreply.github.com> --- cpp/include/tensorrt_llm/batch_manager/llmRequest.h | 10 ++++++++-- cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp | 2 ++ tensorrt_llm/_torch/pyexecutor/model_engine.py | 10 ++++++++-- 3 files changed, 18 insertions(+), 4 deletions(-) diff --git a/cpp/include/tensorrt_llm/batch_manager/llmRequest.h b/cpp/include/tensorrt_llm/batch_manager/llmRequest.h index 2d6e792281..1d05c42e20 100644 --- a/cpp/include/tensorrt_llm/batch_manager/llmRequest.h +++ b/cpp/include/tensorrt_llm/batch_manager/llmRequest.h @@ -1555,8 +1555,14 @@ public: void setContextCurrentPosition(SizeType32 contextCurrentPosition) { - mContextCurrentPositionDraft = contextCurrentPosition; - mContextCurrentPositionTarget = contextCurrentPosition; + if (mUseDraftModel) + { + mContextCurrentPositionDraft = contextCurrentPosition; + } + else + { + mContextCurrentPositionTarget = contextCurrentPosition; + } } /// When chunked, the position of the current chunk is returned. Otherwise, only the beginning diff --git a/cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp b/cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp index 87f8c2c3cc..a56fc38d00 100644 --- a/cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp +++ b/cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp @@ -166,6 +166,8 @@ void initBindings(nb::module_& m) .def_prop_rw( "context_current_position", &GenLlmReq::getContextCurrentPosition, &GenLlmReq::setContextCurrentPosition) .def_prop_ro("prepopulated_prompt_len", &GenLlmReq::getPrepopulatedPromptLen) + .def("set_prepopulated_prompt_len", &GenLlmReq::setPrepopulatedPromptLen, nb::arg("prepopulated_prompt_len"), + nb::arg("kv_tokens_per_block")) .def_prop_rw("guided_decoding_params", &GenLlmReq::getGuidedDecodingParams, &GenLlmReq::setGuidedDecodingParams) .def_prop_rw("context_phase_params", &GenLlmReq::getContextPhaseParams, &GenLlmReq::setContextPhaseParams) .def_prop_ro("is_context_only_request", &GenLlmReq::isContextOnlyRequest) diff --git a/tensorrt_llm/_torch/pyexecutor/model_engine.py b/tensorrt_llm/_torch/pyexecutor/model_engine.py index e2fb2514e3..603ccc1e67 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_engine.py +++ b/tensorrt_llm/_torch/pyexecutor/model_engine.py @@ -1072,8 +1072,14 @@ class PyTorchModelEngine(ModelEngine): available_tokens = kv_cache_manager.get_num_available_tokens(draft_len) # Add one dummy request with the maximum possible sequence length. - max_seq_len = self.max_seq_len if max_seq_len is None else max_seq_len - token_num = max(1, min(available_tokens, max_seq_len - 1)) + max_seq_len = min( + self.max_seq_len if max_seq_len is None else max_seq_len, + kv_cache_manager.max_seq_len) + token_num = max( + 1, + min( + available_tokens, max_seq_len - 1 - + get_num_extra_kv_tokens(self.spec_config) - draft_len)) model_config = self.model.model_config.pretrained_config max_position_embeddings = getattr(model_config, 'max_position_embeddings', None)