parent
9c3820d05a
commit
836f3f35c2
@ -3,7 +3,7 @@ import argparse
|
||||
import OmegaConf
|
||||
import torch
|
||||
|
||||
from diffusers import UNetLDMModel, VQModel, LatentDiffusionUncondPipeline, DDIMScheduler
|
||||
from diffusers import UNetLDMModel, VQModel, LDMPipeline, DDIMScheduler
|
||||
|
||||
def convert_ldm_original(checkpoint_path, config_path, output_path):
|
||||
config = OmegaConf.load(config_path)
|
||||
@ -41,7 +41,7 @@ def convert_ldm_original(checkpoint_path, config_path, output_path):
|
||||
clip_sample=False,
|
||||
)
|
||||
|
||||
pipeline = LatentDiffusionUncondPipeline(vqvae, unet, noise_scheduler)
|
||||
pipeline = LDMPipeline(vqvae, unet, noise_scheduler)
|
||||
pipeline.save_pretrained(output_path)
|
||||
|
||||
|
||||
|
||||
@ -17,7 +17,7 @@
|
||||
import argparse
|
||||
import json
|
||||
import torch
|
||||
from diffusers import VQModel, DDPMScheduler, UNet2DModel, LatentDiffusionUncondPipeline
|
||||
from diffusers import VQModel, DDPMScheduler, UNet2DModel, LDMPipeline
|
||||
|
||||
|
||||
def shave_segments(path, n_shave_prefix_segments=1):
|
||||
@ -326,7 +326,7 @@ if __name__ == "__main__":
|
||||
scheduler = DDPMScheduler.from_config("/".join(args.checkpoint_path.split("/")[:-1]))
|
||||
vqvae = VQModel.from_pretrained("/".join(args.checkpoint_path.split("/")[:-1]))
|
||||
|
||||
pipe = LatentDiffusionUncondPipeline(unet=model, scheduler=scheduler, vae=vqvae)
|
||||
pipe = LDMPipeline(unet=model, scheduler=scheduler, vae=vqvae)
|
||||
pipe.save_pretrained(args.dump_path)
|
||||
except:
|
||||
model.save_pretrained(args.dump_path)
|
||||
|
||||
@ -9,11 +9,11 @@ __version__ = "0.0.4"
|
||||
from .modeling_utils import ModelMixin
|
||||
from .models import AutoencoderKL, UNet2DConditionModel, UNet2DModel, VQModel
|
||||
from .pipeline_utils import DiffusionPipeline
|
||||
from .pipelines import DDIMPipeline, DDPMPipeline, LatentDiffusionUncondPipeline, PNDMPipeline, ScoreSdeVePipeline
|
||||
from .pipelines import DDIMPipeline, DDPMPipeline, LDMPipeline, PNDMPipeline, ScoreSdeVePipeline
|
||||
from .schedulers import DDIMScheduler, DDPMScheduler, PNDMScheduler, SchedulerMixin, ScoreSdeVeScheduler
|
||||
|
||||
|
||||
if is_transformers_available():
|
||||
from .pipelines import LatentDiffusionPipeline
|
||||
from .pipelines import LDMTextToImagePipeline
|
||||
else:
|
||||
from .utils.dummy_transformers_objects import *
|
||||
|
||||
@ -1,10 +1,10 @@
|
||||
from ..utils import is_inflect_available, is_transformers_available, is_unidecode_available
|
||||
from .ddim import DDIMPipeline
|
||||
from .ddpm import DDPMPipeline
|
||||
from .latent_diffusion_uncond import LatentDiffusionUncondPipeline
|
||||
from .latent_diffusion_uncond import LDMPipeline
|
||||
from .pndm import PNDMPipeline
|
||||
from .score_sde_ve import ScoreSdeVePipeline
|
||||
|
||||
|
||||
if is_transformers_available():
|
||||
from .latent_diffusion import LatentDiffusionPipeline
|
||||
from .latent_diffusion import LDMTextToImagePipeline
|
||||
|
||||
@ -2,4 +2,4 @@ from ...utils import is_transformers_available
|
||||
|
||||
|
||||
if is_transformers_available():
|
||||
from .pipeline_latent_diffusion import LatentDiffusionPipeline, LDMBertModel
|
||||
from .pipeline_latent_diffusion import LDMBertModel, LDMTextToImagePipeline
|
||||
|
||||
@ -14,7 +14,7 @@ from transformers.utils import logging
|
||||
from ...pipeline_utils import DiffusionPipeline
|
||||
|
||||
|
||||
class LatentDiffusionPipeline(DiffusionPipeline):
|
||||
class LDMTextToImagePipeline(DiffusionPipeline):
|
||||
def __init__(self, vqvae, bert, tokenizer, unet, scheduler):
|
||||
super().__init__()
|
||||
scheduler = scheduler.set_format("pt")
|
||||
|
||||
@ -1 +1 @@
|
||||
from .pipeline_latent_diffusion_uncond import LatentDiffusionUncondPipeline
|
||||
from .pipeline_latent_diffusion_uncond import LDMPipeline
|
||||
|
||||
@ -5,7 +5,7 @@ from tqdm.auto import tqdm
|
||||
from ...pipeline_utils import DiffusionPipeline
|
||||
|
||||
|
||||
class LatentDiffusionUncondPipeline(DiffusionPipeline):
|
||||
class LDMPipeline(DiffusionPipeline):
|
||||
def __init__(self, vqvae, unet, scheduler):
|
||||
super().__init__()
|
||||
scheduler = scheduler.set_format("pt")
|
||||
|
||||
@ -3,42 +3,7 @@
|
||||
from ..utils import DummyObject, requires_backends
|
||||
|
||||
|
||||
class GlideSuperResUNetModel(metaclass=DummyObject):
|
||||
_backends = ["transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["transformers"])
|
||||
|
||||
|
||||
class GlideTextToImageUNetModel(metaclass=DummyObject):
|
||||
_backends = ["transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["transformers"])
|
||||
|
||||
|
||||
class GlideUNetModel(metaclass=DummyObject):
|
||||
_backends = ["transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["transformers"])
|
||||
|
||||
|
||||
class UNetGradTTSModel(metaclass=DummyObject):
|
||||
_backends = ["transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["transformers"])
|
||||
|
||||
|
||||
class GlidePipeline(metaclass=DummyObject):
|
||||
_backends = ["transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["transformers"])
|
||||
|
||||
|
||||
class LatentDiffusionPipeline(metaclass=DummyObject):
|
||||
class LDMTextToImagePipeline(metaclass=DummyObject):
|
||||
_backends = ["transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
|
||||
@ -29,8 +29,8 @@ from diffusers import (
|
||||
DDIMScheduler,
|
||||
DDPMPipeline,
|
||||
DDPMScheduler,
|
||||
LatentDiffusionPipeline,
|
||||
LatentDiffusionUncondPipeline,
|
||||
LDMPipeline,
|
||||
LDMTextToImagePipeline,
|
||||
PNDMPipeline,
|
||||
PNDMScheduler,
|
||||
ScoreSdeVePipeline,
|
||||
@ -826,7 +826,7 @@ class PipelineTesterMixin(unittest.TestCase):
|
||||
|
||||
@slow
|
||||
def test_ldm_text2img(self):
|
||||
ldm = LatentDiffusionPipeline.from_pretrained("/home/patrick/google_checkpoints/ldm-text2im-large-256")
|
||||
ldm = LDMTextToImagePipeline.from_pretrained("/home/patrick/google_checkpoints/ldm-text2im-large-256")
|
||||
|
||||
prompt = "A painting of a squirrel eating a burger"
|
||||
generator = torch.manual_seed(0)
|
||||
@ -842,7 +842,7 @@ class PipelineTesterMixin(unittest.TestCase):
|
||||
|
||||
@slow
|
||||
def test_ldm_text2img_fast(self):
|
||||
ldm = LatentDiffusionPipeline.from_pretrained("/home/patrick/google_checkpoints/ldm-text2im-large-256")
|
||||
ldm = LDMTextToImagePipeline.from_pretrained("/home/patrick/google_checkpoints/ldm-text2im-large-256")
|
||||
|
||||
prompt = "A painting of a squirrel eating a burger"
|
||||
generator = torch.manual_seed(0)
|
||||
@ -877,7 +877,7 @@ class PipelineTesterMixin(unittest.TestCase):
|
||||
|
||||
@slow
|
||||
def test_ldm_uncond(self):
|
||||
ldm = LatentDiffusionUncondPipeline.from_pretrained("/home/patrick/google_checkpoints/ldm-celebahq-256")
|
||||
ldm = LDMPipeline.from_pretrained("/home/patrick/google_checkpoints/ldm-celebahq-256")
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
image = ldm(generator=generator, num_inference_steps=5, output_type="numpy")["sample"]
|
||||
|
||||
Loading…
Reference in New Issue
Block a user