Skip to content

Commit 7f443a5

Browse files
alongdclaude
andcommitted
level: raise InputError on user CBS formula division by zero
Wrap safe_eval_formula in _user_fn so degenerate inputs surface the formula text and cardinal map instead of a raw ZeroDivisionError. Also add the missing return annotations on _resolve_formula/_user_fn. Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
1 parent 1eb9430 commit 7f443a5

2 files changed

Lines changed: 23 additions & 5 deletions

File tree

arc/level/protocol.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141

4242
import copy
4343
from abc import ABC, abstractmethod
44-
from collections.abc import Iterable
44+
from collections.abc import Callable, Iterable
4545
from typing import Any
4646

4747
from arc.exceptions import InputError
@@ -295,7 +295,7 @@ def __init__(
295295
_USER_FORMULA_MAX_LEVELS = 3
296296

297297
@staticmethod
298-
def _resolve_formula(formula: str, n_levels: int):
298+
def _resolve_formula(formula: str, n_levels: int) -> Callable[[dict[int, float]], float]:
299299
"""Validate ``formula`` against the built-in registry and (if user-supplied)
300300
the safe-eval whitelist; return a callable taking ``{cardinal: energy}``.
301301
@@ -324,13 +324,19 @@ def _resolve_formula(formula: str, n_levels: int):
324324
allowed.update({var for var in ("X", "Y", "Z")[:n_levels]})
325325
validate_formula(formula, allowed)
326326

327-
def _user_fn(energies):
328-
env = {}
327+
def _user_fn(energies: dict[int, float]) -> float:
328+
env: dict[str, float] = {}
329329
for idx, (X, E) in enumerate(sorted(energies.items())):
330330
var = ("X", "Y", "Z")[idx]
331331
env[var] = X
332332
env[f"E_{var}"] = E
333-
return safe_eval_formula(formula, env)
333+
try:
334+
return safe_eval_formula(formula, env)
335+
except ZeroDivisionError as exc:
336+
raise InputError(
337+
f"User CBS formula {formula!r} raised division by zero "
338+
f"for inputs {energies}."
339+
) from exc
334340

335341
return _user_fn
336342

arc/level/protocol_test.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,18 @@ def test_evaluate_user_formula(self):
128128
expected = (27 * -0.30 - 64 * -0.31) / (27 - 64)
129129
self.assertAlmostEqual(result, expected, places=12)
130130

131+
def test_evaluate_user_formula_zero_division_reports_formula_and_inputs(self):
132+
term = CBSExtrapolationTerm(
133+
label="cbs_user",
134+
formula="E_X / (E_Y - E_X)",
135+
levels=[self.tz, self.qz],
136+
)
137+
with self.assertRaises(InputError) as cm:
138+
term.evaluate({"cbs_user__card_3": -0.30, "cbs_user__card_4": -0.30})
139+
message = str(cm.exception)
140+
self.assertIn("E_X / (E_Y - E_X)", message)
141+
self.assertIn("{3: -0.3, 4: -0.3}", message)
142+
131143
def test_three_point_martin(self):
132144
term = CBSExtrapolationTerm(
133145
label="cbs_m", formula="martin_3pt", levels=[self.tz, self.qz, self.fz]

0 commit comments

Comments
 (0)