mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
Update transformers to 4.53.0 (#5747)
Signed-off-by: Hao Lu <14827759+hlu1@users.noreply.github.com> Signed-off-by: Wanli Jiang <35160485+Wanli-Jiang@users.noreply.github.com>
This commit is contained in:
parent
74dca0aa7b
commit
3f7cedec7c
@ -28,7 +28,7 @@ torchvision
|
||||
nvidia-modelopt[torch]~=0.31.0
|
||||
nvidia-nccl-cu12
|
||||
nvidia-cuda-nvrtc-cu12
|
||||
transformers~=4.51.1
|
||||
transformers==4.53.1
|
||||
pydantic>=2.9.1
|
||||
pydantic-settings
|
||||
pillow==10.3.0
|
||||
|
||||
@ -359,8 +359,9 @@ class RopeParams:
|
||||
# get rotary parameters.
|
||||
hidden_size = config.hidden_size
|
||||
num_attention_heads = config.num_attention_heads
|
||||
head_dim = getattr(config, 'head_dim',
|
||||
hidden_size // num_attention_heads)
|
||||
head_dim = getattr(config, 'head_dim', None)
|
||||
if not isinstance(head_dim, int):
|
||||
head_dim = hidden_size // num_attention_heads
|
||||
rope_scaling = getattr(config, 'rope_scaling', None)
|
||||
rope_params.max_positions = config.max_position_embeddings
|
||||
rope_params.theta = getattr(config, 'rope_theta', 10000.0)
|
||||
|
||||
@ -181,6 +181,3 @@ class Gemma3Model(PreTrainedModel):
|
||||
logits = self.llm.forward(attn_metadata, input_ids, position_ids,
|
||||
inputs_embeds, return_context_logits)
|
||||
return logits
|
||||
|
||||
|
||||
AutoModel.register(Gemma3Config, Gemma3Model)
|
||||
|
||||
@ -287,6 +287,3 @@ class LlavaNextModel(PreTrainedModel):
|
||||
logits = self.llm.forward(attn_metadata, input_ids, position_ids,
|
||||
inputs_embeds, return_context_logits)
|
||||
return logits
|
||||
|
||||
|
||||
AutoModel.register(LlavaNextConfig, LlavaNextModel)
|
||||
|
||||
@ -179,6 +179,8 @@ class Qwen2VLInputProcessorBase(InputProcessor):
|
||||
# Calculate temporal position IDs based on model type
|
||||
if hasattr(model_config.vision_config, 'tokens_per_second'):
|
||||
# Qwen2_5_VL style temporal position calculation
|
||||
if isinstance(second_per_grid_t, torch.Tensor):
|
||||
second_per_grid_t = second_per_grid_t.item()
|
||||
range_tensor = torch.arange(llm_grid_t).view(-1, 1)
|
||||
expanded_range = range_tensor.expand(
|
||||
-1, llm_grid_h * llm_grid_w)
|
||||
@ -273,6 +275,8 @@ class Qwen2VLInputProcessorBase(InputProcessor):
|
||||
do_rescale = False
|
||||
if videos and isinstance(videos[0][0], torch.Tensor):
|
||||
do_rescale = False
|
||||
# transformers=4.53.1 does not support GPU video tensors in Qwen2VL processor.
|
||||
videos = [[frame.to("cpu") for frame in video] for video in videos]
|
||||
return self.processor(text=[text],
|
||||
images=images,
|
||||
videos=videos,
|
||||
|
||||
@ -67,8 +67,9 @@ class Attention(nn.Module):
|
||||
config = config or ModelConfig()
|
||||
self.hidden_size = hidden_size
|
||||
self.num_heads = num_attention_heads
|
||||
self.head_dim = getattr(config.pretrained_config, "head_dim",
|
||||
self.hidden_size // self.num_heads)
|
||||
self.head_dim = getattr(config.pretrained_config, 'head_dim', None)
|
||||
if not isinstance(self.head_dim, int):
|
||||
self.head_dim = self.hidden_size // self.num_heads
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
|
||||
@ -75,11 +75,10 @@ class KvCacheCreator:
|
||||
head_dim = config.kv_lora_rank + config.qk_rope_head_dim
|
||||
kv_factor = 1
|
||||
else:
|
||||
head_dim = getattr(
|
||||
config,
|
||||
"head_dim",
|
||||
config.hidden_size // config.num_attention_heads,
|
||||
) * num_key_value_heads // tp_size
|
||||
_head_dim = getattr(config, 'head_dim', None)
|
||||
if not isinstance(_head_dim, int):
|
||||
_head_dim = config.hidden_size // config.num_attention_heads
|
||||
head_dim = _head_dim * num_key_value_heads // tp_size
|
||||
|
||||
# provide at least 1 layer to prevent division by zero cache size
|
||||
num_attention_layers = max(
|
||||
@ -281,8 +280,9 @@ class KvCacheCreator:
|
||||
num_attention_heads = config.num_attention_heads
|
||||
num_key_value_heads = getattr(config, 'num_key_value_heads',
|
||||
num_attention_heads)
|
||||
head_dim = getattr(config, "head_dim",
|
||||
hidden_size // num_attention_heads)
|
||||
head_dim = getattr(config, "head_dim", None)
|
||||
if not isinstance(head_dim, int):
|
||||
head_dim = hidden_size // num_attention_heads
|
||||
|
||||
if quant_config is not None and quant_config.quant_mode.has_fp8_kv_cache(
|
||||
):
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import re
|
||||
import unittest
|
||||
from copy import deepcopy
|
||||
|
||||
@ -255,6 +256,24 @@ LLAMA_3_2_11B_VISION_CONFIG = {
|
||||
}
|
||||
|
||||
|
||||
def convert_weights_names(weights: dict) -> dict:
|
||||
# Since transformers version >= 4.52.0, the default model architecture is changed.
|
||||
# We need to convert the weight names accordingly to match TRTLLM naming.
|
||||
_checkpoint_conversion_mapping = {
|
||||
"^model.language_model": "language_model.model",
|
||||
"^model.vision_model": "vision_model",
|
||||
"^model.multi_modal_projector": "multi_modal_projector",
|
||||
"^lm_head": "language_model.lm_head",
|
||||
}
|
||||
converted_weights = {}
|
||||
for weight_name, weight_value in weights.items():
|
||||
new_name = weight_name
|
||||
for pattern, replacement in _checkpoint_conversion_mapping.items():
|
||||
new_name = re.sub(pattern, replacement, new_name)
|
||||
converted_weights[new_name] = weight_value
|
||||
return converted_weights
|
||||
|
||||
|
||||
class TestMLlama(unittest.TestCase):
|
||||
|
||||
@parameterized.expand([
|
||||
@ -301,7 +320,8 @@ class TestMLlama(unittest.TestCase):
|
||||
mllama = MllamaForConditionalGeneration(
|
||||
ModelConfig(pretrained_config=mllama_config,
|
||||
attn_backend=backend)).to(dtype).to(device)
|
||||
mllama.load_weights(hf_mllama.state_dict())
|
||||
weights = convert_weights_names(hf_mllama.state_dict())
|
||||
mllama.load_weights(weights)
|
||||
|
||||
# KV cache setup
|
||||
num_blocks = 1
|
||||
|
||||
@ -1230,11 +1230,14 @@ class TestFunctional(unittest.TestCase):
|
||||
else:
|
||||
attention_packed_mask = None
|
||||
if attention_type == 'gpt2_attention':
|
||||
torch_output, torch_present = attention(
|
||||
input_tensor,
|
||||
layer_past=None,
|
||||
use_cache=True,
|
||||
attention_mask=attention_mask)
|
||||
# gpt2 uses DynamicCache
|
||||
torch_present = DynamicCache.from_legacy_cache(
|
||||
torch_present)
|
||||
torch_output = attention(input_tensor,
|
||||
past_key_value=torch_present,
|
||||
use_cache=True,
|
||||
attention_mask=attention_mask)[0]
|
||||
torch_present = torch_present.to_legacy_cache()
|
||||
elif attention_type == 'llama_attention':
|
||||
position_embeddings = rotary_emb(input_tensor, position_ids)
|
||||
attention_mask = attention_mask + AttentionMaskConverter._make_causal_mask(
|
||||
@ -1277,7 +1280,7 @@ class TestFunctional(unittest.TestCase):
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
if attention_type == 'llama_attention':
|
||||
if attention_type in ['llama_attention', 'gpt2_attention']:
|
||||
kv_dequant_scale, kv_quant_scale = get_kv_quant_scale(
|
||||
torch_present[0])
|
||||
else:
|
||||
@ -1322,7 +1325,7 @@ class TestFunctional(unittest.TestCase):
|
||||
torch_output[:, :in_len // 2, :].to(
|
||||
torch.float32).cpu().numpy(),
|
||||
atol=5e-3)
|
||||
if attention_type == 'llama_attention':
|
||||
if attention_type in ['llama_attention', 'gpt2_attention']:
|
||||
verify_kv_cache(torch_present[0])
|
||||
else:
|
||||
verify_kv_cache(torch_present)
|
||||
@ -1374,11 +1377,14 @@ class TestFunctional(unittest.TestCase):
|
||||
|
||||
# torch execution
|
||||
if attention_type == 'gpt2_attention':
|
||||
torch_output, torch_present = attention(
|
||||
input_tensor,
|
||||
layer_past=torch_present,
|
||||
use_cache=True,
|
||||
attention_mask=attention_mask)
|
||||
# gpt2 uses DynamicCache
|
||||
torch_present = DynamicCache.from_legacy_cache(
|
||||
torch_present)
|
||||
torch_output = attention(input_tensor,
|
||||
past_key_value=torch_present,
|
||||
use_cache=True,
|
||||
attention_mask=attention_mask)[0]
|
||||
torch_present = torch_present.to_legacy_cache()
|
||||
elif attention_type == 'llama_attention':
|
||||
position_embeddings = rotary_emb(input_tensor, position_ids)
|
||||
attention_mask = attention_mask + AttentionMaskConverter._make_causal_mask(
|
||||
|
||||
@ -754,11 +754,11 @@ class TestFunctional(unittest.TestCase):
|
||||
dtype=tensorrt_llm._utils.str_dtype_to_torch(dtype),
|
||||
tgt_len=(in_len if step == 0 else 1))
|
||||
if attention_type == 'gpt2_attention':
|
||||
torch_output, torch_present = attention(
|
||||
input,
|
||||
layer_past=layer_past,
|
||||
use_cache=True,
|
||||
attention_mask=attention_mask)
|
||||
torch_output = attention(input,
|
||||
past_key_value=layer_past,
|
||||
use_cache=True,
|
||||
attention_mask=attention_mask)[0]
|
||||
torch_present = layer_past
|
||||
elif attention_type == 'llama_attention':
|
||||
position_embeddings = rotary_emb(input, position_ids)
|
||||
attention_mask = attention_mask + AttentionMaskConverter._make_causal_mask(
|
||||
@ -1003,8 +1003,8 @@ class TestFunctional(unittest.TestCase):
|
||||
torch_in = input_tensor[:, offset:offset_next, :].reshape(
|
||||
(local_beam_width, input_length, hidden_size))
|
||||
|
||||
# llama uses DynamicCache
|
||||
if attention_type == 'llama_attention':
|
||||
# llama/gpt2 uses DynamicCache
|
||||
if attention_type in ['llama_attention', 'gpt2_attention']:
|
||||
past_key_values = DynamicCache.from_legacy_cache(
|
||||
torch_cache_list[req_idx])
|
||||
else:
|
||||
@ -1014,8 +1014,8 @@ class TestFunctional(unittest.TestCase):
|
||||
step, torch_in, ctx_attention_mask_list[req_idx], req_idx,
|
||||
past_key_values)
|
||||
|
||||
# llama uses DynamicCache
|
||||
if attention_type == 'llama_attention':
|
||||
# llama/gpt2 uses DynamicCache
|
||||
if attention_type in ['llama_attention', 'gpt2_attention']:
|
||||
torch_cache_list[req_idx] = past_key_values.to_legacy_cache(
|
||||
)
|
||||
past_key_values = torch_cache_list[req_idx][0]
|
||||
|
||||
Loading…
Reference in New Issue
Block a user