Skip to content

Commit ca59f2d

Browse files
committed
Added missing file
1 parent e73c00a commit ca59f2d

2 files changed

Lines changed: 183 additions & 0 deletions

File tree

Studio/Job/StatsGroupDWDJob.cpp

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
#include <pybind11/eigen.h>
2+
#include <pybind11/embed.h>
3+
#include <pybind11/stl.h>
4+
namespace py = pybind11;
5+
using namespace pybind11::literals; // to bring in the `_a` literal
6+
7+
#include <Job/StatsGroupDWDJob.h>
8+
#include <Logging.h>
9+
#include <jkqtplotter/graphs/jkqtpscatter.h>
10+
#include <jkqtplotter/jkqtplotter.h>
11+
12+
namespace shapeworks {
13+
14+
//---------------------------------------------------------------------------
15+
StatsGroupDWDJob::StatsGroupDWDJob() {}
16+
17+
//---------------------------------------------------------------------------
18+
void StatsGroupDWDJob::set_stats(ParticleShapeStatistics stats) { stats_ = stats; }
19+
20+
//---------------------------------------------------------------------------
21+
void StatsGroupDWDJob::run() {
22+
succeeded_ = false;
23+
Q_EMIT progress(0.1);
24+
stats_.principal_component_projections();
25+
auto pca_loadings = stats_.get_pca_loadings();
26+
Q_EMIT progress(0.2);
27+
28+
auto& group_ids = stats_.GroupID();
29+
30+
int num_samples = pca_loadings.rows();
31+
32+
Eigen::MatrixXd group_1_data;
33+
Eigen::MatrixXd group_2_data;
34+
35+
int group_1_count = std::count(group_ids.begin(), group_ids.end(), 1);
36+
int group_2_count = num_samples - group_1_count;
37+
if (group_1_count == 0 || group_2_count == 0) {
38+
return;
39+
}
40+
41+
group_1_data.resize(group_1_count, pca_loadings.cols());
42+
group_2_data.resize(group_2_count, pca_loadings.cols());
43+
44+
int group_1_idx = 0;
45+
int group_2_idx = 0;
46+
for (int i = 0; i < num_samples; i++) {
47+
if (group_ids[i] == 1) {
48+
group_1_data.row(group_1_idx++) = pca_loadings.row(i);
49+
} else {
50+
group_2_data.row(group_2_idx++) = pca_loadings.row(i);
51+
}
52+
}
53+
54+
try {
55+
py::module sw = py::module::import("shapeworks");
56+
py::object dwd_loadings = sw.attr("stats").attr("dwd_loadings");
57+
Q_EMIT progress(0.5);
58+
59+
using ResultType =
60+
std::tuple<Eigen::MatrixXd, Eigen::MatrixXd, Eigen::MatrixXd, Eigen::MatrixXd, Eigen::MatrixXd, Eigen::MatrixXd>;
61+
ResultType result = dwd_loadings(group_1_data.transpose(), group_2_data.transpose()).cast<ResultType>();
62+
63+
group1_x_ = std::get<0>(result);
64+
group2_x_ = std::get<1>(result);
65+
group1_pdf_ = std::get<2>(result);
66+
group2_pdf_ = std::get<3>(result);
67+
group1_map_ = std::get<4>(result);
68+
group2_map_ = std::get<5>(result);
69+
} catch (const std::exception& e) {
70+
SW_ERROR("DWD computation failed: {}", e.what());
71+
succeeded_ = false;
72+
return;
73+
}
74+
75+
succeeded_ = true;
76+
Q_EMIT progress(1.0);
77+
}
78+
79+
//---------------------------------------------------------------------------
80+
QString StatsGroupDWDJob::name() { return "Group DWD"; }
81+
82+
//---------------------------------------------------------------------------
83+
void StatsGroupDWDJob::plot(JKQTPlotter* plot, QString group_1_name, QString group_2_name) {
84+
JKQTPDatastore* ds = plot->getDatastore();
85+
ds->clear();
86+
plot->clearGraphs();
87+
88+
QString title = "DWD";
89+
90+
auto draw_line_plot = [&](Eigen::MatrixXd x, Eigen::MatrixXd y, QString name, QColor color) {
91+
QVector<double> xv, yv;
92+
for (int i = 0; i < x.size(); i++) {
93+
xv << x(i);
94+
yv << y(i);
95+
}
96+
97+
QString x_label = name + " PDF";
98+
QString y_label = name + " y";
99+
100+
size_t column_x = ds->addCopiedColumn(xv, x_label);
101+
size_t column_y = ds->addCopiedColumn(yv, y_label);
102+
103+
JKQTPXYLineGraph* graph = new JKQTPXYLineGraph(plot);
104+
graph->setColor(color);
105+
graph->setSymbolType(JKQTPNoSymbol);
106+
graph->setXColumn(column_x);
107+
graph->setYColumn(column_y);
108+
graph->setTitle(name + " PDF");
109+
plot->addGraph(graph);
110+
};
111+
112+
draw_line_plot(group1_x_, group1_pdf_, group_1_name, QColor(239, 133, 54));
113+
draw_line_plot(group2_x_, group2_pdf_, group_2_name, Qt::blue);
114+
115+
auto draw_scatter_plot = [&](Eigen::MatrixXd map, QString name, QColor color) {
116+
QVector<double> x, y;
117+
for (int i = 0; i < map.size(); i++) {
118+
x << map(i);
119+
y << 0.01;
120+
}
121+
122+
int column_x = ds->addCopiedColumn(x, name + "scatter x");
123+
int column_y = ds->addCopiedColumn(y, name + "scatter y");
124+
125+
auto scatter = new JKQTPXYParametrizedScatterGraph(plot);
126+
scatter->setColor(color);
127+
scatter->setXColumn(column_x);
128+
scatter->setYColumn(column_y);
129+
scatter->setTitle(name + " Shape Mappings");
130+
plot->addGraph(scatter);
131+
};
132+
133+
draw_scatter_plot(group1_map_, group_1_name, QColor(239, 133, 54));
134+
draw_scatter_plot(group2_map_, group_2_name, Qt::blue);
135+
136+
plot->getPlotter()->setUseAntiAliasingForGraphs(true);
137+
plot->getPlotter()->setUseAntiAliasingForSystem(true);
138+
plot->getPlotter()->setUseAntiAliasingForText(true);
139+
plot->getPlotter()->setPlotLabelFontSize(18);
140+
plot->getPlotter()->setPlotLabel("\\textbf{" + title + "}");
141+
plot->getPlotter()->setDefaultTextSize(14);
142+
plot->getPlotter()->setShowKey(true);
143+
144+
plot->getXAxis()->setAxisLabel("Shape mapping to DWD discrimination of variation between population means");
145+
plot->getXAxis()->setLabelFontSize(8);
146+
plot->getYAxis()->setAxisLabel("Probability Density");
147+
plot->getYAxis()->setLabelFontSize(14);
148+
149+
plot->clearAllMouseWheelActions();
150+
plot->setMousePositionShown(false);
151+
plot->setMinimumSize(250, 250);
152+
plot->zoomToFit();
153+
}
154+
} // namespace shapeworks

Studio/Job/StatsGroupDWDJob.h

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
#pragma once
2+
#include <Job/Job.h>
3+
#include <ParticleShapeStatistics.h>
4+
5+
class JKQTPlotter;
6+
7+
namespace shapeworks {
8+
9+
class StatsGroupDWDJob : public Job {
10+
Q_OBJECT
11+
public:
12+
StatsGroupDWDJob();
13+
14+
void set_stats(ParticleShapeStatistics stats);
15+
16+
void run() override;
17+
18+
QString name() override;
19+
20+
void plot(JKQTPlotter* plot, QString group_1_name, QString group_2_name);
21+
22+
bool succeeded() const { return succeeded_; }
23+
24+
private:
25+
bool succeeded_ = false;
26+
ParticleShapeStatistics stats_;
27+
Eigen::MatrixXd group1_x_, group2_x_, group1_pdf_, group2_pdf_, group1_map_, group2_map_;
28+
};
29+
} // namespace shapeworks

0 commit comments

Comments
 (0)