try fix for tests
This commit is contained in:
@@ -52,6 +52,8 @@ class LayerwiseCastingHook(ModelHook):
|
||||
_is_stateful = False
|
||||
|
||||
def __init__(self, storage_dtype: torch.dtype, compute_dtype: torch.dtype, non_blocking: bool) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.storage_dtype = storage_dtype
|
||||
self.compute_dtype = compute_dtype
|
||||
self.non_blocking = non_blocking
|
||||
|
||||
Reference in New Issue
Block a user