Add Ascend NPU support for SDXL fine-tuning and fix the model saving bug when using DeepSpeed. (#7816)
* Add Ascend NPU support for SDXL fine-tuning and fix the model saving bug when using DeepSpeed. * fix check code quality * Decouple the NPU flash attention and make it an independent module. * add doc and unit tests for npu flash attention. --------- Co-authored-by: mhh001 <mahonghao1@huawei.com> Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
This commit is contained in:
committed by
GitHub
parent
3e35628873
commit
58237364b1
@@ -55,3 +55,6 @@ An attention processor is a class for applying different types of attention mech
|
|||||||
|
|
||||||
## XFormersAttnProcessor
|
## XFormersAttnProcessor
|
||||||
[[autodoc]] models.attention_processor.XFormersAttnProcessor
|
[[autodoc]] models.attention_processor.XFormersAttnProcessor
|
||||||
|
|
||||||
|
## AttnProcessorNPU
|
||||||
|
[[autodoc]] models.attention_processor.AttnProcessorNPU
|
||||||
|
|||||||
@@ -32,7 +32,7 @@ import torch.utils.checkpoint
|
|||||||
import transformers
|
import transformers
|
||||||
from accelerate import Accelerator
|
from accelerate import Accelerator
|
||||||
from accelerate.logging import get_logger
|
from accelerate.logging import get_logger
|
||||||
from accelerate.utils import ProjectConfiguration, set_seed
|
from accelerate.utils import DistributedType, ProjectConfiguration, set_seed
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
from huggingface_hub import create_repo, upload_folder
|
from huggingface_hub import create_repo, upload_folder
|
||||||
from packaging import version
|
from packaging import version
|
||||||
@@ -53,7 +53,7 @@ from diffusers import (
|
|||||||
from diffusers.optimization import get_scheduler
|
from diffusers.optimization import get_scheduler
|
||||||
from diffusers.utils import check_min_version, is_wandb_available, make_image_grid
|
from diffusers.utils import check_min_version, is_wandb_available, make_image_grid
|
||||||
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
|
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
|
||||||
from diffusers.utils.import_utils import is_xformers_available
|
from diffusers.utils.import_utils import is_torch_npu_available, is_xformers_available
|
||||||
from diffusers.utils.torch_utils import is_compiled_module
|
from diffusers.utils.torch_utils import is_compiled_module
|
||||||
|
|
||||||
|
|
||||||
@@ -64,6 +64,8 @@ if is_wandb_available():
|
|||||||
check_min_version("0.28.0.dev0")
|
check_min_version("0.28.0.dev0")
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
if is_torch_npu_available():
|
||||||
|
torch.npu.config.allow_internal_format = False
|
||||||
|
|
||||||
|
|
||||||
def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step, is_final_validation=False):
|
def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step, is_final_validation=False):
|
||||||
@@ -471,6 +473,9 @@ def parse_args(input_args=None):
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
|
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--enable_npu_flash_attention", action="store_true", help="Whether or not to use npu flash attention."
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--set_grads_to_none",
|
"--set_grads_to_none",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
@@ -936,6 +941,13 @@ def main(args):
|
|||||||
text_encoder_two.requires_grad_(False)
|
text_encoder_two.requires_grad_(False)
|
||||||
controlnet.train()
|
controlnet.train()
|
||||||
|
|
||||||
|
if args.enable_npu_flash_attention:
|
||||||
|
if is_torch_npu_available():
|
||||||
|
logger.info("npu flash attention enabled.")
|
||||||
|
unet.enable_npu_flash_attention()
|
||||||
|
else:
|
||||||
|
raise ValueError("npu flash attention requires torch_npu extensions and is supported only on npu devices.")
|
||||||
|
|
||||||
if args.enable_xformers_memory_efficient_attention:
|
if args.enable_xformers_memory_efficient_attention:
|
||||||
if is_xformers_available():
|
if is_xformers_available():
|
||||||
import xformers
|
import xformers
|
||||||
@@ -1235,7 +1247,8 @@ def main(args):
|
|||||||
progress_bar.update(1)
|
progress_bar.update(1)
|
||||||
global_step += 1
|
global_step += 1
|
||||||
|
|
||||||
if accelerator.is_main_process:
|
# DeepSpeed requires saving weights on every device; saving weights only on the main process would cause issues.
|
||||||
|
if accelerator.distributed_type == DistributedType.DEEPSPEED or accelerator.is_main_process:
|
||||||
if global_step % args.checkpointing_steps == 0:
|
if global_step % args.checkpointing_steps == 0:
|
||||||
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
|
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
|
||||||
if args.checkpoints_total_limit is not None:
|
if args.checkpoints_total_limit is not None:
|
||||||
|
|||||||
@@ -32,7 +32,7 @@ import torch.utils.checkpoint
|
|||||||
import transformers
|
import transformers
|
||||||
from accelerate import Accelerator
|
from accelerate import Accelerator
|
||||||
from accelerate.logging import get_logger
|
from accelerate.logging import get_logger
|
||||||
from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
|
from accelerate.utils import DistributedDataParallelKwargs, DistributedType, ProjectConfiguration, set_seed
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
from huggingface_hub import create_repo, upload_folder
|
from huggingface_hub import create_repo, upload_folder
|
||||||
from packaging import version
|
from packaging import version
|
||||||
@@ -60,7 +60,7 @@ from diffusers.utils import (
|
|||||||
is_wandb_available,
|
is_wandb_available,
|
||||||
)
|
)
|
||||||
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
|
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
|
||||||
from diffusers.utils.import_utils import is_xformers_available
|
from diffusers.utils.import_utils import is_torch_npu_available, is_xformers_available
|
||||||
from diffusers.utils.torch_utils import is_compiled_module
|
from diffusers.utils.torch_utils import is_compiled_module
|
||||||
|
|
||||||
|
|
||||||
@@ -68,6 +68,8 @@ from diffusers.utils.torch_utils import is_compiled_module
|
|||||||
check_min_version("0.28.0.dev0")
|
check_min_version("0.28.0.dev0")
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
if is_torch_npu_available():
|
||||||
|
torch.npu.config.allow_internal_format = False
|
||||||
|
|
||||||
|
|
||||||
def save_model_card(
|
def save_model_card(
|
||||||
@@ -419,6 +421,9 @@ def parse_args(input_args=None):
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
|
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--enable_npu_flash_attention", action="store_true", help="Whether or not to use npu flash attention."
|
||||||
|
)
|
||||||
parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.")
|
parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--rank",
|
"--rank",
|
||||||
@@ -623,6 +628,13 @@ def main(args):
|
|||||||
text_encoder_one.to(accelerator.device, dtype=weight_dtype)
|
text_encoder_one.to(accelerator.device, dtype=weight_dtype)
|
||||||
text_encoder_two.to(accelerator.device, dtype=weight_dtype)
|
text_encoder_two.to(accelerator.device, dtype=weight_dtype)
|
||||||
|
|
||||||
|
if args.enable_npu_flash_attention:
|
||||||
|
if is_torch_npu_available():
|
||||||
|
logger.info("npu flash attention enabled.")
|
||||||
|
unet.enable_npu_flash_attention()
|
||||||
|
else:
|
||||||
|
raise ValueError("npu flash attention requires torch_npu extensions and is supported only on npu devices.")
|
||||||
|
|
||||||
if args.enable_xformers_memory_efficient_attention:
|
if args.enable_xformers_memory_efficient_attention:
|
||||||
if is_xformers_available():
|
if is_xformers_available():
|
||||||
import xformers
|
import xformers
|
||||||
@@ -1149,7 +1161,8 @@ def main(args):
|
|||||||
accelerator.log({"train_loss": train_loss}, step=global_step)
|
accelerator.log({"train_loss": train_loss}, step=global_step)
|
||||||
train_loss = 0.0
|
train_loss = 0.0
|
||||||
|
|
||||||
if accelerator.is_main_process:
|
# DeepSpeed requires saving weights on every device; saving weights only on the main process would cause issues.
|
||||||
|
if accelerator.distributed_type == DistributedType.DEEPSPEED or accelerator.is_main_process:
|
||||||
if global_step % args.checkpointing_steps == 0:
|
if global_step % args.checkpointing_steps == 0:
|
||||||
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
|
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
|
||||||
if args.checkpoints_total_limit is not None:
|
if args.checkpoints_total_limit is not None:
|
||||||
|
|||||||
@@ -18,8 +18,12 @@ import torch.nn.functional as F
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from ..utils import deprecate
|
from ..utils import deprecate
|
||||||
|
from ..utils.import_utils import is_torch_npu_available
|
||||||
|
|
||||||
|
|
||||||
|
if is_torch_npu_available():
|
||||||
|
import torch_npu
|
||||||
|
|
||||||
ACTIVATION_FUNCTIONS = {
|
ACTIVATION_FUNCTIONS = {
|
||||||
"swish": nn.SiLU(),
|
"swish": nn.SiLU(),
|
||||||
"silu": nn.SiLU(),
|
"silu": nn.SiLU(),
|
||||||
@@ -98,9 +102,13 @@ class GEGLU(nn.Module):
|
|||||||
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
||||||
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
||||||
deprecate("scale", "1.0.0", deprecation_message)
|
deprecate("scale", "1.0.0", deprecation_message)
|
||||||
|
hidden_states = self.proj(hidden_states)
|
||||||
hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)
|
if is_torch_npu_available():
|
||||||
return hidden_states * self.gelu(gate)
|
# using torch_npu.npu_geglu can run faster and save memory on NPU.
|
||||||
|
return torch_npu.npu_geglu(hidden_states, dim=-1, approximate=1)[0]
|
||||||
|
else:
|
||||||
|
hidden_states, gate = hidden_states.chunk(2, dim=-1)
|
||||||
|
return hidden_states * self.gelu(gate)
|
||||||
|
|
||||||
|
|
||||||
class ApproximateGELU(nn.Module):
|
class ApproximateGELU(nn.Module):
|
||||||
|
|||||||
@@ -12,6 +12,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import inspect
|
import inspect
|
||||||
|
import math
|
||||||
from importlib import import_module
|
from importlib import import_module
|
||||||
from typing import Callable, List, Optional, Union
|
from typing import Callable, List, Optional, Union
|
||||||
|
|
||||||
@@ -21,13 +22,15 @@ from torch import nn
|
|||||||
|
|
||||||
from ..image_processor import IPAdapterMaskProcessor
|
from ..image_processor import IPAdapterMaskProcessor
|
||||||
from ..utils import deprecate, logging
|
from ..utils import deprecate, logging
|
||||||
from ..utils.import_utils import is_xformers_available
|
from ..utils.import_utils import is_torch_npu_available, is_xformers_available
|
||||||
from ..utils.torch_utils import maybe_allow_in_graph
|
from ..utils.torch_utils import maybe_allow_in_graph
|
||||||
from .lora import LoRALinearLayer
|
from .lora import LoRALinearLayer
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
if is_torch_npu_available():
|
||||||
|
import torch_npu
|
||||||
|
|
||||||
if is_xformers_available():
|
if is_xformers_available():
|
||||||
import xformers
|
import xformers
|
||||||
@@ -209,6 +212,23 @@ class Attention(nn.Module):
|
|||||||
)
|
)
|
||||||
self.set_processor(processor)
|
self.set_processor(processor)
|
||||||
|
|
||||||
|
def set_use_npu_flash_attention(self, use_npu_flash_attention: bool) -> None:
|
||||||
|
r"""
|
||||||
|
Set whether to use npu flash attention from `torch_npu` or not.
|
||||||
|
|
||||||
|
"""
|
||||||
|
if use_npu_flash_attention:
|
||||||
|
processor = AttnProcessorNPU()
|
||||||
|
else:
|
||||||
|
# set attention processor
|
||||||
|
# We use the AttnProcessor2_0 by default when torch 2.x is used which uses
|
||||||
|
# torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
|
||||||
|
# but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
|
||||||
|
processor = (
|
||||||
|
AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor()
|
||||||
|
)
|
||||||
|
self.set_processor(processor)
|
||||||
|
|
||||||
def set_use_memory_efficient_attention_xformers(
|
def set_use_memory_efficient_attention_xformers(
|
||||||
self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
|
self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
|
||||||
) -> None:
|
) -> None:
|
||||||
@@ -1207,6 +1227,116 @@ class XFormersAttnProcessor:
|
|||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class AttnProcessorNPU:
|
||||||
|
|
||||||
|
r"""
|
||||||
|
Processor for implementing flash attention using torch_npu. Torch_npu supports only fp16 and bf16 data types. If
|
||||||
|
fp32 is used, F.scaled_dot_product_attention will be used for computation, but the acceleration effect on NPU is
|
||||||
|
not significant.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
if not is_torch_npu_available():
|
||||||
|
raise ImportError("AttnProcessorNPU requires torch_npu extensions and is supported only on npu devices.")
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
attn: Attention,
|
||||||
|
hidden_states: torch.FloatTensor,
|
||||||
|
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||||
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
|
temb: Optional[torch.FloatTensor] = None,
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
) -> torch.FloatTensor:
|
||||||
|
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
||||||
|
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
||||||
|
deprecate("scale", "1.0.0", deprecation_message)
|
||||||
|
|
||||||
|
residual = hidden_states
|
||||||
|
if attn.spatial_norm is not None:
|
||||||
|
hidden_states = attn.spatial_norm(hidden_states, temb)
|
||||||
|
|
||||||
|
input_ndim = hidden_states.ndim
|
||||||
|
|
||||||
|
if input_ndim == 4:
|
||||||
|
batch_size, channel, height, width = hidden_states.shape
|
||||||
|
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
||||||
|
|
||||||
|
batch_size, sequence_length, _ = (
|
||||||
|
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||||
|
)
|
||||||
|
|
||||||
|
if attention_mask is not None:
|
||||||
|
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
||||||
|
# scaled_dot_product_attention expects attention_mask shape to be
|
||||||
|
# (batch, heads, source_length, target_length)
|
||||||
|
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
||||||
|
|
||||||
|
if attn.group_norm is not None:
|
||||||
|
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
||||||
|
|
||||||
|
query = attn.to_q(hidden_states)
|
||||||
|
|
||||||
|
if encoder_hidden_states is None:
|
||||||
|
encoder_hidden_states = hidden_states
|
||||||
|
elif attn.norm_cross:
|
||||||
|
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
||||||
|
|
||||||
|
key = attn.to_k(encoder_hidden_states)
|
||||||
|
value = attn.to_v(encoder_hidden_states)
|
||||||
|
|
||||||
|
inner_dim = key.shape[-1]
|
||||||
|
head_dim = inner_dim // attn.heads
|
||||||
|
|
||||||
|
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||||
|
|
||||||
|
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||||
|
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||||
|
|
||||||
|
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
||||||
|
if query.dtype in (torch.float16, torch.bfloat16):
|
||||||
|
hidden_states = torch_npu.npu_fusion_attention(
|
||||||
|
query,
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
attn.heads,
|
||||||
|
input_layout="BNSD",
|
||||||
|
pse=None,
|
||||||
|
atten_mask=attention_mask,
|
||||||
|
scale=1.0 / math.sqrt(query.shape[-1]),
|
||||||
|
pre_tockens=65536,
|
||||||
|
next_tockens=65536,
|
||||||
|
keep_prob=1.0,
|
||||||
|
sync=False,
|
||||||
|
inner_precise=0,
|
||||||
|
)[0]
|
||||||
|
else:
|
||||||
|
# TODO: add support for attn.scale when we move to Torch 2.1
|
||||||
|
hidden_states = F.scaled_dot_product_attention(
|
||||||
|
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||||
|
hidden_states = hidden_states.to(query.dtype)
|
||||||
|
|
||||||
|
# linear proj
|
||||||
|
hidden_states = attn.to_out[0](hidden_states)
|
||||||
|
# dropout
|
||||||
|
hidden_states = attn.to_out[1](hidden_states)
|
||||||
|
|
||||||
|
if input_ndim == 4:
|
||||||
|
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
||||||
|
|
||||||
|
if attn.residual_connection:
|
||||||
|
hidden_states = hidden_states + residual
|
||||||
|
|
||||||
|
hidden_states = hidden_states / attn.rescale_output_factor
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
class AttnProcessor2_0:
|
class AttnProcessor2_0:
|
||||||
r"""
|
r"""
|
||||||
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
|
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
|
||||||
|
|||||||
@@ -272,6 +272,36 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|||||||
if self._supports_gradient_checkpointing:
|
if self._supports_gradient_checkpointing:
|
||||||
self.apply(partial(self._set_gradient_checkpointing, value=False))
|
self.apply(partial(self._set_gradient_checkpointing, value=False))
|
||||||
|
|
||||||
|
def set_use_npu_flash_attention(self, valid: bool) -> None:
|
||||||
|
r"""
|
||||||
|
Set the switch for the npu flash attention.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def fn_recursive_set_npu_flash_attention(module: torch.nn.Module):
|
||||||
|
if hasattr(module, "set_use_npu_flash_attention"):
|
||||||
|
module.set_use_npu_flash_attention(valid)
|
||||||
|
|
||||||
|
for child in module.children():
|
||||||
|
fn_recursive_set_npu_flash_attention(child)
|
||||||
|
|
||||||
|
for module in self.children():
|
||||||
|
if isinstance(module, torch.nn.Module):
|
||||||
|
fn_recursive_set_npu_flash_attention(module)
|
||||||
|
|
||||||
|
def enable_npu_flash_attention(self) -> None:
|
||||||
|
r"""
|
||||||
|
Enable npu flash attention from torch_npu
|
||||||
|
|
||||||
|
"""
|
||||||
|
self.set_use_npu_flash_attention(True)
|
||||||
|
|
||||||
|
def disable_npu_flash_attention(self) -> None:
|
||||||
|
r"""
|
||||||
|
disable npu flash attention from torch_npu
|
||||||
|
|
||||||
|
"""
|
||||||
|
self.set_use_npu_flash_attention(False)
|
||||||
|
|
||||||
def set_use_memory_efficient_attention_xformers(
|
def set_use_memory_efficient_attention_xformers(
|
||||||
self, valid: bool, attention_op: Optional[Callable] = None
|
self, valid: bool, attention_op: Optional[Callable] = None
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|||||||
@@ -30,9 +30,14 @@ from huggingface_hub.utils import is_jinja_available
|
|||||||
from requests.exceptions import HTTPError
|
from requests.exceptions import HTTPError
|
||||||
|
|
||||||
from diffusers.models import UNet2DConditionModel
|
from diffusers.models import UNet2DConditionModel
|
||||||
from diffusers.models.attention_processor import AttnProcessor, AttnProcessor2_0, XFormersAttnProcessor
|
from diffusers.models.attention_processor import (
|
||||||
|
AttnProcessor,
|
||||||
|
AttnProcessor2_0,
|
||||||
|
AttnProcessorNPU,
|
||||||
|
XFormersAttnProcessor,
|
||||||
|
)
|
||||||
from diffusers.training_utils import EMAModel
|
from diffusers.training_utils import EMAModel
|
||||||
from diffusers.utils import is_xformers_available, logging
|
from diffusers.utils import is_torch_npu_available, is_xformers_available, logging
|
||||||
from diffusers.utils.testing_utils import (
|
from diffusers.utils.testing_utils import (
|
||||||
CaptureLogger,
|
CaptureLogger,
|
||||||
get_python_version,
|
get_python_version,
|
||||||
@@ -300,6 +305,53 @@ class ModelTesterMixin:
|
|||||||
|
|
||||||
assert str(error.exception) == f"'{type(model).__name__}' object has no attribute 'does_not_exist'"
|
assert str(error.exception) == f"'{type(model).__name__}' object has no attribute 'does_not_exist'"
|
||||||
|
|
||||||
|
@unittest.skipIf(
|
||||||
|
torch_device != "npu" or not is_torch_npu_available(),
|
||||||
|
reason="torch npu flash attention is only available with NPU and `torch_npu` installed",
|
||||||
|
)
|
||||||
|
def test_set_torch_npu_flash_attn_processor_determinism(self):
|
||||||
|
torch.use_deterministic_algorithms(False)
|
||||||
|
if self.forward_requires_fresh_args:
|
||||||
|
model = self.model_class(**self.init_dict)
|
||||||
|
else:
|
||||||
|
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||||
|
model = self.model_class(**init_dict)
|
||||||
|
model.to(torch_device)
|
||||||
|
|
||||||
|
if not hasattr(model, "set_attn_processor"):
|
||||||
|
# If not has `set_attn_processor`, skip test
|
||||||
|
return
|
||||||
|
|
||||||
|
model.set_default_attn_processor()
|
||||||
|
assert all(type(proc) == AttnProcessorNPU for proc in model.attn_processors.values())
|
||||||
|
with torch.no_grad():
|
||||||
|
if self.forward_requires_fresh_args:
|
||||||
|
output = model(**self.inputs_dict(0))[0]
|
||||||
|
else:
|
||||||
|
output = model(**inputs_dict)[0]
|
||||||
|
|
||||||
|
model.enable_npu_flash_attention()
|
||||||
|
assert all(type(proc) == AttnProcessorNPU for proc in model.attn_processors.values())
|
||||||
|
with torch.no_grad():
|
||||||
|
if self.forward_requires_fresh_args:
|
||||||
|
output_2 = model(**self.inputs_dict(0))[0]
|
||||||
|
else:
|
||||||
|
output_2 = model(**inputs_dict)[0]
|
||||||
|
|
||||||
|
model.set_attn_processor(AttnProcessorNPU())
|
||||||
|
assert all(type(proc) == AttnProcessorNPU for proc in model.attn_processors.values())
|
||||||
|
with torch.no_grad():
|
||||||
|
if self.forward_requires_fresh_args:
|
||||||
|
output_3 = model(**self.inputs_dict(0))[0]
|
||||||
|
else:
|
||||||
|
output_3 = model(**inputs_dict)[0]
|
||||||
|
|
||||||
|
torch.use_deterministic_algorithms(True)
|
||||||
|
|
||||||
|
assert torch.allclose(output, output_2, atol=self.base_precision)
|
||||||
|
assert torch.allclose(output, output_3, atol=self.base_precision)
|
||||||
|
assert torch.allclose(output_2, output_3, atol=self.base_precision)
|
||||||
|
|
||||||
@unittest.skipIf(
|
@unittest.skipIf(
|
||||||
torch_device != "cuda" or not is_xformers_available(),
|
torch_device != "cuda" or not is_xformers_available(),
|
||||||
reason="XFormers attention is only available with CUDA and `xformers` installed",
|
reason="XFormers attention is only available with CUDA and `xformers` installed",
|
||||||
|
|||||||
Reference in New Issue
Block a user