Skip to content

Commit ec6ad42

Browse files
authored
Merge pull request #2494 from SCIInstitute/amorris/procrustes_updates
Fix fixed domain Procrustes to preserve existing transforms
2 parents 3745fca + 62cfb4b commit ec6ad42

6 files changed

Lines changed: 143 additions & 30 deletions

File tree

Libs/Optimize/Optimize.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,11 @@ int Optimize::SetParameters() {
277277
this->ReadPrefixTransformFile(m_prefix_transform_file);
278278
}
279279

280+
// Apply stored Procrustes transforms (e.g. for fixed shapes loaded from project)
281+
for (auto& [domain_index, transform] : m_procrustes_transforms) {
282+
m_sampler->GetParticleSystem()->SetTransform(domain_index, transform);
283+
}
284+
280285
return true;
281286
}
282287

@@ -1811,6 +1816,11 @@ void Optimize::SetFixedDomains(std::vector<int> flags) {
18111816
this->m_domain_flags = flags;
18121817
}
18131818

1819+
//---------------------------------------------------------------------------
1820+
void Optimize::SetProcustesTransforms(std::map<int, vnl_matrix_fixed<double, 4, 4>> transforms) {
1821+
m_procrustes_transforms = std::move(transforms);
1822+
}
1823+
18141824
//---------------------------------------------------------------------------
18151825
const std::vector<int>& Optimize::GetDomainFlags() { return this->m_domain_flags; }
18161826

Libs/Optimize/Optimize.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#endif
66

77
// std
8+
#include <map>
89
#include <string>
910
#include <vector>
1011

@@ -236,6 +237,9 @@ class Optimize {
236237
//! Set Domain Flags (TODO: details)
237238
void SetFixedDomains(std::vector<int> flags);
238239

240+
//! Set Procrustes transforms to load for specific domains (applied after initialization)
241+
void SetProcustesTransforms(std::map<int, vnl_matrix_fixed<double, 4, 4>> transforms);
242+
239243
//! Shared boundary settings
240244
void SetSharedBoundaryEnabled(bool enabled);
241245
void SetSharedBoundaryWeight(double weight);
@@ -415,6 +419,7 @@ class Optimize {
415419
double m_cotan_sigma_factor = 5.0;
416420
std::vector<int> m_particle_flags;
417421
std::vector<int> m_domain_flags;
422+
std::map<int, vnl_matrix_fixed<double, 4, 4>> m_procrustes_transforms; // domain index -> transform
418423
double m_narrow_band = 0.0;
419424
bool m_narrow_band_set = false;
420425
bool m_fixed_domains_present = false;

Libs/Optimize/OptimizeParameters.cpp

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -566,6 +566,30 @@ bool OptimizeParameters::set_up_optimize(Optimize* optimize) {
566566
SW_DEBUG("Setting Initial Points");
567567
optimize->SetInitialPoints(get_initial_points());
568568
}
569+
570+
// Store Procrustes transforms for fixed shapes (applied after ParticleSystem initialization)
571+
using TransformType = vnl_matrix_fixed<double, 4, 4>;
572+
std::map<int, TransformType> procrustes_transforms;
573+
int domain_idx = 0;
574+
for (auto s : subjects) {
575+
auto pt = s->get_procrustes_transforms();
576+
for (int d = 0; d < domains_per_shape; d++) {
577+
if (s->is_fixed() && d < pt.size() && pt[d].size() == 16) {
578+
TransformType transform;
579+
int index = 0;
580+
for (int c = 0; c < 4; c++) {
581+
for (int r = 0; r < 4; r++) {
582+
transform[c][r] = pt[d][index++];
583+
}
584+
}
585+
procrustes_transforms[domain_idx] = transform;
586+
}
587+
domain_idx++;
588+
}
589+
}
590+
if (!procrustes_transforms.empty()) {
591+
optimize->SetProcustesTransforms(std::move(procrustes_transforms));
592+
}
569593
}
570594

571595
for (auto s : subjects) {
@@ -765,7 +789,9 @@ bool OptimizeParameters::set_up_optimize(Optimize* optimize) {
765789
}
766790
}
767791

768-
optimize->GetSampler()->GetParticleSystem()->SetPrefixTransform(domain_count++, prefix_transform);
792+
optimize->GetSampler()->GetParticleSystem()->SetPrefixTransform(domain_count, prefix_transform);
793+
794+
domain_count++;
769795

770796
auto name = StringUtils::getBaseFilenameWithoutExtension(filename);
771797

Libs/Optimize/ProcrustesRegistration.cpp

Lines changed: 11 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,15 @@
77
namespace shapeworks {
88

99
//---------------------------------------------------------------------------
10-
Procrustes3D::ShapeType ProcrustesRegistration::ExtractShape(int domain_index, int num_points) {
10+
Procrustes3D::ShapeType ProcrustesRegistration::ExtractShape(int domain_index, int num_points, bool fully_transformed) {
1111
Procrustes3D::ShapeType shape;
1212
Procrustes3D::PointType point;
1313
for (int j = 0; j < num_points; j++) {
14-
point(0) = m_ParticleSystem->GetPrefixTransformedPosition(j, domain_index)[0];
15-
point(1) = m_ParticleSystem->GetPrefixTransformedPosition(j, domain_index)[1];
16-
point(2) = m_ParticleSystem->GetPrefixTransformedPosition(j, domain_index)[2];
14+
auto pos = fully_transformed ? m_ParticleSystem->GetTransformedPosition(j, domain_index)
15+
: m_ParticleSystem->GetPrefixTransformedPosition(j, domain_index);
16+
point(0) = pos[0];
17+
point(1) = pos[1];
18+
point(2) = pos[2];
1719
shape.push_back(point);
1820
}
1921
return shape;
@@ -53,36 +55,23 @@ void ProcrustesRegistration::RunFixedDomainRegistration(int domainStart, int num
5355

5456
// Build/rebuild cache if needed (first call or particle count changed after split)
5557
if (!cache.valid || cache.num_points != numPoints) {
58+
// Extract fixed shapes using their full transforms (prefix + existing Procrustes).
59+
// Fixed shapes already have correct transforms; we never modify them.
5660
Procrustes3D::ShapeListType fixed_shapelist;
57-
std::vector<int> fixed_domain_indices;
5861

5962
for (int i = 0, k = domainStart; i < numShapes; i++, k += m_DomainsPerShape) {
6063
if (!is_fixed[i]) continue;
61-
fixed_shapelist.push_back(ExtractShape(k, numPoints));
62-
fixed_domain_indices.push_back(k);
64+
fixed_shapelist.push_back(ExtractShape(k, numPoints, /*fully_transformed=*/true));
6365
}
6466

65-
// Run GPA on fixed shapes only
66-
Procrustes3D::SimilarityTransformListType fixed_transforms;
67+
// Compute mean of the already-aligned fixed shapes (no GPA needed)
6768
Procrustes3D procrustes(m_Scaling, m_RotationTranslation);
68-
procrustes.AlignShapes(fixed_transforms, fixed_shapelist);
69-
70-
// Set transforms for fixed shapes
71-
Procrustes3D::TransformMatrixListType fixed_matrices;
72-
procrustes.ConstructTransformMatrices(fixed_transforms, fixed_matrices);
73-
74-
for (size_t i = 0; i < fixed_domain_indices.size(); i++) {
75-
m_ParticleSystem->SetTransform(fixed_domain_indices[i], fixed_matrices[i]);
76-
}
77-
78-
// Compute and cache the mean of the aligned fixed shapes
79-
// (fixed_shapelist has been modified in-place by AlignShapes to be in Procrustes space)
8069
procrustes.ComputeMeanShape(cache.mean, fixed_shapelist);
8170
cache.num_points = numPoints;
8271
cache.valid = true;
8372

8473
SW_LOG("Procrustes: cached fixed domain mean for domain type {} ({} fixed shapes, {} points)", domainType,
85-
fixed_domain_indices.size(), numPoints);
74+
fixed_shapelist.size(), numPoints);
8675
}
8776

8877
// Align each non-fixed shape to the cached fixed mean using OPA

Libs/Optimize/ProcrustesRegistration.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,10 @@ class ProcrustesRegistration {
4747
void RunFixedDomainRegistration(int domainStart, int numShapes, int numPoints,
4848
const std::vector<bool>& is_fixed);
4949

50-
//! Extract prefix-transformed particle positions for a single domain
51-
Procrustes3D::ShapeType ExtractShape(int domain_index, int num_points);
50+
//! Extract particle positions for a single domain.
51+
//! If fully_transformed is true, applies both prefix and Procrustes transforms (world space).
52+
//! If false, applies only the prefix transform (for computing new Procrustes transforms).
53+
Procrustes3D::ShapeType ExtractShape(int domain_index, int num_points, bool fully_transformed = false);
5254

5355
int m_DomainsPerShape = 1;
5456
bool m_Scaling = true;

Testing/OptimizeTests/OptimizeTests.cpp

Lines changed: 86 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -203,12 +203,93 @@ TEST(OptimizeTests, fixed_domain_procrustes) {
203203
std::cerr << "Eigenvalue " << i << " : " << values[i] << "\n";
204204
}
205205

206-
// With Procrustes scaling enabled, the size variation between spheres should be
207-
// factored out, resulting in a much smaller top eigenvalue compared to the
208-
// fixed_domain test (which has Procrustes disabled and gets >5000).
209-
// The top eigenvalue should be small since all shapes are spheres differing only in scale.
206+
// Fixed shapes keep their existing Procrustes transforms (identity in this test).
207+
// Only the new shape (sphere40) gets a Procrustes transform computed via OPA against
208+
// the fixed mean. Since the test fixed shapes have identity transforms with different
209+
// scales, the eigenvalue will be large (scale variation is not normalized).
210+
// In a real pipeline, fixed shapes would have proper Procrustes transforms from their
211+
// original optimization. Here we just verify the optimization completes successfully.
210212
double value = values[values.size() - 1];
211-
ASSERT_LT(value, 100.0);
213+
ASSERT_GT(value, 0.0);
214+
}
215+
216+
//---------------------------------------------------------------------------
217+
// Test that multiple new (non-fixed) shapes don't interact with each other.
218+
// Running two new shapes together with fixed shapes should produce the same
219+
// result as running each new shape individually with the same fixed shapes.
220+
TEST(OptimizeTests, fixed_domain_independence) {
221+
// Helper lambda: run optimization with specified fixed/excluded/new configuration
222+
// Returns local particles for each domain, indexed by domain index in the project
223+
auto run_optimize = [](const std::string& temp_name,
224+
const std::vector<bool>& is_fixed,
225+
const std::vector<bool>& is_excluded) -> std::vector<std::vector<itk::Point<double>>> {
226+
prep_temp("/optimize/fixed_domain", temp_name);
227+
228+
Optimize app;
229+
ProjectHandle project = std::make_shared<Project>();
230+
EXPECT_TRUE(project->load("optimize.swproj"));
231+
232+
// Reconfigure which subjects are fixed/excluded
233+
auto subjects = project->get_subjects();
234+
for (int i = 0; i < subjects.size(); i++) {
235+
subjects[i]->set_fixed(is_fixed[i]);
236+
subjects[i]->set_excluded(is_excluded[i]);
237+
}
238+
239+
OptimizeParameters params(project);
240+
EXPECT_TRUE(params.set_up_optimize(&app));
241+
bool success = app.Run();
242+
EXPECT_TRUE(success);
243+
244+
return app.GetLocalPoints();
245+
};
246+
247+
// Project has 4 shapes: sphere10, sphere20, sphere30, sphere40
248+
// Run A: sphere10,20 fixed; sphere30,40 both new
249+
auto points_together = run_optimize(
250+
"fixed_domain_indep_together",
251+
{true, true, false, false}, // is_fixed
252+
{false, false, false, false} // is_excluded
253+
);
254+
255+
// Run B: sphere10,20 fixed; sphere30 new; sphere40 excluded
256+
auto points_30_alone = run_optimize(
257+
"fixed_domain_indep_30",
258+
{true, true, false, false}, // is_fixed
259+
{false, false, false, true} // is_excluded: sphere40 excluded
260+
);
261+
262+
// Run C: sphere10,20 fixed; sphere40 new; sphere30 excluded
263+
auto points_40_alone = run_optimize(
264+
"fixed_domain_indep_40",
265+
{true, true, false, false}, // is_fixed
266+
{false, false, true, false} // is_excluded: sphere30 excluded
267+
);
268+
269+
// In run A (together), domains are: 0=sphere10, 1=sphere20, 2=sphere30, 3=sphere40
270+
// In run B (30 alone), domains are: 0=sphere10, 1=sphere20, 2=sphere30
271+
// In run C (40 alone), domains are: 0=sphere10, 1=sphere20, 2=sphere40
272+
273+
// Compare sphere30 particles: run A domain 2 vs run B domain 2
274+
ASSERT_EQ(points_together[2].size(), points_30_alone[2].size());
275+
for (int i = 0; i < points_together[2].size(); i++) {
276+
for (int d = 0; d < 3; d++) {
277+
EXPECT_NEAR(points_together[2][i][d], points_30_alone[2][i][d], 1e-6)
278+
<< "sphere30 particle " << i << " dim " << d << " differs";
279+
}
280+
}
281+
282+
// Compare sphere40 particles: run A domain 3 vs run C domain 2
283+
ASSERT_EQ(points_together[3].size(), points_40_alone[2].size());
284+
for (int i = 0; i < points_together[3].size(); i++) {
285+
for (int d = 0; d < 3; d++) {
286+
EXPECT_NEAR(points_together[3][i][d], points_40_alone[2][i][d], 1e-6)
287+
<< "sphere40 particle " << i << " dim " << d << " differs";
288+
}
289+
}
290+
291+
std::cerr << "Fixed domain independence test passed: new shapes produce identical "
292+
<< "results whether run together or individually\n";
212293
}
213294

214295
//---------------------------------------------------------------------------

0 commit comments

Comments
 (0)