mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[TRTLLM-6420][feat] add support for Eclairv2 model - cherry-pick changes and minor fix (#6493)
Signed-off-by: Yibin Li <109242046+yibinl-nvidia@users.noreply.github.com>
This commit is contained in:
parent
d06675071e
commit
97787883c3
@ -14,10 +14,11 @@ import safetensors
|
||||
from helper import (convert_weight_to_dtype, fairseq_sin_pos_embedding,
|
||||
fuse_qkv_one_layer, reshape, split)
|
||||
from transformers import (AutoModelForSeq2SeqLM, Blip2ForConditionalGeneration,
|
||||
MBartForConditionalGeneration,
|
||||
MBartForConditionalGeneration, NougatProcessor,
|
||||
Pix2StructForConditionalGeneration,
|
||||
T5ForConditionalGeneration, VisionEncoderDecoderModel)
|
||||
|
||||
from tensorrt_llm._utils import pad_vocab_size
|
||||
from tensorrt_llm.functional import (LayerNormPositionType, LayerNormType,
|
||||
MLPType)
|
||||
from tensorrt_llm.layers import LanguageAdapterConfig
|
||||
@ -30,6 +31,9 @@ layernorm_type_map = {i.name: i.value for i in LayerNormType}
|
||||
layernorm_position_map = {i.name: i.value for i in LayerNormPositionType}
|
||||
mlp_type_map = {i.name: i.value for i in MLPType}
|
||||
|
||||
# Constants for specific model configurations
|
||||
ECLAIR_RADIO_MAX_POSITION_EMBEDDINGS = 20000
|
||||
|
||||
|
||||
def copy_args_to_component_config(component_config, args):
|
||||
for arg in vars(args):
|
||||
@ -619,14 +623,19 @@ def parse_bart_config(args, hf_model):
|
||||
config = configparser.ConfigParser()
|
||||
|
||||
config['decoder'] = dict()
|
||||
for key, val in hf_model.model.decoder.config.to_dict().items():
|
||||
config["decoder"][key] = f"{val}"
|
||||
if args.eclair_radio:
|
||||
for key, val in hf_model.config.to_dict().items():
|
||||
config["decoder"][key] = f"{val}"
|
||||
else:
|
||||
for key, val in hf_model.model.decoder.config.to_dict().items():
|
||||
config["decoder"][key] = f"{val}"
|
||||
config["decoder"]["q_scaling"] = '1'
|
||||
config["decoder"]["rescale_before_lm_head"] = str(False)
|
||||
config['decoder']['has_model_final_layernorm'] = str(
|
||||
args.nougat or isinstance(hf_model, MBartForConditionalGeneration))
|
||||
args.nougat or args.eclair_radio
|
||||
or isinstance(hf_model, MBartForConditionalGeneration))
|
||||
|
||||
if args.nougat:
|
||||
if args.nougat or args.eclair_radio:
|
||||
# These flags are true for mbart decoders, but missing in HF config
|
||||
config['decoder']['normalize_before'] = str(True)
|
||||
config['decoder']['normalize_embeddings'] = str(True)
|
||||
@ -763,10 +772,14 @@ def parse_bart_config(args, hf_model):
|
||||
return component_config
|
||||
|
||||
encoder_config = None
|
||||
if not args.nougat:
|
||||
if not (args.nougat or args.eclair_radio):
|
||||
encoder_config = parse_bart_config_by_component(config, "encoder", args)
|
||||
decoder_config = parse_bart_config_by_component(config, "decoder", args)
|
||||
|
||||
# Override n_positions for eclair_radio model
|
||||
if args.eclair_radio:
|
||||
decoder_config.n_positions = ECLAIR_RADIO_MAX_POSITION_EMBEDDINGS
|
||||
|
||||
return encoder_config, decoder_config
|
||||
|
||||
|
||||
@ -952,11 +965,22 @@ def convert_bart_weights_to_tllm_safetensors(config, component, params):
|
||||
(hidden_size * 3 // mapping.tp_size)))
|
||||
|
||||
if component == 'decoder':
|
||||
import torch
|
||||
lm_head_weights = params['lm_head.weight'].clone().detach()
|
||||
vocab_size = config.vocab_size
|
||||
if params['lm_head.weight'].shape[0] % mapping.tp_size != 0:
|
||||
vocab_size_padded = pad_vocab_size(config.vocab_size,
|
||||
mapping.tp_size)
|
||||
pad_width = vocab_size_padded - config.vocab_size
|
||||
|
||||
lm_head_weights = torch.nn.functional.pad(lm_head_weights,
|
||||
(0, 0, 0, pad_width),
|
||||
'constant',
|
||||
value=0)
|
||||
vocab_size = vocab_size_padded
|
||||
weights['lm_head.weight'] = reshape(
|
||||
split(params['lm_head.weight'],
|
||||
mapping.tp_size,
|
||||
mapping.tp_rank,
|
||||
dim=0), (config.vocab_size // mapping.tp_size, hidden_size))
|
||||
split(lm_head_weights, mapping.tp_size, mapping.tp_rank, dim=0),
|
||||
(vocab_size // mapping.tp_size, hidden_size))
|
||||
|
||||
if config.has_model_final_layernorm:
|
||||
weights['transformer.ln_f.weight'] = params[
|
||||
@ -1479,6 +1503,113 @@ def get_model(args):
|
||||
if args.nougat:
|
||||
model = VisionEncoderDecoderModel.from_pretrained(args.model_dir)
|
||||
model = model.get_decoder()
|
||||
elif args.eclair_radio:
|
||||
import torch
|
||||
|
||||
class RadioWithNeck(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
self.model_encoder = torch.hub.load("NVlabs/RADIO",
|
||||
"radio_model",
|
||||
version="radio_v2.5-h")
|
||||
self.model_encoder.summary_idxs = torch.tensor(4)
|
||||
|
||||
self.conv1 = torch.nn.Conv1d(1280, 1024, 1)
|
||||
self.layer_norm1 = torch.nn.LayerNorm(
|
||||
1024, eps=1e-6, elementwise_affine=True)
|
||||
self.conv2 = torch.nn.Conv2d(1024,
|
||||
1024,
|
||||
kernel_size=(1, 4),
|
||||
stride=(1, 4),
|
||||
padding=0,
|
||||
bias=False)
|
||||
self.layer_norm2 = torch.nn.LayerNorm(
|
||||
1024, eps=1e-6, elementwise_affine=True)
|
||||
|
||||
def forward(self, pixel_values):
|
||||
_, feature = self.model_encoder(pixel_values)
|
||||
output = self.conv1(feature.permute(0, 2,
|
||||
1)).permute(0, 2, 1)
|
||||
output = self.layer_norm1(output).permute(0, 2, 1)
|
||||
|
||||
b, d, _ = output.shape
|
||||
h = pixel_values.shape[-2] // 16
|
||||
w = pixel_values.shape[-1] // 16
|
||||
output = self.conv2(output.reshape(b, d, h, w))
|
||||
output = output.flatten(-2, -1).permute(0, 2, 1)
|
||||
output = self.layer_norm2(output)
|
||||
return output
|
||||
|
||||
def get_processor():
|
||||
processor = NougatProcessor.from_pretrained(
|
||||
"facebook/nougat-base")
|
||||
|
||||
special_tokens = {
|
||||
"output_plain_index": "<output_plain>",
|
||||
"output_markdown_index": "<output_markdown>",
|
||||
"output_no_text_index": "<output_no_text>",
|
||||
"output_ocr_index": "<output_ocr>",
|
||||
"predict_bbox_index": "<predict_bbox>",
|
||||
"no_bbox_index": "<no_bbox>",
|
||||
"bbox_start_index": "<bbox>", # not used but can keep
|
||||
# "bbox_end_index": "</bbox>", # not used but can keep
|
||||
"no_class_index": "<no_classes>",
|
||||
"predict_classes_index": "<predict_classes>",
|
||||
}
|
||||
for key, special_t in special_tokens.items():
|
||||
processor.tokenizer.add_special_tokens(
|
||||
{"additional_special_tokens": [special_t]})
|
||||
setattr(processor.tokenizer, key,
|
||||
processor.tokenizer.encode(special_t)[1])
|
||||
|
||||
# Add regular tokens for boxes
|
||||
processor.tokenizer.add_tokens(
|
||||
[f"<x_{x_i}>" for x_i in range(1024)])
|
||||
processor.tokenizer.add_tokens(
|
||||
[f"<y_{y_i}>" for y_i in range(1280)])
|
||||
# Add regular tokens for classes
|
||||
#"<class_{class_i}>"
|
||||
possible_classes = [
|
||||
"Text", "Title", "Section-header", "List-item", "TOC",
|
||||
"Bibliography", "Footnote", "Page-header", "Page-footer",
|
||||
"Picture", "Formula", "Page-number", "Table", "Caption"
|
||||
]
|
||||
processor.tokenizer.add_tokens(
|
||||
[f"<class_{cls}>" for cls in possible_classes])
|
||||
return processor
|
||||
|
||||
processor = get_processor()
|
||||
model = VisionEncoderDecoderModel.from_pretrained(
|
||||
"facebook/nougat-base")
|
||||
model.encoder = RadioWithNeck()
|
||||
model.decoder.resize_token_embeddings(len(processor.tokenizer),
|
||||
pad_to_multiple_of=64)
|
||||
model.config.decoder_start_token_id = processor.tokenizer.eos_token_id # 2
|
||||
model.config.pad_token_id = processor.tokenizer.pad_token_id # 1
|
||||
from transformers.models.mbart.modeling_mbart import \
|
||||
MBartLearnedPositionalEmbedding
|
||||
_, d_model = model.device, model.config.decoder.d_model
|
||||
|
||||
with torch.inference_mode():
|
||||
# Inspect checkpoint shapes
|
||||
safetensors.torch.load_model(model,
|
||||
os.path.join(
|
||||
args.model_dir,
|
||||
"model.safetensors"),
|
||||
strict=False)
|
||||
model.decoder.model.decoder.embed_positions = MBartLearnedPositionalEmbedding(
|
||||
ECLAIR_RADIO_MAX_POSITION_EMBEDDINGS, d_model)
|
||||
model.decoder.model.decoder.embed_positions.weight.data.zero_()
|
||||
model.decoder.model.decoder.embed_positions.weight.requires_grad_(
|
||||
True)
|
||||
model.decoder.lm_head.weight = model.decoder.get_input_embeddings(
|
||||
).weight
|
||||
|
||||
model.eval()
|
||||
model = model.get_decoder()
|
||||
|
||||
else:
|
||||
model = AutoModelForSeq2SeqLM.from_pretrained(args.model_dir)
|
||||
elif args.model_type == "pix2struct":
|
||||
@ -1522,14 +1653,23 @@ def convert_checkpoint(args):
|
||||
quant_algo = None
|
||||
|
||||
model_type = args.model_type if args.model_type != "blip2" else "t5"
|
||||
encoder_config, decoder_config = globals()[f'parse_{model_type}_config'](
|
||||
args, model)
|
||||
parse_config_mapper = {
|
||||
't5': parse_t5_config,
|
||||
'pix2struct': parse_pix2struct_config,
|
||||
'blip2': parse_t5_config, # blip2 uses t5 config parser
|
||||
'language_adapter': parse_language_adapter_config,
|
||||
'nmt': parse_nmt_config,
|
||||
'bart': parse_bart_config,
|
||||
}
|
||||
encoder_config, decoder_config = parse_config_mapper[model_type](args,
|
||||
model)
|
||||
|
||||
additional_settings = ["gated_act"]
|
||||
if model_type == 'language_adapter':
|
||||
additional_settings += ["residual_scaling", "language_adapter_config"]
|
||||
|
||||
if not args.nougat and args.model_type != "pix2struct":
|
||||
if not (args.nougat
|
||||
or args.eclair_radio) and args.model_type != "pix2struct":
|
||||
tllm_encoder_config = {
|
||||
'architecture': "EncoderModel",
|
||||
'dtype': args.dtype,
|
||||
@ -1664,7 +1804,8 @@ def convert_checkpoint(args):
|
||||
decoder_convert_args["sin_pos_embedding"] = sin_pos_embedding
|
||||
|
||||
if args.workers == 1:
|
||||
if not args.nougat and args.model_type != "pix2struct":
|
||||
if not (args.nougat
|
||||
or args.eclair_radio) and args.model_type != "pix2struct":
|
||||
convert(0, world_size, args, tllm_encoder_config,
|
||||
encoder_convert_args, encoder_saved_dir)
|
||||
convert(0, world_size, args, tllm_decoder_config, decoder_convert_args,
|
||||
@ -1674,7 +1815,8 @@ def convert_checkpoint(args):
|
||||
args.workers = world_size
|
||||
LOGGER.info(f'Convert checkpoint using {args.workers} workers.')
|
||||
import torch.multiprocessing as mp
|
||||
if not args.nougat and args.model_type != "pix2struct":
|
||||
if not (args.nougat
|
||||
or args.eclair_radio) and args.model_type != "pix2struct":
|
||||
mp.spawn(convert,
|
||||
nprocs=args.workers,
|
||||
args=(world_size, args, tllm_encoder_config,
|
||||
@ -1736,6 +1878,9 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--nougat",
|
||||
action="store_true",
|
||||
help="Model which uses vision encoder + mbart decoder")
|
||||
parser.add_argument("--eclair_radio",
|
||||
action="store_true",
|
||||
help="Model which uses vision encoder + mbart decoder")
|
||||
parser.add_argument("--verbose",
|
||||
action="store_true",
|
||||
help="Provide verbose messages")
|
||||
|
||||
1
examples/models/core/multimodal/requirements-eclair.txt
Normal file
1
examples/models/core/multimodal/requirements-eclair.txt
Normal file
@ -0,0 +1 @@
|
||||
timm
|
||||
@ -20,7 +20,8 @@ import tensorrt as trt
|
||||
import torch
|
||||
|
||||
from tensorrt_llm._common import default_net
|
||||
from tensorrt_llm._utils import numpy_to_torch, str_dtype_to_torch
|
||||
from tensorrt_llm._utils import (numpy_to_torch, pad_vocab_size,
|
||||
str_dtype_to_torch)
|
||||
from tensorrt_llm.functional import (LayerNormPositionType, LayerNormType,
|
||||
MLPType, PositionEmbeddingType, Tensor,
|
||||
assertion, cast, gather_last_token_logits,
|
||||
@ -1156,9 +1157,11 @@ class DecoderModel(PretrainedModel):
|
||||
self.transformer.assign_module(decoder_layers, "layers")
|
||||
|
||||
if self.mapping.is_last_pp_rank():
|
||||
vocab_size_padded = pad_vocab_size(self.config.vocab_size,
|
||||
self.config.mapping.tp_size)
|
||||
self.lm_head = ColumnLinear(
|
||||
self.config.hidden_size,
|
||||
self.config.vocab_size,
|
||||
vocab_size_padded,
|
||||
bias=False if not hasattr(self.config, "has_lm_head_bias") else
|
||||
self.config.has_lm_head_bias,
|
||||
dtype=self.config.dtype,
|
||||
@ -1208,7 +1211,6 @@ class DecoderModel(PretrainedModel):
|
||||
config.set_if_not_exist('num_buckets', None)
|
||||
config.set_if_not_exist('max_distance', None)
|
||||
config.set_if_not_exist('relative_attention', False)
|
||||
config.set_if_not_exist('residual_scaling', 1.0)
|
||||
|
||||
def forward(self,
|
||||
decoder_input_ids: Tensor,
|
||||
|
||||
@ -25,25 +25,25 @@ from transformers import (AutoConfig, AutoModel, AutoModelForCausalLM,
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from PIL import Image
|
||||
from safetensors.torch import save_file
|
||||
from safetensors.torch import load_model, save_file
|
||||
from transformers import CLIPImageProcessor
|
||||
|
||||
from ..runtime.session import Session
|
||||
|
||||
|
||||
def add_multimodal_arguments(parser):
|
||||
parser.add_argument('--model_type',
|
||||
type=str,
|
||||
default=None,
|
||||
choices=[
|
||||
'blip2', 'llava', 'llava_next', 'llava_onevision',
|
||||
'llava_onevision_lmms', 'vila', 'nougat', 'cogvlm',
|
||||
'fuyu', 'pix2struct', 'neva', 'kosmos-2',
|
||||
'video-neva', 'phi-3-vision', 'phi-4-multimodal',
|
||||
'mllama', 'internvl', 'qwen2_vl',
|
||||
'internlm-xcomposer2', 'qwen2_audio', 'pixtral'
|
||||
],
|
||||
help="Model type")
|
||||
parser.add_argument(
|
||||
'--model_type',
|
||||
type=str,
|
||||
default=None,
|
||||
choices=[
|
||||
'blip2', 'llava', 'llava_next', 'llava_onevision',
|
||||
'llava_onevision_lmms', 'vila', 'nougat', 'cogvlm', 'fuyu',
|
||||
'pix2struct', 'neva', 'kosmos-2', 'video-neva', 'phi-3-vision',
|
||||
'phi-4-multimodal', 'mllama', 'internvl', 'qwen2_vl',
|
||||
'internlm-xcomposer2', 'qwen2_audio', 'pixtral', 'eclair'
|
||||
],
|
||||
help="Model type")
|
||||
parser.add_argument(
|
||||
'--model_path',
|
||||
type=str,
|
||||
@ -144,6 +144,8 @@ class MultimodalEngineBuilder:
|
||||
build_qwen2_audio_engine(args)
|
||||
elif args.model_type == "pixtral":
|
||||
build_pixtral_engine(args)
|
||||
elif args.model_type == "eclair":
|
||||
build_eclair_engine(args)
|
||||
else:
|
||||
raise RuntimeError(f"Invalid model type {args.model_type}")
|
||||
|
||||
@ -1739,3 +1741,78 @@ def build_pixtral_engine(args):
|
||||
max_batch_size=args.max_batch_size,
|
||||
engine_name=f"model.engine",
|
||||
dtype=torch.bfloat16)
|
||||
|
||||
|
||||
def build_eclair_engine(args):
|
||||
|
||||
class RadioWithNeck(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
try:
|
||||
self.model_encoder = torch.hub.load("NVlabs/RADIO",
|
||||
"radio_model",
|
||||
version="radio_v2.5-h")
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f"Failed to load RADIO model from torch.hub: {e}")
|
||||
self.model_encoder.summary_idxs = torch.tensor(4)
|
||||
|
||||
self.conv1 = torch.nn.Conv1d(1280, 1024, 1)
|
||||
self.layer_norm1 = torch.nn.LayerNorm(1024,
|
||||
eps=1e-6,
|
||||
elementwise_affine=True)
|
||||
self.conv2 = torch.nn.Conv2d(1024,
|
||||
1024,
|
||||
kernel_size=(1, 4),
|
||||
stride=(1, 4),
|
||||
padding=0,
|
||||
bias=False)
|
||||
self.layer_norm2 = torch.nn.LayerNorm(1024,
|
||||
eps=1e-6,
|
||||
elementwise_affine=True)
|
||||
|
||||
@torch.no_grad
|
||||
def forward(self, pixel_values):
|
||||
_, feature = self.model_encoder(pixel_values)
|
||||
output = self.conv1(feature.permute(0, 2, 1)).permute(0, 2, 1)
|
||||
output = self.layer_norm1(output).permute(0, 2, 1)
|
||||
|
||||
b, d, _ = output.shape
|
||||
h = pixel_values.shape[-2] // 16
|
||||
w = pixel_values.shape[-1] // 16
|
||||
output = self.conv2(output.reshape(b, d, h, w))
|
||||
output = output.flatten(-2, -1).permute(0, 2, 1)
|
||||
output = self.layer_norm2(output)
|
||||
return output
|
||||
|
||||
processor = NougatProcessor.from_pretrained(args.model_path)
|
||||
model = VisionEncoderDecoderModel.from_pretrained("facebook/nougat-base")
|
||||
model.encoder = RadioWithNeck()
|
||||
model.decoder.resize_token_embeddings(len(processor.tokenizer))
|
||||
model.config.decoder_start_token_id = processor.tokenizer.eos_token_id # 2
|
||||
model.config.pad_token_id = processor.tokenizer.pad_token_id # 1
|
||||
checkpoint_path = os.path.join(args.model_path, "model.safetensors")
|
||||
if not os.path.exists(checkpoint_path):
|
||||
raise FileNotFoundError(
|
||||
f"Model checkpoint not found at {checkpoint_path}")
|
||||
load_model(model, checkpoint_path)
|
||||
|
||||
wrapper = model.encoder.to(args.device)
|
||||
# temporary fix due to TRT onnx export bug
|
||||
for block in wrapper.model_encoder.model.blocks:
|
||||
block.attn.fused_attn = False
|
||||
|
||||
image = torch.randn((1, 3, 2048, 1648),
|
||||
device=args.device,
|
||||
dtype=torch.bfloat16)
|
||||
export_onnx(wrapper, image, f'{args.output_dir}/onnx')
|
||||
build_trt_engine(
|
||||
args.model_type,
|
||||
[image.shape[1], image.shape[2], image.shape[3]], # [3, H, W]
|
||||
f'{args.output_dir}/onnx',
|
||||
args.output_dir,
|
||||
args.max_batch_size,
|
||||
dtype=torch.bfloat16,
|
||||
engine_name='visual_encoder.engine')
|
||||
|
||||
Loading…
Reference in New Issue
Block a user