renew repo, shrink down git sizes

This commit is contained in:
Your Name
2025-02-22 12:25:49 +08:00
commit 7bdff0d271
99 changed files with 11758 additions and 0 deletions
+10
View File
@@ -0,0 +1,10 @@
checkpoints/
__pycache__/
.DS_store
*.egg-info/
dist/
vendor/
eval_results/
*.webui_secret_key
+39
View File
@@ -0,0 +1,39 @@
from namo.api.vl import VLInfer
import os
from termcolor import colored
import torch
def chat():
model = VLInfer(
model_type="namo", device="cuda:0" if torch.cuda.is_available() else "cpu"
)
crt_input = ["images/cats.jpg", None]
while True:
img_or_txt = input(colored("\nUser (txt/img_path): ", "cyan")).strip()
if os.path.exists(img_or_txt.split(" ")[0]):
crt_input[0] = img_or_txt
print(colored("System: Image updated.", "green"))
continue
else:
crt_input[1] = img_or_txt
if crt_input[0] and crt_input[1]:
print(colored("Assistant:", "green"), end=" ")
model.generate(images=crt_input[0], prompt=crt_input[1], verbose=False)
crt_input[0] = None
elif not crt_input[0] and crt_input[1]:
# pure text
print(colored("Assistant:", "green"), end=" ")
model.generate(images=None, prompt=crt_input[1], verbose=False)
else:
print(
colored("System: Please provide either an image or text input.", "red")
)
if __name__ == "__main__":
chat()
+199
View File
@@ -0,0 +1,199 @@
import argparse
import json
import os
import random
import torch
from termcolor import colored
from transformers import TextStreamer
from namo.models.namo import NamoForCausalLM
from namo.models.configuration_namo import NamoConfig
from namo.utils.infer_utils import load_multi_images_maybe
from namo.utils.hf_utils import find_and_merge_lora_adapters
from namo.utils.process_utils import tokenizer_image_token
from loguru import logger
def load_model_simple(model_path):
non_lora_bin = os.path.join(model_path, "non_lora_trainables.bin")
if os.path.exists(non_lora_bin):
logger.info(f"loading lora: {model_path}")
config = NamoConfig.from_pretrained(model_path)
model = NamoForCausalLM(config=config)
non_lora = torch.load(non_lora_bin)
non_lora = {k.replace("base_model.model.", ""): v for k, v in non_lora.items()}
model.load_state_dict(non_lora, strict=False)
model = find_and_merge_lora_adapters(model, model_path)
return model
else:
return NamoForCausalLM.from_pretrained(model_path)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--model-path", default="checkpoints/namo-500m")
parser.add_argument("--eval", action="store_true")
parser.add_argument("--debug", action="store_true")
args = parser.parse_args()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = load_model_simple(args.model_path)
model.eval().to(device)
image_processor = model.get_vision_tower().image_processor
tokenizer = model.get_namo().tokenizer
if args.eval:
with open("images/evals.json") as f:
for item in json.load(f):
handle_eval_item(
item, model, image_processor, tokenizer, device, args.debug
)
else:
run_cli(model, image_processor, tokenizer, device)
def handle_eval_item(item, model, image_processor, tokenizer, device, debug=False):
image_path = item["image"]
question = random.choice([item["question1"], item["question2"]])
images = load_multi_images_maybe(image_path)
image_processor.size["shortest_edge"] = 448
pixel_values = (
image_processor.preprocess(images, do_resize=True, return_tensors="pt")[
"pixel_values"
]
.to(device)
.to(model.dtype)
)
if debug:
logger.info(f"pixel_values: {pixel_values.shape}")
chat = [
{"role": "system", "content": "Follow instructions carefully."},
{"role": "user", "content": f"<image>\n{question}"},
]
prompt = (
tokenizer.apply_chat_template(chat, tokenize=False) + "<|im_start|>assistant\n"
)
input_ids = (
tokenizer_image_token(prompt, tokenizer, return_tensors="pt")
.unsqueeze(0)
.to(device)
)
print(f"\nImage: {image_path}\nQ: {question}\n", end="")
print(colored("AI: ", "yellow"), end="")
generate_response(model, tokenizer, pixel_values, prompt)
print("\n")
def run_cli(model, image_processor, tokenizer, device):
DEFAULT_IMAGE = "images/cats.jpg"
current_pixels = process_image(DEFAULT_IMAGE, image_processor, model, device)
image_processor.size["shortest_edge"] = 448
messages = [
{"role": "system", "content": "Respond carefully."},
{
"role": "user",
"content": [{"type": "image_url", "image_url": DEFAULT_IMAGE}],
},
]
while True:
try:
user_input = input(colored("\nUser (txt/img_path): ", "green")).strip()
if not user_input:
continue
if user_input.lower() in ("exit", "quit"):
break
if os.path.exists(user_input):
current_pixels = process_image(
user_input, image_processor, model, device
)
messages = [
{"role": "system", "content": "Respond carefully."},
{
"role": "user",
"content": [{"type": "image_url", "image_url": user_input}],
},
]
print(colored("System: New image loaded", "yellow"))
continue
last_user_msg = next(
(m for m in reversed(messages) if m["role"] == "user"), None
)
if last_user_msg and not has_text_content(last_user_msg):
last_user_msg["content"].append({"type": "text", "text": user_input})
else:
messages.append(
{"role": "user", "content": [{"type": "text", "text": user_input}]}
)
prompt = build_chat_prompt(messages, tokenizer)
print(colored("Assistant: ", "blue"), end="")
response = generate_response(model, tokenizer, current_pixels, prompt)
messages.append(
{"role": "assistant", "content": [{"type": "text", "text": response}]}
)
except KeyboardInterrupt:
print(colored("\nSession ended.", "red"))
break
def process_image(path, processor, model, device):
images = load_multi_images_maybe(path)
return (
processor.preprocess(images, return_tensors="pt")["pixel_values"]
.to(device)
.to(model.dtype)
)
def build_chat_prompt(messages, tokenizer):
converted = []
for msg in messages:
if msg["role"] == "system":
converted.append(msg)
else:
parts = []
for content in msg["content"]:
if content["type"] == "image_url":
parts.append("<image>")
elif content["type"] == "text":
parts.append(content["text"])
converted.append({"role": msg["role"], "content": "\n".join(parts)})
return (
tokenizer.apply_chat_template(converted, tokenize=False)
+ "<|im_start|>assistant\n"
)
def generate_response(model, tokenizer, pixels, prompt):
input_ids = (
tokenizer_image_token(prompt, tokenizer, return_tensors="pt")
.unsqueeze(0)
.to(model.device)
)
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
with torch.autocast(device_type="cuda", dtype=torch.float16):
output_ids = model.generate(
pixel_values=pixels,
input_ids=input_ids,
do_sample=False,
max_new_tokens=360,
streamer=streamer,
eos_token_id=tokenizer.pad_token_id,
pad_token_id=tokenizer.pad_token_id,
)
return tokenizer.decode(output_ids[0], skip_special_tokens=True).strip()
def has_text_content(message):
return any(c["type"] == "text" for c in message["content"])
if __name__ == "__main__":
main()
Binary file not shown.

After

Width:  |  Height:  |  Size: 50 KiB

BIN
View File
Binary file not shown.

After

Width:  |  Height:  |  Size: 20 KiB

BIN
View File
Binary file not shown.

After

Width:  |  Height:  |  Size: 13 KiB

BIN
View File
Binary file not shown.

After

Width:  |  Height:  |  Size: 26 KiB

BIN
View File
Binary file not shown.

After

Width:  |  Height:  |  Size: 261 KiB

BIN
View File
Binary file not shown.

After

Width:  |  Height:  |  Size: 61 KiB

BIN
View File
Binary file not shown.

After

Width:  |  Height:  |  Size: 276 KiB

BIN
View File
Binary file not shown.

After

Width:  |  Height:  |  Size: 32 KiB

BIN
View File
Binary file not shown.

After

Width:  |  Height:  |  Size: 58 KiB

BIN
View File
Binary file not shown.

After

Width:  |  Height:  |  Size: 11 KiB

BIN
View File
Binary file not shown.

After

Width:  |  Height:  |  Size: 188 KiB

BIN
View File
Binary file not shown.

After

Width:  |  Height:  |  Size: 76 KiB

BIN
View File
Binary file not shown.

After

Width:  |  Height:  |  Size: 34 KiB

View File
View File
+35
View File
@@ -0,0 +1,35 @@
import os
from loguru import logger
import torch
class VLBase:
def __init__(self, model_path=None, processor_path=None, device="auto"):
self.device = (
"cuda:0"
if torch.cuda.is_available()
else (
"mps"
if torch.backends.mps.is_available()
else "cpu" if device == "auto" else device
)
)
self.model = self.load_model(model_path)
self.processor = self.load_processor(
processor_path if processor_path is None else model_path
)
self.history_msgs = []
def load_model(self, model_path):
raise NotImplementedError
def load_processor(self, processor_path):
raise NotImplementedError
def stream_chat_with_images(self):
raise NotImplementedError
def generate(self, prompt, image, verbose):
pass
+346
View File
@@ -0,0 +1,346 @@
import os
import threading
from typing import AsyncGenerator
from namo.api.base import VLBase
from loguru import logger
import torch
from termcolor import colored
from transformers import TextStreamer
from transformers import AutoProcessor
from namo.models.namo import NamoForCausalLM
from namo.models.configuration_namo import NamoConfig
from namo.utils.infer_utils import CallbackStreamer, load_multi_images_maybe
from namo.utils.process_utils import convert_image_tags, tokenizer_image_token
from loguru import logger
from huggingface_hub import hf_hub_download, snapshot_download
from namo.utils.process_utils import smart_resize_v1
from transformers.models.clip.image_processing_clip import CLIPImageProcessor
from namo.utils.infer_utils import url_to_image
from transformers import TextIteratorStreamer
class NamoVL(VLBase):
def __init__(
self,
model_path=None,
processor_path=None,
device="auto",
system_msg="You are Namo small VLM model, trained by NAMO. You can look images and with great OCR ability.",
):
super().__init__(model_path, processor_path, device)
# default: Load the model on the available device(s)
self.default_sys = {"role": "system", "content": system_msg}
self.history_msgs = [self.default_sys]
def load_model(self, model_path):
if model_path is None:
model_path = "checkpoints/Namo-500M-V1"
if not os.path.exists(model_path):
logger.info(f"downloading model from huggingface into: {model_path}")
snapshot_download(
repo_id="lucasjin/Namo-500M-V1",
local_dir=model_path,
local_dir_use_symlinks=False,
)
model = NamoForCausalLM.from_pretrained(
model_path,
torch_dtype="auto",
# device_map="auto"
)
model.eval().to(self.device)
logger.info(f"model loaded from: {model_path}")
return model
def load_processor(self, processor_path):
processor = self.model.get_vision_tower().image_processor
self.image_processor = processor
self.tokenizer = self.model.get_namo().tokenizer
if self.tokenizer.pad_token_id is None:
self.tokenizer.pad_token_id = self.tokenizer.encode(
self.tokenizer.pad_token
)
return processor
def build_chat_prompt(self, messages, tokenizer):
converted = []
for msg in messages:
if msg["role"] == "system":
converted.append(msg)
elif msg["role"] == "assistant":
converted.append(msg["content"])
else:
parts = []
# check if content['text'] already contains image tag
# do not convert tag
imgs_num = 0
txt = ""
if isinstance(msg["content"], str):
txt = msg["content"]
else:
for content in msg["content"]:
if content["type"] == "image_url":
parts.append("<image>")
imgs_num += 1
elif content["type"] == "text":
parts.append(content["text"] + "\n")
txt = content["text"]
if txt.count("<image>") == imgs_num:
parts = txt
else:
parts = "".join(parts)
parts = convert_image_tags(parts)
converted.append({"role": msg["role"], "content": parts})
return (
tokenizer.apply_chat_template(converted, tokenize=False)
+ "<|im_start|>assistant\n"
)
def get_history_images(self):
his_images = []
for msg in self.history_msgs:
if isinstance(msg["content"], str):
continue
for content in msg["content"]:
if content["type"] == "image_url":
his_images.append(content["image_url"])
return his_images
@staticmethod
def msg_has_img(msg):
if isinstance(msg["content"], list):
return any(
[
c["type"] == "image_url" and c["image_url"] is not None
for c in msg["content"]
]
)
return False
def remove_history_images(self):
hist_images = []
for msg in self.history_msgs[::-1]:
if self.msg_has_img(msg):
msg_new = msg.copy()
msg_new["content"] = [
itm for itm in msg["content"] if itm["type"] != "image_url"
]
hist_images.append(msg_new)
else:
hist_images.append(msg)
self.history_msgs = hist_images[::-1]
def get_images_history_or_none(self):
his_images = []
for msg in self.history_msgs:
if isinstance(msg["content"], list):
for itm in msg["content"]:
if itm["type"] == "image_url":
his_images.append(itm["image_url"])
return his_images if len(his_images) > 0 else None
def generate(
self,
prompt,
images,
stream=True,
max_size=700,
verbose=False,
prevent_more_image=True,
keep_history=True,
):
if images is not None:
crt_images = load_multi_images_maybe(images)
if keep_history:
if prevent_more_image:
# will delete previous all images.
self.remove_history_images()
images_in = crt_images
else:
logger.warning(
"you have set prevent_more_image=False, current more can not handle history have many images, the result would be wrose."
)
images_in = crt_images + self.get_history_images()
else:
self.history_msgs = [self.default_sys]
images_in = crt_images
# print(images)
self.image_processor.size["longest_edge"] = max_size
pixel_values = [
self.image_processor.preprocess(
img,
return_tensors="pt",
)["pixel_values"]
.to(self.model.device)
.to(self.model.dtype)
for img in images_in
]
self.history_msgs.append(
{
"role": "user",
"content": [
{"type": "image_url", "image_url": img} for img in crt_images
]
+ [
{"type": "text", "text": prompt},
],
},
)
else:
if keep_history:
his_images = self.get_images_history_or_none()
pixel_values = [
self.image_processor.preprocess(
img,
return_tensors="pt",
)["pixel_values"]
.to(self.model.device)
.to(self.model.dtype)
for img in his_images
]
else:
pixel_values = None
self.history_msgs.append(
{
"role": "user",
"content": [
{"type": "text", "text": prompt},
],
},
)
if verbose:
if pixel_values is not None:
logger.info(f"pixel_values: {[t.shape for t in pixel_values]}")
if keep_history and len(self.history_msgs) > 6:
# remove on first pair from history
self.history_msgs = [
msg for i, msg in enumerate(self.history_msgs) if i != 1 and i != 2
]
if verbose and len(self.history_msgs) > 0:
print(self.history_msgs)
input_templated = self.build_chat_prompt(self.history_msgs, self.tokenizer)
if verbose:
print(input_templated)
response = self.generate_response(
self.model,
self.tokenizer,
pixel_values,
input_templated,
stream=stream,
)
if keep_history:
self.history_msgs.append(
{"role": "assistant", "content": response},
)
return response
def generate_response(
self, model, tokenizer, pixels, prompt, stream=True, return_generator=False
):
input_ids = (
tokenizer_image_token(prompt, tokenizer, return_tensors="pt")
.unsqueeze(0)
.to(model.device)
)
if stream and not return_generator:
streamer = TextStreamer(
tokenizer, skip_prompt=True, skip_special_tokens=True
)
if return_generator:
streamer = TextIteratorStreamer(self.tokenizer, skip_special_tokens=True)
gen_args = {
"pixel_values": pixels,
"input_ids": input_ids,
"max_new_tokens": 460,
"do_sample": False,
"eos_token_id": self.tokenizer.eos_token_id,
"pad_token_id": self.tokenizer.pad_token_id,
"streamer": streamer if stream or return_generator else None,
}
if return_generator:
thread = threading.Thread(
target=self.model.generate,
kwargs=gen_args,
)
thread.start()
return (new_text for new_text in streamer)
else:
with torch.autocast(device_type="cuda", dtype=torch.float16):
output_ids = model.generate(**gen_args)
return tokenizer.decode(output_ids[0], skip_special_tokens=True).strip()
def chat_with_request(
self, messages, stream=True, prevent_more_image=True, verbose=False
):
"""
in case we already have a messages list
"""
messages_new = []
images = []
last_img_idx = 0
for msg in messages[::-1]:
if self.msg_has_img(msg):
if last_img_idx >= 1 and prevent_more_image:
msg_new = msg.copy()
msg_new["content"] = [
itm for itm in msg["content"] if itm["type"] != "image_url"
]
messages_new.append(msg_new)
else:
for itm in msg["content"]:
if itm["type"] == "image_url":
images.append(url_to_image(itm["image_url"]["url"]))
messages_new.append(msg)
last_img_idx += 1
else:
messages_new.append(msg)
if prevent_more_image:
assert (
len(images) <= 1
), "if prevent more image, images at each iter should be 1."
messages_new = messages_new[::-1]
if len(images) > 0:
pixel_values = [
self.image_processor.preprocess(
img,
return_tensors="pt",
)["pixel_values"]
.to(self.model.device)
.to(self.model.dtype)
for img in images
]
else:
pixel_values = None
input_templated = self.build_chat_prompt(messages_new, self.tokenizer)
if pixel_values is not None:
print(input_templated)
print(images)
if stream:
generator = self.generate_response(
self.model,
self.tokenizer,
pixel_values,
input_templated,
return_generator=True,
)
return generator
else:
response = self.generate_response(
self.model, self.tokenizer, pixel_values, input_templated, stream=False
)
return response
def stream_chat_with_request(self, messages):
for chunk in self.chat_with_request(messages, stream=True):
yield chunk
+322
View File
@@ -0,0 +1,322 @@
import json
import time
import uuid
import torch
import uvicorn
import os
from uvicorn.config import Config
from uvicorn import Server
from pydantic import BaseModel, Field
from fastapi import BackgroundTasks, FastAPI, HTTPException, Request, Response
from contextlib import asynccontextmanager
from starlette.responses import StreamingResponse
from typing import Any, Dict
from coreai.serve.api_schema import (
DeltaMessage,
ChatCompletionRequest,
ChatCompletionResponseChoice,
ChatCompletionResponseStreamChoice,
ChatCompletionResponse,
ChatMessage,
ModelCard,
ModelList,
ModelPermission,
CompletionUsage,
)
import requests
from PIL import Image
import base64
from io import BytesIO
from loguru import logger
import uuid
_TEXT_COMPLETION_CMD = object()
global_model = None
source_prefix = "You are a helpful assistant."
local_doc_qa = None
conv = None
debug_mode = False
@asynccontextmanager
async def lifespan(app: FastAPI): # collects GPU memory
yield
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
app = FastAPI(lifespan=lifespan)
def load_image(image_file):
if image_file.startswith("http") or image_file.startswith("https"):
response = requests.get(image_file)
image = Image.open(BytesIO(response.content)).convert("RGB")
else:
image_data = base64.b64decode(image_file)
image = Image.open(BytesIO(image_data)).convert("RGB")
return image
def add_extra_stop_words(stop_words):
if stop_words:
_stop_words = []
_stop_words.extend(stop_words)
for x in stop_words:
s = x.lstrip("\n")
if s and (s not in _stop_words):
_stop_words.append(s)
return _stop_words
return stop_words
def trim_stop_words(response, stop_words):
if stop_words:
for stop in stop_words:
idx = response.find(stop)
if idx != -1:
response = response[:idx]
return response
async def text_complete_last_message_vllm(
history, stop_words_ids, gen_kwargs, tokenizer, model, request_id
):
im_start = "<|im_start|>"
im_end = "<|im_end|>"
prompt = f"{im_start}system\nYou are a helpful assistant.{im_end}"
for i, (query, response) in enumerate(history):
query = query.lstrip("\n").rstrip()
response = response.lstrip("\n").rstrip()
prompt += f"\n{im_start}user\n{query}{im_end}"
prompt += f"\n{im_start}assistant\n{response}{im_end}"
prompt = prompt[: -len(im_end)]
_stop_words_ids = [tokenizer.encode(im_end)]
if stop_words_ids:
for s in stop_words_ids:
_stop_words_ids.append(s)
stop_words_ids = _stop_words_ids
results_generator = model.generate(prompt, gen_kwargs, request_id)
output = ""
async for request_output in results_generator:
p = request_output.prompt
output = request_output.outputs[-1].text
# assert output.startswith(prompt)
# output = output[len(prompt) :]
output = trim_stop_words(output, ["<|endoftext|>", im_end])
# print(f"<completion>\n{prompt}\n<!-- *** -->\n{output}\n</completion>")
return output
@app.get("/")
async def root():
return {"message": "Hello World, did you using frp get your local service out?"}
@app.get("/v1/models")
async def show_available_models():
# global model_name_runing
models = ["namo-500m", "namo-700m", "gpt4o", "o1", "internvl2-8b"]
models.sort()
model_cards = []
for m in models:
model_cards.append(ModelCard(id=m, root=m, permission=[ModelPermission()]))
return ModelList(data=model_cards)
def get_response_msg_auto_stream(msg, model_id, stream=False):
if stream:
gen = get_info_msg_stream(msg, model_id=model_id)
return StreamingResponse(gen, media_type="text/event-stream")
choice_data = ChatCompletionResponseChoice(
index=0,
message=ChatMessage(role="assistant", content=msg),
finish_reason="stop",
)
return ChatCompletionResponse(
id=str(uuid.uuid4()),
created=time.time_ns() // 1_000_000,
model=model_id,
choices=[choice_data],
object="chat.completion",
)
async def get_info_msg_stream(content: str, model_id: str):
choice_data = ChatCompletionResponseStreamChoice(
index=0, delta=DeltaMessage(content=content), finish_reason=None
)
chunk = ChatCompletionResponse(
id=str(uuid.uuid4()),
created=time.time_ns() // 1_000_000,
model=model_id,
choices=[choice_data],
object="chat.completion.chunk",
)
yield "data: {}\n\n".format(chunk.model_dump_json(exclude_unset=True))
choice_data = ChatCompletionResponseStreamChoice(
index=0, delta=DeltaMessage(), finish_reason="stop"
)
chunk = ChatCompletionResponse(
id=str(uuid.uuid4()),
created=time.time_ns() // 1_000_000,
model=model_id,
choices=[choice_data],
object="chat.completion.chunk",
)
yield "data: {}\n\n".format(chunk.model_dump_json(exclude_unset=True))
def _map_content(content):
if isinstance(content, str):
return content, None
else:
if len(content) > 1 and any(item.type == "image_url" for item in content):
# might contains image
# print(content)
text = next(itm for itm in content if itm.type == "text").text
img = next(itm.image_url for itm in content if itm.type == "image_url").url
return "<image> " + text, load_image(img)
else:
return content[0], None
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
async def create_chat_completion(request: ChatCompletionRequest):
global global_model, source_prefix, local_doc_qa, conv, debug_mode
t_id = int(time.time())
r_id = f"chatcmpl-{t_id}"
if request.stream:
response = stream_response(request, gen_kwargs=None)
return StreamingResponse(
response,
media_type="text/event-stream",
)
# else:
text = global_model.chat_with_request(
request.model_dump()["messages"], stream=False
)
vis_chat_resp = {
"id": r_id,
"object": "chat.completion", # chat.completions.chunk for stream
"created": t_id,
# "model": global_model.model_name,
"model": "namo",
"system_fingerprint": "fp_111111111",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": text,
},
"logprobs": None,
"finish_reason": "stop",
}
],
"usage": {
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": 0,
},
}
logger.debug(f"Response: {vis_chat_resp}")
return vis_chat_resp
def stream_response(
request,
gen_kwargs: Dict[str, Any],
):
# prompt_txt_num = len(gen_kwargs["inputs"])
prompt_txt_num = 10
all_output = ""
response_generator = global_model.stream_chat_with_request(
request.model_dump()["messages"]
)
for new_text in response_generator:
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(role="assistant", content=new_text),
finish_reason=None,
)
if new_text is not None:
chunk = ChatCompletionResponse(
id=str(uuid.uuid4()),
created=time.time_ns() // 1_000_000,
model="namo",
choices=[choice_data],
object="chat.completion.chunk",
)
# print(chunk.model_dump_json())
# print(new_text)
all_output += new_text
yield "data: {}\n\n".format(chunk.model_dump_json())
completion_txt_num = len(all_output)
# recalculate token
prompt_txt_num *= 1.33
completion_txt_num *= 1.33
choice_data = ChatCompletionResponseStreamChoice(
index=0, delta=DeltaMessage(role="assistant", content=""), finish_reason="stop"
)
chunk = ChatCompletionResponse(
id=str(uuid.uuid4()),
created=time.time_ns() // 1_000_000,
model="namo",
choices=[choice_data],
object="chat.completion.chunk",
usage=CompletionUsage(
prompt_tokens=int(prompt_txt_num),
completion_tokens=int(completion_txt_num),
total_tokens=int(completion_txt_num + prompt_txt_num),
),
)
yield "data: {}\n\n".format(chunk.model_dump_json())
def start_server(model="namo", ip="127.0.0.1", port=8080):
global global_model
if not os.path.exists(model):
if model == "minicpm":
model_path = "checkpoints/minicpm_v2_6"
elif model == "internvl2":
model_path = "checkpoints/internvl2-8b/"
elif model == "qwen2vl":
model_path = "checkpoints/qwen2-vl-7b/"
else:
model_path = "checkpoints/Namo-500M-V1/"
else:
model_path = model
if "internvl" in model_path:
logger.warning("not supported for now")
elif "minicpm" in model_path:
logger.warning("not supported for now")
elif "qwen2-vl" in model_path:
logger.warning("not supported for now")
elif "namo" in model_path.lower():
from namo.api.namo import NamoVL
global_model = NamoVL(model_path=model_path, device="auto")
logger.success("namo model initiated!")
else:
ValueError(f"unsupported model: {model_path}")
http_config = Config(app=app, host=ip, port=port, log_level="info")
http_server = Server(config=http_config)
import asyncio
uvicorn.run(app, host=ip, port=port, workers=1)
+93
View File
@@ -0,0 +1,93 @@
from transformers import (
AutoTokenizer,
AutoProcessor,
)
try:
from qwen_vl_utils import process_vision_info
from transformers.models.qwen2_5_vl import Qwen2_5_VLModel
from transformers import Qwen2_5_VLForConditionalGeneration
except ImportError as e:
pass
from namo.api.base import VLBase
from loguru import logger
class Qwen2_5_VL(VLBase):
def __init__(self, model_path=None, processor_path=None, device="auto"):
super().__init__(model_path, processor_path, device)
# default: Load the model on the available device(s)
def load_model(self, model_path):
if model_path is None:
model_path = "checkpoints/Qwen2.5-VL-3B-Instruct"
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
model_path, torch_dtype="auto", device_map="auto"
)
model.to(self.device)
logger.info(f"model loaded from: {model_path}")
return model
def load_processor(self, processor_path):
if processor_path is None:
processor_path = "checkpoints/Qwen2.5-VL-3B-Instruct"
processor = AutoProcessor.from_pretrained(processor_path)
return processor
def get_msg(self, text, image=None):
if image is None:
return {
"role": "user",
"content": [
{"type": "text", "text": text},
],
}
return {
"role": "user",
"content": [
{
"type": "image",
"image": image,
},
{"type": "text", "text": text},
],
}
def generate(
self,
prompt,
images,
stream=True,
max_size=700,
verbose=False,
prevent_more_image=True,
):
msg = self.get_msg(prompt, images)
messages = [msg]
# Preparation for inference
text = self.processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
image_inputs, video_inputs = process_vision_info(messages)
inputs = self.processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
)
inputs = inputs.to(self.device)
# Inference: Generation of the output
generated_ids = self.model.generate(**inputs, max_new_tokens=128)
generated_ids_trimmed = [
out_ids[len(in_ids) :]
for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = self.processor.batch_decode(
generated_ids_trimmed,
skip_special_tokens=True,
clean_up_tokenization_spaces=False,
)
print(output_text)
+22
View File
@@ -0,0 +1,22 @@
"""
A unified API interface support various VL models
"""
from namo.api.qwen2_5_vl import Qwen2_5_VL
from .namo import NamoVL
class VLInfer:
def __init__(self, model_type="qwen2.5-vl", device="auto"):
if "qwen2.5-vl" in model_type:
self.model = Qwen2_5_VL()
elif "namo" in model_type.lower():
self.model = NamoVL(device=device)
else:
raise ValueError(f"Unknown model type: {model_type}")
def generate(self, prompt, images, verbose=False):
self.model.generate(prompt, images, verbose=verbose)
+52
View File
@@ -0,0 +1,52 @@
import argparse
from namo.api.openai import start_server
def handle_chat(args):
print(f"Starting chat with model: {args.model}")
print("Chat functionality is under development.")
def handle_server(args):
print("Starting the server...")
start_server(ip=args.ip, port=args.port, model=args.model)
def main():
parser = argparse.ArgumentParser(
description="Namo CLI: A tool for chat and server management."
)
subparsers = parser.add_subparsers(dest="command", help="Available subcommands")
chat_parser = subparsers.add_parser("chat", help="Start a chat session")
chat_parser.add_argument(
"--model", type=str, help="Type of model or model local path."
)
chat_parser.set_defaults(func=handle_chat)
server_parser = subparsers.add_parser("server", help="Start the server")
server_parser.add_argument(
"--ip",
type=str,
default="127.0.0.1",
help="The IP address to bind the server to",
)
server_parser.add_argument(
"--port", type=int, default=8000, help="The port to run the server on"
)
server_parser.add_argument(
"--model", type=str, help="Type of model or model local path."
)
server_parser.set_defaults(func=handle_server)
args = parser.parse_args()
if hasattr(args, "func"):
args.func(args)
else:
parser.print_help()
if __name__ == "__main__":
main()
+78
View File
@@ -0,0 +1,78 @@
from dataclasses import dataclass, field
from typing import Optional, List
import transformers
@dataclass
class ModelArguments:
llm_model_path: Optional[str] = field(default="checkpoints/Qwen2.5-0.5B-Instruct")
ve_model_path: Optional[str] = field(default="checkpoints/aimv2-large-patch14-224")
ae_model_path: Optional[str] = None
pretrain_model_path: Optional[str] = None
version: Optional[str] = field(default="v0")
freeze_backbone: bool = field(default=False)
tune_conn_ve_llm: bool = field(default=False)
mm_vision_select_layer: Optional[int] = field(default=-1)
pretrain_conn_ve_llm_path: Optional[str] = field(default=None)
pretrain_stage_1_5: Optional[str] = field(default=None)
conn_ve_llm_type: Optional[str] = field(default="linear")
mm_use_im_start_end: bool = field(default=False)
mm_use_im_patch_token: bool = field(default=True)
mm_patch_merge_type: Optional[str] = field(default="flat")
mm_vision_select_feature: Optional[str] = field(default="patch")
new_img_size: Optional[int] = field(default=None)
max_img_size: Optional[int] = field(default=None)
normalized_before_model: Optional[bool] = field(default=True)
unfreeze_ve: bool = field(default=False)
unfreeze_ve_layer_index: Optional[int] = field(default=None)
s2: bool = field(default=False)
s2_scales: Optional[str] = field(default="384,768")
s2_max_split_size: int = field(default=384)
@dataclass
class DataArguments:
data_path: Optional[List[str]] = field(
default=None, metadata={"help": "Path to the training data."}
)
lazy_preprocess: bool = False
is_multimodal: bool = False
image_folder: Optional[str] = field(default=None)
image_aspect_ratio: str = "square"
video_frames_num: int = field(default=16)
video_fps: Optional[int] = field(default=1)
dynamic_size: bool = False
native_size: bool = False
@dataclass
class TrainingArguments(transformers.TrainingArguments):
cache_dir: Optional[str] = field(default=None)
optim: str = field(default="adamw_torch")
remove_unused_columns: bool = field(default=False)
freeze_conn_ve_llm: bool = field(default=False)
mpt_attn_impl: Optional[str] = field(default="triton")
model_max_length: int = field(default=512)
double_quant: bool = field(
default=True,
metadata={
"help": "Compress the quantization statistics through double quantization."
},
)
quant_type: str = field(
default="nf4",
metadata={
"help": "Quantization data type to use. Should be one of `fp4` or `nf4`."
},
)
bits: int = field(default=16, metadata={"help": "How many bits to use."})
lora_enable: bool = False
lora_r: int = 64
lora_alpha: int = 16
lora_dropout: float = 0.05
lora_weight_path: str = ""
lora_bias: str = "none"
conn_ve_llm_lr: Optional[float] = None
ve_lr: Optional[float] = None
group_by_modality_length: bool = field(default=False)
use_dora: bool = False
+530
View File
@@ -0,0 +1,530 @@
import copy
from dataclasses import dataclass
import json
import os
from typing import Dict, Sequence
from PIL import Image
import jsonlines
import random
import torch
from torch.utils.data import Dataset
import transformers
from namo.dataargs import DataArguments
from namo.models.symbols import (
DEFAULT_IM_END_TOKEN,
DEFAULT_IM_START_TOKEN,
DEFAULT_IMAGE_TOKEN,
DEFAULT_VIDEO_TOKEN,
IGNORE_INDEX,
)
from namo.utils.process_utils import (
convert_image_tags,
get_suitable_size_hw,
process_video_fixed_frames,
resize_pad_images_to_target,
)
from namo.utils.utils import rank0_print
from namo.utils import convs as conversation_lib
from namo.utils.process_template import *
def is_dynamic_size_input(keys):
return "shortest_edge" in keys and "longest_edge" in keys
def expand2square(pil_img, background_color):
width, height = pil_img.size
if width == height:
return pil_img
elif width > height:
result = Image.new(pil_img.mode, (width, width), background_color)
result.paste(pil_img, (0, (width - height) // 2))
return result
else:
result = Image.new(pil_img.mode, (height, height), background_color)
result.paste(pil_img, ((height - width) // 2, 0))
return result
def get_suitable_size(images, longest_edge=800):
max_width = 0
max_height = 0
for image in images:
if isinstance(image, list):
image = image[0]
width, height = image.size
if width > max_width:
max_width = width
if height > max_height:
max_height = height
return min(max(max_width, max_height), longest_edge)
def preprocess_multimodal(sources: Sequence[str], data_args: DataArguments) -> Dict:
is_multimodal = data_args.is_multimodal
if not is_multimodal:
return sources
for source in sources:
for sentence in source:
if sentence["from"] != "gpt":
sentence["value"] = sentence["value"].lstrip("\n")
# possiably avoid <image> or <video> token exist in second or later turn
if DEFAULT_IMAGE_TOKEN in sentence["value"]:
images_num = sentence["value"].count(DEFAULT_IMAGE_TOKEN)
if images_num == 1:
sentence["value"] = (
sentence["value"].replace(DEFAULT_IMAGE_TOKEN, "").strip()
)
sentence["value"] = (
DEFAULT_IMAGE_TOKEN + "\n" + sentence["value"]
)
sentence["value"] = sentence["value"].strip()
if "mmtag" in conversation_lib.default_conversation.version:
sentence["value"] = sentence["value"].replace(
DEFAULT_IMAGE_TOKEN,
"<Image>" + DEFAULT_IMAGE_TOKEN + "</Image>",
)
else:
# pass
# make <image> <image> into: Image1<image>\nImage2<image>\n
sentence["value"] = convert_image_tags(sentence["value"])
# print(f'multi images {images_num} {sentence["value"]}')
# multi images, keep they same as original position
elif DEFAULT_VIDEO_TOKEN in sentence["value"]:
# print(f'video {sentence}')
# force <video> into video_frames_num * '<image> ' indicates frames
sentence["value"] = (
sentence["value"].replace(DEFAULT_VIDEO_TOKEN, "").strip()
)
sentence["value"] = (
"video sequence frames in order:\n"
+ f"{DEFAULT_IMAGE_TOKEN} " * data_args.video_frames_num
+ "\n"
+ sentence["value"]
)
# make <image> <image> into: 1<image>\n2<image>\n
sentence["value"] = convert_image_tags(sentence["value"])
replace_token = DEFAULT_IMAGE_TOKEN
if data_args.mm_use_im_start_end:
replace_token = (
DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
)
sentence["value"] = sentence["value"].replace(
DEFAULT_IMAGE_TOKEN, replace_token
)
return sources
def preprocess(
sources: Sequence[str],
tokenizer: transformers.PreTrainedTokenizer,
has_image: bool = False,
) -> Dict:
if (
conversation_lib.default_conversation.sep_style
== conversation_lib.SeparatorStyle.PLAIN
):
return preprocess_plain(sources, tokenizer)
if conversation_lib.default_conversation.version == "qwen":
return preprocess_qwen(sources, tokenizer, has_image=has_image)
if conversation_lib.default_conversation.version == "llama3":
return preprocess_llama3(sources, tokenizer, has_image=has_image)
if conversation_lib.default_conversation.version == "mistral":
return preprocess_mistral(sources, tokenizer, has_image=has_image)
if conversation_lib.default_conversation.version == "gemma":
return preprocess_gemma(sources, tokenizer, has_image=has_image)
class LazySupervisedDataset(Dataset):
"""Dataset for supervised fine-tuning."""
def __init__(
self,
data_path: str,
tokenizer: transformers.PreTrainedTokenizer,
data_args: DataArguments,
):
super(LazySupervisedDataset, self).__init__()
# list_data_dict = json.load(open(data_path, "r"))
list_data_dict = []
for data in data_path:
if data.endswith("jsonl"):
with jsonlines.open(data, mode="r") as reader:
raw_data = [item for item in reader]
else:
raw_data = json.load(open(data, "r"))
for i in raw_data:
if "conversations" in i.keys():
i["id"] = len(list_data_dict)
i["ds"] = os.path.basename(data).split(".")[0].split("_train")[0]
list_data_dict.append(i)
rank0_print("Formatting inputs...Skip in lazy mode")
self.tokenizer = tokenizer
self.list_data_dict = list_data_dict
random.shuffle(self.list_data_dict)
self.data_args = data_args
def __len__(self):
return len(self.list_data_dict)
@property
def lengths(self):
length_list = []
for sample in self.list_data_dict:
img_tokens = 128 if "image" in sample else 0
length_list.append(
sum(len(conv["value"].split()) for conv in sample["conversations"])
+ img_tokens
)
return length_list
@property
def modality_lengths(self):
length_list = []
for sample in self.list_data_dict:
cur_len = sum(
len(conv["value"].split()) for conv in sample["conversations"]
)
cur_len = cur_len if "image" in sample or "video" in sample else -cur_len
length_list.append(cur_len)
return length_list
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
attempt, max_attempt = 0, 10
while attempt < max_attempt:
try:
# sample an item
data_dict = self._sample_item(i)
# if data_dict is not None:
break
except Exception as e:
attempt += 1
print(f"Error in loading {i}, retrying...")
import traceback
print(e)
traceback.print_exc()
i = random.randint(0, len(self.list_data_dict) - 1)
return data_dict
def _sample_item(self, i) -> Dict[str, torch.Tensor]:
image = None
sources = self.list_data_dict[i]
if isinstance(i, int):
sources = [sources]
assert len(sources) == 1, "Don't know why it is wrapped to a list" # FIXME
if "image" in sources[0]:
# image_file = self.list_data_dict[i]['image']
# image_folder = self.data_args.image_folder
# processor = self.data_args.image_processor
# image = Image.open(os.path.join(image_folder, image_file)).convert('RGB')
image_file = self.list_data_dict[i]["image"]
ds = self.list_data_dict[i]["ds"]
image_folder = self.data_args.image_folder
processor = self.data_args.image_processor
# todo: consider multiple images input.
if (
("llava" in ds and "llavar" not in ds and "llava_recap" not in ds)
or "sharegpt4v_instruct" in ds
or "sharegpt4v_" in ds
or "share-captioner" in ds
or "gemini" in ds
or "bunny_695k" in ds
or "allava_laion" in ds
or "multi_llava" in ds
or "Cambrian7M" in ds
or "c7s-" in ds
or "ureader_tr" in ds
):
if isinstance(image_file, list):
image = [
Image.open(os.path.join(image_folder, img_f)).convert("RGB")
for img_f in image_file
]
else:
image = Image.open(os.path.join(image_folder, image_file)).convert(
"RGB"
)
else:
if "llavar" in ds:
ds = "llavar"
elif "bunny" in ds and "bunny_695k" not in ds:
ds = "bunny_pretrain_laion_2m"
elif "qa_" in ds:
ds = "qa_data"
elif "sharegpt4o" in ds:
ds = "sharegpt4o/images"
elif "mathv360k_cot" in ds:
ds = "mathv360k_cot/images"
if isinstance(image_file, list):
image = [
Image.open(os.path.join(image_folder, ds, img_f)).convert("RGB")
for img_f in image_file
]
else:
image = Image.open(
os.path.join(image_folder, ds, image_file)
).convert("RGB")
# todo: checking image validness here.
def is_valid_image(img):
width, height = img.size
# must bigger than 28 pixels
if width > 14 and height > 14:
return True
else:
return False
if isinstance(image, list):
for img in image:
if not is_valid_image(img):
rank0_print(f"Invalid image found, passing... {img.size}")
raise ValueError(f"Invalid image found {img.size}")
else:
if not is_valid_image(image):
rank0_print(f"Invalid image found, passing... {image.size}")
raise ValueError(f"Invalid image found {image.size}")
if self.data_args.image_aspect_ratio == "pad":
if (
not (
processor.size and is_dynamic_size_input(processor.size.keys())
)
and not self.data_args.dynamic_size
):
# for navit we dont need pad
if isinstance(image, list):
image = [
expand2square(
i, tuple(int(x * 255) for x in processor.image_mean)
)
for i in image
]
# only preprocess item by item when fixed sizes, otherwise do it in batch
image = processor.preprocess(image, return_tensors="pt")[
"pixel_values"
]
else:
image = expand2square(
image, tuple(int(x * 255) for x in processor.image_mean)
)
# only preprocess item by item when fixed sizes, otherwise do it in batch
image = processor.preprocess(image, return_tensors="pt")[
"pixel_values"
][0]
else:
# print(processor.size)
# does multiple images can be handled here?
if (
not (
processor.size and is_dynamic_size_input(processor.size.keys())
)
and not self.data_args.dynamic_size
):
# only preprocess item by item when fixed sizes, otherwise do it in batch
if isinstance(image, list):
image = processor.preprocess(image, return_tensors="pt")[
"pixel_values"
]
else:
image = processor.preprocess(image, return_tensors="pt")[
"pixel_values"
][0]
# if muti img and not same with image token and real images num, force it.
if isinstance(image_file, list) and sources[0]["conversations"][0][
"value"
].count(DEFAULT_IMAGE_TOKEN) != len(image_file):
a = sources[0]["conversations"][0]["value"].replace(
DEFAULT_IMAGE_TOKEN, ""
)
sources[0]["conversations"][0]["value"] = (
f"{DEFAULT_IMAGE_TOKEN} " * len(image_file) + a
)
elif (
isinstance(image_file, str)
and sources[0]["conversations"][0]["value"].count(DEFAULT_IMAGE_TOKEN)
> 1
):
# sometimes single image can have multiple <image> tag
print(f"data single turn but got multiple <image>: {sources}")
a = sources[0]["conversations"][0]["value"].replace(
DEFAULT_IMAGE_TOKEN, ""
)
sources[0]["conversations"][0]["value"] = f"{DEFAULT_IMAGE_TOKEN} " + a
sources = preprocess_multimodal(
copy.deepcopy([e["conversations"] for e in sources]), self.data_args
)
elif "image" not in sources[0] and "video" in sources[0]:
# print("video sample")
video_file = self.list_data_dict[i]["video"]
ds = self.list_data_dict[i]["ds"]
image_folder = self.data_args.image_folder
video_file = os.path.join(image_folder, ds, video_file)
video = process_video_fixed_frames(
video_file, self.data_args.video_fps, self.data_args.video_frames_num
)
processor = self.data_args.image_processor
image = processor.preprocess(video, return_tensors="pt")["pixel_values"]
sources = preprocess_multimodal(
copy.deepcopy([e["conversations"] for e in sources]), self.data_args
)
else:
sources = copy.deepcopy([e["conversations"] for e in sources])
# print(f'sources : {sources}')
data_dict = preprocess(
sources,
self.tokenizer,
has_image=(
"image" in self.list_data_dict[i] or "video" in self.list_data_dict[i]
),
)
if isinstance(i, int):
data_dict = dict(
input_ids=data_dict["input_ids"][0], labels=data_dict["labels"][0]
)
# image exist in the data
if "image" in self.list_data_dict[i] or "video" in self.list_data_dict[i]:
data_dict["image"] = image
elif self.data_args.is_multimodal:
# image does not exist in the data, but the model is multimodal
# crop_size = self.data_args.image_processor.crop_size
# data_dict['image'] = torch.zeros(3, crop_size['height'], crop_size['width'])
if hasattr(self.data_args.image_processor, "crop_size"):
crop_size = self.data_args.image_processor.crop_size
if self.data_args.dynamic_size:
data_dict["image"] = Image.new("RGB", (448, 448), (0, 0, 0))
else:
data_dict["image"] = torch.zeros(
3, crop_size["height"], crop_size["width"]
)
else:
processor = self.data_args.image_processor
size = processor.size
if not (
processor.size and is_dynamic_size_input(processor.size.keys())
):
data_dict["image"] = torch.zeros(3, size["height"], size["width"])
else:
# fake for pure text in navit
data_dict["image"] = Image.new("RGB", (448, 448), (0, 0, 0))
# print(data_dict)
return data_dict
@dataclass
class DataCollatorForSupervisedDataset(object):
"""Collate examples for supervised fine-tuning."""
tokenizer: transformers.PreTrainedTokenizer
data_args: DataArguments
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
input_ids, labels = tuple(
[instance[key] for instance in instances] for key in ("input_ids", "labels")
)
input_ids = torch.nn.utils.rnn.pad_sequence(
input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
)
labels = torch.nn.utils.rnn.pad_sequence(
labels, batch_first=True, padding_value=IGNORE_INDEX
)
input_ids = input_ids[:, : self.tokenizer.model_max_length]
labels = labels[:, : self.tokenizer.model_max_length]
batch = dict(
input_ids=input_ids,
labels=labels,
attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
)
if "image" in instances[0]:
if isinstance(instances[0]["image"], torch.Tensor):
images = [instance["image"] for instance in instances]
if all(
x is not None and x.shape[-2:] == images[0].shape[-2:]
for x in images
):
batch["pixel_values"] = torch.cat(
[i.unsqueeze(0) if len(i.shape) == 3 else i for i in images],
dim=0,
)
else:
batch["pixel_values"] = images
else:
# handle various sizes image inputs
images = [instance["image"] for instance in instances]
images = [
item
for sublist in images
for item in (sublist if isinstance(sublist, list) else [sublist])
]
if self.data_args.dynamic_size:
# size = get_suitable_size(images)
if self.data_args.native_size:
batch["pixel_values"] = [
self.data_args.image_processor.preprocess(
img,
return_tensors="pt",
)["pixel_values"]
for img in images
]
else:
size = get_suitable_size_hw(
images, longest_edge=self.data_args.longest_edge
)
images = resize_pad_images_to_target(images, size)
images_tensor = self.data_args.image_processor.preprocess(
images,
return_tensors="pt",
# size={"width": size, "height": size},
)
batch["pixel_values"] = images_tensor["pixel_values"]
else:
images_tensor = self.data_args.image_processor.preprocess(
images, return_tensors="pt"
)
batch["pixel_values"] = images_tensor["pixel_values"]
if not isinstance(batch["pixel_values"], list):
if "pixel_attention_mask" in images_tensor:
batch["pixel_attention_mask"] = images_tensor[
"pixel_attention_mask"
][0]
# this will goes to Navit
# print('does it got pixeltteionamsk? ------------>')
return batch
def make_supervised_data_module(
tokenizer: transformers.PreTrainedTokenizer,
data_args,
model_args,
) -> Dict:
"""Make dataset and collator for supervised fine-tuning."""
if model_args.version in conversation_lib.conv_templates:
conversation_lib.default_conversation = conversation_lib.conv_templates[
model_args.version
]
else:
conversation_lib.default_conversation = conversation_lib.conv_templates[
"vicuna_v1"
]
train_dataset = LazySupervisedDataset(
tokenizer=tokenizer, data_path=data_args.data_path, data_args=data_args
)
data_collator = DataCollatorForSupervisedDataset(
tokenizer=tokenizer, data_args=data_args
)
return dict(
train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator
)
+745
View File
@@ -0,0 +1,745 @@
from copy import deepcopy
import io
import json
import random
import numpy as np
import torch
from PIL import Image
from namo.dataset import expand2square
from namo.models.symbols import IGNORE_INDEX
from torch.utils.data import ConcatDataset, WeightedRandomSampler
from loguru import logger
import torch.nn.functional as F
import torchvision.transforms as T
from torchvision.transforms.functional import InterpolationMode
from torch.utils.data import Dataset
class WeightedConcatDataset(ConcatDataset):
def __init__(self, datasets, weights):
super().__init__(datasets)
self.weights = torch.DoubleTensor(weights)
self.total_size = sum(len(d) for d in datasets)
self.sampler = WeightedRandomSampler(
weights=self.weights, num_samples=self.total_size, replacement=True
)
def __iter__(self):
return iter(self.sampler)
def __len__(self):
return self.total_size
def dpo_concat_pad_data_collator(features, pad_id=0):
first = features[0]
batch = {}
for prefix in ["chosen_", "rejected_"]:
batch_lens = [feat[f"{prefix}input_ids"].shape[0] for feat in features]
max_item_length = max(batch_lens)
for idx in range(len(features)):
feat = features[idx]
temp_input_ids = torch.LongTensor([pad_id] * max_item_length)
temp_input_ids[: feat[f"{prefix}input_ids"].shape[0]] = feat[
f"{prefix}input_ids"
]
feat[f"{prefix}input_ids"] = temp_input_ids
temp_labels = torch.LongTensor([IGNORE_INDEX] * max_item_length)
temp_labels[: feat[f"{prefix}labels"].shape[0]] = feat[f"{prefix}labels"]
feat[f"{prefix}labels"] = temp_labels
feat[f"{prefix}attention_mask"] = feat[f"{prefix}input_ids"].ne(pad_id)
# Handling of all other possible keys.
# Again, we will use the first element to figure out which key/values are not None for this model.
for k, v in first.items():
if (
k not in ("pixel_values", "image_flags")
and v is not None
and not isinstance(v, str)
):
if isinstance(v, torch.Tensor):
batch[k] = torch.stack([f[k] for f in features])
elif isinstance(v, np.ndarray):
batch[k] = torch.tensor(np.stack([f[k] for f in features]))
else:
batch[k] = torch.tensor([f[k] for f in features])
if k in ("pixel_values", "image_flags"):
if isinstance(v, torch.Tensor):
batch[k] = torch.concat([f[k] for f in features])
elif isinstance(v, np.ndarray):
batch[k] = torch.concat(np.stack([f[k] for f in features]))
else:
batch[k] = torch.concat([f[k] for f in features])
return batch
def simulate_jpeg_degradation(quality):
def jpeg_degrade(img):
with io.BytesIO() as output:
img.convert("RGB").save(output, format="JPEG", quality=quality)
output.seek(0) # Move the reading cursor to the start of the stream
img_jpeg = Image.open(
output
).copy() # Use .copy() to make sure the image is loaded in memory
return img_jpeg
return jpeg_degrade
qualities = list(range(75, 101))
jpeg_degrade_functions = {
quality: simulate_jpeg_degradation(quality) for quality in qualities
}
def build_transform(is_train, input_size, pad2square=False, normalize_type="imagenet"):
if normalize_type == "imagenet":
MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
elif normalize_type == "clip":
MEAN, STD = CLIP_MEAN, CLIP_STD
elif normalize_type == "siglip":
MEAN, STD = SIGLIP_MEAN, SIGLIP_STD
else:
raise NotImplementedError
if is_train: # use data augumentation
transform = T.Compose(
[
T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img),
T.RandomChoice(
[T.Lambda(jpeg_degrade_functions[quality]) for quality in qualities]
),
T.Resize(
(input_size, input_size), interpolation=InterpolationMode.BICUBIC
),
T.ToTensor(),
T.Normalize(mean=MEAN, std=STD),
]
)
else:
if pad2square is False: # now we use this transform function by default
transform = T.Compose(
[
T.Lambda(
lambda img: img.convert("RGB") if img.mode != "RGB" else img
),
T.Resize(
(input_size, input_size),
interpolation=InterpolationMode.BICUBIC,
),
T.ToTensor(),
T.Normalize(mean=MEAN, std=STD),
]
)
else:
transform = T.Compose(
[
T.Lambda(
lambda img: img.convert("RGB") if img.mode != "RGB" else img
),
T.Lambda(
lambda img: expand2square(
img, tuple(int(x * 255) for x in MEAN)
)
),
T.Resize(
(input_size, input_size),
interpolation=InterpolationMode.BICUBIC,
),
T.ToTensor(),
T.Normalize(mean=MEAN, std=STD),
]
)
return transform
class LazySupervisedDataset(Dataset):
"""Dataset for supervised fine-tuning."""
def __init__(
self,
template_name,
meta,
tokenizer,
tcs_loader,
ds_name,
num_image_token,
image_size=448,
is_train=True,
pad2square=False,
group_by_length=False,
dynamic_image_size=False,
use_thumbnail=False,
min_dynamic_patch=1,
max_dynamic_patch=12,
min_num_frame=8, # for video data
max_num_frame=32, # for video data
sampling_method="rand", # for video data
repeat_time=1,
normalize_type="imagenet",
random_seed=0,
):
super(LazySupervisedDataset, self).__init__()
self.ds_name = ds_name
self.tokenizer = tokenizer
self.template_name = template_name
self.num_image_token = num_image_token
logger.info(f"[Dataset] num_image_token: {num_image_token}")
logger.info(f"[Dataset] dynamic_image_size: {dynamic_image_size}")
logger.info(f"[Dataset] use_thumbnail: {use_thumbnail}")
logger.info(
f"[Dataset] min_dynamic_patch: {min_dynamic_patch}, max_dynamic_patch: {max_dynamic_patch}"
)
self.image_size = image_size
self.is_train = is_train
self.pad2square = pad2square
self.max_num_frame = max_num_frame
self.min_num_frame = min_num_frame
self.sampling_method = sampling_method
logger.info("Formatting inputs...Skip in lazy mode")
assert meta["annotation"].endswith(
"jsonl"
), f'annotation must be jsonl, but got {meta["annotation"]}'
with open(meta["annotation"], "r") as f:
self.raw_data = f.readlines()
if repeat_time < 1:
# If repeat_time is less than 1, select a portion of the data
self.raw_data = random.sample(
self.raw_data, k=int(len(self.raw_data) * repeat_time)
)
if repeat_time > 1:
repeat_time = int(repeat_time)
assert isinstance(repeat_time, int)
# Repeat the list if repeat_time is greater than 1
self.raw_data = self.raw_data * repeat_time
self.rng = np.random.default_rng(seed=random_seed)
self.rng.shuffle(self.raw_data)
self.root = meta["root"]
self.cached_data_dict = {}
self.tcs_loader = tcs_loader
self.group_by_length = group_by_length
self.dynamic_image_size = dynamic_image_size
self.use_thumbnail = use_thumbnail
self.min_dynamic_patch = min_dynamic_patch
self.max_dynamic_patch = max_dynamic_patch
self.normalize_type = normalize_type
# If the precomputed length does not exist, roughly estimate the length of
# each sample to improve the efficiency of group_by_length.
if self.group_by_length:
self.conv2length = (
{}
) # Using a dictionary to speed up token length calculation
self.length = []
for data_item in self.raw_data:
data_item = json.loads(data_item)
if "length" in data_item:
token_length = data_item[
"length"
] # Use precomputed length if available
else:
# Compute token length using the tokenizer
conversations = "\n".join(
[temp["value"] for temp in data_item["conversations"]]
)
str_length = len(conversations)
if str_length not in self.conv2length:
token_length = tokenizer(
conversations,
return_tensors="pt",
padding=False,
truncation=False,
).input_ids.size(1)
self.conv2length[str_length] = (
token_length
+ num_image_token * (max_dynamic_patch + use_thumbnail)
)
else:
token_length = self.conv2length[str_length]
self.length.append(token_length)
def __len__(self):
return len(self.raw_data)
def get_preprocess_function(self):
# Select the appropriate preprocessing function based on the template name
if self.template_name == "Hermes-2":
preprocess_function = preprocess_mpt
elif self.template_name == "internlm2-chat":
preprocess_function = preprocess_internlm
elif self.template_name == "phi3-chat":
preprocess_function = preprocess_phi3
elif self.template_name == "internvl2_5":
preprocess_function = preprocess_internvl2_5
else:
preprocess_function = preprocess
return preprocess_function
def load_image(self, image_path):
# Load the image using tcs_loader if available, otherwise use PIL
if self.tcs_loader is not None and "s3://" in image_path:
return self.tcs_loader(image_path)
return Image.open(image_path).convert("RGB")
def get_image_path(self, image_path):
if image_path.startswith("s3://"): # for ceph
image_path = self.root + image_path
else: # for local image
image_path = os.path.join(self.root, image_path)
return image_path
def get_transform(self):
# Build transformation function
transform = build_transform(
is_train=self.is_train,
input_size=self.image_size,
pad2square=self.pad2square,
normalize_type=self.normalize_type,
)
return transform
@staticmethod
def get_longest_common_prefix_index(tensor1, tensor2):
min_len = min(len(tensor1), len(tensor2))
for i in range(min_len):
if tensor1[i] != tensor2[i]:
return i
return min_len
def multi_modal_get_item(self, data_item):
# Build transformation function
transform = self.get_transform()
# Ensure the first conversation contains an image placeholder
if "<image>" not in data_item["question"]:
data_item["question"] = "<image>\n" + data_item["question"]
# Merge the image path
image_path = self.get_image_path(data_item["image"])
# Load the image using tcs_loader if available, otherwise use PIL
image = self.load_image(image_path)
if (
self.dynamic_image_size
): # If dynamic image size is enabled, preprocess the image dynamically
images = dynamic_preprocess(
image,
min_num=self.min_dynamic_patch,
max_num=self.max_dynamic_patch,
image_size=self.image_size,
use_thumbnail=self.use_thumbnail,
)
else: # Otherwise, use the original image as a single patch
images = [image]
# Apply the transformation to each image and stack the results into a tensor
pixel_values = [transform(image) for image in images]
pixel_values = torch.stack(pixel_values)
# Ensure that there is only one patch if dynamic image size is not enabled
num_patches = pixel_values.size(0)
if not self.dynamic_image_size:
assert (
num_patches == 1
), f"The number of patches should be 1, but got {num_patches}."
# Select the appropriate preprocessing function based on the template name
preprocess_function = self.get_preprocess_function()
# Preprocess the conversations and generate the return dictionary
chosen_conversations = [
{"from": "human", "value": data_item["question"]},
{"from": "gpt", "value": data_item["chosen"]},
]
chosen_ret = preprocess_function(
self.template_name,
[deepcopy(chosen_conversations)],
self.tokenizer,
[self.num_image_token * num_patches],
group_by_length=True,
ds_name=self.ds_name,
)
rejected_conversations = [
{"from": "human", "value": data_item["question"]},
{"from": "gpt", "value": data_item["rejected"]},
]
rejected_ret = preprocess_function(
self.template_name,
[deepcopy(rejected_conversations)],
self.tokenizer,
[self.num_image_token * num_patches],
group_by_length=True,
ds_name=self.ds_name,
)
# Create the final return dictionary
ret = dict(
chosen_input_ids=chosen_ret["input_ids"][0],
chosen_labels=chosen_ret["labels"][0],
chosen_attention_mask=chosen_ret["attention_mask"][0],
rejected_input_ids=rejected_ret["input_ids"][0],
rejected_labels=rejected_ret["labels"][0],
rejected_attention_mask=rejected_ret["attention_mask"][0],
pixel_values=pixel_values,
image_flags=torch.tensor([1] * num_patches, dtype=torch.long),
)
return ret
def multi_modal_multi_image_get_item(self, data_item):
# Build transformation function
transform = self.get_transform()
images, num_tiles = [], []
num_image = len(data_item["image"])
for image_path in data_item["image"]:
# Merge the image path
image_path = self.get_image_path(image_path)
# Load the image using tcs_loader if available, otherwise use PIL
image = self.load_image(image_path)
if (
self.dynamic_image_size
): # If dynamic image size is enabled, preprocess the image dynamically
image = dynamic_preprocess(
image,
min_num=self.min_dynamic_patch,
max_num=max(1, self.max_dynamic_patch // num_image),
image_size=self.image_size,
use_thumbnail=self.use_thumbnail,
)
images += image
num_tiles.append(len(image))
else: # Otherwise, use the original image as a single patch
images.append(image)
num_tiles.append(1)
pixel_values = [transform(image) for image in images]
pixel_values = torch.stack(pixel_values)
num_patches = pixel_values.size(0)
# Select the appropriate preprocessing function based on the template name
preprocess_function = self.get_preprocess_function()
# Preprocess the conversations and generate the return dictionary
num_image_tokens = [self.num_image_token * num_tile for num_tile in num_tiles]
chosen_conversations = [
{"from": "human", "value": data_item["question"]},
{"from": "gpt", "value": data_item["chosen"]},
]
chosen_ret = preprocess_function(
self.template_name,
[deepcopy(chosen_conversations)],
self.tokenizer,
num_image_tokens,
group_by_length=self.group_by_length,
ds_name=self.ds_name,
num_image=num_image,
)
rejected_conversations = [
{"from": "human", "value": data_item["question"]},
{"from": "gpt", "value": data_item["rejected"]},
]
rejected_ret = preprocess_function(
self.template_name,
[deepcopy(rejected_conversations)],
self.tokenizer,
num_image_tokens,
group_by_length=self.group_by_length,
ds_name=self.ds_name,
num_image=num_image,
)
# Create the final return dictionary
ret = dict(
chosen_input_ids=chosen_ret["input_ids"][0],
chosen_labels=chosen_ret["labels"][0],
chosen_attention_mask=chosen_ret["attention_mask"][0],
rejected_input_ids=rejected_ret["input_ids"][0],
rejected_labels=rejected_ret["labels"][0],
rejected_attention_mask=rejected_ret["attention_mask"][0],
pixel_values=pixel_values,
image_flags=torch.tensor([1] * num_patches, dtype=torch.long),
)
return ret
def video_get_item(self, data_item):
# Build transformation function
transform = self.get_transform()
# Ensure the first conversation contains a video placeholder
if "<video>" not in data_item["question"]:
data_item["question"] = "<video>\n" + data_item["question"]
# Get the video file path
video_file = data_item["video"]
video_path = os.path.join(self.root, video_file)
# Load the video frames using tcs_loader
# TODO: Load videos without using tcsloader.
image_list = self.tcs_loader(
video_path,
image_type="video",
max_num_frames=self.max_num_frame,
min_num_frames=self.min_num_frame,
sample=self.sampling_method,
clip=data_item.get("clip", None),
)
# Generate special tokens for each video frame
special_tokens = "\n".join(
["Frame{}: <image>".format(i + 1) for i in range(len(image_list))]
)
data_item["question"] = data_item["question"].replace(
"<video>\n", special_tokens
)
# Transform each frame image and stack them into a tensor
pixel_values = [transform(image) for image in image_list]
pixel_values = torch.stack(pixel_values)
num_patches = pixel_values.size(0)
# Select the appropriate preprocessing function based on the template name
preprocess_function = self.get_preprocess_function()
# Preprocess the conversations and generate the return dictionary
num_image_tokens = [self.num_image_token] * num_patches
chosen_conversations = [
{"from": "human", "value": data_item["question"]},
{"from": "gpt", "value": data_item["chosen"]},
]
chosen_ret = preprocess_function(
self.template_name,
[deepcopy(chosen_conversations)],
self.tokenizer,
num_image_tokens,
group_by_length=True,
use_packed_ds=self.use_packed_ds,
ds_name=self.ds_name,
num_image=num_patches,
)
rejected_conversations = [
{"from": "human", "value": data_item["question"]},
{"from": "gpt", "value": data_item["rejected"]},
]
rejected_ret = preprocess_function(
self.template_name,
[deepcopy(rejected_conversations)],
self.tokenizer,
num_image_tokens,
group_by_length=True,
use_packed_ds=self.use_packed_ds,
ds_name=self.ds_name,
num_image=num_patches,
)
ret = dict(
chosen_input_ids=chosen_ret["input_ids"][0],
chosen_labels=chosen_ret["labels"][0],
chosen_attention_mask=chosen_ret["attention_mask"][0],
rejected_input_ids=rejected_ret["input_ids"][0],
rejected_labels=rejected_ret["labels"][0],
rejected_attention_mask=rejected_ret["attention_mask"][0],
pixel_values=pixel_values,
image_flags=torch.tensor([1] * num_patches, dtype=torch.long),
)
return ret
def pure_text_get_item(self, data_item):
# Build transformation function
transform = self.get_transform()
# Create a blank white image
image = Image.new("RGB", (224, 224), (255, 255, 255))
# Dynamically preprocess the image to generate patches
images = dynamic_preprocess(
image,
min_num=self.min_dynamic_patch,
max_num=1,
image_size=self.image_size,
use_thumbnail=self.use_thumbnail,
)
# Apply the transformation to each image patch and stack them into a tensor
pixel_values = [transform(image) for image in images]
pixel_values = torch.stack(pixel_values)
num_patches = pixel_values.size(0)
# Ensure there is only one patch
assert (
num_patches == 1
), f"The number of patches should be 1, but got {num_patches}."
# Select the appropriate preprocessing function based on the template name
preprocess_function = self.get_preprocess_function()
# Preprocess the conversations and generate the return dictionary
chosen_conversations = [
{"from": "human", "value": data_item["question"]},
{"from": "gpt", "value": data_item["chosen"]},
]
chosen_ret = preprocess_function(
self.template_name,
[deepcopy(chosen_conversations)],
self.tokenizer,
[self.num_image_token * num_patches],
text_only=True,
group_by_length=True,
ds_name=self.ds_name,
)
rejected_conversations = [
{"from": "human", "value": data_item["question"]},
{"from": "gpt", "value": data_item["rejected"]},
]
rejected_ret = preprocess_function(
self.template_name,
[deepcopy(rejected_conversations)],
self.tokenizer,
[self.num_image_token * num_patches],
text_only=True,
group_by_length=True,
ds_name=self.ds_name,
)
# Create the final return dictionary
ret = dict(
chosen_input_ids=chosen_ret["input_ids"][0],
chosen_labels=chosen_ret["labels"][0],
chosen_attention_mask=chosen_ret["attention_mask"][0],
rejected_input_ids=rejected_ret["input_ids"][0],
rejected_labels=rejected_ret["labels"][0],
rejected_attention_mask=rejected_ret["attention_mask"][0],
pixel_values=pixel_values,
image_flags=torch.tensor([0] * num_patches, dtype=torch.long),
)
return ret
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
i = i % len(self.raw_data)
try_cnt, max_try = 0, 10
while True:
if try_cnt > max_try:
raise StopIteration
try:
data_item = json.loads(self.raw_data[i])
if "image" in data_item and len(data_item["image"]) != 0:
if type(data_item["image"]) == list:
ret = self.multi_modal_multi_image_get_item(data_item)
else:
ret = self.multi_modal_get_item(data_item)
elif (
"video" in data_item
and data_item["video"] is not None
and data_item["video"] != ""
):
ret = self.video_get_item(data_item)
else:
ret = self.pure_text_get_item(data_item)
break
except Exception as e:
try_cnt += 1
print(e, self.ds_name, flush=True)
if not isinstance(e, (UnidentifiedImageError, FileNotFoundError)):
traceback.print_exc()
data_item = json.loads(self.raw_data[i])
if "image" in data_item:
if type(data_item["image"]) == list:
images = [self.root + item for item in data_item["image"]]
print(
f"Failed to load image: {images}, the dataset is: {self.ds_name}"
)
else:
if data_item["image"].startswith("s3://"):
data_path = self.root + data_item["image"]
else:
data_path = os.path.join(self.root, data_item["image"])
print(
f"Failed to load image: {data_path}, the dataset is: {self.ds_name}"
)
elif "video" in data_item:
data_path = os.path.join(self.root, data_item["video"])
print(
f"Failed to load video: {data_path}, the dataset is: {self.ds_name}"
)
i = random.randint(0, len(self.raw_data) - 1)
return ret
def build_datasets(
data_args,
tokenizer,
tcs_loader,
model,
group_by_length=False,
dynamic_image_size=False,
use_thumbnail=False,
min_dynamic_patch=1,
max_dynamic_patch=12,
min_num_frame=8,
max_num_frame=32,
normalize_type="imagenet",
):
datasets = []
lengths = []
ds_collections = json.loads(open(data_args.meta_path).read())
for ds_idx, ds_name in enumerate(ds_collections.keys()):
repeat_time = ds_collections[ds_name]["repeat_time"]
if "max_dynamic_patch" in ds_collections[ds_name]:
max_num = ds_collections[ds_name]["max_dynamic_patch"]
logger.info(
f"max_dynamic_patch is set to {max_num} according to the meta file"
)
else:
max_num = max_dynamic_patch
dataset = LazySupervisedDataset(
data_args.conv_style,
ds_collections[ds_name],
tokenizer,
tcs_loader,
ds_name=ds_name,
num_image_token=model.num_image_token,
image_size=data_args.force_image_size,
is_train=ds_collections[ds_name].get("data_augment", False),
pad2square=data_args.pad2square,
group_by_length=group_by_length,
dynamic_image_size=dynamic_image_size,
use_thumbnail=use_thumbnail,
min_dynamic_patch=min_dynamic_patch,
max_dynamic_patch=max_num,
min_num_frame=min_num_frame,
max_num_frame=max_num_frame,
repeat_time=repeat_time,
normalize_type=normalize_type,
random_seed=ds_idx,
)
logger.info(f"Add dataset: {ds_name} with length: {len(dataset)}")
datasets.append(dataset)
if data_args.use_data_resampling:
lengths.append(math.sqrt(len(dataset)))
else:
lengths.append(len(dataset))
if data_args.use_data_resampling:
total_length = sum(lengths)
weights = [l / total_length for l in lengths]
train_dataset = WeightedConcatDataset(datasets, weights)
else:
train_dataset = ConcatDataset(datasets)
return train_dataset
View File
View File
+4
View File
@@ -0,0 +1,4 @@
"""
Moonshine has some modifications upon whisper.
less params but didn't have multilingual support as for now
"""
+3
View File
@@ -0,0 +1,3 @@
"""
we uses whisper
"""
+95
View File
@@ -0,0 +1,95 @@
from typing import Any
from transformers.configuration_utils import PretrainedConfig
from transformers.models.auto import AutoConfig, CONFIG_MAPPING
class NamoConfig(PretrainedConfig):
model_type = "namo"
is_composition = False
sub_configs = {
"text_config": AutoConfig,
"vision_config": AutoConfig,
"audio_config": AutoConfig,
}
def __init__(
self,
text_config=None,
vision_config=None,
audio_config=None,
ignore_index=-100,
image_token_index=-200,
vision_feature_select_strategy="same",
vision_feature_layer=-2,
image_seq_length=576,
new_img_size=None,
shortest_edge=None,
longest_edge=None,
unfreeze_ve=True,
multimodal_projector_bias=True,
conn_ve_llm_type="mlp2x_gelu",
**kwargs,
):
self.ignore_index = ignore_index
self.image_token_index = image_token_index
self.image_seq_length = image_seq_length
self.new_img_size = new_img_size
self.shortest_edge = shortest_edge
self.longest_edge = longest_edge
self.unfreeze_ve = unfreeze_ve
self.conn_ve_llm_type = conn_ve_llm_type
if vision_feature_select_strategy not in ["same", "patch"]:
raise ValueError(
"vision_feature_select_strategy should be one of 'same', 'patch'."
f"Got: {vision_feature_select_strategy}"
)
self.vision_feature_select_strategy = vision_feature_select_strategy
self.vision_feature_layer = vision_feature_layer
if isinstance(vision_config, dict):
vision_config["model_type"] = (
vision_config["model_type"]
if "model_type" in vision_config
else "clip_vision_model"
)
vision_config = CONFIG_MAPPING[vision_config["model_type"]](**vision_config)
elif vision_config is None:
vision_config = CONFIG_MAPPING["clip_vision_model"](
intermediate_size=4096,
hidden_size=1024,
patch_size=14,
image_size=336,
num_hidden_layers=24,
num_attention_heads=16,
vocab_size=32000,
projection_dim=768,
)
self.vision_config = vision_config
if isinstance(audio_config, dict):
audio_config["model_type"] = (
audio_config["model_type"]
if "model_type" in audio_config
else "whisper"
)
audio_config = CONFIG_MAPPING[audio_config["model_type"]](**audio_config)
elif audio_config is None:
audio_config = CONFIG_MAPPING["whisper"]()
self.audio_config = audio_config
if isinstance(text_config, dict):
text_config["model_type"] = (
text_config["model_type"] if "model_type" in text_config else "llama"
)
text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config)
elif text_config is None:
text_config = CONFIG_MAPPING["qwen2"]()
self.text_config = text_config
self.multimodal_projector_bias = multimodal_projector_bias
super().__init__(**kwargs)
+470
View File
@@ -0,0 +1,470 @@
import torch
from namo.models.symbols import (
DEFAULT_IM_END_TOKEN,
DEFAULT_IM_START_TOKEN,
DEFAULT_IMAGE_PATCH_TOKEN,
IMAGE_TOKEN_INDEX,
IGNORE_INDEX,
AUDIO_TOKEN_INDEX,
)
from namo.utils.process_utils import unpad_image, get_anyres_image_grid_shape
from .meta_vision import NamoMetaVisionForCausalLM
class NamoMetaOmniForCausalLM(NamoMetaVisionForCausalLM):
def get_audio_encoder(self):
return self.get_model().get_audio_encoder()
def get_audio_projector(self):
return self.get_model().audio_projector
def encode_audio(self, audio, audio_lengths):
audio_encoder_type = self.config.audio_encoder_type
audio_encoder = self.get_audio_encoder()
if "whisper" in audio_encoder_type.lower():
encoder_outs = audio_encoder(audio.permute(0, 2, 1))
audio_lengths = (audio_lengths + 1) // 2
else:
raise ValueError(f"Unknown audio encoder: {audio_encoder}")
audio_projector_type = self.config.audio_projector_type
audio_projector = self.get_audio_projector()
if audio_projector_type == "linear":
encoder_outs = audio_projector(encoder_outs)
audio_lengths = audio_lengths // audio_projector.k
else:
raise ValueError(f"Unknown audio projector: {audio_projector_type}")
audio_features = [
encoder_outs[i, : audio_lengths[i]] for i in range(len(encoder_outs))
]
return audio_features
def prepare_inputs_labels_for_multimodal(
self,
input_ids,
position_ids,
attention_mask,
past_key_values,
labels,
images,
image_sizes=None,
audio=None,
audio_lengths=None,
pixel_attention_mask=None,
):
vision_tower = self.get_vision_tower()
if vision_tower is None or images is None or input_ids.shape[1] == 1:
return (
input_ids,
position_ids,
attention_mask,
past_key_values,
None,
labels,
)
audio_encoder = self.get_audio_encoder()
if audio_encoder is None or audio_encoder is None or input_ids.shape[1] == 1:
return (
input_ids,
position_ids,
attention_mask,
past_key_values,
None,
labels,
)
if type(images) is list or images.ndim == 5:
if type(images) is list:
images = [x.unsqueeze(0) if x.ndim == 3 else x for x in images]
concat_images = torch.cat([image for image in images], dim=0)
image_features = self.encode_images(concat_images)
split_sizes = [image.shape[0] for image in images]
image_features = torch.split(image_features, split_sizes, dim=0)
mm_patch_merge_type = getattr(self.config, "mm_patch_merge_type", "flat")
image_aspect_ratio = getattr(self.config, "image_aspect_ratio", "square")
if mm_patch_merge_type == "flat":
image_features = [x.flatten(0, 1) for x in image_features]
elif mm_patch_merge_type.startswith("spatial"):
new_image_features = []
for image_idx, image_feature in enumerate(image_features):
if image_feature.shape[0] > 1:
base_image_feature = image_feature[0]
image_feature = image_feature[1:]
height = width = self.get_vision_tower().num_patches_per_side
assert height * width == base_image_feature.shape[0]
if image_aspect_ratio == "anyres":
(
num_patch_width,
num_patch_height,
) = get_anyres_image_grid_shape(
image_sizes[image_idx],
self.config.image_grid_pinpoints,
self.get_vision_tower().config.image_size,
)
image_feature = image_feature.view(
num_patch_height, num_patch_width, height, width, -1
)
else:
raise NotImplementedError
if "unpad" in mm_patch_merge_type:
image_feature = image_feature.permute(
4, 0, 2, 1, 3
).contiguous()
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
image_feature = unpad_image(
image_feature, image_sizes[image_idx]
)
image_feature = torch.cat(
(
image_feature,
self.model.image_newline[:, None, None]
.expand(*image_feature.shape[:-1], 1)
.to(image_feature.device),
),
dim=-1,
)
image_feature = image_feature.flatten(1, 2).transpose(0, 1)
else:
image_feature = image_feature.permute(
0, 2, 1, 3, 4
).contiguous()
image_feature = image_feature.flatten(0, 3)
image_feature = torch.cat(
(base_image_feature, image_feature), dim=0
)
else:
image_feature = image_feature[0]
if "unpad" in mm_patch_merge_type:
image_feature = torch.cat(
(
image_feature,
self.model.image_newline[None].to(
image_feature.device
),
),
dim=0,
)
new_image_features.append(image_feature)
image_features = new_image_features
else:
raise ValueError(
f"Unexpected mm_patch_merge_type: {self.config.mm_patch_merge_type}"
)
else:
image_features = self.encode_images(
images, pixel_attention_mask=pixel_attention_mask
)
# preparing audio features
audio_features = self.encode_audio(audio, audio_lengths)
# for video test
# image_features = self.temporal_aggregation(image_features)
# print(f'imgfahture: {image_features.shape}')
# TODO: image start / end is not implemented here to support pretraining.
if getattr(self.config, "tune_conn_ve_llm", False) and getattr(
self.config, "mm_use_im_start_end", False
):
raise NotImplementedError
# Let's just add dummy tensors if they do not exist,
# it is a headache to deal with None all the time.
# But it is not ideal, and if you have a better idea,
# please open an issue / submit a PR, thanks.
_labels = labels
_position_ids = position_ids
_attention_mask = attention_mask
if attention_mask is None:
attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
else:
attention_mask = attention_mask.bool()
if position_ids is None:
position_ids = torch.arange(
0, input_ids.shape[1], dtype=torch.long, device=input_ids.device
)
if labels is None:
labels = torch.full_like(input_ids, IGNORE_INDEX)
# remove the padding using attention_mask -- FIXME
_input_ids = input_ids
input_ids = [
cur_input_ids[cur_attention_mask]
for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)
]
labels = [
cur_labels[cur_attention_mask]
for cur_labels, cur_attention_mask in zip(labels, attention_mask)
]
# I->T, A->T, T->T, AI->T, TI->T, AT->T, AIT->T
new_input_embeds = []
new_labels = []
cur_image_idx = 0
cur_audio_idx = 0
for batch_idx, cur_input_ids in enumerate(input_ids):
# no image or audio, pure text.
num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum()
num_audio_frames = (cur_input_ids == AUDIO_TOKEN_INDEX).sum()
if num_images == 0 and num_audio_frames == 0:
cur_image_features = image_features[cur_image_idx]
cur_audio_features = audio_features["inputs_embeds"][cur_audio_idx]
cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids)
cur_input_embeds = torch.cat(
[
cur_input_embeds_1,
cur_image_features[0:0],
cur_audio_features[0:0],
],
dim=0,
)
new_input_embeds.append(cur_input_embeds)
new_labels.append(labels[batch_idx])
cur_image_idx += 1
cur_audio_idx += 1
continue
image_audio_token_indices = (
[-1]
+ torch.where(
(cur_input_ids == IMAGE_TOKEN_INDEX)
| (cur_input_ids == AUDIO_TOKEN_INDEX)
)[0].tolist()
+ [cur_input_ids.shape[0]]
)
cur_input_ids_noim_noau = []
cur_labels = labels[batch_idx]
cur_labels_noim_noau = []
for i in range(len(image_audio_token_indices) - 1):
cur_input_ids_noim_noau.append(
cur_input_ids[
image_audio_token_indices[i]
+ 1 : image_audio_token_indices[i + 1]
]
)
cur_labels_noim_noau.append(
cur_labels[
image_audio_token_indices[i]
+ 1 : image_audio_token_indices[i + 1]
]
)
split_sizes = [x.shape[0] for x in cur_labels_noim_noau]
cur_input_embeds = self.get_model().embed_tokens(
torch.cat(cur_input_ids_noim_noau)
)
cur_input_embeds_no_im_no_au = torch.split(
cur_input_embeds, split_sizes, dim=0
)
cur_new_input_embeds = []
cur_new_labels = []
for i in range(num_images + num_audio_frames + 1):
cur_new_input_embeds.append(cur_input_embeds_no_im_no_au[i])
cur_new_labels.append(cur_labels_noim_noau[i])
if i < num_images + num_audio_frames:
if (
cur_input_ids[image_audio_token_indices[i + 1]]
== IMAGE_TOKEN_INDEX
):
cur_image_features = image_features[cur_image_idx]
cur_image_idx += 1
cur_new_input_embeds.append(cur_image_features)
cur_new_labels.append(
torch.full(
(cur_image_features.shape[0],),
IGNORE_INDEX,
device=cur_labels.device,
dtype=cur_labels.dtype,
)
)
elif (
cur_input_ids[image_audio_token_indices[i + 1]]
== AUDIO_TOKEN_INDEX
):
cur_audio_features = audio_features["inputs_embeds"][
cur_audio_idx
]
cur_audio_idx += 1
cur_new_input_embeds.append(cur_audio_features)
cur_new_labels.append(
torch.full(
(cur_audio_features.shape[0],),
IGNORE_INDEX,
device=cur_labels.device,
dtype=cur_labels.dtype,
)
)
else:
raise ValueError
if num_images != 0 and num_audio_frames == 0:
cur_audio_features = audio_features["inputs_embeds"][cur_audio_idx]
cur_audio_idx += 1
cur_new_input_embeds.append(cur_audio_features[0:0])
elif num_images == 0 and num_audio_frames != 0:
cur_image_features = image_features[cur_image_idx]
cur_image_idx += 1
cur_new_input_embeds.append(cur_image_features[0:0])
cur_new_input_embeds = torch.cat(cur_new_input_embeds)
cur_new_labels = torch.cat(cur_new_labels)
new_input_embeds.append(cur_new_input_embeds)
new_labels.append(cur_new_labels)
assert cur_image_idx == image_features.shape[0]
assert cur_audio_idx == audio_features["inputs_embeds"].shape[0]
# Truncate sequences to max length as image embeddings can make the sequence longer
tokenizer_model_max_length = getattr(
self.config, "tokenizer_model_max_length", None
)
if tokenizer_model_max_length is not None:
new_input_embeds = [
x[:tokenizer_model_max_length] for x in new_input_embeds
]
new_labels = [x[:tokenizer_model_max_length] for x in new_labels]
# Combine them
max_len = max(x.shape[0] for x in new_input_embeds)
batch_size = len(new_input_embeds)
new_input_embeds_padded = []
new_labels_padded = torch.full(
(batch_size, max_len),
IGNORE_INDEX,
dtype=new_labels[0].dtype,
device=new_labels[0].device,
)
attention_mask = torch.zeros(
(batch_size, max_len),
dtype=attention_mask.dtype,
device=attention_mask.device,
)
position_ids = torch.zeros(
(batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device
)
for i, (cur_new_embed, cur_new_labels) in enumerate(
zip(new_input_embeds, new_labels)
):
cur_len = cur_new_embed.shape[0]
if getattr(self.config, "tokenizer_padding_side", "right") == "left":
new_input_embeds_padded.append(
torch.cat(
(
torch.zeros(
(max_len - cur_len, cur_new_embed.shape[1]),
dtype=cur_new_embed.dtype,
device=cur_new_embed.device,
),
cur_new_embed,
),
dim=0,
)
)
if cur_len > 0:
new_labels_padded[i, -cur_len:] = cur_new_labels
attention_mask[i, -cur_len:] = True
position_ids[i, -cur_len:] = torch.arange(
0, cur_len, dtype=position_ids.dtype, device=position_ids.device
)
else:
new_input_embeds_padded.append(
torch.cat(
(
cur_new_embed,
torch.zeros(
(max_len - cur_len, cur_new_embed.shape[1]),
dtype=cur_new_embed.dtype,
device=cur_new_embed.device,
),
),
dim=0,
)
)
if cur_len > 0:
new_labels_padded[i, :cur_len] = cur_new_labels
attention_mask[i, :cur_len] = True
position_ids[i, :cur_len] = torch.arange(
0, cur_len, dtype=position_ids.dtype, device=position_ids.device
)
new_input_embeds = torch.stack(new_input_embeds_padded, dim=0)
if _labels is None:
new_labels = None
else:
new_labels = new_labels_padded
if _attention_mask is None:
attention_mask = None
else:
attention_mask = attention_mask.to(dtype=_attention_mask.dtype)
if _position_ids is None:
position_ids = None
return (
None,
position_ids,
attention_mask,
past_key_values,
new_input_embeds,
new_labels,
)
def initialize_vision_tokenizer(self, model_args, tokenizer):
if model_args.mm_use_im_patch_token:
tokenizer.add_tokens(
[DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_AUDIOTOKEN], special_tokens=True
)
self.resize_token_embeddings(len(tokenizer))
if model_args.mm_use_im_start_end:
num_new_tokens = tokenizer.add_tokens(
[DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True
)
self.resize_token_embeddings(len(tokenizer))
if num_new_tokens > 0:
input_embeddings = self.get_input_embeddings().weight.data
output_embeddings = self.get_output_embeddings().weight.data
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
dim=0, keepdim=True
)
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
dim=0, keepdim=True
)
input_embeddings[-num_new_tokens:] = input_embeddings_avg
output_embeddings[-num_new_tokens:] = output_embeddings_avg
if model_args.tune_conn_ve_llm:
for p in self.get_input_embeddings().parameters():
p.requires_grad = True
for p in self.get_output_embeddings().parameters():
p.requires_grad = False
if model_args.pretrain_mm_mlp_adapter:
mm_projector_weights = torch.load(
model_args.pretrain_mm_mlp_adapter, map_location="cpu"
)
embed_tokens_weight = mm_projector_weights["model.embed_tokens.weight"]
assert num_new_tokens == 2
if input_embeddings.shape == embed_tokens_weight.shape:
input_embeddings[-num_new_tokens:] = embed_tokens_weight[
-num_new_tokens:
]
elif embed_tokens_weight.shape[0] == num_new_tokens:
input_embeddings[-num_new_tokens:] = embed_tokens_weight
else:
raise ValueError(
f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}."
)
elif model_args.mm_use_im_patch_token:
if model_args.tune_conn_ve_llm:
for p in self.get_input_embeddings().parameters():
p.requires_grad = False
for p in self.get_output_embeddings().parameters():
p.requires_grad = False
+393
View File
@@ -0,0 +1,393 @@
from abc import abstractmethod
from abc import ABC
import torch
from namo.models.symbols import (
DEFAULT_IM_END_TOKEN,
DEFAULT_IM_START_TOKEN,
DEFAULT_IMAGE_PATCH_TOKEN,
IMAGE_TOKEN_INDEX,
IGNORE_INDEX,
)
from namo.utils.process_utils import unpad_image, get_anyres_image_grid_shape
class NamoMetaVisionForCausalLM(ABC):
@abstractmethod
def get_namo(self):
pass
def get_vision_tower(self):
return self.get_namo().get_vision_tower()
def encode_images(self, images, pixel_attention_mask=None):
image_features = self.get_namo().get_vision_tower()(
images, pixel_attention_mask
)
image_features = self.get_namo().conn_ve_llm(image_features)
return image_features
def prepare_inputs_labels_for_multimodal(
self,
input_ids,
position_ids,
attention_mask,
past_key_values,
labels,
images,
image_sizes=None,
pixel_attention_mask=None,
):
vision_tower = self.get_vision_tower()
if vision_tower is None or images is None or input_ids.shape[1] == 1:
return (
input_ids,
position_ids,
attention_mask,
past_key_values,
None,
labels,
)
if type(images) is list or images.ndim == 5:
if type(images) is list:
images = [x.unsqueeze(0) if x.ndim == 3 else x for x in images]
# concat_images = torch.cat([image for image in images], dim=0)
image_features = self.encode_images(images)
mm_patch_merge_type = getattr(self.config, "mm_patch_merge_type", "none")
image_aspect_ratio = getattr(self.config, "image_aspect_ratio", "square")
if mm_patch_merge_type == "flat":
split_sizes = [image.shape[0] for image in images]
image_features = torch.split(image_features, split_sizes, dim=0)
image_features = [x.flatten(0, 1) for x in image_features]
elif mm_patch_merge_type.startswith("spatial"):
split_sizes = [image.shape[0] for image in images]
image_features = torch.split(image_features, split_sizes, dim=0)
new_image_features = []
for image_idx, image_feature in enumerate(image_features):
if image_feature.shape[0] > 1:
base_image_feature = image_feature[0]
image_feature = image_feature[1:]
height = width = self.get_vision_tower().num_patches_per_side
assert height * width == base_image_feature.shape[0]
if image_aspect_ratio == "anyres":
(
num_patch_width,
num_patch_height,
) = get_anyres_image_grid_shape(
image_sizes[image_idx],
self.config.image_grid_pinpoints,
self.get_vision_tower().config.image_size,
)
image_feature = image_feature.view(
num_patch_height, num_patch_width, height, width, -1
)
else:
raise NotImplementedError
if "unpad" in mm_patch_merge_type:
image_feature = image_feature.permute(
4, 0, 2, 1, 3
).contiguous()
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
image_feature = unpad_image(
image_feature, image_sizes[image_idx]
)
image_feature = torch.cat(
(
image_feature,
self.model.image_newline[:, None, None]
.expand(*image_feature.shape[:-1], 1)
.to(image_feature.device),
),
dim=-1,
)
image_feature = image_feature.flatten(1, 2).transpose(0, 1)
else:
image_feature = image_feature.permute(
0, 2, 1, 3, 4
).contiguous()
image_feature = image_feature.flatten(0, 3)
image_feature = torch.cat(
(base_image_feature, image_feature), dim=0
)
else:
image_feature = image_feature[0]
if "unpad" in mm_patch_merge_type:
image_feature = torch.cat(
(
image_feature,
self.model.image_newline[None].to(
image_feature.device
),
),
dim=0,
)
new_image_features.append(image_feature)
image_features = new_image_features
else:
# do nothing, should be a list of image tokens
image_features = [f.squeeze(0) for f in image_features]
else:
image_features = self.encode_images(
images, pixel_attention_mask=pixel_attention_mask
)
# for video test
# image_features = self.temporal_aggregation(image_features)
# print(f"imgfahture: {image_features[0].shape}")
# TODO: image start / end is not implemented here to support pretraining.
if getattr(self.config, "tune_conn_ve_llm", False) and getattr(
self.config, "mm_use_im_start_end", False
):
raise NotImplementedError
# Let's just add dummy tensors if they do not exist,
# it is a headache to deal with None all the time.
# But it is not ideal, and if you have a better idea,
# please open an issue / submit a PR, thanks.
_labels = labels
_position_ids = position_ids
_attention_mask = attention_mask
if attention_mask is None:
attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
else:
attention_mask = attention_mask.bool()
if position_ids is None:
position_ids = torch.arange(
0, input_ids.shape[1], dtype=torch.long, device=input_ids.device
)
if labels is None:
labels = torch.full_like(input_ids, IGNORE_INDEX)
# remove the padding using attention_mask -- FIXME
_input_ids = input_ids
input_ids = [
cur_input_ids[cur_attention_mask]
for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)
]
labels = [
cur_labels[cur_attention_mask]
for cur_labels, cur_attention_mask in zip(labels, attention_mask)
]
new_input_embeds = []
new_labels = []
cur_image_idx = 0
for batch_idx, cur_input_ids in enumerate(input_ids):
num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum()
if num_images == 0:
cur_image_features = image_features[cur_image_idx]
cur_input_embeds_1 = (
self.get_namo().get_llm().embed_tokens(cur_input_ids)
)
cur_input_embeds = torch.cat(
[cur_input_embeds_1, cur_image_features[0:0]], dim=0
)
new_input_embeds.append(cur_input_embeds)
new_labels.append(labels[batch_idx])
cur_image_idx += 1
continue
image_token_indices = (
[-1]
+ torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist()
+ [cur_input_ids.shape[0]]
)
cur_input_ids_noim = []
cur_labels = labels[batch_idx]
cur_labels_noim = []
for i in range(len(image_token_indices) - 1):
cur_input_ids_noim.append(
cur_input_ids[
image_token_indices[i] + 1 : image_token_indices[i + 1]
]
)
cur_labels_noim.append(
cur_labels[image_token_indices[i] + 1 : image_token_indices[i + 1]]
)
split_sizes = [x.shape[0] for x in cur_labels_noim]
cur_input_embeds = (
self.get_namo().get_llm().embed_tokens(torch.cat(cur_input_ids_noim))
)
cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0)
cur_new_input_embeds = []
cur_new_labels = []
for i in range(num_images + 1):
cur_new_input_embeds.append(cur_input_embeds_no_im[i])
cur_new_labels.append(cur_labels_noim[i])
if i < num_images:
cur_image_features = image_features[cur_image_idx]
cur_image_idx += 1
cur_new_input_embeds.append(cur_image_features)
cur_new_labels.append(
torch.full(
(cur_image_features.shape[0],),
IGNORE_INDEX,
device=cur_labels.device,
dtype=cur_labels.dtype,
)
)
cur_new_input_embeds = [x.to(self.device) for x in cur_new_input_embeds]
cur_new_input_embeds = torch.cat(cur_new_input_embeds)
cur_new_labels = torch.cat(cur_new_labels)
new_input_embeds.append(cur_new_input_embeds)
new_labels.append(cur_new_labels)
# Truncate sequences to max length as image embeddings can make the sequence longer
tokenizer_model_max_length = getattr(
self.config, "tokenizer_model_max_length", None
)
if tokenizer_model_max_length is not None:
new_input_embeds = [
x[:tokenizer_model_max_length] for x in new_input_embeds
]
new_labels = [x[:tokenizer_model_max_length] for x in new_labels]
# Combine them
max_len = max(x.shape[0] for x in new_input_embeds)
batch_size = len(new_input_embeds)
new_input_embeds_padded = []
new_labels_padded = torch.full(
(batch_size, max_len),
IGNORE_INDEX,
dtype=new_labels[0].dtype,
device=new_labels[0].device,
)
attention_mask = torch.zeros(
(batch_size, max_len),
dtype=attention_mask.dtype,
device=attention_mask.device,
)
position_ids = torch.zeros(
(batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device
)
for i, (cur_new_embed, cur_new_labels) in enumerate(
zip(new_input_embeds, new_labels)
):
cur_len = cur_new_embed.shape[0]
if getattr(self.config, "tokenizer_padding_side", "right") == "left":
new_input_embeds_padded.append(
torch.cat(
(
torch.zeros(
(max_len - cur_len, cur_new_embed.shape[1]),
dtype=cur_new_embed.dtype,
device=cur_new_embed.device,
),
cur_new_embed,
),
dim=0,
)
)
if cur_len > 0:
new_labels_padded[i, -cur_len:] = cur_new_labels
attention_mask[i, -cur_len:] = True
position_ids[i, -cur_len:] = torch.arange(
0, cur_len, dtype=position_ids.dtype, device=position_ids.device
)
else:
new_input_embeds_padded.append(
torch.cat(
(
cur_new_embed,
torch.zeros(
(max_len - cur_len, cur_new_embed.shape[1]),
dtype=cur_new_embed.dtype,
device=cur_new_embed.device,
),
),
dim=0,
)
)
if cur_len > 0:
new_labels_padded[i, :cur_len] = cur_new_labels
attention_mask[i, :cur_len] = True
position_ids[i, :cur_len] = torch.arange(
0, cur_len, dtype=position_ids.dtype, device=position_ids.device
)
new_input_embeds = torch.stack(new_input_embeds_padded, dim=0)
if _labels is None:
new_labels = None
else:
new_labels = new_labels_padded
if _attention_mask is None:
attention_mask = None
else:
attention_mask = attention_mask.to(dtype=_attention_mask.dtype)
if _position_ids is None:
position_ids = None
return (
None,
position_ids,
attention_mask,
past_key_values,
new_input_embeds,
new_labels,
)
def initialize_vision_tokenizer(self, model_args, tokenizer):
if model_args.mm_use_im_patch_token:
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
self.resize_token_embeddings(len(tokenizer))
if model_args.mm_use_im_start_end:
num_new_tokens = tokenizer.add_tokens(
[DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True
)
self.resize_token_embeddings(len(tokenizer))
if num_new_tokens > 0:
input_embeddings = self.get_input_embeddings().weight.data
output_embeddings = self.get_output_embeddings().weight.data
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
dim=0, keepdim=True
)
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
dim=0, keepdim=True
)
input_embeddings[-num_new_tokens:] = input_embeddings_avg
output_embeddings[-num_new_tokens:] = output_embeddings_avg
if model_args.tune_conn_ve_llm:
for p in self.get_input_embeddings().parameters():
p.requires_grad = True
for p in self.get_output_embeddings().parameters():
p.requires_grad = False
if model_args.pretrain_mm_mlp_adapter:
mm_projector_weights = torch.load(
model_args.pretrain_mm_mlp_adapter, map_location="cpu"
)
embed_tokens_weight = mm_projector_weights["model.embed_tokens.weight"]
assert num_new_tokens == 2
if input_embeddings.shape == embed_tokens_weight.shape:
input_embeddings[-num_new_tokens:] = embed_tokens_weight[
-num_new_tokens:
]
elif embed_tokens_weight.shape[0] == num_new_tokens:
input_embeddings[-num_new_tokens:] = embed_tokens_weight
else:
raise ValueError(
f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}."
)
elif model_args.mm_use_im_patch_token:
if model_args.tune_conn_ve_llm:
for p in self.get_input_embeddings().parameters():
p.requires_grad = False
for p in self.get_output_embeddings().parameters():
p.requires_grad = False
View File
@@ -0,0 +1,40 @@
from torch.nn import LayerNorm
from torch import nn
import torch
from timm.models.regnet import RegStage
from timm.layers import LayerNorm, LayerNorm2d
class GLU(nn.Module):
def __init__(self, hidden_size, ffn_hidden_size, in_features):
super().__init__()
self.linear_proj = nn.Linear(in_features, hidden_size, bias=False)
self.norm1 = nn.LayerNorm(hidden_size)
self.act1 = nn.GELU()
self.act2 = nn.functional.silu
self.dense_h_to_4h = nn.Linear(hidden_size, ffn_hidden_size, bias=False)
self.gate_proj = nn.Linear(hidden_size, ffn_hidden_size, bias=False)
self.dense_4h_to_h = nn.Linear(ffn_hidden_size, hidden_size, bias=False)
def forward(self, x):
x = self.linear_proj(x)
x = self.act1(self.norm1(x))
x = self.act2(self.gate_proj(x)) * self.dense_h_to_4h(x)
x = self.dense_4h_to_h(x)
return x
class MlpGLU(nn.Module):
def __init__(self, in_hidden_size, out_hidden_size):
super(MlpGLU, self).__init__()
ffn_hidden_size = out_hidden_size * 4 # out_hidden_size * 4 3584 * 4 = 14336
self.linear_proj = GLU(
hidden_size=out_hidden_size,
ffn_hidden_size=ffn_hidden_size,
in_features=in_hidden_size,
)
def forward(self, x, attention_mask: torch.Tensor = None):
x = self.linear_proj(x)
return x
+45
View File
@@ -0,0 +1,45 @@
import re
import torch.nn as nn
class MLP(nn.Module):
def __init__(self, projector_type, input_dim, hidden_dim, output_dim=None):
super().__init__()
self.projector_type = projector_type
self.output_dim = output_dim or hidden_dim
self.mlp_depth, self.use_norm = self._parse_projector_type()
self.layers = self._build_layers(input_dim, hidden_dim)
def _parse_projector_type(self):
use_norm = "_Norm" in self.projector_type
clean_type = self.projector_type.replace("_Norm", "")
pattern = r"^mlp(\d+)x_gelu$"
match = re.match(pattern, clean_type)
if not match:
raise ValueError(f"Invalid projector_type: {self.projector_type}")
return int(match.group(1)), use_norm
def _build_layers(self, input_dim, hidden_dim):
modules = []
modules.append(nn.Linear(input_dim, hidden_dim))
if self.use_norm:
modules.append(nn.LayerNorm(hidden_dim))
for _ in range(self.mlp_depth - 1):
modules.append(nn.GELU())
modules.append(nn.Linear(hidden_dim, hidden_dim))
if self.use_norm:
modules.append(nn.LayerNorm(hidden_dim))
if hidden_dim != self.output_dim:
modules.append(nn.Linear(hidden_dim, self.output_dim))
return nn.Sequential(*modules)
def forward(self, x):
if len(x.shape) == 4:
x = x.flatten(1, 2)
return self.layers(x)
@@ -0,0 +1,111 @@
from functools import partial
import re
from torch import nn
import torch
import torch.nn.functional as F
from transformers.activations import ACT2FN
from .components import GLU
from namo.utils.utils import rank0_print
class PixelShuffleLayer(nn.Module):
def __init(self):
super(PixelShuffleLayer, self).__init__()
def forward(self, x, scale_factor=0.5):
n, w, h, c = x.size()
# handle if w h not even
pad_w = 0
pad_h = 0
if w % int(1 / scale_factor) != 0:
pad_w = int(1 / scale_factor) - (w % int(1 / scale_factor))
if h % int(1 / scale_factor) != 0:
pad_h = int(1 / scale_factor) - (h % int(1 / scale_factor))
if pad_w != 0 or pad_h != 0:
x = x.permute(0, 3, 1, 2)
x = F.pad(x, (0, pad_h, 0, pad_w))
x = x.permute(0, 2, 3, 1)
w = w + pad_w
h = h + pad_h
new_h = int(h * scale_factor)
new_c = int(c / scale_factor)
x = x.reshape(n, w, new_h, new_c)
x = x.permute(0, 2, 1, 3).contiguous()
new_w = int(w * scale_factor)
new_c_final = int(c / (scale_factor * scale_factor))
x = x.view(n, new_h, new_w, new_c_final)
x = x.permute(0, 2, 1, 3).contiguous()
return x
class PixelShuffleConnector(nn.Module):
def __init__(self, in_hidden_size, out_hidden_size, down_rate=2, conv_before=False):
super(PixelShuffleConnector, self).__init__()
# ffn_hidden_size = 13696
ffn_hidden_size = in_hidden_size * (
down_rate**2
) # out_hidden_size * 4 3584 * 4 = 14336
self.linear_proj = GLU(
hidden_size=out_hidden_size,
ffn_hidden_size=ffn_hidden_size,
in_features=in_hidden_size * (down_rate**2),
)
rank0_print(
f"==> pixelshuffle: ffn_hidden_size: {ffn_hidden_size}, in_features: {in_hidden_size * (down_rate**2)}"
)
self.down_rate = down_rate
self.downsample = PixelShuffleLayer()
self.conv_before = conv_before
if conv_before:
self.conv = nn.Conv2d(
in_channels=in_hidden_size,
out_channels=in_hidden_size,
kernel_size=2,
stride=2,
)
def forward(self, x, attention_mask: torch.Tensor = None):
# print(f"xin: {x.shape}")
if len(x.shape) == 3:
b, s, h = x.shape
grid_size = int(s**0.5)
x = x.reshape(b, grid_size, grid_size, h)
x = self.downsample(x, scale_factor=1 / self.down_rate)
elif len(x.shape) == 4:
if self.conv_before:
# only for 4x4 at least?
# print(f'brefore conv: {x.shape}')
x = self.conv(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)
# print(x.shape)
x = self.downsample(x, scale_factor=1 / self.down_rate)
else:
raise ValueError(f"Unsupported input shape: {x.shape}, either rank 3 or 4")
# print(f"x after pixshuffle: {x.shape}")
# [11, 16, 16, 4608]
x = x.reshape(x.shape[0], -1, x.shape[-1])
# print(f"x after pixshuffle: {x.shape}")
x = self.linear_proj(x)
return x
def get_pixel_shuffle(type_name, mm_hidden_size, hidden_size):
resampler_match = re.match(r"^pixelshuffle_(\d+)x$", type_name)
if resampler_match:
down_rate = int(resampler_match.group(1))
rank0_print(
f"==> conn_ve_llm type: {type_name}, downsample rate: {down_rate}, {mm_hidden_size}->{hidden_size}"
)
modules = []
m = PixelShuffleConnector(
in_hidden_size=mm_hidden_size,
out_hidden_size=hidden_size,
down_rate=down_rate,
)
modules.append(m)
modules = nn.Sequential(*modules)
return modules
else:
raise ValueError(f"Unknown resampler type: {type_name}")
View File
+55
View File
@@ -0,0 +1,55 @@
from typing import List, Tuple, Union
from torch import nn
import torch
import re
from namo.models.modal_adapt.adapt_ve.mlp import MLP
from namo.models.modal_adapt.adapt_ve.components import MlpGLU
from namo.models.modal_adapt.adapt_ve.pixelshuffle import get_pixel_shuffle
from namo.utils.utils import rank0_print
class ConnVE(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
type_name = self.config.conn_ve_llm_type
llm_hidden_size = config.text_config.hidden_size
ve_hidden_size = getattr(
config.vision_config, "hidden_size", config.vision_config.intermediate_size
)
rank0_print(f"==> current conn type: {type_name}")
if type_name == "identity":
modules = nn.Identity()
elif type_name == "linear":
modules = nn.Linear(ve_hidden_size, llm_hidden_size)
elif "gelu" in type_name:
modules = MLP(type_name, ve_hidden_size, llm_hidden_size)
elif "pixelshuffle" in type_name:
modules = get_pixel_shuffle(type_name, ve_hidden_size, llm_hidden_size)
elif "ovis" in type_name:
print(f"{type_name} is not supported")
elif "glu" in type_name:
rank0_print("==> Using MLP GLU.")
modules = []
m = MlpGLU(in_hidden_size=ve_hidden_size, out_hidden_size=llm_hidden_size)
modules.append(m)
modules = nn.Sequential(*modules)
else:
raise ValueError(f"Unknown projector type: {type_name}")
self.layers = modules
def forward(
self,
x_or_tuple: Union[
Tuple[torch.Tensor, torch.Tensor], torch.Tensor, List[torch.Tensor]
],
):
x = x_or_tuple
if isinstance(x, list):
out = [self.layers(i) for i in x]
return out
else:
out = self.layers(x)
return out
+192
View File
@@ -0,0 +1,192 @@
import torch
from typing import List, Optional, Tuple, Union
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.generation.utils import GenerateOutput
from transformers import (
PreTrainedModel,
GenerationMixin,
AutoModelForCausalLM,
AutoModel,
AutoTokenizer,
)
from torch import nn
from namo.models.meta_vision import NamoMetaVisionForCausalLM
from namo.utils.utils import rank0_print, load_conn_weights
from namo.models.configuration_namo import NamoConfig
from namo.models.modal_adapt.conn_ve import ConnVE
from namo.models.vision.ve import get_ve
from namo.utils.hf_utils import auto_load_model, auto_load_tokenizer, SimpleForCausalLM
class NamoPretrainedModel(PreTrainedModel):
config_class = NamoConfig
base_model_prefix = "namo"
_supports_sdpa = True
_supports_flash_attn_2 = True
_supports_cache_class = True
supports_gradient_checkpointing = True
class NamoModel(NamoPretrainedModel):
_no_split_modules = []
def __init__(self, config: NamoConfig):
super().__init__(config)
self.config = config
self.llm = auto_load_model(self.config.text_config)
self.tokenizer = auto_load_tokenizer(self.config)
# ve always being trained, not need delay anymore.
self.ve = get_ve(config, delay_load=False)
self.conn_ve_llm = ConnVE(config)
def get_llm(self):
return self.llm
def get_vision_tower(self):
ve = getattr(self, "ve", None)
return ve
def load_conn_ve_llm_weights(self, pretrain_mm_mlp_adapter):
load_conn_weights(pretrain_mm_mlp_adapter, self.conn_ve_llm, "conn_ve_llm")
class NamoForCausalLM(
NamoPretrainedModel, NamoMetaVisionForCausalLM, SimpleForCausalLM
):
_no_split_modules = []
def __init__(self, config):
NamoPretrainedModel.__init__(self, config)
super(SimpleForCausalLM, self).__init__(config.text_config)
self.config = config
# using model property avoid duplicated share tensor
self.namo = NamoModel(config)
self.vocab_size = config.text_config.vocab_size
self.lm_head = nn.Linear(
config.text_config.hidden_size, config.text_config.vocab_size, bias=False
)
self.post_init()
def get_namo(self):
return self.namo
@property
def model(self):
return self.namo.llm
def get_model(self):
return self.model
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
pixel_values: Optional[torch.FloatTensor] = None,
image_sizes: Optional[List[List[int]]] = None,
pixel_attention_mask: Optional[torch.FloatTensor] = None,
return_dict: Optional[bool] = None,
cache_position=None,
) -> Union[Tuple, CausalLMOutputWithPast]:
if inputs_embeds is None:
(
input_ids,
position_ids,
attention_mask,
past_key_values,
inputs_embeds,
labels,
) = self.prepare_inputs_labels_for_multimodal(
input_ids,
position_ids,
attention_mask,
past_key_values,
labels,
pixel_values,
image_sizes,
pixel_attention_mask,
)
return super().forward(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
labels=labels,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
@torch.no_grad()
def generate(
self,
input_ids: Optional[torch.Tensor] = None,
pixel_values: Optional[torch.Tensor] = None,
image_sizes: Optional[torch.Tensor] = None,
pixel_attention_mask: Optional[torch.FloatTensor] = None,
**kwargs,
) -> Union[GenerateOutput, torch.LongTensor]:
position_ids = kwargs.pop("position_ids", None)
attention_mask = kwargs.pop("attention_mask", None)
if "inputs_embeds" in kwargs:
raise NotImplementedError("`inputs_embeds` is not supported")
if pixel_values is not None:
(
inputs,
position_ids,
attention_mask,
_,
inputs_embeds,
_,
) = self.prepare_inputs_labels_for_multimodal(
input_ids,
position_ids,
attention_mask,
None,
None,
pixel_values,
image_sizes=image_sizes,
pixel_attention_mask=pixel_attention_mask,
)
else:
inputs_embeds = self.get_model().embed_tokens(input_ids)
return super().generate(
position_ids=position_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
**kwargs,
)
def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs
):
pixel_values = kwargs.pop("pixel_values", None)
image_sizes = kwargs.pop("image_sizes", None)
pixel_attention_mask = kwargs.pop("pixel_attention_mask", None)
inputs = super().prepare_inputs_for_generation(
input_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
**kwargs,
)
if pixel_values is not None:
inputs["pixel_values"] = pixel_values
if image_sizes is not None:
inputs["image_sizes"] = image_sizes
if pixel_attention_mask is not None:
inputs["pixel_attention_mask"] = pixel_attention_mask
return inputs
+12
View File
@@ -0,0 +1,12 @@
# Model Constants
IGNORE_INDEX = -100
IMAGE_TOKEN_INDEX = -200
AUDIO_TOKEN_INDEX = -300
DEFAULT_IMAGE_TOKEN = "<image>"
DEFAULT_AUDIOTOKEN = "<audio>"
DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
DEFAULT_IM_START_TOKEN = "<im_start>"
DEFAULT_IM_END_TOKEN = "<im_end>"
IMAGE_PLACEHOLDER = "<image-placeholder>"
DEFAULT_VIDEO_TOKEN = "<video>"
+9
View File
@@ -0,0 +1,9 @@
from transformers import AutoConfig, AutoModel
from namo.models.vision.aimv2.configuration_aimv2 import AIMv2Config
from namo.models.vision.aimv2.modeling_aimv2 import AIMv2Model
from namo.models.vision.aimv2.modeling_aimv2_native import (
AIMv2Model as AIMv2ModelNative,
)
from namo.processor import *
AutoConfig.register("aimv2", AIMv2Config)
@@ -0,0 +1,62 @@
from typing import Any
from transformers.configuration_utils import PretrainedConfig
__all__ = ["AIMv2Config"]
class AIMv2Config(PretrainedConfig):
"""This is the configuration class to store the configuration of an [`AIMv2Model`].
Instantiating a configuration with the defaults will yield a similar configuration
to that of the [apple/aimv2-large-patch14-224](https://huggingface.co/apple/aimv2-large-patch14-224).
Args:
hidden_size: Dimension of the hidden representations.
intermediate_size: Dimension of the SwiGLU representations.
num_hidden_layers: Number of hidden layers in the Transformer.
num_attention_heads: Number of attention heads for each attention layer
in the Transformer.
num_channels: Number of input channels.
image_size: Image size.
patch_size: Patch size.
rms_norm_eps: Epsilon value used for the RMS normalization layer.
attention_dropout: Dropout ratio for attention probabilities.
projection_dropout: Dropout ratio for the projection layer after the attention.
qkv_bias: Whether to add a bias to the queries, keys and values.
use_bias: Whether to add a bias in the feed-forward and projection layers.
kwargs: Keyword arguments for the [`PretrainedConfig`].
"""
model_type: str = "aimv2"
def __init__(
self,
hidden_size: int = 1024,
intermediate_size: int = 2816,
num_hidden_layers: int = 24,
num_attention_heads: int = 8,
num_channels: int = 3,
image_size: int = 224,
patch_size: int = 14,
rms_norm_eps: float = 1e-5,
attention_dropout: float = 0.0,
projection_dropout: float = 0.0,
qkv_bias: bool = False,
use_bias: bool = False,
**kwargs: Any,
):
super().__init__(**kwargs)
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.num_channels = num_channels
self.patch_size = patch_size
self.image_size = image_size
self.attention_dropout = attention_dropout
self.rms_norm_eps = rms_norm_eps
self.projection_dropout = projection_dropout
self.qkv_bias = qkv_bias
self.use_bias = use_bias
+191
View File
@@ -0,0 +1,191 @@
from typing import Optional, Tuple, Union
import torch
from .configuration_aimv2 import AIMv2Config
from torch import nn
from torch.nn import functional as F
from transformers.modeling_outputs import BaseModelOutputWithNoAttention
from transformers.modeling_utils import PreTrainedModel
__all__ = ["AIMv2Model"]
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(dim))
self.eps = eps
def forward(self, x: torch.Tensor) -> torch.Tensor:
output = self._norm(x.float()).type_as(x)
return output * self.weight
def extra_repr(self) -> str:
return f"{tuple(self.weight.shape)}, eps={self.eps}"
def _norm(self, x: torch.Tensor) -> torch.Tensor:
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
class AIMv2SwiGLUFFN(nn.Module):
def __init__(self, config: AIMv2Config):
super().__init__()
hidden_features = config.intermediate_size
in_features = config.hidden_size
bias = config.use_bias
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
self.fc2 = nn.Linear(hidden_features, in_features, bias=bias)
self.fc3 = nn.Linear(in_features, hidden_features, bias=bias)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = F.silu(self.fc1(x)) * self.fc3(x)
x = self.fc2(x)
return x
class AIMv2PatchEmbed(nn.Module):
def __init__(self, config: AIMv2Config):
super().__init__()
self.proj = nn.Conv2d(
config.num_channels,
config.hidden_size,
kernel_size=(config.patch_size, config.patch_size),
stride=(config.patch_size, config.patch_size),
)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.proj(x).flatten(2).transpose(1, 2)
x = self.norm(x)
return x
class AIMv2ViTPreprocessor(nn.Module):
def __init__(self, config: AIMv2Config):
super().__init__()
num_patches = (config.image_size // config.patch_size) ** 2
self.patchifier = AIMv2PatchEmbed(config)
self.pos_embed = nn.Parameter(torch.zeros((1, num_patches, config.hidden_size)))
def forward(self, x: torch.Tensor) -> torch.Tensor:
tokens = self.patchifier(x)
_, N, _ = tokens.shape
pos_embed = self.pos_embed.to(tokens.device)
tokens = tokens + pos_embed[:, :N]
return tokens
class AIMv2Attention(nn.Module):
def __init__(self, config: AIMv2Config):
super().__init__()
dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.qkv = nn.Linear(dim, dim * 3, bias=config.qkv_bias)
self.attn_drop = nn.Dropout(config.attention_dropout)
self.proj = nn.Linear(dim, dim, bias=config.use_bias)
self.proj_drop = nn.Dropout(config.projection_dropout)
def forward(
self, x: torch.Tensor, mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
B, N, C = x.shape
qkv = (
self.qkv(x)
.reshape(B, N, 3, self.num_heads, C // self.num_heads)
.permute(2, 0, 3, 1, 4)
)
q, k, v = qkv.unbind(0)
x = F.scaled_dot_product_attention(q, k, v, attn_mask=mask)
x = x.transpose(1, 2).contiguous().reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class AIMv2Block(nn.Module):
def __init__(self, config: AIMv2Config):
super().__init__()
self.attn = AIMv2Attention(config)
self.norm_1 = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.mlp = AIMv2SwiGLUFFN(config)
self.norm_2 = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self, x: torch.Tensor, mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
x = x + self.attn(self.norm_1(x), mask)
x = x + self.mlp(self.norm_2(x))
return x
class AIMv2Transformer(nn.Module):
def __init__(self, config: AIMv2Config):
super().__init__()
self.blocks = nn.ModuleList(
[AIMv2Block(config) for _ in range(config.num_hidden_layers)]
)
self.post_trunk_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
tokens: torch.Tensor,
mask: Optional[torch.Tensor] = None,
output_hidden_states: bool = False,
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, ...]]]:
hidden_states = () if output_hidden_states else None
for block in self.blocks:
tokens = block(tokens, mask)
if output_hidden_states:
hidden_states += (tokens,)
tokens = self.post_trunk_norm(tokens)
return tokens, hidden_states
class AIMv2PretrainedModel(PreTrainedModel):
config_class = AIMv2Config
base_model_prefix = "aimv2"
main_input_name = "pixel_values"
_no_split_modules = ["AIMv2ViTPreprocessor", "AIMv2Block"]
_supports_sdpa = True
class AIMv2Model(AIMv2PretrainedModel):
def __init__(self, config: AIMv2Config):
super().__init__(config)
self.preprocessor = AIMv2ViTPreprocessor(config)
self.trunk = AIMv2Transformer(config)
def forward(
self,
pixel_values: torch.Tensor,
mask: Optional[torch.Tensor] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[
Tuple[torch.Tensor],
Tuple[torch.Tensor, Tuple[torch.Tensor, ...]],
BaseModelOutputWithNoAttention,
]:
if output_hidden_states is None:
output_hidden_states = self.config.output_hidden_states
if return_dict is None:
return_dict = self.config.use_return_dict
x = self.preprocessor(pixel_values)
x, hidden_states = self.trunk(
x, mask, output_hidden_states=output_hidden_states
)
if not return_dict:
res = (x,)
res += (hidden_states,) if output_hidden_states else ()
return res
return BaseModelOutputWithNoAttention(
last_hidden_state=x,
hidden_states=hidden_states,
)
@@ -0,0 +1,225 @@
import math
from typing import Optional, Tuple, Union
import torch
from .configuration_aimv2 import AIMv2Config
from torch import nn
from torch.nn import functional as F
from transformers.modeling_outputs import BaseModelOutputWithNoAttention
from transformers.modeling_utils import PreTrainedModel
__all__ = ["AIMv2Model"]
def _get_1d_sincos_pos_embed_from_grid(
embed_dim: int, pos: torch.Tensor
) -> torch.Tensor:
omega = torch.arange(embed_dim // 2).float()
omega /= embed_dim / 2.0
omega = 1.0 / 10000**omega # (D / 2,)
pos = pos.reshape(-1) # (M,)
out = pos[:, None] * omega[None, :] # (M, D / 2), outer product
emb_sin, emb_cos = torch.sin(out), torch.cos(out) # (M, D / 2)
emb = torch.concatenate([emb_sin, emb_cos], dim=1) # (M, D)
return emb
def get_sincos_pos_embed(h: int, w: int, embed_dim: int) -> torch.Tensor:
assert embed_dim % 2 == 0, embed_dim
grid_h = torch.arange(h).float()
grid_w = torch.arange(w).float()
grid = torch.meshgrid(grid_w, grid_h, indexing="xy")
grid = torch.stack(grid, dim=0)
grid = grid.reshape([2, 1, h, w])
emb_h = _get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0])
emb_w = _get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1])
pos_embed = torch.concatenate([emb_h, emb_w], dim=1) # (H * W, D)
return pos_embed
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(dim))
self.eps = eps
def forward(self, x: torch.Tensor) -> torch.Tensor:
output = self._norm(x.float()).type_as(x)
return output * self.weight
def extra_repr(self) -> str:
return f"{tuple(self.weight.shape)}, eps={self.eps}"
def _norm(self, x: torch.Tensor) -> torch.Tensor:
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
class AIMv2SwiGLUFFN(nn.Module):
def __init__(self, config: AIMv2Config):
super().__init__()
hidden_features = config.intermediate_size
in_features = config.hidden_size
bias = config.use_bias
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
self.fc2 = nn.Linear(hidden_features, in_features, bias=bias)
self.fc3 = nn.Linear(in_features, hidden_features, bias=bias)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = F.silu(self.fc1(x)) * self.fc3(x)
x = self.fc2(x)
return x
class AIMv2PatchEmbed(nn.Module):
def __init__(self, config: AIMv2Config):
super().__init__()
self.proj = nn.Conv2d(
config.num_channels,
config.hidden_size,
kernel_size=(config.patch_size, config.patch_size),
stride=(config.patch_size, config.patch_size),
)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.proj(x).flatten(2).transpose(1, 2)
x = self.norm(x)
return x
class AIMv2ViTPreprocessor(nn.Module):
def __init__(self, config: AIMv2Config):
super().__init__()
self.patch_h = config.patch_size
self.patch_w = config.patch_size
self.embed_dim = config.hidden_size
self.patchifier = AIMv2PatchEmbed(config)
def forward(self, x: torch.Tensor) -> torch.Tensor:
_, _, H, W = x.shape
tokens = self.patchifier(x)
pos_embed = get_sincos_pos_embed(
H // self.patch_h, W // self.patch_w, embed_dim=self.embed_dim
)
tokens = tokens + pos_embed.to(tokens.device)
return tokens
class AIMv2Attention(nn.Module):
def __init__(self, config: AIMv2Config):
super().__init__()
dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.qkv = nn.Linear(dim, dim * 3, bias=config.qkv_bias)
self.attn_drop = nn.Dropout(config.attention_dropout)
self.proj = nn.Linear(dim, dim, bias=config.use_bias)
self.proj_drop = nn.Dropout(config.projection_dropout)
def forward(
self, x: torch.Tensor, mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
B, N, C = x.shape
qkv = (
self.qkv(x)
.reshape(B, N, 3, self.num_heads, C // self.num_heads)
.permute(2, 0, 3, 1, 4)
)
q, k, v = qkv.unbind(0)
x = F.scaled_dot_product_attention(q, k, v, attn_mask=mask)
x = x.transpose(1, 2).contiguous().reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class AIMv2Block(nn.Module):
def __init__(self, config: AIMv2Config):
super().__init__()
self.attn = AIMv2Attention(config)
self.norm_1 = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.mlp = AIMv2SwiGLUFFN(config)
self.norm_2 = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self, x: torch.Tensor, mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
x = x + self.attn(self.norm_1(x), mask)
x = x + self.mlp(self.norm_2(x))
return x
class AIMv2Transformer(nn.Module):
def __init__(self, config: AIMv2Config):
super().__init__()
self.blocks = nn.ModuleList(
[AIMv2Block(config) for _ in range(config.num_hidden_layers)]
)
self.post_trunk_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
tokens: torch.Tensor,
mask: Optional[torch.Tensor] = None,
output_hidden_states: bool = False,
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, ...]]]:
hidden_states = () if output_hidden_states else None
for block in self.blocks:
tokens = block(tokens, mask)
if output_hidden_states:
hidden_states += (tokens,)
tokens = self.post_trunk_norm(tokens)
return tokens, hidden_states
class AIMv2PretrainedModel(PreTrainedModel):
config_class = AIMv2Config
base_model_prefix = "aimv2"
main_input_name = "pixel_values"
_no_split_modules = ["AIMv2ViTPreprocessor", "AIMv2Block"]
_supports_sdpa = True
class AIMv2Model(AIMv2PretrainedModel):
def __init__(self, config: AIMv2Config):
super().__init__(config)
self.preprocessor = AIMv2ViTPreprocessor(config)
self.trunk = AIMv2Transformer(config)
def forward(
self,
pixel_values: torch.Tensor,
mask: Optional[torch.Tensor] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[
Tuple[torch.Tensor],
Tuple[torch.Tensor, Tuple[torch.Tensor, ...]],
BaseModelOutputWithNoAttention,
]:
if output_hidden_states is None:
output_hidden_states = self.config.output_hidden_states
if return_dict is None:
return_dict = self.config.use_return_dict
x = self.preprocessor(pixel_values)
x = x.to(pixel_values.dtype) # sin pos made this to float, fix for bf16
x, hidden_states = self.trunk(
x, mask, output_hidden_states=output_hidden_states
)
hidden_states = list(hidden_states)
h, w = pixel_values.shape[-2:]
hh, hw = h // self.preprocessor.patch_h, w // self.preprocessor.patch_w
B, T, H = hidden_states[-2].shape
hidden_states[-2] = hidden_states[-2].reshape(B, hh, hw, H)
if not return_dict:
res = (x,)
res += (hidden_states,) if output_hidden_states else ()
return res
return BaseModelOutputWithNoAttention(
last_hidden_state=x,
hidden_states=hidden_states,
)
+739
View File
@@ -0,0 +1,739 @@
# coding=utf-8
# Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch Florence-2 model."""
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
import math
import torch
import torch.utils.checkpoint
from torch import nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
from torch.nn import CrossEntropyLoss
from collections import OrderedDict
from einops import rearrange
from timm.models.layers import DropPath, trunc_normal_
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import (
ModelOutput,
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_flash_attn_2_available,
logging,
replace_return_docstrings,
is_flash_attn_2_available,
is_flash_attn_greater_or_equal_2_10,
)
from .configuration_florence2 import Florence2Config
from .configuration_florence2 import Florence2LanguageConfig
from .configuration_florence2 import Florence2VisionConfig
from transformers.activations import ACT2FN
from transformers.modeling_attn_mask_utils import (
_prepare_4d_attention_mask,
_prepare_4d_attention_mask_for_sdpa,
_prepare_4d_causal_attention_mask,
_prepare_4d_causal_attention_mask_for_sdpa,
)
from transformers.modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPastAndCrossAttentions,
Seq2SeqLMOutput,
Seq2SeqModelOutput,
)
if is_flash_attn_2_available():
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "Florence2Config"
class LearnedAbsolutePositionEmbedding2D(nn.Module):
"""
This module learns positional embeddings up to a fixed maximum size.
"""
def __init__(self, embedding_dim=256, num_pos=50):
super().__init__()
self.row_embeddings = nn.Embedding(num_pos, embedding_dim // 2)
self.column_embeddings = nn.Embedding(
num_pos, embedding_dim - (embedding_dim // 2)
)
def forward(self, pixel_values):
"""
pixel_values: (batch_size, height, width, num_channels)
returns: (batch_size, height, width, embedding_dim * 2)
"""
if len(pixel_values.shape) != 4:
raise ValueError("pixel_values must be a 4D tensor")
height, width = pixel_values.shape[1:3]
width_values = torch.arange(width, device=pixel_values.device)
height_values = torch.arange(height, device=pixel_values.device)
x_emb = self.column_embeddings(width_values)
y_emb = self.row_embeddings(height_values)
# (height, width, embedding_dim * 2)
pos = torch.cat(
[
x_emb.unsqueeze(0).repeat(height, 1, 1),
y_emb.unsqueeze(1).repeat(1, width, 1),
],
dim=-1,
)
# (embedding_dim * 2, height, width)
pos = pos.permute(2, 0, 1)
pos = pos.unsqueeze(0)
# (batch_size, embedding_dim * 2, height, width)
pos = pos.repeat(pixel_values.shape[0], 1, 1, 1)
# (batch_size, height, width, embedding_dim * 2)
pos = pos.permute(0, 2, 3, 1)
return pos
class PositionalEmbeddingCosine1D(nn.Module):
"""
This class implements a very simple positional encoding. It follows closely
the encoder from the link below:
https://pytorch.org/tutorials/beginner/translation_transformer.html
Args:
embed_dim: The dimension of the embeddings.
dropout_prob: The dropout probability.
max_seq_len: The maximum length to precompute the positional encodings.
"""
def __init__(self, embed_dim: int = 512, max_seq_len: int = 1024) -> None:
super(PositionalEmbeddingCosine1D, self).__init__()
self.embed_dim = embed_dim
self.max_seq_len = max_seq_len
# Generate the sinusoidal arrays.
factor = math.log(10000)
denominator = torch.exp(
-factor * torch.arange(0, self.embed_dim, 2) / self.embed_dim
)
# Matrix where rows correspond to a positional embedding as a function
# of the position index (i.e., the row index).
frequencies = (
torch.arange(0, self.max_seq_len).reshape(self.max_seq_len, 1) * denominator
)
pos_idx_to_embed = torch.zeros((self.max_seq_len, self.embed_dim))
# Populate uneven entries.
pos_idx_to_embed[:, 0::2] = torch.sin(frequencies)
pos_idx_to_embed[:, 1::2] = torch.cos(frequencies)
# Save the positional embeddings in a constant buffer.
self.register_buffer("pos_idx_to_embed", pos_idx_to_embed)
def forward(self, seq_embeds: torch.Tensor) -> torch.Tensor:
"""
Args:
seq_embeds: The sequence embeddings in order. Allowed size:
1. [T, D], where T is the length of the sequence, and D is the
frame embedding dimension.
2. [B, T, D], where B is the batch size and T and D are the
same as above.
Returns a tensor of with the same dimensions as the input: i.e.,
[1, T, D] or [T, D].
"""
shape_len = len(seq_embeds.shape)
assert 2 <= shape_len <= 3
len_seq = seq_embeds.size(-2)
assert len_seq <= self.max_seq_len
pos_embeds = self.pos_idx_to_embed[0 : seq_embeds.size(-2), :]
# Adapt pre-computed positional embeddings to the input.
if shape_len == 3:
pos_embeds = pos_embeds.view((1, pos_embeds.size(0), pos_embeds.size(1)))
return pos_embeds
class LearnedAbsolutePositionEmbedding1D(nn.Module):
"""
Learnable absolute positional embeddings for 1D sequences.
Args:
embed_dim: The dimension of the embeddings.
max_seq_len: The maximum length to precompute the positional encodings.
"""
def __init__(self, embedding_dim: int = 512, num_pos: int = 1024) -> None:
super(LearnedAbsolutePositionEmbedding1D, self).__init__()
self.embeddings = nn.Embedding(num_pos, embedding_dim)
self.num_pos = num_pos
def forward(self, seq_embeds: torch.Tensor) -> torch.Tensor:
"""
Args:
seq_embeds: The sequence embeddings in order. Allowed size:
1. [T, D], where T is the length of the sequence, and D is the
frame embedding dimension.
2. [B, T, D], where B is the batch size and T and D are the
same as above.
Returns a tensor of with the same dimensions as the input: i.e.,
[1, T, D] or [T, D].
"""
shape_len = len(seq_embeds.shape)
assert 2 <= shape_len <= 3
len_seq = seq_embeds.size(-2)
assert len_seq <= self.num_pos
# [T, D]
pos_embeds = self.embeddings(torch.arange(len_seq).to(seq_embeds.device))
# Adapt pre-computed positional embeddings to the input.
if shape_len == 3:
pos_embeds = pos_embeds.view((1, pos_embeds.size(0), pos_embeds.size(1)))
return pos_embeds
class MySequential(nn.Sequential):
def forward(self, *inputs):
for module in self._modules.values():
if type(inputs) == tuple:
inputs = module(*inputs)
else:
inputs = module(inputs)
return inputs
class PreNorm(nn.Module):
def __init__(self, norm, fn, drop_path=None):
super().__init__()
self.norm = norm
self.fn = fn
self.drop_path = drop_path
def forward(self, x, *args, **kwargs):
shortcut = x
if self.norm != None:
x, size = self.fn(self.norm(x), *args, **kwargs)
else:
x, size = self.fn(x, *args, **kwargs)
if self.drop_path:
x = self.drop_path(x)
x = shortcut + x
return x, size
class Mlp(nn.Module):
def __init__(
self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.GELU,
):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.net = nn.Sequential(
OrderedDict(
[
("fc1", nn.Linear(in_features, hidden_features)),
("act", act_layer()),
("fc2", nn.Linear(hidden_features, out_features)),
]
)
)
def forward(self, x, size):
return self.net(x), size
class DepthWiseConv2d(nn.Module):
def __init__(
self,
dim_in,
kernel_size,
padding,
stride,
bias=True,
):
super().__init__()
self.dw = nn.Conv2d(
dim_in,
dim_in,
kernel_size=kernel_size,
padding=padding,
groups=dim_in,
stride=stride,
bias=bias,
)
def forward(self, x, size):
B, N, C = x.shape
H, W = size
assert N == H * W
x = self.dw(x.transpose(1, 2).view(B, C, H, W))
size = (x.size(-2), x.size(-1))
x = x.flatten(2).transpose(1, 2)
return x, size
class ConvEmbed(nn.Module):
"""Image to Patch Embedding"""
def __init__(
self,
patch_size=7,
in_chans=3,
embed_dim=64,
stride=4,
padding=2,
norm_layer=None,
pre_norm=True,
):
super().__init__()
self.patch_size = patch_size
self.proj = nn.Conv2d(
in_chans, embed_dim, kernel_size=patch_size, stride=stride, padding=padding
)
dim_norm = in_chans if pre_norm else embed_dim
self.norm = norm_layer(dim_norm) if norm_layer else None
self.pre_norm = pre_norm
def forward(self, x, size):
H, W = size
if len(x.size()) == 3:
if self.norm and self.pre_norm:
x = self.norm(x)
x = rearrange(x, "b (h w) c -> b c h w", h=H, w=W)
x = self.proj(x)
_, _, H, W = x.shape
x = rearrange(x, "b c h w -> b (h w) c")
if self.norm and not self.pre_norm:
x = self.norm(x)
return x, (H, W)
class ChannelAttention(nn.Module):
def __init__(self, dim, groups=8, qkv_bias=True):
super().__init__()
self.groups = groups
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.proj = nn.Linear(dim, dim)
def forward(self, x, size):
B, N, C = x.shape
qkv = (
self.qkv(x)
.reshape(B, N, 3, self.groups, C // self.groups)
.permute(2, 0, 3, 1, 4)
)
q, k, v = qkv[0], qkv[1], qkv[2]
q = q * (float(N) ** -0.5)
attention = q.transpose(-1, -2) @ k
attention = attention.softmax(dim=-1)
x = (attention @ v.transpose(-1, -2)).transpose(-1, -2)
x = x.transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
return x, size
class ChannelBlock(nn.Module):
def __init__(
self,
dim,
groups,
mlp_ratio=4.0,
qkv_bias=True,
drop_path_rate=0.0,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
conv_at_attn=True,
conv_at_ffn=True,
):
super().__init__()
drop_path = DropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()
self.conv1 = (
PreNorm(None, DepthWiseConv2d(dim, 3, 1, 1)) if conv_at_attn else None
)
self.channel_attn = PreNorm(
norm_layer(dim),
ChannelAttention(dim, groups=groups, qkv_bias=qkv_bias),
drop_path,
)
self.conv2 = (
PreNorm(None, DepthWiseConv2d(dim, 3, 1, 1)) if conv_at_ffn else None
)
self.ffn = PreNorm(
norm_layer(dim),
Mlp(
in_features=dim,
hidden_features=int(dim * mlp_ratio),
act_layer=act_layer,
),
drop_path,
)
def forward(self, x, size):
if self.conv1:
x, size = self.conv1(x, size)
x, size = self.channel_attn(x, size)
if self.conv2:
x, size = self.conv2(x, size)
x, size = self.ffn(x, size)
return x, size
def window_partition(x, window_size: int):
B, H, W, C = x.shape
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
windows = (
x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
)
return windows
def window_reverse(windows, batch_size: int, window_size: int, H: int, W: int):
B = batch_size
# this will cause onnx conversion failed for dynamic axis, because treated as constant
# int(windows.shape[0] / (H * W / window_size / window_size))
x = windows.view(
B, H // window_size, W // window_size, window_size, window_size, -1
)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
return x
class WindowAttention(nn.Module):
def __init__(self, dim, num_heads, window_size, qkv_bias=True):
super().__init__()
self.dim = dim
self.window_size = window_size
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = float(head_dim) ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.proj = nn.Linear(dim, dim)
self.softmax = nn.Softmax(dim=-1)
def forward(self, x, size):
H, W = size
B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
x = x.view(B, H, W, C)
pad_l = pad_t = 0
pad_r = (self.window_size - W % self.window_size) % self.window_size
pad_b = (self.window_size - H % self.window_size) % self.window_size
x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
_, Hp, Wp, _ = x.shape
x = window_partition(x, self.window_size)
x = x.view(-1, self.window_size * self.window_size, C)
# W-MSA/SW-MSA
# attn_windows = self.attn(x_windows)
B_, N, C = x.shape
qkv = (
self.qkv(x)
.reshape(B_, N, 3, self.num_heads, C // self.num_heads)
.permute(2, 0, 3, 1, 4)
)
q, k, v = qkv[0], qkv[1], qkv[2]
q = q * self.scale
attn = q @ k.transpose(-2, -1)
attn = self.softmax(attn)
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
x = self.proj(x)
# merge windows
x = x.view(-1, self.window_size, self.window_size, C)
x = window_reverse(x, B, self.window_size, Hp, Wp)
if pad_r > 0 or pad_b > 0:
x = x[:, :H, :W, :].contiguous()
x = x.view(B, H * W, C)
return x, size
class SpatialBlock(nn.Module):
def __init__(
self,
dim,
num_heads,
window_size,
mlp_ratio=4.0,
qkv_bias=True,
drop_path_rate=0.0,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
conv_at_attn=True,
conv_at_ffn=True,
):
super().__init__()
drop_path = DropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()
self.conv1 = (
PreNorm(None, DepthWiseConv2d(dim, 3, 1, 1)) if conv_at_attn else None
)
self.window_attn = PreNorm(
norm_layer(dim),
WindowAttention(dim, num_heads, window_size, qkv_bias=qkv_bias),
drop_path,
)
self.conv2 = (
PreNorm(None, DepthWiseConv2d(dim, 3, 1, 1)) if conv_at_ffn else None
)
self.ffn = PreNorm(
norm_layer(dim),
Mlp(
in_features=dim,
hidden_features=int(dim * mlp_ratio),
act_layer=act_layer,
),
drop_path,
)
def forward(self, x, size):
if self.conv1:
x, size = self.conv1(x, size)
x, size = self.window_attn(x, size)
if self.conv2:
x, size = self.conv2(x, size)
x, size = self.ffn(x, size)
return x, size
class DaViT(nn.Module):
"""DaViT: Dual-Attention Transformer
Args:
in_chans (int): Number of input image channels. Default: 3.
num_classes (int): Number of classes for classification head. Default: 1000.
patch_size (tuple(int)): Patch size of convolution in different stages. Default: (7, 2, 2, 2).
patch_stride (tuple(int)): Patch stride of convolution in different stages. Default: (4, 2, 2, 2).
patch_padding (tuple(int)): Patch padding of convolution in different stages. Default: (3, 0, 0, 0).
patch_prenorm (tuple(bool)): If True, perform norm before convlution layer. Default: (True, False, False, False).
embed_dims (tuple(int)): Patch embedding dimension in different stages. Default: (64, 128, 192, 256).
num_heads (tuple(int)): Number of spatial attention heads in different stages. Default: (4, 8, 12, 16).
num_groups (tuple(int)): Number of channel groups in different stages. Default: (4, 8, 12, 16).
window_size (int): Window size. Default: 7.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True.
drop_path_rate (float): Stochastic depth rate. Default: 0.1.
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
enable_checkpoint (bool): If True, enable checkpointing. Default: False.
conv_at_attn (bool): If True, performe depthwise convolution before attention layer. Default: True.
conv_at_ffn (bool): If True, performe depthwise convolution before ffn layer. Default: True.
"""
def __init__(
self,
in_chans=3,
num_classes=1000,
depths=(1, 1, 3, 1),
patch_size=(7, 2, 2, 2),
patch_stride=(4, 2, 2, 2),
patch_padding=(3, 0, 0, 0),
patch_prenorm=(False, False, False, False),
embed_dims=(64, 128, 192, 256),
num_heads=(3, 6, 12, 24),
num_groups=(3, 6, 12, 24),
window_size=7,
mlp_ratio=4.0,
qkv_bias=True,
drop_path_rate=0.1,
norm_layer=nn.LayerNorm,
enable_checkpoint=False,
conv_at_attn=True,
conv_at_ffn=True,
):
super().__init__()
self.num_classes = num_classes
self.embed_dims = embed_dims
self.num_heads = num_heads
self.num_groups = num_groups
self.num_stages = len(self.embed_dims)
self.enable_checkpoint = enable_checkpoint
assert self.num_stages == len(self.num_heads) == len(self.num_groups)
num_stages = len(embed_dims)
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths) * 2)]
depth_offset = 0
convs = []
blocks = []
for i in range(num_stages):
conv_embed = ConvEmbed(
patch_size=patch_size[i],
stride=patch_stride[i],
padding=patch_padding[i],
in_chans=in_chans if i == 0 else self.embed_dims[i - 1],
embed_dim=self.embed_dims[i],
norm_layer=norm_layer,
pre_norm=patch_prenorm[i],
)
convs.append(conv_embed)
block = MySequential(
*[
MySequential(
OrderedDict(
[
(
"spatial_block",
SpatialBlock(
embed_dims[i],
num_heads[i],
window_size,
drop_path_rate=dpr[depth_offset + j * 2],
qkv_bias=qkv_bias,
mlp_ratio=mlp_ratio,
conv_at_attn=conv_at_attn,
conv_at_ffn=conv_at_ffn,
),
),
(
"channel_block",
ChannelBlock(
embed_dims[i],
num_groups[i],
drop_path_rate=dpr[depth_offset + j * 2 + 1],
qkv_bias=qkv_bias,
mlp_ratio=mlp_ratio,
conv_at_attn=conv_at_attn,
conv_at_ffn=conv_at_ffn,
),
),
]
)
)
for j in range(depths[i])
]
)
blocks.append(block)
depth_offset += depths[i] * 2
self.convs = nn.ModuleList(convs)
self.blocks = nn.ModuleList(blocks)
self.norms = norm_layer(self.embed_dims[-1])
self.avgpool = nn.AdaptiveAvgPool1d(1)
self.head = (
nn.Linear(self.embed_dims[-1], num_classes)
if num_classes > 0
else nn.Identity()
)
self.apply(self._init_weights)
@property
def dim_out(self):
return self.embed_dims[-1]
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Conv2d):
nn.init.normal_(m.weight, std=0.02)
for name, _ in m.named_parameters():
if name in ["bias"]:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.weight, 1.0)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1.0)
nn.init.constant_(m.bias, 0)
def forward_features_unpool(self, x):
"""
forward until avg pooling
Args:
x (_type_): input image tensor
"""
input_size = (x.size(2), x.size(3))
for conv, block in zip(self.convs, self.blocks):
x, input_size = conv(x, input_size)
if self.enable_checkpoint:
x, input_size = checkpoint.checkpoint(block, x, input_size)
else:
x, input_size = block(x, input_size)
return x
def forward_features(self, x):
x = self.forward_features_unpool(x)
# (batch_size, num_tokens, token_dim)
x = self.avgpool(x.transpose(1, 2))
# (batch_size, 1, num_tokens)
x = torch.flatten(x, 1)
x = self.norms(x)
return x
def forward(self, x):
x = self.forward_features(x)
x = self.head(x)
return x
@classmethod
def from_config(cls, config):
return cls(
depths=config.depths,
embed_dims=config.dim_embed,
num_heads=config.num_heads,
num_groups=config.num_groups,
patch_size=config.patch_size,
patch_stride=config.patch_stride,
patch_padding=config.patch_padding,
patch_prenorm=config.patch_prenorm,
drop_path_rate=config.drop_path_rate,
window_size=config.window_size,
)
if is_flash_attn_2_available():
from flash_attn import flash_attn_func, flash_attn_varlen_func
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
+14
View File
@@ -0,0 +1,14 @@
from namo.models.vision.ve_aim import AimV2VE
from namo.models.vision.ve_siglip_navit import SigLipNavitVE
from torch import nn
def get_ve(config, **kwargs):
type_name = config.vision_config._name_or_path.lower()
if "siglip" in type_name and "navit" not in type_name:
return SigLipNavitVE(config, **kwargs)
elif "aim" in type_name:
return AimV2VE(config, **kwargs)
else:
raise ValueError(f"Unsupported vision model: {type_name}")
+97
View File
@@ -0,0 +1,97 @@
import math
import torch
from torch import nn
import torchvision
from transformers import (
AutoProcessor,
AutoModelForCausalLM,
CLIPImageProcessor,
AutoConfig,
AutoModel,
)
import os
from transformers import AutoModel
from .ve_base import BaseVE
from loguru import logger
from namo.utils.utils import is_main_process
from . import AIMv2Model
from . import AIMv2ModelNative
from . import AIMv2Config
from .aimv2.modeling_aimv2_native import RMSNorm
class VLPatchMerger(nn.Module):
def __init__(self, dim: int, context_dim: int, spatial_merge_size: int = 2) -> None:
super().__init__()
self.hidden_size = context_dim * (spatial_merge_size**2)
self.ln_q = RMSNorm(context_dim, eps=1e-6)
self.mlp = nn.Sequential(
nn.Linear(self.hidden_size, self.hidden_size),
nn.GELU(),
nn.Linear(self.hidden_size, dim),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.mlp(self.ln_q(x).view(-1, self.hidden_size))
return x
class AimV2VE(BaseVE):
def _load_vision_tower(self):
# other models can be customized here, normally AutoModel can handle well
if os.path.exists(self.vision_tower_name):
if "native" in self.vision_tower_name:
if is_main_process():
logger.info(
f"loading AIMv2-native pretrain model: {self.vision_tower_name} {self.torch_dtype}"
)
self.vision_tower = AIMv2ModelNative.from_pretrained(
self.vision_tower_name,
ignore_mismatched_sizes=True,
torch_dtype=self.torch_dtype,
)
else:
if is_main_process():
logger.info(
f"loading AIMv2 pretrain model: {self.vision_tower_name}"
)
self.vision_tower = AIMv2Model.from_pretrained(
self.vision_tower_name,
ignore_mismatched_sizes=True,
torch_dtype=self.torch_dtype,
)
else:
if is_main_process():
logger.info(f"creating AIMv2 model: {self.vision_tower_name}")
if "native" in self.vision_tower_name:
self.vision_tower = AIMv2ModelNative(
config=self.vision_config,
)
else:
self.vision_tower = AIMv2Model(
config=self.vision_config,
)
# todo: should check if vision_tower_name exist, if not, using config._name_or_path
# self.image_processor = CLIPImageProcessor.from_pretrained(
# self.vision_tower_name
# )
# self.image_processor.do_center_crop = False
# add a new PatchMerger after vision tower?
# self.patch_merger = VLPatchMerger(
# self.vision_config.dim, self.vision_config.context_dim
# )
def feature_select(self, image_forward_outs):
image_features = image_forward_outs.hidden_states[self.select_layer]
if self.select_feature == "patch":
return image_features[:, 1:]
elif self.select_feature in ["cls_patch", "same", "all", "default"]:
return image_features
else:
raise ValueError(f"Invalid select feature: {self.select_feature}")
def forward(self, images, image_sizes=None):
return self.basic_forward(images)
+144
View File
@@ -0,0 +1,144 @@
import os
import torch
import torch.nn as nn
from transformers import AutoConfig, AutoModel, AutoImageProcessor
import torchvision
from loguru import logger
from namo.utils.utils import rank0_print, is_main_process
class BaseVE(nn.Module):
def __init__(self, config, **kwargs):
super().__init__()
delay_load = kwargs.pop("delay_load", False)
self.vision_config = config.vision_config
self.model_name_or_path = config._name_or_path
self.torch_dtype = config.vision_config.torch_dtype
self.is_loaded = False
self.vision_tower_name = config.vision_config._name_or_path.lower()
self.select_layer = config.vision_feature_layer
self.select_feature = config.vision_feature_select_strategy
self.new_img_size = config.new_img_size
self.unfreeze_ve = config.unfreeze_ve
self.longest_edge = config.longest_edge
self.shortest_edge = config.shortest_edge
# actually delay_load doesn't needed anymore
if not delay_load:
self.load_model()
elif self.unfreeze_ve:
self.load_model()
else:
self.cfg_only = AutoConfig.from_pretrained(
self.vision_tower_name, trust_remote_code=True
)
if self.new_img_size:
logger.info(f"Update using new image size: {self.cfg_only.image_size}")
self._load_image_processor()
def load_model(self):
if is_main_process():
logger.info(f"Loading base components for {self.vision_tower_name}")
self._load_image_processor()
self._load_vision_tower()
if self.new_img_size:
if is_main_process():
logger.info(f"set new image size {self.new_img_size}")
self.image_processor.size = {
"height": self.new_img_size,
"width": self.new_img_size,
}
if not self.unfreeze_ve:
self.vision_tower.requires_grad_(False)
self.is_loaded = True
def _load_image_processor(self):
processor_path = self.vision_tower_name
logger.info(f"loading image processor from: {self.vision_tower_name}")
if os.path.exists(self.vision_tower_name):
processor_path = self.vision_tower_name
elif os.path.exists(self.model_name_or_path):
processor_path = self.model_name_or_path
else:
raise FileNotFoundError(
f"No processor found for either {self.vision_tower_name} or {self.model_name_or_path}"
)
self.image_processor = AutoImageProcessor.from_pretrained(
processor_path, trust_remote_code=True
)
if is_main_process():
logger.info(f"==> self.longest_edge {self.longest_edge}")
if self.longest_edge is not None:
self.image_processor.do_resize = True
# override from default image processor
self.image_processor.size["longest_edge"] = self.longest_edge
self.image_processor.size["shortest_edge"] = 42 # hard coded here
if is_main_process():
logger.info(f"==> override longest_edge {self.image_processor.size}")
if self.shortest_edge is not None:
self.image_processor.size["shortest_edge"] = self.shortest_edge
if is_main_process():
logger.info(f"override shortest_edge {self.shortest_edge}")
def _load_vision_tower(self):
raise NotImplementedError
def feature_select(self, image_forward_outs):
raise NotImplementedError
def basic_forward(self, images):
if isinstance(images, list):
return self._process_image_list(images)
else:
return self._process_batch(images)
def _process_image_list(self, images):
image_features = []
for image in images:
inputs = image.to(device=self.device, dtype=self.dtype)
# print(f'image shape: {inputs.shape}')
if len(inputs.shape) < 4:
inputs = inputs.unsqueeze(0)
features = self._get_features(inputs)
image_features.append(features)
return image_features
def _process_batch(self, images):
inputs = images.to(device=self.device, dtype=self.dtype)
return self._get_features(inputs)
def _get_features(self, inputs):
outputs = self.vision_tower(inputs, output_hidden_states=True)
return self.feature_select(outputs).to(inputs.dtype)
@property
def dtype(self):
return next(self.vision_tower.parameters()).dtype
@property
def device(self):
return next(self.vision_tower.parameters()).device
@property
def config(self):
return self.vision_tower.config if self.is_loaded else self.cfg_only
@property
def hidden_size(self):
return self.config.hidden_size
@property
def num_patches(self):
return (self.config.image_size // self.config.patch_size) ** 2
@property
def image_size_auto(self):
return self.new_img_size or self.config.image_size
def save_pretrained(self, model_path):
self.vision_tower.save_pretrained(model_path)
+33
View File
@@ -0,0 +1,33 @@
import math
import torch
from torch import nn
import torchvision
from transformers import (
AutoModel,
)
from transformers import AutoModel
from .ve_base import BaseVE
from loguru import logger
class SigLipNavitVE(BaseVE):
def _load_vision_tower(self):
logger.info(f"Loading AIMv2 specific model: {self.vision_tower_name}")
# other models can be customized here, normally AutoModel can handle well
self.vision_tower = AutoModel.from_pretrained(
self.vision_tower_name, ignore_mismatched_sizes=True, trust_remote_code=True
)
self.image_processor.do_center_crop = False
def feature_select(self, image_forward_outs):
image_features = image_forward_outs.hidden_states[self.select_layer]
if self.select_feature == "patch":
return image_features[:, 1:]
elif self.select_feature in ["cls_patch", "same"]:
return image_features
else:
raise ValueError(f"Invalid select feature: {self.select_feature}")
def forward(self, images, image_sizes=None):
return self.basic_forward(images)
+6
View File
@@ -0,0 +1,6 @@
from .image_processing_namo import NamoImageProcessor
from transformers import AutoImageProcessor
AutoImageProcessor.register(
NamoImageProcessor, slow_image_processor_class=NamoImageProcessor
)
+99
View File
@@ -0,0 +1,99 @@
from typing import Dict, List, Optional, Union
import numpy as np
from transformers.image_utils import ImageInput, is_valid_image
from transformers.image_utils import (
OPENAI_CLIP_MEAN,
OPENAI_CLIP_STD,
ChannelDimension,
ImageInput,
PILImageResampling,
)
from transformers.utils import is_vision_available, logging
from transformers import CLIPImageProcessor
from transformers.image_transforms import resize
logger = logging.get_logger(__name__)
if is_vision_available():
import PIL
class NamoImageProcessor(CLIPImageProcessor):
model_input_names = ["pixel_values"]
def resize(
self,
image: np.ndarray,
size: Dict[str, int],
resample: PILImageResampling = PILImageResampling.BICUBIC,
data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs,
) -> np.ndarray:
"""
Override CLIP original resize logic, we matching to longest edge if too large
matching to shortest if too small, if within, we do nothing.
"""
minimal_divider = 28
config_shortest = size.get("shortest_edge", minimal_divider)
config_longest = size.get("longest_edge", 714)
orig_height, orig_width = image.shape[:2]
current_shortest = min(orig_height, orig_width)
current_longest = max(orig_height, orig_width)
# do nothing
if current_shortest >= config_shortest and current_longest <= config_longest:
# we don't apply divided with 28, not necessary
new_height = (orig_height // minimal_divider) * minimal_divider
new_width = (orig_width // minimal_divider) * minimal_divider
return resize(
image,
size=(new_height, new_width),
resample=resample,
data_format=data_format,
input_data_format=input_data_format,
**kwargs,
)
# Determine the appropriate scaling factor.
# If the image is too large, scale down using the longest edge.
if current_longest > config_longest:
scale = config_longest / current_longest
if current_shortest * scale < config_shortest:
# if current shortest too small after scale, we scale to shortest
scale = config_shortest / current_shortest
# If the image is too small, scale up using the shortest edge.
elif current_shortest < config_shortest:
scale = config_shortest / current_shortest
else:
scale = 1.0 # This branch should not be reached.
new_height = int(round(orig_height * scale))
new_width = int(round(orig_width * scale))
# if longest still excceed config_longest
if max(new_height, new_width) > config_longest:
# this will result restortion, but should not effect detections
if new_width > new_height:
new_width = config_longest
else:
new_height = config_longest
# ensure divided by 28 (14*2)
new_height = (new_height // minimal_divider) * minimal_divider
new_width = (new_width // minimal_divider) * minimal_divider
return resize(
image,
size=(new_height, new_width),
resample=resample,
data_format=data_format,
input_data_format=input_data_format,
**kwargs,
)
__all__ = ["NamoImageProcessor"]
View File
+85
View File
@@ -0,0 +1,85 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Custom evaluation tasks for LightEval."""
from lighteval.metrics.dynamic_metrics import (
ExprExtractionConfig,
LatexExtractionConfig,
multilingual_extractive_match_metric,
)
from lighteval.tasks.lighteval_task import LightevalTaskConfig
from lighteval.tasks.requests import Doc
from lighteval.utils.language import Language
metric = multilingual_extractive_match_metric(
language=Language.ENGLISH,
fallback_mode="first_match",
precision=5,
gold_extraction_target=(LatexExtractionConfig(),),
pred_extraction_target=(ExprExtractionConfig(), LatexExtractionConfig()),
aggregation_function=max,
)
def prompt_fn(line, task_name: str = None):
"""Assumes the model is either prompted to emit \\boxed{answer} or does so automatically"""
return Doc(
task_name=task_name,
query=line["problem"],
choices=[line["solution"]],
gold_index=0,
)
# Define tasks
aime24 = LightevalTaskConfig(
name="aime24",
suite=["custom"],
prompt_function=prompt_fn,
hf_repo="HuggingFaceH4/aime_2024",
hf_subset="default",
hf_avail_splits=["train"],
evaluation_splits=["train"],
few_shots_split=None,
few_shots_select=None,
generation_size=32768,
metric=[metric],
version=1,
)
math_500 = LightevalTaskConfig(
name="math_500",
suite=["custom"],
prompt_function=prompt_fn,
hf_repo="HuggingFaceH4/MATH-500",
hf_subset="default",
hf_avail_splits=["test"],
evaluation_splits=["test"],
few_shots_split=None,
few_shots_select=None,
generation_size=32768,
metric=[metric],
version=1,
)
# Add tasks to the table
TASKS_TABLE = []
TASKS_TABLE.append(aime24)
TASKS_TABLE.append(math_500)
# MODULE LOGIC
if __name__ == "__main__":
print([t["name"] for t in TASKS_TABLE])
print(len(TASKS_TABLE))
+162
View File
@@ -0,0 +1,162 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
from distilabel.llms import OpenAILLM
from distilabel.pipeline import Pipeline
from distilabel.steps.tasks import TextGeneration
def build_distilabel_pipeline(
model: str,
base_url: str = "http://localhost:8000/v1",
prompt_column: Optional[str] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
max_new_tokens: int = 8192,
num_generations: int = 1,
) -> Pipeline:
generation_kwargs = {"max_new_tokens": max_new_tokens}
if temperature is not None:
generation_kwargs["temperature"] = temperature
if top_p is not None:
generation_kwargs["top_p"] = top_p
with Pipeline().ray() as pipeline:
TextGeneration(
llm=OpenAILLM(
base_url=base_url,
api_key="something",
model=model,
# thinking can take some time...
timeout=10 * 60,
generation_kwargs=generation_kwargs,
),
input_mappings=(
{"instruction": prompt_column} if prompt_column is not None else {}
),
input_batch_size=64, # on 4 nodes bs ~60+ leads to preemption due to KV cache exhaustion
num_generations=num_generations,
)
return pipeline
if __name__ == "__main__":
import argparse
from datasets import load_dataset
parser = argparse.ArgumentParser(
description="Run distilabel pipeline for generating responses with DeepSeek R1"
)
parser.add_argument(
"--hf-dataset",
type=str,
required=True,
help="HuggingFace dataset to load",
)
parser.add_argument(
"--hf-dataset-config",
type=str,
required=False,
help="Dataset config to use",
)
parser.add_argument(
"--hf-dataset-split",
type=str,
default="train",
help="Dataset split to use",
)
parser.add_argument("--prompt-column", type=str, default="prompt")
parser.add_argument(
"--model",
type=str,
required=True,
help="Model name to use for generation",
)
parser.add_argument(
"--vllm-server-url",
type=str,
default="http://localhost:8000/v1",
help="URL of the vLLM server",
)
parser.add_argument(
"--temperature",
type=float,
help="Temperature for generation",
)
parser.add_argument(
"--top-p",
type=float,
help="Top-p value for generation",
)
parser.add_argument(
"--max-new-tokens",
type=int,
default=8192,
help="Maximum number of new tokens to generate",
)
parser.add_argument(
"--num-generations",
type=int,
default=1,
help="Number of generations per problem",
)
parser.add_argument(
"--hf-output-dataset",
type=str,
required=False,
help="HuggingFace repo to push results to",
)
parser.add_argument(
"--private",
action="store_true",
help="Whether to make the output dataset private when pushing to HF Hub",
)
args = parser.parse_args()
print("\nRunning with arguments:")
for arg, value in vars(args).items():
print(f" {arg}: {value}")
print()
print(
f"Loading '{args.hf_dataset}' (config: {args.hf_dataset_config}, split: {args.hf_dataset_split}) dataset..."
)
dataset = load_dataset(args.hf_dataset, split=args.hf_dataset_split)
print("Dataset loaded!")
pipeline = build_distilabel_pipeline(
model=args.model,
base_url=args.vllm_server_url,
prompt_column=args.prompt_column,
temperature=args.temperature,
top_p=args.top_p,
max_new_tokens=args.max_new_tokens,
num_generations=args.num_generations,
)
print("Running generation pipeline...")
distiset = pipeline.run(dataset=dataset, use_cache=False)
print("Generation pipeline finished!")
if args.hf_output_dataset:
print(f"Pushing resulting dataset to '{args.hf_output_dataset}'...")
distiset.push_to_hub(args.hf_output_dataset, private=args.private)
print("Dataset pushed!")
+223
View File
@@ -0,0 +1,223 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import re
from datetime import datetime
from dataclasses import dataclass, field
from typing import Optional
from datasets import load_dataset, load_from_disk
from transformers import Qwen2VLForConditionalGeneration
from math_verify import parse, verify
from .trainer import Qwen2VLGRPOTrainer
from trl import (
GRPOConfig,
GRPOTrainer,
ModelConfig,
ScriptArguments,
TrlParser,
get_peft_config,
)
@dataclass
class GRPOScriptArguments(ScriptArguments):
"""
Script arguments for the GRPO training script.
Args:
reward_funcs (`list[str]`):
List of reward functions. Possible values: 'accuracy', 'format'.
"""
reward_funcs: list[str] = field(
default_factory=lambda: ["accuracy", "format"],
metadata={
"help": "List of reward functions. Possible values: 'accuracy', 'format'"
},
)
max_pixels: Optional[int] = field(
default=12845056,
metadata={"help": "Maximum number of pixels for the image"},
)
min_pixels: Optional[int] = field(
default=3136,
metadata={"help": "Minimum number of pixels for the image"},
)
def accuracy_reward(completions, solution, **kwargs):
"""Reward function that checks if the completion is correct using either symbolic verification or exact string matching."""
contents = [completion[0]["content"] for completion in completions]
rewards = []
current_time = datetime.now().strftime("%d-%H-%M-%S-%f")
for content, sol in zip(contents, solution):
reward = 0.0
# Try symbolic verification first
try:
answer = parse(content)
if float(verify(answer, parse(sol))) > 0:
reward = 1.0
except Exception:
pass # Continue to next verification method if this fails
# If symbolic verification failed, try string matching
if reward == 0.0:
try:
# Extract answer from solution if it has think/answer tags
sol_match = re.search(r"<answer>(.*?)</answer>", sol)
ground_truth = sol_match.group(1).strip() if sol_match else sol.strip()
# Extract answer from content if it has think/answer tags
content_match = re.search(r"<answer>(.*?)</answer>", content)
student_answer = (
content_match.group(1).strip() if content_match else content.strip()
)
# Compare the extracted answers
if student_answer == ground_truth:
reward = 1.0
except Exception:
pass # Keep reward as 0.0 if both methods fail
rewards.append(reward)
if os.getenv("DEBUG_MODE") == "true":
log_path = os.getenv("LOG_PATH")
# local_rank = int(os.getenv("LOCAL_RANK", 0))
with open(log_path, "a") as f:
f.write(
f"------------- {current_time} Accuracy reward: {reward} -------------\n"
)
f.write(f"Content: {content}\n")
f.write(f"Solution: {sol}\n")
return rewards
def format_reward(completions, **kwargs):
"""Reward function that checks if the completion has a specific format."""
pattern = r"<think>.*?</think>\s*<answer>.*?</answer>"
completion_contents = [completion[0]["content"] for completion in completions]
matches = [re.match(pattern, content) for content in completion_contents]
return [1.0 if match else 0.0 for match in matches]
reward_funcs_registry = {
"accuracy": accuracy_reward,
"format": format_reward,
}
SYSTEM_PROMPT = (
"A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant "
"first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning "
"process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., "
"<think> reasoning process here </think><answer> answer here </answer>"
)
def main(script_args, training_args, model_args):
# Get reward functions
reward_funcs = [reward_funcs_registry[func] for func in script_args.reward_funcs]
# Load the dataset
dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
# Format into conversation
def make_conversation(example):
return {
"prompt": [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": example["problem"]},
],
}
# def make_conversation_image(example):
# return {
# "prompt": [
# {"role": "system", "content": [{"type": "text", "text": SYSTEM_PROMPT}]},
# {
# "role": "user",
# "content": [
# {"type": "image"},
# {"type": "text", "text": example["problem"]},
# ],
# },
# ],
# }
QUESTION_TEMPLATE = "{Question} Output the thinking process in <think> </think> and final answer (number) in <answer> </answer> tags."
def make_conversation_image(example):
return {
"prompt": [
{
"role": "user",
"content": [
{"type": "image"},
{
"type": "text",
"text": QUESTION_TEMPLATE.format(
Question=example["problem"]
),
},
],
},
],
}
if "image" in dataset[script_args.dataset_train_split].features:
print("has image in dataset")
dataset = dataset.map(
make_conversation_image
) # Utilize multiprocessing for faster mapping
# dataset = dataset.remove_columns(["original_question", "original_answer"])
else:
print("no image in dataset")
dataset = dataset.map(make_conversation)
dataset = dataset.remove_columns("messages")
trainer_cls = Qwen2VLGRPOTrainer
# Initialize the GRPO trainer
trainer = trainer_cls(
model=model_args.model_name_or_path,
reward_funcs=reward_funcs,
args=training_args,
train_dataset=dataset[script_args.dataset_train_split],
eval_dataset=(
dataset[script_args.dataset_test_split]
if training_args.eval_strategy != "no"
else None
),
peft_config=get_peft_config(model_args),
attn_implementation=model_args.attn_implementation,
max_pixels=script_args.max_pixels,
min_pixels=script_args.min_pixels,
)
# Train and push the model to the Hub
trainer.train()
# Save and push to hub
trainer.save_model(training_args.output_dir)
if training_args.push_to_hub:
trainer.push_to_hub(dataset_name=script_args.dataset_name)
if __name__ == "__main__":
parser = TrlParser((GRPOScriptArguments, GRPOConfig, ModelConfig))
script_args, training_args, model_args = parser.parse_args_and_config()
main(script_args, training_args, model_args)
+107
View File
@@ -0,0 +1,107 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Supervised fine-tuning script for decoder language models.
Usage:
# One 1 node of 8 x H100s
accelerate launch --config_file=configs/zero3.yaml src/open_r1/sft.py \
--model_name_or_path Qwen/Qwen2.5-1.5B-Instruct \
--dataset_name HuggingFaceH4/Bespoke-Stratos-17k \
--learning_rate 2.0e-5 \
--num_train_epochs 1 \
--packing \
--max_seq_length 4096 \
--per_device_train_batch_size 4 \
--gradient_accumulation_steps 4 \
--gradient_checkpointing \
--bf16 \
--logging_steps 5 \
--eval_strategy steps \
--eval_steps 100 \
--output_dir data/Qwen2.5-1.5B-Open-R1-Distill
"""
from datasets import load_dataset
from transformers import AutoTokenizer
from trl import (
ModelConfig,
ScriptArguments,
SFTConfig,
SFTTrainer,
TrlParser,
get_kbit_device_map,
get_peft_config,
get_quantization_config,
)
def main(script_args, training_args, model_args):
################
# Model init kwargs & Tokenizer
################
quantization_config = get_quantization_config(model_args)
model_kwargs = dict(
revision=model_args.model_revision,
trust_remote_code=model_args.trust_remote_code,
attn_implementation=model_args.attn_implementation,
torch_dtype=model_args.torch_dtype,
use_cache=False if training_args.gradient_checkpointing else True,
device_map=get_kbit_device_map() if quantization_config is not None else None,
quantization_config=quantization_config,
)
training_args.model_init_kwargs = model_kwargs
tokenizer = AutoTokenizer.from_pretrained(
model_args.model_name_or_path,
trust_remote_code=model_args.trust_remote_code,
use_fast=True,
)
tokenizer.pad_token = tokenizer.eos_token
################
# Dataset
################
dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
################
# Training
################
trainer = SFTTrainer(
model=model_args.model_name_or_path,
args=training_args,
train_dataset=dataset[script_args.dataset_train_split],
eval_dataset=(
dataset[script_args.dataset_test_split]
if training_args.eval_strategy != "no"
else None
),
processing_class=tokenizer,
peft_config=get_peft_config(model_args),
)
trainer.train()
# Save and push to hub
trainer.save_model(training_args.output_dir)
if training_args.push_to_hub:
trainer.push_to_hub(dataset_name=script_args.dataset_name)
if __name__ == "__main__":
parser = TrlParser((ScriptArguments, SFTConfig, ModelConfig))
script_args, training_args, model_args = parser.parse_args_and_config()
main(script_args, training_args, model_args)
+4
View File
@@ -0,0 +1,4 @@
from .grpo_trainer import Qwen2VLGRPOTrainer
__all__ = ["Qwen2VLGRPOTrainer"]
+705
View File
@@ -0,0 +1,705 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import textwrap
from collections import defaultdict
from typing import Any, Callable, Optional, Union
import torch
import torch.utils.data
import transformers
from datasets import Dataset, IterableDataset
from packaging import version
from transformers import (
AriaForConditionalGeneration,
AriaProcessor,
AutoModelForCausalLM,
AutoModelForSequenceClassification,
AutoProcessor,
AutoTokenizer,
GenerationConfig,
PreTrainedModel,
PreTrainedTokenizerBase,
Qwen2VLForConditionalGeneration,
Trainer,
TrainerCallback,
is_wandb_available,
)
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
from transformers.utils import is_peft_available
from trl.data_utils import (
apply_chat_template,
is_conversational,
maybe_apply_chat_template,
)
from trl.models import (
create_reference_model,
prepare_deepspeed,
unwrap_model_for_generation,
)
from trl.trainer.grpo_config import GRPOConfig
from trl.trainer.utils import generate_model_card, get_comet_experiment_url
import copy
if is_peft_available():
from peft import PeftConfig, get_peft_model
if is_wandb_available():
import wandb
# What we call a reward function is a callable that takes a list of prompts and completions and returns a list of
# rewards. When it's a string, it's a model ID, so it's loaded as a pretrained model.
RewardFunc = Union[str, PreTrainedModel, Callable[[list, list], list[float]]]
class Qwen2VLGRPOTrainer(Trainer):
"""
Trainer for the Group Relative Policy Optimization (GRPO) method. This algorithm was initially proposed in the
paper [DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models](https://huggingface.co/papers/2402.03300).
Example:
```python
from datasets import load_dataset
from trl import GRPOTrainer
dataset = load_dataset("trl-lib/tldr", split="train")
trainer = GRPOTrainer(
model="Qwen/Qwen2-0.5B-Instruct",
reward_funcs="weqweasdas/RM-Gemma-2B",
train_dataset=dataset,
)
trainer.train()
```
Args:
model (`Union[str, PreTrainedModel]`):
Model to be trained. Can be either:
- A string, being the *model id* of a pretrained model hosted inside a model repo on huggingface.co, or
a path to a *directory* containing model weights saved using
[`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is
loaded using [`~transformers.AutoModelForCausalLM.from_pretrained`] with the keywork arguments
in `args.model_init_kwargs`.
- A [`~transformers.PreTrainedModel`] object. Only causal language models are supported.
reward_funcs (`Union[RewardFunc, list[RewardFunc]]`):
Reward functions to be used for computing the rewards. To compute the rewards, we call all the reward
functions with the prompts and completions and sum the rewards. Can be either:
- A single reward function, such as:
- A string: The *model ID* of a pretrained model hosted inside a model repo on huggingface.co, or a
path to a *directory* containing model weights saved using
[`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is loaded
using [`~transformers.AutoModelForSequenceClassification.from_pretrained`] with `num_labels=1` and the
keyword arguments in `args.model_init_kwargs`.
- A [`~transformers.PreTrainedModel`] object: Only sequence classification models are supported.
- A custom reward function: The function is provided with the prompts and the generated completions,
plus any additional columns in the dataset. It should return a list of rewards. For more details, see
[Using a custom reward function](#using-a-custom-reward-function).
- A list of reward functions, where each item can independently be any of the above types. Mixing different
types within the list (e.g., a string model ID and a custom reward function) is allowed.
args ([`GRPOConfig`], *optional*, defaults to `None`):
Configuration for this trainer. If `None`, a default configuration is used.
train_dataset ([`~datasets.Dataset`] or [`~datasets.IterableDataset`]):
Dataset to use for training. It must include a column `"prompt"`. Any additional columns in the dataset is
ignored. The format of the samples can be either:
- [Standard](dataset_formats#standard): Each sample contains plain text.
- [Conversational](dataset_formats#conversational): Each sample contains structured messages (e.g., role
and content).
eval_dataset ([`~datasets.Dataset`], [`~datasets.IterableDataset`] or `dict[str, Union[Dataset, IterableDataset]]`):
Dataset to use for evaluation. It must meet the same requirements as `train_dataset`.
processing_class ([`~transformers.PreTrainedTokenizerBase`], *optional*, defaults to `None`):
Processing class used to process the data. The padding side must be set to "left". If `None`, the
processing class is loaded from the model's name with [`~transformers.AutoTokenizer.from_pretrained`].
reward_processing_classes (`Union[PreTrainedTokenizerBase, list[PreTrainedTokenizerBase]]`, *optional*, defaults to `None`):
Processing classes corresponding to the reward functions specified in `reward_funcs`. Can be either:
- A single processing class: Used when `reward_funcs` contains only one reward function.
- A list of processing classes: Must match the order and length of the reward functions in `reward_funcs`.
If set to `None`, or if an element of the list corresponding to a [`~transformers.PreTrainedModel`] is
`None`, the tokenizer for the model is automatically loaded using [`~transformers.AutoTokenizer.from_pretrained`].
For elements in `reward_funcs` that are custom reward functions (not [`~transformers.PreTrainedModel`]),
the corresponding entries in `reward_processing_classes` are ignored.
callbacks (list of [`~transformers.TrainerCallback`], *optional*, defaults to `None`):
List of callbacks to customize the training loop. Will add those to the list of default callbacks
detailed in [here](https://huggingface.co/docs/transformers/main_classes/callback).
If you want to remove one of the default callbacks used, use the [`~transformers.Trainer.remove_callback`]
method.
optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`, *optional*, defaults to `(None, None)`):
A tuple containing the optimizer and the scheduler to use. Will default to an instance of [`AdamW`] on your
model and a scheduler given by [`get_linear_schedule_with_warmup`] controlled by `args`.
peft_config ([`~peft.PeftConfig`], *optional*, defaults to `None`):
PEFT configuration used to wrap the model. If `None`, the model is not wrapped.
"""
def __init__(
self,
model: Union[str, PreTrainedModel],
reward_funcs: Union[RewardFunc, list[RewardFunc]],
args: GRPOConfig = None,
train_dataset: Optional[Union[Dataset, IterableDataset]] = None,
eval_dataset: Optional[
Union[Dataset, IterableDataset, dict[str, Union[Dataset, IterableDataset]]]
] = None,
processing_class: Optional[PreTrainedTokenizerBase] = None,
reward_processing_classes: Optional[
Union[PreTrainedTokenizerBase, list[PreTrainedTokenizerBase]]
] = None,
callbacks: Optional[list[TrainerCallback]] = None,
optimizers: tuple[
Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]
] = (None, None),
peft_config: Optional["PeftConfig"] = None,
max_pixels: Optional[int] = 12845056,
min_pixels: Optional[int] = 3136,
attn_implementation: str = "flash_attention_2",
):
# Args
if args is None:
model_name = model if isinstance(model, str) else model.config._name_or_path
model_name = model_name.split("/")[-1]
args = GRPOConfig(f"{model_name}-GRPO")
# Models
# Trained model
model_init_kwargs = args.model_init_kwargs or {}
model_init_kwargs["attn_implementation"] = attn_implementation
if isinstance(model, str):
model_id = model
torch_dtype = model_init_kwargs.get("torch_dtype")
if (
isinstance(torch_dtype, torch.dtype)
or torch_dtype == "auto"
or torch_dtype is None
):
pass # torch_dtype is already a torch.dtype or "auto" or None
elif isinstance(torch_dtype, str): # it's a str, but not "auto"
torch_dtype = getattr(torch, torch_dtype)
model_init_kwargs["torch_dtype"] = torch_dtype
else:
raise ValueError(
"Invalid `torch_dtype` passed to `GRPOConfig`. Expected either 'auto' or a string representing "
f"a `torch.dtype` (e.g., 'float32'), but got {torch_dtype}."
)
# Disable caching if gradient checkpointing is enabled (not supported)
model_init_kwargs["use_cache"] = (
False
if args.gradient_checkpointing
else model_init_kwargs.get("use_cache")
)
if "Qwen2-VL" in model_id:
model = Qwen2VLForConditionalGeneration.from_pretrained(
model, **model_init_kwargs
)
elif "Aria" in model_id:
model_init_kwargs.pop("use_cache")
model = AriaForConditionalGeneration.from_pretrained(
model, **model_init_kwargs
)
else:
model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)
else:
model_id = model.config._name_or_path
if args.model_init_kwargs is not None:
raise ValueError(
"You passed `model_init_kwargs` to the `GRPOConfig`, but your model is already instantiated. "
"This argument can only be used when the `model` argument is a string."
)
if peft_config is not None:
model = get_peft_model(model, peft_config)
# Reference model
if is_deepspeed_zero3_enabled():
if "Qwen2-VL" in model_id:
self.ref_model = Qwen2VLForConditionalGeneration.from_pretrained(
model_id, **model_init_kwargs
)
elif "Aria" in model_id:
self.ref_model = AriaForConditionalGeneration.from_pretrained(
model_id, **model_init_kwargs
)
else:
self.ref_model = AutoModelForCausalLM.from_pretrained(
model_id, **model_init_kwargs
)
elif peft_config is None:
# If PEFT configuration is not provided, create a reference model based on the initial model.
self.ref_model = create_reference_model(model)
else:
# If PEFT is used, the reference model is not needed since the adapter can be disabled
# to revert to the initial model.
self.ref_model = None
# Processing class
if processing_class is None:
if "Qwen2-VL" in model_id or "Aria" in model_id:
processing_class = AutoProcessor.from_pretrained(model_id)
pad_token_id = processing_class.tokenizer.pad_token_id
processing_class.pad_token_id = pad_token_id
processing_class.eos_token_id = processing_class.tokenizer.eos_token_id
if "Qwen2-VL" in model_id:
processing_class.image_processor.max_pixels = max_pixels
processing_class.image_processor.min_pixels = min_pixels
else:
processing_class = AutoTokenizer.from_pretrained(
model.config._name_or_path, padding_side="left"
)
pad_token_id = processing_class.pad_token_id
# Reward functions
if not isinstance(reward_funcs, list):
reward_funcs = [reward_funcs]
for i, reward_func in enumerate(reward_funcs):
if isinstance(reward_func, str):
reward_funcs[i] = AutoModelForSequenceClassification.from_pretrained(
reward_func, num_labels=1, **model_init_kwargs
)
self.reward_funcs = reward_funcs
# Reward processing class
if reward_processing_classes is None:
reward_processing_classes = [None] * len(reward_funcs)
elif not isinstance(reward_processing_classes, list):
reward_processing_classes = [reward_processing_classes]
else:
if len(reward_processing_classes) != len(reward_funcs):
raise ValueError(
"The number of reward processing classes must match the number of reward functions."
)
for i, (reward_processing_class, reward_func) in enumerate(
zip(reward_processing_classes, reward_funcs)
):
if isinstance(reward_func, PreTrainedModel):
if reward_processing_class is None:
reward_processing_class = AutoTokenizer.from_pretrained(
reward_func.config._name_or_path
)
if reward_processing_class.pad_token_id is None:
reward_processing_class.pad_token = (
reward_processing_class.eos_token
)
# The reward model computes the reward for the latest non-padded token in the input sequence.
# So it's important to set the pad token ID to the padding token ID of the processing class.
reward_func.config.pad_token_id = reward_processing_class.pad_token_id
reward_processing_classes[i] = reward_processing_class
self.reward_processing_classes = reward_processing_classes
# Data collator
def data_collator(features): # No data collation is needed in GRPO
return features
# Training arguments
self.max_prompt_length = args.max_prompt_length
self.max_completion_length = (
args.max_completion_length
) # = |o_i| in the GRPO paper
self.num_generations = args.num_generations # = G in the GRPO paper
self.generation_config = GenerationConfig(
max_new_tokens=self.max_completion_length,
do_sample=True,
temperature=1, # HACK
num_return_sequences=self.num_generations,
pad_token_id=pad_token_id,
)
self.beta = args.beta
# The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the
# input tensor associated with the key "input_ids". However, in GRPO, the sampled data does not include the
# "input_ids" key. Instead, the available keys is "prompt". As a result, the trainer issues the warning:
# "Could not estimate the number of tokens of the input, floating-point operations will not be computed." To
# suppress this warning, we set the "estimate_tokens" key in the model's "warnings_issued" dictionary to True.
# This acts as a flag to indicate that the warning has already been issued.
model.warnings_issued["estimate_tokens"] = True
# Initialize the metrics
self._metrics = defaultdict(list)
super().__init__(
model=model,
args=args,
data_collator=data_collator,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
processing_class=processing_class,
callbacks=callbacks,
optimizers=optimizers,
)
# Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the
# model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set
# self.model_accepts_loss_kwargs to False to enable scaling.
self.model_accepts_loss_kwargs = False
if self.ref_model is not None:
if self.is_deepspeed_enabled:
self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator)
else:
self.ref_model = self.accelerator.prepare_model(
self.ref_model, evaluation_mode=True
)
for i, reward_func in enumerate(self.reward_funcs):
if isinstance(reward_func, PreTrainedModel):
self.reward_funcs[i] = self.accelerator.prepare_model(
reward_func, evaluation_mode=True
)
def _set_signature_columns_if_needed(self):
# If `self.args.remove_unused_columns` is True, non-signature columns are removed.
# By default, this method sets `self._signature_columns` to the model's expected inputs.
# In GRPOTrainer, we preprocess data, so using the model's signature columns doesn't work.
# Instead, we set them to the columns expected by the `training_step` method, hence the override.
if self._signature_columns is None:
self._signature_columns = ["prompt"]
# Trainer "prepares" the inputs before calling `compute_loss`. It converts to tensor and move to device.
# Since we preprocess the data in `compute_loss`, we need to override this method to skip this step.
def _prepare_inputs(
self, inputs: dict[str, Union[torch.Tensor, Any]]
) -> dict[str, Union[torch.Tensor, Any]]:
return inputs
def compute_loss(
self, model, inputs, return_outputs=False, num_items_in_batch=None
):
if return_outputs:
raise ValueError("The GRPOTrainer does not support returning outputs")
prompts = [x["prompt"] for x in inputs]
prompts_text = [
maybe_apply_chat_template(example, self.processing_class)["prompt"]
for example in inputs
]
images = [x["image"] for x in inputs]
prompt_inputs = self.processing_class(
text=prompts_text,
images=images,
return_tensors="pt",
padding=True,
padding_side="left",
add_special_tokens=False,
)
prompt_inputs = super()._prepare_inputs(prompt_inputs)
if self.max_prompt_length is not None:
prompt_inputs["input_ids"] = prompt_inputs["input_ids"][
:, -self.max_prompt_length :
]
prompt_inputs["attention_mask"] = prompt_inputs["attention_mask"][
:, -self.max_prompt_length :
]
# Generate completions
with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model:
# prompt_completion_ids = unwrapped_model.generate(**prompt_inputs, generation_config=self.generation_config)
# Generate N times, each generate one with the temp_generation_config , stack the output_ids to prompt_completion_ids, pad the empty places with number 151613
num_generations = self.generation_config.num_return_sequences
temp_generation_config = copy.deepcopy(self.generation_config)
temp_generation_config.num_return_sequences = 1
all_completions = []
for i in range(
num_generations
): # -1 because we already have one generation
completion = unwrapped_model.generate(
**prompt_inputs, generation_config=temp_generation_config
)
all_completions.append(completion)
# Stack all completions and pad if needed
max_length = max(completion.size(1) for completion in all_completions)
padded_completions = []
for completion in all_completions:
if completion.size(1) < max_length:
padding = torch.full(
(completion.size(0), max_length - completion.size(1)),
self.processing_class.tokenizer.pad_token_id,
dtype=completion.dtype,
device=completion.device,
)
padded_completion = torch.cat([completion, padding], dim=1)
else:
padded_completion = completion
padded_completions.append(padded_completion)
# Stack all padded completions
prompt_completion_ids = torch.cat(padded_completions, dim=0)
prompt_length = prompt_inputs["input_ids"].size(1)
completion_ids = prompt_completion_ids[:, prompt_length:]
# import pdb;pdb.set_trace()
# Get the per-token log probabilities for the completions for the model and the reference model
def get_per_token_logps(model, input_ids):
logits = model(input_ids).logits # (B, L, V)
logits = logits[
:, :-1, :
] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred
input_ids = input_ids[
:, 1:
] # (B, L-1), exclude the first input ID since we don't have logits for it
# Compute the log probabilities for the input tokens. Use a loop to reduce memory peak.
per_token_logps = []
for logits_row, input_ids_row in zip(logits, input_ids):
log_probs = logits_row.log_softmax(dim=-1)
token_log_prob = torch.gather(
log_probs, dim=1, index=input_ids_row.unsqueeze(1)
).squeeze(1)
per_token_logps.append(token_log_prob)
return torch.stack(per_token_logps)
per_token_logps = get_per_token_logps(model, prompt_completion_ids)
# Get rid of the prompt (-1 because of the shift done in get_per_token_logps)
per_token_logps = per_token_logps[:, prompt_length - 1 :]
with torch.inference_mode():
if self.ref_model is not None:
ref_per_token_logps = get_per_token_logps(
self.ref_model, prompt_completion_ids
)
else:
with self.accelerator.unwrap_model(model).disable_adapter():
ref_per_token_logps = get_per_token_logps(
model, prompt_completion_ids
)
ref_per_token_logps = ref_per_token_logps[:, prompt_length - 1 :]
# Compute the KL divergence between the model and the reference model
per_token_kl = (
torch.exp(ref_per_token_logps - per_token_logps)
- (ref_per_token_logps - per_token_logps)
- 1
)
# Mask everything after the first EOS token
is_eos = completion_ids == self.processing_class.eos_token_id
device = self.accelerator.device
eos_idx = torch.full(
(is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device
)
eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)]
sequence_indices = torch.arange(is_eos.size(1), device=device).expand(
is_eos.size(0), -1
)
completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int()
# Decode the generated completions
completions = self.processing_class.batch_decode(
completion_ids, skip_special_tokens=True
)
if is_conversational(inputs[0]):
completions = [
[{"role": "assistant", "content": completion}]
for completion in completions
]
# Compute the rewards
prompts = [prompt for prompt in prompts for _ in range(self.num_generations)]
rewards_per_func = torch.zeros(
len(prompts), len(self.reward_funcs), device=device
)
for i, (reward_func, reward_processing_class) in enumerate(
zip(self.reward_funcs, self.reward_processing_classes)
):
if isinstance(reward_func, PreTrainedModel):
if is_conversational(inputs[0]):
messages = [
{"messages": p + c} for p, c in zip(prompts, completions)
]
texts = [
apply_chat_template(x, reward_processing_class)["text"]
for x in messages
]
else:
texts = [p + c for p, c in zip(prompts, completions)]
reward_inputs = reward_processing_class(
texts,
return_tensors="pt",
padding=True,
padding_side="right",
add_special_tokens=False,
)
reward_inputs = super()._prepare_inputs(reward_inputs)
with torch.inference_mode():
rewards_per_func[:, i] = reward_func(**reward_inputs).logits[
:, 0
] # Shape (B*G,)
else:
# Repeat all input columns (but "prompt" and "completion") to match the number of generations
reward_kwargs = {
key: []
for key in inputs[0].keys()
if key not in ["prompt", "completion"]
}
for key in reward_kwargs:
for example in inputs:
# Repeat each value in the column for `num_generations` times
reward_kwargs[key].extend([example[key]] * self.num_generations)
output_reward_func = reward_func(
prompts=prompts, completions=completions, **reward_kwargs
)
rewards_per_func[:, i] = torch.tensor(
output_reward_func, dtype=torch.float32, device=device
)
# Sum the rewards from all reward functions
rewards = rewards_per_func.sum(dim=1)
# Compute grouped-wise rewards
mean_grouped_rewards = rewards.view(-1, self.num_generations).mean(dim=1)
std_grouped_rewards = rewards.view(-1, self.num_generations).std(dim=1)
# Normalize the rewards to compute the advantages
mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(
self.num_generations, dim=0
)
std_grouped_rewards = std_grouped_rewards.repeat_interleave(
self.num_generations, dim=0
)
advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-4)
# x - x.detach() allows for preserving gradients from x
per_token_loss = torch.exp(
per_token_logps - per_token_logps.detach()
) * advantages.unsqueeze(1)
per_token_loss = -(per_token_loss - self.beta * per_token_kl)
loss = (
(per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)
).mean()
# Log the metrics
completion_length = (
self.accelerator.gather_for_metrics(completion_mask.sum(1))
.float()
.mean()
.item()
)
self._metrics["completion_length"].append(completion_length)
reward_per_func = self.accelerator.gather_for_metrics(rewards_per_func).mean(0)
for i, reward_func in enumerate(self.reward_funcs):
if isinstance(reward_func, PreTrainedModel):
reward_func_name = reward_func.config._name_or_path.split("/")[-1]
else:
reward_func_name = reward_func.__name__
self._metrics[f"rewards/{reward_func_name}"].append(
reward_per_func[i].item()
)
self._metrics["reward"].append(
self.accelerator.gather_for_metrics(rewards).mean().item()
)
self._metrics["reward_std"].append(
self.accelerator.gather_for_metrics(std_grouped_rewards).mean().item()
)
mean_kl = (
(per_token_kl * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)
).mean()
self._metrics["kl"].append(
self.accelerator.gather_for_metrics(mean_kl).mean().item()
)
return loss
def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
metrics = {
key: sum(val) / len(val) for key, val in self._metrics.items()
} # average the metrics
logs = {**logs, **metrics}
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
super().log(logs, start_time)
else: # transformers<=4.46
super().log(logs)
self._metrics.clear()
def create_model_card(
self,
model_name: Optional[str] = None,
dataset_name: Optional[str] = None,
tags: Union[str, list[str], None] = None,
):
"""
Creates a draft of a model card using the information available to the `Trainer`.
Args:
model_name (`str` or `None`, *optional*, defaults to `None`):
Name of the model.
dataset_name (`str` or `None`, *optional*, defaults to `None`):
Name of the dataset used for training.
tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
Tags to be associated with the model card.
"""
if not self.is_world_process_zero():
return
if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(
self.model.config._name_or_path
):
base_model = self.model.config._name_or_path
else:
base_model = None
tags = tags or []
if isinstance(tags, str):
tags = [tags]
if hasattr(self.model.config, "unsloth_version"):
tags.append("unsloth")
citation = textwrap.dedent(
"""\
@article{zhihong2024deepseekmath,
title = {{DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models}},
author = {Zhihong Shao and Peiyi Wang and Qihao Zhu and Runxin Xu and Junxiao Song and Mingchuan Zhang and Y. K. Li and Y. Wu and Daya Guo},
year = 2024,
eprint = {arXiv:2402.03300},
"""
)
model_card = generate_model_card(
base_model=base_model,
model_name=model_name,
hub_model_id=self.hub_model_id,
dataset_name=dataset_name,
tags=tags,
wandb_url=(
wandb.run.get_url()
if is_wandb_available() and wandb.run is not None
else None
),
comet_url=get_comet_experiment_url(),
trainer_name="GRPO",
trainer_citation=citation,
paper_title="DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models",
paper_id="2402.03300",
)
model_card.save(os.path.join(self.args.output_dir, "README.md"))
+339
View File
@@ -0,0 +1,339 @@
"""
Code referenced from InternVL mDPO
"""
from copy import deepcopy
from typing import Dict, List, Literal, Optional, Tuple, Union
import deepspeed
import torch
from torch import nn
from torch.utils.data import ConcatDataset
from trl import DPOTrainer
from trl.trainer.utils import RunningMoments, pad_to_length
def _map(self, *args, **kwargs):
return self
ConcatDataset.map = _map
class MultimodalDPOTrainer(DPOTrainer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if self.loss_type != "bco_pair" and "bco_pair" in self.loss_type:
self.running = RunningMoments(self.accelerator)
@staticmethod
def concatenated_inputs(
batch: Dict[str, Union[List, torch.LongTensor]],
is_encoder_decoder: bool = False,
is_vision_model: bool = False,
label_pad_token_id: int = -100,
padding_value: int = 0,
device: Optional[torch.device] = None,
) -> Dict[str, torch.LongTensor]:
"""Concatenate the chosen and rejected inputs into a single tensor.
Args:
batch: A batch of data. Must contain the keys 'chosen_input_ids' and 'rejected_input_ids', which are tensors of shape (batch_size, sequence_length).
is_encoder_decoder: Whether the model is an encoder-decoder model.
label_pad_token_id: The label pad token id.
padding_value: The padding value to use for the concatenated inputs_ids.
device: The device for the concatenated inputs.
Returns:
A dictionary containing the concatenated inputs under the key 'concatenated_input_ids'.
"""
concatenated_batch = {}
if is_encoder_decoder:
max_length = max(
batch["chosen_labels"].shape[1], batch["rejected_labels"].shape[1]
)
else:
max_length = max(
batch["chosen_input_ids"].shape[1], batch["rejected_input_ids"].shape[1]
)
for k in batch:
if k.startswith("chosen") and isinstance(batch[k], torch.Tensor):
if "labels" in k or is_encoder_decoder:
pad_value = label_pad_token_id
elif k.endswith("_input_ids"):
pad_value = padding_value
elif k.endswith("_attention_mask"):
pad_value = 0
concatenated_key = k.replace("chosen", "concatenated")
concatenated_batch[concatenated_key] = pad_to_length(
batch[k], max_length, pad_value=pad_value
)
for k in batch:
if k.startswith("rejected") and isinstance(batch[k], torch.Tensor):
if "labels" in k or is_encoder_decoder:
pad_value = label_pad_token_id
elif k.endswith("_input_ids"):
pad_value = padding_value
elif k.endswith("_attention_mask"):
pad_value = 0
concatenated_key = k.replace("rejected", "concatenated")
concatenated_batch[concatenated_key] = torch.cat(
(
concatenated_batch[concatenated_key],
pad_to_length(batch[k], max_length, pad_value=pad_value),
),
dim=0,
).to(device=device)
if is_encoder_decoder:
concatenated_batch["concatenated_input_ids"] = (
batch["prompt_input_ids"].repeat(2, 1).to(device=device)
)
concatenated_batch["concatenated_attention_mask"] = (
batch["prompt_attention_mask"].repeat(2, 1).to(device=device)
)
if "pixel_values" in batch:
concatenated_batch["pixel_values"] = batch["pixel_values"].repeat(
2, 1, 1, 1
)
concatenated_batch["image_flags"] = batch["image_flags"].repeat(2)
return concatenated_batch
def concatenated_forward(
self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]]
) -> Tuple[
torch.FloatTensor,
torch.FloatTensor,
torch.FloatTensor,
torch.FloatTensor,
torch.FloatTensor,
]:
"""Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together.
We do this to avoid doing two forward passes, because it's faster for FSDP.
"""
concatenated_batch = self.concatenated_inputs(
batch,
is_encoder_decoder=self.is_encoder_decoder,
is_vision_model=self.is_vision_model,
label_pad_token_id=self.label_pad_token_id,
padding_value=self.padding_value,
device=self.accelerator.device,
)
len_chosen = batch["chosen_labels"].shape[0]
model_kwargs = {}
if self.is_encoder_decoder:
model_kwargs["labels"] = concatenated_batch["concatenated_labels"]
model_kwargs["decoder_input_ids"] = concatenated_batch.pop(
"concatenated_decoder_input_ids", None
)
if self.is_vision_model:
model_kwargs["pixel_values"] = concatenated_batch["pixel_values"]
model_kwargs["pixel_attention_mask"] = concatenated_batch[
"pixel_attention_mask"
]
if self.aux_loss_enabled:
model_kwargs["output_router_logits"] = True
outputs = model(
input_ids=concatenated_batch["concatenated_input_ids"],
attention_mask=concatenated_batch["concatenated_attention_mask"],
pixel_values=concatenated_batch["pixel_values"],
image_flags=concatenated_batch["image_flags"],
use_cache=False,
**model_kwargs,
)
all_logits = outputs.logits
all_logps, size_completion = self.get_batch_logps(
all_logits,
concatenated_batch["concatenated_labels"],
# average_log_prob=self.loss_type == "ipo",
is_encoder_decoder=self.is_encoder_decoder,
label_pad_token_id=self.label_pad_token_id,
)
def cross_entropy_loss(logits, labels):
if not self.is_encoder_decoder:
# Shift so that tokens < n predict n
logits = logits[..., :-1, :].contiguous()
labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = nn.CrossEntropyLoss()
logits = logits.view(-1, logits.shape[-1])
labels = labels.view(-1)
# Enable model parallelism
labels = labels.to(logits.device)
loss = loss_fct(logits, labels)
return loss
labels = concatenated_batch["concatenated_labels"].clone()
nll_loss = cross_entropy_loss(all_logits[:len_chosen], labels[:len_chosen])
if self.loss_type == "ipo":
all_logps = all_logps / size_completion
chosen_logps = all_logps[:len_chosen]
rejected_logps = all_logps[len_chosen:]
chosen_logits = all_logits[:len_chosen]
rejected_logits = all_logits[len_chosen:]
if self.aux_loss_enabled:
return (
chosen_logps,
rejected_logps,
chosen_logits,
rejected_logits,
nll_loss,
outputs.aux_loss,
)
return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, nll_loss)
def _prepare_deepspeed_orig(self, model):
# Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473
deepspeed_plugin = self.accelerator.state.deepspeed_plugin
config_kwargs = deepcopy(deepspeed_plugin.deepspeed_config)
# If ZeRO-3 is used, we shard both the active and reference model.
# Otherwise, we assume the reference model fits in memory and is initialized on each device with ZeRO disabled (stage 0)
if config_kwargs["zero_optimization"]["stage"] != 3:
config_kwargs["zero_optimization"]["stage"] = 0
model, *_ = deepspeed.initialize(model=model, config=config_kwargs)
model.eval()
return model
def _prepare_deepspeed(self, model):
deepspeed_plugin = self.accelerator.state.deepspeed_plugin
config_kwargs = deepspeed_plugin.deepspeed_config
if config_kwargs["zero_optimization"]["stage"] == 3:
print("Enable DPOTrainer._prepare_deepspeed")
return self._prepare_deepspeed_orig(model)
print("Disable DPOTrainer._prepare_deepspeed")
for param in model.parameters():
param.requires_grad = False
model.eval()
model = model.to(self.accelerator.device)
return model
def get_batch_loss_metrics(
self,
model,
batch: Dict[str, Union[List, torch.LongTensor]],
train_eval: Literal["train", "eval"] = "train",
):
"""Compute the DPO loss and other metrics for the given batch of inputs for train or test."""
metrics = {}
forward_output = self.concatenated_forward(model, batch)
(
policy_chosen_logps,
policy_rejected_logps,
policy_chosen_logits,
policy_rejected_logits,
policy_nll_loss,
) = forward_output[:5]
if self.aux_loss_enabled:
aux_loss = forward_output[5]
# if reference_chosen_logps and reference_rejected_logps in batch use them, otherwise use the reference model
if (
"reference_chosen_logps" in batch
and "reference_rejected_logps" in batch
and self.args.rpo_alpha is not None
):
reference_chosen_logps = batch["reference_chosen_logps"]
reference_rejected_logps = batch["reference_rejected_logps"]
else:
with torch.no_grad():
if self.ref_model is None:
with self.null_ref_context():
(
reference_chosen_logps,
reference_rejected_logps,
_,
_,
_,
) = self.concatenated_forward(self.model, batch)
else:
(
reference_chosen_logps,
reference_rejected_logps,
_,
_,
_,
) = self.concatenated_forward(self.ref_model, batch)
if "," in self.loss_type:
loss_type = self.loss_type
loss_type_list = loss_type.split(",")
losses, chosen_rewards, rejected_rewards = 0, 0, 0
for curr_type in loss_type_list:
self.loss_type = curr_type
curr_losses, curr_chosen_rewards, curr_rejected_rewards = self.dpo_loss(
policy_chosen_logps,
policy_rejected_logps,
reference_chosen_logps,
reference_rejected_logps,
)
curr_weight = getattr(self.args, f"{curr_type}_loss_weight")
losses = losses + curr_losses * curr_weight
chosen_rewards = chosen_rewards + curr_chosen_rewards * curr_weight
rejected_rewards = (
rejected_rewards + curr_rejected_rewards * curr_weight
)
self.loss_type = loss_type
else:
losses, chosen_rewards, rejected_rewards = self.dpo_loss(
policy_chosen_logps,
policy_rejected_logps,
reference_chosen_logps,
reference_rejected_logps,
)
reward_accuracies = (chosen_rewards > rejected_rewards).float()
if self.args.rpo_alpha is not None:
# losses = losses * self.args.rpo_alpha + policy_nll_loss
losses = losses + policy_nll_loss * self.args.rpo_alpha
prefix = "eval_" if train_eval == "eval" else ""
metrics[f"{prefix}rewards/chosen"] = chosen_rewards.mean().cpu()
metrics[f"{prefix}rewards/rejected"] = rejected_rewards.mean().cpu()
metrics[f"{prefix}rewards/accuracies"] = reward_accuracies.mean().cpu()
metrics[f"{prefix}rewards/margins"] = (
(chosen_rewards - rejected_rewards).mean().cpu()
)
metrics[f"{prefix}logps/rejected"] = policy_rejected_logps.detach().mean().cpu()
metrics[f"{prefix}logps/chosen"] = policy_chosen_logps.detach().mean().cpu()
metrics[f"{prefix}logits/rejected"] = (
policy_rejected_logits.detach().mean().cpu()
)
metrics[f"{prefix}logits/chosen"] = policy_chosen_logits.detach().mean().cpu()
if self.args.rpo_alpha is not None:
metrics[f"{prefix}nll_loss"] = policy_nll_loss.detach().mean().cpu()
if self.aux_loss_enabled:
return (
losses.mean()
+ getattr(model.config, "router_aux_loss_coef", 0.0) * aux_loss,
metrics,
)
return losses.mean(), metrics
+572
View File
@@ -0,0 +1,572 @@
import os
import torch
import torch.nn as nn
from torch.utils.data import Sampler
from namo.utils.utils import rank0_print
try:
from trl.trainer import DPOTrainer
from trl.trainer.utils import DPODataCollatorWithPadding
except ImportError as e:
DPOTrainer = object
from transformers import Trainer
from transformers.trainer import (
is_sagemaker_mp_enabled,
get_parameter_names,
has_length,
ALL_LAYERNORM_LAYERS,
logger,
TRAINER_STATE_NAME,
)
from typing import List, Optional
def maybe_zero_3(param, ignore_status=False, name=None):
from deepspeed import zero
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
if hasattr(param, "ds_id"):
if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:
if not ignore_status:
print(name, "no ignore status")
with zero.GatheredParameters([param]):
param = param.data.detach().cpu().clone()
else:
param = param.detach().cpu().clone()
return param
def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match):
to_return = {
k: t
for k, t in named_params
if any(key_match in k for key_match in keys_to_match)
}
to_return = {
k: maybe_zero_3(v, ignore_status=True, name=k).cpu()
for k, v in to_return.items()
}
return to_return
def split_to_even_chunks(indices, lengths, num_chunks):
"""
Split a list of indices into `chunks` chunks of roughly equal lengths.
"""
if len(indices) % num_chunks != 0:
return [indices[i::num_chunks] for i in range(num_chunks)]
num_indices_per_chunk = len(indices) // num_chunks
chunks = [[] for _ in range(num_chunks)]
chunks_lengths = [0 for _ in range(num_chunks)]
for index in indices:
shortest_chunk = chunks_lengths.index(min(chunks_lengths))
chunks[shortest_chunk].append(index)
chunks_lengths[shortest_chunk] += lengths[index]
if len(chunks[shortest_chunk]) == num_indices_per_chunk:
chunks_lengths[shortest_chunk] = float("inf")
return chunks
def get_modality_length_grouped_indices(
lengths, batch_size, world_size, generator=None
):
# We need to use torch for the random part as a distributed sampler will set the random seed for torch.
assert all(l != 0 for l in lengths), "Should not have zero length."
if all(l > 0 for l in lengths) or all(l < 0 for l in lengths):
# all samples are in the same modality
return get_length_grouped_indices(
lengths, batch_size, world_size, generator=generator
)
mm_indices, mm_lengths = zip(*[(i, l) for i, l in enumerate(lengths) if l > 0])
lang_indices, lang_lengths = zip(*[(i, -l) for i, l in enumerate(lengths) if l < 0])
mm_shuffle = [
mm_indices[i]
for i in get_length_grouped_indices(
mm_lengths, batch_size, world_size, generator=None
)
]
lang_shuffle = [
lang_indices[i]
for i in get_length_grouped_indices(
lang_lengths, batch_size, world_size, generator=None
)
]
megabatch_size = world_size * batch_size
mm_megabatches = [
mm_shuffle[i : i + megabatch_size]
for i in range(0, len(mm_shuffle), megabatch_size)
]
lang_megabatches = [
lang_shuffle[i : i + megabatch_size]
for i in range(0, len(lang_shuffle), megabatch_size)
]
last_mm = mm_megabatches[-1]
last_lang = lang_megabatches[-1]
additional_batch = last_mm + last_lang
megabatches = mm_megabatches[:-1] + lang_megabatches[:-1]
megabatch_indices = torch.randperm(len(megabatches), generator=generator)
megabatches = [megabatches[i] for i in megabatch_indices]
if len(additional_batch) > 0:
megabatches.append(sorted(additional_batch))
return [i for megabatch in megabatches for i in megabatch]
def get_length_grouped_indices(
lengths, batch_size, world_size, generator=None, merge=True
):
# We need to use torch for the random part as a distributed sampler will set the random seed for torch.
indices = torch.randperm(len(lengths), generator=generator)
megabatch_size = world_size * batch_size
megabatches = [
indices[i : i + megabatch_size].tolist()
for i in range(0, len(lengths), megabatch_size)
]
megabatches = [
sorted(megabatch, key=lambda i: lengths[i], reverse=True)
for megabatch in megabatches
]
megabatches = [
split_to_even_chunks(megabatch, lengths, world_size)
for megabatch in megabatches
]
return [i for megabatch in megabatches for batch in megabatch for i in batch]
class LengthGroupedSampler(Sampler):
r"""
Sampler that samples indices in a way that groups together features of the dataset of roughly the same length while
keeping a bit of randomness.
"""
def __init__(
self,
batch_size: int,
world_size: int,
lengths: Optional[List[int]] = None,
generator=None,
group_by_modality: bool = False,
):
if lengths is None:
raise ValueError("Lengths must be provided.")
self.batch_size = batch_size
self.world_size = world_size
self.lengths = lengths
self.generator = generator
self.group_by_modality = group_by_modality
def __len__(self):
return len(self.lengths)
def __iter__(self):
if self.group_by_modality:
indices = get_modality_length_grouped_indices(
self.lengths, self.batch_size, self.world_size, generator=self.generator
)
else:
indices = get_length_grouped_indices(
self.lengths, self.batch_size, self.world_size, generator=self.generator
)
return iter(indices)
class NamoTrainer(Trainer):
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
if self.train_dataset is None or not has_length(self.train_dataset):
return None
if self.args.group_by_modality_length:
lengths = self.train_dataset.modality_lengths
return LengthGroupedSampler(
self.args.train_batch_size,
world_size=self.args.world_size * self.args.gradient_accumulation_steps,
lengths=lengths,
group_by_modality=True,
)
else:
return super()._get_train_sampler()
def create_optimizer(self):
"""
Setup the optimizer.
We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
Trainer's init through `optimizers`, or subclass and override this method in a subclass.
"""
if is_sagemaker_mp_enabled():
return super().create_optimizer()
opt_model = self.model
if self.optimizer is None:
decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS)
decay_parameters = [name for name in decay_parameters if "bias" not in name]
if self.args.conn_ve_llm_lr is not None:
projector_parameters = [
name
for name, _ in opt_model.named_parameters()
if "conn_ve_llm" in name
]
if self.args.ve_lr is not None:
vision_tower_parameters = [
name
for name, _ in opt_model.named_parameters()
if ".ve." in name
]
optimizer_grouped_parameters = [
{
"params": [
p
for n, p in opt_model.named_parameters()
if (
n in decay_parameters
and n not in projector_parameters
and n not in vision_tower_parameters
and p.requires_grad
)
],
"weight_decay": self.args.weight_decay,
},
{
"params": [
p
for n, p in opt_model.named_parameters()
if (
n in decay_parameters
and n not in projector_parameters
and n in vision_tower_parameters
and p.requires_grad
)
],
"weight_decay": self.args.weight_decay,
"lr": self.args.ve_lr,
},
{
"params": [
p
for n, p in opt_model.named_parameters()
if (
n not in decay_parameters
and n not in projector_parameters
and n not in vision_tower_parameters
and p.requires_grad
)
],
"weight_decay": 0.0,
},
{
"params": [
p
for n, p in opt_model.named_parameters()
if (
n not in decay_parameters
and n not in projector_parameters
and n in vision_tower_parameters
and p.requires_grad
)
],
"weight_decay": 0.0,
"lr": self.args.ve_lr,
},
{
"params": [
p
for n, p in opt_model.named_parameters()
if (
n in decay_parameters
and n in projector_parameters
and p.requires_grad
)
],
"weight_decay": self.args.weight_decay,
"lr": self.args.conn_ve_llm_lr,
},
{
"params": [
p
for n, p in opt_model.named_parameters()
if (
n not in decay_parameters
and n in projector_parameters
and p.requires_grad
)
],
"weight_decay": 0.0,
"lr": self.args.conn_ve_llm_lr,
},
]
else:
optimizer_grouped_parameters = [
{
"params": [
p
for n, p in opt_model.named_parameters()
if (
n in decay_parameters
and n not in projector_parameters
and p.requires_grad
)
],
"weight_decay": self.args.weight_decay,
},
{
"params": [
p
for n, p in opt_model.named_parameters()
if (
n not in decay_parameters
and n not in projector_parameters
and p.requires_grad
)
],
"weight_decay": 0.0,
},
{
"params": [
p
for n, p in opt_model.named_parameters()
if (
n in decay_parameters
and n in projector_parameters
and p.requires_grad
)
],
"weight_decay": self.args.weight_decay,
"lr": self.args.conn_ve_llm_lr,
},
{
"params": [
p
for n, p in opt_model.named_parameters()
if (
n not in decay_parameters
and n in projector_parameters
and p.requires_grad
)
],
"weight_decay": 0.0,
"lr": self.args.conn_ve_llm_lr,
},
]
else:
if self.args.ve_lr is not None:
vision_tower_parameters = [
name
for name, _ in opt_model.named_parameters()
if ".ve." in name
]
optimizer_grouped_parameters = [
{
"params": [
p
for n, p in opt_model.named_parameters()
if n in vision_tower_parameters
and n in decay_parameters
and p.requires_grad
],
"weight_decay": self.args.weight_decay,
"lr": self.args.ve_lr,
},
{
"params": [
p
for n, p in opt_model.named_parameters()
if n in vision_tower_parameters
and n not in decay_parameters
and p.requires_grad
],
"weight_decay": 0.0,
"lr": self.args.ve_lr,
},
{
"params": [
p
for n, p in opt_model.named_parameters()
if n in vision_tower_parameters
and n in decay_parameters
and p.requires_grad
],
"weight_decay": self.args.weight_decay,
},
{
"params": [
p
for n, p in opt_model.named_parameters()
if n in vision_tower_parameters
and n not in decay_parameters
and p.requires_grad
],
"weight_decay": 0.0,
},
]
else:
optimizer_grouped_parameters = [
{
"params": [
p
for n, p in opt_model.named_parameters()
if (n in decay_parameters and p.requires_grad)
],
"weight_decay": self.args.weight_decay,
},
{
"params": [
p
for n, p in opt_model.named_parameters()
if (n not in decay_parameters and p.requires_grad)
],
"weight_decay": 0.0,
},
]
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(
self.args
)
self.optimizer = optimizer_cls(
optimizer_grouped_parameters, **optimizer_kwargs
)
if optimizer_cls.__name__ == "Adam8bit":
import bitsandbytes
manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
skipped = 0
for module in opt_model.modules():
if isinstance(module, nn.Embedding):
skipped += sum(
{
p.data_ptr(): p.numel() for p in module.parameters()
}.values()
)
logger.info(f"skipped {module}: {skipped/2**20}M params")
manager.register_module_override(
module, "weight", {"optim_bits": 32}
)
logger.debug(f"bitsandbytes: will optimize {module} in fp32")
logger.info(f"skipped: {skipped/2**20}M params")
return self.optimizer
def _save_checkpoint(self, model, trial, metrics=None):
if getattr(self.args, "tune_conn_ve_llm", False):
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
run_dir = self._get_output_dir(trial=trial)
output_dir = os.path.join(run_dir, checkpoint_folder)
rank0_print(f"saving conn_ve_llm to: {output_dir}")
# save all for mm adaptor resume
self.save_model(output_dir, _internal_call=True)
# Only save Adapter
keys_to_match = ["conn_ve_llm", "vision_resampler"]
if getattr(self.args, "use_im_start_end", False):
keys_to_match.extend(["embed_tokens", "embed_in"])
weight_to_save = get_mm_adapter_state_maybe_zero_3(
self.model.named_parameters(), keys_to_match
)
if self.args.local_rank == 0 or self.args.local_rank == -1:
self.model.config.save_pretrained(output_dir)
torch.save(weight_to_save, os.path.join(output_dir, f"conn_ve_llm.bin"))
if not os.path.exists(output_dir):
os.makedirs(output_dir, exist_ok=True)
self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME))
self._rotate_checkpoints(use_mtime=False, output_dir=run_dir)
else:
# call empty cache here?
torch.cuda.empty_cache()
super(NamoTrainer, self)._save_checkpoint(model, trial)
def _save(self, output_dir: Optional[str] = None, state_dict=None):
if getattr(self.args, "tune_conn_ve_llm", False):
pass
else:
super(NamoTrainer, self)._save(output_dir, state_dict)
class NamoDPOTrainer(DPOTrainer):
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
if self.train_dataset is None or not has_length(self.train_dataset):
return None
if self.args.group_by_modality_length:
lengths = self.train_dataset.modality_lengths
return LengthGroupedSampler(
# self.args.train_batch_size * self.args.gradient_accumulation_steps, # TODO: seems that we should not have gradient_accumulation_steps
self.args.train_batch_size,
world_size=self.args.world_size,
lengths=lengths,
group_by_modality=True,
)
else:
return super()._get_train_sampler()
def _save_checkpoint(self, model, trial, metrics=None):
if getattr(self.args, "tune_conn_ve_llm", False) or (
hasattr(self.args, "mm_tunable_parts")
and (
len(self.args.mm_tunable_parts.split(",")) == 1
and (
"mm_mlp_adapter" in self.args.mm_tunable_parts
or "mm_vision_resampler" in self.args.mm_tunable_parts
)
)
):
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
run_dir = self._get_output_dir(trial=trial)
output_dir = os.path.join(run_dir, checkpoint_folder)
rank0_print(f"saving conn_ve_llm weights: {output_dir}")
# Only save Adapter
keys_to_match = ["conn_ve_llm", "vision_resampler"]
if getattr(self.args, "use_im_start_end", False):
keys_to_match.extend(["embed_tokens", "embed_in"])
weight_to_save = get_mm_adapter_state_maybe_zero_3(
self.model.named_parameters(), keys_to_match
)
if self.args.local_rank == 0 or self.args.local_rank == -1:
self.model.config.save_pretrained(output_dir)
torch.save(weight_to_save, os.path.join(output_dir, f"conn_ve_llm.bin"))
else:
if self.args.lora_enable:
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
run_dir = self._get_output_dir(trial=trial)
output_dir = os.path.join(run_dir, checkpoint_folder)
from transformers.modeling_utils import unwrap_model
unwrapped_model = unwrap_model(model)
self.save_my_lora_ckpt(output_dir, self.args, unwrapped_model)
else:
super(NamoDPOTrainer, self)._save_checkpoint(model, trial)
def _save(self, output_dir: Optional[str] = None, state_dict=None):
if getattr(self.args, "tune_conn_ve_llm", False):
pass
else:
super(NamoDPOTrainer, self)._save(output_dir, state_dict)
View File
+682
View File
@@ -0,0 +1,682 @@
import dataclasses
from enum import auto, Enum
from typing import List, Tuple
import base64
from io import BytesIO
from PIL import Image
class SeparatorStyle(Enum):
"""Different separator style."""
SINGLE = auto()
TWO = auto()
MPT = auto()
PLAIN = auto()
LLAMA_2 = auto()
GEMMA = auto()
LLAMA3 = auto()
LLAMA2 = auto()
PHI3 = auto()
@dataclasses.dataclass
class Conversation:
"""A class that keeps all conversation history."""
system: str
roles: List[str]
messages: List[List[str]]
offset: int
sep_style: SeparatorStyle = SeparatorStyle.SINGLE
sep: str = "###"
sep2: str = None
version: str = "Unknown"
stop_str: str = None
# Stops generation if meeting any token in this list
stop_token_ids: List[int] = None
skip_next: bool = False
def get_prompt(self):
messages = self.messages
if len(messages) > 0 and type(messages[0][1]) is tuple:
messages = self.messages.copy()
init_role, init_msg = messages[0].copy()
init_msg = init_msg[0].replace("<image>", "").strip()
if "mmtag" in self.version:
messages[0] = (init_role, init_msg)
messages.insert(0, (self.roles[0], "<Image><image></Image>"))
messages.insert(1, (self.roles[1], "Received."))
else:
messages[0] = (init_role, "<image>\n" + init_msg)
if self.sep_style == SeparatorStyle.SINGLE:
ret = self.system + self.sep
for role, message in messages:
if message:
if type(message) is tuple:
message = message[0]
ret += role + ": " + message + self.sep
else:
ret += role + ":"
elif self.sep_style == SeparatorStyle.TWO:
seps = [self.sep, self.sep2]
ret = self.system + seps[0]
for i, (role, message) in enumerate(messages):
if message:
if type(message) is tuple:
message = message[0]
ret += role + ": " + message + seps[i % 2]
else:
ret += role + ":"
elif self.sep_style == SeparatorStyle.MPT:
ret = self.system + self.sep
for role, message in messages:
if message:
if type(message) is tuple:
message = message[0]
ret += role + message + self.sep
else:
ret += role
elif self.sep_style == SeparatorStyle.LLAMA_2:
def wrap_sys(msg):
return f"<<SYS>>\n{msg}\n<</SYS>>\n\n" if len(msg) > 0 else msg
def wrap_inst(msg):
return f"[INST] {msg} [/INST]"
ret = ""
for i, (role, message) in enumerate(messages):
if i == 0:
assert message, "first message should not be none"
assert role == self.roles[0], "first message should come from user"
if message:
if type(message) is tuple:
message, _, _ = message
if i == 0:
message = wrap_sys(self.system) + message
if i % 2 == 0:
message = wrap_inst(message)
ret += self.sep + message
else:
ret += " " + message + " " + self.sep2
else:
ret += ""
ret = ret.lstrip(self.sep)
# elif self.sep_style == SeparatorStyle.GEMMA:
# ret = self.system
# for i, (role, message) in enumerate(self.messages):
# assert role == self.roles[i % 2], "Conversation should alternate user/assistant/user/assistant/..." # i % 2 -> user/assistant/user/assistant
# if message:
# if type(message) is tuple:
# message, _, _ = message
# ret += role + message + self.sep
# else:
# ret += role
# return ret
elif self.sep_style == SeparatorStyle.GEMMA:
seps = [self.sep, self.sep2]
ret = self.system
for i, (role, message) in enumerate(messages):
if message:
if type(message) is tuple:
message, _, _ = message
ret += seps[0] + role + "\n" + message + seps[1] + "\n"
else:
ret += seps[0] + role + "\n"
# ret = ret.strip() # remove trailing newline
elif self.sep_style == SeparatorStyle.PLAIN:
seps = [self.sep, self.sep2]
ret = self.system
for i, (role, message) in enumerate(messages):
if message:
if type(message) is tuple:
message, _, _ = message
ret += message + seps[i % 2]
else:
ret += ""
elif self.sep_style == SeparatorStyle.LLAMA3:
ret = "<|begin_of_text|>"
ret += self.system
for i, (role, message) in enumerate(self.messages):
if message:
ret += f"<|start_header_id|>{role}<|end_header_id|>\n\n"
ret += f"{message.strip()}<|eot_id|>"
else:
ret += f"<|start_header_id|>{role}<|end_header_id|>\n\n"
return ret
elif self.sep_style == SeparatorStyle.LLAMA2:
seps = [self.sep, self.sep2]
ret = self.system
for i, (role, message) in enumerate(self.messages):
tag = self.roles[i % 2]
if message:
if i == 0:
ret += message + " "
else:
ret += tag + " " + message + seps[i % 2]
else:
ret += tag
return ret
elif self.sep_style == SeparatorStyle.PHI3:
ret = "<|endoftext|>"
ret += self.system
for i, (role, message) in enumerate(self.messages):
if message:
if type(message) is tuple:
message = message[0]
ret += role + message + self.sep
else:
ret += role
return ret
else:
raise ValueError(f"Invalid style: {self.sep_style}")
return ret
def append_message(self, role, message):
self.messages.append([role, message])
def process_image(
self,
image,
image_process_mode,
return_pil=False,
image_format="PNG",
max_len=1344,
min_len=672,
):
if image_process_mode == "Pad":
def expand2square(pil_img, background_color=(122, 116, 104)):
width, height = pil_img.size
if width == height:
return pil_img
elif width > height:
result = Image.new(pil_img.mode, (width, width), background_color)
result.paste(pil_img, (0, (width - height) // 2))
return result
else:
result = Image.new(pil_img.mode, (height, height), background_color)
result.paste(pil_img, ((height - width) // 2, 0))
return result
image = expand2square(image)
elif image_process_mode in ["Default", "Crop"]:
pass
elif image_process_mode == "Resize":
image = image.resize((336, 336))
else:
raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
if max(image.size) > max_len:
max_hw, min_hw = max(image.size), min(image.size)
aspect_ratio = max_hw / min_hw
shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
longest_edge = int(shortest_edge * aspect_ratio)
W, H = image.size
if H > W:
H, W = longest_edge, shortest_edge
else:
H, W = shortest_edge, longest_edge
image = image.resize((W, H))
if return_pil:
return image
else:
buffered = BytesIO()
image.save(buffered, format=image_format)
img_b64_str = base64.b64encode(buffered.getvalue()).decode()
return img_b64_str
def get_images(self, return_pil=False):
images = []
for i, (role, msg) in enumerate(self.messages[self.offset :]):
if i % 2 == 0:
if type(msg) is tuple:
msg, image, image_process_mode = msg
image = self.process_image(
image, image_process_mode, return_pil=return_pil
)
images.append(image)
return images
def to_gradio_chatbot(self):
ret = []
for i, (role, msg) in enumerate(self.messages[self.offset :]):
if i % 2 == 0:
if type(msg) is tuple:
msg, image, image_process_mode = msg
img_b64_str = self.process_image(
image, "Default", return_pil=False, image_format="JPEG"
)
img_str = f'<img src="data:image/jpeg;base64,{img_b64_str}" alt="user upload image" />'
msg = img_str + msg.replace("<image>", "").strip()
ret.append([msg, None])
else:
ret.append([msg, None])
else:
if type(msg) is tuple and len(msg) == 2:
msg, img_b64_str = msg
img_str = f'<img src="data:image/jpeg;base64,{img_b64_str}" alt="user upload image" />'
msg = msg.strip() + img_str
ret[-1][-1] = msg
return ret
def copy(self):
return Conversation(
system=self.system,
roles=self.roles,
messages=[[x, y] for x, y in self.messages],
offset=self.offset,
sep_style=self.sep_style,
sep=self.sep,
sep2=self.sep2,
version=self.version,
)
def dict(self):
if len(self.get_images()) > 0:
return {
"system": self.system,
"roles": self.roles,
"messages": [
[x, y[0] if type(y) is tuple else y] for x, y in self.messages
],
"offset": self.offset,
"sep": self.sep,
"sep2": self.sep2,
}
return {
"system": self.system,
"roles": self.roles,
"messages": self.messages,
"offset": self.offset,
"sep": self.sep,
"sep2": self.sep2,
}
conv_vicuna_v0 = Conversation(
system="A chat between a curious human and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
roles=("Human", "Assistant"),
messages=(
(
"Human",
"What are the key differences between renewable and non-renewable energy sources?",
),
(
"Assistant",
"Renewable energy sources are those that can be replenished naturally in a relatively "
"short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
"Non-renewable energy sources, on the other hand, are finite and will eventually be "
"depleted, such as coal, oil, and natural gas. Here are some key differences between "
"renewable and non-renewable energy sources:\n"
"1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
"energy sources are finite and will eventually run out.\n"
"2. Environmental impact: Renewable energy sources have a much lower environmental impact "
"than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
"and other negative effects.\n"
"3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
"have lower operational costs than non-renewable sources.\n"
"4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
"locations than non-renewable sources.\n"
"5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
"situations and needs, while non-renewable sources are more rigid and inflexible.\n"
"6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
"non-renewable sources are not, and their depletion can lead to economic and social instability.\n",
),
),
offset=2,
sep_style=SeparatorStyle.SINGLE,
sep="###",
)
conv_vicuna_v1 = Conversation(
system="A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
roles=("USER", "ASSISTANT"),
version="v1",
messages=(),
offset=0,
sep_style=SeparatorStyle.TWO,
sep=" ",
sep2="</s>",
)
conv_llama_2 = Conversation(
system="""You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""",
roles=("USER", "ASSISTANT"),
version="llama_v2",
messages=(),
offset=0,
sep_style=SeparatorStyle.LLAMA_2,
sep="<s>",
sep2="</s>",
)
conv_llava_llama_2 = Conversation(
system="You are a helpful language and vision assistant. "
"You are able to understand the visual content that the user provides, "
"and assist the user with a variety of tasks using natural language.",
roles=("USER", "ASSISTANT"),
version="llama_v2",
messages=(),
offset=0,
sep_style=SeparatorStyle.LLAMA_2,
sep="<s>",
sep2="</s>",
)
conv_mpt = Conversation(
system="""<|im_start|>system
A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.""",
roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
version="mpt",
messages=(),
offset=0,
sep_style=SeparatorStyle.MPT,
sep="<|im_end|>",
)
conv_llava_plain = Conversation(
system="",
roles=("", ""),
messages=(),
offset=0,
sep_style=SeparatorStyle.PLAIN,
sep="\n",
)
conv_llava_v0 = Conversation(
system="A chat between a curious human and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
roles=("Human", "Assistant"),
messages=(),
offset=0,
sep_style=SeparatorStyle.SINGLE,
sep="###",
)
conv_llava_v0_mmtag = Conversation(
system="A chat between a curious user and an artificial intelligence assistant. "
"The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
"The visual content will be provided with the following format: <Image>visual content</Image>.",
roles=("Human", "Assistant"),
messages=(),
offset=0,
sep_style=SeparatorStyle.SINGLE,
sep="###",
version="v0_mmtag",
)
conv_llava_v1 = Conversation(
system="A chat between a curious human and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
roles=("USER", "ASSISTANT"),
version="v1",
messages=(),
offset=0,
sep_style=SeparatorStyle.TWO,
sep=" ",
sep2="</s>",
)
conv_vicuna_imgsp_v1 = Conversation(
system="A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
roles=("USER", "ASSISTANT"),
version="imgsp_v1",
messages=(),
offset=0,
sep_style=SeparatorStyle.TWO,
sep=" ",
sep2="</s>",
)
conv_llava_plain_guided = Conversation(
system="",
roles=("", ""),
version="plain_guided",
messages=(),
offset=0,
sep_style=SeparatorStyle.PLAIN,
sep="\n",
)
conv_llava_v1_mmtag = Conversation(
system="A chat between a curious user and an artificial intelligence assistant. "
"The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
"The visual content will be provided with the following format: <Image>visual content</Image>.",
roles=("USER", "ASSISTANT"),
messages=(),
offset=0,
sep_style=SeparatorStyle.TWO,
sep=" ",
sep2="</s>",
version="v1_mmtag",
)
conv_phi_2 = Conversation(
system="A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
roles=("USER", "ASSISTANT"),
version="phi2",
messages=(),
offset=0,
sep_style=SeparatorStyle.TWO,
sep=" ",
sep2="<|endoftext|>",
)
# conv_mistral_instruct = Conversation(
# system="",
# roles=("USER", "ASSISTANT"),
# version="llama_v2",
# messages=(),
# offset=0,
# sep_style=SeparatorStyle.LLAMA_2,
# sep="",
# sep2="</s>",
# )
conv_mistral_instruct = Conversation(
system="",
roles=("USER", "ASSISTANT"),
version="llama_v2",
messages=(),
offset=0,
sep_style=SeparatorStyle.LLAMA_2,
sep="<s>",
sep2="</s>",
)
# conv_gemma = Conversation(
# system="You are a helpful assistant.",
# # system="",
# roles=("<start_of_turn>user\n", "<start_of_turn>model\n"),
# version="gemma",
# messages=(),
# offset=0,
# sep_style=SeparatorStyle.GEMMA,
# sep="<end_of_turn>\n",
# stop_str=["<end_of_turn>"],
# stop_token_ids=[1, 107], # <eos> <end_of_turn>
# )
conv_gemma = Conversation(
system="A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
roles=("user", "model"),
version="gemma",
messages=(),
offset=0,
sep_style=SeparatorStyle.GEMMA,
sep="<start_of_turn>",
sep2="<end_of_turn>",
)
conv_chatml_direct = Conversation(
system="""<|im_start|>system
Answer the questions.""",
roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
version="mpt",
messages=(),
offset=0,
sep_style=SeparatorStyle.MPT,
sep="<|im_end|>",
)
conv_qwen = Conversation(
system="""<|im_start|>system
You should follow the instructions carefully and explain your answers in detail.""",
# system = None,
roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
version="qwen",
messages=[],
offset=0,
sep_style=SeparatorStyle.MPT,
stop_str="<|im_end|>",
sep="<|im_end|>",
sep2="<|endoftext|>",
stop_token_ids=[151643, 151646, 151645],
)
conv_llama3_test = Conversation(
system="""<|start_header_id|>system<|end_header_id|>\n\nYou should follow the instructions carefully and explain your answers in detail.""",
# system = None,
roles=(
"<|start_header_id|>user<|end_header_id|>\n\n",
"<|start_header_id|>assistant<|end_header_id|>\n\n",
),
version="llama3",
messages=(),
offset=0,
sep_style=SeparatorStyle.MPT,
sep="<|eot_id|>\n",
sep2="<|eot_id|>",
)
conv_llama3 = Conversation(
version="llama3",
system="<|start_header_id|>system<|end_header_id|>\n\nYou should follow the instructions carefully and explain your answers in detail.<|eot_id|>",
roles=("user", "assistant"),
sep_style=SeparatorStyle.LLAMA3,
messages=[],
offset=0,
sep="",
stop_str="<|eot_id|>",
stop_token_ids=[128001, 128009, 128007],
)
conv_internlm2 = Conversation(
version="internlm2-chat",
system="<|im_start|>system\nYou are an AI assistant whose name is InternLM (书生·浦语).",
roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
sep_style=SeparatorStyle.MPT,
messages=[],
offset=0,
sep="<|im_end|>",
stop_token_ids=[2, 92543, 92542],
)
conv_glm4 = Conversation(
system="""[gMASK]<sop><|system|>
You should follow the instructions carefully and explain your answers in detail.""",
# system = None,
roles=("<|user|>\n", "<|assistant|>\n"),
version="glm4",
messages=(),
offset=0,
sep_style=SeparatorStyle.MPT,
sep="",
sep2="<|endoftext|>",
stop_token_ids=[151336, 151337], # <|user|> <|assistant|>
)
conv_phi3_instruct = Conversation(
# system="""<|system|>\nYou are a helpful AI assistant.""",
system="",
roles=("<|user|>\n", "<|assistant|>\n"),
version="phi3",
messages=[],
offset=0,
sep_style=SeparatorStyle.PHI3,
sep="<|end|>\n",
)
conv_mistral = Conversation(
version="mistral",
system="[INST] You should follow the instructions carefully and explain your answers in detail.\n",
roles=("[INST]", "[/INST]"),
messages=[],
offset=0,
sep_style=SeparatorStyle.LLAMA2,
sep=" ",
sep2="</s>",
stop_token_ids=[
2,
],
)
default_conversation = conv_vicuna_v1
conv_templates = {
"default": conv_vicuna_v0,
"v0": conv_vicuna_v0,
"v1": conv_vicuna_v1,
"vicuna_v1": conv_vicuna_v1,
"phi_2": conv_phi_2,
"gemma": conv_gemma,
"llama_2": conv_llama_2,
"imgsp_v1": conv_vicuna_imgsp_v1,
"plain_guided": conv_llava_plain_guided,
"mistral_instruct": conv_mistral_instruct,
"chatml_direct": conv_chatml_direct,
"mistral_direct": conv_chatml_direct,
"plain": conv_llava_plain,
"v0_plain": conv_llava_plain,
"llava_v0": conv_llava_v0,
"v0_mmtag": conv_llava_v0_mmtag,
"llava_v1": conv_llava_v1,
"v1_mmtag": conv_llava_v1_mmtag,
"llava_llama_2": conv_llava_llama_2,
"mpt": conv_mpt,
"qwen": conv_qwen,
"llama3_test": conv_llama3_test,
"llama3": conv_llama3,
"internlm2-chat": conv_internlm2,
"glm4": conv_glm4,
"phi3": conv_phi3_instruct,
"mistral": conv_mistral,
"gemma": conv_gemma,
}
if __name__ == "__main__":
# print(default_conversation.get_prompt())
# conv = conv_templates["chatml_direct"]
# conv = conv_templates["phi3"]
# conv = conv_templates["mistral"]
conv = conv_templates["gemma"]
# conv = conv_templates["qwen"]
print(conv)
conv.messages = []
conv.append_message(conv.roles[0], "hello")
conv.append_message(conv.roles[1], "am tallen! nice to meet u.")
conv.append_message(conv.roles[0], "Nice to see u, how dod u do?")
conv.append_message(conv.roles[1], "我很好,请问有什么可以帮你的吗?")
conv.append_message(conv.roles[0], "介绍一下你自己")
conv.append_message(conv.roles[1], "我是ChatGPT,一个人工智能助手")
conv.append_message(conv.roles[0], "你会python嘛?")
conv.append_message(conv.roles[1], None)
a = conv.get_prompt()
print(a)
+237
View File
@@ -0,0 +1,237 @@
import os
from typing import List, Optional, Tuple, Union
import torch
from transformers import AutoModel, AutoTokenizer
from transformers import PreTrainedModel, GenerationMixin
from transformers.modeling_outputs import CausalLMOutputWithPast
from torch import nn
from .utils import rank0_print
from pathlib import Path
import json
from transformers import TrainerState
from peft import PeftModel
import glob
from loguru import logger
def auto_load_model(config):
model_name_or_path = config._name_or_path
if os.path.exists(model_name_or_path):
return AutoModel.from_pretrained(
model_name_or_path, torch_dtype=config.torch_dtype, trust_remote_code=True
)
else:
return AutoModel.from_config(config=config, trust_remote_code=True)
def auto_load_tokenizer(config):
model_name_or_path = config._name_or_path
if hasattr(config, "text_config"):
text_config = config.text_config
text_model_name_or_path = getattr(text_config, "_name_or_path", None)
if text_model_name_or_path:
if os.path.exists(text_model_name_or_path):
model_name_or_path = text_model_name_or_path
return AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True)
def try_resume_conn_weights(model, output_dir, weight_file_name="conn_ve_llm.bin"):
checkpoint_dirs = list(Path(output_dir).glob("checkpoint-*"))
if checkpoint_dirs:
sorted_checkpoints = sorted(
checkpoint_dirs, key=lambda x: int(x.stem.split("-")[-1])
)
latest_checkpoint = sorted_checkpoints[-1]
weights_path = latest_checkpoint / weight_file_name
if weights_path.exists():
if "ve_llm" in weight_file_name:
model.namo.load_conn_ve_llm_weights(weights_path)
rank0_print(f"resumed conn_ve_llm weights from: {weights_path}")
state_path = latest_checkpoint / "trainer_state.json"
if state_path.exists():
state = TrainerState.load_from_json(state_path)
epoch = state.epoch
global_step = state.global_step
max_steps = state.max_steps
rank0_print(
f"tainer state resumed: {state_path}, epoch: {epoch} {global_step}/{max_steps}"
)
return state
return None
else:
rank0_print(f"not resumed as: {weights_path} not found.")
def get_latest_checkpoint(output_dir, weights_file_name=None):
if os.path.exists(output_dir) and os.path.isfile(output_dir):
return output_dir
if weights_file_name is not None and os.path.exists(
os.path.join(output_dir, weights_file_name)
):
return os.path.join(output_dir, weights_file_name)
else:
checkpoint_dirs = list(Path(output_dir).glob("checkpoint-*"))
if checkpoint_dirs:
sorted_checkpoints = sorted(
checkpoint_dirs, key=lambda x: int(x.stem.split("-")[-1])
)
latest_checkpoint = sorted_checkpoints[-1]
if weights_file_name is None:
return latest_checkpoint
weights_path = latest_checkpoint / weights_file_name
rank0_print(f"==> loading conn middle checkpoint from: {weights_path}")
return weights_path
def find_and_merge_lora_adapters(model, model_path):
def find_latest_checkpoint(model_path):
checkpoints = glob.glob(os.path.join(model_path, "checkpoint-*"))
if not checkpoints:
return None
return max(checkpoints, key=os.path.getctime)
lora_path = None
lora_adapters = glob.glob(
os.path.join(model_path, "*.safetensors")
) # 假设适配器是 safetensors 格式
if len(lora_adapters) > 0:
lora_path = lora_adapters[0]
else:
latest_checkpoint = find_latest_checkpoint(model_path)
if latest_checkpoint:
lora_path = latest_checkpoint
if lora_path:
logger.info(f"Merging LoRA adapters: {lora_path}")
model = PeftModel.from_pretrained(model, lora_path)
model = model.merge_and_unload()
return model
class SimpleForCausalLM(PreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
_tp_plan = {"lm_head": "colwise_rep"}
def __init__(self, config):
super().__init__(config)
self.model = None
self.vocab_size = config.vocab_size
# self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# for placeholder, must set in real model.
self.lm_head = None
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.model.embed_tokens
def set_input_embeddings(self, value):
self.model.embed_tokens = value
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def set_decoder(self, decoder):
self.model = decoder
def get_decoder(self):
return self.model
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
num_logits_to_keep: int = 0,
**loss_kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
num_logits_to_keep (`int`, *optional*):
Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
Returns:
Example:
```python
>>> from transformers import AutoTokenizer, Qwen2ForCausalLM
>>> model = Qwen2ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
>>> prompt = "Hey, are you conscious? Can you talk to me?"
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
```"""
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
)
hidden_states = outputs[0]
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
loss = None
if labels is not None:
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
+51
View File
@@ -0,0 +1,51 @@
import requests
from PIL import Image
import base64
from threading import Thread
import io
from transformers import TextStreamer
try:
from datauri import DataURI
except ImportError:
pass
def load_image(image_file):
if image_file.startswith("http") or image_file.startswith("https"):
response = requests.get(image_file)
image = Image.open(io.BytesIO(response.content)).convert("RGB")
else:
image = Image.open(image_file).convert("RGB")
return image
def load_multi_images_maybe(image_files, splitter=" "):
if isinstance(image_files, str):
images = image_files.split(splitter)
else:
images = image_files
return [load_image(i) for i in images]
def url_to_image(img_url: str) -> Image.Image:
if img_url.startswith("http"):
response = requests.get(img_url)
img_data = response.content
elif img_url.startswith("data:"):
img_data = DataURI(img_url).data
else:
img_data = base64.b64decode(img_url)
return Image.open(io.BytesIO(img_data)).convert("RGB")
class CallbackStreamer(TextStreamer):
def __init__(self, tokenizer, callback=None, **kwargs):
super().__init__(tokenizer, **kwargs)
self.callback = callback
def on_finalized_text(self, text: str, stream_end: bool = False):
if self.callback is not None:
self.callback(text)
super().on_finalized_text(text, stream_end)
+910
View File
@@ -0,0 +1,910 @@
import copy
import transformers
from namo.utils import convs as conversation_lib
import torch
from typing import Dict, Sequence
from namo.utils.process_utils import tokenizer_image_token
from namo.models.symbols import DEFAULT_IMAGE_TOKEN, IGNORE_INDEX
from transformers import AutoTokenizer
def preprocess_plain(
sources: Sequence[str],
tokenizer: transformers.PreTrainedTokenizer,
) -> Dict:
# add end signal and concatenate together
conversations = []
for source in sources:
assert len(source) == 2
assert DEFAULT_IMAGE_TOKEN in source[0]["value"]
source[0]["value"] = DEFAULT_IMAGE_TOKEN
conversation = (
source[0]["value"]
+ source[1]["value"]
+ conversation_lib.default_conversation.sep
)
conversations.append(conversation)
# tokenize conversations
input_ids = [
tokenizer_image_token(prompt, tokenizer, return_tensors="pt")
for prompt in conversations
]
targets = copy.deepcopy(input_ids)
for target, source in zip(targets, sources):
tokenized_len = len(tokenizer_image_token(source[0]["value"], tokenizer))
target[:tokenized_len] = IGNORE_INDEX
return dict(input_ids=input_ids, labels=targets)
def preprocess_qwen(
sources, tokenizer: transformers.PreTrainedTokenizer, has_image: bool = False
) -> Dict:
conv = conversation_lib.default_conversation.copy()
roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
# Apply prompt templates
conversations = []
for i, source in enumerate(sources):
if roles[source[0]["from"]] != conv.roles[0]:
# Skip the first one if it is not from human
source = source[1:]
conv.messages = []
for j, sentence in enumerate(source):
role = roles[sentence["from"]]
assert role == conv.roles[j % 2], f"{i}"
conv.append_message(role, sentence["value"])
conversations.append(conv.get_prompt())
# print(f'conversations: {conversations}')
# Tokenize conversations
if has_image:
input_ids = torch.stack(
[
tokenizer_image_token(prompt, tokenizer, return_tensors="pt")
for prompt in conversations
],
dim=0,
)
else:
input_ids = tokenizer(
conversations,
return_tensors="pt",
padding="longest",
max_length=tokenizer.model_max_length,
truncation=True,
).input_ids
targets = input_ids.clone()
assert conv.sep_style == conversation_lib.SeparatorStyle.MPT
# Mask targets
sep = conv.sep + conv.roles[1] # <|im_end|><|im_start|>assistant\n
sep2 = conv.sep + conv.roles[0] # <|im_end|><|im_start|>user\n
for conversation, target in zip(conversations, targets):
total_len = int(target.ne(tokenizer.pad_token_id).sum())
rounds = conversation.split(sep2)
# due to spe2 will involve system, merge it to first round
if len(rounds) > 1:
rounds[0:2] = [sep2.join(rounds[0:2])]
cur_len = 0
target[:cur_len] = IGNORE_INDEX
for i, rou in enumerate(rounds):
if rou == "":
break
parts = rou.split(sep)
if len(parts) != 2:
break
parts[0] += sep
if has_image:
if len(rounds) == 1:
# no need to compensate
round_len = len(tokenizer_image_token(rou, tokenizer))
instruction_len = len(tokenizer_image_token(parts[0], tokenizer))
elif len(rounds) > 1 and i == 0:
round_len = (
len(tokenizer_image_token(rou, tokenizer)) + 1
) # for <|im_end|>
instruction_len = len(tokenizer_image_token(parts[0], tokenizer))
elif len(rounds) > 1 and i == len(rounds) - 1:
round_len = (
len(tokenizer_image_token(rou, tokenizer)) + 3
) # for <|im_start|>user\n last round
instruction_len = (
len(tokenizer_image_token(parts[0], tokenizer)) + 3
) # for <|im_start|>user\n
else:
round_len = (
len(tokenizer_image_token(rou, tokenizer)) + 4
) # for <|im_start|>user\n .. <|im_end|>
instruction_len = (
len(tokenizer_image_token(parts[0], tokenizer)) + 3
) # for <|im_start|>user\n
else:
if len(rounds) == 1:
# no need to compensate
round_len = len(tokenizer(rou).input_ids)
instruction_len = len(tokenizer(parts[0]).input_ids)
elif len(rounds) > 1 and i == 0:
round_len = len(tokenizer(rou).input_ids) + 1 # for <|im_end|>
instruction_len = len(tokenizer(parts[0]).input_ids)
elif len(rounds) > 1 and i == len(rounds) - 1:
round_len = (
len(tokenizer(rou).input_ids) + 3
) # for <|im_start|>user\n last round
instruction_len = (
len(tokenizer(parts[0]).input_ids) + 3
) # for <|im_start|>user\n
else:
round_len = (
len(tokenizer(rou).input_ids) + 4
) # for <|im_start|>user\n .. <|im_end|>
instruction_len = (
len(tokenizer(parts[0]).input_ids) + 3
) # for <|im_start|>user\n
# if i != 0 and not tokenizer.legacy and IS_TOKENIZER_GREATER_THAN_0_14:
# round_len -= 1
# instruction_len -= 1
target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
cur_len += round_len
target[cur_len:] = IGNORE_INDEX
if cur_len < tokenizer.model_max_length:
if cur_len != total_len:
target[:] = IGNORE_INDEX
print(
f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
f" (ignored)"
f"{conversations}"
)
return dict(
input_ids=input_ids,
labels=targets,
)
def preprocess_llama3(
sources, tokenizer: transformers.PreTrainedTokenizer, has_image: bool = False
) -> Dict:
conv = conversation_lib.default_conversation.copy()
roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
# Apply prompt templates
conversations = []
for i, source in enumerate(sources):
if roles[source[0]["from"]] != conv.roles[0]:
# Skip the first one if it is not from human
source = source[1:]
conv.messages = []
for j, sentence in enumerate(source):
role = roles[sentence["from"]]
assert role == conv.roles[j % 2], f"{i}"
conv.append_message(role, sentence["value"])
conversations.append(conv.get_prompt())
# Tokenize conversations
# Note: LLama3 has bos while Qwen bos is null, we don't need add_special_token here. (Already added in template)
if has_image:
input_ids = torch.stack(
[
tokenizer_image_token(
prompt, tokenizer, return_tensors="pt", add_special_tokens=False
)
for prompt in conversations
],
dim=0,
)
else:
input_ids = tokenizer(
conversations,
return_tensors="pt",
padding="longest",
max_length=tokenizer.model_max_length,
truncation=True,
add_special_tokens=False,
).input_ids
targets = input_ids.clone()
assert conv.sep_style == conversation_lib.SeparatorStyle.LLAMA3
# Mask targets
# <|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n
sep = f"<|eot_id|><|start_header_id|>{conv.roles[1]}<|end_header_id|>\n\n"
# <|eot_id|><|start_header_id|>user<|end_header_id|>\n\n
sep2 = f"<|eot_id|><|start_header_id|>{conv.roles[0]}<|end_header_id|>\n\n"
# <|start_header_id|>assistant<|end_header_id|>\n\n [128006, 882, 128007, 271]
# <|eot_id|> [128009]
# print(targets)
# [128000, 128009, 128006, 882, 128007, 271] <|eot_id|><|start_header_id|>user<|end_header_id|>\n\n
# print(f'{tokenizer.encode(sep, add_special_tokens=False)} {sep}')
# [128000, 128009, 128006, 78191, 128007, 271] <|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n
# print(f'{tokenizer.encode(sep2, add_special_tokens=False)} {sep2}')
for conversation, target in zip(conversations, targets):
total_len = int(target.ne(tokenizer.pad_token_id).sum())
rounds = conversation.split(sep2)
# due to spe2 will involve system, merge it to first round
if len(rounds) > 1:
rounds[0:2] = [sep2.join(rounds[0:2])]
cur_len = 0
target[:cur_len] = IGNORE_INDEX
# print(f'rounds: ----> {rounds}')
for i, rou in enumerate(rounds):
if rou == "":
break
parts = rou.split(sep)
if len(parts) != 2:
break
parts[0] += sep
if has_image:
if len(rounds) == 1:
# no need to compensate
round_len = len(
tokenizer_image_token(rou, tokenizer, add_special_tokens=False)
)
instruction_len = len(
tokenizer_image_token(
parts[0], tokenizer, add_special_tokens=False
)
)
elif len(rounds) > 1 and i == 0:
round_len = (
len(
tokenizer_image_token(
rou, tokenizer, add_special_tokens=False
)
)
+ 1
) # for <|eot_id|>
instruction_len = len(
tokenizer_image_token(
parts[0], tokenizer, add_special_tokens=False
)
)
elif len(rounds) > 1 and i == len(rounds) - 1:
# for <|start_header_id|>user<|end_header_id|>\n\n last round <|eot_id|> already have
round_len = (
len(
tokenizer_image_token(
rou, tokenizer, add_special_tokens=False
)
)
+ 4
)
instruction_len = (
len(
tokenizer_image_token(
parts[0], tokenizer, add_special_tokens=False
)
)
+ 4
) # for <|start_header_id|>user<|end_header_id|>\n\n
else:
# for <|start_header_id|>user<|end_header_id|>\n\n .. <|eot_id|>
round_len = (
len(
tokenizer_image_token(
rou, tokenizer, add_special_tokens=False
)
)
+ 5
)
instruction_len = (
len(
tokenizer_image_token(
parts[0], tokenizer, add_special_tokens=False
)
)
+ 4
) # for <|start_header_id|>user<|end_header_id|>\n\n
else:
if len(rounds) == 1:
# no need to compensate
round_len = len(tokenizer(rou, add_special_tokens=False).input_ids)
instruction_len = len(
tokenizer(parts[0], add_special_tokens=False).input_ids
)
elif len(rounds) > 1 and i == 0:
round_len = (
len(tokenizer(rou, add_special_tokens=False).input_ids) + 1
)
# for <|im_end|>
instruction_len = len(
tokenizer(parts[0], add_special_tokens=False).input_ids
)
elif len(rounds) > 1 and i == len(rounds) - 1:
# for <|im_start|>user\n last round
round_len = (
len(tokenizer(rou, add_special_tokens=False).input_ids) + 4
)
# for <|im_start|>user\n
instruction_len = (
len(tokenizer(parts[0], add_special_tokens=False).input_ids) + 4
)
else:
# for <|im_start|>user\n .. <|im_end|>
round_len = (
len(tokenizer(rou, add_special_tokens=False).input_ids) + 5
)
# for <|im_start|>user\n
instruction_len = (
len(tokenizer(parts[0], add_special_tokens=False).input_ids) + 4
)
# if i != 0 and not tokenizer.legacy and IS_TOKENIZER_GREATER_THAN_0_14:
# round_len -= 1
# instruction_len -= 1
target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
cur_len += round_len
target[cur_len:] = IGNORE_INDEX
if cur_len < tokenizer.model_max_length:
if cur_len != total_len:
target[:] = IGNORE_INDEX
print(
f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
f" (ignored)"
)
return dict(
input_ids=input_ids,
labels=targets,
)
def preprocess_mistral(
sources, tokenizer: transformers.PreTrainedTokenizer, has_image: bool = False
) -> Dict:
conv = conversation_lib.default_conversation.copy()
roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
# Apply prompt templates
conversations = []
for i, source in enumerate(sources):
if roles[source[0]["from"]] != conv.roles[0]:
# Skip the first one if it is not from human
source = source[1:]
conv.messages = []
for j, sentence in enumerate(source):
role = roles[sentence["from"]]
assert role == conv.roles[j % 2], f"{i}"
conv.append_message(role, sentence["value"])
conversations.append(conv.get_prompt())
# Tokenize conversations
if has_image:
input_ids = torch.stack(
[
tokenizer_image_token(
prompt, tokenizer, add_special_tokens=False, return_tensors="pt"
)
for prompt in conversations
],
dim=0,
)
else:
input_ids = tokenizer(
conversations,
return_tensors="pt",
padding="longest",
max_length=tokenizer.model_max_length,
truncation=True,
add_special_tokens=False,
).input_ids
targets = input_ids.clone()
"""
<|im_start|>system
You should follow the instructions carefully and explain your answers in detail.<|im_end|><|im_start|>user
hello<|im_end|><|im_start|>assistant
am tallen! nice to meet u.<|im_end|><|im_start|>user
Nice to see u, how dod u do?<|im_end|><|im_start|>assistant
我很好,请问有什么可以帮你的吗?<|im_end|><|im_start|>user
介绍一下你自己<|im_end|><|im_start|>assistant
我是ChatGPT,一个人工智能助手<|im_end|><|im_start|>user
你会python嘛?<|im_end|><|im_start|>assistant
"""
"""
[INST] You should follow the instructions carefully and explain your answers in detail.
hello [/INST] am tallen! nice to meet u.</s>[INST] Nice to see u, how dod u do? [/INST] 我很好,请问有什么可以帮你的吗?</s>[INST] 介绍一下你自己 [/INST] 我是ChatGPT,一个人工智能助手</s>
[INST] You should follow the instructions carefully and explain your answers in detail.\n<image>\nPlease provide the bounding box coordinate of the region this sentence describes: dark suit near between tan and gray coat. [/INST] [0.28, 0.34, 0.4, 0.71]</s>[INST] Please provide a short description for this region: [0.56, 0.35, 0.67, 0.7]. [/INST] Third person from right.</s>[INST] Please provide a short description for this region: [0.66, 0.36, 0.79, 0.72]. [/INST] Red jacket.</s>[INST] Please provide a short description for this region: [0.66, 0.36, 0.79, 0.72]. [/INST] Red jacket.</s>[INST] Please provide a short description for this region: [0.56, 0.35, 0.67, 0.7]. [/INST] Third from right to left person.</s>[INST] Please provide a short description for this region: [0.36, 0.36, 0.5, 0.72]. [/INST] Tan jacket red scarf lady.</s>[INST] Please provide the bounding box coordinate of the region this sentence describes: guy on right. [/INST] [0.65, 0.31, 0.86, 0.74]</s>[INST] Please provide the bounding box coordinate of the region this sentence describes: third skier from rightbeing pointed at. [/INST] [0.56, 0.35, 0.67, 0.7]</s>
"""
# print(f'conversations: {conversations}')
# print(f'targets: {targets}')
# Mask questions
# sep = conv.sep + conv.roles[1] # <|im_end|><|im_start|>assistant\n
# sep2 = conv.sep + conv.roles[0] # <|im_end|><|im_start|>user\n
sep = conv.roles[1] # '[/INST]' [1032, 4, 1032] # Mistral空格会和后面的合并,坑!
# sep2 = conv.sep2 # '</s>[INST]' [2, 3, 1032]
sep2 = conv.sep2 # '</s>' [2, 3, 1032]
for conversation, target in zip(conversations, targets):
total_len = int(target.ne(tokenizer.pad_token_id).sum())
rounds = conversation.split(sep2)
rounds = [r for r in rounds if r != ""]
# due to spe2 will involve system, merge it to first round
# if len(rounds) > 1:
# rounds[0:2] = [sep2.join(rounds[0:2])]
cur_len = 0
target[:cur_len] = IGNORE_INDEX
# mask all answers
for i, rou in enumerate(rounds):
# rounds
# [INST] You should follow the instructions carefully and explain your answers in detail.\nhello [/INST] am tallen! nice to meet u.
# Nice to see u, how dod u do? [/INST] 我很好,请问有什么可以帮你的吗?
if rou == "":
break
parts = rou.split(sep)
if len(parts) != 2:
break
parts[0] += sep
if has_image:
# 补偿只有两种情况,最开头,和其他
if len(rounds) == 1:
# no need to compensate
round_len = (
len(
tokenizer_image_token(
rou, tokenizer, add_special_tokens=False
)
)
+ 1
)
instruction_len = len(
tokenizer_image_token(
parts[0], tokenizer, add_special_tokens=False
)
)
else:
# </s>分割,只需要补齐这个id即可,其他不影响
round_len = (
len(
tokenizer_image_token(
rou, tokenizer, add_special_tokens=False
)
)
+ 1
)
instruction_len = len(
tokenizer_image_token(
parts[0], tokenizer, add_special_tokens=False
)
)
else:
if len(rounds) == 1:
# no need to compensate
round_len = (
len(tokenizer(rou, add_special_tokens=False).input_ids) + 1
)
instruction_len = len(
tokenizer(parts[0], add_special_tokens=False).input_ids
)
else:
round_len = (
len(tokenizer(rou, add_special_tokens=False).input_ids) + 1
)
instruction_len = len(
tokenizer(parts[0], add_special_tokens=False).input_ids
)
target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
cur_len += round_len
# print(target, cur_len)
target[cur_len:] = IGNORE_INDEX
if cur_len < tokenizer.model_max_length:
if cur_len != total_len:
target[:] = IGNORE_INDEX
print(
f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
f" (ignored)"
f"{conversations}"
)
# print(f'final target: {targets}')
return dict(
input_ids=input_ids,
labels=targets,
)
def preprocess_gemma2(
sources, tokenizer: transformers.PreTrainedTokenizer, has_image: bool = False
) -> Dict:
conv = conversation_lib.default_conversation.copy()
roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
# Apply prompt templates
conversations = []
for i, source in enumerate(sources):
if roles[source[0]["from"]] != conv.roles[0]:
# Skip the first one if it is not from human
source = source[1:]
conv.messages = []
for j, sentence in enumerate(source):
role = roles[sentence["from"]]
assert role == conv.roles[j % 2], f"{i}"
conv.append_message(role, sentence["value"])
conversations.append(conv.get_prompt())
add_special_tokens = False
# Tokenize conversations
if has_image:
input_ids = torch.stack(
[
tokenizer_image_token(
prompt,
tokenizer,
add_special_tokens=add_special_tokens,
return_tensors="pt",
)
for prompt in conversations
],
dim=0,
)
else:
input_ids = tokenizer(
conversations,
return_tensors="pt",
padding="longest",
max_length=tokenizer.model_max_length,
truncation=True,
add_special_tokens=add_special_tokens,
).input_ids
targets = input_ids.clone()
"""
You are a helpful assistant.<start_of_turn>user
hello<end_of_turn>
<start_of_turn>model
am tallen! nice to meet u.<end_of_turn>
<start_of_turn>user
Nice to see u, how dod u do?<end_of_turn>
<start_of_turn>model
我很好,请问有什么可以帮你的吗?<end_of_turn>
<start_of_turn>user
介绍一下你自己<end_of_turn>
<start_of_turn>model
我是ChatGPT,一个人工智能助手<end_of_turn>
<start_of_turn>user
你会python嘛?<end_of_turn>
<start_of_turn>model
"""
# print(f'conversations: {conversations}')
# print(f'targets: {targets}')
# Mask questions
sep = conv.roles[1] # <start_of_turn>model\n 3
sep2 = conv.sep + conv.roles[0] # '<end_of_turn>\n<start_of_turn>user\n' 5
for conversation, target in zip(conversations, targets):
total_len = int(target.ne(tokenizer.pad_token_id).sum())
rounds = conversation.split(sep2)
rounds = [r for r in rounds if r != ""]
# due to spe2 will involve system, merge it to first round
# if len(rounds) > 1:
# rounds[0:2] = [sep2.join(rounds[0:2])]
cur_len = 0
target[:cur_len] = IGNORE_INDEX
# mask all answers
for i, rou in enumerate(rounds):
# print(i, rou)
# rounds
# [INST] You should follow the instructions carefully and explain your answers in detail.\nhello [/INST] am tallen! nice to meet u.
# Nice to see u, how dod u do? [/INST] 我很好,请问有什么可以帮你的吗?
if rou == "":
break
parts = rou.split(sep)
if len(parts) != 2:
break
parts[0] += sep
if has_image:
if len(rounds) == 1:
# no need to compensate
round_len = len(
tokenizer_image_token(
rou, tokenizer, add_special_tokens=add_special_tokens
)
)
instruction_len = len(
tokenizer_image_token(
parts[0], tokenizer, add_special_tokens=add_special_tokens
)
)
elif len(rounds) > 1 and i == 0:
round_len = (
len(
tokenizer_image_token(
rou, tokenizer, add_special_tokens=add_special_tokens
)
)
+ 2
) # for <end_of_turn>\n
instruction_len = len(
tokenizer_image_token(
parts[0], tokenizer, add_special_tokens=add_special_tokens
)
)
elif len(rounds) > 1 and i == len(rounds) - 1:
# for <|start_header_id|>user<|end_header_id|>\n\n last round <|eot_id|> already have
round_len = (
len(
tokenizer_image_token(
rou, tokenizer, add_special_tokens=add_special_tokens
)
)
+ 3
)
instruction_len = (
len(
tokenizer_image_token(
parts[0],
tokenizer,
add_special_tokens=add_special_tokens,
)
)
+ 3
) # for <|start_header_id|>user<|end_header_id|>\n\n
else:
# for <|start_header_id|>user<|end_header_id|>\n\n .. <|eot_id|>
round_len = (
len(
tokenizer_image_token(
rou, tokenizer, add_special_tokens=add_special_tokens
)
)
+ 5
)
instruction_len = (
len(
tokenizer_image_token(
parts[0],
tokenizer,
add_special_tokens=add_special_tokens,
)
)
+ 4
) # for <|start_header_id|>user<|end_header_id|>\n\n
else:
if len(rounds) == 1:
# no need to compensate
round_len = len(
tokenizer(rou, add_special_tokens=add_special_tokens).input_ids
)
instruction_len = len(
tokenizer(
parts[0], add_special_tokens=add_special_tokens
).input_ids
)
elif len(rounds) > 1 and i == 0:
round_len = (
len(
tokenizer(
rou, add_special_tokens=add_special_tokens
).input_ids
)
+ 2
)
# for <|im_end|>
instruction_len = len(
tokenizer(
parts[0], add_special_tokens=add_special_tokens
).input_ids
)
elif len(rounds) > 1 and i == len(rounds) - 1:
# for <|im_start|>user\n last round
round_len = (
len(
tokenizer(
rou, add_special_tokens=add_special_tokens
).input_ids
)
+ 3
)
# for <|im_start|>user\n
instruction_len = (
len(
tokenizer(
parts[0], add_special_tokens=add_special_tokens
).input_ids
)
+ 3
)
else:
# for <|im_start|>user\n .. <|im_end|>
round_len = (
len(
tokenizer(
rou, add_special_tokens=add_special_tokens
).input_ids
)
+ 5
)
# for <|im_start|>user\n
instruction_len = (
len(
tokenizer(
parts[0], add_special_tokens=add_special_tokens
).input_ids
)
+ 4
)
target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
cur_len += round_len
# print(target, cur_len)
target[cur_len:] = IGNORE_INDEX
if cur_len < tokenizer.model_max_length:
if cur_len != total_len:
target[:] = IGNORE_INDEX
print(
f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
f" (ignored)"
f"{conversations}"
)
# print(f'final target: {targets}')
return dict(
input_ids=input_ids,
labels=targets,
)
def preprocess_gemma(
sources, tokenizer: transformers.PreTrainedTokenizer, has_image: bool = False
) -> Dict:
conv = conversation_lib.default_conversation.copy()
roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
# Apply prompt templates
conversations = []
for i, source in enumerate(sources):
if roles[source[0]["from"]] != conv.roles[0]:
# Skip the first one if it is not from human
source = source[1:]
conv.messages = []
for j, sentence in enumerate(source):
role = roles[sentence["from"]]
assert role == conv.roles[j % 2], f"{i}"
conv.append_message(role, sentence["value"])
conversations.append(conv.get_prompt().strip())
# Tokenize conversations
if has_image:
input_ids = torch.stack(
[
tokenizer_image_token(prompt, tokenizer, return_tensors="pt")
for prompt in conversations
],
dim=0,
)
else:
input_ids = tokenizer(
conversations,
return_tensors="pt",
padding="longest",
max_length=tokenizer.model_max_length,
truncation=True,
).input_ids
targets = input_ids.clone()
assert conv.sep_style == conversation_lib.SeparatorStyle.GEMMA
# Mask targets
sep = conv.sep + conv.roles[1] + "\n" # <start_of_turn>model\n
round_sep = "\n" + conv.sep + conv.roles[0] + "\n" # \n<start_of_turn>user\n
for conversation, target in zip(conversations, targets):
total_len = int(target.ne(tokenizer.pad_token_id).sum())
rounds = conversation.split(round_sep)
cur_len = 1
if cur_len > 0:
target[:cur_len] = IGNORE_INDEX
for i, rou in enumerate(rounds):
if rou == "":
break
if i != 0:
rou = round_sep + rou
parts = rou.split(sep)
if len(parts) != 2:
break
parts[0] += sep
if has_image:
round_len = (
len(tokenizer_image_token(rou, tokenizer)) - 1
) # -1 for <bos>
instruction_len = (
len(tokenizer_image_token(parts[0], tokenizer)) - 1
) # -1 for <bos>
else:
round_len = len(tokenizer(rou).input_ids) - 1 # -1 for <bos>
instruction_len = len(tokenizer(parts[0]).input_ids) - 1 # -1 for <bos>
target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
cur_len += round_len
target[cur_len:] = IGNORE_INDEX
if cur_len < tokenizer.model_max_length:
if cur_len != total_len:
target[:] = IGNORE_INDEX
print(conversation)
print(
f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
f" (ignored)"
)
return dict(
input_ids=input_ids,
labels=targets,
)
if __name__ == "__main__":
# conversation_lib.default_conversation = conversation_lib.conv_templates['mistral']
# tokenizer_path = 'checkpoints/mistral-nemo-instruct/'
conversation_lib.default_conversation = conversation_lib.conv_templates["gemma"]
tokenizer_path = "checkpoints/gemma-2-9b-it/"
tokenizer = AutoTokenizer.from_pretrained(
tokenizer_path, trust_remote_code=True, use_fast=False
)
print(tokenizer.pad_token_id)
print(tokenizer)
if tokenizer.pad_token_id is None:
tokenizer.pad_token_id = tokenizer.encode("<pad>")
sources = [
[
{"from": "human", "value": "<video>\nWhat is happening in the video?"},
{
"from": "gpt",
"value": "In the video, a man is sitting at a table with cups and glasses. He then proceeds to grab a cup of coffee and an apple while looking at his phone. Later on, he is seen looking at his phone again while sitting at the same table near a window. The man then has some cups in front of him and continues looking at his phone at the same table. Finally, he is seen looking at his phone again while sitting at the table.",
},
{
"from": "human",
"value": "Is the man interacting with any other objects or people?",
},
{
"from": "gpt",
"value": "The man is not seen interacting with any other objects or people in the video. He is just sitting at the table and using his phone.",
},
# {
# "from": "human",
# "value": "Why do you think the man is looking at his phone so much?"
# },
# {
# "from": "gpt",
# "value": "It is not clear why the man is looking at his phone so often. We cannot infer his intentions as the video doesn't provide much context."
# }
]
]
preprocess_mistral(sources, tokenizer, True)
+273
View File
@@ -0,0 +1,273 @@
import ast
import random
import re
import numpy as np
from namo.models.symbols import IMAGE_TOKEN_INDEX
import torch
from PIL import Image
from loguru import logger
try:
from decord import VideoReader, cpu
except ImportError as e:
pass
try:
from moviepy.editor import VideoFileClip
except ImportError as e:
pass
def tokenizer_image_token(
prompt,
tokenizer,
image_token_index=IMAGE_TOKEN_INDEX,
return_tensors=None,
add_special_tokens=True,
):
prompt_chunks = [
tokenizer(chunk, add_special_tokens=add_special_tokens).input_ids
for chunk in prompt.split("<image>")
]
def insert_separator(X, sep):
return [ele for sublist in zip(X, [sep] * len(X)) for ele in sublist][:-1]
input_ids = []
offset = 0
if (
len(prompt_chunks) > 0
and len(prompt_chunks[0]) > 0
and prompt_chunks[0][0] == tokenizer.bos_token_id
):
offset = 1
input_ids.append(prompt_chunks[0][0])
for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
input_ids.extend(x[offset:])
if return_tensors is not None:
if return_tensors == "pt":
return torch.tensor(input_ids, dtype=torch.long)
raise ValueError(f"Unsupported tensor type: {return_tensors}")
return input_ids
def unpad_image(tensor, original_size):
"""
Unpads a PyTorch tensor of a padded and resized image.
Args:
tensor (torch.Tensor): The image tensor, assumed to be in CxHxW format.
original_size (tuple): The original size of the image (height, width).
Returns:
torch.Tensor: The unpadded image tensor.
"""
original_width, original_height = original_size
current_height, current_width = tensor.shape[1:]
original_aspect_ratio = original_width / original_height
current_aspect_ratio = current_width / current_height
if original_aspect_ratio > current_aspect_ratio:
scale_factor = current_width / original_width
new_height = int(original_height * scale_factor)
padding = (current_height - new_height) // 2
unpadded_tensor = tensor[:, padding : current_height - padding, :]
else:
scale_factor = current_height / original_height
new_width = int(original_width * scale_factor)
padding = (current_width - new_width) // 2
unpadded_tensor = tensor[:, :, padding : current_width - padding]
return unpadded_tensor
def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
"""
Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
Args:
image_size (tuple): The size of the input image in the format (width, height).
grid_pinpoints (str): A string representation of a list of possible resolutions.
patch_size (int): The size of each image patch.
Returns:
tuple: The shape of the image patch grid in the format (width, height).
"""
if type(grid_pinpoints) is list:
possible_resolutions = grid_pinpoints
else:
possible_resolutions = ast.literal_eval(grid_pinpoints)
width, height = select_best_resolution(image_size, possible_resolutions)
return width // patch_size, height // patch_size
def select_best_resolution(original_size, possible_resolutions):
"""
Selects the best resolution from a list of possible resolutions based on the original size.
Args:
original_size (tuple): The original size of the image in the format (width, height).
possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...].
Returns:
tuple: The best fit resolution in the format (width, height).
"""
original_width, original_height = original_size
best_fit = None
max_effective_resolution = 0
min_wasted_resolution = float("inf")
for width, height in possible_resolutions:
scale = min(width / original_width, height / original_height)
downscaled_width, downscaled_height = int(original_width * scale), int(
original_height * scale
)
effective_resolution = min(
downscaled_width * downscaled_height, original_width * original_height
)
wasted_resolution = (width * height) - effective_resolution
if effective_resolution > max_effective_resolution or (
effective_resolution == max_effective_resolution
and wasted_resolution < min_wasted_resolution
):
max_effective_resolution = effective_resolution
min_wasted_resolution = wasted_resolution
best_fit = (width, height)
return best_fit
def process_video_fixed_frames(video_file, fps, num_frames):
def sample_frames(frame_indices):
total_frames = len(frame_indices)
if total_frames > num_frames:
chunk_size = total_frames // num_frames
frame_indices = [
random.sample(
frame_indices[
i * chunk_size : min((i + 1) * chunk_size, total_frames)
],
1,
)[0]
for i in range(num_frames)
]
else:
frame_indices = np.interp(
np.linspace(0, total_frames - 1, num_frames),
np.arange(total_frames),
frame_idx,
).astype(int)
return frame_indices
if video_file.endswith("webm"):
video_webm = VideoFileClip(video_file)
video_frames = np.array(list(video_webm.iter_frames()))
duration, sample_fps = len(video_frames), round(video_webm.fps / fps)
frame_idx = [i for i in range(0, duration, sample_fps)]
frame_idx = sample_frames(frame_idx)
video = video_frames[frame_idx]
return video
else:
vr = VideoReader(video_file, ctx=cpu(0))
sample_fps = round(vr.get_avg_fps() / fps)
frame_idx = [i for i in range(0, len(vr), sample_fps)]
frame_idx = sample_frames(frame_idx)
# random sample 1 frame based on max video frames_num
video = vr.get_batch(frame_idx).asnumpy()
return video
def convert_image_tags(input_string):
if input_string.count("<image>") <= 1:
return input_string
count = 0
def replacer(match):
nonlocal count
count += 1
return f"\nImage{count}:{match.group()}"
return re.sub(r"<image>", replacer, input_string).strip()
def get_suitable_size_hw(images, longest_edge=800):
hs = [img.height for img in images]
ws = [img.width for img in images]
ratios = [h / w for h, w in zip(hs, ws)]
sorted_indices = sorted(range(len(ratios)), key=lambda i: abs(ratios[i] - 1))
k = int(len(images) * 0.75)
selected = sorted_indices[:k]
selected_ratios = [ratios[i] for i in selected]
target_ratio = np.median(selected_ratios)
h_q3 = np.percentile([hs[i] for i in selected], 75)
w_q3 = np.percentile([ws[i] for i in selected], 75)
sum_hw = h_q3 + w_q3
W_initial = sum_hw / (target_ratio + 1)
H_initial = target_ratio * W_initial
H = int(round(H_initial / 14) * 14)
W = int(round(W_initial / 14) * 14)
max_edge = max(H, W)
if max_edge > longest_edge:
new_max = (longest_edge // 14) * 14
if H > W:
H = new_max
W = int(round(H / target_ratio / 14) * 14)
else:
W = new_max
H = int(round(W * target_ratio / 14) * 14)
H, W = max(392, H), max(392, W)
return (H, W)
def resize_pad_images_to_target(images, target_size_hw):
H_target, W_target = target_size_hw
processed_images = []
for img in images:
W, H = img.width, img.height
aspect_ratio = W / H
target_aspect = W_target / H_target
if aspect_ratio >= target_aspect:
# 缩放宽度到目标宽度,调整高度
new_w = W_target
new_h = int(round(H * (new_w / W)))
else:
# 缩放高度到目标高度,调整宽度
new_h = H_target
new_w = int(round(W * (new_h / H)))
# 调整图像尺寸
if new_w > 0 and new_h > 0:
resized_img = img.resize((new_w, new_h), Image.BILINEAR)
# 创建填充后的图像
padded_img = Image.new(img.mode, (W_target, H_target), color=0)
padded_img.paste(resized_img, (0, 0))
else:
logger.info(
f"unexpected new_h: {new_h} new_w: {new_w} got. forcely resize into target {H_target}x{W_target}."
)
padded_img = img.resize((W_target, H_target), Image.BILINEAR)
processed_images.append(padded_img)
return processed_images
def smart_resize_v1(images, longest_edge=800):
if len(images) == 1:
return images
target_hw = get_suitable_size_hw(images, longest_edge)
processed_images = resize_pad_images_to_target(images, target_hw)
return processed_images
+204
View File
@@ -0,0 +1,204 @@
import datetime
import logging
import logging.handlers
import mimetypes
import os
import sys
import torch
import torch.distributed as dist
import requests
import transformers
from loguru import logger
def disable_torch_init():
"""
Disable the redundant torch default initialization to accelerate model creation.
"""
import torch
setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
def is_dist_avail_and_initialized():
if not dist.is_available():
return False
if not dist.is_initialized():
return False
return True
def get_rank():
if not is_dist_avail_and_initialized():
return 0
return dist.get_rank()
def is_main_process():
return get_rank() == 0
def rank0_print(*args):
if is_main_process():
print(*args)
def is_image(url_or_path):
# If it's a local file path, convert it to an absolute path
url_or_path = url_or_path.split(" ")[0]
if os.path.exists(url_or_path):
url_or_path = os.path.abspath(url_or_path)
mimetype, encoding = mimetypes.guess_type(url_or_path)
return (mimetype and mimetype.startswith("image")) or url_or_path.endswith("webp")
def load_conn_weights(conn_model_path, model_conn, module_key="conn_ve_llm"):
mm_projector_weights = torch.load(conn_model_path, map_location="cpu")
def get_w(weights, keyword):
return {
k.split(keyword + ".")[1]: v for k, v in weights.items() if keyword in k
}
mm_projector_weights = get_w(mm_projector_weights, module_key)
try:
model_conn.load_state_dict(mm_projector_weights, strict=False)
if is_main_process():
logger.info(f"conn weights loaded from: {conn_model_path}")
except Exception as e:
print(f"got error load state dict: {e}")
model_conn.load_state_dict(
{
k: v
for k, v in mm_projector_weights.items()
if "layers.1" not in k and "layers.0" not in k
},
strict=False,
)
print(f"{module_key} partially loaded!")
def maybe_zero_3(param, ignore_status=False, name=None):
from deepspeed import zero
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
if hasattr(param, "ds_id"):
if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:
if not ignore_status:
logging.warning(
f"{name}: param.ds_status != ZeroParamStatus.NOT_AVAILABLE: {param.ds_status}"
)
with zero.GatheredParameters([param]):
param = param.data.detach().cpu().clone()
else:
param = param.detach().cpu().clone()
return param
# Borrowed from peft.utils.get_peft_model_state_dict
def get_peft_state_maybe_zero_3(named_params, bias):
if bias == "none":
to_return = {k: t for k, t in named_params if "lora_" in k}
elif bias == "all":
to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k}
elif bias == "lora_only":
to_return = {}
maybe_lora_bias = {}
lora_bias_names = set()
for k, t in named_params:
if "lora_" in k:
to_return[k] = t
bias_name = k.split("lora_")[0] + "bias"
lora_bias_names.add(bias_name)
elif "bias" in k:
maybe_lora_bias[k] = t
for k, t in maybe_lora_bias:
if bias_name in lora_bias_names:
to_return[bias_name] = t
else:
raise NotImplementedError
to_return = {k: maybe_zero_3(v, ignore_status=True) for k, v in to_return.items()}
return to_return
def get_peft_state_non_lora_maybe_zero_3(named_params, require_grad_only=True):
to_return = {k: t for k, t in named_params if "lora_" not in k}
if require_grad_only:
to_return = {k: t for k, t in to_return.items() if t.requires_grad}
to_return = {
k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()
}
return to_return
def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match):
to_return = {
k: t
for k, t in named_params
if any(key_match in k for key_match in keys_to_match)
}
to_return = {
k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()
}
return to_return
def find_all_linear_names(model):
cls = torch.nn.Linear
lora_module_names = set()
# {'up_proj', 'v_proj', 'gate_proj', 'k_proj', 'down_proj', 'q_proj', 'o_proj', 'lm_head'}
multimodal_keywords = ["conn_ve_llm", "ve", "vision_resampler"]
# multimodal_keywords = ['mm_projector', 'vision_resampler']
for name, module in model.named_modules():
if any(mm_keyword in name for mm_keyword in multimodal_keywords):
continue
if isinstance(module, cls):
names = name.split(".")
lora_module_names.add(names[0] if len(names) == 1 else names[-1])
rank0_print(f"==> lora modules: {lora_module_names}")
if "lm_head" in lora_module_names: # needed for 16-bit
lora_module_names.remove("lm_head")
return list(lora_module_names)
def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str):
"""Collects the state dict and dump to disk."""
if getattr(trainer.args, "tune_conn_ve_llm", False):
# Only save Adapter
keys_to_match = ["conn_ve_llm"]
if getattr(trainer.args, "use_im_start_end", False):
keys_to_match.extend(["embed_tokens", "embed_in"])
weight_to_save = get_mm_adapter_state_maybe_zero_3(
trainer.model.named_parameters(), keys_to_match
)
trainer.model.config.save_pretrained(output_dir)
current_folder = output_dir.split("/")[-1]
parent_folder = os.path.dirname(output_dir)
if trainer.args.local_rank == 0 or trainer.args.local_rank == -1:
if current_folder.startswith("checkpoint-"):
mm_projector_folder = os.path.join(parent_folder, "conn_ve_llm")
os.makedirs(mm_projector_folder, exist_ok=True)
torch.save(
weight_to_save,
os.path.join(mm_projector_folder, f"{current_folder}.bin"),
)
else:
torch.save(weight_to_save, os.path.join(output_dir, f"conn_ve_llm.bin"))
return
if trainer.deepspeed:
torch.cuda.synchronize()
trainer.save_model(output_dir)
return
state_dict = trainer.model.state_dict()
if trainer.args.should_save:
cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()}
del state_dict
trainer._save(output_dir, state_dict=cpu_state_dict) # noqa
+21
View File
@@ -0,0 +1,21 @@
from datetime import datetime
major_num = 1
__version__ = "0.0.4"
short_version = __version__
def parse_version_info(version_str):
version_info = []
for x in version_str.split("."):
if x.isdigit():
version_info.append(int(x))
elif x.find("rc") != -1:
patch_version = x.split("rc")
version_info.append(int(patch_version[0]))
version_info.append(f"rc{patch_version[1]}")
return tuple(version_info)
version_info = parse_version_info(__version__)
+19
View File
@@ -0,0 +1,19 @@
import torch
def load(pretrained_model_path):
"""
returns model, tokenizer, image_processor
"""
pass
def generate(model, tokenizer, prompt, images):
pass
if __name__ == "__main__":
model = load("")
output = generate(
model,
)
+142
View File
@@ -0,0 +1,142 @@
<div align='center'>
<img src='https://img2023.cnblogs.com/blog/3572323/202502/3572323-20250221161349804-1020036173.png' style="border-radius: 15px;" />
<h1>Namo R1</h1>
</div>
<p align="center">
🤗 <a href="https://huggingface.co/lucasjin/Namo-500M-V1">Namo-500M-V1</a>&nbsp&nbsp | &nbsp&nbsp🐝 <a href="https://github.com/lucasjinreal/Namo-R1/issues/new">Community</a>
</p
> **You**: *I don't have GPUs to run VLMs.* **Namo R1:** Hold my beer.... let's do this on CPU.
**Namo R1** 🔥🔥 surpassed SmolVLM and Moondream2 in terms of same size! And we are keep evolving, more advanced models are under training!
## Introduction
We are excited to open-source **Namo**, an extremly small yet mighty MLLM. While numerous MLLMs exist, few offer true extensibility or fully open-source their training data, model architectures, and training schedulers - critical components for reproducible AI research.
The AI community has largely overlooked the potential of compact MLLMs, despite their demonstrated efficiency advantages. Our analysis reveals significant untapped potential in sub-billion parameter models, particularly for edge deployment and specialized applications. To address this gap, we're releasing Namo R1, a foundational 500M parameter model trained from scratch using innovative architectural choices.
Key innovations include:
1. **CPU friendly:** Even on CPUs, Namo R1 can runs very fast;
2. **Omni-modal Scalability:** Native support for future expansion into audio (ASR/TTS) and cross-modal fusion;
3. **Training Transparency:** Full disclosure of data curation processes and dynamic curriculum scheduling techniques.
👇 Video Demo Runs on **CPU**:
<video src='https://github.com/user-attachments/assets/eb353124-509e-4b87-8a0d-b0b37b5efba2
'></video>
## Updates
- **`2025.02.21`**: more to come...!
- **`2025.02.21`**: 🔥🔥 The first version is ready to open, fire the MLLM power able to runs on CPU!
- **`2025.02.17`**: Namo R1 start training.
## Results
the result might keep updating as new models trained.
| Model | MMB-EN-T | MMB-CN-T | Size |
| -------------------- | -------------- | -------------- | ---- |
| Namo-500M | **68.8** | **48.7** | 500M |
| Namo-700M | training | training | 700M |
| Namo-500M-R1 | training | training | 500M |
| Namo-700M-R1 | training | training | 700M |
| SmolVLM-500M | 53.8 | 35.4 | 500M |
| SmolVLM-Instruct-DPO | 67.5 | 49.8 | 2.3B |
| Moondream1 | 62.3 | 19.8 | 1.9B |
| Moondream2 | 70 | 28.7 | 1.9B |
⚠️ Currently, the testing has only been conducted on a limited number of benchmarks. In the near future, more metrics will be reported. Even so, we've observed significant improvements compared to other small models.
## Get Started
#### Install & Run in Cli
All you need to do is:
```shell
pip install -U namo
```
A simple demo would be:
```python
from namo.api.vl import VLInfer
# model will download automatically
model = VLInfer(model_type='namo')
# default will have streaming
model.generate('what is this?', 'images/cats.jpg', stream=True)
```
That's all!
For cli multi-turn chat in terminal you can run `python demo.py`. (Namo cli directly in your terminal would be avaiable later.)
#### OpenAI server & Run in OpenWebUI
```shell
namo server --model checkpoints/Namo-500M-V1
```
then, you will have OpenAI like serving in local.
## Showcases
**Namo-500M**, our first small series of models, is capable of performing remarkable tasks such as multilingual OCR, general concept understanding, image captioning, and more. And it has only 500 million parameters! You can run it directly on a CPU!
<details>
<summary><strong>📁 Show more real use cases</strong></summary>
![img](https://img2023.cnblogs.com/blog/3572323/202502/3572323-20250220172027839-313683339.png)
![img](https://img2023.cnblogs.com/blog/3572323/202502/3572323-20250220173348864-1017625952.png)
![img](https://img2023.cnblogs.com/blog/3572323/202502/3572323-20250220172131111-556988890.png)
![img](https://img2023.cnblogs.com/blog/3572323/202502/3572323-20250220172105348-2075807231.png)
![img](https://img2023.cnblogs.com/blog/3572323/202502/3572323-20250220172241158-980404927.png)
![img](https://img2023.cnblogs.com/blog/3572323/202502/3572323-20250220172453851-1606010737.png)
![img](https://img2023.cnblogs.com/blog/3572323/202502/3572323-20250220172546006-49136083.png)
![img](https://img2023.cnblogs.com/blog/3572323/202502/3572323-20250220173000613-625271011.png)
</details>
## Features of Namo R1
In contrast to open-source VLMs like Qwen2.5-3B and MiniCPM, the Namo series offers the following features that enable anyone to train their own VLMs from scratch:
- **Extremely Small**: Our first series has only 500 million parameters yet powerful on various tasks.
- **OCR Capability**: With just a 500M model, you can perform multilingual OCR, covering not only Chinese and English but also Japanese and other languages.
- **Dynamic Resolution**: We support native dynamic resolution as input, making it robust for images of **any ratio**.
- **Fully Open Source**: We opensource all model codes including training steps and scripts!
- **R1 Support**: Yes, we now support R1 for post-training.
Above all, we are also ready to help when u want train your MLLM from scratch at any tasks!
## Roadmap
We are still actively training on new models, here are few things we will arrive:
- Speech model;
- Vision model with more decent vision encoders, such as SigLip2;
- TTS ability;
- Slightly larger models, up to 7B;
## Trouble Shooting
1. Got error when using deepspeed: ` AssertionError: no_sync context manager is incompatible with gradient partitioning logic of ZeRO stage 2` ?
Please upgrade transformers to 4.48+ and use latest deepspeed.
## Copyright
All right reserved by Namo authors, code released under MIT License.
+3
View File
@@ -0,0 +1,3 @@
latext/
latext*/
invoices*/
+99
View File
@@ -0,0 +1,99 @@
"""
convert anyword into llava like format
"""
import sys
import json
TYPE = "ego"
TYPE = "webvid"
TYPE = "videochat2"
def convert_json(input_json, json_file_path):
data = input_json["data_list"]
print(f"processing data list: {len(data)}")
output_data = []
for item in data:
if "annotations" in item.keys():
new_item = {
# "id": f"ego_video_{video_path}",
"id": item["img_name"],
# "video": f"split_videos/{video_path}",
"image": f"images/{item['img_name']}",
}
new_item["conversations"] = []
ocr_text = []
all_lans = []
for i, QA_data in enumerate(item["annotations"]):
ocr_text.append(QA_data["text"])
all_lans.append(QA_data["language"])
should_use = True
if "laion" in json_file_path and len(all_lans) > 0:
if any(lang != "Latin" for lang in all_lans):
should_use = False
# print(item)
else:
print(item)
# if len(new_item["conversations"]) > 5:
# # depart conversations into 2 parts
# for i in range(0, len(new_item["conversations"]), 10):
# data_dict_i = {}
# data_dict_i["id"] = new_item["id"] + f"_{i//10}"
# data_dict_i["video"] = new_item["video"]
# data_dict_i["conversations"] = new_item["conversations"][i : i + 10]
# if i != 0:
# data_dict_i["conversations"][0][
# "value"
# ] = f"<video>\n{data_dict_i['conversations'][0]['value']}"
# output_data.append(data_dict_i)
# else:
if should_use:
new_item["conversations"].append(
{
"from": "human",
"value": "<image>\nPlease provide precise OCR result of the image.",
}
)
new_item["conversations"].append(
{"from": "gpt", "value": "\n".join(ocr_text)}
)
output_data.append(new_item)
return output_data
def main():
if len(sys.argv) != 2:
print("Usage: python script.py <path_to_json_file>")
sys.exit(1)
json_file_path = sys.argv[1]
try:
with open(json_file_path, "r") as file:
input_json = json.load(file)
except FileNotFoundError:
print(f"Error: The file {json_file_path} does not exist.")
sys.exit(1)
except json.JSONDecodeError:
print(f"Error: The file {json_file_path} is not a valid JSON file.")
sys.exit(1)
converted_data = convert_json(input_json, json_file_path)
print(f"All {len(converted_data)} samples")
# file_path = "ego_video.json"
file_path = f"{json_file_path[:-5]}_ocr.json"
with open(file_path, "w") as f:
json.dump(converted_data, f, ensure_ascii=False, indent=2)
if __name__ == "__main__":
main()
+77
View File
@@ -0,0 +1,77 @@
"""
convert
https://huggingface.co/datasets/amaye15/invoices-google-ocr/
https://huggingface.co/datasets/mychen76/invoices-and-receipts_ocr_v2
"""
import json
from datasets import load_dataset
import uuid
import os
import ast
import sys
all_res = []
root_path = sys.argv[1]
ds = load_dataset("parquet", data_files=os.path.join(root_path, "data/*.parquet"))
def build_item(item, idx):
# print(item)
new_item = {
"id": f"{idx}",
"image": f"images/{idx}.jpg",
}
new_item["conversations"] = []
os.makedirs("invoices-ocr/images", exist_ok=True)
if "image" in item.keys():
image = item["image"].convert("RGB")
else:
image = item["pixel_values"].convert("RGB")
img_f = os.path.join("invoices-ocr", new_item["image"])
if not os.path.exists(img_f):
image.save(img_f)
if "ocr" in item.keys():
if len(item["ocr"]) < 1:
return None
text = item["ocr"][0]["text"]
else:
a = json.loads(item["raw_data"])
ll = ast.literal_eval(a["ocr_words"])
# print(ll)
text = "\n".join(ll)
new_item["conversations"].append(
{
"from": "human",
"value": "<image>\nRead all text content visible on image in order.",
}
)
new_item["conversations"].append({"from": "gpt", "value": text})
# print(new_item)
return new_item
print(ds.items())
ds_name = os.path.basename(root_path)
idx = 0
for item in ds["train"]:
new_item = build_item(item, f"invoices_{ds_name}_{idx}")
if new_item is None:
continue
all_res.append(new_item)
idx += 1
print(f"all res: {len(all_res)}")
file_path = f"invoices_ocr_{ds_name}.json"
with open(file_path, "w") as f:
json.dump(all_res, f, ensure_ascii=False, indent=2)
+68
View File
@@ -0,0 +1,68 @@
"""
convert
https://huggingface.co/datasets/linxy/LaTeX_OCR
"""
import json
from datasets import load_dataset
import uuid
import os
all_res = []
ds = load_dataset("parquet", data_files="latex/full/*.parquet")
def build_item(item, idx):
# print(item)
new_item = {
"id": f"{idx}",
"image": f"images/{idx}.jpg",
}
new_item["conversations"] = []
os.makedirs("latex_ocr/images", exist_ok=True)
item["image"] = item["image"].convert("RGB")
item["image"].save(os.path.join("latex_ocr", new_item["image"]))
new_item["conversations"].append(
{"from": "human", "value": "<image>\nPlease convert this into latex format."}
)
new_item["conversations"].append({"from": "gpt", "value": item["text"]})
return new_item
print(ds.items())
idx = 0
for item in ds["train"]:
new_item = build_item(item, f"latex-ocr-full-train-{idx}")
all_res.append(new_item)
idx += 1
# idx = 0
# for item in ds['test']:
# new_item = build_item(item, f'latex-ocr-full-test-{idx}')
# all_res.append(new_item)
# idx += 1
ds = load_dataset("linxy/LaTeX_OCR", "human_handwrite", streaming=True)
idx = 0
for item in ds["train"]:
new_item = build_item(item, f"latex-ocr-human_handwrite-train-{idx}")
all_res.append(new_item)
idx += 1
# idx = 0
# for item in ds['test']:
# new_item = build_item(item, f'latex-ocr-human_handwrite-test-{idx}')
# all_res.append(new_item)
# idx += 1
print(f"all res: {len(all_res)}")
file_path = f"latex_ocr.json"
with open(file_path, "w") as f:
json.dump(all_res, f, ensure_ascii=False, indent=2)
+4
View File
@@ -0,0 +1,4 @@
"""
HuggingFaceM4/LLaVAR-Instruct-16K
"""
+99
View File
@@ -0,0 +1,99 @@
"""
sampling vary data
"""
import json
import os
import sys
import shutil
from PIL import Image
a = sys.argv[1]
def cn():
img_root = os.path.join(os.path.dirname(a), "pdf_cn_30w")
sample_img_root = os.path.join(os.path.dirname(a), "pdf_cn_30w_samples")
os.makedirs(sample_img_root, exist_ok=True)
res = json.load(open(a, "r"))
samples = []
for i, itm in enumerate(res):
img_f = os.path.join(img_root, itm["image"])
if not os.path.exists(img_f):
print(f"{img_f} not found")
if i < 100:
target_img_f = os.path.join(sample_img_root, itm["image"])
os.makedirs(os.path.dirname(target_img_f), exist_ok=True)
shutil.copy(img_f, target_img_f)
samples.append(itm)
print(f"done {len(res)}")
file_path = a.replace(".json", "_samples.json")
with open(file_path, "w") as f:
json.dump(samples, f, ensure_ascii=False, indent=2)
def en():
img_root = os.path.join(os.path.dirname(a), "pdf_en_30w")
sample_img_root = os.path.join(os.path.dirname(a), "pdf_en_30w_samples")
os.makedirs(sample_img_root, exist_ok=True)
res = json.load(open(a, "r"))
samples = []
new_res = []
for i, itm in enumerate(res):
img_f = os.path.join(img_root, itm["image"])
if not os.path.exists(img_f):
print(f"{img_f} not found")
continue
else:
new_res.append(itm)
if i < 100:
target_img_f = os.path.join(sample_img_root, itm["image"])
os.makedirs(os.path.dirname(target_img_f), exist_ok=True)
shutil.copy(img_f, target_img_f)
samples.append(itm)
print(f"done {len(res)}")
file_path = a.replace(".json", "_samples.json")
with open(file_path, "w") as f:
json.dump(samples, f, ensure_ascii=False, indent=2)
file_path = a.replace(".json", "_subset.json")
with open(file_path, "w") as f:
json.dump(new_res, f, ensure_ascii=False, indent=2)
print(f"done {len(new_res)}")
def cn_subset():
"""
choose those relatively smaller size images out
"""
img_root = os.path.join(os.path.dirname(a), "pdf_cn_30w")
# sample_img_root = os.path.join(os.path.dirname(a), 'pdf_cn_30w_samples')
# os.makedirs(sample_img_root, exist_ok=True)
res = json.load(open(a, "r"))
samples = []
for i, itm in enumerate(res):
img_f = os.path.join(img_root, itm["image"])
if not os.path.exists(img_f):
print(f"{img_f} not found")
image = Image.open(img_f)
if image.size[0] < 660 or image.size[1] < 660:
samples.append(itm)
print(f"done {len(res)}")
print(f"done {len(samples)}")
file_path = a.replace(".json", "_subset.json")
with open(file_path, "w") as f:
json.dump(samples, f, ensure_ascii=False, indent=2)
# en()
cn_subset()
+56
View File
@@ -0,0 +1,56 @@
"""
convert
https://huggingface.co/datasets/linxy/LaTeX_OCR
"""
import json
from datasets import load_dataset
import uuid
import os
import sys
all_res = []
root_path = sys.argv[1]
ds = load_dataset("parquet", data_files=os.path.join(root_path, "data/*.parquet"))
def build_item(item, idx):
# print(item)
new_item = {
"id": f"{idx}",
"image": f"images/{idx}.jpg",
}
new_item["conversations"] = []
os.makedirs("ZhEn_latex_ocr/images", exist_ok=True)
item["image"] = item["image"].convert("RGB")
item["image"].save(os.path.join("ZhEn_latex_ocr", new_item["image"]))
new_item["conversations"].append(
{
"from": "human",
"value": "<image>\nPlease convert all text in image into precise latex format.",
}
)
new_item["conversations"].append({"from": "gpt", "value": item["text"]})
return new_item
print(ds.items())
idx = 0
for item in ds["train"]:
new_item = build_item(item, f"zh-en-latex-ocr-full-train-{idx}")
all_res.append(new_item)
idx += 1
print(f"all res: {len(all_res)}")
file_path = f"ZhEn_latex_ocr.json"
with open(file_path, "w") as f:
json.dump(all_res, f, ensure_ascii=False, indent=2)
+63
View File
@@ -0,0 +1,63 @@
import torch
from transformers import AutoModel, AutoTokenizer
import argparse
import os
def convert_and_save_bf16(model_path, output_dir=None):
try:
if output_dir is None:
output_dir = model_path.strip("/") + "_bf16"
os.makedirs(output_dir, exist_ok=True)
print(f"⏳ 正在加载原始模型来自: {model_path}")
model = AutoModel.from_pretrained(
model_path,
torch_dtype=torch.bfloat16, # 初始加载为BF16
# device_map="auto", # 自动分配设备
low_cpu_mem_usage=True, # 优化内存使用
trust_remote_code=True,
)
print("🔧 正在转换模型权重到BF16...")
model = model.to(torch.bfloat16)
print(f"💾 正在保存BF16模型到: {output_dir}")
model.save_pretrained(
output_dir,
safe_serialization=True, # 使用safetensors格式
max_shard_size="6GB", # 分片大小
)
# 保存tokenizer
try:
tokenizer = AutoTokenizer.from_pretrained(
model_path, trust_remote_code=True
)
tokenizer.save_pretrained(output_dir)
except Exception as e:
print("passing save tokenzier.")
print("✅ 转换完成!保存内容:")
print(f" - 模型权重: {output_dir}/pytorch_model*.bin")
print(f" - 配置文件: {output_dir}/config.json")
print(f" - Tokenizer文件: {output_dir}/tokenizer.*")
except Exception as e:
print(f"❌ 错误发生: {str(e)}")
raise
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="转换HF模型到BF16格式")
parser.add_argument(
"model_path",
type=str,
help="输入模型路径(本地目录或HF Hub名称)",
)
parser.add_argument("--output_dir")
args = parser.parse_args()
convert_and_save_bf16(args.model_path, args.output_dir)
+74
View File
@@ -0,0 +1,74 @@
from transformers import AutoConfig
from transformers import TextStreamer
from namo.models.namo import NamoForCausalLM
from namo.models.configuration_namo import NamoConfig
from namo.utils.infer_utils import load_multi_images_maybe
from namo.utils.process_utils import tokenizer_image_token
import torch
from loguru import logger
import sys
"""
<|im_start|>system\nYou should follow the instructions carefully and explain your answers in detail.<|im_end|><|im_start|>user\n<imag
e>\nDescribe the following image.<|im_end|><|im_start|>assistant\n
"""
torch.set_grad_enabled(False)
device = "cuda" if torch.cuda.is_available() else "cpu"
if len(sys.argv) == 1:
model_path = "checkpoints/namo-500m"
else:
model_path = sys.argv[1]
logger.info(f"load namo from: {model_path}")
namo_model = NamoForCausalLM.from_pretrained(model_path).to(device)
logger.success("namo model all loaded.")
image_processor = namo_model.get_vision_tower().image_processor
# images = load_multi_images_maybe("images/cats.jpg")
images = load_multi_images_maybe("images/kobe.jpg")
pixel_values = (
image_processor.preprocess(images, return_tensors="pt")["pixel_values"]
.to(namo_model.device)
.to(namo_model.dtype)
)
print(f"pixel_values: {pixel_values.shape}")
tokenizer = namo_model.get_namo().tokenizer
chat = [
{
"role": "system",
"content": "You should follow the instructions carefully and explain your answers in detail.",
},
{"role": "user", "content": "<image>\nDescribe the following image."},
]
prompt = tokenizer.apply_chat_template(chat, tokenize=False) + "<|im_start|>assistant\n"
print(prompt)
input_ids = (
tokenizer_image_token(
prompt,
tokenizer,
return_tensors="pt",
)
.unsqueeze(0)
.to(namo_model.device)
)
print(input_ids)
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
with torch.autocast(device_type="cuda", dtype=torch.float16):
output_ids = namo_model.generate(
pixel_values=pixel_values,
input_ids=input_ids,
do_sample=False,
max_new_tokens=360,
streamer=streamer,
use_cache=True,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id,
)
outputs = tokenizer.decode(output_ids[0], skip_special_tokens=True).strip()
print(f"final output:\n{outputs}")
+9
View File
@@ -0,0 +1,9 @@
from namo.api.vl import VLInfer
vl = VLInfer(model_type="qwen2.5-vl")
# vl.generate("what is funny in this image?", "images/extreme_ironing.jpg")
# vl.generate("Outline the position of each car, output in json format", "images/extreme_ironing.jpg")
# vl.generate("Locate the person ironing cloth", "images/extreme_ironing.jpg")
# vl.generate("Point the blue shirt", "images/extreme_ironing.jpg")
vl.generate("Point all the cars in image", "images/extreme_ironing.jpg")
+5
View File
@@ -0,0 +1,5 @@
mkdir checkpoints
cd checkpoints/
huggingface-cli download Qwen/Qwen2.5-0.5B-Instruct --local-dir Qwen2.5-0.5B-Instruct
huggingface-cli download lucasjin/aimv2-large-patch14-224 --local-dir aimv2-large-patch14-224
huggingface-cli download lucasjin/aimv2-large-patch14-native --local-dir aimv2-large-patch14-native
+313
View File
@@ -0,0 +1,313 @@
"""
Runs evaluation with VLMEvalKit.
Namo's result can be easily replicated with it.
"""
import torch
import torch.distributed as dist
from vlmeval.smp import *
from vlmeval.inference import infer_data_job
from vlmeval.config import supported_VLM
from vlmeval.utils.result_transfer import MMMU_result_transfer, MMTBench_result_transfer
from vlmeval.dataset import build_dataset
from functools import partial
from vlmeval.vlm import Namo
from loguru import logger
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--data", type=str, nargs="+", required=True)
parser.add_argument("--model", type=str, required=True)
# Args that only apply to Video Dataset
parser.add_argument("--nframe", type=int, default=8)
parser.add_argument("--pack", action="store_true")
parser.add_argument(
"--work-dir",
type=str,
default="./eval_results/",
help="select the output directory",
)
parser.add_argument("--mode", type=str, default="all", choices=["all", "infer"])
parser.add_argument("--nproc", type=int, default=4, help="Parallel API calling")
parser.add_argument(
"--retry", type=int, default=None, help="retry numbers for API VLMs"
)
parser.add_argument("--judge", type=str, default=None)
parser.add_argument("--ignore", action="store_true", help="Ignore failed indices. ")
parser.add_argument("--verbose", action="store_true")
parser.add_argument("--rerun", action="store_true")
args = parser.parse_args()
return args
def main():
logger = get_logger("RUN")
args = parse_args()
assert len(args.data), "--data should be a list of data files"
# insert local model
if "namo-500m" in args.model.lower():
supported_VLM.update({"Namo-500M": partial(Namo, model_path=args.model)})
args.model = ["Namo-500M"]
logger.info(f"eval on model: {args.model}")
if args.retry is not None:
for k, v in supported_VLM.items():
if hasattr(v, "keywords") and "retry" in v.keywords:
v.keywords["retry"] = args.retry
supported_VLM[k] = v
if hasattr(v, "keywords") and "verbose" in v.keywords:
v.keywords["verbose"] = args.verbose
supported_VLM[k] = v
rank, world_size = get_rank_and_world_size()
if world_size > 1:
local_rank = os.environ.get("LOCAL_RANK", 0)
torch.cuda.set_device(int(local_rank))
dist.init_process_group(
backend="nccl", timeout=datetime.timedelta(seconds=10800)
)
for _, model_name in enumerate(args.model):
model = None
pred_root = osp.join(args.work_dir, model_name)
os.makedirs(pred_root, exist_ok=True)
for _, dataset_name in enumerate(args.data):
try:
dataset_kwargs = {}
if dataset_name in [
"MMLongBench_DOC",
"DUDE",
"DUDE_MINI",
"SLIDEVQA",
"SLIDEVQA_MINI",
]:
dataset_kwargs["model"] = model_name
if dataset_name == "MMBench-Video":
dataset_kwargs["pack"] = args.pack
if dataset_name == "Video-MME":
dataset_kwargs["use_subtitle"] = args.use_subtitle
# If distributed, first build the dataset on the main process for doing preparation works
if world_size > 1:
dataset = (
build_dataset(dataset_name, **dataset_kwargs)
if rank == 0
else None
)
dist.barrier()
dataset_list = [dataset]
dist.broadcast_object_list(dataset_list, src=0)
dataset = dataset_list[0]
else:
dataset = build_dataset(dataset_name, **dataset_kwargs)
if dataset is None:
logger.error(
f"Dataset {dataset_name} is not valid, will be skipped. "
)
continue
result_file = f"{pred_root}/{model_name}_{dataset_name}.xlsx"
if dataset_name in ["MMBench-Video"]:
packstr = "pack" if args.pack else "nopack"
result_file = f"{pred_root}/{model_name}_{dataset_name}_{args.nframe}frame_{packstr}.xlsx"
elif dataset.MODALITY == "VIDEO":
if args.pack:
logger.info(
f"{dataset_name} not support Pack Mode, directly change to unpack"
)
args.pack = False
packstr = "pack" if args.pack else "nopack"
result_file = f"{pred_root}/{model_name}_{dataset_name}_{args.nframe}frame_{packstr}.xlsx"
if dataset_name in ["Video-MME"]:
subtitlestr = "subs" if args.use_subtitle else "nosubs"
result_file = result_file.replace(
".xlsx", f"_{subtitlestr}.xlsx"
)
if dataset.TYPE == "MT":
result_file = result_file.replace(".xlsx", ".tsv")
if osp.exists(result_file) and args.rerun:
for keyword in ["openai", "gpt", "auxmatch"]:
os.system(
f"rm {pred_root}/{model_name}_{dataset_name}_{keyword}*"
)
if model is None:
model = model_name # which is only a name
# Perform the Inference
if dataset.MODALITY == "VIDEO":
model = infer_data_job_video(
model,
work_dir=pred_root,
model_name=model_name,
dataset=dataset,
nframe=args.nframe,
pack=args.pack,
verbose=args.verbose,
subtitle=args.use_subtitle,
api_nproc=args.nproc,
)
elif dataset.TYPE == "MT":
model = infer_data_job_mt(
model,
work_dir=pred_root,
model_name=model_name,
dataset=dataset,
verbose=args.verbose,
api_nproc=args.nproc,
ignore_failed=args.ignore,
)
else:
model = infer_data_job(
model,
work_dir=pred_root,
model_name=model_name,
dataset=dataset,
verbose=args.verbose,
api_nproc=args.nproc,
ignore_failed=args.ignore,
)
# Set the judge kwargs first before evaluation or dumping
judge_kwargs = {
"nproc": args.nproc,
"verbose": args.verbose,
}
if args.retry is not None:
judge_kwargs["retry"] = args.retry
if args.judge is not None:
judge_kwargs["model"] = args.judge
else:
if dataset.TYPE in ["MCQ", "Y/N"] or listinstr(
["MathVerse"], dataset_name
):
judge_kwargs["model"] = "chatgpt-0125"
elif listinstr(
[
"MMVet",
"MathVista",
"LLaVABench",
"MMBench-Video",
"MathVision",
],
dataset_name,
):
judge_kwargs["model"] = "gpt-4-turbo"
elif listinstr(
[
"MMLongBench",
"MMDU",
"DUDE",
"DUDE_MINI",
"SLIDEVQA",
"SLIDEVQA_MINI",
],
dataset_name,
):
judge_kwargs["model"] = "gpt-4o"
if "OPENAI_API_KEY_JUDGE" in os.environ and len(
os.environ["OPENAI_API_KEY_JUDGE"]
):
judge_kwargs["key"] = os.environ["OPENAI_API_KEY_JUDGE"]
if "OPENAI_API_BASE_JUDGE" in os.environ and len(
os.environ["OPENAI_API_BASE_JUDGE"]
):
judge_kwargs["api_base"] = os.environ["OPENAI_API_BASE_JUDGE"]
if rank == 0:
if dataset_name in ["MMMU_TEST"]:
result_json = MMMU_result_transfer(result_file)
logger.info(
f"Transfer MMMU_TEST result to json for official evaluation, "
f"json file saved in {result_json}"
) # noqa: E501
continue
elif "MMT-Bench_ALL" in dataset_name:
submission_file = MMTBench_result_transfer(
result_file, **judge_kwargs
)
logger.info(
f"Extract options from prediction of MMT-Bench FULL split for official evaluation "
f"(https://eval.ai/web/challenges/challenge-page/2328/overview), "
f"submission file saved in {submission_file}"
) # noqa: E501
continue
elif "MLLMGuard_DS" in dataset_name:
logger.info(
"The evaluation of MLLMGuard_DS is not supported yet. "
) # noqa: E501
continue
elif "AesBench_TEST" == dataset_name:
logger.info(
f"The results are saved in {result_file}. "
f"Please send it to the AesBench Team via huangyipo@hotmail.com."
) # noqa: E501
continue
if dataset_name in [
"MMBench_TEST_CN",
"MMBench_TEST_EN",
"MMBench",
"MMBench_CN",
"MMBench_TEST_CN_V11",
"MMBench_TEST_EN_V11",
"MMBench_V11",
"MMBench_CN_V11",
]:
if not MMBenchOfficialServer(dataset_name):
logger.error(
f"Can not evaluate {dataset_name} on non-official servers, "
"will skip the evaluation. "
)
continue
eval_proxy = os.environ.get("EVAL_PROXY", None)
old_proxy = os.environ.get("HTTP_PROXY", "")
if rank == 0 and args.mode == "all":
if eval_proxy is not None:
proxy_set(eval_proxy)
eval_results = dataset.evaluate(result_file, **judge_kwargs)
if eval_results is not None:
assert isinstance(eval_results, dict) or isinstance(
eval_results, pd.DataFrame
)
logger.info(
f"The evaluation of model {model_name} x dataset {dataset_name} has finished! "
)
logger.info("Evaluation Results:")
if isinstance(eval_results, dict):
logger.info("\n" + json.dumps(eval_results, indent=4))
elif isinstance(eval_results, pd.DataFrame):
if len(eval_results) < len(eval_results.columns):
eval_results = eval_results.T
logger.info("\n" + tabulate(eval_results))
if eval_proxy is not None:
proxy_set(old_proxy)
except Exception as e:
logger.exception(
f"Model {model_name} x Dataset {dataset_name} combination failed: {e}, "
"skipping this combination."
)
continue
if world_size > 1:
dist.barrier()
if __name__ == "__main__":
load_env()
main()
+51
View File
@@ -0,0 +1,51 @@
"""
Extracting VE from a base trained model
before sft.
"""
from transformers import AutoConfig
from transformers import TextStreamer
from namo.models.namo import NamoForCausalLM
from namo.models.configuration_namo import NamoConfig
from namo.utils.infer_utils import load_multi_images_maybe
from namo.utils.process_utils import tokenizer_image_token
import torch
from loguru import logger
import sys
torch.set_grad_enabled(False)
device = "cuda" if torch.cuda.is_available() else "cpu"
def main():
if len(sys.argv) < 2:
print("provide the pretrained model path please.")
exit()
model_path = sys.argv[1]
logger.info(f"load namo from: {model_path}")
namo_model = NamoForCausalLM.from_pretrained(model_path).to(device)
logger.success("namo model all loaded.")
ve = namo_model.get_vision_tower()
image_processor = ve.image_processor
tokenizer = namo_model.get_namo().tokenizer
if "aimv2-large-patch14-native" in ve.vision_tower_name:
save_model_path = "checkpoints/aimv2-l-native-trained-base"
elif "aimv2-3b-p14" in model_path:
save_model_path = "checkpoints/aimv2-3b-p14-trained-base"
else:
logger.info(f"unsupported vision model type: {ve.vision_tower_name}")
ve.save_pretrained(save_model_path)
image_processor.save_pretrained(save_model_path)
logger.success(f"ve should be saved into: {save_model_path}")
if __name__ == "__main__":
main()
+41
View File
@@ -0,0 +1,41 @@
"""
We need get these images haded vary data.
"""
import json
import os
import sys
def filter_json_by_image(input_json_path, image_root, output_json_path):
# 打开并读取 JSON 文件
with open(input_json_path, "r", encoding="utf-8") as f:
data = json.load(f)
# 存储符合条件的项目
filtered_data = []
for item in data:
image_path = item.get("image")
if image_path: # 如果有图片路径
# 生成完整的图片路径
full_image_path = os.path.join(image_root, image_path)
# 检查图片文件是否存在
if os.path.exists(full_image_path):
filtered_data.append(item)
# 将过滤后的数据写入新文件
with open(output_json_path, "w", encoding="utf-8") as f:
json.dump(filtered_data, f, indent=2, ensure_ascii=False)
print(f"samples all: {len(filtered_data)}")
# 示例用法
input_json_path = sys.argv[1]
image_root = "data/"
output_json_path = os.path.join(os.path.dirname(input_json_path), "vary_filtered.json")
filter_json_by_image(input_json_path, image_root, output_json_path)
+6
View File
@@ -0,0 +1,6 @@
"""
Loading non-lora weigths and lora weights, merge them into pretrained model.
"""
+54
View File
@@ -0,0 +1,54 @@
"""
test a simple multimodal request
if you have run a Namo model in your local:
namo server --model namo
"""
from openai import OpenAI
import base64
import os
# 配置本地API信息
client = OpenAI(
base_url="http://127.0.0.1:8000/v1", # 根据实际API路径调整
api_key="sk-r4536ybrtb", # 如果不需要认证可以留空
)
def encode_image(image_path):
"""将本地图片编码为base64"""
with open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode("utf-8")
img_f = "images/cats.jpg"
# 多模态请求(文本 + 图片)
response = client.chat.completions.create(
model="gpt-4-vision-preview", # 根据实际部署模型调整
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": "请描述这张图片的内容"},
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{encode_image(img_f)}"
},
},
],
}
],
max_tokens=300,
stream=True,
)
# for rsp in response.choices:
# print("AI回复:", rsp.message.content)
# print("AI回复:", response.choices[0].message.content)
print(response)
for chunk in response:
# print(chunk)
print(chunk.choices[0].delta.content, end="", flush=True)
Executable
+6
View File
@@ -0,0 +1,6 @@
# autopep8 -r ./minigemini/ -i
git add .
git commit -am 'add'
git push origin main
git push github main
+6
View File
@@ -0,0 +1,6 @@
from namo.utils.process_utils import convert_image_tags
# a = convert_image_tags('what in these images?\n<image>')
a = convert_image_tags("what in these images?\n<image>\n<image>")
print(a)
+26
View File
@@ -0,0 +1,26 @@
from transformers.models.qwen2_5_vl import Qwen2_5_VLImageProcessor
from PIL import Image
from namo.processor.image_processing_namo import NamoImageProcessor
def load_images():
imgs = [
"images/3001910.jpg",
"images/3001927.jpg",
"images/3001980.jpg",
"images/grey.jpg",
]
res = []
for im in imgs:
res.append(Image.open(im))
return res
imgs = load_images()
# processor = Qwen2_5_VLImageProcessor.from_pretrained(
# "checkpoints/Qwen2.5-VL-3B-Instruct"
# )
processor = NamoImageProcessor.from_pretrained("checkpoints/Namo-500M-V1")
inputs = processor.preprocess(images=imgs)
print(inputs)
print([i.shape for i in inputs["pixel_values"]])
+82
View File
@@ -0,0 +1,82 @@
from transformers import AutoConfig
from transformers import TextStreamer
from namo.models.namo import NamoForCausalLM
from namo.models.configuration_namo import NamoConfig
from namo.utils.infer_utils import load_multi_images_maybe
from namo.utils.process_utils import tokenizer_image_token
import torch
torch.set_grad_enabled(False)
device = "cuda" if torch.cuda.is_available() else "cpu"
text_config = AutoConfig.from_pretrained(
"checkpoints/Qwen2.5-0.5B-Instruct", trust_remote_code=True
)
vision_config = AutoConfig.from_pretrained(
# "checkpoints/aimv2-large-patch14-native", trust_remote_code=True
"checkpoints/aimv2-l-native-trained-base",
trust_remote_code=True,
)
config = NamoConfig(text_config=text_config, vision_config=vision_config)
namo_model = NamoForCausalLM(config=config).to(device)
namo_model.namo.load_conn_ve_llm_weights(
"checkpoints/namo-qwen2-500m-aimv2-native-conn-ve-mlp2x_gelu/checkpoint-2500/conn_ve_llm.bin"
)
image_processor = namo_model.get_vision_tower().image_processor
images = load_multi_images_maybe("images/cats.jpg")
pixel_values = (
image_processor.preprocess(images, return_tensors="pt")["pixel_values"]
.to(namo_model.dtype)
.to(namo_model.device)
)
print(f"pixel_values: {pixel_values.shape}")
tokenizer = namo_model.get_namo().tokenizer
chat = [
{
"role": "system",
"content": "You should follow the instructions carefully and explain your answers in detail.",
},
{"role": "user", "content": "<image>\nDescribe the following image."},
]
prompt = tokenizer.apply_chat_template(chat, tokenize=False) + "<|im_start|>assistant\n"
print(prompt)
input_ids = (
tokenizer_image_token(
prompt,
# "hello, how are you.",
tokenizer,
return_tensors="pt",
)
.unsqueeze(0)
.to(namo_model.device)
)
print(input_ids)
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
with torch.autocast(device_type="cuda", dtype=torch.float16):
output_ids = namo_model.generate(
pixel_values=pixel_values,
input_ids=input_ids,
do_sample=False,
max_new_tokens=100,
streamer=streamer,
use_cache=True,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id,
)
outputs = tokenizer.decode(output_ids[0], skip_special_tokens=True).strip()
print(f"final output:\n{outputs}")
# namo_model.generate(pixel_values=None, input_ids=input_ids, max_new_tokens=300)
model_path = "checkpoints/namo-500m-native"
namo_model.save_pretrained(model_path)
config.save_pretrained(model_path)
tokenizer.save_pretrained(model_path)
image_processor.save_pretrained(model_path)
+12
View File
@@ -0,0 +1,12 @@
python3 setup.py check
sudo rm -r build/
sudo rm -r dist/
# pypi interface are not valid any longer
# python3 setup.py sdist
# python3 setup.py sdist upload -r pypi
# using twine instead
python3 setup.py sdist
twine upload dist/*
+15
View File
@@ -0,0 +1,15 @@
import os
import open_webui
# Show where the Open WebUI package is installed
print("Open WebUI is installed at:", open_webui.__file__)
# Construct a potential path to webui.db (commonly located in 'data/webui.db')
db_path = os.path.join(os.path.dirname(open_webui.__file__), "data", "webui.db")
print("Potential path to webui.db:", db_path)
# Check if webui.db exists at that path
if os.path.exists(db_path):
print("webui.db found at:", db_path)
else:
print("webui.db not found at:", db_path)
+8
View File
@@ -0,0 +1,8 @@
from namo.r1.grpo import main
from trl import TrlParser
if __name__ == "__main__":
parser = TrlParser((GRPOScriptArguments, GRPOConfig, ModelConfig))
script_args, training_args, model_args = parser.parse_args_and_config()
main(script_args, training_args, model_args)
+335
View File
@@ -0,0 +1,335 @@
"""
Code referenced from:
InternVL mDPO
"""
import math
import os
import random
import shutil
import sys
import traceback
from copy import copy, deepcopy
from dataclasses import dataclass, field
from typing import Dict, Literal, Optional
from loguru import logger
import numpy as np
import json
import torch
import torch.distributed as dist
import transformers
from namo.dataargs import DataArguments, ModelArguments
from namo.dataset_dpo import dpo_concat_pad_data_collator, WeightedConcatDataset
from namo.dataset_dpo import build_datasets
from namo.models.configuration_namo import NamoConfig
from namo.models.namo import NamoForCausalLM
from namo.tainer_mdpo import MultimodalDPOTrainer
from PIL import Image, ImageFile, PngImagePlugin, UnidentifiedImageError
from torch.utils.data import Dataset
from transformers import (
AutoConfig,
AutoModelForCausalLM,
AutoTokenizer,
HfArgumentParser,
Trainer,
TrainingArguments,
set_seed,
)
from transformers.trainer_utils import get_last_checkpoint
from transformers.utils.logging import (
enable_default_handler,
enable_explicit_format,
set_verbosity,
)
from trl import DPOConfig as DPOConfigTRL
from namo.models.symbols import IGNORE_INDEX
from namo.utils.hf_utils import get_latest_checkpoint
from namo.utils.utils import find_all_linear_names, rank0_print
Image.MAX_IMAGE_PIXELS = None
ImageFile.LOAD_TRUNCATED_IMAGES = True
MaximumDecompressedSize = 1024
MegaByte = 2**20
PngImagePlugin.MAX_TEXT_CHUNK = MaximumDecompressedSize * MegaByte
os.environ["TOKENIZERS_PARALLELISM"] = "true"
class DPOConfig(DPOConfigTRL):
loss_type: Literal[
"sigmoid",
"hinge",
"ipo",
"bco_pair",
"sppo_hard",
"nca_pair",
"robust",
"aot",
"aot_pair",
"exo_pair",
"sigmoid,bco_pair",
] = "sigmoid"
def main(attn_implementation="flash_attention_2"):
global local_rank
parser = transformers.HfArgumentParser(
(ModelArguments, DataArguments, TrainingArguments)
)
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
local_rank = training_args.local_rank
compute_dtype = (
torch.float16
if training_args.fp16
else (torch.bfloat16 if training_args.bf16 else torch.float32)
)
bnb_model_from_pretrained_args = {}
if training_args.bits in [4, 8]:
from transformers import BitsAndBytesConfig
bnb_model_from_pretrained_args.update(
dict(
device_map={"": training_args.device},
load_in_4bit=training_args.bits == 4,
load_in_8bit=training_args.bits == 8,
quantization_config=BitsAndBytesConfig(
load_in_4bit=training_args.bits == 4,
load_in_8bit=training_args.bits == 8,
llm_int8_skip_modules=["mm_projector"],
llm_int8_threshold=6.0,
llm_int8_has_fp16_weight=False,
bnb_4bit_compute_dtype=compute_dtype,
bnb_4bit_use_double_quant=training_args.double_quant,
bnb_4bit_quant_type=training_args.quant_type, # {'fp4', 'nf4'}
),
)
)
if model_args.pretrain_model_path is not None:
pretrain_model_path = get_latest_checkpoint(model_args.pretrain_model_path)
rank0_print(f"==> finetune from pretrained whole model: {pretrain_model_path}")
model = NamoForCausalLM.from_pretrained(pretrain_model_path)
rank0_print("==> pretrained model loaded.")
else:
text_config = AutoConfig.from_pretrained(
model_args.llm_model_path,
trust_remote_code=True,
attn_implementation=attn_implementation,
torch_dtype=compute_dtype,
)
vision_config = AutoConfig.from_pretrained(
model_args.ve_model_path,
trust_remote_code=True,
torch_dtype=compute_dtype,
)
config = NamoConfig(
text_config=text_config,
vision_config=vision_config,
attn_implementation=attn_implementation,
torch_dtype=compute_dtype,
conn_ve_llm_type=model_args.conn_ve_llm_type,
longest_edge=model_args.max_img_size,
**bnb_model_from_pretrained_args,
)
model = NamoForCausalLM(config=config)
# just copy ref model
ref_model = copy(model)
rank0_print(f"==> current model dtype: {model.dtype}, set is: {compute_dtype}")
tokenizer = model.get_namo().tokenizer
model.config.use_cache = False
if model_args.freeze_backbone:
model.model.requires_grad_(False)
if training_args.gradient_checkpointing:
if hasattr(model, "enable_input_require_grads"):
try:
model.enable_input_require_grads()
except Exception as e:
print(f"enable_input_require_grads: {e}")
else:
def make_inputs_require_grad(module, input, output):
output.requires_grad_(True)
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
if training_args.lora_enable:
from peft import LoraConfig, get_peft_model
lora_config = LoraConfig(
r=training_args.lora_r,
lora_alpha=training_args.lora_alpha,
target_modules=find_all_linear_names(model),
lora_dropout=training_args.lora_dropout,
bias=training_args.lora_bias,
task_type="CAUSAL_LM",
use_dora=training_args.use_dora,
)
if training_args.bits == 16:
if training_args.bf16:
model.to(torch.bfloat16)
if training_args.fp16:
model.to(torch.float16)
rank0_print("Adding LoRA adapters...")
model = get_peft_model(model, lora_config)
if tokenizer.unk_token != None:
tokenizer.pad_token = tokenizer.unk_token
if tokenizer.pad_token_id == None:
rank0_print(f"tokenizer.pad_token: {tokenizer.pad_token}")
if "mistral" in model_args.model_name_or_path.lower():
# important for mistral models
tokenizer.pad_token_id = tokenizer.encode("<pad>")
else:
tokenizer.pad_token_id = tokenizer.encode(
tokenizer.pad_token
if tokenizer.pad_token is not None
else tokenizer.eos_token
)
rank0_print(f"pad_token_id: {tokenizer.pad_token_id}")
if (
model_args.ve_model_path is not None
or model_args.pretrain_model_path is not None
):
logger.info("preparing ve model args...")
# model.get_model().initialize_vision_modules(
# model_args=model_args, fsdp=training_args.fsdp
# )
vision_tower = model.get_vision_tower()
vision_tower.to(
dtype=torch.bfloat16 if training_args.bf16 else torch.float16,
device=training_args.device,
)
model.config.unfreeze_ve = training_args.unfreeze_ve = model_args.unfreeze_ve
if training_args.unfreeze_ve:
for p in model.get_vision_tower().parameters():
p.requires_grad = True
model.config.new_img_size = model_args.new_img_size
model.config.longest_edge = data_args.longest_edge = model_args.max_img_size
model.config.dynamic_size = data_args.dynamic_size
vision_tower.image_processor.size["longest_edge"] = data_args.longest_edge
data_args.image_processor = vision_tower.image_processor
data_args.is_multimodal = True
model.config.image_aspect_ratio = data_args.image_aspect_ratio
model.config.video_fps = data_args.video_fps
model.config.video_frames_num = data_args.video_frames_num
model.config.tokenizer_padding_side = tokenizer.padding_side
model.config.tokenizer_model_max_length = tokenizer.model_max_length
model.config.tune_conn_ve_llm = training_args.tune_conn_ve_llm = (
model_args.tune_conn_ve_llm
)
if model_args.tune_conn_ve_llm:
model.requires_grad_(False)
for p in model.get_namo().conn_ve_llm.parameters():
p.requires_grad = True
model.config.freeze_conn_ve_llm = training_args.freeze_conn_ve_llm
if training_args.freeze_conn_ve_llm:
for p in model.get_namo().conn_ve_llm.parameters():
p.requires_grad = False
if training_args.bits in [4, 8]:
model.get_namo().conn_ve_llm.to(
dtype=compute_dtype, device=training_args.device
)
model.config.mm_use_im_start_end = data_args.mm_use_im_start_end = (
model_args.mm_use_im_start_end
)
model.config.conn_ve_llm_lr = training_args.conn_ve_llm_lr
model.config.s2 = model_args.s2
model.config.s2_scales = model_args.s2_scales
model.config.s2_max_split_size = model_args.s2_max_split_size
training_args.use_im_start_end = model_args.mm_use_im_start_end
model.config.mm_use_im_patch_token = model_args.mm_use_im_patch_token
model.initialize_vision_tokenizer(model_args, tokenizer=tokenizer)
logger.info("Model load finished.")
train_dataset = build_datasets(
data_args,
tokenizer,
None,
model,
group_by_length=training_args.group_by_length,
dynamic_image_size=data_args.dynamic_image_size,
use_thumbnail=data_args.use_thumbnail,
min_dynamic_patch=data_args.min_dynamic_patch,
max_dynamic_patch=data_args.max_dynamic_patch,
normalize_type=data_args.normalize_type,
min_num_frame=data_args.min_num_frame,
max_num_frame=data_args.max_num_frame,
)
def _freeze_params(module):
for param in module.parameters():
param.requires_grad = False
ref_model.eval()
# _freeze_params(ref_model)
# set seed for torch dataloaders
set_seed(training_args.seed)
trainer = MultimodalDPOTrainer(
model=model,
ref_model=ref_model,
args=training_args,
train_dataset=train_dataset if training_args.do_train else None,
eval_dataset=None,
tokenizer=tokenizer,
data_collator=dpo_concat_pad_data_collator,
)
# Training
if training_args.do_train:
print(
f"[Memory Usage before training] {torch.cuda.memory_allocated()/1024/1024/1024:.2f}GB"
)
train_result = trainer.train(resume_from_checkpoint=True)
trainer.save_model() # Saves the tokenizer too for easy upload
metrics = train_result.metrics
try:
metrics["train_samples"] = len(train_dataset)
except:
metrics["train_samples"] = -1
trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)
trainer.save_state()
model_dir = model_args.model_name_or_path
output_dir = training_args.output_dir
for filename in [
"conversation.py",
"modeling_internvl_chat.py",
"modeling_intern_vit.py",
"modeling_internlm2.py",
"configuration_internvl_chat.py",
"configuration_intern_vit.py",
"configuration_internlm2.py",
]:
if os.path.exists(os.path.join(model_dir, filename)):
shutil.copy(os.path.join(model_dir, filename), output_dir)
if __name__ == "__main__":
main()