Changed variable name from "h" to "hidden_states" (#285)
* Changed variable name from "h" to "hidden_states" Per issue #198 , changed variable name from "h" to "hidden_states" in the forward function only. I am happy to change any other variable names, please advise recommended new names. * Update src/diffusers/models/resnet.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
committed by
GitHub
parent
034673bbeb
commit
1f196a09fe
@@ -328,39 +328,39 @@ class ResnetBlock2D(nn.Module):
|
|||||||
if self.use_nin_shortcut:
|
if self.use_nin_shortcut:
|
||||||
self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
||||||
|
|
||||||
def forward(self, x, temb, hey=False):
|
def forward(self, x, temb):
|
||||||
h = x
|
hidden_states = x
|
||||||
|
|
||||||
# make sure hidden states is in float32
|
# make sure hidden states is in float32
|
||||||
# when running in half-precision
|
# when running in half-precision
|
||||||
h = self.norm1(h.float()).type(h.dtype)
|
hidden_states = self.norm1(hidden_states.float()).type(hidden_states.dtype)
|
||||||
h = self.nonlinearity(h)
|
hidden_states = self.nonlinearity(hidden_states)
|
||||||
|
|
||||||
if self.upsample is not None:
|
if self.upsample is not None:
|
||||||
x = self.upsample(x)
|
x = self.upsample(x)
|
||||||
h = self.upsample(h)
|
hidden_states = self.upsample(hidden_states)
|
||||||
elif self.downsample is not None:
|
elif self.downsample is not None:
|
||||||
x = self.downsample(x)
|
x = self.downsample(x)
|
||||||
h = self.downsample(h)
|
hidden_states = self.downsample(hidden_states)
|
||||||
|
|
||||||
h = self.conv1(h)
|
hidden_states = self.conv1(hidden_states)
|
||||||
|
|
||||||
if temb is not None:
|
if temb is not None:
|
||||||
temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None]
|
temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None]
|
||||||
h = h + temb
|
hidden_states = hidden_states + temb
|
||||||
|
|
||||||
# make sure hidden states is in float32
|
# make sure hidden states is in float32
|
||||||
# when running in half-precision
|
# when running in half-precision
|
||||||
h = self.norm2(h.float()).type(h.dtype)
|
hidden_states = self.norm2(hidden_states.float()).type(hidden_states.dtype)
|
||||||
h = self.nonlinearity(h)
|
hidden_states = self.nonlinearity(hidden_states)
|
||||||
|
|
||||||
h = self.dropout(h)
|
hidden_states = self.dropout(hidden_states)
|
||||||
h = self.conv2(h)
|
hidden_states = self.conv2(hidden_states)
|
||||||
|
|
||||||
if self.conv_shortcut is not None:
|
if self.conv_shortcut is not None:
|
||||||
x = self.conv_shortcut(x)
|
x = self.conv_shortcut(x)
|
||||||
|
|
||||||
out = (x + h) / self.output_scale_factor
|
out = (x + hidden_states) / self.output_scale_factor
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user