-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathnn_model.py.com
More file actions
148 lines (124 loc) · 5.08 KB
/
nn_model.py.com
File metadata and controls
148 lines (124 loc) · 5.08 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
# Licensed under the MIT License.
# Copyright (c) Microsoft Corporation.
# All the DDP related code changes are marked by triple pound signs ###.
import abc
import torch
import random
import numpy as np
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
from torch.utils.data.dataloader import DataLoader
from batteryml.data import DataBundle
from .base import BaseModel
###DDP begin add
from torch.nn.parallel import DistributedDataParallel as DDP
import os
import torch.distributed as dist
local_rank = int(os.environ["LOCAL_RANK"])
rank = int(os.environ['RANK'])
world_size = int(os.environ['WORLD_SIZE'])
###DDP end add
def seed_worker(worker_id):
worker_seed = torch.initial_seed() % 2**32
np.random.seed(worker_seed)
random.seed(worker_seed)
class NNModel(BaseModel, nn.Module, abc.ABC):
def __init__(self,
batch_size: int = 32,
epochs: int = 10000,
workspace: str = None,
evaluate_freq: int = 500,
checkpoint_freq: int = 1000,
train_batch_size: int = None,
test_batch_size: int = None,
lr: float = 1e-3):
nn.Module.__init__(self)
BaseModel.__init__(self, workspace)
self.train_epochs = epochs
self.evaluate_freq = evaluate_freq
if checkpoint_freq is None or checkpoint_freq == 'None':
self.checkpoint_freq = None
else:
self.checkpoint_freq = min(checkpoint_freq, self.train_epochs)
self.train_batch_size = train_batch_size or batch_size
self.test_batch_size = test_batch_size or batch_size
self.lr = lr
def fit(self,
dataset: DataBundle,
timestamp: str = None,
seed: int = 0):
self.train()
train_data = dataset.train_data
loader = DataLoader(
train_data, self.train_batch_size,
shuffle=True, worker_init_fn=seed_worker)
# TODO: support customization of optimizers
###DDP begin add
torch.cuda.set_device(local_rank) #has to use local rank here
DEVICE = torch.device("cuda", local_rank) #has to use local rank here
model = self.to(DEVICE)
ddp_model = DDP(model,
device_ids=[rank], # list of gpu that the model lives on
output_device=rank, # where to output model
)
###DDP end add
optimizer = optim.Adam(ddp_model.parameters(), lr=self.lr)
timestamp = timestamp or 'UnknownTime'
latest = None
for epoch in tqdm(range(self.train_epochs), desc='Traning'):
self.train()
for batch in loader:
loss = self.forward(**batch, return_loss=True)
if loss == torch.inf:
reset_parameters(self)
optimizer = optim.Adam(ddp_model.parameters(), lr=self.lr)
else:
optimizer.zero_grad()
loss.backward()
optimizer.step()
if self.checkpoint_freq is not None and \
(epoch + 1) % self.checkpoint_freq == 0:
filename = f'{timestamp}_seed_{seed}_epoch_{epoch+1}.ckpt'
if self.workspace is not None:
self.dump_checkpoint(self.workspace / filename)
latest = self.workspace / filename
if (epoch + 1) % self.evaluate_freq == 0:
pred = self.predict(dataset)
score = dataset.evaluate(pred, 'RMSE')
print(f'[{epoch+1}/{self.train_epochs}] RMSE {score:.2f}', flush=True)
if self.workspace is not None:
self.link_latest_checkpoint(latest)
dist.barrier()
for param in ddp_model.parameters():
if param.requires_grad:
# Create a temporary buffer for the all_reduce operation
# The result will overwrite the local param data
dist.all_reduce(param.data, op=dist.ReduceOp.SUM)
param.data /= dist.get_world_size()
@torch.no_grad()
def predict(self, dataset: DataBundle, data_type: str='test') -> torch.Tensor:
self.eval()
# test_data = dataset.test_data
if data_type == 'test':
test_data = dataset.test_data
else:
test_data = dataset.train_data
loader = DataLoader(
test_data, self.test_batch_size,
shuffle=False, worker_init_fn=seed_worker)
predictions = torch.cat([self.forward(**batch) for batch in loader])
return predictions
def to(self, device: str):
return nn.Module.to(self, device)
def dump_checkpoint(self, path: str):
torch.save(self.state_dict(), path)
def load_checkpoint(self, path: str):
self.load_state_dict(torch.load(path))
def reset_parameters(model):
@torch.no_grad()
def weight_reset(m):
reset_parameters = getattr(m, "reset_parameters", None)
if callable(reset_parameters):
m.reset_parameters()
model.apply(weight_reset)