Correction for non-integral image resolutions with quantizations other than float32 (#7356)
* Correction for non-integral image resolutions with quantizations other than float32. * Support for training, and use of diffusers-style casting.
This commit is contained in:
@@ -521,9 +521,11 @@ class StableCascadeUNet(ModelMixin, ConfigMixin, FromOriginalUNetMixin):
|
|||||||
if isinstance(block, SDCascadeResBlock):
|
if isinstance(block, SDCascadeResBlock):
|
||||||
skip = level_outputs[i] if k == 0 and i > 0 else None
|
skip = level_outputs[i] if k == 0 and i > 0 else None
|
||||||
if skip is not None and (x.size(-1) != skip.size(-1) or x.size(-2) != skip.size(-2)):
|
if skip is not None and (x.size(-1) != skip.size(-1) or x.size(-2) != skip.size(-2)):
|
||||||
|
orig_type = x.dtype
|
||||||
x = torch.nn.functional.interpolate(
|
x = torch.nn.functional.interpolate(
|
||||||
x.float(), skip.shape[-2:], mode="bilinear", align_corners=True
|
x.float(), skip.shape[-2:], mode="bilinear", align_corners=True
|
||||||
)
|
)
|
||||||
|
x = x.to(orig_type)
|
||||||
x = torch.utils.checkpoint.checkpoint(
|
x = torch.utils.checkpoint.checkpoint(
|
||||||
create_custom_forward(block), x, skip, use_reentrant=False
|
create_custom_forward(block), x, skip, use_reentrant=False
|
||||||
)
|
)
|
||||||
@@ -547,9 +549,11 @@ class StableCascadeUNet(ModelMixin, ConfigMixin, FromOriginalUNetMixin):
|
|||||||
if isinstance(block, SDCascadeResBlock):
|
if isinstance(block, SDCascadeResBlock):
|
||||||
skip = level_outputs[i] if k == 0 and i > 0 else None
|
skip = level_outputs[i] if k == 0 and i > 0 else None
|
||||||
if skip is not None and (x.size(-1) != skip.size(-1) or x.size(-2) != skip.size(-2)):
|
if skip is not None and (x.size(-1) != skip.size(-1) or x.size(-2) != skip.size(-2)):
|
||||||
|
orig_type = x.dtype
|
||||||
x = torch.nn.functional.interpolate(
|
x = torch.nn.functional.interpolate(
|
||||||
x.float(), skip.shape[-2:], mode="bilinear", align_corners=True
|
x.float(), skip.shape[-2:], mode="bilinear", align_corners=True
|
||||||
)
|
)
|
||||||
|
x = x.to(orig_type)
|
||||||
x = block(x, skip)
|
x = block(x, skip)
|
||||||
elif isinstance(block, SDCascadeAttnBlock):
|
elif isinstance(block, SDCascadeAttnBlock):
|
||||||
x = block(x, clip)
|
x = block(x, clip)
|
||||||
|
|||||||
Reference in New Issue
Block a user