Skip to content

Commit dbe81b8

Browse files
Moved out version-independent functions
1 parent 9fd4031 commit dbe81b8

5 files changed

Lines changed: 288 additions & 374 deletions

File tree

cwl_utils/parser/__init__.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,16 +25,13 @@ class NoType(ABC):
2525
Saveable: TypeAlias = cwl_v1_0.Saveable | cwl_v1_1.Saveable | cwl_v1_2.Saveable
2626
"""Type union for a CWL v1.x Saveable object."""
2727
InputParameter: TypeAlias = (
28-
cwl_v1_0.InputParameter
29-
| cwl_v1_0.CommandInputParameter
30-
| cwl_v1_1.InputParameter
31-
| cwl_v1_2.InputParameter
28+
cwl_v1_0.InputParameter | cwl_v1_1.InputParameter | cwl_v1_2.InputParameter
3229
)
3330
"""Type union for a CWL v1.x InputEnumSchema object."""
3431
InputRecordField: TypeAlias = (
3532
cwl_v1_0.InputRecordField | cwl_v1_1.InputRecordField | cwl_v1_2.InputRecordField
3633
)
37-
"""Type union for a CWL v1.x InputRecordSchema object."""
34+
"""Type union for a CWL v1.x InputRecordField object."""
3835
InputSchema: TypeAlias = (
3936
cwl_v1_0.InputSchema | cwl_v1_1.InputSchema | cwl_v1_2.InputSchema
4037
)
@@ -166,6 +163,17 @@ class NoType(ABC):
166163
| cwl_v1_2.CommandOutputRecordField
167164
)
168165
"""Type union for a CWL v1.x CommandOutputRecordField object."""
166+
CommandOutputRecordSchema: TypeAlias = (
167+
cwl_v1_0.CommandOutputRecordSchema
168+
| cwl_v1_1.CommandOutputRecordSchema
169+
| cwl_v1_2.CommandOutputRecordSchema
170+
)
171+
CommandOutputRecordSchemaTypes = (
172+
cwl_v1_0.CommandOutputRecordSchema,
173+
cwl_v1_1.CommandOutputRecordSchema,
174+
cwl_v1_2.CommandOutputRecordSchema,
175+
)
176+
"""Type Union for a CWL v1.x CommandOutputRecordSchema object."""
169177
ExpressionTool: TypeAlias = (
170178
cwl_v1_0.ExpressionTool | cwl_v1_1.ExpressionTool | cwl_v1_2.ExpressionTool
171179
)
@@ -242,7 +250,7 @@ class NoType(ABC):
242250
cwl_v1_1.InputRecordSchema,
243251
cwl_v1_2.InputRecordSchema,
244252
)
245-
"""Type Union for a CWL v1.x RecordSchema object."""
253+
"""Type Union for a CWL v1.x InputRecordSchema object."""
246254

247255
File: TypeAlias = cwl_v1_0.File | cwl_v1_1.File | cwl_v1_2.File
248256
"""Type Union for a CWL v1.x File object."""

cwl_utils/parser/cwl_v1_0_utils.py

Lines changed: 14 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@
22
import hashlib
33
import logging
44
from collections import namedtuple
5-
from collections.abc import MutableMapping, MutableSequence, Sequence
5+
from collections.abc import MutableMapping, MutableSequence, Sequence, Mapping
66
from io import StringIO
77
from pathlib import Path
8-
from typing import IO, Any, TypeAlias, TypeVar, cast
8+
from typing import IO, Any, TypeAlias, TypeVar
99
from urllib.parse import urldefrag
1010

1111
from schema_salad.exceptions import ValidationException
@@ -100,39 +100,6 @@ def in_output_type_schema_to_output_type_schema(
100100
return _in_output_type_schema_to_output_type_schema(schema_type, loading_options)
101101

102102

103-
def _compare_records(
104-
src: cwl.RecordSchema, sink: cwl.RecordSchema, strict: bool = False
105-
) -> bool:
106-
"""
107-
Compare two records, ensuring they have compatible fields.
108-
109-
This handles normalizing record names, which will be relative to workflow
110-
step, so that they can be compared.
111-
"""
112-
srcfields = {cwl.shortname(field.name): field.type_ for field in (src.fields or {})}
113-
sinkfields = {
114-
cwl.shortname(field.name): field.type_ for field in (sink.fields or {})
115-
}
116-
for key in sinkfields.keys():
117-
if (
118-
not can_assign_src_to_sink(
119-
srcfields.get(key, "null"), sinkfields.get(key, "null"), strict
120-
)
121-
and sinkfields.get(key) is not None
122-
):
123-
_logger.info(
124-
"Record comparison failure for %s and %s\n"
125-
"Did not match fields for %s: %s and %s",
126-
cast(cwl.InputRecordSchema | cwl.CommandOutputRecordSchema, src).name,
127-
cast(cwl.InputRecordSchema | cwl.CommandOutputRecordSchema, sink).name,
128-
key,
129-
srcfields.get(key),
130-
sinkfields.get(key),
131-
)
132-
return False
133-
return True
134-
135-
136103
def _compare_type(type1: Any, type2: Any) -> bool:
137104
match (type1, type1):
138105
case cwl.ArraySchema() as t1, cwl.ArraySchema() as t2:
@@ -219,45 +186,15 @@ def _inputfile_load(
219186
)
220187

221188

222-
def can_assign_src_to_sink(src: Any, sink: Any, strict: bool = False) -> bool:
223-
"""
224-
Check for identical type specifications, ignoring extra keys like inputBinding.
225-
226-
src: admissible source types
227-
sink: admissible sink types
228-
229-
In non-strict comparison, at least one source type must match one sink type,
230-
except for 'null'.
231-
In strict comparison, all source types must match at least one sink type.
232-
"""
233-
if "Any" in (src, sink):
234-
return True
235-
if isinstance(src, cwl.ArraySchema) and isinstance(sink, cwl.ArraySchema):
236-
return can_assign_src_to_sink(src.items, sink.items, strict)
237-
if isinstance(src, cwl.RecordSchema) and isinstance(sink, cwl.RecordSchema):
238-
return _compare_records(src, sink, strict)
239-
if isinstance(src, MutableSequence):
240-
if strict:
241-
for this_src in src:
242-
if not can_assign_src_to_sink(this_src, sink):
243-
return False
244-
return True
245-
for this_src in src:
246-
if this_src != "null" and can_assign_src_to_sink(this_src, sink):
247-
return True
248-
return False
249-
if isinstance(sink, MutableSequence):
250-
for this_sink in sink:
251-
if can_assign_src_to_sink(src, this_sink):
252-
return True
253-
return False
254-
return bool(src == sink)
255-
256-
257189
def check_all_types(
258-
src_dict: dict[str, Any],
190+
src_dict: Mapping[str, cwl.InputParameter | cwl.WorkflowStepOutput],
259191
sinks: Sequence[cwl.WorkflowStepInput | cwl.WorkflowOutputParameter],
260-
type_dict: dict[str, Any],
192+
type_dict: Mapping[
193+
str,
194+
cwl_utils.parser.utils.InputTypeSchemas
195+
| cwl_utils.parser.utils.OutputTypeSchemas
196+
| None,
197+
],
261198
) -> dict[str, list[SrcSink]]:
262199
"""Given a list of sinks, check if their types match with the types of their sources."""
263200
validation: dict[str, list[SrcSink]] = {"warning": [], "exception": []}
@@ -288,45 +225,17 @@ def check_all_types(
288225
srcs_of_sink = [src_dict[parm_id]]
289226
linkMerge = None
290227
for src in srcs_of_sink:
291-
check_result = check_types(
292-
type_dict[cast(str, src.id)],
228+
check_result = cwl_utils.parser.utils.check_types(
229+
type_dict[src.id],
293230
type_dict[sink.id],
294231
linkMerge,
295-
getattr(sink, "valueFrom", None),
232+
sink.valueFrom if isinstance(sink, cwl.WorkflowStepInput) else None,
296233
)
297234
if check_result in ("warning", "exception"):
298235
validation[check_result].append(SrcSink(src, sink, linkMerge, None))
299236
return validation
300237

301238

302-
def check_types(
303-
srctype: Any,
304-
sinktype: Any,
305-
linkMerge: str | None,
306-
valueFrom: str | None = None,
307-
) -> str:
308-
"""
309-
Check if the source and sink types are correct.
310-
311-
Acceptable types are "pass", "warning", or "exception".
312-
"""
313-
if valueFrom is not None:
314-
return "pass"
315-
if linkMerge is None:
316-
if can_assign_src_to_sink(srctype, sinktype, strict=True):
317-
return "pass"
318-
if can_assign_src_to_sink(srctype, sinktype, strict=False):
319-
return "warning"
320-
return "exception"
321-
if linkMerge == "merge_nested":
322-
return check_types(
323-
cwl.ArraySchema(items=srctype, type_="array"), sinktype, None, None
324-
)
325-
if linkMerge == "merge_flattened":
326-
return check_types(merge_flatten_type(srctype), sinktype, None, None)
327-
raise ValidationException(f"Invalid value {linkMerge} for linkMerge field.")
328-
329-
330239
def content_limit_respected_read_bytes(f: IO[bytes]) -> bytes:
331240
"""
332241
Read file content up to 64 kB as a byte array.
@@ -430,15 +339,6 @@ def load_inputfile_by_yaml(
430339
return result
431340

432341

433-
def merge_flatten_type(src: Any) -> Any:
434-
"""Return the merge flattened type of the source type."""
435-
if isinstance(src, MutableSequence):
436-
return [merge_flatten_type(t) for t in src]
437-
if isinstance(src, cwl.ArraySchema):
438-
return src
439-
return cwl.ArraySchema(type_="array", items=src)
440-
441-
442342
def to_input_array(type_: InputTypeSchemas) -> cwl.InputArraySchema:
443343
return cwl.InputArraySchema(type_="array", items=type_)
444344

@@ -542,7 +442,7 @@ def type_for_source(
542442
type_="array",
543443
)
544444
elif linkMerge == "merge_flattened":
545-
new_type = merge_flatten_type(new_type)
445+
new_type = cwl_utils.parser.utils.merge_flatten_type(new_type)
546446
return new_type
547447
new_types: MutableSequence[InputTypeSchemas | OutputTypeSchemas] = []
548448
for p, sc in zip(params, scatter_context):
@@ -586,7 +486,7 @@ def type_for_source(
586486
type_="array",
587487
)
588488
elif linkMerge == "merge_flattened":
589-
final_type = merge_flatten_type(final_type)
489+
final_type = cwl_utils.parser.utils.merge_flatten_type(final_type)
590490
elif isinstance(sourcenames, list) and len(sourcenames) > 1:
591491
return cwl.OutputArraySchema(
592492
items=final_type,

0 commit comments

Comments
 (0)