Skip to content

Commit fa47f06

Browse files
committed
orthophoto merge: parallelize the block loop (in-order writer + max_workers)
With boundless reads gated (previous commit), parallelize the per-block blend loop. Blocks are computed in a ThreadPoolExecutor and written from a single thread in strict block order with a bounded look-ahead (cap = 2 * max_workers), so writes to the compressed, tiled GeoTIFF stay sequential and incrementally flushable and memory stays small. GDAL's block cache is bounded during the merge (restored on exit). Per-thread source handles (GDAL/rasterio datasets are not thread-safe). Preserves --merge-skip-blending. Wired from stages/splitmerge.py as max_workers=args.max_concurrency. max_workers<=1 is byte-for-byte identical to the original serial loop.
1 parent 7371b32 commit fa47f06

2 files changed

Lines changed: 156 additions & 38 deletions

File tree

opendm/orthophoto.py

Lines changed: 155 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
11
import os
2+
import threading
3+
import contextlib
4+
from collections import deque
5+
from concurrent.futures import ThreadPoolExecutor
26
from opendm import log
37
from opendm import system
48
from opendm.cropper import Cropper
@@ -18,6 +22,45 @@
1822
from osgeo import ogr
1923

2024

25+
@contextlib.contextmanager
26+
def _bounded_gdal_cache(nbytes):
27+
"""Temporarily cap GDAL's global block cache, restoring it on exit.
28+
29+
A small cache keeps the parallel merge's output-tile flushes prompt and
30+
cheap. Restoring on the way out — even if the merge raises — avoids leaving
31+
the rest of the pipeline (notably the COG conversion) with a shrunken cache.
32+
"""
33+
prev = gdal.GetCacheMax()
34+
gdal.SetCacheMax(nbytes)
35+
try:
36+
yield
37+
finally:
38+
gdal.SetCacheMax(prev)
39+
40+
41+
def _read_window_gated(ds, src_window, dst_shape, dtype):
42+
"""Read ``src_window`` into a ``dst_shape`` array — equivalent to a
43+
``boundless=True`` read with 0 nodata fill, but avoiding rasterio's boundless
44+
VRT path for the common cases. ``boundless=True`` builds a VRT and serializes
45+
it via Python's ElementTree (``_serialize_xml``) on every read, which is a
46+
large per-read overhead on big merges. Behaviour:
47+
- window fully outside the dataset -> all zeros (no read), matching the 0
48+
nodata fill a boundless read would produce here;
49+
- window fully inside -> plain non-boundless read (no VRT), identical to a
50+
boundless read when no out-of-bounds padding is needed;
51+
- window partially overlapping the edge -> fall back to boundless (rare;
52+
only the true border blocks), preserving exact fill behaviour.
53+
"""
54+
(r0, r1), (c0, c1) = src_window
55+
out = np.zeros(dst_shape, dtype=dtype)
56+
height, width = ds.height, ds.width
57+
if r1 <= 0 or c1 <= 0 or r0 >= height or c0 >= width:
58+
return out
59+
if r0 >= 0 and c0 >= 0 and r1 <= height and c1 <= width:
60+
return ds.read(out=out, window=src_window, boundless=False, masked=False)
61+
return ds.read(out=out, window=src_window, boundless=True, masked=False)
62+
63+
2164
def get_orthophoto_vars(args):
2265
return {
2366
'TILED': 'NO' if args.orthophoto_no_tiled else 'YES',
@@ -265,32 +308,25 @@ def feather_raster(input_raster, output_raster, blend_distance=20):
265308

266309
return output_raster
267310

268-
def _read_window_gated(ds, src_window, dst_shape, dtype):
269-
"""Read ``src_window`` into a ``dst_shape`` array — equivalent to a
270-
``boundless=True`` read with 0 nodata fill, but avoiding rasterio's boundless
271-
VRT path for the common cases. ``boundless=True`` builds a VRT and serializes
272-
it via Python's ElementTree (``_serialize_xml``) on every read, which is a
273-
large per-read overhead on big merges. Behaviour:
274-
- window fully outside the dataset -> all zeros (no read), matching the 0
275-
nodata fill a boundless read would produce here;
276-
- window fully inside -> plain non-boundless read (no VRT), identical to a
277-
boundless read when no out-of-bounds padding is needed;
278-
- window partially overlapping the edge -> fall back to boundless (rare;
279-
only the true border blocks), preserving exact fill behaviour.
280-
"""
281-
(r0, r1), (c0, c1) = src_window
282-
out = np.zeros(dst_shape, dtype=dtype)
283-
height, width = ds.height, ds.width
284-
if r1 <= 0 or c1 <= 0 or r0 >= height or c0 >= width:
285-
return out
286-
if r0 >= 0 and c0 >= 0 and r1 <= height and c1 <= width:
287-
return ds.read(out=out, window=src_window, boundless=False, masked=False)
288-
return ds.read(out=out, window=src_window, boundless=True, masked=False)
289-
290-
def merge(input_ortho_and_ortho_cuts, output_orthophoto, orthophoto_vars={}, merge_skip_blending=False):
311+
def merge(input_ortho_and_ortho_cuts, output_orthophoto, orthophoto_vars={}, merge_skip_blending=False, max_workers=1):
291312
"""
292313
Based on https://github.com/mapbox/rio-merge-rgba/
293314
Merge orthophotos around cutlines using a blend buffer.
315+
316+
Each output block is an independent pure function of its source windows and
317+
the fixed source ordering, so blocks are processed in parallel. With
318+
max_workers <= 1 (the default) processing is strictly serial and the output
319+
is byte-for-byte identical to the original single-threaded loop.
320+
321+
Args:
322+
input_ortho_and_ortho_cuts: iterable of (orthophoto_path, cut_path) pairs.
323+
output_orthophoto: path for the merged output GeoTIFF.
324+
orthophoto_vars: rasterio profile overrides (TILED, COMPRESS, etc.).
325+
merge_skip_blending: if True, skip the feather/cutline blend passes
326+
(ODM #1934 --merge-skip-blending); only the first naive-copy pass runs.
327+
max_workers: number of parallel worker threads (default 1 = serial).
328+
Returns:
329+
The output_orthophoto path, or None if there were no valid inputs.
294330
"""
295331
inputs = []
296332
bounds=None
@@ -315,6 +351,7 @@ def merge(input_ortho_and_ortho_cuts, output_orthophoto, orthophoto_vars={}, mer
315351
profile = first.profile
316352
num_bands = first.meta['count'] - 1 # minus alpha
317353
colorinterp = first.colorinterp
354+
dst_count = first.count
318355

319356
log.ODM_INFO("%s valid orthophoto rasters to merge" % len(inputs))
320357
sources = [(rasterio.open(o), rasterio.open(c)) for o,c in inputs]
@@ -330,6 +367,10 @@ def merge(input_ortho_and_ortho_cuts, output_orthophoto, orthophoto_vars={}, mer
330367
if src.profile["count"] < 2:
331368
raise ValueError("Inputs must be at least 2-band rasters")
332369
dst_w, dst_s, dst_e, dst_n = min(xs), min(ys), max(xs), max(ys)
370+
# Close the pre-scan handles; they are unused in the parallel block loop.
371+
for s, c in sources:
372+
s.close()
373+
c.close()
333374
log.ODM_INFO("Output bounds: %r %r %r %r" % (dst_w, dst_s, dst_e, dst_n))
334375

335376
output_transform = Affine.translation(dst_w, dst_n)
@@ -360,23 +401,63 @@ def merge(input_ortho_and_ortho_cuts, output_orthophoto, orthophoto_vars={}, mer
360401
if merge_skip_blending:
361402
log.ODM_INFO("Skipping second and third pass orthophoto blending, as --merge-skip-blending passed")
362403

363-
# create destination file
364-
with rasterio.open(output_orthophoto, "w", **profile) as dstrast:
404+
# create destination file. Cap GDAL's global block cache for the merge (restored on
405+
# exit, see _bounded_gdal_cache): the default (5% of RAM) lets dirty output tiles
406+
# accumulate, so a source read can trigger a large eviction/flush under GDAL's global
407+
# lock that stalls the parallel readers; a small cache keeps flushes prompt and cheap.
408+
cache_bytes = 512 * 1024 * 1024
409+
with _bounded_gdal_cache(cache_bytes), \
410+
rasterio.open(output_orthophoto, "w", **profile) as dstrast:
365411
dstrast.colorinterp = colorinterp
366-
for idx, dst_window in dstrast.block_windows():
367-
left, bottom, right, top = dstrast.window_bounds(dst_window)
412+
413+
# Each output block is an independent function of its source windows and
414+
# the source ordering, so blocks can be COMPUTED in parallel. But writes
415+
# to a single compressed, tiled GeoTIFF must happen in row-major (block)
416+
# order: out-of-order writes cannot be flushed incrementally, so GDAL
417+
# hoards every dirty block in RAM until it thrashes or OOMs. So we compute
418+
# in a thread pool and write from one thread in strict block order, with a
419+
# small bounded look-ahead for backpressure. max_workers <= 1 is a plain
420+
# serial compute+write loop, identical to the original.
421+
tls = threading.local()
422+
block_windows = [(dst_window, dstrast.window_bounds(dst_window))
423+
for _, dst_window in dstrast.block_windows()]
424+
total_blocks = len(block_windows)
425+
log_every = max(1, total_blocks // 20)
426+
427+
opened_sources = []
428+
opened_lock = threading.Lock()
429+
430+
def get_sources():
431+
"""Return this thread's (ortho, cut) rasterio dataset handles.
432+
433+
GDAL/rasterio handles are not safe to share across threads, so each
434+
worker thread lazily opens and caches its own set on first use and
435+
registers it in opened_sources for cleanup after the parallel run.
436+
"""
437+
srcs = getattr(tls, "sources", None)
438+
if srcs is None:
439+
srcs = [(rasterio.open(o), rasterio.open(c)) for o, c in inputs]
440+
tls.sources = srcs
441+
with opened_lock:
442+
opened_sources.append(srcs)
443+
return srcs
444+
445+
def compute_block(item):
446+
"""Compute one output block (read + 3 blend passes); return the array.
447+
448+
Does NOT write — writing happens in block order on the main thread.
449+
"""
450+
dst_window, (left, bottom, right, top) = item
451+
local_sources = get_sources()
368452

369453
blocksize = dst_window.width
370454
dst_rows, dst_cols = (dst_window.height, dst_window.width)
371-
372-
# initialize array destined for the block
373-
dst_count = first.count
374455
dst_shape = (dst_count, dst_rows, dst_cols)
375456

376457
dstarr = np.zeros(dst_shape, dtype=dtype)
377458

378459
# First pass, write all rasters naively without blending
379-
for src, _ in sources:
460+
for src, _ in local_sources:
380461
src_window = tuple(zip(rowcol(
381462
src.transform, left, top, op=round, precision=precision
382463
), rowcol(
@@ -395,14 +476,13 @@ def merge(input_ortho_and_ortho_cuts, output_orthophoto, orthophoto_vars={}, mer
395476
if np.count_nonzero(dstarr[-1]) == blocksize:
396477
break
397478

398-
# Skip expensive blending operations if flag passed
479+
# Skip the feather/cutline blend passes if requested (ODM #1934)
399480
if merge_skip_blending:
400-
dstrast.write(dstarr, window=dst_window)
401-
continue
481+
return dstarr
402482

403483
# Second pass, write all feathered rasters
404484
# blending the edges
405-
for src, _ in sources:
485+
for src, _ in local_sources:
406486
src_window = tuple(zip(rowcol(
407487
src.transform, left, top, op=round, precision=precision
408488
), rowcol(
@@ -416,14 +496,14 @@ def merge(input_ortho_and_ortho_cuts, output_orthophoto, orthophoto_vars={}, mer
416496
blended = temp[-1] / 255.0 * temp[b] + (1 - temp[-1] / 255.0) * dstarr[b]
417497
np.copyto(dstarr[b], blended, casting='unsafe', where=where)
418498
dstarr[-1][where] = 255.0
419-
499+
420500
# check if dest has any nodata pixels available
421501
if np.count_nonzero(dstarr[-1]) == blocksize:
422502
break
423503

424504
# Third pass, write cut rasters
425505
# blending the cutlines
426-
for _, cut in sources:
506+
for _, cut in local_sources:
427507
src_window = tuple(zip(rowcol(
428508
cut.transform, left, top, op=round, precision=precision
429509
), rowcol(
@@ -438,6 +518,44 @@ def merge(input_ortho_and_ortho_cuts, output_orthophoto, orthophoto_vars={}, mer
438518
blended = temp[-1] / 255.0 * temp[b] + (1 - temp[-1] / 255.0) * dstarr[b]
439519
np.copyto(dstarr[b], blended, casting='unsafe', where=temp[-1]!=0)
440520

521+
return dstarr
522+
523+
def write_block(idx, dst_window, dstarr):
441524
dstrast.write(dstarr, window=dst_window)
525+
if (idx + 1) % log_every == 0:
526+
log.ODM_INFO("Merging orthophoto: %s / %s blocks" % (idx + 1, total_blocks))
527+
528+
if max_workers <= 1:
529+
# Serial: compute and write each block in order (original behavior).
530+
for idx, item in enumerate(block_windows):
531+
write_block(idx, item[0], compute_block(item))
532+
else:
533+
# Parallel compute; one in-order writer with bounded look-ahead so at
534+
# most `cap` blocks are in flight — keeps memory small and gives GDAL
535+
# strictly sequential, incrementally-flushable writes.
536+
cap = max_workers * 2
537+
with ThreadPoolExecutor(max_workers=max_workers) as ex:
538+
items = iter(enumerate(block_windows))
539+
pending = deque()
540+
for _ in range(cap):
541+
nxt = next(items, None)
542+
if nxt is None:
543+
break
544+
idx, item = nxt
545+
pending.append((idx, item[0], ex.submit(compute_block, item)))
546+
while pending:
547+
widx, wwin, fut = pending.popleft()
548+
dstarr = fut.result()
549+
write_block(widx, wwin, dstarr)
550+
nxt = next(items, None)
551+
if nxt is not None:
552+
idx, item = nxt
553+
pending.append((idx, item[0], ex.submit(compute_block, item)))
554+
555+
# Close all thread-local source handles opened during the run.
556+
for srcs in opened_sources:
557+
for s, c in srcs:
558+
s.close()
559+
c.close()
442560

443561
return output_orthophoto

stages/splitmerge.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,7 @@ def process(self, args, outputs):
265265
os.remove(tree.odm_orthophoto_tif)
266266

267267
orthophoto_vars = orthophoto.get_orthophoto_vars(args)
268-
orthophoto.merge(all_orthos_and_ortho_cuts, tree.odm_orthophoto_tif, orthophoto_vars, args.merge_skip_blending)
268+
orthophoto.merge(all_orthos_and_ortho_cuts, tree.odm_orthophoto_tif, orthophoto_vars, args.merge_skip_blending, max_workers=args.max_concurrency)
269269
orthophoto.post_orthophoto_steps(args, merged_bounds_file, tree.odm_orthophoto_tif, tree.orthophoto_tiles, args.orthophoto_resolution,
270270
reconstruction, tree, False)
271271
elif len(all_orthos_and_ortho_cuts) == 1:

0 commit comments

Comments
 (0)