[Onnx] support half-precision and fix bugs for onnx pipelines (#932)
* [Onnx] support half-precision and fix bugs for onnx pipelines * Update convert_stable_diffusion_checkpoint_to_onnx.py * style * fix has_nsfw_concept * Update convert_stable_diffusion_checkpoint_to_onnx.py * fix style
This commit is contained in:
parent
3d02c92187
commit
0b42b074b4
@ -69,8 +69,15 @@ def onnx_export(
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def convert_models(model_path: str, output_path: str, opset: int):
|
||||
pipeline = StableDiffusionPipeline.from_pretrained(model_path)
|
||||
def convert_models(model_path: str, output_path: str, opset: int, fp16: bool = False):
|
||||
dtype = torch.float16 if fp16 else torch.float32
|
||||
if fp16 and torch.cuda.is_available():
|
||||
device = "cuda"
|
||||
elif fp16 and not torch.cuda.is_available():
|
||||
raise ValueError("`float16` model export is only supported on GPUs with CUDA")
|
||||
else:
|
||||
device = "cpu"
|
||||
pipeline = StableDiffusionPipeline.from_pretrained(model_path, torch_dtype=dtype).to(device)
|
||||
output_path = Path(output_path)
|
||||
|
||||
# TEXT ENCODER
|
||||
@ -84,7 +91,7 @@ def convert_models(model_path: str, output_path: str, opset: int):
|
||||
onnx_export(
|
||||
pipeline.text_encoder,
|
||||
# casting to torch.int32 until the CLIP fix is released: https://github.com/huggingface/transformers/pull/18515/files
|
||||
model_args=(text_input.input_ids.to(torch.int32)),
|
||||
model_args=(text_input.input_ids.to(device=device, dtype=torch.int32)),
|
||||
output_path=output_path / "text_encoder" / "model.onnx",
|
||||
ordered_input_names=["input_ids"],
|
||||
output_names=["last_hidden_state", "pooler_output"],
|
||||
@ -100,9 +107,9 @@ def convert_models(model_path: str, output_path: str, opset: int):
|
||||
onnx_export(
|
||||
pipeline.unet,
|
||||
model_args=(
|
||||
torch.randn(2, pipeline.unet.in_channels, 64, 64),
|
||||
torch.LongTensor([0, 1]),
|
||||
torch.randn(2, 77, 768),
|
||||
torch.randn(2, pipeline.unet.in_channels, 64, 64).to(device=device, dtype=dtype),
|
||||
torch.LongTensor([0, 1]).to(device=device),
|
||||
torch.randn(2, 77, 768).to(device=device, dtype=dtype),
|
||||
False,
|
||||
),
|
||||
output_path=unet_path,
|
||||
@ -139,7 +146,7 @@ def convert_models(model_path: str, output_path: str, opset: int):
|
||||
vae_encoder.forward = lambda sample, return_dict: vae_encoder.encode(sample, return_dict)[0].sample()
|
||||
onnx_export(
|
||||
vae_encoder,
|
||||
model_args=(torch.randn(1, 3, 512, 512), False),
|
||||
model_args=(torch.randn(1, 3, 512, 512).to(device=device, dtype=dtype), False),
|
||||
output_path=output_path / "vae_encoder" / "model.onnx",
|
||||
ordered_input_names=["sample", "return_dict"],
|
||||
output_names=["latent_sample"],
|
||||
@ -155,7 +162,7 @@ def convert_models(model_path: str, output_path: str, opset: int):
|
||||
vae_decoder.forward = vae_encoder.decode
|
||||
onnx_export(
|
||||
vae_decoder,
|
||||
model_args=(torch.randn(1, 4, 64, 64), False),
|
||||
model_args=(torch.randn(1, 4, 64, 64).to(device=device, dtype=dtype), False),
|
||||
output_path=output_path / "vae_decoder" / "model.onnx",
|
||||
ordered_input_names=["latent_sample", "return_dict"],
|
||||
output_names=["sample"],
|
||||
@ -171,13 +178,16 @@ def convert_models(model_path: str, output_path: str, opset: int):
|
||||
safety_checker.forward = safety_checker.forward_onnx
|
||||
onnx_export(
|
||||
pipeline.safety_checker,
|
||||
model_args=(torch.randn(1, 3, 224, 224), torch.randn(1, 512, 512, 3)),
|
||||
model_args=(
|
||||
torch.randn(1, 3, 224, 224).to(device=device, dtype=dtype),
|
||||
torch.randn(1, 512, 512, 3).to(device=device, dtype=dtype),
|
||||
),
|
||||
output_path=output_path / "safety_checker" / "model.onnx",
|
||||
ordered_input_names=["clip_input", "images"],
|
||||
output_names=["out_images", "has_nsfw_concepts"],
|
||||
dynamic_axes={
|
||||
"clip_input": {0: "batch", 1: "channels", 2: "height", 3: "width"},
|
||||
"images": {0: "batch", 1: "channels", 2: "height", 3: "width"},
|
||||
"images": {0: "batch", 1: "height", 2: "width", 3: "channels"},
|
||||
},
|
||||
opset=opset,
|
||||
)
|
||||
@ -221,7 +231,8 @@ if __name__ == "__main__":
|
||||
type=int,
|
||||
help="The version of the ONNX operator set to use.",
|
||||
)
|
||||
parser.add_argument("--fp16", action="store_true", default=False, help="Export the models in `float16` mode")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
convert_models(args.model_path, args.output_path, args.opset)
|
||||
convert_models(args.model_path, args.output_path, args.opset, args.fp16)
|
||||
|
||||
@ -55,7 +55,9 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
|
||||
num_inference_steps: Optional[int] = 50,
|
||||
guidance_scale: Optional[float] = 7.5,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
eta: Optional[float] = 0.0,
|
||||
generator: Optional[np.random.RandomState] = None,
|
||||
latents: Optional[np.ndarray] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
@ -81,6 +83,9 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
|
||||
f" {type(callback_steps)}."
|
||||
)
|
||||
|
||||
if generator is None:
|
||||
generator = np.random
|
||||
|
||||
# get prompt text embeddings
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
@ -98,6 +103,7 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
|
||||
)
|
||||
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
|
||||
text_embeddings = self.text_encoder(input_ids=text_input_ids.astype(np.int32))[0]
|
||||
text_embeddings = np.repeat(text_embeddings, num_images_per_prompt, axis=0)
|
||||
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
@ -133,6 +139,7 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
|
||||
return_tensors="np",
|
||||
)
|
||||
uncond_embeddings = self.text_encoder(input_ids=uncond_input.input_ids.astype(np.int32))[0]
|
||||
uncond_embeddings = np.repeat(uncond_embeddings, num_images_per_prompt, axis=0)
|
||||
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
@ -140,9 +147,10 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
|
||||
text_embeddings = np.concatenate([uncond_embeddings, text_embeddings])
|
||||
|
||||
# get the initial random noise unless the user supplied it
|
||||
latents_shape = (batch_size, 4, height // 8, width // 8)
|
||||
latents_dtype = text_embeddings.dtype
|
||||
latents_shape = (batch_size * num_images_per_prompt, 4, height // 8, width // 8)
|
||||
if latents is None:
|
||||
latents = np.random.randn(*latents_shape).astype(np.float32)
|
||||
latents = generator.randn(*latents_shape).astype(latents_dtype)
|
||||
elif latents.shape != latents_shape:
|
||||
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
|
||||
|
||||
@ -185,13 +193,30 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
|
||||
callback(i, t, latents)
|
||||
|
||||
latents = 1 / 0.18215 * latents
|
||||
image = self.vae_decoder(latent_sample=latents)[0]
|
||||
# image = self.vae_decoder(latent_sample=latents)[0]
|
||||
# it seems likes there is a strange result for using half-precision vae decoder if batchsize>1
|
||||
image = np.concatenate(
|
||||
[self.vae_decoder(latent_sample=latents[i : i + 1])[0] for i in range(latents.shape[0])]
|
||||
)
|
||||
|
||||
image = np.clip(image / 2 + 0.5, 0, 1)
|
||||
image = image.transpose((0, 2, 3, 1))
|
||||
|
||||
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="np")
|
||||
image, has_nsfw_concept = self.safety_checker(clip_input=safety_checker_input.pixel_values, images=image)
|
||||
if self.safety_checker is not None:
|
||||
safety_checker_input = self.feature_extractor(
|
||||
self.numpy_to_pil(image), return_tensors="np"
|
||||
).pixel_values.astype(image.dtype)
|
||||
# There will throw an error if use safety_checker batchsize>1
|
||||
images, has_nsfw_concept = [], []
|
||||
for i in range(image.shape[0]):
|
||||
image_i, has_nsfw_concept_i = self.safety_checker(
|
||||
clip_input=safety_checker_input[i : i + 1], images=image[i : i + 1]
|
||||
)
|
||||
images.append(image_i)
|
||||
has_nsfw_concept.append(has_nsfw_concept_i[0])
|
||||
image = np.concatenate(images)
|
||||
else:
|
||||
has_nsfw_concept = None
|
||||
|
||||
if output_type == "pil":
|
||||
image = self.numpy_to_pil(image)
|
||||
|
||||
@ -121,6 +121,7 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
eta: Optional[float] = 0.0,
|
||||
generator: Optional[np.random.RandomState] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
|
||||
@ -159,6 +160,8 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
eta (`float`, *optional*, defaults to 0.0):
|
||||
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
||||
[`schedulers.DDIMScheduler`], will be ignored for others.
|
||||
generator (`np.random.RandomState`, *optional*):
|
||||
A np.random.RandomState to make generation deterministic.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
@ -197,6 +200,9 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
f" {type(callback_steps)}."
|
||||
)
|
||||
|
||||
if generator is None:
|
||||
generator = np.random
|
||||
|
||||
# set timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
|
||||
@ -239,7 +245,7 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
f" {type(prompt)}."
|
||||
)
|
||||
elif isinstance(negative_prompt, str):
|
||||
uncond_tokens = [negative_prompt]
|
||||
uncond_tokens = [negative_prompt] * batch_size
|
||||
elif batch_size != len(negative_prompt):
|
||||
raise ValueError("The length of `negative_prompt` should be equal to batch_size.")
|
||||
else:
|
||||
@ -257,13 +263,15 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
uncond_embeddings = self.text_encoder(input_ids=uncond_input_ids.astype(np.int32))[0]
|
||||
|
||||
# duplicate unconditional embeddings for each generation per prompt
|
||||
uncond_embeddings = np.repeat(uncond_embeddings, batch_size * num_images_per_prompt, axis=0)
|
||||
uncond_embeddings = np.repeat(uncond_embeddings, num_images_per_prompt, axis=0)
|
||||
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
text_embeddings = np.concatenate([uncond_embeddings, text_embeddings])
|
||||
|
||||
latents_dtype = text_embeddings.dtype
|
||||
init_image = init_image.astype(latents_dtype)
|
||||
# encode the init image into latents and scale the latents
|
||||
init_latents = self.vae_encoder(sample=init_image)[0]
|
||||
init_latents = 0.18215 * init_latents
|
||||
@ -297,7 +305,7 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
timesteps = np.array([timesteps] * batch_size * num_images_per_prompt)
|
||||
|
||||
# add noise to latents using the timesteps
|
||||
noise = np.random.randn(*init_latents.shape).astype(np.float32)
|
||||
noise = generator.randn(*init_latents.shape).astype(latents_dtype)
|
||||
init_latents = self.scheduler.add_noise(
|
||||
torch.from_numpy(init_latents), torch.from_numpy(noise), torch.from_numpy(timesteps)
|
||||
)
|
||||
@ -341,14 +349,28 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
callback(i, t, latents)
|
||||
|
||||
latents = 1 / 0.18215 * latents
|
||||
image = self.vae_decoder(latent_sample=latents)[0]
|
||||
# image = self.vae_decoder(latent_sample=latents)[0]
|
||||
# it seems likes there is a strange result for using half-precision vae decoder if batchsize>1
|
||||
image = np.concatenate(
|
||||
[self.vae_decoder(latent_sample=latents[i : i + 1])[0] for i in range(latents.shape[0])]
|
||||
)
|
||||
|
||||
image = np.clip(image / 2 + 0.5, 0, 1)
|
||||
image = image.transpose((0, 2, 3, 1))
|
||||
|
||||
if self.safety_checker is not None:
|
||||
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="np")
|
||||
image, has_nsfw_concept = self.safety_checker(clip_input=safety_checker_input.pixel_values, images=image)
|
||||
safety_checker_input = self.feature_extractor(
|
||||
self.numpy_to_pil(image), return_tensors="np"
|
||||
).pixel_values.astype(image.dtype)
|
||||
# There will throw an error if use safety_checker batchsize>1
|
||||
images, has_nsfw_concept = [], []
|
||||
for i in range(image.shape[0]):
|
||||
image_i, has_nsfw_concept_i = self.safety_checker(
|
||||
clip_input=safety_checker_input[i : i + 1], images=image[i : i + 1]
|
||||
)
|
||||
images.append(image_i)
|
||||
has_nsfw_concept.append(has_nsfw_concept_i[0])
|
||||
image = np.concatenate(images)
|
||||
else:
|
||||
has_nsfw_concept = None
|
||||
|
||||
|
||||
@ -23,11 +23,11 @@ NUM_LATENT_CHANNELS = 4
|
||||
|
||||
|
||||
def prepare_mask_and_masked_image(image, mask, latents_shape):
|
||||
image = np.array(image.convert("RGB"))
|
||||
image = np.array(image.convert("RGB").resize((latents_shape[1] * 8, latents_shape[0] * 8)))
|
||||
image = image[None].transpose(0, 3, 1, 2)
|
||||
image = image.astype(np.float32) / 127.5 - 1.0
|
||||
|
||||
image_mask = np.array(mask.convert("L"))
|
||||
image_mask = np.array(mask.convert("L").resize((latents_shape[1] * 8, latents_shape[0] * 8)))
|
||||
masked_image = image * (image_mask < 127.5)
|
||||
|
||||
mask = mask.resize((latents_shape[1], latents_shape[0]), PIL.Image.NEAREST)
|
||||
@ -138,6 +138,7 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
eta: float = 0.0,
|
||||
generator: Optional[np.random.RandomState] = None,
|
||||
latents: Optional[np.ndarray] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
@ -180,6 +181,8 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
|
||||
eta (`float`, *optional*, defaults to 0.0):
|
||||
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
||||
[`schedulers.DDIMScheduler`], will be ignored for others.
|
||||
generator (`np.random.RandomState`, *optional*):
|
||||
A np.random.RandomState to make generation deterministic.
|
||||
latents (`np.ndarray`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
@ -222,6 +225,9 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
|
||||
f" {type(callback_steps)}."
|
||||
)
|
||||
|
||||
if generator is None:
|
||||
generator = np.random
|
||||
|
||||
# set timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
|
||||
@ -261,7 +267,7 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
|
||||
f" {type(prompt)}."
|
||||
)
|
||||
elif isinstance(negative_prompt, str):
|
||||
uncond_tokens = [negative_prompt]
|
||||
uncond_tokens = [negative_prompt] * batch_size
|
||||
elif batch_size != len(negative_prompt):
|
||||
raise ValueError(
|
||||
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
||||
@ -283,7 +289,7 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
|
||||
uncond_embeddings = self.text_encoder(input_ids=uncond_input_ids.astype(np.int32))[0]
|
||||
|
||||
# duplicate unconditional embeddings for each generation per prompt
|
||||
uncond_embeddings = np.repeat(uncond_embeddings, batch_size * num_images_per_prompt, axis=0)
|
||||
uncond_embeddings = np.repeat(uncond_embeddings, num_images_per_prompt, axis=0)
|
||||
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
@ -294,7 +300,7 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
|
||||
latents_shape = (batch_size * num_images_per_prompt, num_channels_latents, height // 8, width // 8)
|
||||
latents_dtype = text_embeddings.dtype
|
||||
if latents is None:
|
||||
latents = np.random.randn(*latents_shape).astype(latents_dtype)
|
||||
latents = generator.randn(*latents_shape).astype(latents_dtype)
|
||||
else:
|
||||
if latents.shape != latents_shape:
|
||||
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
|
||||
@ -307,6 +313,10 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
|
||||
masked_image_latents = self.vae_encoder(sample=masked_image)[0]
|
||||
masked_image_latents = 0.18215 * masked_image_latents
|
||||
|
||||
# duplicate mask and masked_image_latents for each generation per prompt
|
||||
mask = mask.repeat(batch_size * num_images_per_prompt, 0)
|
||||
masked_image_latents = masked_image_latents.repeat(batch_size * num_images_per_prompt, 0)
|
||||
|
||||
mask = np.concatenate([mask] * 2) if do_classifier_free_guidance else mask
|
||||
masked_image_latents = (
|
||||
np.concatenate([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents
|
||||
@ -367,14 +377,28 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
|
||||
callback(i, t, latents)
|
||||
|
||||
latents = 1 / 0.18215 * latents
|
||||
image = self.vae_decoder(latent_sample=latents)[0]
|
||||
# image = self.vae_decoder(latent_sample=latents)[0]
|
||||
# it seems likes there is a strange result for using half-precision vae decoder if batchsize>1
|
||||
image = np.concatenate(
|
||||
[self.vae_decoder(latent_sample=latents[i : i + 1])[0] for i in range(latents.shape[0])]
|
||||
)
|
||||
|
||||
image = np.clip(image / 2 + 0.5, 0, 1)
|
||||
image = image.transpose((0, 2, 3, 1))
|
||||
|
||||
if self.safety_checker is not None:
|
||||
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="np")
|
||||
image, has_nsfw_concept = self.safety_checker(clip_input=safety_checker_input.pixel_values, images=image)
|
||||
safety_checker_input = self.feature_extractor(
|
||||
self.numpy_to_pil(image), return_tensors="np"
|
||||
).pixel_values.astype(image.dtype)
|
||||
# There will throw an error if use safety_checker batchsize>1
|
||||
images, has_nsfw_concept = [], []
|
||||
for i in range(image.shape[0]):
|
||||
image_i, has_nsfw_concept_i = self.safety_checker(
|
||||
clip_input=safety_checker_input[i : i + 1], images=image[i : i + 1]
|
||||
)
|
||||
images.append(image_i)
|
||||
has_nsfw_concept.append(has_nsfw_concept_i[0])
|
||||
image = np.concatenate(images)
|
||||
else:
|
||||
has_nsfw_concept = None
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user