[https://nvbugs/5787453][fix] Better align MLA chunking with indexer chunking when chunked prefill enabled for DSV32 (#10552)

This commit is contained in:
Chang Liu 2026-01-09 00:49:39 -08:00 committed by GitHub
parent 4a09acd012
commit 78bb245554
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1047,12 +1047,11 @@ class Indexer(nn.Module):
# Indexer should just process the current MLA chunk as a single chunk
has_mla_chunked_prefill = (
metadata.enable_context_mla_with_cached_kv
and host_cached_tokens.sum().item() > 0
and metadata.runtime_features.chunked_prefill)
if has_mla_chunked_prefill:
# The MLA has already split the sequence, here just process what's given (as a single chunk)
# Cached token info is derived from metadata.host_ctx_cached_token_indptr in prepare_one_prefill_chunk
# MLA chunked prefill is active - use single-chunk pattern for
# indexer prefill chunks.
chunk_specs = [(i, 0, host_seq_lens[i].item(),
host_seq_lens[:i].sum().item() if i > 0 else 0)
for i in range(num_contexts)]
@ -1063,7 +1062,8 @@ class Indexer(nn.Module):
)
]
else:
# Normal mode: use indexer's own chunking logic to prevent L^2 complexity when long-sequence is used.
# Use indexer's own chunking logic to prevent L^2 complexity of indexer MQA logits computation for long sequences.
# This is only used when MLA chunked prefill is not enabled.
chunk_groups = split_prefill_chunks(
host_seq_lens,
metadata.indexer_max_chunk_size,