make CI happy
This commit is contained in:
@@ -15,7 +15,6 @@
|
|||||||
|
|
||||||
import gc
|
import gc
|
||||||
import math
|
import math
|
||||||
import os
|
|
||||||
import tracemalloc
|
import tracemalloc
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
@@ -270,13 +269,6 @@ class UNet2DConditionModelTests(ModelTesterMixin, unittest.TestCase):
|
|||||||
|
|
||||||
def test_gradient_checkpointing(self):
|
def test_gradient_checkpointing(self):
|
||||||
# enable deterministic behavior for gradient checkpointing
|
# enable deterministic behavior for gradient checkpointing
|
||||||
torch.use_deterministic_algorithms(True)
|
|
||||||
|
|
||||||
# from torch docs: "A handful of CUDA operations are nondeterministic if the CUDA version is 10.2 or greater,
|
|
||||||
# unless the environment variable CUBLAS_WORKSPACE_CONFIG=:4096:8 or CUBLAS_WORKSPACE_CONFIG=:16:8 is set."
|
|
||||||
# https://pytorch.org/docs/stable/generated/torch.use_deterministic_algorithms.html
|
|
||||||
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"
|
|
||||||
|
|
||||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||||
model = self.model_class(**init_dict)
|
model = self.model_class(**init_dict)
|
||||||
model.to(torch_device)
|
model.to(torch_device)
|
||||||
@@ -313,10 +305,6 @@ class UNet2DConditionModelTests(ModelTesterMixin, unittest.TestCase):
|
|||||||
for name in grad_checkpointed:
|
for name in grad_checkpointed:
|
||||||
self.assertTrue(torch.allclose(grad_checkpointed[name], grad_not_checkpointed[name], atol=5e-5))
|
self.assertTrue(torch.allclose(grad_checkpointed[name], grad_not_checkpointed[name], atol=5e-5))
|
||||||
|
|
||||||
# disable deterministic behavior for gradient checkpointing
|
|
||||||
del os.environ["CUBLAS_WORKSPACE_CONFIG"]
|
|
||||||
torch.use_deterministic_algorithms(False)
|
|
||||||
|
|
||||||
|
|
||||||
# TODO(Patrick) - Re-add this test after having cleaned up LDM
|
# TODO(Patrick) - Re-add this test after having cleaned up LDM
|
||||||
# def test_output_pretrained_spatial_transformer(self):
|
# def test_output_pretrained_spatial_transformer(self):
|
||||||
|
|||||||
Reference in New Issue
Block a user