11"""
2- Merged tests: test_pathways.py, test_pathways_robust.py, test_pathway_stability.py
2+ Tests for MSigDB pathway parsing and membership matrix construction.
33"""
44
55import pytest
1313 MSIGDB_URLS ,
1414)
1515from spatial_transcript_former .data .pathways import build_membership_matrix
16- from spatial_transcript_former .models .interaction import SpatialTranscriptFormer
17- from spatial_transcript_former .training .losses import (
18- AuxiliaryPathwayLoss ,
19- MaskedMSELoss ,
20- )
21-
22- # --- From test_pathways.py ---
2316
2417
2518@pytest .fixture (scope = "module" )
@@ -133,182 +126,4 @@ def test_core_pathways_exist(self, pathway_result):
133126
134127# ---------------------------------------------------------------------------
135128# Pathway ground truth
136- # ---------------------------------------------------------------------------
137-
138-
139- class TestPathwayTruth :
140- def test_consistent_across_calls (self , gene_list ):
141- """Ground truth from MSigDB membership should be identical across calls."""
142- from spatial_transcript_former .visualization import _compute_pathway_truth
143- from unittest .mock import MagicMock
144-
145- args = MagicMock ()
146- args .sparsity_lambda = 0.0
147- args .pathways = None
148-
149- np .random .seed (42 )
150- gene_truth = np .random .rand (200 , len (gene_list )).astype (np .float32 )
151-
152- result1 , names1 = _compute_pathway_truth (gene_truth , gene_list , args )
153- result2 , names2 = _compute_pathway_truth (gene_truth , gene_list , args )
154-
155- np .testing .assert_array_equal (result1 , result2 )
156- assert names1 == names2
157-
158- def test_output_shape (self , gene_list ):
159- """Pathway truth should be (N, P) where P=50 (Hallmarks default)."""
160- from spatial_transcript_former .visualization import _compute_pathway_truth
161- from unittest .mock import MagicMock
162-
163- args = MagicMock ()
164- args .sparsity_lambda = 0.0
165- args .pathways = None
166-
167- N = 150
168- gene_truth = np .random .rand (N , len (gene_list )).astype (np .float32 )
169- result , names = _compute_pathway_truth (gene_truth , gene_list , args )
170-
171- assert result .shape == (N , 50 )
172- assert len (names ) == 50
173-
174- def test_spatial_variation (self , gene_list ):
175- """Pathway truth should have spatial variation (non-zero std)."""
176- from spatial_transcript_former .visualization import _compute_pathway_truth
177- from unittest .mock import MagicMock
178-
179- args = MagicMock ()
180- args .sparsity_lambda = 0.0
181- args .pathways = None
182-
183- # Create gene expression with spatial patterns
184- N = 200
185- gene_truth = np .random .rand (N , len (gene_list )).astype (np .float32 )
186- # Add spatial structure to first few genes
187- gene_truth [:100 , 0 ] += 5.0
188- gene_truth [100 :, 1 ] += 5.0
189-
190- result , _ = _compute_pathway_truth (gene_truth , gene_list , args )
191-
192- # At least some pathways should have non-trivial spatial variation
193- stds = np .std (result , axis = 0 )
194- assert np .any (stds > 0.01 ), "Pathway truth has no spatial variation"
195-
196-
197- # --- From test_pathways_robust.py ---
198-
199-
200- def test_build_membership_matrix_integrity ():
201- """Verify that the membership matrix correctly maps genes to pathways."""
202- pathway_dict = {
203- "PATHWAY_A" : ["GENE_1" , "GENE_2" ],
204- "PATHWAY_B" : ["GENE_2" , "GENE_3" ],
205- }
206- gene_list = ["GENE_1" , "GENE_2" , "GENE_3" , "GENE_4" ]
207-
208- matrix , names = build_membership_matrix (pathway_dict , gene_list )
209-
210- assert names == ["PATHWAY_A" , "PATHWAY_B" ]
211- assert matrix .shape == (2 , 4 )
212-
213- # Pathway A: GENE_1, GENE_2
214- assert matrix [0 , 0 ] == 1.0
215- assert matrix [0 , 1 ] == 1.0
216- assert matrix [0 , 2 ] == 0.0
217- assert matrix [0 , 3 ] == 0.0
218-
219- # Pathway B: GENE_2, GENE_3
220- assert matrix [1 , 0 ] == 0.0
221- assert matrix [1 , 1 ] == 1.0
222- assert matrix [1 , 2 ] == 1.0
223- assert matrix [1 , 3 ] == 0.0
224-
225-
226- def test_build_membership_matrix_empty ():
227- """Check behavior with no matches."""
228- pathway_dict = {"EMPTY" : ["XYZ" ]}
229- gene_list = ["ABC" , "DEF" ]
230- matrix , names = build_membership_matrix (pathway_dict , gene_list )
231- assert matrix .sum () == 0
232- assert names == ["EMPTY" ]
233-
234-
235- # --- From test_pathway_stability.py ---
236-
237-
238- def test_pathway_initialization_stability_and_gradients ():
239- """
240- Verifies that initializing the model with a binary pathway matrix:
241- 1. Does not cause predictions to exponentially explode (numerical stability).
242- 2. Allows gradients to flow properly when using AuxiliaryPathwayLoss.
243- """
244- torch .manual_seed (42 )
245- num_pathways = 50
246- num_genes = 100
247-
248- # Create a synthetic MSigDB-style binary matrix
249- pathway_matrix = (torch .rand (num_pathways , num_genes ) > 0.8 ).float ()
250- # Ensure no empty pathways to avoid division by zero
251- pathway_matrix [:, 0 ] = 1.0
252-
253- # Initialize model with pathway_init
254- model = SpatialTranscriptFormer (
255- num_genes = num_genes ,
256- num_pathways = num_pathways ,
257- pathway_init = pathway_matrix ,
258- use_spatial_pe = False ,
259- output_mode = "counts" ,
260- pretrained = False ,
261- )
262-
263- # Dummy inputs
264- B , S , D = (
265- 2 ,
266- 10 ,
267- 2048 ,
268- ) # Using D=2048 since backbone='resnet50' requires it natively, or provided features
269- feats = torch .randn (B , S , D , requires_grad = True )
270- coords = torch .randn (B , S , 2 )
271- target_genes = torch .randn (B , S , num_genes ).abs ()
272- mask = torch .zeros (B , S , dtype = torch .bool )
273-
274- # Forward pass
275- # return_pathways=True is needed to get the intermediate pathway preds for Auxiliary loss
276- gene_preds , pathway_preds = model (
277- feats , rel_coords = coords , return_dense = True , return_pathways = True
278- )
279-
280- # 1. Numerical Stability Check
281- # Without L1 normalization and removing temperature, predictions would explode.
282- # With the fix, Softplus should keep outputs reasonably small.
283- max_pred = gene_preds .max ().item ()
284- print (f"Max prediction value at initialization: { max_pred :.2f} " )
285- assert (
286- max_pred < 100.0
287- ), f"Predictions exploded! Max value: { max_pred } . Check L1 normalization."
288- assert not torch .isnan (gene_preds ).any (), "Found NaNs in initial predictions."
289-
290- # 2. Gradient Flow Check (Compatibility with Training)
291- loss_fn = AuxiliaryPathwayLoss (pathway_matrix , MaskedMSELoss (), lambda_pathway = 1.0 )
292- loss = loss_fn (gene_preds , target_genes , mask = mask , pathway_preds = pathway_preds )
293-
294- assert loss .isfinite (), "Loss is not finite."
295-
296- loss .backward ()
297-
298- # Verify gradients reached the core transformer layers
299- target_layer_grad = model .fusion_engine .layers [0 ].linear1 .weight .grad
300- assert target_layer_grad is not None , "Gradients did not reach the fusion engine."
301- assert target_layer_grad .norm () > 0 , "Vanishing gradients in the fusion engine."
302- assert torch .isfinite (
303- target_layer_grad
304- ).all (), "Exploding/NaN gradients in fusion engine."
305-
306- # Verify gradients reached the final reconstructor layer
307- recon_grad = model .gene_reconstructor .weight .grad
308- assert recon_grad is not None , "Gradients did not reach the gene reconstructor."
309- assert recon_grad .norm () > 0 , "Vanishing gradients in the gene reconstructor."
310- assert torch .isfinite (
311- recon_grad
312- ).all (), "Exploding/NaN gradients in gene reconstructor."
313-
314- print ("Pathway initialization is fully stable and compatible with NN training." )
129+ # ---------------------------------------------------------------------------
0 commit comments