@@ -42,14 +42,19 @@ def __init__(
4242 ckpt_path : Optional [str ] = None ):
4343
4444 super ().__init__ ()
45+
46+ # config
4547 self .r_u = r_u
4648 self .r_a = r_a
47- self .estimator = estimator (estimator_config )
49+ self .estimator = estimator
50+ self .estimator_config = estimator_config
4851 self .ckpt_path = ckpt_path
4952
50- if self .r_a is not None and self .estimator is None :
53+ self .estimation_model = self .estimator (estimator_config )
54+
55+ if self .r_a is not None and self .estimation_model is None :
5156 raise RuntimeError ("Anomaly mask ratio is set but estimation model is not provided." )
52- if self .estimator is not None and self .ckpt_path is None :
57+ if self .estimation_model is not None and self .ckpt_path is None :
5358 raise RuntimeError ("Estimation model is set but checkpoint path is not provided." )
5459
5560 self .history_residual : torch .Tensor = None
@@ -59,7 +64,7 @@ def __init__(
5964 def on_train_start (self , runner : "BasicTSRunner" ):
6065 runner .logger .info (f"Use selective learning with r_u={ self .r_u } , r_a={ self .r_a } ." )
6166 self ._load_estimator (runner )
62- self .estimator .eval ()
67+ self .estimation_model .eval ()
6368 self .num_samples = len (runner .train_data_loader .dataset )
6469 runner .train_data_loader = _DataLoaderWithIndex (runner .train_data_loader )
6570
@@ -86,7 +91,7 @@ def on_compute_loss(self, runner: "BasicTSRunner", **kwargs):
8691 # Anomaly mask
8792 if self .r_a is not None :
8893 with torch .no_grad ():
89- est_foward_return = runner ._forward (self .estimator , data , step = 0 )
94+ est_foward_return = runner ._forward (self .estimation_model , data , step = 0 )
9095 residual_lb = torch .abs (est_foward_return ["prediction" ] - forward_return ["targets" ])
9196 dist = residual - residual_lb
9297 thresholds = torch .quantile (
@@ -103,24 +108,24 @@ def on_epoch_end(self, runner: "BasicTSRunner", **kwargs):
103108
104109 def _load_estimator (self , runner : "BasicTSRunner" ):
105110
106- runner .logger .info (f"Building estimation model { self .estimator .__class__ .__name__ } ." )
107- self .estimator = to_device (self .estimator )
111+ runner .logger .info (f"Building estimation model { self .estimation_model .__class__ .__name__ } ." )
112+ self .estimation_model = to_device (self .estimation_model )
108113
109114 # DDP
110115 if torch .distributed .is_initialized ():
111- self .estimator = DDP (
112- self .estimator ,
116+ self .estimation_model = DDP (
117+ self .estimation_model ,
113118 device_ids = [get_local_rank ()],
114119 find_unused_parameters = runner .cfg .ddp_find_unused_parameters
115120 )
116121
117122 # load model weights
118123 try :
119124 checkpoint_dict = load_ckpt (None , ckpt_path = self .ckpt_path , logger = runner .logger )
120- if isinstance (self .estimator , DDP ):
121- self .estimator .module .load_state_dict (checkpoint_dict ["model_state_dict" ])
125+ if isinstance (self .estimation_model , DDP ):
126+ self .estimation_model .module .load_state_dict (checkpoint_dict ["model_state_dict" ])
122127 else :
123- self .estimator .load_state_dict (checkpoint_dict ["model_state_dict" ])
128+ self .estimation_model .load_state_dict (checkpoint_dict ["model_state_dict" ])
124129 except (IndexError , OSError ) as e :
125130 raise OSError (f"Ckpt file { self .ckpt_path } does not exist" ) from e
126131
0 commit comments