diff --git a/python/google/protobuf/internal/thread_safe_test.py b/python/google/protobuf/internal/thread_safe_test.py index eefe40dbf4998..3cae7887a547b 100644 --- a/python/google/protobuf/internal/thread_safe_test.py +++ b/python/google/protobuf/internal/thread_safe_test.py @@ -7,11 +7,13 @@ """Unittest for thread safe""" +import importlib import sys import threading import time import unittest +from google3.fitbit.research.sensing.common.proto import evaluation_pb2 from google.protobuf import descriptor_pb2 from google.protobuf import descriptor_pool from google.protobuf import message_factory @@ -303,6 +305,26 @@ def AccessFields(msg, barrier) -> None: for thread in threads: thread.join() + def testConcurrentGetAndRegisterMessageClassDataRace(self): + """Reproduces the data race in GetMessageClass/RegisterMessageClass.""" + + barrier = threading.Barrier(10) + imported = False + lock = threading.Lock() + + def Task(): + barrier.wait() + for _ in range(50): + evaluation_pb2.MeasurementCollection(measurements=[]) + + nonlocal imported + if not imported: + with lock: + if not imported: + from google3.monitoring.streamz.api import data_model_pb2 + imported = True + + self.RunThreads(10, Task) if __name__ == '__main__': unittest.main() diff --git a/python/google/protobuf/pyext/descriptor.cc b/python/google/protobuf/pyext/descriptor.cc index 5b04e3f0c96a7..c8b32f582aa26 100644 --- a/python/google/protobuf/pyext/descriptor.cc +++ b/python/google/protobuf/pyext/descriptor.cc @@ -541,7 +541,6 @@ static PyObject* GetConcreteClass(PyBaseDescriptor* self, void* closure) { return nullptr; } - Py_XINCREF(concrete_class); return concrete_class->AsPyObject(); } diff --git a/python/google/protobuf/pyext/extension_dict.cc b/python/google/protobuf/pyext/extension_dict.cc index 76b01adbe11a6..0d8cd02cdfecf 100644 --- a/python/google/protobuf/pyext/extension_dict.cc +++ b/python/google/protobuf/pyext/extension_dict.cc @@ -53,12 +53,15 @@ static Py_ssize_t len(ExtensionDict* self) { // happened to be linked in from C++ but not imported via Python. This is // for consistency with the pure Python implementation. if (fields[i]->file()->pool() == GetDefaultDescriptorPool()->pool && - fields[i]->message_type() != nullptr && - message_factory::GetMessageClass( - cmessage::GetFactoryForMessage(self->parent), - fields[i]->message_type()) == nullptr) { - PyErr_Clear(); - continue; + fields[i]->message_type() != nullptr) { + CMessageClass* cls = message_factory::GetMessageClass( + cmessage::GetFactoryForMessage(self->parent), + fields[i]->message_type()); + if (cls == nullptr) { + PyErr_Clear(); + continue; + } + Py_DECREF(cls); } ++size; } @@ -416,12 +419,15 @@ PyObject* IterNext(PyObject* _self) { // for consistency with the pure Python implementation. if (self->fields[index]->file()->pool() == GetDefaultDescriptorPool()->pool && - self->fields[index]->message_type() != nullptr && - message_factory::GetMessageClass( - cmessage::GetFactoryForMessage(self->extension_dict->parent), - self->fields[index]->message_type()) == nullptr) { - PyErr_Clear(); - continue; + self->fields[index]->message_type() != nullptr) { + CMessageClass* cls = message_factory::GetMessageClass( + cmessage::GetFactoryForMessage(self->extension_dict->parent), + self->fields[index]->message_type()); + if (cls == nullptr) { + PyErr_Clear(); + continue; + } + Py_DECREF(cls); } return PyFieldDescriptor_FromDescriptor(self->fields[index]); diff --git a/python/google/protobuf/pyext/map_container.cc b/python/google/protobuf/pyext/map_container.cc index e1a98809d5bb5..83e8067e9357b 100644 --- a/python/google/protobuf/pyext/map_container.cc +++ b/python/google/protobuf/pyext/map_container.cc @@ -306,7 +306,6 @@ PyObject* GetEntryClass(PyObject* _self) { CMessageClass* message_class = message_factory::GetMessageClass( cmessage::GetFactoryForMessage(self->parent), self->parent_field_descriptor->message_type()); - Py_XINCREF(message_class); return reinterpret_cast(message_class); } diff --git a/python/google/protobuf/pyext/message.cc b/python/google/protobuf/pyext/message.cc index 63f44145d08b0..01eb864039e84 100644 --- a/python/google/protobuf/pyext/message.cc +++ b/python/google/protobuf/pyext/message.cc @@ -2112,12 +2112,14 @@ static PyObject* ListFields(CMessage* self) { // happened to be linked in from C++ but not imported via Python. This is // for consistency with the pure Python implementation. if (fields[i]->file()->pool() == GetDefaultDescriptorPool()->pool && - fields[i]->message_type() != nullptr && - message_factory::GetMessageClass(GetFactoryForMessage(self), - fields[i]->message_type()) == - nullptr) { - PyErr_Clear(); - continue; + fields[i]->message_type() != nullptr) { + CMessageClass* cls = message_factory::GetMessageClass( + GetFactoryForMessage(self), fields[i]->message_type()); + if (cls == nullptr) { + PyErr_Clear(); + continue; + } + Py_DECREF(cls); } ScopedPyObjectPtr extensions(GetExtensionDict(self, nullptr)); if (extensions == nullptr) { @@ -2687,6 +2689,7 @@ PyObject* GetFieldValue(CMessage* self, } py_container = NewMessageMapContainer(self, field_descriptor, value_class); + Py_DECREF(value_class); } else { py_container = NewScalarMapContainer(self, field_descriptor); } @@ -2699,6 +2702,7 @@ PyObject* GetFieldValue(CMessage* self, } py_container = repeated_composite_container::NewContainer( self, field_descriptor, message_class); + Py_DECREF(message_class); } else { py_container = repeated_scalar_container::NewContainer(self, field_descriptor); diff --git a/python/google/protobuf/pyext/message_factory.cc b/python/google/protobuf/pyext/message_factory.cc index d8d68312d7ea0..22d114e877bd9 100644 --- a/python/google/protobuf/pyext/message_factory.cc +++ b/python/google/protobuf/pyext/message_factory.cc @@ -13,6 +13,7 @@ #include "google/protobuf/dynamic_message.h" #include "google/protobuf/pyext/descriptor.h" +#include "google/protobuf/pyext/free_threading_mutex.h" #include "google/protobuf/pyext/message.h" #include "google/protobuf/pyext/message_factory.h" #include "google/protobuf/pyext/scoped_pyobject_ptr.h" @@ -31,7 +32,8 @@ namespace python { namespace message_factory { -PyMessageFactory* NewMessageFactory(PyTypeObject* type, PyDescriptorPool* pool) { +PyMessageFactory* NewMessageFactory(PyTypeObject* type, + PyDescriptorPool* pool) { PyMessageFactory* factory = reinterpret_cast( PyType_GenericAlloc(type, 0)); if (factory == nullptr) { @@ -46,7 +48,15 @@ PyMessageFactory* NewMessageFactory(PyTypeObject* type, PyDescriptorPool* pool) factory->pool = pool; Py_INCREF(pool); - factory->classes_by_descriptor = new PyMessageFactory::ClassesByMessageMap(); + // Explicitly construct the mutex using placement new because + // PyType_GenericAlloc only zero-initializes memory. + new (&factory->mutex) FreeThreadingMutex(); + + { + FreeThreadingLockGuard lock(factory->mutex); + factory->classes_by_descriptor = + new PyMessageFactory::ClassesByMessageMap(); + } return factory; } @@ -81,14 +91,23 @@ PyObject* New(PyTypeObject* type, PyObject* args, PyObject* kwargs) { static void Dealloc(PyObject* pself) { PyMessageFactory* self = reinterpret_cast(pself); - typedef PyMessageFactory::ClassesByMessageMap::iterator iterator; - for (iterator it = self->classes_by_descriptor->begin(); - it != self->classes_by_descriptor->end(); ++it) { - Py_CLEAR(it->second); + { + FreeThreadingLockGuard lock(self->mutex); + typedef PyMessageFactory::ClassesByMessageMap::iterator iterator; + for (iterator it = self->classes_by_descriptor->begin(); + it != self->classes_by_descriptor->end(); ++it) { + Py_CLEAR(it->second); + } + delete self->classes_by_descriptor; + self->classes_by_descriptor = nullptr; } - delete self->classes_by_descriptor; + delete self->message_factory; Py_CLEAR(self->pool); + + // Explicitly call the destructor for the placement-new'd mutex. + self->mutex.~FreeThreadingMutex(); + PyObject_GC_UnTrack(pself); Py_TYPE(self)->tp_free(pself); } @@ -96,6 +115,8 @@ static void Dealloc(PyObject* pself) { static int GcTraverse(PyObject* pself, visitproc visit, void* arg) { PyMessageFactory* self = reinterpret_cast(pself); Py_VISIT(self->pool); + + FreeThreadingLockGuard lock(self->mutex); for (const auto& desc_and_class : *self->classes_by_descriptor) { Py_VISIT(desc_and_class.second); } @@ -106,6 +127,8 @@ static int GcClear(PyObject* pself) { PyMessageFactory* self = reinterpret_cast(pself); // Here it's important to not clear self->pool, so that the C++ DescriptorPool // is still alive when self->message_factory is destructed. + + FreeThreadingLockGuard lock(self->mutex); for (auto& desc_and_class : *self->classes_by_descriptor) { Py_CLEAR(desc_and_class.second); } @@ -118,13 +141,20 @@ int RegisterMessageClass(PyMessageFactory* self, const Descriptor* message_descriptor, CMessageClass* message_class) { Py_INCREF(message_class); - typedef PyMessageFactory::ClassesByMessageMap::iterator iterator; - std::pair ret = self->classes_by_descriptor->insert( - std::make_pair(message_descriptor, message_class)); - if (!ret.second) { - // Update case: DECREF the previous value. - Py_DECREF(ret.first->second); - ret.first->second = message_class; + CMessageClass* old_class = nullptr; + { + FreeThreadingLockGuard lock(self->mutex); + typedef PyMessageFactory::ClassesByMessageMap::iterator iterator; + std::pair ret = self->classes_by_descriptor->insert( + std::make_pair(message_descriptor, message_class)); + if (!ret.second) { + // Update case: save the previous value to DECREF later. + old_class = ret.first->second; + ret.first->second = message_class; + } + } + if (old_class != nullptr) { + Py_DECREF(old_class); } return 0; } @@ -134,11 +164,14 @@ CMessageClass* GetOrCreateMessageClass(PyMessageFactory* self, // This is the same implementation as MessageFactory.GetPrototype(). // Do not create a MessageClass that already exists. - std::unordered_map::iterator it = - self->classes_by_descriptor->find(descriptor); - if (it != self->classes_by_descriptor->end()) { - Py_INCREF(it->second); - return it->second; + { + FreeThreadingLockGuard lock(self->mutex); + std::unordered_map::iterator it = + self->classes_by_descriptor->find(descriptor); + if (it != self->classes_by_descriptor->end()) { + Py_INCREF(it->second); + return it->second; + } } ScopedPyObjectPtr py_descriptor( PyMessageDescriptor_FromDescriptor(descriptor)); @@ -192,6 +225,7 @@ CMessageClass* GetOrCreateMessageClass(PyMessageFactory* self, // Retrieve the message class added to our database. CMessageClass* GetMessageClass(PyMessageFactory* self, const Descriptor* message_descriptor) { + FreeThreadingLockGuard lock(self->mutex); typedef PyMessageFactory::ClassesByMessageMap::iterator iterator; iterator ret = self->classes_by_descriptor->find(message_descriptor); if (ret == self->classes_by_descriptor->end()) { @@ -199,6 +233,7 @@ CMessageClass* GetMessageClass(PyMessageFactory* self, std::string(message_descriptor->full_name()).c_str()); return nullptr; } else { + Py_INCREF(ret->second); return ret->second; } } diff --git a/python/google/protobuf/pyext/message_factory.h b/python/google/protobuf/pyext/message_factory.h index 3225a130f74fc..fd1306bc0e93e 100644 --- a/python/google/protobuf/pyext/message_factory.h +++ b/python/google/protobuf/pyext/message_factory.h @@ -12,8 +12,10 @@ #include #include + #include "google/protobuf/descriptor.h" #include "google/protobuf/pyext/descriptor_pool.h" +#include "google/protobuf/pyext/free_threading_mutex.h" namespace google { namespace protobuf { @@ -45,7 +47,10 @@ struct PyMessageFactory { // Python references to classes are owned by this PyDescriptorPool. typedef std::unordered_map ClassesByMessageMap; - ClassesByMessageMap* classes_by_descriptor; + ClassesByMessageMap* classes_by_descriptor ABSL_GUARDED_BY(mutex); + + // Mutex to protect classes_by_descriptor in free-threaded builds. + FreeThreadingMutex mutex; }; extern PyTypeObject PyMessageFactory_Type; @@ -61,7 +66,7 @@ int RegisterMessageClass(PyMessageFactory* self, const Descriptor* message_descriptor, CMessageClass* message_class); // Retrieves the Python class registered with the given message descriptor, or -// fail with a TypeError. Returns a *borrowed* reference. +// fail with a TypeError. Returns a *new* reference. CMessageClass* GetMessageClass(PyMessageFactory* self, const Descriptor* message_descriptor); // Retrieves the Python class registered with the given message descriptor.