Skip to content

Commit 6dabe86

Browse files
authored
feat: add contains() operation to QuerySet (#2163)
* Add `contains()` operation to `QuerySet` * Update the documentation and the changelog * Trigger CI * Add docstrings * Trigger CI * Update changelog * Update changelog * Address PR comments * Verify the object's model is similar to the queryset's model * Use `to_db_value()` to get the value
1 parent bfb9e81 commit 6dabe86

4 files changed

Lines changed: 79 additions & 0 deletions

File tree

CHANGELOG.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ Changelog
1414
Added
1515
^^^^^
1616
- ``QuerySet.union()`` — SQL UNION query support for combining results from multiple QuerySets, including support for union across different models, ``union(all=True)`` for duplicates, ``order_by()``, ``limit()``, and ``count()``.
17+
- ``QuerySet.contains()`` method to check if an object exists in a queryset.
1718
- Added comprehensive EXPLAIN support for MySQL and PostgreSQL.
1819

1920
Fixed

docs/query.rst

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,19 @@ It's also possible to filter your queries with ``.exclude()``:
9595
9696
await Team.exclude(name__icontains='junior')
9797
98+
You can also check if a specific object exists in a queryset:
99+
100+
.. code-block:: python3
101+
102+
obj = await Team.filter(name='My Team').first()
103+
exists = await Team.all().contains(obj)
104+
105+
Or simply check if any record matches a filter:
106+
107+
.. code-block:: python3
108+
109+
exists = await Team.filter(name='My Team').exists()
110+
98111
As more interesting case, when you are working with related data, you could also build your
99112
query around related entities:
100113

tests/test_queryset.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,30 @@ async def test_exists(db, intfields_data):
5959
assert not ret
6060

6161

62+
@pytest.mark.asyncio
63+
async def test_contains(db, intfields_data):
64+
obj = await IntFields.filter(intnum=10).first()
65+
assert await IntFields.all().contains(obj)
66+
67+
assert await IntFields.filter(intnum__lt=50).contains(obj)
68+
69+
assert not await IntFields.filter(intnum__gt=50).contains(obj)
70+
71+
72+
@pytest.mark.asyncio
73+
async def test_contains_when_no_pk(db, intfields_data):
74+
with pytest.raises(ParamsError, match="The given object does not have a primary key."):
75+
await IntFields.all().contains(IntFields(intnum=99))
76+
77+
78+
@pytest.mark.asyncio
79+
async def test_contains_with_wrong_model(db, intfields_data):
80+
with pytest.raises(
81+
ParamsError, match="The given object is not an instance of the queryset's model."
82+
):
83+
await IntFields.all().contains(Tournament(name="test"))
84+
85+
6286
@pytest.mark.asyncio
6387
async def test_limit_count(db, intfields_data):
6488
assert await IntFields.all().limit(10).count() == 10

tortoise/queryset.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -822,6 +822,31 @@ def exists(self) -> ExistsQuery:
822822
use_indexes=self._use_indexes,
823823
)
824824

825+
def contains(self, obj: MODEL) -> ContainsQuery:
826+
"""
827+
Check if the QuerySet contains the given instance.
828+
829+
:param obj: The model instance to check for.
830+
:return: True if the QuerySet contains the instance, False otherwise.
831+
"""
832+
833+
if not isinstance(obj, self.model):
834+
raise ParamsError("The given object is not an instance of the queryset's model.")
835+
836+
if not obj.pk:
837+
raise ParamsError("The given object does not have a primary key.")
838+
839+
return ContainsQuery(
840+
db=self._db,
841+
model=self.model,
842+
q_objects=self._q_objects,
843+
annotations=self._annotations,
844+
custom_filters=self._custom_filters,
845+
force_indexes=self._force_indexes,
846+
use_indexes=self._use_indexes,
847+
obj=obj,
848+
)
849+
825850
def all(self) -> QuerySet[MODEL]:
826851
"""
827852
Return the whole QuerySet.
@@ -1473,6 +1498,22 @@ async def _execute(
14731498
return bool(result)
14741499

14751500

1501+
class ContainsQuery(ExistsQuery):
1502+
def __init__(
1503+
self,
1504+
obj: MODEL,
1505+
**kwargs,
1506+
) -> None:
1507+
super().__init__(**kwargs)
1508+
self._obj = obj
1509+
1510+
def _make_query(self) -> None:
1511+
super()._make_query()
1512+
pk_field = Field(self.model._meta.db_pk_column)
1513+
pk_value = self.model._meta.pk.to_db_value(self._obj.pk, self._obj)
1514+
self.query = self.query.where(pk_field.eq(pk_value))
1515+
1516+
14761517
class CountQuery(AwaitableQuery):
14771518
__slots__ = (
14781519
"_limit",

0 commit comments

Comments
 (0)