diff --git a/benchmarks/kernels/benchmark_vit_bilinear_pos_embed.py b/benchmarks/kernels/benchmark_vit_bilinear_pos_embed.py new file mode 100644 index 00000000000..65171a1b2e1 --- /dev/null +++ b/benchmarks/kernels/benchmark_vit_bilinear_pos_embed.py @@ -0,0 +1,162 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Benchmarks the fused Triton bilinear position-embedding kernel against +# the pure-PyTorch (native) implementation used in Qwen3-VL ViT models. +# +# == Usage Examples == +# +# Default benchmark: +# python3 benchmark_vit_bilinear_pos_embed.py +# +# Custom parameters: +# python3 benchmark_vit_bilinear_pos_embed.py --hidden-dim 1152 \ +# --num-grid-per-side 48 --save-path ./configs/vit_pos_embed/ + +import itertools + +import torch + +from vllm.model_executor.models.qwen3_vl import ( + pos_embed_interpolate_native, + triton_pos_embed_interpolate, +) +from vllm.triton_utils import HAS_TRITON, triton +from vllm.utils.argparse_utils import FlexibleArgumentParser + +# (h, w) configurations to benchmark +h_w_configs = [ + (16, 16), + (32, 32), + (48, 48), + (64, 64), + (128, 128), + (32, 48), + (60, 80), +] + +# Temporal dimensions +t_range = [1] + +configs = list(itertools.product(t_range, h_w_configs)) + + +def get_benchmark( + num_grid_per_side: int, + spatial_merge_size: int, + hidden_dim: int, + dtype: torch.dtype, + device: str, +): + @triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["t", "h_w"], + x_vals=[list(_) for _ in configs], + line_arg="provider", + line_vals=["native", "triton"], + line_names=["Native (PyTorch)", "Triton"], + styles=[("blue", "-"), ("red", "-")], + ylabel="us", + plot_name=( + f"vit-bilinear-pos-embed-" + f"grid{num_grid_per_side}-" + f"dim{hidden_dim}-" + f"{dtype}" + ), + args={}, + ) + ) + def benchmark(t, h_w, provider): + h, w = h_w + + torch.manual_seed(42) + embed_weight = ( + torch.randn( + num_grid_per_side * num_grid_per_side, + hidden_dim, + device=device, + dtype=dtype, + ) + * 0.25 + ) + + quantiles = [0.5, 0.2, 0.8] + + if provider == "native": + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: pos_embed_interpolate_native( + embed_weight, + t, + h, + w, + num_grid_per_side, + spatial_merge_size, + dtype, + ), + quantiles=quantiles, + ) + else: + assert HAS_TRITON, "Triton not available" + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: triton_pos_embed_interpolate( + embed_weight, + t, + h, + w, + num_grid_per_side, + spatial_merge_size, + dtype, + ), + quantiles=quantiles, + ) + + return 1000 * ms, 1000 * max_ms, 1000 * min_ms + + return benchmark + + +if __name__ == "__main__": + parser = FlexibleArgumentParser( + description="Benchmark bilinear position embedding interpolation." + ) + parser.add_argument( + "--num-grid-per-side", + type=int, + default=48, + help="Position embedding grid size (default: 48 for Qwen3-VL)", + ) + parser.add_argument( + "--spatial-merge-size", + type=int, + default=2, + help="Spatial merge size (default: 2)", + ) + parser.add_argument( + "--hidden-dim", + type=int, + default=1152, + help="Embedding hidden dimension (default: 1152 for Qwen3-VL)", + ) + parser.add_argument( + "--device", + type=str, + choices=["cuda:0", "cuda:1"], + default="cuda:0", + ) + parser.add_argument( + "--save-path", + type=str, + default="./vit_pos_embed/", + ) + args = parser.parse_args() + + dtype = torch.bfloat16 + + bench = get_benchmark( + args.num_grid_per_side, + args.spatial_merge_size, + args.hidden_dim, + dtype, + args.device, + ) + bench.run(print_data=True, save_path=args.save_path) diff --git a/tests/kernels/core/test_vit_bilinear_pos_embed.py b/tests/kernels/core/test_vit_bilinear_pos_embed.py new file mode 100644 index 00000000000..66571e3a2fb --- /dev/null +++ b/tests/kernels/core/test_vit_bilinear_pos_embed.py @@ -0,0 +1,120 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Accuracy tests for the fused Triton bilinear position-embedding kernel. + +Compares ``triton_pos_embed_interpolate`` against the pure-PyTorch +``pos_embed_interpolate_native`` across a variety of grid shapes and dtypes. +""" + +import pytest +import torch + +from vllm.triton_utils import HAS_TRITON + +if HAS_TRITON: + from vllm.model_executor.models.qwen3_vl import ( + pos_embed_interpolate_native, + triton_pos_embed_interpolate, + ) + + +DTYPES = [torch.float32, torch.bfloat16] +# Qwen3-VL default +NUM_GRID_PER_SIDE = 48 +SPATIAL_MERGE_SIZE = 2 +HIDDEN_DIM = 1152 + +# 4 square + 4 non-square grids (h, w divisible by spatial_merge_size=2) +SQUARE_GRIDS = [(1, 4, 4), (1, 16, 16), (1, 32, 32), (1, 48, 48)] +NON_SQUARE_GRIDS = [(1, 8, 16), (1, 14, 20), (1, 32, 48), (1, 60, 80)] +ALL_GRIDS = SQUARE_GRIDS + NON_SQUARE_GRIDS + + +@pytest.mark.skipif(not HAS_TRITON, reason="Triton not available") +@pytest.mark.parametrize("dtype", DTYPES, ids=lambda d: str(d).split(".")[-1]) +@pytest.mark.parametrize( + "grid_thw", + ALL_GRIDS, + ids=[f"{t}x{h}x{w}" for t, h, w in ALL_GRIDS], +) +def test_triton_matches_native( + grid_thw: tuple[int, int, int], + dtype: torch.dtype, +) -> None: + """Triton kernel output must match the native PyTorch implementation.""" + t, h, w = grid_thw + device = "cuda" + + # Scale to match real Qwen3-VL pos_embed weight distribution (std~0.23). + torch.manual_seed(42) + embed_weight = ( + torch.randn( + NUM_GRID_PER_SIDE * NUM_GRID_PER_SIDE, + HIDDEN_DIM, + device=device, + dtype=dtype, + ) + * 0.25 + ) + + native_out = pos_embed_interpolate_native( + embed_weight, t, h, w, NUM_GRID_PER_SIDE, SPATIAL_MERGE_SIZE, dtype + ) + triton_out = triton_pos_embed_interpolate( + embed_weight, t, h, w, NUM_GRID_PER_SIDE, SPATIAL_MERGE_SIZE, dtype + ) + + assert native_out.shape == triton_out.shape, ( + f"Shape mismatch: native {native_out.shape} vs triton {triton_out.shape}" + ) + + # Small numerical differences arise from the precomputed h/w_scale + # in the triton kernel vs torch.linspace in the native path, which can + # cause single-ULP output differences + # in a handful of elements. + atol = {torch.float32: 5e-5, torch.bfloat16: 1e-2}[dtype] + rtol = {torch.float32: 1e-5, torch.bfloat16: 1e-2}[dtype] + torch.testing.assert_close(triton_out, native_out, atol=atol, rtol=rtol) + + +@pytest.mark.skipif(not HAS_TRITON, reason="Triton not available") +@pytest.mark.parametrize("dtype", DTYPES, ids=lambda d: str(d).split(".")[-1]) +def test_temporal_repeat(dtype: torch.dtype) -> None: + """Verify temporal dimension t > 1 correctly repeats the spatial pattern.""" + device = "cuda" + h, w = 16, 16 + t_single, t_multi = 1, 3 + + # Scale to match real Qwen3-VL pos_embed weight distribution (std~0.23). + torch.manual_seed(42) + embed_weight = ( + torch.randn( + NUM_GRID_PER_SIDE * NUM_GRID_PER_SIDE, + HIDDEN_DIM, + device=device, + dtype=dtype, + ) + * 0.25 + ) + + out_single = triton_pos_embed_interpolate( + embed_weight, + t_single, + h, + w, + NUM_GRID_PER_SIDE, + SPATIAL_MERGE_SIZE, + dtype, + ) + out_multi = triton_pos_embed_interpolate( + embed_weight, + t_multi, + h, + w, + NUM_GRID_PER_SIDE, + SPATIAL_MERGE_SIZE, + dtype, + ) + + expected = out_single.repeat(t_multi, 1) + torch.testing.assert_close(out_multi, expected, atol=0, rtol=0) diff --git a/vllm/model_executor/models/qwen3_vl.py b/vllm/model_executor/models/qwen3_vl.py index 431371c124d..cb48ceb0c77 100644 --- a/vllm/model_executor/models/qwen3_vl.py +++ b/vllm/model_executor/models/qwen3_vl.py @@ -96,6 +96,7 @@ from vllm.multimodal.processing import ( from vllm.sequence import IntermediateTensors from vllm.tokenizers.protocol import TokenizerLike from vllm.tokenizers.registry import cached_tokenizer_from_config +from vllm.triton_utils import HAS_TRITON, tl, triton from vllm.utils.collection_utils import is_list_of from vllm.utils.math_utils import round_up @@ -145,6 +146,201 @@ logger = init_logger(__name__) # of the maximum size. DUMMY_VIDEO_NUM_FRAMES = 2048 +# --------------------------------------------------------------------------- +# Triton kernel: fused bilinear position-embedding interpolation +# --------------------------------------------------------------------------- +# Replaces many small eager-mode CUDA kernels with a single launch. +# The spatial-merge reorder is baked into the index math so the output +# is ready to be added to the patch embeddings directly. +# --------------------------------------------------------------------------- + +if HAS_TRITON: + + @triton.jit + def _bilinear_pos_embed_kernel( + embed_ptr, + output_ptr, + H, + W, + h_scale, + w_scale, + NUM_GRID: tl.constexpr, + M_SIZE: tl.constexpr, + HIDDEN_DIM: tl.constexpr, + BLOCK_D: tl.constexpr, + ): + """Fused bilinear pos-embed interpolation with spatial-merge reorder.""" + pid = tl.program_id(0) + total_spatial = H * W + spatial_idx = pid % total_spatial + + num_blocks_w = W // M_SIZE + block_idx = spatial_idx // (M_SIZE * M_SIZE) + local_idx = spatial_idx % (M_SIZE * M_SIZE) + br = block_idx // num_blocks_w + bc = block_idx % num_blocks_w + lr = local_idx // M_SIZE + lc = local_idx % M_SIZE + row = br * M_SIZE + lr + col = bc * M_SIZE + lc + + h_frac = row.to(tl.float32) * h_scale + w_frac = col.to(tl.float32) * w_scale + + hf = tl.math.floor(h_frac).to(tl.int32) + wf = tl.math.floor(w_frac).to(tl.int32) + hc = tl.minimum(hf + 1, NUM_GRID - 1) + wc = tl.minimum(wf + 1, NUM_GRID - 1) + + dh = h_frac - hf.to(tl.float32) + dw = w_frac - wf.to(tl.float32) + w11 = dh * dw + w10 = dh - w11 + w01 = dw - w11 + w00 = 1.0 - dh - w01 + + off00 = (hf * NUM_GRID + wf) * HIDDEN_DIM + off01 = (hf * NUM_GRID + wc) * HIDDEN_DIM + off10 = (hc * NUM_GRID + wf) * HIDDEN_DIM + off11 = (hc * NUM_GRID + wc) * HIDDEN_DIM + out_off = pid * HIDDEN_DIM + + # Cast weights to output dtype so the multiply-accumulate stays + # in the same precision as the native PyTorch implementation. + out_dtype = output_ptr.dtype.element_ty + w00_c = w00.to(out_dtype) + w01_c = w01.to(out_dtype) + w10_c = w10.to(out_dtype) + w11_c = w11.to(out_dtype) + + for d in tl.range(0, HIDDEN_DIM, BLOCK_D): + cols = d + tl.arange(0, BLOCK_D) + mask = cols < HIDDEN_DIM + + e00 = tl.load(embed_ptr + off00 + cols, mask=mask) + e01 = tl.load(embed_ptr + off01 + cols, mask=mask) + e10 = tl.load(embed_ptr + off10 + cols, mask=mask) + e11 = tl.load(embed_ptr + off11 + cols, mask=mask) + + val = w00_c * e00 + w01_c * e01 + w10_c * e10 + w11_c * e11 + + tl.store(output_ptr + out_off + cols, val, mask=mask) + + def triton_pos_embed_interpolate( + embed_weight: torch.Tensor, + t: int, + h: int, + w: int, + num_grid_per_side: int, + m_size: int, + dtype: torch.dtype, + ) -> torch.Tensor: + """Launch the fused Triton kernel for one (t,h,w) grid. + + Returns a tensor of shape ``(t * h * w, hidden_dim)`` with the + bilinearly-interpolated position embeddings in spatial-merge order. + """ + assert h % m_size == 0 and w % m_size == 0, ( + f"h={h} and w={w} must be divisible by m_size={m_size}" + ) + hidden_dim = embed_weight.shape[1] + total_out = t * h * w + output = torch.empty( + total_out, + hidden_dim, + device=embed_weight.device, + dtype=dtype, + ) + + h_scale = float(num_grid_per_side - 1) / float(h - 1) if h > 1 else 0.0 + w_scale = float(num_grid_per_side - 1) / float(w - 1) if w > 1 else 0.0 + + BLOCK_D = triton.next_power_of_2(hidden_dim) + + _bilinear_pos_embed_kernel[(total_out,)]( + embed_weight, + output, + h, + w, + h_scale, + w_scale, + num_grid_per_side, + m_size, + hidden_dim, + BLOCK_D, + ) + return output + + +def pos_embed_interpolate_native( + embed_weight: torch.Tensor, + t: int, + h: int, + w: int, + num_grid_per_side: int, + m_size: int, + dtype: torch.dtype, +) -> torch.Tensor: + """Eager PyTorch bilinear position-embedding interpolation. + + Returns a tensor of shape ``(t * h * w, hidden_dim)`` with the + bilinearly-interpolated position embeddings in spatial-merge order. + """ + assert h % m_size == 0 and w % m_size == 0, ( + f"h={h} and w={w} must be divisible by m_size={m_size}" + ) + hidden_dim = embed_weight.shape[1] + device = embed_weight.device + + h_idxs = torch.linspace( + 0, + num_grid_per_side - 1, + h, + dtype=torch.float32, + device=device, + ) + w_idxs = torch.linspace( + 0, + num_grid_per_side - 1, + w, + dtype=torch.float32, + device=device, + ) + + h_floor = h_idxs.to(torch.long) + w_floor = w_idxs.to(torch.long) + h_ceil = torch.clamp(h_floor + 1, max=num_grid_per_side - 1) + w_ceil = torch.clamp(w_floor + 1, max=num_grid_per_side - 1) + + dh = h_idxs - h_floor + dw = w_idxs - w_floor + + dh_grid, dw_grid = torch.meshgrid(dh, dw, indexing="ij") + h_floor_grid, w_floor_grid = torch.meshgrid(h_floor, w_floor, indexing="ij") + h_ceil_grid, w_ceil_grid = torch.meshgrid(h_ceil, w_ceil, indexing="ij") + + w11 = dh_grid * dw_grid + w10 = dh_grid - w11 + w01 = dw_grid - w11 + w00 = 1 - dh_grid - w01 + + h_grid = torch.stack([h_floor_grid, h_floor_grid, h_ceil_grid, h_ceil_grid]) + w_grid = torch.stack([w_floor_grid, w_ceil_grid, w_floor_grid, w_ceil_grid]) + h_grid_idx = h_grid * num_grid_per_side + + indices = (h_grid_idx + w_grid).reshape(4, -1) + weights = torch.stack([w00, w01, w10, w11], dim=0).reshape(4, -1, 1) + weights = weights.to(dtype=dtype) + + embeds = embed_weight[indices] + embeds *= weights + combined = embeds.sum(dim=0) + + combined = combined.reshape(h // m_size, m_size, w // m_size, m_size, hidden_dim) + combined = combined.permute(0, 2, 1, 3, 4).reshape(1, -1, hidden_dim) + repeated = combined.expand(t, -1, -1).reshape(-1, hidden_dim) + return repeated.to(dtype=dtype) + class Qwen3_VisionPatchEmbed(nn.Module): def __init__( @@ -470,63 +666,22 @@ class Qwen3_VisionTransformer(nn.Module): return cos_combined, sin_combined def fast_pos_embed_interpolate(self, grid_thw: list[list[int]]) -> torch.Tensor: - num_grid_per_side = self.num_grid_per_side - m_size = self.spatial_merge_size - hidden_dim = self.pos_embed.embedding_dim - + interpolate_fn = ( + triton_pos_embed_interpolate if HAS_TRITON else pos_embed_interpolate_native + ) outputs = [] for t, h, w in grid_thw: - h_idxs = torch.linspace( - 0, num_grid_per_side - 1, h, dtype=torch.float32, device=self.device + outputs.append( + interpolate_fn( + self.pos_embed.weight, + t, + h, + w, + self.num_grid_per_side, + self.spatial_merge_size, + self.dtype, + ) ) - w_idxs = torch.linspace( - 0, num_grid_per_side - 1, w, dtype=torch.float32, device=self.device - ) - - h_floor = h_idxs.to(torch.long) - w_floor = w_idxs.to(torch.long) - h_ceil = torch.clamp(h_floor + 1, max=num_grid_per_side - 1) - w_ceil = torch.clamp(w_floor + 1, max=num_grid_per_side - 1) - - dh = h_idxs - h_floor - dw = w_idxs - w_floor - - # Create meshgrid view for all h, w vars - dh_grid, dw_grid = torch.meshgrid(dh, dw, indexing="ij") - h_floor_grid, w_floor_grid = torch.meshgrid(h_floor, w_floor, indexing="ij") - h_ceil_grid, w_ceil_grid = torch.meshgrid(h_ceil, w_ceil, indexing="ij") - - # original computation of weights - # w00 = (1 - dh_grid) * (1 - dw_grid) - # w01 = (1 - dh_grid) * dw_grid - # w10 = dh_grid * (1 - dw_grid) - # w11 = dh_grid * dw_grid - # we reuse w11 here to avoid duplicate - # dh_grid * dw_grid computation - w11 = dh_grid * dw_grid - w10 = dh_grid - w11 - w01 = dw_grid - w11 - w00 = 1 - dh_grid - w01 - - h_grid = torch.stack([h_floor_grid, h_floor_grid, h_ceil_grid, h_ceil_grid]) - w_grid = torch.stack([w_floor_grid, w_ceil_grid, w_floor_grid, w_ceil_grid]) - h_grid_idx = h_grid * num_grid_per_side - - indices = (h_grid_idx + w_grid).reshape(4, -1) - weights = torch.stack([w00, w01, w10, w11], dim=0).reshape(4, -1, 1) - weights = weights.to(dtype=self.dtype) - - embeds = self.pos_embed(indices) - embeds *= weights - combined = embeds.sum(dim=0) - - combined = combined.reshape( - h // m_size, m_size, w // m_size, m_size, hidden_dim - ) - combined = combined.permute(0, 2, 1, 3, 4).reshape(1, -1, hidden_dim) - repeated = combined.expand(t, -1, -1).reshape(-1, hidden_dim) - outputs.append(repeated) - return torch.cat(outputs, dim=0) def prepare_encoder_metadata(