Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
a28c7c9
init
selmanozleyen Jun 30, 2025
225d593
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 30, 2025
e549b4b
fix mypy linterrors
selmanozleyen Jun 30, 2025
2aad72b
update the location and the design
selmanozleyen Jul 3, 2025
ef74057
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 3, 2025
d6e22cb
update docs
selmanozleyen Jul 3, 2025
46c41db
Merge branch 'feature/filter_operations_on_label' of https://github.c…
selmanozleyen Jul 3, 2025
80d95a2
make coverage 100/100 because why not
selmanozleyen Jul 3, 2025
4438605
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 3, 2025
4c927ee
fixed type annotation
selmanozleyen Jul 10, 2025
e9e0da2
dont compute eagerly use. delete other instance key for consistency
selmanozleyen Jul 10, 2025
7534c91
update the tests and make sure we use match_element_to_table
selmanozleyen Jul 14, 2025
b4901cb
Merge branch 'main' into feature/filter_operations_on_label
selmanozleyen Aug 21, 2025
b21a0a1
Merge branch 'main' into feature/filter_operations_on_label
selmanozleyen Oct 20, 2025
631fe2a
Merge branch 'main' into feature/filter_operations_on_label
selmanozleyen Nov 3, 2025
908f7b4
wip rewrite tests using existing APIs
LucaMarconato May 11, 2026
6f7e468
Merge branch 'main' into feature/filter_operations_on_label
LucaMarconato May 11, 2026
359d553
test passing without using subset_sdata_by_table_mask()
LucaMarconato May 11, 2026
71e27ef
Remove _filter_by_instance_ids and _get_scale_factors; refactor tests…
LucaMarconato May 11, 2026
264ad2a
Add filter_label_pixels flag to match_sdata_to_table and filter_by_ta…
LucaMarconato May 11, 2026
7f81db0
Change filter_label_pixels default to None; False silences the warning
LucaMarconato May 11, 2026
6cab962
Move and consolidate label-filtering tests into test_relational_query…
LucaMarconato May 11, 2026
8d6552e
Add 3D labels guard to _filter_labels_element; fix annsel list predic…
LucaMarconato May 11, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 114 additions & 0 deletions src/spatialdata/_core/query/masking.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
from functools import partial

import numpy as np
import xarray as xr
from geopandas import GeoDataFrame
from xarray.core.dataarray import DataArray
from xarray.core.datatree import DataTree

from spatialdata.models import Labels2DModel, ShapesModel


def _mask_block(block: xr.DataArray, ids_to_remove: list[int]) -> xr.DataArray:
# Use apply_ufunc for efficient processing
# Create a copy to avoid modifying read-only array
result = block.copy()
result[np.isin(result, ids_to_remove)] = 0
return result


def _set_instance_ids_in_labels_to_zero(image: xr.DataArray, ids_to_remove: list[int]) -> xr.DataArray:
processed = xr.apply_ufunc(
partial(_mask_block, ids_to_remove=ids_to_remove),
image,
input_core_dims=[["y", "x"]],
output_core_dims=[["y", "x"]],
vectorize=True,
dask="parallelized",
output_dtypes=[image.dtype],
dataset_fill_value=0,
dask_gufunc_kwargs={"allow_rechunk": True},
)

# Force computation to ensure the changes are materialized
computed_result = processed.compute()

# Create a new DataArray to ensure persistence
return xr.DataArray(
data=computed_result.data,
coords=image.coords,
dims=image.dims,
attrs=image.attrs.copy(), # Preserve all attributes
)


def _get_scale_factors(labels_element: Labels2DModel) -> list[tuple[float, float]]:
scales = list(labels_element.keys())

# Calculate relative scale factors between consecutive scales
scale_factors = []
for i in range(len(scales) - 1):
y_size_current = labels_element[scales[i]].image.shape[0]
x_size_current = labels_element[scales[i]].image.shape[1]
y_size_next = labels_element[scales[i + 1]].image.shape[0]
x_size_next = labels_element[scales[i + 1]].image.shape[1]
y_factor = y_size_current / y_size_next
x_factor = x_size_current / x_size_next

scale_factors.append((y_factor, x_factor))

return scale_factors


def filter_shapesmodel_by_instance_ids(element: ShapesModel, ids_to_remove: list[str]) -> GeoDataFrame:
Comment thread
selmanozleyen marked this conversation as resolved.
Outdated
"""
Filter a ShapesModel by instance ids.

Parameters
----------
element
The ShapesModel to filter.
ids_to_remove
The instance ids to remove.

Returns
-------
The filtered ShapesModel.
"""
element2: GeoDataFrame = element[~element.index.isin(ids_to_remove)] # type: ignore[index, attr-defined]
return ShapesModel.parse(element2)


def filter_labels2dmodel_by_instance_ids(element: Labels2DModel, ids_to_remove: list[int]) -> DataArray | DataTree:
"""
Filter a Labels2DModel by instance ids.

This function works for both DataArray and DataTree and sets the
instance ids to zero.

Parameters
----------
element
The Labels2DModel to filter.
ids_to_remove
The instance ids to remove.

Returns
-------
The filtered Labels2DModel.
"""
if isinstance(element, xr.DataArray):
return Labels2DModel.parse(_set_instance_ids_in_labels_to_zero(element, ids_to_remove))

if isinstance(element, DataTree):
# we extract the info to just reconstruct
# the DataTree after filtering the max scale
max_scale = list(element.keys())[0]
scale_factors_temp = _get_scale_factors(element)
scale_factors = [int(sf[0]) for sf in scale_factors_temp]

return Labels2DModel.parse(
data=_set_instance_ids_in_labels_to_zero(element[max_scale].image, ids_to_remove),
scale_factors=scale_factors,
)
raise ValueError(f"Unknown element type: {type(element)}")
42 changes: 42 additions & 0 deletions tests/core/query/test_masking.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import numpy as np

from spatialdata._core.query.masking import filter_labels2dmodel_by_instance_ids, filter_shapesmodel_by_instance_ids
from spatialdata.datasets import blobs_annotating_element


def test_filter_labels2dmodel_by_instance_ids():
sdata = blobs_annotating_element("blobs_labels")
labels_element = sdata["blobs_labels"]
all_instance_ids = sdata.tables["table"].obs["instance_id"].unique()
filtered_labels_element = filter_labels2dmodel_by_instance_ids(labels_element, [2, 3])

# because 0 is the background, we expect the filtered ids to be the instance ids that are not 0
filtered_ids = set(np.unique(filtered_labels_element.data.compute())) - {
0,
}
preserved_ids = np.unique(labels_element.data.compute())
assert filtered_ids == (set(all_instance_ids) - {2, 3})
# check if there is modification of the original labels
assert set(preserved_ids) == set(all_instance_ids) | {0}

sdata.tables["table"].uns["spatialdata_attrs"]["region"] = "blobs_multiscale_labels"
sdata.tables["table"].obs.region = "blobs_multiscale_labels"
labels_element = sdata["blobs_multiscale_labels"]
filtered_labels_element = filter_labels2dmodel_by_instance_ids(labels_element, [2, 3])

for scale in labels_element:
filtered_ids = set(np.unique(filtered_labels_element[scale].image.compute())) - {
0,
}
preserved_ids = np.unique(labels_element[scale].image.compute())
assert filtered_ids == (set(all_instance_ids) - {2, 3})
# check if there is modification of the original labels
assert set(preserved_ids) == set(all_instance_ids) | {0}


def test_filter_shapesmodel_by_instance_ids():
sdata = blobs_annotating_element("blobs_circles")
shapes_element = sdata["blobs_circles"]
filtered_shapes_element = filter_shapesmodel_by_instance_ids(shapes_element, [2, 3])

assert set(filtered_shapes_element.index.tolist()) == {0, 1, 4}
Loading