Skip to content

Commit 5d91aa3

Browse files
Add Randstain transfo
1 parent 3e79d76 commit 5d91aa3

3 files changed

Lines changed: 288 additions & 3 deletions

File tree

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
"mkdocstrings-python==1.16.10",
2626
"mkdocs-material==9.6.11",
2727
"numpy==1.24.3",
28+
"opencv-python==4.10.0.84",
2829
"omegaconf==2.3.0",
2930
"openpyxl==3.1.5",
3031
"pandas==2.2.3",

src/thunder/tasks/transformation_invariance.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ def transformation_invariance(
160160

161161
set_transform_seed(dataset_seed)
162162

163-
invariance_transforms = get_invariance_transforms()
163+
invariance_transforms = get_invariance_transforms(dataset_name)
164164

165165
# ---------------------------------------------------------------------
166166
# Main evaluation loop: compute similarities between original and transformed embeddings

src/thunder/utils/transforms.py

Lines changed: 286 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import torchvision.transforms.v2 as v2
1010
from PIL import Image
1111
from torchvision.transforms import ColorJitter, InterpolationMode
12+
import cv2
1213

1314
try:
1415
import kornia.morphology as _kmorph
@@ -21,6 +22,165 @@
2122
_AUG_RNG = random.Random()
2223

2324

25+
# Constants for RandStain transformation
26+
randstain_constants = {
27+
"bach": {
28+
"L": {
29+
"avg": {"mean": 176.212, "std": 13.677, "distribution": "laplace"},
30+
"std": {"mean": 38.368, "std": 8.739, "distribution": "laplace"},
31+
},
32+
"A": {
33+
"avg": {"mean": 150.281, "std": 8.22, "distribution": "laplace"},
34+
"std": {"mean": 10.942, "std": 3.087, "distribution": "norm"},
35+
},
36+
"B": {
37+
"avg": {"mean": 101.909, "std": 8.559, "distribution": "laplace"},
38+
"std": {"mean": 12.358, "std": 3.998, "distribution": "laplace"},
39+
},
40+
},
41+
"ccrcc": {
42+
"L": {
43+
"avg": {"mean": 162.086, "std": 23.691, "distribution": "norm"},
44+
"std": {"mean": 45.72, "std": 9.992, "distribution": "norm"},
45+
},
46+
"A": {
47+
"avg": {"mean": 152.298, "std": 7.65, "distribution": "norm"},
48+
"std": {"mean": 10.916, "std": 2.58, "distribution": "norm"},
49+
},
50+
"B": {
51+
"avg": {"mean": 117.695, "std": 4.044, "distribution": "norm"},
52+
"std": {"mean": 9.185, "std": 1.973, "distribution": "norm"},
53+
},
54+
},
55+
"crc": {
56+
"L": {
57+
"avg": {"mean": 160.658, "std": 28.507, "distribution": "laplace"},
58+
"std": {"mean": 35.602, "std": 12.938, "distribution": "laplace"},
59+
},
60+
"A": {
61+
"avg": {"mean": 155.721, "std": 8.59, "distribution": "laplace"},
62+
"std": {"mean": 8.644, "std": 3.222, "distribution": "norm"},
63+
},
64+
"B": {
65+
"avg": {"mean": 113.101, "std": 4.914, "distribution": "laplace"},
66+
"std": {"mean": 5.326, "std": 1.834, "distribution": "laplace"},
67+
},
68+
},
69+
"esca": {
70+
"L": {
71+
"avg": {"mean": 162.977, "std": 28.455, "distribution": "laplace"},
72+
"std": {"mean": 38.103, "std": 9.602, "distribution": "norm"},
73+
},
74+
"A": {
75+
"avg": {"mean": 153.457, "std": 10.034, "distribution": "laplace"},
76+
"std": {"mean": 9.525, "std": 3.285, "distribution": "norm"},
77+
},
78+
"B": {
79+
"avg": {"mean": 112.414, "std": 7.494, "distribution": "norm"},
80+
"std": {"mean": 5.663, "std": 1.883, "distribution": "norm"},
81+
},
82+
},
83+
"patch_camelyon": {
84+
"L": {
85+
"avg": {"mean": 157.616, "std": 40.322, "distribution": "norm"},
86+
"std": {"mean": 47.41, "std": 12.984, "distribution": "norm"},
87+
},
88+
"A": {
89+
"avg": {"mean": 151.256, "std": 10.978, "distribution": "norm"},
90+
"std": {"mean": 7.997, "std": 3.214, "distribution": "laplace"},
91+
},
92+
"B": {
93+
"avg": {"mean": 113.587, "std": 12.046, "distribution": "laplace"},
94+
"std": {"mean": 6.327, "std": 2.779, "distribution": "laplace"},
95+
},
96+
},
97+
"tcga_crc_msi": {
98+
"L": {
99+
"avg": {"mean": 157.412, "std": 17.274, "distribution": "norm"},
100+
"std": {"mean": 41.626, "std": 8.558, "distribution": "norm"},
101+
},
102+
"A": {
103+
"avg": {"mean": 155.497, "std": 4.807, "distribution": "norm"},
104+
"std": {"mean": 8.973, "std": 2.735, "distribution": "norm"},
105+
},
106+
"B": {
107+
"avg": {"mean": 113.043, "std": 4.678, "distribution": "laplace"},
108+
"std": {"mean": 5.587, "std": 1.552, "distribution": "laplace"},
109+
},
110+
},
111+
"tcga_tils": {
112+
"L": {
113+
"avg": {"mean": 159.268, "std": 33.309, "distribution": "norm"},
114+
"std": {"mean": 40.325, "std": 12.098, "distribution": "norm"},
115+
},
116+
"A": {
117+
"avg": {"mean": 151.63, "std": 9.875, "distribution": "norm"},
118+
"std": {"mean": 8.519, "std": 3.292, "distribution": "norm"},
119+
},
120+
"B": {
121+
"avg": {"mean": 117.799, "std": 6.768, "distribution": "norm"},
122+
"std": {"mean": 7.612, "std": 2.546, "distribution": "laplace"},
123+
},
124+
},
125+
"tcga_uniform": {
126+
"L": {
127+
"avg": {"mean": 140.328, "std": 26.043, "distribution": "norm"},
128+
"std": {"mean": 42.271, "std": 8.964, "distribution": "norm"},
129+
},
130+
"A": {
131+
"avg": {"mean": 156.3, "std": 7.71, "distribution": "norm"},
132+
"std": {"mean": 7.451, "std": 2.719, "distribution": "norm"},
133+
},
134+
"B": {
135+
"avg": {"mean": 114.37, "std": 5.652, "distribution": "norm"},
136+
"std": {"mean": 5.814, "std": 1.605, "distribution": "norm"},
137+
},
138+
},
139+
"wilds": {
140+
"L": {
141+
"avg": {"mean": 169.551, "std": 32.673, "distribution": "norm"},
142+
"std": {"mean": 34.248, "std": 11.67, "distribution": "norm"},
143+
},
144+
"A": {
145+
"avg": {"mean": 148.907, "std": 7.386, "distribution": "norm"},
146+
"std": {"mean": 6.937, "std": 2.683, "distribution": "laplace"},
147+
},
148+
"B": {
149+
"avg": {"mean": 116.235, "std": 5.245, "distribution": "laplace"},
150+
"std": {"mean": 5.577, "std": 1.418, "distribution": "norm"},
151+
},
152+
},
153+
"break_his": {
154+
"L": {
155+
"avg": {"mean": 184.174, "std": 15.589, "distribution": "laplace"},
156+
"std": {"mean": 25.219, "std": 7.311, "distribution": "norm"},
157+
},
158+
"A": {
159+
"avg": {"mean": 149.7, "std": 12.966, "distribution": "laplace"},
160+
"std": {"mean": 7.763, "std": 3.242, "distribution": "norm"},
161+
},
162+
"B": {
163+
"avg": {"mean": 116.526, "std": 7.479, "distribution": "laplace"},
164+
"std": {"mean": 5.346, "std": 1.618, "distribution": "norm"},
165+
},
166+
},
167+
"mhist": {
168+
"L": {
169+
"avg": {"mean": 179.178, "std": 16.974, "distribution": "norm"},
170+
"std": {"mean": 51.886, "std": 5.499, "distribution": "norm"},
171+
},
172+
"A": {
173+
"avg": {"mean": 142.941, "std": 4.153, "distribution": "norm"},
174+
"std": {"mean": 9.835, "std": 1.666, "distribution": "norm"},
175+
},
176+
"B": {
177+
"avg": {"mean": 114.176, "std": 3.819, "distribution": "norm"},
178+
"std": {"mean": 9.391, "std": 1.522, "distribution": "norm"},
179+
},
180+
},
181+
}
182+
183+
24184
def set_transform_seed(seed: int) -> None:
25185
"""
26186
Seed only the augmentation RNG. Call this once per‐dataset before you apply any of the get_invariance_transforms().
@@ -245,7 +405,7 @@ def _random_hed(sigma: float = 0.025) -> Callable[[_Image], _Image]:
245405
def _inner(img: _Image):
246406
M = torch.tensor(
247407
np.array(
248-
[[0.651, 0.701, 0.290], [0.269, 0.568, 0.778], [0.633, -0.713, 0.302]],
408+
[[0.65, 0.70, 0.29], [0.07, 0.99, 0.11], [0.27, 0.57, 0.78]],
249409
dtype="float32",
250410
)
251411
)
@@ -454,7 +614,130 @@ def _inner(img):
454614
return _inner
455615

456616

457-
def get_invariance_transforms() -> Dict[str, Callable[[_Image], _Image]]:
617+
def _random_randstain(
618+
dataset_name: str, std_hyper: float = -0.3
619+
) -> Callable[[_Image], Tuple[_Image, Dict]]:
620+
"""
621+
RandStain transformation.
622+
Code inspired from https://github.com/yiqings/RandStainNA/blob/master/randstainna.py
623+
"""
624+
625+
if dataset_name not in randstain_constants:
626+
raise ValueError(
627+
f"Unknown dataset_name '{dataset_name}'. "
628+
f"Available: {list(randstain_constants.keys())}"
629+
)
630+
631+
stats = randstain_constants[dataset_name]
632+
633+
def _inner(img: _Image) -> Tuple[_Image, Dict]:
634+
635+
rng_seed = _AUG_RNG.randint(0, 2**20 - 1)
636+
rng = np.random.RandomState(rng_seed)
637+
638+
if isinstance(img, Image.Image):
639+
rgb = np.array(img)
640+
if rgb.dtype != np.uint8:
641+
rgb = np.clip(rgb, 0, 255).astype(np.uint8)
642+
container = "pil"
643+
elif isinstance(img, torch.Tensor):
644+
t = _to_tensor(img)
645+
rgb = (
646+
(t.permute(1, 2, 0).cpu().numpy() * 255.0)
647+
.round()
648+
.clip(0, 255)
649+
.astype(np.uint8)
650+
)
651+
container = "tensor"
652+
elif isinstance(img, np.ndarray):
653+
rgb = img
654+
if rgb.dtype != np.uint8:
655+
if rgb.max() <= 1.0 + 1e-6:
656+
rgb = rgb * 255.0
657+
rgb = np.round(rgb).clip(0, 255).astype(np.uint8)
658+
container = "ndarray"
659+
else:
660+
raise TypeError(
661+
"Unsupported image type; expected PIL.Image, torch.Tensor, or np.ndarray."
662+
)
663+
664+
# ---- RGB -> LAB ----
665+
lab = cv2.cvtColor(rgb, cv2.COLOR_RGB2LAB).astype(np.float32)
666+
667+
flat = lab.reshape(-1, 3)
668+
img_avgs = flat.mean(axis=0)
669+
img_stds = flat.std(axis=0)
670+
img_stds = np.clip(img_stds, 1e-4, 255.0)
671+
672+
# ---- sample target avgs/stds per channel (L, A, B) ----
673+
tar_avgs = []
674+
tar_stds = []
675+
sampled = {}
676+
for i, ch in enumerate(("L", "A", "B")):
677+
ch_stats = stats[ch]
678+
# avg
679+
loc_avg = ch_stats["avg"]["mean"]
680+
scale_avg = ch_stats["avg"]["std"] * (1.0 + std_hyper)
681+
dist_avg = ch_stats["avg"]["distribution"].lower()
682+
if dist_avg in ("norm", "normal"):
683+
tavg = float(rng.normal(loc=loc_avg, scale=max(1e-8, scale_avg)))
684+
else:
685+
tavg = float(rng.laplace(loc=loc_avg, scale=max(1e-8, scale_avg)))
686+
# std
687+
loc_std = ch_stats["std"]["mean"]
688+
scale_std = ch_stats["std"]["std"] * (1.0 + std_hyper)
689+
dist_std = ch_stats["std"]["distribution"].lower()
690+
if dist_std in ("norm", "normal"):
691+
tstd = float(rng.normal(loc=loc_std, scale=max(1e-8, scale_std)))
692+
else:
693+
tstd = float(rng.laplace(loc=loc_std, scale=max(1e-8, scale_std)))
694+
tstd = max(1e-4, tstd)
695+
696+
tar_avgs.append(tavg)
697+
tar_stds.append(tstd)
698+
sampled[ch] = {
699+
"target_avg": tavg,
700+
"target_std": tstd,
701+
"avg_loc": loc_avg,
702+
"avg_scale": scale_avg,
703+
"avg_dist": dist_avg,
704+
"std_loc": loc_std,
705+
"std_scale": scale_std,
706+
"std_dist": dist_std,
707+
}
708+
709+
tar_avgs = np.array(tar_avgs, dtype=np.float32)
710+
tar_stds = np.array(tar_stds, dtype=np.float32)
711+
712+
out_lab = (lab - img_avgs) * (tar_stds / img_stds) + tar_avgs
713+
out_lab = np.clip(out_lab, 0.0, 255.0).astype(np.uint8)
714+
715+
# ---- LAB -> RGB ----
716+
out_rgb = cv2.cvtColor(out_lab, cv2.COLOR_LAB2RGB)
717+
718+
if container == "pil":
719+
out = Image.fromarray(out_rgb)
720+
elif container == "tensor":
721+
t = torch.from_numpy(out_rgb).permute(2, 0, 1).float() / 255.0
722+
out = _from_tensor(t, img) # use existing to match dtype
723+
else: # ndarray
724+
if isinstance(img, np.ndarray) and img.dtype != np.uint8:
725+
out = (out_rgb.astype(np.float32) / 255.0).astype(img.dtype)
726+
else:
727+
out = out_rgb
728+
729+
params = {
730+
"seed": rng_seed,
731+
"sampled": sampled,
732+
}
733+
return out, params
734+
735+
return _inner
736+
737+
738+
def get_invariance_transforms(
739+
dataset_name: str,
740+
) -> Dict[str, Callable[[_Image], _Image]]:
458741
"""
459742
Getting dictionary of available data augmentation transformations.
460743
@@ -469,6 +752,7 @@ def get_invariance_transforms() -> Dict[str, Callable[[_Image], _Image]]:
469752
"random_color_jitter": _random_color_jitter(),
470753
"random_gamma": _random_gamma(),
471754
"random_hed": _random_hed(),
755+
"random_randstain": _random_randstain(dataset_name),
472756
"random_cutout": _random_cutout(),
473757
"random_dilation": _random_dilation(),
474758
"random_erosion": _random_erosion(),

0 commit comments

Comments
 (0)