mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-04 18:21:52 +08:00
[TRTLLM-9904][feat] Changes for future KVCacheV2 MTP support (#11029)
Signed-off-by: Jin Li <59594262+liji-nv@users.noreply.github.com>
This commit is contained in:
parent
6506d63466
commit
ef268e2062
@ -1555,8 +1555,14 @@ public:
|
||||
|
||||
void setContextCurrentPosition(SizeType32 contextCurrentPosition)
|
||||
{
|
||||
mContextCurrentPositionDraft = contextCurrentPosition;
|
||||
mContextCurrentPositionTarget = contextCurrentPosition;
|
||||
if (mUseDraftModel)
|
||||
{
|
||||
mContextCurrentPositionDraft = contextCurrentPosition;
|
||||
}
|
||||
else
|
||||
{
|
||||
mContextCurrentPositionTarget = contextCurrentPosition;
|
||||
}
|
||||
}
|
||||
|
||||
/// When chunked, the position of the current chunk is returned. Otherwise, only the beginning
|
||||
|
||||
@ -166,6 +166,8 @@ void initBindings(nb::module_& m)
|
||||
.def_prop_rw(
|
||||
"context_current_position", &GenLlmReq::getContextCurrentPosition, &GenLlmReq::setContextCurrentPosition)
|
||||
.def_prop_ro("prepopulated_prompt_len", &GenLlmReq::getPrepopulatedPromptLen)
|
||||
.def("set_prepopulated_prompt_len", &GenLlmReq::setPrepopulatedPromptLen, nb::arg("prepopulated_prompt_len"),
|
||||
nb::arg("kv_tokens_per_block"))
|
||||
.def_prop_rw("guided_decoding_params", &GenLlmReq::getGuidedDecodingParams, &GenLlmReq::setGuidedDecodingParams)
|
||||
.def_prop_rw("context_phase_params", &GenLlmReq::getContextPhaseParams, &GenLlmReq::setContextPhaseParams)
|
||||
.def_prop_ro("is_context_only_request", &GenLlmReq::isContextOnlyRequest)
|
||||
|
||||
@ -1072,8 +1072,14 @@ class PyTorchModelEngine(ModelEngine):
|
||||
available_tokens = kv_cache_manager.get_num_available_tokens(draft_len)
|
||||
|
||||
# Add one dummy request with the maximum possible sequence length.
|
||||
max_seq_len = self.max_seq_len if max_seq_len is None else max_seq_len
|
||||
token_num = max(1, min(available_tokens, max_seq_len - 1))
|
||||
max_seq_len = min(
|
||||
self.max_seq_len if max_seq_len is None else max_seq_len,
|
||||
kv_cache_manager.max_seq_len)
|
||||
token_num = max(
|
||||
1,
|
||||
min(
|
||||
available_tokens, max_seq_len - 1 -
|
||||
get_num_extra_kv_tokens(self.spec_config) - draft_len))
|
||||
model_config = self.model.model_config.pretrained_config
|
||||
max_position_embeddings = getattr(model_config,
|
||||
'max_position_embeddings', None)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user