Skip to content

Commit 80e0c47

Browse files
committed
Confirm that model works with both contant and time-varying functions for the change rate.
1 parent ec9da52 commit 80e0c47

1 file changed

Lines changed: 9 additions & 2 deletions

File tree

exploratory/models/pytorch_simple.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,13 +56,18 @@ def tensor(x: np.ndarray, **kwargs) -> torch.Tensor:
5656
# Define model
5757
# Only parameters need requires_grad = True
5858
y = tensor(obs_sub['changed'].values)
59-
t1 = torch.zeros(obs_sub.shape[0], 1, dtype=dtype, device=device)
6059
t2 = tensor(obs_sub[['tag_years']].values)
60+
t1 = torch.zeros_like(t2)
6161
# Estimand: (log) lambda, log of the rate parameter
6262
def simple_model_fun(params: torch.Tensor) -> torch.Tensor:
6363
return torch.exp(params)
6464
starting_params = tensor(np.array([0.0]), requires_grad = True)
6565

66+
# def pseudo_varying_model_fun(params: torch.Tensor) -> callable:
67+
# def f_t(t: torch.Tensor) -> torch.Tensor:
68+
# return torch.exp(params).repeat(t.shape).reshape(list(t.shape) + [-1])
69+
# return f_t
70+
6671
simple_model = ModelFitter(
6772
event_rate_type = 'constant',
6873
event_rate_fun = simple_model_fun,
@@ -78,7 +83,9 @@ def simple_model_fun(params: torch.Tensor) -> torch.Tensor:
7883
simple_model.fit()
7984
simple_model.generate_parameter_draws(n_draws = N_DRAWS)
8085
fitted_params = simple_model.get_parameter_table().assign(parameter = 'log_lambda')
81-
predictions = simple_model.predict(t2 = tensor(np.arange(11))).assign(units = 'years')
86+
predictions = simple_model.predict(
87+
t2 = tensor(np.arange(11)).reshape(-1, 1)
88+
).assign(units = 'years')
8289

8390
# Save results
8491
fitted_params.to_csv(MODEL_DIR / f"fitted_params{model_suffix}.csv", index = False)

0 commit comments

Comments
 (0)