1818logger = logging .getLogger (__name__ )
1919
2020
21- def test_imports ():
21+ def _check_imports ():
2222 """Test all imports work correctly."""
2323 logger .info ("Testing imports..." )
2424
2525 try :
2626 import nodelens
2727
2828 # Core / registry
29- from nodelens .core import ModelWrapper # noqa: F401
30- from nodelens .metrics import METRIC_REGISTRY # noqa: F401
31- from nodelens .metrics . base import MetricComputer # noqa: F401
29+ from nodelens .core import METRIC_REGISTRY # noqa: F401
30+ from nodelens .metrics import get_metric , list_metrics # noqa: F401
31+ from nodelens .models import ModelWrapper # noqa: F401
3232
3333 # Pruning + services
3434 from nodelens .pruning import get_pruning_strategy # noqa: F401
@@ -43,51 +43,43 @@ def test_imports():
4343 return False
4444
4545
46- def test_metric_computer ():
46+ def _check_metric_computer ():
4747 """Test MetricComputer is functional."""
4848 logger .info ("\n Testing MetricComputer..." )
4949
5050 try :
51- from nodelens .metrics import METRIC_REGISTRY
52- from nodelens .metrics .base import MetricComputer
53-
54- # Create metrics
55- metrics = {
56- "rayleigh_quotient" : METRIC_REGISTRY .get_metric ("rayleigh_quotient" ),
57- "mutual_information" : METRIC_REGISTRY .get_metric ("mutual_information" ),
58- }
59-
60- # Create computer
61- computer = MetricComputer (metrics )
51+ from nodelens .metrics import get_metric
6252
63- # Test computation
6453 weights = torch .randn (10 , 20 )
54+ inputs = torch .randn (32 , 20 )
6555 outputs = torch .randn (32 , 10 )
6656
67- results = computer .compute_all (weights = weights , outputs = outputs )
57+ rq = get_metric ("rayleigh_quotient" ).compute (inputs = inputs , weights = weights )
58+ act = get_metric ("activation_l2_norm" ).compute (outputs = outputs )
6859
69- assert len (results ) == 2
70- assert "rayleigh_quotient" in results
71- assert "mutual_information" in results
60+ assert rq .shape == (weights .shape [0 ],)
61+ assert act .shape == (outputs .shape [1 ],)
62+ assert torch .all (torch .isfinite (rq ))
63+ assert torch .all (torch .isfinite (act ))
7264
73- logger .info ("OK MetricComputer is functional" )
65+ logger .info ("OK metric registry and metric computation are functional" )
7466 return True
7567 except Exception as e :
7668 logger .error (f"FAIL MetricComputer test failed: { e } " )
7769 return False
7870
7971
80- def test_parallel_processing ():
72+ def _check_parallel_processing ():
8173 """Test parallel processing is implemented."""
8274 logger .info ("\n Testing parallel processing..." )
8375
8476 try :
8577 import torch .nn as nn
8678 from torch .utils .data import DataLoader , TensorDataset
8779
88- from nodelens .core import ModelWrapper
89- from nodelens .metrics import METRIC_REGISTRY
90- from nodelens .utils . batch_processing import compute_metrics_parallel
80+ from nodelens .dataops . processing . batch import compute_metrics_parallel
81+ from nodelens .metrics import get_metric
82+ from nodelens .models import ModelWrapper
9183
9284 # Create simple model and data
9385 model = nn .Sequential (nn .Linear (10 , 20 ), nn .ReLU (), nn .Linear (20 , 5 ))
@@ -96,63 +88,53 @@ def test_parallel_processing():
9688 dataloader = DataLoader (dataset , batch_size = 10 )
9789
9890 wrapper = ModelWrapper (model , tracked_layers = ["0" , "2" ])
99- metrics = {"rayleigh_quotient " : METRIC_REGISTRY [ "rayleigh_quotient" ]( )}
91+ metrics = {"activation_l2_norm " : get_metric ( "activation_l2_norm" )}
10092
101- # Test parallel computation (will use single worker if only 1 GPU)
102- results = compute_metrics_parallel (wrapper , dataloader , metrics , num_workers = 2 )
93+ # Force the single-device path so this remains a lightweight CI smoke test.
94+ results = compute_metrics_parallel (wrapper , dataloader , metrics , num_workers = 1 , devices = [ torch . device ( "cpu" )] )
10395
10496 assert isinstance (results , dict )
105- logger .info ("OK Parallel processing is implemented" )
97+ assert set (results ) == {"0" , "2" }
98+ logger .info ("OK batch metric processing is functional" )
10699 return True
107100 except Exception as e :
108101 logger .error (f"FAIL Parallel processing test failed: { e } " )
109102 return False
110103
111104
112- def test_pruning_utilities ():
105+ def _check_pruning_utilities ():
113106 """Test pruning utilities are complete."""
114107 logger .info ("\n Testing pruning utilities..." )
115108
116109 try :
117110 import torch .nn as nn
118111
119- from nodelens .utils . pruning import PruningUtilities , create_pruning_schedule
112+ from nodelens .pruning import get_pruning_strategy
120113
121114 # Create test layer
122115 layer = nn .Linear (10 , 20 )
123116
124- # Test different pruning methods
125- methods = [
126- ("magnitude" , PruningUtilities .get_pruning_mask_magnitude ),
127- ("random" , PruningUtilities .get_pruning_mask_random ),
128- ]
129-
130- for name , method in methods :
131- mask = method (layer .weight .data , amount = 0.5 )
117+ for name in ["magnitude" , "random" ]:
118+ strategy = get_pruning_strategy (name )
119+ scores = strategy .compute_importance_scores (layer )
120+ mask = strategy .create_pruning_mask (scores , amount = 0.5 )
132121 assert mask .shape == layer .weight .shape
133122 assert 0.4 < (mask == 0 ).float ().mean () < 0.6 # Roughly 50% pruned
134123 logger .info (f" OK { name } pruning works" )
135124
136- # Test pruning schedule
137- schedule = create_pruning_schedule (0.0 , 0.9 , 0 , 100 , 10 , "polynomial" )
138- assert schedule (0 ) == 0.0
139- assert schedule (100 ) == 0.9
140- assert 0.0 < schedule (50 ) < 0.9
141- logger .info (" OK Pruning schedules work" )
142-
143125 logger .info ("OK All pruning utilities functional" )
144126 return True
145127 except Exception as e :
146128 logger .error (f"FAIL Pruning utilities test failed: { e } " )
147129 return False
148130
149131
150- def test_experiment_tracking ():
132+ def _check_experiment_tracking ():
151133 """Test experiment tracking is functional."""
152134 logger .info ("\n Testing experiment tracking..." )
153135
154136 try :
155- from nodelens .utils . experiment_tracking import ExperimentTracker , create_tracker
137+ from nodelens .experiments . tracking import ExperimentTracker , create_tracker
156138
157139 # Test base tracker (doesn't raise NotImplementedError anymore)
158140 tracker = ExperimentTracker ("test" , {"key" : "value" })
@@ -175,11 +157,16 @@ def test_experiment_tracking():
175157 return False
176158
177159
178- def test_examples_exist ():
160+ def _check_examples_exist ():
179161 """Test that comprehensive examples exist."""
180162 logger .info ("\n Checking examples..." )
181163
182- example_files = ["examples/quick_demo.py" , "examples/advanced_analysis.py" , "examples/comprehensive_demo.py" , "examples/pruning_demo.py" ]
164+ example_files = [
165+ "configs/examples/alexnet_pruning.yaml" ,
166+ "configs/examples/resnet_pruning.yaml" ,
167+ "configs/examples/llama3_extended_analysis.yaml" ,
168+ "projects/supernodes_scar/README.md" ,
169+ ]
183170
184171 all_exist = True
185172 for file in example_files :
@@ -192,19 +179,43 @@ def test_examples_exist():
192179 return all_exist
193180
194181
182+ def test_imports ():
183+ assert _check_imports ()
184+
185+
186+ def test_metric_computer ():
187+ assert _check_metric_computer ()
188+
189+
190+ def test_parallel_processing ():
191+ assert _check_parallel_processing ()
192+
193+
194+ def test_pruning_utilities ():
195+ assert _check_pruning_utilities ()
196+
197+
198+ def test_experiment_tracking ():
199+ assert _check_experiment_tracking ()
200+
201+
202+ def test_examples_exist ():
203+ assert _check_examples_exist ()
204+
205+
195206def main ():
196207 """Run all tests."""
197208 logger .info ("=" * 60 )
198209 logger .info ("TESTING ALL IMPLEMENTATIONS" )
199210 logger .info ("=" * 60 )
200211
201212 tests = [
202- ("Imports" , test_imports ),
203- ("MetricComputer" , test_metric_computer ),
204- ("Parallel Processing" , test_parallel_processing ),
205- ("Pruning Utilities" , test_pruning_utilities ),
206- ("Experiment Tracking" , test_experiment_tracking ),
207- ("Examples" , test_examples_exist ),
213+ ("Imports" , _check_imports ),
214+ ("MetricComputer" , _check_metric_computer ),
215+ ("Parallel Processing" , _check_parallel_processing ),
216+ ("Pruning Utilities" , _check_pruning_utilities ),
217+ ("Experiment Tracking" , _check_experiment_tracking ),
218+ ("Examples" , _check_examples_exist ),
208219 ]
209220
210221 results = {}
0 commit comments