Skip to content

Commit e0323ca

Browse files
committed
feat: add string shorthand choices
1 parent 556bd89 commit e0323ca

3 files changed

Lines changed: 204 additions & 5 deletions

File tree

openhexa/sdk/pipelines/parameter/decorator.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def __init__(
4343
| File
4444
],
4545
name: str | None = None,
46-
choices: typing.Sequence | ChoicesFromFile | None = None,
46+
choices: typing.Sequence | ChoicesFromFile | str | None = None,
4747
help: str | None = None,
4848
default: typing.Any | None = None,
4949
widget: DHIS2Widget | IASOWidget | None = None,
@@ -67,6 +67,8 @@ def __init__(
6767
if choices is not None:
6868
if not self.type.accepts_choices:
6969
raise InvalidParameterError(f"Parameters of type {self.type} don't accept choices.")
70+
if isinstance(choices, str):
71+
choices = ChoicesFromFile(choices)
7072
if isinstance(choices, ChoicesFromFile):
7173
# validate_spec() already ran in ChoicesFromFile.__init__; nothing more to check here
7274
pass
@@ -244,7 +246,7 @@ def parameter(
244246
| File
245247
],
246248
name: str | None = None,
247-
choices: typing.Sequence | ChoicesFromFile | None = None,
249+
choices: typing.Sequence | ChoicesFromFile | str | None = None,
248250
help: str | None = None,
249251
widget: DHIS2Widget | IASOWidget | None = None,
250252
connection: str | None = None,

openhexa/sdk/pipelines/runtime.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import sys
99
from dataclasses import dataclass, field
1010
from pathlib import Path
11-
from typing import Any
11+
from typing import Any, Callable
1212
from zipfile import ZipFile
1313

1414
import requests
@@ -37,9 +37,10 @@
3737
class Argument:
3838
"""Argument of a decorator."""
3939

40-
name: str # Use str instead of string
40+
name: str
4141
types: list[type] = field(default_factory=list)
4242
default_value: Any = None
43+
transform: Callable | None = None
4344

4445

4546
def import_pipeline(pipeline_dir_path: str) -> Pipeline:
@@ -214,6 +215,8 @@ def _get_decorator_spec(decorator: ast.Call, args: tuple[Argument, ...]) -> dict
214215
args_spec = {}
215216
for i, arg in enumerate(args):
216217
value, is_keyword = _get_decorator_arg_value(decorator, arg, i)
218+
if arg.transform is not None:
219+
value = arg.transform(value)
217220
args_spec[arg.name] = {"value": value, "is_keyword": is_keyword}
218221
return args_spec
219222

@@ -300,7 +303,11 @@ def get_pipeline(pipeline_path: Path) -> Pipeline:
300303
Argument("code", [ast.Constant]),
301304
Argument("type", [ast.Name]),
302305
Argument("name", [ast.Constant]),
303-
Argument("choices", [ast.List, ast.Call]),
306+
Argument(
307+
"choices",
308+
[ast.List, ast.Call, ast.Constant],
309+
transform=lambda v: ChoicesFromFile(v) if isinstance(v, str) else v,
310+
),
304311
Argument("help", [ast.Constant]),
305312
Argument("default", [ast.Constant, ast.List]),
306313
Argument("widget", [ast.Attribute]),

tests/test_choices.py

Lines changed: 190 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,196 @@ def test_to_dict_no_column(self):
5757
assert fc.to_dict() == {"format": "csv", "path": "districts.csv", "column": None}
5858

5959

60+
# ---------------------------------------------------------------------------
61+
# String shorthand — Parameter.__init__
62+
# ---------------------------------------------------------------------------
63+
64+
65+
class TestStringShorthand:
66+
# --- happy paths ---
67+
68+
def test_string_shorthand_csv(self):
69+
p = Parameter(code="district", type=str, choices="districts.csv")
70+
assert isinstance(p.choices, ChoicesFromFile)
71+
assert p.choices.path == "districts.csv"
72+
assert p.choices.format == "csv"
73+
assert p.choices.column is None
74+
75+
def test_string_shorthand_json(self):
76+
p = Parameter(code="district", type=str, choices="data/regions.json")
77+
assert isinstance(p.choices, ChoicesFromFile)
78+
assert p.choices.format == "json"
79+
80+
def test_string_shorthand_yaml(self):
81+
p = Parameter(code="district", type=str, choices="list.yaml")
82+
assert isinstance(p.choices, ChoicesFromFile)
83+
assert p.choices.format == "yaml"
84+
85+
def test_string_shorthand_yml(self):
86+
p = Parameter(code="district", type=str, choices="list.yml")
87+
assert isinstance(p.choices, ChoicesFromFile)
88+
assert p.choices.format == "yaml"
89+
90+
def test_string_shorthand_leading_slash_stripped(self):
91+
p = Parameter(code="district", type=str, choices="/choices.csv")
92+
assert p.choices.path == "/choices.csv" # ChoicesFromFile stores as-is; stripping is app-side
93+
94+
def test_string_shorthand_serialises_same_as_explicit(self):
95+
shorthand = Parameter(code="district", type=str, choices="districts.csv").to_dict()
96+
explicit = Parameter(code="district", type=str, choices=ChoicesFromFile("districts.csv")).to_dict()
97+
assert shorthand == explicit
98+
99+
# --- static list still works ---
100+
101+
def test_static_list_unaffected(self):
102+
p = Parameter(code="country", type=str, choices=["UG", "KE"])
103+
assert p.choices == ["UG", "KE"]
104+
105+
def test_explicit_choices_from_file_unaffected(self):
106+
p = Parameter(code="district", type=str, choices=ChoicesFromFile("districts.csv", column="code"))
107+
assert p.choices.column == "code"
108+
109+
# --- invalid strings raise clearly ---
110+
111+
def test_string_no_extension_raises(self):
112+
with pytest.raises(InvalidParameterError, match="Supported extensions"):
113+
Parameter(code="district", type=str, choices="nodot")
114+
115+
def test_string_unsupported_extension_raises(self):
116+
with pytest.raises(InvalidParameterError, match="Supported extensions"):
117+
Parameter(code="district", type=str, choices="file.xlsx")
118+
119+
def test_empty_string_raises(self):
120+
with pytest.raises(InvalidParameterError):
121+
Parameter(code="district", type=str, choices="")
122+
123+
# --- column cannot be specified via shorthand ---
124+
125+
def test_shorthand_has_no_column(self):
126+
p = Parameter(code="district", type=str, choices="districts.csv")
127+
assert p.choices.column is None
128+
129+
def test_decorator_with_string_shorthand(self):
130+
@parameter(code="district", type=str, choices="districts.csv")
131+
def my_pipeline(district):
132+
pass
133+
134+
params = my_pipeline.get_all_parameters()
135+
assert isinstance(params[0].choices, ChoicesFromFile)
136+
137+
138+
# ---------------------------------------------------------------------------
139+
# String shorthand — AST round-trip
140+
# ---------------------------------------------------------------------------
141+
142+
143+
class TestAstStringShorthand(TestCase):
144+
def _write_pipeline(self, tmpdir, param_line):
145+
with open(f"{tmpdir}/pipeline.py", "w") as f:
146+
f.write(
147+
"\n".join(
148+
[
149+
"from openhexa.sdk.pipelines import pipeline, parameter",
150+
"",
151+
param_line,
152+
"@pipeline(name='Test pipeline')",
153+
"def test_pipeline(district):",
154+
" pass",
155+
]
156+
)
157+
)
158+
159+
def test_ast_string_shorthand_csv(self):
160+
with tempfile.TemporaryDirectory() as tmpdir:
161+
self._write_pipeline(
162+
tmpdir,
163+
"@parameter('district', type=str, choices='districts.csv')",
164+
)
165+
p = get_pipeline(tmpdir)
166+
param_dict = p.to_dict()["parameters"][0]
167+
assert param_dict["choices"] is None
168+
assert param_dict["choices_from_file"] == {"format": "csv", "path": "districts.csv", "column": None}
169+
170+
def test_ast_string_shorthand_json(self):
171+
with tempfile.TemporaryDirectory() as tmpdir:
172+
self._write_pipeline(
173+
tmpdir,
174+
"@parameter('district', type=str, choices='regions.json')",
175+
)
176+
p = get_pipeline(tmpdir)
177+
assert p.to_dict()["parameters"][0]["choices_from_file"]["format"] == "json"
178+
179+
def test_ast_string_shorthand_yaml(self):
180+
with tempfile.TemporaryDirectory() as tmpdir:
181+
self._write_pipeline(
182+
tmpdir,
183+
"@parameter('district', type=str, choices='list.yml')",
184+
)
185+
p = get_pipeline(tmpdir)
186+
assert p.to_dict()["parameters"][0]["choices_from_file"]["format"] == "yaml"
187+
188+
def test_ast_string_shorthand_same_output_as_explicit(self):
189+
with tempfile.TemporaryDirectory() as tmpdir:
190+
self._write_pipeline(
191+
tmpdir,
192+
"@parameter('district', type=str, choices='districts.csv')",
193+
)
194+
shorthand_dict = get_pipeline(tmpdir).to_dict()["parameters"][0]
195+
196+
with tempfile.TemporaryDirectory() as tmpdir:
197+
self._write_pipeline(
198+
tmpdir,
199+
"@parameter('district', type=str, choices=ChoicesFromFile('districts.csv'))",
200+
)
201+
# need the import for the explicit form
202+
with open(f"{tmpdir}/pipeline.py", "w") as f:
203+
f.write(
204+
"\n".join(
205+
[
206+
"from openhexa.sdk.pipelines import pipeline, parameter",
207+
"from openhexa.sdk.pipelines.parameter import ChoicesFromFile",
208+
"",
209+
"@parameter('district', type=str, choices=ChoicesFromFile('districts.csv'))",
210+
"@pipeline(name='Test pipeline')",
211+
"def test_pipeline(district):",
212+
" pass",
213+
]
214+
)
215+
)
216+
explicit_dict = get_pipeline(tmpdir).to_dict()["parameters"][0]
217+
218+
assert shorthand_dict == explicit_dict
219+
220+
def test_ast_static_list_unaffected(self):
221+
with tempfile.TemporaryDirectory() as tmpdir:
222+
self._write_pipeline(
223+
tmpdir,
224+
"@parameter('country', type=str, choices=['UG', 'KE'])",
225+
)
226+
p = get_pipeline(tmpdir)
227+
param_dict = p.to_dict()["parameters"][0]
228+
assert param_dict["choices"] == ["UG", "KE"]
229+
assert "choices_from_file" not in param_dict
230+
231+
def test_ast_string_no_extension_raises(self):
232+
with tempfile.TemporaryDirectory() as tmpdir:
233+
self._write_pipeline(
234+
tmpdir,
235+
"@parameter('district', type=str, choices='nodot')",
236+
)
237+
with self.assertRaises(InvalidParameterError):
238+
get_pipeline(tmpdir)
239+
240+
def test_ast_string_unsupported_extension_raises(self):
241+
with tempfile.TemporaryDirectory() as tmpdir:
242+
self._write_pipeline(
243+
tmpdir,
244+
"@parameter('district', type=str, choices='file.xlsx')",
245+
)
246+
with self.assertRaises(InvalidParameterError):
247+
get_pipeline(tmpdir)
248+
249+
60250
# ---------------------------------------------------------------------------
61251
# Parameter integration
62252
# ---------------------------------------------------------------------------

0 commit comments

Comments
 (0)