1313
1414#include " google/protobuf/dynamic_message.h"
1515#include " google/protobuf/pyext/descriptor.h"
16+ #include " google/protobuf/pyext/free_threading_mutex.h"
1617#include " google/protobuf/pyext/message.h"
1718#include " google/protobuf/pyext/message_factory.h"
1819#include " google/protobuf/pyext/scoped_pyobject_ptr.h"
@@ -31,7 +32,8 @@ namespace python {
3132
3233namespace message_factory {
3334
34- PyMessageFactory* NewMessageFactory (PyTypeObject* type, PyDescriptorPool* pool) {
35+ PyMessageFactory* NewMessageFactory (PyTypeObject* type,
36+ PyDescriptorPool* pool) {
3537 PyMessageFactory* factory = reinterpret_cast <PyMessageFactory*>(
3638 PyType_GenericAlloc (type, 0 ));
3739 if (factory == nullptr ) {
@@ -46,7 +48,15 @@ PyMessageFactory* NewMessageFactory(PyTypeObject* type, PyDescriptorPool* pool)
4648 factory->pool = pool;
4749 Py_INCREF (pool);
4850
49- factory->classes_by_descriptor = new PyMessageFactory::ClassesByMessageMap ();
51+ // Explicitly construct the mutex using placement new because
52+ // PyType_GenericAlloc only zero-initializes memory.
53+ new (&factory->mutex ) FreeThreadingMutex ();
54+
55+ {
56+ FreeThreadingLockGuard lock (factory->mutex );
57+ factory->classes_by_descriptor =
58+ new PyMessageFactory::ClassesByMessageMap ();
59+ }
5060
5161 return factory;
5262}
@@ -81,21 +91,32 @@ PyObject* New(PyTypeObject* type, PyObject* args, PyObject* kwargs) {
8191static void Dealloc (PyObject* pself) {
8292 PyMessageFactory* self = reinterpret_cast <PyMessageFactory*>(pself);
8393
84- typedef PyMessageFactory::ClassesByMessageMap::iterator iterator;
85- for (iterator it = self->classes_by_descriptor ->begin ();
86- it != self->classes_by_descriptor ->end (); ++it) {
87- Py_CLEAR (it->second );
94+ {
95+ FreeThreadingLockGuard lock (self->mutex );
96+ typedef PyMessageFactory::ClassesByMessageMap::iterator iterator;
97+ for (iterator it = self->classes_by_descriptor ->begin ();
98+ it != self->classes_by_descriptor ->end (); ++it) {
99+ Py_CLEAR (it->second );
100+ }
101+ delete self->classes_by_descriptor ;
102+ self->classes_by_descriptor = nullptr ;
88103 }
89- delete self-> classes_by_descriptor ;
104+
90105 delete self->message_factory ;
91106 Py_CLEAR (self->pool );
107+
108+ // Explicitly call the destructor for the placement-new'd mutex.
109+ self->mutex .~FreeThreadingMutex ();
110+
92111 PyObject_GC_UnTrack (pself);
93112 Py_TYPE (self)->tp_free (pself);
94113}
95114
96115static int GcTraverse (PyObject* pself, visitproc visit, void * arg) {
97116 PyMessageFactory* self = reinterpret_cast <PyMessageFactory*>(pself);
98117 Py_VISIT (self->pool );
118+
119+ FreeThreadingLockGuard lock (self->mutex );
99120 for (const auto & desc_and_class : *self->classes_by_descriptor ) {
100121 Py_VISIT (desc_and_class.second );
101122 }
@@ -106,6 +127,8 @@ static int GcClear(PyObject* pself) {
106127 PyMessageFactory* self = reinterpret_cast <PyMessageFactory*>(pself);
107128 // Here it's important to not clear self->pool, so that the C++ DescriptorPool
108129 // is still alive when self->message_factory is destructed.
130+
131+ FreeThreadingLockGuard lock (self->mutex );
109132 for (auto & desc_and_class : *self->classes_by_descriptor ) {
110133 Py_CLEAR (desc_and_class.second );
111134 }
@@ -118,13 +141,20 @@ int RegisterMessageClass(PyMessageFactory* self,
118141 const Descriptor* message_descriptor,
119142 CMessageClass* message_class) {
120143 Py_INCREF (message_class);
121- typedef PyMessageFactory::ClassesByMessageMap::iterator iterator;
122- std::pair<iterator, bool > ret = self->classes_by_descriptor ->insert (
123- std::make_pair (message_descriptor, message_class));
124- if (!ret.second ) {
125- // Update case: DECREF the previous value.
126- Py_DECREF (ret.first ->second );
127- ret.first ->second = message_class;
144+ CMessageClass* old_class = nullptr ;
145+ {
146+ FreeThreadingLockGuard lock (self->mutex );
147+ typedef PyMessageFactory::ClassesByMessageMap::iterator iterator;
148+ std::pair<iterator, bool > ret = self->classes_by_descriptor ->insert (
149+ std::make_pair (message_descriptor, message_class));
150+ if (!ret.second ) {
151+ // Update case: save the previous value to DECREF later.
152+ old_class = ret.first ->second ;
153+ ret.first ->second = message_class;
154+ }
155+ }
156+ if (old_class != nullptr ) {
157+ Py_DECREF (old_class);
128158 }
129159 return 0 ;
130160}
@@ -134,11 +164,14 @@ CMessageClass* GetOrCreateMessageClass(PyMessageFactory* self,
134164 // This is the same implementation as MessageFactory.GetPrototype().
135165
136166 // Do not create a MessageClass that already exists.
137- std::unordered_map<const Descriptor*, CMessageClass*>::iterator it =
138- self->classes_by_descriptor ->find (descriptor);
139- if (it != self->classes_by_descriptor ->end ()) {
140- Py_INCREF (it->second );
141- return it->second ;
167+ {
168+ FreeThreadingLockGuard lock (self->mutex );
169+ std::unordered_map<const Descriptor*, CMessageClass*>::iterator it =
170+ self->classes_by_descriptor ->find (descriptor);
171+ if (it != self->classes_by_descriptor ->end ()) {
172+ Py_INCREF (it->second );
173+ return it->second ;
174+ }
142175 }
143176 ScopedPyObjectPtr py_descriptor (
144177 PyMessageDescriptor_FromDescriptor (descriptor));
@@ -192,13 +225,15 @@ CMessageClass* GetOrCreateMessageClass(PyMessageFactory* self,
192225// Retrieve the message class added to our database.
193226CMessageClass* GetMessageClass (PyMessageFactory* self,
194227 const Descriptor* message_descriptor) {
228+ FreeThreadingLockGuard lock (self->mutex );
195229 typedef PyMessageFactory::ClassesByMessageMap::iterator iterator;
196230 iterator ret = self->classes_by_descriptor ->find (message_descriptor);
197231 if (ret == self->classes_by_descriptor ->end ()) {
198232 PyErr_Format (PyExc_TypeError, " No message class registered for '%s'" ,
199233 std::string (message_descriptor->full_name ()).c_str ());
200234 return nullptr ;
201235 } else {
236+ Py_INCREF (ret->second );
202237 return ret->second ;
203238 }
204239}
0 commit comments