Skip to content

Commit 59ef035

Browse files
committed
Overlay model estimates on POI change plots.
1 parent 3593d2f commit 59ef035

1 file changed

Lines changed: 67 additions & 1 deletion

File tree

scripts/osm_data/data_viz.py

Lines changed: 67 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,12 @@
4444

4545
SAVE_DIR = config.get_dir_path("osm_data")
4646
VIZ_DIR = SAVE_DIR / "viz"
47-
OSM_KEYS = config.get("download", "download_keys")
47+
OSM_KEYS = config.get("download", "osm", "filter_keys")
4848
TAG_KEY = config.get("osm_data", "tag_key")
4949
END_DATE = pd.Timestamp(config.get("download", "osm", "end_date"), tz='UTC')
50+
MODEL_BASE = config.get_dir_path("model_output").parent
51+
MODEL_STUB = config.get("osm_data", "apply_model", "model_stub")
52+
ADJ_FACTOR = 1.0
5053

5154
max_days = 365 * 10
5255
VIZ_DIR.mkdir(parents=True, exist_ok=True)
@@ -79,12 +82,40 @@ def fig_save(
7982
**kwargs
8083
)
8184

85+
def get_preds_dict(model_stub: str | None, adj_factor: float = 1.0) -> dict[str, pd.DataFrame]:
86+
"""
87+
Load model predictions from the model output directory.
88+
"""
89+
model_output_dir = config.get_dir_path("model_output")
90+
if model_stub is None:
91+
return dict()
92+
def get_preds_df(model_stub: str, subset: str | None = None) -> Path:
93+
if subset is None:
94+
preds_fp = MODEL_BASE / f"{model_stub}_constant/predictions.csv"
95+
else:
96+
preds_fp = MODEL_BASE / f"{model_stub}_by_{subset}/predictions.csv"
97+
if not preds_fp.exists():
98+
return None
99+
return pd.read_csv(preds_fp).assign(
100+
year = pd.col('t2'),
101+
conf_mean = (1.0 - pd.col('p_mean')) * adj_factor,
102+
conf_lower = (1.0 - pd.col('p_upper')) * adj_factor,
103+
conf_upper = (1.0 - pd.col('p_lower')) * adj_factor,
104+
)
105+
preds = dict()
106+
preds["constant"] = get_preds_df(model_stub)
107+
for subset in OSM_KEYS:
108+
preds[subset] = get_preds_df(model_stub, subset)
109+
return preds
110+
82111

83112
# ----------------------------------------------------------------------------------------
84113
# Main workflow
85114
# ----------------------------------------------------------------------------------------
86115

87116
if __name__ == "__main__":
117+
# Read model predictions
118+
preds = get_preds_dict(MODEL_STUB, adj_factor = ADJ_FACTOR)
88119
# Read observations
89120
# Drop the first observation for each POI (when the POI was first added) - the last
90121
# observation timestamp will be missing for these rows
@@ -127,6 +158,7 @@ def fig_save(
127158
)
128159
# Format changes
129160
to_plot_df = pd.concat([changed_tags, unchanged_tags])
161+
to_plot_df['final_obs'] = np.inf
130162
# Create a plot for all tags
131163
fig = change_plot_create(
132164
observations = to_plot_df,
@@ -140,6 +172,20 @@ def fig_save(
140172
)
141173
fig_save(fig, stub = f"osm_changes_{TAG_KEY}_all")
142174

175+
if 'constant' in preds:
176+
fig = change_plot_create(
177+
observations = to_plot_df,
178+
predictions = preds['constant'],
179+
no_change_col = 'no_change',
180+
change_col = 'change',
181+
final_observation_col = 'final_obs',
182+
day_range = max_days,
183+
title = f"Stability of the `{TAG_KEY}` tag over time",
184+
x_label = "Years since tag",
185+
y_label = "Proportion remaining unchanged",
186+
)
187+
fig_save(fig, stub = f"osm_changes_{TAG_KEY}_all_preds")
188+
143189
# Create multi-panel plots for the top tags in each OSM category
144190
TOP_N_TYPES = config.get("osm_data", "top_n_types")
145191
for subtype in OSM_KEYS:
@@ -159,3 +205,23 @@ def fig_save(
159205
day_range = max_days,
160206
)
161207
fig_save(fig = fig, stub = f"osm_changes_{TAG_KEY}_{subtype}")
208+
209+
if subtype in preds:
210+
top_n_tags = to_plot_df[subtype].value_counts().head(TOP_N_TYPES).index
211+
pred_groups = preds[subtype]['group_name'].unique().tolist()
212+
keep_preds = list(set(top_n_tags) & set(pred_groups))
213+
for pred_tag in keep_preds:
214+
print(f"Plotting {subtype}={pred_tag}")
215+
fig = change_plot_create(
216+
observations = to_plot_df.query(f"{subtype} == @pred_tag"),
217+
predictions = preds[subtype].query(f"group_name == @pred_tag"),
218+
no_change_col = 'no_change',
219+
change_col = 'change',
220+
final_observation_col = 'final_obs',
221+
day_range = max_days,
222+
title = f"Stability of the `{TAG_KEY}` tag over time: {pred_tag}",
223+
x_label = "Years since tag",
224+
y_label = "Proportion remaining unchanged",
225+
)
226+
fig_save(fig, stub = f"osm_changes_{TAG_KEY}_{subtype}_preds_{pred_tag}")
227+

0 commit comments

Comments
 (0)