Skip to content

Commit 5bb1366

Browse files
authored
chore: Better Validation of the DHIS2 widgets (#245)
* chore(dhis2): add an example pipeline for dhis2 widgets * refactor(Parameter): rename validate_parameters_with_connection to validate_parameters and enhance validation for DHIS2 widgets requiring a connection parameter * Update pipeline.py docstring to clarify purpose as an example for retrieving data elements from DHIS2. * refactor: replace ParameterWidget with DHIS2Widget for improved clarity and organization in DHIS2 pipeline parameters * refactor: remove comment
1 parent 5b3e9e2 commit 5bb1366

5 files changed

Lines changed: 71 additions & 59 deletions

File tree

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
workspace/
2+
workspace.yaml
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
"""Example pipeline to get data elements from DHIS2."""
2+
3+
from openhexa.sdk import current_run, parameter, pipeline
4+
from openhexa.sdk.pipelines.parameter import DHIS2Widget
5+
from openhexa.sdk.workspaces.connection import DHIS2Connection
6+
7+
8+
@pipeline("dhis2")
9+
@parameter("dhis2_con", type=DHIS2Connection, required=True)
10+
@parameter(
11+
"data_elements",
12+
type=str,
13+
required=True,
14+
widget=DHIS2Widget.DATA_ELEMENTS,
15+
multiple=True,
16+
connection="dhis2_con",
17+
)
18+
def dhis2(dhis2_con, data_elements):
19+
"""Get data elements from DHIS2."""
20+
print_data_elements(dhis2_con, data_elements)
21+
22+
23+
@dhis2.task
24+
def print_data_elements(dhis2_con, data_elements):
25+
"""Print data elements."""
26+
current_run.log_info("Printing data elements")
27+
28+
current_run.log_info(f"Data elements: {data_elements}")
29+
30+
31+
if __name__ == "__main__":
32+
dhis2()

openhexa/sdk/pipelines/parameter.py

Lines changed: 19 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -359,23 +359,17 @@ def validate(self, value: typing.Any | None) -> Dataset:
359359
}
360360

361361

362-
class ParameterWidget(StrEnum):
363-
"""
364-
Enum for parameter available parameter widgets.
365-
366-
The list of supported widgets can be found in the OpenHEXA documentation or GraphQL schema.
367-
/graphql/ -> __schema -> types -> ParameterWidget
368-
https://github.com/blsq/openhexa/wiki/Writing-OpenHEXA-pipelines#using-widget-parameters
369-
"""
362+
class DHIS2Widget(StrEnum):
363+
"""Enum for DHIS2 widgets."""
370364

371-
DHIS2_ORG_UNITS = "DHIS2_ORG_UNITS"
372-
DHIS2_ORG_UNIT_GROUPS = "DHIS2_ORG_UNIT_GROUPS"
373-
DHIS2_ORG_UNIT_LEVELS = "DHIS2_ORG_UNIT_LEVELS"
374-
DHIS2_DATASETS = "DHIS2_DATASETS"
375-
DHIS2_DATA_ELEMENTS = "DHIS2_DATA_ELEMENTS"
376-
DHIS2_DATA_ELEMENT_GROUPS = "DHIS2_DATA_ELEMENT_GROUPS"
377-
DHIS2_INDICATORS = "DHIS2_INDICATORS"
378-
DHIS2_INDICATOR_GROUPS = "DHIS2_INDICATOR_GROUPS"
365+
ORG_UNITS = "DHIS2_ORG_UNITS"
366+
ORG_UNIT_GROUPS = "DHIS2_ORG_UNIT_GROUPS"
367+
ORG_UNIT_LEVELS = "DHIS2_ORG_UNIT_LEVELS"
368+
DATASETS = "DHIS2_DATASETS"
369+
DATA_ELEMENTS = "DHIS2_DATA_ELEMENTS"
370+
DATA_ELEMENT_GROUPS = "DHIS2_DATA_ELEMENT_GROUPS"
371+
INDICATORS = "DHIS2_INDICATORS"
372+
INDICATOR_GROUPS = "DHIS2_INDICATOR_GROUPS"
379373

380374

381375
class Parameter:
@@ -399,13 +393,12 @@ def __init__(
399393
choices: typing.Sequence | None = None,
400394
help: str | None = None,
401395
default: typing.Any | None = None,
402-
widget: ParameterWidget | None = None,
396+
widget: DHIS2Widget | None = None,
403397
connection: str | None = None,
404398
required: bool = True,
405399
multiple: bool = False,
406400
):
407401
validate_pipeline_parameter_code(code)
408-
409402
self.code = code
410403

411404
try:
@@ -460,7 +453,7 @@ def to_dict(self) -> dict[str, typing.Any]:
460453
"choices": self.choices,
461454
"help": self.help,
462455
"default": self.default,
463-
"widget": self.widget.value if self.widget else None,
456+
"widget": self.widget if self.widget else None,
464457
"connection": self.connection,
465458
"required": self.required,
466459
"multiple": self.multiple,
@@ -536,7 +529,7 @@ def _validate_default(self, default: typing.Any, multiple: bool):
536529
)
537530

538531

539-
def validate_parameters_with_connection(parameters: [Parameter]):
532+
def validate_parameters(parameters: list[Parameter]):
540533
"""Validate the provided connection parameters if they relate to existing connection parameter."""
541534
supported_connection_types = {DHIS2ConnectionType}
542535
connection_parameters = {p.code for p in parameters if type(p.type) in supported_connection_types}
@@ -546,6 +539,11 @@ def validate_parameters_with_connection(parameters: [Parameter]):
546539
raise InvalidParameterError(
547540
f"Connection field '{parameter.code}' references a non-existing connection parameter '{parameter.connection}'"
548541
)
542+
if parameter.widget and parameter.widget in DHIS2Widget and not parameter.connection:
543+
raise InvalidParameterError(
544+
f"DHIS2 widgets require a connection parameter. Please provide a connection parameter for {parameter.code}. "
545+
f"Example: @parameter('{parameter.code}', type=str, widget=DHIS2Widget.{parameter.widget}, connection='my_connection')"
546+
)
549547

550548

551549
def parameter(
@@ -565,7 +563,7 @@ def parameter(
565563
name: str | None = None,
566564
choices: typing.Sequence | None = None,
567565
help: str | None = None,
568-
widget: ParameterWidget | None = None,
566+
widget: DHIS2Widget | None = None,
569567
default: typing.Any | None = None,
570568
required: bool = True,
571569
multiple: bool = False,

openhexa/sdk/pipelines/runtime.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@
1616
from openhexa.sdk.pipelines.exceptions import InvalidParameterError, PipelineNotFound
1717
from openhexa.sdk.pipelines.parameter import (
1818
TYPES_BY_PYTHON_TYPE,
19+
DHIS2Widget,
1920
Parameter,
20-
ParameterWidget,
21-
validate_parameters_with_connection,
21+
validate_parameters,
2222
)
2323

2424
from .pipeline import Pipeline
@@ -168,7 +168,10 @@ def _get_decorator_arg_value(decorator: ast.Call, arg: Argument, index: int) ->
168168
elif isinstance(keyword.value, ast.List):
169169
return ([el.value for el in keyword.value.elts], True)
170170
elif isinstance(keyword.value, ast.Attribute):
171-
return (ParameterWidget(keyword.value.attr), True)
171+
if keyword.value.attr in DHIS2Widget.__members__:
172+
return getattr(DHIS2Widget, keyword.value.attr), True
173+
else:
174+
raise ValueError(f"Unsupported widget: {keyword.value.attr}")
172175

173176
# Then check for positional arguments
174177
try:
@@ -303,8 +306,7 @@ def get_pipeline(pipeline_path: Path) -> Pipeline:
303306
except KeyError as e:
304307
raise InvalidParameterError(f"Missing required parameter attribute: {e}")
305308

306-
# Validate parameters with connections
307-
validate_parameters_with_connection(pipeline_parameters)
309+
validate_parameters(pipeline_parameters)
308310

309311
# Create and return the pipeline
310312
return Pipeline(

tests/test_ast.py

Lines changed: 11 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from unittest.mock import patch
77

88
from openhexa.sdk.pipelines.exceptions import InvalidParameterError, PipelineNotFound
9-
from openhexa.sdk.pipelines.parameter import ParameterWidget
9+
from openhexa.sdk.pipelines.parameter import DHIS2Widget
1010
from openhexa.sdk.pipelines.runtime import get_pipeline
1111

1212

@@ -436,10 +436,10 @@ def test_pipeline_with_connection_parameter(self):
436436
"\n".join(
437437
[
438438
"from openhexa.sdk.pipelines import pipeline, parameter",
439-
"from openhexa.sdk.pipelines.widgets import ParameterWidget",
439+
"from openhexa.sdk.pipelines.widgets import DHIS2Widget",
440440
"",
441441
"@parameter('dhis_con', name='DHIS2 Connection', type=DHIS2Connection, required=True)",
442-
"@parameter('data_element_ids', name='Data Elements id', type=str, widget=ParameterWidget.DHIS2_ORG_UNITS, connection='dhis_con', required=True)",
442+
"@parameter('data_element_ids', name='Data Elements id', type=str, widget=DHIS2Widget.ORG_UNITS, connection='dhis_con', required=True)",
443443
"@pipeline('Test pipeline')",
444444
"def test_pipeline():",
445445
" pass",
@@ -472,7 +472,7 @@ def test_pipeline_with_connection_parameter(self):
472472
"code": "data_element_ids",
473473
"type": "str",
474474
"name": "Data Elements id",
475-
"widget": ParameterWidget.DHIS2_ORG_UNITS.value,
475+
"widget": DHIS2Widget.ORG_UNITS.value,
476476
"connection": "dhis_con",
477477
"default": None,
478478
"multiple": False,
@@ -493,11 +493,11 @@ def test_pipeline_wit_wrong_connection_parameter(self):
493493
"\n".join(
494494
[
495495
"from openhexa.sdk.pipelines import pipeline, parameter",
496-
"from openhexa.sdk.pipelines.parameter import ParameterWidget",
496+
"from openhexa.sdk.pipelines.parameter import DHIS2Widget",
497497
"",
498498
"@parameter('dhis_con', name='DHIS2 Connection', type=DHIS2Connection, required=True)",
499499
"@pipeline('Test pipeline')",
500-
"@parameter('data_element_ids', name='Data Elements id', type=str, widget=ParameterWidget.DHIS2_ORG_UNITS, connection='sds_con', required=True)",
500+
"@parameter('data_element_ids', name='Data Elements id', type=str, widget=DHIS2Widget.ORG_UNITS, connection='sds_con', required=True)",
501501
"def test_pipeline():",
502502
" pass",
503503
"",
@@ -507,48 +507,26 @@ def test_pipeline_wit_wrong_connection_parameter(self):
507507
with self.assertRaises(InvalidParameterError):
508508
get_pipeline(tmpdirname)
509509

510-
def test_pipeline_with_widget_without_connection(self):
510+
def test_pipeline_with_dhis2_widget_without_connection(self):
511511
"""The file contains a @pipeline decorator and a @parameter decorator with a widget parameter field."""
512512
with tempfile.TemporaryDirectory() as tmpdirname:
513513
with open(f"{tmpdirname}/pipeline.py", "w") as f:
514514
f.write(
515515
"\n".join(
516516
[
517517
"from openhexa.sdk.pipelines import pipeline, parameter",
518-
"from openhexa.sdk.pipelines.parameter import ParameterWidget",
518+
"from openhexa.sdk.pipelines.parameter import DHIS2Widget",
519519
"",
520-
"@parameter('test_field_for_widget', name='Widget Param', type=str, widget=ParameterWidget.DHIS2_ORG_UNITS, help='Param help')",
520+
"@parameter('test_field_for_widget', name='Widget Param', type=str, widget=DHIS2Widget.ORG_UNITS, help='Param help')",
521521
"@pipeline('Test pipeline')",
522522
"def test_pipeline():",
523523
" pass",
524524
"",
525525
]
526526
)
527527
)
528-
pipeline = get_pipeline(tmpdirname)
529-
self.assertEqual(
530-
pipeline.to_dict(),
531-
{
532-
"name": "Test pipeline",
533-
"function": None,
534-
"tasks": [],
535-
"parameters": [
536-
{
537-
"code": "test_field_for_widget",
538-
"type": "str",
539-
"name": "Widget Param",
540-
"default": None,
541-
"multiple": False,
542-
"choices": None,
543-
"widget": ParameterWidget.DHIS2_ORG_UNITS.value,
544-
"connection": None,
545-
"help": "Param help",
546-
"required": True,
547-
}
548-
],
549-
"timeout": None,
550-
},
551-
)
528+
with self.assertRaises(InvalidParameterError):
529+
get_pipeline(tmpdirname)
552530

553531
def test_pipeline_with_deprecated_code_argument_with_name(self):
554532
"""The file contains a @pipeline decorator with the deprecated 'code' argument."""

0 commit comments

Comments
 (0)