diff --git a/mypy/typeshed/stubs/librt/librt/vecs.pyi b/mypy/typeshed/stubs/librt/librt/vecs.pyi index a4e8c88e0e596..dbf199da9465c 100644 --- a/mypy/typeshed/stubs/librt/librt/vecs.pyi +++ b/mypy/typeshed/stubs/librt/librt/vecs.pyi @@ -14,6 +14,7 @@ class vec(Generic[T]): @overload def __getitem__(self, i: slice, /) -> vec[T]: ... def __setitem__(self, i: i64, o: T, /) -> None: ... + def __contains__(self, o: object, /) -> bool: ... def __iter__(self) -> Iterator[T]: ... def append(v: vec[T], o: T, /) -> vec[T]: ... diff --git a/mypyc/lib-rt/vecs/vec_nested.c b/mypyc/lib-rt/vecs/vec_nested.c index 459a40735744a..f877566c17e24 100644 --- a/mypyc/lib-rt/vecs/vec_nested.c +++ b/mypyc/lib-rt/vecs/vec_nested.c @@ -178,6 +178,24 @@ static int vec_ass_item(PyObject *self, Py_ssize_t i, PyObject *o) { } } +static int vec_contains(PyObject *self, PyObject *value) { + VecNested v = ((VecNestedObject *)self)->vec; + for (Py_ssize_t i = 0; i < v.len; i++) { + PyObject *item = box_vec_item_by_index(v, i); + if (item == NULL) + return -1; + if (item == value) { + Py_DECREF(item); + return 1; + } + int cmp = PyObject_RichCompareBool(item, value, Py_EQ); + Py_DECREF(item); + if (cmp != 0) + return cmp; // 1 if equal, -1 on error + } + return 0; +} + static PyObject *compare_vec_eq(VecNested x, VecNested y, int op) { int cmp = 1; PyObject *res; @@ -414,6 +432,7 @@ static PyMappingMethods VecNestedMapping = { static PySequenceMethods VecNestedSequence = { .sq_item = vec_get_item, .sq_ass_item = vec_ass_item, + .sq_contains = vec_contains, }; static PyMethodDef vec_methods[] = { diff --git a/mypyc/lib-rt/vecs/vec_t.c b/mypyc/lib-rt/vecs/vec_t.c index 79938cb9aade3..56188a42410e8 100644 --- a/mypyc/lib-rt/vecs/vec_t.c +++ b/mypyc/lib-rt/vecs/vec_t.c @@ -210,6 +210,22 @@ static int vec_ass_item(PyObject *self, Py_ssize_t i, PyObject *o) { } } +static int vec_contains(PyObject *self, PyObject *value) { + VecT v = ((VecTObject *)self)->vec; + for (Py_ssize_t i = 0; i < v.len; i++) { + PyObject *item = v.buf->items[i]; + if (item == value) { + return 1; + } + Py_INCREF(item); + int cmp = PyObject_RichCompareBool(item, value, Py_EQ); + Py_DECREF(item); + if (cmp != 0) + return cmp; // 1 if equal, -1 on error + } + return 0; +} + static PyObject *vec_richcompare(PyObject *self, PyObject *other, int op) { PyObject *res; if (op == Py_EQ || op == Py_NE) { @@ -410,6 +426,7 @@ static PyMappingMethods VecTMapping = { static PySequenceMethods VecTSequence = { .sq_item = vec_get_item, .sq_ass_item = vec_ass_item, + .sq_contains = vec_contains, }; static PyMethodDef vec_methods[] = { diff --git a/mypyc/lib-rt/vecs/vec_template.c b/mypyc/lib-rt/vecs/vec_template.c index 7881bb496a88c..5978f4af5fb27 100644 --- a/mypyc/lib-rt/vecs/vec_template.c +++ b/mypyc/lib-rt/vecs/vec_template.c @@ -29,6 +29,7 @@ #include #include "librt_vecs.h" #include "vecs_internal.h" +#include "mypyc_util.h" inline static VEC vec_error() { VEC v = { .len = -1 }; @@ -235,6 +236,32 @@ static int vec_ass_item(PyObject *self, Py_ssize_t i, PyObject *o) { } } +static int vec_contains(PyObject *self, PyObject *value) { + ITEM_C_TYPE x = UNBOX_ITEM(value); + if (unlikely(IS_UNBOX_ERROR(x))) { + if (PyErr_Occurred()) + PyErr_Clear(); + // Fall back to boxed comparison (e.g. 2.0 == 2) + VEC v = ((VEC_OBJECT *)self)->vec; + for (Py_ssize_t i = 0; i < v.len; i++) { + PyObject *boxed = BOX_ITEM(v.buf->items[i]); + if (boxed == NULL) + return -1; + int cmp = PyObject_RichCompareBool(boxed, value, Py_EQ); + Py_DECREF(boxed); + if (cmp != 0) + return cmp; // 1 if equal, -1 on error + } + return 0; + } + VEC v = ((VEC_OBJECT *)self)->vec; + for (Py_ssize_t i = 0; i < v.len; i++) { + if (v.buf->items[i] == x) + return 1; + } + return 0; +} + static Py_ssize_t vec_length(PyObject *o) { return ((VEC_OBJECT *)o)->vec.len; } @@ -348,6 +375,7 @@ static PyMappingMethods vec_mapping_methods = { static PySequenceMethods vec_sequence_methods = { .sq_item = vec_get_item, .sq_ass_item = vec_ass_item, + .sq_contains = vec_contains, }; static PyMethodDef vec_methods[] = { diff --git a/mypyc/test-data/run-vecs-i64-interp.test b/mypyc/test-data/run-vecs-i64-interp.test index 166d6818a724c..5fd2c00d984fa 100644 --- a/mypyc/test-data/run-vecs-i64-interp.test +++ b/mypyc/test-data/run-vecs-i64-interp.test @@ -312,6 +312,41 @@ def test_contains() -> None: vv = vec[i64]() assert 3 not in vv +def test_contains_boundary_values() -> None: + v = vec[i64]([0, 2**63 - 1, -2**63]) + assert 0 in v + assert 2**63 - 1 in v + assert -2**63 in v + assert 1 not in v + assert -1 not in v + +def test_contains_wrong_type() -> None: + # Wrong type should return False, not raise + v: Any = vec[i64]([1, 2, 3]) + assert 'x' not in v + assert None not in v + assert [] not in v + # Overflow should also return False + assert 2**63 not in v + assert -2**63 - 1 not in v + +def test_contains_float() -> None: + # float == int equality should work via boxed fallback + v: Any = vec[i64]([1, 2, 3]) + assert 2.0 in v + assert 3.0 in v + assert 1.5 not in v + assert 0.0 not in v + +def test_contains_bool() -> None: + # bool is a subclass of int, so True/False should work + v: Any = vec[i64]([0, 1, 5]) + assert True in v + assert False in v + v2: Any = vec[i64]([2, 3]) + assert True not in v2 + assert False not in v2 + def test_remove() -> None: a = [4, 7, 9] for i in a: diff --git a/mypyc/test-data/run-vecs-nested-interp.test b/mypyc/test-data/run-vecs-nested-interp.test index bd621e99498d1..50f8bca600b95 100644 --- a/mypyc/test-data/run-vecs-nested-interp.test +++ b/mypyc/test-data/run-vecs-nested-interp.test @@ -410,3 +410,48 @@ def test_iterator_length_hint() -> None: next(it) next(it) assert length_hint(it) == 0 + +def test_contains() -> None: + v1 = vec[str](['a']) + v2 = vec[str](['b']) + v3 = vec[str](['c']) + v = vec[vec[str]]([v1, v2, v3]) + assert v1 in v + assert v2 in v + assert v3 in v + assert vec[str]([]) not in v + assert vec[str](['d']) not in v + +def test_contains_empty() -> None: + v = vec[vec[str]]() + assert vec[str]([]) not in v + assert vec[str](['x']) not in v + +def test_contains_equality() -> None: + # Test that equality is used, not just identity + v = vec[vec[str]]([vec[str](['x', 'y'])]) + assert vec[str](['x', 'y']) in v + assert vec[str](['x']) not in v + assert vec[str](['y']) not in v + +def test_contains_nested_i64() -> None: + v = vec[vec[i64]]([vec[i64]([1, 2]), vec[i64]([3])]) + assert vec[i64]([1, 2]) in v + assert vec[i64]([3]) in v + assert vec[i64]([1]) not in v + assert vec[i64]([]) not in v + +def test_contains_wrong_type() -> None: + v: Any = vec[vec[str]]([vec[str](['x'])]) + assert 'x' not in v + assert 123 not in v + assert None not in v + assert [] not in v + +def test_contains_deeply_nested() -> None: + inner = vec[vec[str]]([vec[str](['a'])]) + v = vec[vec[vec[str]]]([inner]) + assert inner in v + assert vec[vec[str]]([vec[str](['a'])]) in v + assert vec[vec[str]]([vec[str](['b'])]) not in v + assert vec[vec[str]]() not in v diff --git a/mypyc/test-data/run-vecs-t-interp.test b/mypyc/test-data/run-vecs-t-interp.test index 881530b50f7a6..76bb57bfdd571 100644 --- a/mypyc/test-data/run-vecs-t-interp.test +++ b/mypyc/test-data/run-vecs-t-interp.test @@ -419,6 +419,41 @@ def test_contains() -> None: assert None in v2 assert '4' not in v2 +def test_contains_identity() -> None: + # Test that identity check works (short-circuits before __eq__) + o = ClassWithRaisingEq() + v = vec[object]([o]) + assert o in v # Should use identity, not __eq__ + +def test_contains_equality() -> None: + # Test that equality is used when identity doesn't match + s = 'hello world!' + v = vec[str]([s]) + # Create a new string object with same value (different identity) + s2 = 'hello' + str() + ' world!' + assert s2 in v + +def test_contains_exception_from_eq() -> None: + o1 = ClassWithRaisingEq() + o2 = ClassWithRaisingEq() + v = vec[object]([o1]) + with assertRaises(RuntimeError): + o2 in v + +def test_contains_wrong_type() -> None: + # Wrong type should return False, not raise + v: Any = vec[str](['x']) + assert 123 not in v + assert b'x' not in v + assert None not in v + assert [] not in v + +def test_contains_subclass() -> None: + v = vec[str]([StringSubclass('abc')]) + assert 'abc' in v + assert StringSubclass('abc') in v + assert 'xyz' not in v + def test_remove() -> None: a = ['4x', '7x', '9x'] for i in a: