Skip to content

Commit f1486eb

Browse files
authored
Merge pull request #2486 from SCIInstitute/deepssm_refactor2
Refactors DeepSSM for reliability, memory efficiency, and testability.
2 parents 9a650bd + ac9897a commit f1486eb

43 files changed

Lines changed: 1802 additions & 237 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

Applications/shapeworks/Command.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ class Command {
3333
const std::string desc() const { return parser.description(); }
3434

3535
/// parses the arguments for this command, saving them in the parser and returning the leftovers
36-
std::vector<std::string> parse_args(const std::vector<std::string> &arguments);
36+
virtual std::vector<std::string> parse_args(const std::vector<std::string> &arguments);
3737

3838
/// calls execute for this command using the parsed args, returning system exit value
3939
int run(SharedCommandData &sharedData);
@@ -108,6 +108,12 @@ class DeepSSMCommandGroup : public Command
108108
public:
109109
const std::string type() override { return "DeepSSM"; }
110110

111+
// DeepSSM is a terminal command - don't pass remaining args to other commands
112+
std::vector<std::string> parse_args(const std::vector<std::string> &arguments) override {
113+
Command::parse_args(arguments);
114+
return {}; // return empty - DeepSSM consumes all args
115+
}
116+
111117
private:
112118
};
113119

Applications/shapeworks/Commands.cpp

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
#include <ShapeworksUtils.h>
1313
#include <Utils/StringUtils.h>
1414

15-
#include <QApplication>
15+
#include <QCoreApplication>
1616
#include <boost/filesystem.hpp>
1717

1818
namespace shapeworks {
@@ -371,18 +371,23 @@ void DeepSSMCommand::buildParser() {
371371
.set_default(0)
372372
.help("Number of data loader workers (default: 0)");
373373

374+
parser.add_option("--aug_processes")
375+
.action("store")
376+
.type("int")
377+
.set_default(0)
378+
.help("Number of augmentation processes (default: 0 = use all cores)");
379+
374380
Command::buildParser();
375381
}
376382

377383
bool DeepSSMCommand::execute(const optparse::Values& options, SharedCommandData& sharedData) {
378-
// Create a non-gui QApplication instance
379-
int argc = 3;
380-
char* argv[3];
384+
// QCoreApplication provides the event loop needed for PythonWorker's QThread,
385+
// without requiring Qt platform plugins (which may not be available on headless CI).
386+
int argc = 1;
387+
char* argv[1];
381388
argv[0] = const_cast<char*>("shapeworks");
382-
argv[1] = const_cast<char*>("-platform");
383-
argv[2] = const_cast<char*>("offscreen");
384389

385-
QApplication app(argc, argv);
390+
QCoreApplication app(argc, argv);
386391

387392
// Handle project file: either from --name or first positional argument
388393
std::string project_file;
@@ -413,12 +418,14 @@ bool DeepSSMCommand::execute(const optparse::Values& options, SharedCommandData&
413418
bool do_test = options.is_set("test") || options.is_set("all");
414419

415420
int num_workers = static_cast<int>(options.get("num_workers"));
421+
int aug_processes = static_cast<int>(options.get("aug_processes"));
416422

417423
std::cout << "Prep step: " << (do_prep ? "on" : "off") << "\n";
418424
std::cout << "Augment step: " << (do_augment ? "on" : "off") << "\n";
419425
std::cout << "Train step: " << (do_train ? "on" : "off") << "\n";
420426
std::cout << "Test step: " << (do_test ? "on" : "off") << "\n";
421427
std::cout << "Num dataloader workers: " << num_workers << "\n";
428+
std::cout << "Augmentation processes: " << (aug_processes == 0 ? QThread::idealThreadCount() : aug_processes) << "\n";
422429

423430
if (!do_prep && !do_augment && !do_train && !do_test) {
424431
do_prep = true;
@@ -472,6 +479,7 @@ bool DeepSSMCommand::execute(const optparse::Values& options, SharedCommandData&
472479
if (do_augment) {
473480
std::cout << "Running DeepSSM data augmentation...\n";
474481
auto job = QSharedPointer<DeepSSMJob>::create(project, DeepSSMJob::JobType::DeepSSM_AugmentationType);
482+
job->set_aug_processes(aug_processes);
475483
python_worker.run_job(job);
476484
if (!wait_for_job(job)) {
477485
return false;
@@ -499,7 +507,7 @@ bool DeepSSMCommand::execute(const optparse::Values& options, SharedCommandData&
499507

500508
project->save();
501509

502-
return false;
510+
return true;
503511
}
504512

505513
} // namespace shapeworks

Examples/Python/RunUseCase.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,14 @@
6969
parser.add_argument("--tiny_test", help="Run as a short test", action="store_true")
7070
parser.add_argument("--verify", help="Run as a full test", action="store_true")
7171
parser.add_argument("--clean", help="Run from scratch, ignoring intermediate stages", action="store_true")
72+
parser.add_argument("--exact_check", help="Save or verify exact values for refactoring verification (platform-specific)",
73+
choices=["save", "verify"])
7274
args = parser.parse_args()
7375

76+
# Validate deep_ssm-specific arguments
77+
if args.exact_check and args.use_case != "deep_ssm":
78+
parser.error("--exact_check is only supported for the deep_ssm use case")
79+
7480
type = ""
7581
if args.tiny_test:
7682
type = "tiny_test_"

Examples/Python/deep_ssm.py

Lines changed: 45 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,8 @@ def Run_Pipeline(args):
6464
This data is comprised of femur meshes and corresponding hip CT scans.
6565
"""
6666

67-
if platform.system() == "Darwin":
68-
# On MacOS, CPU PyTorch is hanging with parallel
67+
if platform.system() != "Linux":
68+
# CPU PyTorch hangs with OpenMP parallelism on macOS and Windows
6969
os.environ['OMP_NUM_THREADS'] = "1"
7070
# If running a tiny_test, then download subset of the data
7171
if args.tiny_test:
@@ -396,6 +396,7 @@ def Run_Pipeline(args):
396396
"c_lat": 6.3
397397
}
398398
}
399+
399400
if args.tiny_test:
400401
model_parameters["trainer"]["epochs"] = 1
401402
# Save config file
@@ -436,17 +437,17 @@ def Run_Pipeline(args):
436437
val_world_particles.append(project_path + subjects[index].get_world_particle_filenames()[0])
437438
val_mesh_files.append(project_path + subjects[index].get_groomed_filenames()[0])
438439

439-
val_out_dir = output_directory + model_name + '/validation_predictions/'
440440
predicted_val_world_particles = DeepSSMUtils.testDeepSSM(config_file, loader='validation')
441441
print("Validation world predictions saved.")
442-
# Generate local predictions
443-
local_val_prediction_dir = val_out_dir + 'local_predictions/'
442+
# Generate local predictions - create directory next to world_predictions
443+
world_pred_dir = os.path.dirname(predicted_val_world_particles[0])
444+
local_val_prediction_dir = world_pred_dir.replace("world_predictions", "local_predictions")
444445
if not os.path.exists(local_val_prediction_dir):
445446
os.makedirs(local_val_prediction_dir)
446447
predicted_val_local_particles = []
447448
for particle_file, transform in zip(predicted_val_world_particles, val_transforms):
448449
particles = np.loadtxt(particle_file)
449-
local_particle_file = particle_file.replace("world_predictions/", "local_predictions/")
450+
local_particle_file = particle_file.replace("world_predictions", "local_predictions")
450451
local_particles = sw.utils.transformParticles(particles, transform, inverse=True)
451452
np.savetxt(local_particle_file, local_particles)
452453
predicted_val_local_particles.append(local_particle_file)
@@ -462,6 +463,8 @@ def Run_Pipeline(args):
462463
template_mesh = project_path + subjects[reference_index].get_groomed_filenames()[0]
463464
template_particles = project_path + subjects[reference_index].get_local_particle_filenames()[0]
464465
# Get distance between clipped true and predicted meshes
466+
# Get the validation output directory from the predictions path
467+
val_out_dir = os.path.dirname(local_val_prediction_dir.rstrip('/')) + '/'
465468
mean_dist = DeepSSMUtils.analyzeMeshDistance(predicted_val_local_particles, val_mesh_files,
466469
template_particles, template_mesh, val_out_dir,
467470
planes=val_planes)
@@ -500,17 +503,17 @@ def Run_Pipeline(args):
500503
with open(plane_file) as json_file:
501504
test_planes.append(json.load(json_file)['planes'][0]['points'])
502505

503-
test_out_dir = output_directory + model_name + '/test_predictions/'
504506
predicted_test_world_particles = DeepSSMUtils.testDeepSSM(config_file, loader='test')
505507
print("Test world predictions saved.")
506-
# Generate local predictions
507-
local_test_prediction_dir = test_out_dir + 'local_predictions/'
508+
# Generate local predictions - create directory next to world_predictions
509+
world_pred_dir = os.path.dirname(predicted_test_world_particles[0])
510+
local_test_prediction_dir = world_pred_dir.replace("world_predictions", "local_predictions")
508511
if not os.path.exists(local_test_prediction_dir):
509512
os.makedirs(local_test_prediction_dir)
510513
predicted_test_local_particles = []
511514
for particle_file, transform in zip(predicted_test_world_particles, test_transforms):
512515
particles = np.loadtxt(particle_file)
513-
local_particle_file = particle_file.replace("world_predictions/", "local_predictions/")
516+
local_particle_file = particle_file.replace("world_predictions", "local_predictions")
514517
local_particles = sw.utils.transformParticles(particles, transform, inverse=True)
515518
np.savetxt(local_particle_file, local_particles)
516519
predicted_test_local_particles.append(local_particle_file)
@@ -524,28 +527,53 @@ def Run_Pipeline(args):
524527
template_mesh = project_path + subjects[reference_index].get_groomed_filenames()[0]
525528
template_particles = project_path + subjects[reference_index].get_local_particle_filenames()[0]
526529

530+
# Get the test output directory from the predictions path
531+
test_out_dir = os.path.dirname(local_test_prediction_dir.rstrip('/')) + '/'
527532
mean_dist = DeepSSMUtils.analyzeMeshDistance(predicted_test_local_particles, test_mesh_files,
528533
template_particles, template_mesh, test_out_dir,
529534
planes=test_planes)
530535
print("Test mean mesh surface-to-surface distance: " + str(mean_dist))
531536

532-
DeepSSMUtils.process_test_predictions(project, config_file)
533-
537+
final_mean_dist = DeepSSMUtils.process_test_predictions(project, config_file)
538+
534539
# If tiny test or verify, check results and exit
535-
check_results(args, mean_dist)
540+
check_results(args, final_mean_dist, output_directory)
536541

537542
open(status_dir + "step_12.txt", 'w').close()
538543

539544
print("All steps complete")
540545

541546

542547
# Verification
543-
def check_results(args, mean_dist):
548+
def check_results(args, mean_dist, output_directory):
544549
if args.tiny_test:
545550
print("\nVerifying use case results.")
546-
if not math.isclose(mean_dist, 10, rel_tol=1):
547-
print("Test failed.")
548-
exit(-1)
551+
552+
exact_check_file = output_directory + "exact_check_value.txt"
553+
554+
# Exact check for refactoring verification (platform-specific)
555+
if args.exact_check == "save":
556+
with open(exact_check_file, "w") as f:
557+
f.write(str(mean_dist))
558+
print(f"Saved exact check value to: {exact_check_file}")
559+
print(f"Value: {mean_dist}")
560+
elif args.exact_check == "verify":
561+
if not os.path.exists(exact_check_file):
562+
print(f"Error: No saved value found at {exact_check_file}")
563+
print("Run with --exact_check save first to create baseline.")
564+
exit(-1)
565+
with open(exact_check_file, "r") as f:
566+
expected_mean_dist = float(f.read().strip())
567+
if mean_dist != expected_mean_dist:
568+
print(f"Exact check failed: expected {expected_mean_dist}, got {mean_dist}")
569+
exit(-1)
570+
print(f"Exact check passed: {mean_dist}")
571+
else:
572+
# Relaxed check for CI/cross-platform
573+
if not math.isclose(mean_dist, 10, rel_tol=1):
574+
print("Test failed.")
575+
exit(-1)
576+
549577
print("Done with test, verification succeeded.")
550578
exit(0)
551579
else:

Libs/Application/DeepSSM/DeepSSMJob.cpp

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -241,11 +241,12 @@ void DeepSSMJob::run_augmentation() {
241241
py::module py_deep_ssm_utils = py::module::import("DeepSSMUtils");
242242
py::object run_data_aug = py_deep_ssm_utils.attr("run_data_augmentation");
243243

244+
int processes = aug_processes_ > 0 ? aug_processes_ : QThread::idealThreadCount();
245+
244246
int aug_dims = run_data_aug(project_, params.get_aug_num_samples(),
245247
0 /* num dims, set to zero to allow percent variability to be used */,
246248
params.get_aug_percent_variability(), sampler_type.toStdString(), 0 /* mixture_num */,
247-
QThread::idealThreadCount() /* processes */
248-
)
249+
processes)
249250
.cast<int>();
250251

251252
params.set_training_num_dims(aug_dims);
@@ -394,6 +395,12 @@ void DeepSSMJob::set_num_dataloader_workers(int num_workers) { num_dataloader_wo
394395
//---------------------------------------------------------------------------
395396
int DeepSSMJob::get_num_dataloader_workers() { return num_dataloader_workers_; }
396397

398+
//---------------------------------------------------------------------------
399+
void DeepSSMJob::set_aug_processes(int processes) { aug_processes_ = processes; }
400+
401+
//---------------------------------------------------------------------------
402+
int DeepSSMJob::get_aug_processes() { return aug_processes_; }
403+
397404
//---------------------------------------------------------------------------
398405
void DeepSSMJob::update_prep_stage(PrepStep step) {
399406
/*

Libs/Application/DeepSSM/DeepSSMJob.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,9 @@ class DeepSSMJob : public Job {
5555
void set_num_dataloader_workers(int num_workers);
5656
int get_num_dataloader_workers();
5757

58+
void set_aug_processes(int processes);
59+
int get_aug_processes();
60+
5861
void set_prep_step(DeepSSMJob::PrepStep step) {
5962
std::lock_guard<std::mutex> lock(mutex_);
6063
prep_step_ = step;
@@ -72,6 +75,7 @@ class DeepSSMJob : public Job {
7275
DeepSSMJob::PrepStep prep_step_{DeepSSMJob::NOT_STARTED};
7376

7477
int num_dataloader_workers_{0};
78+
int aug_processes_{0};
7579

7680
// mutex
7781
std::mutex mutex_;

Libs/Application/Job/PythonWorker.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -207,11 +207,10 @@ bool PythonWorker::init() {
207207
path = QString::fromStdString(line);
208208
}
209209
file.close();
210+
qputenv("PATH", path.toUtf8());
211+
SW_LOG("Setting PATH for Python to: " + path.toStdString());
210212
}
211213

212-
qputenv("PATH", path.toUtf8());
213-
SW_LOG("Setting PATH for Python to: " + path.toStdString());
214-
215214
// Python 3.8+ requires explicit DLL directory registration
216215
// PATH environment variable is no longer used for DLL search
217216
SetDefaultDllDirectories(LOAD_LIBRARY_SEARCH_DEFAULT_DIRS | LOAD_LIBRARY_SEARCH_USER_DIRS);

Libs/Groom/Groom.cpp

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,17 @@ bool Groom::image_pipeline(std::shared_ptr<Subject> subject, size_t domain) {
186186
std::string groomed_name = get_output_filename(original, DomainType::Image);
187187

188188
if (params.get_convert_to_mesh()) {
189+
// Use isovalue 0.0 for distance transforms (the zero level set is the surface)
189190
Mesh mesh = image.toMesh(0.0);
191+
if (mesh.numPoints() == 0) {
192+
throw std::runtime_error("Empty mesh generated from segmentation - segmentation may have no valid data");
193+
}
194+
// Check for valid cells
195+
auto poly_data = mesh.getVTKMesh();
196+
if (poly_data->GetNumberOfCells() == 0) {
197+
throw std::runtime_error("Mesh has no cells - segmentation may have no valid surface");
198+
}
199+
SW_DEBUG("Mesh after toMesh: {} points, {} cells", poly_data->GetNumberOfPoints(), poly_data->GetNumberOfCells());
190200
run_mesh_pipeline(mesh, params, original);
191201
groomed_name = get_output_filename(original, DomainType::Mesh);
192202
// save the groomed mesh
@@ -239,6 +249,9 @@ bool Groom::run_image_pipeline(Image& image, GroomParameters params) {
239249
// crop
240250
if (params.get_crop()) {
241251
PhysicalRegion region = image.physicalBoundingBox(0.5);
252+
if (!region.valid()) {
253+
throw std::runtime_error("Empty segmentation - no voxels found above threshold for cropping");
254+
}
242255
image.crop(region);
243256
increment_progress();
244257
}
@@ -1336,7 +1349,7 @@ std::vector<std::vector<double>> Groom::get_icp_transforms(const std::vector<Mes
13361349
matrix->Identity();
13371350

13381351
Mesh source = meshes[i];
1339-
if (source.getVTKMesh()->GetNumberOfPoints() != 0) {
1352+
if (source.getVTKMesh()->GetNumberOfPoints() != 0 && reference.getVTKMesh()->GetNumberOfPoints() != 0) {
13401353
// create copies for thread safety
13411354
auto poly_data1 = vtkSmartPointer<vtkPolyData>::New();
13421355
poly_data1->DeepCopy(source.getVTKMesh());

0 commit comments

Comments
 (0)