@@ -56,13 +56,18 @@ def tensor(x: np.ndarray, **kwargs) -> torch.Tensor:
5656 # Define model
5757 # Only parameters need requires_grad = True
5858 y = tensor (obs_sub ['changed' ].values )
59- t1 = torch .zeros (obs_sub .shape [0 ], 1 , dtype = dtype , device = device )
6059 t2 = tensor (obs_sub [['tag_years' ]].values )
60+ t1 = torch .zeros_like (t2 )
6161 # Estimand: (log) lambda, log of the rate parameter
6262 def simple_model_fun (params : torch .Tensor ) -> torch .Tensor :
6363 return torch .exp (params )
6464 starting_params = tensor (np .array ([0.0 ]), requires_grad = True )
6565
66+ # def pseudo_varying_model_fun(params: torch.Tensor) -> callable:
67+ # def f_t(t: torch.Tensor) -> torch.Tensor:
68+ # return torch.exp(params).repeat(t.shape).reshape(list(t.shape) + [-1])
69+ # return f_t
70+
6671 simple_model = ModelFitter (
6772 event_rate_type = 'constant' ,
6873 event_rate_fun = simple_model_fun ,
@@ -78,7 +83,9 @@ def simple_model_fun(params: torch.Tensor) -> torch.Tensor:
7883 simple_model .fit ()
7984 simple_model .generate_parameter_draws (n_draws = N_DRAWS )
8085 fitted_params = simple_model .get_parameter_table ().assign (parameter = 'log_lambda' )
81- predictions = simple_model .predict (t2 = tensor (np .arange (11 ))).assign (units = 'years' )
86+ predictions = simple_model .predict (
87+ t2 = tensor (np .arange (11 )).reshape (- 1 , 1 )
88+ ).assign (units = 'years' )
8289
8390 # Save results
8491 fitted_params .to_csv (MODEL_DIR / f"fitted_params{ model_suffix } .csv" , index = False )
0 commit comments