make style
This commit is contained in:
+2
-2
@@ -45,7 +45,6 @@ def preprocess_image(image):
|
|||||||
|
|
||||||
|
|
||||||
def preprocess_mask(mask, scale_factor=8):
|
def preprocess_mask(mask, scale_factor=8):
|
||||||
|
|
||||||
if not isinstance(mask, torch.FloatTensor):
|
if not isinstance(mask, torch.FloatTensor):
|
||||||
mask = mask.convert("L")
|
mask = mask.convert("L")
|
||||||
w, h = mask.size
|
w, h = mask.size
|
||||||
@@ -65,7 +64,8 @@ def preprocess_mask(mask, scale_factor=8):
|
|||||||
mask = mask.permute(0, 3, 1, 2)
|
mask = mask.permute(0, 3, 1, 2)
|
||||||
elif mask.shape[1] not in valid_mask_channel_sizes:
|
elif mask.shape[1] not in valid_mask_channel_sizes:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Mask channel dimension of size in {valid_mask_channel_sizes} should be second or fourth dimension, but received mask of shape {tuple(mask.shape)}"
|
f"Mask channel dimension of size in {valid_mask_channel_sizes} should be second or fourth dimension,"
|
||||||
|
f" but received mask of shape {tuple(mask.shape)}"
|
||||||
)
|
)
|
||||||
# (potentially) reduce mask channel dimension from 3 to 1 for broadcasting to latent shape
|
# (potentially) reduce mask channel dimension from 3 to 1 for broadcasting to latent shape
|
||||||
mask = mask.mean(dim=1, keepdim=True)
|
mask = mask.mean(dim=1, keepdim=True)
|
||||||
|
|||||||
Reference in New Issue
Block a user