Skip to content

Commit 2cf1ebf

Browse files
committed
update selective learning to support load configuration from json
1 parent 76b4891 commit 2cf1ebf

3 files changed

Lines changed: 48 additions & 13 deletions

File tree

src/basicts/runners/callback/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from .curriculum_learrning import CurriculumLearning
55
from .early_stopping import EarlyStopping
66
from .grad_accumulation import GradAccumulation
7-
from .koopa_mask_init import KoopaMaskInitCallbackFullTrain
87
from .no_bp import NoBP
98
from .selective_learning import SelectiveLearning
109

src/basicts/runners/callback/selective_learning.py

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -42,14 +42,19 @@ def __init__(
4242
ckpt_path: Optional[str] = None):
4343

4444
super().__init__()
45+
46+
# config
4547
self.r_u = r_u
4648
self.r_a = r_a
47-
self.estimator = estimator(estimator_config)
49+
self.estimator = estimator
50+
self.estimator_config = estimator_config
4851
self.ckpt_path = ckpt_path
4952

50-
if self.r_a is not None and self.estimator is None:
53+
self.estimation_model = self.estimator(estimator_config)
54+
55+
if self.r_a is not None and self.estimation_model is None:
5156
raise RuntimeError("Anomaly mask ratio is set but estimation model is not provided.")
52-
if self.estimator is not None and self.ckpt_path is None:
57+
if self.estimation_model is not None and self.ckpt_path is None:
5358
raise RuntimeError("Estimation model is set but checkpoint path is not provided.")
5459

5560
self.history_residual: torch.Tensor = None
@@ -59,7 +64,7 @@ def __init__(
5964
def on_train_start(self, runner: "BasicTSRunner"):
6065
runner.logger.info(f"Use selective learning with r_u={self.r_u}, r_a={self.r_a}.")
6166
self._load_estimator(runner)
62-
self.estimator.eval()
67+
self.estimation_model.eval()
6368
self.num_samples = len(runner.train_data_loader.dataset)
6469
runner.train_data_loader = _DataLoaderWithIndex(runner.train_data_loader)
6570

@@ -86,7 +91,7 @@ def on_compute_loss(self, runner: "BasicTSRunner", **kwargs):
8691
# Anomaly mask
8792
if self.r_a is not None:
8893
with torch.no_grad():
89-
est_foward_return = runner._forward(self.estimator, data, step=0)
94+
est_foward_return = runner._forward(self.estimation_model, data, step=0)
9095
residual_lb = torch.abs(est_foward_return["prediction"] - forward_return["targets"])
9196
dist = residual - residual_lb
9297
thresholds = torch.quantile(
@@ -103,24 +108,24 @@ def on_epoch_end(self, runner: "BasicTSRunner", **kwargs):
103108

104109
def _load_estimator(self, runner: "BasicTSRunner"):
105110

106-
runner.logger.info(f"Building estimation model {self.estimator.__class__.__name__}.")
107-
self.estimator = to_device(self.estimator)
111+
runner.logger.info(f"Building estimation model {self.estimation_model.__class__.__name__}.")
112+
self.estimation_model = to_device(self.estimation_model)
108113

109114
# DDP
110115
if torch.distributed.is_initialized():
111-
self.estimator = DDP(
112-
self.estimator,
116+
self.estimation_model = DDP(
117+
self.estimation_model,
113118
device_ids=[get_local_rank()],
114119
find_unused_parameters=runner.cfg.ddp_find_unused_parameters
115120
)
116121

117122
# load model weights
118123
try:
119124
checkpoint_dict = load_ckpt(None, ckpt_path=self.ckpt_path, logger=runner.logger)
120-
if isinstance(self.estimator, DDP):
121-
self.estimator.module.load_state_dict(checkpoint_dict["model_state_dict"])
125+
if isinstance(self.estimation_model, DDP):
126+
self.estimation_model.module.load_state_dict(checkpoint_dict["model_state_dict"])
122127
else:
123-
self.estimator.load_state_dict(checkpoint_dict["model_state_dict"])
128+
self.estimation_model.load_state_dict(checkpoint_dict["model_state_dict"])
124129
except (IndexError, OSError) as e:
125130
raise OSError(f"Ckpt file {self.ckpt_path} does not exist") from e
126131

src/test.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
from basicts.models.DLinear import DLinear, DLinearConfig
2+
from basicts.models.iTransformer import iTransformerConfig, iTransformerForForecasting
3+
from basicts import BasicTSLauncher
4+
from basicts.configs import BasicTSForecastingConfig
5+
from basicts.runners.callback import SelectiveLearning
6+
7+
8+
if __name__ == "__main__":
9+
10+
cb = SelectiveLearning(
11+
r_u=0.3,
12+
r_a=0.3,
13+
estimator=DLinear,
14+
estimator_config=DLinearConfig(input_len=336, output_len=336),
15+
ckpt_path="checkpoints/DLinear/ETTh1_100_336_336/1f037d3a0fb4a6de40ce3dcb2656b136/DLinear_best_val_MSE.pt"
16+
)
17+
18+
BasicTSLauncher.launch_training(
19+
BasicTSForecastingConfig(
20+
model=iTransformerForForecasting,
21+
input_len=336,
22+
output_len=336,
23+
use_timestamps=False,
24+
model_config=iTransformerConfig(
25+
input_len=336,
26+
output_len=336,
27+
num_features=7),
28+
dataset_name="ETTh1",
29+
gpus="0",
30+
callbacks=[cb],
31+
))

0 commit comments

Comments
 (0)