mirror of
https://github.com/vllm-project/vllm.git
synced 2026-06-06 00:16:14 +00:00
[Perf] triton bilinear_pos_embed kernel for ViT (#37948)
Signed-off-by: Zhanda Zhu <zhandazhu@gmail.com>
This commit is contained in:
@@ -0,0 +1,162 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# Benchmarks the fused Triton bilinear position-embedding kernel against
|
||||
# the pure-PyTorch (native) implementation used in Qwen3-VL ViT models.
|
||||
#
|
||||
# == Usage Examples ==
|
||||
#
|
||||
# Default benchmark:
|
||||
# python3 benchmark_vit_bilinear_pos_embed.py
|
||||
#
|
||||
# Custom parameters:
|
||||
# python3 benchmark_vit_bilinear_pos_embed.py --hidden-dim 1152 \
|
||||
# --num-grid-per-side 48 --save-path ./configs/vit_pos_embed/
|
||||
|
||||
import itertools
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.models.qwen3_vl import (
|
||||
pos_embed_interpolate_native,
|
||||
triton_pos_embed_interpolate,
|
||||
)
|
||||
from vllm.triton_utils import HAS_TRITON, triton
|
||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||
|
||||
# (h, w) configurations to benchmark
|
||||
h_w_configs = [
|
||||
(16, 16),
|
||||
(32, 32),
|
||||
(48, 48),
|
||||
(64, 64),
|
||||
(128, 128),
|
||||
(32, 48),
|
||||
(60, 80),
|
||||
]
|
||||
|
||||
# Temporal dimensions
|
||||
t_range = [1]
|
||||
|
||||
configs = list(itertools.product(t_range, h_w_configs))
|
||||
|
||||
|
||||
def get_benchmark(
|
||||
num_grid_per_side: int,
|
||||
spatial_merge_size: int,
|
||||
hidden_dim: int,
|
||||
dtype: torch.dtype,
|
||||
device: str,
|
||||
):
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["t", "h_w"],
|
||||
x_vals=[list(_) for _ in configs],
|
||||
line_arg="provider",
|
||||
line_vals=["native", "triton"],
|
||||
line_names=["Native (PyTorch)", "Triton"],
|
||||
styles=[("blue", "-"), ("red", "-")],
|
||||
ylabel="us",
|
||||
plot_name=(
|
||||
f"vit-bilinear-pos-embed-"
|
||||
f"grid{num_grid_per_side}-"
|
||||
f"dim{hidden_dim}-"
|
||||
f"{dtype}"
|
||||
),
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark(t, h_w, provider):
|
||||
h, w = h_w
|
||||
|
||||
torch.manual_seed(42)
|
||||
embed_weight = (
|
||||
torch.randn(
|
||||
num_grid_per_side * num_grid_per_side,
|
||||
hidden_dim,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
* 0.25
|
||||
)
|
||||
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
|
||||
if provider == "native":
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: pos_embed_interpolate_native(
|
||||
embed_weight,
|
||||
t,
|
||||
h,
|
||||
w,
|
||||
num_grid_per_side,
|
||||
spatial_merge_size,
|
||||
dtype,
|
||||
),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
else:
|
||||
assert HAS_TRITON, "Triton not available"
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: triton_pos_embed_interpolate(
|
||||
embed_weight,
|
||||
t,
|
||||
h,
|
||||
w,
|
||||
num_grid_per_side,
|
||||
spatial_merge_size,
|
||||
dtype,
|
||||
),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
|
||||
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
|
||||
|
||||
return benchmark
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = FlexibleArgumentParser(
|
||||
description="Benchmark bilinear position embedding interpolation."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-grid-per-side",
|
||||
type=int,
|
||||
default=48,
|
||||
help="Position embedding grid size (default: 48 for Qwen3-VL)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--spatial-merge-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="Spatial merge size (default: 2)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hidden-dim",
|
||||
type=int,
|
||||
default=1152,
|
||||
help="Embedding hidden dimension (default: 1152 for Qwen3-VL)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
type=str,
|
||||
choices=["cuda:0", "cuda:1"],
|
||||
default="cuda:0",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save-path",
|
||||
type=str,
|
||||
default="./vit_pos_embed/",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
dtype = torch.bfloat16
|
||||
|
||||
bench = get_benchmark(
|
||||
args.num_grid_per_side,
|
||||
args.spatial_merge_size,
|
||||
args.hidden_dim,
|
||||
dtype,
|
||||
args.device,
|
||||
)
|
||||
bench.run(print_data=True, save_path=args.save_path)
|
||||
@@ -0,0 +1,120 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Accuracy tests for the fused Triton bilinear position-embedding kernel.
|
||||
|
||||
Compares ``triton_pos_embed_interpolate`` against the pure-PyTorch
|
||||
``pos_embed_interpolate_native`` across a variety of grid shapes and dtypes.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.triton_utils import HAS_TRITON
|
||||
|
||||
if HAS_TRITON:
|
||||
from vllm.model_executor.models.qwen3_vl import (
|
||||
pos_embed_interpolate_native,
|
||||
triton_pos_embed_interpolate,
|
||||
)
|
||||
|
||||
|
||||
DTYPES = [torch.float32, torch.bfloat16]
|
||||
# Qwen3-VL default
|
||||
NUM_GRID_PER_SIDE = 48
|
||||
SPATIAL_MERGE_SIZE = 2
|
||||
HIDDEN_DIM = 1152
|
||||
|
||||
# 4 square + 4 non-square grids (h, w divisible by spatial_merge_size=2)
|
||||
SQUARE_GRIDS = [(1, 4, 4), (1, 16, 16), (1, 32, 32), (1, 48, 48)]
|
||||
NON_SQUARE_GRIDS = [(1, 8, 16), (1, 14, 20), (1, 32, 48), (1, 60, 80)]
|
||||
ALL_GRIDS = SQUARE_GRIDS + NON_SQUARE_GRIDS
|
||||
|
||||
|
||||
@pytest.mark.skipif(not HAS_TRITON, reason="Triton not available")
|
||||
@pytest.mark.parametrize("dtype", DTYPES, ids=lambda d: str(d).split(".")[-1])
|
||||
@pytest.mark.parametrize(
|
||||
"grid_thw",
|
||||
ALL_GRIDS,
|
||||
ids=[f"{t}x{h}x{w}" for t, h, w in ALL_GRIDS],
|
||||
)
|
||||
def test_triton_matches_native(
|
||||
grid_thw: tuple[int, int, int],
|
||||
dtype: torch.dtype,
|
||||
) -> None:
|
||||
"""Triton kernel output must match the native PyTorch implementation."""
|
||||
t, h, w = grid_thw
|
||||
device = "cuda"
|
||||
|
||||
# Scale to match real Qwen3-VL pos_embed weight distribution (std~0.23).
|
||||
torch.manual_seed(42)
|
||||
embed_weight = (
|
||||
torch.randn(
|
||||
NUM_GRID_PER_SIDE * NUM_GRID_PER_SIDE,
|
||||
HIDDEN_DIM,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
* 0.25
|
||||
)
|
||||
|
||||
native_out = pos_embed_interpolate_native(
|
||||
embed_weight, t, h, w, NUM_GRID_PER_SIDE, SPATIAL_MERGE_SIZE, dtype
|
||||
)
|
||||
triton_out = triton_pos_embed_interpolate(
|
||||
embed_weight, t, h, w, NUM_GRID_PER_SIDE, SPATIAL_MERGE_SIZE, dtype
|
||||
)
|
||||
|
||||
assert native_out.shape == triton_out.shape, (
|
||||
f"Shape mismatch: native {native_out.shape} vs triton {triton_out.shape}"
|
||||
)
|
||||
|
||||
# Small numerical differences arise from the precomputed h/w_scale
|
||||
# in the triton kernel vs torch.linspace in the native path, which can
|
||||
# cause single-ULP output differences
|
||||
# in a handful of elements.
|
||||
atol = {torch.float32: 5e-5, torch.bfloat16: 1e-2}[dtype]
|
||||
rtol = {torch.float32: 1e-5, torch.bfloat16: 1e-2}[dtype]
|
||||
torch.testing.assert_close(triton_out, native_out, atol=atol, rtol=rtol)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not HAS_TRITON, reason="Triton not available")
|
||||
@pytest.mark.parametrize("dtype", DTYPES, ids=lambda d: str(d).split(".")[-1])
|
||||
def test_temporal_repeat(dtype: torch.dtype) -> None:
|
||||
"""Verify temporal dimension t > 1 correctly repeats the spatial pattern."""
|
||||
device = "cuda"
|
||||
h, w = 16, 16
|
||||
t_single, t_multi = 1, 3
|
||||
|
||||
# Scale to match real Qwen3-VL pos_embed weight distribution (std~0.23).
|
||||
torch.manual_seed(42)
|
||||
embed_weight = (
|
||||
torch.randn(
|
||||
NUM_GRID_PER_SIDE * NUM_GRID_PER_SIDE,
|
||||
HIDDEN_DIM,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
* 0.25
|
||||
)
|
||||
|
||||
out_single = triton_pos_embed_interpolate(
|
||||
embed_weight,
|
||||
t_single,
|
||||
h,
|
||||
w,
|
||||
NUM_GRID_PER_SIDE,
|
||||
SPATIAL_MERGE_SIZE,
|
||||
dtype,
|
||||
)
|
||||
out_multi = triton_pos_embed_interpolate(
|
||||
embed_weight,
|
||||
t_multi,
|
||||
h,
|
||||
w,
|
||||
NUM_GRID_PER_SIDE,
|
||||
SPATIAL_MERGE_SIZE,
|
||||
dtype,
|
||||
)
|
||||
|
||||
expected = out_single.repeat(t_multi, 1)
|
||||
torch.testing.assert_close(out_multi, expected, atol=0, rtol=0)
|
||||
@@ -96,6 +96,7 @@ from vllm.multimodal.processing import (
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.tokenizers.protocol import TokenizerLike
|
||||
from vllm.tokenizers.registry import cached_tokenizer_from_config
|
||||
from vllm.triton_utils import HAS_TRITON, tl, triton
|
||||
from vllm.utils.collection_utils import is_list_of
|
||||
from vllm.utils.math_utils import round_up
|
||||
|
||||
@@ -145,6 +146,201 @@ logger = init_logger(__name__)
|
||||
# of the maximum size.
|
||||
DUMMY_VIDEO_NUM_FRAMES = 2048
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Triton kernel: fused bilinear position-embedding interpolation
|
||||
# ---------------------------------------------------------------------------
|
||||
# Replaces many small eager-mode CUDA kernels with a single launch.
|
||||
# The spatial-merge reorder is baked into the index math so the output
|
||||
# is ready to be added to the patch embeddings directly.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
if HAS_TRITON:
|
||||
|
||||
@triton.jit
|
||||
def _bilinear_pos_embed_kernel(
|
||||
embed_ptr,
|
||||
output_ptr,
|
||||
H,
|
||||
W,
|
||||
h_scale,
|
||||
w_scale,
|
||||
NUM_GRID: tl.constexpr,
|
||||
M_SIZE: tl.constexpr,
|
||||
HIDDEN_DIM: tl.constexpr,
|
||||
BLOCK_D: tl.constexpr,
|
||||
):
|
||||
"""Fused bilinear pos-embed interpolation with spatial-merge reorder."""
|
||||
pid = tl.program_id(0)
|
||||
total_spatial = H * W
|
||||
spatial_idx = pid % total_spatial
|
||||
|
||||
num_blocks_w = W // M_SIZE
|
||||
block_idx = spatial_idx // (M_SIZE * M_SIZE)
|
||||
local_idx = spatial_idx % (M_SIZE * M_SIZE)
|
||||
br = block_idx // num_blocks_w
|
||||
bc = block_idx % num_blocks_w
|
||||
lr = local_idx // M_SIZE
|
||||
lc = local_idx % M_SIZE
|
||||
row = br * M_SIZE + lr
|
||||
col = bc * M_SIZE + lc
|
||||
|
||||
h_frac = row.to(tl.float32) * h_scale
|
||||
w_frac = col.to(tl.float32) * w_scale
|
||||
|
||||
hf = tl.math.floor(h_frac).to(tl.int32)
|
||||
wf = tl.math.floor(w_frac).to(tl.int32)
|
||||
hc = tl.minimum(hf + 1, NUM_GRID - 1)
|
||||
wc = tl.minimum(wf + 1, NUM_GRID - 1)
|
||||
|
||||
dh = h_frac - hf.to(tl.float32)
|
||||
dw = w_frac - wf.to(tl.float32)
|
||||
w11 = dh * dw
|
||||
w10 = dh - w11
|
||||
w01 = dw - w11
|
||||
w00 = 1.0 - dh - w01
|
||||
|
||||
off00 = (hf * NUM_GRID + wf) * HIDDEN_DIM
|
||||
off01 = (hf * NUM_GRID + wc) * HIDDEN_DIM
|
||||
off10 = (hc * NUM_GRID + wf) * HIDDEN_DIM
|
||||
off11 = (hc * NUM_GRID + wc) * HIDDEN_DIM
|
||||
out_off = pid * HIDDEN_DIM
|
||||
|
||||
# Cast weights to output dtype so the multiply-accumulate stays
|
||||
# in the same precision as the native PyTorch implementation.
|
||||
out_dtype = output_ptr.dtype.element_ty
|
||||
w00_c = w00.to(out_dtype)
|
||||
w01_c = w01.to(out_dtype)
|
||||
w10_c = w10.to(out_dtype)
|
||||
w11_c = w11.to(out_dtype)
|
||||
|
||||
for d in tl.range(0, HIDDEN_DIM, BLOCK_D):
|
||||
cols = d + tl.arange(0, BLOCK_D)
|
||||
mask = cols < HIDDEN_DIM
|
||||
|
||||
e00 = tl.load(embed_ptr + off00 + cols, mask=mask)
|
||||
e01 = tl.load(embed_ptr + off01 + cols, mask=mask)
|
||||
e10 = tl.load(embed_ptr + off10 + cols, mask=mask)
|
||||
e11 = tl.load(embed_ptr + off11 + cols, mask=mask)
|
||||
|
||||
val = w00_c * e00 + w01_c * e01 + w10_c * e10 + w11_c * e11
|
||||
|
||||
tl.store(output_ptr + out_off + cols, val, mask=mask)
|
||||
|
||||
def triton_pos_embed_interpolate(
|
||||
embed_weight: torch.Tensor,
|
||||
t: int,
|
||||
h: int,
|
||||
w: int,
|
||||
num_grid_per_side: int,
|
||||
m_size: int,
|
||||
dtype: torch.dtype,
|
||||
) -> torch.Tensor:
|
||||
"""Launch the fused Triton kernel for one (t,h,w) grid.
|
||||
|
||||
Returns a tensor of shape ``(t * h * w, hidden_dim)`` with the
|
||||
bilinearly-interpolated position embeddings in spatial-merge order.
|
||||
"""
|
||||
assert h % m_size == 0 and w % m_size == 0, (
|
||||
f"h={h} and w={w} must be divisible by m_size={m_size}"
|
||||
)
|
||||
hidden_dim = embed_weight.shape[1]
|
||||
total_out = t * h * w
|
||||
output = torch.empty(
|
||||
total_out,
|
||||
hidden_dim,
|
||||
device=embed_weight.device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
h_scale = float(num_grid_per_side - 1) / float(h - 1) if h > 1 else 0.0
|
||||
w_scale = float(num_grid_per_side - 1) / float(w - 1) if w > 1 else 0.0
|
||||
|
||||
BLOCK_D = triton.next_power_of_2(hidden_dim)
|
||||
|
||||
_bilinear_pos_embed_kernel[(total_out,)](
|
||||
embed_weight,
|
||||
output,
|
||||
h,
|
||||
w,
|
||||
h_scale,
|
||||
w_scale,
|
||||
num_grid_per_side,
|
||||
m_size,
|
||||
hidden_dim,
|
||||
BLOCK_D,
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
def pos_embed_interpolate_native(
|
||||
embed_weight: torch.Tensor,
|
||||
t: int,
|
||||
h: int,
|
||||
w: int,
|
||||
num_grid_per_side: int,
|
||||
m_size: int,
|
||||
dtype: torch.dtype,
|
||||
) -> torch.Tensor:
|
||||
"""Eager PyTorch bilinear position-embedding interpolation.
|
||||
|
||||
Returns a tensor of shape ``(t * h * w, hidden_dim)`` with the
|
||||
bilinearly-interpolated position embeddings in spatial-merge order.
|
||||
"""
|
||||
assert h % m_size == 0 and w % m_size == 0, (
|
||||
f"h={h} and w={w} must be divisible by m_size={m_size}"
|
||||
)
|
||||
hidden_dim = embed_weight.shape[1]
|
||||
device = embed_weight.device
|
||||
|
||||
h_idxs = torch.linspace(
|
||||
0,
|
||||
num_grid_per_side - 1,
|
||||
h,
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
)
|
||||
w_idxs = torch.linspace(
|
||||
0,
|
||||
num_grid_per_side - 1,
|
||||
w,
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
)
|
||||
|
||||
h_floor = h_idxs.to(torch.long)
|
||||
w_floor = w_idxs.to(torch.long)
|
||||
h_ceil = torch.clamp(h_floor + 1, max=num_grid_per_side - 1)
|
||||
w_ceil = torch.clamp(w_floor + 1, max=num_grid_per_side - 1)
|
||||
|
||||
dh = h_idxs - h_floor
|
||||
dw = w_idxs - w_floor
|
||||
|
||||
dh_grid, dw_grid = torch.meshgrid(dh, dw, indexing="ij")
|
||||
h_floor_grid, w_floor_grid = torch.meshgrid(h_floor, w_floor, indexing="ij")
|
||||
h_ceil_grid, w_ceil_grid = torch.meshgrid(h_ceil, w_ceil, indexing="ij")
|
||||
|
||||
w11 = dh_grid * dw_grid
|
||||
w10 = dh_grid - w11
|
||||
w01 = dw_grid - w11
|
||||
w00 = 1 - dh_grid - w01
|
||||
|
||||
h_grid = torch.stack([h_floor_grid, h_floor_grid, h_ceil_grid, h_ceil_grid])
|
||||
w_grid = torch.stack([w_floor_grid, w_ceil_grid, w_floor_grid, w_ceil_grid])
|
||||
h_grid_idx = h_grid * num_grid_per_side
|
||||
|
||||
indices = (h_grid_idx + w_grid).reshape(4, -1)
|
||||
weights = torch.stack([w00, w01, w10, w11], dim=0).reshape(4, -1, 1)
|
||||
weights = weights.to(dtype=dtype)
|
||||
|
||||
embeds = embed_weight[indices]
|
||||
embeds *= weights
|
||||
combined = embeds.sum(dim=0)
|
||||
|
||||
combined = combined.reshape(h // m_size, m_size, w // m_size, m_size, hidden_dim)
|
||||
combined = combined.permute(0, 2, 1, 3, 4).reshape(1, -1, hidden_dim)
|
||||
repeated = combined.expand(t, -1, -1).reshape(-1, hidden_dim)
|
||||
return repeated.to(dtype=dtype)
|
||||
|
||||
|
||||
class Qwen3_VisionPatchEmbed(nn.Module):
|
||||
def __init__(
|
||||
@@ -470,63 +666,22 @@ class Qwen3_VisionTransformer(nn.Module):
|
||||
return cos_combined, sin_combined
|
||||
|
||||
def fast_pos_embed_interpolate(self, grid_thw: list[list[int]]) -> torch.Tensor:
|
||||
num_grid_per_side = self.num_grid_per_side
|
||||
m_size = self.spatial_merge_size
|
||||
hidden_dim = self.pos_embed.embedding_dim
|
||||
|
||||
interpolate_fn = (
|
||||
triton_pos_embed_interpolate if HAS_TRITON else pos_embed_interpolate_native
|
||||
)
|
||||
outputs = []
|
||||
for t, h, w in grid_thw:
|
||||
h_idxs = torch.linspace(
|
||||
0, num_grid_per_side - 1, h, dtype=torch.float32, device=self.device
|
||||
outputs.append(
|
||||
interpolate_fn(
|
||||
self.pos_embed.weight,
|
||||
t,
|
||||
h,
|
||||
w,
|
||||
self.num_grid_per_side,
|
||||
self.spatial_merge_size,
|
||||
self.dtype,
|
||||
)
|
||||
)
|
||||
w_idxs = torch.linspace(
|
||||
0, num_grid_per_side - 1, w, dtype=torch.float32, device=self.device
|
||||
)
|
||||
|
||||
h_floor = h_idxs.to(torch.long)
|
||||
w_floor = w_idxs.to(torch.long)
|
||||
h_ceil = torch.clamp(h_floor + 1, max=num_grid_per_side - 1)
|
||||
w_ceil = torch.clamp(w_floor + 1, max=num_grid_per_side - 1)
|
||||
|
||||
dh = h_idxs - h_floor
|
||||
dw = w_idxs - w_floor
|
||||
|
||||
# Create meshgrid view for all h, w vars
|
||||
dh_grid, dw_grid = torch.meshgrid(dh, dw, indexing="ij")
|
||||
h_floor_grid, w_floor_grid = torch.meshgrid(h_floor, w_floor, indexing="ij")
|
||||
h_ceil_grid, w_ceil_grid = torch.meshgrid(h_ceil, w_ceil, indexing="ij")
|
||||
|
||||
# original computation of weights
|
||||
# w00 = (1 - dh_grid) * (1 - dw_grid)
|
||||
# w01 = (1 - dh_grid) * dw_grid
|
||||
# w10 = dh_grid * (1 - dw_grid)
|
||||
# w11 = dh_grid * dw_grid
|
||||
# we reuse w11 here to avoid duplicate
|
||||
# dh_grid * dw_grid computation
|
||||
w11 = dh_grid * dw_grid
|
||||
w10 = dh_grid - w11
|
||||
w01 = dw_grid - w11
|
||||
w00 = 1 - dh_grid - w01
|
||||
|
||||
h_grid = torch.stack([h_floor_grid, h_floor_grid, h_ceil_grid, h_ceil_grid])
|
||||
w_grid = torch.stack([w_floor_grid, w_ceil_grid, w_floor_grid, w_ceil_grid])
|
||||
h_grid_idx = h_grid * num_grid_per_side
|
||||
|
||||
indices = (h_grid_idx + w_grid).reshape(4, -1)
|
||||
weights = torch.stack([w00, w01, w10, w11], dim=0).reshape(4, -1, 1)
|
||||
weights = weights.to(dtype=self.dtype)
|
||||
|
||||
embeds = self.pos_embed(indices)
|
||||
embeds *= weights
|
||||
combined = embeds.sum(dim=0)
|
||||
|
||||
combined = combined.reshape(
|
||||
h // m_size, m_size, w // m_size, m_size, hidden_dim
|
||||
)
|
||||
combined = combined.permute(0, 2, 1, 3, 4).reshape(1, -1, hidden_dim)
|
||||
repeated = combined.expand(t, -1, -1).reshape(-1, hidden_dim)
|
||||
outputs.append(repeated)
|
||||
|
||||
return torch.cat(outputs, dim=0)
|
||||
|
||||
def prepare_encoder_metadata(
|
||||
|
||||
Reference in New Issue
Block a user