This repository was archived by the owner on May 13, 2025. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
62 lines (47 loc) · 1.96 KB
/
Copy pathtrain.py
File metadata and controls
62 lines (47 loc) · 1.96 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
# Dirty workaround for module import, which violates PEP8: E402
import os
import sys
import copy
import pathlib
import argparse
import warnings
from typing import Dict
sys.path.append(os.path.join(os.path.dirname(os.getcwd()), "../../"))
import torch
from torch.nn import Module
from torch.utils.data import DataLoader
from AI.src.modeling import build_model
from AI.src.data import build_dataloader
from AI.src.optimizer import build_optimizer
from AI.src.losses import LossWrapper
from AI.src.metrics import MetricWrapper
from AI.src.runner import Trainer
from AI.src.utils import DotDict, ConfigReader
torch.set_num_threads(64)
torch.set_num_interop_threads(64)
warnings.filterwarnings("once")
warnings.filterwarnings("ignore", category=DeprecationWarning)
DEFAULT_CONFIG_PATH: Dict[str, pathlib.Path] = {
"linux": pathlib.Path("../../config/single/linux.json"),
"win32": pathlib.Path("../../config/single/windows.json")
}
def main(args: argparse.Namespace) -> None:
config: DotDict = ConfigReader(args.config).config
train_dataloader: DataLoader = build_dataloader(copy.deepcopy(config), "train")
val_dataloader: DataLoader = build_dataloader(copy.deepcopy(config), "val")
model: Module = build_model(copy.deepcopy(config))
optim, scheduler = build_optimizer(copy.deepcopy(config), model)
loss: LossWrapper = LossWrapper(copy.deepcopy(config))
metrics: MetricWrapper = MetricWrapper(copy.deepcopy(config))
trainer = Trainer(config, model, optim, scheduler, loss, metrics, train_dataloader, val_dataloader)
trainer.fit()
return None
if __name__ == "__main__":
argument_parser = argparse.ArgumentParser()
argument_parser.add_argument("--config",
default=DEFAULT_CONFIG_PATH[sys.platform],
type=str,
help="Path to config file"
)
parsed_args = argument_parser.parse_args()
main(parsed_args)