Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 6bc78b3902 |
@@ -44,7 +44,7 @@ pipe = StableVideoDiffusionPipeline.from_pretrained(
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
# Load the conditioning image
|
||||
image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/svd/rocket.png?download=true")
|
||||
image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/svd/rocket.png")
|
||||
image = image.resize((1024, 576))
|
||||
|
||||
generator = torch.manual_seed(42)
|
||||
@@ -58,6 +58,11 @@ export_to_video(frames, "generated.mp4", fps=7)
|
||||
<source src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/svd/rocket_generated.mp4" type="video/mp4" />
|
||||
</video>
|
||||
|
||||
| **Source Image** | **Video** |
|
||||
|:------------:|:-----:|
|
||||
|  |  |
|
||||
|
||||
|
||||
<Tip>
|
||||
Since generating videos is more memory intensive we can use the `decode_chunk_size` argument to control how many frames are decoded at once. This will reduce the memory usage. It's recommended to tweak this value based on your GPU memory.
|
||||
Setting `decode_chunk_size=1` will decode one frame at a time and will use the least amount of memory but the video might have some flickering.
|
||||
@@ -120,7 +125,7 @@ pipe = StableVideoDiffusionPipeline.from_pretrained(
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
# Load the conditioning image
|
||||
image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/svd/rocket.png?download=true")
|
||||
image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/svd/rocket.png")
|
||||
image = image.resize((1024, 576))
|
||||
|
||||
generator = torch.manual_seed(42)
|
||||
@@ -128,7 +133,5 @@ frames = pipe(image, decode_chunk_size=8, generator=generator, motion_bucket_id=
|
||||
export_to_video(frames, "generated.mp4", fps=7)
|
||||
```
|
||||
|
||||
<video width="1024" height="576" controls>
|
||||
<source src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/svd/rocket_generated_motion.mp4" type="video/mp4">
|
||||
</video>
|
||||

|
||||
|
||||
|
||||
@@ -1,51 +0,0 @@
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
from safetensors.torch import save_file
|
||||
|
||||
|
||||
def convert_motion_module(original_state_dict):
|
||||
converted_state_dict = {}
|
||||
for k, v in original_state_dict.items():
|
||||
if "pos_encoder" in k:
|
||||
continue
|
||||
|
||||
else:
|
||||
converted_state_dict[
|
||||
k.replace(".norms.0", ".norm1")
|
||||
.replace(".norms.1", ".norm2")
|
||||
.replace(".ff_norm", ".norm3")
|
||||
.replace(".attention_blocks.0", ".attn1")
|
||||
.replace(".attention_blocks.1", ".attn2")
|
||||
.replace(".temporal_transformer", "")
|
||||
] = v
|
||||
|
||||
return converted_state_dict
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--ckpt_path", type=str, required=True)
|
||||
parser.add_argument("--output_path", type=str, required=True)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = get_args()
|
||||
|
||||
state_dict = torch.load(args.ckpt_path, map_location="cpu")
|
||||
|
||||
if "state_dict" in state_dict.keys():
|
||||
state_dict = state_dict["state_dict"]
|
||||
|
||||
conv_state_dict = convert_motion_module(state_dict)
|
||||
|
||||
# convert to new format
|
||||
output_dict = {}
|
||||
for module_name, params in conv_state_dict.items():
|
||||
if type(params) is not torch.Tensor:
|
||||
continue
|
||||
output_dict.update({f"unet.{module_name}": params})
|
||||
|
||||
save_file(output_dict, f"{args.output_path}/diffusion_pytorch_model.safetensors")
|
||||
@@ -1,51 +0,0 @@
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import MotionAdapter
|
||||
|
||||
|
||||
def convert_motion_module(original_state_dict):
|
||||
converted_state_dict = {}
|
||||
for k, v in original_state_dict.items():
|
||||
if "pos_encoder" in k:
|
||||
continue
|
||||
|
||||
else:
|
||||
converted_state_dict[
|
||||
k.replace(".norms.0", ".norm1")
|
||||
.replace(".norms.1", ".norm2")
|
||||
.replace(".ff_norm", ".norm3")
|
||||
.replace(".attention_blocks.0", ".attn1")
|
||||
.replace(".attention_blocks.1", ".attn2")
|
||||
.replace(".temporal_transformer", "")
|
||||
] = v
|
||||
|
||||
return converted_state_dict
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--ckpt_path", type=str, required=True)
|
||||
parser.add_argument("--output_path", type=str, required=True)
|
||||
parser.add_argument("--use_motion_mid_block", action="store_true")
|
||||
parser.add_argument("--motion_max_seq_length", type=int, default=32)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = get_args()
|
||||
|
||||
state_dict = torch.load(args.ckpt_path, map_location="cpu")
|
||||
if "state_dict" in state_dict.keys():
|
||||
state_dict = state_dict["state_dict"]
|
||||
|
||||
conv_state_dict = convert_motion_module(state_dict)
|
||||
adapter = MotionAdapter(
|
||||
use_motion_mid_block=args.use_motion_mid_block, motion_max_seq_length=args.motion_max_seq_length
|
||||
)
|
||||
# skip loading position embeddings
|
||||
adapter.load_state_dict(conv_state_dict, strict=False)
|
||||
adapter.save_pretrained(args.output_path)
|
||||
adapter.save_pretrained(args.output_path, variant="fp16", torch_dtype=torch.float16)
|
||||
Reference in New Issue
Block a user