Skip to content

Commit 3593d2f

Browse files
committed
Allow predictions table to be overlaid on empirical POI change plots.
1 parent 80d223e commit 3593d2f

1 file changed

Lines changed: 23 additions & 2 deletions

File tree

src/openpois/osm/change_plots.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,11 +84,12 @@ def change_plot_reshape_data(
8484

8585
def change_plot_create(
8686
observations: pd.DataFrame,
87+
predictions: pd.DataFrame | None = None,
8788
no_change_col: str = 'no_change',
8889
change_col: str = 'change',
8990
final_observation_col: str = 'final_obs',
90-
title: str = None,
91-
subtitle: str = None,
91+
title: str | None = None,
92+
subtitle: str | None = None,
9293
x_label: str = '',
9394
y_label: str = '',
9495
day_range: int = 365 * 10,
@@ -99,6 +100,7 @@ def change_plot_create(
99100
Args:
100101
observations: DataFrame with observations. Each row is an iteration of a
101102
tag, with the three columns described below.
103+
predictions: DataFrame with modeled predictions.
102104
no_change_col: Column name for the days elapsed from when the tag was added to
103105
when it was last confirmed (observed unchanged).
104106
change_col: Column name for the days elapsed from when the tag was added to when
@@ -122,6 +124,11 @@ def change_plot_create(
122124
final_observation_col = final_observation_col,
123125
day_range = day_range,
124126
)
127+
if predictions is not None:
128+
if subtitle is not None:
129+
subtitle = f"{subtitle}\nModeled predictions in red"
130+
else:
131+
subtitle = "Modeled predictions in red"
125132
fig = (
126133
gg.ggplot(
127134
data = reshaped,
@@ -150,6 +157,20 @@ def change_plot_create(
150157
) +
151158
gg.theme_bw()
152159
)
160+
if predictions is not None:
161+
p_renamed = predictions.assign(
162+
year = pd.col('t2'),
163+
y = pd.col('conf_mean'),
164+
ymin = pd.col('conf_lower'),
165+
ymax = pd.col('conf_upper'),
166+
)
167+
fig = fig + gg.geom_ribbon(
168+
data = p_renamed,
169+
fill = 'red', alpha = 0.25, linetype = 'dashed', color = '#444444'
170+
) + gg.geom_line(
171+
data = p_renamed, color = '#444444',
172+
mapping = gg.aes(x = 'year', y = 'y')
173+
)
153174
return fig
154175

155176

0 commit comments

Comments
 (0)