Skip to content

Commit 38baaed

Browse files
committed
python padding
1 parent c88027e commit 38baaed

2 files changed

Lines changed: 53 additions & 3 deletions

File tree

python/tests/test_metadata.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -671,6 +671,51 @@ def test_round_trip_with_struct_and_json(self):
671671
out = ms.decode_row(encoded)
672672
assert out == row
673673

674+
def test_blob_bytes_aligned(self):
675+
# test that the portion of the encoded metadata up until the struct
676+
# is 8-byte aligned; we do that in the pedantic way
677+
# of figuring out how much memory is being used per int
678+
# in the struct part and subtracting that off
679+
def schema_with_blobs(k):
680+
schema = {
681+
"codec": "json+struct",
682+
"json": {
683+
"type": "object",
684+
"properties": {
685+
"label": {"type": "string"},
686+
"count": {"type": "number"},
687+
},
688+
"required": ["label"],
689+
},
690+
"struct": {
691+
"type": "object",
692+
"properties": {},
693+
},
694+
}
695+
for j in range(k):
696+
schema["struct"]["properties"][f"b{j}"] = {
697+
"type": "integer",
698+
"binaryFormat": "i",
699+
}
700+
return tskit.MetadataSchema(schema)
701+
702+
k_list = (0, 1, 2, 3)
703+
schemas = [schema_with_blobs(k) for k in k_list]
704+
rows = []
705+
for k in k_list:
706+
row = {"label": "alpha", "count": 7}
707+
for j in range(k):
708+
row[f"b{j}"] = j
709+
rows.append(row)
710+
encoded = [ms.validate_and_encode_row(row) for ms, row in zip(schemas, rows)]
711+
dbytes = len(encoded[2]) - len(encoded[1])
712+
assert len(encoded[3]) - len(encoded[2]) == dbytes
713+
for k, en in zip(k_list, encoded):
714+
assert (len(en) - k * dbytes) % 8 == 0
715+
for ms, en, row in zip(schemas, encoded, rows):
716+
decoded = ms.decode_row(en)
717+
assert decoded == row
718+
674719
def test_json_defaults_applied(self):
675720
schema = {
676721
"codec": "json+struct",

python/tskit/metadata.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -294,20 +294,25 @@ def encode(self, obj: Any) -> bytes:
294294
header = self._HDR.pack(
295295
self.MAGIC, self.VERSION, len(json_bytes), len(blob_bytes)
296296
)
297-
return header + json_bytes + blob_bytes
297+
padding_bytes = bytes((-(len(header) + len(json_bytes))) % 8)
298+
return header + json_bytes + padding_bytes + blob_bytes
298299

299300
def decode(self, encoded: bytes) -> Any:
300301
if len(encoded) >= self._HDR.size and encoded[:4] == self.MAGIC:
301302
_, version, jlen, blen = self._HDR.unpack_from(encoded)
302303
if version != self.VERSION:
303304
raise ValueError("Unsupported json+struct version")
304305
start = self._HDR.size
305-
if jlen > len(encoded) - start or blen > len(encoded) - start - jlen:
306+
padding_length = (-(start + jlen)) % 8
307+
if (
308+
jlen > len(encoded) - start
309+
or blen > len(encoded) - start - jlen - padding_length
310+
):
306311
raise ValueError(
307312
"Invalid json+struct payload: declared lengths exceed buffer size"
308313
)
309314
json_bytes = encoded[start : start + jlen]
310-
blob_bytes = encoded[start + jlen : start + jlen + blen]
315+
blob_bytes = encoded[start + jlen : start + jlen + blen + padding_length]
311316
json_data = self.json_codec.decode(json_bytes)
312317
struct_data = self.struct_codec.decode(blob_bytes)
313318
overlap = set(json_data).intersection(struct_data)

0 commit comments

Comments
 (0)