From 9e9d31c9b1f16ba1e368677586572e1bf89c6e97 Mon Sep 17 00:00:00 2001 From: Pranjal Sahu Date: Tue, 7 Jun 2022 23:43:04 -0400 Subject: [PATCH 1/5] ENH: Adding PointToPlane metric --- ...tkPointToPlanePointSetToPointSetMetricv4.h | 105 ++++++++++++++++++ ...PointToPlanePointSetToPointSetMetricv4.hxx | 72 ++++++++++++ ...itkEuclideanDistancePointSetMetricTest.cxx | 3 +- 3 files changed, 179 insertions(+), 1 deletion(-) create mode 100644 Modules/Registration/Metricsv4/include/itkPointToPlanePointSetToPointSetMetricv4.h create mode 100644 Modules/Registration/Metricsv4/include/itkPointToPlanePointSetToPointSetMetricv4.hxx diff --git a/Modules/Registration/Metricsv4/include/itkPointToPlanePointSetToPointSetMetricv4.h b/Modules/Registration/Metricsv4/include/itkPointToPlanePointSetToPointSetMetricv4.h new file mode 100644 index 00000000000..7dc3c3c0801 --- /dev/null +++ b/Modules/Registration/Metricsv4/include/itkPointToPlanePointSetToPointSetMetricv4.h @@ -0,0 +1,105 @@ +/*========================================================================= + * + * Copyright NumFOCUS + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0.txt + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + *=========================================================================*/ +#ifndef itkPointToPlanePointSetToPointSetMetricv4_h +#define itkPointToPlanePointSetToPointSetMetricv4_h + +#include "itkPointSetToPointSetMetricv4.h" + +namespace itk +{ +/** \class PointToPlanePointSetToPointSetMetricv4 + * \brief Computes the Euclidan distance metric between two point sets. + * + * Given two point sets the Euclidean distance metric (i.e. ICP) is + * defined to be the aggregate of all shortest distances between all + * possible pairings of points between the two sets. + * + * We only have to handle the individual point case as the parent + * class handles the aggregation. + * + * Reference: + * PJ Besl and ND McKay, "A Method for Registration of 3-D Shapes", + * IEEE PAMI, Vol 14, No. 2, February 1992 + * + * \ingroup ITKMetricsv4 + */ +template +class ITK_TEMPLATE_EXPORT PointToPlanePointSetToPointSetMetricv4 + : public PointSetToPointSetMetricv4 +{ +public: + ITK_DISALLOW_COPY_AND_MOVE(PointToPlanePointSetToPointSetMetricv4); + + /** Standard class type aliases. */ + using Self = PointToPlanePointSetToPointSetMetricv4; + using Superclass = PointSetToPointSetMetricv4; + using Pointer = SmartPointer; + using ConstPointer = SmartPointer; + + /** Method for creation through the object factory. */ + itkNewMacro(Self); + + /** Run-time type information (and related methods). */ + itkTypeMacro(PointToPlanePointSetToPointSetMetricv4, PointSetToPointSetMetricv4); + + /** Types transferred from the base class */ + using typename Superclass::MeasureType; + using typename Superclass::DerivativeType; + using typename Superclass::LocalDerivativeType; + using typename Superclass::PointType; + using typename Superclass::PixelType; + using typename Superclass::PointIdentifier; + + /** + * Calculates the local metric value for a single point. + */ + MeasureType + GetLocalNeighborhoodValue(const PointType &, const PixelType & pixel = 0) const override; + + /** + * Calculates the local value and derivative for a single point. + */ + void + GetLocalNeighborhoodValueAndDerivative(const PointType &, + MeasureType &, + LocalDerivativeType &, + const PixelType & pixel = 0) const override; + +protected: + PointToPlanePointSetToPointSetMetricv4() = default; + ~PointToPlanePointSetToPointSetMetricv4() override = default; + + bool + RequiresFixedPointsLocator() const override + { + return false; + } + + /** PrintSelf function */ + void + PrintSelf(std::ostream & os, Indent indent) const override; +}; +} // end namespace itk + +#ifndef ITK_MANUAL_INSTANTIATION +# include "itkPointToPlanePointSetToPointSetMetricv4.hxx" +#endif + +#endif diff --git a/Modules/Registration/Metricsv4/include/itkPointToPlanePointSetToPointSetMetricv4.hxx b/Modules/Registration/Metricsv4/include/itkPointToPlanePointSetToPointSetMetricv4.hxx new file mode 100644 index 00000000000..e60b7812491 --- /dev/null +++ b/Modules/Registration/Metricsv4/include/itkPointToPlanePointSetToPointSetMetricv4.hxx @@ -0,0 +1,72 @@ +/*========================================================================= + * + * Copyright NumFOCUS + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0.txt + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + *=========================================================================*/ +#ifndef itkPointToPlanePointSetToPointSetMetricv4_hxx +#define itkPointToPlanePointSetToPointSetMetricv4_hxx + + +namespace itk +{ + +template +typename PointToPlanePointSetToPointSetMetricv4:: + MeasureType + PointToPlanePointSetToPointSetMetricv4:: + GetLocalNeighborhoodValue(const PointType & point, const PixelType & itkNotUsed(pixel)) const +{ + PointType closestPoint; + closestPoint.Fill(0.0); + + PointIdentifier pointId = this->m_MovingTransformedPointsLocator->FindClosestPoint(point); + closestPoint = this->m_MovingTransformedPointSet->GetPoint(pointId); + + const MeasureType distance = point.EuclideanDistanceTo(closestPoint); + return distance; +} + +template +void +PointToPlanePointSetToPointSetMetricv4:: + GetLocalNeighborhoodValueAndDerivative(const PointType & point, + MeasureType & measure, + LocalDerivativeType & localDerivative, + const PixelType & itkNotUsed(pixel)) const +{ + PointType closestPoint; + closestPoint.Fill(0.0); + + + PointIdentifier pointId = this->m_MovingTransformedPointsLocator->FindClosestPoint(point); + closestPoint = this->m_MovingTransformedPointSet->GetPoint(pointId); + + measure = point.EuclideanDistanceTo(closestPoint); + localDerivative = closestPoint - point; +} + +/** PrintSelf method */ +template +void +PointToPlanePointSetToPointSetMetricv4::PrintSelf( + std::ostream & os, + Indent indent) const +{ + Superclass::PrintSelf(os, indent); +} + +} // end namespace itk + +#endif diff --git a/Modules/Registration/Metricsv4/test/itkEuclideanDistancePointSetMetricTest.cxx b/Modules/Registration/Metricsv4/test/itkEuclideanDistancePointSetMetricTest.cxx index 5cdb8acb447..aab39ed4d71 100644 --- a/Modules/Registration/Metricsv4/test/itkEuclideanDistancePointSetMetricTest.cxx +++ b/Modules/Registration/Metricsv4/test/itkEuclideanDistancePointSetMetricTest.cxx @@ -17,6 +17,7 @@ *=========================================================================*/ #include "itkEuclideanDistancePointSetToPointSetMetricv4.h" +#include "itkPointToPlanePointSetToPointSetMetricv4.h" #include "itkTranslationTransform.h" #include @@ -75,7 +76,7 @@ itkEuclideanDistancePointSetMetricTestRun() translationTransform->SetIdentity(); // Instantiate the metric - using PointSetMetricType = itk::EuclideanDistancePointSetToPointSetMetricv4; + using PointSetMetricType = itk::PointToPlanePointSetToPointSetMetricv4; auto metric = PointSetMetricType::New(); metric->SetFixedPointSet(fixedPoints); metric->SetMovingPointSet(movingPoints); From 5180af53914605910eed5dbc3535cf3e189f9a9d Mon Sep 17 00:00:00 2001 From: Pranjal Sahu Date: Wed, 8 Jun 2022 00:12:28 -0400 Subject: [PATCH 2/5] ENH: Able to use the pixel data. Now store the gradient and use it for ICP. --- .../itkPointSetToPointSetMetricWithIndexv4.hxx | 16 ++++++++-------- .../itkPointToPlanePointSetToPointSetMetricv4.h | 2 +- ...itkPointToPlanePointSetToPointSetMetricv4.hxx | 7 +++++++ ...eanDistancePointSetMetricRegistrationTest.cxx | 5 +++-- 4 files changed, 19 insertions(+), 11 deletions(-) diff --git a/Modules/Registration/Metricsv4/include/itkPointSetToPointSetMetricWithIndexv4.hxx b/Modules/Registration/Metricsv4/include/itkPointSetToPointSetMetricWithIndexv4.hxx index 03d07be9453..4d981eb1c27 100644 --- a/Modules/Registration/Metricsv4/include/itkPointSetToPointSetMetricWithIndexv4.hxx +++ b/Modules/Registration/Metricsv4/include/itkPointSetToPointSetMetricWithIndexv4.hxx @@ -199,10 +199,10 @@ typename PointSetToPointSetMetricWithIndexv4m_UsePointSetData) { bool doesPointDataExist = this->m_FixedPointSet->GetPointData(index, &pixel); - if (!doesPointDataExist) - { - itkExceptionMacro("The corresponding data for point (pointId = " << index << ") does not exist."); - } + // if (!doesPointDataExist) + // { + // itkExceptionMacro("The corresponding data for point (pointId = " << index << ") does not exist."); + // } } threadValue += this->GetLocalNeighborhoodValueWithIndex(index, fixedTransformedPointSet[index], pixel); } @@ -319,10 +319,10 @@ PointSetToPointSetMetricWithIndexv4m_UsePointSetData) { bool doesPointDataExist = this->m_FixedPointSet->GetPointData(index, &pixel); - if (!doesPointDataExist) - { - itkExceptionMacro("The corresponding data for point with id " << index << " does not exist."); - } + // if (!doesPointDataExist) + // { + // itkExceptionMacro("The corresponding data for point with id " << index << " does not exist."); + // } } if (calculateValue) diff --git a/Modules/Registration/Metricsv4/include/itkPointToPlanePointSetToPointSetMetricv4.h b/Modules/Registration/Metricsv4/include/itkPointToPlanePointSetToPointSetMetricv4.h index 7dc3c3c0801..154f8ba35c6 100644 --- a/Modules/Registration/Metricsv4/include/itkPointToPlanePointSetToPointSetMetricv4.h +++ b/Modules/Registration/Metricsv4/include/itkPointToPlanePointSetToPointSetMetricv4.h @@ -83,7 +83,7 @@ class ITK_TEMPLATE_EXPORT PointToPlanePointSetToPointSetMetricv4 const PixelType & pixel = 0) const override; protected: - PointToPlanePointSetToPointSetMetricv4() = default; + PointToPlanePointSetToPointSetMetricv4(); ~PointToPlanePointSetToPointSetMetricv4() override = default; bool diff --git a/Modules/Registration/Metricsv4/include/itkPointToPlanePointSetToPointSetMetricv4.hxx b/Modules/Registration/Metricsv4/include/itkPointToPlanePointSetToPointSetMetricv4.hxx index e60b7812491..1e2c2362d94 100644 --- a/Modules/Registration/Metricsv4/include/itkPointToPlanePointSetToPointSetMetricv4.hxx +++ b/Modules/Registration/Metricsv4/include/itkPointToPlanePointSetToPointSetMetricv4.hxx @@ -22,6 +22,13 @@ namespace itk { +template +PointToPlanePointSetToPointSetMetricv4:: + PointToPlanePointSetToPointSetMetricv4() +{ + this->m_UsePointSetData = true; +} + template typename PointToPlanePointSetToPointSetMetricv4:: MeasureType diff --git a/Modules/Registration/Metricsv4/test/itkEuclideanDistancePointSetMetricRegistrationTest.cxx b/Modules/Registration/Metricsv4/test/itkEuclideanDistancePointSetMetricRegistrationTest.cxx index 946ee5e4e29..4f076344201 100644 --- a/Modules/Registration/Metricsv4/test/itkEuclideanDistancePointSetMetricRegistrationTest.cxx +++ b/Modules/Registration/Metricsv4/test/itkEuclideanDistancePointSetMetricRegistrationTest.cxx @@ -17,6 +17,7 @@ *=========================================================================*/ #include "itkEuclideanDistancePointSetToPointSetMetricv4.h" +#include "itkPointToPlanePointSetToPointSetMetricv4.h" #include "itkGradientDescentOptimizerv4.h" #include "itkRegistrationParameterScalesFromPhysicalShift.h" #include "itkAffineTransform.h" @@ -225,7 +226,7 @@ itkEuclideanDistancePointSetMetricRegistrationTest(int argc, char * argv[]) // metric using PointSetType = itk::PointSet; - using PointSetMetricType = itk::EuclideanDistancePointSetToPointSetMetricv4; + using PointSetMetricType = itk::PointToPlanePointSetToPointSetMetricv4; auto metric = PointSetMetricType::New(); // transform @@ -291,7 +292,7 @@ itkEuclideanDistancePointSetMetricRegistrationTest(int argc, char * argv[]) displacementTransform->SetDisplacementField(displacementField); // metric - using PointSetMetricType = itk::EuclideanDistancePointSetToPointSetMetricv4; + using PointSetMetricType = itk::PointToPlanePointSetToPointSetMetricv4; auto metric2 = PointSetMetricType::New(); // If we don't set the virtual domain when using a displacement field transform, the // metric takes it from the transform during initialization. From 977f238efa7596ae7f56e9f6d5d8ad01219e5b63 Mon Sep 17 00:00:00 2001 From: Pranjal Sahu Date: Wed, 8 Jun 2022 14:45:43 -0400 Subject: [PATCH 3/5] ENH: Changes to test Point to Plane ICP --- ...itkPointSetToPointSetMetricWithIndexv4.hxx | 4 +- ...tkPointToPlanePointSetToPointSetMetricv4.h | 4 +- ...PointToPlanePointSetToPointSetMetricv4.hxx | 18 +- ...DistancePointSetMetricRegistrationTest.cxx | 195 ++++++++++-------- ...itkEuclideanDistancePointSetMetricTest.cxx | 4 +- 5 files changed, 129 insertions(+), 96 deletions(-) diff --git a/Modules/Registration/Metricsv4/include/itkPointSetToPointSetMetricWithIndexv4.hxx b/Modules/Registration/Metricsv4/include/itkPointSetToPointSetMetricWithIndexv4.hxx index 4d981eb1c27..a4bf6ea8f68 100644 --- a/Modules/Registration/Metricsv4/include/itkPointSetToPointSetMetricWithIndexv4.hxx +++ b/Modules/Registration/Metricsv4/include/itkPointSetToPointSetMetricWithIndexv4.hxx @@ -199,6 +199,7 @@ typename PointSetToPointSetMetricWithIndexv4m_UsePointSetData) { bool doesPointDataExist = this->m_FixedPointSet->GetPointData(index, &pixel); + // std::cout << "point data is " << pixel << std::endl; // if (!doesPointDataExist) // { // itkExceptionMacro("The corresponding data for point (pointId = " << index << ") does not exist."); @@ -302,7 +303,7 @@ PointSetToPointSetMetricWithIndexv4 threadValue; PixelType pixel; - NumericTraits::SetLength(pixel, 1); + // NumericTraits::SetLength(pixel, 1); for (PointIdentifier index = ranges[rangeIndex].first; index < ranges[rangeIndex].second; ++index) { MeasureType pointValue = NumericTraits::ZeroValue(); @@ -319,6 +320,7 @@ PointSetToPointSetMetricWithIndexv4m_UsePointSetData) { bool doesPointDataExist = this->m_FixedPointSet->GetPointData(index, &pixel); + // std::cout << "point data is " << pixel << std::endl; // if (!doesPointDataExist) // { // itkExceptionMacro("The corresponding data for point with id " << index << " does not exist."); diff --git a/Modules/Registration/Metricsv4/include/itkPointToPlanePointSetToPointSetMetricv4.h b/Modules/Registration/Metricsv4/include/itkPointToPlanePointSetToPointSetMetricv4.h index 154f8ba35c6..296707cbeea 100644 --- a/Modules/Registration/Metricsv4/include/itkPointToPlanePointSetToPointSetMetricv4.h +++ b/Modules/Registration/Metricsv4/include/itkPointToPlanePointSetToPointSetMetricv4.h @@ -71,7 +71,7 @@ class ITK_TEMPLATE_EXPORT PointToPlanePointSetToPointSetMetricv4 * Calculates the local metric value for a single point. */ MeasureType - GetLocalNeighborhoodValue(const PointType &, const PixelType & pixel = 0) const override; + GetLocalNeighborhoodValue(const PointType &, const PixelType & pixel) const override; /** * Calculates the local value and derivative for a single point. @@ -80,7 +80,7 @@ class ITK_TEMPLATE_EXPORT PointToPlanePointSetToPointSetMetricv4 GetLocalNeighborhoodValueAndDerivative(const PointType &, MeasureType &, LocalDerivativeType &, - const PixelType & pixel = 0) const override; + const PixelType & pixel) const override; protected: PointToPlanePointSetToPointSetMetricv4(); diff --git a/Modules/Registration/Metricsv4/include/itkPointToPlanePointSetToPointSetMetricv4.hxx b/Modules/Registration/Metricsv4/include/itkPointToPlanePointSetToPointSetMetricv4.hxx index 1e2c2362d94..922c02d99e7 100644 --- a/Modules/Registration/Metricsv4/include/itkPointToPlanePointSetToPointSetMetricv4.hxx +++ b/Modules/Registration/Metricsv4/include/itkPointToPlanePointSetToPointSetMetricv4.hxx @@ -51,17 +51,29 @@ PointToPlanePointSetToPointSetMetricv4m_MovingTransformedPointsLocator->FindClosestPoint(point); closestPoint = this->m_MovingTransformedPointSet->GetPoint(pointId); - measure = point.EuclideanDistanceTo(closestPoint); + measure = 0; + auto temp_diff = closestPoint - point; + for (int i = 0; i < pixel.Size(); ++i) + { + measure = measure + temp_diff[i] * pixel[i]; + } + measure = measure * measure; + localDerivative = closestPoint - point; + + // Perform dot product with the normal vector + for (int i = 0; i < pixel.Size(); ++i) + { + localDerivative[i] = localDerivative[i] * pixel[i]; + } } /** PrintSelf method */ diff --git a/Modules/Registration/Metricsv4/test/itkEuclideanDistancePointSetMetricRegistrationTest.cxx b/Modules/Registration/Metricsv4/test/itkEuclideanDistancePointSetMetricRegistrationTest.cxx index 4f076344201..f80dfb0b9d3 100644 --- a/Modules/Registration/Metricsv4/test/itkEuclideanDistancePointSetMetricRegistrationTest.cxx +++ b/Modules/Registration/Metricsv4/test/itkEuclideanDistancePointSetMetricRegistrationTest.cxx @@ -16,11 +16,12 @@ * *=========================================================================*/ -#include "itkEuclideanDistancePointSetToPointSetMetricv4.h" +//#include "itkEuclideanDistancePointSetToPointSetMetricv4.h" #include "itkPointToPlanePointSetToPointSetMetricv4.h" #include "itkGradientDescentOptimizerv4.h" #include "itkRegistrationParameterScalesFromPhysicalShift.h" #include "itkAffineTransform.h" +#include "itkCenteredRigid2DTransform.h" #include "itkCommand.h" #include "itkMath.h" @@ -86,23 +87,39 @@ itkEuclideanDistancePointSetMetricRegistrationTestRun(unsigned int // Create a few points and apply a small rotation to make the moving point set - float theta = itk::Math::pi / static_cast(180.0) * static_cast(1.0); + int num_of_points = 30; + + float x[num_of_points] = { 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, + 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0 }; + + float y[num_of_points] = { 0.0, 0.096, 0.337, 0.598, 0.727, 0.598, 0.169, -0.491, -1.211, -1.76, + -1.918, -1.552, -0.671, 0.559, 1.84, 2.814, 3.166, 2.715, 1.484, -0.286, + -2.176, -3.695, -4.4, -4.027, -2.576, -0.332, 2.185, 4.34, 5.547, 5.422 }; + + float nx[num_of_points] = { 0.0, -0.166, -0.2437, -0.1918, 0.0, 0.2688, 0.4784, 0.568, 0.5356, 0.3333, + -0.1031, -0.5292, -0.726, -0.7821, -0.7481, -0.5527, 0.0495, 0.6437, 0.8321, 0.8775, + 0.8625, 0.7435, 0.1639, -0.6739, -0.8795, -0.9219, -0.9193, -0.8595, -0.4758, 0.0 }; + + float ny[num_of_points] = { 0.0, 0.9861, 0.9698, 0.9814, 1.0, 0.9632, 0.8781, 0.823, 0.8445, 0.9428, + 0.9947, 0.8485, 0.6877, 0.6231, 0.6636, 0.8334, 0.9988, 0.7653, 0.5546, 0.4796, + 0.506, 0.6687, 0.9865, 0.7388, 0.476, 0.3873, 0.3935, 0.5112, 0.8795, 0.0 }; + + float theta = itk::Math::pi / static_cast(180.0) * static_cast(25.0); PointType fixedPoint; - fixedPoint[0] = static_cast(0.0); - fixedPoint[1] = static_cast(0.0); - fixedPoints->SetPoint(0, fixedPoint); - fixedPoint[0] = pointMax; - fixedPoint[1] = static_cast(0.0); - fixedPoints->SetPoint(1, fixedPoint); - fixedPoint[0] = pointMax; - fixedPoint[1] = pointMax; - fixedPoints->SetPoint(2, fixedPoint); - fixedPoint[0] = static_cast(0.0); - fixedPoint[1] = pointMax; - fixedPoints->SetPoint(3, fixedPoint); - fixedPoint[0] = pointMax / static_cast(2.0); - fixedPoint[1] = pointMax / static_cast(2.0); - fixedPoints->SetPoint(4, fixedPoint); + + using VectorType = itk::Vector; + VectorType v; + + for (int i = 0; i < num_of_points; ++i) + { + fixedPoint[0] = static_cast(x[i]); + fixedPoint[1] = static_cast(y[i]); + fixedPoints->SetPoint(i, fixedPoint); + v[0] = nx[i]; + v[1] = ny[i]; + fixedPoints->SetPointData(i, v); + } + unsigned int numberOfPoints = fixedPoints->GetNumberOfPoints(); PointType movingPoint; @@ -133,13 +150,14 @@ itkEuclideanDistancePointSetMetricRegistrationTestRun(unsigned int using OptimizerType = itk::GradientDescentOptimizerv4; auto optimizer = OptimizerType::New(); optimizer->SetMetric(metric); + optimizer->SetLearningRate(0.001); optimizer->SetNumberOfIterations(numberOfIterations); - optimizer->SetScalesEstimator(shiftScaleEstimator); + // optimizer->SetScalesEstimator(shiftScaleEstimator); optimizer->SetMaximumStepSizeInPhysicalUnits(maximumPhysicalStepSize); using CommandType = itkEuclideanDistancePointSetMetricRegistrationTestCommandIterationUpdate; auto observer = CommandType::New(); - // optimizer->AddObserver( itk::IterationEvent(), observer ); + optimizer->AddObserver(itk::IterationEvent(), observer); // start optimizer->StartOptimization(); @@ -183,7 +201,7 @@ itkEuclideanDistancePointSetMetricRegistrationTestRun(unsigned int movingPoint = movingPoints->GetPoint(n); difference[0] = movingPoint[0] - transformedFixedPoint[0]; difference[1] = movingPoint[1] - transformedFixedPoint[1]; - std::cout << fixedPoints->GetPoint(n) << "\t" << movingPoint << "\t" << transformedFixedPoint << "\t" << difference + std::cout << fixedPoints->GetPoint(n) << "->" << movingPoint << "->" << transformedFixedPoint << "->" << difference << std::endl; if (itk::Math::abs(difference[0]) > tolerance || itk::Math::abs(difference[1]) > tolerance) { @@ -208,7 +226,7 @@ itkEuclideanDistancePointSetMetricRegistrationTest(int argc, char * argv[]) int finalResult = EXIT_SUCCESS; unsigned int numberOfIterations = 100; - auto maximumPhysicalStepSize = static_cast(2.0); + auto maximumPhysicalStepSize = static_cast(1); if (argc > 1) { numberOfIterations = std::stoi(argv[1]); @@ -225,12 +243,13 @@ itkEuclideanDistancePointSetMetricRegistrationTest(int argc, char * argv[]) // // metric - using PointSetType = itk::PointSet; + using PointSetType = itk::PointSet, Dimension>; using PointSetMetricType = itk::PointToPlanePointSetToPointSetMetricv4; auto metric = PointSetMetricType::New(); // transform - using AffineTransformType = itk::AffineTransform; + // using AffineTransformType = itk::AffineTransform; + using AffineTransformType = itk::CenteredRigid2DTransform; auto affineTransform = AffineTransformType::New(); affineTransform->SetIdentity(); std::cout << "XX Test with affine transform: " << std::endl; @@ -243,71 +262,71 @@ itkEuclideanDistancePointSetMetricRegistrationTest(int argc, char * argv[]) std::cerr << "Failed for affine transform." << std::endl; } - // - // Displacement field transform - // - - using DisplacementFieldTransformType = itk::DisplacementFieldTransform; - auto displacementTransform = DisplacementFieldTransformType::New(); - - // Setup the physical space to match the point set virtual domain, - // which is defined by the fixed point set since the fixed transform - // is identity. - using FieldType = DisplacementFieldTransformType::DisplacementFieldType; - using RegionType = FieldType::RegionType; - using RealType = DisplacementFieldTransformType::ScalarType; - - FieldType::SpacingType spacing; - spacing.Fill(static_cast(1.0)); - - FieldType::DirectionType direction; - direction.Fill(static_cast(0.0)); - for (unsigned int d = 0; d < Dimension; ++d) - { - direction[d][d] = static_cast(1.0); - } - - FieldType::PointType origin; - origin.Fill(static_cast(0.0)); - - RegionType::SizeType regionSize; - regionSize.Fill(static_cast(pointMax) + 1); - - RegionType::IndexType regionIndex; - regionIndex.Fill(0); - - RegionType region; - region.SetSize(regionSize); - region.SetIndex(regionIndex); - - auto displacementField = FieldType::New(); - displacementField->SetOrigin(origin); - displacementField->SetDirection(direction); - displacementField->SetSpacing(spacing); - displacementField->SetRegions(region); - displacementField->Allocate(); - DisplacementFieldTransformType::OutputVectorType zeroVector; - zeroVector.Fill(static_cast(0.0)); - displacementField->FillBuffer(zeroVector); - displacementTransform->SetDisplacementField(displacementField); - - // metric - using PointSetMetricType = itk::PointToPlanePointSetToPointSetMetricv4; - auto metric2 = PointSetMetricType::New(); - // If we don't set the virtual domain when using a displacement field transform, the - // metric takes it from the transform during initialization. - // metric2->SetVirtualDomain( spacing, origin, direction, region ); - - std::cout << "XX Testing with displacement field transform." << std::endl; - oneResult = itkEuclideanDistancePointSetMetricRegistrationTestRun( - numberOfIterations, maximumPhysicalStepSize, pointMax, displacementTransform, metric2); - if (oneResult == EXIT_FAILURE) - { - finalResult = EXIT_FAILURE; - std::cerr << "Failed for displacement transform." << std::endl; - } + // // + // // Displacement field transform + // // + + // using DisplacementFieldTransformType = itk::DisplacementFieldTransform; + // auto displacementTransform = DisplacementFieldTransformType::New(); + + // // Setup the physical space to match the point set virtual domain, + // // which is defined by the fixed point set since the fixed transform + // // is identity. + // using FieldType = DisplacementFieldTransformType::DisplacementFieldType; + // using RegionType = FieldType::RegionType; + // using RealType = DisplacementFieldTransformType::ScalarType; + + // FieldType::SpacingType spacing; + // spacing.Fill(static_cast(1.0)); + + // FieldType::DirectionType direction; + // direction.Fill(static_cast(0.0)); + // for (unsigned int d = 0; d < Dimension; ++d) + // { + // direction[d][d] = static_cast(1.0); + // } + + // FieldType::PointType origin; + // origin.Fill(static_cast(0.0)); + + // RegionType::SizeType regionSize; + // regionSize.Fill(static_cast(pointMax) + 1); + + // RegionType::IndexType regionIndex; + // regionIndex.Fill(0); + + // RegionType region; + // region.SetSize(regionSize); + // region.SetIndex(regionIndex); + + // auto displacementField = FieldType::New(); + // displacementField->SetOrigin(origin); + // displacementField->SetDirection(direction); + // displacementField->SetSpacing(spacing); + // displacementField->SetRegions(region); + // displacementField->Allocate(); + // DisplacementFieldTransformType::OutputVectorType zeroVector; + // zeroVector.Fill(static_cast(0.0)); + // displacementField->FillBuffer(zeroVector); + // displacementTransform->SetDisplacementField(displacementField); + + // // metric + // using PointSetMetricType = itk::PointToPlanePointSetToPointSetMetricv4; + // auto metric2 = PointSetMetricType::New(); + // // If we don't set the virtual domain when using a displacement field transform, the + // // metric takes it from the transform during initialization. + // // metric2->SetVirtualDomain( spacing, origin, direction, region ); + + // std::cout << "XX Testing with displacement field transform." << std::endl; + // oneResult = itkEuclideanDistancePointSetMetricRegistrationTestRun( + // numberOfIterations, maximumPhysicalStepSize, pointMax, displacementTransform, metric2); + // if (oneResult == EXIT_FAILURE) + // { + // finalResult = EXIT_FAILURE; + // std::cerr << "Failed for displacement transform." << std::endl; + // } return finalResult; } diff --git a/Modules/Registration/Metricsv4/test/itkEuclideanDistancePointSetMetricTest.cxx b/Modules/Registration/Metricsv4/test/itkEuclideanDistancePointSetMetricTest.cxx index aab39ed4d71..4f3d0c603a1 100644 --- a/Modules/Registration/Metricsv4/test/itkEuclideanDistancePointSetMetricTest.cxx +++ b/Modules/Registration/Metricsv4/test/itkEuclideanDistancePointSetMetricTest.cxx @@ -27,7 +27,7 @@ template int itkEuclideanDistancePointSetMetricTestRun() { - using PointSetType = itk::PointSet; + using PointSetType = itk::PointSet; using PointType = typename PointSetType::PointType; @@ -76,7 +76,7 @@ itkEuclideanDistancePointSetMetricTestRun() translationTransform->SetIdentity(); // Instantiate the metric - using PointSetMetricType = itk::PointToPlanePointSetToPointSetMetricv4; + using PointSetMetricType = itk::EuclideanDistancePointSetToPointSetMetricv4; auto metric = PointSetMetricType::New(); metric->SetFixedPointSet(fixedPoints); metric->SetMovingPointSet(movingPoints); From 422f7c9ed73710dfd6f393267dea69dcc352d16b Mon Sep 17 00:00:00 2001 From: Pranjal Sahu Date: Wed, 8 Jun 2022 23:05:30 -0400 Subject: [PATCH 4/5] ENH: Adding code where new jacobian is used by taking normal --- ...itkPointSetToPointSetMetricWithIndexv4.hxx | 23 ++++++++++++++++++- ...DistancePointSetMetricRegistrationTest.cxx | 15 ++++++------ 2 files changed, 30 insertions(+), 8 deletions(-) diff --git a/Modules/Registration/Metricsv4/include/itkPointSetToPointSetMetricWithIndexv4.hxx b/Modules/Registration/Metricsv4/include/itkPointSetToPointSetMetricWithIndexv4.hxx index a4bf6ea8f68..502f6dabbde 100644 --- a/Modules/Registration/Metricsv4/include/itkPointSetToPointSetMetricWithIndexv4.hxx +++ b/Modules/Registration/Metricsv4/include/itkPointSetToPointSetMetricWithIndexv4.hxx @@ -354,12 +354,33 @@ PointSetToPointSetMetricWithIndexv4GetMovingTransform()->ComputeJacobianWithRespectToParametersCachedTemporaries( virtualTransformedPointSet[index], jacobian, jacobianCache); + float new_jacobian[numberOfLocalParameters] = { 0 }; + for (NumberOfParametersType par = 0; par < numberOfLocalParameters; ++par) { + // for (DimensionType d = 0; d < PointDimension; ++d) + // { + // auto temp_jd = jacobian(d, par); + // threadLocalTransformDerivative[par] += temp_jd * pointDerivative[d]; + // } + + // Writing new jacobian by taking dot product with the normal + // auto checking_pixel = pixel; + for (DimensionType d = 0; d < PointDimension; ++d) { - threadLocalTransformDerivative[par] += jacobian(d, par) * pointDerivative[d]; + new_jacobian[par] = new_jacobian[par] + jacobian(d, par) * pointDerivative[d]; + // threadLocalTransformDerivative[par] += temp_jd * pointDerivative[d]; } + + // perform dot product summation here of the dot product error + // threadLocalTransformDerivative[par] += temp_jd * (pointDerivative[0] + pointDerivative[1]); + } + + for (NumberOfParametersType par = 0; par < numberOfLocalParameters; ++par) + { + // perform dot product summation here of the dot product error with new jacobian + threadLocalTransformDerivative[par] += new_jacobian[par] * (pointDerivative[0] + pointDerivative[1]); } } // For local-support transforms, store the per-point result diff --git a/Modules/Registration/Metricsv4/test/itkEuclideanDistancePointSetMetricRegistrationTest.cxx b/Modules/Registration/Metricsv4/test/itkEuclideanDistancePointSetMetricRegistrationTest.cxx index f80dfb0b9d3..6e9f19ad07a 100644 --- a/Modules/Registration/Metricsv4/test/itkEuclideanDistancePointSetMetricRegistrationTest.cxx +++ b/Modules/Registration/Metricsv4/test/itkEuclideanDistancePointSetMetricRegistrationTest.cxx @@ -16,12 +16,12 @@ * *=========================================================================*/ -//#include "itkEuclideanDistancePointSetToPointSetMetricv4.h" +#include "itkEuclideanDistancePointSetToPointSetMetricv4.h" #include "itkPointToPlanePointSetToPointSetMetricv4.h" #include "itkGradientDescentOptimizerv4.h" #include "itkRegistrationParameterScalesFromPhysicalShift.h" #include "itkAffineTransform.h" -#include "itkCenteredRigid2DTransform.h" +#include "itkRigid2DTransform.h" #include "itkCommand.h" #include "itkMath.h" @@ -150,9 +150,9 @@ itkEuclideanDistancePointSetMetricRegistrationTestRun(unsigned int using OptimizerType = itk::GradientDescentOptimizerv4; auto optimizer = OptimizerType::New(); optimizer->SetMetric(metric); - optimizer->SetLearningRate(0.001); + optimizer->SetLearningRate(0.0001); optimizer->SetNumberOfIterations(numberOfIterations); - // optimizer->SetScalesEstimator(shiftScaleEstimator); + optimizer->SetScalesEstimator(shiftScaleEstimator); optimizer->SetMaximumStepSizeInPhysicalUnits(maximumPhysicalStepSize); using CommandType = itkEuclideanDistancePointSetMetricRegistrationTestCommandIterationUpdate; @@ -225,8 +225,8 @@ itkEuclideanDistancePointSetMetricRegistrationTest(int argc, char * argv[]) int finalResult = EXIT_SUCCESS; - unsigned int numberOfIterations = 100; - auto maximumPhysicalStepSize = static_cast(1); + unsigned int numberOfIterations = 500; + auto maximumPhysicalStepSize = static_cast(0.01); if (argc > 1) { numberOfIterations = std::stoi(argv[1]); @@ -245,11 +245,12 @@ itkEuclideanDistancePointSetMetricRegistrationTest(int argc, char * argv[]) // metric using PointSetType = itk::PointSet, Dimension>; using PointSetMetricType = itk::PointToPlanePointSetToPointSetMetricv4; + // using PointSetMetricType = itk::EuclideanDistancePointSetToPointSetMetricv4; auto metric = PointSetMetricType::New(); // transform // using AffineTransformType = itk::AffineTransform; - using AffineTransformType = itk::CenteredRigid2DTransform; + using AffineTransformType = itk::Rigid2DTransform; auto affineTransform = AffineTransformType::New(); affineTransform->SetIdentity(); std::cout << "XX Test with affine transform: " << std::endl; From b4bb504d74265373abdd5033dcd597a256078120 Mon Sep 17 00:00:00 2001 From: Pranjal Sahu Date: Wed, 8 Jun 2022 23:14:58 -0400 Subject: [PATCH 5/5] ENH: Adding better combination for point to plane --- .../itkPointSetToPointSetMetricWithIndexv4.h | 2 +- ...itkPointSetToPointSetMetricWithIndexv4.hxx | 21 +- ...tkPointToPlanePointSetToPointSetMetricv4.h | 111 +++++++++- ...PointToPlanePointSetToPointSetMetricv4.hxx | 198 ++++++++++++++++++ ...DistancePointSetMetricRegistrationTest.cxx | 11 +- 5 files changed, 323 insertions(+), 20 deletions(-) diff --git a/Modules/Registration/Metricsv4/include/itkPointSetToPointSetMetricWithIndexv4.h b/Modules/Registration/Metricsv4/include/itkPointSetToPointSetMetricWithIndexv4.h index a400de127d6..fc34c67f071 100644 --- a/Modules/Registration/Metricsv4/include/itkPointSetToPointSetMetricWithIndexv4.h +++ b/Modules/Registration/Metricsv4/include/itkPointSetToPointSetMetricWithIndexv4.h @@ -370,7 +370,7 @@ class ITK_TEMPLATE_EXPORT PointSetToPointSetMetricWithIndexv4 /** Helper method allows for code reuse while skipping the metric value * calculation when appropriate */ - void + virtual void CalculateValueAndDerivative(MeasureType & calculatedValue, DerivativeType & derivative, bool calculateValue) const; /** diff --git a/Modules/Registration/Metricsv4/include/itkPointSetToPointSetMetricWithIndexv4.hxx b/Modules/Registration/Metricsv4/include/itkPointSetToPointSetMetricWithIndexv4.hxx index 502f6dabbde..62d6dc6a168 100644 --- a/Modules/Registration/Metricsv4/include/itkPointSetToPointSetMetricWithIndexv4.hxx +++ b/Modules/Registration/Metricsv4/include/itkPointSetToPointSetMetricWithIndexv4.hxx @@ -199,11 +199,10 @@ typename PointSetToPointSetMetricWithIndexv4m_UsePointSetData) { bool doesPointDataExist = this->m_FixedPointSet->GetPointData(index, &pixel); - // std::cout << "point data is " << pixel << std::endl; - // if (!doesPointDataExist) - // { - // itkExceptionMacro("The corresponding data for point (pointId = " << index << ") does not exist."); - // } + if (!doesPointDataExist) + { + itkExceptionMacro("The corresponding data for point (pointId = " << index << ") does not exist."); + } } threadValue += this->GetLocalNeighborhoodValueWithIndex(index, fixedTransformedPointSet[index], pixel); } @@ -320,11 +319,10 @@ PointSetToPointSetMetricWithIndexv4m_UsePointSetData) { bool doesPointDataExist = this->m_FixedPointSet->GetPointData(index, &pixel); - // std::cout << "point data is " << pixel << std::endl; - // if (!doesPointDataExist) - // { - // itkExceptionMacro("The corresponding data for point with id " << index << " does not exist."); - // } + if (!doesPointDataExist) + { + itkExceptionMacro("The corresponding data for point with id " << index << " does not exist."); + } } if (calculateValue) @@ -369,8 +367,9 @@ PointSetToPointSetMetricWithIndexv4; + using NeighborsIdentifierType = typename PointsLocatorType::NeighborsIdentifierType; + + using FixedTransformedPointSetType = PointSet; + using MovingTransformedPointSetType = PointSet; + + using DerivativeValueType = typename DerivativeType::ValueType; + using LocalDerivativeType = FixedArray; + + /** Types for the virtual domain */ + using VirtualImageType = typename Superclass::VirtualImageType; + using typename Superclass::VirtualImagePointer; + using typename Superclass::VirtualPixelType; + using typename Superclass::VirtualRegionType; + using typename Superclass::VirtualSizeType; + using typename Superclass::VirtualSpacingType; + using VirtualOriginType = typename Superclass::VirtualPointType; + using typename Superclass::VirtualPointType; + using typename Superclass::VirtualDirectionType; + using VirtualRadiusType = typename Superclass::VirtualSizeType; + using typename Superclass::VirtualIndexType; + using typename Superclass::VirtualPointSetType; + using typename Superclass::VirtualPointSetPointer; + + // Create ranges over the point set for multithreaded computation of value and derivatives + // using PointIdentifierPair = std::pair; + // using PointIdentifierRanges = std::vector; /** * Calculates the local metric value for a single point. @@ -82,6 +169,14 @@ class ITK_TEMPLATE_EXPORT PointToPlanePointSetToPointSetMetricv4 LocalDerivativeType &, const PixelType & pixel) const override; + /** + * Overide it to handle the change jacobian due to normal vector. + */ + void + CalculateValueAndDerivative(MeasureType & calculatedValue, + DerivativeType & derivative, + bool calculateValue) const override; + protected: PointToPlanePointSetToPointSetMetricv4(); ~PointToPlanePointSetToPointSetMetricv4() override = default; @@ -95,6 +190,14 @@ class ITK_TEMPLATE_EXPORT PointToPlanePointSetToPointSetMetricv4 /** PrintSelf function */ void PrintSelf(std::ostream & os, Indent indent) const override; + + +private: + // Create ranges over the point set for multithreaded computation of value and derivatives + using PointIdentifierPair = std::pair; + using PointIdentifierRanges = std::vector; + const PointIdentifierRanges + CreateRanges() const; }; } // end namespace itk diff --git a/Modules/Registration/Metricsv4/include/itkPointToPlanePointSetToPointSetMetricv4.hxx b/Modules/Registration/Metricsv4/include/itkPointToPlanePointSetToPointSetMetricv4.hxx index 922c02d99e7..c2f1fdc8326 100644 --- a/Modules/Registration/Metricsv4/include/itkPointToPlanePointSetToPointSetMetricv4.hxx +++ b/Modules/Registration/Metricsv4/include/itkPointToPlanePointSetToPointSetMetricv4.hxx @@ -76,6 +76,204 @@ PointToPlanePointSetToPointSetMetricv4 +void +PointToPlanePointSetToPointSetMetricv4:: + CalculateValueAndDerivative(MeasureType & calculatedValue, DerivativeType & derivative, bool calculateValue) const +{ + this->InitializeForIteration(); + + // Virtual point set will be the same size as fixed point set as long as it's + // generated from the fixed point set. + if (this->m_VirtualTransformedPointSet->GetNumberOfPoints() != this->m_FixedTransformedPointSet->GetNumberOfPoints()) + { + itkExceptionMacro("Expected FixedTransformedPointSet to be the same size as VirtualTransformedPointSet."); + } + + derivative.SetSize(this->GetNumberOfParameters()); + if (!this->GetStoreDerivativeAsSparseFieldForLocalSupportTransforms()) + { + derivative.SetSize(PointDimension * this->m_FixedTransformedPointSet->GetNumberOfPoints()); + } + derivative.Fill(NumericTraits::ZeroValue()); + + /* + * Split pointset in nWorkUnits ranges and sum individually + * This splitting is required in order to avoid having the threads + * repeatedly write to same location causing false sharing + */ + // GetNumberOfLocalParameters is not trhead safe in itkCompositeTransform + NumberOfParametersType numberOfLocalParameters = this->GetNumberOfLocalParameters(); + PointIdentifierRanges ranges = this->CreateRanges(); + std::vector> threadValues(ranges.size()); + using CompensatedDerivative = typename std::vector>; + std::vector threadDerivatives(ranges.size()); + std::function sumNeighborhoodValues = + [this, &derivative, &threadDerivatives, &threadValues, &ranges, &calculateValue, &numberOfLocalParameters]( + SizeValueType rangeIndex) { + // Use STL container to make sure no unesecarry checks are performed + using FixedTransformedVectorContainer = typename FixedPointsContainer::STLContainerType; + using VirtualPointsContainer = typename VirtualPointSetType::PointsContainer; + using VirtualVectorContainer = typename VirtualPointsContainer::STLContainerType; + const VirtualVectorContainer & virtualTransformedPointSet = + this->m_VirtualTransformedPointSet->GetPoints()->CastToSTLConstContainer(); + const FixedTransformedVectorContainer & fixedTransformedPointSet = + this->m_FixedTransformedPointSet->GetPoints()->CastToSTLConstContainer(); + + MovingTransformJacobianType jacobian(MovingPointDimension, numberOfLocalParameters); + MovingTransformJacobianType jacobianCache; + + DerivativeType threadLocalTransformDerivative(numberOfLocalParameters); + threadLocalTransformDerivative.Fill(NumericTraits::ZeroValue()); + + CompensatedDerivative threadDerivativeSum(numberOfLocalParameters); + + CompensatedSummation threadValue; + PixelType pixel; + // NumericTraits::SetLength(pixel, 1); + for (PointIdentifier index = ranges[rangeIndex].first; index < ranges[rangeIndex].second; ++index) + { + MeasureType pointValue = NumericTraits::ZeroValue(); + LocalDerivativeType pointDerivative; + + /* Verify the virtual point is in the virtual domain. + * If user hasn't defined a virtual space, and the active transform is not + * a displacement field transform type, then this will always return true. */ + if (!this->IsInsideVirtualDomain(virtualTransformedPointSet[index])) + { + continue; + } + + if (this->m_UsePointSetData) + { + bool doesPointDataExist = this->m_FixedPointSet->GetPointData(index, &pixel); + if (!doesPointDataExist) + { + itkExceptionMacro("The corresponding data for point with id " << index << " does not exist."); + } + } + + if (calculateValue) + { + this->GetLocalNeighborhoodValueAndDerivativeWithIndex( + index, fixedTransformedPointSet[index], pointValue, pointDerivative, pixel); + threadValue += pointValue; + } + else + { + pointDerivative = + this->GetLocalNeighborhoodDerivativeWithIndex(index, fixedTransformedPointSet[index], pixel); + } + + // Map into parameter space + threadLocalTransformDerivative.Fill(NumericTraits::ZeroValue()); + + if (this->m_CalculateValueAndDerivativeInTangentSpace) + { + for (DimensionType d = 0; d < PointDimension; ++d) + { + threadLocalTransformDerivative[d] += pointDerivative[d]; + } + } + else + { + this->GetMovingTransform()->ComputeJacobianWithRespectToParametersCachedTemporaries( + virtualTransformedPointSet[index], jacobian, jacobianCache); + + float new_jacobian[numberOfLocalParameters] = { 0 }; + + for (NumberOfParametersType par = 0; par < numberOfLocalParameters; ++par) + { + // for (DimensionType d = 0; d < PointDimension; ++d) + // { + // auto temp_jd = jacobian(d, par); + // threadLocalTransformDerivative[par] += temp_jd * pointDerivative[d]; + // } + + // Writing new jacobian by taking dot product with the normal + // auto checking_pixel = pixel; + + for (DimensionType d = 0; d < PointDimension; ++d) + { + // Use pixel here instead of pointDerivative + // Override this method in the new class + new_jacobian[par] = new_jacobian[par] + jacobian(d, par) * pointDerivative[d]; + // threadLocalTransformDerivative[par] += temp_jd * pointDerivative[d]; + } + + // perform dot product summation here of the dot product error + // threadLocalTransformDerivative[par] += temp_jd * (pointDerivative[0] + pointDerivative[1]); + } + + for (NumberOfParametersType par = 0; par < numberOfLocalParameters; ++par) + { + // perform dot product summation here of the dot product error with new jacobian + threadLocalTransformDerivative[par] += new_jacobian[par] * (pointDerivative[0] + pointDerivative[1]); + } + } + // For local-support transforms, store the per-point result + if (this->HasLocalSupport() || this->m_CalculateValueAndDerivativeInTangentSpace) + { + if (this->GetStoreDerivativeAsSparseFieldForLocalSupportTransforms()) + { + this->StorePointDerivative(virtualTransformedPointSet[index], threadLocalTransformDerivative, derivative); + } + else + { + for (NumberOfParametersType par = 0; par < numberOfLocalParameters; ++par) + { + derivative[this->GetNumberOfLocalParameters() * index + par] = threadLocalTransformDerivative[par]; + } + } + } + for (NumberOfParametersType par = 0; par < numberOfLocalParameters; ++par) + { + threadDerivativeSum[par] += threadLocalTransformDerivative[par]; + } + } + threadValues[rangeIndex] = threadValue; + threadDerivatives[rangeIndex] = threadDerivativeSum; + }; + + // Sum per thread + MultiThreaderBase::New()->ParallelizeArray( + (SizeValueType)0, (SizeValueType)ranges.size(), sumNeighborhoodValues, nullptr); + + // Sum thread results + CompensatedSummation value = 0; + for (unsigned int i = 0; i < threadValues.size(); ++i) + { + value += threadValues[i]; + } + MeasureType valueSum = value.GetSum(); + + if (this->VerifyNumberOfValidPoints(valueSum, derivative)) + { + // For global-support transforms, average the accumulated derivative result + if (!this->HasLocalSupport() && !this->m_CalculateValueAndDerivativeInTangentSpace) + { + CompensatedDerivative localTransformDerivative(numberOfLocalParameters); + for (unsigned int i = 0; i < threadDerivatives.size(); ++i) + { + for (NumberOfParametersType par = 0; par < numberOfLocalParameters; ++par) + { + localTransformDerivative[par] += threadDerivatives[i][par]; + } + } + derivative.SetSize(numberOfLocalParameters); + for (NumberOfParametersType par = 0; par < numberOfLocalParameters; ++par) + { + derivative[par] = + localTransformDerivative[par].GetSum() / static_cast(this->m_NumberOfValidPoints); + } + } + valueSum /= static_cast(this->m_NumberOfValidPoints); + } + calculatedValue = valueSum; + this->m_Value = valueSum; +} + /** PrintSelf method */ template void diff --git a/Modules/Registration/Metricsv4/test/itkEuclideanDistancePointSetMetricRegistrationTest.cxx b/Modules/Registration/Metricsv4/test/itkEuclideanDistancePointSetMetricRegistrationTest.cxx index 6e9f19ad07a..fd1f9d0d59d 100644 --- a/Modules/Registration/Metricsv4/test/itkEuclideanDistancePointSetMetricRegistrationTest.cxx +++ b/Modules/Registration/Metricsv4/test/itkEuclideanDistancePointSetMetricRegistrationTest.cxx @@ -152,7 +152,7 @@ itkEuclideanDistancePointSetMetricRegistrationTestRun(unsigned int optimizer->SetMetric(metric); optimizer->SetLearningRate(0.0001); optimizer->SetNumberOfIterations(numberOfIterations); - optimizer->SetScalesEstimator(shiftScaleEstimator); + // optimizer->SetScalesEstimator(shiftScaleEstimator); optimizer->SetMaximumStepSizeInPhysicalUnits(maximumPhysicalStepSize); using CommandType = itkEuclideanDistancePointSetMetricRegistrationTestCommandIterationUpdate; @@ -225,7 +225,7 @@ itkEuclideanDistancePointSetMetricRegistrationTest(int argc, char * argv[]) int finalResult = EXIT_SUCCESS; - unsigned int numberOfIterations = 500; + unsigned int numberOfIterations = 200; auto maximumPhysicalStepSize = static_cast(0.01); if (argc > 1) { @@ -248,9 +248,12 @@ itkEuclideanDistancePointSetMetricRegistrationTest(int argc, char * argv[]) // using PointSetMetricType = itk::EuclideanDistancePointSetToPointSetMetricv4; auto metric = PointSetMetricType::New(); + std::cout << "Metric is " << std::endl; + std::cout << metric << std::endl; + // transform - // using AffineTransformType = itk::AffineTransform; - using AffineTransformType = itk::Rigid2DTransform; + using AffineTransformType = itk::AffineTransform; + // using AffineTransformType = itk::Rigid2DTransform; auto affineTransform = AffineTransformType::New(); affineTransform->SetIdentity(); std::cout << "XX Test with affine transform: " << std::endl;