diff --git a/Libs/Groom/Groom.cpp b/Libs/Groom/Groom.cpp index e627522f8e..36787b0661 100644 --- a/Libs/Groom/Groom.cpp +++ b/Libs/Groom/Groom.cpp @@ -518,6 +518,11 @@ bool Groom::run_alignment() { size_t num_domains = project_->get_number_of_domains_per_subject(); auto subjects = project_->get_subjects(); + if (subjects.empty()) { + SW_ERROR("No subjects to groom"); + return false; + } + auto base_params = GroomParameters(project_); bool global_icp = false; @@ -564,6 +569,10 @@ bool Groom::run_alignment() { if (global_icp) { Mesh reference_mesh = vtkSmartPointer::New(); + if (reference_meshes.empty()) { + SW_ERROR("No reference meshes available"); + return false; + } if (reference_index < 0 || reference_index >= reference_meshes.size()) { reference_index = MeshUtils::findReferenceMesh(reference_meshes, subset_size); if (reference_index < 0 || reference_index >= reference_meshes.size()) { @@ -632,6 +641,11 @@ bool Groom::run_alignment() { reference_index = params.get_alignment_reference(); subset_size = params.get_alignment_subset_size(); + if (reference_meshes.empty()) { + SW_ERROR("No reference meshes available"); + return false; + } + Mesh reference_mesh = vtkSmartPointer::New(); if (reference_index < 0 || reference_index >= subjects.size()) { reference_index = MeshUtils::findReferenceMesh(reference_meshes, subset_size); diff --git a/Studio/DeepSSM/DeepSSMJob.cpp b/Studio/DeepSSM/DeepSSMJob.cpp index 8a7308e169..78989469df 100644 --- a/Studio/DeepSSM/DeepSSMJob.cpp +++ b/Studio/DeepSSM/DeepSSMJob.cpp @@ -107,6 +107,15 @@ void DeepSSMJob::run_prep() { py::object create_split = py_deep_ssm_utils.attr("create_split"); create_split(project_, train_split, val_split, test_split); + int num_train = DeepSSMTool::get_split(project_, DeepSSMTool::SplitType::TRAIN).size(); + int num_val = DeepSSMTool::get_split(project_, DeepSSMTool::SplitType::VAL).size(); + int num_test = DeepSSMTool::get_split(project_, DeepSSMTool::SplitType::TEST).size(); + if (num_train == 0 || num_val == 0) { + SW_ERROR("DeepSSM: Not enough subjects in training and validation. Please check split."); + abort(); + //return; + } + if (is_aborted()) { return; } diff --git a/Studio/DeepSSM/DeepSSMTool.cpp b/Studio/DeepSSM/DeepSSMTool.cpp index a78a7131ac..b474318bc2 100644 --- a/Studio/DeepSSM/DeepSSMTool.cpp +++ b/Studio/DeepSSM/DeepSSMTool.cpp @@ -94,6 +94,7 @@ DeepSSMTool::DeepSSMTool(Preferences& prefs) : preferences_(prefs) { QIntValidator* zero_to_hundred = new QIntValidator(0, 100, this); ui_->validation_split->setValidator(zero_to_hundred); ui_->testing_split->setValidator(zero_to_hundred); + ui_->training_split->setValidator(zero_to_hundred); QDoubleValidator* double_validator = new QDoubleValidator(0, 100, 4, this); double_validator->setNotation(QDoubleValidator::StandardNotation); @@ -102,8 +103,8 @@ DeepSSMTool::DeepSSMTool(Preferences& prefs) : preferences_(prefs) { ui_->spacing_y->setValidator(double_validator); ui_->spacing_z->setValidator(double_validator); + connect(ui_->training_split, &QLineEdit::editingFinished, this, &DeepSSMTool::update_split); connect(ui_->validation_split, &QLineEdit::editingFinished, this, &DeepSSMTool::update_split); - connect(ui_->testing_split, &QLineEdit::editingFinished, this, &DeepSSMTool::update_split); ui_->tl_net_options->setVisible(false); @@ -198,6 +199,7 @@ void DeepSSMTool::load_params() { void DeepSSMTool::store_params() { auto params = DeepSSMParameters(session_->get_project()); + params.set_training_split(ui_->training_split->text().toDouble()); params.set_validation_split(ui_->validation_split->text().toDouble()); params.set_testing_split(ui_->testing_split->text().toDouble()); @@ -382,22 +384,24 @@ void DeepSSMTool::update_panels() { //--------------------------------------------------------------------------- void DeepSSMTool::update_split() { + double training = ui_->training_split->text().toDouble(); double testing = ui_->testing_split->text().toDouble(); double validation = ui_->validation_split->text().toDouble(); + training = std::max(std::min(training, 100), 0); testing = std::max(std::min(testing, 100), 0); validation = std::max(std::min(validation, 100), 0); - if (testing + validation > 100) { - if (testing > validation) { - validation = 100 - testing; + if (training + validation > 100) { + if (training > validation) { + validation = 100 - training; } else { - testing = 100 - validation; + training = 100 - validation; } } - ui_->testing_split->setText(QString::number(testing)); + ui_->training_split->setText(QString::number(training)); ui_->validation_split->setText(QString::number(validation)); - ui_->training_split->setText(QString::number(100 - testing - validation)); + ui_->testing_split->setText(QString::number(100 - training - validation)); } //--------------------------------------------------------------------------- @@ -922,6 +926,10 @@ void DeepSSMTool::run_tool(DeepSSMTool::ToolMode type) { } else if (type == DeepSSMTool::ToolMode::DeepSSM_PrepType) { ui_->tab_widget->setCurrentIndex(0); + // check that there are at least 1 subject in test/val/train each + + + SW_LOG("Please Wait: Running Groom/Optimize..."); } else { SW_ERROR("Unknown tool mode"); diff --git a/Studio/DeepSSM/DeepSSMTool.ui b/Studio/DeepSSM/DeepSSMTool.ui index b3d464c02d..0dc84f12c0 100644 --- a/Studio/DeepSSM/DeepSSMTool.ui +++ b/Studio/DeepSSM/DeepSSMTool.ui @@ -579,8 +579,11 @@ Preserved Split - + + + false + 0 @@ -590,10 +593,13 @@ Preserved Qt::AlignCenter + + true + - - + + 0 @@ -601,11 +607,30 @@ Preserved - Validation + Training - + + + + true + + + + 0 + 0 + + + + Qt::AlignCenter + + + false + + + + @@ -618,35 +643,22 @@ Preserved - + % - - - - - 0 - 0 - - - - Qt::AlignCenter - - - - - + + % - - + + 0 @@ -654,15 +666,12 @@ Preserved - Training + Validation - - - - false - + + 0 @@ -672,13 +681,10 @@ Preserved Qt::AlignCenter - - true - - - + + %