Update transformers to 4.53.0 (#5747)

Signed-off-by: Hao Lu <14827759+hlu1@users.noreply.github.com>
Signed-off-by: Wanli Jiang <35160485+Wanli-Jiang@users.noreply.github.com>
This commit is contained in:
Wanli Jiang 2025-07-10 00:32:24 +08:00 committed by GitHub
parent 74dca0aa7b
commit 3f7cedec7c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 66 additions and 40 deletions

View File

@ -28,7 +28,7 @@ torchvision
nvidia-modelopt[torch]~=0.31.0
nvidia-nccl-cu12
nvidia-cuda-nvrtc-cu12
transformers~=4.51.1
transformers==4.53.1
pydantic>=2.9.1
pydantic-settings
pillow==10.3.0

View File

@ -359,8 +359,9 @@ class RopeParams:
# get rotary parameters.
hidden_size = config.hidden_size
num_attention_heads = config.num_attention_heads
head_dim = getattr(config, 'head_dim',
hidden_size // num_attention_heads)
head_dim = getattr(config, 'head_dim', None)
if not isinstance(head_dim, int):
head_dim = hidden_size // num_attention_heads
rope_scaling = getattr(config, 'rope_scaling', None)
rope_params.max_positions = config.max_position_embeddings
rope_params.theta = getattr(config, 'rope_theta', 10000.0)

View File

@ -181,6 +181,3 @@ class Gemma3Model(PreTrainedModel):
logits = self.llm.forward(attn_metadata, input_ids, position_ids,
inputs_embeds, return_context_logits)
return logits
AutoModel.register(Gemma3Config, Gemma3Model)

View File

@ -287,6 +287,3 @@ class LlavaNextModel(PreTrainedModel):
logits = self.llm.forward(attn_metadata, input_ids, position_ids,
inputs_embeds, return_context_logits)
return logits
AutoModel.register(LlavaNextConfig, LlavaNextModel)

View File

@ -179,6 +179,8 @@ class Qwen2VLInputProcessorBase(InputProcessor):
# Calculate temporal position IDs based on model type
if hasattr(model_config.vision_config, 'tokens_per_second'):
# Qwen2_5_VL style temporal position calculation
if isinstance(second_per_grid_t, torch.Tensor):
second_per_grid_t = second_per_grid_t.item()
range_tensor = torch.arange(llm_grid_t).view(-1, 1)
expanded_range = range_tensor.expand(
-1, llm_grid_h * llm_grid_w)
@ -273,6 +275,8 @@ class Qwen2VLInputProcessorBase(InputProcessor):
do_rescale = False
if videos and isinstance(videos[0][0], torch.Tensor):
do_rescale = False
# transformers=4.53.1 does not support GPU video tensors in Qwen2VL processor.
videos = [[frame.to("cpu") for frame in video] for video in videos]
return self.processor(text=[text],
images=images,
videos=videos,

View File

@ -67,8 +67,9 @@ class Attention(nn.Module):
config = config or ModelConfig()
self.hidden_size = hidden_size
self.num_heads = num_attention_heads
self.head_dim = getattr(config.pretrained_config, "head_dim",
self.hidden_size // self.num_heads)
self.head_dim = getattr(config.pretrained_config, 'head_dim', None)
if not isinstance(self.head_dim, int):
self.head_dim = self.hidden_size // self.num_heads
self.num_key_value_heads = num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.max_position_embeddings = max_position_embeddings

View File

@ -75,11 +75,10 @@ class KvCacheCreator:
head_dim = config.kv_lora_rank + config.qk_rope_head_dim
kv_factor = 1
else:
head_dim = getattr(
config,
"head_dim",
config.hidden_size // config.num_attention_heads,
) * num_key_value_heads // tp_size
_head_dim = getattr(config, 'head_dim', None)
if not isinstance(_head_dim, int):
_head_dim = config.hidden_size // config.num_attention_heads
head_dim = _head_dim * num_key_value_heads // tp_size
# provide at least 1 layer to prevent division by zero cache size
num_attention_layers = max(
@ -281,8 +280,9 @@ class KvCacheCreator:
num_attention_heads = config.num_attention_heads
num_key_value_heads = getattr(config, 'num_key_value_heads',
num_attention_heads)
head_dim = getattr(config, "head_dim",
hidden_size // num_attention_heads)
head_dim = getattr(config, "head_dim", None)
if not isinstance(head_dim, int):
head_dim = hidden_size // num_attention_heads
if quant_config is not None and quant_config.quant_mode.has_fp8_kv_cache(
):

View File

@ -1,3 +1,4 @@
import re
import unittest
from copy import deepcopy
@ -255,6 +256,24 @@ LLAMA_3_2_11B_VISION_CONFIG = {
}
def convert_weights_names(weights: dict) -> dict:
# Since transformers version >= 4.52.0, the default model architecture is changed.
# We need to convert the weight names accordingly to match TRTLLM naming.
_checkpoint_conversion_mapping = {
"^model.language_model": "language_model.model",
"^model.vision_model": "vision_model",
"^model.multi_modal_projector": "multi_modal_projector",
"^lm_head": "language_model.lm_head",
}
converted_weights = {}
for weight_name, weight_value in weights.items():
new_name = weight_name
for pattern, replacement in _checkpoint_conversion_mapping.items():
new_name = re.sub(pattern, replacement, new_name)
converted_weights[new_name] = weight_value
return converted_weights
class TestMLlama(unittest.TestCase):
@parameterized.expand([
@ -301,7 +320,8 @@ class TestMLlama(unittest.TestCase):
mllama = MllamaForConditionalGeneration(
ModelConfig(pretrained_config=mllama_config,
attn_backend=backend)).to(dtype).to(device)
mllama.load_weights(hf_mllama.state_dict())
weights = convert_weights_names(hf_mllama.state_dict())
mllama.load_weights(weights)
# KV cache setup
num_blocks = 1

View File

@ -1230,11 +1230,14 @@ class TestFunctional(unittest.TestCase):
else:
attention_packed_mask = None
if attention_type == 'gpt2_attention':
torch_output, torch_present = attention(
input_tensor,
layer_past=None,
use_cache=True,
attention_mask=attention_mask)
# gpt2 uses DynamicCache
torch_present = DynamicCache.from_legacy_cache(
torch_present)
torch_output = attention(input_tensor,
past_key_value=torch_present,
use_cache=True,
attention_mask=attention_mask)[0]
torch_present = torch_present.to_legacy_cache()
elif attention_type == 'llama_attention':
position_embeddings = rotary_emb(input_tensor, position_ids)
attention_mask = attention_mask + AttentionMaskConverter._make_causal_mask(
@ -1277,7 +1280,7 @@ class TestFunctional(unittest.TestCase):
torch.cuda.synchronize()
if attention_type == 'llama_attention':
if attention_type in ['llama_attention', 'gpt2_attention']:
kv_dequant_scale, kv_quant_scale = get_kv_quant_scale(
torch_present[0])
else:
@ -1322,7 +1325,7 @@ class TestFunctional(unittest.TestCase):
torch_output[:, :in_len // 2, :].to(
torch.float32).cpu().numpy(),
atol=5e-3)
if attention_type == 'llama_attention':
if attention_type in ['llama_attention', 'gpt2_attention']:
verify_kv_cache(torch_present[0])
else:
verify_kv_cache(torch_present)
@ -1374,11 +1377,14 @@ class TestFunctional(unittest.TestCase):
# torch execution
if attention_type == 'gpt2_attention':
torch_output, torch_present = attention(
input_tensor,
layer_past=torch_present,
use_cache=True,
attention_mask=attention_mask)
# gpt2 uses DynamicCache
torch_present = DynamicCache.from_legacy_cache(
torch_present)
torch_output = attention(input_tensor,
past_key_value=torch_present,
use_cache=True,
attention_mask=attention_mask)[0]
torch_present = torch_present.to_legacy_cache()
elif attention_type == 'llama_attention':
position_embeddings = rotary_emb(input_tensor, position_ids)
attention_mask = attention_mask + AttentionMaskConverter._make_causal_mask(

View File

@ -754,11 +754,11 @@ class TestFunctional(unittest.TestCase):
dtype=tensorrt_llm._utils.str_dtype_to_torch(dtype),
tgt_len=(in_len if step == 0 else 1))
if attention_type == 'gpt2_attention':
torch_output, torch_present = attention(
input,
layer_past=layer_past,
use_cache=True,
attention_mask=attention_mask)
torch_output = attention(input,
past_key_value=layer_past,
use_cache=True,
attention_mask=attention_mask)[0]
torch_present = layer_past
elif attention_type == 'llama_attention':
position_embeddings = rotary_emb(input, position_ids)
attention_mask = attention_mask + AttentionMaskConverter._make_causal_mask(
@ -1003,8 +1003,8 @@ class TestFunctional(unittest.TestCase):
torch_in = input_tensor[:, offset:offset_next, :].reshape(
(local_beam_width, input_length, hidden_size))
# llama uses DynamicCache
if attention_type == 'llama_attention':
# llama/gpt2 uses DynamicCache
if attention_type in ['llama_attention', 'gpt2_attention']:
past_key_values = DynamicCache.from_legacy_cache(
torch_cache_list[req_idx])
else:
@ -1014,8 +1014,8 @@ class TestFunctional(unittest.TestCase):
step, torch_in, ctx_attention_mask_list[req_idx], req_idx,
past_key_values)
# llama uses DynamicCache
if attention_type == 'llama_attention':
# llama/gpt2 uses DynamicCache
if attention_type in ['llama_attention', 'gpt2_attention']:
torch_cache_list[req_idx] = past_key_values.to_legacy_cache(
)
past_key_values = torch_cache_list[req_idx][0]