fix custom diffusion training with concept list (#6710)

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
This commit is contained in:
Andrew Ishutin
2024-01-26 06:17:51 +03:00
committed by GitHub
parent 7c1c705f60
commit 5b93338235
@@ -753,7 +753,7 @@ def main(args):
num_new_images = args.num_class_images - cur_class_images num_new_images = args.num_class_images - cur_class_images
logger.info(f"Number of class images to sample: {num_new_images}.") logger.info(f"Number of class images to sample: {num_new_images}.")
sample_dataset = PromptDataset(args.class_prompt, num_new_images) sample_dataset = PromptDataset(concept["class_prompt"], num_new_images)
sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size) sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)
sample_dataloader = accelerator.prepare(sample_dataloader) sample_dataloader = accelerator.prepare(sample_dataloader)