Skip to content

Commit 4842022

Browse files
- Applied black formatting.
1 parent ca6f5b1 commit 4842022

33 files changed

Lines changed: 226 additions & 128 deletions

scripts/batch_qc_stats.py

Lines changed: 41 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,46 +4,75 @@
44
import matplotlib.pyplot as plt
55
from scipy.sparse import csr_matrix
66

7+
78
def calculate_qc_stats(h5ad_path, min_umis=500, min_genes=200, max_mt=0.15):
89
with h5py.File(h5ad_path, "r") as f:
9-
barcodes = [b.decode("utf-8") if isinstance(b, bytes) else b for b in f["obs"]["_index"][:]]
10-
gene_names = [g.decode("utf-8") if isinstance(g, bytes) else g for g in f["var"]["_index"][:]]
10+
barcodes = [
11+
b.decode("utf-8") if isinstance(b, bytes) else b
12+
for b in f["obs"]["_index"][:]
13+
]
14+
gene_names = [
15+
g.decode("utf-8") if isinstance(g, bytes) else g
16+
for g in f["var"]["_index"][:]
17+
]
1118
X_group = f["X"]
1219
if isinstance(X_group, h5py.Group):
13-
X = csr_matrix((X_group["data"][:], X_group["indices"][:], X_group["indptr"][:]), shape=(len(barcodes), len(gene_names)))
20+
X = csr_matrix(
21+
(X_group["data"][:], X_group["indices"][:], X_group["indptr"][:]),
22+
shape=(len(barcodes), len(gene_names)),
23+
)
1424
else:
1525
X = f["X"][:]
1626

1727
n_counts = np.array(X.sum(axis=1)).flatten()
1828
n_genes = np.array((X > 0).sum(axis=1)).flatten()
19-
mt_genes = [i for i, name in enumerate(gene_names) if "mt-" in name.lower() or "mt:" in name.lower()]
29+
mt_genes = [
30+
i
31+
for i, name in enumerate(gene_names)
32+
if "mt-" in name.lower() or "mt:" in name.lower()
33+
]
2034
if mt_genes:
2135
mt_counts = np.array(X[:, mt_genes].sum(axis=1)).flatten()
2236
pct_counts_mt = mt_counts / (n_counts + 1e-9)
2337
else:
2438
pct_counts_mt = np.zeros_like(n_counts)
2539

26-
keep_mask = (n_counts >= min_umis) & (n_genes >= min_genes) & (pct_counts_mt <= max_mt)
27-
return len(barcodes), np.sum(keep_mask), np.sum(n_counts < min_umis), np.sum(n_genes < min_genes), np.sum(pct_counts_mt > max_mt)
40+
keep_mask = (
41+
(n_counts >= min_umis) & (n_genes >= min_genes) & (pct_counts_mt <= max_mt)
42+
)
43+
return (
44+
len(barcodes),
45+
np.sum(keep_mask),
46+
np.sum(n_counts < min_umis),
47+
np.sum(n_genes < min_genes),
48+
np.sum(pct_counts_mt > max_mt),
49+
)
50+
2851

2952
if __name__ == "__main__":
3053
st_dir = r"A:\hest_data\st"
31-
54+
3255
# Get all h5ad files
33-
samples = [f.replace(".h5ad", "") for f in os.listdir(st_dir) if f.endswith(".h5ad")]
56+
samples = [
57+
f.replace(".h5ad", "") for f in os.listdir(st_dir) if f.endswith(".h5ad")
58+
]
3459
samples.sort()
35-
60+
3661
print(f"Analyzing {len(samples)} samples...")
37-
print(f"{'Sample':<15} | {'Total':<6} | {'Kept':<6} | {'% Kept':<8} | {'Low UMI':<8} | {'Low Gene':<8} | {'High MT':<8}")
62+
print(
63+
f"{'Sample':<15} | {'Total':<6} | {'Kept':<6} | {'% Kept':<8} | {'Low UMI':<8} | {'Low Gene':<8} | {'High MT':<8}"
64+
)
3865
print("-" * 80)
39-
66+
4067
all_results = []
4168
for s in samples:
4269
path = os.path.join(st_dir, f"{s}.h5ad")
4370
try:
4471
total, kept, l_umi, l_gene, h_mt = calculate_qc_stats(path)
4572
pct = kept / total if total > 0 else 0
46-
print(f"{s:<15} | {total:<6} | {kept:<6} | {pct:7.1%} | {l_umi:<8} | {l_gene:<8} | {h_mt:<8}")
73+
print(
74+
f"{s:<15} | {total:<6} | {kept:<6} | {pct:7.1%} | {l_umi:<8} | {l_gene:<8} | {h_mt:<8}"
75+
)
4776
all_results.append((total, kept, pct))
4877
except Exception as e:
4978
print(f"Error processing {s}: {e}")

scripts/diagnose_qc.py

Lines changed: 48 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -4,24 +4,33 @@
44
import matplotlib.pyplot as plt
55
from scipy.sparse import csr_matrix
66

7+
78
def diagnose_qc(h5ad_path, output_path, min_umis=500, min_genes=200, max_mt=0.15):
89
print(f"Loading {h5ad_path}...")
9-
10+
1011
with h5py.File(h5ad_path, "r") as f:
1112
# Load barcodes and gene names
12-
barcodes = [b.decode("utf-8") if isinstance(b, bytes) else b for b in f["obs"]["_index"][:]]
13-
gene_names = [g.decode("utf-8") if isinstance(g, bytes) else g for g in f["var"]["_index"][:]]
14-
13+
barcodes = [
14+
b.decode("utf-8") if isinstance(b, bytes) else b
15+
for b in f["obs"]["_index"][:]
16+
]
17+
gene_names = [
18+
g.decode("utf-8") if isinstance(g, bytes) else g
19+
for g in f["var"]["_index"][:]
20+
]
21+
1522
# Load expression matrix X
1623
X_group = f["X"]
1724
if isinstance(X_group, h5py.Group):
1825
data = X_group["data"][:]
1926
indices = X_group["indices"][:]
2027
indptr = X_group["indptr"][:]
21-
X = csr_matrix((data, indices, indptr), shape=(len(barcodes), len(gene_names)))
28+
X = csr_matrix(
29+
(data, indices, indptr), shape=(len(barcodes), len(gene_names))
30+
)
2231
else:
2332
X = f["X"][:]
24-
33+
2534
# Spatial coordinates (usually in obsm/spatial)
2635
if "obsm" in f and "spatial" in f["obsm"]:
2736
coords = f["obsm"]["spatial"][:]
@@ -33,10 +42,14 @@ def diagnose_qc(h5ad_path, output_path, min_umis=500, min_genes=200, max_mt=0.15
3342
# Calculate QC metrics
3443
n_counts = np.array(X.sum(axis=1)).flatten()
3544
n_genes = np.array((X > 0).sum(axis=1)).flatten()
36-
45+
3746
# MT fraction
3847
# Robust MT detection: check for mt- or MT- anywhere, but prioritize common patterns
39-
mt_genes = [i for i, name in enumerate(gene_names) if "mt-" in name.lower() or "mt:" in name.lower()]
48+
mt_genes = [
49+
i
50+
for i, name in enumerate(gene_names)
51+
if "mt-" in name.lower() or "mt:" in name.lower()
52+
]
4053
if mt_genes:
4154
print(f"Found {len(mt_genes)} mitochondrial genes.")
4255
mt_counts = np.array(X[:, mt_genes].sum(axis=1)).flatten()
@@ -49,12 +62,12 @@ def diagnose_qc(h5ad_path, output_path, min_umis=500, min_genes=200, max_mt=0.15
4962
pass_umi = n_counts >= min_umis
5063
pass_gene = n_genes >= min_genes
5164
pass_mt = pct_counts_mt <= max_mt
52-
65+
5366
keep_mask = pass_umi & pass_gene & pass_mt
54-
67+
5568
total_spots = len(barcodes)
5669
kept_spots = np.sum(keep_mask)
57-
70+
5871
print(f"Total spots: {total_spots}")
5972
print(f"Kept spots: {kept_spots} ({kept_spots/total_spots:.1%})")
6073
print(f"Filtered: {total_spots - kept_spots}")
@@ -64,31 +77,41 @@ def diagnose_qc(h5ad_path, output_path, min_umis=500, min_genes=200, max_mt=0.15
6477

6578
# Plotting
6679
fig, axes = plt.subplots(2, 2, figsize=(15, 12))
67-
80+
6881
# 1. UMI vs Genes
6982
axes[0, 0].scatter(n_counts, n_genes, c=keep_mask, cmap="RdYlGn", alpha=0.5, s=10)
70-
axes[0, 0].axvline(min_umis, color="red", linestyle="--", label=f"Min UMI={min_umis}")
71-
axes[0, 0].axhline(min_genes, color="blue", linestyle="--", label=f"Min Genes={min_genes}")
83+
axes[0, 0].axvline(
84+
min_umis, color="red", linestyle="--", label=f"Min UMI={min_umis}"
85+
)
86+
axes[0, 0].axhline(
87+
min_genes, color="blue", linestyle="--", label=f"Min Genes={min_genes}"
88+
)
7289
axes[0, 0].set_xlabel("Total UMI counts")
7390
axes[0, 0].set_ylabel("Number of detected genes")
7491
axes[0, 0].set_title("QC: UMI vs Genes")
7592
axes[0, 0].legend()
76-
93+
7794
# 2. MT Fraction distribution
7895
axes[0, 1].hist(pct_counts_mt, bins=50, color="gray", alpha=0.7)
79-
axes[0, 1].axvline(max_mt, color="red", linestyle="--", label=f"Max MT={max_mt:.0%}")
96+
axes[0, 1].axvline(
97+
max_mt, color="red", linestyle="--", label=f"Max MT={max_mt:.0%}"
98+
)
8099
axes[0, 1].set_xlabel("Mitochondrial Fraction")
81100
axes[0, 1].set_ylabel("Count")
82101
axes[0, 1].set_title("QC: MT Fraction Distribution")
83102
axes[0, 1].legend()
84-
103+
85104
# 3. Spatial: Before (Total)
86-
axes[1, 0].scatter(coords[:, 0], coords[:, 1], c="lightgray", s=15, label="All Spots")
87-
axes[1, 0].scatter(coords[keep_mask, 0], coords[keep_mask, 1], c="green", s=15, label="Pass QC")
105+
axes[1, 0].scatter(
106+
coords[:, 0], coords[:, 1], c="lightgray", s=15, label="All Spots"
107+
)
108+
axes[1, 0].scatter(
109+
coords[keep_mask, 0], coords[keep_mask, 1], c="green", s=15, label="Pass QC"
110+
)
88111
axes[1, 0].set_title("Spatial Distribution: Kept vs Filtered")
89112
axes[1, 0].set_aspect("equal")
90113
axes[1, 0].legend()
91-
114+
92115
# 4. Summary Table (as text)
93116
stats_text = (
94117
f"Sample: {os.path.basename(h5ad_path)}\n\n"
@@ -102,21 +125,23 @@ def diagnose_qc(h5ad_path, output_path, min_umis=500, min_genes=200, max_mt=0.15
102125
)
103126
axes[1, 1].text(0.1, 0.5, stats_text, fontsize=14, family="monospace")
104127
axes[1, 1].axis("off")
105-
128+
106129
plt.tight_layout()
107130
plt.savefig(output_path)
108131
print(f"Plot saved to {output_path}")
109132

133+
110134
if __name__ == "__main__":
111135
import argparse
136+
112137
parser = argparse.ArgumentParser()
113138
parser.add_argument("--sample", type=str, default="MEND29")
114139
args = parser.parse_args()
115-
140+
116141
sample_id = args.sample
117142
h5ad_file = f"A:\\hest_data\\st\\{sample_id}.h5ad"
118143
output_file = f"qc_diagnosis_{sample_id}.png"
119-
144+
120145
if not os.path.exists(h5ad_file):
121146
print(f"Error: {h5ad_file} not found.")
122147
else:

scripts/find_outliers.py

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import os
22
import sys
3+
34
sys.path.append(r"z:\Projects\SpatialTranscriptFormer\scripts")
45
from batch_qc_stats import calculate_qc_stats
56
import numpy as np
@@ -11,22 +12,30 @@
1112
results = []
1213
for s in samples:
1314
try:
14-
total, kept, l_umi, l_gene, h_mt = calculate_qc_stats(os.path.join(st_dir, f"{s}.h5ad"))
15-
results.append({
16-
"sample": s,
17-
"total": total,
18-
"kept": kept,
19-
"pct": kept / total if total > 0 else 0,
20-
"low_umi": l_umi,
21-
"low_gene": l_gene,
22-
"high_mt": h_mt
23-
})
15+
total, kept, l_umi, l_gene, h_mt = calculate_qc_stats(
16+
os.path.join(st_dir, f"{s}.h5ad")
17+
)
18+
results.append(
19+
{
20+
"sample": s,
21+
"total": total,
22+
"kept": kept,
23+
"pct": kept / total if total > 0 else 0,
24+
"low_umi": l_umi,
25+
"low_gene": l_gene,
26+
"high_mt": h_mt,
27+
}
28+
)
2429
except Exception as e:
2530
print(f"Error {s}: {e}")
2631

2732
results.sort(key=lambda x: x["pct"])
2833

29-
print(f"{'Sample':<15} | {'Kept %':<8} | {'Filtered':<10} | {'Low UMI':<8} | {'Low Gene':<8} | {'High MT':<8}")
34+
print(
35+
f"{'Sample':<15} | {'Kept %':<8} | {'Filtered':<10} | {'Low UMI':<8} | {'Low Gene':<8} | {'High MT':<8}"
36+
)
3037
print("-" * 75)
3138
for r in results[:15]:
32-
print(f"{r['sample']:<15} | {r['pct']:7.1%} | {r['total']-r['kept']:<10} | {r['low_umi']:<8} | {r['low_gene']:<8} | {r['high_mt']:<8}")
39+
print(
40+
f"{r['sample']:<15} | {r['pct']:7.1%} | {r['total']-r['kept']:<10} | {r['low_umi']:<8} | {r['low_gene']:<8} | {r['high_mt']:<8}"
41+
)

scripts/monitor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
"""
33
Real-time Training Monitor Entrypoint for SpatialTranscriptFormer.
44
"""
5+
56
import argparse
67
import logging
78
from spatial_transcript_former.dashboard.app import init_app, app

src/spatial_transcript_former/checkpoint.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313

1414
import torch
1515

16-
1716
# Keys serialized into config.json. These correspond to
1817
# SpatialTranscriptFormer.__init__ arguments (minus runtime-only
1918
# arguments like ``pathway_init`` and ``pretrained``).

src/spatial_transcript_former/data/pathways.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
"c2_cgp": "https://data.broadinstitute.org/gsea-msigdb/msigdb/release/2024.1.Hs/c2.cgp.v2024.1.Hs.symbols.gmt",
1919
}
2020

21+
2122
def download_msigdb_gmt(url: str, filename: str, cache_dir: str = ".cache") -> str:
2223
"""
2324
Download an MSigDB GMT file if not already cached.

src/spatial_transcript_former/data/spatial_stats.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def morans_i(x: np.ndarray, W: csr_matrix) -> float:
7171
x_mean = x.mean()
7272
z = x - x_mean
7373

74-
denominator = np.sum(z ** 2)
74+
denominator = np.sum(z**2)
7575
if denominator < 1e-12:
7676
return 0.0 # Constant expression → no spatial pattern
7777

src/spatial_transcript_former/models/interaction.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -348,9 +348,7 @@ def forward(
348348
x_layer = x_layer + layer.dropout1(attn_output)
349349
x_layer = x_layer + layer._ff_block(layer.norm2(x_layer))
350350
else:
351-
x_layer = layer.norm1(
352-
x_layer + layer.dropout1(attn_output)
353-
)
351+
x_layer = layer.norm1(x_layer + layer.dropout1(attn_output))
354352
x_layer = layer.norm2(x_layer + layer._ff_block(x_layer))
355353
out = x_layer
356354
else:

src/spatial_transcript_former/predict.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
import torch.nn as nn
2323
from torchvision import transforms
2424

25-
2625
# ═══════════════════════════════════════════════════════════════════════
2726
# FeatureExtractor
2827
# ═══════════════════════════════════════════════════════════════════════

src/spatial_transcript_former/recipes/hest/build_vocab.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,12 @@ def scan_h5ad_files(data_dir):
4444

4545

4646
def calculate_global_genes(
47-
data_dir, ids, num_genes=1000, target_pathways=None,
48-
svg_weight=0.0, svg_k=6,
47+
data_dir,
48+
ids,
49+
num_genes=1000,
50+
target_pathways=None,
51+
svg_weight=0.0,
52+
svg_k=6,
4953
):
5054
st_dir = os.path.join(data_dir, "st")
5155
if not ids:
@@ -158,15 +162,13 @@ def calculate_global_genes(
158162
# Hybrid score: weighted sum of ranks (lower = better)
159163
alpha = svg_weight
160164
hybrid_score = {
161-
g: (1 - alpha) * expr_rank[g] + alpha * mi_rank[g]
162-
for g in all_genes
165+
g: (1 - alpha) * expr_rank[g] + alpha * mi_rank[g] for g in all_genes
163166
}
164167
sorted_all_genes = sorted(all_genes, key=lambda g: hybrid_score[g])
165168

166169
# Build stats list with Moran's I column
167170
sorted_all = [
168-
(g, gene_totals[g], gene_morans_avg.get(g, 0.0))
169-
for g in sorted_all_genes
171+
(g, gene_totals[g], gene_morans_avg.get(g, 0.0)) for g in sorted_all_genes
170172
]
171173
print(
172174
f"Hybrid ranking: expression weight={(1 - alpha):.1f}, "
@@ -221,7 +223,7 @@ def main():
221223
type=float,
222224
default=0.0,
223225
help="Weight for spatial variability (Moran's I) in gene ranking. "
224-
"0.0=expression-only (default), 1.0=SVG-only, 0.5=balanced.",
226+
"0.0=expression-only (default), 1.0=SVG-only, 0.5=balanced.",
225227
)
226228
parser.add_argument(
227229
"--svg-k",
@@ -242,8 +244,12 @@ def main():
242244
sys.exit(1)
243245

244246
top_genes, all_stats = calculate_global_genes(
245-
args.data_dir, ids, args.num_genes, target_pathways=args.pathways,
246-
svg_weight=args.svg_weight, svg_k=args.svg_k,
247+
args.data_dir,
248+
ids,
249+
args.num_genes,
250+
target_pathways=args.pathways,
251+
svg_weight=args.svg_weight,
252+
svg_k=args.svg_k,
247253
)
248254

249255
print(f"Saving top {len(top_genes)} genes to {output_path}")

0 commit comments

Comments
 (0)