Skip to content

Commit 92e9076

Browse files
- Cleaned up and adapted tests to new model pipeline. God this took a long time...
1 parent c402c52 commit 92e9076

18 files changed

Lines changed: 342 additions & 1189 deletions

config.yaml

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,16 @@ training:
1111
num_genes: 1000
1212
batch_size: 8
1313
learning_rate: 0.0001
14-
output_dir: "./checkpoints"
14+
output_dir: "./runs"
1515

1616
# MSigDB Pathway Settings
1717
pathways:
1818
default_collection: "hallmarks"
1919
cache_dir: ".cache"
20+
21+
# Quality Control Defaults
22+
qc:
23+
min_umis: 500
24+
min_genes: 200
25+
max_mt: 0.15
26+
min_pathways: 25

tests/data/test_augmentation.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""
2-
Merged tests: test_spatial_augment.py, test_augmentation_sync.py, test_spatial_alignment.py
2+
Tests for spatial augmentation logic (dihedral groups) and coordinate alignment.
33
"""
44

55
import torch
@@ -13,7 +13,9 @@
1313
)
1414
from spatial_transcript_former.models import SpatialTranscriptFormer
1515

16-
# --- From test_spatial_augment.py ---
16+
# ---------------------------------------------------------------------------
17+
# Dihedral Logic
18+
# ---------------------------------------------------------------------------
1719

1820

1921
def test_apply_dihedral_augmentation_torch():
@@ -46,7 +48,9 @@ def test_apply_dihedral_augmentation_numpy():
4648
assert out.shape == coords.shape
4749

4850

49-
# --- From test_augmentation_sync.py ---
51+
# ---------------------------------------------------------------------------
52+
# Synchronization
53+
# ---------------------------------------------------------------------------
5054

5155

5256
def test_sync_logic():
@@ -92,7 +96,9 @@ def test_sync_logic():
9296
assert False, f"Pixel lost in op {op}"
9397

9498

95-
# --- From test_spatial_alignment.py ---
99+
# ---------------------------------------------------------------------------
100+
# Spatial Alignment
101+
# ---------------------------------------------------------------------------
96102

97103

98104
def test_spatial_mixing_with_large_coordinates():
@@ -104,7 +110,7 @@ def test_spatial_mixing_with_large_coordinates():
104110
token_dim = 64
105111

106112
model = SpatialTranscriptFormer(
107-
num_genes=10, token_dim=token_dim, n_layers=2, use_spatial_pe=True
113+
token_dim=token_dim, n_layers=2, use_spatial_pe=True
108114
)
109115

110116
# Create two patches that are physically adjacent (256px apart) but logically neighbors
Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,11 @@
11
"""
2-
Merged tests: test_data_management.py, test_data_integrity.py
2+
Tests for biological coverage and spatial data integrity.
33
"""
44

55
import os
6-
76
import pytest
87
import torch
98
import numpy as np
10-
import h5py
119

1210
from spatial_transcript_former.recipes.hest.io import (
1311
get_hest_data_dir,
@@ -20,8 +18,6 @@
2018
MSIGDB_URLS,
2119
)
2220

23-
# --- From test_data_integrity.py ---
24-
2521

2622
@pytest.fixture
2723
def data_dir():

tests/data/test_pathways.py

Lines changed: 2 additions & 187 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""
2-
Merged tests: test_pathways.py, test_pathways_robust.py, test_pathway_stability.py
2+
Tests for MSigDB pathway parsing and membership matrix construction.
33
"""
44

55
import pytest
@@ -13,13 +13,6 @@
1313
MSIGDB_URLS,
1414
)
1515
from spatial_transcript_former.data.pathways import build_membership_matrix
16-
from spatial_transcript_former.models.interaction import SpatialTranscriptFormer
17-
from spatial_transcript_former.training.losses import (
18-
AuxiliaryPathwayLoss,
19-
MaskedMSELoss,
20-
)
21-
22-
# --- From test_pathways.py ---
2316

2417

2518
@pytest.fixture(scope="module")
@@ -133,182 +126,4 @@ def test_core_pathways_exist(self, pathway_result):
133126

134127
# ---------------------------------------------------------------------------
135128
# Pathway ground truth
136-
# ---------------------------------------------------------------------------
137-
138-
139-
class TestPathwayTruth:
140-
def test_consistent_across_calls(self, gene_list):
141-
"""Ground truth from MSigDB membership should be identical across calls."""
142-
from spatial_transcript_former.visualization import _compute_pathway_truth
143-
from unittest.mock import MagicMock
144-
145-
args = MagicMock()
146-
args.sparsity_lambda = 0.0
147-
args.pathways = None
148-
149-
np.random.seed(42)
150-
gene_truth = np.random.rand(200, len(gene_list)).astype(np.float32)
151-
152-
result1, names1 = _compute_pathway_truth(gene_truth, gene_list, args)
153-
result2, names2 = _compute_pathway_truth(gene_truth, gene_list, args)
154-
155-
np.testing.assert_array_equal(result1, result2)
156-
assert names1 == names2
157-
158-
def test_output_shape(self, gene_list):
159-
"""Pathway truth should be (N, P) where P=50 (Hallmarks default)."""
160-
from spatial_transcript_former.visualization import _compute_pathway_truth
161-
from unittest.mock import MagicMock
162-
163-
args = MagicMock()
164-
args.sparsity_lambda = 0.0
165-
args.pathways = None
166-
167-
N = 150
168-
gene_truth = np.random.rand(N, len(gene_list)).astype(np.float32)
169-
result, names = _compute_pathway_truth(gene_truth, gene_list, args)
170-
171-
assert result.shape == (N, 50)
172-
assert len(names) == 50
173-
174-
def test_spatial_variation(self, gene_list):
175-
"""Pathway truth should have spatial variation (non-zero std)."""
176-
from spatial_transcript_former.visualization import _compute_pathway_truth
177-
from unittest.mock import MagicMock
178-
179-
args = MagicMock()
180-
args.sparsity_lambda = 0.0
181-
args.pathways = None
182-
183-
# Create gene expression with spatial patterns
184-
N = 200
185-
gene_truth = np.random.rand(N, len(gene_list)).astype(np.float32)
186-
# Add spatial structure to first few genes
187-
gene_truth[:100, 0] += 5.0
188-
gene_truth[100:, 1] += 5.0
189-
190-
result, _ = _compute_pathway_truth(gene_truth, gene_list, args)
191-
192-
# At least some pathways should have non-trivial spatial variation
193-
stds = np.std(result, axis=0)
194-
assert np.any(stds > 0.01), "Pathway truth has no spatial variation"
195-
196-
197-
# --- From test_pathways_robust.py ---
198-
199-
200-
def test_build_membership_matrix_integrity():
201-
"""Verify that the membership matrix correctly maps genes to pathways."""
202-
pathway_dict = {
203-
"PATHWAY_A": ["GENE_1", "GENE_2"],
204-
"PATHWAY_B": ["GENE_2", "GENE_3"],
205-
}
206-
gene_list = ["GENE_1", "GENE_2", "GENE_3", "GENE_4"]
207-
208-
matrix, names = build_membership_matrix(pathway_dict, gene_list)
209-
210-
assert names == ["PATHWAY_A", "PATHWAY_B"]
211-
assert matrix.shape == (2, 4)
212-
213-
# Pathway A: GENE_1, GENE_2
214-
assert matrix[0, 0] == 1.0
215-
assert matrix[0, 1] == 1.0
216-
assert matrix[0, 2] == 0.0
217-
assert matrix[0, 3] == 0.0
218-
219-
# Pathway B: GENE_2, GENE_3
220-
assert matrix[1, 0] == 0.0
221-
assert matrix[1, 1] == 1.0
222-
assert matrix[1, 2] == 1.0
223-
assert matrix[1, 3] == 0.0
224-
225-
226-
def test_build_membership_matrix_empty():
227-
"""Check behavior with no matches."""
228-
pathway_dict = {"EMPTY": ["XYZ"]}
229-
gene_list = ["ABC", "DEF"]
230-
matrix, names = build_membership_matrix(pathway_dict, gene_list)
231-
assert matrix.sum() == 0
232-
assert names == ["EMPTY"]
233-
234-
235-
# --- From test_pathway_stability.py ---
236-
237-
238-
def test_pathway_initialization_stability_and_gradients():
239-
"""
240-
Verifies that initializing the model with a binary pathway matrix:
241-
1. Does not cause predictions to exponentially explode (numerical stability).
242-
2. Allows gradients to flow properly when using AuxiliaryPathwayLoss.
243-
"""
244-
torch.manual_seed(42)
245-
num_pathways = 50
246-
num_genes = 100
247-
248-
# Create a synthetic MSigDB-style binary matrix
249-
pathway_matrix = (torch.rand(num_pathways, num_genes) > 0.8).float()
250-
# Ensure no empty pathways to avoid division by zero
251-
pathway_matrix[:, 0] = 1.0
252-
253-
# Initialize model with pathway_init
254-
model = SpatialTranscriptFormer(
255-
num_genes=num_genes,
256-
num_pathways=num_pathways,
257-
pathway_init=pathway_matrix,
258-
use_spatial_pe=False,
259-
output_mode="counts",
260-
pretrained=False,
261-
)
262-
263-
# Dummy inputs
264-
B, S, D = (
265-
2,
266-
10,
267-
2048,
268-
) # Using D=2048 since backbone='resnet50' requires it natively, or provided features
269-
feats = torch.randn(B, S, D, requires_grad=True)
270-
coords = torch.randn(B, S, 2)
271-
target_genes = torch.randn(B, S, num_genes).abs()
272-
mask = torch.zeros(B, S, dtype=torch.bool)
273-
274-
# Forward pass
275-
# return_pathways=True is needed to get the intermediate pathway preds for Auxiliary loss
276-
gene_preds, pathway_preds = model(
277-
feats, rel_coords=coords, return_dense=True, return_pathways=True
278-
)
279-
280-
# 1. Numerical Stability Check
281-
# Without L1 normalization and removing temperature, predictions would explode.
282-
# With the fix, Softplus should keep outputs reasonably small.
283-
max_pred = gene_preds.max().item()
284-
print(f"Max prediction value at initialization: {max_pred:.2f}")
285-
assert (
286-
max_pred < 100.0
287-
), f"Predictions exploded! Max value: {max_pred}. Check L1 normalization."
288-
assert not torch.isnan(gene_preds).any(), "Found NaNs in initial predictions."
289-
290-
# 2. Gradient Flow Check (Compatibility with Training)
291-
loss_fn = AuxiliaryPathwayLoss(pathway_matrix, MaskedMSELoss(), lambda_pathway=1.0)
292-
loss = loss_fn(gene_preds, target_genes, mask=mask, pathway_preds=pathway_preds)
293-
294-
assert loss.isfinite(), "Loss is not finite."
295-
296-
loss.backward()
297-
298-
# Verify gradients reached the core transformer layers
299-
target_layer_grad = model.fusion_engine.layers[0].linear1.weight.grad
300-
assert target_layer_grad is not None, "Gradients did not reach the fusion engine."
301-
assert target_layer_grad.norm() > 0, "Vanishing gradients in the fusion engine."
302-
assert torch.isfinite(
303-
target_layer_grad
304-
).all(), "Exploding/NaN gradients in fusion engine."
305-
306-
# Verify gradients reached the final reconstructor layer
307-
recon_grad = model.gene_reconstructor.weight.grad
308-
assert recon_grad is not None, "Gradients did not reach the gene reconstructor."
309-
assert recon_grad.norm() > 0, "Vanishing gradients in the gene reconstructor."
310-
assert torch.isfinite(
311-
recon_grad
312-
).all(), "Exploding/NaN gradients in gene reconstructor."
313-
314-
print("Pathway initialization is fully stable and compatible with NN training.")
129+
# ---------------------------------------------------------------------------

tests/data/test_visualization.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""
2-
Merged tests: test_visualization.py, test_spatial_stats.py
2+
Tests for training summary visualization and spatial statistics (Moran's I, Coherence).
33
"""
44

55
import os
@@ -16,7 +16,9 @@
1616
_build_knn_weights,
1717
)
1818

19-
# --- From test_visualization.py ---
19+
# ---------------------------------------------------------------------------
20+
# Training Summary
21+
# ---------------------------------------------------------------------------
2022

2123

2224
matplotlib.use("Agg")
@@ -172,7 +174,9 @@ def test_constant_input_handled(self):
172174
assert np.allclose(z, 0.0, atol=1e-4)
173175

174176

175-
# --- From test_spatial_stats.py ---
177+
# ---------------------------------------------------------------------------
178+
# Spatial Statistics
179+
# ---------------------------------------------------------------------------
176180

177181

178182
def _make_grid(rows=10, cols=10):

0 commit comments

Comments
 (0)