@@ -64,8 +64,8 @@ def Run_Pipeline(args):
6464 This data is comprised of femur meshes and corresponding hip CT scans.
6565 """
6666
67- if platform .system () == "Darwin " :
68- # On MacOS, CPU PyTorch is hanging with parallel
67+ if platform .system () != "Linux " :
68+ # CPU PyTorch hangs with OpenMP parallelism on macOS and Windows
6969 os .environ ['OMP_NUM_THREADS' ] = "1"
7070 # If running a tiny_test, then download subset of the data
7171 if args .tiny_test :
@@ -396,6 +396,7 @@ def Run_Pipeline(args):
396396 "c_lat" : 6.3
397397 }
398398 }
399+
399400 if args .tiny_test :
400401 model_parameters ["trainer" ]["epochs" ] = 1
401402 # Save config file
@@ -436,17 +437,17 @@ def Run_Pipeline(args):
436437 val_world_particles .append (project_path + subjects [index ].get_world_particle_filenames ()[0 ])
437438 val_mesh_files .append (project_path + subjects [index ].get_groomed_filenames ()[0 ])
438439
439- val_out_dir = output_directory + model_name + '/validation_predictions/'
440440 predicted_val_world_particles = DeepSSMUtils .testDeepSSM (config_file , loader = 'validation' )
441441 print ("Validation world predictions saved." )
442- # Generate local predictions
443- local_val_prediction_dir = val_out_dir + 'local_predictions/'
442+ # Generate local predictions - create directory next to world_predictions
443+ world_pred_dir = os .path .dirname (predicted_val_world_particles [0 ])
444+ local_val_prediction_dir = world_pred_dir .replace ("world_predictions" , "local_predictions" )
444445 if not os .path .exists (local_val_prediction_dir ):
445446 os .makedirs (local_val_prediction_dir )
446447 predicted_val_local_particles = []
447448 for particle_file , transform in zip (predicted_val_world_particles , val_transforms ):
448449 particles = np .loadtxt (particle_file )
449- local_particle_file = particle_file .replace ("world_predictions/ " , "local_predictions/ " )
450+ local_particle_file = particle_file .replace ("world_predictions" , "local_predictions" )
450451 local_particles = sw .utils .transformParticles (particles , transform , inverse = True )
451452 np .savetxt (local_particle_file , local_particles )
452453 predicted_val_local_particles .append (local_particle_file )
@@ -462,6 +463,8 @@ def Run_Pipeline(args):
462463 template_mesh = project_path + subjects [reference_index ].get_groomed_filenames ()[0 ]
463464 template_particles = project_path + subjects [reference_index ].get_local_particle_filenames ()[0 ]
464465 # Get distance between clipped true and predicted meshes
466+ # Get the validation output directory from the predictions path
467+ val_out_dir = os .path .dirname (local_val_prediction_dir .rstrip ('/' )) + '/'
465468 mean_dist = DeepSSMUtils .analyzeMeshDistance (predicted_val_local_particles , val_mesh_files ,
466469 template_particles , template_mesh , val_out_dir ,
467470 planes = val_planes )
@@ -500,17 +503,17 @@ def Run_Pipeline(args):
500503 with open (plane_file ) as json_file :
501504 test_planes .append (json .load (json_file )['planes' ][0 ]['points' ])
502505
503- test_out_dir = output_directory + model_name + '/test_predictions/'
504506 predicted_test_world_particles = DeepSSMUtils .testDeepSSM (config_file , loader = 'test' )
505507 print ("Test world predictions saved." )
506- # Generate local predictions
507- local_test_prediction_dir = test_out_dir + 'local_predictions/'
508+ # Generate local predictions - create directory next to world_predictions
509+ world_pred_dir = os .path .dirname (predicted_test_world_particles [0 ])
510+ local_test_prediction_dir = world_pred_dir .replace ("world_predictions" , "local_predictions" )
508511 if not os .path .exists (local_test_prediction_dir ):
509512 os .makedirs (local_test_prediction_dir )
510513 predicted_test_local_particles = []
511514 for particle_file , transform in zip (predicted_test_world_particles , test_transforms ):
512515 particles = np .loadtxt (particle_file )
513- local_particle_file = particle_file .replace ("world_predictions/ " , "local_predictions/ " )
516+ local_particle_file = particle_file .replace ("world_predictions" , "local_predictions" )
514517 local_particles = sw .utils .transformParticles (particles , transform , inverse = True )
515518 np .savetxt (local_particle_file , local_particles )
516519 predicted_test_local_particles .append (local_particle_file )
@@ -524,28 +527,53 @@ def Run_Pipeline(args):
524527 template_mesh = project_path + subjects [reference_index ].get_groomed_filenames ()[0 ]
525528 template_particles = project_path + subjects [reference_index ].get_local_particle_filenames ()[0 ]
526529
530+ # Get the test output directory from the predictions path
531+ test_out_dir = os .path .dirname (local_test_prediction_dir .rstrip ('/' )) + '/'
527532 mean_dist = DeepSSMUtils .analyzeMeshDistance (predicted_test_local_particles , test_mesh_files ,
528533 template_particles , template_mesh , test_out_dir ,
529534 planes = test_planes )
530535 print ("Test mean mesh surface-to-surface distance: " + str (mean_dist ))
531536
532- DeepSSMUtils .process_test_predictions (project , config_file )
533-
537+ final_mean_dist = DeepSSMUtils .process_test_predictions (project , config_file )
538+
534539 # If tiny test or verify, check results and exit
535- check_results (args , mean_dist )
540+ check_results (args , final_mean_dist , output_directory )
536541
537542 open (status_dir + "step_12.txt" , 'w' ).close ()
538543
539544 print ("All steps complete" )
540545
541546
542547# Verification
543- def check_results (args , mean_dist ):
548+ def check_results (args , mean_dist , output_directory ):
544549 if args .tiny_test :
545550 print ("\n Verifying use case results." )
546- if not math .isclose (mean_dist , 10 , rel_tol = 1 ):
547- print ("Test failed." )
548- exit (- 1 )
551+
552+ exact_check_file = output_directory + "exact_check_value.txt"
553+
554+ # Exact check for refactoring verification (platform-specific)
555+ if args .exact_check == "save" :
556+ with open (exact_check_file , "w" ) as f :
557+ f .write (str (mean_dist ))
558+ print (f"Saved exact check value to: { exact_check_file } " )
559+ print (f"Value: { mean_dist } " )
560+ elif args .exact_check == "verify" :
561+ if not os .path .exists (exact_check_file ):
562+ print (f"Error: No saved value found at { exact_check_file } " )
563+ print ("Run with --exact_check save first to create baseline." )
564+ exit (- 1 )
565+ with open (exact_check_file , "r" ) as f :
566+ expected_mean_dist = float (f .read ().strip ())
567+ if mean_dist != expected_mean_dist :
568+ print (f"Exact check failed: expected { expected_mean_dist } , got { mean_dist } " )
569+ exit (- 1 )
570+ print (f"Exact check passed: { mean_dist } " )
571+ else :
572+ # Relaxed check for CI/cross-platform
573+ if not math .isclose (mean_dist , 10 , rel_tol = 1 ):
574+ print ("Test failed." )
575+ exit (- 1 )
576+
549577 print ("Done with test, verification succeeded." )
550578 exit (0 )
551579 else :
0 commit comments