|
20 | 20 |
|
21 | 21 | logger = getLogger(__name__) |
22 | 22 |
|
| 23 | +_L1_METRICS = {"l1", "manhattan", "cityblock"} |
| 24 | +_L2_METRICS = {"l2", "euclidean"} |
| 25 | + |
23 | 26 |
|
24 | 27 | def pearson_delta( |
25 | 28 | data: PerturbationAnndataPair, embed_key: str | None = None |
26 | 29 | ) -> dict[str, float]: |
27 | 30 | """Compute Pearson correlation between mean differences from control.""" |
28 | | - return _generic_evaluation( |
29 | | - data, |
30 | | - pearsonr, |
31 | | - use_delta=True, |
32 | | - embed_key=embed_key, |
33 | | - ) |
| 31 | + real_effects, pred_effects = _bulk_effect_matrices(data, embed_key=embed_key) |
| 32 | + correlations = _rowwise_pearson(pred_effects, real_effects) |
| 33 | + return { |
| 34 | + str(pert): float(correlation) |
| 35 | + for pert, correlation in zip(data.perts, correlations) |
| 36 | + } |
34 | 37 |
|
35 | 38 |
|
36 | 39 | def mse( |
@@ -149,53 +152,157 @@ def discrimination_score( |
149 | 152 | # Ignore the embedding key for L1 |
150 | 153 | embed_key = None |
151 | 154 |
|
152 | | - # Compute perturbation effects for all perturbations |
153 | | - real_effects = np.vstack( |
154 | | - [ |
155 | | - d.perturbation_effect(which="real", abs=False) |
156 | | - for d in data.iter_bulk_arrays(embed_key=embed_key) |
157 | | - ] |
| 155 | + real_effects, pred_effects = _bulk_effect_matrices(data, embed_key=embed_key) |
| 156 | + excluded_indices = _excluded_gene_indices( |
| 157 | + data, |
| 158 | + embed_key=embed_key, |
| 159 | + exclude_target_gene=exclude_target_gene, |
158 | 160 | ) |
159 | | - pred_effects = np.vstack( |
160 | | - [ |
161 | | - d.perturbation_effect(which="pred", abs=False) |
162 | | - for d in data.iter_bulk_arrays(embed_key=embed_key) |
163 | | - ] |
| 161 | + distances = _pairwise_distances_with_exclusions( |
| 162 | + pred_effects=pred_effects, |
| 163 | + real_effects=real_effects, |
| 164 | + metric=metric, |
| 165 | + excluded_indices=excluded_indices, |
164 | 166 | ) |
| 167 | + order = np.argsort(distances, axis=1) |
| 168 | + ranks = np.argmax(order == np.arange(data.perts.size)[:, None], axis=1) |
165 | 169 |
|
166 | | - norm_ranks = {} |
167 | | - for p_idx, p in enumerate(data.perts): |
168 | | - # Determine which features to include in the comparison |
169 | | - if exclude_target_gene and not embed_key: |
170 | | - # For expression data, exclude the target gene |
171 | | - include_mask = np.flatnonzero(data.genes != p) |
172 | | - else: |
173 | | - # For embedding data or when not excluding target gene, use all features |
174 | | - include_mask = np.ones(real_effects.shape[1], dtype=bool) |
| 170 | + return { |
| 171 | + str(pert): 1 - float(rank) / data.perts.size |
| 172 | + for pert, rank in zip(data.perts, ranks) |
| 173 | + } |
175 | 174 |
|
176 | | - # Compute distances to all real effects |
177 | | - distances = skm.pairwise_distances( |
178 | | - real_effects[ |
179 | | - :, include_mask |
180 | | - ], # compare to all real effects across perturbations |
181 | | - pred_effects[p_idx, include_mask].reshape( |
182 | | - 1, -1 |
183 | | - ), # select pred effect for current perturbation |
184 | | - metric=metric, |
185 | | - ).flatten() |
186 | 175 |
|
187 | | - # Sort by distance (ascending - lower distance = better match) |
188 | | - sorted_indices = np.argsort(distances) |
| 176 | +def _bulk_effect_matrices( |
| 177 | + data: PerturbationAnndataPair, |
| 178 | + embed_key: str | None = None, |
| 179 | +) -> tuple[np.ndarray, np.ndarray]: |
| 180 | + """Return real/pred perturbation-control effects in data.perts order.""" |
| 181 | + data._initialize_bulk_arrays(embed_key) |
| 182 | + cache_key = embed_key or "_default" |
| 183 | + assert data.bulk_real is not None |
| 184 | + assert data.bulk_pred is not None |
| 185 | + keys, real_bulk = data.bulk_real[cache_key] |
| 186 | + _, pred_bulk = data.bulk_pred[cache_key] |
| 187 | + positions = {str(key): idx for idx, key in enumerate(keys)} |
| 188 | + pert_positions = np.array([positions[str(pert)] for pert in data.perts]) |
| 189 | + ctrl_position = positions[str(data.control_pert)] |
| 190 | + real_effects = real_bulk[pert_positions] - real_bulk[ctrl_position] |
| 191 | + pred_effects = pred_bulk[pert_positions] - pred_bulk[ctrl_position] |
| 192 | + return np.asarray(real_effects), np.asarray(pred_effects) |
| 193 | + |
| 194 | + |
| 195 | +def _rowwise_pearson(x: np.ndarray, y: np.ndarray) -> np.ndarray: |
| 196 | + x = np.asarray(x, dtype=np.float64) |
| 197 | + y = np.asarray(y, dtype=np.float64) |
| 198 | + x_centered = x - x.mean(axis=1, keepdims=True) |
| 199 | + y_centered = y - y.mean(axis=1, keepdims=True) |
| 200 | + numerator = np.sum(x_centered * y_centered, axis=1) |
| 201 | + denominator = np.sqrt( |
| 202 | + np.sum(x_centered * x_centered, axis=1) |
| 203 | + * np.sum(y_centered * y_centered, axis=1) |
| 204 | + ) |
| 205 | + correlations = np.full(x.shape[0], np.nan, dtype=np.float64) |
| 206 | + np.divide(numerator, denominator, out=correlations, where=denominator > 0) |
| 207 | + return correlations |
189 | 208 |
|
190 | | - # Find rank of the correct perturbation |
191 | | - p_index = np.flatnonzero(data.perts == p)[0] |
192 | | - rank = np.flatnonzero(sorted_indices == p_index)[0] |
193 | 209 |
|
194 | | - # Normalize rank by total number of perturbations |
195 | | - norm_rank = rank / data.perts.size |
196 | | - norm_ranks[str(p)] = 1 - norm_rank |
| 210 | +def _excluded_gene_indices( |
| 211 | + data: PerturbationAnndataPair, |
| 212 | + embed_key: str | None, |
| 213 | + exclude_target_gene: bool, |
| 214 | +) -> list[np.ndarray]: |
| 215 | + if embed_key or not exclude_target_gene: |
| 216 | + return [np.array([], dtype=np.int64) for _ in data.perts] |
| 217 | + return [np.flatnonzero(data.genes == pert) for pert in data.perts] |
| 218 | + |
| 219 | + |
| 220 | +def _pairwise_distances_with_exclusions( |
| 221 | + pred_effects: np.ndarray, |
| 222 | + real_effects: np.ndarray, |
| 223 | + metric: str, |
| 224 | + excluded_indices: list[np.ndarray], |
| 225 | +) -> np.ndarray: |
| 226 | + pred_effects = np.asarray(pred_effects, dtype=np.float64) |
| 227 | + real_effects = np.asarray(real_effects, dtype=np.float64) |
| 228 | + has_exclusions = any(indices.size > 0 for indices in excluded_indices) |
| 229 | + |
| 230 | + if metric in _L1_METRICS: |
| 231 | + distances = skm.pairwise_distances( |
| 232 | + pred_effects, real_effects, metric="manhattan" |
| 233 | + ) |
| 234 | + for idx, excluded in enumerate(excluded_indices): |
| 235 | + if excluded.size: |
| 236 | + distances[idx] -= np.abs( |
| 237 | + real_effects[:, excluded] - pred_effects[idx, excluded] |
| 238 | + ).sum(axis=1) |
| 239 | + np.maximum(distances, 0, out=distances) |
| 240 | + return distances |
| 241 | + |
| 242 | + if metric in _L2_METRICS: |
| 243 | + pred_sq = np.sum(pred_effects * pred_effects, axis=1) |
| 244 | + real_sq = np.sum(real_effects * real_effects, axis=1) |
| 245 | + distances_sq = ( |
| 246 | + pred_sq[:, None] + real_sq[None, :] - 2 * (pred_effects @ real_effects.T) |
| 247 | + ) |
| 248 | + for idx, excluded in enumerate(excluded_indices): |
| 249 | + if excluded.size: |
| 250 | + excluded_delta = real_effects[:, excluded] - pred_effects[idx, excluded] |
| 251 | + distances_sq[idx] -= np.sum(excluded_delta * excluded_delta, axis=1) |
| 252 | + np.maximum(distances_sq, 0, out=distances_sq) |
| 253 | + return np.sqrt(distances_sq) |
| 254 | + |
| 255 | + if metric == "cosine": |
| 256 | + return _cosine_distances_with_exclusions( |
| 257 | + pred_effects=pred_effects, |
| 258 | + real_effects=real_effects, |
| 259 | + excluded_indices=excluded_indices, |
| 260 | + ) |
| 261 | + |
| 262 | + if not has_exclusions: |
| 263 | + return skm.pairwise_distances(pred_effects, real_effects, metric=metric) |
| 264 | + |
| 265 | + distances = np.empty((pred_effects.shape[0], real_effects.shape[0])) |
| 266 | + for idx, excluded in enumerate(excluded_indices): |
| 267 | + include_mask = np.ones(real_effects.shape[1], dtype=bool) |
| 268 | + include_mask[excluded] = False |
| 269 | + distances[idx] = skm.pairwise_distances( |
| 270 | + pred_effects[idx, include_mask].reshape(1, -1), |
| 271 | + real_effects[:, include_mask], |
| 272 | + metric=metric, |
| 273 | + ).ravel() |
| 274 | + return distances |
| 275 | + |
| 276 | + |
| 277 | +def _cosine_distances_with_exclusions( |
| 278 | + pred_effects: np.ndarray, |
| 279 | + real_effects: np.ndarray, |
| 280 | + excluded_indices: list[np.ndarray], |
| 281 | +) -> np.ndarray: |
| 282 | + dot = pred_effects @ real_effects.T |
| 283 | + pred_sq = np.sum(pred_effects * pred_effects, axis=1) |
| 284 | + real_sq = np.sum(real_effects * real_effects, axis=1) |
| 285 | + distances = np.empty_like(dot) |
| 286 | + |
| 287 | + for idx, excluded in enumerate(excluded_indices): |
| 288 | + row_dot = dot[idx].copy() |
| 289 | + row_pred_sq = pred_sq[idx] |
| 290 | + row_real_sq = real_sq.copy() |
| 291 | + if excluded.size: |
| 292 | + row_dot -= (real_effects[:, excluded] * pred_effects[idx, excluded]).sum( |
| 293 | + axis=1 |
| 294 | + ) |
| 295 | + row_pred_sq -= float(np.sum(pred_effects[idx, excluded] ** 2)) |
| 296 | + row_real_sq -= np.sum(real_effects[:, excluded] ** 2, axis=1) |
| 297 | + denominator = np.sqrt(max(row_pred_sq, 0.0)) * np.sqrt( |
| 298 | + np.maximum(row_real_sq, 0.0) |
| 299 | + ) |
| 300 | + similarity = np.zeros_like(row_dot) |
| 301 | + np.divide(row_dot, denominator, out=similarity, where=denominator > 0) |
| 302 | + distances[idx] = 1 - similarity |
197 | 303 |
|
198 | | - return norm_ranks |
| 304 | + np.clip(distances, 0, 2, out=distances) |
| 305 | + return distances |
199 | 306 |
|
200 | 307 |
|
201 | 308 | def _generic_evaluation( |
|
0 commit comments