-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathmodel_training_LOO_subject.py
More file actions
98 lines (72 loc) · 3.74 KB
/
Copy pathmodel_training_LOO_subject.py
File metadata and controls
98 lines (72 loc) · 3.74 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
# Copyright (C) 2024 ETH Zurich. All rights reserved.
# Author: Carlos Santos, ETH Zurich
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License.
# You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0.
# Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and limitations under the License.
# SPDX-License-Identifier: Apache-2.0
# Imports
import os
import numpy as np
import scipy.io as sio
import torch
# Custom imports
from train import train_encoder, train_classifier
from utils import read_model_lines, get_LOO_sets
from parser_file import parse_train, log_arguments
#############################################################
# Script to train models with leave-one-out subject
# For each model, for each subject train encoder-decoder and encoder-classifier (LOSO)
#############################################################
def main():
# Load argparse arguments
args = parse_train()
log_arguments(args)
# Connect to GPU if available
device = torch.device(args.device if torch.cuda.is_available() else "cpu")
print(device)
# Seed
torch.manual_seed(314)
torch.cuda.manual_seed(314)
# Read training models from saving directory
trainings_file = os.path.join(args.all_model_dir, 'Trainings.txt')
if not os.path.exists(trainings_file):
raise ValueError("Trainings file not found.")
# Divide lines into training names and features
trainings, features = read_model_lines(trainings_file)
# Get training and validation paths
subject_ids = os.listdir(args.data_loading_dir)
# LOSO CV - n subjects
enc_mat = np.zeros(len(subject_ids))
cla_mat = np.zeros(len(subject_ids))
for i in range(len(trainings)): # for each model
training = trainings[i] # "Model 1/Model 2/Model n"/...
feature_names = features[i] # "L_inear/R_inear"/...
# Create model directory - for current model save
model_dir = os.path.join(args.all_model_dir, training)
if not os.path.exists(model_dir):
os.makedirs(model_dir)
for j in range(len(subject_ids)): # for each subject, LOSO
subject_id = subject_ids[j]
print(f'{training}. LOO Subject: {subject_id}')
# Directory - for current model and fold save
subject_fold = os.path.join(model_dir, subject_id)
if not os.path.exists(subject_fold):
os.makedirs(subject_fold)
# Training log file
log_file_name = os.path.join(subject_fold, 'log_file.txt')
log_file = open(log_file_name, 'w')
###### Training #####
# Training and validation files -> n subjects vs 1 subject
train_set, val_set = get_LOO_sets(args.data_loading_dir, subject_id)
# Train encoder
encoder_weights, last_loss_encoder = train_encoder(device, subject_fold, feature_names, train_set, val_set, log_file, args, weight_init = False)
enc_mat[j] = last_loss_encoder
# Train classifier
last_loss_classifier = train_classifier(device, subject_fold, feature_names, train_set, val_set, log_file, encoder_weights, args, weight_init = True)
cla_mat[j] = last_loss_classifier
log_file.close()
# Save encoder and classifier loss matrices
sio.savemat(os.path.join(model_dir, 'loss_' + training + '.mat'), {'enc_mat': enc_mat, 'cla_mat': cla_mat})
if __name__ == "__main__":
main()