Skip to content

Commit eb581d7

Browse files
authored
feat: Conditional parameters (HEXA-1687) (#396)
* feat: Conditional parameters (HEXA-1687) * move check to param init * add option to let user decide if disables when true or false * Rename 'disable_when' to 'disableWhen' for consistency * fix tests
1 parent 9638745 commit eb581d7

6 files changed

Lines changed: 264 additions & 0 deletions

File tree

openhexa/sdk/pipelines/parameter/decorator.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@ def __init__(
5151
required: bool = True,
5252
multiple: bool = False,
5353
directory: str | None = None,
54+
disables: typing.Sequence[str] | None = None,
55+
disable_when: bool = True,
5456
):
5557
validate_pipeline_parameter_code(code)
5658
self.code = code
@@ -92,6 +94,16 @@ def __init__(
9294
self.widget = widget
9395
self.connection = connection
9496
self.directory = directory
97+
self.disables = list(dict.fromkeys(disables)) if disables else None
98+
self.disable_when = disable_when
99+
if self.disables and not isinstance(self.type, Boolean):
100+
raise InvalidParameterError(
101+
f"Only boolean parameters can use 'disables'. Parameter '{self.code}' is of type {self.type}."
102+
)
103+
if not isinstance(self.disable_when, bool):
104+
raise InvalidParameterError(
105+
f"'disable_when' must be a boolean for parameter '{self.code}' (got {disable_when!r})."
106+
)
95107

96108
self._validate_default(default, multiple)
97109
self.default = default
@@ -117,6 +129,8 @@ def to_dict(self) -> dict[str, typing.Any]:
117129
"required": self.required,
118130
"multiple": self.multiple,
119131
"directory": self.directory,
132+
"disables": self.disables,
133+
"disableWhen": self.disable_when,
120134
}
121135
if isinstance(self.choices, ChoicesFromFile):
122136
d["choicesFromFile"] = self.choices.to_dict()
@@ -207,6 +221,24 @@ def validate_parameters(parameters: list[Parameter]):
207221
supported_connection_types = {DHIS2ConnectionType, IASOConnectionType}
208222
connection_parameters = {p.code for p in parameters if type(p.type) in supported_connection_types}
209223

224+
parameters_by_code = {p.code: p for p in parameters}
225+
controllers = {p.code for p in parameters if p.disables}
226+
for parameter in parameters:
227+
if not parameter.disables:
228+
continue
229+
for target_code in parameter.disables:
230+
if target_code == parameter.code:
231+
raise InvalidParameterError(f"Parameter '{parameter.code}' cannot disable itself.")
232+
if target_code not in parameters_by_code:
233+
raise InvalidParameterError(
234+
f"Parameter '{parameter.code}' disables a non-existing parameter '{target_code}'."
235+
)
236+
if target_code in controllers:
237+
raise InvalidParameterError(
238+
f"Parameter '{parameter.code}' disables '{target_code}', which is itself a disabling "
239+
f"parameter. Chaining disabling parameters is not supported."
240+
)
241+
210242
for parameter in parameters:
211243
if parameter.connection and parameter.connection not in connection_parameters:
212244
raise InvalidParameterError(
@@ -251,6 +283,8 @@ def parameter(
251283
required: bool = True,
252284
multiple: bool = False,
253285
directory: str | None = None,
286+
disables: typing.Sequence[str] | None = None,
287+
disable_when: bool = True,
254288
):
255289
"""Decorate a pipeline function by attaching a parameter to it..
256290
@@ -282,6 +316,14 @@ def parameter(
282316
values of the chosen type)
283317
directory : str, optional
284318
An optional parameter to force file selection to specific directory (only used for parameter type File). If the directory does not exist, it will be ignored.
319+
disables : sequence of str, optional
320+
An optional list of parameter codes to disable when this (boolean) parameter's value matches ``disable_when``.
321+
Disabled parameters are hidden/greyed out in the run form, their required check is skipped, and they are
322+
omitted from the run config (the pipeline function receives their default value). Only boolean parameters can
323+
use this.
324+
disable_when : bool, default=True
325+
The boolean value of this parameter that triggers the disabling of the parameters listed in ``disables``.
326+
Use ``disable_when=False`` for an "enable" toggle (the listed parameters are disabled while it is unticked).
285327
286328
Returns
287329
-------
@@ -305,6 +347,8 @@ def decorator(fun):
305347
connection=connection,
306348
multiple=multiple,
307349
directory=directory,
350+
disables=disables,
351+
disable_when=disable_when,
308352
),
309353
)
310354

openhexa/sdk/pipelines/pipeline.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,9 +123,16 @@ def _validate_config(self, config: dict[str, typing.Any]) -> dict[str, typing.An
123123
ParameterValueError
124124
If the config contains invalid keys or parameter validation fails.
125125
"""
126+
disabled_codes = self._get_disabled_codes(config)
127+
126128
validated_config = {}
127129
for parameter in self.parameters:
128130
value = config.pop(parameter.code, None)
131+
if parameter.code in disabled_codes:
132+
# Parameter is disabled by an active controller: ignore the (possibly dummy or missing)
133+
# value, skip required/type validation, and fall back to its default.
134+
validated_config[parameter.code] = parameter.default
135+
continue
129136
validated_value = parameter.validate(value)
130137
validated_config[parameter.code] = validated_value
131138

@@ -134,6 +141,22 @@ def _validate_config(self, config: dict[str, typing.Any]) -> dict[str, typing.An
134141

135142
return validated_config
136143

144+
def _get_disabled_codes(self, config: dict[str, typing.Any]) -> set[str]:
145+
"""Return the codes of parameters disabled by an active controller in the given config.
146+
147+
A controller is a boolean parameter declaring ``disables=[...]``. It is "active" when its effective
148+
value (from the config, falling back to its default) equals its ``disable_when`` (``True`` by default).
149+
A parameter is disabled if any active controller lists it.
150+
"""
151+
disabled_codes: set[str] = set()
152+
for parameter in self.parameters:
153+
if not parameter.disables:
154+
continue
155+
effective_value = config.get(parameter.code, parameter.default)
156+
if bool(effective_value) == parameter.disable_when:
157+
disabled_codes.update(parameter.disables)
158+
return disabled_codes
159+
137160
def _execute_tasks(self, pool):
138161
"""Execute all tasks using the provided multiprocessing pool.
139162

openhexa/sdk/pipelines/runtime.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -316,6 +316,8 @@ def get_pipeline(pipeline_path: Path) -> Pipeline:
316316
Argument("required", [ast.Constant], default_value=True),
317317
Argument("multiple", [ast.Constant], default_value=False),
318318
Argument("directory", [ast.Constant]),
319+
Argument("disables", [ast.List]),
320+
Argument("disable_when", [ast.Constant], default_value=True),
319321
),
320322
)
321323

tests/test_ast.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,13 +154,62 @@ def test_pipeline_with_int_param(self):
154154
"help": "Param help",
155155
"required": True,
156156
"directory": None,
157+
"disables": None,
158+
"disableWhen": True,
157159
}
158160
],
159161
"timeout": None,
160162
"functional_type": None,
161163
},
162164
)
163165

166+
def test_pipeline_with_disables_param(self):
167+
"""The @parameter decorator's 'disables' list is parsed from the pipeline code."""
168+
with tempfile.TemporaryDirectory() as tmpdirname:
169+
with open(f"{tmpdirname}/pipeline.py", "w") as f:
170+
f.write(
171+
"\n".join(
172+
[
173+
"from openhexa.sdk.pipelines import pipeline, parameter",
174+
"",
175+
"@parameter('run_report_only', type=bool, default=False, disables=['data_input'])",
176+
"@parameter('data_input', type=str)",
177+
"@pipeline('Test pipeline')",
178+
"def test_pipeline():",
179+
" pass",
180+
"",
181+
]
182+
)
183+
)
184+
pipeline = get_pipeline(tmpdirname)
185+
params = {p["code"]: p for p in pipeline.to_dict()["parameters"]}
186+
self.assertEqual(params["run_report_only"]["disables"], ["data_input"])
187+
self.assertEqual(params["run_report_only"]["disableWhen"], True)
188+
self.assertIsNone(params["data_input"]["disables"])
189+
190+
def test_pipeline_with_disable_when_false(self):
191+
"""The @parameter decorator's 'disable_when' is parsed from the pipeline code."""
192+
with tempfile.TemporaryDirectory() as tmpdirname:
193+
with open(f"{tmpdirname}/pipeline.py", "w") as f:
194+
f.write(
195+
"\n".join(
196+
[
197+
"from openhexa.sdk.pipelines import pipeline, parameter",
198+
"",
199+
"@parameter('enable_advanced', type=bool, default=False, disables=['tuning'], disable_when=False)",
200+
"@parameter('tuning', type=str)",
201+
"@pipeline('Test pipeline')",
202+
"def test_pipeline():",
203+
" pass",
204+
"",
205+
]
206+
)
207+
)
208+
pipeline = get_pipeline(tmpdirname)
209+
params = {p["code"]: p for p in pipeline.to_dict()["parameters"]}
210+
self.assertEqual(params["enable_advanced"]["disables"], ["tuning"])
211+
self.assertEqual(params["enable_advanced"]["disableWhen"], False)
212+
164213
def test_pipeline_with_multiple_param(self):
165214
"""The file contains a @pipeline decorator and a @parameter decorator with multiple=True."""
166215
with tempfile.TemporaryDirectory() as tmpdirname:
@@ -198,6 +247,8 @@ def test_pipeline_with_multiple_param(self):
198247
"help": "Param help",
199248
"required": True,
200249
"directory": None,
250+
"disables": None,
251+
"disableWhen": True,
201252
}
202253
],
203254
"timeout": None,
@@ -243,6 +294,8 @@ def test_pipeline_with_dataset(self):
243294
"help": "Dataset",
244295
"required": False,
245296
"directory": None,
297+
"disables": None,
298+
"disableWhen": True,
246299
}
247300
],
248301
"timeout": None,
@@ -287,6 +340,8 @@ def test_pipeline_with_choices(self):
287340
"help": "Param help",
288341
"required": True,
289342
"directory": None,
343+
"disables": None,
344+
"disableWhen": True,
290345
}
291346
],
292347
"timeout": None,
@@ -359,6 +414,8 @@ def test_pipeline_with_bool(self):
359414
"help": "Param help",
360415
"required": True,
361416
"directory": None,
417+
"disables": None,
418+
"disableWhen": True,
362419
}
363420
],
364421
"timeout": None,
@@ -404,6 +461,8 @@ def test_pipeline_with_multiple_parameters(self):
404461
"help": "Param help",
405462
"required": True,
406463
"directory": None,
464+
"disables": None,
465+
"disableWhen": True,
407466
},
408467
{
409468
"choices": ["a", "b"],
@@ -417,6 +476,8 @@ def test_pipeline_with_multiple_parameters(self):
417476
"help": "Param help 2",
418477
"required": True,
419478
"directory": None,
479+
"disables": None,
480+
"disableWhen": True,
420481
},
421482
],
422483
"timeout": None,
@@ -484,6 +545,8 @@ def test_pipeline_with_connection_parameter_for_dhis2(self):
484545
"help": None,
485546
"required": True,
486547
"directory": None,
548+
"disables": None,
549+
"disableWhen": True,
487550
},
488551
{
489552
"code": "data_element_ids",
@@ -497,6 +560,8 @@ def test_pipeline_with_connection_parameter_for_dhis2(self):
497560
"help": None,
498561
"required": True,
499562
"directory": None,
563+
"disables": None,
564+
"disableWhen": True,
500565
},
501566
],
502567
"timeout": None,
@@ -546,6 +611,8 @@ def test_pipeline_with_connection_parameter_for_iaso(self):
546611
"help": None,
547612
"required": True,
548613
"directory": None,
614+
"disables": None,
615+
"disableWhen": True,
549616
},
550617
{
551618
"code": "org_units",
@@ -559,6 +626,8 @@ def test_pipeline_with_connection_parameter_for_iaso(self):
559626
"help": None,
560627
"required": True,
561628
"directory": None,
629+
"disables": None,
630+
"disableWhen": True,
562631
},
563632
{
564633
"code": "projects",
@@ -572,6 +641,8 @@ def test_pipeline_with_connection_parameter_for_iaso(self):
572641
"help": None,
573642
"required": True,
574643
"directory": None,
644+
"disables": None,
645+
"disableWhen": True,
575646
},
576647
{
577648
"code": "forms",
@@ -585,6 +656,8 @@ def test_pipeline_with_connection_parameter_for_iaso(self):
585656
"help": None,
586657
"required": True,
587658
"directory": None,
659+
"disables": None,
660+
"disableWhen": True,
588661
},
589662
],
590663
"timeout": None,

tests/test_parameter.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
SecretType,
3737
StringType,
3838
parameter,
39+
validate_parameters,
3940
)
4041
from openhexa.utils import stringcase
4142

@@ -422,3 +423,63 @@ def a_function():
422423
assert function_parameters[1].default == ["yo"]
423424
assert function_parameters[1].required is False
424425
assert function_parameters[1].multiple is True
426+
427+
428+
def test_parameter_disables_serialization():
429+
"""The 'disables' option is normalized to a list and serialized in to_dict."""
430+
no_disables = Parameter("plain", type=str)
431+
assert no_disables.disables is None
432+
assert no_disables.to_dict()["disables"] is None
433+
434+
controller = Parameter("run_report_only", type=bool, disables=["data_input", "year"])
435+
assert controller.disables == ["data_input", "year"]
436+
assert controller.to_dict()["disables"] == ["data_input", "year"]
437+
assert controller.to_dict()["disableWhen"] is True
438+
439+
440+
def test_parameter_disables_dedup_preserves_order():
441+
"""Duplicate disables targets are removed while keeping declaration order."""
442+
controller = Parameter("toggle", type=bool, disables=["b", "a", "b", "a"])
443+
assert controller.disables == ["b", "a"]
444+
445+
446+
def test_disable_when_must_be_boolean():
447+
"""'disable_when' must be a boolean — rejected at construction time."""
448+
with pytest.raises(InvalidParameterError):
449+
Parameter("toggle", type=bool, disables=["x_param"], disable_when="yes")
450+
451+
452+
def test_validate_parameters_disables_ok():
453+
"""A valid disabling setup passes validation."""
454+
controller = Parameter("run_report_only", type=bool, default=False, disables=["data_input"])
455+
data_input = Parameter("data_input", type=str, required=True)
456+
validate_parameters([controller, data_input])
457+
458+
459+
def test_disables_must_be_boolean():
460+
"""Only boolean parameters can use 'disables' — rejected at construction time."""
461+
with pytest.raises(InvalidParameterError):
462+
Parameter("mode", type=str, disables=["data_input"])
463+
464+
465+
def test_validate_parameters_disables_unknown_target():
466+
"""Disabling a non-existing parameter raises."""
467+
controller = Parameter("run_report_only", type=bool, disables=["does_not_exist"])
468+
with pytest.raises(InvalidParameterError):
469+
validate_parameters([controller])
470+
471+
472+
def test_validate_parameters_disables_self_reference():
473+
"""A parameter cannot disable itself."""
474+
controller = Parameter("run_report_only", type=bool, disables=["run_report_only"])
475+
with pytest.raises(InvalidParameterError):
476+
validate_parameters([controller])
477+
478+
479+
def test_validate_parameters_disables_no_chaining():
480+
"""A disabling parameter cannot disable another disabling parameter."""
481+
controller_a = Parameter("toggle_a", type=bool, disables=["toggle_b"])
482+
controller_b = Parameter("toggle_b", type=bool, disables=["plain_c"])
483+
plain_c = Parameter("plain_c", type=str)
484+
with pytest.raises(InvalidParameterError):
485+
validate_parameters([controller_a, controller_b, plain_c])

0 commit comments

Comments
 (0)