-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathprepare_pannuke.py
More file actions
executable file
·241 lines (189 loc) · 9.65 KB
/
prepare_pannuke.py
File metadata and controls
executable file
·241 lines (189 loc) · 9.65 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
# -*- coding: utf-8 -*-
# Prepare Pannuke Dataset by converting and resorting files
#
# @ Fabian Hörst, fabian.hoerst@uk-essen.de
# Institute for Artifical Intelligence in Medicine,
# University Medicine Essen
import inspect
import os
import sys
currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
parentdir = os.path.dirname(currentdir)
sys.path.insert(0, parentdir)
parentdir = os.path.dirname(parentdir)
sys.path.insert(0, parentdir)
import numpy as np
from pathlib import Path
from PIL import Image
from tqdm import tqdm
import argparse
import pandas as pd
from scipy.sparse import csr_matrix, load_npz, save_npz
from math import ceil
import gc
from cell_segmentation.utils.metrics import remap_label
def load_sparse_3d_masks_chunked(file_path, mask_shape, chunk_size):
"""
Load 3D masks from a sparse .npz file in chunks and process them incrementally.
Args:
file_path (str): Path to the saved .npz file.
mask_shape (tuple): Original shape of the mask (height, width, channels).
chunk_size (int): Number of masks to load and process per chunk.
Yields:
list: List of dense 3D masks (height, width, channels) for the current chunk.
"""
print("Loading sparse masks...")
sparse_matrix = load_npz(file_path)
height, width, channels = mask_shape
flat_size = height * width # Number of pixels per mask slice
total_masks = sparse_matrix.shape[0] // flat_size # Total number of masks
for start_idx in range(0, total_masks, chunk_size):
end_idx = min(start_idx + chunk_size, total_masks)
chunk_size_actual = end_idx - start_idx
# Preallocate memory for dense masks in this chunk
masks_chunk = []
for i in range(chunk_size_actual):
# Calculate the row range for mask `i` in this chunk
start_row = (start_idx + i) * flat_size
end_row = start_row + flat_size
# Extract the rows corresponding to the 3D mask
slices = sparse_matrix[start_row:end_row].toarray()
# Reshape back to the original (height, width, channels)
mask = slices.reshape(mask_shape)
masks_chunk.append(mask)
yield masks_chunk
def save_sparse_maps_single_file(inst_map, type_map, output_path, outname):
"""
Save instance and type maps as sparse matrices in a single .npz file.
Parameters:
- inst_map: 2D numpy array, the instance map.
- type_map: 2D numpy array, the type map.
- output_path: Path object or string, the directory to save the file.
- outname: String, the base name of the output file (without extension).
"""
# Ensure the output directory exists
os.makedirs(output_path, exist_ok=True)
# Convert to sparse format
inst_map_sparse = csr_matrix(inst_map)
type_map_sparse = csr_matrix(type_map)
# Save both sparse matrices into a single .npz file using numpy.savez
combined_path = os.path.join(output_path, "labels", outname)
np.savez(
combined_path,
inst_map_data=inst_map_sparse.data,
inst_map_indices=inst_map_sparse.indices,
inst_map_indptr=inst_map_sparse.indptr,
inst_map_shape=inst_map_sparse.shape,
type_map_data=type_map_sparse.data,
type_map_indices=type_map_sparse.indices,
type_map_indptr=type_map_sparse.indptr,
type_map_shape=type_map_sparse.shape,
)
def process_ds(input_path, output_path, list_cat) -> None:
print(f"\n==== Processing slide {os.path.basename(input_path)} ====")
print('\nUsing list_cat:', list_cat)
os.makedirs(os.path.join(output_path, "images"), exist_ok=True)
os.makedirs(os.path.join(output_path, "labels"), exist_ok=True)
# os.makedirs(os.path.join(output_path, "masks_cell_ids_nuclei"), exist_ok=True)
print("\nLoading large numpy files, this may take a while")
print("-> Loading images.npy...")
images = np.load(input_path / "images.npy")
print("-> Loading types.npy...")
types = np.load(input_path / "types.npy")
# print("-> Loading patch_ids.npy...")
# patch_ids = np.load(input_path / "patch_ids.npy")
patch_ids = [f"{os.path.basename(input_path)}_{i}" for i in range(len(images))]
print("\nProcess images")
for i in tqdm(range(len(images)), total=len(images)):
outname = f"{patch_ids[i]}.png"
out_img = images[i]
im = Image.fromarray(out_img.astype(np.uint8))
im.save(output_path / "images" / outname)
cell_count = {} # create a dictionary to store cell count for each patch
save_types = {} # create a dictionary to store type for each patch
print("\nProcess masks")
mask_shape = (256, 256, len(list_cat) + 1)
chunk_size = 20000
number_of_chunks = ceil(len(patch_ids) / chunk_size)
chunk_generator = load_sparse_3d_masks_chunked(input_path / "masks.npz", mask_shape, chunk_size=chunk_size)
for chunk_idx, masks_chunk in enumerate(chunk_generator):
for i, mask in tqdm(enumerate(masks_chunk), total=len(masks_chunk), desc=f"Chunk {chunk_idx+1}/{number_of_chunks}"):
patch_idx = chunk_idx * chunk_size + i
outname = f"{patch_ids[patch_idx]}.npz"
type = types[patch_idx]
# store cell count for each class for the given patch
class_cell_count = {}
for j, class_name in enumerate(list_cat):
class_cell_count[class_name] = len(np.unique(mask[:, :, j]))-1
cell_count[f"{patch_ids[patch_idx]}.png"] = class_cell_count
# store type for the given patch
save_types[f"{patch_ids[patch_idx]}.png"] = type
# need to create instance map and type map with shape 256x256
inst_map = np.zeros((256, 256))
num_nuc = 0
for j in range(len(list_cat)):
# copy value from new array if value is not equal 0
layer_res = remap_label(mask[:, :, j])
# inst_map = np.where(mask[:,:,j] != 0, mask[:,:,j], inst_map)
inst_map = np.where(layer_res != 0, layer_res + num_nuc, inst_map)
num_nuc = num_nuc + np.max(layer_res)
inst_map = remap_label(inst_map)
type_map = np.zeros((256, 256)).astype(np.int32)
for j in range(len(list_cat)):
layer_res = ((j + 1) * np.clip(mask[:, :, j], 0, 1)).astype(np.int32)
type_map = np.where(layer_res != 0, layer_res, type_map)
save_sparse_maps_single_file(inst_map, type_map, output_path, outname)
# # Save also last layer with cell_ids as npz
# mask_cell_ids = csr_matrix(mask[:, :, -1])
# save_npz(output_path / "masks_cell_ids_nuclei" / outname, mask_cell_ids)
# Delete maps after saving to free memory
del inst_map
del type_map
# del mask_cell_ids
gc.collect() # Manually run garbage collection
# Clear the current chunk's data from memory
del masks_chunk
gc.collect() # Force garbage collection to free up memory
# save cell count and type for each patch
print("\nSave cell count and type for each patch")
cell_count = pd.DataFrame(cell_count).T
cell_count = cell_count.rename_axis("Image").reset_index()
cell_count.to_csv(output_path / "cell_count.csv", sep=',', index=False)
save_types = pd.DataFrame(save_types.items(), columns=["img", "type"])
save_types.to_csv(output_path / "types.csv", sep=',', index=False)
print("\nDone.")
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
description="Prepare dataset by converting and resorting files",
)
parser.add_argument(
"--input_path",
type=str,
#default="/Volumes/DD_FGS/MICS/data_HE2CellType/CT_DS/check_align_patches/patches_xenium/heart_s0",
#default="/Volumes/DD_FGS/MICS/data_HE2CellType/CT_DS/ds_slides_cat/ct_1/heart_s0", # Comment part for images if already done during check align + comment part for masks_cell_ids_nuclei
default="/Volumes/DD_FGS/MICS/data_HE2CellType/CT_DS/ds_slides_cat/pannuke/fold2", # Comment part for masks_cell_ids_nuclei + for patch_ids use patch_ids = [f"{os.path.basename(input_path)}_{i}" for i in range(len(images))]
help="Input path of the original PanNuke dataset"
)
parser.add_argument(
"--output_path",
type=str,
#default="/Volumes/DD_FGS/MICS/data_HE2CellType/CT_DS/check_align_patches/apply_cellvit/prepared_patches_xenium/heart_s0",
#default="/Volumes/DD_FGS/MICS/data_HE2CellType/HE2CT/prepared_datasets_cat/ct_1/heart_s0",
default="/Volumes/DD_FGS/MICS/data_HE2CellType/HE2CT/prepared_datasets_cat/pannuke/fold2",
help="Output path to store the processed PanNuke dataset"
)
parser.add_argument(
"--list_cat",
nargs='+',
type=str,
default=['Neoplastic','Inflammatory','Connective','Dead','Epithelial'],
#default=["T_NK", "B_Plasma", "Myeloid", "Blood_vessel", "Fibroblast_Myofibroblast", "Epithelial", "Specialized", "Melanocyte", "Dead"],
help="List that contains the name of the categories (cell types), e.g. ['Neoplastic','Inflammatory','Connective','Dead','Epithelial']"
)
if __name__ == "__main__":
opt = parser.parse_args()
configuration = vars(opt)
input_path = Path(configuration["input_path"])
output_path = Path(configuration["output_path"])
list_cat = configuration["list_cat"]
process_ds(input_path, output_path, list_cat)