Skip to content

Commit 82dbe62

Browse files
authored
Merge pull request #2480 from SCIInstitute/sampling_scale
Add sampling gradient scaling based on surface area normalization
2 parents 15e22e6 + dfd9ca8 commit 82dbe62

17 files changed

Lines changed: 626 additions & 328 deletions

Libs/Groom/Groom.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -378,6 +378,14 @@ bool Groom::run_mesh_pipeline(Mesh& mesh, GroomParameters params, const std::str
378378
}
379379

380380
if (params.get_remesh()) {
381+
auto poly_data = mesh.getVTKMesh();
382+
if (poly_data->GetNumberOfCells() == 0 || poly_data->GetCell(0)->GetNumberOfPoints() == 2) {
383+
SW_DEBUG("Number of cells: {}", poly_data->GetNumberOfCells());
384+
if (poly_data->GetNumberOfCells() > 0) {
385+
SW_DEBUG("Number of points in first cell: {}", poly_data->GetCell(0)->GetNumberOfPoints());
386+
}
387+
throw std::runtime_error("malformed mesh, mesh should be triangular");
388+
}
381389
int total_vertices = mesh.getVTKMesh()->GetNumberOfPoints();
382390
int num_vertices = params.get_remesh_num_vertices();
383391
if (params.get_remesh_percent_mode()) {

Libs/Mesh/Mesh.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
#include <vtkIncrementalPointLocator.h>
3434
#include <vtkKdTreePointLocator.h>
3535
#include <vtkLoopSubdivisionFilter.h>
36+
#include <vtkMassProperties.h>
3637
#include <vtkNew.h>
3738
#include <vtkOBJReader.h>
3839
#include <vtkOBJWriter.h>
@@ -1198,6 +1199,13 @@ Point3 Mesh::centerOfMass() const {
11981199
return center;
11991200
}
12001201

1202+
double Mesh::getSurfaceArea() const {
1203+
auto mass_props = vtkSmartPointer<vtkMassProperties>::New();
1204+
mass_props->SetInputData(this->poly_data_);
1205+
mass_props->Update();
1206+
return mass_props->GetSurfaceArea();
1207+
}
1208+
12011209
Point3 Mesh::getPoint(int id) const {
12021210
if (this->numPoints() < id) {
12031211
throw std::invalid_argument("mesh has fewer indices than requested");

Libs/Mesh/Mesh.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,9 @@ class Mesh {
206206
/// center of mass of mesh
207207
Point3 centerOfMass() const;
208208

209+
/// surface area of mesh
210+
double getSurfaceArea() const;
211+
209212
/// number of points
210213
int numPoints() const { return poly_data_->GetNumberOfPoints(); }
211214

Libs/Optimize/Domain/ContourDomain.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,10 @@ class ContourDomain : public ParticleDomain {
8585
return out;
8686
}
8787

88-
double GetSurfaceArea() const override { throw std::runtime_error("Contours do not have area"); }
88+
double GetSurfaceArea() const override {
89+
// TODO: Implement something analogous for scaling purposes
90+
return 1.0;
91+
}
8992

9093
void DeleteImages() override {
9194
// TODO what?
@@ -100,7 +103,6 @@ class ContourDomain : public ParticleDomain {
100103
const VectorDoubleType& global_direction, double epsilon) const override;
101104

102105
private:
103-
104106
double ComputeLineCoordinate(const double pt[3], int line) const;
105107

106108
// Return the number of lines that consist of i-th point
@@ -136,7 +138,6 @@ class ContourDomain : public ParticleDomain {
136138
void ComputeAvgEdgeLength();
137139

138140
int GetLineForPoint(const double pt[3], int idx, double& closest_distance, double closest_pt[3]) const;
139-
140141
};
141142

142143
} // namespace shapeworks

Libs/Optimize/Domain/ImageDomain.h

Lines changed: 6 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -111,8 +111,7 @@ class ImageDomain : public ParticleRegionDomain {
111111
this->UpdateSurfaceArea(I);
112112
}
113113

114-
inline double GetSurfaceArea() const override {
115-
throw std::runtime_error("Surface area is not computed currently.");
114+
double GetSurfaceArea() const override {
116115
return m_SurfaceArea;
117116
}
118117

@@ -137,7 +136,8 @@ class ImageDomain : public ParticleRegionDomain {
137136
} else {
138137
std::ostringstream message;
139138
message << "Domain " << m_DomainID << ": " << m_DomainName << " : Distance transform queried for a Point, " << p
140-
<< ", outside the given image domain. Consider increasing the padding in grooming or the narrow band optimization parameter";
139+
<< ", outside the given image domain. Consider increasing the padding in grooming or the narrow band "
140+
"optimization parameter";
141141
throw std::runtime_error(message.str());
142142
}
143143
}
@@ -179,7 +179,7 @@ class ImageDomain : public ParticleRegionDomain {
179179
openvdb::FloatGrid::Ptr GetVDBImage() const { return m_VDBImage; }
180180

181181
ImageDomain() {}
182-
virtual ~ImageDomain(){};
182+
virtual ~ImageDomain() {};
183183

184184
void PrintSelf(std::ostream& os, itk::Indent indent) const {
185185
ParticleRegionDomain::PrintSelf(os, indent);
@@ -241,20 +241,8 @@ class ImageDomain : public ParticleRegionDomain {
241241
}
242242

243243
void UpdateSurfaceArea(ImageType* I) {
244-
// TODO: This code has been copied from Optimize.cpp. It does not work
245-
/*
246-
typename itk::ImageToVTKImageFilter < ImageType > ::Pointer itk2vtkConnector;
247-
itk2vtkConnector = itk::ImageToVTKImageFilter < ImageType > ::New();
248-
itk2vtkConnector->SetInput(I);
249-
vtkSmartPointer < vtkContourFilter > ls = vtkSmartPointer < vtkContourFilter > ::New();
250-
ls->SetInputData(itk2vtkConnector->GetOutput());
251-
ls->SetValue(0, 0.0);
252-
ls->Update();
253-
vtkSmartPointer < vtkMassProperties > mp = vtkSmartPointer < vtkMassProperties > ::New();
254-
mp->SetInputData(ls->GetOutput());
255-
mp->Update();
256-
m_SurfaceArea = mp->GetSurfaceArea();
257-
*/
244+
Image image(I);
245+
m_SurfaceArea = image.toMesh(0).getSurfaceArea();
258246
}
259247
};
260248

Libs/Optimize/Domain/MeshDomain.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,9 @@ void MeshDomain::SetMesh(std::shared_ptr<Surface> mesh, double geodesic_remesh_p
8585
surface_ = mesh;
8686
sw_mesh_ = std::make_shared<Mesh>(surface_->get_polydata());
8787

88-
if (geodesic_remesh_percent >= 100.0) { // no remeshing
88+
surface_area_ = sw_mesh_->getSurfaceArea();
89+
90+
if (geodesic_remesh_percent >= 100.0) { // no remeshing
8991
geodesics_mesh_ = surface_;
9092
} else {
9193
auto poly_data = surface_->get_polydata();

Libs/Optimize/Domain/MeshDomain.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,7 @@ class MeshDomain : public ParticleDomain {
4545
PointType GetValidLocationNear(PointType p) const override;
4646

4747
double GetSurfaceArea() const override {
48-
// TODO return actual surface area
49-
return 0;
48+
return surface_area_;
5049
}
5150

5251
double GetMaxDiameter() const override;
@@ -86,6 +85,7 @@ class MeshDomain : public ParticleDomain {
8685
std::shared_ptr<Surface> geodesics_mesh_;
8786
std::shared_ptr<Mesh> sw_mesh_;
8887
PointType zero_crossing_point_;
88+
double surface_area_ = 0.0;
8989
};
9090

9191
} // namespace shapeworks

Libs/Optimize/Function/SamplingFunction.cpp

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11

22
#include "SamplingFunction.h"
33

4+
#include <set>
5+
46
#include "Libs/Common/Logging.h"
57
#include "Libs/Optimize/Domain/DomainType.h"
68
#include "vnl/vnl_vector_fixed.h"
@@ -85,6 +87,9 @@ std::shared_ptr<VectorFunction> SamplingFunction::clone() {
8587
copy->m_avgKappa = m_avgKappa;
8688
copy->m_IsSharedBoundaryEnabled = m_IsSharedBoundaryEnabled;
8789
copy->m_SharedBoundaryWeight = m_SharedBoundaryWeight;
90+
copy->m_SamplingScale = m_SamplingScale;
91+
copy->m_SamplingAutoScale = m_SamplingAutoScale;
92+
copy->m_SamplingScaleValue = m_SamplingScaleValue;
8893
copy->m_CurrentSigma = m_CurrentSigma;
8994
copy->m_CurrentNeighborhood = m_CurrentNeighborhood;
9095

@@ -244,6 +249,37 @@ SamplingFunction::VectorType SamplingFunction::evaluate(unsigned int idx, unsign
244249
energy = (A * sigma2inv) / m_avgKappa;
245250

246251
gradE = gradE / m_avgKappa;
252+
253+
// Apply sampling scale if enabled
254+
if (m_SamplingScale) {
255+
double scale_factor = 1.0;
256+
257+
if (m_SamplingAutoScale) {
258+
// Get surface area from the domain
259+
double surface_area = system->GetDomain(d)->GetSurfaceArea();
260+
261+
// Reference surface area for scale factor of 1.0
262+
// This constant will be tuned based on experiments
263+
constexpr double reference_surface_area = 12250.0;
264+
265+
// Scale factor is proportional to surface area
266+
scale_factor = surface_area / reference_surface_area;
267+
268+
// Log once per domain using static set
269+
static std::set<int> logged_domains;
270+
if (logged_domains.find(d) == logged_domains.end()) {
271+
logged_domains.insert(d);
272+
SW_DEBUG("SamplingFunction: Auto scale for domain " + std::to_string(d) + ", surface_area = " +
273+
std::to_string(surface_area) + ", scale_factor = " + std::to_string(scale_factor));
274+
}
275+
}
276+
277+
// multiply by scaling value (whether auto is on or off)
278+
scale_factor *= m_SamplingScaleValue;
279+
280+
gradE = gradE * scale_factor;
281+
}
282+
247283
return gradE;
248284
}
249285

Libs/Optimize/Function/SamplingFunction.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,15 @@ class SamplingFunction : public VectorFunction {
5353
void SetSharedBoundaryEnabled(bool enabled) { m_IsSharedBoundaryEnabled = enabled; }
5454
bool GetSharedBoundaryEnabled() const { return m_IsSharedBoundaryEnabled; }
5555

56+
void SetSamplingScale(bool enabled) { m_SamplingScale = enabled; }
57+
bool GetSamplingScale() const { return m_SamplingScale; }
58+
59+
void SetSamplingAutoScale(bool auto_scale) { m_SamplingAutoScale = auto_scale; }
60+
bool GetSamplingAutoScale() const { return m_SamplingAutoScale; }
61+
62+
void SetSamplingScaleValue(double scale_value) { m_SamplingScaleValue = scale_value; }
63+
double GetSamplingScaleValue() const { return m_SamplingScaleValue; }
64+
5665
/**Access the cache of sigma values for each particle position. This cache
5766
is populated by registering this object as an observer of the correct
5867
particle system (see SetParticleSystem).*/
@@ -105,6 +114,9 @@ class SamplingFunction : public VectorFunction {
105114
double m_avgKappa{0};
106115
bool m_IsSharedBoundaryEnabled{false};
107116
double m_SharedBoundaryWeight{1.0};
117+
bool m_SamplingScale{true};
118+
bool m_SamplingAutoScale{true};
119+
double m_SamplingScaleValue{1.0};
108120
double m_CurrentSigma{0.0};
109121
float m_MaxMoveFactor{0};
110122
double m_MinimumNeighborhoodRadius{0};

Libs/Optimize/Optimize.cpp

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -435,6 +435,10 @@ void Optimize::InitializeSampler() {
435435

436436
m_sampler->SetCorrespondenceOn();
437437

438+
m_sampler->SetSamplingScale(m_sampling_scale);
439+
m_sampler->SetSamplingAutoScale(m_sampling_auto_scale);
440+
m_sampler->SetSamplingScaleValue(m_sampling_scale_value);
441+
438442
m_sampler->SetAdaptivityMode();
439443
m_sampler->GetEnsembleEntropyFunction()->SetRecomputeCovarianceInterval(m_recompute_regularization_interval);
440444
m_sampler->GetDisentangledEnsembleEntropyFunction()->SetRecomputeCovarianceInterval(
@@ -1708,7 +1712,11 @@ void Optimize::AddImage(ImageType::Pointer image, std::string name) {
17081712
m_sampler->AddImage(image, this->GetNarrowBand(), name);
17091713
this->m_num_shapes++;
17101714
if (image) {
1711-
this->m_spacing = image->GetSpacing()[0] * 5;
1715+
double new_spacing = image->GetSpacing()[0] * 5;
1716+
if (m_spacing == 0 || new_spacing < this->m_spacing) {
1717+
// pick the smallest spacing
1718+
m_spacing = new_spacing;
1719+
}
17121720
}
17131721
}
17141722

@@ -2019,6 +2027,24 @@ void Optimize::SetSharedBoundaryEnabled(bool enabled) { m_sampler->SetSharedBoun
20192027
//---------------------------------------------------------------------------
20202028
void Optimize::SetSharedBoundaryWeight(double weight) { m_sampler->SetSharedBoundaryWeight(weight); }
20212029

2030+
//---------------------------------------------------------------------------
2031+
void Optimize::SetSamplingScale(bool enabled) { m_sampling_scale = enabled; }
2032+
2033+
//---------------------------------------------------------------------------
2034+
bool Optimize::GetSamplingScale() { return m_sampling_scale; }
2035+
2036+
//---------------------------------------------------------------------------
2037+
void Optimize::SetSamplingAutoScale(bool auto_scale) { m_sampling_auto_scale = auto_scale; }
2038+
2039+
//---------------------------------------------------------------------------
2040+
bool Optimize::GetSamplingAutoScale() { return m_sampling_auto_scale; }
2041+
2042+
//---------------------------------------------------------------------------
2043+
void Optimize::SetSamplingScaleValue(double scale_value) { m_sampling_scale_value = scale_value; }
2044+
2045+
//---------------------------------------------------------------------------
2046+
double Optimize::GetSamplingScaleValue() { return m_sampling_scale_value; }
2047+
20222048
//---------------------------------------------------------------------------
20232049
void Optimize::SetEarlyStoppingConfig(EarlyStoppingConfig config) { m_sampler->SetEarlyStoppingConfig(config); }
20242050

0 commit comments

Comments
 (0)