Skip to content

Commit 1bba9f6

Browse files
add src align+features+finalds
1 parent 5f386da commit 1bba9f6

48 files changed

Lines changed: 33815 additions & 0 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

src/_3_check_alignment/_3-1_build_patches.py

Lines changed: 507 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 233 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,233 @@
1+
"""Group all npz in one unique file"""
2+
3+
4+
import os
5+
from scipy.sparse import load_npz, vstack, save_npz
6+
from tqdm import tqdm
7+
import argparse
8+
import gc
9+
import re
10+
import numpy as np
11+
12+
13+
14+
def consolidate_chunks(folder_path, chunk_prefix, output_file):
15+
"""
16+
Consolidate multiple sparse chunk files into a single sparse .npz file.
17+
18+
Args:
19+
folder_path (str): Path to the folder containing chunk files.
20+
chunk_prefix (str): Prefix of the chunk files (e.g., 'masks_chunk', 'masks_cells_chunk').
21+
output_file (str): Path to save the consolidated .npz file.
22+
"""
23+
print(f"\n-> Consolidating chunks with prefix '{chunk_prefix}' in {folder_path}...")
24+
25+
# Collect all chunk files matching the prefix
26+
chunk_files = sorted(
27+
[os.path.join(folder_path, f) for f in os.listdir(folder_path) if f.startswith(chunk_prefix) and f.endswith(".npz")]
28+
)
29+
30+
if not chunk_files:
31+
print(f"No chunks found with prefix '{chunk_prefix}'. Skipping.")
32+
return
33+
34+
# Extract chunk indices and sort numerically
35+
def extract_index(file_name):
36+
match = re.search(rf"{chunk_prefix}_(\d+)\.npz$", file_name)
37+
return int(match.group(1)) if match else float('inf')
38+
39+
chunk_files = sorted(chunk_files, key=lambda x: extract_index(os.path.basename(x)))
40+
41+
# Load and combine all sparse chunks
42+
sparse_matrices = []
43+
for chunk_file in tqdm(chunk_files, desc=f"Loading {chunk_prefix}", unit="chunk"):
44+
sparse_chunk = load_npz(chunk_file)
45+
sparse_matrices.append(sparse_chunk)
46+
del sparse_chunk # Release memory
47+
gc.collect()
48+
49+
# Combine into a single sparse matrix
50+
print(f"-> Combining {len(sparse_matrices)} chunks...")
51+
final_sparse_matrix = vstack(sparse_matrices)
52+
53+
# Save the combined sparse matrix
54+
print(f"-> Saving...")
55+
save_npz(output_file, final_sparse_matrix)
56+
57+
# # Ensure the output file is saved
58+
# if os.path.exists(output_file):
59+
# # If the file is saved successfully, delete all chunk files
60+
# print(f"-> Deleting chunk files after successful save...")
61+
# for chunk_file in chunk_files:
62+
# os.remove(chunk_file)
63+
# print(f" - Deleted {chunk_file}")
64+
# else:
65+
# print(f"-> Warning: Output file '{output_file}' was not created. Chunk files retained.")
66+
67+
# Cleanup
68+
del sparse_matrices, final_sparse_matrix
69+
gc.collect()
70+
71+
print(f"Done.")
72+
73+
74+
75+
def consolidate_npy_chunks(folder_path, file_prefix, output_file):
76+
"""
77+
Consolidate multiple .npy chunk files into a single .npy file.
78+
79+
Args:
80+
folder_path (str): Path to the folder containing chunk files.
81+
file_prefix (str): Prefix of the chunk files (e.g., 'images_chunk').
82+
output_file (str): Path to save the consolidated .npy file.
83+
"""
84+
print(f"\n-> Consolidating .npy chunks with prefix '{file_prefix}' in {folder_path}...")
85+
86+
# Collect all .npy chunk files matching the prefix
87+
chunk_files = sorted(
88+
[os.path.join(folder_path, f) for f in os.listdir(folder_path) if f.startswith(file_prefix) and f.endswith(".npy")],
89+
key=lambda x: int(re.search(rf"{file_prefix}_(\d+)\.npy$", os.path.basename(x)).group(1))
90+
)
91+
92+
if not chunk_files:
93+
print(f"No .npy chunks found with prefix '{file_prefix}'. Skipping.")
94+
return
95+
96+
# Load and combine all chunks
97+
arrays = []
98+
for chunk_file in tqdm(chunk_files, desc=f"Loading {file_prefix}", unit="chunk"):
99+
arrays.append(np.load(chunk_file))
100+
101+
# Concatenate and save the final array
102+
final_array = np.concatenate(arrays, axis=0)
103+
np.save(output_file, final_array)
104+
105+
# # Ensure the output file is saved before deleting chunks
106+
# if os.path.exists(output_file):
107+
# print(f"-> Deleting chunk files after successful save...")
108+
# for chunk_file in chunk_files:
109+
# os.remove(chunk_file)
110+
# print(f" - Deleted {chunk_file}")
111+
# else:
112+
# print(f"-> Warning: Output file '{output_file}' was not created. Chunk files retained.")
113+
114+
# Cleanup
115+
del arrays, final_array
116+
gc.collect()
117+
118+
print(f"Done.")
119+
120+
121+
122+
123+
def consolidate_nested_chunks(folder_path, chunk_prefix, output_file):
124+
"""
125+
Consolidate nested sparse chunk files into a single sparse .npz file.
126+
127+
Args:
128+
folder_path (str): Path to the folder containing chunk files.
129+
chunk_prefix (str): Prefix of the chunk files (e.g., 'masks_chunk').
130+
output_file (str): Path to save the consolidated .npz file.
131+
"""
132+
print(f"\n-> Consolidating nested chunks with prefix '{chunk_prefix}' in {folder_path}...")
133+
134+
# Collect all nested chunk files matching the prefix
135+
nested_chunk_files = sorted(
136+
[os.path.join(folder_path, f) for f in os.listdir(folder_path) if f.startswith(chunk_prefix) and f.endswith(".npz")]
137+
)
138+
139+
if not nested_chunk_files:
140+
print(f"No nested chunks found with prefix '{chunk_prefix}'. Skipping.")
141+
return
142+
143+
# Group by "i" and then by "j"
144+
def extract_indices(file_name):
145+
match = re.search(rf"{chunk_prefix}_(\d+)_chunk_(\d+)\.npz$", file_name)
146+
if match:
147+
return int(match.group(1)), int(match.group(2))
148+
return float('inf'), float('inf')
149+
150+
nested_chunk_files = sorted(nested_chunk_files, key=lambda x: extract_indices(os.path.basename(x)))
151+
152+
# Combine all nested sparse chunks
153+
sparse_matrices = []
154+
for nested_chunk_file in tqdm(nested_chunk_files, desc=f"Loading {chunk_prefix}", unit="nested_chunk"):
155+
sparse_chunk = load_npz(nested_chunk_file)
156+
sparse_matrices.append(sparse_chunk)
157+
158+
# Combine into a single sparse matrix
159+
print(f"-> Combining {len(sparse_matrices)} nested chunks...")
160+
final_sparse_matrix = vstack(sparse_matrices)
161+
162+
# Save the combined sparse matrix
163+
print(f"-> Saving...")
164+
save_npz(output_file, final_sparse_matrix)
165+
166+
# # Ensure the output file is saved before deleting chunks
167+
# if os.path.exists(output_file):
168+
# print(f"-> Deleting nested chunk files after successful save...")
169+
# for nested_chunk_file in nested_chunk_files:
170+
# os.remove(nested_chunk_file)
171+
# print(f" - Deleted {nested_chunk_file}")
172+
# else:
173+
# print(f"-> Warning: Output file '{output_file}' was not created. Nested chunk files retained.")
174+
175+
# Cleanup
176+
del sparse_matrices, final_sparse_matrix
177+
gc.collect()
178+
179+
print(f"Done.")
180+
181+
182+
183+
184+
def process_slide_folders(slide_ids, folder_name):
185+
"""
186+
Process all slide folders to consolidate sparse mask chunks into single .npz files.
187+
188+
Args:
189+
slide_ids (list): List of slide IDs to process.
190+
folder_name (str): Path to the parent folder containing slide subfolders.
191+
"""
192+
for slide_id in slide_ids:
193+
print(f"\n===== PROCESSING SLIDE: {slide_id} =====")
194+
slide_folder = os.path.join(folder_name, slide_id)
195+
196+
if not os.path.exists(slide_folder):
197+
print(f"Slide folder '{slide_folder}' does not exist. Skipping.")
198+
continue
199+
200+
# Check for images.npy or chunked images
201+
images_file = os.path.join(slide_folder, "images.npy")
202+
chunked_images = sorted(
203+
[os.path.join(slide_folder, f) for f in os.listdir(slide_folder) if f.startswith("images_chunk") and f.endswith(".npy")]
204+
)
205+
206+
if os.path.exists(images_file):
207+
print(f"[INFO] Single 'images.npy' file detected for slide {slide_id}.")
208+
consolidate_chunks(slide_folder, "masks_chunk", os.path.join(slide_folder, "masks.npz"))
209+
consolidate_chunks(slide_folder, "masks_cells_chunk", os.path.join(slide_folder, "masks_cells.npz"))
210+
elif chunked_images:
211+
print(f"[INFO] Chunked 'images_chunk' files detected for slide {slide_id}.")
212+
consolidate_npy_chunks(slide_folder, "images_chunk", os.path.join(slide_folder, "images.npy"))
213+
consolidate_npy_chunks(slide_folder, "types_chunk", os.path.join(slide_folder, "types.npy"))
214+
consolidate_npy_chunks(slide_folder, "patch_ids_chunk", os.path.join(slide_folder, "patch_ids.npy"))
215+
consolidate_nested_chunks(slide_folder, "masks_chunk", os.path.join(slide_folder, "masks.npz"))
216+
consolidate_nested_chunks(slide_folder, "masks_cells_chunk", os.path.join(slide_folder, "masks_cells.npz"))
217+
218+
print("\nAll slides processed successfully.")
219+
220+
221+
222+
223+
if __name__ == "__main__":
224+
parser = argparse.ArgumentParser(description="Consolidate sparse mask chunks for multiple slides.")
225+
226+
# Input arguments
227+
parser.add_argument("--slide_ids", type=str, nargs="+", required=True, help="List of slide IDs to process.")
228+
parser.add_argument("--folder_name", type=str, default="/Volumes/DD_FGS/MICS/data_HE2CellType/CT_DS/check_align_patches/patches_xenium", help="Parent folder containing slide subfolders.")
229+
230+
args = parser.parse_args()
231+
232+
# Run the consolidation process
233+
process_slide_folders(args.slide_ids, args.folder_name)
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
"""
2+
Before running this script, you need to applgy CellVit on the patches to predict the comparison segmentation mask:
3+
In HE2CT folder:
4+
- Prepare dataset using cell_segmentation/datasets/prepare_pannuke.py using ['Neoplastic','Inflammatory','Connective','Dead','Epithelial'] as classes
5+
- Make the Macenko normalization using cell_segmentation/datasets/macenko_normalization.py (OR NOT ????)
6+
- Convert the dataset to zip using cell_segmentation/datasets/convert_into_zip.py
7+
- Apply CellVit inference (file cell_segmentation/inference/inference_cellvit_experiment_pannuke.py):
8+
- modifying config.yaml
9+
- de-commenting the part to get the predictions for instance_map and pixel count predictions for pannuke label for gt in inference_cellvit_experiment_pannuke.py (cf. See where there is CHOOSE OR NOT)
10+
- In case we want to use real cell type mask instead of fake one : Dynamically adapt mask to handle PanNuke categories in datasets/pannuke.py by uncommenting type_map[type_map != 0] = 1 at the end of the load_maskfile function (cf. CHOOSE)
11+
- and --cell_tokens nucleus will be useful for H&E features after
12+
- and using inference_cellvit_experiment_pannuke.py with in the terminal : python3 cell_segmentation/inference/inference_cellvit_experiment_pannuke.py --run_dir /Volumes/DD_FGS/MICS/data_HE2CellType/CT_DS/check_align_patches/apply_cellvit/output_cellvit/heart_s0 --checkpoint_name CellViT-SAM-H-x40.pth --gpu mps --magnification 40 --cell_tokens nucleus
13+
OR use ruche with slurm_cellvit_checkalign.sh
14+
Then using this file, add metrics for each patch in the sdata object.
15+
!!!! If slide was too big to do inference in one time, use before the script optional_group_output_cellvit.py to group the output of CellVit in one unique files for the given slide. !!!!
16+
"""
17+
18+
import argparse
19+
import os
20+
import json
21+
import pandas as pd
22+
import spatialdata as sd
23+
24+
25+
26+
def open_json_metrics(output_cellvit_folder, slide_id):
27+
28+
json_path = os.path.join(output_cellvit_folder, f'{slide_id}/inference_results.json')
29+
with open(json_path, 'r') as file:
30+
metric_json_file = json.load(file)
31+
32+
return metric_json_file
33+
34+
35+
36+
37+
def build_df_metrics(metric_json_file):
38+
39+
df_metrics = pd.DataFrame.from_dict(metric_json_file['image_metrics'], orient='index')
40+
41+
df_metrics.reset_index(inplace=True)
42+
df_metrics.rename(columns={'index': 'image'}, inplace=True)
43+
44+
return df_metrics
45+
46+
47+
48+
49+
def add_metrics_in_sdata(sdata, df_metrics):
50+
51+
he_patches = sdata.shapes['he_patches'].copy()
52+
53+
df_metrics['patch_id'] = df_metrics['image'].str.replace('.png', '').astype(int)
54+
55+
he_patches = he_patches.merge(df_metrics[['patch_id', 'Dice', 'Jaccard', 'bPQ']],
56+
on='patch_id', how='left')
57+
58+
he_patches[['Dice', 'Jaccard', 'bPQ']] = he_patches[['Dice', 'Jaccard', 'bPQ']].fillna(-1) # -1 will correspond to no cell in xenium mask
59+
60+
sdata.shapes['he_patches']['Dice'] = he_patches['Dice']
61+
sdata.shapes['he_patches']['Jaccard'] = he_patches['Jaccard']
62+
sdata.shapes['he_patches']['bPQ'] = he_patches['bPQ']
63+
print(sdata.shapes['he_patches'].head())
64+
print("\n\n")
65+
print(sdata)
66+
67+
print("\n\nSaving on disk...")
68+
sdata.delete_element_from_disk("he_patches")
69+
sdata.write_element("he_patches")
70+
print("Done.")
71+
72+
73+
74+
75+
def main(args):
76+
77+
print(f"\n==== Proccessing {args.slide_id} ====")
78+
79+
# Get metrics output from CellVit
80+
print("Loading metrics...")
81+
metric_json_file = open_json_metrics(args.output_cellvit_folder, args.slide_id)
82+
df_metrics = build_df_metrics(metric_json_file)
83+
84+
# Load sdata
85+
print("Loading sdata...")
86+
sdata_path = os.path.join(args.sdata_folder, f'sdata_{args.slide_id}.zarr')
87+
sdata = sd.read_zarr(sdata_path, selection=('shapes',))
88+
89+
try:
90+
del sdata.shapes['he_patches']['Dice']
91+
del sdata.shapes['he_patches']['Jaccard']
92+
del sdata.shapes['he_patches']['bPQ']
93+
except:
94+
pass
95+
96+
print(sdata.shapes['he_patches'].head())
97+
98+
# Add metrics in sdata
99+
print("\n\nAdding metrics in sdata...")
100+
add_metrics_in_sdata(sdata, df_metrics)
101+
102+
103+
104+
if __name__ == "__main__":
105+
parser = argparse.ArgumentParser(description="Add metrics to check alignment patches in sdata")
106+
107+
parser.add_argument("--slide_id", type=str, default="heart_s0", help="Slide id")
108+
parser.add_argument("--output_cellvit_folder", type=str, default="/Volumes/DD_FGS/MICS/data_HE2CellType/CT_DS/check_align_patches/apply_cellvit/output_cellvit", help="Output folder of CellVit for align checking")
109+
parser.add_argument("--sdata_folder", type=str, default="/Volumes/SAUV_FGS/MICS/data_HE2CellType/CT_DS/sdata_final", help="Folder containing final sdata")
110+
111+
args = parser.parse_args()
112+
main(args)

0 commit comments

Comments
 (0)