Skip to content

Commit 3cfa396

Browse files
authored
Fix thread safety issues and add better concurrent tests (#27)
1 parent d6a0134 commit 3cfa396

4 files changed

Lines changed: 232 additions & 43 deletions

File tree

igor2/binarywave.py

Lines changed: 38 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Read IGOR Binary Wave files into Numpy arrays."""
22
import logging
3+
import threading as _threading
34
# Based on WaveMetric's Technical Note 003, "Igor Binary Format"
45
# ftp://ftp.wavemetrics.net/IgorPro/Technical_Notes/TN003.zip
56
# From ftp://ftp.wavemetrics.net/IgorPro/Technical_Notes/TN000.txt
@@ -13,11 +14,13 @@
1314
from .struct import DynamicStructure as _DynamicStructure
1415
from .struct import Field as _Field
1516
from .struct import DynamicField as _DynamicField
17+
from .struct import clone_structure as _clone_structure
1618
from .util import byte_order as _byte_order
1719
from .util import need_to_reorder_bytes as _need_to_reorder_bytes
1820

1921

2022
logger = logging.getLogger(__name__)
23+
_thread_local = _threading.local()
2124

2225
# Numpy doesn't support complex integers by default, see
2326
# http://mail.python.org/pipermail/python-dev/2002-April/022408.html
@@ -625,15 +628,15 @@ def post_unpack(self, parents, data):
625628
else:
626629
need_to_reorder_bytes = False
627630

631+
version_map = getattr(wave_structure, '_wave_versions', {
632+
1: Wave1,
633+
2: Wave2,
634+
3: Wave3,
635+
5: Wave5,
636+
})
628637
old_format = wave_structure.fields[-1].format
629-
if version == 1:
630-
wave_structure.fields[-1].format = Wave1
631-
elif version == 2:
632-
wave_structure.fields[-1].format = Wave2
633-
elif version == 3:
634-
wave_structure.fields[-1].format = Wave3
635-
elif version == 5:
636-
wave_structure.fields[-1].format = Wave5
638+
if version in version_map:
639+
wave_structure.fields[-1].format = version_map[version]
637640
elif not need_to_reorder_bytes:
638641
raise ValueError(
639642
'invalid binary wave version: {}'.format(version))
@@ -795,6 +798,10 @@ def post_unpack(self, parents, data):
795798

796799

797800
def setup_wave(byte_order='='):
801+
wave1 = _clone_structure(Wave1)
802+
wave2 = _clone_structure(Wave2)
803+
wave3 = _clone_structure(Wave3)
804+
wave5 = _clone_structure(Wave5)
798805
wave = _DynamicStructure(
799806
name='Wave',
800807
fields=[
@@ -803,22 +810,43 @@ def setup_wave(byte_order='='):
803810
'version',
804811
help='Version number for backwards compatibility.'),
805812
DynamicWaveField(
806-
Wave1,
813+
wave1,
807814
'wave',
808815
help='The rest of the wave data.'),
809816
],
810817
byte_order=byte_order)
818+
wave._wave_versions = {
819+
1: wave1,
820+
2: wave2,
821+
3: wave3,
822+
5: wave5,
823+
}
811824
wave.setup()
812825
return wave
813826

814827

828+
def _get_thread_local_wave():
829+
wave = getattr(_thread_local, 'wave', None)
830+
if wave is None:
831+
wave = setup_wave(byte_order='=')
832+
_thread_local.wave = wave
833+
return wave
834+
835+
836+
def _reset_wave_parser(wave):
837+
wave.byte_order = '='
838+
wave.fields[-1].format = wave._wave_versions[1]
839+
wave.setup()
840+
841+
815842
def load(filename):
816843
if hasattr(filename, 'read'):
817844
f = filename # filename is actually a stream object
818845
else:
819846
f = open(filename, 'rb')
820847
try:
821-
wave = setup_wave()
848+
wave = _get_thread_local_wave()
849+
_reset_wave_parser(wave)
822850
data = wave.unpack_stream(f)
823851
finally:
824852
if not hasattr(filename, 'read'):

igor2/record/variables.py

Lines changed: 48 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import io as _io
22
import logging
3+
import threading as _threading
34

45
from ..binarywave import TYPE_TABLE as _TYPE_TABLE
56
from ..binarywave import NullStaticStringField as _NullStaticStringField
@@ -8,12 +9,14 @@
89
from ..struct import DynamicStructure as _DynamicStructure
910
from ..struct import Field as _Field
1011
from ..struct import DynamicField as _DynamicField
12+
from ..struct import clone_structure as _clone_structure
1113
from ..util import byte_order as _byte_order
1214
from ..util import need_to_reorder_bytes as _need_to_reorder_bytes
1315
from .base import Record
1416

1517

1618
logger = logging.getLogger(__name__)
19+
_thread_local = _threading.local()
1720

1821

1922
class ListedStaticStringField(_NullStaticStringField):
@@ -297,11 +300,13 @@ def post_unpack(self, parents, data):
297300
else:
298301
need_to_reorder_bytes = False
299302

303+
version_map = getattr(variables_structure, '_version_structures', {
304+
1: Variables1,
305+
2: Variables2,
306+
})
300307
old_format = variables_structure.fields[-1].format
301-
if version == 1:
302-
variables_structure.fields[-1].format = Variables1
303-
elif version == 2:
304-
variables_structure.fields[-1].format = Variables2
308+
if version in version_map:
309+
variables_structure.fields[-1].format = version_map[version]
305310
elif not need_to_reorder_bytes:
306311
raise ValueError(
307312
'invalid variables record version: {}'.format(version))
@@ -318,26 +323,52 @@ def post_unpack(self, parents, data):
318323
return need_to_reorder_bytes
319324

320325

321-
VariablesRecordStructure = _DynamicStructure(
322-
name='VariablesRecord',
323-
fields=[
324-
DynamicVersionField(
325-
'h', 'version', help='Version number for this header.'),
326-
_Field(
327-
Variables1,
328-
'variables',
329-
help='The rest of the variables data.'),
330-
])
326+
def setup_variables_record(byte_order='='):
327+
variables1 = _clone_structure(Variables1)
328+
variables2 = _clone_structure(Variables2)
329+
variables_record_structure = _DynamicStructure(
330+
name='VariablesRecord',
331+
fields=[
332+
DynamicVersionField(
333+
'h', 'version', help='Version number for this header.'),
334+
_Field(
335+
variables1,
336+
'variables',
337+
help='The rest of the variables data.'),
338+
],
339+
byte_order=byte_order)
340+
variables_record_structure._version_structures = {
341+
1: variables1,
342+
2: variables2,
343+
}
344+
variables_record_structure.setup()
345+
return variables_record_structure
346+
347+
348+
def _get_thread_local_variables_record():
349+
variables_record_structure = getattr(
350+
_thread_local, 'variables_record_structure', None)
351+
if variables_record_structure is None:
352+
variables_record_structure = setup_variables_record(byte_order='=')
353+
_thread_local.variables_record_structure = variables_record_structure
354+
return variables_record_structure
355+
356+
357+
def _reset_variables_record_parser(variables_record_structure):
358+
variables_record_structure.byte_order = '='
359+
variables_record_structure.fields[-1].format = (
360+
variables_record_structure._version_structures[1])
361+
variables_record_structure.setup()
331362

332363

333364
class VariablesRecord (Record):
334365
def __init__(self, *args, **kwargs):
335366
super(VariablesRecord, self).__init__(*args, **kwargs)
336367
# self.header['version'] # record version always 0?
337-
VariablesRecordStructure.byte_order = '='
338-
VariablesRecordStructure.setup()
368+
variables_record_structure = _get_thread_local_variables_record()
369+
_reset_variables_record_parser(variables_record_structure)
339370
stream = _io.BytesIO(bytes(self.data))
340-
self.variables = VariablesRecordStructure.unpack_stream(stream)
371+
self.variables = variables_record_structure.unpack_stream(stream)
341372
self.namespace = {}
342373
for key, value in self.variables['variables'].items():
343374
if key not in ['var_header']:

igor2/struct.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -826,3 +826,42 @@ def unpack_from(self, buffer, offset=0, *args, **kwargs):
826826
args = super(Structure, self).unpack_from(
827827
buffer, offset, *args, **kwargs)
828828
return self._unpack_item(args)
829+
830+
831+
def clone_structure(structure, _memo=None):
832+
"""Recursively clone a Structure/DynamicStructure tree.
833+
834+
struct.Struct-derived instances cannot be copied with copy.copy/deepcopy,
835+
but we need independent parser instances for thread-safe dynamic unpacking.
836+
"""
837+
if _memo is None:
838+
_memo = {}
839+
sid = id(structure)
840+
if sid in _memo:
841+
return _memo[sid]
842+
843+
clone = structure.__class__(
844+
name=structure.name,
845+
fields=[],
846+
byte_order=structure.byte_order,
847+
)
848+
_memo[sid] = clone
849+
850+
fields = []
851+
for field in structure.fields:
852+
field_format = field.format
853+
if isinstance(field_format, Structure):
854+
field_format = clone_structure(field_format, _memo=_memo)
855+
field_clone = field.__class__(
856+
field_format,
857+
field.name,
858+
default=field.default,
859+
help=field.help,
860+
count=field.count,
861+
array=field.array,
862+
)
863+
fields.append(field_clone)
864+
865+
clone.fields = fields
866+
clone.setup()
867+
return clone

tests/test_pxp.py

Lines changed: 107 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1+
import hashlib
12
import threading
23

34
import numpy as np
45

6+
from igor2.binarywave import load as loadibw
57
from igor2.packed import load as loadpxp
68

79
from helper import data_dir
@@ -13,6 +15,67 @@ def tostr(data):
1315
return data
1416

1517

18+
def _array_fingerprint(data):
19+
# Array signature for equality checks
20+
array = np.ascontiguousarray(data)
21+
return (
22+
str(array.dtype),
23+
tuple(int(i) for i in array.shape),
24+
hashlib.sha256(array.tobytes()).hexdigest(),
25+
)
26+
27+
28+
def _ibw_fingerprint(path):
29+
data = loadibw(path)
30+
wave = data["wave"]["wData"]
31+
header = data["wave"]["wave_header"]
32+
# Include metadata plus payload signature for equality checks.
33+
return (
34+
int(data["version"]),
35+
tostr(header["bname"]),
36+
_array_fingerprint(wave),
37+
)
38+
39+
40+
def _pxp_fingerprint(path, initial_byte_order):
41+
records, filesystem = loadpxp(path, initial_byte_order=initial_byte_order)
42+
# Check both tree shape and wave payloads for equality checks.
43+
root_keys = tuple(sorted(tostr(key) for key in filesystem["root"].keys()))
44+
waves = []
45+
for record in records:
46+
if hasattr(record, "wave"):
47+
wave = record.wave["wave"]["wData"]
48+
name = tostr(record.wave["wave"]["wave_header"]["bname"])
49+
waves.append((name,) + _array_fingerprint(wave))
50+
return (len(records), root_keys, tuple(waves))
51+
52+
53+
def _run_concurrent_workload(worker_count, iterations_per_worker, task):
54+
barrier = threading.Barrier(worker_count)
55+
errors = []
56+
lock = threading.Lock()
57+
58+
def worker(thread_id):
59+
try:
60+
barrier.wait()
61+
for iteration in range(iterations_per_worker):
62+
task(thread_id, iteration)
63+
except Exception as exc:
64+
with lock:
65+
errors.append(f"thread {thread_id}: {exc!r}")
66+
67+
threads = []
68+
for thread_id in range(worker_count):
69+
thread = threading.Thread(target=worker, args=(thread_id,))
70+
threads.append(thread)
71+
thread.start()
72+
73+
for thread in threads:
74+
thread.join()
75+
76+
assert not errors, "\n".join(errors[:10])
77+
78+
1679
def test_pxp():
1780
data = loadpxp(data_dir / 'polar-graphs-demo.pxp')
1881
records = data[0]
@@ -157,22 +220,50 @@ def test_pxt():
157220

158221

159222
def test_thread_safe():
223+
jobs = [
224+
(data_dir / "polar-graphs-demo.pxp", None),
225+
(data_dir / "packed-byteorder.pxt", ">"),
226+
]
227+
expected = {job: _pxp_fingerprint(*job) for job in jobs}
160228

161-
def worker(fileobj, thread_id):
162-
expt = None
163-
for bo in ('<', '>'):
164-
try:
165-
_, expt = loadpxp(fileobj, initial_byte_order=bo)
166-
except ValueError:
167-
pass
168-
if expt is None:
169-
raise ValueError(f"No experiment loaded for thread {thread_id}")
229+
def task(thread_id, iteration):
230+
job = jobs[(thread_id + iteration) % len(jobs)]
231+
assert _pxp_fingerprint(*job) == expected[job]
232+
233+
_run_concurrent_workload(
234+
worker_count=32,
235+
iterations_per_worker=12,
236+
task=task,
237+
)
170238

171-
threads = []
172-
for i, fname in enumerate([data_dir / 'packed-byteorder.pxt'] * 100):
173-
t = threading.Thread(target=worker, args=(fname, i))
174-
threads.append(t)
175-
t.start()
176239

177-
for t in threads:
178-
t.join()
240+
def test_thread_safe_mixed():
241+
ibw_jobs = [
242+
data_dir / "mac-double.ibw",
243+
data_dir / "win-double.ibw",
244+
data_dir / "mac-version5.ibw",
245+
data_dir / "win-version5.ibw",
246+
]
247+
pxp_jobs = [
248+
(data_dir / "polar-graphs-demo.pxp", None),
249+
(data_dir / "packed-byteorder.pxt", ">"),
250+
]
251+
252+
expected_ibw = {job: _ibw_fingerprint(job) for job in ibw_jobs}
253+
expected_pxp = {job: _pxp_fingerprint(*job) for job in pxp_jobs}
254+
all_jobs = (
255+
[("ibw", job) for job in ibw_jobs] + [("pxp", job) for job in pxp_jobs]
256+
)
257+
258+
def task(thread_id, iteration):
259+
kind, payload = all_jobs[(thread_id * 3 + iteration) % len(all_jobs)]
260+
if kind == "ibw":
261+
assert _ibw_fingerprint(payload) == expected_ibw[payload]
262+
else:
263+
assert _pxp_fingerprint(*payload) == expected_pxp[payload]
264+
265+
_run_concurrent_workload(
266+
worker_count=32,
267+
iterations_per_worker=10,
268+
task=task,
269+
)

0 commit comments

Comments
 (0)