Skip to content

Commit 67b6624

Browse files
author
Donglai Wei
committed
update read_images
1 parent f438659 commit 67b6624

7 files changed

Lines changed: 924 additions & 151 deletions

File tree

connectomics/data/io/io.py

Lines changed: 18 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -9,20 +9,18 @@
99
"""
1010

1111
from __future__ import annotations
12-
from typing import Optional, List, Union
13-
import os
12+
1413
import glob
14+
import os
1515
import pickle
16+
from typing import List, Optional, Union
17+
1618
import h5py
17-
import numpy as np
1819
import imageio
1920
import nibabel as nib
21+
import numpy as np
2022

21-
# Avoid PIL "IOError: image file truncated"
22-
from PIL import ImageFile
23-
24-
ImageFile.LOAD_TRUNCATED_IMAGES = True
25-
23+
from .utils import rgb_to_seg
2624

2725
# =============================================================================
2826
# HDF5 I/O
@@ -107,7 +105,9 @@ def list_hdf5_datasets(filename: str) -> List[str]:
107105
SUPPORTED_IMAGE_FORMATS = ["png", "tif", "tiff", "jpg", "jpeg"]
108106

109107

110-
def read_image(filename: str, add_channel: bool = False) -> Optional[np.ndarray]:
108+
def read_image(
109+
filename: str, add_channel: bool = False, image_type: str = "image"
110+
) -> Optional[np.ndarray]:
111111
"""Read a single image file.
112112
113113
Args:
@@ -121,12 +121,14 @@ def read_image(filename: str, add_channel: bool = False) -> Optional[np.ndarray]
121121
return None
122122

123123
image = imageio.imread(filename)
124+
if image_type == "seg" and image.ndim == 3:
125+
image = rgb_to_seg(image)
124126
if add_channel and image.ndim == 2:
125127
image = image[:, :, None]
126128
return image
127129

128130

129-
def read_images(filename_pattern: str) -> np.ndarray:
131+
def read_images(filename_pattern: str, image_type: str = "image") -> np.ndarray:
130132
"""Read multiple images from a filename pattern.
131133
132134
Args:
@@ -143,17 +145,11 @@ def read_images(filename_pattern: str) -> np.ndarray:
143145
raise ValueError(f"No files found matching pattern: {filename_pattern}")
144146

145147
# Determine array shape from first image
146-
first_image = imageio.imread(file_list[0])
147-
if first_image.ndim == 2:
148-
data = np.zeros((len(file_list), *first_image.shape), dtype=first_image.dtype)
149-
elif first_image.ndim == 3:
150-
data = np.zeros((len(file_list), *first_image.shape), dtype=first_image.dtype)
151-
else:
152-
raise ValueError(f"Unsupported image dimensions: {first_image.ndim}D")
153-
148+
first_image = read_image(file_list[0], image_type=image_type)
149+
data = np.zeros((len(file_list), *first_image.shape), dtype=first_image.dtype)
154150
# Load all images
155151
for i, filepath in enumerate(file_list):
156-
data[i] = imageio.imread(filepath)
152+
data[i] = read_image(filepath, image_type=image_type)
157153

158154
return data
159155

@@ -171,7 +167,7 @@ def read_image_as_volume(filename: str, drop_channel: bool = False) -> np.ndarra
171167
Raises:
172168
ValueError: If file format is not supported
173169
"""
174-
image_suffix = filename[filename.rfind(".") + 1:].lower()
170+
image_suffix = filename[filename.rfind(".") + 1 :].lower()
175171
if image_suffix not in SUPPORTED_IMAGE_FORMATS:
176172
raise ValueError(
177173
f"Unsupported format: {image_suffix}. Supported formats: {SUPPORTED_IMAGE_FORMATS}"
@@ -281,7 +277,7 @@ def read_volume(
281277
if filename.endswith(".nii.gz"):
282278
image_suffix = "nii.gz"
283279
else:
284-
image_suffix = filename[filename.rfind(".") + 1:].lower()
280+
image_suffix = filename[filename.rfind(".") + 1 :].lower()
285281

286282
if image_suffix in ["h5", "hdf5"]:
287283
data = read_hdf5(filename, dataset)
@@ -420,7 +416,7 @@ def get_vol_shape(filename: str, dataset: Optional[str] = None) -> tuple:
420416
if filename.endswith(".nii.gz"):
421417
image_suffix = "nii.gz"
422418
else:
423-
image_suffix = filename[filename.rfind(".") + 1:].lower()
419+
image_suffix = filename[filename.rfind(".") + 1 :].lower()
424420

425421
if image_suffix in ["h5", "hdf5"]:
426422
# HDF5: Read shape from metadata (no data loading)

connectomics/metrics/metrics_seg.py

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,24 +48,55 @@ class AdaptedRandError(torchmetrics.Metric):
4848
4949
This wrapper lets us accumulate scores during Lightning `test_step` without
5050
manual numpy↔torch conversions in the training loop.
51+
52+
Args:
53+
return_all_stats: If True, also compute and return precision and recall
54+
dist_sync_on_step: Whether to sync across distributed processes on each step
5155
"""
5256

5357
full_state_update: bool = False
5458

55-
def __init__(self, dist_sync_on_step: bool = False) -> None:
59+
def __init__(self, return_all_stats: bool = False, dist_sync_on_step: bool = False) -> None:
5660
super().__init__(dist_sync_on_step=dist_sync_on_step)
61+
self.return_all_stats = return_all_stats
62+
5763
self.add_state("total", default=torch.tensor(0.0), dist_reduce_fx="sum")
5864
self.add_state("count", default=torch.tensor(0), dist_reduce_fx="sum")
5965

66+
if return_all_stats:
67+
self.add_state("total_precision", default=torch.tensor(0.0), dist_reduce_fx="sum")
68+
self.add_state("total_recall", default=torch.tensor(0.0), dist_reduce_fx="sum")
69+
6070
def update(self, preds: torch.Tensor, target: torch.Tensor) -> None:
6171
# Move to CPU and numpy for the underlying implementation
6272
preds_np = preds.detach().cpu().numpy()
6373
target_np = target.detach().cpu().numpy()
64-
score = float(adapted_rand(preds_np, target_np))
65-
self.total += torch.tensor(score, device=self.total.device)
74+
75+
if self.return_all_stats:
76+
are, precision, recall = adapted_rand(preds_np, target_np, all_stats=True)
77+
self.total += torch.tensor(are, device=self.total.device)
78+
self.total_precision += torch.tensor(precision, device=self.total_precision.device)
79+
self.total_recall += torch.tensor(recall, device=self.total_recall.device)
80+
else:
81+
score = float(adapted_rand(preds_np, target_np, all_stats=False))
82+
self.total += torch.tensor(score, device=self.total.device)
83+
6684
self.count += 1
6785

6886
def compute(self) -> torch.Tensor:
6987
if self.count == 0:
88+
if self.return_all_stats:
89+
return {
90+
"adapted_rand_error": torch.tensor(0.0, device=self.total.device),
91+
"adapted_rand_precision": torch.tensor(0.0, device=self.total.device),
92+
"adapted_rand_recall": torch.tensor(0.0, device=self.total.device),
93+
}
7094
return torch.tensor(0.0, device=self.total.device)
95+
96+
if self.return_all_stats:
97+
return {
98+
"adapted_rand_error": self.total / self.count,
99+
"adapted_rand_precision": self.total_precision / self.count,
100+
"adapted_rand_recall": self.total_recall / self.count,
101+
}
71102
return self.total / self.count

connectomics/training/lit/model.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,8 @@ def _setup_test_metrics(self):
194194
self.test_accuracy = torchmetrics.Accuracy(task='multiclass', num_classes=num_classes).to(self.device)
195195
if 'adapted_rand' in metrics:
196196
from ...metrics.metrics_seg import AdaptedRandError
197-
self.test_adapted_rand = AdaptedRandError().to(self.device)
197+
# Enable all_stats to also compute precision and recall
198+
self.test_adapted_rand = AdaptedRandError(return_all_stats=True).to(self.device)
198199

199200
def _invert_save_prediction_transform(self, data: np.ndarray) -> np.ndarray:
200201
"""
@@ -349,14 +350,30 @@ def _compute_test_metrics(self, decoded_predictions: np.ndarray, labels: torch.T
349350
# Adapted Rand Error is for instance segmentation
350351
if hasattr(self, "test_adapted_rand") and isinstance(self.test_adapted_rand, torchmetrics.Metric):
351352
from ...metrics.metrics_seg import AdaptedRandError
352-
per_volume_metric = AdaptedRandError().to(self.device)
353+
# Use return_all_stats=True to get precision and recall
354+
per_volume_metric = AdaptedRandError(return_all_stats=True).to(self.device)
353355
per_volume_metric.update(pred_instances.cpu(), labels_instances.cpu())
354-
adapted_rand_value = per_volume_metric.compute()
355-
print(f" {volume_prefix}Adapted Rand Error: {adapted_rand_value.item():.6f}")
356+
adapted_rand_stats = per_volume_metric.compute()
357+
358+
# Print per-volume metrics
359+
if isinstance(adapted_rand_stats, dict):
360+
print(f" {volume_prefix}Adapted Rand Error: {adapted_rand_stats['adapted_rand_error'].item():.6f}")
361+
print(f" {volume_prefix}Adapted Rand Precision: {adapted_rand_stats['adapted_rand_precision'].item():.6f}")
362+
print(f" {volume_prefix}Adapted Rand Recall: {adapted_rand_stats['adapted_rand_recall'].item():.6f}")
363+
else:
364+
print(f" {volume_prefix}Adapted Rand Error: {adapted_rand_stats.item():.6f}")
356365

357366
# Update running metric for epoch-level aggregation
358367
self.test_adapted_rand.update(pred_instances.cpu(), labels_instances.cpu())
359-
self.log("test_adapted_rand", self.test_adapted_rand, on_step=False, on_epoch=True, prog_bar=True, logger=True)
368+
369+
# Log metrics - handle both dict and tensor return values
370+
epoch_stats = self.test_adapted_rand.compute()
371+
if isinstance(epoch_stats, dict):
372+
self.log("test_adapted_rand", epoch_stats['adapted_rand_error'], on_step=False, on_epoch=True, prog_bar=True, logger=True)
373+
self.log("test_adapted_rand_precision", epoch_stats['adapted_rand_precision'], on_step=False, on_epoch=True, prog_bar=True, logger=True)
374+
self.log("test_adapted_rand_recall", epoch_stats['adapted_rand_recall'], on_step=False, on_epoch=True, prog_bar=True, logger=True)
375+
else:
376+
self.log("test_adapted_rand", epoch_stats, on_step=False, on_epoch=True, prog_bar=True, logger=True)
360377

361378
else:
362379
# For binary/semantic segmentation: binarize predictions

scripts/images_to_h5.py

Lines changed: 13 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -25,16 +25,18 @@
2525
Note: Use quotes around the input pattern to prevent shell expansion.
2626
"""
2727

28-
import sys
2928
import os
30-
from pathlib import Path
31-
from connectomics.data.io import read_volume, write_hdf5
29+
import sys
30+
31+
from connectomics.data.io import read_images, write_hdf5
3232

3333

3434
def main():
3535
"""Main conversion function."""
3636
if len(sys.argv) < 3:
37-
print("Usage: python scripts/images_to_h5.py <input_pattern> <output_file.h5> [dataset_key]")
37+
print(
38+
"Usage: python scripts/images_to_h5.py <input_pattern> <output_file.h5> [dataset_key]"
39+
)
3840
print("")
3941
print("Examples:")
4042
print(' python scripts/images_to_h5.py "datasets/images/*.tiff" output.h5')
@@ -46,48 +48,27 @@ def main():
4648

4749
input_pattern = sys.argv[1]
4850
output_file = sys.argv[2]
49-
dataset_key = sys.argv[3] if len(sys.argv) > 3 else "main"
51+
image_type = sys.argv[3] if len(sys.argv) > 3 else "image"
52+
dataset_key = sys.argv[4] if len(sys.argv) > 4 else "main"
5053

5154
# Ensure output directory exists
5255
output_dir = os.path.dirname(output_file)
5356
if output_dir and not os.path.exists(output_dir):
5457
print(f"Creating output directory: {output_dir}")
5558
os.makedirs(output_dir, exist_ok=True)
5659

57-
# Detect file format from pattern
58-
pattern_lower = input_pattern.lower()
59-
if any(ext in pattern_lower for ext in ['.tif', '.tiff']):
60-
format_name = "TIFF"
61-
elif '.png' in pattern_lower:
62-
format_name = "PNG"
63-
elif any(ext in pattern_lower for ext in ['.jpg', '.jpeg']):
64-
format_name = "JPEG"
65-
else:
66-
format_name = "image"
67-
68-
print(f"Reading {format_name} files matching: {input_pattern}")
69-
print("This may take a while for large volumes...")
70-
7160
# Read all image files as a 3D volume
72-
try:
73-
volume = read_volume(input_pattern)
74-
except Exception as e:
75-
print(f"Error reading images: {e}")
76-
print("\nTips:")
77-
print(" - Check that the file pattern is correct")
78-
print(" - Ensure all images have the same dimensions")
79-
print(" - Verify the image files are readable")
80-
sys.exit(1)
61+
volume = read_images(input_pattern, image_type=image_type)
8162

82-
print(f"\n{'='*60}")
83-
print(f"Volume Information:")
84-
print(f"{'='*60}")
63+
print(f"\n{'=' * 60}")
64+
print("Volume Information:")
65+
print(f"{'=' * 60}")
8566
print(f" Shape: {volume.shape}")
8667
print(f" Data type: {volume.dtype}")
8768
print(f" Size: {volume.nbytes / (1024**3):.2f} GB")
8869
print(f" Min value: {volume.min()}")
8970
print(f" Max value: {volume.max()}")
90-
print(f"{'='*60}")
71+
print(f"{'=' * 60}")
9172

9273
print(f"\nSaving to: {output_file}")
9374
print(f"Dataset key: '{dataset_key}'")

0 commit comments

Comments
 (0)