|
| 1 | +# Copyright (c) 2024 The Pybind Development Team. All rights reserved. |
| 2 | +# |
| 3 | +# All rights reserved. Use of this source code is governed by a |
| 4 | +# BSD-style license that can be found in the LICENSE file. |
| 5 | +"""Regression test for use-after-free in FindFileContainingExtension. |
| 6 | +
|
| 7 | +When a dynamic Python proto with extensions is passed through pybind11_protobuf, |
| 8 | +the C++ bridge creates a DescriptorPool backed by the Python pool. During |
| 9 | +proto parsing on the C++ side, extension lookup calls back into the Python |
| 10 | +pool via FindExtensionByNumber. If the intermediate FieldDescriptor wrapper |
| 11 | +returned by FindExtensionByNumber is not kept alive while its `.file` attribute |
| 12 | +is accessed and serialized, the UPB runtime may free the wrapper before |
| 13 | +DescriptorPoolDatabase::CopyToFileDescriptorProto finishes, leading to a |
| 14 | +heap-use-after-free. |
| 15 | +
|
| 16 | +This test creates a fully dynamic proto schema (base message + extension in a |
| 17 | +separate file), serializes a message with the extension set, then roundtrips |
| 18 | +it through a C++ pybind11 function that accepts `const proto2::Message&`. |
| 19 | +Without the fix in proto_cast_util.cc, this crashes under AddressSanitizer. |
| 20 | +""" |
| 21 | + |
| 22 | +import gc |
| 23 | + |
| 24 | +from absl.testing import absltest |
| 25 | +from google.protobuf import descriptor_pb2 |
| 26 | +from google.protobuf import descriptor_pool |
| 27 | +from google.protobuf import message_factory |
| 28 | + |
| 29 | +from pybind11_protobuf.tests import dynamic_message_module as m |
| 30 | + |
| 31 | + |
| 32 | +def _make_base_file(): |
| 33 | + """Create a FileDescriptorProto with an extendable message.""" |
| 34 | + return descriptor_pb2.FileDescriptorProto( |
| 35 | + name='_dynamic_ext_test_base.proto', |
| 36 | + package='pybind11.dynamic_ext_test', |
| 37 | + syntax='proto2', |
| 38 | + message_type=[ |
| 39 | + descriptor_pb2.DescriptorProto( |
| 40 | + name='Extendable', |
| 41 | + field=[ |
| 42 | + descriptor_pb2.FieldDescriptorProto( |
| 43 | + name='id', |
| 44 | + number=1, |
| 45 | + type=descriptor_pb2.FieldDescriptorProto.TYPE_INT32, |
| 46 | + label=descriptor_pb2.FieldDescriptorProto.LABEL_OPTIONAL, |
| 47 | + ), |
| 48 | + ], |
| 49 | + extension_range=[ |
| 50 | + descriptor_pb2.DescriptorProto.ExtensionRange( |
| 51 | + start=100, end=10000 |
| 52 | + ), |
| 53 | + ], |
| 54 | + ), |
| 55 | + ], |
| 56 | + ) |
| 57 | + |
| 58 | + |
| 59 | +def _make_ext_file(): |
| 60 | + """Create a FileDescriptorProto defining an extension in a separate file.""" |
| 61 | + return descriptor_pb2.FileDescriptorProto( |
| 62 | + name='_dynamic_ext_test_ext.proto', |
| 63 | + package='pybind11.dynamic_ext_test', |
| 64 | + syntax='proto2', |
| 65 | + dependency=['_dynamic_ext_test_base.proto'], |
| 66 | + message_type=[ |
| 67 | + descriptor_pb2.DescriptorProto( |
| 68 | + name='Annotation', |
| 69 | + field=[ |
| 70 | + descriptor_pb2.FieldDescriptorProto( |
| 71 | + name='value', |
| 72 | + number=1, |
| 73 | + type=descriptor_pb2.FieldDescriptorProto.TYPE_STRING, |
| 74 | + label=descriptor_pb2.FieldDescriptorProto.LABEL_OPTIONAL, |
| 75 | + ), |
| 76 | + ], |
| 77 | + ), |
| 78 | + ], |
| 79 | + extension=[ |
| 80 | + descriptor_pb2.FieldDescriptorProto( |
| 81 | + name='annotation_ext', |
| 82 | + number=200, |
| 83 | + type=descriptor_pb2.FieldDescriptorProto.TYPE_MESSAGE, |
| 84 | + type_name='.pybind11.dynamic_ext_test.Annotation', |
| 85 | + extendee='.pybind11.dynamic_ext_test.Extendable', |
| 86 | + label=descriptor_pb2.FieldDescriptorProto.LABEL_OPTIONAL, |
| 87 | + ), |
| 88 | + ], |
| 89 | + ) |
| 90 | + |
| 91 | + |
| 92 | +class DynamicExtensionRoundtripTest(absltest.TestCase): |
| 93 | + """Regression test: dynamic proto extensions via pybind11 roundtrip.""" |
| 94 | + |
| 95 | + def test_roundtrip_dynamic_extension(self): |
| 96 | + """A dynamic proto with extensions must survive a pybind11 roundtrip.""" |
| 97 | + base_file = _make_base_file() |
| 98 | + ext_file = _make_ext_file() |
| 99 | + |
| 100 | + # Create a pool with both the base and extension files. |
| 101 | + pool = descriptor_pool.DescriptorPool() |
| 102 | + pool.Add(base_file) |
| 103 | + pool.Add(ext_file) |
| 104 | + |
| 105 | + classes = message_factory.GetMessageClassesForFiles( |
| 106 | + [base_file.name, ext_file.name], pool |
| 107 | + ) |
| 108 | + |
| 109 | + # Build a message with the extension set. |
| 110 | + extendable_cls = classes['pybind11.dynamic_ext_test.Extendable'] |
| 111 | + ext_desc = pool.FindExtensionByName( |
| 112 | + 'pybind11.dynamic_ext_test.annotation_ext' |
| 113 | + ) |
| 114 | + msg = extendable_cls() |
| 115 | + msg.id = 42 |
| 116 | + msg.Extensions[ext_desc].value = 'hello' |
| 117 | + |
| 118 | + # Drop reference to descriptor and force GC to ensure the wrapper is |
| 119 | + # collected if not held elsewhere, triggering the UAF when C++ tries |
| 120 | + # to look it up again. |
| 121 | + del ext_desc |
| 122 | + gc.collect() |
| 123 | + |
| 124 | + # Roundtrip through C++ via `const proto2::Message&`. |
| 125 | + # This triggers AllocateCProtoFromPythonSymbolDatabase which wraps the |
| 126 | + # Python pool in a C++ DescriptorPool, and ParsePartialFromString which |
| 127 | + # triggers FindFileContainingExtension to look up the extension. |
| 128 | + result = m.roundtrip(msg) |
| 129 | + |
| 130 | + # Verify the message survived the roundtrip. |
| 131 | + # The result is a Python proto from the same dynamic pool. |
| 132 | + self.assertEqual(result.id, 42) |
| 133 | + |
| 134 | + # The extension data should be preserved (either as a known extension |
| 135 | + # or as unknown fields that re-serialize identically). |
| 136 | + self.assertEqual( |
| 137 | + result.SerializeToString(), msg.SerializeToString() |
| 138 | + ) |
| 139 | + |
| 140 | + |
| 141 | +if __name__ == '__main__': |
| 142 | + absltest.main() |
0 commit comments