Skip to content

Commit 58c255b

Browse files
committed
Add some more tests.
1 parent 37e4556 commit 58c255b

File tree

3 files changed

+61
-23
lines changed

3 files changed

+61
-23
lines changed

Lib/concurrent/interpreters/__init__.py

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
# aliases:
88
from _interpreters import (
99
InterpreterError, InterpreterNotFoundError, NotShareableError,
10-
is_shareable, SharedObjectProxy
10+
is_shareable, SharedObjectProxy, share
1111
)
1212
from ._queues import (
1313
create as create_queue,
@@ -245,19 +245,3 @@ def call_in_thread(self, callable, /, *args, **kwargs):
245245
t = threading.Thread(target=self._call, args=(callable, args, kwargs))
246246
t.start()
247247
return t
248-
249-
250-
def _can_natively_share(obj):
251-
if isinstance(obj, SharedObjectProxy):
252-
return False
253-
254-
return _interpreters.is_shareable(obj)
255-
256-
257-
def share(obj):
258-
"""Wrap the object in a shareable object proxy that allows cross-interpreter
259-
access.
260-
"""
261-
if _can_natively_share(obj):
262-
return obj
263-
return _interpreters.share(obj)

Lib/test/test_interpreters/test_object_proxy.py

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
# Raise SkipTest if subinterpreters not supported.
77
import_helper.import_module("_interpreters")
8-
from concurrent.interpreters import share, SharedObjectProxy
8+
from concurrent.interpreters import NotShareableError, share, SharedObjectProxy
99
from test.test_interpreters.utils import TestBase
1010
from threading import Barrier, Thread, Lock, local
1111
from concurrent import interpreters
@@ -287,6 +287,56 @@ def foo():
287287

288288
self.assertTrue(called)
289289

290+
def test_proxy_reshare_does_not_copy(self):
291+
class Test:
292+
pass
293+
294+
proxy = share(Test())
295+
reproxy = share(proxy)
296+
self.assertIs(proxy, reproxy)
297+
298+
def test_object_share_method(self):
299+
class Test:
300+
def __share__(self):
301+
return 42
302+
303+
shared = share(Test())
304+
self.assertEqual(shared, 42)
305+
306+
def test_object_share_method_failure(self):
307+
class Test:
308+
def __share__(self):
309+
return self
310+
311+
exception = RuntimeError("ouch")
312+
class Evil:
313+
def __share__(self):
314+
raise exception
315+
316+
with self.assertRaises(NotShareableError):
317+
share(Test())
318+
319+
with self.assertRaises(RuntimeError) as exc:
320+
share(Evil())
321+
322+
self.assertIs(exc.exception, exception)
323+
324+
def test_proxy_manual_construction(self):
325+
called = False
326+
327+
class Test:
328+
def __init__(self):
329+
self.attr = 24
330+
331+
def __share__(self):
332+
nonlocal called
333+
called = True
334+
return 42
335+
336+
proxy = SharedObjectProxy(Test())
337+
self.assertIsInstance(proxy, SharedObjectProxy)
338+
self.assertFalse(called)
339+
self.assertEqual(proxy.attr, 24)
290340

291341
if __name__ == "__main__":
292342
unittest.main()

Modules/_interpretersmodule.c

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2548,18 +2548,16 @@ _interpreters_capture_exception_impl(PyObject *module, PyObject *exc_arg)
25482548
}
25492549

25502550
static PyObject *
2551-
call_share_method_steal(PyObject *method)
2551+
call_share_method_steal(PyThreadState *tstate, PyObject *method)
25522552
{
2553+
assert(tstate != NULL);
25532554
assert(method != NULL);
25542555
PyObject *result = PyObject_CallNoArgs(method);
25552556
Py_DECREF(method);
25562557
if (result == NULL) {
25572558
return NULL;
25582559
}
25592560

2560-
PyThreadState *tstate = _PyThreadState_GET();
2561-
assert(tstate != NULL);
2562-
25632561
if (_PyObject_CheckXIData(tstate, result) < 0) {
25642562
PyObject *exc = _PyErr_GetRaisedException(tstate);
25652563
_PyXIData_FormatNotShareableError(tstate, "__share__() returned unshareable object: %R", result);
@@ -2586,12 +2584,18 @@ static PyObject *
25862584
_interpreters_share(PyObject *module, PyObject *op)
25872585
/*[clinic end generated code: output=e2ce861ae3b58508 input=5fb300b5598bb7d2]*/
25882586
{
2587+
PyThreadState *tstate = _PyThreadState_GET();
2588+
if (_PyObject_CheckXIData(tstate, op) == 0) {
2589+
return Py_NewRef(op);
2590+
}
2591+
PyErr_Clear();
2592+
25892593
PyObject *share_method;
25902594
if (PyObject_GetOptionalAttrString(op, "__share__", &share_method) < 0) {
25912595
return NULL;
25922596
}
25932597
if (share_method != NULL) {
2594-
return call_share_method_steal(share_method /* stolen */);
2598+
return call_share_method_steal(tstate, share_method /* stolen */);
25952599
}
25962600

25972601
return _sharedobjectproxy_create(op, _PyInterpreterState_GET());

0 commit comments

Comments
 (0)