Skip to content

Commit b71ce2e

Browse files
committed
Vectorize anndata metric bottlenecks
1 parent 965c1bf commit b71ce2e

2 files changed

Lines changed: 264 additions & 45 deletions

File tree

src/cell_eval/metrics/_anndata.py

Lines changed: 152 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -20,17 +20,20 @@
2020

2121
logger = getLogger(__name__)
2222

23+
_L1_METRICS = {"l1", "manhattan", "cityblock"}
24+
_L2_METRICS = {"l2", "euclidean"}
25+
2326

2427
def pearson_delta(
2528
data: PerturbationAnndataPair, embed_key: str | None = None
2629
) -> dict[str, float]:
2730
"""Compute Pearson correlation between mean differences from control."""
28-
return _generic_evaluation(
29-
data,
30-
pearsonr,
31-
use_delta=True,
32-
embed_key=embed_key,
33-
)
31+
real_effects, pred_effects = _bulk_effect_matrices(data, embed_key=embed_key)
32+
correlations = _rowwise_pearson(pred_effects, real_effects)
33+
return {
34+
str(pert): float(correlation)
35+
for pert, correlation in zip(data.perts, correlations)
36+
}
3437

3538

3639
def mse(
@@ -149,53 +152,157 @@ def discrimination_score(
149152
# Ignore the embedding key for L1
150153
embed_key = None
151154

152-
# Compute perturbation effects for all perturbations
153-
real_effects = np.vstack(
154-
[
155-
d.perturbation_effect(which="real", abs=False)
156-
for d in data.iter_bulk_arrays(embed_key=embed_key)
157-
]
155+
real_effects, pred_effects = _bulk_effect_matrices(data, embed_key=embed_key)
156+
excluded_indices = _excluded_gene_indices(
157+
data,
158+
embed_key=embed_key,
159+
exclude_target_gene=exclude_target_gene,
158160
)
159-
pred_effects = np.vstack(
160-
[
161-
d.perturbation_effect(which="pred", abs=False)
162-
for d in data.iter_bulk_arrays(embed_key=embed_key)
163-
]
161+
distances = _pairwise_distances_with_exclusions(
162+
pred_effects=pred_effects,
163+
real_effects=real_effects,
164+
metric=metric,
165+
excluded_indices=excluded_indices,
164166
)
167+
order = np.argsort(distances, axis=1)
168+
ranks = np.argmax(order == np.arange(data.perts.size)[:, None], axis=1)
165169

166-
norm_ranks = {}
167-
for p_idx, p in enumerate(data.perts):
168-
# Determine which features to include in the comparison
169-
if exclude_target_gene and not embed_key:
170-
# For expression data, exclude the target gene
171-
include_mask = np.flatnonzero(data.genes != p)
172-
else:
173-
# For embedding data or when not excluding target gene, use all features
174-
include_mask = np.ones(real_effects.shape[1], dtype=bool)
170+
return {
171+
str(pert): 1 - float(rank) / data.perts.size
172+
for pert, rank in zip(data.perts, ranks)
173+
}
175174

176-
# Compute distances to all real effects
177-
distances = skm.pairwise_distances(
178-
real_effects[
179-
:, include_mask
180-
], # compare to all real effects across perturbations
181-
pred_effects[p_idx, include_mask].reshape(
182-
1, -1
183-
), # select pred effect for current perturbation
184-
metric=metric,
185-
).flatten()
186175

187-
# Sort by distance (ascending - lower distance = better match)
188-
sorted_indices = np.argsort(distances)
176+
def _bulk_effect_matrices(
177+
data: PerturbationAnndataPair,
178+
embed_key: str | None = None,
179+
) -> tuple[np.ndarray, np.ndarray]:
180+
"""Return real/pred perturbation-control effects in data.perts order."""
181+
data._initialize_bulk_arrays(embed_key)
182+
cache_key = embed_key or "_default"
183+
assert data.bulk_real is not None
184+
assert data.bulk_pred is not None
185+
keys, real_bulk = data.bulk_real[cache_key]
186+
_, pred_bulk = data.bulk_pred[cache_key]
187+
positions = {str(key): idx for idx, key in enumerate(keys)}
188+
pert_positions = np.array([positions[str(pert)] for pert in data.perts])
189+
ctrl_position = positions[str(data.control_pert)]
190+
real_effects = real_bulk[pert_positions] - real_bulk[ctrl_position]
191+
pred_effects = pred_bulk[pert_positions] - pred_bulk[ctrl_position]
192+
return np.asarray(real_effects), np.asarray(pred_effects)
193+
194+
195+
def _rowwise_pearson(x: np.ndarray, y: np.ndarray) -> np.ndarray:
196+
x = np.asarray(x, dtype=np.float64)
197+
y = np.asarray(y, dtype=np.float64)
198+
x_centered = x - x.mean(axis=1, keepdims=True)
199+
y_centered = y - y.mean(axis=1, keepdims=True)
200+
numerator = np.sum(x_centered * y_centered, axis=1)
201+
denominator = np.sqrt(
202+
np.sum(x_centered * x_centered, axis=1)
203+
* np.sum(y_centered * y_centered, axis=1)
204+
)
205+
correlations = np.full(x.shape[0], np.nan, dtype=np.float64)
206+
np.divide(numerator, denominator, out=correlations, where=denominator > 0)
207+
return correlations
189208

190-
# Find rank of the correct perturbation
191-
p_index = np.flatnonzero(data.perts == p)[0]
192-
rank = np.flatnonzero(sorted_indices == p_index)[0]
193209

194-
# Normalize rank by total number of perturbations
195-
norm_rank = rank / data.perts.size
196-
norm_ranks[str(p)] = 1 - norm_rank
210+
def _excluded_gene_indices(
211+
data: PerturbationAnndataPair,
212+
embed_key: str | None,
213+
exclude_target_gene: bool,
214+
) -> list[np.ndarray]:
215+
if embed_key or not exclude_target_gene:
216+
return [np.array([], dtype=np.int64) for _ in data.perts]
217+
return [np.flatnonzero(data.genes == pert) for pert in data.perts]
218+
219+
220+
def _pairwise_distances_with_exclusions(
221+
pred_effects: np.ndarray,
222+
real_effects: np.ndarray,
223+
metric: str,
224+
excluded_indices: list[np.ndarray],
225+
) -> np.ndarray:
226+
pred_effects = np.asarray(pred_effects, dtype=np.float64)
227+
real_effects = np.asarray(real_effects, dtype=np.float64)
228+
has_exclusions = any(indices.size > 0 for indices in excluded_indices)
229+
230+
if metric in _L1_METRICS:
231+
distances = skm.pairwise_distances(
232+
pred_effects, real_effects, metric="manhattan"
233+
)
234+
for idx, excluded in enumerate(excluded_indices):
235+
if excluded.size:
236+
distances[idx] -= np.abs(
237+
real_effects[:, excluded] - pred_effects[idx, excluded]
238+
).sum(axis=1)
239+
np.maximum(distances, 0, out=distances)
240+
return distances
241+
242+
if metric in _L2_METRICS:
243+
pred_sq = np.sum(pred_effects * pred_effects, axis=1)
244+
real_sq = np.sum(real_effects * real_effects, axis=1)
245+
distances_sq = (
246+
pred_sq[:, None] + real_sq[None, :] - 2 * (pred_effects @ real_effects.T)
247+
)
248+
for idx, excluded in enumerate(excluded_indices):
249+
if excluded.size:
250+
excluded_delta = real_effects[:, excluded] - pred_effects[idx, excluded]
251+
distances_sq[idx] -= np.sum(excluded_delta * excluded_delta, axis=1)
252+
np.maximum(distances_sq, 0, out=distances_sq)
253+
return np.sqrt(distances_sq)
254+
255+
if metric == "cosine":
256+
return _cosine_distances_with_exclusions(
257+
pred_effects=pred_effects,
258+
real_effects=real_effects,
259+
excluded_indices=excluded_indices,
260+
)
261+
262+
if not has_exclusions:
263+
return skm.pairwise_distances(pred_effects, real_effects, metric=metric)
264+
265+
distances = np.empty((pred_effects.shape[0], real_effects.shape[0]))
266+
for idx, excluded in enumerate(excluded_indices):
267+
include_mask = np.ones(real_effects.shape[1], dtype=bool)
268+
include_mask[excluded] = False
269+
distances[idx] = skm.pairwise_distances(
270+
pred_effects[idx, include_mask].reshape(1, -1),
271+
real_effects[:, include_mask],
272+
metric=metric,
273+
).ravel()
274+
return distances
275+
276+
277+
def _cosine_distances_with_exclusions(
278+
pred_effects: np.ndarray,
279+
real_effects: np.ndarray,
280+
excluded_indices: list[np.ndarray],
281+
) -> np.ndarray:
282+
dot = pred_effects @ real_effects.T
283+
pred_sq = np.sum(pred_effects * pred_effects, axis=1)
284+
real_sq = np.sum(real_effects * real_effects, axis=1)
285+
distances = np.empty_like(dot)
286+
287+
for idx, excluded in enumerate(excluded_indices):
288+
row_dot = dot[idx].copy()
289+
row_pred_sq = pred_sq[idx]
290+
row_real_sq = real_sq.copy()
291+
if excluded.size:
292+
row_dot -= (real_effects[:, excluded] * pred_effects[idx, excluded]).sum(
293+
axis=1
294+
)
295+
row_pred_sq -= float(np.sum(pred_effects[idx, excluded] ** 2))
296+
row_real_sq -= np.sum(real_effects[:, excluded] ** 2, axis=1)
297+
denominator = np.sqrt(max(row_pred_sq, 0.0)) * np.sqrt(
298+
np.maximum(row_real_sq, 0.0)
299+
)
300+
similarity = np.zeros_like(row_dot)
301+
np.divide(row_dot, denominator, out=similarity, where=denominator > 0)
302+
distances[idx] = 1 - similarity
197303

198-
return norm_ranks
304+
np.clip(distances, 0, 2, out=distances)
305+
return distances
199306

200307

201308
def _generic_evaluation(

tests/test_anndata_metrics.py

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
import numpy as np
2+
import sklearn.metrics as skm
3+
from scipy.stats import pearsonr
4+
5+
from cell_eval._types import PerturbationAnndataPair
6+
from cell_eval.data import CONTROL_VAR, PERT_COL, build_random_anndata
7+
from cell_eval.metrics._anndata import discrimination_score, pearson_delta
8+
9+
10+
def _metric_pair() -> PerturbationAnndataPair:
11+
real = build_random_anndata(
12+
n_cells=500,
13+
n_genes=12,
14+
n_perts=5,
15+
n_celltypes=1,
16+
random_state=17,
17+
)
18+
pert_names = [f"pert_{idx}" for idx in range(5)]
19+
real.var_names = pert_names + [f"gene_{idx}" for idx in range(7)]
20+
labels = np.resize(np.array([CONTROL_VAR, *pert_names]), real.n_obs)
21+
real.obs[PERT_COL] = labels
22+
23+
pred = real.copy()
24+
rng = np.random.default_rng(23)
25+
pred.X = np.clip(
26+
np.asarray(pred.X) + rng.normal(0, 0.01, size=pred.X.shape), 0, None
27+
)
28+
29+
return PerturbationAnndataPair(
30+
real=real,
31+
pred=pred,
32+
control_pert=CONTROL_VAR,
33+
pert_col=PERT_COL,
34+
)
35+
36+
37+
def _reference_pearson_delta(data: PerturbationAnndataPair) -> dict[str, float]:
38+
res = {}
39+
for bulk_array in data.iter_bulk_arrays():
40+
x = bulk_array.perturbation_effect(which="pred", abs=False)
41+
y = bulk_array.perturbation_effect(which="real", abs=False)
42+
res[bulk_array.key] = float(pearsonr(x, y).correlation)
43+
return res
44+
45+
46+
def _reference_discrimination_score(
47+
data: PerturbationAnndataPair,
48+
metric: str,
49+
exclude_target_gene: bool = True,
50+
) -> dict[str, float]:
51+
real_effects = np.vstack(
52+
[
53+
d.perturbation_effect(which="real", abs=False)
54+
for d in data.iter_bulk_arrays()
55+
]
56+
)
57+
pred_effects = np.vstack(
58+
[
59+
d.perturbation_effect(which="pred", abs=False)
60+
for d in data.iter_bulk_arrays()
61+
]
62+
)
63+
64+
norm_ranks = {}
65+
for p_idx, pert in enumerate(data.perts):
66+
if exclude_target_gene:
67+
include_mask = np.flatnonzero(data.genes != pert)
68+
else:
69+
include_mask = np.ones(real_effects.shape[1], dtype=bool)
70+
distances = skm.pairwise_distances(
71+
real_effects[:, include_mask],
72+
pred_effects[p_idx, include_mask].reshape(1, -1),
73+
metric=metric,
74+
).flatten()
75+
sorted_indices = np.argsort(distances)
76+
pert_index = np.flatnonzero(data.perts == pert)[0]
77+
rank = np.flatnonzero(sorted_indices == pert_index)[0]
78+
norm_ranks[str(pert)] = 1 - rank / data.perts.size
79+
return norm_ranks
80+
81+
82+
def test_pearson_delta_matches_reference() -> None:
83+
data = _metric_pair()
84+
85+
expected = _reference_pearson_delta(data)
86+
actual = pearson_delta(data)
87+
88+
assert actual.keys() == expected.keys()
89+
np.testing.assert_allclose(
90+
list(actual.values()),
91+
list(expected.values()),
92+
rtol=1e-12,
93+
atol=1e-12,
94+
equal_nan=True,
95+
)
96+
97+
98+
def test_discrimination_score_matches_reference() -> None:
99+
data = _metric_pair()
100+
101+
for metric in ["l1", "l2", "cosine"]:
102+
expected = _reference_discrimination_score(data, metric=metric)
103+
actual = discrimination_score(data, metric=metric)
104+
105+
assert actual.keys() == expected.keys()
106+
np.testing.assert_allclose(
107+
list(actual.values()),
108+
list(expected.values()),
109+
rtol=1e-12,
110+
atol=1e-12,
111+
equal_nan=True,
112+
)

0 commit comments

Comments
 (0)