Skip to content

Commit 2d729f2

Browse files
committed
Fix: Corrected the bug where metrics were averaged without accounting for missing values​
1 parent f3b725b commit 2d729f2

5 files changed

Lines changed: 211 additions & 18 deletions

File tree

basicts/metrics/__init__.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1+
from .cls_metrics import accuracy, f1_score, precision, recall
12
from .corr import masked_corr
23
from .mae import masked_mae
34
from .mape import masked_mape
5+
from .metric_meter import AvgMeter, RMSEMeter
46
from .mse import masked_mse
57
from .r_square import masked_r2
68
from .rmse import masked_rmse
@@ -15,17 +17,32 @@
1517
'WAPE': masked_wape,
1618
'SMAPE': masked_smape,
1719
'R2': masked_r2,
18-
'CORR': masked_corr
20+
'CORR': masked_corr,
21+
"accuracy": accuracy,
22+
"precision": precision,
23+
"recall": recall,
24+
"f1": f1_score
1925
}
2026

27+
METRIC_METER = {
28+
'RMSE': RMSEMeter,
29+
'default': AvgMeter
30+
}
31+
2132
__all__ = [
2233
'masked_mae',
2334
'masked_mse',
2435
'masked_rmse',
36+
'incremental_masked_rmse',
2537
'masked_mape',
2638
'masked_wape',
2739
'masked_smape',
2840
'masked_r2',
2941
'masked_corr',
30-
'ALL_METRICS'
42+
'accuracy',
43+
'precision',
44+
'recall',
45+
'f1_score',
46+
'ALL_METRICS',
47+
'METRIC_METER'
3148
]

basicts/metrics/metric_meter.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
class AvgMeter:
2+
"""Average meter.
3+
"""
4+
5+
def __init__(self):
6+
self._sum: float = 0.
7+
self._count: int = 0
8+
9+
def reset(self):
10+
"""Reset counter.
11+
"""
12+
13+
self._sum = 0.
14+
self._count = 0
15+
16+
def update(self, value: float, n: int = 1):
17+
"""Update sum and count.
18+
19+
Args:
20+
value (float): value.
21+
n (int): number.
22+
"""
23+
24+
self._sum += value * n
25+
self._count += n
26+
27+
@property
28+
def value(self) -> float:
29+
"""Get average value.
30+
31+
Returns:
32+
avg (float)
33+
"""
34+
35+
return self._sum / self._count if self._count != 0 else 0
36+
37+
38+
class RMSEMeter:
39+
"""
40+
RMSE meter.
41+
This meter maintains **MSE** and calculate **RMSE** in the post process.
42+
"""
43+
44+
def __init__(self):
45+
self._mse: float = 0.
46+
self._count: int = 0
47+
48+
def reset(self):
49+
"""Reset counter.
50+
"""
51+
52+
self._mse = 0.
53+
self._count = 0
54+
55+
def update(self, value: float, n: int = 1):
56+
"""Update sum and count.
57+
58+
Args:
59+
value (float): value.
60+
n (int): number.
61+
"""
62+
63+
self._mse += value ** 2 * n
64+
self._count += n
65+
66+
@property
67+
def value(self) -> float:
68+
"""Get average value.
69+
70+
Returns:
71+
avg (float)
72+
"""
73+
74+
mse = self._mse / self._count if self._count != 0 else 0
75+
76+
return mse ** 0.5

basicts/runners/base_epoch_runner.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
from easytorch.core.checkpoint import (backup_last_ckpt, clear_ckpt, load_ckpt,
1111
save_ckpt)
1212
from easytorch.core.data_loader import build_data_loader, build_data_loader_ddp
13-
from easytorch.core.meter_pool import MeterPool
1413
from easytorch.device import to_device
1514
from easytorch.utils import (TimePredictor, get_local_rank, get_logger,
1615
is_master, master_only, set_env)
@@ -22,7 +21,7 @@
2221
from torch.utils.tensorboard import SummaryWriter
2322
from tqdm import tqdm
2423

25-
from ..utils import get_dataset_name
24+
from ..utils import MeterPool, get_dataset_name
2625
from . import optim
2726

2827

@@ -597,7 +596,7 @@ def inference_pipeline(self, cfg: Optional[Dict] = None, input_data: Union[str,
597596
result = self.inference(save_result_path=output_data_file_path)
598597

599598
inference_end_time = time.time()
600-
self.update_epoch_meter('inference/time', inference_end_time - inference_start_time)
599+
self.update_epoch_meter('inference/time', 'inference', inference_end_time - inference_start_time)
601600

602601
self.print_epoch_meters('inference')
603602

@@ -924,7 +923,7 @@ def save_best_model(self, epoch: int, metric_name: str, greater_best: bool = Tru
924923
`False` means lower value is best, such as `loss`. Defaults to True.
925924
"""
926925

927-
metric = self.meter_pool.get_avg(metric_name)
926+
metric = self.meter_pool.get_value(metric_name)
928927
best_metric = self.best_metrics.get(metric_name)
929928
if best_metric is None or (metric > best_metric if greater_best else metric < best_metric):
930929
self.best_metrics[metric_name] = metric

basicts/runners/base_tsf_runner.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def __init__(self, cfg: Dict):
7474
# define metrics
7575
self.metrics = cfg.get('METRICS', {}).get('FUNCS', {
7676
'MAE': masked_mae,
77-
'RMSE': masked_rmse,
77+
'RMSE': masked_rmse,
7878
'MAPE': masked_mape,
7979
'WAPE': masked_wape,
8080
'MSE': masked_mse
@@ -376,7 +376,7 @@ def train_iters(self, epoch: int, iter_index: int, data: Union[torch.Tensor, Tup
376376

377377
for metric_name, metric_func in self.metrics.items():
378378
metric_item = self.metric_forward(metric_func, forward_return)
379-
self.update_epoch_meter(f'train/{metric_name}', metric_item.item())
379+
self.update_epoch_meter(f'train/{metric_name}', metric_item.item(), weight)
380380
return loss
381381

382382
def val_iters(self, iter_index: int, data: Union[torch.Tensor, Tuple]):
@@ -432,22 +432,23 @@ def test(self, train_epoch: Optional[int] = None, save_metrics: bool = False, sa
432432
for i in self.evaluation_horizons:
433433
pred_h = pred[:, i, :, :]
434434
target_h = target[:, i, :, :]
435+
weight_h = self._get_metric_weight(target_h)
435436

436437
for metric_name, metric_func in self.metrics.items():
437438
if metric_name.lower() == 'mase':
438439
continue # MASE needs to be calculated after all horizons
439440
metric_val = self.metric_forward(metric_func, {'prediction': pred_h, 'target': target_h})
440-
self.update_epoch_meter(f'test/{metric_name}@h{i+1}', metric_val.item(), weight)
441+
self.update_epoch_meter(f'test/{metric_name}@h{i+1}', metric_val.item(), weight_h)
441442

442443
for metric_name, metric_func in self.metrics.items():
443444
metric_item = self.metric_forward(metric_func, {'prediction': pred, 'target': target})
444445
self.update_epoch_meter(f'test/{metric_name}', metric_item.item(), weight)
445446

446447
if save_metrics:
447448
metrics_results = {}
448-
metrics_results['overall'] = {k: self.meter_pool.get_avg(f'test/{k}') for k in self.metrics.keys()}
449+
metrics_results['overall'] = {k: self.meter_pool.get_value(f'test/{k}') for k in self.metrics.keys()}
449450
for i in self.evaluation_horizons:
450-
metrics_results[f'horizon_{i+1}'] = {k: self.meter_pool.get_avg(f'test/{k}@h{i+1}') for k in self.metrics.keys()}
451+
metrics_results[f'horizon_{i+1}'] = {k: self.meter_pool.get_value(f'test/{k}@h{i+1}') for k in self.metrics.keys()}
451452

452453
# save metrics_results to self.ckpt_save_dir/test_metrics.json
453454
with open(os.path.join(self.ckpt_save_dir, 'test_metrics.json'), 'w') as f:
@@ -553,18 +554,14 @@ def _save_test_results(self, batch_idx: int, batch_data: Dict[str, np.ndarray])
553554
def _get_metric_weight(self, x: torch.Tensor) -> int:
554555
"""
555556
Get the weight for calculating metrics.
556-
1. Since the last batch may be smaller (`drop_last=False`), it is necessary to perform a weighted average based on the batch size.
557-
2. Since the number of valid values in each batch may vary, a weighted average based on the valid value count is also required.
558-
Valid value count is the total count minus the number of missing values.
559-
The weight is the product of the batch size and the valid value count.
557+
Since the number of valid values in each batch may vary, it is necessary to perform a weighted average based on the valid value count.
558+
The valid value count is the total count minus the number of missing values.
560559
"""
561560

562-
batch_size = x.shape[0]
563-
564561
if self.null_val == np.nan:
565562
valid_num = (~torch.isnan(x)).sum().item()
566563
else:
567564
eps = 5e-5
568565
valid_num = (~torch.isclose(x, torch.tensor(self.null_val).expand_as(x).to(x.device), atol=eps, rtol=0.0)).sum().item()
569566

570-
return batch_size * valid_num
567+
return valid_num

basicts/utils/meter_pool.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
import logging
2+
from typing import Any, Dict, Tuple, Union
3+
4+
from torch.utils.tensorboard import SummaryWriter
5+
6+
from ..metrics import METRIC_METER
7+
8+
9+
class MeterPool:
10+
"""Meter container
11+
"""
12+
13+
def __init__(self):
14+
self._pool: Dict[str, Dict[str, Any]] = {}
15+
16+
def register(self, name: str, meter_type: str, fmt: str = '{:f}', plt: bool = True):
17+
"""Add a meter to meter pool.
18+
Args:
19+
name (str): meter name.
20+
meter_type (str): meter type.
21+
fmt (str): meter output format.
22+
plt (bool): set ```True``` to plot it in tensorboard
23+
when calling ```plt_meters```.
24+
"""
25+
26+
if name in self._pool:
27+
raise ValueError(f'Meter {name} already existed.')
28+
29+
# name: type/metric or type/metric@h{i}
30+
metric = name.split('/')[1].split('@')[0] # get the metric name
31+
handle_meter = 'default' if metric not in METRIC_METER else metric
32+
33+
self._pool[name] = {
34+
'meter': METRIC_METER[handle_meter](),
35+
'index': len(self._pool.keys()),
36+
'format': fmt,
37+
'type': meter_type,
38+
'plt': plt
39+
}
40+
41+
def update(self, name: str, value: Union[float, Tuple[float]] , n: int = 1):
42+
"""Update average meter.
43+
44+
Args:
45+
name (str): meter name.
46+
value (Union[float, Tuple[float]]): value.
47+
n: (int): num.
48+
"""
49+
50+
self._pool[name]['meter'].update(value, n)
51+
52+
def get_value(self, name: str) -> float:
53+
"""Get value.
54+
55+
Args:
56+
name (str): meter name.
57+
58+
Returns:
59+
avg (float)
60+
"""
61+
62+
return self._pool[name]['meter'].value
63+
64+
def print_meters(self, meter_type: str, logger: logging.Logger = None):
65+
"""Print the specified type of meters.
66+
67+
Args:
68+
meter_type (str): meter type
69+
logger (logging.Logger): logger
70+
"""
71+
72+
print_list = []
73+
for i in range(len(self._pool.keys())):
74+
for name, value in self._pool.items():
75+
if value['index'] == i and value['type'] == meter_type:
76+
print_list.append(
77+
('{}: ' + value['format']).format(name, value['meter'].value)
78+
)
79+
print_str = 'Result <{}>: [{}]'.format(meter_type, ', '.join(print_list))
80+
if logger is None:
81+
print(print_str)
82+
else:
83+
logger.info(print_str)
84+
85+
def plt_meters(self, meter_type: str, step: int, tensorboard_writer: SummaryWriter):
86+
"""Plot the specified type of meters in tensorboard.
87+
88+
Args:
89+
meter_type (str): meter type.
90+
step (int): Global step value to record
91+
tensorboard_writer (SummaryWriter): tensorboard SummaryWriter
92+
"""
93+
94+
for name, value in self._pool.items():
95+
if value['plt'] and value['type'] == meter_type:
96+
tensorboard_writer.add_scalar(name, value['meter'].value, global_step=step)
97+
tensorboard_writer.flush()
98+
99+
def reset(self):
100+
"""Reset all meters.
101+
"""
102+
103+
for _, value in self._pool.items():
104+
value['meter'].reset()

0 commit comments

Comments
 (0)