Compare commits
6 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 92199ff3ac | |||
| 04e9323055 | |||
| 9a09162baf | |||
| 33a8a3be0c | |||
| 58743c3ee7 | |||
| 50c0b786d2 |
@@ -278,29 +278,6 @@ def get_transformer_config(model_type: str) -> Tuple[Dict[str, Any], ...]:
|
||||
}
|
||||
RENAME_DICT = VACE_TRANSFORMER_KEYS_RENAME_DICT
|
||||
SPECIAL_KEYS_REMAP = VACE_TRANSFORMER_SPECIAL_KEYS_REMAP
|
||||
elif model_type == "Wan2.2-VACE-Fun-14B":
|
||||
config = {
|
||||
"model_id": "alibaba-pai/Wan2.2-VACE-Fun-A14B",
|
||||
"diffusers_config": {
|
||||
"added_kv_proj_dim": None,
|
||||
"attention_head_dim": 128,
|
||||
"cross_attn_norm": True,
|
||||
"eps": 1e-06,
|
||||
"ffn_dim": 13824,
|
||||
"freq_dim": 256,
|
||||
"in_channels": 16,
|
||||
"num_attention_heads": 40,
|
||||
"num_layers": 40,
|
||||
"out_channels": 16,
|
||||
"patch_size": [1, 2, 2],
|
||||
"qk_norm": "rms_norm_across_heads",
|
||||
"text_dim": 4096,
|
||||
"vace_layers": [0, 5, 10, 15, 20, 25, 30, 35],
|
||||
"vace_in_channels": 96,
|
||||
},
|
||||
}
|
||||
RENAME_DICT = VACE_TRANSFORMER_KEYS_RENAME_DICT
|
||||
SPECIAL_KEYS_REMAP = VACE_TRANSFORMER_SPECIAL_KEYS_REMAP
|
||||
elif model_type == "Wan2.2-I2V-14B-720p":
|
||||
config = {
|
||||
"model_id": "Wan-AI/Wan2.2-I2V-A14B",
|
||||
@@ -998,17 +975,7 @@ if __name__ == "__main__":
|
||||
image_encoder=image_encoder,
|
||||
image_processor=image_processor,
|
||||
)
|
||||
elif "Wan2.2-VACE" in args.model_type:
|
||||
pipe = WanVACEPipeline(
|
||||
transformer=transformer,
|
||||
transformer_2=transformer_2,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
vae=vae,
|
||||
scheduler=scheduler,
|
||||
boundary_ratio=0.875,
|
||||
)
|
||||
elif "Wan-VACE" in args.model_type:
|
||||
elif "VACE" in args.model_type:
|
||||
pipe = WanVACEPipeline(
|
||||
transformer=transformer,
|
||||
text_encoder=text_encoder,
|
||||
|
||||
@@ -207,7 +207,6 @@ DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
|
||||
"wan-i2v-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"},
|
||||
"wan-vace-1.3B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-VACE-1.3B-diffusers"},
|
||||
"wan-vace-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-VACE-14B-diffusers"},
|
||||
"wan-2-2-i2v-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan-AI/Wan2.1-I2V-14B-720P-Diffusers"},
|
||||
"hidream": {"pretrained_model_name_or_path": "HiDream-ai/HiDream-I1-Dev"},
|
||||
"cosmos-1.0-t2w-7B": {"pretrained_model_name_or_path": "nvidia/Cosmos-1.0-Diffusion-7B-Text2World"},
|
||||
"cosmos-1.0-t2w-14B": {"pretrained_model_name_or_path": "nvidia/Cosmos-1.0-Diffusion-14B-Text2World"},
|
||||
@@ -733,10 +732,7 @@ def infer_diffusers_model_type(checkpoint):
|
||||
elif checkpoint[target_key].shape[0] == 5120 and checkpoint[target_key].shape[1] == 16:
|
||||
model_type = "wan-t2v-14B"
|
||||
else:
|
||||
if "img_emb.proj.0.bias" in checkpoint:
|
||||
model_type = "wan-i2v-14B"
|
||||
else:
|
||||
model_type = "wan-2-2-i2v-14B"
|
||||
model_type = "wan-i2v-14B"
|
||||
|
||||
elif CHECKPOINT_KEY_NAMES["wan_vae"] in checkpoint:
|
||||
# All Wan models use the same VAE so we can use the same default model repo to fetch the config
|
||||
|
||||
@@ -17,10 +17,11 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from ..utils import deprecate
|
||||
from ..utils.import_utils import is_torch_npu_available, is_torch_version
|
||||
from ..utils import deprecate, get_logger, is_torch_npu_available, is_torch_version
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
if is_torch_npu_available():
|
||||
import torch_npu
|
||||
|
||||
@@ -31,6 +32,7 @@ ACT2CLS = {
|
||||
"gelu": nn.GELU,
|
||||
"relu": nn.ReLU,
|
||||
}
|
||||
KERNELS_REPO_ID = "kernels-community/activation"
|
||||
|
||||
|
||||
def get_activation(act_fn: str) -> nn.Module:
|
||||
@@ -90,6 +92,27 @@ class GELU(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
# TODO: validation checks / consider making Python classes of activations like `transformers`
|
||||
# All of these are temporary for now.
|
||||
class CUDAOptimizedGELU(GELU):
|
||||
def __init__(self, *args, **kwargs):
|
||||
from kernels import get_kernel
|
||||
|
||||
activation = get_kernel("kernels-community/activation", revision="add_more_act")
|
||||
approximate = kwargs.get("approximate", "none")
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
if approximate == "none":
|
||||
self.act_fn = activation.layers.Gelu()
|
||||
elif approximate == "tanh":
|
||||
self.act_fn = activation.layers.GeluTanh()
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.proj(hidden_states)
|
||||
hidden_states = self.act_fn(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class GEGLU(nn.Module):
|
||||
r"""
|
||||
A [variant](https://huggingface.co/papers/2002.05202) of the gated linear unit activation function.
|
||||
|
||||
@@ -20,11 +20,20 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ..utils import is_torch_npu_available, is_torch_version
|
||||
from ..utils import is_kernels_available, is_torch_npu_available, is_torch_version
|
||||
from ..utils.constants import DIFFUSERS_ENABLE_HUB_KERNELS
|
||||
from ..utils.kernels_utils import use_kernel_forward_from_hub
|
||||
from .activations import get_activation
|
||||
from .embeddings import CombinedTimestepLabelEmbeddings, PixArtAlphaCombinedTimestepSizeEmbeddings
|
||||
|
||||
|
||||
if is_kernels_available() and DIFFUSERS_ENABLE_HUB_KERNELS:
|
||||
from kernels import get_kernel
|
||||
|
||||
activation = get_kernel("kernels-community/activation", revision="add_more_act")
|
||||
silu_kernel = activation.layers.Silu
|
||||
|
||||
|
||||
class AdaLayerNorm(nn.Module):
|
||||
r"""
|
||||
Norm layer modified to incorporate timestep embeddings.
|
||||
@@ -57,7 +66,10 @@ class AdaLayerNorm(nn.Module):
|
||||
else:
|
||||
self.emb = None
|
||||
|
||||
self.silu = nn.SiLU()
|
||||
if DIFFUSERS_ENABLE_HUB_KERNELS:
|
||||
self.silu = silu_kernel()
|
||||
else:
|
||||
self.silu = nn.SiLU()
|
||||
self.linear = nn.Linear(embedding_dim, output_dim)
|
||||
self.norm = nn.LayerNorm(output_dim // 2, norm_eps, norm_elementwise_affine)
|
||||
|
||||
@@ -144,7 +156,10 @@ class AdaLayerNormZero(nn.Module):
|
||||
else:
|
||||
self.emb = None
|
||||
|
||||
self.silu = nn.SiLU()
|
||||
if DIFFUSERS_ENABLE_HUB_KERNELS:
|
||||
self.silu = silu_kernel()
|
||||
else:
|
||||
self.silu = nn.SiLU()
|
||||
self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=bias)
|
||||
if norm_type == "layer_norm":
|
||||
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
|
||||
@@ -183,7 +198,10 @@ class AdaLayerNormZeroSingle(nn.Module):
|
||||
def __init__(self, embedding_dim: int, norm_type="layer_norm", bias=True):
|
||||
super().__init__()
|
||||
|
||||
self.silu = nn.SiLU()
|
||||
if DIFFUSERS_ENABLE_HUB_KERNELS:
|
||||
self.silu = silu_kernel()
|
||||
else:
|
||||
self.silu = nn.SiLU()
|
||||
self.linear = nn.Linear(embedding_dim, 3 * embedding_dim, bias=bias)
|
||||
if norm_type == "layer_norm":
|
||||
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
|
||||
@@ -335,7 +353,10 @@ class AdaLayerNormContinuous(nn.Module):
|
||||
norm_type="layer_norm",
|
||||
):
|
||||
super().__init__()
|
||||
self.silu = nn.SiLU()
|
||||
if DIFFUSERS_ENABLE_HUB_KERNELS:
|
||||
self.silu = silu_kernel()
|
||||
else:
|
||||
self.silu = nn.SiLU()
|
||||
self.linear = nn.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=bias)
|
||||
if norm_type == "layer_norm":
|
||||
self.norm = LayerNorm(embedding_dim, eps, elementwise_affine, bias)
|
||||
@@ -508,6 +529,7 @@ else:
|
||||
return F.layer_norm(input, self.dim, self.weight, self.bias, self.eps)
|
||||
|
||||
|
||||
@use_kernel_forward_from_hub("RMSNorm")
|
||||
class RMSNorm(nn.Module):
|
||||
r"""
|
||||
RMS Norm as introduced in https://huggingface.co/papers/1910.07467 by Zhang et al.
|
||||
|
||||
@@ -22,7 +22,8 @@ import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils import USE_PEFT_BACKEND, is_kernels_available, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils.constants import DIFFUSERS_ENABLE_HUB_KERNELS
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
|
||||
from ..attention_dispatch import dispatch_attention_fn
|
||||
@@ -40,6 +41,12 @@ from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNo
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
if is_kernels_available() and DIFFUSERS_ENABLE_HUB_KERNELS:
|
||||
from kernels import get_kernel
|
||||
|
||||
activation = get_kernel("kernels-community/activation", revision="add_more_act")
|
||||
gelu_tanh_kernel = activation.layers.GeluTanh
|
||||
|
||||
|
||||
def _get_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states=None):
|
||||
query = attn.to_q(hidden_states)
|
||||
@@ -300,8 +307,14 @@ class FluxAttention(torch.nn.Module, AttentionModuleMixin):
|
||||
self.added_kv_proj_dim = added_kv_proj_dim
|
||||
self.added_proj_bias = added_proj_bias
|
||||
|
||||
self.norm_q = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
|
||||
self.norm_k = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
|
||||
if DIFFUSERS_ENABLE_HUB_KERNELS:
|
||||
from ..normalization import RMSNorm
|
||||
|
||||
self.norm_q = RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
|
||||
self.norm_k = RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
|
||||
else:
|
||||
self.norm_q = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
|
||||
self.norm_k = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
|
||||
self.to_q = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
|
||||
self.to_k = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
|
||||
self.to_v = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
|
||||
@@ -312,8 +325,14 @@ class FluxAttention(torch.nn.Module, AttentionModuleMixin):
|
||||
self.to_out.append(torch.nn.Dropout(dropout))
|
||||
|
||||
if added_kv_proj_dim is not None:
|
||||
self.norm_added_q = torch.nn.RMSNorm(dim_head, eps=eps)
|
||||
self.norm_added_k = torch.nn.RMSNorm(dim_head, eps=eps)
|
||||
if DIFFUSERS_ENABLE_HUB_KERNELS:
|
||||
from ..normalization import RMSNorm
|
||||
|
||||
self.norm_added_q = RMSNorm(dim_head, eps=eps)
|
||||
self.norm_added_k = RMSNorm(dim_head, eps=eps)
|
||||
else:
|
||||
self.norm_added_q = torch.nn.RMSNorm(dim_head, eps=eps)
|
||||
self.norm_added_k = torch.nn.RMSNorm(dim_head, eps=eps)
|
||||
self.add_q_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
|
||||
self.add_k_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
|
||||
self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
|
||||
@@ -351,6 +370,11 @@ class FluxSingleTransformerBlock(nn.Module):
|
||||
self.norm = AdaLayerNormZeroSingle(dim)
|
||||
self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim)
|
||||
self.act_mlp = nn.GELU(approximate="tanh")
|
||||
# if not DIFFUSERS_ENABLE_HUB_KERNELS:
|
||||
# self.act_mlp = nn.GELU(approximate="tanh")
|
||||
# else:
|
||||
# self.act_mlp = gelu_tanh_kernel()
|
||||
|
||||
self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
|
||||
|
||||
self.attn = FluxAttention(
|
||||
|
||||
@@ -152,26 +152,16 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
text_encoder ([`T5EncoderModel`]):
|
||||
[T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
|
||||
the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
|
||||
transformer ([`WanVACETransformer3DModel`]):
|
||||
transformer ([`WanTransformer3DModel`]):
|
||||
Conditional Transformer to denoise the input latents.
|
||||
transformer_2 ([`WanVACETransformer3DModel`], *optional*):
|
||||
Conditional Transformer to denoise the input latents during the low-noise stage. In two-stage denoising,
|
||||
`transformer` handles high-noise stages and `transformer_2` handles low-noise stages. If not provided, only
|
||||
`transformer` is used.
|
||||
scheduler ([`UniPCMultistepScheduler`]):
|
||||
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
||||
vae ([`AutoencoderKLWan`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
|
||||
boundary_ratio (`float`, *optional*, defaults to `None`):
|
||||
Ratio of total timesteps to use as the boundary for switching between transformers in two-stage denoising.
|
||||
The actual boundary timestep is calculated as `boundary_ratio * num_train_timesteps`. When provided,
|
||||
`transformer` handles timesteps >= boundary_timestep and `transformer_2` handles timesteps <
|
||||
boundary_timestep. If `None`, only `transformer` is used for the entire denoising process.
|
||||
"""
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->transformer->vae"
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
|
||||
_optional_components = ["transformer_2"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -180,8 +170,6 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
transformer: WanVACETransformer3DModel,
|
||||
vae: AutoencoderKLWan,
|
||||
scheduler: FlowMatchEulerDiscreteScheduler,
|
||||
transformer_2: WanVACETransformer3DModel = None,
|
||||
boundary_ratio: Optional[float] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -190,10 +178,9 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
transformer=transformer,
|
||||
transformer_2=transformer_2,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
self.register_to_config(boundary_ratio=boundary_ratio)
|
||||
|
||||
self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4
|
||||
self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8
|
||||
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
|
||||
@@ -334,7 +321,6 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
video=None,
|
||||
mask=None,
|
||||
reference_images=None,
|
||||
guidance_scale_2=None,
|
||||
):
|
||||
base = self.vae_scale_factor_spatial * self.transformer.config.patch_size[1]
|
||||
if height % base != 0 or width % base != 0:
|
||||
@@ -346,8 +332,6 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
raise ValueError(
|
||||
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
||||
)
|
||||
if self.config.boundary_ratio is None and guidance_scale_2 is not None:
|
||||
raise ValueError("`guidance_scale_2` is only supported when the pipeline's `boundary_ratio` is not None.")
|
||||
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
@@ -683,7 +667,6 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
num_frames: int = 81,
|
||||
num_inference_steps: int = 50,
|
||||
guidance_scale: float = 5.0,
|
||||
guidance_scale_2: Optional[float] = None,
|
||||
num_videos_per_prompt: Optional[int] = 1,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
@@ -745,10 +728,6 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
guidance_scale_2 (`float`, *optional*, defaults to `None`):
|
||||
Guidance scale for the low-noise stage transformer (`transformer_2`). If `None` and the pipeline's
|
||||
`boundary_ratio` is not None, uses the same value as `guidance_scale`. Only used when `transformer_2`
|
||||
and the pipeline's `boundary_ratio` are not None.
|
||||
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
@@ -814,7 +793,6 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
video,
|
||||
mask,
|
||||
reference_images,
|
||||
guidance_scale_2,
|
||||
)
|
||||
|
||||
if num_frames % self.vae_scale_factor_temporal != 1:
|
||||
@@ -824,11 +802,7 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1
|
||||
num_frames = max(num_frames, 1)
|
||||
|
||||
if self.config.boundary_ratio is not None and guidance_scale_2 is None:
|
||||
guidance_scale_2 = guidance_scale
|
||||
|
||||
self._guidance_scale = guidance_scale
|
||||
self._guidance_scale_2 = guidance_scale_2
|
||||
self._attention_kwargs = attention_kwargs
|
||||
self._current_timestep = None
|
||||
self._interrupt = False
|
||||
@@ -922,53 +896,36 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
if self.config.boundary_ratio is not None:
|
||||
boundary_timestep = self.config.boundary_ratio * self.scheduler.config.num_train_timesteps
|
||||
else:
|
||||
boundary_timestep = None
|
||||
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
self._current_timestep = t
|
||||
|
||||
if boundary_timestep is None or t >= boundary_timestep:
|
||||
# wan2.1 or high-noise stage in wan2.2
|
||||
current_model = self.transformer
|
||||
current_guidance_scale = guidance_scale
|
||||
else:
|
||||
# low-noise stage in wan2.2
|
||||
current_model = self.transformer_2
|
||||
current_guidance_scale = guidance_scale_2
|
||||
|
||||
latent_model_input = latents.to(transformer_dtype)
|
||||
timestep = t.expand(latents.shape[0])
|
||||
|
||||
with current_model.cache_context("cond"):
|
||||
noise_pred = current_model(
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
control_hidden_states=conditioning_latents,
|
||||
control_hidden_states_scale=conditioning_scale,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
if self.do_classifier_free_guidance:
|
||||
noise_uncond = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
encoder_hidden_states=negative_prompt_embeds,
|
||||
control_hidden_states=conditioning_latents,
|
||||
control_hidden_states_scale=conditioning_scale,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
if self.do_classifier_free_guidance:
|
||||
with current_model.cache_context("uncond"):
|
||||
noise_uncond = current_model(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=negative_prompt_embeds,
|
||||
control_hidden_states=conditioning_latents,
|
||||
control_hidden_states_scale=conditioning_scale,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)
|
||||
noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from typing import Union
|
||||
|
||||
from ..utils import get_logger
|
||||
from .import_utils import is_kernels_available
|
||||
|
||||
@@ -21,3 +23,42 @@ def _get_fa3_from_hub():
|
||||
except Exception as e:
|
||||
logger.error(f"An error occurred while fetching kernel '{_DEFAULT_HUB_ID_FA3}' from the Hub: {e}")
|
||||
raise
|
||||
|
||||
|
||||
if is_kernels_available():
|
||||
from kernels import (
|
||||
Device,
|
||||
LayerRepository,
|
||||
register_kernel_mapping,
|
||||
replace_kernel_forward_from_hub,
|
||||
use_kernel_forward_from_hub,
|
||||
)
|
||||
|
||||
_KERNEL_MAPPING: dict[str, dict[Union[Device, str], LayerRepository]] = {
|
||||
"RMSNorm": {
|
||||
"cuda": LayerRepository(repo_id="kernels-community/liger_kernels", layer_name="LigerRMSNorm"),
|
||||
},
|
||||
}
|
||||
|
||||
register_kernel_mapping(_KERNEL_MAPPING)
|
||||
|
||||
else:
|
||||
# Stub to make decorators int transformers work when `kernels`
|
||||
# is not installed.
|
||||
def use_kernel_forward_from_hub(*args, **kwargs):
|
||||
def decorator(cls):
|
||||
return cls
|
||||
|
||||
return decorator
|
||||
|
||||
class LayerRepository:
|
||||
def __init__(self, *args, **kwargs):
|
||||
raise RuntimeError("LayerRepository requires `kernels` to be installed. Run `pip install kernels`.")
|
||||
|
||||
def replace_kernel_forward_from_hub(*args, **kwargs):
|
||||
raise RuntimeError(
|
||||
"replace_kernel_forward_from_hub requires `kernels` to be installed. Run `pip install kernels`."
|
||||
)
|
||||
|
||||
def register_kernel_mapping(*args, **kwargs):
|
||||
raise RuntimeError("register_kernel_mapping requires `kernels` to be installed. Run `pip install kernels`.")
|
||||
|
||||
@@ -87,7 +87,6 @@ class WanVACEPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
"scheduler": scheduler,
|
||||
"text_encoder": text_encoder,
|
||||
"tokenizer": tokenizer,
|
||||
"transformer_2": None,
|
||||
}
|
||||
return components
|
||||
|
||||
|
||||
Reference in New Issue
Block a user