Skip to content

Commit ca6f5b1

Browse files
Refactor: Centralize data management and reorganize test suite
Data Management Framework Refactoring: - Created default `GeneVocab` class in `data/gene_vocab.py` to act as the single source of truth for gene list handling, replacing duplicated loading logic. - Centralized path resolution into `data/paths.py` (`resolve_feature_dir`, `resolve_gene_vocab_path`), eliminating hardcoded dataset paths. - Extracted HEST-specific patient-aware dataset splitting logic from core train.py into `recipes/hest/utils.py::get_train_val_ids()`, decoupling the training script from benchmark-specific data properties. - Updated predict.py to handle `anndata >= 0.12` strict shape enforcement by dynamically rebuilding the AnnData object in-place when filtering variables. Test Suite Restructuring & Bug Fixes: - Consolidated 37 flat, root-level test files into 15 focused module files organized across `tests/models/`, `tests/training/`, `tests/data/`, and `tests/recipes/hest/`. - Resolved 6 shadowed test name collisions that caused tests to silently overwrite each other during collection (e.g., `test_interaction_output_shape`). - Fixed broken `download_hest` imports in the `test_io.py` script by correctly resolving the root `scripts/` directory via `sys.path`. - Fixed `test_pathways.py` Mock failures by explicitly defining `args.pathways = None` to enforce default filtering behavior. - Total active test suite passes locally (199/199 collected tests).
1 parent f34e8c3 commit ca6f5b1

62 files changed

Lines changed: 2792 additions & 2018 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

docs/SVG_PROTOTYPE_RESULTS.md

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
# Spatially Variable Gene (SVG) Selection & Validation Prototype Results
2+
3+
## Overview
4+
5+
This document summarizes the prototype results for integrating spatially variable gene (SVG) scoring into the SpatialTranscriptFormer pipeline. Two major components were built and validated:
6+
7+
1. **SVG-Aware Gene Selection**: Using Moran's I to bias the 1000-gene vocabulary towards genes that exhibit strong spatial autocorrelation, improving the biological relevance of the model's bottleneck.
8+
2. **Spatial Coherence Validation**: An end-to-end validation metric that calculates Moran's I on the predicted vs. ground-truth expression vectors for the top-50 spatially variable genes, reported as a Pearson correlation score.
9+
10+
This prototype was run on a colorectal/bowel cancer cohort from the HEST-1k dataset (84 human samples) for ~400 epochs.
11+
12+
---
13+
14+
## 1. Training Progress & Stability
15+
16+
Training the `stf_small` model (4 layers, 384 dim, 8 heads) on the bowel cancer subset with SVG weighting (`--svg-weight 0.5`) showed steady convergence and stability.
17+
18+
![Training Loss Landscape](./assets/bowel_svg_loss_curve.png)
19+
20+
* **Learning Schedule**: 10 warm-up epochs followed by cosine annealing.
21+
* **Overfitting**: The gap between training and validation loss is characteristic of the small dataset size (84 samples), but the model finds a strong minimum.
22+
* **Best Checkpoint**: Reached a best validation loss of **1.611 at epoch 367**, an improvement over early training stages (~1.835 at epoch 130).
23+
24+
![Validation Metrics](./assets/bowel_svg_val_metrics.png)
25+
26+
* **Mean Absolute Error (MAE)**: Decreased smoothly from ~0.80 down strictly towards 0.60, indicating improved pixel-wise prediction accuracy.
27+
* **Prediction Variance**: Evaluated as a collapse detector. The variance ramped up from near-zero to ~0.03, confirming the model successfully escaped "mean-prediction" collapse and learned to output highly differentiated spatial patterns. *(Note: A variance of 0.0 indicates a collapsed model predicting the same average value everywhere).*
28+
29+
---
30+
31+
## 2. Pathway Spatial Coherence (Bowel Cancer)
32+
33+
To validate the biological plausibility of the predictions, we visualised clinically relevant Hallmarks pathways for colorectal/bowel cancer using the best epoch checkpoint (epoch 367) on sample `TENX29`.
34+
35+
Pathways selected for their relevance to colorectal cancer progression, invasion, and tumor microenvironment:
36+
37+
* **Wnt/β-Catenin Signaling**: Mutated in >80% of sporadic colorectal cancers (APC mutation).
38+
* **Epithelial-Mesenchymal Transition (EMT)**: Key marker of invasion and metastasis.
39+
* **TNF-α Signaling via NF-κB** & **Inflammatory Response**: Chronic inflammation is a hallmark driver of CRC.
40+
* **KRAS Signaling (Up)**: KRAS mutations occurring in ~40% of CRCs.
41+
* **Angiogenesis**: Critical for tumor vasculature, the target of anti-VEGF therapies.
42+
43+
### Pathway Predictions vs Ground Truth (Epoch 524)
44+
45+
![Bowel Cancer Pathway Predictions](./assets/TENX29_epoch_524.png)
46+
47+
**Observations:**
48+
49+
1. **Spatial Pattern Matching**: The model successfully reconstructs complex, heterogeneous spatial patterns across the tissue architecture. High-expression regions (yellow/green) in the ground truth strongly correlate with high-expression regions in the predictions.
50+
2. **Biological Gradients**: Crucially, the model does not just predict average expression but captures the relative spatial *gradients* of these pathways, confirming the health of the non-collapsed prediction variance metric.
51+
3. **Tumor Microenvironment (TME)**: Inflammatory pathways (TNF-α, Inflammatory Response) show distinct spatial localization differing from tumor-intrinsic pathways (Wnt, KRAS), accurately reflecting TME heterogeneity.
52+
53+
---
54+
55+
## Conclusion
56+
57+
The integration of Moran's I for both **SVG-aware gene selection** and **Spatial Coherence Validation** provides a significantly more robust, biologically grounded training pipeline. The model demonstrates the ability to learn and reconstruct clinically relevant spatial pathway patterns from H&E imaging alone, laying a strong foundation for scaling the model to the full HEST-1k dataset and evaluating patient-level stratification.

docs/TRAINING_GUIDE.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,17 @@ python scripts/run_preset.py --preset stf_medium
145145
python scripts/run_preset.py --preset stf_large
146146
```
147147

148+
#### Disease-Specific Priors (Colorectal Cancer)
149+
150+
To learn representations specifically constrained to phenotypes of a target disease, you can explicitly filter the initialization pathway bottleneck using the `--pathways` argument. The `crc` presets demonstrate this by shrinking the dimensionality down from 50 generic hallmarks to 14 CRC-specific pathways (e.g. Wnt/Beta-catenin, EMT, Angiogenesis).
151+
152+
```bash
153+
# Small CRC Variant (14 explicit pathways)
154+
python scripts/run_preset.py --preset stf_crc_small
155+
```
156+
157+
These presets are defined directly in `scripts/run_preset.py`, serving as a template for how you can introduce your own biological priors for other diseases.
158+
148159
### Choosing Interaction Modes
149160

150161
By default, the model runs in **Full Interaction** mode (`p2p p2h h2p h2h`) where all token types attend to each other. You can selectively disable interactions using the `--interactions` flag for ablation or to enforce specific architectural constraints.

scripts/batch_qc_stats.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
import os
2+
import h5py
3+
import numpy as np
4+
import matplotlib.pyplot as plt
5+
from scipy.sparse import csr_matrix
6+
7+
def calculate_qc_stats(h5ad_path, min_umis=500, min_genes=200, max_mt=0.15):
8+
with h5py.File(h5ad_path, "r") as f:
9+
barcodes = [b.decode("utf-8") if isinstance(b, bytes) else b for b in f["obs"]["_index"][:]]
10+
gene_names = [g.decode("utf-8") if isinstance(g, bytes) else g for g in f["var"]["_index"][:]]
11+
X_group = f["X"]
12+
if isinstance(X_group, h5py.Group):
13+
X = csr_matrix((X_group["data"][:], X_group["indices"][:], X_group["indptr"][:]), shape=(len(barcodes), len(gene_names)))
14+
else:
15+
X = f["X"][:]
16+
17+
n_counts = np.array(X.sum(axis=1)).flatten()
18+
n_genes = np.array((X > 0).sum(axis=1)).flatten()
19+
mt_genes = [i for i, name in enumerate(gene_names) if "mt-" in name.lower() or "mt:" in name.lower()]
20+
if mt_genes:
21+
mt_counts = np.array(X[:, mt_genes].sum(axis=1)).flatten()
22+
pct_counts_mt = mt_counts / (n_counts + 1e-9)
23+
else:
24+
pct_counts_mt = np.zeros_like(n_counts)
25+
26+
keep_mask = (n_counts >= min_umis) & (n_genes >= min_genes) & (pct_counts_mt <= max_mt)
27+
return len(barcodes), np.sum(keep_mask), np.sum(n_counts < min_umis), np.sum(n_genes < min_genes), np.sum(pct_counts_mt > max_mt)
28+
29+
if __name__ == "__main__":
30+
st_dir = r"A:\hest_data\st"
31+
32+
# Get all h5ad files
33+
samples = [f.replace(".h5ad", "") for f in os.listdir(st_dir) if f.endswith(".h5ad")]
34+
samples.sort()
35+
36+
print(f"Analyzing {len(samples)} samples...")
37+
print(f"{'Sample':<15} | {'Total':<6} | {'Kept':<6} | {'% Kept':<8} | {'Low UMI':<8} | {'Low Gene':<8} | {'High MT':<8}")
38+
print("-" * 80)
39+
40+
all_results = []
41+
for s in samples:
42+
path = os.path.join(st_dir, f"{s}.h5ad")
43+
try:
44+
total, kept, l_umi, l_gene, h_mt = calculate_qc_stats(path)
45+
pct = kept / total if total > 0 else 0
46+
print(f"{s:<15} | {total:<6} | {kept:<6} | {pct:7.1%} | {l_umi:<8} | {l_gene:<8} | {h_mt:<8}")
47+
all_results.append((total, kept, pct))
48+
except Exception as e:
49+
print(f"Error processing {s}: {e}")
50+
51+
if all_results:
52+
avg_kept = np.mean([r[2] for r in all_results])
53+
min_kept = np.min([r[2] for r in all_results])
54+
max_kept = np.max([r[2] for r in all_results])
55+
print("-" * 80)
56+
print(f"GLOBAL SUMMARY ({len(all_results)} samples):")
57+
print(f"Average Kept: {avg_kept:.1%}")
58+
print(f"Minimum Kept: {min_kept:.1%}")
59+
print(f"Maximum Kept: {max_kept:.1%}")

scripts/diagnose_qc.py

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
import os
2+
import h5py
3+
import numpy as np
4+
import matplotlib.pyplot as plt
5+
from scipy.sparse import csr_matrix
6+
7+
def diagnose_qc(h5ad_path, output_path, min_umis=500, min_genes=200, max_mt=0.15):
8+
print(f"Loading {h5ad_path}...")
9+
10+
with h5py.File(h5ad_path, "r") as f:
11+
# Load barcodes and gene names
12+
barcodes = [b.decode("utf-8") if isinstance(b, bytes) else b for b in f["obs"]["_index"][:]]
13+
gene_names = [g.decode("utf-8") if isinstance(g, bytes) else g for g in f["var"]["_index"][:]]
14+
15+
# Load expression matrix X
16+
X_group = f["X"]
17+
if isinstance(X_group, h5py.Group):
18+
data = X_group["data"][:]
19+
indices = X_group["indices"][:]
20+
indptr = X_group["indptr"][:]
21+
X = csr_matrix((data, indices, indptr), shape=(len(barcodes), len(gene_names)))
22+
else:
23+
X = f["X"][:]
24+
25+
# Spatial coordinates (usually in obsm/spatial)
26+
if "obsm" in f and "spatial" in f["obsm"]:
27+
coords = f["obsm"]["spatial"][:]
28+
else:
29+
# Fallback if spatial not found
30+
coords = np.zeros((len(barcodes), 2))
31+
print("Warning: Spatial coordinates not found in obsm/spatial")
32+
33+
# Calculate QC metrics
34+
n_counts = np.array(X.sum(axis=1)).flatten()
35+
n_genes = np.array((X > 0).sum(axis=1)).flatten()
36+
37+
# MT fraction
38+
# Robust MT detection: check for mt- or MT- anywhere, but prioritize common patterns
39+
mt_genes = [i for i, name in enumerate(gene_names) if "mt-" in name.lower() or "mt:" in name.lower()]
40+
if mt_genes:
41+
print(f"Found {len(mt_genes)} mitochondrial genes.")
42+
mt_counts = np.array(X[:, mt_genes].sum(axis=1)).flatten()
43+
pct_counts_mt = mt_counts / (n_counts + 1e-9)
44+
else:
45+
pct_counts_mt = np.zeros_like(n_counts)
46+
print("No mitochondrial genes found.")
47+
48+
# Filter mask
49+
pass_umi = n_counts >= min_umis
50+
pass_gene = n_genes >= min_genes
51+
pass_mt = pct_counts_mt <= max_mt
52+
53+
keep_mask = pass_umi & pass_gene & pass_mt
54+
55+
total_spots = len(barcodes)
56+
kept_spots = np.sum(keep_mask)
57+
58+
print(f"Total spots: {total_spots}")
59+
print(f"Kept spots: {kept_spots} ({kept_spots/total_spots:.1%})")
60+
print(f"Filtered: {total_spots - kept_spots}")
61+
print(f" - Low UMI (<{min_umis}): {np.sum(~pass_umi)}")
62+
print(f" - Low Genes (<{min_genes}): {np.sum(~pass_gene)}")
63+
print(f" - High MT (>{max_mt:.1%}): {np.sum(~pass_mt)}")
64+
65+
# Plotting
66+
fig, axes = plt.subplots(2, 2, figsize=(15, 12))
67+
68+
# 1. UMI vs Genes
69+
axes[0, 0].scatter(n_counts, n_genes, c=keep_mask, cmap="RdYlGn", alpha=0.5, s=10)
70+
axes[0, 0].axvline(min_umis, color="red", linestyle="--", label=f"Min UMI={min_umis}")
71+
axes[0, 0].axhline(min_genes, color="blue", linestyle="--", label=f"Min Genes={min_genes}")
72+
axes[0, 0].set_xlabel("Total UMI counts")
73+
axes[0, 0].set_ylabel("Number of detected genes")
74+
axes[0, 0].set_title("QC: UMI vs Genes")
75+
axes[0, 0].legend()
76+
77+
# 2. MT Fraction distribution
78+
axes[0, 1].hist(pct_counts_mt, bins=50, color="gray", alpha=0.7)
79+
axes[0, 1].axvline(max_mt, color="red", linestyle="--", label=f"Max MT={max_mt:.0%}")
80+
axes[0, 1].set_xlabel("Mitochondrial Fraction")
81+
axes[0, 1].set_ylabel("Count")
82+
axes[0, 1].set_title("QC: MT Fraction Distribution")
83+
axes[0, 1].legend()
84+
85+
# 3. Spatial: Before (Total)
86+
axes[1, 0].scatter(coords[:, 0], coords[:, 1], c="lightgray", s=15, label="All Spots")
87+
axes[1, 0].scatter(coords[keep_mask, 0], coords[keep_mask, 1], c="green", s=15, label="Pass QC")
88+
axes[1, 0].set_title("Spatial Distribution: Kept vs Filtered")
89+
axes[1, 0].set_aspect("equal")
90+
axes[1, 0].legend()
91+
92+
# 4. Summary Table (as text)
93+
stats_text = (
94+
f"Sample: {os.path.basename(h5ad_path)}\n\n"
95+
f"Total Spots: {total_spots}\n"
96+
f"Kept Spots: {kept_spots} ({kept_spots/total_spots:.1%})\n"
97+
f"Filtered: {total_spots - kept_spots}\n\n"
98+
f"Thresholds:\n"
99+
f"Min UMI: {min_umis}\n"
100+
f"Min Genes: {min_genes}\n"
101+
f"Max MT: {max_mt:.0%}"
102+
)
103+
axes[1, 1].text(0.1, 0.5, stats_text, fontsize=14, family="monospace")
104+
axes[1, 1].axis("off")
105+
106+
plt.tight_layout()
107+
plt.savefig(output_path)
108+
print(f"Plot saved to {output_path}")
109+
110+
if __name__ == "__main__":
111+
import argparse
112+
parser = argparse.ArgumentParser()
113+
parser.add_argument("--sample", type=str, default="MEND29")
114+
args = parser.parse_args()
115+
116+
sample_id = args.sample
117+
h5ad_file = f"A:\\hest_data\\st\\{sample_id}.h5ad"
118+
output_file = f"qc_diagnosis_{sample_id}.png"
119+
120+
if not os.path.exists(h5ad_file):
121+
print(f"Error: {h5ad_file} not found.")
122+
else:
123+
diagnose_qc(h5ad_file, output_file)

scripts/find_outliers.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import os
2+
import sys
3+
sys.path.append(r"z:\Projects\SpatialTranscriptFormer\scripts")
4+
from batch_qc_stats import calculate_qc_stats
5+
import numpy as np
6+
7+
st_dir = r"A:\hest_data\st"
8+
samples = [f.replace(".h5ad", "") for f in os.listdir(st_dir) if f.endswith(".h5ad")]
9+
samples.sort()
10+
11+
results = []
12+
for s in samples:
13+
try:
14+
total, kept, l_umi, l_gene, h_mt = calculate_qc_stats(os.path.join(st_dir, f"{s}.h5ad"))
15+
results.append({
16+
"sample": s,
17+
"total": total,
18+
"kept": kept,
19+
"pct": kept / total if total > 0 else 0,
20+
"low_umi": l_umi,
21+
"low_gene": l_gene,
22+
"high_mt": h_mt
23+
})
24+
except Exception as e:
25+
print(f"Error {s}: {e}")
26+
27+
results.sort(key=lambda x: x["pct"])
28+
29+
print(f"{'Sample':<15} | {'Kept %':<8} | {'Filtered':<10} | {'Low UMI':<8} | {'Low Gene':<8} | {'High MT':<8}")
30+
print("-" * 75)
31+
for r in results[:15]:
32+
print(f"{r['sample']:<15} | {r['pct']:7.1%} | {r['total']-r['kept']:<10} | {r['low_umi']:<8} | {r['low_gene']:<8} | {r['high_mt']:<8}")

scripts/run_preset.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,24 @@
55

66
from spatial_transcript_former.config import get_config
77

8+
# Curated list of MSigDB Hallmarks with strong evidence of involvement in Colorectal/Bowel Cancer
9+
CRC_PATHWAYS = [
10+
"HALLMARK_WNT_BETA_CATENIN_SIGNALING",
11+
"HALLMARK_TGF_BETA_SIGNALING",
12+
"HALLMARK_KRAS_SIGNALING_UP",
13+
"HALLMARK_KRAS_SIGNALING_DN",
14+
"HALLMARK_PI3K_AKT_MTOR_SIGNALING",
15+
"HALLMARK_EPITHELIAL_MESENCHYMAL_TRANSITION",
16+
"HALLMARK_ANGIOGENESIS",
17+
"HALLMARK_APICAL_JUNCTION",
18+
"HALLMARK_INFLAMMATORY_RESPONSE",
19+
"HALLMARK_IL6_JAK_STAT3_SIGNALING",
20+
"HALLMARK_APOPTOSIS",
21+
"HALLMARK_P53_PATHWAY",
22+
"HALLMARK_DNA_REPAIR",
23+
"HALLMARK_HYPOXIA",
24+
]
25+
826

927
def make_stf_params(n_layers: int, token_dim: int, n_heads: int, batch_size: int):
1028
"""Helper to create standard SpatialTranscriptFormer parameters."""
@@ -14,6 +32,7 @@ def make_stf_params(n_layers: int, token_dim: int, n_heads: int, batch_size: int
1432
"precomputed": True,
1533
"whole-slide": True,
1634
"pathway-init": True,
35+
"pathway-loss-weight": 0.5,
1736
"use-amp": True,
1837
"log-transform": True,
1938
"loss": "mse_pcc",
@@ -55,6 +74,11 @@ def make_stf_params(n_layers: int, token_dim: int, n_heads: int, batch_size: int
5574
"stf_small": make_stf_params(n_layers=4, token_dim=384, n_heads=8, batch_size=8),
5675
"stf_medium": make_stf_params(n_layers=6, token_dim=512, n_heads=8, batch_size=8),
5776
"stf_large": make_stf_params(n_layers=12, token_dim=768, n_heads=12, batch_size=8),
77+
# --- Biologically-Prioritized Variants (e.g. Colorectal Cancer) ---
78+
"stf_crc_tiny": {**make_stf_params(2, 256, 4, 8), "pathways": CRC_PATHWAYS},
79+
"stf_crc_small": {**make_stf_params(4, 384, 8, 8), "pathways": CRC_PATHWAYS},
80+
"stf_crc_medium": {**make_stf_params(6, 512, 8, 8), "pathways": CRC_PATHWAYS},
81+
"stf_crc_large": {**make_stf_params(12, 768, 12, 8), "pathways": CRC_PATHWAYS},
5882
}
5983

6084

@@ -67,6 +91,9 @@ def params_to_args(params_dict):
6791
args.append(arg_name)
6892
elif value is False or value is None:
6993
continue
94+
elif isinstance(value, list) or isinstance(value, tuple):
95+
args.append(arg_name)
96+
args.extend([str(v) for v in value])
7097
else:
7198
args.extend([arg_name, str(value)])
7299
return args

src/spatial_transcript_former/data/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,13 @@
22
Data abstractions for SpatialTranscriptFormer.
33
44
Core exports:
5+
- :class:`GeneVocab` — single source of truth for gene vocabulary
56
- :class:`SpatialDataset` — abstract base class implementing the data contract
67
- :func:`apply_dihedral_augmentation` — D4 coordinate augmentation
78
- :func:`apply_dihedral_to_tensor` — D4 image augmentation
89
- :func:`normalize_coordinates` — auto-normalise spatial coordinates
10+
- :func:`resolve_feature_dir` — centralised feature directory discovery
11+
- :func:`resolve_gene_vocab_path` — find ``global_genes.json``
912
1013
HEST-specific exports (backward compatibility — prefer ``recipes.hest``):
1114
- :class:`HEST_Dataset`, :func:`get_hest_dataloader`
@@ -20,3 +23,5 @@
2023
apply_dihedral_to_tensor,
2124
normalize_coordinates,
2225
)
26+
from .gene_vocab import GeneVocab
27+
from .paths import resolve_feature_dir, resolve_gene_vocab_path

0 commit comments

Comments
 (0)