Skip to content

Commit 80038df

Browse files
authored
load_inspection_annotations() is completely non-functional (#86)
* Refactored `InspectionAnnotations` save/load methods to streamline CSV handling. (#81) * Enhanced `load_inspection_annotations` to support `background` parameter and improved row parsing with `literal_eval()`. (#81) * Refactored row parsing in `load_inspection_annotations()` with `literal_eval()`. (#81) * Enhanced `inspect()` with progress tracking using `rich.progress` and added `console` parameter for customizable output. (#81) * Added `console` parameter to `Progress` instances in `training.py` for consistent output handling. (#81) * Enhanced progress display in `inspect()` by adding `SpinnerColumn` for improved visual feedback with `rich.progress`. (#81) * Refactored `InspectionAnnotations` save/load methods to use JSON format instead of CSV for correct serialization. Added `_lists_to_tuples` helper for parsing JSON objects. (#81) * Updated `save` and `load_inspection_annotations` methods in `InspectionAnnotations` to include `background` in serialization and deserialization logic. (#81)
1 parent b473ba1 commit 80038df

2 files changed

Lines changed: 33 additions & 26 deletions

File tree

mipcandy/data/inspection.py

Lines changed: 31 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
1-
from dataclasses import dataclass
1+
from dataclasses import dataclass, asdict
2+
from json import dump, load
23
from os import PathLike
3-
from typing import Sequence, override, Callable, Self
4+
from typing import Sequence, override, Callable, Self, Any
45

56
import numpy as np
67
import torch
7-
from pandas import DataFrame
8+
from rich.console import Console
9+
from rich.progress import Progress, SpinnerColumn
810
from torch import nn
911

1012
from mipcandy.data.dataset import SupervisedDataset
@@ -37,6 +39,9 @@ def center_of_foreground(self) -> tuple[int, int] | tuple[int, int, int]:
3739
round((self.foreground_bbox[3] + self.foreground_bbox[2]) * .5))
3840
return r if len(self.shape) == 2 else r + (round((self.foreground_bbox[5] + self.foreground_bbox[4]) * .5),)
3941

42+
def to_dict(self) -> dict[str, tuple[int, ...]]:
43+
return asdict(self)
44+
4045

4146
class InspectionAnnotations(HasDevice, Sequence[InspectionAnnotation]):
4247
def __init__(self, dataset: SupervisedDataset, background: int, *annotations: InspectionAnnotation,
@@ -65,10 +70,8 @@ def __len__(self) -> int:
6570
return len(self._annotations)
6671

6772
def save(self, path: str | PathLike[str]) -> None:
68-
r = []
69-
for annotation in self._annotations:
70-
r.append({"foreground_bbox": annotation.foreground_bbox, "ids": annotation.ids})
71-
DataFrame(r).to_csv(path, index=False)
73+
with open(path, "w") as f:
74+
dump({"background": self._background, "annotations": self._annotations}, f)
7275

7376
def _get_shapes(self, get_shape: Callable[[InspectionAnnotation], tuple[int, ...]]) -> tuple[
7477
tuple[int, ...] | None, tuple[int, ...], tuple[int, ...]]:
@@ -208,27 +211,31 @@ def crop_roi(self, i: int, *, percentile: float = .95) -> tuple[torch.Tensor, to
208211
return crop(image.unsqueeze(0), roi).squeeze(0), crop(label.unsqueeze(0), roi).squeeze(0)
209212

210213

211-
def load_inspection_annotations(path: str | PathLike[str]) -> InspectionAnnotations:
212-
df = DataFrame.from_csv(path)
213-
return InspectionAnnotations(*(
214-
InspectionAnnotation(
215-
tuple(row["shape"]), format_bbox(row["foreground_bbox"]), tuple(row["ids"])
216-
) for _, row in df.iterrows()
214+
def _lists_to_tuples(pairs: Sequence[tuple[str, Any]]) -> dict[str, Any]:
215+
return {k: tuple(v) if isinstance(v, list) else v for k, v in pairs}
216+
217+
218+
def load_inspection_annotations(path: str | PathLike[str], dataset: SupervisedDataset) -> InspectionAnnotations:
219+
with open(path) as f:
220+
obj = load(f, object_pairs_hook=_lists_to_tuples)
221+
return InspectionAnnotations(dataset, obj["background"], *(
222+
InspectionAnnotation(**row) for row in obj["annotations"]
217223
))
218224

219225

220-
def inspect(dataset: SupervisedDataset, *, background: int = 0) -> InspectionAnnotations:
226+
def inspect(dataset: SupervisedDataset, *, background: int = 0, console: Console = Console()) -> InspectionAnnotations:
221227
r = []
222-
for _, label in dataset:
223-
indices = (label != background).nonzero()
224-
mins = indices.min(dim=0)[0].tolist()
225-
maxs = indices.max(dim=0)[0].tolist()
226-
bbox = (mins[1], maxs[1], mins[2], maxs[2])
227-
r.append(InspectionAnnotation(
228-
label.shape[1:],
229-
bbox if label.ndim == 3 else bbox + (mins[3], maxs[3]),
230-
tuple(label.unique())
231-
))
228+
with Progress(*Progress.get_default_columns(), SpinnerColumn(), console=console) as progress:
229+
task = progress.add_task("Inspecting dataset...", total=len(dataset))
230+
for _, label in dataset:
231+
progress.update(task, advance=1, description=f"Inspecting dataset {tuple(label.shape)}")
232+
indices = (label != background).nonzero()
233+
mins = indices.min(dim=0)[0].tolist()
234+
maxs = indices.max(dim=0)[0].tolist()
235+
bbox = (mins[1], maxs[1], mins[2], maxs[2])
236+
r.append(InspectionAnnotation(
237+
label.shape[1:], bbox if label.ndim == 3 else bbox + (mins[3], maxs[3]), tuple(label.unique())
238+
))
232239
return InspectionAnnotations(dataset, background, *r, device=dataset.device())
233240

234241

mipcandy/training.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -306,7 +306,7 @@ def train_epoch(self, epoch: int, toolbox: TrainerToolbox) -> None:
306306
toolbox.model.train()
307307
if toolbox.ema:
308308
toolbox.ema.train()
309-
with Progress(*Progress.get_default_columns(), SpinnerColumn()) as progress:
309+
with Progress(*Progress.get_default_columns(), SpinnerColumn(), console=self._console) as progress:
310310
epoch_prog = progress.add_task(f"Epoch {epoch}", total=len(self._dataloader))
311311
for images, labels in self._dataloader:
312312
images, labels = images.to(self._device), labels.to(self._device)
@@ -439,7 +439,7 @@ def validate(self, toolbox: TrainerToolbox) -> tuple[float, dict[str, list[float
439439
worst_score = float("+inf")
440440
metrics = {}
441441
num_cases = len(self._validation_dataloader)
442-
with Progress(*Progress.get_default_columns(), SpinnerColumn()) as progress:
442+
with Progress(*Progress.get_default_columns(), SpinnerColumn(), console=self._console) as progress:
443443
val_prog = progress.add_task(f"Validating", total=num_cases)
444444
for image, label in self._validation_dataloader:
445445
image, label = image.to(self._device), label.to(self._device)

0 commit comments

Comments
 (0)