mirror of
https://github.com/vllm-project/vllm.git
synced 2026-06-06 00:16:14 +00:00
[MM][CG] Support ViT CG for Qwen2.5-VL (#40830)
Signed-off-by: John Calderon <jcalderon@nvidia.com>
This commit is contained in:
@@ -86,9 +86,11 @@ Models opt-in to encoder CUDA Graphs by implementing the [SupportsEncoderCudaGra
|
||||
| Architecture | Models | CG for Image | CG for Video |
|
||||
| ------------ | ------ | ------------ | ------------ |
|
||||
| `Qwen3VLForConditionalGeneration` | `Qwen3-VL` | ✅︎ | ✅︎ |
|
||||
| `Qwen2_5_VLForConditionalGeneration` | `Qwen2.5-VL` | ✅︎ | ✅︎ |
|
||||
|
||||
!!! note
|
||||
Encoder CUDA Graphs have currently been tested with `--mm-encoder-attn-backend=FLASH_ATTN` and `--mm-encoder-attn-backend=FLASHINFER` on Blackwell GPUs.
|
||||
For Qwen2.5-VL only FA2 and FA3 has been tested.
|
||||
|
||||
## Configuration
|
||||
|
||||
|
||||
@@ -2466,6 +2466,7 @@ MODELS_NEED_VIDEO_METADATA = [
|
||||
MODELS_SUPPORT_VIT_CUDA_GRAPH = [
|
||||
"qwen3_vl",
|
||||
"qwen3_vl_moe",
|
||||
"qwen2_5_vl",
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.assets.image import ImageAsset
|
||||
from vllm.multimodal.video import sample_frames_from_video
|
||||
|
||||
from ....conftest import VIDEO_ASSETS
|
||||
@@ -11,6 +12,7 @@ models = ["Qwen/Qwen2.5-VL-3B-Instruct"]
|
||||
target_dtype = "bfloat16"
|
||||
|
||||
VIDEO_PLACEHOLDER = "<|vision_start|><|video_pad|><|vision_end|>"
|
||||
IMAGE_PLACEHOLDER = "<|vision_start|><|image_pad|><|vision_end|>"
|
||||
|
||||
|
||||
def qwen2_5_vl_chat_template(*query):
|
||||
@@ -28,6 +30,25 @@ VIDEO_PROMPTS = VIDEO_ASSETS.prompts(
|
||||
)
|
||||
|
||||
|
||||
WINDOW_ATTN_IMAGE_PROMPT = qwen2_5_vl_chat_template(
|
||||
IMAGE_PLACEHOLDER,
|
||||
"Describe the image.",
|
||||
)
|
||||
|
||||
|
||||
def _window_attention_regression_image():
|
||||
# image from regression issue: https://github.com/vllm-project/vllm/issues/15122
|
||||
image = ImageAsset("hato").pil_image
|
||||
return image.resize((image.width // 2, image.height // 2))
|
||||
|
||||
|
||||
def _encoder_cudagraph_config(*, max_vision_items: int) -> dict:
|
||||
return {
|
||||
"cudagraph_mm_encoder": True,
|
||||
"encoder_cudagraph_max_vision_items_per_batch": max_vision_items,
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.core_model
|
||||
@pytest.mark.parametrize("model", models)
|
||||
@pytest.mark.parametrize("video_pruning_rate", [0.0, 0.75])
|
||||
@@ -146,3 +167,77 @@ def test_qwen2_5_vl_evs_batched_videos(
|
||||
|
||||
# Ensure the output is a string
|
||||
assert isinstance(output_text, str)
|
||||
|
||||
|
||||
@pytest.mark.core_model
|
||||
@pytest.mark.parametrize("model", models)
|
||||
@pytest.mark.parametrize("dtype", [target_dtype])
|
||||
@pytest.mark.parametrize("max_tokens", [128])
|
||||
@pytest.mark.parametrize("use_bytecode_hook", [True, False])
|
||||
def test_qwen2_5_vl_window_attention_image(
|
||||
vllm_runner,
|
||||
model,
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
use_bytecode_hook: bool,
|
||||
monkeypatch,
|
||||
) -> None:
|
||||
"""Regression test for Qwen2.5 window-attention image path."""
|
||||
monkeypatch.setenv("VLLM_USE_BYTECODE_HOOK", "1" if use_bytecode_hook else "0")
|
||||
|
||||
prompt = [WINDOW_ATTN_IMAGE_PROMPT]
|
||||
images = [[_window_attention_regression_image()]]
|
||||
|
||||
with vllm_runner(
|
||||
model,
|
||||
runner="generate",
|
||||
max_model_len=4096,
|
||||
dtype=dtype,
|
||||
limit_mm_per_prompt={"image": 1},
|
||||
compilation_config=_encoder_cudagraph_config(max_vision_items=1),
|
||||
) as vllm_model:
|
||||
outputs = vllm_model.generate_greedy(prompt, max_tokens, images=images)
|
||||
|
||||
assert len(outputs) == 1
|
||||
output_ids, output_text = outputs[0]
|
||||
assert len(output_ids) > 0
|
||||
assert len(output_text) > 0
|
||||
assert isinstance(output_text, str)
|
||||
|
||||
|
||||
@pytest.mark.core_model
|
||||
@pytest.mark.parametrize("model", models)
|
||||
@pytest.mark.parametrize("dtype", [target_dtype])
|
||||
@pytest.mark.parametrize("max_tokens", [128])
|
||||
@pytest.mark.parametrize("use_bytecode_hook", [True, False])
|
||||
def test_qwen2_5_vl_window_attention_image_batch(
|
||||
vllm_runner,
|
||||
model,
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
use_bytecode_hook: bool,
|
||||
monkeypatch,
|
||||
) -> None:
|
||||
"""Regression test window-attention with a small image batch."""
|
||||
monkeypatch.setenv("VLLM_USE_BYTECODE_HOOK", "1" if use_bytecode_hook else "0")
|
||||
|
||||
image = _window_attention_regression_image()
|
||||
prompts = [WINDOW_ATTN_IMAGE_PROMPT, WINDOW_ATTN_IMAGE_PROMPT]
|
||||
images = [[image], [image]]
|
||||
|
||||
with vllm_runner(
|
||||
model,
|
||||
runner="generate",
|
||||
max_model_len=4096,
|
||||
max_num_seqs=2,
|
||||
dtype=dtype,
|
||||
limit_mm_per_prompt={"image": 1},
|
||||
compilation_config=_encoder_cudagraph_config(max_vision_items=2),
|
||||
) as vllm_model:
|
||||
outputs = vllm_model.generate_greedy(prompts, max_tokens, images=images)
|
||||
|
||||
assert len(outputs) == 2
|
||||
for output_ids, output_text in outputs:
|
||||
assert len(output_ids) > 0
|
||||
assert len(output_text) > 0
|
||||
assert isinstance(output_text, str)
|
||||
|
||||
@@ -54,7 +54,18 @@ MODEL_CONFIGS: dict[str, VitCudagraphTestConfig] = {
|
||||
needs_video_metadata=True,
|
||||
marks=[pytest.mark.core_model],
|
||||
),
|
||||
# TODO: Add more models below.
|
||||
"qwen2_5_vl": VitCudagraphTestConfig(
|
||||
model="Qwen/Qwen2.5-VL-3B-Instruct",
|
||||
image_prompt=qwen_vl_chat_template(
|
||||
"<|vision_start|><|image_pad|><|vision_end|>What is in this image?"
|
||||
),
|
||||
video_prompt=qwen_vl_chat_template(
|
||||
"<|vision_start|><|video_pad|><|vision_end|>"
|
||||
"Describe this video in one sentence."
|
||||
),
|
||||
needs_video_metadata=False,
|
||||
marks=[pytest.mark.core_model],
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -85,11 +85,13 @@ from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils.platform_utils import is_pin_memory_available
|
||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
from vllm.v1.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.v1.worker.encoder_cudagraph_defs import EncoderCudaGraphReplayBuffers
|
||||
|
||||
from .interfaces import (
|
||||
MultiModalEmbeddings,
|
||||
SupportsEagle,
|
||||
SupportsEagle3,
|
||||
SupportsEncoderCudaGraph,
|
||||
SupportsLoRA,
|
||||
SupportsMRoPE,
|
||||
SupportsMultiModal,
|
||||
@@ -771,22 +773,54 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
||||
inv[perm] = torch.arange(perm.numel(), device=perm.device, dtype=perm.dtype)
|
||||
return inv
|
||||
|
||||
def forward(
|
||||
def prepare_encoder_metadata(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
grid_thw: list[list[int]],
|
||||
) -> torch.Tensor:
|
||||
*,
|
||||
max_batch_size: int | None = None,
|
||||
max_frames_per_batch: int | None = None,
|
||||
max_window_seqs_per_batch: int | None = None,
|
||||
max_seqlen_override: int | None = None,
|
||||
max_seqlen_window_override: int | None = None,
|
||||
device: torch.device | None = None,
|
||||
) -> dict[str, torch.Tensor]:
|
||||
"""Compute encoder metadata from grid_thw.
|
||||
|
||||
Shared by the eager forward path, CUDA graph capture, and
|
||||
CUDA graph replay to avoid duplicated implementation.
|
||||
|
||||
Args:
|
||||
grid_thw: Grid configurations as list of [t, h, w].
|
||||
max_batch_size: If set, pad cu_seqlens to this size
|
||||
(needed for CUDA graph capture/replay).
|
||||
max_frames_per_batch: If set, overrides max_batch_size for
|
||||
cu_seqlens padding. For video inputs each item contributes
|
||||
T attention sequences (frames); this sizes the buffer to
|
||||
the total frame budget so video replays never overflow.
|
||||
max_window_seqs_per_batch: If set, pad cu_window_seqlens to this
|
||||
number of window sequences. This keeps cu_window_seqlens shape
|
||||
stable across capture/replay for CUDA graph safety.
|
||||
max_seqlen_override: If set, use this value for max_seqlen
|
||||
instead of computing from cu_seqlens (needed for CUDA
|
||||
graph capture to cover worst-case replay scenarios).
|
||||
max_seqlen_window_override: If set, use this value for
|
||||
window-attention max_seqlen instead of computing from
|
||||
cu_window_seqlens (needed for CUDA graph capture to
|
||||
cover worst-case replay scenarios).
|
||||
device: Device to place tensors on. Defaults to self.device.
|
||||
"""
|
||||
|
||||
if device is None:
|
||||
device = self.device
|
||||
metadata: dict[str, torch.Tensor] = {}
|
||||
|
||||
# patchify
|
||||
seq_len, _ = x.size()
|
||||
rotary_pos_emb_cos = []
|
||||
rotary_pos_emb_sin = []
|
||||
window_index: list = []
|
||||
cu_window_seqlens: list = [torch.tensor([0], dtype=torch.int32)]
|
||||
cu_seqlens: list = []
|
||||
|
||||
hidden_states = x.to(device=self.device, dtype=self.dtype)
|
||||
hidden_states = self.patch_embed(hidden_states)
|
||||
|
||||
window_index_id = 0
|
||||
cu_window_seqlens_last = 0
|
||||
for t, h, w in grid_thw:
|
||||
@@ -825,23 +859,99 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
||||
cu_seqlens = torch.cumsum(cu_seqlens, dim=0, dtype=torch.int32)
|
||||
cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0)
|
||||
|
||||
# Pad cu_seqlens to the required number of sequences.
|
||||
# For videos each item contributes T frames = T attention sequences,
|
||||
# so the total can exceed max_batch_size. max_frames_per_batch
|
||||
# overrides the pad target when set.
|
||||
pad_to = (
|
||||
max_frames_per_batch if max_frames_per_batch is not None else max_batch_size
|
||||
)
|
||||
if pad_to is not None:
|
||||
num_seqs = len(cu_seqlens) - 1
|
||||
if num_seqs < pad_to:
|
||||
cu_seqlens = torch.cat(
|
||||
(
|
||||
cu_seqlens,
|
||||
torch.full(
|
||||
(pad_to - num_seqs,),
|
||||
cu_seqlens[-1],
|
||||
dtype=cu_seqlens.dtype,
|
||||
device=cu_seqlens.device,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
# Pad cu_window_seqlens to a stable number of window sequences.
|
||||
# Like cu_seqlens, we repeat the last cumulative offset so padded
|
||||
# entries represent empty sequences.
|
||||
if max_window_seqs_per_batch is not None:
|
||||
num_window_seqs = len(cu_window_seqlens) - 1
|
||||
if num_window_seqs < max_window_seqs_per_batch:
|
||||
cu_window_seqlens = torch.cat(
|
||||
(
|
||||
cu_window_seqlens,
|
||||
torch.full(
|
||||
(max_window_seqs_per_batch - num_window_seqs,),
|
||||
cu_window_seqlens[-1],
|
||||
dtype=cu_window_seqlens.dtype,
|
||||
device=cu_window_seqlens.device,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
# transformers
|
||||
# pre-compute seqlens for window/full attn to reduce cuMemcpy operations
|
||||
max_seqlen_full = self.compute_attn_mask_seqlen(cu_seqlens)
|
||||
max_seqlen_window = self.compute_attn_mask_seqlen(cu_window_seqlens)
|
||||
if max_seqlen_override is None:
|
||||
max_seqlen_full = self.compute_attn_mask_seqlen(cu_seqlens)
|
||||
else:
|
||||
max_seqlen_full = torch.tensor(max_seqlen_override, dtype=torch.int32)
|
||||
if max_seqlen_window_override is None:
|
||||
max_seqlen_window = self.compute_attn_mask_seqlen(cu_window_seqlens)
|
||||
else:
|
||||
max_seqlen_window = torch.tensor(
|
||||
max_seqlen_window_override, dtype=torch.int32
|
||||
)
|
||||
|
||||
cu_seqlens = cu_seqlens.to(device=self.device, non_blocking=True)
|
||||
cu_window_seqlens = cu_window_seqlens.to(device=self.device, non_blocking=True)
|
||||
rotary_pos_emb_cos = rotary_pos_emb_cos.to(
|
||||
device=self.device, non_blocking=True
|
||||
)
|
||||
rotary_pos_emb_sin = rotary_pos_emb_sin.to(
|
||||
device=self.device, non_blocking=True
|
||||
)
|
||||
window_index = window_index.to(device=hidden_states.device, non_blocking=True)
|
||||
reverse_indices = reverse_indices.to(
|
||||
device=hidden_states.device, non_blocking=True
|
||||
)
|
||||
cu_seqlens = cu_seqlens.to(device=device, non_blocking=True)
|
||||
cu_window_seqlens = cu_window_seqlens.to(device=device, non_blocking=True)
|
||||
rotary_pos_emb_cos = rotary_pos_emb_cos.to(device=device, non_blocking=True)
|
||||
rotary_pos_emb_sin = rotary_pos_emb_sin.to(device=device, non_blocking=True)
|
||||
window_index = window_index.to(device=device, non_blocking=True)
|
||||
reverse_indices = reverse_indices.to(device=device, non_blocking=True)
|
||||
|
||||
metadata["rotary_pos_emb_cos"] = rotary_pos_emb_cos
|
||||
metadata["rotary_pos_emb_sin"] = rotary_pos_emb_sin
|
||||
metadata["window_index"] = window_index
|
||||
metadata["reverse_indices"] = reverse_indices
|
||||
metadata["cu_seqlens"] = cu_seqlens
|
||||
metadata["cu_window_seqlens"] = cu_window_seqlens
|
||||
metadata["max_seqlen_full"] = max_seqlen_full
|
||||
metadata["max_seqlen_window"] = max_seqlen_window
|
||||
|
||||
return metadata
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
grid_thw: list[list[int]],
|
||||
*,
|
||||
encoder_metadata: dict[str, torch.Tensor] | None = None,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = x.to(device=self.device, dtype=self.dtype)
|
||||
hidden_states = self.patch_embed(hidden_states)
|
||||
|
||||
seq_len = hidden_states.shape[0]
|
||||
if encoder_metadata is None:
|
||||
encoder_metadata = self.prepare_encoder_metadata(grid_thw)
|
||||
|
||||
rotary_pos_emb_cos = encoder_metadata["rotary_pos_emb_cos"]
|
||||
rotary_pos_emb_sin = encoder_metadata["rotary_pos_emb_sin"]
|
||||
window_index = encoder_metadata["window_index"]
|
||||
reverse_indices = encoder_metadata["reverse_indices"]
|
||||
cu_seqlens = encoder_metadata["cu_seqlens"]
|
||||
cu_window_seqlens = encoder_metadata["cu_window_seqlens"]
|
||||
max_seqlen_full = encoder_metadata["max_seqlen_full"]
|
||||
max_seqlen_window = encoder_metadata["max_seqlen_window"]
|
||||
|
||||
hidden_states = hidden_states.reshape(
|
||||
seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1
|
||||
@@ -1003,6 +1113,7 @@ class Qwen2_5_VLMultiModalProcessor(Qwen2VLMultiModalProcessor):
|
||||
class Qwen2_5_VLForConditionalGeneration(
|
||||
nn.Module,
|
||||
SupportsMultiModal,
|
||||
SupportsEncoderCudaGraph,
|
||||
SupportsLoRA,
|
||||
SupportsPP,
|
||||
SupportsQuant,
|
||||
@@ -1124,6 +1235,7 @@ class Qwen2_5_VLForConditionalGeneration(
|
||||
|
||||
self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
|
||||
self.config = config
|
||||
self.model_config = vllm_config.model_config
|
||||
self.vllm_config = vllm_config
|
||||
self.multimodal_config = multimodal_config
|
||||
self.video_pruning_rate = multimodal_config.video_pruning_rate
|
||||
@@ -1447,6 +1559,302 @@ class Qwen2_5_VLForConditionalGeneration(
|
||||
multimodal_embeddings += tuple(video_embeddings)
|
||||
return multimodal_embeddings
|
||||
|
||||
# -- SupportsEncoderCudaGraph protocol methods --
|
||||
|
||||
def get_encoder_cudagraph_config(self):
|
||||
from vllm.v1.worker.encoder_cudagraph_defs import (
|
||||
EncoderCudaGraphConfig,
|
||||
)
|
||||
|
||||
# NOTE: With EVS pruning enabled, multimodal embeddings are post-processed
|
||||
# (append positions for image and prune+append positions for video) in
|
||||
# embed_multimodal(). The encoder CUDA graph path bypasses that postprocess
|
||||
# hook, so disable CUDA graph for all modalities to avoid inconsistent
|
||||
# embedding formats between eager and cudagraph paths.
|
||||
modalities = [] if self.is_multimodal_pruning_enabled else ["image", "video"]
|
||||
|
||||
return EncoderCudaGraphConfig(
|
||||
modalities=modalities,
|
||||
input_key_by_modality={
|
||||
"image": "pixel_values",
|
||||
"video": "pixel_values_videos",
|
||||
},
|
||||
buffer_keys=[
|
||||
"rotary_pos_emb_cos",
|
||||
"rotary_pos_emb_sin",
|
||||
"window_index",
|
||||
"reverse_indices",
|
||||
"cu_seqlens",
|
||||
"cu_window_seqlens",
|
||||
"max_seqlen_full",
|
||||
"max_seqlen_window",
|
||||
],
|
||||
out_hidden_size=self.visual.out_hidden_size,
|
||||
)
|
||||
|
||||
def get_input_modality(
|
||||
self,
|
||||
mm_kwargs: dict[str, Any],
|
||||
) -> str:
|
||||
if "image_grid_thw" in mm_kwargs:
|
||||
return "image"
|
||||
return "video"
|
||||
|
||||
def get_max_frames_per_video(self) -> int:
|
||||
mm_registry = MULTIMODAL_REGISTRY
|
||||
info = mm_registry.get_processing_info(self.model_config)
|
||||
max_frames_per_video = info.get_num_frames_with_most_features(
|
||||
seq_len=self.model_config.max_model_len,
|
||||
mm_counts={"video": self.multimodal_config.get_limit_per_prompt("video")},
|
||||
)
|
||||
return max_frames_per_video
|
||||
|
||||
def get_encoder_cudagraph_budget_range(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
) -> tuple[int, int]:
|
||||
# Min: estimated smallest possible encoder input.
|
||||
# 224x224 image → 16x16 patches (patch_size=14)
|
||||
# spatial_merge_size=2 → 8x8 = 64 tokens
|
||||
min_budget = 64
|
||||
# Max: capped by max_num_batched_tokens
|
||||
max_budget = min(
|
||||
vllm_config.scheduler_config.max_num_batched_tokens,
|
||||
self.model_config.max_model_len,
|
||||
)
|
||||
return (min_budget, max_budget)
|
||||
|
||||
def _get_pixel_values_by_modality(
|
||||
self,
|
||||
mm_kwargs: dict[str, Any],
|
||||
) -> torch.Tensor:
|
||||
if self.get_input_modality(mm_kwargs) == "image":
|
||||
pixel_values = mm_kwargs["pixel_values"]
|
||||
else:
|
||||
pixel_values = mm_kwargs["pixel_values_videos"]
|
||||
return pixel_values
|
||||
|
||||
def _get_grid_thw_by_modality(
|
||||
self,
|
||||
mm_kwargs: dict[str, Any],
|
||||
) -> list[tuple[int, int, int]]:
|
||||
grid_thw_key = f"{self.get_input_modality(mm_kwargs)}_grid_thw"
|
||||
grid_thw = mm_kwargs[grid_thw_key]
|
||||
if not isinstance(grid_thw, list):
|
||||
grid_thw = grid_thw.tolist()
|
||||
return grid_thw
|
||||
|
||||
def get_encoder_cudagraph_num_items(
|
||||
self,
|
||||
mm_kwargs: dict[str, Any],
|
||||
) -> int:
|
||||
return len(self._get_grid_thw_by_modality(mm_kwargs))
|
||||
|
||||
def get_encoder_cudagraph_per_item_output_tokens(
|
||||
self,
|
||||
mm_kwargs: dict[str, Any],
|
||||
) -> list[int]:
|
||||
m = self.visual.spatial_merge_size
|
||||
grid_thw = self._get_grid_thw_by_modality(mm_kwargs)
|
||||
return [t * (h // m) * (w // m) for t, h, w in grid_thw]
|
||||
|
||||
def get_encoder_cudagraph_per_item_input_sizes(
|
||||
self,
|
||||
mm_kwargs: dict[str, Any],
|
||||
) -> list[int]:
|
||||
grid_thw = self._get_grid_thw_by_modality(mm_kwargs)
|
||||
return [t * h * w for t, h, w in grid_thw]
|
||||
|
||||
def select_encoder_cudagraph_items(
|
||||
self,
|
||||
mm_kwargs: dict[str, Any],
|
||||
indices: list[int],
|
||||
) -> dict[str, Any]:
|
||||
grid_thw = self._get_grid_thw_by_modality(mm_kwargs)
|
||||
pixel_values = self._get_pixel_values_by_modality(mm_kwargs)
|
||||
|
||||
if len(indices) == 0:
|
||||
if self.get_input_modality(mm_kwargs) == "image":
|
||||
return {
|
||||
"pixel_values": pixel_values[:0],
|
||||
"image_grid_thw": [],
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"pixel_values_videos": pixel_values[:0],
|
||||
"video_grid_thw": [],
|
||||
}
|
||||
|
||||
# Compute cumulative patch offsets for slicing pixel_values
|
||||
patches_per_item = [t * h * w for t, h, w in grid_thw]
|
||||
cum_patches = [0]
|
||||
for p in patches_per_item:
|
||||
cum_patches.append(cum_patches[-1] + p)
|
||||
|
||||
selected_pv = torch.cat(
|
||||
[pixel_values[cum_patches[i] : cum_patches[i + 1]] for i in indices]
|
||||
)
|
||||
selected_grid = [grid_thw[i] for i in indices]
|
||||
|
||||
if self.get_input_modality(mm_kwargs) == "image":
|
||||
return {
|
||||
"pixel_values": selected_pv,
|
||||
"image_grid_thw": selected_grid,
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"pixel_values_videos": selected_pv,
|
||||
"video_grid_thw": selected_grid,
|
||||
}
|
||||
|
||||
def prepare_encoder_cudagraph_capture_inputs(
|
||||
self,
|
||||
token_budget: int,
|
||||
max_batch_size: int,
|
||||
max_frames_per_batch: int,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
from vllm.v1.worker.encoder_cudagraph_defs import (
|
||||
EncoderCudaGraphCaptureInputs,
|
||||
)
|
||||
|
||||
spatial_merge_size = self.visual.spatial_merge_size
|
||||
max_window_seqs_per_batch = min(
|
||||
self.vllm_config.scheduler_config.max_num_batched_tokens,
|
||||
self.model_config.max_model_len,
|
||||
)
|
||||
# Use ceil here (not floor) so total captured capacity is never smaller
|
||||
# than token_budget when token_budget is not divisible by max_batch_size
|
||||
# (e.g., 324 budget with max_batch_size=8). Floor under-allocates
|
||||
# input_buffer and can fail replay copy for valid single-item batches.
|
||||
per_mm_item_output = (token_budget + max_batch_size - 1) // max_batch_size
|
||||
|
||||
frames_per_item = max_frames_per_batch // max_batch_size
|
||||
if frames_per_item > 1:
|
||||
# Build the capture grid using a video-format layout so that
|
||||
# cu_seqlens is sized for video replays from the start.
|
||||
# cu_seqlens has one entry per attention sequence (one per frame),
|
||||
# so using T > 1 per item makes the buffer large enough without
|
||||
# relying solely on padding.
|
||||
# Ceiling ensures frames_per_item * tokens_per_frame >= per_mm_item_output
|
||||
# so the pixel_values buffer covers any valid single-item replay.
|
||||
tokens_per_frame = (
|
||||
per_mm_item_output + frames_per_item - 1
|
||||
) // frames_per_item
|
||||
# Video-format grid_config (T=frames_per_item).
|
||||
grid_config = [
|
||||
[
|
||||
frames_per_item,
|
||||
spatial_merge_size,
|
||||
tokens_per_frame * spatial_merge_size,
|
||||
]
|
||||
for _ in range(max_batch_size)
|
||||
]
|
||||
else:
|
||||
# Image-format grid_config (T=1).
|
||||
grid_config = [
|
||||
[1, spatial_merge_size, per_mm_item_output * spatial_merge_size]
|
||||
for _ in range(max_batch_size)
|
||||
]
|
||||
|
||||
# Create dummy pixel_values
|
||||
patch_embed = self.visual.patch_embed
|
||||
in_channels = patch_embed.proj.in_channels
|
||||
patch_size = patch_embed.patch_size
|
||||
temporal_patch_size = patch_embed.temporal_patch_size
|
||||
total_patches = sum(t * h * w for t, h, w in grid_config)
|
||||
flattened_patch_size = (
|
||||
in_channels * temporal_patch_size * patch_size * patch_size
|
||||
)
|
||||
dummy_pixel_values = torch.randn(
|
||||
total_patches, flattened_patch_size, device=device, dtype=dtype
|
||||
)
|
||||
|
||||
# Override max_seqlen with a safe upper bound for capture.
|
||||
# max_seqlen.item() gets baked into the CUDA graph (not replayed),
|
||||
# so the capture value must cover any replay scenario.
|
||||
# Worst case: 1 item consuming the full budget ->
|
||||
# seq_len = token_budget * spatial_merge_size^2.
|
||||
# For window-attention, each local window is bounded by fixed geometry:
|
||||
# (window_size / patch_size / spatial_merge_size)^2 windows in merged
|
||||
# token space, multiplied by spatial_merge_size^2 to map back to the
|
||||
# unmerged sequence length used by attention kernels.
|
||||
vit_merger_window_size = (
|
||||
self.visual.window_size
|
||||
// self.visual.spatial_merge_size
|
||||
// self.visual.patch_size
|
||||
)
|
||||
max_seqlen_window_override = vit_merger_window_size**2 * (spatial_merge_size**2)
|
||||
buffers = self.visual.prepare_encoder_metadata(
|
||||
grid_config,
|
||||
max_batch_size=max_batch_size,
|
||||
max_frames_per_batch=max_frames_per_batch,
|
||||
max_window_seqs_per_batch=max_window_seqs_per_batch,
|
||||
max_seqlen_override=token_budget * (spatial_merge_size**2),
|
||||
max_seqlen_window_override=max_seqlen_window_override,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# Just use image-modality dummy input_buffer for capturing, since it's also
|
||||
# compatible for video inputs (has the same shape: [num_patches, C*T*P*P]).
|
||||
mm_kwargs = {
|
||||
"pixel_values": dummy_pixel_values,
|
||||
"image_grid_thw": grid_config,
|
||||
}
|
||||
|
||||
return EncoderCudaGraphCaptureInputs(
|
||||
mm_kwargs=mm_kwargs,
|
||||
buffers=buffers,
|
||||
)
|
||||
|
||||
def prepare_encoder_cudagraph_replay_buffers(
|
||||
self,
|
||||
mm_kwargs: dict[str, Any],
|
||||
max_batch_size: int,
|
||||
max_frames_per_batch: int,
|
||||
):
|
||||
modality = self.get_input_modality(mm_kwargs)
|
||||
grid_thw_list = self._get_grid_thw_by_modality(mm_kwargs)
|
||||
|
||||
if modality == "image":
|
||||
buffers = self.visual.prepare_encoder_metadata(
|
||||
grid_thw_list,
|
||||
max_batch_size=max_batch_size,
|
||||
max_window_seqs_per_batch=min(
|
||||
self.vllm_config.scheduler_config.max_num_batched_tokens,
|
||||
self.model_config.max_model_len,
|
||||
),
|
||||
)
|
||||
else:
|
||||
buffers = self.visual.prepare_encoder_metadata(
|
||||
grid_thw_list,
|
||||
max_frames_per_batch=max_frames_per_batch,
|
||||
max_window_seqs_per_batch=min(
|
||||
self.vllm_config.scheduler_config.max_num_batched_tokens,
|
||||
self.model_config.max_model_len,
|
||||
),
|
||||
)
|
||||
|
||||
return EncoderCudaGraphReplayBuffers(buffers=buffers)
|
||||
|
||||
def encoder_cudagraph_forward(
|
||||
self,
|
||||
mm_kwargs: dict[str, Any],
|
||||
buffers: dict[str, torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
pixel_values = self._get_pixel_values_by_modality(mm_kwargs)
|
||||
grid_thw = self._get_grid_thw_by_modality(mm_kwargs)
|
||||
return self.visual(pixel_values, grid_thw, encoder_metadata=buffers)
|
||||
|
||||
def encoder_eager_forward(
|
||||
self,
|
||||
mm_kwargs: dict[str, Any],
|
||||
) -> torch.Tensor:
|
||||
pixel_values = self._get_pixel_values_by_modality(mm_kwargs)
|
||||
grid_thw = self._get_grid_thw_by_modality(mm_kwargs)
|
||||
return self.visual(pixel_values, grid_thw)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor | None,
|
||||
|
||||
Reference in New Issue
Block a user