Skip to content

Commit 0826ac8

Browse files
Performance enhancement of image statistics. Not calculating in RGB mode, as not needed. Batching for cal_on_everything() agains OOM.
1 parent 287cad3 commit 0826ac8

2 files changed

Lines changed: 70 additions & 12 deletions

File tree

detectree2/preprocessing/tiling.py

Lines changed: 69 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -673,22 +673,79 @@ def calculate_image_statistics(file_path,
673673
def calc_on_everything():
674674
logger.info("Processing entire image...")
675675
band_stats = []
676+
677+
# Define chunk size for reading (e.g. 2048 rows)
678+
chunk_height = 2048
679+
676680
for band_idx in range(1, src.count + 1):
677681
if band_idx - 1 in ignore_bands_indices:
678682
continue
679-
band = src.read(band_idx).astype(float)
680-
# Mask out bad values
681-
mask = (np.isnan(band) | np.isin(band, values_to_ignore))
682-
valid_data = band[~mask]
683-
684-
if valid_data.size > 0:
685-
min_val, max_val = np.percentile(valid_data, [1, 99])
683+
684+
# Accumulators for exact stats
685+
total_count = 0
686+
total_sum = 0.0
687+
total_sum_sq = 0.0
688+
global_min = float('inf')
689+
global_max = float('-inf')
690+
691+
# Buffer for percentiles
692+
percentile_buffer = []
693+
buffer_size = 0
694+
MAX_BUFFER = 5_000_000 # 5 million pixels ~ 40MB
695+
696+
for row_off in tqdm(range(0, height, chunk_height), desc=f"Calculating stats for band {band_idx}", leave=False):
697+
h = min(chunk_height, height - row_off)
698+
window = rasterio.windows.Window(0, row_off, width, h)
699+
700+
band_chunk = src.read(band_idx, window=window).astype(float)
701+
702+
# Mask out bad values
703+
mask = (np.isnan(band_chunk) | np.isin(band_chunk, values_to_ignore))
704+
valid_chunk = band_chunk[~mask]
705+
706+
if valid_chunk.size > 0:
707+
# Update exact stats
708+
c_min = np.min(valid_chunk)
709+
c_max = np.max(valid_chunk)
710+
c_sum = np.sum(valid_chunk)
711+
c_sum_sq = np.sum(valid_chunk ** 2)
712+
c_count = valid_chunk.size
713+
714+
if c_min < global_min: global_min = c_min
715+
if c_max > global_max: global_max = c_max
716+
total_sum += c_sum
717+
total_sum_sq += c_sum_sq
718+
total_count += c_count
719+
720+
# Update percentile buffer
721+
percentile_buffer.append(valid_chunk)
722+
buffer_size += c_count
723+
724+
if buffer_size > MAX_BUFFER:
725+
merged = np.concatenate(percentile_buffer)
726+
# Downsample to keep memory usage low
727+
merged = merged[::2]
728+
percentile_buffer = [merged]
729+
buffer_size = merged.size
730+
731+
if total_count > 0:
732+
# Finalize percentiles
733+
if percentile_buffer:
734+
final_buffer = np.concatenate(percentile_buffer)
735+
min_val, max_val = np.percentile(final_buffer, [1, 99])
736+
else:
737+
min_val, max_val = global_min, global_max
738+
739+
mean_val = total_sum / total_count
740+
# Variance = (SumSq / N) - Mean^2
741+
variance = (total_sum_sq / total_count) - (mean_val ** 2)
742+
std_dev = np.sqrt(max(0, variance))
686743

687744
stats = {
688-
"mean": float(np.mean(valid_data)),
745+
"mean": float(mean_val),
689746
"min": float(min_val),
690747
"max": float(max_val),
691-
"std_dev": float(np.std(valid_data)),
748+
"std_dev": float(std_dev),
692749
}
693750
else:
694751
stats = {
@@ -835,11 +892,12 @@ def tile_data(
835892

836893
tile_coordinates = _calculate_tile_placements(img_path, buffer, tile_width, tile_height, crowns, tile_placement,
837894
overlapping_tiles)
895+
838896
image_statistics = calculate_image_statistics(img_path,
839897
values_to_ignore=additional_nodata,
840898
mode=mode,
841-
ignore_bands_indices=ignore_bands_indices)
842-
899+
ignore_bands_indices=ignore_bands_indices) if mode == "ms" else None # Only needed for multispectral data
900+
843901
tile_args = [
844902
(img_path, out_dir, buffer, tile_width, tile_height, dtype_bool, minx, miny, crs, tilename, crowns, threshold,
845903
nan_threshold, mode, class_column, mask_gdf, additional_nodata, image_statistics, ignore_bands_indices,

docs/source/tutorials/02_data_preparation.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ The ``tile_data`` function exposes many parameters to control how tiles are crea
8888

8989
- ``nan_threshold``: The maximum proportion of a tile that can be NaN (or other no-data values) before it is discarded.
9090

91-
- ``use_convex_mask``: When ``True``, this creates a tight "wrapper" polygon (a convex hull) around all the training crowns within a tile. Any pixels outside this wrapper are masked out. This is a way to reduce noise by forcing the model to ignore parts of the tile that are far from any labeled object.
91+
- ``use_convex_mask``: When ``True``, this creates a tight "wrapper" polygon (a convex hull) around all the training crowns within a tile. Any pixels outside this wrapper are masked out. This is a way to reduce noise by forcing the model to ignore parts of the tile that are far from any labeled object. Note that the masked out area counts towards the ``nan_threshold``, so you may need to increase ``nan_threshold`` when using this option.
9292

9393
- ``enhance_rgb_contrast``: When ``True`` (for RGB images only), this applies a percentile contrast stretch. It calculates the 0.2 and 99.8 percentile pixel values and rescales the image to a 1-255 range. This is effective for normalizing hazy, dark, or washed-out imagery. It allows the model to more easily differentiate between tree crowns. 0 is reserved for masked-out areas.
9494

0 commit comments

Comments
 (0)