Skip to content

Commit 34b0826

Browse files
(attempt at) pickle Buffers
1 parent b4a65b3 commit 34b0826

2 files changed

Lines changed: 96 additions & 0 deletions

File tree

pyopencl/__init__.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2407,4 +2407,80 @@ def fsvm_empty_like(ctx, ary, alignment=None):
24072407
_KERNEL_ARG_CLASSES = (*_KERNEL_ARG_CLASSES, SVM)
24082408

24092409

2410+
# {{{ pickling support
2411+
2412+
import threading
2413+
from contextlib import contextmanager
2414+
2415+
2416+
_QUEUE_FOR_PICKLING_TLS = threading.local()
2417+
2418+
2419+
@contextmanager
2420+
def queue_for_pickling(queue):
2421+
r"""A context manager that, for the current thread, sets the command queue
2422+
to be used for pickling and unpickling :class:`Buffer`\ s to *queue*."""
2423+
try:
2424+
existing_pickle_queue = _QUEUE_FOR_PICKLING_TLS.queue
2425+
except AttributeError:
2426+
existing_pickle_queue = None
2427+
2428+
if existing_pickle_queue is not None:
2429+
raise RuntimeError("queue_for_pickling should not be called "
2430+
"inside the context of its own invocation.")
2431+
2432+
_QUEUE_FOR_PICKLING_TLS.queue = queue
2433+
try:
2434+
yield None
2435+
finally:
2436+
_QUEUE_FOR_PICKLING_TLS.queue = None
2437+
2438+
2439+
def _getstate_buffer(self):
2440+
import pyopencl as cl
2441+
state = {}
2442+
state["size"] = self.size
2443+
state["flags"] = self.flags
2444+
2445+
try:
2446+
queue = _QUEUE_FOR_PICKLING_TLS.queue
2447+
except AttributeError:
2448+
queue = None
2449+
2450+
if queue is None:
2451+
raise RuntimeError("CL Buffer instances can only be pickled while "
2452+
"queue_for_pickling is active.")
2453+
2454+
a = bytearray(self.size)
2455+
cl.enqueue_copy(queue, a, self)
2456+
2457+
state["_pickle_data"] = a
2458+
2459+
return state
2460+
2461+
2462+
def _setstate_buffer(self, state):
2463+
try:
2464+
queue = _QUEUE_FOR_PICKLING_TLS.queue
2465+
except AttributeError:
2466+
queue = None
2467+
2468+
if queue is None:
2469+
raise RuntimeError("CL Buffer instances can only be unpickled while "
2470+
"queue_for_pickling is active.")
2471+
2472+
size = state["size"]
2473+
flags = state["flags"]
2474+
2475+
import pyopencl as cl
2476+
2477+
a = state["_pickle_data"]
2478+
Buffer.__init__(self, queue.context, flags | cl.mem_flags.COPY_HOST_PTR, size, a)
2479+
2480+
2481+
Buffer.__getstate__ = _getstate_buffer
2482+
Buffer.__setstate__ = _setstate_buffer
2483+
2484+
# }}}
2485+
24102486
# vim: foldmethod=marker

test/test_array.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2453,6 +2453,26 @@ def test_array_pickling(ctx_factory):
24532453
# }}}
24542454

24552455

2456+
def test_buffer_pickling(ctx_factory):
2457+
context = ctx_factory()
2458+
queue = cl.CommandQueue(context)
2459+
2460+
a = np.array([1, 2, 3, 4, 5]).astype(np.float32)
2461+
a_gpu = cl.Buffer(context, cl.mem_flags.READ_WRITE, a.nbytes)
2462+
cl.enqueue_copy(queue, a_gpu, a)
2463+
2464+
import pickle
2465+
2466+
with pytest.raises(cl.RuntimeError):
2467+
pickle.dumps(a_gpu)
2468+
2469+
with cl.queue_for_pickling(queue):
2470+
a_gpu_pickled = pickle.loads(pickle.dumps(a_gpu))
2471+
2472+
a_new = np.empty_like(a)
2473+
cl.enqueue_copy(queue, a_new, a_gpu_pickled)
2474+
assert np.all(a_new == a)
2475+
24562476
# }}}
24572477

24582478

0 commit comments

Comments
 (0)