Changed variable name from "h" to "hidden_states" (#285)
* Changed variable name from "h" to "hidden_states" Per issue #198 , changed variable name from "h" to "hidden_states" in the forward function only. I am happy to change any other variable names, please advise recommended new names. * Update src/diffusers/models/resnet.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
committed by
GitHub
parent
034673bbeb
commit
1f196a09fe
@@ -328,39 +328,39 @@ class ResnetBlock2D(nn.Module):
|
||||
if self.use_nin_shortcut:
|
||||
self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
def forward(self, x, temb, hey=False):
|
||||
h = x
|
||||
def forward(self, x, temb):
|
||||
hidden_states = x
|
||||
|
||||
# make sure hidden states is in float32
|
||||
# when running in half-precision
|
||||
h = self.norm1(h.float()).type(h.dtype)
|
||||
h = self.nonlinearity(h)
|
||||
hidden_states = self.norm1(hidden_states.float()).type(hidden_states.dtype)
|
||||
hidden_states = self.nonlinearity(hidden_states)
|
||||
|
||||
if self.upsample is not None:
|
||||
x = self.upsample(x)
|
||||
h = self.upsample(h)
|
||||
hidden_states = self.upsample(hidden_states)
|
||||
elif self.downsample is not None:
|
||||
x = self.downsample(x)
|
||||
h = self.downsample(h)
|
||||
hidden_states = self.downsample(hidden_states)
|
||||
|
||||
h = self.conv1(h)
|
||||
hidden_states = self.conv1(hidden_states)
|
||||
|
||||
if temb is not None:
|
||||
temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None]
|
||||
h = h + temb
|
||||
hidden_states = hidden_states + temb
|
||||
|
||||
# make sure hidden states is in float32
|
||||
# when running in half-precision
|
||||
h = self.norm2(h.float()).type(h.dtype)
|
||||
h = self.nonlinearity(h)
|
||||
hidden_states = self.norm2(hidden_states.float()).type(hidden_states.dtype)
|
||||
hidden_states = self.nonlinearity(hidden_states)
|
||||
|
||||
h = self.dropout(h)
|
||||
h = self.conv2(h)
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
hidden_states = self.conv2(hidden_states)
|
||||
|
||||
if self.conv_shortcut is not None:
|
||||
x = self.conv_shortcut(x)
|
||||
|
||||
out = (x + h) / self.output_scale_factor
|
||||
out = (x + hidden_states) / self.output_scale_factor
|
||||
|
||||
return out
|
||||
|
||||
|
||||
Reference in New Issue
Block a user