11from __future__ import annotations
22
3- import itertools
4- from typing import Literal
5-
6- import dask .array as da
73import geopandas as gpd
84import numpy as np
95import pandas as pd
1713from spatialdata .transformations import get_transformation , set_transformation
1814
1915from squidpy ._utils import _yx_from_shape
20-
21- from ._utils import _get_element_data
16+ from squidpy .experimental .im ._utils import (
17+ TileGrid ,
18+ _get_element_data ,
19+ _get_mask_materialized ,
20+ _save_tile_grid_to_shapes ,
21+ )
2222
2323__all__ = ["make_tiles" , "make_tiles_from_spots" ]
2424
2525
26- class _TileGrid :
27- """Immutable tile grid definition with cached bounds and centroids."""
28-
29- def __init__ (
30- self ,
31- H : int ,
32- W : int ,
33- tile_size : Literal ["auto" ] | tuple [int , int ] = "auto" ,
34- target_tiles : int = 100 ,
35- offset_y : int = 0 ,
36- offset_x : int = 0 ,
37- ):
38- self .H = H
39- self .W = W
40- if tile_size == "auto" :
41- size = max (min (self .H // target_tiles , self .W // target_tiles ), 100 )
42- self .ty = int (size )
43- self .tx = int (size )
44- else :
45- self .ty = int (tile_size [0 ])
46- self .tx = int (tile_size [1 ])
47- self .offset_y = offset_y
48- self .offset_x = offset_x
49- # Calculate number of tiles needed to cover entire image, accounting for offset
50- # The grid starts at offset_y, offset_x (can be negative)
51- # We need tiles from min(0, offset_y) to at least H
52- # So total coverage needed is from min(0, offset_y) to H
53- grid_start_y = min (0 , self .offset_y )
54- grid_start_x = min (0 , self .offset_x )
55- total_h_needed = self .H - grid_start_y
56- total_w_needed = self .W - grid_start_x
57- self .tiles_y = (total_h_needed + self .ty - 1 ) // self .ty
58- self .tiles_x = (total_w_needed + self .tx - 1 ) // self .tx
59- # Cache immutable derived values
60- self ._indices = np .array ([[iy , ix ] for iy in range (self .tiles_y ) for ix in range (self .tiles_x )], dtype = int )
61- self ._names = [f"tile_x{ ix } _y{ iy } " for iy in range (self .tiles_y ) for ix in range (self .tiles_x )]
62- self ._bounds = self ._compute_bounds ()
63- self ._centroids_polys = self ._compute_centroids_and_polygons ()
64-
65- def indices (self ) -> np .ndarray :
66- return self ._indices
67-
68- def names (self ) -> list [str ]:
69- return self ._names
70-
71- def bounds (self ) -> np .ndarray :
72- return self ._bounds
73-
74- def _compute_bounds (self ) -> np .ndarray :
75- b : list [list [int ]] = []
76- for iy , ix in itertools .product (range (self .tiles_y ), range (self .tiles_x )):
77- y0 = iy * self .ty + self .offset_y
78- x0 = ix * self .tx + self .offset_x
79- y1 = ((iy + 1 ) * self .ty + self .offset_y ) if iy < self .tiles_y - 1 else self .H
80- x1 = ((ix + 1 ) * self .tx + self .offset_x ) if ix < self .tiles_x - 1 else self .W
81- # Clamp bounds to image dimensions
82- y0 = max (0 , min (y0 , self .H ))
83- x0 = max (0 , min (x0 , self .W ))
84- y1 = max (0 , min (y1 , self .H ))
85- x1 = max (0 , min (x1 , self .W ))
86- b .append ([y0 , x0 , y1 , x1 ])
87- return np .array (b , dtype = int )
88-
89- def centroids_and_polygons (self ) -> tuple [np .ndarray , list [Polygon ]]:
90- return self ._centroids_polys
91-
92- def _compute_centroids_and_polygons (self ) -> tuple [np .ndarray , list [Polygon ]]:
93- cents : list [list [float ]] = []
94- polys : list [Polygon ] = []
95- for y0 , x0 , y1 , x1 in self ._bounds :
96- cy = (y0 + y1 ) / 2
97- cx = (x0 + x1 ) / 2
98- cents .append ([cy , cx ])
99- polys .append (Polygon ([(x0 , y0 ), (x1 , y0 ), (x1 , y1 ), (x0 , y1 ), (x0 , y0 )]))
100- return np .array (cents , dtype = float ), polys
101-
102- def rechunk_and_pad (self , arr_yx : da .Array ) -> da .Array :
103- if arr_yx .ndim != 2 :
104- raise ValueError ("Expected a 2D array shaped (y, x)." )
105- pad_y = self .tiles_y * self .ty - int (arr_yx .shape [0 ])
106- pad_x = self .tiles_x * self .tx - int (arr_yx .shape [1 ])
107- a = arr_yx .rechunk ((self .ty , self .tx ))
108- return da .pad (a , ((0 , pad_y ), (0 , pad_x )), mode = "edge" ) if (pad_y > 0 or pad_x > 0 ) else a
109-
110- def coarsen (self , arr_yx : da .Array , reduce : Literal ["mean" , "sum" ] = "mean" ) -> da .Array :
111- reducer = np .mean if reduce == "mean" else np .sum
112- return da .coarsen (reducer , arr_yx , {0 : self .ty , 1 : self .tx }, trim_excess = False )
113-
114-
11526class _SpotTileGrid :
11627 """Tile container for Visium spots, used with ``_filter_tiles``."""
11728
@@ -204,34 +115,12 @@ def _choose_label_scale_for_image(label_node: Labels2DModel, target_hw: tuple[in
204115
205116def _save_tiles_to_shapes (
206117 sdata : sd .SpatialData ,
207- tg : _TileGrid ,
118+ tg : TileGrid ,
208119 image_key : str ,
209120 shapes_key : str ,
210121) -> None :
211122 """Save a TileGrid to sdata.shapes as a GeoDataFrame."""
212- tile_indices = tg .indices ()
213- pixel_bounds = tg .bounds ()
214- _ , polys = tg .centroids_and_polygons ()
215-
216- tile_gdf = gpd .GeoDataFrame (
217- {
218- "tile_id" : tg .names (),
219- "tile_y" : tile_indices [:, 0 ],
220- "tile_x" : tile_indices [:, 1 ],
221- "pixel_y0" : pixel_bounds [:, 0 ],
222- "pixel_x0" : pixel_bounds [:, 1 ],
223- "pixel_y1" : pixel_bounds [:, 2 ],
224- "pixel_x1" : pixel_bounds [:, 3 ],
225- "geometry" : polys ,
226- },
227- geometry = "geometry" ,
228- )
229-
230- sdata .shapes [shapes_key ] = ShapesModel .parse (tile_gdf )
231- # we know that a) the element exists and b) it has at least an Identity transformation
232- transformations = get_transformation (sdata .images [image_key ], get_all = True )
233- set_transformation (sdata .shapes [shapes_key ], transformations , set_all = True )
234- logger .info (f"Saved tile grid as 'sdata.shapes[\" { shapes_key } \" ]'" )
123+ _save_tile_grid_to_shapes (sdata , tg , shapes_key , copy_transforms_from_key = image_key )
235124
236125
237126def _save_spot_tiles_to_shapes (
@@ -366,7 +255,7 @@ def make_tiles(
366255 mask_key_for_grid = default_mask_key
367256 else :
368257 try :
369- from ._detect_tissue import detect_tissue
258+ from squidpy . experimental . im ._detect_tissue import detect_tissue
370259
371260 detect_tissue (
372261 sdata ,
@@ -411,7 +300,7 @@ def make_tiles(
411300 classification_mask_key ,
412301 )
413302 try :
414- from ._detect_tissue import detect_tissue
303+ from squidpy . experimental . im ._detect_tissue import detect_tissue
415304
416305 detect_tissue (
417306 sdata ,
@@ -558,7 +447,7 @@ def make_tiles_from_spots(
558447 classification_mask_key ,
559448 )
560449 try :
561- from ._detect_tissue import detect_tissue
450+ from squidpy . experimental . im ._detect_tissue import detect_tissue
562451
563452 detect_tissue (
564453 sdata ,
@@ -633,7 +522,7 @@ def make_tiles_from_spots(
633522
634523def _filter_tiles (
635524 sdata : sd .SpatialData ,
636- tg : _TileGrid ,
525+ tg : TileGrid ,
637526 image_key : str | None ,
638527 * ,
639528 tissue_mask_key : str | None = None ,
@@ -686,7 +575,7 @@ def _filter_tiles(
686575 raise ValueError ("tissue_mask_key must be provided when image_key is None." )
687576 if mask_key not in sdata .labels :
688577 raise KeyError (f"Tissue mask '{ mask_key } ' not found in sdata.labels." )
689- mask = _get_mask_from_labels (sdata , mask_key , scale )
578+ mask = _get_mask_materialized (sdata , mask_key , scale )
690579 H_mask , W_mask = mask .shape
691580
692581 # Check tissue coverage for each tile
@@ -751,7 +640,7 @@ def _make_tiles(
751640 tile_size : tuple [int , int ] = (224 , 224 ),
752641 center_grid_on_tissue : bool = False ,
753642 scale : str = "auto" ,
754- ) -> _TileGrid :
643+ ) -> TileGrid :
755644 """Construct a tile grid for an image, optionally centered on a tissue mask."""
756645 # Validate image key
757646 if image_key not in sdata .images :
@@ -764,7 +653,7 @@ def _make_tiles(
764653
765654 # Path 1: Regular grid starting from top-left
766655 if not center_grid_on_tissue or image_mask_key is None :
767- return _TileGrid (H , W , tile_size = tile_size )
656+ return TileGrid (H , W , tile_size = tile_size )
768657
769658 # Path 2: Center grid on tissue mask centroid
770659 if image_mask_key not in sdata .labels :
@@ -806,7 +695,7 @@ def _make_tiles(
806695 mask_bool = mask > 0
807696 if not mask_bool .any ():
808697 logger .warning ("Mask is empty. Using regular grid starting from top-left." )
809- return _TileGrid (H , W , tile_size = tile_size )
698+ return TileGrid (H , W , tile_size = tile_size )
810699
811700 # Calculate centroid using center of mass
812701 y_coords , x_coords = np .where (mask_bool )
@@ -821,7 +710,7 @@ def _make_tiles(
821710 offset_y = int (round (centroid_y - tile_center_y_standard ))
822711 offset_x = int (round (centroid_x - tile_center_x_standard ))
823712
824- return _TileGrid (H , W , tile_size = tile_size , offset_y = offset_y , offset_x = offset_x )
713+ return TileGrid (H , W , tile_size = tile_size , offset_y = offset_y , offset_x = offset_x )
825714
826715
827716def _get_spot_coordinates (
@@ -877,27 +766,3 @@ def _derive_tile_size_from_spots(coords: np.ndarray) -> tuple[int, int]:
877766 )
878767 side = max (1 , int (np .floor (row_spacing )))
879768 return side , side
880-
881-
882- def _get_mask_from_labels (sdata : sd .SpatialData , mask_key : str , scale : str ) -> np .ndarray :
883- """Extract a 2D mask array from ``sdata.labels`` at the requested scale."""
884- if mask_key not in sdata .labels :
885- raise KeyError (f"Mask key '{ mask_key } ' not found in sdata.labels" )
886-
887- label_node = sdata .labels [mask_key ]
888- mask_da = _get_element_data (label_node , scale , "label" , mask_key )
889-
890- if is_dask_collection (mask_da ):
891- mask_da = mask_da .compute ()
892-
893- if isinstance (mask_da , xr .DataArray ):
894- mask = np .asarray (mask_da .data )
895- else :
896- mask = np .asarray (mask_da )
897-
898- if mask .ndim > 2 :
899- mask = mask .squeeze ()
900- if mask .ndim != 2 :
901- raise ValueError (f"Expected 2D mask with shape (y, x), got shape { mask .shape } " )
902-
903- return mask
0 commit comments