[PERF] Move calculation Qwen2-VL's rotary_cos_sin to LLM worker process (#6004)

Signed-off-by: Vadim Gimpelson <vadim.gimpelson@centml.ai>
This commit is contained in:
Vadim Gimpelson 2025-07-31 04:35:24 +04:00 committed by GitHub
parent 0e16d1f070
commit 25cd4f215e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -10,6 +10,7 @@ from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize
from tensorrt_llm.inputs.multimodal import MultimodalParams
from ..._utils import nvtx_range_debug
from ...functional import RopeEmbeddingUtils, RotaryScalingType
from ...inputs import (ExtraProcessedInputs, InputProcessor, TextPrompt,
register_input_processor)
@ -41,7 +42,6 @@ class Qwen2VLInputProcessorBase(InputProcessor):
trust_remote_code=trust_remote_code)
self.tllm_multimodal_token_id = self.model_config.vocab_size + 1
self._post_init_()
@classmethod
def get_rope_index(
@ -217,22 +217,6 @@ class Qwen2VLInputProcessorBase(InputProcessor):
mrope_position_deltas, device=input_ids.device).unsqueeze(1)
return position_ids, mrope_position_deltas
def _post_init_(self):
_, rotary_cos_sin = RopeEmbeddingUtils.create_sinusoidal_positions_for_attention_plugin(
num_pos=self.model_config.max_position_embeddings,
dim=int(self.model_config.hidden_size /
self.model_config.num_attention_heads),
theta=float(self.model_config.rope_theta),
scale_type=RotaryScalingType.mrope)
self.rotary_cos_sin = torch.from_numpy(rotary_cos_sin)
self.rotary_cos_sin = self.rotary_cos_sin.reshape(
self.model_config.max_position_embeddings,
int(self.model_config.hidden_size /
self.model_config.num_attention_heads / 2), 2)
self.cos_ori = self.rotary_cos_sin[:, :, 0]
self.sin_ori = self.rotary_cos_sin[:, :, 1]
def get_num_tokens_per_image(
self,
*,
@ -304,30 +288,8 @@ class Qwen2VLInputProcessorBase(InputProcessor):
self.model_config, input_ids, image_grid_thw, video_grid_thw,
attention_mask, second_per_grid_ts)
mrope_position_ids = mrope_position_ids.transpose(1, 0)
mrope_position_ids_padding = torch.zeros(
mrope_position_ids.shape[:-1] +
(self.model_config.max_position_embeddings, ),
dtype=torch.int32,
device=input_ids.device)
mrope_position_ids_padding[:, :, :mrope_position_ids.
shape[-1]] = mrope_position_ids
cos = self.cos_ori[mrope_position_ids_padding]
sin = self.sin_ori[mrope_position_ids_padding]
mrope_section = [16, 24, 24]
cos = torch.cat([
m[:, i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))
],
dim=-1).unsqueeze(-1)
sin = torch.cat([
m[:, i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))
],
dim=-1).unsqueeze(-1)
concat_cos_sin = torch.concatenate((cos, sin), axis=-1)
concat_cos_sin = concat_cos_sin.reshape(concat_cos_sin.shape[0], -1)
mrope_config = {}
mrope_config['mrope_rotary_cos_sin'] = concat_cos_sin.to('cpu')
mrope_config['mrope_position_ids'] = mrope_position_ids.to('cpu')
mrope_config['mrope_position_deltas'] = mrope_position_deltas.to(
'cpu').to(torch.int32)
return mrope_config
@ -340,10 +302,9 @@ class Qwen2VLInputProcessorBase(InputProcessor):
) -> Tuple[List[int], Optional[ExtraProcessedInputs]]:
text_prompt, mm_data, mm_processor_kwargs = inputs.get("prompt"), \
inputs.get("multi_modal_data", {}), inputs.get("mm_processor_kwargs", {})
processed_inputs = self._preprocess(text_prompt, mm_data,
mm_processor_kwargs)
with nvtx_range_debug("transformers input preprocess"):
processed_inputs = self._preprocess(text_prompt, mm_data,
mm_processor_kwargs)
if not mm_data:
fused_input_ids = processed_inputs['input_ids']
return fused_input_ids.to(torch.int32).tolist(), {}
@ -513,8 +474,25 @@ class Qwen2VLModelBase(PreTrainedModel):
self.post_config()
self.is_loaded = True
def init_rotary_cos_sin_ori(self):
_, rotary_cos_sin = RopeEmbeddingUtils.create_sinusoidal_positions_for_attention_plugin(
num_pos=self.model_config.pretrained_config.max_position_embeddings,
dim=int(self.model_config.pretrained_config.hidden_size /
self.model_config.pretrained_config.num_attention_heads),
theta=float(self.model_config.pretrained_config.rope_theta),
scale_type=RotaryScalingType.mrope)
self.rotary_cos_sin = torch.from_numpy(rotary_cos_sin).to(self.device)
self.rotary_cos_sin = self.rotary_cos_sin.reshape(
self.model_config.pretrained_config.max_position_embeddings,
int(self.model_config.pretrained_config.hidden_size /
self.model_config.pretrained_config.num_attention_heads / 2), 2)
self.cos_ori = self.rotary_cos_sin[:, :, 0]
self.sin_ori = self.rotary_cos_sin[:, :, 1]
def load_weights(self, weights):
self.llm.load_weights(weights)
self.init_rotary_cos_sin_ori()
def infer_max_seq_len(self) -> int:
return self.llm.infer_max_seq_len()
@ -566,6 +544,49 @@ class Qwen2VLModelBase(PreTrainedModel):
return batched_mrope_config
def add_rotary_cos_sin(self, multimodal_params: List[MultimodalParams]):
for param in multimodal_params:
mrope_config = param.multimodal_data.get('mrope_config')
if mrope_config:
mrope_position_ids = mrope_config.get('mrope_position_ids',
None)
if mrope_position_ids is None:
continue
mrope_position_ids = mrope_position_ids.transpose(1, 0)
mrope_position_ids_padding = torch.zeros(
mrope_position_ids.shape[:-1] +
(self.model_config.pretrained_config.
max_position_embeddings, ),
dtype=torch.int32,
device=mrope_position_ids.device)
mrope_position_ids_padding[:, :, :mrope_position_ids.
shape[-1]] = mrope_position_ids
mrope_position_ids_padding = mrope_position_ids_padding.to(
self.cos_ori.device)
cos = self.cos_ori[mrope_position_ids_padding]
sin = self.sin_ori[mrope_position_ids_padding]
mrope_section = [16, 24, 24]
cos = torch.cat([
m[:, i % 3]
for i, m in enumerate(cos.split(mrope_section, dim=-1))
],
dim=-1).unsqueeze(-1)
sin = torch.cat([
m[:, i % 3]
for i, m in enumerate(sin.split(mrope_section, dim=-1))
],
dim=-1).unsqueeze(-1)
concat_cos_sin = torch.concatenate((cos, sin), axis=-1)
concat_cos_sin = concat_cos_sin.reshape(concat_cos_sin.shape[0],
-1)
mrope_config['mrope_rotary_cos_sin'] = concat_cos_sin.to(
self.device)
return multimodal_params
@torch.inference_mode()
def forward(
self,
@ -585,6 +606,8 @@ class Qwen2VLModelBase(PreTrainedModel):
)
multimodal_params = kwargs.get("multimodal_params", [])
multimodal_params = self.add_rotary_cos_sin(multimodal_params)
mm_embeds = []
mrope_config = {}