mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
[PERF] Move calculation Qwen2-VL's rotary_cos_sin to LLM worker process (#6004)
Signed-off-by: Vadim Gimpelson <vadim.gimpelson@centml.ai>
This commit is contained in:
parent
0e16d1f070
commit
25cd4f215e
@ -10,6 +10,7 @@ from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize
|
||||
|
||||
from tensorrt_llm.inputs.multimodal import MultimodalParams
|
||||
|
||||
from ..._utils import nvtx_range_debug
|
||||
from ...functional import RopeEmbeddingUtils, RotaryScalingType
|
||||
from ...inputs import (ExtraProcessedInputs, InputProcessor, TextPrompt,
|
||||
register_input_processor)
|
||||
@ -41,7 +42,6 @@ class Qwen2VLInputProcessorBase(InputProcessor):
|
||||
trust_remote_code=trust_remote_code)
|
||||
|
||||
self.tllm_multimodal_token_id = self.model_config.vocab_size + 1
|
||||
self._post_init_()
|
||||
|
||||
@classmethod
|
||||
def get_rope_index(
|
||||
@ -217,22 +217,6 @@ class Qwen2VLInputProcessorBase(InputProcessor):
|
||||
mrope_position_deltas, device=input_ids.device).unsqueeze(1)
|
||||
return position_ids, mrope_position_deltas
|
||||
|
||||
def _post_init_(self):
|
||||
_, rotary_cos_sin = RopeEmbeddingUtils.create_sinusoidal_positions_for_attention_plugin(
|
||||
num_pos=self.model_config.max_position_embeddings,
|
||||
dim=int(self.model_config.hidden_size /
|
||||
self.model_config.num_attention_heads),
|
||||
theta=float(self.model_config.rope_theta),
|
||||
scale_type=RotaryScalingType.mrope)
|
||||
self.rotary_cos_sin = torch.from_numpy(rotary_cos_sin)
|
||||
self.rotary_cos_sin = self.rotary_cos_sin.reshape(
|
||||
self.model_config.max_position_embeddings,
|
||||
int(self.model_config.hidden_size /
|
||||
self.model_config.num_attention_heads / 2), 2)
|
||||
|
||||
self.cos_ori = self.rotary_cos_sin[:, :, 0]
|
||||
self.sin_ori = self.rotary_cos_sin[:, :, 1]
|
||||
|
||||
def get_num_tokens_per_image(
|
||||
self,
|
||||
*,
|
||||
@ -304,30 +288,8 @@ class Qwen2VLInputProcessorBase(InputProcessor):
|
||||
self.model_config, input_ids, image_grid_thw, video_grid_thw,
|
||||
attention_mask, second_per_grid_ts)
|
||||
|
||||
mrope_position_ids = mrope_position_ids.transpose(1, 0)
|
||||
mrope_position_ids_padding = torch.zeros(
|
||||
mrope_position_ids.shape[:-1] +
|
||||
(self.model_config.max_position_embeddings, ),
|
||||
dtype=torch.int32,
|
||||
device=input_ids.device)
|
||||
mrope_position_ids_padding[:, :, :mrope_position_ids.
|
||||
shape[-1]] = mrope_position_ids
|
||||
cos = self.cos_ori[mrope_position_ids_padding]
|
||||
sin = self.sin_ori[mrope_position_ids_padding]
|
||||
|
||||
mrope_section = [16, 24, 24]
|
||||
cos = torch.cat([
|
||||
m[:, i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))
|
||||
],
|
||||
dim=-1).unsqueeze(-1)
|
||||
sin = torch.cat([
|
||||
m[:, i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))
|
||||
],
|
||||
dim=-1).unsqueeze(-1)
|
||||
concat_cos_sin = torch.concatenate((cos, sin), axis=-1)
|
||||
concat_cos_sin = concat_cos_sin.reshape(concat_cos_sin.shape[0], -1)
|
||||
mrope_config = {}
|
||||
mrope_config['mrope_rotary_cos_sin'] = concat_cos_sin.to('cpu')
|
||||
mrope_config['mrope_position_ids'] = mrope_position_ids.to('cpu')
|
||||
mrope_config['mrope_position_deltas'] = mrope_position_deltas.to(
|
||||
'cpu').to(torch.int32)
|
||||
return mrope_config
|
||||
@ -340,10 +302,9 @@ class Qwen2VLInputProcessorBase(InputProcessor):
|
||||
) -> Tuple[List[int], Optional[ExtraProcessedInputs]]:
|
||||
text_prompt, mm_data, mm_processor_kwargs = inputs.get("prompt"), \
|
||||
inputs.get("multi_modal_data", {}), inputs.get("mm_processor_kwargs", {})
|
||||
|
||||
processed_inputs = self._preprocess(text_prompt, mm_data,
|
||||
mm_processor_kwargs)
|
||||
|
||||
with nvtx_range_debug("transformers input preprocess"):
|
||||
processed_inputs = self._preprocess(text_prompt, mm_data,
|
||||
mm_processor_kwargs)
|
||||
if not mm_data:
|
||||
fused_input_ids = processed_inputs['input_ids']
|
||||
return fused_input_ids.to(torch.int32).tolist(), {}
|
||||
@ -513,8 +474,25 @@ class Qwen2VLModelBase(PreTrainedModel):
|
||||
self.post_config()
|
||||
self.is_loaded = True
|
||||
|
||||
def init_rotary_cos_sin_ori(self):
|
||||
_, rotary_cos_sin = RopeEmbeddingUtils.create_sinusoidal_positions_for_attention_plugin(
|
||||
num_pos=self.model_config.pretrained_config.max_position_embeddings,
|
||||
dim=int(self.model_config.pretrained_config.hidden_size /
|
||||
self.model_config.pretrained_config.num_attention_heads),
|
||||
theta=float(self.model_config.pretrained_config.rope_theta),
|
||||
scale_type=RotaryScalingType.mrope)
|
||||
self.rotary_cos_sin = torch.from_numpy(rotary_cos_sin).to(self.device)
|
||||
self.rotary_cos_sin = self.rotary_cos_sin.reshape(
|
||||
self.model_config.pretrained_config.max_position_embeddings,
|
||||
int(self.model_config.pretrained_config.hidden_size /
|
||||
self.model_config.pretrained_config.num_attention_heads / 2), 2)
|
||||
|
||||
self.cos_ori = self.rotary_cos_sin[:, :, 0]
|
||||
self.sin_ori = self.rotary_cos_sin[:, :, 1]
|
||||
|
||||
def load_weights(self, weights):
|
||||
self.llm.load_weights(weights)
|
||||
self.init_rotary_cos_sin_ori()
|
||||
|
||||
def infer_max_seq_len(self) -> int:
|
||||
return self.llm.infer_max_seq_len()
|
||||
@ -566,6 +544,49 @@ class Qwen2VLModelBase(PreTrainedModel):
|
||||
|
||||
return batched_mrope_config
|
||||
|
||||
def add_rotary_cos_sin(self, multimodal_params: List[MultimodalParams]):
|
||||
for param in multimodal_params:
|
||||
mrope_config = param.multimodal_data.get('mrope_config')
|
||||
if mrope_config:
|
||||
mrope_position_ids = mrope_config.get('mrope_position_ids',
|
||||
None)
|
||||
if mrope_position_ids is None:
|
||||
continue
|
||||
mrope_position_ids = mrope_position_ids.transpose(1, 0)
|
||||
mrope_position_ids_padding = torch.zeros(
|
||||
mrope_position_ids.shape[:-1] +
|
||||
(self.model_config.pretrained_config.
|
||||
max_position_embeddings, ),
|
||||
dtype=torch.int32,
|
||||
device=mrope_position_ids.device)
|
||||
mrope_position_ids_padding[:, :, :mrope_position_ids.
|
||||
shape[-1]] = mrope_position_ids
|
||||
|
||||
mrope_position_ids_padding = mrope_position_ids_padding.to(
|
||||
self.cos_ori.device)
|
||||
cos = self.cos_ori[mrope_position_ids_padding]
|
||||
sin = self.sin_ori[mrope_position_ids_padding]
|
||||
|
||||
mrope_section = [16, 24, 24]
|
||||
cos = torch.cat([
|
||||
m[:, i % 3]
|
||||
for i, m in enumerate(cos.split(mrope_section, dim=-1))
|
||||
],
|
||||
dim=-1).unsqueeze(-1)
|
||||
sin = torch.cat([
|
||||
m[:, i % 3]
|
||||
for i, m in enumerate(sin.split(mrope_section, dim=-1))
|
||||
],
|
||||
dim=-1).unsqueeze(-1)
|
||||
concat_cos_sin = torch.concatenate((cos, sin), axis=-1)
|
||||
concat_cos_sin = concat_cos_sin.reshape(concat_cos_sin.shape[0],
|
||||
-1)
|
||||
|
||||
mrope_config['mrope_rotary_cos_sin'] = concat_cos_sin.to(
|
||||
self.device)
|
||||
|
||||
return multimodal_params
|
||||
|
||||
@torch.inference_mode()
|
||||
def forward(
|
||||
self,
|
||||
@ -585,6 +606,8 @@ class Qwen2VLModelBase(PreTrainedModel):
|
||||
)
|
||||
|
||||
multimodal_params = kwargs.get("multimodal_params", [])
|
||||
multimodal_params = self.add_rotary_cos_sin(multimodal_params)
|
||||
|
||||
mm_embeds = []
|
||||
mrope_config = {}
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user