@@ -92,14 +92,18 @@ def lda_loadings(group1_data, group2_data):
9292 return _project_and_pdf (diffVect , group1_data , group2_data , combined_data )
9393
9494
95- def _project_and_pdf (diffVect , group1_data , group2_data , combined_data ):
95+ def _project_and_pdf (diffVect , group1_data , group2_data , combined_data , normalize_projections = True ):
9696 """Shared logic for projecting groups onto a discriminant direction and fitting PDFs.
9797
9898 Args:
9999 diffVect: Discriminant direction vector (features,)
100100 group1_data: PCA loadings for group 1 (features x samples)
101101 group2_data: PCA loadings for group 2 (features x samples)
102102 combined_data: Concatenation of group1_data and group2_data (features x all_samples)
103+ normalize_projections: If True, normalize so group means map to -1 and +1.
104+ This works well when diffVect is aligned with the mean difference (e.g. LDA).
105+ Set to False for directions that may not be aligned with the mean difference
106+ (e.g. DWD), which would cause the normalization to produce extreme values.
103107
104108 Returns: 6-tuple (group1_x, group2_x, group1_pdf, group2_pdf, group1_map, group2_map)
105109 """
@@ -122,12 +126,14 @@ def _project_and_pdf(diffVect, group1_data, group2_data, combined_data):
122126 for ii in range (group1_num ):
123127 subjDiff = group1_data [:, ii ] - overall_mean
124128 group1_map [ii ] = np .dot (diffVect , subjDiff )
125- group1_map [ii ] = normalize (group1_map [ii ], group1_mean_map , group2_mean_map )
129+ if normalize_projections :
130+ group1_map [ii ] = normalize (group1_map [ii ], group1_mean_map , group2_mean_map )
126131
127132 for ii in range (group2_num ):
128133 subjDiff = group2_data [:, ii ] - overall_mean
129134 group2_map [ii ] = np .dot (diffVect , subjDiff )
130- group2_map [ii ] = normalize (group2_map [ii ], group1_mean_map , group2_mean_map )
135+ if normalize_projections :
136+ group2_map [ii ] = normalize (group2_map [ii ], group1_mean_map , group2_mean_map )
131137
132138 group1_map_mean = group1_map .mean ()
133139 group2_map_mean = group2_map .mean ()
@@ -142,8 +148,20 @@ def _project_and_pdf(diffVect, group1_data, group2_data, combined_data):
142148 if group2_map_std < min_std :
143149 group2_map_std = min_std
144150
145- group1_x = np .linspace (group1_map_mean - 6 , group1_map_mean + 6 , num = 300 )
146- group2_x = np .linspace (group2_map_mean - 6 , group2_map_mean + 6 , num = 300 )
151+ if normalize_projections :
152+ group1_x = np .linspace (group1_map_mean - 6 , group1_map_mean + 6 , num = 300 )
153+ group2_x = np .linspace (group2_map_mean - 6 , group2_map_mean + 6 , num = 300 )
154+ else :
155+ # Common x-range covering both groups and all shape mappings so PDF
156+ # tails extend smoothly across the full plot
157+ all_maps = np .concatenate ([group1_map , group2_map ])
158+ max_std = max (group1_map_std , group2_map_std )
159+ x_min = min (all_maps .min (), group1_map_mean - 6 * group1_map_std ,
160+ group2_map_mean - 6 * group2_map_std ) - max_std
161+ x_max = max (all_maps .max (), group1_map_mean + 6 * group1_map_std ,
162+ group2_map_mean + 6 * group2_map_std ) + max_std
163+ group1_x = np .linspace (x_min , x_max , num = 300 )
164+ group2_x = np .linspace (x_min , x_max , num = 300 )
147165
148166 group1_pdf = stats .norm .pdf (group1_x , group1_map_mean , group1_map_std )
149167 group2_pdf = stats .norm .pdf (group2_x , group2_map_mean , group2_map_std )
@@ -172,7 +190,17 @@ def dwd_loadings(group1_data, group2_data):
172190
173191 diffVect = model .coef_ .flatten ()
174192
175- return _project_and_pdf (diffVect , group1_data , group2_data , combined_data )
193+ # Normalize to unit length so projections reflect data geometry, not solver scale
194+ norm = np .linalg .norm (diffVect )
195+ if norm > 1e-12 :
196+ diffVect = diffVect / norm
197+
198+ # DWD's direction optimizes for margin, not mean separation, so it may be
199+ # nearly orthogonal to the mean difference. The mean-based normalization in
200+ # _project_and_pdf divides by the projection of the mean difference onto
201+ # diffVect, which can be near-zero, producing extreme values.
202+ # Use raw projections with adaptive PDF ranges instead.
203+ return _project_and_pdf (diffVect , group1_data , group2_data , combined_data , normalize_projections = False )
176204
177205
178206def lda (data ):
0 commit comments