Skip to content

Commit b62383e

Browse files
pochedlstomvothecoder
authored andcommitted
Prototyping more sophisticated bounds handling for temporal averaging
1 parent 0d7e112 commit b62383e

1 file changed

Lines changed: 206 additions & 2 deletions

File tree

xcdat/temporal.py

Lines changed: 206 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import warnings
44
from datetime import datetime
5-
from itertools import chain
5+
from itertools import chain, product
66
from typing import Literal, TypedDict, get_args
77

88
import cf_xarray # noqa: F401
@@ -17,7 +17,7 @@
1717

1818
from xcdat import bounds # noqa: F401
1919
from xcdat._logger import _setup_custom_logger
20-
from xcdat.axis import get_dim_coords
20+
from xcdat.axis import get_dim_coords, get_dim_keys
2121
from xcdat.dataset import _get_data_var
2222

2323
logger = _setup_custom_logger(__name__)
@@ -2086,6 +2086,210 @@ def _calculate_departures(
20862086
return ds_departs
20872087

20882088

2089+
def compute_monthly_average(self, data_var):
2090+
"""
2091+
Computes monthly averages for dataset
2092+
2093+
This function ensures that the dataset's time bounds are
2094+
ordered correctly, computes the target monthly time bounds
2095+
and associated weights, and then the monthly average.
2096+
2097+
Parameters
2098+
----------
2099+
data_var : str
2100+
The key of the data variable.
2101+
2102+
Returns
2103+
-------
2104+
xr.Dataset
2105+
Dataset with the computed monthly average.
2106+
2107+
Notes
2108+
-----
2109+
The monthly averages are computed from January - December, but
2110+
it is possible the source dataset starts after January or ends
2111+
before December. A potential enhancement would be to cater the
2112+
bounds to the source dataset. For example, if the source dataset
2113+
starts in March 2010, the resulting monthly dataset would begin
2114+
in March 2010.
2115+
"""
2116+
ds = self._dataset.copy()
2117+
# ensure source time bounds are ordered correctly
2118+
ds.temporal.ensure_bounds_order()
2119+
# get target time and bounds
2120+
target_time, target_bnds = ds.temporal.generate_monthly_bounds()
2121+
# get temporal weights
2122+
weights = ds.temporal.get_temporal_weights(target_bnds)
2123+
# compute average and return resulting dataset
2124+
return ds.temporal._experimental_averager(data_var, weights, target_bnds)
2125+
2126+
2127+
def _experimental_averager(self, data_var, weights, target_bnds):
2128+
"""
2129+
Calculates time period averages for a set of weights and bounds.
2130+
2131+
Parameters
2132+
----------
2133+
data_var : str
2134+
The key of the data variable.
2135+
2136+
weights : xr.DataArray
2137+
The weight of each source time slice that should be used to compute
2138+
a temporal average for each target time slice [target_time, source_time].
2139+
2140+
target_bnds : xr.DataArray
2141+
The time_bnds for the target time slices.
2142+
2143+
Returns
2144+
-------
2145+
xr.Dataset
2146+
The dataset with the computed temporal averages
2147+
"""
2148+
ds = self._dataset.copy()
2149+
# get time key
2150+
time_key = get_dim_keys(ds, 'T')
2151+
# convert to weighted array
2152+
da_weighted = ds[data_var].weighted(weights)
2153+
# compute weighted mean
2154+
with xr.set_options(keep_attrs=True):
2155+
da_mean = da_weighted.mean(dim=time_key)
2156+
# revert to original time coordinate name
2157+
da_mean = da_mean.rename({'target_time': time_key})
2158+
# ensure order is the same as original dataset
2159+
da_mean = da_mean.transpose(*ds[data_var].dims)
2160+
# create output dataset
2161+
dsmean = ds.copy()
2162+
# The original time dimension is dropped from the dataset because
2163+
# it becomes obsolete after the data variable is averaged. When the
2164+
# averaged data variable is added to the dataset, the new time dimension
2165+
# and its associated coordinates are also added.
2166+
dsmean = dsmean.drop_dims(time_key)
2167+
# add weighted mean data array to output dataset
2168+
dsmean[data_var] = da_mean
2169+
# add the time bounds to the dataset
2170+
dsmean[time_key + '_bnds'] = target_bnds
2171+
return dsmean
2172+
2173+
2174+
def get_temporal_weights(self, target_bnds):
2175+
"""Compute the temporal weights for a set of target time bounds.
2176+
2177+
Parameters
2178+
----------
2179+
target_bnds : xr.DataArray
2180+
The bounds for target time averages
2181+
2182+
Returns
2183+
-------
2184+
xr.DataArray
2185+
The temporal weights that should be applied to the source data
2186+
to produce time averaged data corresponding to the target time
2187+
bounds
2188+
"""
2189+
ds = self._dataset.copy()
2190+
# Get time key and source time bounds
2191+
time_key = get_dim_keys(ds, 'T')
2192+
source_bnds = ds.cf.get_bounds(time_key).values
2193+
target_time = target_bnds['time']
2194+
2195+
# Preallocate weight matrix
2196+
weights = np.zeros((len(target_bnds), len(ds[time_key])))
2197+
2198+
# bounds adjustment
2199+
for i, tbnd in enumerate(target_bnds.values):
2200+
# Adjust source bounds to fit within target bounds
2201+
sbnds = source_bnds.copy()
2202+
sbnds[:, 0] = np.maximum(sbnds[:, 0], tbnd[0]) # Lower bound adjustment
2203+
sbnds[:, 1] = np.minimum(sbnds[:, 1], tbnd[1]) # Upper bound adjustment
2204+
2205+
# Handle cases where bounds are outside the target range
2206+
sbnds[:, 0] = np.minimum(sbnds[:, 0], tbnd[1]) # Lower bound > upper target bound
2207+
sbnds[:, 1] = np.maximum(sbnds[:, 1], tbnd[0]) # Upper bound < lower target bound
2208+
2209+
# Compute weights as the difference between bounds
2210+
w = (sbnds[:, 1] - sbnds[:, 0]).astype("timedelta64[ns]")
2211+
weights[i, :] = w
2212+
2213+
# Convert weights to xarray DataArray
2214+
weights = xr.DataArray(
2215+
data=weights,
2216+
dims=['target_time', 'time'],
2217+
coords={'target_time': target_time.values, 'time': ds[time_key].values}
2218+
)
2219+
return weights
2220+
2221+
2222+
def generate_monthly_bounds(self):
2223+
"""Generates monthly time bounds and the corresponding time axis
2224+
for a dataset.
2225+
2226+
This method will generate monthly time bounds, e.g.,
2227+
[["2010-01-01 00:00:00", "2010-02-01 00:00:00"],
2228+
["2010-02-01 00:00:00", "2010-03-01 00:00:00"],
2229+
["2010-03-01 00:00:00", "2010-04-01 00:00:00"],
2230+
...]
2231+
2232+
and a time axis, e.g.,
2233+
["2010-01-16 12:00:00",
2234+
"2010-02-15 00:00:00",
2235+
"2010-03-16 12:00:00",
2236+
...]
2237+
2238+
for a dataset. The arrays will start with January 1 of the first
2239+
year in the original dataset going through December of the final year
2240+
in the original dataset.
2241+
2242+
Returns
2243+
-------
2244+
monthly_time : xr.DataArray
2245+
The centered time axis corresponding to the generated bounds.
2246+
2247+
monthly_bnds : xr.DataArray
2248+
The generated monthly bounds.
2249+
"""
2250+
ds = self._dataset.copy()
2251+
# get all years in source dataset
2252+
time_key = get_dim_keys(ds, 'T')
2253+
years = list(set([t.year for t in ds[time_key].values]))
2254+
# get time type
2255+
ttype = type(ds[time_key].values[0])
2256+
# create target time bounds and time axis
2257+
monthly_bnds = []
2258+
monthly_time = []
2259+
for year, month in product(years, range(1, 13)):
2260+
lower_bnd = ttype(year, month, 1)
2261+
upper_bnd = ds.bounds._add_months_to_timestep(lower_bnd, ttype, 1)
2262+
center_time = lower_bnd + (upper_bnd - lower_bnd)/2.
2263+
monthly_bnds.append([lower_bnd, upper_bnd])
2264+
monthly_time.append(center_time)
2265+
# generate xarray dataarray objexts
2266+
monthly_time = xr.DataArray(data=monthly_time,
2267+
dims=[time_key],
2268+
coords={time_key: monthly_time})
2269+
monthly_time.encoding = ds[time_key].encoding
2270+
target_time = monthly_time.assign_attrs({'bounds': time_key + '_bnds'})
2271+
monthly_bnds = xr.DataArray(data=monthly_bnds,
2272+
dims=[time_key, 'bnds'],
2273+
coords={time_key: monthly_time})
2274+
monthly_bnds.encoding = ds[time_key].encoding
2275+
return monthly_time, monthly_bnds
2276+
2277+
2278+
def ensure_bounds_order(self):
2279+
"""Ensures that time bounds are ordered [earlier, later]
2280+
2281+
Raises
2282+
------
2283+
ValueError
2284+
If there are any bounds that are out of order.
2285+
"""
2286+
ds = self._dataset.copy()
2287+
time_bnds = ds.bounds.get_bounds("T")
2288+
for tbnd in time_bnds.values:
2289+
if tbnd[0] >= tbnd[1]:
2290+
raise ValueError('Time bounds are not ordered from low-to-high')
2291+
2292+
20892293
def _infer_freq(time_coords: xr.DataArray) -> Frequency:
20902294
"""Infers the time frequency from the coordinates.
20912295

0 commit comments

Comments
 (0)