Skip to content

Commit 8d9c5a8

Browse files
committed
Fix DWD group analysis producing broken plots
DWD's discriminant direction optimizes for margin, not mean separation, so it can be nearly orthogonal to the group mean difference. The normalize() function divides by the mean projection gap, which blows up when that gap is near-zero, producing shape mappings at ±100 and flat PDFs. Skip mean-based normalization for DWD and use raw unit-length projections with adaptive PDF x-ranges instead. Also scale scatter dot y-position to peak PDF height so dots stay at the plot bottom regardless of PDF scale.
1 parent 7afb08c commit 8d9c5a8

3 files changed

Lines changed: 44 additions & 8 deletions

File tree

Python/shapeworks/shapeworks/stats.py

Lines changed: 34 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -92,14 +92,18 @@ def lda_loadings(group1_data, group2_data):
9292
return _project_and_pdf(diffVect, group1_data, group2_data, combined_data)
9393

9494

95-
def _project_and_pdf(diffVect, group1_data, group2_data, combined_data):
95+
def _project_and_pdf(diffVect, group1_data, group2_data, combined_data, normalize_projections=True):
9696
"""Shared logic for projecting groups onto a discriminant direction and fitting PDFs.
9797
9898
Args:
9999
diffVect: Discriminant direction vector (features,)
100100
group1_data: PCA loadings for group 1 (features x samples)
101101
group2_data: PCA loadings for group 2 (features x samples)
102102
combined_data: Concatenation of group1_data and group2_data (features x all_samples)
103+
normalize_projections: If True, normalize so group means map to -1 and +1.
104+
This works well when diffVect is aligned with the mean difference (e.g. LDA).
105+
Set to False for directions that may not be aligned with the mean difference
106+
(e.g. DWD), which would cause the normalization to produce extreme values.
103107
104108
Returns: 6-tuple (group1_x, group2_x, group1_pdf, group2_pdf, group1_map, group2_map)
105109
"""
@@ -122,12 +126,14 @@ def _project_and_pdf(diffVect, group1_data, group2_data, combined_data):
122126
for ii in range(group1_num):
123127
subjDiff = group1_data[:, ii] - overall_mean
124128
group1_map[ii] = np.dot(diffVect, subjDiff)
125-
group1_map[ii] = normalize(group1_map[ii], group1_mean_map, group2_mean_map)
129+
if normalize_projections:
130+
group1_map[ii] = normalize(group1_map[ii], group1_mean_map, group2_mean_map)
126131

127132
for ii in range(group2_num):
128133
subjDiff = group2_data[:, ii] - overall_mean
129134
group2_map[ii] = np.dot(diffVect, subjDiff)
130-
group2_map[ii] = normalize(group2_map[ii], group1_mean_map, group2_mean_map)
135+
if normalize_projections:
136+
group2_map[ii] = normalize(group2_map[ii], group1_mean_map, group2_mean_map)
131137

132138
group1_map_mean = group1_map.mean()
133139
group2_map_mean = group2_map.mean()
@@ -142,8 +148,20 @@ def _project_and_pdf(diffVect, group1_data, group2_data, combined_data):
142148
if group2_map_std < min_std:
143149
group2_map_std = min_std
144150

145-
group1_x = np.linspace(group1_map_mean - 6, group1_map_mean + 6, num=300)
146-
group2_x = np.linspace(group2_map_mean - 6, group2_map_mean + 6, num=300)
151+
if normalize_projections:
152+
group1_x = np.linspace(group1_map_mean - 6, group1_map_mean + 6, num=300)
153+
group2_x = np.linspace(group2_map_mean - 6, group2_map_mean + 6, num=300)
154+
else:
155+
# Common x-range covering both groups and all shape mappings so PDF
156+
# tails extend smoothly across the full plot
157+
all_maps = np.concatenate([group1_map, group2_map])
158+
max_std = max(group1_map_std, group2_map_std)
159+
x_min = min(all_maps.min(), group1_map_mean - 6 * group1_map_std,
160+
group2_map_mean - 6 * group2_map_std) - max_std
161+
x_max = max(all_maps.max(), group1_map_mean + 6 * group1_map_std,
162+
group2_map_mean + 6 * group2_map_std) + max_std
163+
group1_x = np.linspace(x_min, x_max, num=300)
164+
group2_x = np.linspace(x_min, x_max, num=300)
147165

148166
group1_pdf = stats.norm.pdf(group1_x, group1_map_mean, group1_map_std)
149167
group2_pdf = stats.norm.pdf(group2_x, group2_map_mean, group2_map_std)
@@ -172,7 +190,17 @@ def dwd_loadings(group1_data, group2_data):
172190

173191
diffVect = model.coef_.flatten()
174192

175-
return _project_and_pdf(diffVect, group1_data, group2_data, combined_data)
193+
# Normalize to unit length so projections reflect data geometry, not solver scale
194+
norm = np.linalg.norm(diffVect)
195+
if norm > 1e-12:
196+
diffVect = diffVect / norm
197+
198+
# DWD's direction optimizes for margin, not mean separation, so it may be
199+
# nearly orthogonal to the mean difference. The mean-based normalization in
200+
# _project_and_pdf divides by the projection of the mean difference onto
201+
# diffVect, which can be near-zero, producing extreme values.
202+
# Use raw projections with adaptive PDF ranges instead.
203+
return _project_and_pdf(diffVect, group1_data, group2_data, combined_data, normalize_projections=False)
176204

177205

178206
def lda(data):

Studio/Job/StatsGroupDWDJob.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,11 +112,15 @@ void StatsGroupDWDJob::plot(JKQTPlotter* plot, QString group_1_name, QString gro
112112
draw_line_plot(group1_x_, group1_pdf_, group_1_name, QColor(239, 133, 54));
113113
draw_line_plot(group2_x_, group2_pdf_, group_2_name, Qt::blue);
114114

115+
// Place shape mapping dots near the bottom of the plot, scaled to peak PDF height
116+
double peak_pdf = std::max(group1_pdf_.maxCoeff(), group2_pdf_.maxCoeff());
117+
double scatter_y = peak_pdf * 0.03;
118+
115119
auto draw_scatter_plot = [&](Eigen::MatrixXd map, QString name, QColor color) {
116120
QVector<double> x, y;
117121
for (int i = 0; i < map.size(); i++) {
118122
x << map(i);
119-
y << 0.01;
123+
y << scatter_y;
120124
}
121125

122126
int column_x = ds->addCopiedColumn(x, name + "scatter x");

Studio/Job/StatsGroupLDAJob.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,11 +111,15 @@ void StatsGroupLDAJob::plot(JKQTPlotter* plot, QString group_1_name, QString gro
111111
draw_line_plot(group1_x_, group1_pdf_, group_1_name, QColor(239, 133, 54));
112112
draw_line_plot(group2_x_, group2_pdf_, group_2_name, Qt::blue);
113113

114+
// Place shape mapping dots near the bottom of the plot, scaled to peak PDF height
115+
double peak_pdf = std::max(group1_pdf_.maxCoeff(), group2_pdf_.maxCoeff());
116+
double scatter_y = peak_pdf * 0.03;
117+
114118
auto draw_scatter_plot = [&](Eigen::MatrixXd map, QString name, QColor color) {
115119
QVector<double> x, y;
116120
for (int i = 0; i < map.size(); i++) {
117121
x << map(i);
118-
y << 0.01;
122+
y << scatter_y;
119123
}
120124

121125
int column_x = ds->addCopiedColumn(x, name + "scatter x");

0 commit comments

Comments
 (0)