Skip to content

Commit 43c7631

Browse files
committed
test(pipeline): convert pipeline tests from pytest to unittest
CI uses 'unittest discover' and does not install pytest as a dev dependency, so 'import pytest' caused ImportError on the runner. Rewrites all assertions/raises to native unittest.TestCase APIs. No production code changes; behavior covered is identical.
1 parent 5b8aad4 commit 43c7631

3 files changed

Lines changed: 314 additions & 291 deletions

File tree

tests/domain/pipeline/test_factors.py

Lines changed: 151 additions & 138 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@
22
from __future__ import annotations
33

44
import math
5+
import unittest
56
from datetime import datetime, timedelta
67

78
import polars as pl
8-
import pytest
99

1010
from investing_algorithm_framework import (
1111
AverageDollarVolume,
@@ -43,140 +43,153 @@ def _bar(dt_idx, close, volume=1.0):
4343
return (dt, close, close, close, close, volume)
4444

4545

46-
def test_returns_simple_percent_return():
47-
panel = _panel({"X": [_bar(i, c) for i, c in enumerate([10, 11, 12, 13])]})
48-
series = Returns(window=2).compute_panel(panel).to_list()
49-
# Bar 0,1 → null; bar 2 → 12/10 - 1; bar 3 → 13/11 - 1
50-
assert series[0] is None and series[1] is None
51-
assert series[2] == pytest.approx(0.2)
52-
assert series[3] == pytest.approx(13.0 / 11.0 - 1.0)
53-
54-
55-
def test_average_dollar_volume_rolling_mean():
56-
panel = _panel(
57-
{
58-
"X": [
59-
(datetime(2024, 1, 1) + timedelta(days=i), c, c, c, c, vol)
60-
for i, (c, vol) in enumerate(
61-
[(10, 1), (20, 2), (30, 3), (40, 4)]
62-
)
63-
]
64-
}
65-
)
66-
series = AverageDollarVolume(window=2).compute_panel(panel).to_list()
67-
# close*volume = [10, 40, 90, 160]; rolling mean window=2
68-
assert series[0] is None
69-
assert series[1] == pytest.approx(25.0)
70-
assert series[2] == pytest.approx(65.0)
71-
assert series[3] == pytest.approx(125.0)
72-
73-
74-
def test_sma_rolling_mean():
75-
panel = _panel({"X": [_bar(i, c) for i, c in enumerate([1, 2, 3, 4, 5])]})
76-
series = SMA(window=3).compute_panel(panel).to_list()
77-
assert series[0] is None and series[1] is None
78-
assert series[2] == pytest.approx(2.0)
79-
assert series[3] == pytest.approx(3.0)
80-
assert series[4] == pytest.approx(4.0)
81-
82-
83-
def test_volatility_log_return_stdev_scaled():
84-
closes = [100.0, 101.0, 99.0, 102.0, 100.0, 103.0]
85-
panel = _panel({"X": [_bar(i, c) for i, c in enumerate(closes)]})
86-
window = 4
87-
pp_year = 252
88-
series = (
89-
Volatility(window=window, periods_per_year=pp_year)
90-
.compute_panel(panel)
91-
.to_list()
92-
)
93-
# Manually compute the last value
94-
log_rets = [math.log(closes[i] / closes[i - 1]) for i in range(1, len(closes))]
95-
last_window = log_rets[-window:]
96-
mean = sum(last_window) / window
97-
var = sum((x - mean) ** 2 for x in last_window) / (window - 1)
98-
expected = math.sqrt(var) * math.sqrt(pp_year)
99-
assert series[-1] == pytest.approx(expected)
100-
101-
102-
def test_rsi_all_gains_returns_100():
103-
panel = _panel({"X": [_bar(i, c) for i, c in enumerate(range(1, 20))]})
104-
series = RSI(window=4).compute_panel(panel).to_list()
105-
# All gains, no losses → avg_loss == 0 → RSI clamped to 100
106-
assert series[-1] == pytest.approx(100.0)
107-
108-
109-
def test_rsi_with_losses_strictly_between_0_and_100():
110-
closes = [100, 102, 101, 103, 99, 104, 100, 106, 101]
111-
panel = _panel({"X": [_bar(i, c) for i, c in enumerate(closes)]})
112-
series = RSI(window=4).compute_panel(panel).to_list()
113-
last = series[-1]
114-
assert last is not None
115-
assert 0.0 < last < 100.0
116-
117-
118-
def test_factor_rank_orders_within_each_bar():
119-
# 3 symbols, 1 bar of meaningful data — but rank needs Returns(window=1).
120-
panel = _panel(
121-
{
122-
"AAA": [_bar(0, 100), _bar(1, 110)], # +10%
123-
"BBB": [_bar(0, 100), _bar(1, 105)], # +5%
124-
"CCC": [_bar(0, 100), _bar(1, 120)], # +20%
125-
}
126-
)
127-
ranked = Returns(window=1).rank().compute_panel(panel)
128-
df = panel.select(["datetime", "symbol"]).with_columns(
129-
ranked.alias("rk")
130-
).filter(pl.col("datetime") == datetime(2024, 1, 2))
131-
out = {row["symbol"]: row["rk"] for row in df.to_dicts()}
132-
# Ascending ordinal ranks: BBB=1, AAA=2, CCC=3
133-
assert out["BBB"] == 1.0
134-
assert out["AAA"] == 2.0
135-
assert out["CCC"] == 3.0
136-
137-
138-
def test_factor_top_filter_keeps_highest():
139-
panel = _panel(
140-
{
141-
"AAA": [_bar(0, 100), _bar(1, 110)],
142-
"BBB": [_bar(0, 100), _bar(1, 105)],
143-
"CCC": [_bar(0, 100), _bar(1, 120)],
144-
}
145-
)
146-
mask = Returns(window=1).top(2).compute_panel(panel)
147-
df = panel.select(["datetime", "symbol"]).with_columns(
148-
mask.alias("m")
149-
).filter(pl.col("datetime") == datetime(2024, 1, 2))
150-
out = {row["symbol"]: row["m"] for row in df.to_dicts()}
151-
# Top 2 by descending returns: CCC (20%) and AAA (10%)
152-
assert out["AAA"] is True
153-
assert out["CCC"] is True
154-
assert out["BBB"] is False
155-
156-
157-
def test_factor_bottom_filter_keeps_lowest():
158-
panel = _panel(
159-
{
160-
"AAA": [_bar(0, 100), _bar(1, 110)],
161-
"BBB": [_bar(0, 100), _bar(1, 105)],
162-
"CCC": [_bar(0, 100), _bar(1, 120)],
163-
}
164-
)
165-
mask = Returns(window=1).bottom(1).compute_panel(panel)
166-
df = panel.select(["datetime", "symbol"]).with_columns(
167-
mask.alias("m")
168-
).filter(pl.col("datetime") == datetime(2024, 1, 2))
169-
out = {row["symbol"]: row["m"] for row in df.to_dicts()}
170-
assert out["BBB"] is True
171-
assert out["AAA"] is False
172-
assert out["CCC"] is False
173-
174-
175-
def test_factor_invalid_window_raises():
176-
with pytest.raises(ValueError):
177-
Returns(window=0)
178-
179-
180-
def test_volatility_invalid_periods_raises():
181-
with pytest.raises(ValueError):
182-
Volatility(window=10, periods_per_year=0)
46+
class TestPipelineFactors(unittest.TestCase):
47+
48+
def test_returns_simple_percent_return(self):
49+
panel = _panel(
50+
{"X": [_bar(i, c) for i, c in enumerate([10, 11, 12, 13])]}
51+
)
52+
series = Returns(window=2).compute_panel(panel).to_list()
53+
# Bar 0,1 → null; bar 2 → 12/10 - 1; bar 3 → 13/11 - 1
54+
self.assertIsNone(series[0])
55+
self.assertIsNone(series[1])
56+
self.assertAlmostEqual(series[2], 0.2)
57+
self.assertAlmostEqual(series[3], 13.0 / 11.0 - 1.0)
58+
59+
def test_average_dollar_volume_rolling_mean(self):
60+
panel = _panel(
61+
{
62+
"X": [
63+
(datetime(2024, 1, 1) + timedelta(days=i), c, c, c, c, vol)
64+
for i, (c, vol) in enumerate(
65+
[(10, 1), (20, 2), (30, 3), (40, 4)]
66+
)
67+
]
68+
}
69+
)
70+
series = AverageDollarVolume(window=2).compute_panel(panel).to_list()
71+
# close*volume = [10, 40, 90, 160]; rolling mean window=2
72+
self.assertIsNone(series[0])
73+
self.assertAlmostEqual(series[1], 25.0)
74+
self.assertAlmostEqual(series[2], 65.0)
75+
self.assertAlmostEqual(series[3], 125.0)
76+
77+
def test_sma_rolling_mean(self):
78+
panel = _panel(
79+
{"X": [_bar(i, c) for i, c in enumerate([1, 2, 3, 4, 5])]}
80+
)
81+
series = SMA(window=3).compute_panel(panel).to_list()
82+
self.assertIsNone(series[0])
83+
self.assertIsNone(series[1])
84+
self.assertAlmostEqual(series[2], 2.0)
85+
self.assertAlmostEqual(series[3], 3.0)
86+
self.assertAlmostEqual(series[4], 4.0)
87+
88+
def test_volatility_log_return_stdev_scaled(self):
89+
closes = [100.0, 101.0, 99.0, 102.0, 100.0, 103.0]
90+
panel = _panel({"X": [_bar(i, c) for i, c in enumerate(closes)]})
91+
window = 4
92+
pp_year = 252
93+
series = (
94+
Volatility(window=window, periods_per_year=pp_year)
95+
.compute_panel(panel)
96+
.to_list()
97+
)
98+
# Manually compute the last value
99+
log_rets = [
100+
math.log(closes[i] / closes[i - 1]) for i in range(1, len(closes))
101+
]
102+
last_window = log_rets[-window:]
103+
mean = sum(last_window) / window
104+
var = sum((x - mean) ** 2 for x in last_window) / (window - 1)
105+
expected = math.sqrt(var) * math.sqrt(pp_year)
106+
self.assertAlmostEqual(series[-1], expected)
107+
108+
def test_rsi_all_gains_returns_100(self):
109+
panel = _panel(
110+
{"X": [_bar(i, c) for i, c in enumerate(range(1, 20))]}
111+
)
112+
series = RSI(window=4).compute_panel(panel).to_list()
113+
# All gains, no losses → avg_loss == 0 → RSI clamped to 100
114+
self.assertAlmostEqual(series[-1], 100.0)
115+
116+
def test_rsi_with_losses_strictly_between_0_and_100(self):
117+
closes = [100, 102, 101, 103, 99, 104, 100, 106, 101]
118+
panel = _panel({"X": [_bar(i, c) for i, c in enumerate(closes)]})
119+
series = RSI(window=4).compute_panel(panel).to_list()
120+
last = series[-1]
121+
self.assertIsNotNone(last)
122+
self.assertGreater(last, 0.0)
123+
self.assertLess(last, 100.0)
124+
125+
def test_factor_rank_orders_within_each_bar(self):
126+
# 3 symbols, 1 bar of meaningful data — rank needs Returns(window=1).
127+
panel = _panel(
128+
{
129+
"AAA": [_bar(0, 100), _bar(1, 110)], # +10%
130+
"BBB": [_bar(0, 100), _bar(1, 105)], # +5%
131+
"CCC": [_bar(0, 100), _bar(1, 120)], # +20%
132+
}
133+
)
134+
ranked = Returns(window=1).rank().compute_panel(panel)
135+
df = (
136+
panel.select(["datetime", "symbol"])
137+
.with_columns(ranked.alias("rk"))
138+
.filter(pl.col("datetime") == datetime(2024, 1, 2))
139+
)
140+
out = {row["symbol"]: row["rk"] for row in df.to_dicts()}
141+
# Ascending ordinal ranks: BBB=1, AAA=2, CCC=3
142+
self.assertEqual(out["BBB"], 1.0)
143+
self.assertEqual(out["AAA"], 2.0)
144+
self.assertEqual(out["CCC"], 3.0)
145+
146+
def test_factor_top_filter_keeps_highest(self):
147+
panel = _panel(
148+
{
149+
"AAA": [_bar(0, 100), _bar(1, 110)],
150+
"BBB": [_bar(0, 100), _bar(1, 105)],
151+
"CCC": [_bar(0, 100), _bar(1, 120)],
152+
}
153+
)
154+
mask = Returns(window=1).top(2).compute_panel(panel)
155+
df = (
156+
panel.select(["datetime", "symbol"])
157+
.with_columns(mask.alias("m"))
158+
.filter(pl.col("datetime") == datetime(2024, 1, 2))
159+
)
160+
out = {row["symbol"]: row["m"] for row in df.to_dicts()}
161+
# Top 2 by descending returns: CCC (20%) and AAA (10%)
162+
self.assertTrue(out["AAA"])
163+
self.assertTrue(out["CCC"])
164+
self.assertFalse(out["BBB"])
165+
166+
def test_factor_bottom_filter_keeps_lowest(self):
167+
panel = _panel(
168+
{
169+
"AAA": [_bar(0, 100), _bar(1, 110)],
170+
"BBB": [_bar(0, 100), _bar(1, 105)],
171+
"CCC": [_bar(0, 100), _bar(1, 120)],
172+
}
173+
)
174+
mask = Returns(window=1).bottom(1).compute_panel(panel)
175+
df = (
176+
panel.select(["datetime", "symbol"])
177+
.with_columns(mask.alias("m"))
178+
.filter(pl.col("datetime") == datetime(2024, 1, 2))
179+
)
180+
out = {row["symbol"]: row["m"] for row in df.to_dicts()}
181+
self.assertTrue(out["BBB"])
182+
self.assertFalse(out["AAA"])
183+
self.assertFalse(out["CCC"])
184+
185+
def test_factor_invalid_window_raises(self):
186+
with self.assertRaises(ValueError):
187+
Returns(window=0)
188+
189+
def test_volatility_invalid_periods_raises(self):
190+
with self.assertRaises(ValueError):
191+
Volatility(window=10, periods_per_year=0)
192+
193+
194+
if __name__ == "__main__":
195+
unittest.main()

tests/domain/pipeline/test_pipeline.py

Lines changed: 38 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
"""
66
from __future__ import annotations
77

8-
import pytest
8+
import unittest
99

1010
from investing_algorithm_framework import (
1111
AverageDollarVolume,
@@ -22,48 +22,50 @@ class _Screener(Pipeline):
2222
alpha = momentum.rank(mask=universe)
2323

2424

25-
def test_pipeline_collects_columns_excluding_universe():
26-
cols = _Screener.get_columns()
27-
assert list(cols.keys()) == ["dollar_volume", "momentum", "alpha"]
28-
assert _Screener.get_universe() is _Screener.universe
25+
class TestPipelineIntrospection(unittest.TestCase):
2926

27+
def test_pipeline_collects_columns_excluding_universe(self):
28+
cols = _Screener.get_columns()
29+
self.assertEqual(
30+
list(cols.keys()), ["dollar_volume", "momentum", "alpha"]
31+
)
32+
self.assertIs(_Screener.get_universe(), _Screener.universe)
3033

31-
def test_pipeline_required_columns_union():
32-
required = _Screener.required_columns()
33-
# AverageDollarVolume needs close+volume, Returns needs close
34-
assert "close" in required
35-
assert "volume" in required
34+
def test_pipeline_required_columns_union(self):
35+
required = _Screener.required_columns()
36+
# AverageDollarVolume needs close+volume, Returns needs close
37+
self.assertIn("close", required)
38+
self.assertIn("volume", required)
3639

40+
def test_pipeline_required_window_is_max(self):
41+
self.assertEqual(_Screener.required_window(), 5)
3742

38-
def test_pipeline_required_window_is_max():
39-
assert _Screener.required_window() == 5
43+
def test_pipeline_name_defaults_to_class_name(self):
44+
self.assertEqual(_Screener.name(), "_Screener")
4045

46+
def test_pipeline_with_no_columns_raises(self):
47+
with self.assertRaisesRegex(TypeError, "declares no factor columns"):
48+
class _Empty(Pipeline):
49+
pass
4150

42-
def test_pipeline_name_defaults_to_class_name():
43-
assert _Screener.name() == "_Screener"
51+
def test_pipeline_universe_must_be_filter(self):
52+
with self.assertRaisesRegex(TypeError, "must be a Filter"):
53+
class _BadUniverse(Pipeline):
54+
momentum = Returns(window=3)
55+
# Returns is a Factor, not a Filter
56+
universe = momentum
4457

58+
def test_pipeline_inheritance_collects_parent_columns(self):
59+
class _Child(_Screener):
60+
sma = SMA(window=4)
4561

46-
def test_pipeline_with_no_columns_raises():
47-
with pytest.raises(TypeError, match="declares no factor columns"):
48-
class _Empty(Pipeline):
49-
pass
62+
cols = _Child.get_columns()
63+
# Child columns + parent columns
64+
self.assertIn("sma", cols)
65+
self.assertIn("dollar_volume", cols)
66+
self.assertIn("momentum", cols)
67+
self.assertIn("alpha", cols)
5068

5169

52-
def test_pipeline_universe_must_be_filter():
53-
with pytest.raises(TypeError, match="must be a Filter"):
54-
class _BadUniverse(Pipeline):
55-
momentum = Returns(window=3)
56-
# Returns is a Factor, not a Filter
57-
universe = momentum
58-
59-
60-
def test_pipeline_inheritance_collects_parent_columns():
61-
class _Child(_Screener):
62-
sma = SMA(window=4)
63-
64-
cols = _Child.get_columns()
65-
# Child columns + parent columns
66-
assert "sma" in cols
67-
assert "dollar_volume" in cols
68-
assert "momentum" in cols
69-
assert "alpha" in cols
70+
if __name__ == "__main__":
71+
unittest.main()

0 commit comments

Comments
 (0)