[MM][CG] Support ViT CG for Qwen2.5-VL (#40830)

Signed-off-by: John Calderon <jcalderon@nvidia.com>
This commit is contained in:
John Calderon
2026-05-01 23:10:14 -04:00
committed by GitHub
parent c408fdd663
commit 964a4bc2a5
5 changed files with 539 additions and 22 deletions
+2
View File
@@ -86,9 +86,11 @@ Models opt-in to encoder CUDA Graphs by implementing the [SupportsEncoderCudaGra
| Architecture | Models | CG for Image | CG for Video |
| ------------ | ------ | ------------ | ------------ |
| `Qwen3VLForConditionalGeneration` | `Qwen3-VL` | ✅︎ | ✅︎ |
| `Qwen2_5_VLForConditionalGeneration` | `Qwen2.5-VL` | ✅︎ | ✅︎ |
!!! note
Encoder CUDA Graphs have currently been tested with `--mm-encoder-attn-backend=FLASH_ATTN` and `--mm-encoder-attn-backend=FLASHINFER` on Blackwell GPUs.
For Qwen2.5-VL only FA2 and FA3 has been tested.
## Configuration
@@ -2466,6 +2466,7 @@ MODELS_NEED_VIDEO_METADATA = [
MODELS_SUPPORT_VIT_CUDA_GRAPH = [
"qwen3_vl",
"qwen3_vl_moe",
"qwen2_5_vl",
]
@@ -3,6 +3,7 @@
import pytest
from vllm.assets.image import ImageAsset
from vllm.multimodal.video import sample_frames_from_video
from ....conftest import VIDEO_ASSETS
@@ -11,6 +12,7 @@ models = ["Qwen/Qwen2.5-VL-3B-Instruct"]
target_dtype = "bfloat16"
VIDEO_PLACEHOLDER = "<|vision_start|><|video_pad|><|vision_end|>"
IMAGE_PLACEHOLDER = "<|vision_start|><|image_pad|><|vision_end|>"
def qwen2_5_vl_chat_template(*query):
@@ -28,6 +30,25 @@ VIDEO_PROMPTS = VIDEO_ASSETS.prompts(
)
WINDOW_ATTN_IMAGE_PROMPT = qwen2_5_vl_chat_template(
IMAGE_PLACEHOLDER,
"Describe the image.",
)
def _window_attention_regression_image():
# image from regression issue: https://github.com/vllm-project/vllm/issues/15122
image = ImageAsset("hato").pil_image
return image.resize((image.width // 2, image.height // 2))
def _encoder_cudagraph_config(*, max_vision_items: int) -> dict:
return {
"cudagraph_mm_encoder": True,
"encoder_cudagraph_max_vision_items_per_batch": max_vision_items,
}
@pytest.mark.core_model
@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize("video_pruning_rate", [0.0, 0.75])
@@ -146,3 +167,77 @@ def test_qwen2_5_vl_evs_batched_videos(
# Ensure the output is a string
assert isinstance(output_text, str)
@pytest.mark.core_model
@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize("dtype", [target_dtype])
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("use_bytecode_hook", [True, False])
def test_qwen2_5_vl_window_attention_image(
vllm_runner,
model,
dtype: str,
max_tokens: int,
use_bytecode_hook: bool,
monkeypatch,
) -> None:
"""Regression test for Qwen2.5 window-attention image path."""
monkeypatch.setenv("VLLM_USE_BYTECODE_HOOK", "1" if use_bytecode_hook else "0")
prompt = [WINDOW_ATTN_IMAGE_PROMPT]
images = [[_window_attention_regression_image()]]
with vllm_runner(
model,
runner="generate",
max_model_len=4096,
dtype=dtype,
limit_mm_per_prompt={"image": 1},
compilation_config=_encoder_cudagraph_config(max_vision_items=1),
) as vllm_model:
outputs = vllm_model.generate_greedy(prompt, max_tokens, images=images)
assert len(outputs) == 1
output_ids, output_text = outputs[0]
assert len(output_ids) > 0
assert len(output_text) > 0
assert isinstance(output_text, str)
@pytest.mark.core_model
@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize("dtype", [target_dtype])
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("use_bytecode_hook", [True, False])
def test_qwen2_5_vl_window_attention_image_batch(
vllm_runner,
model,
dtype: str,
max_tokens: int,
use_bytecode_hook: bool,
monkeypatch,
) -> None:
"""Regression test window-attention with a small image batch."""
monkeypatch.setenv("VLLM_USE_BYTECODE_HOOK", "1" if use_bytecode_hook else "0")
image = _window_attention_regression_image()
prompts = [WINDOW_ATTN_IMAGE_PROMPT, WINDOW_ATTN_IMAGE_PROMPT]
images = [[image], [image]]
with vllm_runner(
model,
runner="generate",
max_model_len=4096,
max_num_seqs=2,
dtype=dtype,
limit_mm_per_prompt={"image": 1},
compilation_config=_encoder_cudagraph_config(max_vision_items=2),
) as vllm_model:
outputs = vllm_model.generate_greedy(prompts, max_tokens, images=images)
assert len(outputs) == 2
for output_ids, output_text in outputs:
assert len(output_ids) > 0
assert len(output_text) > 0
assert isinstance(output_text, str)
@@ -54,7 +54,18 @@ MODEL_CONFIGS: dict[str, VitCudagraphTestConfig] = {
needs_video_metadata=True,
marks=[pytest.mark.core_model],
),
# TODO: Add more models below.
"qwen2_5_vl": VitCudagraphTestConfig(
model="Qwen/Qwen2.5-VL-3B-Instruct",
image_prompt=qwen_vl_chat_template(
"<|vision_start|><|image_pad|><|vision_end|>What is in this image?"
),
video_prompt=qwen_vl_chat_template(
"<|vision_start|><|video_pad|><|vision_end|>"
"Describe this video in one sentence."
),
needs_video_metadata=False,
marks=[pytest.mark.core_model],
),
}
+429 -21
View File
@@ -85,11 +85,13 @@ from vllm.sequence import IntermediateTensors
from vllm.utils.platform_utils import is_pin_memory_available
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from vllm.v1.attention.backends.registry import AttentionBackendEnum
from vllm.v1.worker.encoder_cudagraph_defs import EncoderCudaGraphReplayBuffers
from .interfaces import (
MultiModalEmbeddings,
SupportsEagle,
SupportsEagle3,
SupportsEncoderCudaGraph,
SupportsLoRA,
SupportsMRoPE,
SupportsMultiModal,
@@ -771,22 +773,54 @@ class Qwen2_5_VisionTransformer(nn.Module):
inv[perm] = torch.arange(perm.numel(), device=perm.device, dtype=perm.dtype)
return inv
def forward(
def prepare_encoder_metadata(
self,
x: torch.Tensor,
grid_thw: list[list[int]],
) -> torch.Tensor:
*,
max_batch_size: int | None = None,
max_frames_per_batch: int | None = None,
max_window_seqs_per_batch: int | None = None,
max_seqlen_override: int | None = None,
max_seqlen_window_override: int | None = None,
device: torch.device | None = None,
) -> dict[str, torch.Tensor]:
"""Compute encoder metadata from grid_thw.
Shared by the eager forward path, CUDA graph capture, and
CUDA graph replay to avoid duplicated implementation.
Args:
grid_thw: Grid configurations as list of [t, h, w].
max_batch_size: If set, pad cu_seqlens to this size
(needed for CUDA graph capture/replay).
max_frames_per_batch: If set, overrides max_batch_size for
cu_seqlens padding. For video inputs each item contributes
T attention sequences (frames); this sizes the buffer to
the total frame budget so video replays never overflow.
max_window_seqs_per_batch: If set, pad cu_window_seqlens to this
number of window sequences. This keeps cu_window_seqlens shape
stable across capture/replay for CUDA graph safety.
max_seqlen_override: If set, use this value for max_seqlen
instead of computing from cu_seqlens (needed for CUDA
graph capture to cover worst-case replay scenarios).
max_seqlen_window_override: If set, use this value for
window-attention max_seqlen instead of computing from
cu_window_seqlens (needed for CUDA graph capture to
cover worst-case replay scenarios).
device: Device to place tensors on. Defaults to self.device.
"""
if device is None:
device = self.device
metadata: dict[str, torch.Tensor] = {}
# patchify
seq_len, _ = x.size()
rotary_pos_emb_cos = []
rotary_pos_emb_sin = []
window_index: list = []
cu_window_seqlens: list = [torch.tensor([0], dtype=torch.int32)]
cu_seqlens: list = []
hidden_states = x.to(device=self.device, dtype=self.dtype)
hidden_states = self.patch_embed(hidden_states)
window_index_id = 0
cu_window_seqlens_last = 0
for t, h, w in grid_thw:
@@ -825,23 +859,99 @@ class Qwen2_5_VisionTransformer(nn.Module):
cu_seqlens = torch.cumsum(cu_seqlens, dim=0, dtype=torch.int32)
cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0)
# Pad cu_seqlens to the required number of sequences.
# For videos each item contributes T frames = T attention sequences,
# so the total can exceed max_batch_size. max_frames_per_batch
# overrides the pad target when set.
pad_to = (
max_frames_per_batch if max_frames_per_batch is not None else max_batch_size
)
if pad_to is not None:
num_seqs = len(cu_seqlens) - 1
if num_seqs < pad_to:
cu_seqlens = torch.cat(
(
cu_seqlens,
torch.full(
(pad_to - num_seqs,),
cu_seqlens[-1],
dtype=cu_seqlens.dtype,
device=cu_seqlens.device,
),
)
)
# Pad cu_window_seqlens to a stable number of window sequences.
# Like cu_seqlens, we repeat the last cumulative offset so padded
# entries represent empty sequences.
if max_window_seqs_per_batch is not None:
num_window_seqs = len(cu_window_seqlens) - 1
if num_window_seqs < max_window_seqs_per_batch:
cu_window_seqlens = torch.cat(
(
cu_window_seqlens,
torch.full(
(max_window_seqs_per_batch - num_window_seqs,),
cu_window_seqlens[-1],
dtype=cu_window_seqlens.dtype,
device=cu_window_seqlens.device,
),
)
)
# transformers
# pre-compute seqlens for window/full attn to reduce cuMemcpy operations
max_seqlen_full = self.compute_attn_mask_seqlen(cu_seqlens)
max_seqlen_window = self.compute_attn_mask_seqlen(cu_window_seqlens)
if max_seqlen_override is None:
max_seqlen_full = self.compute_attn_mask_seqlen(cu_seqlens)
else:
max_seqlen_full = torch.tensor(max_seqlen_override, dtype=torch.int32)
if max_seqlen_window_override is None:
max_seqlen_window = self.compute_attn_mask_seqlen(cu_window_seqlens)
else:
max_seqlen_window = torch.tensor(
max_seqlen_window_override, dtype=torch.int32
)
cu_seqlens = cu_seqlens.to(device=self.device, non_blocking=True)
cu_window_seqlens = cu_window_seqlens.to(device=self.device, non_blocking=True)
rotary_pos_emb_cos = rotary_pos_emb_cos.to(
device=self.device, non_blocking=True
)
rotary_pos_emb_sin = rotary_pos_emb_sin.to(
device=self.device, non_blocking=True
)
window_index = window_index.to(device=hidden_states.device, non_blocking=True)
reverse_indices = reverse_indices.to(
device=hidden_states.device, non_blocking=True
)
cu_seqlens = cu_seqlens.to(device=device, non_blocking=True)
cu_window_seqlens = cu_window_seqlens.to(device=device, non_blocking=True)
rotary_pos_emb_cos = rotary_pos_emb_cos.to(device=device, non_blocking=True)
rotary_pos_emb_sin = rotary_pos_emb_sin.to(device=device, non_blocking=True)
window_index = window_index.to(device=device, non_blocking=True)
reverse_indices = reverse_indices.to(device=device, non_blocking=True)
metadata["rotary_pos_emb_cos"] = rotary_pos_emb_cos
metadata["rotary_pos_emb_sin"] = rotary_pos_emb_sin
metadata["window_index"] = window_index
metadata["reverse_indices"] = reverse_indices
metadata["cu_seqlens"] = cu_seqlens
metadata["cu_window_seqlens"] = cu_window_seqlens
metadata["max_seqlen_full"] = max_seqlen_full
metadata["max_seqlen_window"] = max_seqlen_window
return metadata
def forward(
self,
x: torch.Tensor,
grid_thw: list[list[int]],
*,
encoder_metadata: dict[str, torch.Tensor] | None = None,
) -> torch.Tensor:
hidden_states = x.to(device=self.device, dtype=self.dtype)
hidden_states = self.patch_embed(hidden_states)
seq_len = hidden_states.shape[0]
if encoder_metadata is None:
encoder_metadata = self.prepare_encoder_metadata(grid_thw)
rotary_pos_emb_cos = encoder_metadata["rotary_pos_emb_cos"]
rotary_pos_emb_sin = encoder_metadata["rotary_pos_emb_sin"]
window_index = encoder_metadata["window_index"]
reverse_indices = encoder_metadata["reverse_indices"]
cu_seqlens = encoder_metadata["cu_seqlens"]
cu_window_seqlens = encoder_metadata["cu_window_seqlens"]
max_seqlen_full = encoder_metadata["max_seqlen_full"]
max_seqlen_window = encoder_metadata["max_seqlen_window"]
hidden_states = hidden_states.reshape(
seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1
@@ -1003,6 +1113,7 @@ class Qwen2_5_VLMultiModalProcessor(Qwen2VLMultiModalProcessor):
class Qwen2_5_VLForConditionalGeneration(
nn.Module,
SupportsMultiModal,
SupportsEncoderCudaGraph,
SupportsLoRA,
SupportsPP,
SupportsQuant,
@@ -1124,6 +1235,7 @@ class Qwen2_5_VLForConditionalGeneration(
self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
self.config = config
self.model_config = vllm_config.model_config
self.vllm_config = vllm_config
self.multimodal_config = multimodal_config
self.video_pruning_rate = multimodal_config.video_pruning_rate
@@ -1447,6 +1559,302 @@ class Qwen2_5_VLForConditionalGeneration(
multimodal_embeddings += tuple(video_embeddings)
return multimodal_embeddings
# -- SupportsEncoderCudaGraph protocol methods --
def get_encoder_cudagraph_config(self):
from vllm.v1.worker.encoder_cudagraph_defs import (
EncoderCudaGraphConfig,
)
# NOTE: With EVS pruning enabled, multimodal embeddings are post-processed
# (append positions for image and prune+append positions for video) in
# embed_multimodal(). The encoder CUDA graph path bypasses that postprocess
# hook, so disable CUDA graph for all modalities to avoid inconsistent
# embedding formats between eager and cudagraph paths.
modalities = [] if self.is_multimodal_pruning_enabled else ["image", "video"]
return EncoderCudaGraphConfig(
modalities=modalities,
input_key_by_modality={
"image": "pixel_values",
"video": "pixel_values_videos",
},
buffer_keys=[
"rotary_pos_emb_cos",
"rotary_pos_emb_sin",
"window_index",
"reverse_indices",
"cu_seqlens",
"cu_window_seqlens",
"max_seqlen_full",
"max_seqlen_window",
],
out_hidden_size=self.visual.out_hidden_size,
)
def get_input_modality(
self,
mm_kwargs: dict[str, Any],
) -> str:
if "image_grid_thw" in mm_kwargs:
return "image"
return "video"
def get_max_frames_per_video(self) -> int:
mm_registry = MULTIMODAL_REGISTRY
info = mm_registry.get_processing_info(self.model_config)
max_frames_per_video = info.get_num_frames_with_most_features(
seq_len=self.model_config.max_model_len,
mm_counts={"video": self.multimodal_config.get_limit_per_prompt("video")},
)
return max_frames_per_video
def get_encoder_cudagraph_budget_range(
self,
vllm_config: VllmConfig,
) -> tuple[int, int]:
# Min: estimated smallest possible encoder input.
# 224x224 image → 16x16 patches (patch_size=14)
# spatial_merge_size=2 → 8x8 = 64 tokens
min_budget = 64
# Max: capped by max_num_batched_tokens
max_budget = min(
vllm_config.scheduler_config.max_num_batched_tokens,
self.model_config.max_model_len,
)
return (min_budget, max_budget)
def _get_pixel_values_by_modality(
self,
mm_kwargs: dict[str, Any],
) -> torch.Tensor:
if self.get_input_modality(mm_kwargs) == "image":
pixel_values = mm_kwargs["pixel_values"]
else:
pixel_values = mm_kwargs["pixel_values_videos"]
return pixel_values
def _get_grid_thw_by_modality(
self,
mm_kwargs: dict[str, Any],
) -> list[tuple[int, int, int]]:
grid_thw_key = f"{self.get_input_modality(mm_kwargs)}_grid_thw"
grid_thw = mm_kwargs[grid_thw_key]
if not isinstance(grid_thw, list):
grid_thw = grid_thw.tolist()
return grid_thw
def get_encoder_cudagraph_num_items(
self,
mm_kwargs: dict[str, Any],
) -> int:
return len(self._get_grid_thw_by_modality(mm_kwargs))
def get_encoder_cudagraph_per_item_output_tokens(
self,
mm_kwargs: dict[str, Any],
) -> list[int]:
m = self.visual.spatial_merge_size
grid_thw = self._get_grid_thw_by_modality(mm_kwargs)
return [t * (h // m) * (w // m) for t, h, w in grid_thw]
def get_encoder_cudagraph_per_item_input_sizes(
self,
mm_kwargs: dict[str, Any],
) -> list[int]:
grid_thw = self._get_grid_thw_by_modality(mm_kwargs)
return [t * h * w for t, h, w in grid_thw]
def select_encoder_cudagraph_items(
self,
mm_kwargs: dict[str, Any],
indices: list[int],
) -> dict[str, Any]:
grid_thw = self._get_grid_thw_by_modality(mm_kwargs)
pixel_values = self._get_pixel_values_by_modality(mm_kwargs)
if len(indices) == 0:
if self.get_input_modality(mm_kwargs) == "image":
return {
"pixel_values": pixel_values[:0],
"image_grid_thw": [],
}
else:
return {
"pixel_values_videos": pixel_values[:0],
"video_grid_thw": [],
}
# Compute cumulative patch offsets for slicing pixel_values
patches_per_item = [t * h * w for t, h, w in grid_thw]
cum_patches = [0]
for p in patches_per_item:
cum_patches.append(cum_patches[-1] + p)
selected_pv = torch.cat(
[pixel_values[cum_patches[i] : cum_patches[i + 1]] for i in indices]
)
selected_grid = [grid_thw[i] for i in indices]
if self.get_input_modality(mm_kwargs) == "image":
return {
"pixel_values": selected_pv,
"image_grid_thw": selected_grid,
}
else:
return {
"pixel_values_videos": selected_pv,
"video_grid_thw": selected_grid,
}
def prepare_encoder_cudagraph_capture_inputs(
self,
token_budget: int,
max_batch_size: int,
max_frames_per_batch: int,
device: torch.device,
dtype: torch.dtype,
):
from vllm.v1.worker.encoder_cudagraph_defs import (
EncoderCudaGraphCaptureInputs,
)
spatial_merge_size = self.visual.spatial_merge_size
max_window_seqs_per_batch = min(
self.vllm_config.scheduler_config.max_num_batched_tokens,
self.model_config.max_model_len,
)
# Use ceil here (not floor) so total captured capacity is never smaller
# than token_budget when token_budget is not divisible by max_batch_size
# (e.g., 324 budget with max_batch_size=8). Floor under-allocates
# input_buffer and can fail replay copy for valid single-item batches.
per_mm_item_output = (token_budget + max_batch_size - 1) // max_batch_size
frames_per_item = max_frames_per_batch // max_batch_size
if frames_per_item > 1:
# Build the capture grid using a video-format layout so that
# cu_seqlens is sized for video replays from the start.
# cu_seqlens has one entry per attention sequence (one per frame),
# so using T > 1 per item makes the buffer large enough without
# relying solely on padding.
# Ceiling ensures frames_per_item * tokens_per_frame >= per_mm_item_output
# so the pixel_values buffer covers any valid single-item replay.
tokens_per_frame = (
per_mm_item_output + frames_per_item - 1
) // frames_per_item
# Video-format grid_config (T=frames_per_item).
grid_config = [
[
frames_per_item,
spatial_merge_size,
tokens_per_frame * spatial_merge_size,
]
for _ in range(max_batch_size)
]
else:
# Image-format grid_config (T=1).
grid_config = [
[1, spatial_merge_size, per_mm_item_output * spatial_merge_size]
for _ in range(max_batch_size)
]
# Create dummy pixel_values
patch_embed = self.visual.patch_embed
in_channels = patch_embed.proj.in_channels
patch_size = patch_embed.patch_size
temporal_patch_size = patch_embed.temporal_patch_size
total_patches = sum(t * h * w for t, h, w in grid_config)
flattened_patch_size = (
in_channels * temporal_patch_size * patch_size * patch_size
)
dummy_pixel_values = torch.randn(
total_patches, flattened_patch_size, device=device, dtype=dtype
)
# Override max_seqlen with a safe upper bound for capture.
# max_seqlen.item() gets baked into the CUDA graph (not replayed),
# so the capture value must cover any replay scenario.
# Worst case: 1 item consuming the full budget ->
# seq_len = token_budget * spatial_merge_size^2.
# For window-attention, each local window is bounded by fixed geometry:
# (window_size / patch_size / spatial_merge_size)^2 windows in merged
# token space, multiplied by spatial_merge_size^2 to map back to the
# unmerged sequence length used by attention kernels.
vit_merger_window_size = (
self.visual.window_size
// self.visual.spatial_merge_size
// self.visual.patch_size
)
max_seqlen_window_override = vit_merger_window_size**2 * (spatial_merge_size**2)
buffers = self.visual.prepare_encoder_metadata(
grid_config,
max_batch_size=max_batch_size,
max_frames_per_batch=max_frames_per_batch,
max_window_seqs_per_batch=max_window_seqs_per_batch,
max_seqlen_override=token_budget * (spatial_merge_size**2),
max_seqlen_window_override=max_seqlen_window_override,
device=device,
)
# Just use image-modality dummy input_buffer for capturing, since it's also
# compatible for video inputs (has the same shape: [num_patches, C*T*P*P]).
mm_kwargs = {
"pixel_values": dummy_pixel_values,
"image_grid_thw": grid_config,
}
return EncoderCudaGraphCaptureInputs(
mm_kwargs=mm_kwargs,
buffers=buffers,
)
def prepare_encoder_cudagraph_replay_buffers(
self,
mm_kwargs: dict[str, Any],
max_batch_size: int,
max_frames_per_batch: int,
):
modality = self.get_input_modality(mm_kwargs)
grid_thw_list = self._get_grid_thw_by_modality(mm_kwargs)
if modality == "image":
buffers = self.visual.prepare_encoder_metadata(
grid_thw_list,
max_batch_size=max_batch_size,
max_window_seqs_per_batch=min(
self.vllm_config.scheduler_config.max_num_batched_tokens,
self.model_config.max_model_len,
),
)
else:
buffers = self.visual.prepare_encoder_metadata(
grid_thw_list,
max_frames_per_batch=max_frames_per_batch,
max_window_seqs_per_batch=min(
self.vllm_config.scheduler_config.max_num_batched_tokens,
self.model_config.max_model_len,
),
)
return EncoderCudaGraphReplayBuffers(buffers=buffers)
def encoder_cudagraph_forward(
self,
mm_kwargs: dict[str, Any],
buffers: dict[str, torch.Tensor],
) -> torch.Tensor:
pixel_values = self._get_pixel_values_by_modality(mm_kwargs)
grid_thw = self._get_grid_thw_by_modality(mm_kwargs)
return self.visual(pixel_values, grid_thw, encoder_metadata=buffers)
def encoder_eager_forward(
self,
mm_kwargs: dict[str, Any],
) -> torch.Tensor:
pixel_values = self._get_pixel_values_by_modality(mm_kwargs)
grid_thw = self._get_grid_thw_by_modality(mm_kwargs)
return self.visual(pixel_values, grid_thw)
def forward(
self,
input_ids: torch.Tensor | None,