[Pipelines] Enable Wan VACE to run since single transformer (#12428)
* update * update * update * update * update
This commit is contained in:
@@ -12,6 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
@@ -19,9 +20,15 @@ import torch
|
||||
from PIL import Image
|
||||
from transformers import AutoTokenizer, T5EncoderModel
|
||||
|
||||
from diffusers import AutoencoderKLWan, FlowMatchEulerDiscreteScheduler, WanVACEPipeline, WanVACETransformer3DModel
|
||||
from diffusers import (
|
||||
AutoencoderKLWan,
|
||||
FlowMatchEulerDiscreteScheduler,
|
||||
UniPCMultistepScheduler,
|
||||
WanVACEPipeline,
|
||||
WanVACETransformer3DModel,
|
||||
)
|
||||
|
||||
from ...testing_utils import enable_full_determinism
|
||||
from ...testing_utils import enable_full_determinism, torch_device
|
||||
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
|
||||
from ..test_pipelines_common import PipelineTesterMixin
|
||||
|
||||
@@ -212,3 +219,81 @@ class WanVACEPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
)
|
||||
def test_save_load_float16(self):
|
||||
pass
|
||||
|
||||
def test_inference_with_only_transformer(self):
|
||||
components = self.get_dummy_components()
|
||||
components["transformer_2"] = None
|
||||
components["boundary_ratio"] = 0.0
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
video = pipe(**inputs).frames[0]
|
||||
assert video.shape == (17, 3, 16, 16)
|
||||
|
||||
def test_inference_with_only_transformer_2(self):
|
||||
components = self.get_dummy_components()
|
||||
components["transformer_2"] = components["transformer"]
|
||||
components["transformer"] = None
|
||||
|
||||
# FlowMatchEulerDiscreteScheduler doesn't support running low noise only scheduler
|
||||
# because starting timestep t == 1000 == boundary_timestep
|
||||
components["scheduler"] = UniPCMultistepScheduler(
|
||||
prediction_type="flow_prediction", use_flow_sigmas=True, flow_shift=3.0
|
||||
)
|
||||
|
||||
components["boundary_ratio"] = 1.0
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
video = pipe(**inputs).frames[0]
|
||||
assert video.shape == (17, 3, 16, 16)
|
||||
|
||||
def test_save_load_optional_components(self, expected_max_difference=1e-4):
|
||||
optional_component = ["transformer"]
|
||||
|
||||
components = self.get_dummy_components()
|
||||
components["transformer_2"] = components["transformer"]
|
||||
# FlowMatchEulerDiscreteScheduler doesn't support running low noise only scheduler
|
||||
# because starting timestep t == 1000 == boundary_timestep
|
||||
components["scheduler"] = UniPCMultistepScheduler(
|
||||
prediction_type="flow_prediction", use_flow_sigmas=True, flow_shift=3.0
|
||||
)
|
||||
for component in optional_component:
|
||||
components[component] = None
|
||||
|
||||
components["boundary_ratio"] = 1.0
|
||||
|
||||
pipe = self.pipeline_class(**components)
|
||||
for component in pipe.components.values():
|
||||
if hasattr(component, "set_default_attn_processor"):
|
||||
component.set_default_attn_processor()
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
generator_device = "cpu"
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
torch.manual_seed(0)
|
||||
output = pipe(**inputs)[0]
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
pipe.save_pretrained(tmpdir, safe_serialization=False)
|
||||
pipe_loaded = self.pipeline_class.from_pretrained(tmpdir)
|
||||
for component in pipe_loaded.components.values():
|
||||
if hasattr(component, "set_default_attn_processor"):
|
||||
component.set_default_attn_processor()
|
||||
pipe_loaded.to(torch_device)
|
||||
pipe_loaded.set_progress_bar_config(disable=None)
|
||||
|
||||
for component in optional_component:
|
||||
assert getattr(pipe_loaded, component) is None, f"`{component}` did not stay set to None after loading."
|
||||
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
torch.manual_seed(0)
|
||||
output_loaded = pipe_loaded(**inputs)[0]
|
||||
|
||||
max_diff = np.abs(output.detach().cpu().numpy() - output_loaded.detach().cpu().numpy()).max()
|
||||
assert max_diff < expected_max_difference, "Outputs exceed expecpted maximum difference"
|
||||
|
||||
Reference in New Issue
Block a user