Skip to content

Commit 42c3285

Browse files
committed
Improve memory use in area_potential_report rule
1 parent 5718c8a commit 42c3285

7 files changed

Lines changed: 81 additions & 48 deletions

File tree

config/config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ tiny_files: False
66
# A good option is "epsg:8857" (WGS 84 / Equal Earth Greenwich) for global coverage
77
buffer_crs: "epsg:8857"
88

9-
split_by: country_id
9+
split_by: country_id # likely country_id or shape_id
1010

1111
land_cover_types:
1212
POST_FLOODING: FARM

tests/integration/test_config.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,3 +100,5 @@ module_area_potentials:
100100
protected: 0
101101
shapes_buffer:
102102
land: 10000 # meters
103+
104+
overrides: {}

workflow/rules/process.smk

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,23 @@ rule aggregate_area_potential:
9898
"""
9999

100100

101+
rule plot_aggregated_area_potential:
102+
message:
103+
"Plot aggregated area potential for the tech {wildcards.tech} in {wildcards.shape}."
104+
input:
105+
rules.aggregate_area_potential.output.aggregated_area_potential,
106+
output:
107+
report(
108+
"results/{shape}/area_potential_{tech}.png", category="area_potential_plot"
109+
),
110+
log:
111+
"logs/{shape}/plot_aggregated_area_potential_{tech}.log",
112+
conda:
113+
"../envs/default.yaml"
114+
script:
115+
"../scripts/tif_to_png.py"
116+
117+
101118
rule area_potential_report:
102119
message:
103120
"Generate an overview report of the area potential for all techs in shapes {wildcards.shape}."
@@ -107,16 +124,16 @@ rule area_potential_report:
107124
"results/{{shape}}/area_potential_{tech}.tif",
108125
tech=config["techs"].keys(),
109126
),
127+
area_potential_plots=expand(
128+
"results/{{shape}}/area_potential_{tech}.png",
129+
tech=config["techs"].keys(),
130+
),
110131
output:
111132
csv="results/{shape}/area_potential_report.csv",
112133
html=report(
113134
"results/{shape}/area_potential_report.html",
114135
category="area_potential_report_table",
115136
),
116-
png=report(
117-
"results/{shape}/area_potential_report.png",
118-
category="area_potential_report_image",
119-
),
120137
log:
121138
"logs/{shape}/area_potential_report.log",
122139
conda:

workflow/scripts/report.py

Lines changed: 19 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -3,42 +3,39 @@
33
import geopandas as gpd
44
import pandas as pd
55
import rioxarray as rxr
6-
import script_utils
76
import xarray as xr
87
from resample import _rasterize_regions
98

109

11-
def report(shapes, area_potentials, csv_path, html_path, png_path):
10+
def report(shapes, area_potentials, csv_path, html_path):
1211
"""Generate a report summarizing area potentials for different technologies."""
1312
shapes = gpd.read_parquet(shapes)
1413

15-
ds_inputs = xr.Dataset()
16-
17-
# Collect the area potentials from the input files
18-
for area_potential in area_potentials:
19-
ds_inputs[area_potential] = rxr.open_rasterio(
20-
area_potential, mask_and_scale=True
21-
)
22-
23-
ds_inputs["regions"] = (
24-
("y", "x"),
25-
_rasterize_regions(shapes, ds_inputs[area_potential]),
14+
print("Generating reference raster and rasterizing regions...")
15+
reference_raster = rxr.open_rasterio(area_potentials[0])
16+
regions = xr.DataArray(
17+
_rasterize_regions(shapes, reference_raster),
18+
dims=("y", "x"),
19+
coords={"y": reference_raster.y, "x": reference_raster.x},
2620
)
21+
# regions = xr.DataArray(("y", "x"), _rasterize_regions(shapes, reference_raster))
22+
del reference_raster
2723

28-
script_utils.plot_all_dataset_variables(
29-
ds_inputs, ncols=2, savefig=png_path, categorical_vars=["regions"]
30-
)
31-
32-
ds_inputs = ds_inputs.squeeze().drop_vars(["band", "spatial_ref"])
33-
24+
# Collect the area potentials from the input files
3425
# Group the area potentials by regions, sum them up, and collect the resulting Series
3526
# into a DataFrame, where each column corresponds to a technology's area potential,
3627
# and the index corresponds to the regions.
3728
dataframes = []
38-
for area_potential in area_potentials:
39-
dataframes.append(
40-
ds_inputs[area_potential].groupby(ds_inputs["regions"]).sum().to_pandas()
29+
for area_potential_file in area_potentials:
30+
print(f"Processing area potential file: {area_potential_file}")
31+
da_area_potential = (
32+
rxr.open_rasterio(area_potential_file, mask_and_scale=True)
33+
.squeeze()
34+
.drop_vars(["band", "spatial_ref"])
4135
)
36+
dataframes.append(da_area_potential.groupby(regions).sum().to_pandas())
37+
del da_area_potential
38+
4239
df = pd.concat(dataframes, axis=1)
4340

4441
# Add metadata columns from shapes in front of the data columns
@@ -64,5 +61,4 @@ def report(shapes, area_potentials, csv_path, html_path, png_path):
6461
snakemake.input.area_potentials,
6562
snakemake.output.csv,
6663
snakemake.output.html,
67-
snakemake.output.png,
6864
)

workflow/scripts/resample.py

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -289,26 +289,6 @@ def resample_inputs(
289289
resampled.to_netcdf(output_path, encoding=netcdf4_encoding)
290290

291291
print("Saving image to plot path:", plot_path)
292-
# If needed, resample `resampled` to fit within a maximum of `max_pixels` pixels
293-
max_pixels = 5000000
294-
total_pixels = resampled.sizes["y"] * resampled.sizes["x"]
295-
if total_pixels > max_pixels:
296-
# Calculate the new resolution to fit within the max_pixels limit
297-
resolution_multiplier = 1 / math.sqrt(total_pixels / max_pixels)
298-
new_y_size = int(resampled.sizes["y"] * resolution_multiplier)
299-
new_x_size = int(resampled.sizes["x"] * resolution_multiplier)
300-
print(
301-
f"Resampling old size {resampled.sizes['y']} x {resampled.sizes['x']} "
302-
f"to new size: {new_y_size} x {new_x_size} "
303-
f"to fit within {max_pixels} pixels."
304-
)
305-
306-
resampled = resampled.coarsen(
307-
x=round(resampled.sizes["x"] / new_x_size),
308-
y=round(resampled.sizes["y"] / new_y_size),
309-
boundary="trim",
310-
).mean()
311-
312292
script_utils.plot_all_dataset_variables(resampled, ncols=3, savefig=plot_path)
313293

314294

workflow/scripts/script_utils.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,26 @@ def random_categorical_cmap(n, base_cmap="tab20", seed=42):
2929

3030
def plot_all_dataset_variables(ds, ncols=2, savefig=None, categorical_vars=[]):
3131
"""Plot all variables in an xarray dataset on a grid of plots."""
32+
# If needed, resample `ds` to fit within a maximum of `max_pixels` pixels
33+
max_pixels = 5000000
34+
total_pixels = ds.sizes["y"] * ds.sizes["x"]
35+
if total_pixels > max_pixels:
36+
# Calculate the new resolution to fit within the max_pixels limit
37+
resolution_multiplier = 1 / math.sqrt(total_pixels / max_pixels)
38+
new_y_size = int(ds.sizes["y"] * resolution_multiplier)
39+
new_x_size = int(ds.sizes["x"] * resolution_multiplier)
40+
print(
41+
f"Resampling old size {ds.sizes['y']} x {ds.sizes['x']} "
42+
f"to new size: {new_y_size} x {new_x_size} "
43+
f"to fit within {max_pixels} pixels."
44+
)
45+
46+
ds = ds.coarsen(
47+
x=round(ds.sizes["x"] / new_x_size),
48+
y=round(ds.sizes["y"] / new_y_size),
49+
boundary="trim",
50+
).mean()
51+
3252
# Drop dimensionless variables
3353
ds = ds.drop_vars(lambda x: [v for v, da in x.variables.items() if not da.ndim])
3454

workflow/scripts/tif_to_png.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
"""This script plots a TIF file to PNG format."""
2+
3+
import rioxarray as rxr
4+
from script_utils import plot_all_dataset_variables
5+
6+
7+
def tif_to_png(tif_file_in, png_file_out):
8+
"""Convert a TIF file to PNG format."""
9+
ds = rxr.open_rasterio(tif_file_in, mask_and_scale=True).to_dataset(
10+
name=tif_file_in
11+
)
12+
plot_all_dataset_variables(
13+
ds, ncols=2, savefig=png_file_out, categorical_vars=["regions"]
14+
)
15+
16+
17+
if __name__ == "__main__":
18+
tif_to_png(snakemake.input[0], snakemake.output[0])

0 commit comments

Comments
 (0)