|
2 | 2 |
|
3 | 3 | import warnings |
4 | 4 | from datetime import datetime |
5 | | -from itertools import chain |
| 5 | +from itertools import chain, product |
6 | 6 | from typing import Literal, TypedDict, get_args |
7 | 7 |
|
8 | 8 | import cf_xarray # noqa: F401 |
|
17 | 17 |
|
18 | 18 | from xcdat import bounds # noqa: F401 |
19 | 19 | from xcdat._logger import _setup_custom_logger |
20 | | -from xcdat.axis import get_dim_coords |
| 20 | +from xcdat.axis import get_dim_coords, get_dim_keys |
21 | 21 | from xcdat.dataset import _get_data_var |
22 | 22 |
|
23 | 23 | logger = _setup_custom_logger(__name__) |
@@ -2086,6 +2086,210 @@ def _calculate_departures( |
2086 | 2086 | return ds_departs |
2087 | 2087 |
|
2088 | 2088 |
|
| 2089 | + def compute_monthly_average(self, data_var): |
| 2090 | + """ |
| 2091 | + Computes monthly averages for dataset |
| 2092 | +
|
| 2093 | + This function ensures that the dataset's time bounds are |
| 2094 | + ordered correctly, computes the target monthly time bounds |
| 2095 | + and associated weights, and then the monthly average. |
| 2096 | +
|
| 2097 | + Parameters |
| 2098 | + ---------- |
| 2099 | + data_var : str |
| 2100 | + The key of the data variable. |
| 2101 | +
|
| 2102 | + Returns |
| 2103 | + ------- |
| 2104 | + xr.Dataset |
| 2105 | + Dataset with the computed monthly average. |
| 2106 | +
|
| 2107 | + Notes |
| 2108 | + ----- |
| 2109 | + The monthly averages are computed from January - December, but |
| 2110 | + it is possible the source dataset starts after January or ends |
| 2111 | + before December. A potential enhancement would be to cater the |
| 2112 | + bounds to the source dataset. For example, if the source dataset |
| 2113 | + starts in March 2010, the resulting monthly dataset would begin |
| 2114 | + in March 2010. |
| 2115 | + """ |
| 2116 | + ds = self._dataset.copy() |
| 2117 | + # ensure source time bounds are ordered correctly |
| 2118 | + ds.temporal.ensure_bounds_order() |
| 2119 | + # get target time and bounds |
| 2120 | + target_time, target_bnds = ds.temporal.generate_monthly_bounds() |
| 2121 | + # get temporal weights |
| 2122 | + weights = ds.temporal.get_temporal_weights(target_bnds) |
| 2123 | + # compute average and return resulting dataset |
| 2124 | + return ds.temporal._experimental_averager(data_var, weights, target_bnds) |
| 2125 | + |
| 2126 | + |
| 2127 | + def _experimental_averager(self, data_var, weights, target_bnds): |
| 2128 | + """ |
| 2129 | + Calculates time period averages for a set of weights and bounds. |
| 2130 | +
|
| 2131 | + Parameters |
| 2132 | + ---------- |
| 2133 | + data_var : str |
| 2134 | + The key of the data variable. |
| 2135 | +
|
| 2136 | + weights : xr.DataArray |
| 2137 | + The weight of each source time slice that should be used to compute |
| 2138 | + a temporal average for each target time slice [target_time, source_time]. |
| 2139 | +
|
| 2140 | + target_bnds : xr.DataArray |
| 2141 | + The time_bnds for the target time slices. |
| 2142 | +
|
| 2143 | + Returns |
| 2144 | + ------- |
| 2145 | + xr.Dataset |
| 2146 | + The dataset with the computed temporal averages |
| 2147 | + """ |
| 2148 | + ds = self._dataset.copy() |
| 2149 | + # get time key |
| 2150 | + time_key = get_dim_keys(ds, 'T') |
| 2151 | + # convert to weighted array |
| 2152 | + da_weighted = ds[data_var].weighted(weights) |
| 2153 | + # compute weighted mean |
| 2154 | + with xr.set_options(keep_attrs=True): |
| 2155 | + da_mean = da_weighted.mean(dim=time_key) |
| 2156 | + # revert to original time coordinate name |
| 2157 | + da_mean = da_mean.rename({'target_time': time_key}) |
| 2158 | + # ensure order is the same as original dataset |
| 2159 | + da_mean = da_mean.transpose(*ds[data_var].dims) |
| 2160 | + # create output dataset |
| 2161 | + dsmean = ds.copy() |
| 2162 | + # The original time dimension is dropped from the dataset because |
| 2163 | + # it becomes obsolete after the data variable is averaged. When the |
| 2164 | + # averaged data variable is added to the dataset, the new time dimension |
| 2165 | + # and its associated coordinates are also added. |
| 2166 | + dsmean = dsmean.drop_dims(time_key) |
| 2167 | + # add weighted mean data array to output dataset |
| 2168 | + dsmean[data_var] = da_mean |
| 2169 | + # add the time bounds to the dataset |
| 2170 | + dsmean[time_key + '_bnds'] = target_bnds |
| 2171 | + return dsmean |
| 2172 | + |
| 2173 | + |
| 2174 | + def get_temporal_weights(self, target_bnds): |
| 2175 | + """Compute the temporal weights for a set of target time bounds. |
| 2176 | +
|
| 2177 | + Parameters |
| 2178 | + ---------- |
| 2179 | + target_bnds : xr.DataArray |
| 2180 | + The bounds for target time averages |
| 2181 | +
|
| 2182 | + Returns |
| 2183 | + ------- |
| 2184 | + xr.DataArray |
| 2185 | + The temporal weights that should be applied to the source data |
| 2186 | + to produce time averaged data corresponding to the target time |
| 2187 | + bounds |
| 2188 | + """ |
| 2189 | + ds = self._dataset.copy() |
| 2190 | + # Get time key and source time bounds |
| 2191 | + time_key = get_dim_keys(ds, 'T') |
| 2192 | + source_bnds = ds.cf.get_bounds(time_key).values |
| 2193 | + target_time = target_bnds['time'] |
| 2194 | + |
| 2195 | + # Preallocate weight matrix |
| 2196 | + weights = np.zeros((len(target_bnds), len(ds[time_key]))) |
| 2197 | + |
| 2198 | + # bounds adjustment |
| 2199 | + for i, tbnd in enumerate(target_bnds.values): |
| 2200 | + # Adjust source bounds to fit within target bounds |
| 2201 | + sbnds = source_bnds.copy() |
| 2202 | + sbnds[:, 0] = np.maximum(sbnds[:, 0], tbnd[0]) # Lower bound adjustment |
| 2203 | + sbnds[:, 1] = np.minimum(sbnds[:, 1], tbnd[1]) # Upper bound adjustment |
| 2204 | + |
| 2205 | + # Handle cases where bounds are outside the target range |
| 2206 | + sbnds[:, 0] = np.minimum(sbnds[:, 0], tbnd[1]) # Lower bound > upper target bound |
| 2207 | + sbnds[:, 1] = np.maximum(sbnds[:, 1], tbnd[0]) # Upper bound < lower target bound |
| 2208 | + |
| 2209 | + # Compute weights as the difference between bounds |
| 2210 | + w = (sbnds[:, 1] - sbnds[:, 0]).astype("timedelta64[ns]") |
| 2211 | + weights[i, :] = w |
| 2212 | + |
| 2213 | + # Convert weights to xarray DataArray |
| 2214 | + weights = xr.DataArray( |
| 2215 | + data=weights, |
| 2216 | + dims=['target_time', 'time'], |
| 2217 | + coords={'target_time': target_time.values, 'time': ds[time_key].values} |
| 2218 | + ) |
| 2219 | + return weights |
| 2220 | + |
| 2221 | + |
| 2222 | + def generate_monthly_bounds(self): |
| 2223 | + """Generates monthly time bounds and the corresponding time axis |
| 2224 | + for a dataset. |
| 2225 | +
|
| 2226 | + This method will generate monthly time bounds, e.g., |
| 2227 | + [["2010-01-01 00:00:00", "2010-02-01 00:00:00"], |
| 2228 | + ["2010-02-01 00:00:00", "2010-03-01 00:00:00"], |
| 2229 | + ["2010-03-01 00:00:00", "2010-04-01 00:00:00"], |
| 2230 | + ...] |
| 2231 | +
|
| 2232 | + and a time axis, e.g., |
| 2233 | + ["2010-01-16 12:00:00", |
| 2234 | + "2010-02-15 00:00:00", |
| 2235 | + "2010-03-16 12:00:00", |
| 2236 | + ...] |
| 2237 | +
|
| 2238 | + for a dataset. The arrays will start with January 1 of the first |
| 2239 | + year in the original dataset going through December of the final year |
| 2240 | + in the original dataset. |
| 2241 | +
|
| 2242 | + Returns |
| 2243 | + ------- |
| 2244 | + monthly_time : xr.DataArray |
| 2245 | + The centered time axis corresponding to the generated bounds. |
| 2246 | +
|
| 2247 | + monthly_bnds : xr.DataArray |
| 2248 | + The generated monthly bounds. |
| 2249 | + """ |
| 2250 | + ds = self._dataset.copy() |
| 2251 | + # get all years in source dataset |
| 2252 | + time_key = get_dim_keys(ds, 'T') |
| 2253 | + years = list(set([t.year for t in ds[time_key].values])) |
| 2254 | + # get time type |
| 2255 | + ttype = type(ds[time_key].values[0]) |
| 2256 | + # create target time bounds and time axis |
| 2257 | + monthly_bnds = [] |
| 2258 | + monthly_time = [] |
| 2259 | + for year, month in product(years, range(1, 13)): |
| 2260 | + lower_bnd = ttype(year, month, 1) |
| 2261 | + upper_bnd = ds.bounds._add_months_to_timestep(lower_bnd, ttype, 1) |
| 2262 | + center_time = lower_bnd + (upper_bnd - lower_bnd)/2. |
| 2263 | + monthly_bnds.append([lower_bnd, upper_bnd]) |
| 2264 | + monthly_time.append(center_time) |
| 2265 | + # generate xarray dataarray objexts |
| 2266 | + monthly_time = xr.DataArray(data=monthly_time, |
| 2267 | + dims=[time_key], |
| 2268 | + coords={time_key: monthly_time}) |
| 2269 | + monthly_time.encoding = ds[time_key].encoding |
| 2270 | + target_time = monthly_time.assign_attrs({'bounds': time_key + '_bnds'}) |
| 2271 | + monthly_bnds = xr.DataArray(data=monthly_bnds, |
| 2272 | + dims=[time_key, 'bnds'], |
| 2273 | + coords={time_key: monthly_time}) |
| 2274 | + monthly_bnds.encoding = ds[time_key].encoding |
| 2275 | + return monthly_time, monthly_bnds |
| 2276 | + |
| 2277 | + |
| 2278 | + def ensure_bounds_order(self): |
| 2279 | + """Ensures that time bounds are ordered [earlier, later] |
| 2280 | +
|
| 2281 | + Raises |
| 2282 | + ------ |
| 2283 | + ValueError |
| 2284 | + If there are any bounds that are out of order. |
| 2285 | + """ |
| 2286 | + ds = self._dataset.copy() |
| 2287 | + time_bnds = ds.bounds.get_bounds("T") |
| 2288 | + for tbnd in time_bnds.values: |
| 2289 | + if tbnd[0] >= tbnd[1]: |
| 2290 | + raise ValueError('Time bounds are not ordered from low-to-high') |
| 2291 | + |
| 2292 | + |
2089 | 2293 | def _infer_freq(time_coords: xr.DataArray) -> Frequency: |
2090 | 2294 | """Infers the time frequency from the coordinates. |
2091 | 2295 |
|
|
0 commit comments