|
1 | | -from dataclasses import dataclass |
| 1 | +from dataclasses import dataclass, asdict |
| 2 | +from json import dump, load |
2 | 3 | from os import PathLike |
3 | | -from typing import Sequence, override, Callable, Self |
| 4 | +from typing import Sequence, override, Callable, Self, Any |
4 | 5 |
|
5 | 6 | import numpy as np |
6 | 7 | import torch |
7 | | -from pandas import DataFrame |
| 8 | +from rich.console import Console |
| 9 | +from rich.progress import Progress, SpinnerColumn |
8 | 10 | from torch import nn |
9 | 11 |
|
10 | 12 | from mipcandy.data.dataset import SupervisedDataset |
@@ -37,6 +39,9 @@ def center_of_foreground(self) -> tuple[int, int] | tuple[int, int, int]: |
37 | 39 | round((self.foreground_bbox[3] + self.foreground_bbox[2]) * .5)) |
38 | 40 | return r if len(self.shape) == 2 else r + (round((self.foreground_bbox[5] + self.foreground_bbox[4]) * .5),) |
39 | 41 |
|
| 42 | + def to_dict(self) -> dict[str, tuple[int, ...]]: |
| 43 | + return asdict(self) |
| 44 | + |
40 | 45 |
|
41 | 46 | class InspectionAnnotations(HasDevice, Sequence[InspectionAnnotation]): |
42 | 47 | def __init__(self, dataset: SupervisedDataset, background: int, *annotations: InspectionAnnotation, |
@@ -65,10 +70,8 @@ def __len__(self) -> int: |
65 | 70 | return len(self._annotations) |
66 | 71 |
|
67 | 72 | def save(self, path: str | PathLike[str]) -> None: |
68 | | - r = [] |
69 | | - for annotation in self._annotations: |
70 | | - r.append({"foreground_bbox": annotation.foreground_bbox, "ids": annotation.ids}) |
71 | | - DataFrame(r).to_csv(path, index=False) |
| 73 | + with open(path, "w") as f: |
| 74 | + dump({"background": self._background, "annotations": self._annotations}, f) |
72 | 75 |
|
73 | 76 | def _get_shapes(self, get_shape: Callable[[InspectionAnnotation], tuple[int, ...]]) -> tuple[ |
74 | 77 | tuple[int, ...] | None, tuple[int, ...], tuple[int, ...]]: |
@@ -208,27 +211,31 @@ def crop_roi(self, i: int, *, percentile: float = .95) -> tuple[torch.Tensor, to |
208 | 211 | return crop(image.unsqueeze(0), roi).squeeze(0), crop(label.unsqueeze(0), roi).squeeze(0) |
209 | 212 |
|
210 | 213 |
|
211 | | -def load_inspection_annotations(path: str | PathLike[str]) -> InspectionAnnotations: |
212 | | - df = DataFrame.from_csv(path) |
213 | | - return InspectionAnnotations(*( |
214 | | - InspectionAnnotation( |
215 | | - tuple(row["shape"]), format_bbox(row["foreground_bbox"]), tuple(row["ids"]) |
216 | | - ) for _, row in df.iterrows() |
| 214 | +def _lists_to_tuples(pairs: Sequence[tuple[str, Any]]) -> dict[str, Any]: |
| 215 | + return {k: tuple(v) if isinstance(v, list) else v for k, v in pairs} |
| 216 | + |
| 217 | + |
| 218 | +def load_inspection_annotations(path: str | PathLike[str], dataset: SupervisedDataset) -> InspectionAnnotations: |
| 219 | + with open(path) as f: |
| 220 | + obj = load(f, object_pairs_hook=_lists_to_tuples) |
| 221 | + return InspectionAnnotations(dataset, obj["background"], *( |
| 222 | + InspectionAnnotation(**row) for row in obj["annotations"] |
217 | 223 | )) |
218 | 224 |
|
219 | 225 |
|
220 | | -def inspect(dataset: SupervisedDataset, *, background: int = 0) -> InspectionAnnotations: |
| 226 | +def inspect(dataset: SupervisedDataset, *, background: int = 0, console: Console = Console()) -> InspectionAnnotations: |
221 | 227 | r = [] |
222 | | - for _, label in dataset: |
223 | | - indices = (label != background).nonzero() |
224 | | - mins = indices.min(dim=0)[0].tolist() |
225 | | - maxs = indices.max(dim=0)[0].tolist() |
226 | | - bbox = (mins[1], maxs[1], mins[2], maxs[2]) |
227 | | - r.append(InspectionAnnotation( |
228 | | - label.shape[1:], |
229 | | - bbox if label.ndim == 3 else bbox + (mins[3], maxs[3]), |
230 | | - tuple(label.unique()) |
231 | | - )) |
| 228 | + with Progress(*Progress.get_default_columns(), SpinnerColumn(), console=console) as progress: |
| 229 | + task = progress.add_task("Inspecting dataset...", total=len(dataset)) |
| 230 | + for _, label in dataset: |
| 231 | + progress.update(task, advance=1, description=f"Inspecting dataset {tuple(label.shape)}") |
| 232 | + indices = (label != background).nonzero() |
| 233 | + mins = indices.min(dim=0)[0].tolist() |
| 234 | + maxs = indices.max(dim=0)[0].tolist() |
| 235 | + bbox = (mins[1], maxs[1], mins[2], maxs[2]) |
| 236 | + r.append(InspectionAnnotation( |
| 237 | + label.shape[1:], bbox if label.ndim == 3 else bbox + (mins[3], maxs[3]), tuple(label.unique()) |
| 238 | + )) |
232 | 239 | return InspectionAnnotations(dataset, background, *r, device=dataset.device()) |
233 | 240 |
|
234 | 241 |
|
|
0 commit comments