Skip to content

Commit 258d16c

Browse files
authored
perf(geotiff): _write_vrt_tiled uses threaded dask scheduler (#1714) (#1725)
* perf(geotiff): _write_vrt_tiled uses threaded dask scheduler (#1714) Each delayed task in _write_vrt_tiled writes one tile to its own output path with no shared mutable Python state, so the writes are embarrassingly parallel. The prior code called dask.compute with scheduler='synchronous', which forced every tile through the calling thread one at a time. Switch to scheduler='threads'. zlib/zstd/LZW release the GIL during compression, so threading delivers real wall-time wins on the compression stage. Microbench: 4096x4096 float32 dask DataArray with chunks=256 (256 output tiles) at zstd compression drops from 0.49s to 0.33s (~33% reduction). Adds tests covering the scheduler choice, the tile-file inventory, and a determinism check that runs the same write twice and compares every tile byte-for-byte to catch any race regression. * Address Copilot lint feedback on #1725 - Remove unused pytest import in test_vrt_tiled_scheduler_1714.py - Use Path.read_bytes() in tile-byte comparison to avoid leaking file descriptors (the previous dict comprehension opened files via ``open(p, "rb").read()`` without a context manager)
1 parent 995b688 commit 258d16c

2 files changed

Lines changed: 120 additions & 2 deletions

File tree

xrspatial/geotiff/__init__.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1833,10 +1833,18 @@ def _write_vrt_tiled(data, vrt_path, *, crs=None, nodata=None,
18331833
col_offset += chunk_w
18341834
row_offset += chunk_h
18351835

1836-
# Execute all dask tasks
1836+
# Execute all dask tasks.
1837+
#
1838+
# Each delayed task is an independent ``_write_single_tile`` call on
1839+
# a distinct output path, with no shared mutable Python state, so
1840+
# the writes are embarrassingly parallel. Using ``scheduler='threads'``
1841+
# lets zlib / zstd / LZW release the GIL during compression and the
1842+
# OS coalesce concurrent writes; in a 256-tile zstd write on a
1843+
# 4096x4096 dask DataArray the wall time drops ~33% versus the
1844+
# ``synchronous`` scheduler this used to call (issue #1714).
18371845
if delayed_tasks:
18381846
import dask
1839-
dask.compute(*delayed_tasks, scheduler='synchronous')
1847+
dask.compute(*delayed_tasks, scheduler='threads')
18401848

18411849
# Write VRT index with relative paths
18421850
from ._vrt import write_vrt as _write_vrt_fn
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
"""Verify ``_write_vrt_tiled`` runs tile writes on dask's threaded scheduler.
2+
3+
Issue #1714: the prior code called ``dask.compute(*delayed_tasks,
4+
scheduler='synchronous')`` which serialised independent tile writes on
5+
the calling thread. Switching to the threaded scheduler reduces wall
6+
time by ~33% on a 256-tile zstd write. These tests pin the new
7+
scheduler choice and confirm the output is correct (no concurrent-write
8+
races, all tiles present, content matches).
9+
"""
10+
from __future__ import annotations
11+
12+
import glob
13+
import os
14+
import tempfile
15+
from pathlib import Path
16+
from unittest.mock import patch
17+
18+
import dask
19+
import dask.array as da
20+
import numpy as np
21+
import xarray as xr
22+
23+
from xrspatial.geotiff import to_geotiff
24+
25+
26+
def _make_dask_da(h: int = 32, w: int = 32, chunk: int = 8) -> xr.DataArray:
27+
"""Return a dask-backed 2D DataArray with ``chunk``-sized chunks.
28+
29+
Using ``da.from_array`` on a pre-built numpy array gives clean
30+
``(chunk, chunk)`` chunking. ``da.arange(...).reshape(...)`` keeps a
31+
chunk size of 1 along the new axis, which produces a confusing test
32+
setup.
33+
"""
34+
arr = np.arange(h * w, dtype=np.float32).reshape(h, w)
35+
return xr.DataArray(
36+
da.from_array(arr, chunks=(chunk, chunk)),
37+
dims=["y", "x"],
38+
)
39+
40+
41+
def test_vrt_tiled_uses_threaded_scheduler():
42+
"""_write_vrt_tiled passes ``scheduler='threads'`` to dask.compute."""
43+
da_arr = _make_dask_da()
44+
with tempfile.TemporaryDirectory(prefix="vrt_sched_1714_") as td:
45+
vrt = os.path.join(td, "sched_check.vrt")
46+
47+
# Wrap dask.compute so we can record the scheduler kwarg the
48+
# writer chose. The dask module is imported inside
49+
# _write_vrt_tiled, so we patch on the module object directly.
50+
captured = {}
51+
real_compute = dask.compute
52+
53+
def spy(*args, **kwargs):
54+
captured["scheduler"] = kwargs.get("scheduler")
55+
return real_compute(*args, **kwargs)
56+
57+
with patch.object(dask, "compute", side_effect=spy) as p:
58+
to_geotiff(da_arr, vrt)
59+
assert p.called, "_write_vrt_tiled never invoked dask.compute"
60+
61+
assert captured.get("scheduler") == "threads", (
62+
"Expected scheduler='threads' on the VRT-tiled write but "
63+
f"got {captured.get('scheduler')!r}"
64+
)
65+
66+
67+
def test_vrt_tiled_threaded_write_produces_all_tiles():
68+
"""All expected tile files exist after the threaded write."""
69+
da_arr = _make_dask_da(h=32, w=32, chunk=8) # 4x4 = 16 tiles
70+
with tempfile.TemporaryDirectory(prefix="vrt_sched_1714_") as td:
71+
vrt = os.path.join(td, "tile_count.vrt")
72+
to_geotiff(da_arr, vrt)
73+
tiles_dir = os.path.join(td, "tile_count_tiles")
74+
tiles = sorted(glob.glob(os.path.join(tiles_dir, "*.tif")))
75+
assert len(tiles) == 16, (
76+
f"Expected 16 tile files, got {len(tiles)} in {tiles_dir}"
77+
)
78+
79+
80+
def test_vrt_tiled_threaded_write_is_deterministic():
81+
"""Threaded scheduler must not introduce write ordering races.
82+
83+
Each delayed task writes to its own file path, so the threaded
84+
scheduler is safe. Run the same write twice and compare byte
85+
contents of every tile to catch any accidental race regression.
86+
"""
87+
da_arr = _make_dask_da(h=32, w=32, chunk=8)
88+
89+
def _write_and_collect(vrt_path: str) -> dict[str, bytes]:
90+
to_geotiff(da_arr, vrt_path)
91+
stem = os.path.splitext(os.path.basename(vrt_path))[0]
92+
tiles_dir = os.path.join(os.path.dirname(vrt_path), stem + "_tiles")
93+
return {
94+
os.path.basename(p): Path(p).read_bytes()
95+
for p in sorted(glob.glob(os.path.join(tiles_dir, "*.tif")))
96+
}
97+
98+
with tempfile.TemporaryDirectory(prefix="vrt_sched_1714_") as td1:
99+
with tempfile.TemporaryDirectory(prefix="vrt_sched_1714_") as td2:
100+
tiles1 = _write_and_collect(os.path.join(td1, "run1.vrt"))
101+
tiles2 = _write_and_collect(os.path.join(td2, "run2.vrt"))
102+
103+
assert set(tiles1) == set(tiles2), (
104+
"Tile file set differs between runs: "
105+
f"{set(tiles1) ^ set(tiles2)}"
106+
)
107+
for name, blob1 in tiles1.items():
108+
assert blob1 == tiles2[name], (
109+
f"Tile {name} differs between runs (race condition?)"
110+
)

0 commit comments

Comments
 (0)