-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathconfig.py
More file actions
38 lines (32 loc) · 1.52 KB
/
config.py
File metadata and controls
38 lines (32 loc) · 1.52 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
from dataclasses import dataclass
from typing import Union
import pytorch_lightning as pl
from pyrallis import field
from unet.implicit_unet import EncoderSolver
from unet.original_unet import FeaturesUNet
@dataclass
class TrainConfig:
dataset_path: Union[list, str] = field()
results_path: str = field()
exp_name: str = field(default='default_exp')
batch_size: int = field(default=32)
val_percent: float = field(default=0.2)
seed: int = field(default=42)
devices: int = field(default=1)
epochs: int = field(default=100)
workers: int = field(default=4)
# Architecture to select. Options are:
# ExplicitEncoderSolver, ImplicitEncoderSolver, OriginalEncoderSolver
architecture: str = field(default='ImplicitEncoderSolver')
input_channels: int = field(default=3)
output_channels: int = field(default=2)
small: bool = field(default=False)
@property
def model(self) -> pl.LightningModule:
factory = {'ExplicitEncoderSolver': EncoderSolver(self.input_channels, self.output_channels,
implicit=False, small=self.small),
'ImplicitEncoderSolver': EncoderSolver(self.input_channels, self.output_channels,
implicit=False, small=self.small),
'OriginalEncoderSolver': FeaturesUNet(self.input_channels, self.output_channels,
kernel=3)}
return factory[self.architecture]