Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions python/google/protobuf/internal/thread_safe_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,13 @@

"""Unittest for thread safe"""

import importlib
import sys
import threading
import time
import unittest

from google3.fitbit.research.sensing.common.proto import evaluation_pb2
from google.protobuf import descriptor_pb2
from google.protobuf import descriptor_pool
from google.protobuf import message_factory
Expand Down Expand Up @@ -303,6 +305,26 @@ def AccessFields(msg, barrier) -> None:
for thread in threads:
thread.join()

def testConcurrentGetAndRegisterMessageClassDataRace(self):
"""Reproduces the data race in GetMessageClass/RegisterMessageClass."""

barrier = threading.Barrier(10)
imported = False
lock = threading.Lock()

def Task():
barrier.wait()
for _ in range(50):
evaluation_pb2.MeasurementCollection(measurements=[])

nonlocal imported
if not imported:
with lock:
if not imported:
from google3.monitoring.streamz.api import data_model_pb2
imported = True

self.RunThreads(10, Task)

if __name__ == '__main__':
unittest.main()
1 change: 0 additions & 1 deletion python/google/protobuf/pyext/descriptor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -541,7 +541,6 @@ static PyObject* GetConcreteClass(PyBaseDescriptor* self, void* closure) {
return nullptr;
}

Py_XINCREF(concrete_class);
return concrete_class->AsPyObject();
}

Expand Down
30 changes: 18 additions & 12 deletions python/google/protobuf/pyext/extension_dict.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,15 @@ static Py_ssize_t len(ExtensionDict* self) {
// happened to be linked in from C++ but not imported via Python. This is
// for consistency with the pure Python implementation.
if (fields[i]->file()->pool() == GetDefaultDescriptorPool()->pool &&
fields[i]->message_type() != nullptr &&
message_factory::GetMessageClass(
cmessage::GetFactoryForMessage(self->parent),
fields[i]->message_type()) == nullptr) {
PyErr_Clear();
continue;
fields[i]->message_type() != nullptr) {
CMessageClass* cls = message_factory::GetMessageClass(
cmessage::GetFactoryForMessage(self->parent),
fields[i]->message_type());
if (cls == nullptr) {
PyErr_Clear();
continue;
}
Py_DECREF(cls);
}
++size;
}
Expand Down Expand Up @@ -416,12 +419,15 @@ PyObject* IterNext(PyObject* _self) {
// for consistency with the pure Python implementation.
if (self->fields[index]->file()->pool() ==
GetDefaultDescriptorPool()->pool &&
self->fields[index]->message_type() != nullptr &&
message_factory::GetMessageClass(
cmessage::GetFactoryForMessage(self->extension_dict->parent),
self->fields[index]->message_type()) == nullptr) {
PyErr_Clear();
continue;
self->fields[index]->message_type() != nullptr) {
CMessageClass* cls = message_factory::GetMessageClass(
cmessage::GetFactoryForMessage(self->extension_dict->parent),
self->fields[index]->message_type());
if (cls == nullptr) {
PyErr_Clear();
continue;
}
Py_DECREF(cls);
}

return PyFieldDescriptor_FromDescriptor(self->fields[index]);
Expand Down
1 change: 0 additions & 1 deletion python/google/protobuf/pyext/map_container.cc
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,6 @@ PyObject* GetEntryClass(PyObject* _self) {
CMessageClass* message_class = message_factory::GetMessageClass(
cmessage::GetFactoryForMessage(self->parent),
self->parent_field_descriptor->message_type());
Py_XINCREF(message_class);
return reinterpret_cast<PyObject*>(message_class);
}

Expand Down
16 changes: 10 additions & 6 deletions python/google/protobuf/pyext/message.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2112,12 +2112,14 @@ static PyObject* ListFields(CMessage* self) {
// happened to be linked in from C++ but not imported via Python. This is
// for consistency with the pure Python implementation.
if (fields[i]->file()->pool() == GetDefaultDescriptorPool()->pool &&
fields[i]->message_type() != nullptr &&
message_factory::GetMessageClass(GetFactoryForMessage(self),
fields[i]->message_type()) ==
nullptr) {
PyErr_Clear();
continue;
fields[i]->message_type() != nullptr) {
CMessageClass* cls = message_factory::GetMessageClass(
GetFactoryForMessage(self), fields[i]->message_type());
if (cls == nullptr) {
PyErr_Clear();
continue;
}
Py_DECREF(cls);
}
ScopedPyObjectPtr extensions(GetExtensionDict(self, nullptr));
if (extensions == nullptr) {
Expand Down Expand Up @@ -2687,6 +2689,7 @@ PyObject* GetFieldValue(CMessage* self,
}
py_container =
NewMessageMapContainer(self, field_descriptor, value_class);
Py_DECREF(value_class);
} else {
py_container = NewScalarMapContainer(self, field_descriptor);
}
Expand All @@ -2699,6 +2702,7 @@ PyObject* GetFieldValue(CMessage* self,
}
py_container = repeated_composite_container::NewContainer(
self, field_descriptor, message_class);
Py_DECREF(message_class);
} else {
py_container =
repeated_scalar_container::NewContainer(self, field_descriptor);
Expand Down
73 changes: 54 additions & 19 deletions python/google/protobuf/pyext/message_factory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

#include "google/protobuf/dynamic_message.h"
#include "google/protobuf/pyext/descriptor.h"
#include "google/protobuf/pyext/free_threading_mutex.h"
#include "google/protobuf/pyext/message.h"
#include "google/protobuf/pyext/message_factory.h"
#include "google/protobuf/pyext/scoped_pyobject_ptr.h"
Expand All @@ -31,7 +32,8 @@ namespace python {

namespace message_factory {

PyMessageFactory* NewMessageFactory(PyTypeObject* type, PyDescriptorPool* pool) {
PyMessageFactory* NewMessageFactory(PyTypeObject* type,
PyDescriptorPool* pool) {
PyMessageFactory* factory = reinterpret_cast<PyMessageFactory*>(
PyType_GenericAlloc(type, 0));
if (factory == nullptr) {
Expand All @@ -46,7 +48,15 @@ PyMessageFactory* NewMessageFactory(PyTypeObject* type, PyDescriptorPool* pool)
factory->pool = pool;
Py_INCREF(pool);

factory->classes_by_descriptor = new PyMessageFactory::ClassesByMessageMap();
// Explicitly construct the mutex using placement new because
// PyType_GenericAlloc only zero-initializes memory.
new (&factory->mutex) FreeThreadingMutex();

{
FreeThreadingLockGuard lock(factory->mutex);
factory->classes_by_descriptor =
new PyMessageFactory::ClassesByMessageMap();
}

return factory;
}
Expand Down Expand Up @@ -81,21 +91,32 @@ PyObject* New(PyTypeObject* type, PyObject* args, PyObject* kwargs) {
static void Dealloc(PyObject* pself) {
PyMessageFactory* self = reinterpret_cast<PyMessageFactory*>(pself);

typedef PyMessageFactory::ClassesByMessageMap::iterator iterator;
for (iterator it = self->classes_by_descriptor->begin();
it != self->classes_by_descriptor->end(); ++it) {
Py_CLEAR(it->second);
{
FreeThreadingLockGuard lock(self->mutex);
typedef PyMessageFactory::ClassesByMessageMap::iterator iterator;
for (iterator it = self->classes_by_descriptor->begin();
it != self->classes_by_descriptor->end(); ++it) {
Py_CLEAR(it->second);
}
delete self->classes_by_descriptor;
self->classes_by_descriptor = nullptr;
}
delete self->classes_by_descriptor;

delete self->message_factory;
Py_CLEAR(self->pool);

// Explicitly call the destructor for the placement-new'd mutex.
self->mutex.~FreeThreadingMutex();

PyObject_GC_UnTrack(pself);
Py_TYPE(self)->tp_free(pself);
}

static int GcTraverse(PyObject* pself, visitproc visit, void* arg) {
PyMessageFactory* self = reinterpret_cast<PyMessageFactory*>(pself);
Py_VISIT(self->pool);

FreeThreadingLockGuard lock(self->mutex);
for (const auto& desc_and_class : *self->classes_by_descriptor) {
Py_VISIT(desc_and_class.second);
}
Expand All @@ -106,6 +127,8 @@ static int GcClear(PyObject* pself) {
PyMessageFactory* self = reinterpret_cast<PyMessageFactory*>(pself);
// Here it's important to not clear self->pool, so that the C++ DescriptorPool
// is still alive when self->message_factory is destructed.

FreeThreadingLockGuard lock(self->mutex);
for (auto& desc_and_class : *self->classes_by_descriptor) {
Py_CLEAR(desc_and_class.second);
}
Expand All @@ -118,13 +141,20 @@ int RegisterMessageClass(PyMessageFactory* self,
const Descriptor* message_descriptor,
CMessageClass* message_class) {
Py_INCREF(message_class);
typedef PyMessageFactory::ClassesByMessageMap::iterator iterator;
std::pair<iterator, bool> ret = self->classes_by_descriptor->insert(
std::make_pair(message_descriptor, message_class));
if (!ret.second) {
// Update case: DECREF the previous value.
Py_DECREF(ret.first->second);
ret.first->second = message_class;
CMessageClass* old_class = nullptr;
{
FreeThreadingLockGuard lock(self->mutex);
typedef PyMessageFactory::ClassesByMessageMap::iterator iterator;
std::pair<iterator, bool> ret = self->classes_by_descriptor->insert(
std::make_pair(message_descriptor, message_class));
if (!ret.second) {
// Update case: save the previous value to DECREF later.
old_class = ret.first->second;
ret.first->second = message_class;
}
}
if (old_class != nullptr) {
Py_DECREF(old_class);
}
return 0;
}
Expand All @@ -134,11 +164,14 @@ CMessageClass* GetOrCreateMessageClass(PyMessageFactory* self,
// This is the same implementation as MessageFactory.GetPrototype().

// Do not create a MessageClass that already exists.
std::unordered_map<const Descriptor*, CMessageClass*>::iterator it =
self->classes_by_descriptor->find(descriptor);
if (it != self->classes_by_descriptor->end()) {
Py_INCREF(it->second);
return it->second;
{
FreeThreadingLockGuard lock(self->mutex);
std::unordered_map<const Descriptor*, CMessageClass*>::iterator it =
self->classes_by_descriptor->find(descriptor);
if (it != self->classes_by_descriptor->end()) {
Py_INCREF(it->second);
return it->second;
}
}
ScopedPyObjectPtr py_descriptor(
PyMessageDescriptor_FromDescriptor(descriptor));
Expand Down Expand Up @@ -192,13 +225,15 @@ CMessageClass* GetOrCreateMessageClass(PyMessageFactory* self,
// Retrieve the message class added to our database.
CMessageClass* GetMessageClass(PyMessageFactory* self,
const Descriptor* message_descriptor) {
FreeThreadingLockGuard lock(self->mutex);
typedef PyMessageFactory::ClassesByMessageMap::iterator iterator;
iterator ret = self->classes_by_descriptor->find(message_descriptor);
if (ret == self->classes_by_descriptor->end()) {
PyErr_Format(PyExc_TypeError, "No message class registered for '%s'",
std::string(message_descriptor->full_name()).c_str());
return nullptr;
} else {
Py_INCREF(ret->second);
return ret->second;
}
}
Expand Down
9 changes: 7 additions & 2 deletions python/google/protobuf/pyext/message_factory.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@
#include <Python.h>

#include <unordered_map>

#include "google/protobuf/descriptor.h"
#include "google/protobuf/pyext/descriptor_pool.h"
#include "google/protobuf/pyext/free_threading_mutex.h"

namespace google {
namespace protobuf {
Expand Down Expand Up @@ -45,7 +47,10 @@ struct PyMessageFactory {
// Python references to classes are owned by this PyDescriptorPool.
typedef std::unordered_map<const Descriptor*, CMessageClass*>
ClassesByMessageMap;
ClassesByMessageMap* classes_by_descriptor;
ClassesByMessageMap* classes_by_descriptor ABSL_GUARDED_BY(mutex);

// Mutex to protect classes_by_descriptor in free-threaded builds.
FreeThreadingMutex mutex;
};

extern PyTypeObject PyMessageFactory_Type;
Expand All @@ -61,7 +66,7 @@ int RegisterMessageClass(PyMessageFactory* self,
const Descriptor* message_descriptor,
CMessageClass* message_class);
// Retrieves the Python class registered with the given message descriptor, or
// fail with a TypeError. Returns a *borrowed* reference.
// fail with a TypeError. Returns a *new* reference.
CMessageClass* GetMessageClass(PyMessageFactory* self,
const Descriptor* message_descriptor);
// Retrieves the Python class registered with the given message descriptor.
Expand Down
Loading