diff --git a/Modules/Registration/Metricsv4/include/itkPointSetToPointSetMetricWithIndexv4.h b/Modules/Registration/Metricsv4/include/itkPointSetToPointSetMetricWithIndexv4.h index a400de127d6..fc34c67f071 100644 --- a/Modules/Registration/Metricsv4/include/itkPointSetToPointSetMetricWithIndexv4.h +++ b/Modules/Registration/Metricsv4/include/itkPointSetToPointSetMetricWithIndexv4.h @@ -370,7 +370,7 @@ class ITK_TEMPLATE_EXPORT PointSetToPointSetMetricWithIndexv4 /** Helper method allows for code reuse while skipping the metric value * calculation when appropriate */ - void + virtual void CalculateValueAndDerivative(MeasureType & calculatedValue, DerivativeType & derivative, bool calculateValue) const; /** diff --git a/Modules/Registration/Metricsv4/include/itkPointSetToPointSetMetricWithIndexv4.hxx b/Modules/Registration/Metricsv4/include/itkPointSetToPointSetMetricWithIndexv4.hxx index 03d07be9453..62d6dc6a168 100644 --- a/Modules/Registration/Metricsv4/include/itkPointSetToPointSetMetricWithIndexv4.hxx +++ b/Modules/Registration/Metricsv4/include/itkPointSetToPointSetMetricWithIndexv4.hxx @@ -302,7 +302,7 @@ PointSetToPointSetMetricWithIndexv4 threadValue; PixelType pixel; - NumericTraits::SetLength(pixel, 1); + // NumericTraits::SetLength(pixel, 1); for (PointIdentifier index = ranges[rangeIndex].first; index < ranges[rangeIndex].second; ++index) { MeasureType pointValue = NumericTraits::ZeroValue(); @@ -352,12 +352,34 @@ PointSetToPointSetMetricWithIndexv4GetMovingTransform()->ComputeJacobianWithRespectToParametersCachedTemporaries( virtualTransformedPointSet[index], jacobian, jacobianCache); + float new_jacobian[numberOfLocalParameters] = { 0 }; + for (NumberOfParametersType par = 0; par < numberOfLocalParameters; ++par) { + // for (DimensionType d = 0; d < PointDimension; ++d) + // { + // auto temp_jd = jacobian(d, par); + // threadLocalTransformDerivative[par] += temp_jd * pointDerivative[d]; + // } + + // Writing new jacobian by taking dot product with the normal + // auto checking_pixel = pixel; + for (DimensionType d = 0; d < PointDimension; ++d) { - threadLocalTransformDerivative[par] += jacobian(d, par) * pointDerivative[d]; + // Use pixel here instead of pointDerivative + // Override this method in the new class + new_jacobian[par] = new_jacobian[par] + jacobian(d, par) * pointDerivative[d]; } + + // perform dot product summation here of the dot product error + // threadLocalTransformDerivative[par] += temp_jd * (pointDerivative[0] + pointDerivative[1]); + } + + for (NumberOfParametersType par = 0; par < numberOfLocalParameters; ++par) + { + // perform dot product summation here of the dot product error with new jacobian + threadLocalTransformDerivative[par] += new_jacobian[par] * (pointDerivative[0] + pointDerivative[1]); } } // For local-support transforms, store the per-point result diff --git a/Modules/Registration/Metricsv4/include/itkPointToPlanePointSetToPointSetMetricv4.h b/Modules/Registration/Metricsv4/include/itkPointToPlanePointSetToPointSetMetricv4.h new file mode 100644 index 00000000000..353b151e40d --- /dev/null +++ b/Modules/Registration/Metricsv4/include/itkPointToPlanePointSetToPointSetMetricv4.h @@ -0,0 +1,208 @@ +/*========================================================================= + * + * Copyright NumFOCUS + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0.txt + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + *=========================================================================*/ +#ifndef itkPointToPlanePointSetToPointSetMetricv4_h +#define itkPointToPlanePointSetToPointSetMetricv4_h + +#include "itkPointSetToPointSetMetricv4.h" + +namespace itk +{ +/** \class PointToPlanePointSetToPointSetMetricv4 + * \brief Computes the Euclidan distance metric between two point sets. + * + * Given two point sets the Euclidean distance metric (i.e. ICP) is + * defined to be the aggregate of all shortest distances between all + * possible pairings of points between the two sets. + * + * We only have to handle the individual point case as the parent + * class handles the aggregation. + * + * Reference: + * PJ Besl and ND McKay, "A Method for Registration of 3-D Shapes", + * IEEE PAMI, Vol 14, No. 2, February 1992 + * + * \ingroup ITKMetricsv4 + */ +template +class ITK_TEMPLATE_EXPORT PointToPlanePointSetToPointSetMetricv4 + : public PointSetToPointSetMetricv4 +{ +public: + ITK_DISALLOW_COPY_AND_MOVE(PointToPlanePointSetToPointSetMetricv4); + + /** Standard class type aliases. */ + using Self = PointToPlanePointSetToPointSetMetricv4; + using Superclass = PointSetToPointSetMetricv4; + using Pointer = SmartPointer; + using ConstPointer = SmartPointer; + + /** Method for creation through the object factory. */ + itkNewMacro(Self); + + /** Run-time type information (and related methods). */ + itkTypeMacro(PointToPlanePointSetToPointSetMetricv4, PointSetToPointSetMetricv4); + + /** Types transferred from the base class */ + using typename Superclass::MeasureType; + + /** Type of the parameters. */ + using typename Superclass::ParametersType; + using typename Superclass::ParametersValueType; + using typename Superclass::NumberOfParametersType; + + /** Type of the derivative. */ + using typename Superclass::DerivativeType; + + /** Transform types from Superclass*/ + using typename Superclass::FixedTransformType; + using typename Superclass::FixedTransformPointer; + using typename Superclass::FixedInputPointType; + using typename Superclass::FixedOutputPointType; + using typename Superclass::FixedTransformParametersType; + + using typename Superclass::MovingTransformType; + using typename Superclass::MovingTransformPointer; + using typename Superclass::MovingInputPointType; + using typename Superclass::MovingOutputPointType; + using typename Superclass::MovingTransformParametersType; + + using typename Superclass::JacobianType; + using typename Superclass::FixedTransformJacobianType; + using typename Superclass::MovingTransformJacobianType; + + using DisplacementFieldTransformType = typename Superclass::MovingDisplacementFieldTransformType; + + using ObjectType = typename Superclass::ObjectType; + + /** Dimension type */ + using typename Superclass::DimensionType; + + /** Type of the fixed point set. */ + using FixedPointSetType = TFixedPointSet; + using FixedPointType = typename TFixedPointSet::PointType; + using FixedPixelType = typename TFixedPointSet::PixelType; + using FixedPointsContainer = typename TFixedPointSet::PointsContainer; + + static constexpr DimensionType FixedPointDimension = Superclass::FixedDimension; + + /** Type of the moving point set. */ + using MovingPointSetType = TMovingPointSet; + using MovingPointType = typename TMovingPointSet::PointType; + using MovingPixelType = typename TMovingPointSet::PixelType; + using MovingPointsContainer = typename TMovingPointSet::PointsContainer; + + static constexpr DimensionType MovingPointDimension = Superclass::MovingDimension; + + /** + * typedefs for the data types used in the point set metric calculations. + * It is assumed that the constants of the fixed point set, such as the + * point dimension, are the same for the "common space" in which the metric + * calculation occurs. + */ + static constexpr DimensionType PointDimension = Superclass::FixedDimension; + + using PointType = FixedPointType; + using PixelType = FixedPixelType; + using CoordRepType = typename PointType::CoordRepType; + using PointsContainer = FixedPointsContainer; + using PointsConstIterator = typename PointsContainer::ConstIterator; + using PointIdentifier = typename PointsContainer::ElementIdentifier; + + /** Typedef for points locator class to speed up finding neighboring points */ + using PointsLocatorType = PointsLocator; + using NeighborsIdentifierType = typename PointsLocatorType::NeighborsIdentifierType; + + using FixedTransformedPointSetType = PointSet; + using MovingTransformedPointSetType = PointSet; + + using DerivativeValueType = typename DerivativeType::ValueType; + using LocalDerivativeType = FixedArray; + + /** Types for the virtual domain */ + using VirtualImageType = typename Superclass::VirtualImageType; + using typename Superclass::VirtualImagePointer; + using typename Superclass::VirtualPixelType; + using typename Superclass::VirtualRegionType; + using typename Superclass::VirtualSizeType; + using typename Superclass::VirtualSpacingType; + using VirtualOriginType = typename Superclass::VirtualPointType; + using typename Superclass::VirtualPointType; + using typename Superclass::VirtualDirectionType; + using VirtualRadiusType = typename Superclass::VirtualSizeType; + using typename Superclass::VirtualIndexType; + using typename Superclass::VirtualPointSetType; + using typename Superclass::VirtualPointSetPointer; + + // Create ranges over the point set for multithreaded computation of value and derivatives + // using PointIdentifierPair = std::pair; + // using PointIdentifierRanges = std::vector; + + /** + * Calculates the local metric value for a single point. + */ + MeasureType + GetLocalNeighborhoodValue(const PointType &, const PixelType & pixel) const override; + + /** + * Calculates the local value and derivative for a single point. + */ + void + GetLocalNeighborhoodValueAndDerivative(const PointType &, + MeasureType &, + LocalDerivativeType &, + const PixelType & pixel) const override; + + /** + * Overide it to handle the change jacobian due to normal vector. + */ + void + CalculateValueAndDerivative(MeasureType & calculatedValue, + DerivativeType & derivative, + bool calculateValue) const override; + +protected: + PointToPlanePointSetToPointSetMetricv4(); + ~PointToPlanePointSetToPointSetMetricv4() override = default; + + bool + RequiresFixedPointsLocator() const override + { + return false; + } + + /** PrintSelf function */ + void + PrintSelf(std::ostream & os, Indent indent) const override; + + +private: + // Create ranges over the point set for multithreaded computation of value and derivatives + using PointIdentifierPair = std::pair; + using PointIdentifierRanges = std::vector; + const PointIdentifierRanges + CreateRanges() const; +}; +} // end namespace itk + +#ifndef ITK_MANUAL_INSTANTIATION +# include "itkPointToPlanePointSetToPointSetMetricv4.hxx" +#endif + +#endif diff --git a/Modules/Registration/Metricsv4/include/itkPointToPlanePointSetToPointSetMetricv4.hxx b/Modules/Registration/Metricsv4/include/itkPointToPlanePointSetToPointSetMetricv4.hxx new file mode 100644 index 00000000000..c2f1fdc8326 --- /dev/null +++ b/Modules/Registration/Metricsv4/include/itkPointToPlanePointSetToPointSetMetricv4.hxx @@ -0,0 +1,289 @@ +/*========================================================================= + * + * Copyright NumFOCUS + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0.txt + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + *=========================================================================*/ +#ifndef itkPointToPlanePointSetToPointSetMetricv4_hxx +#define itkPointToPlanePointSetToPointSetMetricv4_hxx + + +namespace itk +{ + +template +PointToPlanePointSetToPointSetMetricv4:: + PointToPlanePointSetToPointSetMetricv4() +{ + this->m_UsePointSetData = true; +} + +template +typename PointToPlanePointSetToPointSetMetricv4:: + MeasureType + PointToPlanePointSetToPointSetMetricv4:: + GetLocalNeighborhoodValue(const PointType & point, const PixelType & itkNotUsed(pixel)) const +{ + PointType closestPoint; + closestPoint.Fill(0.0); + + PointIdentifier pointId = this->m_MovingTransformedPointsLocator->FindClosestPoint(point); + closestPoint = this->m_MovingTransformedPointSet->GetPoint(pointId); + + const MeasureType distance = point.EuclideanDistanceTo(closestPoint); + return distance; +} + +template +void +PointToPlanePointSetToPointSetMetricv4:: + GetLocalNeighborhoodValueAndDerivative(const PointType & point, + MeasureType & measure, + LocalDerivativeType & localDerivative, + const PixelType & pixel) const +{ + PointType closestPoint; + closestPoint.Fill(0.0); + + PointIdentifier pointId = this->m_MovingTransformedPointsLocator->FindClosestPoint(point); + closestPoint = this->m_MovingTransformedPointSet->GetPoint(pointId); + + measure = 0; + auto temp_diff = closestPoint - point; + for (int i = 0; i < pixel.Size(); ++i) + { + measure = measure + temp_diff[i] * pixel[i]; + } + measure = measure * measure; + + localDerivative = closestPoint - point; + + // Perform dot product with the normal vector + for (int i = 0; i < pixel.Size(); ++i) + { + localDerivative[i] = localDerivative[i] * pixel[i]; + } +} + + +template +void +PointToPlanePointSetToPointSetMetricv4:: + CalculateValueAndDerivative(MeasureType & calculatedValue, DerivativeType & derivative, bool calculateValue) const +{ + this->InitializeForIteration(); + + // Virtual point set will be the same size as fixed point set as long as it's + // generated from the fixed point set. + if (this->m_VirtualTransformedPointSet->GetNumberOfPoints() != this->m_FixedTransformedPointSet->GetNumberOfPoints()) + { + itkExceptionMacro("Expected FixedTransformedPointSet to be the same size as VirtualTransformedPointSet."); + } + + derivative.SetSize(this->GetNumberOfParameters()); + if (!this->GetStoreDerivativeAsSparseFieldForLocalSupportTransforms()) + { + derivative.SetSize(PointDimension * this->m_FixedTransformedPointSet->GetNumberOfPoints()); + } + derivative.Fill(NumericTraits::ZeroValue()); + + /* + * Split pointset in nWorkUnits ranges and sum individually + * This splitting is required in order to avoid having the threads + * repeatedly write to same location causing false sharing + */ + // GetNumberOfLocalParameters is not trhead safe in itkCompositeTransform + NumberOfParametersType numberOfLocalParameters = this->GetNumberOfLocalParameters(); + PointIdentifierRanges ranges = this->CreateRanges(); + std::vector> threadValues(ranges.size()); + using CompensatedDerivative = typename std::vector>; + std::vector threadDerivatives(ranges.size()); + std::function sumNeighborhoodValues = + [this, &derivative, &threadDerivatives, &threadValues, &ranges, &calculateValue, &numberOfLocalParameters]( + SizeValueType rangeIndex) { + // Use STL container to make sure no unesecarry checks are performed + using FixedTransformedVectorContainer = typename FixedPointsContainer::STLContainerType; + using VirtualPointsContainer = typename VirtualPointSetType::PointsContainer; + using VirtualVectorContainer = typename VirtualPointsContainer::STLContainerType; + const VirtualVectorContainer & virtualTransformedPointSet = + this->m_VirtualTransformedPointSet->GetPoints()->CastToSTLConstContainer(); + const FixedTransformedVectorContainer & fixedTransformedPointSet = + this->m_FixedTransformedPointSet->GetPoints()->CastToSTLConstContainer(); + + MovingTransformJacobianType jacobian(MovingPointDimension, numberOfLocalParameters); + MovingTransformJacobianType jacobianCache; + + DerivativeType threadLocalTransformDerivative(numberOfLocalParameters); + threadLocalTransformDerivative.Fill(NumericTraits::ZeroValue()); + + CompensatedDerivative threadDerivativeSum(numberOfLocalParameters); + + CompensatedSummation threadValue; + PixelType pixel; + // NumericTraits::SetLength(pixel, 1); + for (PointIdentifier index = ranges[rangeIndex].first; index < ranges[rangeIndex].second; ++index) + { + MeasureType pointValue = NumericTraits::ZeroValue(); + LocalDerivativeType pointDerivative; + + /* Verify the virtual point is in the virtual domain. + * If user hasn't defined a virtual space, and the active transform is not + * a displacement field transform type, then this will always return true. */ + if (!this->IsInsideVirtualDomain(virtualTransformedPointSet[index])) + { + continue; + } + + if (this->m_UsePointSetData) + { + bool doesPointDataExist = this->m_FixedPointSet->GetPointData(index, &pixel); + if (!doesPointDataExist) + { + itkExceptionMacro("The corresponding data for point with id " << index << " does not exist."); + } + } + + if (calculateValue) + { + this->GetLocalNeighborhoodValueAndDerivativeWithIndex( + index, fixedTransformedPointSet[index], pointValue, pointDerivative, pixel); + threadValue += pointValue; + } + else + { + pointDerivative = + this->GetLocalNeighborhoodDerivativeWithIndex(index, fixedTransformedPointSet[index], pixel); + } + + // Map into parameter space + threadLocalTransformDerivative.Fill(NumericTraits::ZeroValue()); + + if (this->m_CalculateValueAndDerivativeInTangentSpace) + { + for (DimensionType d = 0; d < PointDimension; ++d) + { + threadLocalTransformDerivative[d] += pointDerivative[d]; + } + } + else + { + this->GetMovingTransform()->ComputeJacobianWithRespectToParametersCachedTemporaries( + virtualTransformedPointSet[index], jacobian, jacobianCache); + + float new_jacobian[numberOfLocalParameters] = { 0 }; + + for (NumberOfParametersType par = 0; par < numberOfLocalParameters; ++par) + { + // for (DimensionType d = 0; d < PointDimension; ++d) + // { + // auto temp_jd = jacobian(d, par); + // threadLocalTransformDerivative[par] += temp_jd * pointDerivative[d]; + // } + + // Writing new jacobian by taking dot product with the normal + // auto checking_pixel = pixel; + + for (DimensionType d = 0; d < PointDimension; ++d) + { + // Use pixel here instead of pointDerivative + // Override this method in the new class + new_jacobian[par] = new_jacobian[par] + jacobian(d, par) * pointDerivative[d]; + // threadLocalTransformDerivative[par] += temp_jd * pointDerivative[d]; + } + + // perform dot product summation here of the dot product error + // threadLocalTransformDerivative[par] += temp_jd * (pointDerivative[0] + pointDerivative[1]); + } + + for (NumberOfParametersType par = 0; par < numberOfLocalParameters; ++par) + { + // perform dot product summation here of the dot product error with new jacobian + threadLocalTransformDerivative[par] += new_jacobian[par] * (pointDerivative[0] + pointDerivative[1]); + } + } + // For local-support transforms, store the per-point result + if (this->HasLocalSupport() || this->m_CalculateValueAndDerivativeInTangentSpace) + { + if (this->GetStoreDerivativeAsSparseFieldForLocalSupportTransforms()) + { + this->StorePointDerivative(virtualTransformedPointSet[index], threadLocalTransformDerivative, derivative); + } + else + { + for (NumberOfParametersType par = 0; par < numberOfLocalParameters; ++par) + { + derivative[this->GetNumberOfLocalParameters() * index + par] = threadLocalTransformDerivative[par]; + } + } + } + for (NumberOfParametersType par = 0; par < numberOfLocalParameters; ++par) + { + threadDerivativeSum[par] += threadLocalTransformDerivative[par]; + } + } + threadValues[rangeIndex] = threadValue; + threadDerivatives[rangeIndex] = threadDerivativeSum; + }; + + // Sum per thread + MultiThreaderBase::New()->ParallelizeArray( + (SizeValueType)0, (SizeValueType)ranges.size(), sumNeighborhoodValues, nullptr); + + // Sum thread results + CompensatedSummation value = 0; + for (unsigned int i = 0; i < threadValues.size(); ++i) + { + value += threadValues[i]; + } + MeasureType valueSum = value.GetSum(); + + if (this->VerifyNumberOfValidPoints(valueSum, derivative)) + { + // For global-support transforms, average the accumulated derivative result + if (!this->HasLocalSupport() && !this->m_CalculateValueAndDerivativeInTangentSpace) + { + CompensatedDerivative localTransformDerivative(numberOfLocalParameters); + for (unsigned int i = 0; i < threadDerivatives.size(); ++i) + { + for (NumberOfParametersType par = 0; par < numberOfLocalParameters; ++par) + { + localTransformDerivative[par] += threadDerivatives[i][par]; + } + } + derivative.SetSize(numberOfLocalParameters); + for (NumberOfParametersType par = 0; par < numberOfLocalParameters; ++par) + { + derivative[par] = + localTransformDerivative[par].GetSum() / static_cast(this->m_NumberOfValidPoints); + } + } + valueSum /= static_cast(this->m_NumberOfValidPoints); + } + calculatedValue = valueSum; + this->m_Value = valueSum; +} + +/** PrintSelf method */ +template +void +PointToPlanePointSetToPointSetMetricv4::PrintSelf( + std::ostream & os, + Indent indent) const +{ + Superclass::PrintSelf(os, indent); +} + +} // end namespace itk + +#endif diff --git a/Modules/Registration/Metricsv4/test/itkEuclideanDistancePointSetMetricRegistrationTest.cxx b/Modules/Registration/Metricsv4/test/itkEuclideanDistancePointSetMetricRegistrationTest.cxx index 946ee5e4e29..fd1f9d0d59d 100644 --- a/Modules/Registration/Metricsv4/test/itkEuclideanDistancePointSetMetricRegistrationTest.cxx +++ b/Modules/Registration/Metricsv4/test/itkEuclideanDistancePointSetMetricRegistrationTest.cxx @@ -17,9 +17,11 @@ *=========================================================================*/ #include "itkEuclideanDistancePointSetToPointSetMetricv4.h" +#include "itkPointToPlanePointSetToPointSetMetricv4.h" #include "itkGradientDescentOptimizerv4.h" #include "itkRegistrationParameterScalesFromPhysicalShift.h" #include "itkAffineTransform.h" +#include "itkRigid2DTransform.h" #include "itkCommand.h" #include "itkMath.h" @@ -85,23 +87,39 @@ itkEuclideanDistancePointSetMetricRegistrationTestRun(unsigned int // Create a few points and apply a small rotation to make the moving point set - float theta = itk::Math::pi / static_cast(180.0) * static_cast(1.0); + int num_of_points = 30; + + float x[num_of_points] = { 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, + 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0 }; + + float y[num_of_points] = { 0.0, 0.096, 0.337, 0.598, 0.727, 0.598, 0.169, -0.491, -1.211, -1.76, + -1.918, -1.552, -0.671, 0.559, 1.84, 2.814, 3.166, 2.715, 1.484, -0.286, + -2.176, -3.695, -4.4, -4.027, -2.576, -0.332, 2.185, 4.34, 5.547, 5.422 }; + + float nx[num_of_points] = { 0.0, -0.166, -0.2437, -0.1918, 0.0, 0.2688, 0.4784, 0.568, 0.5356, 0.3333, + -0.1031, -0.5292, -0.726, -0.7821, -0.7481, -0.5527, 0.0495, 0.6437, 0.8321, 0.8775, + 0.8625, 0.7435, 0.1639, -0.6739, -0.8795, -0.9219, -0.9193, -0.8595, -0.4758, 0.0 }; + + float ny[num_of_points] = { 0.0, 0.9861, 0.9698, 0.9814, 1.0, 0.9632, 0.8781, 0.823, 0.8445, 0.9428, + 0.9947, 0.8485, 0.6877, 0.6231, 0.6636, 0.8334, 0.9988, 0.7653, 0.5546, 0.4796, + 0.506, 0.6687, 0.9865, 0.7388, 0.476, 0.3873, 0.3935, 0.5112, 0.8795, 0.0 }; + + float theta = itk::Math::pi / static_cast(180.0) * static_cast(25.0); PointType fixedPoint; - fixedPoint[0] = static_cast(0.0); - fixedPoint[1] = static_cast(0.0); - fixedPoints->SetPoint(0, fixedPoint); - fixedPoint[0] = pointMax; - fixedPoint[1] = static_cast(0.0); - fixedPoints->SetPoint(1, fixedPoint); - fixedPoint[0] = pointMax; - fixedPoint[1] = pointMax; - fixedPoints->SetPoint(2, fixedPoint); - fixedPoint[0] = static_cast(0.0); - fixedPoint[1] = pointMax; - fixedPoints->SetPoint(3, fixedPoint); - fixedPoint[0] = pointMax / static_cast(2.0); - fixedPoint[1] = pointMax / static_cast(2.0); - fixedPoints->SetPoint(4, fixedPoint); + + using VectorType = itk::Vector; + VectorType v; + + for (int i = 0; i < num_of_points; ++i) + { + fixedPoint[0] = static_cast(x[i]); + fixedPoint[1] = static_cast(y[i]); + fixedPoints->SetPoint(i, fixedPoint); + v[0] = nx[i]; + v[1] = ny[i]; + fixedPoints->SetPointData(i, v); + } + unsigned int numberOfPoints = fixedPoints->GetNumberOfPoints(); PointType movingPoint; @@ -132,13 +150,14 @@ itkEuclideanDistancePointSetMetricRegistrationTestRun(unsigned int using OptimizerType = itk::GradientDescentOptimizerv4; auto optimizer = OptimizerType::New(); optimizer->SetMetric(metric); + optimizer->SetLearningRate(0.0001); optimizer->SetNumberOfIterations(numberOfIterations); - optimizer->SetScalesEstimator(shiftScaleEstimator); + // optimizer->SetScalesEstimator(shiftScaleEstimator); optimizer->SetMaximumStepSizeInPhysicalUnits(maximumPhysicalStepSize); using CommandType = itkEuclideanDistancePointSetMetricRegistrationTestCommandIterationUpdate; auto observer = CommandType::New(); - // optimizer->AddObserver( itk::IterationEvent(), observer ); + optimizer->AddObserver(itk::IterationEvent(), observer); // start optimizer->StartOptimization(); @@ -182,7 +201,7 @@ itkEuclideanDistancePointSetMetricRegistrationTestRun(unsigned int movingPoint = movingPoints->GetPoint(n); difference[0] = movingPoint[0] - transformedFixedPoint[0]; difference[1] = movingPoint[1] - transformedFixedPoint[1]; - std::cout << fixedPoints->GetPoint(n) << "\t" << movingPoint << "\t" << transformedFixedPoint << "\t" << difference + std::cout << fixedPoints->GetPoint(n) << "->" << movingPoint << "->" << transformedFixedPoint << "->" << difference << std::endl; if (itk::Math::abs(difference[0]) > tolerance || itk::Math::abs(difference[1]) > tolerance) { @@ -206,8 +225,8 @@ itkEuclideanDistancePointSetMetricRegistrationTest(int argc, char * argv[]) int finalResult = EXIT_SUCCESS; - unsigned int numberOfIterations = 100; - auto maximumPhysicalStepSize = static_cast(2.0); + unsigned int numberOfIterations = 200; + auto maximumPhysicalStepSize = static_cast(0.01); if (argc > 1) { numberOfIterations = std::stoi(argv[1]); @@ -224,12 +243,17 @@ itkEuclideanDistancePointSetMetricRegistrationTest(int argc, char * argv[]) // // metric - using PointSetType = itk::PointSet; - using PointSetMetricType = itk::EuclideanDistancePointSetToPointSetMetricv4; + using PointSetType = itk::PointSet, Dimension>; + using PointSetMetricType = itk::PointToPlanePointSetToPointSetMetricv4; + // using PointSetMetricType = itk::EuclideanDistancePointSetToPointSetMetricv4; auto metric = PointSetMetricType::New(); + std::cout << "Metric is " << std::endl; + std::cout << metric << std::endl; + // transform using AffineTransformType = itk::AffineTransform; + // using AffineTransformType = itk::Rigid2DTransform; auto affineTransform = AffineTransformType::New(); affineTransform->SetIdentity(); std::cout << "XX Test with affine transform: " << std::endl; @@ -242,71 +266,71 @@ itkEuclideanDistancePointSetMetricRegistrationTest(int argc, char * argv[]) std::cerr << "Failed for affine transform." << std::endl; } - // - // Displacement field transform - // - - using DisplacementFieldTransformType = itk::DisplacementFieldTransform; - auto displacementTransform = DisplacementFieldTransformType::New(); - - // Setup the physical space to match the point set virtual domain, - // which is defined by the fixed point set since the fixed transform - // is identity. - using FieldType = DisplacementFieldTransformType::DisplacementFieldType; - using RegionType = FieldType::RegionType; - using RealType = DisplacementFieldTransformType::ScalarType; - - FieldType::SpacingType spacing; - spacing.Fill(static_cast(1.0)); - - FieldType::DirectionType direction; - direction.Fill(static_cast(0.0)); - for (unsigned int d = 0; d < Dimension; ++d) - { - direction[d][d] = static_cast(1.0); - } - - FieldType::PointType origin; - origin.Fill(static_cast(0.0)); - - RegionType::SizeType regionSize; - regionSize.Fill(static_cast(pointMax) + 1); - - RegionType::IndexType regionIndex; - regionIndex.Fill(0); - - RegionType region; - region.SetSize(regionSize); - region.SetIndex(regionIndex); - - auto displacementField = FieldType::New(); - displacementField->SetOrigin(origin); - displacementField->SetDirection(direction); - displacementField->SetSpacing(spacing); - displacementField->SetRegions(region); - displacementField->Allocate(); - DisplacementFieldTransformType::OutputVectorType zeroVector; - zeroVector.Fill(static_cast(0.0)); - displacementField->FillBuffer(zeroVector); - displacementTransform->SetDisplacementField(displacementField); - - // metric - using PointSetMetricType = itk::EuclideanDistancePointSetToPointSetMetricv4; - auto metric2 = PointSetMetricType::New(); - // If we don't set the virtual domain when using a displacement field transform, the - // metric takes it from the transform during initialization. - // metric2->SetVirtualDomain( spacing, origin, direction, region ); - - std::cout << "XX Testing with displacement field transform." << std::endl; - oneResult = itkEuclideanDistancePointSetMetricRegistrationTestRun( - numberOfIterations, maximumPhysicalStepSize, pointMax, displacementTransform, metric2); - if (oneResult == EXIT_FAILURE) - { - finalResult = EXIT_FAILURE; - std::cerr << "Failed for displacement transform." << std::endl; - } + // // + // // Displacement field transform + // // + + // using DisplacementFieldTransformType = itk::DisplacementFieldTransform; + // auto displacementTransform = DisplacementFieldTransformType::New(); + + // // Setup the physical space to match the point set virtual domain, + // // which is defined by the fixed point set since the fixed transform + // // is identity. + // using FieldType = DisplacementFieldTransformType::DisplacementFieldType; + // using RegionType = FieldType::RegionType; + // using RealType = DisplacementFieldTransformType::ScalarType; + + // FieldType::SpacingType spacing; + // spacing.Fill(static_cast(1.0)); + + // FieldType::DirectionType direction; + // direction.Fill(static_cast(0.0)); + // for (unsigned int d = 0; d < Dimension; ++d) + // { + // direction[d][d] = static_cast(1.0); + // } + + // FieldType::PointType origin; + // origin.Fill(static_cast(0.0)); + + // RegionType::SizeType regionSize; + // regionSize.Fill(static_cast(pointMax) + 1); + + // RegionType::IndexType regionIndex; + // regionIndex.Fill(0); + + // RegionType region; + // region.SetSize(regionSize); + // region.SetIndex(regionIndex); + + // auto displacementField = FieldType::New(); + // displacementField->SetOrigin(origin); + // displacementField->SetDirection(direction); + // displacementField->SetSpacing(spacing); + // displacementField->SetRegions(region); + // displacementField->Allocate(); + // DisplacementFieldTransformType::OutputVectorType zeroVector; + // zeroVector.Fill(static_cast(0.0)); + // displacementField->FillBuffer(zeroVector); + // displacementTransform->SetDisplacementField(displacementField); + + // // metric + // using PointSetMetricType = itk::PointToPlanePointSetToPointSetMetricv4; + // auto metric2 = PointSetMetricType::New(); + // // If we don't set the virtual domain when using a displacement field transform, the + // // metric takes it from the transform during initialization. + // // metric2->SetVirtualDomain( spacing, origin, direction, region ); + + // std::cout << "XX Testing with displacement field transform." << std::endl; + // oneResult = itkEuclideanDistancePointSetMetricRegistrationTestRun( + // numberOfIterations, maximumPhysicalStepSize, pointMax, displacementTransform, metric2); + // if (oneResult == EXIT_FAILURE) + // { + // finalResult = EXIT_FAILURE; + // std::cerr << "Failed for displacement transform." << std::endl; + // } return finalResult; } diff --git a/Modules/Registration/Metricsv4/test/itkEuclideanDistancePointSetMetricTest.cxx b/Modules/Registration/Metricsv4/test/itkEuclideanDistancePointSetMetricTest.cxx index 5cdb8acb447..4f3d0c603a1 100644 --- a/Modules/Registration/Metricsv4/test/itkEuclideanDistancePointSetMetricTest.cxx +++ b/Modules/Registration/Metricsv4/test/itkEuclideanDistancePointSetMetricTest.cxx @@ -17,6 +17,7 @@ *=========================================================================*/ #include "itkEuclideanDistancePointSetToPointSetMetricv4.h" +#include "itkPointToPlanePointSetToPointSetMetricv4.h" #include "itkTranslationTransform.h" #include @@ -26,7 +27,7 @@ template int itkEuclideanDistancePointSetMetricTestRun() { - using PointSetType = itk::PointSet; + using PointSetType = itk::PointSet; using PointType = typename PointSetType::PointType;