Skip to content

Commit 8c9de86

Browse files
authored
Increase test coverage (#69)
Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
1 parent a15d2b0 commit 8c9de86

10 files changed

Lines changed: 706 additions & 2 deletions

quantflow_tests/test_ai.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import pytest
1111
from mcp.server.fastmcp import FastMCP
1212

13+
from quantflow.ai import server as ai_server
1314
from quantflow.ai.tools import charts, crypto, fred, stocks, vault
1415
from quantflow.ai.tools.base import McpTool
1516
from quantflow.data.vault import Vault
@@ -459,3 +460,28 @@ async def test_ascii_chart_empty(
459460
with patch("quantflow.ai.tools.base.FMP", return_value=mock_fmp):
460461
result = await charts_server.call_tool("ascii_chart", {"symbol": "FAKE"})
461462
assert "No price data" in text(result)
463+
464+
465+
def test_create_server_registers_all_tools() -> None:
466+
fake_tool = MagicMock()
467+
with (
468+
patch("quantflow.ai.server.McpTool", return_value=fake_tool),
469+
patch("quantflow.ai.server.vault.register") as vault_register,
470+
patch("quantflow.ai.server.crypto.register") as crypto_register,
471+
patch("quantflow.ai.server.stocks.register") as stocks_register,
472+
patch("quantflow.ai.server.fred.register") as fred_register,
473+
patch("quantflow.ai.server.charts.register") as charts_register,
474+
):
475+
mcp = ai_server.create_server()
476+
vault_register.assert_called_once_with(mcp, fake_tool)
477+
crypto_register.assert_called_once_with(mcp, fake_tool)
478+
stocks_register.assert_called_once_with(mcp, fake_tool)
479+
fred_register.assert_called_once_with(mcp, fake_tool)
480+
charts_register.assert_called_once_with(mcp, fake_tool)
481+
482+
483+
def test_main_runs_server() -> None:
484+
mock_server = MagicMock()
485+
with patch("quantflow.ai.server.create_server", return_value=mock_server):
486+
ai_server.main()
487+
mock_server.run.assert_called_once_with()

quantflow_tests/test_data.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1+
from datetime import date
12
from typing import AsyncIterator
3+
from unittest.mock import AsyncMock
24

35
import pytest
46

@@ -59,3 +61,38 @@ async def __test_fiscal_data() -> None:
5961
assert df is not None
6062
assert df.shape[0] > 0
6163
assert df.shape[1] == 2
64+
65+
66+
async def test_fiscal_securities_builds_previous_month_filter() -> None:
67+
fd = FiscalData()
68+
fd.get_all = AsyncMock(return_value=[{"a": 1}]) # type: ignore[method-assign]
69+
df = await fd.securities(record_date=date(2024, 3, 15))
70+
fd.get_all.assert_awaited_once_with(
71+
"/v1/debt/mspd/mspd_table_3_market",
72+
{"filter": "record_date:eq:2024-02-29"},
73+
)
74+
assert len(df) == 1
75+
76+
77+
async def test_fiscal_get_all_multi_page() -> None:
78+
fd = FiscalData()
79+
fd.get = AsyncMock( # type: ignore[method-assign]
80+
side_effect=[
81+
{
82+
"data": [{"id": 1}],
83+
"links": {"next": "/v1/debt/mspd/mspd_table_3_market?page=2"},
84+
},
85+
{"data": [{"id": 2}], "links": {"next": None}},
86+
]
87+
)
88+
data = await fd.get_all("/v1/debt/mspd/mspd_table_3_market", {"a": "b"})
89+
assert data == [{"id": 1}, {"id": 2}]
90+
assert fd.get.await_count == 2
91+
92+
93+
async def test_fiscal_get_all_single_page_without_links() -> None:
94+
fd = FiscalData()
95+
fd.get = AsyncMock(return_value={"data": [{"id": 7}]}) # type: ignore[method-assign]
96+
data = await fd.get_all("/v1/debt/mspd/mspd_table_3_market", {"a": "b"})
97+
assert data == [{"id": 7}]
98+
fd.get.assert_awaited_once()

quantflow_tests/test_divfm.py

Lines changed: 67 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,12 @@
1010
try:
1111
import torch
1212

13-
from quantflow.options.divfm.network import DIVFMNetwork
14-
from quantflow.options.divfm.trainer import DayData, DIVFMTrainer
13+
from quantflow.options.divfm.network import (
14+
DIVFMNetwork,
15+
)
16+
from quantflow.options.divfm.network import _extract_subnet as extract_subnet_torch
17+
from quantflow.options.divfm.network import _make_subnet as make_subnet_torch
18+
from quantflow.options.divfm.trainer import DayData, DIVFMTrainer, _day_loss
1519

1620
has_torch = True
1721
except ImportError:
@@ -191,6 +195,12 @@ def test_network_default_construction() -> None:
191195
assert net.extra_features == 0
192196

193197

198+
@pytest.mark.skipif(not has_torch, reason="torch not installed")
199+
def test_network_minimum_factors_validation() -> None:
200+
with pytest.raises(ValueError, match="at least 3"):
201+
DIVFMNetwork(num_factors=2)
202+
203+
194204
@pytest.mark.skipif(not has_torch, reason="torch not installed")
195205
def test_network_forward_shape() -> None:
196206
net = DIVFMNetwork(num_factors=NUM_FACTORS, hidden_size=HIDDEN_SIZE)
@@ -202,6 +212,28 @@ def test_network_forward_shape() -> None:
202212
assert (out[:, 0] == 1.0).all() # f_1 = 1
203213

204214

215+
@pytest.mark.skipif(not has_torch, reason="torch not installed")
216+
def test_make_subnet_layout() -> None:
217+
subnet = make_subnet_torch(2, 4, 2, 3)
218+
modules = list(subnet.children())
219+
assert isinstance(modules[0], torch.nn.Linear)
220+
assert isinstance(modules[1], torch.nn.Sigmoid)
221+
assert isinstance(modules[2], torch.nn.BatchNorm1d)
222+
assert isinstance(modules[-1], torch.nn.BatchNorm1d)
223+
assert modules[-1].affine is False
224+
225+
226+
@pytest.mark.skipif(not has_torch, reason="torch not installed")
227+
def test_extract_subnet_output_structure() -> None:
228+
subnet = make_subnet_torch(2, 4, 1, 1)
229+
subnet.eval()
230+
extracted = extract_subnet_torch(subnet)
231+
assert isinstance(extracted, SubnetWeights)
232+
assert len(extracted.layers) == 2
233+
assert extracted.layers[0].apply_activation is True
234+
assert extracted.layers[-1].apply_activation is False
235+
236+
205237
@pytest.mark.skipif(not has_torch, reason="torch not installed")
206238
def test_to_weights_forward_matches_network() -> None:
207239
net = DIVFMNetwork(num_factors=NUM_FACTORS, hidden_size=HIDDEN_SIZE)
@@ -220,6 +252,14 @@ def test_to_weights_forward_matches_network() -> None:
220252
np.testing.assert_allclose(torch_out, numpy_out, atol=1e-5)
221253

222254

255+
@pytest.mark.skipif(not has_torch, reason="torch not installed")
256+
def test_to_weights_without_joint_subnet() -> None:
257+
net = DIVFMNetwork(num_factors=3, hidden_size=HIDDEN_SIZE)
258+
weights = net.to_weights()
259+
assert weights.subnet_joint is None
260+
assert weights.num_factors == 3
261+
262+
223263
# ---------------------------------------------------------------------------
224264
# DIVFMTrainer tests (requires torch)
225265
# ---------------------------------------------------------------------------
@@ -252,6 +292,14 @@ def test_trainer_construction() -> None:
252292
assert trainer.network is net
253293

254294

295+
@pytest.mark.skipif(not has_torch, reason="torch not installed")
296+
def test_day_loss_non_negative() -> None:
297+
net = DIVFMNetwork(num_factors=NUM_FACTORS, hidden_size=HIDDEN_SIZE)
298+
day = _make_days(num_days=1)[0]
299+
loss = _day_loss(net, day, ridge=1e-6)
300+
assert float(loss.detach().item()) >= 0.0
301+
302+
255303
@pytest.mark.skipif(not has_torch, reason="torch not installed")
256304
def test_trainer_step_returns_loss() -> None:
257305
net = DIVFMNetwork(num_factors=NUM_FACTORS, hidden_size=HIDDEN_SIZE)
@@ -272,6 +320,13 @@ def test_trainer_evaluate() -> None:
272320
assert val_loss >= 0.0
273321

274322

323+
@pytest.mark.skipif(not has_torch, reason="torch not installed")
324+
def test_trainer_evaluate_empty_days() -> None:
325+
net = DIVFMNetwork(num_factors=NUM_FACTORS, hidden_size=HIDDEN_SIZE)
326+
trainer = DIVFMTrainer(net)
327+
assert trainer.evaluate([]) == 0.0
328+
329+
275330
@pytest.mark.skipif(not has_torch, reason="torch not installed")
276331
def test_trainer_fit_loss_decreases() -> None:
277332
"""Loss should decrease over training steps on a structured IV surface.
@@ -296,6 +351,16 @@ def test_trainer_fit_loss_decreases() -> None:
296351
assert np.mean(losses[-10:]) < np.mean(losses[:10])
297352

298353

354+
@pytest.mark.skipif(not has_torch, reason="torch not installed")
355+
def test_trainer_fit_with_validation_days() -> None:
356+
torch.manual_seed(1)
357+
net = DIVFMNetwork(num_factors=NUM_FACTORS, hidden_size=HIDDEN_SIZE)
358+
trainer = DIVFMTrainer(net, lr=1e-2, batch_days=4)
359+
days = _make_days(num_days=8)
360+
losses = trainer.fit(days, num_steps=5, val_days=days[:2], log_every=2)
361+
assert len(losses) == 5
362+
363+
299364
@pytest.mark.skipif(not has_torch, reason="torch not installed")
300365
def test_trainer_to_weights_produces_pricer() -> None:
301366
net = DIVFMNetwork(num_factors=NUM_FACTORS, hidden_size=HIDDEN_SIZE)

quantflow_tests/test_fmp_unit.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
from __future__ import annotations
2+
3+
from datetime import date
4+
from unittest.mock import AsyncMock
5+
6+
import pandas as pd
7+
8+
from quantflow.data.fmp import FMP, nice_sector_performance, summary_sector_performance
9+
10+
11+
def test_freq_crate_and_join_and_params() -> None:
12+
assert FMP.freq.crate(None) == FMP.freq.daily
13+
assert FMP.freq.crate("1hour") == FMP.freq.one_hour
14+
assert FMP.freq.crate("bad") == FMP.freq.daily
15+
16+
fmp = FMP(key="k")
17+
assert fmp.join("AAPL", "MSFT") == "AAPL,MSFT"
18+
assert fmp.params({"a": 1}) == {"params": {"a": 1, "apikey": "k"}}
19+
20+
21+
async def test_prices_daily_and_intraday_paths() -> None:
22+
fmp = FMP(key="k")
23+
fmp.get_path = AsyncMock(return_value=[{"date": "2024-01-01", "close": 100}]) # type: ignore[method-assign]
24+
df = await fmp.prices("AAPL", frequency=None, convert_to_date=True)
25+
assert isinstance(df, pd.DataFrame)
26+
assert str(df["date"].dtype).startswith("datetime64")
27+
fmp.get_path.assert_awaited_with(
28+
"historical-price-eod/full",
29+
params={"symbol": "AAPL"},
30+
)
31+
32+
fmp.get_path = AsyncMock(return_value={"historical": [{"date": "2024-01-01"}]}) # type: ignore[method-assign]
33+
await fmp.prices("AAPL", frequency="1hour")
34+
fmp.get_path.assert_awaited_with(
35+
"historical-chart/1hour",
36+
params={"frequency": "1hour", "symbol": "AAPL"},
37+
)
38+
39+
40+
async def test_sector_performance_summary_and_timeseries() -> None:
41+
fmp = FMP(key="k")
42+
fmp.get_path = AsyncMock( # type: ignore[method-assign]
43+
side_effect=[
44+
[{"sector": "Tech", "changesPercentage": "1.23%"}],
45+
[
46+
{"date": "2024-01-01", "technologyChangesPercentage": 1.0},
47+
{"date": "2024-01-02", "technologyChangesPercentage": 2.0},
48+
],
49+
]
50+
)
51+
snapshot = await fmp.sector_performance()
52+
assert isinstance(snapshot, dict)
53+
assert str(snapshot["Tech"]) == "1.23"
54+
55+
summary = await fmp.sector_performance(from_date=date(2024, 1, 1), summary=True)
56+
assert isinstance(summary, dict)
57+
assert str(summary["Technology"]) == "3.02"
58+
59+
60+
def test_sector_helpers() -> None:
61+
nice = dict(
62+
nice_sector_performance(
63+
{"date": "2024-01-01", "consumerStaplesChangesPercentage": 1.5}
64+
)
65+
)
66+
assert nice["date"] == date(2024, 1, 1)
67+
assert nice["Consumer Staples"] == 1.5
68+
69+
summary = summary_sector_performance(
70+
[
71+
{"date": date(2024, 1, 1), "Tech": 1.0},
72+
{"date": date(2024, 1, 2), "Tech": 2.0},
73+
]
74+
)
75+
assert str(summary["Tech"]) == "3.02"
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
from __future__ import annotations
2+
3+
import numpy as np
4+
5+
from quantflow.options.calibration import (
6+
DoubleHestonCalibration,
7+
DoubleHestonJCalibration,
8+
HestonCalibration,
9+
HestonJCalibration,
10+
)
11+
from quantflow.options.pricer import OptionPricer
12+
from quantflow.sp.heston import DoubleHeston, DoubleHestonJ, Heston, HestonJ
13+
from quantflow.utils.distributions import DoubleExponential
14+
15+
16+
def test_heston_calibration_get_set_and_penalize(vol_surface) -> None:
17+
cal: HestonCalibration = HestonCalibration(
18+
pricer=OptionPricer(
19+
model=Heston.create(vol=0.4, kappa=2.0, sigma=0.6, rho=-0.4)
20+
),
21+
vol_surface=vol_surface,
22+
)
23+
params = cal.get_params()
24+
assert len(params) == 5
25+
updated = np.array([0.1, 0.1, 1.0, 1.0, -0.2], dtype=float)
26+
cal.set_params(updated)
27+
assert np.allclose(cal.get_params(), updated)
28+
assert cal.penalize() >= 0.0
29+
bounds = cal.get_bounds()
30+
assert len(bounds.lb) == 5
31+
assert len(bounds.ub) == 5
32+
33+
34+
def test_hestonj_calibration_get_set_and_bounds(vol_surface) -> None:
35+
model = HestonJ.create(
36+
DoubleExponential,
37+
vol=0.3,
38+
kappa=1.5,
39+
sigma=0.4,
40+
rho=-0.3,
41+
jump_fraction=0.2,
42+
jump_asymmetry=0.1,
43+
)
44+
cal: HestonJCalibration = HestonJCalibration(
45+
pricer=OptionPricer(model=model), vol_surface=vol_surface
46+
)
47+
params = cal.get_params()
48+
cal.set_params(params)
49+
assert np.allclose(cal.get_params(), params)
50+
bounds = cal.get_bounds()
51+
assert len(bounds.lb) == len(bounds.ub) == len(params)
52+
53+
54+
def test_double_heston_calibration_param_logic(vol_surface) -> None:
55+
model = DoubleHeston(
56+
heston1=Heston.create(vol=0.3, kappa=2.0, sigma=0.4, rho=-0.2),
57+
heston2=Heston.create(vol=0.25, kappa=1.0, sigma=0.3, rho=-0.4),
58+
)
59+
cal: DoubleHestonCalibration = DoubleHestonCalibration(
60+
pricer=OptionPricer(model=model), vol_surface=vol_surface
61+
)
62+
params = cal.get_params()
63+
assert len(params) == 10
64+
cal.set_params(params)
65+
assert (
66+
cal.model.heston1.variance_process.kappa
67+
>= cal.model.heston2.variance_process.kappa
68+
)
69+
assert cal.penalize() >= 0.0
70+
assert len(cal.feller_residuals()) == 2
71+
assert cal.maturity_split() > 0.0
72+
73+
74+
def test_double_hestonj_calibration_get_set_and_bounds(vol_surface) -> None:
75+
model = DoubleHestonJ(
76+
heston1=HestonJ.create(
77+
DoubleExponential,
78+
vol=0.3,
79+
kappa=2.5,
80+
sigma=0.4,
81+
rho=-0.2,
82+
jump_fraction=0.2,
83+
jump_asymmetry=0.1,
84+
),
85+
heston2=Heston.create(vol=0.25, kappa=1.2, sigma=0.3, rho=-0.4),
86+
)
87+
cal: DoubleHestonJCalibration = DoubleHestonJCalibration(
88+
pricer=OptionPricer(model=model),
89+
vol_surface=vol_surface,
90+
)
91+
params = cal.get_params()
92+
cal.set_params(params)
93+
assert np.allclose(cal.get_params(), params)
94+
bounds = cal.get_bounds()
95+
assert len(bounds.lb) == len(bounds.ub) == len(params)

0 commit comments

Comments
 (0)