Skip to content

Commit c4338ed

Browse files
authored
Merge pull request #79 from KenCao2007/add-wsmthp
Add WSMTHP model and WSM-specific T handling
2 parents 19f6270 + 829c897 commit c4338ed

5 files changed

Lines changed: 800 additions & 7 deletions

File tree

easy_tpp/config_factory/runner_config.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ def update_config(self):
118118
# update base config => model config
119119
model_id = self.base_config.model_id
120120
self.model_config.model_id = model_id
121+
self._maybe_set_max_observed_time()
121122

122123
run = current_stage
123124
use_torch = self.base_config.backend == Backend.Torch
@@ -141,6 +142,42 @@ def update_config(self):
141142

142143
return
143144

145+
def _maybe_set_max_observed_time(self):
146+
"""Resolve the observation-window end T for WSM models."""
147+
if self.base_config.model_id != 'WSMTHP':
148+
return
149+
150+
model_specs = self.model_config.model_specs
151+
t_mode = str(model_specs.get('T_mode', 'train_global')).lower()
152+
153+
if t_mode == 'manual':
154+
py_assert(model_specs.get('max_observed_time') is not None,
155+
ValueError,
156+
'WSMTHP with T_mode=manual requires model_specs.max_observed_time.')
157+
return
158+
159+
if t_mode == 'batch':
160+
model_specs['max_observed_time'] = None
161+
logger.info('WSMTHP uses batch-wise T (T_mode=batch).')
162+
return
163+
164+
py_assert(t_mode == 'train_global',
165+
ValueError,
166+
f'Unsupported WSMTHP T_mode: {t_mode}. Use manual, train_global, or batch.')
167+
168+
from easy_tpp.preprocess.data_loader import TPPDataLoader
169+
170+
data_loader = TPPDataLoader(
171+
data_config=self.data_config,
172+
backend=self.base_config.backend,
173+
batch_size=self.trainer_config.batch_size,
174+
shuffle=False,
175+
)
176+
max_observed_time = data_loader.get_max_event_time('train')
177+
if max_observed_time is not None:
178+
model_specs['max_observed_time'] = float(max_observed_time)
179+
logger.info(f'Auto-set model_specs.max_observed_time={max_observed_time} from train split (T_mode=train_global)')
180+
144181
def get_metric_functions(self):
145182
return MetricsHelper.get_metrics_callback_from_names(self.trainer_config.metrics)
146183

easy_tpp/model/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from easy_tpp.model.torch_model.torch_s2p2 import S2P2 as TorchS2P2
1010
from easy_tpp.model.torch_model.torch_sahp import SAHP as TorchSAHP
1111
from easy_tpp.model.torch_model.torch_thp import THP as TorchTHP
12+
from easy_tpp.model.torch_model.torch_wsm_thp import WSMTHP as TorchWSMTHP
1213

1314
__all__ = ['TorchBaseModel',
1415
'TorchNHP',
@@ -20,4 +21,5 @@
2021
'TorchODETPP',
2122
'TorchRMTPP',
2223
'TorchANHN',
23-
'TorchS2P2']
24+
'TorchS2P2',
25+
'TorchWSMTHP']

0 commit comments

Comments
 (0)