TensorRT-LLMs/tensorrt_llm/inputs/evs.py
Wanli Jiang 95be56e56b
[TRTLLM-8238][feat] Add EVS support for nano-v2-vlm (#8024)
Signed-off-by: Wanli Jiang <35160485+Wanli-Jiang@users.noreply.github.com>
Signed-off-by: Chang Liu <9713593+chang-l@users.noreply.github.com>
Co-authored-by: Chang Liu <9713593+chang-l@users.noreply.github.com>
2025-10-25 05:43:27 -04:00

94 lines
3.6 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
import torch
def compute_retained_tokens_count(video_size: torch.LongTensor,
spatial_merge_size: int,
pruning_ratio: float) -> int:
"""
Compute the number of retained tokens for a given video.
Method ensures that we retain all the tokens from the first frame
regardless of the pruning rate.
Args:
video_size: The size of the video in the format of (T, H, W).
spatial_merge_size: The size of the spatial merge.
pruning_ratio: The pruning ratio.
Returns:
The number of retained tokens.
"""
# Note about why map(int,..) exists here.
# In vLLM a rounding issue was observed when input was Tensor versus when input was tuple of integers.
# Tuple of ints input came from Preprocessing stage, while in actual forward() it was a Tensor.
# To make sure number of output tokens stays the case - an explicit cast was added.
T, H, W = map(int, video_size)
min_num_tokens = (H // spatial_merge_size) * (W // spatial_merge_size)
evs_num_tokens = int(T * min_num_tokens * (1 - pruning_ratio))
return max(min_num_tokens, evs_num_tokens)
def compute_retention_mask(
video_embeds: torch.Tensor,
video_size: torch.LongTensor,
spatial_merge_size: int,
pruning_ratio: float,
flatten_output: bool = True,
) -> torch.Tensor:
"""
Computes the retention mask for input video embeddings.
Args:
video_embeds (`torch.Tensor`): The input video embeddings
of shape `(T * H * W // spatial_merge_size ^ 2, hidden_size)`
or shape `(T, H * W // spatial_merge_size ^ 2, hidden_size)`.
video_size (`torch.LongTensor` of shape `(3)`):
The temporal, height and width of video.
spatial_merge_size: Size reduction for rows & cols dimensions.
pruning_ratio: (`float`): Pruning ratio factor [0,1)
flatten_output: (`bool`): Whether to flatten the output mask.
Returns:
`torch.Tensor`: The retention mask for the video embeddings of
`(T * H * W // spatial_merge_size ^ 2)` shape.
"""
T, H, W = video_size
# Use reshape instead of einops to avoid graph breaks
video_embeds = video_embeds.reshape(
T,
H // spatial_merge_size,
W // spatial_merge_size,
video_embeds.size(-1),
)
# Core EVS
similarity = torch.nn.functional.cosine_similarity(video_embeds[1:, ...],
video_embeds[:-1, ...],
dim=-1)
dissimilarity = 1 - similarity
# Always ensure we include all tokens from the first frame
dissimilarity = torch.cat(
[255 * torch.ones_like(video_embeds[:1, :, :, 0]), dissimilarity],
dim=0)
dissimilarity_flat = dissimilarity.view(-1)
order = torch.argsort(dissimilarity_flat,
dim=-1,
descending=True,
stable=True)
retain_num_tokens = compute_retained_tokens_count(video_size,
spatial_merge_size,
pruning_ratio)
topk_indices = order[:retain_num_tokens]
retention_mask = torch.zeros_like(dissimilarity_flat, dtype=torch.bool)
retention_mask[topk_indices] = True
retention_mask = retention_mask.reshape(dissimilarity.size())
mask = retention_mask.view(-1) if flatten_output else retention_mask
return mask