Skip to content

Commit 490bb94

Browse files
(attempt at) pickle Buffers
1 parent eab5edc commit 490bb94

2 files changed

Lines changed: 96 additions & 0 deletions

File tree

pyopencl/__init__.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2404,4 +2404,80 @@ def fsvm_empty_like(ctx, ary, alignment=None):
24042404
_KERNEL_ARG_CLASSES = (*_KERNEL_ARG_CLASSES, SVM)
24052405

24062406

2407+
# {{{ pickling support
2408+
2409+
import threading
2410+
from contextlib import contextmanager
2411+
2412+
2413+
_QUEUE_FOR_PICKLING_TLS = threading.local()
2414+
2415+
2416+
@contextmanager
2417+
def queue_for_pickling(queue):
2418+
r"""A context manager that, for the current thread, sets the command queue
2419+
to be used for pickling and unpickling :class:`Buffer`\ s to *queue*."""
2420+
try:
2421+
existing_pickle_queue = _QUEUE_FOR_PICKLING_TLS.queue
2422+
except AttributeError:
2423+
existing_pickle_queue = None
2424+
2425+
if existing_pickle_queue is not None:
2426+
raise RuntimeError("queue_for_pickling should not be called "
2427+
"inside the context of its own invocation.")
2428+
2429+
_QUEUE_FOR_PICKLING_TLS.queue = queue
2430+
try:
2431+
yield None
2432+
finally:
2433+
_QUEUE_FOR_PICKLING_TLS.queue = None
2434+
2435+
2436+
def _getstate_buffer(self):
2437+
import pyopencl as cl
2438+
state = {}
2439+
state["size"] = self.size
2440+
state["flags"] = self.flags
2441+
2442+
try:
2443+
queue = _QUEUE_FOR_PICKLING_TLS.queue
2444+
except AttributeError:
2445+
queue = None
2446+
2447+
if queue is None:
2448+
raise RuntimeError("CL Buffer instances can only be pickled while "
2449+
"queue_for_pickling is active.")
2450+
2451+
a = bytearray(self.size)
2452+
cl.enqueue_copy(queue, a, self)
2453+
2454+
state["_pickle_data"] = a
2455+
2456+
return state
2457+
2458+
2459+
def _setstate_buffer(self, state):
2460+
try:
2461+
queue = _QUEUE_FOR_PICKLING_TLS.queue
2462+
except AttributeError:
2463+
queue = None
2464+
2465+
if queue is None:
2466+
raise RuntimeError("CL Buffer instances can only be unpickled while "
2467+
"queue_for_pickling is active.")
2468+
2469+
size = state["size"]
2470+
flags = state["flags"]
2471+
2472+
import pyopencl as cl
2473+
2474+
a = state["_pickle_data"]
2475+
Buffer.__init__(self, queue.context, flags | cl.mem_flags.COPY_HOST_PTR, size, a)
2476+
2477+
2478+
Buffer.__getstate__ = _getstate_buffer
2479+
Buffer.__setstate__ = _setstate_buffer
2480+
2481+
# }}}
2482+
24072483
# vim: foldmethod=marker

test/test_array.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2453,6 +2453,26 @@ def test_array_pickling(ctx_factory):
24532453
# }}}
24542454

24552455

2456+
def test_buffer_pickling(ctx_factory):
2457+
context = ctx_factory()
2458+
queue = cl.CommandQueue(context)
2459+
2460+
a = np.array([1, 2, 3, 4, 5]).astype(np.float32)
2461+
a_gpu = cl.Buffer(context, cl.mem_flags.READ_WRITE, a.nbytes)
2462+
cl.enqueue_copy(queue, a_gpu, a)
2463+
2464+
import pickle
2465+
2466+
with pytest.raises(cl.RuntimeError):
2467+
pickle.dumps(a_gpu)
2468+
2469+
with cl.queue_for_pickling(queue):
2470+
a_gpu_pickled = pickle.loads(pickle.dumps(a_gpu))
2471+
2472+
a_new = np.empty_like(a)
2473+
cl.enqueue_copy(queue, a_new, a_gpu_pickled)
2474+
assert np.all(a_new == a)
2475+
24562476
# }}}
24572477

24582478

0 commit comments

Comments
 (0)