Skip to content

Commit e12b2c4

Browse files
committed
refactor(risktraj): lays better foundations for MeasureAppraiser
1 parent b75ab91 commit e12b2c4

2 files changed

Lines changed: 120 additions & 44 deletions

File tree

climada/trajectories/risk_trajectory.py

Lines changed: 57 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,6 @@ def _reset_metrics(self):
105105
self._risk_components_metrics = None
106106
self._aai_per_group_metrics = None
107107
self._all_risk_metrics = None
108-
self._metrics_up_to_date = False
109108

110109
@property
111110
def default_rp(self):
@@ -249,48 +248,67 @@ def _generic_metrics(
249248

250249
return getattr(self, attr_name)
251250

252-
def aai_metrics(self, npv=True):
253-
return self._generic_metrics(
254-
npv=npv, metric_name="aai", metric_meth="calc_aai_metric"
251+
def _compute_metrics(
252+
self, metric_name, metric_meth, total=False, npv=True, *args, **kwargs
253+
):
254+
"""Helper method to compute metrics and optionally return total risk."""
255+
df = self._generic_metrics(
256+
npv=npv, metric_name=metric_name, metric_meth=metric_meth, *args, **kwargs
255257
)
258+
if total:
259+
return self._per_period_risk(df)
260+
return df
256261

257-
def return_periods_metrics(self, return_periods=None, npv=True):
262+
def aai_metrics(self, total=False, npv=True, *args, **kwargs):
263+
return self._compute_metrics(
264+
total=total,
265+
npv=npv,
266+
metric_name="aai",
267+
metric_meth="calc_aai_metric",
268+
*args,
269+
**kwargs,
270+
)
271+
272+
def return_periods_metrics(
273+
self, total=False, return_periods=None, npv=True, *args, **kwargs
274+
):
258275
return_periods = return_periods if return_periods else self.default_rp
259-
return self._generic_metrics(
276+
return self._compute_metrics(
260277
npv=npv,
261278
metric_name="return_periods",
262279
metric_meth="calc_return_periods_metric",
263280
return_periods=return_periods,
281+
*args,
282+
**kwargs,
264283
)
265284

266-
def aai_per_group_metrics(self, npv=True):
267-
return self._generic_metrics(
285+
def aai_per_group_metrics(self, npv=True, *args, **kwargs):
286+
return self._compute_metrics(
268287
npv=npv,
269288
metric_name="aai_per_group",
270289
metric_meth="calc_aai_per_group_metric",
290+
*args,
291+
**kwargs,
271292
)
272293

273-
def risk_components_metrics(self, npv=True):
274-
return self._generic_metrics(
294+
def risk_components_metrics(self, npv=True, *args, **kwargs):
295+
return self._compute_metrics(
275296
npv=npv,
276297
metric_name="risk_components",
277298
metric_meth="calc_risk_components_metric",
299+
*args,
300+
**kwargs,
278301
)
279302

280303
def all_risk_metrics(
281-
self, return_periods=[50, 100, 500], npv=True
304+
self, return_periods=[50, 100, 500], npv=True, *args, **kwargs
282305
) -> pd.DataFrame | pd.Series:
283-
if not self._metrics_up_to_date or self._all_risk_metrics is None:
284-
aai = self.aai_metrics(npv)
285-
rp = self.return_periods_metrics(return_periods, npv)
286-
aai_per_group = self.aai_per_group_metrics(npv)
287-
risk_components = self.risk_components_metrics(npv)
288-
self._all_risk_metrics = pd.concat(
289-
[aai, rp, aai_per_group, risk_components]
290-
)
291-
self._metrics_up_to_date = True
292306

293-
return self._all_risk_metrics
307+
aai = self.aai_metrics(npv, *args, **kwargs)
308+
rp = self.return_periods_metrics(return_periods, npv, *args, **kwargs)
309+
aai_per_group = self.aai_per_group_metrics(npv, *args, **kwargs)
310+
risk_components = self.risk_components_metrics(npv, *args, **kwargs)
311+
return pd.concat([aai, rp, aai_per_group, risk_components])
294312

295313
@staticmethod
296314
def _get_risk_periods(
@@ -321,7 +339,7 @@ def identify_continuous_periods(group, time_unit):
321339
return group
322340

323341
grouper = cls._grouper
324-
if "group" in df.columns:
342+
if "group" in df.columns and "group" not in grouper:
325343
grouper = ["group"] + grouper
326344

327345
df_sorted = df.sort_values(by=cls._grouper + ["date"])
@@ -330,14 +348,20 @@ def identify_continuous_periods(group, time_unit):
330348
identify_continuous_periods, time_unit
331349
)
332350

351+
if isinstance(colname, str):
352+
colname = [colname]
353+
354+
agg_dict = {
355+
"start_date": pd.NamedAgg(column="date", aggfunc="min"),
356+
"end_date": pd.NamedAgg(column="date", aggfunc="max"),
357+
}
358+
for col in colname:
359+
agg_dict[col] = pd.NamedAgg(column=col, aggfunc="sum")
333360
# Group by the identified periods and calculate start and end dates
361+
print(df_periods)
334362
df_periods = (
335363
df_periods.groupby(grouper + ["period_id"], dropna=False)
336-
.agg(
337-
start_date=pd.NamedAgg(column="date", aggfunc="min"),
338-
end_date=pd.NamedAgg(column="date", aggfunc="max"),
339-
total=pd.NamedAgg(column=colname, aggfunc="sum"),
340-
)
364+
.agg(**agg_dict)
341365
.reset_index()
342366
)
343367

@@ -346,28 +370,22 @@ def identify_continuous_periods(group, time_unit):
346370
+ " to "
347371
+ df_periods["end_date"].astype(str)
348372
)
349-
df_periods = df_periods.rename(columns={"total": f"{colname}"})
373+
# df_periods = df_periods.rename(columns={"total": f"{colname}"})
350374
df_periods = df_periods.drop(["period_id", "start_date", "end_date"], axis=1)
351375
return df_periods[
352376
["period"] + [col for col in df_periods.columns if col != "period"]
353377
]
354378

355379
@property
356-
def per_date_risk_metrics(self) -> pd.DataFrame | pd.Series:
380+
def per_date_risk_metrics(self, *args, **kwargs) -> pd.DataFrame | pd.Series:
357381
"""Returns a tidy dataframe of the risk metrics for all dates."""
358-
return self._prepare_risk_metrics(total=False, npv=True)
382+
return self.all_risk_metrics(*args, **kwargs)
359383

360384
@property
361-
def total_risk_metrics(self) -> pd.DataFrame | pd.Series:
385+
def total_risk_metrics(self, *args, **kwargs) -> pd.DataFrame | pd.Series:
362386
"""Returns a tidy dataframe of the risk metrics with the total for each different period."""
363-
return self._prepare_risk_metrics(total=True, npv=True)
364-
365-
def _prepare_risk_metrics(self, total=False, npv=True) -> pd.DataFrame | pd.Series:
366-
df = self.all_risk_metrics(npv=npv)
367-
if total:
368-
return self._per_period_risk(df)
369-
370-
return df
387+
df = self.all_risk_metrics(*args, **kwargs)
388+
return self._per_period_risk(df)
371389

372390
def _calc_waterfall_plot_data(self, start_date=None, end_date=None, npv=True):
373391
start_date = self.start_date if start_date is None else start_date

climada/trajectories/riskperiod.py

Lines changed: 63 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,10 @@ def lazy_property(method):
4646
@property
4747
def _lazy(self):
4848
if getattr(self, attr_name) is None:
49+
meas_n = self.measure.name if self.measure else "no_measure"
50+
LOGGER.debug(
51+
f"Computing {method.__name__} for {self._snapshot0.date}-{self._snapshot1.date} with {meas_n}."
52+
)
4953
setattr(self, attr_name, method(self))
5054
return getattr(self, attr_name)
5155

@@ -67,13 +71,14 @@ def __init__(
6771
risk_transf_cover: float | None = None,
6872
calc_residual: bool = False,
6973
):
74+
LOGGER.info("Instantiating new CalcRiskPeriod.")
7075
self._snapshot0 = snapshot0
7176
self._snapshot1 = snapshot1
72-
self.date_idx = pd.date_range(
73-
snapshot0.date,
74-
snapshot1.date,
77+
self.date_idx = CalcRiskPeriod._set_date_idx(
78+
date1=snapshot0.date,
79+
date2=snapshot1.date,
7580
periods=time_points,
76-
freq=interval_freq, # type: ignore
81+
freq=interval_freq,
7782
name="date",
7883
)
7984
self.interpolation_strategy = interpolation_strategy or LinearInterpolation()
@@ -89,13 +94,66 @@ def __init__(
8994
self._group_id_E1 = self.snapshot1.exposure.gdf["group_id"].values
9095

9196
def _reset_impact_data(self):
92-
self._impacts_arrays = None, None, None, None
97+
self._impacts_arrays = None
9398
self._imp_mats_H0, self._imp_mats_H1 = None, None
9499
self._imp_mats_E0, self._imp_mats_E1 = None, None
95100
self._per_date_eai_H0, self._per_date_eai_H1 = None, None
96101
self._per_date_aai_H0, self._per_date_aai_H1 = None, None
97102
self._per_date_return_periods_H0, self._per_date_return_periods_H1 = None, None
98103

104+
@staticmethod
105+
def _set_date_idx(
106+
date1: str | pd.Timestamp,
107+
date2: str | pd.Timestamp,
108+
periods: int | None = None,
109+
freq: str | None = None,
110+
name: str | None = None,
111+
) -> pd.DatetimeIndex:
112+
"""
113+
Generate a date range index based on the provided parameters.
114+
115+
Parameters
116+
----------
117+
date1 : str or pd.Timestamp
118+
The start date of the date range.
119+
date2 : str or pd.Timestamp
120+
The end date of the date range.
121+
periods : int, optional
122+
Number of date points to generate. If None, `freq` must be provided.
123+
freq : str, optional
124+
Frequency string for the date range. If None, `periods` must be provided.
125+
name : str, optional
126+
Name of the resulting date range index.
127+
128+
Returns
129+
-------
130+
pd.DatetimeIndex
131+
A DatetimeIndex representing the date range.
132+
133+
Raises
134+
------
135+
ValueError
136+
If the number of periods and frequency given to date_range are inconsistent.
137+
"""
138+
if periods is not None and freq is not None:
139+
points = None
140+
else:
141+
points = periods
142+
143+
ret = pd.date_range(
144+
date1,
145+
date2,
146+
periods=points,
147+
freq=freq, # type: ignore
148+
name=name,
149+
)
150+
if periods is not None and len(ret) != periods:
151+
raise ValueError(
152+
"Number of periods and frequency given to date_range are inconsistant"
153+
)
154+
155+
return ret
156+
99157
@property
100158
def snapshot0(self):
101159
return self._snapshot0

0 commit comments

Comments
 (0)