add 4K support for Sana
This commit is contained in:
Junsong Chen
2025-01-09 05:58:11 +08:00
committed by GitHub
parent b13cdbb294
commit c0964571fc
2 changed files with 54 additions and 5 deletions
+8 -4
View File
@@ -25,6 +25,7 @@ from diffusers.utils.import_utils import is_accelerate_available
CTX = init_empty_weights if is_accelerate_available else nullcontext
ckpt_ids = [
"Efficient-Large-Model/Sana_1600M_4Kpx_BF16/checkpoints/Sana_1600M_4Kpx_BF16.pth",
"Efficient-Large-Model/Sana_1600M_2Kpx_BF16/checkpoints/Sana_1600M_2Kpx_BF16.pth",
"Efficient-Large-Model/Sana_1600M_1024px_MultiLing/checkpoints/Sana_1600M_1024px_MultiLing.pth",
"Efficient-Large-Model/Sana_1600M_1024px_BF16/checkpoints/Sana_1600M_1024px_BF16.pth",
@@ -89,7 +90,10 @@ def main(args):
converted_state_dict["caption_norm.weight"] = state_dict.pop("attention_y_norm.weight")
# scheduler
flow_shift = 3.0
if args.image_size == 4096:
flow_shift = 6.0
else:
flow_shift = 3.0
# model config
if args.model_type == "SanaMS_1600M_P1_D20":
@@ -99,7 +103,7 @@ def main(args):
else:
raise ValueError(f"{args.model_type} is not supported.")
# Positional embedding interpolation scale.
interpolation_scale = {512: None, 1024: None, 2048: 1.0}
interpolation_scale = {512: None, 1024: None, 2048: 1.0, 4096: 2.0}
for depth in range(layer_num):
# Transformer blocks.
@@ -272,9 +276,9 @@ if __name__ == "__main__":
"--image_size",
default=1024,
type=int,
choices=[512, 1024, 2048],
choices=[512, 1024, 2048, 4096],
required=False,
help="Image size of pretrained model, 512, 1024 or 2048.",
help="Image size of pretrained model, 512, 1024, 2048 or 4096.",
)
parser.add_argument(
"--model_type", default="SanaMS_1600M_P1_D20", type=str, choices=["SanaMS_1600M_P1_D20", "SanaMS_600M_P1_D28"]