1+ import tensorflow as tf
2+ from torchstain .tf .utils import cov , percentile , solveLS
3+
4+ class TensorFlowMultiMacenkoNormalizer :
5+ def __init__ (self , norm_mode = "avg-post" ):
6+ self .norm_mode = norm_mode
7+ self .HERef = tf .constant ([[0.5626 , 0.2159 ],
8+ [0.7201 , 0.8012 ],
9+ [0.4062 , 0.5581 ]], dtype = tf .float32 )
10+ self .maxCRef = tf .constant ([1.9705 , 1.0308 ], dtype = tf .float32 )
11+
12+ def __convert_rgb2od (self , I , Io , beta ):
13+ I = tf .transpose (I , perm = [1 , 2 , 0 ]) # Shape: (height, width, 3)
14+ OD = - tf .math .log ((tf .reshape (I , [- 1 , tf .shape (I )[- 1 ]]) + 1 ) / Io )
15+ ODhat = tf .boolean_mask (OD , ~ tf .reduce_any (OD < beta , axis = 1 ))
16+
17+ if tf .size (ODhat ) == 0 :
18+ raise ValueError ("ODhat is empty. Check image values and beta threshold." )
19+
20+ return OD , ODhat
21+
22+ def __find_phi_bounds (self , ODhat , eigvecs , alpha ):
23+ That = tf .matmul (ODhat , eigvecs )
24+ phi = tf .math .atan2 (That [:, 1 ], That [:, 0 ])
25+
26+ minPhi = percentile (phi , alpha )
27+ maxPhi = percentile (phi , 100 - alpha )
28+ return minPhi , maxPhi
29+
30+ def __find_HE_from_bounds (self , eigvecs , minPhi , maxPhi ):
31+ # Expand minPhi and maxPhi to have a second dimension
32+ vMin = tf .matmul (eigvecs , tf .stack ([tf .cos (minPhi ), tf .sin (minPhi )])[:, tf .newaxis ])
33+ vMax = tf .matmul (eigvecs , tf .stack ([tf .cos (maxPhi ), tf .sin (maxPhi )])[:, tf .newaxis ])
34+
35+ # Concatenate along the last dimension and return
36+ HE = tf .where (vMin [0 ] > vMax [0 ],
37+ tf .concat ([vMin , vMax ], axis = 1 ),
38+ tf .concat ([vMax , vMin ], axis = 1 ))
39+ return HE
40+
41+ def __find_HE (self , ODhat , eigvecs , alpha ):
42+ minPhi , maxPhi = self .__find_phi_bounds (ODhat , eigvecs , alpha )
43+ return self .__find_HE_from_bounds (eigvecs , minPhi , maxPhi )
44+
45+ def __find_concentration (self , OD , HE ):
46+ # Solve linear system using the provided solveLS function
47+ Y = tf .transpose (OD )
48+ C = solveLS (HE , Y )
49+ return C
50+
51+ def __compute_matrices_single (self , I , Io , alpha , beta ):
52+ OD , ODhat = self .__convert_rgb2od (I , Io , beta )
53+
54+ cov_matrix = cov (tf .transpose (ODhat )) # cov expects shape (dims, samples)
55+ eigvals , eigvecs = tf .linalg .eigh (cov_matrix )
56+ eigvecs = tf .gather (eigvecs , [1 , 2 ], axis = 1 )
57+
58+ HE = self .__find_HE (ODhat , eigvecs , alpha )
59+ C = self .__find_concentration (OD , HE )
60+ maxC = tf .stack ([percentile (C [0 , :], 99 ), percentile (C [1 , :], 99 )])
61+ return HE , C , maxC
62+
63+ def fit (self , Is , Io = 240 , alpha = 1 , beta = 0.15 ):
64+ if not isinstance (Is , list ) or len (Is ) == 0 :
65+ raise ValueError ("Input images should be a non-empty list of tensors." )
66+
67+ for i , I in enumerate (Is ):
68+ if not isinstance (I , tf .Tensor ):
69+ raise ValueError (f"Image at index { i } is not a TensorFlow tensor." )
70+ if I .ndim != 3 or I .shape [0 ] != 3 :
71+ raise ValueError (f"Image at index { i } should have shape (3, height, width)." )
72+
73+ if self .norm_mode == "avg-post" :
74+ HEs , _ , maxCs = zip (* [self .__compute_matrices_single (I , Io , alpha , beta ) for I in Is ])
75+ self .HERef = tf .reduce_mean (tf .stack (HEs ), axis = 0 )
76+ self .maxCRef = tf .reduce_mean (tf .stack (maxCs ), axis = 0 )
77+ elif self .norm_mode == "concat" :
78+ ODs , ODhats = zip (* [self .__convert_rgb2od (I , Io , beta ) for I in Is ])
79+ OD = tf .concat (ODs , axis = 0 )
80+ ODhat = tf .concat (ODhats , axis = 0 )
81+
82+ cov_matrix = cov (tf .transpose (ODhat )) # cov expects shape (dims, samples)
83+ eigvals , eigvecs = tf .linalg .eigh (cov_matrix )
84+ eigvecs = tf .gather (eigvecs , [1 , 2 ], axis = 1 )
85+
86+ HE = self .__find_HE (ODhat , eigvecs , alpha )
87+ C = self .__find_concentration (OD , HE )
88+ maxCs = tf .stack ([percentile (C [0 , :], 99 ), percentile (C [1 , :], 99 )])
89+
90+ self .HERef = HE
91+ self .maxCRef = maxCs
92+ else :
93+ raise ValueError ("Unsupported normalization mode." )
94+
95+ def normalize (self , I , Io = 240 , alpha = 1 , beta = 0.15 , stains = True ):
96+ c , h , w = I .shape
97+
98+ HE , C , maxC = self .__compute_matrices_single (I , Io , alpha , beta )
99+
100+ # Ensure maxCRef and maxC are broadcastable
101+ scaling_factors = (self .maxCRef / maxC )[:, tf .newaxis ] # Shape: [2, 1]
102+ C = scaling_factors * C [:2 , :] # Use only the first two rows of C
103+
104+ # Reconstruct the normalized image
105+ Inorm = Io * tf .exp (- tf .linalg .matmul (self .HERef , C ))
106+ Inorm = tf .clip_by_value (Inorm , 0 , 255 )
107+ Inorm = tf .transpose (tf .reshape (Inorm , [c , h , w ]), perm = [1 , 2 , 0 ]) # Convert back to HWC format
108+ Inorm = tf .cast (Inorm , tf .int32 )
109+
110+ H , E = None , None
111+
112+ if stains :
113+ # Extract the H and E components
114+ H = Io * tf .exp (- tf .linalg .matmul (self .HERef [:, 0 :1 ], C [0 :1 , :]))
115+ H = tf .clip_by_value (H , 0 , 255 )
116+ H = tf .transpose (tf .reshape (H , [c , h , w ]), perm = [1 , 2 , 0 ])
117+ H = tf .cast (H , tf .int32 )
118+
119+ E = Io * tf .exp (- tf .linalg .matmul (self .HERef [:, 1 :2 ], C [1 :2 , :]))
120+ E = tf .clip_by_value (E , 0 , 255 )
121+ E = tf .transpose (tf .reshape (E , [c , h , w ]), perm = [1 , 2 , 0 ])
122+ E = tf .cast (E , tf .int32 )
123+
124+ return Inorm , H , E
0 commit comments