Skip to content

Commit a15e734

Browse files
Add a lock to third_party/py/google/protobuf/pyext/message_factory.cc to resolve data race detected with Python Free Threading.
PiperOrigin-RevId: 903062390
1 parent 7e8ca3e commit a15e734

7 files changed

Lines changed: 111 additions & 41 deletions

File tree

python/google/protobuf/internal/thread_safe_test.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,13 @@
77

88
"""Unittest for thread safe"""
99

10+
import importlib
1011
import sys
1112
import threading
1213
import time
1314
import unittest
1415

16+
from google3.fitbit.research.sensing.common.proto import evaluation_pb2
1517
from google.protobuf import descriptor_pb2
1618
from google.protobuf import descriptor_pool
1719
from google.protobuf import message_factory
@@ -303,6 +305,26 @@ def AccessFields(msg, barrier) -> None:
303305
for thread in threads:
304306
thread.join()
305307

308+
def testConcurrentGetAndRegisterMessageClassDataRace(self):
309+
"""Reproduces the data race in GetMessageClass/RegisterMessageClass."""
310+
311+
barrier = threading.Barrier(10)
312+
imported = False
313+
lock = threading.Lock()
314+
315+
def Task():
316+
barrier.wait()
317+
for _ in range(50):
318+
evaluation_pb2.MeasurementCollection(measurements=[])
319+
320+
nonlocal imported
321+
if not imported:
322+
with lock:
323+
if not imported:
324+
from google3.monitoring.streamz.api import data_model_pb2
325+
imported = True
326+
327+
self.RunThreads(10, Task)
306328

307329
if __name__ == '__main__':
308330
unittest.main()

python/google/protobuf/pyext/descriptor.cc

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -541,7 +541,6 @@ static PyObject* GetConcreteClass(PyBaseDescriptor* self, void* closure) {
541541
return nullptr;
542542
}
543543

544-
Py_XINCREF(concrete_class);
545544
return concrete_class->AsPyObject();
546545
}
547546

python/google/protobuf/pyext/extension_dict.cc

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -53,12 +53,15 @@ static Py_ssize_t len(ExtensionDict* self) {
5353
// happened to be linked in from C++ but not imported via Python. This is
5454
// for consistency with the pure Python implementation.
5555
if (fields[i]->file()->pool() == GetDefaultDescriptorPool()->pool &&
56-
fields[i]->message_type() != nullptr &&
57-
message_factory::GetMessageClass(
58-
cmessage::GetFactoryForMessage(self->parent),
59-
fields[i]->message_type()) == nullptr) {
60-
PyErr_Clear();
61-
continue;
56+
fields[i]->message_type() != nullptr) {
57+
CMessageClass* cls = message_factory::GetMessageClass(
58+
cmessage::GetFactoryForMessage(self->parent),
59+
fields[i]->message_type());
60+
if (cls == nullptr) {
61+
PyErr_Clear();
62+
continue;
63+
}
64+
Py_DECREF(cls);
6265
}
6366
++size;
6467
}
@@ -416,12 +419,15 @@ PyObject* IterNext(PyObject* _self) {
416419
// for consistency with the pure Python implementation.
417420
if (self->fields[index]->file()->pool() ==
418421
GetDefaultDescriptorPool()->pool &&
419-
self->fields[index]->message_type() != nullptr &&
420-
message_factory::GetMessageClass(
421-
cmessage::GetFactoryForMessage(self->extension_dict->parent),
422-
self->fields[index]->message_type()) == nullptr) {
423-
PyErr_Clear();
424-
continue;
422+
self->fields[index]->message_type() != nullptr) {
423+
CMessageClass* cls = message_factory::GetMessageClass(
424+
cmessage::GetFactoryForMessage(self->extension_dict->parent),
425+
self->fields[index]->message_type());
426+
if (cls == nullptr) {
427+
PyErr_Clear();
428+
continue;
429+
}
430+
Py_DECREF(cls);
425431
}
426432

427433
return PyFieldDescriptor_FromDescriptor(self->fields[index]);

python/google/protobuf/pyext/map_container.cc

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -306,7 +306,6 @@ PyObject* GetEntryClass(PyObject* _self) {
306306
CMessageClass* message_class = message_factory::GetMessageClass(
307307
cmessage::GetFactoryForMessage(self->parent),
308308
self->parent_field_descriptor->message_type());
309-
Py_XINCREF(message_class);
310309
return reinterpret_cast<PyObject*>(message_class);
311310
}
312311

python/google/protobuf/pyext/message.cc

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2112,12 +2112,14 @@ static PyObject* ListFields(CMessage* self) {
21122112
// happened to be linked in from C++ but not imported via Python. This is
21132113
// for consistency with the pure Python implementation.
21142114
if (fields[i]->file()->pool() == GetDefaultDescriptorPool()->pool &&
2115-
fields[i]->message_type() != nullptr &&
2116-
message_factory::GetMessageClass(GetFactoryForMessage(self),
2117-
fields[i]->message_type()) ==
2118-
nullptr) {
2119-
PyErr_Clear();
2120-
continue;
2115+
fields[i]->message_type() != nullptr) {
2116+
CMessageClass* cls = message_factory::GetMessageClass(
2117+
GetFactoryForMessage(self), fields[i]->message_type());
2118+
if (cls == nullptr) {
2119+
PyErr_Clear();
2120+
continue;
2121+
}
2122+
Py_DECREF(cls);
21212123
}
21222124
ScopedPyObjectPtr extensions(GetExtensionDict(self, nullptr));
21232125
if (extensions == nullptr) {
@@ -2687,6 +2689,7 @@ PyObject* GetFieldValue(CMessage* self,
26872689
}
26882690
py_container =
26892691
NewMessageMapContainer(self, field_descriptor, value_class);
2692+
Py_DECREF(value_class);
26902693
} else {
26912694
py_container = NewScalarMapContainer(self, field_descriptor);
26922695
}
@@ -2699,6 +2702,7 @@ PyObject* GetFieldValue(CMessage* self,
26992702
}
27002703
py_container = repeated_composite_container::NewContainer(
27012704
self, field_descriptor, message_class);
2705+
Py_DECREF(message_class);
27022706
} else {
27032707
py_container =
27042708
repeated_scalar_container::NewContainer(self, field_descriptor);

python/google/protobuf/pyext/message_factory.cc

Lines changed: 54 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
#include "google/protobuf/dynamic_message.h"
1515
#include "google/protobuf/pyext/descriptor.h"
16+
#include "google/protobuf/pyext/free_threading_mutex.h"
1617
#include "google/protobuf/pyext/message.h"
1718
#include "google/protobuf/pyext/message_factory.h"
1819
#include "google/protobuf/pyext/scoped_pyobject_ptr.h"
@@ -31,7 +32,8 @@ namespace python {
3132

3233
namespace message_factory {
3334

34-
PyMessageFactory* NewMessageFactory(PyTypeObject* type, PyDescriptorPool* pool) {
35+
PyMessageFactory* NewMessageFactory(PyTypeObject* type,
36+
PyDescriptorPool* pool) {
3537
PyMessageFactory* factory = reinterpret_cast<PyMessageFactory*>(
3638
PyType_GenericAlloc(type, 0));
3739
if (factory == nullptr) {
@@ -46,7 +48,15 @@ PyMessageFactory* NewMessageFactory(PyTypeObject* type, PyDescriptorPool* pool)
4648
factory->pool = pool;
4749
Py_INCREF(pool);
4850

49-
factory->classes_by_descriptor = new PyMessageFactory::ClassesByMessageMap();
51+
// Explicitly construct the mutex using placement new because
52+
// PyType_GenericAlloc only zero-initializes memory.
53+
new (&factory->mutex) FreeThreadingMutex();
54+
55+
{
56+
FreeThreadingLockGuard lock(factory->mutex);
57+
factory->classes_by_descriptor =
58+
new PyMessageFactory::ClassesByMessageMap();
59+
}
5060

5161
return factory;
5262
}
@@ -81,21 +91,32 @@ PyObject* New(PyTypeObject* type, PyObject* args, PyObject* kwargs) {
8191
static void Dealloc(PyObject* pself) {
8292
PyMessageFactory* self = reinterpret_cast<PyMessageFactory*>(pself);
8393

84-
typedef PyMessageFactory::ClassesByMessageMap::iterator iterator;
85-
for (iterator it = self->classes_by_descriptor->begin();
86-
it != self->classes_by_descriptor->end(); ++it) {
87-
Py_CLEAR(it->second);
94+
{
95+
FreeThreadingLockGuard lock(self->mutex);
96+
typedef PyMessageFactory::ClassesByMessageMap::iterator iterator;
97+
for (iterator it = self->classes_by_descriptor->begin();
98+
it != self->classes_by_descriptor->end(); ++it) {
99+
Py_CLEAR(it->second);
100+
}
101+
delete self->classes_by_descriptor;
102+
self->classes_by_descriptor = nullptr;
88103
}
89-
delete self->classes_by_descriptor;
104+
90105
delete self->message_factory;
91106
Py_CLEAR(self->pool);
107+
108+
// Explicitly call the destructor for the placement-new'd mutex.
109+
self->mutex.~FreeThreadingMutex();
110+
92111
PyObject_GC_UnTrack(pself);
93112
Py_TYPE(self)->tp_free(pself);
94113
}
95114

96115
static int GcTraverse(PyObject* pself, visitproc visit, void* arg) {
97116
PyMessageFactory* self = reinterpret_cast<PyMessageFactory*>(pself);
98117
Py_VISIT(self->pool);
118+
119+
FreeThreadingLockGuard lock(self->mutex);
99120
for (const auto& desc_and_class : *self->classes_by_descriptor) {
100121
Py_VISIT(desc_and_class.second);
101122
}
@@ -106,6 +127,8 @@ static int GcClear(PyObject* pself) {
106127
PyMessageFactory* self = reinterpret_cast<PyMessageFactory*>(pself);
107128
// Here it's important to not clear self->pool, so that the C++ DescriptorPool
108129
// is still alive when self->message_factory is destructed.
130+
131+
FreeThreadingLockGuard lock(self->mutex);
109132
for (auto& desc_and_class : *self->classes_by_descriptor) {
110133
Py_CLEAR(desc_and_class.second);
111134
}
@@ -118,13 +141,20 @@ int RegisterMessageClass(PyMessageFactory* self,
118141
const Descriptor* message_descriptor,
119142
CMessageClass* message_class) {
120143
Py_INCREF(message_class);
121-
typedef PyMessageFactory::ClassesByMessageMap::iterator iterator;
122-
std::pair<iterator, bool> ret = self->classes_by_descriptor->insert(
123-
std::make_pair(message_descriptor, message_class));
124-
if (!ret.second) {
125-
// Update case: DECREF the previous value.
126-
Py_DECREF(ret.first->second);
127-
ret.first->second = message_class;
144+
CMessageClass* old_class = nullptr;
145+
{
146+
FreeThreadingLockGuard lock(self->mutex);
147+
typedef PyMessageFactory::ClassesByMessageMap::iterator iterator;
148+
std::pair<iterator, bool> ret = self->classes_by_descriptor->insert(
149+
std::make_pair(message_descriptor, message_class));
150+
if (!ret.second) {
151+
// Update case: save the previous value to DECREF later.
152+
old_class = ret.first->second;
153+
ret.first->second = message_class;
154+
}
155+
}
156+
if (old_class != nullptr) {
157+
Py_DECREF(old_class);
128158
}
129159
return 0;
130160
}
@@ -134,11 +164,14 @@ CMessageClass* GetOrCreateMessageClass(PyMessageFactory* self,
134164
// This is the same implementation as MessageFactory.GetPrototype().
135165

136166
// Do not create a MessageClass that already exists.
137-
std::unordered_map<const Descriptor*, CMessageClass*>::iterator it =
138-
self->classes_by_descriptor->find(descriptor);
139-
if (it != self->classes_by_descriptor->end()) {
140-
Py_INCREF(it->second);
141-
return it->second;
167+
{
168+
FreeThreadingLockGuard lock(self->mutex);
169+
std::unordered_map<const Descriptor*, CMessageClass*>::iterator it =
170+
self->classes_by_descriptor->find(descriptor);
171+
if (it != self->classes_by_descriptor->end()) {
172+
Py_INCREF(it->second);
173+
return it->second;
174+
}
142175
}
143176
ScopedPyObjectPtr py_descriptor(
144177
PyMessageDescriptor_FromDescriptor(descriptor));
@@ -192,13 +225,15 @@ CMessageClass* GetOrCreateMessageClass(PyMessageFactory* self,
192225
// Retrieve the message class added to our database.
193226
CMessageClass* GetMessageClass(PyMessageFactory* self,
194227
const Descriptor* message_descriptor) {
228+
FreeThreadingLockGuard lock(self->mutex);
195229
typedef PyMessageFactory::ClassesByMessageMap::iterator iterator;
196230
iterator ret = self->classes_by_descriptor->find(message_descriptor);
197231
if (ret == self->classes_by_descriptor->end()) {
198232
PyErr_Format(PyExc_TypeError, "No message class registered for '%s'",
199233
std::string(message_descriptor->full_name()).c_str());
200234
return nullptr;
201235
} else {
236+
Py_INCREF(ret->second);
202237
return ret->second;
203238
}
204239
}

python/google/protobuf/pyext/message_factory.h

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,10 @@
1212
#include <Python.h>
1313

1414
#include <unordered_map>
15+
1516
#include "google/protobuf/descriptor.h"
1617
#include "google/protobuf/pyext/descriptor_pool.h"
18+
#include "google/protobuf/pyext/free_threading_mutex.h"
1719

1820
namespace google {
1921
namespace protobuf {
@@ -45,7 +47,10 @@ struct PyMessageFactory {
4547
// Python references to classes are owned by this PyDescriptorPool.
4648
typedef std::unordered_map<const Descriptor*, CMessageClass*>
4749
ClassesByMessageMap;
48-
ClassesByMessageMap* classes_by_descriptor;
50+
ClassesByMessageMap* classes_by_descriptor ABSL_GUARDED_BY(mutex);
51+
52+
// Mutex to protect classes_by_descriptor in free-threaded builds.
53+
FreeThreadingMutex mutex;
4954
};
5055

5156
extern PyTypeObject PyMessageFactory_Type;
@@ -61,7 +66,7 @@ int RegisterMessageClass(PyMessageFactory* self,
6166
const Descriptor* message_descriptor,
6267
CMessageClass* message_class);
6368
// Retrieves the Python class registered with the given message descriptor, or
64-
// fail with a TypeError. Returns a *borrowed* reference.
69+
// fail with a TypeError. Returns a *new* reference.
6570
CMessageClass* GetMessageClass(PyMessageFactory* self,
6671
const Descriptor* message_descriptor);
6772
// Retrieves the Python class registered with the given message descriptor.

0 commit comments

Comments
 (0)