mirror of
https://github.com/vllm-project/vllm.git
synced 2026-06-06 00:16:14 +00:00
[Perf] Improve multimodal item handling from O(n) to O(log n) per step (#44212)
Signed-off-by: Andy Lo <andy@mistral.ai>
This commit is contained in:
@@ -43,6 +43,7 @@ def test_basic_allocate_and_reuse():
|
||||
|
||||
assert cache.check_and_update_cache(req, 0)
|
||||
assert "r1" in cache.cached["imgA"]
|
||||
assert cache.request_cached_ids["r1"] == {0}
|
||||
assert cache.num_free_slots == 6
|
||||
|
||||
# Free twice to bring refcount to 0.
|
||||
@@ -50,6 +51,7 @@ def test_basic_allocate_and_reuse():
|
||||
cache.free_encoder_input(req, 0)
|
||||
|
||||
assert not cache.cached["imgA"]
|
||||
assert "r1" not in cache.request_cached_ids
|
||||
assert "imgA" in cache.freeable
|
||||
assert cache.num_freeable_slots == 10
|
||||
assert cache.num_free_slots == 6
|
||||
@@ -63,10 +65,12 @@ def test_freeing_decreases_refcount_and_moves_to_freeable():
|
||||
manager.allocate(req, 0)
|
||||
|
||||
assert len(manager.cached["img3"]) == 1
|
||||
assert manager.request_cached_ids["req2"] == {0}
|
||||
|
||||
manager.free_encoder_input(req, 0)
|
||||
|
||||
assert not manager.cached["img3"]
|
||||
assert "req2" not in manager.request_cached_ids
|
||||
assert "img3" in manager.freeable
|
||||
assert manager.num_freeable_slots == 10
|
||||
|
||||
@@ -83,11 +87,13 @@ def test_free_request_frees_all_inputs():
|
||||
|
||||
assert len(manager.cached["a"]) == 1
|
||||
assert len(manager.cached["b"]) == 1
|
||||
assert manager.request_cached_ids["req3"] == {0, 1}
|
||||
|
||||
manager.free(req)
|
||||
|
||||
assert not manager.cached["a"]
|
||||
assert not manager.cached["b"]
|
||||
assert "req3" not in manager.request_cached_ids
|
||||
assert "a" in manager.freeable
|
||||
assert "b" in manager.freeable
|
||||
assert manager.num_freeable_slots == 10
|
||||
@@ -108,6 +114,7 @@ def test_eviction_when_cache_is_full():
|
||||
|
||||
# 'x' should have been evicted.
|
||||
assert "x" not in manager.cached
|
||||
assert "req1" not in manager.request_cached_ids
|
||||
assert "x" in manager.get_freed_mm_hashes()
|
||||
|
||||
|
||||
@@ -137,6 +144,7 @@ def test_has_cache_restores_from_freeable():
|
||||
# Should restore from freeable.
|
||||
assert manager.check_and_update_cache(req, 0)
|
||||
assert len(manager.cached["imgZ"]) == 1
|
||||
assert manager.request_cached_ids["reqY"] == {0}
|
||||
assert "imgZ" not in manager.freeable
|
||||
assert manager.num_freeable_slots == 6
|
||||
|
||||
@@ -205,6 +213,7 @@ def test_encoder_cache_with_is_embed_mask():
|
||||
|
||||
assert manager.num_free_slots == 92
|
||||
assert "img1" in manager.cached
|
||||
assert manager.request_cached_ids["r1"] == {0}
|
||||
|
||||
old_size = 100
|
||||
new_size = request.mm_features[0].mm_position.get_num_embeds()
|
||||
@@ -276,6 +285,7 @@ def test_reset_clears_all_state():
|
||||
manager.reset()
|
||||
|
||||
assert len(manager.cached) == 0
|
||||
assert len(manager.request_cached_ids) == 0
|
||||
assert len(manager.freeable) == 0
|
||||
assert len(manager.freed) == 0
|
||||
assert manager.num_free_slots == 20
|
||||
@@ -298,6 +308,26 @@ def test_reset_allows_fresh_allocations():
|
||||
assert manager.num_free_slots == 2
|
||||
assert "img2" in manager.cached
|
||||
assert "img1" not in manager.cached
|
||||
assert manager.request_cached_ids["req2"] == {0}
|
||||
assert "req1" not in manager.request_cached_ids
|
||||
|
||||
|
||||
def test_free_request_with_duplicate_mm_hashes():
|
||||
"""Freeing a request whose two inputs share the same mm_hash must fully
|
||||
clean up request_cached_ids. After the first free_encoder_input call,
|
||||
cached[mm_hash] becomes empty; the second call must still remove the
|
||||
remaining input_id from request_cached_ids."""
|
||||
manager = EncoderCacheManager(cache_size=20)
|
||||
|
||||
req = MockRequest("r1", ["imgA", "imgA"], [4, 4])
|
||||
|
||||
manager.allocate(req, 0)
|
||||
# input 1 has the same hash, so it's already cached.
|
||||
assert manager.check_and_update_cache(req, 1)
|
||||
assert manager.request_cached_ids["r1"] == {0, 1}
|
||||
|
||||
manager.free(req)
|
||||
assert "r1" not in manager.request_cached_ids
|
||||
|
||||
|
||||
def test_encoder_decoder_cache_manager_reset():
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import bisect
|
||||
import mimetypes
|
||||
from collections import defaultdict
|
||||
from collections.abc import Generator, Sequence
|
||||
@@ -18,6 +19,7 @@ from vllm.utils.import_utils import LazyLoader
|
||||
from .hasher import MultiModalHasher
|
||||
from .inputs import (
|
||||
BatchedTensorInputs,
|
||||
MultiModalFeatureSpec,
|
||||
MultiModalFieldElem,
|
||||
MultiModalKwargsItem,
|
||||
MultiModalSharedField,
|
||||
@@ -109,6 +111,29 @@ def encode_video_url(
|
||||
return f"data:{mimetype};base64,{video_b64}"
|
||||
|
||||
|
||||
def get_mm_features_in_window(
|
||||
mm_features: list[MultiModalFeatureSpec],
|
||||
start: int,
|
||||
end: int,
|
||||
) -> tuple[int, int]:
|
||||
"""Return (lo, hi) indices for features overlapping [start, end).
|
||||
|
||||
Assumes mm_features are sorted by offset and non-overlapping, so
|
||||
offset + length is also sorted.
|
||||
"""
|
||||
lo = bisect.bisect_left(
|
||||
mm_features,
|
||||
start + 1,
|
||||
key=lambda f: f.mm_position.offset + f.mm_position.length,
|
||||
)
|
||||
hi = bisect.bisect_left(
|
||||
mm_features,
|
||||
end,
|
||||
key=lambda f: f.mm_position.offset,
|
||||
)
|
||||
return lo, hi
|
||||
|
||||
|
||||
def argsort_mm_positions(
|
||||
mm_positions: MultiModalPlaceholders,
|
||||
) -> list[tuple[str, int]]:
|
||||
|
||||
@@ -44,7 +44,7 @@ class MistralCommonFeatureExtractor:
|
||||
if not self.audio_encoder.audio_config.is_streaming:
|
||||
audio = self.audio_encoder.pad(audio, self.sampling_rate)
|
||||
|
||||
audios_processed.append(torch.tensor(audio))
|
||||
audios_processed.append(torch.from_numpy(audio))
|
||||
|
||||
return BatchFeature(
|
||||
{"audio_arrays": audios_processed}, tensor_type=return_tensors
|
||||
|
||||
@@ -71,6 +71,8 @@ class EncoderCacheManager:
|
||||
|
||||
# mm_hash of mm_data => ids of requests that reference the mm_data
|
||||
self.cached: dict[str, set[str]] = {}
|
||||
# request_id => set of input_ids cached for that request
|
||||
self.request_cached_ids: dict[str, set[int]] = {}
|
||||
|
||||
# mm_hash of mm_data => num_encoder_embeds of the mm_data
|
||||
self.freeable: OrderedDict[str, int] = OrderedDict()
|
||||
@@ -83,6 +85,7 @@ class EncoderCacheManager:
|
||||
Called when model weights are updated to invalidate stale embeddings.
|
||||
"""
|
||||
self.cached.clear()
|
||||
self.request_cached_ids.clear()
|
||||
self.freeable.clear()
|
||||
self.freed.clear()
|
||||
self.num_free_slots = self.cache_size
|
||||
@@ -114,6 +117,7 @@ class EncoderCacheManager:
|
||||
self.num_freeable_slots -= num_encoder_embeds
|
||||
|
||||
self.cached[mm_hash].add(request.request_id)
|
||||
self.request_cached_ids.setdefault(request.request_id, set()).add(input_id)
|
||||
return True
|
||||
|
||||
def can_allocate(
|
||||
@@ -201,22 +205,13 @@ class EncoderCacheManager:
|
||||
assert self.num_freeable_slots >= num_encoder_embeds
|
||||
|
||||
self.cached[mm_hash].add(request_id)
|
||||
self.request_cached_ids.setdefault(request_id, set()).add(input_id)
|
||||
self.num_free_slots -= num_encoder_embeds
|
||||
self.num_freeable_slots -= num_encoder_embeds
|
||||
|
||||
def get_cached_input_ids(self, request: Request) -> set[int]:
|
||||
"""Get all cached multimodal input IDs for a request.
|
||||
|
||||
Returns the set of input IDs whose `mm_hash` exists in the cache map.
|
||||
This includes entries that are currently unreferenced (and thus present
|
||||
in `freeable`); for such entries, freeing for this request will be a
|
||||
no-op.
|
||||
"""
|
||||
return {
|
||||
input_id
|
||||
for input_id in range(len(request.mm_features))
|
||||
if request.mm_features[input_id].identifier in self.cached
|
||||
}
|
||||
"""Get all cached multimodal input IDs for a request."""
|
||||
return self.request_cached_ids.get(request.request_id, set())
|
||||
|
||||
def free_encoder_input(self, request: Request, input_id: int) -> None:
|
||||
"""Free the request's reference to the encoder input (`mm_data`)
|
||||
@@ -230,6 +225,12 @@ class EncoderCacheManager:
|
||||
"""
|
||||
req_id = request.request_id
|
||||
mm_hash = request.mm_features[input_id].identifier
|
||||
# Always clean up request_cached_ids, even if the mm_hash was
|
||||
# already evicted from cache (e.g. by can_allocate).
|
||||
if req_id in self.request_cached_ids:
|
||||
self.request_cached_ids[req_id].discard(input_id)
|
||||
if not self.request_cached_ids[req_id]:
|
||||
del self.request_cached_ids[req_id]
|
||||
# The mm_hash not in cache or the req_id set is empty
|
||||
if not self.cached.get(mm_hash, None):
|
||||
return
|
||||
@@ -248,8 +249,7 @@ class EncoderCacheManager:
|
||||
|
||||
Typically called when a request is finished, cancelled, or aborted.
|
||||
"""
|
||||
input_ids = self.get_cached_input_ids(request)
|
||||
for input_id in input_ids:
|
||||
for input_id in list(self.get_cached_input_ids(request)):
|
||||
self.free_encoder_input(request, input_id)
|
||||
|
||||
def get_freed_mm_hashes(self) -> list[str]:
|
||||
|
||||
@@ -29,6 +29,7 @@ from vllm.model_executor.layers.fused_moe.routed_experts_capturer import (
|
||||
)
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
|
||||
from vllm.multimodal.encoder_budget import MultiModalBudget
|
||||
from vllm.multimodal.utils import get_mm_features_in_window
|
||||
from vllm.v1.core.encoder_cache_manager import (
|
||||
EncoderCacheManager,
|
||||
EncoderDecoderCacheManager,
|
||||
@@ -1163,22 +1164,23 @@ class Scheduler(SchedulerInterface):
|
||||
# trackers for accounting at the encoder input level.
|
||||
mm_hashes_to_schedule = set()
|
||||
num_embeds_to_schedule = 0
|
||||
for i, mm_feature in enumerate(mm_features):
|
||||
|
||||
lo, hi = get_mm_features_in_window(
|
||||
mm_features,
|
||||
start=num_computed_tokens,
|
||||
end=num_computed_tokens + num_new_tokens + shift_computed_tokens,
|
||||
)
|
||||
# For encoder-decoder, all inputs sit at start_pos=0, so lo=0 always.
|
||||
if self.is_encoder_decoder:
|
||||
lo = 0
|
||||
|
||||
for i in range(lo, hi):
|
||||
mm_feature = mm_features[i]
|
||||
start_pos = mm_feature.mm_position.offset
|
||||
num_encoder_tokens = mm_feature.mm_position.length
|
||||
num_encoder_embeds = mm_feature.mm_position.get_num_embeds()
|
||||
item_identifier = mm_feature.identifier
|
||||
|
||||
# The encoder output is needed if the two ranges overlap:
|
||||
# [num_computed_tokens, num_computed_tokens + num_new_tokens) and
|
||||
# [start_pos, start_pos + num_encoder_tokens)
|
||||
if (
|
||||
start_pos
|
||||
>= num_computed_tokens + num_new_tokens + shift_computed_tokens
|
||||
):
|
||||
# The encoder input is not needed in this step.
|
||||
break
|
||||
|
||||
if self.is_encoder_decoder and num_computed_tokens > 0:
|
||||
assert start_pos == 0, (
|
||||
"Encoder input should be processed at the beginning of "
|
||||
@@ -1194,10 +1196,6 @@ class Scheduler(SchedulerInterface):
|
||||
# decoder tokens (num_computed_tokens > 0), then we know we
|
||||
# already calculated encoder inputs and can skip here.
|
||||
continue
|
||||
elif start_pos + num_encoder_tokens <= num_computed_tokens:
|
||||
# The encoder input is already computed and stored
|
||||
# in the decoder's KV cache.
|
||||
continue
|
||||
|
||||
if not self.is_encoder_decoder:
|
||||
# We are not using the encoder cache for encoder-decoder models,
|
||||
|
||||
@@ -104,7 +104,7 @@ from vllm.multimodal.inputs import (
|
||||
MultiModalKwargsItem,
|
||||
PlaceholderRange,
|
||||
)
|
||||
from vllm.multimodal.utils import group_and_batch_mm_kwargs
|
||||
from vllm.multimodal.utils import get_mm_features_in_window, group_and_batch_mm_kwargs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import SamplingType
|
||||
@@ -3100,23 +3100,18 @@ class GPUModelRunner(
|
||||
req_state = self.requests[req_id]
|
||||
num_computed_tokens = req_state.num_computed_tokens + shift_computed_tokens
|
||||
|
||||
for mm_feature in req_state.mm_features:
|
||||
mm_features = req_state.mm_features
|
||||
lo, hi = get_mm_features_in_window(
|
||||
mm_features,
|
||||
start=num_computed_tokens,
|
||||
end=num_computed_tokens + num_scheduled_tokens,
|
||||
)
|
||||
for i in range(lo, hi):
|
||||
mm_feature = mm_features[i]
|
||||
pos_info = mm_feature.mm_position
|
||||
start_pos = pos_info.offset
|
||||
num_encoder_tokens = pos_info.length
|
||||
|
||||
# The encoder output is needed if the two ranges overlap:
|
||||
# [num_computed_tokens,
|
||||
# num_computed_tokens + num_scheduled_tokens) and
|
||||
# [start_pos, start_pos + num_encoder_tokens)
|
||||
if start_pos >= num_computed_tokens + num_scheduled_tokens:
|
||||
# The encoder output is not needed in this step.
|
||||
break
|
||||
if start_pos + num_encoder_tokens <= num_computed_tokens:
|
||||
# The encoder output is already processed and stored
|
||||
# in the decoder's KV cache.
|
||||
continue
|
||||
|
||||
start_idx = max(num_computed_tokens - start_pos, 0)
|
||||
end_idx = min(
|
||||
num_computed_tokens - start_pos + num_scheduled_tokens,
|
||||
|
||||
Reference in New Issue
Block a user