Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Applications/shapeworks/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ target_include_directories(shapeworks_exe PUBLIC

target_link_libraries(shapeworks_exe
Mesh ${VTK_LIBRARIES} Optimize Utils trimesh2 Particles
pybind11::embed Project Image Groom Analyze
pybind11::embed Project Image Groom Analyze Application
)

message(STATUS "opt libs ${OPTIMIZE_LIBRARIES}")
Expand Down
9 changes: 9 additions & 0 deletions Applications/shapeworks/Command.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,15 @@ class ParticleSystemCommand : public Command
private:
};

class DeepSSMCommandGroup : public Command
{
public:
const std::string type() override { return "DeepSSM"; }

private:
};


class ShapeworksCommand : public Command
{
public:
Expand Down
168 changes: 164 additions & 4 deletions Applications/shapeworks/Commands.cpp
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
#include "Commands.h"

#include <Analyze/Analyze.h>
#include <Application/DeepSSM/DeepSSMJob.h>
#include <Application/Job/PythonWorker.h>
#include <Groom/Groom.h>
#include <Logging.h>
#include <Optimize/Optimize.h>
#include <Optimize/OptimizeParameterFile.h>
#include <Optimize/OptimizeParameters.h>
#include <Profiling.h>
#include <ShapeworksUtils.h>
#include <Utils/StringUtils.h>

#include <QApplication>
#include <boost/filesystem.hpp>

#include <Profiling.h>

namespace shapeworks {

// boilerplate for a command. Copy this to start a new command
Expand Down Expand Up @@ -43,8 +45,6 @@ bool Example::execute(const optparse::Values &options, SharedCommandData &shared
}
#endif



///////////////////////////////////////////////////////////////////////////////
// Seed
///////////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -331,4 +331,164 @@ bool ConvertProjectCommand::execute(const optparse::Values& options, SharedComma
return false;
}
}

///////////////////////////////////////////////////////////////////////////////
// DeepSSM
///////////////////////////////////////////////////////////////////////////////
void DeepSSMCommand::buildParser() {
const std::string prog = "deepssm";
const std::string desc = "run deepssm steps";
parser.prog(prog).description(desc);

parser.add_option("--name").action("store").type("string").set_default("").help(
"Path to input project file (xlsx or swproj).");

// Create a vector of choices first
std::vector<std::string> prep_choices = {"all", "groom_training", "optimize_training", "optimize_validation",
"groom_images"};

// --prep option with choices
parser.add_option("--prep")
.action("store")
.type("choice")
.choices(prep_choices.begin(), prep_choices.end())
//.set_default("all")
.help("Preparation step to run");

// Boolean flag options
parser.add_option("--augment").action("store_true").help("Run data augmentation");

parser.add_option("--train").action("store_true").help("Run training");

parser.add_option("--test").action("store_true").help("Run testing");

parser.add_option("--all").action("store_true").help("Run all steps");

Command::buildParser();
}

bool DeepSSMCommand::execute(const optparse::Values& options, SharedCommandData& sharedData) {
// Create a non-gui QApplication instance
int argc = 3;
char* argv[3];
argv[0] = const_cast<char*>("shapeworks");
argv[1] = const_cast<char*>("-platform");
argv[2] = const_cast<char*>("offscreen");

QApplication app(argc, argv);

// Handle project file: either from --name or first positional argument
std::string project_file;
if (options.is_set_by_user("name")) {
// User explicitly provided --name
project_file = options["name"];
} else if (!parser.args().empty()) {
// Use first positional argument
project_file = parser.args()[0];
} else {
// No project file provided at all
parser.error("Project file must be provided either as --name or as a positional argument");
}

// Handle prep option with manual default
std::string prep_step;
if (options.is_set_by_user("prep")) {
prep_step = options["prep"];
} else {
prep_step = "all"; // Manual default
}

std::cout << "DeepSSM: Using project file: " << project_file << std::endl;

bool do_prep = options.is_set("prep") || options.is_set("all");
bool do_augment = options.is_set("augment") || options.is_set("all");
bool do_train = options.is_set("train") || options.is_set("all");
bool do_test = options.is_set("test") || options.is_set("all");

std::cout << "Prep step: " << (do_prep ? "on" : "off") << "\n";
std::cout << "Augment step: " << (do_augment ? "on" : "off") << "\n";
std::cout << "Train step: " << (do_train ? "on" : "off") << "\n";
std::cout << "Test step: " << (do_test ? "on" : "off") << "\n";

if (!do_prep && !do_augment && !do_train && !do_test) {
do_prep = true;
do_augment = true;
do_train = true;
do_test = true;
}

ProjectHandle project = std::make_shared<Project>();
project->load(project_file);

PythonWorker python_worker;
python_worker.set_cli_mode(true);

auto wait_for_job = [&](auto job) {
// This lambda will block until the job is complete
while (!job->is_complete()) {
QCoreApplication::processEvents();
std::this_thread::sleep_for(std::chrono::milliseconds(100));
if (job->is_aborted()) {
return false;
}
}
return true;
};

if (do_prep) {
auto job = QSharedPointer<DeepSSMJob>::create(project, DeepSSMJob::JobType::DeepSSM_PrepType);
if (prep_step == "all") {
job->set_prep_step(DeepSSMJob::PrepStep::NOT_STARTED);
} else if (prep_step == "groom_training") {
job->set_prep_step(DeepSSMJob::PrepStep::GROOM_TRAINING);
} else if (prep_step == "optimize_training") {
job->set_prep_step(DeepSSMJob::PrepStep::OPTIMIZE_TRAINING);
} else if (prep_step == "optimize_validation") {
job->set_prep_step(DeepSSMJob::PrepStep::OPTIMIZE_VALIDATION);
} else if (prep_step == "groom_images") {
job->set_prep_step(DeepSSMJob::PrepStep::GROOM_IMAGES);
} else {
SW_ERROR("Unknown prep step: {}", prep_step);
return false;
}
std::cout << "Running DeepSSM preparation step...\n";
python_worker.run_job(job);
if (!wait_for_job(job)) {
return false;
}
std::cout << "DeepSSM preparation step completed.\n";
}
if (do_augment) {
std::cout << "Running DeepSSM data augmentation...\n";
auto job = QSharedPointer<DeepSSMJob>::create(project, DeepSSMJob::JobType::DeepSSM_AugmentationType);
python_worker.run_job(job);
if (!wait_for_job(job)) {
return false;
}
std::cout << "DeepSSM data augmentation completed.\n";
}
if (do_train) {
std::cout << "Running DeepSSM training...\n";
auto job = QSharedPointer<DeepSSMJob>::create(project, DeepSSMJob::JobType::DeepSSM_TrainingType);
python_worker.run_job(job);
if (!wait_for_job(job)) {
return false;
}
std::cout << "DeepSSM training completed.\n";
}
if (do_test) {
std::cout << "Running DeepSSM testing...\n";
auto job = QSharedPointer<DeepSSMJob>::create(project, DeepSSMJob::JobType::DeepSSM_TestingType);
python_worker.run_job(job);
if (!wait_for_job(job)) {
return false;
}
std::cout << "DeepSSM testing completed.\n";
}

project->save();

return false;
}

} // namespace shapeworks
1 change: 1 addition & 0 deletions Applications/shapeworks/Commands.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,5 +101,6 @@ COMMAND_DECLARE(OptimizeCommand, OptimizeCommandGroup);
COMMAND_DECLARE(GroomCommand, GroomCommandGroup);
COMMAND_DECLARE(AnalyzeCommand, AnalyzeCommandGroup);
COMMAND_DECLARE(ConvertProjectCommand, ProjectCommandGroup);
COMMAND_DECLARE(DeepSSMCommand, DeepSSMCommandGroup);

} // shapeworks
1 change: 1 addition & 0 deletions Applications/shapeworks/shapeworks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ int main(int argc, char *argv[])
shapeworks.addCommand(GroomCommand::getCommand());
shapeworks.addCommand(AnalyzeCommand::getCommand());
shapeworks.addCommand(ConvertProjectCommand::getCommand());
shapeworks.addCommand(DeepSSMCommand::getCommand());

try {
TIME_START("shapeworks");
Expand Down
50 changes: 50 additions & 0 deletions Libs/Application/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
SET(APPLICATION_MOC_HDRS
DeepSSM/DeepSSMJob.h
Job/Job.h
Job/PythonWorker.h
ShapeWorksVtkOutputWindow.h
)

qt5_wrap_cpp( APPLICATION_MOC_SRCS ${APPLICATION_MOC_HDRS} )

SET(Application_headers
)

add_library(Application STATIC
DeepSSM/DeepSSMJob.cpp
Job/Job.cpp
Job/PythonWorker.cpp
ShapeWorksVtkOutputWindow.cpp
${APPLICATION_MOC_SRCS}
)

target_include_directories(Application PUBLIC
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}>
$<INSTALL_INTERFACE:include>)

set(SW_PYTHON_LIBS pybind11::embed)

if (APPLE)
include_directories(${_Python3_INCLUDE_DIR})
set(SW_PYTHON_LIBS "")
endif(APPLE)

target_link_libraries(Application PUBLIC
Groom
Mesh
Utils
Particles
Project
${SW_PYTHON_LIBS}
)

# set
set_target_properties(Application PROPERTIES PUBLIC_HEADER
"${Application_headers}")

install(TARGETS Application EXPORT ShapeWorksTargets
ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR}
RUNTIME DESTINATION ${CMAKE_INSTALL_LIBDIR}
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}
PUBLIC_HEADER DESTINATION include/Application
)
Loading
Loading