Skip to content

Commit c5af8af

Browse files
authored
[mypyc] Add __contains__ method to vec (#21138)
This is used in interpreted/non-native code only, where it will speed up 'in' operations significantly. This is also added for consistency with built-in sequence types. I used coding agents assist but did changes in small increments and reviewed each change manually.
1 parent 24b4a81 commit c5af8af

File tree

7 files changed

+180
-0
lines changed

7 files changed

+180
-0
lines changed

mypy/typeshed/stubs/librt/librt/vecs.pyi

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ class vec(Generic[T]):
1414
@overload
1515
def __getitem__(self, i: slice, /) -> vec[T]: ...
1616
def __setitem__(self, i: i64, o: T, /) -> None: ...
17+
def __contains__(self, o: object, /) -> bool: ...
1718
def __iter__(self) -> Iterator[T]: ...
1819

1920
def append(v: vec[T], o: T, /) -> vec[T]: ...

mypyc/lib-rt/vecs/vec_nested.c

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,24 @@ static int vec_ass_item(PyObject *self, Py_ssize_t i, PyObject *o) {
178178
}
179179
}
180180

181+
static int vec_contains(PyObject *self, PyObject *value) {
182+
VecNested v = ((VecNestedObject *)self)->vec;
183+
for (Py_ssize_t i = 0; i < v.len; i++) {
184+
PyObject *item = box_vec_item_by_index(v, i);
185+
if (item == NULL)
186+
return -1;
187+
if (item == value) {
188+
Py_DECREF(item);
189+
return 1;
190+
}
191+
int cmp = PyObject_RichCompareBool(item, value, Py_EQ);
192+
Py_DECREF(item);
193+
if (cmp != 0)
194+
return cmp; // 1 if equal, -1 on error
195+
}
196+
return 0;
197+
}
198+
181199
static PyObject *compare_vec_eq(VecNested x, VecNested y, int op) {
182200
int cmp = 1;
183201
PyObject *res;
@@ -414,6 +432,7 @@ static PyMappingMethods VecNestedMapping = {
414432
static PySequenceMethods VecNestedSequence = {
415433
.sq_item = vec_get_item,
416434
.sq_ass_item = vec_ass_item,
435+
.sq_contains = vec_contains,
417436
};
418437

419438
static PyMethodDef vec_methods[] = {

mypyc/lib-rt/vecs/vec_t.c

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,22 @@ static int vec_ass_item(PyObject *self, Py_ssize_t i, PyObject *o) {
210210
}
211211
}
212212

213+
static int vec_contains(PyObject *self, PyObject *value) {
214+
VecT v = ((VecTObject *)self)->vec;
215+
for (Py_ssize_t i = 0; i < v.len; i++) {
216+
PyObject *item = v.buf->items[i];
217+
if (item == value) {
218+
return 1;
219+
}
220+
Py_INCREF(item);
221+
int cmp = PyObject_RichCompareBool(item, value, Py_EQ);
222+
Py_DECREF(item);
223+
if (cmp != 0)
224+
return cmp; // 1 if equal, -1 on error
225+
}
226+
return 0;
227+
}
228+
213229
static PyObject *vec_richcompare(PyObject *self, PyObject *other, int op) {
214230
PyObject *res;
215231
if (op == Py_EQ || op == Py_NE) {
@@ -410,6 +426,7 @@ static PyMappingMethods VecTMapping = {
410426
static PySequenceMethods VecTSequence = {
411427
.sq_item = vec_get_item,
412428
.sq_ass_item = vec_ass_item,
429+
.sq_contains = vec_contains,
413430
};
414431

415432
static PyMethodDef vec_methods[] = {

mypyc/lib-rt/vecs/vec_template.c

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
#include <Python.h>
3030
#include "librt_vecs.h"
3131
#include "vecs_internal.h"
32+
#include "mypyc_util.h"
3233

3334
inline static VEC vec_error() {
3435
VEC v = { .len = -1 };
@@ -235,6 +236,32 @@ static int vec_ass_item(PyObject *self, Py_ssize_t i, PyObject *o) {
235236
}
236237
}
237238

239+
static int vec_contains(PyObject *self, PyObject *value) {
240+
ITEM_C_TYPE x = UNBOX_ITEM(value);
241+
if (unlikely(IS_UNBOX_ERROR(x))) {
242+
if (PyErr_Occurred())
243+
PyErr_Clear();
244+
// Fall back to boxed comparison (e.g. 2.0 == 2)
245+
VEC v = ((VEC_OBJECT *)self)->vec;
246+
for (Py_ssize_t i = 0; i < v.len; i++) {
247+
PyObject *boxed = BOX_ITEM(v.buf->items[i]);
248+
if (boxed == NULL)
249+
return -1;
250+
int cmp = PyObject_RichCompareBool(boxed, value, Py_EQ);
251+
Py_DECREF(boxed);
252+
if (cmp != 0)
253+
return cmp; // 1 if equal, -1 on error
254+
}
255+
return 0;
256+
}
257+
VEC v = ((VEC_OBJECT *)self)->vec;
258+
for (Py_ssize_t i = 0; i < v.len; i++) {
259+
if (v.buf->items[i] == x)
260+
return 1;
261+
}
262+
return 0;
263+
}
264+
238265
static Py_ssize_t vec_length(PyObject *o) {
239266
return ((VEC_OBJECT *)o)->vec.len;
240267
}
@@ -348,6 +375,7 @@ static PyMappingMethods vec_mapping_methods = {
348375
static PySequenceMethods vec_sequence_methods = {
349376
.sq_item = vec_get_item,
350377
.sq_ass_item = vec_ass_item,
378+
.sq_contains = vec_contains,
351379
};
352380

353381
static PyMethodDef vec_methods[] = {

mypyc/test-data/run-vecs-i64-interp.test

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -312,6 +312,41 @@ def test_contains() -> None:
312312
vv = vec[i64]()
313313
assert 3 not in vv
314314

315+
def test_contains_boundary_values() -> None:
316+
v = vec[i64]([0, 2**63 - 1, -2**63])
317+
assert 0 in v
318+
assert 2**63 - 1 in v
319+
assert -2**63 in v
320+
assert 1 not in v
321+
assert -1 not in v
322+
323+
def test_contains_wrong_type() -> None:
324+
# Wrong type should return False, not raise
325+
v: Any = vec[i64]([1, 2, 3])
326+
assert 'x' not in v
327+
assert None not in v
328+
assert [] not in v
329+
# Overflow should also return False
330+
assert 2**63 not in v
331+
assert -2**63 - 1 not in v
332+
333+
def test_contains_float() -> None:
334+
# float == int equality should work via boxed fallback
335+
v: Any = vec[i64]([1, 2, 3])
336+
assert 2.0 in v
337+
assert 3.0 in v
338+
assert 1.5 not in v
339+
assert 0.0 not in v
340+
341+
def test_contains_bool() -> None:
342+
# bool is a subclass of int, so True/False should work
343+
v: Any = vec[i64]([0, 1, 5])
344+
assert True in v
345+
assert False in v
346+
v2: Any = vec[i64]([2, 3])
347+
assert True not in v2
348+
assert False not in v2
349+
315350
def test_remove() -> None:
316351
a = [4, 7, 9]
317352
for i in a:

mypyc/test-data/run-vecs-nested-interp.test

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -410,3 +410,48 @@ def test_iterator_length_hint() -> None:
410410
next(it)
411411
next(it)
412412
assert length_hint(it) == 0
413+
414+
def test_contains() -> None:
415+
v1 = vec[str](['a'])
416+
v2 = vec[str](['b'])
417+
v3 = vec[str](['c'])
418+
v = vec[vec[str]]([v1, v2, v3])
419+
assert v1 in v
420+
assert v2 in v
421+
assert v3 in v
422+
assert vec[str]([]) not in v
423+
assert vec[str](['d']) not in v
424+
425+
def test_contains_empty() -> None:
426+
v = vec[vec[str]]()
427+
assert vec[str]([]) not in v
428+
assert vec[str](['x']) not in v
429+
430+
def test_contains_equality() -> None:
431+
# Test that equality is used, not just identity
432+
v = vec[vec[str]]([vec[str](['x', 'y'])])
433+
assert vec[str](['x', 'y']) in v
434+
assert vec[str](['x']) not in v
435+
assert vec[str](['y']) not in v
436+
437+
def test_contains_nested_i64() -> None:
438+
v = vec[vec[i64]]([vec[i64]([1, 2]), vec[i64]([3])])
439+
assert vec[i64]([1, 2]) in v
440+
assert vec[i64]([3]) in v
441+
assert vec[i64]([1]) not in v
442+
assert vec[i64]([]) not in v
443+
444+
def test_contains_wrong_type() -> None:
445+
v: Any = vec[vec[str]]([vec[str](['x'])])
446+
assert 'x' not in v
447+
assert 123 not in v
448+
assert None not in v
449+
assert [] not in v
450+
451+
def test_contains_deeply_nested() -> None:
452+
inner = vec[vec[str]]([vec[str](['a'])])
453+
v = vec[vec[vec[str]]]([inner])
454+
assert inner in v
455+
assert vec[vec[str]]([vec[str](['a'])]) in v
456+
assert vec[vec[str]]([vec[str](['b'])]) not in v
457+
assert vec[vec[str]]() not in v

mypyc/test-data/run-vecs-t-interp.test

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -419,6 +419,41 @@ def test_contains() -> None:
419419
assert None in v2
420420
assert '4' not in v2
421421

422+
def test_contains_identity() -> None:
423+
# Test that identity check works (short-circuits before __eq__)
424+
o = ClassWithRaisingEq()
425+
v = vec[object]([o])
426+
assert o in v # Should use identity, not __eq__
427+
428+
def test_contains_equality() -> None:
429+
# Test that equality is used when identity doesn't match
430+
s = 'hello world!'
431+
v = vec[str]([s])
432+
# Create a new string object with same value (different identity)
433+
s2 = 'hello' + str() + ' world!'
434+
assert s2 in v
435+
436+
def test_contains_exception_from_eq() -> None:
437+
o1 = ClassWithRaisingEq()
438+
o2 = ClassWithRaisingEq()
439+
v = vec[object]([o1])
440+
with assertRaises(RuntimeError):
441+
o2 in v
442+
443+
def test_contains_wrong_type() -> None:
444+
# Wrong type should return False, not raise
445+
v: Any = vec[str](['x'])
446+
assert 123 not in v
447+
assert b'x' not in v
448+
assert None not in v
449+
assert [] not in v
450+
451+
def test_contains_subclass() -> None:
452+
v = vec[str]([StringSubclass('abc')])
453+
assert 'abc' in v
454+
assert StringSubclass('abc') in v
455+
assert 'xyz' not in v
456+
422457
def test_remove() -> None:
423458
a = ['4x', '7x', '9x']
424459
for i in a:

0 commit comments

Comments
 (0)