Skip to content

Commit 5f23f7d

Browse files
committed
fix: 🐛 a bug in utsf runner
1 parent 9a10e45 commit 5f23f7d

2 files changed

Lines changed: 5 additions & 4 deletions

File tree

basicts/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from .launcher import launch_evaluation, launch_inference, launch_training
22
from .runners import BaseEpochRunner
33

4-
__version__ = '0.5.5'
4+
__version__ = '0.5.6'
55

66
__all__ = ['__version__', 'launch_training', 'launch_evaluation', 'BaseEpochRunner', 'launch_inference']

basicts/runners/base_utsf_runner.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def __init__(self, cfg: Dict):
5757
self.grad_accumulation_steps = cfg.get('TRAIN.GRAD_ACCUMULATION_STEPS',1)
5858

5959
# inference settings
60-
self.generation_params = cfg['INFERENCE'].get('GENERATION_PARAMS', {})
60+
self.generation_params = cfg.get('INFERENCE', {}).get('GENERATION_PARAMS', {})
6161
self.forward_features = [0] # do not use time features
6262
self.target_features = [0] # do not use time features
6363

@@ -358,8 +358,9 @@ def val_iters(self, iter_index: int, data: Union[torch.Tensor, Tuple]):
358358
data = self.preprocessing(data)
359359
# TODO: consider using amp for validation
360360
# with self.ctx:
361-
forward_return = self.forward(data=data, iter_num=iter_index, train=False)
362-
forward_return = self.postprocessing(forward_return)
361+
with self.amp_context:
362+
forward_return = self.forward(data=data, iter_num=iter_index, train=False)
363+
forward_return = self.postprocessing(forward_return)
363364
loss = self.metric_forward(self.loss, forward_return)
364365
self.update_iteration_meter('val/loss', loss)
365366

0 commit comments

Comments
 (0)