|
| 1 | +"""Parameter class, decorator, and validation logic for pipeline parameters.""" |
| 2 | + |
| 3 | +import typing |
| 4 | + |
| 5 | +from openhexa.sdk.datasets import Dataset |
| 6 | +from openhexa.sdk.files import File |
| 7 | +from openhexa.sdk.pipelines.exceptions import InvalidParameterError, ParameterValueError |
| 8 | +from openhexa.sdk.pipelines.utils import validate_pipeline_parameter_code |
| 9 | +from openhexa.sdk.workspaces.connection import ( |
| 10 | + CustomConnection, |
| 11 | + DHIS2Connection, |
| 12 | + GCSConnection, |
| 13 | + IASOConnection, |
| 14 | + PostgreSQLConnection, |
| 15 | + S3Connection, |
| 16 | +) |
| 17 | + |
| 18 | +from .types import TYPES_BY_PYTHON_TYPE, Boolean, DHIS2ConnectionType, IASOConnectionType, Secret |
| 19 | +from .widgets import DHIS2Widget, IASOWidget |
| 20 | + |
| 21 | + |
| 22 | +class Parameter: |
| 23 | + """Pipeline parameter class. Contains validation logic specs generation logic.""" |
| 24 | + |
| 25 | + def __init__( |
| 26 | + self, |
| 27 | + code: str, |
| 28 | + *, |
| 29 | + type: type[ |
| 30 | + str |
| 31 | + | int |
| 32 | + | bool |
| 33 | + | float |
| 34 | + | Secret |
| 35 | + | DHIS2Connection |
| 36 | + | IASOConnection |
| 37 | + | PostgreSQLConnection |
| 38 | + | GCSConnection |
| 39 | + | S3Connection |
| 40 | + | CustomConnection |
| 41 | + | Dataset |
| 42 | + | File |
| 43 | + ], |
| 44 | + name: str | None = None, |
| 45 | + choices: typing.Sequence | None = None, |
| 46 | + help: str | None = None, |
| 47 | + default: typing.Any | None = None, |
| 48 | + widget: DHIS2Widget | IASOWidget | None = None, |
| 49 | + connection: str | None = None, |
| 50 | + required: bool = True, |
| 51 | + multiple: bool = False, |
| 52 | + directory: str | None = None, |
| 53 | + ): |
| 54 | + validate_pipeline_parameter_code(code) |
| 55 | + self.code = code |
| 56 | + |
| 57 | + try: |
| 58 | + self.type = TYPES_BY_PYTHON_TYPE[type.__name__]() |
| 59 | + except (KeyError, AttributeError): |
| 60 | + valid_parameter_types = [k for k in TYPES_BY_PYTHON_TYPE.keys()] |
| 61 | + raise InvalidParameterError( |
| 62 | + f"Invalid parameter type provided ({type}). " |
| 63 | + f"Valid parameter types are {', '.join(valid_parameter_types)}" |
| 64 | + ) |
| 65 | + |
| 66 | + if choices is not None: |
| 67 | + if not self.type.accepts_choices: |
| 68 | + raise InvalidParameterError(f"Parameters of type {self.type} don't accept choices.") |
| 69 | + elif len(choices) == 0: |
| 70 | + raise InvalidParameterError("Choices, if provided, cannot be empty.") |
| 71 | + |
| 72 | + try: |
| 73 | + for choice in choices: |
| 74 | + self.type.validate(choice) |
| 75 | + except ParameterValueError: |
| 76 | + raise InvalidParameterError(f"The provided choices are not valid for the {self.type} parameter type.") |
| 77 | + self.choices = choices |
| 78 | + |
| 79 | + self.name = name |
| 80 | + self.help = help |
| 81 | + self.required = required |
| 82 | + |
| 83 | + if multiple is True and not self.type.accepts_multiple: |
| 84 | + raise InvalidParameterError(f"Parameters of type {self.type} can't have multiple values.") |
| 85 | + self.multiple = multiple |
| 86 | + |
| 87 | + self.widget = widget |
| 88 | + self.connection = connection |
| 89 | + self.directory = directory |
| 90 | + |
| 91 | + self._validate_default(default, multiple) |
| 92 | + self.default = default |
| 93 | + |
| 94 | + def validate(self, value: typing.Any) -> typing.Any: |
| 95 | + """Validate the provided value against the parameter, taking required / default options into account.""" |
| 96 | + if self.multiple: |
| 97 | + return self._validate_multiple(value) |
| 98 | + else: |
| 99 | + return self._validate_single(value) |
| 100 | + |
| 101 | + def to_dict(self) -> dict[str, typing.Any]: |
| 102 | + """Return a dictionary representation of the Parameter instance.""" |
| 103 | + return { |
| 104 | + "code": self.code, |
| 105 | + "type": self.type.spec_type, |
| 106 | + "name": self.name, |
| 107 | + "choices": self.choices, |
| 108 | + "help": self.help, |
| 109 | + "default": self.default, |
| 110 | + "widget": self.widget.value if self.widget else None, |
| 111 | + "connection": self.connection, |
| 112 | + "required": self.required, |
| 113 | + "multiple": self.multiple, |
| 114 | + "directory": self.directory, |
| 115 | + } |
| 116 | + |
| 117 | + def _validate_single(self, value: typing.Any): |
| 118 | + # Normalize empty values to None and handles default |
| 119 | + normalized_value = self.type.normalize(value) |
| 120 | + if normalized_value is None and self.default is not None: |
| 121 | + normalized_value = self.default |
| 122 | + |
| 123 | + if normalized_value is None: |
| 124 | + if isinstance(self.type, Boolean): |
| 125 | + normalized_value = False |
| 126 | + elif self.required: |
| 127 | + raise ParameterValueError(f"{self.code} is required") |
| 128 | + else: |
| 129 | + return None |
| 130 | + |
| 131 | + pre_validated = self.type.validate(normalized_value) |
| 132 | + if self.choices is not None and pre_validated not in self.choices: |
| 133 | + raise ParameterValueError(f"The provided value for {self.code} is not included in the provided choices.") |
| 134 | + |
| 135 | + return pre_validated |
| 136 | + |
| 137 | + def _validate_multiple(self, value: typing.Any): |
| 138 | + # Reject values that are not lists |
| 139 | + if value is not None and not isinstance(value, list): |
| 140 | + raise InvalidParameterError("If provided, value should be a list when parameter is multiple.") |
| 141 | + |
| 142 | + # Normalize empty values to an empty list |
| 143 | + if value is None: |
| 144 | + normalized_value = [] |
| 145 | + else: |
| 146 | + normalized_value = [self.type.normalize(v) for v in value] |
| 147 | + normalized_value = list(filter(lambda v: v is not None, normalized_value)) |
| 148 | + if len(normalized_value) == 0 and self.default is not None: |
| 149 | + normalized_value = self.default |
| 150 | + |
| 151 | + if len(normalized_value) == 0 and self.required: |
| 152 | + raise ParameterValueError(f"{self.code} is required") |
| 153 | + |
| 154 | + pre_validated = [self.type.validate(single_value) for single_value in normalized_value] |
| 155 | + if self.choices is not None and any(v not in self.choices for v in pre_validated): |
| 156 | + raise ParameterValueError( |
| 157 | + f"One of the provided values for {self.code} is not included in the provided choices." |
| 158 | + ) |
| 159 | + |
| 160 | + return pre_validated |
| 161 | + |
| 162 | + def _validate_default(self, default: typing.Any, multiple: bool): |
| 163 | + if default is None: |
| 164 | + return |
| 165 | + |
| 166 | + try: |
| 167 | + if multiple: |
| 168 | + if not isinstance(default, list): |
| 169 | + raise InvalidParameterError("Default values should be lists when using multiple=True") |
| 170 | + for default_value in default: |
| 171 | + self.type.validate_default(default_value) |
| 172 | + else: |
| 173 | + self.type.validate_default(default) |
| 174 | + except ParameterValueError: |
| 175 | + raise InvalidParameterError(f"The default value for {self.code} is not valid.") |
| 176 | + |
| 177 | + if self.choices is not None: |
| 178 | + if isinstance(default, list): |
| 179 | + if not all(d in self.choices for d in default): |
| 180 | + raise InvalidParameterError( |
| 181 | + f"The default list of values for {self.code} is not included in the provided choices." |
| 182 | + ) |
| 183 | + elif default not in self.choices: |
| 184 | + raise InvalidParameterError( |
| 185 | + f"The default value for {self.code} is not included in the provided choices." |
| 186 | + ) |
| 187 | + |
| 188 | + |
| 189 | +def validate_parameters(parameters: list[Parameter]): |
| 190 | + """Validate the provided connection parameters if they relate to existing connection parameter.""" |
| 191 | + supported_connection_types = {DHIS2ConnectionType, IASOConnectionType} |
| 192 | + connection_parameters = {p.code for p in parameters if type(p.type) in supported_connection_types} |
| 193 | + |
| 194 | + for parameter in parameters: |
| 195 | + if parameter.connection and parameter.connection not in connection_parameters: |
| 196 | + raise InvalidParameterError( |
| 197 | + f"Connection field '{parameter.code}' references a non-existing connection parameter '{parameter.connection}'" |
| 198 | + ) |
| 199 | + if ( |
| 200 | + parameter.widget |
| 201 | + and (parameter.widget in DHIS2Widget or parameter.widget in IASOWidget) |
| 202 | + and not parameter.connection |
| 203 | + ): |
| 204 | + raise InvalidParameterError( |
| 205 | + f"Widgets require a connection parameter. Please provide a connection parameter for {parameter.code}. " |
| 206 | + f"Example: @parameter('my_connection', ...)" |
| 207 | + f"Example: @parameter('{parameter.code}', widget = ..., connection='my_connection')" |
| 208 | + ) |
| 209 | + |
| 210 | + |
| 211 | +def parameter( |
| 212 | + code: str, |
| 213 | + *, |
| 214 | + type: type[ |
| 215 | + str |
| 216 | + | int |
| 217 | + | bool |
| 218 | + | float |
| 219 | + | Secret |
| 220 | + | DHIS2Connection |
| 221 | + | IASOConnection |
| 222 | + | PostgreSQLConnection |
| 223 | + | GCSConnection |
| 224 | + | S3Connection |
| 225 | + | CustomConnection |
| 226 | + | Dataset |
| 227 | + | File |
| 228 | + ], |
| 229 | + name: str | None = None, |
| 230 | + choices: typing.Sequence | None = None, |
| 231 | + help: str | None = None, |
| 232 | + widget: DHIS2Widget | IASOWidget | None = None, |
| 233 | + connection: str | None = None, |
| 234 | + default: typing.Any | None = None, |
| 235 | + required: bool = True, |
| 236 | + multiple: bool = False, |
| 237 | + directory: str | None = None, |
| 238 | +): |
| 239 | + """Decorate a pipeline function by attaching a parameter to it.. |
| 240 | +
|
| 241 | + This decorator must be used on a function decorated by the @pipeline decorator. |
| 242 | +
|
| 243 | + Parameters |
| 244 | + ---------- |
| 245 | + code : str |
| 246 | + The parameter identifier (must be unique for a given pipeline) |
| 247 | + type : {str, int, bool, float, DHIS2Connection, IASOConnection, PostgreSQLConnection, GCSConnection, S3Connection, CustomConnection, Dataset, File} |
| 248 | + The parameter Python type |
| 249 | + name : str, optional |
| 250 | + A name for the parameter (will be used instead of the code in the web interface) |
| 251 | + choices : list, optional |
| 252 | + An optional list or tuple of choices for the parameter (will be used to build a choice widget in the web |
| 253 | + interface) |
| 254 | + help : str, optional |
| 255 | + An optional help text to be displayed in the web interface |
| 256 | + widget : DHIS2Widget|IASOWidget, optional |
| 257 | + An optional widget type for the parameter (only used if the parameter type is DHIS2Connection, IASOConnection) |
| 258 | + connection : str, optional |
| 259 | + An optional connection parameter that will be used to link widget to the connection. |
| 260 | + default : any, optional |
| 261 | + An optional default value for the parameter (should be of the type defined by the type parameter) |
| 262 | + required : bool, default=True |
| 263 | + Whether the parameter is mandatory |
| 264 | + multiple : bool, default=True |
| 265 | + Whether this parameter should be provided multiple values (if True, the value must be provided as a list of |
| 266 | + values of the chosen type) |
| 267 | + directory : str, optional |
| 268 | + An optional parameter to force file selection to specific directory (only used for parameter type File). If the directory does not exist, it will be ignored. |
| 269 | +
|
| 270 | + Returns |
| 271 | + ------- |
| 272 | + typing.Callable |
| 273 | + A decorator that returns the Pipeline with the parameter attached |
| 274 | +
|
| 275 | + """ |
| 276 | + |
| 277 | + def decorator(fun): |
| 278 | + return FunctionWithParameter( |
| 279 | + fun, |
| 280 | + Parameter( |
| 281 | + code, |
| 282 | + type=type, |
| 283 | + name=name, |
| 284 | + choices=choices, |
| 285 | + help=help, |
| 286 | + default=default, |
| 287 | + required=required, |
| 288 | + widget=widget, |
| 289 | + connection=connection, |
| 290 | + multiple=multiple, |
| 291 | + directory=directory, |
| 292 | + ), |
| 293 | + ) |
| 294 | + |
| 295 | + return decorator |
| 296 | + |
| 297 | + |
| 298 | +class FunctionWithParameter: |
| 299 | + """Wrapper class for pipeline functions decorated with the @parameter decorator.""" |
| 300 | + |
| 301 | + def __init__(self, function, added_parameter: Parameter): |
| 302 | + self.function = function |
| 303 | + self.parameter = added_parameter |
| 304 | + |
| 305 | + def get_all_parameters(self) -> list[Parameter]: |
| 306 | + """Go through the decorators chain to find all pipeline parameters.""" |
| 307 | + if isinstance(self.function, FunctionWithParameter): |
| 308 | + return [self.parameter, *self.function.get_all_parameters()] |
| 309 | + |
| 310 | + return [self.parameter] |
| 311 | + |
| 312 | + def __call__(self, *args, **kwargs): |
| 313 | + """Call the decorated pipeline function.""" |
| 314 | + return self.function(*args, **kwargs) |
0 commit comments