Compare commits
4 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 50a96ca719 | |||
| 5ca062e011 | |||
| 619e3ab6f6 | |||
| 9e2804f720 |
@@ -59,7 +59,7 @@ jobs:
|
||||
|
||||
- name: Run fast PyTorch LoRA CPU tests with PEFT backend
|
||||
run: |
|
||||
python -m pytest -n 2 --max-worker-restart=0 --dist=loadfile \
|
||||
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
|
||||
-s -v \
|
||||
--make-reports=tests_${{ matrix.config.report }} \
|
||||
tests/lora/test_lora_layers_peft.py
|
||||
|
||||
@@ -33,6 +33,9 @@ model = AutoencoderKL.from_single_file(url)
|
||||
## AutoencoderKL
|
||||
|
||||
[[autodoc]] AutoencoderKL
|
||||
- decode
|
||||
- encode
|
||||
- all
|
||||
|
||||
## AutoencoderKLOutput
|
||||
|
||||
|
||||
@@ -1279,7 +1279,7 @@ def main(args):
|
||||
for name, param in text_encoder_one.named_parameters():
|
||||
if "token_embedding" in name:
|
||||
# ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16
|
||||
param = param.to(dtype=torch.float32)
|
||||
param.data = param.to(dtype=torch.float32)
|
||||
param.requires_grad = True
|
||||
text_lora_parameters_one.append(param)
|
||||
else:
|
||||
@@ -1288,7 +1288,7 @@ def main(args):
|
||||
for name, param in text_encoder_two.named_parameters():
|
||||
if "token_embedding" in name:
|
||||
# ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16
|
||||
param = param.to(dtype=torch.float32)
|
||||
param.data = param.to(dtype=torch.float32)
|
||||
param.requires_grad = True
|
||||
text_lora_parameters_two.append(param)
|
||||
else:
|
||||
@@ -1725,19 +1725,19 @@ def main(args):
|
||||
num_train_epochs_text_encoder = int(args.train_text_encoder_frac * args.num_train_epochs)
|
||||
elif args.train_text_encoder_ti: # args.train_text_encoder_ti
|
||||
num_train_epochs_text_encoder = int(args.train_text_encoder_ti_frac * args.num_train_epochs)
|
||||
|
||||
# flag used for textual inversion
|
||||
pivoted = False
|
||||
for epoch in range(first_epoch, args.num_train_epochs):
|
||||
# if performing any kind of optimization of text_encoder params
|
||||
if args.train_text_encoder or args.train_text_encoder_ti:
|
||||
if epoch == num_train_epochs_text_encoder:
|
||||
print("PIVOT HALFWAY", epoch)
|
||||
# stopping optimization of text_encoder params
|
||||
# re setting the optimizer to optimize only on unet params
|
||||
optimizer.param_groups[1]["lr"] = 0.0
|
||||
optimizer.param_groups[2]["lr"] = 0.0
|
||||
# this flag is used to reset the optimizer to optimize only on unet params
|
||||
pivoted = True
|
||||
|
||||
else:
|
||||
# still optimizng the text encoder
|
||||
# still optimizing the text encoder
|
||||
text_encoder_one.train()
|
||||
text_encoder_two.train()
|
||||
# set top parameter requires_grad = True for gradient checkpointing works
|
||||
@@ -1747,6 +1747,12 @@ def main(args):
|
||||
|
||||
unet.train()
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
if pivoted:
|
||||
# stopping optimization of text_encoder params
|
||||
# re setting the optimizer to optimize only on unet params
|
||||
optimizer.param_groups[1]["lr"] = 0.0
|
||||
optimizer.param_groups[2]["lr"] = 0.0
|
||||
|
||||
with accelerator.accumulate(unet):
|
||||
prompts = batch["prompts"]
|
||||
# encode batch prompts when custom prompts are provided for each image -
|
||||
@@ -1885,8 +1891,7 @@ def main(args):
|
||||
|
||||
# every step, we reset the embeddings to the original embeddings.
|
||||
if args.train_text_encoder_ti:
|
||||
for idx, text_encoder in enumerate(text_encoders):
|
||||
embedding_handler.retract_embeddings()
|
||||
embedding_handler.retract_embeddings()
|
||||
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if accelerator.sync_gradients:
|
||||
|
||||
@@ -519,10 +519,10 @@ def export_to_video(video_frames: List[np.ndarray], output_video_path: str = Non
|
||||
|
||||
|
||||
def load_hf_numpy(path) -> np.ndarray:
|
||||
if not path.startswith("http://") or path.startswith("https://"):
|
||||
path = Path(
|
||||
"https://huggingface.co/datasets/fusing/diffusers-testing/resolve/main", urllib.parse.quote(path)
|
||||
).as_posix()
|
||||
base_url = "https://huggingface.co/datasets/fusing/diffusers-testing/resolve/main"
|
||||
|
||||
if not path.startswith("http://") and not path.startswith("https://"):
|
||||
path = os.path.join(base_url, urllib.parse.quote(path))
|
||||
|
||||
return load_numpy(path)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user