Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 69 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from tests.test_annotation_stores import cell_polygon
from tiatoolbox import utils
from tiatoolbox.annotation.storage import SQLiteStore
from tiatoolbox.models.architecture import fetch_pretrained_weights
from tiatoolbox.utils import misc
from tiatoolbox.utils.exceptions import FileNotSupportedError
Expand Down Expand Up @@ -734,6 +735,7 @@ def test_sub_pixel_read_incorrect_read_func_return() -> None:
image = np.ones((10, 10))

def read_func(*args: tuple, **kwargs: dict) -> np.ndarray: # noqa: ARG001
"""Dummy read function for tests."""
return np.ones((5, 5))

with pytest.raises(ValueError, match="incorrect size"):
Expand All @@ -752,6 +754,7 @@ def test_sub_pixel_read_empty_read_func_return() -> None:
image = np.ones((10, 10))

def read_func(*args: tuple, **kwargs: dict) -> np.ndarray: # noqa: ARG001
"""Dummy read function for tests."""
return np.ones((0, 0))

with pytest.raises(ValueError, match="is empty"):
Expand Down Expand Up @@ -1642,3 +1645,69 @@ def test_imwrite(tmp_path: Path) -> NoReturn:
tmp_path / "thisfolderdoesnotexist" / "test_imwrite.jpg",
img,
)


def test_patch_pred_store() -> None:
"""Test patch_pred_store."""
# Define a mock patch_output
patch_output = {
"predictions": [1, 0, 1],
"coordinates": [(0, 0, 1, 1), (1, 1, 2, 2), (2, 2, 3, 3)],
"other": "other",
}

store = misc.patch_pred_store(patch_output, (1.0, 1.0))

# Check that its an SQLiteStore containing the expected annotations
assert isinstance(store, SQLiteStore)
assert len(store) == 3
for annotation in store.values():
assert annotation.geometry.area == 1
assert annotation.properties["type"] in [0, 1]
assert "other" not in annotation.properties

patch_output.pop("coordinates")
# check correct error is raised if coordinates are missing
with pytest.raises(ValueError, match="coordinates"):
misc.patch_pred_store(patch_output, (1.0, 1.0))


def test_patch_pred_store_cdict() -> None:
"""Test patch_pred_store with a class dict."""
# Define a mock patch_output
patch_output = {
"predictions": [1, 0, 1],
"coordinates": [(0, 0, 1, 1), (1, 1, 2, 2), (2, 2, 3, 3)],
"probabilities": [[0.1, 0.9], [0.9, 0.1], [0.4, 0.6]],
"labels": [1, 0, 1],
"other": "other",
}
class_dict = {0: "class0", 1: "class1"}
store = misc.patch_pred_store(patch_output, (1.0, 1.0), class_dict=class_dict)

# Check that its an SQLiteStore containing the expected annotations
assert isinstance(store, SQLiteStore)
assert len(store) == 3
for annotation in store.values():
assert annotation.geometry.area == 1
assert annotation.properties["label"] in ["class0", "class1"]
assert annotation.properties["type"] in ["class0", "class1"]
assert "other" not in annotation.properties


def test_patch_pred_store_sf() -> None:
"""Test patch_pred_store with scale factor."""
# Define a mock patch_output
patch_output = {
"predictions": [1, 0, 1],
"coordinates": [(0, 0, 1, 1), (1, 1, 2, 2), (2, 2, 3, 3)],
"probabilities": [[0.1, 0.9], [0.9, 0.1], [0.4, 0.6]],
"labels": [1, 0, 1],
}
store = misc.patch_pred_store(patch_output, (2.0, 2.0))

# Check that its an SQLiteStore containing the expected annotations
assert isinstance(store, SQLiteStore)
assert len(store) == 3
for annotation in store.values():
assert annotation.geometry.area == 4
68 changes: 66 additions & 2 deletions tiatoolbox/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import yaml
from filelock import FileLock
from shapely.affinity import translate
from shapely.geometry import Polygon
from shapely.geometry import shape as feature2geometry
from skimage import exposure

Expand Down Expand Up @@ -860,7 +861,8 @@ def select_device(*, on_gpu: bool) -> str:
"""Selects the appropriate device as requested.

Args:
on_gpu (bool): Selects gpu if True.
on_gpu (bool):
Selects gpu if True.

Returns:
str:
Expand All @@ -883,7 +885,6 @@ def model_to(model: torch.nn.Module, *, on_gpu: bool) -> torch.nn.Module:
Returns:
torch.nn.Module:
The model after being moved to cpu/gpu.

"""
if on_gpu: # DataParallel work only for cuda
model = torch.nn.DataParallel(model)
Expand Down Expand Up @@ -1194,3 +1195,66 @@ def add_from_dat(

logger.info("Added %d annotations.", len(anns))
store.append_many(anns)


def patch_pred_store(
patch_output: dict,
scale_factor: tuple[int, int],
class_dict: dict | None = None,
) -> AnnotationStore:
"""Create an SQLiteStore containing Annotations for each patch.

Args:
patch_output (dict): A dictionary of patch prediction information. Important
keys are "probabilities", "predictions", "coordinates", and "labels".
scale_factor (tuple[int, int]): The scale factor to use when loading the
annotations. All coordinates will be multiplied by this factor to allow
conversion of annotations saved at non-baseline resolution to baseline.
Should be model_mpp/slide_mpp.
class_dict (dict): Optional dictionary mapping class indices to class names.

Returns:
SQLiteStore: An SQLiteStore containing Annotations for each patch.

"""
if "coordinates" not in patch_output:
# we cant create annotations without coordinates
msg = "Patch output must contain coordinates."
raise ValueError(msg)
# get relevant keys
class_probs = patch_output.get("probabilities", [])
preds = patch_output.get("predictions", [])
patch_coords = np.array(patch_output.get("coordinates", []))
if not np.all(np.array(scale_factor) == 1):
patch_coords = patch_coords * (np.tile(scale_factor, 2)) # to baseline mpp
labels = patch_output.get("labels", [])
# get classes to consider
if len(class_probs) == 0:
classes_predicted = np.unique(preds).tolist()
else:
classes_predicted = range(len(class_probs[0]))
if class_dict is None:
# if no class dict create a default one
class_dict = {i: i for i in np.unique(preds + labels).tolist()}
annotations = []
# find what keys we need to save
keys = ["predictions"]
keys = keys + [key for key in ["probabilities", "labels"] if key in patch_output]

# put patch predictions into a store
annotations = []
for i, pred in enumerate(preds):
if "probabilities" in keys:
props = {
f"prob_{class_dict[j]}": class_probs[i][j] for j in classes_predicted
}
else:
props = {}
if "labels" in keys:
props["label"] = class_dict[labels[i]]
props["type"] = class_dict[pred]
annotations.append(Annotation(Polygon.from_bounds(*patch_coords[i]), props))
store = SQLiteStore()
keys = store.append_many(annotations, [str(i) for i in range(len(annotations))])

return store