Skip to content

Commit 5a8952d

Browse files
committed
Add BoundaryStore: memmap-backed boundary strip storage for dask tile sweeps
1 parent 0417a33 commit 5a8952d

File tree

1 file changed

+104
-0
lines changed

1 file changed

+104
-0
lines changed

xrspatial/_boundary_store.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
"""Memmap-backed boundary strip storage for dask tile sweeps.
2+
3+
Stores top/bottom/left/right boundary strips for a 2D tile grid in
4+
flat numpy memory-mapped files, avoiding O(N_tiles) in-memory nested
5+
lists that can OOM on very large inputs.
6+
"""
7+
8+
from __future__ import annotations
9+
10+
import os
11+
import shutil
12+
import tempfile
13+
14+
import numpy as np
15+
16+
17+
class BoundaryStore:
18+
"""Disk-backed storage for boundary strips of a tiled 2D grid.
19+
20+
Each of the four sides (top, bottom, left, right) is stored as a
21+
single contiguous memmap file. Strip lookup is O(1) via
22+
precomputed cumulative offsets, and ``get`` returns a zero-copy
23+
memmap view.
24+
25+
Parameters
26+
----------
27+
chunks_y : sequence of int
28+
Tile heights (one per tile row).
29+
chunks_x : sequence of int
30+
Tile widths (one per tile column).
31+
fill_value : float, optional
32+
Initial fill value for all strips (default 0.0).
33+
"""
34+
35+
def __init__(self, chunks_y, chunks_x, fill_value=0.0):
36+
self._tmpdir = tempfile.mkdtemp(prefix='xrs_bdry_')
37+
self._closed = False
38+
self._chunks_y = tuple(chunks_y)
39+
self._chunks_x = tuple(chunks_x)
40+
n_ty = len(chunks_y)
41+
n_tx = len(chunks_x)
42+
total_h = sum(chunks_y)
43+
total_w = sum(chunks_x)
44+
45+
# Cumulative offsets for O(1) strip lookup
46+
self._cum_x = np.zeros(n_tx + 1, dtype=np.int64)
47+
np.cumsum(chunks_x, out=self._cum_x[1:])
48+
self._cum_y = np.zeros(n_ty + 1, dtype=np.int64)
49+
np.cumsum(chunks_y, out=self._cum_y[1:])
50+
51+
# top/bottom: strip length = chunks_x[ix], indexed by (iy, ix)
52+
# shape (n_ty, total_w) — row iy holds all top/bottom strips
53+
# left/right: strip length = chunks_y[iy], indexed by (iy, ix)
54+
# shape (n_tx, total_h) — row ix holds all left/right strips
55+
for name, shape in [('top', (n_ty, total_w)),
56+
('bottom', (n_ty, total_w)),
57+
('left', (n_tx, total_h)),
58+
('right', (n_tx, total_h))]:
59+
path = os.path.join(self._tmpdir, f'{name}.dat')
60+
mm = np.memmap(path, dtype=np.float64, mode='w+', shape=shape)
61+
mm[:] = fill_value
62+
mm.flush()
63+
setattr(self, f'_{name}', mm)
64+
65+
def get(self, side, iy, ix):
66+
"""Return a memmap view of the boundary strip for tile (iy, ix)."""
67+
if side == 'top':
68+
return self._top[iy, self._cum_x[ix]:self._cum_x[ix + 1]]
69+
elif side == 'bottom':
70+
return self._bottom[iy, self._cum_x[ix]:self._cum_x[ix + 1]]
71+
elif side == 'left':
72+
return self._left[ix, self._cum_y[iy]:self._cum_y[iy + 1]]
73+
elif side == 'right':
74+
return self._right[ix, self._cum_y[iy]:self._cum_y[iy + 1]]
75+
else:
76+
raise ValueError(f"Unknown side: {side!r}")
77+
78+
def set(self, side, iy, ix, data):
79+
"""Write *data* into the boundary strip for tile (iy, ix)."""
80+
self.get(side, iy, ix)[:] = data
81+
82+
def close(self):
83+
"""Flush memmaps and remove temporary files."""
84+
if self._closed:
85+
return
86+
self._closed = True
87+
for name in ('top', 'bottom', 'left', 'right'):
88+
mm = getattr(self, f'_{name}', None)
89+
if mm is not None:
90+
del mm
91+
setattr(self, f'_{name}', None)
92+
try:
93+
shutil.rmtree(self._tmpdir)
94+
except OSError:
95+
pass
96+
97+
def __del__(self):
98+
self.close()
99+
100+
def __enter__(self):
101+
return self
102+
103+
def __exit__(self, *exc):
104+
self.close()

0 commit comments

Comments
 (0)