1010try :
1111 import torch
1212
13- from quantflow .options .divfm .network import DIVFMNetwork
14- from quantflow .options .divfm .trainer import DayData , DIVFMTrainer
13+ from quantflow .options .divfm .network import (
14+ DIVFMNetwork ,
15+ )
16+ from quantflow .options .divfm .network import _extract_subnet as extract_subnet_torch
17+ from quantflow .options .divfm .network import _make_subnet as make_subnet_torch
18+ from quantflow .options .divfm .trainer import DayData , DIVFMTrainer , _day_loss
1519
1620 has_torch = True
1721except ImportError :
@@ -191,6 +195,12 @@ def test_network_default_construction() -> None:
191195 assert net .extra_features == 0
192196
193197
198+ @pytest .mark .skipif (not has_torch , reason = "torch not installed" )
199+ def test_network_minimum_factors_validation () -> None :
200+ with pytest .raises (ValueError , match = "at least 3" ):
201+ DIVFMNetwork (num_factors = 2 )
202+
203+
194204@pytest .mark .skipif (not has_torch , reason = "torch not installed" )
195205def test_network_forward_shape () -> None :
196206 net = DIVFMNetwork (num_factors = NUM_FACTORS , hidden_size = HIDDEN_SIZE )
@@ -202,6 +212,28 @@ def test_network_forward_shape() -> None:
202212 assert (out [:, 0 ] == 1.0 ).all () # f_1 = 1
203213
204214
215+ @pytest .mark .skipif (not has_torch , reason = "torch not installed" )
216+ def test_make_subnet_layout () -> None :
217+ subnet = make_subnet_torch (2 , 4 , 2 , 3 )
218+ modules = list (subnet .children ())
219+ assert isinstance (modules [0 ], torch .nn .Linear )
220+ assert isinstance (modules [1 ], torch .nn .Sigmoid )
221+ assert isinstance (modules [2 ], torch .nn .BatchNorm1d )
222+ assert isinstance (modules [- 1 ], torch .nn .BatchNorm1d )
223+ assert modules [- 1 ].affine is False
224+
225+
226+ @pytest .mark .skipif (not has_torch , reason = "torch not installed" )
227+ def test_extract_subnet_output_structure () -> None :
228+ subnet = make_subnet_torch (2 , 4 , 1 , 1 )
229+ subnet .eval ()
230+ extracted = extract_subnet_torch (subnet )
231+ assert isinstance (extracted , SubnetWeights )
232+ assert len (extracted .layers ) == 2
233+ assert extracted .layers [0 ].apply_activation is True
234+ assert extracted .layers [- 1 ].apply_activation is False
235+
236+
205237@pytest .mark .skipif (not has_torch , reason = "torch not installed" )
206238def test_to_weights_forward_matches_network () -> None :
207239 net = DIVFMNetwork (num_factors = NUM_FACTORS , hidden_size = HIDDEN_SIZE )
@@ -220,6 +252,14 @@ def test_to_weights_forward_matches_network() -> None:
220252 np .testing .assert_allclose (torch_out , numpy_out , atol = 1e-5 )
221253
222254
255+ @pytest .mark .skipif (not has_torch , reason = "torch not installed" )
256+ def test_to_weights_without_joint_subnet () -> None :
257+ net = DIVFMNetwork (num_factors = 3 , hidden_size = HIDDEN_SIZE )
258+ weights = net .to_weights ()
259+ assert weights .subnet_joint is None
260+ assert weights .num_factors == 3
261+
262+
223263# ---------------------------------------------------------------------------
224264# DIVFMTrainer tests (requires torch)
225265# ---------------------------------------------------------------------------
@@ -252,6 +292,14 @@ def test_trainer_construction() -> None:
252292 assert trainer .network is net
253293
254294
295+ @pytest .mark .skipif (not has_torch , reason = "torch not installed" )
296+ def test_day_loss_non_negative () -> None :
297+ net = DIVFMNetwork (num_factors = NUM_FACTORS , hidden_size = HIDDEN_SIZE )
298+ day = _make_days (num_days = 1 )[0 ]
299+ loss = _day_loss (net , day , ridge = 1e-6 )
300+ assert float (loss .detach ().item ()) >= 0.0
301+
302+
255303@pytest .mark .skipif (not has_torch , reason = "torch not installed" )
256304def test_trainer_step_returns_loss () -> None :
257305 net = DIVFMNetwork (num_factors = NUM_FACTORS , hidden_size = HIDDEN_SIZE )
@@ -272,6 +320,13 @@ def test_trainer_evaluate() -> None:
272320 assert val_loss >= 0.0
273321
274322
323+ @pytest .mark .skipif (not has_torch , reason = "torch not installed" )
324+ def test_trainer_evaluate_empty_days () -> None :
325+ net = DIVFMNetwork (num_factors = NUM_FACTORS , hidden_size = HIDDEN_SIZE )
326+ trainer = DIVFMTrainer (net )
327+ assert trainer .evaluate ([]) == 0.0
328+
329+
275330@pytest .mark .skipif (not has_torch , reason = "torch not installed" )
276331def test_trainer_fit_loss_decreases () -> None :
277332 """Loss should decrease over training steps on a structured IV surface.
@@ -296,6 +351,16 @@ def test_trainer_fit_loss_decreases() -> None:
296351 assert np .mean (losses [- 10 :]) < np .mean (losses [:10 ])
297352
298353
354+ @pytest .mark .skipif (not has_torch , reason = "torch not installed" )
355+ def test_trainer_fit_with_validation_days () -> None :
356+ torch .manual_seed (1 )
357+ net = DIVFMNetwork (num_factors = NUM_FACTORS , hidden_size = HIDDEN_SIZE )
358+ trainer = DIVFMTrainer (net , lr = 1e-2 , batch_days = 4 )
359+ days = _make_days (num_days = 8 )
360+ losses = trainer .fit (days , num_steps = 5 , val_days = days [:2 ], log_every = 2 )
361+ assert len (losses ) == 5
362+
363+
299364@pytest .mark .skipif (not has_torch , reason = "torch not installed" )
300365def test_trainer_to_weights_produces_pricer () -> None :
301366 net = DIVFMNetwork (num_factors = NUM_FACTORS , hidden_size = HIDDEN_SIZE )
0 commit comments