[TRTLLM-6420][feat] add support for Eclairv2 model - cherry-pick changes and minor fix (#6493)

Signed-off-by: Yibin Li <109242046+yibinl-nvidia@users.noreply.github.com>
This commit is contained in:
Yibin Li 2025-08-08 18:40:48 -07:00 committed by GitHub
parent d06675071e
commit 97787883c3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 256 additions and 31 deletions

View File

@ -14,10 +14,11 @@ import safetensors
from helper import (convert_weight_to_dtype, fairseq_sin_pos_embedding,
fuse_qkv_one_layer, reshape, split)
from transformers import (AutoModelForSeq2SeqLM, Blip2ForConditionalGeneration,
MBartForConditionalGeneration,
MBartForConditionalGeneration, NougatProcessor,
Pix2StructForConditionalGeneration,
T5ForConditionalGeneration, VisionEncoderDecoderModel)
from tensorrt_llm._utils import pad_vocab_size
from tensorrt_llm.functional import (LayerNormPositionType, LayerNormType,
MLPType)
from tensorrt_llm.layers import LanguageAdapterConfig
@ -30,6 +31,9 @@ layernorm_type_map = {i.name: i.value for i in LayerNormType}
layernorm_position_map = {i.name: i.value for i in LayerNormPositionType}
mlp_type_map = {i.name: i.value for i in MLPType}
# Constants for specific model configurations
ECLAIR_RADIO_MAX_POSITION_EMBEDDINGS = 20000
def copy_args_to_component_config(component_config, args):
for arg in vars(args):
@ -619,14 +623,19 @@ def parse_bart_config(args, hf_model):
config = configparser.ConfigParser()
config['decoder'] = dict()
for key, val in hf_model.model.decoder.config.to_dict().items():
config["decoder"][key] = f"{val}"
if args.eclair_radio:
for key, val in hf_model.config.to_dict().items():
config["decoder"][key] = f"{val}"
else:
for key, val in hf_model.model.decoder.config.to_dict().items():
config["decoder"][key] = f"{val}"
config["decoder"]["q_scaling"] = '1'
config["decoder"]["rescale_before_lm_head"] = str(False)
config['decoder']['has_model_final_layernorm'] = str(
args.nougat or isinstance(hf_model, MBartForConditionalGeneration))
args.nougat or args.eclair_radio
or isinstance(hf_model, MBartForConditionalGeneration))
if args.nougat:
if args.nougat or args.eclair_radio:
# These flags are true for mbart decoders, but missing in HF config
config['decoder']['normalize_before'] = str(True)
config['decoder']['normalize_embeddings'] = str(True)
@ -763,10 +772,14 @@ def parse_bart_config(args, hf_model):
return component_config
encoder_config = None
if not args.nougat:
if not (args.nougat or args.eclair_radio):
encoder_config = parse_bart_config_by_component(config, "encoder", args)
decoder_config = parse_bart_config_by_component(config, "decoder", args)
# Override n_positions for eclair_radio model
if args.eclair_radio:
decoder_config.n_positions = ECLAIR_RADIO_MAX_POSITION_EMBEDDINGS
return encoder_config, decoder_config
@ -952,11 +965,22 @@ def convert_bart_weights_to_tllm_safetensors(config, component, params):
(hidden_size * 3 // mapping.tp_size)))
if component == 'decoder':
import torch
lm_head_weights = params['lm_head.weight'].clone().detach()
vocab_size = config.vocab_size
if params['lm_head.weight'].shape[0] % mapping.tp_size != 0:
vocab_size_padded = pad_vocab_size(config.vocab_size,
mapping.tp_size)
pad_width = vocab_size_padded - config.vocab_size
lm_head_weights = torch.nn.functional.pad(lm_head_weights,
(0, 0, 0, pad_width),
'constant',
value=0)
vocab_size = vocab_size_padded
weights['lm_head.weight'] = reshape(
split(params['lm_head.weight'],
mapping.tp_size,
mapping.tp_rank,
dim=0), (config.vocab_size // mapping.tp_size, hidden_size))
split(lm_head_weights, mapping.tp_size, mapping.tp_rank, dim=0),
(vocab_size // mapping.tp_size, hidden_size))
if config.has_model_final_layernorm:
weights['transformer.ln_f.weight'] = params[
@ -1479,6 +1503,113 @@ def get_model(args):
if args.nougat:
model = VisionEncoderDecoderModel.from_pretrained(args.model_dir)
model = model.get_decoder()
elif args.eclair_radio:
import torch
class RadioWithNeck(torch.nn.Module):
def __init__(self):
super().__init__()
self.model_encoder = torch.hub.load("NVlabs/RADIO",
"radio_model",
version="radio_v2.5-h")
self.model_encoder.summary_idxs = torch.tensor(4)
self.conv1 = torch.nn.Conv1d(1280, 1024, 1)
self.layer_norm1 = torch.nn.LayerNorm(
1024, eps=1e-6, elementwise_affine=True)
self.conv2 = torch.nn.Conv2d(1024,
1024,
kernel_size=(1, 4),
stride=(1, 4),
padding=0,
bias=False)
self.layer_norm2 = torch.nn.LayerNorm(
1024, eps=1e-6, elementwise_affine=True)
def forward(self, pixel_values):
_, feature = self.model_encoder(pixel_values)
output = self.conv1(feature.permute(0, 2,
1)).permute(0, 2, 1)
output = self.layer_norm1(output).permute(0, 2, 1)
b, d, _ = output.shape
h = pixel_values.shape[-2] // 16
w = pixel_values.shape[-1] // 16
output = self.conv2(output.reshape(b, d, h, w))
output = output.flatten(-2, -1).permute(0, 2, 1)
output = self.layer_norm2(output)
return output
def get_processor():
processor = NougatProcessor.from_pretrained(
"facebook/nougat-base")
special_tokens = {
"output_plain_index": "<output_plain>",
"output_markdown_index": "<output_markdown>",
"output_no_text_index": "<output_no_text>",
"output_ocr_index": "<output_ocr>",
"predict_bbox_index": "<predict_bbox>",
"no_bbox_index": "<no_bbox>",
"bbox_start_index": "<bbox>", # not used but can keep
# "bbox_end_index": "</bbox>", # not used but can keep
"no_class_index": "<no_classes>",
"predict_classes_index": "<predict_classes>",
}
for key, special_t in special_tokens.items():
processor.tokenizer.add_special_tokens(
{"additional_special_tokens": [special_t]})
setattr(processor.tokenizer, key,
processor.tokenizer.encode(special_t)[1])
# Add regular tokens for boxes
processor.tokenizer.add_tokens(
[f"<x_{x_i}>" for x_i in range(1024)])
processor.tokenizer.add_tokens(
[f"<y_{y_i}>" for y_i in range(1280)])
# Add regular tokens for classes
#"<class_{class_i}>"
possible_classes = [
"Text", "Title", "Section-header", "List-item", "TOC",
"Bibliography", "Footnote", "Page-header", "Page-footer",
"Picture", "Formula", "Page-number", "Table", "Caption"
]
processor.tokenizer.add_tokens(
[f"<class_{cls}>" for cls in possible_classes])
return processor
processor = get_processor()
model = VisionEncoderDecoderModel.from_pretrained(
"facebook/nougat-base")
model.encoder = RadioWithNeck()
model.decoder.resize_token_embeddings(len(processor.tokenizer),
pad_to_multiple_of=64)
model.config.decoder_start_token_id = processor.tokenizer.eos_token_id # 2
model.config.pad_token_id = processor.tokenizer.pad_token_id # 1
from transformers.models.mbart.modeling_mbart import \
MBartLearnedPositionalEmbedding
_, d_model = model.device, model.config.decoder.d_model
with torch.inference_mode():
# Inspect checkpoint shapes
safetensors.torch.load_model(model,
os.path.join(
args.model_dir,
"model.safetensors"),
strict=False)
model.decoder.model.decoder.embed_positions = MBartLearnedPositionalEmbedding(
ECLAIR_RADIO_MAX_POSITION_EMBEDDINGS, d_model)
model.decoder.model.decoder.embed_positions.weight.data.zero_()
model.decoder.model.decoder.embed_positions.weight.requires_grad_(
True)
model.decoder.lm_head.weight = model.decoder.get_input_embeddings(
).weight
model.eval()
model = model.get_decoder()
else:
model = AutoModelForSeq2SeqLM.from_pretrained(args.model_dir)
elif args.model_type == "pix2struct":
@ -1522,14 +1653,23 @@ def convert_checkpoint(args):
quant_algo = None
model_type = args.model_type if args.model_type != "blip2" else "t5"
encoder_config, decoder_config = globals()[f'parse_{model_type}_config'](
args, model)
parse_config_mapper = {
't5': parse_t5_config,
'pix2struct': parse_pix2struct_config,
'blip2': parse_t5_config, # blip2 uses t5 config parser
'language_adapter': parse_language_adapter_config,
'nmt': parse_nmt_config,
'bart': parse_bart_config,
}
encoder_config, decoder_config = parse_config_mapper[model_type](args,
model)
additional_settings = ["gated_act"]
if model_type == 'language_adapter':
additional_settings += ["residual_scaling", "language_adapter_config"]
if not args.nougat and args.model_type != "pix2struct":
if not (args.nougat
or args.eclair_radio) and args.model_type != "pix2struct":
tllm_encoder_config = {
'architecture': "EncoderModel",
'dtype': args.dtype,
@ -1664,7 +1804,8 @@ def convert_checkpoint(args):
decoder_convert_args["sin_pos_embedding"] = sin_pos_embedding
if args.workers == 1:
if not args.nougat and args.model_type != "pix2struct":
if not (args.nougat
or args.eclair_radio) and args.model_type != "pix2struct":
convert(0, world_size, args, tllm_encoder_config,
encoder_convert_args, encoder_saved_dir)
convert(0, world_size, args, tllm_decoder_config, decoder_convert_args,
@ -1674,7 +1815,8 @@ def convert_checkpoint(args):
args.workers = world_size
LOGGER.info(f'Convert checkpoint using {args.workers} workers.')
import torch.multiprocessing as mp
if not args.nougat and args.model_type != "pix2struct":
if not (args.nougat
or args.eclair_radio) and args.model_type != "pix2struct":
mp.spawn(convert,
nprocs=args.workers,
args=(world_size, args, tllm_encoder_config,
@ -1736,6 +1878,9 @@ if __name__ == "__main__":
parser.add_argument("--nougat",
action="store_true",
help="Model which uses vision encoder + mbart decoder")
parser.add_argument("--eclair_radio",
action="store_true",
help="Model which uses vision encoder + mbart decoder")
parser.add_argument("--verbose",
action="store_true",
help="Provide verbose messages")

View File

@ -0,0 +1 @@
timm

View File

@ -20,7 +20,8 @@ import tensorrt as trt
import torch
from tensorrt_llm._common import default_net
from tensorrt_llm._utils import numpy_to_torch, str_dtype_to_torch
from tensorrt_llm._utils import (numpy_to_torch, pad_vocab_size,
str_dtype_to_torch)
from tensorrt_llm.functional import (LayerNormPositionType, LayerNormType,
MLPType, PositionEmbeddingType, Tensor,
assertion, cast, gather_last_token_logits,
@ -1156,9 +1157,11 @@ class DecoderModel(PretrainedModel):
self.transformer.assign_module(decoder_layers, "layers")
if self.mapping.is_last_pp_rank():
vocab_size_padded = pad_vocab_size(self.config.vocab_size,
self.config.mapping.tp_size)
self.lm_head = ColumnLinear(
self.config.hidden_size,
self.config.vocab_size,
vocab_size_padded,
bias=False if not hasattr(self.config, "has_lm_head_bias") else
self.config.has_lm_head_bias,
dtype=self.config.dtype,
@ -1208,7 +1211,6 @@ class DecoderModel(PretrainedModel):
config.set_if_not_exist('num_buckets', None)
config.set_if_not_exist('max_distance', None)
config.set_if_not_exist('relative_attention', False)
config.set_if_not_exist('residual_scaling', 1.0)
def forward(self,
decoder_input_ids: Tensor,

View File

@ -25,25 +25,25 @@ from transformers import (AutoConfig, AutoModel, AutoModelForCausalLM,
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from safetensors.torch import save_file
from safetensors.torch import load_model, save_file
from transformers import CLIPImageProcessor
from ..runtime.session import Session
def add_multimodal_arguments(parser):
parser.add_argument('--model_type',
type=str,
default=None,
choices=[
'blip2', 'llava', 'llava_next', 'llava_onevision',
'llava_onevision_lmms', 'vila', 'nougat', 'cogvlm',
'fuyu', 'pix2struct', 'neva', 'kosmos-2',
'video-neva', 'phi-3-vision', 'phi-4-multimodal',
'mllama', 'internvl', 'qwen2_vl',
'internlm-xcomposer2', 'qwen2_audio', 'pixtral'
],
help="Model type")
parser.add_argument(
'--model_type',
type=str,
default=None,
choices=[
'blip2', 'llava', 'llava_next', 'llava_onevision',
'llava_onevision_lmms', 'vila', 'nougat', 'cogvlm', 'fuyu',
'pix2struct', 'neva', 'kosmos-2', 'video-neva', 'phi-3-vision',
'phi-4-multimodal', 'mllama', 'internvl', 'qwen2_vl',
'internlm-xcomposer2', 'qwen2_audio', 'pixtral', 'eclair'
],
help="Model type")
parser.add_argument(
'--model_path',
type=str,
@ -144,6 +144,8 @@ class MultimodalEngineBuilder:
build_qwen2_audio_engine(args)
elif args.model_type == "pixtral":
build_pixtral_engine(args)
elif args.model_type == "eclair":
build_eclair_engine(args)
else:
raise RuntimeError(f"Invalid model type {args.model_type}")
@ -1739,3 +1741,78 @@ def build_pixtral_engine(args):
max_batch_size=args.max_batch_size,
engine_name=f"model.engine",
dtype=torch.bfloat16)
def build_eclair_engine(args):
class RadioWithNeck(torch.nn.Module):
def __init__(self):
super().__init__()
try:
self.model_encoder = torch.hub.load("NVlabs/RADIO",
"radio_model",
version="radio_v2.5-h")
except Exception as e:
raise RuntimeError(
f"Failed to load RADIO model from torch.hub: {e}")
self.model_encoder.summary_idxs = torch.tensor(4)
self.conv1 = torch.nn.Conv1d(1280, 1024, 1)
self.layer_norm1 = torch.nn.LayerNorm(1024,
eps=1e-6,
elementwise_affine=True)
self.conv2 = torch.nn.Conv2d(1024,
1024,
kernel_size=(1, 4),
stride=(1, 4),
padding=0,
bias=False)
self.layer_norm2 = torch.nn.LayerNorm(1024,
eps=1e-6,
elementwise_affine=True)
@torch.no_grad
def forward(self, pixel_values):
_, feature = self.model_encoder(pixel_values)
output = self.conv1(feature.permute(0, 2, 1)).permute(0, 2, 1)
output = self.layer_norm1(output).permute(0, 2, 1)
b, d, _ = output.shape
h = pixel_values.shape[-2] // 16
w = pixel_values.shape[-1] // 16
output = self.conv2(output.reshape(b, d, h, w))
output = output.flatten(-2, -1).permute(0, 2, 1)
output = self.layer_norm2(output)
return output
processor = NougatProcessor.from_pretrained(args.model_path)
model = VisionEncoderDecoderModel.from_pretrained("facebook/nougat-base")
model.encoder = RadioWithNeck()
model.decoder.resize_token_embeddings(len(processor.tokenizer))
model.config.decoder_start_token_id = processor.tokenizer.eos_token_id # 2
model.config.pad_token_id = processor.tokenizer.pad_token_id # 1
checkpoint_path = os.path.join(args.model_path, "model.safetensors")
if not os.path.exists(checkpoint_path):
raise FileNotFoundError(
f"Model checkpoint not found at {checkpoint_path}")
load_model(model, checkpoint_path)
wrapper = model.encoder.to(args.device)
# temporary fix due to TRT onnx export bug
for block in wrapper.model_encoder.model.blocks:
block.attn.fused_attn = False
image = torch.randn((1, 3, 2048, 1648),
device=args.device,
dtype=torch.bfloat16)
export_onnx(wrapper, image, f'{args.output_dir}/onnx')
build_trt_engine(
args.model_type,
[image.shape[1], image.shape[2], image.shape[3]], # [3, H, W]
f'{args.output_dir}/onnx',
args.output_dir,
args.max_batch_size,
dtype=torch.bfloat16,
engine_name='visual_encoder.engine')