diff --git a/requirements.txt b/requirements.txt index f4ecfa23e7..88bf8552f7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -28,7 +28,7 @@ torchvision nvidia-modelopt[torch]~=0.31.0 nvidia-nccl-cu12 nvidia-cuda-nvrtc-cu12 -transformers~=4.51.1 +transformers==4.53.1 pydantic>=2.9.1 pydantic-settings pillow==10.3.0 diff --git a/tensorrt_llm/_torch/attention_backend/interface.py b/tensorrt_llm/_torch/attention_backend/interface.py index 132b1bfecd..66a5ef3c39 100644 --- a/tensorrt_llm/_torch/attention_backend/interface.py +++ b/tensorrt_llm/_torch/attention_backend/interface.py @@ -359,8 +359,9 @@ class RopeParams: # get rotary parameters. hidden_size = config.hidden_size num_attention_heads = config.num_attention_heads - head_dim = getattr(config, 'head_dim', - hidden_size // num_attention_heads) + head_dim = getattr(config, 'head_dim', None) + if not isinstance(head_dim, int): + head_dim = hidden_size // num_attention_heads rope_scaling = getattr(config, 'rope_scaling', None) rope_params.max_positions = config.max_position_embeddings rope_params.theta = getattr(config, 'rope_theta', 10000.0) diff --git a/tensorrt_llm/_torch/models/modeling_gemma3vl.py b/tensorrt_llm/_torch/models/modeling_gemma3vl.py index ceb6b01a9c..52234643a5 100644 --- a/tensorrt_llm/_torch/models/modeling_gemma3vl.py +++ b/tensorrt_llm/_torch/models/modeling_gemma3vl.py @@ -181,6 +181,3 @@ class Gemma3Model(PreTrainedModel): logits = self.llm.forward(attn_metadata, input_ids, position_ids, inputs_embeds, return_context_logits) return logits - - -AutoModel.register(Gemma3Config, Gemma3Model) diff --git a/tensorrt_llm/_torch/models/modeling_llava_next.py b/tensorrt_llm/_torch/models/modeling_llava_next.py index 851b350b36..8af484ce1a 100644 --- a/tensorrt_llm/_torch/models/modeling_llava_next.py +++ b/tensorrt_llm/_torch/models/modeling_llava_next.py @@ -287,6 +287,3 @@ class LlavaNextModel(PreTrainedModel): logits = self.llm.forward(attn_metadata, input_ids, position_ids, inputs_embeds, return_context_logits) return logits - - -AutoModel.register(LlavaNextConfig, LlavaNextModel) diff --git a/tensorrt_llm/_torch/models/modeling_qwen2vl.py b/tensorrt_llm/_torch/models/modeling_qwen2vl.py index 90a66b907f..2d63a4bbf9 100644 --- a/tensorrt_llm/_torch/models/modeling_qwen2vl.py +++ b/tensorrt_llm/_torch/models/modeling_qwen2vl.py @@ -179,6 +179,8 @@ class Qwen2VLInputProcessorBase(InputProcessor): # Calculate temporal position IDs based on model type if hasattr(model_config.vision_config, 'tokens_per_second'): # Qwen2_5_VL style temporal position calculation + if isinstance(second_per_grid_t, torch.Tensor): + second_per_grid_t = second_per_grid_t.item() range_tensor = torch.arange(llm_grid_t).view(-1, 1) expanded_range = range_tensor.expand( -1, llm_grid_h * llm_grid_w) @@ -273,6 +275,8 @@ class Qwen2VLInputProcessorBase(InputProcessor): do_rescale = False if videos and isinstance(videos[0][0], torch.Tensor): do_rescale = False + # transformers=4.53.1 does not support GPU video tensors in Qwen2VL processor. + videos = [[frame.to("cpu") for frame in video] for video in videos] return self.processor(text=[text], images=images, videos=videos, diff --git a/tensorrt_llm/_torch/modules/attention.py b/tensorrt_llm/_torch/modules/attention.py index 8ddb5f06ac..3bc949b486 100644 --- a/tensorrt_llm/_torch/modules/attention.py +++ b/tensorrt_llm/_torch/modules/attention.py @@ -67,8 +67,9 @@ class Attention(nn.Module): config = config or ModelConfig() self.hidden_size = hidden_size self.num_heads = num_attention_heads - self.head_dim = getattr(config.pretrained_config, "head_dim", - self.hidden_size // self.num_heads) + self.head_dim = getattr(config.pretrained_config, 'head_dim', None) + if not isinstance(self.head_dim, int): + self.head_dim = self.hidden_size // self.num_heads self.num_key_value_heads = num_key_value_heads self.num_key_value_groups = self.num_heads // self.num_key_value_heads self.max_position_embeddings = max_position_embeddings diff --git a/tensorrt_llm/_torch/pyexecutor/_util.py b/tensorrt_llm/_torch/pyexecutor/_util.py index 29c4413b67..91a7d3ec00 100644 --- a/tensorrt_llm/_torch/pyexecutor/_util.py +++ b/tensorrt_llm/_torch/pyexecutor/_util.py @@ -75,11 +75,10 @@ class KvCacheCreator: head_dim = config.kv_lora_rank + config.qk_rope_head_dim kv_factor = 1 else: - head_dim = getattr( - config, - "head_dim", - config.hidden_size // config.num_attention_heads, - ) * num_key_value_heads // tp_size + _head_dim = getattr(config, 'head_dim', None) + if not isinstance(_head_dim, int): + _head_dim = config.hidden_size // config.num_attention_heads + head_dim = _head_dim * num_key_value_heads // tp_size # provide at least 1 layer to prevent division by zero cache size num_attention_layers = max( @@ -281,8 +280,9 @@ class KvCacheCreator: num_attention_heads = config.num_attention_heads num_key_value_heads = getattr(config, 'num_key_value_heads', num_attention_heads) - head_dim = getattr(config, "head_dim", - hidden_size // num_attention_heads) + head_dim = getattr(config, "head_dim", None) + if not isinstance(head_dim, int): + head_dim = hidden_size // num_attention_heads if quant_config is not None and quant_config.quant_mode.has_fp8_kv_cache( ): diff --git a/tests/unittest/_torch/modeling/test_modeling_mllama.py b/tests/unittest/_torch/modeling/test_modeling_mllama.py index b3338e148d..665e28919b 100644 --- a/tests/unittest/_torch/modeling/test_modeling_mllama.py +++ b/tests/unittest/_torch/modeling/test_modeling_mllama.py @@ -1,3 +1,4 @@ +import re import unittest from copy import deepcopy @@ -255,6 +256,24 @@ LLAMA_3_2_11B_VISION_CONFIG = { } +def convert_weights_names(weights: dict) -> dict: + # Since transformers version >= 4.52.0, the default model architecture is changed. + # We need to convert the weight names accordingly to match TRTLLM naming. + _checkpoint_conversion_mapping = { + "^model.language_model": "language_model.model", + "^model.vision_model": "vision_model", + "^model.multi_modal_projector": "multi_modal_projector", + "^lm_head": "language_model.lm_head", + } + converted_weights = {} + for weight_name, weight_value in weights.items(): + new_name = weight_name + for pattern, replacement in _checkpoint_conversion_mapping.items(): + new_name = re.sub(pattern, replacement, new_name) + converted_weights[new_name] = weight_value + return converted_weights + + class TestMLlama(unittest.TestCase): @parameterized.expand([ @@ -301,7 +320,8 @@ class TestMLlama(unittest.TestCase): mllama = MllamaForConditionalGeneration( ModelConfig(pretrained_config=mllama_config, attn_backend=backend)).to(dtype).to(device) - mllama.load_weights(hf_mllama.state_dict()) + weights = convert_weights_names(hf_mllama.state_dict()) + mllama.load_weights(weights) # KV cache setup num_blocks = 1 diff --git a/tests/unittest/trt/attention/test_gpt_attention.py b/tests/unittest/trt/attention/test_gpt_attention.py index 1015caecec..afc592cb7e 100644 --- a/tests/unittest/trt/attention/test_gpt_attention.py +++ b/tests/unittest/trt/attention/test_gpt_attention.py @@ -1230,11 +1230,14 @@ class TestFunctional(unittest.TestCase): else: attention_packed_mask = None if attention_type == 'gpt2_attention': - torch_output, torch_present = attention( - input_tensor, - layer_past=None, - use_cache=True, - attention_mask=attention_mask) + # gpt2 uses DynamicCache + torch_present = DynamicCache.from_legacy_cache( + torch_present) + torch_output = attention(input_tensor, + past_key_value=torch_present, + use_cache=True, + attention_mask=attention_mask)[0] + torch_present = torch_present.to_legacy_cache() elif attention_type == 'llama_attention': position_embeddings = rotary_emb(input_tensor, position_ids) attention_mask = attention_mask + AttentionMaskConverter._make_causal_mask( @@ -1277,7 +1280,7 @@ class TestFunctional(unittest.TestCase): torch.cuda.synchronize() - if attention_type == 'llama_attention': + if attention_type in ['llama_attention', 'gpt2_attention']: kv_dequant_scale, kv_quant_scale = get_kv_quant_scale( torch_present[0]) else: @@ -1322,7 +1325,7 @@ class TestFunctional(unittest.TestCase): torch_output[:, :in_len // 2, :].to( torch.float32).cpu().numpy(), atol=5e-3) - if attention_type == 'llama_attention': + if attention_type in ['llama_attention', 'gpt2_attention']: verify_kv_cache(torch_present[0]) else: verify_kv_cache(torch_present) @@ -1374,11 +1377,14 @@ class TestFunctional(unittest.TestCase): # torch execution if attention_type == 'gpt2_attention': - torch_output, torch_present = attention( - input_tensor, - layer_past=torch_present, - use_cache=True, - attention_mask=attention_mask) + # gpt2 uses DynamicCache + torch_present = DynamicCache.from_legacy_cache( + torch_present) + torch_output = attention(input_tensor, + past_key_value=torch_present, + use_cache=True, + attention_mask=attention_mask)[0] + torch_present = torch_present.to_legacy_cache() elif attention_type == 'llama_attention': position_embeddings = rotary_emb(input_tensor, position_ids) attention_mask = attention_mask + AttentionMaskConverter._make_causal_mask( diff --git a/tests/unittest/trt/attention/test_gpt_attention_IFB.py b/tests/unittest/trt/attention/test_gpt_attention_IFB.py index 4e9537b5a5..41e0e015de 100644 --- a/tests/unittest/trt/attention/test_gpt_attention_IFB.py +++ b/tests/unittest/trt/attention/test_gpt_attention_IFB.py @@ -754,11 +754,11 @@ class TestFunctional(unittest.TestCase): dtype=tensorrt_llm._utils.str_dtype_to_torch(dtype), tgt_len=(in_len if step == 0 else 1)) if attention_type == 'gpt2_attention': - torch_output, torch_present = attention( - input, - layer_past=layer_past, - use_cache=True, - attention_mask=attention_mask) + torch_output = attention(input, + past_key_value=layer_past, + use_cache=True, + attention_mask=attention_mask)[0] + torch_present = layer_past elif attention_type == 'llama_attention': position_embeddings = rotary_emb(input, position_ids) attention_mask = attention_mask + AttentionMaskConverter._make_causal_mask( @@ -1003,8 +1003,8 @@ class TestFunctional(unittest.TestCase): torch_in = input_tensor[:, offset:offset_next, :].reshape( (local_beam_width, input_length, hidden_size)) - # llama uses DynamicCache - if attention_type == 'llama_attention': + # llama/gpt2 uses DynamicCache + if attention_type in ['llama_attention', 'gpt2_attention']: past_key_values = DynamicCache.from_legacy_cache( torch_cache_list[req_idx]) else: @@ -1014,8 +1014,8 @@ class TestFunctional(unittest.TestCase): step, torch_in, ctx_attention_mask_list[req_idx], req_idx, past_key_values) - # llama uses DynamicCache - if attention_type == 'llama_attention': + # llama/gpt2 uses DynamicCache + if attention_type in ['llama_attention', 'gpt2_attention']: torch_cache_list[req_idx] = past_key_values.to_legacy_cache( ) past_key_values = torch_cache_list[req_idx][0]