Skip to content

Commit beceeaf

Browse files
ryuwdStellatsuu
authored andcommitted
perf: use sargable (Search ARGument ABLE) range predicates for datetime search filters (DIRACGrid#809)
* perf: use sargable range predicates for datetime search filters Replace date_trunc() (which wraps columns in date_format() on MySQL) with range-based comparisons on the raw column. This allows the database to use indexes on datetime columns instead of performing full table scans.
1 parent 8dbf288 commit beceeaf

2 files changed

Lines changed: 125 additions & 88 deletions

File tree

diracx-db/src/diracx/db/sql/utils/base.py

Lines changed: 124 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,13 @@
77
from abc import ABCMeta
88
from collections.abc import AsyncIterator
99
from contextvars import ContextVar
10-
from datetime import datetime, timezone
10+
from datetime import datetime, timedelta, timezone
11+
from enum import StrEnum
1112
from typing import Any, Self, cast
1213
from uuid import UUID as StdUUID # noqa: N811
1314

1415
from pydantic import TypeAdapter
15-
from sqlalchemy import DateTime, MetaData, func, inspect, select
16+
from sqlalchemy import DateTime, MetaData, and_, func, inspect, or_, select
1617
from sqlalchemy.exc import OperationalError
1718
from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine, create_async_engine
1819
from sqlalchemy.orm import DeclarativeBase
@@ -21,16 +22,16 @@
2122
from diracx.core.exceptions import InvalidQueryError
2223
from diracx.core.extensions import DiracEntryPoint, select_from_extension
2324
from diracx.core.models.search import (
25+
ScalarSearchOperator,
2426
SearchSpec,
2527
SortDirection,
2628
SortSpec,
29+
VectorSearchOperator,
2730
)
2831
from diracx.core.settings import SqlalchemyDsn
2932
from diracx.db.exceptions import DBUnavailableError
3033
from diracx.db.sql.utils.types import SmarterDateTime
3134

32-
from .functions import date_trunc
33-
3435
logger = logging.getLogger(__name__)
3536

3637

@@ -318,24 +319,33 @@ async def _summary(
318319
]
319320

320321

322+
class TimeResolution(StrEnum):
323+
YEAR = "YEAR"
324+
MONTH = "MONTH"
325+
DAY = "DAY"
326+
HOUR = "HOUR"
327+
MINUTE = "MINUTE"
328+
SECOND = "SECOND"
329+
330+
321331
def find_time_resolution(value):
322332
if isinstance(value, datetime):
323333
return None, value
324334
if match := re.fullmatch(
325335
r"\d{4}(-\d{2}(-\d{2}(([ T])\d{2}(:\d{2}(:\d{2}(\.\d{1,6}Z?)?)?)?)?)?)?", value
326336
):
327337
if match.group(6):
328-
precision, pattern = "SECOND", r"\1-\2-\3 \4:\5:\6"
338+
precision, pattern = TimeResolution.SECOND, r"\1-\2-\3 \4:\5:\6"
329339
elif match.group(5):
330-
precision, pattern = "MINUTE", r"\1-\2-\3 \4:\5"
340+
precision, pattern = TimeResolution.MINUTE, r"\1-\2-\3 \4:\5"
331341
elif match.group(3):
332-
precision, pattern = "HOUR", r"\1-\2-\3 \4"
342+
precision, pattern = TimeResolution.HOUR, r"\1-\2-\3 \4"
333343
elif match.group(2):
334-
precision, pattern = "DAY", r"\1-\2-\3"
344+
precision, pattern = TimeResolution.DAY, r"\1-\2-\3"
335345
elif match.group(1):
336-
precision, pattern = "MONTH", r"\1-\2"
346+
precision, pattern = TimeResolution.MONTH, r"\1-\2"
337347
else:
338-
precision, pattern = "YEAR", r"\1"
348+
precision, pattern = TimeResolution.YEAR, r"\1"
339349
return (
340350
precision,
341351
re.sub(
@@ -359,6 +369,84 @@ def _get_columns(table, parameters):
359369
return columns
360370

361371

372+
def _datetime_period_bounds(
373+
value_str: str, precision: TimeResolution
374+
) -> tuple[datetime, datetime]:
375+
"""Compute the inclusive start and exclusive end of a datetime period.
376+
377+
For example, precision=TimeResolution.DAY and value_str="2025-08-25" returns:
378+
(datetime(2025, 8, 25, 0, 0), datetime(2025, 8, 26, 0, 0))
379+
"""
380+
parse_formats = {
381+
TimeResolution.YEAR: "%Y",
382+
TimeResolution.MONTH: "%Y-%m",
383+
TimeResolution.DAY: "%Y-%m-%d",
384+
TimeResolution.HOUR: "%Y-%m-%d %H",
385+
TimeResolution.MINUTE: "%Y-%m-%d %H:%M",
386+
TimeResolution.SECOND: "%Y-%m-%d %H:%M:%S",
387+
}
388+
start = datetime.strptime(value_str, parse_formats[precision]).replace(
389+
tzinfo=timezone.utc
390+
)
391+
392+
if precision == TimeResolution.YEAR:
393+
end = start.replace(year=start.year + 1)
394+
elif precision == TimeResolution.MONTH:
395+
end = (
396+
start.replace(year=start.year + 1, month=1)
397+
if start.month == 12
398+
else start.replace(month=start.month + 1)
399+
)
400+
elif precision == TimeResolution.DAY:
401+
end = start + timedelta(days=1)
402+
elif precision == TimeResolution.HOUR:
403+
end = start + timedelta(hours=1)
404+
elif precision == TimeResolution.MINUTE:
405+
end = start + timedelta(minutes=1)
406+
else: # SECOND
407+
end = start + timedelta(seconds=1)
408+
409+
return start, end
410+
411+
412+
def _build_datetime_range_expr(
413+
column,
414+
operator: ScalarSearchOperator,
415+
start: datetime,
416+
end: datetime,
417+
):
418+
"""Build a sargable range expression for a single datetime period.
419+
420+
Uses range predicates on the raw column so database indexes can be used.
421+
"""
422+
if operator == ScalarSearchOperator.EQUAL:
423+
return and_(column >= start, column < end)
424+
if operator == ScalarSearchOperator.NOT_EQUAL:
425+
return or_(column < start, column >= end)
426+
if operator == ScalarSearchOperator.GREATER_THAN:
427+
return column >= end
428+
if operator == ScalarSearchOperator.LESS_THAN:
429+
return column < start
430+
raise InvalidQueryError(
431+
f"Operator '{operator}' is not supported for partial datetime values"
432+
)
433+
434+
435+
def _build_datetime_range_multi_expr(
436+
column,
437+
operator: VectorSearchOperator,
438+
bounds: list[tuple[datetime, datetime]],
439+
):
440+
"""Build a sargable range expression for multiple datetime periods (IN/NOT IN)."""
441+
if operator == VectorSearchOperator.IN:
442+
return or_(*[and_(column >= s, column < e) for s, e in bounds])
443+
if operator == VectorSearchOperator.NOT_IN:
444+
return and_(*[or_(column < s, column >= e) for s, e in bounds])
445+
raise InvalidQueryError(
446+
f"Operator '{operator}' is not supported for partial datetime values"
447+
)
448+
449+
362450
def apply_search_filters(column_mapping, stmt, search):
363451
for query in search:
364452
try:
@@ -370,7 +458,13 @@ def apply_search_filters(column_mapping, stmt, search):
370458
if "value" in query and isinstance(query["value"], str):
371459
resolution, value = find_time_resolution(query["value"])
372460
if resolution:
373-
column = date_trunc(column, time_resolution=resolution)
461+
start, end = _datetime_period_bounds(value, resolution)
462+
stmt = stmt.where(
463+
_build_datetime_range_expr(
464+
column, query["operator"], start, end
465+
)
466+
)
467+
continue
374468
query["value"] = value
375469

376470
if query.get("values"):
@@ -382,28 +476,37 @@ def apply_search_filters(column_mapping, stmt, search):
382476
f"Cannot mix different time resolutions in {query=}"
383477
)
384478
if resolution := resolutions[0]:
385-
column = date_trunc(column, time_resolution=resolution)
479+
bounds = [
480+
_datetime_period_bounds(cast(str, v), resolution)
481+
for v in values
482+
]
483+
stmt = stmt.where(
484+
_build_datetime_range_multi_expr(
485+
column, query["operator"], bounds
486+
)
487+
)
488+
continue
386489
query["values"] = values
387490

388-
if query["operator"] == "eq":
491+
if query["operator"] == ScalarSearchOperator.EQUAL:
389492
expr = column == query["value"]
390-
elif query["operator"] == "neq":
493+
elif query["operator"] == ScalarSearchOperator.NOT_EQUAL:
391494
expr = column != query["value"]
392-
elif query["operator"] == "gt":
495+
elif query["operator"] == ScalarSearchOperator.GREATER_THAN:
393496
expr = column > query["value"]
394-
elif query["operator"] == "lt":
497+
elif query["operator"] == ScalarSearchOperator.LESS_THAN:
395498
expr = column < query["value"]
396-
elif query["operator"] == "in":
499+
elif query["operator"] == VectorSearchOperator.IN:
397500
expr = column.in_(query["values"])
398-
elif query["operator"] == "not in":
501+
elif query["operator"] == VectorSearchOperator.NOT_IN:
399502
expr = column.notin_(query["values"])
400-
elif query["operator"] in "like":
503+
elif query["operator"] == ScalarSearchOperator.LIKE:
401504
expr = column.like(query["value"])
402505
elif query["operator"] in "ilike":
403506
expr = column.ilike(query["value"])
404-
elif query["operator"] == "not like":
507+
elif query["operator"] == ScalarSearchOperator.NOT_LIKE:
405508
expr = column.not_like(query["value"])
406-
elif query["operator"] == "regex":
509+
elif query["operator"] == ScalarSearchOperator.REGEX:
407510
# We check the regex validity here
408511
try:
409512
re.compile(query["value"])

diracx-db/src/diracx/db/sql/utils/functions.py

Lines changed: 1 addition & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from datetime import datetime, timedelta, timezone
55
from typing import TYPE_CHECKING
66

7-
from sqlalchemy import DateTime, func
7+
from sqlalchemy import DateTime
88
from sqlalchemy.ext.compiler import compiles
99
from sqlalchemy.sql import expression
1010

@@ -37,72 +37,6 @@ def sqlite_utcnow(element, compiler, **kw) -> str:
3737
return "DATETIME('now')"
3838

3939

40-
class date_trunc(expression.FunctionElement): # noqa: N801
41-
"""Sqlalchemy function to truncate a date to a given resolution.
42-
43-
Primarily used to be able to query for a specific resolution of a date e.g.
44-
45-
select * from table where date_trunc('day', date_column) = '2021-01-01'
46-
select * from table where date_trunc('year', date_column) = '2021'
47-
select * from table where date_trunc('minute', date_column) = '2021-01-01 12:00'
48-
"""
49-
50-
type = DateTime()
51-
# Cache does not work as intended with time resolution values, so we disable it
52-
inherit_cache = False
53-
54-
def __init__(self, *args, time_resolution, **kwargs) -> None:
55-
super().__init__(*args, **kwargs)
56-
self._time_resolution = time_resolution
57-
58-
59-
@compiles(date_trunc, "postgresql")
60-
def pg_date_trunc(element, compiler, **kw):
61-
res = {
62-
"SECOND": "second",
63-
"MINUTE": "minute",
64-
"HOUR": "hour",
65-
"DAY": "day",
66-
"MONTH": "month",
67-
"YEAR": "year",
68-
}[element._time_resolution]
69-
return f"date_trunc('{res}', {compiler.process(element.clauses)})"
70-
71-
72-
@compiles(date_trunc, "mysql")
73-
def mysql_date_trunc(element, compiler, **kw):
74-
pattern = {
75-
"SECOND": "%Y-%m-%d %H:%i:%S",
76-
"MINUTE": "%Y-%m-%d %H:%i",
77-
"HOUR": "%Y-%m-%d %H",
78-
"DAY": "%Y-%m-%d",
79-
"MONTH": "%Y-%m",
80-
"YEAR": "%Y",
81-
}[element._time_resolution]
82-
83-
(dt_col,) = list(element.clauses)
84-
return compiler.process(func.date_format(dt_col, pattern))
85-
86-
87-
@compiles(date_trunc, "sqlite")
88-
def sqlite_date_trunc(element, compiler, **kw):
89-
pattern = {
90-
"SECOND": "%Y-%m-%d %H:%M:%S",
91-
"MINUTE": "%Y-%m-%d %H:%M",
92-
"HOUR": "%Y-%m-%d %H",
93-
"DAY": "%Y-%m-%d",
94-
"MONTH": "%Y-%m",
95-
"YEAR": "%Y",
96-
}[element._time_resolution]
97-
(dt_col,) = list(element.clauses)
98-
return compiler.process(
99-
func.strftime(
100-
pattern,
101-
dt_col,
102-
)
103-
)
104-
105-
10640
class days_since(expression.FunctionElement): # noqa: N801
10741
"""Sqlalchemy function to get the number of days since a given date.
10842

0 commit comments

Comments
 (0)