@@ -74,7 +74,7 @@ def __init__(self, cfg: Dict):
7474 # define metrics
7575 self .metrics = cfg .get ('METRICS' , {}).get ('FUNCS' , {
7676 'MAE' : masked_mae ,
77- 'RMSE' : masked_rmse ,
77+ 'RMSE' : masked_rmse ,
7878 'MAPE' : masked_mape ,
7979 'WAPE' : masked_wape ,
8080 'MSE' : masked_mse
@@ -376,7 +376,7 @@ def train_iters(self, epoch: int, iter_index: int, data: Union[torch.Tensor, Tup
376376
377377 for metric_name , metric_func in self .metrics .items ():
378378 metric_item = self .metric_forward (metric_func , forward_return )
379- self .update_epoch_meter (f'train/{ metric_name } ' , metric_item .item ())
379+ self .update_epoch_meter (f'train/{ metric_name } ' , metric_item .item (), weight )
380380 return loss
381381
382382 def val_iters (self , iter_index : int , data : Union [torch .Tensor , Tuple ]):
@@ -432,22 +432,23 @@ def test(self, train_epoch: Optional[int] = None, save_metrics: bool = False, sa
432432 for i in self .evaluation_horizons :
433433 pred_h = pred [:, i , :, :]
434434 target_h = target [:, i , :, :]
435+ weight_h = self ._get_metric_weight (target_h )
435436
436437 for metric_name , metric_func in self .metrics .items ():
437438 if metric_name .lower () == 'mase' :
438439 continue # MASE needs to be calculated after all horizons
439440 metric_val = self .metric_forward (metric_func , {'prediction' : pred_h , 'target' : target_h })
440- self .update_epoch_meter (f'test/{ metric_name } @h{ i + 1 } ' , metric_val .item (), weight )
441+ self .update_epoch_meter (f'test/{ metric_name } @h{ i + 1 } ' , metric_val .item (), weight_h )
441442
442443 for metric_name , metric_func in self .metrics .items ():
443444 metric_item = self .metric_forward (metric_func , {'prediction' : pred , 'target' : target })
444445 self .update_epoch_meter (f'test/{ metric_name } ' , metric_item .item (), weight )
445446
446447 if save_metrics :
447448 metrics_results = {}
448- metrics_results ['overall' ] = {k : self .meter_pool .get_avg (f'test/{ k } ' ) for k in self .metrics .keys ()}
449+ metrics_results ['overall' ] = {k : self .meter_pool .get_value (f'test/{ k } ' ) for k in self .metrics .keys ()}
449450 for i in self .evaluation_horizons :
450- metrics_results [f'horizon_{ i + 1 } ' ] = {k : self .meter_pool .get_avg (f'test/{ k } @h{ i + 1 } ' ) for k in self .metrics .keys ()}
451+ metrics_results [f'horizon_{ i + 1 } ' ] = {k : self .meter_pool .get_value (f'test/{ k } @h{ i + 1 } ' ) for k in self .metrics .keys ()}
451452
452453 # save metrics_results to self.ckpt_save_dir/test_metrics.json
453454 with open (os .path .join (self .ckpt_save_dir , 'test_metrics.json' ), 'w' ) as f :
@@ -553,18 +554,14 @@ def _save_test_results(self, batch_idx: int, batch_data: Dict[str, np.ndarray])
553554 def _get_metric_weight (self , x : torch .Tensor ) -> int :
554555 """
555556 Get the weight for calculating metrics.
556- 1. Since the last batch may be smaller (`drop_last=False`), it is necessary to perform a weighted average based on the batch size.
557- 2. Since the number of valid values in each batch may vary, a weighted average based on the valid value count is also required.
558- Valid value count is the total count minus the number of missing values.
559- The weight is the product of the batch size and the valid value count.
557+ Since the number of valid values in each batch may vary, it is necessary to perform a weighted average based on the valid value count.
558+ The valid value count is the total count minus the number of missing values.
560559 """
561560
562- batch_size = x .shape [0 ]
563-
564561 if self .null_val == np .nan :
565562 valid_num = (~ torch .isnan (x )).sum ().item ()
566563 else :
567564 eps = 5e-5
568565 valid_num = (~ torch .isclose (x , torch .tensor (self .null_val ).expand_as (x ).to (x .device ), atol = eps , rtol = 0.0 )).sum ().item ()
569566
570- return batch_size * valid_num
567+ return valid_num
0 commit comments