|
2 | 2 | from typing import Callable, List, Literal, Tuple, Union |
3 | 3 |
|
4 | 4 | import numpy as np |
| 5 | +from torch.nn import CrossEntropyLoss |
| 6 | +from torch.optim import Adam |
| 7 | + |
5 | 8 | from basicts.data import UEADataset |
6 | 9 | from basicts.runners.callback import BasicTSCallback |
7 | 10 | from basicts.runners.taskflow import (BasicTSClassificationTaskFlow, |
8 | 11 | BasicTSTaskFlow) |
9 | | -from torch.nn import CrossEntropyLoss |
10 | | -from torch.optim import Adam |
11 | 12 |
|
12 | 13 | from .base_config import BasicTSConfig |
13 | 14 | from .model_config import BasicTSModelConfig |
@@ -99,9 +100,11 @@ class BasicTSClassificationConfig(BasicTSConfig): |
99 | 100 |
|
100 | 101 | # Dataset settings |
101 | 102 | dataset_type: type = field(default=UEADataset, metadata={"help": "Dataset type."}) |
102 | | - dataset_params: Union[dict, None] = field(default=None, metadata={"help": "Dataset parameters."}) |
| 103 | + dataset_params: Union[dict, None] = field( |
| 104 | + default_factory=lambda: {"memmap": False}, |
| 105 | + metadata={"help": "Dataset parameters."}) |
103 | 106 | use_timestamps: bool = field(default=False, metadata={"help": "Whether to use timestamps as supplementary."}) |
104 | | - memmap: bool = field(default=False, metadata={"help": "Whether to use memmap to load datasets."}) |
| 107 | + memmap: bool = field(default=None, metadata={"help": "Whether to use memmap to load datasets."}) |
105 | 108 | null_val: float = field(default=np.nan, metadata={"help": "Null value."}) |
106 | 109 | null_to_num: float = field(default=0.0, metadata={"help": "Null value to number."}) |
107 | 110 |
|
@@ -148,7 +151,7 @@ class BasicTSClassificationConfig(BasicTSConfig): |
148 | 151 | optimizer_params: dict = field( |
149 | 152 | default_factory=lambda: {"lr": 2e-4, "weight_decay": 5e-4}, |
150 | 153 | metadata={"help": "Optimizer parameters."}) |
151 | | - lr: float = field(default=2e-4, metadata={"help": "Learning rate."}) |
| 154 | + lr: float = field(default=None, metadata={"help": "Learning rate."}) |
152 | 155 |
|
153 | 156 | # Learning rate scheduler |
154 | 157 | lr_scheduler: Union[type, None] = field(default=None, metadata={"help": "Learning rate scheduler type."}) |
|
0 commit comments