|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
3 | 3 | from datetime import datetime, timezone |
| 4 | +from typing import Any |
4 | 5 |
|
5 | | -from sqlalchemy import insert |
| 6 | +from sqlalchemy import func, insert, select |
6 | 7 |
|
7 | | -from ..utils import BaseSQLDB |
| 8 | +from diracx.core.exceptions import InvalidQueryError |
| 9 | +from diracx.core.models import ( |
| 10 | + SearchSpec, |
| 11 | + SortSpec, |
| 12 | +) |
| 13 | + |
| 14 | +from ..utils import BaseSQLDB, apply_search_filters, apply_sort_constraints, get_columns |
8 | 15 | from .schema import PilotAgents, PilotAgentsDBBase |
9 | 16 |
|
10 | 17 |
|
@@ -44,3 +51,46 @@ async def add_pilot_references( |
44 | 51 | stmt = insert(PilotAgents).values(values) |
45 | 52 | await self.conn.execute(stmt) |
46 | 53 | return |
| 54 | + |
| 55 | + async def search( |
| 56 | + self, |
| 57 | + parameters: list[str] | None, |
| 58 | + search: list[SearchSpec], |
| 59 | + sorts: list[SortSpec], |
| 60 | + *, |
| 61 | + distinct: bool = False, |
| 62 | + per_page: int = 100, |
| 63 | + page: int | None = None, |
| 64 | + ) -> tuple[int, list[dict[Any, Any]]]: |
| 65 | + # Find which columns to select |
| 66 | + columns = get_columns(PilotAgents.__table__, parameters) |
| 67 | + |
| 68 | + stmt = select(*columns) |
| 69 | + |
| 70 | + stmt = apply_search_filters( |
| 71 | + PilotAgents.__table__.columns.__getitem__, stmt, search |
| 72 | + ) |
| 73 | + stmt = apply_sort_constraints( |
| 74 | + PilotAgents.__table__.columns.__getitem__, stmt, sorts |
| 75 | + ) |
| 76 | + |
| 77 | + if distinct: |
| 78 | + stmt = stmt.distinct() |
| 79 | + |
| 80 | + # Calculate total count before applying pagination |
| 81 | + total_count_subquery = stmt.alias() |
| 82 | + total_count_stmt = select(func.count()).select_from(total_count_subquery) |
| 83 | + total = (await self.conn.execute(total_count_stmt)).scalar_one() |
| 84 | + |
| 85 | + # Apply pagination |
| 86 | + if page is not None: |
| 87 | + if page < 1: |
| 88 | + raise InvalidQueryError("Page must be a positive integer") |
| 89 | + if per_page < 1: |
| 90 | + raise InvalidQueryError("Per page must be a positive integer") |
| 91 | + stmt = stmt.offset((page - 1) * per_page).limit(per_page) |
| 92 | + |
| 93 | + # Execute the query |
| 94 | + return total, [ |
| 95 | + dict(row._mapping) async for row in (await self.conn.stream(stmt)) |
| 96 | + ] |
0 commit comments