Skip to content

Commit b7a709d

Browse files
authored
Merge pull request #2521 from SCIInstitute/amorris/2520-dwd
Add DWD group analysis and fix LDA/DWD stale results
2 parents 5e5d10e + ca59f2d commit b7a709d

12 files changed

Lines changed: 463 additions & 29 deletions

File tree

Python/shapeworks/shapeworks/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,10 @@
55
from shapeworks_py import *
66
from .conversion import sw2vtkImage, sw2vtkMesh
77
from .plot import plot_meshes, plot_volumes, plot_meshes_volumes_mix, add_mesh_to_plotter, add_volume_to_plotter, plot_mesh_contour,plot_pca_metrics,\
8-
pca_loadings_violinplot,plot_mode_line,visualize_reconstruction,lda_plot
8+
pca_loadings_violinplot,plot_mode_line,visualize_reconstruction,lda_plot,dwd_plot
99
from .utils import num_subplots, positive_factors, save_images, get_file_with_ext, find_reference_image_index, find_reference_mesh_index, load_mesh
1010
from .data import get_file_list, sample_images, sample_meshes
11-
from .stats import compute_pvalues_for_group_difference,lda
11+
from .stats import compute_pvalues_for_group_difference,lda,dwd_loadings
1212
from .network_analysis import NetworkAnalysis
1313
from .portal import download_dataset
1414
from .shape_scalars import run_mbpls

Python/shapeworks/shapeworks/plot.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -481,7 +481,7 @@ def lda_plot(group1_x,group2_x,group1_pdf,group2_pdf,group1_map,group2_map,lda_d
481481
group2_num = len(group2_map)
482482
plt.plot(group1_x, group1_pdf, label = labels[0] + ' PDF',linewidth=10)
483483
plt.plot(group2_x, group2_pdf, label = labels[1] + ' PDF',linewidth=10)
484-
484+
485485
plt.scatter(group1_map, 0.01*np.ones((group1_num)), s=330, label = labels[0] + ' Shape Mappings', edgecolors='black',linewidths=5)
486486
plt.scatter(group2_map, 0.01*np.ones((group2_num)), s=330, label = labels[1] + ' Shape Mappings', edgecolors='black',linewidths=5)
487487
plt.ylabel("Probability Density")
@@ -491,4 +491,25 @@ def lda_plot(group1_x,group2_x,group1_pdf,group2_pdf,group1_map,group2_map,lda_d
491491
plt.close(fig)
492492

493493
print("Figure saved in directory -" + lda_dir)
494-
print()
494+
print()
495+
496+
def dwd_plot(group1_x,group2_x,group1_pdf,group2_pdf,group1_map,group2_map,dwd_dir,labels):
497+
498+
plt.figure(dpi=50,figsize=(14,14))
499+
fig = plt.gcf()
500+
plt.rcParams['font.size'] = '20'
501+
group1_num = len(group1_map)
502+
group2_num = len(group2_map)
503+
plt.plot(group1_x, group1_pdf, label = labels[0] + ' PDF',linewidth=10)
504+
plt.plot(group2_x, group2_pdf, label = labels[1] + ' PDF',linewidth=10)
505+
506+
plt.scatter(group1_map, 0.01*np.ones((group1_num)), s=330, label = labels[0] + ' Shape Mappings', edgecolors='black',linewidths=5)
507+
plt.scatter(group2_map, 0.01*np.ones((group2_num)), s=330, label = labels[1] + ' Shape Mappings', edgecolors='black',linewidths=5)
508+
plt.ylabel("Probability Density")
509+
plt.xlabel('Shape mapping to DWD discrimination of variation between population means')
510+
plt.legend(loc='upper right')
511+
plt.savefig(dwd_dir+"/DWD.png")
512+
plt.close(fig)
513+
514+
print("Figure saved in directory -" + dwd_dir)
515+
print()

Python/shapeworks/shapeworks/stats.py

Lines changed: 57 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -71,33 +71,49 @@ def compute_pvalues_for_group_difference_data(group_0_data, group_1_data, permut
7171

7272

7373
def normalize(subj_map, group1_mean_map, group2_mean_map):
74-
slope = (2.0 / (group2_mean_map - group1_mean_map))
74+
denom = group2_mean_map - group1_mean_map
75+
if abs(denom) < 1e-12:
76+
return 0.0
77+
slope = 2.0 / denom
7578
subj_diff = subj_map - group1_mean_map
7679
subj_map_normalized = slope * subj_diff - 1
7780
return subj_map_normalized
7881

7982

8083
def lda_loadings(group1_data, group2_data):
81-
group1_num = np.shape(group1_data)[1]
82-
group2_num = np.shape(group2_data)[1]
83-
8484
combined_data = np.concatenate((group1_data, group2_data), axis=1)
8585
group1_mean = np.mean(group1_data, axis=1)
8686
group2_mean = np.mean(group2_data, axis=1)
8787

88-
overall_mean = np.mean(combined_data, axis=1)
89-
9088
diffVect = group1_mean - group2_mean
9189

90+
return _project_and_pdf(diffVect, group1_data, group2_data, combined_data)
91+
92+
93+
def _project_and_pdf(diffVect, group1_data, group2_data, combined_data):
94+
"""Shared logic for projecting groups onto a discriminant direction and fitting PDFs.
95+
96+
Args:
97+
diffVect: Discriminant direction vector (features,)
98+
group1_data: PCA loadings for group 1 (features x samples)
99+
group2_data: PCA loadings for group 2 (features x samples)
100+
combined_data: Concatenation of group1_data and group2_data (features x all_samples)
101+
102+
Returns: 6-tuple (group1_x, group2_x, group1_pdf, group2_pdf, group1_map, group2_map)
103+
"""
104+
group1_num = group1_data.shape[1]
105+
group2_num = group2_data.shape[1]
106+
107+
group1_mean = np.mean(group1_data, axis=1)
108+
group2_mean = np.mean(group2_data, axis=1)
109+
overall_mean = np.mean(combined_data, axis=1)
110+
92111
group1_mean_diff = group1_mean - overall_mean
93112
group2_mean_diff = group2_mean - overall_mean
94113

95114
group1_mean_map = np.dot(diffVect, group1_mean_diff)
96115
group2_mean_map = np.dot(diffVect, group2_mean_diff)
97116

98-
group1_mean_map_normalized = normalize(group1_mean_map, group1_mean_map, group2_mean_map)
99-
group2_mean_map_normalized = normalize(group2_mean_map, group1_mean_map, group2_mean_map)
100-
101117
group1_map = np.zeros((group1_num,))
102118
group2_map = np.zeros((group2_num,))
103119

@@ -117,6 +133,13 @@ def lda_loadings(group1_data, group2_data):
117133
group1_map_std = group1_map.std()
118134
group2_map_std = group2_map.std()
119135

136+
# Guard against zero std (all samples project to same point)
137+
min_std = 1e-6
138+
if group1_map_std < min_std:
139+
group1_map_std = min_std
140+
if group2_map_std < min_std:
141+
group2_map_std = min_std
142+
120143
group1_x = np.linspace(group1_map_mean - 6, group1_map_mean + 6, num=300)
121144
group2_x = np.linspace(group2_map_mean - 6, group2_map_mean + 6, num=300)
122145

@@ -125,6 +148,31 @@ def lda_loadings(group1_data, group2_data):
125148
return group1_x, group2_x, group1_pdf, group2_pdf, group1_map, group2_map
126149

127150

151+
def dwd_loadings(group1_data, group2_data):
152+
from dwd.gen_dwd import GenDWD
153+
group1_num = np.shape(group1_data)[1]
154+
group2_num = np.shape(group2_data)[1]
155+
156+
if group1_num < 2 or group2_num < 2:
157+
raise ValueError(f"DWD requires at least 2 samples per group (got {group1_num} and {group2_num})")
158+
159+
combined_data = np.concatenate((group1_data, group2_data), axis=1)
160+
161+
# Fit GenDWD (samples x features)
162+
X = combined_data.T
163+
y = np.array([1]*group1_num + [-1]*group2_num)
164+
165+
try:
166+
model = GenDWD(lambd=1.0)
167+
model.fit(X, y)
168+
except Exception as e:
169+
raise RuntimeError(f"DWD fitting failed: {e}") from e
170+
171+
diffVect = model.coef_.flatten()
172+
173+
return _project_and_pdf(diffVect, group1_data, group2_data, combined_data)
174+
175+
128176
def lda(data):
129177
group_id = data["group_ids"].unique()
130178
group1_idxs = data.index[data['group_ids'] == 0].tolist()

Studio/Analysis/AnalysisTool.cpp

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include <Job/NetworkAnalysisJob.h>
1212
#include <Job/ParticleNormalEvaluationJob.h>
1313
#include <Job/StatsGroupLDAJob.h>
14+
#include <Job/StatsGroupDWDJob.h>
1415
#include <Libs/Application/Job/PythonWorker.h>
1516
#include <Groom/GroomParameters.h>
1617
#include <Logging.h>
@@ -168,6 +169,12 @@ AnalysisTool::AnalysisTool(Preferences& prefs) : preferences_(prefs) {
168169
connect(group_lda_job_.data(), &StatsGroupLDAJob::progress, this, &AnalysisTool::handle_lda_progress);
169170
connect(group_lda_job_.data(), &StatsGroupLDAJob::finished, this, &AnalysisTool::handle_lda_complete);
170171

172+
ui_->dwd_graph->hide();
173+
ui_->dwd_hint_label->hide();
174+
group_dwd_job_ = QSharedPointer<StatsGroupDWDJob>::create();
175+
connect(group_dwd_job_.data(), &StatsGroupDWDJob::progress, this, &AnalysisTool::handle_dwd_progress);
176+
connect(group_dwd_job_.data(), &StatsGroupDWDJob::finished, this, &AnalysisTool::handle_dwd_complete);
177+
171178
connect(ui_->show_difference_to_mean, &QPushButton::clicked, this, &AnalysisTool::show_difference_to_mean_clicked);
172179

173180
connect(ui_->group_analysis_combo, qOverload<int>(&QComboBox::currentIndexChanged), this,
@@ -1620,6 +1627,24 @@ void AnalysisTool::update_lda_graph() {
16201627
}
16211628
}
16221629

1630+
//---------------------------------------------------------------------------
1631+
void AnalysisTool::update_dwd_graph() {
1632+
if (groups_active()) {
1633+
if (!dwd_computed_ && !group_dwd_job_running_) {
1634+
group_dwd_job_running_ = true;
1635+
ui_->dwd_label->show();
1636+
ui_->dwd_progress->setValue(0);
1637+
ui_->dwd_progress->setMaximum(0);
1638+
ui_->dwd_progress->update();
1639+
group_dwd_job_->set_stats(stats_);
1640+
app_->get_py_worker()->run_job(group_dwd_job_);
1641+
}
1642+
} else {
1643+
ui_->dwd_graph->setVisible(false);
1644+
ui_->dwd_hint_label->setVisible(false);
1645+
}
1646+
}
1647+
16231648
//---------------------------------------------------------------------------
16241649
void AnalysisTool::update_difference_particles() {
16251650
if (!stats_ready_) {
@@ -1679,7 +1704,10 @@ void AnalysisTool::group_changed() {
16791704
stats_ready_ = false;
16801705
group_pvalue_job_ = nullptr;
16811706
lda_computed_ = false;
1707+
dwd_computed_ = false;
16821708
compute_stats();
1709+
// Re-trigger LDA/DWD if currently visible
1710+
group_analysis_combo_changed();
16831711
}
16841712

16851713
//---------------------------------------------------------------------------
@@ -1909,12 +1937,64 @@ void AnalysisTool::handle_lda_complete() {
19091937
QString left_group = ui_->group_left->currentText();
19101938
QString right_group = ui_->group_right->currentText();
19111939

1940+
if (!group_lda_job_->succeeded()) {
1941+
ui_->lda_graph->setVisible(false);
1942+
if (left_group == right_group) {
1943+
ui_->lda_hint_label->setText("LDA requires two distinct groups.");
1944+
} else {
1945+
ui_->lda_hint_label->setText("LDA computation failed. Check log for details.");
1946+
}
1947+
ui_->lda_hint_label->setVisible(true);
1948+
QTimer::singleShot(0, this, &AnalysisTool::resize_tab_to_current);
1949+
return;
1950+
}
1951+
19121952
group_lda_job_->plot(ui_->lda_graph, left_group, right_group);
19131953
ui_->lda_graph->setVisible(true);
19141954
ui_->lda_hint_label->setVisible(true);
19151955
QTimer::singleShot(0, this, &AnalysisTool::resize_tab_to_current);
19161956
}
19171957

1958+
//---------------------------------------------------------------------------
1959+
void AnalysisTool::handle_dwd_progress(double progress) {
1960+
if (progress > 0) {
1961+
ui_->dwd_progress->setMaximum(100);
1962+
} else {
1963+
ui_->dwd_progress->setMaximum(0);
1964+
}
1965+
ui_->dwd_progress_widget->setVisible(progress < 1);
1966+
ui_->dwd_progress->setValue(progress * 100);
1967+
ui_->dwd_progress->update();
1968+
}
1969+
1970+
//---------------------------------------------------------------------------
1971+
void AnalysisTool::handle_dwd_complete() {
1972+
ui_->dwd_progress_widget->setVisible(false);
1973+
ui_->dwd_label->setVisible(false);
1974+
group_dwd_job_running_ = false;
1975+
dwd_computed_ = true;
1976+
1977+
QString left_group = ui_->group_left->currentText();
1978+
QString right_group = ui_->group_right->currentText();
1979+
1980+
if (!group_dwd_job_->succeeded()) {
1981+
ui_->dwd_graph->setVisible(false);
1982+
if (left_group == right_group) {
1983+
ui_->dwd_hint_label->setText("DWD requires two distinct groups.");
1984+
} else {
1985+
ui_->dwd_hint_label->setText("DWD computation failed. Check log for details.");
1986+
}
1987+
ui_->dwd_hint_label->setVisible(true);
1988+
QTimer::singleShot(0, this, &AnalysisTool::resize_tab_to_current);
1989+
return;
1990+
}
1991+
1992+
group_dwd_job_->plot(ui_->dwd_graph, left_group, right_group);
1993+
ui_->dwd_graph->setVisible(true);
1994+
ui_->dwd_hint_label->setVisible(true);
1995+
QTimer::singleShot(0, this, &AnalysisTool::resize_tab_to_current);
1996+
}
1997+
19181998
void AnalysisTool::handle_network_analysis_progress(int progress) {
19191999
if (progress > 0) {
19202000
ui_->network_progress->setMaximum(100);
@@ -1960,6 +2040,9 @@ void AnalysisTool::group_analysis_combo_changed() {
19602040
if (ui_->group_analysis_stacked_widget->currentWidget() == ui_->lda_page) {
19612041
update_lda_graph();
19622042
}
2043+
if (ui_->group_analysis_stacked_widget->currentWidget() == ui_->dwd_page) {
2044+
update_dwd_graph();
2045+
}
19632046
}
19642047
// Recalculate tab height since analysis content changed
19652048
QTimer::singleShot(0, this, &AnalysisTool::resize_tab_to_current);

Studio/Analysis/AnalysisTool.h

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ class ShapeWorksStudioApp;
2828
class GroupPvalueJob;
2929
class NetworkAnalysisJob;
3030
class StatsGroupLDAJob;
31+
class StatsGroupDWDJob;
3132
class ParticleAreaPanel;
3233
class ShapeScalarPanel;
3334

@@ -37,7 +38,7 @@ class AnalysisTool : public QWidget {
3738
public:
3839
using AlignmentType = Analyze::AlignmentType;
3940

40-
enum GroupAnalysisType { None = 0, Pvalues = 1, NetworkAnalysis = 2, LDA = 3 };
41+
enum GroupAnalysisType { None = 0, Pvalues = 1, NetworkAnalysis = 2, LDA = 3, DWD = 4 };
4142

4243
enum McaMode { Vanilla, Within, Between };
4344

@@ -187,6 +188,9 @@ class AnalysisTool : public QWidget {
187188
void handle_lda_progress(double progress);
188189
void handle_lda_complete();
189190

191+
void handle_dwd_progress(double progress);
192+
void handle_dwd_complete();
193+
190194
void handle_network_analysis_progress(int progress);
191195
void handle_network_analysis_complete();
192196

@@ -243,6 +247,7 @@ class AnalysisTool : public QWidget {
243247
void handle_pca_group_list_item_changed();
244248

245249
void update_lda_graph();
250+
void update_dwd_graph();
246251

247252
void update_difference_particles();
248253

@@ -294,10 +299,13 @@ class AnalysisTool : public QWidget {
294299

295300
QSharedPointer<GroupPvalueJob> group_pvalue_job_;
296301
QSharedPointer<StatsGroupLDAJob> group_lda_job_;
302+
QSharedPointer<StatsGroupDWDJob> group_dwd_job_;
297303
QSharedPointer<NetworkAnalysisJob> network_analysis_job_;
298304

299305
bool group_lda_job_running_ = false;
300306
bool lda_computed_ = false;
307+
bool group_dwd_job_running_ = false;
308+
bool dwd_computed_ = false;
301309
bool block_group_change_ = false;
302310

303311
ParticleAreaPanel* particle_area_panel_{nullptr};

0 commit comments

Comments
 (0)