@@ -22,7 +22,12 @@ def __init__(self, func, args_list, dataset, x_bounds, x_shifts):
2222 self .x_shifts = x_shifts
2323 self .n_dim = len (x_bounds )
2424 self .n_data_points = len (dataset )
25- self .output = []
25+ self .grid_results = []
26+ self .x_spacing_results = []
27+ self .log_like_results = []
28+ self .ess_results = []
29+ self .func_evals = []
30+ self .grid_size = []
2631
2732
2833 def calc_naive_grid (self , grid_resolution , data_tempering_index , n_processes = 4 ):
@@ -40,7 +45,7 @@ def calc_naive_grid(self, grid_resolution, data_tempering_index, n_processes=4):
4045 log_likelihoods , rel_prob , weights , ess = gt .eval_grid_points (grid , self .func , self .args_list [data_tempering_index ], n_processes )
4146 return grid , x_spacing , log_likelihoods , rel_prob , weights , ess
4247
43- def initialize (self , init_grid_resolution , init_data_size , ess_min , n_processes = 4 , max_iter = 100 ):
48+ def initialize (self , init_grid_resolution , init_data_size , ess_min , n_processes = 4 , max_iter = 100 , store_results = False ):
4449 """Iteratively update the initial grid until the ESS is greater than the specified minimum ESS.
4550
4651 Args:
@@ -49,6 +54,8 @@ def initialize(self, init_grid_resolution, init_data_size, ess_min, n_processes=
4954 ess_min (float): The minimum effective sample size (ESS) to be used for initialization.
5055 n_processes (int): The number of parallel processes to use when evaluating the grid. Defaults to 4.
5156 max_iter (int): The maximum number of iterations to use when updating the grid. Defaults to 100.
57+ store_results (bool, optional): Stores results at each tempering stage (may use a lot of memory for large grids / datasets)
58+ Defaults to False.
5259
5360 Returns:
5461 Tuple[int, int, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, float]: A tuple containing the following elements:
@@ -67,16 +74,25 @@ def initialize(self, init_grid_resolution, init_data_size, ess_min, n_processes=
6774
6875 grid_resolution = init_grid_resolution
6976 iter = 0
77+ f_evals = np .size (log_likelihoods )
7078 with tqdm (total = None , desc = "Initializing..." ) as pbar :
7179 while ess < ess_min and iter < max_iter :
7280 grid_resolution = grid_resolution + 1
7381 grid , x_spacing , log_likelihoods , rel_prob , weights , ess = self .calc_naive_grid (grid_resolution , data_tempering_index , n_processes )
7482 iter = iter + 1
83+ f_evals = f_evals + np .size (log_likelihoods )
7584 pbar .update ()
76- pbar .set_description (f"Intialization: data_size={ data_tempering_index + 1 } , grid_resolution={ grid_resolution } , n_grid_points={ np .shape (grid )[0 ]} , ESS={ ess } " )
85+ pbar .set_description (f"Intialization: data_size={ data_tempering_index + 1 } , grid_resolution={ grid_resolution } , n_grid_points={ np .shape (grid )[0 ]} , ESS={ ess } , init_func_evals:{ f_evals } " )
86+ if store_results :
87+ self .grid_results .append (grid )
88+ self .x_spacing_results .append (x_spacing )
89+ self .log_like_results .append (log_likelihoods )
90+ self .ess_results .append (ess )
91+ self .func_evals .append (f_evals )
92+ self .grid_size .append (np .shape (grid )[0 ])
7793 return grid_resolution , data_tempering_index + 1 , grid , x_spacing , log_likelihoods , rel_prob , weights , ess
7894
79- def initialize_and_sample (self , init_grid_resolution , init_data_size , ess_min , delta , n_processes = 4 , max_iter = 100 , store_grid = False ):
95+ def initialize_and_sample (self , init_grid_resolution , init_data_size , ess_min , delta , n_processes = 4 , max_iter = 100 , store_results = False ):
8096 """ Initializes the grid and performs the data tempering (iterative batching) to obtain posterior samples.
8197
8298 Args:
@@ -88,7 +104,7 @@ def initialize_and_sample(self, init_grid_resolution, init_data_size, ess_min, d
88104 Defaults to 4.
89105 max_iter (int, optional): The maximum number of iterations for updating the grid.
90106 Defaults to 100.
91- store_grid (bool, optional): Stores reduced grid at each tempering stage (may use a lot of memory for large grids / datasets)
107+ store_results (bool, optional): Stores results at each tempering stage (may use a lot of memory for large grids / datasets)
92108 Defaults to False.
93109
94110 Returns:
@@ -104,47 +120,18 @@ def initialize_and_sample(self, init_grid_resolution, init_data_size, ess_min, d
104120 """
105121
106122 # generate initial (uniform) grid samples
107- grid_resolution , data_size , grid , x_spacing , log_likelihoods , rel_prob , weights , ess = self .initialize (init_grid_resolution , init_data_size , ess_min , n_processes , max_iter )
108- # gt.plot_2d_scatter(grid,
109- # f"initial grid: {data_size} data points, {np.shape(grid)[0]} grid point",
110- # r'$x_1$',
111- # r'$x_2$',
112- # self.x_bounds[0],
113- # self.x_bounds[1],
114- # 5,
115- # 0.1,
116- # "initial_grid.png"
117- # )
118- # # remove low probability samples
119- # gt.plot_2d_scatter(grid,
120- # f"initial grid reduced: {data_size} data points, {np.shape(grid)[0]} grid point",
121- # r'$x_1$',
122- # r'$x_2$',
123- # self.x_bounds[0],
124- # self.x_bounds[1],
125- # 5,
126- # 0.1,
127- # "initial_grid_reduced.png"
128- # )
123+ grid_resolution , data_size , grid , x_spacing , log_likelihoods , rel_prob , weights , ess = self .initialize (init_grid_resolution , init_data_size , ess_min , n_processes , max_iter , store_results )
129124
130125 # iterate through remaining data (i.e. data tempering)
131126 data_tempering_index = data_size - 1
132127 pbar = tqdm (range (data_tempering_index , self .n_data_points - 1 ),desc = 'Processing:' )
133128 for i in pbar :
129+ f_evals = 0
134130 data_tempering_index = i + 1
135131 args = self .args_list [data_tempering_index ]
136132 grid = gt .add_grid_points (grid , self .x_bounds , self .x_shifts , x_spacing ) # expand
137133 log_likelihoods , rel_prob , weights , ess = gt .eval_grid_points (grid , self .func , args , n_processes )
138- # gt.plot_2d_scatter(grid,
139- # f"expanded grid: {data_tempering_index+1} data points, {np.shape(grid)[0]} grid points",
140- # r'$x_1$',
141- # r'$x_2$',
142- # self.x_bounds[0],
143- # self.x_bounds[1],
144- # 5,
145- # 0.1,
146- # f"expanded_grid_{i}.png"
147- # )
134+ f_evals = f_evals + np .size (log_likelihoods )
148135
149136 # ensure ess is high enough for added datapoint, if not, make grid finer
150137 iter = 0
@@ -153,34 +140,20 @@ def initialize_and_sample(self, init_grid_resolution, init_data_size, ess_min, d
153140 x_spacing = gt .update_x_spacing (x_spacing ,2 ) # make grid spacing finer by 2x
154141 grid = gt .add_grid_points (grid , self .x_bounds , self .x_shifts , x_spacing , prev_x_spacing ) # expand and pack
155142 log_likelihoods , rel_prob , weights , ess = gt .eval_grid_points (grid , self .func , args , n_processes )
156- iter = iter + 1
157- # gt.plot_2d_scatter(grid,
158- # f"packed grid: {data_tempering_index+1} data points, {np.shape(grid)[0]} grid points",
159- # r'$x_1$',
160- # r'$x_2$',
161- # self.x_bounds[0],
162- # self.x_bounds[1],
163- # 5,
164- # 0.1,
165- # f"packed_grid_{i}_{iter}.png"
166- # )
143+ iter = iter + 1
144+ f_evals = f_evals + np .size (log_likelihoods )
167145
168146 grid = gt .reduce_grid_points (grid ,weights ,delta )
169- # gt.plot_2d_scatter(grid,
170- # f"reduced grid: {data_tempering_index+1} data points, {np.shape(grid)[0]} grid points",
171- # r'$x_1$',
172- # r'$x_2$',
173- # self.x_bounds[0],
174- # self.x_bounds[1],
175- # 5,
176- # 0.1,
177- # f"reduced_grid_{i}.png"
178- # )
179- # self.output.append(grid)
180- if store_grid :
181- self .output .append (grid )
182-
183- pbar .set_description (f"Processing: data_size={ data_tempering_index + 1 } , n_grid_points={ np .shape (grid )[0 ]} , ESS={ ess } " )
147+
148+ if store_results :
149+ self .grid_results .append (grid )
150+ self .x_spacing_results .append (x_spacing )
151+ self .log_like_results .append (log_likelihoods )
152+ self .ess_results .append (ess )
153+ self .func_evals .append (f_evals )
154+ self .grid_size .append (np .shape (grid )[0 ])
155+
156+ pbar .set_description (f"Processing: data_size={ data_tempering_index + 1 } , n_grid_points={ np .shape (grid )[0 ]} , ESS={ ess } , func_evals:{ f_evals } " )
184157 return grid_resolution , data_size , grid , x_spacing , log_likelihoods , rel_prob , weights , ess
185158
186159
0 commit comments