Skip to content

Commit b6506e2

Browse files
committed
Merge branch 'beta'
2 parents 13002af + 9795acc commit b6506e2

108 files changed

Lines changed: 1756 additions & 150 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.github/workflows/python-test.yml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@ name: Pytest
55

66
on:
77
push:
8-
branches: [ "master" ]
8+
branches: [ "master", "release/*" ]
99
pull_request:
10-
branches: [ "master" ]
10+
branches: [ "master", "release/*" ]
1111

1212
permissions:
1313
contents: read
@@ -29,6 +29,7 @@ jobs:
2929
pip install flake8 pytest
3030
pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu
3131
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
32+
if [ -f tests/requirements.txt ]; then pip install -r tests/requirements.txt; fi
3233
- name: Test with pytest
3334
run: |
3435
pytest

.isort.cfg

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
[settings]
2-
src_paths=basicts,tests
3-
skip_glob=baselines/*,assets/*,examples/*
2+
src_paths=src/basicts,tests
3+
skip_glob=baselines/*,assets/*

.pylintrc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ ignore=baselines,assets,checkpoints,examples,scripts
1212

1313
# Files or directories matching the regex patterns are skipped. The regex
1414
# matches against base names, not paths.
15-
ignore-patterns=^\.|^_|^.*\.md|^.*\.txt|^.*\.csv|^.*\.CFF|^LICENSE
15+
ignore-patterns=^\.|^_|^.*\.md|^.*\.txt|^.*\.csv|^.*\.CFF|^LICENSE|^.*\.toml
1616

1717
# Pickle collected data for later comparisons.
1818
persistent=no

examples/classification/classification_demo.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from basicts import BasicTSLauncher
22
from basicts.configs import BasicTSClassificationConfig
3-
from basicts.models.iTransformer import iTransformerForClassification, iTransformerConfig
3+
from basicts.models.iTransformer import (iTransformerConfig,
4+
iTransformerForClassification)
45

56

67
def main():

examples/forecasting/forecasting_demo.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
1+
from torch.optim.lr_scheduler import MultiStepLR
2+
13
from basicts import BasicTSLauncher
24
from basicts.configs import BasicTSForecastingConfig
3-
from basicts.models.iTransformer import iTransformerForForecasting, iTransformerConfig
4-
from basicts.runners.callback import EarlyStopping, GradientClipping
55
from basicts.metrics import masked_mse
6-
from torch.optim.lr_scheduler import MultiStepLR
6+
from basicts.models.iTransformer import (iTransformerConfig,
7+
iTransformerForForecasting)
8+
from basicts.runners.callback import EarlyStopping, GradientClipping
79

810

911
def main():

examples/imputation/imputation_demo.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from basicts import BasicTSLauncher
22
from basicts.configs import BasicTSImputationConfig
3-
from basicts.models.iTransformer import iTransformerForReconstruction, iTransformerConfig
3+
from basicts.models.iTransformer import (iTransformerConfig,
4+
iTransformerForReconstruction)
45

56

67
def main():

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ dependencies = [
1818
"sympy",
1919
"openpyxl",
2020
"setuptools==59.5.0",
21-
"numpy==1.24.4",
21+
"numpy",
2222
"tqdm==4.67.1",
2323
"tensorboard==2.18.0",
2424
"transformers==4.40.1"

src/basicts/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from .launcher import BasicTSLauncher
22

3-
__version__ = '1.0.2'
3+
__version__ = '1.1.0'
44

55
__all__ = ['__version__', 'BasicTSLauncher']

src/basicts/configs/base_config.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,17 +9,22 @@
99
from functools import partial
1010
from numbers import Number
1111
from types import FunctionType
12-
from typing import Callable, List, Literal, Optional, Tuple, Union
12+
from typing import (TYPE_CHECKING, Callable, List, Literal, Optional, Tuple,
13+
Union)
1314

1415
import numpy as np
1516
import torch
16-
from basicts.runners.callback import BasicTSCallback
17-
from basicts.runners.taskflow import BasicTSTaskFlow
1817
from easydict import EasyDict
1918
from torch.optim.lr_scheduler import LRScheduler
2019

2120
from .model_config import BasicTSModelConfig
2221

22+
# avoid circular imports
23+
if TYPE_CHECKING:
24+
from basicts.runners.callback import BasicTSCallback
25+
from basicts.runners.taskflow import BasicTSTaskFlow
26+
27+
2328

2429
@dataclass(init=False)
2530
class BasicTSConfig(EasyDict):
@@ -35,8 +40,8 @@ class BasicTSConfig(EasyDict):
3540
model_config: BasicTSModelConfig
3641

3742
dataset_name: str
38-
taskflow: BasicTSTaskFlow
39-
callbacks: List[BasicTSCallback]
43+
taskflow: "BasicTSTaskFlow"
44+
callbacks: List["BasicTSCallback"]
4045

4146
############################## General Configuration ##############################
4247

@@ -277,7 +282,7 @@ def _pack_params(self, obj: type, obj_params: Union[dict, None]) -> dict:
277282
elif issubclass(obj, LRScheduler) and k == "optimizer":
278283
continue
279284
# short cut has higher priority than params in config
280-
elif k in self:
285+
elif k in self and self[k] is not None:
281286
obj_params[k] = self[k]
282287
return obj_params
283288

@@ -338,7 +343,7 @@ def _serialize_obj(self, obj: object) -> object:
338343
if not isinstance(is_default, bool):
339344
raise ValueError(f"Parameter {k} of {obj.__class__.__name__} is not serializable.")
340345
if not is_default:
341-
params[k] = repr(v)
346+
params[k] = self._serialize_obj(v)
342347

343348
return {
344349
"name": obj.__class__.__name__,

src/basicts/configs/tsc_config.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,13 @@
22
from typing import Callable, List, Literal, Tuple, Union
33

44
import numpy as np
5+
from torch.nn import CrossEntropyLoss
6+
from torch.optim import Adam
7+
58
from basicts.data import UEADataset
69
from basicts.runners.callback import BasicTSCallback
710
from basicts.runners.taskflow import (BasicTSClassificationTaskFlow,
811
BasicTSTaskFlow)
9-
from torch.nn import CrossEntropyLoss
10-
from torch.optim import Adam
1112

1213
from .base_config import BasicTSConfig
1314
from .model_config import BasicTSModelConfig
@@ -99,9 +100,11 @@ class BasicTSClassificationConfig(BasicTSConfig):
99100

100101
# Dataset settings
101102
dataset_type: type = field(default=UEADataset, metadata={"help": "Dataset type."})
102-
dataset_params: Union[dict, None] = field(default=None, metadata={"help": "Dataset parameters."})
103+
dataset_params: Union[dict, None] = field(
104+
default_factory=lambda: {"memmap": False},
105+
metadata={"help": "Dataset parameters."})
103106
use_timestamps: bool = field(default=False, metadata={"help": "Whether to use timestamps as supplementary."})
104-
memmap: bool = field(default=False, metadata={"help": "Whether to use memmap to load datasets."})
107+
memmap: bool = field(default=None, metadata={"help": "Whether to use memmap to load datasets."})
105108
null_val: float = field(default=np.nan, metadata={"help": "Null value."})
106109
null_to_num: float = field(default=0.0, metadata={"help": "Null value to number."})
107110

@@ -148,7 +151,7 @@ class BasicTSClassificationConfig(BasicTSConfig):
148151
optimizer_params: dict = field(
149152
default_factory=lambda: {"lr": 2e-4, "weight_decay": 5e-4},
150153
metadata={"help": "Optimizer parameters."})
151-
lr: float = field(default=2e-4, metadata={"help": "Learning rate."})
154+
lr: float = field(default=None, metadata={"help": "Learning rate."})
152155

153156
# Learning rate scheduler
154157
lr_scheduler: Union[type, None] = field(default=None, metadata={"help": "Learning rate scheduler type."})

0 commit comments

Comments
 (0)