Skip to content

Commit 7495c4c

Browse files
committed
Fix lint
1 parent 299d8ea commit 7495c4c

1 file changed

Lines changed: 14 additions & 13 deletions

File tree

quantflow_tests/test_interpolated_curve.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,10 @@
1414

1515
_REF = datetime(2026, 6, 7, tzinfo=timezone.utc)
1616
_TTM = np.array([0.25, 1.0, 2.0, 5.0, 10.0])
17-
_RATES = [0.02, 0.025, 0.03, 0.035, 0.04]
17+
_RATES = [Decimal(r) for r in ("0.02", "0.025", "0.03", "0.035", "0.04")]
18+
_RATES_F = np.array([float(r) for r in _RATES])
1819
_YEAR = 365.0 * 86400.0
20+
_ADAPTER: TypeAdapter[AnyYieldCurve] = TypeAdapter(AnyYieldCurve)
1921

2022

2123
def _dates(ttm: np.ndarray = _TTM) -> list[datetime]:
@@ -24,7 +26,7 @@ def _dates(ttm: np.ndarray = _TTM) -> list[datetime]:
2426

2527
def _curve(
2628
interpolation_type: InterpolationType = InterpolationType.MONOTONE_CUBIC,
27-
rates: list[float] = _RATES,
29+
rates: list[Decimal] = _RATES,
2830
) -> InterpolatedYieldCurve:
2931
return InterpolatedYieldCurve(
3032
ref_date=_REF,
@@ -58,9 +60,8 @@ def test_anchor_rates_coerced_to_decimal() -> None:
5860

5961
def test_private_attrs_populated() -> None:
6062
curve = _curve()
61-
rf = np.array(_RATES)
6263
assert curve._ttm == pytest.approx(_TTM)
63-
assert curve._log_discount == pytest.approx(-rf * _TTM)
64+
assert curve._log_discount == pytest.approx(-_RATES_F * _TTM)
6465

6566

6667
# ---------------------------------------------------------------------------
@@ -76,14 +77,14 @@ def test_discount_factor_at_zero_is_one(interpolation_type: InterpolationType) -
7677
def test_reprices_nodes_exactly(interpolation_type: InterpolationType) -> None:
7778
curve = _curve(interpolation_type)
7879
fitted = curve.continuously_compounded_rate(_TTM)
79-
assert np.asarray(fitted) == pytest.approx(np.array(_RATES))
80+
assert np.asarray(fitted) == pytest.approx(_RATES_F)
8081

8182

8283
def test_discount_factor_matches_node_rates(
8384
interpolation_type: InterpolationType,
8485
) -> None:
8586
curve = _curve(interpolation_type)
86-
for t, r in zip(_TTM, _RATES):
87+
for t, r in zip(_TTM, _RATES_F):
8788
assert float(curve.discount_factor(t)) == pytest.approx(math.exp(-r * t))
8889

8990

@@ -119,11 +120,12 @@ def test_discount_factor_monotone_decreasing(
119120

120121
def test_monotone_cubic_introduces_no_new_extrema() -> None:
121122
# a non-monotone forward profile that would make a natural cubic overshoot
122-
rates = [0.05, 0.01, 0.05, 0.01, 0.05]
123+
rates = [Decimal(r) for r in ("0.05", "0.01", "0.05", "0.01", "0.05")]
123124
curve = _curve(InterpolationType.MONOTONE_CUBIC, rates=rates)
124125
g = np.log(np.asarray(curve.discount_factor(np.linspace(0.0, 10.0, 500))))
125126
# log discount factor must stay within the envelope of the node values
126-
node_g = np.concatenate([[0.0], -np.array(rates) * _TTM])
127+
rates_f = np.array([float(r) for r in rates])
128+
node_g = np.concatenate([[0.0], -rates_f * _TTM])
127129
assert g.min() >= node_g.min() - 1e-9
128130
assert g.max() <= node_g.max() + 1e-9
129131

@@ -167,8 +169,7 @@ def test_forward_rate_consistent_with_discount_factor(
167169

168170
def test_json_round_trip_via_union(interpolation_type: InterpolationType) -> None:
169171
curve = _curve(interpolation_type)
170-
adapter = TypeAdapter(AnyYieldCurve)
171-
restored = adapter.validate_json(adapter.dump_json(curve))
172+
restored = _ADAPTER.validate_json(_ADAPTER.dump_json(curve))
172173
assert type(restored) is InterpolatedYieldCurve
173174
assert restored.interpolation_type is interpolation_type
174175
assert restored._ttm == pytest.approx(_TTM)
@@ -199,7 +200,7 @@ def test_non_increasing_dates_raise() -> None:
199200
def test_anchor_before_ref_date_raises() -> None:
200201
with pytest.raises((ValidationError, ValueError)):
201202
InterpolatedYieldCurve(
202-
ref_date=_REF, anchor_dates=[_REF], anchor_rates=[0.02]
203+
ref_date=_REF, anchor_dates=[_REF], anchor_rates=[Decimal("0.02")]
203204
)
204205

205206

@@ -209,7 +210,7 @@ def test_anchor_before_ref_date_raises() -> None:
209210

210211

211212
def test_calibrate_from_ttm_reprices_exactly() -> None:
212-
target = np.array(_RATES) * 1.1
213+
target = _RATES_F * 1.1
213214
curve = _curve().calibrator().calibrate(_TTM, target)
214215
assert np.asarray(curve.continuously_compounded_rate(_TTM)) == pytest.approx(target)
215216
assert all(isinstance(r, Decimal) for r in curve.anchor_rates)
@@ -218,7 +219,7 @@ def test_calibrate_from_ttm_reprices_exactly() -> None:
218219
def test_set_params_updates_log_discount() -> None:
219220
curve = _curve()
220221
calibrator = curve.calibrator()
221-
new_rates = np.array(_RATES) * 0.5
222+
new_rates = _RATES_F * 0.5
222223
calibrator.set_params(new_rates)
223224
assert curve._log_discount == pytest.approx(-new_rates * _TTM)
224225
assert np.asarray(curve.continuously_compounded_rate(_TTM)) == pytest.approx(

0 commit comments

Comments
 (0)