diff --git a/tests/unittest/_torch/misc/test_autotuner.py b/tests/unittest/_torch/misc/test_autotuner.py index 7212f491e6..c5021ef7f9 100644 --- a/tests/unittest/_torch/misc/test_autotuner.py +++ b/tests/unittest/_torch/misc/test_autotuner.py @@ -738,6 +738,8 @@ def _distributed_worker_function(world_size, strategy): runner = DistributedGemmRunner(prefer_tactics=prefer_tactics) config = TuningConfig(distributed_tuning_strategy=strategy) + # Keep temp_dir in function scope to prevent premature garbage collection + temp_dir = None if rank == 0: temp_dir = tempfile.TemporaryDirectory() # rank 0 should broadcast the cache path to all ranks @@ -782,6 +784,7 @@ def _distributed_worker_function(world_size, strategy): else: assert False, f"Unknown strategy: {strategy}" + dist.barrier() return True