Skip to content

Commit d07c8d9

Browse files
committed
Updated random seeding
1 parent 5ce8047 commit d07c8d9

2 files changed

Lines changed: 5 additions & 3 deletions

File tree

src/thunder/benchmark.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -119,9 +119,6 @@ def run_benchmark(cfg: DictConfig, model_cls: Callable = None) -> None:
119119
from .utils.dice_loss import multiclass_dice_loss
120120
from .utils.utils import set_seed
121121

122-
# Setting the random seed
123-
set_seed(UtilsConstants.DEFAULT_SEED.value)
124-
125122
# Getting device
126123
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
127124

@@ -177,6 +174,9 @@ def run_benchmark(cfg: DictConfig, model_cls: Callable = None) -> None:
177174
shutil.rmtree(res_folder)
178175
os.makedirs(res_folder)
179176

177+
# Setting the random seed
178+
set_seed(UtilsConstants.DEFAULT_SEED.value)
179+
180180
if task_type in ["linear_probing", "segmentation"]:
181181
# Model checkpoints folder
182182
ckpt_folder = os.path.join(

src/thunder/utils/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,8 @@ def set_seed(seed: int) -> None:
192192
torch.cuda.manual_seed(seed)
193193
torch.backends.cudnn.deterministic = True
194194
torch.backends.cudnn.benchmark = False
195+
torch.use_deterministic_algorithms(True)
196+
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
195197

196198

197199
def wb_mask(

0 commit comments

Comments
 (0)