Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions Libs/Groom/Groom.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -518,6 +518,11 @@ bool Groom::run_alignment() {
size_t num_domains = project_->get_number_of_domains_per_subject();
auto subjects = project_->get_subjects();

if (subjects.empty()) {
SW_ERROR("No subjects to groom");
return false;
}

auto base_params = GroomParameters(project_);

bool global_icp = false;
Expand Down Expand Up @@ -564,6 +569,10 @@ bool Groom::run_alignment() {

if (global_icp) {
Mesh reference_mesh = vtkSmartPointer<vtkPolyData>::New();
if (reference_meshes.empty()) {
SW_ERROR("No reference meshes available");
return false;
}
if (reference_index < 0 || reference_index >= reference_meshes.size()) {
reference_index = MeshUtils::findReferenceMesh(reference_meshes, subset_size);
if (reference_index < 0 || reference_index >= reference_meshes.size()) {
Expand Down Expand Up @@ -632,6 +641,11 @@ bool Groom::run_alignment() {
reference_index = params.get_alignment_reference();
subset_size = params.get_alignment_subset_size();

if (reference_meshes.empty()) {
SW_ERROR("No reference meshes available");
return false;
}

Mesh reference_mesh = vtkSmartPointer<vtkPolyData>::New();
if (reference_index < 0 || reference_index >= subjects.size()) {
reference_index = MeshUtils::findReferenceMesh(reference_meshes, subset_size);
Expand Down
9 changes: 9 additions & 0 deletions Studio/DeepSSM/DeepSSMJob.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,15 @@ void DeepSSMJob::run_prep() {
py::object create_split = py_deep_ssm_utils.attr("create_split");
create_split(project_, train_split, val_split, test_split);

int num_train = DeepSSMTool::get_split(project_, DeepSSMTool::SplitType::TRAIN).size();
int num_val = DeepSSMTool::get_split(project_, DeepSSMTool::SplitType::VAL).size();
int num_test = DeepSSMTool::get_split(project_, DeepSSMTool::SplitType::TEST).size();
if (num_train == 0 || num_val == 0) {
SW_ERROR("DeepSSM: Not enough subjects in training and validation. Please check split.");
abort();
//return;
}

if (is_aborted()) {
return;
}
Expand Down
22 changes: 15 additions & 7 deletions Studio/DeepSSM/DeepSSMTool.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ DeepSSMTool::DeepSSMTool(Preferences& prefs) : preferences_(prefs) {
QIntValidator* zero_to_hundred = new QIntValidator(0, 100, this);
ui_->validation_split->setValidator(zero_to_hundred);
ui_->testing_split->setValidator(zero_to_hundred);
ui_->training_split->setValidator(zero_to_hundred);

QDoubleValidator* double_validator = new QDoubleValidator(0, 100, 4, this);
double_validator->setNotation(QDoubleValidator::StandardNotation);
Expand All @@ -102,8 +103,8 @@ DeepSSMTool::DeepSSMTool(Preferences& prefs) : preferences_(prefs) {
ui_->spacing_y->setValidator(double_validator);
ui_->spacing_z->setValidator(double_validator);

connect(ui_->training_split, &QLineEdit::editingFinished, this, &DeepSSMTool::update_split);
connect(ui_->validation_split, &QLineEdit::editingFinished, this, &DeepSSMTool::update_split);
connect(ui_->testing_split, &QLineEdit::editingFinished, this, &DeepSSMTool::update_split);

ui_->tl_net_options->setVisible(false);

Expand Down Expand Up @@ -198,6 +199,7 @@ void DeepSSMTool::load_params() {
void DeepSSMTool::store_params() {
auto params = DeepSSMParameters(session_->get_project());

params.set_training_split(ui_->training_split->text().toDouble());
params.set_validation_split(ui_->validation_split->text().toDouble());
params.set_testing_split(ui_->testing_split->text().toDouble());

Expand Down Expand Up @@ -382,22 +384,24 @@ void DeepSSMTool::update_panels() {

//---------------------------------------------------------------------------
void DeepSSMTool::update_split() {
double training = ui_->training_split->text().toDouble();
double testing = ui_->testing_split->text().toDouble();
double validation = ui_->validation_split->text().toDouble();
training = std::max<double>(std::min<double>(training, 100), 0);
testing = std::max<double>(std::min<double>(testing, 100), 0);
validation = std::max<double>(std::min<double>(validation, 100), 0);

if (testing + validation > 100) {
if (testing > validation) {
validation = 100 - testing;
if (training + validation > 100) {
if (training > validation) {
validation = 100 - training;
} else {
testing = 100 - validation;
training = 100 - validation;
}
}

ui_->testing_split->setText(QString::number(testing));
ui_->training_split->setText(QString::number(training));
ui_->validation_split->setText(QString::number(validation));
ui_->training_split->setText(QString::number(100 - testing - validation));
ui_->testing_split->setText(QString::number(100 - training - validation));
}

//---------------------------------------------------------------------------
Expand Down Expand Up @@ -922,6 +926,10 @@ void DeepSSMTool::run_tool(DeepSSMTool::ToolMode type) {
} else if (type == DeepSSMTool::ToolMode::DeepSSM_PrepType) {
ui_->tab_widget->setCurrentIndex(0);

// check that there are at least 1 subject in test/val/train each



SW_LOG("Please Wait: Running Groom/Optimize...");
} else {
SW_ERROR("Unknown tool mode");
Expand Down
74 changes: 40 additions & 34 deletions Studio/DeepSSM/DeepSSMTool.ui
Original file line number Diff line number Diff line change
Expand Up @@ -579,8 +579,11 @@ Preserved</string>
<string>Split</string>
</property>
<layout class="QGridLayout" name="gridLayout_14">
<item row="0" column="1">
<item row="2" column="1">
<widget class="QLineEdit" name="testing_split">
<property name="enabled">
<bool>false</bool>
</property>
<property name="sizePolicy">
<sizepolicy hsizetype="Expanding" vsizetype="Fixed">
<horstretch>0</horstretch>
Expand All @@ -590,22 +593,44 @@ Preserved</string>
<property name="alignment">
<set>Qt::AlignCenter</set>
</property>
<property name="readOnly">
<bool>true</bool>
</property>
</widget>
</item>
<item row="1" column="0">
<widget class="QLabel" name="label_12">
<item row="0" column="0">
<widget class="QLabel" name="label_33">
<property name="sizePolicy">
<sizepolicy hsizetype="Expanding" vsizetype="Preferred">
<horstretch>0</horstretch>
<verstretch>0</verstretch>
</sizepolicy>
</property>
<property name="text">
<string>Validation</string>
<string>Training</string>
</property>
</widget>
</item>
<item row="0" column="0">
<item row="0" column="1">
<widget class="QLineEdit" name="training_split">
<property name="enabled">
<bool>true</bool>
</property>
<property name="sizePolicy">
<sizepolicy hsizetype="Expanding" vsizetype="Fixed">
<horstretch>0</horstretch>
<verstretch>0</verstretch>
</sizepolicy>
</property>
<property name="alignment">
<set>Qt::AlignCenter</set>
</property>
<property name="readOnly">
<bool>false</bool>
</property>
</widget>
</item>
<item row="2" column="0">
<widget class="QLabel" name="label_11">
<property name="sizePolicy">
<sizepolicy hsizetype="Expanding" vsizetype="Preferred">
Expand All @@ -618,51 +643,35 @@ Preserved</string>
</property>
</widget>
</item>
<item row="0" column="2">
<item row="2" column="2">
<widget class="QLabel" name="label_14">
<property name="text">
<string>%</string>
</property>
</widget>
</item>
<item row="1" column="1">
<widget class="QLineEdit" name="validation_split">
<property name="sizePolicy">
<sizepolicy hsizetype="Expanding" vsizetype="Fixed">
<horstretch>0</horstretch>
<verstretch>0</verstretch>
</sizepolicy>
</property>
<property name="alignment">
<set>Qt::AlignCenter</set>
</property>
</widget>
</item>
<item row="1" column="2">
<widget class="QLabel" name="label_15">
<item row="0" column="2">
<widget class="QLabel" name="label_34">
<property name="text">
<string>%</string>
</property>
</widget>
</item>
<item row="2" column="0">
<widget class="QLabel" name="label_33">
<item row="1" column="0">
<widget class="QLabel" name="label_12">
<property name="sizePolicy">
<sizepolicy hsizetype="Expanding" vsizetype="Preferred">
<horstretch>0</horstretch>
<verstretch>0</verstretch>
</sizepolicy>
</property>
<property name="text">
<string>Training</string>
<string>Validation</string>
</property>
</widget>
</item>
<item row="2" column="1">
<widget class="QLineEdit" name="training_split">
<property name="enabled">
<bool>false</bool>
</property>
<item row="1" column="1">
<widget class="QLineEdit" name="validation_split">
<property name="sizePolicy">
<sizepolicy hsizetype="Expanding" vsizetype="Fixed">
<horstretch>0</horstretch>
Expand All @@ -672,13 +681,10 @@ Preserved</string>
<property name="alignment">
<set>Qt::AlignCenter</set>
</property>
<property name="readOnly">
<bool>true</bool>
</property>
</widget>
</item>
<item row="2" column="2">
<widget class="QLabel" name="label_34">
<item row="1" column="2">
<widget class="QLabel" name="label_15">
<property name="text">
<string>%</string>
</property>
Expand Down
Loading