|
1 | | -from collections.abc import Callable |
2 | | -from dataclasses import dataclass |
3 | | -from typing import Protocol |
4 | | - |
5 | | -import xarray as xr |
| 1 | +from isku.extract import ( |
| 2 | + ExtractionTemplate, |
| 3 | + GridWeightingRegions, |
| 4 | + RegionExtractor, |
| 5 | + build_extraction_template, |
| 6 | + extract_regions, |
| 7 | +) |
| 8 | +from isku.project import ProjectionTemplate, build_projection_template, project |
6 | 9 |
|
7 | 10 | __all__ = [ |
8 | 11 | "ExtractionTemplate", |
|
14 | 17 | "extract_regions", |
15 | 18 | "project", |
16 | 19 | ] |
17 | | - |
18 | | - |
19 | | -class ExtractionTemplate(Protocol): |
20 | | - """ |
21 | | - Template for pre and post region extraction transformation |
22 | | -
|
23 | | - See Also |
24 | | - -------- |
25 | | - build_extraction_template: Quickly build extraction template from functions for regionalization with pre/post transformations. |
26 | | - extract_regions: Apply a template to extract a new regionalized dataset from gridded data. |
27 | | - RegionExtractor: Protocol for regionalizing, or extracting regions from a dataset. |
28 | | - """ |
29 | | - |
30 | | - def pre_extract(self, ds: xr.Dataset) -> xr.Dataset: |
31 | | - """ |
32 | | - Transform dataset before region extraction |
33 | | - """ |
34 | | - ... |
35 | | - |
36 | | - def post_extract(self, ds: xr.Dataset) -> xr.Dataset: |
37 | | - """ |
38 | | - Transform dataset after region extraction |
39 | | - """ |
40 | | - ... |
41 | | - |
42 | | - |
43 | | -class RegionExtractor(Protocol): |
44 | | - """ |
45 | | - Protocol for extracting regions from gridded data |
46 | | -
|
47 | | - See Also |
48 | | - -------- |
49 | | - extract_regions: Apply a template to extract a new regionalized dataset from gridded data with pre/post transformations. |
50 | | - ExtractionTemplate: Technical protocol for a workflow with pre/post regionalization transformations. |
51 | | - """ |
52 | | - |
53 | | - def extract_regions(self, ds: xr.Dataset) -> xr.Dataset: |
54 | | - """ |
55 | | - Extract and aggregate gridded dataset points into regionalized dataset |
56 | | - """ |
57 | | - ... |
58 | | - |
59 | | - |
60 | | -# This dataclass is a quick and simple way to get a concrete instance of the protocol. |
61 | | -@dataclass(frozen=True) |
62 | | -class _SimpleExtractionTemplate(ExtractionTemplate): |
63 | | - pre_extract: Callable[[xr.Dataset], xr.Dataset] |
64 | | - post_extract: Callable[[xr.Dataset], xr.Dataset] |
65 | | - |
66 | | - |
67 | | -def build_extraction_template( |
68 | | - *, pre: Callable[[xr.Dataset], xr.Dataset], post: Callable[[xr.Dataset], xr.Dataset] |
69 | | -) -> ExtractionTemplate: |
70 | | - """ |
71 | | - Build a template of tranformation steps applied to input gridded data, pre/post regionalization, to create a derived variable as output |
72 | | -
|
73 | | - This function is a quick and simple way to build an ExtractionTemplate from two simple functions. |
74 | | -
|
75 | | - These steps should be general. They may contain logic for sanity checks |
76 | | - on inputs and outputs, calculating derived variables and climate indices, |
77 | | - adding or checking metadata or units. Avoid including logic for cleaning, |
78 | | - or harmonizing input data, especially if it is specific to a single |
79 | | - project's usecase. Generally avoid using a single strategy to output |
80 | | - multiple unrelated variables. |
81 | | -
|
82 | | - See Also |
83 | | - -------- |
84 | | - extract_regions: Apply a template to extract a new regionalized dataset from gridded data. |
85 | | - build_extraction_template: Quickly build extraction template from functions for regionalization. |
86 | | - ExtractionTemplate: The underlaying protocol for a workflow that extracts a regionalized dataset. |
87 | | - """ |
88 | | - return _SimpleExtractionTemplate(pre_extract=pre, post_extract=post) |
89 | | - |
90 | | - |
91 | | -# Use class for segment weights because we're making assumptions/enforcements about the weight data's content and interactions... |
92 | | -class GridWeightingRegions(RegionExtractor): |
93 | | - """ |
94 | | - Regions that can be extracted from regularly-gridded data after weighting grid points |
95 | | -
|
96 | | - 'weights' dataset must have "lat", "lon", "weight", "region". |
97 | | -
|
98 | | - Raises |
99 | | - ------ |
100 | | - ValueError |
101 | | - If 'weights' is missing "lat", "lon", "weight" or "region" variables. |
102 | | -
|
103 | | - See Also |
104 | | - -------- |
105 | | - extract_regions: Extract new regionalized dataset. |
106 | | - build_extraction_template: Quickly build extraction template from functions for regionalization. |
107 | | - RegionExtractor: Protocol for regionalizing, or extracting regions from a dataset. |
108 | | - """ |
109 | | - |
110 | | - def __init__(self, weights: xr.Dataset): |
111 | | - target_variables = ("lat", "lon", "weight", "region") |
112 | | - missing_variables = [v for v in target_variables if v not in weights.variables] |
113 | | - if missing_variables: |
114 | | - raise ValueError( |
115 | | - f"input weights is missing required {missing_variables} variable(s)" |
116 | | - ) |
117 | | - self._data = weights |
118 | | - |
119 | | - def extract_regions(self, ds: xr.Dataset) -> xr.Dataset: |
120 | | - """ |
121 | | - Regionalize input gridded data after multiplying 'ds' by weights and summing the product within each region. |
122 | | -
|
123 | | - 'ds' must have "lat", "lon" coordinates exactly matching "lat", "lon" in weights. |
124 | | - """ |
125 | | - # TODO: See how this errors in different common scenarios. What happens on the |
126 | | - # unhappy path? |
127 | | - region_sel = ds.sel(lat=self._data["lat"], lon=self._data["lon"]) |
128 | | - out = (region_sel * self._data["weight"]).groupby(self._data["region"]).sum() |
129 | | - # TODO: Maybe drop lat/lon and set 'region' as dim/coord? I feel like we can do |
130 | | - # this because we're asking weights to strictly match input's lat/lon. Maybe |
131 | | - # make this a req of segment weights we're reading in? |
132 | | - return out |
133 | | - |
134 | | - |
135 | | -def extract_regions( |
136 | | - ds: xr.Dataset, *, template: ExtractionTemplate, regions: RegionExtractor |
137 | | -) -> xr.Dataset: |
138 | | - """ |
139 | | - Use transformations in 'template' to extract 'regions' from gridded dataset, 'ds', returning a regionalized dataset |
140 | | -
|
141 | | - This function specifically does not just regionalize through zonal aggregation. It uses 'template' to apply pre/post regionalization transformations to create new datasets and variables. |
142 | | -
|
143 | | - See Also |
144 | | - -------- |
145 | | - build_extraction_template: Quickly build extraction workflow from functions for regionalization. |
146 | | - """ |
147 | | - return template.post_extract(regions.extract_regions(template.pre_extract(ds))) |
148 | | - |
149 | | - |
150 | | -class ProjectionTemplate(Protocol): |
151 | | - """ |
152 | | - Template for projecting a model with pre and post processing. |
153 | | -
|
154 | | - See Also |
155 | | - -------- |
156 | | - build_projection_template: Build a projection template from simple functions. |
157 | | - """ |
158 | | - |
159 | | - def pre_project(self, d: xr.Dataset) -> xr.Dataset: |
160 | | - """ |
161 | | - Pre-process a dataset before projection |
162 | | - """ |
163 | | - ... |
164 | | - |
165 | | - def project(self, d: xr.Dataset) -> xr.Dataset: |
166 | | - """ |
167 | | - Create a projection from a dataset |
168 | | - """ |
169 | | - ... |
170 | | - |
171 | | - def post_project(self, d: xr.Dataset) -> xr.Dataset: |
172 | | - """ |
173 | | - Process a projected dataset |
174 | | - """ |
175 | | - ... |
176 | | - |
177 | | - |
178 | | -# This dataclass is a quick and simple way to get a concrete instance of the protocol. |
179 | | -@dataclass(frozen=True) |
180 | | -class _SimpleProjectionTemplate(ProjectionTemplate): |
181 | | - pre_project: Callable[[xr.Dataset], xr.Dataset] |
182 | | - project: Callable[[xr.Dataset], xr.Dataset] |
183 | | - post_project: Callable[[xr.Dataset], xr.Dataset] |
184 | | - |
185 | | - |
186 | | -def build_projection_template( |
187 | | - *, |
188 | | - pre: Callable[[xr.Dataset], xr.Dataset], |
189 | | - project: Callable[[xr.Dataset], xr.Dataset], |
190 | | - post: Callable[[xr.Dataset], xr.Dataset], |
191 | | -) -> ProjectionTemplate: |
192 | | - """ |
193 | | - Use simple functions to quickly build a model to project effects, impacts and/or damages. |
194 | | -
|
195 | | - This function is a quick and simple way to build an ProjectionTemplate from three simple functions. |
196 | | -
|
197 | | - See Also |
198 | | - -------- |
199 | | - project: Apply a projection template to a dataset. |
200 | | - ProjectionTemplate: Technical ProjectionTemplate protocol. |
201 | | - """ |
202 | | - return _SimpleProjectionTemplate( |
203 | | - pre_project=pre, |
204 | | - project=project, |
205 | | - post_project=post, |
206 | | - ) |
207 | | - |
208 | | - |
209 | | -def project(d: xr.Dataset, *, model: ProjectionTemplate) -> xr.Dataset: |
210 | | - """ |
211 | | - Project a dataset of predictors, 'd', with 'model' to return a projected dataset |
212 | | -
|
213 | | - See Also |
214 | | - -------- |
215 | | - build_projection_template: Build a projection template from simple functions. |
216 | | - ProjectionTemplate: Technical ProjectionTemplate protocol. |
217 | | - """ |
218 | | - preprocessed = model.pre_project(d) |
219 | | - projected = model.project(preprocessed) |
220 | | - postprocessed = model.post_project(projected) |
221 | | - |
222 | | - return postprocessed |
0 commit comments