Skip to content

Commit 48f9c6a

Browse files
authored
Merge pull request #2528 from SCIInstitute/amorris/fix-dwd
Fix DWD group analysis producing broken plots
2 parents 7afb08c + 8d9c5a8 commit 48f9c6a

3 files changed

Lines changed: 44 additions & 8 deletions

File tree

Python/shapeworks/shapeworks/stats.py

Lines changed: 34 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -92,14 +92,18 @@ def lda_loadings(group1_data, group2_data):
9292
return _project_and_pdf(diffVect, group1_data, group2_data, combined_data)
9393

9494

95-
def _project_and_pdf(diffVect, group1_data, group2_data, combined_data):
95+
def _project_and_pdf(diffVect, group1_data, group2_data, combined_data, normalize_projections=True):
9696
"""Shared logic for projecting groups onto a discriminant direction and fitting PDFs.
9797
9898
Args:
9999
diffVect: Discriminant direction vector (features,)
100100
group1_data: PCA loadings for group 1 (features x samples)
101101
group2_data: PCA loadings for group 2 (features x samples)
102102
combined_data: Concatenation of group1_data and group2_data (features x all_samples)
103+
normalize_projections: If True, normalize so group means map to -1 and +1.
104+
This works well when diffVect is aligned with the mean difference (e.g. LDA).
105+
Set to False for directions that may not be aligned with the mean difference
106+
(e.g. DWD), which would cause the normalization to produce extreme values.
103107
104108
Returns: 6-tuple (group1_x, group2_x, group1_pdf, group2_pdf, group1_map, group2_map)
105109
"""
@@ -122,12 +126,14 @@ def _project_and_pdf(diffVect, group1_data, group2_data, combined_data):
122126
for ii in range(group1_num):
123127
subjDiff = group1_data[:, ii] - overall_mean
124128
group1_map[ii] = np.dot(diffVect, subjDiff)
125-
group1_map[ii] = normalize(group1_map[ii], group1_mean_map, group2_mean_map)
129+
if normalize_projections:
130+
group1_map[ii] = normalize(group1_map[ii], group1_mean_map, group2_mean_map)
126131

127132
for ii in range(group2_num):
128133
subjDiff = group2_data[:, ii] - overall_mean
129134
group2_map[ii] = np.dot(diffVect, subjDiff)
130-
group2_map[ii] = normalize(group2_map[ii], group1_mean_map, group2_mean_map)
135+
if normalize_projections:
136+
group2_map[ii] = normalize(group2_map[ii], group1_mean_map, group2_mean_map)
131137

132138
group1_map_mean = group1_map.mean()
133139
group2_map_mean = group2_map.mean()
@@ -142,8 +148,20 @@ def _project_and_pdf(diffVect, group1_data, group2_data, combined_data):
142148
if group2_map_std < min_std:
143149
group2_map_std = min_std
144150

145-
group1_x = np.linspace(group1_map_mean - 6, group1_map_mean + 6, num=300)
146-
group2_x = np.linspace(group2_map_mean - 6, group2_map_mean + 6, num=300)
151+
if normalize_projections:
152+
group1_x = np.linspace(group1_map_mean - 6, group1_map_mean + 6, num=300)
153+
group2_x = np.linspace(group2_map_mean - 6, group2_map_mean + 6, num=300)
154+
else:
155+
# Common x-range covering both groups and all shape mappings so PDF
156+
# tails extend smoothly across the full plot
157+
all_maps = np.concatenate([group1_map, group2_map])
158+
max_std = max(group1_map_std, group2_map_std)
159+
x_min = min(all_maps.min(), group1_map_mean - 6 * group1_map_std,
160+
group2_map_mean - 6 * group2_map_std) - max_std
161+
x_max = max(all_maps.max(), group1_map_mean + 6 * group1_map_std,
162+
group2_map_mean + 6 * group2_map_std) + max_std
163+
group1_x = np.linspace(x_min, x_max, num=300)
164+
group2_x = np.linspace(x_min, x_max, num=300)
147165

148166
group1_pdf = stats.norm.pdf(group1_x, group1_map_mean, group1_map_std)
149167
group2_pdf = stats.norm.pdf(group2_x, group2_map_mean, group2_map_std)
@@ -172,7 +190,17 @@ def dwd_loadings(group1_data, group2_data):
172190

173191
diffVect = model.coef_.flatten()
174192

175-
return _project_and_pdf(diffVect, group1_data, group2_data, combined_data)
193+
# Normalize to unit length so projections reflect data geometry, not solver scale
194+
norm = np.linalg.norm(diffVect)
195+
if norm > 1e-12:
196+
diffVect = diffVect / norm
197+
198+
# DWD's direction optimizes for margin, not mean separation, so it may be
199+
# nearly orthogonal to the mean difference. The mean-based normalization in
200+
# _project_and_pdf divides by the projection of the mean difference onto
201+
# diffVect, which can be near-zero, producing extreme values.
202+
# Use raw projections with adaptive PDF ranges instead.
203+
return _project_and_pdf(diffVect, group1_data, group2_data, combined_data, normalize_projections=False)
176204

177205

178206
def lda(data):

Studio/Job/StatsGroupDWDJob.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,11 +112,15 @@ void StatsGroupDWDJob::plot(JKQTPlotter* plot, QString group_1_name, QString gro
112112
draw_line_plot(group1_x_, group1_pdf_, group_1_name, QColor(239, 133, 54));
113113
draw_line_plot(group2_x_, group2_pdf_, group_2_name, Qt::blue);
114114

115+
// Place shape mapping dots near the bottom of the plot, scaled to peak PDF height
116+
double peak_pdf = std::max(group1_pdf_.maxCoeff(), group2_pdf_.maxCoeff());
117+
double scatter_y = peak_pdf * 0.03;
118+
115119
auto draw_scatter_plot = [&](Eigen::MatrixXd map, QString name, QColor color) {
116120
QVector<double> x, y;
117121
for (int i = 0; i < map.size(); i++) {
118122
x << map(i);
119-
y << 0.01;
123+
y << scatter_y;
120124
}
121125

122126
int column_x = ds->addCopiedColumn(x, name + "scatter x");

Studio/Job/StatsGroupLDAJob.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,11 +111,15 @@ void StatsGroupLDAJob::plot(JKQTPlotter* plot, QString group_1_name, QString gro
111111
draw_line_plot(group1_x_, group1_pdf_, group_1_name, QColor(239, 133, 54));
112112
draw_line_plot(group2_x_, group2_pdf_, group_2_name, Qt::blue);
113113

114+
// Place shape mapping dots near the bottom of the plot, scaled to peak PDF height
115+
double peak_pdf = std::max(group1_pdf_.maxCoeff(), group2_pdf_.maxCoeff());
116+
double scatter_y = peak_pdf * 0.03;
117+
114118
auto draw_scatter_plot = [&](Eigen::MatrixXd map, QString name, QColor color) {
115119
QVector<double> x, y;
116120
for (int i = 0; i < map.size(); i++) {
117121
x << map(i);
118-
y << 0.01;
122+
y << scatter_y;
119123
}
120124

121125
int column_x = ds->addCopiedColumn(x, name + "scatter x");

0 commit comments

Comments
 (0)