Fix test for consistency decoder. (#7746)

update
This commit is contained in:
Dhruv Nair 2024-04-24 12:28:11 +05:30 committed by GitHub
parent 88018fcf20
commit 9ef43f38d4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1153,5 +1153,5 @@ class ConsistencyDecoderVAEIntegrationTests(unittest.TestCase):
shapes = [(1, 4, 73, 97), (1, 4, 97, 73), (1, 4, 49, 65), (1, 4, 65, 49)]
with torch.no_grad():
for shape in shapes:
image = torch.zeros(shape, device=torch_device)
image = torch.zeros(shape, device=torch_device, dtype=pipe.vae.dtype)
pipe.vae.decode(image)