[Mochi-1] ensuring to compute the fourier features in FP32 in Mochi encoder (#10031)
compute fourier features in FP32.
This commit is contained in:
parent
6b288ec44d
commit
c96bfa5c80
@ -437,7 +437,8 @@ class FourierFeatures(nn.Module):
|
|||||||
|
|
||||||
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
||||||
r"""Forward method of the `FourierFeatures` class."""
|
r"""Forward method of the `FourierFeatures` class."""
|
||||||
|
original_dtype = inputs.dtype
|
||||||
|
inputs = inputs.to(torch.float32)
|
||||||
num_channels = inputs.shape[1]
|
num_channels = inputs.shape[1]
|
||||||
num_freqs = (self.stop - self.start) // self.step
|
num_freqs = (self.stop - self.start) // self.step
|
||||||
|
|
||||||
@ -450,7 +451,7 @@ class FourierFeatures(nn.Module):
|
|||||||
# Scale channels by frequency.
|
# Scale channels by frequency.
|
||||||
h = w * h
|
h = w * h
|
||||||
|
|
||||||
return torch.cat([inputs, torch.sin(h), torch.cos(h)], dim=1)
|
return torch.cat([inputs, torch.sin(h), torch.cos(h)], dim=1).to(original_dtype)
|
||||||
|
|
||||||
|
|
||||||
class MochiEncoder3D(nn.Module):
|
class MochiEncoder3D(nn.Module):
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user