Skip to content

Commit e44ad39

Browse files
committed
Enhance the relationship code
1 parent 3013563 commit e44ad39

5 files changed

Lines changed: 178 additions & 49 deletions

File tree

sqlalchemy_crud_plus/crud.py

Lines changed: 77 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -131,13 +131,15 @@ async def count(
131131
self,
132132
session: AsyncSession,
133133
*whereclause: ColumnExpressionArgument[bool],
134+
join_conditions: JoinConditionsConfig | None = None,
134135
**kwargs,
135136
) -> int:
136137
"""
137138
Count records that match specified filters.
138139
139-
:param session: The SQLAlchemy async session
140-
:param whereclause: Additional WHERE clauses to apply to the query
140+
:param session: SQLAlchemy async session
141+
:param whereclause: Additional WHERE clauses
142+
:param join_conditions: JOIN conditions for relationships
141143
:param kwargs: Filter expressions using field__operator=value syntax
142144
:return:
143145
"""
@@ -150,6 +152,9 @@ async def count(
150152
if filters:
151153
stmt = stmt.where(*filters)
152154

155+
if join_conditions:
156+
stmt = apply_join_conditions(self.model, stmt, join_conditions)
157+
153158
query = await session.execute(stmt)
154159
total_count = query.scalar()
155160
return total_count if total_count is not None else 0
@@ -158,13 +163,15 @@ async def exists(
158163
self,
159164
session: AsyncSession,
160165
*whereclause: ColumnExpressionArgument[bool],
166+
join_conditions: JoinConditionsConfig | None = None,
161167
**kwargs,
162168
) -> bool:
163169
"""
164170
Check whether records that match the specified filters exist.
165171
166-
:param session: The SQLAlchemy async session
167-
:param whereclause: Additional WHERE clauses to apply to the query
172+
:param session: SQLAlchemy async session
173+
:param whereclause: Additional WHERE clauses
174+
:param join_conditions: JOIN conditions for relationships
168175
:param kwargs: Filter expressions using field__operator=value syntax
169176
:return:
170177
"""
@@ -174,6 +181,10 @@ async def exists(
174181
filters.extend(parse_filters(self.model, **kwargs))
175182

176183
stmt = select(self.model).where(*filters).limit(1)
184+
185+
if join_conditions:
186+
stmt = apply_join_conditions(self.model, stmt, join_conditions)
187+
177188
query = await session.execute(stmt)
178189
return query.scalars().first() is not None
179190

@@ -241,44 +252,60 @@ async def select_model_by_column(
241252
:param kwargs: Filter expressions using field__operator=value syntax
242253
:return:
243254
"""
244-
filters = list(whereclause)
245-
246-
if kwargs:
247-
filters.extend(parse_filters(self.model, **kwargs))
248-
249-
stmt = select(self.model).where(*filters)
250-
251-
if options:
252-
stmt = stmt.options(*options)
253-
254-
if join_conditions:
255-
stmt = apply_join_conditions(self.model, stmt, join_conditions)
256-
257-
if load_strategies:
258-
rel_options = build_load_strategies(self.model, load_strategies)
259-
if rel_options:
260-
stmt = stmt.options(*rel_options)
255+
stmt = await self.select(
256+
*whereclause,
257+
options=options,
258+
load_strategies=load_strategies,
259+
join_conditions=join_conditions,
260+
**kwargs
261+
)
261262

262263
query = await session.execute(stmt)
263264
return query.scalars().first()
264265

265-
async def select(self, *whereclause: ColumnExpressionArgument[bool], **kwargs) -> Select:
266+
async def select(
267+
self,
268+
*whereclause: ColumnExpressionArgument[bool],
269+
options: QueryOptions | None = None,
270+
load_strategies: LoadStrategiesConfig | None = None,
271+
join_conditions: JoinConditionsConfig | None = None,
272+
**kwargs
273+
) -> Select:
266274
"""
267275
Construct the SQLAlchemy selection.
268276
269277
:param whereclause: WHERE clauses to apply to the query
278+
:param options: SQLAlchemy loading options
279+
:param load_strategies: Relationship loading strategies
280+
:param join_conditions: JOIN conditions for relationships
270281
:param kwargs: Query expressions
271282
:return:
272283
"""
273-
filters = parse_filters(self.model, **kwargs) + list(whereclause)
284+
filters = list(whereclause)
285+
filters.extend(parse_filters(self.model, **kwargs))
274286
stmt = select(self.model).where(*filters)
287+
288+
if join_conditions:
289+
stmt = apply_join_conditions(self.model, stmt, join_conditions)
290+
291+
if options:
292+
stmt = stmt.options(*options)
293+
294+
if load_strategies:
295+
rel_options = build_load_strategies(self.model, load_strategies)
296+
if rel_options:
297+
stmt = stmt.options(*rel_options)
298+
275299
return stmt
276300

277301
async def select_order(
278302
self,
279303
sort_columns: SortColumns,
280304
sort_orders: SortOrders = None,
281305
*whereclause: ColumnExpressionArgument[bool],
306+
options: QueryOptions | None = None,
307+
load_strategies: LoadStrategiesConfig | None = None,
308+
join_conditions: JoinConditionsConfig | None = None,
282309
**kwargs: Any,
283310
) -> Select:
284311
"""
@@ -287,10 +314,19 @@ async def select_order(
287314
:param sort_columns: Column names to sort by
288315
:param sort_orders: Sort orders ('asc' or 'desc')
289316
:param whereclause: WHERE clauses to apply to the query
317+
:param options: SQLAlchemy loading options
318+
:param load_strategies: Relationship loading strategies
319+
:param join_conditions: JOIN conditions for relationships
290320
:param kwargs: Query expressions
291321
:return:
292322
"""
293-
stmt = await self.select(*whereclause, **kwargs)
323+
stmt = await self.select(
324+
*whereclause,
325+
options=options,
326+
load_strategies=load_strategies,
327+
join_conditions=join_conditions,
328+
**kwargs
329+
)
294330
sorted_stmt = apply_sorting(self.model, stmt, sort_columns, sort_orders)
295331
return sorted_stmt
296332

@@ -306,7 +342,7 @@ async def select_models(
306342
**kwargs: Any,
307343
) -> Sequence[Model]:
308344
"""
309-
Query all rows that match the specified filters with optional relationship loading and joins.
345+
Query all rows that match the specified filters with optional relationship loading and joins.
310346
311347
:param session: SQLAlchemy async session
312348
:param whereclause: Additional WHERE clauses
@@ -318,18 +354,13 @@ async def select_models(
318354
:param kwargs: Filter expressions using field__operator=value syntax
319355
:return:
320356
"""
321-
stmt = await self.select(*whereclause, **kwargs)
322-
323-
if options:
324-
stmt = stmt.options(*options)
325-
326-
if join_conditions:
327-
stmt = apply_join_conditions(self.model, stmt, join_conditions)
328-
329-
if load_strategies:
330-
rel_options = build_load_strategies(self.model, load_strategies)
331-
if rel_options:
332-
stmt = stmt.options(*rel_options)
357+
stmt = await self.select(
358+
*whereclause,
359+
options=options,
360+
load_strategies=load_strategies,
361+
join_conditions=join_conditions,
362+
**kwargs
363+
)
333364

334365
if limit is not None:
335366
stmt = stmt.limit(limit)
@@ -368,18 +399,15 @@ async def select_models_order(
368399
:param kwargs: Filter expressions using field__operator=value syntax
369400
:return:
370401
"""
371-
stmt = await self.select_order(sort_columns, sort_orders, *whereclause, **kwargs)
372-
373-
if options:
374-
stmt = stmt.options(*options)
375-
376-
if join_conditions:
377-
stmt = apply_join_conditions(self.model, stmt, join_conditions)
378-
379-
if load_strategies:
380-
rel_options = build_load_strategies(self.model, load_strategies)
381-
if rel_options:
382-
stmt = stmt.options(*rel_options)
402+
stmt = await self.select_order(
403+
sort_columns,
404+
sort_orders,
405+
*whereclause,
406+
options=options,
407+
load_strategies=load_strategies,
408+
join_conditions=join_conditions,
409+
**kwargs
410+
)
383411

384412
if limit is not None:
385413
stmt = stmt.limit(limit)

sqlalchemy_crud_plus/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -374,6 +374,8 @@ def apply_join_conditions(model: type[Model], stmt: Select, join_conditions: Joi
374374
elif join_type == 'inner':
375375
stmt = stmt.join(rel_attr)
376376
elif join_type == 'right':
377+
# RIGHT OUTER JOIN: SQLAlchemy doesn't directly support RIGHT JOIN
378+
# The standard approach is to reverse the table order and use LEFT JOIN For ORM relationships
377379
stmt = stmt.join(rel_attr, isouter=True)
378380
elif join_type == 'full':
379381
stmt = stmt.join(rel_attr, full=True)

tests/test_options.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
#!/usr/bin/env python3
2+
# -*- coding: utf-8 -*-
13
import pytest
24

35
from sqlalchemy.ext.asyncio import AsyncSession
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,3 +199,90 @@ async def test_many_to_many_role_users(
199199

200200
assert role is not None
201201
assert len(role.users) > 0
202+
203+
204+
class TestCountWithRelationships:
205+
"""Test count method with relationship queries."""
206+
207+
@pytest.mark.asyncio
208+
async def test_count_with_inner_join(
209+
self, db_session: AsyncSession, rel_sample_data: dict, rel_user_crud: CRUDPlus[RelUser]
210+
):
211+
"""Test count with INNER JOIN."""
212+
count = await rel_user_crud.count(
213+
db_session,
214+
join_conditions=['posts']
215+
)
216+
217+
assert count > 0
218+
total_users = await rel_user_crud.count(db_session)
219+
assert count >= total_users
220+
221+
@pytest.mark.asyncio
222+
async def test_count_with_left_join(
223+
self, db_session: AsyncSession, rel_sample_data: dict, rel_user_crud: CRUDPlus[RelUser]
224+
):
225+
"""Test count with LEFT JOIN."""
226+
count = await rel_user_crud.count(
227+
db_session,
228+
join_conditions={'posts': 'left'}
229+
)
230+
231+
total_users = await rel_user_crud.count(db_session)
232+
assert count >= total_users
233+
234+
@pytest.mark.asyncio
235+
async def test_count_with_filters(
236+
self, db_session: AsyncSession, rel_sample_data: dict, rel_user_crud: CRUDPlus[RelUser]
237+
):
238+
"""Test count with JOIN conditions and filters."""
239+
users = rel_sample_data['users']
240+
241+
count = await rel_user_crud.count(
242+
db_session,
243+
join_conditions=['posts'],
244+
name=users[0].name
245+
)
246+
247+
assert count >= 0
248+
249+
250+
class TestExistsWithRelationships:
251+
"""Test exists method with relationship queries."""
252+
253+
@pytest.mark.asyncio
254+
async def test_exists_with_inner_join(
255+
self, db_session: AsyncSession, rel_sample_data: dict, rel_user_crud: CRUDPlus[RelUser]
256+
):
257+
"""Test exists with INNER JOIN."""
258+
exists = await rel_user_crud.exists(
259+
db_session,
260+
join_conditions=['posts']
261+
)
262+
263+
assert exists is True
264+
265+
@pytest.mark.asyncio
266+
async def test_exists_with_left_join(
267+
self, db_session: AsyncSession, rel_sample_data: dict, rel_post_crud: CRUDPlus[RelPost]
268+
):
269+
"""Test exists with LEFT JOIN."""
270+
exists = await rel_post_crud.exists(
271+
db_session,
272+
join_conditions={'category': 'left'}
273+
)
274+
275+
assert exists is True
276+
277+
@pytest.mark.asyncio
278+
async def test_exists_false_case(
279+
self, db_session: AsyncSession, rel_sample_data: dict, rel_user_crud: CRUDPlus[RelUser]
280+
):
281+
"""Test exists returns False when no records match."""
282+
exists = await rel_user_crud.exists(
283+
db_session,
284+
join_conditions=['posts'],
285+
name='nonexistent_user'
286+
)
287+
288+
assert exists is False

tests/test_utils.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,16 @@ def test_apply_join_conditions_dict(self):
237237
sql_str = str(new_stmt)
238238
assert 'JOIN' in sql_str.upper()
239239

240+
def test_apply_join_conditions_right_join_limitation(self):
241+
"""Test that RIGHT JOIN is treated as LEFT JOIN due to SQLAlchemy limitations."""
242+
stmt = select(RelPost)
243+
# RIGHT JOIN should work but behave like LEFT JOIN
244+
new_stmt = apply_join_conditions(RelPost, stmt, {'author': 'right'})
245+
246+
sql_str = str(new_stmt)
247+
assert 'JOIN' in sql_str.upper()
248+
# Note: RIGHT JOIN is treated as LEFT JOIN in SQLAlchemy ORM
249+
240250
def test_apply_join_conditions_invalid_relation(self):
241251
"""Test applying JOIN conditions with invalid relationship."""
242252
stmt = select(RelPost)

0 commit comments

Comments
 (0)