mirror of
https://github.com/lucasjinreal/Namo-R1.git
synced 2026-03-23 17:04:41 +08:00
57 lines
1.6 KiB
Python
57 lines
1.6 KiB
Python
import requests
|
|
from PIL import Image
|
|
import base64
|
|
from threading import Thread
|
|
import io
|
|
import os
|
|
from transformers import TextStreamer
|
|
|
|
try:
|
|
from datauri import DataURI
|
|
except ImportError:
|
|
pass
|
|
|
|
|
|
def load_image(image_file):
|
|
if isinstance(image_file, Image.Image):
|
|
return image_file
|
|
if image_file.startswith("http") or image_file.startswith("https"):
|
|
response = requests.get(image_file)
|
|
image = Image.open(io.BytesIO(response.content)).convert("RGB")
|
|
else:
|
|
image = Image.open(image_file).convert("RGB")
|
|
return image
|
|
|
|
|
|
def load_multi_images_maybe(image_files, splitter=" "):
|
|
if isinstance(image_files, Image.Image):
|
|
return [image_files]
|
|
if isinstance(image_files, str) and not os.path.exists(image_files):
|
|
images = image_files.split(splitter)
|
|
else:
|
|
images = image_files
|
|
return [load_image(i) for i in images]
|
|
|
|
|
|
def url_to_image(img_url: str) -> Image.Image:
|
|
if img_url.startswith("http"):
|
|
response = requests.get(img_url)
|
|
|
|
img_data = response.content
|
|
elif img_url.startswith("data:"):
|
|
img_data = DataURI(img_url).data
|
|
else:
|
|
img_data = base64.b64decode(img_url)
|
|
return Image.open(io.BytesIO(img_data)).convert("RGB")
|
|
|
|
|
|
class CallbackStreamer(TextStreamer):
|
|
def __init__(self, tokenizer, callback=None, **kwargs):
|
|
super().__init__(tokenizer, **kwargs)
|
|
self.callback = callback
|
|
|
|
def on_finalized_text(self, text: str, stream_end: bool = False):
|
|
if self.callback is not None:
|
|
self.callback(text)
|
|
super().on_finalized_text(text, stream_end)
|