mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[https://nvbugs/5787453][fix] Better align MLA chunking with indexer chunking when chunked prefill enabled for DSV32 (#10552)
This commit is contained in:
parent
4a09acd012
commit
78bb245554
@ -1047,12 +1047,11 @@ class Indexer(nn.Module):
|
||||
# Indexer should just process the current MLA chunk as a single chunk
|
||||
has_mla_chunked_prefill = (
|
||||
metadata.enable_context_mla_with_cached_kv
|
||||
and host_cached_tokens.sum().item() > 0
|
||||
and metadata.runtime_features.chunked_prefill)
|
||||
|
||||
if has_mla_chunked_prefill:
|
||||
# The MLA has already split the sequence, here just process what's given (as a single chunk)
|
||||
# Cached token info is derived from metadata.host_ctx_cached_token_indptr in prepare_one_prefill_chunk
|
||||
# MLA chunked prefill is active - use single-chunk pattern for
|
||||
# indexer prefill chunks.
|
||||
chunk_specs = [(i, 0, host_seq_lens[i].item(),
|
||||
host_seq_lens[:i].sum().item() if i > 0 else 0)
|
||||
for i in range(num_contexts)]
|
||||
@ -1063,7 +1062,8 @@ class Indexer(nn.Module):
|
||||
)
|
||||
]
|
||||
else:
|
||||
# Normal mode: use indexer's own chunking logic to prevent L^2 complexity when long-sequence is used.
|
||||
# Use indexer's own chunking logic to prevent L^2 complexity of indexer MQA logits computation for long sequences.
|
||||
# This is only used when MLA chunked prefill is not enabled.
|
||||
chunk_groups = split_prefill_chunks(
|
||||
host_seq_lens,
|
||||
metadata.indexer_max_chunk_size,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user