diff --git a/tensorrt_llm/_torch/models/modeling_qwen2vl.py b/tensorrt_llm/_torch/models/modeling_qwen2vl.py index 3371bb6fc5..03f15c37b4 100644 --- a/tensorrt_llm/_torch/models/modeling_qwen2vl.py +++ b/tensorrt_llm/_torch/models/modeling_qwen2vl.py @@ -10,6 +10,7 @@ from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize from tensorrt_llm.inputs.multimodal import MultimodalParams +from ..._utils import nvtx_range_debug from ...functional import RopeEmbeddingUtils, RotaryScalingType from ...inputs import (ExtraProcessedInputs, InputProcessor, TextPrompt, register_input_processor) @@ -41,7 +42,6 @@ class Qwen2VLInputProcessorBase(InputProcessor): trust_remote_code=trust_remote_code) self.tllm_multimodal_token_id = self.model_config.vocab_size + 1 - self._post_init_() @classmethod def get_rope_index( @@ -217,22 +217,6 @@ class Qwen2VLInputProcessorBase(InputProcessor): mrope_position_deltas, device=input_ids.device).unsqueeze(1) return position_ids, mrope_position_deltas - def _post_init_(self): - _, rotary_cos_sin = RopeEmbeddingUtils.create_sinusoidal_positions_for_attention_plugin( - num_pos=self.model_config.max_position_embeddings, - dim=int(self.model_config.hidden_size / - self.model_config.num_attention_heads), - theta=float(self.model_config.rope_theta), - scale_type=RotaryScalingType.mrope) - self.rotary_cos_sin = torch.from_numpy(rotary_cos_sin) - self.rotary_cos_sin = self.rotary_cos_sin.reshape( - self.model_config.max_position_embeddings, - int(self.model_config.hidden_size / - self.model_config.num_attention_heads / 2), 2) - - self.cos_ori = self.rotary_cos_sin[:, :, 0] - self.sin_ori = self.rotary_cos_sin[:, :, 1] - def get_num_tokens_per_image( self, *, @@ -304,30 +288,8 @@ class Qwen2VLInputProcessorBase(InputProcessor): self.model_config, input_ids, image_grid_thw, video_grid_thw, attention_mask, second_per_grid_ts) - mrope_position_ids = mrope_position_ids.transpose(1, 0) - mrope_position_ids_padding = torch.zeros( - mrope_position_ids.shape[:-1] + - (self.model_config.max_position_embeddings, ), - dtype=torch.int32, - device=input_ids.device) - mrope_position_ids_padding[:, :, :mrope_position_ids. - shape[-1]] = mrope_position_ids - cos = self.cos_ori[mrope_position_ids_padding] - sin = self.sin_ori[mrope_position_ids_padding] - - mrope_section = [16, 24, 24] - cos = torch.cat([ - m[:, i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1)) - ], - dim=-1).unsqueeze(-1) - sin = torch.cat([ - m[:, i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1)) - ], - dim=-1).unsqueeze(-1) - concat_cos_sin = torch.concatenate((cos, sin), axis=-1) - concat_cos_sin = concat_cos_sin.reshape(concat_cos_sin.shape[0], -1) mrope_config = {} - mrope_config['mrope_rotary_cos_sin'] = concat_cos_sin.to('cpu') + mrope_config['mrope_position_ids'] = mrope_position_ids.to('cpu') mrope_config['mrope_position_deltas'] = mrope_position_deltas.to( 'cpu').to(torch.int32) return mrope_config @@ -340,10 +302,9 @@ class Qwen2VLInputProcessorBase(InputProcessor): ) -> Tuple[List[int], Optional[ExtraProcessedInputs]]: text_prompt, mm_data, mm_processor_kwargs = inputs.get("prompt"), \ inputs.get("multi_modal_data", {}), inputs.get("mm_processor_kwargs", {}) - - processed_inputs = self._preprocess(text_prompt, mm_data, - mm_processor_kwargs) - + with nvtx_range_debug("transformers input preprocess"): + processed_inputs = self._preprocess(text_prompt, mm_data, + mm_processor_kwargs) if not mm_data: fused_input_ids = processed_inputs['input_ids'] return fused_input_ids.to(torch.int32).tolist(), {} @@ -513,8 +474,25 @@ class Qwen2VLModelBase(PreTrainedModel): self.post_config() self.is_loaded = True + def init_rotary_cos_sin_ori(self): + _, rotary_cos_sin = RopeEmbeddingUtils.create_sinusoidal_positions_for_attention_plugin( + num_pos=self.model_config.pretrained_config.max_position_embeddings, + dim=int(self.model_config.pretrained_config.hidden_size / + self.model_config.pretrained_config.num_attention_heads), + theta=float(self.model_config.pretrained_config.rope_theta), + scale_type=RotaryScalingType.mrope) + self.rotary_cos_sin = torch.from_numpy(rotary_cos_sin).to(self.device) + self.rotary_cos_sin = self.rotary_cos_sin.reshape( + self.model_config.pretrained_config.max_position_embeddings, + int(self.model_config.pretrained_config.hidden_size / + self.model_config.pretrained_config.num_attention_heads / 2), 2) + + self.cos_ori = self.rotary_cos_sin[:, :, 0] + self.sin_ori = self.rotary_cos_sin[:, :, 1] + def load_weights(self, weights): self.llm.load_weights(weights) + self.init_rotary_cos_sin_ori() def infer_max_seq_len(self) -> int: return self.llm.infer_max_seq_len() @@ -566,6 +544,49 @@ class Qwen2VLModelBase(PreTrainedModel): return batched_mrope_config + def add_rotary_cos_sin(self, multimodal_params: List[MultimodalParams]): + for param in multimodal_params: + mrope_config = param.multimodal_data.get('mrope_config') + if mrope_config: + mrope_position_ids = mrope_config.get('mrope_position_ids', + None) + if mrope_position_ids is None: + continue + mrope_position_ids = mrope_position_ids.transpose(1, 0) + mrope_position_ids_padding = torch.zeros( + mrope_position_ids.shape[:-1] + + (self.model_config.pretrained_config. + max_position_embeddings, ), + dtype=torch.int32, + device=mrope_position_ids.device) + mrope_position_ids_padding[:, :, :mrope_position_ids. + shape[-1]] = mrope_position_ids + + mrope_position_ids_padding = mrope_position_ids_padding.to( + self.cos_ori.device) + cos = self.cos_ori[mrope_position_ids_padding] + sin = self.sin_ori[mrope_position_ids_padding] + + mrope_section = [16, 24, 24] + cos = torch.cat([ + m[:, i % 3] + for i, m in enumerate(cos.split(mrope_section, dim=-1)) + ], + dim=-1).unsqueeze(-1) + sin = torch.cat([ + m[:, i % 3] + for i, m in enumerate(sin.split(mrope_section, dim=-1)) + ], + dim=-1).unsqueeze(-1) + concat_cos_sin = torch.concatenate((cos, sin), axis=-1) + concat_cos_sin = concat_cos_sin.reshape(concat_cos_sin.shape[0], + -1) + + mrope_config['mrope_rotary_cos_sin'] = concat_cos_sin.to( + self.device) + + return multimodal_params + @torch.inference_mode() def forward( self, @@ -585,6 +606,8 @@ class Qwen2VLModelBase(PreTrainedModel): ) multimodal_params = kwargs.get("multimodal_params", []) + multimodal_params = self.add_rotary_cos_sin(multimodal_params) + mm_embeds = [] mrope_config = {}