77from abc import ABCMeta
88from collections .abc import AsyncIterator
99from contextvars import ContextVar
10- from datetime import datetime , timezone
10+ from datetime import datetime , timedelta , timezone
11+ from enum import StrEnum
1112from typing import Any , Self , cast
1213from uuid import UUID as StdUUID # noqa: N811
1314
1415from pydantic import TypeAdapter
15- from sqlalchemy import DateTime , MetaData , func , inspect , select
16+ from sqlalchemy import DateTime , MetaData , and_ , func , inspect , or_ , select
1617from sqlalchemy .exc import OperationalError
1718from sqlalchemy .ext .asyncio import AsyncConnection , AsyncEngine , create_async_engine
1819from sqlalchemy .orm import DeclarativeBase
2122from diracx .core .exceptions import InvalidQueryError
2223from diracx .core .extensions import DiracEntryPoint , select_from_extension
2324from diracx .core .models .search import (
25+ ScalarSearchOperator ,
2426 SearchSpec ,
2527 SortDirection ,
2628 SortSpec ,
29+ VectorSearchOperator ,
2730)
2831from diracx .core .settings import SqlalchemyDsn
2932from diracx .db .exceptions import DBUnavailableError
3033from diracx .db .sql .utils .types import SmarterDateTime
3134
32- from .functions import date_trunc
33-
3435logger = logging .getLogger (__name__ )
3536
3637
@@ -318,24 +319,33 @@ async def _summary(
318319 ]
319320
320321
322+ class TimeResolution (StrEnum ):
323+ YEAR = "YEAR"
324+ MONTH = "MONTH"
325+ DAY = "DAY"
326+ HOUR = "HOUR"
327+ MINUTE = "MINUTE"
328+ SECOND = "SECOND"
329+
330+
321331def find_time_resolution (value ):
322332 if isinstance (value , datetime ):
323333 return None , value
324334 if match := re .fullmatch (
325335 r"\d{4}(-\d{2}(-\d{2}(([ T])\d{2}(:\d{2}(:\d{2}(\.\d{1,6}Z?)?)?)?)?)?)?" , value
326336 ):
327337 if match .group (6 ):
328- precision , pattern = " SECOND" , r"\1-\2-\3 \4:\5:\6"
338+ precision , pattern = TimeResolution . SECOND , r"\1-\2-\3 \4:\5:\6"
329339 elif match .group (5 ):
330- precision , pattern = " MINUTE" , r"\1-\2-\3 \4:\5"
340+ precision , pattern = TimeResolution . MINUTE , r"\1-\2-\3 \4:\5"
331341 elif match .group (3 ):
332- precision , pattern = " HOUR" , r"\1-\2-\3 \4"
342+ precision , pattern = TimeResolution . HOUR , r"\1-\2-\3 \4"
333343 elif match .group (2 ):
334- precision , pattern = " DAY" , r"\1-\2-\3"
344+ precision , pattern = TimeResolution . DAY , r"\1-\2-\3"
335345 elif match .group (1 ):
336- precision , pattern = " MONTH" , r"\1-\2"
346+ precision , pattern = TimeResolution . MONTH , r"\1-\2"
337347 else :
338- precision , pattern = " YEAR" , r"\1"
348+ precision , pattern = TimeResolution . YEAR , r"\1"
339349 return (
340350 precision ,
341351 re .sub (
@@ -359,6 +369,84 @@ def _get_columns(table, parameters):
359369 return columns
360370
361371
372+ def _datetime_period_bounds (
373+ value_str : str , precision : TimeResolution
374+ ) -> tuple [datetime , datetime ]:
375+ """Compute the inclusive start and exclusive end of a datetime period.
376+
377+ For example, precision=TimeResolution.DAY and value_str="2025-08-25" returns:
378+ (datetime(2025, 8, 25, 0, 0), datetime(2025, 8, 26, 0, 0))
379+ """
380+ parse_formats = {
381+ TimeResolution .YEAR : "%Y" ,
382+ TimeResolution .MONTH : "%Y-%m" ,
383+ TimeResolution .DAY : "%Y-%m-%d" ,
384+ TimeResolution .HOUR : "%Y-%m-%d %H" ,
385+ TimeResolution .MINUTE : "%Y-%m-%d %H:%M" ,
386+ TimeResolution .SECOND : "%Y-%m-%d %H:%M:%S" ,
387+ }
388+ start = datetime .strptime (value_str , parse_formats [precision ]).replace (
389+ tzinfo = timezone .utc
390+ )
391+
392+ if precision == TimeResolution .YEAR :
393+ end = start .replace (year = start .year + 1 )
394+ elif precision == TimeResolution .MONTH :
395+ end = (
396+ start .replace (year = start .year + 1 , month = 1 )
397+ if start .month == 12
398+ else start .replace (month = start .month + 1 )
399+ )
400+ elif precision == TimeResolution .DAY :
401+ end = start + timedelta (days = 1 )
402+ elif precision == TimeResolution .HOUR :
403+ end = start + timedelta (hours = 1 )
404+ elif precision == TimeResolution .MINUTE :
405+ end = start + timedelta (minutes = 1 )
406+ else : # SECOND
407+ end = start + timedelta (seconds = 1 )
408+
409+ return start , end
410+
411+
412+ def _build_datetime_range_expr (
413+ column ,
414+ operator : ScalarSearchOperator ,
415+ start : datetime ,
416+ end : datetime ,
417+ ):
418+ """Build a sargable range expression for a single datetime period.
419+
420+ Uses range predicates on the raw column so database indexes can be used.
421+ """
422+ if operator == ScalarSearchOperator .EQUAL :
423+ return and_ (column >= start , column < end )
424+ if operator == ScalarSearchOperator .NOT_EQUAL :
425+ return or_ (column < start , column >= end )
426+ if operator == ScalarSearchOperator .GREATER_THAN :
427+ return column >= end
428+ if operator == ScalarSearchOperator .LESS_THAN :
429+ return column < start
430+ raise InvalidQueryError (
431+ f"Operator '{ operator } ' is not supported for partial datetime values"
432+ )
433+
434+
435+ def _build_datetime_range_multi_expr (
436+ column ,
437+ operator : VectorSearchOperator ,
438+ bounds : list [tuple [datetime , datetime ]],
439+ ):
440+ """Build a sargable range expression for multiple datetime periods (IN/NOT IN)."""
441+ if operator == VectorSearchOperator .IN :
442+ return or_ (* [and_ (column >= s , column < e ) for s , e in bounds ])
443+ if operator == VectorSearchOperator .NOT_IN :
444+ return and_ (* [or_ (column < s , column >= e ) for s , e in bounds ])
445+ raise InvalidQueryError (
446+ f"Operator '{ operator } ' is not supported for partial datetime values"
447+ )
448+
449+
362450def apply_search_filters (column_mapping , stmt , search ):
363451 for query in search :
364452 try :
@@ -370,7 +458,13 @@ def apply_search_filters(column_mapping, stmt, search):
370458 if "value" in query and isinstance (query ["value" ], str ):
371459 resolution , value = find_time_resolution (query ["value" ])
372460 if resolution :
373- column = date_trunc (column , time_resolution = resolution )
461+ start , end = _datetime_period_bounds (value , resolution )
462+ stmt = stmt .where (
463+ _build_datetime_range_expr (
464+ column , query ["operator" ], start , end
465+ )
466+ )
467+ continue
374468 query ["value" ] = value
375469
376470 if query .get ("values" ):
@@ -382,28 +476,37 @@ def apply_search_filters(column_mapping, stmt, search):
382476 f"Cannot mix different time resolutions in { query = } "
383477 )
384478 if resolution := resolutions [0 ]:
385- column = date_trunc (column , time_resolution = resolution )
479+ bounds = [
480+ _datetime_period_bounds (cast (str , v ), resolution )
481+ for v in values
482+ ]
483+ stmt = stmt .where (
484+ _build_datetime_range_multi_expr (
485+ column , query ["operator" ], bounds
486+ )
487+ )
488+ continue
386489 query ["values" ] = values
387490
388- if query ["operator" ] == "eq" :
491+ if query ["operator" ] == ScalarSearchOperator . EQUAL :
389492 expr = column == query ["value" ]
390- elif query ["operator" ] == "neq" :
493+ elif query ["operator" ] == ScalarSearchOperator . NOT_EQUAL :
391494 expr = column != query ["value" ]
392- elif query ["operator" ] == "gt" :
495+ elif query ["operator" ] == ScalarSearchOperator . GREATER_THAN :
393496 expr = column > query ["value" ]
394- elif query ["operator" ] == "lt" :
497+ elif query ["operator" ] == ScalarSearchOperator . LESS_THAN :
395498 expr = column < query ["value" ]
396- elif query ["operator" ] == "in" :
499+ elif query ["operator" ] == VectorSearchOperator . IN :
397500 expr = column .in_ (query ["values" ])
398- elif query ["operator" ] == "not in" :
501+ elif query ["operator" ] == VectorSearchOperator . NOT_IN :
399502 expr = column .notin_ (query ["values" ])
400- elif query ["operator" ] in "like" :
503+ elif query ["operator" ] == ScalarSearchOperator . LIKE :
401504 expr = column .like (query ["value" ])
402505 elif query ["operator" ] in "ilike" :
403506 expr = column .ilike (query ["value" ])
404- elif query ["operator" ] == "not like" :
507+ elif query ["operator" ] == ScalarSearchOperator . NOT_LIKE :
405508 expr = column .not_like (query ["value" ])
406- elif query ["operator" ] == "regex" :
509+ elif query ["operator" ] == ScalarSearchOperator . REGEX :
407510 # We check the regex validity here
408511 try :
409512 re .compile (query ["value" ])
0 commit comments