Skip to content

Commit d1303fe

Browse files
fix: address CI failures - format code, fix NaN check, skip test classes in public API test
1 parent b2211ba commit d1303fe

3 files changed

Lines changed: 71 additions & 22 deletions

File tree

pydsdl/_serdes.py

Lines changed: 67 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from __future__ import annotations
1313

14+
import math
1415
import struct
1516
import typing
1617

@@ -32,7 +33,6 @@
3233
UnionType,
3334
ServiceType,
3435
DelimitedType,
35-
Field,
3636
PaddingField,
3737
)
3838

@@ -162,7 +162,11 @@ def deserialize(
162162
payload_bit_length = payload_byte_length * 8
163163

164164
if payload_bit_length > reader.remaining_bits:
165-
inner_type_name = schema.inner_type.full_name if hasattr(schema.inner_type, 'full_name') else type(schema.inner_type).__name__
165+
inner_type_name = (
166+
schema.inner_type.full_name
167+
if hasattr(schema.inner_type, "full_name")
168+
else type(schema.inner_type).__name__
169+
)
166170
raise DelimiterHeaderError(
167171
f"Delimiter header specifies {payload_byte_length} bytes ({payload_bit_length} bits) "
168172
+ f"but only {reader.remaining_bits} bits remain (delimited type: {inner_type_name})"
@@ -350,23 +354,32 @@ def _serialize_primitive(writer: _BitWriter, schema: PrimitiveType | VoidType, v
350354
"""
351355
if isinstance(schema, BooleanType):
352356
if not isinstance(value, (bool, int, float)):
353-
raise ValueError(f"Boolean requires numeric input, got {type(value).__name__} (schema: {type(schema).__name__}, bit_length: {schema.bit_length})")
357+
raise ValueError(
358+
f"Boolean requires numeric input, got {type(value).__name__} "
359+
f"(schema: {type(schema).__name__}, bit_length: {schema.bit_length})"
360+
)
354361
if isinstance(value, float):
355362
if not (-float("inf") < value < float("inf")):
356-
raise ValueError(f"Non-finite float cannot be converted to bool (schema: {type(schema).__name__}, bit_length: {schema.bit_length})")
363+
raise ValueError(
364+
f"Non-finite float cannot be converted to bool "
365+
f"(schema: {type(schema).__name__}, bit_length: {schema.bit_length})"
366+
)
357367
bit_value = 1 if value else 0
358368
writer.write_bits(bit_value, 1)
359369

360370
elif isinstance(schema, FloatType):
361371
if not isinstance(value, (bool, int, float)):
362-
raise ValueError(f"Float requires numeric input, got {type(value).__name__} (schema: {type(schema).__name__}, bit_length: {schema.bit_length})")
372+
raise ValueError(
373+
f"Float requires numeric input, got {type(value).__name__} "
374+
f"(schema: {type(schema).__name__}, bit_length: {schema.bit_length})"
375+
)
363376
float_value = float(value)
364377

365378
if schema.cast_mode == PrimitiveType.CastMode.SATURATED:
366379
range_val = schema.inclusive_value_range
367380
min_bound = float(range_val.min)
368381
max_bound = float(range_val.max)
369-
if float_value != float_value:
382+
if math.isnan(float_value):
370383
pass
371384
elif float_value == float("inf"):
372385
pass
@@ -389,10 +402,16 @@ def _serialize_primitive(writer: _BitWriter, schema: PrimitiveType | VoidType, v
389402

390403
elif isinstance(schema, SignedIntegerType):
391404
if not isinstance(value, (bool, int, float)):
392-
raise ValueError(f"Integer requires numeric input, got {type(value).__name__} (schema: {type(schema).__name__}, bit_length: {schema.bit_length})")
405+
raise ValueError(
406+
f"Integer requires numeric input, got {type(value).__name__} "
407+
f"(schema: {type(schema).__name__}, bit_length: {schema.bit_length})"
408+
)
393409
if isinstance(value, float):
394410
if not (-float("inf") < value < float("inf")):
395-
raise ValueError(f"Non-finite float cannot be converted to int (schema: {type(schema).__name__}, bit_length: {schema.bit_length})")
411+
raise ValueError(
412+
f"Non-finite float cannot be converted to int "
413+
f"(schema: {type(schema).__name__}, bit_length: {schema.bit_length})"
414+
)
396415
int_value = int(round(value))
397416
else:
398417
int_value = int(value)
@@ -411,10 +430,16 @@ def _serialize_primitive(writer: _BitWriter, schema: PrimitiveType | VoidType, v
411430

412431
elif isinstance(schema, UnsignedIntegerType):
413432
if not isinstance(value, (bool, int, float)):
414-
raise ValueError(f"Integer requires numeric input, got {type(value).__name__} (schema: {type(schema).__name__}, bit_length: {schema.bit_length})")
433+
raise ValueError(
434+
f"Integer requires numeric input, got {type(value).__name__} "
435+
f"(schema: {type(schema).__name__}, bit_length: {schema.bit_length})"
436+
)
415437
if isinstance(value, float):
416438
if not (-float("inf") < value < float("inf")):
417-
raise ValueError(f"Non-finite float cannot be converted to int (schema: {type(schema).__name__}, bit_length: {schema.bit_length})")
439+
raise ValueError(
440+
f"Non-finite float cannot be converted to int "
441+
f"(schema: {type(schema).__name__}, bit_length: {schema.bit_length})"
442+
)
418443
int_value = int(round(value))
419444
else:
420445
int_value = int(value)
@@ -496,7 +521,11 @@ def _serialize_array(writer: _BitWriter, schema: ArrayType, value: _Value) -> No
496521
elif isinstance(value, (bytes, bytearray)):
497522
_ = value.decode("utf-8")
498523
else:
499-
raise TypeError(f"UTF-8 array requires str, bytes, or bytearray input, got {type(value).__name__} (array type: {type(schema).__name__}, capacity: {schema.capacity})")
524+
raise TypeError(
525+
f"UTF-8 array requires str, bytes, or bytearray input, "
526+
f"got {type(value).__name__} (array type: {type(schema).__name__}, "
527+
f"capacity: {schema.capacity})"
528+
)
500529
value = list(value)
501530

502531
elif isinstance(schema.element_type, ByteType):
@@ -518,14 +547,22 @@ def _serialize_array(writer: _BitWriter, schema: ArrayType, value: _Value) -> No
518547

519548
if isinstance(schema, FixedLengthArrayType):
520549
if len(value) != schema.capacity:
521-
raise ArrayLengthError(f"Fixed-length array requires exactly {schema.capacity} elements, got {len(value)} (array type: {type(schema).__name__}, capacity: {schema.capacity})")
550+
raise ArrayLengthError(
551+
f"Fixed-length array requires exactly {schema.capacity} elements, "
552+
f"got {len(value)} (array type: {type(schema).__name__}, "
553+
f"capacity: {schema.capacity})"
554+
)
522555

523556
for element in value:
524557
_serialize_element(writer, schema.element_type, element)
525558

526559
elif isinstance(schema, VariableLengthArrayType):
527560
if not (0 <= len(value) <= schema.capacity):
528-
raise ArrayLengthError(f"Variable-length array length {len(value)} exceeds capacity {schema.capacity} (array type: {type(schema).__name__}, capacity: {schema.capacity})")
561+
raise ArrayLengthError(
562+
f"Variable-length array length {len(value)} exceeds capacity "
563+
f"{schema.capacity} (array type: {type(schema).__name__}, "
564+
f"capacity: {schema.capacity})"
565+
)
529566

530567
writer.write_bits(len(value), schema.length_field_type.bit_length)
531568

@@ -547,7 +584,9 @@ def _deserialize_array(reader: _BitReader, schema: ArrayType) -> _Value:
547584
elif isinstance(schema, VariableLengthArrayType):
548585
length = reader.read_bits(schema.length_field_type.bit_length)
549586
if length > schema.capacity:
550-
raise ArrayLengthError(f"Variable-length array length {length} exceeds capacity {schema.capacity} (array type: {type(schema).__name__}, capacity: {schema.capacity})")
587+
raise ArrayLengthError(
588+
f"Variable-length array length {length} exceeds capacity {schema.capacity} (array type: {type(schema).__name__}, capacity: {schema.capacity})"
589+
)
551590
else:
552591
raise ValueError(f"Unknown array type: {type(schema).__name__}")
553592

@@ -611,9 +650,9 @@ def _serialize_composite(writer: _BitWriter, schema: CompositeType, obj: _Obj) -
611650
if not isinstance(obj, dict):
612651
raise ValueError("Union value must be a dict")
613652
if len(obj) == 0:
614-
raise ValueError(f"Union must have exactly one field, got none (union type: {schema.full_name})")
653+
raise ValueError(f"Union must have exactly one field, got none " f"(union type: {schema.full_name})")
615654
if len(obj) > 1:
616-
raise ValueError(f"Union must have exactly one field, got multiple (union type: {schema.full_name})")
655+
raise ValueError(f"Union must have exactly one field, got multiple " f"(union type: {schema.full_name})")
617656

618657
key = next(iter(obj.keys()))
619658
value = obj[key]
@@ -627,7 +666,10 @@ def _serialize_composite(writer: _BitWriter, schema: CompositeType, obj: _Obj) -
627666
break
628667

629668
if tag_index is None:
630-
raise UnionFieldError(f"Unknown union variant: {key} (union type: {schema.full_name}, valid variants: {[f.name for f in schema.fields]})")
669+
valid_variants = [f.name for f in schema.fields]
670+
raise UnionFieldError(
671+
f"Unknown union variant: {key} (union type: {schema.full_name}, " f"valid variants: {valid_variants})"
672+
)
631673

632674
assert field is not None
633675
writer.write_bits(tag_index, schema.tag_field_type.bit_length)
@@ -669,7 +711,11 @@ def _deserialize_composite(reader: _BitReader, schema: CompositeType) -> _Obj:
669711
payload_bit_length = payload_byte_length * 8
670712

671713
if payload_bit_length > reader.remaining_bits:
672-
inner_type_name = schema.inner_type.full_name if hasattr(schema.inner_type, 'full_name') else type(schema.inner_type).__name__
714+
inner_type_name = (
715+
schema.inner_type.full_name
716+
if hasattr(schema.inner_type, "full_name")
717+
else type(schema.inner_type).__name__
718+
)
673719
raise DelimiterHeaderError(
674720
f"Delimiter header specifies {payload_byte_length} bytes ({payload_bit_length} bits) "
675721
+ f"but only {reader.remaining_bits} bits remain (delimited type: {inner_type_name})"
@@ -681,7 +727,9 @@ def _deserialize_composite(reader: _BitReader, schema: CompositeType) -> _Obj:
681727
elif isinstance(schema, UnionType):
682728
tag = reader.read_bits(schema.tag_field_type.bit_length)
683729
if tag >= len(schema.fields):
684-
raise UnionTagError(f"Invalid union tag: {tag} (union type: {schema.full_name}, valid range: 0-{len(schema.fields)-1})")
730+
raise UnionTagError(
731+
f"Invalid union tag: {tag} (union type: {schema.full_name}, valid range: 0-{len(schema.fields)-1})"
732+
)
685733

686734
field = schema.fields[tag]
687735
value = _deserialize_field_value(reader, field.data_type)

pydsdl/_test.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
from pathlib import Path
1212
from textwrap import dedent
1313
import pytest # This is only safe to import in test files!
14-
from . import InvalidDefinitionError
1514
from . import _expression
1615
from . import _error
1716
from . import _parser
@@ -2158,4 +2157,7 @@ def _unittest_public_api() -> None:
21582157
for root in public_roots:
21592158
expected_types = {root} | set(_collect_descendants(root))
21602159
for t in expected_types:
2160+
# Skip test classes (defined in test modules)
2161+
if t.__module__.startswith("pydsdl._test"):
2162+
continue
21612163
assert t.__name__ in dir(pydsdl), "Data type %r is not exported" % t

pydsdl/_test_serdes.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,8 +104,7 @@ class MockServiceType(ServiceType):
104104
deserialize(mock_service, bytes([0]))
105105

106106
# Test 3: with_delimiter_header=True on non-delimited type raises ValueError
107-
# Create a mock StructureType
108-
class MockStructureType(StructureType):
107+
class MockStructureType(StructureType): # type: ignore[misc]
109108
pass
110109

111110
mock_struct = MockStructureType.__new__(MockStructureType)

0 commit comments

Comments
 (0)