[Perf] triton bilinear_pos_embed kernel for ViT (#37948)

Signed-off-by: Zhanda Zhu <zhandazhu@gmail.com>
This commit is contained in:
Zhanda Zhu
2026-04-01 04:52:02 -04:00
committed by GitHub
parent 4f6eed3bd4
commit c75a313824
3 changed files with 491 additions and 54 deletions
@@ -0,0 +1,162 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Benchmarks the fused Triton bilinear position-embedding kernel against
# the pure-PyTorch (native) implementation used in Qwen3-VL ViT models.
#
# == Usage Examples ==
#
# Default benchmark:
# python3 benchmark_vit_bilinear_pos_embed.py
#
# Custom parameters:
# python3 benchmark_vit_bilinear_pos_embed.py --hidden-dim 1152 \
# --num-grid-per-side 48 --save-path ./configs/vit_pos_embed/
import itertools
import torch
from vllm.model_executor.models.qwen3_vl import (
pos_embed_interpolate_native,
triton_pos_embed_interpolate,
)
from vllm.triton_utils import HAS_TRITON, triton
from vllm.utils.argparse_utils import FlexibleArgumentParser
# (h, w) configurations to benchmark
h_w_configs = [
(16, 16),
(32, 32),
(48, 48),
(64, 64),
(128, 128),
(32, 48),
(60, 80),
]
# Temporal dimensions
t_range = [1]
configs = list(itertools.product(t_range, h_w_configs))
def get_benchmark(
num_grid_per_side: int,
spatial_merge_size: int,
hidden_dim: int,
dtype: torch.dtype,
device: str,
):
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["t", "h_w"],
x_vals=[list(_) for _ in configs],
line_arg="provider",
line_vals=["native", "triton"],
line_names=["Native (PyTorch)", "Triton"],
styles=[("blue", "-"), ("red", "-")],
ylabel="us",
plot_name=(
f"vit-bilinear-pos-embed-"
f"grid{num_grid_per_side}-"
f"dim{hidden_dim}-"
f"{dtype}"
),
args={},
)
)
def benchmark(t, h_w, provider):
h, w = h_w
torch.manual_seed(42)
embed_weight = (
torch.randn(
num_grid_per_side * num_grid_per_side,
hidden_dim,
device=device,
dtype=dtype,
)
* 0.25
)
quantiles = [0.5, 0.2, 0.8]
if provider == "native":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: pos_embed_interpolate_native(
embed_weight,
t,
h,
w,
num_grid_per_side,
spatial_merge_size,
dtype,
),
quantiles=quantiles,
)
else:
assert HAS_TRITON, "Triton not available"
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: triton_pos_embed_interpolate(
embed_weight,
t,
h,
w,
num_grid_per_side,
spatial_merge_size,
dtype,
),
quantiles=quantiles,
)
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
return benchmark
if __name__ == "__main__":
parser = FlexibleArgumentParser(
description="Benchmark bilinear position embedding interpolation."
)
parser.add_argument(
"--num-grid-per-side",
type=int,
default=48,
help="Position embedding grid size (default: 48 for Qwen3-VL)",
)
parser.add_argument(
"--spatial-merge-size",
type=int,
default=2,
help="Spatial merge size (default: 2)",
)
parser.add_argument(
"--hidden-dim",
type=int,
default=1152,
help="Embedding hidden dimension (default: 1152 for Qwen3-VL)",
)
parser.add_argument(
"--device",
type=str,
choices=["cuda:0", "cuda:1"],
default="cuda:0",
)
parser.add_argument(
"--save-path",
type=str,
default="./vit_pos_embed/",
)
args = parser.parse_args()
dtype = torch.bfloat16
bench = get_benchmark(
args.num_grid_per_side,
args.spatial_merge_size,
args.hidden_dim,
dtype,
args.device,
)
bench.run(print_data=True, save_path=args.save_path)
@@ -0,0 +1,120 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Accuracy tests for the fused Triton bilinear position-embedding kernel.
Compares ``triton_pos_embed_interpolate`` against the pure-PyTorch
``pos_embed_interpolate_native`` across a variety of grid shapes and dtypes.
"""
import pytest
import torch
from vllm.triton_utils import HAS_TRITON
if HAS_TRITON:
from vllm.model_executor.models.qwen3_vl import (
pos_embed_interpolate_native,
triton_pos_embed_interpolate,
)
DTYPES = [torch.float32, torch.bfloat16]
# Qwen3-VL default
NUM_GRID_PER_SIDE = 48
SPATIAL_MERGE_SIZE = 2
HIDDEN_DIM = 1152
# 4 square + 4 non-square grids (h, w divisible by spatial_merge_size=2)
SQUARE_GRIDS = [(1, 4, 4), (1, 16, 16), (1, 32, 32), (1, 48, 48)]
NON_SQUARE_GRIDS = [(1, 8, 16), (1, 14, 20), (1, 32, 48), (1, 60, 80)]
ALL_GRIDS = SQUARE_GRIDS + NON_SQUARE_GRIDS
@pytest.mark.skipif(not HAS_TRITON, reason="Triton not available")
@pytest.mark.parametrize("dtype", DTYPES, ids=lambda d: str(d).split(".")[-1])
@pytest.mark.parametrize(
"grid_thw",
ALL_GRIDS,
ids=[f"{t}x{h}x{w}" for t, h, w in ALL_GRIDS],
)
def test_triton_matches_native(
grid_thw: tuple[int, int, int],
dtype: torch.dtype,
) -> None:
"""Triton kernel output must match the native PyTorch implementation."""
t, h, w = grid_thw
device = "cuda"
# Scale to match real Qwen3-VL pos_embed weight distribution (std~0.23).
torch.manual_seed(42)
embed_weight = (
torch.randn(
NUM_GRID_PER_SIDE * NUM_GRID_PER_SIDE,
HIDDEN_DIM,
device=device,
dtype=dtype,
)
* 0.25
)
native_out = pos_embed_interpolate_native(
embed_weight, t, h, w, NUM_GRID_PER_SIDE, SPATIAL_MERGE_SIZE, dtype
)
triton_out = triton_pos_embed_interpolate(
embed_weight, t, h, w, NUM_GRID_PER_SIDE, SPATIAL_MERGE_SIZE, dtype
)
assert native_out.shape == triton_out.shape, (
f"Shape mismatch: native {native_out.shape} vs triton {triton_out.shape}"
)
# Small numerical differences arise from the precomputed h/w_scale
# in the triton kernel vs torch.linspace in the native path, which can
# cause single-ULP output differences
# in a handful of elements.
atol = {torch.float32: 5e-5, torch.bfloat16: 1e-2}[dtype]
rtol = {torch.float32: 1e-5, torch.bfloat16: 1e-2}[dtype]
torch.testing.assert_close(triton_out, native_out, atol=atol, rtol=rtol)
@pytest.mark.skipif(not HAS_TRITON, reason="Triton not available")
@pytest.mark.parametrize("dtype", DTYPES, ids=lambda d: str(d).split(".")[-1])
def test_temporal_repeat(dtype: torch.dtype) -> None:
"""Verify temporal dimension t > 1 correctly repeats the spatial pattern."""
device = "cuda"
h, w = 16, 16
t_single, t_multi = 1, 3
# Scale to match real Qwen3-VL pos_embed weight distribution (std~0.23).
torch.manual_seed(42)
embed_weight = (
torch.randn(
NUM_GRID_PER_SIDE * NUM_GRID_PER_SIDE,
HIDDEN_DIM,
device=device,
dtype=dtype,
)
* 0.25
)
out_single = triton_pos_embed_interpolate(
embed_weight,
t_single,
h,
w,
NUM_GRID_PER_SIDE,
SPATIAL_MERGE_SIZE,
dtype,
)
out_multi = triton_pos_embed_interpolate(
embed_weight,
t_multi,
h,
w,
NUM_GRID_PER_SIDE,
SPATIAL_MERGE_SIZE,
dtype,
)
expected = out_single.repeat(t_multi, 1)
torch.testing.assert_close(out_multi, expected, atol=0, rtol=0)
+209 -54
View File
@@ -96,6 +96,7 @@ from vllm.multimodal.processing import (
from vllm.sequence import IntermediateTensors
from vllm.tokenizers.protocol import TokenizerLike
from vllm.tokenizers.registry import cached_tokenizer_from_config
from vllm.triton_utils import HAS_TRITON, tl, triton
from vllm.utils.collection_utils import is_list_of
from vllm.utils.math_utils import round_up
@@ -145,6 +146,201 @@ logger = init_logger(__name__)
# of the maximum size.
DUMMY_VIDEO_NUM_FRAMES = 2048
# ---------------------------------------------------------------------------
# Triton kernel: fused bilinear position-embedding interpolation
# ---------------------------------------------------------------------------
# Replaces many small eager-mode CUDA kernels with a single launch.
# The spatial-merge reorder is baked into the index math so the output
# is ready to be added to the patch embeddings directly.
# ---------------------------------------------------------------------------
if HAS_TRITON:
@triton.jit
def _bilinear_pos_embed_kernel(
embed_ptr,
output_ptr,
H,
W,
h_scale,
w_scale,
NUM_GRID: tl.constexpr,
M_SIZE: tl.constexpr,
HIDDEN_DIM: tl.constexpr,
BLOCK_D: tl.constexpr,
):
"""Fused bilinear pos-embed interpolation with spatial-merge reorder."""
pid = tl.program_id(0)
total_spatial = H * W
spatial_idx = pid % total_spatial
num_blocks_w = W // M_SIZE
block_idx = spatial_idx // (M_SIZE * M_SIZE)
local_idx = spatial_idx % (M_SIZE * M_SIZE)
br = block_idx // num_blocks_w
bc = block_idx % num_blocks_w
lr = local_idx // M_SIZE
lc = local_idx % M_SIZE
row = br * M_SIZE + lr
col = bc * M_SIZE + lc
h_frac = row.to(tl.float32) * h_scale
w_frac = col.to(tl.float32) * w_scale
hf = tl.math.floor(h_frac).to(tl.int32)
wf = tl.math.floor(w_frac).to(tl.int32)
hc = tl.minimum(hf + 1, NUM_GRID - 1)
wc = tl.minimum(wf + 1, NUM_GRID - 1)
dh = h_frac - hf.to(tl.float32)
dw = w_frac - wf.to(tl.float32)
w11 = dh * dw
w10 = dh - w11
w01 = dw - w11
w00 = 1.0 - dh - w01
off00 = (hf * NUM_GRID + wf) * HIDDEN_DIM
off01 = (hf * NUM_GRID + wc) * HIDDEN_DIM
off10 = (hc * NUM_GRID + wf) * HIDDEN_DIM
off11 = (hc * NUM_GRID + wc) * HIDDEN_DIM
out_off = pid * HIDDEN_DIM
# Cast weights to output dtype so the multiply-accumulate stays
# in the same precision as the native PyTorch implementation.
out_dtype = output_ptr.dtype.element_ty
w00_c = w00.to(out_dtype)
w01_c = w01.to(out_dtype)
w10_c = w10.to(out_dtype)
w11_c = w11.to(out_dtype)
for d in tl.range(0, HIDDEN_DIM, BLOCK_D):
cols = d + tl.arange(0, BLOCK_D)
mask = cols < HIDDEN_DIM
e00 = tl.load(embed_ptr + off00 + cols, mask=mask)
e01 = tl.load(embed_ptr + off01 + cols, mask=mask)
e10 = tl.load(embed_ptr + off10 + cols, mask=mask)
e11 = tl.load(embed_ptr + off11 + cols, mask=mask)
val = w00_c * e00 + w01_c * e01 + w10_c * e10 + w11_c * e11
tl.store(output_ptr + out_off + cols, val, mask=mask)
def triton_pos_embed_interpolate(
embed_weight: torch.Tensor,
t: int,
h: int,
w: int,
num_grid_per_side: int,
m_size: int,
dtype: torch.dtype,
) -> torch.Tensor:
"""Launch the fused Triton kernel for one (t,h,w) grid.
Returns a tensor of shape ``(t * h * w, hidden_dim)`` with the
bilinearly-interpolated position embeddings in spatial-merge order.
"""
assert h % m_size == 0 and w % m_size == 0, (
f"h={h} and w={w} must be divisible by m_size={m_size}"
)
hidden_dim = embed_weight.shape[1]
total_out = t * h * w
output = torch.empty(
total_out,
hidden_dim,
device=embed_weight.device,
dtype=dtype,
)
h_scale = float(num_grid_per_side - 1) / float(h - 1) if h > 1 else 0.0
w_scale = float(num_grid_per_side - 1) / float(w - 1) if w > 1 else 0.0
BLOCK_D = triton.next_power_of_2(hidden_dim)
_bilinear_pos_embed_kernel[(total_out,)](
embed_weight,
output,
h,
w,
h_scale,
w_scale,
num_grid_per_side,
m_size,
hidden_dim,
BLOCK_D,
)
return output
def pos_embed_interpolate_native(
embed_weight: torch.Tensor,
t: int,
h: int,
w: int,
num_grid_per_side: int,
m_size: int,
dtype: torch.dtype,
) -> torch.Tensor:
"""Eager PyTorch bilinear position-embedding interpolation.
Returns a tensor of shape ``(t * h * w, hidden_dim)`` with the
bilinearly-interpolated position embeddings in spatial-merge order.
"""
assert h % m_size == 0 and w % m_size == 0, (
f"h={h} and w={w} must be divisible by m_size={m_size}"
)
hidden_dim = embed_weight.shape[1]
device = embed_weight.device
h_idxs = torch.linspace(
0,
num_grid_per_side - 1,
h,
dtype=torch.float32,
device=device,
)
w_idxs = torch.linspace(
0,
num_grid_per_side - 1,
w,
dtype=torch.float32,
device=device,
)
h_floor = h_idxs.to(torch.long)
w_floor = w_idxs.to(torch.long)
h_ceil = torch.clamp(h_floor + 1, max=num_grid_per_side - 1)
w_ceil = torch.clamp(w_floor + 1, max=num_grid_per_side - 1)
dh = h_idxs - h_floor
dw = w_idxs - w_floor
dh_grid, dw_grid = torch.meshgrid(dh, dw, indexing="ij")
h_floor_grid, w_floor_grid = torch.meshgrid(h_floor, w_floor, indexing="ij")
h_ceil_grid, w_ceil_grid = torch.meshgrid(h_ceil, w_ceil, indexing="ij")
w11 = dh_grid * dw_grid
w10 = dh_grid - w11
w01 = dw_grid - w11
w00 = 1 - dh_grid - w01
h_grid = torch.stack([h_floor_grid, h_floor_grid, h_ceil_grid, h_ceil_grid])
w_grid = torch.stack([w_floor_grid, w_ceil_grid, w_floor_grid, w_ceil_grid])
h_grid_idx = h_grid * num_grid_per_side
indices = (h_grid_idx + w_grid).reshape(4, -1)
weights = torch.stack([w00, w01, w10, w11], dim=0).reshape(4, -1, 1)
weights = weights.to(dtype=dtype)
embeds = embed_weight[indices]
embeds *= weights
combined = embeds.sum(dim=0)
combined = combined.reshape(h // m_size, m_size, w // m_size, m_size, hidden_dim)
combined = combined.permute(0, 2, 1, 3, 4).reshape(1, -1, hidden_dim)
repeated = combined.expand(t, -1, -1).reshape(-1, hidden_dim)
return repeated.to(dtype=dtype)
class Qwen3_VisionPatchEmbed(nn.Module):
def __init__(
@@ -470,63 +666,22 @@ class Qwen3_VisionTransformer(nn.Module):
return cos_combined, sin_combined
def fast_pos_embed_interpolate(self, grid_thw: list[list[int]]) -> torch.Tensor:
num_grid_per_side = self.num_grid_per_side
m_size = self.spatial_merge_size
hidden_dim = self.pos_embed.embedding_dim
interpolate_fn = (
triton_pos_embed_interpolate if HAS_TRITON else pos_embed_interpolate_native
)
outputs = []
for t, h, w in grid_thw:
h_idxs = torch.linspace(
0, num_grid_per_side - 1, h, dtype=torch.float32, device=self.device
outputs.append(
interpolate_fn(
self.pos_embed.weight,
t,
h,
w,
self.num_grid_per_side,
self.spatial_merge_size,
self.dtype,
)
)
w_idxs = torch.linspace(
0, num_grid_per_side - 1, w, dtype=torch.float32, device=self.device
)
h_floor = h_idxs.to(torch.long)
w_floor = w_idxs.to(torch.long)
h_ceil = torch.clamp(h_floor + 1, max=num_grid_per_side - 1)
w_ceil = torch.clamp(w_floor + 1, max=num_grid_per_side - 1)
dh = h_idxs - h_floor
dw = w_idxs - w_floor
# Create meshgrid view for all h, w vars
dh_grid, dw_grid = torch.meshgrid(dh, dw, indexing="ij")
h_floor_grid, w_floor_grid = torch.meshgrid(h_floor, w_floor, indexing="ij")
h_ceil_grid, w_ceil_grid = torch.meshgrid(h_ceil, w_ceil, indexing="ij")
# original computation of weights
# w00 = (1 - dh_grid) * (1 - dw_grid)
# w01 = (1 - dh_grid) * dw_grid
# w10 = dh_grid * (1 - dw_grid)
# w11 = dh_grid * dw_grid
# we reuse w11 here to avoid duplicate
# dh_grid * dw_grid computation
w11 = dh_grid * dw_grid
w10 = dh_grid - w11
w01 = dw_grid - w11
w00 = 1 - dh_grid - w01
h_grid = torch.stack([h_floor_grid, h_floor_grid, h_ceil_grid, h_ceil_grid])
w_grid = torch.stack([w_floor_grid, w_ceil_grid, w_floor_grid, w_ceil_grid])
h_grid_idx = h_grid * num_grid_per_side
indices = (h_grid_idx + w_grid).reshape(4, -1)
weights = torch.stack([w00, w01, w10, w11], dim=0).reshape(4, -1, 1)
weights = weights.to(dtype=self.dtype)
embeds = self.pos_embed(indices)
embeds *= weights
combined = embeds.sum(dim=0)
combined = combined.reshape(
h // m_size, m_size, w // m_size, m_size, hidden_dim
)
combined = combined.permute(0, 2, 1, 3, 4).reshape(1, -1, hidden_dim)
repeated = combined.expand(t, -1, -1).reshape(-1, hidden_dim)
outputs.append(repeated)
return torch.cat(outputs, dim=0)
def prepare_encoder_metadata(