[TRTLLM-9904][feat] Changes for future KVCacheV2 MTP support (#11029)

Signed-off-by: Jin Li <59594262+liji-nv@users.noreply.github.com>
This commit is contained in:
Jin Li 2026-01-30 14:49:17 +08:00 committed by GitHub
parent 6506d63466
commit ef268e2062
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 18 additions and 4 deletions

View File

@ -1555,8 +1555,14 @@ public:
void setContextCurrentPosition(SizeType32 contextCurrentPosition)
{
mContextCurrentPositionDraft = contextCurrentPosition;
mContextCurrentPositionTarget = contextCurrentPosition;
if (mUseDraftModel)
{
mContextCurrentPositionDraft = contextCurrentPosition;
}
else
{
mContextCurrentPositionTarget = contextCurrentPosition;
}
}
/// When chunked, the position of the current chunk is returned. Otherwise, only the beginning

View File

@ -166,6 +166,8 @@ void initBindings(nb::module_& m)
.def_prop_rw(
"context_current_position", &GenLlmReq::getContextCurrentPosition, &GenLlmReq::setContextCurrentPosition)
.def_prop_ro("prepopulated_prompt_len", &GenLlmReq::getPrepopulatedPromptLen)
.def("set_prepopulated_prompt_len", &GenLlmReq::setPrepopulatedPromptLen, nb::arg("prepopulated_prompt_len"),
nb::arg("kv_tokens_per_block"))
.def_prop_rw("guided_decoding_params", &GenLlmReq::getGuidedDecodingParams, &GenLlmReq::setGuidedDecodingParams)
.def_prop_rw("context_phase_params", &GenLlmReq::getContextPhaseParams, &GenLlmReq::setContextPhaseParams)
.def_prop_ro("is_context_only_request", &GenLlmReq::isContextOnlyRequest)

View File

@ -1072,8 +1072,14 @@ class PyTorchModelEngine(ModelEngine):
available_tokens = kv_cache_manager.get_num_available_tokens(draft_len)
# Add one dummy request with the maximum possible sequence length.
max_seq_len = self.max_seq_len if max_seq_len is None else max_seq_len
token_num = max(1, min(available_tokens, max_seq_len - 1))
max_seq_len = min(
self.max_seq_len if max_seq_len is None else max_seq_len,
kv_cache_manager.max_seq_len)
token_num = max(
1,
min(
available_tokens, max_seq_len - 1 -
get_num_extra_kv_tokens(self.spec_config) - draft_len))
model_config = self.model.model_config.pretrained_config
max_position_embeddings = getattr(model_config,
'max_position_embeddings', None)