forked from VariationalRegistration/VariationalRegistration
-
Notifications
You must be signed in to change notification settings - Fork 6
Expand file tree
/
Copy pathitkVariationalRegistrationFunction.h
More file actions
307 lines (256 loc) · 8.79 KB
/
itkVariationalRegistrationFunction.h
File metadata and controls
307 lines (256 loc) · 8.79 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
/*=========================================================================
*
* Copyright NumFOCUS
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0.txt
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*=========================================================================*/
#ifndef itkVariationalRegistrationFunction_h
#define itkVariationalRegistrationFunction_h
#include "itkFiniteDifferenceFunction.h"
// #include "itkWarpImageFilter.h"
#include "itkContinuousBorderWarpImageFilter.h"
#include <mutex>
namespace itk
{
/** \class itk::VariationalRegistrationFunction
*
* \brief Base class for force calculation in the variational registration framework.
*
* This class is templated over fixed image type, moving image type and deformation field type.
* This function has the fixed image, the moving image and the current displacement field as input
* and computes an update value in ComputeUpdate().
*
* Implement a concrete force type in a subclass; overwrite the methods
* InitializeIteration() and ComputeUpdate().
*
* \sa VariationalRegistrationFilter
*
* \ingroup FiniteDifferenceFunctions
* \ingroup VariationalRegistration
*
* \note This class was developed with funding from the German Research
* Foundation (DFG: EH 224/3-1 and HA 235/9-1).
* \author Alexander Schmidt-Richberg
* \author Rene Werner
* \author Jan Ehrhardt
*/
template <typename TFixedImage, typename TMovingImage, typename TDisplacementField>
class VariationalRegistrationFunction : public FiniteDifferenceFunction<TDisplacementField>
{
public:
ITK_DISALLOW_COPY_AND_MOVE(VariationalRegistrationFunction);
/** Standard class type alias. */
using Self = VariationalRegistrationFunction;
using Superclass = FiniteDifferenceFunction<TDisplacementField>;
using Pointer = SmartPointer<Self>;
using ConstPointer = SmartPointer<const Self>;
using TimeStepType = typename Superclass::TimeStepType;
/** Run-time type information (and related methods) */
itkTypeMacro(VariationalRegistrationFunction, FiniteDifferenceFunction);
/** Get image dimension. */
static constexpr unsigned int ImageDimension = Superclass::ImageDimension;
/** MovingImage image type. */
using MovingImageType = TMovingImage;
using MovingImagePointer = typename MovingImageType::ConstPointer;
/** FixedImage image type. */
using FixedImageType = TFixedImage;
using FixedImagePointer = typename FixedImageType::ConstPointer;
/** FixedImage image type. */
using WarpedImageType = TFixedImage;
using WarpedImagePointer = typename WarpedImageType::ConstPointer;
/** Deformation field type. */
using DisplacementFieldType = TDisplacementField;
using DisplacementFieldTypePointer = typename DisplacementFieldType::ConstPointer;
/** MovingImage image type. */
using MaskImagePixelType = unsigned char;
using MaskImageType = Image<MaskImagePixelType, ImageDimension>;
using MaskImagePointer = typename MaskImageType::ConstPointer;
// uncomment the following line to use the standard ITK warper (not recommended)
// typedef itk::WarpImageFilter< FixedImageType, WarpedImageType, DisplacementFieldType >
/** Typedef of the warp image filter. */
using MovingImageWarperType =
itk::ContinuousBorderWarpImageFilter<FixedImageType, WarpedImageType, DisplacementFieldType>;
using MovingImageWarperPointer = typename MovingImageWarperType::Pointer;
/** Set the Moving image. */
virtual void
SetMovingImage(const MovingImageType * ptr)
{
m_MovingImage = ptr;
}
/** Get the Moving image. */
virtual const MovingImageType *
GetMovingImage() const
{
return m_MovingImage;
}
/** Set the fixed image. */
virtual void
SetFixedImage(const FixedImageType * ptr)
{
m_FixedImage = ptr;
}
/** Get the fixed image. */
virtual const FixedImageType *
GetFixedImage() const
{
return m_FixedImage;
}
/** Set the deformation field. */
virtual void
SetDisplacementField(DisplacementFieldType * ptr)
{
m_DisplacementField = ptr;
}
/** Get the deformation field. */
virtual const DisplacementFieldType *
GetDisplacementField() const
{
return m_DisplacementField;
}
/** Set the mask image. */
virtual void
SetMaskImage(const MaskImageType * ptr)
{
m_MaskImage = ptr;
}
/** Get the mask image. */
virtual const MaskImageType *
GetMaskImage() const
{
return m_MaskImage;
}
/** Set the moving image warper. */
virtual void
SetMovingImageWarper(MovingImageWarperType * ptr)
{
m_MovingImageWarper = ptr;
}
/** Get the moving image warper. */
virtual const MovingImageWarperType *
GetMovingImageWarper() const
{
return m_MovingImageWarper;
}
/** Set the time step. This time step will be used by ComputeGlobalTimeStep(). */
virtual void
SetTimeStep(TimeStepType timeStep)
{
m_TimeStep = timeStep;
}
/** Get the time step. */
virtual const TimeStepType
GetTimeStep() const
{
return m_TimeStep;
}
/** Set the MaskBackgroundThreshold. All Pixels of the mask image will be
* treated as background if the are <= this threshold. */
virtual void
SetMaskBackgroundThreshold(MaskImagePixelType threshold)
{
m_MaskBackgroundThreshold = threshold;
}
/** Get the MaskBackgroundThreshold. All Pixels of the mask image will be
* treated as background if the are <= this threshold. */
virtual MaskImagePixelType
GetMaskBackgroundThreshold() const
{
return m_MaskBackgroundThreshold;
}
/** Set the object's state before each iteration. */
void
InitializeIteration() override;
/** Computes the time step for an update.
* Returns the constant time step.
* \sa SetTimeStep() */
TimeStepType
ComputeGlobalTimeStep(void * itkNotUsed(GlobalData)) const override
{
return m_TimeStep;
}
/** Return a pointer to a global data structure that is passed to
* this object from the solver at each calculation. */
void *
GetGlobalDataPointer() const override;
/** Release memory for global data structure. */
void
ReleaseGlobalDataPointer(void * GlobalData) const override;
//
// Metric accessor methods
/** Get the metric value. The metric value is the mean square difference
* in intensity between the fixed image and transforming moving image
* computed over the the overlapping region between the two images. */
virtual double
GetMetric() const
{
return m_Metric;
}
/** Get the rms change in deformation field. */
virtual double
GetRMSChange() const
{
return m_RMSChange;
}
protected:
VariationalRegistrationFunction();
~VariationalRegistrationFunction() override = default;
/** Print information about the filter. */
void
PrintSelf(std::ostream & os, Indent indent) const override;
/** Warp the moving image into the domain of the fixed image using the
* deformation field. */
virtual void
WarpMovingImage();
/** Get the warped image. */
virtual const WarpedImagePointer
GetWarpedImage() const;
/** A global data type for this class of equation. Used to store
* information for computing the metric. */
struct GlobalDataStruct
{
double m_SumOfMetricValues;
SizeValueType m_NumberOfPixelsProcessed;
double m_SumOfSquaredChange;
};
private:
/** The Moving image. */
MovingImagePointer m_MovingImage;
/** The fixed image. */
FixedImagePointer m_FixedImage;
/** The deformation field. */
DisplacementFieldTypePointer m_DisplacementField;
/** The deformation field. */
MaskImagePointer m_MaskImage;
/** A class to warp the moving image into the domain of the fixed image. */
MovingImageWarperPointer m_MovingImageWarper;
/** The global timestep. */
TimeStepType m_TimeStep;
/** Threshold to define the background in the mask image. */
MaskImagePixelType m_MaskBackgroundThreshold;
/** The metric value is the mean square difference in intensity between
* the fixed image and transforming moving image computed over the
* the overlapping region between the two images. */
mutable double m_Metric;
mutable double m_SumOfMetricValues;
mutable SizeValueType m_NumberOfPixelsProcessed;
mutable double m_RMSChange;
mutable double m_SumOfSquaredChange;
/** Mutex lock to protect modification to metric. */
mutable std::mutex m_MetricCalculationLock;
};
} // end namespace itk
#ifndef ITK_MANUAL_INSTANTIATION
# include "itkVariationalRegistrationFunction.hxx"
#endif
#endif