Skip to content

Commit 08533b3

Browse files
committed
feat: Conditional parameters (HEXA-1687)
1 parent ad4e725 commit 08533b3

6 files changed

Lines changed: 176 additions & 0 deletions

File tree

openhexa/sdk/pipelines/parameter/decorator.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ def __init__(
5151
required: bool = True,
5252
multiple: bool = False,
5353
directory: str | None = None,
54+
disables: typing.Sequence[str] | None = None,
5455
):
5556
validate_pipeline_parameter_code(code)
5657
self.code = code
@@ -92,6 +93,7 @@ def __init__(
9293
self.widget = widget
9394
self.connection = connection
9495
self.directory = directory
96+
self.disables = list(disables) if disables else None
9597

9698
self._validate_default(default, multiple)
9799
self.default = default
@@ -117,6 +119,7 @@ def to_dict(self) -> dict[str, typing.Any]:
117119
"required": self.required,
118120
"multiple": self.multiple,
119121
"directory": self.directory,
122+
"disables": self.disables,
120123
}
121124
if isinstance(self.choices, ChoicesFromFile):
122125
d["choices_from_file"] = self.choices.to_dict()
@@ -207,6 +210,28 @@ def validate_parameters(parameters: list[Parameter]):
207210
supported_connection_types = {DHIS2ConnectionType, IASOConnectionType}
208211
connection_parameters = {p.code for p in parameters if type(p.type) in supported_connection_types}
209212

213+
parameters_by_code = {p.code: p for p in parameters}
214+
controllers = {p.code for p in parameters if p.disables}
215+
for parameter in parameters:
216+
if not parameter.disables:
217+
continue
218+
if not isinstance(parameter.type, Boolean):
219+
raise InvalidParameterError(
220+
f"Only boolean parameters can use 'disables'. Parameter '{parameter.code}' is of type {parameter.type}."
221+
)
222+
for target_code in parameter.disables:
223+
if target_code == parameter.code:
224+
raise InvalidParameterError(f"Parameter '{parameter.code}' cannot disable itself.")
225+
if target_code not in parameters_by_code:
226+
raise InvalidParameterError(
227+
f"Parameter '{parameter.code}' disables a non-existing parameter '{target_code}'."
228+
)
229+
if target_code in controllers:
230+
raise InvalidParameterError(
231+
f"Parameter '{parameter.code}' disables '{target_code}', which is itself a disabling "
232+
f"parameter. Chaining disabling parameters is not supported."
233+
)
234+
210235
for parameter in parameters:
211236
if parameter.connection and parameter.connection not in connection_parameters:
212237
raise InvalidParameterError(
@@ -251,6 +276,7 @@ def parameter(
251276
required: bool = True,
252277
multiple: bool = False,
253278
directory: str | None = None,
279+
disables: typing.Sequence[str] | None = None,
254280
):
255281
"""Decorate a pipeline function by attaching a parameter to it..
256282
@@ -282,6 +308,10 @@ def parameter(
282308
values of the chosen type)
283309
directory : str, optional
284310
An optional parameter to force file selection to specific directory (only used for parameter type File). If the directory does not exist, it will be ignored.
311+
disables : sequence of str, optional
312+
An optional list of parameter codes to disable when this (boolean) parameter is set to ``True``. Disabled
313+
parameters are hidden/greyed out in the run form, their required check is skipped, and they are omitted from
314+
the run config (the pipeline function receives their default value). Only boolean parameters can use this.
285315
286316
Returns
287317
-------
@@ -305,6 +335,7 @@ def decorator(fun):
305335
connection=connection,
306336
multiple=multiple,
307337
directory=directory,
338+
disables=disables,
308339
),
309340
)
310341

openhexa/sdk/pipelines/pipeline.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,9 +123,16 @@ def _validate_config(self, config: dict[str, typing.Any]) -> dict[str, typing.An
123123
ParameterValueError
124124
If the config contains invalid keys or parameter validation fails.
125125
"""
126+
disabled_codes = self._get_disabled_codes(config)
127+
126128
validated_config = {}
127129
for parameter in self.parameters:
128130
value = config.pop(parameter.code, None)
131+
if parameter.code in disabled_codes:
132+
# Parameter is disabled by an active controller: ignore the (possibly dummy or missing)
133+
# value, skip required/type validation, and fall back to its default.
134+
validated_config[parameter.code] = parameter.default
135+
continue
129136
validated_value = parameter.validate(value)
130137
validated_config[parameter.code] = validated_value
131138

@@ -134,6 +141,22 @@ def _validate_config(self, config: dict[str, typing.Any]) -> dict[str, typing.An
134141

135142
return validated_config
136143

144+
def _get_disabled_codes(self, config: dict[str, typing.Any]) -> set[str]:
145+
"""Return the codes of parameters disabled by an active controller in the given config.
146+
147+
A controller is a boolean parameter declaring ``disables=[...]``. It is "active" when its effective
148+
value (from the config, falling back to its default) is truthy. A parameter is disabled if any active
149+
controller lists it.
150+
"""
151+
disabled_codes: set[str] = set()
152+
for parameter in self.parameters:
153+
if not parameter.disables:
154+
continue
155+
effective_value = config.get(parameter.code, parameter.default)
156+
if effective_value:
157+
disabled_codes.update(parameter.disables)
158+
return disabled_codes
159+
137160
def _execute_tasks(self, pool):
138161
"""Execute all tasks using the provided multiprocessing pool.
139162

openhexa/sdk/pipelines/runtime.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -316,6 +316,7 @@ def get_pipeline(pipeline_path: Path) -> Pipeline:
316316
Argument("required", [ast.Constant], default_value=True),
317317
Argument("multiple", [ast.Constant], default_value=False),
318318
Argument("directory", [ast.Constant]),
319+
Argument("disables", [ast.List]),
319320
),
320321
)
321322

tests/test_ast.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,13 +154,37 @@ def test_pipeline_with_int_param(self):
154154
"help": "Param help",
155155
"required": True,
156156
"directory": None,
157+
"disables": None,
157158
}
158159
],
159160
"timeout": None,
160161
"functional_type": None,
161162
},
162163
)
163164

165+
def test_pipeline_with_disables_param(self):
166+
"""The @parameter decorator's 'disables' list is parsed from the pipeline code."""
167+
with tempfile.TemporaryDirectory() as tmpdirname:
168+
with open(f"{tmpdirname}/pipeline.py", "w") as f:
169+
f.write(
170+
"\n".join(
171+
[
172+
"from openhexa.sdk.pipelines import pipeline, parameter",
173+
"",
174+
"@parameter('run_report_only', type=bool, default=False, disables=['data_input'])",
175+
"@parameter('data_input', type=str)",
176+
"@pipeline('Test pipeline')",
177+
"def test_pipeline():",
178+
" pass",
179+
"",
180+
]
181+
)
182+
)
183+
pipeline = get_pipeline(tmpdirname)
184+
params = {p["code"]: p for p in pipeline.to_dict()["parameters"]}
185+
self.assertEqual(params["run_report_only"]["disables"], ["data_input"])
186+
self.assertIsNone(params["data_input"]["disables"])
187+
164188
def test_pipeline_with_multiple_param(self):
165189
"""The file contains a @pipeline decorator and a @parameter decorator with multiple=True."""
166190
with tempfile.TemporaryDirectory() as tmpdirname:
@@ -198,6 +222,7 @@ def test_pipeline_with_multiple_param(self):
198222
"help": "Param help",
199223
"required": True,
200224
"directory": None,
225+
"disables": None,
201226
}
202227
],
203228
"timeout": None,
@@ -243,6 +268,7 @@ def test_pipeline_with_dataset(self):
243268
"help": "Dataset",
244269
"required": False,
245270
"directory": None,
271+
"disables": None,
246272
}
247273
],
248274
"timeout": None,
@@ -287,6 +313,7 @@ def test_pipeline_with_choices(self):
287313
"help": "Param help",
288314
"required": True,
289315
"directory": None,
316+
"disables": None,
290317
}
291318
],
292319
"timeout": None,
@@ -359,6 +386,7 @@ def test_pipeline_with_bool(self):
359386
"help": "Param help",
360387
"required": True,
361388
"directory": None,
389+
"disables": None,
362390
}
363391
],
364392
"timeout": None,
@@ -404,6 +432,7 @@ def test_pipeline_with_multiple_parameters(self):
404432
"help": "Param help",
405433
"required": True,
406434
"directory": None,
435+
"disables": None,
407436
},
408437
{
409438
"choices": ["a", "b"],
@@ -417,6 +446,7 @@ def test_pipeline_with_multiple_parameters(self):
417446
"help": "Param help 2",
418447
"required": True,
419448
"directory": None,
449+
"disables": None,
420450
},
421451
],
422452
"timeout": None,
@@ -484,6 +514,7 @@ def test_pipeline_with_connection_parameter_for_dhis2(self):
484514
"help": None,
485515
"required": True,
486516
"directory": None,
517+
"disables": None,
487518
},
488519
{
489520
"code": "data_element_ids",
@@ -497,6 +528,7 @@ def test_pipeline_with_connection_parameter_for_dhis2(self):
497528
"help": None,
498529
"required": True,
499530
"directory": None,
531+
"disables": None,
500532
},
501533
],
502534
"timeout": None,
@@ -546,6 +578,7 @@ def test_pipeline_with_connection_parameter_for_iaso(self):
546578
"help": None,
547579
"required": True,
548580
"directory": None,
581+
"disables": None,
549582
},
550583
{
551584
"code": "org_units",
@@ -559,6 +592,7 @@ def test_pipeline_with_connection_parameter_for_iaso(self):
559592
"help": None,
560593
"required": True,
561594
"directory": None,
595+
"disables": None,
562596
},
563597
{
564598
"code": "projects",
@@ -572,6 +606,7 @@ def test_pipeline_with_connection_parameter_for_iaso(self):
572606
"help": None,
573607
"required": True,
574608
"directory": None,
609+
"disables": None,
575610
},
576611
{
577612
"code": "forms",
@@ -585,6 +620,7 @@ def test_pipeline_with_connection_parameter_for_iaso(self):
585620
"help": None,
586621
"required": True,
587622
"directory": None,
623+
"disables": None,
588624
},
589625
],
590626
"timeout": None,

tests/test_parameter.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
SecretType,
3737
StringType,
3838
parameter,
39+
validate_parameters,
3940
)
4041
from openhexa.utils import stringcase
4142

@@ -422,3 +423,52 @@ def a_function():
422423
assert function_parameters[1].default == ["yo"]
423424
assert function_parameters[1].required is False
424425
assert function_parameters[1].multiple is True
426+
427+
428+
def test_parameter_disables_serialization():
429+
"""The 'disables' option is normalized to a list and serialized in to_dict."""
430+
no_disables = Parameter("plain", type=str)
431+
assert no_disables.disables is None
432+
assert no_disables.to_dict()["disables"] is None
433+
434+
controller = Parameter("run_report_only", type=bool, disables=["data_input", "year"])
435+
assert controller.disables == ["data_input", "year"]
436+
assert controller.to_dict()["disables"] == ["data_input", "year"]
437+
438+
439+
def test_validate_parameters_disables_ok():
440+
"""A valid disabling setup passes validation."""
441+
controller = Parameter("run_report_only", type=bool, default=False, disables=["data_input"])
442+
data_input = Parameter("data_input", type=str, required=True)
443+
validate_parameters([controller, data_input])
444+
445+
446+
def test_validate_parameters_disables_must_be_boolean():
447+
"""Only boolean parameters can use 'disables'."""
448+
controller = Parameter("mode", type=str, disables=["data_input"])
449+
data_input = Parameter("data_input", type=str)
450+
with pytest.raises(InvalidParameterError):
451+
validate_parameters([controller, data_input])
452+
453+
454+
def test_validate_parameters_disables_unknown_target():
455+
"""Disabling a non-existing parameter raises."""
456+
controller = Parameter("run_report_only", type=bool, disables=["does_not_exist"])
457+
with pytest.raises(InvalidParameterError):
458+
validate_parameters([controller])
459+
460+
461+
def test_validate_parameters_disables_self_reference():
462+
"""A parameter cannot disable itself."""
463+
controller = Parameter("run_report_only", type=bool, disables=["run_report_only"])
464+
with pytest.raises(InvalidParameterError):
465+
validate_parameters([controller])
466+
467+
468+
def test_validate_parameters_disables_no_chaining():
469+
"""A disabling parameter cannot disable another disabling parameter."""
470+
controller_a = Parameter("toggle_a", type=bool, disables=["toggle_b"])
471+
controller_b = Parameter("toggle_b", type=bool, disables=["plain_c"])
472+
plain_c = Parameter("plain_c", type=str)
473+
with pytest.raises(InvalidParameterError):
474+
validate_parameters([controller_a, controller_b, plain_c])

tests/test_pipeline.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,41 @@ def test_pipeline_run_extra_config():
5252
pipeline.run({"arg1": "ok", "arg2": "extra"})
5353

5454

55+
def test_pipeline_run_disabled_required_parameter_skipped():
56+
"""A required parameter disabled by an active controller is skipped and receives its default."""
57+
pipeline_func = Mock()
58+
controller = Parameter("run_report_only", type=bool, default=False, disables=["data_input"])
59+
data_input = Parameter("data_input", type=str, required=True)
60+
pipeline = Pipeline("pipeline", pipeline_func, [controller, data_input])
61+
62+
pipeline.run({"run_report_only": True})
63+
64+
pipeline_func.assert_called_once_with(run_report_only=True, data_input=None)
65+
66+
67+
def test_pipeline_run_disabled_parameter_value_ignored():
68+
"""A dummy value provided for a disabled parameter is ignored in favor of its default."""
69+
pipeline_func = Mock()
70+
controller = Parameter("run_report_only", type=bool, default=False, disables=["year"])
71+
year = Parameter("year", type=int, required=True, default=2024)
72+
pipeline = Pipeline("pipeline", pipeline_func, [controller, year])
73+
74+
pipeline.run({"run_report_only": True, "year": 1})
75+
76+
pipeline_func.assert_called_once_with(run_report_only=True, year=2024)
77+
78+
79+
def test_pipeline_run_inactive_controller_still_validates():
80+
"""When the controller is not active, disabled parameters are still validated as usual."""
81+
pipeline_func = Mock()
82+
controller = Parameter("run_report_only", type=bool, default=False, disables=["data_input"])
83+
data_input = Parameter("data_input", type=str, required=True)
84+
pipeline = Pipeline("pipeline", pipeline_func, [controller, data_input])
85+
86+
with pytest.raises(ParameterValueError):
87+
pipeline.run({"run_report_only": False})
88+
89+
5590
@patch.dict(
5691
os.environ,
5792
{"HEXA_SERVER_URL": "https://test.openhexa.org"},

0 commit comments

Comments
 (0)