Skip to content

Commit 472bd5f

Browse files
Implementing svg (#7)
* feat: add SVG-aware gene selection and spatial coherence validation Introduce Moran's I-based spatial variability scoring to improve gene vocabulary selection and add a spatial coherence metric for validation. Spatial Statistics Module (NEW): - data/spatial_stats.py: lightweight Moran's I via KNN weights (numpy + scipy only, no new dependencies) - morans_i(), morans_i_batch(), spatial_coherence_score() SVG-aware Gene Selection: - build_vocab.py: --svg-weight (0-1) blends expression rank with Moran's I rank; --svg-k controls KNN graph size - Default svg_weight=0.0 preserves original behaviour - Stats CSV now includes morans_i column Spatial Coherence Validation: - engine.py: computes Moran's I correlation between predicted and ground-truth expression on top-50 SVGs during validation - train.py: logs spatial_coherence to SQLite Tests: - test_spatial_stats.py: 14 tests covering Moran's I (uniform, clustered, checkerboard, gradient) and coherence scoring Docs: - SC_BEST_PRACTICES.md: marked SVG selection and spatial coherence as implemented * Refactor: Centralize data management and reorganize test suite Data Management Framework Refactoring: - Created default `GeneVocab` class in `data/gene_vocab.py` to act as the single source of truth for gene list handling, replacing duplicated loading logic. - Centralized path resolution into `data/paths.py` (`resolve_feature_dir`, `resolve_gene_vocab_path`), eliminating hardcoded dataset paths. - Extracted HEST-specific patient-aware dataset splitting logic from core train.py into `recipes/hest/utils.py::get_train_val_ids()`, decoupling the training script from benchmark-specific data properties. - Updated predict.py to handle `anndata >= 0.12` strict shape enforcement by dynamically rebuilding the AnnData object in-place when filtering variables. Test Suite Restructuring & Bug Fixes: - Consolidated 37 flat, root-level test files into 15 focused module files organized across `tests/models/`, `tests/training/`, `tests/data/`, and `tests/recipes/hest/`. - Resolved 6 shadowed test name collisions that caused tests to silently overwrite each other during collection (e.g., `test_interaction_output_shape`). - Fixed broken `download_hest` imports in the `test_io.py` script by correctly resolving the root `scripts/` directory via `sys.path`. - Fixed `test_pathways.py` Mock failures by explicitly defining `args.pathways = None` to enforce default filtering behavior. - Total active test suite passes locally (199/199 collected tests). * - Applied black formatting. * - added .gitattributes * - forgot this... * - maybe now?
1 parent 7d75160 commit 472bd5f

71 files changed

Lines changed: 3407 additions & 2027 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.gitattributes

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
* text=auto eol=lf
2+

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ Visualization plots and spatial expression maps will be saved to the `./results`
113113
- **[Pathway Mapping](docs/PATHWAY_MAPPING.md)**: Clinical interpretability, pathway bottleneck design, and MSigDB integration.
114114
- **[Gene Analysis](docs/GENE_ANALYSIS.md)**: Modeling strategies for mapping morphology to high-dimensional gene spaces.
115115
- **[Data Structure](docs/DATA_STRUCTURE.md)**: Detailed breakdown of the HEST data structure on disk, metadata conventions, and preprocessing invariants.
116-
- **[Single-cell Best Practices](docs/SC_BEST_PRACTICES.md)**: Gap analysis and roadmap for alignment with industry standard recommendations.
116+
- **[Single-cell Best Practices](docs/SC_BEST_PRACTICES.md)**: Gap analysis and roadmap for alignment with standard recommendations.
117117

118118
## Development
119119

docs/SC_BEST_PRACTICES.md

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,13 @@ These are areas where the project already follows industry best practices:
2121

2222
The following items are recommended for future sprints to improve model robustness and biological accuracy.
2323

24-
### 1. SVG-aware Gene Selection (Moran's I)
24+
### 1. SVG-aware Gene Selection (Moran's I)
2525

26-
**Priority: High**
26+
**Priority: High** **Implemented**
2727
**Rationale**: Currently, genes are selected based on total expression or pathway membership. However, the model's primary task is to learn spatial patterns. Selecting genes based on **Spatially Variable Gene (SVG)** metrics like Moran's I (available in Squidpy) would prioritise genes that have learned spatial coherence over those that are just highly expressed (like housekeeping genes).
2828

29+
**Usage**: `stf-build-vocab --svg-weight 0.5 --svg-k 6` enables a hybrid ranking that blends total expression with Moran's I spatial variability. See `data/spatial_stats.py` for the implementation.
30+
2931
### 2. Standardised Preprocessing Pipeline
3032

3133
**Priority: Medium-High**
@@ -41,10 +43,12 @@ The following items are recommended for future sprints to improve model robustne
4143
**Priority: Medium**
4244
**Rationale**: Adding explicit QC thresholds (e.g., minimum UMI count, minimum detected genes, maximum mitochondrial fraction) to the dataset loading scripts would protect the model from training on low-quality "noise" spots.
4345

44-
### 5. Spatial Coherence Validation Metrics
46+
### 5. Spatial Coherence Validation Metrics
4547

46-
**Priority: Medium**
47-
**Rationale**: Aggregate metrics like MSE or PCC don't capture whether the *spatial distribution* of predictions is realistic. Adding a validation step that compares the Moran's I of predicted vs. ground-truth expression would provide a much stronger biological validation signal.
48+
**Priority: Medium****Implemented**
49+
**Rationale**: Aggregate metrics like MSE or PCC don't capture whether the *spatial distribution* of predictions is realistic. A validation step now compares the Moran's I of predicted vs. ground-truth expression for the top-50 spatially variable genes, reporting a Pearson correlation as the **Spatial Coherence Score**.
50+
51+
**Integration**: Computed automatically during validation in `training/engine.py` and logged to SQLite as `spatial_coherence`. See `data/spatial_stats.py:spatial_coherence_score()`.
4852

4953
### 6. Preprocessing Documentation
5054

docs/SVG_PROTOTYPE_RESULTS.md

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
# Spatially Variable Gene (SVG) Selection & Validation Prototype Results
2+
3+
## Overview
4+
5+
This document summarizes the prototype results for integrating spatially variable gene (SVG) scoring into the SpatialTranscriptFormer pipeline. Two major components were built and validated:
6+
7+
1. **SVG-Aware Gene Selection**: Using Moran's I to bias the 1000-gene vocabulary towards genes that exhibit strong spatial autocorrelation, improving the biological relevance of the model's bottleneck.
8+
2. **Spatial Coherence Validation**: An end-to-end validation metric that calculates Moran's I on the predicted vs. ground-truth expression vectors for the top-50 spatially variable genes, reported as a Pearson correlation score.
9+
10+
This prototype was run on a colorectal/bowel cancer cohort from the HEST-1k dataset (84 human samples) for ~400 epochs.
11+
12+
---
13+
14+
## 1. Training Progress & Stability
15+
16+
Training the `stf_small` model (4 layers, 384 dim, 8 heads) on the bowel cancer subset with SVG weighting (`--svg-weight 0.5`) showed steady convergence and stability.
17+
18+
![Training Loss Landscape](./assets/bowel_svg_loss_curve.png)
19+
20+
* **Learning Schedule**: 10 warm-up epochs followed by cosine annealing.
21+
* **Overfitting**: The gap between training and validation loss is characteristic of the small dataset size (84 samples), but the model finds a strong minimum.
22+
* **Best Checkpoint**: Reached a best validation loss of **1.611 at epoch 367**, an improvement over early training stages (~1.835 at epoch 130).
23+
24+
![Validation Metrics](./assets/bowel_svg_val_metrics.png)
25+
26+
* **Mean Absolute Error (MAE)**: Decreased smoothly from ~0.80 down strictly towards 0.60, indicating improved pixel-wise prediction accuracy.
27+
* **Prediction Variance**: Evaluated as a collapse detector. The variance ramped up from near-zero to ~0.03, confirming the model successfully escaped "mean-prediction" collapse and learned to output highly differentiated spatial patterns. *(Note: A variance of 0.0 indicates a collapsed model predicting the same average value everywhere).*
28+
29+
---
30+
31+
## 2. Pathway Spatial Coherence (Bowel Cancer)
32+
33+
To validate the biological plausibility of the predictions, we visualised clinically relevant Hallmarks pathways for colorectal/bowel cancer using the best epoch checkpoint (epoch 367) on sample `TENX29`.
34+
35+
Pathways selected for their relevance to colorectal cancer progression, invasion, and tumor microenvironment:
36+
37+
* **Wnt/β-Catenin Signaling**: Mutated in >80% of sporadic colorectal cancers (APC mutation).
38+
* **Epithelial-Mesenchymal Transition (EMT)**: Key marker of invasion and metastasis.
39+
* **TNF-α Signaling via NF-κB** & **Inflammatory Response**: Chronic inflammation is a hallmark driver of CRC.
40+
* **KRAS Signaling (Up)**: KRAS mutations occurring in ~40% of CRCs.
41+
* **Angiogenesis**: Critical for tumor vasculature, the target of anti-VEGF therapies.
42+
43+
### Pathway Predictions vs Ground Truth (Epoch 524)
44+
45+
![Bowel Cancer Pathway Predictions](./assets/TENX29_epoch_524.png)
46+
47+
**Observations:**
48+
49+
1. **Spatial Pattern Matching**: The model successfully reconstructs complex, heterogeneous spatial patterns across the tissue architecture. High-expression regions (yellow/green) in the ground truth strongly correlate with high-expression regions in the predictions.
50+
2. **Biological Gradients**: Crucially, the model does not just predict average expression but captures the relative spatial *gradients* of these pathways, confirming the health of the non-collapsed prediction variance metric.
51+
3. **Tumor Microenvironment (TME)**: Inflammatory pathways (TNF-α, Inflammatory Response) show distinct spatial localization differing from tumor-intrinsic pathways (Wnt, KRAS), accurately reflecting TME heterogeneity.
52+
53+
---
54+
55+
## Conclusion
56+
57+
The integration of Moran's I for both **SVG-aware gene selection** and **Spatial Coherence Validation** provides a significantly more robust, biologically grounded training pipeline. The model demonstrates the ability to learn and reconstruct clinically relevant spatial pathway patterns from H&E imaging alone, laying a strong foundation for scaling the model to the full HEST-1k dataset and evaluating patient-level stratification.

docs/TRAINING_GUIDE.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,17 @@ python scripts/run_preset.py --preset stf_medium
145145
python scripts/run_preset.py --preset stf_large
146146
```
147147

148+
#### Disease-Specific Priors (Colorectal Cancer)
149+
150+
To learn representations specifically constrained to phenotypes of a target disease, you can explicitly filter the initialization pathway bottleneck using the `--pathways` argument. The `crc` presets demonstrate this by shrinking the dimensionality down from 50 generic hallmarks to 14 CRC-specific pathways (e.g. Wnt/Beta-catenin, EMT, Angiogenesis).
151+
152+
```bash
153+
# Small CRC Variant (14 explicit pathways)
154+
python scripts/run_preset.py --preset stf_crc_small
155+
```
156+
157+
These presets are defined directly in `scripts/run_preset.py`, serving as a template for how you can introduce your own biological priors for other diseases.
158+
148159
### Choosing Interaction Modes
149160

150161
By default, the model runs in **Full Interaction** mode (`p2p p2h h2p h2h`) where all token types attend to each other. You can selectively disable interactions using the `--interactions` flag for ablation or to enforce specific architectural constraints.

scripts/batch_qc_stats.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
import os
2+
import h5py
3+
import numpy as np
4+
import matplotlib.pyplot as plt
5+
from scipy.sparse import csr_matrix
6+
7+
8+
def calculate_qc_stats(h5ad_path, min_umis=500, min_genes=200, max_mt=0.15):
9+
with h5py.File(h5ad_path, "r") as f:
10+
barcodes = [
11+
b.decode("utf-8") if isinstance(b, bytes) else b
12+
for b in f["obs"]["_index"][:]
13+
]
14+
gene_names = [
15+
g.decode("utf-8") if isinstance(g, bytes) else g
16+
for g in f["var"]["_index"][:]
17+
]
18+
X_group = f["X"]
19+
if isinstance(X_group, h5py.Group):
20+
X = csr_matrix(
21+
(X_group["data"][:], X_group["indices"][:], X_group["indptr"][:]),
22+
shape=(len(barcodes), len(gene_names)),
23+
)
24+
else:
25+
X = f["X"][:]
26+
27+
n_counts = np.array(X.sum(axis=1)).flatten()
28+
n_genes = np.array((X > 0).sum(axis=1)).flatten()
29+
mt_genes = [
30+
i
31+
for i, name in enumerate(gene_names)
32+
if "mt-" in name.lower() or "mt:" in name.lower()
33+
]
34+
if mt_genes:
35+
mt_counts = np.array(X[:, mt_genes].sum(axis=1)).flatten()
36+
pct_counts_mt = mt_counts / (n_counts + 1e-9)
37+
else:
38+
pct_counts_mt = np.zeros_like(n_counts)
39+
40+
keep_mask = (
41+
(n_counts >= min_umis) & (n_genes >= min_genes) & (pct_counts_mt <= max_mt)
42+
)
43+
return (
44+
len(barcodes),
45+
np.sum(keep_mask),
46+
np.sum(n_counts < min_umis),
47+
np.sum(n_genes < min_genes),
48+
np.sum(pct_counts_mt > max_mt),
49+
)
50+
51+
52+
if __name__ == "__main__":
53+
st_dir = r"A:\hest_data\st"
54+
55+
# Get all h5ad files
56+
samples = [
57+
f.replace(".h5ad", "") for f in os.listdir(st_dir) if f.endswith(".h5ad")
58+
]
59+
samples.sort()
60+
61+
print(f"Analyzing {len(samples)} samples...")
62+
print(
63+
f"{'Sample':<15} | {'Total':<6} | {'Kept':<6} | {'% Kept':<8} | {'Low UMI':<8} | {'Low Gene':<8} | {'High MT':<8}"
64+
)
65+
print("-" * 80)
66+
67+
all_results = []
68+
for s in samples:
69+
path = os.path.join(st_dir, f"{s}.h5ad")
70+
try:
71+
total, kept, l_umi, l_gene, h_mt = calculate_qc_stats(path)
72+
pct = kept / total if total > 0 else 0
73+
print(
74+
f"{s:<15} | {total:<6} | {kept:<6} | {pct:7.1%} | {l_umi:<8} | {l_gene:<8} | {h_mt:<8}"
75+
)
76+
all_results.append((total, kept, pct))
77+
except Exception as e:
78+
print(f"Error processing {s}: {e}")
79+
80+
if all_results:
81+
avg_kept = np.mean([r[2] for r in all_results])
82+
min_kept = np.min([r[2] for r in all_results])
83+
max_kept = np.max([r[2] for r in all_results])
84+
print("-" * 80)
85+
print(f"GLOBAL SUMMARY ({len(all_results)} samples):")
86+
print(f"Average Kept: {avg_kept:.1%}")
87+
print(f"Minimum Kept: {min_kept:.1%}")
88+
print(f"Maximum Kept: {max_kept:.1%}")

scripts/diagnose_qc.py

Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
import os
2+
import h5py
3+
import numpy as np
4+
import matplotlib.pyplot as plt
5+
from scipy.sparse import csr_matrix
6+
7+
8+
def diagnose_qc(h5ad_path, output_path, min_umis=500, min_genes=200, max_mt=0.15):
9+
print(f"Loading {h5ad_path}...")
10+
11+
with h5py.File(h5ad_path, "r") as f:
12+
# Load barcodes and gene names
13+
barcodes = [
14+
b.decode("utf-8") if isinstance(b, bytes) else b
15+
for b in f["obs"]["_index"][:]
16+
]
17+
gene_names = [
18+
g.decode("utf-8") if isinstance(g, bytes) else g
19+
for g in f["var"]["_index"][:]
20+
]
21+
22+
# Load expression matrix X
23+
X_group = f["X"]
24+
if isinstance(X_group, h5py.Group):
25+
data = X_group["data"][:]
26+
indices = X_group["indices"][:]
27+
indptr = X_group["indptr"][:]
28+
X = csr_matrix(
29+
(data, indices, indptr), shape=(len(barcodes), len(gene_names))
30+
)
31+
else:
32+
X = f["X"][:]
33+
34+
# Spatial coordinates (usually in obsm/spatial)
35+
if "obsm" in f and "spatial" in f["obsm"]:
36+
coords = f["obsm"]["spatial"][:]
37+
else:
38+
# Fallback if spatial not found
39+
coords = np.zeros((len(barcodes), 2))
40+
print("Warning: Spatial coordinates not found in obsm/spatial")
41+
42+
# Calculate QC metrics
43+
n_counts = np.array(X.sum(axis=1)).flatten()
44+
n_genes = np.array((X > 0).sum(axis=1)).flatten()
45+
46+
# MT fraction
47+
# Robust MT detection: check for mt- or MT- anywhere, but prioritize common patterns
48+
mt_genes = [
49+
i
50+
for i, name in enumerate(gene_names)
51+
if "mt-" in name.lower() or "mt:" in name.lower()
52+
]
53+
if mt_genes:
54+
print(f"Found {len(mt_genes)} mitochondrial genes.")
55+
mt_counts = np.array(X[:, mt_genes].sum(axis=1)).flatten()
56+
pct_counts_mt = mt_counts / (n_counts + 1e-9)
57+
else:
58+
pct_counts_mt = np.zeros_like(n_counts)
59+
print("No mitochondrial genes found.")
60+
61+
# Filter mask
62+
pass_umi = n_counts >= min_umis
63+
pass_gene = n_genes >= min_genes
64+
pass_mt = pct_counts_mt <= max_mt
65+
66+
keep_mask = pass_umi & pass_gene & pass_mt
67+
68+
total_spots = len(barcodes)
69+
kept_spots = np.sum(keep_mask)
70+
71+
print(f"Total spots: {total_spots}")
72+
print(f"Kept spots: {kept_spots} ({kept_spots/total_spots:.1%})")
73+
print(f"Filtered: {total_spots - kept_spots}")
74+
print(f" - Low UMI (<{min_umis}): {np.sum(~pass_umi)}")
75+
print(f" - Low Genes (<{min_genes}): {np.sum(~pass_gene)}")
76+
print(f" - High MT (>{max_mt:.1%}): {np.sum(~pass_mt)}")
77+
78+
# Plotting
79+
fig, axes = plt.subplots(2, 2, figsize=(15, 12))
80+
81+
# 1. UMI vs Genes
82+
axes[0, 0].scatter(n_counts, n_genes, c=keep_mask, cmap="RdYlGn", alpha=0.5, s=10)
83+
axes[0, 0].axvline(
84+
min_umis, color="red", linestyle="--", label=f"Min UMI={min_umis}"
85+
)
86+
axes[0, 0].axhline(
87+
min_genes, color="blue", linestyle="--", label=f"Min Genes={min_genes}"
88+
)
89+
axes[0, 0].set_xlabel("Total UMI counts")
90+
axes[0, 0].set_ylabel("Number of detected genes")
91+
axes[0, 0].set_title("QC: UMI vs Genes")
92+
axes[0, 0].legend()
93+
94+
# 2. MT Fraction distribution
95+
axes[0, 1].hist(pct_counts_mt, bins=50, color="gray", alpha=0.7)
96+
axes[0, 1].axvline(
97+
max_mt, color="red", linestyle="--", label=f"Max MT={max_mt:.0%}"
98+
)
99+
axes[0, 1].set_xlabel("Mitochondrial Fraction")
100+
axes[0, 1].set_ylabel("Count")
101+
axes[0, 1].set_title("QC: MT Fraction Distribution")
102+
axes[0, 1].legend()
103+
104+
# 3. Spatial: Before (Total)
105+
axes[1, 0].scatter(
106+
coords[:, 0], coords[:, 1], c="lightgray", s=15, label="All Spots"
107+
)
108+
axes[1, 0].scatter(
109+
coords[keep_mask, 0], coords[keep_mask, 1], c="green", s=15, label="Pass QC"
110+
)
111+
axes[1, 0].set_title("Spatial Distribution: Kept vs Filtered")
112+
axes[1, 0].set_aspect("equal")
113+
axes[1, 0].legend()
114+
115+
# 4. Summary Table (as text)
116+
stats_text = (
117+
f"Sample: {os.path.basename(h5ad_path)}\n\n"
118+
f"Total Spots: {total_spots}\n"
119+
f"Kept Spots: {kept_spots} ({kept_spots/total_spots:.1%})\n"
120+
f"Filtered: {total_spots - kept_spots}\n\n"
121+
f"Thresholds:\n"
122+
f"Min UMI: {min_umis}\n"
123+
f"Min Genes: {min_genes}\n"
124+
f"Max MT: {max_mt:.0%}"
125+
)
126+
axes[1, 1].text(0.1, 0.5, stats_text, fontsize=14, family="monospace")
127+
axes[1, 1].axis("off")
128+
129+
plt.tight_layout()
130+
plt.savefig(output_path)
131+
print(f"Plot saved to {output_path}")
132+
133+
134+
if __name__ == "__main__":
135+
import argparse
136+
137+
parser = argparse.ArgumentParser()
138+
parser.add_argument("--sample", type=str, default="MEND29")
139+
args = parser.parse_args()
140+
141+
sample_id = args.sample
142+
h5ad_file = f"A:\\hest_data\\st\\{sample_id}.h5ad"
143+
output_file = f"qc_diagnosis_{sample_id}.png"
144+
145+
if not os.path.exists(h5ad_file):
146+
print(f"Error: {h5ad_file} not found.")
147+
else:
148+
diagnose_qc(h5ad_file, output_file)

scripts/find_outliers.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
import os
2+
import sys
3+
4+
sys.path.append(r"z:\Projects\SpatialTranscriptFormer\scripts")
5+
from batch_qc_stats import calculate_qc_stats
6+
import numpy as np
7+
8+
st_dir = r"A:\hest_data\st"
9+
samples = [f.replace(".h5ad", "") for f in os.listdir(st_dir) if f.endswith(".h5ad")]
10+
samples.sort()
11+
12+
results = []
13+
for s in samples:
14+
try:
15+
total, kept, l_umi, l_gene, h_mt = calculate_qc_stats(
16+
os.path.join(st_dir, f"{s}.h5ad")
17+
)
18+
results.append(
19+
{
20+
"sample": s,
21+
"total": total,
22+
"kept": kept,
23+
"pct": kept / total if total > 0 else 0,
24+
"low_umi": l_umi,
25+
"low_gene": l_gene,
26+
"high_mt": h_mt,
27+
}
28+
)
29+
except Exception as e:
30+
print(f"Error {s}: {e}")
31+
32+
results.sort(key=lambda x: x["pct"])
33+
34+
print(
35+
f"{'Sample':<15} | {'Kept %':<8} | {'Filtered':<10} | {'Low UMI':<8} | {'Low Gene':<8} | {'High MT':<8}"
36+
)
37+
print("-" * 75)
38+
for r in results[:15]:
39+
print(
40+
f"{r['sample']:<15} | {r['pct']:7.1%} | {r['total']-r['kept']:<10} | {r['low_umi']:<8} | {r['low_gene']:<8} | {r['high_mt']:<8}"
41+
)

0 commit comments

Comments
 (0)