Skip to content

Commit 4da1d3d

Browse files
committed
Update test script to use new ModelFitter class.
1 parent d5a3fdf commit 4da1d3d

1 file changed

Lines changed: 62 additions & 100 deletions

File tree

exploratory/models/pytorch_simple.py

Lines changed: 62 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -10,116 +10,78 @@
1010
import numpy as np
1111
import pandas as pd
1212
import torch
13-
import torchmin
14-
import plotnine as gg
1513
from pathlib import Path
1614

17-
from openpois.models.base_model import BaseModel, EventRate
15+
from openpois.models.model_fitter import ModelFitter
16+
from openpois.models.setup import pytorch_setup, prepare_data_for_model
1817

1918
# Globals
2019
DATA_VERSION = "20260129"
2120
MODEL_VERSION = "20260212"
2221
DATA_DIR = Path("~/data/openpois").expanduser() / DATA_VERSION
2322
MODEL_DIR = Path("~/data/openpois").expanduser() / MODEL_VERSION
2423
TAG_KEY = "name"
25-
GROUP_KEY = "leisure"
24+
GROUP_KEY = None
2625
GROUP_VALUES = ["park"]
27-
28-
# Load data
29-
observations_df = pd.read_csv(DATA_DIR / f"osm_observations_{TAG_KEY}.csv")
30-
31-
# Ensure model directory exists
32-
MODEL_DIR.mkdir(parents = True, exist_ok = True)
33-
model_suffix = f"_simple_{TAG_KEY}"
34-
if GROUP_KEY is not None:
35-
model_suffix += f"_{GROUP_KEY}"
36-
if GROUP_VALUES is not None:
37-
model_suffix += f"_{'-'.join(GROUP_VALUES)}"
38-
# Device setup
39-
DTYPE = torch.float64
40-
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
41-
print("Running on", DEVICE)
42-
torch.set_default_device(DEVICE)
43-
44-
45-
## Input data preparation --------------------------------------------------------------->
46-
47-
# If a group key was set, subset to those observations
48-
if GROUP_KEY is not None:
49-
keep_ids = observations_df.dropna(subset = [GROUP_KEY]).id.unique().tolist()
50-
observations_df = observations_df.query('id in @keep_ids')
51-
# If a group values were set, subset to those observations
52-
if GROUP_VALUES is not None:
53-
keep_ids = observations_df.loc[
54-
observations_df[GROUP_KEY].isin(GROUP_VALUES), 'id'
55-
].unique().tolist()
56-
observations_df = observations_df.query('id in @keep_ids')
57-
58-
timestamp_cols = ['obs_timestamp', 'last_obs_timestamp', 'last_tag_timestamp']
59-
for timestamp_col in timestamp_cols:
60-
observations_df[timestamp_col] = pd.to_datetime(observations_df[timestamp_col])
61-
observations_df = observations_df.assign(
62-
tag_days = (pd.col('obs_timestamp') - pd.col('last_tag_timestamp')).dt.days,
63-
tag_years = pd.col('tag_days') / 365
64-
)
65-
obs_sub = (observations_df
66-
.dropna(subset = ['tag_years', 'changed'])
67-
.query('tag_years > 1e-6')
68-
)
69-
70-
71-
## Define model ------------------------------------------------------------------------->
72-
73-
# Only parameters need requires_grad=True; data tensors must not, or memory explodes
74-
y = torch.tensor(obs_sub['changed'].values, dtype=DTYPE, device=DEVICE)
75-
X = torch.zeros(obs_sub.shape[0], 1, dtype=DTYPE, device=DEVICE)
76-
t1 = torch.zeros(obs_sub.shape[0], 1, dtype=DTYPE, device=DEVICE)
77-
t2 = torch.tensor(obs_sub[['tag_years']].values, dtype=DTYPE, device=DEVICE)
78-
# Estimand: (log) lambda, log of the rate parameter
79-
starting_params = torch.tensor(
80-
np.array([0.0]),
81-
dtype = DTYPE,
82-
device = DEVICE,
83-
requires_grad = True,
84-
)
85-
86-
def simple_model_fun(params, covariates = None):
87-
return torch.exp(params)
88-
89-
simple_model = BaseModel(
90-
event_rate = EventRate(
91-
type = 'constant',
92-
fun = simple_model_fun,
93-
),
94-
params = starting_params,
95-
covariates = X,
96-
target = y,
97-
t1 = t1,
98-
t2 = t2,
99-
verbose = True
100-
)
101-
simple_model.fit()
102-
103-
m1 = simple_model.get_results().assign(parameter = 'log_lambda')
104-
m2 = (
105-
m1
106-
.copy()
107-
.assign(
108-
parameter = 'lambda',
109-
estimate = np.exp(pd.col('estimate')),
110-
std_err = pd.col('estimate') * pd.col('std_err')
26+
N_DRAWS = 250
27+
SAVE_FULL_MODEL = False
28+
29+
30+
if __name__ == "__main__":
31+
# Ensure model directory exists
32+
MODEL_DIR.mkdir(parents = True, exist_ok = True)
33+
model_suffix = f"_simple_{TAG_KEY}"
34+
if GROUP_KEY is not None:
35+
model_suffix += f"_{GROUP_KEY}"
36+
if GROUP_VALUES is not None:
37+
model_suffix += f"_{'-'.join(GROUP_VALUES)}"
38+
39+
# Device setup
40+
dtype = torch.float64
41+
device = pytorch_setup()
42+
def tensor(x: np.ndarray, **kwargs) -> torch.Tensor:
43+
"""Convenience function to create a tensor with default dtype and device."""
44+
return torch.tensor(x, dtype = dtype, device = device, **kwargs)
45+
46+
# Data preparation
47+
observations_df = pd.read_csv(DATA_DIR / f"osm_observations_{TAG_KEY}.csv")
48+
obs_sub = prepare_data_for_model(
49+
data = observations_df,
50+
group_key = GROUP_KEY,
51+
group_values = GROUP_VALUES,
52+
t1_col = 'last_tag_timestamp',
53+
t2_col = 'obs_timestamp',
11154
)
112-
)
113-
model_results = pd.concat([m1, m2])
11455

115-
predictions = simple_model.predict(
116-
t2 = torch.tensor(np.arange(11), dtype = DTYPE, device = DEVICE),
117-
covariates = None,
118-
).assign(units = 'years')
119-
predictions.to_csv(MODEL_DIR / f"predictions{model_suffix}.csv", index = False)
120-
121-
122-
## Run model and save results ----------------------------------------------------------->
56+
# Define model
57+
# Only parameters need requires_grad = True
58+
y = tensor(obs_sub['changed'].values)
59+
t1 = torch.zeros(obs_sub.shape[0], 1, dtype=dtype, device=device)
60+
t2 = tensor(obs_sub[['tag_years']].values)
61+
# Estimand: (log) lambda, log of the rate parameter
62+
def simple_model_fun(params: torch.Tensor) -> torch.Tensor:
63+
return torch.exp(params)
64+
starting_params = tensor(np.array([0.0]), requires_grad = True)
65+
66+
simple_model = ModelFitter(
67+
event_rate_type = 'constant',
68+
event_rate_fun = simple_model_fun,
69+
params = starting_params,
70+
target = y,
71+
data = {},
72+
t1 = t1,
73+
t2 = t2,
74+
verbose = True
75+
)
12376

124-
model_results.to_csv(MODEL_DIR / f"fitted_params{model_suffix}.csv", index = False)
125-
torch.save(simple_model, MODEL_DIR / f"fitted_params{model_suffix}.pt")
77+
# Run the model and get predictions
78+
simple_model.fit()
79+
simple_model.generate_parameter_draws(n_draws = N_DRAWS)
80+
fitted_params = simple_model.get_parameter_table().assign(parameter = 'log_lambda')
81+
predictions = simple_model.predict(t2 = tensor(np.arange(11))).assign(units = 'years')
82+
83+
# Save results
84+
fitted_params.to_csv(MODEL_DIR / f"fitted_params{model_suffix}.csv", index = False)
85+
predictions.to_csv(MODEL_DIR / f"predictions{model_suffix}.csv", index = False)
86+
if SAVE_FULL_MODEL:
87+
torch.save(simple_model, MODEL_DIR / f"fitted_params{model_suffix}.pt")

0 commit comments

Comments
 (0)