CogView4 (supports different length c and uc) (#10649)

* init

* encode with glm

* draft schedule

* feat(scheduler): Add CogView scheduler implementation

* feat(embeddings): add CogView 2D rotary positional embedding

* 1

* Update pipeline_cogview4.py

* fix the timestep init and sigma

* update latent

* draft patch(not work)

* fix

* [WIP][cogview4]: implement initial CogView4 pipeline

Implement the basic CogView4 pipeline structure with the following changes:
- Add CogView4 pipeline implementation
- Implement DDIM scheduler for CogView4
- Add CogView3Plus transformer architecture
- Update embedding models

Current limitations:
- CFG implementation uses padding for sequence length alignment
- Need to verify transformer inference alignment with Megatron

TODO:
- Consider separate forward passes for condition/uncondition
  instead of padding approach

* [WIP][cogview4][refactor]: Split condition/uncondition forward pass in CogView4 pipeline

Split the forward pass for conditional and unconditional predictions in the CogView4 pipeline to match the original implementation. The noise prediction is now done separately for each case before combining them for guidance. However, the results still need improvement.

This is a work in progress as the generated images are not yet matching expected quality.

* use with -2 hidden state

* remove text_projector

* 1

* [WIP] Add tensor-reload to align input from transformer block

* [WIP] for older glm

* use with cogview4 transformers forward twice of u and uc

* Update convert_cogview4_to_diffusers.py

* remove this

* use main example

* change back

* reset

* setback

* back

* back 4

* Fix qkv conversion logic for CogView4 to Diffusers format

* back5

* revert to sat to cogview4 version

* update a new convert from megatron

* [WIP][cogview4]: implement CogView4 attention processor

Add CogView4AttnProcessor class for implementing scaled dot-product attention
with rotary embeddings for the CogVideoX model. This processor concatenates
encoder and hidden states, applies QKV projections and RoPE, but does not
include spatial normalization.

TODO:
- Fix incorrect QKV projection weights
- Resolve ~25% error in RoPE implementation compared to Megatron

* [cogview4] implement CogView4 transformer block

Implement CogView4 transformer block following the Megatron architecture:
- Add multi-modulate and multi-gate mechanisms for adaptive layer normalization
- Implement dual-stream attention with encoder-decoder structure
- Add feed-forward network with GELU activation
- Support rotary position embeddings for image tokens

The implementation follows the original CogView4 architecture while adapting
it to work within the diffusers framework.

* with new attn

* [bugfix] fix dimension mismatch in CogView4 attention

* [cogview4][WIP]: update final normalization in CogView4 transformer

Refactored the final normalization layer in CogView4 transformer to use separate layernorm and AdaLN operations instead of combined AdaLayerNormContinuous. This matches the original implementation but needs validation.

Needs verification against reference implementation.

* 1

* put back

* Update transformer_cogview4.py

* change time_shift

* Update pipeline_cogview4.py

* change timesteps

* fix

* change text_encoder_id

* [cogview4][rope] align RoPE implementation with Megatron

- Implement apply_rope method in attention processor to match Megatron's implementation
- Update position embeddings to ensure compatibility with Megatron-style rotary embeddings
- Ensure consistent rotary position encoding across attention layers

This change improves compatibility with Megatron-based models and provides
better alignment with the original implementation's positional encoding approach.

* [cogview4][bugfix] apply silu activation to time embeddings in CogView4

Applied silu activation to time embeddings before splitting into conditional
and unconditional parts in CogView4Transformer2DModel. This matches the
original implementation and helps ensure correct time conditioning behavior.

* [cogview4][chore] clean up pipeline code

- Remove commented out code and debug statements
- Remove unused retrieve_timesteps function
- Clean up code formatting and documentation

This commit focuses on code cleanup in the CogView4 pipeline implementation, removing unnecessary commented code and improving readability without changing functionality.

* [cogview4][scheduler] Implement CogView4 scheduler and pipeline

* now It work

* add timestep

* batch

* change convert scipt

* refactor pt. 1; make style

* refactor pt. 2

* refactor pt. 3

* add tests

* make fix-copies

* update toctree.yml

* use flow match scheduler instead of custom

* remove scheduling_cogview.py

* add tiktoken to test dependencies

* Update src/diffusers/models/embeddings.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* apply suggestions from review

* use diffusers apply_rotary_emb

* update flow match scheduler to accept timesteps

* fix comment

* apply review sugestions

* Update src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

---------

Co-authored-by: 三洋三洋 <1258009915@qq.com>
Co-authored-by: OleehyO <leehy0357@gmail.com>
Co-authored-by: Aryan <aryan@huggingface.co>
Co-authored-by: YiYi Xu <yixu310@gmail.com>
This commit is contained in:
Yuxuan Zhang
2025-02-16 00:16:48 +08:00
committed by GitHub
parent 69f919d8b5
commit d90cd3621d
24 changed files with 2262 additions and 18 deletions
+243
View File
@@ -0,0 +1,243 @@
"""
Convert a CogView4 checkpoint from SAT(https://github.com/THUDM/SwissArmyTransformer) to the Diffusers format.
(deprecated Since 2025-02-07 and will remove it in later CogView4 version)
This script converts a CogView4 checkpoint to the Diffusers format, which can then be used
with the Diffusers library.
Example usage:
python scripts/convert_cogview4_to_diffusers.py \
--transformer_checkpoint_path 'your path/cogview4_6b/1/mp_rank_00_model_states.pt' \
--vae_checkpoint_path 'your path/cogview4_6b/imagekl_ch16.pt' \
--output_path "THUDM/CogView4-6B" \
--dtype "bf16"
Arguments:
--transformer_checkpoint_path: Path to Transformer state dict.
--vae_checkpoint_path: Path to VAE state dict.
--output_path: The path to save the converted model.
--push_to_hub: Whether to push the converted checkpoint to the HF Hub or not. Defaults to `False`.
--text_encoder_cache_dir: Cache directory where text encoder is located. Defaults to None, which means HF_HOME will be used
--dtype: The dtype to save the model in (default: "bf16", options: "fp16", "bf16", "fp32"). If None, the dtype of the state dict is considered.
Default is "bf16" because CogView4 uses bfloat16 for Training.
Note: You must provide either --original_state_dict_repo_id or --checkpoint_path.
"""
import argparse
from contextlib import nullcontext
import torch
from accelerate import init_empty_weights
from transformers import GlmForCausalLM, PreTrainedTokenizerFast
from diffusers import AutoencoderKL, CogView4Pipeline, CogView4Transformer2DModel, FlowMatchEulerDiscreteScheduler
from diffusers.loaders.single_file_utils import convert_ldm_vae_checkpoint
from diffusers.utils.import_utils import is_accelerate_available
CTX = init_empty_weights if is_accelerate_available() else nullcontext
parser = argparse.ArgumentParser()
parser.add_argument("--transformer_checkpoint_path", default=None, type=str)
parser.add_argument("--vae_checkpoint_path", default=None, type=str)
parser.add_argument("--output_path", required=True, type=str)
parser.add_argument("--push_to_hub", action="store_true", default=False, help="Whether to push to HF Hub after saving")
parser.add_argument("--text_encoder_cache_dir", type=str, default=None, help="Path to text encoder cache directory")
parser.add_argument("--dtype", type=str, default="bf16")
args = parser.parse_args()
# this is specific to `AdaLayerNormContinuous`:
# diffusers implementation split the linear projection into the scale, shift while CogView4 split it tino shift, scale
def swap_scale_shift(weight, dim):
shift, scale = weight.chunk(2, dim=0)
new_weight = torch.cat([scale, shift], dim=0)
return new_weight
def convert_cogview4_transformer_checkpoint_to_diffusers(ckpt_path):
original_state_dict = torch.load(ckpt_path, map_location="cpu")
original_state_dict = original_state_dict["module"]
original_state_dict = {k.replace("model.diffusion_model.", ""): v for k, v in original_state_dict.items()}
new_state_dict = {}
# Convert patch_embed
new_state_dict["patch_embed.proj.weight"] = original_state_dict.pop("mixins.patch_embed.proj.weight")
new_state_dict["patch_embed.proj.bias"] = original_state_dict.pop("mixins.patch_embed.proj.bias")
new_state_dict["patch_embed.text_proj.weight"] = original_state_dict.pop("mixins.patch_embed.text_proj.weight")
new_state_dict["patch_embed.text_proj.bias"] = original_state_dict.pop("mixins.patch_embed.text_proj.bias")
# Convert time_condition_embed
new_state_dict["time_condition_embed.timestep_embedder.linear_1.weight"] = original_state_dict.pop(
"time_embed.0.weight"
)
new_state_dict["time_condition_embed.timestep_embedder.linear_1.bias"] = original_state_dict.pop(
"time_embed.0.bias"
)
new_state_dict["time_condition_embed.timestep_embedder.linear_2.weight"] = original_state_dict.pop(
"time_embed.2.weight"
)
new_state_dict["time_condition_embed.timestep_embedder.linear_2.bias"] = original_state_dict.pop(
"time_embed.2.bias"
)
new_state_dict["time_condition_embed.condition_embedder.linear_1.weight"] = original_state_dict.pop(
"label_emb.0.0.weight"
)
new_state_dict["time_condition_embed.condition_embedder.linear_1.bias"] = original_state_dict.pop(
"label_emb.0.0.bias"
)
new_state_dict["time_condition_embed.condition_embedder.linear_2.weight"] = original_state_dict.pop(
"label_emb.0.2.weight"
)
new_state_dict["time_condition_embed.condition_embedder.linear_2.bias"] = original_state_dict.pop(
"label_emb.0.2.bias"
)
# Convert transformer blocks, for cogview4 is 28 blocks
for i in range(28):
block_prefix = f"transformer_blocks.{i}."
old_prefix = f"transformer.layers.{i}."
adaln_prefix = f"mixins.adaln.adaln_modules.{i}."
new_state_dict[block_prefix + "norm1.linear.weight"] = original_state_dict.pop(adaln_prefix + "1.weight")
new_state_dict[block_prefix + "norm1.linear.bias"] = original_state_dict.pop(adaln_prefix + "1.bias")
qkv_weight = original_state_dict.pop(old_prefix + "attention.query_key_value.weight")
qkv_bias = original_state_dict.pop(old_prefix + "attention.query_key_value.bias")
q, k, v = qkv_weight.chunk(3, dim=0)
q_bias, k_bias, v_bias = qkv_bias.chunk(3, dim=0)
new_state_dict[block_prefix + "attn1.to_q.weight"] = q
new_state_dict[block_prefix + "attn1.to_q.bias"] = q_bias
new_state_dict[block_prefix + "attn1.to_k.weight"] = k
new_state_dict[block_prefix + "attn1.to_k.bias"] = k_bias
new_state_dict[block_prefix + "attn1.to_v.weight"] = v
new_state_dict[block_prefix + "attn1.to_v.bias"] = v_bias
new_state_dict[block_prefix + "attn1.to_out.0.weight"] = original_state_dict.pop(
old_prefix + "attention.dense.weight"
)
new_state_dict[block_prefix + "attn1.to_out.0.bias"] = original_state_dict.pop(
old_prefix + "attention.dense.bias"
)
new_state_dict[block_prefix + "ff.net.0.proj.weight"] = original_state_dict.pop(
old_prefix + "mlp.dense_h_to_4h.weight"
)
new_state_dict[block_prefix + "ff.net.0.proj.bias"] = original_state_dict.pop(
old_prefix + "mlp.dense_h_to_4h.bias"
)
new_state_dict[block_prefix + "ff.net.2.weight"] = original_state_dict.pop(
old_prefix + "mlp.dense_4h_to_h.weight"
)
new_state_dict[block_prefix + "ff.net.2.bias"] = original_state_dict.pop(old_prefix + "mlp.dense_4h_to_h.bias")
# Convert final norm and projection
new_state_dict["norm_out.linear.weight"] = swap_scale_shift(
original_state_dict.pop("mixins.final_layer.adaln.1.weight"), dim=0
)
new_state_dict["norm_out.linear.bias"] = swap_scale_shift(
original_state_dict.pop("mixins.final_layer.adaln.1.bias"), dim=0
)
new_state_dict["proj_out.weight"] = original_state_dict.pop("mixins.final_layer.linear.weight")
new_state_dict["proj_out.bias"] = original_state_dict.pop("mixins.final_layer.linear.bias")
return new_state_dict
def convert_cogview4_vae_checkpoint_to_diffusers(ckpt_path, vae_config):
original_state_dict = torch.load(ckpt_path, map_location="cpu")["state_dict"]
return convert_ldm_vae_checkpoint(original_state_dict, vae_config)
def main(args):
if args.dtype == "fp16":
dtype = torch.float16
elif args.dtype == "bf16":
dtype = torch.bfloat16
elif args.dtype == "fp32":
dtype = torch.float32
else:
raise ValueError(f"Unsupported dtype: {args.dtype}")
transformer = None
vae = None
if args.transformer_checkpoint_path is not None:
converted_transformer_state_dict = convert_cogview4_transformer_checkpoint_to_diffusers(
args.transformer_checkpoint_path
)
transformer = CogView4Transformer2DModel(
patch_size=2,
in_channels=16,
num_layers=28,
attention_head_dim=128,
num_attention_heads=32,
out_channels=16,
text_embed_dim=4096,
time_embed_dim=512,
condition_dim=256,
pos_embed_max_size=128,
)
transformer.load_state_dict(converted_transformer_state_dict, strict=True)
if dtype is not None:
# Original checkpoint data type will be preserved
transformer = transformer.to(dtype=dtype)
if args.vae_checkpoint_path is not None:
vae_config = {
"in_channels": 3,
"out_channels": 3,
"down_block_types": ("DownEncoderBlock2D",) * 4,
"up_block_types": ("UpDecoderBlock2D",) * 4,
"block_out_channels": (128, 512, 1024, 1024),
"layers_per_block": 3,
"act_fn": "silu",
"latent_channels": 16,
"norm_num_groups": 32,
"sample_size": 1024,
"scaling_factor": 1.0,
"force_upcast": True,
"use_quant_conv": False,
"use_post_quant_conv": False,
"mid_block_add_attention": False,
}
converted_vae_state_dict = convert_cogview4_vae_checkpoint_to_diffusers(args.vae_checkpoint_path, vae_config)
vae = AutoencoderKL(**vae_config)
vae.load_state_dict(converted_vae_state_dict, strict=True)
if dtype is not None:
vae = vae.to(dtype=dtype)
text_encoder_id = "THUDM/glm-4-9b-hf"
tokenizer = PreTrainedTokenizerFast.from_pretrained(text_encoder_id)
text_encoder = GlmForCausalLM.from_pretrained(
text_encoder_id,
cache_dir=args.text_encoder_cache_dir,
torch_dtype=torch.bfloat16 if args.dtype == "bf16" else torch.float32,
)
for param in text_encoder.parameters():
param.data = param.data.contiguous()
scheduler = FlowMatchEulerDiscreteScheduler(
base_shift=0.25, max_shift=0.75, base_image_seq_len=256, use_dynamic_shifting=True, time_shift_type="linear"
)
pipe = CogView4Pipeline(
tokenizer=tokenizer,
text_encoder=text_encoder,
vae=vae,
transformer=transformer,
scheduler=scheduler,
)
# This is necessary for users with insufficient memory, such as those using Colab and notebooks, as it can
# save some memory used for model loading.
pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB", push_to_hub=args.push_to_hub)
if __name__ == "__main__":
main(args)
@@ -0,0 +1,366 @@
"""
Convert a CogView4 checkpoint from Megatron to the Diffusers format.
Example usage:
python scripts/convert_cogview4_to_diffusers.py \
--transformer_checkpoint_path 'your path/cogview4_6b/mp_rank_00/model_optim_rng.pt' \
--vae_checkpoint_path 'your path/cogview4_6b/imagekl_ch16.pt' \
--output_path "THUDM/CogView4-6B" \
--dtype "bf16"
Arguments:
--transformer_checkpoint_path: Path to Transformer state dict.
--vae_checkpoint_path: Path to VAE state dict.
--output_path: The path to save the converted model.
--push_to_hub: Whether to push the converted checkpoint to the HF Hub or not. Defaults to `False`.
--text_encoder_cache_dir: Cache directory where text encoder is located. Defaults to None, which means HF_HOME will be used.
--dtype: The dtype to save the model in (default: "bf16", options: "fp16", "bf16", "fp32"). If None, the dtype of the state dict is considered.
Default is "bf16" because CogView4 uses bfloat16 for training.
Note: You must provide either --transformer_checkpoint_path or --vae_checkpoint_path.
"""
import argparse
import torch
from tqdm import tqdm
from transformers import GlmForCausalLM, PreTrainedTokenizerFast
from diffusers import AutoencoderKL, CogView4Pipeline, CogView4Transformer2DModel, FlowMatchEulerDiscreteScheduler
from diffusers.loaders.single_file_utils import convert_ldm_vae_checkpoint
parser = argparse.ArgumentParser()
parser.add_argument(
"--transformer_checkpoint_path",
default=None,
type=str,
help="Path to Megatron (not SAT) Transformer checkpoint, e.g., 'model_optim_rng.pt'.",
)
parser.add_argument(
"--vae_checkpoint_path",
default=None,
type=str,
help="(Optional) Path to VAE checkpoint, e.g., 'imagekl_ch16.pt'.",
)
parser.add_argument(
"--output_path",
required=True,
type=str,
help="Directory to save the final Diffusers format pipeline.",
)
parser.add_argument(
"--push_to_hub",
action="store_true",
default=False,
help="Whether to push the converted model to the HuggingFace Hub.",
)
parser.add_argument(
"--text_encoder_cache_dir",
type=str,
default=None,
help="Specify the cache directory for the text encoder.",
)
parser.add_argument(
"--dtype",
type=str,
default="bf16",
choices=["fp16", "bf16", "fp32"],
help="Data type to save the model in.",
)
parser.add_argument(
"--num_layers",
type=int,
default=28,
help="Number of Transformer layers (e.g., 28, 48...).",
)
parser.add_argument(
"--num_heads",
type=int,
default=32,
help="Number of attention heads.",
)
parser.add_argument(
"--hidden_size",
type=int,
default=4096,
help="Transformer hidden dimension size.",
)
parser.add_argument(
"--attention_head_dim",
type=int,
default=128,
help="Dimension of each attention head.",
)
parser.add_argument(
"--time_embed_dim",
type=int,
default=512,
help="Dimension of time embeddings.",
)
parser.add_argument(
"--condition_dim",
type=int,
default=256,
help="Dimension of condition embeddings.",
)
parser.add_argument(
"--pos_embed_max_size",
type=int,
default=128,
help="Maximum size for positional embeddings.",
)
args = parser.parse_args()
def swap_scale_shift(weight, dim):
"""
Swap the scale and shift components in the weight tensor.
Args:
weight (torch.Tensor): The original weight tensor.
dim (int): The dimension along which to split.
Returns:
torch.Tensor: The modified weight tensor with scale and shift swapped.
"""
shift, scale = weight.chunk(2, dim=dim)
new_weight = torch.cat([scale, shift], dim=dim)
return new_weight
def convert_megatron_transformer_checkpoint_to_diffusers(
ckpt_path: str,
num_layers: int,
num_heads: int,
hidden_size: int,
):
"""
Convert a Megatron Transformer checkpoint to Diffusers format.
Args:
ckpt_path (str): Path to the Megatron Transformer checkpoint.
num_layers (int): Number of Transformer layers.
num_heads (int): Number of attention heads.
hidden_size (int): Hidden size of the Transformer.
Returns:
dict: The converted state dictionary compatible with Diffusers.
"""
ckpt = torch.load(ckpt_path, map_location="cpu")
mega = ckpt["model"]
new_state_dict = {}
# Patch Embedding
new_state_dict["patch_embed.proj.weight"] = mega["encoder_expand_linear.weight"].reshape(hidden_size, 64)
new_state_dict["patch_embed.proj.bias"] = mega["encoder_expand_linear.bias"]
new_state_dict["patch_embed.text_proj.weight"] = mega["text_projector.weight"]
new_state_dict["patch_embed.text_proj.bias"] = mega["text_projector.bias"]
# Time Condition Embedding
new_state_dict["time_condition_embed.timestep_embedder.linear_1.weight"] = mega[
"time_embedding.time_embed.0.weight"
]
new_state_dict["time_condition_embed.timestep_embedder.linear_1.bias"] = mega["time_embedding.time_embed.0.bias"]
new_state_dict["time_condition_embed.timestep_embedder.linear_2.weight"] = mega[
"time_embedding.time_embed.2.weight"
]
new_state_dict["time_condition_embed.timestep_embedder.linear_2.bias"] = mega["time_embedding.time_embed.2.bias"]
new_state_dict["time_condition_embed.condition_embedder.linear_1.weight"] = mega[
"label_embedding.label_embed.0.weight"
]
new_state_dict["time_condition_embed.condition_embedder.linear_1.bias"] = mega[
"label_embedding.label_embed.0.bias"
]
new_state_dict["time_condition_embed.condition_embedder.linear_2.weight"] = mega[
"label_embedding.label_embed.2.weight"
]
new_state_dict["time_condition_embed.condition_embedder.linear_2.bias"] = mega[
"label_embedding.label_embed.2.bias"
]
# Convert each Transformer layer
for i in tqdm(range(num_layers), desc="Converting layers (Megatron->Diffusers)"):
block_prefix = f"transformer_blocks.{i}."
# AdaLayerNorm
new_state_dict[block_prefix + "norm1.linear.weight"] = swap_scale_shift(
mega[f"decoder.layers.{i}.adaln.weight"], dim=0
)
new_state_dict[block_prefix + "norm1.linear.bias"] = swap_scale_shift(
mega[f"decoder.layers.{i}.adaln.bias"], dim=0
)
# QKV
qkv_weight = mega[f"decoder.layers.{i}.self_attention.linear_qkv.weight"]
qkv_bias = mega[f"decoder.layers.{i}.self_attention.linear_qkv.bias"]
# Reshape to match SAT logic
qkv_weight = qkv_weight.view(num_heads, 3, hidden_size // num_heads, hidden_size)
qkv_weight = qkv_weight.permute(1, 0, 2, 3).reshape(3 * hidden_size, hidden_size)
qkv_bias = qkv_bias.view(num_heads, 3, hidden_size // num_heads)
qkv_bias = qkv_bias.permute(1, 0, 2).reshape(3 * hidden_size)
# Assign to Diffusers keys
q, k, v = torch.chunk(qkv_weight, 3, dim=0)
qb, kb, vb = torch.chunk(qkv_bias, 3, dim=0)
new_state_dict[block_prefix + "attn1.to_q.weight"] = q
new_state_dict[block_prefix + "attn1.to_q.bias"] = qb
new_state_dict[block_prefix + "attn1.to_k.weight"] = k
new_state_dict[block_prefix + "attn1.to_k.bias"] = kb
new_state_dict[block_prefix + "attn1.to_v.weight"] = v
new_state_dict[block_prefix + "attn1.to_v.bias"] = vb
# Attention Output
new_state_dict[block_prefix + "attn1.to_out.0.weight"] = mega[
f"decoder.layers.{i}.self_attention.linear_proj.weight"
].T
new_state_dict[block_prefix + "attn1.to_out.0.bias"] = mega[
f"decoder.layers.{i}.self_attention.linear_proj.bias"
]
# MLP
new_state_dict[block_prefix + "ff.net.0.proj.weight"] = mega[f"decoder.layers.{i}.mlp.linear_fc1.weight"]
new_state_dict[block_prefix + "ff.net.0.proj.bias"] = mega[f"decoder.layers.{i}.mlp.linear_fc1.bias"]
new_state_dict[block_prefix + "ff.net.2.weight"] = mega[f"decoder.layers.{i}.mlp.linear_fc2.weight"]
new_state_dict[block_prefix + "ff.net.2.bias"] = mega[f"decoder.layers.{i}.mlp.linear_fc2.bias"]
# Final Layers
new_state_dict["norm_out.linear.weight"] = swap_scale_shift(mega["adaln_final.weight"], dim=0)
new_state_dict["norm_out.linear.bias"] = swap_scale_shift(mega["adaln_final.bias"], dim=0)
new_state_dict["proj_out.weight"] = mega["output_projector.weight"]
new_state_dict["proj_out.bias"] = mega["output_projector.bias"]
return new_state_dict
def convert_cogview4_vae_checkpoint_to_diffusers(ckpt_path, vae_config):
"""
Convert a CogView4 VAE checkpoint to Diffusers format.
Args:
ckpt_path (str): Path to the VAE checkpoint.
vae_config (dict): Configuration dictionary for the VAE.
Returns:
dict: The converted VAE state dictionary compatible with Diffusers.
"""
original_state_dict = torch.load(ckpt_path, map_location="cpu")["state_dict"]
return convert_ldm_vae_checkpoint(original_state_dict, vae_config)
def main(args):
"""
Main function to convert CogView4 checkpoints to Diffusers format.
Args:
args (argparse.Namespace): Parsed command-line arguments.
"""
# Determine the desired data type
if args.dtype == "fp16":
dtype = torch.float16
elif args.dtype == "bf16":
dtype = torch.bfloat16
elif args.dtype == "fp32":
dtype = torch.float32
else:
raise ValueError(f"Unsupported dtype: {args.dtype}")
transformer = None
vae = None
# Convert Transformer checkpoint if provided
if args.transformer_checkpoint_path is not None:
converted_transformer_state_dict = convert_megatron_transformer_checkpoint_to_diffusers(
ckpt_path=args.transformer_checkpoint_path,
num_layers=args.num_layers,
num_heads=args.num_heads,
hidden_size=args.hidden_size,
)
transformer = CogView4Transformer2DModel(
patch_size=2,
in_channels=16,
num_layers=args.num_layers,
attention_head_dim=args.attention_head_dim,
num_attention_heads=args.num_heads,
out_channels=16,
text_embed_dim=args.hidden_size,
time_embed_dim=args.time_embed_dim,
condition_dim=args.condition_dim,
pos_embed_max_size=args.pos_embed_max_size,
)
transformer.load_state_dict(converted_transformer_state_dict, strict=True)
# Convert to the specified dtype
if dtype is not None:
transformer = transformer.to(dtype=dtype)
# Convert VAE checkpoint if provided
if args.vae_checkpoint_path is not None:
vae_config = {
"in_channels": 3,
"out_channels": 3,
"down_block_types": ("DownEncoderBlock2D",) * 4,
"up_block_types": ("UpDecoderBlock2D",) * 4,
"block_out_channels": (128, 512, 1024, 1024),
"layers_per_block": 3,
"act_fn": "silu",
"latent_channels": 16,
"norm_num_groups": 32,
"sample_size": 1024,
"scaling_factor": 1.0,
"force_upcast": True,
"use_quant_conv": False,
"use_post_quant_conv": False,
"mid_block_add_attention": False,
}
converted_vae_state_dict = convert_cogview4_vae_checkpoint_to_diffusers(args.vae_checkpoint_path, vae_config)
vae = AutoencoderKL(**vae_config)
vae.load_state_dict(converted_vae_state_dict, strict=True)
if dtype is not None:
vae = vae.to(dtype=dtype)
# Load the text encoder and tokenizer
text_encoder_id = "THUDM/glm-4-9b-hf"
tokenizer = PreTrainedTokenizerFast.from_pretrained(text_encoder_id)
text_encoder = GlmForCausalLM.from_pretrained(
text_encoder_id,
cache_dir=args.text_encoder_cache_dir,
torch_dtype=torch.bfloat16 if args.dtype == "bf16" else torch.float32,
)
for param in text_encoder.parameters():
param.data = param.data.contiguous()
# Initialize the scheduler
scheduler = FlowMatchEulerDiscreteScheduler(
base_shift=0.25, max_shift=0.75, base_image_seq_len=256, use_dynamic_shifting=True, time_shift_type="linear"
)
# Create the pipeline
pipe = CogView4Pipeline(
tokenizer=tokenizer,
text_encoder=text_encoder,
vae=vae,
transformer=transformer,
scheduler=scheduler,
)
# Save the converted pipeline
pipe.save_pretrained(
args.output_path,
safe_serialization=True,
max_shard_size="5GB",
push_to_hub=args.push_to_hub,
)
if __name__ == "__main__":
main(args)