[Dreambooth] Editable number of class images (#2251)
* [Dreambooth] Editable number of class images * 'class_num=None' bug fix --------- Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -454,6 +454,7 @@ class DreamBoothDataset(Dataset):
|
|||||||
tokenizer,
|
tokenizer,
|
||||||
class_data_root=None,
|
class_data_root=None,
|
||||||
class_prompt=None,
|
class_prompt=None,
|
||||||
|
class_num=None,
|
||||||
size=512,
|
size=512,
|
||||||
center_crop=False,
|
center_crop=False,
|
||||||
):
|
):
|
||||||
@@ -474,6 +475,9 @@ class DreamBoothDataset(Dataset):
|
|||||||
self.class_data_root = Path(class_data_root)
|
self.class_data_root = Path(class_data_root)
|
||||||
self.class_data_root.mkdir(parents=True, exist_ok=True)
|
self.class_data_root.mkdir(parents=True, exist_ok=True)
|
||||||
self.class_images_path = list(self.class_data_root.iterdir())
|
self.class_images_path = list(self.class_data_root.iterdir())
|
||||||
|
if class_num is not None:
|
||||||
|
self.num_class_images = min(len(self.class_images_path), class_num)
|
||||||
|
else:
|
||||||
self.num_class_images = len(self.class_images_path)
|
self.num_class_images = len(self.class_images_path)
|
||||||
self._length = max(self.num_class_images, self.num_instance_images)
|
self._length = max(self.num_class_images, self.num_instance_images)
|
||||||
self.class_prompt = class_prompt
|
self.class_prompt = class_prompt
|
||||||
@@ -814,6 +818,7 @@ def main(args):
|
|||||||
instance_prompt=args.instance_prompt,
|
instance_prompt=args.instance_prompt,
|
||||||
class_data_root=args.class_data_dir if args.with_prior_preservation else None,
|
class_data_root=args.class_data_dir if args.with_prior_preservation else None,
|
||||||
class_prompt=args.class_prompt,
|
class_prompt=args.class_prompt,
|
||||||
|
class_num=args.num_class_images,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
size=args.resolution,
|
size=args.resolution,
|
||||||
center_crop=args.center_crop,
|
center_crop=args.center_crop,
|
||||||
|
|||||||
@@ -231,6 +231,7 @@ class DreamBoothDataset(Dataset):
|
|||||||
tokenizer,
|
tokenizer,
|
||||||
class_data_root=None,
|
class_data_root=None,
|
||||||
class_prompt=None,
|
class_prompt=None,
|
||||||
|
class_num=None,
|
||||||
size=512,
|
size=512,
|
||||||
center_crop=False,
|
center_crop=False,
|
||||||
):
|
):
|
||||||
@@ -251,6 +252,9 @@ class DreamBoothDataset(Dataset):
|
|||||||
self.class_data_root = Path(class_data_root)
|
self.class_data_root = Path(class_data_root)
|
||||||
self.class_data_root.mkdir(parents=True, exist_ok=True)
|
self.class_data_root.mkdir(parents=True, exist_ok=True)
|
||||||
self.class_images_path = list(self.class_data_root.iterdir())
|
self.class_images_path = list(self.class_data_root.iterdir())
|
||||||
|
if class_num is not None:
|
||||||
|
self.num_class_images = min(len(self.class_images_path), class_num)
|
||||||
|
else:
|
||||||
self.num_class_images = len(self.class_images_path)
|
self.num_class_images = len(self.class_images_path)
|
||||||
self._length = max(self.num_class_images, self.num_instance_images)
|
self._length = max(self.num_class_images, self.num_instance_images)
|
||||||
self.class_prompt = class_prompt
|
self.class_prompt = class_prompt
|
||||||
@@ -419,6 +423,7 @@ def main():
|
|||||||
instance_prompt=args.instance_prompt,
|
instance_prompt=args.instance_prompt,
|
||||||
class_data_root=args.class_data_dir if args.with_prior_preservation else None,
|
class_data_root=args.class_data_dir if args.with_prior_preservation else None,
|
||||||
class_prompt=args.class_prompt,
|
class_prompt=args.class_prompt,
|
||||||
|
class_num=args.num_class_images,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
size=args.resolution,
|
size=args.resolution,
|
||||||
center_crop=args.center_crop,
|
center_crop=args.center_crop,
|
||||||
|
|||||||
@@ -417,6 +417,7 @@ class DreamBoothDataset(Dataset):
|
|||||||
tokenizer,
|
tokenizer,
|
||||||
class_data_root=None,
|
class_data_root=None,
|
||||||
class_prompt=None,
|
class_prompt=None,
|
||||||
|
class_num=None,
|
||||||
size=512,
|
size=512,
|
||||||
center_crop=False,
|
center_crop=False,
|
||||||
):
|
):
|
||||||
@@ -437,6 +438,9 @@ class DreamBoothDataset(Dataset):
|
|||||||
self.class_data_root = Path(class_data_root)
|
self.class_data_root = Path(class_data_root)
|
||||||
self.class_data_root.mkdir(parents=True, exist_ok=True)
|
self.class_data_root.mkdir(parents=True, exist_ok=True)
|
||||||
self.class_images_path = list(self.class_data_root.iterdir())
|
self.class_images_path = list(self.class_data_root.iterdir())
|
||||||
|
if class_num is not None:
|
||||||
|
self.num_class_images = min(len(self.class_images_path), class_num)
|
||||||
|
else:
|
||||||
self.num_class_images = len(self.class_images_path)
|
self.num_class_images = len(self.class_images_path)
|
||||||
self._length = max(self.num_class_images, self.num_instance_images)
|
self._length = max(self.num_class_images, self.num_instance_images)
|
||||||
self.class_prompt = class_prompt
|
self.class_prompt = class_prompt
|
||||||
@@ -771,6 +775,7 @@ def main(args):
|
|||||||
instance_prompt=args.instance_prompt,
|
instance_prompt=args.instance_prompt,
|
||||||
class_data_root=args.class_data_dir if args.with_prior_preservation else None,
|
class_data_root=args.class_data_dir if args.with_prior_preservation else None,
|
||||||
class_prompt=args.class_prompt,
|
class_prompt=args.class_prompt,
|
||||||
|
class_num=args.num_class_images,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
size=args.resolution,
|
size=args.resolution,
|
||||||
center_crop=args.center_crop,
|
center_crop=args.center_crop,
|
||||||
|
|||||||
Reference in New Issue
Block a user