-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy path03_train.py
More file actions
executable file
·103 lines (91 loc) · 3.69 KB
/
03_train.py
File metadata and controls
executable file
·103 lines (91 loc) · 3.69 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Base training script
"""
import os
import sys
sys.path.insert(2, os.getcwd())
import argparse
import torch
from cfno.training.pySDC import FourierNeuralOp
from cfno.utils import readConfig
# -----------------------------------------------------------------------------
# Script parameters
# -----------------------------------------------------------------------------
parser = argparse.ArgumentParser(
description='Train a 2D FNO model on a given dataset',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument(
"--trainDir", default="trainDir", help="directory to store training results")
parser.add_argument(
"--nEpochs", default=200, type=int, help="number of epochs to train on")
parser.add_argument(
"--checkpoint", default="model.pt", help="name of the file storing the model")
parser.add_argument(
"--saveEvery", default=100, type=int, help="save checkpoint every [...] epochs")
parser.add_argument(
"--ndim", default=2, type=int, help="FNO2D or 3D")
parser.add_argument(
"--data_aug", action="store_true", help='Add noisy data per batch while training')
parser.add_argument(
"--model_class", default="CFNO", help="CFNO or FNO")
parser.add_argument(
"--savePermanent", action="store_true", help="save permanent checkpoint into [...]_epochs[...].pt files")
parser.add_argument(
"--noTensorboard", action="store_false", help="do not use tensorboard for losses output (only native)")
parser.add_argument(
"--lossesFile", default=FourierNeuralOp.LOSSES_FILE, help='base text file to write the loss')
parser.add_argument(
"--physicsLossesFile", default=None, help="text file to write individual loss contributions for the physics loss")
parser.add_argument(
"--logPrint", action="store_true", help='print loss after each optimizer step')
parser.add_argument(
"--config", default="config.yaml", help="configuration file")
args = parser.parse_args()
config = readConfig(args.config)
if "train" in config:
args.__dict__.update(**config.train)
sections = ["data", "model", "optim", "lr_scheduler", "parallel_strategy"]
for name in sections:
assert name in config, f"config file needs a {name} section"
# trainer class configs, "loss" parameter uses default if not specified
configs = {name: config.get(name) for name in (sections + ["loss"])}
nEpochs = args.nEpochs
saveEvery = args.saveEvery
savePermanent = args.savePermanent
checkpoint = args.checkpoint
if not args.noTensorboard:
FourierNeuralOp.USE_TENSORBOARD = True
else:
FourierNeuralOp.USE_TENSORBOARD = False
# -----------------------------------------------------------------------------
# Script execution
# -----------------------------------------------------------------------------
FourierNeuralOp.TRAIN_DIR = args.trainDir
FourierNeuralOp.LOSSES_FILE = args.lossesFile
FourierNeuralOp.LOG_PRINT = args.logPrint
FourierNeuralOp.PHYSICS_LOSSES_FILE = args.physicsLossesFile
model = FourierNeuralOp(**configs,
checkpoint=checkpoint,
model_class=args.model_class,
ndim=args.ndim,
data_aug=args.data_aug)
saveEvery = min(nEpochs, saveEvery)
nChunks = nEpochs // saveEvery
lastChunk = nEpochs % saveEvery
# cPrefix = os.path.splitext(checkpoint)[0]
for _ in range(nChunks):
model.learn(saveEvery)
if savePermanent:
model.save(f"model_epochs{model.epochs}.pt")
else:
model.save(f"model.pt")
if lastChunk > 0:
model.learn(lastChunk)
if savePermanent:
model.save(f"model_epochs{model.epochs}.pt")
else:
model.save(f"model.pt")
if torch.distributed.is_initialized():
torch.distributed.destroy_process_group()