Skip to content

Commit 8dc2383

Browse files
authored
Merge pull request #24 from brews/internal_refactor
Move core logic out of __init__
2 parents aef2941 + 7c6acbd commit 8dc2383

3 files changed

Lines changed: 238 additions & 211 deletions

File tree

src/isku/__init__.py

Lines changed: 8 additions & 211 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
1-
from collections.abc import Callable
2-
from dataclasses import dataclass
3-
from typing import Protocol
4-
5-
import xarray as xr
1+
from isku.extract import (
2+
ExtractionTemplate,
3+
GridWeightingRegions,
4+
RegionExtractor,
5+
build_extraction_template,
6+
extract_regions,
7+
)
8+
from isku.project import ProjectionTemplate, build_projection_template, project
69

710
__all__ = [
811
"ExtractionTemplate",
@@ -14,209 +17,3 @@
1417
"extract_regions",
1518
"project",
1619
]
17-
18-
19-
class ExtractionTemplate(Protocol):
20-
"""
21-
Template for pre and post region extraction transformation
22-
23-
See Also
24-
--------
25-
build_extraction_template: Quickly build extraction template from functions for regionalization with pre/post transformations.
26-
extract_regions: Apply a template to extract a new regionalized dataset from gridded data.
27-
RegionExtractor: Protocol for regionalizing, or extracting regions from a dataset.
28-
"""
29-
30-
def pre_extract(self, ds: xr.Dataset) -> xr.Dataset:
31-
"""
32-
Transform dataset before region extraction
33-
"""
34-
...
35-
36-
def post_extract(self, ds: xr.Dataset) -> xr.Dataset:
37-
"""
38-
Transform dataset after region extraction
39-
"""
40-
...
41-
42-
43-
class RegionExtractor(Protocol):
44-
"""
45-
Protocol for extracting regions from gridded data
46-
47-
See Also
48-
--------
49-
extract_regions: Apply a template to extract a new regionalized dataset from gridded data with pre/post transformations.
50-
ExtractionTemplate: Technical protocol for a workflow with pre/post regionalization transformations.
51-
"""
52-
53-
def extract_regions(self, ds: xr.Dataset) -> xr.Dataset:
54-
"""
55-
Extract and aggregate gridded dataset points into regionalized dataset
56-
"""
57-
...
58-
59-
60-
# This dataclass is a quick and simple way to get a concrete instance of the protocol.
61-
@dataclass(frozen=True)
62-
class _SimpleExtractionTemplate(ExtractionTemplate):
63-
pre_extract: Callable[[xr.Dataset], xr.Dataset]
64-
post_extract: Callable[[xr.Dataset], xr.Dataset]
65-
66-
67-
def build_extraction_template(
68-
*, pre: Callable[[xr.Dataset], xr.Dataset], post: Callable[[xr.Dataset], xr.Dataset]
69-
) -> ExtractionTemplate:
70-
"""
71-
Build a template of tranformation steps applied to input gridded data, pre/post regionalization, to create a derived variable as output
72-
73-
This function is a quick and simple way to build an ExtractionTemplate from two simple functions.
74-
75-
These steps should be general. They may contain logic for sanity checks
76-
on inputs and outputs, calculating derived variables and climate indices,
77-
adding or checking metadata or units. Avoid including logic for cleaning,
78-
or harmonizing input data, especially if it is specific to a single
79-
project's usecase. Generally avoid using a single strategy to output
80-
multiple unrelated variables.
81-
82-
See Also
83-
--------
84-
extract_regions: Apply a template to extract a new regionalized dataset from gridded data.
85-
build_extraction_template: Quickly build extraction template from functions for regionalization.
86-
ExtractionTemplate: The underlaying protocol for a workflow that extracts a regionalized dataset.
87-
"""
88-
return _SimpleExtractionTemplate(pre_extract=pre, post_extract=post)
89-
90-
91-
# Use class for segment weights because we're making assumptions/enforcements about the weight data's content and interactions...
92-
class GridWeightingRegions(RegionExtractor):
93-
"""
94-
Regions that can be extracted from regularly-gridded data after weighting grid points
95-
96-
'weights' dataset must have "lat", "lon", "weight", "region".
97-
98-
Raises
99-
------
100-
ValueError
101-
If 'weights' is missing "lat", "lon", "weight" or "region" variables.
102-
103-
See Also
104-
--------
105-
extract_regions: Extract new regionalized dataset.
106-
build_extraction_template: Quickly build extraction template from functions for regionalization.
107-
RegionExtractor: Protocol for regionalizing, or extracting regions from a dataset.
108-
"""
109-
110-
def __init__(self, weights: xr.Dataset):
111-
target_variables = ("lat", "lon", "weight", "region")
112-
missing_variables = [v for v in target_variables if v not in weights.variables]
113-
if missing_variables:
114-
raise ValueError(
115-
f"input weights is missing required {missing_variables} variable(s)"
116-
)
117-
self._data = weights
118-
119-
def extract_regions(self, ds: xr.Dataset) -> xr.Dataset:
120-
"""
121-
Regionalize input gridded data after multiplying 'ds' by weights and summing the product within each region.
122-
123-
'ds' must have "lat", "lon" coordinates exactly matching "lat", "lon" in weights.
124-
"""
125-
# TODO: See how this errors in different common scenarios. What happens on the
126-
# unhappy path?
127-
region_sel = ds.sel(lat=self._data["lat"], lon=self._data["lon"])
128-
out = (region_sel * self._data["weight"]).groupby(self._data["region"]).sum()
129-
# TODO: Maybe drop lat/lon and set 'region' as dim/coord? I feel like we can do
130-
# this because we're asking weights to strictly match input's lat/lon. Maybe
131-
# make this a req of segment weights we're reading in?
132-
return out
133-
134-
135-
def extract_regions(
136-
ds: xr.Dataset, *, template: ExtractionTemplate, regions: RegionExtractor
137-
) -> xr.Dataset:
138-
"""
139-
Use transformations in 'template' to extract 'regions' from gridded dataset, 'ds', returning a regionalized dataset
140-
141-
This function specifically does not just regionalize through zonal aggregation. It uses 'template' to apply pre/post regionalization transformations to create new datasets and variables.
142-
143-
See Also
144-
--------
145-
build_extraction_template: Quickly build extraction workflow from functions for regionalization.
146-
"""
147-
return template.post_extract(regions.extract_regions(template.pre_extract(ds)))
148-
149-
150-
class ProjectionTemplate(Protocol):
151-
"""
152-
Template for projecting a model with pre and post processing.
153-
154-
See Also
155-
--------
156-
build_projection_template: Build a projection template from simple functions.
157-
"""
158-
159-
def pre_project(self, d: xr.Dataset) -> xr.Dataset:
160-
"""
161-
Pre-process a dataset before projection
162-
"""
163-
...
164-
165-
def project(self, d: xr.Dataset) -> xr.Dataset:
166-
"""
167-
Create a projection from a dataset
168-
"""
169-
...
170-
171-
def post_project(self, d: xr.Dataset) -> xr.Dataset:
172-
"""
173-
Process a projected dataset
174-
"""
175-
...
176-
177-
178-
# This dataclass is a quick and simple way to get a concrete instance of the protocol.
179-
@dataclass(frozen=True)
180-
class _SimpleProjectionTemplate(ProjectionTemplate):
181-
pre_project: Callable[[xr.Dataset], xr.Dataset]
182-
project: Callable[[xr.Dataset], xr.Dataset]
183-
post_project: Callable[[xr.Dataset], xr.Dataset]
184-
185-
186-
def build_projection_template(
187-
*,
188-
pre: Callable[[xr.Dataset], xr.Dataset],
189-
project: Callable[[xr.Dataset], xr.Dataset],
190-
post: Callable[[xr.Dataset], xr.Dataset],
191-
) -> ProjectionTemplate:
192-
"""
193-
Use simple functions to quickly build a model to project effects, impacts and/or damages.
194-
195-
This function is a quick and simple way to build an ProjectionTemplate from three simple functions.
196-
197-
See Also
198-
--------
199-
project: Apply a projection template to a dataset.
200-
ProjectionTemplate: Technical ProjectionTemplate protocol.
201-
"""
202-
return _SimpleProjectionTemplate(
203-
pre_project=pre,
204-
project=project,
205-
post_project=post,
206-
)
207-
208-
209-
def project(d: xr.Dataset, *, model: ProjectionTemplate) -> xr.Dataset:
210-
"""
211-
Project a dataset of predictors, 'd', with 'model' to return a projected dataset
212-
213-
See Also
214-
--------
215-
build_projection_template: Build a projection template from simple functions.
216-
ProjectionTemplate: Technical ProjectionTemplate protocol.
217-
"""
218-
preprocessed = model.pre_project(d)
219-
projected = model.project(preprocessed)
220-
postprocessed = model.post_project(projected)
221-
222-
return postprocessed

src/isku/extract.py

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
from collections.abc import Callable
2+
from dataclasses import dataclass
3+
from typing import Protocol
4+
5+
import xarray as xr
6+
7+
__all__ = [
8+
"ExtractionTemplate",
9+
"GridWeightingRegions",
10+
"RegionExtractor",
11+
"build_extraction_template",
12+
"extract_regions",
13+
]
14+
15+
16+
class ExtractionTemplate(Protocol):
17+
"""
18+
Template for pre and post region extraction transformation
19+
20+
See Also
21+
--------
22+
build_extraction_template: Quickly build extraction template from functions for regionalization with pre/post transformations.
23+
extract_regions: Apply a template to extract a new regionalized dataset from gridded data.
24+
RegionExtractor: Protocol for regionalizing, or extracting regions from a dataset.
25+
"""
26+
27+
def pre_extract(self, ds: xr.Dataset) -> xr.Dataset:
28+
"""
29+
Transform dataset before region extraction
30+
"""
31+
...
32+
33+
def post_extract(self, ds: xr.Dataset) -> xr.Dataset:
34+
"""
35+
Transform dataset after region extraction
36+
"""
37+
...
38+
39+
40+
class RegionExtractor(Protocol):
41+
"""
42+
Protocol for extracting regions from gridded data
43+
44+
See Also
45+
--------
46+
extract_regions: Apply a template to extract a new regionalized dataset from gridded data with pre/post transformations.
47+
ExtractionTemplate: Technical protocol for a workflow with pre/post regionalization transformations.
48+
"""
49+
50+
def extract_regions(self, ds: xr.Dataset) -> xr.Dataset:
51+
"""
52+
Extract and aggregate gridded dataset points into regionalized dataset
53+
"""
54+
...
55+
56+
57+
# This dataclass is a quick and simple way to get a concrete instance of the protocol.
58+
@dataclass(frozen=True)
59+
class _SimpleExtractionTemplate(ExtractionTemplate):
60+
pre_extract: Callable[[xr.Dataset], xr.Dataset]
61+
post_extract: Callable[[xr.Dataset], xr.Dataset]
62+
63+
64+
def build_extraction_template(
65+
*, pre: Callable[[xr.Dataset], xr.Dataset], post: Callable[[xr.Dataset], xr.Dataset]
66+
) -> ExtractionTemplate:
67+
"""
68+
Build a template of tranformation steps applied to input gridded data, pre/post regionalization, to create a derived variable as output
69+
70+
This function is a quick and simple way to build an ExtractionTemplate from two simple functions.
71+
72+
These steps should be general. They may contain logic for sanity checks
73+
on inputs and outputs, calculating derived variables and climate indices,
74+
adding or checking metadata or units. Avoid including logic for cleaning,
75+
or harmonizing input data, especially if it is specific to a single
76+
project's usecase. Generally avoid using a single strategy to output
77+
multiple unrelated variables.
78+
79+
See Also
80+
--------
81+
extract_regions: Apply a template to extract a new regionalized dataset from gridded data.
82+
build_extraction_template: Quickly build extraction template from functions for regionalization.
83+
ExtractionTemplate: The underlaying protocol for a workflow that extracts a regionalized dataset.
84+
"""
85+
return _SimpleExtractionTemplate(pre_extract=pre, post_extract=post)
86+
87+
88+
# Use class for segment weights because we're making assumptions/enforcements about the weight data's content and interactions...
89+
class GridWeightingRegions(RegionExtractor):
90+
"""
91+
Regions that can be extracted from regularly-gridded data after weighting grid points
92+
93+
'weights' dataset must have "lat", "lon", "weight", "region".
94+
95+
Raises
96+
------
97+
ValueError
98+
If 'weights' is missing "lat", "lon", "weight" or "region" variables.
99+
100+
See Also
101+
--------
102+
extract_regions: Extract new regionalized dataset.
103+
build_extraction_template: Quickly build extraction template from functions for regionalization.
104+
RegionExtractor: Protocol for regionalizing, or extracting regions from a dataset.
105+
"""
106+
107+
def __init__(self, weights: xr.Dataset):
108+
target_variables = ("lat", "lon", "weight", "region")
109+
missing_variables = [v for v in target_variables if v not in weights.variables]
110+
if missing_variables:
111+
raise ValueError(
112+
f"input weights is missing required {missing_variables} variable(s)"
113+
)
114+
self._data = weights
115+
116+
def extract_regions(self, ds: xr.Dataset) -> xr.Dataset:
117+
"""
118+
Regionalize input gridded data after multiplying 'ds' by weights and summing the product within each region.
119+
120+
'ds' must have "lat", "lon" coordinates exactly matching "lat", "lon" in weights.
121+
"""
122+
# TODO: See how this errors in different common scenarios. What happens on the
123+
# unhappy path?
124+
region_sel = ds.sel(lat=self._data["lat"], lon=self._data["lon"])
125+
out = (region_sel * self._data["weight"]).groupby(self._data["region"]).sum()
126+
# TODO: Maybe drop lat/lon and set 'region' as dim/coord? I feel like we can do
127+
# this because we're asking weights to strictly match input's lat/lon. Maybe
128+
# make this a req of segment weights we're reading in?
129+
return out
130+
131+
132+
def extract_regions(
133+
ds: xr.Dataset, *, template: ExtractionTemplate, regions: RegionExtractor
134+
) -> xr.Dataset:
135+
"""
136+
Use transformations in 'template' to extract 'regions' from gridded dataset, 'ds', returning a regionalized dataset
137+
138+
This function specifically does not just regionalize through zonal aggregation. It uses 'template' to apply pre/post regionalization transformations to create new datasets and variables.
139+
140+
See Also
141+
--------
142+
build_extraction_template: Quickly build extraction workflow from functions for regionalization.
143+
"""
144+
return template.post_extract(regions.extract_regions(template.pre_extract(ds)))

0 commit comments

Comments
 (0)