[Perf] Improve multimodal item handling from O(n) to O(log n) per step (#44212)

Signed-off-by: Andy Lo <andy@mistral.ai>
This commit is contained in:
Andy Lo
2026-06-03 12:00:26 +01:00
committed by GitHub
parent 1fa9ea09f6
commit 95b1615ec9
6 changed files with 92 additions and 44 deletions
@@ -43,6 +43,7 @@ def test_basic_allocate_and_reuse():
assert cache.check_and_update_cache(req, 0)
assert "r1" in cache.cached["imgA"]
assert cache.request_cached_ids["r1"] == {0}
assert cache.num_free_slots == 6
# Free twice to bring refcount to 0.
@@ -50,6 +51,7 @@ def test_basic_allocate_and_reuse():
cache.free_encoder_input(req, 0)
assert not cache.cached["imgA"]
assert "r1" not in cache.request_cached_ids
assert "imgA" in cache.freeable
assert cache.num_freeable_slots == 10
assert cache.num_free_slots == 6
@@ -63,10 +65,12 @@ def test_freeing_decreases_refcount_and_moves_to_freeable():
manager.allocate(req, 0)
assert len(manager.cached["img3"]) == 1
assert manager.request_cached_ids["req2"] == {0}
manager.free_encoder_input(req, 0)
assert not manager.cached["img3"]
assert "req2" not in manager.request_cached_ids
assert "img3" in manager.freeable
assert manager.num_freeable_slots == 10
@@ -83,11 +87,13 @@ def test_free_request_frees_all_inputs():
assert len(manager.cached["a"]) == 1
assert len(manager.cached["b"]) == 1
assert manager.request_cached_ids["req3"] == {0, 1}
manager.free(req)
assert not manager.cached["a"]
assert not manager.cached["b"]
assert "req3" not in manager.request_cached_ids
assert "a" in manager.freeable
assert "b" in manager.freeable
assert manager.num_freeable_slots == 10
@@ -108,6 +114,7 @@ def test_eviction_when_cache_is_full():
# 'x' should have been evicted.
assert "x" not in manager.cached
assert "req1" not in manager.request_cached_ids
assert "x" in manager.get_freed_mm_hashes()
@@ -137,6 +144,7 @@ def test_has_cache_restores_from_freeable():
# Should restore from freeable.
assert manager.check_and_update_cache(req, 0)
assert len(manager.cached["imgZ"]) == 1
assert manager.request_cached_ids["reqY"] == {0}
assert "imgZ" not in manager.freeable
assert manager.num_freeable_slots == 6
@@ -205,6 +213,7 @@ def test_encoder_cache_with_is_embed_mask():
assert manager.num_free_slots == 92
assert "img1" in manager.cached
assert manager.request_cached_ids["r1"] == {0}
old_size = 100
new_size = request.mm_features[0].mm_position.get_num_embeds()
@@ -276,6 +285,7 @@ def test_reset_clears_all_state():
manager.reset()
assert len(manager.cached) == 0
assert len(manager.request_cached_ids) == 0
assert len(manager.freeable) == 0
assert len(manager.freed) == 0
assert manager.num_free_slots == 20
@@ -298,6 +308,26 @@ def test_reset_allows_fresh_allocations():
assert manager.num_free_slots == 2
assert "img2" in manager.cached
assert "img1" not in manager.cached
assert manager.request_cached_ids["req2"] == {0}
assert "req1" not in manager.request_cached_ids
def test_free_request_with_duplicate_mm_hashes():
"""Freeing a request whose two inputs share the same mm_hash must fully
clean up request_cached_ids. After the first free_encoder_input call,
cached[mm_hash] becomes empty; the second call must still remove the
remaining input_id from request_cached_ids."""
manager = EncoderCacheManager(cache_size=20)
req = MockRequest("r1", ["imgA", "imgA"], [4, 4])
manager.allocate(req, 0)
# input 1 has the same hash, so it's already cached.
assert manager.check_and_update_cache(req, 1)
assert manager.request_cached_ids["r1"] == {0, 1}
manager.free(req)
assert "r1" not in manager.request_cached_ids
def test_encoder_decoder_cache_manager_reset():
+25
View File
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import bisect
import mimetypes
from collections import defaultdict
from collections.abc import Generator, Sequence
@@ -18,6 +19,7 @@ from vllm.utils.import_utils import LazyLoader
from .hasher import MultiModalHasher
from .inputs import (
BatchedTensorInputs,
MultiModalFeatureSpec,
MultiModalFieldElem,
MultiModalKwargsItem,
MultiModalSharedField,
@@ -109,6 +111,29 @@ def encode_video_url(
return f"data:{mimetype};base64,{video_b64}"
def get_mm_features_in_window(
mm_features: list[MultiModalFeatureSpec],
start: int,
end: int,
) -> tuple[int, int]:
"""Return (lo, hi) indices for features overlapping [start, end).
Assumes mm_features are sorted by offset and non-overlapping, so
offset + length is also sorted.
"""
lo = bisect.bisect_left(
mm_features,
start + 1,
key=lambda f: f.mm_position.offset + f.mm_position.length,
)
hi = bisect.bisect_left(
mm_features,
end,
key=lambda f: f.mm_position.offset,
)
return lo, hi
def argsort_mm_positions(
mm_positions: MultiModalPlaceholders,
) -> list[tuple[str, int]]:
@@ -44,7 +44,7 @@ class MistralCommonFeatureExtractor:
if not self.audio_encoder.audio_config.is_streaming:
audio = self.audio_encoder.pad(audio, self.sampling_rate)
audios_processed.append(torch.tensor(audio))
audios_processed.append(torch.from_numpy(audio))
return BatchFeature(
{"audio_arrays": audios_processed}, tensor_type=return_tensors
+14 -14
View File
@@ -71,6 +71,8 @@ class EncoderCacheManager:
# mm_hash of mm_data => ids of requests that reference the mm_data
self.cached: dict[str, set[str]] = {}
# request_id => set of input_ids cached for that request
self.request_cached_ids: dict[str, set[int]] = {}
# mm_hash of mm_data => num_encoder_embeds of the mm_data
self.freeable: OrderedDict[str, int] = OrderedDict()
@@ -83,6 +85,7 @@ class EncoderCacheManager:
Called when model weights are updated to invalidate stale embeddings.
"""
self.cached.clear()
self.request_cached_ids.clear()
self.freeable.clear()
self.freed.clear()
self.num_free_slots = self.cache_size
@@ -114,6 +117,7 @@ class EncoderCacheManager:
self.num_freeable_slots -= num_encoder_embeds
self.cached[mm_hash].add(request.request_id)
self.request_cached_ids.setdefault(request.request_id, set()).add(input_id)
return True
def can_allocate(
@@ -201,22 +205,13 @@ class EncoderCacheManager:
assert self.num_freeable_slots >= num_encoder_embeds
self.cached[mm_hash].add(request_id)
self.request_cached_ids.setdefault(request_id, set()).add(input_id)
self.num_free_slots -= num_encoder_embeds
self.num_freeable_slots -= num_encoder_embeds
def get_cached_input_ids(self, request: Request) -> set[int]:
"""Get all cached multimodal input IDs for a request.
Returns the set of input IDs whose `mm_hash` exists in the cache map.
This includes entries that are currently unreferenced (and thus present
in `freeable`); for such entries, freeing for this request will be a
no-op.
"""
return {
input_id
for input_id in range(len(request.mm_features))
if request.mm_features[input_id].identifier in self.cached
}
"""Get all cached multimodal input IDs for a request."""
return self.request_cached_ids.get(request.request_id, set())
def free_encoder_input(self, request: Request, input_id: int) -> None:
"""Free the request's reference to the encoder input (`mm_data`)
@@ -230,6 +225,12 @@ class EncoderCacheManager:
"""
req_id = request.request_id
mm_hash = request.mm_features[input_id].identifier
# Always clean up request_cached_ids, even if the mm_hash was
# already evicted from cache (e.g. by can_allocate).
if req_id in self.request_cached_ids:
self.request_cached_ids[req_id].discard(input_id)
if not self.request_cached_ids[req_id]:
del self.request_cached_ids[req_id]
# The mm_hash not in cache or the req_id set is empty
if not self.cached.get(mm_hash, None):
return
@@ -248,8 +249,7 @@ class EncoderCacheManager:
Typically called when a request is finished, cancelled, or aborted.
"""
input_ids = self.get_cached_input_ids(request)
for input_id in input_ids:
for input_id in list(self.get_cached_input_ids(request)):
self.free_encoder_input(request, input_id)
def get_freed_mm_hashes(self) -> list[str]:
+13 -15
View File
@@ -29,6 +29,7 @@ from vllm.model_executor.layers.fused_moe.routed_experts_capturer import (
)
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.multimodal.encoder_budget import MultiModalBudget
from vllm.multimodal.utils import get_mm_features_in_window
from vllm.v1.core.encoder_cache_manager import (
EncoderCacheManager,
EncoderDecoderCacheManager,
@@ -1163,22 +1164,23 @@ class Scheduler(SchedulerInterface):
# trackers for accounting at the encoder input level.
mm_hashes_to_schedule = set()
num_embeds_to_schedule = 0
for i, mm_feature in enumerate(mm_features):
lo, hi = get_mm_features_in_window(
mm_features,
start=num_computed_tokens,
end=num_computed_tokens + num_new_tokens + shift_computed_tokens,
)
# For encoder-decoder, all inputs sit at start_pos=0, so lo=0 always.
if self.is_encoder_decoder:
lo = 0
for i in range(lo, hi):
mm_feature = mm_features[i]
start_pos = mm_feature.mm_position.offset
num_encoder_tokens = mm_feature.mm_position.length
num_encoder_embeds = mm_feature.mm_position.get_num_embeds()
item_identifier = mm_feature.identifier
# The encoder output is needed if the two ranges overlap:
# [num_computed_tokens, num_computed_tokens + num_new_tokens) and
# [start_pos, start_pos + num_encoder_tokens)
if (
start_pos
>= num_computed_tokens + num_new_tokens + shift_computed_tokens
):
# The encoder input is not needed in this step.
break
if self.is_encoder_decoder and num_computed_tokens > 0:
assert start_pos == 0, (
"Encoder input should be processed at the beginning of "
@@ -1194,10 +1196,6 @@ class Scheduler(SchedulerInterface):
# decoder tokens (num_computed_tokens > 0), then we know we
# already calculated encoder inputs and can skip here.
continue
elif start_pos + num_encoder_tokens <= num_computed_tokens:
# The encoder input is already computed and stored
# in the decoder's KV cache.
continue
if not self.is_encoder_decoder:
# We are not using the encoder cache for encoder-decoder models,
+9 -14
View File
@@ -104,7 +104,7 @@ from vllm.multimodal.inputs import (
MultiModalKwargsItem,
PlaceholderRange,
)
from vllm.multimodal.utils import group_and_batch_mm_kwargs
from vllm.multimodal.utils import get_mm_features_in_window, group_and_batch_mm_kwargs
from vllm.platforms import current_platform
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingType
@@ -3100,23 +3100,18 @@ class GPUModelRunner(
req_state = self.requests[req_id]
num_computed_tokens = req_state.num_computed_tokens + shift_computed_tokens
for mm_feature in req_state.mm_features:
mm_features = req_state.mm_features
lo, hi = get_mm_features_in_window(
mm_features,
start=num_computed_tokens,
end=num_computed_tokens + num_scheduled_tokens,
)
for i in range(lo, hi):
mm_feature = mm_features[i]
pos_info = mm_feature.mm_position
start_pos = pos_info.offset
num_encoder_tokens = pos_info.length
# The encoder output is needed if the two ranges overlap:
# [num_computed_tokens,
# num_computed_tokens + num_scheduled_tokens) and
# [start_pos, start_pos + num_encoder_tokens)
if start_pos >= num_computed_tokens + num_scheduled_tokens:
# The encoder output is not needed in this step.
break
if start_pos + num_encoder_tokens <= num_computed_tokens:
# The encoder output is already processed and stored
# in the decoder's KV cache.
continue
start_idx = max(num_computed_tokens - start_pos, 0)
end_idx = min(
num_computed_tokens - start_pos + num_scheduled_tokens,