-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
72 lines (61 loc) · 2.71 KB
/
main.py
File metadata and controls
72 lines (61 loc) · 2.71 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import torch
import pandas as pd
from glob import glob
from torch.utils.data import DataLoader
from torch import optim
from torchvision import transforms
from utils import set_seed, initialize_device
from dataset import BrainMRIDataset
from model import UNet
from training import train_node_model
from data_utils import split_data_into_nodes
if __name__ == "__main__":
set_seed()
device = initialize_device()
# Parameters
num_nodes = 15
radius = 0.3
batch_size = 16
learning_rate = 1e-3
total_iterations = 150
epochs_per_iteration = 5
image_h, image_w = 256, 256
# Load and preprocess data
mask_images = glob(r'./Dataset/kaggle_3m/*/*_mask*')
image_filenames = [i.replace("_mask", "") for i in mask_images]
chunks_images, chunks_masks = split_data_into_nodes(image_filenames, mask_images, num_nodes)
# Convert chunks into dataframes
dfs = [
pd.DataFrame({'image_filename': chunks_images[i], 'mask_images': chunks_masks[i]})
for i in range(len(chunks_images))
]
# Initialize models, datasets, loaders, optimizers, and schedulers
models = [UNet() for _ in range(num_nodes)]
# Wrap models with nn.DataParallel
if torch.cuda.device_count() > 1:
print(f"Using {torch.cuda.device_count()} GPUs for training.")
models = [torch.nn.DataParallel(model).to(device) for model in models]
else:
print("Only one GPU available. Using single GPU.")
models = [model.to(device) for model in models]
# Data transforms
train_transforms = transforms.Compose([
transforms.RandomRotation(20),
transforms.RandomResizedCrop((image_h, image_w), scale=(0.95, 1.05)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor()
])
# Create datasets and DataLoaders
datasets = [BrainMRIDataset(dfs[i], image_transform=train_transforms, mask_transform=train_transforms) for i in range(num_nodes)]
loaders = [DataLoader(datasets[i], batch_size=batch_size, shuffle=True) for i in range(num_nodes)]
# Optimizers and schedulers
optimizers = [optim.Adam(models[i].parameters(), lr=learning_rate) for i in range(num_nodes)]
schedulers = [optim.lr_scheduler.StepLR(optimizers[i], step_size=30, gamma=0.1) for i in range(num_nodes)]
# Training loop
for iteration in range(total_iterations):
print(f"\n===== Iteration {iteration + 1}/{total_iterations} =====")
for epoch in range(epochs_per_iteration):
print(f"-- Epoch {epoch + 1}/{epochs_per_iteration} --")
for node_id, model in enumerate(models):
train_node_model(node_id, model, loaders[node_id], optimizers[node_id], schedulers[node_id], device)
print("Training completed.")