99import torchvision .transforms .v2 as v2
1010from PIL import Image
1111from torchvision .transforms import ColorJitter , InterpolationMode
12+ import cv2
1213
1314try :
1415 import kornia .morphology as _kmorph
2122_AUG_RNG = random .Random ()
2223
2324
25+ # Constants for RandStain transformation
26+ randstain_constants = {
27+ "bach" : {
28+ "L" : {
29+ "avg" : {"mean" : 176.212 , "std" : 13.677 , "distribution" : "laplace" },
30+ "std" : {"mean" : 38.368 , "std" : 8.739 , "distribution" : "laplace" },
31+ },
32+ "A" : {
33+ "avg" : {"mean" : 150.281 , "std" : 8.22 , "distribution" : "laplace" },
34+ "std" : {"mean" : 10.942 , "std" : 3.087 , "distribution" : "norm" },
35+ },
36+ "B" : {
37+ "avg" : {"mean" : 101.909 , "std" : 8.559 , "distribution" : "laplace" },
38+ "std" : {"mean" : 12.358 , "std" : 3.998 , "distribution" : "laplace" },
39+ },
40+ },
41+ "ccrcc" : {
42+ "L" : {
43+ "avg" : {"mean" : 162.086 , "std" : 23.691 , "distribution" : "norm" },
44+ "std" : {"mean" : 45.72 , "std" : 9.992 , "distribution" : "norm" },
45+ },
46+ "A" : {
47+ "avg" : {"mean" : 152.298 , "std" : 7.65 , "distribution" : "norm" },
48+ "std" : {"mean" : 10.916 , "std" : 2.58 , "distribution" : "norm" },
49+ },
50+ "B" : {
51+ "avg" : {"mean" : 117.695 , "std" : 4.044 , "distribution" : "norm" },
52+ "std" : {"mean" : 9.185 , "std" : 1.973 , "distribution" : "norm" },
53+ },
54+ },
55+ "crc" : {
56+ "L" : {
57+ "avg" : {"mean" : 160.658 , "std" : 28.507 , "distribution" : "laplace" },
58+ "std" : {"mean" : 35.602 , "std" : 12.938 , "distribution" : "laplace" },
59+ },
60+ "A" : {
61+ "avg" : {"mean" : 155.721 , "std" : 8.59 , "distribution" : "laplace" },
62+ "std" : {"mean" : 8.644 , "std" : 3.222 , "distribution" : "norm" },
63+ },
64+ "B" : {
65+ "avg" : {"mean" : 113.101 , "std" : 4.914 , "distribution" : "laplace" },
66+ "std" : {"mean" : 5.326 , "std" : 1.834 , "distribution" : "laplace" },
67+ },
68+ },
69+ "esca" : {
70+ "L" : {
71+ "avg" : {"mean" : 162.977 , "std" : 28.455 , "distribution" : "laplace" },
72+ "std" : {"mean" : 38.103 , "std" : 9.602 , "distribution" : "norm" },
73+ },
74+ "A" : {
75+ "avg" : {"mean" : 153.457 , "std" : 10.034 , "distribution" : "laplace" },
76+ "std" : {"mean" : 9.525 , "std" : 3.285 , "distribution" : "norm" },
77+ },
78+ "B" : {
79+ "avg" : {"mean" : 112.414 , "std" : 7.494 , "distribution" : "norm" },
80+ "std" : {"mean" : 5.663 , "std" : 1.883 , "distribution" : "norm" },
81+ },
82+ },
83+ "patch_camelyon" : {
84+ "L" : {
85+ "avg" : {"mean" : 157.616 , "std" : 40.322 , "distribution" : "norm" },
86+ "std" : {"mean" : 47.41 , "std" : 12.984 , "distribution" : "norm" },
87+ },
88+ "A" : {
89+ "avg" : {"mean" : 151.256 , "std" : 10.978 , "distribution" : "norm" },
90+ "std" : {"mean" : 7.997 , "std" : 3.214 , "distribution" : "laplace" },
91+ },
92+ "B" : {
93+ "avg" : {"mean" : 113.587 , "std" : 12.046 , "distribution" : "laplace" },
94+ "std" : {"mean" : 6.327 , "std" : 2.779 , "distribution" : "laplace" },
95+ },
96+ },
97+ "tcga_crc_msi" : {
98+ "L" : {
99+ "avg" : {"mean" : 157.412 , "std" : 17.274 , "distribution" : "norm" },
100+ "std" : {"mean" : 41.626 , "std" : 8.558 , "distribution" : "norm" },
101+ },
102+ "A" : {
103+ "avg" : {"mean" : 155.497 , "std" : 4.807 , "distribution" : "norm" },
104+ "std" : {"mean" : 8.973 , "std" : 2.735 , "distribution" : "norm" },
105+ },
106+ "B" : {
107+ "avg" : {"mean" : 113.043 , "std" : 4.678 , "distribution" : "laplace" },
108+ "std" : {"mean" : 5.587 , "std" : 1.552 , "distribution" : "laplace" },
109+ },
110+ },
111+ "tcga_tils" : {
112+ "L" : {
113+ "avg" : {"mean" : 159.268 , "std" : 33.309 , "distribution" : "norm" },
114+ "std" : {"mean" : 40.325 , "std" : 12.098 , "distribution" : "norm" },
115+ },
116+ "A" : {
117+ "avg" : {"mean" : 151.63 , "std" : 9.875 , "distribution" : "norm" },
118+ "std" : {"mean" : 8.519 , "std" : 3.292 , "distribution" : "norm" },
119+ },
120+ "B" : {
121+ "avg" : {"mean" : 117.799 , "std" : 6.768 , "distribution" : "norm" },
122+ "std" : {"mean" : 7.612 , "std" : 2.546 , "distribution" : "laplace" },
123+ },
124+ },
125+ "tcga_uniform" : {
126+ "L" : {
127+ "avg" : {"mean" : 140.328 , "std" : 26.043 , "distribution" : "norm" },
128+ "std" : {"mean" : 42.271 , "std" : 8.964 , "distribution" : "norm" },
129+ },
130+ "A" : {
131+ "avg" : {"mean" : 156.3 , "std" : 7.71 , "distribution" : "norm" },
132+ "std" : {"mean" : 7.451 , "std" : 2.719 , "distribution" : "norm" },
133+ },
134+ "B" : {
135+ "avg" : {"mean" : 114.37 , "std" : 5.652 , "distribution" : "norm" },
136+ "std" : {"mean" : 5.814 , "std" : 1.605 , "distribution" : "norm" },
137+ },
138+ },
139+ "wilds" : {
140+ "L" : {
141+ "avg" : {"mean" : 169.551 , "std" : 32.673 , "distribution" : "norm" },
142+ "std" : {"mean" : 34.248 , "std" : 11.67 , "distribution" : "norm" },
143+ },
144+ "A" : {
145+ "avg" : {"mean" : 148.907 , "std" : 7.386 , "distribution" : "norm" },
146+ "std" : {"mean" : 6.937 , "std" : 2.683 , "distribution" : "laplace" },
147+ },
148+ "B" : {
149+ "avg" : {"mean" : 116.235 , "std" : 5.245 , "distribution" : "laplace" },
150+ "std" : {"mean" : 5.577 , "std" : 1.418 , "distribution" : "norm" },
151+ },
152+ },
153+ "break_his" : {
154+ "L" : {
155+ "avg" : {"mean" : 184.174 , "std" : 15.589 , "distribution" : "laplace" },
156+ "std" : {"mean" : 25.219 , "std" : 7.311 , "distribution" : "norm" },
157+ },
158+ "A" : {
159+ "avg" : {"mean" : 149.7 , "std" : 12.966 , "distribution" : "laplace" },
160+ "std" : {"mean" : 7.763 , "std" : 3.242 , "distribution" : "norm" },
161+ },
162+ "B" : {
163+ "avg" : {"mean" : 116.526 , "std" : 7.479 , "distribution" : "laplace" },
164+ "std" : {"mean" : 5.346 , "std" : 1.618 , "distribution" : "norm" },
165+ },
166+ },
167+ "mhist" : {
168+ "L" : {
169+ "avg" : {"mean" : 179.178 , "std" : 16.974 , "distribution" : "norm" },
170+ "std" : {"mean" : 51.886 , "std" : 5.499 , "distribution" : "norm" },
171+ },
172+ "A" : {
173+ "avg" : {"mean" : 142.941 , "std" : 4.153 , "distribution" : "norm" },
174+ "std" : {"mean" : 9.835 , "std" : 1.666 , "distribution" : "norm" },
175+ },
176+ "B" : {
177+ "avg" : {"mean" : 114.176 , "std" : 3.819 , "distribution" : "norm" },
178+ "std" : {"mean" : 9.391 , "std" : 1.522 , "distribution" : "norm" },
179+ },
180+ },
181+ }
182+
183+
24184def set_transform_seed (seed : int ) -> None :
25185 """
26186 Seed only the augmentation RNG. Call this once per‐dataset before you apply any of the get_invariance_transforms().
@@ -245,7 +405,7 @@ def _random_hed(sigma: float = 0.025) -> Callable[[_Image], _Image]:
245405 def _inner (img : _Image ):
246406 M = torch .tensor (
247407 np .array (
248- [[0.651 , 0.701 , 0.290 ], [0.269 , 0.568 , 0.778 ], [0.633 , - 0.713 , 0.302 ]],
408+ [[0.65 , 0.70 , 0.29 ], [0.07 , 0.99 , 0.11 ], [0.27 , 0.57 , 0.78 ]],
249409 dtype = "float32" ,
250410 )
251411 )
@@ -454,7 +614,130 @@ def _inner(img):
454614 return _inner
455615
456616
457- def get_invariance_transforms () -> Dict [str , Callable [[_Image ], _Image ]]:
617+ def _random_randstain (
618+ dataset_name : str , std_hyper : float = - 0.3
619+ ) -> Callable [[_Image ], Tuple [_Image , Dict ]]:
620+ """
621+ RandStain transformation.
622+ Code inspired from https://github.com/yiqings/RandStainNA/blob/master/randstainna.py
623+ """
624+
625+ if dataset_name not in randstain_constants :
626+ raise ValueError (
627+ f"Unknown dataset_name '{ dataset_name } '. "
628+ f"Available: { list (randstain_constants .keys ())} "
629+ )
630+
631+ stats = randstain_constants [dataset_name ]
632+
633+ def _inner (img : _Image ) -> Tuple [_Image , Dict ]:
634+
635+ rng_seed = _AUG_RNG .randint (0 , 2 ** 20 - 1 )
636+ rng = np .random .RandomState (rng_seed )
637+
638+ if isinstance (img , Image .Image ):
639+ rgb = np .array (img )
640+ if rgb .dtype != np .uint8 :
641+ rgb = np .clip (rgb , 0 , 255 ).astype (np .uint8 )
642+ container = "pil"
643+ elif isinstance (img , torch .Tensor ):
644+ t = _to_tensor (img )
645+ rgb = (
646+ (t .permute (1 , 2 , 0 ).cpu ().numpy () * 255.0 )
647+ .round ()
648+ .clip (0 , 255 )
649+ .astype (np .uint8 )
650+ )
651+ container = "tensor"
652+ elif isinstance (img , np .ndarray ):
653+ rgb = img
654+ if rgb .dtype != np .uint8 :
655+ if rgb .max () <= 1.0 + 1e-6 :
656+ rgb = rgb * 255.0
657+ rgb = np .round (rgb ).clip (0 , 255 ).astype (np .uint8 )
658+ container = "ndarray"
659+ else :
660+ raise TypeError (
661+ "Unsupported image type; expected PIL.Image, torch.Tensor, or np.ndarray."
662+ )
663+
664+ # ---- RGB -> LAB ----
665+ lab = cv2 .cvtColor (rgb , cv2 .COLOR_RGB2LAB ).astype (np .float32 )
666+
667+ flat = lab .reshape (- 1 , 3 )
668+ img_avgs = flat .mean (axis = 0 )
669+ img_stds = flat .std (axis = 0 )
670+ img_stds = np .clip (img_stds , 1e-4 , 255.0 )
671+
672+ # ---- sample target avgs/stds per channel (L, A, B) ----
673+ tar_avgs = []
674+ tar_stds = []
675+ sampled = {}
676+ for i , ch in enumerate (("L" , "A" , "B" )):
677+ ch_stats = stats [ch ]
678+ # avg
679+ loc_avg = ch_stats ["avg" ]["mean" ]
680+ scale_avg = ch_stats ["avg" ]["std" ] * (1.0 + std_hyper )
681+ dist_avg = ch_stats ["avg" ]["distribution" ].lower ()
682+ if dist_avg in ("norm" , "normal" ):
683+ tavg = float (rng .normal (loc = loc_avg , scale = max (1e-8 , scale_avg )))
684+ else :
685+ tavg = float (rng .laplace (loc = loc_avg , scale = max (1e-8 , scale_avg )))
686+ # std
687+ loc_std = ch_stats ["std" ]["mean" ]
688+ scale_std = ch_stats ["std" ]["std" ] * (1.0 + std_hyper )
689+ dist_std = ch_stats ["std" ]["distribution" ].lower ()
690+ if dist_std in ("norm" , "normal" ):
691+ tstd = float (rng .normal (loc = loc_std , scale = max (1e-8 , scale_std )))
692+ else :
693+ tstd = float (rng .laplace (loc = loc_std , scale = max (1e-8 , scale_std )))
694+ tstd = max (1e-4 , tstd )
695+
696+ tar_avgs .append (tavg )
697+ tar_stds .append (tstd )
698+ sampled [ch ] = {
699+ "target_avg" : tavg ,
700+ "target_std" : tstd ,
701+ "avg_loc" : loc_avg ,
702+ "avg_scale" : scale_avg ,
703+ "avg_dist" : dist_avg ,
704+ "std_loc" : loc_std ,
705+ "std_scale" : scale_std ,
706+ "std_dist" : dist_std ,
707+ }
708+
709+ tar_avgs = np .array (tar_avgs , dtype = np .float32 )
710+ tar_stds = np .array (tar_stds , dtype = np .float32 )
711+
712+ out_lab = (lab - img_avgs ) * (tar_stds / img_stds ) + tar_avgs
713+ out_lab = np .clip (out_lab , 0.0 , 255.0 ).astype (np .uint8 )
714+
715+ # ---- LAB -> RGB ----
716+ out_rgb = cv2 .cvtColor (out_lab , cv2 .COLOR_LAB2RGB )
717+
718+ if container == "pil" :
719+ out = Image .fromarray (out_rgb )
720+ elif container == "tensor" :
721+ t = torch .from_numpy (out_rgb ).permute (2 , 0 , 1 ).float () / 255.0
722+ out = _from_tensor (t , img ) # use existing to match dtype
723+ else : # ndarray
724+ if isinstance (img , np .ndarray ) and img .dtype != np .uint8 :
725+ out = (out_rgb .astype (np .float32 ) / 255.0 ).astype (img .dtype )
726+ else :
727+ out = out_rgb
728+
729+ params = {
730+ "seed" : rng_seed ,
731+ "sampled" : sampled ,
732+ }
733+ return out , params
734+
735+ return _inner
736+
737+
738+ def get_invariance_transforms (
739+ dataset_name : str ,
740+ ) -> Dict [str , Callable [[_Image ], _Image ]]:
458741 """
459742 Getting dictionary of available data augmentation transformations.
460743
@@ -469,6 +752,7 @@ def get_invariance_transforms() -> Dict[str, Callable[[_Image], _Image]]:
469752 "random_color_jitter" : _random_color_jitter (),
470753 "random_gamma" : _random_gamma (),
471754 "random_hed" : _random_hed (),
755+ "random_randstain" : _random_randstain (dataset_name ),
472756 "random_cutout" : _random_cutout (),
473757 "random_dilation" : _random_dilation (),
474758 "random_erosion" : _random_erosion (),
0 commit comments