Skip to content

Commit 27607ef

Browse files
committed
Implemented some image comparison methods: MSE, PSNR and SSIM
1 parent 3918874 commit 27607ef

11 files changed

Lines changed: 870 additions & 0 deletions

File tree

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
fast_add_sources(
2+
MSE.cpp
3+
MSE.hpp
4+
PSNR.cpp
5+
SSIM.cpp
6+
)
7+
fast_add_process_object(MeanSquaredError MSE.hpp)
8+
fast_add_process_object(PeakSignalToNoiseRatio PSNR.hpp)
9+
fast_add_process_object(StructuralSimilarityIndexMeasure SSIM.hpp)
10+
fast_add_test_sources(ImageComparisonTests.cpp)
Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
#include <FAST/Testing.hpp>
2+
#include <FAST/Importers/ImageFileImporter.hpp>
3+
#include <FAST/Importers/WholeSlideImageImporter.hpp>
4+
#include <FAST/Data/ImagePyramid.hpp>
5+
#include <FAST/Algorithms/GaussianSmoothing/GaussianSmoothing.hpp>
6+
#include <FAST/Visualization/Plotting/LinePlotter.hpp>
7+
#include <FAST/Algorithms/ImageCropper/ImageCropper.hpp>
8+
#include "PSNR.hpp"
9+
#include "SSIM.hpp"
10+
11+
using namespace fast;
12+
13+
TEST_CASE("MSE on 2D images", "[fast][MSE][ImageComparison]") {
14+
auto importer1 = ImageFileImporter::create(Config::getTestDataPath() + "US/Heart/ApicalFourChamber/US-2D_0.mhd");
15+
auto image1 = importer1->run()->getOutput<Image>();
16+
auto importer2 = ImageFileImporter::create(Config::getTestDataPath() + "US/Heart/ApicalFourChamber/US-2D_20.mhd");
17+
18+
auto mse = MSE::create()
19+
->connect(0, importer1)
20+
->connect(1, importer2);
21+
CHECK_NOTHROW(
22+
mse->run();
23+
);
24+
25+
auto output = mse->getOutput<Image>();
26+
CHECK(mse->get() == Approx(676.59));
27+
CHECK(output->getWidth() == image1->getWidth());
28+
CHECK(output->getHeight() == image1->getHeight());
29+
CHECK(output->getDataType() == TYPE_FLOAT);
30+
CHECK(output->getNrOfChannels() == image1->getNrOfChannels());
31+
}
32+
33+
TEST_CASE("MSE on 2D color images", "[fast][MSE][ImageComparison]") {
34+
auto importer1 = WholeSlideImageImporter::create(Config::getTestDataPath() + "WSI/CMU-1.svs");
35+
auto wsi = importer1->run()->getOutput<ImagePyramid>();
36+
auto access = wsi->getAccess(ACCESS_READ);
37+
auto image1 = access->getLevelAsImage(wsi->getNrOfLevels()-1);
38+
auto image2 = GaussianSmoothing::create()->connect(image1);
39+
40+
auto mse = MSE::create()->connect(image1)->connect(1, image2);
41+
CHECK_NOTHROW(
42+
mse->run();
43+
);
44+
45+
auto output = mse->getOutput<Image>();
46+
CHECK(output->getWidth() == image1->getWidth());
47+
CHECK(output->getHeight() == image1->getHeight());
48+
CHECK(output->getDataType() == TYPE_FLOAT);
49+
CHECK(output->getNrOfChannels() == image1->getNrOfChannels());
50+
}
51+
52+
TEST_CASE("MSE on 3D images", "[fast][MSE][ImageComparison]") {
53+
auto importer1 = ImageFileImporter::create(Config::getTestDataPath() + "US/Ball/US-3Dt_0.mhd");
54+
auto image1 = importer1->run()->getOutput<Image>();
55+
auto importer2 = ImageFileImporter::create(Config::getTestDataPath() + "US/Ball/US-3Dt_20.mhd");
56+
57+
auto mse = MSE::create()
58+
->connect(0, importer1)
59+
->connect(1, importer2);
60+
CHECK_NOTHROW(
61+
mse->run();
62+
);
63+
64+
auto output = mse->getOutput<Image>();
65+
CHECK(mse->get() == Approx(214.87));
66+
CHECK(output->getWidth() == image1->getWidth());
67+
CHECK(output->getHeight() == image1->getHeight());
68+
CHECK(output->getDepth() == image1->getDepth());
69+
CHECK(output->getDataType() == TYPE_FLOAT);
70+
CHECK(output->getNrOfChannels() == image1->getNrOfChannels());
71+
}
72+
73+
74+
TEST_CASE("PSNR on 2D image", "[fast][PSNR][ImageComparison]") {
75+
auto importer1 = ImageFileImporter::create(Config::getTestDataPath() + "US/Heart/ApicalFourChamber/US-2D_0.mhd");
76+
auto image1 = importer1->run()->getOutput<Image>();
77+
auto importer2 = ImageFileImporter::create(Config::getTestDataPath() + "US/Heart/ApicalFourChamber/US-2D_20.mhd");
78+
79+
auto psnr = PSNR::create(255)
80+
->connect(0, importer1)
81+
->connect(1, importer2);
82+
CHECK_NOTHROW(
83+
psnr->run();
84+
);
85+
86+
auto output = psnr->getOutput<Image>();
87+
CHECK(psnr->get() == Approx(19.8275));
88+
CHECK(output->getWidth() == image1->getWidth());
89+
CHECK(output->getHeight() == image1->getHeight());
90+
CHECK(output->getDepth() == image1->getDepth());
91+
CHECK(output->getDataType() == TYPE_FLOAT);
92+
CHECK(output->getNrOfChannels() == image1->getNrOfChannels());
93+
}
94+
95+
TEST_CASE("SSIM on 2D image", "[fast][SSIM][ImageComparison]") {
96+
auto importer1 = ImageFileImporter::create(Config::getTestDataPath() + "US/Heart/ApicalFourChamber/US-2D_0.mhd");
97+
auto image1 = importer1->run()->getOutput<Image>();
98+
auto importer2 = ImageFileImporter::create(Config::getTestDataPath() + "US/Heart/ApicalFourChamber/US-2D_20.mhd");
99+
100+
// Crop to be able to compare value with other implementations
101+
auto cropper1 = ImageCropper::create(Vector2i(image1->getWidth()-10, image1->getHeight()-10), Vector2i(5, 5))->connect(image1);
102+
auto cropper2 = ImageCropper::create(Vector2i(image1->getWidth()-10, image1->getHeight()-10), Vector2i(5, 5))->connect(importer2);
103+
image1 = cropper1->run()->getOutput<Image>();
104+
105+
auto ssim = SSIM::create(255, 0, Vector3i::Constant(11), Vector3f::Constant(1.5f))
106+
->connect(0, image1)
107+
->connect(1, cropper2);
108+
ssim->enableRuntimeMeasurements();
109+
ssim->run();
110+
111+
auto output = ssim->getOutput<Image>();
112+
auto value = ssim->getOutput<FloatScalar>(1);
113+
std::cout << ssim->get() << std::endl;
114+
ssim->getAllRuntimes()->printAll();
115+
116+
CHECK(value->get() == ssim->get());
117+
CHECK(ssim->get() == Approx(0.59).epsilon(0.01));
118+
CHECK(output->getWidth() == image1->getWidth());
119+
CHECK(output->getHeight() == image1->getHeight());
120+
CHECK(output->getDepth() == image1->getDepth());
121+
CHECK(output->getDataType() == TYPE_FLOAT);
122+
CHECK(output->getNrOfChannels() == image1->getNrOfChannels());
123+
}
124+
125+
TEST_CASE("SSIM on 2D color image", "[fast][SSIM][ImageComparison]") {
126+
auto importer1 = WholeSlideImageImporter::create(Config::getTestDataPath() + "WSI/CMU-1.svs");
127+
auto wsi = importer1->run()->getOutput<ImagePyramid>();
128+
auto access = wsi->getAccess(ACCESS_READ);
129+
auto image1 = access->getLevelAsImage(wsi->getNrOfLevels()-1);
130+
auto image2 = GaussianSmoothing::create()->connect(image1);
131+
132+
auto ssim = SSIM::create(255)->connect(image1)->connect(1, image2);
133+
ssim->enableRuntimeMeasurements();
134+
135+
CHECK_NOTHROW(
136+
ssim->run();
137+
);
138+
139+
auto output = ssim->getOutput<Image>(0);
140+
auto value = ssim->getOutput<FloatScalar>(1);
141+
ssim->getAllRuntimes()->printAll();
142+
143+
CHECK(value->get() == ssim->get());
144+
CHECK(output->getWidth() == image1->getWidth());
145+
CHECK(output->getHeight() == image1->getHeight());
146+
CHECK(output->getDepth() == image1->getDepth());
147+
CHECK(output->getDataType() == TYPE_FLOAT);
148+
CHECK(output->getNrOfChannels() == image1->getNrOfChannels());
149+
}
150+
151+
TEST_CASE("SSIM on 3D images", "[fast][SSIM][ImageComparison]") {
152+
auto importer1 = ImageFileImporter::create(Config::getTestDataPath() + "US/Ball/US-3Dt_0.mhd");
153+
auto image1 = importer1->run()->getOutput<Image>();
154+
auto importer2 = ImageFileImporter::create(Config::getTestDataPath() + "US/Ball/US-3Dt_20.mhd");
155+
156+
// Crop to be able to compare value with other implementations
157+
auto cropper1 = ImageCropper::create(Vector3i(image1->getWidth()-10, image1->getHeight()-10, image1->getDepth()-10), Vector3i(5, 5, 5))->connect(image1);
158+
auto cropper2 = ImageCropper::create(Vector3i(image1->getWidth()-10, image1->getHeight()-10, image1->getDepth()-10), Vector3i(5, 5, 5))->connect(importer2);
159+
image1 = cropper1->run()->getOutput<Image>();
160+
161+
auto ssim = SSIM::create(255)
162+
->connect(0, image1)
163+
->connect(1, cropper2);
164+
CHECK_NOTHROW(
165+
ssim->run();
166+
);
167+
std::cout << ssim->get() << std::endl;
168+
169+
auto output = ssim->getOutput<Image>(0);
170+
auto value = ssim->getOutput<FloatScalar>(1);
171+
172+
CHECK(value->get() == ssim->get());
173+
CHECK(ssim->get() == Approx(0.73).epsilon(0.01));
174+
CHECK(output->getWidth() == image1->getWidth());
175+
CHECK(output->getHeight() == image1->getHeight());
176+
CHECK(output->getDepth() == image1->getDepth());
177+
CHECK(output->getDataType() == TYPE_FLOAT);
178+
CHECK(output->getNrOfChannels() == image1->getNrOfChannels());
179+
}
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
__const sampler_t sampler = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_NONE | CLK_FILTER_NEAREST;
2+
3+
float4 readImageAsFloat2D(__read_only image2d_t image, sampler_t sampler, int2 position) {
4+
int dataType = get_image_channel_data_type(image);
5+
if(dataType == CLK_FLOAT || dataType == CLK_SNORM_INT16 || dataType == CLK_UNORM_INT16) {
6+
return read_imagef(image, sampler, position);
7+
} else if(dataType == CLK_SIGNED_INT8 || dataType == CLK_SIGNED_INT16 || dataType == CLK_SIGNED_INT32) {
8+
return convert_float4(read_imagei(image, sampler, position));
9+
} else {
10+
return convert_float4(read_imageui(image, sampler, position));
11+
}
12+
}
13+
14+
__kernel void squaredError2D(
15+
__read_only image2d_t input1,
16+
__read_only image2d_t input2,
17+
__write_only image2d_t output
18+
) {
19+
const int2 pos = {get_global_id(0), get_global_id(1)};
20+
float4 value1 = readImageAsFloat2D(input1, sampler, pos);
21+
float4 value2 = readImageAsFloat2D(input2, sampler, pos);
22+
write_imagef(output, pos, (value1 - value2)*(value1 - value2));
23+
}
24+
25+
float4 readImageAsFloat3D(__read_only image3d_t image, sampler_t sampler, int4 position) {
26+
int dataType = get_image_channel_data_type(image);
27+
if(dataType == CLK_FLOAT || dataType == CLK_SNORM_INT16 || dataType == CLK_UNORM_INT16) {
28+
return read_imagef(image, sampler, position);
29+
} else if(dataType == CLK_SIGNED_INT8 || dataType == CLK_SIGNED_INT16 || dataType == CLK_SIGNED_INT32) {
30+
return convert_float4(read_imagei(image, sampler, position));
31+
} else {
32+
return convert_float4(read_imageui(image, sampler, position));
33+
}
34+
}
35+
36+
#ifdef fast_3d_image_writes
37+
__kernel void squaredError3D(
38+
__read_only image3d_t input1,
39+
__read_only image3d_t input2,
40+
__write_only image3d_t output
41+
) {
42+
const int4 pos = {get_global_id(0), get_global_id(1), get_global_id(2), 0};
43+
float4 value1 = readImageAsFloat3D(input1, sampler, pos);
44+
float4 value2 = readImageAsFloat3D(input2, sampler, pos);
45+
write_imagef(output, pos, (value1 - value2)*(value1 - value2));
46+
}
47+
#else
48+
__kernel void squaredError3DBuffer(
49+
__read_only image3d_t input1,
50+
__read_only image3d_t input2,
51+
__global float* output,
52+
__private const int channels
53+
) {
54+
const int4 pos = {get_global_id(0), get_global_id(1), get_global_id(2), 0};
55+
float4 value1 = readImageAsFloat3D(input1, sampler, pos);
56+
float4 value2 = readImageAsFloat3D(input2, sampler, pos);
57+
float4 squaredError = (value1 - value2)*(value1 - value2);
58+
output[(pos.x + pos.y*get_global_size(0) + pos.z*get_global_size(0)*get_global_size(1))*channels] = squaredError.x;
59+
if(channels > 1)
60+
output[(pos.x + pos.y*get_global_size(0) + pos.z*get_global_size(0)*get_global_size(1))*channels + 1] = squaredError.y;
61+
if(channels > 2)
62+
output[(pos.x + pos.y*get_global_size(0) + pos.z*get_global_size(0)*get_global_size(1))*channels + 2] = squaredError.z;
63+
if(channels > 3)
64+
output[(pos.x + pos.y*get_global_size(0) + pos.z*get_global_size(0)*get_global_size(1))*channels + 3] = squaredError.w;
65+
}
66+
#endif
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
#include <FAST/Visualization/Plotting/LinePlotter.hpp>
2+
#include <FAST/Algorithms/ImageChannelConverter/ImageChannelConverter.hpp>
3+
#include "MSE.hpp"
4+
5+
namespace fast {
6+
7+
MeanSquaredError::MeanSquaredError() {
8+
createInputPort(0);
9+
createInputPort(1);
10+
createOutputPort(0);
11+
createOutputPort(1);
12+
13+
createOpenCLProgram(Config::getKernelSourcePath() + "Algorithms/ImageComparison/MSE.cl");
14+
}
15+
16+
void MeanSquaredError::execute() {
17+
auto image1 = getInputData<Image>(0);
18+
auto image2 = getInputData<Image>(1);
19+
20+
auto outputImage = calculateSquaredDiffImage(image1, image2);
21+
auto MSE = calculateMSE(outputImage);
22+
23+
m_value = MSE;
24+
auto outputValue = FloatScalar::create(MSE);
25+
26+
addOutputData(0, outputImage);
27+
addOutputData(1, outputValue);
28+
}
29+
30+
float MeanSquaredError::get() const {
31+
return m_value;
32+
}
33+
34+
Image::pointer MeanSquaredError::calculateSquaredDiffImage(Image::pointer image1, Image::pointer image2) {
35+
if(image1->getSize() != image2->getSize())
36+
throw Exception("Images must be the same size");
37+
38+
if(image1->getNrOfChannels() != image2->getNrOfChannels())
39+
throw Exception("Images must have same number of channels");
40+
41+
auto output = Image::create(image1->getSize(), TYPE_FLOAT, image1->getNrOfChannels());
42+
SceneGraph::setParentNode(output, image1);
43+
output->setSpacing(image1->getSpacing());
44+
45+
Kernel kernel;
46+
if(output->getDimensions() == 2) {
47+
kernel = getKernel("squaredError2D");
48+
} else {
49+
if(getMainOpenCLDevice()->isWritingTo3DTexturesSupported()) {
50+
kernel = getKernel("squaredError3D");
51+
} else {
52+
kernel = getKernel("squaredError3DBuffer");
53+
kernel.setArg("channels", output->getNrOfChannels());
54+
}
55+
}
56+
kernel.setArg("input1", image1);
57+
kernel.setArg("input2", image2);
58+
kernel.setArg("output", output);
59+
60+
getQueue().add(kernel, output->getSize());
61+
62+
return output;
63+
}
64+
65+
float MeanSquaredError::calculateMSE(Image::pointer output) {
66+
float MSE = 0.0f;
67+
if(output->getNrOfChannels() > 1) {
68+
// FIXME calculateAverageIntensity doesn't support multi-channels yet, so to this for now:
69+
float MSEsum = 0.0f;
70+
for(int i = 0; i < output->getNrOfChannels(); ++i) {
71+
std::vector<int> toRemove;
72+
for(int j = 0; j < output->getNrOfChannels(); ++j) {
73+
if(i != j)
74+
toRemove.push_back(j);
75+
}
76+
auto singleChannelImage = ImageChannelConverter::create(toRemove)
77+
->connect(output)
78+
->run()
79+
->getOutput<Image>();
80+
MSEsum += singleChannelImage->calculateAverageIntensity();
81+
}
82+
MSE = MSEsum / output->getNrOfChannels();
83+
} else {
84+
MSE = output->calculateAverageIntensity();
85+
}
86+
return MSE;
87+
}
88+
89+
}

0 commit comments

Comments
 (0)