|
2 | 2 | import hashlib |
3 | 3 | import logging |
4 | 4 | from collections import namedtuple |
5 | | -from collections.abc import MutableMapping, MutableSequence, Sequence |
| 5 | +from collections.abc import MutableMapping, MutableSequence, Sequence, Mapping |
6 | 6 | from io import StringIO |
7 | 7 | from pathlib import Path |
8 | | -from typing import IO, Any, TypeAlias, TypeVar, cast |
| 8 | +from typing import IO, Any, TypeAlias, TypeVar |
9 | 9 | from urllib.parse import urldefrag |
10 | 10 |
|
11 | 11 | from schema_salad.exceptions import ValidationException |
@@ -100,39 +100,6 @@ def in_output_type_schema_to_output_type_schema( |
100 | 100 | return _in_output_type_schema_to_output_type_schema(schema_type, loading_options) |
101 | 101 |
|
102 | 102 |
|
103 | | -def _compare_records( |
104 | | - src: cwl.RecordSchema, sink: cwl.RecordSchema, strict: bool = False |
105 | | -) -> bool: |
106 | | - """ |
107 | | - Compare two records, ensuring they have compatible fields. |
108 | | -
|
109 | | - This handles normalizing record names, which will be relative to workflow |
110 | | - step, so that they can be compared. |
111 | | - """ |
112 | | - srcfields = {cwl.shortname(field.name): field.type_ for field in (src.fields or {})} |
113 | | - sinkfields = { |
114 | | - cwl.shortname(field.name): field.type_ for field in (sink.fields or {}) |
115 | | - } |
116 | | - for key in sinkfields.keys(): |
117 | | - if ( |
118 | | - not can_assign_src_to_sink( |
119 | | - srcfields.get(key, "null"), sinkfields.get(key, "null"), strict |
120 | | - ) |
121 | | - and sinkfields.get(key) is not None |
122 | | - ): |
123 | | - _logger.info( |
124 | | - "Record comparison failure for %s and %s\n" |
125 | | - "Did not match fields for %s: %s and %s", |
126 | | - cast(cwl.InputRecordSchema | cwl.CommandOutputRecordSchema, src).name, |
127 | | - cast(cwl.InputRecordSchema | cwl.CommandOutputRecordSchema, sink).name, |
128 | | - key, |
129 | | - srcfields.get(key), |
130 | | - sinkfields.get(key), |
131 | | - ) |
132 | | - return False |
133 | | - return True |
134 | | - |
135 | | - |
136 | 103 | def _compare_type(type1: Any, type2: Any) -> bool: |
137 | 104 | match (type1, type1): |
138 | 105 | case cwl.ArraySchema() as t1, cwl.ArraySchema() as t2: |
@@ -219,45 +186,15 @@ def _inputfile_load( |
219 | 186 | ) |
220 | 187 |
|
221 | 188 |
|
222 | | -def can_assign_src_to_sink(src: Any, sink: Any, strict: bool = False) -> bool: |
223 | | - """ |
224 | | - Check for identical type specifications, ignoring extra keys like inputBinding. |
225 | | -
|
226 | | - src: admissible source types |
227 | | - sink: admissible sink types |
228 | | -
|
229 | | - In non-strict comparison, at least one source type must match one sink type, |
230 | | - except for 'null'. |
231 | | - In strict comparison, all source types must match at least one sink type. |
232 | | - """ |
233 | | - if "Any" in (src, sink): |
234 | | - return True |
235 | | - if isinstance(src, cwl.ArraySchema) and isinstance(sink, cwl.ArraySchema): |
236 | | - return can_assign_src_to_sink(src.items, sink.items, strict) |
237 | | - if isinstance(src, cwl.RecordSchema) and isinstance(sink, cwl.RecordSchema): |
238 | | - return _compare_records(src, sink, strict) |
239 | | - if isinstance(src, MutableSequence): |
240 | | - if strict: |
241 | | - for this_src in src: |
242 | | - if not can_assign_src_to_sink(this_src, sink): |
243 | | - return False |
244 | | - return True |
245 | | - for this_src in src: |
246 | | - if this_src != "null" and can_assign_src_to_sink(this_src, sink): |
247 | | - return True |
248 | | - return False |
249 | | - if isinstance(sink, MutableSequence): |
250 | | - for this_sink in sink: |
251 | | - if can_assign_src_to_sink(src, this_sink): |
252 | | - return True |
253 | | - return False |
254 | | - return bool(src == sink) |
255 | | - |
256 | | - |
257 | 189 | def check_all_types( |
258 | | - src_dict: dict[str, Any], |
| 190 | + src_dict: Mapping[str, cwl.InputParameter | cwl.WorkflowStepOutput], |
259 | 191 | sinks: Sequence[cwl.WorkflowStepInput | cwl.WorkflowOutputParameter], |
260 | | - type_dict: dict[str, Any], |
| 192 | + type_dict: Mapping[ |
| 193 | + str, |
| 194 | + cwl_utils.parser.utils.InputTypeSchemas |
| 195 | + | cwl_utils.parser.utils.OutputTypeSchemas |
| 196 | + | None, |
| 197 | + ], |
261 | 198 | ) -> dict[str, list[SrcSink]]: |
262 | 199 | """Given a list of sinks, check if their types match with the types of their sources.""" |
263 | 200 | validation: dict[str, list[SrcSink]] = {"warning": [], "exception": []} |
@@ -288,45 +225,17 @@ def check_all_types( |
288 | 225 | srcs_of_sink = [src_dict[parm_id]] |
289 | 226 | linkMerge = None |
290 | 227 | for src in srcs_of_sink: |
291 | | - check_result = check_types( |
292 | | - type_dict[cast(str, src.id)], |
| 228 | + check_result = cwl_utils.parser.utils.check_types( |
| 229 | + type_dict[src.id], |
293 | 230 | type_dict[sink.id], |
294 | 231 | linkMerge, |
295 | | - getattr(sink, "valueFrom", None), |
| 232 | + sink.valueFrom if isinstance(sink, cwl.WorkflowStepInput) else None, |
296 | 233 | ) |
297 | 234 | if check_result in ("warning", "exception"): |
298 | 235 | validation[check_result].append(SrcSink(src, sink, linkMerge, None)) |
299 | 236 | return validation |
300 | 237 |
|
301 | 238 |
|
302 | | -def check_types( |
303 | | - srctype: Any, |
304 | | - sinktype: Any, |
305 | | - linkMerge: str | None, |
306 | | - valueFrom: str | None = None, |
307 | | -) -> str: |
308 | | - """ |
309 | | - Check if the source and sink types are correct. |
310 | | -
|
311 | | - Acceptable types are "pass", "warning", or "exception". |
312 | | - """ |
313 | | - if valueFrom is not None: |
314 | | - return "pass" |
315 | | - if linkMerge is None: |
316 | | - if can_assign_src_to_sink(srctype, sinktype, strict=True): |
317 | | - return "pass" |
318 | | - if can_assign_src_to_sink(srctype, sinktype, strict=False): |
319 | | - return "warning" |
320 | | - return "exception" |
321 | | - if linkMerge == "merge_nested": |
322 | | - return check_types( |
323 | | - cwl.ArraySchema(items=srctype, type_="array"), sinktype, None, None |
324 | | - ) |
325 | | - if linkMerge == "merge_flattened": |
326 | | - return check_types(merge_flatten_type(srctype), sinktype, None, None) |
327 | | - raise ValidationException(f"Invalid value {linkMerge} for linkMerge field.") |
328 | | - |
329 | | - |
330 | 239 | def content_limit_respected_read_bytes(f: IO[bytes]) -> bytes: |
331 | 240 | """ |
332 | 241 | Read file content up to 64 kB as a byte array. |
@@ -430,15 +339,6 @@ def load_inputfile_by_yaml( |
430 | 339 | return result |
431 | 340 |
|
432 | 341 |
|
433 | | -def merge_flatten_type(src: Any) -> Any: |
434 | | - """Return the merge flattened type of the source type.""" |
435 | | - if isinstance(src, MutableSequence): |
436 | | - return [merge_flatten_type(t) for t in src] |
437 | | - if isinstance(src, cwl.ArraySchema): |
438 | | - return src |
439 | | - return cwl.ArraySchema(type_="array", items=src) |
440 | | - |
441 | | - |
442 | 342 | def to_input_array(type_: InputTypeSchemas) -> cwl.InputArraySchema: |
443 | 343 | return cwl.InputArraySchema(type_="array", items=type_) |
444 | 344 |
|
@@ -542,7 +442,7 @@ def type_for_source( |
542 | 442 | type_="array", |
543 | 443 | ) |
544 | 444 | elif linkMerge == "merge_flattened": |
545 | | - new_type = merge_flatten_type(new_type) |
| 445 | + new_type = cwl_utils.parser.utils.merge_flatten_type(new_type) |
546 | 446 | return new_type |
547 | 447 | new_types: MutableSequence[InputTypeSchemas | OutputTypeSchemas] = [] |
548 | 448 | for p, sc in zip(params, scatter_context): |
@@ -586,7 +486,7 @@ def type_for_source( |
586 | 486 | type_="array", |
587 | 487 | ) |
588 | 488 | elif linkMerge == "merge_flattened": |
589 | | - final_type = merge_flatten_type(final_type) |
| 489 | + final_type = cwl_utils.parser.utils.merge_flatten_type(final_type) |
590 | 490 | elif isinstance(sourcenames, list) and len(sourcenames) > 1: |
591 | 491 | return cwl.OutputArraySchema( |
592 | 492 | items=final_type, |
|
0 commit comments