Skip to content
This repository was archived by the owner on Dec 9, 2025. It is now read-only.

Commit 79b735a

Browse files
Make upb numpy type checks consistent with pure python and cpp.
PiperOrigin-RevId: 464907203
1 parent e09d6fc commit 79b735a

4 files changed

Lines changed: 235 additions & 6 deletions

File tree

.github/workflows/python_tests.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,8 @@ jobs:
8989
run: pip install tzdata
9090
# Only needed on Windows, Linux ships with tzdata.
9191
if: ${{ contains(matrix.os, 'windows') }}
92+
- name: Install numpy
93+
run: pip install numpy
9294
- name: Install Protobuf Wheels
9395
run: pip install -vvv --no-index --find-links wheels protobuf protobuftests
9496
- name: Test that module is importable

python/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,7 @@ py_extension(
205205
deps = [
206206
"//:collections",
207207
"//:descriptor_upb_proto_reflection",
208+
"//:port",
208209
"//:reflection",
209210
"//:table_internal",
210211
"//:textformat",

python/convert.c

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,9 @@
3232
#include "upb/reflection.h"
3333
#include "upb/util/compare.h"
3434

35+
// Must be last.
36+
#include "upb/port_def.inc"
37+
3538
PyObject* PyUpb_UpbToPy(upb_MessageValue val, const upb_FieldDef* f,
3639
PyObject* arena) {
3740
switch (upb_FieldDef_CType(f)) {
@@ -150,6 +153,34 @@ static upb_MessageValue PyUpb_MaybeCopyString(const char* ptr, size_t size,
150153
return ret;
151154
}
152155

156+
const char* upb_FieldDef_TypeString(const upb_FieldDef* f) {
157+
switch (upb_FieldDef_CType(f)) {
158+
case kUpb_CType_Double:
159+
return "double";
160+
case kUpb_CType_Float:
161+
return "float";
162+
case kUpb_CType_Int64:
163+
return "int64";
164+
case kUpb_CType_Int32:
165+
return "int32";
166+
case kUpb_CType_UInt64:
167+
return "uint64";
168+
case kUpb_CType_UInt32:
169+
return "uint32";
170+
case kUpb_CType_Enum:
171+
return "enum";
172+
case kUpb_CType_Bool:
173+
return "bool";
174+
case kUpb_CType_String:
175+
return "string";
176+
case kUpb_CType_Bytes:
177+
return "bytes";
178+
case kUpb_CType_Message:
179+
return "message";
180+
}
181+
UPB_UNREACHABLE();
182+
}
183+
153184
static bool PyUpb_PyToUpbEnum(PyObject* obj, const upb_EnumDef* e,
154185
upb_MessageValue* val) {
155186
if (PyUnicode_Check(obj)) {
@@ -176,6 +207,20 @@ static bool PyUpb_PyToUpbEnum(PyObject* obj, const upb_EnumDef* e,
176207
}
177208
}
178209

210+
bool PyUpb_IsNumpyNdarray(PyObject* obj, const upb_FieldDef* f) {
211+
PyObject* type_name_obj =
212+
PyObject_GetAttrString((PyObject*)Py_TYPE(obj), "__name__");
213+
bool is_ndarray = false;
214+
if (!strcmp(PyUpb_GetStrData(type_name_obj), "ndarray")) {
215+
PyErr_Format(PyExc_TypeError,
216+
"%S has type ndarray, but expected one of: %s", obj,
217+
upb_FieldDef_TypeString(f));
218+
is_ndarray = true;
219+
}
220+
Py_DECREF(type_name_obj);
221+
return is_ndarray;
222+
}
223+
179224
bool PyUpb_PyToUpb(PyObject* obj, const upb_FieldDef* f, upb_MessageValue* val,
180225
upb_Arena* arena) {
181226
switch (upb_FieldDef_CType(f)) {
@@ -190,12 +235,15 @@ bool PyUpb_PyToUpb(PyObject* obj, const upb_FieldDef* f, upb_MessageValue* val,
190235
case kUpb_CType_UInt64:
191236
return PyUpb_GetUint64(obj, &val->uint64_val);
192237
case kUpb_CType_Float:
238+
if (PyUpb_IsNumpyNdarray(obj, f)) return false;
193239
val->float_val = PyFloat_AsDouble(obj);
194240
return !PyErr_Occurred();
195241
case kUpb_CType_Double:
242+
if (PyUpb_IsNumpyNdarray(obj, f)) return false;
196243
val->double_val = PyFloat_AsDouble(obj);
197244
return !PyErr_Occurred();
198245
case kUpb_CType_Bool:
246+
if (PyUpb_IsNumpyNdarray(obj, f)) return false;
199247
val->bool_val = PyLong_AsLong(obj);
200248
return !PyErr_Occurred();
201249
case kUpb_CType_Bytes: {
@@ -223,6 +271,7 @@ bool PyUpb_PyToUpb(PyObject* obj, const upb_FieldDef* f, upb_MessageValue* val,
223271
return true;
224272
}
225273
case kUpb_CType_Message:
274+
// TODO(b/238226055): Include ctype in error message.
226275
PyErr_Format(PyExc_ValueError, "Message objects may not be assigned",
227276
upb_FieldDef_CType(f));
228277
return false;
@@ -392,3 +441,5 @@ bool upb_Message_IsEqual(const upb_Message* msg1, const upb_Message* msg2,
392441
return upb_Message_UnknownFieldsAreEqual(uf1, usize1, uf2, usize2, 100) ==
393442
kUpb_UnknownCompareResult_Equal;
394443
}
444+
445+
#include "upb/port_undef.inc"

python/pb_unit_tests/numpy_test_wrapper.py

Lines changed: 181 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,188 @@
2727

2828
# begin:google_only
2929
# from google.protobuf.internal.numpy_test import *
30-
#
31-
# # TODO(b/227379846): upb does not match pure-python and fast cpp behavior for
32-
# # assignment of numpy arrays to proto float or multidimensional arrays to
33-
# # repeated fields yet.
34-
# NumpyFloatProtoTest.testNumpyFloatArrayToScalar_RaisesTypeError.__unittest_expecting_failure__ = True
35-
# NumpyFloatProtoTest.testNumpyDim2FloatArrayToRepeated_RaisesTypeError.__unittest_expecting_failure__ = True
3630
# end:google_only
3731

32+
# begin:github_only
33+
# TODO(b/240447513) Delete workaround after numpy_test is open-sourced in
34+
# protobuf github.
35+
import unittest
36+
37+
import numpy as np
38+
39+
from google.protobuf import unittest_pb2
40+
from google.protobuf.internal import testing_refleaks
41+
42+
message = unittest_pb2.TestAllTypes()
43+
np_float_scalar = np.float64(0.0)
44+
np_1_float_array = np.zeros(shape=(1,), dtype=np.float64)
45+
np_2_float_array = np.zeros(shape=(2,), dtype=np.float64)
46+
np_11_float_array = np.zeros(shape=(1, 1), dtype=np.float64)
47+
np_22_float_array = np.zeros(shape=(2, 2), dtype=np.float64)
48+
49+
np_int_scalar = np.int64(0)
50+
np_1_int_array = np.zeros(shape=(1,), dtype=np.int64)
51+
np_2_int_array = np.zeros(shape=(2,), dtype=np.int64)
52+
np_11_int_array = np.zeros(shape=(1, 1), dtype=np.int64)
53+
np_22_int_array = np.zeros(shape=(2, 2), dtype=np.int64)
54+
55+
np_uint_scalar = np.uint64(0)
56+
np_1_uint_array = np.zeros(shape=(1,), dtype=np.uint64)
57+
np_2_uint_array = np.zeros(shape=(2,), dtype=np.uint64)
58+
np_11_uint_array = np.zeros(shape=(1, 1), dtype=np.uint64)
59+
np_22_uint_array = np.zeros(shape=(2, 2), dtype=np.uint64)
60+
61+
np_bool_scalar = np.bool_(False)
62+
np_1_bool_array = np.zeros(shape=(1,), dtype=np.bool_)
63+
np_2_bool_array = np.zeros(shape=(2,), dtype=np.bool_)
64+
np_11_bool_array = np.zeros(shape=(1, 1), dtype=np.bool_)
65+
np_22_bool_array = np.zeros(shape=(2, 2), dtype=np.bool_)
66+
67+
@testing_refleaks.TestCase
68+
class NumpyIntProtoTest(unittest.TestCase):
69+
70+
# Assigning dim 1 ndarray of ints to repeated field should pass
71+
def testNumpyDim1IntArrayToRepeated_IsValid(self):
72+
message.repeated_int64[:] = np_1_int_array
73+
message.repeated_int64[:] = np_2_int_array
74+
75+
message.repeated_uint64[:] = np_1_uint_array
76+
message.repeated_uint64[:] = np_2_uint_array
77+
78+
# Assigning dim 2 ndarray of ints to repeated field should fail
79+
def testNumpyDim2IntArrayToRepeated_RaisesTypeError(self):
80+
with self.assertRaises(TypeError):
81+
message.repeated_int64[:] = np_11_int_array
82+
with self.assertRaises(TypeError):
83+
message.repeated_int64[:] = np_22_int_array
84+
85+
with self.assertRaises(TypeError):
86+
message.repeated_uint64[:] = np_11_uint_array
87+
with self.assertRaises(TypeError):
88+
message.repeated_uint64[:] = np_22_uint_array
89+
90+
# Assigning any ndarray of floats to repeated int field should fail
91+
def testNumpyFloatArrayToRepeated_RaisesTypeError(self):
92+
with self.assertRaises(TypeError):
93+
message.repeated_int64[:] = np_1_float_array
94+
with self.assertRaises(TypeError):
95+
message.repeated_int64[:] = np_11_float_array
96+
with self.assertRaises(TypeError):
97+
message.repeated_int64[:] = np_22_float_array
98+
99+
# Assigning any np int to scalar field should pass
100+
def testNumpyIntScalarToScalar_IsValid(self):
101+
message.optional_int64 = np_int_scalar
102+
message.optional_uint64 = np_uint_scalar
103+
104+
# Assigning any ndarray of ints to scalar field should fail
105+
def testNumpyIntArrayToScalar_RaisesTypeError(self):
106+
with self.assertRaises(TypeError):
107+
message.optional_int64 = np_1_int_array
108+
with self.assertRaises(TypeError):
109+
message.optional_int64 = np_11_int_array
110+
with self.assertRaises(TypeError):
111+
message.optional_int64 = np_22_int_array
112+
113+
with self.assertRaises(TypeError):
114+
message.optional_uint64 = np_1_uint_array
115+
with self.assertRaises(TypeError):
116+
message.optional_uint64 = np_11_uint_array
117+
with self.assertRaises(TypeError):
118+
message.optional_uint64 = np_22_uint_array
119+
120+
# Assigning any ndarray of floats to scalar field should fail
121+
def testNumpyFloatArrayToScalar_RaisesTypeError(self):
122+
with self.assertRaises(TypeError):
123+
message.optional_int64 = np_1_float_array
124+
with self.assertRaises(TypeError):
125+
message.optional_int64 = np_11_float_array
126+
with self.assertRaises(TypeError):
127+
message.optional_int64 = np_22_float_array
128+
129+
@testing_refleaks.TestCase
130+
class NumpyFloatProtoTest(unittest.TestCase):
131+
132+
# Assigning dim 1 ndarray of floats to repeated field should pass
133+
def testNumpyDim1FloatArrayToRepeated_IsValid(self):
134+
message.repeated_float[:] = np_1_float_array
135+
message.repeated_float[:] = np_2_float_array
136+
137+
# Assigning dim 2 ndarray of floats to repeated field should fail
138+
def testNumpyDim2FloatArrayToRepeated_RaisesTypeError(self):
139+
with self.assertRaises(TypeError):
140+
message.repeated_float[:] = np_11_float_array
141+
with self.assertRaises(TypeError):
142+
message.repeated_float[:] = np_22_float_array
143+
144+
# Assigning any np float to scalar field should pass
145+
def testNumpyFloatScalarToScalar_IsValid(self):
146+
message.optional_float = np_float_scalar
147+
148+
# Assigning any ndarray of float to scalar field should fail
149+
def testNumpyFloatArrayToScalar_RaisesTypeError(self):
150+
with self.assertRaises(TypeError):
151+
message.optional_float = np_1_float_array
152+
with self.assertRaises(TypeError):
153+
message.optional_float = np_11_float_array
154+
with self.assertRaises(TypeError):
155+
message.optional_float = np_22_float_array
156+
157+
@testing_refleaks.TestCase
158+
class NumpyBoolProtoTest(unittest.TestCase):
159+
160+
# Assigning dim 1 ndarray of bool to repeated field should pass
161+
def testNumpyDim1BoolArrayToRepeated_IsValid(self):
162+
message.repeated_bool[:] = np_1_bool_array
163+
message.repeated_bool[:] = np_2_bool_array
164+
165+
# Assigning dim 2 ndarray of bool to repeated field should fail
166+
def testNumpyDim2BoolArrayToRepeated_RaisesTypeError(self):
167+
with self.assertRaises(TypeError):
168+
message.repeated_bool[:] = np_11_bool_array
169+
with self.assertRaises(TypeError):
170+
message.repeated_bool[:] = np_22_bool_array
171+
172+
# Assigning any np bool to scalar field should pass
173+
def testNumpyBoolScalarToScalar_IsValid(self):
174+
message.optional_bool = np_bool_scalar
175+
176+
# Assigning any ndarray of bool to scalar field should fail
177+
def testNumpyBoolArrayToScalar_RaisesTypeError(self):
178+
with self.assertRaises(TypeError):
179+
message.optional_bool = np_1_bool_array
180+
with self.assertRaises(TypeError):
181+
message.optional_bool = np_11_bool_array
182+
with self.assertRaises(TypeError):
183+
message.optional_bool = np_22_bool_array
184+
185+
@testing_refleaks.TestCase
186+
class NumpyProtoIndexingTest(unittest.TestCase):
187+
188+
def testNumpyIntScalarIndexing_Passes(self):
189+
data = unittest_pb2.TestAllTypes(repeated_int64=[0, 1, 2])
190+
self.assertEqual(0, data.repeated_int64[np.int64(0)])
191+
192+
def testNumpyNegative1IntScalarIndexing_Passes(self):
193+
data = unittest_pb2.TestAllTypes(repeated_int64=[0, 1, 2])
194+
self.assertEqual(2, data.repeated_int64[np.int64(-1)])
195+
196+
def testNumpyFloatScalarIndexing_Fails(self):
197+
data = unittest_pb2.TestAllTypes(repeated_int64=[0, 1, 2])
198+
with self.assertRaises(TypeError):
199+
_ = data.repeated_int64[np.float64(0.0)]
200+
201+
def testNumpyIntArrayIndexing_Fails(self):
202+
data = unittest_pb2.TestAllTypes(repeated_int64=[0, 1, 2])
203+
with self.assertRaises(TypeError):
204+
_ = data.repeated_int64[np.array([0])]
205+
with self.assertRaises(TypeError):
206+
_ = data.repeated_int64[np.ndarray((1,), buffer=np.array([0]), dtype=int)]
207+
with self.assertRaises(TypeError):
208+
_ = data.repeated_int64[np.ndarray((1, 1),
209+
buffer=np.array([0]),
210+
dtype=int)]
211+
# end:github_only
212+
38213
if __name__ == '__main__':
39214
unittest.main(verbosity=2)

0 commit comments

Comments
 (0)