diff --git a/Applications/shapeworks/CMakeLists.txt b/Applications/shapeworks/CMakeLists.txt index b6cafe7f4c7..eded1b33e89 100644 --- a/Applications/shapeworks/CMakeLists.txt +++ b/Applications/shapeworks/CMakeLists.txt @@ -23,7 +23,7 @@ target_include_directories(shapeworks_exe PUBLIC target_link_libraries(shapeworks_exe Mesh ${VTK_LIBRARIES} Optimize Utils trimesh2 Particles - pybind11::embed Project Image Groom Analyze + pybind11::embed Project Image Groom Analyze Application ) message(STATUS "opt libs ${OPTIMIZE_LIBRARIES}") diff --git a/Applications/shapeworks/Command.h b/Applications/shapeworks/Command.h index 881e448e575..8c6db366ef0 100644 --- a/Applications/shapeworks/Command.h +++ b/Applications/shapeworks/Command.h @@ -103,6 +103,15 @@ class ParticleSystemCommand : public Command private: }; +class DeepSSMCommandGroup : public Command +{ +public: + const std::string type() override { return "DeepSSM"; } + +private: +}; + + class ShapeworksCommand : public Command { public: diff --git a/Applications/shapeworks/Commands.cpp b/Applications/shapeworks/Commands.cpp index b94955282be..f595c5837e5 100644 --- a/Applications/shapeworks/Commands.cpp +++ b/Applications/shapeworks/Commands.cpp @@ -1,18 +1,20 @@ #include "Commands.h" #include +#include +#include #include #include #include #include #include +#include #include #include +#include #include -#include - namespace shapeworks { // boilerplate for a command. Copy this to start a new command @@ -43,8 +45,6 @@ bool Example::execute(const optparse::Values &options, SharedCommandData &shared } #endif - - /////////////////////////////////////////////////////////////////////////////// // Seed /////////////////////////////////////////////////////////////////////////////// @@ -331,4 +331,164 @@ bool ConvertProjectCommand::execute(const optparse::Values& options, SharedComma return false; } } + +/////////////////////////////////////////////////////////////////////////////// +// DeepSSM +/////////////////////////////////////////////////////////////////////////////// +void DeepSSMCommand::buildParser() { + const std::string prog = "deepssm"; + const std::string desc = "run deepssm steps"; + parser.prog(prog).description(desc); + + parser.add_option("--name").action("store").type("string").set_default("").help( + "Path to input project file (xlsx or swproj)."); + + // Create a vector of choices first + std::vector prep_choices = {"all", "groom_training", "optimize_training", "optimize_validation", + "groom_images"}; + + // --prep option with choices + parser.add_option("--prep") + .action("store") + .type("choice") + .choices(prep_choices.begin(), prep_choices.end()) + //.set_default("all") + .help("Preparation step to run"); + + // Boolean flag options + parser.add_option("--augment").action("store_true").help("Run data augmentation"); + + parser.add_option("--train").action("store_true").help("Run training"); + + parser.add_option("--test").action("store_true").help("Run testing"); + + parser.add_option("--all").action("store_true").help("Run all steps"); + + Command::buildParser(); +} + +bool DeepSSMCommand::execute(const optparse::Values& options, SharedCommandData& sharedData) { + // Create a non-gui QApplication instance + int argc = 3; + char* argv[3]; + argv[0] = const_cast("shapeworks"); + argv[1] = const_cast("-platform"); + argv[2] = const_cast("offscreen"); + + QApplication app(argc, argv); + + // Handle project file: either from --name or first positional argument + std::string project_file; + if (options.is_set_by_user("name")) { + // User explicitly provided --name + project_file = options["name"]; + } else if (!parser.args().empty()) { + // Use first positional argument + project_file = parser.args()[0]; + } else { + // No project file provided at all + parser.error("Project file must be provided either as --name or as a positional argument"); + } + + // Handle prep option with manual default + std::string prep_step; + if (options.is_set_by_user("prep")) { + prep_step = options["prep"]; + } else { + prep_step = "all"; // Manual default + } + + std::cout << "DeepSSM: Using project file: " << project_file << std::endl; + + bool do_prep = options.is_set("prep") || options.is_set("all"); + bool do_augment = options.is_set("augment") || options.is_set("all"); + bool do_train = options.is_set("train") || options.is_set("all"); + bool do_test = options.is_set("test") || options.is_set("all"); + + std::cout << "Prep step: " << (do_prep ? "on" : "off") << "\n"; + std::cout << "Augment step: " << (do_augment ? "on" : "off") << "\n"; + std::cout << "Train step: " << (do_train ? "on" : "off") << "\n"; + std::cout << "Test step: " << (do_test ? "on" : "off") << "\n"; + + if (!do_prep && !do_augment && !do_train && !do_test) { + do_prep = true; + do_augment = true; + do_train = true; + do_test = true; + } + + ProjectHandle project = std::make_shared(); + project->load(project_file); + + PythonWorker python_worker; + python_worker.set_cli_mode(true); + + auto wait_for_job = [&](auto job) { + // This lambda will block until the job is complete + while (!job->is_complete()) { + QCoreApplication::processEvents(); + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + if (job->is_aborted()) { + return false; + } + } + return true; + }; + + if (do_prep) { + auto job = QSharedPointer::create(project, DeepSSMJob::JobType::DeepSSM_PrepType); + if (prep_step == "all") { + job->set_prep_step(DeepSSMJob::PrepStep::NOT_STARTED); + } else if (prep_step == "groom_training") { + job->set_prep_step(DeepSSMJob::PrepStep::GROOM_TRAINING); + } else if (prep_step == "optimize_training") { + job->set_prep_step(DeepSSMJob::PrepStep::OPTIMIZE_TRAINING); + } else if (prep_step == "optimize_validation") { + job->set_prep_step(DeepSSMJob::PrepStep::OPTIMIZE_VALIDATION); + } else if (prep_step == "groom_images") { + job->set_prep_step(DeepSSMJob::PrepStep::GROOM_IMAGES); + } else { + SW_ERROR("Unknown prep step: {}", prep_step); + return false; + } + std::cout << "Running DeepSSM preparation step...\n"; + python_worker.run_job(job); + if (!wait_for_job(job)) { + return false; + } + std::cout << "DeepSSM preparation step completed.\n"; + } + if (do_augment) { + std::cout << "Running DeepSSM data augmentation...\n"; + auto job = QSharedPointer::create(project, DeepSSMJob::JobType::DeepSSM_AugmentationType); + python_worker.run_job(job); + if (!wait_for_job(job)) { + return false; + } + std::cout << "DeepSSM data augmentation completed.\n"; + } + if (do_train) { + std::cout << "Running DeepSSM training...\n"; + auto job = QSharedPointer::create(project, DeepSSMJob::JobType::DeepSSM_TrainingType); + python_worker.run_job(job); + if (!wait_for_job(job)) { + return false; + } + std::cout << "DeepSSM training completed.\n"; + } + if (do_test) { + std::cout << "Running DeepSSM testing...\n"; + auto job = QSharedPointer::create(project, DeepSSMJob::JobType::DeepSSM_TestingType); + python_worker.run_job(job); + if (!wait_for_job(job)) { + return false; + } + std::cout << "DeepSSM testing completed.\n"; + } + + project->save(); + + return false; +} + } // namespace shapeworks diff --git a/Applications/shapeworks/Commands.h b/Applications/shapeworks/Commands.h index 631a96643c3..3cbfc6fb2b8 100644 --- a/Applications/shapeworks/Commands.h +++ b/Applications/shapeworks/Commands.h @@ -101,5 +101,6 @@ COMMAND_DECLARE(OptimizeCommand, OptimizeCommandGroup); COMMAND_DECLARE(GroomCommand, GroomCommandGroup); COMMAND_DECLARE(AnalyzeCommand, AnalyzeCommandGroup); COMMAND_DECLARE(ConvertProjectCommand, ProjectCommandGroup); +COMMAND_DECLARE(DeepSSMCommand, DeepSSMCommandGroup); } // shapeworks diff --git a/Applications/shapeworks/shapeworks.cpp b/Applications/shapeworks/shapeworks.cpp index 63cd6ad7513..06866ac3f54 100644 --- a/Applications/shapeworks/shapeworks.cpp +++ b/Applications/shapeworks/shapeworks.cpp @@ -110,6 +110,7 @@ int main(int argc, char *argv[]) shapeworks.addCommand(GroomCommand::getCommand()); shapeworks.addCommand(AnalyzeCommand::getCommand()); shapeworks.addCommand(ConvertProjectCommand::getCommand()); + shapeworks.addCommand(DeepSSMCommand::getCommand()); try { TIME_START("shapeworks"); diff --git a/Libs/Application/CMakeLists.txt b/Libs/Application/CMakeLists.txt new file mode 100644 index 00000000000..bab42215742 --- /dev/null +++ b/Libs/Application/CMakeLists.txt @@ -0,0 +1,50 @@ +SET(APPLICATION_MOC_HDRS + DeepSSM/DeepSSMJob.h + Job/Job.h + Job/PythonWorker.h + ShapeWorksVtkOutputWindow.h +) + +qt5_wrap_cpp( APPLICATION_MOC_SRCS ${APPLICATION_MOC_HDRS} ) + +SET(Application_headers + ) + +add_library(Application STATIC + DeepSSM/DeepSSMJob.cpp + Job/Job.cpp + Job/PythonWorker.cpp + ShapeWorksVtkOutputWindow.cpp + ${APPLICATION_MOC_SRCS} + ) + +target_include_directories(Application PUBLIC + $ + $) + +set(SW_PYTHON_LIBS pybind11::embed) + +if (APPLE) + include_directories(${_Python3_INCLUDE_DIR}) + set(SW_PYTHON_LIBS "") +endif(APPLE) + +target_link_libraries(Application PUBLIC + Groom + Mesh + Utils + Particles + Project + ${SW_PYTHON_LIBS} + ) + +# set +set_target_properties(Application PROPERTIES PUBLIC_HEADER + "${Application_headers}") + +install(TARGETS Application EXPORT ShapeWorksTargets + ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} + RUNTIME DESTINATION ${CMAKE_INSTALL_LIBDIR} + LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} + PUBLIC_HEADER DESTINATION include/Application + ) diff --git a/Studio/DeepSSM/DeepSSMJob.cpp b/Libs/Application/DeepSSM/DeepSSMJob.cpp similarity index 83% rename from Studio/DeepSSM/DeepSSMJob.cpp rename to Libs/Application/DeepSSM/DeepSSMJob.cpp index 78989469dfe..7de61cd6437 100644 --- a/Studio/DeepSSM/DeepSSMJob.cpp +++ b/Libs/Application/DeepSSM/DeepSSMJob.cpp @@ -2,25 +2,17 @@ #include #include -#include "qdir.h" namespace py = pybind11; using namespace pybind11::literals; // to bring in the `_a` literal -// std -#include -#include -#include -#include - // qt #include #include #include // shapeworks -#include -#include -#include +#include "DeepSSMJob.h" +#include #include #include #include @@ -30,11 +22,8 @@ using namespace pybind11::literals; // to bring in the `_a` literal namespace shapeworks { //--------------------------------------------------------------------------- -DeepSSMJob::DeepSSMJob(QSharedPointer session, DeepSSMTool::ToolMode tool_mode, - DeepSSMTool::PrepStep prep_step) - : session_(session), tool_mode_(tool_mode), prep_step_(prep_step) { - project_ = session_->get_project(); -} +DeepSSMJob::DeepSSMJob(std::shared_ptr project, DeepSSMJob::JobType tool_mode, DeepSSMJob::PrepStep prep_step) + : project_(project), job_type_(tool_mode), prep_step_(prep_step) {} //--------------------------------------------------------------------------- DeepSSMJob::~DeepSSMJob() {} @@ -42,17 +31,17 @@ DeepSSMJob::~DeepSSMJob() {} //--------------------------------------------------------------------------- void DeepSSMJob::run() { try { - switch (tool_mode_) { - case DeepSSMTool::ToolMode::DeepSSM_PrepType: + switch (job_type_) { + case DeepSSMJob::JobType::DeepSSM_PrepType: run_prep(); break; - case DeepSSMTool::ToolMode::DeepSSM_AugmentationType: + case DeepSSMJob::JobType::DeepSSM_AugmentationType: run_augmentation(); break; - case DeepSSMTool::ToolMode::DeepSSM_TrainingType: + case DeepSSMJob::JobType::DeepSSM_TrainingType: run_training(); break; - case DeepSSMTool::ToolMode::DeepSSM_TestingType: + case DeepSSMJob::JobType::DeepSSM_TestingType: run_testing(); break; } @@ -63,17 +52,17 @@ void DeepSSMJob::run() { //--------------------------------------------------------------------------- QString DeepSSMJob::name() { - switch (tool_mode_) { - case DeepSSMTool::ToolMode::DeepSSM_PrepType: + switch (job_type_) { + case DeepSSMJob::JobType::DeepSSM_PrepType: return "DeepSSM: Prep"; break; - case DeepSSMTool::ToolMode::DeepSSM_AugmentationType: + case DeepSSMJob::JobType::DeepSSM_AugmentationType: return "DeepSSM: Augmentation"; break; - case DeepSSMTool::ToolMode::DeepSSM_TrainingType: + case DeepSSMJob::JobType::DeepSSM_TrainingType: return "DeepSSM: Training"; break; - case DeepSSMTool::ToolMode::DeepSSM_TestingType: + case DeepSSMJob::JobType::DeepSSM_TestingType: return "DeepSSM: Testing"; break; } @@ -85,7 +74,6 @@ QString DeepSSMJob::name() { void DeepSSMJob::run_prep() { // groom training auto subjects = project_->get_subjects(); - auto shapes = session_->get_shapes(); SW_LOG("DeepSSM: Grooming Training Data"); py::module py_deep_ssm_utils = py::module::import("DeepSSMUtils"); @@ -96,7 +84,7 @@ void DeepSSMJob::run_prep() { params.set_training_step_complete(false); params.save_to_project(); - if (prep_step_ == DeepSSMTool::PrepStep::NOT_STARTED || prep_step_ == DeepSSMTool::PrepStep::GROOM_TRAINING) { + if (prep_step_ == DeepSSMJob::PrepStep::NOT_STARTED || prep_step_ == DeepSSMJob::PrepStep::GROOM_TRAINING) { SW_LOG("Creating Split..."); ///////////////////////////////////////////////////////// /// Step 1. Create Split @@ -107,13 +95,12 @@ void DeepSSMJob::run_prep() { py::object create_split = py_deep_ssm_utils.attr("create_split"); create_split(project_, train_split, val_split, test_split); - int num_train = DeepSSMTool::get_split(project_, DeepSSMTool::SplitType::TRAIN).size(); - int num_val = DeepSSMTool::get_split(project_, DeepSSMTool::SplitType::VAL).size(); - int num_test = DeepSSMTool::get_split(project_, DeepSSMTool::SplitType::TEST).size(); + int num_train = get_split(project_, SplitType::TRAIN).size(); + int num_val = get_split(project_, SplitType::VAL).size(); + int num_test = get_split(project_, SplitType::TEST).size(); if (num_train == 0 || num_val == 0) { SW_ERROR("DeepSSM: Not enough subjects in training and validation. Please check split."); abort(); - //return; } if (is_aborted()) { @@ -123,7 +110,7 @@ void DeepSSMJob::run_prep() { ///////////////////////////////////////////////////////// /// Step 2. Groom Training Shapes ///////////////////////////////////////////////////////// - update_prep_stage(DeepSSMTool::PrepStep::GROOM_TRAINING); + update_prep_stage(DeepSSMJob::PrepStep::GROOM_TRAINING); py::object groom_training_shapes = py_deep_ssm_utils.attr("groom_training_shapes"); QElapsedTimer timer; @@ -153,11 +140,11 @@ void DeepSSMJob::run_prep() { } } - if (prep_step_ == DeepSSMTool::PrepStep::NOT_STARTED || prep_step_ == DeepSSMTool::PrepStep::OPTIMIZE_TRAINING) { + if (prep_step_ == DeepSSMJob::PrepStep::NOT_STARTED || prep_step_ == DeepSSMJob::PrepStep::OPTIMIZE_TRAINING) { ///////////////////////////////////////////////////////// /// Step 3. Optimize Training Particles ///////////////////////////////////////////////////////// - update_prep_stage(DeepSSMTool::PrepStep::OPTIMIZE_TRAINING); + update_prep_stage(DeepSSMJob::PrepStep::OPTIMIZE_TRAINING); QElapsedTimer timer; timer.start(); py::object optimize_training_particles = py_deep_ssm_utils.attr("optimize_training_particles"); @@ -171,11 +158,11 @@ void DeepSSMJob::run_prep() { } } - if (prep_step_ == DeepSSMTool::PrepStep::NOT_STARTED || prep_step_ == DeepSSMTool::PrepStep::OPTIMIZE_VALIDATION) { + if (prep_step_ == DeepSSMJob::PrepStep::NOT_STARTED || prep_step_ == DeepSSMJob::PrepStep::OPTIMIZE_VALIDATION) { ///////////////////////////////////////////////////////// /// Step 6. Optimize Validation Particles with Fixed Domains ///////////////////////////////////////////////////////// - update_prep_stage(DeepSSMTool::PrepStep::OPTIMIZE_VALIDATION); + update_prep_stage(DeepSSMJob::PrepStep::OPTIMIZE_VALIDATION); py::object prep_project_for_val_particles = py_deep_ssm_utils.attr("prep_project_for_val_particles"); prep_project_for_val_particles(project_); @@ -199,12 +186,12 @@ void DeepSSMJob::run_prep() { SW_LOG("DeepSSM: Optimize Validation Particles complete. Duration: {} seconds", duration.toStdString()); } - if (prep_step_ == DeepSSMTool::PrepStep::NOT_STARTED || prep_step_ == DeepSSMTool::PrepStep::GROOM_IMAGES) { + if (prep_step_ == DeepSSMJob::PrepStep::NOT_STARTED || prep_step_ == DeepSSMJob::PrepStep::GROOM_IMAGES) { ///////////////////////////////////////////////////////// /// Step 4. Groom Training Images ///////////////////////////////////////////////////////// - update_prep_stage(DeepSSMTool::PrepStep::GROOM_IMAGES); + update_prep_stage(DeepSSMJob::PrepStep::GROOM_IMAGES); QElapsedTimer timer; timer.start(); py::object groom_training_images = py_deep_ssm_utils.attr("groom_training_images"); @@ -221,7 +208,7 @@ void DeepSSMJob::run_prep() { ///////////////////////////////////////////////////////// timer.start(); py::object groom_val_test_images = py_deep_ssm_utils.attr("groom_val_test_images"); - groom_val_test_images(project_, DeepSSMTool::get_split(project_, DeepSSMTool::SplitType::VAL)); + groom_val_test_images(project_, get_split(project_, SplitType::VAL)); project_->save(); duration = QString::number(timer.elapsed() / 1000.0, 'f', 1); SW_LOG("DeepSSM: Groom Validation Images complete. Duration: {} seconds", duration.toStdString()); @@ -232,7 +219,7 @@ void DeepSSMJob::run_prep() { } ///////////////////////////////////////////////////////// - update_prep_stage(DeepSSMTool::PrepStep::DONE); + update_prep_stage(DeepSSMJob::PrepStep::DONE); params.set_prep_step_complete(true); params.set_aug_step_complete(false); params.set_training_step_complete(false); @@ -328,7 +315,7 @@ void DeepSSMJob::run_testing() { py::module py_deep_ssm_utils = py::module::import("DeepSSMUtils"); - std::vector test_indices = DeepSSMTool::get_split(project_, DeepSSMTool::SplitType::TEST); + std::vector test_indices = get_split(project_, SplitType::TEST); // Groom Test Images SW_MESSAGE("Grooming Test Images"); @@ -371,7 +358,37 @@ void DeepSSMJob::run_testing() { void DeepSSMJob::python_message(std::string str) { SW_LOG(str); } //--------------------------------------------------------------------------- -void DeepSSMJob::update_prep_stage(DeepSSMTool::PrepStep step) { +std::vector DeepSSMJob::get_split(ProjectHandle project, SplitType split_type) { + auto subjects = project->get_subjects(); + + std::vector list; + + for (int id = 0; id < subjects.size(); id++) { + auto extra_values = subjects[id]->get_extra_values(); + + std::string split = extra_values["split"]; + + if (split_type == DeepSSMJob::SplitType::TRAIN) { + if (split != "train") { + continue; + } + } else if (split_type == DeepSSMJob::SplitType::VAL) { + if (split != "val") { + continue; + } + } else if (split_type == DeepSSMJob::SplitType::TEST) { + if (split != "test") { + continue; + } + } + + list.push_back(id); + } + return list; +} + +//--------------------------------------------------------------------------- +void DeepSSMJob::update_prep_stage(PrepStep step) { /* std::lock_guard lock(mutex_); diff --git a/Libs/Application/DeepSSM/DeepSSMJob.h b/Libs/Application/DeepSSM/DeepSSMJob.h new file mode 100644 index 00000000000..f3c26dbcb94 --- /dev/null +++ b/Libs/Application/DeepSSM/DeepSSMJob.h @@ -0,0 +1,73 @@ +#pragma once + +#include +#include + +#include + +namespace shapeworks { + +//! Qt Wrapper for DeepSSM +/*! + * The DeepSSMJob class wraps the functionality for DeepSSM as a Studio Job object + * + */ +class DeepSSMJob : public Job { + Q_OBJECT; + + public: + enum class JobType { + DeepSSM_PrepType = 0, + DeepSSM_AugmentationType = 1, + DeepSSM_TrainingType = 2, + DeepSSM_TestingType = 3 + }; + + enum PrepStep { + NOT_STARTED = 0, + GROOM_TRAINING = 1, + OPTIMIZE_TRAINING = 2, + OPTIMIZE_VALIDATION = 3, + GROOM_IMAGES = 4, + DONE = 5 + }; + + enum class SplitType { TRAIN, VAL, TEST }; + + DeepSSMJob(std::shared_ptr project, DeepSSMJob::JobType tool_mode, + DeepSSMJob::PrepStep prep_step = DeepSSMJob::NOT_STARTED); + ~DeepSSMJob(); + + void run() override; + + QString name() override; + + void run_prep(); + void run_augmentation(); + void run_training(); + void run_testing(); + + void python_message(std::string str); + + static std::vector get_split(ProjectHandle project, DeepSSMJob::SplitType split_type); + + void set_prep_step(DeepSSMJob::PrepStep step) { + std::lock_guard lock(mutex_); + prep_step_ = step; + } + + private: + void update_prep_stage(DeepSSMJob::PrepStep step); + void process_test_results(); + + std::shared_ptr project_; + + DeepSSMJob::JobType job_type_; + + QString prep_message_; + DeepSSMJob::PrepStep prep_step_{DeepSSMJob::NOT_STARTED}; + + // mutex + std::mutex mutex_; +}; +} // namespace shapeworks diff --git a/Studio/Job/Job.cpp b/Libs/Application/Job/Job.cpp similarity index 100% rename from Studio/Job/Job.cpp rename to Libs/Application/Job/Job.cpp diff --git a/Studio/Job/Job.h b/Libs/Application/Job/Job.h similarity index 94% rename from Studio/Job/Job.h rename to Libs/Application/Job/Job.h index 00f9976c9c3..24bbea6ddb9 100644 --- a/Studio/Job/Job.h +++ b/Libs/Application/Job/Job.h @@ -1,6 +1,7 @@ #pragma once #include #include +#include namespace shapeworks { @@ -62,3 +63,6 @@ class Job : public QObject { QElapsedTimer timer_; }; } // namespace shapeworks + + +Q_DECLARE_METATYPE(QSharedPointer); diff --git a/Studio/Python/PythonWorker.cpp b/Libs/Application/Job/PythonWorker.cpp similarity index 91% rename from Studio/Python/PythonWorker.cpp rename to Libs/Application/Job/PythonWorker.cpp index e46058ba9ca..e2169cb8628 100644 --- a/Studio/Python/PythonWorker.cpp +++ b/Libs/Application/Job/PythonWorker.cpp @@ -7,15 +7,14 @@ namespace py = pybind11; using namespace pybind11::literals; // to bring in the `_a` literal +#include #include -#include -#include #include #include #include +#include #include -#include namespace shapeworks { @@ -38,11 +37,18 @@ class PythonLogger { bool check_abort() { return aborted_; } + bool is_cli_mode() { return is_cli_mode_; } + + void set_cli_mode(bool cli) { is_cli_mode_ = cli; } + private: std::function callback_; std::function progress_callback_; std::atomic aborted_{false}; + + //! Command line interface mode + std::atomic is_cli_mode_{false}; }; //--------------------------------------------------------------------------- @@ -51,13 +57,16 @@ PYBIND11_EMBEDDED_MODULE(logger, m) { .def(py::init<>()) .def("log", &PythonLogger::cpp_log) .def("check_abort", &PythonLogger::check_abort) - .def("progress", &PythonLogger::cpp_progress); + .def("progress", &PythonLogger::cpp_progress) + .def("is_cli_mode", &PythonLogger::is_cli_mode); }; //--------------------------------------------------------------------------- PythonWorker::PythonWorker() { python_logger_ = QSharedPointer::create(); + qRegisterMetaType>("QSharedPointer"); + // create singular Python thread and move this object to the new thread thread_ = new QThread(this); moveToThread(thread_); @@ -67,15 +76,20 @@ PythonWorker::PythonWorker() { //--------------------------------------------------------------------------- PythonWorker::~PythonWorker() { end_python(); - thread_->wait(); - delete thread_; + if (thread_) { + thread_->wait(); + delete thread_; + } } //--------------------------------------------------------------------------- -void PythonWorker::set_vtk_output_window(vtkSmartPointer output_window) { +void PythonWorker::set_vtk_output_window(vtkSmartPointer output_window) { studio_vtk_output_window_ = output_window; } +//--------------------------------------------------------------------------- +void PythonWorker::set_cli_mode(bool cli_mode) { python_logger_->set_cli_mode(cli_mode); } + //--------------------------------------------------------------------------- void PythonWorker::start_job(QSharedPointer job) { if (init()) { @@ -109,6 +123,9 @@ void PythonWorker::run_job(QSharedPointer job) { QMetaObject::invokeMethod(this, "start_job", Qt::QueuedConnection, Q_ARG(QSharedPointer, job)); } +//--------------------------------------------------------------------------- +void PythonWorker::set_current_job(QSharedPointer job) { current_job_ = job; } + //--------------------------------------------------------------------------- bool PythonWorker::init() { std::string script = "install_shapeworks.sh"; diff --git a/Studio/Python/PythonWorker.h b/Libs/Application/Job/PythonWorker.h similarity index 73% rename from Studio/Python/PythonWorker.h rename to Libs/Application/Job/PythonWorker.h index 71775971124..4044185e530 100644 --- a/Studio/Python/PythonWorker.h +++ b/Libs/Application/Job/PythonWorker.h @@ -9,7 +9,7 @@ // studio #include -#include +#include namespace shapeworks { class PythonLogger; @@ -23,9 +23,11 @@ class PythonWorker : public QObject { PythonWorker(); ~PythonWorker(); - void set_vtk_output_window(vtkSmartPointer output_window); + void set_vtk_output_window(vtkSmartPointer output_window); + void set_cli_mode(bool cli_mode); void run_job(QSharedPointer job); + void set_current_job(QSharedPointer job); void incoming_python_message(std::string message_string); void incoming_python_progress(double value, std::string message); @@ -51,12 +53,12 @@ class PythonWorker : public QObject { bool initialized_ = false; bool initialized_success_ = false; - vtkSmartPointer studio_vtk_output_window_; + vtkSmartPointer studio_vtk_output_window_; QSharedPointer python_logger_; QSharedPointer current_job_; - QThread* thread_; + QThread* thread_{nullptr}; }; } // namespace shapeworks diff --git a/Studio/Visualization/StudioVtkOutputWindow.cpp b/Libs/Application/ShapeWorksVtkOutputWindow.cpp similarity index 53% rename from Studio/Visualization/StudioVtkOutputWindow.cpp rename to Libs/Application/ShapeWorksVtkOutputWindow.cpp index cd9f794c12b..6219b26640d 100644 --- a/Studio/Visualization/StudioVtkOutputWindow.cpp +++ b/Libs/Application/ShapeWorksVtkOutputWindow.cpp @@ -1,26 +1,26 @@ -#include "StudioVtkOutputWindow.h" +#include "ShapeWorksVtkOutputWindow.h" #include #include namespace shapeworks { -vtkStandardNewMacro(StudioVtkOutputWindow); +vtkStandardNewMacro(ShapeWorksVtkOutputWindow); //--------------------------------------------------------------------------- -StudioVtkOutputWindow::StudioVtkOutputWindow() {} +ShapeWorksVtkOutputWindow::ShapeWorksVtkOutputWindow() {} //--------------------------------------------------------------------------- -void StudioVtkOutputWindow::DisplayErrorText(const char* text) { SW_ERROR(text); } +void ShapeWorksVtkOutputWindow::DisplayErrorText(const char* text) { SW_ERROR(text); } //--------------------------------------------------------------------------- -void StudioVtkOutputWindow::DisplayWarningText(const char* text) { SW_WARN(text); } +void ShapeWorksVtkOutputWindow::DisplayWarningText(const char* text) { SW_WARN(text); } //--------------------------------------------------------------------------- -void StudioVtkOutputWindow::DisplayGenericWarningText(const char* text) { SW_WARN(text); } +void ShapeWorksVtkOutputWindow::DisplayGenericWarningText(const char* text) { SW_WARN(text); } //--------------------------------------------------------------------------- -void StudioVtkOutputWindow::DisplayDebugText(const char* text) { SW_DEBUG(text); } +void ShapeWorksVtkOutputWindow::DisplayDebugText(const char* text) { SW_DEBUG(text); } //--------------------------------------------------------------------------- diff --git a/Studio/Visualization/StudioVtkOutputWindow.h b/Libs/Application/ShapeWorksVtkOutputWindow.h similarity index 68% rename from Studio/Visualization/StudioVtkOutputWindow.h rename to Libs/Application/ShapeWorksVtkOutputWindow.h index 721c55ff50f..0ff930710c7 100644 --- a/Studio/Visualization/StudioVtkOutputWindow.h +++ b/Libs/Application/ShapeWorksVtkOutputWindow.h @@ -7,15 +7,15 @@ namespace shapeworks { //! Implementation of vtkOutputWindow to capture and display VTK error messages -class StudioVtkOutputWindow : public QObject, public vtkOutputWindow { +class ShapeWorksVtkOutputWindow : public QObject, public vtkOutputWindow { Q_OBJECT; public: - static StudioVtkOutputWindow* New(); + static ShapeWorksVtkOutputWindow* New(); - vtkTypeMacro(StudioVtkOutputWindow, vtkOutputWindow); + vtkTypeMacro(ShapeWorksVtkOutputWindow, vtkOutputWindow); - StudioVtkOutputWindow(); + ShapeWorksVtkOutputWindow(); void DisplayErrorText(const char* text) override; void DisplayWarningText(const char* text) override; diff --git a/Libs/CMakeLists.txt b/Libs/CMakeLists.txt index 75adb39e983..5cba523b616 100644 --- a/Libs/CMakeLists.txt +++ b/Libs/CMakeLists.txt @@ -1,3 +1,4 @@ +add_subdirectory(Application) add_subdirectory(Common) add_subdirectory(Mesh) add_subdirectory(Image) diff --git a/Libs/Optimize/Optimize.h b/Libs/Optimize/Optimize.h index 3f86a2f059a..d09d2833059 100644 --- a/Libs/Optimize/Optimize.h +++ b/Libs/Optimize/Optimize.h @@ -5,7 +5,6 @@ #endif // std -#include #include #include diff --git a/Libs/Project/CMakeLists.txt b/Libs/Project/CMakeLists.txt index fdf03cb29a6..7cc14dd63a8 100644 --- a/Libs/Project/CMakeLists.txt +++ b/Libs/Project/CMakeLists.txt @@ -1,6 +1,7 @@ set(SOURCES ExcelProjectReader.cpp ExcelProjectWriter.cpp + DeepSSMParameters.cpp JsonProjectReader.cpp JsonProjectWriter.cpp Parameters.cpp @@ -12,6 +13,7 @@ set(SOURCES ) set(HEADERS + DeepSSMParameters.h Project.h Subject.h Parameters.h diff --git a/Studio/DeepSSM/DeepSSMParameters.cpp b/Libs/Project/DeepSSMParameters.cpp similarity index 100% rename from Studio/DeepSSM/DeepSSMParameters.cpp rename to Libs/Project/DeepSSMParameters.cpp diff --git a/Studio/DeepSSM/DeepSSMParameters.h b/Libs/Project/DeepSSMParameters.h similarity index 100% rename from Studio/DeepSSM/DeepSSMParameters.h rename to Libs/Project/DeepSSMParameters.h diff --git a/Python/DeepSSMUtilsPackage/DeepSSMUtils/run_utils.py b/Python/DeepSSMUtilsPackage/DeepSSMUtils/run_utils.py index 6dc7b5bfdb4..18a48e8a4a8 100644 --- a/Python/DeepSSMUtilsPackage/DeepSSMUtils/run_utils.py +++ b/Python/DeepSSMUtilsPackage/DeepSSMUtils/run_utils.py @@ -538,6 +538,7 @@ def process_test_predictions(project, config_file): template_particles, template_mesh, pred_dir) print("Distances: ", distances) + print("Mean distance: ", np.mean(distances)) # write to csv file in deepssm_dir csv_file = f"{deepssm_dir}/test_distances.csv" diff --git a/Python/DeepSSMUtilsPackage/DeepSSMUtils/trainer.py b/Python/DeepSSMUtilsPackage/DeepSSMUtils/trainer.py index 38f97cf83f3..4ed97fcc492 100644 --- a/Python/DeepSSMUtilsPackage/DeepSSMUtils/trainer.py +++ b/Python/DeepSSMUtilsPackage/DeepSSMUtils/trainer.py @@ -44,6 +44,8 @@ def log_print(logger, values): print(values[i], end=' ') else: print('%.5f' % values[i], end=' ') + # print a new line + print() # csv format string_values = [str(i) for i in values] @@ -167,7 +169,8 @@ def supervised_train(config_file): if sw_check_abort(): sw_message("Aborted") return - sw_message(f"Epoch {e}/{num_epochs}") + if not sw_is_cli_mode(): + sw_message(f"Epoch {e}/{num_epochs}") sw_progress(e / (num_epochs + 1)) torch.cuda.empty_cache() @@ -229,6 +232,7 @@ def supervised_train(config_file): log_print(logger, ["Base_Training", e, last_learning_rate, train_mr_MSE, train_rel_err, val_mr_MSE, val_rel_err, time.time() - t0]) + # plot epochs.append(e) plot_train_losses.append(train_mr_MSE) diff --git a/Python/shapeworks/shapeworks/utils.py b/Python/shapeworks/shapeworks/utils.py index 41f4e7a0bf8..0db266b6f10 100644 --- a/Python/shapeworks/shapeworks/utils.py +++ b/Python/shapeworks/shapeworks/utils.py @@ -198,6 +198,13 @@ def sw_check_abort(): else: return False +def sw_is_cli_mode(): + """Check if the current mode is CLI mode""" + global sw_logger + if sw_logger is not None: + return sw_logger.is_cli_mode() + else: + return False def sw_progress(progress, message=""): """If sw_logger is set, use it, otherwise do nothing""" diff --git a/Studio/Analysis/AnalysisTool.cpp b/Studio/Analysis/AnalysisTool.cpp index d952fb19bed..0e9d6b52f93 100644 --- a/Studio/Analysis/AnalysisTool.cpp +++ b/Studio/Analysis/AnalysisTool.cpp @@ -11,8 +11,8 @@ #include #include #include +#include #include -#include #include #include #include diff --git a/Studio/Analysis/ShapeScalarPanel.cpp b/Studio/Analysis/ShapeScalarPanel.cpp index e5b49790945..0e420df1004 100644 --- a/Studio/Analysis/ShapeScalarPanel.cpp +++ b/Studio/Analysis/ShapeScalarPanel.cpp @@ -10,8 +10,8 @@ #include #include #include +#include #include -#include #include #include #include diff --git a/Studio/CMakeLists.txt b/Studio/CMakeLists.txt index fb6f135bd26..b7cfafd1c73 100644 --- a/Studio/CMakeLists.txt +++ b/Studio/CMakeLists.txt @@ -102,7 +102,6 @@ SET(STUDIO_DATA_MOC_HDRS ) SET(STUDIO_JOB_SRCS - Job/Job.cpp Job/GroupPvalueJob.cpp Job/ParticleAreaJob.cpp Job/NetworkAnalysisJob.cpp @@ -112,7 +111,6 @@ SET(STUDIO_JOB_SRCS ) SET(STUDIO_JOB_MOC_HDRS - Job/Job.h Job/GroupPvalueJob.h Job/NetworkAnalysisJob.h Job/ParticleAreaJob.h @@ -128,13 +126,6 @@ SET(STUDIO_GROOM_MOC_HDRS Groom/GroomTool.h ) -SET(STUDIO_PYTHON_SRCS - Python/PythonWorker.cpp - ) -SET(STUDIO_PYTHON_MOC_HDRS - Python/PythonWorker.h - ) - SET(STUDIO_OPTIMIZE_SRCS Optimize/OptimizeTool.cpp Optimize/QOptimize.cpp @@ -159,12 +150,9 @@ SET(STUDIO_ANALYSIS_MOC_HDRS SET(STUDIO_DEEPSSM_SRCS DeepSSM/DeepSSMTool.cpp - DeepSSM/DeepSSMParameters.cpp - DeepSSM/DeepSSMJob.cpp ) SET(STUDIO_DEEPSSM_MOC_HDRS DeepSSM/DeepSSMTool.h - DeepSSM/DeepSSMJob.h ) @@ -233,14 +221,12 @@ SET(STUDIO_VISUALIZATION_SRCS Visualization/MeshSlice.cpp Visualization/Viewer.cpp Visualization/Visualizer.cpp - Visualization/StudioVtkOutputWindow.cpp Visualization/StudioHandleWidget.cpp ) SET(STUDIO_VISUALIZATION_MOC_HDRS Visualization/Lightbox.h Visualization/ParticleColors.h - Visualization/StudioVtkOutputWindow.h Visualization/Visualizer.h ) @@ -251,7 +237,6 @@ SET(STUDIO_SRCS ${STUDIO_DEEPSSM_SRCS} ${STUDIO_MONAILABEL_SRCS} ${STUDIO_GROOM_SRCS} - ${STUDIO_PYTHON_SRCS} ${STUDIO_INTERFACE_SRCS} ${STUDIO_OPTIMIZE_SRCS} ${STUDIO_UTILS_SRCS} @@ -265,7 +250,6 @@ SET(STUDIO_MOC_HDRS ${STUDIO_DEEPSSM_MOC_HDRS} ${STUDIO_MONAILABEL_MOC_HDRS} ${STUDIO_GROOM_MOC_HDRS} - ${STUDIO_PYTHON_MOC_HDRS} ${STUDIO_INTERFACE_MOC_HDRS} ${STUDIO_OPTIMIZE_MOC_HDRS} ${STUDIO_UTILS_MOC_HDRS} @@ -385,6 +369,7 @@ TARGET_LINK_LIBRARIES(ShapeWorksStudio ${ITK_LIBRARIES} ${OPENGL_LIBRARIES} tinyxml + Application Alignment Analyze OptimizeLibraries diff --git a/Studio/DeepSSM/DeepSSMJob.h b/Studio/DeepSSM/DeepSSMJob.h deleted file mode 100644 index 19ab46ad99d..00000000000 --- a/Studio/DeepSSM/DeepSSMJob.h +++ /dev/null @@ -1,50 +0,0 @@ -#pragma once - -#include -#include -#include - -#include - -namespace shapeworks { - -//! Qt Wrapper for DeepSSM -/*! - * The DeepSSMJob class wraps the functionality for DeepSSM as a Studio Job object - * - */ -class DeepSSMJob : public Job { - Q_OBJECT; - - public: - DeepSSMJob(QSharedPointer session, DeepSSMTool::ToolMode tool_mode, - DeepSSMTool::PrepStep prep_step = DeepSSMTool::NOT_STARTED); - ~DeepSSMJob(); - - void run() override; - - QString name() override; - - void run_prep(); - void run_augmentation(); - void run_training(); - void run_testing(); - - void python_message(std::string str); - - private: - void update_prep_stage(DeepSSMTool::PrepStep step); - void process_test_results(); - - QSharedPointer session_; - ProjectHandle project_; - - DeepSSMTool::ToolMode tool_mode_; - - QString prep_message_; - DeepSSMTool::PrepStep prep_step_{DeepSSMTool::NOT_STARTED}; - - // mutex - std::mutex mutex_; -}; -} // namespace shapeworks diff --git a/Studio/DeepSSM/DeepSSMTool.cpp b/Studio/DeepSSM/DeepSSMTool.cpp index b474318bc2d..4a1ba2abe0b 100644 --- a/Studio/DeepSSM/DeepSSMTool.cpp +++ b/Studio/DeepSSM/DeepSSMTool.cpp @@ -7,16 +7,16 @@ #include // shapeworks +#include #include #include -#include // studio #include #include #include #include -#include +#include #include #include #include @@ -124,16 +124,16 @@ DeepSSMTool::DeepSSMTool(Preferences& prefs) : preferences_(prefs) { void DeepSSMTool::tab_changed(int tab) { switch (tab) { case 0: - current_tool_ = DeepSSMTool::ToolMode::DeepSSM_PrepType; + current_tool_ = DeepSSMJob::JobType::DeepSSM_PrepType; break; case 1: - current_tool_ = DeepSSMTool::ToolMode::DeepSSM_AugmentationType; + current_tool_ = DeepSSMJob::JobType::DeepSSM_AugmentationType; break; case 2: - current_tool_ = DeepSSMTool::ToolMode::DeepSSM_TrainingType; + current_tool_ = DeepSSMJob::JobType::DeepSSM_TrainingType; break; case 3: - current_tool_ = DeepSSMTool::ToolMode::DeepSSM_TestingType; + current_tool_ = DeepSSMJob::JobType::DeepSSM_TestingType; break; } update_panels(); @@ -160,6 +160,8 @@ void DeepSSMTool::load_params() { ui_->validation_split->setText(QString::number(params.get_validation_split())); ui_->testing_split->setText(QString::number(params.get_testing_split())); + ui_->training_split->setText(QString::number(params.get_training_split())); + update_split(); auto spacing = params.get_spacing(); ui_->spacing_x->setText(QString::number(spacing[0])); @@ -249,7 +251,7 @@ void DeepSSMTool::run_clicked() { } else { session_->trigger_save(); if (ui_->run_all->isChecked()) { - run_tool(DeepSSMTool::ToolMode::DeepSSM_PrepType); + run_tool(DeepSSMJob::JobType::DeepSSM_PrepType); } else { run_tool(current_tool_); } @@ -258,7 +260,7 @@ void DeepSSMTool::run_clicked() { //--------------------------------------------------------------------------- void DeepSSMTool::run_prep_clicked(int step) { - prep_step_ = static_cast(step); + prep_step_ = static_cast(step); run_clicked(); } @@ -266,16 +268,16 @@ void DeepSSMTool::run_prep_clicked(int step) { void DeepSSMTool::handle_thread_complete() { try { if (!deep_ssm_->is_aborted()) { - if (current_tool_ == DeepSSMTool::ToolMode::DeepSSM_PrepType) { + if (current_tool_ == DeepSSMJob::JobType::DeepSSM_PrepType) { auto params = DeepSSMParameters(session_->get_project()); params.set_prep_stage(static_cast(prep_step_)); - if (prep_step_ == DeepSSMTool::PrepStep::NOT_STARTED || prep_step_ == DeepSSMTool::PrepStep::GROOM_IMAGES) { + if (prep_step_ == DeepSSMJob::PrepStep::NOT_STARTED || prep_step_ == DeepSSMJob::PrepStep::GROOM_IMAGES) { params.set_prep_step_complete(true); - params.set_prep_stage(static_cast(DeepSSMTool::PrepStep::DONE)); + params.set_prep_stage(static_cast(DeepSSMJob::PrepStep::DONE)); } params.save_to_project(); update_panels(); - prep_step_ = DeepSSMTool::PrepStep::NOT_STARTED; + prep_step_ = DeepSSMJob::PrepStep::NOT_STARTED; } } Q_EMIT progress(100); @@ -286,12 +288,12 @@ void DeepSSMTool::handle_thread_complete() { if (!deep_ssm_->is_aborted()) { if (ui_->run_all->isChecked()) { - if (current_tool_ == ToolMode::DeepSSM_PrepType) { - run_tool(DeepSSMTool::ToolMode::DeepSSM_AugmentationType); - } else if (current_tool_ == ToolMode::DeepSSM_AugmentationType) { - run_tool(DeepSSMTool::ToolMode::DeepSSM_TrainingType); - } else if (current_tool_ == ToolMode::DeepSSM_TrainingType) { - run_tool(DeepSSMTool::ToolMode::DeepSSM_TestingType); + if (current_tool_ == DeepSSMJob::JobType::DeepSSM_PrepType) { + run_tool(DeepSSMJob::JobType::DeepSSM_AugmentationType); + } else if (current_tool_ == DeepSSMJob::JobType::DeepSSM_AugmentationType) { + run_tool(DeepSSMJob::JobType::DeepSSM_TrainingType); + } else if (current_tool_ == DeepSSMJob::JobType::DeepSSM_TrainingType) { + run_tool(DeepSSMJob::JobType::DeepSSM_TestingType); } } } @@ -302,7 +304,7 @@ void DeepSSMTool::handle_thread_complete() { //--------------------------------------------------------------------------- void DeepSSMTool::handle_progress(int val, QString message) { - if (current_tool_ == DeepSSMTool::ToolMode::DeepSSM_PrepType) { + if (current_tool_ == DeepSSMJob::JobType::DeepSSM_PrepType) { //?? TODO ui_->prep_text_edit->setText(deep_ssm_->get_prep_message()); //?? TODO ui_->prep_text_edit->setEnabled(true); } @@ -333,20 +335,20 @@ void DeepSSMTool::update_panels() { ui_->training_panel->hide(); bool enabled = true; switch (current_tool_) { - case DeepSSMTool::ToolMode::DeepSSM_PrepType: + case DeepSSMJob::JobType::DeepSSM_PrepType: string = "All Prep Stages"; break; - case DeepSSMTool::ToolMode::DeepSSM_AugmentationType: + case DeepSSMJob::JobType::DeepSSM_AugmentationType: string = "Data Augmentation"; ui_->data_panel->show(); enabled = params.get_prep_step_complete(); break; - case DeepSSMTool::ToolMode::DeepSSM_TrainingType: + case DeepSSMJob::JobType::DeepSSM_TrainingType: string = "Training"; ui_->training_panel->show(); enabled = params.get_aug_step_complete(); break; - case DeepSSMTool::ToolMode::DeepSSM_TestingType: + case DeepSSMJob::JobType::DeepSSM_TestingType: string = "Testing"; enabled = params.get_training_step_complete(); break; @@ -406,7 +408,7 @@ void DeepSSMTool::update_split() { //--------------------------------------------------------------------------- void DeepSSMTool::handle_new_mesh() { - if (current_tool_ == DeepSSMTool::ToolMode::DeepSSM_TestingType) { + if (current_tool_ == DeepSSMJob::JobType::DeepSSM_TestingType) { update_testing_meshes(); } } @@ -604,8 +606,8 @@ void DeepSSMTool::show_training_meshes() { //--------------------------------------------------------------------------- void DeepSSMTool::show_testing_meshes() { shapes_.clear(); - deep_ssm_ = QSharedPointer::create(session_, DeepSSMTool::ToolMode::DeepSSM_TestingType); - auto id_list = get_split(session_->get_project(), SplitType::TEST); + deep_ssm_ = QSharedPointer::create(session_->get_project(), DeepSSMJob::JobType::DeepSSM_TestingType); + auto id_list = DeepSSMJob::get_split(session_->get_project(), DeepSSMJob::SplitType::TEST); auto subjects = session_->get_project()->get_subjects(); auto shapes = session_->get_shapes(); @@ -664,8 +666,8 @@ void DeepSSMTool::show_testing_meshes() { //--------------------------------------------------------------------------- void DeepSSMTool::update_testing_meshes() { try { - deep_ssm_ = QSharedPointer::create(session_, DeepSSMTool::ToolMode::DeepSSM_TestingType); - auto id_list = get_split(session_->get_project(), SplitType::TEST); + deep_ssm_ = QSharedPointer::create(session_->get_project(), DeepSSMJob::JobType::DeepSSM_TestingType); + auto id_list = DeepSSMJob::get_split(session_->get_project(), DeepSSMJob::SplitType::TEST); auto subjects = session_->get_project()->get_subjects(); auto shapes = session_->get_shapes(); @@ -685,17 +687,17 @@ void DeepSSMTool::update_meshes() { return; } switch (current_tool_) { - case DeepSSMTool::ToolMode::DeepSSM_PrepType: + case DeepSSMJob::JobType::DeepSSM_PrepType: shapes_.clear(); Q_EMIT update_view(); break; - case DeepSSMTool::ToolMode::DeepSSM_AugmentationType: + case DeepSSMJob::JobType::DeepSSM_AugmentationType: show_augmentation_meshes(); break; - case DeepSSMTool::ToolMode::DeepSSM_TrainingType: + case DeepSSMJob::JobType::DeepSSM_TrainingType: show_training_meshes(); break; - case DeepSSMTool::ToolMode::DeepSSM_TestingType: + case DeepSSMJob::JobType::DeepSSM_TestingType: show_testing_meshes(); break; } @@ -835,43 +837,13 @@ void DeepSSMTool::resizeEvent(QResizeEvent* event) { //--------------------------------------------------------------------------- std::string DeepSSMTool::get_display_feature() { - if (current_tool_ == DeepSSMTool::ToolMode::DeepSSM_TrainingType || - current_tool_ == DeepSSMTool::ToolMode::DeepSSM_TestingType) { + if (current_tool_ == DeepSSMJob::JobType::DeepSSM_TrainingType || + current_tool_ == DeepSSMJob::JobType::DeepSSM_TestingType) { return "deepssm_error"; } return ""; } -//--------------------------------------------------------------------------- -std::vector DeepSSMTool::get_split(ProjectHandle project, SplitType split_type) { - auto subjects = project->get_subjects(); - - std::vector list; - - for (int id = 0; id < subjects.size(); id++) { - auto extra_values = subjects[id]->get_extra_values(); - - std::string split = extra_values["split"]; - - if (split_type == SplitType::TRAIN) { - if (split != "train") { - continue; - } - } else if (split_type == SplitType::VAL) { - if (split != "val") { - continue; - } - } else if (split_type == SplitType::TEST) { - if (split != "test") { - continue; - } - } - - list.push_back(id); - } - return list; -} - //--------------------------------------------------------------------------- void DeepSSMTool::restore_defaults() { // need to save values from the other pages @@ -880,16 +852,16 @@ void DeepSSMTool::restore_defaults() { auto params = DeepSSMParameters(session_->get_project()); switch (current_tool_) { - case DeepSSMTool::ToolMode::DeepSSM_PrepType: + case DeepSSMJob::JobType::DeepSSM_PrepType: params.restore_split_defaults(); break; - case DeepSSMTool::ToolMode::DeepSSM_AugmentationType: + case DeepSSMJob::JobType::DeepSSM_AugmentationType: params.restore_augmentation_defaults(); break; - case DeepSSMTool::ToolMode::DeepSSM_TrainingType: + case DeepSSMJob::JobType::DeepSSM_TrainingType: params.restore_training_defaults(); break; - case DeepSSMTool::ToolMode::DeepSSM_TestingType: + case DeepSSMJob::JobType::DeepSSM_TestingType: // params.restore_inference_defaults(); break; } @@ -900,17 +872,17 @@ void DeepSSMTool::restore_defaults() { } //--------------------------------------------------------------------------- -void DeepSSMTool::run_tool(DeepSSMTool::ToolMode type) { - current_tool_ = type; +void DeepSSMTool::run_tool(DeepSSMJob::JobType job_type) { + current_tool_ = job_type; Q_EMIT progress(-1); - if (type == DeepSSMTool::ToolMode::DeepSSM_AugmentationType) { + if (job_type == DeepSSMJob::JobType::DeepSSM_AugmentationType) { ui_->tab_widget->setCurrentIndex(1); SW_LOG("Please Wait: Running Data Augmentation..."); // clean QFile("deepssm/augmentation/TotalData.csv").remove(); - } else if (type == DeepSSMTool::ToolMode::DeepSSM_TrainingType) { + } else if (job_type == DeepSSMJob::JobType::DeepSSM_TrainingType) { ui_->tab_widget->setCurrentIndex(2); SW_LOG("Please Wait: Running Training..."); @@ -919,17 +891,15 @@ void DeepSSMTool::run_tool(DeepSSMTool::ToolMode type) { dir.removeRecursively(); show_training_meshes(); - } else if (type == DeepSSMTool::ToolMode::DeepSSM_TestingType) { + } else if (job_type == DeepSSMJob::JobType::DeepSSM_TestingType) { ui_->tab_widget->setCurrentIndex(3); SW_LOG("Please Wait: Running Testing..."); - } else if (type == DeepSSMTool::ToolMode::DeepSSM_PrepType) { + } else if (job_type == DeepSSMJob::JobType::DeepSSM_PrepType) { ui_->tab_widget->setCurrentIndex(0); // check that there are at least 1 subject in test/val/train each - - SW_LOG("Please Wait: Running Groom/Optimize..."); } else { SW_ERROR("Unknown tool mode"); @@ -947,7 +917,7 @@ void DeepSSMTool::run_tool(DeepSSMTool::ToolMode type) { store_params(); - deep_ssm_ = QSharedPointer::create(session_, type, prep_step_); + deep_ssm_ = QSharedPointer::create(session_->get_project(), job_type, prep_step_); connect(deep_ssm_.data(), &DeepSSMJob::progress, this, &DeepSSMTool::handle_progress); connect(deep_ssm_.data(), &DeepSSMJob::finished, this, &DeepSSMTool::handle_thread_complete); diff --git a/Studio/DeepSSM/DeepSSMTool.h b/Studio/DeepSSM/DeepSSMTool.h index efd4bc306a7..3e36dce6334 100644 --- a/Studio/DeepSSM/DeepSSMTool.h +++ b/Studio/DeepSSM/DeepSSMTool.h @@ -5,6 +5,8 @@ #include #include +#include + // studio #include #include @@ -26,23 +28,6 @@ class DeepSSMTool : public QWidget { Q_OBJECT; public: - enum class ToolMode { - DeepSSM_PrepType = 0, - DeepSSM_AugmentationType = 1, - DeepSSM_TrainingType = 2, - DeepSSM_TestingType = 3 - }; - - enum PrepStep { - NOT_STARTED = 0, - GROOM_TRAINING = 1, - OPTIMIZE_TRAINING = 2, - OPTIMIZE_VALIDATION = 3, - GROOM_IMAGES = 4, - DONE = 5 - }; - - enum class SplitType { TRAIN, VAL, TEST }; DeepSSMTool(Preferences& prefs); ~DeepSSMTool(); @@ -67,7 +52,6 @@ class DeepSSMTool : public QWidget { std::string get_display_feature(); - static std::vector get_split(ProjectHandle project, SplitType split_type); public Q_SLOTS: @@ -96,7 +80,7 @@ class DeepSSMTool : public QWidget { private: void update_meshes(); - void run_tool(DeepSSMTool::ToolMode type); + void run_tool(DeepSSMJob::JobType type); void show_augmentation_meshes(); void update_tables(); void show_training_meshes(); @@ -120,10 +104,10 @@ class DeepSSMTool : public QWidget { Ui_DeepSSMTool* ui_; QSharedPointer session_; ShapeWorksStudioApp* app_; - PrepStep prep_step_ = PrepStep::NOT_STARTED; + DeepSSMJob::PrepStep prep_step_ = DeepSSMJob::PrepStep::NOT_STARTED; bool tool_is_running_ = false; - DeepSSMTool::ToolMode current_tool_ = DeepSSMTool::ToolMode::DeepSSM_AugmentationType; + DeepSSMJob::JobType current_tool_ = DeepSSMJob::JobType::DeepSSM_AugmentationType; QSharedPointer deep_ssm_; QElapsedTimer timer_; diff --git a/Studio/Interface/ShapeWorksStudioApp.cpp b/Studio/Interface/ShapeWorksStudioApp.cpp index 771c390d3d5..046eb7f8842 100644 --- a/Studio/Interface/ShapeWorksStudioApp.cpp +++ b/Studio/Interface/ShapeWorksStudioApp.cpp @@ -41,8 +41,8 @@ #include #include #include +#include #include -#include #include #include #include @@ -69,7 +69,7 @@ ShapeWorksStudioApp::ShapeWorksStudioApp() { status_bar_ = new StatusBarWidget(this); connect(status_bar_, &StatusBarWidget::toggle_log_window, this, &ShapeWorksStudioApp::toggle_log_window); - studio_vtk_output_window_ = vtkSmartPointer::New(); + studio_vtk_output_window_ = vtkSmartPointer::New(); vtkOutputWindow::SetInstance(studio_vtk_output_window_); logger_.register_callbacks(); diff --git a/Studio/Interface/ShapeWorksStudioApp.h b/Studio/Interface/ShapeWorksStudioApp.h index fc84eced2be..339459158c6 100644 --- a/Studio/Interface/ShapeWorksStudioApp.h +++ b/Studio/Interface/ShapeWorksStudioApp.h @@ -6,7 +6,7 @@ #include #include #include -#include +#include #include #include @@ -234,7 +234,7 @@ class ShapeWorksStudioApp : public QMainWindow { QSharedPointer visualizer_; QSharedPointer preferences_window_; CompareWidget* compare_widget_ = nullptr; - vtkSmartPointer studio_vtk_output_window_; + vtkSmartPointer studio_vtk_output_window_; // all the preferences Preferences preferences_; diff --git a/Studio/Job/ShapeScalarJob.cpp b/Studio/Job/ShapeScalarJob.cpp index 4db36bcfab7..acde309d7a1 100644 --- a/Studio/Job/ShapeScalarJob.cpp +++ b/Studio/Job/ShapeScalarJob.cpp @@ -1,6 +1,6 @@ #include +#include #include -#include #include #include #include diff --git a/Studio/ShapeWorksMONAI/MonaiLabelTool.cpp b/Studio/ShapeWorksMONAI/MonaiLabelTool.cpp index 7f535c97651..a9e996cc198 100644 --- a/Studio/ShapeWorksMONAI/MonaiLabelTool.cpp +++ b/Studio/ShapeWorksMONAI/MonaiLabelTool.cpp @@ -3,8 +3,8 @@ #include #include #include +#include #include -#include #include #include #include