try fix for tests

This commit is contained in:
Aryan
2025-02-21 08:41:11 +01:00
parent 8546c9ed29
commit d080379e94
+2
View File
@@ -52,6 +52,8 @@ class LayerwiseCastingHook(ModelHook):
_is_stateful = False
def __init__(self, storage_dtype: torch.dtype, compute_dtype: torch.dtype, non_blocking: bool) -> None:
super().__init__()
self.storage_dtype = storage_dtype
self.compute_dtype = compute_dtype
self.non_blocking = non_blocking