Fix test for consistency decoder. (#7746)

update
This commit is contained in:
Dhruv Nair
2024-04-24 12:28:11 +05:30
committed by GitHub
parent 88018fcf20
commit 9ef43f38d4
+1 -1
View File
@@ -1153,5 +1153,5 @@ class ConsistencyDecoderVAEIntegrationTests(unittest.TestCase):
shapes = [(1, 4, 73, 97), (1, 4, 97, 73), (1, 4, 49, 65), (1, 4, 65, 49)]
with torch.no_grad():
for shape in shapes:
image = torch.zeros(shape, device=torch_device)
image = torch.zeros(shape, device=torch_device, dtype=pipe.vae.dtype)
pipe.vae.decode(image)