Skip to content

Commit b04717a

Browse files
committed
fix io funcs for on-the-fly quantiles, fix storage_options default, add other selectors to prep_sliiders, allow other name options for dims of input stores, improve performance of create_template_dataarray, drop now unnecessary save_to_zarr_region wrapper
1 parent ac47952 commit b04717a

1 file changed

Lines changed: 90 additions & 126 deletions

File tree

pyCIAM/io.py

Lines changed: 90 additions & 126 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def prep_sliiders(
3636
selectors={},
3737
calc_popdens_with_wetland_area=True,
3838
expand_exposure=True,
39-
storage_options={},
39+
storage_options=None,
4040
):
4141
"""Import the SLIIDERS dataset (or a different dataset formatted analogously).
4242
@@ -102,7 +102,10 @@ def prep_sliiders(
102102
"""
103103
inputs_all = xr.open_zarr(
104104
str(input_store), chunks=None, storage_options=storage_options
105-
).sel(selectors, drop=True)
105+
)
106+
inputs_all = inputs_all.sel(
107+
{k: v for k, v in selectors.items() if k in inputs_all.dims}, drop=True
108+
)
106109

107110
inputs = inputs_all.sel({seg_var: seg_vals})
108111
inputs = _s2d(inputs).assign(constants)
@@ -193,30 +196,28 @@ def _load_scenario_mc(
193196
include_cc=True,
194197
quantiles=None,
195198
ncc_name="ncc",
196-
storage_options={},
199+
storage_options=None,
197200
):
198201
scen_mc_filter = xr.open_zarr(
199202
str(slr_store), chunks=None, storage_options=storage_options
200203
)[["scenario", mc_dim]]
201204
if quantiles is not None:
202205
if mc_dim == "quantile":
203-
scen_mc_filter = scen_mc_filter.sel(quantile=quantiles)
206+
scen_mc_filter = scen_mc_filter.sel(quantile=quantiles).sortby(
207+
["scenario", mc_dim]
208+
)
204209
else:
205-
scen_mc_filter = scen_mc_filter.quantile(quantiles, dim=mc_dim)
206-
207-
scen_mc_filter = (
208-
scen_mc_filter.to_dataframe().sort_values(["scenario", mc_dim]).index
209-
)
210+
scen_mc_filter = scen_mc_filter.scenario.sortby("scenario")
211+
scen_mc_filter = scen_mc_filter.to_dataframe().index
210212

211213
if include_ncc:
212-
scen_mc_filter = scen_mc_filter.append(
213-
pd.MultiIndex.from_product(
214-
(
215-
[ncc_name],
216-
scen_mc_filter.get_level_values(mc_dim).unique().sort_values(),
217-
),
218-
names=["scenario", mc_dim],
219-
)
214+
scen_mc_filter = scen_mc_filter.union(
215+
pd.DataFrame(index=scen_mc_filter)
216+
.reset_index()
217+
.assign(scenario=ncc_name)
218+
.drop_duplicates()
219+
.set_index(scen_mc_filter.names)
220+
.index
220221
)
221222

222223
if not include_cc:
@@ -236,9 +237,10 @@ def _load_lslr_for_ciam(
236237
mc_dim="mc_sample_id",
237238
lsl_var="lsl_msl05",
238239
lsl_ncc_var="lsl_ncc_msl05",
240+
site_id_dim="site_id",
239241
ncc_name="ncc",
240242
slr_0_year=2005,
241-
storage_options={},
243+
storage_options=None,
242244
quantiles=None,
243245
):
244246
if scen_mc_filter is None:
@@ -253,7 +255,10 @@ def _load_lslr_for_ciam(
253255
)
254256

255257
wcc = scen_mc_filter.get_level_values("scenario") != ncc_name
256-
scen_mc_ncc = scen_mc_filter[~wcc].droplevel("scenario").values
258+
if mc_dim in scen_mc_filter.names:
259+
scen_mc_ncc = scen_mc_filter[~wcc].droplevel("scenario").values
260+
else:
261+
scen_mc_ncc = None
257262
scen_mc_xr_wcc = (
258263
scen_mc_filter[wcc]
259264
.to_frame()
@@ -266,44 +271,56 @@ def _load_lslr_for_ciam(
266271

267272
# select the nearest SLR locations to the passed locations
268273
slr = _s2d(
269-
slr.sel(site_id=get_nearest_slrs(slr, lonlats).to_xarray()).drop("site_id")
274+
slr.sel({site_id_dim: get_nearest_slrs(slr, lonlats).to_xarray()}).drop(
275+
site_id_dim
276+
)
270277
).drop(["lat", "lon"], errors="ignore")
271278

279+
# convert to meters
280+
for v in slr.data_vars:
281+
if "units" in slr[v].attrs:
282+
slr[v] = slr[v].pint.quantify().pint.to("meters").pint.dequantify()
283+
272284
# select only the scenarios we wish to model
273285
if len(scen_mc_xr_wcc.scen_mc):
274286
slr_out = (
275287
slr[lsl_var]
276-
.sel({"scenario": scen_mc_xr_wcc.scenario, mc_dim: scen_mc_xr_wcc[mc_dim]})
277-
.set_index(scen_mc=["scenario", mc_dim])
288+
.sel({k: scen_mc_xr_wcc[k] for k in scen_mc_xr_wcc.data_vars})
289+
.set_index(scen_mc=list(scen_mc_xr_wcc.data_vars.keys()))
278290
)
279291
else:
280292
slr_out = xr.DataArray(
281293
[],
282294
dims=("scen_mc",),
283295
coords={
284-
"scen_mc": pd.MultiIndex.from_tuples([], names=["scenario", mc_dim])
296+
"scen_mc": pd.MultiIndex.from_tuples(
297+
[], names=list(scen_mc_xr_wcc.data_vars.keys())
298+
)
299+
if mc_dim in scen_mc_xr_wcc.data_vars
300+
else pd.Index([], name="scen_mc")
285301
},
286302
)
287303

288-
if len(scen_mc_ncc):
289-
slr_ncc = (
290-
slr[lsl_ncc_var]
291-
.sel({mc_dim: scen_mc_ncc})
292-
.expand_dims(scenario=[ncc_name])
293-
.stack(scen_mc=["scenario", mc_dim])
294-
)
304+
if include_ncc:
305+
slr_ncc = slr[lsl_ncc_var]
306+
stack_dims = ["scenario"]
307+
if scen_mc_ncc is not None:
308+
slr_ncc = slr_ncc.sel({mc_dim: scen_mc_ncc})
309+
stack_dims.append(mc_dim)
310+
311+
slr_ncc = slr_ncc.expand_dims(scenario=[ncc_name])
312+
if len(stack_dims) > 1:
313+
slr_ncc = slr_ncc.stack(scen_mc=stack_dims)
314+
else:
315+
slr_ncc = slr_ncc.rename({stack_dims[0]: "scen_mc"})
295316
slr_out = xr.concat((slr_out, slr_ncc), dim="scen_mc").sel(
296317
scen_mc=scen_mc_filter
297318
)
298319

299-
if "units" in slr_out.attrs:
300-
ix_names = slr_out.indexes["scen_mc"].names
301-
# hack to avoid pint destroying multi-indexed coords
302-
slr_out = (
303-
slr_out.pint.quantify()
304-
.pint.to("meters")
305-
.pint.dequantify()
306-
.set_index(scen_mc=ix_names)
320+
if quantiles is not None and mc_dim != "quantile":
321+
slr_out = slr_out.quantile(quantiles, dim=mc_dim)
322+
slr_out = slr_out.rename(scen_mc="scenario").stack(
323+
scen_mc=["scenario", "quantile"]
307324
)
308325

309326
# add on base year where slr is 0
@@ -347,14 +364,19 @@ def create_template_dataarray(dims, coords, chunks, dtype="float32", name=None):
347364
An empty dask-backed DataArray.
348365
"""
349366
lens = {k: len(v) for k, v in coords.items()}
350-
return xr.DataArray(
351-
da.empty(
352-
[lens[k] for k in dims], chunks=[chunks[k] for k in dims], dtype=dtype
353-
),
367+
out = xr.DataArray(
368+
da.empty([lens[k] for k in dims], chunks=[-1] * len(dims), dtype=dtype),
354369
dims=dims,
355370
coords={k: v for k, v in coords.items() if k in dims},
356371
name=name,
357372
)
373+
out.encoding["chunks"] = [chunks[k] for k in dims]
374+
if np.issubdtype(np.dtype(dtype), np.integer):
375+
fill_value = np.iinfo(dtype).max
376+
else:
377+
fill_value = "NaN"
378+
out.encoding["fill_value"] = fill_value
379+
return out
358380

359381

360382
def create_template_dataset(var_dims, coords, chunks, dtypes):
@@ -400,7 +422,7 @@ def check_finished_zarr_workflow(
400422
varname=None,
401423
final_selector={},
402424
mask=None,
403-
storage_options={},
425+
storage_options=None,
404426
):
405427
"""Check if a workflow that writes to a particular region of a zarr store has
406428
already run. This is useful when running pyCIAM in "probabilistic" mode across a
@@ -484,80 +506,6 @@ def check_finished_zarr_workflow(
484506
return finished
485507

486508

487-
def save_to_zarr_region(ds_in, store, already_aligned=False, storage_options={}):
488-
"""Wrapper around :py:method:`xarray.Dataset.to_zarr` when specifying the `region`
489-
kwarg. This function allows you to avoid boilerplate to figure out the integer slice
490-
objects needed to pass as `region` when calling `:py:meth:xarray.Dataset.to_zarr`.
491-
492-
Parameters
493-
----------
494-
ds_in : :py:class:`xarray.Dataset` or :py:class:`xarray.DataArray`
495-
Dataset or DataArray to save to a specific region of a Zarr store
496-
store : Path-like
497-
Path to Zarr store
498-
already_aligned : bool, default False
499-
If True, assume that the coordinates of `ds_in` are already ordered the same
500-
way as those of `store`. May save some computation, but will miss-attribute
501-
values to coordinates if set to True when coords are not aligned.
502-
storage_options : dict, optional
503-
Passed to :py:function:`xarray.open_zarr`
504-
505-
Returns
506-
-------
507-
None :
508-
No return value but `ds_in` is saved to the appropriate region of `store`.
509-
510-
Raises
511-
------
512-
ValueError
513-
If `ds_in` is an unnamed DataArray and `store` has more than one variable.
514-
AssertionError
515-
If any coordinate values of `ds_in` are not contiguous within `store`.
516-
"""
517-
ds_out = xr.open_zarr(str(store), chunks=None, storage_options=storage_options)
518-
519-
# convert dataarray to dataset if needed
520-
if isinstance(ds_in, xr.DataArray):
521-
if ds_in.name is not None:
522-
ds_in = ds_in.to_dataset()
523-
else:
524-
if len(ds_out.data_vars) != 1:
525-
raise ValueError(
526-
"``ds_in`` is an unnamed DataArray and ``store`` has more than one "
527-
"variable."
528-
)
529-
ds_in = ds_in.to_dataset(name=list(ds_out.data_vars)[0])
530-
531-
# align
532-
for v in ds_in.data_vars:
533-
ds_in[v] = ds_in[v].transpose(*ds_out[v].dims).astype(ds_out[v].dtype)
534-
535-
# find appropriate regions
536-
alignment_dims = {}
537-
regions = {}
538-
for r in ds_in.dims:
539-
if len(ds_in[r]) == len(ds_out[r]):
540-
alignment_dims[r] = ds_out[r].values
541-
continue
542-
alignment_dims[r] = [v for v in ds_out[r].values if v in ds_in[r].values]
543-
valid_ixs = np.arange(len(ds_out[r]))[ds_out[r].isin(alignment_dims[r]).values]
544-
n_valid = len(valid_ixs)
545-
st = valid_ixs[0]
546-
end = valid_ixs[-1]
547-
assert end - st == n_valid - 1, (
548-
f"Indices are not continuous along dimension {r}"
549-
)
550-
regions[r] = slice(st, end + 1)
551-
552-
# align coords
553-
if not already_aligned:
554-
ds_in = ds_in.sel(alignment_dims)
555-
556-
ds_in.drop_vars(ds_in.coords).to_zarr(
557-
str(store), region=regions, storage_options=storage_options
558-
)
559-
560-
561509
def get_nearest_slrs(slr_ds, lonlats, x1="seg_lon", y1="seg_lat"):
562510
unique_lonlats = lonlats[[x1, y1]].drop_duplicates()
563511
slr_lonlat = slr_ds[["lon", "lat"]].to_dataframe()
@@ -582,9 +530,11 @@ def load_ciam_inputs(
582530
input_store,
583531
slr_store,
584532
params,
585-
seg_vals,
533+
selectors={},
586534
slr_names=None,
587535
seg_var="seg",
536+
lsl_var="lsl_msl05",
537+
slr_site_id_dim="site_id",
588538
surge_lookup_store=None,
589539
ssp=None,
590540
iam=None,
@@ -593,7 +543,7 @@ def load_ciam_inputs(
593543
include_cc=True,
594544
mc_dim="mc_sample_id",
595545
quantiles=None,
596-
storage_options={},
546+
storage_options=None,
597547
):
598548
"""Load, process, and format all inputs needed to run pyCIAM.
599549
@@ -608,9 +558,8 @@ def load_ciam_inputs(
608558
params : dict
609559
Dictionary of model parameters, typically loaded from a JSON file. See
610560
:file:`../params.json` for an example of the required parameters.
611-
seg_vals : list of str
612-
Defines the subset of regions (along dimension `seg_var`) that the function
613-
will prep. Subsets are used to run CIAM in parallel.
561+
selectors : list of str
562+
Defines the subset of regions (along dimension `seg_var`) and/or scenario that the function will prep. Subsets are used to run CIAM in parallel.
614563
slr_names : list of str, optional
615564
If `slr_store` is a list of multiple SLR datasets, this must be a list of the
616565
same length providing names for each SLR dataset. This is used as a suffix for
@@ -620,6 +569,10 @@ def load_ciam_inputs(
620569
seg_var : str, default "seg_var"
621570
The name of the dimension in `input_store` along which the function will
622571
subset using `seg_vals`
572+
lsl_var : str, default "lsl_msl05"
573+
The name of the variable in ``slr_store`` containing local SLR values
574+
slr_site_id_dim : str, default "site_id"
575+
The name of the location dimension in ``slr_store``.
623576
surge_lookup_store : Path-like, optional
624577
If not None, will also load and process data from an ESL impacts lookup table
625578
(see `lookup.create_surge_lookup`). If included in a call to
@@ -678,14 +631,14 @@ def load_ciam_inputs(
678631
If `ssp` or `iam` is specified and the corresponding variables are not
679632
present in the Zarr store located at `input_store`.
680633
"""
681-
selectors = {"year": slice(params.model_start, None)}
634+
selectors = {"year": slice(params.model_start, None), **selectors}
682635
if ssp is not None:
683636
selectors["ssp"] = ssp
684637
if iam is not None:
685638
selectors["iam"] = iam
686639
inputs = prep_sliiders(
687640
input_store,
688-
seg_vals,
641+
selectors[seg_var],
689642
# dropping the "refA_scenario_selectors" b/c this doesn't need to be added to
690643
# the input dataset object
691644
constants=params[params.map(type) != dict].to_dict(), # noqa: E721
@@ -704,7 +657,7 @@ def load_ciam_inputs(
704657
xr.open_zarr(
705658
str(surge_lookup_store), chunks=None, storage_options=storage_options
706659
)
707-
.sel({seg_var: seg_vals})
660+
.sel({seg_var: selectors["seg"]})
708661
.load()
709662
)
710663
if seg_var != "seg":
@@ -730,20 +683,31 @@ def load_ciam_inputs(
730683
include_ncc=include_ncc,
731684
include_cc=include_cc,
732685
ncc_name=ncc_names[sx],
686+
lsl_var=lsl_var,
733687
mc_dim=mc_dim,
734688
quantiles=quantiles,
689+
site_id_dim=slr_site_id_dim,
735690
storage_options=storage_options,
736691
)
737692
for sx, s in enumerate(slr_store)
738693
],
739694
dim="scen_mc",
740695
)
696+
if scen_mc_filter is None:
697+
slr = slr.unstack("scen_mc")
698+
699+
slr = slr.sel({k: v for k, v in selectors.items() if k in slr.dims})
741700

742701
return inputs, slr, surge
743702

744703

745704
def load_diaz_inputs(
746-
input_store, seg_vals, params, include_ncc=True, include_cc=True, storage_options={}
705+
input_store,
706+
seg_vals,
707+
params,
708+
include_ncc=True,
709+
include_cc=True,
710+
storage_options=None,
747711
):
748712
"""Load the original inputs used in Diaz 2016.
749713

0 commit comments

Comments
 (0)