Skip to content

Commit dc1af1c

Browse files
committed
Replace direct np.random.* calls with np.random.RandomState instances
Replace global `np.random` function calls with proper `RandomState` instances for reproducibility: - transforms/utils.py: Replace `np.random.random.__self__` (3 sites) with `np.random.RandomState()` in generate_pos_neg_label_crop_centers, weighted_patch_samples, and get_extreme_points - transforms/signal/array.py: Replace `np.random.choice` (2 sites) with `self.R.choice` in SignalRandAddSine and SignalRandAddSquarePulsePartial (classes already inherit Randomizable) - data/synthetic.py: Replace `np.random.random.__self__` (2 sites) with `np.random.RandomState()` in create_test_image_2d/3d - data/utils.py: Replace `np.random.randint` fallback with `np.random.RandomState().randint` in get_random_patch - utils/ordering.py: Replace `np.random.shuffle` with `np.random.RandomState().shuffle` in random ordering Ref #6888 Signed-off-by: SexyERIC0723 <haoyuwang144@gmail.com>
1 parent a8176f1 commit dc1af1c

5 files changed

Lines changed: 9 additions & 9 deletions

File tree

monai/data/synthetic.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def create_test_image_2d(
6262
raise ValueError(f"the minimal size {min_size} of the image should be larger than `2 * rad_max` 2x{rad_max}.")
6363

6464
image = np.zeros((height, width))
65-
rs: np.random.RandomState = np.random.random.__self__ if random_state is None else random_state # type: ignore
65+
rs: np.random.RandomState = np.random.RandomState() if random_state is None else random_state
6666

6767
for _ in range(num_objs):
6868
x = rs.randint(rad_max, height - rad_max)
@@ -139,7 +139,7 @@ def create_test_image_3d(
139139
raise ValueError(f"the minimal size {min_size} of the image should be larger than `2 * rad_max` 2x{rad_max}.")
140140

141141
image = np.zeros((height, width, depth))
142-
rs: np.random.RandomState = np.random.random.__self__ if random_state is None else random_state # type: ignore
142+
rs: np.random.RandomState = np.random.RandomState() if random_state is None else random_state
143143

144144
for _ in range(num_objs):
145145
x = rs.randint(rad_max, height - rad_max)

monai/data/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ def get_random_patch(
120120
"""
121121

122122
# choose the minimal corner of the patch
123-
rand_int = np.random.randint if rand_state is None else rand_state.randint
123+
rand_int = np.random.RandomState().randint if rand_state is None else rand_state.randint
124124
min_corner = tuple(rand_int(0, ms - ps + 1) if ms > ps else 0 for ms, ps in zip(dims, patch_size))
125125

126126
# create the slices for each dimension which define the patch in the source array

monai/transforms/signal/array.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,7 @@ def __call__(self, signal: NdarrayOrTensor) -> NdarrayOrTensor:
273273
data = convert_to_tensor(self.freqs * time_partial)
274274
sine_partial = self.magnitude * torch.sin(data)
275275

276-
loc = np.random.choice(range(length))
276+
loc = self.R.choice(range(length))
277277
signal = paste(signal, sine_partial, (loc,))
278278

279279
return signal
@@ -354,7 +354,7 @@ def __call__(self, signal: NdarrayOrTensor) -> NdarrayOrTensor:
354354
time_partial = np.arange(0, round(self.fracs * length), 1)
355355
squaredpulse_partial = self.magnitude * squarepulse(self.freqs * time_partial)
356356

357-
loc = np.random.choice(range(length))
357+
loc = self.R.choice(range(length))
358358
signal = paste(signal, squaredpulse_partial, (loc,))
359359

360360
return signal

monai/transforms/utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -666,7 +666,7 @@ def generate_pos_neg_label_crop_centers(
666666
667667
"""
668668
if rand_state is None:
669-
rand_state = np.random.random.__self__ # type: ignore
669+
rand_state = np.random.RandomState()
670670

671671
centers = []
672672
fg_indices = np.asarray(fg_indices) if isinstance(fg_indices, Sequence) else fg_indices
@@ -721,7 +721,7 @@ def generate_label_classes_crop_centers(
721721
722722
"""
723723
if rand_state is None:
724-
rand_state = np.random.random.__self__ # type: ignore
724+
rand_state = np.random.RandomState()
725725

726726
if num_samples < 1:
727727
raise ValueError(f"num_samples must be an int number and greater than 0, got {num_samples}.")
@@ -1585,7 +1585,7 @@ def get_extreme_points(
15851585
"""
15861586
check_non_lazy_pending_ops(img, name="get_extreme_points")
15871587
if rand_state is None:
1588-
rand_state = np.random.random.__self__ # type: ignore
1588+
rand_state = np.random.RandomState()
15891589
indices = where(img != background)
15901590
if np.size(indices[0]) == 0:
15911591
raise ValueError("get_extreme_points: no foreground object in mask!")

monai/utils/ordering.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,6 @@ def random_idx(rows: int, cols: int, depths: int | None = None) -> np.ndarray:
202202
idx.append((r, c))
203203

204204
idx_np = np.array(idx)
205-
np.random.shuffle(idx_np)
205+
np.random.RandomState().shuffle(idx_np)
206206

207207
return idx_np

0 commit comments

Comments
 (0)