Skip to content

Commit 17127a5

Browse files
committed
Implemented TensorFlowMultiMacenkoNormalizer
1 parent 141e4b9 commit 17127a5

3 files changed

Lines changed: 129 additions & 2 deletions

File tree

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
def MultiMacenkoNormalizer(backend="torch", **kwargs):
22
if backend == "numpy":
3-
raise NotImplementedError("MultiMacenkoNormalizer is not implemented for NumPy backend")
3+
from torchstain.numpy.normalizers import NumpyMultiMacenkoNormalizer
4+
return NumpyMultiMacenkoNormalizer(**kwargs)
45
elif backend == "torch":
56
from torchstain.torch.normalizers import TorchMultiMacenkoNormalizer
67
return TorchMultiMacenkoNormalizer(**kwargs)
78
elif backend == "tensorflow":
8-
raise NotImplementedError("MultiMacenkoNormalizer is not implemented for TensorFlow backend")
9+
from torchstain.tf.normalizers import TensorFlowMultiMacenkoNormalizer
10+
return TensorFlowMultiMacenkoNormalizer(**kwargs)
911
else:
1012
raise Exception(f"Unsupported backend {backend}")
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
from torchstain.tf.normalizers.macenko import TensorFlowMacenkoNormalizer
22
from torchstain.tf.normalizers.reinhard import TensorFlowReinhardNormalizer
3+
from torchstain.tf.normalizers.multitarget import TensorFlowMultiMacenkoNormalizer
Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
import tensorflow as tf
2+
from torchstain.tf.utils import cov, percentile, solveLS
3+
4+
class TensorFlowMultiMacenkoNormalizer:
5+
def __init__(self, norm_mode="avg-post"):
6+
self.norm_mode = norm_mode
7+
self.HERef = tf.constant([[0.5626, 0.2159],
8+
[0.7201, 0.8012],
9+
[0.4062, 0.5581]], dtype=tf.float32)
10+
self.maxCRef = tf.constant([1.9705, 1.0308], dtype=tf.float32)
11+
12+
def __convert_rgb2od(self, I, Io, beta):
13+
I = tf.transpose(I, perm=[1, 2, 0]) # Shape: (height, width, 3)
14+
OD = -tf.math.log((tf.reshape(I, [-1, tf.shape(I)[-1]]) + 1) / Io)
15+
ODhat = tf.boolean_mask(OD, ~tf.reduce_any(OD < beta, axis=1))
16+
17+
if tf.size(ODhat) == 0:
18+
raise ValueError("ODhat is empty. Check image values and beta threshold.")
19+
20+
return OD, ODhat
21+
22+
def __find_phi_bounds(self, ODhat, eigvecs, alpha):
23+
That = tf.matmul(ODhat, eigvecs)
24+
phi = tf.math.atan2(That[:, 1], That[:, 0])
25+
26+
minPhi = percentile(phi, alpha)
27+
maxPhi = percentile(phi, 100 - alpha)
28+
return minPhi, maxPhi
29+
30+
def __find_HE_from_bounds(self, eigvecs, minPhi, maxPhi):
31+
# Expand minPhi and maxPhi to have a second dimension
32+
vMin = tf.matmul(eigvecs, tf.stack([tf.cos(minPhi), tf.sin(minPhi)])[:, tf.newaxis])
33+
vMax = tf.matmul(eigvecs, tf.stack([tf.cos(maxPhi), tf.sin(maxPhi)])[:, tf.newaxis])
34+
35+
# Concatenate along the last dimension and return
36+
HE = tf.where(vMin[0] > vMax[0],
37+
tf.concat([vMin, vMax], axis=1),
38+
tf.concat([vMax, vMin], axis=1))
39+
return HE
40+
41+
def __find_HE(self, ODhat, eigvecs, alpha):
42+
minPhi, maxPhi = self.__find_phi_bounds(ODhat, eigvecs, alpha)
43+
return self.__find_HE_from_bounds(eigvecs, minPhi, maxPhi)
44+
45+
def __find_concentration(self, OD, HE):
46+
# Solve linear system using the provided solveLS function
47+
Y = tf.transpose(OD)
48+
C = solveLS(HE, Y)
49+
return C
50+
51+
def __compute_matrices_single(self, I, Io, alpha, beta):
52+
OD, ODhat = self.__convert_rgb2od(I, Io, beta)
53+
54+
cov_matrix = cov(tf.transpose(ODhat)) # cov expects shape (dims, samples)
55+
eigvals, eigvecs = tf.linalg.eigh(cov_matrix)
56+
eigvecs = tf.gather(eigvecs, [1, 2], axis=1)
57+
58+
HE = self.__find_HE(ODhat, eigvecs, alpha)
59+
C = self.__find_concentration(OD, HE)
60+
maxC = tf.stack([percentile(C[0, :], 99), percentile(C[1, :], 99)])
61+
return HE, C, maxC
62+
63+
def fit(self, Is, Io=240, alpha=1, beta=0.15):
64+
if not isinstance(Is, list) or len(Is) == 0:
65+
raise ValueError("Input images should be a non-empty list of tensors.")
66+
67+
for i, I in enumerate(Is):
68+
if not isinstance(I, tf.Tensor):
69+
raise ValueError(f"Image at index {i} is not a TensorFlow tensor.")
70+
if I.ndim != 3 or I.shape[0] != 3:
71+
raise ValueError(f"Image at index {i} should have shape (3, height, width).")
72+
73+
if self.norm_mode == "avg-post":
74+
HEs, _, maxCs = zip(*[self.__compute_matrices_single(I, Io, alpha, beta) for I in Is])
75+
self.HERef = tf.reduce_mean(tf.stack(HEs), axis=0)
76+
self.maxCRef = tf.reduce_mean(tf.stack(maxCs), axis=0)
77+
elif self.norm_mode == "concat":
78+
ODs, ODhats = zip(*[self.__convert_rgb2od(I, Io, beta) for I in Is])
79+
OD = tf.concat(ODs, axis=0)
80+
ODhat = tf.concat(ODhats, axis=0)
81+
82+
cov_matrix = cov(tf.transpose(ODhat)) # cov expects shape (dims, samples)
83+
eigvals, eigvecs = tf.linalg.eigh(cov_matrix)
84+
eigvecs = tf.gather(eigvecs, [1, 2], axis=1)
85+
86+
HE = self.__find_HE(ODhat, eigvecs, alpha)
87+
C = self.__find_concentration(OD, HE)
88+
maxCs = tf.stack([percentile(C[0, :], 99), percentile(C[1, :], 99)])
89+
90+
self.HERef = HE
91+
self.maxCRef = maxCs
92+
else:
93+
raise ValueError("Unsupported normalization mode.")
94+
95+
def normalize(self, I, Io=240, alpha=1, beta=0.15, stains=True):
96+
c, h, w = I.shape
97+
98+
HE, C, maxC = self.__compute_matrices_single(I, Io, alpha, beta)
99+
100+
# Ensure maxCRef and maxC are broadcastable
101+
scaling_factors = (self.maxCRef / maxC)[:, tf.newaxis] # Shape: [2, 1]
102+
C = scaling_factors * C[:2, :] # Use only the first two rows of C
103+
104+
# Reconstruct the normalized image
105+
Inorm = Io * tf.exp(-tf.linalg.matmul(self.HERef, C))
106+
Inorm = tf.clip_by_value(Inorm, 0, 255)
107+
Inorm = tf.transpose(tf.reshape(Inorm, [c, h, w]), perm=[1, 2, 0]) # Convert back to HWC format
108+
Inorm = tf.cast(Inorm, tf.int32)
109+
110+
H, E = None, None
111+
112+
if stains:
113+
# Extract the H and E components
114+
H = Io * tf.exp(-tf.linalg.matmul(self.HERef[:, 0:1], C[0:1, :]))
115+
H = tf.clip_by_value(H, 0, 255)
116+
H = tf.transpose(tf.reshape(H, [c, h, w]), perm=[1, 2, 0])
117+
H = tf.cast(H, tf.int32)
118+
119+
E = Io * tf.exp(-tf.linalg.matmul(self.HERef[:, 1:2], C[1:2, :]))
120+
E = tf.clip_by_value(E, 0, 255)
121+
E = tf.transpose(tf.reshape(E, [c, h, w]), perm=[1, 2, 0])
122+
E = tf.cast(E, tf.int32)
123+
124+
return Inorm, H, E

0 commit comments

Comments
 (0)