From 1f196a09fea0bb62308a31b66f1c398ff851959d Mon Sep 17 00:00:00 2001 From: Juan Carrasquilla <68667541+JC-swEng@users.noreply.github.com> Date: Thu, 1 Sep 2022 04:31:02 -0500 Subject: [PATCH] Changed variable name from "h" to "hidden_states" (#285) * Changed variable name from "h" to "hidden_states" Per issue #198 , changed variable name from "h" to "hidden_states" in the forward function only. I am happy to change any other variable names, please advise recommended new names. * Update src/diffusers/models/resnet.py Co-authored-by: Patrick von Platen Co-authored-by: Patrick von Platen --- src/diffusers/models/resnet.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index acce7b574e..50382bcab3 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -328,39 +328,39 @@ class ResnetBlock2D(nn.Module): if self.use_nin_shortcut: self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) - def forward(self, x, temb, hey=False): - h = x + def forward(self, x, temb): + hidden_states = x # make sure hidden states is in float32 # when running in half-precision - h = self.norm1(h.float()).type(h.dtype) - h = self.nonlinearity(h) + hidden_states = self.norm1(hidden_states.float()).type(hidden_states.dtype) + hidden_states = self.nonlinearity(hidden_states) if self.upsample is not None: x = self.upsample(x) - h = self.upsample(h) + hidden_states = self.upsample(hidden_states) elif self.downsample is not None: x = self.downsample(x) - h = self.downsample(h) + hidden_states = self.downsample(hidden_states) - h = self.conv1(h) + hidden_states = self.conv1(hidden_states) if temb is not None: temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None] - h = h + temb + hidden_states = hidden_states + temb # make sure hidden states is in float32 # when running in half-precision - h = self.norm2(h.float()).type(h.dtype) - h = self.nonlinearity(h) + hidden_states = self.norm2(hidden_states.float()).type(hidden_states.dtype) + hidden_states = self.nonlinearity(hidden_states) - h = self.dropout(h) - h = self.conv2(h) + hidden_states = self.dropout(hidden_states) + hidden_states = self.conv2(hidden_states) if self.conv_shortcut is not None: x = self.conv_shortcut(x) - out = (x + h) / self.output_scale_factor + out = (x + hidden_states) / self.output_scale_factor return out