@@ -118,6 +118,7 @@ def update_config(self):
118118 # update base config => model config
119119 model_id = self .base_config .model_id
120120 self .model_config .model_id = model_id
121+ self ._maybe_set_max_observed_time ()
121122
122123 run = current_stage
123124 use_torch = self .base_config .backend == Backend .Torch
@@ -141,6 +142,42 @@ def update_config(self):
141142
142143 return
143144
145+ def _maybe_set_max_observed_time (self ):
146+ """Resolve the observation-window end T for WSM models."""
147+ if self .base_config .model_id != 'WSMTHP' :
148+ return
149+
150+ model_specs = self .model_config .model_specs
151+ t_mode = str (model_specs .get ('T_mode' , 'train_global' )).lower ()
152+
153+ if t_mode == 'manual' :
154+ py_assert (model_specs .get ('max_observed_time' ) is not None ,
155+ ValueError ,
156+ 'WSMTHP with T_mode=manual requires model_specs.max_observed_time.' )
157+ return
158+
159+ if t_mode == 'batch' :
160+ model_specs ['max_observed_time' ] = None
161+ logger .info ('WSMTHP uses batch-wise T (T_mode=batch).' )
162+ return
163+
164+ py_assert (t_mode == 'train_global' ,
165+ ValueError ,
166+ f'Unsupported WSMTHP T_mode: { t_mode } . Use manual, train_global, or batch.' )
167+
168+ from easy_tpp .preprocess .data_loader import TPPDataLoader
169+
170+ data_loader = TPPDataLoader (
171+ data_config = self .data_config ,
172+ backend = self .base_config .backend ,
173+ batch_size = self .trainer_config .batch_size ,
174+ shuffle = False ,
175+ )
176+ max_observed_time = data_loader .get_max_event_time ('train' )
177+ if max_observed_time is not None :
178+ model_specs ['max_observed_time' ] = float (max_observed_time )
179+ logger .info (f'Auto-set model_specs.max_observed_time={ max_observed_time } from train split (T_mode=train_global)' )
180+
144181 def get_metric_functions (self ):
145182 return MetricsHelper .get_metrics_callback_from_names (self .trainer_config .metrics )
146183
0 commit comments