|
|
|
@@ -131,11 +131,15 @@ class SDFunctionTesterMixin:
|
|
|
|
|
|
|
|
|
|
inputs = self.get_dummy_inputs(torch_device)
|
|
|
|
|
inputs["return_dict"] = False
|
|
|
|
|
inputs["output_type"] = "np"
|
|
|
|
|
|
|
|
|
|
output = pipe(**inputs)[0]
|
|
|
|
|
|
|
|
|
|
pipe.enable_freeu(s1=0.9, s2=0.2, b1=1.2, b2=1.4)
|
|
|
|
|
inputs = self.get_dummy_inputs(torch_device)
|
|
|
|
|
inputs["return_dict"] = False
|
|
|
|
|
inputs["output_type"] = "np"
|
|
|
|
|
|
|
|
|
|
output_freeu = pipe(**inputs)[0]
|
|
|
|
|
|
|
|
|
|
assert not np.allclose(
|
|
|
|
@@ -150,6 +154,8 @@ class SDFunctionTesterMixin:
|
|
|
|
|
|
|
|
|
|
inputs = self.get_dummy_inputs(torch_device)
|
|
|
|
|
inputs["return_dict"] = False
|
|
|
|
|
inputs["output_type"] = "np"
|
|
|
|
|
|
|
|
|
|
output = pipe(**inputs)[0]
|
|
|
|
|
|
|
|
|
|
pipe.enable_freeu(s1=0.9, s2=0.2, b1=1.2, b2=1.4)
|
|
|
|
@@ -162,6 +168,8 @@ class SDFunctionTesterMixin:
|
|
|
|
|
|
|
|
|
|
inputs = self.get_dummy_inputs(torch_device)
|
|
|
|
|
inputs["return_dict"] = False
|
|
|
|
|
inputs["output_type"] = "np"
|
|
|
|
|
|
|
|
|
|
output_no_freeu = pipe(**inputs)[0]
|
|
|
|
|
assert np.allclose(
|
|
|
|
|
output, output_no_freeu, atol=1e-2
|
|
|
|
@@ -1144,24 +1152,20 @@ class PipelineTesterMixin:
|
|
|
|
|
self.assertLess(
|
|
|
|
|
max_diff, expected_max_diff, "running CPU offloading 2nd time should not affect the inference results"
|
|
|
|
|
)
|
|
|
|
|
offloaded_modules = {
|
|
|
|
|
k: v
|
|
|
|
|
offloaded_modules = [
|
|
|
|
|
v
|
|
|
|
|
for k, v in pipe.components.items()
|
|
|
|
|
if isinstance(v, torch.nn.Module) and k not in pipe._exclude_from_cpu_offload
|
|
|
|
|
}
|
|
|
|
|
self.assertTrue(
|
|
|
|
|
all(v.device.type == "cpu" for v in offloaded_modules.values()),
|
|
|
|
|
f"Not offloaded: {[k for k, v in offloaded_modules.items() if v.device.type != 'cpu']}",
|
|
|
|
|
]
|
|
|
|
|
(
|
|
|
|
|
self.assertTrue(all(v.device.type == "cpu" for v in offloaded_modules)),
|
|
|
|
|
f"Not offloaded: {[v for v in offloaded_modules if v.device.type != 'cpu']}",
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
offloaded_modules_with_incorrect_hooks = {}
|
|
|
|
|
for k, v in offloaded_modules.items():
|
|
|
|
|
if hasattr(v, "_hf_hook") and not isinstance(v._hf_hook, accelerate.hooks.CpuOffload):
|
|
|
|
|
offloaded_modules_with_incorrect_hooks[k] = type(v._hf_hook)
|
|
|
|
|
|
|
|
|
|
self.assertTrue(
|
|
|
|
|
len(offloaded_modules_with_incorrect_hooks) == 0,
|
|
|
|
|
f"Not installed correct hook: {offloaded_modules_with_incorrect_hooks}",
|
|
|
|
|
offloaded_modules_with_hooks = [v for v in offloaded_modules if hasattr(v, "_hf_hook")]
|
|
|
|
|
(
|
|
|
|
|
self.assertTrue(all(isinstance(v, accelerate.hooks.CpuOffload) for v in offloaded_modules_with_hooks)),
|
|
|
|
|
f"Not installed correct hook: {[v for v in offloaded_modules_with_hooks if not isinstance(v, accelerate.hooks.CpuOffload)]}",
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
@unittest.skipIf(
|
|
|
|
@@ -1193,23 +1197,22 @@ class PipelineTesterMixin:
|
|
|
|
|
self.assertLess(
|
|
|
|
|
max_diff, expected_max_diff, "running sequential offloading second time should have the inference results"
|
|
|
|
|
)
|
|
|
|
|
offloaded_modules = {
|
|
|
|
|
k: v
|
|
|
|
|
offloaded_modules = [
|
|
|
|
|
v
|
|
|
|
|
for k, v in pipe.components.items()
|
|
|
|
|
if isinstance(v, torch.nn.Module) and k not in pipe._exclude_from_cpu_offload
|
|
|
|
|
}
|
|
|
|
|
self.assertTrue(
|
|
|
|
|
all(v.device.type == "meta" for v in offloaded_modules.values()),
|
|
|
|
|
f"Not offloaded: {[k for k, v in offloaded_modules.items() if v.device.type != 'meta']}",
|
|
|
|
|
]
|
|
|
|
|
(
|
|
|
|
|
self.assertTrue(all(v.device.type == "meta" for v in offloaded_modules)),
|
|
|
|
|
f"Not offloaded: {[v for v in offloaded_modules if v.device.type != 'meta']}",
|
|
|
|
|
)
|
|
|
|
|
offloaded_modules_with_incorrect_hooks = {}
|
|
|
|
|
for k, v in offloaded_modules.items():
|
|
|
|
|
if hasattr(v, "_hf_hook") and not isinstance(v._hf_hook, accelerate.hooks.AlignDevicesHook):
|
|
|
|
|
offloaded_modules_with_incorrect_hooks[k] = type(v._hf_hook)
|
|
|
|
|
|
|
|
|
|
self.assertTrue(
|
|
|
|
|
len(offloaded_modules_with_incorrect_hooks) == 0,
|
|
|
|
|
f"Not installed correct hook: {offloaded_modules_with_incorrect_hooks}",
|
|
|
|
|
offloaded_modules_with_hooks = [v for v in offloaded_modules if hasattr(v, "_hf_hook")]
|
|
|
|
|
(
|
|
|
|
|
self.assertTrue(
|
|
|
|
|
all(isinstance(v, accelerate.hooks.AlignDevicesHook) for v in offloaded_modules_with_hooks)
|
|
|
|
|
),
|
|
|
|
|
f"Not installed correct hook: {[v for v in offloaded_modules_with_hooks if not isinstance(v, accelerate.hooks.AlignDevicesHook)]}",
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
@unittest.skipIf(
|
|
|
|
|