diff --git a/docs/design/cuda_graphs_multimodal.md b/docs/design/cuda_graphs_multimodal.md index e32010232ef..f44ef359df3 100644 --- a/docs/design/cuda_graphs_multimodal.md +++ b/docs/design/cuda_graphs_multimodal.md @@ -86,9 +86,11 @@ Models opt-in to encoder CUDA Graphs by implementing the [SupportsEncoderCudaGra | Architecture | Models | CG for Image | CG for Video | | ------------ | ------ | ------------ | ------------ | | `Qwen3VLForConditionalGeneration` | `Qwen3-VL` | ✅︎ | ✅︎ | +| `Qwen2_5_VLForConditionalGeneration` | `Qwen2.5-VL` | ✅︎ | ✅︎ | !!! note Encoder CUDA Graphs have currently been tested with `--mm-encoder-attn-backend=FLASH_ATTN` and `--mm-encoder-attn-backend=FLASHINFER` on Blackwell GPUs. + For Qwen2.5-VL only FA2 and FA3 has been tested. ## Configuration diff --git a/examples/generate/multimodal/vision_language_offline.py b/examples/generate/multimodal/vision_language_offline.py index 87d42c036ec..794f20dd0a5 100644 --- a/examples/generate/multimodal/vision_language_offline.py +++ b/examples/generate/multimodal/vision_language_offline.py @@ -2466,6 +2466,7 @@ MODELS_NEED_VIDEO_METADATA = [ MODELS_SUPPORT_VIT_CUDA_GRAPH = [ "qwen3_vl", "qwen3_vl_moe", + "qwen2_5_vl", ] diff --git a/tests/models/multimodal/generation/test_qwen2_5_vl.py b/tests/models/multimodal/generation/test_qwen2_5_vl.py index 3ba665710af..791bb3b3088 100644 --- a/tests/models/multimodal/generation/test_qwen2_5_vl.py +++ b/tests/models/multimodal/generation/test_qwen2_5_vl.py @@ -3,6 +3,7 @@ import pytest +from vllm.assets.image import ImageAsset from vllm.multimodal.video import sample_frames_from_video from ....conftest import VIDEO_ASSETS @@ -11,6 +12,7 @@ models = ["Qwen/Qwen2.5-VL-3B-Instruct"] target_dtype = "bfloat16" VIDEO_PLACEHOLDER = "<|vision_start|><|video_pad|><|vision_end|>" +IMAGE_PLACEHOLDER = "<|vision_start|><|image_pad|><|vision_end|>" def qwen2_5_vl_chat_template(*query): @@ -28,6 +30,25 @@ VIDEO_PROMPTS = VIDEO_ASSETS.prompts( ) +WINDOW_ATTN_IMAGE_PROMPT = qwen2_5_vl_chat_template( + IMAGE_PLACEHOLDER, + "Describe the image.", +) + + +def _window_attention_regression_image(): + # image from regression issue: https://github.com/vllm-project/vllm/issues/15122 + image = ImageAsset("hato").pil_image + return image.resize((image.width // 2, image.height // 2)) + + +def _encoder_cudagraph_config(*, max_vision_items: int) -> dict: + return { + "cudagraph_mm_encoder": True, + "encoder_cudagraph_max_vision_items_per_batch": max_vision_items, + } + + @pytest.mark.core_model @pytest.mark.parametrize("model", models) @pytest.mark.parametrize("video_pruning_rate", [0.0, 0.75]) @@ -146,3 +167,77 @@ def test_qwen2_5_vl_evs_batched_videos( # Ensure the output is a string assert isinstance(output_text, str) + + +@pytest.mark.core_model +@pytest.mark.parametrize("model", models) +@pytest.mark.parametrize("dtype", [target_dtype]) +@pytest.mark.parametrize("max_tokens", [128]) +@pytest.mark.parametrize("use_bytecode_hook", [True, False]) +def test_qwen2_5_vl_window_attention_image( + vllm_runner, + model, + dtype: str, + max_tokens: int, + use_bytecode_hook: bool, + monkeypatch, +) -> None: + """Regression test for Qwen2.5 window-attention image path.""" + monkeypatch.setenv("VLLM_USE_BYTECODE_HOOK", "1" if use_bytecode_hook else "0") + + prompt = [WINDOW_ATTN_IMAGE_PROMPT] + images = [[_window_attention_regression_image()]] + + with vllm_runner( + model, + runner="generate", + max_model_len=4096, + dtype=dtype, + limit_mm_per_prompt={"image": 1}, + compilation_config=_encoder_cudagraph_config(max_vision_items=1), + ) as vllm_model: + outputs = vllm_model.generate_greedy(prompt, max_tokens, images=images) + + assert len(outputs) == 1 + output_ids, output_text = outputs[0] + assert len(output_ids) > 0 + assert len(output_text) > 0 + assert isinstance(output_text, str) + + +@pytest.mark.core_model +@pytest.mark.parametrize("model", models) +@pytest.mark.parametrize("dtype", [target_dtype]) +@pytest.mark.parametrize("max_tokens", [128]) +@pytest.mark.parametrize("use_bytecode_hook", [True, False]) +def test_qwen2_5_vl_window_attention_image_batch( + vllm_runner, + model, + dtype: str, + max_tokens: int, + use_bytecode_hook: bool, + monkeypatch, +) -> None: + """Regression test window-attention with a small image batch.""" + monkeypatch.setenv("VLLM_USE_BYTECODE_HOOK", "1" if use_bytecode_hook else "0") + + image = _window_attention_regression_image() + prompts = [WINDOW_ATTN_IMAGE_PROMPT, WINDOW_ATTN_IMAGE_PROMPT] + images = [[image], [image]] + + with vllm_runner( + model, + runner="generate", + max_model_len=4096, + max_num_seqs=2, + dtype=dtype, + limit_mm_per_prompt={"image": 1}, + compilation_config=_encoder_cudagraph_config(max_vision_items=2), + ) as vllm_model: + outputs = vllm_model.generate_greedy(prompts, max_tokens, images=images) + + assert len(outputs) == 2 + for output_ids, output_text in outputs: + assert len(output_ids) > 0 + assert len(output_text) > 0 + assert isinstance(output_text, str) diff --git a/tests/models/multimodal/generation/test_vit_cudagraph.py b/tests/models/multimodal/generation/test_vit_cudagraph.py index 7adea0771b6..fb7bdfc8625 100644 --- a/tests/models/multimodal/generation/test_vit_cudagraph.py +++ b/tests/models/multimodal/generation/test_vit_cudagraph.py @@ -54,7 +54,18 @@ MODEL_CONFIGS: dict[str, VitCudagraphTestConfig] = { needs_video_metadata=True, marks=[pytest.mark.core_model], ), - # TODO: Add more models below. + "qwen2_5_vl": VitCudagraphTestConfig( + model="Qwen/Qwen2.5-VL-3B-Instruct", + image_prompt=qwen_vl_chat_template( + "<|vision_start|><|image_pad|><|vision_end|>What is in this image?" + ), + video_prompt=qwen_vl_chat_template( + "<|vision_start|><|video_pad|><|vision_end|>" + "Describe this video in one sentence." + ), + needs_video_metadata=False, + marks=[pytest.mark.core_model], + ), } diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index c11684b4b89..54334c91bfa 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -85,11 +85,13 @@ from vllm.sequence import IntermediateTensors from vllm.utils.platform_utils import is_pin_memory_available from vllm.utils.tensor_schema import TensorSchema, TensorShape from vllm.v1.attention.backends.registry import AttentionBackendEnum +from vllm.v1.worker.encoder_cudagraph_defs import EncoderCudaGraphReplayBuffers from .interfaces import ( MultiModalEmbeddings, SupportsEagle, SupportsEagle3, + SupportsEncoderCudaGraph, SupportsLoRA, SupportsMRoPE, SupportsMultiModal, @@ -771,22 +773,54 @@ class Qwen2_5_VisionTransformer(nn.Module): inv[perm] = torch.arange(perm.numel(), device=perm.device, dtype=perm.dtype) return inv - def forward( + def prepare_encoder_metadata( self, - x: torch.Tensor, grid_thw: list[list[int]], - ) -> torch.Tensor: + *, + max_batch_size: int | None = None, + max_frames_per_batch: int | None = None, + max_window_seqs_per_batch: int | None = None, + max_seqlen_override: int | None = None, + max_seqlen_window_override: int | None = None, + device: torch.device | None = None, + ) -> dict[str, torch.Tensor]: + """Compute encoder metadata from grid_thw. + + Shared by the eager forward path, CUDA graph capture, and + CUDA graph replay to avoid duplicated implementation. + + Args: + grid_thw: Grid configurations as list of [t, h, w]. + max_batch_size: If set, pad cu_seqlens to this size + (needed for CUDA graph capture/replay). + max_frames_per_batch: If set, overrides max_batch_size for + cu_seqlens padding. For video inputs each item contributes + T attention sequences (frames); this sizes the buffer to + the total frame budget so video replays never overflow. + max_window_seqs_per_batch: If set, pad cu_window_seqlens to this + number of window sequences. This keeps cu_window_seqlens shape + stable across capture/replay for CUDA graph safety. + max_seqlen_override: If set, use this value for max_seqlen + instead of computing from cu_seqlens (needed for CUDA + graph capture to cover worst-case replay scenarios). + max_seqlen_window_override: If set, use this value for + window-attention max_seqlen instead of computing from + cu_window_seqlens (needed for CUDA graph capture to + cover worst-case replay scenarios). + device: Device to place tensors on. Defaults to self.device. + """ + + if device is None: + device = self.device + metadata: dict[str, torch.Tensor] = {} + # patchify - seq_len, _ = x.size() rotary_pos_emb_cos = [] rotary_pos_emb_sin = [] window_index: list = [] cu_window_seqlens: list = [torch.tensor([0], dtype=torch.int32)] cu_seqlens: list = [] - hidden_states = x.to(device=self.device, dtype=self.dtype) - hidden_states = self.patch_embed(hidden_states) - window_index_id = 0 cu_window_seqlens_last = 0 for t, h, w in grid_thw: @@ -825,23 +859,99 @@ class Qwen2_5_VisionTransformer(nn.Module): cu_seqlens = torch.cumsum(cu_seqlens, dim=0, dtype=torch.int32) cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0) + # Pad cu_seqlens to the required number of sequences. + # For videos each item contributes T frames = T attention sequences, + # so the total can exceed max_batch_size. max_frames_per_batch + # overrides the pad target when set. + pad_to = ( + max_frames_per_batch if max_frames_per_batch is not None else max_batch_size + ) + if pad_to is not None: + num_seqs = len(cu_seqlens) - 1 + if num_seqs < pad_to: + cu_seqlens = torch.cat( + ( + cu_seqlens, + torch.full( + (pad_to - num_seqs,), + cu_seqlens[-1], + dtype=cu_seqlens.dtype, + device=cu_seqlens.device, + ), + ) + ) + + # Pad cu_window_seqlens to a stable number of window sequences. + # Like cu_seqlens, we repeat the last cumulative offset so padded + # entries represent empty sequences. + if max_window_seqs_per_batch is not None: + num_window_seqs = len(cu_window_seqlens) - 1 + if num_window_seqs < max_window_seqs_per_batch: + cu_window_seqlens = torch.cat( + ( + cu_window_seqlens, + torch.full( + (max_window_seqs_per_batch - num_window_seqs,), + cu_window_seqlens[-1], + dtype=cu_window_seqlens.dtype, + device=cu_window_seqlens.device, + ), + ) + ) + # transformers # pre-compute seqlens for window/full attn to reduce cuMemcpy operations - max_seqlen_full = self.compute_attn_mask_seqlen(cu_seqlens) - max_seqlen_window = self.compute_attn_mask_seqlen(cu_window_seqlens) + if max_seqlen_override is None: + max_seqlen_full = self.compute_attn_mask_seqlen(cu_seqlens) + else: + max_seqlen_full = torch.tensor(max_seqlen_override, dtype=torch.int32) + if max_seqlen_window_override is None: + max_seqlen_window = self.compute_attn_mask_seqlen(cu_window_seqlens) + else: + max_seqlen_window = torch.tensor( + max_seqlen_window_override, dtype=torch.int32 + ) - cu_seqlens = cu_seqlens.to(device=self.device, non_blocking=True) - cu_window_seqlens = cu_window_seqlens.to(device=self.device, non_blocking=True) - rotary_pos_emb_cos = rotary_pos_emb_cos.to( - device=self.device, non_blocking=True - ) - rotary_pos_emb_sin = rotary_pos_emb_sin.to( - device=self.device, non_blocking=True - ) - window_index = window_index.to(device=hidden_states.device, non_blocking=True) - reverse_indices = reverse_indices.to( - device=hidden_states.device, non_blocking=True - ) + cu_seqlens = cu_seqlens.to(device=device, non_blocking=True) + cu_window_seqlens = cu_window_seqlens.to(device=device, non_blocking=True) + rotary_pos_emb_cos = rotary_pos_emb_cos.to(device=device, non_blocking=True) + rotary_pos_emb_sin = rotary_pos_emb_sin.to(device=device, non_blocking=True) + window_index = window_index.to(device=device, non_blocking=True) + reverse_indices = reverse_indices.to(device=device, non_blocking=True) + + metadata["rotary_pos_emb_cos"] = rotary_pos_emb_cos + metadata["rotary_pos_emb_sin"] = rotary_pos_emb_sin + metadata["window_index"] = window_index + metadata["reverse_indices"] = reverse_indices + metadata["cu_seqlens"] = cu_seqlens + metadata["cu_window_seqlens"] = cu_window_seqlens + metadata["max_seqlen_full"] = max_seqlen_full + metadata["max_seqlen_window"] = max_seqlen_window + + return metadata + + def forward( + self, + x: torch.Tensor, + grid_thw: list[list[int]], + *, + encoder_metadata: dict[str, torch.Tensor] | None = None, + ) -> torch.Tensor: + hidden_states = x.to(device=self.device, dtype=self.dtype) + hidden_states = self.patch_embed(hidden_states) + + seq_len = hidden_states.shape[0] + if encoder_metadata is None: + encoder_metadata = self.prepare_encoder_metadata(grid_thw) + + rotary_pos_emb_cos = encoder_metadata["rotary_pos_emb_cos"] + rotary_pos_emb_sin = encoder_metadata["rotary_pos_emb_sin"] + window_index = encoder_metadata["window_index"] + reverse_indices = encoder_metadata["reverse_indices"] + cu_seqlens = encoder_metadata["cu_seqlens"] + cu_window_seqlens = encoder_metadata["cu_window_seqlens"] + max_seqlen_full = encoder_metadata["max_seqlen_full"] + max_seqlen_window = encoder_metadata["max_seqlen_window"] hidden_states = hidden_states.reshape( seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1 @@ -1003,6 +1113,7 @@ class Qwen2_5_VLMultiModalProcessor(Qwen2VLMultiModalProcessor): class Qwen2_5_VLForConditionalGeneration( nn.Module, SupportsMultiModal, + SupportsEncoderCudaGraph, SupportsLoRA, SupportsPP, SupportsQuant, @@ -1124,6 +1235,7 @@ class Qwen2_5_VLForConditionalGeneration( self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data" self.config = config + self.model_config = vllm_config.model_config self.vllm_config = vllm_config self.multimodal_config = multimodal_config self.video_pruning_rate = multimodal_config.video_pruning_rate @@ -1447,6 +1559,302 @@ class Qwen2_5_VLForConditionalGeneration( multimodal_embeddings += tuple(video_embeddings) return multimodal_embeddings + # -- SupportsEncoderCudaGraph protocol methods -- + + def get_encoder_cudagraph_config(self): + from vllm.v1.worker.encoder_cudagraph_defs import ( + EncoderCudaGraphConfig, + ) + + # NOTE: With EVS pruning enabled, multimodal embeddings are post-processed + # (append positions for image and prune+append positions for video) in + # embed_multimodal(). The encoder CUDA graph path bypasses that postprocess + # hook, so disable CUDA graph for all modalities to avoid inconsistent + # embedding formats between eager and cudagraph paths. + modalities = [] if self.is_multimodal_pruning_enabled else ["image", "video"] + + return EncoderCudaGraphConfig( + modalities=modalities, + input_key_by_modality={ + "image": "pixel_values", + "video": "pixel_values_videos", + }, + buffer_keys=[ + "rotary_pos_emb_cos", + "rotary_pos_emb_sin", + "window_index", + "reverse_indices", + "cu_seqlens", + "cu_window_seqlens", + "max_seqlen_full", + "max_seqlen_window", + ], + out_hidden_size=self.visual.out_hidden_size, + ) + + def get_input_modality( + self, + mm_kwargs: dict[str, Any], + ) -> str: + if "image_grid_thw" in mm_kwargs: + return "image" + return "video" + + def get_max_frames_per_video(self) -> int: + mm_registry = MULTIMODAL_REGISTRY + info = mm_registry.get_processing_info(self.model_config) + max_frames_per_video = info.get_num_frames_with_most_features( + seq_len=self.model_config.max_model_len, + mm_counts={"video": self.multimodal_config.get_limit_per_prompt("video")}, + ) + return max_frames_per_video + + def get_encoder_cudagraph_budget_range( + self, + vllm_config: VllmConfig, + ) -> tuple[int, int]: + # Min: estimated smallest possible encoder input. + # 224x224 image → 16x16 patches (patch_size=14) + # spatial_merge_size=2 → 8x8 = 64 tokens + min_budget = 64 + # Max: capped by max_num_batched_tokens + max_budget = min( + vllm_config.scheduler_config.max_num_batched_tokens, + self.model_config.max_model_len, + ) + return (min_budget, max_budget) + + def _get_pixel_values_by_modality( + self, + mm_kwargs: dict[str, Any], + ) -> torch.Tensor: + if self.get_input_modality(mm_kwargs) == "image": + pixel_values = mm_kwargs["pixel_values"] + else: + pixel_values = mm_kwargs["pixel_values_videos"] + return pixel_values + + def _get_grid_thw_by_modality( + self, + mm_kwargs: dict[str, Any], + ) -> list[tuple[int, int, int]]: + grid_thw_key = f"{self.get_input_modality(mm_kwargs)}_grid_thw" + grid_thw = mm_kwargs[grid_thw_key] + if not isinstance(grid_thw, list): + grid_thw = grid_thw.tolist() + return grid_thw + + def get_encoder_cudagraph_num_items( + self, + mm_kwargs: dict[str, Any], + ) -> int: + return len(self._get_grid_thw_by_modality(mm_kwargs)) + + def get_encoder_cudagraph_per_item_output_tokens( + self, + mm_kwargs: dict[str, Any], + ) -> list[int]: + m = self.visual.spatial_merge_size + grid_thw = self._get_grid_thw_by_modality(mm_kwargs) + return [t * (h // m) * (w // m) for t, h, w in grid_thw] + + def get_encoder_cudagraph_per_item_input_sizes( + self, + mm_kwargs: dict[str, Any], + ) -> list[int]: + grid_thw = self._get_grid_thw_by_modality(mm_kwargs) + return [t * h * w for t, h, w in grid_thw] + + def select_encoder_cudagraph_items( + self, + mm_kwargs: dict[str, Any], + indices: list[int], + ) -> dict[str, Any]: + grid_thw = self._get_grid_thw_by_modality(mm_kwargs) + pixel_values = self._get_pixel_values_by_modality(mm_kwargs) + + if len(indices) == 0: + if self.get_input_modality(mm_kwargs) == "image": + return { + "pixel_values": pixel_values[:0], + "image_grid_thw": [], + } + else: + return { + "pixel_values_videos": pixel_values[:0], + "video_grid_thw": [], + } + + # Compute cumulative patch offsets for slicing pixel_values + patches_per_item = [t * h * w for t, h, w in grid_thw] + cum_patches = [0] + for p in patches_per_item: + cum_patches.append(cum_patches[-1] + p) + + selected_pv = torch.cat( + [pixel_values[cum_patches[i] : cum_patches[i + 1]] for i in indices] + ) + selected_grid = [grid_thw[i] for i in indices] + + if self.get_input_modality(mm_kwargs) == "image": + return { + "pixel_values": selected_pv, + "image_grid_thw": selected_grid, + } + else: + return { + "pixel_values_videos": selected_pv, + "video_grid_thw": selected_grid, + } + + def prepare_encoder_cudagraph_capture_inputs( + self, + token_budget: int, + max_batch_size: int, + max_frames_per_batch: int, + device: torch.device, + dtype: torch.dtype, + ): + from vllm.v1.worker.encoder_cudagraph_defs import ( + EncoderCudaGraphCaptureInputs, + ) + + spatial_merge_size = self.visual.spatial_merge_size + max_window_seqs_per_batch = min( + self.vllm_config.scheduler_config.max_num_batched_tokens, + self.model_config.max_model_len, + ) + # Use ceil here (not floor) so total captured capacity is never smaller + # than token_budget when token_budget is not divisible by max_batch_size + # (e.g., 324 budget with max_batch_size=8). Floor under-allocates + # input_buffer and can fail replay copy for valid single-item batches. + per_mm_item_output = (token_budget + max_batch_size - 1) // max_batch_size + + frames_per_item = max_frames_per_batch // max_batch_size + if frames_per_item > 1: + # Build the capture grid using a video-format layout so that + # cu_seqlens is sized for video replays from the start. + # cu_seqlens has one entry per attention sequence (one per frame), + # so using T > 1 per item makes the buffer large enough without + # relying solely on padding. + # Ceiling ensures frames_per_item * tokens_per_frame >= per_mm_item_output + # so the pixel_values buffer covers any valid single-item replay. + tokens_per_frame = ( + per_mm_item_output + frames_per_item - 1 + ) // frames_per_item + # Video-format grid_config (T=frames_per_item). + grid_config = [ + [ + frames_per_item, + spatial_merge_size, + tokens_per_frame * spatial_merge_size, + ] + for _ in range(max_batch_size) + ] + else: + # Image-format grid_config (T=1). + grid_config = [ + [1, spatial_merge_size, per_mm_item_output * spatial_merge_size] + for _ in range(max_batch_size) + ] + + # Create dummy pixel_values + patch_embed = self.visual.patch_embed + in_channels = patch_embed.proj.in_channels + patch_size = patch_embed.patch_size + temporal_patch_size = patch_embed.temporal_patch_size + total_patches = sum(t * h * w for t, h, w in grid_config) + flattened_patch_size = ( + in_channels * temporal_patch_size * patch_size * patch_size + ) + dummy_pixel_values = torch.randn( + total_patches, flattened_patch_size, device=device, dtype=dtype + ) + + # Override max_seqlen with a safe upper bound for capture. + # max_seqlen.item() gets baked into the CUDA graph (not replayed), + # so the capture value must cover any replay scenario. + # Worst case: 1 item consuming the full budget -> + # seq_len = token_budget * spatial_merge_size^2. + # For window-attention, each local window is bounded by fixed geometry: + # (window_size / patch_size / spatial_merge_size)^2 windows in merged + # token space, multiplied by spatial_merge_size^2 to map back to the + # unmerged sequence length used by attention kernels. + vit_merger_window_size = ( + self.visual.window_size + // self.visual.spatial_merge_size + // self.visual.patch_size + ) + max_seqlen_window_override = vit_merger_window_size**2 * (spatial_merge_size**2) + buffers = self.visual.prepare_encoder_metadata( + grid_config, + max_batch_size=max_batch_size, + max_frames_per_batch=max_frames_per_batch, + max_window_seqs_per_batch=max_window_seqs_per_batch, + max_seqlen_override=token_budget * (spatial_merge_size**2), + max_seqlen_window_override=max_seqlen_window_override, + device=device, + ) + + # Just use image-modality dummy input_buffer for capturing, since it's also + # compatible for video inputs (has the same shape: [num_patches, C*T*P*P]). + mm_kwargs = { + "pixel_values": dummy_pixel_values, + "image_grid_thw": grid_config, + } + + return EncoderCudaGraphCaptureInputs( + mm_kwargs=mm_kwargs, + buffers=buffers, + ) + + def prepare_encoder_cudagraph_replay_buffers( + self, + mm_kwargs: dict[str, Any], + max_batch_size: int, + max_frames_per_batch: int, + ): + modality = self.get_input_modality(mm_kwargs) + grid_thw_list = self._get_grid_thw_by_modality(mm_kwargs) + + if modality == "image": + buffers = self.visual.prepare_encoder_metadata( + grid_thw_list, + max_batch_size=max_batch_size, + max_window_seqs_per_batch=min( + self.vllm_config.scheduler_config.max_num_batched_tokens, + self.model_config.max_model_len, + ), + ) + else: + buffers = self.visual.prepare_encoder_metadata( + grid_thw_list, + max_frames_per_batch=max_frames_per_batch, + max_window_seqs_per_batch=min( + self.vllm_config.scheduler_config.max_num_batched_tokens, + self.model_config.max_model_len, + ), + ) + + return EncoderCudaGraphReplayBuffers(buffers=buffers) + + def encoder_cudagraph_forward( + self, + mm_kwargs: dict[str, Any], + buffers: dict[str, torch.Tensor], + ) -> torch.Tensor: + pixel_values = self._get_pixel_values_by_modality(mm_kwargs) + grid_thw = self._get_grid_thw_by_modality(mm_kwargs) + return self.visual(pixel_values, grid_thw, encoder_metadata=buffers) + + def encoder_eager_forward( + self, + mm_kwargs: dict[str, Any], + ) -> torch.Tensor: + pixel_values = self._get_pixel_values_by_modality(mm_kwargs) + grid_thw = self._get_grid_thw_by_modality(mm_kwargs) + return self.visual(pixel_values, grid_thw) + def forward( self, input_ids: torch.Tensor | None,