Skip to content

Commit 2bd0f02

Browse files
committed
Revert global RNG changes that broke reproducibility
The previous commit replaced global np.random.* calls with unseeded np.random.RandomState() instances. This broke 12 tests because callers relying on np.random.seed() for determinism no longer get reproducible results from these public API functions. Revert transforms/utils.py, data/synthetic.py, data/utils.py, and utils/ordering.py back to the global RNG. These utility functions accept an optional rand_state parameter -- when None, they intentionally fall back to the global RNG to respect np.random.seed(). The signal/array.py fix (self.R.choice) is retained because those classes already inherit from Randomizable and use self.R for all other random operations -- np.random.choice was an inconsistency. Signed-off-by: SexyERIC0723 <haoyuwang144@gmail.com>
1 parent dc1af1c commit 2bd0f02

4 files changed

Lines changed: 7 additions & 7 deletions

File tree

monai/data/synthetic.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def create_test_image_2d(
6262
raise ValueError(f"the minimal size {min_size} of the image should be larger than `2 * rad_max` 2x{rad_max}.")
6363

6464
image = np.zeros((height, width))
65-
rs: np.random.RandomState = np.random.RandomState() if random_state is None else random_state
65+
rs: np.random.RandomState = np.random.random.__self__ if random_state is None else random_state # type: ignore
6666

6767
for _ in range(num_objs):
6868
x = rs.randint(rad_max, height - rad_max)
@@ -139,7 +139,7 @@ def create_test_image_3d(
139139
raise ValueError(f"the minimal size {min_size} of the image should be larger than `2 * rad_max` 2x{rad_max}.")
140140

141141
image = np.zeros((height, width, depth))
142-
rs: np.random.RandomState = np.random.RandomState() if random_state is None else random_state
142+
rs: np.random.RandomState = np.random.random.__self__ if random_state is None else random_state # type: ignore
143143

144144
for _ in range(num_objs):
145145
x = rs.randint(rad_max, height - rad_max)

monai/data/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ def get_random_patch(
120120
"""
121121

122122
# choose the minimal corner of the patch
123-
rand_int = np.random.RandomState().randint if rand_state is None else rand_state.randint
123+
rand_int = np.random.randint if rand_state is None else rand_state.randint
124124
min_corner = tuple(rand_int(0, ms - ps + 1) if ms > ps else 0 for ms, ps in zip(dims, patch_size))
125125

126126
# create the slices for each dimension which define the patch in the source array

monai/transforms/utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -666,7 +666,7 @@ def generate_pos_neg_label_crop_centers(
666666
667667
"""
668668
if rand_state is None:
669-
rand_state = np.random.RandomState()
669+
rand_state = np.random.random.__self__ # type: ignore
670670

671671
centers = []
672672
fg_indices = np.asarray(fg_indices) if isinstance(fg_indices, Sequence) else fg_indices
@@ -721,7 +721,7 @@ def generate_label_classes_crop_centers(
721721
722722
"""
723723
if rand_state is None:
724-
rand_state = np.random.RandomState()
724+
rand_state = np.random.random.__self__ # type: ignore
725725

726726
if num_samples < 1:
727727
raise ValueError(f"num_samples must be an int number and greater than 0, got {num_samples}.")
@@ -1585,7 +1585,7 @@ def get_extreme_points(
15851585
"""
15861586
check_non_lazy_pending_ops(img, name="get_extreme_points")
15871587
if rand_state is None:
1588-
rand_state = np.random.RandomState()
1588+
rand_state = np.random.random.__self__ # type: ignore
15891589
indices = where(img != background)
15901590
if np.size(indices[0]) == 0:
15911591
raise ValueError("get_extreme_points: no foreground object in mask!")

monai/utils/ordering.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,6 @@ def random_idx(rows: int, cols: int, depths: int | None = None) -> np.ndarray:
202202
idx.append((r, c))
203203

204204
idx_np = np.array(idx)
205-
np.random.RandomState().shuffle(idx_np)
205+
np.random.shuffle(idx_np)
206206

207207
return idx_np

0 commit comments

Comments
 (0)