[Bugfix] Split attention groups by num_heads_q for spec-decode drafts (#43543)

Signed-off-by: Luciano Martins <lucianommartins@users.noreply.github.com>
Co-authored-by: Luciano Martins <lucianommartins@users.noreply.github.com>
This commit is contained in:
Luciano Martins
2026-05-26 21:11:01 -03:00
committed by GitHub
parent e19b9b1045
commit dede691c95
2 changed files with 26 additions and 5 deletions
+1
View File
@@ -32,6 +32,7 @@ steps:
source_file_dependencies:
- vllm/v1/spec_decode/
- vllm/v1/worker/gpu/spec_decode/
- vllm/v1/attention/backends/
- vllm/transformers_utils/configs/speculators/
- tests/v1/e2e/spec_decode/
commands:
+25 -5
View File
@@ -6556,8 +6556,21 @@ class GPUModelRunner(
assert len(self.attn_groups) == 0, "Attention backends are already initialized"
class AttentionGroupKey(NamedTuple):
"""Deduplication key for attention groups within a KV cache group.
Splits on per-rank ``num_heads_q`` in addition to backend + spec
so layers with different Q-head counts (e.g. a spec-decode draft
with fewer attention heads than its target) get separate metadata
builders. The builders' scratch (e.g. ``softmax_segm_*`` in
``triton_attn``, ``num_qo_heads`` in FlashInfer) is sized by
``num_heads_q`` and assumes uniformity within the group; see
``get_num_attention_heads_from_layers`` in
``vllm/v1/attention/backends/utils.py``.
"""
attn_backend: type[AttentionBackend]
kv_cache_spec: KVCacheSpec
num_heads_q: int
def get_attn_backends_for_group(
kv_cache_group_spec: KVCacheGroupSpec,
@@ -6586,9 +6599,16 @@ class GPUModelRunner(
layer_kv_cache_spec = kv_cache_group_spec.kv_cache_spec
if isinstance(layer_kv_cache_spec, UniformTypeKVCacheSpecs):
layer_kv_cache_spec = layer_kv_cache_spec.kv_cache_specs[layer_name]
key = (full_cls_name, layer_kv_cache_spec)
# Non-Attention layer types (e.g. Mamba1, ShortConv) do not
# expose ``num_heads``; fall back to 0 so they cluster as
# before. Such layers never coexist with Attention in a
# single KV cache group (different KVCacheSpec), so the
# fallback can never spuriously merge them with attention
# layers.
num_heads_q = getattr(layers[layer_name], "num_heads", 0)
key = (full_cls_name, layer_kv_cache_spec, num_heads_q)
attn_backends[key] = AttentionGroupKey(
attn_backend, layer_kv_cache_spec
attn_backend, layer_kv_cache_spec, num_heads_q
)
attn_backend_layers[key].append(layer_name)
return (
@@ -6601,11 +6621,11 @@ class GPUModelRunner(
kv_cache_group_id: int,
) -> list[AttentionGroup]:
attn_groups: list[AttentionGroup] = []
for (attn_backend, kv_cache_spec), layer_names in attn_backends_map.items():
for key, layer_names in attn_backends_map.items():
attn_group = AttentionGroup(
attn_backend,
key.attn_backend,
layer_names,
kv_cache_spec,
key.kv_cache_spec,
kv_cache_group_id,
)