fix train_dreambooth_lora_sd3.py loading hook (#9107)
This commit is contained in:
@@ -1271,7 +1271,7 @@ def main(args):
|
|||||||
lora_state_dict = StableDiffusion3Pipeline.lora_state_dict(input_dir)
|
lora_state_dict = StableDiffusion3Pipeline.lora_state_dict(input_dir)
|
||||||
|
|
||||||
transformer_state_dict = {
|
transformer_state_dict = {
|
||||||
f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")
|
f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.")
|
||||||
}
|
}
|
||||||
transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)
|
transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)
|
||||||
incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default")
|
incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default")
|
||||||
|
|||||||
Reference in New Issue
Block a user